Skip to content

Instantly share code, notes, and snippets.

@kpalin
Created November 4, 2022 15:12
Show Gist options
  • Save kpalin/5069b6de8eb0f9fb7efb2d5861ade53c to your computer and use it in GitHub Desktop.
Save kpalin/5069b6de8eb0f9fb7efb2d5861ade53c to your computer and use it in GitHub Desktop.
Test/demo model for nested priors in bambi
# %%
import pandas as pd
import numpy as np
import bambi as bmb
import arviz as az
import matplotlib.pyplot as plt
#%% [markdown]
#
# Generate data
#
#%%
np.random.seed(0)
N, D = 50, 10
d = 1
relevant_dim = [1]
X = np.random.randn(N, D)
true_coef = pd.Series(np.zeros(D), index=[f"d{i}" for i in range(D)])
true_coef[relevant_dim] = np.random.randn(d) / np.sqrt(d)
formula = "response~1+" + "+".join(f"d{i}" for i in range(D))
true_response = X.dot(true_coef)
true_response_var = (true_coef ** 2).sum()
y = true_response + np.random.randn(N)
data = pd.DataFrame(X, columns=[f"d{i}" for i in range(D)])
data["response"] = y
#%% [markdown]
#
# The expected model:
#
# %%
import pymc as pm
with pm.Model():
w_prior = pm.Dirichlet("w", a=[1, 1])
beta = pm.NormalMixture("beta", w=w_prior, mu=[0, 0], sigma=[0.1, 4.5], shape=D)
sigma = pm.HalfStudentT("sigma", nu=4, sigma=2.0322)
pm.Normal("y", mu=pm.math.dot(X, beta), sigma=sigma, observed=y)
tr = pm.sample()
# %%
az.plot_trace(tr)
plt.tight_layout()
s = az.summary(tr)
s["true_value"] = np.nan
s["true_value"].iloc[:D] = true_coef.values
s
#%% [markdown]
#
# This model works fine with fixed ratio of small and large coefficients.
#
# %%
b_model = bmb.Model(formula, data)
w = 0.5
subject_prior = bmb.Prior("NormalMixture", w=[w, 1 - w], mu=[0, 0], sigma=[0.1, 4.5])
b_model.set_priors(common=subject_prior)
b_model
#%%
b_model.build()
r = b_model.fit()
#%%
rv = {n: [{"ref_val": v}] for n, v in true_coef.items()}
az.plot_posterior(r, ref_val=rv)
plt.tight_layout()
az.summary(r)
#%% [markdown]
#
# This model will fail with dirichlet weights
#
# %%
b_model = bmb.Model(formula, data)
w_prior = bmb.Prior("Dirichlet", a=np.array([1.0, 1.0]))
subject_prior = bmb.Prior("NormalMixture", w=w_prior, mu=[0, 0], sigma=[0.1, 3])
b_model.set_priors(common=subject_prior)
b_model
#%%
b_model.build()
r = b_model.fit()
#%%
# %%
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment