Skip to content

Instantly share code, notes, and snippets.

@davipatti
Created November 21, 2022 20:31
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save davipatti/734bed25d07eaa1de797f5e0a4a95d02 to your computer and use it in GitHub Desktop.
Save davipatti/734bed25d07eaa1de797f5e0a4a95d02 to your computer and use it in GitHub Desktop.
[Model index not mixing] pymc example code to illustrate model index parameter not mixing #pymc #bayesian
#!/usr/bin/env python
import pymc as pm
import numpy as np
import arviz as az
import matplotlib.pyplot as plt
for module in pm, np, az:
print(f"{module.__name__} {module.__version__}")
def sigmoid(x, a, b, c, d):
return c + d / (1 + np.exp(-b * (x - a)))
n_ind = 2 # 2 individuals measured
n_reps = 10 # Each individual repeated 10 times
n_timepoints = 12 # Each repeat measured at 10 time points
t = np.arange(n_timepoints)
np.random.seed(42)
# Responses for indivdiual 0
y0 = sigmoid(
x=t[:, np.newaxis],
a=np.random.uniform(3, 4, n_reps),
b=np.random.uniform(1, 2, n_reps),
c=0,
d=np.random.uniform(3, 4, n_reps),
)
# Responses for indivdiual 1
y1 = sigmoid(
x=t[:, np.newaxis],
a=np.random.uniform(6, 7, n_reps),
b=np.random.uniform(1, 2, n_reps),
c=0,
d=np.random.uniform(3, 4, n_reps),
)
plt.plot(t, y0, c="blue")
plt.plot(t, y1, c="red")
# Stack y0 and y1
y = np.hstack([y0.T.ravel(), y1.T.ravel()])
x = np.tile(t, n_reps * n_ind)
i = np.repeat(np.arange(n_ind), n_reps * n_timepoints)
plt.scatter(x, y, c=i, cmap="bwr")
plt.savefig("example-data.png")
with pm.Model():
# Model 1: Single, shared, growth curve
a_1 = pm.Normal("a_1", 1, 1)
b_1 = pm.Normal("b_1", 1, 1)
d_1 = pm.Normal("d_1", 1, 0.5)
mu_1 = sigmoid(x, a_1, b_1, 0, d_1)
# Model 2: Independent growth curves
a_2 = pm.Normal("a_2", 1, 1, shape=2)
b_2 = pm.Normal("b_2", 1, 1, shape=2)
d_2 = pm.Normal("d_2", 1, 0.5, shape=2)
mu_2 = sigmoid(x, a_2[i], b_2[i], 0, d_2[i])
# Choose between mu_1 and mu_2
p_mu = pm.Beta("p_mu", 1, 1)
m = pm.Bernoulli("m", p_mu) # model index parameter
# Likelihood
sigma = pm.Exponential("sigma", 1)
pm.Normal("lik", mu_1 * (1 - m) + mu_2 * m, sigma, observed=y)
trace = pm.sample(random_seed=42)
az.plot_trace(trace)
plt.savefig("example-trace.png")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment