Last active
May 10, 2024 07:27
-
-
Save Routhleck/f37c18283c169ed3148f2ab0ac6a1a08 to your computer and use it in GitHub Desktop.
BrainPy Issue #667
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import sys | |
import matplotlib.pyplot as plt | |
import numpy as np | |
import brainpy as bp | |
import brainpy.math as bm | |
import os | |
from scipy.sparse import coo_matrix, csr_matrix | |
class EINet(bp.DynSysGroup): | |
def __init__(self, E_neuron, I_neuron, E_params, I_params, E2E_synapse, E2I_synapse, I2E_synapse, I2I_synapse, E2E_synapse_params, E2I_synapse_params, I2E_synapse_params, I2I_synapse_params, E2E_comm, E2I_comm, I2E_comm, I2I_comm): | |
super().__init__() | |
self.E_params = E_params.copy() | |
self.I_params = I_params.copy() | |
self.E2E_synapse_params = E2E_synapse_params.copy() | |
self.E2I_synapse_params = E2I_synapse_params.copy() | |
self.I2E_synapse_params = I2E_synapse_params.copy() | |
self.I2I_synapse_params = I2I_synapse_params.copy() | |
# neurons | |
self.E = E_neuron(**self.E_params) | |
self.I = I_neuron(**self.I_params) | |
# synapses | |
self.E2E = E2E_synapse(pre=self.E, post=self.E, comm=E2E_comm, **self.E2E_synapse_params) | |
self.E2I = E2I_synapse(pre=self.E, post=self.I, comm=E2I_comm, **self.E2I_synapse_params) | |
self.I2E = I2E_synapse(pre=self.I, post=self.E, comm=I2E_comm, **self.I2E_synapse_params) | |
self.I2I = I2I_synapse(pre=self.I, post=self.I, comm=I2I_comm, **self.I2I_synapse_params) | |
def update(self, E_inp, I_inp): | |
self.E2E() | |
self.E2I() | |
self.I2E() | |
self.I2I() | |
self.E(E_inp) | |
self.I(I_inp) | |
# monitor | |
return self.E.spike, self.I.spike, self.E.V, self.I.V | |
def get_run_func(EI_net, E_inp_kwargs, I_inp_kwargs, E_size, I_size, input_type='constant'): | |
if input_type == 'constant': | |
def run_func(i): | |
return EI_net.step_run(i, E_inp_kwargs['mean'], I_inp_kwargs['mean']) | |
if input_type == 'wiener': | |
def run_func(i): | |
local_E_inp = np.random.randn(E_size)*E_inp_kwargs['std'] + E_inp_kwargs['mean'] | |
local_I_inp = np.random.randn(I_size)*I_inp_kwargs['std'] + I_inp_kwargs['mean'] | |
return EI_net.step_run(i, local_E_inp, local_I_inp) | |
return run_func | |
# 利用brainpy搭建SNN | |
E2E_weight = 1 | |
E2I_weight = 0 | |
I2E_weight = 0 | |
I2I_weight = 0 | |
E_size = 2 | |
I_size = 2 | |
E_params = {'size': E_size, 'V_th': 20.0, 'V_reset': -5.0, 'V_rest':0., 'tau_ref': 5.0, 'R': 1.0, 'tau': 10.0} | |
I_params = {'size': I_size, 'V_th': 20.0, 'V_reset': -5.0, 'V_rest':0., 'tau_ref': 5.0, 'R': 1.0, 'tau': 10.0} | |
E2E_synapse_params = {'delay': 0} | |
E2I_synapse_params = {'delay': 0} | |
I2E_synapse_params = {'delay': 0} | |
I2I_synapse_params = {'delay': 0} | |
E_inp_mean = np.zeros(E_size) | |
E_inp_mean[0] = 30 # 只在第一个神经元输入电流 | |
E_inp_std = 0 | |
I_inp_mean = 0 | |
I_inp_std = 0 | |
dt = 1. | |
bm.set_dt(dt) | |
for mode in ['csr', 'ij']: | |
if mode == 'csr': | |
# 利用csr创建conn | |
row_indices = np.array([0, 0]) | |
col_indices = np.array([1, 1]) | |
E2E_csr = csr_matrix((np.ones_like(row_indices), (row_indices, col_indices)), shape=(E_size, E_size)) | |
E2E_conn = bp.connect.SparseMatConn(E2E_csr) | |
print(f"csr:\n conn_mat: {E2E_conn.require('conn_mat')}\n pre ids: {E2E_conn.require('pre_ids')}\n post ids: {E2E_conn.require('post_ids')}") | |
if mode == 'ij': | |
# 利用ij创建conn | |
pre_list = np.array([0, 0]) | |
post_list = np.array([1, 1]) | |
E2E_conn = bp.conn.IJConn(i=pre_list, j=post_list) | |
E2E_conn = E2E_conn(pre_size=E_size, post_size=E_size) | |
print(f"ij:\n conn_mat: {E2E_conn.require('conn_mat')}\n pre ids: {E2E_conn.require('pre_ids')}\n post ids: {E2E_conn.require('post_ids')}") | |
# 利用conn创建comm | |
E2E_comm = bp.dnn.EventCSRLinear(conn=E2E_conn, weight=E2E_weight) | |
# set the weight 0 thus connection will be ignored | |
E2I_conn = E2E_conn | |
E2I_comm = bp.dnn.EventCSRLinear(conn=E2I_conn, weight=E2I_weight) | |
# set the weight 0 thus connection will be ignored | |
I2E_conn = E2E_conn | |
I2E_comm = bp.dnn.EventCSRLinear(conn=I2E_conn, weight=I2E_weight) | |
# set the weight 0 thus connection will be ignored | |
I2I_conn = E2E_conn | |
I2I_comm = bp.dnn.EventCSRLinear(conn=I2I_conn, weight=I2I_weight) | |
EI_net = EINet(E_neuron=bp.dyn.LifRef, I_neuron=bp.dyn.LifRef, E_params=E_params, I_params=I_params, | |
E2E_synapse=bp.dyn.FullProjDelta, E2I_synapse=bp.dyn.FullProjDelta, I2E_synapse=bp.dyn.FullProjDelta, | |
I2I_synapse=bp.dyn.FullProjDelta, E2E_synapse_params=E2E_synapse_params, | |
E2I_synapse_params=E2I_synapse_params, I2E_synapse_params=I2E_synapse_params, | |
I2I_synapse_params=I2I_synapse_params, E2E_comm=E2E_comm, E2I_comm=E2I_comm, I2E_comm=I2E_comm, | |
I2I_comm=I2I_comm) | |
run_func = get_run_func(EI_net, {'mean': E_inp_mean, 'std': E_inp_std}, {'mean': I_inp_mean, 'std': I_inp_std}, | |
E_size, I_size, input_type='wiener') | |
indices = np.arange(100) | |
ts = indices * bm.get_dt() | |
E_spikes, I_spikes, E_V, I_V = bm.for_loop(run_func, indices, progress_bar=True) | |
fig, ax = plt.subplots() | |
ax.plot(ts, E_V[:, 0], label='E neuron 0', color='blue') | |
ax.plot(ts, E_V[:, 1], label='E neuron 1', color='red') | |
ax.axhline(E2E_weight, label='E2E weight', color='black') | |
ax.axhline(E2E_weight * 2, label='2*E2E weight', color='black', linestyle='--') | |
ax.legend() | |
ax.set_title(f'Connection mode: {mode}') | |
plt.show() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment