Skip to content

Instantly share code, notes, and snippets.

@Adam-Vandervorst
Created March 4, 2020 19:35
Show Gist options
  • Save Adam-Vandervorst/6e6c26e710bd8eea28612d2502fa72e7 to your computer and use it in GitHub Desktop.
Save Adam-Vandervorst/6e6c26e710bd8eea28612d2502fa72e7 to your computer and use it in GitHub Desktop.
Expression recursion schemes
def ana(coalgebra):
def ana_(seed, alg=coalgebra):
return alg(seed).map(ana_)
return ana_
def apo(coalgebra):
def apo_(seed, alg=coalgebra):
stop, v = alg(seed)
return v if stop else v.map(apo_)
return apo_
def cata(algebra):
def cata_(data, alg=algebra):
return alg(data.map(cata_))
return cata_
def para(algebra):
def para_(data, alg=algebra):
return alg(data, data.map(para_))
return para_
def hylo(algebra, coalgebra):
def hylo_(seed):
return algebra(coalgebra(seed).map(hylo_))
return hylo_
def histo(algebra):
def histo_(data, hist=(), alg=algebra):
new = alg(data, *hist)
return data.map(lambda dat: histo_(dat, (new,) + hist))
return histo_
class Insert:
def __init__(self, ins): self.ins = ins
def futu(coalgebra):
def worker(fa):
return fa.ins.map(worker) if isinstance(fa, Insert) else futu_(fa)
def futu_(seed):
return coalgebra(seed).map(worker)
return futu_
def chrono(algebra, coalgebra):
def chrono_(data, hist=()):
stop, new = coalgebra(data, *hist)
new_hist = (new,) + hist
return algebra(new if stop else new.map(lambda dat: chrono_(dat, new_hist)), *new_hist)
return chrono_
from typing import Any
from dataclasses import dataclass, field
class AutoFunctor(type):
def __init__(cls, name, bases, dct):
if bases:
base, *_ = bases
if hasattr(cls, '__annotations__'):
rec_fields = [fld for fld, tp in cls.__annotations__.items() if tp is base]
else:
rec_fields = None
def _map(self, f): return cls(*[f(getattr(self, fld)) for fld in rec_fields])
def _id(self, f): return self
cls.map = _map if rec_fields else _id
super().__init__(name, bases, dct)
class SimpleExpr(metaclass=AutoFunctor): pass
@dataclass
class Num(SimpleExpr):
value: float = field()
@dataclass
class Mul(SimpleExpr):
l: SimpleExpr = field()
r: SimpleExpr = field()
class GraphBase(metaclass=AutoFunctor):
def __add__(self, other): return Overlay(self, other)
def __mul__(self, other): return Connect(self, other)
class Empty(GraphBase): pass
@dataclass
class Vertex(GraphBase):
data: Any = field()
@dataclass
class Overlay(GraphBase):
x: GraphBase = field()
y: GraphBase = field()
@dataclass
class Connect(GraphBase):
x: GraphBase = field()
y: GraphBase = field()
class BiTree(metaclass=AutoFunctor): pass
class TreeEmpty(BiTree): pass
@dataclass
class Leaf(BiTree):
x: Any = field()
@dataclass
class Node(BiTree):
l: BiTree = field()
r: BiTree = field()
class Expr(metaclass=AutoFunctor): pass
@dataclass
class Var(Expr): pass
@dataclass
class Zero(Expr): pass
@dataclass
class One(Expr): pass
@dataclass
class Neg(Expr):
x: Expr = field()
@dataclass
class Exp(Expr):
x: Expr = field()
@dataclass
class Add(Expr):
x: Expr = field()
y: Expr = field()
@dataclass
class Prod(Expr):
x: Expr = field()
y: Expr = field()
def to_string_alg(fa):
if isinstance(fa, Num): return f"Num({fa.value})"
elif isinstance(fa, Mul): return f"Mul({fa.l}, {fa.r})"
def graph_size_alg(fa):
if isinstance(fa, Empty): return 0
elif isinstance(fa, Vertex): return 1
elif isinstance(fa, Overlay): return fa.x + fa.y
elif isinstance(fa, Connect): return fa.x + fa.y
def graph_flatmap(f):
def graph_flatmap_alg(fa):
if isinstance(fa, Empty): return fa
elif isinstance(fa, Vertex): return f(fa)
elif isinstance(fa, Overlay): return fa
elif isinstance(fa, Connect): return fa
return graph_flatmap_alg
def graph_vertexset_alg(fa):
if isinstance(fa, Empty): return {}
elif isinstance(fa, Vertex): return {fa.data}
elif isinstance(fa, Overlay): return fa.x | fa.y
elif isinstance(fa, Connect): return fa.x | fa.y
def mergesort_coalg(seed):
if seed == []: return TreeEmpty()
elif len(seed) == 1: return Leaf(seed[0])
else:
middle = len(seed)//2
return Node(seed[:middle], seed[middle:])
def merge(l, r):
if not r: return l
if not l: return r
x, *xs = l; y, *ys = r
return [x, *merge(xs, r)] if x < y else [y, *merge(l, ys)]
def mergesort_alg(fa):
if isinstance(fa, TreeEmpty): return []
elif isinstance(fa, Leaf): return [fa.x]
elif isinstance(fa, Node): return merge(fa.l, fa.r)
def str_expr_alg(fa):
if isinstance(fa, Var): return 'x'
elif isinstance(fa, Zero): return '0'
elif isinstance(fa, One): return '1'
elif isinstance(fa, Neg): return f'-{fa.x}'
elif isinstance(fa, Exp): return f'e^({fa.x})'
elif isinstance(fa, Add): return f'({fa.x} + {fa.y})'
elif isinstance(fa, Prod): return f'({fa.x}*{fa.y})'
def eval_expr(x):
def eval_expr_alg(fa):
if isinstance(fa, Var): return x
elif isinstance(fa, Zero): return 0
elif isinstance(fa, One): return 1
elif isinstance(fa, Neg): return -fa.x
elif isinstance(fa, Exp): return 2.71828**fa.x
elif isinstance(fa, Add): return fa.x + fa.y
elif isinstance(fa, Prod): return fa.x*fa.y
return eval_expr_alg
def diff_expr_alg(fa, dfa):
if isinstance(fa, Var): return One()
elif isinstance(fa, Zero): return Zero()
elif isinstance(fa, One): return Zero()
elif isinstance(fa, Neg): return Neg(dfa.x)
elif isinstance(fa, Exp): return Prod(Exp(fa.x), dfa.x)
elif isinstance(fa, Add): return Add(dfa.x, dfa.y)
elif isinstance(fa, Prod): return Add(Prod(fa.x, dfa.y), Prod(dfa.x, fa.y))
def simplify_expr_alg(fa):
if isinstance(fa, Neg):
if isinstance(fa.x, Neg): return fa.x.x
elif isinstance(fa, Exp):
if isinstance(fa.x, Zero): return One()
elif isinstance(fa, Add):
if isinstance(fa.x, Zero): return fa.y
elif isinstance(fa.y, Zero): return fa.x
elif isinstance(fa.y, Zero): return fa.x
elif isinstance(fa, Prod):
if isinstance(fa.x, Zero): return fa.x
elif isinstance(fa.y, Zero): return fa.y
elif isinstance(fa.x, One): return fa.y
elif isinstance(fa.y, One): return fa.x
elif isinstance(fa.x, Neg) and isinstance(fa.x.x, One): return Neg(fa.y)
elif isinstance(fa.y, Neg) and isinstance(fa.y.x, One): return Neg(fa.x)
return fa
if __name__ == '__main__':
from recursion_schemes import cata, hylo, para, apo
# -x2 + -e^(1 - x)
expr = Add(Neg(Prod(Var(), Var())), Neg(Exp(Add(One(), Neg(Var())))))
print(cata(eval_expr(1.2))(expr))
print(cata(str_expr_alg)(para(diff_expr_alg)(expr))) # e^(1 - x) - 2x
#(-((x*1) + (1*x)) + -(e^((1 + -x))*(0 + -1)))
print(cata(str_expr_alg)(cata(simplify_expr_alg)(para(diff_expr_alg)(expr))))
quit()
print(hylo(mergesort_alg, mergesort_coalg)("ADEBC"))
print(cata(to_string_alg)(Mul(Num(1), Mul(Num(2), Num(2)))))
# Algebraic graphs
g = Vertex('x')*(Vertex('y') + Vertex('z'))
print(cata(graph_size_alg)(g))
print(cata(graph_flatmap(lambda x: x.data*2))(g))
def graph_induce(p):
return graph_flatmap(lambda x: x if p(x) else Empty())
print(cata(graph_induce(lambda v: v.data != 'x'))(g))
print(cata(graph_vertexset_alg)(g))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment