Skip to content

Instantly share code, notes, and snippets.

@Spaak
Created August 24, 2016 11:30
Show Gist options
  • Save Spaak/f544375d169392e1d0934c57e450b05a to your computer and use it in GitHub Desktop.
Save Spaak/f544375d169392e1d0934c57e450b05a to your computer and use it in GitHub Desktop.
# imports
import theano
import theano.tensor as T
from theano.ifelse import ifelse
import pymc3 as mc
# some test data
# note: not actually generated by any sensible model, just for technical testing
ncell = 4
ntim = 10
# spikes as binary variable
y_obs = np.random.choice((0,1), p=(0.7,0.3), size=(ncell,ntim)).astype('float32')
# model 'parameters'
# theta: cell-specific base firing rate
theta_true = np.random.random(ncell).astype('float32')
# beta: cell X cell connectivity matrix
beta_true = np.random.random((ncell,ncell)).astype('float32')
# theano.scan() and map() are not comparible with test value debugging,
# so set it to 'warn'. note that 'off' will give an error:
# AttributeError: 'scratchpad' object has no attribute 'test_value'
# in pymc3.distributions.distribution
theano.config.compute_test_value = 'warn'
# the rate term for one cell i and one time point t
def rate(i,t,y,theta,beta):
# determine previous spike time of cell i
# known as tau_it in model of Rigat
last_time = T.flatnonzero(y[i,:t]).astype('uint16')
# if last_time is not empty (cell i has fired at least once)
# then use last firing time, otherwise use 0
last_time = ifelse(
T.gt(last_time.shape[0], 0),
# the next line is where the error happens in pymc3 (but not in pure Theano)
last_time[-1],
T.constant(0, dtype='uint16')
)
# if t == 0 then rate term is simply theta[i]
# otherwise it's theta[i] plus a weighted sum of all cells' firing
# history since the last time cell i spiked
return ifelse(
T.eq(t, 0),
theta[i],
theta[i] + T.sum( beta[:,i] * \
(T.sum(y[:,last_time:t], axis=-1) / (t-last_time)), axis=0)
)
# symbolic variables for observed data and parameters
y_theano = T.matrix('y')
theta_theano = T.vector('theta')
beta_theano = T.matrix('beta')
# create flattened indexing grids over which rate() will loop
cgrid, tgrid = T.mgrid[0:ncell, 0:ntim]
cgrid = cgrid.flatten().astype('uint16')
tgrid = tgrid.flatten().astype('uint16')
# call theano.map() which will call rate() for all cells and all time points
rateresult, updates = theano.map(rate,
sequences=[cgrid, tgrid],
non_sequences=[y_theano, theta_theano, beta_theano]
)
# reshape back (not needed but was useful in debugging sometimes)
rateresult = T.reshape(rateresult, (ncell, ntim))
# compute log-likelihood based on rate term
logp = T.sum(y_theano * rateresult - T.log1p(T.exp(rateresult)))
# compile function
fun = theano.function(
inputs=[y_theano, theta_theano, beta_theano],
outputs=logp
)
# call function and print result
result = fun(y_obs, theta_true, beta_true)
print(result)
class SpikeLikelihood(mc.Discrete):
"""
The likelihood for spiking as defined in eq. 4 of Rigat et al.
Redefined this to not depend explicitly on an intermediate computation
of tau_it.
theta: ncell, baseline rates
beta : ncell x ncell, connection strengths
"""
def __init__(self, theta, beta, ncell, ntim, *args, **kwargs):
super(SpikeLikelihood, self).__init__(*args, **kwargs)
self.theta = theta
self.beta = theta
# pre-initialize the indexing grid used to implement the loop for
# computing the rate term
self._cgrid, self._tgrid = T.mgrid[0:ncell, 0:ntim]
self._cgrid = self._cgrid.flatten().astype('uint16')
self._tgrid = self._tgrid.flatten().astype('uint16')
self._cgrid.name = 'SpikeLikelihood._cgrid'
self._tgrid.name = 'SpikeLikelihood._tgrid'
def logp(self, Y):
# Y will be ncell X time
# rate can only be computed iteratively per cell and per timepoint
# so we use theano's scan functionality (accessed here through map())
# to loop over all time points and cells
# rate() here is the same theano function as above
rateresult, updates = theano.map(rate,
sequences=[self._cgrid, self._tgrid],
non_sequences=[Y, self.theta, self.beta],
name='SpikeLikelihood._rate_term')
# rateresult will give a flattened output, reshape here
rateresult = T.reshape(rateresult, (self.ncell, self.ntim))
return T.sum(Y * rateresult - T.log1p(T.exp(rateresult)))
model = mc.Model()
with model:
# prior probability of connection on/off, same for all cell pairs
# exponential with lambda=1 means a connection is a priori unlikely
alpha_0 = mc.Exponential('alpha_0', lam=1)
# connections on/off, per cell pair
nu_ij = mc.Bernoulli('nu_ij', p=alpha_0, shape=(ncell, ncell))
# s.d. for connection strength, depends on sigma same for all cells and
# nu_ij per cell. epsilon is shrinkage factor, see Rigat eq. 5.
epsilon = 0.05
sigma_squared = mc.Normal('sigma_squared', mu=20, sd=15)
# warp in mc.Deterministic to give it a name for use in examining the trace etc.
sigma_ij = mc.Deterministic('sigma_ij', T.sqrt(sigma_squared) * \
(nu_ij + epsilon * (1 - nu_ij)))
# actual connection strength
beta_ij = mc.Normal('beta_ij', mu=0, sd=sigma_ij)
# baseline rate per cell
theta_i = mc.Normal('theta_i', mu=20, sd=15, shape=ncell)
y_modelled = SpikeLikelihood('y_modelled', theta=theta_i, beta=beta_ij,
ncell=ncell, ntim=ntim, observed=y_obs)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment