Skip to content

Instantly share code, notes, and snippets.

@fehiepsi
Last active April 8, 2019 05:08
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save fehiepsi/ab876d1db27c2277dc6c0cf1ab5c8ff0 to your computer and use it in GitHub Desktop.
Save fehiepsi/ab876d1db27c2277dc6c0cf1ab5c8ff0 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import time\n",
"\n",
"N, dim = 3000, 3\n",
"warmup_steps, num_samples = 1000, 20000"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### numpyro"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [],
"source": [
"import jax.numpy as np\n",
"import jax.random as random\n",
"from jax.scipy.special import expit\n",
"\n",
"import numpyro.distributions as dist\n",
"from numpyro.distributions.util import validation_disabled\n",
"from numpyro.handlers import sample\n",
"from numpyro.hmc_util import initialize_model\n",
"from numpyro.mcmc import hmc_kernel\n",
"from numpyro.util import tscan"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
"rng = random.PRNGKey(0)\n",
"data = random.normal(rng, (N, dim))\n",
"true_coefs = np.arange(1., dim + 1.)\n",
"logits = np.sum(true_coefs * data, axis=-1)\n",
"labels = dist.bernoulli(logits, is_logits=True).rvs(random_state=rng)\n",
"\n",
"def model(labels):\n",
" coefs = sample('coefs', dist.norm(np.zeros(dim), np.ones(dim)))\n",
" logits = np.sum(coefs * data, axis=-1)\n",
" return sample('obs', dist.bernoulli(logits, is_logits=True), obs=labels)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Time for complie and init: 6.054578542709351\n",
"Time to compile: 1.8476543426513672\n",
"Time to compile and generate 20000 samples: 2.8188066482543945\n"
]
}
],
"source": [
"with validation_disabled():\n",
" init_params, potential_fn = initialize_model(random.PRNGKey(2), model, (labels,), {})\n",
" init_kernel, sample_kernel = hmc_kernel(potential_fn, algo=\"HMC\")\n",
" start = time.time()\n",
" hmc_state = init_kernel(init_params,\n",
" step_size=0.1,\n",
" num_steps=15,\n",
" num_warmup_steps=warmup_steps)\n",
" print(\"Time for complie and init:\", time.time() - start)\n",
" start = time.time()\n",
" hmc_states = tscan(lambda state, i: sample_kernel(state),\n",
" hmc_state, np.arange(1))\n",
" print(\"Time to compile:\", time.time() - start)\n",
" start = time.time()\n",
" hmc_states = tscan(lambda state, i: sample_kernel(state),\n",
" hmc_state, np.arange(num_samples))\n",
" print(\"Time to compile and generate 20000 samples:\", time.time() - start)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Time for complie and init: 11.951331853866577\n",
"Time to compile: 8.162105798721313\n",
"Time to compile and generate 20000 samples: 11.791397094726562\n"
]
}
],
"source": [
"with validation_disabled():\n",
" init_params, potential_fn = initialize_model(random.PRNGKey(2), model, (labels,), {})\n",
" init_kernel, sample_kernel = hmc_kernel(potential_fn, algo=\"NUTS\")\n",
" start = time.time()\n",
" hmc_state = init_kernel(init_params,\n",
" step_size=0.1,\n",
" num_steps=15,\n",
" num_warmup_steps=warmup_steps)\n",
" print(\"Time for complie and init:\", time.time() - start)\n",
" start = time.time()\n",
" hmc_states = tscan(lambda state, i: sample_kernel(state),\n",
" hmc_state, np.arange(1))\n",
" print(\"Time to compile:\", time.time() - start)\n",
" start = time.time()\n",
" hmc_states = tscan(lambda state, i: sample_kernel(state),\n",
" hmc_state, np.arange(num_samples))\n",
" print(\"Time to compile and generate 20000 samples:\", time.time() - start)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### pyro"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"\n",
"import pyro\n",
"import pyro.distributions as pdist\n",
"from pyro.infer.mcmc import HMC, MCMC, NUTS"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"pyro.set_rng_seed(0)\n",
"data = torch.randn(N, dim)\n",
"true_coefs = torch.arange(1., dim + 1.)\n",
"logits = (true_coefs * data).sum(-1)\n",
"labels = pdist.Bernoulli(logits=logits).sample()\n",
"\n",
"def model(data):\n",
" coefs = pyro.sample('beta', pdist.Normal(torch.zeros(dim), torch.ones(dim)))\n",
" logits = (coefs * data).sum(-1)\n",
" return pyro.sample('y', pdist.Bernoulli(logits=logits), obs=labels)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Time for complie and init: 4.401107311248779\n",
"Time for complie, init, and run: 42.86292505264282\n"
]
}
],
"source": [
"start = time.time()\n",
"hmc_kernel = HMC(model, step_size=0.1, num_steps=15, jit_compile=True, ignore_jit_warnings=True)\n",
"mcmc_run = MCMC(hmc_kernel, num_samples=1, warmup_steps=warmup_steps, disable_progbar=True).run(data)\n",
"print(\"Time for complie and init:\", time.time() - start)\n",
"start = time.time()\n",
"hmc_kernel = HMC(model, step_size=0.1, num_steps=15, jit_compile=True, ignore_jit_warnings=True)\n",
"mcmc_run = MCMC(hmc_kernel, num_samples=num_samples, warmup_steps=warmup_steps, disable_progbar=True).run(data)\n",
"print(\"Time for complie, init, and run:\", time.time() - start)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Time for complie and init: 5.244949817657471\n",
"Time for complie, init, and run: 118.32065653800964\n"
]
}
],
"source": [
"start = time.time()\n",
"hmc_kernel = NUTS(model, step_size=0.1, jit_compile=True, ignore_jit_warnings=True)\n",
"mcmc_run = MCMC(hmc_kernel, num_samples=1, warmup_steps=warmup_steps, disable_progbar=True).run(data)\n",
"print(\"Time for complie and init:\", time.time() - start)\n",
"start = time.time()\n",
"hmc_kernel = NUTS(model, step_size=0.1, jit_compile=True, ignore_jit_warnings=True)\n",
"mcmc_run = MCMC(hmc_kernel, num_samples=num_samples, warmup_steps=warmup_steps, disable_progbar=True).run(data)\n",
"print(\"Time for complie, init, and run:\", time.time() - start)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**So numpyro sampling is 30x faster than pyro.**"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**And we see that NUTS takes more than 3s to generate 20000 samples while HMC only takes 1s.**"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### let's see if nuts' implementation adds overhead over hmc"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array(40000, dtype=int32)"
]
},
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"init_params, potential_fn = initialize_model(random.PRNGKey(2), model, (labels,), {})\n",
"init_kernel, sample_kernel = hmc_kernel(potential_fn, algo=\"HMC\")\n",
"hmc_state = init_kernel(init_params,\n",
" step_size=0.1,\n",
" num_steps=15,\n",
" num_warmup_steps=warmup_steps)\n",
"hmc_states = tscan(lambda state, i: sample_kernel(state),\n",
" hmc_state, np.arange(num_samples), fields=(3,))\n",
"np.sum(hmc_states.num_steps).copy()"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array(113032, dtype=int32)"
]
},
"execution_count": 17,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"init_params, potential_fn = initialize_model(random.PRNGKey(2), model, (labels,), {})\n",
"init_kernel, sample_kernel = hmc_kernel(potential_fn, algo=\"NUTS\")\n",
"hmc_state = init_kernel(init_params,\n",
" step_size=0.1,\n",
" num_steps=15,\n",
" num_warmup_steps=warmup_steps)\n",
"hmc_states = tscan(lambda state, i: sample_kernel(state),\n",
" hmc_state, np.arange(num_samples), fields=(3,))\n",
"np.sum(hmc_states.num_steps).copy()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**So NUTS spends 3x more verlet steps than HMC! -> Iterative NUTS just has a small overhead over HMC.**"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Sampling is pretty fast. The only issue now is the compiling time. In addition, we compile 2 times: one at init and one at scan."
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.8"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment