Skip to content

Instantly share code, notes, and snippets.

@palloc
Last active July 7, 2017 04:24
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save palloc/02356b30c69e9f3c8343f4f867f4880b to your computer and use it in GitHub Desktop.
Save palloc/02356b30c69e9f3c8343f4f867f4880b to your computer and use it in GitHub Desktop.
homework
from math import *
import random
import numpy as np
import collections
class Kmeans:
def __init__(self, cluster_num):
self.cluster_num = cluster_num
self.counter = 0
self.error = 0.
# データの読み込み
def read_data(self, data_path):
with open(data_path) as file:
self.data = np.array(list(map(lambda x: list(map(int, x.split(','))), file.read().split('\n'))))
self.label = [0 for i in range(len(self.data))]
self.centroid = np.array([self.data[random.randint(0, len(self.data))] for i in range(self.cluster_num)])
# ラベリング処理
def labeling(self):
self.error = 0.
for i in range(len(self.data)):
dist = []
for c in self.centroid:
dist.append(np.linalg.norm(self.data[i]-c))
self.label[i] = dist.index(min(dist))
self.error += min(dist)
# 重心の更新
def update_centroid(self):
self.centroid = np.array([[0. for i in range(4)] for i in range(self.cluster_num)])
cluster_counter = collections.Counter(self.label)
for i in range(len(self.data)):
self.centroid[self.label[i]] += (self.data[i] / cluster_counter[self.label[i]])
# 学習
def fit(self, epoch=10):
for i in range(epoch):
print("--------epoch {}---------".format(self.counter))
print("centroid:")
print(self.centroid)
print("label:")
print(self.label)
print("error:")
print(self.error)
self.labeling()
self.update_centroid()
self.counter += 1
if __name__ == '__main__':
clustering = Kmeans(5)
clustering.read_data('data.csv')
clustering.fit(epoch=3)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment