Skip to content

Instantly share code, notes, and snippets.

@ianhi
Created May 28, 2021 04:14
Show Gist options
  • Save ianhi/3cd151068f338249b0547b285adc2aa7 to your computer and use it in GitHub Desktop.
Save ianhi/3cd151068f338249b0547b285adc2aa7 to your computer and use it in GitHub Desktop.
interactively explore watershed segmentation parameters
%matplotlib widget
from functools import lru_cache
import matplotlib.pyplot as plt
import numpy as np
from mpl_interactions import ipyplot as iplt
from scipy import ndimage as ndi
from skimage.data import binary_blobs
from skimage.feature import peak_local_max
from skimage.segmentation import watershed
def interactive_watershed(
image, min_distances=(1, 100), mask_alpha=(1, 0), figsize=(6, 6), scatter_color="r"
):
"""
Parameters
----------
image : (M, N) array-like of bool
A thresholded image
min_distance : (int, int) or three tuple of int or array-like of int
The possible values of min_distance for peak finding. Tuples will be used as shorthand for
np.arange, otherwise they will be passed directly to the function
mask_alpha : (float, float) or float
The alpha of the watershed mask. If a tuple it will be used as an argument to np.linspace and create a slider
for a single value it will be fixed
figsize : (float, float)
The figure size
"""
from copy import copy
clipped_viridis = copy(plt.cm.viridis)
clipped_viridis.set_under(alpha=0)
fig, ax = plt.subplots(figsize=figsize)
distance = ndi.distance_transform_edt(image)
@lru_cache
def get_peaks(min_distance):
# copy to avoid https://github.com/scikit-image/scikit-image/issues/5235
# has been fixed but not yet released
coords = peak_local_max(distance.copy(), min_distance=min_distance, labels=image)
return coords[:, ::-1]
@lru_cache
def get_mask(min_distance):
coords = get_peaks(min_distance)[:, ::-1]
mask = np.zeros(distance.shape, dtype=bool)
mask[tuple(coords.T)] = True
markers, _ = ndi.label(mask)
labels = watershed(-distance, markers, mask=image)
return labels
if isinstance(min_distances, tuple):
min_distances = np.arange(*min_distances)
N = len(get_peaks(min_distances[0]))
# make a custom cmap for segmentation using a naive random sampling of colors
# Assume that we will start with the most possible cells as the default
# arugment of min_distance starts it at a min which leads to maximal number of segments
colors = np.random.random(size=N * 3).reshape(N, 3)
from matplotlib.colors import LinearSegmentedColormap
cmap = LinearSegmentedColormap.from_list("segments", colors, N)
cmap.set_under(alpha=0)
ax.imshow(image, cmap="gray")
ctrls = iplt.scatter(get_peaks, min_distance=min_distances, parametric=True, c="r")
iplt.imshow(get_mask, alpha=mask_alpha, controls=ctrls, cmap=cmap, vmin=0.1)
image = binary_blobs(volume_fraction=0.2)
interactive_watershed(image)
@ianhi
Copy link
Author

ianhi commented May 28, 2021

Peek 2021-05-28 00-18

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment