Skip to content

Instantly share code, notes, and snippets.

@fehiepsi
Last active March 23, 2019 06:00
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save fehiepsi/453c0e8ed6a24739d6295feecbeb8635 to your computer and use it in GitHub Desktop.
Save fehiepsi/453c0e8ed6a24739d6295feecbeb8635 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import jax.numpy as np\n",
"from jax import jit, lax, random\n",
"from numpyro.util import build_tree, velocity_verlet"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"def kinetic_fn(p):\n",
" return 0.5 * p ** 2\n",
"\n",
"def potential_fn(q):\n",
" return 0.5 * q ** 2"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/fehiepsi/jax/jax/lib/xla_bridge.py:128: UserWarning: No GPU found, falling back to CPU.\n",
" warnings.warn('No GPU found, falling back to CPU.')\n"
]
}
],
"source": [
"vv_init, vv_update = velocity_verlet(potential_fn, kinetic_fn)\n",
"jitted_vv_update = jit(vv_update)\n",
"vv_state = vv_init(0.0, 1.0)\n",
"inverse_mass_matrix = np.array([1.])\n",
"step_size = 0.01\n",
"rng = random.PRNGKey(0)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1.01 s ± 8.85 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n"
]
},
{
"data": {
"text/plain": [
"_TreeInfo(z_left=array(-0.99996996, dtype=float32), r_left=array(-0.00920987, dtype=float32), z_left_grad=array(-0.99996996, dtype=float32), z_right=array(0.96084845, dtype=float32), r_right=array(0.27711588, dtype=float32), z_right_grad=array(0.96084845, dtype=float32), z_proposal=array(0.188862, dtype=float32), z_proposal_pe=array(0.01783443, dtype=float32), z_proposal_grad=array(0.188862, dtype=float32), depth=array(9, dtype=int32), weight=array(5.662955, dtype=float32), r_sum=array(196.21094, dtype=float32), turning=array(True), diverging=array(False), sum_accept_probs=array(286.99835, dtype=float32), num_proposals=array(287, dtype=int32))"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"def f(vv_state):\n",
" return build_tree(vv_update, kinetic_fn, vv_state,\n",
" inverse_mass_matrix, step_size, rng, iterative_build=False)\n",
"\n",
"%timeit f(vv_state)\n",
"f(vv_state)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### comparing to hmc"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"90.6 ms ± 1.07 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)\n"
]
}
],
"source": [
"def g(vv_state):\n",
" for i in range(287):\n",
" vv_state = jitted_vv_update(step_size, vv_state)\n",
" return vv_state\n",
"\n",
"%timeit g(vv_state)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"44.6 ms ± 418 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)\n"
]
}
],
"source": [
"def h(vv_state):\n",
" vv_state = lax.fori_loop(\n",
" 0, 287,\n",
" lambda i, state: jitted_vv_update(step_size, state),\n",
" vv_state\n",
" )\n",
" return vv_state\n",
"\n",
"%timeit h(vv_state)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### jitted each hmc sample"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"581 µs ± 246 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)\n"
]
}
],
"source": [
"jitted_g = jit(g)\n",
"%timeit jitted_g(vv_state)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"360 µs ± 7.44 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)\n"
]
}
],
"source": [
"jitted_h = jit(h)\n",
"%timeit jitted_h(vv_state)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### is this 350microsecond affected by integrator?"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"315 ms ± 11.5 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n"
]
}
],
"source": [
"def g1(vv_state):\n",
" for i in range(1000):\n",
" vv_state = jitted_vv_update(step_size, vv_state)\n",
" return vv_state\n",
"\n",
"%timeit g1(vv_state)"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"355 µs ± 9.1 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)\n"
]
}
],
"source": [
"@jit\n",
"def h1(vv_state):\n",
" vv_state = lax.fori_loop(\n",
" 0, 1000,\n",
" lambda i, state: jitted_vv_update(step_size, state),\n",
" vv_state\n",
" )\n",
" return vv_state\n",
"\n",
"%timeit h1(vv_state)"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"348 µs ± 7.63 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)\n"
]
}
],
"source": [
"@jit\n",
"def k(vv_state):\n",
" vv_state = lax.fori_loop(\n",
" 0, 287,\n",
" lambda i, state: state, # trivial function\n",
" vv_state\n",
" )\n",
" return vv_state\n",
"\n",
"%timeit k(vv_state)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**Some takeaways**\n",
"+ For HMC, it is best to jit each `sample` call. The overhead 360microsecond is small. Assume that we want to generate 1000 samples, the overhead is expected to be less than 1s. So most of time will be spent for verlet update, which is exactly what we want. To reduce compling time for generating the first trajectory, it is better to use `lax.fori_loop` instead of native python loop. Anyway, it seems that JAX will works pretty well with HMC.\n",
"\n",
"+ For NUTS, we couldn't jit each sample call yet. If we only jit the verlet update, NUTS is 10x slower than HMC. But doing so in HMC has already involved a lot of overhead (in `g1`, it took 300ms for 1000 verlet steps, hence 300s to get 1000 samples). So improving NUTS overhead might not be a good solution. The best way is still to jit each NUTS `sample` call."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Iterative NUTS"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"555 µs ± 98.9 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)\n"
]
},
{
"data": {
"text/plain": [
"_TreeInfo(z_left=array(-0.99996996, dtype=float32), r_left=array(-0.00920987, dtype=float32), z_left_grad=array(-0.99996996, dtype=float32), z_right=array(0.96084845, dtype=float32), r_right=array(0.27711588, dtype=float32), z_right_grad=array(0.96084845, dtype=float32), z_proposal=array(0.8772135, dtype=float32), z_proposal_pe=array(0.38475174, dtype=float32), z_proposal_grad=array(0.8772135, dtype=float32), depth=array(9, dtype=int32), weight=array(5.662955, dtype=float32), r_sum=array(196.2109, dtype=float32), turning=array(True), diverging=array(False), sum_accept_probs=array(286.9984, dtype=float32), num_proposals=array(287, dtype=int32))"
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"@jit\n",
"def f1(vv_state):\n",
" return build_tree(vv_update, kinetic_fn, vv_state,\n",
" inverse_mass_matrix, step_size, rng, iterative_build=True)\n",
"\n",
"%timeit f1(vv_state)\n",
"f1(vv_state)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.8"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment