Skip to content

Instantly share code, notes, and snippets.

@jsundram
Last active February 18, 2021 19:35
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save jsundram/e78bc63cb131fb66cb7c811bdd9a3e0e to your computer and use it in GitHub Desktop.
Save jsundram/e78bc63cb131fb66cb7c811bdd9a3e0e to your computer and use it in GitHub Desktop.
SuperPixel Segmentation in Julia
# src: https://github.com/johnnychen94/SuperPixels.jl/blob/analyze/src/analyze.jl
function SLIC(img::AbstractArray{<:Colorant, 2}; kwargs...)
return SegmentedImage(img, _slic(Lab.(img); kwargs...))
end
SLIC(img::AbstractArray{<:Number, 2}; kwargs...) = _slic(Gray.(img); kwargs...)
"""
img = Images.load("/path/to/my/image.jpg")
seg = SLIC(img; n_segments=500, compactness=40.0, enforce_connectivity=true)
Simple Linear Iterative Clustering (SLIC) algorithm segments image using
an `O(N)` complexity version of k-means clustering in Color-(x,y,z) space.
The implementation details are described in [1].
# Arguments:
* `img` : an Image
# Keywords
## `n_segments::Int`
The _approximate_ number of labels in the segmented output image. The default
value is `100`.
## `compactness::Float64`
Balances color proximity and space proximity. Higher values give more weight to
space proximity, making superpixel shapes more square/cubic. The default value
is `10.0`.
!!! tip
We recommend exploring possible values on a log scale, e.g., `0.01`, `0.1`,
`1`, `10`, `100`, before refining around a chosen value.
## `max_iter::Int`
Maximum number of iterations of k-means. The default value is `10`.
## `enforce_connectivity::Bool`
Whether the generated segments are connected or not. The default value is `true`.
# Example:
```julia
using Images
img = Images.load("./test.jpg");
seg = SLIC(img; n_segments=100, compactness=40, enforce_connectivity=true)
```
# References
[1] Radhakrishna Achanta, Appu Shaji, Kevin Smith, Aurelien Lucchi, Pascal Fua, and Sabine Süsstrunk, SLIC Superpixels, _EPFL Technical Report_ no. 149300, June 2010.
[2] Radhakrishna Achanta, Appu Shaji, Kevin Smith, Aurelien Lucchi, Pascal Fua, and Sabine Süsstrunk, SLIC Superpixels Compared to State-of-the-art Superpixel Methods, _IEEE Transactions on Pattern Analysis and Machine Intelligence_, vol. 34, num. 11, p. 2274 – 2282, May 2012.
[3] EPFL (2018, Oct 24). SLIC Superpixels. Retrieved from https://ivrl.epfl.ch/research-2/research-current/research-superpixels/, Sep 29, 2019.
"""
function _slic(img::AbstractArray{<:Lab, 2}; n_segments::Integer=100,
compactness::Real=10.0,
max_iter::Integer=10,
enforce_connectivity::Bool=true)
m = compactness
N = length(img)
S = ceil(Int, sqrt(N / n_segments))
spatial_weight = m / S
height, width = size(img)
raw_img = permutedims(channelview(img), (2, 3, 1)) # [x y c] shape
# Initialize cluster centers with a grid of step S
x = range(1, step=S, stop=height)
y = range(1, step=S, stop=width)
initial_centers = reshape(CartesianIndex.(Iterators.product(x, y)), :)
n_segments = length(initial_centers) # _actual_ number of segments
# xylab array of centers
segments = zeros(Float32, (n_segments, 5))
segments[:, 1] = map(xy->xy[1], initial_centers)
segments[:, 2] = map(xy->xy[2], initial_centers)
segments[:, 3:5] = raw_img[initial_centers, :]
# Number of pixels in each segment
n_segment_elems = zeros(Int, n_segments)
# nearest_distance[x, y] is the distance between pixel (x, y) and its current assigned center
nearest_distance = fill(Inf, size(img))
# nearest_segments[x, y] is the label of its assigned center, each center is labeled from 1 to n_segments
nearest_segments = zeros(Int, size(img))
for i in 1:max_iter
changed = false
for k in 1:n_segments
cx, cy = segments[k, 1:2]
x_min = floor(Int, max(cx - 2S, 1))
x_max = ceil(Int, min(cx + 2S, height))
y_min = floor(Int, max(cy - 2S, 1))
y_max = ceil(Int, min(cy + 2S, width))
# Break distance computation into nested for-loop to reuse `dy`
for y in y_min:y_max
dy = abs2(cy - y)
for x in x_min:x_max
dx = abs2(cx - x)
dist_center = sqrt(dy + dx) * spatial_weight
dist_color = 0
for c in 3:5
t = raw_img[x, y, c-2] - segments[k, c]
dist_color += abs2(t)
end
dist_center += sqrt(dist_color)
if dist_center < nearest_distance[x, y]
nearest_distance[x, y] = dist_center
nearest_segments[x, y] = k
changed = true
end
end
end
end
changed || break
# Recompute segment centers.
# sum features for all segments
n_segment_elems[:] .= 0
segments[:, 1:2] .= 0
for I in CartesianIndices(img)
k = nearest_segments[I]
n_segment_elems[k] += 1
segments[k, 1:2] .+= I.I
segments[k, 3:5] .+= raw_img[I, :]
end
# divide by number of elements per segment to obtain mean
n_segment_elems[n_segment_elems.==0] .= 1
segments ./= n_segment_elems # broadcast: (n_segments,) -> (n_segments, 5)
end
# Enforce connectivity
# Reference: https://github.com/scikit-image/scikit-image/blob/7e4840bd9439d1dfb6beaf549998452c99f97fdd/skimage/segmentation/_slic.pyx#L240-L348
if enforce_connectivity
segment_size = height * width / n_segments
min_size = round(Int, 0.5 * segment_size)
max_size = round(Int, 3.0 * segment_size)
dy = [1, -1, 0, 0]
dx = [0, 0, 1, -1]
# dz = [] # reversed for supervoxels
start_label = 1
mask_label = start_label - 1 # indicates the label of this pixel has not been assigned
nearest_segments_final = fill(mask_label, height, width)
current_new_label = start_label
# used for BFS
current_segment_size = 1
bfs_visited = 0
# store neighboring pixels
# now set the dimension to 2 because we are using superpixel
coord_list = fill(0, max_size, 2)
for x = 1:width
for y = 1:height
nearest_segments[y, x] == mask_label && continue
nearest_segments_final[y, x] > mask_label && continue
adjacent = 0
label = nearest_segments[y, x]
nearest_segments_final[y, x] = current_new_label
current_segment_size = 1
bfs_visited = 0
coord_list[bfs_visited + 1, 1] = y
coord_list[bfs_visited + 1, 2] = x
# Perform BFS to find size of superpixel with the same label
while bfs_visited < current_segment_size <= max_size
for i = 1:4
yy = coord_list[bfs_visited + 1, 1] + dy[i]
xx = coord_list[bfs_visited + 1, 2] + dx[i]
if 1 <= yy <= height && 1 <= xx <= width
if nearest_segments[yy, xx] == label && nearest_segments_final[yy, xx] == mask_label
nearest_segments_final[yy, xx] = current_new_label
coord_list[current_segment_size + 1, 1] = yy # <-- index problem in the future
coord_list[current_segment_size + 1, 2] = xx # <-- index problem in the future
current_segment_size += 1
if current_segment_size > max_size break end
elseif nearest_segments_final[yy, xx] > mask_label &&
nearest_segments_final[yy, xx] != current_new_label
adjacent = nearest_segments_final[yy, xx]
end
end
end
bfs_visited += 1
end
# Merge the superpixel with its neighbor if it is too small
if current_segment_size < min_size
for i = 1:current_segment_size
nearest_segments_final[coord_list[i, 1], coord_list[i, 2]] = adjacent
end
else
current_new_label += 1
end
end
end
nearest_segments = nearest_segments_final
end
# I would love to call convert here, but computing region means in Lab
# would be weird (and requires an implementation of /, +, zero for the
# Lab struct, something like: .
#
# import Base./, Base.+, Base.zero
# /(c::Lab{Float32}, n::Int64) = Lab(c.l / n, c.a / n, c.b / n)
# +(c1::Lab{Float32}, c2::Lab{Float32}) = Lab(c1.l + c2.l, c1.a + c2.a, c1.b + c2.b)
# zero(::Type{Lab{Float32}}) = Lab(0, 0, 0)
#
# Even if I did that, it would be weird to have image region means in Lab
# colorspace instead of the original.
return nearest_segments
end
"""
SegmentedImage(img, nearest_segments)
Takes the output of SLIC and the source image it was run on, and returns
a [`SegmentedImage`](@ref)
"""
function SegmentedImage(img, labelled_px::Array{Int,2})
n, _ = findmax(labelled_px)
means, px_counts = Dict{Int64, Colorant}(), Dict{Int64, Int64}()
labels = collect(1:n)
for i in labels
ix = findall(x -> x == i, labelled_px)
means[i] = mean(img[ix])
px_counts[i] = length(ix)
end
return SegmentedImage(labelled_px, labels, means, px_counts)
end
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment