Skip to content

Instantly share code, notes, and snippets.

@Routhleck
Last active May 10, 2024 07:27
Show Gist options
  • Save Routhleck/f37c18283c169ed3148f2ab0ac6a1a08 to your computer and use it in GitHub Desktop.
Save Routhleck/f37c18283c169ed3148f2ab0ac6a1a08 to your computer and use it in GitHub Desktop.
BrainPy Issue #667
import sys
import matplotlib.pyplot as plt
import numpy as np
import brainpy as bp
import brainpy.math as bm
import os
from scipy.sparse import coo_matrix, csr_matrix
class EINet(bp.DynSysGroup):
def __init__(self, E_neuron, I_neuron, E_params, I_params, E2E_synapse, E2I_synapse, I2E_synapse, I2I_synapse, E2E_synapse_params, E2I_synapse_params, I2E_synapse_params, I2I_synapse_params, E2E_comm, E2I_comm, I2E_comm, I2I_comm):
super().__init__()
self.E_params = E_params.copy()
self.I_params = I_params.copy()
self.E2E_synapse_params = E2E_synapse_params.copy()
self.E2I_synapse_params = E2I_synapse_params.copy()
self.I2E_synapse_params = I2E_synapse_params.copy()
self.I2I_synapse_params = I2I_synapse_params.copy()
# neurons
self.E = E_neuron(**self.E_params)
self.I = I_neuron(**self.I_params)
# synapses
self.E2E = E2E_synapse(pre=self.E, post=self.E, comm=E2E_comm, **self.E2E_synapse_params)
self.E2I = E2I_synapse(pre=self.E, post=self.I, comm=E2I_comm, **self.E2I_synapse_params)
self.I2E = I2E_synapse(pre=self.I, post=self.E, comm=I2E_comm, **self.I2E_synapse_params)
self.I2I = I2I_synapse(pre=self.I, post=self.I, comm=I2I_comm, **self.I2I_synapse_params)
def update(self, E_inp, I_inp):
self.E2E()
self.E2I()
self.I2E()
self.I2I()
self.E(E_inp)
self.I(I_inp)
# monitor
return self.E.spike, self.I.spike, self.E.V, self.I.V
def get_run_func(EI_net, E_inp_kwargs, I_inp_kwargs, E_size, I_size, input_type='constant'):
if input_type == 'constant':
def run_func(i):
return EI_net.step_run(i, E_inp_kwargs['mean'], I_inp_kwargs['mean'])
if input_type == 'wiener':
def run_func(i):
local_E_inp = np.random.randn(E_size)*E_inp_kwargs['std'] + E_inp_kwargs['mean']
local_I_inp = np.random.randn(I_size)*I_inp_kwargs['std'] + I_inp_kwargs['mean']
return EI_net.step_run(i, local_E_inp, local_I_inp)
return run_func
# 利用brainpy搭建SNN
E2E_weight = 1
E2I_weight = 0
I2E_weight = 0
I2I_weight = 0
E_size = 2
I_size = 2
E_params = {'size': E_size, 'V_th': 20.0, 'V_reset': -5.0, 'V_rest':0., 'tau_ref': 5.0, 'R': 1.0, 'tau': 10.0}
I_params = {'size': I_size, 'V_th': 20.0, 'V_reset': -5.0, 'V_rest':0., 'tau_ref': 5.0, 'R': 1.0, 'tau': 10.0}
E2E_synapse_params = {'delay': 0}
E2I_synapse_params = {'delay': 0}
I2E_synapse_params = {'delay': 0}
I2I_synapse_params = {'delay': 0}
E_inp_mean = np.zeros(E_size)
E_inp_mean[0] = 30 # 只在第一个神经元输入电流
E_inp_std = 0
I_inp_mean = 0
I_inp_std = 0
dt = 1.
bm.set_dt(dt)
for mode in ['csr', 'ij']:
if mode == 'csr':
# 利用csr创建conn
row_indices = np.array([0, 0])
col_indices = np.array([1, 1])
E2E_csr = csr_matrix((np.ones_like(row_indices), (row_indices, col_indices)), shape=(E_size, E_size))
E2E_conn = bp.connect.SparseMatConn(E2E_csr)
print(f"csr:\n conn_mat: {E2E_conn.require('conn_mat')}\n pre ids: {E2E_conn.require('pre_ids')}\n post ids: {E2E_conn.require('post_ids')}")
if mode == 'ij':
# 利用ij创建conn
pre_list = np.array([0, 0])
post_list = np.array([1, 1])
E2E_conn = bp.conn.IJConn(i=pre_list, j=post_list)
E2E_conn = E2E_conn(pre_size=E_size, post_size=E_size)
print(f"ij:\n conn_mat: {E2E_conn.require('conn_mat')}\n pre ids: {E2E_conn.require('pre_ids')}\n post ids: {E2E_conn.require('post_ids')}")
# 利用conn创建comm
E2E_comm = bp.dnn.EventCSRLinear(conn=E2E_conn, weight=E2E_weight)
# set the weight 0 thus connection will be ignored
E2I_conn = E2E_conn
E2I_comm = bp.dnn.EventCSRLinear(conn=E2I_conn, weight=E2I_weight)
# set the weight 0 thus connection will be ignored
I2E_conn = E2E_conn
I2E_comm = bp.dnn.EventCSRLinear(conn=I2E_conn, weight=I2E_weight)
# set the weight 0 thus connection will be ignored
I2I_conn = E2E_conn
I2I_comm = bp.dnn.EventCSRLinear(conn=I2I_conn, weight=I2I_weight)
EI_net = EINet(E_neuron=bp.dyn.LifRef, I_neuron=bp.dyn.LifRef, E_params=E_params, I_params=I_params,
E2E_synapse=bp.dyn.FullProjDelta, E2I_synapse=bp.dyn.FullProjDelta, I2E_synapse=bp.dyn.FullProjDelta,
I2I_synapse=bp.dyn.FullProjDelta, E2E_synapse_params=E2E_synapse_params,
E2I_synapse_params=E2I_synapse_params, I2E_synapse_params=I2E_synapse_params,
I2I_synapse_params=I2I_synapse_params, E2E_comm=E2E_comm, E2I_comm=E2I_comm, I2E_comm=I2E_comm,
I2I_comm=I2I_comm)
run_func = get_run_func(EI_net, {'mean': E_inp_mean, 'std': E_inp_std}, {'mean': I_inp_mean, 'std': I_inp_std},
E_size, I_size, input_type='wiener')
indices = np.arange(100)
ts = indices * bm.get_dt()
E_spikes, I_spikes, E_V, I_V = bm.for_loop(run_func, indices, progress_bar=True)
fig, ax = plt.subplots()
ax.plot(ts, E_V[:, 0], label='E neuron 0', color='blue')
ax.plot(ts, E_V[:, 1], label='E neuron 1', color='red')
ax.axhline(E2E_weight, label='E2E weight', color='black')
ax.axhline(E2E_weight * 2, label='2*E2E weight', color='black', linestyle='--')
ax.legend()
ax.set_title(f'Connection mode: {mode}')
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment