Created
December 14, 2017 14:52
-
-
Save DavidSanf0rd/af63e4cb0a2bf350f7b51db07a01aefa to your computer and use it in GitHub Desktop.
IC: Kmeans
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from copy import deepcopy | |
import numpy as npy | |
import pandas as pd | |
from matplotlib import pyplot as plt | |
class KMeans: | |
def __init__(self, k, file_name): | |
self.k = k | |
self.file_name = file_name | |
self.dataset = None | |
""" | |
Loads the csv file into an attribute | |
""" | |
def load_dataset(self): | |
self.dataset = pd.read_csv(self.file_name) | |
""" | |
Calculates the euclidean distance between two points | |
""" | |
@staticmethod | |
def distance(a, b, ax=1): | |
return npy.linalg.norm(a - b, axis=ax) | |
""" | |
Runs the algorithm and plot the results | |
""" | |
def run(self): | |
# Get the dataset columns | |
f1 = self.dataset['V1'].values | |
f2 = self.dataset['V2'].values | |
X = npy.array(list(zip(f1, f2))) | |
plt.scatter(f1, f2, c='black', s=7) # Plot all dataset points | |
cx = npy.random.randint(0, npy.max(X)-20, size=self.k) | |
cy = npy.random.randint(0, npy.max(X)-20, size=self.k) | |
c = npy.array(list(zip(cx, cy)), dtype=npy.float32) | |
# Plot centroids | |
plt.scatter(cx, cy, marker='*', s=200, c='g') | |
c_old = npy.zeros(c.shape) | |
clusters = npy.zeros(len(X)) | |
error = self.distance(c, c_old, None) | |
# Until is zero | |
while error != 0: | |
for i in range(len(X)): # Closest cluster | |
distances = self.distance(X[i], c) | |
cluster = npy.argmin(distances) | |
clusters[i] = cluster | |
c_old = deepcopy(c) | |
# Find new centroids | |
for i in range(self.k): | |
points = [X[j] for j in range(len(X)) if clusters[j] == i] | |
c[i] = npy.mean(points, axis=0) | |
error = self.distance(c, c_old, None) | |
# Plotting | |
colors = ['b', 'r', 'g', 'y', 'c', 'm'] | |
fig, ax = plt.subplots() | |
for i in range(self.k): | |
points = npy.array([X[j] for j in range(len(X)) if clusters[j] == i]) | |
ax.scatter(points[:, 0], points[:, 1], s=7, c=colors[i]) | |
ax.scatter(c[:, 0], c[:, 1], marker='*', s=200, c='#050505') | |
plt.show() | |
k_means = KMeans(k=3, file_name='xclara.csv') | |
k_means.load_dataset() | |
k_means.run() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment