Skip to content

Instantly share code, notes, and snippets.

@vxgmichel
Last active April 6, 2020 01:32
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save vxgmichel/590f3e0dbc1a3a841251881686313346 to your computer and use it in GitHub Desktop.
Save vxgmichel/590f3e0dbc1a3a841251881686313346 to your computer and use it in GitHub Desktop.
Stackless recursion using generators
"""
A decorator to turn a generator into a cached, stackless recursive function.
The yield keyword is used to send the arguments to pass to the recursive
function and retrieve the return value, i.e:
@stackless
def fibonacci(n):
if n < 2:
return n
return (yield (n - 1,)) + (yield (n - 2,))
"""
def stackless(func=None, *, cached=True):
import functools
def decorator(func):
@functools.wraps(func)
def wrapper(*args):
cache = {}
result = None
stack = [(args, func(*args))]
while stack:
args, gen = stack[-1]
try:
args = gen.send(result)
except StopIteration as exc:
stack.pop()
result = exc.value
if cached:
cache[args] = result
continue
if cached and args in cache:
result = cache[args]
continue
result = None
stack.append((args, func(*args)))
return result
return wrapper
return decorator if func is None else decorator(func)
# Testing
import pytest
@stackless
def stackless_fib(n):
if n < 2:
return n
return (yield (n - 1,)) + (yield (n - 2,))
@stackless(cached=False)
def slow_stackless_fib(n):
if n < 2:
return n
return (yield (n - 1,)) + (yield (n - 2,))
def recursive_fib(n):
if n < 2:
return n
return recursive_fib(n - 1) + recursive_fib(n - 2)
def iterative_fib(n):
a, b = 0, 1
for _ in range(n):
a, b = b, a + b
return a
@stackless
def stackless_sum(n):
if n == 0:
return 0
return n + (yield (n - 1,))
def recursive_sum(n):
if n == 0:
return 0
return n + recursive_sum(n - 1)
@pytest.mark.parametrize(
"fib", [recursive_fib, stackless_fib, slow_stackless_fib]
)
def test_fib(fib):
n = 28
assert fib(n) == iterative_fib(n)
def test_sum():
n = 10 ** 5
assert stackless_sum(n) == n * (n + 1) // 2
with pytest.raises(RecursionError):
recursive_sum(n)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment