Last active
August 6, 2018 12:49
-
-
Save TStesco/87d035475c1f04063bfd68e54c3d6e28 to your computer and use it in GitHub Desktop.
demo pymc3 problem on online updating with pm.Interpolate() changing the model prior at each iteration.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# -*- coding: utf-8 -*- | |
#!/usr/bin/env python | |
import pymc3 as pm | |
import numpy as np | |
import theano | |
import pickle | |
import os | |
import copy | |
from scipy import stats | |
import tempfile | |
import shutil | |
def from_posterior(model_param, trace_samples): | |
smin, smax = np.min(trace_samples), np.max(trace_samples) | |
width = smax - smin | |
x_full = np.linspace(smin-2*width, smax+2*width, 1000) | |
y_full = stats.gaussian_kde(trace_samples)(x_full) | |
idx = np.where(y_full > 1e-07) | |
y = y_full[idx] | |
x = x_full[idx] | |
return pm.Interpolated(model_param, x, y) | |
if __name__ == "__main__": | |
np.random.seed(8) # ensure reproducable results | |
num_samples = 100 | |
print(pm.__version__) | |
tmpdir = tempfile.mkdtemp() | |
prior_trace_fname = os.path.join(tmpdir, "prior.pkl") | |
online_trace_fname = os.path.join(tmpdir, "online.pkl") | |
x_0 = np.random.normal(5, 5, size=num_samples*3) | |
y_0 = np.array([x*np.random.normal(1, 4) for x in x_0]) | |
Y = theano.shared(y_0) | |
X = theano.shared(x_0) | |
if not os.path.exists(prior_trace_fname): | |
print("Generating prior model") | |
with pm.Model() as prior_model: | |
eta = pm.Normal('eta', mu=0, sd=3) | |
theta = X*eta | |
obs = pm.Normal('obs', theta, observed=Y) | |
trace_h = pm.sample( | |
draws=1000, | |
chains=1, | |
cores=1, | |
step=pm.Metropolis() | |
) | |
with open(prior_trace_fname, 'wb') as f: | |
pickle.dump(trace_h, f) | |
with open(prior_trace_fname, 'rb') as f: | |
print("Loading prior model from: {}".format(prior_trace_fname)) | |
prior_trace_full = pickle.load(f) | |
# randomly resample trace to get trace length == num_samples | |
# super hacky... can someone tell me a better way pls? | |
if len(prior_trace_full) != num_samples: | |
# create empty trace object | |
prior_trace = pm.backends.base.MultiTrace( | |
[t[len(prior_trace_full):] for t in prior_trace_full._straces.values()]) | |
# sample N particles from trace | |
for r in np.random.randint(low=0, high=len(prior_trace_full), size=num_samples): | |
for varname, val in prior_trace_full[r].items(): | |
prior_trace._straces[0].samples[varname] = np.append( | |
prior_trace._straces[0].samples[varname], val) | |
prior_trace._straces[0].draw_idx += 1 | |
else: | |
prior_trace = copy.deepcopy(prior_trace_full) | |
online_trace = copy.deepcopy(prior_trace) | |
x_1 = np.random.normal(4, 4, size=int(num_samples/10)) | |
y_1 = np.array([x*np.random.normal(0.5, 2) for x in x_0]) | |
update_num = 0 | |
# online updating | |
for i in range(len(y_1)): | |
Y_1 = theano.shared(y_1[i]) | |
X_1 = theano.shared(x_1[i]) | |
update_num += 1 | |
print("online update: {}".format(update_num)) | |
with pm.Model() as update_model: | |
# using the pm.Interpolated distribution causes problem here | |
# it adds '_interval__' variables which are not in previous trace | |
eta = from_posterior('eta', online_trace.get_values('eta')) | |
theta = X_1*eta | |
obs = pm.Normal('obs', theta, observed=Y_1) | |
online_trace = pm.sample( | |
draws=num_samples, | |
start=online_trace, | |
chains=num_samples, | |
cores=1, | |
step=pm.SMC() | |
) | |
shutil.rmtree(tmpdir) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment