Skip to content

Instantly share code, notes, and snippets.

@xmodar
Last active April 7, 2021 03:48
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save xmodar/d9dcf4d63dd60a0f31a7a8afb78a3eb6 to your computer and use it in GitHub Desktop.
Save xmodar/d9dcf4d63dd60a0f31a7a8afb78a3eb6 to your computer and use it in GitHub Desktop.
import itertools
from typing import Tuple, Optional
from contextlib import contextmanager
import torch
from torch.utils import benchmark
# @torch.jit.script
def nearest_neighbors(
num_neighbors: int,
features: torch.Tensor,
neighbors: Optional[torch.Tensor] = None,
p_norm: float = 2,
ordered: bool = False) -> Tuple[torch.Tensor, torch.Tensor]:
"""Get the distances to the k nearest neighbors of features
Args:
num_neighbors: number of neighbors to consider
features: query points which we need their neighbors [*, N, F]
neighbors: set of neighborhood points (features if None) [*, M, F]
p_norm: distance is computed based on L_p norm
ordered: force the neighbors to be sorted by distance (descending)
Returns:
(distances, indices) both of shape [*, N, num_neighbors]
"""
# features:[*, N, F], neighbors:[*, M, F] -> [*, N, k]
if neighbors is None:
neighbors = features
distances = torch.cdist(features, neighbors, p_norm)
return distances.topk(num_neighbors, dim=-1, largest=False, sorted=ordered)
@contextmanager
def print_memory(device=None):
try:
torch.cuda.reset_peak_memory_stats(device)
memory = -torch.cuda.max_memory_allocated(device)
yield
finally:
memory += torch.cuda.max_memory_allocated(device)
print(f'Used {memory / 1024**2:.3f} MB in {device}')
def pow_2s(x):
while x > 0:
yield x
x //= 2
def main():
# TODO: use https://pytorch.org/docs/stable/profiler.html
eager = nearest_neighbors
script = torch.jit.script(eager)
device = torch.device('cuda:0')
x = torch.randn(1024, 256, device=device)
y = torch.randn(2, 512, 256, device=device)
with print_memory(device):
print('eager')
eager(16, x, y)
with print_memory(device):
print('script')
script(16, x, y)
num_threads = torch.get_num_threads()
results = benchmark.Compare([
benchmark.Timer(
label='kNN',
sub_label=str(device),
description='eager' if method is eager else 'script',
stmt='knn(k, x, y)',
globals={
'k': 16,
'x': x.to(device),
'y': y.to(device),
'knn': method,
},
num_threads=num_threads,
).blocked_autorange(min_run_time=1)
for device, method, num_threads in itertools.product(
['cpu', 'cuda:0'], [eager, script], pow_2s(num_threads))
])
# results.trim_significant_figures()
results.colorize(rowwise=True)
results.print()
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment