Skip to content

Instantly share code, notes, and snippets.

@vanAmsterdam
Created May 1, 2020 13:13
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save vanAmsterdam/c87e4552892c6fdcc306282ab4948ad5 to your computer and use it in GitHub Desktop.
Save vanAmsterdam/c87e4552892c6fdcc306282ab4948ad5 to your computer and use it in GitHub Desktop.
define and run a latent variable model, everything gaussian, with posterior predictive on new data for a description of the model see twitter thread: https://twitter.com/WvanAmsterdam/status/1251214875394740226?s=20
'''
define and run a latent variable model, everything gaussian, with posterior predictive on newdata
for a description of the model see twitter thread: https://twitter.com/WvanAmsterdam/status/1251214875394740226?s=20
DAG:
W1 <- U -> W2 # latent confounder with 2 proxies
U -> tx
U -> y
tx -> y
'''
import numpyro
from jax import numpy as np, random
from jax.scipy.special import logsumexp
from numpyro import distributions as dist
from numpyro.distributions import constraints
from numpyro.handlers import mask, substitute, trace, seed, condition
from numpyro.infer.mcmc import NUTS, MCMC
from numpyro.infer.util import log_likelihood
from numpyro import diagnostics
import arviz as az
import pandas as pd
import re
numpyro.set_host_device_count(4)
def latentconfoundermodel1d(
control={'N': 500, 'sample_priors': True},
data={'tx': None, 'W1': None, 'W2': None, 'y': None},
priors=None, prm_vals=None):
'''
a probabilistic model for linear regression with a latent confounder
:param data dict: a dictionary with np.ndarrays for all observed data that is used to condition on, and None for unobserved / to marginalize out data.
:param priors dict: priors keyed by parameter name
:param control dict: control arguments, like N
'''
# get global parameters
if control['sample_priors']:
prms = {prm_name: numpyro.sample(prm_name, prior) for prm_name, prior in priors.items()}
else:
prms = prm_vals
# data plate
with numpyro.plate('obs', control['N']):
Uhat = numpyro.sample('Uhat', dist.Normal(0,1))
# U -> W model
muhat_W1 = Uhat * prms['b_U_W1']
muhat_W2 = Uhat * prms['b_U_W2']
numpyro.sample('W1', dist.Normal(muhat_W1, prms['s_W1']), obs=data['W1'])
numpyro.sample('W2', dist.Normal(muhat_W2, prms['s_W2']), obs=data['W2'])
# U -> tx model
muhat_tx = Uhat * prms['b_U_tx']
tx = numpyro.sample('tx', dist.Normal(muhat_tx, prms['s_tx']), obs=data['tx'])
# outcome model for the linear predictor
muhat_y = Uhat * prms['b_U_y'] + tx * prms['b_tx_y']
# sample outcome
return numpyro.sample('y', dist.Normal(muhat_y, prms['s_y']), obs=data['y'])
## sample data
prm_vals = dict(
b_U_W1 = 0.5,
b_U_W2 = 0.5,
s_W1 = 0.1,
s_W2 = 0.2,
b_U_tx = 0.75,
s_tx = 0.2,
b_tx_y = 1.0,
b_U_y = 0.75,
s_y = 0.2
)
prm_priors = {k: dist.Normal(0,5) for k in prm_vals.keys()}
# fix some parameters including the latent confounder to positive values for identification:
pos_prms = ['b_U_W1', 'b_U_W2', 'b_U_tx', 'b_U_y']
for prm in pos_prms:
prm_priors[prm] = dist.HalfNormal(2.5)
def sim_from_model(rng_key, model, prm_vals, nsim=500):
control = dict(N=nsim, sample_priors=False)
# run model forward
tr = trace(seed(model, rng_key)).get_trace(control=control, prm_vals=prm_vals)
# make dictionary
data = {k: v['value'] for k, v in tr.items()}
return data
nsim = 1000
sim_keys = random.split(random.PRNGKey(1224), 2)
simdata = sim_from_model(sim_keys[0], latentconfoundermodel1d, prm_vals, nsim)
# create test data
simdata2 = sim_from_model(sim_keys[1], latentconfoundermodel1d, prm_vals, nsim)
testdata = simdata2.copy()
testdata['y'] = None # test data should not contain y
control = {'N': nsim, 'sample_priors': True}
num_samples = 3000
num_warmup = 1500
num_chains = 4
## do mcmc
def run_mcmc(key, model, *args, **kwargs):
kernel = NUTS(model, target_accept_prob = 0.95)
mcmc = MCMC(kernel, num_warmup=num_warmup, num_samples=num_samples, num_chains=num_chains, progress_bar=True)
mcmc.run(key, *args, **kwargs)
return mcmc
mcmc_keys = random.split(random.PRNGKey(1225), 2)
mcmc = run_mcmc(mcmc_keys[0], latentconfoundermodel1d, control=control, data=simdata, priors=prm_priors)
smps = mcmc.get_samples(group_by_chain=True)
allvars = list(smps.keys())
prmvars = [k for k in allvars if k not in ['Uhat']]
sum_ = pd.DataFrame(diagnostics.summary(smps)).T
print(sum_.loc[prmvars,])
divergences = mcmc.get_extra_fields()['diverging']
print(f"num divergences: {divergences.sum()}")
## now do posterior prediction on a new datase
# create model that fixes all fixed arguments in the posterior mode
def make_jittable(model, control, data, priors):
def newmodel(*args, **kwargs):
model(control, data, priors, *args, **kwargs)
return newmodel
smps2 = mcmc.get_samples(group_by_chain=False)
pp_control = control.copy()
pp_control['sample_priors'] = False
# script for getting posterior samples
def get_postpred_smps(key, model, control, data, priors, postsamples, num_draws=10, num_warmup=100, num_samples=100):
jittable_model = make_jittable(latentconfoundermodel1d, control, data, priors)
mcmc = MCMC(NUTS(jittable_model), num_warmup, num_samples, num_chains = 1, jit_model_args=True, progress_bar=False)
keys = random.split(key, num_draws)
draws = []
for i in range(num_draws):
print(i, end='')
smp = {k: v[i] for k, v in postsamples.items()}
mcmc.run(keys[i], prm_vals=smp)
postsmps = mcmc.get_samples()
postsmp = {k: v[-1] for k, v in postsmps.items()} # grab last sample of each run
draws.append(postsmp)
return draws
pp_smps = get_postpred_smps(mcmc_keys[1],
latentconfoundermodel1d,
pp_control,
testdata,
prm_priors,
smps2,
num_draws=100)
## convert of type: {varname: (num_posterior_draws, N)}
def list_of_dicts_to_dict_of_lists(LD):
v = {k: [dic[k] for dic in LD] for k in LD[0]}
return v
pp_smps = list_of_dicts_to_dict_of_lists(pp_smps)
pp_smps = {k: np.stack(v, axis=0) for k, v in pp_smps.items()}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment