Navigation Menu

Skip to content

Instantly share code, notes, and snippets.

@ckrapu
Created December 4, 2022 18:19
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save ckrapu/ba25be2b592f4269943dfd8f740de41f to your computer and use it in GitHub Desktop.
Save ckrapu/ba25be2b592f4269943dfd8f740de41f to your computer and use it in GitHub Desktop.
2nd attempt at fixing hierarchical model example
import pymc as pm
import numpy as np
import pandas as pd
import arviz as az
import matplotlib.pyplot as plt
import xarray as xr
import aesara
trainSize,testSize = 1000,400
group_num_all = 7
group_num_train = 5
field_num = 3
totalSize = trainSize + testSize
# Number of hierarchical groups
group_size = totalSize/group_num_all
# Randomly assign group index out of 0,...,group_num_all
group_idx_val = np.random.choice(group_num_all, size=totalSize)
group_idx_train,group_idx_test = group_idx_val[:trainSize],group_idx_val[trainSize:]
x_train,x_test=np.random.random((trainSize,field_num)),np.random.random((testSize,field_num))
y_train,y_test=np.random.random((trainSize,1)),np.random.random((testSize,1))
with pm.Model() as model:
slope_mu = pm.Normal("slope_mu", 0.5, 0.1)
slope_sigma = pm.Normal("slope_sigma", 0.5, 0.1)
intercept_mu = pm.Normal("intercept_mu", 0.5, 0.1)
intercept_sigma = pm.Normal("intercept_sigma", 0.5, 0.1)
slope = pm.Normal("slope", slope_mu, slope_sigma,shape=(group_num_all,field_num))
intercept = pm.Normal("intercept", intercept_mu, intercept_sigma, shape=group_num_all)
x = pm.MutableData("x",x_train)
group_idx = pm.MutableData("group_idx",group_idx_train)
mu = intercept[group_idx] + pm.math.sum(x*slope[group_idx],axis=1)
sigma = pm.Exponential("sigma", 1.0)
pm.Normal("obs", mu=pm.Deterministic("y",mu), sigma=sigma, observed=y_train)
idata = pm.sample_prior_predictive(random_seed=100, samples=10)
trace = pm.sample(tune=10, draws=10)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment