Skip to content

Instantly share code, notes, and snippets.

@fehiepsi
Last active December 28, 2020 13:14
Show Gist options
  • Star 2 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save fehiepsi/b4a5a80b245600b99467a0264be05fd5 to your computer and use it in GitHub Desktop.
Save fehiepsi/b4a5a80b245600b99467a0264be05fd5 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"from collections import namedtuple\n",
"import copy\n",
"\n",
"from jax import device_put, lax, random\n",
"import jax.numpy as jnp\n",
"\n",
"import numpyro\n",
"import numpyro.distributions as dist\n",
"from numpyro.handlers import substitute, trace, seed\n",
"from numpyro.infer import MCMC, NUTS, log_likelihood\n",
"from numpyro.infer.mcmc import MCMCKernel\n",
"from numpyro.util import identity"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"HMC_ECS_State = namedtuple(\"HMC_ECS_State\", \"uz, hmc_state, accept_prob, rng_key\")\n",
"\"\"\"\n",
" - **uz** - a dict of current subsample indices and the current latent values\n",
" - **hmc_state** - current hmc_state\n",
" - **accept_prob** - acceptance probability of the proposal subsample indices\n",
" - **rng_key** - random key to generate new subsample indices\n",
"\"\"\"\n",
"\n",
"def _wrap_model(model):\n",
" def fn(*args, **kwargs):\n",
" subsample_values = kwargs.pop(\"_subsample_sites\", {})\n",
" with substitute(data=subsample_values):\n",
" model(*args, **kwargs)\n",
"\n",
" return fn\n",
"\n",
"\n",
"class HMC_ECS(MCMCKernel):\n",
" sample_field = \"uz\"\n",
"\n",
" def __init__(self, inner_kernel):\n",
" self.inner_kernel = copy.copy(inner_kernel)\n",
" self.inner_kernel._model = _wrap_model(inner_kernel.model)\n",
" self._plate_sizes = None\n",
"\n",
" @property\n",
" def model(self):\n",
" return self.inner_kernel._model\n",
"\n",
" def postprocess_fn(self, args, kwargs):\n",
" def fn(uz):\n",
" z = {k: v for k, v in uz.items() if k not in self._plate_sizes}\n",
" return self.inner_kernel.postprocess_fn(args, kwargs)(z)\n",
"\n",
" return fn\n",
"\n",
" def init(self, rng_key, num_warmup, init_params, model_args, model_kwargs):\n",
" model_kwargs = {} if model_kwargs is None else model_kwargs.copy()\n",
" rng_key, key_u, key_z = random.split(rng_key, 3)\n",
" prototype_trace = trace(seed(self.model, key_u)).get_trace(*model_args, **model_kwargs)\n",
" u = {name: site[\"value\"] for name, site in prototype_trace.items()\n",
" if site[\"type\"] == \"plate\" and site[\"args\"][0] > site[\"args\"][1]}\n",
" self._plate_sizes = {name: prototype_trace[name][\"args\"] for name in u}\n",
" model_kwargs[\"_subsample_sites\"] = u\n",
" hmc_state = self.inner_kernel.init(key_z, num_warmup, init_params,\n",
" model_args, model_kwargs)\n",
" uz = {**u, **hmc_state.z}\n",
" return device_put(HMC_ECS_State(uz, hmc_state, 1., rng_key))\n",
"\n",
" def sample(self, state, model_args, model_kwargs):\n",
" model_kwargs = {} if model_kwargs is None else model_kwargs.copy()\n",
" rng_key, key_u = random.split(state.rng_key)\n",
" u = {k: v for k, v in state.uz.items() if k in self._plate_sizes}\n",
" u_new = {}\n",
" for name, (size, subsample_size) in self._plate_sizes.items():\n",
" key_u, subkey = random.split(key_u)\n",
" u_new[name] = random.choice(subkey, size, (subsample_size,), replace=False)\n",
" sample = self.postprocess_fn(model_args, model_kwargs)(state.hmc_state.z)\n",
" u_loglik = log_likelihood(self.model, sample, *model_args, batch_ndims=0,\n",
" **model_kwargs, _subsample_sites=u)\n",
" u_loglik = sum(v.sum() for v in u_loglik.values())\n",
" u_new_loglik = log_likelihood(self.model, sample, *model_args, batch_ndims=0,\n",
" **model_kwargs, _subsample_sites=u_new)\n",
" u_new_loglik = sum(v.sum() for v in u_new_loglik.values())\n",
" accept_prob = jnp.clip(jnp.exp(u_new_loglik - u_loglik), a_max=1.0)\n",
" u = lax.cond(random.bernoulli(key_u, accept_prob), u_new, identity, u, identity)\n",
" model_kwargs[\"_subsample_sites\"] = u\n",
" hmc_state = self.inner_kernel.sample(state.hmc_state, model_args, model_kwargs)\n",
" uz = {**u, **hmc_state.z}\n",
" return HMC_ECS_State(uz, hmc_state, accept_prob, rng_key)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"sample: 100%|██████████| 1000/1000 [00:11<00:00, 90.54it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
" mean std median 5.0% 95.0% n_eff r_hat\n",
" x 1.01 0.01 1.01 0.99 1.02 180.80 1.02\n",
"\n"
]
}
],
"source": [
"def model(data):\n",
" x = numpyro.sample(\"x\", dist.Normal(0, 1))\n",
" with numpyro.plate(\"N\", data.shape[0], subsample_size=100):\n",
" batch = numpyro.subsample(data, event_dim=0)\n",
" numpyro.sample(\"obs\", dist.Normal(x, 1), obs=batch)\n",
"\n",
"kernel = HMC_ECS(NUTS(model))\n",
"mcmc = MCMC(kernel, 500, 500)\n",
"data = random.normal(random.PRNGKey(1), (10000,)) + 1\n",
"mcmc.run(random.PRNGKey(0), data, extra_fields=(\"accept_prob\",))\n",
"# there is a bug when exclude_deterministic=True, which will be fixed upstream\n",
"mcmc.print_summary(exclude_deterministic=False)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.5"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment