Created
March 11, 2015 01:53
-
-
Save rawkintrevo/50f46ac7396b3d6f3a21 to your computer and use it in GitHub Desktop.
A Clinician's Tool for Analyzing Non-Compliance -Chickering and Pearl (1997)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
A Clinician's Tool for Analyzing Non-Compliance | |
-Chickering and Pearl (1997) | |
X: | |
In this paper they use v to represent probability (instead of p, it's fun to be creative). | |
Z~ Treatment Assignement : Bern(vZ) 1 - Pt. was prescribed treatment | |
D~ Treatment Recieved : Bern(vD) 1 - Pt. was administered treatment | |
Y~ Outcome : Bern(vY) 1 - positive response | |
U~ Latent Factors : Given by C and R below | |
Data: a series of triples of form (z,d,y) one for each subject. | |
We seek average change in Y due to treatment: ACE( D -> Y ) | |
(Average Causal Effect) | |
ACE: 'Equivalently, ACE( D --+ Y) is the difference between | |
the fraction of subjects who are helped by the treatment | |
(R = r1) and the fraction of subjects who are | |
hurt by the treatment (R = r2).' | |
Model: | |
This model will give a distribution of Pr( ACE( D->Y ) | Data ) | |
Introducing two 4 variable values C and R (Invokes Pearl 1994) | |
C - Compliance of Subject | |
0: never-taker | |
1: complier | |
2: defier | |
3: always-taker | |
R - Response Behavior of Subject | |
0: never-recover | |
1: helped | |
2: hurt | |
3: always-recover | |
N_c / N_r - The paper had one variable N_CR, which we have split into two. | |
This makes everything a lot cleaner over all. It represents the counts | |
of R=r / C=c which is the prior for a Dirichlet distribution v_c / v_r, which | |
is the probability distribution for assigning a given observation to a category | |
for R and C. | |
""" | |
import pymc as pm | |
import numpy as np | |
import matplotlib.pyplot as plt | |
from pprint import pprint | |
import pandas as pd | |
""" A Wrapper function for quickly generating the data sets found in the paper. """ | |
def make_dataset(counts): | |
""" counts: a list of 8 integers | |
returns: a dataset to be fed to the model_factory """ | |
zdy = [tuple(int(i) for i in list(bin(j).split('b')[1].zfill(3))) for j in range(8)] | |
return [obs for group in [ (zdy[i],) * counts[i] for i in range(len(zdy))] for obs in group] | |
""" BEWARE: Deep Magic Ahead... *sparkle*, *sparkle, *sparkle* """ | |
def model_factory(list_data): | |
df_data = pd.DataFrame(list_data, columns=['z','d','y']) | |
N = df_data.shape[0] | |
Z = pm.Bernoulli('Z', p=.5, value=df_data['z'].values, observed=True) | |
C = np.empty(N, dtype=object) | |
R = np.empty(N, dtype=object) | |
N_r = pm.Multinomial('N_r', n=N, p=[.25,.25,.25,.25]) | |
N_c = pm.Multinomial('N_c', n=N, p=[.25,.25,.25,.25]) | |
v_c_i = pm.Dirichlet('v_c_i', N_c) | |
v_r_i = pm.Dirichlet('v_r_i', N_r) | |
v_c = pm.CompletedDirichlet('v_c', v_r_i ) | |
v_r = pm.CompletedDirichlet('v_r', v_c_i ) | |
for m in range(N): | |
C[m] = pm.Categorical('c%i' % m, p=v_c) | |
R[m] = pm.Categorical('r%i' % m, p=v_r) | |
""" Equation (2) """ | |
@pm.deterministic | |
def d(Z=Z, C=C): | |
return np.where( ( (C==3) | ((Z == False) & (C==2)) | ((Z== True) & (C==1)) ) , .9999, .0001 ) | |
""" Equation (3) """ | |
@pm.deterministic | |
def y(d=d, R=R): | |
return np.where( ( (R==3) | ((d == .0001) & (R==2)) | ((d== .9999) & (R==1)) ) , .9999, .0001 ) | |
D = pm.Bernoulli('D', p=d, value=df_data['d'].values.astype(bool), observed=True) | |
Y = pm.Bernoulli('Y', p=y, value=df_data['y'].values.astype(bool), observed=True) | |
@pm.deterministic | |
def v_r_1(R=R): | |
return float( sum(np.where( (R==1) , 1, 0 ) )) / N | |
@pm.deterministic | |
def v_r_2(R=R): | |
return float( sum(np.where( (R==2) , 1, 0 ) )) / N | |
""" Equation (4) """ | |
@pm.deterministic | |
def ACE(v_r_1=v_r_1, v_r_2=v_r_2): | |
return (v_r_1 - v_r_2) | |
return locals() | |
pm.Matplot.plot(arti_model.ACE) | |
plt.savefig('Artifical Dataset Diagnostics.png') | |
r_total = 0 | |
for vr in arti_model.v_r: | |
r_total += vr.value[0,1] | |
c_total = 0 | |
for cr in arti_model.v_r: | |
c_total += vr.value[0,1] | |
""" This doesn't work because there are two stochastic variables (c and r) for each observation, | |
it makes a graph that is totally unreadable, but you can make the test case only, say 3 | |
observations and get the gist of it """ | |
#graph= pm.graph.graph(model) | |
#graph.write_png('Pearl_97.png') | |
""" A very simple mode for testing """ | |
test_counts = [(40)/8]*8 | |
test_ds = make_dataset(test_counts) | |
test_model = pm.MCMC(model_factory(test_ds)) | |
test_model.sample(300,100) | |
binwidth=.025 | |
plot_data= test_model.trace('ACE')[:] | |
plt.hist(plot_data, bins=np.arange(min(plot_data), max(plot_data) + binwidth, binwidth)); plt.show() | |
samples = 900 | |
burn = 100 | |
""" Chickering and Pearl's 10 sample synthetic dataset """ | |
syn10_counts = [3, 0, 2, 0, 2, 0, 0, 7] | |
syn10_ds = make_dataset(syn100_counts) | |
syn10_model = pm.MCMC(model_factory(syn100_ds)) | |
syn10_model.sample(samples, burn) | |
#binwidth=.01 | |
#plot_data= syn10_model.trace('ACE')[:] | |
#plt.hist(plot_data, bins=np.arange(min(plot_data), max(plot_data) + binwidth, binwidth)); plt.show() | |
""" Chickering and Pearl's 100 sample synthetic dataset """ | |
syn100_counts = [27, 0, 23, 0, 23, 0, 0, 27] | |
syn100_ds = make_dataset(syn100_counts) | |
syn100_model = pm.MCMC(model_factory(syn100_ds)) | |
syn100_model.sample(samples, burn) | |
binwidth=.01 | |
#plot_data= syn100_model.trace('ACE')[:] | |
#plt.hist(plot_data, bins=np.arange(min(plot_data), max(plot_data) + binwidth, binwidth)) | |
#plt.show() | |
""" Chickering and Pearl's 200 sample synthetic dataset """ | |
syn200_counts = [55, 0, 45, 0, 45, 0, 0, 55] | |
syn200_ds = make_dataset(syn200_counts) | |
syn200_model = pm.MCMC(model_factory(syn200_ds)) | |
syn200_model.sample(samples, burn) | |
binwidth=.005 | |
""" Chickering and Pearl's 500 sample synthetic dataset """ | |
syn500_counts = [138, 0, 112, 0, 112, 0, 0, 138] | |
syn500_ds = make_dataset(syn500_counts) | |
syn500_model = pm.MCMC(model_factory(syn500_ds)) | |
syn500_model.sample(samples, burn) | |
binwidth=.005 | |
sp = 500 | |
plot_data= syn10_model.trace('ACE')[sp:] # Blue | |
plt.hist(plot_data, bins=np.arange(min(plot_data), max(plot_data) + binwidth, binwidth)) | |
plot_data= syn100_model.trace('ACE')[sp:] #Green | |
plt.hist(plot_data, color='g', bins=np.arange(min(plot_data), max(plot_data) + binwidth, binwidth)) | |
plot_data= syn200_model.trace('ACE')[sp:] # Red | |
plt.hist(plot_data, color='r', bins=np.arange(min(plot_data), max(plot_data) + binwidth, binwidth)) | |
plot_data= syn500_model.trace('ACE')[sp:] # Cyan | |
plt.hist(plot_data, color='c', bins=np.arange(min(plot_data), max(plot_data) + binwidth, binwidth)) | |
plt.show() | |
""" Chickering and Pearl's 1000 sample synthetic dataset """ | |
syn1000_counts = [275, 0, 225, 0, 225, 0, 0, 275] | |
syn1000_ds = make_dataset(syn1000_counts) | |
syn1000_model = pm.MCMC(model_factory(syn1000_ds)) | |
syn1000_model.sample(150,50) | |
binwidth=.01 | |
plot_data= syn1000_model.trace('ACE')[:] | |
plt.hist(plot_data, bins=np.arange(min(plot_data), max(plot_data) + binwidth, binwidth)); plt.show() | |
""" Real Data Example: Effect of Cholestyramine on Reduced Cholesterol Data Set """ | |
lipid_counts = [158, 14, 0, 0, 52, 12, 23, 78] | |
lipid_ds = make_dataset(lipid_counts) | |
lipid_model = pm.MCMC(model_factory(lipid_ds)) | |
lipid_model.sample(100) | |
binwidth=.01 | |
plot_data= lipid_model.trace('ACE')[:] | |
plt.hist(plot_data, bins=np.arange(min(plot_data), max(plot_data) + binwidth, binwidth)); plt.show() | |
""" Real Data Example: Effect of Vitamin A Supplements on Child Mortality """ | |
vita_counts = [11514, 74, 0, 0, 2385, 34, 9663, 12] | |
vita_ds = make_dataset(vita_counts) | |
vita_model = pm.MCMC(model_factory(vita_ds)) | |
vita_model.sample(500,150) | |
binwidth=.01 | |
plot_data= vita_model.trace('ACE')[:] | |
plt.hist(plot_data, bins=np.arange(min(plot_data), max(plot_data) + binwidth, binwidth)); plt.show() | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment