Created
May 4, 2021 17:11
-
-
Save danielmk/9adc7409f40a076ffec0cdf85dea4519 to your computer and use it in GitHub Desktop.
Python implementation of a balanced spiking neural network.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
NumPy implementation of a balanced spiking neural network. Inspired by MATLAB | |
code from Nicola & Clopath (2017): https://doi.org/10.1038/s41467-017-01827-3 | |
""" | |
import numpy as np | |
from numpy.random import rand, randn | |
import matplotlib.pyplot as plt | |
def balanced_spiking_network(dt=0.00005, T=2.0, tref=0.002, tm=0.01, | |
vreset=-65.0, vpeak=-40.0, n=2000, | |
td=0.02, tr=0.002, p=0.1, | |
offset=-40.00, g=0.04, seed=100, | |
nrec=10): | |
"""Simulate a balanced spiking neuronal network | |
Parameters | |
---------- | |
dt : float | |
Sampling interval of the simulation. | |
T : float | |
Duration of the simulation. | |
tref : float | |
Refractory time of the neurons. | |
tm : float | |
Time constant of the neurons. | |
vreset : float | |
The voltage neurons are set to after a spike. | |
vpeak : float | |
The voltage above which a spike is triggered | |
n : int | |
The number of neurons. | |
td : float | |
Synaptic decay time constant. | |
tr : float | |
Synaptic rise time constant. | |
p : float | |
Connection probability between neurons. | |
offset : float | |
A constant input into all neurons. | |
g : float | |
Scaling factor of synaptic strength | |
seed : int | |
The seed makes NumPy random number generator deterministic. | |
nrec : int | |
The number of neurons to record. | |
Returns | |
------- | |
ndarray | |
A 2D array of recorded voltages. Rows are time points, | |
columns are the recorded neurons. Shape: (int(T/dt), nrec). | |
""" | |
np.random.seed(seed) # Seeding randomness for reproducibility | |
"""Setup weight matrix""" | |
w = g * (randn(n, n)) * (rand(n, n) < p) / (np.sqrt(n) * p) | |
# Set the row mean to zero | |
row_means = np.mean(w, axis=1, where=np.abs(w) > 0)[:, None] | |
row_means = np.repeat(row_means, w.shape[0], axis=1) | |
w[np.abs(w) > 0] = w[np.abs(w) > 0] - row_means[np.abs(w) > 0] | |
"""Preinitialize recording""" | |
nt = round(T/dt) # Number of time steps | |
rec = np.zeros((nt, nrec)) | |
"""Initial conditions""" | |
ipsc = np.zeros(n) # Post synaptic current storage variable | |
hm = np.zeros(n) # Storage variable for filtered firing rates | |
tlast = np.zeros((n)) # Used to set the refractory times | |
v = vreset + rand(n)*(30-vreset) # Initialize neuron voltage | |
"""Start integration loop""" | |
for i in np.arange(0, nt, 1): | |
inp = ipsc + offset # Total input current | |
# Voltage equation with refractory period | |
# Only change if voltage outside of refractory time period | |
dv = (dt * i > tlast + tref) * (-v + inp) / tm | |
v = v + dt*dv | |
index = np.argwhere(v >= vpeak)[:, 0] # Spiked neurons | |
# Get the weight matrix column sum of spikers | |
if len(index) > 0: | |
# Compute the increase in current due to spiking | |
jd = w[:, index].sum(axis=1) | |
else: | |
jd = 0*ipsc | |
# Used to set the refractory period of LIF neurons | |
tlast = (tlast + (dt * i - tlast) * | |
np.array(v >= vpeak, dtype=int)) | |
ipsc = ipsc * np.exp(-dt / tr) + hm * dt | |
# Integrate the current | |
hm = (hm * np.exp(-dt / td) + jd * | |
(int(len(index) > 0)) / (tr * td)) | |
v = v + (30 - v) * (v >= vpeak) | |
rec[i, :] = v[0:nrec] # Record a random voltage | |
v = v + (vreset - v) * (v >= vpeak) | |
return rec | |
if __name__ == '__main__': | |
rec = balanced_spiking_network() | |
"""PLOTTING""" | |
fig, ax = plt.subplots(1) | |
ax.plot(rec[:, 0] - 100.0) | |
ax.plot(rec[:, 1]) | |
ax.plot(rec[:, 2] + 100.0) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment