Skip to content

Instantly share code, notes, and snippets.

@jdmonaco
Last active October 9, 2016 18:42
Show Gist options
  • Save jdmonaco/db1f34fe993dffa71d26 to your computer and use it in GitHub Desktop.
Save jdmonaco/db1f34fe993dffa71d26 to your computer and use it in GitHub Desktop.
"""
Spike train cross-correlogram methods.
"""
import time
import numpy as np
import matplotlib.pyplot as plt
from scipy import real
from scipy.fftpack import fft, ifft
def spikes(rate, duration):
return np.random.rand(int(rate * duration)) * duration
def xcorr(a, b, maxlag=1.0, bins=128):
"""
Compute the cross-correlogram of two spike train arrays.
"""
if bins % 2 == 0:
bins += 1
edges = np.linspace(-maxlag, maxlag, bins + 1)
nb = b.size
start = end = 0
H = np.zeros(bins)
for t in a:
while b[start] < t - maxlag:
start += 1
while end < nb and b[end] <= t + maxlag:
end += 1
H += np.histogram(b[start:end] - t, bins=edges)[0]
centers = (edges[:-1] + edges[1:]) / 2
return H, centers
def xcorrfft(a, b, dt=0.001, maxlag=1.0, bins=128):
"""
Use FFT to compute spike-train cross-correlograms with post-hoc rebinning.
"""
if bins % 2 == 0:
bins += 1
edges = np.linspace(-maxlag, maxlag, bins + 1)
tmin, tmax = np.min([a,b]), np.max([a,b])
dur = tmax - tmin
xedges = np.linspace(tmin, tmax, dur / dt)
xa = np.histogram(a, bins=xedges)[0]
xb = np.histogram(b, bins=xedges)[0]
Fa = fft(xa, overwrite_x=True)
Fb = fft(xb[::-1], overwrite_x=True)
C = real(ifft(Fa * Fb, overwrite_x=True))
Cfull = np.r_[C, C[-2::-1]] # reflect around y-axis
lag = np.linspace(-dur + dt / 2, 0, C.size)
lagfull = np.r_[lag, -1.0 * lag[-2::-1]]
windex = np.abs(lagfull) <= maxlag
lagwin = lagfull[windex]
Cwin = Cfull[windex]
Cbinned = np.histogram(np.repeat(lagwin, Cwin.astype(int)), bins=edges)[0]
centers = (edges[:-1] + edges[1:]) / 2
return Cbinned, centers
np.random.seed(11235)
st = spikes(10.0, 200.0)
t0 = time.time()
C, lags = xcorr(st, st)
dt1 = time.time() - t0
Cfft, lagsfft = xcorrfft(st, st, dt=0.001) # change binarization here
dt2 = time.time() - dt1 - t0
print 'Loopy dt = {} seconds'.format(dt1)
print 'FFT dt = {} seconds'.format(dt2)
plt.ioff()
fig = plt.figure(num=10, figsize=(12, 5))
fig.clear()
ax = fig.add_subplot(111)
ax.plot(lags, C, drawstyle='steps-mid', label='loops')
ax.plot(lagsfft, Cfft, ls='-', c='r', drawstyle='steps-mid', label='fft')
ax.set_ylim(bottom=0)
ax.legend(loc=2)
ax.set(xlabel='Lag (s)', ylabel='count')
ax.set_title('Comparing loop and FFT spike cross-correlograms')
plt.ion()
plt.draw()
@vr2262
Copy link

vr2262 commented Oct 9, 2016

I don't understand what the spikes function actually gives you. I thought it would be the timestamp of each spike in a spike train, but that can't be right.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment