Last active
March 12, 2020 17:45
-
-
Save mstimberg/838fef5fc8807be8ffb82632b7439729 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import sympy | |
from brian2.parsing.sympytools import str_to_sympy, sympy_to_str | |
from brian2 import * | |
from brian2.equations.equations import (DIFFERENTIAL_EQUATION, SingleEquation, | |
PARAMETER, SUBEXPRESSION) | |
def evaluate_rhs(eqs, values, namespace=None, level=0): | |
""" | |
Evaluates the RHS of a system of differential equations for given state | |
variable values. External constants can be provided via the namespace or | |
will be taken from the local namespace. | |
This function could be used for example to find a resting state of the | |
system, i.e. a fixed point where the RHS of all equations are approximately | |
0. | |
Parameters | |
---------- | |
eqs : `Equations` | |
The equations | |
values : dict-like | |
Values for each of the state variables (differential equations and | |
parameters). | |
Returns | |
------- | |
rhs : dict | |
A dictionary with the names of all variables defined by differential | |
equations as keys and the respective RHS of the equations as values. | |
""" | |
# Make a new set of equations, where differential equations are replaced | |
# by parameters, and a new subexpression defines their RHS. | |
# E.g. for 'dv/dt = -v / tau : volt' use: | |
# '''v : volt | |
# RHS_v = -v / tau : volt''' | |
new_equations = [] | |
for eq in eqs.values(): | |
if eq.type == DIFFERENTIAL_EQUATION: | |
new_equations.append(SingleEquation(PARAMETER, eq.varname, | |
dimensions=eq.dim, | |
var_type=eq.var_type)) | |
new_equations.append(SingleEquation(SUBEXPRESSION, 'RHS_'+eq.varname, | |
dimensions=eq.dim/second.dim, | |
var_type=eq.var_type, | |
expr=eq.expr)) | |
else: | |
new_equations.append(eq) | |
if namespace is None: | |
namespace = get_local_namespace(level+1) | |
# TODO: Hide this from standalone mode | |
group = NeuronGroup(1, model=Equations(new_equations), | |
codeobj_class=NumpyCodeObject, | |
namespace=namespace) | |
# Set the values of the state variables/parameters | |
group.set_states(values) | |
# Get the values of all RHS_... subexpressions | |
states = ['RHS_' + name for name in eqs.diff_eq_names] | |
return group.get_states(states) | |
def evaluate_jacobian(eqs, values, namespace=None, level=0): | |
if namespace is None: | |
namespace = get_local_namespace(level+1) | |
group = NeuronGroup(1, model=eqs, | |
codeobj_class=NumpyCodeObject, | |
namespace=namespace) | |
diff_eqs = eqs.get_substituted_expressions(group.variables) | |
diff_eq_names = [name for name, _ in diff_eqs] | |
system = sympy.Matrix([str_to_sympy(diff_eq[1].code) | |
for diff_eq in diff_eqs]) | |
J = system.jacobian([str_to_sympy(d) for d in diff_eq_names]) | |
new_eqs = [] | |
for diff_eq_name, diff_eq in diff_eqs: | |
new_eqs.append(SingleEquation(PARAMETER, diff_eq_name, | |
dimensions=eqs[diff_eq_name].dim, | |
var_type=eqs[diff_eq_name].var_type)) | |
for var_idx, diff_eq_var in enumerate(diff_eq_names): | |
for diff_idx, diff_eq_diff in enumerate(diff_eq_names): | |
dimensions = eqs[diff_eq_var].dim/second.dim/eqs[diff_eq_diff].dim | |
expr = f'{sympy_to_str(J[var_idx, diff_idx])}' | |
if expr == '0': | |
expr = f'0*{dimensions!r}' | |
new_eqs.append(SingleEquation(SUBEXPRESSION, f'J_{diff_eq_var}_{diff_eq_diff}', | |
dimensions=dimensions, | |
expr=Expression(expr))) | |
group = NeuronGroup(1, model=Equations(new_eqs), | |
codeobj_class=NumpyCodeObject, | |
namespace=namespace) | |
group.set_states(values) | |
# Create a matrix from the individual J_..._... values | |
states = group.get_states( | |
[f'J_{var}_{diff_var}' | |
for diff_var in diff_eq_names | |
for var in diff_eq_names], | |
units=False) | |
jac_matrix = np.array( | |
[[float(states[f'J_{var}_{diff_var}']) for diff_var in diff_eq_names] | |
for var in diff_eq_names]) | |
return jac_matrix | |
if __name__ == '__main__': | |
# Parameters | |
area = 20000 * umetre ** 2 | |
Cm = 1 * ufarad * cm ** -2 * area | |
gl = 5e-5 * siemens * cm ** -2 * area | |
El = -65 * mV | |
EK = -90 * mV | |
ENa = 50 * mV | |
g_na = 100 * msiemens * cm ** -2 * area | |
g_kd = 30 * msiemens * cm ** -2 * area | |
VT = -63 * mV | |
I = 0.01*nA | |
eqs = Equations(''' | |
dv/dt = (gl*(El-v) - g_na*(m*m*m)*h*(v-ENa) - g_kd*(n*n*n*n)*(v-EK) + I)/Cm : volt | |
dm/dt = 0.32*(mV**-1)*(13.*mV-v+VT)/ | |
(exp((13.*mV-v+VT)/(4.*mV))-1.)/ms*(1-m)-0.28*(mV**-1)*(v-VT-40.*mV)/ | |
(exp((v-VT-40.*mV)/(5.*mV))-1.)/ms*m : 1 | |
dn/dt = 0.032*(mV**-1)*(15.*mV-v+VT)/ | |
(exp((15.*mV-v+VT)/(5.*mV))-1.)/ms*(1.-n)-.5*exp((10.*mV-v+VT)/(40.*mV))/ms*n : 1 | |
dh/dt = 0.128*exp((17.*mV-v+VT)/(18.*mV))/ms*(1.-h)-4./(1+exp((40.*mV-v+VT)/(5.*mV)))/ms*h : 1 | |
''') | |
# Find the resting state of this model | |
def wrapper(args): | |
rhs = evaluate_rhs(eqs, {'v': args[0]*volt, | |
'm': args[1], | |
'n': args[2], | |
'h': args[3]}) | |
return [float(rhs['RHS_v']), | |
float(rhs['RHS_m']), | |
float(rhs['RHS_n']), | |
float(rhs['RHS_h'])] | |
from scipy.optimize import root | |
result = root(wrapper, x0=np.array([float(-70*mV), 0, 0, 0])) | |
# Evaluate the Jacobian at the (potential) resting value | |
jacobian = evaluate_jacobian(eqs, {'v': result.x[0]*volt, | |
'm': result.x[1], | |
'n': result.x[2], | |
'h': result.x[3]}) | |
jac_eig = np.linalg.eigvals(jacobian) | |
if np.all(np.real(jac_eig) < 0): | |
print('Found a stable equilibrium point.') | |
else: | |
print('Equilibrium point is not stable.') | |
# Simulate neuron and compare resting state to calculated resting state | |
group = NeuronGroup(1, eqs, method='exponential_euler') | |
group.v = -70*mV | |
mon = StateMonitor(group, ['v', 'm', 'n', 'h'], record=0) | |
run(200*ms) | |
fig, axes = plt.subplots(2, 2, sharex='all') | |
axes[0, 0].plot(mon.t/ms, mon[0].v/mV, label='simulation') | |
axes[0, 0].plot(200, result.x[0]*1000, 'rx', label='resting state') | |
axes[0, 0].set(ylabel='v', xlabel='time (ms)') | |
axes[0, 1].plot(mon.t/ms, mon[0].m, label='simulation') | |
axes[0, 1].plot(200, result.x[1], 'rx', label='resting state') | |
axes[0, 1].set(ylabel='m', xlabel='time (ms)') | |
axes[1, 0].plot(mon.t/ms, mon[0].n, label='simulation') | |
axes[1, 0].plot(200, result.x[2], 'rx', label='resting state') | |
axes[1, 0].set(ylabel='n', xlabel='time (ms)') | |
axes[1, 1].plot(mon.t/ms, mon[0].h, label='simulation') | |
axes[1, 1].plot(200, result.x[3], 'rx', label='resting state') | |
axes[1, 1].set(ylabel='h', xlabel='time (ms)') | |
plt.tight_layout() | |
plt.show() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment