Skip to content

Instantly share code, notes, and snippets.

@eigenfoo
Created March 26, 2020 01:13
Show Gist options
  • Save eigenfoo/43282c4e69156647d7bb2505f1dbafb2 to your computer and use it in GitHub Desktop.
Save eigenfoo/43282c4e69156647d7bb2505f1dbafb2 to your computer and use it in GitHub Desktop.
class FunctionToGenerator(ast.NodeTransformer):
"""
This subclass traverses the AST of the user-written, decorated,
model specification and transforms it into a generator for the
model. Subclassing in this way is the idiomatic way to transform
an AST.
Specifically:
1. Add `yield` keywords to all assignments
E.g. `x = tfd.Normal(0, 1)` -> `x = yield tfd.Normal(0, 1)`
2. Rename the model specification function to
`_pm_compiled_model_generator`. This is done out an abundance
of caution more than anything.
3. Remove the @Model decorator. Otherwise, we risk running into
an infinite recursion.
"""
def visit_Assign(self, node):
new_node = node
new_node.value = ast.Yield(value=new_node.value)
# Tie up loose ends in the AST.
ast.copy_location(new_node, node)
ast.fix_missing_locations(new_node)
self.generic_visit(node)
return new_node
def visit_FunctionDef(self, node):
new_node = node
new_node.name = "_pm_compiled_model_generator"
new_node.decorator_list = []
# Tie up loose ends in the AST.
ast.copy_location(new_node, node)
ast.fix_missing_locations(new_node)
self.generic_visit(node)
return new_node
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment