Created
March 22, 2018 12:03
-
-
Save joshdorrington/2913269ec89932f9b68b5fe4c91cdf5c to your computer and use it in GitHub Desktop.
A side by side comparison of hmmlearn and pomegranate HMMs on a chaotic dataset
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#run in python 2.7 | |
from hmmlearn.hmm import GaussianHMM | |
data=np.fromfile("np_input_file.txt").reshape([200001,6]) | |
#MODEL PARAMS | |
K=3 | |
iter_num=3000 | |
convergence_tolerance=0.001 | |
stat_vec_tol=0.999 | |
#FIT MODEL | |
model= GaussianHMM(n_components=K, covariance_type="full", n_iter=iter_num,tol=convergence_tolerance).fit(data) | |
hidden_states=model.predict(data) #encodes a time series of hidden states | |
means=model.means_ | |
#PLOT MODEL | |
plt.plot(data.T[0],data.T[3],c='k',lw=0.005) #the data | |
for i in range(0,3): #the model means | |
plt.scatter(model.states[i].distribution.mu[0],model.states[i].distribution.mu[3],c='r',label="HMM clusters upon convergence") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#run in python 3.6 | |
from pomegranate import * | |
from sklearn.cluster import KMeans | |
import matplotlib.pyplot as plt | |
#IMPORT DATA AND RUN KMEANS | |
data=np.fromfile("np_input_file.txt").reshape([1,200001,6]) | |
cov = numpy.eye(6) | |
kmodel=KMeans(3).fit(data[0,:,:]) | |
mu=kmodel.cluster_centers_ | |
#BUILD MODEL | |
s1=State(MultivariateGaussianDistribution(mu[0],cov)) | |
s2=State(MultivariateGaussianDistribution(mu[1],cov)) | |
s3=State(MultivariateGaussianDistribution(mu[2],cov)) | |
model=HiddenMarkovModel() | |
model.add_states(s1, s2, s3) | |
model.add_transition(model.start, s1, 0.33) | |
model.add_transition(model.start, s2, 0.33) | |
model.add_transition(model.start, s3, 0.34) | |
model.add_transition(s1, s1, 0.33) | |
model.add_transition(s1, s2, 0.33) | |
model.add_transition(s1, s3, 0.34) | |
model.add_transition(s2, s1, 0.33) | |
model.add_transition(s2, s2, 0.33) | |
model.add_transition(s2, s3, 0.34) | |
model.add_transition(s3, s1, 0.33) | |
model.add_transition(s3, s2, 0.33) | |
model.add_transition(s3, s3, 0.34) | |
model.bake() | |
data=data.reshape([200001,6]) | |
#FIT MODEL | |
model.fit(data, verbose=True) | |
##PLOT MODEL | |
plt.plot(data.T[0], data.T[3],c='k',lw=0.005) #the data | |
for i in range(0,3): #the model means | |
plt.scatter(model.states[i].distribution.mu[0],model.states[i].distribution.mu[3],c='r',label="HMM clusters upon convergence") | |
plt.scatter(kmodel.cluster_centers_[i][0],kmodel.cluster_centers_[i][3],c='g',zorder=30,label="k_means initial clusters") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment