Skip to content

Instantly share code, notes, and snippets.

@secondspass
Created March 29, 2024 16:02
Show Gist options
  • Save secondspass/2baae9750b5567ace12a47882893f0d6 to your computer and use it in GitHub Desktop.
Save secondspass/2baae9750b5567ace12a47882893f0d6 to your computer and use it in GitHub Desktop.
TIL Corner Turning example code
#include <hip/hip_runtime.h>
#include <hip/nvidia_detail/nvidia_hip_runtime_api.h>
#include <iostream>
#include <math.h>
#include <stdlib.h>
#define gpuErrorCheck(call) \
do { \
hipError_t gpuErr = call; \
if (hipSuccess != gpuErr) { \
printf("HIP Error - %s:%d: '%s'\n", __FILE__, __LINE__, \
hipGetErrorString(gpuErr)); \
exit(0); \
} \
} while (0)
#define TILE_WIDTH 32
__global__ void MatrixMulKernel(float *M, float *N, float *P, int width) {
__shared__ float Mds[TILE_WIDTH][TILE_WIDTH];
__shared__ float Nds[TILE_WIDTH][TILE_WIDTH];
int bdx = blockIdx.x;
int bdy = blockIdx.y;
int tdx = threadIdx.x;
int tdy = threadIdx.y;
int row = bdy * TILE_WIDTH + tdy;
int col = bdx * TILE_WIDTH + tdx;
float PValue = 0;
for (int ph = 0; ph < ceil(width / (float)TILE_WIDTH); ph++) {
if ((row < width) && (ph * TILE_WIDTH + tdx) < width)
Mds[tdy][tdx] = M[row * width + ph * TILE_WIDTH + tdx];
else
Mds[tdy][tdx] = 0.0f;
if ((col < width) && ((ph * TILE_WIDTH + tdy) < width))
Nds[tdy][tdx] = N[(ph * TILE_WIDTH + tdy) * width + col];
else
Nds[tdy][tdx] = 0.0f;
__syncthreads();
for (int k = 0; k < TILE_WIDTH; k++) {
PValue += Mds[tdy][k] * Nds[k][tdx];
}
__syncthreads();
}
if ((row < width) && (col < width))
P[row * width + col] = PValue;
}
int main(int argc, char **argv) {
int width = atoi(argv[1]);
float *M;
float *N;
float *P;
float *P_verify;
float *M_d;
float *N_d;
float *P_d;
size_t size = width * width * sizeof(float);
M = (float *)malloc(size);
N = (float *)malloc(size);
P = (float *)malloc(size);
P_verify = (float *)malloc(size);
gpuErrorCheck(hipMalloc((void **)&M_d, size));
gpuErrorCheck(hipMalloc((void **)&N_d, size));
gpuErrorCheck(hipMalloc((void **)&P_d, size));
srand(1234);
for (int i = 0; i < width * width; i++) {
M[i] = (float)rand() / (float)RAND_MAX;
N[i] = (float)rand() / (float)RAND_MAX;
}
gpuErrorCheck(hipMemcpy(M_d, M, size, hipMemcpyHostToDevice));
gpuErrorCheck(hipMemcpy(N_d, N, size, hipMemcpyHostToDevice));
int griddim = ceil(width/(float)TILE_WIDTH);
dim3 dimGrid(griddim, griddim, 1);
// remember that typically a tile is the same as the block size (since
// each thread in a block loads one element in the tile into shared memory)
dim3 dimBlock(TILE_WIDTH, TILE_WIDTH, 1);
MatrixMulKernel<<<dimGrid, dimBlock>>>(M_d, N_d, P_d, width);
gpuErrorCheck(hipDeviceSynchronize());
gpuErrorCheck(hipMemcpy(P, P_d, size, hipMemcpyDeviceToHost));
}
#include <hip/hip_runtime.h>
#include <hip/nvidia_detail/nvidia_hip_runtime_api.h>
#include <iostream>
#include <math.h>
#include <stdlib.h>
#define gpuErrorCheck(call) \
do { \
hipError_t gpuErr = call; \
if (hipSuccess != gpuErr) { \
printf("HIP Error - %s:%d: '%s'\n", __FILE__, __LINE__, \
hipGetErrorString(gpuErr)); \
exit(0); \
} \
} while (0)
#define TILE_WIDTH 32
__global__ void MatrixMulKernel(float *M, float *N, float *P, int width) {
__shared__ float Mds[TILE_WIDTH][TILE_WIDTH];
__shared__ float Nds[TILE_WIDTH][TILE_WIDTH];
int bdx = blockIdx.x;
int bdy = blockIdx.y;
int tdx = threadIdx.x;
int tdy = threadIdx.y;
int row = bdy * TILE_WIDTH + tdy;
int col = bdx * TILE_WIDTH + tdx;
float PValue = 0;
for (int ph = 0; ph < ceil(width / (float)TILE_WIDTH); ph++) {
if ((row < width) && (ph * TILE_WIDTH + tdx) < width)
Mds[tdy][tdx] = M[row * width + ph * TILE_WIDTH + tdx];
else
Mds[tdy][tdx] = 0.0f;
// N is in column major order, so lets do corner turning
// remember this is a column major array (so the column values are
// placed in sequence, the next column starts after the previous column
// ends in the array)
// bdx*TILE_WIDTH selects the starting point of the tile phase set of
// columns to load bdx*TILE_WIDTH+tdy selects the column id within the tile,
// (bdx*TILE_WIDTH+tdy)*width jumps to the start of the column in the column
// major matrix array. ph*TILE_WIDTH selects the starting point of the phase
// and +tdx selects the row in the column it feels a little like how we are
// traversing M but a little different because if we were traversing exactly
// like M, we would do N[col*width + ph*TILE_WIDTH +tdy] but doing so would
// mean adjacent threads in the block will load values from different
// columns and therefore they would not be coalesced (i.e. we won't be corner turning)
// . For example, (if tyx
// represents thread in block) t00 and t01 are next to each other but would
// load values from different columns. But the way we do it below makes sure
// that t00 and t01 load values from the same column (while also making sure
// that the tile of N is overall loaded correctly) this really needs a full
// blog post with diagrams.
// essentially, tdx and tdy variables have swapped positions in the formula
// we would use for N in column major order without corner turning. i.e. in the
// no corner turning example we have Nds[tdy][tdx] = N[col * width + ph * TILE_WIDTH + tdy];
// but if we swap tdx and tdy in the above formula (after expanding col to (bdx*TILE_WIDTH+tdx) , we get corner turning.
if ((bdx * TILE_WIDTH + tdy < width) && ((ph * TILE_WIDTH + tdx) < width))
Nds[tdx][tdy] =
N[(bdx * TILE_WIDTH + tdy) * width + ph * TILE_WIDTH + tdx];
else
Nds[tdx][tdy] = 0.0f;
__syncthreads();
for (int k = 0; k < TILE_WIDTH; k++) {
PValue += Mds[tdy][k] * Nds[k][tdx];
}
__syncthreads();
}
if ((row < width) && (col < width))
P[row * width + col] = PValue;
}
int main(int argc, char **argv) {
int width = atoi(argv[1]);
float *M;
float *N;
float *P;
float *P_verify;
float *M_d;
float *N_d;
float *P_d;
size_t size = width * width * sizeof(float);
M = (float *)malloc(size);
N = (float *)malloc(size);
P = (float *)malloc(size);
P_verify = (float *)malloc(size);
gpuErrorCheck(hipMalloc((void **)&M_d, size));
gpuErrorCheck(hipMalloc((void **)&N_d, size));
gpuErrorCheck(hipMalloc((void **)&P_d, size));
srand(1234);
for (int i = 0; i < width * width; i++) {
M[i] = (float)rand() / (float)RAND_MAX;
N[i] = (float)rand() / (float)RAND_MAX;
}
// Transposing N to simulate N being stored in column major order
float *N_transpose = (float *)malloc(size);
for (int i = 0; i < width; i++) {
for (int j = 0; j < width; j++) {
N_transpose[j * width + i] = N[i * width + j];
}
}
free(N);
N = N_transpose;
gpuErrorCheck(hipMemcpy(M_d, M, size, hipMemcpyHostToDevice));
gpuErrorCheck(hipMemcpy(N_d, N, size, hipMemcpyHostToDevice));
int griddim = ceil(width / (float)TILE_WIDTH);
dim3 dimGrid(griddim, griddim, 1);
// remember that typically a tile is the same as the block size (since
// each thread in a block loads one element in the tile into shared memory)
dim3 dimBlock(TILE_WIDTH, TILE_WIDTH, 1);
// benchmarking the kernel
// warmup loop
for (int i = 0; i < 10; i++) {
MatrixMulKernel<<<dimGrid, dimBlock>>>(M_d, N_d, P_d, width);
}
gpuErrorCheck(hipMemcpy(P, P_d, size, hipMemcpyDeviceToHost));
gpuErrorCheck(hipDeviceSynchronize());
float milliseconds[150];
hipEvent_t start, stop;
gpuErrorCheck(hipEventCreate(&start));
gpuErrorCheck(hipEventCreate(&stop));
for (int i = 0; i < 150; i++) {
gpuErrorCheck(hipEventRecord(start));
MatrixMulKernel<<<dimGrid, dimBlock>>>(M_d, N_d, P_d, width);
gpuErrorCheck(hipEventRecord(stop));
gpuErrorCheck(hipMemcpy(P, P_d, size, hipMemcpyDeviceToHost));
gpuErrorCheck(hipEventSynchronize(stop));
gpuErrorCheck(hipEventElapsedTime(&milliseconds[i], start, stop));
}
gpuErrorCheck(hipDeviceSynchronize());
float sum = 0;
std::cout << "sample Elapsed times in test:" << std::endl;
for (int i = 0; i < 150; i++) {
if (i % 10 == 0)
std::cout << milliseconds[i] << std::endl;
sum += milliseconds[i];
}
std::cout << "Mean time: " << sum / 150 << std::endl;
}
#include <hip/hip_runtime.h>
#include <hip/nvidia_detail/nvidia_hip_runtime_api.h>
#include <iostream>
#include <math.h>
#include <stdlib.h>
#define gpuErrorCheck(call) \
do { \
hipError_t gpuErr = call; \
if (hipSuccess != gpuErr) { \
printf("HIP Error - %s:%d: '%s'\n", __FILE__, __LINE__, \
hipGetErrorString(gpuErr)); \
exit(0); \
} \
} while (0)
#define TILE_WIDTH 32
__global__ void MatrixMulKernel(float *M, float *N, float *P, int width) {
__shared__ float Mds[TILE_WIDTH][TILE_WIDTH];
__shared__ float Nds[TILE_WIDTH][TILE_WIDTH];
int bdx = blockIdx.x;
int bdy = blockIdx.y;
int tdx = threadIdx.x;
int tdy = threadIdx.y;
int row = bdy * TILE_WIDTH + tdy;
int col = bdx * TILE_WIDTH + tdx;
float PValue = 0;
for (int ph = 0; ph < ceil(width / (float)TILE_WIDTH); ph++) {
if ((row < width) && (ph * TILE_WIDTH + tdx) < width)
Mds[tdy][tdx] = M[row * width + ph * TILE_WIDTH + tdx];
else
Mds[tdy][tdx] = 0.0f;
// here there is no corner turning. adjacent threads actually
// load elements far away from each other in the RAM rather than adjacent to
// each other. The contents of the tile are still correct, its just that
// the memory access isn't coalesced. Remember that here the elements are
// stored in column major order, so the below formula is correct to read along
// the column (where in case this was in row major order, we would've done
// Nds[tdy][tdx] = N[(ph * TILE_WIDTH + tdy) * width + col] which is what we do
// in the mmtiled_boundary.cpp example
if ((col < width) && ((ph * TILE_WIDTH + tdy) < width))
Nds[tdy][tdx] = N[col * width + ph * TILE_WIDTH + tdy];
else
Nds[tdy][tdx] = 0.0f;
__syncthreads();
for (int k = 0; k < TILE_WIDTH; k++) {
PValue += Mds[tdy][k] * Nds[k][tdx];
}
__syncthreads();
}
if ((row < width) && (col < width))
P[row * width + col] = PValue;
}
int main(int argc, char **argv) {
int width = atoi(argv[1]);
float *M;
float *N;
float *P;
float *P_verify;
float *M_d;
float *N_d;
float *P_d;
size_t size = width * width * sizeof(float);
M = (float *)malloc(size);
N = (float *)malloc(size);
P = (float *)malloc(size);
P_verify = (float *)malloc(size);
gpuErrorCheck(hipMalloc((void **)&M_d, size));
gpuErrorCheck(hipMalloc((void **)&N_d, size));
gpuErrorCheck(hipMalloc((void **)&P_d, size));
srand(1234);
for (int i = 0; i < width * width; i++) {
M[i] = (float)rand() / (float)RAND_MAX;
N[i] = (float)rand() / (float)RAND_MAX;
}
// Transposing N to simulate N being stored in column major order
float *N_transpose = (float *)malloc(size);
for (int i = 0; i < width; i++) {
for (int j = 0; j < width; j++) {
N_transpose[j*width+i] = N[i*width+j];
}
}
free(N);
N = N_transpose;
gpuErrorCheck(hipMemcpy(M_d, M, size, hipMemcpyHostToDevice));
gpuErrorCheck(hipMemcpy(N_d, N, size, hipMemcpyHostToDevice));
int griddim = ceil(width/(float)TILE_WIDTH);
dim3 dimGrid(griddim, griddim, 1);
// remember that typically a tile is the same as the block size (since
// each thread in a block loads one element in the tile into shared memory)
dim3 dimBlock(TILE_WIDTH, TILE_WIDTH, 1);
// benchmarking the kernel
// warmup loop
for (int i = 0; i < 10; i++) {
MatrixMulKernel<<<dimGrid, dimBlock>>>(M_d, N_d, P_d, width);
}
gpuErrorCheck(hipMemcpy(P, P_d, size, hipMemcpyDeviceToHost));
gpuErrorCheck(hipDeviceSynchronize());
float milliseconds[150];
hipEvent_t start, stop;
gpuErrorCheck(hipEventCreate(&start));
gpuErrorCheck(hipEventCreate(&stop));
for (int i = 0; i < 150; i++) {
gpuErrorCheck(hipEventRecord(start));
MatrixMulKernel<<<dimGrid, dimBlock>>>(M_d, N_d, P_d, width);
gpuErrorCheck(hipEventRecord(stop));
gpuErrorCheck(hipMemcpy(P, P_d, size, hipMemcpyDeviceToHost));
gpuErrorCheck(hipEventSynchronize(stop));
gpuErrorCheck(hipEventElapsedTime(&milliseconds[i], start, stop));
}
gpuErrorCheck(hipDeviceSynchronize());
float sum = 0;
std::cout << "sample Elapsed times in test:" << std::endl;
for (int i = 0; i < 150; i++) {
if (i % 10 == 0)
std::cout << milliseconds[i] << std::endl;
sum += milliseconds[i];
}
std::cout << "Mean time: " << sum / 150 << std::endl;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment