Skip to content

Instantly share code, notes, and snippets.

@llandsmeer
Last active November 30, 2022 14:33
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save llandsmeer/64e009f94b4127e7371c4633739df12f to your computer and use it in GitHub Desktop.
Save llandsmeer/64e009f94b4127e7371c4633739df12f to your computer and use it in GitHub Desktop.
Arbor as an arbitrary ODE solver
import arbor
import re
import subprocess
import tempfile
import pathlib
import matplotlib.pyplot as plt
nmodl = '''
NEURON {
SUFFIX node
}
PARAMETER {
sigma = 10
rho = 28
beta = 2.66666666666666666666666666666666666666666666666666
x0 = 1
y0 = 1
z0 = 1
}
STATE {
x y z
}
INITIAL {
x = x0
y = y0
z = z0
}
DERIVATIVE dstate {
x' = sigma * (y - x)
y' = x*(rho - z) - y
z' = x*y - beta * z
}
BREAKPOINT {
SOLVE dstate METHOD sparse
}
'''
with tempfile.TemporaryDirectory() as f:
cat_root = pathlib.Path(f) / 'cat'
cat_root.mkdir()
print(cat_root)
name = re.search(r'SUFFIX ([^ \n\r]+)', nmodl).group(1)
states = re.search(r'STATE.*{([^}]+).*}', nmodl).group(1).strip().split()
file = cat_root / (name + '.mod')
file.write_text(nmodl)
res = subprocess.run(['arbor-build-catalogue', 'cat', cat_root])
assert res.returncode == 0
tree = arbor.segment_tree()
tree.append(arbor.mnpos, arbor.mpoint(0, 0, 0, 1), arbor.mpoint(1, 0, 0, 1), tag=1)
labels = arbor.label_dict()
decor = arbor.decor() .paint('(all)', arbor.density(name))
cell = arbor.cable_cell(tree, decor, labels)
props = arbor.neuron_cable_properties()
props.catalogue.extend(arbor.load_catalogue('./cat-catalogue.so'), '')
probes = [
arbor.cable_probe_density_state('(root)', name, state)
for state in states
]
class Recipe(arbor.recipe):
def probes(self, _): return probes
def num_cells(self): return 1
def cell_kind(self, _): return arbor.cell_kind.cable
def cell_description(self, _): return cell
def global_properties(self, _): return props
recipe = Recipe()
sim = arbor.simulation(recipe)
handles = [sim.sample((0, i), arbor.regular_schedule(0.01)) for i in range(len(states))]
sim.run(tfinal=100, dt=0.001)
for i, state in enumerate(states):
(data, meta), = sim.samples(i)
time, value = data.T
plt.plot(time, value, label=state)
plt.legend()
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment