Skip to content

Instantly share code, notes, and snippets.

@badjano
Last active December 14, 2018 22:32
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save badjano/4d64c5b11e2e4d17f49564b1a410d63f to your computer and use it in GitHub Desktop.
Save badjano/4d64c5b11e2e4d17f49564b1a410d63f to your computer and use it in GitHub Desktop.
Finding palette from image using K-Means Clustering
import random
from PIL import ImageDraw, Image
from statistics import median
max_dist = pow(255, 2) + pow(255, 2) + pow(255, 2)
global_power = (2, 2, 2)
def avg(arr):
return sum(arr) / len(arr)
class Item:
def __init__(self, color):
self.color = color
self.group = -1
def find_group(self, centroids):
groups_start = self.group
self.distance = max_dist
for centroid_index in range(len(centroids)):
centroid = centroids[centroid_index]
_d = self.dist(centroid)
if _d < self.distance:
self.distance = _d
self.group = centroid_index
return groups_start != self.group
def dist(self, color):
a = self.color
b = color
s = pow(abs(a[0] - b[0]) * global_power[0], 2) + \
pow(abs(a[1] - b[1]) * global_power[1], 2) + \
pow(abs(a[2] - b[2]) * global_power[2], 2)
return s
def __str__(self):
return "r:%d, g:%d, b:%d" % self.color
def get_main_colors(all: [Item], cents=None, max_iterations=-1, method=None, count=6):
iterations = 0
if not cents:
print("generating random centroids")
cents = []
while len(cents) < count:
cents.append(random.choice(all))
cents = list(set(cents))
while True:
print("iterations: ", iterations)
iterations += 1
change = False
print("grouping %d" % len(all))
change_count = 0
for item in all:
if item.find_group([a.color for a in cents]):
change_count += 1
change = True
if change:
print("calculating centroids, %d changed" % change_count)
new_cents = []
for centroid_index in range(len(cents)):
colors = []
for item in all:
if item.group == centroid_index:
colors.append(item.color)
t = len(colors)
if t:
if method:
avg_color = (
int(method([r[0] for r in colors])),
int(method([r[1] for r in colors])),
int(method([r[2] for r in colors])))
else:
avg_color = (
int(avg([r[0] for r in colors])),
int(avg([r[1] for r in colors])),
int(avg([r[2] for r in colors])))
new_cents.append(Item(avg_color))
else:
new_cents.append(cents[len(new_cents) - 1])
cents = new_cents
print("new centroids: %d" % len(cents))
# for c in cents:
# print("\t", c)
if 0 < max_iterations <= iterations:
break
else:
break
return cents
def get_palette(filename, method, levels=1):
global global_power
global_power = (1, 1, 1)
img = Image.open(filename)
size = 100
if img.size[0] < img.size[1]:
scale = size / img.size[0]
else:
scale = size / img.size[1]
img = img.resize((
int(img.size[0] * scale),
int(img.size[1] * scale)
))
items = []
data = img.getdata()
new_data = list(set(data))
print("Optimizing data, was %d pixels, now %d pixels" % (len(data), len(new_data)))
for color in new_data:
items.append(Item(color))
level = 1
while level <= levels:
items = sorted(get_main_colors(items, None, -1, method, 6 ** (levels - level + 1)), key=lambda x: x.color[0] + x.color[1] + x.color[2])
level += 1
draw = ImageDraw.Draw(img)
for c in items:
index = items.index(c)
w = int(img.size[0] / len(items))
xy = (index * w, img.size[1] - 100 * scale, (index + 1) * w, img.size[1])
draw.rectangle(xy, c.color)
del draw
# img.save("kmeans_%d_%d_%d.jpg" % power)
img.show()
for i in range(6, 7):
get_palette("img/art_%02d.jpg" % i, median, 1)
get_palette("img/art_%02d.jpg" % i, median, 3)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment