Last active
June 24, 2020 18:00
-
-
Save brandonwillard/7c92ea0bb242a1abc3b3c6bd0f7c6f66 to your computer and use it in GitHub Desktop.
Symbolic-PyMC Beta-Binomial Conjugate Example
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""Symbolic-PyMC Beta-Binomial Conjugate Example | |
Using the extremely naive model from | |
https://github.com/zachwill/covid-19/blob/master/covid-19.ipynb as an example, | |
we'll show how auto-conjugation could save one from needlessly sampling a | |
posterior known in closed form. | |
""" | |
import pandas as pd | |
import pymc3 as pm | |
import theano | |
import theano.tensor as tt | |
import matplotlib.pyplot as plt | |
import seaborn as sns | |
from operator import add | |
from unification import var | |
from etuples import etuple | |
from kanren import run | |
from kanren.core import eq, lall | |
from symbolic_pymc.theano.meta import mt | |
from symbolic_pymc.theano.pymc3 import model_graph, graph_model | |
from symbolic_pymc.theano.utils import canonicalize | |
sns.set_style("whitegrid") | |
theano.config.cxx = "" | |
theano.config.mode = "FAST_COMPILE" | |
tt.config.compute_test_value = "ignore" | |
df = pd.DataFrame( | |
[ | |
["Singapore", 72, 0], | |
["Italy", 46, 21], | |
["Japan", 32, 5], | |
["Hong Kong", 30, 2], | |
["Thailand", 28, 0], | |
["South Korea", 24, 16], | |
["Malayasia", 20, 0], | |
["Vietnam", 16, 0], | |
["Germany", 16, 0], | |
["Australia", 15, 0], | |
["France", 11, 2], | |
["UK", 8, 0], | |
["USA", 7, 0], | |
["Macau", 6, 0], | |
["Taiwan", 5, 1], | |
["UAE", 5, 0], | |
["Canada", 3, 0], | |
["Spain", 2, 0], | |
], | |
columns=["country", "recovered", "deaths"], | |
) | |
df["combined"] = df.recovered + df.deaths | |
with pm.Model() as model: | |
# No idea why one would use such a specific, "informative" prior. Why would | |
# we want to put that much density on 1/2? Why?! | |
# For that matter, I'm not sure why one would use this model with this data | |
# at all. | |
p = pm.Beta("p", alpha=2, beta=2) | |
y = pm.Binomial("y", n=df.combined.values, p=p, observed=df.deaths.values) | |
# Convert the PyMC3 graph into a symbolic-pymc graph | |
fgraph = model_graph(model) | |
# Perform a set of standard algebraic simplifications | |
fgraph = canonicalize(fgraph, in_place=False) | |
def betabin_conjugateo(x, y): | |
"""Replace an observed Beta-Binomial model with an unobserved posterior Beta-Binomial model.""" | |
obs_lv = var() | |
beta_size, beta_rng, beta_name_lv = var(), var(), var() | |
alpha_lv, beta_lv = var(), var() | |
beta_rv_lv = mt.BetaRV(alpha_lv, beta_lv, size=beta_size, rng=beta_rng, name=beta_name_lv) | |
binom_size, binom_rng, binom_name_lv = var(), var(), var() | |
N_lv = var() | |
binom_lv = mt.BinomialRV(N_lv, beta_rv_lv, size=binom_size, rng=binom_rng, name=binom_name_lv) | |
obs_sum = etuple(mt.sum, obs_lv) | |
alpha_new = etuple(mt.add, alpha_lv, obs_sum) | |
beta_new = etuple(mt.add, beta_lv, etuple(mt.sub, etuple(mt.sum, N_lv), obs_sum)) | |
beta_post_rv_lv = etuple( | |
mt.BetaRV, alpha_new, beta_new, beta_size, beta_rng, name=etuple(add, beta_name_lv, "_post") | |
) | |
binom_new_lv = etuple( | |
mt.BinomialRV, | |
N_lv, | |
beta_post_rv_lv, | |
binom_size, | |
binom_rng, | |
name=etuple(add, binom_name_lv, "_post"), | |
) | |
# TODO: We could also transform non-observed conjugates. | |
return lall(eq(x, mt.observed(obs_lv, binom_lv)), eq(y, binom_new_lv)) | |
# TODO: We could walk the entire graph and find all occurrences of Beta-Binomial | |
# conjugates, but, since we're only looking for observed Beta-Binomials, | |
# they'll always be the function graph outputs. | |
q = var() | |
res = run(1, q, betabin_conjugateo(fgraph.outputs[0], q)) | |
expr_graph = res[0].eval_obj | |
fgraph_conj = expr_graph.reify() | |
# Convert the symbolic-pymc graph into a PyMC3 model | |
model_conjugated = graph_model(fgraph_conj) | |
# Do some wasteful MCMC sampling using the original model | |
with model: | |
brute_trace = pm.sample(draws=6000, tune=2000) | |
# Zzzz | |
# For our conjugate model, we can just sample the posterior term directly | |
conj_samples = model_conjugated.p_post.random(size=len(brute_trace["p"])) | |
# Wow, that was quick! | |
# Now, let's compare the two... | |
pm.traceplot({"p_post_est": brute_trace["p"], "p_post_conj": conj_samples}) | |
plt.savefig("beta-binomial-samples.png") | |
# What do ya know; they're essentially the same! |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment