Skip to content

Instantly share code, notes, and snippets.

Last active January 9, 2022 23:24
What would you like to do?
from collections import namedtuple
from jax import random
from jax.flatten_util import ravel_pytree
import jax.numpy as jnp
import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, util
from numpyro.infer.mcmc import MCMCKernel
MetState = namedtuple("MetState", ["z", "rng_key"])
class Metropolis(MCMCKernel):
def __init__(self, model, step_size=0.1):
self._model = model
self._step_size = step_size
def sample_field(self):
return "z"
def init(self, rng_key, num_warmup, init_params, model_args, model_kwargs):
assert rng_key.ndim == 1, "only non-vectorized, for now"
init_params_2, potential_fn, postprocess_fn, model_trace = util.initialize_model(
# init_strategy=self._init_strategy,
z_flat, unravel_fn = ravel_pytree(init_params)
self._potential_fn = lambda z: potential_fn(unravel_fn(z))
self._postprocess_fn = lambda z: postprocess_fn(unravel_fn(z))
return MetState(z_flat, rng_key)
def postprocess_fn(self, model_args, model_kwargs):
return self._postprocess_fn
def sample(self, state, model_args, model_kwargs):
rng_key, key_proposal, key_accept = random.split(state.rng_key, 3)
z_proposal = dist.Normal(state.z, self._step_size).sample(key_proposal)
accept_prob = jnp.exp(self._potential_fn(state.z) - self._potential_fn(z_proposal))
z_new = jnp.where(dist.Uniform().sample(key_accept) < accept_prob, z_proposal, state.z)
return MetState(z_new, rng_key)
def model():
numpyro.sample('x', dist.Uniform(0,1))
def my_run(model):
rng_key = random.PRNGKey(12345)
kernel = Metropolis(model, step_size=1)
mcmc = MCMC(kernel, num_warmup=0, num_samples=50_000, thinning=1), init_params={'x':jnp.ones(10)})
posterior_samples = mcmc.get_samples()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment