Skip to content

Instantly share code, notes, and snippets.

@hristo-vrigazov
Created July 22, 2016 07:47
Show Gist options
  • Save hristo-vrigazov/747780cc1afa3cd2bf39d4e14b4eaeec to your computer and use it in GitHub Desktop.
Save hristo-vrigazov/747780cc1afa3cd2bf39d4e14b4eaeec to your computer and use it in GitHub Desktop.
Simple k means implementation to solve Coursera's quiz
import math
import statistics
import copy
centroids = [(2, 2), (-2, -2)]
data_points = [ \
(-1.88, 2.05), \
(-0.71, 0.42), \
(2.41, -0.67), \
(1.85, -3.80), \
(-3.69, -1.33)]
belongs_to = [0] * len(data_points)
# counts how many times a point has changed to answer
# Coursera's quiz question
changed = [0] * len(data_points)
def distance(data_point, centroid):
return math.sqrt(sum([(data_point[i] - centroid[i]) ** 2 for i in range(len(data_point))]))
def find_closest_centroid_index_to_data_point(i):
min_distance = float("+inf")
min_index = 0
for j in range(len(centroids)):
dist = distance(data_points[i], centroids[j])
if dist < min_distance:
min_distance = dist
min_index = j
return min_index
def assign_data_point_to_nearest_centroid(i):
closest_centroid_index = find_closest_centroid_index_to_data_point(i)
if belongs_to[i] != closest_centroid_index:
changed[i] += 1
belongs_to[i] = closest_centroid_index
def assign_data_points_to_nearest_centroid():
for i in range(len(data_points)):
assign_data_point_to_nearest_centroid(i)
converged = False
def adjust_centers():
before_update_centroids = copy.deepcopy(centroids)
for i in range(len(centroids)):
data_points_in_this_centroid = [data_points[j] for j in range(len(data_points)) if belongs_to[j] == i]
x_mean = statistics.mean(map(lambda x: x[0], data_points_in_this_centroid))
y_mean = statistics.mean(map(lambda x: x[1], data_points_in_this_centroid))
centroids[i] = (x_mean, y_mean)
global converged
if centroids == before_update_centroids:
converged = True
def k_means():
while not converged:
assign_data_points_to_nearest_centroid()
adjust_centers()
def show_most_changed_data_point():
max_index = 0
for i in range(len(changed)):
if changed[max_index] < changed[i]:
max_index = i
print("Data point {} changed the most; it changed {} times".format(max_index + 1, changed[max_index]))
if __name__ == "__main__":
k_means()
show_most_changed_data_point()
@hristo-vrigazov
Copy link
Author

It uses Python 3

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment