Last active
December 14, 2018 22:32
-
-
Save badjano/4d64c5b11e2e4d17f49564b1a410d63f to your computer and use it in GitHub Desktop.
Finding palette from image using K-Means Clustering
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import random | |
from PIL import ImageDraw, Image | |
from statistics import median | |
max_dist = pow(255, 2) + pow(255, 2) + pow(255, 2) | |
global_power = (2, 2, 2) | |
def avg(arr): | |
return sum(arr) / len(arr) | |
class Item: | |
def __init__(self, color): | |
self.color = color | |
self.group = -1 | |
def find_group(self, centroids): | |
groups_start = self.group | |
self.distance = max_dist | |
for centroid_index in range(len(centroids)): | |
centroid = centroids[centroid_index] | |
_d = self.dist(centroid) | |
if _d < self.distance: | |
self.distance = _d | |
self.group = centroid_index | |
return groups_start != self.group | |
def dist(self, color): | |
a = self.color | |
b = color | |
s = pow(abs(a[0] - b[0]) * global_power[0], 2) + \ | |
pow(abs(a[1] - b[1]) * global_power[1], 2) + \ | |
pow(abs(a[2] - b[2]) * global_power[2], 2) | |
return s | |
def __str__(self): | |
return "r:%d, g:%d, b:%d" % self.color | |
def get_main_colors(all: [Item], cents=None, max_iterations=-1, method=None, count=6): | |
iterations = 0 | |
if not cents: | |
print("generating random centroids") | |
cents = [] | |
while len(cents) < count: | |
cents.append(random.choice(all)) | |
cents = list(set(cents)) | |
while True: | |
print("iterations: ", iterations) | |
iterations += 1 | |
change = False | |
print("grouping %d" % len(all)) | |
change_count = 0 | |
for item in all: | |
if item.find_group([a.color for a in cents]): | |
change_count += 1 | |
change = True | |
if change: | |
print("calculating centroids, %d changed" % change_count) | |
new_cents = [] | |
for centroid_index in range(len(cents)): | |
colors = [] | |
for item in all: | |
if item.group == centroid_index: | |
colors.append(item.color) | |
t = len(colors) | |
if t: | |
if method: | |
avg_color = ( | |
int(method([r[0] for r in colors])), | |
int(method([r[1] for r in colors])), | |
int(method([r[2] for r in colors]))) | |
else: | |
avg_color = ( | |
int(avg([r[0] for r in colors])), | |
int(avg([r[1] for r in colors])), | |
int(avg([r[2] for r in colors]))) | |
new_cents.append(Item(avg_color)) | |
else: | |
new_cents.append(cents[len(new_cents) - 1]) | |
cents = new_cents | |
print("new centroids: %d" % len(cents)) | |
# for c in cents: | |
# print("\t", c) | |
if 0 < max_iterations <= iterations: | |
break | |
else: | |
break | |
return cents | |
def get_palette(filename, method, levels=1): | |
global global_power | |
global_power = (1, 1, 1) | |
img = Image.open(filename) | |
size = 100 | |
if img.size[0] < img.size[1]: | |
scale = size / img.size[0] | |
else: | |
scale = size / img.size[1] | |
img = img.resize(( | |
int(img.size[0] * scale), | |
int(img.size[1] * scale) | |
)) | |
items = [] | |
data = img.getdata() | |
new_data = list(set(data)) | |
print("Optimizing data, was %d pixels, now %d pixels" % (len(data), len(new_data))) | |
for color in new_data: | |
items.append(Item(color)) | |
level = 1 | |
while level <= levels: | |
items = sorted(get_main_colors(items, None, -1, method, 6 ** (levels - level + 1)), key=lambda x: x.color[0] + x.color[1] + x.color[2]) | |
level += 1 | |
draw = ImageDraw.Draw(img) | |
for c in items: | |
index = items.index(c) | |
w = int(img.size[0] / len(items)) | |
xy = (index * w, img.size[1] - 100 * scale, (index + 1) * w, img.size[1]) | |
draw.rectangle(xy, c.color) | |
del draw | |
# img.save("kmeans_%d_%d_%d.jpg" % power) | |
img.show() | |
for i in range(6, 7): | |
get_palette("img/art_%02d.jpg" % i, median, 1) | |
get_palette("img/art_%02d.jpg" % i, median, 3) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment