-
-
Save jakirkham/21eb03f3188b8ffd976cc2dbf2042d7d to your computer and use it in GitHub Desktop.
First attempt at distributed watershed
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"nbformat": 4, | |
"nbformat_minor": 2, | |
"metadata": { | |
"language_info": { | |
"name": "python", | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
} | |
}, | |
"orig_nbformat": 2, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"npconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": 3 | |
}, | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"## Distributed Watershed\n", | |
"\n", | |
"Implementation of a distributed watershed function using Dask. Implementation inspired by Juan Nunez-Iglesias [here](https://github.com/dask/dask-image/pull/99). The overall method relies on a two-pass watershed model. The first pass watershed is used to generate and share information about markers across chunk boundaries. The second pass then propagates that information.\n", | |
"\n", | |
"This implementation differs from v1. Whereas v1 shares marker information about the first pass watershed directly, this shares information about through a specially-labeled chunk-boundary basin." | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"### Base Imports" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 38, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import dask\n", | |
"import dask.array as da\n", | |
"import numpy as np" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"### Helper Functions" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 27, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def compute_mem_mb(shape):\n", | |
" '''Determines memory consumption of an array with shape in MB'''\n", | |
" from functools import reduce\n", | |
" from operator import mul\n", | |
" count = reduce(mul, shape, 1)\n", | |
" return count * 8 / ( 1024 ** 2 )\n", | |
"\n", | |
"def display(image):\n", | |
" '''Shows an image in the Jupyter notebook.'''\n", | |
" from skimage import io\n", | |
" io.imshow(image)\n", | |
" io.show()\n", | |
"\n", | |
"def frac_to_zscore(frac):\n", | |
" from scipy.stats import norm\n", | |
" return norm.ppf(frac)\n", | |
"\n", | |
"def mask_overlap(chunk, depth, label):\n", | |
" overlap = np.zeros_like(chunk)\n", | |
" for i in range(0, chunk.ndim):\n", | |
" idx = [slice(None)]*chunk.ndim\n", | |
" r = np.array(range(0, chunk.shape[i]))\n", | |
" idx[i] = (r[:depth], r[-depth:])\n", | |
" overlap[tuple(idx)] = label\n", | |
" return overlap\n", | |
"\n", | |
"def build_full_markers(labels, depth, label):\n", | |
" markers = mask_overlap(labels, depth, label)\n", | |
" markers[labels > 0] = labels[labels > 0]\n", | |
" return markers\n", | |
"\n", | |
"def create_random_salt_image(fraction_salt, size):\n", | |
" zscore = frac_to_zscore(1 - fraction_salt)\n", | |
" salt = np.random.normal(0.0, 1.0, size) > zscore\n", | |
" return salt\n", | |
"\n", | |
"def remove_label(chunk, label):\n", | |
" out = chunk.copy()\n", | |
" out[out == label] = 0\n", | |
" return out\n", | |
"\n", | |
"def compose(fp_chunk, sp_chunk, mask_chunk):\n", | |
" out = fp_chunk.copy()\n", | |
" out[mask_chunk] = sp_chunk[mask_chunk]\n", | |
" return out" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"### Image Geometry Definitions" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 57, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"ndim = 2\n", | |
"size_len = 200\n", | |
"size = ndim * [size_len]\n", | |
"mem_mb = compute_mem_mb(size)\n", | |
"print(\"2D array size (MB): {:.2f}\".format(mem_mb))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"chunk_len = 400\n", | |
"chunks = ndim * [chunk_len]\n", | |
"chunk_mem_mb = compute_mem_mb(chunks)\n", | |
"print(\"Chunk size (MB): {:.2f}\".format(chunk_mem_mb))" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"### Random Seed" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 29, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"seed = 1\n", | |
"np.random.seed(seed=seed)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"### Create Random \"Salting\" Image" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"fraction_salt = 1e-2\n", | |
"salt = create_random_salt_image(fraction_salt, size)\n", | |
"#salt = np.flip(salt, axis=0)\n", | |
"display(salt)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"### Determine EDT" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 65, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"from scipy.ndimage.morphology import distance_transform_edt\n", | |
"edt = distance_transform_edt(~salt)\n", | |
"max_edt = edt.max()\n", | |
"edt = edt / max_edt\n", | |
"display(edt)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"### Filter EDT Using H-Max/H-Dome" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 66, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"from skimage.morphology import reconstruction\n", | |
"h = 1 / max_edt\n", | |
"h_seed = edt - h\n", | |
"hmax = reconstruction(h_seed, edt, method='dilation')\n", | |
"display(hmax)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"### Prepare Marker Image" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 67, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"from skimage.measure import label\n", | |
"ws_markers = label(salt)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"### Watershed Transform" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"from skimage.morphology import watershed\n", | |
"ws = watershed(hmax, markers=ws_markers)\n", | |
"display(ws)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"### Prepare Dask Client" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 55, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"from dask.distributed import Client\n", | |
"c = Client()\n", | |
"port = c.scheduler_info()['services']['dashboard']\n", | |
"print(\"Type `http://localhost:{port}` into the URL bar of your favorite browser to watch the following code in action on your machine in real time.\".format(port=port))" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"### Create Dask Arrays\n", | |
"\n", | |
"Here we assume future users have access to distributed versions of h-max, EDT, and connected component labeling." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 68, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"depth = 1\n", | |
"hmax_da = da.from_array(hmax, chunks=chunks)\n", | |
"hmax_op = da.overlap.overlap(hmax_da, depth=depth, boundary='nearest')\n", | |
"\n", | |
"boundary_label = ws_label_max + 1\n", | |
"ws_markers_da = da.from_array(ws_markers, chunks=chunks)\n", | |
"ws_markers_op = da.overlap.overlap(ws_markers_da, depth=depth, boundary=boundary_label)\n", | |
"ws_markers_op = ws_markers_op.map_blocks(lambda x: build_full_markers(x, depth, boundary_label), dtype=ws_markers_op.dtype)\n", | |
"display(ws_markers_op)\n", | |
"\n", | |
"# fp = first pass\n", | |
"ws_fp = hmax_op.map_blocks(lambda x, y: watershed(x, markers=y), ws_markers_op, dtype=hmax_op.dtype)\n", | |
"display(ws_fp)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"### Propagate Boundary Basin" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 69, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"ws_markers_sp = ws_fp.map_blocks(lambda x: remove_label(x, boundary_label))\n", | |
"ws_markers_sp = da.overlap.trim_overlap(ws_markers_sp, depth=depth)\n", | |
"ws_markers_sp = da.overlap.overlap(ws_markers_sp, depth=depth, boundary='nearest')\n", | |
"display(ws_markers_sp)\n", | |
"\n", | |
"ws_mask = ws_fp == boundary_label\n", | |
"display(ws_mask)\n", | |
"\n", | |
"ws_sp = hmax_op.map_blocks(lambda x, y, z: watershed(x, markers=y, mask=z), ws_markers_sp, ws_mask, dtype=ws_fp.dtype)\n", | |
"display(ws_sp)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"### Compose First and Second Passes Using Mask" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"ws_final = ws_fp.map_blocks(lambda x, y, z: compose(x, y, z), ws_sp, ws_mask, dtype=ws_fp.dtype)\n", | |
"ws_final = da.overlap.trim_overlap(ws_final, depth=depth)\n", | |
"display(ws_final)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"### Validation\n", | |
"\n", | |
"Note there are some different basin assignments between the methods. Pay particular attention to the error-free strip along the bottom 1/4 of the image. That strip seems to stay through seed changes, and even flipping the salt image along the vertical axis." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 70, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"error = ~(ws == ws_final).compute()\n", | |
"display(error)\n", | |
"\n", | |
"error_count = error.sum()\n", | |
"print(\"Error count: {:d}\".format(error_count))\n", | |
"print(\"Error fraction: {:.3%}\".format(error_count / error.size))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
} | |
] | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment