Skip to content

Instantly share code, notes, and snippets.

@Ediolot
Created January 27, 2023 12:50
Show Gist options
  • Save Ediolot/9784cd43ff90522eb92e0d2625f53f88 to your computer and use it in GitHub Desktop.
Save Ediolot/9784cd43ff90522eb92e0d2625f53f88 to your computer and use it in GitHub Desktop.
Vector addition in CUDA and Torch
#include <torch/extension.h>
#include <cstdint>
__global__ void kernel_vector_add(uint32_t N, float* a, float* b, float* c) {
uint32_t idx = blockIdx.x * blockDim.x + threadIdx.x;
uint32_t max_threads = gridDim.x * blockDim.x;
for (uint32_t i = idx; i < N; i += max_threads) {
c[i] = a[i] + b[i];
}
}
void custom_vector_add(
uint32_t blocks,
uint32_t threads,
at::Tensor a,
at::Tensor b,
at::Tensor c
) {
assert(a.is_contiguous());
assert(b.is_contiguous());
assert(c.is_contiguous());
assert(a.is_cuda());
assert(b.is_cuda());
assert(c.is_cuda());
assert(a.numel() == b.numel());
assert(a.numel() == c.numel());
kernel_vector_add<<<blocks, threads>>>(
a.numel(),
(float*)a.data_ptr<float>(),
(float*)b.data_ptr<float>(),
(float*)c.data_ptr<float>()
);
cudaDeviceSynchronize();
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("vector_add", &custom_vector_add, "Adds two vectors");
}
// Python main.py
import time
import torch
from torch.utils.cpp_extension import load
def main():
print('Compiling...')
sample = load(
name='sample',
sources=['cuda/vector_add.cu'])
print('Done!')
a = torch.randn(10000, device='cuda')
b = torch.randn(10000, device='cuda')
c = torch.empty(10000, device='cuda')
# PyTorch Kernel GPU
t0 = torch.cuda.Event(enable_timing=True)
t1 = torch.cuda.Event(enable_timing=True)
t0.record()
c = a + b
t1.record()
torch.cuda.synchronize()
print(f'Torch GPU: ({t0.elapsed_time(t1):.2f}ms)')
assert torch.allclose(c, a + b)
# Custom Kernel
t0 = torch.cuda.Event(enable_timing=True)
t1 = torch.cuda.Event(enable_timing=True)
t0.record()
sample.vector_add(4096, 256, a, b, c)
t1.record()
torch.cuda.synchronize()
print(f'Custom Kernel: ({t0.elapsed_time(t1):.2f}ms)')
assert torch.allclose(c, a + b)
# PyTorch CPU
a_host = a.cpu()
b_host = b.cpu()
t0 = time.perf_counter()
c_host = a_host + b_host
t1 = time.perf_counter()
print(f'Torch CPU: ({(t1 - t0) * 1000:.2f}ms)')
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment