In order to avoid causing a stack overflow during recursion, do the following:
- Ensure each recursive call is a tail call
- Decorate the recursive function(s) with
thunk
(in the case of co-recursive functions, decorate both). This exposes atrampoline
function, which is used in the next step. - Append a call to
trampoline
to only the initial call to each decorated function. Conversely, do not append a call totrampoline
to any recursive call to athunk
-decorated function.
For example, suppose we have the following naive recursive implementation of the factorial
function:
def factorial(n):
return n if n <= 1 else n * factorial(n - 1)
Of course, with the default Python recursion limit, for values of n
of roughly 1000
and larger, this factorial
implementation will exceed the recursion limit and cause an exception. Therefore, if we want to avoid this limitation, let's follow the steps outlined above.
First, let's modify the function such that the recursive call is a tail call. In order to do this, we'll introduce an accumulator argument (product
, in this case):
def factorial(n, product=1):
# 1. Convert recursive call to tail call with the help of an accumulator argument
return product if n <= 1 else factorial(n - 1, product * n)
Now, let's adjust this so that the product
argument isn't publicly exposed, and our public signature is the same as the original:
def factorial(n):
return _factorial(n)
def _factorial(n, product=1):
return product if n <= 1 else _factorial(n - 1, product * n)
Our next step is to decorate our tail-recursive function with thunk
:
def factorial(n):
return _factorial(n)
# 2. Decorate tail-recursive function with `thunk`
@thunk
def _factorial(n, product=1):
return product if n <= 1 else _factorial(n - 1, product * n)
Finally, we must append a call to trampoline
to the initial call to our thunked function:
def factorial(n):
# 3. Append call to `trampoline` to only the initial call to the "thunked" function
return _factorial(n).trampoline()
@thunk
def _factorial(n, product=1):
return product if n <= 1 else _factorial(n - 1, product * n)
Now we can call factorial
with arbitrarily large values of n
without fear of exceeding the recursion limit.