Skip to content

Instantly share code, notes, and snippets.

@bblfish
Last active May 1, 2019 10:52
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save bblfish/e04b81ab55c1e4d128256305afd0bc57 to your computer and use it in GitHub Desktop.
Save bblfish/e04b81ab55c1e4d128256305afd0bc57 to your computer and use it in GitHub Desktop.
Code from blog post on the State Monad
// Code from blog post on the State Monad
// http://blog.tmorris.net/posts/memoisation-with-state-using-scala/index.html
type Memo = Map[Long, BigInt]
object FibNaïve {
def fibnaïve(n: Long): BigInt =
if (n <= 1)
BigInt(n)
else {
val r = fibnaïve(n - 1)
val s = fibnaïve(n - 2)
r + s
}
}
case class State[S, A](run: S => (A, S)) {
// 1. the map method
def map[B](f: A => B): State[S, B] =
State(s => {
val (a, t) = run(s)
(f(a), t)
})
// 2. the flatMap method
def flatMap[B](f: A => State[S, B]): State[S, B] =
State(s => {
val (a, t) = run(s)
f(a) run t
})
// Convenience function to drop the resulting state value
def eval(s: S): A =
run(s)._1
}
object State {
// 3. The insert function
def insert[S, A](a: A): State[S, A] =
State(s => (a, s))
// Convenience function for taking the current state to a value
def get[S, A](f: S => A): State[S, A] =
State(s => (f(s), s))
// Convenience function for modifying the current state
def mod[S](f: S => S): State[S, Unit] =
State(s => ((), f(s)))
}
// fixes memoisation bug
// Also adds little debugging tool to allow one to see how state evolves.
// returns the memorisation map, to help see that it has been altered.
object FibMemo1 {
def fibmemo1(n: Long)(
implicit memo: Memo = Map(),
debug: Boolean = false
): (BigInt,Memo) = {
def fibmemoR(z: Long, memo: Memo): (BigInt, Memo) =
if (z <= 1)
(z, memo)
else memo get z match {
case None => {
val (r, memo0) = fibmemoR(z - 1, memo)
val (s, memo1) = fibmemoR(z - 2, memo0)
val res = r + s
if (debug) println(z + "->" + res)
(res, memo1 + Pair(z, res)) //<- the
}
case Some(v) => (v, memo)
}
fibmemoR(n,memo)
}
}
//Fix FibMemo2 in
// http://blog.tmorris.net/posts/memoisation-with-state-using-scala/index.html
// Again the results were not being memoised. (Added debug option to see what
// gets calculated
object FibMemo2 {
def fibmemo2(n: Long)(
implicit memo: Memo = Map(),
debug: Boolean = false
): (BigInt,Memo) = {
def fibmemoR(z: Long): State[Memo, BigInt] =
State(memo =>
if (z <= 1)
(z, memo)
else memo get z match {
case None => {
val (r, memo0) = fibmemoR(z - 1) run memo
val (s, memo1) = fibmemoR(z - 2) run memo0
val res = r + s
if (debug) println(z + "->" + res)
(res, memo1 + Pair(z, res)) // <- add results to memory
}
case Some(v) => (v, memo)
})
fibmemoR(n).run(memo)
}
}
object FibMemo3 {
def fibmemo3(n: Long)(
implicit memo: Memo = Map(),
debug: Boolean = false
): (BigInt, Memo) = {
def fibmemoR(z: Long): State[Memo, BigInt] =
if (z <= 1)
State.insert(BigInt(z))
else
for {
u <- State.get((m: Memo) => m get z)
v <- u map State.insert[Memo, BigInt] getOrElse {
//<- the above { bracket is important or result is wrong
fibmemoR(z - 1) flatMap (r =>
fibmemoR(z - 2) flatMap (s => {
val t = r + s
if (debug) println(z + "->" + t)
State.mod((m: Memo) => m + (z -> t)) map (_ => t)
}))
}
} yield v
fibmemoR(n) run memo
}
}
object FibMemo4 {
def fibmemo4(n: Long)(
implicit memo: Memo = Map(),
debug: Boolean = false
): (BigInt, Memo) = {
def fibmemoR(z: Long): State[Memo, BigInt] =
if (z <= 1)
State.insert(BigInt(z))
else
for {
u <- State.get((m: Memo) => m get z)
v <- u map State.insert[Memo, BigInt] getOrElse (for {
r <- fibmemoR(z - 1)
s <- fibmemoR(z - 2)
t = r + s
_ <- State.mod((m: Memo) => m + (z -> t))
} yield {
if (debug) println(z + "->" + t)
t
})
} yield v
fibmemoR(n) run memo
}
}
implicit val memo = Map()
implicit val debut = true
import scala.util.Try
def test(nums: List[Long], fib: (Long,Memo,Boolean) => (BigInt,Memo)): Try[Memo] = Try {
nums.foldLeft[Memo](Map()){case (memo,n) =>
print("["+n+"]")
fib(n,memo,false)._2
}
}
// this throws a StackOverflowExcpetion around for fibonacies larger than 2000 or 4000
val nums = List(10L,100L,1000L,4000L,8000L)
//println("naive=" + test(nums, FibNaïve.fibnaïve))
println("test1="); test(nums, FibMemo1.fibmemo1(_)(_,_))
//println("test2=" + test(nums, FibMemo2.fibmemo2))
//println("test3=" + test(nums, FibMemo3.fibmemo3))
//println("test4="); test(nums, FibMemo4.fibmemo4(_)(_,_))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment