Skip to content

Instantly share code, notes, and snippets.

@banditkings
Created December 22, 2022 20:53
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save banditkings/6c0345f00c10c2b5acd8f07ebd8e576e to your computer and use it in GitHub Desktop.
Save banditkings/6c0345f00c10c2b5acd8f07ebd8e576e to your computer and use it in GitHub Desktop.
Custom Color Palette from an Animated GIF in Julia
"""
Example:
```julia
# Include this file
include("palette.jl")
filepath = "data/image001.gif"
# Set a seed for reproducibility
Random.seed!(42)
# Get the custom color palette
top_colors = make_palette(filepath, 0.4)
# Create a summary plot of the results
fig = create_summary(filepath, top_colors)
```
"""
using Images, Clustering, Random, Distances
using Images:load, channelview, HSV, colorview
using Statistics:mean
using DataFramesMeta:Not
using CairoMakie
"""
Helper function to resize an image to 2D for our
distance calculations
"""
function reshape_channelview(X)
if length(size(X))==3
C, H, W = size(X)
result = reshape(X, (C, H*W))
else
C, H, W, N = size(X)
result = reshape(X, (C, H*W*N))
end
return(result)
end
"""
Helper function to find the closest actual colors
to the cluster centers
"""
function get_actual_colors(img, clusters)
# Reshape original image so we can calculate the
# pairwise distances
X = reshape_channelview(channelview(img))
# compute distances from the cluster centers and the
# actual pixel values
nearest_actuals = pairwise(Euclidean(), clusters, X, dims=2)
# Create a vector of the nearest locations
# where each item in the vector is a CartesianIndex
nearest_loc = argmin(nearest_actuals, dims=2)
# get results
results = []
n_clusters = size(clusters)[2]
n_channels = size(X)[1]
results = zeros(n_channels, n_clusters)
for loc in nearest_loc
# because it's a cartesian index we need
# to split it up into two parts:
# loc[1] is the color and loc[2] is the column
results[:, loc[1]] = X[:, loc[2]]
end
return(results)
end
"""
Run K-means and sort the resultant cluster centers based
on how common they are in the original image
"""
function get_centers(dat::Array; n_clusters::Int=30,
use_actual_colors::Bool=true)
kmeans_model = kmeans(dat, n_clusters; maxiter=200, display=:none)
centers = kmeans_model.centers
# Cluster counts
value_counts = counts(kmeans_model)
# Sort the cluster centers by the counts
sorted_centers = centers[:, sortperm(value_counts, rev=true)]
if use_actual_colors==true
sorted_centers = get_actual_colors(dat, sorted_centers)
end
return sorted_centers
end
"""
Given an array of K cluster centers from K-means and a threshold value
between 0 and 1, find a set of cluster centers that's visually distant
from each other
Parameters
----------
rgb_kmeans : KmeansResult
Result from Clustering.kmeans clustering method fit on the original
data
threshold : AbstractFloat
Any centers
Returns
-------
top_colors : Matrix
A 3xK matrix where each row represents red, green, or blue, and each of the K
columns represents a cluster center
"""
function filter_centers(sorted_centers::Matrix, threshold::AbstractFloat)
# if you give a single matrix, it will automatically compute
# the distance between each column
distmat = pairwise(Euclidean(), sorted_centers)
most_distant_color = argmax(distmat[1,:])
# start with the modal centroid and the most different
# color
newmat = copy(sorted_centers)
top_colors = newmat[:, [1, most_distant_color]]
# Update the remaining color matrix
newmat = newmat[:, Not([1, most_distant_color])]
while size(newmat)[2] !=0
# Add the color that has the highest average distance
# from the remaining points
new_dist = pairwise(Euclidean(), top_colors, newmat, dims=2)
# If any of the remaining points has a distance to the
# any of other points less than some threshold, then drop it
drop_cols = []
for i in 1:size(newmat)[2]
if minimum(new_dist[:, i]) < threshold
push!(drop_cols, i)
end
end
newmat = newmat[:, Not(drop_cols)]
# If there are colors to add:
if size(newmat)[2] != 0
# Repeat new distances calc after removing stuff
new_dist = pairwise(Euclidean(), top_colors, newmat, dims=2)
# Add in the next highest
avg_distance = mean(new_dist, dims=1)
next_color_index = argmax(avg_distance)[2]
next_color = newmat[:, next_color_index]
# update top colors
top_colors = hcat(top_colors, next_color)
# update new mat
newmat = newmat[:, Not(next_color_index)]
end
end
return top_colors
end
"""Run the pipeline"""
function make_palette(filepath::String, threshold::AbstractFloat;
n_clusters::Int=30,
use_actual_colors::Bool=true)
dat = load(filepath)
# It loads in the gif into an Array of shape (H, W, N) where N is
# the number of frames and each value in the array is a RGB color
# 266x480x94
# If it's not RGB color, then coerce to RGB from RGBA or other format
dat = RGB.(dat)
sorted_centers = get_centers(dat;
n_clusters=n_clusters,
use_actual_colors=use_actual_colors)
top_colors = filter_centers(sorted_centers, threshold)
return top_colors
end
"""
Create a single plot that shows the original image,
the final color palette with hex codes, and
an example plot
"""
function create_summary(filepath::String, top_colors::Matrix)
img = load(filepath)
f = Figure(resolution=(800,600))
# Row 1: Show a snapshot of the original image
ax1 = CairoMakie.Axis(f[1,1], aspect=DataAspect())
hidedecorations!(ax1)
# if `img` is an animated gif, we want to extract a frame
# from the original image and display it as Row 1.
if length(size(img))==3
frame = img[:,:,rand(1:size(img)[3])]
else
frame = img
end
image!(ax1, rotr90(frame))
# Row 2: Show the color palette with hex code labels
ax2 = CairoMakie.Axis(f[2,1], aspect=DataAspect())
hidedecorations!(ax2)
# Convert the palette matrix to an image called `mat`
# where each color is a 50 x 50 square
colors = colorview(RGB, top_colors)
dim = 50
mat = fill(colors[1], (dim, dim))
for color in colors[2:length(colors)]
mat = hcat(mat, fill(color, (dim, dim)))
end
image!(ax2, rotr90(mat))
# Add an array of hex colors and add the '#' sign
hexcolors = ["#"*color for color in hex.(colors)]
# ...and show it underneath the color palette
for i in 1:length(hexcolors)
text!(ax2, hexcolors[i],
position=Point2f((dim*i)-(dim/2),-9),
align=(:center, :bottom), fontsize=16)
end
# Row 3: Make a sample plot with some of the colors
ax3 = CairoMakie.Axis(f[3,1]; backgroundcolor=colors[1],
topspinevisible = false,
rightspinevisible = false,
xgridcolor = colors[2],
ygridcolor = colors[2],
bottomspinecolor = colors[2],
leftspinecolor = colors[2],
ytickcolor = colors[2],
xtickcolor = colors[2],
title = "Sample Plot")
x = rand(10)
y = rand(10)
z = rand(10)
scatterlines!(ax3, x; linewidth=3, color=colors[3])
scatterlines!(ax3, y; linewidth=3, color=colors[4])
scatterlines!(ax3, z; linewidth=3, color=colors[5])
return f
end
@banditkings
Copy link
Author

Usage:

# Include this file
include("palette.jl")
filepath = "data/image001.gif"
# Set a seed for reproducibility
Random.seed!(42)
# Get the custom color palette
top_colors = make_palette(filepath, 0.4)
# Create a summary plot of the results
fig = create_summary(filepath, top_colors)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment