Skip to content

Instantly share code, notes, and snippets.

@GenevieveBuckley
Last active August 31, 2020 09:39
Show Gist options
  • Save GenevieveBuckley/ecb81a13c93139ae8dc9cc1e50f1e2ca to your computer and use it in GitHub Desktop.
Save GenevieveBuckley/ecb81a13c93139ae8dc9cc1e50f1e2ca to your computer and use it in GitHub Desktop.
GPU dask-image support for ndmorph
# -*- coding: utf-8 -*-
__author__ = """John Kirkham"""
__email__ = "kirkhamj@janelia.hhmi.org"
import scipy.ndimage
from . import _utils
from . import _ops
from ..dispatch._dispatch_ndmorph import (
dispatch_binary_dilation,
dispatch_binary_erosion)
__all__ = [
"binary_closing",
"binary_dilation",
"binary_erosion",
"binary_opening",
]
@_utils._update_wrapper(scipy.ndimage.binary_closing)
def binary_closing(image,
structure=None,
iterations=1,
origin=0):
image = (image != 0)
structure = _utils._get_structure(image, structure)
iterations = _utils._get_iterations(iterations)
origin = _utils._get_origin(structure.shape, origin)
result = image
result = binary_dilation(
result, structure=structure, iterations=iterations, origin=origin
)
result = binary_erosion(
result, structure=structure, iterations=iterations, origin=origin
)
return result
@_utils._update_wrapper(scipy.ndimage.binary_dilation)
def binary_dilation(image,
structure=None,
iterations=1,
mask=None,
border_value=0,
origin=0,
brute_force=False):
border_value = _utils._get_border_value(border_value)
result = _ops._binary_op(
dispatch_binary_dilation(image),
image,
structure=structure,
iterations=iterations,
mask=mask,
origin=origin,
brute_force=brute_force,
border_value=border_value
)
return result
@_utils._update_wrapper(scipy.ndimage.binary_erosion)
def binary_erosion(image,
structure=None,
iterations=1,
mask=None,
border_value=0,
origin=0,
brute_force=False):
border_value = _utils._get_border_value(border_value)
result = _ops._binary_op(
dispatch_binary_erosion(image),
image,
structure=structure,
iterations=iterations,
mask=mask,
origin=origin,
brute_force=brute_force,
border_value=border_value
)
return result
@_utils._update_wrapper(scipy.ndimage.binary_opening)
def binary_opening(image,
structure=None,
iterations=1,
origin=0):
image = (image != 0)
structure = _utils._get_structure(image, structure)
iterations = _utils._get_iterations(iterations)
origin = _utils._get_origin(structure.shape, origin)
result = image
result = binary_erosion(
result, structure=structure, iterations=iterations, origin=origin
)
result = binary_dilation(
result, structure=structure, iterations=iterations, origin=origin
)
return result
# -*- coding: utf-8 -*-
import numpy as np
import scipy.ndimage
from ._dispatcher import Dispatcher
__all__ = [
"dispatch_binary_dilation",
"dispatch_binary_erosion",
]
dispatch_binary_dilation = Dispatcher(name="dispatch_binary_dilation")
dispatch_binary_erosion = Dispatcher(name="dispatch_binary_erosion")
# ================== binary_dilation ==================
@dispatch_binary_dilation.register(np.ndarray)
def numpy_binary_dilation(*args, **kwargs):
return scipy.ndimage.binary_dilation
@dispatch_binary_dilation.register_lazy("cupy")
def register_cupy_binary_dilation():
import cupy
import cupyx.scipy.ndimage
@dispatch_binary_dilation.register(cupy.ndarray)
def cupy_binary_dilation(*args, **kwargs):
return cupyx.scipy.ndimage.binary_dilation
# ================== binary_erosion ==================
@dispatch_binary_erosion.register(np.ndarray)
def numpy_binary_erosion(*args, **kwargs):
return scipy.ndimage.binary_erosion
@dispatch_binary_erosion.register_lazy("cupy")
def register_cupy_binary_erosion():
import cupy
import cupyx.scipy.ndimage
@dispatch_binary_erosion.register(cupy.ndarray)
def cupy_binary_erosion(*args, **kwargs):
return cupyx.scipy.ndimage.binary_erosion
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import dask.array as da
import numpy as np
import pytest
from dask_image import ndmorph
cupy = pytest.importorskip("cupy", minversion="7.7.0")
@pytest.fixture
def array():
s = (10, 10)
a = da.from_array(cupy.arange(int(np.prod(s)),
dtype=cupy.float32).reshape(s), chunks=5)
return a
@pytest.mark.cupy
@pytest.mark.parametrize("func", [
ndmorph.binary_closing,
ndmorph.binary_dilation,
ndmorph.binary_erosion,
ndmorph.binary_opening,
])
def test_cupy_ndmorph(array, func):
"""Test convolve & correlate filters with cupy input arrays."""
result = func(array)
result.compute()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment