Skip to content

Instantly share code, notes, and snippets.

@pulsejet
Created April 7, 2019 14:14
Show Gist options
  • Save pulsejet/3706d7411682709928a9ec41ec1b3872 to your computer and use it in GitHub Desktop.
Save pulsejet/3706d7411682709928a9ec41ec1b3872 to your computer and use it in GitHub Desktop.
A ludicrous spaghetti implementation of n-dimensional weighted k-means
"""A ludicrous spaghetti implementation of n-dimensional weighted k-means that can do no prediction. This line is long since, after all, it's spaghetti!"""
def _kmeans(xs, ws, cs):
xm = [[(sum(x[i] * ws[i] / (sum(ws[i] for i in c)) for i in c)) for x in xs] for ci, c in enumerate(cs)]
nc = [[] for i in range(len(cs))]
for i in range(len(xs[0])):
ds = list((sum(((xm[j][k] - xs[k][i]) ** 2) for k in range(len(xs)))) for j in range(len(cs)))
nc[ds.index(min(ds))].append(i)
return cs, nc
def kmeans(m, xs, ws=None):
ws = ws if ws else [1] * len(xs[0])
oc = [None] * m
nc = list(list(range(len(xs[0]))[i::m]) for i in range(m))
while any(nc[i] != oc[i] for i in range(m)): oc, nc = _kmeans(xs, ws, nc)
return nc
if __name__ == '__main__':
data = [
[10, 50, 20, 60, 90, 50, 60, 70, 30, 40],
[40, 100, 80, 2, 70, 10, 40, 70, 10, 90]
]
w = [0.25, 0.6, 0.75, 0.6, 1.2, 0.5, 0.4, 0.5, 0.35, 0.6]
print(kmeans(3, data, w))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment