Skip to content

Instantly share code, notes, and snippets.

@jimbaker
Created May 26, 2022 03:21
Show Gist options
  • Save jimbaker/619de23e34a28affc14e4fdfb1b99996 to your computer and use it in GitHub Desktop.
Save jimbaker/619de23e34a28affc14e4fdfb1b99996 to your computer and use it in GitHub Desktop.
from types import FunctionType
from typing import *
import textwrap
import dis # temporary debugging
# getvalue, raw, conv, formatspec
Thunk = tuple[
Callable[[], Any],
str,
str | None,
str | None,
]
def param_list(names):
return ', '.join(names)
def name_bindings(names):
c = ['{']
for name in names:
c.append(f' {name!r} : {name},')
c.append('}')
return '\n'.join(c)
def rewrite_thunk(thunk: Thunk) -> Thunk:
"""Given a thunk, return a rewritten thunk to return any used names.
When the thunk's getvalue is evaluated, returns a dict of name to any bound
value.
"""
getvalue, raw, conv, formatspec = thunk
code = getvalue.__code__
print(f'Compiling thunk {hash(getvalue.__code__)}')
dis.dis(code)
all_names = code.co_names + code.co_freevars
# # Implement the "lambda trick"
wrapped = f"""
def outer({param_list(code.co_freevars)}):
def inner():
return \\
{textwrap.indent(name_bindings(all_names), ' ' * 3)}
"""
print(wrapped)
capture = {}
exec(wrapped, getvalue.__globals__, capture)
new_lambda_code = capture["outer"].__code__.co_consts[1]
dis.dis(new_lambda_code)
new_getvalue = FunctionType(
new_lambda_code,
getvalue.__globals__,
getvalue.__name__,
getvalue.__defaults__,
getvalue.__closure__)
return new_getvalue, wrapped, conv, formatspec
def rewritten(*args: str | Thunk):
new_args = [] # def rewritten2(*args: str | Thunk):
for arg in args:
if isinstance(arg, str):
new_args.append(arg)
else:
new_args.append(rewrite_thunk(arg))
return new_args
# Set up some variables at differing level of nested scope
a = 2
def nested1():
b = 3
def nested2():
c = 5
def nested3():
d = 7
# new_args is rewritten such that each thunk's getvalue is a new
# function/code object that returns that the mapping of the
# variables that are used to their values (namely, for a, b, c, d)
new_args = rewritten"{d**a + c * c * c * a * b * a + d}"
print(new_args[0][0]())
nested3()
nested2()
nested1()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment