Last active
August 27, 2021 21:48
-
-
Save internetimagery/7012246ac8aae8fa5e185f634db60582 to your computer and use it in GitHub Desktop.
Simple do notation for python (modifying the ast)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import ast | |
from inspect import getsource | |
from types import FunctionType, CodeType | |
from functools import wraps | |
from textwrap import dedent | |
from linecache import getline | |
def do(func): | |
""" | |
Use to emulate do notation syntax on a function. Rewrites the underlying | |
source code to desugar the notation. | |
Example: | |
>>> @do | |
>>> def add_two(first, second): | |
>>> num1 <<= first | |
>>> num2 <<= second | |
>>> return num1 + num2 | |
>>> assert add_two(Just(1), Just(2)) == Just(3) | |
The desugared version of the function would look something like the following... | |
>>> def add_two(first, second): | |
>>> def __GENERATED__0(num1): | |
>>> def __GENERATED__1(num2): | |
>>> return __LAST_MONAD.pure(num1 + num2) | |
>>> __LAST_MONAD = second | |
>>> return __LAST_MONAD.flat_map(__GENERATED__1) | |
>>> __LAST_MONAD = first | |
>>> return __LAST_MONAD.flat_map(__GENERATED__0) | |
Due to this, bind notation (name <<= expression) is limited to expressions | |
at the root level of the function. No nested bindings. | |
This makes things simpler, and removes unexpected behavour. | |
For example a monad does not have to execute code right away: | |
>>> @do | |
>>> def my_func(monad): | |
>>> with context(...): | |
>>> val <<= monad # This is a nested bind statement | |
The generated output (if allowed): | |
>>> def my_func(monad): | |
>>> with context(...): | |
>>> def __GENERATED__0(val): | |
>>> return __LAST_MONAD.pure(None) | |
>>> __LAST_MONAD = monad | |
>>> return __LAST_MONAD.flat_map(__GENERATED__0) | |
>>> # We can return here without actually having executed the function | |
>>> # and the context will have expired. | |
""" | |
source = dedent(getsource(func)) | |
source_ast = ast.increment_lineno( | |
ast.parse(source), | |
n=func.__code__.co_firstlineno - 1, | |
) | |
modified_ast = ReWrite(func).visit(source_ast) | |
new_code = next( | |
filter( | |
lambda x: isinstance(x, CodeType), | |
compile(modified_ast, func.__code__.co_filename, "exec").co_consts, | |
) | |
) | |
new_func = FunctionType( | |
new_code, | |
func.__globals__, | |
) | |
@wraps(func) | |
def wrapper(*args, **kwargs): | |
return new_func(*args, **kwargs) | |
return wrapper | |
class ReWrite(ast.NodeTransformer): | |
LAST_MONAD_NAME = "__LAST_MONAD" | |
FUNC_NAME = "__GENERATED__{}" | |
def __init__(self, func): | |
self.func = func | |
def visit_FunctionDef(self, node): | |
"""Strip leading decorators""" | |
if not any(map(self._parse_do_expr, ast.walk(node))): | |
err = SyntaxError( | |
f"Decorated function {self.func} contains no <<= expression, cannot infer monad type." | |
) | |
err.lineno = self.func.__code__.co_firstlineno | |
err.filename = self.func.__code__.co_filename | |
err.text = getline(err.filename, err.lineno) | |
raise err | |
node.decorator_list.clear() | |
node.body = list(self._parse_body(enumerate(iter(node.body)))) | |
return node | |
def _parse_body(self, statements): | |
for i, stmt in statements: | |
# Checking for: | |
# name <<= expression | |
do_expr = self._parse_do_expr(stmt) | |
if do_expr: | |
for inner_stmt in self._write_flatmap( | |
i, do_expr[0], do_expr[1], statements | |
): | |
yield inner_stmt | |
continue | |
return_ = self._parse_return(stmt) | |
if return_: | |
yield self._write_pure(return_) | |
continue | |
nested_do_expr = tuple( | |
filter(None, map(self._parse_do_expr, ast.walk(stmt))) | |
) | |
if nested_do_expr: | |
err = SyntaxError("Nested <<= operators not supported") | |
err.lineno = nested_do_expr[0][0].lineno | |
err.filename = self.func.__code__.co_filename | |
err.text = getline(err.filename, err.lineno) | |
raise err | |
else: | |
yield stmt | |
# Always return None at the end | |
yield self._write_pure( | |
ast.fix_missing_locations(ast.Constant(value=None, kind=None)) | |
) | |
def _write_flatmap(self, i, var, expr, body): | |
loc = lambda n: ast.copy_location(n, var) | |
# Break code into functions | |
# >>> def __GENERATED__0(name): | |
# >>> ... | |
func_name = self.FUNC_NAME.format(i) | |
func = self._build_func( | |
func_name, | |
[var.id], | |
body, | |
loc, | |
) | |
callee = self._build_name(func_name, loc) | |
yield func | |
# Evaluate the monad expression and store it | |
# >>> __LAST_MONAD = expr | |
last_monad = self._build_assign(self.LAST_MONAD_NAME, expr, loc) | |
yield last_monad | |
# Return with a flatmapped expression | |
# >>> return __LAST_MONAD.flat_map(__GENERATED__0) | |
return_ = loc( | |
ast.Return( | |
value=self._build_flatmap( | |
self._build_name(self.LAST_MONAD_NAME, loc), | |
callee, | |
loc, | |
), | |
) | |
) | |
yield return_ | |
def _write_pure(self, expr): | |
loc = lambda n: ast.copy_location(n, expr) | |
return_ = loc( | |
ast.Return( | |
value=self._build_pure( | |
self._build_name(self.LAST_MONAD_NAME, loc), | |
expr, | |
loc, | |
) | |
) | |
) | |
return return_ | |
def _parse_do_expr(self, node): | |
if not (isinstance(node, ast.AugAssign) and isinstance(node.op, ast.LShift)): | |
return None | |
return node.target, node.value | |
def _parse_return(self, node): | |
if isinstance(node, ast.Return): | |
return node.value | |
return None | |
@staticmethod | |
def _build_name(name, loc): | |
return loc(ast.Name(id=name, ctx=ast.Load())) | |
@staticmethod | |
def _build_assign(name, expr, loc): | |
return loc( | |
ast.Assign( | |
targets=[loc(ast.Name(id=name, ctx=ast.Store()))], | |
value=expr, | |
) | |
) | |
@staticmethod | |
def _build_args(args): | |
return ast.arguments( | |
args=args, | |
posonlyargs=[], | |
kwonlyargs=[], | |
kw_defaults=[], | |
vararg=None, | |
kwarg=None, | |
defaults=[], | |
) | |
def _build_func(self, name, args, body, loc): | |
return loc( | |
ast.FunctionDef( | |
name=name, | |
args=self._build_args([loc(ast.arg(arg=arg)) for arg in args]), | |
body=list(self._parse_body(body)), | |
decorator_list=[], | |
) | |
) | |
@staticmethod | |
def _build_flatmap(monad, func, loc): | |
return loc( | |
ast.Call( | |
func=loc( | |
ast.Attribute( | |
value=monad, | |
attr="flat_map", | |
ctx=ast.Load(), | |
) | |
), | |
args=[func], | |
keywords=[], | |
) | |
) | |
@staticmethod | |
def _build_pure(monad, expr, loc): | |
return loc( | |
ast.Call( | |
func=loc( | |
ast.Attribute( | |
value=monad, | |
attr="pure", | |
ctx=ast.Load(), | |
) | |
), | |
args=[expr], | |
keywords=[], | |
) | |
) | |
if __name__ == "__main__": | |
import attr | |
from typing import Any, TypeVar, Generic, TYPE_CHECKING | |
A = TypeVar("A") | |
@attr.s | |
class Just(Generic[A]): | |
value = attr.ib() | |
@classmethod | |
def pure(cls, value): | |
return cls(value) | |
def flat_map(self, func): | |
return func(self.value) | |
if TYPE_CHECKING: | |
def __rlshift__(self, other: Any) -> A: | |
... | |
@attr.s | |
class Nothing: | |
@classmethod | |
def pure(cls, value): | |
return Just(value) | |
def flat_map(self, func): | |
return self | |
if TYPE_CHECKING: | |
def __rlshift__(self, other: Any) -> Any: | |
... | |
@do | |
def add_two_times_two(first: Just[int], second: Just[int]) -> Just[int]: | |
if TYPE_CHECKING: | |
num1 = num2 = None | |
num1 <<= first # flat_map / bind | |
num2 <<= second | |
num3 = num1 + num2 | |
return num3 * 2 | |
assert add_two_times_two(Just(1), Just(2)) == Just(6) | |
assert add_two_times_two(Nothing(), Just(2)) == Nothing() | |
assert add_two_times_two(Just(1), Nothing()) == Nothing() | |
assert add_two_times_two(Nothing(), Nothing()) == Nothing() | |
@do | |
def select_plus_two(first: Just[int], second: Just[int], select: bool) -> Just[int]: | |
if TYPE_CHECKING: | |
selection = None | |
# If statement in expression | |
selection <<= first if select else second | |
return selection + 2 | |
assert select_plus_two(Just(1), Just(11), True) == Just(3) | |
assert select_plus_two(Just(1), Just(11), False) == Just(13) | |
try: | |
# Error if there is no bind <<= | |
@do | |
def nothing(): | |
return 123 | |
except SyntaxError: | |
pass | |
else: | |
assert False | |
# Implicitly return None | |
@do | |
def no_return(): | |
_ <<= Just(1) | |
assert no_return() == Just(None) | |
try: | |
# Error for nested bind operators | |
@do | |
def loop(): | |
val <<= Just(0) | |
for i in range(10): | |
j <<= Just(i) | |
val += j | |
return val | |
except SyntaxError: | |
pass | |
else: | |
assert False | |
# Loops and other nestings are ok, so long as they're contained within a generated function | |
@do | |
def loop(): | |
val <<= Just(0) | |
for i in range(10): | |
val += i | |
return val | |
assert loop() == Just(45) | |
class Test: | |
@classmethod | |
@do | |
def run(cls): | |
val <<= Just(3) | |
return val * 2 | |
assert Test.run() == Just(6) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment