Created
November 15, 2023 23:40
-
-
Save banditkings/8fd36a96c444ec7769a17f940d34e14b to your computer and use it in GitHub Desktop.
numpyro boilerplate MCMC example using a NUTS sampler
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import jax.numpy as jnp | |
from jax import random | |
import numpyro | |
from numpyro.diagnostics import hpdi | |
import numpyro.distributions as dist | |
from numpyro.infer import MCMC, NUTS, Predictive | |
# Simulated Data | |
# 3000 samples, 3 coefficients | |
N, D = 3000, 3 | |
data = random.normal(random.PRNGKey(0), (N, D)) | |
true_coefs = jnp.arange(1., D+1.) | |
# logits = jnp.sum(true_coefs * data + 10, axis=1) | |
logits = data @ true_coefs + 10 | |
labels = dist.Bernoulli(logits=logits).sample(random.PRNGKey(0)) | |
# Logistic Regression example | |
N, D = 3000, 3 | |
def logistic_regression(data, labels): | |
# Priors | |
coefs = numpyro.sample('coefs', dist.Normal(jnp.zeros(D), jnp.ones(D))) | |
intercept = numpyro.sample('intercept', dist.Normal(0., 10.)) | |
logits = data @ coefs + intercept | |
# Don't need to return anything but here it is | |
numpyro.sample('obs', dist.Bernoulli(logits=logits), obs=labels) | |
# Inference | |
model = logistic_regression | |
num_warmup, num_samples = 1000, 1000 | |
mcmc = MCMC(NUTS(model=model), | |
num_warmup=num_warmup, | |
num_samples=num_samples) | |
mcmc.run(random.PRNGKey(2), data, labels) | |
mcmc.print_summary() | |
# Prior Predictive Sampling | |
rng_key, rng_key_ = random.split(rng_key) | |
prior_predictive = Predictive(model, num_samples=1000) | |
prior_predictions = prior_predictive(rng_key_, data, labels)['obs'] | |
mean_prior_pred = jnp.mean(prior_predictions, axis=0) | |
hpdi_prior_pred = hpdi(prior_predictions, 0.9) | |
# Posterior Predictive Sampling | |
## get representative sample of posterior | |
posterior_samples = mcmc.get_samples() | |
## Posterior Predictive Sampling | |
rng_key_, rng_key_2 = random.split(rng_key_) | |
posterior_predictive = Predictive(model, posterior_samples=posterior_samples) | |
# Set target to None here | |
posterior_predictive_samples = posterior_predictive(rng_key_2, data=data, labels=None)['obs'] | |
mean_posterior_pred = jnp.mean(posterior_predictive_samples, axis=0) | |
hpdi_posterior_pred = hpdi(posterior_predictive_samples, 0.9) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Consider logistic regression as a basic example of a GLM with a logit link function and a Bernoulli distribution.
In the above example, the
dist.Bernoulli
function takeslogits
as an argument instead of specifying the probabilityp
. We could have alternatively specifiedp
as: