Skip to content

Instantly share code, notes, and snippets.

@pchng
Created March 6, 2024 19:41
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save pchng/898314b97f46bb879d7810446da3ed93 to your computer and use it in GitHub Desktop.
Save pchng/898314b97f46bb879d7810446da3ed93 to your computer and use it in GitHub Desktop.
CUDA: matmul good
// Computes out = left @ right, where `@` is matrix muliplication
// Dimensions:
// left: a x b
// right: b x c
// out: a x c
// monolithic kernel: One thread per output element in matrix C.
// (Each thread computes the dot product between a row in `left` and a col in `right`)
__global__ void matMul(float *left, float *right, float *out, int a, int b, int c) {
// Use y to index to rows, x to index to cols (just to match typical visualization)
// row indexes into left, col indexes into right.
int row = blockIdx.y * blockDim.y + threadIdx.y;
int col = blockIdx.x * blockDim.x + threadIdx.x;
if (row < a && col < c) {
float sum = 0.0;
// Each row of `left` and each col of `right` has `b` elements.
for (int i = 0; i < b; i++) {
// 1. If the row (indexed by threadIdx.y) is not changing, every thread reads the same row from `left`.
// This will reduce to a single read for each iteration, and likely will be cached after the first read since
// elements are being read consecutively on each iteration.
// 2. If threadIdx.x is changing, then the 32 threads of the warp will read consecutive positions from
// the same row in `right` on each iteration in coalesced fashion, so only one read per iteration across the warp.
// This results in 32 cols being read across the entire loop.
sum += left[row * b + i] * right[i * c + col];
}
// 3. The write is done in coalesced fashion if we assume `row` is not changing across each thread in the warp, but `col` is.
// Then adjacent threads will write to adjacent elements in `out`.
out[row * c + col] = sum;
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment