Skip to content

Instantly share code, notes, and snippets.

@amueller
Last active April 27, 2016 15:24
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save amueller/d5d3de7630a25ae61cff7af4b29b5970 to your computer and use it in GitHub Desktop.
Save amueller/d5d3de7630a25ae61cff7af4b29b5970 to your computer and use it in GitHub Desktop.
extract colormap from an image
from colorspacious import cspace_convert
from scipy.sparse.csgraph import minimum_spanning_tree
from sklearn.metrics import euclidean_distances
import scipy.sparse as sp
from colorspacious import cspace_convert
from scipy.sparse.csgraph import minimum_spanning_tree
from sklearn.metrics import euclidean_distances
import scipy.sparse as sp
def is_heatmap(image, threshold=1, random_seed=None):
# drop alpha channel
if image.shape[2] == 4:
image = image[:, :, :3]
# convert to cam space
flat = image.reshape(-1, 3)
# subsample for speed
state = np.random.RandomState(random_seed)
indices = state.choice(len(flat), size=2000, replace=False)
subsample_rgb = flat[indices]
subsample = cspace_convert(subsample_rgb, "sRGB255", "CAM02-UCS")
# compute MST
distances = euclidean_distances(subsample)
mst = sp.csgraph.minimum_spanning_tree(distances)
connectivity = (mst + mst.T != 0).astype(np.int)
# prune leaves until only a chain is left:
last_n_nodes = connectivity.shape[0]
central_nodes = np.arange(last_n_nodes)
while True:
not_leaf, = np.where(np.array(connectivity.sum(axis=0)).ravel() > 1)
central_nodes = central_nodes[not_leaf]
connectivity = connectivity[not_leaf, :][:, not_leaf]
if connectivity.shape[0] >= last_n_nodes - 2:
break
last_n_nodes = connectivity.shape[0]
# central nodes contain the chain
chain = subsample[central_nodes]
print("length of chain: %d" % len(central_nodes))
if len(central_nodes) < 100:
print("probably an illustration / vectorgraphic")
# hack to return false on is_heatmap
threshold = -1
# now get the order along the chain:
# find one of the two leaves
leafs, = np.where(np.array(connectivity.sum(axis=0)).ravel() == 1)
chain_order, _ = sp.csgraph.depth_first_order(connectivity, leafs[0])
# chain order reorders central_nodes. get indices in original subsample array
ordered_chain = central_nodes[chain_order]
# check if the chain covers most of the colors:
# median was meh (a sunset fooled us), let's try l1
distance_to_chain = np.abs(np.min(euclidean_distances(subsample, chain), axis=1)).mean()
print(distance_to_chain)
return distance_to_chain < threshold, ordered_chain, subsample, subsample_rgb
if __name__ == "__main__":
from scipy.misc import imread
import matplotlib.pyplot as plt
import matplotlib
import numpy as np
image = imread("./UytvhhN.jpg")
decision, chain_order, subsample, subsample_rgb = is_heatmap(image, random_seed=5)
from mpl_toolkits.mplot3d import Axes3D
# 3d plot in cam space
X_ = subsample[chain_order]
fig = plt.figure(figsize=(8, 6))
ax = Axes3D(fig, elev=-150, azim=110)
ax.scatter(X_[:, 0], X_[:, 1], X_[:, 2], s=100)
ax.scatter(subsample[:, 0], subsample[:, 1], subsample[:, 2], c=subsample_rgb/255.)
# extracted colormap
plt.figure()
plt.imshow(np.repeat(subsample_rgb[chain_order].reshape(1, -1, 3), repeats=10, axis=0))
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment