Skip to content

Instantly share code, notes, and snippets.

@tcbegley
Last active August 13, 2021 08:14
Show Gist options
  • Save tcbegley/70cb8e34a681c8324533b88cc5d99b1d to your computer and use it in GitHub Desktop.
Save tcbegley/70cb8e34a681c8324533b88cc5d99b1d to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"id": "ab842ff0-d8ac-438a-8e69-5a62153eec75",
"metadata": {},
"source": [
"Install dependencies to run\n",
"\n",
"```sh\n",
"pip install -U pyro-ppl funsor\n",
"pip install git+https://github.com/tcbegley/numpyro.git@format-trace\n",
"```"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "7492c904-96b5-41d8-a791-85f368df3b5a",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"import jax.numpy as jnp\n",
"import numpyro\n",
"import numpyro.distributions as dist\n",
"import pyro\n",
"import pyro.distributions as pdist\n",
"import torch\n",
"from numpyro.contrib.funsor import config_enumerate, enum\n",
"from numpyro.util import format_shapes\n",
"from pyro import poutine"
]
},
{
"cell_type": "markdown",
"id": "48ac48f4-337e-4155-86b2-fd5711fc2cee",
"metadata": {},
"source": [
"## Model 1"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "a720e280-0abc-4b5b-a489-b1d239017a15",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Trace Shapes: \n",
" Param Sites: \n",
"Sample Sites: \n",
" a dist | \n",
" value | \n",
" log_prob | \n",
" b dist | 2 \n",
" value | 2 \n",
" log_prob | \n",
" c_plate dist | \n",
" value 2 | \n",
" log_prob | \n",
" c dist 2 | \n",
" value 2 | \n",
" log_prob 2 | \n",
" d_plate dist | \n",
" value 3 | \n",
" log_prob | \n",
" d dist 3 | 4 5\n",
" value 3 | 4 5\n",
" log_prob 3 | \n",
" x_axis dist | \n",
" value 3 | \n",
" log_prob | \n",
" y_axis dist | \n",
" value 2 | \n",
" log_prob | \n",
" x dist 3 1 | \n",
" value 3 1 | \n",
" log_prob 3 1 | \n",
" y dist 2 1 1 | \n",
" value 2 1 1 | \n",
" log_prob 2 1 1 | \n",
" xy dist 2 3 1 | \n",
" value 2 3 1 | \n",
" log_prob 2 3 1 | \n",
" z dist 2 3 1 | 5 \n",
" value 2 3 1 | 5 \n",
" log_prob 2 3 1 | \n"
]
}
],
"source": [
"def model1():\n",
" a = pyro.sample(\"a\", pdist.Normal(0, 1))\n",
" b = pyro.sample(\"b\", pdist.Normal(torch.zeros(2), 1).to_event(1))\n",
" with pyro.plate(\"c_plate\", 2):\n",
" c = pyro.sample(\"c\", pdist.Normal(torch.zeros(2), 1))\n",
" with pyro.plate(\"d_plate\", 3):\n",
" d = pyro.sample(\"d\", pdist.Normal(torch.zeros(3, 4, 5), 1).to_event(2))\n",
"\n",
" x_axis = pyro.plate(\"x_axis\", 3, dim=-2)\n",
" y_axis = pyro.plate(\"y_axis\", 2, dim=-3)\n",
" with x_axis:\n",
" x = pyro.sample(\"x\", pdist.Normal(0, 1))\n",
" with y_axis:\n",
" y = pyro.sample(\"y\", pdist.Normal(0, 1))\n",
" with x_axis, y_axis:\n",
" xy = pyro.sample(\"xy\", pdist.Normal(0, 1))\n",
" z = pyro.sample(\"z\", pdist.Normal(0, 1).expand([5]).to_event(1))\n",
"\n",
"\n",
"trace = pyro.poutine.trace(model1).get_trace()\n",
"trace.compute_log_prob() # optional, but allows printing of log_prob shapes\n",
"print(trace.format_shapes())"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "31360b9b-dad5-4d26-9644-f32dcec4cf19",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Trace Shapes: \n",
" Param Sites: \n",
"Sample Sites: \n",
" a dist | \n",
" value | \n",
" log_prob | \n",
" b dist | 2 \n",
" value | 2 \n",
" log_prob | \n",
"c_plate plate 2 | \n",
" c dist 2 | \n",
" value 2 | \n",
" log_prob 2 | \n",
"d_plate plate 3 | \n",
" d dist 3 | 4 5\n",
" value 3 | 4 5\n",
" log_prob 3 | \n",
" x_axis plate 3 | \n",
" y_axis plate 2 | \n",
" x dist 3 1 | \n",
" value 3 1 | \n",
" log_prob 3 1 | \n",
" y dist 2 1 1 | \n",
" value 2 1 1 | \n",
" log_prob 2 1 1 | \n",
" xy dist 2 3 1 | \n",
" value 2 3 1 | \n",
" log_prob 2 3 1 | \n",
" z dist 2 3 1 | 5 \n",
" value 2 3 1 | 5 \n",
" log_prob 2 3 1 | \n"
]
}
],
"source": [
"def model1():\n",
" a = numpyro.sample(\"a\", dist.Normal(0, 1))\n",
" b = numpyro.sample(\"b\", dist.Normal(jnp.zeros(2), 1).to_event(1))\n",
" with numpyro.plate(\"c_plate\", 2):\n",
" c = numpyro.sample(\"c\", dist.Normal(jnp.zeros(2), 1))\n",
" with numpyro.plate(\"d_plate\", 3):\n",
" d = numpyro.sample(\"d\", dist.Normal(jnp.zeros((3, 4, 5)), 1).to_event(2))\n",
"\n",
" x_axis = numpyro.plate(\"x_axis\", 3, dim=-2)\n",
" y_axis = numpyro.plate(\"y_axis\", 2, dim=-3)\n",
" with x_axis:\n",
" x = numpyro.sample(\"x\", dist.Normal(0, 1))\n",
" with y_axis:\n",
" y = numpyro.sample(\"y\", dist.Normal(0, 1))\n",
" with x_axis, y_axis:\n",
" xy = numpyro.sample(\"xy\", dist.Normal(0, 1))\n",
" z = numpyro.sample(\"z\", dist.Normal(0, 1).expand([5]).to_event(1))\n",
"\n",
"\n",
"with numpyro.handlers.seed(rng_seed=0), numpyro.handlers.trace() as t:\n",
" model1()\n",
"\n",
"print(format_shapes(t, log_prob=True))"
]
},
{
"cell_type": "markdown",
"id": "9aea0bdd-f4f6-425b-85ca-68e1fefbbd01",
"metadata": {},
"source": [
"## Model 2"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "ffe79b1e-ca70-4fdb-8f58-dc68062542d4",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Trace Shapes: \n",
" Param Sites: \n",
" mean 100\n",
"Sample Sites: \n",
" data dist |\n",
" value 10 |\n",
" log_prob |\n",
" x dist 10 |\n",
" value 10 |\n",
" log_prob 10 |\n"
]
}
],
"source": [
"data = torch.arange(100.0)\n",
"\n",
"\n",
"def model2():\n",
" mean = pyro.param(\"mean\", torch.zeros(len(data)))\n",
" with pyro.plate(\"data\", len(data), subsample_size=10) as ind:\n",
" batch = data[ind] # Select a minibatch of data.\n",
" mean_batch = mean[ind] # Take care to select the relevant per-datum parameters.\n",
" # Do stuff with batch:\n",
" x = pyro.sample(\"x\", pdist.Normal(mean_batch, 1), obs=batch)\n",
"\n",
"\n",
"trace = pyro.poutine.trace(model2).get_trace()\n",
"trace.compute_log_prob() # optional, but allows printing of log_prob shapes\n",
"print(trace.format_shapes())"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "0bc3a05c-c996-4341-aa0f-a53bac14207a",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Trace Shapes: \n",
" Param Sites: \n",
" mean 100\n",
"Sample Sites: \n",
" data plate 10 |\n",
" x dist 10 |\n",
" value 10 |\n",
" log_prob 10 |\n"
]
}
],
"source": [
"data = jnp.arange(100)\n",
"\n",
"\n",
"def model2():\n",
" mean = numpyro.param(\"mean\", jnp.zeros(len(data)))\n",
" with numpyro.plate(\"data\", len(data), subsample_size=10) as ind:\n",
" batch = data[ind] # Select a minibatch of data.\n",
" mean_batch = mean[ind] # Take care to select the relevant per-datum parameters.\n",
" # Do stuff with batch:\n",
" x = numpyro.sample(\"x\", dist.Normal(mean_batch, 1), obs=batch)\n",
"\n",
"\n",
"with numpyro.handlers.seed(rng_seed=0), numpyro.handlers.trace() as t:\n",
" model2()\n",
"\n",
"print(format_shapes(t, log_prob=True))"
]
},
{
"cell_type": "markdown",
"id": "c79bc290-6ae9-4250-a8de-78c2cfdaa207",
"metadata": {},
"source": [
"## Model 3"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "f69a0e8d-d0ef-4222-87de-1133e5d34b05",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Trace Shapes: \n",
" Param Sites: \n",
" p 6 \n",
" locs 2 \n",
"Sample Sites: \n",
" a dist | \n",
" value 6 1 1 | \n",
" log_prob 6 1 1 | \n",
" b dist 6 1 1 | \n",
" value 2 1 1 1 | \n",
" log_prob 2 6 1 1 | \n",
" c_plate dist | \n",
" value 4 | \n",
" log_prob | \n",
" c dist 4 | \n",
" value 2 1 1 1 1 | \n",
" log_prob 2 1 1 1 4 | \n",
" d_plate dist | \n",
" value 5 | \n",
" log_prob | \n",
" d dist 5 4 | \n",
" value 2 1 1 1 1 1 | \n",
" log_prob 2 1 1 1 5 4 | \n",
" e dist 2 1 1 1 5 4 | 7\n",
" value 2 1 1 1 5 4 | 7\n",
" log_prob 2 1 1 1 5 4 | \n"
]
}
],
"source": [
"@pyro.infer.config_enumerate\n",
"def model3():\n",
" p = pyro.param(\"p\", torch.arange(6.0) / 6)\n",
" locs = pyro.param(\"locs\", torch.tensor([-1.0, 1.0]))\n",
"\n",
" a = pyro.sample(\"a\", pdist.Categorical(torch.ones(6) / 6))\n",
" b = pyro.sample(\"b\", pdist.Bernoulli(p[a]))\n",
" with pyro.plate(\"c_plate\", 4):\n",
" c = pyro.sample(\"c\", pdist.Bernoulli(0.3))\n",
" with pyro.plate(\"d_plate\", 5):\n",
" d = pyro.sample(\"d\", pdist.Bernoulli(0.4))\n",
" e_loc = locs[d.long()].unsqueeze(-1)\n",
" e_scale = torch.arange(1.0, 8.0)\n",
" e = pyro.sample(\"e\", pdist.Normal(e_loc, e_scale).to_event(1))\n",
"\n",
"\n",
"trace = pyro.poutine.trace(\n",
" pyro.poutine.enum(model3, first_available_dim=-3)\n",
").get_trace()\n",
"trace.compute_log_prob() # optional, but allows printing of log_prob shapes\n",
"print(trace.format_shapes())"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "d2d158f8-b87b-451c-b11a-da389234fdf6",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Trace Shapes: \n",
" Param Sites: \n",
" p 6 \n",
" locs 2 \n",
"Sample Sites: \n",
" a dist | \n",
" value 6 1 1 | \n",
" log_prob 6 1 1 | \n",
" b dist 6 1 1 | \n",
" value 2 1 1 1 | \n",
" log_prob 2 6 1 1 | \n",
"c_plate plate 4 | \n",
" c dist 4 | \n",
" value 2 1 1 1 1 | \n",
" log_prob 2 1 1 1 4 | \n",
"d_plate plate 5 | \n",
" d dist 5 4 | \n",
" value 2 1 1 1 1 1 | \n",
" log_prob 2 1 1 1 5 4 | \n",
" e dist 2 1 1 1 5 4 | 7\n",
" value 2 1 1 1 5 4 | 7\n",
" log_prob 2 1 1 1 5 4 | \n"
]
}
],
"source": [
"@config_enumerate\n",
"def model3():\n",
" p = numpyro.param(\"p\", jnp.arange(6) / 6)\n",
" locs = numpyro.param(\"locs\", jnp.array([-1.0, 1.0]))\n",
"\n",
" a = numpyro.sample(\"a\", dist.Categorical(jnp.ones(6) / 6))\n",
" b = numpyro.sample(\"b\", dist.Bernoulli(p[a])) # Note this depends on a.\n",
" with numpyro.plate(\"c_plate\", 4):\n",
" c = numpyro.sample(\"c\", dist.Bernoulli(0.3))\n",
" with numpyro.plate(\"d_plate\", 5):\n",
" d = numpyro.sample(\"d\", dist.Bernoulli(0.4))\n",
" e_loc = locs[d][..., None]\n",
" e_scale = jnp.arange(1, 8)\n",
" e = numpyro.sample(\"e\", dist.Normal(e_loc, e_scale).to_event(1))\n",
"\n",
"\n",
"with numpyro.handlers.seed(rng_seed=0), numpyro.handlers.trace() as t:\n",
" enum(model3, first_available_dim=-3)()\n",
"\n",
"print(format_shapes(t, log_prob=True))"
]
},
{
"cell_type": "markdown",
"id": "cb9ed19c-16bd-44c6-ae3f-7087b368cf88",
"metadata": {},
"source": [
"## Model from NumPyro tests"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "61272d0d-c790-47c0-ac51-c62b87111b6b",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Trace Shapes: \n",
" Param Sites: \n",
" mean 100\n",
"Sample Sites: \n",
" data dist |\n",
" value 10 |\n",
" log_prob |\n",
" x dist 10 |\n",
" value 10 |\n",
" log_prob 10 |\n"
]
}
],
"source": [
"data = torch.arange(100.0)\n",
"\n",
"\n",
"def model_test():\n",
" mean = pyro.param(\"mean\", torch.zeros(len(data)))\n",
" with pyro.plate(\"data\", len(data), subsample_size=10) as ind:\n",
" batch = data[ind]\n",
" mean_batch = mean[ind]\n",
" pyro.sample(\"x\", pdist.Normal(mean_batch, 1), obs=batch)\n",
"\n",
"\n",
"trace = pyro.poutine.trace(model_test).get_trace()\n",
"trace.compute_log_prob() # optional, but allows printing of log_prob shapes\n",
"print(trace.format_shapes())"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "05a881e5-46b0-4a90-a9a3-c311c6763322",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Trace Shapes: \n",
" Param Sites: \n",
" mean 100\n",
"Sample Sites: \n",
" data plate 10 |\n",
" x dist 10 |\n",
" value 10 |\n",
" log_prob 10 |\n"
]
}
],
"source": [
"data = jnp.arange(100)\n",
"\n",
"\n",
"def model_test():\n",
" mean = numpyro.param(\"mean\", jnp.zeros(len(data)))\n",
" with numpyro.plate(\"data\", len(data), subsample_size=10) as ind:\n",
" batch = data[ind]\n",
" mean_batch = mean[ind]\n",
" numpyro.sample(\"x\", dist.Normal(mean_batch, 1), obs=batch)\n",
"\n",
"\n",
"with numpyro.handlers.seed(rng_seed=0), numpyro.handlers.trace() as t:\n",
" model_test()\n",
"\n",
"print(format_shapes(t, log_prob=True))"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "numpyro",
"language": "python",
"name": "numpyro"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.2"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment