Skip to content

Instantly share code, notes, and snippets.

@elias-ramzi
Created February 12, 2024 10:53
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save elias-ramzi/1a48f35a8dec8466d87634ce3cd720f1 to your computer and use it in GitHub Desktop.
Save elias-ramzi/1a48f35a8dec8466d87634ce3cd720f1 to your computer and use it in GitHub Desktop.
Batch version of K-means
from typing import Tuple, Optional
import torch
import torch.nn.functional as F
from torch import Tensor
from tqdm import tqdm
NoneType = type(None)
@torch.no_grad()
def batch_kmeans(
embeddings: Tensor,
discard_mask: Optional[Tensor] = None,
num_prototypes: int = 50,
spherical: bool = False,
nmb_kmeans_iters: int = 300,
nredo: Optional[int] = None,
init: str = "kmeans++",
verbose: bool = False,
random_state: int = 0,
can_be_inplace: bool = False
) -> Tuple[Tensor]:
"""
This functions performs kmeans on a batch of collections of vectors.
This means that for every collection of the batch, the kmeans algorithm will be applied.
Resulting in batch_size * num_prototypes vectors.
Args:
embeddings (Tensor): tokens to cluster, [batch_size, num_tokens, embedding_size].
discard_mask (Tensor, optional): if passed indicates for each collection on which tokens the clustering is performed
(True to keep the token), [batch_size, num_tokens].
num_prototypes (int): number of centroids (it is the same for every collection).
spherical (bool): if True, the centroids will be normalized after each iteration.
nmb_kmeans_iters (int): max number of iterations.
nredo (int): number of times the kmeans algorithm will be applied to select the best centroids.
init (str): initialization method, either "kmeans++" or "random".
p (int): power for the generalized mean pooling.
eps (float): epsilon for the generalized mean pooling.
verbose (bool): if True, a progress bar will be displayed.
random_state (int): random state for the initialization.
can_be_inplace (bool): whether or not the embeddings can overwritten.
//!\\ This is not a mini-batch kmean function
"""
# perform checks
assert embeddings.ndim == 3 # [batch, vectors, dimension]
assert init in ["kmeans++", "random"], f"init must be in one of [kmeans++, random]; got {init}"
if discard_mask is not None:
assert (discard_mask.sum(1) >= num_prototypes).all(), "please reduce the number of centroids"
else:
assert embeddings.size(1) >= num_prototypes, "please reduce the number of centroids"
# options
nredo = nredo or 1
random_state += 1
if not can_be_inplace:
embeddings = torch.clone(embeddings)
if discard_mask is not None:
embeddings[~discard_mask] = 0
X_mean = embeddings.sum(dim=1) / discard_mask.sum(dim=-1, keepdim=True)
X_mean.unsqueeze_(1)
else:
X_mean = embeddings.mean(dim=1, keepdim=True)
embeddings -= X_mean # this trick allow better distance computation (from sklearn)
device = embeddings.device
dtype = embeddings.dtype
batch_size, num_tokens, embedding_size = embeddings.shape
min_inertia = torch.ones(batch_size, device=device) * torch.inf
best_centroids = torch.empty(batch_size, num_prototypes, embedding_size, device=device, dtype=dtype)
for n in range(nredo):
assignments = torch.empty(batch_size, num_tokens, dtype=torch.long, device=device)
old_assignments = torch.clone(assignments)
centroids = torch.empty_like(best_centroids)
for i in range(batch_size):
torch.manual_seed(random_state * (n * batch_size + i))
if discard_mask is not None:
idxs = torch.where(discard_mask[i])[0]
random_idx = idxs[torch.randperm(len(idxs))[:num_prototypes]]
centroids[i] = embeddings[i][random_idx]
else:
random_idx = torch.randperm(num_tokens)[:num_prototypes]
centroids[i] = embeddings[i][random_idx]
if init == "kmeans++":
for j in range(1, num_prototypes):
dot_products = embeddings[i] @ centroids[i, :j].permute(1, 0)
cosine_similarity, _ = dot_products.max(dim=-1)
probabilities = (1 - cosine_similarity) ** 2
if discard_mask is None:
probabilities /= probabilities.sum()
else:
probabilities /= probabilities[discard_mask[i]].sum()
probabilities[~discard_mask[i]] = 0
new_centroid = torch.multinomial(probabilities, 1)
centroids[i, j] = embeddings[i][new_centroid]
for n_iter in tqdm(range(nmb_kmeans_iters + 1), desc=f"Kmeans redo: {n+1}/{nredo}", disable=not verbose):
# E step
dot_products = embeddings @ centroids.permute(0, 2, 1)
cosine_similarity, assignments = dot_products.max(dim=-1)
if discard_mask is not None:
assignments[~discard_mask] = 0 # force all discarded tokens to be assigned to the first centroid
# finish
if (n_iter == nmb_kmeans_iters) or (old_assignments == assignments).all():
inertia = (2 * (1 - cosine_similarity)).sum(-1)
inertia_mask = min_inertia > inertia # updates the best centroids if the inertia is lower
min_inertia[inertia_mask] = inertia[inertia_mask]
best_centroids[inertia_mask] = centroids[inertia_mask]
break
old_assignments = torch.clone(assignments)
# M step
emb_means = torch.zeros_like(centroids)
counts = torch.zeros(batch_size, num_prototypes, device=device, dtype=torch.long)
emb_means.scatter_add_(1, assignments.unsqueeze(-1).expand_as(embeddings), embeddings) # you can sum the embeddings directly as the discarded tokens are already set to 0
counts.scatter_add_(1, assignments, torch.ones_like(assignments, device=device, dtype=torch.long)) # all the discarded tokens are assigned to the first centroid
if discard_mask is not None:
counts[:, 0] -= (~discard_mask).sum(dim=1) # remove the discarded tokens from the counts of the first centroid
centroids = emb_means / counts.unsqueeze(-1)
force_reassign_cluster = (counts == 0)
if force_reassign_cluster.any():
for i in range(batch_size):
if force_reassign_cluster[i].any():
for j in range(num_prototypes):
if force_reassign_cluster[i, j]:
if discard_mask is not None:
idxs = torch.where(discard_mask[i])[0]
random_idx = idxs[torch.randperm(len(idxs))[:1]]
centroids[i, j] = embeddings[i][random_idx]
else:
random_idx = torch.randperm(num_tokens)[:1]
centroids[i, j] = embeddings[i][random_idx]
if spherical: # normalize centroids
centroids = F.normalize(centroids, dim=-1, p=2)
if nredo > 1: # if ndredo == 1, the assignements are already computed
dot_products = embeddings @ best_centroids.permute(0, 2, 1)
assignments = dot_products.max(dim=-1).indices
best_centroids += X_mean # add the mean back
if spherical: # normalize centroids
centroids = F.normalize(centroids, dim=-1, p=2)
if discard_mask is not None:
assignments[~discard_mask] = -1 # force all discarded tokens to be assigned to unexisting cluster
return assignments, best_centroids, dot_products
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment