Skip to content

Instantly share code, notes, and snippets.

@cthoyt
Created January 12, 2024 12:52
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save cthoyt/598921a58c6ff74c23b971ce0b34b862 to your computer and use it in GitHub Desktop.
Save cthoyt/598921a58c6ff74c23b971ce0b34b862 to your computer and use it in GitHub Desktop.
This is a script I was using a long time ago `torch-max-mem` relevant for https://github.com/mberr/torch-max-mem/issues/14. This cause crashes on my MPS GPU
import torch
from torch_max_mem import maximize_memory_utilization
import logging
import torch.mps
from humanize.filesize import naturalsize
logging.basicConfig(level=logging.DEBUG)
@maximize_memory_utilization()
def knn(x, y, batch_size: int, k: int = 3):
return torch.cat(
[
torch.cdist(x[start : start + batch_size], y).topk(k=k, dim=1, largest=False).indices
for start in range(0, x.shape[0], batch_size)
],
dim=0,
)
print(f"Currently allocated: {naturalsize(torch.mps.current_allocated_memory())}")
x = torch.rand(100000, 100, device="mps")
y = torch.rand(200000, 100, device="mps")
print(f"Currently allocated after making test data: {naturalsize(torch.mps.current_allocated_memory())}")
z = knn(x, y, batch_size=x.shape[0])
print(z)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment