Skip to content

Instantly share code, notes, and snippets.

@sug0
Last active January 25, 2018 04:16
Show Gist options
  • Save sug0/339e516bdebb66f32a3b8879651fa6ea to your computer and use it in GitHub Desktop.
Save sug0/339e516bdebb66f32a3b8879651fa6ea to your computer and use it in GitHub Desktop.
Monads in Python
from functools import reduce as _reduce
from random import random as _random
from sys import version_info as _version, stdout as _stdout, stdin as _stdin
# version specific code
_input = None
if _version[0] < 3:
_input = raw_input
else:
_input = input
# dummy type
Traversable = type('Traversable', (), {})
# base monad
class Monad(object):
def __init__(self, value):
self._value = value
def __repr__(self):
return 'Monad %s' % repr(self._value)
def __str__(self):
return repr(self)
def __iter__(self):
if not isinstance(self, Traversable):
raise ValueError('%s is not traversable' % type(self).__class__)
else:
return self.traverse()
def unwrap(self):
return self._value
def bind(self, f):
if self._value == None:
return self
val = f(self._value)
if type(val) != type(self):
name = type(self).__name__
raise ValueError('expected %s monad return value from function' % name)
else:
return val
def value_type(self):
return type(self._value)
# maybe monad
class Maybe(Monad, Traversable):
def __repr__(self):
if self._value != None:
return 'Just %s' % repr(self._value)
else:
return 'Nothing'
def traverse(self):
if self._value != None:
yield self._value
else:
raise StopIteration
Just = lambda x: Maybe(x)
Nothing = Maybe(None)
# either monad
class Either(Monad, Traversable):
def __init__(self, value, is_left=False):
self._value = not None if not is_left else None
if not is_left:
self._right = value
else:
self._left = value
def __repr__(self):
if not self._value:
return 'Left %s' % repr(self._left)
else:
return 'Right %s' % repr(self._right)
def traverse(self):
if self._value != None:
yield self._right
else:
raise StopIteration
def bind(self, f):
if not self._value:
return self
val = f(self._right)
if type(val) != type(self):
name = type(self).__name__
raise ValueError('expected Either monad return value from function')
else:
return val
def unwrap(self):
if not self._value:
return self._left
else:
return self._right
def value_type(self):
if not self._value:
return type(self._left)
else:
return type(self._right)
def is_left(self):
return not self._value
Left = lambda x: Either(x, is_left=True)
Right = lambda x: Either(x, is_left=False)
# io monad
class IO(Monad):
def __init__(self, value, *args):
self._value = value
self._args = args
self._call = callable(value)
def __repr__(self):
if self._call:
if len(self._args) > 1:
return 'IO %s%s' % (self._value.__name__, repr(self._args))
elif len(self._args) == 0:
return 'IO %s' % self._value.__name__
else:
return 'IO %s(%s)' % (self._value.__name__, repr(self._args[0]))
else:
return 'IO %s' % repr(self._value)
def unwrap(self):
if self._call:
return self._value(*self._args)
else:
return self._value
# random stuff
def mreturn(value=None, type=Monad):
return type(value)
def liftM(f, type=Monad):
def new_func(m1):
return m1.bind(lambda x1: \
mreturn(f(x1), type))
new_func.__name__ = '%s_liftM' % f.__name__
return new_func
def liftM2(f, type=Monad):
def new_func(m1, m2):
return m1.bind(lambda x1: \
m2.bind(lambda x2: \
mreturn(f(x1, x2), type)))
new_func.__name__ = '%s_liftM2' % f.__name__
return new_func
def liftM3(f, type=Monad):
def new_func(m1, m2, m3):
return m1.bind(lambda x1: \
m2.bind(lambda x2: \
m3.bind(lambda x3: \
mreturn(f(x1, x2, x3), type))))
new_func.__name__ = '%s_liftM3' % f.__name__
return new_func
def division(x, y):
if y == 0:
return Nothing
else:
return Just(x/float(y))
def division2(x, y):
if y == 0:
return Left("can't divide by 0")
else:
return Right(x/float(y))
def mtry(f):
def mtry_fail(*args):
try:
return Right(f(*args))
except Exception as e:
return Left(e)
return mtry_fail
def cat_maybes(iter):
return [x.unwrap() for x in iter if x != Nothing]
def _print(x):
print(x)
def put_ln(x):
return IO(_print, x)
def read(size=-1, handle=_stdin):
return IO(handle.read(size))
def write(x, handle=_stdout):
return IO(handle.write, str(x))
def io_random():
rand = _random()
return IO(rand)
def getline():
return IO(_input())
def unsafe_perform_io(io):
return io.unwrap()
def sequence(io):
return IO(map(unsafe_perform_io, io))
def bind(x, f):
return x.bind(f)
def forM(m, f, type=Monad):
if not isinstance(m, Monad):
raise ValueError('not a monad')
rets = []
for x in m:
rets.append(f(x))
return type(rets)
def forM_(m, f):
if not isinstance(m, Monad):
raise ValueError('not a monad')
for x in m:
f(x)
def apply(*args):
def some_function(f):
return f(*args)
return some_function
def _compose(f, g):
def __compose(x):
return f(g(x))
return __compose
def compose(*funcs):
return _reduce(_compose, funcs)
# test some shit
if __name__ == '__main__':
print('Testing random numbers:')
sequence([io_random() for _ in xrange(10)]) \
.bind(lambda xs: IO([put_ln(x) for x in xs if x > 0.5])) \
.bind(lambda xs: sequence(xs))
print('\nTesting lifting functions:')
plus = liftM2(lambda x,y: x + y, type=Maybe)
nums = [Nothing, Just(3), Nothing]
print(nums)
print(_reduce(plus, nums))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment