Skip to content

Instantly share code, notes, and snippets.

@ejconlon
Created January 30, 2011 05:03
Show Gist options
  • Save ejconlon/802557 to your computer and use it in GitHub Desktop.
Save ejconlon/802557 to your computer and use it in GitHub Desktop.
Tail-Recursion helper in Python
#!/usr/bin/env python
"""
Tail-Recursion helper in Python.
Inspired by the trampoline function at
http://jasonmbaker.com/tail-recursion-in-python-using-pysistence
Tail-recursive functions return calls to tail-recursive functions
(themselves, most of the time). For example, this is tail-recursive:
sum [] acc = acc
sum (x:xs) = sum xs (acc+x)
And this is not:
fib n | n == 0 || n == 1 = 1
| otherwise = (fib (n-1)) + (fib (n-2))
because fib n returns an application of (+), not directly of fib.
Suppose we wanted to write sum in Python like we could in Haskell:
"""
# iterator must have a has_next method...
def nontrampsum(iterator, accumulator):
if not iterator.has_next():
return accumulator
else:
head = iterator.next()
accumulator += head
return nontrampsum(iterator, accumulator)
"""
It looks elegant, but would blow up the stack pretty quickly.
Python will fully evaluate the recurisive call before returning, unlike
lazier Haskell.
We'll need some help:
"""
# Factory for consuming tail-recursive functions
# that return partially applied TR functions
def trampoline(f, *args, **kwargs):
def trampolined_f(*args, **kwargs):
result = f(*args, **kwargs)
while callable(result):
result = result()
return result
return trampolined_f
# Creates a 'suspension' of f
# Rreturns a function of zero-arity
# functools.partial does more though...
def partial(f, *args, **kwargs):
def partial_f():
return f(*args, **kwargs)
return partial_f
"""
First, we can make our tail-recursive function not directly call itself,
but instead return a closure in which it is applied. Then we'll decorate
it with trampoline to call the suspensions it returns until the base case of
the recursion is reached.
"""
def trampsum_inner(iterator, acc):
if not iterator.has_next():
return acc
else:
head = iterator.next()
acc += head
return partial(trampsum_inner, iterator, acc)
trampsum = trampoline(trampsum_inner)
"""
And a digression: we'll need to define an iterator with a has_next method.
I'd like to be able to pattern-match on iterators like lists in Haskell
sum [] acc = acc
sum (x:xs) = sum xs (acc+x)
We can just wrap an iterator and look ahead lazily.
"""
import collections
class LookAheadIterator(collections.Iterator):
def __init__(self, wrapped):
self._wrapped = iter(wrapped)
self._need_to_advance = True
self._has_next = False
self._cache = None
def has_next(self):
if self._need_to_advance:
self._advance()
return self._has_next
def _advance(self):
try:
self._cache = self._wrapped.next()
self._has_next = True
except StopIteration:
self._has_next = False
self._need_to_advance = False
def next(self):
if self._need_to_advance:
self._advance()
if self._has_next:
self._need_to_advance = True
return self._cache
else:
raise StopIteration()
def __next__(self):
self.next()
"""
Let's prove (sadly) that it's not the speediest:
"""
import cProfile
def test(f):
iterator = LookAheadIterator(xrange(1000000))
accumulator = 0
print f(iterator, accumulator)
print "Summing with built-in sum"
cProfile.run('test(sum)')
print "Summing with trampolined sum"
cProfile.run('test(trampsum)')
"""
499999500000
2000009 function calls in 2.254 CPU seconds
Ordered by: standard name
ncalls tottime percall cumtime percall filename:lineno(function)
1 0.000 0.000 2.254 2.254 <string>:1(<module>)
1 0.000 0.000 0.000 0.000 _abcoll.py:66(__iter__)
1000001 0.909 0.000 1.747 0.000 cool.py:107(next)
1 0.000 0.000 2.254 2.254 cool.py:125(test)
1 0.000 0.000 0.000 0.000 cool.py:88(__init__)
1000001 0.838 0.000 0.838 0.000 cool.py:99(_advance)
1 0.000 0.000 0.000 0.000 {iter}
1 0.000 0.000 0.000 0.000 {method 'disable' of '_lsprof.Profiler' objects}
1 0.506 0.506 2.254 2.254 {sum}
499999500000
7000010 function calls in 6.139 CPU seconds
Ordered by: standard name
ncalls tottime percall cumtime percall filename:lineno(function)
1 0.000 0.000 6.139 6.139 <string>:1(<module>)
1000000 0.588 0.000 0.588 0.000 cool.py:107(next)
1 0.000 0.000 6.139 6.139 cool.py:125(test)
1 0.761 0.761 6.139 6.139 cool.py:44(trampolined_f)
1000000 0.358 0.000 0.358 0.000 cool.py:54(partial)
1000000 0.784 0.000 5.242 0.000 cool.py:55(partial_f)
1000001 1.763 0.000 4.457 0.000 cool.py:66(trampsum_inner)
1 0.000 0.000 0.000 0.000 cool.py:88(__init__)
1000001 0.736 0.000 1.748 0.000 cool.py:94(has_next)
1000001 1.012 0.000 1.012 0.000 cool.py:99(_advance)
1000001 0.136 0.000 0.136 0.000 {callable}
1 0.000 0.000 0.000 0.000 {iter}
1 0.000 0.000 0.000 0.000 {method 'disable' of '_lsprof.Profiler' objects}
But it won't topple the stack:
"""
#print cProfile.run('test(nontrampsum)')
# RuntimeError: maximum recursion depth exceeded while calling a Python object
"""
We could fold with it:
"""
def foldl_inner(f, accumulator, iterator):
if not iterator.has_next():
return accumulator
else:
head = iterator.next()
accumulator = f(accumulator, head)
return partial(foldl_inner, f, accumulator, iterator)
foldl = trampoline(foldl_inner)
def add(a, b): return a + b
def foldlsum(iterator, accumulator):
return foldl(add, accumulator, iterator)
print "Summing with trampolined foldl"
cProfile.run('test(foldlsum)')
"""
499999500000
8000011 function calls in 6.874 CPU seconds
Ordered by: standard name
ncalls tottime percall cumtime percall filename:lineno(function)
1 0.000 0.000 6.874 6.874 <string>:1(<module>)
1000000 0.601 0.000 0.601 0.000 cool.py:107(next)
1 0.000 0.000 6.874 6.874 cool.py:125(test)
1000001 2.159 0.000 5.109 0.000 cool.py:187(foldl_inner)
1000000 0.206 0.000 0.206 0.000 cool.py:196(add)
1 0.000 0.000 6.874 6.874 cool.py:198(foldlsum)
1 0.824 0.824 6.874 6.874 cool.py:44(trampolined_f)
1000000 0.389 0.000 0.389 0.000 cool.py:54(partial)
1000000 0.819 0.000 5.927 0.000 cool.py:55(partial_f)
1 0.000 0.000 0.000 0.000 cool.py:88(__init__)
1000001 0.750 0.000 1.754 0.000 cool.py:94(has_next)
1000001 1.004 0.000 1.004 0.000 cool.py:99(_advance)
1000001 0.122 0.000 0.122 0.000 {callable}
1 0.000 0.000 0.000 0.000 {iter}
1 0.000 0.000 0.000 0.000 {method 'disable' of '_lsprof.Profiler' objects}
Use it wisely, I guess. (Or not at all.) Look into functools.partial too.
"""
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment