Skip to content

Instantly share code, notes, and snippets.

@peterroelants
Created August 21, 2021 15:09
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save peterroelants/f9304e8d619af49b752ae012e5320bc9 to your computer and use it in GitHub Desktop.
Save peterroelants/f9304e8d619af49b752ae012e5320bc9 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"id": "2361a2b9-1a2a-4089-9172-f330ad38a73e",
"metadata": {},
"source": [
"# Normalized vs Non-normalized log_prob: Difference in speed?\n",
"\n",
"Notebook to test the difference in NumPyro inference speed vs:\n",
"- Normalized Poisson `log_prob`\n",
"- Non-Normalized Poisson `log_prob`\n",
"\n",
"Related discourse thread: https://forum.pyro.ai/t/unnormalized-densities/3251"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d0fb3a55-598d-4e60-8cf4-6d908ed7e6ac",
"metadata": {},
"outputs": [],
"source": [
"# Imports\n",
"%matplotlib inline\n",
"%config InlineBackend.figure_format = 'svg'"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "99a4c37f-0d42-43c2-b072-e3063bcdbf57",
"metadata": {},
"outputs": [],
"source": [
"import sys\n",
"import warnings\n",
"import time\n",
"\n",
"import numpy as np\n",
"\n",
"import jax\n",
"import jax.numpy as jnp\n",
"\n",
"import numpyro\n",
"from numpyro.infer import MCMC, NUTS, Predictive\n",
"import numpyro.distributions as dist\n",
"from numpyro.distributions import constraints\n",
"from numpyro.distributions.util import validate_sample\n",
"\n",
"import matplotlib\n",
"import matplotlib.pyplot as plt\n",
"from matplotlib import cm # Colormaps\n",
"import seaborn as sns\n",
"import arviz as az\n",
"\n",
"import tqdm"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "59f82b8e-b8c4-48aa-b583-f0bb78579d32",
"metadata": {},
"outputs": [],
"source": [
"sns.set_style('darkgrid')\n",
"az.rcParams['stats.hdi_prob'] = 0.90\n",
"az.style.use(\"arviz-darkgrid\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "cd12f436-971f-48d9-a2dd-33bf56914a37",
"metadata": {},
"outputs": [],
"source": [
"numpyro.set_platform('cpu')\n",
"numpyro.set_host_device_count(1)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "9b4e48c7-3732-4c36-abe3-d39399b5d233",
"metadata": {},
"outputs": [],
"source": [
"np.random.seed(42)\n",
"rng_key = jax.random.PRNGKey(42)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "bf5a8a6f-4362-49eb-9af9-eed5c9be77a3",
"metadata": {},
"outputs": [],
"source": [
"warnings.filterwarnings('ignore')"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "4f1f896b-422b-43e0-b411-ac0a319d38d0",
"metadata": {},
"outputs": [],
"source": [
"k = 5"
]
},
{
"cell_type": "markdown",
"id": "597a1e78-97e1-4883-a49d-fd47131d55d3",
"metadata": {},
"source": [
"## Data"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "78eae1e8-af9e-4fb4-8283-b66a7cdbf64b",
"metadata": {},
"outputs": [],
"source": [
"n = 50_000\n",
"true_rate = 25.34\n",
"\n",
"observations = np.random.poisson(lam=true_rate, size=n)\n",
"print(observations.shape)"
]
},
{
"cell_type": "markdown",
"id": "b781d5af-e6d1-42ae-8a83-b03f760fdda2",
"metadata": {},
"source": [
"### Rate inference: Normalized"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d6c2c981-a7bc-4229-9276-7190d785b41a",
"metadata": {},
"outputs": [],
"source": [
"def model_poisson(obs=None):\n",
" log_rate = numpyro.sample(\"log_rate\", dist.Normal(loc=0.0, scale=10.0))\n",
" rate = numpyro.deterministic(\"rate\", jnp.exp(log_rate))\n",
" numpyro.sample('obs', dist.Poisson(rate=rate), obs=obs)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "5fd72025-2cf6-43ce-b781-0886663439bd",
"metadata": {},
"outputs": [],
"source": [
"rng_key = jax.random.PRNGKey(42)\n",
"\n",
"num_warmup, num_samples = 1000, 10000\n",
"\n",
"# Run NUTS.\n",
"kernel_poisson = NUTS(model_poisson)\n",
"mcmc_poisson = MCMC(\n",
" kernel_poisson,\n",
" num_warmup=num_warmup,\n",
" num_samples=num_samples,\n",
" num_chains=4,\n",
" progress_bar=True,\n",
")\n",
"# Run once to compile\n",
"mcmc_poisson.run(rng_key, obs=observations)\n",
"\n",
"# Show trace\n",
"display(az.summary(mcmc_poisson, var_names=[\"~log_rate\"], round_to=2))\n",
"inference_data_poisson = az.from_numpyro(\n",
" posterior=mcmc_poisson,\n",
")\n",
"\n",
"az.plot_trace(\n",
" inference_data_poisson,\n",
" compact=True,\n",
" var_names=[\"~log_rate\"],\n",
" lines=[\n",
" (\"rate\", {}, true_rate),\n",
" ],\n",
")\n",
"plt.suptitle('Trace plots', fontsize=18)\n",
"plt.show()\n"
]
},
{
"cell_type": "markdown",
"id": "9b9c1aeb-2da6-4f1a-9ceb-76b36b30ebe7",
"metadata": {},
"source": [
"### Rate inference: Un-Normalized\n",
"\n",
"- http://sherrytowers.com/2014/07/10/poisson-likelihood/"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a9a8d91e-d5d8-40bb-b632-1622ada9ee45",
"metadata": {},
"outputs": [],
"source": [
"class PoissonUN(dist.Distribution):\n",
" arg_constraints = {\"rate\": constraints.positive}\n",
" support = constraints.nonnegative_integer\n",
"\n",
" def __init__(self, rate, *, validate_args=None):\n",
" self.rate = rate\n",
" super().__init__(jnp.shape(rate), validate_args=validate_args)\n",
"\n",
" def sample(self, key, sample_shape=()):\n",
" assert is_prng_key(key)\n",
" return random.poisson(key, self.rate, shape=sample_shape + self.batch_shape)\n",
"\n",
" @validate_sample\n",
" def log_prob(self, value):\n",
" if self._validate_args:\n",
" self._validate_sample(value)\n",
" value = jax.device_get(value)\n",
" return (jnp.log(self.rate) * value) - self.rate"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3b859618-fefd-45f4-9d38-145b869f96ff",
"metadata": {},
"outputs": [],
"source": [
"def model_poisson_unnormalized(obs=None):\n",
" log_rate = numpyro.sample(\"log_rate\", dist.Normal(loc=0.0, scale=10.0))\n",
" rate = numpyro.deterministic(\"rate\", jnp.exp(log_rate))\n",
" numpyro.sample('obs', PoissonUN(rate=rate), obs=obs)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "7eeb3a7e-dcff-4d10-86a2-79fcc74c8892",
"metadata": {},
"outputs": [],
"source": [
"rng_key = jax.random.PRNGKey(42)\n",
"\n",
"num_warmup, num_samples = 1000, 10000\n",
"\n",
"# Run NUTS.\n",
"kernel_poisson_un = NUTS(model_poisson_unnormalized)\n",
"mcmc_poisson_un = MCMC(\n",
" kernel_poisson_un,\n",
" num_warmup=num_warmup,\n",
" num_samples=num_samples,\n",
" num_chains=4,\n",
" progress_bar=True,\n",
")\n",
"mcmc_poisson_un.run(rng_key, obs=observations)\n",
"\n",
"# Show trace\n",
"display(az.summary(mcmc_poisson_un, var_names=[\"~log_rate\"], round_to=2))\n",
"inference_data_poisson_un = az.from_numpyro(\n",
" posterior=mcmc_poisson_un,\n",
")\n",
"\n",
"az.plot_trace(\n",
" inference_data_poisson_un,\n",
" compact=True,\n",
" var_names=[\"~log_rate\"],\n",
" lines=[\n",
" (\"rate\", {}, true_rate),\n",
" ],\n",
")\n",
"plt.suptitle('Trace plots', fontsize=18)\n",
"plt.show()\n"
]
},
{
"cell_type": "markdown",
"id": "4d3ddc09-acdf-436a-a90e-970fd827a4f4",
"metadata": {},
"source": [
"## Comparison"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "8eab63c5-5250-44e5-8fbe-3c90abb58202",
"metadata": {},
"outputs": [],
"source": [
"true_rate = 25.34\n",
"\n",
"data_sizes = [2, 5, 10, 50, 100, 500, 1000, 5000, 10_000, 25_000]"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "969e7f90-db33-4d70-acb0-f97d5afe541f",
"metadata": {},
"outputs": [],
"source": [
"def run_with_timing(model_fn, n_runs, n_warmup, n_samples, obs, true_rate):\n",
" \"\"\"Original\"\"\"\n",
" rng_key = jax.random.PRNGKey(42)\n",
" # Run NUTS.\n",
" kernel = NUTS(model_fn)\n",
" mcmc = MCMC(\n",
" kernel,\n",
" num_warmup=n_warmup,\n",
" num_samples=n_samples,\n",
" num_chains=1,\n",
" progress_bar=False,\n",
" jit_model_args=True\n",
" )\n",
" # Run once to compile\n",
" mcmc.run(rng_key, obs=obs)\n",
" # Run k times to time\n",
" times = []\n",
" for _ in range(n_runs):\n",
" start_time = time.monotonic()\n",
" mcmc.run(rng_key, obs=obs)\n",
" posterior_samples = mcmc.get_samples()\n",
" posterior_samples[\"rate\"].block_until_ready()\n",
" stop_time = time.monotonic()\n",
" times.append(stop_time - start_time)\n",
" times = np.array(times)\n",
" median_time = np.median(times)\n",
" mad_time = np.median(np.abs(times - median_time))\n",
" rate_error = np.abs(posterior_samples[\"rate\"] - true_rate)\n",
" mean_rate_error = np.mean(rate_error)\n",
" std_rate_error = np.std(rate_error)\n",
" return (median_time, mad_time), (mean_rate_error, std_rate_error)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "75f6419f-8054-41d2-8813-9b8d49e304e5",
"metadata": {},
"outputs": [],
"source": [
"median_times_normalized = np.zeros_like(data_sizes, dtype=np.float32)\n",
"mad_times_normalized = np.zeros_like(data_sizes, dtype=np.float32)\n",
"mean_rate_error_normalized = np.zeros_like(data_sizes, dtype=np.float32)\n",
"std_rate_error_normalized = np.zeros_like(data_sizes, dtype=np.float32)\n",
"\n",
"median_times_unnormalized = np.zeros_like(data_sizes, dtype=np.float32)\n",
"mad_times_unnormalized = np.zeros_like(data_sizes, dtype=np.float32)\n",
"mean_rate_error_unnormalized = np.zeros_like(data_sizes, dtype=np.float32)\n",
"std_rate_error_unnormalized = np.zeros_like(data_sizes, dtype=np.float32)\n",
"\n",
"n_runs = 5\n",
"n_warmup = 500\n",
"n_samples = 5000\n",
"\n",
"pbar = tqdm.tqdm(data_sizes)\n",
"for i, n in enumerate(pbar):\n",
"\n",
" pbar.set_description(f\"#samples = {n}\")\n",
" observations = np.random.poisson(lam=true_rate, size=n)\n",
" # Run normalized\n",
" (median_time, mad_time), (mean_rate_error, std_rate_error) = run_with_timing(\n",
" model_fn=model_poisson,\n",
" n_runs=n_runs,\n",
" n_warmup=n_warmup,\n",
" n_samples=n_samples,\n",
" obs=observations,\n",
" true_rate=true_rate\n",
" )\n",
" median_times_normalized[i] = median_time\n",
" mad_times_normalized[i] = mad_time\n",
" mean_rate_error_normalized[i] = mean_rate_error\n",
" std_rate_error_normalized[i] = std_rate_error\n",
" # Run non-normalized\n",
" (median_time, mad_time), (mean_rate_error, std_rate_error) = run_with_timing(\n",
" model_fn=model_poisson_unnormalized,\n",
" n_runs=n_runs,\n",
" n_warmup=n_warmup,\n",
" n_samples=n_samples,\n",
" obs=observations,\n",
" true_rate=true_rate\n",
" )\n",
" median_times_unnormalized[i] = median_time\n",
" mad_times_unnormalized[i] = mad_time\n",
" mean_rate_error_unnormalized[i] = mean_rate_error\n",
" std_rate_error_unnormalized[i] = std_rate_error"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "624e7014-562f-4162-b9b5-748d49741cc6",
"metadata": {},
"outputs": [],
"source": [
"fig, (ax1, ax2) = plt.subplots(nrows=2, ncols=1, figsize=(8, 8))\n",
"\n",
"ax1.plot(data_sizes, median_times_normalized, \"o-\", color=\"blue\", label=\"Normalized\")\n",
"ax1.fill_between(\n",
" data_sizes, median_times_normalized-mad_times_normalized, median_times_normalized+mad_times_normalized,\n",
" color=\"blue\", alpha=0.15)\n",
"ax1.plot(data_sizes, median_times_unnormalized, \"o-\", color=\"red\", label=\"Non-Normalized\")\n",
"ax1.fill_between(\n",
" data_sizes, median_times_unnormalized-mad_times_unnormalized, median_times_unnormalized+mad_times_unnormalized,\n",
" color=\"red\", alpha=0.15)\n",
"ax1.set_xscale(\"log\", base=10)\n",
"ax1.set_yscale(\"log\", base=10)\n",
"ax1.set_xlabel(\"#samples\")\n",
"ax1.set_ylabel(\"time (seconds)\")\n",
"ax1.set_title(\"Inference time\")\n",
"ax1.legend()\n",
"\n",
"ax2.plot(data_sizes, mean_rate_error_normalized, \"o-\", color=\"blue\", label=\"Normalized\")\n",
"ax2.fill_between(\n",
" data_sizes, mean_rate_error_normalized-std_rate_error_normalized, mean_rate_error_normalized+std_rate_error_normalized,\n",
" color=\"blue\", alpha=0.15)\n",
"ax2.plot(data_sizes, mean_rate_error_unnormalized, \"o-\", color=\"red\", label=\"Non-Normalized\")\n",
"ax2.fill_between(\n",
" data_sizes, mean_rate_error_unnormalized-std_rate_error_unnormalized, mean_rate_error_unnormalized+std_rate_error_unnormalized,\n",
" color=\"red\", alpha=0.15)\n",
"ax2.set_xscale(\"log\", base=10)\n",
"\n",
"ax2.set_xlabel(\"#samples\")\n",
"ax2.set_ylabel(\"Error\")\n",
"ax2.set_title(\"Inference error on \\\"Rate\\\"\")\n",
"ax2.legend()\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "238b8bda-b95f-402f-a79c-dee0c1ec4a68",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.6"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment