Skip to content

Instantly share code, notes, and snippets.

@constantinpape
Created July 11, 2022 18:15
Show Gist options
  • Save constantinpape/66b04be4c6fca146e7d2a1c074081035 to your computer and use it in GitHub Desktop.
Save constantinpape/66b04be4c6fca146e7d2a1c074081035 to your computer and use it in GitHub Desktop.
Merging non-touching objects after Mutex Watershed Segmentation
import elf.segmentation as eseg
import nifty.graph.rag as nrag
import nifty.ufd
import numpy as np
from elf.segmentation.mutex_watershed import mutex_watershed
def merge_non_touching(affinities, seg, offsets):
# compute the region adjacency graph
rag = eseg.compute_rag(seg)
# compute the features corresponding to non-local edges
# (i.e. edges defined by long range affinitirs that are between nodes not
# in the region adjacency graph, i.e. corresponding segments are not touching)
lr_edges, _, lr_features = nrag.computeFeaturesAndNhFromAffinities(rag, affinities, offsets)
assert len(lr_edges) == len(lr_features)
# mean is in the first feature
mean_affinity = lr_features[:, 0]
# we set a simple threshold and merge all segmentes that have a higher affinity than this
# THIS MOST LIKELY NEEDS TO BE TWEAKED!!!
merge_threshold = 0.5
# find the merge pairs and merge the corresponding nodes with a union-find datastructure
merge_pairs = lr_edges[mean_affinity > merge_threshold]
ufd = nifty.ufd.ufd(rag.numberOfNodes)
ufd.merge(merge_pairs)
node_labeling = ufd.elementLabeling()
# project the merge result back to the pixels to get the merged segmentation
merged_seg = eseg.project_node_labels_to_pixels(rag, node_labeling)
return merged_seg
def main():
# affinities with 6 offset channels and spatial shape 128 x 128
# (we just create random toy data)
offsets = [
[-1, 0], [0, -1], [-3, 0], [0, -3], [-9, 0], [0, -9]
]
affinities = np.random.rand(len(offsets), 128, 128)
# compute the initial segmentation with mutex watershed
seg1 = mutex_watershed(affinities, offsets, strides=[2, 2])
print(seg1.shape)
# merge non-nouching segments based on high mean affinity derived from the long-range affinities
print("Merge non-touching objects...")
seg2 = merge_non_touching(affinities, seg1, offsets)
assert seg1.shape == seg2.shape
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment