Skip to content

Instantly share code, notes, and snippets.

@dhilst
Last active March 22, 2022 13:26
Show Gist options
  • Save dhilst/4bce4210394e75ae4cc03d081204ae64 to your computer and use it in GitHub Desktop.
Save dhilst/4bce4210394e75ae4cc03d081204ae64 to your computer and use it in GitHub Desktop.
from ast import (
NodeTransformer,
arguments,
arg,
Lambda,
parse,
In,
Call,
Expression,
fix_missing_locations,
keyword,
dump,
unparse,
)
class LetVisitor(NodeTransformer):
def visit_Compare(self, node):
self.generic_visit(node)
if node.left.func.id == "let" and isinstance(node.ops[0], In):
args = arguments(
posonlyargs=[],
args=[arg(k.arg) for k in node.left.keywords],
kwonlyargs=[],
kw_defaults=[],
defaults=[],
)
kwds = [keyword(k.arg, k.value) for k in node.left.keywords]
lamb = Lambda(args=args, body=node.comparators[0])
call = Call(lamb, args=[], keywords=kwds)
return fix_missing_locations(call)
return node
def visit_FunctionDef(self, node):
found = None
node.decorator_list = list(
filter(lambda d: d.id != "letdec", node.decorator_list)
)
self.generic_visit(node)
return node
def letdec(f):
import inspect, types
source = inspect.getsource(f)
print("original source\n")
print(source, "\n\n")
old_code_obj = f.__code__
old_ast = parse(source)
new_ast = fix_missing_locations(LetVisitor().visit(old_ast))
new_code_obj = compile(new_ast, old_code_obj.co_filename, "exec")
print("new source\n")
print(unparse(new_ast), "\n\n")
new_f = types.FunctionType(new_code_obj.co_consts[0], f.__globals__)
return new_f
@letdec
def foo():
return let(a=1, b=2) in (let(c=3) in a + b + c)
print("result", foo())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment