Created
April 7, 2019 14:14
-
-
Save pulsejet/3706d7411682709928a9ec41ec1b3872 to your computer and use it in GitHub Desktop.
A ludicrous spaghetti implementation of n-dimensional weighted k-means
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""A ludicrous spaghetti implementation of n-dimensional weighted k-means that can do no prediction. This line is long since, after all, it's spaghetti!""" | |
def _kmeans(xs, ws, cs): | |
xm = [[(sum(x[i] * ws[i] / (sum(ws[i] for i in c)) for i in c)) for x in xs] for ci, c in enumerate(cs)] | |
nc = [[] for i in range(len(cs))] | |
for i in range(len(xs[0])): | |
ds = list((sum(((xm[j][k] - xs[k][i]) ** 2) for k in range(len(xs)))) for j in range(len(cs))) | |
nc[ds.index(min(ds))].append(i) | |
return cs, nc | |
def kmeans(m, xs, ws=None): | |
ws = ws if ws else [1] * len(xs[0]) | |
oc = [None] * m | |
nc = list(list(range(len(xs[0]))[i::m]) for i in range(m)) | |
while any(nc[i] != oc[i] for i in range(m)): oc, nc = _kmeans(xs, ws, nc) | |
return nc | |
if __name__ == '__main__': | |
data = [ | |
[10, 50, 20, 60, 90, 50, 60, 70, 30, 40], | |
[40, 100, 80, 2, 70, 10, 40, 70, 10, 90] | |
] | |
w = [0.25, 0.6, 0.75, 0.6, 1.2, 0.5, 0.4, 0.5, 0.35, 0.6] | |
print(kmeans(3, data, w)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment