Skip to content

Instantly share code, notes, and snippets.

@tjyuyao
Last active April 6, 2022 08:37
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save tjyuyao/6b30cad01ed2e4b031ed750aaa047bf6 to your computer and use it in GitHub Desktop.
Save tjyuyao/6b30cad01ed2e4b031ed750aaa047bf6 to your computer and use it in GitHub Desktop.
pycuda and pytorch direct inter operation on tensors with multidimensional element accessor.
import torch
import cutex
M, N, K = 4, 4, 1
a = torch.rand((M, K), dtype=torch.float32).cuda()
b = torch.rand((K, N), dtype=torch.float32).cuda()
c = torch.empty((M, N), dtype=torch.float32).cuda()
kernels = cutex.SourceModule(r"""
__global__ void matmul(Tensor<float, 2> *a, Tensor<float, 2> *b, Tensor<float, 2> *c, int M, int N, int K) {
int m = blockIdx.y * blockDim.y + threadIdx.y;
int n = blockIdx.x * blockDim.x + threadIdx.x;
float v = 0.f;
if (m >= M || n >= N) return;
for (int k = 0; k < K; ++k) {
v += (*a)[m][k] * (*b)[k][n];
}
(*c)[m][n] = v;
}
""", float_bits=32)
kernels.matmul(a, b, c, M, N, K, grid=(N // 32 + 1, M // 32 + 1), block=(32, 32, 1))
torch.cuda.synchronize()
assert torch.allclose(c, torch.mm(a, b))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment