Created
May 10, 2018 17:13
-
-
Save koshian2/74715968c10b0cdaec2329b9cc0f0542 to your computer and use it in GitHub Desktop.
Coursera Machine LearningをPythonで実装 - [Week8]k-Means, 主成分分析(PCA)(2)k-Means、組み込み
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import matplotlib.pyplot as plt | |
from sklearn.cluster import KMeans | |
from scipy.io import loadmat | |
## 1.2次元のクラスタリング | |
# data2の読み込み | |
X = np.array(loadmat("ex7data2.mat")['X']) | |
# K=3でクラスタリング | |
# initで初期値のとり方を選ぶ。デフォルトのk-means++(randomも選べる) | |
# n_initは異なる初期値で回す回数 | |
clst = KMeans(n_clusters = 3, init="k-means++", n_init=10) | |
# fitで計算 | |
clst.fit(X) | |
# cluster_centers_で重心の座標 | |
print("cluster_centers_ : ") | |
print(clst.cluster_centers_) | |
print() | |
# labels_でどの重心かのインデックスを表示 | |
print("labels_ (Take 10) : ") | |
print(clst.labels_[:10]) | |
# 可視化 | |
palette = np.linspace(0, 1, clst.n_clusters + 1) | |
color = palette[np.array(clst.labels_, dtype=int)] | |
plt.scatter(X[:, 0], X[:, 1], 15, c=plt.cm.hsv(color)) | |
plt.plot(clst.cluster_centers_[:, 0], clst.cluster_centers_[:, 1], "x", color="k", markersize=10, linewidth=3) | |
plt.show() | |
## 2.画像の減色 | |
from PIL import Image | |
# 画像の読み込み | |
img = np.asarray(Image.open("bird_small.png")) #(128,128,3)のテンソル | |
img_array = img.reshape(img.shape[0]*img.shape[1], img.shape[2]) / 255 #(128*128, 3)の行列に変形 | |
# k-meansで減色 | |
# n_clusters:減色数 | |
clst = KMeans(n_clusters=16) | |
label = clst.fit_predict(img_array) | |
# 画像として復元 | |
img_array_quant = clst.cluster_centers_[label] #(128*128, 3)の行列 | |
img_quant = img_array_quant.reshape(img.shape) | |
# プロットに表示(imshowで上手く表示するために[0,1]のfloatにしておくこと) | |
fig = plt.figure() | |
ax = fig.add_subplot(1, 2, 1, xticks=[], yticks=[]) | |
ax.imshow(img) | |
ax.set_title("Original") | |
ax = fig.add_subplot(1, 2, 2, xticks=[], yticks=[]) | |
ax.imshow(img_quant) | |
ax.set_title("k-Means Color Quantization (K = 16)") | |
plt.show() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment