Il modello lineare gerarchico#
In questo capitolo, approfondiremo la metodologia della regressione lineare gerarchica Bayesiana, utilizzando le API della libreria Bambi. L’obiettivo è di applicare questo metodo statistico a un dataset multilivello. In particolare, ci concentreremo su un dataset che comprende diverse unità di osservazione (i soggetti), ognuna delle quali ha più misurazioni associate.
Preparazione del Notebook#
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import bambi as bmb
import pymc.sampling_jax
import xarray as xr
import pingouin as pg
import arviz as az
import warnings
warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", category=FutureWarning)
warnings.filterwarnings("ignore", category=Warning)
/Users/corrado/opt/anaconda3/envs/pymc_env/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
from .autonotebook import tqdm as notebook_tqdm
%config InlineBackend.figure_format = 'retina'
RANDOM_SEED = 42
rng = np.random.default_rng(RANDOM_SEED)
az.style.use("arviz-darkgrid")
sns.set_theme(palette="colorblind")
Abbiamo già introdotto il concetto di modellazione gerarchica Bayesiana, focalizzandoci in precedenza sulla stima dei parametri di una distribuzione di probabilità (per ulteriori dettagli, si rimanda al capitolo Modello gerarchico beta-binomiale). In questo capitolo, estenderemo quel concetto alla stima dei parametri di un modello di regressione lineare, dove i dati sono raggruppati in diversi cluster.
Esamineremo e metteremo a confronto tre tipi di modelli gerarchici.
Complete Pooling: Questo modello ignora completamente la struttura gerarchica dei dati. In altre parole, tratta tutte le unità di osservazione come se fossero provenienti da un’unica popolazione.
No Pooling: Questo modello, al contrario, considera ogni cluster come completamente indipendente dagli altri, senza alcuna struttura gerarchica.
Partial Pooling (o Modello Multi-Livello): Questo modello è il più sofisticato tra i tre. Assume che le pendenze e le intercette di ciascun cluster siano realizzazioni di variabili casuali, estratte da una distribuzione normale con media e varianza comuni a tutti i cluster.
Il modello di “partial pooling” si distingue per la sua capacità di riconoscere la natura intrinsecamente gerarchica dei dati. Fornisce un equilibrio tra il “no pooling”, che tratta ogni cluster come indipendente, e il “complete pooling”, che normalizza eccessivamente i dati. In questo modello, le stime dei parametri per ciascun cluster sono influenzate sia dai dati specifici del cluster stesso sia dalla distribuzione globale dei dati tra tutti i cluster. Questo permette un “shrinkage” dei parametri, mitigando l’effetto di outliers o di cluster di dimensioni ridotte.
Per illustrare l’applicazione pratica di questi concetti, utilizzeremo il dataset sleepstudy
. Questo dataset proviene dalla ricerca di Belenky et al. (2003) e contiene i tempi di reazione medi giornalieri (misurati in millisecondi) di un gruppo di partecipanti sottoposti a deprivazione del sonno. Il design dello studio prevede una fase iniziale di calibrazione e adattamento nei primi due giorni, seguita da una misurazione di baseline al terzo giorno. Dopo questo periodo, inizia la fase di deprivazione del sonno, durante la quale i partecipanti sono limitati a un regime di solo 3 ore di sonno per notte.
EDA#
Iniziamo importando i dati e ispezionando la struttura delle osservazioni suddivise nei diversi cluster.
data = bmb.load_data("sleepstudy")
data.head()
Reaction | Days | Subject | |
---|---|---|---|
0 | 249.5600 | 0 | 308 |
1 | 258.7047 | 1 | 308 |
2 | 250.8006 | 2 | 308 |
3 | 321.4398 | 3 | 308 |
4 | 356.8519 | 4 | 308 |
Eliminiamo le righe in cui la colonna “Days” ha valore 0 o 1 dal dataset “sleepstudy” utilizzando il seguente codice:
data = data[data['Days'].isin([0, 1]) == False]
data.head()
Reaction | Days | Subject | |
---|---|---|---|
2 | 250.8006 | 2 | 308 |
3 | 321.4398 | 3 | 308 |
4 | 356.8519 | 4 | 308 |
5 | 414.6901 | 5 | 308 |
6 | 382.2038 | 6 | 308 |
Analizziamo il tempo di reazione medio in relazione ai giorni di deprivazione del sonno, osservando come questo varia per ciascun soggetto coinvolto nello studio.
def plot_data(data):
fig, axes = plt.subplots(3, 6, figsize=(16, 8), sharey=True, sharex=True, dpi=300, constrained_layout=True)
fig.subplots_adjust(left=0.075, right=0.975, bottom=0.075, top=0.925, wspace=0.03)
axes_flat = axes.ravel()
for i, subject in enumerate(data["Subject"].unique()):
ax = axes_flat[i]
idx = data.index[data["Subject"] == subject].tolist()
days = data.loc[idx, "Days"].values
reaction = data.loc[idx, "Reaction"].values
# Plot observed data points
ax.scatter(days, reaction, color="C0", ec="black", alpha=0.7)
# Add a title
ax.set_title(f"Subject: {subject}", fontsize=14)
# Remove axis labels for individual plots
for ax in axes_flat:
ax.set_xlabel('')
ax.set_ylabel('')
# Set x-axis ticks for the last row
for ax in axes[-1]:
ax.xaxis.set_ticks([0, 2, 4, 6, 8])
return axes
plot_data(data)
plt.tight_layout();
![../_images/d2cc72c2ea25733cd74f9c2a2e9e2c8c5d24df7e9dbb096a05a9d24f1013a733.png](../_images/d2cc72c2ea25733cd74f9c2a2e9e2c8c5d24df7e9dbb096a05a9d24f1013a733.png)
Modello complete pooling#
Il modello complete pooling tratta tutte le osservazioni come se fossero indipendenti, aggregandole in un unico gruppo. In questo modello, le rette di regressione lineare per tutti i soggetti hanno la stessa pendenza e la stessa intercetta. Il modello può essere descritto esplicitamente come segue:
Se disponiamo di \( m \) soggetti e ciascun soggetto \( i \) ha \( n_i \) osservazioni, il modello può essere definito da:
dove:
\(\text{Reaction}_{ij}\) è il tempo di reazione per il soggetto \( i \) al giorno \( j \).
\(\text{Days}_{ij}\) è il numero di giorni per il soggetto \( i \) all’osservazione \( j \).
\(\alpha\) è l’intercetta comune a tutti i soggetti.
\(\beta\) è la pendenza comune a tutti i soggetti.
\(\epsilon_{ij}\) è il termine di errore casuale per il soggetto \( i \) all’osservazione \( j \), che si suppone sia distribuito normalmente con media 0 e varianza costante \( \sigma^2 \).
Questo modello non distingue tra i gruppi di osservazoni che appartengono a soggetti diversi e stima un’unica pendenza e un’unica intercetta dai dati di tutti i soggetti. In Bambi, questo modello può essere specificato utilizzando solo la variabile Days
come predittore, senza includere il Subject
come fattore.
model_pooling = bmb.Model("Reaction ~ 1 + Days", data)
Procediamo con l’esecuzione del campionamento MCMC, utilizzando il metodo NUTS specifico per il campionatore JAX. Questo può essere fatto semplicemente passando l’opzione method="nuts_numpyro"
durante la chiamata al campionamento. In questo modo, stiamo invocando direttamente il campionatore JAX, sfruttando le sue caratteristiche avanzate.
results_pooling = model_pooling.fit(
method="nuts_numpyro", idata_kwargs={"log_likelihood": True}
)
Show code cell output
Compiling...
Compilation time = 0:00:02.468250
Sampling...
0%| | 0/2000 [00:00<?, ?it/s]
Compiling.. : 0%| | 0/2000 [00:00<?, ?it/s]
0%| | 0/2000 [00:00<?, ?it/s]
Compiling.. : 0%| | 0/2000 [00:00<?, ?it/s]
0%| | 0/2000 [00:00<?, ?it/s]
Compiling.. : 0%| | 0/2000 [00:00<?, ?it/s]
0%| | 0/2000 [00:00<?, ?it/s]
Compiling.. : 0%| | 0/2000 [00:00<?, ?it/s]
Running chain 1: 0%| | 0/2000 [00:02<?, ?it/s]
Running chain 3: 0%| | 0/2000 [00:02<?, ?it/s]
Running chain 2: 0%| | 0/2000 [00:02<?, ?it/s]
Running chain 0: 0%| | 0/2000 [00:02<?, ?it/s]
Running chain 0: 100%|█████████████████████████████████████████████████████| 2000/2000 [00:02<00:00, 898.30it/s]
Running chain 1: 100%|█████████████████████████████████████████████████████| 2000/2000 [00:02<00:00, 899.14it/s]
Running chain 2: 100%|█████████████████████████████████████████████████████| 2000/2000 [00:02<00:00, 899.96it/s]
Running chain 3: 100%|█████████████████████████████████████████████████████| 2000/2000 [00:02<00:00, 900.65it/s]
Sampling time = 0:00:02.580513
Transforming variables...
Transformation time = 0:00:00.106217
Computing Log Likelihood...
Log Likelihood time = 0:00:00.195465
Un sommario numerico delle distribuzioni a posteriori dei parametri si ottiene con az.summary
.
az.summary(results_pooling, round_to=2)
mean | sd | hdi_3% | hdi_97% | mcse_mean | mcse_sd | ess_bulk | ess_tail | r_hat | |
---|---|---|---|---|---|---|---|---|---|
Intercept | 245.47 | 10.98 | 224.78 | 265.60 | 0.18 | 0.13 | 3723.47 | 2899.64 | 1.0 |
Days | 11.36 | 1.86 | 7.72 | 14.67 | 0.03 | 0.02 | 4393.83 | 2897.52 | 1.0 |
Reaction_sigma | 51.22 | 3.06 | 45.23 | 56.66 | 0.05 | 0.03 | 4422.30 | 3025.60 | 1.0 |
Modello no-pooling#
Il modello no-pooling tratta ogni soggetto come indipendente e adatta una retta di regressione separata per ciascun soggetto. Se disponiamo di \( m \) soggetti e ciascun soggetto \( i \) ha \( n_i \) osservazioni, il modello può essere definito da:
dove:
\(\text{Reaction}_{ij}\) è il tempo di reazione per il soggetto \( i \) al giorno \( j \).
\(\text{Days}_{ij}\) è il numero di giorni per il soggetto \( i \) all’osservazione \( j \).
\(\alpha_i\) è l’intercetta per il soggetto \( i \).
\(\beta_i\) è la pendenza per il soggetto \( i \).
\(\epsilon_{ij}\) è il termine di errore casuale per il soggetto \( i \) all’osservazione \( j \), che si suppone sia distribuito normalmente con media 0 e varianza costante \( \sigma^2 \).
Questo modello non fa alcuna ipotesi sulle relazioni tra diversi soggetti e stima la pendenza e l’intercetta di ciascun soggetto indipendentemente dagli altri soggetti. In Bambi, questo modello viene specificato con l’interazione tra Days
e Subject
, come descritto in seguito.
model_no_pooling = bmb.Model("Reaction ~ Days * C(Subject)", data=data)
results_no_pooling = model_no_pooling.fit(
method="nuts_numpyro", idata_kwargs={"log_likelihood": True}
)
Show code cell output
Compiling...
Compilation time = 0:00:01.611346
Sampling...
0%| | 0/2000 [00:00<?, ?it/s]
Compiling.. : 0%| | 0/2000 [00:00<?, ?it/s]
0%| | 0/2000 [00:00<?, ?it/s]
Compiling.. : 0%| | 0/2000 [00:00<?, ?it/s]
0%| | 0/2000 [00:00<?, ?it/s]
Compiling.. : 0%| | 0/2000 [00:00<?, ?it/s]
0%| | 0/2000 [00:00<?, ?it/s]
Compiling.. : 0%| | 0/2000 [00:00<?, ?it/s]
Running chain 2: 0%| | 0/2000 [00:02<?, ?it/s]
Running chain 1: 0%| | 0/2000 [00:02<?, ?it/s]
Running chain 0: 0%| | 0/2000 [00:02<?, ?it/s]
Running chain 3: 0%| | 0/2000 [00:02<?, ?it/s]
Running chain 0: 5%|██▋ | 100/2000 [00:03<00:06, 289.27it/s]
Running chain 1: 5%|██▋ | 100/2000 [00:03<00:06, 286.57it/s]
Running chain 3: 5%|██▋ | 100/2000 [00:03<00:06, 282.98it/s]
Running chain 2: 5%|██▋ | 100/2000 [00:03<00:06, 281.84it/s]
Running chain 0: 35%|██████████████████▌ | 700/2000 [00:03<00:00, 1955.20it/s]
Running chain 1: 35%|██████████████████▌ | 700/2000 [00:03<00:00, 1934.42it/s]
Running chain 2: 35%|██████████████████▌ | 700/2000 [00:03<00:00, 1911.00it/s]
Running chain 3: 40%|█████████████████████▏ | 800/2000 [00:03<00:00, 2132.80it/s]
Running chain 0: 70%|████████████████████████████████████▍ | 1400/2000 [00:03<00:00, 3429.17it/s]
Running chain 1: 70%|████████████████████████████████████▍ | 1400/2000 [00:03<00:00, 3402.91it/s]
Running chain 2: 75%|███████████████████████████████████████ | 1500/2000 [00:03<00:00, 3540.94it/s]
Running chain 3: 75%|███████████████████████████████████████ | 1500/2000 [00:03<00:00, 3425.99it/s]
Running chain 0: 100%|█████████████████████████████████████████████████████| 2000/2000 [00:03<00:00, 595.64it/s]
Running chain 1: 100%|█████████████████████████████████████████████████████| 2000/2000 [00:03<00:00, 595.88it/s]
Running chain 2: 100%|█████████████████████████████████████████████████████| 2000/2000 [00:03<00:00, 596.15it/s]
Running chain 3: 100%|█████████████████████████████████████████████████████| 2000/2000 [00:03<00:00, 596.41it/s]
Sampling time = 0:00:03.532101
Transforming variables...
Transformation time = 0:00:00.136384
Computing Log Likelihood...
Log Likelihood time = 0:00:00.199458
az.summary(results_no_pooling, round_to=2)
mean | sd | hdi_3% | hdi_97% | mcse_mean | mcse_sd | ess_bulk | ess_tail | r_hat | |
---|---|---|---|---|---|---|---|---|---|
Intercept | 245.47 | 22.74 | 200.27 | 285.68 | 1.18 | 0.84 | 368.26 | 905.46 | 1.01 |
Days | 21.57 | 3.81 | 14.35 | 28.80 | 0.20 | 0.14 | 366.37 | 752.88 | 1.01 |
C(Subject)[309] | -53.93 | 32.91 | -114.95 | 6.81 | 1.30 | 0.92 | 639.40 | 1308.64 | 1.00 |
C(Subject)[310] | -27.14 | 32.61 | -90.89 | 31.03 | 1.25 | 0.89 | 677.38 | 1785.75 | 1.00 |
C(Subject)[330] | 11.78 | 33.16 | -50.57 | 73.10 | 1.29 | 0.91 | 660.26 | 1394.69 | 1.00 |
C(Subject)[331] | 42.86 | 32.81 | -19.87 | 102.82 | 1.29 | 0.91 | 653.68 | 1560.82 | 1.00 |
C(Subject)[332] | 65.76 | 32.90 | 1.41 | 125.85 | 1.26 | 0.89 | 682.22 | 1492.93 | 1.00 |
C(Subject)[333] | 18.15 | 33.10 | -43.79 | 81.87 | 1.25 | 0.88 | 702.09 | 1317.04 | 1.00 |
C(Subject)[334] | -43.69 | 33.12 | -105.27 | 19.50 | 1.31 | 0.93 | 640.53 | 1726.88 | 1.00 |
C(Subject)[335] | 25.69 | 32.60 | -33.21 | 86.50 | 1.24 | 0.88 | 685.77 | 1710.30 | 1.00 |
C(Subject)[337] | 22.57 | 32.46 | -36.22 | 84.38 | 1.22 | 0.87 | 703.65 | 1680.08 | 1.00 |
C(Subject)[349] | -49.07 | 33.07 | -111.84 | 11.51 | 1.31 | 0.92 | 641.53 | 1342.73 | 1.00 |
C(Subject)[350] | -44.41 | 32.84 | -105.18 | 16.52 | 1.24 | 0.88 | 700.89 | 1868.37 | 1.00 |
C(Subject)[351] | 1.64 | 33.13 | -58.61 | 63.51 | 1.28 | 0.91 | 667.28 | 1394.80 | 1.00 |
C(Subject)[352] | 71.43 | 32.47 | 14.88 | 135.93 | 1.31 | 0.92 | 615.99 | 1461.22 | 1.00 |
C(Subject)[369] | -6.47 | 33.45 | -71.90 | 54.22 | 1.32 | 0.94 | 640.54 | 1405.03 | 1.00 |
C(Subject)[370] | -51.98 | 33.21 | -112.61 | 12.17 | 1.28 | 0.91 | 674.26 | 1338.50 | 1.00 |
C(Subject)[371] | -12.24 | 32.45 | -74.69 | 45.52 | 1.27 | 0.90 | 651.84 | 1817.77 | 1.00 |
C(Subject)[372] | 23.20 | 32.37 | -35.94 | 86.27 | 1.31 | 0.93 | 608.36 | 1572.94 | 1.00 |
Days:C(Subject)[309] | -17.19 | 5.53 | -27.78 | -7.06 | 0.22 | 0.16 | 628.65 | 1360.40 | 1.00 |
Days:C(Subject)[310] | -17.72 | 5.48 | -27.90 | -7.32 | 0.21 | 0.15 | 659.47 | 1831.34 | 1.00 |
Days:C(Subject)[330] | -13.59 | 5.60 | -24.78 | -3.66 | 0.22 | 0.16 | 637.48 | 1275.37 | 1.00 |
Days:C(Subject)[331] | -16.70 | 5.56 | -26.65 | -6.12 | 0.22 | 0.15 | 665.48 | 1545.23 | 1.00 |
Days:C(Subject)[332] | -19.16 | 5.57 | -30.16 | -9.08 | 0.21 | 0.15 | 683.89 | 1556.75 | 1.00 |
Days:C(Subject)[333] | -10.70 | 5.53 | -21.59 | -0.93 | 0.21 | 0.15 | 695.48 | 1459.69 | 1.00 |
Days:C(Subject)[334] | -3.45 | 5.55 | -14.15 | 6.72 | 0.22 | 0.15 | 649.91 | 1157.77 | 1.00 |
Days:C(Subject)[335] | -25.75 | 5.50 | -35.53 | -15.06 | 0.21 | 0.15 | 694.83 | 1765.07 | 1.00 |
Days:C(Subject)[337] | 0.87 | 5.43 | -9.19 | 11.12 | 0.21 | 0.15 | 670.37 | 1467.96 | 1.00 |
Days:C(Subject)[349] | -5.22 | 5.49 | -15.06 | 5.63 | 0.22 | 0.15 | 640.98 | 1308.49 | 1.00 |
Days:C(Subject)[350] | 1.77 | 5.48 | -8.10 | 11.99 | 0.21 | 0.15 | 694.17 | 1696.08 | 1.00 |
Days:C(Subject)[351] | -13.07 | 5.50 | -23.51 | -3.40 | 0.21 | 0.15 | 664.18 | 1178.23 | 1.00 |
Days:C(Subject)[352] | -14.31 | 5.50 | -24.27 | -3.36 | 0.22 | 0.16 | 607.89 | 1619.49 | 1.00 |
Days:C(Subject)[369] | -7.80 | 5.59 | -18.55 | 2.41 | 0.23 | 0.16 | 613.97 | 1287.20 | 1.00 |
Days:C(Subject)[370] | -0.93 | 5.53 | -11.93 | 8.84 | 0.21 | 0.15 | 668.68 | 1172.90 | 1.00 |
Days:C(Subject)[371] | -9.25 | 5.42 | -20.02 | 0.22 | 0.22 | 0.15 | 635.56 | 1838.60 | 1.00 |
Days:C(Subject)[372] | -10.50 | 5.45 | -20.94 | -0.50 | 0.22 | 0.16 | 594.14 | 1628.94 | 1.00 |
Reaction_sigma | 25.76 | 1.78 | 22.33 | 29.05 | 0.03 | 0.02 | 3478.85 | 2467.45 | 1.00 |
Per ricavare i coefficienti \(\alpha\) delle regressioni individuali, dobbiamo sommare Intercept
al valore del singolo soggetto. Per esempio, per il soggetto 309 abbiamo
246.98 + -55.29
191.69
Facciamo lo stesso per la pendenza individuale delle rette di regressione. Per il soggetto 309 otteniamo
21.30 + -16.97
4.330000000000002
Questi valori sono identici a quelli che si otterrebbero se adattassimo il modello di regressione separatamente per ciascun soggetto. In effetti, abbiamo fatto proprio questo, anche se utilizzando un modello unico. Per esempio, esaminiamo solo i dati del soggetto 309.
data_subject_309 = data[data["Subject"] == 309]
data_subject_309.shape
(8, 3)
Per questi dati, stimiamo l’intercetta e la pendenza della retta di regressione usando l’approccio frequentista mediante la funzione linear_regression
del pacchetto pingouin
.
result = pg.linear_regression(data_subject_309["Days"], data_subject_309["Reaction"])
print(result)
names coef se T pval r2 \
0 Intercept 191.576970 3.723259 51.454104 3.615788e-09 0.890144
1 Days 4.357144 0.624898 6.972569 4.325982e-04 0.890144
adj_r2 CI[2.5%] CI[97.5%]
0 0.871834 182.466483 200.687457
1 0.871834 2.828074 5.886214
Si noti che i risultati ottenuti sono sostanzialmente gli stessi, con solo qualche minima differenza numerica. Questa discrepanza deriva dalla diversità degli approcci utilizzati: in un caso abbiamo applicato un metodo bayesiano, mentre nell’altro abbiamo adottato una tecnica di stima frequentista.
Modello partial pooling#
Il modello gerarchico, conosciuto anche come modello di “partial pooling”, consente di gestire la complessità presente nei dati raggruppati o clusterizzati, come nel caso presente. La regressione lineare classica presume che ogni osservazione sia indipendente dalle altre, ma questa ipotesi viene meno quando i dati sono organizzati in gruppi. Le osservazioni all’interno dello stesso gruppo tendono ad essere più correlate tra loro rispetto a quelle in gruppi diversi. Trascurare questa struttura gerarchica potrebbe portare a stime errate e conclusioni fuorvianti.
Il modello gerarchico affronta questo problema introducendo la nozione di effetti casuali, in contrapposizione agli effetti fissi del modello classico. Gli effetti fissi rappresentano l’effetto medio di una variabile predittiva su tutti gli individui o gruppi, mentre gli effetti casuali considerano come l’effetto di una variabile possa variare da un gruppo all’altro. Mentre gli effetti fissi sono comuni a tutto il dataset, gli effetti casuali tengono conto delle differenze tra i gruppi.
Questo modello gerarchico unisce effetti fissi e casuali per fornire una rappresentazione più accurata dei dati, quando questi mostrano relazioni gerarchiche o raggruppate. Il modello gerarchico di “partial pooling” considera le somiglianze tra i soggetti stimando un’intercetta e una pendenza comuni, ma consente anche variazioni individuali attorno a questi valori medi.
Possiamo rappresentare matematicamente il modello come segue:
dove:
\(\text{Reaction}_{ij}\) è il tempo di reazione del soggetto \(i\) al giorno \(j\).
\(\text{Days}_{ij}\) è il numero di giorni per il soggetto \(i\) all’osservazione \(j\).
\(\alpha_i\) è l’intercetta per il soggetto \(i\), che segue la distribuzione \(\alpha_i \sim \mathcal{N}(\alpha, \tau_\alpha^2)\).
\(\beta_i\) è la pendenza per il soggetto \(i\), che segue la distribuzione \(\beta_i \sim \mathcal{N}(\beta, \tau_\beta^2)\).
\(\epsilon_{ij}\) è l’errore casuale per il soggetto \(i\) all’osservazione \(j\), distribuito normalmente con media 0 e varianza costante \(\sigma^2\).
I parametri \(\alpha\) e \(\beta\) rappresentano l’intercetta e la pendenza medie per tutti i soggetti, e le varianze \(\tau_\alpha^2\) e \(\tau_\beta^2\) quantificano le differenze tra gli individui.
In questo modo, il modello gerarchico riesce a rappresentare sia le informazioni comuni a tutti i soggetti, sia le differenze individuali, considerando sia gli effetti fissi che quelli casuali. Può quindi offrire una visione più completa e realistica dei dati, tenendo conto della loro struttura gerarchica. In Bambi, questo modello può essere specificato utilizzando la variabile Days
come predittore e includendo Subject
come effetto casuale.
model_partial_pooling = bmb.Model(
"Reaction ~ 1 + Days + (Days | Subject)", data, categorical="Subject"
)
Eseguiamo il campionamento.
results_partial_pooling = model_partial_pooling.fit(
method="nuts_numpyro", idata_kwargs={"log_likelihood": True}
)
Show code cell output
Compiling...
Compilation time = 0:00:01.874257
Sampling...
0%| | 0/2000 [00:00<?, ?it/s]
Compiling.. : 0%| | 0/2000 [00:00<?, ?it/s]
0%| | 0/2000 [00:00<?, ?it/s]
Compiling.. : 0%| | 0/2000 [00:00<?, ?it/s]
0%| | 0/2000 [00:00<?, ?it/s]
Compiling.. : 0%| | 0/2000 [00:00<?, ?it/s]
0%| | 0/2000 [00:00<?, ?it/s]
Compiling.. : 0%| | 0/2000 [00:00<?, ?it/s]
Running chain 3: 0%| | 0/2000 [00:03<?, ?it/s]
Running chain 2: 0%| | 0/2000 [00:03<?, ?it/s]
Running chain 0: 0%| | 0/2000 [00:03<?, ?it/s]
Running chain 1: 0%| | 0/2000 [00:03<?, ?it/s]
Running chain 0: 20%|██████████▌ | 400/2000 [00:03<00:00, 3976.11it/s]
Running chain 1: 20%|██████████▌ | 400/2000 [00:03<00:00, 3767.55it/s]
Running chain 3: 20%|██████████▌ | 400/2000 [00:03<00:00, 3565.60it/s]
Running chain 2: 20%|██████████▌ | 400/2000 [00:03<00:00, 3478.10it/s]
Running chain 1: 45%|███████████████████████▊ | 900/2000 [00:03<00:00, 4298.09it/s]
Running chain 0: 50%|██████████████████████████ | 1000/2000 [00:03<00:00, 4671.07it/s]
Running chain 3: 50%|██████████████████████████ | 1000/2000 [00:03<00:00, 4490.92it/s]
Running chain 2: 50%|██████████████████████████ | 1000/2000 [00:03<00:00, 4388.12it/s]
Running chain 1: 70%|████████████████████████████████████▍ | 1400/2000 [00:03<00:00, 4540.22it/s]
Running chain 0: 75%|███████████████████████████████████████ | 1500/2000 [00:03<00:00, 4545.55it/s]
Running chain 3: 80%|█████████████████████████████████████████▌ | 1600/2000 [00:03<00:00, 5010.95it/s]
Running chain 2: 75%|███████████████████████████████████████ | 1500/2000 [00:03<00:00, 4393.03it/s]
Running chain 1: 95%|█████████████████████████████████████████████████▍ | 1900/2000 [00:03<00:00, 4697.24it/s]
Running chain 0: 100%|████████████████████████████████████████████████████| 2000/2000 [00:03<00:00, 4418.82it/s]
Running chain 2: 100%|████████████████████████████████████████████████████| 2000/2000 [00:03<00:00, 4358.10it/s]
Running chain 0: 100%|█████████████████████████████████████████████████████| 2000/2000 [00:03<00:00, 553.15it/s]
Running chain 1: 100%|█████████████████████████████████████████████████████| 2000/2000 [00:03<00:00, 553.34it/s]
Running chain 2: 100%|█████████████████████████████████████████████████████| 2000/2000 [00:03<00:00, 553.58it/s]
Running chain 3: 100%|█████████████████████████████████████████████████████| 2000/2000 [00:03<00:00, 553.87it/s]
Sampling time = 0:00:03.847272
Transforming variables...
Transformation time = 0:00:00.178392
Computing Log Likelihood...
Log Likelihood time = 0:00:00.245859
Esaminiamo i risultati.
az.summary(results_partial_pooling, round_to=2)
mean | sd | hdi_3% | hdi_97% | mcse_mean | mcse_sd | ess_bulk | ess_tail | r_hat | |
---|---|---|---|---|---|---|---|---|---|
Intercept | 245.56 | 9.49 | 226.92 | 262.34 | 0.19 | 0.13 | 2605.88 | 2973.88 | 1.0 |
Days | 11.38 | 1.87 | 7.66 | 14.76 | 0.05 | 0.03 | 1542.81 | 1711.13 | 1.0 |
Reaction_sigma | 26.09 | 1.82 | 22.88 | 29.66 | 0.03 | 0.02 | 3713.58 | 2671.94 | 1.0 |
1|Subject_sigma | 30.90 | 8.65 | 15.71 | 47.55 | 0.20 | 0.14 | 1789.00 | 1992.80 | 1.0 |
Days|Subject_sigma | 6.80 | 1.61 | 4.10 | 9.91 | 0.04 | 0.03 | 1717.49 | 2221.50 | 1.0 |
1|Subject[308] | 10.17 | 18.52 | -26.00 | 43.72 | 0.30 | 0.26 | 3859.40 | 2845.22 | 1.0 |
1|Subject[309] | -42.59 | 19.92 | -79.89 | -6.30 | 0.34 | 0.24 | 3300.36 | 2634.49 | 1.0 |
1|Subject[310] | -25.90 | 19.19 | -63.19 | 9.42 | 0.31 | 0.24 | 3799.22 | 2945.74 | 1.0 |
1|Subject[330] | 4.27 | 18.50 | -31.48 | 38.53 | 0.29 | 0.27 | 4050.65 | 2933.51 | 1.0 |
1|Subject[331] | 21.44 | 18.97 | -13.98 | 56.67 | 0.33 | 0.24 | 3327.01 | 3008.24 | 1.0 |
1|Subject[332] | 34.44 | 19.27 | -0.73 | 70.54 | 0.35 | 0.25 | 3108.77 | 2415.76 | 1.0 |
1|Subject[333] | 10.93 | 18.12 | -20.66 | 47.89 | 0.30 | 0.25 | 3613.28 | 2732.10 | 1.0 |
1|Subject[334] | -21.90 | 18.58 | -56.19 | 13.14 | 0.34 | 0.24 | 2964.93 | 3023.44 | 1.0 |
1|Subject[335] | 1.14 | 19.31 | -35.92 | 36.41 | 0.31 | 0.29 | 3934.37 | 3136.78 | 1.0 |
1|Subject[337] | 26.07 | 19.48 | -10.55 | 62.44 | 0.32 | 0.24 | 3583.87 | 3246.23 | 1.0 |
1|Subject[349] | -26.95 | 18.64 | -62.97 | 7.27 | 0.32 | 0.23 | 3389.64 | 3372.79 | 1.0 |
1|Subject[350] | -17.16 | 18.18 | -49.45 | 19.09 | 0.29 | 0.23 | 3883.39 | 2972.05 | 1.0 |
1|Subject[351] | -2.63 | 18.18 | -35.60 | 32.44 | 0.29 | 0.28 | 3914.73 | 3278.06 | 1.0 |
1|Subject[352] | 42.69 | 19.53 | 6.70 | 80.27 | 0.35 | 0.25 | 3114.44 | 2621.43 | 1.0 |
1|Subject[369] | -1.88 | 17.96 | -35.95 | 31.78 | 0.28 | 0.28 | 4191.50 | 2747.54 | 1.0 |
1|Subject[370] | -24.92 | 19.32 | -60.57 | 10.62 | 0.32 | 0.24 | 3580.67 | 2963.86 | 1.0 |
1|Subject[371] | -7.14 | 18.09 | -43.23 | 24.20 | 0.29 | 0.24 | 3951.87 | 2930.37 | 1.0 |
1|Subject[372] | 14.49 | 18.04 | -18.72 | 48.82 | 0.31 | 0.23 | 3304.70 | 2874.23 | 1.0 |
Days|Subject[308] | 8.14 | 3.37 | 2.00 | 14.52 | 0.06 | 0.04 | 3017.59 | 3064.36 | 1.0 |
Days|Subject[309] | -8.33 | 3.51 | -14.86 | -2.00 | 0.07 | 0.05 | 2616.61 | 2795.29 | 1.0 |
Days|Subject[310] | -7.33 | 3.38 | -13.61 | -0.76 | 0.06 | 0.04 | 3019.15 | 2867.75 | 1.0 |
Days|Subject[330] | -2.09 | 3.29 | -8.34 | 3.94 | 0.06 | 0.05 | 2644.52 | 2752.60 | 1.0 |
Days|Subject[331] | -3.04 | 3.35 | -9.64 | 2.75 | 0.07 | 0.05 | 2614.82 | 2598.17 | 1.0 |
Days|Subject[332] | -3.90 | 3.41 | -10.24 | 2.53 | 0.07 | 0.05 | 2585.41 | 3028.24 | 1.0 |
Days|Subject[333] | 0.56 | 3.21 | -5.38 | 6.81 | 0.06 | 0.04 | 3018.22 | 2857.95 | 1.0 |
Days|Subject[334] | 3.16 | 3.32 | -2.84 | 9.61 | 0.07 | 0.05 | 2588.82 | 3057.97 | 1.0 |
Days|Subject[335] | -11.13 | 3.44 | -17.49 | -4.72 | 0.06 | 0.04 | 3053.41 | 2724.33 | 1.0 |
Days|Subject[337] | 9.88 | 3.45 | 3.27 | 16.26 | 0.07 | 0.05 | 2818.15 | 2759.29 | 1.0 |
Days|Subject[349] | 1.42 | 3.34 | -4.85 | 7.58 | 0.07 | 0.05 | 2518.89 | 2598.97 | 1.0 |
Days|Subject[350] | 7.29 | 3.35 | 1.23 | 13.88 | 0.06 | 0.05 | 2815.58 | 2485.05 | 1.0 |
Days|Subject[351] | -2.14 | 3.30 | -8.28 | 4.19 | 0.06 | 0.05 | 2783.30 | 2594.87 | 1.0 |
Days|Subject[352] | 0.29 | 3.43 | -6.13 | 6.57 | 0.07 | 0.05 | 2586.07 | 2778.74 | 1.0 |
Days|Subject[369] | 1.62 | 3.21 | -4.20 | 7.87 | 0.06 | 0.05 | 2691.95 | 2712.13 | 1.0 |
Days|Subject[370] | 4.80 | 3.42 | -1.90 | 11.03 | 0.06 | 0.05 | 2963.62 | 3184.74 | 1.0 |
Days|Subject[371] | 0.14 | 3.22 | -5.98 | 6.08 | 0.06 | 0.05 | 2931.63 | 2898.19 | 1.0 |
Days|Subject[372] | 0.96 | 3.23 | -5.15 | 7.04 | 0.06 | 0.05 | 2512.45 | 2784.32 | 1.0 |
Consideriamo il soggetto 309. Per questo soggetto l’intercetta è
245.25 + -43.14
202.11
e la pendenza della retta di regressione è
11.33 + -8.14
3.1899999999999995
Si noti che questi valori sono diversi da quelli ottenuti con la procedura di no-pooling. Entrambi i modelli di no pooling e il modello gerarchico di partial pooling riconoscono che ci possono essere differenze tra i diversi gruppi (o soggetti) nel dataset, ma gestiscono queste differenze in modi diversi.
Nel modello di no pooling, ogni gruppo viene trattato in modo completamente indipendente dagli altri. Ogni intercetta e pendenza viene stimata separatamente per ogni gruppo, senza fare riferimento agli altri gruppi. In altre parole, si adatta una regressione lineare separata per ciascun gruppo. Ciò significa che se si hanno molti gruppi, ci saranno molti parametri da stimare.
Questo approccio può catturare le differenze tra i gruppi molto accuratamente se ci sono molte osservazioni in ogni gruppo, ma può essere problematico se ci sono poche osservazioni per gruppo. Inoltre, non sfrutta le informazioni comuni tra i gruppi e può portare a stime molto variabili.
Il modello gerarchico di partial pooling, invece, riconosce che anche se ci sono differenze tra i gruppi, questi potrebbero condividere alcune caratteristiche comuni. Invece di stimare le intercette e pendenze completamente separatamente per ogni gruppo, stima una media comune e una varianza comune per l’intercetta e la pendenza, e poi permette a ciascun gruppo di variare attorno a questi valori comuni.
Questo porta al concetto di “shrinkage”. Le stime delle intercette e pendenze per ciascun gruppo tendono a essere “retratte” o “compresso” verso i valori medi. Se un gruppo ha poche osservazioni, la sua stima sarà più fortemente influenzata dalla media comune. Se ha molte osservazioni, la sua stima sarà meno influenzata dalla media comune. In questo modo, il modello riesce a bilanciare tra catturare le differenze tra i gruppi e sfruttare le informazioni comuni.
In sintesi, la differenza principale tra il modello di no pooling e il modello gerarchico di partial pooling sta nel modo in cui gestiscono le intercette e pendenze individuali:
Il modello di no pooling tratta ogni gruppo separatamente, stimando le intercette e pendenze individuali senza considerare gli altri gruppi.
Il modello gerarchico di partial pooling stima le intercette e pendenze comuni e permette a ciascun gruppo di variare attorno a questi valori comuni, utilizzando il concetto di shrinkage.
Il modello di no pooling può essere più adatto se i gruppi sono veramente indipendenti e molto diversi tra loro, mentre il modello gerarchico può essere più efficace quando ci sono somiglianze tra i gruppi che possono essere sfruttate per ottenere stime più precise e robuste.
Modello gerarchico e la distribuzione dei coefficienti#
Più speficicatamente, nel modello gerarchico di partial pooling, gli effetti casuali (come le intercette e le pendenze per ogni gruppo o soggetto) sono considerati come realizzazioni di variabili casuali. Questa è una differenza fondamentale rispetto al modello di no pooling, dove ogni coefficiente è trattato come un parametro fisso.
Nel modello gerarchico, si suppone che gli effetti casuali seguano una distribuzione normale. Questo significa che ogni coefficiente individuale (ad esempio, l’intercetta per un particolare soggetto) è considerato come un campione estratto da una popolazione normale. La popolazione ha una media e una varianza che sono comuni a tutti i gruppi, e queste vengono stimate dai dati.
Ad esempio, le intercette individuali \(\alpha_i\) possono essere modellate come:
dove \(\alpha\) è l’intercetta media per tutti i soggetti e \(\tau_\alpha^2\) è la varianza delle intercette tra i soggetti. Analogamente, le pendenze individuali \(\beta_i\) possono essere modellate come:
dove \(\beta\) è la pendenza media e \(\tau_\beta^2\) è la varianza delle pendenze.
Implicazioni#
Questa struttura ha diverse implicazioni importanti:
Shrinkage: Come discusso in precedenza, le stime dei coefficienti individuali tendono a essere “retratte” verso i valori medi. Questo aiuta a stabilizzare le stime, specialmente quando ci sono poche osservazioni per gruppo.
Scambio di informazioni tra i gruppi: Poiché i coefficienti individuali sono considerati come estratti dalla stessa distribuzione, ciò permette uno scambio di informazioni tra i gruppi. Se un gruppo ha molte osservazioni, può aiutare a informare le stime per un gruppo con poche osservazioni.
Interpretazione gerarchica: Il modello riconosce una struttura gerarchica nei dati, con osservazioni raggruppate all’interno di gruppi, e gruppi che condividono caratteristiche comuni. Questa struttura può riflettere una realtà sottostante nella quale gli individui o i gruppi non sono completamente indipendenti l’uno dall’altro.
In conclusione, il modello gerarchico di partial pooling offre un quadro flessibile e potente per analizzare dati raggruppati o clusterizzati, riconoscendo sia le somiglianze che le differenze tra i gruppi e utilizzando una struttura probabilistica per modellare le relazioni tra di loro.
Interpretazione#
Iniziamo considerando le stime a posteriori degli effetti fissi.
az.summary(results_partial_pooling, var_names=["Intercept", "Days"], round_to=2)
mean | sd | hdi_3% | hdi_97% | mcse_mean | mcse_sd | ess_bulk | ess_tail | r_hat | |
---|---|---|---|---|---|---|---|---|---|
Intercept | 245.56 | 9.49 | 226.92 | 262.34 | 0.19 | 0.13 | 2605.88 | 2973.88 | 1.0 |
Days | 11.38 | 1.87 | 7.66 | 14.76 | 0.05 | 0.03 | 1542.81 | 1711.13 | 1.0 |
In media, il tempo di reazione medio delle persone all’inizio dello studio è compreso tra 227 e 264 millisecondi. Con ogni giorno aggiuntivo di privazione del sonno, i tempi di reazione medi aumentano, in media, tra 7.9 e 15.1 millisecondi.
L’interpretazione degli effetti fissi è semplice. Ma quest’analisi sarebbe incompleta e fuorviante se non valutiamo i termini specifici per i singoli soggetti che abbiamo aggiunto al modello. Questi termini ci dicono quanto i soggetti differiscono l’uno dall’altro in termini di tempo di reazione iniziale e dell’associazione tra giorni di privazione del sonno e tempi di reazione.
Di seguito, utilizziamo ArviZ per ottenere un traceplot delle intercetti specifiche per i soggetti 1|Subject
e delle pendenze Days|Subject
. Questo traceplot contiene due colonne. A sinistra, abbiamo le distribuzioni posteriori e a destra abbiamo i trace-plots. L’aspetto casuale stazionario, o l’apparenza di rumore bianco, ci dice che il campionatore ha raggiunto la convergenza e le catene sono ben mescolate.
az.plot_trace(
results_partial_pooling, combined=True, var_names=["1|Subject", "Days|Subject"]
)
plt.tight_layout()
![../_images/6646f53a9d810c17023e9f0b493b61d091ccb2355f620e4fcd17e9ac87a92a2a.png](../_images/6646f53a9d810c17023e9f0b493b61d091ccb2355f620e4fcd17e9ac87a92a2a.png)
Dall’ampiezza delle distribuzioni a posteriori delle intercette per i singoli soggetti possiamo vedere che il tempo di reazione medio iniziale per un determinato soggetto può differire notevolmente dalla media generale che abbiamo visto nella tabella precedente. C’è anche una grande differenza nelle pendenze. Alcuni soggetti vedono aumentare rapidamente i loro tempi di reazione quando vengono deprivati del sonno, mentre altri hanno una tolleranza migliore e peggiorano più lentamente.
Una rappresentazione grafica della stima a posteriore dei parametri e dei dati si ottiene con az.plot_forest()
.
az.plot_forest(data=results_partial_pooling, r_hat=False, combined=True, textsize=8);
![../_images/b72b55d98c5a808b8faf2acd7ff5080a16bd86f768f203fe319a7ecb6029184d.png](../_images/b72b55d98c5a808b8faf2acd7ff5080a16bd86f768f203fe319a7ecb6029184d.png)
In sintesi, il modello gerarchico cattura il comportamento che abbiamo visto nella fase di esplorazione dei dati. Le persone differiscono sia nei tempi di reazione iniziali che nel modo in cui questi tempi di reazione sono influenzati dai giorni di deprivazione del sonno. Possiamo dunque giungere alle seguenti conclusioni:
Il tempo di reazione medio delle persone aumenta quando sono deprivate del sonno.
I soggetti hanno tempi di reazione diversi all’inizio dello studio.
Alcuni soggetti sono più colpiti dalla privazione del sonno rispetto ad altri.
Ma c’è un’altra domanda a cui non abbiamo ancora risposto: I tempi di reazione iniziali sono associati a quanto la deprivazione del sonno influisce sull’evoluzione dei tempi di reazione?
Creiamo un diagramma a dispersione per visualizzare le stime a posteriori congiunte delle intercette e delle pendenze specifiche per i soggetti. Questo grafico usa colori diversi per i soggetti. Se guardiamo il quadro generale, cioè trascurando i ragruppamenti dei dati in base ai soggetti, possiamo concludere che non c’è associazione tra l’intercetta e la pendenza. In altre parole, avere tempi di reazione iniziali più bassi o più alti non dice nulla su quanto la deprivazione del sonno influisca sul tempo di reazione medio di un determinato soggetto.
D’altra parte, se guardiamo la distribuzione a posteriori congiunta per un determinato individuo, possiamo vedere una correlazione negativa tra l’intercetta e la pendenza. Questo indica che, condizionalmente a un determinato soggetto, le stime a posteriori dell’intercetta e della pendenza non sono indipendenti.
# extract a subsample from the posterior and stack the chain and draw dims
posterior = az.extract(results_partial_pooling, num_samples=500)
_, ax = plt.subplots()
results_partial_pooling.posterior.plot.scatter(
x="1|Subject", y="Days|Subject",
hue="Subject__factor_dim",
add_colorbar=False,
add_legend=False,
cmap="viridis",
edgecolors=None,
)
ax.axhline(c="0.25", ls="--")
ax.axvline(c="0.25", ls="--")
ax.set_xlabel("Subject-specific intercept")
ax.set_ylabel("Subject-specific slope");
![../_images/88a880a6a182c7053ecbcf9201471310a28702f8aece2cb2f89d22943f076cb2.png](../_images/88a880a6a182c7053ecbcf9201471310a28702f8aece2cb2f89d22943f076cb2.png)
Confronto dei modelli#
Un aspetto finale e cruciale del nostro studio riguarda il confronto tra i diversi modelli che abbiamo esaminato. La nostra intenzione è determinare quale modello fornisce una rappresentazione migliore dei dati, trovando un equilibrio appropriato tra l’accuratezza del modello e la sua complessità, cioè la parsimonia.
Per raggiungere questo scopo, faremo uso della metrica ELPD (Expected Log Predictive Density), che abbiamo introdotto in precedenza. ELPD ci consente di valutare un modello in termini di adattamento ai dati, considerando sia l’accuratezza delle previsioni che la complessità del modello.
Utilizzo di az.compare()
#
In Python, possiamo sfruttare la funzione az.compare()
per confrontare direttamente modelli bayesiani. Questa funzione accetta un dizionario contenente gli oggetti InferenceData
, risultanti dalla funzione Model.fit()
, e restituisce un dataframe. I modelli vengono ordinati dal migliore al peggiore in base ai criteri selezionati, e di default, ArviZ usa il criterio di convalida incrociata “leave one out” (LOO).
Convalida Incrociata “Leave One Out” (LOO)#
LOO è una tecnica di convalida che addestra il modello su tutti i dati disponibili tranne uno, utilizzando il singolo punto escluso come dati di test. Questo processo viene ripetuto per ogni punto dati nel set, e la media delle misure di errore fornisce una stima accurata dell’errore di generalizzazione del modello. Anche se computazionalmente impegnativa, LOO fornisce una valutazione affidabile delle prestazioni del modello. In ArviZ, la funzione loo
implementa questo metodo seguendo un approccio bayesiano.
Widely Applicable Information Criterion (WAIC)#
Oltre a LOO, possiamo anche utilizzare il criterio WAIC (Widely Applicable Information Criterion). Il WAIC è uno strumento per la selezione del modello che mira a trovare il modello ottimale in un insieme di candidati, equilibrando l’adattamento ai dati e la complessità del modello, evitando così il sovradattamento. WAIC è particolarmente utile nel contesto bayesiano, poiché tiene conto dell’incertezza associata ai parametri del modello.
Sia LOO che WAIC possono essere visti come stime empiriche dell’ELPD, fornendo un quadro comprensivo delle prestazioni dei modelli.
Conclusione#
Utilizzando la funzione az.compare()
, siamo in grado di effettuare una comparazione rapida ed efficace tra i diversi modelli, valutandoli secondo i criteri LOO e WAIC. Nel nostro caso specifico, il modello di “partial pooling” emerge come il migliore, presentando il valore ELPD stimato più alto. Questo risultato conferma la validità del modello nel rappresentare la struttura dei dati, tenendo conto delle differenze individuali all’interno dei cluster, e fornendo una stima coerente e informativa dell’effetto della deprivazione del sonno sul tempo di reazione.
models_dict = {
"pooling": results_pooling,
"no_pooling": results_no_pooling,
"partial_pooling": results_partial_pooling
}
df_compare = az.compare(models_dict)
df_compare
rank | elpd_loo | p_loo | elpd_diff | weight | se | dse | warning | scale | |
---|---|---|---|---|---|---|---|---|---|
partial_pooling | 0 | -691.873575 | 30.136117 | 0.000000 | 0.788699 | 21.269008 | 0.000000 | True | log |
no_pooling | 1 | -694.878338 | 36.507771 | 3.004763 | 0.161502 | 22.002219 | 3.591068 | True | log |
pooling | 2 | -772.160430 | 3.033118 | 80.286855 | 0.049800 | 8.991967 | 19.907327 | False | log |
È importante sottolineare che, per ottenere una stima dell’ELPD (Expected Log Predictive Density), è necessario includere l’opzione idata_kwargs={"log_likelihood": True}
all’interno della funzione responsabile dell’esecuzione del campionamento MCMC.
La figura che segue illustra visivamente le informazioni rilevanti per il confronto tra i diversi modelli. In grigio è indicata l’incertezza nella stima della differenza tra i valori ELPD dei diversi modelli.
az.plot_compare(df_compare, insample_dev=False);
![../_images/5c7ae55d00063101130e833b8501454cc1cf7ed3537c4d19a9529a640e0f198b.png](../_images/5c7ae55d00063101130e833b8501454cc1cf7ed3537c4d19a9529a640e0f198b.png)
Il confronto tra i modelli guida il processo di selezione. In particolare, la comparazione tra il modello di partial-pooling e il modello completo di pooling è resa chiara dall’elpd_diff
di 80.17 e dal suo errore standard di 19.97. Questi valori indicano inequivocabilmente che il modello di partial-pooling è superiore.
La situazione diventa più sfumata quando confrontiamo il modello di partial-pooling con il modello di no-pooling. In questo caso, le stime dell’ELPD mostrano una grande sovrapposizione, suggerendo che non c’è una differenza netta tra i due modelli in termini di adattamento ai dati.
Tuttavia, nonostante la vicinanza dei valori di ELPD, il modello di partial-pooling è da preferire. La ragione risiede nelle sue proprietà: esso fornisce stime più robuste e conservative delle differenze individuali. A differenza del modello di no-pooling, che può essere troppo sensibile alle variazioni all’interno dei cluster, il modello di partial-pooling incorpora un equilibrio tra la condivisione delle informazioni all’interno del gruppo e il riconoscimento delle differenze tra i gruppi. Questo lo rende più resistente alle fluttuazioni nei dati e offre una rappresentazione più affidabile delle relazioni sottostanti, rendendolo la scelta preferibile in questo contesto.
PPC plots#
Per affrontare il tema della selezione di modelli, Johnson et al. [JOD22] usano anche il metodo dei posterior predictive checks. Creiamo dunque i PPC plots per i tre modelli.
model_pooling_fitted = model_pooling.fit(
method="nuts_numpyro", idata_kwargs={"log_likelihood": True}
)
model_pooling.predict(model_pooling_fitted, kind="pps")
Show code cell output
Compiling...
Compilation time = 0:00:00.871858
Sampling...
0%| | 0/2000 [00:00<?, ?it/s]
Compiling.. : 0%| | 0/2000 [00:00<?, ?it/s]
0%| | 0/2000 [00:00<?, ?it/s]
Compiling.. : 0%| | 0/2000 [00:00<?, ?it/s]
0%| | 0/2000 [00:00<?, ?it/s]
Compiling.. : 0%| | 0/2000 [00:00<?, ?it/s]
0%| | 0/2000 [00:00<?, ?it/s]
Compiling.. : 0%| | 0/2000 [00:00<?, ?it/s]
Running chain 3: 0%| | 0/2000 [00:02<?, ?it/s]
Running chain 0: 0%| | 0/2000 [00:02<?, ?it/s]
Running chain 2: 0%| | 0/2000 [00:02<?, ?it/s]
Running chain 1: 0%| | 0/2000 [00:02<?, ?it/s]
Running chain 0: 100%|█████████████████████████████████████████████████████| 2000/2000 [00:02<00:00, 952.81it/s]
Running chain 1: 100%|█████████████████████████████████████████████████████| 2000/2000 [00:02<00:00, 953.33it/s]
Running chain 2: 100%|█████████████████████████████████████████████████████| 2000/2000 [00:02<00:00, 954.00it/s]
Running chain 3: 100%|█████████████████████████████████████████████████████| 2000/2000 [00:02<00:00, 954.64it/s]
Sampling time = 0:00:02.215110
Transforming variables...
Transformation time = 0:00:00.061310
Computing Log Likelihood...
Log Likelihood time = 0:00:00.145963
az.plot_ppc(model_pooling_fitted, num_pp_samples=50);
![../_images/f7d70632ee92fbe028066ba7e22bbb29d4d1984ccb51cce9800820116af4de4c.png](../_images/f7d70632ee92fbe028066ba7e22bbb29d4d1984ccb51cce9800820116af4de4c.png)
model_no_pooling_fitted = model_no_pooling.fit(
method="nuts_numpyro", idata_kwargs={"log_likelihood": True}
)
model_no_pooling.predict(model_no_pooling_fitted, kind="pps");
Show code cell output
Compiling...
Compilation time = 0:00:01.412526
Sampling...
0%| | 0/2000 [00:00<?, ?it/s]
Compiling.. : 0%| | 0/2000 [00:00<?, ?it/s]
0%| | 0/2000 [00:00<?, ?it/s]
Compiling.. : 0%| | 0/2000 [00:00<?, ?it/s]
0%| | 0/2000 [00:00<?, ?it/s]
Compiling.. : 0%| | 0/2000 [00:00<?, ?it/s]
0%| | 0/2000 [00:00<?, ?it/s]
Compiling.. : 0%| | 0/2000 [00:00<?, ?it/s]
Running chain 1: 0%| | 0/2000 [00:02<?, ?it/s]
Running chain 2: 0%| | 0/2000 [00:02<?, ?it/s]
Running chain 0: 0%| | 0/2000 [00:02<?, ?it/s]
Running chain 3: 0%| | 0/2000 [00:02<?, ?it/s]
Running chain 2: 5%|██▋ | 100/2000 [00:03<00:06, 294.44it/s]
Running chain 1: 5%|██▋ | 100/2000 [00:03<00:06, 286.16it/s]
Running chain 0: 5%|██▋ | 100/2000 [00:03<00:06, 278.07it/s]
Running chain 3: 5%|██▋ | 100/2000 [00:03<00:07, 259.93it/s]
Running chain 2: 30%|███████████████▉ | 600/2000 [00:03<00:00, 1657.63it/s]
Running chain 1: 30%|███████████████▉ | 600/2000 [00:03<00:00, 1633.38it/s]
Running chain 0: 30%|███████████████▉ | 600/2000 [00:03<00:00, 1616.54it/s]
Running chain 3: 35%|██████████████████▌ | 700/2000 [00:03<00:00, 1748.46it/s]
Running chain 1: 60%|███████████████████████████████▏ | 1200/2000 [00:03<00:00, 2873.80it/s]
Running chain 2: 65%|█████████████████████████████████▊ | 1300/2000 [00:03<00:00, 3105.23it/s]
Running chain 0: 65%|█████████████████████████████████▊ | 1300/2000 [00:03<00:00, 3044.70it/s]
Running chain 3: 65%|█████████████████████████████████▊ | 1300/2000 [00:03<00:00, 2896.40it/s]
Running chain 1: 90%|██████████████████████████████████████████████▊ | 1800/2000 [00:03<00:00, 3740.82it/s]
Running chain 2: 95%|█████████████████████████████████████████████████▍ | 1900/2000 [00:03<00:00, 3902.35it/s]
Running chain 0: 95%|█████████████████████████████████████████████████▍ | 1900/2000 [00:03<00:00, 3874.60it/s]
Running chain 3: 100%|████████████████████████████████████████████████████| 2000/2000 [00:03<00:00, 3888.70it/s]
Running chain 0: 100%|█████████████████████████████████████████████████████| 2000/2000 [00:03<00:00, 585.95it/s]
Running chain 1: 100%|█████████████████████████████████████████████████████| 2000/2000 [00:03<00:00, 586.14it/s]
Running chain 2: 100%|█████████████████████████████████████████████████████| 2000/2000 [00:03<00:00, 586.45it/s]
Running chain 3: 100%|█████████████████████████████████████████████████████| 2000/2000 [00:03<00:00, 586.72it/s]
Sampling time = 0:00:03.558055
Transforming variables...
Transformation time = 0:00:00.105202
Computing Log Likelihood...
Log Likelihood time = 0:00:00.192238
az.plot_ppc(model_no_pooling_fitted, num_pp_samples=50);
![../_images/f1f54ae9b2809cd909b0c4b9b7f050c7cba327257c88c95b7926c3a7e25e9407.png](../_images/f1f54ae9b2809cd909b0c4b9b7f050c7cba327257c88c95b7926c3a7e25e9407.png)
model_partial_pooling_fitted = model_partial_pooling.fit(
method="nuts_numpyro", idata_kwargs={"log_likelihood": True}
)
model_partial_pooling.predict(model_partial_pooling_fitted, kind="pps");
Show code cell output
Compiling...
Compilation time = 0:00:01.785339
Sampling...
0%| | 0/2000 [00:00<?, ?it/s]
Compiling.. : 0%| | 0/2000 [00:00<?, ?it/s]
0%| | 0/2000 [00:00<?, ?it/s]
Compiling.. : 0%| | 0/2000 [00:00<?, ?it/s]
0%| | 0/2000 [00:00<?, ?it/s]
Compiling.. : 0%| | 0/2000 [00:00<?, ?it/s]
0%| | 0/2000 [00:00<?, ?it/s]
Compiling.. : 0%| | 0/2000 [00:00<?, ?it/s]
Running chain 0: 0%| | 0/2000 [00:03<?, ?it/s]
Running chain 3: 0%| | 0/2000 [00:03<?, ?it/s]
Running chain 1: 0%| | 0/2000 [00:03<?, ?it/s]
Running chain 2: 0%| | 0/2000 [00:03<?, ?it/s]
Running chain 0: 20%|██████████▌ | 400/2000 [00:03<00:00, 3916.78it/s]
Running chain 3: 20%|██████████▌ | 400/2000 [00:03<00:00, 3866.07it/s]
Running chain 1: 15%|███████▉ | 300/2000 [00:03<00:00, 2889.50it/s]
Running chain 2: 20%|██████████▌ | 400/2000 [00:03<00:00, 3561.73it/s]
Running chain 0: 50%|██████████████████████████ | 1000/2000 [00:03<00:00, 4963.98it/s]
Running chain 3: 50%|██████████████████████████ | 1000/2000 [00:03<00:00, 4961.08it/s]
Running chain 1: 45%|███████████████████████▊ | 900/2000 [00:03<00:00, 4491.22it/s]
Running chain 2: 50%|██████████████████████████ | 1000/2000 [00:03<00:00, 4689.35it/s]
Running chain 0: 80%|█████████████████████████████████████████▌ | 1600/2000 [00:03<00:00, 5424.84it/s]
Running chain 1: 75%|███████████████████████████████████████ | 1500/2000 [00:03<00:00, 5007.98it/s]
Running chain 3: 85%|████████████████████████████████████████████▏ | 1700/2000 [00:03<00:00, 5525.50it/s]
Running chain 2: 80%|█████████████████████████████████████████▌ | 1600/2000 [00:03<00:00, 4940.61it/s]
Running chain 0: 100%|█████████████████████████████████████████████████████| 2000/2000 [00:03<00:00, 568.51it/s]
Running chain 1: 100%|█████████████████████████████████████████████████████| 2000/2000 [00:03<00:00, 568.70it/s]
Running chain 2: 100%|█████████████████████████████████████████████████████| 2000/2000 [00:03<00:00, 568.96it/s]
Running chain 3: 100%|█████████████████████████████████████████████████████| 2000/2000 [00:03<00:00, 569.18it/s]
Sampling time = 0:00:03.699627
Transforming variables...
Transformation time = 0:00:00.135023
Computing Log Likelihood...
Log Likelihood time = 0:00:00.240652
az.plot_ppc(model_partial_pooling_fitted, num_pp_samples=50);
![../_images/00f854bad2764b25211ebb71095cc1399c50955fb72eab252b04e2fccb6cea53.png](../_images/00f854bad2764b25211ebb71095cc1399c50955fb72eab252b04e2fccb6cea53.png)
In questo contesto specifico, l’analisi tramite i PPC (Posterior Predictive Checks) plots non rivela differenze evidenti tra i tre modelli in esame: tutti sembrano egualmente adeguati nell’adattarsi ai dati. Di conseguenza, i PPC plots non forniscono ulteriori chiarimenti o conferme alle conclusioni già raggiunte attraverso il confronto tra modelli basato sulla differenza ELPD (Expected Log Predictive Density). In altre parole, l’analisi visiva tramite i PPC plots non aggiunge valore o informazioni supplementari a quanto già dedotto dalle metriche di confronto.
Commenti conclusivi#
In questo capitolo, abbiamo confrontato i modelli di pooling, no pooling, e partial pooling utilizzando i dati del “sleepstudy.” Ogni modello ha presentato aspetti distintivi: il pooling per la sua struttura comune, il no pooling per la sua indipendenza tra i gruppi, e il partial pooling come un compromesso equilibrato.
L’analisi basata sulla differenza Expected Log Predictive Density (ELPD) è stata cruciale nella selezione del modello più appropriato. Benché ciascun modello abbia avuto i propri vantaggi, la valutazione tramite ELPD ha fornito una misura oggettiva della qualità di adattamento, guidando la scelta verso il modello che meglio rappresenta la struttura sottostante dei dati.
In conclusione, l’approccio combinato di comprendere le caratteristiche dei modelli e applicare metodi quantitativi come l’ELPD ha permesso una selezione di modelli informata ed efficace.
%load_ext watermark
%watermark -n -u -v -iv -w
Last updated: Fri Jan 26 2024
Python implementation: CPython
Python version : 3.11.7
IPython version : 8.19.0
pandas : 2.1.4
pingouin : 0.5.3
pymc : 5.10.3
bambi : 0.13.0
numpy : 1.26.2
matplotlib: 3.8.2
arviz : 0.17.0
xarray : 2023.12.0
seaborn : 0.13.0
Watermark: 2.4.3