Skip to content

Instantly share code, notes, and snippets.

@fcharras
Last active June 20, 2023 13:55
Show Gist options
  • Save fcharras/ce1f1df7d15675268827e1fb9b65265b to your computer and use it in GitHub Desktop.
Save fcharras/ce1f1df7d15675268827e1fb9b65265b to your computer and use it in GitHub Desktop.
k-means lloyd implementation with pytorch (not fused)
import torch
import math
import numpy as np
def kmeans_single(
# NB: best performance might depend on the layout for `X` and `centroids`
# TODO: benchmark and warns or error out if the layout is not adapted
X, # (n_samples, n_features)
sample_weight, # (n_samples,)
centroids, # (n_clusters, n_features)
# NB: centroids data is ovveridden during
# compute
max_iter, # int
tol, # float
verbose, # bool
max_compute_buffer_bytes=1073741824, # int (default 1 GiB)
):
n_samples, n_features = X.shape
n_clusters = centroids.shape[0]
compute_dtype = X[-1, -1].cpu().numpy().dtype.type
compute_dtype_itemsize = np.dtype(compute_dtype).itemsize
# The computation of nearest centroids will be batched and the size of each batch
# is set so that the size of the buffer of pairwise distances computed for this
# batch do not exceed `maximum_comnpute_buffer_size`
(
assignment_batch_size,
assignment_n_batches,
assignment_n_full_batches,
assignment_last_batch_size
) = _get_batch_properties(
expected_bytes_per_sample = n_clusters * compute_dtype_itemsize,
max_compute_buffer_bytes = max_compute_buffer_bytes,
dataset_n_samples = n_samples
)
# Batching the update of the centroids is also necessary to support non-uniform
# sample weights.
(
update_batch_size,
update_n_batches,
update_n_full_batches,
update_last_batch_size
) = _get_batch_properties(
expected_bytes_per_sample = n_features * compute_dtype_itemsize,
max_compute_buffer_bytes = max_compute_buffer_bytes,
dataset_n_samples = n_samples
)
# Pre-allocate buffers that will be reused accross iterations (rather than re-
# allocated)
new_centroids = torch.zeros_like(centroids) # TODO: test memory layouts ?
weight_in_clusters = torch.zeros(n_clusters, dtype=X.dtype, device=X.device)
new_weight_in_clusters = torch.zeros_like(weight_in_clusters)
# Those buffers that will store centroid assignments for each sample are
# over-allocated with `n_clusters` extra values ranging for 0 to `n_clusters`,
# that are used to detect empty clusters later on using torch.unique
assignments_idx_extended = torch.empty(
(n_samples + n_clusters, 1), dtype=torch.int64, device=X.device
)
assignments_idx_extended[n_samples:] = torch.arange(
n_clusters, dtype=assignments_idx_extended.dtype, device=X.device
).unsqueeze(1)
assignments_idx = assignments_idx_extended[:n_samples]
new_assignments_idx_extended = torch.empty_like(assignments_idx_extended)
new_assignments_idx_extended[n_samples:] = assignments_idx_extended[n_samples:]
new_assignments_idx = new_assignments_idx_extended[:n_samples]
dist_to_nearest_centroid = torch.empty(
(n_samples, 1), dtype=X.dtype, device=X.device
)
dist_to_nearest_centroid_sqz = dist_to_nearest_centroid.squeeze(1)
n_iteration = 0
strict_convergence = False
centroid_shifts_sum = torch.inf
while (n_iteration < max_iter) and (
centroid_shifts_sum > tol
):
# NB: current implementation of _min_over_pairwise_distance is underwhelming
# because for each batch it materializes in memory the pairwise distance matrix,
# before searching the closest centroid. The IO from writing and reading from
# global memory becomes the bottleneck. It can be about 3 times faster (or
# more ?) if the pairwise distance and the min lookup are fused together
# in a way that global memory is not used anymore. That would require a custom
# low level implementation (e.g using triton directly), `torch.compiler`
# doesn't seem to support automatically fusing `torch.cdist` and `torch.min`.
_min_over_pairwise_distance(
X,
centroids,
assignment_n_batches,
assignment_n_full_batches,
assignment_batch_size,
assignment_last_batch_size,
# OUT
dist_to_nearest_centroid,
new_assignments_idx,
)
# ???: should we pass `sorted=False` ?
unique_clusters, counts = torch.unique(
new_assignments_idx_extended, return_counts=True
)
empty_clusters_list = unique_clusters[counts == 1]
new_centroids[:] = 0
new_weight_in_clusters[:] = 0
# relocate empty clusters if such clusters are detected
if (n_empty_clusters := len(empty_clusters_list)) > 0:
print("relocation event")
# ???: should we pass `sorted=False` ?
samples_far_from_center = torch.topk(
dist_to_nearest_centroid_sqz, n_empty_clusters
).indices
new_centroids[empty_clusters_list] = X[samples_far_from_center]
new_assignments_idx[
samples_far_from_center
] = empty_clusters_list.unsqueeze(1)
dist_to_nearest_centroid[samples_far_from_center] = 0
if verbose:
inertia = (
sample_weight
* dist_to_nearest_centroid_sqz
* dist_to_nearest_centroid_sqz
).sum().item()
print(f"Iteration {n_iteration}, inertia {inertia:5.3e}")
# update centers
# NB: (same comment than for `_min_over_pairwise_distance`)
# Multipliying with weights and then using `scatter_add_` could be fused
# together, yet again with a x2 - x3 speedup.
batch_start_idx = batch_end_idx = 0
for batch_idx in range(update_n_batches):
if batch_idx == update_n_full_batches:
batch_end_idx += update_last_batch_size
else:
batch_end_idx += update_batch_size
batch_slice = slice(batch_start_idx, batch_end_idx)
X_weighted = X[batch_slice] * sample_weight[batch_slice].unsqueeze(1)
new_centroids.scatter_add_(
dim=0,
# NB: expand does not allocate memory, it's like a "repeated view"
index=new_assignments_idx[batch_slice].expand(-1, n_features),
src=X_weighted
)
del X_weighted
# HACK: force synchronization to avoid memory overflow
# Similar to torch.cuda.synchronize(X.device) but with device
# interoperability for a negligible cost.
new_centroids[-1, -1].item()
batch_start_idx += update_batch_size
new_weight_in_clusters.scatter_add_(
dim=0, index=new_assignments_idx.squeeze(), src=sample_weight
)
new_centroids /= new_weight_in_clusters.unsqueeze(1)
centroids, new_centroids = new_centroids, centroids
assignments_idx, new_assignments_idx = new_assignments_idx, assignments_idx
assignments_idx_extended, new_assignments_idx_extended = (
new_assignments_idx_extended, assignments_idx_extended
)
n_iteration += 1
if (n_iteration > 1) and (
strict_convergence := bool(
(assignments_idx == new_assignments_idx).all())
):
break
new_centroids -= centroids
new_centroids *= new_centroids
centroid_shifts_sum = new_centroids.sum().item()
if verbose:
converged_at = n_iteration - 1
# NB: possible if tol = 0
if strict_convergence or (centroid_shifts_sum == 0):
print(f"Converged at iteration {converged_at}: strict convergence.")
elif centroid_shifts_sum <= tol:
print(
f"Converged at iteration {converged_at}: center shift "
f"{centroid_shifts_sum} within tolerance {tol}."
)
# TODO: if strict_convergence: no need to do that
_min_over_pairwise_distance(
X,
centroids,
assignment_n_batches,
assignment_n_full_batches,
assignment_batch_size,
assignment_last_batch_size,
# OUT
dist_to_nearest_centroid,
assignments_idx,
)
inertia = (
sample_weight
* dist_to_nearest_centroid_sqz
* dist_to_nearest_centroid_sqz
).sum().item()
return assignments_idx.squeeze(), inertia, centroids, n_iteration
def _get_batch_properties(
expected_bytes_per_sample,
max_compute_buffer_bytes,
dataset_n_samples
):
batch_size = (
max_compute_buffer_bytes /
expected_bytes_per_sample
)
if batch_size < 1:
raise RuntimeError("Buffer size is too small")
batch_size = min(math.floor(batch_size), dataset_n_samples)
n_batches = math.ceil(dataset_n_samples / batch_size)
n_full_batches = n_batches - 1
last_batch_size = ((dataset_n_samples - 1) % batch_size) + 1
return batch_size, n_batches, n_full_batches, last_batch_size
def _min_over_pairwise_distance(
X, # IN (n_samples, n_features)
centroids, # IN (n_clusters, n_feautres)
n_batches, # PARAM int
n_full_batches, # PARAM int
batch_size, # PARAM int
last_batch_size, # PARAM int
dist_to_nearest_centroid, # OUT (n_samples, n_clusters)
assignments_idx, # OUT (n_samples,)
):
"""The result is returned in `dist_to_nearest_centroid` and `assignments_idx`
arrays that are modified inplace"""
# TODO: slice here so that pairwise_distance has a max size of 1GB
batch_start_idx = batch_end_idx = 0
for batch_idx in range(n_batches):
if batch_idx == n_full_batches:
batch_end_idx += last_batch_size
else:
batch_end_idx += batch_size
batch_slice = slice(batch_start_idx, batch_end_idx)
pairwise_distances = torch.cdist(X[batch_slice], centroids)
torch.min(
pairwise_distances,
axis=1,
keepdims=True,
out=(
dist_to_nearest_centroid[batch_slice],
assignments_idx[batch_slice]
)
)
del pairwise_distances
# HACK: force synchronization to avoid memory overflow
# Similar to torch.cuda.synchronize(X.device) but with device interoperability
# for a negligible cost.
assignments_idx[-1, -1].item()
batch_start_idx += batch_size
if __name__ == "__main__":
n_samples = 500000 # common sizes: 50000, 50000000, 50000000
n_features = 14
n_clusters = 127
max_iter = 20
tol = 0
verbose = True
device = "cuda"
dtype = torch.float32
seed = 123
rng = torch.Generator(device=device).manual_seed(543212345)
X = torch.rand(
n_samples, n_features, generator=rng, dtype=dtype, device=device
)
centroids = torch.rand(
n_clusters, n_features, generator=rng, dtype=dtype, device=device
)
sample_weight = torch.rand(n_samples, generator=rng, dtype=dtype, device=device)
kmeans_single(
X,
sample_weight,
centroids,
max_iter,
tol,
verbose,
max_compute_buffer_bytes=1073741824,
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment