Skip to content

Instantly share code, notes, and snippets.

@ckrapu
Created October 28, 2020 14:43
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save ckrapu/8be4da91a70763ee62f889c6cd98f700 to your computer and use it in GitHub Desktop.
Save ckrapu/8be4da91a70763ee62f889c6cd98f700 to your computer and use it in GitHub Desktop.
conjugate-sampling-custom-step
from pymc3.step_methods.arraystep import BlockedStep
from pymc3.distributions.transforms import stick_breaking
from pymc3.model import modelcontext
import pymc3 as pm
import numpy as np
def sample_dirichlet(c):
gamma = np.random.gamma(c)
p = gamma/gamma.sum(axis=-1, keepdims=True)
return p
class CDUpdate(BlockedStep):
def __init__(self, var, counts, concentration, model=None):
model = modelcontext(model)
self.m = model
self.vars = [var]
self.counts = counts
self.name = var.name
self.conc = concentration
def step(self, point):
alpha = np.exp(point[self.conc.transformed.name]) + self.counts
new_p = sample_dirichlet(alpha)
point[self.name] = stick_breaking.forward_val(new_p)
return point
J = 10
N = 500
ncounts = 20
alpha = 0.5 * np.ones([N,J])
p_true = sample_dirichlet(alpha)
counts = np.zeros([N,J])
for i in range(N):
counts[i] = np.random.multinomial(ncounts, p_true[i])
use_conjugate = True
with pm.Model() as model:
tau = pm.Exponential('tau', lam=1, testval=1.)
alpha = pm.Deterministic('alpha', tau*np.ones([N,J]))
p = pm.Dirichlet('p', a=alpha)
step = []
if use_conjugate:
step += [CDUpdate(p.transformed, counts, tau, model=model)]
else:
x = pm.Multinomial('x', n=counts.sum(axis=-1), p=p, observed=counts)
trace = pm.sample(step=step, chains=2,cores=1)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment