Skip to content

Instantly share code, notes, and snippets.

@tyler-romero
Created January 4, 2025 01:24
An efficient pytorch-based implementation of K-means++ that supports adding new data-points to the set of yet-unsampled points on-the-fly.
import torch
class IncrementalKMeansPP:
def __init__(self, initial_data: torch.Tensor = None):
"""
Incremental K-means++ implementation that supports adding data on the fly.
Args:
initial_data: Initial data points (optional).
"""
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.data = initial_data.to(self.device) if initial_data is not None else None
self.min_squared_distances = None
self.selected_indices = []
@torch.no_grad()
def add_data(self, new_data: torch.Tensor):
"""
Add new data points to the pool of unlabeled data.
Args:
new_data: New data points to add.
"""
new_data = new_data.to(self.device)
if self.data is None:
self.data = new_data
self.min_squared_distances = torch.full((self.data.shape[0],), float('inf'), device=self.device)
else:
# Compute distances from existing selected centers to new data points
if self.selected_indices:
selected_centers = self.data[torch.tensor(self.selected_indices, device=self.device)]
new_data_squared_distances = torch.cdist(new_data, selected_centers, p=2).square().min(dim=1).values
else:
new_data_squared_distances = torch.full((new_data.shape[0],), float('inf'), device=self.device)
self.data = torch.cat((self.data, new_data), dim=0)
# Initialize min_squared_distances for new data
new_min_squared_distances = torch.full((new_data.shape[0],), float('inf'), device=self.device)
if self.min_squared_distances is not None:
self.min_squared_distances = torch.cat((self.min_squared_distances, new_min_squared_distances), dim=0)
else:
self.min_squared_distances = new_min_squared_distances
# Update min_squared_distances for the new data points
self.min_squared_distances[-new_data.shape[0]:] = new_data_squared_distances
@torch.no_grad()
def generate_ordering(self, k: int = -1):
"""
Generates a priority ordering of data points using the K-means++ algorithm.
Points are yielded one at a time in priority order.
Args:
k: The number of points to select in the ordering. If -1, all points are ordered.
Returns:
A generator that yields indices one at a time representing the priority ordering.
Yields:
int: The index of the next highest priority point.
"""
if self.data is None:
return torch.tensor([], dtype=torch.long, device=self.device)
n_samples = self.data.shape[0]
if k == -1:
k = n_samples
if not self.selected_indices:
# Pick first point randomly
first_index = torch.randint(n_samples, (1,), device=self.device).item()
self.selected_indices.append(first_index)
self.min_squared_distances[first_index] = 0
yield first_index
# Process remaining points
while len(self.selected_indices) < k:
# Get the last selected center and compute squared distances to all points
last_center = self.data[self.selected_indices[-1]].unsqueeze(0)
new_squared_distances = torch.cdist(self.data, last_center, p=2).square()[:, 0]
# Update minimum squared distances if new distances are smaller
self.min_squared_distances = torch.minimum(self.min_squared_distances, new_squared_distances)
# Sample next center with probability proportional to squared distance
probs = self.min_squared_distances / self.min_squared_distances.sum()
cumprobs = torch.cumsum(probs, dim=0)
r = torch.rand(1, device=self.device)
next_index = torch.searchsorted(cumprobs, r).item()
self.selected_indices.append(next_index)
# Set distance to zero for selected point to avoid reselection
self.min_squared_distances[next_index] = 0
yield next_index
raise StopIteration()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment