Skip to content

Instantly share code, notes, and snippets.

@rawkintrevo
Created March 11, 2015 01:53
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 3 You must be signed in to fork a gist
  • Save rawkintrevo/50f46ac7396b3d6f3a21 to your computer and use it in GitHub Desktop.
Save rawkintrevo/50f46ac7396b3d6f3a21 to your computer and use it in GitHub Desktop.
A Clinician's Tool for Analyzing Non-Compliance -Chickering and Pearl (1997)
"""
A Clinician's Tool for Analyzing Non-Compliance
-Chickering and Pearl (1997)
X:
In this paper they use v to represent probability (instead of p, it's fun to be creative).
Z~ Treatment Assignement : Bern(vZ) 1 - Pt. was prescribed treatment
D~ Treatment Recieved : Bern(vD) 1 - Pt. was administered treatment
Y~ Outcome : Bern(vY) 1 - positive response
U~ Latent Factors : Given by C and R below
Data: a series of triples of form (z,d,y) one for each subject.
We seek average change in Y due to treatment: ACE( D -> Y )
(Average Causal Effect)
ACE: 'Equivalently, ACE( D --+ Y) is the difference between
the fraction of subjects who are helped by the treatment
(R = r1) and the fraction of subjects who are
hurt by the treatment (R = r2).'
Model:
This model will give a distribution of Pr( ACE( D->Y ) | Data )
Introducing two 4 variable values C and R (Invokes Pearl 1994)
C - Compliance of Subject
0: never-taker
1: complier
2: defier
3: always-taker
R - Response Behavior of Subject
0: never-recover
1: helped
2: hurt
3: always-recover
N_c / N_r - The paper had one variable N_CR, which we have split into two.
This makes everything a lot cleaner over all. It represents the counts
of R=r / C=c which is the prior for a Dirichlet distribution v_c / v_r, which
is the probability distribution for assigning a given observation to a category
for R and C.
"""
import pymc as pm
import numpy as np
import matplotlib.pyplot as plt
from pprint import pprint
import pandas as pd
""" A Wrapper function for quickly generating the data sets found in the paper. """
def make_dataset(counts):
""" counts: a list of 8 integers
returns: a dataset to be fed to the model_factory """
zdy = [tuple(int(i) for i in list(bin(j).split('b')[1].zfill(3))) for j in range(8)]
return [obs for group in [ (zdy[i],) * counts[i] for i in range(len(zdy))] for obs in group]
""" BEWARE: Deep Magic Ahead... *sparkle*, *sparkle, *sparkle* """
def model_factory(list_data):
df_data = pd.DataFrame(list_data, columns=['z','d','y'])
N = df_data.shape[0]
Z = pm.Bernoulli('Z', p=.5, value=df_data['z'].values, observed=True)
C = np.empty(N, dtype=object)
R = np.empty(N, dtype=object)
N_r = pm.Multinomial('N_r', n=N, p=[.25,.25,.25,.25])
N_c = pm.Multinomial('N_c', n=N, p=[.25,.25,.25,.25])
v_c_i = pm.Dirichlet('v_c_i', N_c)
v_r_i = pm.Dirichlet('v_r_i', N_r)
v_c = pm.CompletedDirichlet('v_c', v_r_i )
v_r = pm.CompletedDirichlet('v_r', v_c_i )
for m in range(N):
C[m] = pm.Categorical('c%i' % m, p=v_c)
R[m] = pm.Categorical('r%i' % m, p=v_r)
""" Equation (2) """
@pm.deterministic
def d(Z=Z, C=C):
return np.where( ( (C==3) | ((Z == False) & (C==2)) | ((Z== True) & (C==1)) ) , .9999, .0001 )
""" Equation (3) """
@pm.deterministic
def y(d=d, R=R):
return np.where( ( (R==3) | ((d == .0001) & (R==2)) | ((d== .9999) & (R==1)) ) , .9999, .0001 )
D = pm.Bernoulli('D', p=d, value=df_data['d'].values.astype(bool), observed=True)
Y = pm.Bernoulli('Y', p=y, value=df_data['y'].values.astype(bool), observed=True)
@pm.deterministic
def v_r_1(R=R):
return float( sum(np.where( (R==1) , 1, 0 ) )) / N
@pm.deterministic
def v_r_2(R=R):
return float( sum(np.where( (R==2) , 1, 0 ) )) / N
""" Equation (4) """
@pm.deterministic
def ACE(v_r_1=v_r_1, v_r_2=v_r_2):
return (v_r_1 - v_r_2)
return locals()
pm.Matplot.plot(arti_model.ACE)
plt.savefig('Artifical Dataset Diagnostics.png')
r_total = 0
for vr in arti_model.v_r:
r_total += vr.value[0,1]
c_total = 0
for cr in arti_model.v_r:
c_total += vr.value[0,1]
""" This doesn't work because there are two stochastic variables (c and r) for each observation,
it makes a graph that is totally unreadable, but you can make the test case only, say 3
observations and get the gist of it """
#graph= pm.graph.graph(model)
#graph.write_png('Pearl_97.png')
""" A very simple mode for testing """
test_counts = [(40)/8]*8
test_ds = make_dataset(test_counts)
test_model = pm.MCMC(model_factory(test_ds))
test_model.sample(300,100)
binwidth=.025
plot_data= test_model.trace('ACE')[:]
plt.hist(plot_data, bins=np.arange(min(plot_data), max(plot_data) + binwidth, binwidth)); plt.show()
samples = 900
burn = 100
""" Chickering and Pearl's 10 sample synthetic dataset """
syn10_counts = [3, 0, 2, 0, 2, 0, 0, 7]
syn10_ds = make_dataset(syn100_counts)
syn10_model = pm.MCMC(model_factory(syn100_ds))
syn10_model.sample(samples, burn)
#binwidth=.01
#plot_data= syn10_model.trace('ACE')[:]
#plt.hist(plot_data, bins=np.arange(min(plot_data), max(plot_data) + binwidth, binwidth)); plt.show()
""" Chickering and Pearl's 100 sample synthetic dataset """
syn100_counts = [27, 0, 23, 0, 23, 0, 0, 27]
syn100_ds = make_dataset(syn100_counts)
syn100_model = pm.MCMC(model_factory(syn100_ds))
syn100_model.sample(samples, burn)
binwidth=.01
#plot_data= syn100_model.trace('ACE')[:]
#plt.hist(plot_data, bins=np.arange(min(plot_data), max(plot_data) + binwidth, binwidth))
#plt.show()
""" Chickering and Pearl's 200 sample synthetic dataset """
syn200_counts = [55, 0, 45, 0, 45, 0, 0, 55]
syn200_ds = make_dataset(syn200_counts)
syn200_model = pm.MCMC(model_factory(syn200_ds))
syn200_model.sample(samples, burn)
binwidth=.005
""" Chickering and Pearl's 500 sample synthetic dataset """
syn500_counts = [138, 0, 112, 0, 112, 0, 0, 138]
syn500_ds = make_dataset(syn500_counts)
syn500_model = pm.MCMC(model_factory(syn500_ds))
syn500_model.sample(samples, burn)
binwidth=.005
sp = 500
plot_data= syn10_model.trace('ACE')[sp:] # Blue
plt.hist(plot_data, bins=np.arange(min(plot_data), max(plot_data) + binwidth, binwidth))
plot_data= syn100_model.trace('ACE')[sp:] #Green
plt.hist(plot_data, color='g', bins=np.arange(min(plot_data), max(plot_data) + binwidth, binwidth))
plot_data= syn200_model.trace('ACE')[sp:] # Red
plt.hist(plot_data, color='r', bins=np.arange(min(plot_data), max(plot_data) + binwidth, binwidth))
plot_data= syn500_model.trace('ACE')[sp:] # Cyan
plt.hist(plot_data, color='c', bins=np.arange(min(plot_data), max(plot_data) + binwidth, binwidth))
plt.show()
""" Chickering and Pearl's 1000 sample synthetic dataset """
syn1000_counts = [275, 0, 225, 0, 225, 0, 0, 275]
syn1000_ds = make_dataset(syn1000_counts)
syn1000_model = pm.MCMC(model_factory(syn1000_ds))
syn1000_model.sample(150,50)
binwidth=.01
plot_data= syn1000_model.trace('ACE')[:]
plt.hist(plot_data, bins=np.arange(min(plot_data), max(plot_data) + binwidth, binwidth)); plt.show()
""" Real Data Example: Effect of Cholestyramine on Reduced Cholesterol Data Set """
lipid_counts = [158, 14, 0, 0, 52, 12, 23, 78]
lipid_ds = make_dataset(lipid_counts)
lipid_model = pm.MCMC(model_factory(lipid_ds))
lipid_model.sample(100)
binwidth=.01
plot_data= lipid_model.trace('ACE')[:]
plt.hist(plot_data, bins=np.arange(min(plot_data), max(plot_data) + binwidth, binwidth)); plt.show()
""" Real Data Example: Effect of Vitamin A Supplements on Child Mortality """
vita_counts = [11514, 74, 0, 0, 2385, 34, 9663, 12]
vita_ds = make_dataset(vita_counts)
vita_model = pm.MCMC(model_factory(vita_ds))
vita_model.sample(500,150)
binwidth=.01
plot_data= vita_model.trace('ACE')[:]
plt.hist(plot_data, bins=np.arange(min(plot_data), max(plot_data) + binwidth, binwidth)); plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment