Skip to content

Instantly share code, notes, and snippets.

Embed
What would you like to do?
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"\n",
"from jax import pmap\n",
"from jax.config import config; config.update('jax_platform_name', 'cpu')\n",
"from jax.lib import xla_bridge\n",
"from jax.tree_util import tree_multimap\n",
"\n",
"import numpyro.distributions as dist\n",
"from numpyro.handlers import sample\n",
"from numpyro.mcmc import hmc, mcmc\n",
"from numpyro.util import fori_collect"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"ShardedDeviceArray([3, 3, 3, 3], dtype=int32)"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"f = pmap(lambda x: 3)\n",
"x = np.arange(4)\n",
"ans = f(x)\n",
"ans"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"warmup: 100%|██████████| 10/10 [00:05<00:00, 1.92it/s, 1 steps of size 1.49e+00. acc. prob=0.72]\n"
]
}
],
"source": [
"true_mean, true_std = 1., 2.\n",
"warmup_steps, num_samples = 10, 2\n",
"\n",
"def potential_fn(z):\n",
" return 0.5 * np.sum(((z - true_mean) / true_std) ** 2)\n",
"\n",
"init_kernel, sample_kernel = hmc(potential_fn)\n",
"init_params = np.array(0.)\n",
"hmc_state = init_kernel(init_params, trajectory_length=9, num_warmup=warmup_steps)\n",
"hmc_state1 = hmc_state.update(z=np.array(0.))"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"collect_pmap = pmap(lambda hmc_state: fori_collect(num_samples, sample_kernel, hmc_state, transform=lambda x: x.z))"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"init_states = tree_multimap(lambda x1, x2: np.stack([x1, x2], axis=0), hmc_state, hmc_state1)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 2/2 [00:00<00:00, 3.70it/s]\n"
]
}
],
"source": [
"ans = collect_pmap(init_states)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"ShardedDeviceArray([[ 2.33032942, -1.0394417],\n",
" [-0.4470557, -3.37890959]], dtype=float32)"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"ans"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python (pydata)",
"language": "python",
"name": "pydata"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.8"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
You can’t perform that action at this time.