Skip to content

Instantly share code, notes, and snippets.

@jcorbin
Created March 14, 2016 18:29
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save jcorbin/3173695ec83b20386f36 to your computer and use it in GitHub Desktop.
Save jcorbin/3173695ec83b20386f36 to your computer and use it in GitHub Desktop.
Dask Scalar Zip
diff --git a/dask/bag/core.py b/dask/bag/core.py
index dc8a04d..47535c1 100644
--- a/dask/bag/core.py
+++ b/dask/bag/core.py
@@ -6,6 +6,7 @@ import math
import bz2
import os
import uuid
+from collections import Iterable
from fnmatch import fnmatchcase
from glob import glob
from collections import Iterable, Iterator, defaultdict
@@ -1338,7 +1339,15 @@ def bag_range(n, npartitions):
return Bag(dsk, name, npartitions)
-def bag_zip(*bags):
+def scalar_zip(*iterable_or_scalars):
+ assert any(isinstance(part, Iterable)
+ for part in iterable_or_scalars)
+ iterables = (part if isinstance(part, Iterable) else itertools.repeat(part)
+ for part in iterable_or_scalars)
+ return zip(*iterables)
+
+
+def bag_zip(*parts):
""" Partition-wise bag zip
All passed bags must have the same number of partitions.
@@ -1374,16 +1383,22 @@ def bag_zip(*bags):
[(0, 0), (3, None), (None, 5), (6, None), (None 10), (9, None),
(12, None), (15, 15), (18, None), (None, 20), (None, 25), (None, 30)]
"""
- npartitions = bags[0].npartitions
- assert all(bag.npartitions == npartitions for bag in bags)
+ npartitions = parts[0].npartitions
+ assert all(isinstance(part, Bag) or isinstance(part, Item)
+ for part in parts)
+ assert all(part.npartitions == npartitions
+ for part in parts
+ if isinstance(part, Bag))
# TODO: do more checks
- name = 'zip-' + tokenize(*bags)
+ name = 'zip-' + tokenize(*parts)
dsk = dict(
- ((name, i), (reify, (zip,) + tuple((bag.name, i) for bag in bags)))
+ ((name, i), (reify, (scalar_zip,) + tuple(
+ (part.name, i) if isinstance(part, Bag) else part.key
+ for part in parts)))
for i in range(npartitions))
- bags_dsk = merge(*(bag.dask for bag in bags))
- return Bag(merge(bags_dsk, dsk), name, npartitions)
+ parts_dsk = merge(*(bag.dask for bag in parts))
+ return Bag(merge(parts_dsk, dsk), name, npartitions)
def _reduce(binop, sequence, initial=no_default):
@jcorbin
Copy link
Author

jcorbin commented Mar 14, 2016

Example use case:

b = dask.bag.from_sequence(100, npartitions=10)
b_with_total = db.zip(b, b.sum())
assert b_with_total.map(lambda (x, total): x / total).sum().compute() == 1.0
assert b_with_total.map_partitions(lambda z: [x / total for x, total in z]).sum().compute() == 1.0

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment