Skip to content

Instantly share code, notes, and snippets.

@sweeneyde
Created March 22, 2021 10:29
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save sweeneyde/fb6734d7b9f7d17e132c28af9ecb6270 to your computer and use it in GitHub Desktop.
Save sweeneyde/fb6734d7b9f7d17e132c28af9ecb6270 to your computer and use it in GitHub Desktop.
A less-greedy implementation of itertools.product: consume iterators lazily.
class LazyProductObject:
# Assume there are n >= 1 iterators, with repeat = r.
# len(pools) == n and len(indices) == len(result) == n*r
# pools[0:num_incomplete_pools] are lists of items that will still
# get added to from iterators, while pools[num_incomplete_pools:n]
# are finished and don not need to be re-modified.
# If repeat >= 2:
# then the pools will fill up as we iterate the first time
# over all possible result[(n-1)*r:n*r]
# If repeat == 1:
# then we don't need to store the most significant iterator's
# items at all; they only get iterated over once.
# states:
# 0 --> Not initialized
# 1 --> running
# 2 --> stopped
# 3 --> no iterables passed; yield ()
__slots__ = ("iterators", "repeat", "pools", "indices",
"num_incomplete_pools", "result", "state")
def __init__(self, *iterables, repeat=1):
if repeat < 0:
raise ValueError("repeat argument cannot be negative")
self.iterators = [iter(x) for x in iterables]
self.repeat = repeat
self.pools = self.indices = self.result = self.num_incomplete_pools = None
self.state = 0
if not iterables or repeat == 0:
self.iterators = None
self.state = 3
def __iter__(self):
return self
def _stop(self):
# free up my memory
self.pools = self.indices = self.result = self.num_incomplete_pools = None
self.iterators = self.repeat = None
self.state = 2
raise StopIteration()
def __getstate__(self):
return (self.iterators, self.repeat, self.pools, self.indices,
self.num_incomplete_pools, self.result, self.state)
def __setstate__(self, state):
(self.iterators, self.repeat, self.pools, self.indices,
self.num_incomplete_pools, self.result, self.state) = state
def __next__(self):
if self.state >= 2:
if self.state == 3:
# list(product()) == [()]
self.state = 2
return ()
else:
raise StopIteration()
sentinel = object()
n = len(self.iterators)
if self.state == 0:
# initialize
self.pools = []
self.result = []
for it in self.iterators:
first = next(it, sentinel)
if first is sentinel:
return self._stop()
self.result.append(first)
self.pools = [[x] for x in self.result]
self.result *= self.repeat
if self.repeat == 1:
self.pools[0] = None
self.indices = [0] * (n * self.repeat)
self.num_incomplete_pools = n
self.state = 1
return tuple(self.result)
indices = self.indices
result = self.result
pools = self.pools
iterators = self.iterators
offset = n*(self.repeat - 1)
# iterating over already-complete pools
# while some remain unfilled.
i = len(indices) - 1
while i >= offset + self.num_incomplete_pools:
assert i % n == i - offset
indices[i] += 1
if indices[i] < len(pools[i - offset]):
result[i] = pools[i - offset][indices[i]]
return tuple(result)
# Otherwise, "carry"
indices[i] = 0
result[i] = pools[i - offset][0]
i -= 1
# For the memory optimization below
lowest_pool_index = (offset if offset else 1)
# Making it here means we need to collect one new item
# from an iterator.
while i >= lowest_pool_index:
assert i % n == i - offset
new = next(iterators[i - offset], sentinel)
if new is not sentinel:
pools[i - offset].append(new)
result[i] = new
return tuple(result)
# Otherwise, pools[i % n] is done.
self.num_incomplete_pools -= 1
assert indices[i] == 0
result[i] = pools[i - offset][0]
i -= 1
# Memory optimization: don't store the most significant items.
if offset == 0:
assert i == 0
assert pools[0] is None
new = next(iterators[0], sentinel)
if new is not sentinel:
result[0] = new
return tuple(result)
else:
return self._stop()
assert self.num_incomplete_pools == 0
while i >= 0:
indices[i] += 1
if indices[i] < len(pools[i % n]):
result[i] = pools[i % n][indices[i]]
return tuple(result)
# Otherwise, "carry"
indices[i] = 0
result[i] = pools[i % n][0]
i -= 1
# The most significant repetition overflowed.
return self._stop()
if __name__ == "__main__":
product = LazyProductObject
# New tests ######################################################
assert list(product("ab", "ab")) == [("a", "a"), ("a", "b"),
("b", "a"), ("b", "b")]
assert list(product("ab", repeat=2)) == [("a", "a"), ("a", "b"),
("b", "a"), ("b", "b")]
# Don't call iter() multiple times.
class OnceIterable:
def __init__(self):
self.made_iter = False
def __iter__(self):
if self.made_iter:
raise RuntimeError()
self.made_iter = True
return iter("python")
assert list(product(OnceIterable())) == [(c,) for c in "python"]
assert list(product(OnceIterable(), repeat=3)) == \
[(a, b, c) for a in "python" for b in "python" for c in "python"]
arr = list(product(OnceIterable(), OnceIterable(), OnceIterable(), repeat=2))
assert len(arr) == len("python") ** 6
from itertools import count
it = product(count(0), count(0), "ab")
assert next(it) == (0, 0, "a")
assert next(it) == (0, 0, "b")
assert next(it) == (0, 1, "a")
assert next(it) == (0, 1, "b")
assert next(it) == (0, 2, "a")
assert next(it) == (0, 2, "b")
# Tests copied from test_itertools.py ############################
from math import prod
import random
for args, result in [
([], [()]), # zero iterables
(['ab'], [('a',), ('b',)]), # one iterable
([range(2), range(3)], [(0,0), (0,1), (0,2), (1,0), (1,1), (1,2)]), # two iterables
([range(0), range(2), range(3)], []), # first iterable with zero length
([range(2), range(0), range(3)], []), # middle iterable with zero length
([range(2), range(3), range(0)], []), # last iterable with zero length
]:
assert list(product(*args)) == result
for r in range(4):
assert list(product(*(args*r))) == \
list(product(*args, **dict(repeat=r)))
assert len(list(product(*[range(7)]*6))) == 7**6
try:
product(range(6), None)
except TypeError:
pass
else:
raise AssertionError
def product1(*args, **kwds):
pools = list(map(tuple, args)) * kwds.get('repeat', 1)
n = len(pools)
if n == 0:
yield ()
return
if any(len(pool) == 0 for pool in pools):
return
indices = [0] * n
yield tuple(pool[i] for pool, i in zip(pools, indices))
while 1:
for i in reversed(range(n)): # right to left
if indices[i] == len(pools[i]) - 1:
continue
indices[i] += 1
for j in range(i+1, n):
indices[j] = 0
yield tuple(pool[i] for pool, i in zip(pools, indices))
break
else:
return
def product2(*args, **kwds):
'Pure python version used in docs'
pools = list(map(tuple, args)) * kwds.get('repeat', 1)
result = [[]]
for pool in pools:
result = [x+[y] for x in result for y in pool]
for prod in result:
yield tuple(prod)
argtypes = ['', 'abc', '', range(0), range(4), dict(a=1, b=2, c=3),
set('abcdefg'), range(11), tuple(range(13))]
for i in range(100):
args = [random.choice(argtypes) for j in range(random.randrange(5))]
expected_len = prod(map(len, args))
assert len(list(product(*args))) == expected_len
assert list(product(*args)) == list(product1(*args))
assert list(product(*args)) == list(product2(*args))
args = map(iter, args)
assert len(list(product(*args))) == expected_len
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment