Skip to content

Instantly share code, notes, and snippets.

@YoukaiCat
Last active January 27, 2018 21:53
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save YoukaiCat/24e6d225012bcdc6a8b8516ee33efe59 to your computer and use it in GitHub Desktop.
Save YoukaiCat/24e6d225012bcdc6a8b8516ee33efe59 to your computer and use it in GitHub Desktop.
Monad implementation in Scala
object MonadInterface {
import scala.language.higherKinds
trait Functor[T[_]] {
def map[A, B](f: A => B)(x: T[A]): T[B]
}
trait Applicative[T[_]] extends Functor[T] {
def pure[A](v: A): T[A]
def apply[A, B](f: T[A => B])(a: T[A]): T[B]
}
trait Monad[T[_]] extends Applicative[T] {
def bind[A, B](m: T[A])(f: A => T[B]): T[B]
def join[A](m: T[T[A]]): T[A]
}
abstract class MonadOps[T[_], A](m: T[A]) {
def $[B](f: A => B)(implicit monad: Monad[T]): T[B] =
monad.map(f)(m)
def *[B](f: T[A => B])(implicit monad: Monad[T]): T[B] =
monad.apply(f)(m)
def >>=[B](f: A => T[B])(implicit monad: Monad[T]): T[B] =
monad.bind(m)(f)
}
}
object MaybeMonad {
import MonadInterface._
sealed trait Maybe[+A]
case class Just[A](v: A) extends Maybe[A]
case class Empty() extends Maybe[Nothing]
implicit val MaybeMonad = new Monad[Maybe] {
def map[A, B](f: A => B)(x: Maybe[A]): Maybe[B] = {
x match {
case Just(v) => Just(f(v))
case Empty() => Empty()
}
}
def pure[A](v: A): Maybe[A] = Just(v)
def apply[A, B](f: Maybe[A => B])(a: Maybe[A]): Maybe[B] = {
f match {
case Just(fn) => a match {
case Just(v) => Just(fn(v))
case Empty() => Empty()
}
case Empty() => Empty()
}
}
def bind[A, B](m: Maybe[A])(f: A => Maybe[B]): Maybe[B] = {
m match {
case Just(v) => f(v)
case Empty() => Empty()
}
}
def join[A](m: Maybe[Maybe[A]]): Maybe[A] = {
m match {
case Just(v) => v
case Empty() => Empty()
}
}
}
implicit class MaybeMonadOps[A](m: Maybe[A]) extends MonadOps(m) {}
}
object Main extends App {
import MaybeMonad._
assert((Just(5) $ (x => x * x)) == Just(25))
assert((Empty() $ ((x: Int) => x + x)) == Empty())
def f(x: Int)(y: Int)(z: Int): Int = x * y * z
assert(Just(5) * (Just(5) * (Just(5) $ f)) == Just(125))
assert(Just(5) * (Empty() * (Just(5) $ f)) == Empty())
assert((Just(3) >>= (
x => Just(4) >>= (
y => Just(x * y)))) == Just(12))
assert((Just(3) >>= (
x => Empty() >>= (
(y: Int) => Just(x + y)))) == Empty())
import scala.language.postfixOps
println(Just(5) $ (_ toString) $ (_ length))
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment