Skip to content

Instantly share code, notes, and snippets.

@royshil
Last active August 11, 2020 10:31
Show Gist options
  • Star 2 You must be signed in to star a gist
  • Fork 2 You must be signed in to fork a gist
  • Save royshil/52bd3a0e21e9e7f6a8747bde52a3f86b to your computer and use it in GitHub Desktop.
Save royshil/52bd3a0e21e9e7f6a8747bde52a3f86b to your computer and use it in GitHub Desktop.
import cv2
import numpy as np
import matplotlib.pyplot as plt
from skimage.segmentation import slic
from skimage.segmentation import mark_boundaries
from skimage.data import astronaut
from skimage.util import img_as_float
import maxflow
from scipy.spatial import Delaunay
# Calculate the SLIC superpixels, their histograms and neighbors
def superpixels_histograms_neighbors(img):
# SLIC
segments = slic(img, n_segments=500, compactness=20)
segments_ids = np.unique(segments)
# centers
centers = np.array([np.mean(np.nonzero(segments==i),axis=1) for i in segments_ids])
# H-S histograms for all superpixels
hsv = cv2.cvtColor(img.astype('float32'), cv2.COLOR_BGR2HSV)
bins = [20, 20] # H = S = 20
ranges = [0, 360, 0, 1] # H: [0, 360], S: [0, 1]
colors_hists = np.float32([cv2.calcHist([hsv],[0, 1], np.uint8(segments==i), bins, ranges).flatten() for i in segments_ids])
# neighbors via Delaunay tesselation
tri = Delaunay(centers)
return (centers,colors_hists,segments,tri.vertex_neighbor_vertices)
# Get superpixels IDs for FG and BG from marking
def find_superpixels_under_marking(marking, superpixels):
fg_segments = np.unique(superpixels[marking[:,:,0]!=255])
bg_segments = np.unique(superpixels[marking[:,:,2]!=255])
return (fg_segments, bg_segments)
# Sum up the histograms for a given selection of superpixel IDs, normalize
def cumulative_histogram_for_superpixels(ids, histograms):
h = np.sum(histograms[ids],axis=0)
return h / h.sum()
# Get a bool mask of the pixels for a given selection of superpixel IDs
def pixels_for_segment_selection(superpixels_labels, selection):
pixels_mask = np.where(np.isin(superpixels_labels, selection), True, False)
return pixels_mask
# Get a normalized version of the given histograms (divide by sum)
def normalize_histograms(histograms):
return np.float32([h / h.sum() for h in histograms])
# Perform graph cut using superpixels histograms
def do_graph_cut(fgbg_hists, fgbg_superpixels, norm_hists, neighbors):
num_nodes = norm_hists.shape[0]
# Create a graph of N nodes, and estimate of 5 edges per node
g = maxflow.Graph[float](num_nodes, num_nodes * 5)
# Add N nodes
nodes = g.add_nodes(num_nodes)
hist_comp_alg = cv2.HISTCMP_KL_DIV
# Smoothness term: cost between neighbors
indptr,indices = neighbors
for i in range(len(indptr)-1):
N = indices[indptr[i]:indptr[i+1]] # list of neighbor superpixels
hi = norm_hists[i] # histogram for center
for n in N:
if (n < 0) or (n > num_nodes):
continue
# Create two edges (forwards and backwards) with capacities based on
# histogram matching
hn = norm_hists[n] # histogram for neighbor
g.add_edge(nodes[i], nodes[n], 20-cv2.compareHist(hi, hn, hist_comp_alg),
20-cv2.compareHist(hn, hi, hist_comp_alg))
# Match term: cost to FG/BG
for i,h in enumerate(norm_hists):
if i in fgbg_superpixels[0]:
g.add_tedge(nodes[i], 0, 1000) # FG - set high cost to BG
elif i in fgbg_superpixels[1]:
g.add_tedge(nodes[i], 1000, 0) # BG - set high cost to FG
else:
g.add_tedge(nodes[i], cv2.compareHist(fgbg_hists[0], h, hist_comp_alg),
cv2.compareHist(fgbg_hists[1], h, hist_comp_alg))
g.maxflow()
return g.get_grid_segments(nodes)
if __name__ == '__main__':
img = img_as_float(astronaut()[::2, ::2])
img_marking = cv2.imread("astronaut_marking.png")
centers, colors_hists, segments, neighbors = superpixels_histograms_neighbors(img)
fg_segments, bg_segments = find_superpixels_under_marking(img_marking, segments)
# get cumulative BG/FG histograms, before normalization
fg_cumulative_hist = cumulative_histogram_for_superpixels(fg_segments, colors_hists)
bg_cumulative_hist = cumulative_histogram_for_superpixels(bg_segments, colors_hists)
norm_hists = normalize_histograms(colors_hists)
graph_cut = do_graph_cut((fg_cumulative_hist, bg_cumulative_hist),
(fg_segments, bg_segments),
norm_hists,
neighbors)
plt.subplot(1,2,2), plt.xticks([]), plt.yticks([])
plt.title('segmentation')
segmask = pixels_for_segment_selection(segments, np.nonzero(graph_cut))
cv2.imwrite("output_segmentation.png", np.uint8(segmask * 255))
plt.imshow(segmask)
plt.subplot(1,2,1), plt.xticks([]), plt.yticks([])
img = mark_boundaries(img, segments)
img[img_marking[:,:,0]!=255] = (1,0,0)
img[img_marking[:,:,2]!=255] = (0,0,1)
plt.imshow(img)
plt.title("SLIC + markings")
plt.savefig("segmentation.png",bbox_inches='tight',dpi=96)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment