-
-
Save junjihashimoto/3a3020797076f8b5a0b4afcf0b448b93 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <iostream> | |
#define M 4096 | |
#define K 4096 | |
#define N (2 * 4096) | |
#define BM 128 | |
#define BN 128 | |
#define BK 16 | |
#define TM (BM / BK) | |
#define TN (BN / BK) | |
#define num_threads (BM * BN / (TM * TN)) | |
#define NUM_TILEA (BM * BK / num_threads) | |
#define NUM_TILEB (BN * BK / num_threads) | |
__global__ void matMulKernel(float* a, float* b, float* c) { | |
// Calculate threadRow and threadCol | |
int threadRow = (threadIdx.x / (BN / TN)) * TM; | |
int threadCol = (threadIdx.x % (BN / TN)) * TN; | |
// Pointer initialization | |
int aPtr = blockIdx.y * BM * K; | |
int bPtr = blockIdx.x * BN * K; | |
int cPtr = blockIdx.y * BM * N + blockIdx.x * BN; | |
__shared__ float tileA[BM * BK]; | |
__shared__ float tileB[BK * BN]; | |
float threadResults[TM * TN] = {0}; | |
// Loop through tiles | |
for (int bkidx = 0; bkidx < K; bkidx += BK) { | |
// Load tiles into shared memory | |
for (int idx = 0; idx < NUM_TILEA; idx++) { | |
tileA[threadIdx.x + idx * blockDim.x] = a[aPtr + ((threadIdx.x + idx * blockDim.x) / BK) * K + (threadIdx.x + idx * blockDim.x) % BK]; | |
} | |
for (int idx = 0; idx < NUM_TILEB; idx++) { | |
tileB[threadIdx.x + idx * blockDim.x] = b[bPtr + ((threadIdx.x + idx * blockDim.x) / BK) * K + (threadIdx.x + idx * blockDim.x) % BK]; | |
} | |
__syncthreads(); | |
// Compute partial results | |
for (int dotIdx = 0; dotIdx < BK; dotIdx++) { | |
float localM[TM], localN[TN]; | |
for (int idx = 0; idx < TM; idx++) { | |
localM[idx] = tileA[(threadRow + idx) * BK + dotIdx]; | |
} | |
for (int idx = 0; idx < TN; idx++) { | |
localN[idx] = tileB[(threadCol + idx) * BK + dotIdx]; | |
} | |
for (int resIdxM = 0; resIdxM < TM; resIdxM++) { | |
for (int resIdxN = 0; resIdxN < TN; resIdxN++) { | |
threadResults[resIdxM * TN + resIdxN] += localM[resIdxM] * localN[resIdxN]; | |
} | |
} | |
} | |
__syncthreads(); | |
} | |
// Write results | |
for (int resIdxM = 0; resIdxM < TM; resIdxM++) { | |
for (int resIdxN = 0; resIdxN < TN; resIdxN++) { | |
c[cPtr + (threadRow + resIdxM) * N + (threadCol + resIdxN)] = threadResults[resIdxM * TN + resIdxN]; | |
} | |
} | |
} | |
int main() { | |
// Allocate memory | |
float *a, *b, *c; | |
cudaMalloc(&a, M * K * sizeof(float)); | |
cudaMalloc(&b, K * N * sizeof(float)); | |
cudaMalloc(&c, M * N * sizeof(float)); | |
// Initialize matrices | |
// (Initialization code here) | |
// Kernel configuration | |
dim3 blockDim(BM * BN / (TM * TN)); | |
dim3 gridDim((N + BN - 1) / BN, (M + BM - 1) / BM); | |
// Create CUDA events for timing | |
cudaEvent_t start, stop; | |
int niter = 5; | |
cudaEventCreate(&start); | |
cudaEventCreate(&stop); | |
// Record the start event | |
cudaEventRecord(start); | |
for(int i=0;i<niter;i++){ | |
// Launch kernel | |
matMulKernel<<<gridDim, blockDim>>>(a, b, c); | |
} | |
// Record the stop event | |
cudaEventRecord(stop); | |
// Wait for the stop event to complete | |
cudaEventSynchronize(stop); | |
// Calculate elapsed time | |
float milliseconds = 0; | |
cudaEventElapsedTime(&milliseconds, start, stop); | |
// Calculate FLOPS | |
float flops = 2.0f * M * N * K; | |
float gflops = flops / (milliseconds / niter * 1e6); | |
std::cout << "Execution time: " << milliseconds << " ms\n"; | |
std::cout << "Execution time / iterations: " << (milliseconds/ niter) << " ms\n"; | |
std::cout << "GFLOPS: " << gflops << "\n"; | |
// Copy result back to CPU | |
float* output = (float*)malloc(M * N * sizeof(float)); | |
cudaMemcpy(output, c, M * N * sizeof(float), cudaMemcpyDeviceToHost); | |
// Check results | |
// (Check code here) | |
// Free memory | |
cudaFree(a); | |
cudaFree(b); | |
cudaFree(c); | |
free(output); | |
// Destroy CUDA events | |
cudaEventDestroy(start); | |
cudaEventDestroy(stop); | |
return 0; | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment