Skip to content

Instantly share code, notes, and snippets.

@felko
Last active April 1, 2020 18:26
Show Gist options
  • Star 4 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save felko/f831f59c014feed4e8228582ea6ad36e to your computer and use it in GitHub Desktop.
Save felko/f831f59c014feed4e8228582ea6ad36e to your computer and use it in GitHub Desktop.
#!/usr/bin/env python3.7
# coding: utf-8
import types
import functools
import abc
import typing
from contextlib import contextmanager
from collections import OrderedDict
def _to_tuple(val):
if isinstance(val, tuple):
return val
else:
return (val,)
def _from_tuple(val):
if len(val) == 1:
return val[0]
else:
return val
class PatternMatchError(ValueError):
pass
class Pattern:
def __rmatmul__(self, name):
return As(name, self)
def __mul__(self, pat):
return Tup(self, pat)
def __and__(self, pat):
return And(self, pat)
def __or__(self, pat):
return Or(self, pat)
def match(self, val, env):
raise NotImplementedError
class Val(Pattern):
__slots__ = ['expected_val']
def __init__(self, val):
self.expected_val = val
def match(self, val, env):
if val != self.expected_val:
raise PatternMatchError(f"Expected {val} to be {self.expected_val}")
class Is(Pattern):
__slots__ = ['typ']
def __init__(self, typ):
self.typ = typ
def match(self, val, env):
if not isinstance(val, self.typ):
raise PatternMatchError(f"Expected {self.typ}, got {val} of type {type(val)}")
class Alt(Pattern):
__slots__ = ['pats']
def __init__(self, *pats):
self.pats = list(pats)
def __ror__(self, lhs):
return Alt(lhs, *self.pats)
def __or__(self, rhs):
pats = self.pats + [rhs]
return Alt(*pats)
def match(self, val, env):
for pat in self.pats:
new_env = env.copy()
try:
pat.match(val, new_env)
except PatternMatchError:
continue
else:
env.update(new_env)
return
raise PatternMatchError(f"Non exhaustive pattern match, missing case: {val}")
class Nil(Alt):
def __init__(self):
super().__init__()
class And(Pattern):
def __init__(self, *pats):
self.pats = list(pats)
def __rand__(self, lhs):
return And(lhs, *self.pats)
def __and__(self, rhs):
pats = self.pats + [rhs]
return And(*pats)
def match(self, val, env):
for pat in self.pats:
pat.match(val, env)
class Any(Pattern):
def match(self, val, env):
pass
class Var(Pattern):
def __init__(self, name):
self.name = name
def match(self, val, env):
if self.name not in env:
env[self.name] = val
elif env[self.name] is not val:
raise PatternMatchError(f"Variable pattern failed to satisfy {self.name} = {env[self.name]} != {val}")
class As(Var):
def __init__(self, name, pat):
super().__init__(name)
self.pat = pat
def match(self, val, env):
super().match(val, env)
pat.match(env)
class View(Pattern):
def __init__(self, f, pat):
self.f = f
self.pat = pat
def match(self, val, env):
self.pat.match(self.f(val), env)
class Tup(Pattern):
def __init__(self, *pats):
self.pats = list(pats)
def __rmul__(self, lhs):
return Tup(lhs, *self.pats)
def __mul__(self, rhs):
pats = self.pats + [rhs]
return Tup(*pats)
def match(self, vals, env):
if isinstance(vals, tuple) and len(vals) == len(self.pats):
for pat, val in zip(self.pats, vals):
pat.match(val, env)
else:
raise PatternMatchError(f"Cannot match tuple {vals}, expected {len(self.pats)} elements")
class Constr(Pattern):
def __init__(self, prism, args):
self.prism = prism
self.args = _to_tuple(args)
def match(self, val, env):
vals = self.prism.preview(val)
if isinstance(self.prism, Case) and isinstance(val, self.prism.cls) and val._case is self.prism:
pass
elif isinstance(self.prism, Case):
raise PatternMatchError(f"Expected {self.prism.name}, got {val}")
if isinstance(vals, tuple) and len(vals) == len(self.args):
for pat, val in zip(self.args, vals):
pat.match(val, env)
else:
raise PatternMatchError(f"Expected {self.prism.review.__name__}, got {val}")
class Prism:
def __init__(self, review, preview=None):
self.review = review
self.preview = preview
def __lshift__(self, args):
return Constr(self, _to_tuple(args))
def __call__(self, *args):
return self.review(*args)
def unwrap(self, p):
self.preview = p
class Case(Prism):
def __init__(self, check, name=None):
@functools.wraps(check)
def _review_wrapper(*args, **kwargs):
check(*args, **kwargs)
return self.cls(self, args)
super().__init__(_review_wrapper, lambda obj: obj._args)
self.cls = None
self.name = name or check.__name__
self.check = check
def match(self, val, env):
if isinstance(val, self.cls):
return super().match(val, env)
else:
raise PatternMatchError(f"Expected value of type {self.cls}, got {type(val)}")
class ADTMeta(abc.ABCMeta):
_ADTBase = None
def __new__(mcs, name, bases, attrs, renaming=()):
if name == 'ADT' and mcs._ADTBase is None:
mcs._ADTBase = super().__new__(mcs, name, (), attrs)
return mcs._ADTBase
new_bases = []
cls = super().__new__(mcs, name, (mcs._ADTBase,), attrs)
for b in bases:
if isinstance(b, ADTMeta) and b is not mcs._ADTBase:
if b.__bases__ == (mcs._ADTBase,):
b.__bases__ = (cls,)
else:
print(cls, b)
b.__bases__ += (cls,)
else:
new_bases.append(b)
cls.__bases__ = tuple(new_bases) or (mcs._ADTBase,)
return cls
def __prepare__(mcs, bases, renaming=()):
renaming = dict(renaming)
env = {}
for b in bases:
if isinstance(b, ADTMeta):
for name, attr in b.__dict__.items():
if isinstance(attr, Case):
n = renaming.get(attr.name, attr.name)
env[n] = attr
env[attr.name] = attr
elif isinstance(attr, _CaseMethod):
env[name] = _CaseMethod(attr._default, cases=attr._cases)
return env
def __init__(cls, name, bases, attrs, renaming=()):
cls.__renaming__ = dict(renaming)
if name == 'ADT':
cls.__traits__ = ()
cls.__supers__ = ()
cls.__constrs__ = ()
super().__init__(name, (), attrs)
return
constrs = []
for attr in attrs.values():
if isinstance(attr, Case):
attr.cls = cls
constrs.append(attr)
elif isinstance(attr, _CaseMethod):
attr._cls = cls
cls.__constrs__ = tuple(constrs)
traits = []
for b in bases:
if isinstance(b, ADTMeta):
traits.extend(b.__traits__)
elif issubclass(b, Trait):
traits.append(b)
cls.__traits__ = tuple(traits)
cls.__supers__ = tuple(b for b in bases if isinstance(b, ADTMeta))
super().__init__(name, cls.__bases__, attrs)
def _propagate_new_superclass(cls, new):
bs = list(cls.__supers__)
bs.remove(cls._ADTBase)
for base in bs:
if isinstance(base, ADTMeta):
base._propagate_new_superclass(cls)
cls.__bases__ = tuple(bs) + (new,)
class ADT(metaclass=ADTMeta):
def __init__(self, *args):
if len(args) == 0:
raise ValueError("Expected value or case/arguments pair")
elif len(args) == 1:
val, = args
if isinstance(val, type(self)):
self._case = val._case
self._args = val._args
else:
raise TypeError(f"Expected a subtype of {type(self)}, got {type(val)}")
elif len(args) == 2:
case, args = args
self._case = case
self._args = tuple(args)
else:
raise TypeError(f"Too many arguments, expected 1 or 2")
def __repr__(self):
name = type(self).__renaming__.get(self._case.name, self._case.name)
if self._args:
return f"<{name} {' '.join(map(repr, self._args))}>"
else:
return f"<{name}>"
def __getattribute__(self, attr):
d = type(self).__dict__
if attr in d and isinstance(d[attr], _CaseMethod):
return _BoundCaseMethod(d[attr], self)
else:
return super(ADT, self).__getattribute__(attr)
def case(self, cases):
for pat, branch in OrderedDict(cases).items():
try:
env = {}
pat.match(self, env)
except PatternMatchError:
continue
else:
return branch(**env)
raise PatternMatchError("No match")
def Wrapper(f):
return ADTMeta(f.__name__, (ADT,), {f.__name__: Case(f.__name__, f)})
@contextmanager
def match(val, pat, exc=None):
env = OrderedDict()
try:
pat.match(val, env)
except PatternMatchError:
if exc is not None:
raise exc from None
raise
else:
yield _from_tuple(tuple(env.values()))
class _CaseFunction(typing.Callable):
def __init__(self, f, cases=()):
self._default = f
self._cases = OrderedDict(cases)
def __call__(self, *args):
for pats, branch in self._cases.items():
if len(args) == len(pats):
env = {}
try:
for pat, val in zip(pats, vals):
pat.match(val, env)
except PatternMatchError:
continue
return branch(**env)
else:
raise TypeError(f"Expected {len(pats)} arguments, got {len(args)}")
try:
raise PatternMatchError(f"Non exhaustive pattern match, missing case: {args}")
except PatternMatchError:
return self._default(*args)
def case(self, pats):
def _decorator_wrapper(f):
self._cases[pats] = f
return _decorator_wrapper
class _CaseMethod(_CaseFunction):
def __init__(self, f, cases=()):
self._default = f
self._cases = OrderedDict(cases)
self._cls = None
def __call__(self, obj, *args):
if not isinstance(obj, self._cls):
raise TypeError(f"Expected {self._cls} instance, got {type(obj)}")
for pat, branch in self._cases.items():
env = OrderedDict()
try:
pat.match(obj, env)
except PatternMatchError:
continue
else:
mth_args = tuple(env.values()) + args
return branch(obj, *mth_args)
try:
raise PatternMatchError(f"Non exhaustive pattern match in method {self._default.__name__}, missing case: {obj}")
except PatternMatchError:
return self._default(obj, *args)
class _BoundCaseMethod(typing.Callable):
def __init__(self, method, instance):
if not isinstance(instance, method._cls):
raise TypeError(f"Cannot bind case method of datatype {method._cls} to {type(instance)} object")
@functools.wraps(method)
def _method_wrapper(*args):
return method(instance, *args)
self._method = _method_wrapper
self._cases = method._cases
self._default = method._default
def __call__(self, *args):
return self._method(*args)
casefunc = _CaseFunction
casemethod = _CaseMethod
class Trait(metaclass=abc.ABCMeta):
pass
class Functor(Trait):
@abc.abstractmethod
def map(self, f):
raise NotImplementedError
class Applicative(Functor):
@classmethod
@abc.abstractmethod
def pure(cls, x):
raise NotImplementedError
@abc.abstractmethod
def app(self, xs):
raise NotImplementedError
class Monad(Applicative):
@abc.abstractmethod
def bind(self, f):
raise NotImplementedError
def join(self):
return self.bind(lambda x: x)
def app(self, xs):
return self.bind(lambda f: xs.map(f))
class Foldable(Trait):
@abc.abstractmethod
def foldr(self, f, i):
raise NotImplementedError
def foldl(self, f, i):
return self.foldr(lambda b, g: lambda x: g(f(x, b)), lambda x: x)(i)
class Traversable(Foldable, Functor):
@abc.abstractmethod
def traverse(self, f):
raise NotImplementedError
class Maybe(ADT, Monad, Traversable):
@Case
def Just(x): pass
@Case
def Nothing(): pass
@casemethod
def map(self, *_): raise
@classmethod
def pure(cls, x):
return cls.Just(x)
@casemethod
def bind(self, f): raise
@casemethod
def foldr(self, f, i): raise
@casemethod
def traverse(self, f): raise
@map.case(Just << Var('x'))
def map_just(self, x, f):
return self.Just(f(x))
@map.case(Nothing << ())
def map_nothing(self, f):
return self
@bind.case(Just << Var('x'))
def bind_just(self, x, f):
return f(x)
@bind.case(Nothing << ())
def bind_maybe(self, f):
return self
@foldr.case(Just << Var('x'))
def foldr_just(self, x, f, i):
return f(x, i)
@foldr.case(Nothing << ())
def foldr_nothing(self, f, i):
return i
@traverse.case(Just << Var('x'))
def traverse_just(self, x, f):
fb = f(x)
f._applicative = type(fb)
return fb.map(self.Just)
@traverse.case(Nothing << ())
def traverse_nothing(self, f):
return f._applicative.pure(self)
class Either(ADT):
@Case
def Left(x): pass
@Case
def Right(x): pass
@casemethod
def map(self, f): raise
@classmethod
def pure(cls, x):
return cls.Right(x)
@casemethod
def bind(self, f): raise
@casemethod
def foldr(self, f, i): raise
@casemethod
def traverse(self, f): raise
@map.case(Left << Var('x'))
def map_left(self, x, f):
return self.Left(x)
@map.case(Right << Var('x'))
def map_right(self, x, f):
return self.Right(f(x))
@bind.case(Left << Any())
def bind_left(self, f):
return self
@bind.case(Right << Var('x'))
def bind_right(self, x, f):
return self.Right()
@foldr.case(Left << Any())
def foldr_left(self, f, i):
return i
@foldr.case(Right << Var('x'))
def foldr_right(self, x, f, i):
return f(x, i)
@traverse.case(Left << Any())
def traverse_nothing(self, f):
return f._applicative.pure(self)
@traverse.case(Right << Var('x'))
def traverse_just(self, x, f):
return f(x).map(self.Right)
class These(Either, renaming={'Left': 'This', 'Right': 'That'}):
@Case
def These(x, y): pass
@map.case(These << (Var('x'), Var('y')))
def map_these(self, x, y, f):
return self.These(x, f(y))
@casemethod
def from_these(self, a, b): pass
@from_these.case(This << Var('x'))
def from_this(self, x, a, b):
return x, b
@from_these.case(That << Var('x'))
def from_that(self, x, a, b):
return a, x
@from_these.case(These << (Var('x'), Var('y')))
def from_these_(self, x, y, a, b):
return x, y
class List(ADT, Functor):
@Case
def Nil(): pass
@Case
def Cons(x, xs): pass
def map(self, f):
return self.case({
self.Nil << (): lambda: self.Nil(),
self.Cons << (Var('x'), Var('xs')): lambda x, xs: self.Cons(f(x), xs.map(f))
})
def sum(self):
return self.case({
self.Nil << (): lambda: 0,
self.Cons << (Var('x'), Var('xs')): lambda x, xs: x + xs.sum()
})
@classmethod
def from_list(cls, lst):
return functools.reduce(lambda xs, x: List.Cons(x, xs), lst[::-1], cls.Nil())
# Example: lambda calculus interpreter
class Expr(ADT):
@Case
def Id(name):
if not name.isidentifier():
raise ValueError(f"{name!r} is not a valid identifier")
@Case
def App(f, x): pass
@Case
def Lam(x, e, closure): pass
@Case
def Zero(): pass
@Case
def Succ(n): pass
@Case
def Val(x): pass
@casemethod
def eval(self, env): raise
@eval.case(Id << Var('name'))
def eval_id(self, name, env):
try:
return env[name].eval(env)
except KeyError:
raise ValueError(self) from None
@eval.case(App << (Var('f'), Var('y')))
def eval_app(self, f, y, env):
pat = View(lambda e: e.eval(env), self.Lam << (Var('x'), Var('e'), Var('c')))
with match(f, pat, exc=ValueError(self)) as (x, e, closure):
local_env = {**env, **closure, x: y.eval(env)}
return e.eval(local_env)
@eval.case(Lam << (Var('x'), Var('e'), Var('c')))
def eval_lam(self, x, e, c, env):
return self.Lam(x, e, {**env, **c})
@eval.case(Zero << ())
def eval_zero(self, env):
return self.Val(0)
@eval.case(Succ << Var('n'))
def eval_succ(self, n, env):
with match(n, View(lambda e: e.eval(env), self.Val << Var('x')), exc=ValueError(n)) as x:
return self.Val(x + 1)
@eval.case(Val << Var('x'))
def eval_val(self, x, env):
return self
if __name__ == '__main__':
m = Maybe.Just(1)
m.case({
Maybe.Just << Var('x'): lambda x: print(x),
Maybe.Nothing << (): lambda: print('nothing')
})
print(m.map(lambda x: x+9))
lst = List.from_list([1,2,3])
print(lst.sum())
print(lst.map(lambda x: x ** 2))
print(These.__mro__)
print(Either.__mro__)
t1 = These.This('test')
t2 = These.That(3)
t3 = These.These('test', 3)
for t in [t1, t2, t3]:
print(t.map(lambda x: x * 2))
print(isinstance(Either.Left(1), These))
print(Either.Left(1).from_these('a', 'b'))
# Nat = forall r. (r -> r) -> r -> r
z = Expr.Lam('f', Expr.Lam('x', Expr.Id('x'), {}), {})
s = Expr.Lam('n', Expr.Lam('f', Expr.Lam('x', Expr.App(Expr.Id('f'), Expr.App(Expr.App(Expr.Id('n'), Expr.Id('f')), Expr.Id('x'))), {}), {}), {})
incr = Expr.Lam('n', Expr.Succ(Expr.Id('n')), {})
to_int = Expr.Lam('n', Expr.App(Expr.App(Expr.Id('n'), incr), Expr.Zero()), {})
add = Expr.Lam('m', Expr.Lam('n', Expr.Lam('f', Expr.Lam('x', Expr.App(Expr.App(Expr.Id('m'), Expr.Id('f')), Expr.App(Expr.App(Expr.Id('n'), Expr.Id('f')), Expr.Id('x'))), {}), {}), {}), {})
mul = Expr.Lam('m', Expr.Lam('n', Expr.Lam('f', Expr.Lam('x', Expr.App(Expr.App(Expr.Id('m'), Expr.App(Expr.Id('n'), Expr.Id('f'))), Expr.Id('x')), {}), {}), {}), {})
two = Expr.App(s, Expr.App(s, z))
four = Expr.App(Expr.App(mul, two), two)
six = Expr.App(Expr.App(add, two), four)
seven = Expr.App(s, six)
fourtytwo = Expr.App(Expr.App(mul, six), seven)
print(Expr.App(to_int, fourtytwo).eval({}))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment