Skip to content

Instantly share code, notes, and snippets.

@SegFaultAX
Created December 17, 2018 08:48
Show Gist options
  • Star 2 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save SegFaultAX/1196f1522b672959debc882bf7e290df to your computer and use it in GitHub Desktop.
Save SegFaultAX/1196f1522b672959debc882bf7e290df to your computer and use it in GitHub Desktop.
Simple Free Monad [Python]
import dataclasses as dc
import typing as ty
import inspect
import functools
S = ty.TypeVar("S")
A = ty.TypeVar("A")
B = ty.TypeVar("B")
@dc.dataclass(frozen=True)
class Monad:
fmap: ty.Callable
pure: ty.Callable
bind: ty.Callable
@dc.dataclass(frozen=True)
class Free(ty.Generic[S, A]):
pass
@dc.dataclass(frozen=True)
class Pure(Free[S, A]):
a: A
@dc.dataclass(frozen=True)
class Suspend(Free[S, A]):
k: S # S[Free[S, A]]
@dc.dataclass(frozen=True)
class FlatMap(Free[S, A]):
v: Free[S, A]
f: ty.Callable[[A], Free[S, B]]
def fmap(fn, free):
return FlatMap(free, lambda x: puref(fn(x)))
def pure(a):
return Pure(a)
def suspend(k):
return Suspend(k)
def bind(free, fn):
return FlatMap(free, fn)
MonadFree = Monad(fmap, pure, bind)
def match(free, if_pure, if_suspend, if_flatmap):
if isinstance(free, Pure):
return if_pure(free)
elif isinstance(free, Suspend):
return if_suspend(free)
else:
return if_flatmap(free)
def step(free):
root = free
while True:
if isinstance(root, FlatMap):
if isinstance(root.v, Pure):
root = root.f(root.v.a)
elif isinstance(root.v, FlatMap):
inner = root.v
root = bind(inner.v, lambda x: bind(inner.f(x), root.f))
else:
break
else:
break
return root
def foldmap(free, natural, monad, tailrec):
def run1(x):
return match(x,
lambda pure: (True, pure.a),
lambda suspend: (True, natural(suspend.k)),
lambda flatmap: (False, flatmap.f(foldmap(flatmap.v, natural, monad, tailrec)))
)
return tailrec(free, run1)
def do(monad, inst=lambda e: True):
def binder(gen):
def step(value):
try:
result = gen.send(value)
return monad.bind(result, step)
except StopIteration as e:
return e.value if inst(e.value) else monad.pure(e.value)
return step
def decorator(fn):
def wrapper(*args, **kwargs):
gen = fn(*args, **kwargs)
if not inspect.isgenerator(gen):
return gen
return binder(gen)(None)
return wrapper
return decorator
def free(fn):
@functools.wraps(fn)
@do(MonadFree, lambda e: isinstance(e, Free))
def wrapper(*args, **kwargs):
return fn(*args, **kwargs)
return wrapper
### Example ###
@dc.dataclass(frozen=True)
class ReadLine:
prompt: str
def readln(prompt):
return suspend(ReadLine(prompt))
@dc.dataclass(frozen=True)
class PrintLine:
line: str
def println(line):
return suspend(PrintLine(line))
MonadNullable = Monad(
lambda f, e: f(e) if e is not None else None,
lambda e: e,
lambda e, f: f(e) if e is not None else None)
def tailrec_nullable(val, step):
done, result = False, val
while not done:
done, result = step(result)
if result is None:
return None
return result
def handler(cmd):
if isinstance(cmd, ReadLine):
return input(cmd.prompt)
elif isinstance(cmd, PrintLine):
print(cmd.line)
return ()
@free
def program1():
name = yield readln("What is your name? ")
age = yield readln("What is your age? ")
yield println(f"Your name is {name} and you are {age} years old!")
return (name, age)
print(foldmap(program1(), handler, MonadNullable, tailrec_nullable))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment