Last active
August 27, 2021 21:48
Revisions
-
internetimagery revised this gist
Aug 27, 2021 . 1 changed file with 29 additions and 19 deletions.There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -1,9 +1,9 @@ import ast from inspect import getsource from types import FunctionType, CodeType from functools import wraps from textwrap import dedent from linecache import getline def do(func): @@ -55,19 +55,24 @@ def do(func): >>> # and the context will have expired. """ source = dedent(getsource(func)) source_ast = ast.increment_lineno( ast.parse(source), n=func.__code__.co_firstlineno - 1, ) modified_ast = ReWrite(func).visit(source_ast) new_code = next( filter( lambda x: isinstance(x, CodeType), compile(modified_ast, func.__code__.co_filename, "exec").co_consts, ) ) new_func = FunctionType( new_code, func.__globals__, ) @wraps(func) def wrapper(*args, **kwargs): return new_func(*args, **kwargs) @@ -89,7 +94,7 @@ def visit_FunctionDef(self, node): ) err.lineno = self.func.__code__.co_firstlineno err.filename = self.func.__code__.co_filename err.text = getline(err.filename, err.lineno) raise err node.decorator_list.clear() @@ -120,7 +125,7 @@ def _parse_body(self, statements): err = SyntaxError("Nested <<= operators not supported") err.lineno = nested_do_expr[0][0].lineno err.filename = self.func.__code__.co_filename err.text = getline(err.filename, err.lineno) raise err else: yield stmt @@ -274,7 +279,9 @@ def flat_map(self, func): return func(self.value) if TYPE_CHECKING: def __rlshift__(self, other: Any) -> A: ... @attr.s class Nothing: @@ -286,11 +293,15 @@ def flat_map(self, func): return self if TYPE_CHECKING: def __rlshift__(self, other: Any) -> Any: ... @do def add_two_times_two(first: Just[int], second: Just[int]) -> Just[int]: if TYPE_CHECKING: num1 = num2 = None num1 <<= first # flat_map / bind num2 <<= second num3 = num1 + num2 @@ -303,7 +314,9 @@ def add_two_times_two(first: Just[int], second: Just[int]) -> Just[int]: @do def select_plus_two(first: Just[int], second: Just[int], select: bool) -> Just[int]: if TYPE_CHECKING: selection = None # If statement in expression selection <<= first if select else second return selection + 2 @@ -355,13 +368,10 @@ def loop(): assert loop() == Just(45) class Test: @classmethod @do def run(cls): val <<= Just(3) return val * 2 assert Test.run() == Just(6) -
internetimagery revised this gist
Aug 27, 2021 . 1 changed file with 16 additions and 3 deletions.There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -258,9 +258,12 @@ def _build_pure(monad, expr, loc): if __name__ == "__main__": import attr from typing import Any, TypeVar, Generic, TYPE_CHECKING A = TypeVar("A") @attr.s class Just(Generic[A]): value = attr.ib() @classmethod @@ -270,6 +273,9 @@ def pure(cls, value): def flat_map(self, func): return func(self.value) if TYPE_CHECKING: def __rlshift__(self, other: Any) -> A: ... @attr.s class Nothing: @classmethod @@ -279,8 +285,12 @@ def pure(cls, value): def flat_map(self, func): return self if TYPE_CHECKING: def __rlshift__(self, other: Any) -> Any: ... @do def add_two_times_two(first: Just[int], second: Just[int]) -> Just[int]: if TYPE_CHECKING: num1 = num2 = None num1 <<= first # flat_map / bind num2 <<= second num3 = num1 + num2 @@ -292,7 +302,8 @@ def add_two_times_two(first, second): assert add_two_times_two(Nothing(), Nothing()) == Nothing() @do def select_plus_two(first: Just[int], second: Just[int], select: bool) -> Just[int]: if TYPE_CHECKING: selection = None # If statement in expression selection <<= first if select else second return selection + 2 @@ -352,3 +363,5 @@ def run(cls): return val * 2 assert Test.run() == Just(6) -
internetimagery revised this gist
Aug 27, 2021 . 1 changed file with 10 additions and 1 deletion.There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -58,7 +58,7 @@ def do(func): source = textwrap.dedent(inspect.getsource(func)) source_ast = ast.increment_lineno( ast.parse(source), n=func.__code__.co_firstlineno - 1, ) modified_ast = ReWrite(func).visit(source_ast) @@ -343,3 +343,12 @@ def loop(): assert loop() == Just(45) class Test: @classmethod @do def run(cls): val <<= Just(3) return val * 2 assert Test.run() == Just(6) -
internetimagery revised this gist
Aug 27, 2021 . 1 changed file with 30 additions and 17 deletions.There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -3,6 +3,7 @@ import types import functools import textwrap import linecache def do(func): @@ -59,7 +60,7 @@ def do(func): ast.parse(source), n=func.__code__.co_firstlineno, ) modified_ast = ReWrite(func).visit(source_ast) new_func = types.FunctionType( compile(modified_ast, func.__code__.co_filename, "exec").co_consts[0], @@ -77,12 +78,19 @@ class ReWrite(ast.NodeTransformer): LAST_MONAD_NAME = "__LAST_MONAD" FUNC_NAME = "__GENERATED__{}" def __init__(self, func): self.func = func def visit_FunctionDef(self, node): """Strip leading decorators""" if not any(map(self._parse_do_expr, ast.walk(node))): err = SyntaxError( f"Decorated function {self.func} contains no <<= expression, cannot infer monad type." ) err.lineno = self.func.__code__.co_firstlineno err.filename = self.func.__code__.co_filename err.text = linecache.getline(err.filename, err.lineno) raise err node.decorator_list.clear() node.body = list(self._parse_body(enumerate(iter(node.body)))) @@ -93,18 +101,27 @@ def _parse_body(self, statements): # Checking for: # name <<= expression do_expr = self._parse_do_expr(stmt) if do_expr: for inner_stmt in self._write_flatmap( i, do_expr[0], do_expr[1], statements ): yield inner_stmt continue return_ = self._parse_return(stmt) if return_: yield self._write_pure(return_) continue nested_do_expr = tuple( filter(None, map(self._parse_do_expr, ast.walk(stmt))) ) if nested_do_expr: err = SyntaxError("Nested <<= operators not supported") err.lineno = nested_do_expr[0][0].lineno err.filename = self.func.__code__.co_filename err.text = linecache.getline(err.filename, err.lineno) raise err else: yield stmt @@ -170,11 +187,6 @@ def _parse_return(self, node): return node.value return None @staticmethod def _build_name(name, loc): return loc(ast.Name(id=name, ctx=ast.Load())) @@ -294,7 +306,7 @@ def select_plus_two(first, second, select): def nothing(): return 123 except SyntaxError: pass else: assert False @@ -316,12 +328,12 @@ def loop(): val += j return val except SyntaxError: pass else: assert False # Loops and other nestings are ok, so long as they're contained within a generated function @do def loop(): val <<= Just(0) @@ -330,3 +342,4 @@ def loop(): return val assert loop() == Just(45) -
internetimagery revised this gist
Aug 27, 2021 . 1 changed file with 70 additions and 68 deletions.There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -5,6 +5,74 @@ import textwrap def do(func): """ Use to emulate do notation syntax on a function. Rewrites the underlying source code to desugar the notation. Example: >>> @do >>> def add_two(first, second): >>> num1 <<= first >>> num2 <<= second >>> return num1 + num2 >>> assert add_two(Just(1), Just(2)) == Just(3) The desugared version of the function would look something like the following... >>> def add_two(first, second): >>> def __GENERATED__0(num1): >>> def __GENERATED__1(num2): >>> return __LAST_MONAD.pure(num1 + num2) >>> __LAST_MONAD = second >>> return __LAST_MONAD.flat_map(__GENERATED__1) >>> __LAST_MONAD = first >>> return __LAST_MONAD.flat_map(__GENERATED__0) Due to this, bind notation (name <<= expression) is limited to expressions at the root level of the function. No nested bindings. This makes things simpler, and removes unexpected behavour. For example a monad does not have to execute code right away: >>> @do >>> def my_func(monad): >>> with context(...): >>> val <<= monad # This is a nested bind statement The generated output (if allowed): >>> def my_func(monad): >>> with context(...): >>> def __GENERATED__0(val): >>> return __LAST_MONAD.pure(None) >>> __LAST_MONAD = monad >>> return __LAST_MONAD.flat_map(__GENERATED__0) >>> # We can return here without actually having executed the function >>> # and the context will have expired. """ source = textwrap.dedent(inspect.getsource(func)) source_ast = ast.increment_lineno( ast.parse(source), n=func.__code__.co_firstlineno, ) modified_ast = ReWrite().visit(source_ast) new_func = types.FunctionType( compile(modified_ast, func.__code__.co_filename, "exec").co_consts[0], func.__globals__, ) @functools.wraps(func) def wrapper(*args, **kwargs): return new_func(*args, **kwargs) return wrapper class ReWrite(ast.NodeTransformer): LAST_MONAD_NAME = "__LAST_MONAD" FUNC_NAME = "__GENERATED__{}" @@ -175,74 +243,6 @@ def _build_pure(monad, expr, loc): ) if __name__ == "__main__": import attr @@ -315,6 +315,7 @@ def loop(): j <<= Just(i) val += j return val except TypeError: pass else: @@ -327,4 +328,5 @@ def loop(): for i in range(10): val += i return val assert loop() == Just(45) -
internetimagery revised this gist
Aug 27, 2021 . 1 changed file with 24 additions and 2 deletions.There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -180,8 +180,6 @@ def do(func): Use to emulate do notation syntax on a function. Rewrites the underlying source code to desugar the notation. Example: >>> @do @@ -202,6 +200,29 @@ def do(func): >>> return __LAST_MONAD.flat_map(__GENERATED__1) >>> __LAST_MONAD = first >>> return __LAST_MONAD.flat_map(__GENERATED__0) Due to this, bind notation (name <<= expression) is limited to expressions at the root level of the function. No nested bindings. This makes things simpler, and removes unexpected behavour. For example a monad does not have to execute code right away: >>> @do >>> def my_func(monad): >>> with context(...): >>> val <<= monad # This is a nested bind statement The generated output (if allowed): >>> def my_func(monad): >>> with context(...): >>> def __GENERATED__0(val): >>> return __LAST_MONAD.pure(None) >>> __LAST_MONAD = monad >>> return __LAST_MONAD.flat_map(__GENERATED__0) >>> # We can return here without actually having executed the function >>> # and the context will have expired. """ source = textwrap.dedent(inspect.getsource(func)) source_ast = ast.increment_lineno( @@ -299,6 +320,7 @@ def loop(): else: assert False # Loops are ok, so long as they're contained within a generated function @do def loop(): val <<= Just(0) -
internetimagery revised this gist
Aug 27, 2021 . 1 changed file with 58 additions and 6 deletions.There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -11,6 +11,11 @@ class ReWrite(ast.NodeTransformer): def visit_FunctionDef(self, node): """Strip leading decorators""" if not any(map(self._parse_do_expr, ast.walk(node))): raise TypeError( "Decorated function contains no <<= expression, cannot infer monad type." ) node.decorator_list.clear() node.body = list(self._parse_body(enumerate(iter(node.body)))) return node @@ -27,12 +32,18 @@ def _parse_body(self, statements): ): yield inner_stmt elif return_: yield self._write_pure(return_) elif self._parse_assignment(stmt): yield stmt elif any(map(self._parse_do_expr, ast.walk(stmt))): raise TypeError("Nested <<= operators not supported") else: yield stmt # Always return None at the end yield self._write_pure( ast.fix_missing_locations(ast.Constant(value=None, kind=None)) ) def _write_flatmap(self, i, var, expr, body): loc = lambda n: ast.copy_location(n, var) @@ -79,7 +90,7 @@ def _write_pure(self, expr): ) ) ) return return_ def _parse_do_expr(self, node): if not (isinstance(node, ast.AugAssign) and isinstance(node.op, ast.LShift)): @@ -230,14 +241,14 @@ def flat_map(self, func): class Nothing: @classmethod def pure(cls, value): return Just(value) def flat_map(self, func): return self @do def add_two_times_two(first, second): num1 <<= first # flat_map / bind num2 <<= second num3 = num1 + num2 return num3 * 2 @@ -249,8 +260,49 @@ def add_two_times_two(first, second): @do def select_plus_two(first, second, select): # If statement in expression selection <<= first if select else second return selection + 2 assert select_plus_two(Just(1), Just(11), True) == Just(3) assert select_plus_two(Just(1), Just(11), False) == Just(13) try: # Error if there is no bind <<= @do def nothing(): return 123 except TypeError: pass else: assert False # Implicitly return None @do def no_return(): _ <<= Just(1) assert no_return() == Just(None) try: # Error for nested bind operators @do def loop(): val <<= Just(0) for i in range(10): j <<= Just(i) val += j return val except TypeError: pass else: assert False @do def loop(): val <<= Just(0) for i in range(10): val += i return val assert loop() == Just(45) -
internetimagery revised this gist
Aug 27, 2021 . 1 changed file with 41 additions and 27 deletions.There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -7,6 +7,7 @@ class ReWrite(ast.NodeTransformer): LAST_MONAD_NAME = "__LAST_MONAD" FUNC_NAME = "__GENERATED__{}" def visit_FunctionDef(self, node): """Strip leading decorators""" @@ -17,7 +18,7 @@ def visit_FunctionDef(self, node): def _parse_body(self, statements): for i, stmt in statements: # Checking for: # name <<= expression do_expr = self._parse_do_expr(stmt) return_ = self._parse_return(stmt) if do_expr: @@ -28,16 +29,18 @@ def _parse_body(self, statements): elif return_: for inner_stmt in self._write_pure(return_): yield inner_stmt elif self._parse_assignment(stmt): yield stmt else: raise TypeError("Unsupported statements in function") def _write_flatmap(self, i, var, expr, body): loc = lambda n: ast.copy_location(n, var) # Break code into functions # >>> def __GENERATED__0(name): # >>> ... func_name = self.FUNC_NAME.format(i) func = self._build_func( func_name, [var.id], @@ -79,23 +82,20 @@ def _write_pure(self, expr): yield return_ def _parse_do_expr(self, node): if not (isinstance(node, ast.AugAssign) and isinstance(node.op, ast.LShift)): return None return node.target, node.value def _parse_return(self, node): if isinstance(node, ast.Return): return node.value return None def _parse_assignment(self, node): if isinstance(node, ast.Assign): return node return None @staticmethod def _build_name(name, loc): return loc(ast.Name(id=name, ctx=ast.Load())) @@ -167,12 +167,16 @@ def _build_pure(monad, expr, loc): def do(func): """ Use to emulate do notation syntax on a function. Rewrites the underlying source code to desugar the notation. This does result in a limited number of valid statements available. Example: >>> @do >>> def add_two(first, second): >>> num1 <<= first >>> num2 <<= second >>> return num1 + num2 >>> assert add_two(Just(1), Just(2)) == Just(3) @@ -189,7 +193,10 @@ def do(func): >>> return __LAST_MONAD.flat_map(__GENERATED__0) """ source = textwrap.dedent(inspect.getsource(func)) source_ast = ast.increment_lineno( ast.parse(source), n=func.__code__.co_firstlineno, ) modified_ast = ReWrite().visit(source_ast) new_func = types.FunctionType( @@ -229,14 +236,21 @@ def flat_map(self, func): return self @do def add_two_times_two(first, second): num1 <<= first num2 <<= second num3 = num1 + num2 return num3 * 2 assert add_two_times_two(Just(1), Just(2)) == Just(6) assert add_two_times_two(Nothing(), Just(2)) == Nothing() assert add_two_times_two(Just(1), Nothing()) == Nothing() assert add_two_times_two(Nothing(), Nothing()) == Nothing() @do def select_plus_two(first, second, select): selection <<= first if select else second return selection + 2 assert select_plus_two(Just(1), Just(11), True) == Just(3) assert select_plus_two(Just(1), Just(11), False) == Just(13) -
internetimagery revised this gist
Aug 27, 2021 . 1 changed file with 152 additions and 48 deletions.There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -1,64 +1,113 @@ import ast import inspect import types import functools import textwrap class ReWrite(ast.NodeTransformer): LAST_MONAD_NAME = "__LAST_MONAD" def visit_FunctionDef(self, node): """Strip leading decorators""" node.decorator_list.clear() node.body = list(self._parse_body(enumerate(iter(node.body)))) return node def _parse_body(self, statements): for i, stmt in statements: # Checking for: # name <- expression do_expr = self._parse_do_expr(stmt) return_ = self._parse_return(stmt) if do_expr: for inner_stmt in self._write_flatmap( i, do_expr[0], do_expr[1], statements ): yield inner_stmt elif return_: for inner_stmt in self._write_pure(return_): yield inner_stmt else: yield stmt def _write_flatmap(self, i, var, expr, body): loc = lambda n: ast.copy_location(n, var) # Break code into functions # >>> def __GENERATED__0(name): # >>> ... func_name = f"__GENERATED__{i}" func = self._build_func( func_name, [var.id], body, loc, ) callee = self._build_name(func_name, loc) yield func # Evaluate the monad expression and store it # >>> __LAST_MONAD = expr last_monad = self._build_assign(self.LAST_MONAD_NAME, expr, loc) yield last_monad # Return with a flatmapped expression # >>> return __LAST_MONAD.flat_map(__GENERATED__0) return_ = loc( ast.Return( value=self._build_flatmap( self._build_name(self.LAST_MONAD_NAME, loc), callee, loc, ), ) ) yield return_ def _write_pure(self, expr): loc = lambda n: ast.copy_location(n, expr) return_ = loc( ast.Return( value=self._build_pure( self._build_name(self.LAST_MONAD_NAME, loc), expr, loc, ) ) ) yield return_ def _parse_do_expr(self, node): if ( isinstance(node, ast.Expr) and isinstance(node.value, ast.Compare) and isinstance(node.value.left, ast.Name) and isinstance(node.value.left.ctx, ast.Load) and isinstance(node.value.ops[0], ast.Lt) and isinstance(node.value.comparators[0], ast.UnaryOp) and isinstance(node.value.comparators[0].op, ast.USub) ): return node.value.left, node.value.comparators[0].operand return None def _parse_return(self, node): if isinstance(node, ast.Return): return node.value return None @staticmethod def _build_name(name, loc): return loc(ast.Name(id=name, ctx=ast.Load())) @staticmethod def _build_assign(name, expr, loc): return loc( ast.Assign( targets=[loc(ast.Name(id=name, ctx=ast.Store()))], value=expr, ) ) @staticmethod def _build_args(args): @@ -72,31 +121,73 @@ def _build_args(args): defaults=[], ) def _build_func(self, name, args, body, loc): return loc( ast.FunctionDef( name=name, args=self._build_args([loc(ast.arg(arg=arg)) for arg in args]), body=list(self._parse_body(body)), decorator_list=[], ) ) @staticmethod def _build_flatmap(monad, func, loc): return loc( ast.Call( func=loc( ast.Attribute( value=monad, attr="flat_map", ctx=ast.Load(), ) ), args=[func], keywords=[], ) ) @staticmethod def _build_pure(monad, expr, loc): return loc( ast.Call( func=loc( ast.Attribute( value=monad, attr="pure", ctx=ast.Load(), ) ), args=[expr], keywords=[], ) ) def do(func): """ Use to emulate do notation syntax on a function. Rewrites the underlying source code to desugar the notation. Example: >>> @do >>> def add_two(first, second): >>> num1 <- first >>> num2 <- second >>> return num1 + num2 >>> assert add_two(Just(1), Just(2)) == Just(3) The desugared version of the function would look something like the following... >>> def add_two(first, second): >>> def __GENERATED__0(num1): >>> def __GENERATED__1(num2): >>> return __LAST_MONAD.pure(num1 + num2) >>> __LAST_MONAD = second >>> return __LAST_MONAD.flat_map(__GENERATED__1) >>> __LAST_MONAD = first >>> return __LAST_MONAD.flat_map(__GENERATED__0) """ source = textwrap.dedent(inspect.getsource(func)) source_ast = ast.parse(source) modified_ast = ReWrite().visit(source_ast) @@ -121,18 +212,31 @@ def wrapper(*args, **kwargs): class Just: value = attr.ib() @classmethod def pure(cls, value): return cls(value) def flat_map(self, func): return func(self.value) @attr.s class Nothing: @classmethod def pure(cls, value): return cls() def flat_map(self, func): return self @do def add_two(first, second): test = 123 num1 <- first num2 <- second return num1 + num2 assert add_two(Just(1), Just(2)) == Just(3) assert add_two(Nothing(), Just(2)) == Nothing() assert add_two(Just(1), Nothing()) == Nothing() assert add_two(Nothing(), Nothing()) == Nothing() -
internetimagery created this gist
Aug 27, 2021 .There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -0,0 +1,138 @@ import ast import inspect import copy import types import functools import textwrap class ReWrite(ast.NodeTransformer): def visit_FunctionDef(self, node): """Strip leading decorators""" node.decorator_list.clear() node.body = list(self._parseBody(enumerate(iter(node.body)))) return node def _parseBody(self, statements): for i, stmt in statements: # Checking for: # name <- expression result = self._parseDoExpr(stmt) if result: for inner_stmt in self._write_flatmap( i, result[0], result[1], statements ): yield inner_stmt continue yield stmt def _write_flatmap(self, i, var, expr, body): loc = lambda n: ast.copy_location(n, var) func_name = f"__GENERATED__{i}" func = loc( self._build_func( func_name, [loc(ast.arg(arg=var.id))], body, ), ) callee = loc(ast.Name(id=func_name, ctx=ast.Load())) yield func return_ = loc( ast.Return( value=loc(self._build_flatmap(expr, callee, loc)), ) ) yield return_ def _parseDoExpr(self, node): if ( not isinstance(node, ast.Expr) or not isinstance(node.value, ast.Compare) or not isinstance(node.value.left, ast.Name) or not isinstance(node.value.left.ctx, ast.Load) or not isinstance(node.value.ops[0], ast.Lt) or not isinstance(node.value.comparators[0], ast.UnaryOp) or not isinstance(node.value.comparators[0].op, ast.USub) ): return None return node.value.left, node.value.comparators[0].operand @staticmethod def _build_args(args): return ast.arguments( args=args, posonlyargs=[], kwonlyargs=[], kw_defaults=[], vararg=None, kwarg=None, defaults=[], ) def _build_func(self, name, args, body): return ast.FunctionDef( name=name, args=self._build_args(args), body=list(self._parseBody(body)), decorator_list=[], ) @staticmethod def _build_flatmap(monad, func, loc): return ast.Call( func=loc( ast.Attribute( value=monad, attr="flat_map", ctx=ast.Load(), ) ), args=[func], keywords=[], ) def do(func): source = textwrap.dedent(inspect.getsource(func)) source_ast = ast.parse(source) modified_ast = ReWrite().visit(source_ast) new_func = types.FunctionType( compile(modified_ast, func.__code__.co_filename, "exec").co_consts[0], func.__globals__, ) @functools.wraps(func) def wrapper(*args, **kwargs): return new_func(*args, **kwargs) return wrapper if __name__ == "__main__": import attr @attr.s class Just: value = attr.ib() def map(self, func): return self.__class__(func(self.value)) def flat_map(self, func): return func(self.value) @do def add_two(first, second): num1 <- first num2 <- second #TODO: Still need to handle the final "pure" on the return value. return num1 + num2 result = add_two(Just(1), Just(2)) print("Result:", result)