Skip to content

Instantly share code, notes, and snippets.

@yaoyaoding
Created January 27, 2022 04:53
Show Gist options
  • Save yaoyaoding/6c4991e362d60d9562689905a4baebba to your computer and use it in GitHub Desktop.
Save yaoyaoding/6c4991e362d60d9562689905a4baebba to your computer and use it in GitHub Desktop.
#include <cassert>
#include <cstdio>
extern "C" {
__device__ __forceinline__ void matmul_bt128x128_bsz256_s128x128_block_c_init_warp(float out[64]) {
int32_t lane_id = (threadIdx.x % 32);
out[0] = 0.0;
out[1] = 0.0;
out[2] = 0.0;
out[3] = 0.0;
out[4] = 0.0;
out[5] = 0.0;
out[6] = 0.0;
out[7] = 0.0;
out[8] = 0.0;
out[9] = 0.0;
out[10] = 0.0;
out[11] = 0.0;
out[12] = 0.0;
out[13] = 0.0;
out[14] = 0.0;
out[15] = 0.0;
out[16] = 0.0;
out[17] = 0.0;
out[18] = 0.0;
out[19] = 0.0;
out[20] = 0.0;
out[21] = 0.0;
out[22] = 0.0;
out[23] = 0.0;
out[24] = 0.0;
out[25] = 0.0;
out[26] = 0.0;
out[27] = 0.0;
out[28] = 0.0;
out[29] = 0.0;
out[30] = 0.0;
out[31] = 0.0;
out[32] = 0.0;
out[33] = 0.0;
out[34] = 0.0;
out[35] = 0.0;
out[36] = 0.0;
out[37] = 0.0;
out[38] = 0.0;
out[39] = 0.0;
out[40] = 0.0;
out[41] = 0.0;
out[42] = 0.0;
out[43] = 0.0;
out[44] = 0.0;
out[45] = 0.0;
out[46] = 0.0;
out[47] = 0.0;
out[48] = 0.0;
out[49] = 0.0;
out[50] = 0.0;
out[51] = 0.0;
out[52] = 0.0;
out[53] = 0.0;
out[54] = 0.0;
out[55] = 0.0;
out[56] = 0.0;
out[57] = 0.0;
out[58] = 0.0;
out[59] = 0.0;
out[60] = 0.0;
out[61] = 0.0;
out[62] = 0.0;
out[63] = 0.0;
}
__device__ __forceinline__ void matmul_bt128x128_bsz256_s128x128_block_a_g2r_block(float in[3833856], float out[4]) {
out[0] = in[((((threadIdx.x / 8) * 4) * 2304) + (threadIdx.x % 8))];
out[1] = in[(((((threadIdx.x / 8) * 4) + 1) * 2304) + (threadIdx.x % 8))];
out[2] = in[(((((threadIdx.x / 8) * 4) + 2) * 2304) + (threadIdx.x % 8))];
out[3] = in[(((((threadIdx.x / 8) * 4) + 3) * 2304) + (threadIdx.x % 8))];
}
__device__ __forceinline__ void matmul_bt128x128_bsz256_s128x128_block_a_r2s_block(float in[4], __shared__ float out[1024]) {
out[(((threadIdx.x / 8) * 4) + ((threadIdx.x % 8) * 128))] = in[0];
out[((((threadIdx.x / 8) * 4) + 1) + ((threadIdx.x % 8) * 128))] = in[1];
out[((((threadIdx.x / 8) * 4) + 2) + ((threadIdx.x % 8) * 128))] = in[2];
out[((((threadIdx.x / 8) * 4) + 3) + ((threadIdx.x % 8) * 128))] = in[3];
}
__device__ __forceinline__ void matmul_bt128x128_bsz256_s128x128_block_b_g2r_block(float in[1769472], float out[4]) {
out[0] = in[(((threadIdx.x / 32) * 768) + (threadIdx.x % 32))];
out[1] = in[(((threadIdx.x / 32) * 768) + (32 + (threadIdx.x % 32)))];
out[2] = in[(((threadIdx.x / 32) * 768) + (64 + (threadIdx.x % 32)))];
out[3] = in[(((threadIdx.x / 32) * 768) + (96 + (threadIdx.x % 32)))];
}
__device__ __forceinline__ void matmul_bt128x128_bsz256_s128x128_block_b_r2s_block(float in[4], __shared__ float out[1024]) {
out[(((threadIdx.x / 32) * 128) + (threadIdx.x % 32))] = in[0];
out[(((threadIdx.x / 32) * 128) + (32 + (threadIdx.x % 32)))] = in[1];
out[(((threadIdx.x / 32) * 128) + (64 + (threadIdx.x % 32)))] = in[2];
out[(((threadIdx.x / 32) * 128) + (96 + (threadIdx.x % 32)))] = in[3];
}
__device__ __forceinline__ void matmul_bt128x128_bsz256_s128x128_block_a_s2r_warp(__shared__ float in[1024], float out[8]) {
int32_t lane_id = (threadIdx.x % 32);
out[0] = in[(((lane_id / 16) * 8) + ((lane_id % 2) * 4))];
out[1] = in[((((lane_id / 16) * 8) + ((lane_id % 2) * 4)) + 1)];
out[2] = in[((((lane_id / 16) * 8) + ((lane_id % 2) * 4)) + 2)];
out[3] = in[((((lane_id / 16) * 8) + ((lane_id % 2) * 4)) + 3)];
out[4] = in[((16 + ((lane_id / 16) * 8)) + ((lane_id % 2) * 4))];
out[5] = in[(((16 + ((lane_id / 16) * 8)) + ((lane_id % 2) * 4)) + 1)];
out[6] = in[(((16 + ((lane_id / 16) * 8)) + ((lane_id % 2) * 4)) + 2)];
out[7] = in[(((16 + ((lane_id / 16) * 8)) + ((lane_id % 2) * 4)) + 3)];
}
__device__ __forceinline__ void matmul_bt128x128_bsz256_s128x128_block_b_s2r_warp(__shared__ float in[1024], float out[8]) {
int32_t lane_id = (threadIdx.x % 32);
out[0] = in[(((lane_id % 16) / 2) * 4)];
out[1] = in[((((lane_id % 16) / 2) * 4) + 1)];
out[2] = in[((((lane_id % 16) / 2) * 4) + 2)];
out[3] = in[((((lane_id % 16) / 2) * 4) + 3)];
out[4] = in[(32 + (((lane_id % 16) / 2) * 4))];
out[5] = in[((32 + (((lane_id % 16) / 2) * 4)) + 1)];
out[6] = in[((32 + (((lane_id % 16) / 2) * 4)) + 2)];
out[7] = in[((32 + (((lane_id % 16) / 2) * 4)) + 3)];
}
__device__ __forceinline__ void matmul_bt128x128_bsz256_s128x128_block_compute_warp(float A[8], float B[8], float C[64]) {
int32_t lane_id = (threadIdx.x % 32);
C[0] = (C[0] + (A[0] * B[0]));
C[1] = (C[1] + (A[0] * B[1]));
C[2] = (C[2] + (A[0] * B[2]));
C[3] = (C[3] + (A[0] * B[3]));
C[4] = (C[4] + (A[0] * B[4]));
C[5] = (C[5] + (A[0] * B[5]));
C[6] = (C[6] + (A[0] * B[6]));
C[7] = (C[7] + (A[0] * B[7]));
C[8] = (C[8] + (A[1] * B[0]));
C[9] = (C[9] + (A[1] * B[1]));
C[10] = (C[10] + (A[1] * B[2]));
C[11] = (C[11] + (A[1] * B[3]));
C[12] = (C[12] + (A[1] * B[4]));
C[13] = (C[13] + (A[1] * B[5]));
C[14] = (C[14] + (A[1] * B[6]));
C[15] = (C[15] + (A[1] * B[7]));
C[16] = (C[16] + (A[2] * B[0]));
C[17] = (C[17] + (A[2] * B[1]));
C[18] = (C[18] + (A[2] * B[2]));
C[19] = (C[19] + (A[2] * B[3]));
C[20] = (C[20] + (A[2] * B[4]));
C[21] = (C[21] + (A[2] * B[5]));
C[22] = (C[22] + (A[2] * B[6]));
C[23] = (C[23] + (A[2] * B[7]));
C[24] = (C[24] + (A[3] * B[0]));
C[25] = (C[25] + (A[3] * B[1]));
C[26] = (C[26] + (A[3] * B[2]));
C[27] = (C[27] + (A[3] * B[3]));
C[28] = (C[28] + (A[3] * B[4]));
C[29] = (C[29] + (A[3] * B[5]));
C[30] = (C[30] + (A[3] * B[6]));
C[31] = (C[31] + (A[3] * B[7]));
C[32] = (C[32] + (A[4] * B[0]));
C[33] = (C[33] + (A[4] * B[1]));
C[34] = (C[34] + (A[4] * B[2]));
C[35] = (C[35] + (A[4] * B[3]));
C[36] = (C[36] + (A[4] * B[4]));
C[37] = (C[37] + (A[4] * B[5]));
C[38] = (C[38] + (A[4] * B[6]));
C[39] = (C[39] + (A[4] * B[7]));
C[40] = (C[40] + (A[5] * B[0]));
C[41] = (C[41] + (A[5] * B[1]));
C[42] = (C[42] + (A[5] * B[2]));
C[43] = (C[43] + (A[5] * B[3]));
C[44] = (C[44] + (A[5] * B[4]));
C[45] = (C[45] + (A[5] * B[5]));
C[46] = (C[46] + (A[5] * B[6]));
C[47] = (C[47] + (A[5] * B[7]));
C[48] = (C[48] + (A[6] * B[0]));
C[49] = (C[49] + (A[6] * B[1]));
C[50] = (C[50] + (A[6] * B[2]));
C[51] = (C[51] + (A[6] * B[3]));
C[52] = (C[52] + (A[6] * B[4]));
C[53] = (C[53] + (A[6] * B[5]));
C[54] = (C[54] + (A[6] * B[6]));
C[55] = (C[55] + (A[6] * B[7]));
C[56] = (C[56] + (A[7] * B[0]));
C[57] = (C[57] + (A[7] * B[1]));
C[58] = (C[58] + (A[7] * B[2]));
C[59] = (C[59] + (A[7] * B[3]));
C[60] = (C[60] + (A[7] * B[4]));
C[61] = (C[61] + (A[7] * B[5]));
C[62] = (C[62] + (A[7] * B[6]));
C[63] = (C[63] + (A[7] * B[7]));
}
__device__ __forceinline__ void matmul_bt128x128_bsz256_s128x128_block_r2g_warp(float in[64], float out[1277952]) {
int32_t lane_id = (threadIdx.x % 32);
out[(((((lane_id / 16) * 8) + ((lane_id % 2) * 4)) * 768) + (((lane_id % 16) / 2) * 4))] = in[0];
out[(((((lane_id / 16) * 8) + ((lane_id % 2) * 4)) * 768) + ((((lane_id % 16) / 2) * 4) + 1))] = in[1];
out[(((((lane_id / 16) * 8) + ((lane_id % 2) * 4)) * 768) + ((((lane_id % 16) / 2) * 4) + 2))] = in[2];
out[(((((lane_id / 16) * 8) + ((lane_id % 2) * 4)) * 768) + ((((lane_id % 16) / 2) * 4) + 3))] = in[3];
out[(((((lane_id / 16) * 8) + ((lane_id % 2) * 4)) * 768) + (32 + (((lane_id % 16) / 2) * 4)))] = in[4];
out[(((((lane_id / 16) * 8) + ((lane_id % 2) * 4)) * 768) + ((32 + (((lane_id % 16) / 2) * 4)) + 1))] = in[5];
out[(((((lane_id / 16) * 8) + ((lane_id % 2) * 4)) * 768) + ((32 + (((lane_id % 16) / 2) * 4)) + 2))] = in[6];
out[(((((lane_id / 16) * 8) + ((lane_id % 2) * 4)) * 768) + ((32 + (((lane_id % 16) / 2) * 4)) + 3))] = in[7];
out[((((((lane_id / 16) * 8) + ((lane_id % 2) * 4)) + 1) * 768) + (((lane_id % 16) / 2) * 4))] = in[8];
out[((((((lane_id / 16) * 8) + ((lane_id % 2) * 4)) + 1) * 768) + ((((lane_id % 16) / 2) * 4) + 1))] = in[9];
out[((((((lane_id / 16) * 8) + ((lane_id % 2) * 4)) + 1) * 768) + ((((lane_id % 16) / 2) * 4) + 2))] = in[10];
out[((((((lane_id / 16) * 8) + ((lane_id % 2) * 4)) + 1) * 768) + ((((lane_id % 16) / 2) * 4) + 3))] = in[11];
out[((((((lane_id / 16) * 8) + ((lane_id % 2) * 4)) + 1) * 768) + (32 + (((lane_id % 16) / 2) * 4)))] = in[12];
out[((((((lane_id / 16) * 8) + ((lane_id % 2) * 4)) + 1) * 768) + ((32 + (((lane_id % 16) / 2) * 4)) + 1))] = in[13];
out[((((((lane_id / 16) * 8) + ((lane_id % 2) * 4)) + 1) * 768) + ((32 + (((lane_id % 16) / 2) * 4)) + 2))] = in[14];
out[((((((lane_id / 16) * 8) + ((lane_id % 2) * 4)) + 1) * 768) + ((32 + (((lane_id % 16) / 2) * 4)) + 3))] = in[15];
out[((((((lane_id / 16) * 8) + ((lane_id % 2) * 4)) + 2) * 768) + (((lane_id % 16) / 2) * 4))] = in[16];
out[((((((lane_id / 16) * 8) + ((lane_id % 2) * 4)) + 2) * 768) + ((((lane_id % 16) / 2) * 4) + 1))] = in[17];
out[((((((lane_id / 16) * 8) + ((lane_id % 2) * 4)) + 2) * 768) + ((((lane_id % 16) / 2) * 4) + 2))] = in[18];
out[((((((lane_id / 16) * 8) + ((lane_id % 2) * 4)) + 2) * 768) + ((((lane_id % 16) / 2) * 4) + 3))] = in[19];
out[((((((lane_id / 16) * 8) + ((lane_id % 2) * 4)) + 2) * 768) + (32 + (((lane_id % 16) / 2) * 4)))] = in[20];
out[((((((lane_id / 16) * 8) + ((lane_id % 2) * 4)) + 2) * 768) + ((32 + (((lane_id % 16) / 2) * 4)) + 1))] = in[21];
out[((((((lane_id / 16) * 8) + ((lane_id % 2) * 4)) + 2) * 768) + ((32 + (((lane_id % 16) / 2) * 4)) + 2))] = in[22];
out[((((((lane_id / 16) * 8) + ((lane_id % 2) * 4)) + 2) * 768) + ((32 + (((lane_id % 16) / 2) * 4)) + 3))] = in[23];
out[((((((lane_id / 16) * 8) + ((lane_id % 2) * 4)) + 3) * 768) + (((lane_id % 16) / 2) * 4))] = in[24];
out[((((((lane_id / 16) * 8) + ((lane_id % 2) * 4)) + 3) * 768) + ((((lane_id % 16) / 2) * 4) + 1))] = in[25];
out[((((((lane_id / 16) * 8) + ((lane_id % 2) * 4)) + 3) * 768) + ((((lane_id % 16) / 2) * 4) + 2))] = in[26];
out[((((((lane_id / 16) * 8) + ((lane_id % 2) * 4)) + 3) * 768) + ((((lane_id % 16) / 2) * 4) + 3))] = in[27];
out[((((((lane_id / 16) * 8) + ((lane_id % 2) * 4)) + 3) * 768) + (32 + (((lane_id % 16) / 2) * 4)))] = in[28];
out[((((((lane_id / 16) * 8) + ((lane_id % 2) * 4)) + 3) * 768) + ((32 + (((lane_id % 16) / 2) * 4)) + 1))] = in[29];
out[((((((lane_id / 16) * 8) + ((lane_id % 2) * 4)) + 3) * 768) + ((32 + (((lane_id % 16) / 2) * 4)) + 2))] = in[30];
out[((((((lane_id / 16) * 8) + ((lane_id % 2) * 4)) + 3) * 768) + ((32 + (((lane_id % 16) / 2) * 4)) + 3))] = in[31];
out[((((16 + ((lane_id / 16) * 8)) + ((lane_id % 2) * 4)) * 768) + (((lane_id % 16) / 2) * 4))] = in[32];
out[((((16 + ((lane_id / 16) * 8)) + ((lane_id % 2) * 4)) * 768) + ((((lane_id % 16) / 2) * 4) + 1))] = in[33];
out[((((16 + ((lane_id / 16) * 8)) + ((lane_id % 2) * 4)) * 768) + ((((lane_id % 16) / 2) * 4) + 2))] = in[34];
out[((((16 + ((lane_id / 16) * 8)) + ((lane_id % 2) * 4)) * 768) + ((((lane_id % 16) / 2) * 4) + 3))] = in[35];
out[((((16 + ((lane_id / 16) * 8)) + ((lane_id % 2) * 4)) * 768) + (32 + (((lane_id % 16) / 2) * 4)))] = in[36];
out[((((16 + ((lane_id / 16) * 8)) + ((lane_id % 2) * 4)) * 768) + ((32 + (((lane_id % 16) / 2) * 4)) + 1))] = in[37];
out[((((16 + ((lane_id / 16) * 8)) + ((lane_id % 2) * 4)) * 768) + ((32 + (((lane_id % 16) / 2) * 4)) + 2))] = in[38];
out[((((16 + ((lane_id / 16) * 8)) + ((lane_id % 2) * 4)) * 768) + ((32 + (((lane_id % 16) / 2) * 4)) + 3))] = in[39];
out[(((((16 + ((lane_id / 16) * 8)) + ((lane_id % 2) * 4)) + 1) * 768) + (((lane_id % 16) / 2) * 4))] = in[40];
out[(((((16 + ((lane_id / 16) * 8)) + ((lane_id % 2) * 4)) + 1) * 768) + ((((lane_id % 16) / 2) * 4) + 1))] = in[41];
out[(((((16 + ((lane_id / 16) * 8)) + ((lane_id % 2) * 4)) + 1) * 768) + ((((lane_id % 16) / 2) * 4) + 2))] = in[42];
out[(((((16 + ((lane_id / 16) * 8)) + ((lane_id % 2) * 4)) + 1) * 768) + ((((lane_id % 16) / 2) * 4) + 3))] = in[43];
out[(((((16 + ((lane_id / 16) * 8)) + ((lane_id % 2) * 4)) + 1) * 768) + (32 + (((lane_id % 16) / 2) * 4)))] = in[44];
out[(((((16 + ((lane_id / 16) * 8)) + ((lane_id % 2) * 4)) + 1) * 768) + ((32 + (((lane_id % 16) / 2) * 4)) + 1))] = in[45];
out[(((((16 + ((lane_id / 16) * 8)) + ((lane_id % 2) * 4)) + 1) * 768) + ((32 + (((lane_id % 16) / 2) * 4)) + 2))] = in[46];
out[(((((16 + ((lane_id / 16) * 8)) + ((lane_id % 2) * 4)) + 1) * 768) + ((32 + (((lane_id % 16) / 2) * 4)) + 3))] = in[47];
out[(((((16 + ((lane_id / 16) * 8)) + ((lane_id % 2) * 4)) + 2) * 768) + (((lane_id % 16) / 2) * 4))] = in[48];
out[(((((16 + ((lane_id / 16) * 8)) + ((lane_id % 2) * 4)) + 2) * 768) + ((((lane_id % 16) / 2) * 4) + 1))] = in[49];
out[(((((16 + ((lane_id / 16) * 8)) + ((lane_id % 2) * 4)) + 2) * 768) + ((((lane_id % 16) / 2) * 4) + 2))] = in[50];
out[(((((16 + ((lane_id / 16) * 8)) + ((lane_id % 2) * 4)) + 2) * 768) + ((((lane_id % 16) / 2) * 4) + 3))] = in[51];
out[(((((16 + ((lane_id / 16) * 8)) + ((lane_id % 2) * 4)) + 2) * 768) + (32 + (((lane_id % 16) / 2) * 4)))] = in[52];
out[(((((16 + ((lane_id / 16) * 8)) + ((lane_id % 2) * 4)) + 2) * 768) + ((32 + (((lane_id % 16) / 2) * 4)) + 1))] = in[53];
out[(((((16 + ((lane_id / 16) * 8)) + ((lane_id % 2) * 4)) + 2) * 768) + ((32 + (((lane_id % 16) / 2) * 4)) + 2))] = in[54];
out[(((((16 + ((lane_id / 16) * 8)) + ((lane_id % 2) * 4)) + 2) * 768) + ((32 + (((lane_id % 16) / 2) * 4)) + 3))] = in[55];
out[(((((16 + ((lane_id / 16) * 8)) + ((lane_id % 2) * 4)) + 3) * 768) + (((lane_id % 16) / 2) * 4))] = in[56];
out[(((((16 + ((lane_id / 16) * 8)) + ((lane_id % 2) * 4)) + 3) * 768) + ((((lane_id % 16) / 2) * 4) + 1))] = in[57];
out[(((((16 + ((lane_id / 16) * 8)) + ((lane_id % 2) * 4)) + 3) * 768) + ((((lane_id % 16) / 2) * 4) + 2))] = in[58];
out[(((((16 + ((lane_id / 16) * 8)) + ((lane_id % 2) * 4)) + 3) * 768) + ((((lane_id % 16) / 2) * 4) + 3))] = in[59];
out[(((((16 + ((lane_id / 16) * 8)) + ((lane_id % 2) * 4)) + 3) * 768) + (32 + (((lane_id % 16) / 2) * 4)))] = in[60];
out[(((((16 + ((lane_id / 16) * 8)) + ((lane_id % 2) * 4)) + 3) * 768) + ((32 + (((lane_id % 16) / 2) * 4)) + 1))] = in[61];
out[(((((16 + ((lane_id / 16) * 8)) + ((lane_id % 2) * 4)) + 3) * 768) + ((32 + (((lane_id % 16) / 2) * 4)) + 2))] = in[62];
out[(((((16 + ((lane_id / 16) * 8)) + ((lane_id % 2) * 4)) + 3) * 768) + ((32 + (((lane_id % 16) / 2) * 4)) + 3))] = in[63];
}
__device__ __forceinline__ void matmul_bt128x128_bsz256_s128x128_block(float A[3833856], float B[1769472], float C[1277952]) {
// label: matmul128x128x8
__shared__ float smem_A[2048];
__shared__ float smem_B[2048];
float regs_A[16];
float regs_B[16];
float regs_C[64];
float regs_A_ldg[4];
float regs_B_ldg[4];
matmul_bt128x128_bsz256_s128x128_block_c_init_warp(regs_C);
int32_t warp_id = (threadIdx.x / 32);
matmul_bt128x128_bsz256_s128x128_block_a_g2r_block(&A[0], regs_A_ldg);
matmul_bt128x128_bsz256_s128x128_block_a_r2s_block(regs_A_ldg, smem_A);
matmul_bt128x128_bsz256_s128x128_block_b_g2r_block(&B[0], regs_B_ldg);
matmul_bt128x128_bsz256_s128x128_block_b_r2s_block(regs_B_ldg, smem_B);
__syncthreads();
matmul_bt128x128_bsz256_s128x128_block_a_s2r_warp(&smem_A[((warp_id / 2) * 32)], &regs_A[0]);
matmul_bt128x128_bsz256_s128x128_block_b_s2r_warp(&smem_B[((warp_id % 2) * 64)], &regs_B[0]);
__syncthreads();
for (int32_t block_k_tile = 0; (block_k_tile < 287); block_k_tile = (block_k_tile + 1)) {
#pragma unroll
for (int32_t warp_k_tile = 0; (warp_k_tile < 8); warp_k_tile = (warp_k_tile + 1)) {
if (warp_k_tile == 0) {
matmul_bt128x128_bsz256_s128x128_block_a_s2r_warp(&smem_A[((((block_k_tile % 2) * 1024) + ((warp_id / 2) * 32)) + ((warp_k_tile + 1) * 128))], &regs_A[(((warp_k_tile + 1) % 2) * 8)]);
matmul_bt128x128_bsz256_s128x128_block_b_s2r_warp(&smem_B[((((block_k_tile % 2) * 1024) + ((warp_k_tile + 1) * 128)) + ((warp_id % 2) * 64))], &regs_B[(((warp_k_tile + 1) % 2) * 8)]);
matmul_bt128x128_bsz256_s128x128_block_a_g2r_block(&A[((block_k_tile + 1) * 8)], regs_A_ldg);
matmul_bt128x128_bsz256_s128x128_block_b_g2r_block(&B[(((block_k_tile + 1) * 8) * 768)], regs_B_ldg);
matmul_bt128x128_bsz256_s128x128_block_compute_warp(&regs_A[((warp_k_tile % 2) * 8)], &regs_B[((warp_k_tile % 2) * 8)], regs_C);
} else {
if (warp_k_tile < 7) {
matmul_bt128x128_bsz256_s128x128_block_a_s2r_warp(&smem_A[((((block_k_tile % 2) * 1024) + ((warp_id / 2) * 32)) + ((warp_k_tile + 1) * 128))], &regs_A[(((warp_k_tile + 1) % 2) * 8)]);
matmul_bt128x128_bsz256_s128x128_block_b_s2r_warp(&smem_B[((((block_k_tile % 2) * 1024) + ((warp_k_tile + 1) * 128)) + ((warp_id % 2) * 64))], &regs_B[(((warp_k_tile + 1) % 2) * 8)]);
matmul_bt128x128_bsz256_s128x128_block_compute_warp(&regs_A[((warp_k_tile % 2) * 8)], &regs_B[((warp_k_tile % 2) * 8)], regs_C);
} else {
matmul_bt128x128_bsz256_s128x128_block_a_r2s_block(regs_A_ldg, &smem_A[(((block_k_tile + 1) % 2) * 1024)]);
matmul_bt128x128_bsz256_s128x128_block_b_r2s_block(regs_B_ldg, &smem_B[(((block_k_tile + 1) % 2) * 1024)]);
__syncthreads();
matmul_bt128x128_bsz256_s128x128_block_a_s2r_warp(&smem_A[((((block_k_tile + 1) % 2) * 1024) + ((warp_id / 2) * 32))], &regs_A[(((warp_k_tile + 1) % 2) * 8)]);
matmul_bt128x128_bsz256_s128x128_block_b_s2r_warp(&smem_B[((((block_k_tile + 1) % 2) * 1024) + ((warp_id % 2) * 64))], &regs_B[(((warp_k_tile + 1) % 2) * 8)]);
matmul_bt128x128_bsz256_s128x128_block_compute_warp(&regs_A[((warp_k_tile % 2) * 8)], &regs_B[((warp_k_tile % 2) * 8)], regs_C);
}
}
}
}
#pragma unroll
for (int32_t warp_k_tile_1 = 0; (warp_k_tile_1 < 8); warp_k_tile_1 = (warp_k_tile_1 + 1)) {
if (warp_k_tile_1 < 7) {
matmul_bt128x128_bsz256_s128x128_block_a_s2r_warp(&smem_A[((1024 + ((warp_id / 2) * 32)) + ((warp_k_tile_1 + 1) * 128))], &regs_A[(((warp_k_tile_1 + 1) % 2) * 8)]);
matmul_bt128x128_bsz256_s128x128_block_b_s2r_warp(&smem_B[((1024 + ((warp_k_tile_1 + 1) * 128)) + ((warp_id % 2) * 64))], &regs_B[(((warp_k_tile_1 + 1) % 2) * 8)]);
matmul_bt128x128_bsz256_s128x128_block_compute_warp(&regs_A[((warp_k_tile_1 % 2) * 8)], &regs_B[((warp_k_tile_1 % 2) * 8)], regs_C);
} else {
matmul_bt128x128_bsz256_s128x128_block_compute_warp(&regs_A[((warp_k_tile_1 % 2) * 8)], &regs_B[((warp_k_tile_1 % 2) * 8)], regs_C);
}
}
__syncthreads();
matmul_bt128x128_bsz256_s128x128_block_r2g_warp(regs_C, &C[((((warp_id / 2) * 32) * 768) + ((warp_id % 2) * 64))]);
}
__global__ void __launch_bounds__(256, 2) matmul_grid(float A[3833856], float B[1769472], float C[1277952]) {
// label: block_task-128x128-block_size-256
int32_t n_block_idx = (blockIdx.x / 6);
int32_t m_block_idx = (blockIdx.x % 6);
matmul_bt128x128_bsz256_s128x128_block(&A[((n_block_idx * 128) * 2304)], &B[(m_block_idx * 128)], &C[(((n_block_idx * 128) * 768) + (m_block_idx * 128))]);
}
__host__ void matmul(int32_t num_args, int32_t *arg_types, void* *args) {
assert(((void)"expect 3 args", (num_args == 3)));
assert(((void)"The 0 th arg should be TensorType(ScalarType(float32), [1664, 2304], global)", (arg_types[0] == 3)));
assert(((void)"The 1 th arg should be TensorType(ScalarType(float32), [2304, 768], global)", (arg_types[1] == 3)));
assert(((void)"The 2 th arg should be TensorType(ScalarType(float32), [1664, 768], global)", (arg_types[2] == 3)));
matmul_grid<<<78,256>>>((float*)args[0], (float*)args[1], (float*)args[2]);
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment