Skip to content

Instantly share code, notes, and snippets.

@brandonwillard
Last active June 24, 2020 18:00
Show Gist options
  • Save brandonwillard/7c92ea0bb242a1abc3b3c6bd0f7c6f66 to your computer and use it in GitHub Desktop.
Save brandonwillard/7c92ea0bb242a1abc3b3c6bd0f7c6f66 to your computer and use it in GitHub Desktop.
Symbolic-PyMC Beta-Binomial Conjugate Example
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Symbolic-PyMC Beta-Binomial Conjugate Example\n",
"\n",
"Using the example model from\n",
"[here](https://github.com/zachwill/covid-19/blob/master/covid-19.ipynb),\n",
"we'll show how auto-conjugation can save one from needlessly sampling a\n",
"posterior known in closed form."
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import pandas as pd\n",
"\n",
"import pymc3 as pm\n",
"\n",
"import theano\n",
"import theano.tensor as tt\n",
"\n",
"import matplotlib.pyplot as plt\n",
"import seaborn as sns\n",
"\n",
"from operator import add\n",
"\n",
"from unification import var\n",
"from etuples import etuple\n",
"\n",
"\n",
"from kanren import run\n",
"from kanren.core import eq, lall\n",
"\n",
"from symbolic_pymc.theano.meta import mt\n",
"from symbolic_pymc.theano.pymc3 import model_graph, graph_model\n",
"from symbolic_pymc.theano.utils import canonicalize\n",
"\n",
"sns.set_style(\"whitegrid\")\n",
"\n",
"theano.config.cxx = \"\"\n",
"theano.config.mode = \"FAST_COMPILE\"\n",
"tt.config.compute_test_value = \"ignore\""
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"df = pd.DataFrame(\n",
" [\n",
" [\"Singapore\", 72, 0],\n",
" [\"Italy\", 46, 21],\n",
" [\"Japan\", 32, 5],\n",
" [\"Hong Kong\", 30, 2],\n",
" [\"Thailand\", 28, 0],\n",
" [\"South Korea\", 24, 16],\n",
" [\"Malayasia\", 20, 0],\n",
" [\"Vietnam\", 16, 0],\n",
" [\"Germany\", 16, 0],\n",
" [\"Australia\", 15, 0],\n",
" [\"France\", 11, 2],\n",
" [\"UK\", 8, 0],\n",
" [\"USA\", 7, 0],\n",
" [\"Macau\", 6, 0],\n",
" [\"Taiwan\", 5, 1],\n",
" [\"UAE\", 5, 0],\n",
" [\"Canada\", 3, 0],\n",
" [\"Spain\", 2, 0],\n",
" ],\n",
" columns=[\"country\", \"recovered\", \"deaths\"],\n",
")\n",
"\n",
"df[\"combined\"] = df.recovered + df.deaths"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"with pm.Model() as model:\n",
" p = pm.Beta(\"p\", alpha=2, beta=2)\n",
" y = pm.Binomial(\"y\", n=df.combined.values, p=p, observed=df.deaths.values)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Numeric Optimization\n",
"\n",
"The naive approach to estimating this model would involve some wasteful MCMC sampling, which we do below."
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Auto-assigning NUTS sampler...\n",
"Initializing NUTS using jitter+adapt_diag...\n",
"Multiprocess sampling (4 chains in 4 jobs)\n",
"NUTS: [p]\n",
"Sampling 4 chains, 0 divergences: 100%|██████████| 32000/32000 [00:48<00:00, 660.25draws/s]\n"
]
}
],
"source": [
"with model:\n",
" brute_trace = pm.sample(draws=6000, tune=2000)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Symbolic Optimization\n",
"\n",
"Now, we'll walk through the process of creating a rewrite rule--from scratch--that converts Beta-Binomial models into their closed-form posteriors."
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"# Convert the PyMC3 graph into a symbolic-pymc graph\n",
"fgraph = model_graph(model)\n",
"\n",
"# Convert the graph to a more consistent form \n",
"# (this helps \"normalize\" the patterns that we want to match)\n",
"fgraph = canonicalize(fgraph, in_place=False)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"def betabin_conjugateo(x, y):\n",
" \"\"\"Replace an observed Beta-Binomial model with an unobserved posterior Beta-Binomial model.\"\"\"\n",
" obs_lv = var()\n",
"\n",
" beta_size, beta_rng, beta_name_lv = var(), var(), var()\n",
" alpha_lv, beta_lv = var(), var()\n",
" # Match a generic Beta random variable\n",
" beta_rv_lv = mt.BetaRV(alpha_lv, beta_lv, size=beta_size, rng=beta_rng, name=beta_name_lv)\n",
"\n",
" binom_size, binom_rng, binom_name_lv = var(), var(), var()\n",
" N_lv = var()\n",
" # Match a generic Binomial RV that uses the aforementioned Beta RV\n",
" binom_lv = mt.BinomialRV(N_lv, beta_rv_lv, size=binom_size, rng=binom_rng, name=binom_name_lv)\n",
"\n",
" # Construct the posterior parameters from the prior parameters and observations\n",
" obs_sum = etuple(mt.sum, obs_lv)\n",
" alpha_new = etuple(mt.add, alpha_lv, obs_sum)\n",
" beta_new = etuple(mt.add, beta_lv, etuple(mt.sub, etuple(mt.sum, N_lv), obs_sum))\n",
"\n",
" beta_post_rv_lv = etuple(\n",
" mt.BetaRV, alpha_new, beta_new, beta_size, beta_rng, name=etuple(add, beta_name_lv, \"_post\")\n",
" )\n",
" # Construct the posterior from the posterior parameters\n",
" binom_new_lv = etuple(\n",
" mt.BinomialRV,\n",
" N_lv,\n",
" beta_post_rv_lv,\n",
" binom_size,\n",
" binom_rng,\n",
" # Give it a descriptive name\n",
" name=etuple(add, binom_name_lv, \"_post\"),\n",
" )\n",
"\n",
" # TODO: We could also transform non-observed conjugates.\n",
" return lall(eq(x, mt.observed(obs_lv, binom_lv)), eq(y, binom_new_lv))"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"# TODO: We could walk the entire graph and find all occurrences of Beta-Binomial\n",
"# conjugates, but, since we're only looking for observed Beta-Binomials,\n",
"# they'll always be the function graph outputs.\n",
"q = var()\n",
"res = run(1, q, betabin_conjugateo(fgraph.outputs[0], q))\n",
"expr_graph = res[0].eval_obj\n",
"fgraph_conj = expr_graph.reify()"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"model_conjugated = graph_model(fgraph_conj)"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
"# For our conjugate model, we can just sample the posterior term directly\n",
"conj_samples = model_conjugated.p_post.random(size=len(brute_trace[\"p\"]))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Comparison"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/bwillard/apps/anaconda3/envs/symbolic-pymc/lib/python3.7/site-packages/arviz/plots/backends/matplotlib/distplot.py:38: UserWarning: Argument backend_kwargs has not effect in matplotlib.plot_distSupplied value won't be used\n",
" \"Argument backend_kwargs has not effect in matplotlib.plot_dist\"\n",
"/home/bwillard/apps/anaconda3/envs/symbolic-pymc/lib/python3.7/site-packages/arviz/plots/backends/matplotlib/distplot.py:38: UserWarning: Argument backend_kwargs has not effect in matplotlib.plot_distSupplied value won't be used\n",
" \"Argument backend_kwargs has not effect in matplotlib.plot_dist\"\n"
]
},
{
"data": {
"text/plain": [
"array([[<matplotlib.axes._subplots.AxesSubplot object at 0x7f0d137ed860>,\n",
" <matplotlib.axes._subplots.AxesSubplot object at 0x7f0d13063b38>],\n",
" [<matplotlib.axes._subplots.AxesSubplot object at 0x7f0d1386b9b0>,\n",
" <matplotlib.axes._subplots.AxesSubplot object at 0x7f0d1303f630>]],\n",
" dtype=object)"
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 864x288 with 4 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"pm.traceplot({\"p_post_est\": brute_trace[\"p\"], \"p_post_conj\": conj_samples})\n",
"plt.savefig(\"beta-binomial-samples.png\")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.0"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
"""Symbolic-PyMC Beta-Binomial Conjugate Example
Using the extremely naive model from
https://github.com/zachwill/covid-19/blob/master/covid-19.ipynb as an example,
we'll show how auto-conjugation could save one from needlessly sampling a
posterior known in closed form.
"""
import pandas as pd
import pymc3 as pm
import theano
import theano.tensor as tt
import matplotlib.pyplot as plt
import seaborn as sns
from operator import add
from unification import var
from etuples import etuple
from kanren import run
from kanren.core import eq, lall
from symbolic_pymc.theano.meta import mt
from symbolic_pymc.theano.pymc3 import model_graph, graph_model
from symbolic_pymc.theano.utils import canonicalize
sns.set_style("whitegrid")
theano.config.cxx = ""
theano.config.mode = "FAST_COMPILE"
tt.config.compute_test_value = "ignore"
df = pd.DataFrame(
[
["Singapore", 72, 0],
["Italy", 46, 21],
["Japan", 32, 5],
["Hong Kong", 30, 2],
["Thailand", 28, 0],
["South Korea", 24, 16],
["Malayasia", 20, 0],
["Vietnam", 16, 0],
["Germany", 16, 0],
["Australia", 15, 0],
["France", 11, 2],
["UK", 8, 0],
["USA", 7, 0],
["Macau", 6, 0],
["Taiwan", 5, 1],
["UAE", 5, 0],
["Canada", 3, 0],
["Spain", 2, 0],
],
columns=["country", "recovered", "deaths"],
)
df["combined"] = df.recovered + df.deaths
with pm.Model() as model:
# No idea why one would use such a specific, "informative" prior. Why would
# we want to put that much density on 1/2? Why?!
# For that matter, I'm not sure why one would use this model with this data
# at all.
p = pm.Beta("p", alpha=2, beta=2)
y = pm.Binomial("y", n=df.combined.values, p=p, observed=df.deaths.values)
# Convert the PyMC3 graph into a symbolic-pymc graph
fgraph = model_graph(model)
# Perform a set of standard algebraic simplifications
fgraph = canonicalize(fgraph, in_place=False)
def betabin_conjugateo(x, y):
"""Replace an observed Beta-Binomial model with an unobserved posterior Beta-Binomial model."""
obs_lv = var()
beta_size, beta_rng, beta_name_lv = var(), var(), var()
alpha_lv, beta_lv = var(), var()
beta_rv_lv = mt.BetaRV(alpha_lv, beta_lv, size=beta_size, rng=beta_rng, name=beta_name_lv)
binom_size, binom_rng, binom_name_lv = var(), var(), var()
N_lv = var()
binom_lv = mt.BinomialRV(N_lv, beta_rv_lv, size=binom_size, rng=binom_rng, name=binom_name_lv)
obs_sum = etuple(mt.sum, obs_lv)
alpha_new = etuple(mt.add, alpha_lv, obs_sum)
beta_new = etuple(mt.add, beta_lv, etuple(mt.sub, etuple(mt.sum, N_lv), obs_sum))
beta_post_rv_lv = etuple(
mt.BetaRV, alpha_new, beta_new, beta_size, beta_rng, name=etuple(add, beta_name_lv, "_post")
)
binom_new_lv = etuple(
mt.BinomialRV,
N_lv,
beta_post_rv_lv,
binom_size,
binom_rng,
name=etuple(add, binom_name_lv, "_post"),
)
# TODO: We could also transform non-observed conjugates.
return lall(eq(x, mt.observed(obs_lv, binom_lv)), eq(y, binom_new_lv))
# TODO: We could walk the entire graph and find all occurrences of Beta-Binomial
# conjugates, but, since we're only looking for observed Beta-Binomials,
# they'll always be the function graph outputs.
q = var()
res = run(1, q, betabin_conjugateo(fgraph.outputs[0], q))
expr_graph = res[0].eval_obj
fgraph_conj = expr_graph.reify()
# Convert the symbolic-pymc graph into a PyMC3 model
model_conjugated = graph_model(fgraph_conj)
# Do some wasteful MCMC sampling using the original model
with model:
brute_trace = pm.sample(draws=6000, tune=2000)
# Zzzz
# For our conjugate model, we can just sample the posterior term directly
conj_samples = model_conjugated.p_post.random(size=len(brute_trace["p"]))
# Wow, that was quick!
# Now, let's compare the two...
pm.traceplot({"p_post_est": brute_trace["p"], "p_post_conj": conj_samples})
plt.savefig("beta-binomial-samples.png")
# What do ya know; they're essentially the same!
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment