Created
February 28, 2020 17:52
-
-
Save MinaGabriel/2593eab805c18937b31d05320c85e37b to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import sys | |
import numpy | |
import os | |
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True' | |
numpy.set_printoptions(threshold=sys.maxsize) | |
import torch | |
from torch.distributions import normal | |
import seaborn as sns; | |
sns.set(style="white", color_codes=True) | |
import numpy as np | |
import matplotlib.pyplot as plt | |
import pandas as pd | |
from sklearn.cluster import KMeans | |
def random_data(): | |
m = normal.Normal(25, 6) | |
X1 = m.sample((400, 2)) | |
m = normal.Normal(10, 6) | |
X2 = m.sample((300, 2)) | |
m = normal.Normal(50, 5) | |
X3 = m.sample((500, 2)) | |
transformation = torch.tensor([[0.3754, 0.9254], [0.5452, 0.6729]]) | |
X = torch.cat([torch.mm(X1, transformation), X2, X3], 0) | |
return X | |
def random_centroids(data_points, k): | |
cx = np.random.randint(0, np.max(data_points), size=k) | |
cy = np.random.randint(0, np.max(data_points), size=k) | |
center = np.array(list(zip(cx, cy)), dtype=np.float32) | |
return center | |
def get_distance(data_points, centroids, show_data=False): | |
distances = [] | |
for i in range(len(data_points)): | |
for j in range(len(centroids)): | |
distance = np.linalg.norm(data_points[i] - centroids[j]) | |
distances = np.append(distances, distance) | |
distances = distances.reshape(len(data_points), len(centroids)) | |
pd_data = np.concatenate((data_points, distances), axis=1) | |
if show_data: | |
print(pd.DataFrame(data=pd_data)) | |
return distances | |
def get_labels(distances): | |
return np.argmin(distances, axis=1) | |
def plotting(data_points, centroids, labels, sklearn=False): | |
fig, ax = plt.subplots() | |
colors = ['r', 'g', 'b', 'y', 'c', 'm', 'k', 'gold', 'blue', 'lime', 'purple', 'darkgreen', | |
'ivory', 'magenta', 'plum', 'salmon', 'orange', 'navy'] | |
for i in range(len(centroids)): | |
points = [] | |
for j in range(len(data_points)): | |
if labels[j] == i: | |
points.append(data_points[j]) | |
points = np.asarray(points) | |
if len(points) != 0: | |
ax.scatter(points[:, 0], points[:, 1], s=10, c=colors[i % len(colors)]) | |
ax.scatter(centroids[:, 0], centroids[:, 1], marker='*', s=25, c='#050505') | |
if sklearn: | |
# Number of clusters | |
kmeans = KMeans(n_clusters=len(centroids)) | |
# Fitting the input data | |
kmeans = kmeans.fit(X) | |
# Getting the cluster labels | |
labels = kmeans.predict(X) | |
# Centroid values | |
c = kmeans.cluster_centers_ | |
ax.scatter(c[:, 0], c[:, 1], marker='+', c='lime', s=100) | |
plt.show() | |
def orphan_centroid(data_points): | |
return np.mean(data_points, axis=0) | |
def update_centroids(data_points, centroids): | |
new_centroids = [] | |
dist = get_distance(X, c) | |
labels = get_labels(dist) | |
for i in range(len(centroids)): | |
group_data_points = [] | |
for j in range(len(labels)): | |
if labels[j] == i: | |
group_data_points.append(data_points[j]) | |
# the following if statement fix the problem with orphan centroids | |
# orphan centroids are centroids that don't belong to any data point | |
if len(group_data_points) != 0: | |
new_centroids.append(np.mean(np.asarray(group_data_points), axis=0)) | |
else: | |
# print('-------GOT IN ---------') | |
new_centroids.append(orphan_centroid(data_points)) | |
return np.asarray(new_centroids) | |
def get_error(old_centroids, new_centroids): | |
return np.linalg.norm(new_centroids - old_centroids, axis=None) | |
# data with Gaussian Distribution | |
X = random_data().data.numpy() | |
k = 25 | |
# none sense data points | |
#X = np.random.randint(19, size=(50, 2)) | |
c = random_centroids(X, k) | |
dist = get_distance(X, c) | |
labels = get_labels(dist) | |
plotting(X, c, labels, False) | |
error = get_error(random_centroids(X, k), c) | |
# print(error) | |
errors = [] | |
errors.append(error) | |
while error != 0: | |
new = update_centroids(X, c) | |
dist = get_distance(X, new) | |
labels = get_labels(dist) | |
plotting(X, new, labels, False) | |
error = get_error(c, new) | |
c = new | |
# print(error) | |
errors.append(error) | |
print(errors) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment