Skip to content

Instantly share code, notes, and snippets.

@internetimagery
Last active August 27, 2021 21:48

Revisions

  1. internetimagery revised this gist Aug 27, 2021. 1 changed file with 29 additions and 19 deletions.
    48 changes: 29 additions & 19 deletions do_notation_ast.py
    Original file line number Diff line number Diff line change
    @@ -1,9 +1,9 @@
    import ast
    import inspect
    import types
    import functools
    import textwrap
    import linecache
    from inspect import getsource
    from types import FunctionType, CodeType
    from functools import wraps
    from textwrap import dedent
    from linecache import getline


    def do(func):
    @@ -55,19 +55,24 @@ def do(func):
    >>> # and the context will have expired.
    """
    source = textwrap.dedent(inspect.getsource(func))
    source = dedent(getsource(func))
    source_ast = ast.increment_lineno(
    ast.parse(source),
    n=func.__code__.co_firstlineno - 1,
    )
    modified_ast = ReWrite(func).visit(source_ast)

    new_func = types.FunctionType(
    compile(modified_ast, func.__code__.co_filename, "exec").co_consts[0],
    new_code = next(
    filter(
    lambda x: isinstance(x, CodeType),
    compile(modified_ast, func.__code__.co_filename, "exec").co_consts,
    )
    )
    new_func = FunctionType(
    new_code,
    func.__globals__,
    )

    @functools.wraps(func)
    @wraps(func)
    def wrapper(*args, **kwargs):
    return new_func(*args, **kwargs)

    @@ -89,7 +94,7 @@ def visit_FunctionDef(self, node):
    )
    err.lineno = self.func.__code__.co_firstlineno
    err.filename = self.func.__code__.co_filename
    err.text = linecache.getline(err.filename, err.lineno)
    err.text = getline(err.filename, err.lineno)
    raise err

    node.decorator_list.clear()
    @@ -120,7 +125,7 @@ def _parse_body(self, statements):
    err = SyntaxError("Nested <<= operators not supported")
    err.lineno = nested_do_expr[0][0].lineno
    err.filename = self.func.__code__.co_filename
    err.text = linecache.getline(err.filename, err.lineno)
    err.text = getline(err.filename, err.lineno)
    raise err
    else:
    yield stmt
    @@ -274,7 +279,9 @@ def flat_map(self, func):
    return func(self.value)

    if TYPE_CHECKING:
    def __rlshift__(self, other: Any) -> A: ...

    def __rlshift__(self, other: Any) -> A:
    ...

    @attr.s
    class Nothing:
    @@ -286,11 +293,15 @@ def flat_map(self, func):
    return self

    if TYPE_CHECKING:
    def __rlshift__(self, other: Any) -> Any: ...

    def __rlshift__(self, other: Any) -> Any:
    ...

    @do
    def add_two_times_two(first: Just[int], second: Just[int]) -> Just[int]:
    if TYPE_CHECKING: num1 = num2 = None
    if TYPE_CHECKING:
    num1 = num2 = None

    num1 <<= first # flat_map / bind
    num2 <<= second
    num3 = num1 + num2
    @@ -303,7 +314,9 @@ def add_two_times_two(first: Just[int], second: Just[int]) -> Just[int]:

    @do
    def select_plus_two(first: Just[int], second: Just[int], select: bool) -> Just[int]:
    if TYPE_CHECKING: selection = None
    if TYPE_CHECKING:
    selection = None

    # If statement in expression
    selection <<= first if select else second
    return selection + 2
    @@ -355,13 +368,10 @@ def loop():
    assert loop() == Just(45)

    class Test:

    @classmethod
    @do
    def run(cls):
    val <<= Just(3)
    return val * 2

    assert Test.run() == Just(6)


  2. internetimagery revised this gist Aug 27, 2021. 1 changed file with 16 additions and 3 deletions.
    19 changes: 16 additions & 3 deletions do_notation_ast.py
    Original file line number Diff line number Diff line change
    @@ -258,9 +258,12 @@ def _build_pure(monad, expr, loc):
    if __name__ == "__main__":

    import attr
    from typing import Any, TypeVar, Generic, TYPE_CHECKING

    A = TypeVar("A")

    @attr.s
    class Just:
    class Just(Generic[A]):
    value = attr.ib()

    @classmethod
    @@ -270,6 +273,9 @@ def pure(cls, value):
    def flat_map(self, func):
    return func(self.value)

    if TYPE_CHECKING:
    def __rlshift__(self, other: Any) -> A: ...

    @attr.s
    class Nothing:
    @classmethod
    @@ -279,8 +285,12 @@ def pure(cls, value):
    def flat_map(self, func):
    return self

    if TYPE_CHECKING:
    def __rlshift__(self, other: Any) -> Any: ...

    @do
    def add_two_times_two(first, second):
    def add_two_times_two(first: Just[int], second: Just[int]) -> Just[int]:
    if TYPE_CHECKING: num1 = num2 = None
    num1 <<= first # flat_map / bind
    num2 <<= second
    num3 = num1 + num2
    @@ -292,7 +302,8 @@ def add_two_times_two(first, second):
    assert add_two_times_two(Nothing(), Nothing()) == Nothing()

    @do
    def select_plus_two(first, second, select):
    def select_plus_two(first: Just[int], second: Just[int], select: bool) -> Just[int]:
    if TYPE_CHECKING: selection = None
    # If statement in expression
    selection <<= first if select else second
    return selection + 2
    @@ -352,3 +363,5 @@ def run(cls):
    return val * 2

    assert Test.run() == Just(6)


  3. internetimagery revised this gist Aug 27, 2021. 1 changed file with 10 additions and 1 deletion.
    11 changes: 10 additions & 1 deletion do_notation_ast.py
    Original file line number Diff line number Diff line change
    @@ -58,7 +58,7 @@ def do(func):
    source = textwrap.dedent(inspect.getsource(func))
    source_ast = ast.increment_lineno(
    ast.parse(source),
    n=func.__code__.co_firstlineno,
    n=func.__code__.co_firstlineno - 1,
    )
    modified_ast = ReWrite(func).visit(source_ast)

    @@ -343,3 +343,12 @@ def loop():

    assert loop() == Just(45)

    class Test:

    @classmethod
    @do
    def run(cls):
    val <<= Just(3)
    return val * 2

    assert Test.run() == Just(6)
  4. internetimagery revised this gist Aug 27, 2021. 1 changed file with 30 additions and 17 deletions.
    47 changes: 30 additions & 17 deletions do_notation_ast.py
    Original file line number Diff line number Diff line change
    @@ -3,6 +3,7 @@
    import types
    import functools
    import textwrap
    import linecache


    def do(func):
    @@ -59,7 +60,7 @@ def do(func):
    ast.parse(source),
    n=func.__code__.co_firstlineno,
    )
    modified_ast = ReWrite().visit(source_ast)
    modified_ast = ReWrite(func).visit(source_ast)

    new_func = types.FunctionType(
    compile(modified_ast, func.__code__.co_filename, "exec").co_consts[0],
    @@ -77,12 +78,19 @@ class ReWrite(ast.NodeTransformer):
    LAST_MONAD_NAME = "__LAST_MONAD"
    FUNC_NAME = "__GENERATED__{}"

    def __init__(self, func):
    self.func = func

    def visit_FunctionDef(self, node):
    """Strip leading decorators"""
    if not any(map(self._parse_do_expr, ast.walk(node))):
    raise TypeError(
    "Decorated function contains no <<= expression, cannot infer monad type."
    err = SyntaxError(
    f"Decorated function {self.func} contains no <<= expression, cannot infer monad type."
    )
    err.lineno = self.func.__code__.co_firstlineno
    err.filename = self.func.__code__.co_filename
    err.text = linecache.getline(err.filename, err.lineno)
    raise err

    node.decorator_list.clear()
    node.body = list(self._parse_body(enumerate(iter(node.body))))
    @@ -93,18 +101,27 @@ def _parse_body(self, statements):
    # Checking for:
    # name <<= expression
    do_expr = self._parse_do_expr(stmt)
    return_ = self._parse_return(stmt)
    if do_expr:
    for inner_stmt in self._write_flatmap(
    i, do_expr[0], do_expr[1], statements
    ):
    yield inner_stmt
    elif return_:
    continue

    return_ = self._parse_return(stmt)
    if return_:
    yield self._write_pure(return_)
    elif self._parse_assignment(stmt):
    yield stmt
    elif any(map(self._parse_do_expr, ast.walk(stmt))):
    raise TypeError("Nested <<= operators not supported")
    continue

    nested_do_expr = tuple(
    filter(None, map(self._parse_do_expr, ast.walk(stmt)))
    )
    if nested_do_expr:
    err = SyntaxError("Nested <<= operators not supported")
    err.lineno = nested_do_expr[0][0].lineno
    err.filename = self.func.__code__.co_filename
    err.text = linecache.getline(err.filename, err.lineno)
    raise err
    else:
    yield stmt

    @@ -170,11 +187,6 @@ def _parse_return(self, node):
    return node.value
    return None

    def _parse_assignment(self, node):
    if isinstance(node, ast.Assign):
    return node
    return None

    @staticmethod
    def _build_name(name, loc):
    return loc(ast.Name(id=name, ctx=ast.Load()))
    @@ -294,7 +306,7 @@ def select_plus_two(first, second, select):
    def nothing():
    return 123

    except TypeError:
    except SyntaxError:
    pass
    else:
    assert False
    @@ -316,12 +328,12 @@ def loop():
    val += j
    return val

    except TypeError:
    except SyntaxError:
    pass
    else:
    assert False

    # Loops are ok, so long as they're contained within a generated function
    # Loops and other nestings are ok, so long as they're contained within a generated function
    @do
    def loop():
    val <<= Just(0)
    @@ -330,3 +342,4 @@ def loop():
    return val

    assert loop() == Just(45)

  5. internetimagery revised this gist Aug 27, 2021. 1 changed file with 70 additions and 68 deletions.
    138 changes: 70 additions & 68 deletions do_notation_ast.py
    Original file line number Diff line number Diff line change
    @@ -5,6 +5,74 @@
    import textwrap


    def do(func):
    """
    Use to emulate do notation syntax on a function. Rewrites the underlying
    source code to desugar the notation.
    Example:
    >>> @do
    >>> def add_two(first, second):
    >>> num1 <<= first
    >>> num2 <<= second
    >>> return num1 + num2
    >>> assert add_two(Just(1), Just(2)) == Just(3)
    The desugared version of the function would look something like the following...
    >>> def add_two(first, second):
    >>> def __GENERATED__0(num1):
    >>> def __GENERATED__1(num2):
    >>> return __LAST_MONAD.pure(num1 + num2)
    >>> __LAST_MONAD = second
    >>> return __LAST_MONAD.flat_map(__GENERATED__1)
    >>> __LAST_MONAD = first
    >>> return __LAST_MONAD.flat_map(__GENERATED__0)
    Due to this, bind notation (name <<= expression) is limited to expressions
    at the root level of the function. No nested bindings.
    This makes things simpler, and removes unexpected behavour.
    For example a monad does not have to execute code right away:
    >>> @do
    >>> def my_func(monad):
    >>> with context(...):
    >>> val <<= monad # This is a nested bind statement
    The generated output (if allowed):
    >>> def my_func(monad):
    >>> with context(...):
    >>> def __GENERATED__0(val):
    >>> return __LAST_MONAD.pure(None)
    >>> __LAST_MONAD = monad
    >>> return __LAST_MONAD.flat_map(__GENERATED__0)
    >>> # We can return here without actually having executed the function
    >>> # and the context will have expired.
    """
    source = textwrap.dedent(inspect.getsource(func))
    source_ast = ast.increment_lineno(
    ast.parse(source),
    n=func.__code__.co_firstlineno,
    )
    modified_ast = ReWrite().visit(source_ast)

    new_func = types.FunctionType(
    compile(modified_ast, func.__code__.co_filename, "exec").co_consts[0],
    func.__globals__,
    )

    @functools.wraps(func)
    def wrapper(*args, **kwargs):
    return new_func(*args, **kwargs)

    return wrapper


    class ReWrite(ast.NodeTransformer):
    LAST_MONAD_NAME = "__LAST_MONAD"
    FUNC_NAME = "__GENERATED__{}"
    @@ -175,74 +243,6 @@ def _build_pure(monad, expr, loc):
    )


    def do(func):
    """
    Use to emulate do notation syntax on a function. Rewrites the underlying
    source code to desugar the notation.
    Example:
    >>> @do
    >>> def add_two(first, second):
    >>> num1 <<= first
    >>> num2 <<= second
    >>> return num1 + num2
    >>> assert add_two(Just(1), Just(2)) == Just(3)
    The desugared version of the function would look something like the following...
    >>> def add_two(first, second):
    >>> def __GENERATED__0(num1):
    >>> def __GENERATED__1(num2):
    >>> return __LAST_MONAD.pure(num1 + num2)
    >>> __LAST_MONAD = second
    >>> return __LAST_MONAD.flat_map(__GENERATED__1)
    >>> __LAST_MONAD = first
    >>> return __LAST_MONAD.flat_map(__GENERATED__0)
    Due to this, bind notation (name <<= expression) is limited to expressions
    at the root level of the function. No nested bindings.
    This makes things simpler, and removes unexpected behavour.
    For example a monad does not have to execute code right away:
    >>> @do
    >>> def my_func(monad):
    >>> with context(...):
    >>> val <<= monad # This is a nested bind statement
    The generated output (if allowed):
    >>> def my_func(monad):
    >>> with context(...):
    >>> def __GENERATED__0(val):
    >>> return __LAST_MONAD.pure(None)
    >>> __LAST_MONAD = monad
    >>> return __LAST_MONAD.flat_map(__GENERATED__0)
    >>> # We can return here without actually having executed the function
    >>> # and the context will have expired.
    """
    source = textwrap.dedent(inspect.getsource(func))
    source_ast = ast.increment_lineno(
    ast.parse(source),
    n=func.__code__.co_firstlineno,
    )
    modified_ast = ReWrite().visit(source_ast)

    new_func = types.FunctionType(
    compile(modified_ast, func.__code__.co_filename, "exec").co_consts[0],
    func.__globals__,
    )

    @functools.wraps(func)
    def wrapper(*args, **kwargs):
    return new_func(*args, **kwargs)

    return wrapper


    if __name__ == "__main__":

    import attr
    @@ -315,6 +315,7 @@ def loop():
    j <<= Just(i)
    val += j
    return val

    except TypeError:
    pass
    else:
    @@ -327,4 +328,5 @@ def loop():
    for i in range(10):
    val += i
    return val

    assert loop() == Just(45)
  6. internetimagery revised this gist Aug 27, 2021. 1 changed file with 24 additions and 2 deletions.
    26 changes: 24 additions & 2 deletions do_notation_ast.py
    Original file line number Diff line number Diff line change
    @@ -180,8 +180,6 @@ def do(func):
    Use to emulate do notation syntax on a function. Rewrites the underlying
    source code to desugar the notation.
    This does result in a limited number of valid statements available.
    Example:
    >>> @do
    @@ -202,6 +200,29 @@ def do(func):
    >>> return __LAST_MONAD.flat_map(__GENERATED__1)
    >>> __LAST_MONAD = first
    >>> return __LAST_MONAD.flat_map(__GENERATED__0)
    Due to this, bind notation (name <<= expression) is limited to expressions
    at the root level of the function. No nested bindings.
    This makes things simpler, and removes unexpected behavour.
    For example a monad does not have to execute code right away:
    >>> @do
    >>> def my_func(monad):
    >>> with context(...):
    >>> val <<= monad # This is a nested bind statement
    The generated output (if allowed):
    >>> def my_func(monad):
    >>> with context(...):
    >>> def __GENERATED__0(val):
    >>> return __LAST_MONAD.pure(None)
    >>> __LAST_MONAD = monad
    >>> return __LAST_MONAD.flat_map(__GENERATED__0)
    >>> # We can return here without actually having executed the function
    >>> # and the context will have expired.
    """
    source = textwrap.dedent(inspect.getsource(func))
    source_ast = ast.increment_lineno(
    @@ -299,6 +320,7 @@ def loop():
    else:
    assert False

    # Loops are ok, so long as they're contained within a generated function
    @do
    def loop():
    val <<= Just(0)
  7. internetimagery revised this gist Aug 27, 2021. 1 changed file with 58 additions and 6 deletions.
    64 changes: 58 additions & 6 deletions do_notation_ast.py
    Original file line number Diff line number Diff line change
    @@ -11,6 +11,11 @@ class ReWrite(ast.NodeTransformer):

    def visit_FunctionDef(self, node):
    """Strip leading decorators"""
    if not any(map(self._parse_do_expr, ast.walk(node))):
    raise TypeError(
    "Decorated function contains no <<= expression, cannot infer monad type."
    )

    node.decorator_list.clear()
    node.body = list(self._parse_body(enumerate(iter(node.body))))
    return node
    @@ -27,12 +32,18 @@ def _parse_body(self, statements):
    ):
    yield inner_stmt
    elif return_:
    for inner_stmt in self._write_pure(return_):
    yield inner_stmt
    yield self._write_pure(return_)
    elif self._parse_assignment(stmt):
    yield stmt
    elif any(map(self._parse_do_expr, ast.walk(stmt))):
    raise TypeError("Nested <<= operators not supported")
    else:
    raise TypeError("Unsupported statements in function")
    yield stmt

    # Always return None at the end
    yield self._write_pure(
    ast.fix_missing_locations(ast.Constant(value=None, kind=None))
    )

    def _write_flatmap(self, i, var, expr, body):
    loc = lambda n: ast.copy_location(n, var)
    @@ -79,7 +90,7 @@ def _write_pure(self, expr):
    )
    )
    )
    yield return_
    return return_

    def _parse_do_expr(self, node):
    if not (isinstance(node, ast.AugAssign) and isinstance(node.op, ast.LShift)):
    @@ -230,14 +241,14 @@ def flat_map(self, func):
    class Nothing:
    @classmethod
    def pure(cls, value):
    return cls()
    return Just(value)

    def flat_map(self, func):
    return self

    @do
    def add_two_times_two(first, second):
    num1 <<= first
    num1 <<= first # flat_map / bind
    num2 <<= second
    num3 = num1 + num2
    return num3 * 2
    @@ -249,8 +260,49 @@ def add_two_times_two(first, second):

    @do
    def select_plus_two(first, second, select):
    # If statement in expression
    selection <<= first if select else second
    return selection + 2

    assert select_plus_two(Just(1), Just(11), True) == Just(3)
    assert select_plus_two(Just(1), Just(11), False) == Just(13)

    try:
    # Error if there is no bind <<=
    @do
    def nothing():
    return 123

    except TypeError:
    pass
    else:
    assert False

    # Implicitly return None
    @do
    def no_return():
    _ <<= Just(1)

    assert no_return() == Just(None)

    try:
    # Error for nested bind operators
    @do
    def loop():
    val <<= Just(0)
    for i in range(10):
    j <<= Just(i)
    val += j
    return val
    except TypeError:
    pass
    else:
    assert False

    @do
    def loop():
    val <<= Just(0)
    for i in range(10):
    val += i
    return val
    assert loop() == Just(45)
  8. internetimagery revised this gist Aug 27, 2021. 1 changed file with 41 additions and 27 deletions.
    68 changes: 41 additions & 27 deletions do_notation_ast.py
    Original file line number Diff line number Diff line change
    @@ -7,6 +7,7 @@

    class ReWrite(ast.NodeTransformer):
    LAST_MONAD_NAME = "__LAST_MONAD"
    FUNC_NAME = "__GENERATED__{}"

    def visit_FunctionDef(self, node):
    """Strip leading decorators"""
    @@ -17,7 +18,7 @@ def visit_FunctionDef(self, node):
    def _parse_body(self, statements):
    for i, stmt in statements:
    # Checking for:
    # name <- expression
    # name <<= expression
    do_expr = self._parse_do_expr(stmt)
    return_ = self._parse_return(stmt)
    if do_expr:
    @@ -28,16 +29,18 @@ def _parse_body(self, statements):
    elif return_:
    for inner_stmt in self._write_pure(return_):
    yield inner_stmt
    else:
    elif self._parse_assignment(stmt):
    yield stmt
    else:
    raise TypeError("Unsupported statements in function")

    def _write_flatmap(self, i, var, expr, body):
    loc = lambda n: ast.copy_location(n, var)

    # Break code into functions
    # >>> def __GENERATED__0(name):
    # >>> ...
    func_name = f"__GENERATED__{i}"
    func_name = self.FUNC_NAME.format(i)
    func = self._build_func(
    func_name,
    [var.id],
    @@ -79,23 +82,20 @@ def _write_pure(self, expr):
    yield return_

    def _parse_do_expr(self, node):
    if (
    isinstance(node, ast.Expr)
    and isinstance(node.value, ast.Compare)
    and isinstance(node.value.left, ast.Name)
    and isinstance(node.value.left.ctx, ast.Load)
    and isinstance(node.value.ops[0], ast.Lt)
    and isinstance(node.value.comparators[0], ast.UnaryOp)
    and isinstance(node.value.comparators[0].op, ast.USub)
    ):
    return node.value.left, node.value.comparators[0].operand
    return None
    if not (isinstance(node, ast.AugAssign) and isinstance(node.op, ast.LShift)):
    return None
    return node.target, node.value

    def _parse_return(self, node):
    if isinstance(node, ast.Return):
    return node.value
    return None

    def _parse_assignment(self, node):
    if isinstance(node, ast.Assign):
    return node
    return None

    @staticmethod
    def _build_name(name, loc):
    return loc(ast.Name(id=name, ctx=ast.Load()))
    @@ -167,12 +167,16 @@ def _build_pure(monad, expr, loc):
    def do(func):
    """
    Use to emulate do notation syntax on a function. Rewrites the underlying
    source code to desugar the notation. Example:
    source code to desugar the notation.
    This does result in a limited number of valid statements available.
    Example:
    >>> @do
    >>> def add_two(first, second):
    >>> num1 <- first
    >>> num2 <- second
    >>> num1 <<= first
    >>> num2 <<= second
    >>> return num1 + num2
    >>> assert add_two(Just(1), Just(2)) == Just(3)
    @@ -189,7 +193,10 @@ def do(func):
    >>> return __LAST_MONAD.flat_map(__GENERATED__0)
    """
    source = textwrap.dedent(inspect.getsource(func))
    source_ast = ast.parse(source)
    source_ast = ast.increment_lineno(
    ast.parse(source),
    n=func.__code__.co_firstlineno,
    )
    modified_ast = ReWrite().visit(source_ast)

    new_func = types.FunctionType(
    @@ -229,14 +236,21 @@ def flat_map(self, func):
    return self

    @do
    def add_two(first, second):
    test = 123
    num1 <- first
    num2 <- second
    return num1 + num2
    def add_two_times_two(first, second):
    num1 <<= first
    num2 <<= second
    num3 = num1 + num2
    return num3 * 2

    assert add_two(Just(1), Just(2)) == Just(3)
    assert add_two_times_two(Just(1), Just(2)) == Just(6)
    assert add_two_times_two(Nothing(), Just(2)) == Nothing()
    assert add_two_times_two(Just(1), Nothing()) == Nothing()
    assert add_two_times_two(Nothing(), Nothing()) == Nothing()

    @do
    def select_plus_two(first, second, select):
    selection <<= first if select else second
    return selection + 2

    assert add_two(Nothing(), Just(2)) == Nothing()
    assert add_two(Just(1), Nothing()) == Nothing()
    assert add_two(Nothing(), Nothing()) == Nothing()
    assert select_plus_two(Just(1), Just(11), True) == Just(3)
    assert select_plus_two(Just(1), Just(11), False) == Just(13)
  9. internetimagery revised this gist Aug 27, 2021. 1 changed file with 152 additions and 48 deletions.
    200 changes: 152 additions & 48 deletions do_notation_ast.py
    Original file line number Diff line number Diff line change
    @@ -1,64 +1,113 @@
    import ast
    import inspect
    import copy
    import types
    import functools
    import textwrap


    class ReWrite(ast.NodeTransformer):
    LAST_MONAD_NAME = "__LAST_MONAD"

    def visit_FunctionDef(self, node):
    """Strip leading decorators"""
    node.decorator_list.clear()
    node.body = list(self._parseBody(enumerate(iter(node.body))))
    node.body = list(self._parse_body(enumerate(iter(node.body))))
    return node

    def _parseBody(self, statements):
    def _parse_body(self, statements):
    for i, stmt in statements:
    # Checking for:
    # name <- expression
    result = self._parseDoExpr(stmt)
    if result:
    do_expr = self._parse_do_expr(stmt)
    return_ = self._parse_return(stmt)
    if do_expr:
    for inner_stmt in self._write_flatmap(
    i, result[0], result[1], statements
    i, do_expr[0], do_expr[1], statements
    ):
    yield inner_stmt
    continue
    yield stmt
    elif return_:
    for inner_stmt in self._write_pure(return_):
    yield inner_stmt
    else:
    yield stmt

    def _write_flatmap(self, i, var, expr, body):
    loc = lambda n: ast.copy_location(n, var)

    # Break code into functions
    # >>> def __GENERATED__0(name):
    # >>> ...
    func_name = f"__GENERATED__{i}"
    func = loc(
    self._build_func(
    func_name,
    [loc(ast.arg(arg=var.id))],
    body,
    ),
    func = self._build_func(
    func_name,
    [var.id],
    body,
    loc,
    )
    callee = loc(ast.Name(id=func_name, ctx=ast.Load()))
    callee = self._build_name(func_name, loc)
    yield func

    # Evaluate the monad expression and store it
    # >>> __LAST_MONAD = expr
    last_monad = self._build_assign(self.LAST_MONAD_NAME, expr, loc)
    yield last_monad

    # Return with a flatmapped expression
    # >>> return __LAST_MONAD.flat_map(__GENERATED__0)
    return_ = loc(
    ast.Return(
    value=loc(self._build_flatmap(expr, callee, loc)),
    value=self._build_flatmap(
    self._build_name(self.LAST_MONAD_NAME, loc),
    callee,
    loc,
    ),
    )
    )
    yield return_

    def _parseDoExpr(self, node):
    def _write_pure(self, expr):
    loc = lambda n: ast.copy_location(n, expr)
    return_ = loc(
    ast.Return(
    value=self._build_pure(
    self._build_name(self.LAST_MONAD_NAME, loc),
    expr,
    loc,
    )
    )
    )
    yield return_

    def _parse_do_expr(self, node):
    if (
    not isinstance(node, ast.Expr)
    or not isinstance(node.value, ast.Compare)
    or not isinstance(node.value.left, ast.Name)
    or not isinstance(node.value.left.ctx, ast.Load)
    or not isinstance(node.value.ops[0], ast.Lt)
    or not isinstance(node.value.comparators[0], ast.UnaryOp)
    or not isinstance(node.value.comparators[0].op, ast.USub)
    isinstance(node, ast.Expr)
    and isinstance(node.value, ast.Compare)
    and isinstance(node.value.left, ast.Name)
    and isinstance(node.value.left.ctx, ast.Load)
    and isinstance(node.value.ops[0], ast.Lt)
    and isinstance(node.value.comparators[0], ast.UnaryOp)
    and isinstance(node.value.comparators[0].op, ast.USub)
    ):
    return None
    return node.value.left, node.value.comparators[0].operand
    return node.value.left, node.value.comparators[0].operand
    return None

    def _parse_return(self, node):
    if isinstance(node, ast.Return):
    return node.value
    return None

    @staticmethod
    def _build_name(name, loc):
    return loc(ast.Name(id=name, ctx=ast.Load()))

    @staticmethod
    def _build_assign(name, expr, loc):
    return loc(
    ast.Assign(
    targets=[loc(ast.Name(id=name, ctx=ast.Store()))],
    value=expr,
    )
    )

    @staticmethod
    def _build_args(args):
    @@ -72,31 +121,73 @@ def _build_args(args):
    defaults=[],
    )

    def _build_func(self, name, args, body):
    return ast.FunctionDef(
    name=name,
    args=self._build_args(args),
    body=list(self._parseBody(body)),
    decorator_list=[],
    def _build_func(self, name, args, body, loc):
    return loc(
    ast.FunctionDef(
    name=name,
    args=self._build_args([loc(ast.arg(arg=arg)) for arg in args]),
    body=list(self._parse_body(body)),
    decorator_list=[],
    )
    )

    @staticmethod
    def _build_flatmap(monad, func, loc):
    return ast.Call(
    func=loc(
    ast.Attribute(
    value=monad,
    attr="flat_map",
    ctx=ast.Load(),
    )
    ),
    args=[func],
    keywords=[],
    return loc(
    ast.Call(
    func=loc(
    ast.Attribute(
    value=monad,
    attr="flat_map",
    ctx=ast.Load(),
    )
    ),
    args=[func],
    keywords=[],
    )
    )

    @staticmethod
    def _build_pure(monad, expr, loc):
    return loc(
    ast.Call(
    func=loc(
    ast.Attribute(
    value=monad,
    attr="pure",
    ctx=ast.Load(),
    )
    ),
    args=[expr],
    keywords=[],
    )
    )

    def do(func):

    def do(func):
    """
    Use to emulate do notation syntax on a function. Rewrites the underlying
    source code to desugar the notation. Example:
    >>> @do
    >>> def add_two(first, second):
    >>> num1 <- first
    >>> num2 <- second
    >>> return num1 + num2
    >>> assert add_two(Just(1), Just(2)) == Just(3)
    The desugared version of the function would look something like the following...
    >>> def add_two(first, second):
    >>> def __GENERATED__0(num1):
    >>> def __GENERATED__1(num2):
    >>> return __LAST_MONAD.pure(num1 + num2)
    >>> __LAST_MONAD = second
    >>> return __LAST_MONAD.flat_map(__GENERATED__1)
    >>> __LAST_MONAD = first
    >>> return __LAST_MONAD.flat_map(__GENERATED__0)
    """
    source = textwrap.dedent(inspect.getsource(func))
    source_ast = ast.parse(source)
    modified_ast = ReWrite().visit(source_ast)
    @@ -121,18 +212,31 @@ def wrapper(*args, **kwargs):
    class Just:
    value = attr.ib()

    def map(self, func):
    return self.__class__(func(self.value))
    @classmethod
    def pure(cls, value):
    return cls(value)

    def flat_map(self, func):
    return func(self.value)

    @attr.s
    class Nothing:
    @classmethod
    def pure(cls, value):
    return cls()

    def flat_map(self, func):
    return self

    @do
    def add_two(first, second):
    test = 123
    num1 <- first
    num2 <- second
    #TODO: Still need to handle the final "pure" on the return value.
    return num1 + num2

    result = add_two(Just(1), Just(2))
    print("Result:", result)
    assert add_two(Just(1), Just(2)) == Just(3)

    assert add_two(Nothing(), Just(2)) == Nothing()
    assert add_two(Just(1), Nothing()) == Nothing()
    assert add_two(Nothing(), Nothing()) == Nothing()
  10. internetimagery created this gist Aug 27, 2021.
    138 changes: 138 additions & 0 deletions do_notation_ast.py
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,138 @@
    import ast
    import inspect
    import copy
    import types
    import functools
    import textwrap


    class ReWrite(ast.NodeTransformer):
    def visit_FunctionDef(self, node):
    """Strip leading decorators"""
    node.decorator_list.clear()
    node.body = list(self._parseBody(enumerate(iter(node.body))))
    return node

    def _parseBody(self, statements):
    for i, stmt in statements:
    # Checking for:
    # name <- expression
    result = self._parseDoExpr(stmt)
    if result:
    for inner_stmt in self._write_flatmap(
    i, result[0], result[1], statements
    ):
    yield inner_stmt
    continue
    yield stmt

    def _write_flatmap(self, i, var, expr, body):
    loc = lambda n: ast.copy_location(n, var)

    func_name = f"__GENERATED__{i}"
    func = loc(
    self._build_func(
    func_name,
    [loc(ast.arg(arg=var.id))],
    body,
    ),
    )
    callee = loc(ast.Name(id=func_name, ctx=ast.Load()))
    yield func

    return_ = loc(
    ast.Return(
    value=loc(self._build_flatmap(expr, callee, loc)),
    )
    )
    yield return_

    def _parseDoExpr(self, node):
    if (
    not isinstance(node, ast.Expr)
    or not isinstance(node.value, ast.Compare)
    or not isinstance(node.value.left, ast.Name)
    or not isinstance(node.value.left.ctx, ast.Load)
    or not isinstance(node.value.ops[0], ast.Lt)
    or not isinstance(node.value.comparators[0], ast.UnaryOp)
    or not isinstance(node.value.comparators[0].op, ast.USub)
    ):
    return None
    return node.value.left, node.value.comparators[0].operand

    @staticmethod
    def _build_args(args):
    return ast.arguments(
    args=args,
    posonlyargs=[],
    kwonlyargs=[],
    kw_defaults=[],
    vararg=None,
    kwarg=None,
    defaults=[],
    )

    def _build_func(self, name, args, body):
    return ast.FunctionDef(
    name=name,
    args=self._build_args(args),
    body=list(self._parseBody(body)),
    decorator_list=[],
    )

    @staticmethod
    def _build_flatmap(monad, func, loc):
    return ast.Call(
    func=loc(
    ast.Attribute(
    value=monad,
    attr="flat_map",
    ctx=ast.Load(),
    )
    ),
    args=[func],
    keywords=[],
    )


    def do(func):

    source = textwrap.dedent(inspect.getsource(func))
    source_ast = ast.parse(source)
    modified_ast = ReWrite().visit(source_ast)

    new_func = types.FunctionType(
    compile(modified_ast, func.__code__.co_filename, "exec").co_consts[0],
    func.__globals__,
    )

    @functools.wraps(func)
    def wrapper(*args, **kwargs):
    return new_func(*args, **kwargs)

    return wrapper


    if __name__ == "__main__":

    import attr

    @attr.s
    class Just:
    value = attr.ib()

    def map(self, func):
    return self.__class__(func(self.value))

    def flat_map(self, func):
    return func(self.value)

    @do
    def add_two(first, second):
    num1 <- first
    num2 <- second
    #TODO: Still need to handle the final "pure" on the return value.
    return num1 + num2

    result = add_two(Just(1), Just(2))
    print("Result:", result)