Skip to content

Instantly share code, notes, and snippets.

@tcbegley
Created December 19, 2020 22:17
Show Gist options
  • Save tcbegley/48c409813d2c21e2708657c6ee08a32c to your computer and use it in GitHub Desktop.
Save tcbegley/48c409813d2c21e2708657c6ee08a32c to your computer and use it in GitHub Desktop.
import matplotlib.pyplot as plt
import numpy as np
import requests
# data source
URL = "http://www.stat.columbia.edu/~gelman/book/data/light.asc"
# simulation hyperparams
N_SIM = 200 # number of simulations
def load_data():
resp = requests.get(URL)
_, data = resp.text.strip().split("\n\n")
return np.array([int(i) for i in data.replace("\n", " ").split(" ")])
def order_stat(y, theta):
y_sorted = np.sort(y, axis=1)
return np.abs(y_sorted[:, 60] - theta) - np.abs(y_sorted[:, 5] - theta)
if __name__ == "__main__":
rng = np.random.default_rng()
y = load_data()
n_obs = y.size
s2 = ((y - y.mean()) ** 2).sum() / (n_obs - 1) # sample variance
# sample mu and sigma 2 from posterior distribution
sigma2 = (n_obs - 1) * s2 / rng.chisquare(n_obs - 1, size=N_SIM)
mu = rng.normal(y.mean(), np.sqrt(sigma2 / n_obs))
y_post = rng.normal(
mu[:, None], np.sqrt(sigma2[:, None]), size=(N_SIM, n_obs)
)
# replicate Figure 6.2
f, ax = plt.subplots(figsize=(10, 8), nrows=4, ncols=5)
ax = ax.flatten()
for i in range(20):
ax[i].hist(y_post[i], bins=10)
f.suptitle("Replication of Gelman Figure 6.2")
plt.show(block=True)
# replicate Figure 6.3
f, ax = plt.subplots(figsize=(7, 5))
ax.hist(y_post[:20].min(axis=1))
ax.axvline(y.min(), color="k")
f.suptitle("Replication of Gelman Figure 6.3")
plt.show(block=True)
# replicate Figure 6.4
f, ax = plt.subplots(figsize=(10, 5), ncols=2)
ax[0].hist(y_post.var(axis=1))
ax[0].axvline(s2, color="k")
t = order_stat(np.tile(y, 200).reshape(-1, y.size), mu)
t_rep = order_stat(y_post, mu)
ax[1].scatter(t, t_rep, alpha=0.5)
ax[1].set_xlim((-15, 15))
ax[1].set_ylim((-15, 15))
ax[1].plot([-15, 15], [-15, 15], color="k")
ax[1].set_title(f"p-value: {(t_rep > t).mean():.2f}")
f.suptitle("Replication of Gelman Figure 6.4")
plt.show(block=True)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment