Skip to content

Instantly share code, notes, and snippets.

@ahennequ
Created September 15, 2022 18:04
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save ahennequ/5bc4ebde0caa2541fc8fc40546652ddd to your computer and use it in GitHub Desktop.
Save ahennequ/5bc4ebde0caa2541fc8fc40546652ddd to your computer and use it in GitHub Desktop.
Use this program to find out about tensor core's accumulator warp register layout
#include <stdio.h>
// Check tensor core's warp register layout
// nvcc -arch=sm_75 tensorcore_mapping.cu -o mapping
// ./mapping
// Define some error checking macros.
#define cudaErrCheck(stat) { cudaErrCheck_((stat), __FILE__, __LINE__); }
void cudaErrCheck_(cudaError_t stat, const char *file, int line) {
if (stat != cudaSuccess) {
fprintf(stderr, "CUDA Error: %s %s %d\n", cudaGetErrorString(stat), file, line);
}
}
#include <mma.h>
using namespace nvcuda;
__device__ int getWarpRow(int i) {
int tid = threadIdx.x % 32;
return ((i / 2) % 2) * 8 + tid / 4;
}
__device__ int getWarpCol(int i) {
int tid = threadIdx.x % 32;
return (tid % 4) * 2 + i % 2 + (i / 4) * 8;
}
__global__ void wmma_example(float *elem, float* thread, float* row, float* col) {
wmma::fragment<wmma::accumulator, 16, 16, 16, float> acc_frag;
wmma::fill_fragment(acc_frag, 0.0f);
for (int i=0 ; i<acc_frag.num_elements; i++) {
acc_frag.x[i] = i;
}
wmma::store_matrix_sync(elem, acc_frag, 16, wmma::mem_row_major);
wmma::fill_fragment(acc_frag, 0.0f);
for (int i=0 ; i<acc_frag.num_elements; i++) {
acc_frag.x[i] = threadIdx.x;
}
wmma::store_matrix_sync(thread, acc_frag, 16, wmma::mem_row_major);
// row:
wmma::fill_fragment(acc_frag, 0.0f);
for (int i=0 ; i<acc_frag.num_elements; i++) {
acc_frag.x[i] = getWarpRow(i);
}
wmma::store_matrix_sync(row, acc_frag, 16, wmma::mem_row_major);
// col:
wmma::fill_fragment(acc_frag, 0.0f);
for (int i=0 ; i<acc_frag.num_elements; i++) {
acc_frag.x[i] = getWarpCol(i);
}
wmma::store_matrix_sync(col, acc_frag, 16, wmma::mem_row_major);
}
int main(int argc, char* argv[]) {
float *elem;
float *thread;
float *row;
float *col;
float *elem_host;
float *thread_host;
float *row_host;
float *col_host;
// Use tensor cores
cudaErrCheck(cudaMalloc((void**)&elem, 16 * 16 * sizeof(float)));
cudaErrCheck(cudaMalloc((void**)&thread, 16 * 16 * sizeof(float)));
cudaErrCheck(cudaMalloc((void**)&row, 16 * 16 * sizeof(float)));
cudaErrCheck(cudaMalloc((void**)&col, 16 * 16 * sizeof(float)));
elem_host = (float*)malloc(16 * 16 * sizeof(float));
thread_host = (float*)malloc(16 * 16 * sizeof(float));
row_host = (float*)malloc(16 * 16 * sizeof(float));
col_host = (float*)malloc(16 * 16 * sizeof(float));
// First: using WMMA
dim3 gridDim(1);
dim3 blockDim(32);
printf("Running with wmma...\n");
wmma_example <<< gridDim, blockDim >>> (elem, thread, row, col);
// Error checking
printf("\nChecking results...\n");
cudaErrCheck(cudaMemcpy(elem_host, elem, 16 * 16 * sizeof(float), cudaMemcpyDeviceToHost));
cudaErrCheck(cudaMemcpy(thread_host, thread, 16 * 16 * sizeof(float), cudaMemcpyDeviceToHost));
cudaErrCheck(cudaMemcpy(row_host, row, 16 * 16 * sizeof(float), cudaMemcpyDeviceToHost));
cudaErrCheck(cudaMemcpy(col_host, col, 16 * 16 * sizeof(float), cudaMemcpyDeviceToHost));
printf("Elem:\n");
for (int i=0; i<16 ; i++) {
for (int j=0; j<16; j++) {
printf("%2d ", (int) elem_host[i*16+j]);
}
printf("\n");
}
printf("ThreadIdx:\n");
for (int i=0; i<16 ; i++) {
for (int j=0; j<16; j++) {
printf("%2d ", (int) thread_host[i*16+j]);
}
printf("\n");
}
printf("Row:\n");
for (int i=0; i<16 ; i++) {
for (int j=0; j<16; j++) {
printf("%2d ", (int) row_host[i*16+j]);
}
printf("\n");
}
printf("Col:\n");
for (int i=0; i<16 ; i++) {
for (int j=0; j<16; j++) {
printf("%2d ", (int) col_host[i*16+j]);
}
printf("\n");
}
cudaErrCheck(cudaFree(elem));
cudaErrCheck(cudaFree(thread));
cudaErrCheck(cudaFree(row));
cudaErrCheck(cudaFree(col));
free(elem_host);
free(thread_host);
free(row_host);
free(col_host);
cudaErrCheck(cudaDeviceReset());
return 0;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment