Created
December 28, 2022 14:53
-
-
Save tomicapretto/b938614811774b8ab4d9127bdb92e138 to your computer and use it in GitHub Desktop.
Split top level terms in a model formula
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Will be useful when we work with non-linear terms in Bambi. | |
# This is the first step. | |
# 1. Split top level terms DONE | |
# 2. Detect terms where non-linear parameters play a role TO DO | |
# 3. Manipulate the expression in terms where non-linear parameters play a role | |
# to pass only the 'predictor' to `formulae`, but keep the expression to be used later TO DO | |
import ast | |
from functools import singledispatch | |
@singledispatch | |
def accept(node): | |
return ast.unparse(node) | |
@accept.register(ast.BinOp) | |
def _(node): | |
if isinstance(node.op, ast.Add): | |
return [accept(node.left), accept(node.right)] | |
return ast.unparse(node) | |
def flatten_list(x): | |
output = [] | |
for element in x: | |
if isinstance(element, list): | |
output += flatten_list(element) | |
else: | |
output.append(element) | |
return output | |
def split_top_level_terms(formula): | |
terms_list = accept(ast.parse(formula).body[0].value) | |
return flatten_list(terms_list) | |
split_top_level_terms("z + f(x + g(y + z)) + 5 * np.exp(1) + a * np.exp(x * y)") | |
# ['z', 'f(x + g(y + z))', '5 * np.exp(1)', 'a * np.exp(x * y)'] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment