Last active
May 1, 2019 10:52
-
-
Save bblfish/e04b81ab55c1e4d128256305afd0bc57 to your computer and use it in GitHub Desktop.
Code from blog post on the State Monad
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// Code from blog post on the State Monad | |
// http://blog.tmorris.net/posts/memoisation-with-state-using-scala/index.html | |
type Memo = Map[Long, BigInt] | |
object FibNaïve { | |
def fibnaïve(n: Long): BigInt = | |
if (n <= 1) | |
BigInt(n) | |
else { | |
val r = fibnaïve(n - 1) | |
val s = fibnaïve(n - 2) | |
r + s | |
} | |
} | |
case class State[S, A](run: S => (A, S)) { | |
// 1. the map method | |
def map[B](f: A => B): State[S, B] = | |
State(s => { | |
val (a, t) = run(s) | |
(f(a), t) | |
}) | |
// 2. the flatMap method | |
def flatMap[B](f: A => State[S, B]): State[S, B] = | |
State(s => { | |
val (a, t) = run(s) | |
f(a) run t | |
}) | |
// Convenience function to drop the resulting state value | |
def eval(s: S): A = | |
run(s)._1 | |
} | |
object State { | |
// 3. The insert function | |
def insert[S, A](a: A): State[S, A] = | |
State(s => (a, s)) | |
// Convenience function for taking the current state to a value | |
def get[S, A](f: S => A): State[S, A] = | |
State(s => (f(s), s)) | |
// Convenience function for modifying the current state | |
def mod[S](f: S => S): State[S, Unit] = | |
State(s => ((), f(s))) | |
} | |
// fixes memoisation bug | |
// Also adds little debugging tool to allow one to see how state evolves. | |
// returns the memorisation map, to help see that it has been altered. | |
object FibMemo1 { | |
def fibmemo1(n: Long)( | |
implicit memo: Memo = Map(), | |
debug: Boolean = false | |
): (BigInt,Memo) = { | |
def fibmemoR(z: Long, memo: Memo): (BigInt, Memo) = | |
if (z <= 1) | |
(z, memo) | |
else memo get z match { | |
case None => { | |
val (r, memo0) = fibmemoR(z - 1, memo) | |
val (s, memo1) = fibmemoR(z - 2, memo0) | |
val res = r + s | |
if (debug) println(z + "->" + res) | |
(res, memo1 + Pair(z, res)) //<- the | |
} | |
case Some(v) => (v, memo) | |
} | |
fibmemoR(n,memo) | |
} | |
} | |
//Fix FibMemo2 in | |
// http://blog.tmorris.net/posts/memoisation-with-state-using-scala/index.html | |
// Again the results were not being memoised. (Added debug option to see what | |
// gets calculated | |
object FibMemo2 { | |
def fibmemo2(n: Long)( | |
implicit memo: Memo = Map(), | |
debug: Boolean = false | |
): (BigInt,Memo) = { | |
def fibmemoR(z: Long): State[Memo, BigInt] = | |
State(memo => | |
if (z <= 1) | |
(z, memo) | |
else memo get z match { | |
case None => { | |
val (r, memo0) = fibmemoR(z - 1) run memo | |
val (s, memo1) = fibmemoR(z - 2) run memo0 | |
val res = r + s | |
if (debug) println(z + "->" + res) | |
(res, memo1 + Pair(z, res)) // <- add results to memory | |
} | |
case Some(v) => (v, memo) | |
}) | |
fibmemoR(n).run(memo) | |
} | |
} | |
object FibMemo3 { | |
def fibmemo3(n: Long)( | |
implicit memo: Memo = Map(), | |
debug: Boolean = false | |
): (BigInt, Memo) = { | |
def fibmemoR(z: Long): State[Memo, BigInt] = | |
if (z <= 1) | |
State.insert(BigInt(z)) | |
else | |
for { | |
u <- State.get((m: Memo) => m get z) | |
v <- u map State.insert[Memo, BigInt] getOrElse { | |
//<- the above { bracket is important or result is wrong | |
fibmemoR(z - 1) flatMap (r => | |
fibmemoR(z - 2) flatMap (s => { | |
val t = r + s | |
if (debug) println(z + "->" + t) | |
State.mod((m: Memo) => m + (z -> t)) map (_ => t) | |
})) | |
} | |
} yield v | |
fibmemoR(n) run memo | |
} | |
} | |
object FibMemo4 { | |
def fibmemo4(n: Long)( | |
implicit memo: Memo = Map(), | |
debug: Boolean = false | |
): (BigInt, Memo) = { | |
def fibmemoR(z: Long): State[Memo, BigInt] = | |
if (z <= 1) | |
State.insert(BigInt(z)) | |
else | |
for { | |
u <- State.get((m: Memo) => m get z) | |
v <- u map State.insert[Memo, BigInt] getOrElse (for { | |
r <- fibmemoR(z - 1) | |
s <- fibmemoR(z - 2) | |
t = r + s | |
_ <- State.mod((m: Memo) => m + (z -> t)) | |
} yield { | |
if (debug) println(z + "->" + t) | |
t | |
}) | |
} yield v | |
fibmemoR(n) run memo | |
} | |
} | |
implicit val memo = Map() | |
implicit val debut = true | |
import scala.util.Try | |
def test(nums: List[Long], fib: (Long,Memo,Boolean) => (BigInt,Memo)): Try[Memo] = Try { | |
nums.foldLeft[Memo](Map()){case (memo,n) => | |
print("["+n+"]") | |
fib(n,memo,false)._2 | |
} | |
} | |
// this throws a StackOverflowExcpetion around for fibonacies larger than 2000 or 4000 | |
val nums = List(10L,100L,1000L,4000L,8000L) | |
//println("naive=" + test(nums, FibNaïve.fibnaïve)) | |
println("test1="); test(nums, FibMemo1.fibmemo1(_)(_,_)) | |
//println("test2=" + test(nums, FibMemo2.fibmemo2)) | |
//println("test3=" + test(nums, FibMemo3.fibmemo3)) | |
//println("test4="); test(nums, FibMemo4.fibmemo4(_)(_,_)) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment