Skip to content

Instantly share code, notes, and snippets.

@amifalk
Last active March 9, 2024 00:38
Show Gist options
  • Save amifalk/950439c10063f0a75ac99fa2d277825f to your computer and use it in GitHub Desktop.
Save amifalk/950439c10063f0a75ac99fa2d277825f to your computer and use it in GitHub Desktop.
Functional NumPyro MCMC Wrapper
from functools import partial
from operator import attrgetter
import jax
import jax.numpy as jnp
import jax.random as random
@partial(jax.jit, static_argnames=['field_names'])
def collect_fields(state, field_names: tuple):
"""
Collect fields from a state (i.e. a `namedtuple`)
Returns:
dict: {*field_names: *collected_fields}
"""
collected_fields = attrgetter(*field_names)(state)
if len(field_names) == 1:
return {field_names[0]: collected_fields}
else:
return dict(zip(field_names, collected_fields))
@partial(jax.jit, static_argnames=['transform', 'sites_subset'])
def transform_and_subset_sample_sites(z, transform, sites_subset):
z = transform(z)
if sites_subset:
return {site: z[site] for site in sites_subset}
else:
return z
def run_mcmc(rng_key,
kernel,
model_args,
model_kwargs,
*,
num_warmup: int,
num_samples: int,
extra_fields=(),
sites_subset=None,
return_warmup=False,
):
"""
A thin functional wrapper around NumPyro `MCMCKernel`s that allows for vectorization/parallelism
over model_args and model_kwargs and optionally avoids storing nuisance sample sites. The API and
implementation draw heavily on work in the [BlackJAX](https://github.com/blackjax-devs/blackjax)
and [NumPyro](https://github.com/pyro-ppl/numpyro) repositories.
Args:
rng_key (PRNGKey): JAX PRNGKey. For `EnsembleSampler` kernels, the number of keys must match the
number of desired chains.
kernel (MCMCKernel): MCMCKernel to use for inference.
model_args (tuple): Tuple containing model arguments.
model_kwargs (dict): Dictionary containing model keyword arguments.
num_warmup (int): Number of warmup steps to take.
num_samples (int): Number of samples to take.
extra_fields (tuple, optional): Tuple containing the names of fields from the kernel state
to collect. If empty, will collect the kernel's default fields. Defaults to ().
sites_subset (tuple, optional): Tuple containing the names of sample sites to collect. If None,
collects all sites. Defaults to None.
return_warmup (bool, optional): Whether or not to return warmup samples. Defaults to False.
Returns:
(samples, other_fields): Tuple containing samples and collected fields from the state object.
"""
init_state = kernel.init(rng_key,
num_warmup=num_warmup,
model_args=model_args,
model_kwargs=model_kwargs)
transform = kernel.postprocess_fn(model_args, model_kwargs)
to_collect = tuple(set((kernel.sample_field,) + kernel.default_fields + extra_fields))
def step(state, iter):
state = kernel.sample(state, model_args, model_kwargs)
collected_fields = collect_fields(state, to_collect)
sample_sites = collected_fields.pop(kernel.sample_field)
sample = transform_and_subset_sample_sites(sample_sites, transform, sites_subset)
return state, (sample, collected_fields)
final_state, (samples, other_fields) = jax.lax.scan(step, init_state, jnp.arange(num_warmup + num_samples))
if not return_warmup:
samples = jax.tree_util.tree_map(lambda x: x[num_warmup:], samples)
other_fields = jax.tree_util.tree_map(lambda x: x[num_warmup:], other_fields)
return samples, other_fields
if __name__ == '__main__':
import numpyro
import numpyro.distributions as dist
from numpyro.infer import NUTS
numpyro.set_platform('cpu')
numpyro.set_host_device_count('4')
def model(data):
mu = numpyro.sample('mu', dist.Normal(0, 2))
sigma = numpyro.sample('sigma', dist.HalfCauchy(1))
with numpyro.plate('n_obs', len(data)):
numpyro.sample('data', dist.Normal(mu, sigma), obs=data)
key = random.PRNGKey(0)
ground_truth_mu = jnp.array([1, 2, 3, 4])[:, None]
data = random.normal(key, (4, 200))*0.1 + ground_truth_mu
model_args = (data,)
model_kwargs = {}
keys = random.split(key, 4)
batch_run_mcmc = partial(run_mcmc, num_warmup=5000, num_samples=5000, sites_subset=('mu',))
samps, other_fields = jax.pmap(batch_run_mcmc,
in_axes=(0, None, 0, None),
static_broadcasted_argnums=1)(keys,
NUTS(model),
model_args,
model_kwargs)
print(samps['mu'].shape)
assert jnp.allclose(jnp.mean(samps['mu'], axis=1), ground_truth_mu.squeeze(),
atol=.01)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment