Created
March 4, 2020 19:35
-
-
Save Adam-Vandervorst/6e6c26e710bd8eea28612d2502fa72e7 to your computer and use it in GitHub Desktop.
Expression recursion schemes
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def ana(coalgebra): | |
def ana_(seed, alg=coalgebra): | |
return alg(seed).map(ana_) | |
return ana_ | |
def apo(coalgebra): | |
def apo_(seed, alg=coalgebra): | |
stop, v = alg(seed) | |
return v if stop else v.map(apo_) | |
return apo_ | |
def cata(algebra): | |
def cata_(data, alg=algebra): | |
return alg(data.map(cata_)) | |
return cata_ | |
def para(algebra): | |
def para_(data, alg=algebra): | |
return alg(data, data.map(para_)) | |
return para_ | |
def hylo(algebra, coalgebra): | |
def hylo_(seed): | |
return algebra(coalgebra(seed).map(hylo_)) | |
return hylo_ | |
def histo(algebra): | |
def histo_(data, hist=(), alg=algebra): | |
new = alg(data, *hist) | |
return data.map(lambda dat: histo_(dat, (new,) + hist)) | |
return histo_ | |
class Insert: | |
def __init__(self, ins): self.ins = ins | |
def futu(coalgebra): | |
def worker(fa): | |
return fa.ins.map(worker) if isinstance(fa, Insert) else futu_(fa) | |
def futu_(seed): | |
return coalgebra(seed).map(worker) | |
return futu_ | |
def chrono(algebra, coalgebra): | |
def chrono_(data, hist=()): | |
stop, new = coalgebra(data, *hist) | |
new_hist = (new,) + hist | |
return algebra(new if stop else new.map(lambda dat: chrono_(dat, new_hist)), *new_hist) | |
return chrono_ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from typing import Any | |
from dataclasses import dataclass, field | |
class AutoFunctor(type): | |
def __init__(cls, name, bases, dct): | |
if bases: | |
base, *_ = bases | |
if hasattr(cls, '__annotations__'): | |
rec_fields = [fld for fld, tp in cls.__annotations__.items() if tp is base] | |
else: | |
rec_fields = None | |
def _map(self, f): return cls(*[f(getattr(self, fld)) for fld in rec_fields]) | |
def _id(self, f): return self | |
cls.map = _map if rec_fields else _id | |
super().__init__(name, bases, dct) | |
class SimpleExpr(metaclass=AutoFunctor): pass | |
@dataclass | |
class Num(SimpleExpr): | |
value: float = field() | |
@dataclass | |
class Mul(SimpleExpr): | |
l: SimpleExpr = field() | |
r: SimpleExpr = field() | |
class GraphBase(metaclass=AutoFunctor): | |
def __add__(self, other): return Overlay(self, other) | |
def __mul__(self, other): return Connect(self, other) | |
class Empty(GraphBase): pass | |
@dataclass | |
class Vertex(GraphBase): | |
data: Any = field() | |
@dataclass | |
class Overlay(GraphBase): | |
x: GraphBase = field() | |
y: GraphBase = field() | |
@dataclass | |
class Connect(GraphBase): | |
x: GraphBase = field() | |
y: GraphBase = field() | |
class BiTree(metaclass=AutoFunctor): pass | |
class TreeEmpty(BiTree): pass | |
@dataclass | |
class Leaf(BiTree): | |
x: Any = field() | |
@dataclass | |
class Node(BiTree): | |
l: BiTree = field() | |
r: BiTree = field() | |
class Expr(metaclass=AutoFunctor): pass | |
@dataclass | |
class Var(Expr): pass | |
@dataclass | |
class Zero(Expr): pass | |
@dataclass | |
class One(Expr): pass | |
@dataclass | |
class Neg(Expr): | |
x: Expr = field() | |
@dataclass | |
class Exp(Expr): | |
x: Expr = field() | |
@dataclass | |
class Add(Expr): | |
x: Expr = field() | |
y: Expr = field() | |
@dataclass | |
class Prod(Expr): | |
x: Expr = field() | |
y: Expr = field() | |
def to_string_alg(fa): | |
if isinstance(fa, Num): return f"Num({fa.value})" | |
elif isinstance(fa, Mul): return f"Mul({fa.l}, {fa.r})" | |
def graph_size_alg(fa): | |
if isinstance(fa, Empty): return 0 | |
elif isinstance(fa, Vertex): return 1 | |
elif isinstance(fa, Overlay): return fa.x + fa.y | |
elif isinstance(fa, Connect): return fa.x + fa.y | |
def graph_flatmap(f): | |
def graph_flatmap_alg(fa): | |
if isinstance(fa, Empty): return fa | |
elif isinstance(fa, Vertex): return f(fa) | |
elif isinstance(fa, Overlay): return fa | |
elif isinstance(fa, Connect): return fa | |
return graph_flatmap_alg | |
def graph_vertexset_alg(fa): | |
if isinstance(fa, Empty): return {} | |
elif isinstance(fa, Vertex): return {fa.data} | |
elif isinstance(fa, Overlay): return fa.x | fa.y | |
elif isinstance(fa, Connect): return fa.x | fa.y | |
def mergesort_coalg(seed): | |
if seed == []: return TreeEmpty() | |
elif len(seed) == 1: return Leaf(seed[0]) | |
else: | |
middle = len(seed)//2 | |
return Node(seed[:middle], seed[middle:]) | |
def merge(l, r): | |
if not r: return l | |
if not l: return r | |
x, *xs = l; y, *ys = r | |
return [x, *merge(xs, r)] if x < y else [y, *merge(l, ys)] | |
def mergesort_alg(fa): | |
if isinstance(fa, TreeEmpty): return [] | |
elif isinstance(fa, Leaf): return [fa.x] | |
elif isinstance(fa, Node): return merge(fa.l, fa.r) | |
def str_expr_alg(fa): | |
if isinstance(fa, Var): return 'x' | |
elif isinstance(fa, Zero): return '0' | |
elif isinstance(fa, One): return '1' | |
elif isinstance(fa, Neg): return f'-{fa.x}' | |
elif isinstance(fa, Exp): return f'e^({fa.x})' | |
elif isinstance(fa, Add): return f'({fa.x} + {fa.y})' | |
elif isinstance(fa, Prod): return f'({fa.x}*{fa.y})' | |
def eval_expr(x): | |
def eval_expr_alg(fa): | |
if isinstance(fa, Var): return x | |
elif isinstance(fa, Zero): return 0 | |
elif isinstance(fa, One): return 1 | |
elif isinstance(fa, Neg): return -fa.x | |
elif isinstance(fa, Exp): return 2.71828**fa.x | |
elif isinstance(fa, Add): return fa.x + fa.y | |
elif isinstance(fa, Prod): return fa.x*fa.y | |
return eval_expr_alg | |
def diff_expr_alg(fa, dfa): | |
if isinstance(fa, Var): return One() | |
elif isinstance(fa, Zero): return Zero() | |
elif isinstance(fa, One): return Zero() | |
elif isinstance(fa, Neg): return Neg(dfa.x) | |
elif isinstance(fa, Exp): return Prod(Exp(fa.x), dfa.x) | |
elif isinstance(fa, Add): return Add(dfa.x, dfa.y) | |
elif isinstance(fa, Prod): return Add(Prod(fa.x, dfa.y), Prod(dfa.x, fa.y)) | |
def simplify_expr_alg(fa): | |
if isinstance(fa, Neg): | |
if isinstance(fa.x, Neg): return fa.x.x | |
elif isinstance(fa, Exp): | |
if isinstance(fa.x, Zero): return One() | |
elif isinstance(fa, Add): | |
if isinstance(fa.x, Zero): return fa.y | |
elif isinstance(fa.y, Zero): return fa.x | |
elif isinstance(fa.y, Zero): return fa.x | |
elif isinstance(fa, Prod): | |
if isinstance(fa.x, Zero): return fa.x | |
elif isinstance(fa.y, Zero): return fa.y | |
elif isinstance(fa.x, One): return fa.y | |
elif isinstance(fa.y, One): return fa.x | |
elif isinstance(fa.x, Neg) and isinstance(fa.x.x, One): return Neg(fa.y) | |
elif isinstance(fa.y, Neg) and isinstance(fa.y.x, One): return Neg(fa.x) | |
return fa | |
if __name__ == '__main__': | |
from recursion_schemes import cata, hylo, para, apo | |
# -x2 + -e^(1 - x) | |
expr = Add(Neg(Prod(Var(), Var())), Neg(Exp(Add(One(), Neg(Var()))))) | |
print(cata(eval_expr(1.2))(expr)) | |
print(cata(str_expr_alg)(para(diff_expr_alg)(expr))) # e^(1 - x) - 2x | |
#(-((x*1) + (1*x)) + -(e^((1 + -x))*(0 + -1))) | |
print(cata(str_expr_alg)(cata(simplify_expr_alg)(para(diff_expr_alg)(expr)))) | |
quit() | |
print(hylo(mergesort_alg, mergesort_coalg)("ADEBC")) | |
print(cata(to_string_alg)(Mul(Num(1), Mul(Num(2), Num(2))))) | |
# Algebraic graphs | |
g = Vertex('x')*(Vertex('y') + Vertex('z')) | |
print(cata(graph_size_alg)(g)) | |
print(cata(graph_flatmap(lambda x: x.data*2))(g)) | |
def graph_induce(p): | |
return graph_flatmap(lambda x: x if p(x) else Empty()) | |
print(cata(graph_induce(lambda v: v.data != 'x'))(g)) | |
print(cata(graph_vertexset_alg)(g)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment