Skip to content

Instantly share code, notes, and snippets.

@dfm
Created August 19, 2019 20:49
Show Gist options
  • Save dfm/da1d0470d6fb54c63e6a913c1ef67a9e to your computer and use it in GitHub Desktop.
Save dfm/da1d0470d6fb54c63e6a913c1ef67a9e to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"This a demo that shows how dense mass matrix adaptation can help PyMC3's performance in some cases.\n",
"See [this blog post](https://dfm.io/posts/pymc3-mass-matrix/) for a more detailed discussion.\n",
"\n",
"First, let's see what the PyMC3's performance is on an uncorrelated Gaussian:"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Auto-assigning NUTS sampler...\n",
"Initializing NUTS using jitter+adapt_diag...\n",
"Multiprocess sampling (2 chains in 2 jobs)\n",
"NUTS: [x]\n",
"Sampling 2 chains, 0 divergences: 100%|██████████| 12000/12000 [00:04<00:00, 2851.12draws/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"time per effective sample: 0.31932 ms\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n"
]
}
],
"source": [
"import time\n",
"import pymc3 as pm\n",
"\n",
"ndim = 5\n",
"\n",
"with pm.Model() as simple_model:\n",
" pm.Normal(\"x\", shape=(ndim,))\n",
"\n",
"strt = time.time()\n",
"with simple_model:\n",
" simple_trace = pm.sample(draws=3000, tune=3000, random_seed=42)\n",
" \n",
" # About half the time is spent in tuning so correct for that\n",
" simple_time = 0.5*(time.time() - strt)\n",
" \n",
"stats = pm.summary(simple_trace)\n",
"simple_time_per_eff = simple_time / stats.n_eff.min()\n",
"print(\"time per effective sample: {0:.5f} ms\".format(simple_time_per_eff * 1000))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"As discussed in the blog post, PyMC3 doesn't do so well if there are correlations.\n",
"But we can use the `QuadPotentialFullAdapt` potential to get nearly the same performance as above:"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Multiprocess sampling (2 chains in 2 jobs)\n",
"NUTS: [x]\n",
"Sampling 2 chains, 0 divergences: 100%|██████████| 30000/30000 [00:21<00:00, 1391.33draws/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"time per effective sample: 0.30707 ms\n"
]
}
],
"source": [
"import numpy as np\n",
"\n",
"# Generate a random positive definite matrix\n",
"np.random.seed(42)\n",
"L = np.random.randn(ndim, ndim)\n",
"L[np.diag_indices_from(L)] = 0.1*np.exp(L[np.diag_indices_from(L)])\n",
"L[np.triu_indices_from(L, 1)] = 0.0\n",
"cov = np.dot(L, L.T)\n",
"\n",
"with pm.Model() as model:\n",
" pm.MvNormal(\"x\", mu=np.zeros(ndim), chol=L, shape=(ndim,))\n",
" \n",
" # *** This is the new part ***\n",
" potential = pm.step_methods.hmc.quadpotential.QuadPotentialFullAdapt(\n",
" model.ndim, np.zeros(model.ndim))\n",
" step = pm.NUTS(model=model, potential=potential)\n",
" # *** end new part ***\n",
" \n",
" strt = time.time()\n",
" full_adapt_trace = pm.sample(draws=10000, tune=5000, random_seed=42, step=step)\n",
" full_adapt_time = 0.5 * (time.time() - strt)\n",
"\n",
"stats = pm.summary(full_adapt_trace)\n",
"full_adapt_time_per_eff = full_adapt_time / stats.n_eff.min()\n",
"print(\"time per effective sample: {0:.5f} ms\".format(full_adapt_time_per_eff * 1000))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.8"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment