Skip to content

Instantly share code, notes, and snippets.

@yaoyaoding
Created January 27, 2022 04:52
Show Gist options
  • Save yaoyaoding/f886ff8442579306e9fc4d0f6d5e2706 to your computer and use it in GitHub Desktop.
Save yaoyaoding/f886ff8442579306e9fc4d0f6d5e2706 to your computer and use it in GitHub Desktop.
#include <cassert>
#include <cstdio>
extern "C" {
__device__ __forceinline__ void matmul_bt128x128_bsz256_s128x128_block_c_init_warp(float out[64]) {
int32_t lane_id = (threadIdx.x % 32);
out[0] = 0.0;
out[1] = 0.0;
out[2] = 0.0;
out[3] = 0.0;
out[4] = 0.0;
out[5] = 0.0;
out[6] = 0.0;
out[7] = 0.0;
out[8] = 0.0;
out[9] = 0.0;
out[10] = 0.0;
out[11] = 0.0;
out[12] = 0.0;
out[13] = 0.0;
out[14] = 0.0;
out[15] = 0.0;
out[16] = 0.0;
out[17] = 0.0;
out[18] = 0.0;
out[19] = 0.0;
out[20] = 0.0;
out[21] = 0.0;
out[22] = 0.0;
out[23] = 0.0;
out[24] = 0.0;
out[25] = 0.0;
out[26] = 0.0;
out[27] = 0.0;
out[28] = 0.0;
out[29] = 0.0;
out[30] = 0.0;
out[31] = 0.0;
out[32] = 0.0;
out[33] = 0.0;
out[34] = 0.0;
out[35] = 0.0;
out[36] = 0.0;
out[37] = 0.0;
out[38] = 0.0;
out[39] = 0.0;
out[40] = 0.0;
out[41] = 0.0;
out[42] = 0.0;
out[43] = 0.0;
out[44] = 0.0;
out[45] = 0.0;
out[46] = 0.0;
out[47] = 0.0;
out[48] = 0.0;
out[49] = 0.0;
out[50] = 0.0;
out[51] = 0.0;
out[52] = 0.0;
out[53] = 0.0;
out[54] = 0.0;
out[55] = 0.0;
out[56] = 0.0;
out[57] = 0.0;
out[58] = 0.0;
out[59] = 0.0;
out[60] = 0.0;
out[61] = 0.0;
out[62] = 0.0;
out[63] = 0.0;
}
__device__ __forceinline__ void matmul_bt128x128_bsz256_s128x128_block_a_g2s_block(float in[3833856], __shared__ float out[1024]) {
out[(((threadIdx.x / 8) * 4) + ((threadIdx.x % 8) * 128))] = in[((((threadIdx.x / 8) * 4) * 2304) + (threadIdx.x % 8))];
out[((((threadIdx.x / 8) * 4) + 1) + ((threadIdx.x % 8) * 128))] = in[(((((threadIdx.x / 8) * 4) + 1) * 2304) + (threadIdx.x % 8))];
out[((((threadIdx.x / 8) * 4) + 2) + ((threadIdx.x % 8) * 128))] = in[(((((threadIdx.x / 8) * 4) + 2) * 2304) + (threadIdx.x % 8))];
out[((((threadIdx.x / 8) * 4) + 3) + ((threadIdx.x % 8) * 128))] = in[(((((threadIdx.x / 8) * 4) + 3) * 2304) + (threadIdx.x % 8))];
}
__device__ __forceinline__ void matmul_bt128x128_bsz256_s128x128_block_b_g2s_block(float in[1769472], __shared__ float out[1024]) {
out[(((threadIdx.x / 32) * 128) + (threadIdx.x % 32))] = in[(((threadIdx.x / 32) * 768) + (threadIdx.x % 32))];
out[(((threadIdx.x / 32) * 128) + (32 + (threadIdx.x % 32)))] = in[(((threadIdx.x / 32) * 768) + (32 + (threadIdx.x % 32)))];
out[(((threadIdx.x / 32) * 128) + (64 + (threadIdx.x % 32)))] = in[(((threadIdx.x / 32) * 768) + (64 + (threadIdx.x % 32)))];
out[(((threadIdx.x / 32) * 128) + (96 + (threadIdx.x % 32)))] = in[(((threadIdx.x / 32) * 768) + (96 + (threadIdx.x % 32)))];
}
__device__ __forceinline__ void matmul_bt128x128_bsz256_s128x128_block_a_s2r_warp(__shared__ float in[1024], float out[8]) {
int32_t lane_id = (threadIdx.x % 32);
out[0] = in[(((lane_id / 16) * 8) + ((lane_id % 2) * 4))];
out[1] = in[((((lane_id / 16) * 8) + ((lane_id % 2) * 4)) + 1)];
out[2] = in[((((lane_id / 16) * 8) + ((lane_id % 2) * 4)) + 2)];
out[3] = in[((((lane_id / 16) * 8) + ((lane_id % 2) * 4)) + 3)];
out[4] = in[((16 + ((lane_id / 16) * 8)) + ((lane_id % 2) * 4))];
out[5] = in[(((16 + ((lane_id / 16) * 8)) + ((lane_id % 2) * 4)) + 1)];
out[6] = in[(((16 + ((lane_id / 16) * 8)) + ((lane_id % 2) * 4)) + 2)];
out[7] = in[(((16 + ((lane_id / 16) * 8)) + ((lane_id % 2) * 4)) + 3)];
}
__device__ __forceinline__ void matmul_bt128x128_bsz256_s128x128_block_b_s2r_warp(__shared__ float in[1024], float out[8]) {
int32_t lane_id = (threadIdx.x % 32);
out[0] = in[(((lane_id % 16) / 2) * 4)];
out[1] = in[((((lane_id % 16) / 2) * 4) + 1)];
out[2] = in[((((lane_id % 16) / 2) * 4) + 2)];
out[3] = in[((((lane_id % 16) / 2) * 4) + 3)];
out[4] = in[(32 + (((lane_id % 16) / 2) * 4))];
out[5] = in[((32 + (((lane_id % 16) / 2) * 4)) + 1)];
out[6] = in[((32 + (((lane_id % 16) / 2) * 4)) + 2)];
out[7] = in[((32 + (((lane_id % 16) / 2) * 4)) + 3)];
}
__device__ __forceinline__ void matmul_bt128x128_bsz256_s128x128_block_compute_warp(float A[8], float B[8], float C[64]) {
int32_t lane_id = (threadIdx.x % 32);
C[0] = (C[0] + (A[0] * B[0]));
C[1] = (C[1] + (A[0] * B[1]));
C[2] = (C[2] + (A[0] * B[2]));
C[3] = (C[3] + (A[0] * B[3]));
C[4] = (C[4] + (A[0] * B[4]));
C[5] = (C[5] + (A[0] * B[5]));
C[6] = (C[6] + (A[0] * B[6]));
C[7] = (C[7] + (A[0] * B[7]));
C[8] = (C[8] + (A[1] * B[0]));
C[9] = (C[9] + (A[1] * B[1]));
C[10] = (C[10] + (A[1] * B[2]));
C[11] = (C[11] + (A[1] * B[3]));
C[12] = (C[12] + (A[1] * B[4]));
C[13] = (C[13] + (A[1] * B[5]));
C[14] = (C[14] + (A[1] * B[6]));
C[15] = (C[15] + (A[1] * B[7]));
C[16] = (C[16] + (A[2] * B[0]));
C[17] = (C[17] + (A[2] * B[1]));
C[18] = (C[18] + (A[2] * B[2]));
C[19] = (C[19] + (A[2] * B[3]));
C[20] = (C[20] + (A[2] * B[4]));
C[21] = (C[21] + (A[2] * B[5]));
C[22] = (C[22] + (A[2] * B[6]));
C[23] = (C[23] + (A[2] * B[7]));
C[24] = (C[24] + (A[3] * B[0]));
C[25] = (C[25] + (A[3] * B[1]));
C[26] = (C[26] + (A[3] * B[2]));
C[27] = (C[27] + (A[3] * B[3]));
C[28] = (C[28] + (A[3] * B[4]));
C[29] = (C[29] + (A[3] * B[5]));
C[30] = (C[30] + (A[3] * B[6]));
C[31] = (C[31] + (A[3] * B[7]));
C[32] = (C[32] + (A[4] * B[0]));
C[33] = (C[33] + (A[4] * B[1]));
C[34] = (C[34] + (A[4] * B[2]));
C[35] = (C[35] + (A[4] * B[3]));
C[36] = (C[36] + (A[4] * B[4]));
C[37] = (C[37] + (A[4] * B[5]));
C[38] = (C[38] + (A[4] * B[6]));
C[39] = (C[39] + (A[4] * B[7]));
C[40] = (C[40] + (A[5] * B[0]));
C[41] = (C[41] + (A[5] * B[1]));
C[42] = (C[42] + (A[5] * B[2]));
C[43] = (C[43] + (A[5] * B[3]));
C[44] = (C[44] + (A[5] * B[4]));
C[45] = (C[45] + (A[5] * B[5]));
C[46] = (C[46] + (A[5] * B[6]));
C[47] = (C[47] + (A[5] * B[7]));
C[48] = (C[48] + (A[6] * B[0]));
C[49] = (C[49] + (A[6] * B[1]));
C[50] = (C[50] + (A[6] * B[2]));
C[51] = (C[51] + (A[6] * B[3]));
C[52] = (C[52] + (A[6] * B[4]));
C[53] = (C[53] + (A[6] * B[5]));
C[54] = (C[54] + (A[6] * B[6]));
C[55] = (C[55] + (A[6] * B[7]));
C[56] = (C[56] + (A[7] * B[0]));
C[57] = (C[57] + (A[7] * B[1]));
C[58] = (C[58] + (A[7] * B[2]));
C[59] = (C[59] + (A[7] * B[3]));
C[60] = (C[60] + (A[7] * B[4]));
C[61] = (C[61] + (A[7] * B[5]));
C[62] = (C[62] + (A[7] * B[6]));
C[63] = (C[63] + (A[7] * B[7]));
}
__device__ __forceinline__ void matmul_bt128x128_bsz256_s128x128_block_r2g_warp(float in[64], float out[1277952]) {
int32_t lane_id = (threadIdx.x % 32);
out[(((((lane_id / 16) * 8) + ((lane_id % 2) * 4)) * 768) + (((lane_id % 16) / 2) * 4))] = in[0];
out[(((((lane_id / 16) * 8) + ((lane_id % 2) * 4)) * 768) + ((((lane_id % 16) / 2) * 4) + 1))] = in[1];
out[(((((lane_id / 16) * 8) + ((lane_id % 2) * 4)) * 768) + ((((lane_id % 16) / 2) * 4) + 2))] = in[2];
out[(((((lane_id / 16) * 8) + ((lane_id % 2) * 4)) * 768) + ((((lane_id % 16) / 2) * 4) + 3))] = in[3];
out[(((((lane_id / 16) * 8) + ((lane_id % 2) * 4)) * 768) + (32 + (((lane_id % 16) / 2) * 4)))] = in[4];
out[(((((lane_id / 16) * 8) + ((lane_id % 2) * 4)) * 768) + ((32 + (((lane_id % 16) / 2) * 4)) + 1))] = in[5];
out[(((((lane_id / 16) * 8) + ((lane_id % 2) * 4)) * 768) + ((32 + (((lane_id % 16) / 2) * 4)) + 2))] = in[6];
out[(((((lane_id / 16) * 8) + ((lane_id % 2) * 4)) * 768) + ((32 + (((lane_id % 16) / 2) * 4)) + 3))] = in[7];
out[((((((lane_id / 16) * 8) + ((lane_id % 2) * 4)) + 1) * 768) + (((lane_id % 16) / 2) * 4))] = in[8];
out[((((((lane_id / 16) * 8) + ((lane_id % 2) * 4)) + 1) * 768) + ((((lane_id % 16) / 2) * 4) + 1))] = in[9];
out[((((((lane_id / 16) * 8) + ((lane_id % 2) * 4)) + 1) * 768) + ((((lane_id % 16) / 2) * 4) + 2))] = in[10];
out[((((((lane_id / 16) * 8) + ((lane_id % 2) * 4)) + 1) * 768) + ((((lane_id % 16) / 2) * 4) + 3))] = in[11];
out[((((((lane_id / 16) * 8) + ((lane_id % 2) * 4)) + 1) * 768) + (32 + (((lane_id % 16) / 2) * 4)))] = in[12];
out[((((((lane_id / 16) * 8) + ((lane_id % 2) * 4)) + 1) * 768) + ((32 + (((lane_id % 16) / 2) * 4)) + 1))] = in[13];
out[((((((lane_id / 16) * 8) + ((lane_id % 2) * 4)) + 1) * 768) + ((32 + (((lane_id % 16) / 2) * 4)) + 2))] = in[14];
out[((((((lane_id / 16) * 8) + ((lane_id % 2) * 4)) + 1) * 768) + ((32 + (((lane_id % 16) / 2) * 4)) + 3))] = in[15];
out[((((((lane_id / 16) * 8) + ((lane_id % 2) * 4)) + 2) * 768) + (((lane_id % 16) / 2) * 4))] = in[16];
out[((((((lane_id / 16) * 8) + ((lane_id % 2) * 4)) + 2) * 768) + ((((lane_id % 16) / 2) * 4) + 1))] = in[17];
out[((((((lane_id / 16) * 8) + ((lane_id % 2) * 4)) + 2) * 768) + ((((lane_id % 16) / 2) * 4) + 2))] = in[18];
out[((((((lane_id / 16) * 8) + ((lane_id % 2) * 4)) + 2) * 768) + ((((lane_id % 16) / 2) * 4) + 3))] = in[19];
out[((((((lane_id / 16) * 8) + ((lane_id % 2) * 4)) + 2) * 768) + (32 + (((lane_id % 16) / 2) * 4)))] = in[20];
out[((((((lane_id / 16) * 8) + ((lane_id % 2) * 4)) + 2) * 768) + ((32 + (((lane_id % 16) / 2) * 4)) + 1))] = in[21];
out[((((((lane_id / 16) * 8) + ((lane_id % 2) * 4)) + 2) * 768) + ((32 + (((lane_id % 16) / 2) * 4)) + 2))] = in[22];
out[((((((lane_id / 16) * 8) + ((lane_id % 2) * 4)) + 2) * 768) + ((32 + (((lane_id % 16) / 2) * 4)) + 3))] = in[23];
out[((((((lane_id / 16) * 8) + ((lane_id % 2) * 4)) + 3) * 768) + (((lane_id % 16) / 2) * 4))] = in[24];
out[((((((lane_id / 16) * 8) + ((lane_id % 2) * 4)) + 3) * 768) + ((((lane_id % 16) / 2) * 4) + 1))] = in[25];
out[((((((lane_id / 16) * 8) + ((lane_id % 2) * 4)) + 3) * 768) + ((((lane_id % 16) / 2) * 4) + 2))] = in[26];
out[((((((lane_id / 16) * 8) + ((lane_id % 2) * 4)) + 3) * 768) + ((((lane_id % 16) / 2) * 4) + 3))] = in[27];
out[((((((lane_id / 16) * 8) + ((lane_id % 2) * 4)) + 3) * 768) + (32 + (((lane_id % 16) / 2) * 4)))] = in[28];
out[((((((lane_id / 16) * 8) + ((lane_id % 2) * 4)) + 3) * 768) + ((32 + (((lane_id % 16) / 2) * 4)) + 1))] = in[29];
out[((((((lane_id / 16) * 8) + ((lane_id % 2) * 4)) + 3) * 768) + ((32 + (((lane_id % 16) / 2) * 4)) + 2))] = in[30];
out[((((((lane_id / 16) * 8) + ((lane_id % 2) * 4)) + 3) * 768) + ((32 + (((lane_id % 16) / 2) * 4)) + 3))] = in[31];
out[((((16 + ((lane_id / 16) * 8)) + ((lane_id % 2) * 4)) * 768) + (((lane_id % 16) / 2) * 4))] = in[32];
out[((((16 + ((lane_id / 16) * 8)) + ((lane_id % 2) * 4)) * 768) + ((((lane_id % 16) / 2) * 4) + 1))] = in[33];
out[((((16 + ((lane_id / 16) * 8)) + ((lane_id % 2) * 4)) * 768) + ((((lane_id % 16) / 2) * 4) + 2))] = in[34];
out[((((16 + ((lane_id / 16) * 8)) + ((lane_id % 2) * 4)) * 768) + ((((lane_id % 16) / 2) * 4) + 3))] = in[35];
out[((((16 + ((lane_id / 16) * 8)) + ((lane_id % 2) * 4)) * 768) + (32 + (((lane_id % 16) / 2) * 4)))] = in[36];
out[((((16 + ((lane_id / 16) * 8)) + ((lane_id % 2) * 4)) * 768) + ((32 + (((lane_id % 16) / 2) * 4)) + 1))] = in[37];
out[((((16 + ((lane_id / 16) * 8)) + ((lane_id % 2) * 4)) * 768) + ((32 + (((lane_id % 16) / 2) * 4)) + 2))] = in[38];
out[((((16 + ((lane_id / 16) * 8)) + ((lane_id % 2) * 4)) * 768) + ((32 + (((lane_id % 16) / 2) * 4)) + 3))] = in[39];
out[(((((16 + ((lane_id / 16) * 8)) + ((lane_id % 2) * 4)) + 1) * 768) + (((lane_id % 16) / 2) * 4))] = in[40];
out[(((((16 + ((lane_id / 16) * 8)) + ((lane_id % 2) * 4)) + 1) * 768) + ((((lane_id % 16) / 2) * 4) + 1))] = in[41];
out[(((((16 + ((lane_id / 16) * 8)) + ((lane_id % 2) * 4)) + 1) * 768) + ((((lane_id % 16) / 2) * 4) + 2))] = in[42];
out[(((((16 + ((lane_id / 16) * 8)) + ((lane_id % 2) * 4)) + 1) * 768) + ((((lane_id % 16) / 2) * 4) + 3))] = in[43];
out[(((((16 + ((lane_id / 16) * 8)) + ((lane_id % 2) * 4)) + 1) * 768) + (32 + (((lane_id % 16) / 2) * 4)))] = in[44];
out[(((((16 + ((lane_id / 16) * 8)) + ((lane_id % 2) * 4)) + 1) * 768) + ((32 + (((lane_id % 16) / 2) * 4)) + 1))] = in[45];
out[(((((16 + ((lane_id / 16) * 8)) + ((lane_id % 2) * 4)) + 1) * 768) + ((32 + (((lane_id % 16) / 2) * 4)) + 2))] = in[46];
out[(((((16 + ((lane_id / 16) * 8)) + ((lane_id % 2) * 4)) + 1) * 768) + ((32 + (((lane_id % 16) / 2) * 4)) + 3))] = in[47];
out[(((((16 + ((lane_id / 16) * 8)) + ((lane_id % 2) * 4)) + 2) * 768) + (((lane_id % 16) / 2) * 4))] = in[48];
out[(((((16 + ((lane_id / 16) * 8)) + ((lane_id % 2) * 4)) + 2) * 768) + ((((lane_id % 16) / 2) * 4) + 1))] = in[49];
out[(((((16 + ((lane_id / 16) * 8)) + ((lane_id % 2) * 4)) + 2) * 768) + ((((lane_id % 16) / 2) * 4) + 2))] = in[50];
out[(((((16 + ((lane_id / 16) * 8)) + ((lane_id % 2) * 4)) + 2) * 768) + ((((lane_id % 16) / 2) * 4) + 3))] = in[51];
out[(((((16 + ((lane_id / 16) * 8)) + ((lane_id % 2) * 4)) + 2) * 768) + (32 + (((lane_id % 16) / 2) * 4)))] = in[52];
out[(((((16 + ((lane_id / 16) * 8)) + ((lane_id % 2) * 4)) + 2) * 768) + ((32 + (((lane_id % 16) / 2) * 4)) + 1))] = in[53];
out[(((((16 + ((lane_id / 16) * 8)) + ((lane_id % 2) * 4)) + 2) * 768) + ((32 + (((lane_id % 16) / 2) * 4)) + 2))] = in[54];
out[(((((16 + ((lane_id / 16) * 8)) + ((lane_id % 2) * 4)) + 2) * 768) + ((32 + (((lane_id % 16) / 2) * 4)) + 3))] = in[55];
out[(((((16 + ((lane_id / 16) * 8)) + ((lane_id % 2) * 4)) + 3) * 768) + (((lane_id % 16) / 2) * 4))] = in[56];
out[(((((16 + ((lane_id / 16) * 8)) + ((lane_id % 2) * 4)) + 3) * 768) + ((((lane_id % 16) / 2) * 4) + 1))] = in[57];
out[(((((16 + ((lane_id / 16) * 8)) + ((lane_id % 2) * 4)) + 3) * 768) + ((((lane_id % 16) / 2) * 4) + 2))] = in[58];
out[(((((16 + ((lane_id / 16) * 8)) + ((lane_id % 2) * 4)) + 3) * 768) + ((((lane_id % 16) / 2) * 4) + 3))] = in[59];
out[(((((16 + ((lane_id / 16) * 8)) + ((lane_id % 2) * 4)) + 3) * 768) + (32 + (((lane_id % 16) / 2) * 4)))] = in[60];
out[(((((16 + ((lane_id / 16) * 8)) + ((lane_id % 2) * 4)) + 3) * 768) + ((32 + (((lane_id % 16) / 2) * 4)) + 1))] = in[61];
out[(((((16 + ((lane_id / 16) * 8)) + ((lane_id % 2) * 4)) + 3) * 768) + ((32 + (((lane_id % 16) / 2) * 4)) + 2))] = in[62];
out[(((((16 + ((lane_id / 16) * 8)) + ((lane_id % 2) * 4)) + 3) * 768) + ((32 + (((lane_id % 16) / 2) * 4)) + 3))] = in[63];
}
__device__ __forceinline__ void matmul_bt128x128_bsz256_s128x128_block(float A[3833856], float B[1769472], float C[1277952]) {
// label: matmul128x128x8
__shared__ float smem_A[1024];
__shared__ float smem_B[1024];
float regs_A[8];
float regs_B[8];
float regs_C[64];
int32_t warp_id = (threadIdx.x / 32);
matmul_bt128x128_bsz256_s128x128_block_c_init_warp(regs_C);
for (int32_t block_k_tile = 0; (block_k_tile < 288); block_k_tile = (block_k_tile + 1)) {
matmul_bt128x128_bsz256_s128x128_block_a_g2s_block(&A[(block_k_tile * 8)], &smem_A[0]);
matmul_bt128x128_bsz256_s128x128_block_b_g2s_block(&B[((block_k_tile * 8) * 768)], &smem_B[0]);
__syncthreads();
#pragma unroll
for (int32_t warp_k_tile = 0; (warp_k_tile < 8); warp_k_tile = (warp_k_tile + 1)) {
matmul_bt128x128_bsz256_s128x128_block_a_s2r_warp(&smem_A[(((warp_id / 2) * 32) + (warp_k_tile * 128))], regs_A);
matmul_bt128x128_bsz256_s128x128_block_b_s2r_warp(&smem_B[((warp_k_tile * 128) + ((warp_id % 2) * 64))], regs_B);
matmul_bt128x128_bsz256_s128x128_block_compute_warp(regs_A, regs_B, regs_C);
}
__syncthreads();
}
matmul_bt128x128_bsz256_s128x128_block_r2g_warp(regs_C, &C[((((warp_id / 2) * 32) * 768) + ((warp_id % 2) * 64))]);
}
__global__ void __launch_bounds__(256, 2) matmul_grid(float A[3833856], float B[1769472], float C[1277952]) {
// label: block_task-128x128-block_size-256
int32_t n_block_idx = (blockIdx.x / 6);
int32_t m_block_idx = (blockIdx.x % 6);
matmul_bt128x128_bsz256_s128x128_block(&A[((n_block_idx * 128) * 2304)], &B[(m_block_idx * 128)], &C[(((n_block_idx * 128) * 768) + (m_block_idx * 128))]);
}
__host__ void matmul(int32_t num_args, int32_t *arg_types, void* *args) {
assert(((void)"expect 3 args", (num_args == 3)));
assert(((void)"The 0 th arg should be TensorType(ScalarType(float32), [1664, 2304], global)", (arg_types[0] == 3)));
assert(((void)"The 1 th arg should be TensorType(ScalarType(float32), [2304, 768], global)", (arg_types[1] == 3)));
assert(((void)"The 2 th arg should be TensorType(ScalarType(float32), [1664, 768], global)", (arg_types[2] == 3)));
matmul_grid<<<78,256>>>((float*)args[0], (float*)args[1], (float*)args[2]);
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment