Created
November 13, 2015 16:51
-
-
Save ales-erjavec/35f22c88e3e06bcfb14c to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Tail call 'optimization' | |
------------------------ | |
A simple tail call 'optimization' implemented by a trampoline. | |
Usage | |
----- | |
Use the `tailrec` decorator on the function, and change the tail recursive | |
call (the call site) with a `return functioname.tc(...)`. | |
Example | |
------- | |
>>> # a tail recursive fibonacchi | |
>>> def fibn(n): | |
... @tailrec | |
... def fibn_acc(i, x1, x2): | |
... if i >= n: | |
... return x1 | |
... else: | |
... return fibn_acc.tc(i + 1, x1 + x2, x1) | |
... return fibn_acc(0, 0, 1) | |
... | |
>>> fibn(10) | |
55 | |
>>> fibn(1100) | |
34428592852410... | |
>>> fibn(10 ** 5) | |
25974069347221... | |
>>> # mutual tail recursion | |
>>> @tailrec | |
... def ping(i): | |
... if i > 0: | |
... print("ping") | |
... return pong.tc(i - 1) | |
... | |
>>> @tailrec | |
... def pong(i): | |
... if i > 0: | |
... print("pong") | |
... return ping.tc(i - 1) | |
... | |
>>> ping(10000) | |
ping | |
pong... | |
""" | |
from collections import namedtuple | |
from functools import partial, wraps | |
tailcall = namedtuple("tailcall", ["func", "args", "kwargs"]) | |
class tailcall(tailcall): | |
def __new__(cls, func, *args, **kwargs): | |
return super(tailcall, cls).__new__(cls, func, args, kwargs) | |
def tailrec(func): | |
""" | |
A tail recursion decorator. | |
""" | |
@wraps(func) | |
def trampoline_wrapper(*args, **kwargs): | |
r = func(*args, **kwargs) | |
while True: | |
if isinstance(r, tailcall): | |
if is_tailrec(r.func): | |
r = r.func.__wrapped__(*r.args, **r.kwargs) | |
else: | |
r = r.func(*r.args, **r.kwargs) | |
else: | |
return r | |
trampoline_wrapper.__tailrec = True | |
trampoline_wrapper.__wrapped__ = func | |
trampoline_wrapper.tailcall = partial(tailcall, func) | |
trampoline_wrapper.tc = partial(tailcall, func) | |
return trampoline_wrapper | |
def is_tailrec(func): | |
""" | |
Is the function `func` a `tailrec` function. | |
""" | |
return getattr(func, "__tailrec", False) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment