Last active
March 2, 2020 03:58
-
-
Save jakirkham/362ec5316922c19543bec84c7b7b7072 to your computer and use it in GitHub Desktop.
Slicing Dask Arrays with Index Arrays
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import itertools | |
import numbers | |
import operator | |
import numpy as np | |
import dask.array as da | |
try: | |
from functools import reduce | |
except ImportError: | |
pass | |
def slice_index_array(a, idx): | |
# Validate this is not a single array. | |
if not isinstance(idx, tuple): | |
raise TypeError("`idx` must be a `tuple` of 1-D arrays") | |
# Cast as Dask Arrays. | |
a = da.asarray(a) | |
idx = tuple(da.asarray(e) for e in idx) | |
# Validate the dimensionality and shape of `idx` and `a`. | |
if not all([issubclass(e.dtype.type, numbers.Integral) for e in idx]): | |
raise TypeError("`idx` must be of integral type") | |
if not all(e.ndim == 1 for e in idx): | |
raise ValueError("`idx` must contain only 1-D arrays") | |
if not np.allclose(idx[0].shape[0], [e.shape[0] for e in idx], equal_nan=True): | |
raise ValueError("`idx` must contain the same number of points") | |
if a.ndim != len(idx): | |
raise ValueError("`idx` and `a` must share the same dimensionality") | |
# Construct indices for `a`. | |
a_indices = tuple( | |
da.arange(s, chunks=c) for s, c in zip(a.shape, a.chunks) | |
) | |
# Create a series of masks indicating where each coordinate is. | |
# Then apply those masks in each chunk to get the values needed. | |
idx_vals = da.atop( | |
lambda arr, *args: np.fromiter( | |
map( | |
lambda e: arr[e].item(), | |
reduce( | |
np.logical_and, | |
map( | |
np.equal.outer, | |
args[::2], | |
map( | |
lambda i, e: e[i * (None,) + (slice(None),) + (len(args) // 2 - i - 1) * (None,)], | |
range(len(args) // 2), | |
args[1::2] | |
) | |
) | |
) | |
), | |
dtype=a.dtype | |
), | |
(0,), | |
a, tuple(range(1, 1 + a.ndim)), | |
*itertools.chain(*[ | |
(e_0, (0,), e_i, (i,)) | |
for i, (e_0, e_i) in enumerate(zip(idx, a_indices), start=1) | |
]), | |
dtype=bool, | |
concatenate=True | |
) | |
return idx_vals |
My guess is there have already been changes in Dask that include support for this in a reasonable way. Would look there instead of this Gist.
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
The memory usage is quite substantial unless the chunks are small.