Skip to content

Instantly share code, notes, and snippets.

@will-moore
Created November 13, 2020 16:36
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save will-moore/819ade3c4e46864d9405555a1bf4933c to your computer and use it in GitHub Desktop.
Save will-moore/819ade3c4e46864d9405555a1bf4933c to your computer and use it in GitHub Desktop.
Performance test for concatenating tiles into multi-dimensional pyramid with dask.array.stack vv map_blocks
"""
Display a 5D dask multiscale pyramid
"""
from dask import array as da
from dask import delayed
import datetime
import numpy as np
import napari
from math import ceil
USE_MAP_BLOCKS = True
def stitch_planes(planes, axis=0):
if not USE_MAP_BLOCKS:
return da.concatenate(planes, axis=axis)
arrayfunc = np.asanyarray
stitched_shape = list(planes[0].shape)
chunk = planes[0].shape
stitched_shape[axis] = stitched_shape[axis] * len(planes)
stitched_shape = tuple(stitched_shape)
dtype = planes[0].dtype
chunks = da.core.normalize_chunks(chunk, stitched_shape)
a = da.map_blocks(
read_data,
chunks=chunks,
planes=planes,
axis=axis,
arrayfunc=arrayfunc,
meta=arrayfunc([]).astype(dtype), # meta overwrites `dtype` argument
)
return a
def read_data(planes, axis, block_info=None, **kwargs):
"""map_blocks passes in the planes"""
i = block_info[None]['chunk-location'][axis]
return planes[i]
def stack_planes(planes):
if not USE_MAP_BLOCKS:
return da.stack(planes)
arrayfunc = np.asanyarray
shape = (len(planes),) + planes[0].shape
dtype = planes[0].dtype
chunks = da.core.normalize_chunks((1,) + shape[1:], shape)
a = da.map_blocks(
read_data_with_extra_dimension,
chunks=chunks,
planes = planes,
arrayfunc=arrayfunc,
meta=arrayfunc([]).astype(dtype), # meta overwrites `dtype` argument
)
return a
def read_data_with_extra_dimension(planes, block_info=None, **kwargs):
"""map_blocks passes in the planes"""
i, j = block_info[None]['array-location'][0]
return np.expand_dims(planes[i], axis=0)
def get_tile(tile_name):
"""Return a tile for the given coordinates"""
print('get_tile level, t, c, z, y, x, w, h', tile_name)
level, t, c, z, y, x, w, h = [int(n) for n in tile_name.split(",")]
def f2(x, y):
# Try to return a tile that depends on level and z, c, t
if c % 2 == 1:
return (y + (2 * t) + (2 * z))
else:
return (x + ((level % 2) * y)) // 2
plane_2d = np.fromfunction(f2, (h, w), dtype=np.int16)
return plane_2d
lazy_reader = delayed(get_tile)
def get_lazy_plane(level, t, c, z, plane_y, plane_x, tile_shape):
print('get_lazy_plane: level, t, c, z, plane_y, plane_x', level, t, c, z, plane_y, plane_x)
tile_w, tile_h = tile_shape
rows = ceil(plane_y / tile_h)
cols = ceil(plane_x / tile_w)
print('rows', rows, 'cols', cols)
lazy_rows = []
for row in range(rows):
lazy_row = []
for col in range(cols):
x = col * tile_w
y = row * tile_h
w = min(tile_w, plane_x - x)
h = min(tile_h, plane_y - y)
tile_name = "%s,%s,%s,%s,%s,%s,%s,%s" % (level, t, c, z, y, x, w, h)
lazy_tile = da.from_delayed(lazy_reader(tile_name), shape=(h, w), dtype=np.int16)
lazy_row.append(lazy_tile)
lazy_row = stitch_planes(lazy_row, axis=1)
print('lazy_row.shape', lazy_row.shape)
lazy_rows.append(lazy_row)
return stitch_planes(lazy_rows, axis=0)
def get_pyramid_lazy(shape, tile_shape, levels):
"""Get a pyramid of rgb dask arrays, loading tiles from OMERO."""
size_t, size_c, size_z, size_y, size_x = shape
pyramid = []
plane_x = size_x
plane_y = size_y
for level in range(levels):
print('level', level)
t_stacks = []
for t in range(size_t):
c_stacks = []
for c in range(size_c):
z_stack = []
for z in range(size_z):
lazy_plane = get_lazy_plane(level, t, c, z, plane_y, plane_x, tile_shape)
z_stack.append(lazy_plane)
c_stacks.append(stack_planes(z_stack))
t_stacks.append(stack_planes(c_stacks))
pyramid.append(stack_planes(t_stacks))
plane_x = plane_x // 2
plane_y = plane_y // 2
print ('pyramid...')
for level in pyramid:
print(level.shape)
return pyramid
shape = (10, 2, 5, 3000, 5000)
tile_shape = (256, 256)
levels = 4
start = datetime.datetime.now()
pyramid = get_pyramid_lazy(shape, tile_shape, levels)
lazy_timer = (datetime.datetime.now() - start).total_seconds()
print('lazy pyramid timer', lazy_timer)
# times = []
# for level in range(levels, 0, -1):
# start = datetime.datetime.now()
# pyramid[level - 1].compute()
# timer = (datetime.datetime.now() - start).total_seconds()
# times.append(timer)
# print(f'Level {level - 1} compute timer', (datetime.datetime.now() - start).total_seconds())
# print('shape', shape, 'tile_shape', tile_shape)
# print('lazy_pyramid creation', lazy_timer)
# print('compute times', times)
# Example output
# shape (10, 2, 5, 3000, 5000) tile_shape (256, 256)
# With USE_MAP_BLOCKS = False
# lazy_pyramid creation 8.401882
# compute times [0.403972, 1.156574, 5.493262, 26.310839]
# With USE_MAP_BLOCKS = True (should be faster, but is slower!)
# lazy_pyramid creation 14.68406
# compute times [1.341739, 2.146021, 8.515899, 36.635117]
with napari.gui_qt():
viewer = napari.view_image(pyramid, channel_axis=1)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment