Skip to content

Instantly share code, notes, and snippets.

@douglasgoodwin
Last active August 29, 2015 14:17
Show Gist options
  • Save douglasgoodwin/8845beda75c1ca73c7ae to your computer and use it in GitHub Desktop.
Save douglasgoodwin/8845beda75c1ca73c7ae to your computer and use it in GitHub Desktop.
Using python and k-means to find the dominant colors in images. "forked" from charlesleifer.com
"""
from: http://charlesleifer.com/blog/using-python-and-k-means-to-find-the-dominant-colors-in-images/
start with a bunch of data points. For simplicity let's say they're numbers on a number-line.
You want to group the numbers into "k" clusters, so pick "k" points randomly from the data to
use as your "clusters".
Now loop over every point in the data and calculate its distance to each of the "k" clusters.
Find the nearest cluster and associate that point with the cluster. When you've looped over
all the points they should all be assigned to one of the "k" clusters. Now, for each cluster
recalculate its center by averaging the distances of all the associated points and start over.
When the centers stop moving very much you can stop looping. You will end up with something
like this -- the points are colored based on what "cluster" they are in and the dark-black
circles indicate the centers of each cluster.
"""
from collections import namedtuple
from math import sqrt
import random
from scipy.cluster.vq import kmeans,vq
try:
import Image
except ImportError:
from PIL import Image
Point = namedtuple('Point', ('coords', 'n', 'ct'))
Cluster = namedtuple('Cluster', ('points', 'center', 'n'))
def get_points(img):
points = []
w, h = img.size
for count, color in img.getcolors(w * h):
points.append(Point(color, 3, count))
return points
rtoh = lambda rgb: '#%s' % ''.join(('%02x' % p for p in rgb))
def colorz(filename, n=3):
img = Image.open(filename)
img.thumbnail((200, 200))
w, h = img.size
points = get_points(img)
clusters = kmeans(points, n, 1)
rgbs = [map(int, c.center.coords) for c in clusters]
return map(rtoh, rgbs)
def euclidean(p1, p2):
return sqrt(sum([
(p1.coords[i] - p2.coords[i]) ** 2 for i in range(p1.n)
]))
def calculate_center(points, n):
vals = [0.0 for i in range(n)]
plen = 0
for p in points:
plen += p.ct
for i in range(n):
vals[i] += (p.coords[i] * p.ct)
return Point([(v / plen) for v in vals], n, 1)
def kmeans(points, k, min_diff):
clusters = [Cluster([p], p, p.n) for p in random.sample(points, k)]
while 1:
plists = [[] for i in range(k)]
for p in points:
smallest_distance = float('Inf')
for i in range(k):
distance = euclidean(p, clusters[i].center)
if distance < smallest_distance:
smallest_distance = distance
idx = i
plists[idx].append(p)
diff = 0
for i in range(k):
old = clusters[i]
center = calculate_center(plists[i], old.n)
new = Cluster(plists[i], center, old.n)
clusters[i] = new
diff = max(diff, euclidean(old.center, new.center))
if diff < min_diff:
break
return clusters
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment