Created
September 27, 2022 06:27
-
-
Save llandsmeer/3ae8f460e3904bdb636e472d9e569744 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import sympy as sm | |
import lark | |
parser = lark.Lark(r''' | |
%import common.SIGNED_NUMBER -> NUMBER | |
%import common.WS | |
%ignore WS | |
%ignore /\?.*\n/ | |
%ignore /:.*\n/ | |
NAME: /[A-Za-z_][A-Za-z_0-9]*/ | |
_NL: /\n/ | |
_UNIT: /\([^()]+\)/ | |
TITLE: /[^\n\r]+/ | |
ENDCOMMENT: /.*?^ENDCOMMENT$/ | |
program: block* | |
?block : breakpoint_block | |
| derivative_block | |
| initial_block | |
| state_block | |
| initial_block | |
| neuron_block | |
| parameter_block | |
| title_block | |
| units_block | |
| assigned_block | |
| procedure_block | |
neuron_block: "NEURON" neuron_body | |
breakpoint_block: "BREAKPOINT" simple_body | |
initial_block: "INITIAL" simple_body | |
state_block: "STATE" decl_body | |
derivative_block: "DERIVATIVE" NAME simple_body | |
parameter_block: "PARAMETER" param_body | |
title_block: "TITLE" TITLE _NL | |
units_block: "UNITS" units_body | |
assigned_block: "ASSIGNED" assigned_body | |
procedure_block: "PROCEDURE" NAME "(" NAME ")" simple_body | |
simple_body: "{" (stmt _NL)* stmt? "}" | |
?stmt: stmt_assign | |
| stmt_deriv | |
| stmt_if | |
| stmt_method | |
| stmt_local | |
| stmt_call | |
stmt_if: "if" "(" expr ")" simple_body | |
stmt_assign: NAME "=" expr | |
stmt_deriv: NAME "'" "=" expr | |
stmt_method: "SOLVE" NAME "METHOD" NAME | |
stmt_local: "LOCAL" (NAME ",")* NAME | |
stmt_call: NAME "(" NAME ")" | |
param_body: "{" (param_decl _NL)* param_decl? "}" | |
param_decl: NAME ("=" NUMBER)? _UNIT? | |
units_body: "{" (_UNIT "=" _UNIT)* "}" | |
assigned_body: "{" (NAME _UNIT?)* "}" | |
decl_body: "{" (decl )* decl? "}" | |
?decl: NAME _UNIT? | |
?neuron_body: "{" (neuron_stmt _NL+)* neuron_stmt? "}" | |
?neuron_stmt: "SUFFIX" NAME -> nrn_suffix | |
| "NONSPECIFIC_CURRENT" NAME -> nrn_nsc | |
| "USEION" NAME ("READ" NAME)? ("WRITE" NAME)? ("VALENCE" NUMBER)? -> nrn_useion | |
| "RANGE" (NAME ",")* NAME? -> nrn_range | |
?expr: eq | |
?eq: sum | |
| eq "==" sum -> eq | |
| eq "!=" sum -> neq | |
?sum: product | |
| sum "+" product -> add | |
| sum "-" product -> sub | |
?product: atom | |
| product "*" atom -> mul | |
| product "/" atom -> div | |
?atom: NUMBER -> number | |
| atom ("^" atom)+ -> pow | |
| "-" atom -> neg | |
| NAME "(" sum ")" -> fun | |
| NAME -> var | |
| "(" sum ")" | |
''', start='program') | |
class NMODLTransformer(lark.Transformer): | |
def number(self, n): | |
assert len(n) == 1 | |
return sm.Float(n[0]) | |
def var(self, e): | |
assert len(e) == 1 | |
return sm.Symbol(e[0]) | |
def neg(self, e): | |
assert len(e) == 1 | |
return -e[0] | |
def fun(self, e): | |
name = e[0].value | |
if name == 'exp': | |
return sm.E**e[1] | |
raise Exception(f'Unknown function {e[0]}') | |
return sm.Function(e[0])(e[1]) | |
def mul(self, e): return e[0] * e[1] | |
def div(self, e): return e[0] / e[1] | |
def add(self, e): return e[0] + e[1] | |
def sub(self, e): return e[0] - e[1] | |
def neq(self, e): return e[0] != e[1] | |
def eq(self, e): return e[0] == e[1] | |
def pow(self, e): | |
e = list(e) | |
while len(e) > 1: | |
q = e.pop(-1) | |
e[-1] = e[-1] ** q | |
return e[0] | |
# def stmt_if(self, e): return e | |
# def stmt_assign(self, e): return e | |
# def stmt_deriv(self, e): return e | |
# def stmt_method(self, e): return e | |
# def stmt_local(self, e): return e | |
def simple_body(self, e): | |
return e | |
def neuron_body(self, e): | |
global_names = [] | |
for names in e: | |
global_names.extend(names) | |
return global_names | |
def assigned_body(self, e): | |
return [e.value for e in e] | |
def decl_body(self, e): | |
return [e.value for e in e] | |
def param_decl(self, e): | |
if len(e) == 2: | |
return e[0].value, float(e[1]) | |
else: | |
return e[0].value, None | |
def param_body(self, e): | |
return dict(e) | |
def neuron_stmt(self, e): return e | |
def nrn_suffix(self, e): return [] | |
def nrn_nsc(self, e): return [e[0].value] | |
def nrn_useion(self, e): return [x.value for x in e[1:3]] | |
def nrn_range(self, e): return [] | |
class BodyParser: | |
def __init__(self, body, global_names=(), procs=()): | |
self.procs = dict(procs) | |
self.global_names = global_names | |
self.local_names = [] | |
self.method = None | |
self.local_env = {} | |
self.global_env = {} | |
self.cond = None | |
for stmt in body: | |
self.parse_stmt(stmt) | |
def parse_stmt(self, stmt): | |
if stmt.data == 'stmt_local': | |
self.local_names.extend(e.value for e in stmt.children) | |
elif stmt.data == 'stmt_method': | |
self.method = stmt.children[0], stmt.children[1] | |
elif stmt.data in ('stmt_assign', 'stmt_deriv'): | |
name, expr = stmt.children | |
if stmt.data == 'stmt_deriv': | |
name = name + "'" | |
is_global_assign = name in self.global_names | |
env = self.global_env if is_global_assign else self.local_env | |
if not is_global_assign: | |
if name not in self.local_names: | |
raise Exception(f'Reference to undefined local {name}') | |
name = sm.Symbol(name) | |
expr = expr.subs(self.local_env) | |
if self.cond is None: | |
env[name] = expr | |
else: | |
env[name] = sm.Piecewise( | |
(expr, self.cond), | |
(name if is_global_assign else env[name], True)) | |
elif stmt.data == 'stmt_if': | |
old_cond = self.cond | |
cond = stmt.children[0] | |
self.cond = sm.And(self.cond, cond) if self.cond is not None else cond | |
for s in stmt.children[1]: | |
self.parse_stmt(s) | |
self.cond = old_cond | |
elif stmt.data == 'stmt_call': | |
if stmt.children[1] != 'v': | |
raise Exception('Sorry proc support is very dumb for the moment') | |
name = stmt.children[0].value | |
# no scopes | |
if name not in self.procs: | |
raise Exception('') | |
for stmt in self.procs[name]: | |
self.parse_stmt(stmt) | |
else: | |
raise Exception(stmt.data) | |
class NMODL: | |
def __init__(self, text): | |
self.init = {} | |
self.update = {} | |
self.params = {} | |
lines = [line for line in text.splitlines()] | |
out = [] | |
in_comment = False | |
for line in lines: | |
line = line.rstrip() | |
if line == 'COMMENT': | |
in_comment = True | |
continue | |
elif line == 'ENDCOMMENT': | |
in_comment = False | |
continue | |
elif in_comment: | |
continue | |
if ':' in line: | |
line = line.split(':')[0].rstrip() | |
if '?' in line: | |
line = line.split('?')[0].rstrip() | |
out.append(line) | |
assert not in_comment | |
text = '\n'.join(out) | |
ast = parser.parse(text) | |
t = NMODLTransformer() | |
out = t.transform(ast) | |
blocks = {} | |
blocks_deriv = {} | |
blocks_proc = {} | |
for block in out.children: | |
c = block.children | |
if block.data == 'derivative_block': | |
blocks_deriv[c[0]] = c[1] | |
elif block.data == 'procedure_block': | |
assert c[1] == 'v' # sorry | |
blocks_proc[c[0]] = c[2] | |
else: | |
assert len(c) == 1 | |
blocks[block.data] = c[0] | |
self.neuron_names = blocks.get('neuron_block', ()) | |
self.state_names = blocks.get('state_block', ()) | |
self.assigned_names = blocks.get('assigned_block', ()) | |
self.params.update(dict(blocks.get('parameter_block', ()))) | |
self.global_names = global_names = [] | |
global_names.extend(self.neuron_names) | |
global_names.extend(self.params) | |
global_names.extend(self.state_names) | |
global_names.extend(x + "'" for x in self.state_names) | |
global_names.extend(self.assigned_names) | |
if 'initial_block' in blocks: | |
init = BodyParser(blocks['initial_block'], global_names, blocks_proc) | |
for k, v in init.global_env.items(): | |
assert str(k)[-1] != "'" | |
self.init[k] = v.simplify() | |
solve_target = None | |
if 'breakpoint_block' in blocks: | |
breakpoint = BodyParser(blocks['breakpoint_block'], global_names, blocks_proc) | |
if breakpoint.method is not None: | |
solve_target = breakpoint.method[0] | |
for k, v in breakpoint.global_env.items(): | |
self.update[k] = v.simplify() | |
if solve_target is not None: | |
deriv = BodyParser(blocks_deriv[solve_target], global_names, blocks_proc) | |
for k, v in deriv.global_env.items(): | |
self.update[k] = v.simplify() | |
def validate_hh(): | |
text = open('hh.mod').read( | |
mod = NMODL(text) | |
for k, v in mod.init.items(): | |
print(k, '=', v) | |
for k, v in mod.update.items(): | |
print(k, '=', v) | |
import os, glob | |
root = '/home/llandsmeer/repos/thorstenhater/nmlcc/experiments' | |
pattern = os.path.join(root, '**/*.mod') | |
mods = glob.glob(pattern, recursive=True) | |
for fn in mods: | |
base = fn.removeprefix(root) | |
print(fn) | |
mod = NMODL(open(fn).read()) | |
print('---') | |
for k, v in mod.init.items(): | |
print(k, '=', v) | |
for k, v in mod.update.items(): | |
print(k, '=', v) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Known bugs:
else if
andelse
v