Created
July 21, 2021 17:49
-
-
Save d-v-b/f4cf44e42f4e9d9cfa2d109d2ad26500 to your computer and use it in GitHub Desktop.
Example of "saving" an increasingly downsampled dask array chunk-by-chunk.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import dask.array as da | |
from dask import delayed | |
import time | |
import numpy as np | |
import distributed | |
from distributed import LocalCluster, Client, performance_report | |
from datetime import datetime | |
def blocks(self, index, key_array): | |
""" | |
Chunkwise iteration function that only needs to exist until I put together a PR | |
fixing the bad performance of dask.array.blocks | |
""" | |
from numbers import Number | |
from dask.array.slicing import normalize_index | |
from dask.base import tokenize | |
from itertools import product | |
from dask.highlevelgraph import HighLevelGraph | |
from dask.array import Array | |
if not isinstance(index, tuple): | |
index = (index,) | |
if sum(isinstance(ind, (np.ndarray, list)) for ind in index) > 1: | |
raise ValueError("Can only slice with a single list") | |
if any(ind is None for ind in index): | |
raise ValueError("Slicing with np.newaxis or None is not supported") | |
index = normalize_index(index, self.numblocks) | |
index = tuple(slice(k, k + 1) if isinstance(k, Number) else k for k in index) | |
name = "blocks-" + tokenize(self, index) | |
new_keys = key_array[index] | |
chunks = tuple(tuple(np.array(c)[i].tolist()) for c, i in zip(self.chunks, index)) | |
keys = product(*(range(len(c)) for c in chunks)) | |
layer = {(name,) + key: tuple(new_keys[key].tolist()) for key in keys} | |
graph = HighLevelGraph.from_collections(name, layer, dependencies=[self]) | |
return Array(graph, name, chunks, meta=self) | |
def downscale(array, depth=1): | |
results = [] | |
results.append(da.coarsen(da.mean, array, {k: 2 for k in range(array.ndim)}, trim_excess=True)) | |
if depth == 1: | |
pass | |
elif depth > 1: | |
results.extend(downscale(results[-1], depth-1)) | |
else: | |
raise ValueError('Invalid depth') | |
return results | |
def data_sink(array, duration=0.0): | |
time.sleep(duration) | |
return None | |
def data_source(array, duration=0.0): | |
time.sleep(duration) | |
return array | |
def save_blocks(array, duration=0.0): | |
results = [] | |
key_array = np.array(array.__dask_keys__(), dtype=object) | |
for idx in np.ndindex(tuple(map(len, array.chunks))): | |
block = blocks(array, idx, key_array) | |
results.append(delayed(data_sink)(block, duration)) | |
return results | |
def ensure_minimum_chunksize(array, chunksize): | |
old_chunks = np.array(array.chunksize) | |
new_chunks = old_chunks.copy() | |
chunk_fitness = np.less(old_chunks, chunksize) | |
if np.any(chunk_fitness): | |
new_chunks[chunk_fitness] = np.array(chunksize)[chunk_fitness] | |
return array.rechunk(new_chunks.tolist()) | |
shape = (4096, 4096, 8192) | |
in_chunks = (256,) * len(shape) | |
out_chunks = (64,) * len(shape) | |
data = da.zeros(shape, chunks=in_chunks, dtype='uint8').map_blocks(data_source) | |
multiscale = [ensure_minimum_chunksize(a, out_chunks) for a in [data, *downscale(data, 4)]] | |
storage = [save_blocks(a) for a in multiscale] | |
now = str(datetime.now()).replace(' ', '_').split('.')[0] | |
report_fname = f'multiscale_storage_distributed-{distributed.__version__}_{now}.html' | |
with Client(threads_per_worker=1, memory_limit="15GB") as cl: | |
with performance_report(filename=report_fname): | |
cl.cluster.scale(10) | |
print(cl.cluster.dashboard_link) | |
cl.compute(delayed(storage), sync=True) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment