Last active
January 10, 2021 14:36
-
-
Save sjoerdvisscher/a56a286ccfabce40e424 to your computer and use it in GitHub Desktop.
The mother of all monads to the rescue
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// See http://blog.sigfpe.com/2008/12/mother-of-all-monads.html | |
// MARK: function composition | |
infix operator >>> { associativity left } | |
func >>> <A, B, C>(f: A -> B, g: B -> C) -> A -> C { | |
return { x in g(f(x)) } | |
} | |
// MARK: Continuation monad | |
struct Cont<R, A> { | |
let start: (A -> R) -> R | |
} | |
func lift<R, A>(a: A) -> Cont<R, A> { | |
return Cont { k in k(a) } | |
} | |
func lift<R, A, B>(fa: Cont<R, A>, f: A -> B) -> Cont<R, B> { | |
return Cont { k in fa.start(f >>> k) } | |
} | |
func lift<R, A, B, C>(fa: Cont<R, A>, fb: Cont<R, B>, f: (A, B) -> C) -> Cont<R, C> { | |
return Cont { k in fa.start { a in fb.start { b in k(f(a, b)) } } } | |
} | |
func lift<R, A, B, C, D>(fa: Cont<R, A>, fb: Cont<R, B>, fc: Cont<R, C>, f: (A, B, C) -> D) -> Cont<R, D> { | |
return Cont { k in fa.start { a in fb.start { b in fc.start { c in k(f(a, b, c)) } } } } | |
} | |
// traverse can't be generic in the monad, so we specialise to the continuation monad | |
// and then emulate the other monads | |
protocol Traversable { | |
typealias El | |
func traverse<R>(f: Self.El -> Cont<R, Self.El>) -> Cont<R, Self> | |
} | |
extension Array: Traversable { | |
func traverse<R, S>(f: T -> Cont<R, S>) -> Cont<R, Array<S>> { | |
return self.reduce(lift([]), { (r: Cont<R, Array<S>>, a: T) in | |
lift(r, f(a)) { (tail: Array<S>, head: S) in tail + [head] } | |
}) | |
} | |
} | |
extension String: Traversable { | |
func traverse<R>(f: Character -> Cont<R, Character>) -> Cont<R, String> { | |
var r : Cont<R, String> = lift("") | |
for c in self { | |
r = lift(r, f(c)) { (cs, c2) in cs + [c2] } | |
} | |
return r | |
} | |
} | |
// MARK: emulate Optional monad | |
func i<A, B>(ma: A?) -> Cont<B?, A> { | |
return Cont { k in if let a = ma { return k(a) } else { return nil } } | |
} | |
func run<A>(m : Cont<A?, A>) -> A? { | |
return m.start { $0 } | |
} | |
func all<T: Traversable>(t: T, f: T.El -> T.El?) -> T? { | |
return run(t.traverse(f >>> i)) | |
} | |
println(all([1, 2, 3, 4]) { $0 < 3 ? $0 : nil }) | |
println(all([1, 2, 3, 4]) { $0 < 5 ? $0 : nil }) | |
// MARK: emulate List monad | |
func i<A, B>(arr: [A]) -> Cont<[B], A> { | |
return Cont { k in arr.reduce([], { (r, a) in r + k(a) }) } | |
} | |
func run<A>(m: Cont<[A], A>) -> [A] { | |
return m.start { [$0] } | |
} | |
func vary<T: Traversable>(t: T, f: T.El -> [T.El]) -> [T] { | |
return run(t.traverse(f >>> i)) | |
} | |
println(vary("bla") { [".", $0] }) | |
// MARK: emulate State monad | |
func i<A, B, S>(ma: S -> (S, A)) -> Cont<S -> (S, B), A> { | |
return Cont { k in { s0 in var (s1, a) = ma(s0); return k(a)(s1) } } | |
} | |
func run<A, S>(m : Cont<S -> (S, A), A>, s0: S) -> (S, A) { | |
return m.start { a in { s in (s, a) }}(s0) | |
} | |
func mapAccumR<T: Traversable, S>(t: T, s0: S, f: (S, T.El) -> (S, T.El)) -> (S, T) { | |
return run(t.traverse { x in i{ f($0, x) } }, s0) | |
} | |
println(mapAccumR([1,2,3], "", { (s, i) in (i.description + s, i + 10) })) | |
// MARK: Monoids | |
protocol Semigroup { | |
func +(l: Self, r: Self) -> Self | |
} | |
protocol Monoid : Semigroup { | |
class func zero() -> Self | |
} | |
extension Int: Monoid { static func zero() -> Int { return 0 } } | |
extension String: Monoid { static func zero() -> String { return "" } } | |
extension Array: Monoid { static func zero() -> Array { return [] } } | |
// MARK: emulate Writer monad | |
func i<A, B, M:Monoid>(m1: M, a: A) -> Cont<(M, B), A> { | |
return Cont { k in var (m2, b) = k(a); return (m1 + m2, b) } | |
} | |
func run<A, M:Monoid>(m: Cont<(M, A), A>) -> (M, A) { | |
return m.start { (M.zero(), $0) } | |
} | |
func foldMap<T: Traversable, M: Monoid>(t: T, f: T.El -> M) -> M { | |
return run(t.traverse { x in i(f(x), x) }).0 | |
} | |
func fold<T: Traversable where T.El: Monoid>(t: T) -> T.El { | |
return foldMap(t) { $0 } | |
} | |
println(fold(["1", "2", "3"])) | |
println(foldMap(["1", "2", "3"]) { $0.toInt()! }) | |
// MARK: emulate Identity monad | |
func i<A, B>(a: A) -> Cont<B, A> { | |
return Cont { k in k(a) } | |
} | |
func run<A>(m: Cont<A, A>) -> A { | |
return m.start { $0 } | |
} | |
func monoMap<T: Traversable>(t: T, f: T.El -> T.El) -> T { | |
return run(t.traverse(f >>> i)) | |
} | |
// Use boxing to prevent the compiler from crashing when using enums | |
class Box<T> { | |
let unbox: T | |
init(_ value: T) { self.unbox = value } | |
} | |
// MARK: Binary Tree | |
enum BinaryTree<A: Printable> { | |
case Leaf(Box<A>) | |
case Branch(Box<BinaryTree<A>>, Box<BinaryTree<A>>) | |
} | |
func leaf<A>(a: A) -> BinaryTree<A> { | |
return .Leaf(Box(a)) | |
} | |
func +<A>(l: BinaryTree<A>, r: BinaryTree<A>) -> BinaryTree<A> { | |
return .Branch(Box(l), Box(r)) | |
} | |
extension BinaryTree: Semigroup, Traversable, Printable { | |
typealias El = A | |
func reduce<R>(fl: A -> R, fb: (R, R) -> R) -> R { | |
switch self { | |
case .Leaf(let a): return fl(a.unbox) | |
case .Branch(let l, let r): return fb(l.unbox.reduce(fl, fb), r.unbox.reduce(fl, fb)) | |
} | |
} | |
func traverse<R, B>(f: A -> Cont<R, B>) -> Cont<R, BinaryTree<B>> { | |
return self.reduce({ lift(f($0), leaf) }, { lift($0, $1, +) }) | |
} | |
func flatMap<B>(f: A -> BinaryTree<B>) -> BinaryTree<B> { | |
return self.reduce(f, +) | |
} | |
func map<B>(f: A -> B) -> BinaryTree<B> { | |
return self.flatMap(f >>> leaf) | |
} | |
var description: String { | |
return self.reduce({ $0.description }, { "(\($0) \($1))" }) | |
} | |
} | |
// BinaryTree is a monad | |
func i<A, B>(ma: BinaryTree<A>) -> Cont<BinaryTree<B>, A> { | |
return Cont { k in ma.flatMap(k) } | |
} | |
func run<A>(m: Cont<BinaryTree<A>, A>, a: A) -> BinaryTree<A> { | |
return m.start(leaf) | |
} | |
println(run((leaf(1) + leaf(2)).traverse{ i(leaf($0) + leaf($0*10)) }).description) | |
// MARK: Compose traversables | |
struct Compose<F: Traversable, G: Traversable where F.El == G> { | |
let get: F | |
} | |
extension Compose: Traversable { | |
typealias El = G.El | |
func traverse<R>(f: El -> Cont<R, El>) -> Cont<R, Compose> { | |
return lift(self.get.traverse { $0.traverse(f) }) { Compose(get: $0) } | |
} | |
} | |
println(fold([[1,2],[3,4]])) | |
println(fold(Compose(get: [[1,2],[3,4]]))) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment