Recursion is beautiful. As an example, let's consider this perfectly acceptable example of defining the functions even
and odd
in Scala, whose semantics you can guess:
def even(i: Int): Boolean = i match {
case 0 => true
case _ => odd(i - 1)
}
def odd(i: Int): Boolean = i match {
case 0 => false
case _ => even(i - 1)
}
We've defined even
in a very natural way: an value is even when it's zero, or when it's odd after subtracting one. Odd is defined in a similar way.
Let's ignore the problem this code has with negative integers, and the fact that an O(n) algorithm for computing even
and odd
is not ideal, and focus on a single problem: this code breaks for large n:
scala> even(5000)
res3: Boolean = true
scala> even(50000)
java.lang.StackOverflowError
at Trampolines$Stack$.odd(trampolines.scala:14)
at Trampolines$Stack$.even(trampolines.scala:9)
at Trampolines$Stack$.odd(trampolines.scala:14)
at Trampolines$Stack$.even(trampolines.scala:9)
at Trampolines$Stack$.odd(trampolines.scala:14)
at Trampolines$Stack$.even(trampolines.scala:9)
...
at Trampolines$Stack$.even(trampolines.scala:9)
at Trampolines$Stack$.odd(trampolines.scala:14)
In the remainder of this post we look at various strategies for preventing recursive programs from running out of memory.
The problem manifests itself in the stack trace: even
calls odd
, then odd
calls even
, which calls odd
again. Each of these invocations causes the creation of a stack frame, eating away the space that's reserved for them.
Keeping all the stack frames would not really be necessary: when even
calls odd
, that is the very last thing that the even
function will ever do! When odd
returns, even
itself returns immediately with that same value.
When we compute even(500)
, after a bunch of recursive calls, the 501st stack frame will be created, for the invocation of even(0)
, which returns true
, and the 500 method calls that are left on the stack will one after another all return that same value, until the whole stack is unwinded.
Some languages or runtimes are smart enough to figure out that if a calling function will simply return the value that a function it calls returns, it does not need to allocate a new stack frame for the callee, but can reuse the stack frame of the caller for that.
Unfortunately, the JVM does not support this.
And so Scala, as long as it sticks to JVM method calling semantics, also can not support this.
Luckily, there is one special case of tail calls which Scala treats differently, and that is the case of tail recursion. If a method does a tail call to itself, it's called tail recursion, and Scala will rewrite the recursion into a loop that does not consume stack space.
Many recursive algorithms can be rewritten in a tail recursive way. For our even
and odd
problem, this is a possible solution:
def even(i: Int): Boolean = i match {
case 0 => true
case 1 => false
case _ => even(i - 2)
}
def odd(i: Int): Boolean = !even(i)
We've rewritten even
to be tail-recursive, and now Scala will optimize it for us:
scala> even(500000000)
res25: Boolean = true
it no longer blows up for large numbers.
The @tailrec
annotation is useful in these cases. It does not change how the function will be compiled, but it does trigger an error if the function is not actually tail-recursive. So this will protect you from accidentally changing a function that you want to be tail-recursive function into something that's no longer tail recursive.
You can try this out by putting the @tailrec
annotation on the first and the latest version of even
that we've defined.
Another way of solving the problem of consuming too much stack space, is by moving the computation from the stack to the heap.
Instead of letting the even
function recursively call odd
, we could also let it return some recipe to it's caller, how to continue with the computation. Let's look at the following code:
sealed trait EvenOdd
case class Done(result: Boolean) extends EvenOdd
case class Even(value: Int) extends EvenOdd
case class Odd(value: Int) extends EvenOdd
def even(i: Int): EvenOdd = i match {
case 0 => Done(true)
case _ => Odd(i - 1)
}
def odd(i: Int): EvenOdd = i match {
case 0 => Done(false)
case _ => Even(i - 1)
}
Here we defined a little language, EvenOdd
, that can express three things:
- The computation is
Done
, and we have the result value - We need to compute whether a value is
Even
- We need to compute whether a value is
Odd
Our even
and odd
functions no longer return a boolean, but return an expression in this little language.
Finally, we create an interpreter, that will recursively evaluate expressions in this language:
// Using Scala's self recursive tail call optimization
@tailrec
def run(evenOdd: EvenOdd): Boolean = evenOdd match {
case Done(result) => result
case Even(value) => run(even(value))
case Odd(value) => run(odd(value))
}
Now we have expressed even
and odd
in their natural, mutually recursive way, and we still have a stack safe interpreter:
scala> run(even(5000000))
res1: Boolean = true
The disadvantage of this is that this is significantly slower. Unfortunately, we can't seem to have our cake and eat it too :(
This strategy is sometimes called trampolining, because instead of creating a big stack, we go up to even
, then down to run
, then up to odd
, then down to run
, then up to even
, down to run
, etcetera. The size of our stack keeps growing and shrinking by one frame for every step in the computation. This looks a lot like going up and down on a trampoline :)
There is no need to specialize our little language to computing even and odd. We can also make a little language that can express recursion in a general way:
sealed trait Computation[A]
class Continue[A](n: => Computation[A]) extends Computation[A] {
lazy val next = n
}
case class Done[A](result: A) extends Computation[A]
def even(i: Int): Computation[Boolean] = i match {
case 0 => Done(true)
case _ => new Continue(odd(i - 1))
}
def odd(i: Int): Computation[Boolean] = i match {
case 0 => Done(false)
case _ => new Continue(even(i - 1))
}
@tailrec
def run[A](computation: Computation[A]): A = computation match {
case Done(a) => a
case c: Continue[A] => run(c.next)
}
Here our even
and odd
functions don't return domain specific values, but a general value that indicates whether the computation is done, or whether more steps are needed. The latter includes the next step as a by-name parameter, that the tail recursive runner function can call.
Note that our run
function is no longer tied to computing even
and odd
, it can compute anything.
Something similar in spirit, but with a better implementation is also available in the Scala standard library:
import scala.util.control.TailCalls.{ TailRec, done, tailcall }
def even(i: Int): TailRec[Boolean] = i match {
case 0 => done(true)
case _ => tailcall(odd(i - 1))
}
def odd(i: Int): TailRec[Boolean] = i match {
case 0 => done(false)
case _ => tailcall(even(i - 1))
}
even(3000).result
I compared the performance of these solutions with JMH, and these are the results:
[info] Benchmark Mode Cnt Score Error Units
[info] Trampolines.GeneralTrampolineRunner.bench thrpt 30 44916.024 ± 388.202 ops/s
[info] Trampolines.ScalaTrampolineRunner.bench thrpt 30 52106.426 ± 408.242 ops/s
[info] Trampolines.SpecializedTrampolineRunner.bench thrpt 30 94002.234 ± 1584.913 ops/s
[info] Trampolines.StackRunner.bench thrpt 30 358382.321 ± 6622.659 ops/s
As expected, the version that runs on the stack is the fastest. But remember that this is the version that breaks for a large number of recursions.
The specialized trampolining version, with the EvenOdd
domain specific language and a runner optimized for this particular problem, takes about a 4 times speed hit compared to the stack version.
The general trampoline version that we defined here is about 2 times slower than the specialized version, and about 8 times slower than the stack version.
The TailRec
version from the Scala standard library is about 20% faster than our general trampoline, making it about 7 times slower than the stack version.
The source code of the benchmarks (and all the code), is available on https://github.com/eamelink/scala-trampolines