Skip to content

Instantly share code, notes, and snippets.

@fehiepsi
Created April 3, 2019 19:57
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save fehiepsi/b24b6dad5198c50e4923ac209daf93d9 to your computer and use it in GitHub Desktop.
Save fehiepsi/b24b6dad5198c50e4923ac209daf93d9 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import math\n",
"\n",
"import jax.numpy as np\n",
"from jax import partial, random\n",
"from jax.flatten_util import ravel_pytree\n",
"from jax.random import PRNGKey\n",
"\n",
"import numpyro.distributions as dist\n",
"from numpyro.hmc_util import IntegratorState, find_reasonable_step_size, velocity_verlet, warmup_adapter\n",
"from numpyro.util import cond, fori_loop, laxtuple\n",
"\n",
"HMCState = laxtuple('HMCState', ['z', 'z_grad', 'potential_energy', 'num_steps', 'accept_prob',\n",
" 'step_size', 'inverse_mass_matrix', 'rng'])\n",
"\n",
"\n",
"def _get_num_steps(step_size, trajectory_length):\n",
" num_steps = np.array(trajectory_length / step_size, dtype=np.int32)\n",
" return np.where(num_steps < 1, np.array(1, dtype=np.int32), num_steps)\n",
"\n",
"\n",
"def _sample_momentum(unpack_fn, inverse_mass_matrix, rng):\n",
" if inverse_mass_matrix.ndim == 1:\n",
" r = dist.norm(0., np.sqrt(np.reciprocal(inverse_mass_matrix))).rvs(random_state=rng)\n",
" return unpack_fn(r)\n",
" elif inverse_mass_matrix.ndim == 2:\n",
" raise NotImplementedError"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def hmc(potential_fn, kinetic_fn, algo=\"NUTS\"):\n",
" vv_init, vv_update = velocity_verlet(potential_fn, kinetic_fn)\n",
" momentum_generator = None\n",
" wa_update = None\n",
" trajectory_length = None\n",
" _next = _nuts_next if algo == \"NUTS\" else _hmc_next\n",
"\n",
" def init_kernel(init_samples,\n",
" num_warmup_steps,\n",
" step_size=1.0,\n",
" num_steps=None,\n",
" adapt_step_size=True,\n",
" adapt_mass_matrix=True,\n",
" diag_mass=True,\n",
" target_accept_prob=0.8,\n",
" run_warmup=True,\n",
" rng=PRNGKey(0)):\n",
" step_size = float(step_size)\n",
" nonlocal trajectory_length, momentum_generator, wa_update\n",
"\n",
" if num_steps is None:\n",
" trajectory_length = 2 * math.pi\n",
" else:\n",
" trajectory_length = num_steps * step_size\n",
"\n",
" z = init_samples\n",
" z_flat, unravel_fn = ravel_pytree(z)\n",
" momentum_generator = partial(_sample_momentum, unravel_fn)\n",
"\n",
" find_reasonable_ss = partial(find_reasonable_step_size,\n",
" potential_fn, kinetic_fn, momentum_generator)\n",
"\n",
" wa_init, wa_update = warmup_adapter(num_warmup_steps,\n",
" find_reasonable_step_size=find_reasonable_ss,\n",
" adapt_step_size=adapt_step_size,\n",
" adapt_mass_matrix=adapt_mass_matrix,\n",
" diag_mass=diag_mass,\n",
" target_accept_prob=target_accept_prob)\n",
"\n",
" rng_hmc, rng_wa = random.split(rng)\n",
" wa_init_state = wa_init(z, rng_wa, mass_matrix_size=np.size(z_flat))\n",
" r = momentum_generator(wa_state.inverse_mass_matrix, rng)\n",
" vv_state = vv_init(z, r)\n",
" hmc_init_state = HMCState(..., rng_hmc)\n",
"\n",
" if run_warmup:\n",
" hmc_state, wa_state = fori_loop(0, num_warmup_steps, warmup_update,\n",
" (hmc_init_state, warmup_init_state))\n",
" return hmc_state\n",
" else:\n",
" return warmup_update, hmc_init_state, warmup_init_state\n",
"\n",
" def warmup_update(t, args):\n",
" hmc_state, wa_state = args\n",
" hmc_state = sample_kernel(hmc_state)\n",
" wa_state = wa_update(t, hmc_state.accept_prob, hmc_state.z, wa_state)\n",
" hmc_state = hmc_state.update(step_size=wa_state.step_size,\n",
" inverse_mass_matrix=wa_state.inverse_mass_matrix)\n",
" return hmc_state, wa_state\n",
"\n",
" def _hmc_next(step_size, inverse_mass_matrix, vv_state, rng):\n",
" num_steps = _get_num_steps(wa_state.step_size, trajectory_length)\n",
" vv_state_new = fori_loop(0, num_steps,\n",
" lambda i, val: vv_update(step_size, inverse_mass_matrix, val),\n",
" vv_state)\n",
" energy_old = vv_state.potential_energy + kinetic_fn(vv_state.r, inverse_mass_matrix)\n",
" energy_new = vv_state_new.potential_energy + kinetic_fn(vv_state_new.r, inverse_mass_matrix)\n",
" delta_energy = energy_new - energy_old\n",
" delta_energy = np.where(np.isnan(delta_energy), np.inf, delta_energy)\n",
" accept_prob = np.clip(np.exp(-delta_energy), a_max=1.0)\n",
" transition = random.bernoulli(rng, accept_prob)\n",
" vv_state = cond(transition,\n",
" vv_state_new, lambda state: state,\n",
" vv_state, lambda state: state)\n",
" return num_steps, accept_prob, vv_state_new\n",
"\n",
" def _nuts_next(step_size, inverse_mass_matrix, vv_state):\n",
" binary_tree = build_tree(..., rng)\n",
" accept_prob = binary_tree.num_accept_probs / binary_tree.num_proposals\n",
" num_steps = binary_tree.num_proposals\n",
" vv_state_new = ... # binary_tree.z_proposal, vv_state.r, binary_tree.z_proposal_pe\n",
" return accept_prob, num_steps, vv_state_new\n",
"\n",
" def sample_kernel(hmc_state):\n",
" rng, rng_momentum, rng_transition = random.split(hmc_state.rng, 3)\n",
" r = momentum_generator(hmc_state.inverse_mass_matrix, rng_momentum)\n",
" vv_state = IntegratorState(hmc_state.z, r, hmc_state.potential_energy, hmc_state.z_grad)\n",
" num_steps, accept_prob, vv_state_new = _next(hmc_state.step_size,\n",
" hmc_state.inverse_mass_matrix, vv_state, rng_transition)\n",
" return HMCState(vv_state.z, vv_state.z_grad, vv_state.potential_energy, num_steps, accept_prob,\n",
" hmc_state.step_size, hmc_state.inverse_mass_matrix, rng)\n",
"\n",
" return init_kernel, sample_kernel"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.8"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment