Skip to content

Instantly share code, notes, and snippets.

@colesbury
Created September 20, 2019 00:27
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save colesbury/11b3308539f5162b178cc229a2aac9c5 to your computer and use it in GitHub Desktop.
Save colesbury/11b3308539f5162b178cc229a2aac9c5 to your computer and use it in GitHub Desktop.
Pinned memory trick from SpeedTorch
#include <torch/extension.h>
using namespace torch;
torch::Tensor pinned_as_cuda(Tensor x) {
return torch::from_blob(
x.data_ptr(),
x.sizes(),
x.strides(),
[x](void*) {
(void)x;
},
kCUDA);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("pinned_as_cuda", &pinned_as_cuda, "Pinned memory as CUDA");
}
import torch
import time
from torch.utils.cpp_extension import load
pin_cpp = load(name="pin_cpp", sources=["pin.cpp"])
data = torch.randn(1000000, 128, pin_memory=True)
data_fake_gpu = pin_cpp.pinned_as_cuda(data)
idx = torch.randint(0, data.shape[0], (131072,))
def benchmark(fn, N=100):
torch.cuda.synchronize()
start = time.time()
for _ in range(N):
fn()
torch.cuda.synchronize()
end = time.time()
return 1000 * (end - start) / N
# "normal" way
out = data[idx].cuda()
print("normal", benchmark(lambda: data[idx].cuda()), 'ms') # ~19 ms
# "fused" indexing + copy
out2 = data_fake_gpu[idx]
print("fused", benchmark(lambda: data_fake_gpu[idx]), 'ms') # ~6 ms
assert((out == out2).all())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment