Skip to content

Instantly share code, notes, and snippets.

@alperyeg
Last active February 4, 2016 01:08
Show Gist options
  • Save alperyeg/b6d796cb69c620433441 to your computer and use it in GitHub Desktop.
Save alperyeg/b6d796cb69c620433441 to your computer and use it in GitHub Desktop.
profiling of elephant.spike_train_correlation.cross_correlation_histogram
Timer unit: 1e-06 s
Total time: 10.4797 s
Function: _cch_memory at line 397
Line # Hits Time Per Hit % Time Line Contents
==============================================================
397 @profile
398 def _cch_memory(st_1, st_2, win, mode, norm, border_corr, binary, kern):
399
400 # Check that the spike trains are binned with the saem temporal
401 # resolution
402 5 12 2.4 0.0 if not st1.matrix_rows == 1:
403 raise AssertionError("Spike train must be one dimensional")
404 5 7 1.4 0.0 if not st2.matrix_rows == 1:
405 raise AssertionError("Spike train must be one dimensional")
406 5 1006 201.2 0.0 if not st1.binsize == st2.binsize:
407 raise AssertionError("Bin sizes must be equal")
408
409 # Retrieve unclipped matrix
410 5 26 5.2 0.0 st1_spmat = st_1.to_sparse_array()
411 5 10 2.0 0.0 st2_spmat = st_2.to_sparse_array()
412 5 6 1.2 0.0 binsize = st_1.binsize
413 5 16 3.2 0.0 max_num_bins = max(st_1.num_bins, st_2.num_bins)
414
415 # Set the time window in which is computed the cch
416 5 8 1.6 0.0 if win is not None:
417 # Window parameter given in number of bins (integer)
418 if isinstance(win[0], int) and isinstance(win[1], int):
419 # Check the window parameter values
420 if win[0] >= win[1] or win[0] <= -max_num_bins \
421 or win[1] >= max_num_bins:
422 raise ValueError(
423 "The window exceeds the length of the spike trains")
424 # Assign left and right edges of the cch
425 l, r = win[0], win[1]
426 # Window parameter given in time units
427 else:
428 # Check the window parameter values
429 if win[0].rescale(binsize.units).magnitude % \
430 binsize.magnitude != 0 or win[1].rescale(
431 binsize.units).magnitude % binsize.magnitude != 0:
432 raise ValueError(
433 "The window has to be a multiple of the binsize")
434 if win[0] >= win[1] or win[0] <= -max_num_bins * binsize \
435 or win[1] >= max_num_bins * binsize:
436 raise ValueError("The window exceeds the length of the"
437 " spike trains")
438 # Assign left and right edges of the cch
439 l, r = int(win[0].rescale(binsize.units) / binsize), int(
440 win[1].rescale(binsize.units) / binsize)
441 # Case without explicit window parameter
442 else:
443 # cch computed for all the possible entries
444 5 7 1.4 0.0 if mode == 'full':
445 # Assign left and right edges of the cch
446 5 8 1.6 0.0 r = st_2.num_bins - 1
447 5 9 1.8 0.0 l = - st_1.num_bins + 1
448 # cch compute only for the entries that completely overlap
449 elif mode == 'valid':
450 # Assign left and right edges of the cch
451 r = max(st_2.num_bins - st_1.num_bins, 0)
452 l = min(st_2.num_bins - st_1.num_bins, 0)
453 # Check the mode parameter
454 else:
455 raise KeyError(
456 "The possible entries for mode parameter are" +
457 "'full' and 'valid'")
458
459 # For each row, extract the nonzero column indices
460 # and the corresponding # data in the matrix (for performance reasons)
461 5 2096 419.2 0.0 st1_bin_idx_unique = st1_spmat.nonzero()[1]
462 5 1413 282.6 0.0 st2_bin_idx_unique = st2_spmat.nonzero()[1]
463
464 # Case with binary entries
465 5 27 5.4 0.0 if binary:
466 st1_bin_counts_unique = np.array(st1_spmat.data > 0, dtype=int)
467 st2_bin_counts_unique = np.array(st2_spmat.data > 0, dtype=int)
468 # Case with all values
469 else:
470 5 7 1.4 0.0 st1_bin_counts_unique = st1_spmat.data
471 5 7 1.4 0.0 st2_bin_counts_unique = st2_spmat.data
472
473 # Initialize the counts to an array of zeroes,
474 # and the bin IDs to integers
475 # spanning the time axis
476 5 926 185.2 0.0 counts = np.zeros(np.abs(l) + np.abs(r) + 1)
477 5 1644 328.8 0.0 bin_ids = np.arange(l, r + 1)
478 # Compute the CCH at lags in l,...,r only
479 47435 81286 1.7 0.8 for idx, i in enumerate(st1_bin_idx_unique):
480 47430 345082 7.3 3.3 timediff = st2_bin_idx_unique - i
481 47430 56917 1.2 0.5 timediff_in_range = np.all(
482 47430 1440376 30.4 13.7 [timediff >= l, timediff <= r], axis=0)
483 47430 660733 13.9 6.3 timediff = (timediff[timediff_in_range]).reshape((-1,))
484 47430 2334617 49.2 22.3 counts[timediff + np.abs(l)] += st1_bin_counts_unique[idx] * \
485 47430 5549904 117.0 53.0 st2_bin_counts_unique[timediff_in_range]
486
487 # Correct the values taking into account lacking contributes
488 # at the edges
489 5 8 1.6 0.0 if border_corr is True:
490 correction = float(max_num_bins + 1) / np.array(
491 max_num_bins + 1 - abs(
492 np.arange(l, r + 1)), float)
493 counts = counts * correction
494
495 # Define the kern for smoothing as an ndarray
496 5 21 4.2 0.0 if hasattr(kern, '__iter__'):
497 if len(kern) > np.abs(l) + np.abs(r) + 1:
498 raise ValueError(
499 'The length of the kernel cannot be larger than the '
500 'length %d of the resulting CCH.' % (
501 np.abs(l) + np.abs(r) + 1))
502 kern = np.array(kern, dtype=float)
503 kern = 1. * kern / sum(kern)
504 # Check kern parameter
505 5 6 1.2 0.0 elif kern is not None:
506 raise ValueError('Invalid smoothing kernel.')
507
508 # Smooth the cross-correlation histogram with the kern
509 5 5 1.0 0.0 if kern is not None:
510 counts = np.convolve(counts, kern, mode='same')
511
512 # Rescale the histogram so that the central bin has height 1,
513 # if requested
514 5 7 1.4 0.0 if norm and l <= 0 <= r:
515 if counts[np.abs(l)] != 0:
516 counts = counts / counts[np.abs(l)]
517 else:
518 warnings.warn('CCH not normalized because no value for 0 lag')
519
520 # Transform the array count into an AnalogSignalArray
521 5 10 2.0 0.0 cch_result = neo.AnalogSignalArray(
522 5 25 5.0 0.0 signal=counts.reshape(counts.size, 1),
523 5 8 1.6 0.0 units=pq.dimensionless,
524 5 331 66.2 0.0 t_start=(bin_ids[0] - 0.5) * st_1.binsize,
525 5 3142 628.4 0.0 sampling_period=st_1.binsize)
526 # Return only the hist_bins bins and counts before and after the
527 # central one
528 5 8 1.6 0.0 return cch_result, bin_ids
Total time: 0.110484 s
Function: _cch_speed at line 530
Line # Hits Time Per Hit % Time Line Contents
==============================================================
530 @profile
531 def _cch_speed(st_1, st_2, win, mode, norm, border_corr, binary, kern):
532
533 # Check that the spike trains are binned with the same temporal
534 # resolution
535 5 13 2.6 0.0 if not st1.matrix_rows == 1:
536 raise AssertionError("Spike train must be one dimensional")
537 5 6 1.2 0.0 if not st2.matrix_rows == 1:
538 raise AssertionError("Spike train must be one dimensional")
539 5 936 187.2 0.8 if not st1.binsize == st2.binsize:
540 raise AssertionError("Bin sizes must be equal")
541
542 # Retrieve the array of the binne spik train
543 5 2240 448.0 2.0 st1_arr = st1.to_array()[0, :]
544 5 2013 402.6 1.8 st2_arr = st2.to_array()[0, :]
545 5 9 1.8 0.0 binsize = st1.binsize
546
547 # Convert the to binary version
548 5 4 0.8 0.0 if binary:
549 st1_arr = np.array(st1_arr > 0, dtype=int)
550 st2_arr = np.array(st2_arr > 0, dtype=int)
551 5 10 2.0 0.0 max_num_bins = max(len(st1_arr), len(st2_arr))
552
553 # Cross correlate the spiketrains
554
555 # Case explicit temporal window
556 5 5 1.0 0.0 if win is not None:
557 # Window parameter given in number of bins (integer)
558 if isinstance(win[0], int) and isinstance(win[1], int):
559 # Check the window parameter values
560 if win[0] >= win[1] or win[0] <= -max_num_bins \
561 or win[1] >= max_num_bins:
562 raise ValueError(
563 "The window exceed the length of the spike trains")
564 # Assign left and right edges of the cch
565 l, r = win[0], win[1]
566 # Window parameter given in time units
567 else:
568 # Check the window parameter values
569 if win[0].rescale(binsize.units).magnitude % \
570 binsize.magnitude != 0 or win[1].rescale(
571 binsize.units).magnitude % binsize.magnitude != 0:
572 raise ValueError(
573 "The window has to be a multiple of the binsize")
574 if win[0] >= win[1] or win[0] <= -max_num_bins * binsize \
575 or win[1] >= max_num_bins * binsize:
576 raise ValueError("The window exceed the length of the"
577 " spike trains")
578 # Assign left and right edges of the cch
579 l, r = int(win[0].rescale(binsize.units) / binsize), int(
580 win[1].rescale(binsize.units) / binsize)
581
582 # Cross correlate the spike trains
583 corr = signal.fftconvolve(st2_arr, st1_arr[::-1], mode='full').astype(int)
584 counts = corr[len(st1_arr)+l+1:len(st1_arr)+1+r+1]
585
586 # Case generic
587 else:
588 # Cross correlate the spike trains
589 5 99853 19970.6 90.4 counts = signal.fftconvolve(st2_arr, st1_arr[::-1], mode=mode).astype(int)
590 # Assign the edges of the cch for the different mode parameters
591 5 16 3.2 0.0 if mode == 'full':
592 # Assign left and right edges of the cch
593 5 12 2.4 0.0 r = st_2.num_bins - 1
594 5 11 2.2 0.0 l = - st_1.num_bins + 1
595 # cch compute only for the entries that completely overlap
596 elif mode == 'valid':
597 # Assign left and right edges of the cch
598 r = max(st_2.num_bins - st_1.num_bins, 0)
599 l = min(st_2.num_bins - st_1.num_bins, 0)
600 5 793 158.6 0.7 bin_ids = np.r_[l:r + 1]
601
602 # Correct the values taking into account lacking contributes
603 # at the edges
604 5 9 1.8 0.0 if border_corr is True:
605 correction = float(max_num_bins + 1) / np.array(
606 max_num_bins + 1 - abs(
607 np.arange(l, r + 1)), float)
608 counts = counts * correction
609
610 # Define the kern for smoothing as an ndarray
611 5 32 6.4 0.0 if hasattr(kern, '__iter__'):
612 if len(kern) > np.abs(l) + np.abs(r) + 1:
613 raise ValueError(
614 'The length of the kernel cannot be larger than the '
615 'length %d of the resulting CCH.' % (
616 np.abs(l) + np.abs(r) + 1))
617 kern = np.array(kern, dtype=float)
618 kern = 1. * kern / sum(kern)
619 # Check kern parameter
620 5 5 1.0 0.0 elif kern is not None:
621 raise ValueError('Invalid smoothing kernel.')
622
623 # Smooth the cross-correlation histogram with the kern
624 5 6 1.2 0.0 if kern is not None:
625 counts = np.convolve(counts, kern, mode='same')
626
627 # Rescale the histogram so that the central bin has height 1,
628 # if requested
629 5 4 0.8 0.0 if norm and l <= 0 <= r:
630 if counts[np.abs(l)] != 0:
631 counts = counts / counts[np.abs(l)]
632 else:
633 warnings.warn('CCH not normalized because no value for 0 lag')
634
635 # Transform the array count into an AnalogSignalArray
636 5 13 2.6 0.0 cch_result = neo.AnalogSignalArray(
637 5 36 7.2 0.0 signal=counts.reshape(counts.size, 1),
638 5 7 1.4 0.0 units=pq.dimensionless,
639 5 363 72.6 0.3 t_start=(bin_ids[0] - 0.5) * st_1.binsize,
640 5 4077 815.4 3.7 sampling_period=st_1.binsize)
641 # Return only the hist_bins bins and counts before and after the
642 # central one
643 5 11 2.2 0.0 return cch_result, bin_ids
import timeit
import quantities as pq
import elephant.spike_train_correlation as corr
from elephant.conversion import BinnedSpikeTrain
from fuemc nctools import partial
from elephant.spike_train_generation import homogeneous_poisson_process as poisson
benchmark = list()
runs = 5
t_stop = 100 * pq.s
rate = 100 * pq.Hz
spikedata = [poisson(rate, t_start=0 * pq.s, t_stop=t_stop) for _ in range(2)]
st1 = BinnedSpikeTrain(spikedata[0], binsize=1 * pq.ms)
st2 = BinnedSpikeTrain(spikedata[1], binsize=1 * pq.ms)
benchmark.append(timeit.Timer(partial(corr.cross_correlation_histogram, st1, st2,
method='speed')).timeit(runs))
benchmark.append(timeit.Timer(partial(corr.cross_correlation_histogram, st1, st2,
method='memory')).timeit(runs))
print "Benchmark results for {0} runs, {1} t_stop, {2} rate".format(runs, t_stop, rate)
print "Speed method: {}".format(benchmark[0])
print "Memory method: {}".format(benchmark[1])
def cross_correlation_histogram(
st1, st2, mode='full', window=None, normalize=False,
border_correction=False, binary=False, kernel=None,
chance_corrected=False, method='speed', **kwargs):
"""
Computes the cross-correlation histogram (CCH) between two binned spike
trains st1 and st2.
"""
# @profile
def _cch_memory(st_1, st_2, win, mode, norm, border_corr, binary, kern):
# Check that the spike trains are binned with the saem temporal
# resolution
if not st1.matrix_rows == 1:
raise AssertionError("Spike train must be one dimensional")
if not st2.matrix_rows == 1:
raise AssertionError("Spike train must be one dimensional")
if not st1.binsize == st2.binsize:
raise AssertionError("Bin sizes must be equal")
# Retrieve unclipped matrix
st1_spmat = st_1.to_sparse_array()
st2_spmat = st_2.to_sparse_array()
binsize = st_1.binsize
max_num_bins = max(st_1.num_bins, st_2.num_bins)
# Set the time window in which is computed the cch
if win is not None:
# Window parameter given in number of bins (integer)
if isinstance(win[0], int) and isinstance(win[1], int):
# Check the window parameter values
if win[0] >= win[1] or win[0] <= -max_num_bins \
or win[1] >= max_num_bins:
raise ValueError(
"The window exceeds the length of the spike trains")
# Assign left and right edges of the cch
l, r = win[0], win[1]
# Window parameter given in time units
else:
# Check the window parameter values
if win[0].rescale(binsize.units).magnitude % \
binsize.magnitude != 0 or win[1].rescale(
binsize.units).magnitude % binsize.magnitude != 0:
raise ValueError(
"The window has to be a multiple of the binsize")
if win[0] >= win[1] or win[0] <= -max_num_bins * binsize \
or win[1] >= max_num_bins * binsize:
raise ValueError("The window exceeds the length of the"
" spike trains")
# Assign left and right edges of the cch
l, r = int(win[0].rescale(binsize.units) / binsize), int(
win[1].rescale(binsize.units) / binsize)
# Case without explicit window parameter
else:
# cch computed for all the possible entries
if mode == 'full':
# Assign left and right edges of the cch
r = st_2.num_bins - 1
l = - st_1.num_bins + 1
# cch compute only for the entries that completely overlap
elif mode == 'valid':
# Assign left and right edges of the cch
r = max(st_2.num_bins - st_1.num_bins, 0)
l = min(st_2.num_bins - st_1.num_bins, 0)
# Check the mode parameter
else:
raise KeyError(
"The possible entries for mode parameter are" +
"'full' and 'valid'")
# For each row, extract the nonzero column indices
# and the corresponding # data in the matrix (for performance reasons)
st1_bin_idx_unique = st1_spmat.nonzero()[1]
st2_bin_idx_unique = st2_spmat.nonzero()[1]
# Case with binary entries
if binary:
st1_bin_counts_unique = np.array(st1_spmat.data > 0, dtype=int)
st2_bin_counts_unique = np.array(st2_spmat.data > 0, dtype=int)
# Case with all values
else:
st1_bin_counts_unique = st1_spmat.data
st2_bin_counts_unique = st2_spmat.data
# Initialize the counts to an array of zeroes,
# and the bin IDs to integers
# spanning the time axis
counts = np.zeros(np.abs(l) + np.abs(r) + 1)
bin_ids = np.arange(l, r + 1)
# Compute the CCH at lags in l,...,r only
for idx, i in enumerate(st1_bin_idx_unique):
timediff = st2_bin_idx_unique - i
timediff_in_range = np.all(
[timediff >= l, timediff <= r], axis=0)
timediff = (timediff[timediff_in_range]).reshape((-1,))
counts[timediff + np.abs(l)] += st1_bin_counts_unique[idx] * \
st2_bin_counts_unique[timediff_in_range]
# Correct the values taking into account lacking contributes
# at the edges
if border_corr is True:
correction = float(max_num_bins + 1) / np.array(
max_num_bins + 1 - abs(
np.arange(l, r + 1)), float)
counts = counts * correction
# Define the kern for smoothing as an ndarray
if hasattr(kern, '__iter__'):
if len(kern) > np.abs(l) + np.abs(r) + 1:
raise ValueError(
'The length of the kernel cannot be larger than the '
'length %d of the resulting CCH.' % (
np.abs(l) + np.abs(r) + 1))
kern = np.array(kern, dtype=float)
kern = 1. * kern / sum(kern)
# Check kern parameter
elif kern is not None:
raise ValueError('Invalid smoothing kernel.')
# Smooth the cross-correlation histogram with the kern
if kern is not None:
counts = np.convolve(counts, kern, mode='same')
# Rescale the histogram so that the central bin has height 1,
# if requested
if norm and l <= 0 <= r:
if counts[np.abs(l)] != 0:
counts = counts / counts[np.abs(l)]
else:
warnings.warn('CCH not normalized because no value for 0 lag')
# Transform the array count into an AnalogSignalArray
cch_result = neo.AnalogSignalArray(
signal=counts.reshape(counts.size, 1),
units=pq.dimensionless,
t_start=(bin_ids[0] - 0.5) * st_1.binsize,
sampling_period=st_1.binsize)
# Return only the hist_bins bins and counts before and after the
# central one
return cch_result, bin_ids
# @profile
def _cch_speed(st_1, st_2, win, mode, norm, border_corr, binary, kern):
# Check that the spike trains are binned with the same temporal
# resolution
if not st1.matrix_rows == 1:
raise AssertionError("Spike train must be one dimensional")
if not st2.matrix_rows == 1:
raise AssertionError("Spike train must be one dimensional")
if not st1.binsize == st2.binsize:
raise AssertionError("Bin sizes must be equal")
# Retrieve the array of the binne spik train
st1_arr = st1.to_array()[0, :]
st2_arr = st2.to_array()[0, :]
binsize = st1.binsize
# Convert the to binary version
if binary:
st1_arr = np.array(st1_arr > 0, dtype=int)
st2_arr = np.array(st2_arr > 0, dtype=int)
max_num_bins = max(len(st1_arr), len(st2_arr))
# Cross correlate the spiketrains
# Case explicit temporal window
if win is not None:
# Window parameter given in number of bins (integer)
if isinstance(win[0], int) and isinstance(win[1], int):
# Check the window parameter values
if win[0] >= win[1] or win[0] <= -max_num_bins \
or win[1] >= max_num_bins:
raise ValueError(
"The window exceed the length of the spike trains")
# Assign left and right edges of the cch
l, r = win[0], win[1]
# Window parameter given in time units
else:
# Check the window parameter values
if win[0].rescale(binsize.units).magnitude % \
binsize.magnitude != 0 or win[1].rescale(
binsize.units).magnitude % binsize.magnitude != 0:
raise ValueError(
"The window has to be a multiple of the binsize")
if win[0] >= win[1] or win[0] <= -max_num_bins * binsize \
or win[1] >= max_num_bins * binsize:
raise ValueError("The window exceed the length of the"
" spike trains")
# Assign left and right edges of the cch
l, r = int(win[0].rescale(binsize.units) / binsize), int(
win[1].rescale(binsize.units) / binsize)
# Cross correlate the spike trains
corr = signal.fftconvolve(st2_arr, st1_arr[::-1], mode='full').astype(int)
counts = corr[len(st1_arr)+l+1:len(st1_arr)+1+r+1]
# Case generic
else:
# Cross correlate the spike trains
counts = signal.fftconvolve(st2_arr, st1_arr[::-1], mode=mode).astype(int)
# Assign the edges of the cch for the different mode parameters
if mode == 'full':
# Assign left and right edges of the cch
r = st_2.num_bins - 1
l = - st_1.num_bins + 1
# cch compute only for the entries that completely overlap
elif mode == 'valid':
# Assign left and right edges of the cch
r = max(st_2.num_bins - st_1.num_bins, 0)
l = min(st_2.num_bins - st_1.num_bins, 0)
bin_ids = np.r_[l:r + 1]
# Correct the values taking into account lacking contributes
# at the edges
if border_corr is True:
correction = float(max_num_bins + 1) / np.array(
max_num_bins + 1 - abs(
np.arange(l, r + 1)), float)
counts = counts * correction
# Define the kern for smoothing as an ndarray
if hasattr(kern, '__iter__'):
if len(kern) > np.abs(l) + np.abs(r) + 1:
raise ValueError(
'The length of the kernel cannot be larger than the '
'length %d of the resulting CCH.' % (
np.abs(l) + np.abs(r) + 1))
kern = np.array(kern, dtype=float)
kern = 1. * kern / sum(kern)
# Check kern parameter
elif kern is not None:
raise ValueError('Invalid smoothing kernel.')
# Smooth the cross-correlation histogram with the kern
if kern is not None:
counts = np.convolve(counts, kern, mode='same')
# Rescale the histogram so that the central bin has height 1,
# if requested
if norm and l <= 0 <= r:
if counts[np.abs(l)] != 0:
counts = counts / counts[np.abs(l)]
else:
warnings.warn('CCH not normalized because no value for 0 lag')
# Transform the array count into an AnalogSignalArray
cch_result = neo.AnalogSignalArray(
signal=counts.reshape(counts.size, 1),
units=pq.dimensionless,
t_start=(bin_ids[0] - 0.5) * st_1.binsize,
sampling_period=st_1.binsize)
# Return only the hist_bins bins and counts before and after the
# central one
return cch_result, bin_ids
if method is "memory":
cch_result, bin_ids = _cch_memory(
st1, st2, window, mode, normalize, border_correction, binary,
kernel)
elif method is "speed":
cch_result, bin_ids = _cch_speed(
st1, st2, window, mode, normalize, border_correction, binary,
kernel)
return cch_result, bin_ids
# Alias for common abbreviation
cch = cross_correlation_histogram
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment