Created
December 22, 2022 20:53
-
-
Save banditkings/6c0345f00c10c2b5acd8f07ebd8e576e to your computer and use it in GitHub Desktop.
Custom Color Palette from an Animated GIF in Julia
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Example: | |
```julia | |
# Include this file | |
include("palette.jl") | |
filepath = "data/image001.gif" | |
# Set a seed for reproducibility | |
Random.seed!(42) | |
# Get the custom color palette | |
top_colors = make_palette(filepath, 0.4) | |
# Create a summary plot of the results | |
fig = create_summary(filepath, top_colors) | |
``` | |
""" | |
using Images, Clustering, Random, Distances | |
using Images:load, channelview, HSV, colorview | |
using Statistics:mean | |
using DataFramesMeta:Not | |
using CairoMakie | |
""" | |
Helper function to resize an image to 2D for our | |
distance calculations | |
""" | |
function reshape_channelview(X) | |
if length(size(X))==3 | |
C, H, W = size(X) | |
result = reshape(X, (C, H*W)) | |
else | |
C, H, W, N = size(X) | |
result = reshape(X, (C, H*W*N)) | |
end | |
return(result) | |
end | |
""" | |
Helper function to find the closest actual colors | |
to the cluster centers | |
""" | |
function get_actual_colors(img, clusters) | |
# Reshape original image so we can calculate the | |
# pairwise distances | |
X = reshape_channelview(channelview(img)) | |
# compute distances from the cluster centers and the | |
# actual pixel values | |
nearest_actuals = pairwise(Euclidean(), clusters, X, dims=2) | |
# Create a vector of the nearest locations | |
# where each item in the vector is a CartesianIndex | |
nearest_loc = argmin(nearest_actuals, dims=2) | |
# get results | |
results = [] | |
n_clusters = size(clusters)[2] | |
n_channels = size(X)[1] | |
results = zeros(n_channels, n_clusters) | |
for loc in nearest_loc | |
# because it's a cartesian index we need | |
# to split it up into two parts: | |
# loc[1] is the color and loc[2] is the column | |
results[:, loc[1]] = X[:, loc[2]] | |
end | |
return(results) | |
end | |
""" | |
Run K-means and sort the resultant cluster centers based | |
on how common they are in the original image | |
""" | |
function get_centers(dat::Array; n_clusters::Int=30, | |
use_actual_colors::Bool=true) | |
kmeans_model = kmeans(dat, n_clusters; maxiter=200, display=:none) | |
centers = kmeans_model.centers | |
# Cluster counts | |
value_counts = counts(kmeans_model) | |
# Sort the cluster centers by the counts | |
sorted_centers = centers[:, sortperm(value_counts, rev=true)] | |
if use_actual_colors==true | |
sorted_centers = get_actual_colors(dat, sorted_centers) | |
end | |
return sorted_centers | |
end | |
""" | |
Given an array of K cluster centers from K-means and a threshold value | |
between 0 and 1, find a set of cluster centers that's visually distant | |
from each other | |
Parameters | |
---------- | |
rgb_kmeans : KmeansResult | |
Result from Clustering.kmeans clustering method fit on the original | |
data | |
threshold : AbstractFloat | |
Any centers | |
Returns | |
------- | |
top_colors : Matrix | |
A 3xK matrix where each row represents red, green, or blue, and each of the K | |
columns represents a cluster center | |
""" | |
function filter_centers(sorted_centers::Matrix, threshold::AbstractFloat) | |
# if you give a single matrix, it will automatically compute | |
# the distance between each column | |
distmat = pairwise(Euclidean(), sorted_centers) | |
most_distant_color = argmax(distmat[1,:]) | |
# start with the modal centroid and the most different | |
# color | |
newmat = copy(sorted_centers) | |
top_colors = newmat[:, [1, most_distant_color]] | |
# Update the remaining color matrix | |
newmat = newmat[:, Not([1, most_distant_color])] | |
while size(newmat)[2] !=0 | |
# Add the color that has the highest average distance | |
# from the remaining points | |
new_dist = pairwise(Euclidean(), top_colors, newmat, dims=2) | |
# If any of the remaining points has a distance to the | |
# any of other points less than some threshold, then drop it | |
drop_cols = [] | |
for i in 1:size(newmat)[2] | |
if minimum(new_dist[:, i]) < threshold | |
push!(drop_cols, i) | |
end | |
end | |
newmat = newmat[:, Not(drop_cols)] | |
# If there are colors to add: | |
if size(newmat)[2] != 0 | |
# Repeat new distances calc after removing stuff | |
new_dist = pairwise(Euclidean(), top_colors, newmat, dims=2) | |
# Add in the next highest | |
avg_distance = mean(new_dist, dims=1) | |
next_color_index = argmax(avg_distance)[2] | |
next_color = newmat[:, next_color_index] | |
# update top colors | |
top_colors = hcat(top_colors, next_color) | |
# update new mat | |
newmat = newmat[:, Not(next_color_index)] | |
end | |
end | |
return top_colors | |
end | |
"""Run the pipeline""" | |
function make_palette(filepath::String, threshold::AbstractFloat; | |
n_clusters::Int=30, | |
use_actual_colors::Bool=true) | |
dat = load(filepath) | |
# It loads in the gif into an Array of shape (H, W, N) where N is | |
# the number of frames and each value in the array is a RGB color | |
# 266x480x94 | |
# If it's not RGB color, then coerce to RGB from RGBA or other format | |
dat = RGB.(dat) | |
sorted_centers = get_centers(dat; | |
n_clusters=n_clusters, | |
use_actual_colors=use_actual_colors) | |
top_colors = filter_centers(sorted_centers, threshold) | |
return top_colors | |
end | |
""" | |
Create a single plot that shows the original image, | |
the final color palette with hex codes, and | |
an example plot | |
""" | |
function create_summary(filepath::String, top_colors::Matrix) | |
img = load(filepath) | |
f = Figure(resolution=(800,600)) | |
# Row 1: Show a snapshot of the original image | |
ax1 = CairoMakie.Axis(f[1,1], aspect=DataAspect()) | |
hidedecorations!(ax1) | |
# if `img` is an animated gif, we want to extract a frame | |
# from the original image and display it as Row 1. | |
if length(size(img))==3 | |
frame = img[:,:,rand(1:size(img)[3])] | |
else | |
frame = img | |
end | |
image!(ax1, rotr90(frame)) | |
# Row 2: Show the color palette with hex code labels | |
ax2 = CairoMakie.Axis(f[2,1], aspect=DataAspect()) | |
hidedecorations!(ax2) | |
# Convert the palette matrix to an image called `mat` | |
# where each color is a 50 x 50 square | |
colors = colorview(RGB, top_colors) | |
dim = 50 | |
mat = fill(colors[1], (dim, dim)) | |
for color in colors[2:length(colors)] | |
mat = hcat(mat, fill(color, (dim, dim))) | |
end | |
image!(ax2, rotr90(mat)) | |
# Add an array of hex colors and add the '#' sign | |
hexcolors = ["#"*color for color in hex.(colors)] | |
# ...and show it underneath the color palette | |
for i in 1:length(hexcolors) | |
text!(ax2, hexcolors[i], | |
position=Point2f((dim*i)-(dim/2),-9), | |
align=(:center, :bottom), fontsize=16) | |
end | |
# Row 3: Make a sample plot with some of the colors | |
ax3 = CairoMakie.Axis(f[3,1]; backgroundcolor=colors[1], | |
topspinevisible = false, | |
rightspinevisible = false, | |
xgridcolor = colors[2], | |
ygridcolor = colors[2], | |
bottomspinecolor = colors[2], | |
leftspinecolor = colors[2], | |
ytickcolor = colors[2], | |
xtickcolor = colors[2], | |
title = "Sample Plot") | |
x = rand(10) | |
y = rand(10) | |
z = rand(10) | |
scatterlines!(ax3, x; linewidth=3, color=colors[3]) | |
scatterlines!(ax3, y; linewidth=3, color=colors[4]) | |
scatterlines!(ax3, z; linewidth=3, color=colors[5]) | |
return f | |
end |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Usage: