Skip to content

Instantly share code, notes, and snippets.

@chvandorp
Last active November 25, 2023 15:48
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save chvandorp/d02082731fc24cec78ea3107ecc14c03 to your computer and use it in GitHub Desktop.
Save chvandorp/d02082731fc24cec78ea3107ecc14c03 to your computer and use it in GitHub Desktop.
Python script to plot Dirichlet Multinomial sample
import cmdstanpy
import numpy as np
import matplotlib.pyplot as plt
import scipy.stats as sts
def simplex_transform(X):
"""
Take a DirMult sample X and plot it on a 2-D simplex
"""
p = X / np.sum(X[0,:])
p = p[:,:2]
theta = 3*np.pi/4
v, u = np.sin(theta), np.cos(theta)
# rotation
R = np.array([[u, -v],[v, u]])
# scale
S = np.array([[1, 0], [0, np.sqrt(3)]])
return (p - np.array([0.5, 0.5])) @ R @ S - np.array([0, 0.5])
def plot_trian(ax, **kwargs):
"""plot a triangle"""
s, t = np.sqrt(1/2), np.sqrt(3/2) - 0.5
ax.plot([-s, s, 0,-s],[-0.5, -0.5, t,-0.5], **kwargs)
sm = cmdstanpy.CmdStanModel(stan_file="dirmult_rng.stan")
alphas = [
np.array([2.0, 4.5, 7.0]),
np.array([40.0, 60.0, 15.0]),
np.array([0.25, 0.25, 0.25]),
np.array([0.75, 5.0, 10.0])
]
Xs = []
for alpha in alphas:
data = {"K" : 3, "N" : 1000, "alpha" : alpha}
sam = sm.sample(data=data, iter_sampling=10000, fixed_param=True, show_progress=False)
X = sam.stan_variable("X")
Xs.append(X)
fig, axs = plt.subplots(2, 2, figsize=(7,7))
for i, X in enumerate(Xs):
p = simplex_transform(X)
ax = axs.flatten()[i]
ax.axis('equal')
color = sts.gaussian_kde(p[:5000].T)(p.T)
ax.scatter(p[:,0], p[:,1], s=0.5, c=color, linewidths=0)
ax.axis('off')
plot_trian(ax, color='k')
a1, a2, a3 = alphas[i]
ax.set_title(f"$\\alpha = ({a1}, {a2}, {a3})'$")
print("scaled means:", np.mean(X, axis=0) / np.sum(X[0,:]))
fig.tight_layout()
fig.savefig("dirmult_simplices.png", dpi=300, bbox_inches="tight")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment