Skip to content

Instantly share code, notes, and snippets.

@fuad021
Last active November 28, 2020 05:45
Show Gist options
  • Save fuad021/0d3c432338875ca7ffd21860f8d42c16 to your computer and use it in GitHub Desktop.
Save fuad021/0d3c432338875ca7ffd21860f8d42c16 to your computer and use it in GitHub Desktop.
# -*- coding: utf-8 -*-
# =============================================
# Author: Fuad Al Abir
# Date: 22 Feb 2020
# Problem: K-Means
# Course: CSE 3210
# =============================================
def initCenter(_dataList, k):
_centerList = []
for _ in range(k):
_centerList.append(_dataList[_])
return _centerList
def updateCenter(_clusterList):
_centerList = []
for i in range(len(set(_clusterList))):
sum = 0
n = 0
for j in range(len(_clusterList)):
if i == _clusterList[j]:
sum += _dataList[j]
n += 1
_centerList.append(sum/n)
return _centerList
def dist(dataPoint, center):
distance = dataPoint - center
if distance < 0: return -distance
else: return distance
def distList(_dataList, center):
_distList = []
for _ in range(len(_dataList)):
_distList.append(dist(_dataList[_], center))
return _distList
def dist2DList(_dataList, _centerList):
_dist2DList = []
for i in range(len(_centerList)):
_dist2DList.append(distList(_dataList, _centerList[i]))
return _dist2DList
def clusterList(_dist2DList, k):
_clusterList = []
for _ in range(len(_dist2DList[0])):
min = 999
c = k + 1
for j in range(k):
if min > _dist2DList[j][_]:
min = _dist2DList[j][_]
c = j
_clusterList.append(c)
return _clusterList
def iteration(_dataList, _clusterList, k):
_centerList = updateCenter(_clusterList)
_dist2DList = dist2DList(_dataList, _centerList)
return clusterList(_dist2DList, k)
def initCluster(_datalist, k):
_centerList = initCenter(_dataList, k)
_dist2DList = dist2DList(_dataList, _centerList)
return clusterList(_dist2DList, k)
def kMeans(_dataList, k, plot_each_step = False):
if k > len(_dataList):
print("ERROR: k > size of data.")
return
_clusterList = initCluster(_dataList, k)
step = 0
while(True):
step += 1
if plot_each_step:
plotClusters(_clusterList, step)
_clusterListNew = iteration(_dataList, _clusterList, k)
if _clusterListNew == _clusterList:
if (plot_each_step == False):
plotClusters(_clusterList, step)
return _clusterList, step
else:
_clusterList = _clusterListNew
def predictCluster(_clusterList, point, prediction=True):
_dist = []
for _ in range(len(set(_clusterList))):
cluster = []
for i in range(len(_clusterList)):
if _ == _clusterList[i]:
cluster.append(_dataList[i])
_dist.append(sum(distList(cluster, point)))
predCluster = _dist.index(min(_dist))
plotClusters(_clusterList, step=0, point=point, prediction=predCluster)
return predCluster
import matplotlib.pyplot as plt
import seaborn as sns
sns.set_style("whitegrid")
colors = ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728', '#9467bd', '#8c564b', '#e377c2', '#7f7f7f', '#bcbd22', '#17becf']
def plotClusters(_clusterList, step, point=False, prediction=False):
k = len(set(_clusterList))
cluster = []
mean = []
for i in range(k):
x = []
y = []
for j in range(len(_clusterList)):
if i == _clusterList[j]:
x.append(_dataList[j])
mean.append(sum(x)/len(x))
cluster.append(x)
fig = plt.figure()
ax1 = fig.add_subplot(111)
for _ in range(k):
if _ >= len(colors):
c = 0
c = _%len(colors)
else:
c = _
ax1.scatter(cluster[_], cluster[_], color = colors[c])
for _ in range(k):
c = 0
if _ >= len(colors):
c = _%len(colors)
else:
c = _
ax1.scatter(mean[_], mean[_], marker = "x", color = colors[c], s = 100)
title = 'k:' + str(k) + ' Step:' + str(step)
ax1.set_title(title, fontdict = {'fontsize':15})
if prediction and point:
ax1.scatter(point, point, marker=',', color = colors[prediction], s = 200)
title = 'Cluster Prediction: ' + str(prediction)
ax1.set_title(title, fontdict = {'fontsize':15})
plt.show()
def plotKs(high_k, plot_each_step=True):
for _ in range(2, high_k):
_clusterList, step = kMeans(_dataList, _, plot_each_step)
_dataList = [2, 4, 10, 12, 3, 20, 30, 11, 25]
k = 2
_clusterList, step = kMeans(_dataList, k, plot_each_step=True)
# plotKs(10, plot_each_step=False)
_ = predictCluster(_clusterList, 14.99)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment