Created
October 4, 2021 09:45
-
-
Save RF5/cd2c477d0a11e6a0b8a243c36cb06334 to your computer and use it in GitHub Desktop.
Pure kmeans in pytorch
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" By Matthew Baas (rf5.github.io) """ | |
import torch | |
import torch.nn.functional as F | |
def kmeans_pp_init(X, k, dist_func, tol=1e-9): | |
""" | |
`X` is (d, N) , `k` is int; | |
uses kmeanspp init from https://theory.stanford.edu/~sergei/papers/kMeansPP-soda.pdf | |
""" | |
means = torch.empty(X.shape[0], k, dtype=X.dtype, device=X.device) | |
means[:, 0] = X[:, torch.randint(0, X.shape[1], (1, ))[0]] | |
for i in range(1, k): | |
D = dist_func(X, means[:, :i]).min(dim=-1).values # (N, k) | |
D = torch.clamp(D, tol) | |
# naive way of doing this | |
# probs = D / D.sum(dim=0) | |
# smarter way of doing this to prevent numerical errors | |
logp = D.log() - D.sum(dim=0).log() | |
pmf = torch.distributions.Categorical(logits=logp, validate_args=True) | |
ind = pmf.sample() | |
means[:, i] = X[:, ind] | |
return means | |
def euclid_dist(X, means): | |
""" `X` is (d, N), `means` is (d, K), returns dist matrix of shape (N, K) """ | |
dist = ((X[..., None] - means[:, None])**2).sum(dim=0) | |
return dist | |
def cosine_dist(X, means): | |
""" `X` is (d, N), `means` is (d, K), returns dist matrix of shape (N, K) """ | |
dist = 1 - F.cosine_similarity(X[..., None], means[:, None], dim=0) | |
return dist | |
def k_means(X: torch.Tensor, k: int, tol=1e-9, times=50, dist='euclid', init='kmeanspp', verbose=True): | |
""" | |
k-means for `X` (d, N) and `k` classes, where d is vector dimension and N is number of vectors. | |
Tries to fit a kmeans model `times` number of times, returning the results for the best run. | |
The kmeans uses `dist` (either 'euclid' or 'cosine') for distance function. | |
The kmeans uses `init` cluster initialization (either 'kmeanspp' or 'random'). | |
Returns (means, cluster assignments, best loss) """ | |
dist_func = euclid_dist if dist == 'euclid' else cosine_dist | |
best_loss = torch.tensor(float('inf'), dtype=torch.float, device=X.device) | |
best_means = None | |
best_t_jn = None | |
for t in range(times): | |
if init == 'kmeanspp': means = kmeans_pp_init(X, k, dist_func) | |
else: means = X[:, torch.randperm(X.shape[-1], device=X.device)[:k]] # (d, k) | |
new_means = 0 | |
while ((new_means - means)**2).sum() > tol: | |
# E step | |
new_means = means | |
dists = dist_func(X, means) | |
assigned_classes = dists.argmin(dim=-1) | |
del dists | |
t_jn = torch.zeros((X.shape[-1], k), device=X.device) | |
t_jn[torch.arange(t_jn.shape[0], device=X.device), assigned_classes] = 1 | |
# M step | |
for i in range(k): | |
class_i_samples = X[:, assigned_classes == i] | |
# only update the mean if a sample is assigned to this cluster. | |
if class_i_samples.shape[-1] > 0: new_means[:, i] = class_i_samples.mean(dim=-1) | |
# class means (d, k) | |
loss = (t_jn[None] * dist_func(X, new_means) ).sum() # (d, n, k) | |
if loss < best_loss: | |
if verbose: print(f"Run {t:4d}: found new best loss: {loss:7f}") | |
best_loss = loss | |
best_means = new_means | |
best_t_jn = t_jn | |
cluster_assignments = best_t_jn.argmax(dim=-1) | |
return best_means, cluster_assignments, best_loss | |
if __name__ == '__main__': | |
print("Running verification tests") | |
N = 1000 | |
k = 50 | |
d = 25 | |
X = torch.rand(d, N).cuda() | |
c, assignments, loss = k_means(X, k, times=1, dist='euclid', init='kmeanspp') | |
print(c.shape, assignments.shape, loss) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment