Created
August 21, 2021 15:09
-
-
Save peterroelants/f9304e8d619af49b752ae012e5320bc9 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"id": "2361a2b9-1a2a-4089-9172-f330ad38a73e", | |
"metadata": {}, | |
"source": [ | |
"# Normalized vs Non-normalized log_prob: Difference in speed?\n", | |
"\n", | |
"Notebook to test the difference in NumPyro inference speed vs:\n", | |
"- Normalized Poisson `log_prob`\n", | |
"- Non-Normalized Poisson `log_prob`\n", | |
"\n", | |
"Related discourse thread: https://forum.pyro.ai/t/unnormalized-densities/3251" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "d0fb3a55-598d-4e60-8cf4-6d908ed7e6ac", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# Imports\n", | |
"%matplotlib inline\n", | |
"%config InlineBackend.figure_format = 'svg'" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "99a4c37f-0d42-43c2-b072-e3063bcdbf57", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import sys\n", | |
"import warnings\n", | |
"import time\n", | |
"\n", | |
"import numpy as np\n", | |
"\n", | |
"import jax\n", | |
"import jax.numpy as jnp\n", | |
"\n", | |
"import numpyro\n", | |
"from numpyro.infer import MCMC, NUTS, Predictive\n", | |
"import numpyro.distributions as dist\n", | |
"from numpyro.distributions import constraints\n", | |
"from numpyro.distributions.util import validate_sample\n", | |
"\n", | |
"import matplotlib\n", | |
"import matplotlib.pyplot as plt\n", | |
"from matplotlib import cm # Colormaps\n", | |
"import seaborn as sns\n", | |
"import arviz as az\n", | |
"\n", | |
"import tqdm" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "59f82b8e-b8c4-48aa-b583-f0bb78579d32", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"sns.set_style('darkgrid')\n", | |
"az.rcParams['stats.hdi_prob'] = 0.90\n", | |
"az.style.use(\"arviz-darkgrid\")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "cd12f436-971f-48d9-a2dd-33bf56914a37", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"numpyro.set_platform('cpu')\n", | |
"numpyro.set_host_device_count(1)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "9b4e48c7-3732-4c36-abe3-d39399b5d233", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"np.random.seed(42)\n", | |
"rng_key = jax.random.PRNGKey(42)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "bf5a8a6f-4362-49eb-9af9-eed5c9be77a3", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"warnings.filterwarnings('ignore')" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "4f1f896b-422b-43e0-b411-ac0a319d38d0", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"k = 5" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "597a1e78-97e1-4883-a49d-fd47131d55d3", | |
"metadata": {}, | |
"source": [ | |
"## Data" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "78eae1e8-af9e-4fb4-8283-b66a7cdbf64b", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"n = 50_000\n", | |
"true_rate = 25.34\n", | |
"\n", | |
"observations = np.random.poisson(lam=true_rate, size=n)\n", | |
"print(observations.shape)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "b781d5af-e6d1-42ae-8a83-b03f760fdda2", | |
"metadata": {}, | |
"source": [ | |
"### Rate inference: Normalized" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "d6c2c981-a7bc-4229-9276-7190d785b41a", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def model_poisson(obs=None):\n", | |
" log_rate = numpyro.sample(\"log_rate\", dist.Normal(loc=0.0, scale=10.0))\n", | |
" rate = numpyro.deterministic(\"rate\", jnp.exp(log_rate))\n", | |
" numpyro.sample('obs', dist.Poisson(rate=rate), obs=obs)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "5fd72025-2cf6-43ce-b781-0886663439bd", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"rng_key = jax.random.PRNGKey(42)\n", | |
"\n", | |
"num_warmup, num_samples = 1000, 10000\n", | |
"\n", | |
"# Run NUTS.\n", | |
"kernel_poisson = NUTS(model_poisson)\n", | |
"mcmc_poisson = MCMC(\n", | |
" kernel_poisson,\n", | |
" num_warmup=num_warmup,\n", | |
" num_samples=num_samples,\n", | |
" num_chains=4,\n", | |
" progress_bar=True,\n", | |
")\n", | |
"# Run once to compile\n", | |
"mcmc_poisson.run(rng_key, obs=observations)\n", | |
"\n", | |
"# Show trace\n", | |
"display(az.summary(mcmc_poisson, var_names=[\"~log_rate\"], round_to=2))\n", | |
"inference_data_poisson = az.from_numpyro(\n", | |
" posterior=mcmc_poisson,\n", | |
")\n", | |
"\n", | |
"az.plot_trace(\n", | |
" inference_data_poisson,\n", | |
" compact=True,\n", | |
" var_names=[\"~log_rate\"],\n", | |
" lines=[\n", | |
" (\"rate\", {}, true_rate),\n", | |
" ],\n", | |
")\n", | |
"plt.suptitle('Trace plots', fontsize=18)\n", | |
"plt.show()\n" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "9b9c1aeb-2da6-4f1a-9ceb-76b36b30ebe7", | |
"metadata": {}, | |
"source": [ | |
"### Rate inference: Un-Normalized\n", | |
"\n", | |
"- http://sherrytowers.com/2014/07/10/poisson-likelihood/" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "a9a8d91e-d5d8-40bb-b632-1622ada9ee45", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"class PoissonUN(dist.Distribution):\n", | |
" arg_constraints = {\"rate\": constraints.positive}\n", | |
" support = constraints.nonnegative_integer\n", | |
"\n", | |
" def __init__(self, rate, *, validate_args=None):\n", | |
" self.rate = rate\n", | |
" super().__init__(jnp.shape(rate), validate_args=validate_args)\n", | |
"\n", | |
" def sample(self, key, sample_shape=()):\n", | |
" assert is_prng_key(key)\n", | |
" return random.poisson(key, self.rate, shape=sample_shape + self.batch_shape)\n", | |
"\n", | |
" @validate_sample\n", | |
" def log_prob(self, value):\n", | |
" if self._validate_args:\n", | |
" self._validate_sample(value)\n", | |
" value = jax.device_get(value)\n", | |
" return (jnp.log(self.rate) * value) - self.rate" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "3b859618-fefd-45f4-9d38-145b869f96ff", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def model_poisson_unnormalized(obs=None):\n", | |
" log_rate = numpyro.sample(\"log_rate\", dist.Normal(loc=0.0, scale=10.0))\n", | |
" rate = numpyro.deterministic(\"rate\", jnp.exp(log_rate))\n", | |
" numpyro.sample('obs', PoissonUN(rate=rate), obs=obs)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "7eeb3a7e-dcff-4d10-86a2-79fcc74c8892", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"rng_key = jax.random.PRNGKey(42)\n", | |
"\n", | |
"num_warmup, num_samples = 1000, 10000\n", | |
"\n", | |
"# Run NUTS.\n", | |
"kernel_poisson_un = NUTS(model_poisson_unnormalized)\n", | |
"mcmc_poisson_un = MCMC(\n", | |
" kernel_poisson_un,\n", | |
" num_warmup=num_warmup,\n", | |
" num_samples=num_samples,\n", | |
" num_chains=4,\n", | |
" progress_bar=True,\n", | |
")\n", | |
"mcmc_poisson_un.run(rng_key, obs=observations)\n", | |
"\n", | |
"# Show trace\n", | |
"display(az.summary(mcmc_poisson_un, var_names=[\"~log_rate\"], round_to=2))\n", | |
"inference_data_poisson_un = az.from_numpyro(\n", | |
" posterior=mcmc_poisson_un,\n", | |
")\n", | |
"\n", | |
"az.plot_trace(\n", | |
" inference_data_poisson_un,\n", | |
" compact=True,\n", | |
" var_names=[\"~log_rate\"],\n", | |
" lines=[\n", | |
" (\"rate\", {}, true_rate),\n", | |
" ],\n", | |
")\n", | |
"plt.suptitle('Trace plots', fontsize=18)\n", | |
"plt.show()\n" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "4d3ddc09-acdf-436a-a90e-970fd827a4f4", | |
"metadata": {}, | |
"source": [ | |
"## Comparison" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "8eab63c5-5250-44e5-8fbe-3c90abb58202", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"true_rate = 25.34\n", | |
"\n", | |
"data_sizes = [2, 5, 10, 50, 100, 500, 1000, 5000, 10_000, 25_000]" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "969e7f90-db33-4d70-acb0-f97d5afe541f", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def run_with_timing(model_fn, n_runs, n_warmup, n_samples, obs, true_rate):\n", | |
" \"\"\"Original\"\"\"\n", | |
" rng_key = jax.random.PRNGKey(42)\n", | |
" # Run NUTS.\n", | |
" kernel = NUTS(model_fn)\n", | |
" mcmc = MCMC(\n", | |
" kernel,\n", | |
" num_warmup=n_warmup,\n", | |
" num_samples=n_samples,\n", | |
" num_chains=1,\n", | |
" progress_bar=False,\n", | |
" jit_model_args=True\n", | |
" )\n", | |
" # Run once to compile\n", | |
" mcmc.run(rng_key, obs=obs)\n", | |
" # Run k times to time\n", | |
" times = []\n", | |
" for _ in range(n_runs):\n", | |
" start_time = time.monotonic()\n", | |
" mcmc.run(rng_key, obs=obs)\n", | |
" posterior_samples = mcmc.get_samples()\n", | |
" posterior_samples[\"rate\"].block_until_ready()\n", | |
" stop_time = time.monotonic()\n", | |
" times.append(stop_time - start_time)\n", | |
" times = np.array(times)\n", | |
" median_time = np.median(times)\n", | |
" mad_time = np.median(np.abs(times - median_time))\n", | |
" rate_error = np.abs(posterior_samples[\"rate\"] - true_rate)\n", | |
" mean_rate_error = np.mean(rate_error)\n", | |
" std_rate_error = np.std(rate_error)\n", | |
" return (median_time, mad_time), (mean_rate_error, std_rate_error)\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "75f6419f-8054-41d2-8813-9b8d49e304e5", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"median_times_normalized = np.zeros_like(data_sizes, dtype=np.float32)\n", | |
"mad_times_normalized = np.zeros_like(data_sizes, dtype=np.float32)\n", | |
"mean_rate_error_normalized = np.zeros_like(data_sizes, dtype=np.float32)\n", | |
"std_rate_error_normalized = np.zeros_like(data_sizes, dtype=np.float32)\n", | |
"\n", | |
"median_times_unnormalized = np.zeros_like(data_sizes, dtype=np.float32)\n", | |
"mad_times_unnormalized = np.zeros_like(data_sizes, dtype=np.float32)\n", | |
"mean_rate_error_unnormalized = np.zeros_like(data_sizes, dtype=np.float32)\n", | |
"std_rate_error_unnormalized = np.zeros_like(data_sizes, dtype=np.float32)\n", | |
"\n", | |
"n_runs = 5\n", | |
"n_warmup = 500\n", | |
"n_samples = 5000\n", | |
"\n", | |
"pbar = tqdm.tqdm(data_sizes)\n", | |
"for i, n in enumerate(pbar):\n", | |
"\n", | |
" pbar.set_description(f\"#samples = {n}\")\n", | |
" observations = np.random.poisson(lam=true_rate, size=n)\n", | |
" # Run normalized\n", | |
" (median_time, mad_time), (mean_rate_error, std_rate_error) = run_with_timing(\n", | |
" model_fn=model_poisson,\n", | |
" n_runs=n_runs,\n", | |
" n_warmup=n_warmup,\n", | |
" n_samples=n_samples,\n", | |
" obs=observations,\n", | |
" true_rate=true_rate\n", | |
" )\n", | |
" median_times_normalized[i] = median_time\n", | |
" mad_times_normalized[i] = mad_time\n", | |
" mean_rate_error_normalized[i] = mean_rate_error\n", | |
" std_rate_error_normalized[i] = std_rate_error\n", | |
" # Run non-normalized\n", | |
" (median_time, mad_time), (mean_rate_error, std_rate_error) = run_with_timing(\n", | |
" model_fn=model_poisson_unnormalized,\n", | |
" n_runs=n_runs,\n", | |
" n_warmup=n_warmup,\n", | |
" n_samples=n_samples,\n", | |
" obs=observations,\n", | |
" true_rate=true_rate\n", | |
" )\n", | |
" median_times_unnormalized[i] = median_time\n", | |
" mad_times_unnormalized[i] = mad_time\n", | |
" mean_rate_error_unnormalized[i] = mean_rate_error\n", | |
" std_rate_error_unnormalized[i] = std_rate_error" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "624e7014-562f-4162-b9b5-748d49741cc6", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"fig, (ax1, ax2) = plt.subplots(nrows=2, ncols=1, figsize=(8, 8))\n", | |
"\n", | |
"ax1.plot(data_sizes, median_times_normalized, \"o-\", color=\"blue\", label=\"Normalized\")\n", | |
"ax1.fill_between(\n", | |
" data_sizes, median_times_normalized-mad_times_normalized, median_times_normalized+mad_times_normalized,\n", | |
" color=\"blue\", alpha=0.15)\n", | |
"ax1.plot(data_sizes, median_times_unnormalized, \"o-\", color=\"red\", label=\"Non-Normalized\")\n", | |
"ax1.fill_between(\n", | |
" data_sizes, median_times_unnormalized-mad_times_unnormalized, median_times_unnormalized+mad_times_unnormalized,\n", | |
" color=\"red\", alpha=0.15)\n", | |
"ax1.set_xscale(\"log\", base=10)\n", | |
"ax1.set_yscale(\"log\", base=10)\n", | |
"ax1.set_xlabel(\"#samples\")\n", | |
"ax1.set_ylabel(\"time (seconds)\")\n", | |
"ax1.set_title(\"Inference time\")\n", | |
"ax1.legend()\n", | |
"\n", | |
"ax2.plot(data_sizes, mean_rate_error_normalized, \"o-\", color=\"blue\", label=\"Normalized\")\n", | |
"ax2.fill_between(\n", | |
" data_sizes, mean_rate_error_normalized-std_rate_error_normalized, mean_rate_error_normalized+std_rate_error_normalized,\n", | |
" color=\"blue\", alpha=0.15)\n", | |
"ax2.plot(data_sizes, mean_rate_error_unnormalized, \"o-\", color=\"red\", label=\"Non-Normalized\")\n", | |
"ax2.fill_between(\n", | |
" data_sizes, mean_rate_error_unnormalized-std_rate_error_unnormalized, mean_rate_error_unnormalized+std_rate_error_unnormalized,\n", | |
" color=\"red\", alpha=0.15)\n", | |
"ax2.set_xscale(\"log\", base=10)\n", | |
"\n", | |
"ax2.set_xlabel(\"#samples\")\n", | |
"ax2.set_ylabel(\"Error\")\n", | |
"ax2.set_title(\"Inference error on \\\"Rate\\\"\")\n", | |
"ax2.legend()\n", | |
"plt.show()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "238b8bda-b95f-402f-a79c-dee0c1ec4a68", | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3 (ipykernel)", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.9.6" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 5 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment