Skip to content

Instantly share code, notes, and snippets.

@ttim
Created April 26, 2019 01:46
Show Gist options
  • Save ttim/2a63c297b6b9f33082422875b4ef806b to your computer and use it in GitHub Desktop.
Save ttim/2a63c297b6b9f33082422875b4ef806b to your computer and use it in GitHub Desktop.
Recursion again
package com.twitter.scalding
import scala.util.control.TailCalls
import scala.util.control.TailCalls.TailRec
object StacklessRecursion {
type Rec[A, B] = (A, A => TailRec[B]) => TailRec[B]
def computeNonSafe[A, B](rec: Rec[A, B])(value: A): B = {
def compute(value: A): TailRec[B] = rec(value, compute)
compute(value).result
}
def computeSafe[A, B](rec: Rec[A, B])(value: A): B = {
def compute(value: A): TailRec[B] = rec(value, inner => TailCalls.tailcall(compute(inner)))
compute(value).result
}
def computeLimited[A, B](rec: Rec[A, B], stackLimit: Int = 100)(value: A): B = {
val withDepth: Rec[(A, Int), B] = {
case ((curValue, curDepth), recCall) =>
if (curDepth < stackLimit) {
rec(curValue, recA => recCall(recA, curDepth + 1))
} else {
TailCalls.tailcall(rec(curValue, recA => recCall(recA, 1)))
}
}
computeNonSafe(withDepth)((value, 0))
}
val rec1: Rec[Int, Int] = {
case (0, rec) => TailCalls.done(0)
case (n, rec) => rec(n-1).map(_ + 1)
}
def rec1(n: Int): Int = n match {
case 0 => 0
case n => rec1(n-1) + 1
}
def measure(name: String, number: Int = 1000)(calc: () => Unit): Unit = {
(1 to number).foreach(_ => calc())
val start = System.currentTimeMillis()
(1 to number).foreach(_ => calc())
println(name + " " + (System.currentTimeMillis() - start) * 1.0 / number)
}
def main(args: Array[String]): Unit = {
val n = 1000
// println(rec1(n))
// println(computeSafe(rec1)(n))
// println(computeNonSafe(rec1)(n))
// native 0.001
// safe 0.049
// non safe 6.18
// limited 1 0.055
// limited 5 0.091
// limited 10 0.104
// limited 20 0.189
// limited 100 0.563
(1 to 10).foreach { _ =>
measure("native") { () =>
rec1(n)
}
measure("safe") { () =>
computeSafe(rec1)(n)
}
measure("non safe", number = 50) { () =>
computeNonSafe(rec1)(n)
}
measure("limited 1") { () =>
computeLimited(rec1, 1)(n)
}
measure("limited 5") { () =>
computeLimited(rec1, 10)(n)
}
measure("limited 10") { () =>
computeLimited(rec1, 10)(n)
}
measure("limited 20") { () =>
computeLimited(rec1, 20)(n)
}
measure("limited 100") { () =>
computeLimited(rec1, 100)(n)
}
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment