Skip to content

Instantly share code, notes, and snippets.

@m-albert
Created October 9, 2020 10:59
Show Gist options
  • Save m-albert/a2aa957bf0e96b665d8d53c723506e41 to your computer and use it in GitHub Desktop.
Save m-albert/a2aa957bf0e96b665d8d53c723506e41 to your computer and use it in GitHub Desktop.
Affine transformation for dask arrays: Wrapper around `ndimage.affine_transform`
#!/usr/bin/env python
"""
Affine transformation for dask arrays: Wrapper around `ndimage.affine_transform`
"""
__author__ = "Marvin Albert"
__email__ = "marvin.albert@gmail.com"
import numpy as np
import dask.array as da
from scipy import ndimage
def affine_transform_dask(
input,
matrix,
offset=0.0,
output_shape=None,
output_chunks=None,
**kwargs
):
"""
Wraps `ndimage.affine_transformation` for dask arrays.
For every output chunk, only the slice containing the
relevant part of the input is passed on to
`ndimage.affine_transformation`.
To do:
- optionally use cupyx.scipy.ndimage.affine_transform
API wraps `ndimage.affine_transformation`, except for `output_chunks`.
:param input: N-D numpy or dask array
:param matrix:
:param offset:
:param output_shape:
:param output_chunks:
:param kwargs:
:return: dask array
"""
def resample_chunk(chunk, matrix, offset, kwargs, block_info=None):
N = chunk.ndim
input_shape = input.shape
chunk_shape = chunk.shape
chunk_offset = [i[0] for i in block_info[0]['array-location']]
# print('chunk_offset', chunk_offset)
chunk_edges = np.array([i for i in np.ndindex(tuple([2] * N))])\
* np.array(chunk_shape) + np.array(chunk_offset)
rel_input_edges = np.dot(matrix, chunk_edges.T).T + offset
# print('rel_input_edges', rel_input_edges) # ok
# print('chunk_edges', chunk_edges) # ok
rel_input_i = np.min(rel_input_edges, 0)
rel_input_f = np.max(rel_input_edges, 0)
# not sure yet how many additional pixels to include
# (depends on interp order?)
for dim, upper in zip(range(N), input_shape):
rel_input_i[dim] = np.clip(rel_input_i[dim] - 2, 0, upper)
rel_input_f[dim] = np.clip(rel_input_f[dim] + 2, 0, upper)
rel_input_i = rel_input_i.astype(np.int64)
rel_input_f = rel_input_f.astype(np.int64)
# print('min max input', rel_input_i, rel_input_f)
rel_input_slice = tuple([slice(int(rel_input_i[dim]),
int(rel_input_f[dim]))
for dim in range(N)])
rel_input = input[rel_input_slice]
# print('rel_input_slice', rel_input_slice)
# modify offset to point into cropped input
# y = Mx + o
# coordinate substitution:
# y' = y - y0(min_coord_px)
# x' = x - x0(chunk_offset)
# then
# y' = Mx' + o + Mx0 - y0
# M' = M
# o' = o + Mx0 - y0
offset_prime = offset + np.dot(matrix, chunk_offset) - rel_input_i
chunk = ndimage.affine_transform(rel_input,
matrix,
offset_prime,
output_shape=chunk_shape,
**kwargs)
return chunk
if output_shape is None: output_shape = input.shape
transformed = da.zeros(output_shape,
dtype=input.dtype,
chunks=output_chunks)
transformed = transformed.map_blocks(resample_chunk,
dtype=input.dtype,
matrix=matrix,
offset=offset,
kwargs=kwargs,
)
return transformed
if __name__ == "__main__":
from timeit import default_timer as timer
from matplotlib import pyplot
import tifffile
# create test image
N = 3
a = 100
np.random.seed(0)
im = np.random.random([int(a / 20)] * N)
im = ndimage.zoom(im, 20, order=1)
im = im / im.max()
im *= 1000
im = im.astype(np.uint16)
# transform into dask array
chunksize = [32] * N
dim = da.from_array(im, chunks=chunksize)
# define (random) transformation
matrix = np.eye(N) + (np.random.random((N, N)) - 0.5) / 5.
offset = (np.random.random(N) - 0.5) / 5. * np.array(im.shape)
print('matrix\n', matrix)
print('offset\n', offset)
# define resampling options
# output_shape = im.shape
output_shape = [int(a / 4)] * N
output_chunks = [32] * N
interp_order = 3
# transform without dask
ti = timer()
im_t_nodask = ndimage.affine_transform(im, matrix, offset,
output_shape=output_shape,
order=interp_order)
tf = timer()
print('Timing without dask: %s seconds' %(tf-ti))
# transform with function above using dask
ti = timer()
scheduler = 'single-threaded'
# scheduler = 'threads'
im_t_dask = affine_transform_dask(dim, matrix, offset,
output_shape=output_shape,
output_chunks=output_chunks,
order=interp_order)
im_t_dask_computed = im_t_dask.compute(scheduler=scheduler)
tf = timer()
print('Timing with dask: %s seconds' %(tf-ti))
# write out dask graph to visualize chunk flow
# python-graphviz needs to be installed, see:
# https://docs.dask.org/en/latest/graphviz.html
# im_t_dask.visualize(filename='affine_transformation_dask.png')
# show and compare transformation results
tifffile.imshow(np.array([im_t_nodask, im_t_dask_computed]), vmin=0, vmax=1000)
pyplot.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment