Skip to content

Instantly share code, notes, and snippets.

@stefsmeets
Created August 30, 2023 10:12
Show Gist options
  • Save stefsmeets/c7bf4f5c2fcee9ec9d7ae688e782ce9f to your computer and use it in GitHub Desktop.
Save stefsmeets/c7bf4f5c2fcee9ec9d7ae688e782ce9f to your computer and use it in GitHub Desktop.
Watershed with periodic boundary conditions
from skimage.segmentation import _watershed, _watershed_cy
from skimage.morphology._util import (_validate_connectivity,
_offsets_to_raveled_neighbors)
def watershed_pbc(image, markers=None, connectivity=1, offset=None, mask=None,
compactness=0, watershed_line=False):
"""https://github.com/scikit-image/scikit-image/blob/main/skimage/segmentation/_watershed.py"""
image, markers, mask = _watershed._validate_inputs(image, markers, mask, connectivity)
connectivity, offset = _watershed._validate_connectivity(image.ndim, connectivity,
offset)
mask = mask.ravel()
output = markers.copy()
flat_neighborhood = _offsets_to_raveled_neighbors(
image.shape, connectivity, center=offset)
marker_locations = np.flatnonzero(output)
image_strides = np.array(image.strides, dtype=np.intp) // image.itemsize
_watershed_cy.watershed_raveled(image.ravel(),
marker_locations, flat_neighborhood,
mask, image_strides, compactness,
output.ravel(),
watershed_line)
return output
import numpy as np
import matplotlib.pyplot as plt
from scipy import ndimage as ndi
from skimage.segmentation import watershed
from skimage.feature import peak_local_max
# Generate an initial image with two overlapping circles
x, y = np.indices((80, 80))
x1, y1, x2, y2 = 28, 28, 44, 52
r1, r2 = 16, 20
mask_circle1 = (x - x1)**2 + (y - y1)**2 < r1**2
mask_circle2 = (x - x2)**2 + (y - y2)**2 < r2**2
image = np.logical_or(mask_circle1, mask_circle2)
# Now we want to separate the two objects in image
# Generate the markers as local maxima of the distance to the background
distance = ndi.distance_transform_edt(image)
coords = peak_local_max(distance, footprint=np.ones((3, 3)), labels=image)
coords[:,1] = np.mod(coords[:,1] + 40, 80)
image = np.roll(image, shift=40)
distance = np.roll(distance, shift=40)
mask = np.zeros(distance.shape, dtype=bool)
mask[tuple(coords.T)] = True
markers, _ = ndi.label(mask)
labels = watershed_pbc(-distance, markers, mask=image)
fig, axes = plt.subplots(ncols=3, figsize=(9, 3), sharex=True, sharey=True)
ax = axes.ravel()
ax[0].imshow(image, cmap=plt.cm.gray)
ax[0].scatter(*coords.T[::-1])
ax[0].set_title('Overlapping objects')
ax[1].imshow(-distance, cmap=plt.cm.gray)
ax[1].set_title('Distances')
ax[2].imshow(labels, cmap=plt.cm.nipy_spectral)
ax[2].set_title('Separated objects')
for a in ax:
a.set_axis_off()
fig.tight_layout()
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment