Skip to content

Instantly share code, notes, and snippets.

What would you like to do?
from collections import namedtuple
from jax import random
from jax.flatten_util import ravel_pytree
import jax.numpy as jnp
import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, util
from numpyro.infer.mcmc import MCMCKernel
MetState = namedtuple("MetState", ["z", "rng_key"]) # does it matter if it is called z or u?
class Metropolis(MCMCKernel):
def __init__(self, model, step_size=0.1):
self._model = model
self._step_size = step_size
def sample_field(self):
return "z"
def default_fields(self):
return ("z",)
def init(self, rng_key, num_warmup, init_params, model_args, model_kwargs):
assert rng_key.ndim == 1, "only non-vectorized, for now"
return MetState(init_params, rng_key)
def sample(self, state, model_args, model_kwargs):
rng_key, key_proposal, key_accept = random.split(state.rng_key, 3)
z_flat, unravel_fn = ravel_pytree(state.z)
z_proposal = dist.Normal(z_flat, self._step_size).sample(key_proposal)
z_proposal_dict = unravel_fn(z_proposal)
log_pr_0, model_tr = util.log_density(self._model, model_args, model_kwargs, state.z)
log_pr_1, model_tr = util.log_density(self._model, model_args, model_kwargs, z_proposal_dict)
accept_prob = jnp.exp(log_pr_1 - log_pr_0)
z_new = jnp.where(dist.Uniform().sample(key_accept) < accept_prob, z_proposal, z_flat)
return MetState(unravel_fn(z_new), rng_key)
def model():
numpyro.sample('x', dist.Normal(0,1))
rng_key = random.PRNGKey(12345)
kernel = Metropolis(model, step_size=1)
mcmc = MCMC(kernel, num_warmup=0, num_samples=200, thinning=1), init_params={'x':0})
posterior_samples = mcmc.get_samples()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment