Skip to content

Instantly share code, notes, and snippets.

@chuckwondo
Last active February 5, 2019 14:58
Show Gist options
  • Save chuckwondo/fd36da61fa6e9422e7681efc48a4674c to your computer and use it in GitHub Desktop.
Save chuckwondo/fd36da61fa6e9422e7681efc48a4674c to your computer and use it in GitHub Desktop.
Python Tail-Call Optimization

Python Tail-Call Optimization

In order to avoid causing a stack overflow during recursion, do the following:

  1. Ensure each recursive call is a tail call
  2. Decorate the recursive function(s) with thunk (in the case of co-recursive functions, decorate both). This exposes a trampoline function, which is used in the next step.
  3. Append a call to trampoline to only the initial call to each decorated function. Conversely, do not append a call to trampoline to any recursive call to a thunk-decorated function.

For example, suppose we have the following naive recursive implementation of the factorial function:

def factorial(n):
    return n if n <= 1 else n * factorial(n - 1)

Of course, with the default Python recursion limit, for values of n of roughly 1000 and larger, this factorial implementation will exceed the recursion limit and cause an exception. Therefore, if we want to avoid this limitation, let's follow the steps outlined above.

First, let's modify the function such that the recursive call is a tail call. In order to do this, we'll introduce an accumulator argument (product, in this case):

def factorial(n, product=1):
    # 1. Convert recursive call to tail call with the help of an accumulator argument
    return product if n <= 1 else factorial(n - 1, product * n)

Now, let's adjust this so that the product argument isn't publicly exposed, and our public signature is the same as the original:

def factorial(n):
    return _factorial(n)

def _factorial(n, product=1):
    return product if n <= 1 else _factorial(n - 1, product * n)

Our next step is to decorate our tail-recursive function with thunk:

def factorial(n):
    return _factorial(n)

# 2. Decorate tail-recursive function with `thunk`
@thunk
def _factorial(n, product=1):
    return product if n <= 1 else _factorial(n - 1, product * n)

Finally, we must append a call to trampoline to the initial call to our thunked function:

def factorial(n):
    # 3. Append call to `trampoline` to only the initial call to the "thunked" function
    return _factorial(n).trampoline()

@thunk
def _factorial(n, product=1):
    return product if n <= 1 else _factorial(n - 1, product * n)

Now we can call factorial with arbitrarily large values of n without fear of exceeding the recursion limit.

class Thunk(object):
def __init__(self, f, *args, **kwargs):
self._thunk = lambda: f(*args, **kwargs)
def __call__(self):
return self._thunk()
def trampoline(self):
result = self()
while isinstance(result, Thunk):
result = result()
return result
def thunk(f):
def wrapper(*args, **kwargs):
return Thunk(f, *args, **kwargs)
return wrapper
from thunk import thunk
def factorial(n):
"""
>>> factorial(5)
120
>>> factorial(2000) # doctest: +ELLIPSIS
33162750924506332411753933805763240382811172081057803945719354370603807...
"""
return _factorial(n).trampoline()
@thunk
def _factorial(n, product=1):
return product if n <= 1 else _factorial(n - 1, product * n)
def even(n):
"""
>>> even(-2)
True
>>> even(-1)
False
>>> even(0)
True
>>> even(1)
False
>>> even(2)
True
>>> even(10001)
False
"""
return _even(n).trampoline()
def odd(n):
"""
>>> odd(0)
False
>>> odd(1)
True
>>> odd(9999)
True
"""
return _odd(n).trampoline()
@thunk
def _even(n):
return _even(-n) if n < 0 else n == 0 or _odd(n - 1)
@thunk
def _odd(n):
return _odd(-n) if n < 0 else n != 0 and _even(n - 1)
if __name__ == '__main__':
import doctest
doctest.testmod(verbose=True)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment