Created
March 6, 2024 19:41
-
-
Save pchng/898314b97f46bb879d7810446da3ed93 to your computer and use it in GitHub Desktop.
CUDA: matmul good
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// Computes out = left @ right, where `@` is matrix muliplication | |
// Dimensions: | |
// left: a x b | |
// right: b x c | |
// out: a x c | |
// monolithic kernel: One thread per output element in matrix C. | |
// (Each thread computes the dot product between a row in `left` and a col in `right`) | |
__global__ void matMul(float *left, float *right, float *out, int a, int b, int c) { | |
// Use y to index to rows, x to index to cols (just to match typical visualization) | |
// row indexes into left, col indexes into right. | |
int row = blockIdx.y * blockDim.y + threadIdx.y; | |
int col = blockIdx.x * blockDim.x + threadIdx.x; | |
if (row < a && col < c) { | |
float sum = 0.0; | |
// Each row of `left` and each col of `right` has `b` elements. | |
for (int i = 0; i < b; i++) { | |
// 1. If the row (indexed by threadIdx.y) is not changing, every thread reads the same row from `left`. | |
// This will reduce to a single read for each iteration, and likely will be cached after the first read since | |
// elements are being read consecutively on each iteration. | |
// 2. If threadIdx.x is changing, then the 32 threads of the warp will read consecutive positions from | |
// the same row in `right` on each iteration in coalesced fashion, so only one read per iteration across the warp. | |
// This results in 32 cols being read across the entire loop. | |
sum += left[row * b + i] * right[i * c + col]; | |
} | |
// 3. The write is done in coalesced fashion if we assume `row` is not changing across each thread in the warp, but `col` is. | |
// Then adjacent threads will write to adjacent elements in `out`. | |
out[row * c + col] = sum; | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment