Skip to content

Instantly share code, notes, and snippets.

@dslaw
Last active February 26, 2018 00:26
Show Gist options
  • Save dslaw/5f6e0ce1cedc470af27e709fea5d6e3f to your computer and use it in GitHub Desktop.
Save dslaw/5f6e0ce1cedc470af27e709fea5d6e3f to your computer and use it in GitHub Desktop.
Plot mixture of univariate distributions
from scipy.stats import norm
import matplotlib.pyplot as plt
import numpy as np
plt.style.use("ggplot")
# Simulate from a Univariate Gaussian Mixture Model.
rs = np.random.RandomState(13)
n_components = 3
n_draws = 1000
weights = np.array([.5, .2, .3])
means = np.array([25, 10.8, 32.1])
stds = np.array([1.75, 1, 2.2])
labels = rs.choice(n_components, size=n_draws, replace=True, p=weights)
data = np.array([
rs.normal(loc=means[k], scale=stds[k])
for k in labels
])
# Plot each distribution against the histogram.
fig, ax = plt.subplots()
ax.hist(data, bins=100, density=True, edgecolor="k")
xpts = np.linspace(data.min(), data.max(), 250)
for k in range(n_components):
densities = norm.pdf(xpts, loc=means[k], scale=stds[k])
ax.plot(xpts, densities, label=f"k={k + 1}")
ax.legend()
fig.savefig("univariate-mixture.png")
plt.close(fig)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment