Created
May 12, 2019 08:41
-
-
Save tyliec/e053157c3e52cecf42e9ed1c1ecc7dce to your computer and use it in GitHub Desktop.
soft k-means
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import matplotlib.pyplot as plt | |
from sklearn.datasets import make_blobs | |
def plot_k_means(x, r, k, centers, colors): | |
# print(r[:20]) | |
# plt.scatter(x[:,0], x[:,1], c=('red', 'blue', 'green')) | |
# plt.scatter(x[:,0], x[:,1], c=('black', 'black', 'black')) | |
plt.scatter(x[:,0], x[:,1], c=colors) | |
for c in centers: | |
plt.plot(c[0], c[1], 'ro') | |
plt.show() | |
def initialize_centers(x, num_k): | |
N, D = x.shape | |
centers = np.zeros((num_k, D)) | |
used_idx = [] | |
for k in range(num_k): | |
idx = np.random.choice(N) | |
while idx in used_idx: | |
idx = np.random.choice(N) | |
used_idx.append(idx) | |
centers[k] = x[idx] | |
return centers | |
def update_centers(x, r, K): | |
N, D = x.shape | |
centers = np.zeros((K, D)) | |
for k in range(K): | |
centers[k] = r[:, k].dot(x) / r[:, k].sum() | |
return centers | |
def square_dist(a, b): | |
return (a - b) ** 2 | |
def cost_func(x, r, centers, K): | |
cost = 0 | |
for k in range(K): | |
norm = np.linalg.norm(x - centers[k], 2) | |
cost += (norm * np.expand_dims(r[:, k], axis=1) ).sum() | |
return cost | |
def cluster_responsibilities(centers, x, beta): | |
N, _ = x.shape | |
K, D = centers.shape | |
R = np.zeros((N, K)) | |
for n in range(N): | |
R[n] = np.exp(-beta * np.linalg.norm(centers - x[n], 2, axis=1)) | |
R /= R.sum(axis=1, keepdims=True) | |
return R | |
def return_responsibilities(R): | |
a = [] | |
for i, r in enumerate(R): | |
r[0], r[2] = r[2], r[0] | |
a.append(np.argmax(r)) | |
r[0], r[2] = r[2], r[0] | |
return a | |
def calculate_partial_accuracy(X, labels, R): | |
responsibilities = return_responsibilities(R) | |
score = 0 | |
for i, l in enumerate(labels): | |
if responsibilities[i] == l: | |
score += 1 | |
return float(score) / len(responsibilities) * 100 | |
def soft_k_means(x, labels, K, max_iters=20, beta=1.): | |
np.random.seed(5) | |
random_colors = np.random.random((K, 3)) | |
centers = initialize_centers(x, K) | |
# print centers | |
r = cluster_responsibilities(centers, x, beta) | |
# print r | |
colors = r.dot(random_colors) | |
print 'Initialize Plot' | |
# plot_k_means(x, r, K, centers, colors) | |
max_i = [] | |
accuracies = [] | |
prev_cost = 0 | |
for i in range(max_iters): | |
r = cluster_responsibilities(centers, x, beta) | |
colors = r.dot(random_colors) | |
centers = update_centers(x, r, K) | |
cost = cost_func(x, r, centers, K) | |
print 'Iteration: ' + str(i) | |
print centers | |
# plot_k_means(x, r, K, centers, colors) | |
acc = calculate_partial_accuracy(X, labels, r) | |
accuracies.append(acc) | |
max_i.append(i) | |
# print 'Accuracy: ' + str(acc) + '%' | |
if acc == 100: | |
print 'Breaking: Fully Accurate' | |
break | |
if np.abs(cost - prev_cost) < 1e-5: | |
print 'Breaking: Cost too high' | |
break | |
prev_cost = cost | |
print 'Finish Plot' | |
plot_k_means(x, r, K, centers, colors) | |
r = cluster_responsibilities(centers, x, beta) | |
acc = calculate_partial_accuracy(X, labels, r) | |
print max_i | |
# plt.plot(max_i, accuracies) | |
plt.show() | |
print 'Accuracy: ' + str(acc) + '%' | |
X, labels = make_blobs(n_samples=100, centers=3, cluster_std=1.5, random_state=1) | |
soft_k_means(X, labels, K=3) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment