Skip to content

Instantly share code, notes, and snippets.

@andreasvc
Last active March 29, 2017 02:33
Show Gist options
  • Save andreasvc/5455646 to your computer and use it in GitHub Desktop.
Save andreasvc/5455646 to your computer and use it in GitHub Desktop.
Get the cartesian product of an arbitrary number of iterables, including infinite sequences.
def cartpi(seq):
""" A depth-first cartesian product for a sequence of iterables;
i.e., all values of the last iterable are consumed before advancing the
preceding ones. Like itertools.product(), but supports infinite sequences.
>>> from itertools import islice, count
>>> list(islice(cartpi([count(), count()]), 9))
[(0, 0), (0, 1), (0, 2), (0, 3), (0, 4), (0, 5), (0, 6), (0, 7), (0, 8)]
"""
if seq:
return (b + (a,) for b in cartpi(seq[:-1]) for a in seq[-1])
return ((), )
def bfcartpi(seq):
""" Breadth-first (diagonal) cartesian product of a sequence of iterables;
each iterable is advanced in turn in a round-robin fashion. As usual with
breadth-first, this comes at the cost of memory consumption.
>>> from itertools import islice, count
>>> list(islice(bfcartpi([count(), count()]), 9))
[(0, 0), (0, 1), (1, 0), (1, 1), (0, 2), (1, 2), (2, 0), (2, 1), (2, 2)]
"""
#get iterators for items of seq
seqit = map(iter, seq)
#fetch initial values
try:
seqlist = [[next(a)] for a in seqit]
except StopIteration:
return
yield tuple(a[0] for a in seqlist)
#bookkeeping of which iterators still have values
stopped = len(seqit) * [False]
n = len(seqit)
while not all(stopped):
if n == 0:
n = len(seqit) - 1
else:
n -= 1
if stopped[n]:
continue
try:
seqlist[n].append(next(seqit[n]))
except StopIteration:
stopped[n] = True
continue
for result in cartpi(seqlist[:n] + [seqlist[n][-1:]] + seqlist[n+1:]):
yield result
if __name__ == '__main__':
from itertools import islice, count
menu = ('eggs', 'sausage', 'spam')
seq = (menu, ('and', 'or'), menu, ('and', 'or'), count(2), menu)
print('depth-first:')
for val in islice(cartpi(seq), 20):
print(' '.join(str(a) for a in val))
seq = (menu, ('and', 'or'), menu, ('and', 'or'), count(2), menu)
print('breadth-first:')
for val in islice(bfcartpi(seq), 100):
print(' '.join(str(a) for a in val))
depth-first:
eggs and eggs and 2 eggs
eggs and eggs and 2 sausage
eggs and eggs and 2 spam
eggs and eggs and 3 eggs
eggs and eggs and 3 sausage
eggs and eggs and 3 spam
eggs and eggs and 4 eggs
eggs and eggs and 4 sausage
eggs and eggs and 4 spam
eggs and eggs and 5 eggs
breadth-first:
eggs and eggs and 2 eggs
eggs and eggs and 2 sausage
eggs and eggs and 3 eggs
eggs and eggs and 3 sausage
eggs and eggs or 2 eggs
eggs and eggs or 2 sausage
eggs and eggs or 3 eggs
eggs and eggs or 3 sausage
eggs and sausage and 2 eggs
eggs and sausage and 2 sausage
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment