Skip to content

Instantly share code, notes, and snippets.

@junjihashimoto
Last active July 22, 2024 13:06
Show Gist options
  • Save junjihashimoto/3a3020797076f8b5a0b4afcf0b448b93 to your computer and use it in GitHub Desktop.
Save junjihashimoto/3a3020797076f8b5a0b4afcf0b448b93 to your computer and use it in GitHub Desktop.
#include <iostream>
#define M 4096
#define K 4096
#define N (2 * 4096)
#define BM 128
#define BN 128
#define BK 16
#define TM (BM / BK)
#define TN (BN / BK)
#define num_threads (BM * BN / (TM * TN))
#define NUM_TILEA (BM * BK / num_threads)
#define NUM_TILEB (BN * BK / num_threads)
__global__ void matMulKernel(float* a, float* b, float* c) {
// Calculate threadRow and threadCol
int threadRow = (threadIdx.x / (BN / TN)) * TM;
int threadCol = (threadIdx.x % (BN / TN)) * TN;
// Pointer initialization
int aPtr = blockIdx.y * BM * K;
int bPtr = blockIdx.x * BN * K;
int cPtr = blockIdx.y * BM * N + blockIdx.x * BN;
__shared__ float tileA[BM * BK];
__shared__ float tileB[BK * BN];
float threadResults[TM * TN] = {0};
// Loop through tiles
for (int bkidx = 0; bkidx < K; bkidx += BK) {
// Load tiles into shared memory
for (int idx = 0; idx < NUM_TILEA; idx++) {
tileA[threadIdx.x + idx * blockDim.x] = a[aPtr + ((threadIdx.x + idx * blockDim.x) / BK) * K + (threadIdx.x + idx * blockDim.x) % BK];
}
for (int idx = 0; idx < NUM_TILEB; idx++) {
tileB[threadIdx.x + idx * blockDim.x] = b[bPtr + ((threadIdx.x + idx * blockDim.x) / BK) * K + (threadIdx.x + idx * blockDim.x) % BK];
}
__syncthreads();
// Compute partial results
for (int dotIdx = 0; dotIdx < BK; dotIdx++) {
float localM[TM], localN[TN];
for (int idx = 0; idx < TM; idx++) {
localM[idx] = tileA[(threadRow + idx) * BK + dotIdx];
}
for (int idx = 0; idx < TN; idx++) {
localN[idx] = tileB[(threadCol + idx) * BK + dotIdx];
}
for (int resIdxM = 0; resIdxM < TM; resIdxM++) {
for (int resIdxN = 0; resIdxN < TN; resIdxN++) {
threadResults[resIdxM * TN + resIdxN] += localM[resIdxM] * localN[resIdxN];
}
}
}
__syncthreads();
}
// Write results
for (int resIdxM = 0; resIdxM < TM; resIdxM++) {
for (int resIdxN = 0; resIdxN < TN; resIdxN++) {
c[cPtr + (threadRow + resIdxM) * N + (threadCol + resIdxN)] = threadResults[resIdxM * TN + resIdxN];
}
}
}
int main() {
// Allocate memory
float *a, *b, *c;
cudaMalloc(&a, M * K * sizeof(float));
cudaMalloc(&b, K * N * sizeof(float));
cudaMalloc(&c, M * N * sizeof(float));
// Initialize matrices
// (Initialization code here)
// Kernel configuration
dim3 blockDim(BM * BN / (TM * TN));
dim3 gridDim((N + BN - 1) / BN, (M + BM - 1) / BM);
// Create CUDA events for timing
cudaEvent_t start, stop;
int niter = 5;
cudaEventCreate(&start);
cudaEventCreate(&stop);
// Record the start event
cudaEventRecord(start);
for(int i=0;i<niter;i++){
// Launch kernel
matMulKernel<<<gridDim, blockDim>>>(a, b, c);
}
// Record the stop event
cudaEventRecord(stop);
// Wait for the stop event to complete
cudaEventSynchronize(stop);
// Calculate elapsed time
float milliseconds = 0;
cudaEventElapsedTime(&milliseconds, start, stop);
// Calculate FLOPS
float flops = 2.0f * M * N * K;
float gflops = flops / (milliseconds / niter * 1e6);
std::cout << "Execution time: " << milliseconds << " ms\n";
std::cout << "Execution time / iterations: " << (milliseconds/ niter) << " ms\n";
std::cout << "GFLOPS: " << gflops << "\n";
// Copy result back to CPU
float* output = (float*)malloc(M * N * sizeof(float));
cudaMemcpy(output, c, M * N * sizeof(float), cudaMemcpyDeviceToHost);
// Check results
// (Check code here)
// Free memory
cudaFree(a);
cudaFree(b);
cudaFree(c);
free(output);
// Destroy CUDA events
cudaEventDestroy(start);
cudaEventDestroy(stop);
return 0;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment