Created
March 14, 2016 18:29
-
-
Save jcorbin/3173695ec83b20386f36 to your computer and use it in GitHub Desktop.
Dask Scalar Zip
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
diff --git a/dask/bag/core.py b/dask/bag/core.py | |
index dc8a04d..47535c1 100644 | |
--- a/dask/bag/core.py | |
+++ b/dask/bag/core.py | |
@@ -6,6 +6,7 @@ import math | |
import bz2 | |
import os | |
import uuid | |
+from collections import Iterable | |
from fnmatch import fnmatchcase | |
from glob import glob | |
from collections import Iterable, Iterator, defaultdict | |
@@ -1338,7 +1339,15 @@ def bag_range(n, npartitions): | |
return Bag(dsk, name, npartitions) | |
-def bag_zip(*bags): | |
+def scalar_zip(*iterable_or_scalars): | |
+ assert any(isinstance(part, Iterable) | |
+ for part in iterable_or_scalars) | |
+ iterables = (part if isinstance(part, Iterable) else itertools.repeat(part) | |
+ for part in iterable_or_scalars) | |
+ return zip(*iterables) | |
+ | |
+ | |
+def bag_zip(*parts): | |
""" Partition-wise bag zip | |
All passed bags must have the same number of partitions. | |
@@ -1374,16 +1383,22 @@ def bag_zip(*bags): | |
[(0, 0), (3, None), (None, 5), (6, None), (None 10), (9, None), | |
(12, None), (15, 15), (18, None), (None, 20), (None, 25), (None, 30)] | |
""" | |
- npartitions = bags[0].npartitions | |
- assert all(bag.npartitions == npartitions for bag in bags) | |
+ npartitions = parts[0].npartitions | |
+ assert all(isinstance(part, Bag) or isinstance(part, Item) | |
+ for part in parts) | |
+ assert all(part.npartitions == npartitions | |
+ for part in parts | |
+ if isinstance(part, Bag)) | |
# TODO: do more checks | |
- name = 'zip-' + tokenize(*bags) | |
+ name = 'zip-' + tokenize(*parts) | |
dsk = dict( | |
- ((name, i), (reify, (zip,) + tuple((bag.name, i) for bag in bags))) | |
+ ((name, i), (reify, (scalar_zip,) + tuple( | |
+ (part.name, i) if isinstance(part, Bag) else part.key | |
+ for part in parts))) | |
for i in range(npartitions)) | |
- bags_dsk = merge(*(bag.dask for bag in bags)) | |
- return Bag(merge(bags_dsk, dsk), name, npartitions) | |
+ parts_dsk = merge(*(bag.dask for bag in parts)) | |
+ return Bag(merge(parts_dsk, dsk), name, npartitions) | |
def _reduce(binop, sequence, initial=no_default): |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Example use case: