Skip to content

Instantly share code, notes, and snippets.

Embed
What would you like to do?
slices_from_chunks_overlap
# Proposed slices_from_chunks_overlap function
# Mofified from slices_from_chunks from dask.array.core
from itertools import product
from dask.array.slicing import cached_cumsum
def slices_from_chunks_overlap(chunks, array_shape, depth=1):
"""Translate chunks tuple to a set of slices in product order
Parameters
----------
chunks : tuple
The chunks of the corresponding dask array.
array_shape : tuple
Shape of the corresponding dask array.
depth : int
The number of pixels to overlap, providing we're not at the array edge.
Example
-------
>>> slices_from_chunks_overlap(((4,), (7, 7)), (4, 14), depth=1) # doctest: +NORMALIZE_WHITESPACE
[(slice(0, 5, None), slice(0, 8, None)),
(slice(0, 5, None), slice(6, 15, None))]
"""
cumdims = [cached_cumsum(bds, initial_zero=True) for bds in chunks]
slices = []
for starts, shapes in zip(cumdims, chunks):
inner_slices = []
for s, dim, maxshape in zip(starts, shapes, array_shape):
slice_start = s
slice_stop = s + dim
if slice_start > 0:
slice_start -= depth
if slice_stop >= maxshape:
slice_stop += depth
inner_slices.append(slice(slice_start, slice_stop))
slices.append(inner_slices)
return list(product(*slices))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment