Created
March 22, 2021 10:29
-
-
Save sweeneyde/fb6734d7b9f7d17e132c28af9ecb6270 to your computer and use it in GitHub Desktop.
A less-greedy implementation of itertools.product: consume iterators lazily.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class LazyProductObject: | |
# Assume there are n >= 1 iterators, with repeat = r. | |
# len(pools) == n and len(indices) == len(result) == n*r | |
# pools[0:num_incomplete_pools] are lists of items that will still | |
# get added to from iterators, while pools[num_incomplete_pools:n] | |
# are finished and don not need to be re-modified. | |
# If repeat >= 2: | |
# then the pools will fill up as we iterate the first time | |
# over all possible result[(n-1)*r:n*r] | |
# If repeat == 1: | |
# then we don't need to store the most significant iterator's | |
# items at all; they only get iterated over once. | |
# states: | |
# 0 --> Not initialized | |
# 1 --> running | |
# 2 --> stopped | |
# 3 --> no iterables passed; yield () | |
__slots__ = ("iterators", "repeat", "pools", "indices", | |
"num_incomplete_pools", "result", "state") | |
def __init__(self, *iterables, repeat=1): | |
if repeat < 0: | |
raise ValueError("repeat argument cannot be negative") | |
self.iterators = [iter(x) for x in iterables] | |
self.repeat = repeat | |
self.pools = self.indices = self.result = self.num_incomplete_pools = None | |
self.state = 0 | |
if not iterables or repeat == 0: | |
self.iterators = None | |
self.state = 3 | |
def __iter__(self): | |
return self | |
def _stop(self): | |
# free up my memory | |
self.pools = self.indices = self.result = self.num_incomplete_pools = None | |
self.iterators = self.repeat = None | |
self.state = 2 | |
raise StopIteration() | |
def __getstate__(self): | |
return (self.iterators, self.repeat, self.pools, self.indices, | |
self.num_incomplete_pools, self.result, self.state) | |
def __setstate__(self, state): | |
(self.iterators, self.repeat, self.pools, self.indices, | |
self.num_incomplete_pools, self.result, self.state) = state | |
def __next__(self): | |
if self.state >= 2: | |
if self.state == 3: | |
# list(product()) == [()] | |
self.state = 2 | |
return () | |
else: | |
raise StopIteration() | |
sentinel = object() | |
n = len(self.iterators) | |
if self.state == 0: | |
# initialize | |
self.pools = [] | |
self.result = [] | |
for it in self.iterators: | |
first = next(it, sentinel) | |
if first is sentinel: | |
return self._stop() | |
self.result.append(first) | |
self.pools = [[x] for x in self.result] | |
self.result *= self.repeat | |
if self.repeat == 1: | |
self.pools[0] = None | |
self.indices = [0] * (n * self.repeat) | |
self.num_incomplete_pools = n | |
self.state = 1 | |
return tuple(self.result) | |
indices = self.indices | |
result = self.result | |
pools = self.pools | |
iterators = self.iterators | |
offset = n*(self.repeat - 1) | |
# iterating over already-complete pools | |
# while some remain unfilled. | |
i = len(indices) - 1 | |
while i >= offset + self.num_incomplete_pools: | |
assert i % n == i - offset | |
indices[i] += 1 | |
if indices[i] < len(pools[i - offset]): | |
result[i] = pools[i - offset][indices[i]] | |
return tuple(result) | |
# Otherwise, "carry" | |
indices[i] = 0 | |
result[i] = pools[i - offset][0] | |
i -= 1 | |
# For the memory optimization below | |
lowest_pool_index = (offset if offset else 1) | |
# Making it here means we need to collect one new item | |
# from an iterator. | |
while i >= lowest_pool_index: | |
assert i % n == i - offset | |
new = next(iterators[i - offset], sentinel) | |
if new is not sentinel: | |
pools[i - offset].append(new) | |
result[i] = new | |
return tuple(result) | |
# Otherwise, pools[i % n] is done. | |
self.num_incomplete_pools -= 1 | |
assert indices[i] == 0 | |
result[i] = pools[i - offset][0] | |
i -= 1 | |
# Memory optimization: don't store the most significant items. | |
if offset == 0: | |
assert i == 0 | |
assert pools[0] is None | |
new = next(iterators[0], sentinel) | |
if new is not sentinel: | |
result[0] = new | |
return tuple(result) | |
else: | |
return self._stop() | |
assert self.num_incomplete_pools == 0 | |
while i >= 0: | |
indices[i] += 1 | |
if indices[i] < len(pools[i % n]): | |
result[i] = pools[i % n][indices[i]] | |
return tuple(result) | |
# Otherwise, "carry" | |
indices[i] = 0 | |
result[i] = pools[i % n][0] | |
i -= 1 | |
# The most significant repetition overflowed. | |
return self._stop() | |
if __name__ == "__main__": | |
product = LazyProductObject | |
# New tests ###################################################### | |
assert list(product("ab", "ab")) == [("a", "a"), ("a", "b"), | |
("b", "a"), ("b", "b")] | |
assert list(product("ab", repeat=2)) == [("a", "a"), ("a", "b"), | |
("b", "a"), ("b", "b")] | |
# Don't call iter() multiple times. | |
class OnceIterable: | |
def __init__(self): | |
self.made_iter = False | |
def __iter__(self): | |
if self.made_iter: | |
raise RuntimeError() | |
self.made_iter = True | |
return iter("python") | |
assert list(product(OnceIterable())) == [(c,) for c in "python"] | |
assert list(product(OnceIterable(), repeat=3)) == \ | |
[(a, b, c) for a in "python" for b in "python" for c in "python"] | |
arr = list(product(OnceIterable(), OnceIterable(), OnceIterable(), repeat=2)) | |
assert len(arr) == len("python") ** 6 | |
from itertools import count | |
it = product(count(0), count(0), "ab") | |
assert next(it) == (0, 0, "a") | |
assert next(it) == (0, 0, "b") | |
assert next(it) == (0, 1, "a") | |
assert next(it) == (0, 1, "b") | |
assert next(it) == (0, 2, "a") | |
assert next(it) == (0, 2, "b") | |
# Tests copied from test_itertools.py ############################ | |
from math import prod | |
import random | |
for args, result in [ | |
([], [()]), # zero iterables | |
(['ab'], [('a',), ('b',)]), # one iterable | |
([range(2), range(3)], [(0,0), (0,1), (0,2), (1,0), (1,1), (1,2)]), # two iterables | |
([range(0), range(2), range(3)], []), # first iterable with zero length | |
([range(2), range(0), range(3)], []), # middle iterable with zero length | |
([range(2), range(3), range(0)], []), # last iterable with zero length | |
]: | |
assert list(product(*args)) == result | |
for r in range(4): | |
assert list(product(*(args*r))) == \ | |
list(product(*args, **dict(repeat=r))) | |
assert len(list(product(*[range(7)]*6))) == 7**6 | |
try: | |
product(range(6), None) | |
except TypeError: | |
pass | |
else: | |
raise AssertionError | |
def product1(*args, **kwds): | |
pools = list(map(tuple, args)) * kwds.get('repeat', 1) | |
n = len(pools) | |
if n == 0: | |
yield () | |
return | |
if any(len(pool) == 0 for pool in pools): | |
return | |
indices = [0] * n | |
yield tuple(pool[i] for pool, i in zip(pools, indices)) | |
while 1: | |
for i in reversed(range(n)): # right to left | |
if indices[i] == len(pools[i]) - 1: | |
continue | |
indices[i] += 1 | |
for j in range(i+1, n): | |
indices[j] = 0 | |
yield tuple(pool[i] for pool, i in zip(pools, indices)) | |
break | |
else: | |
return | |
def product2(*args, **kwds): | |
'Pure python version used in docs' | |
pools = list(map(tuple, args)) * kwds.get('repeat', 1) | |
result = [[]] | |
for pool in pools: | |
result = [x+[y] for x in result for y in pool] | |
for prod in result: | |
yield tuple(prod) | |
argtypes = ['', 'abc', '', range(0), range(4), dict(a=1, b=2, c=3), | |
set('abcdefg'), range(11), tuple(range(13))] | |
for i in range(100): | |
args = [random.choice(argtypes) for j in range(random.randrange(5))] | |
expected_len = prod(map(len, args)) | |
assert len(list(product(*args))) == expected_len | |
assert list(product(*args)) == list(product1(*args)) | |
assert list(product(*args)) == list(product2(*args)) | |
args = map(iter, args) | |
assert len(list(product(*args))) == expected_len | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment