Skip to content

Instantly share code, notes, and snippets.

@jeffdonahue
Last active November 8, 2023 20:59
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save jeffdonahue/12ff1b8e90bed6ed22221cbd9ba49578 to your computer and use it in GitHub Desktop.
Save jeffdonahue/12ff1b8e90bed6ed22221cbd9ba49578 to your computer and use it in GitHub Desktop.
try:
range = xrange # Python 2
except NameError:
pass # Python 3
def lazy_product(*iter_funcs, **kwargs):
"""
If f1, f2, ..., are functions which have no (required) arguments and
return iterables, then
lazy_product(f1, f2, ..., repeat=k)
is equivalent to
itertools.product(f1(), f2(), ..., repeat=k);
but much faster in certain cases.
For example, let f have the following definition:
def f(n):
def func():
return xrange(n)
return func
Then, this code:
p = itertools.product(*[f(N)() for _ in xrange(M)], repeat=K)
first_element = next(p)
takes O(NMK) time and memory to execute, whereas
p = lazy_product(*[f(N) for _ in xrange(M)], repeat=K)
first_element = next(p)
is equivalent, and takes just O(MK) time and memory.
(Of course, iterating over either result is exactly N^(MK) steps, and each
step takes O(1) time; the only difference between itertools.product and
lazy_product is at the time of initialization of the iterable p
(including the call to next(p) to get the first element, as shown above).
itertools.product's O(N) speed/memory overhead results from its saving the
full result of xrange(N) as a list (or similar data structure) in memory.
This is necessary as itertools.product takes iterables as input, and it is
not generally possible to "reset" an iterator, so all of its values
instead need to be stored. So, the input to lazy_product is an iterable
of *functions* returning iterables, rather than the iterables themselves,
allowing for repeated iteration over each iterable (by calling iter_func
again when we reach the end of the iterable that iter_func created on
the previous call).
Inputs:
- iter_funcs: functions with no (required) arguments that create and
return an iterable. Each function is assumed to be be deterministic --
i.e., return an identical iterable on each call. (Otherwise, the
behavior of lazy_product is undefined.)
- kwargs: a dict which is either empty or contains only the key `repeat`,
with an integer value. In Python 3, the function header could (much
more cleanly) be written as:
def lazy_product(*iter_funcs, repeat=1):
and the first two lines of ugly parsing code could be dropped.
Returns:
an iterator over the Cartesian product of the iterables returned
by the elements of iter_funcs -- equivalent to:
return itertools.product(*(f() for f in iter_funcs), **kwargs)
"""
repeat = kwargs.pop('repeat', 1)
if kwargs: raise ValueError('unknown kwargs: %s' % kwargs.keys())
iters = [iter(f()) for _ in range(repeat) for f in iter_funcs]
values = [next(i) for i in iters]
while True:
yield tuple(values)
for index in reversed(range(len(iters))):
try:
values[index] = next(iters[index])
break
except StopIteration:
iters[index] = iter(iter_funcs[index % len(iter_funcs)]())
values[index] = next(iters[index])
else: return
from functools import partial
def lazy_product_func(*a, **k):
return partial(lazy_product, *a, **k)
def range_func(*a, **k):
return partial(range, *a, **k)
xrange_func = range_func
if __name__ == '__main__':
import itertools
def test_equivalence(*iter_funcs, **kwargs):
lazy_result = lazy_product(*iter_funcs, **kwargs)
iters = (f() for f in iter_funcs)
itertools_result = itertools.product(*iters, **kwargs)
return list(lazy_result) == list(itertools_result)
assert test_equivalence()
assert test_equivalence(repeat=0)
assert test_equivalence(repeat=1)
assert test_equivalence(repeat=2)
assert test_equivalence(range_func(0))
assert test_equivalence(range_func(0), repeat=2)
assert test_equivalence(range_func(2))
assert test_equivalence(range_func(2), repeat=2)
assert test_equivalence(range_func(2), range_func(3))
assert test_equivalence(range_func(2), range_func(0), range_func(3))
assert test_equivalence(range_func(2), range_func(0), range_func(3),
repeat=2)
assert test_equivalence(range_func(2), range_func(3), repeat=2)
assert test_equivalence(range_func(2), range_func(3), repeat=2)
assert test_equivalence(range_func(3), range_func(2, 7), repeat=0)
assert test_equivalence(range_func(3), range_func(2, 7), repeat=4)
print('Test passed!')
In [1]: import itertools; from lazy_product import *
In [2]: def f(n):
def func():
return xrange(n)
return func
In [3]: K=1; M=10; N=1000000;
In [4]: itertools_input = [f(N)() for _ in xrange(M)]
In [5]: lazy_input = [f(N) for _ in xrange(M)]
In [6]: %timeit p = itertools.product(*itertools_input, repeat=K); next(p)
10 loops, best of 3: 155 ms per loop
In [7]: %timeit p = lazy_product(*lazy_input, repeat=K); next(p)
100000 loops, best of 3: 6.21 µs per loop
In [8]: N*=10
In [9]: itertools_input = [f(N)() for _ in xrange(M)]
In [10]: lazy_input = [f(N) for _ in xrange(M)]
In [11]: %timeit p = itertools.product(*itertools_input, repeat=K); next(p)
1 loops, best of 3: 1.03 s per loop
In [12]: %timeit p = lazy_product(*lazy_input, repeat=K); next(p)
100000 loops, best of 3: 6.32 µs per loop
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment