Skip to content

Instantly share code, notes, and snippets.

@smiranda
Created July 20, 2021 23:27
Show Gist options
  • Star 2 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save smiranda/88601542d3012b05f48c8a440e6ea7a9 to your computer and use it in GitHub Desktop.
Save smiranda/88601542d3012b05f48c8a440e6ea7a9 to your computer and use it in GitHub Desktop.
A script to map two color palletes using delta_e_cie2000
import sys
import itertools
import networkx as nx
from networkx.algorithms import bipartite
from colormath.color_objects import sRGBColor, LabColor
from colormath.color_conversions import convert_color
from colormath.color_diff import delta_e_cie2000
# References
# (1) About comparing colors with colormath: https://dev.to/tejeshreddy/color-difference-between-2-colours-using-python-182b
# (2) About networkx for matching bipartite graphs https://towardsdatascience.com/matching-of-bipartite-graphs-using-networkx-6d355b164567
# Usage
# python .\map-pallete.py .\pallete-to-be-mapped.hex .\reference-pallete.hex
# NOTE: Only tested for palletes of the same size. Should be easy to adapt for different size palletes as the underlying bipartite graph problem allows for that.
def openHex(path):
ls = []
lo = []
with open(path) as f0:
for line in f0:
if len(line) >= 6:
ls.append(convert_color(sRGBColor.new_from_rgb_hex(line), LabColor))
lo.append(line.strip())
return (ls, lo)
def map_pallete(f0n, f1n):
(f0ls, f0os) = openHex(f0n)
(f1ls, f1os) = openHex(f1n)
B = nx.Graph()
top_nodes = [f"T{n}" for n in range(len(f0ls))]
bottom_nodes = [f"B{n}" for n in range(len(f1ls))]
top_nodes_idx = [n for n in range(len(f0ls))]
bottom_nodes_idx = [n for n in range(len(f1ls))]
bottom_nodes_imap = {bn: i for (i, bn) in enumerate(bottom_nodes)}
top_nodes_hexmap = {tn: f0os[i] for (i, tn) in enumerate(top_nodes)}
B.add_nodes_from(top_nodes, bipartite=0)
B.add_nodes_from(bottom_nodes, bipartite=1)
for x in itertools.product(top_nodes_idx, bottom_nodes_idx):
B.add_edge(
top_nodes[x[0]],
bottom_nodes[x[1]],
weight=delta_e_cie2000(f0ls[x[0]], f1ls[x[1]]),
)
best_pallete_match = bipartite.minimum_weight_full_matching(B, top_nodes, "weight")
idxs = [bottom_nodes_imap[best_pallete_match[tn]] for tn in top_nodes]
with open("out.hex", "w") as fo:
fo.write(
"\n".join([top_nodes_hexmap[x] for _, x in sorted(zip(idxs, top_nodes))])
+ "\n"
)
if __name__ == "__main__":
map_pallete(sys.argv[1], sys.argv[2])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment