Skip to content

Instantly share code, notes, and snippets.

@netsatsawat
Created September 21, 2020 14:38
Show Gist options
  • Save netsatsawat/9f6a3a14b80f8425f1012e183c1b0798 to your computer and use it in GitHub Desktop.
Save netsatsawat/9f6a3a14b80f8425f1012e183c1b0798 to your computer and use it in GitHub Desktop.
Python script to implement GaussianHMM from HMMLearn to model the hidden states
_df = np.column_stack([train_df[FT_COLS].values])
hmm_model = GaussianHMM(n_components=3, covariance_type="full",
n_iter=1000, random_state=SEED).fit(_df)
hidden_states = hmm_model.predict(_df)
print("Means and vars of each hidden state")
for i in range(hmm_model.n_components):
print(f'{i}th hidden state')
print('mean: ', (hmm_model.means_[i]))
print('var: ', np.diag(hmm_model.covars_[i]))
print()
sns.set(font_scale=1.25)
style_kwds = {'xtick.major.size': 3, 'ytick.major.size': 3, 'legend.frameon': True}
sns.set_style('white', style_kwds)
fig, axs = plt.subplots(hmm_model.n_components, sharex=True, sharey=True, figsize=(12,9))
colors = cm.rainbow(np.linspace(0, 1, model.n_components))
for i, (ax, color) in enumerate(zip(axs, colors)):
# Use fancy indexing to plot data in each state.
mask = hidden_states == i
ax.plot_date(train_df.index.values[mask],
train_df[COL_].values[mask],
".-", c=color)
ax.set_title("{0}th hidden state".format(i), fontsize=16, fontweight='demi')
# Format the ticks.
ax.xaxis.set_major_locator(YearLocator())
ax.xaxis.set_minor_locator(MonthLocator())
sns.despine(offset=10)
plt.tight_layout()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment