Skip to content

Instantly share code, notes, and snippets.

@slinderman
Last active December 23, 2015 18:57
Show Gist options
  • Save slinderman/1555e77aff0b2c7dc44d to your computer and use it in GitHub Desktop.
Save slinderman/1555e77aff0b2c7dc44d to your computer and use it in GitHub Desktop.
from __future__ import division
import numpy as np
np.seterr(divide='ignore') # these warnings are usually harmless for this code
from matplotlib import pyplot as plt
import matplotlib
import os
matplotlib.rcParams['font.size'] = 8
import pyhsmm
from pyhsmm.util.text import progprint_xrange
#########################
# posterior inference #
#########################
# Set the weak limit truncation level
Nmax = 100
# and some hyperparameters
obs_dim = 1
gauss_hypparams = {'mu_0':np.zeros(obs_dim),
'sigma_0':np.eye(obs_dim),
'kappa_0':0.25,
'nu_0':obs_dim+2}
poiss_hypparams = {'alpha_0': 1.0, 'beta_0': 1.0}
### HDP-HMM to generate true data
obs_distns = [pyhsmm.distributions.Gaussian(**gauss_hypparams) for state in xrange(Nmax)]
#obs_distns = [pyhsmm.distributions.Poisson(**poiss_hypparams) for state in xrange(Nmax)]
model = pyhsmm.models.WeakLimitHDPHMM(alpha=6.,gamma=6.,init_state_concentration=1.,
obs_distns=obs_distns)
data, Z_true = model.generate(T=1000)
K_true = len(model.used_states)
# Fit with a HDPHMM with concentration parameter resampling
obs_distns = [pyhsmm.distributions.Gaussian(**gauss_hypparams) for state in xrange(Nmax)]
#obs_distns = [pyhsmm.distributions.Poisson(**poiss_hypparams) for state in xrange(Nmax)]
posteriormodel = pyhsmm.models.WeakLimitHDPHMM(alpha_a_0=6.0, alpha_b_0=1.0,
gamma_a_0=6., gamma_b_0=1.0,
init_state_concentration=1.,
obs_distns=obs_distns)
posteriormodel.add_data(data)
# Fit with a HDPHMM *without* concentration parameter resampling
obs_distns = [pyhsmm.distributions.Gaussian(**gauss_hypparams) for state in xrange(Nmax)]
#obs_distns = [pyhsmm.distributions.Poisson(**poiss_hypparams) for state in xrange(Nmax)]
posteriormodel_concfixed = pyhsmm.models.WeakLimitHDPHMM(alpha=6.0, gamma=6.,
init_state_concentration=1.,
obs_distns=obs_distns)
posteriormodel_concfixed.add_data(data)
Ks_concresample = []
Ks_concfixed = []
for idx in progprint_xrange(100):
posteriormodel.resample_model()
Ks_concresample.append(len(posteriormodel.used_states))
posteriormodel_concfixed.resample_model()
Ks_concfixed.append(len(posteriormodel_concfixed.used_states))
plt.figure()
plt.plot(np.array(Ks_concresample), "-b", label="conc resampling")
plt.plot(np.array(Ks_concfixed), "-r", label="conc fixed")
plt.plot([0,100], K_true * np.ones(2), ':k', label="True")
plt.legend(loc="upper left")
plt.ylim(0,100)
plt.title("Obs distn: %s" % obs_distns[0].__class__)
plt.xlabel("Iteration")
plt.ylabel("Number of States")
plt.savefig("num_states.png")
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment