Created
February 12, 2024 10:53
-
-
Save elias-ramzi/1a48f35a8dec8466d87634ce3cd720f1 to your computer and use it in GitHub Desktop.
Batch version of K-means
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from typing import Tuple, Optional | |
import torch | |
import torch.nn.functional as F | |
from torch import Tensor | |
from tqdm import tqdm | |
NoneType = type(None) | |
@torch.no_grad() | |
def batch_kmeans( | |
embeddings: Tensor, | |
discard_mask: Optional[Tensor] = None, | |
num_prototypes: int = 50, | |
spherical: bool = False, | |
nmb_kmeans_iters: int = 300, | |
nredo: Optional[int] = None, | |
init: str = "kmeans++", | |
verbose: bool = False, | |
random_state: int = 0, | |
can_be_inplace: bool = False | |
) -> Tuple[Tensor]: | |
""" | |
This functions performs kmeans on a batch of collections of vectors. | |
This means that for every collection of the batch, the kmeans algorithm will be applied. | |
Resulting in batch_size * num_prototypes vectors. | |
Args: | |
embeddings (Tensor): tokens to cluster, [batch_size, num_tokens, embedding_size]. | |
discard_mask (Tensor, optional): if passed indicates for each collection on which tokens the clustering is performed | |
(True to keep the token), [batch_size, num_tokens]. | |
num_prototypes (int): number of centroids (it is the same for every collection). | |
spherical (bool): if True, the centroids will be normalized after each iteration. | |
nmb_kmeans_iters (int): max number of iterations. | |
nredo (int): number of times the kmeans algorithm will be applied to select the best centroids. | |
init (str): initialization method, either "kmeans++" or "random". | |
p (int): power for the generalized mean pooling. | |
eps (float): epsilon for the generalized mean pooling. | |
verbose (bool): if True, a progress bar will be displayed. | |
random_state (int): random state for the initialization. | |
can_be_inplace (bool): whether or not the embeddings can overwritten. | |
//!\\ This is not a mini-batch kmean function | |
""" | |
# perform checks | |
assert embeddings.ndim == 3 # [batch, vectors, dimension] | |
assert init in ["kmeans++", "random"], f"init must be in one of [kmeans++, random]; got {init}" | |
if discard_mask is not None: | |
assert (discard_mask.sum(1) >= num_prototypes).all(), "please reduce the number of centroids" | |
else: | |
assert embeddings.size(1) >= num_prototypes, "please reduce the number of centroids" | |
# options | |
nredo = nredo or 1 | |
random_state += 1 | |
if not can_be_inplace: | |
embeddings = torch.clone(embeddings) | |
if discard_mask is not None: | |
embeddings[~discard_mask] = 0 | |
X_mean = embeddings.sum(dim=1) / discard_mask.sum(dim=-1, keepdim=True) | |
X_mean.unsqueeze_(1) | |
else: | |
X_mean = embeddings.mean(dim=1, keepdim=True) | |
embeddings -= X_mean # this trick allow better distance computation (from sklearn) | |
device = embeddings.device | |
dtype = embeddings.dtype | |
batch_size, num_tokens, embedding_size = embeddings.shape | |
min_inertia = torch.ones(batch_size, device=device) * torch.inf | |
best_centroids = torch.empty(batch_size, num_prototypes, embedding_size, device=device, dtype=dtype) | |
for n in range(nredo): | |
assignments = torch.empty(batch_size, num_tokens, dtype=torch.long, device=device) | |
old_assignments = torch.clone(assignments) | |
centroids = torch.empty_like(best_centroids) | |
for i in range(batch_size): | |
torch.manual_seed(random_state * (n * batch_size + i)) | |
if discard_mask is not None: | |
idxs = torch.where(discard_mask[i])[0] | |
random_idx = idxs[torch.randperm(len(idxs))[:num_prototypes]] | |
centroids[i] = embeddings[i][random_idx] | |
else: | |
random_idx = torch.randperm(num_tokens)[:num_prototypes] | |
centroids[i] = embeddings[i][random_idx] | |
if init == "kmeans++": | |
for j in range(1, num_prototypes): | |
dot_products = embeddings[i] @ centroids[i, :j].permute(1, 0) | |
cosine_similarity, _ = dot_products.max(dim=-1) | |
probabilities = (1 - cosine_similarity) ** 2 | |
if discard_mask is None: | |
probabilities /= probabilities.sum() | |
else: | |
probabilities /= probabilities[discard_mask[i]].sum() | |
probabilities[~discard_mask[i]] = 0 | |
new_centroid = torch.multinomial(probabilities, 1) | |
centroids[i, j] = embeddings[i][new_centroid] | |
for n_iter in tqdm(range(nmb_kmeans_iters + 1), desc=f"Kmeans redo: {n+1}/{nredo}", disable=not verbose): | |
# E step | |
dot_products = embeddings @ centroids.permute(0, 2, 1) | |
cosine_similarity, assignments = dot_products.max(dim=-1) | |
if discard_mask is not None: | |
assignments[~discard_mask] = 0 # force all discarded tokens to be assigned to the first centroid | |
# finish | |
if (n_iter == nmb_kmeans_iters) or (old_assignments == assignments).all(): | |
inertia = (2 * (1 - cosine_similarity)).sum(-1) | |
inertia_mask = min_inertia > inertia # updates the best centroids if the inertia is lower | |
min_inertia[inertia_mask] = inertia[inertia_mask] | |
best_centroids[inertia_mask] = centroids[inertia_mask] | |
break | |
old_assignments = torch.clone(assignments) | |
# M step | |
emb_means = torch.zeros_like(centroids) | |
counts = torch.zeros(batch_size, num_prototypes, device=device, dtype=torch.long) | |
emb_means.scatter_add_(1, assignments.unsqueeze(-1).expand_as(embeddings), embeddings) # you can sum the embeddings directly as the discarded tokens are already set to 0 | |
counts.scatter_add_(1, assignments, torch.ones_like(assignments, device=device, dtype=torch.long)) # all the discarded tokens are assigned to the first centroid | |
if discard_mask is not None: | |
counts[:, 0] -= (~discard_mask).sum(dim=1) # remove the discarded tokens from the counts of the first centroid | |
centroids = emb_means / counts.unsqueeze(-1) | |
force_reassign_cluster = (counts == 0) | |
if force_reassign_cluster.any(): | |
for i in range(batch_size): | |
if force_reassign_cluster[i].any(): | |
for j in range(num_prototypes): | |
if force_reassign_cluster[i, j]: | |
if discard_mask is not None: | |
idxs = torch.where(discard_mask[i])[0] | |
random_idx = idxs[torch.randperm(len(idxs))[:1]] | |
centroids[i, j] = embeddings[i][random_idx] | |
else: | |
random_idx = torch.randperm(num_tokens)[:1] | |
centroids[i, j] = embeddings[i][random_idx] | |
if spherical: # normalize centroids | |
centroids = F.normalize(centroids, dim=-1, p=2) | |
if nredo > 1: # if ndredo == 1, the assignements are already computed | |
dot_products = embeddings @ best_centroids.permute(0, 2, 1) | |
assignments = dot_products.max(dim=-1).indices | |
best_centroids += X_mean # add the mean back | |
if spherical: # normalize centroids | |
centroids = F.normalize(centroids, dim=-1, p=2) | |
if discard_mask is not None: | |
assignments[~discard_mask] = -1 # force all discarded tokens to be assigned to unexisting cluster | |
return assignments, best_centroids, dot_products |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment