Created
January 4, 2025 01:24
An efficient pytorch-based implementation of K-means++ that supports adding new data-points to the set of yet-unsampled points on-the-fly.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
class IncrementalKMeansPP: | |
def __init__(self, initial_data: torch.Tensor = None): | |
""" | |
Incremental K-means++ implementation that supports adding data on the fly. | |
Args: | |
initial_data: Initial data points (optional). | |
""" | |
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
self.data = initial_data.to(self.device) if initial_data is not None else None | |
self.min_squared_distances = None | |
self.selected_indices = [] | |
@torch.no_grad() | |
def add_data(self, new_data: torch.Tensor): | |
""" | |
Add new data points to the pool of unlabeled data. | |
Args: | |
new_data: New data points to add. | |
""" | |
new_data = new_data.to(self.device) | |
if self.data is None: | |
self.data = new_data | |
self.min_squared_distances = torch.full((self.data.shape[0],), float('inf'), device=self.device) | |
else: | |
# Compute distances from existing selected centers to new data points | |
if self.selected_indices: | |
selected_centers = self.data[torch.tensor(self.selected_indices, device=self.device)] | |
new_data_squared_distances = torch.cdist(new_data, selected_centers, p=2).square().min(dim=1).values | |
else: | |
new_data_squared_distances = torch.full((new_data.shape[0],), float('inf'), device=self.device) | |
self.data = torch.cat((self.data, new_data), dim=0) | |
# Initialize min_squared_distances for new data | |
new_min_squared_distances = torch.full((new_data.shape[0],), float('inf'), device=self.device) | |
if self.min_squared_distances is not None: | |
self.min_squared_distances = torch.cat((self.min_squared_distances, new_min_squared_distances), dim=0) | |
else: | |
self.min_squared_distances = new_min_squared_distances | |
# Update min_squared_distances for the new data points | |
self.min_squared_distances[-new_data.shape[0]:] = new_data_squared_distances | |
@torch.no_grad() | |
def generate_ordering(self, k: int = -1): | |
""" | |
Generates a priority ordering of data points using the K-means++ algorithm. | |
Points are yielded one at a time in priority order. | |
Args: | |
k: The number of points to select in the ordering. If -1, all points are ordered. | |
Returns: | |
A generator that yields indices one at a time representing the priority ordering. | |
Yields: | |
int: The index of the next highest priority point. | |
""" | |
if self.data is None: | |
return torch.tensor([], dtype=torch.long, device=self.device) | |
n_samples = self.data.shape[0] | |
if k == -1: | |
k = n_samples | |
if not self.selected_indices: | |
# Pick first point randomly | |
first_index = torch.randint(n_samples, (1,), device=self.device).item() | |
self.selected_indices.append(first_index) | |
self.min_squared_distances[first_index] = 0 | |
yield first_index | |
# Process remaining points | |
while len(self.selected_indices) < k: | |
# Get the last selected center and compute squared distances to all points | |
last_center = self.data[self.selected_indices[-1]].unsqueeze(0) | |
new_squared_distances = torch.cdist(self.data, last_center, p=2).square()[:, 0] | |
# Update minimum squared distances if new distances are smaller | |
self.min_squared_distances = torch.minimum(self.min_squared_distances, new_squared_distances) | |
# Sample next center with probability proportional to squared distance | |
probs = self.min_squared_distances / self.min_squared_distances.sum() | |
cumprobs = torch.cumsum(probs, dim=0) | |
r = torch.rand(1, device=self.device) | |
next_index = torch.searchsorted(cumprobs, r).item() | |
self.selected_indices.append(next_index) | |
# Set distance to zero for selected point to avoid reselection | |
self.min_squared_distances[next_index] = 0 | |
yield next_index | |
raise StopIteration() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment