Skip to content

Instantly share code, notes, and snippets.

@dschobel
Created June 6, 2015 06:41
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save dschobel/acbd360845072e4c3652 to your computer and use it in GitHub Desktop.
Save dschobel/acbd360845072e4c3652 to your computer and use it in GitHub Desktop.
calculating fibonacci numbers and manipulating a stack with the State monad
// no external dependencies, just :paste the entire gist into a scala 2.11 repl
// scala> import State._
// scala> StackExample.computed.run(List.empty)
// #result of push(1), push(2), push(3), pop()
// res0: (List[Int], Option[Int]) = (List(2, 1),Some(3))
// #naive fibonacci impl
// scala> time(FibExample.fib(42))
// took 1204 ms
// res1: Int = 267914296
// #memoized fibonacci
// scala> time(FibExample.fibM(42).run(Map.empty)._2)
// took 7 ms
// res2: Int = 267914296
object State {
def time[A](f: => A): A = {
val start = System.currentTimeMillis
val res = f
val duration = System.currentTimeMillis - start
println(s"took $duration ms")
res
}
object StateM {
def units[S, A](f: S => A): StateM[S, A] = StateM({ (s: S) => (s, f(s)) })
def unit[S,A](a: A): StateM[S, A] = units{(_:S) => a}
}
case class StateM[S,A](run: S => (S,A)) {
def getState = StateM[S,S]({(s: S) =>
(s, s)
})
def setState(newS: S) = StateM({(s: S) =>
val (s2, a) = run(s)
(newS, a)
})
def modifyState(f: S => S) = StateM[S,Unit]({(s: S) =>
(f(s), ())
})
def flatMap[B](f: A => StateM[S,B]) = StateM({(s: S) =>
val (s2, a) = run(s)
f(a).run(s2)
})
def map[B](f: A => B) = StateM({(s: S) =>
val (s2, a) = run(s)
(s2, f(a))
})
}
object StackExample {
type StackState[T] = StateM[List[Int], T]
def pop(): StackState[Option[Int]] = StateM[List[Int], Option[Int]]({ (s: List[Int]) => s match {
case x :: xs => (xs, Some(x))
case xs => (xs, None)
}})
def push(x: Int): StackState[Unit] =
StateM[List[Int], Unit]({ (s: List[Int]) => s match {
case xs => (x :: xs, ())
}})
def pushThreeAndPopOne[T](s: StackState[T]) =
for {
_ <- push(1)
_ <- push(2)
_ <- push(3)
x <- pop()
} yield x
val computed = pushThreeAndPopOne(StateM.unit[List[Int], Unit](()))
val (s,a) = computed.run(List.empty)
assert(s == List(2,1))
assert(a == Some(3))
}
object FibExample {
type Memo = Map[Int,Int]
def fib(n: Int): Int = {
if(n < 2) n
else {
fib(n-1) + fib(n-2)
}
}
def fibM(n: Int): StateM[Memo, Int] =
if (n < 2) StateM.unit(n)
else {
for {
memo <- StateM.units({ (m: Memo) => m.get(n) })
res <- memo match {
case Some(fibN) =>
StateM.unit[Memo, Int](fibN)
case None =>
val f1 = fibM(n-1)
for {
f1Res <- f1
_ <- f1.modifyState(_.updated(n-1, f1Res))
f2 = fibM(n-2)
f2Res <- f2
_ <- f2.modifyState(_.updated(n-2, f2Res))
} yield f1Res + f2Res
}
} yield res
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment