Skip to content

Instantly share code, notes, and snippets.

@ColCarroll
Created September 30, 2018 20:39
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save ColCarroll/ebbe5f59299b97dd447b5472cc0d8672 to your computer and use it in GitHub Desktop.
Save ColCarroll/ebbe5f59299b97dd447b5472cc0d8672 to your computer and use it in GitHub Desktop.
import numpy as np
import pyro
import pyro.distributions as dist
from pyro.infer.mcmc import MCMC, NUTS
import torch
def pyro_centered_model(sigma):
mu = pyro.sample('mu', dist.Normal(torch.zeros(1), 10 * torch.ones(1)))
tau = pyro.sample('tau', dist.HalfCauchy(scale=25 * torch.ones(1)))
theta = pyro.sample('theta',
dist.Normal(
mu * torch.ones(8),
tau * torch.ones(8)))
return pyro.sample("obs", dist.Normal(theta, sigma))
def pyro_conditioned_model(model, sigma, y):
return pyro.poutine.condition(model, data={"obs": y})(sigma)
def pyro_centered_schools(data, draws, chains):
del chains
y = torch.Tensor(data['y']).type(torch.Tensor)
sigma = torch.Tensor(data['sigma']).type(torch.Tensor)
nuts_kernel = NUTS(pyro_conditioned_model, adapt_step_size=True)
posterior = MCMC(
nuts_kernel,
num_samples=draws,
warmup_steps=500,
).run(pyro_centered_model, sigma, y)
return posterior
if __name__ == '__main__':
data = {
'J': 8,
'y': np.array([28., 8., -3., 7., -1., 1., 18., 12.]),
'sigma': np.array([15., 10., 16., 11., 9., 11., 10., 18.]),
}
posterior = pyro_centered_schools(data, 500, 500)
# no idea how to save this object: have tried pickle, dill, cloudpickle which
# all complain about `weakref`s, and pyro.get_param_store(), which is empty, and
# does not reload.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment