Navigation Menu

Skip to content

Instantly share code, notes, and snippets.

@jdhooghe
Last active February 26, 2018 08:39
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save jdhooghe/9719a464071bb9fa899ca31a1a1fa6d3 to your computer and use it in GitHub Desktop.
Save jdhooghe/9719a464071bb9fa899ca31a1a1fa6d3 to your computer and use it in GitHub Desktop.
Stick Breaking Example
import pymc3 as pm
import numpy as np
import matplotlib.pyplot as plt
from theano import shared
import theano.tensor as tt
import pandas as pd
x1 = np.linspace(0., 9.9, 10)
x2 = np.linspace(10., 30.9, 30)
x = np.concatenate((x1, x2), axis = 0)
y1 = 10.*x1 + 5 + np.random.ranf(x1.size)*100.
y2 = y1[-1:] + 40.*x2 + -40.*x2[0] + np.random.ranf(x2.size)*300.
y = np.concatenate((y1, y2), axis = 0)
plt.plot(x1, y1, linestyle = 'None', marker = '.')
plt.plot(x2, y2, linestyle = 'None', marker = '.')
def norm_cdf(z):
return 0.5 * (1 + tt.erf(z / np.sqrt(2)))
def stick_breaking(v):
return v * tt.concatenate([tt.ones_like(v[:, :1]), tt.extra_ops.cumprod(1. - v, axis=1)[:, :-1]], axis=1)
x = x[:, np.newaxis]
y = y[:, np.newaxis]
X = shared(x, broadcastable=(False, True))
K = 20
with pm.Model() as model:
alpha = pm.Normal('alpha', 0., 1., shape=K)
beta = pm.Normal('beta', 0., 1., shape=K)
v = norm_cdf(alpha + beta * X)
w = pm.Deterministic('w', stick_breaking(v))
gamma = pm.Normal('gamma', 0., 10., shape=K)
delta = pm.Normal('delta', 0., 10., shape=K)
mu = pm.Deterministic('mu', gamma + delta * X)
tau = pm.Gamma('tau', 1., 1., shape=K)
obs = pm.NormalMixture('obs', w, mu, tau=tau, observed=y)
with model:
trace = pm.sample(20000, init='advi+adapt_diag', progressbar = True)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment