Skip to content

Instantly share code, notes, and snippets.

@jakirkham
Forked from wwarriner/dask_image_watershed.ipynb
Created November 12, 2019 21:17
Show Gist options
  • Save jakirkham/21eb03f3188b8ffd976cc2dbf2042d7d to your computer and use it in GitHub Desktop.
Save jakirkham/21eb03f3188b8ffd976cc2dbf2042d7d to your computer and use it in GitHub Desktop.
First attempt at distributed watershed
Display the source blob
Display the rendered blob
Raw
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Display the source blob
Display the rendered blob
Raw
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 2,
"metadata": {
"language_info": {
"name": "python",
"codemirror_mode": {
"name": "ipython",
"version": 3
}
},
"orig_nbformat": 2,
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"npconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": 3
},
"cells": [
{
"cell_type": "markdown",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"## Watershed\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import dask\n",
"import dask.array as da\n",
"import numpy as np"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"def compute_mem_mb(shape):\n",
" '''Determines memory consumption of an array with shape in MB'''\n",
" from functools import reduce\n",
" from operator import mul\n",
" count = reduce(mul, shape, 1)\n",
" return count * 8 / ( 1024 ** 2 )\n",
"\n",
"def display(image):\n",
" '''Shows an image in the Jupyter notebook.'''\n",
" from skimage import io\n",
" io.imshow(image)\n",
" io.show()\n",
"\n",
"def frac_to_zscore(frac):\n",
" from scipy.stats import norm\n",
" return norm.ppf(frac)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"ndim = 2\n",
"chunk_len = 400\n",
"chunks = ndim * [chunk_len]\n",
"chunk_mem_mb = compute_mem_mb(chunks)\n",
"print(\"Chunk size (MB): {:.2f}\".format(chunk_mem_mb))"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"from skimage.io import imread\n",
"from skimage.color import rgb2gray\n",
"from skimage.filters import threshold_otsu\n",
"im = imread(\"./double_spiral.jpg\")\n",
"im = rgb2gray(im)\n",
"im = im > threshold_otsu(im)\n",
"im = im + 1\n",
"\n",
"salt = np.zeros_like(im)\n",
"salt[600, 650] = True\n",
"salt[700, 650] = True\n",
"im[600, 650] = 0\n",
"im[700, 650] = 0\n",
"max_edt = im.max()\n",
"edt = im / max_edt\n",
"display(edt)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"# hmax\n",
"from skimage.morphology import reconstruction\n",
"h = 1 / (1 + max_edt)\n",
"h_seed = edt - h\n",
"hmax = reconstruction(h_seed, edt, method='dilation')\n",
"display(hmax)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"from skimage.measure import label\n",
"ws_markers, ws_label_max = label(salt, return_num=True)\n",
"\n",
"from skimage.morphology import watershed\n",
"ws = watershed(hmax, markers=ws_markers)\n",
"display(ws)"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"from dask.distributed import Client\n",
"c = Client()\n",
"port = c.scheduler_info()['services']['dashboard']\n",
"print(\"Type `http://localhost:{port}` into the URL bar of your favorite browser to watch the following code in action on your machine in real time.\".format(port=port))"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
"depth = 1\n",
"hmax_da = da.from_array(hmax, chunks=chunks)\n",
"hmax_op = da.overlap.overlap(hmax_da, depth=depth, boundary='nearest')\n",
"#display(hmax_op)\n",
"\n",
"boundary_label = ws_label_max + 1\n",
"ws_markers_da = da.from_array(ws_markers, chunks=chunks)\n",
"ws_markers_op = da.overlap.overlap(ws_markers_da, depth=depth, boundary=boundary_label)\n",
"#display(labels_op)\n",
"\n",
"def mask_overlap(chunk, depth, label):\n",
" overlap = np.zeros_like(chunk)\n",
" for i in range(0, chunk.ndim):\n",
" idx = [slice(None)]*chunk.ndim\n",
" r = np.array(range(0, chunk.shape[i]))\n",
" idx[i] = (r[:depth], r[-depth:])\n",
" overlap[tuple(idx)] = label\n",
" return overlap\n",
"\n",
"def build_full_markers(labels, depth, label):\n",
" markers = mask_overlap(labels, depth, label)\n",
" markers[labels > 0] = labels[labels > 0]\n",
" return markers\n",
"\n",
"ws_markers_op = ws_markers_op.map_blocks(lambda x: build_full_markers(x, depth, boundary_label), dtype=ws_markers_op.dtype)\n",
"display(ws_markers_op)\n",
"\n",
"ws_fp = hmax_op.map_blocks(lambda x, y: watershed(x, markers=y), ws_markers_op, dtype=hmax_op.dtype)\n",
"display(ws_fp)"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
"def remove_label(chunk, label):\n",
" out = chunk.copy()\n",
" out[out == label] = 0\n",
" return out\n",
"\n",
"ws_markers_sp = ws_fp.map_blocks(lambda x: remove_label(x, boundary_label))\n",
"ws_markers_sp = da.overlap.trim_overlap(ws_markers_sp, depth=depth)\n",
"ws_markers_sp = da.overlap.overlap(ws_markers_sp, depth=depth, boundary='nearest')\n",
"display(ws_markers_sp)\n",
"\n",
"ws_mask = ws_fp == boundary_label\n",
"display(ws_mask)\n",
"\n",
"ws_sp = hmax_op.map_blocks(lambda x, y, z: watershed(x, markers=y, mask=z), ws_markers_sp, ws_mask, dtype=ws_fp.dtype)\n",
"display(ws_sp)\n",
"\n",
"def compose(fp_chunk, sp_chunk, mask_chunk):\n",
" out = fp_chunk.copy()\n",
" out[mask_chunk] = sp_chunk[mask_chunk]\n",
" return out\n",
"\n",
"ws_final = ws_fp.map_blocks(lambda x, y, z: compose(x, y, z), ws_sp, ws_mask, dtype=ws_fp.dtype)\n",
"ws_final = da.overlap.trim_overlap(ws_final, depth=depth)\n",
"display(ws_final)"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [],
"source": [
"# Validate\n",
"error = ~(ws == ws_final).compute()\n",
"display(error)\n",
"\n",
"error_count = error.sum()\n",
"print(\"Error count: {:d}\".format(error_count))\n",
"print(\"Error fraction: {:.3%}\".format(error_count / error.size))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment