Skip to content

Instantly share code, notes, and snippets.

@TStesco
Last active August 6, 2018 12:49
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save TStesco/87d035475c1f04063bfd68e54c3d6e28 to your computer and use it in GitHub Desktop.
Save TStesco/87d035475c1f04063bfd68e54c3d6e28 to your computer and use it in GitHub Desktop.
demo pymc3 problem on online updating with pm.Interpolate() changing the model prior at each iteration.
# -*- coding: utf-8 -*-
#!/usr/bin/env python
import pymc3 as pm
import numpy as np
import theano
import pickle
import os
import copy
from scipy import stats
import tempfile
import shutil
def from_posterior(model_param, trace_samples):
smin, smax = np.min(trace_samples), np.max(trace_samples)
width = smax - smin
x_full = np.linspace(smin-2*width, smax+2*width, 1000)
y_full = stats.gaussian_kde(trace_samples)(x_full)
idx = np.where(y_full > 1e-07)
y = y_full[idx]
x = x_full[idx]
return pm.Interpolated(model_param, x, y)
if __name__ == "__main__":
np.random.seed(8) # ensure reproducable results
num_samples = 100
print(pm.__version__)
tmpdir = tempfile.mkdtemp()
prior_trace_fname = os.path.join(tmpdir, "prior.pkl")
online_trace_fname = os.path.join(tmpdir, "online.pkl")
x_0 = np.random.normal(5, 5, size=num_samples*3)
y_0 = np.array([x*np.random.normal(1, 4) for x in x_0])
Y = theano.shared(y_0)
X = theano.shared(x_0)
if not os.path.exists(prior_trace_fname):
print("Generating prior model")
with pm.Model() as prior_model:
eta = pm.Normal('eta', mu=0, sd=3)
theta = X*eta
obs = pm.Normal('obs', theta, observed=Y)
trace_h = pm.sample(
draws=1000,
chains=1,
cores=1,
step=pm.Metropolis()
)
with open(prior_trace_fname, 'wb') as f:
pickle.dump(trace_h, f)
with open(prior_trace_fname, 'rb') as f:
print("Loading prior model from: {}".format(prior_trace_fname))
prior_trace_full = pickle.load(f)
# randomly resample trace to get trace length == num_samples
# super hacky... can someone tell me a better way pls?
if len(prior_trace_full) != num_samples:
# create empty trace object
prior_trace = pm.backends.base.MultiTrace(
[t[len(prior_trace_full):] for t in prior_trace_full._straces.values()])
# sample N particles from trace
for r in np.random.randint(low=0, high=len(prior_trace_full), size=num_samples):
for varname, val in prior_trace_full[r].items():
prior_trace._straces[0].samples[varname] = np.append(
prior_trace._straces[0].samples[varname], val)
prior_trace._straces[0].draw_idx += 1
else:
prior_trace = copy.deepcopy(prior_trace_full)
online_trace = copy.deepcopy(prior_trace)
x_1 = np.random.normal(4, 4, size=int(num_samples/10))
y_1 = np.array([x*np.random.normal(0.5, 2) for x in x_0])
update_num = 0
# online updating
for i in range(len(y_1)):
Y_1 = theano.shared(y_1[i])
X_1 = theano.shared(x_1[i])
update_num += 1
print("online update: {}".format(update_num))
with pm.Model() as update_model:
# using the pm.Interpolated distribution causes problem here
# it adds '_interval__' variables which are not in previous trace
eta = from_posterior('eta', online_trace.get_values('eta'))
theta = X_1*eta
obs = pm.Normal('obs', theta, observed=Y_1)
online_trace = pm.sample(
draws=num_samples,
start=online_trace,
chains=num_samples,
cores=1,
step=pm.SMC()
)
shutil.rmtree(tmpdir)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment