Skip to content

Instantly share code, notes, and snippets.

@dslaw
Last active February 26, 2018 00:29
Show Gist options
  • Save dslaw/edae9837ade4733ca19861817c66bf6b to your computer and use it in GitHub Desktop.
Save dslaw/edae9837ade4733ca19861817c66bf6b to your computer and use it in GitHub Desktop.
Plot state sequence with generative distributions
from scipy.stats import norm
import numpy as np
import matplotlib.pyplot as plt
plt.style.use("ggplot")
def save_and_close(fig, filename):
fig.savefig(filename, dpi=300)
plt.close(fig)
return
def remove_ticks(ax, xaxis=False, yaxis=False):
if xaxis:
plt.setp(ax.get_xticklabels(), visible=False)
plt.setp(ax.xaxis.get_majorticklines(), visible=False)
if yaxis:
plt.setp(ax.get_yticklabels(), visible=False)
plt.setp(ax.yaxis.get_majorticklines(), visible=False)
return ax
def share_axes(*axes, xaxis=False, yaxis=False):
ax = axes[0]
if xaxis:
ax.get_shared_x_axes().join(*axes)
if yaxis:
ax.get_shared_y_axes().join(*axes)
return
# Simulate from a three-state Hidden Markov Model with Gaussian emissions.
n_components = 3
n_draws = 200
initial_state = 0
transmat = np.array([
[.75, .25, 0],
[.25, .5, .25],
[.05, .2, .75],
])
means = np.array([25, 35, 60])
stds = np.array([1.41, 2.23, 3.9])
rs = np.random.RandomState(13)
labels = np.full(n_draws, initial_state, dtype=np.int)
data = np.empty(n_draws, dtype=np.float)
for t in range(n_draws - 1):
i = labels[t]
labels[t + 1] = rs.choice(n_components, p=transmat[i, :])
for t in range(n_draws):
k = labels[t]
data[t] = rs.normal(means[k], stds[k])
# Variables for plotting.
n_points = 250
line_color = (.5, .5, .5, .25)
line_kwds = {"alpha": .5, "linewidth": 1.2}
scatter_kwds = {"s": 5}
time_index = np.arange(len(data))
# Plot the time-series with labeled states and densities
# corresponding to each state's distribution.
# Create figure and subplots to be drawn on.
fig = plt.figure()
gs = plt.GridSpec(1, 2, width_ratios=[4, 1])
ax_ts = fig.add_subplot(gs[0, 0])
ax_density = fig.add_subplot(gs[0, 1], sharey=ax_ts)
# Plot time-series.
# Giving each state its own collection will cause
# matplotlib to cycle through colors from the global
# palette.
ax_ts.plot(time_index, data, c=line_color)
for k in np.unique(labels):
mask = labels == k
ax_ts.scatter(time_index[mask], data[mask], **scatter_kwds)
ax_ts.set_xlabel("Time")
# Plot densities with data on y-axis.
ymin, ymax = ax_ts.get_ybound()
ypts = np.linspace(ymin, ymax, n_points)
# Draw in the same order as the time-series
# states - matplotlib will cycle through the color
# palette in the same order since we're drawing on
# a new axis.
for k in np.unique(labels):
densities = norm.pdf(ypts, loc=means[k], scale=stds[k])
ax_density.plot(densities, ypts)
# Remove ticks from the density subplot.
remove_ticks(ax_density, xaxis=True, yaxis=True)
save_and_close(fig, "state-sequence-1.png")
# It may be of interest to show each state as it's own signal.
# Create figure and subplots to be drawn on.
fig = plt.figure()
gs = plt.GridSpec(1, 2, width_ratios=[4, 1])
ax_ts = fig.add_subplot(gs[0, 0])
ax_density = fig.add_subplot(gs[0, 1], sharey=ax_ts)
# Plot time-series.
for k in np.unique(labels):
mask = labels == k
ax_ts.plot(time_index[mask], data[mask], **line_kwds)
ax_ts.scatter(time_index[mask], data[mask], **scatter_kwds)
ax_ts.set_xlabel("Time")
# Plot densities with data on y-axis.
ymin, ymax = ax_ts.get_ybound()
ypts = np.linspace(ymin, ymax, n_points)
# Plot densities with data on y-axis.
ymin, ymax = ax_ts.get_ybound()
ypts = np.linspace(ymin, ymax, n_points)
for k in np.unique(labels):
densities = norm.pdf(ypts, loc=means[k], scale=stds[k])
ax_density.plot(densities, ypts)
# Remove ticks from the density subplot.
remove_ticks(ax_density, xaxis=True, yaxis=True)
save_and_close(fig, "state-sequence-2.png")
# Take it to the conclusion, and place each state on its
# own subplot.
cmap = plt.get_cmap("Dark2") # Not sure what ggplot uses...
# Order states by mean so that subplot order aligns with
# visual expectation (i.e. can be read top-down/bottom-up).
states = np.argsort(means)[::-1]
K = len(states)
# Create figure and subplots to be drawn on.
# One subplot for each state's time-series, and one for
# all the distributions.
fig = plt.figure()
gs = plt.GridSpec(K, 2, width_ratios=[4, 1])
axes_ts = [fig.add_subplot(gs[k, 0]) for k in range(K)]
share_axes(*axes_ts, xaxis=True, yaxis=True)
ax_density = fig.add_subplot(gs[:, 1])
# Plot each time-series.
for k, ax in zip(states, axes_ts):
mask = labels == k
colors = cmap(labels[mask])
ax.plot(time_index[mask], data[mask], c=cmap(k), **line_kwds)
ax.scatter(time_index[mask], data[mask], c=colors, **scatter_kwds)
# Leave x-ticks on the bottom-most axis in the column.
if ax is not axes_ts[-1]:
remove_ticks(ax, xaxis=True, yaxis=False)
axes_ts[-1].set_xlabel("Time")
# Draw density for each component over the entire range.
ybounds = np.array([ax.get_ybound() for ax in axes_ts])
ymin, _ = np.min(ybounds, axis=0)
_, ymax = np.max(ybounds, axis=0)
ypts = np.linspace(ymin, ymax, n_points)
for k in states:
densities = norm.pdf(ypts, loc=means[k], scale=stds[k])
ax_density.plot(densities, ypts, c=cmap(k))
# Clean up density plot, but keep y-axis labels as ticks no
# longer align with the time-series.
remove_ticks(ax_density, xaxis=True, yaxis=False)
ax_density.yaxis.tick_right()
save_and_close(fig, "state-sequence-3.png")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment