Last active
February 20, 2019 16:34
-
-
Save brandonwillard/a6f03c5afae044c64f83506bf901a795 to your computer and use it in GitHub Desktop.
Automatic Bayesian Hierarchical Model Re-centering/scaling in PyMC3
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
A demonstration of automatic parameter re-centering/scaling for the example | |
in https://twiecki.io/blog/2017/02/08/bayesian-hierchical-non-centered/ | |
using https://github.com/pymc-devs/symbolic-pymc. | |
""" | |
import theano | |
import theano.tensor as tt | |
import numpy as np | |
import pymc3 as pm | |
import pandas as pd | |
import matplotlib.pyplot as plt | |
import seaborn as sns | |
from theano.gof.opt import EquilibriumOptimizer | |
from theano.printing import debugprint as tt_dprint | |
from unification.utils import transitive_get as walk | |
from kanren.goals import goalify | |
from symbolic_pymc.opt import KanrenRelationSub | |
from symbolic_pymc.pymc3 import model_graph, graph_model | |
from symbolic_pymc.utils import (optimize_graph, canonicalize, | |
get_rv_observation) | |
from symbolic_pymc.printing import tt_pprint | |
sns.set_style('whitegrid') | |
# Skip compilation | |
_cxx_config = theano.config.cxx | |
theano.config.cxx = '' | |
tt.config.compute_test_value = 'ignore' | |
# | |
# Set up the original data set. | |
# | |
data = pd.read_csv('radon.csv') | |
data['log_radon'] = data['log_radon'].astype(theano.config.floatX) | |
county_names = data.county.unique() | |
county_idx = data.county_code.values | |
n_counties = len(data.county.unique()) | |
with pm.Model() as model_centered: | |
mu_a = pm.Normal('mu_a', mu=0., sd=100**2) | |
sigma_a = pm.HalfCauchy('sigma_a', 5) | |
mu_b = pm.Normal('mu_b', mu=0., sd=100**2) | |
sigma_b = pm.HalfCauchy('sigma_b', 5) | |
a = pm.Normal('a', mu=mu_a, sd=sigma_a, shape=n_counties) | |
b = pm.Normal('b', mu=mu_b, sd=sigma_b, shape=n_counties) | |
eps = pm.HalfCauchy('eps', 5) | |
radon_est = a[county_idx] + b[county_idx] * data.floor.values | |
radon_like = pm.Normal('radon_like', mu=radon_est, sd=eps, | |
observed=data.log_radon) | |
fgraph = model_graph(model_centered) | |
fgraph = canonicalize(fgraph, return_graph=True, in_place=False) | |
print(tt_dprint(fgraph)) | |
print(tt_pprint(fgraph)) | |
# | |
# Create miniKanren goals and relations to shift and scale normals. | |
# | |
concat = goalify(lambda *args: ''.join(args)) | |
def constant_neq(lvar, val): | |
"""Assert that a constant graph variable is not equal to a specific value. | |
Scalar values are broadcast across arrays. | |
""" | |
from symbolic_pymc.meta import MetaConstant | |
def _goal(s): | |
lvar_val = walk(lvar, s) | |
if isinstance(lvar_val, (tt.Constant, MetaConstant)): | |
data = lvar_val.data | |
if ((isinstance(val, np.ndarray) and | |
not np.array_equal(data, val)) or | |
not all(np.atleast_1d(data) == val)): | |
yield s | |
else: | |
yield s | |
return _goal | |
def recenter_relations(in_expr, out_expr): | |
"""Relations that re-center and re-scale distributions. | |
""" | |
from unification import var | |
from kanren import conde, eq | |
from symbolic_pymc.meta import mt | |
norm_name_lv = var() | |
norm_size_lv = var() | |
norm_rng_lv = var() | |
mean_lv = var() | |
sd_lv = var() | |
norm_mt = mt.NormalRV(mean_lv, sd_lv, | |
size=norm_size_lv, | |
rng=norm_rng_lv, | |
name=norm_name_lv) | |
norm_offset_name_mt = var() | |
rct_norm_offset_mt = (mt.add, mean_lv, | |
(mt.mul, sd_lv, | |
mt.NormalRV(0., 1., | |
size=norm_size_lv, | |
rng=norm_rng_lv, | |
name=norm_offset_name_mt))) | |
# TODO: PyMC3 rescaling issue doesn't allow us to take the more | |
# general approach. | |
# norm_mean_name_mt = var() | |
# rct_norm_mean_mt = (mt.add, mean_lv, | |
# mt.NormalRV(0., sd_lv, | |
# size=norm_size_lv, | |
# rng=norm_rng_lv, | |
# name=norm_mean_name_mt)) | |
# | |
# norm_sd_name_mt = var() | |
# rct_norm_sd_mt = (mt.mul, sd_lv, | |
# mt.NormalRV(mean_lv, 1., | |
# size=norm_size_lv, | |
# rng=norm_rng_lv, | |
# name=norm_sd_name_mt)) | |
rels = (conde, | |
[(eq, in_expr, norm_mt), | |
(conde, | |
[(constant_neq, sd_lv, 1), | |
(constant_neq, mean_lv, 0), | |
(eq, out_expr, rct_norm_offset_mt), | |
(concat, [norm_name_lv, "_offset"], norm_offset_name_mt)]) | |
# [(constant_neq, mean_lv, 0), | |
# (eq, out_expr, rct_norm_mean_mt), | |
# (concat, [norm_name_lv, "_rmean"], norm_mean_name_mt)], | |
# [(constant_neq, sd_lv, 1), | |
# (eq, out_expr, rct_norm_sd_mt), | |
# (concat, [norm_name_lv, "_rsd"], norm_sd_name_mt)]) | |
]) | |
return rels | |
# | |
# Run the re-scaling optimization on the Theano graph. | |
# | |
posterior_opt = EquilibriumOptimizer( | |
[KanrenRelationSub(recenter_relations, | |
node_filter=get_rv_observation)], | |
max_use_ratio=10) | |
fgraph_opt = optimize_graph(fgraph, posterior_opt, return_graph=True) | |
fgraph_opt = canonicalize(fgraph_opt, return_graph=True, in_place=False) | |
print(tt_dprint(fgraph_opt)) | |
print(tt_pprint(fgraph_opt)) | |
# | |
# Create a PyMC3 model from the re-scaled graph. | |
# | |
# TODO: Provide arguments that make some variables `Deterministic`s. | |
theano.config.cxx = _cxx_config | |
model_recentered = graph_model(fgraph_opt) | |
np.random.seed(123) | |
with model_centered: | |
centered_trace = pm.sample(draws=5000, tune=1000, njobs=4)[1000:] | |
with model_recentered: | |
recentered_trace = pm.sample(draws=5000, tune=1000, njobs=4)[1000:] | |
pm.traceplot(recentered_trace, varnames=['sigma_b']) | |
# pm.traceplot(recentered_trace) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment