Skip to content

Instantly share code, notes, and snippets.

@nmoya
Created July 25, 2018 20:26
Show Gist options
  • Save nmoya/8db2db7c282ff68e0df6145fdde6fdb8 to your computer and use it in GitHub Desktop.
Save nmoya/8db2db7c282ff68e0df6145fdde6fdb8 to your computer and use it in GitHub Desktop.
from sklearn.cluster import KMeans
from mpl_toolkits.mplot3d import Axes3D
import matplotlib.pyplot as plt
from PIL import Image
import cv2
import pandas as pd
import numpy as np
def color_fn(values):
colors = []
for v in values:
rgb = []
for ch in v:
rgb.append(ch/255.)
colors.append(rgb)
return colors
def organize_by_label(features, labels):
data_by_labels = {}
for i, value in enumerate(features):
label = labels[i]
if not(label in data_by_labels):
data_by_labels[label] = value
else:
data_by_labels[label] = np.concatenate([data_by_labels[label], value])
for k in data_by_labels:
data_by_labels[k] = data_by_labels[k].reshape(-1, 3)
return data_by_labels
def average_colors(labels):
averages = {}
for label in labels:
colors = labels[label]
averages[label] = np.average(colors, axis=0)
return averages
def plot(data, kmeans, average_colors):
colors = []
for k in average_colors:
colors.append(average_colors[k])
fig = plt.figure(figsize=(12, 12))
ax = plt.subplot(111, projection='3d')
centers = kmeans.cluster_centers_
ax.scatter(data['r'].values.tolist(),
data['g'].values.tolist(),
data['b'].values.tolist(),
marker='o', c=color_fn(data.values))
ax.scatter(centers[:, 0],
centers[:, 1],
centers[:, 2],
marker='o', s=200, c=color_fn(colors))
plt.title("Color classification")
plt.savefig("./color/clustering.png")
def main():
# data = pd.read_csv("./color/data/dataset.csv")
img = cv2.imread("./color/castle.jpg")
img_array = pd.DataFrame(np.asarray(img).reshape(-1, 3), columns=["r", "g", "b"])
kmeans = KMeans(n_clusters=128, random_state=0).fit(img_array)
labels = kmeans.labels_
labels_dict = organize_by_label(img_array.values.tolist(), labels)
avg_colors = average_colors(labels_dict)
colors_output = []
for pred in labels:
colors_output.append(avg_colors[pred])
outpixels = np.array(colors_output).flatten().reshape(img.shape)
output = Image.fromarray(outpixels.astype(np.uint8), 'RGB')
output.save('./color/output.png')
plot(img_array, kmeans, avg_colors)
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment