Skip to content

Instantly share code, notes, and snippets.

@llandsmeer
Created September 27, 2022 06:27
Show Gist options
  • Save llandsmeer/3ae8f460e3904bdb636e472d9e569744 to your computer and use it in GitHub Desktop.
Save llandsmeer/3ae8f460e3904bdb636e472d9e569744 to your computer and use it in GitHub Desktop.
import sympy as sm
import lark
parser = lark.Lark(r'''
%import common.SIGNED_NUMBER -> NUMBER
%import common.WS
%ignore WS
%ignore /\?.*\n/
%ignore /:.*\n/
NAME: /[A-Za-z_][A-Za-z_0-9]*/
_NL: /\n/
_UNIT: /\([^()]+\)/
TITLE: /[^\n\r]+/
ENDCOMMENT: /.*?^ENDCOMMENT$/
program: block*
?block : breakpoint_block
| derivative_block
| initial_block
| state_block
| initial_block
| neuron_block
| parameter_block
| title_block
| units_block
| assigned_block
| procedure_block
neuron_block: "NEURON" neuron_body
breakpoint_block: "BREAKPOINT" simple_body
initial_block: "INITIAL" simple_body
state_block: "STATE" decl_body
derivative_block: "DERIVATIVE" NAME simple_body
parameter_block: "PARAMETER" param_body
title_block: "TITLE" TITLE _NL
units_block: "UNITS" units_body
assigned_block: "ASSIGNED" assigned_body
procedure_block: "PROCEDURE" NAME "(" NAME ")" simple_body
simple_body: "{" (stmt _NL)* stmt? "}"
?stmt: stmt_assign
| stmt_deriv
| stmt_if
| stmt_method
| stmt_local
| stmt_call
stmt_if: "if" "(" expr ")" simple_body
stmt_assign: NAME "=" expr
stmt_deriv: NAME "'" "=" expr
stmt_method: "SOLVE" NAME "METHOD" NAME
stmt_local: "LOCAL" (NAME ",")* NAME
stmt_call: NAME "(" NAME ")"
param_body: "{" (param_decl _NL)* param_decl? "}"
param_decl: NAME ("=" NUMBER)? _UNIT?
units_body: "{" (_UNIT "=" _UNIT)* "}"
assigned_body: "{" (NAME _UNIT?)* "}"
decl_body: "{" (decl )* decl? "}"
?decl: NAME _UNIT?
?neuron_body: "{" (neuron_stmt _NL+)* neuron_stmt? "}"
?neuron_stmt: "SUFFIX" NAME -> nrn_suffix
| "NONSPECIFIC_CURRENT" NAME -> nrn_nsc
| "USEION" NAME ("READ" NAME)? ("WRITE" NAME)? ("VALENCE" NUMBER)? -> nrn_useion
| "RANGE" (NAME ",")* NAME? -> nrn_range
?expr: eq
?eq: sum
| eq "==" sum -> eq
| eq "!=" sum -> neq
?sum: product
| sum "+" product -> add
| sum "-" product -> sub
?product: atom
| product "*" atom -> mul
| product "/" atom -> div
?atom: NUMBER -> number
| atom ("^" atom)+ -> pow
| "-" atom -> neg
| NAME "(" sum ")" -> fun
| NAME -> var
| "(" sum ")"
''', start='program')
class NMODLTransformer(lark.Transformer):
def number(self, n):
assert len(n) == 1
return sm.Float(n[0])
def var(self, e):
assert len(e) == 1
return sm.Symbol(e[0])
def neg(self, e):
assert len(e) == 1
return -e[0]
def fun(self, e):
name = e[0].value
if name == 'exp':
return sm.E**e[1]
raise Exception(f'Unknown function {e[0]}')
return sm.Function(e[0])(e[1])
def mul(self, e): return e[0] * e[1]
def div(self, e): return e[0] / e[1]
def add(self, e): return e[0] + e[1]
def sub(self, e): return e[0] - e[1]
def neq(self, e): return e[0] != e[1]
def eq(self, e): return e[0] == e[1]
def pow(self, e):
e = list(e)
while len(e) > 1:
q = e.pop(-1)
e[-1] = e[-1] ** q
return e[0]
# def stmt_if(self, e): return e
# def stmt_assign(self, e): return e
# def stmt_deriv(self, e): return e
# def stmt_method(self, e): return e
# def stmt_local(self, e): return e
def simple_body(self, e):
return e
def neuron_body(self, e):
global_names = []
for names in e:
global_names.extend(names)
return global_names
def assigned_body(self, e):
return [e.value for e in e]
def decl_body(self, e):
return [e.value for e in e]
def param_decl(self, e):
if len(e) == 2:
return e[0].value, float(e[1])
else:
return e[0].value, None
def param_body(self, e):
return dict(e)
def neuron_stmt(self, e): return e
def nrn_suffix(self, e): return []
def nrn_nsc(self, e): return [e[0].value]
def nrn_useion(self, e): return [x.value for x in e[1:3]]
def nrn_range(self, e): return []
class BodyParser:
def __init__(self, body, global_names=(), procs=()):
self.procs = dict(procs)
self.global_names = global_names
self.local_names = []
self.method = None
self.local_env = {}
self.global_env = {}
self.cond = None
for stmt in body:
self.parse_stmt(stmt)
def parse_stmt(self, stmt):
if stmt.data == 'stmt_local':
self.local_names.extend(e.value for e in stmt.children)
elif stmt.data == 'stmt_method':
self.method = stmt.children[0], stmt.children[1]
elif stmt.data in ('stmt_assign', 'stmt_deriv'):
name, expr = stmt.children
if stmt.data == 'stmt_deriv':
name = name + "'"
is_global_assign = name in self.global_names
env = self.global_env if is_global_assign else self.local_env
if not is_global_assign:
if name not in self.local_names:
raise Exception(f'Reference to undefined local {name}')
name = sm.Symbol(name)
expr = expr.subs(self.local_env)
if self.cond is None:
env[name] = expr
else:
env[name] = sm.Piecewise(
(expr, self.cond),
(name if is_global_assign else env[name], True))
elif stmt.data == 'stmt_if':
old_cond = self.cond
cond = stmt.children[0]
self.cond = sm.And(self.cond, cond) if self.cond is not None else cond
for s in stmt.children[1]:
self.parse_stmt(s)
self.cond = old_cond
elif stmt.data == 'stmt_call':
if stmt.children[1] != 'v':
raise Exception('Sorry proc support is very dumb for the moment')
name = stmt.children[0].value
# no scopes
if name not in self.procs:
raise Exception('')
for stmt in self.procs[name]:
self.parse_stmt(stmt)
else:
raise Exception(stmt.data)
class NMODL:
def __init__(self, text):
self.init = {}
self.update = {}
self.params = {}
lines = [line for line in text.splitlines()]
out = []
in_comment = False
for line in lines:
line = line.rstrip()
if line == 'COMMENT':
in_comment = True
continue
elif line == 'ENDCOMMENT':
in_comment = False
continue
elif in_comment:
continue
if ':' in line:
line = line.split(':')[0].rstrip()
if '?' in line:
line = line.split('?')[0].rstrip()
out.append(line)
assert not in_comment
text = '\n'.join(out)
ast = parser.parse(text)
t = NMODLTransformer()
out = t.transform(ast)
blocks = {}
blocks_deriv = {}
blocks_proc = {}
for block in out.children:
c = block.children
if block.data == 'derivative_block':
blocks_deriv[c[0]] = c[1]
elif block.data == 'procedure_block':
assert c[1] == 'v' # sorry
blocks_proc[c[0]] = c[2]
else:
assert len(c) == 1
blocks[block.data] = c[0]
self.neuron_names = blocks.get('neuron_block', ())
self.state_names = blocks.get('state_block', ())
self.assigned_names = blocks.get('assigned_block', ())
self.params.update(dict(blocks.get('parameter_block', ())))
self.global_names = global_names = []
global_names.extend(self.neuron_names)
global_names.extend(self.params)
global_names.extend(self.state_names)
global_names.extend(x + "'" for x in self.state_names)
global_names.extend(self.assigned_names)
if 'initial_block' in blocks:
init = BodyParser(blocks['initial_block'], global_names, blocks_proc)
for k, v in init.global_env.items():
assert str(k)[-1] != "'"
self.init[k] = v.simplify()
solve_target = None
if 'breakpoint_block' in blocks:
breakpoint = BodyParser(blocks['breakpoint_block'], global_names, blocks_proc)
if breakpoint.method is not None:
solve_target = breakpoint.method[0]
for k, v in breakpoint.global_env.items():
self.update[k] = v.simplify()
if solve_target is not None:
deriv = BodyParser(blocks_deriv[solve_target], global_names, blocks_proc)
for k, v in deriv.global_env.items():
self.update[k] = v.simplify()
def validate_hh():
text = open('hh.mod').read(
mod = NMODL(text)
for k, v in mod.init.items():
print(k, '=', v)
for k, v in mod.update.items():
print(k, '=', v)
import os, glob
root = '/home/llandsmeer/repos/thorstenhater/nmlcc/experiments'
pattern = os.path.join(root, '**/*.mod')
mods = glob.glob(pattern, recursive=True)
for fn in mods:
base = fn.removeprefix(root)
print(fn)
mod = NMODL(open(fn).read())
print('---')
for k, v in mod.init.items():
print(k, '=', v)
for k, v in mod.update.items():
print(k, '=', v)
@llandsmeer
Copy link
Author

Known bugs:

  • If statemnts are always true
  • If statemtns on an undefiend variable
  • else if and else
  • ^ and - operator order might be wrong
  • procedures are only supported with argument v

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment