Skip to content

Instantly share code, notes, and snippets.

@ckrapu
Created April 1, 2021 04:24
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save ckrapu/aaedb3463d9382cdfc5e423437534179 to your computer and use it in GitHub Desktop.
Save ckrapu/aaedb3463d9382cdfc5e423437534179 to your computer and use it in GitHub Desktop.
dirichlet-concat-divergence.py
## Combined model
c_comb = np.asarray([[16,29,4],
[16,29,6],
[14,30,4],
[16,29,3],
[16,31,5],
[13,29,5],
[15,32,5],
[15,29,6],
[17,31,6],
[13,29,4],
[15,4,31],
[15,3,30],
[15,5,31],
[16,6,31],
[14,5,29],
[16,5,30],
[15,6,30],
[13,6,30],
[14,6,31],
[16,6,29],
[15,7,29]]);
counts_comb = np.asarray([49,51,48,48,52,47,52,50,54,46,50,48,51,53,48,51,51,49,51,51,51])[:, None];
idx = np.asarray([0,0,0,0,0,0,0,0,0,0,1,1,1,1,1,1,1,1,1,1,1]).astype(int);
with pm.Model() as model_comb:
hyper_param_comb1 = pm.HalfNormal('hyper_param_comb1',10,shape=3);
hyper_param_comb2 = pm.HalfNormal('hyper_param_comb2',10,shape=3);
param_comb1 = pm.Dirichlet('param_comb1',a=hyper_param_comb1,shape=(N1,3));
param_comb2 = pm.Dirichlet('param_comb2',a=hyper_param_comb2,shape=(N2,3));
param_comb = tt.concatenate((param_comb1,param_comb2),axis=0);
print(param_comb.tag.test_value.shape)
y_comb = pm.Multinomial('y_comb', n=counts_comb, p=param_comb[idx], observed=c_comb);
trace_comb = pm.sample(target_accept=0.90);
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment