Skip to content

Instantly share code, notes, and snippets.

@brandonwillard
Last active June 24, 2020 18:00
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save brandonwillard/7c92ea0bb242a1abc3b3c6bd0f7c6f66 to your computer and use it in GitHub Desktop.
Save brandonwillard/7c92ea0bb242a1abc3b3c6bd0f7c6f66 to your computer and use it in GitHub Desktop.
Symbolic-PyMC Beta-Binomial Conjugate Example
Display the source blob
Display the rendered blob
Raw
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
"""Symbolic-PyMC Beta-Binomial Conjugate Example
Using the extremely naive model from
https://github.com/zachwill/covid-19/blob/master/covid-19.ipynb as an example,
we'll show how auto-conjugation could save one from needlessly sampling a
posterior known in closed form.
"""
import pandas as pd
import pymc3 as pm
import theano
import theano.tensor as tt
import matplotlib.pyplot as plt
import seaborn as sns
from operator import add
from unification import var
from etuples import etuple
from kanren import run
from kanren.core import eq, lall
from symbolic_pymc.theano.meta import mt
from symbolic_pymc.theano.pymc3 import model_graph, graph_model
from symbolic_pymc.theano.utils import canonicalize
sns.set_style("whitegrid")
theano.config.cxx = ""
theano.config.mode = "FAST_COMPILE"
tt.config.compute_test_value = "ignore"
df = pd.DataFrame(
[
["Singapore", 72, 0],
["Italy", 46, 21],
["Japan", 32, 5],
["Hong Kong", 30, 2],
["Thailand", 28, 0],
["South Korea", 24, 16],
["Malayasia", 20, 0],
["Vietnam", 16, 0],
["Germany", 16, 0],
["Australia", 15, 0],
["France", 11, 2],
["UK", 8, 0],
["USA", 7, 0],
["Macau", 6, 0],
["Taiwan", 5, 1],
["UAE", 5, 0],
["Canada", 3, 0],
["Spain", 2, 0],
],
columns=["country", "recovered", "deaths"],
)
df["combined"] = df.recovered + df.deaths
with pm.Model() as model:
# No idea why one would use such a specific, "informative" prior. Why would
# we want to put that much density on 1/2? Why?!
# For that matter, I'm not sure why one would use this model with this data
# at all.
p = pm.Beta("p", alpha=2, beta=2)
y = pm.Binomial("y", n=df.combined.values, p=p, observed=df.deaths.values)
# Convert the PyMC3 graph into a symbolic-pymc graph
fgraph = model_graph(model)
# Perform a set of standard algebraic simplifications
fgraph = canonicalize(fgraph, in_place=False)
def betabin_conjugateo(x, y):
"""Replace an observed Beta-Binomial model with an unobserved posterior Beta-Binomial model."""
obs_lv = var()
beta_size, beta_rng, beta_name_lv = var(), var(), var()
alpha_lv, beta_lv = var(), var()
beta_rv_lv = mt.BetaRV(alpha_lv, beta_lv, size=beta_size, rng=beta_rng, name=beta_name_lv)
binom_size, binom_rng, binom_name_lv = var(), var(), var()
N_lv = var()
binom_lv = mt.BinomialRV(N_lv, beta_rv_lv, size=binom_size, rng=binom_rng, name=binom_name_lv)
obs_sum = etuple(mt.sum, obs_lv)
alpha_new = etuple(mt.add, alpha_lv, obs_sum)
beta_new = etuple(mt.add, beta_lv, etuple(mt.sub, etuple(mt.sum, N_lv), obs_sum))
beta_post_rv_lv = etuple(
mt.BetaRV, alpha_new, beta_new, beta_size, beta_rng, name=etuple(add, beta_name_lv, "_post")
)
binom_new_lv = etuple(
mt.BinomialRV,
N_lv,
beta_post_rv_lv,
binom_size,
binom_rng,
name=etuple(add, binom_name_lv, "_post"),
)
# TODO: We could also transform non-observed conjugates.
return lall(eq(x, mt.observed(obs_lv, binom_lv)), eq(y, binom_new_lv))
# TODO: We could walk the entire graph and find all occurrences of Beta-Binomial
# conjugates, but, since we're only looking for observed Beta-Binomials,
# they'll always be the function graph outputs.
q = var()
res = run(1, q, betabin_conjugateo(fgraph.outputs[0], q))
expr_graph = res[0].eval_obj
fgraph_conj = expr_graph.reify()
# Convert the symbolic-pymc graph into a PyMC3 model
model_conjugated = graph_model(fgraph_conj)
# Do some wasteful MCMC sampling using the original model
with model:
brute_trace = pm.sample(draws=6000, tune=2000)
# Zzzz
# For our conjugate model, we can just sample the posterior term directly
conj_samples = model_conjugated.p_post.random(size=len(brute_trace["p"]))
# Wow, that was quick!
# Now, let's compare the two...
pm.traceplot({"p_post_est": brute_trace["p"], "p_post_conj": conj_samples})
plt.savefig("beta-binomial-samples.png")
# What do ya know; they're essentially the same!
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment