Skip to content

Instantly share code, notes, and snippets.

@internetimagery
Last active August 27, 2021 21:48
Show Gist options
  • Save internetimagery/7012246ac8aae8fa5e185f634db60582 to your computer and use it in GitHub Desktop.
Save internetimagery/7012246ac8aae8fa5e185f634db60582 to your computer and use it in GitHub Desktop.
Simple do notation for python (modifying the ast)
import ast
from inspect import getsource
from types import FunctionType, CodeType
from functools import wraps
from textwrap import dedent
from linecache import getline
def do(func):
"""
Use to emulate do notation syntax on a function. Rewrites the underlying
source code to desugar the notation.
Example:
>>> @do
>>> def add_two(first, second):
>>> num1 <<= first
>>> num2 <<= second
>>> return num1 + num2
>>> assert add_two(Just(1), Just(2)) == Just(3)
The desugared version of the function would look something like the following...
>>> def add_two(first, second):
>>> def __GENERATED__0(num1):
>>> def __GENERATED__1(num2):
>>> return __LAST_MONAD.pure(num1 + num2)
>>> __LAST_MONAD = second
>>> return __LAST_MONAD.flat_map(__GENERATED__1)
>>> __LAST_MONAD = first
>>> return __LAST_MONAD.flat_map(__GENERATED__0)
Due to this, bind notation (name <<= expression) is limited to expressions
at the root level of the function. No nested bindings.
This makes things simpler, and removes unexpected behavour.
For example a monad does not have to execute code right away:
>>> @do
>>> def my_func(monad):
>>> with context(...):
>>> val <<= monad # This is a nested bind statement
The generated output (if allowed):
>>> def my_func(monad):
>>> with context(...):
>>> def __GENERATED__0(val):
>>> return __LAST_MONAD.pure(None)
>>> __LAST_MONAD = monad
>>> return __LAST_MONAD.flat_map(__GENERATED__0)
>>> # We can return here without actually having executed the function
>>> # and the context will have expired.
"""
source = dedent(getsource(func))
source_ast = ast.increment_lineno(
ast.parse(source),
n=func.__code__.co_firstlineno - 1,
)
modified_ast = ReWrite(func).visit(source_ast)
new_code = next(
filter(
lambda x: isinstance(x, CodeType),
compile(modified_ast, func.__code__.co_filename, "exec").co_consts,
)
)
new_func = FunctionType(
new_code,
func.__globals__,
)
@wraps(func)
def wrapper(*args, **kwargs):
return new_func(*args, **kwargs)
return wrapper
class ReWrite(ast.NodeTransformer):
LAST_MONAD_NAME = "__LAST_MONAD"
FUNC_NAME = "__GENERATED__{}"
def __init__(self, func):
self.func = func
def visit_FunctionDef(self, node):
"""Strip leading decorators"""
if not any(map(self._parse_do_expr, ast.walk(node))):
err = SyntaxError(
f"Decorated function {self.func} contains no <<= expression, cannot infer monad type."
)
err.lineno = self.func.__code__.co_firstlineno
err.filename = self.func.__code__.co_filename
err.text = getline(err.filename, err.lineno)
raise err
node.decorator_list.clear()
node.body = list(self._parse_body(enumerate(iter(node.body))))
return node
def _parse_body(self, statements):
for i, stmt in statements:
# Checking for:
# name <<= expression
do_expr = self._parse_do_expr(stmt)
if do_expr:
for inner_stmt in self._write_flatmap(
i, do_expr[0], do_expr[1], statements
):
yield inner_stmt
continue
return_ = self._parse_return(stmt)
if return_:
yield self._write_pure(return_)
continue
nested_do_expr = tuple(
filter(None, map(self._parse_do_expr, ast.walk(stmt)))
)
if nested_do_expr:
err = SyntaxError("Nested <<= operators not supported")
err.lineno = nested_do_expr[0][0].lineno
err.filename = self.func.__code__.co_filename
err.text = getline(err.filename, err.lineno)
raise err
else:
yield stmt
# Always return None at the end
yield self._write_pure(
ast.fix_missing_locations(ast.Constant(value=None, kind=None))
)
def _write_flatmap(self, i, var, expr, body):
loc = lambda n: ast.copy_location(n, var)
# Break code into functions
# >>> def __GENERATED__0(name):
# >>> ...
func_name = self.FUNC_NAME.format(i)
func = self._build_func(
func_name,
[var.id],
body,
loc,
)
callee = self._build_name(func_name, loc)
yield func
# Evaluate the monad expression and store it
# >>> __LAST_MONAD = expr
last_monad = self._build_assign(self.LAST_MONAD_NAME, expr, loc)
yield last_monad
# Return with a flatmapped expression
# >>> return __LAST_MONAD.flat_map(__GENERATED__0)
return_ = loc(
ast.Return(
value=self._build_flatmap(
self._build_name(self.LAST_MONAD_NAME, loc),
callee,
loc,
),
)
)
yield return_
def _write_pure(self, expr):
loc = lambda n: ast.copy_location(n, expr)
return_ = loc(
ast.Return(
value=self._build_pure(
self._build_name(self.LAST_MONAD_NAME, loc),
expr,
loc,
)
)
)
return return_
def _parse_do_expr(self, node):
if not (isinstance(node, ast.AugAssign) and isinstance(node.op, ast.LShift)):
return None
return node.target, node.value
def _parse_return(self, node):
if isinstance(node, ast.Return):
return node.value
return None
@staticmethod
def _build_name(name, loc):
return loc(ast.Name(id=name, ctx=ast.Load()))
@staticmethod
def _build_assign(name, expr, loc):
return loc(
ast.Assign(
targets=[loc(ast.Name(id=name, ctx=ast.Store()))],
value=expr,
)
)
@staticmethod
def _build_args(args):
return ast.arguments(
args=args,
posonlyargs=[],
kwonlyargs=[],
kw_defaults=[],
vararg=None,
kwarg=None,
defaults=[],
)
def _build_func(self, name, args, body, loc):
return loc(
ast.FunctionDef(
name=name,
args=self._build_args([loc(ast.arg(arg=arg)) for arg in args]),
body=list(self._parse_body(body)),
decorator_list=[],
)
)
@staticmethod
def _build_flatmap(monad, func, loc):
return loc(
ast.Call(
func=loc(
ast.Attribute(
value=monad,
attr="flat_map",
ctx=ast.Load(),
)
),
args=[func],
keywords=[],
)
)
@staticmethod
def _build_pure(monad, expr, loc):
return loc(
ast.Call(
func=loc(
ast.Attribute(
value=monad,
attr="pure",
ctx=ast.Load(),
)
),
args=[expr],
keywords=[],
)
)
if __name__ == "__main__":
import attr
from typing import Any, TypeVar, Generic, TYPE_CHECKING
A = TypeVar("A")
@attr.s
class Just(Generic[A]):
value = attr.ib()
@classmethod
def pure(cls, value):
return cls(value)
def flat_map(self, func):
return func(self.value)
if TYPE_CHECKING:
def __rlshift__(self, other: Any) -> A:
...
@attr.s
class Nothing:
@classmethod
def pure(cls, value):
return Just(value)
def flat_map(self, func):
return self
if TYPE_CHECKING:
def __rlshift__(self, other: Any) -> Any:
...
@do
def add_two_times_two(first: Just[int], second: Just[int]) -> Just[int]:
if TYPE_CHECKING:
num1 = num2 = None
num1 <<= first # flat_map / bind
num2 <<= second
num3 = num1 + num2
return num3 * 2
assert add_two_times_two(Just(1), Just(2)) == Just(6)
assert add_two_times_two(Nothing(), Just(2)) == Nothing()
assert add_two_times_two(Just(1), Nothing()) == Nothing()
assert add_two_times_two(Nothing(), Nothing()) == Nothing()
@do
def select_plus_two(first: Just[int], second: Just[int], select: bool) -> Just[int]:
if TYPE_CHECKING:
selection = None
# If statement in expression
selection <<= first if select else second
return selection + 2
assert select_plus_two(Just(1), Just(11), True) == Just(3)
assert select_plus_two(Just(1), Just(11), False) == Just(13)
try:
# Error if there is no bind <<=
@do
def nothing():
return 123
except SyntaxError:
pass
else:
assert False
# Implicitly return None
@do
def no_return():
_ <<= Just(1)
assert no_return() == Just(None)
try:
# Error for nested bind operators
@do
def loop():
val <<= Just(0)
for i in range(10):
j <<= Just(i)
val += j
return val
except SyntaxError:
pass
else:
assert False
# Loops and other nestings are ok, so long as they're contained within a generated function
@do
def loop():
val <<= Just(0)
for i in range(10):
val += i
return val
assert loop() == Just(45)
class Test:
@classmethod
@do
def run(cls):
val <<= Just(3)
return val * 2
assert Test.run() == Just(6)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment