Skip to content

Instantly share code, notes, and snippets.

@joeld42
Created November 15, 2017 06:54
Show Gist options
  • Save joeld42/f872d393ae7d2b35c4826dac349984b9 to your computer and use it in GitHub Desktop.
Save joeld42/f872d393ae7d2b35c4826dac349984b9 to your computer and use it in GitHub Desktop.
Simple brute-force k-means palette quantization
# Slow, bad K-Means image quantization.
# This code is in the public domain.
import os, sys
import random
from PIL import Image, ImageDraw
# Doesn't seem to get much better after this, YMMV...
MAX_ITER = 10
NUM_MEANS = 16
class ColorGroup:
def __init__(self):
self.meanColor = (0.0,0.0,0.0)
self.count = 0
self.totColor = (0.0, 0.0, 0.0)
self.lastCount = 0 # NOTE: This is potentially wrong but seems to work OK
def roundColor(self):
return ( int(round(self.meanColor[0])), int(round(self.meanColor[1])), int(round(self.meanColor[2])) )
def closestColor( colors, target) :
bestCol = None
bestErr = 0.0
for cg in colors:
err = (target[0] - cg.meanColor[0])**2 + (target[1] - cg.meanColor[1])**2 + (target[2] - cg.meanColor[2])**2
if (bestCol is None) or (err < bestErr):
bestCol = cg
bestErr = err
return bestCol
if __name__=='__main__':
if len(sys.argv) < 2:
print "Usage: k-means <input image>"
sys.exit(1)
infile = sys.argv[1]
outbase, outext = os.path.splitext( infile )
outfilename = outbase + "_pal" + str(NUM_MEANS) + outext
img = Image.open( infile )
pix = img.load()
w,h = img.size
unique_color_counts = {}
for j in range(h):
for i in range (w):
c = pix[i,j]
unique_color_counts[c] = unique_color_counts.get( c, 0 ) + 1
unique_colors = list(unique_color_counts.keys())
random.shuffle( unique_colors )
print len(unique_colors), " colors in image"
means = []
while len(means) < NUM_MEANS:
cg = ColorGroup()
startCol = unique_colors[len(means)]
cg.meanColor = startCol
print startCol
means.append( cg )
for step in range(MAX_ITER):
print "K-Means iter ", step
for cg in means:
cg.totColor = (0.0, 0.0, 0.0)
cg.count = 0
for c in unique_colors:
group = closestColor( means, c )
weight = unique_color_counts[ c ]
group.totColor = (group.totColor[0] + c[0] * weight, group.totColor[1] + c[1]*weight, group.totColor[2] + c[2]*weight )
group.count += weight
converged = True
for cg in means:
if (cg.count > 0):
cg.totColor = ( cg.totColor[0] / cg.count, cg.totColor[1] / cg.count, cg.totColor[2] / cg.count )
cg.meanColor = cg.totColor
if cg.lastCount != cg.count:
cg.lastCount = cg.count
converged = False
if converged:
print "Converged..."
break
# Palettize result (still save a RGB image, but just for visualization of result)
pix = img.load()
for j in range(h-1):
for i in range (w-1):
bestCol = closestColor( means, pix[i,j] )
pix[i,j] = bestCol.roundColor()
# Draw the palette
barSz = float(w) / NUM_MEANS
draw = ImageDraw.Draw( img )
for i in range(NUM_MEANS):
cg = means[i]
draw.rectangle( [ int(i * barSz), 0, int((i+1)*barSz), 20 ], fill = cg.roundColor(), outline = (255,255,255) )
img.save( outfilename )
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment