Created
August 24, 2016 11:30
-
-
Save Spaak/f544375d169392e1d0934c57e450b05a to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# imports | |
import theano | |
import theano.tensor as T | |
from theano.ifelse import ifelse | |
import pymc3 as mc | |
# some test data | |
# note: not actually generated by any sensible model, just for technical testing | |
ncell = 4 | |
ntim = 10 | |
# spikes as binary variable | |
y_obs = np.random.choice((0,1), p=(0.7,0.3), size=(ncell,ntim)).astype('float32') | |
# model 'parameters' | |
# theta: cell-specific base firing rate | |
theta_true = np.random.random(ncell).astype('float32') | |
# beta: cell X cell connectivity matrix | |
beta_true = np.random.random((ncell,ncell)).astype('float32') | |
# theano.scan() and map() are not comparible with test value debugging, | |
# so set it to 'warn'. note that 'off' will give an error: | |
# AttributeError: 'scratchpad' object has no attribute 'test_value' | |
# in pymc3.distributions.distribution | |
theano.config.compute_test_value = 'warn' | |
# the rate term for one cell i and one time point t | |
def rate(i,t,y,theta,beta): | |
# determine previous spike time of cell i | |
# known as tau_it in model of Rigat | |
last_time = T.flatnonzero(y[i,:t]).astype('uint16') | |
# if last_time is not empty (cell i has fired at least once) | |
# then use last firing time, otherwise use 0 | |
last_time = ifelse( | |
T.gt(last_time.shape[0], 0), | |
# the next line is where the error happens in pymc3 (but not in pure Theano) | |
last_time[-1], | |
T.constant(0, dtype='uint16') | |
) | |
# if t == 0 then rate term is simply theta[i] | |
# otherwise it's theta[i] plus a weighted sum of all cells' firing | |
# history since the last time cell i spiked | |
return ifelse( | |
T.eq(t, 0), | |
theta[i], | |
theta[i] + T.sum( beta[:,i] * \ | |
(T.sum(y[:,last_time:t], axis=-1) / (t-last_time)), axis=0) | |
) | |
# symbolic variables for observed data and parameters | |
y_theano = T.matrix('y') | |
theta_theano = T.vector('theta') | |
beta_theano = T.matrix('beta') | |
# create flattened indexing grids over which rate() will loop | |
cgrid, tgrid = T.mgrid[0:ncell, 0:ntim] | |
cgrid = cgrid.flatten().astype('uint16') | |
tgrid = tgrid.flatten().astype('uint16') | |
# call theano.map() which will call rate() for all cells and all time points | |
rateresult, updates = theano.map(rate, | |
sequences=[cgrid, tgrid], | |
non_sequences=[y_theano, theta_theano, beta_theano] | |
) | |
# reshape back (not needed but was useful in debugging sometimes) | |
rateresult = T.reshape(rateresult, (ncell, ntim)) | |
# compute log-likelihood based on rate term | |
logp = T.sum(y_theano * rateresult - T.log1p(T.exp(rateresult))) | |
# compile function | |
fun = theano.function( | |
inputs=[y_theano, theta_theano, beta_theano], | |
outputs=logp | |
) | |
# call function and print result | |
result = fun(y_obs, theta_true, beta_true) | |
print(result) | |
class SpikeLikelihood(mc.Discrete): | |
""" | |
The likelihood for spiking as defined in eq. 4 of Rigat et al. | |
Redefined this to not depend explicitly on an intermediate computation | |
of tau_it. | |
theta: ncell, baseline rates | |
beta : ncell x ncell, connection strengths | |
""" | |
def __init__(self, theta, beta, ncell, ntim, *args, **kwargs): | |
super(SpikeLikelihood, self).__init__(*args, **kwargs) | |
self.theta = theta | |
self.beta = theta | |
# pre-initialize the indexing grid used to implement the loop for | |
# computing the rate term | |
self._cgrid, self._tgrid = T.mgrid[0:ncell, 0:ntim] | |
self._cgrid = self._cgrid.flatten().astype('uint16') | |
self._tgrid = self._tgrid.flatten().astype('uint16') | |
self._cgrid.name = 'SpikeLikelihood._cgrid' | |
self._tgrid.name = 'SpikeLikelihood._tgrid' | |
def logp(self, Y): | |
# Y will be ncell X time | |
# rate can only be computed iteratively per cell and per timepoint | |
# so we use theano's scan functionality (accessed here through map()) | |
# to loop over all time points and cells | |
# rate() here is the same theano function as above | |
rateresult, updates = theano.map(rate, | |
sequences=[self._cgrid, self._tgrid], | |
non_sequences=[Y, self.theta, self.beta], | |
name='SpikeLikelihood._rate_term') | |
# rateresult will give a flattened output, reshape here | |
rateresult = T.reshape(rateresult, (self.ncell, self.ntim)) | |
return T.sum(Y * rateresult - T.log1p(T.exp(rateresult))) | |
model = mc.Model() | |
with model: | |
# prior probability of connection on/off, same for all cell pairs | |
# exponential with lambda=1 means a connection is a priori unlikely | |
alpha_0 = mc.Exponential('alpha_0', lam=1) | |
# connections on/off, per cell pair | |
nu_ij = mc.Bernoulli('nu_ij', p=alpha_0, shape=(ncell, ncell)) | |
# s.d. for connection strength, depends on sigma same for all cells and | |
# nu_ij per cell. epsilon is shrinkage factor, see Rigat eq. 5. | |
epsilon = 0.05 | |
sigma_squared = mc.Normal('sigma_squared', mu=20, sd=15) | |
# warp in mc.Deterministic to give it a name for use in examining the trace etc. | |
sigma_ij = mc.Deterministic('sigma_ij', T.sqrt(sigma_squared) * \ | |
(nu_ij + epsilon * (1 - nu_ij))) | |
# actual connection strength | |
beta_ij = mc.Normal('beta_ij', mu=0, sd=sigma_ij) | |
# baseline rate per cell | |
theta_i = mc.Normal('theta_i', mu=20, sd=15, shape=ncell) | |
y_modelled = SpikeLikelihood('y_modelled', theta=theta_i, beta=beta_ij, | |
ncell=ncell, ntim=ntim, observed=y_obs) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment