-
-
Save Azoay/f170322df555f97e4ceb to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
# -*- coding: utf-8 -*- | |
import numpy as np | |
import matplotlib.pyplot as plt | |
from scipy.spatial.distance import pdist | |
from scipy.cluster.hierarchy import linkage, dendrogram | |
def hcluster(method='average'): | |
a = np.array([1,2,2,3,2,4,3,1,5,3], dtype=float) | |
labels = np.array(['a','b','c','d','e']) | |
result = linkage(a, method=method) | |
dendrogram(result, orientation='left', labels=labels) | |
plt.show() | |
def main(): | |
hcluster('single') | |
hcluster('complete') | |
hcluster('average') | |
#hcluster('weighted') | |
if __name__ == "__main__": | |
main() | |
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
# -*- coding: utf-8 -*- | |
import numpy as np | |
import matplotlib.pyplot as plt | |
def wrap_k_means(data, k=3): | |
cluster = np.sort(np.random.random(k) * data.max()) # 初期値ランダム | |
k_means(data, cluster) | |
def k_means(data, cluster): | |
data = np.array(data, dtype=float) | |
prof = np.zeros(len(data)) # 各要素∈dataがどのクラスタに属しているか | |
cluster = np.sort(np.array(cluster, dtype=float)) | |
old_cluster = np.zeros(len(cluster)) # 収束チェック用 | |
conv = True; count = 0 | |
while conv: | |
count += 1 | |
# 割り当て | |
for i,d in enumerate(data): | |
min_d = 100000 # てきとうに大きい数 | |
for j,c in enumerate(cluster): | |
dist = abs(d - c) | |
if min_d > dist: | |
min_d = dist | |
prof[i] = j # クラスタの割り当て | |
# 更新 | |
for j,c in enumerate(cluster): | |
m = 0; n = 0 | |
for i,p in enumerate(prof): | |
if p == j: # もしもそのクラスタに属していたら | |
m += data[i] | |
n += 1 | |
if m != 0: | |
m /= n # mは更新した平均 | |
old_cluster[j] = cluster[j] | |
cluster[j] = m | |
# 途中経過 | |
print("{}回目".format(count)) | |
print("data : ", data) | |
print("prof : ", prof) | |
#print("old : ", old_cluster) | |
print("cluster: ", cluster) | |
# 収束チェック | |
for i,c in enumerate(cluster): | |
if c != old_cluster[i]: | |
conv = True | |
break | |
else: | |
conv = False | |
# 結果出力 | |
print("result") | |
print("data : ", data) # debug | |
print("prof : ", prof) # debug | |
print("cluster: ", cluster) # debug | |
def main(): | |
data = np.array([2, 3, 4, 10, 11, 12, 20, 25, 30], dtype=float) | |
#(1) [2, 20] | |
k_means(data, [2, 20]) | |
print() | |
#(2) [2, 3, 10] | |
k_means(data, [2, 3, 10]) | |
print() | |
#(3) [12, 25, 30] | |
k_means(data, [12, 25, 30]) | |
if __name__ == "__main__": | |
main() | |
pass | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment