Skip to content

Instantly share code, notes, and snippets.

@d-v-b
Created July 21, 2021 17:49
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save d-v-b/f4cf44e42f4e9d9cfa2d109d2ad26500 to your computer and use it in GitHub Desktop.
Save d-v-b/f4cf44e42f4e9d9cfa2d109d2ad26500 to your computer and use it in GitHub Desktop.
Example of "saving" an increasingly downsampled dask array chunk-by-chunk.
import dask.array as da
from dask import delayed
import time
import numpy as np
import distributed
from distributed import LocalCluster, Client, performance_report
from datetime import datetime
def blocks(self, index, key_array):
"""
Chunkwise iteration function that only needs to exist until I put together a PR
fixing the bad performance of dask.array.blocks
"""
from numbers import Number
from dask.array.slicing import normalize_index
from dask.base import tokenize
from itertools import product
from dask.highlevelgraph import HighLevelGraph
from dask.array import Array
if not isinstance(index, tuple):
index = (index,)
if sum(isinstance(ind, (np.ndarray, list)) for ind in index) > 1:
raise ValueError("Can only slice with a single list")
if any(ind is None for ind in index):
raise ValueError("Slicing with np.newaxis or None is not supported")
index = normalize_index(index, self.numblocks)
index = tuple(slice(k, k + 1) if isinstance(k, Number) else k for k in index)
name = "blocks-" + tokenize(self, index)
new_keys = key_array[index]
chunks = tuple(tuple(np.array(c)[i].tolist()) for c, i in zip(self.chunks, index))
keys = product(*(range(len(c)) for c in chunks))
layer = {(name,) + key: tuple(new_keys[key].tolist()) for key in keys}
graph = HighLevelGraph.from_collections(name, layer, dependencies=[self])
return Array(graph, name, chunks, meta=self)
def downscale(array, depth=1):
results = []
results.append(da.coarsen(da.mean, array, {k: 2 for k in range(array.ndim)}, trim_excess=True))
if depth == 1:
pass
elif depth > 1:
results.extend(downscale(results[-1], depth-1))
else:
raise ValueError('Invalid depth')
return results
def data_sink(array, duration=0.0):
time.sleep(duration)
return None
def data_source(array, duration=0.0):
time.sleep(duration)
return array
def save_blocks(array, duration=0.0):
results = []
key_array = np.array(array.__dask_keys__(), dtype=object)
for idx in np.ndindex(tuple(map(len, array.chunks))):
block = blocks(array, idx, key_array)
results.append(delayed(data_sink)(block, duration))
return results
def ensure_minimum_chunksize(array, chunksize):
old_chunks = np.array(array.chunksize)
new_chunks = old_chunks.copy()
chunk_fitness = np.less(old_chunks, chunksize)
if np.any(chunk_fitness):
new_chunks[chunk_fitness] = np.array(chunksize)[chunk_fitness]
return array.rechunk(new_chunks.tolist())
shape = (4096, 4096, 8192)
in_chunks = (256,) * len(shape)
out_chunks = (64,) * len(shape)
data = da.zeros(shape, chunks=in_chunks, dtype='uint8').map_blocks(data_source)
multiscale = [ensure_minimum_chunksize(a, out_chunks) for a in [data, *downscale(data, 4)]]
storage = [save_blocks(a) for a in multiscale]
now = str(datetime.now()).replace(' ', '_').split('.')[0]
report_fname = f'multiscale_storage_distributed-{distributed.__version__}_{now}.html'
with Client(threads_per_worker=1, memory_limit="15GB") as cl:
with performance_report(filename=report_fname):
cl.cluster.scale(10)
print(cl.cluster.dashboard_link)
cl.compute(delayed(storage), sync=True)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment