Created
November 28, 2022 13:40
-
-
Save llandsmeer/d1219cbe25e1b6e0783b885a69b634fc to your computer and use it in GitHub Desktop.
Single file numpy implementation of IO network
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import matplotlib.pyplot as plt | |
import numba | |
import time | |
NUM_STATE_VARS = 14 | |
def main(): | |
n = 5 | |
s = make_initial_neuron_state(n ** 3, V_soma=None, V_axon=None) | |
src, tgt = sample_connections_3d(n ** 3, rmax=4) | |
trace = [] | |
g_CaL = 0.5+1.2*np.random.random(n**3).astype('float32') | |
for i in range(1000): | |
a = time.perf_counter() | |
s = one_ms(s, gj_src=src, gj_tgt=tgt, g_gj=0.05, g_CaL=g_CaL) | |
b = time.perf_counter() | |
if i == 0: | |
print(f'initial jit compile: {b - a:.2f}s') | |
elif i == 1: | |
print(f'other runs: {b - a:.2f}s') | |
trace.append(s[0, :]) | |
trace = np.array(trace) | |
plt.plot(trace) | |
plt.show() | |
@numba.jit(nopython=True, fastmath=True, cache=True) | |
def one_ms(state, gj_src, gj_tgt, g_gj, g_CaL): # map args through | |
for _ in range(40): | |
state = timestep(state, gj_src=gj_src, gj_tgt=gj_tgt, g_gj=g_gj, g_CaL=g_CaL) | |
return state | |
def make_initial_neuron_state( | |
ncells, | |
# Soma State | |
V_soma = -60.0, | |
soma_k = 0.7423159, | |
soma_l = 0.0321349, | |
soma_h = 0.3596066, | |
soma_n = 0.2369847, | |
soma_x = 0.1, | |
# Axon state | |
V_axon = -60.0, | |
axon_Sodium_h = 0.9, | |
axon_Potassium_x= 0.2369847, | |
# Dend state | |
V_dend = -60.0, | |
dend_Ca2Plus = 3.715, | |
dend_Calcium_r = 0.0113, | |
dend_Potassium_s= 0.0049291, | |
dend_Hcurrent_q = 0.0337836, | |
dtype=np.float32): | |
return np.array([ | |
# Soma state | |
[V_soma]*ncells if V_soma is not None else np.random.normal(-60, 3, ncells), | |
[soma_k]*ncells if soma_k is not None else np.random.random(ncells), | |
[soma_l]*ncells if soma_l is not None else np.random.random(ncells), | |
[soma_h]*ncells if soma_h is not None else np.random.random(ncells), | |
[soma_n]*ncells if soma_n is not None else np.random.random(ncells), | |
[soma_x]*ncells if soma_x is not None else np.random.random(ncells), | |
# Axon state | |
[V_axon]*ncells if V_axon is not None else np.random.normal(-60, 3, ncells), | |
[axon_Sodium_h]*ncells if axon_Sodium_h is not None else np.random.random(ncells), | |
[axon_Potassium_x]*ncells if axon_Potassium_x is not None else np.random.random(ncells), | |
# Dend state | |
[V_dend]*ncells if V_dend is not None else np.random.normal(-60, 3, ncells), | |
[dend_Ca2Plus]*ncells, | |
[dend_Calcium_r]*ncells if dend_Calcium_r is not None else np.random.random(ncells), | |
[dend_Potassium_s]*ncells if dend_Potassium_s is not None else np.random.random(ncells), | |
[dend_Hcurrent_q]*ncells if dend_Hcurrent_q is not None else np.random.random(ncells), | |
], dtype=dtype) | |
@numba.jit(nopython=True, fastmath=True, cache=True) | |
def timestep(state, gj_src, gj_tgt, g_gj, | |
# Simulation parameters | |
delta=0.025, | |
# Geometry parameters | |
g_int = 0.13, # Cell internal conductance -- now a parameter (0.13) | |
p1 = 0.25, # Cell surface ratio soma/dendrite | |
p2 = 0.15, # Cell surface ratio axon(hillock)/soma | |
# Channel conductance parameters | |
g_CaL = 1.1, # Calcium T - (CaV 3.1) (0.7) | |
g_h = 0.12, # H current (HCN) (0.4996) | |
g_K_Ca = 35.0, # Potassium (KCa v1.1 - BK) (35) | |
g_ld = 0.01532, # Leak dendrite (0.016) | |
g_la = 0.016, # Leak axon (0.016) | |
g_ls = 0.016, # Leak soma (0.016) | |
g_Na_s = 150.0, # Sodium - (Na v1.6 ) | |
g_Kdr_s = 9.0, # Potassium - (K v4.3) | |
g_K_s = 5.0, # Potassium - (K v3.4) | |
g_CaH = 4.5, # High-threshold calcium -- Ca V2.1 | |
g_Na_a = 240.0, # Sodium | |
g_K_a = 240.0, # Potassium (20) | |
# Membrane capacitance | |
S = 1.0, # 1/C_m, cm^2/uF | |
# Reversal potential parameters | |
V_Na = 55.0, # Sodium | |
V_K = -75.0, # Potassium | |
V_Ca = 120.0, # Low-threshold calcium channel | |
V_h = -43.0, # H current | |
V_l = 10.0, # Leak | |
# Stimulus parameter | |
I_app = 0.0, | |
): | |
assert state.shape[0] == NUM_STATE_VARS | |
# Soma state | |
V_soma = state[0, :] | |
soma_k = state[1, :] | |
soma_l = state[2, :] | |
soma_h = state[3, :] | |
soma_n = state[4, :] | |
soma_x = state[5, :] | |
# Axon state | |
V_axon = state[6, :] | |
axon_Sodium_h = state[7, :] | |
axon_Potassium_x = state[8, :] | |
# Dend state | |
V_dend = state[9, :] | |
dend_Ca2Plus = state[10,:] | |
dend_Calcium_r = state[11,:] | |
dend_Potassium_s = state[12,:] | |
dend_Hcurrent_q = state[13,:] | |
########## SOMA UPDATE ########## | |
# CURRENT: Soma leak current (ls) | |
soma_I_leak = g_ls * (V_soma - V_l) | |
# CURRENT: Soma interaction current (ds, as) | |
I_ds = (g_int / p1) * (V_soma - V_dend) | |
I_as = (g_int / (1 - p2)) * (V_soma - V_axon) | |
soma_I_interact = I_ds + I_as | |
# CHANNEL: Soma Low-threshold calcium (CaL) | |
soma_Ical = g_CaL * soma_k * soma_k * soma_k * soma_l * (V_soma - V_Ca) | |
soma_k_inf = 1 / (1 + np.exp(-(V_soma + 61)/4.2)) | |
soma_l_inf = 1 / (1 + np.exp( (V_soma + 85)/8.5)) | |
soma_tau_l = (20 * np.exp((V_soma + 160)/30) / (1 + np.exp((V_soma + 84) / 7.3))) + 35 | |
soma_dk_dt = soma_k_inf - soma_k | |
soma_dl_dt = (soma_l_inf - soma_l) / soma_tau_l | |
# CHANNEL: Soma sodium (Na_s) | |
# watch out direct gate: m = m_inf | |
soma_m_inf = 1 / (1 + np.exp(-(V_soma + 30)/5.5)) | |
soma_h_inf = 1 / (1 + np.exp( (V_soma + 70)/5.8)) | |
soma_Ina = g_Na_s * soma_m_inf**3 * soma_h * (V_soma - V_Na) | |
soma_tau_h = 3 * np.exp(-(V_soma + 40)/33) | |
soma_dh_dt = (soma_h_inf - soma_h) / soma_tau_h | |
# CHANNEL: Soma potassium, slow component (Kdr) | |
soma_Ikdr = g_Kdr_s * soma_n**4 * (V_soma - V_K) | |
soma_n_inf = 1 / ( 1 + np.exp(-(V_soma + 3)/10)) | |
soma_tau_n = 5 + (47 * np.exp( (V_soma + 50)/900)) | |
soma_dn_dt = (soma_n_inf - soma_n) / soma_tau_n | |
# CHANNEL: Soma potassium, fast component (K_s) | |
soma_Ik = g_K_s * soma_x**4 * (V_soma - V_K) | |
soma_alpha_x = 0.13 * (V_soma + 25) / (1 - np.exp(-(V_soma + 25)/10)) | |
soma_beta_x = 1.69 * np.exp(-(V_soma + 35)/80) | |
soma_tau_x_inv=soma_alpha_x + soma_beta_x | |
soma_x_inf = soma_alpha_x / soma_tau_x_inv | |
soma_dx_dt = (soma_x_inf - soma_x) * soma_tau_x_inv | |
# UPDATE: Soma compartment update (V_soma) | |
soma_I_Channels = soma_Ik + soma_Ikdr + soma_Ina + soma_Ical | |
soma_dv_dt = S * (-(soma_I_leak + soma_I_interact + soma_I_Channels)) | |
########## AXON UPDATE ########## | |
# CURRENT: Axon leak current (la) | |
axon_I_leak = g_la * (V_axon - V_l) | |
# CURRENT: Axon interaction current (sa) | |
I_sa = (g_int / p2) * (V_axon - V_soma) | |
axon_I_interact= I_sa | |
# CHANNEL: Axon sodium (Na_a) | |
# watch out direct gate: m = m_inf | |
axon_m_inf = 1 / (1 + np.exp(-(V_axon+30)/5.5)) | |
axon_h_inf = 1 / (1 + np.exp( (V_axon+60)/5.8)) | |
axon_Ina = g_Na_a * axon_m_inf**3 * axon_Sodium_h * (V_axon - V_Na) | |
axon_tau_h = 1.5 * np.exp(-(V_axon+40)/33) | |
axon_dh_dt = (axon_h_inf - axon_Sodium_h) / axon_tau_h | |
# CHANNEL: Axon potassium (K_a) | |
axon_Ik = g_K_a * axon_Potassium_x**4 * (V_axon - V_K) | |
axon_alpha_x = 0.13*(V_axon + 25) / (1 - np.exp(-(V_axon + 25)/10)) | |
axon_beta_x = 1.69 * np.exp(-(V_axon + 35)/80) | |
axon_tau_x_inv = axon_alpha_x + axon_beta_x | |
axon_x_inf = axon_alpha_x / axon_tau_x_inv | |
axon_dx_dt = (axon_x_inf - axon_Potassium_x) * axon_tau_x_inv | |
# UPDATE: Axon hillock compartment update (V_axon) | |
axon_I_Channels = axon_Ina + axon_Ik | |
axon_dv_dt = S * (-(axon_I_leak + axon_I_interact + axon_I_Channels)) | |
########## DEND UPDATE ########## | |
# CURRENT: Dend application current (I_app) | |
vdiff = V_dend[gj_src] - V_dend[gj_tgt] | |
cx36_current_per_gj = (0.2 + 0.8 * np.exp(-vdiff*vdiff / 100)) * vdiff * g_gj | |
I_gapp = np.zeros_like(V_dend) | |
for i in range(len(gj_tgt)): | |
I_gapp[gj_tgt[i]] += cx36_current_per_gj[i] | |
dend_I_application = -I_app - I_gapp | |
# CURRENT: Dend leak current (ld) | |
dend_I_leak = g_ld * (V_dend - V_l) | |
# CURRENT: Dend interaction Current (sd) | |
dend_I_interact = (g_int / (1 - p1)) * (V_dend - V_soma) | |
# CHANNEL: Dend high-threshold calcium (CaH) | |
dend_Icah = g_CaH * dend_Calcium_r * dend_Calcium_r * (V_dend - V_Ca) | |
dend_alpha_r = 1.7 / (1 + np.exp(-(V_dend - 5)/13.9)) | |
dend_beta_r = 0.02*(V_dend + 8.5) / (np.exp((V_dend + 8.5)/5) - 1.0) | |
dend_tau_r_inv5 = (dend_alpha_r + dend_beta_r) # tau = 5 / (alpha + beta) | |
dend_r_inf = dend_alpha_r / dend_tau_r_inv5 | |
dend_dr_dt = (dend_r_inf - dend_Calcium_r) * dend_tau_r_inv5 * 0.2 | |
# CHANNEL: Dend calcium dependent potassium (KCa) | |
dend_Ikca = g_K_Ca * dend_Potassium_s * (V_dend - V_K) | |
dend_alpha_s = np.where( | |
0.00002 * dend_Ca2Plus < 0.01, | |
0.00002 * dend_Ca2Plus, | |
0.01) | |
dend_tau_s_inv = dend_alpha_s + 0.015 | |
dend_s_inf = dend_alpha_s / dend_tau_s_inv | |
dend_ds_dt = (dend_s_inf - dend_Potassium_s) * dend_tau_s_inv | |
# CHANNEL: Dend proton (h) | |
dend_Ih = g_h * dend_Hcurrent_q * (V_dend - V_h) | |
q_inf = 1 / (1 + np.exp((V_dend + 80)/4)) | |
tau_q_inv = np.exp(-0.086*V_dend - 14.6) + np.exp(0.070*V_dend - 1.87) | |
dend_dq_dt = (q_inf - dend_Hcurrent_q) * tau_q_inv | |
# CONCENTRATION: Dend calcium concentration (CaPlus) | |
dend_dCa_dt = -3 * dend_Icah - 0.075 * dend_Ca2Plus | |
# UPDATE: Dend compartment update (V_dend) | |
dend_I_Channels = dend_Icah + dend_Ikca + dend_Ih | |
dend_dv_dt = S * (-(dend_I_leak + dend_I_interact + dend_I_application + dend_I_Channels)) | |
########## UPDATE ########## | |
return np.stack(( | |
# Soma state | |
V_soma + soma_dv_dt * delta, | |
soma_k + soma_dk_dt * delta, | |
soma_l + soma_dl_dt * delta, | |
soma_h + soma_dh_dt * delta, | |
soma_n + soma_dn_dt * delta, | |
soma_x + soma_dx_dt * delta, | |
# Axon state | |
V_axon + axon_dv_dt * delta, | |
axon_Sodium_h + axon_dh_dt * delta, | |
axon_Potassium_x + axon_dx_dt * delta, | |
# Dend state | |
V_dend + dend_dv_dt * delta, | |
dend_Ca2Plus + dend_dCa_dt* delta, | |
dend_Calcium_r + dend_dr_dt * delta, | |
dend_Potassium_s + dend_ds_dt * delta, | |
dend_Hcurrent_q + dend_dq_dt * delta, | |
), axis=0).astype(np.float32) | |
def sample_connections_3d( | |
nneurons, | |
nconnections=10, | |
rmax=2, | |
connection_probability=lambda r: np.exp(-(r/4)**2), | |
normalize_by_dr=True | |
): | |
assert int(round(nneurons**(1/3)))**3 == nneurons | |
# we sample half the connections for each neuron | |
assert nconnections % 2 == 0 | |
# we assume a cubic (4d toroid) brain | |
nside = int(np.ceil(nneurons**(1/3))) | |
if rmax > nside / 2: rmax = nside // 2 | |
# we set up a connection probability kernel around each neuron | |
dx, dy, dz = np.mgrid[-rmax:rmax+1, -rmax:rmax+1, -rmax:rmax+1] | |
dx, dy, dz = dx.flatten(), dy.flatten(), dz.flatten() | |
r = np.sqrt(dx*dx + dy*dy + dz*dz) | |
# we only sample backwards, as the forward connections | |
# are part of the kernel of other neurons | |
sample_backwards = \ | |
((dz < 0)) | \ | |
((dz == 0) &( dy < 0)) | \ | |
((dz == 0) & (dy == 0) & (dx < 0)) | |
m = (r != 0) & sample_backwards & (r < rmax) | |
dx, dy, dz, r = dx[m], dy[m], dz[m], r[m] | |
P = connection_probability(r) | |
# next, there is a ~r^2 increase in point density per r, | |
# and very non uniform distribution of those due to | |
# the integer grid. let's remove that bias | |
ro, r_uniq_idx = np.unique(r, return_inverse=True) | |
r_idx_freq = np.bincount(r_uniq_idx) | |
r_freq = r_idx_freq[r_uniq_idx] | |
P = P / r_freq | |
if normalize_by_dr: | |
dr = 0.5*np.diff(ro, append=rmax)[r_uniq_idx] + 0.5*np.diff(ro, prepend=0)[r_uniq_idx] | |
P = P * dr | |
# P must sum up to 1 | |
P = P / P.sum() | |
# a connection connects two neurons | |
final_connection_count = nneurons * nconnections // 2 | |
# instead of sampling using the P array, | |
# we sample for each value of the P array, | |
# which is much more memory efficient | |
counts = (P * final_connection_count + .5).astype(int) | |
counts[-1] = max(0, final_connection_count - counts[:-1].sum()) | |
assert (counts < nneurons).all() | |
conn_idx = [] | |
for draw in range(len(P)): | |
if counts[draw] == 0: | |
continue | |
if counts[draw] == 1: | |
draw_idx = np.array([np.random.randint(nneurons)]) | |
else: | |
draw_idx = np.random.choice(nneurons, counts[draw], replace=False) | |
conn_idx.append(draw + len(P) * draw_idx) | |
conn_idx = np.concatenate(conn_idx) | |
# now we calculate the neuron indices back from the P kernel | |
neuron_id1 = conn_idx // len(P) | |
x = ( neuron_id1 % nside).astype('int32') | |
y = ((neuron_id1 // nside) % nside).astype('int32') | |
z = ((neuron_id1 // (nside*nside)) % nside).astype('int32') | |
di = conn_idx % len(P) | |
neuron_id2 = ( \ | |
(x + dx[di]) % nside + \ | |
(y + dy[di]) % nside * nside + \ | |
(z + dz[di]) % nside * nside * nside | |
).astype(int) | |
# and generate the final index arrays | |
# needed for gj calculation | |
tgt_idx = np.concatenate([neuron_id1, neuron_id2]) | |
src_idx = np.concatenate([neuron_id2, neuron_id1]) | |
return src_idx, tgt_idx | |
if __name__ == '__main__': | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment