Skip to content

Instantly share code, notes, and snippets.

@fcharras
Last active July 20, 2023 11:20
Show Gist options
  • Save fcharras/82772cf7651e087b3b91b99105a860dd to your computer and use it in GitHub Desktop.
Save fcharras/82772cf7651e087b3b91b99105a860dd to your computer and use it in GitHub Desktop.
import math
import numpy as np
import torch
def kneighbors(
# NB: best performance might depend on the layout for `X` and `centroids`
# TODO: benchmark and warns or error out if the layout is not adapted
query, # (n_queries, n_features)
data, # (n_samples, n_features)
n_neighbors, # int
metric="euclidean", # str
max_compute_buffer_bytes=1073741824, # int (default 1 GiB)
):
n_queries, n_features = query.shape
n_samples = data.shape[0]
compute_dtype = query[0,0].cpu().numpy().dtype.type
compute_dtype_itemsize = np.dtype(compute_dtype).itemsize
# The computation will be batched and the size of each batch is set so that the
# size of the buffer of pairwise distances computed for this batch do not exceed
# `maximum_comnpute_buffer_size`
(
batch_size,
n_batches,
n_full_batches,
last_batch_size
) = _get_batch_properties(
expected_bytes_per_sample = n_samples * compute_dtype_itemsize,
max_compute_buffer_bytes = max_compute_buffer_bytes,
dataset_n_samples = n_samples
)
if batch_size < 1:
raise RuntimeError("Buffer size is too small")
result = torch.empty(n_queries, n_neighbors, dtype=query.dtype, device=query.device)
idx = torch.empty(n_queries, n_neighbors, dtype=torch.int64, device=query.device)
batch_start_idx = batch_end_idx = 0
# TODO: investigate if it's possible to fuse pairwise distance computation and topk
# search. (seems there's no profitable way to do it on gpu)
for batch_idx in range(n_batches):
if batch_idx == n_full_batches:
batch_end_idx += last_batch_size
else:
batch_end_idx += batch_size
batch_slice = slice(batch_start_idx, batch_end_idx)
pairwise_distance = torch.cdist(query[batch_slice], data)
# ???: should we pass `sorted=False` ?
torch.topk(
pairwise_distance,
n_neighbors,
largest=False,
sorted=True,
out=(
result[batch_slice], idx[batch_slice]
)
)
del pairwise_distance
batch_start_idx += batch_size
# HACK: force synchronization to avoid memory overflow similar to
# torch.cuda.synchronize(X.device) but with device interoperability for a
# negligible cost.
result[-1, -1].item()
return result, idx
def _get_batch_properties(
expected_bytes_per_sample,
max_compute_buffer_bytes,
dataset_n_samples
):
batch_size = (
max_compute_buffer_bytes /
expected_bytes_per_sample
)
if batch_size < 1:
raise RuntimeError("Buffer size is too small")
batch_size = min(math.floor(batch_size), dataset_n_samples)
n_batches = math.ceil(dataset_n_samples / batch_size)
n_full_batches = n_batches - 1
last_batch_size = ((dataset_n_samples - 1) % batch_size) + 1
return batch_size, n_batches, n_full_batches, last_batch_size
if __name__ == "__main__":
n_samples = 1000000 # common sizes: 10000, 100000, 1000000
n_features = 100
n_queries = 1000
n_neighbors = 100
device = "cuda"
dtype = torch.float32
seed = 123
rng = torch.Generator(device=device).manual_seed(543212345)
data = torch.rand(
n_samples, n_features, generator=rng, dtype=dtype, device=device
)
query = torch.rand(
n_queries, n_features, generator=rng, dtype=dtype, device=device
)
kneighbors(
query,
data,
n_neighbors,
metric="euclidean",
max_compute_buffer_bytes=1073741824,
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment