Skip to content

Instantly share code, notes, and snippets.

@rawkintrevo
Created March 11, 2015 01:55
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save rawkintrevo/77338f3c25e6bb973d6e to your computer and use it in GitHub Desktop.
Save rawkintrevo/77338f3c25e6bb973d6e to your computer and use it in GitHub Desktop.
Multivariate Linear Hierachical -pymc
import pymc as pm
import numpy as np
import matplotlib.pyplot as plt
from pprint import pprint
import pandas as pd
def linear_setup(df, ind_cols, dep_col, gb_cols, intercept=True):
'''
N: Number of observations
G: Number of groups
K: Number of indivariables
L: Number of levels
'''
N = len(df)
if intercept:
df['intercept'] = np.ones(N)
ind_cols = ['intercept'] + ind_cols
### Split up data using tuples as keys
arrays = [ df[c].values for c in gb_cols ]
midf = df[ [dep_col] + ind_cols ]
midf.index = arrays
midf.index.names = gb_cols
G = len(gb_cols)
all_levels = [tuple(np.dstack(arrays)[0][i]) for i in range(len(arrays[0]))]
ll = list(set(all_levels)) #fromally l -> levels list
for l in ll:
if len(l) > 1:
ll.append(l[:-1])
ll = list(set(ll))
# High class tuples must come first so that lower class tuples may reference them. #
ll.sort(key= lambda t: len(t), reverse=False)
l_key = dict(zip( ll, range(len(ll)) ) )
L = len(ll)
term_l = [l for l in ll if len(l)==G]
TL = len(term_l)
""" Now you can access each group with the following Python-fu
midf.loc[l[1]], the list l contains not only fundemental groups but also upper class groups
"""
K = len(ind_cols)
### The Stochastics ###############################################################################
l_err = pm.Uniform("l_err", 0,500, size=L)
d = np.empty(L, dtype=object)
level = 1
for l in ll:
if len(l) > level:
level += 1
if level == 1: # mu=0, sigma=.00001 on top level
d[ l_key[l] ] = pm.Normal('d_%s' % ('_'.join( map(str, l) ) ),
mu=pm.Normal('mu_%s'% ('_'.join( map(str, l) ) ), 0,.001) ,
tau=l_err[ l_key[l] ],
size=K)
else: # parameterized on parent
d [l_key[l] ] = pm.Normal('d_%s' % ('_'.join( map(str, l) ) ),
mu=d[ l_key[l[:-1]] ], tau=.0001, size=K)
### The Independent Cols############################################################################
x = np.empty(L, dtype=object)
for l in term_l:
x[ l_key[l] ] = np.empty(K, dtype=object)
for k, col in enumerate(ind_cols):
x[ l_key[l] ][k] = pm.Normal('x_%s_%s' % (k, '_'.join( map(str, l) ) ), 0, 1,
value=midf.loc[l][col].values, observed=True)
### The deterministics ##########################################################################
@pm.deterministic
def b(d=d):
b_ = np.empty(L, dtype=object)
for l in term_l:
b_[ l_key[l] ] = sum([ d[ l_key[l[:i+1]]] for i in range(G)])
return b_
@pm.deterministic
def y_hat(b=b, x=x): #b0=b0
y_hat = np.empty(L, dtype= object)
for l in term_l:
y_hat[ l_key[l] ] = np.array([ b[l_key[l]].dot(x[l_key[l]]) ])
return y_hat
err = pm.Uniform("err", 0, 500)
y = np.empty(L, dtype=object)
for l in term_l:
y[ l_key[l] ] = pm.Normal( 'y_%s' % ('_'.join( map(str, l) ) ),
y_hat[ l_key[l] ], err,
value=midf.loc[l][dep_col].values, observed=True)
return locals()
#### Synthetic Dataset #######
from random import uniform
size_factor = 32 # obs_per_grp
cats = 10
sgs = 5
ssgs = 7
data= pd.DataFrame([ (c, sg, ssg, uniform(.9,1.1) )
for c in range(cats) for sg in range(sgs) for ssg in range(ssgs)
for i in range(size_factor)], columns=['c','sg','ssg','x'])
#data['x2'] = data.sum(axis=1) + uniform(-1,1)
ind_cols = ['x'] #, 'x2']
dep_col = 'y'
gb_cols = ['c', 'sg', 'ssg']
data['y'] = data[ gb_cols ].sum(axis=1) * data['x']
model = pm.MCMC(linear_setup(data, ind_cols, dep_col,gb_cols, intercept=False))
model.sample(100)
### only need this for multi variables
#for c in range(1,5):
# beta =0
# plt_data= np.array([i[beta] for i in model.trace('d_0_0_0', chain=c)[:]])
# #plt_data = model.trace('y_hat', chain=c)[:]
# plt.plot(plt_data)
#plt.plot(model.trace('d_7')[:]);plt.show()
#plt.show()
l_key = model.l_key
## Beta Plotting
cat = (8,3,1)
key= l_key[ cat ]
plt.plot([model.trace('b')[:][i][ key ] for i in range(len(model.trace('b')[:]))] );plt.show()
model.d[l_key[(1,)]].value + model.d[ l_key[(1,1)] ].value + model.d[ l_key[(1,1,1)] ].value
for c in range(0,1):
plt.plot(model.trace('d_0', chain=c)[:])
plt.show()
plt.plot(model.trace('z0', chain=3)[:]); plt.show()
from datetime import datetime
datetime.now()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment