Skip to content

Instantly share code, notes, and snippets.

@nekketsuuu
Last active June 19, 2018 04:06
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save nekketsuuu/4db3f47ac6841b139c3faf3aa9b3da49 to your computer and use it in GitHub Desktop.
Save nekketsuuu/4db3f47ac6841b139c3faf3aa9b3da49 to your computer and use it in GitHub Desktop.
#!/usr/bin/python3
import numpy as np
# calculate min after calculating all dists
def update1(data, labels):
sizes = np.bincount(labels)
n_label = sizes.shape[0]
n_data = data.shape[0]
dim = data.shape[1]
centroids = np.zeros((n_label, dim))
for label, size in enumerate(sizes):
centroids[label] = np.sum(data[labels == label], axis=0) / size
dist = np.zeros((n_data, n_label))
for i, point in enumerate(data):
for j, centroid in enumerate(centroids):
dist[i][j] = np.linalg.norm(point - centroid)
return np.argmin(dist, axis=1)
# calculate min just after calculating each dist
def update2(data, old_labels):
labels = np.array(old_labels, copy=True)
n_label = np.max(labels) + 1
n_data = data.shape[0]
dim = data.shape[1]
for i, point in enumerate(data):
n_elem = (labels == 0).sum()
centroid = np.sum(data[labels == 0], axis=0) / n_elem
min_dist = np.linalg.norm(point - centroid)
labels[i] = 0
for label in range(n_label - 1):
n_elem = (labels == label + 1).sum()
if n_elem == 0:
continue
centroid = np.sum(data[labels == label + 1], axis=0) / n_elem
tmp_dist = np.linalg.norm(point - centroid)
if tmp_dist < min_dist:
min_dist = tmp_dist
labels[i] = label + 1
return labels
if __name__ == '__main__':
data = np.array([
[2.0, 3.0],
[0.0, 0.0],
[0.0, 2.0],
[3.0, 2.0],
[3.0, 0.0]
])
labels = np.array([1, 0, 0, 0, 0])
print(0, labels)
for n in range(10):
new_labels = update1(data, labels)
# new_labels = update2(data, labels)
print(n+1, new_labels)
if np.array_equal(labels, new_labels):
print("INFO: converge in {} times".format(n))
break
labels = new_labels
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment