Skip to content

Instantly share code, notes, and snippets.

@nicoguaro
Created January 9, 2020 18:36
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save nicoguaro/9d5ede0e1636905f5ade6a32980b4352 to your computer and use it in GitHub Desktop.
Save nicoguaro/9d5ede0e1636905f5ade6a32980b4352 to your computer and use it in GitHub Desktop.
Posterization of an image using k-means
"""
Posterization of an image using k-means
It can be used to create color palettes from pictures.
@author:Nicolás Guarín-Zapata
@date: January 2020
"""
import numpy as np
from scipy import misc
import matplotlib.pyplot as plt
from sklearn.cluster import KMeans
from webcolors import rgb_to_hex
import matplotlib.patches as mpatches
def cluster_colors(img, n_colors=8):
img = img/255
m, n, d = img.shape
X = img.reshape((m*n, d))
k_means = KMeans(n_clusters=n_colors)
k_means.fit(X)
values = k_means.cluster_centers_
labels = k_means.labels_
return values, labels
def compressed_img(values, labels, dims):
m, n, d = dims
img_compressed = values[labels]
img_compressed.shape = (m, n, d)
return img_compressed
def box_array(n_rows, n_cols, colors, ax):
cont = 0
for row in range(n_rows):
for col in range(n_cols):
rect = mpatches.Rectangle([col, row], 1, 1, facecolor=colors[cont])
ax.add_patch(rect)
cont = cont + 1
return None
#%% Configuration
n_colors = 6
n_cols = 3
n_rows = 2
img = misc.face()
values, labels = cluster_colors(img, n_colors=n_colors)
img_compressed = compressed_img(values, labels, img.shape)
values_int = np.empty_like(values, dtype=np.int)
values_int[:] = 255*values[:]
colors = [rgb_to_hex(values_int[k, :]) for k in range(n_colors)]
print(colors)
#%% Visualization
ax0 = plt.subplot(221)
plt.imshow(img)
plt.ylim(0, 2)
plt.xlim(0, 4)
plt.axis("image")
plt.axis("off")
ax0.invert_yaxis()
ax1 = plt.subplot(222)
plt.imshow(img_compressed)
plt.ylim(0, 2)
plt.xlim(0, 4)
plt.axis("image")
plt.axis("off")
ax1.invert_yaxis()
ax2 = plt.subplot(212)
box_array(n_rows, n_cols, colors, ax2)
plt.ylim(0, 2)
plt.xlim(0, 4)
plt.axis("image")
plt.axis("off")
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment