-
-
Save aaronchall/eb562d0916617f8ada9636ec934c5adc to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from __future__ import annotations | |
from abc import ABC, abstractmethod | |
from typing import Callable, Any | |
AnyA = AnyB = Any | |
class Functor(ABC): | |
"""e.g. lists, Maybe, IO...""" | |
__slots__ = () | |
@abstractmethod | |
def fmap(self: Functor, fn: Callable[[AnyA], AnyB]) -> Functor: | |
"""haskell has fmap :: (a -> b) -> f a -> f b, but we rearrange first two?""" | |
raise NotImplementedError('requires method for fmap') | |
def __lt__(self, a): # <, Haskell: <$ | |
""" convenience method, not required | |
-- | Replace all locations in the input with the same value. | |
(<$) :: a -> f b -> f a | |
(<$) = fmap . const | |
""" | |
return self.fmap(lambda _: a) | |
def _validate(self): | |
# first law, the fmap of the id is the id of the functor: | |
def id(obj): | |
return obj | |
assert self.fmap(id) == id(self) | |
# second law, fmap of the compose is the compose of the fmap: | |
def compose(f, g): | |
return lambda x: f(g(x)) | |
def f(x): | |
return x + 1 | |
def g(x): | |
return x * 2 | |
assert self.fmap(compose(f, g)) == self.fmap(g).fmap(f) | |
return self | |
@abstractmethod | |
def __eq__(self, other): | |
raise NotImplementedError('must implement __eq__') | |
SameMonadAnyB = AnyB | |
class Monad(ABC): | |
"""a monad is an abstract data type | |
(https://www.haskell.org/onlinereport/haskell2010/haskellch7.html) | |
requires abstract `>>=` (bind) and `return` | |
(fail method is being removed, don't use or implement) | |
""" | |
#@abstractmethod | |
@classmethod | |
@abstractmethod | |
def return_(cls, a): # aka return, return can only make a list of length 1 | |
raise NotImplementedError('must implement return_') | |
@abstractmethod | |
def bind(self, fn: Callable[[AnyA], SameMonadAnyB]) -> Monad: | |
raise NotImplementedError('must implement bind') | |
def __ge__(self, k): # >=, in haskell: >>= | |
return self.bind(k) | |
def __gt__(self, x): # >, Haskell: (>>) = (*>) | |
return self.bind(lambda _: x) | |
def _validate(self): | |
"""monad laws: | |
return a >>= k = k a | |
m >>= return = m | |
m >>= (\ x -> k x >>= h) = (m >>= k) >>= h | |
""" | |
def k(a) -> 'Monad[b]': | |
return type(self).return_(a + a) # assumes string or number or list | |
a = 1 | |
assert type(self).return_(a).bind(k) == k(a) | |
assert self.bind(type(self).return_) == self | |
def h(a) -> 'Monad[b]': | |
return type(self).return_(a * a) # assumes number type | |
assert self.bind((lambda x: k(x))).bind(h) == self.bind(k).bind(h) | |
super()._validate() | |
return self | |
class Maybe(Functor, Monad): | |
"""a.k.a. Just, unless Nothing, which is a class member and instance of Maybe as well""" | |
def fmap(self, f): | |
return type(self)(f(self.x)) | |
def bind(self, k): | |
# (Just x, see http://hackage.haskell.org/package/base-4.12.0.0/docs/src/GHC.Base.html#line-854) | |
result = k(self.x) | |
if not isinstance(result, type(self)): | |
raise TypeError('k must return a value of type {type(self)}') | |
return result | |
def return_(self, x): | |
self.x = x | |
__init__ = return_ | |
def __eq__(self, other): | |
return self is other or self.x == other.x | |
def __gt__(self, other): # > | |
""" | |
Just _m1 *> m2 = m2 | |
Nothing *> _m2 = Nothing | |
(>>) = (*>) | |
""" | |
return other | |
def __repr__(self): | |
return f'{type(self).__name__}({repr(self.x)})' | |
class Nothing(Maybe): | |
def _return_self(self, _): | |
return self | |
bind = __gt__ = __ge__ = fmap = _return_self | |
Maybe.Nothing = Nothing(None) | |
Maybe(1) | |
isinstance(Maybe.Nothing, Maybe) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment