Skip to content

Instantly share code, notes, and snippets.

@mstimberg
Last active March 12, 2020 17:45
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save mstimberg/838fef5fc8807be8ffb82632b7439729 to your computer and use it in GitHub Desktop.
Save mstimberg/838fef5fc8807be8ffb82632b7439729 to your computer and use it in GitHub Desktop.
import sympy
from brian2.parsing.sympytools import str_to_sympy, sympy_to_str
from brian2 import *
from brian2.equations.equations import (DIFFERENTIAL_EQUATION, SingleEquation,
PARAMETER, SUBEXPRESSION)
def evaluate_rhs(eqs, values, namespace=None, level=0):
"""
Evaluates the RHS of a system of differential equations for given state
variable values. External constants can be provided via the namespace or
will be taken from the local namespace.
This function could be used for example to find a resting state of the
system, i.e. a fixed point where the RHS of all equations are approximately
0.
Parameters
----------
eqs : `Equations`
The equations
values : dict-like
Values for each of the state variables (differential equations and
parameters).
Returns
-------
rhs : dict
A dictionary with the names of all variables defined by differential
equations as keys and the respective RHS of the equations as values.
"""
# Make a new set of equations, where differential equations are replaced
# by parameters, and a new subexpression defines their RHS.
# E.g. for 'dv/dt = -v / tau : volt' use:
# '''v : volt
# RHS_v = -v / tau : volt'''
new_equations = []
for eq in eqs.values():
if eq.type == DIFFERENTIAL_EQUATION:
new_equations.append(SingleEquation(PARAMETER, eq.varname,
dimensions=eq.dim,
var_type=eq.var_type))
new_equations.append(SingleEquation(SUBEXPRESSION, 'RHS_'+eq.varname,
dimensions=eq.dim/second.dim,
var_type=eq.var_type,
expr=eq.expr))
else:
new_equations.append(eq)
if namespace is None:
namespace = get_local_namespace(level+1)
# TODO: Hide this from standalone mode
group = NeuronGroup(1, model=Equations(new_equations),
codeobj_class=NumpyCodeObject,
namespace=namespace)
# Set the values of the state variables/parameters
group.set_states(values)
# Get the values of all RHS_... subexpressions
states = ['RHS_' + name for name in eqs.diff_eq_names]
return group.get_states(states)
def evaluate_jacobian(eqs, values, namespace=None, level=0):
if namespace is None:
namespace = get_local_namespace(level+1)
group = NeuronGroup(1, model=eqs,
codeobj_class=NumpyCodeObject,
namespace=namespace)
diff_eqs = eqs.get_substituted_expressions(group.variables)
diff_eq_names = [name for name, _ in diff_eqs]
system = sympy.Matrix([str_to_sympy(diff_eq[1].code)
for diff_eq in diff_eqs])
J = system.jacobian([str_to_sympy(d) for d in diff_eq_names])
new_eqs = []
for diff_eq_name, diff_eq in diff_eqs:
new_eqs.append(SingleEquation(PARAMETER, diff_eq_name,
dimensions=eqs[diff_eq_name].dim,
var_type=eqs[diff_eq_name].var_type))
for var_idx, diff_eq_var in enumerate(diff_eq_names):
for diff_idx, diff_eq_diff in enumerate(diff_eq_names):
dimensions = eqs[diff_eq_var].dim/second.dim/eqs[diff_eq_diff].dim
expr = f'{sympy_to_str(J[var_idx, diff_idx])}'
if expr == '0':
expr = f'0*{dimensions!r}'
new_eqs.append(SingleEquation(SUBEXPRESSION, f'J_{diff_eq_var}_{diff_eq_diff}',
dimensions=dimensions,
expr=Expression(expr)))
group = NeuronGroup(1, model=Equations(new_eqs),
codeobj_class=NumpyCodeObject,
namespace=namespace)
group.set_states(values)
# Create a matrix from the individual J_..._... values
states = group.get_states(
[f'J_{var}_{diff_var}'
for diff_var in diff_eq_names
for var in diff_eq_names],
units=False)
jac_matrix = np.array(
[[float(states[f'J_{var}_{diff_var}']) for diff_var in diff_eq_names]
for var in diff_eq_names])
return jac_matrix
if __name__ == '__main__':
# Parameters
area = 20000 * umetre ** 2
Cm = 1 * ufarad * cm ** -2 * area
gl = 5e-5 * siemens * cm ** -2 * area
El = -65 * mV
EK = -90 * mV
ENa = 50 * mV
g_na = 100 * msiemens * cm ** -2 * area
g_kd = 30 * msiemens * cm ** -2 * area
VT = -63 * mV
I = 0.01*nA
eqs = Equations('''
dv/dt = (gl*(El-v) - g_na*(m*m*m)*h*(v-ENa) - g_kd*(n*n*n*n)*(v-EK) + I)/Cm : volt
dm/dt = 0.32*(mV**-1)*(13.*mV-v+VT)/
(exp((13.*mV-v+VT)/(4.*mV))-1.)/ms*(1-m)-0.28*(mV**-1)*(v-VT-40.*mV)/
(exp((v-VT-40.*mV)/(5.*mV))-1.)/ms*m : 1
dn/dt = 0.032*(mV**-1)*(15.*mV-v+VT)/
(exp((15.*mV-v+VT)/(5.*mV))-1.)/ms*(1.-n)-.5*exp((10.*mV-v+VT)/(40.*mV))/ms*n : 1
dh/dt = 0.128*exp((17.*mV-v+VT)/(18.*mV))/ms*(1.-h)-4./(1+exp((40.*mV-v+VT)/(5.*mV)))/ms*h : 1
''')
# Find the resting state of this model
def wrapper(args):
rhs = evaluate_rhs(eqs, {'v': args[0]*volt,
'm': args[1],
'n': args[2],
'h': args[3]})
return [float(rhs['RHS_v']),
float(rhs['RHS_m']),
float(rhs['RHS_n']),
float(rhs['RHS_h'])]
from scipy.optimize import root
result = root(wrapper, x0=np.array([float(-70*mV), 0, 0, 0]))
# Evaluate the Jacobian at the (potential) resting value
jacobian = evaluate_jacobian(eqs, {'v': result.x[0]*volt,
'm': result.x[1],
'n': result.x[2],
'h': result.x[3]})
jac_eig = np.linalg.eigvals(jacobian)
if np.all(np.real(jac_eig) < 0):
print('Found a stable equilibrium point.')
else:
print('Equilibrium point is not stable.')
# Simulate neuron and compare resting state to calculated resting state
group = NeuronGroup(1, eqs, method='exponential_euler')
group.v = -70*mV
mon = StateMonitor(group, ['v', 'm', 'n', 'h'], record=0)
run(200*ms)
fig, axes = plt.subplots(2, 2, sharex='all')
axes[0, 0].plot(mon.t/ms, mon[0].v/mV, label='simulation')
axes[0, 0].plot(200, result.x[0]*1000, 'rx', label='resting state')
axes[0, 0].set(ylabel='v', xlabel='time (ms)')
axes[0, 1].plot(mon.t/ms, mon[0].m, label='simulation')
axes[0, 1].plot(200, result.x[1], 'rx', label='resting state')
axes[0, 1].set(ylabel='m', xlabel='time (ms)')
axes[1, 0].plot(mon.t/ms, mon[0].n, label='simulation')
axes[1, 0].plot(200, result.x[2], 'rx', label='resting state')
axes[1, 0].set(ylabel='n', xlabel='time (ms)')
axes[1, 1].plot(mon.t/ms, mon[0].h, label='simulation')
axes[1, 1].plot(200, result.x[3], 'rx', label='resting state')
axes[1, 1].set(ylabel='h', xlabel='time (ms)')
plt.tight_layout()
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment