Skip to content

Instantly share code, notes, and snippets.

@banditkings
Created November 15, 2023 23:40
Show Gist options
  • Save banditkings/8fd36a96c444ec7769a17f940d34e14b to your computer and use it in GitHub Desktop.
Save banditkings/8fd36a96c444ec7769a17f940d34e14b to your computer and use it in GitHub Desktop.
numpyro boilerplate MCMC example using a NUTS sampler
import jax.numpy as jnp
from jax import random
import numpyro
from numpyro.diagnostics import hpdi
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS, Predictive
# Simulated Data
# 3000 samples, 3 coefficients
N, D = 3000, 3
data = random.normal(random.PRNGKey(0), (N, D))
true_coefs = jnp.arange(1., D+1.)
# logits = jnp.sum(true_coefs * data + 10, axis=1)
logits = data @ true_coefs + 10
labels = dist.Bernoulli(logits=logits).sample(random.PRNGKey(0))
# Logistic Regression example
N, D = 3000, 3
def logistic_regression(data, labels):
# Priors
coefs = numpyro.sample('coefs', dist.Normal(jnp.zeros(D), jnp.ones(D)))
intercept = numpyro.sample('intercept', dist.Normal(0., 10.))
logits = data @ coefs + intercept
# Don't need to return anything but here it is
numpyro.sample('obs', dist.Bernoulli(logits=logits), obs=labels)
# Inference
model = logistic_regression
num_warmup, num_samples = 1000, 1000
mcmc = MCMC(NUTS(model=model),
num_warmup=num_warmup,
num_samples=num_samples)
mcmc.run(random.PRNGKey(2), data, labels)
mcmc.print_summary()
# Prior Predictive Sampling
rng_key, rng_key_ = random.split(rng_key)
prior_predictive = Predictive(model, num_samples=1000)
prior_predictions = prior_predictive(rng_key_, data, labels)['obs']
mean_prior_pred = jnp.mean(prior_predictions, axis=0)
hpdi_prior_pred = hpdi(prior_predictions, 0.9)
# Posterior Predictive Sampling
## get representative sample of posterior
posterior_samples = mcmc.get_samples()
## Posterior Predictive Sampling
rng_key_, rng_key_2 = random.split(rng_key_)
posterior_predictive = Predictive(model, posterior_samples=posterior_samples)
# Set target to None here
posterior_predictive_samples = posterior_predictive(rng_key_2, data=data, labels=None)['obs']
mean_posterior_pred = jnp.mean(posterior_predictive_samples, axis=0)
hpdi_posterior_pred = hpdi(posterior_predictive_samples, 0.9)
@banditkings
Copy link
Author

banditkings commented Nov 15, 2023

Consider logistic regression as a basic example of a GLM with a logit link function and a Bernoulli distribution.

$$\begin{aligned} y_i &\sim Bernoulli(p_i)\\\ \text{logit}(p_i) &= X\beta\\\ &\text{or equivalently,}\\\ p_i &= \text{logit}^{-1}(X\beta)\\\ &=\frac{\exp(X\beta)}{1+ \exp(X\beta)} \end{aligned}$$

In the above example, the dist.Bernoulli function takes logits as an argument instead of specifying the probability p. We could have alternatively specified p as:

from jax.scipy.special import expit

def logistic_regression(data, labels):
    # Priors
    coefs = numpyro.sample('coefs', dist.Normal(jnp.zeros(D), jnp.ones(D)))
    intercept = numpyro.sample('intercept', dist.Normal(0., 10.))
    
    # Use inverse logit or `expit` function:
    p = expit(data @ coefs + intercept)
    # Don't need to return anything but here it is
    numpyro.sample('obs', dist.Bernoulli(p=p), obs=labels)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment