Skip to content

Instantly share code, notes, and snippets.

@jakirkham
Forked from wwarriner/dask_image_watershed.ipynb
Created November 12, 2019 21:17
Show Gist options
  • Save jakirkham/21eb03f3188b8ffd976cc2dbf2042d7d to your computer and use it in GitHub Desktop.
Save jakirkham/21eb03f3188b8ffd976cc2dbf2042d7d to your computer and use it in GitHub Desktop.
First attempt at distributed watershed
Display the source blob
Display the rendered blob
Raw
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 2,
"metadata": {
"language_info": {
"name": "python",
"codemirror_mode": {
"name": "ipython",
"version": 3
}
},
"orig_nbformat": 2,
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"npconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": 3
},
"cells": [
{
"cell_type": "markdown",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"## Distributed Watershed\n",
"\n",
"Implementation of a distributed watershed function using Dask. Implementation inspired by Juan Nunez-Iglesias [here](https://github.com/dask/dask-image/pull/99). The overall method relies on a two-pass watershed model. The first pass watershed is used to generate and share information about markers across chunk boundaries. The second pass then propagates that information.\n",
"\n",
"This implementation differs from v1. Whereas v1 shares marker information about the first pass watershed directly, this shares information about through a specially-labeled chunk-boundary basin."
]
},
{
"cell_type": "markdown",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"### Base Imports"
]
},
{
"cell_type": "code",
"execution_count": 38,
"metadata": {},
"outputs": [],
"source": [
"import dask\n",
"import dask.array as da\n",
"import numpy as np"
]
},
{
"cell_type": "markdown",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"### Helper Functions"
]
},
{
"cell_type": "code",
"execution_count": 27,
"metadata": {},
"outputs": [],
"source": [
"def compute_mem_mb(shape):\n",
" '''Determines memory consumption of an array with shape in MB'''\n",
" from functools import reduce\n",
" from operator import mul\n",
" count = reduce(mul, shape, 1)\n",
" return count * 8 / ( 1024 ** 2 )\n",
"\n",
"def display(image):\n",
" '''Shows an image in the Jupyter notebook.'''\n",
" from skimage import io\n",
" io.imshow(image)\n",
" io.show()\n",
"\n",
"def frac_to_zscore(frac):\n",
" from scipy.stats import norm\n",
" return norm.ppf(frac)\n",
"\n",
"def mask_overlap(chunk, depth, label):\n",
" overlap = np.zeros_like(chunk)\n",
" for i in range(0, chunk.ndim):\n",
" idx = [slice(None)]*chunk.ndim\n",
" r = np.array(range(0, chunk.shape[i]))\n",
" idx[i] = (r[:depth], r[-depth:])\n",
" overlap[tuple(idx)] = label\n",
" return overlap\n",
"\n",
"def build_full_markers(labels, depth, label):\n",
" markers = mask_overlap(labels, depth, label)\n",
" markers[labels > 0] = labels[labels > 0]\n",
" return markers\n",
"\n",
"def create_random_salt_image(fraction_salt, size):\n",
" zscore = frac_to_zscore(1 - fraction_salt)\n",
" salt = np.random.normal(0.0, 1.0, size) > zscore\n",
" return salt\n",
"\n",
"def remove_label(chunk, label):\n",
" out = chunk.copy()\n",
" out[out == label] = 0\n",
" return out\n",
"\n",
"def compose(fp_chunk, sp_chunk, mask_chunk):\n",
" out = fp_chunk.copy()\n",
" out[mask_chunk] = sp_chunk[mask_chunk]\n",
" return out"
]
},
{
"cell_type": "markdown",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"### Image Geometry Definitions"
]
},
{
"cell_type": "code",
"execution_count": 57,
"metadata": {},
"outputs": [],
"source": [
"ndim = 2\n",
"size_len = 200\n",
"size = ndim * [size_len]\n",
"mem_mb = compute_mem_mb(size)\n",
"print(\"2D array size (MB): {:.2f}\".format(mem_mb))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"chunk_len = 400\n",
"chunks = ndim * [chunk_len]\n",
"chunk_mem_mb = compute_mem_mb(chunks)\n",
"print(\"Chunk size (MB): {:.2f}\".format(chunk_mem_mb))"
]
},
{
"cell_type": "markdown",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"### Random Seed"
]
},
{
"cell_type": "code",
"execution_count": 29,
"metadata": {},
"outputs": [],
"source": [
"seed = 1\n",
"np.random.seed(seed=seed)"
]
},
{
"cell_type": "markdown",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"### Create Random \"Salting\" Image"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"fraction_salt = 1e-2\n",
"salt = create_random_salt_image(fraction_salt, size)\n",
"#salt = np.flip(salt, axis=0)\n",
"display(salt)"
]
},
{
"cell_type": "markdown",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"### Determine EDT"
]
},
{
"cell_type": "code",
"execution_count": 65,
"metadata": {},
"outputs": [],
"source": [
"from scipy.ndimage.morphology import distance_transform_edt\n",
"edt = distance_transform_edt(~salt)\n",
"max_edt = edt.max()\n",
"edt = edt / max_edt\n",
"display(edt)"
]
},
{
"cell_type": "markdown",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"### Filter EDT Using H-Max/H-Dome"
]
},
{
"cell_type": "code",
"execution_count": 66,
"metadata": {},
"outputs": [],
"source": [
"from skimage.morphology import reconstruction\n",
"h = 1 / max_edt\n",
"h_seed = edt - h\n",
"hmax = reconstruction(h_seed, edt, method='dilation')\n",
"display(hmax)"
]
},
{
"cell_type": "markdown",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"### Prepare Marker Image"
]
},
{
"cell_type": "code",
"execution_count": 67,
"metadata": {},
"outputs": [],
"source": [
"from skimage.measure import label\n",
"ws_markers = label(salt)"
]
},
{
"cell_type": "markdown",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"### Watershed Transform"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from skimage.morphology import watershed\n",
"ws = watershed(hmax, markers=ws_markers)\n",
"display(ws)"
]
},
{
"cell_type": "markdown",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"### Prepare Dask Client"
]
},
{
"cell_type": "code",
"execution_count": 55,
"metadata": {},
"outputs": [],
"source": [
"from dask.distributed import Client\n",
"c = Client()\n",
"port = c.scheduler_info()['services']['dashboard']\n",
"print(\"Type `http://localhost:{port}` into the URL bar of your favorite browser to watch the following code in action on your machine in real time.\".format(port=port))"
]
},
{
"cell_type": "markdown",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"### Create Dask Arrays\n",
"\n",
"Here we assume future users have access to distributed versions of h-max, EDT, and connected component labeling."
]
},
{
"cell_type": "code",
"execution_count": 68,
"metadata": {},
"outputs": [],
"source": [
"depth = 1\n",
"hmax_da = da.from_array(hmax, chunks=chunks)\n",
"hmax_op = da.overlap.overlap(hmax_da, depth=depth, boundary='nearest')\n",
"\n",
"boundary_label = ws_label_max + 1\n",
"ws_markers_da = da.from_array(ws_markers, chunks=chunks)\n",
"ws_markers_op = da.overlap.overlap(ws_markers_da, depth=depth, boundary=boundary_label)\n",
"ws_markers_op = ws_markers_op.map_blocks(lambda x: build_full_markers(x, depth, boundary_label), dtype=ws_markers_op.dtype)\n",
"display(ws_markers_op)\n",
"\n",
"# fp = first pass\n",
"ws_fp = hmax_op.map_blocks(lambda x, y: watershed(x, markers=y), ws_markers_op, dtype=hmax_op.dtype)\n",
"display(ws_fp)"
]
},
{
"cell_type": "markdown",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"### Propagate Boundary Basin"
]
},
{
"cell_type": "code",
"execution_count": 69,
"metadata": {},
"outputs": [],
"source": [
"ws_markers_sp = ws_fp.map_blocks(lambda x: remove_label(x, boundary_label))\n",
"ws_markers_sp = da.overlap.trim_overlap(ws_markers_sp, depth=depth)\n",
"ws_markers_sp = da.overlap.overlap(ws_markers_sp, depth=depth, boundary='nearest')\n",
"display(ws_markers_sp)\n",
"\n",
"ws_mask = ws_fp == boundary_label\n",
"display(ws_mask)\n",
"\n",
"ws_sp = hmax_op.map_blocks(lambda x, y, z: watershed(x, markers=y, mask=z), ws_markers_sp, ws_mask, dtype=ws_fp.dtype)\n",
"display(ws_sp)"
]
},
{
"cell_type": "markdown",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"### Compose First and Second Passes Using Mask"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"ws_final = ws_fp.map_blocks(lambda x, y, z: compose(x, y, z), ws_sp, ws_mask, dtype=ws_fp.dtype)\n",
"ws_final = da.overlap.trim_overlap(ws_final, depth=depth)\n",
"display(ws_final)"
]
},
{
"cell_type": "markdown",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"### Validation\n",
"\n",
"Note there are some different basin assignments between the methods. Pay particular attention to the error-free strip along the bottom 1/4 of the image. That strip seems to stay through seed changes, and even flipping the salt image along the vertical axis."
]
},
{
"cell_type": "code",
"execution_count": 70,
"metadata": {},
"outputs": [],
"source": [
"error = ~(ws == ws_final).compute()\n",
"display(error)\n",
"\n",
"error_count = error.sum()\n",
"print(\"Error count: {:d}\".format(error_count))\n",
"print(\"Error fraction: {:.3%}\".format(error_count / error.size))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
]
}
Display the source blob
Display the rendered blob
Raw
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment