Skip to content

Instantly share code, notes, and snippets.

@ktbarrett
Created October 14, 2019 06:52
Show Gist options
  • Save ktbarrett/733b73d57bb7536c8a07ae948ab946c1 to your computer and use it in GitHub Desktop.
Save ktbarrett/733b73d57bb7536c8a07ae948ab946c1 to your computer and use it in GitHub Desktop.
monad stuff
from monad import Monad
class Computation(Monad):
__category = True
@staticmethod
def pure(value):
return Continue(value)
def map(self, f):
if isinstance(self, (Finish, Failure)):
return self
else:
try:
res = f(self.value)
except Exception as e:
return Finish(e)
else:
if not isinstance(res, Computation):
return Continue(res)
else:
return res
class Continue(Computation):
@property
def value(self):
return self._value
def __init__(self, value):
self._value = value
class Finish(Continue):
pass
class Failure(Finish):
@property
def value(self):
raise self._value from None
if __name__ == "__main__":
def pipeline(x):
return Computation(x).map(
lambda x: x + 2).map(
lambda x: x * 2).map(
lambda x: Finish(x) if x < 5 else x // 5).map(
lambda x: x + (None if x > 10 else x))
b = pipeline(-1)
print(b.value) # 2
c = pipeline(3)
print(c.value) # 4
d = pipeline(26)
try:
d.value
except Exception as e:
print(e) # TypeError
from abc import ABCMeta, abstractmethod
class MonadMeta(type):
_noarg = ()
def __call__(cls, value=_noarg):
if hasattr(cls, f"_{cls.__name__}__category"):
if value is not MonadMeta._noarg:
return cls.pure(value)
else:
return cls.pure()
else:
if value is not MonadMeta._noarg:
return super().__call__(value)
else:
return super().__call__()
class ABCMonadMeta(MonadMeta, ABCMeta):
pass
class Monad(metaclass=ABCMonadMeta):
"""
Defines the monad interface.
Subclass this and define the below functions to implement the monad.
Mark that subclass as a category by adding a `__category` attribute.
Subclass the category with the resident types.
"""
@staticmethod
@abstractmethod
def pure(value):
"""
Constructs value into a value of one of the resident types of the category.
"""
raise NotImplementedError
@abstractmethod
def map(self, f):
"""
Both map and flat_map in the traditional sense. Maps a function over a monadic value.
Because Python does not have static types, this function must implement
both flat_map and map by checking the return of the the map call, and
constructing it if necessary.
"""
raise NotImplementedError
# Example Monad
class Maybe(Monad):
__category = True
@staticmethod
def pure(value):
if value is None:
return Nothing()
else:
return Just(value)
def map(self, f):
if isinstance(self, Nothing):
return self
else:
res = f(self.value)
if not isinstance(res, Maybe):
return self.pure(res)
else:
return res
class Just(Maybe):
@property
def value(self):
return self._value
def __init__(self, value):
self._value = value
def __repr__(self):
return f"Just({self.value})"
class Nothing(Maybe):
@property
def value(self):
return None
def __repr__(self):
return "Nothing()"
# Example Program
if __name__ == "__main__":
a = Just(1)
print(a) # Just(1)
a = a.map(lambda x: x + 2)
print(a) # Just(3)
def lower5(x):
if x < 5:
return Nothing()
else:
return x
b = a.map(lower5)
print(b) # Nothing()
b.map(lambda x: x * 3)
print(b) # Nothing()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment