Created
October 18, 2016 16:40
-
-
Save maxentile/2a46be1e8445867da342500c7d77f330 to your computer and use it in GitHub Desktop.
quick check that `hmm.sample_by_observation_probabilities` isn't the culprit
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import numpy.random as npr | |
import pyemma | |
# generate random trajectories, where each trajectory has a random length between 100 and 1000. | |
# since the trajectories have different lengths, we will by likely to see an error if `hmm` somehow | |
# scrambles the trajectory indices. | |
dtrajs_A = [npr.randint(0,100,x) for x in npr.randint(100,1000,8)] | |
dtrajs_B = [npr.randint(100,200,x) for x in npr.randint(100,1000,9)] | |
dtrajs_C = [npr.randint(200,300,x) for x in npr.randint(100,1000,10)] | |
dtrajs_D = [npr.randint(300,400,x) for x in npr.randint(100,1000,11)] | |
dtrajs = dtrajs_A + dtrajs_B + dtrajs_C + dtrajs_D | |
hmm = pyemma.msm.estimate_hidden_markov_model(dtrajs, nstates = 10, lag = 10, maxit = 1) | |
# draw many (traj_ind, frame_ind) pairs according to the MSM | |
indices = hmm.sample_by_observation_probabilities(10000) | |
# for each (traj_ind, frame_ind) pair, check that the frame_ind isn't out of bounds | |
for (traj_ind, frame_ind) in np.vstack(indices): | |
assert(len(dtrajs[traj_ind]) > frame_ind) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment