Last active
October 12, 2015 19:38
-
-
Save huynhjl/4077185 to your computer and use it in GitHub Desktop.
Continuation Monad
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package worksheets | |
//http://blog.tmorris.net/continuation-monad-in-scala/ | |
object continuations { | |
println("Welcome to the Scala worksheet") //> Welcome to the Scala worksheet | |
import Continuation._ | |
// http://hackage.haskell.org/packages/archive/mtl/2.0.1.0/doc/html/Control-Monad-Cont.html | |
val list = List(1, 2, 3) //> list : List[Int] = List(1, 2, 3) | |
def calcLength(l: List[_]) = point[Unit](l.size)//> calcLength: (l: List[_])worksheets.Continuation[Unit,Int] | |
def double(n: Int) = point[Unit](2*n) //> double: (n: Int)worksheets.Continuation[Unit,Int] | |
def double2[R](n: Int) = continuation[R, Int](k => k(2*n)) | |
//> double2: [R](n: Int)worksheets.Continuation[R,Int] | |
calcLength(list) runCont (println) //> 3 | |
calcLength(list) flatMap double runCont println //> 6 | |
calcLength(list) flatMap double2 runCont println//> 6 | |
def validateName(name: String, exit: String => Continuation[String, Unit]) = | |
if (name == null || name.isEmpty()) exit("missing name") | |
else point[String]( () ) //> validateName: (name: String, exit: String => worksheets.Continuation[String, | |
//| Unit])worksheets.Continuation[String,Unit] | |
def whatsYourName(name: String): String = { | |
val welcome = point[String]("Welcome " + name + "!") | |
val welcomeX = callcc[String, String, Unit]{ exit => | |
for { | |
_ <- validateName(name, exit) | |
w <- welcome | |
} yield w | |
} | |
val cont = for { | |
response <- welcomeX | |
} yield response | |
val res = cont runCont identity | |
res | |
} //> whatsYourName: (name: String)String | |
whatsYourName("JLH") //> res0: String = Welcome JLH! | |
whatsYourName("") //> res1: String = missing name | |
// =============================================================== | |
// http://en.wikibooks.org/wiki/Haskell/Continuation_passing_style | |
// =============================================================== | |
def square[R](x: Int) = point[R](x * x) //> square: [R](x: Int)worksheets.Continuation[R,Int] | |
def pythagoras[R](x: Int, y: Int): Continuation[R, Int] = for { | |
x2 <- square(x) | |
y2 <- square(y) | |
sum <- point(x2 + y2) | |
} yield sum //> pythagoras: [R](x: Int, y: Int)worksheets.Continuation[R,Int] | |
pythagoras(3, 4) runCont (println) //> 25 | |
def square1[R](n: Int) = point[R](n*n) //> square1: [R](n: Int)worksheets.Continuation[R,Int] | |
def square2[R](n: Int) = callcc[R, Int, Int](k => k (n*n)) | |
//> square2: [R](n: Int)worksheets.Continuation[R,Int] | |
square1(2) runCont println //> 4 | |
square2(2) runCont println //> 4 | |
def foo[R](n: Int): Continuation[R, String] = callcc[R, String, String] { k => | |
for { | |
m <- point(n * n + 3) | |
_ <- if (m > 20) k("over twenty") else point[R](()) | |
} yield { | |
(m - 4).toString | |
} | |
} //> foo: [R](n: Int)worksheets.Continuation[R,String] | |
foo(5) runCont println //> over twenty | |
foo(4) runCont println //> 15 | |
def bar[R](c: Char, s: List[Char]): Continuation[R, Int] = for { | |
msg <- callcc[R, String, Unit] { k => | |
for { | |
s1 <- point(c :: s) | |
_ <- if (s1 == "hello".toList) k("They say hello.") else point[R](()) | |
s2 = s1.mkString | |
} yield ("They appear to be saying " + s2) | |
} | |
} yield (msg.length) //> bar: [R](c: Char, s: List[Char])worksheets.Continuation[R,Int] | |
bar('h', "ello".toList) runCont println //> 15 | |
bar('b', "ello".toList) runCont println //> 30 | |
def bar2[R](): Continuation[R, Int] = | |
callcc[R, Int, Unit] { k => | |
for { | |
n <- point[R](5) | |
_ <- k(n) | |
} yield ({ println("evaled 25"); 25 }) | |
} //> bar2: [R]()worksheets.Continuation[R,Int] | |
bar2() runCont println //> 5 | |
/* We use the continuation monad to perform "escapes" from code blocks. | |
This function implements a complicated control structure to process | |
numbers: | |
Input (n) Output List Shown | |
========= ====== ========== | |
0-9 n none | |
10-199 number of digits in (n/2) digits of (n/2) | |
200-19999 n digits of (n/2) | |
20000-1999999 (n/2) backwards none | |
>= 2000000 sum of digits of (n/2) digits of (n/2) */ | |
def fun(n: Int): String = resetc[String] { | |
for { | |
str <- callcc[String, String, Unit](exit1 => for { | |
_ <- when(n < 10) { exit1(n.toString) } | |
ns = (n / 2).toString.map(Character.digit(_, 10)) | |
n1 <- callcc[String, Int, Unit](exit2 => for { | |
_ <- when(ns.size < 3) { exit2(ns.size) } | |
_ <- when(ns.size < 5) { exit2(n) } | |
_ <- when(ns.size < 7) { | |
val ns1 = ns.reverse.map(Character.forDigit(_, 10)) | |
exit1(ns1.mkString.dropWhile(_ == '0')) | |
} | |
} yield (ns.sum)) | |
} yield ("(ns = " + ns.mkString + ") " + n1)) | |
} yield ("Answer: " + str) | |
} //> fun: (n: Int)String | |
fun(9) //> res2: String = Answer: 9 | |
fun(10) //> res3: String = Answer: (ns = 5) 1 | |
fun(199) //> res4: String = Answer: (ns = 99) 2 | |
fun(200) //> res5: String = Answer: (ns = 100) 200 | |
fun(19999) //> res6: String = Answer: (ns = 9999) 19999 | |
fun(20000) //> res7: String = Answer: 1 | |
fun(24680) //> res8: String = Answer: 4321 | |
fun(1999999) //> res9: String = Answer: 999999 | |
fun(2000026) //> res10: String = Answer: (ns = 1000013) 5 | |
// An exception-throwing div | |
def divExcpt[R](num: Int, denom: Int, handler: String => Continuation[R, Int]): Continuation[R, Int] = { | |
callcc[R, Int, Unit](ok => for { | |
err <- callcc[R, String, Unit](notOk => for { | |
_ <- when(denom == 0)(notOk("Denominator 0")) | |
_ <- ok(num / denom) | |
} yield ("")) | |
_ <- handler(err) | |
} yield (0)) | |
} //> divExcpt: [R](num: Int, denom: Int, handler: String => worksheets.Continuat | |
//| ion[R,Int])worksheets.Continuation[R,Int] | |
resetc[Int](divExcpt[Int](4, 2, msg => point[Int](0))) | |
//> res11: Int = 2 | |
resetc[Int](divExcpt[Int](4, 0, msg => point[Int](0))) | |
//> res12: Int = 0 | |
def divExcpt2[R](num: Int, denom: Int, handler: String => Continuation[R, Int]): Continuation[R, Int] = | |
callcc[R, Int, String] { ok => | |
val err = callcc[R, String, String](notOk => | |
if (denom == 0) notOk("Denominator 0") | |
else ok(num / denom)) | |
val res = err flatMap handler | |
res | |
} //> divExcpt2: [R](num: Int, denom: Int, handler: String => worksheets.Continua | |
//| tion[R,Int])worksheets.Continuation[R,Int] | |
resetc[Int](divExcpt2[Int](4, 2, msg => point[Int](0))) | |
//> res13: Int = 2 | |
resetc[Int](divExcpt2[Int](4, 0, msg => point[Int](0))) | |
//> res14: Int = 0 | |
// =============================================================== | |
// What happens when there are two continuations that modify the | |
// flow? The last one is executed first (twoModif is inside). | |
// =============================================================== | |
def oneModif(n: Int) = continuation[String, Int](k => "oneModif[k+1=%s k+10=%s]".format(k(n+10), k(n+20))) | |
//> oneModif: (n: Int)worksheets.Continuation[String,Int] | |
def twoModif(n: Int) = continuation[String, Int](l => "twoModif[l+100=%s l+200=%s]".format(l(n+100), l(n+200))) | |
//> twoModif: (n: Int)worksheets.Continuation[String,Int] | |
val modifs = for { | |
n <- point[String](1) | |
m <- oneModif(1) | |
o <- twoModif(1) | |
} yield { | |
"(modifs: n=" + n + " m=" + m + " o=" + o + ")" | |
//> modifs : worksheets.Continuation[String,java.lang.String] = worksheets.Con | |
//| tinuation$$anon$1@2955220e | |
} | |
resetc[String](modifs) //> res15: String = oneModif[k+1=twoModif[l+100=(modifs: n=1 m=11 o=101) l+200= | |
//| (modifs: n=1 m=11 o=201)] k+10=twoModif[l+100=(modifs: n=1 m=21 o=101) l+20 | |
//| 0=(modifs: n=1 m=21 o=201)]] | |
// =============================================================== | |
// http://lampwww.epfl.ch/~rompf/continuations-icfp09.pdf page 4.2 | |
// =============================================================== | |
import Shift._ | |
// note: Continuation[R, A] is Shift[A, R, R] (Shift[A,B,C] for fun: (A => B) => C | |
// Continuation[R, A] is A @cps[R,R] | |
val ctx = shift[Int, Int, Int](f => f(f(f(7)))).map(x => x + 1) | |
//> ctx : worksheets.Shift[Int,Int,Int] = worksheets.Shift$$anon$2@13edc485 | |
reset(ctx) //> res16: Int = 10 | |
// type-safe sprintf | |
val int = (x: Int) => x.toString //> int : Int => java.lang.String = <function1> | |
val str = (x: String) => x //> str : String => String = <function1> | |
def format[A, B](toStr: A => String) = | |
shift[String, B, A => B](k => (a: A) => k(toStr(a))) | |
//> format: [A, B](toStr: A => String)worksheets.Shift[String,B,A => B] | |
def sprintf[A](str: => Shift[String, String, A]) = { | |
reset(str) | |
} //> sprintf: [A](str: => worksheets.Shift[String,String,A])A | |
val f1 = sprintf[String](pure("Hello World!")) //> f1 : String = Hello World! | |
// scalac can actually infer the type parameters | |
// but I leave it here to help me... | |
val f2 = sprintf[String => String]( | |
for { | |
fmt <- format[String, String](str) | |
} yield { | |
"Hello " + fmt + "!" | |
} | |
) //> f2 : String => String = <function1> | |
f2("World") //> res17: String = Hello World! | |
// scalac can actually infer the type parameters | |
val f3 = sprintf[String => Int => String]( | |
for { | |
fmtstr <- format[String, Int => String](str) | |
fmtint <- format[Int, String](int) | |
} yield { | |
"The value of " + fmtstr + " is " + fmtint + "." | |
} | |
) //> f3 : String => (Int => String) = <function1> | |
f3("x")(3) //> res18: String = The value of x is 3. | |
} | |
sealed trait Continuation[R, +A] { | |
def runCont(k: A => R): R | |
import Continuation.continuation | |
def map[B](f: A => B) = | |
continuation[R, B](k => runCont(k compose f)) | |
// k: B => R and (k compose f): A => R | |
// http://lampwww.epfl.ch/~rompf/continuations-icfp09.pdf page 4 | |
def flatMap[B](f: A => Continuation[R, B]) = | |
continuation[R, B]{ k => | |
runCont(a => f(a).runCont(k)) | |
} | |
//continuation[R, B](z => apply(k(_)(z))) // http://blog.tmorris.net/continuation-monad-in-scala/ | |
} | |
object Continuation { | |
def continuation[R, A](g: (A => R) => R): Continuation[R, A] = new Continuation[R, A] { | |
def runCont(f: A => R) = g(f) | |
} | |
def point[R] = new { | |
def apply[A](a: A) = continuation[R, A](f => f(a)) | |
} | |
/** | |
* The standard idiom used with callCC is to provide a lambda-expression | |
* to name the continuation. Then calling the named continuation anywhere | |
* within its scope will escape from the computation, even if it is many | |
* layers deep within nested computations. | |
*/ | |
// great explanation here: http://stackoverflow.com/a/9050907/257449 | |
def callcc[R, A, B](f: (A => Continuation[R, B]) => Continuation[R, A]) = { | |
continuation[R, A] { k => | |
f(a => continuation(_ => k(a))).runCont(k) | |
} | |
} | |
/** to translate the haskell examples it will be convenient to define `when` */ | |
def when[R, B](pred: Boolean)(k: Continuation[R, B]) = if (pred) k else point[R]( () ) | |
def resetc[A](cont: Continuation[A, A]): A = cont.runCont((a: A) => a) | |
} | |
/** A @cps[-B, +C] */ | |
sealed trait Shift[+A, -B, +C]{ | |
import Shift._ | |
def runCont(k: A => B): C | |
def map[A1](f: A => A1): Shift[A1, B, C] = shift[A1, B, C](k => runCont (k compose f)) | |
def flatMap[A1, B1, C1 <: B](f: A => Shift[A1, B1, C1]) = shift[A1, B1, C]{ k => | |
runCont( (a: A) => f(a).runCont(k)) | |
} | |
} | |
object Shift { | |
def shift[A, B, C](f: (A => B) => C): Shift[A, B, C] = new Shift[A, B, C] { | |
def runCont(k: A => B) = f(k) | |
} | |
def reset[A, C](c: Shift[A, A, C]) = c.runCont(identity) | |
def pure[C] = new { | |
def apply[A](a: A) = shift[A, C, C](k => k(a)) | |
} | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// http://community.schemewiki.org/?composable-continuations-tutorial | |
$ scala -P:continuations:enable | |
scala> import util.continuations._ | |
scala> reset { shiftUnit[Int, Int, Int](1) } | |
res3: Int = 1 | |
scala> 1 + reset{ 2 + shift[Int, Int, Int] { k => 3 } } | |
res6: Int = 4 | |
scala> 1 :: (reset{ 2 :: (shift[List[Int], List[Int], List[Int]]{ k => 3 :: Nil }) }) | |
res8: List[Int] = List(1, 3) | |
scala> 1 + reset[Int, Int] { 2 + shift[Int, Int, Int] { k => k(4) + 3 } } | |
res10: Int = 10 | |
scala> 1 + reset[Int, Int] { 2 + shift[Int, Int, Int] { k => k(5) + k(1) + 3 } } | |
res11: Int = 14 | |
scala> 1 :: (reset{ 2 :: (shift[List[Int], List[Int], List[Int]]{ k => 3 :: k(4 :: Nil) }) }) | |
res12: List[Int] = List(1, 3, 2, 4) | |
scala> 1 :: (reset{ 2 :: (shift[List[Int], List[Int], List[Int]]{ k => 3 :: k(k(4 :: Nil)) }) }) | |
res13: List[Int] = List(1, 3, 2, 2, 4) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment