Skip to content

Instantly share code, notes, and snippets.

@ismael-elatifi
Last active March 24, 2021 23:38
Show Gist options
  • Save ismael-elatifi/228397a4d1383b6d07e691071f0bc4a1 to your computer and use it in GitHub Desktop.
Save ismael-elatifi/228397a4d1383b6d07e691071f0bc4a1 to your computer and use it in GitHub Desktop.
Compute minimum spanning tree in binary image (nodes are white pixels) using Prim's algorithm
import cv2
import numpy as np
from heapq import heappush, heappop
def find_pixels_in_window(img, x, y, window_size):
d = window_size//2
y_start = max(y-d,0)
x_start = max(x-d,0)
window = img[y_start:y+d+1, x_start:x+d+1]
ay, ax = np.where(window > 0)
ay += y_start
ax += x_start
return ax, ay
def get_neighbors_and_distance(img, x, y, window_size):
ax, ay = find_pixels_in_window(img, x, y, window_size=window_size)
def dist(x, y, x2, y2):
return abs(x-x2)+abs(y-y2)
return [((x2, y2), dist(x, y, x2, y2)) for x2, y2 in zip(ax, ay) if x != x2 or y != y2]
def compute_minimum_spanning_tree(img, window_size):
"""
Compute minimum spanning tree in binary image (nodes are white pixels) using Prim algorithm
Look for neighbors inside a window (window_size, window_size) around current pixel
"""
if img.max() == 0:
return dict(), 0, 0
heap = []
ay, ax = np.where(img > 0) # nodes are non null pixels
n_nodes = len(ay)
cur = ax[0], ay[0]
all_nodes = set((x,y) for x,y in zip(ax, ay))
all_nodes.remove(cur)
dict_tree = {cur: None} # dict node -> father
n_trees = 1
total_distance = 0
while len(dict_tree) < n_nodes:
for x_y_adj, dist in get_neighbors_and_distance(img, *cur, window_size=window_size):
if x_y_adj not in dict_tree:
heappush(heap, (dist, cur, x_y_adj))
if not heap: # to handle multiple disconnected trees
n_trees += 1
all_nodes = all_nodes.difference(dict_tree.values())
if not all_nodes:
break
cur = next(iter(all_nodes)) # start a new tree at node cur
all_nodes.remove(cur)
continue
dist, father, cur = heappop(heap)
while cur in dict_tree and heap:
dist, father, cur = heappop(heap)
if cur not in dict_tree:
dict_tree[cur] = father
total_distance += dist
return dict_tree, total_distance, n_trees
def draw_tree(img, dict_tree, show_terminal_edge=False):
img = np.zeros_like(img)
img_rgb = np.dstack(3*[img])
set_fathers = set(dict_tree.values())
for cur, father in dict_tree.items():
if father is None:
continue
if show_terminal_edge and cur not in set_fathers:
color = (0,0,255)
thickness = 2
else:
color = (255,255,255)
thickness = 1
cv2.line(img_rgb, cur,father,color,thickness=thickness)
return img_rgb
def make_binary_image_random(height, width, proba_black_pixel):
"""Make a binary image with white pixels at random positions"""
return ((np.random.rand(height,width)>proba_black_pixel)*255).astype("uint8")
@ismael-elatifi
Copy link
Author

ismael-elatifi commented Mar 24, 2021

Usage example :

img = make_binary_image_random(height=200, width=400, proba_black_pixel=0.99)

# compute the minimum spanning tree to connect white pixels
dict_tree, total_distance, n_trees = compute_minimum_spanning_tree(img, window_size=15)
print("Number of spanning trees :", n_trees)  # Number of spanning trees : 413

img_tree_rgb = draw_tree(img, dict_tree)
img_orig_rgb = np.dstack(3*[img])
img2 = np.hstack((img_orig_rgb, img_tree_rgb))
cv2.imshow("nodes (left) | tree (right) | window_size=15", img2)

# same nodes with larger window (more edges so fewer number of trees)
dict_tree, total_distance, n_trees = compute_minimum_spanning_tree(img, window_size=40)
print("Number of spanning trees :", n_trees)  # Number of spanning trees : 1

img_tree_rgb = draw_tree(img, dict_tree)
img_orig_rgb = np.dstack(3*[img])
img2 = np.hstack((img_orig_rgb, img_tree_rgb))
cv2.imshow("nodes (left) | tree (right) | window_size=40", img2)
cv2.waitKey()

image

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment