Last active
August 13, 2021 08:14
-
-
Save tcbegley/70cb8e34a681c8324533b88cc5d99b1d to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"id": "ab842ff0-d8ac-438a-8e69-5a62153eec75", | |
"metadata": {}, | |
"source": [ | |
"Install dependencies to run\n", | |
"\n", | |
"```sh\n", | |
"pip install -U pyro-ppl funsor\n", | |
"pip install git+https://github.com/tcbegley/numpyro.git@format-trace\n", | |
"```" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"id": "7492c904-96b5-41d8-a791-85f368df3b5a", | |
"metadata": { | |
"tags": [] | |
}, | |
"outputs": [], | |
"source": [ | |
"import jax.numpy as jnp\n", | |
"import numpyro\n", | |
"import numpyro.distributions as dist\n", | |
"import pyro\n", | |
"import pyro.distributions as pdist\n", | |
"import torch\n", | |
"from numpyro.contrib.funsor import config_enumerate, enum\n", | |
"from numpyro.util import format_shapes\n", | |
"from pyro import poutine" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "48ac48f4-337e-4155-86b2-fd5711fc2cee", | |
"metadata": {}, | |
"source": [ | |
"## Model 1" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"id": "a720e280-0abc-4b5b-a489-b1d239017a15", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Trace Shapes: \n", | |
" Param Sites: \n", | |
"Sample Sites: \n", | |
" a dist | \n", | |
" value | \n", | |
" log_prob | \n", | |
" b dist | 2 \n", | |
" value | 2 \n", | |
" log_prob | \n", | |
" c_plate dist | \n", | |
" value 2 | \n", | |
" log_prob | \n", | |
" c dist 2 | \n", | |
" value 2 | \n", | |
" log_prob 2 | \n", | |
" d_plate dist | \n", | |
" value 3 | \n", | |
" log_prob | \n", | |
" d dist 3 | 4 5\n", | |
" value 3 | 4 5\n", | |
" log_prob 3 | \n", | |
" x_axis dist | \n", | |
" value 3 | \n", | |
" log_prob | \n", | |
" y_axis dist | \n", | |
" value 2 | \n", | |
" log_prob | \n", | |
" x dist 3 1 | \n", | |
" value 3 1 | \n", | |
" log_prob 3 1 | \n", | |
" y dist 2 1 1 | \n", | |
" value 2 1 1 | \n", | |
" log_prob 2 1 1 | \n", | |
" xy dist 2 3 1 | \n", | |
" value 2 3 1 | \n", | |
" log_prob 2 3 1 | \n", | |
" z dist 2 3 1 | 5 \n", | |
" value 2 3 1 | 5 \n", | |
" log_prob 2 3 1 | \n" | |
] | |
} | |
], | |
"source": [ | |
"def model1():\n", | |
" a = pyro.sample(\"a\", pdist.Normal(0, 1))\n", | |
" b = pyro.sample(\"b\", pdist.Normal(torch.zeros(2), 1).to_event(1))\n", | |
" with pyro.plate(\"c_plate\", 2):\n", | |
" c = pyro.sample(\"c\", pdist.Normal(torch.zeros(2), 1))\n", | |
" with pyro.plate(\"d_plate\", 3):\n", | |
" d = pyro.sample(\"d\", pdist.Normal(torch.zeros(3, 4, 5), 1).to_event(2))\n", | |
"\n", | |
" x_axis = pyro.plate(\"x_axis\", 3, dim=-2)\n", | |
" y_axis = pyro.plate(\"y_axis\", 2, dim=-3)\n", | |
" with x_axis:\n", | |
" x = pyro.sample(\"x\", pdist.Normal(0, 1))\n", | |
" with y_axis:\n", | |
" y = pyro.sample(\"y\", pdist.Normal(0, 1))\n", | |
" with x_axis, y_axis:\n", | |
" xy = pyro.sample(\"xy\", pdist.Normal(0, 1))\n", | |
" z = pyro.sample(\"z\", pdist.Normal(0, 1).expand([5]).to_event(1))\n", | |
"\n", | |
"\n", | |
"trace = pyro.poutine.trace(model1).get_trace()\n", | |
"trace.compute_log_prob() # optional, but allows printing of log_prob shapes\n", | |
"print(trace.format_shapes())" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"id": "31360b9b-dad5-4d26-9644-f32dcec4cf19", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Trace Shapes: \n", | |
" Param Sites: \n", | |
"Sample Sites: \n", | |
" a dist | \n", | |
" value | \n", | |
" log_prob | \n", | |
" b dist | 2 \n", | |
" value | 2 \n", | |
" log_prob | \n", | |
"c_plate plate 2 | \n", | |
" c dist 2 | \n", | |
" value 2 | \n", | |
" log_prob 2 | \n", | |
"d_plate plate 3 | \n", | |
" d dist 3 | 4 5\n", | |
" value 3 | 4 5\n", | |
" log_prob 3 | \n", | |
" x_axis plate 3 | \n", | |
" y_axis plate 2 | \n", | |
" x dist 3 1 | \n", | |
" value 3 1 | \n", | |
" log_prob 3 1 | \n", | |
" y dist 2 1 1 | \n", | |
" value 2 1 1 | \n", | |
" log_prob 2 1 1 | \n", | |
" xy dist 2 3 1 | \n", | |
" value 2 3 1 | \n", | |
" log_prob 2 3 1 | \n", | |
" z dist 2 3 1 | 5 \n", | |
" value 2 3 1 | 5 \n", | |
" log_prob 2 3 1 | \n" | |
] | |
} | |
], | |
"source": [ | |
"def model1():\n", | |
" a = numpyro.sample(\"a\", dist.Normal(0, 1))\n", | |
" b = numpyro.sample(\"b\", dist.Normal(jnp.zeros(2), 1).to_event(1))\n", | |
" with numpyro.plate(\"c_plate\", 2):\n", | |
" c = numpyro.sample(\"c\", dist.Normal(jnp.zeros(2), 1))\n", | |
" with numpyro.plate(\"d_plate\", 3):\n", | |
" d = numpyro.sample(\"d\", dist.Normal(jnp.zeros((3, 4, 5)), 1).to_event(2))\n", | |
"\n", | |
" x_axis = numpyro.plate(\"x_axis\", 3, dim=-2)\n", | |
" y_axis = numpyro.plate(\"y_axis\", 2, dim=-3)\n", | |
" with x_axis:\n", | |
" x = numpyro.sample(\"x\", dist.Normal(0, 1))\n", | |
" with y_axis:\n", | |
" y = numpyro.sample(\"y\", dist.Normal(0, 1))\n", | |
" with x_axis, y_axis:\n", | |
" xy = numpyro.sample(\"xy\", dist.Normal(0, 1))\n", | |
" z = numpyro.sample(\"z\", dist.Normal(0, 1).expand([5]).to_event(1))\n", | |
"\n", | |
"\n", | |
"with numpyro.handlers.seed(rng_seed=0), numpyro.handlers.trace() as t:\n", | |
" model1()\n", | |
"\n", | |
"print(format_shapes(t, log_prob=True))" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "9aea0bdd-f4f6-425b-85ca-68e1fefbbd01", | |
"metadata": {}, | |
"source": [ | |
"## Model 2" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"id": "ffe79b1e-ca70-4fdb-8f58-dc68062542d4", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Trace Shapes: \n", | |
" Param Sites: \n", | |
" mean 100\n", | |
"Sample Sites: \n", | |
" data dist |\n", | |
" value 10 |\n", | |
" log_prob |\n", | |
" x dist 10 |\n", | |
" value 10 |\n", | |
" log_prob 10 |\n" | |
] | |
} | |
], | |
"source": [ | |
"data = torch.arange(100.0)\n", | |
"\n", | |
"\n", | |
"def model2():\n", | |
" mean = pyro.param(\"mean\", torch.zeros(len(data)))\n", | |
" with pyro.plate(\"data\", len(data), subsample_size=10) as ind:\n", | |
" batch = data[ind] # Select a minibatch of data.\n", | |
" mean_batch = mean[ind] # Take care to select the relevant per-datum parameters.\n", | |
" # Do stuff with batch:\n", | |
" x = pyro.sample(\"x\", pdist.Normal(mean_batch, 1), obs=batch)\n", | |
"\n", | |
"\n", | |
"trace = pyro.poutine.trace(model2).get_trace()\n", | |
"trace.compute_log_prob() # optional, but allows printing of log_prob shapes\n", | |
"print(trace.format_shapes())" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"id": "0bc3a05c-c996-4341-aa0f-a53bac14207a", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Trace Shapes: \n", | |
" Param Sites: \n", | |
" mean 100\n", | |
"Sample Sites: \n", | |
" data plate 10 |\n", | |
" x dist 10 |\n", | |
" value 10 |\n", | |
" log_prob 10 |\n" | |
] | |
} | |
], | |
"source": [ | |
"data = jnp.arange(100)\n", | |
"\n", | |
"\n", | |
"def model2():\n", | |
" mean = numpyro.param(\"mean\", jnp.zeros(len(data)))\n", | |
" with numpyro.plate(\"data\", len(data), subsample_size=10) as ind:\n", | |
" batch = data[ind] # Select a minibatch of data.\n", | |
" mean_batch = mean[ind] # Take care to select the relevant per-datum parameters.\n", | |
" # Do stuff with batch:\n", | |
" x = numpyro.sample(\"x\", dist.Normal(mean_batch, 1), obs=batch)\n", | |
"\n", | |
"\n", | |
"with numpyro.handlers.seed(rng_seed=0), numpyro.handlers.trace() as t:\n", | |
" model2()\n", | |
"\n", | |
"print(format_shapes(t, log_prob=True))" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "c79bc290-6ae9-4250-a8de-78c2cfdaa207", | |
"metadata": {}, | |
"source": [ | |
"## Model 3" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"id": "f69a0e8d-d0ef-4222-87de-1133e5d34b05", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Trace Shapes: \n", | |
" Param Sites: \n", | |
" p 6 \n", | |
" locs 2 \n", | |
"Sample Sites: \n", | |
" a dist | \n", | |
" value 6 1 1 | \n", | |
" log_prob 6 1 1 | \n", | |
" b dist 6 1 1 | \n", | |
" value 2 1 1 1 | \n", | |
" log_prob 2 6 1 1 | \n", | |
" c_plate dist | \n", | |
" value 4 | \n", | |
" log_prob | \n", | |
" c dist 4 | \n", | |
" value 2 1 1 1 1 | \n", | |
" log_prob 2 1 1 1 4 | \n", | |
" d_plate dist | \n", | |
" value 5 | \n", | |
" log_prob | \n", | |
" d dist 5 4 | \n", | |
" value 2 1 1 1 1 1 | \n", | |
" log_prob 2 1 1 1 5 4 | \n", | |
" e dist 2 1 1 1 5 4 | 7\n", | |
" value 2 1 1 1 5 4 | 7\n", | |
" log_prob 2 1 1 1 5 4 | \n" | |
] | |
} | |
], | |
"source": [ | |
"@pyro.infer.config_enumerate\n", | |
"def model3():\n", | |
" p = pyro.param(\"p\", torch.arange(6.0) / 6)\n", | |
" locs = pyro.param(\"locs\", torch.tensor([-1.0, 1.0]))\n", | |
"\n", | |
" a = pyro.sample(\"a\", pdist.Categorical(torch.ones(6) / 6))\n", | |
" b = pyro.sample(\"b\", pdist.Bernoulli(p[a]))\n", | |
" with pyro.plate(\"c_plate\", 4):\n", | |
" c = pyro.sample(\"c\", pdist.Bernoulli(0.3))\n", | |
" with pyro.plate(\"d_plate\", 5):\n", | |
" d = pyro.sample(\"d\", pdist.Bernoulli(0.4))\n", | |
" e_loc = locs[d.long()].unsqueeze(-1)\n", | |
" e_scale = torch.arange(1.0, 8.0)\n", | |
" e = pyro.sample(\"e\", pdist.Normal(e_loc, e_scale).to_event(1))\n", | |
"\n", | |
"\n", | |
"trace = pyro.poutine.trace(\n", | |
" pyro.poutine.enum(model3, first_available_dim=-3)\n", | |
").get_trace()\n", | |
"trace.compute_log_prob() # optional, but allows printing of log_prob shapes\n", | |
"print(trace.format_shapes())" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"id": "d2d158f8-b87b-451c-b11a-da389234fdf6", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Trace Shapes: \n", | |
" Param Sites: \n", | |
" p 6 \n", | |
" locs 2 \n", | |
"Sample Sites: \n", | |
" a dist | \n", | |
" value 6 1 1 | \n", | |
" log_prob 6 1 1 | \n", | |
" b dist 6 1 1 | \n", | |
" value 2 1 1 1 | \n", | |
" log_prob 2 6 1 1 | \n", | |
"c_plate plate 4 | \n", | |
" c dist 4 | \n", | |
" value 2 1 1 1 1 | \n", | |
" log_prob 2 1 1 1 4 | \n", | |
"d_plate plate 5 | \n", | |
" d dist 5 4 | \n", | |
" value 2 1 1 1 1 1 | \n", | |
" log_prob 2 1 1 1 5 4 | \n", | |
" e dist 2 1 1 1 5 4 | 7\n", | |
" value 2 1 1 1 5 4 | 7\n", | |
" log_prob 2 1 1 1 5 4 | \n" | |
] | |
} | |
], | |
"source": [ | |
"@config_enumerate\n", | |
"def model3():\n", | |
" p = numpyro.param(\"p\", jnp.arange(6) / 6)\n", | |
" locs = numpyro.param(\"locs\", jnp.array([-1.0, 1.0]))\n", | |
"\n", | |
" a = numpyro.sample(\"a\", dist.Categorical(jnp.ones(6) / 6))\n", | |
" b = numpyro.sample(\"b\", dist.Bernoulli(p[a])) # Note this depends on a.\n", | |
" with numpyro.plate(\"c_plate\", 4):\n", | |
" c = numpyro.sample(\"c\", dist.Bernoulli(0.3))\n", | |
" with numpyro.plate(\"d_plate\", 5):\n", | |
" d = numpyro.sample(\"d\", dist.Bernoulli(0.4))\n", | |
" e_loc = locs[d][..., None]\n", | |
" e_scale = jnp.arange(1, 8)\n", | |
" e = numpyro.sample(\"e\", dist.Normal(e_loc, e_scale).to_event(1))\n", | |
"\n", | |
"\n", | |
"with numpyro.handlers.seed(rng_seed=0), numpyro.handlers.trace() as t:\n", | |
" enum(model3, first_available_dim=-3)()\n", | |
"\n", | |
"print(format_shapes(t, log_prob=True))" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "cb9ed19c-16bd-44c6-ae3f-7087b368cf88", | |
"metadata": {}, | |
"source": [ | |
"## Model from NumPyro tests" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"id": "61272d0d-c790-47c0-ac51-c62b87111b6b", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Trace Shapes: \n", | |
" Param Sites: \n", | |
" mean 100\n", | |
"Sample Sites: \n", | |
" data dist |\n", | |
" value 10 |\n", | |
" log_prob |\n", | |
" x dist 10 |\n", | |
" value 10 |\n", | |
" log_prob 10 |\n" | |
] | |
} | |
], | |
"source": [ | |
"data = torch.arange(100.0)\n", | |
"\n", | |
"\n", | |
"def model_test():\n", | |
" mean = pyro.param(\"mean\", torch.zeros(len(data)))\n", | |
" with pyro.plate(\"data\", len(data), subsample_size=10) as ind:\n", | |
" batch = data[ind]\n", | |
" mean_batch = mean[ind]\n", | |
" pyro.sample(\"x\", pdist.Normal(mean_batch, 1), obs=batch)\n", | |
"\n", | |
"\n", | |
"trace = pyro.poutine.trace(model_test).get_trace()\n", | |
"trace.compute_log_prob() # optional, but allows printing of log_prob shapes\n", | |
"print(trace.format_shapes())" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 9, | |
"id": "05a881e5-46b0-4a90-a9a3-c311c6763322", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Trace Shapes: \n", | |
" Param Sites: \n", | |
" mean 100\n", | |
"Sample Sites: \n", | |
" data plate 10 |\n", | |
" x dist 10 |\n", | |
" value 10 |\n", | |
" log_prob 10 |\n" | |
] | |
} | |
], | |
"source": [ | |
"data = jnp.arange(100)\n", | |
"\n", | |
"\n", | |
"def model_test():\n", | |
" mean = numpyro.param(\"mean\", jnp.zeros(len(data)))\n", | |
" with numpyro.plate(\"data\", len(data), subsample_size=10) as ind:\n", | |
" batch = data[ind]\n", | |
" mean_batch = mean[ind]\n", | |
" numpyro.sample(\"x\", dist.Normal(mean_batch, 1), obs=batch)\n", | |
"\n", | |
"\n", | |
"with numpyro.handlers.seed(rng_seed=0), numpyro.handlers.trace() as t:\n", | |
" model_test()\n", | |
"\n", | |
"print(format_shapes(t, log_prob=True))" | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "numpyro", | |
"language": "python", | |
"name": "numpyro" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.8.2" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 5 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment