Skip to content

Instantly share code, notes, and snippets.

@bmorris3
Created October 18, 2023 16:49
Show Gist options
  • Save bmorris3/0557c13584c3fe321827eca788f37d02 to your computer and use it in GitHub Desktop.
Save bmorris3/0557c13584c3fe321827eca788f37d02 to your computer and use it in GitHub Desktop.
How to compute leave-one-out cross validation stats for a numpyro (jax) model with a Gaussian process from celerite2
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"id": "151c0c61-b90b-4f03-b80f-e78fdb4bcdcc",
"metadata": {},
"source": [
"# Compute LOO for models with GPs\n",
"\n",
"The techniques used in this notebook are explained in [this tutorial for stats folk](https://mc-stan.org/loo/articles/loo2-non-factorizable.html)."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b4142a99-bc9f-4812-bbe4-c7c5b7fa8664",
"metadata": {},
"outputs": [],
"source": [
"%matplotlib inline\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"import numpyro\n",
"cpu_cores = 8\n",
"numpyro.set_host_device_count(cpu_cores)\n",
"\n",
"import numpyro.distributions as dist\n",
"from numpyro.infer import MCMC, NUTS, Predictive\n",
"\n",
"from jax.config import config\n",
"config.update('jax_enable_x64', True)\n",
"\n",
"from jax import random, numpy as jnp\n",
"\n",
"from celerite2 import terms as terms_py, GaussianProcess as GaussianProcess_py\n",
"from celerite2.jax import terms, GaussianProcess\n",
"\n",
"import arviz\n",
"from arviz.stats.stats_utils import logsumexp as _logsumexp\n",
"from arviz.stats.stats import _ic_matrix\n",
"\n",
"import pandas as pd\n",
"import xarray as xr\n",
"from corner import corner\n",
"from scipy.optimize import minimize\n",
"from scipy import stats as st"
]
},
{
"cell_type": "markdown",
"id": "602ad2e3-4360-40a7-beb2-4506754fa17f",
"metadata": {},
"source": [
"Generate synthetic data with a GP. I'll use the \"python\" (not JAX) module of celerite2 in (only) this cell, so we can produce a random sample:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "36d7fc22-c630-43b1-a5f8-c098af1abb26",
"metadata": {},
"outputs": [],
"source": [
"np.random.seed(42) \n",
"\n",
"t = np.linspace(0, 100, 1000)\n",
"\n",
"true_sigma = 0.2\n",
"true_rho = 9\n",
"true_tau = 100\n",
"true_mean = 2\n",
"kernel = terms_py.SHOTerm(sigma=true_sigma, rho=true_rho, tau=true_tau)\n",
"gp = GaussianProcess_py(kernel, t=t, mean=true_mean)\n",
"\n",
"yerr = 0.05\n",
"y = gp.sample() + np.random.normal(scale=yerr, size=len(t))\n",
"\n",
"plt.plot(t, y)"
]
},
{
"cell_type": "markdown",
"id": "1d3617fc-76ce-4c83-82cb-82a6cdb6e365",
"metadata": {},
"source": [
"Fit the synthetic data with a GP using numpyro+celerite2:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "353e0376-5e74-47a4-a678-c502c9610f9d",
"metadata": {},
"outputs": [],
"source": [
"def numpyro_model():\n",
" # this model looks like the tutorial in celerite2 here:\n",
" # https://celerite2.readthedocs.io/en/latest/tutorials/first/#posterior-inference-using-numpyro\n",
" \n",
" mean = numpyro.sample(\"mean\", dist.Normal(0.0, 1))\n",
" log_jitter = numpyro.sample(\"log_jitter\", dist.Normal(-7, 0.5))\n",
"\n",
" log_sigma = numpyro.sample(\"log_sigma\", dist.Normal(-1, 1))\n",
" log_rho = numpyro.sample(\"log_rho\", dist.Normal(2, 2))\n",
" log_tau = numpyro.sample(\"log_tau\", dist.Normal(4, 2))\n",
" kernel = terms.UnderdampedSHOTerm(\n",
" sigma=jnp.exp(log_sigma), rho=jnp.exp(log_rho), tau=jnp.exp(log_tau)\n",
" )\n",
"\n",
" gp = GaussianProcess(kernel, mean=mean)\n",
" gp.compute(t, diag=yerr**2 + jnp.exp(log_jitter), check_sorted=False)\n",
"\n",
" numpyro.sample(\"obs\", gp.numpyro_dist(), obs=y)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "34895e72-113d-42cc-bba1-03119fe427eb",
"metadata": {},
"outputs": [],
"source": [
"nuts_kernel = NUTS(numpyro_model, dense_mass=True)\n",
"mcmc = MCMC(\n",
" nuts_kernel,\n",
" num_warmup=1000,\n",
" num_samples=1000,\n",
" num_chains=cpu_cores,\n",
" progress_bar=True,\n",
")\n",
"rng_key = random.PRNGKey(34923)\n",
"mcmc.run(rng_key)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b7d04dca-beed-42e4-b1f4-bb3c8ef98eae",
"metadata": {},
"outputs": [],
"source": [
"result = arviz.from_numpyro(mcmc)\n",
"\n",
"corner(\n",
" result, \n",
" var_names='log_sigma log_rho log_tau mean'.split(), \n",
" truths=np.log([true_sigma, true_rho, true_tau]).tolist() + [true_mean]\n",
")\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"id": "9fa9201f-066c-46cc-b651-6b30d9a86ad6",
"metadata": {},
"source": [
"But note that numpyro is tracking the log likelihood as a single number for all datapoints in the timeseries, rather than pointwise:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "2386e929-6e06-4a27-b109-740cea79d298",
"metadata": {},
"outputs": [],
"source": [
"result.log_likelihood"
]
},
{
"cell_type": "markdown",
"id": "7cc7764e-ac2a-4375-8f06-e1656e467bbb",
"metadata": {},
"source": [
"Here's how we can modify the numpyro model to (optionally) compute the pointwise log likelihood, which allows us to use Leave-One-Out cross validation: "
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "740e2f93-81b1-4544-b18f-76aecedfc4e9",
"metadata": {},
"outputs": [],
"source": [
"def numpyro_model_pointwise(pointwise=False):\n",
" mean = numpyro.sample(\"mean\", dist.Normal(0.0, 1))\n",
" log_jitter = numpyro.sample(\"log_jitter\", dist.Normal(-7, 0.5))\n",
"\n",
" log_sigma = numpyro.sample(\"log_sigma\", dist.Normal(-1, 1))\n",
" log_rho = numpyro.sample(\"log_rho\", dist.Normal(2, 2))\n",
" log_tau = numpyro.sample(\"log_tau\", dist.Normal(4, 2))\n",
" kernel = terms.UnderdampedSHOTerm(\n",
" sigma=jnp.exp(log_sigma), rho=jnp.exp(log_rho), tau=jnp.exp(log_tau)\n",
" )\n",
"\n",
" gp = GaussianProcess(kernel, mean=mean)\n",
" gp.compute(t, diag=yerr**2 + jnp.exp(log_jitter), check_sorted=False)\n",
"\n",
" numpyro.sample(\"obs\", gp.numpyro_dist(), obs=y)\n",
" \n",
" if pointwise:\n",
" \n",
" # if you have a non-uniform mean model (which is independent of the GP),\n",
" # you should assign it to `mean_model` here. In the model above, \n",
" # the mean model is described by a single parameter `mean`:\n",
" mean_model = mean\n",
" \n",
" diag = yerr ** 2 + jnp.exp(log_jitter) ** 2\n",
" K_s = kernel.to_dense(t.flatten(), np.zeros_like(t))\n",
" covariance_matrix = K_s + jnp.eye(*K_s.shape) * diag\n",
" inv_cov = jnp.linalg.inv(covariance_matrix)\n",
"\n",
" # https://mc-stan.org/loo/articles/loo2-non-factorizable.html\n",
" g_i = inv_cov @ (y - mean_model)\n",
" c_ii = jnp.diag(inv_cov)\n",
"\n",
" lnlike = (\n",
" -0.5 * jnp.log(2 * np.pi) + 0.5 * jnp.log(c_ii) -\n",
" 0.5 * (g_i**2 / c_ii)\n",
" )\n",
"\n",
" numpyro.deterministic(\"pointwise\", lnlike)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ca0ff5f2-327a-48c5-ad40-52482f5940d0",
"metadata": {},
"outputs": [],
"source": [
"nuts_kernel = NUTS(numpyro_model_pointwise, dense_mass=True)\n",
"mcmc_pointwise = MCMC(\n",
" nuts_kernel,\n",
" num_warmup=1000,\n",
" num_samples=1000,\n",
" num_chains=cpu_cores,\n",
" progress_bar=True,\n",
")\n",
"rng_key = random.PRNGKey(34923)\n",
"mcmc_pointwise.run(rng_key)\n",
"result_pointwise = arviz.from_numpyro(mcmc_pointwise)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "213d0751-4b1e-4266-bab5-82b3db3444b8",
"metadata": {},
"outputs": [],
"source": [
"corner(\n",
" result_pointwise, \n",
" var_names='log_sigma log_rho log_tau mean'.split(), \n",
" truths=np.log([true_sigma, true_rho, true_tau]).tolist() + [true_mean]\n",
")\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"id": "15d13fb9-90ea-45a9-98f8-7b3bc8ef24bb",
"metadata": {},
"source": [
"The results in `result_pointwise` are no different than in `result` above, since `pointwise=False` by default. But here's where we can make use of that new logic:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "02594df2-03f0-4f47-985c-33fc8845b971",
"metadata": {},
"outputs": [],
"source": [
"def recompute_pointwise_lnlike(result_pointwise, n_draws_recompute=150):\n",
" n_chains = len(result_pointwise.sample_stats.chain)\n",
"\n",
" posterior_sample = {\n",
" k: result_pointwise.posterior[k].data.ravel() \n",
" if len(np.shape(result_pointwise.posterior[k])) == 2 else\n",
" result_pointwise.posterior[k].data.reshape(\n",
" (result_pointwise.posterior[k].data.shape[0] * \n",
" result_pointwise.posterior[k].data.shape[1], -1))\n",
" for k in result_pointwise.posterior.keys()\n",
" }\n",
"\n",
" last_n_samples = - n_draws_recompute * n_chains\n",
"\n",
" pred_kwargs = {key: posterior_sample[key][last_n_samples:] for key in posterior_sample}\n",
" pred = Predictive(\n",
" numpyro_model_pointwise, \n",
" pred_kwargs, \n",
" return_sites=[\"pointwise\"],\n",
" batch_ndims=1\n",
" )\n",
"\n",
" pointwise_logps = pred(rng_key, pointwise=True)['pointwise']\n",
"\n",
" n_draws_total = result_pointwise.sample_stats.draw.shape[0]\n",
" draws = result_pointwise.sample_stats.draw.draw[-n_draws_recompute:]\n",
" posterior = result_pointwise.posterior.sel(draw=draws)\n",
"\n",
" new_shape = (\n",
" n_chains, \n",
" pointwise_logps.shape[0] // n_chains, \n",
" pointwise_logps.shape[-1]\n",
" )\n",
"\n",
" log_likelihood = xr.Dataset(\n",
" {\n",
" 'obs': xr.DataArray(\n",
" pointwise_logps.reshape(new_shape), \n",
" dims=['chain', 'draw', 'obs_dim_0']\n",
" )\n",
" }\n",
" )\n",
"\n",
" return arviz.InferenceData(log_likelihood=log_likelihood, posterior=posterior)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f1aed95c-076b-4318-9866-1ccfc1beeec4",
"metadata": {},
"outputs": [],
"source": [
"result_pointwise_subset = recompute_pointwise_lnlike(result_pointwise)"
]
},
{
"cell_type": "markdown",
"id": "417c724d-d405-468d-9250-98d91273f792",
"metadata": {},
"source": [
"Now we have computed the pointwise log likelihood for a subset of samples in `result_pointwise`, which we have stored in `result_pointwise_subset`. If you look at the shape of the log likelihood, you'll see it now has a `time` dimension (which corresponds to the time series dimension in the original data `y`)."
]
},
{
"cell_type": "markdown",
"id": "23e9eecb-091c-48b9-a8c2-9f703c3dab83",
"metadata": {},
"source": [
"Now we can compute the LOO after we compute the effective sample size `reff`:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "696538f0-b3a2-4225-98b8-a4ffd9f871e8",
"metadata": {},
"outputs": [],
"source": [
"def loo(result_pointwise_subset):\n",
" n = np.prod(result_pointwise_subset.posterior['mean'].shape)\n",
" reff = arviz.ess(result_pointwise_subset, method='mean').mean() / n\n",
"\n",
" reff = (\n",
" np.hstack([reff[v].values.flatten() for v in reff.data_vars]).mean()\n",
" )\n",
"\n",
" return arviz.loo(result_pointwise_subset, pointwise=True, reff=reff)\n",
"\n",
"loo_result = loo(result_pointwise_subset)\n",
"loo_result"
]
},
{
"cell_type": "markdown",
"id": "b845976c-16f6-41f3-8671-47830051e3d0",
"metadata": {},
"source": [
"Now that we've computed LOO CV results for the GP model, let's compare the results to a simple strictly sinusoidal model:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "37358117-dc9c-4af1-97c3-e79609374dd0",
"metadata": {},
"outputs": [],
"source": [
"def numpyro_model_sinusoid():\n",
" \"\"\"\n",
" this version of the model has a strict sinusoid\n",
" as the mean model and no GP, rather than GP with\n",
" the SHO kernel above.\n",
" \"\"\"\n",
" \n",
" mean = numpyro.sample(\"mean\", dist.Normal(0.0, 1))\n",
" log_jitter = numpyro.sample(\"log_jitter\", dist.Normal(-7, 0.5))\n",
"\n",
" amp = numpyro.sample(\"amp\", dist.Uniform(0, 0.5))\n",
" phase = numpyro.sample(\"phase\", dist.Uniform(0, 2*np.pi))\n",
" period = numpyro.sample(\"period\", dist.Uniform(1, 3))\n",
" model = amp * jnp.sin(2*np.pi / period - phase) + mean\n",
"\n",
" numpyro.sample(\n",
" \"obs\", \n",
" dist.Normal(\n",
" loc=model, \n",
" scale=yerr + jnp.exp(log_jitter)\n",
" ), \n",
" obs=y\n",
" )"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3f3187a1-f356-4529-8b16-a392dba7f844",
"metadata": {},
"outputs": [],
"source": [
"nuts_kernel_sin = NUTS(numpyro_model_sinusoid, dense_mass=True)\n",
"mcmc_pointwise_sinusoid = MCMC(\n",
" nuts_kernel_sin,\n",
" num_warmup=1000,\n",
" num_samples=1000,\n",
" num_chains=cpu_cores,\n",
" progress_bar=True,\n",
")\n",
"rng_key = random.PRNGKey(34923)\n",
"mcmc_pointwise_sinusoid.run(rng_key)\n",
"result_sinusoid = arviz.from_numpyro(mcmc_pointwise_sinusoid)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d9d9b8a3-b460-46a5-90ca-4ece9f82788e",
"metadata": {},
"outputs": [],
"source": [
"loo_sinusoid = arviz.loo(result_sinusoid, pointwise=True)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "137e4048-242c-4da4-9c52-806704fd7d99",
"metadata": {},
"outputs": [],
"source": [
"def compare(\n",
" dataset_dict, ic='loo', method='stacking', \n",
" ascending=False, b_samples=1000, alpha=1, seed=None\n",
"):\n",
" \"\"\"\n",
" This is a modified version of arviz.compare that works \n",
" on the LOO outputs generated above.\n",
" \"\"\"\n",
" scale_value = 1\n",
" np.random.seed(seed)\n",
" if ic != 'loo': \n",
" raise NotImplementedError()\n",
" \n",
" names = list(dataset_dict.keys())\n",
" \n",
" ic_se = f\"{ic}_se\"\n",
" p_ic = f\"p_{ic}\"\n",
" ic_i = f\"{ic}_i\"\n",
" scale_col = f\"{ic}_scale\"\n",
" df_comp = pd.DataFrame(\n",
" index=names,\n",
" columns=[\n",
" \"rank\",\n",
" \"loo\",\n",
" \"p_loo\",\n",
" \"d_loo\",\n",
" \"weight\",\n",
" \"se\",\n",
" \"dse\",\n",
" \"warning\",\n",
" \"loo_scale\",\n",
" ],\n",
" dtype=np.float_,\n",
" )\n",
" \n",
" ics = pd.DataFrame()\n",
" names = []\n",
" for name, dataset in dataset_dict.items():\n",
" names.append(name)\n",
" try:\n",
" # Here is where the IC function is actually computed -- the rest of this\n",
" # function is argument processing and return value formatting\n",
" # ics = ics.append([dataset_dict[name]])\n",
" ics = pd.concat([ics, pd.DataFrame([dataset_dict[name]])], ignore_index=True)\n",
"\n",
" except Exception as e:\n",
" raise e.__class__(f\"Encountered error trying to compute {ic} from model {name}.\") from e\n",
" ics.index = names\n",
" ics.sort_values(by=ic, inplace=True, ascending=ascending)\n",
" ics[ic_i] = ics[ic_i].apply(lambda x: x.values.flatten())\n",
" \n",
" \n",
" if method.lower() == \"stacking\":\n",
" rows, cols, ic_i_val = _ic_matrix(ics, ic_i)\n",
" exp_ic_i = np.exp(ic_i_val / scale_value)\n",
" km1 = cols - 1\n",
"\n",
" def w_fuller(weights):\n",
" return np.concatenate((weights, [max(1.0 - np.sum(weights), 0.0)]))\n",
"\n",
" def log_score(weights):\n",
" w_full = w_fuller(weights)\n",
" score = 0.0\n",
" for i in range(rows):\n",
" score += np.log(np.dot(exp_ic_i[i], w_full))\n",
" return -score\n",
"\n",
" def gradient(weights):\n",
" w_full = w_fuller(weights)\n",
" grad = np.zeros(km1)\n",
" for k in range(km1):\n",
" for i in range(rows):\n",
" grad[k] += (exp_ic_i[i, k] - exp_ic_i[i, km1]) / np.dot(exp_ic_i[i], w_full)\n",
" return -grad\n",
"\n",
" theta = np.full(km1, 1.0 / cols)\n",
" bounds = [(0.0, 1.0) for _ in range(km1)]\n",
" constraints = [\n",
" {\"type\": \"ineq\", \"fun\": lambda x: -np.sum(x) + 1.0},\n",
" {\"type\": \"ineq\", \"fun\": np.sum},\n",
" ]\n",
"\n",
" weights = minimize(\n",
" fun=log_score, x0=theta, jac=gradient, bounds=bounds, constraints=constraints\n",
" )\n",
"\n",
" weights = w_fuller(weights[\"x\"])\n",
" ses = ics[ic_se]\n",
"\n",
" elif method.lower() == \"bb-pseudo-bma\":\n",
" rows, cols, ic_i_val = _ic_matrix(ics, ic_i)\n",
" ic_i_val = ic_i_val * rows\n",
"\n",
" b_weighting = st.dirichlet.rvs(alpha=[alpha] * rows, size=b_samples, random_state=seed)\n",
" weights = np.zeros((b_samples, cols))\n",
" z_bs = np.zeros_like(weights)\n",
" for i in range(b_samples):\n",
" z_b = np.dot(b_weighting[i], ic_i_val)\n",
" u_weights = np.exp((z_b - np.max(z_b)) / scale_value)\n",
" z_bs[i] = z_b # pylint: disable=unsupported-assignment-operation\n",
" weights[i] = u_weights / np.sum(u_weights)\n",
"\n",
" weights = weights.mean(axis=0)\n",
" ses = pd.Series(z_bs.std(axis=0), index=names) # pylint: disable=no-member\n",
"\n",
" elif method.lower() == \"pseudo-bma\":\n",
" min_ic = ics.iloc[0][ic]\n",
" z_rv = np.exp((ics[ic] - min_ic) / scale_value)\n",
" weights = z_rv / np.sum(z_rv)\n",
" ses = ics[ic_se]\n",
"\n",
" if np.any(weights):\n",
" min_ic_i_val = ics[ic_i].iloc[0]\n",
" for idx, val in enumerate(ics.index):\n",
" res = ics.loc[val]\n",
" if scale_value < 0:\n",
" diff = res[ic_i] - min_ic_i_val\n",
" else:\n",
" diff = min_ic_i_val - res[ic_i]\n",
" d_ic = np.sum(diff)\n",
" d_std_err = np.sqrt(len(diff) * np.var(diff))\n",
" std_err = ses.loc[val]\n",
" weight = weights[idx]\n",
" df_comp.loc[val] = (\n",
" idx,\n",
" res[ic],\n",
" res[p_ic],\n",
" d_ic,\n",
" weight,\n",
" std_err,\n",
" d_std_err,\n",
" res[\"warning\"],\n",
" res[scale_col],\n",
" )\n",
"\n",
" df_comp[\"rank\"] = df_comp[\"rank\"].astype(int)\n",
" df_comp[\"warning\"] = df_comp[\"warning\"].astype(bool)\n",
" return df_comp.sort_values(by=ic, ascending=ascending)"
]
},
{
"cell_type": "markdown",
"id": "3c61e651-416e-4d89-81c5-8e69db522f3d",
"metadata": {},
"source": [
"Here's the result:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "413bb045-dbbe-462b-abea-243a9bc77b43",
"metadata": {},
"outputs": [],
"source": [
"compare({'sinusoid': loo_sinusoid, 'gp': loo_result})"
]
},
{
"cell_type": "markdown",
"id": "9b707983-5b3f-4704-9a3a-8462a0fe819e",
"metadata": {},
"source": [
"The GP model is preferred!"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "01ce72d1-2ab4-4471-b13b-fda9ef0af00c",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.8"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment