Skip to content

Instantly share code, notes, and snippets.

@netsatsawat
Created September 21, 2020 12:44
Show Gist options
  • Save netsatsawat/6bede56769d525351ed1990c87520852 to your computer and use it in GitHub Desktop.
Save netsatsawat/6bede56769d525351ed1990c87520852 to your computer and use it in GitHub Desktop.
Python script demonstrates the implementation of Gaussian Mixture model
train_df = data_df.loc[: '2019-01-01'].dropna()
test_df = data_df.loc['2019-01-01': ].dropna()
X_train = train_df[FT_COLS].values
X_test = test_df[FT_COLS].values
model = mix.GaussianMixture(n_components=N_COMPONENTS,
covariance_type="full",
n_init=100,
random_state=SEED).fit(X_train)
# Predict the optimal sequence of internal hidden state
hidden_states = model.predict(X_train)
print("Means and vars of each hidden state")
for i in range(model.n_components):
print(f'{i}th hidden state')
print('mean: ', (model.means_[i]))
print('var: ', np.diag(model.covariances_[i]))
print()
sns.set(font_scale=1.25)
style_kwds = {'xtick.major.size': 3, 'ytick.major.size': 3, 'legend.frameon': True}
sns.set_style('white', style_kwds)
fig, axs = plt.subplots(model.n_components, sharex=True, sharey=True, figsize=(12,9))
colors = cm.rainbow(np.linspace(0, 1, model.n_components))
for i, (ax, color) in enumerate(zip(axs, colors)):
# Use fancy indexing to plot data in each state.
mask = hidden_states == i
ax.plot_date(train_df.index.values[mask],
train_df[COL_].values[mask],
".-", c=color)
ax.set_title("{0}th hidden state".format(i), fontsize=16, fontweight='demi')
# Format the ticks.
ax.xaxis.set_major_locator(YearLocator())
ax.xaxis.set_minor_locator(MonthLocator())
sns.despine(offset=10)
plt.tight_layout()
@rhettxio
Copy link

So many variables are not defined. FT_COLS, N_COMPONENTS, COL_

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment