Created
March 7, 2013 13:26
-
-
Save mattjj/5108053 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from __future__ import division | |
import numpy as np | |
np.seterr(divide='ignore') | |
from matplotlib import pyplot as plt | |
plt.ioff() | |
import pyhsmm | |
pyhsmm.internals.states.use_eigen() | |
from pyhsmm.util.text import progprint_xrange | |
########## | |
# util # | |
########## | |
def changepoint_indicators_from_stateseq(stateseq): | |
''' | |
returns an binary array "ret" of same size as stateseq | |
where ret[i]=1 iff there's a changepoint between times i-1 and i | |
''' | |
return np.concatenate(( (0,), np.diff(stateseq) != 0 )) | |
########### | |
# setup # | |
########### | |
data = np.loadtxt('small_data.bed',usecols=(3,)) | |
Nmax = 50 | |
obs_hypparams = dict( | |
mu_0=0., tausq_0=30000.**2, # mean and variance for the prior for means | |
sigmasq_0=2000.**2, nu_0=2, # mean and number of pseudo-observations for variances | |
# setting a large sigmasq_0 and increasing | |
# nu_0 will encourage less segmentation | |
) | |
model = pyhsmm.models.StickyHMM( | |
obs_distns=[pyhsmm.distributions.ScalarGaussianNonconjNIX(**obs_hypparams) | |
for idx in range(Nmax)], | |
alpha=10.,gamma=10.,kappa=25., # the number of segments is sensitive to these, esp kappa! | |
init_state_concentration=5.) | |
# the HDP can be sensitive to the alpha and gamma concentration parameters and | |
# the sticky prior can also be very sensitive to kappa (meaning those choices | |
# can significantly affect the number of states or segments inferred). instead | |
# of passing alpha and gamma directly, we can put priors over them and sample | |
# them as well, which is much more flexible and data-driven but much slower (a | |
# lot of it is Python interpreter overhead). the results should be much less | |
# sensitive to the settings of these hyperparameters. (maybe slice sampling over | |
# the concentrations would be more effective than gibbs sampling...) | |
# model = pyhsmm.models.StickyHMM( | |
# obs_distns=[pyhsmm.distributions.ScalarGaussianNonconjNIX(**obs_hypparams) | |
# for idx in range(Nmax)], | |
# alphakappa_a_0=50.,alphakappa_b_0=1., | |
# rho_a_0=50.,rho_b_0=1., # higher rho_a_0 means stickier | |
# gamma_a_0=1.,gamma_b_0=3., | |
# init_state_concentration=5. | |
# ) | |
model.add_data(data) | |
################### | |
# sampling loop # | |
################### | |
niter = 500 | |
burn_in = 50 | |
plt.figure(figsize=(12,3)) | |
dmin,dmax = data.min(),data.max() | |
changepoint_counts = np.zeros(data.shape[0]) | |
for itr in progprint_xrange(niter): | |
model.resample_model() | |
if (itr > burn_in): | |
changepoint_counts += changepoint_indicators_from_stateseq(model.states_list[0].stateseq) | |
if (itr % 10 == 0): | |
plt.clf() | |
plt.plot(data,'k-') | |
model.states_list[0].plot(colors_dict=model._get_colors(),vertical_extent=(dmin,dmax),alpha=0.5) | |
plt.savefig('figs/%d.png' % itr) | |
plt.figure(figsize=(12,3)) | |
plt.plot(changepoint_counts/(niter - burn_in)) | |
plt.xlim(0,data.shape[0]-1) | |
plt.title('estimated changepoints') | |
plt.savefig('figs/estimated_changepoints.png') | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment