Skip to content

Instantly share code, notes, and snippets.

@brandonwillard
Last active February 20, 2019 16:34
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save brandonwillard/a6f03c5afae044c64f83506bf901a795 to your computer and use it in GitHub Desktop.
Save brandonwillard/a6f03c5afae044c64f83506bf901a795 to your computer and use it in GitHub Desktop.
Automatic Bayesian Hierarchical Model Re-centering/scaling in PyMC3
"""
A demonstration of automatic parameter re-centering/scaling for the example
in https://twiecki.io/blog/2017/02/08/bayesian-hierchical-non-centered/
using https://github.com/pymc-devs/symbolic-pymc.
"""
import theano
import theano.tensor as tt
import numpy as np
import pymc3 as pm
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from theano.gof.opt import EquilibriumOptimizer
from theano.printing import debugprint as tt_dprint
from unification.utils import transitive_get as walk
from kanren.goals import goalify
from symbolic_pymc.opt import KanrenRelationSub
from symbolic_pymc.pymc3 import model_graph, graph_model
from symbolic_pymc.utils import (optimize_graph, canonicalize,
get_rv_observation)
from symbolic_pymc.printing import tt_pprint
sns.set_style('whitegrid')
# Skip compilation
_cxx_config = theano.config.cxx
theano.config.cxx = ''
tt.config.compute_test_value = 'ignore'
#
# Set up the original data set.
#
data = pd.read_csv('radon.csv')
data['log_radon'] = data['log_radon'].astype(theano.config.floatX)
county_names = data.county.unique()
county_idx = data.county_code.values
n_counties = len(data.county.unique())
with pm.Model() as model_centered:
mu_a = pm.Normal('mu_a', mu=0., sd=100**2)
sigma_a = pm.HalfCauchy('sigma_a', 5)
mu_b = pm.Normal('mu_b', mu=0., sd=100**2)
sigma_b = pm.HalfCauchy('sigma_b', 5)
a = pm.Normal('a', mu=mu_a, sd=sigma_a, shape=n_counties)
b = pm.Normal('b', mu=mu_b, sd=sigma_b, shape=n_counties)
eps = pm.HalfCauchy('eps', 5)
radon_est = a[county_idx] + b[county_idx] * data.floor.values
radon_like = pm.Normal('radon_like', mu=radon_est, sd=eps,
observed=data.log_radon)
fgraph = model_graph(model_centered)
fgraph = canonicalize(fgraph, return_graph=True, in_place=False)
print(tt_dprint(fgraph))
print(tt_pprint(fgraph))
#
# Create miniKanren goals and relations to shift and scale normals.
#
concat = goalify(lambda *args: ''.join(args))
def constant_neq(lvar, val):
"""Assert that a constant graph variable is not equal to a specific value.
Scalar values are broadcast across arrays.
"""
from symbolic_pymc.meta import MetaConstant
def _goal(s):
lvar_val = walk(lvar, s)
if isinstance(lvar_val, (tt.Constant, MetaConstant)):
data = lvar_val.data
if ((isinstance(val, np.ndarray) and
not np.array_equal(data, val)) or
not all(np.atleast_1d(data) == val)):
yield s
else:
yield s
return _goal
def recenter_relations(in_expr, out_expr):
"""Relations that re-center and re-scale distributions.
"""
from unification import var
from kanren import conde, eq
from symbolic_pymc.meta import mt
norm_name_lv = var()
norm_size_lv = var()
norm_rng_lv = var()
mean_lv = var()
sd_lv = var()
norm_mt = mt.NormalRV(mean_lv, sd_lv,
size=norm_size_lv,
rng=norm_rng_lv,
name=norm_name_lv)
norm_offset_name_mt = var()
rct_norm_offset_mt = (mt.add, mean_lv,
(mt.mul, sd_lv,
mt.NormalRV(0., 1.,
size=norm_size_lv,
rng=norm_rng_lv,
name=norm_offset_name_mt)))
# TODO: PyMC3 rescaling issue doesn't allow us to take the more
# general approach.
# norm_mean_name_mt = var()
# rct_norm_mean_mt = (mt.add, mean_lv,
# mt.NormalRV(0., sd_lv,
# size=norm_size_lv,
# rng=norm_rng_lv,
# name=norm_mean_name_mt))
#
# norm_sd_name_mt = var()
# rct_norm_sd_mt = (mt.mul, sd_lv,
# mt.NormalRV(mean_lv, 1.,
# size=norm_size_lv,
# rng=norm_rng_lv,
# name=norm_sd_name_mt))
rels = (conde,
[(eq, in_expr, norm_mt),
(conde,
[(constant_neq, sd_lv, 1),
(constant_neq, mean_lv, 0),
(eq, out_expr, rct_norm_offset_mt),
(concat, [norm_name_lv, "_offset"], norm_offset_name_mt)])
# [(constant_neq, mean_lv, 0),
# (eq, out_expr, rct_norm_mean_mt),
# (concat, [norm_name_lv, "_rmean"], norm_mean_name_mt)],
# [(constant_neq, sd_lv, 1),
# (eq, out_expr, rct_norm_sd_mt),
# (concat, [norm_name_lv, "_rsd"], norm_sd_name_mt)])
])
return rels
#
# Run the re-scaling optimization on the Theano graph.
#
posterior_opt = EquilibriumOptimizer(
[KanrenRelationSub(recenter_relations,
node_filter=get_rv_observation)],
max_use_ratio=10)
fgraph_opt = optimize_graph(fgraph, posterior_opt, return_graph=True)
fgraph_opt = canonicalize(fgraph_opt, return_graph=True, in_place=False)
print(tt_dprint(fgraph_opt))
print(tt_pprint(fgraph_opt))
#
# Create a PyMC3 model from the re-scaled graph.
#
# TODO: Provide arguments that make some variables `Deterministic`s.
theano.config.cxx = _cxx_config
model_recentered = graph_model(fgraph_opt)
np.random.seed(123)
with model_centered:
centered_trace = pm.sample(draws=5000, tune=1000, njobs=4)[1000:]
with model_recentered:
recentered_trace = pm.sample(draws=5000, tune=1000, njobs=4)[1000:]
pm.traceplot(recentered_trace, varnames=['sigma_b'])
# pm.traceplot(recentered_trace)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment