Skip to content

Instantly share code, notes, and snippets.

@Semnodime
Forked from iandanforth/kmeansExample.py
Last active May 13, 2024 14:02
Show Gist options
  • Save Semnodime/0f90a0c2c1cba55141cd5e5e630f3b8c to your computer and use it in GitHub Desktop.
Save Semnodime/0f90a0c2c1cba55141cd5e5e630f3b8c to your computer and use it in GitHub Desktop.
A pure python3 compatible implementation of K-Means clustering. Optional cluster visualization using plot.ly.
"""
This is a pure Python3 implementation of the K-means Clustering algorithm.
It is based on a GitHub Gist which uses Python2:
https://gist.github.com/iandanforth/5862470
I have refactored the code and to assure the code obeys Python Enhancement Proposals (PEPs) rules.
After reading through this code you should understand clearly how K-means works.
This script specifically avoids using numpy or other more obscure libraries.
It is meant to be *clear* not fast.
I have also added integration with the plot.ly plotting library.
So you can see the clusters found by this algorithm. To install plotly run:
```
pip install plotly
```
This script uses an offline plotting mode and will store and open plots locally.
To store and share plots online sign up for a plotly API key at https://plot.ly.
"""
import math
import random
try:
import plotly
from plotly.graph_objs import Scatter, Scatter3d, Layout
except ImportError:
plotly = Scatter = Scatter3d = Layout = None
print('INFO: Plotly is not installed, plots will not be generated.')
def main():
# How many points are in our dataset?
num_points = 20
# For each of those points how many dimensions do they have?
# Note: Plotting will only work in two or three dimensions
dimensions = 2
# Bounds for the values of those points in each dimension
lower = 0
upper = 200
# The K in k-means. How many clusters do we assume exist?
# - Must be less than num_points
num_clusters = 3
# When do we say the process has 'converged' and stop updating clusters?
cutoff = 0.2
# Generate some points to cluster
# Note: If you want to use your own data, set points equal to it here.
points = [
make_random_point(dimensions, lower, upper) for i in range(num_points)
]
# Cluster those data!
iteration_count = 20
best_clusters = iterative_kmeans(
points,
num_clusters,
cutoff,
iteration_count
)
# Print our best clusters
for i, c in enumerate(best_clusters):
for p in c.points:
print(' Cluster: ', i, '\t Point :', p)
# Display clusters using plotly for 2d data
if dimensions in [2, 3] and plotly:
print('Plotting points, launching browser ...')
plot_clusters(best_clusters, dimensions)
#############################################################################
# K-means Methods
def iterative_kmeans(points, num_clusters, cutoff, iteration_count):
"""
K-means isn't guaranteed to get the best answer the first time. It might
get stuck in a "local minimum."
Here we run kmeans() *iteration_count* times to increase the chance of
getting a good answer.
Returns the best set of clusters found.
"""
print('Running K-means %d times to find best clusters ...' % iteration_count)
candidate_clusters = []
errors = []
for _ in range(iteration_count):
clusters = kmeans(points, num_clusters, cutoff)
error = calculate_error(clusters)
candidate_clusters.append(clusters)
errors.append(error)
highest_error = max(errors)
lowest_error = min(errors)
print('Lowest error found: %.2f (Highest: %.2f)' % (
lowest_error,
highest_error
))
ind_of_lowest_error = errors.index(lowest_error)
best_clusters = candidate_clusters[ind_of_lowest_error]
return best_clusters
def kmeans(points, k, cutoff):
# Pick out k random points to use as our initial centroids
initial_centroids = random.sample(points, k)
# Create k clusters using those centroids
# Note: Cluster takes lists, so we wrap each point in a list here.
clusters = [Cluster([p]) for p in initial_centroids]
# Loop through the dataset until the clusters stabilize
loop_counter = 0
while True:
# Create a list of lists to hold the points in each cluster
lists = [[] for _ in clusters]
cluster_count = len(clusters)
# Start counting loops
loop_counter += 1
# For every point in the dataset ...
for p in points:
# Get the distance between that point and the centroid of the first
# cluster.
smallest_distance = get_distance(p, clusters[0].centroid)
# Set the cluster this point belongs to
cluster_index = 0
# For the remainder of the clusters ...
for i in range(1, cluster_count):
# calculate the distance of that point to each other cluster's
# centroid.
distance = get_distance(p, clusters[i].centroid)
# If it's closer to that cluster's centroid update what we
# think the smallest distance is
if distance < smallest_distance:
smallest_distance = distance
cluster_index = i
# After finding the cluster the smallest distance away
# set the point to belong to that cluster
lists[cluster_index].append(p)
# Set our biggest_shift to zero for this iteration
biggest_shift = 0.0
# For each cluster ...
for i in range(cluster_count):
# Calculate how far the centroid moved in this iteration
shift = clusters[i].update(lists[i])
# Keep track of the largest move from all cluster centroid updates
biggest_shift = max(biggest_shift, shift)
# Remove empty clusters
clusters = [c for c in clusters if len(c.points) != 0]
# If the centroids have stopped moving much, say we're done!
if biggest_shift < cutoff:
print('Converged after %s iterations' % loop_counter)
break
return clusters
#############################################################################
# Classes
class Point(object):
"""
A point in n dimensional space
"""
def __init__(self, coords):
"""
coords - A list of values, one per dimension
"""
self.coords = coords
self.n = len(coords)
def __repr__(self):
return str(self.coords)
class Cluster(object):
"""
A set of points and their centroid
"""
def __init__(self, points):
"""
points - A list of point objects
"""
if len(points) == 0:
raise Exception('ERROR: empty cluster')
# The points that belong to this cluster
self.points = points
# The dimensionality of the points in this cluster
self.n = points[0].n
# Assert that all points are of the same dimensionality
for p in points:
if p.n != self.n:
raise Exception('ERROR: inconsistent dimensions')
# Set up the initial centroid (this is usually based off one point)
self.centroid = self.calculate_centroid()
def __repr__(self):
"""
String representation of this object
"""
return str(self.points)
def update(self, points):
"""
Returns the distance between the previous centroid and the new after
recalculating and storing the new centroid.
Note: Initially we expect centroids to shift around a lot and then
gradually settle down.
"""
old_centroid = self.centroid
self.points = points
# Return early if we have no points, this cluster will get
# cleaned up (removed) in the outer loop.
if len(self.points) == 0:
return 0
self.centroid = self.calculate_centroid()
shift = get_distance(old_centroid, self.centroid)
return shift
def calculate_centroid(self):
"""
Finds a virtual center point for a group of n-dimensional points
"""
num_points = len(self.points)
# Get a list of all coordinates in this cluster
coords = [p.coords for p in self.points]
# Reformat that so all x's are together, all y'z etc.
unzipped = zip(*coords)
# Calculate the mean for each dimension
centroid_coords = [math.fsum(dList) / num_points for dList in unzipped]
return Point(centroid_coords)
def get_total_distance(self):
"""
Return the sum of all squared Euclidean distances between each point in
the cluster and the cluster's centroid.
"""
sum_of_distances = 0.0
for p in self.points:
sum_of_distances += get_distance(p, self.centroid)
return sum_of_distances
#############################################################################
# Helper Methods
def get_distance(a, b):
"""
Squared Euclidean distance between two n-dimensional points.
https://en.wikipedia.org/wiki/Euclidean_distance#n_dimensions
Note: This can be very slow and does not scale well
"""
if a.n != b.n:
raise Exception('ERROR: non comparable points')
accumulated_difference = 0.0
for i in range(a.n):
square_difference = pow((a.coords[i] - b.coords[i]), 2)
accumulated_difference += square_difference
return accumulated_difference
def make_random_point(n, lower, upper):
"""
Returns a Point object with n dimensions and values between lower and
upper in each of those dimensions
"""
p = Point([random.uniform(lower, upper) for _ in range(n)])
return p
def calculate_error(clusters):
"""
Return the average squared distance between each point and its cluster
centroid.
This is also known as the "distortion cost."
"""
accumulated_distances = 0
num_points = 0
for cluster in clusters:
num_points += len(cluster.points)
accumulated_distances += cluster.get_total_distance()
error = accumulated_distances / num_points
return error
def plot_clusters(data, dimensions):
"""
This uses the plotly offline mode to create a local HTML file.
This should open your default web browser.
"""
if plotly == Scatter == Scatter3d == Layout is None:
raise ImportError('Plotly is not installed or could not be imported correctly.')
if dimensions not in [2, 3]:
raise Exception('Plots are only available for 2 and 3 dimensional data')
# Convert data into plotly format.
trace_list = []
for i, c in enumerate(data):
# Get a list of x,y coordinates for the points in this cluster.
cluster_data = []
for point in c.points:
cluster_data.append(point.coords)
trace = {}
centroid = {}
if dimensions == 2:
# Convert our list of x,y's into an x list and a y list.
trace['x'], trace['y'] = zip(*cluster_data)
trace['mode'] = 'markers'
trace['marker'] = {}
trace['marker']['symbol'] = i
trace['marker']['size'] = 12
trace['name'] = 'Cluster ' + str(i)
trace_list.append(Scatter(**trace))
# Centroid (A trace of length 1)
centroid['x'] = [c.centroid.coords[0]]
centroid['y'] = [c.centroid.coords[1]]
centroid['mode'] = 'markers'
centroid['marker'] = {}
centroid['marker']['symbol'] = i
centroid['marker']['color'] = 'rgb(200,10,10)'
centroid['name'] = 'Centroid ' + str(i)
trace_list.append(Scatter(**centroid))
else:
symbols = [
'circle',
'square',
'diamond',
'circle-open',
'square-open',
'diamond-open',
'cross', 'x'
]
symbol_count = len(symbols)
if i > symbol_count:
print('Warning: Not enough marker symbols to go around')
# Convert our list of x,y,z's separate lists.
trace['x'], trace['y'], trace['z'] = zip(*cluster_data)
trace['mode'] = 'markers'
trace['marker'] = {}
trace['marker']['symbol'] = symbols[i]
trace['marker']['size'] = 12
trace['name'] = 'Cluster ' + str(i)
trace_list.append(Scatter3d(**trace))
# Centroid (A trace of length 1)
centroid['x'] = [c.centroid.coords[0]]
centroid['y'] = [c.centroid.coords[1]]
centroid['z'] = [c.centroid.coords[2]]
centroid['mode'] = 'markers'
centroid['marker'] = {}
centroid['marker']['symbol'] = symbols[i]
centroid['marker']['color'] = 'rgb(200,10,10)'
centroid['name'] = 'Centroid ' + str(i)
trace_list.append(Scatter3d(**centroid))
title = 'K-means clustering with %s clusters' % str(len(data))
plotly.offline.plot(dict(data=trace_list, layout=Layout(title=title)))
if __name__ == '__main__':
main()
@trainerpraveen
Copy link

Superb

@hemulin
Copy link

hemulin commented Jul 18, 2018

@Semnodime, you are my hero 🥇

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment