Open In Colab

47. Posterior Predictive Checks#

I controlli predittivi a posteriori (PPC) forniscno un ottimo metodo per convalidare un modello. L’idea è quella di generare dei dati dal modello utilizzando i parametri della distribuzione a posteriori. Questo argomento è stato trattato nel capitolo La predizione bayesiana. Ricordiamo che i PPC si pongono il problema di quantificare il grado in cui i dati generati dal modello si discostano dalla vera distribuzione \(Y\). Quindi, la domanda è quella di capire se la distribuzione a posteriori che è stata ottenuta è simile alla distribuzione teorica. In questo capitolo vedremo come usare degli strumenti grafici per affrontare questo problema.

47.1. Simulazione#

Iniziamo a simulare dei dati.

import arviz as az
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import pymc as pm
from pymc import HalfNormal, Model, Normal, sample
import xarray as xr
from statsmodels.nonparametric.smoothers_lowess import lowess as sm_lowess
import warnings

warnings.filterwarnings("ignore")
warnings.simplefilter("ignore")

print(f"Runing on PyMC v{pm.__version__}")
Runing on PyMC v5.5.0
# Initialize random number generator
RANDOM_SEED = 8927
rng = np.random.default_rng(RANDOM_SEED)

plt.style.use("bmh")
plt.rcParams["figure.figsize"] = [10, 6]
plt.rcParams["figure.dpi"] = 100
plt.rcParams["figure.facecolor"] = "white"

sns.set_theme(palette="colorblind")

%load_ext autoreload
%autoreload 2
%matplotlib inline
%config InlineBackend.figure_format = 'retina'
%config InlineBackend.figure_format = "svg"

def standardize(series):
    """Standardize a pandas series"""
    return (series - series.mean()) / series.std()

I dati verranno simulati in modo tale da avere una relazione non lineare tra \(X\) e \(Y\).

# Size of dataset
size = 100

x = 2*np.random.random(size)
y = np.sin(x) * 12*np.exp(-x) + np.random.normal(0, 0.2, size)

zx = standardize(x)
zy = standardize(y)
sm_x, sm_y = sm_lowess(zy, zx,  frac=1./5., it=5, return_sorted=True).T
plt.plot(sm_x, sm_y, color='tomato')
plt.plot(zx, zy, 'k.')
[<matplotlib.lines.Line2D at 0x11dd09790>]
_images/536ab6747c4ca90ea2a5c8af4b253cfd017f982b60c4f5de114514f3bfb050aa.svg

È ovvio che, in un caso come questo, un modello lineare come quello che abbiamo esaminato in precedenza, non è adeguato. Ma adattiamo comunque un modello lineare ai dati proprio per verificare se i PPC saranno in grado di mettere in luce il fatto che il modello è sbagliato. Implemento qui lo stesso modello che abbiamo usato in precedenza. I dati sono standardizzati, per cui le seguenti distribuzioni a priori per i parametri sono adeguate.

with pm.Model() as model_0:
    a = pm.Normal("a", 0.0, 2.0)
    b = pm.Normal("b", 0.0, 2.0)

    mu = a + b * zx
    sigma = pm.Normal("sigma", sigma=5.0)

    pm.Normal("obs", mu=mu, sigma=sigma, observed=zy)
    idata_0 = pm.sample()
Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [a, b, sigma]
---------------------------------------------------------------------------
KeyboardInterrupt                         Traceback (most recent call last)
Cell In[5], line 9
      6 sigma = pm.Normal("sigma", sigma=5.0)
      8 pm.Normal("obs", mu=mu, sigma=sigma, observed=zy)
----> 9 idata_0 = pm.sample()

File ~/mambaforge/envs/pymc_env/lib/python3.11/site-packages/pymc/sampling/mcmc.py:766, in sample(draws, tune, chains, cores, random_seed, progressbar, step, nuts_sampler, initvals, init, jitter_max_retries, n_init, trace, discard_tuned_samples, compute_convergence_checks, keep_warning_stat, return_inferencedata, idata_kwargs, nuts_sampler_kwargs, callback, mp_ctx, model, **kwargs)
    764 _print_step_hierarchy(step)
    765 try:
--> 766     _mp_sample(**sample_args, **parallel_args)
    767 except pickle.PickleError:
    768     _log.warning("Could not pickle model, sampling singlethreaded.")

File ~/mambaforge/envs/pymc_env/lib/python3.11/site-packages/pymc/sampling/mcmc.py:1141, in _mp_sample(draws, tune, step, chains, cores, random_seed, start, progressbar, traces, model, callback, mp_ctx, **kwargs)
   1138 # We did draws += tune in pm.sample
   1139 draws -= tune
-> 1141 sampler = ps.ParallelSampler(
   1142     draws=draws,
   1143     tune=tune,
   1144     chains=chains,
   1145     cores=cores,
   1146     seeds=random_seed,
   1147     start_points=start,
   1148     step_method=step,
   1149     progressbar=progressbar,
   1150     mp_ctx=mp_ctx,
   1151 )
   1152 try:
   1153     try:

File ~/mambaforge/envs/pymc_env/lib/python3.11/site-packages/pymc/sampling/parallel.py:402, in ParallelSampler.__init__(self, draws, tune, chains, cores, seeds, start_points, step_method, progressbar, mp_ctx)
    399 if mp_ctx.get_start_method() != "fork":
    400     step_method_pickled = cloudpickle.dumps(step_method, protocol=-1)
--> 402 self._samplers = [
    403     ProcessAdapter(
    404         draws,
    405         tune,
    406         step_method,
    407         step_method_pickled,
    408         chain,
    409         seed,
    410         start,
    411         mp_ctx,
    412     )
    413     for chain, seed, start in zip(range(chains), seeds, start_points)
    414 ]
    416 self._inactive = self._samplers.copy()
    417 self._finished: List[ProcessAdapter] = []

File ~/mambaforge/envs/pymc_env/lib/python3.11/site-packages/pymc/sampling/parallel.py:403, in <listcomp>(.0)
    399 if mp_ctx.get_start_method() != "fork":
    400     step_method_pickled = cloudpickle.dumps(step_method, protocol=-1)
    402 self._samplers = [
--> 403     ProcessAdapter(
    404         draws,
    405         tune,
    406         step_method,
    407         step_method_pickled,
    408         chain,
    409         seed,
    410         start,
    411         mp_ctx,
    412     )
    413     for chain, seed, start in zip(range(chains), seeds, start_points)
    414 ]
    416 self._inactive = self._samplers.copy()
    417 self._finished: List[ProcessAdapter] = []

File ~/mambaforge/envs/pymc_env/lib/python3.11/site-packages/pymc/sampling/parallel.py:259, in ProcessAdapter.__init__(self, draws, tune, step_method, step_method_pickled, chain, seed, start, mp_ctx)
    242     step_method_send = step_method
    244 self._process = mp_ctx.Process(
    245     daemon=True,
    246     name=process_name,
   (...)
    257     ),
    258 )
--> 259 self._process.start()
    260 # Close the remote pipe, so that we get notified if the other
    261 # end is closed.
    262 remote_conn.close()

File ~/mambaforge/envs/pymc_env/lib/python3.11/multiprocessing/process.py:121, in BaseProcess.start(self)
    118 assert not _current_process._config.get('daemon'), \
    119        'daemonic processes are not allowed to have children'
    120 _cleanup()
--> 121 self._popen = self._Popen(self)
    122 self._sentinel = self._popen.sentinel
    123 # Avoid a refcycle if the target function holds an indirect
    124 # reference to the process object (see bpo-30775)

File ~/mambaforge/envs/pymc_env/lib/python3.11/multiprocessing/context.py:300, in ForkServerProcess._Popen(process_obj)
    297 @staticmethod
    298 def _Popen(process_obj):
    299     from .popen_forkserver import Popen
--> 300     return Popen(process_obj)

File ~/mambaforge/envs/pymc_env/lib/python3.11/multiprocessing/popen_forkserver.py:35, in Popen.__init__(self, process_obj)
     33 def __init__(self, process_obj):
     34     self._fds = []
---> 35     super().__init__(process_obj)

File ~/mambaforge/envs/pymc_env/lib/python3.11/multiprocessing/popen_fork.py:19, in Popen.__init__(self, process_obj)
     17 self.returncode = None
     18 self.finalizer = None
---> 19 self._launch(process_obj)

File ~/mambaforge/envs/pymc_env/lib/python3.11/multiprocessing/popen_forkserver.py:58, in Popen._launch(self, process_obj)
     55 self.finalizer = util.Finalize(self, util.close_fds,
     56                                (_parent_w, self.sentinel))
     57 with open(w, 'wb', closefd=True) as f:
---> 58     f.write(buf.getbuffer())
     59 self.pid = forkserver.read_signed(self.sentinel)

KeyboardInterrupt: 

Anche se il modello è sbagliato, i trace plot non mettono in evidenza alcuna anomalia di rilievo.

_ = az.plot_trace(idata_0, combined=True)
plt.tight_layout()
_images/4f44673df853664c59fa02385916d72a43b9d7a66a65420c49118fff33b0c98f.svg

Generiamo ora i dati necessari per i controlli predittivi a posteriori. A tal fine, useremo una funzione PyMC dedicata a questo scopo per campionare i dati dalla distribuzione a posteriori. Questa funzione estrarrà casualmente 4000 campioni di parametri dalla traccia. Quindi, per ogni campione, estrarrà 100 numeri casuali da una distribuzione normale specificata dai valori di mu e sigma in quel campione (si veda il capitolo La predizione bayesiana.):

with model_0:
    pm.sample_posterior_predictive(
        idata_0, extend_inferencedata=True, random_seed=rng);

Ora, l’oggetto posterior_predictive in idata_0 contiene 4000 set di dati simulati, contenenti 100 campioni ciascuno, ognuno dei quali è stato calcolato usando un valore del parametro preso a caso dalla distribuzione a posteriori:

idata_0
arviz.InferenceData
    • <xarray.Dataset>
      Dimensions:  (chain: 4, draw: 1000)
      Coordinates:
        * chain    (chain) int64 0 1 2 3
        * draw     (draw) int64 0 1 2 3 4 5 6 7 8 ... 992 993 994 995 996 997 998 999
      Data variables:
          a        (chain, draw) float64 0.1012 -0.01572 -0.01572 ... -0.1204 0.2089
          b        (chain, draw) float64 0.04482 0.06067 0.06067 ... 0.1886 0.0903
          sigma    (chain, draw) float64 1.013 1.008 1.008 1.029 ... 1.083 1.044 1.026
      Attributes:
          created_at:                 2023-05-09T06:27:40.229954
          arviz_version:              0.15.1
          inference_library:          pymc
          inference_library_version:  5.3.0
          sampling_time:              33.94731283187866
          tuning_steps:               1000

    • <xarray.Dataset>
      Dimensions:    (chain: 4, draw: 1000, obs_dim_2: 100)
      Coordinates:
        * chain      (chain) int64 0 1 2 3
        * draw       (draw) int64 0 1 2 3 4 5 6 7 ... 992 993 994 995 996 997 998 999
        * obs_dim_2  (obs_dim_2) int64 0 1 2 3 4 5 6 7 8 ... 92 93 94 95 96 97 98 99
      Data variables:
          obs        (chain, draw, obs_dim_2) float64 2.144 0.4018 ... 1.331 -0.7943
      Attributes:
          created_at:                 2023-05-09T06:28:26.359177
          arviz_version:              0.15.1
          inference_library:          pymc
          inference_library_version:  5.3.0

    • <xarray.Dataset>
      Dimensions:                (chain: 4, draw: 1000)
      Coordinates:
        * chain                  (chain) int64 0 1 2 3
        * draw                   (draw) int64 0 1 2 3 4 5 ... 994 995 996 997 998 999
      Data variables: (12/17)
          acceptance_rate        (chain, draw) float64 0.7641 0.9911 ... 0.8792 0.9259
          perf_counter_start     (chain, draw) float64 2.071e+05 ... 2.071e+05
          diverging              (chain, draw) bool False False False ... False False
          max_energy_error       (chain, draw) float64 0.4123 -0.1461 ... 0.2515
          smallest_eigval        (chain, draw) float64 nan nan nan nan ... nan nan nan
          n_steps                (chain, draw) float64 3.0 3.0 3.0 3.0 ... 3.0 3.0 3.0
          ...                     ...
          index_in_trajectory    (chain, draw) int64 -3 -1 0 -1 3 2 ... 2 2 -2 -3 2 2
          lp                     (chain, draw) float64 -147.9 -147.4 ... -148.7 -149.5
          reached_max_treedepth  (chain, draw) bool False False False ... False False
          tree_depth             (chain, draw) int64 2 2 2 2 2 2 2 2 ... 2 2 2 2 2 2 2
          perf_counter_diff      (chain, draw) float64 0.0004933 ... 0.0004311
          step_size              (chain, draw) float64 0.895 0.895 ... 0.9756 0.9756
      Attributes:
          created_at:                 2023-05-09T06:27:40.249921
          arviz_version:              0.15.1
          inference_library:          pymc
          inference_library_version:  5.3.0
          sampling_time:              33.94731283187866
          tuning_steps:               1000

    • <xarray.Dataset>
      Dimensions:    (obs_dim_0: 100)
      Coordinates:
        * obs_dim_0  (obs_dim_0) int64 0 1 2 3 4 5 6 7 8 ... 92 93 94 95 96 97 98 99
      Data variables:
          obs        (obs_dim_0) float64 0.9185 -1.643 0.8393 ... -1.756 1.176 0.7894
      Attributes:
          created_at:                 2023-05-09T06:27:40.254781
          arviz_version:              0.15.1
          inference_library:          pymc
          inference_library_version:  5.3.0

A questo punto possiamo utilizzare la funzione az.plot_ppc() per determinare se il modello è almeno in grado di riprodurre i dati osservati:

_ = az.plot_ppc(idata_0, num_pp_samples=100)
_images/d7effd68728c8ce9e5c12d39a3fe56355347d452e83ef62155e6ff67ff0c6046.svg

Il PPC plot mostra come vi sia pochissima corrispondenza tra la distribuzione dei dati osservati (in nero) e la distribuzione dei dati predetti dal modello. Questo era atteso, dato che abbiamo adattato un modello del tutto sbagliato per i dati a disposizione.

Proseguiamo con la simulazione. I dati sono stati generati con una funzione non lineare. In generale, quella funzione è sconosciuta. Qualunque funzione non lineare, però, può essere approssimata da un modello polinomiale. Un modello di regressione polinomiale si ottiene inserendo nel modello dei nuovi predittori, ciascuno ottenuto elevando la variabile originale a una potenza:

\[ \mathbb{E}(Y) = \alpha + \beta_1 x + \beta_2 x^2 + \beta_3 x^3 + \beta_4 x^4 + \dots \]

Questo modello, anche se non descrive una relazione lineare tra \(X\) e \(Y\), è funzione lineare dei coefficienti ignoti \(\alpha, \beta_1, \beta_2, \dots\). Il grado del polinomio dipende dal tipo di curva che vogliamo approssimare. Proviamo qui con un modello polinomiale di ordine 4.

zx2 = zx**2
zx3 = zx**3
zx4 = zx**4

with pm.Model() as model_1:
    a = pm.Normal("a", 0.0, 3.0)
    b1 = pm.Normal("b1", 0.0, 3.0)
    b2 = pm.Normal("b2", 0.0, 3.0)
    b3 = pm.Normal("b3", 0.0, 3.0)
    b4 = pm.Normal("b4", 0.0, 3.0)
    
    mu = a + b1 * zx + b2 * zx2 + b3 * zx3 + b4 * zx4
    sigma = pm.HalfNormal("sigma", sigma=5.0)

    pm.Normal("obs", mu=mu, sigma=sigma, observed=zy)
    idata_1 = pm.sample()
Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [a, b1, b2, b3, b4, sigma]
WARNING (pytensor.tensor.blas): Using NumPy C-API based implementation for BLAS functions.
WARNING (pytensor.tensor.blas): Using NumPy C-API based implementation for BLAS functions.
WARNING (pytensor.tensor.blas): Using NumPy C-API based implementation for BLAS functions.
WARNING (pytensor.tensor.blas): Using NumPy C-API based implementation for BLAS functions.
Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 37 seconds.
_ = az.plot_trace(idata_1, combined=True)
plt.tight_layout()
_images/caef5f9f0e387e7e527f20d9ba3149741309fcfecd4b56fad47efcdc9df12e7b.svg
az.summary(idata_1)
mean sd hdi_3% hdi_97% mcse_mean mcse_sd ess_bulk ess_tail r_hat
a 0.960 0.038 0.888 1.029 0.001 0.001 1653.0 2003.0 1.0
b1 -0.672 0.049 -0.768 -0.581 0.001 0.001 1813.0 2014.0 1.0
b2 -0.837 0.073 -0.967 -0.698 0.002 0.001 1428.0 1560.0 1.0
b3 0.423 0.025 0.376 0.467 0.001 0.000 1894.0 2310.0 1.0
b4 -0.070 0.026 -0.119 -0.020 0.001 0.000 1540.0 1869.0 1.0
sigma 0.202 0.015 0.176 0.229 0.000 0.000 2943.0 2493.0 1.0

Anche in questo caso, campioniamo dalla distribuzione a posteriori dei parametri e simuliamo i dati possibili futuri. L’esame del PPC plot, in questo secondo caso, mostra una buona corrispondenza tra le predizioni del modello e i dati osservati.

with model_1:
    pm.sample_posterior_predictive(
        idata_1, extend_inferencedata=True, random_seed=rng)

_ = az.plot_ppc(idata_1, num_pp_samples=100)
_images/d736dc91d589502a20caf6ae598383ed091f09e9f79e15c653853b7fc9734bcf.svg

Si noti che la curva nera rappresenta la distribuzione della variabile dipendente, come risulta anche dal seguente KDE plot.

sns.kdeplot(zy)
<Axes: ylabel='Density'>
_images/1604fefc8d6084ab4a59a7f6c8106d449393e8304085c17ba86b3d68420b0fb5.svg

Poniamoci ora il problema di generare un grafico che mostra la relazione prevista tra il predittore (\(x\)) e i valori della \(y\) previsti dal modello.

with model_1:
    pm.sample_posterior_predictive(
        idata_1, extend_inferencedata=True, random_seed=rng)

post = idata_1.posterior
mu_pp = post["a"] + post["b1"] * xr.DataArray(zx, dims=["obs_id"]) + \
    post["b2"] * xr.DataArray(zx2, dims=["obs_id"]) + \
    post["b3"] * xr.DataArray(zx3, dims=["obs_id"]) + \
    post["b4"] * xr.DataArray(zx4, dims=["obs_id"]) 

_, ax = plt.subplots()

ax.plot(
    zx, mu_pp.mean(("chain", "draw")), '.', label="Mean outcome", color="C1", alpha=0.6
)
ax.scatter(zx, idata_1.observed_data["obs"])
az.plot_hdi(zx, idata_1.posterior_predictive["obs"])

ax.set_xlabel("Predictor (stdz)")
ax.set_ylabel("Outcome (stdz)")
Text(0, 0.5, 'Outcome (stdz)')
_images/a02cb622d2421b75732c1c8f7cca0f41748b1059b006164a280967ff9b9d4150.svg

per questo secondo modello vediamo come i dati predetti (la banda evidenziata) approssimano da vicino la relazione tra i dati osservati e la \(x\).

Ripetiamo questo procedimento usando i dati del modello sbagliato.

with model_0:
    pm.sample_posterior_predictive(
        idata_0, extend_inferencedata=True, random_seed=rng)

post = idata_0.posterior
mu_pp = post["a"] + post["b"] * xr.DataArray(zx, dims=["obs_id"]) 

_, ax = plt.subplots()

ax.plot(
    zx, mu_pp.mean(("chain", "draw")), '.', label="Mean outcome", color="C1", alpha=0.6
)
ax.scatter(zx, idata_0.observed_data["obs"])
az.plot_hdi(zx, idata_0.posterior_predictive["obs"])

ax.set_xlabel("Predictor (stdz)")
ax.set_ylabel("Outcome (stdz)")
Text(0, 0.5, 'Outcome (stdz)')
_images/e530bd08204ad832314286caf5690b50a5165e77edc288831d56d286a3b374b4.svg

Si noti come, in questo caso, vi è un’enorme incertezza nella predizone del modello. E, ovviamente, i dati predetti non rendono conto, in nessun modo, della relazione tra i valori osservati \(x, y\).

In conlusione, dunque, i PPC plot sono un’utile strumento per il confronto tra modelli.

47.2. Watermark#

%load_ext watermark
%watermark -n -u -v -iv -w