Skip to content

Instantly share code, notes, and snippets.

@un1tz3r0
Last active July 27, 2021 04:13
Show Gist options
  • Save un1tz3r0/3c4dae80ed1f9e8b1442b6309d633bd9 to your computer and use it in GitHub Desktop.
Save un1tz3r0/3c4dae80ed1f9e8b1442b6309d633bd9 to your computer and use it in GitHub Desktop.
kmeans in pure python
# -------------------------------------------------------------------------------
# a very basic k-means/elbow clustering
# - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
# "wrote this up one day while sitting on the terlet. i do some of my best
# thinking there" - Author - Victor M. Condino - Wednesday, Feb. 3rd, 2020
#
# this is some code to find the optimal clustering of points in a n-dimensional
# dataset, when the number of clusters is not known uses the elbow-method.
#
# (I'd like to dedicate this to my mom, Martha, who knows way more about
# statistical analysis than me.)
# -------------------------------------------------------------------------------
def kmeans(pts, k):
''' k-means algorithm to cluster points in pts into k groups minimizing the distortion, or
euclidean distance from each point to the centroid (or average) of the points in the group
returns a tuple of (labels, centroids) where labels is an array of the same length as pts,
with the label (0..k-1) assigned to each corresponding point in pts. the second result,
centroids, is an array with k elements which contains the centroid of the points labeled
in each cluster. '''
centroids = list(sorted(list(set(pts))[0:k]))
labels = assignlabels(pts, centroids)
while True:
newcentroids = list(sorted(list(calccentroids(pts, labels))))
newlabels = assignlabels(pts, newcentroids)
if samelabels(newlabels, labels):
break
labels = newlabels
centroids = newcentroids
return newlabels, newcentroids
def wcss(pts, labels, centroids):
''' within-cluster-sum-of-squared error function
gives the sum of the distances from each point to the centroid of the cluster its labeled with,
used to evaluate distortion at varying values of k for optimizing k-means when the number of
clusters is not initially known. does a pretty good job '''
errors = [sqr(distance(pts[i], centroids[labels[i]])) for i in range(0, len(pts))]
return sum(errors)
def elbow(kwcsspts):
''' where pts is an array of N 2-dimensional point coordinate pairs, draw a line from
pts[0] to pts[N-1] and find the point in pts which is farthest from the line. this is
the 'elbow', and if y is wcss of k-means where k is x, then the elbow is the optimal
value of k. '''
x1, y1 = kwcsspts[0]
x2, y2 = kwcsspts[-1]
results = []
for pindex, pt in enumerate(kwcsspts):
x0, y0 = pt
d = abs((x2-x1)*(y1-y0)-(x1-x0)*(y2-y1))/math.sqrt(sqr(x2-x1)+sqr(y2-y1))
results.append(-d)
return kwcsspts[minindex(results)[0]]
def optimizek(pts, mink, maxk):
kwcss = {}
kcentroids = {}
klabels = {}
for k in range(mink, maxk):
labs, cens = kmeans(pts, k)
kcentroids[k] = cens
klabels[k] = labs
kwcss[k] = wcss(pts, labs, cens)
k, err = elbow(list(kwcss.items()))
return k, err, kcentroids[k], klabels[k]
# -------------------------------------------
# sequence/array tools
# -------------------------------------------
from collections import abc
from wrapt import decorator
from itertools import permutations
def isstr(val):
return isinstance(val, (str, bytes))
def isseq(val):
return isinstance(val, abc.Iterable) and not isstr(val)
def toseq(val):
return [] if val == None else val if isseq(val) else [val]
def fromseq(val):
return val if not isseq(val) else None if len(val) == 0 \
else fromseq(val[0]) if len(val) == 1 else val
@decorator
def elementwise(wrapped, instance, args, kwargs):
''' decorate a function so that when it is passed one or more array arguments, it is
called once for each element or permutation of elements if multiple arrays.
@elementwise
def product(x, y):
return x * y
>> list(product([1,2,3],[0.5, 2.0]))
--> [0.5, 1, 0.5, 2, 4, 6]
also works with non-array parameters:
>> list(product(-10, [0, -3, 9.9]
'''
argseqs = [[arg] if not isseq(arg) else list(arg) for arg in args]
for perm in permutations(*argseqs):
yield wrapped(*perm, **kwargs)
# -------------------------------------------
# math helpers
# -------------------------------------------
from itertools import zip_longest
from math import sqrt
import math
def sqr(x):
return x * x
def abs(x):
return x if x >= 0 else -x
def avg(values):
''' can take either an array of single values, returning the mean value or
or an array of arrays, in which case an array of the column-wise averages of
the nth elements in each subarray/row is returned. the length of the subarrays
should be the same, it will probably raise an exception or at least not do what
you want if there are rows missing one or more columns. '''
try:
return [sum(col) / len(col) for col in zip_longest(*values)]
except TypeError: # indicates not all rows have matching number of columns...
return sum(values) / len(values) # try a regular element-wise average instead
def distance(a, b):
''' the euclidean distance from points a and b, given their N-dimensional coordinate tuples '''
if (not isseq(a)) and (not isseq(b)):
return ((a-b) if (a>b) else (b-a))
if not (isseq(a) and isseq(b) and len(a) == len(b)):
raise TypeError("a and b must be sequences of same length")
return math.sqrt(sum([(a[i] - b[i])*(a[i] - b[i]) for i in range(0, len(a))]))
def minindex(a):
''' return a tuple of the minimum index and the minimum element value in the given array '''
if not isseq(a):
raise TypeError("not a sequence")
if len(a) < 1:
raise LengthError("cannot find min index of empty sequence")
minv = a[0]
mini = 0
for i in range(1, len(a)):
if a[i] < minv:
mini = i
minv = a[i]
return mini, minv
def assignlabels(pts, centroids):
''' given a set of points and a set of centerpoints, label each point in the first set with
the index in the second set of the nearest point. '''
def nearestpointindex(apt, bpts):
i, v = minindex([distance(apt, bpts[i]) for i in range(0, len(bpts))])
return i
return [nearestpointindex(pt, centroids) for pt in pts]
def calccentroids(pts, labels):
''' given an array of points and an array of cluster-labels corresponding to those points,
return the centroid of the points for each unique label value '''
for label in set(labels):
center = avg([pts[i] for i in range(0, len(pts)) if labels[i] == label])
yield center
def samelabels(newlabels, oldlabels):
''' compares two arrays of point-cluster labels to determine when the k-means algorithm
has converged on the optimal clustering '''
if len(newlabels) != len(oldlabels):
raise LengthError("number of elements is inconsistent")
for i in range(0, len(newlabels)):
if newlabels[i] != oldlabels[i]:
return False
return True
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment