Skip to content

Instantly share code, notes, and snippets.

@DavidSanf0rd
Created December 14, 2017 14:52
Show Gist options
  • Save DavidSanf0rd/af63e4cb0a2bf350f7b51db07a01aefa to your computer and use it in GitHub Desktop.
Save DavidSanf0rd/af63e4cb0a2bf350f7b51db07a01aefa to your computer and use it in GitHub Desktop.
IC: Kmeans
from copy import deepcopy
import numpy as npy
import pandas as pd
from matplotlib import pyplot as plt
class KMeans:
def __init__(self, k, file_name):
self.k = k
self.file_name = file_name
self.dataset = None
"""
Loads the csv file into an attribute
"""
def load_dataset(self):
self.dataset = pd.read_csv(self.file_name)
"""
Calculates the euclidean distance between two points
"""
@staticmethod
def distance(a, b, ax=1):
return npy.linalg.norm(a - b, axis=ax)
"""
Runs the algorithm and plot the results
"""
def run(self):
# Get the dataset columns
f1 = self.dataset['V1'].values
f2 = self.dataset['V2'].values
X = npy.array(list(zip(f1, f2)))
plt.scatter(f1, f2, c='black', s=7) # Plot all dataset points
cx = npy.random.randint(0, npy.max(X)-20, size=self.k)
cy = npy.random.randint(0, npy.max(X)-20, size=self.k)
c = npy.array(list(zip(cx, cy)), dtype=npy.float32)
# Plot centroids
plt.scatter(cx, cy, marker='*', s=200, c='g')
c_old = npy.zeros(c.shape)
clusters = npy.zeros(len(X))
error = self.distance(c, c_old, None)
# Until is zero
while error != 0:
for i in range(len(X)): # Closest cluster
distances = self.distance(X[i], c)
cluster = npy.argmin(distances)
clusters[i] = cluster
c_old = deepcopy(c)
# Find new centroids
for i in range(self.k):
points = [X[j] for j in range(len(X)) if clusters[j] == i]
c[i] = npy.mean(points, axis=0)
error = self.distance(c, c_old, None)
# Plotting
colors = ['b', 'r', 'g', 'y', 'c', 'm']
fig, ax = plt.subplots()
for i in range(self.k):
points = npy.array([X[j] for j in range(len(X)) if clusters[j] == i])
ax.scatter(points[:, 0], points[:, 1], s=7, c=colors[i])
ax.scatter(c[:, 0], c[:, 1], marker='*', s=200, c='#050505')
plt.show()
k_means = KMeans(k=3, file_name='xclara.csv')
k_means.load_dataset()
k_means.run()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment