Created
January 9, 2020 18:36
-
-
Save nicoguaro/9d5ede0e1636905f5ade6a32980b4352 to your computer and use it in GitHub Desktop.
Posterization of an image using k-means
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Posterization of an image using k-means | |
It can be used to create color palettes from pictures. | |
@author:Nicolás Guarín-Zapata | |
@date: January 2020 | |
""" | |
import numpy as np | |
from scipy import misc | |
import matplotlib.pyplot as plt | |
from sklearn.cluster import KMeans | |
from webcolors import rgb_to_hex | |
import matplotlib.patches as mpatches | |
def cluster_colors(img, n_colors=8): | |
img = img/255 | |
m, n, d = img.shape | |
X = img.reshape((m*n, d)) | |
k_means = KMeans(n_clusters=n_colors) | |
k_means.fit(X) | |
values = k_means.cluster_centers_ | |
labels = k_means.labels_ | |
return values, labels | |
def compressed_img(values, labels, dims): | |
m, n, d = dims | |
img_compressed = values[labels] | |
img_compressed.shape = (m, n, d) | |
return img_compressed | |
def box_array(n_rows, n_cols, colors, ax): | |
cont = 0 | |
for row in range(n_rows): | |
for col in range(n_cols): | |
rect = mpatches.Rectangle([col, row], 1, 1, facecolor=colors[cont]) | |
ax.add_patch(rect) | |
cont = cont + 1 | |
return None | |
#%% Configuration | |
n_colors = 6 | |
n_cols = 3 | |
n_rows = 2 | |
img = misc.face() | |
values, labels = cluster_colors(img, n_colors=n_colors) | |
img_compressed = compressed_img(values, labels, img.shape) | |
values_int = np.empty_like(values, dtype=np.int) | |
values_int[:] = 255*values[:] | |
colors = [rgb_to_hex(values_int[k, :]) for k in range(n_colors)] | |
print(colors) | |
#%% Visualization | |
ax0 = plt.subplot(221) | |
plt.imshow(img) | |
plt.ylim(0, 2) | |
plt.xlim(0, 4) | |
plt.axis("image") | |
plt.axis("off") | |
ax0.invert_yaxis() | |
ax1 = plt.subplot(222) | |
plt.imshow(img_compressed) | |
plt.ylim(0, 2) | |
plt.xlim(0, 4) | |
plt.axis("image") | |
plt.axis("off") | |
ax1.invert_yaxis() | |
ax2 = plt.subplot(212) | |
box_array(n_rows, n_cols, colors, ax2) | |
plt.ylim(0, 2) | |
plt.xlim(0, 4) | |
plt.axis("image") | |
plt.axis("off") | |
plt.show() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment