Last active
August 29, 2015 14:17
-
-
Save douglasgoodwin/8845beda75c1ca73c7ae to your computer and use it in GitHub Desktop.
Using python and k-means to find the dominant colors in images. "forked" from charlesleifer.com
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
from: http://charlesleifer.com/blog/using-python-and-k-means-to-find-the-dominant-colors-in-images/ | |
start with a bunch of data points. For simplicity let's say they're numbers on a number-line. | |
You want to group the numbers into "k" clusters, so pick "k" points randomly from the data to | |
use as your "clusters". | |
Now loop over every point in the data and calculate its distance to each of the "k" clusters. | |
Find the nearest cluster and associate that point with the cluster. When you've looped over | |
all the points they should all be assigned to one of the "k" clusters. Now, for each cluster | |
recalculate its center by averaging the distances of all the associated points and start over. | |
When the centers stop moving very much you can stop looping. You will end up with something | |
like this -- the points are colored based on what "cluster" they are in and the dark-black | |
circles indicate the centers of each cluster. | |
""" | |
from collections import namedtuple | |
from math import sqrt | |
import random | |
from scipy.cluster.vq import kmeans,vq | |
try: | |
import Image | |
except ImportError: | |
from PIL import Image | |
Point = namedtuple('Point', ('coords', 'n', 'ct')) | |
Cluster = namedtuple('Cluster', ('points', 'center', 'n')) | |
def get_points(img): | |
points = [] | |
w, h = img.size | |
for count, color in img.getcolors(w * h): | |
points.append(Point(color, 3, count)) | |
return points | |
rtoh = lambda rgb: '#%s' % ''.join(('%02x' % p for p in rgb)) | |
def colorz(filename, n=3): | |
img = Image.open(filename) | |
img.thumbnail((200, 200)) | |
w, h = img.size | |
points = get_points(img) | |
clusters = kmeans(points, n, 1) | |
rgbs = [map(int, c.center.coords) for c in clusters] | |
return map(rtoh, rgbs) | |
def euclidean(p1, p2): | |
return sqrt(sum([ | |
(p1.coords[i] - p2.coords[i]) ** 2 for i in range(p1.n) | |
])) | |
def calculate_center(points, n): | |
vals = [0.0 for i in range(n)] | |
plen = 0 | |
for p in points: | |
plen += p.ct | |
for i in range(n): | |
vals[i] += (p.coords[i] * p.ct) | |
return Point([(v / plen) for v in vals], n, 1) | |
def kmeans(points, k, min_diff): | |
clusters = [Cluster([p], p, p.n) for p in random.sample(points, k)] | |
while 1: | |
plists = [[] for i in range(k)] | |
for p in points: | |
smallest_distance = float('Inf') | |
for i in range(k): | |
distance = euclidean(p, clusters[i].center) | |
if distance < smallest_distance: | |
smallest_distance = distance | |
idx = i | |
plists[idx].append(p) | |
diff = 0 | |
for i in range(k): | |
old = clusters[i] | |
center = calculate_center(plists[i], old.n) | |
new = Cluster(plists[i], center, old.n) | |
clusters[i] = new | |
diff = max(diff, euclidean(old.center, new.center)) | |
if diff < min_diff: | |
break | |
return clusters |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment