Skip to content

Instantly share code, notes, and snippets.

@tomicapretto
Created December 28, 2022 14:53
Show Gist options
  • Save tomicapretto/b938614811774b8ab4d9127bdb92e138 to your computer and use it in GitHub Desktop.
Save tomicapretto/b938614811774b8ab4d9127bdb92e138 to your computer and use it in GitHub Desktop.
Split top level terms in a model formula
# Will be useful when we work with non-linear terms in Bambi.
# This is the first step.
# 1. Split top level terms DONE
# 2. Detect terms where non-linear parameters play a role TO DO
# 3. Manipulate the expression in terms where non-linear parameters play a role
# to pass only the 'predictor' to `formulae`, but keep the expression to be used later TO DO
import ast
from functools import singledispatch
@singledispatch
def accept(node):
return ast.unparse(node)
@accept.register(ast.BinOp)
def _(node):
if isinstance(node.op, ast.Add):
return [accept(node.left), accept(node.right)]
return ast.unparse(node)
def flatten_list(x):
output = []
for element in x:
if isinstance(element, list):
output += flatten_list(element)
else:
output.append(element)
return output
def split_top_level_terms(formula):
terms_list = accept(ast.parse(formula).body[0].value)
return flatten_list(terms_list)
split_top_level_terms("z + f(x + g(y + z)) + 5 * np.exp(1) + a * np.exp(x * y)")
# ['z', 'f(x + g(y + z))', '5 * np.exp(1)', 'a * np.exp(x * y)']
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment