Skip to content

@huynhjl /continuations.scala
Last active

Embed URL

HTTPS clone URL

Subversion checkout URL

You can clone with
or
.
Download ZIP
Continuation Monad
package worksheets
//http://blog.tmorris.net/continuation-monad-in-scala/
object continuations {
println("Welcome to the Scala worksheet") //> Welcome to the Scala worksheet
import Continuation._
// http://hackage.haskell.org/packages/archive/mtl/2.0.1.0/doc/html/Control-Monad-Cont.html
val list = List(1, 2, 3) //> list : List[Int] = List(1, 2, 3)
def calcLength(l: List[_]) = point[Unit](l.size)//> calcLength: (l: List[_])worksheets.Continuation[Unit,Int]
def double(n: Int) = point[Unit](2*n) //> double: (n: Int)worksheets.Continuation[Unit,Int]
def double2[R](n: Int) = continuation[R, Int](k => k(2*n))
//> double2: [R](n: Int)worksheets.Continuation[R,Int]
calcLength(list) runCont (println) //> 3
calcLength(list) flatMap double runCont println //> 6
calcLength(list) flatMap double2 runCont println//> 6
def validateName(name: String, exit: String => Continuation[String, Unit]) =
if (name == null || name.isEmpty()) exit("missing name")
else point[String]( () ) //> validateName: (name: String, exit: String => worksheets.Continuation[String,
//| Unit])worksheets.Continuation[String,Unit]
def whatsYourName(name: String): String = {
val welcome = point[String]("Welcome " + name + "!")
val welcomeX = callcc[String, String, Unit]{ exit =>
for {
_ <- validateName(name, exit)
w <- welcome
} yield w
}
val cont = for {
response <- welcomeX
} yield response
val res = cont runCont identity
res
} //> whatsYourName: (name: String)String
whatsYourName("JLH") //> res0: String = Welcome JLH!
whatsYourName("") //> res1: String = missing name
// ===============================================================
// http://en.wikibooks.org/wiki/Haskell/Continuation_passing_style
// ===============================================================
def square[R](x: Int) = point[R](x * x) //> square: [R](x: Int)worksheets.Continuation[R,Int]
def pythagoras[R](x: Int, y: Int): Continuation[R, Int] = for {
x2 <- square(x)
y2 <- square(y)
sum <- point(x2 + y2)
} yield sum //> pythagoras: [R](x: Int, y: Int)worksheets.Continuation[R,Int]
pythagoras(3, 4) runCont (println) //> 25
def square1[R](n: Int) = point[R](n*n) //> square1: [R](n: Int)worksheets.Continuation[R,Int]
def square2[R](n: Int) = callcc[R, Int, Int](k => k (n*n))
//> square2: [R](n: Int)worksheets.Continuation[R,Int]
square1(2) runCont println //> 4
square2(2) runCont println //> 4
def foo[R](n: Int): Continuation[R, String] = callcc[R, String, String] { k =>
for {
m <- point(n * n + 3)
_ <- if (m > 20) k("over twenty") else point[R](())
} yield {
(m - 4).toString
}
} //> foo: [R](n: Int)worksheets.Continuation[R,String]
foo(5) runCont println //> over twenty
foo(4) runCont println //> 15
def bar[R](c: Char, s: List[Char]): Continuation[R, Int] = for {
msg <- callcc[R, String, Unit] { k =>
for {
s1 <- point(c :: s)
_ <- if (s1 == "hello".toList) k("They say hello.") else point[R](())
s2 = s1.mkString
} yield ("They appear to be saying " + s2)
}
} yield (msg.length) //> bar: [R](c: Char, s: List[Char])worksheets.Continuation[R,Int]
bar('h', "ello".toList) runCont println //> 15
bar('b', "ello".toList) runCont println //> 30
def bar2[R](): Continuation[R, Int] =
callcc[R, Int, Unit] { k =>
for {
n <- point[R](5)
_ <- k(n)
} yield ({ println("evaled 25"); 25 })
} //> bar2: [R]()worksheets.Continuation[R,Int]
bar2() runCont println //> 5
/* We use the continuation monad to perform "escapes" from code blocks.
This function implements a complicated control structure to process
numbers:
Input (n) Output List Shown
========= ====== ==========
0-9 n none
10-199 number of digits in (n/2) digits of (n/2)
200-19999 n digits of (n/2)
20000-1999999 (n/2) backwards none
>= 2000000 sum of digits of (n/2) digits of (n/2) */
def fun(n: Int): String = resetc[String] {
for {
str <- callcc[String, String, Unit](exit1 => for {
_ <- when(n < 10) { exit1(n.toString) }
ns = (n / 2).toString.map(Character.digit(_, 10))
n1 <- callcc[String, Int, Unit](exit2 => for {
_ <- when(ns.size < 3) { exit2(ns.size) }
_ <- when(ns.size < 5) { exit2(n) }
_ <- when(ns.size < 7) {
val ns1 = ns.reverse.map(Character.forDigit(_, 10))
exit1(ns1.mkString.dropWhile(_ == '0'))
}
} yield (ns.sum))
} yield ("(ns = " + ns.mkString + ") " + n1))
} yield ("Answer: " + str)
} //> fun: (n: Int)String
fun(9) //> res2: String = Answer: 9
fun(10) //> res3: String = Answer: (ns = 5) 1
fun(199) //> res4: String = Answer: (ns = 99) 2
fun(200) //> res5: String = Answer: (ns = 100) 200
fun(19999) //> res6: String = Answer: (ns = 9999) 19999
fun(20000) //> res7: String = Answer: 1
fun(24680) //> res8: String = Answer: 4321
fun(1999999) //> res9: String = Answer: 999999
fun(2000026) //> res10: String = Answer: (ns = 1000013) 5
// An exception-throwing div
def divExcpt[R](num: Int, denom: Int, handler: String => Continuation[R, Int]): Continuation[R, Int] = {
callcc[R, Int, Unit](ok => for {
err <- callcc[R, String, Unit](notOk => for {
_ <- when(denom == 0)(notOk("Denominator 0"))
_ <- ok(num / denom)
} yield (""))
_ <- handler(err)
} yield (0))
} //> divExcpt: [R](num: Int, denom: Int, handler: String => worksheets.Continuat
//| ion[R,Int])worksheets.Continuation[R,Int]
resetc[Int](divExcpt[Int](4, 2, msg => point[Int](0)))
//> res11: Int = 2
resetc[Int](divExcpt[Int](4, 0, msg => point[Int](0)))
//> res12: Int = 0
def divExcpt2[R](num: Int, denom: Int, handler: String => Continuation[R, Int]): Continuation[R, Int] =
callcc[R, Int, String] { ok =>
val err = callcc[R, String, String](notOk =>
if (denom == 0) notOk("Denominator 0")
else ok(num / denom))
val res = err flatMap handler
res
} //> divExcpt2: [R](num: Int, denom: Int, handler: String => worksheets.Continua
//| tion[R,Int])worksheets.Continuation[R,Int]
resetc[Int](divExcpt2[Int](4, 2, msg => point[Int](0)))
//> res13: Int = 2
resetc[Int](divExcpt2[Int](4, 0, msg => point[Int](0)))
//> res14: Int = 0
// ===============================================================
// What happens when there are two continuations that modify the
// flow? The last one is executed first (twoModif is inside).
// ===============================================================
def oneModif(n: Int) = continuation[String, Int](k => "oneModif[k+1=%s k+10=%s]".format(k(n+10), k(n+20)))
//> oneModif: (n: Int)worksheets.Continuation[String,Int]
def twoModif(n: Int) = continuation[String, Int](l => "twoModif[l+100=%s l+200=%s]".format(l(n+100), l(n+200)))
//> twoModif: (n: Int)worksheets.Continuation[String,Int]
val modifs = for {
n <- point[String](1)
m <- oneModif(1)
o <- twoModif(1)
} yield {
"(modifs: n=" + n + " m=" + m + " o=" + o + ")"
//> modifs : worksheets.Continuation[String,java.lang.String] = worksheets.Con
//| tinuation$$anon$1@2955220e
}
resetc[String](modifs) //> res15: String = oneModif[k+1=twoModif[l+100=(modifs: n=1 m=11 o=101) l+200=
//| (modifs: n=1 m=11 o=201)] k+10=twoModif[l+100=(modifs: n=1 m=21 o=101) l+20
//| 0=(modifs: n=1 m=21 o=201)]]
// ===============================================================
// http://lampwww.epfl.ch/~rompf/continuations-icfp09.pdf page 4.2
// ===============================================================
import Shift._
// note: Continuation[R, A] is Shift[A, R, R] (Shift[A,B,C] for fun: (A => B) => C
// Continuation[R, A] is A @cps[R,R]
val ctx = shift[Int, Int, Int](f => f(f(f(7)))).map(x => x + 1)
//> ctx : worksheets.Shift[Int,Int,Int] = worksheets.Shift$$anon$2@13edc485
reset(ctx) //> res16: Int = 10
// type-safe sprintf
val int = (x: Int) => x.toString //> int : Int => java.lang.String = <function1>
val str = (x: String) => x //> str : String => String = <function1>
def format[A, B](toStr: A => String) =
shift[String, B, A => B](k => (a: A) => k(toStr(a)))
//> format: [A, B](toStr: A => String)worksheets.Shift[String,B,A => B]
def sprintf[A](str: => Shift[String, String, A]) = {
reset(str)
} //> sprintf: [A](str: => worksheets.Shift[String,String,A])A
val f1 = sprintf[String](pure("Hello World!")) //> f1 : String = Hello World!
// scalac can actually infer the type parameters
// but I leave it here to help me...
val f2 = sprintf[String => String](
for {
fmt <- format[String, String](str)
} yield {
"Hello " + fmt + "!"
}
) //> f2 : String => String = <function1>
f2("World") //> res17: String = Hello World!
// scalac can actually infer the type parameters
val f3 = sprintf[String => Int => String](
for {
fmtstr <- format[String, Int => String](str)
fmtint <- format[Int, String](int)
} yield {
"The value of " + fmtstr + " is " + fmtint + "."
}
) //> f3 : String => (Int => String) = <function1>
f3("x")(3) //> res18: String = The value of x is 3.
}
sealed trait Continuation[R, +A] {
def runCont(k: A => R): R
import Continuation.continuation
def map[B](f: A => B) =
continuation[R, B](k => runCont(k compose f))
// k: B => R and (k compose f): A => R
// http://lampwww.epfl.ch/~rompf/continuations-icfp09.pdf page 4
def flatMap[B](f: A => Continuation[R, B]) =
continuation[R, B]{ k =>
runCont(a => f(a).runCont(k))
}
//continuation[R, B](z => apply(k(_)(z))) // http://blog.tmorris.net/continuation-monad-in-scala/
}
object Continuation {
def continuation[R, A](g: (A => R) => R): Continuation[R, A] = new Continuation[R, A] {
def runCont(f: A => R) = g(f)
}
def point[R] = new {
def apply[A](a: A) = continuation[R, A](f => f(a))
}
/**
* The standard idiom used with callCC is to provide a lambda-expression
* to name the continuation. Then calling the named continuation anywhere
* within its scope will escape from the computation, even if it is many
* layers deep within nested computations.
*/
// great explanation here: http://stackoverflow.com/a/9050907/257449
def callcc[R, A, B](f: (A => Continuation[R, B]) => Continuation[R, A]) = {
continuation[R, A] { k =>
f(a => continuation(_ => k(a))).runCont(k)
}
}
/** to translate the haskell examples it will be convenient to define `when` */
def when[R, B](pred: Boolean)(k: Continuation[R, B]) = if (pred) k else point[R]( () )
def resetc[A](cont: Continuation[A, A]): A = cont.runCont((a: A) => a)
}
/** A @cps[-B, +C] */
sealed trait Shift[+A, -B, +C]{
import Shift._
def runCont(k: A => B): C
def map[A1](f: A => A1): Shift[A1, B, C] = shift[A1, B, C](k => runCont (k compose f))
def flatMap[A1, B1, C1 <: B](f: A => Shift[A1, B1, C1]) = shift[A1, B1, C]{ k =>
runCont( (a: A) => f(a).runCont(k))
}
}
object Shift {
def shift[A, B, C](f: (A => B) => C): Shift[A, B, C] = new Shift[A, B, C] {
def runCont(k: A => B) = f(k)
}
def reset[A, C](c: Shift[A, A, C]) = c.runCont(identity)
def pure[C] = new {
def apply[A](a: A) = shift[A, C, C](k => k(a))
}
}
// http://community.schemewiki.org/?composable-continuations-tutorial
$ scala -P:continuations:enable
scala> import util.continuations._
scala> reset { shiftUnit[Int, Int, Int](1) }
res3: Int = 1
scala> 1 + reset{ 2 + shift[Int, Int, Int] { k => 3 } }
res6: Int = 4
scala> 1 :: (reset{ 2 :: (shift[List[Int], List[Int], List[Int]]{ k => 3 :: Nil }) })
res8: List[Int] = List(1, 3)
scala> 1 + reset[Int, Int] { 2 + shift[Int, Int, Int] { k => k(4) + 3 } }
res10: Int = 10
scala> 1 + reset[Int, Int] { 2 + shift[Int, Int, Int] { k => k(5) + k(1) + 3 } }
res11: Int = 14
scala> 1 :: (reset{ 2 :: (shift[List[Int], List[Int], List[Int]]{ k => 3 :: k(4 :: Nil) }) })
res12: List[Int] = List(1, 3, 2, 4)
scala> 1 :: (reset{ 2 :: (shift[List[Int], List[Int], List[Int]]{ k => 3 :: k(k(4 :: Nil)) }) })
res13: List[Int] = List(1, 3, 2, 2, 4)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Something went wrong with that request. Please try again.