Skip to content

Instantly share code, notes, and snippets.

Embed
What would you like to do?
notebooks/GLM-hierarchical.ipynb
{
"cells": [
{
"metadata": {},
"cell_type": "markdown",
"source": "# Using JAX for faster sampling\n\n(c) Thomas Wiecki, 2020\n\n*Note: These samplers are still experimental.*\n\nUsing the new Theano JAX linker that Brandon Willard has developed, we can compile PyMC3 models to JAX without any change to the PyMC3 code base or any user-level code changes. The way this works is that we take our Theano graph built by PyMC3 and then translate it to JAX primitives. \n\nUsing our Python samplers, this is still a bit slower than the C-code generated by default Theano.\n\nHowever, things get really interesting when we also express our samplers in JAX. Here we have used the JAX samplers by NumPyro or TFP. This combining of the samplers was done by [Junpeng Lao](https://twitter.com/junpenglao). \n\nThe reason this is so much faster is that while before in PyMC3, only the logp evaluation was compiled while the samplers where still coded in Python, so for every loop we went back from C to Python. With this approach, the model *and* the sampler are JIT-compiled by JAX and there is no more Python overhead during the whole sampling run. This way we also get sampling on GPUs or TPUs for free.\n\nThis NB requires the master of [Theano-PyMC](https://github.com/pymc-devs/Theano-PyMC), the [pymc3jax branch of PyMC3](https://github.com/pymc-devs/pymc3/tree/pymc3jax), as well as JAX, TFP-nightly and numpyro.\n\nThis is all still highly experimental but extremely promising and just plain amazing.\n\nAs an example we'll use the classic Radon hierarchical model. Note that this model is still very small, I would expect much more massive speed-ups with larger models."
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "import arviz as az\nimport matplotlib.pyplot as plt\nimport numpy as np\nimport pandas as pd\nimport pymc3 as pm\nimport pymc3.sampling_jax\nimport theano\n\nprint(f\"Running on PyMC3 v{pm.__version__}\")",
"execution_count": 2,
"outputs": [
{
"output_type": "stream",
"text": "Running on PyMC3 v3.9.3\n",
"name": "stdout"
},
{
"output_type": "stream",
"text": "/Users/twiecki/projects/pymc/pymc3/sampling_jax.py:22: UserWarning: This module is experimental.\n warnings.warn(\"This module is experimental.\")\n",
"name": "stderr"
}
]
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "%config InlineBackend.figure_format = 'retina'\naz.style.use(\"arviz-darkgrid\")",
"execution_count": 3,
"outputs": []
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "data = pd.read_csv(pm.get_data(\"radon.csv\"))\ndata[\"log_radon\"] = data[\"log_radon\"].astype(theano.config.floatX)\ncounty_names = data.county.unique()\ncounty_idx = data.county_code.values.astype(\"int32\")\n\nn_counties = len(data.county.unique())",
"execution_count": 4,
"outputs": []
},
{
"metadata": {},
"cell_type": "markdown",
"source": "Unchanged PyMC3 model specification:"
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "with pm.Model() as hierarchical_model:\n # Hyperpriors for group nodes\n mu_a = pm.Normal(\"mu_a\", mu=0.0, sigma=100.0)\n sigma_a = pm.HalfNormal(\"sigma_a\", 5.0)\n mu_b = pm.Normal(\"mu_b\", mu=0.0, sigma=100.0)\n sigma_b = pm.HalfNormal(\"sigma_b\", 5.0)\n\n # Intercept for each county, distributed around group mean mu_a\n # Above we just set mu and sd to a fixed value while here we\n # plug in a common group distribution for all a and b (which are\n # vectors of length n_counties).\n a = pm.Normal(\"a\", mu=mu_a, sigma=sigma_a, shape=n_counties)\n # Intercept for each county, distributed around group mean mu_a\n b = pm.Normal(\"b\", mu=mu_b, sigma=sigma_b, shape=n_counties)\n\n # Model error\n eps = pm.HalfCauchy(\"eps\", 5.0)\n\n radon_est = a[county_idx] + b[county_idx] * data.floor.values\n\n # Data likelihood\n radon_like = pm.Normal(\"radon_like\", mu=radon_est, sigma=eps, observed=data.log_radon)",
"execution_count": 5,
"outputs": []
},
{
"metadata": {},
"cell_type": "markdown",
"source": "## Sampling using our old Python NUTS sampler"
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "%%time\nwith hierarchical_model:\n hierarchical_trace = pm.sample(\n 2000, tune=2000, target_accept=0.9, compute_convergence_checks=False\n )",
"execution_count": 6,
"outputs": [
{
"output_type": "stream",
"text": "Auto-assigning NUTS sampler...\nInitializing NUTS using jitter+adapt_diag...\nMultiprocess sampling (2 chains in 2 jobs)\nNUTS: [eps, b, a, sigma_b, mu_b, sigma_a, mu_a]\n",
"name": "stderr"
},
{
"output_type": "display_data",
"data": {
"text/plain": "<IPython.core.display.HTML object>",
"text/html": "\n <div>\n <style>\n /* Turns off some styling */\n progress {\n /* gets rid of default border in Firefox and Opera. */\n border: none;\n /* Needs to be in here for Safari polyfill so background images work as expected. */\n background-size: auto;\n }\n .progress-bar-interrupted, .progress-bar-interrupted::-webkit-progress-bar {\n background: #F44336;\n }\n </style>\n <progress value='8000' class='' max='8000' style='width:300px; height:20px; vertical-align: middle;'></progress>\n 100.00% [8000/8000 01:07<00:00 Sampling 2 chains, 13 divergences]\n </div>\n "
},
"metadata": {}
},
{
"output_type": "stream",
"text": "Sampling 2 chains for 2_000 tune and 2_000 draw iterations (4_000 + 4_000 draws total) took 94 seconds.\nThere were 4 divergences after tuning. Increase `target_accept` or reparameterize.\nThere were 9 divergences after tuning. Increase `target_accept` or reparameterize.\n",
"name": "stderr"
},
{
"output_type": "stream",
"text": "CPU times: user 9.72 s, sys: 1.34 s, total: 11.1 s\nWall time: 1min 47s\n",
"name": "stdout"
}
]
},
{
"metadata": {},
"cell_type": "markdown",
"source": "## Sampling using JAX TFP NUTS sampler"
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "%%time\n# Inference button (TM)!\nwith hierarchical_model:\n hierarchical_trace_jax = pm.sampling_jax.sample_numpyro_nuts(2000, tune=2000, target_accept=0.9)",
"execution_count": 7,
"outputs": [
{
"output_type": "stream",
"text": "Compiling...\n",
"name": "stdout"
},
{
"output_type": "stream",
"text": "/Users/twiecki/miniconda3/envs/pymc3theano/lib/python3.8/site-packages/jax/interpreters/xla.py:624: UserWarning: The jitted function _sample includes a pmap. Using jit-of-pmap can lead to inefficient data movement, as the outer jit does not preserve sharded data representations and instead collects input and output arrays onto a single device. Consider removing the outer jit unless you know what you're doing. See https://github.com/google/jax/issues/2926.\n warn(f\"The jitted function {fun.__name__} includes a pmap. Using \"\n",
"name": "stderr"
},
{
"output_type": "stream",
"text": "Compilation + sampling time = 0 days 00:00:28.724421\nCPU times: user 30.5 s, sys: 1.82 s, total: 32.3 s\nWall time: 29.1 s\n",
"name": "stdout"
}
]
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "%%time\n# Inference button (TM)!\nwith hierarchical_model:\n hierarchical_trace_jax = pm.sampling_jax.sample_numpyro_nuts(2000, tune=2000, target_accept=0.9)",
"execution_count": 10,
"outputs": [
{
"output_type": "stream",
"text": "Compiling...\n",
"name": "stdout"
},
{
"output_type": "stream",
"text": "/Users/twiecki/miniconda3/envs/pymc3theano/lib/python3.8/site-packages/jax/interpreters/xla.py:624: UserWarning: The jitted function _sample includes a pmap. Using jit-of-pmap can lead to inefficient data movement, as the outer jit does not preserve sharded data representations and instead collects input and output arrays onto a single device. Consider removing the outer jit unless you know what you're doing. See https://github.com/google/jax/issues/2926.\n warn(f\"The jitted function {fun.__name__} includes a pmap. Using \"\n",
"name": "stderr"
},
{
"output_type": "stream",
"text": "Compilation + sampling time = 0 days 00:00:26.573344\nCPU times: user 28.9 s, sys: 1.44 s, total: 30.3 s\nWall time: 26.7 s\n",
"name": "stdout"
}
]
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "%%time\n# Inference button (TM)!\nwith hierarchical_model:\n hierarchical_trace_tfp = pm.sampling_jax.sample_tfp_nuts(2000, tune=2000, target_accept=0.9)",
"execution_count": 8,
"outputs": [
{
"output_type": "stream",
"text": "Compiling...\nCompilation + sampling time = 0 days 00:01:20.362842\nCPU times: user 1min 13s, sys: 4.15 s, total: 1min 17s\nWall time: 1min 20s\n",
"name": "stdout"
}
]
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "%%time\n# Inference button (TM)!\nwith hierarchical_model:\n hierarchical_trace_tfp = pm.sampling_jax.sample_tfp_nuts(2000, tune=2000, target_accept=0.9)",
"execution_count": 9,
"outputs": [
{
"output_type": "stream",
"text": "Compiling...\nCompilation + sampling time = 0 days 00:01:21.231797\nCPU times: user 1min 13s, sys: 4.22 s, total: 1min 17s\nWall time: 1min 21s\n",
"name": "stdout"
}
]
},
{
"metadata": {
"trusted": false
},
"cell_type": "code",
"source": "print(f\"Speed-up = {180 / 24}x\")",
"execution_count": 7,
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": "Speed-up = 7.5x\n"
}
]
},
{
"metadata": {
"trusted": false
},
"cell_type": "code",
"source": "pm.traceplot(\n hierarchical_trace_jax,\n var_names=[\"mu_a\", \"mu_b\", \"sigma_a_log__\", \"sigma_b_log__\", \"eps_log__\"],\n);",
"execution_count": 8,
"outputs": [
{
"data": {
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment