Created
January 27, 2022 04:53
-
-
Save yaoyaoding/6c4991e362d60d9562689905a4baebba to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <cassert> | |
#include <cstdio> | |
extern "C" { | |
__device__ __forceinline__ void matmul_bt128x128_bsz256_s128x128_block_c_init_warp(float out[64]) { | |
int32_t lane_id = (threadIdx.x % 32); | |
out[0] = 0.0; | |
out[1] = 0.0; | |
out[2] = 0.0; | |
out[3] = 0.0; | |
out[4] = 0.0; | |
out[5] = 0.0; | |
out[6] = 0.0; | |
out[7] = 0.0; | |
out[8] = 0.0; | |
out[9] = 0.0; | |
out[10] = 0.0; | |
out[11] = 0.0; | |
out[12] = 0.0; | |
out[13] = 0.0; | |
out[14] = 0.0; | |
out[15] = 0.0; | |
out[16] = 0.0; | |
out[17] = 0.0; | |
out[18] = 0.0; | |
out[19] = 0.0; | |
out[20] = 0.0; | |
out[21] = 0.0; | |
out[22] = 0.0; | |
out[23] = 0.0; | |
out[24] = 0.0; | |
out[25] = 0.0; | |
out[26] = 0.0; | |
out[27] = 0.0; | |
out[28] = 0.0; | |
out[29] = 0.0; | |
out[30] = 0.0; | |
out[31] = 0.0; | |
out[32] = 0.0; | |
out[33] = 0.0; | |
out[34] = 0.0; | |
out[35] = 0.0; | |
out[36] = 0.0; | |
out[37] = 0.0; | |
out[38] = 0.0; | |
out[39] = 0.0; | |
out[40] = 0.0; | |
out[41] = 0.0; | |
out[42] = 0.0; | |
out[43] = 0.0; | |
out[44] = 0.0; | |
out[45] = 0.0; | |
out[46] = 0.0; | |
out[47] = 0.0; | |
out[48] = 0.0; | |
out[49] = 0.0; | |
out[50] = 0.0; | |
out[51] = 0.0; | |
out[52] = 0.0; | |
out[53] = 0.0; | |
out[54] = 0.0; | |
out[55] = 0.0; | |
out[56] = 0.0; | |
out[57] = 0.0; | |
out[58] = 0.0; | |
out[59] = 0.0; | |
out[60] = 0.0; | |
out[61] = 0.0; | |
out[62] = 0.0; | |
out[63] = 0.0; | |
} | |
__device__ __forceinline__ void matmul_bt128x128_bsz256_s128x128_block_a_g2r_block(float in[3833856], float out[4]) { | |
out[0] = in[((((threadIdx.x / 8) * 4) * 2304) + (threadIdx.x % 8))]; | |
out[1] = in[(((((threadIdx.x / 8) * 4) + 1) * 2304) + (threadIdx.x % 8))]; | |
out[2] = in[(((((threadIdx.x / 8) * 4) + 2) * 2304) + (threadIdx.x % 8))]; | |
out[3] = in[(((((threadIdx.x / 8) * 4) + 3) * 2304) + (threadIdx.x % 8))]; | |
} | |
__device__ __forceinline__ void matmul_bt128x128_bsz256_s128x128_block_a_r2s_block(float in[4], __shared__ float out[1024]) { | |
out[(((threadIdx.x / 8) * 4) + ((threadIdx.x % 8) * 128))] = in[0]; | |
out[((((threadIdx.x / 8) * 4) + 1) + ((threadIdx.x % 8) * 128))] = in[1]; | |
out[((((threadIdx.x / 8) * 4) + 2) + ((threadIdx.x % 8) * 128))] = in[2]; | |
out[((((threadIdx.x / 8) * 4) + 3) + ((threadIdx.x % 8) * 128))] = in[3]; | |
} | |
__device__ __forceinline__ void matmul_bt128x128_bsz256_s128x128_block_b_g2r_block(float in[1769472], float out[4]) { | |
out[0] = in[(((threadIdx.x / 32) * 768) + (threadIdx.x % 32))]; | |
out[1] = in[(((threadIdx.x / 32) * 768) + (32 + (threadIdx.x % 32)))]; | |
out[2] = in[(((threadIdx.x / 32) * 768) + (64 + (threadIdx.x % 32)))]; | |
out[3] = in[(((threadIdx.x / 32) * 768) + (96 + (threadIdx.x % 32)))]; | |
} | |
__device__ __forceinline__ void matmul_bt128x128_bsz256_s128x128_block_b_r2s_block(float in[4], __shared__ float out[1024]) { | |
out[(((threadIdx.x / 32) * 128) + (threadIdx.x % 32))] = in[0]; | |
out[(((threadIdx.x / 32) * 128) + (32 + (threadIdx.x % 32)))] = in[1]; | |
out[(((threadIdx.x / 32) * 128) + (64 + (threadIdx.x % 32)))] = in[2]; | |
out[(((threadIdx.x / 32) * 128) + (96 + (threadIdx.x % 32)))] = in[3]; | |
} | |
__device__ __forceinline__ void matmul_bt128x128_bsz256_s128x128_block_a_s2r_warp(__shared__ float in[1024], float out[8]) { | |
int32_t lane_id = (threadIdx.x % 32); | |
out[0] = in[(((lane_id / 16) * 8) + ((lane_id % 2) * 4))]; | |
out[1] = in[((((lane_id / 16) * 8) + ((lane_id % 2) * 4)) + 1)]; | |
out[2] = in[((((lane_id / 16) * 8) + ((lane_id % 2) * 4)) + 2)]; | |
out[3] = in[((((lane_id / 16) * 8) + ((lane_id % 2) * 4)) + 3)]; | |
out[4] = in[((16 + ((lane_id / 16) * 8)) + ((lane_id % 2) * 4))]; | |
out[5] = in[(((16 + ((lane_id / 16) * 8)) + ((lane_id % 2) * 4)) + 1)]; | |
out[6] = in[(((16 + ((lane_id / 16) * 8)) + ((lane_id % 2) * 4)) + 2)]; | |
out[7] = in[(((16 + ((lane_id / 16) * 8)) + ((lane_id % 2) * 4)) + 3)]; | |
} | |
__device__ __forceinline__ void matmul_bt128x128_bsz256_s128x128_block_b_s2r_warp(__shared__ float in[1024], float out[8]) { | |
int32_t lane_id = (threadIdx.x % 32); | |
out[0] = in[(((lane_id % 16) / 2) * 4)]; | |
out[1] = in[((((lane_id % 16) / 2) * 4) + 1)]; | |
out[2] = in[((((lane_id % 16) / 2) * 4) + 2)]; | |
out[3] = in[((((lane_id % 16) / 2) * 4) + 3)]; | |
out[4] = in[(32 + (((lane_id % 16) / 2) * 4))]; | |
out[5] = in[((32 + (((lane_id % 16) / 2) * 4)) + 1)]; | |
out[6] = in[((32 + (((lane_id % 16) / 2) * 4)) + 2)]; | |
out[7] = in[((32 + (((lane_id % 16) / 2) * 4)) + 3)]; | |
} | |
__device__ __forceinline__ void matmul_bt128x128_bsz256_s128x128_block_compute_warp(float A[8], float B[8], float C[64]) { | |
int32_t lane_id = (threadIdx.x % 32); | |
C[0] = (C[0] + (A[0] * B[0])); | |
C[1] = (C[1] + (A[0] * B[1])); | |
C[2] = (C[2] + (A[0] * B[2])); | |
C[3] = (C[3] + (A[0] * B[3])); | |
C[4] = (C[4] + (A[0] * B[4])); | |
C[5] = (C[5] + (A[0] * B[5])); | |
C[6] = (C[6] + (A[0] * B[6])); | |
C[7] = (C[7] + (A[0] * B[7])); | |
C[8] = (C[8] + (A[1] * B[0])); | |
C[9] = (C[9] + (A[1] * B[1])); | |
C[10] = (C[10] + (A[1] * B[2])); | |
C[11] = (C[11] + (A[1] * B[3])); | |
C[12] = (C[12] + (A[1] * B[4])); | |
C[13] = (C[13] + (A[1] * B[5])); | |
C[14] = (C[14] + (A[1] * B[6])); | |
C[15] = (C[15] + (A[1] * B[7])); | |
C[16] = (C[16] + (A[2] * B[0])); | |
C[17] = (C[17] + (A[2] * B[1])); | |
C[18] = (C[18] + (A[2] * B[2])); | |
C[19] = (C[19] + (A[2] * B[3])); | |
C[20] = (C[20] + (A[2] * B[4])); | |
C[21] = (C[21] + (A[2] * B[5])); | |
C[22] = (C[22] + (A[2] * B[6])); | |
C[23] = (C[23] + (A[2] * B[7])); | |
C[24] = (C[24] + (A[3] * B[0])); | |
C[25] = (C[25] + (A[3] * B[1])); | |
C[26] = (C[26] + (A[3] * B[2])); | |
C[27] = (C[27] + (A[3] * B[3])); | |
C[28] = (C[28] + (A[3] * B[4])); | |
C[29] = (C[29] + (A[3] * B[5])); | |
C[30] = (C[30] + (A[3] * B[6])); | |
C[31] = (C[31] + (A[3] * B[7])); | |
C[32] = (C[32] + (A[4] * B[0])); | |
C[33] = (C[33] + (A[4] * B[1])); | |
C[34] = (C[34] + (A[4] * B[2])); | |
C[35] = (C[35] + (A[4] * B[3])); | |
C[36] = (C[36] + (A[4] * B[4])); | |
C[37] = (C[37] + (A[4] * B[5])); | |
C[38] = (C[38] + (A[4] * B[6])); | |
C[39] = (C[39] + (A[4] * B[7])); | |
C[40] = (C[40] + (A[5] * B[0])); | |
C[41] = (C[41] + (A[5] * B[1])); | |
C[42] = (C[42] + (A[5] * B[2])); | |
C[43] = (C[43] + (A[5] * B[3])); | |
C[44] = (C[44] + (A[5] * B[4])); | |
C[45] = (C[45] + (A[5] * B[5])); | |
C[46] = (C[46] + (A[5] * B[6])); | |
C[47] = (C[47] + (A[5] * B[7])); | |
C[48] = (C[48] + (A[6] * B[0])); | |
C[49] = (C[49] + (A[6] * B[1])); | |
C[50] = (C[50] + (A[6] * B[2])); | |
C[51] = (C[51] + (A[6] * B[3])); | |
C[52] = (C[52] + (A[6] * B[4])); | |
C[53] = (C[53] + (A[6] * B[5])); | |
C[54] = (C[54] + (A[6] * B[6])); | |
C[55] = (C[55] + (A[6] * B[7])); | |
C[56] = (C[56] + (A[7] * B[0])); | |
C[57] = (C[57] + (A[7] * B[1])); | |
C[58] = (C[58] + (A[7] * B[2])); | |
C[59] = (C[59] + (A[7] * B[3])); | |
C[60] = (C[60] + (A[7] * B[4])); | |
C[61] = (C[61] + (A[7] * B[5])); | |
C[62] = (C[62] + (A[7] * B[6])); | |
C[63] = (C[63] + (A[7] * B[7])); | |
} | |
__device__ __forceinline__ void matmul_bt128x128_bsz256_s128x128_block_r2g_warp(float in[64], float out[1277952]) { | |
int32_t lane_id = (threadIdx.x % 32); | |
out[(((((lane_id / 16) * 8) + ((lane_id % 2) * 4)) * 768) + (((lane_id % 16) / 2) * 4))] = in[0]; | |
out[(((((lane_id / 16) * 8) + ((lane_id % 2) * 4)) * 768) + ((((lane_id % 16) / 2) * 4) + 1))] = in[1]; | |
out[(((((lane_id / 16) * 8) + ((lane_id % 2) * 4)) * 768) + ((((lane_id % 16) / 2) * 4) + 2))] = in[2]; | |
out[(((((lane_id / 16) * 8) + ((lane_id % 2) * 4)) * 768) + ((((lane_id % 16) / 2) * 4) + 3))] = in[3]; | |
out[(((((lane_id / 16) * 8) + ((lane_id % 2) * 4)) * 768) + (32 + (((lane_id % 16) / 2) * 4)))] = in[4]; | |
out[(((((lane_id / 16) * 8) + ((lane_id % 2) * 4)) * 768) + ((32 + (((lane_id % 16) / 2) * 4)) + 1))] = in[5]; | |
out[(((((lane_id / 16) * 8) + ((lane_id % 2) * 4)) * 768) + ((32 + (((lane_id % 16) / 2) * 4)) + 2))] = in[6]; | |
out[(((((lane_id / 16) * 8) + ((lane_id % 2) * 4)) * 768) + ((32 + (((lane_id % 16) / 2) * 4)) + 3))] = in[7]; | |
out[((((((lane_id / 16) * 8) + ((lane_id % 2) * 4)) + 1) * 768) + (((lane_id % 16) / 2) * 4))] = in[8]; | |
out[((((((lane_id / 16) * 8) + ((lane_id % 2) * 4)) + 1) * 768) + ((((lane_id % 16) / 2) * 4) + 1))] = in[9]; | |
out[((((((lane_id / 16) * 8) + ((lane_id % 2) * 4)) + 1) * 768) + ((((lane_id % 16) / 2) * 4) + 2))] = in[10]; | |
out[((((((lane_id / 16) * 8) + ((lane_id % 2) * 4)) + 1) * 768) + ((((lane_id % 16) / 2) * 4) + 3))] = in[11]; | |
out[((((((lane_id / 16) * 8) + ((lane_id % 2) * 4)) + 1) * 768) + (32 + (((lane_id % 16) / 2) * 4)))] = in[12]; | |
out[((((((lane_id / 16) * 8) + ((lane_id % 2) * 4)) + 1) * 768) + ((32 + (((lane_id % 16) / 2) * 4)) + 1))] = in[13]; | |
out[((((((lane_id / 16) * 8) + ((lane_id % 2) * 4)) + 1) * 768) + ((32 + (((lane_id % 16) / 2) * 4)) + 2))] = in[14]; | |
out[((((((lane_id / 16) * 8) + ((lane_id % 2) * 4)) + 1) * 768) + ((32 + (((lane_id % 16) / 2) * 4)) + 3))] = in[15]; | |
out[((((((lane_id / 16) * 8) + ((lane_id % 2) * 4)) + 2) * 768) + (((lane_id % 16) / 2) * 4))] = in[16]; | |
out[((((((lane_id / 16) * 8) + ((lane_id % 2) * 4)) + 2) * 768) + ((((lane_id % 16) / 2) * 4) + 1))] = in[17]; | |
out[((((((lane_id / 16) * 8) + ((lane_id % 2) * 4)) + 2) * 768) + ((((lane_id % 16) / 2) * 4) + 2))] = in[18]; | |
out[((((((lane_id / 16) * 8) + ((lane_id % 2) * 4)) + 2) * 768) + ((((lane_id % 16) / 2) * 4) + 3))] = in[19]; | |
out[((((((lane_id / 16) * 8) + ((lane_id % 2) * 4)) + 2) * 768) + (32 + (((lane_id % 16) / 2) * 4)))] = in[20]; | |
out[((((((lane_id / 16) * 8) + ((lane_id % 2) * 4)) + 2) * 768) + ((32 + (((lane_id % 16) / 2) * 4)) + 1))] = in[21]; | |
out[((((((lane_id / 16) * 8) + ((lane_id % 2) * 4)) + 2) * 768) + ((32 + (((lane_id % 16) / 2) * 4)) + 2))] = in[22]; | |
out[((((((lane_id / 16) * 8) + ((lane_id % 2) * 4)) + 2) * 768) + ((32 + (((lane_id % 16) / 2) * 4)) + 3))] = in[23]; | |
out[((((((lane_id / 16) * 8) + ((lane_id % 2) * 4)) + 3) * 768) + (((lane_id % 16) / 2) * 4))] = in[24]; | |
out[((((((lane_id / 16) * 8) + ((lane_id % 2) * 4)) + 3) * 768) + ((((lane_id % 16) / 2) * 4) + 1))] = in[25]; | |
out[((((((lane_id / 16) * 8) + ((lane_id % 2) * 4)) + 3) * 768) + ((((lane_id % 16) / 2) * 4) + 2))] = in[26]; | |
out[((((((lane_id / 16) * 8) + ((lane_id % 2) * 4)) + 3) * 768) + ((((lane_id % 16) / 2) * 4) + 3))] = in[27]; | |
out[((((((lane_id / 16) * 8) + ((lane_id % 2) * 4)) + 3) * 768) + (32 + (((lane_id % 16) / 2) * 4)))] = in[28]; | |
out[((((((lane_id / 16) * 8) + ((lane_id % 2) * 4)) + 3) * 768) + ((32 + (((lane_id % 16) / 2) * 4)) + 1))] = in[29]; | |
out[((((((lane_id / 16) * 8) + ((lane_id % 2) * 4)) + 3) * 768) + ((32 + (((lane_id % 16) / 2) * 4)) + 2))] = in[30]; | |
out[((((((lane_id / 16) * 8) + ((lane_id % 2) * 4)) + 3) * 768) + ((32 + (((lane_id % 16) / 2) * 4)) + 3))] = in[31]; | |
out[((((16 + ((lane_id / 16) * 8)) + ((lane_id % 2) * 4)) * 768) + (((lane_id % 16) / 2) * 4))] = in[32]; | |
out[((((16 + ((lane_id / 16) * 8)) + ((lane_id % 2) * 4)) * 768) + ((((lane_id % 16) / 2) * 4) + 1))] = in[33]; | |
out[((((16 + ((lane_id / 16) * 8)) + ((lane_id % 2) * 4)) * 768) + ((((lane_id % 16) / 2) * 4) + 2))] = in[34]; | |
out[((((16 + ((lane_id / 16) * 8)) + ((lane_id % 2) * 4)) * 768) + ((((lane_id % 16) / 2) * 4) + 3))] = in[35]; | |
out[((((16 + ((lane_id / 16) * 8)) + ((lane_id % 2) * 4)) * 768) + (32 + (((lane_id % 16) / 2) * 4)))] = in[36]; | |
out[((((16 + ((lane_id / 16) * 8)) + ((lane_id % 2) * 4)) * 768) + ((32 + (((lane_id % 16) / 2) * 4)) + 1))] = in[37]; | |
out[((((16 + ((lane_id / 16) * 8)) + ((lane_id % 2) * 4)) * 768) + ((32 + (((lane_id % 16) / 2) * 4)) + 2))] = in[38]; | |
out[((((16 + ((lane_id / 16) * 8)) + ((lane_id % 2) * 4)) * 768) + ((32 + (((lane_id % 16) / 2) * 4)) + 3))] = in[39]; | |
out[(((((16 + ((lane_id / 16) * 8)) + ((lane_id % 2) * 4)) + 1) * 768) + (((lane_id % 16) / 2) * 4))] = in[40]; | |
out[(((((16 + ((lane_id / 16) * 8)) + ((lane_id % 2) * 4)) + 1) * 768) + ((((lane_id % 16) / 2) * 4) + 1))] = in[41]; | |
out[(((((16 + ((lane_id / 16) * 8)) + ((lane_id % 2) * 4)) + 1) * 768) + ((((lane_id % 16) / 2) * 4) + 2))] = in[42]; | |
out[(((((16 + ((lane_id / 16) * 8)) + ((lane_id % 2) * 4)) + 1) * 768) + ((((lane_id % 16) / 2) * 4) + 3))] = in[43]; | |
out[(((((16 + ((lane_id / 16) * 8)) + ((lane_id % 2) * 4)) + 1) * 768) + (32 + (((lane_id % 16) / 2) * 4)))] = in[44]; | |
out[(((((16 + ((lane_id / 16) * 8)) + ((lane_id % 2) * 4)) + 1) * 768) + ((32 + (((lane_id % 16) / 2) * 4)) + 1))] = in[45]; | |
out[(((((16 + ((lane_id / 16) * 8)) + ((lane_id % 2) * 4)) + 1) * 768) + ((32 + (((lane_id % 16) / 2) * 4)) + 2))] = in[46]; | |
out[(((((16 + ((lane_id / 16) * 8)) + ((lane_id % 2) * 4)) + 1) * 768) + ((32 + (((lane_id % 16) / 2) * 4)) + 3))] = in[47]; | |
out[(((((16 + ((lane_id / 16) * 8)) + ((lane_id % 2) * 4)) + 2) * 768) + (((lane_id % 16) / 2) * 4))] = in[48]; | |
out[(((((16 + ((lane_id / 16) * 8)) + ((lane_id % 2) * 4)) + 2) * 768) + ((((lane_id % 16) / 2) * 4) + 1))] = in[49]; | |
out[(((((16 + ((lane_id / 16) * 8)) + ((lane_id % 2) * 4)) + 2) * 768) + ((((lane_id % 16) / 2) * 4) + 2))] = in[50]; | |
out[(((((16 + ((lane_id / 16) * 8)) + ((lane_id % 2) * 4)) + 2) * 768) + ((((lane_id % 16) / 2) * 4) + 3))] = in[51]; | |
out[(((((16 + ((lane_id / 16) * 8)) + ((lane_id % 2) * 4)) + 2) * 768) + (32 + (((lane_id % 16) / 2) * 4)))] = in[52]; | |
out[(((((16 + ((lane_id / 16) * 8)) + ((lane_id % 2) * 4)) + 2) * 768) + ((32 + (((lane_id % 16) / 2) * 4)) + 1))] = in[53]; | |
out[(((((16 + ((lane_id / 16) * 8)) + ((lane_id % 2) * 4)) + 2) * 768) + ((32 + (((lane_id % 16) / 2) * 4)) + 2))] = in[54]; | |
out[(((((16 + ((lane_id / 16) * 8)) + ((lane_id % 2) * 4)) + 2) * 768) + ((32 + (((lane_id % 16) / 2) * 4)) + 3))] = in[55]; | |
out[(((((16 + ((lane_id / 16) * 8)) + ((lane_id % 2) * 4)) + 3) * 768) + (((lane_id % 16) / 2) * 4))] = in[56]; | |
out[(((((16 + ((lane_id / 16) * 8)) + ((lane_id % 2) * 4)) + 3) * 768) + ((((lane_id % 16) / 2) * 4) + 1))] = in[57]; | |
out[(((((16 + ((lane_id / 16) * 8)) + ((lane_id % 2) * 4)) + 3) * 768) + ((((lane_id % 16) / 2) * 4) + 2))] = in[58]; | |
out[(((((16 + ((lane_id / 16) * 8)) + ((lane_id % 2) * 4)) + 3) * 768) + ((((lane_id % 16) / 2) * 4) + 3))] = in[59]; | |
out[(((((16 + ((lane_id / 16) * 8)) + ((lane_id % 2) * 4)) + 3) * 768) + (32 + (((lane_id % 16) / 2) * 4)))] = in[60]; | |
out[(((((16 + ((lane_id / 16) * 8)) + ((lane_id % 2) * 4)) + 3) * 768) + ((32 + (((lane_id % 16) / 2) * 4)) + 1))] = in[61]; | |
out[(((((16 + ((lane_id / 16) * 8)) + ((lane_id % 2) * 4)) + 3) * 768) + ((32 + (((lane_id % 16) / 2) * 4)) + 2))] = in[62]; | |
out[(((((16 + ((lane_id / 16) * 8)) + ((lane_id % 2) * 4)) + 3) * 768) + ((32 + (((lane_id % 16) / 2) * 4)) + 3))] = in[63]; | |
} | |
__device__ __forceinline__ void matmul_bt128x128_bsz256_s128x128_block(float A[3833856], float B[1769472], float C[1277952]) { | |
// label: matmul128x128x8 | |
__shared__ float smem_A[2048]; | |
__shared__ float smem_B[2048]; | |
float regs_A[16]; | |
float regs_B[16]; | |
float regs_C[64]; | |
float regs_A_ldg[4]; | |
float regs_B_ldg[4]; | |
matmul_bt128x128_bsz256_s128x128_block_c_init_warp(regs_C); | |
int32_t warp_id = (threadIdx.x / 32); | |
matmul_bt128x128_bsz256_s128x128_block_a_g2r_block(&A[0], regs_A_ldg); | |
matmul_bt128x128_bsz256_s128x128_block_a_r2s_block(regs_A_ldg, smem_A); | |
matmul_bt128x128_bsz256_s128x128_block_b_g2r_block(&B[0], regs_B_ldg); | |
matmul_bt128x128_bsz256_s128x128_block_b_r2s_block(regs_B_ldg, smem_B); | |
__syncthreads(); | |
matmul_bt128x128_bsz256_s128x128_block_a_s2r_warp(&smem_A[((warp_id / 2) * 32)], ®s_A[0]); | |
matmul_bt128x128_bsz256_s128x128_block_b_s2r_warp(&smem_B[((warp_id % 2) * 64)], ®s_B[0]); | |
__syncthreads(); | |
for (int32_t block_k_tile = 0; (block_k_tile < 287); block_k_tile = (block_k_tile + 1)) { | |
#pragma unroll | |
for (int32_t warp_k_tile = 0; (warp_k_tile < 8); warp_k_tile = (warp_k_tile + 1)) { | |
if (warp_k_tile == 0) { | |
matmul_bt128x128_bsz256_s128x128_block_a_s2r_warp(&smem_A[((((block_k_tile % 2) * 1024) + ((warp_id / 2) * 32)) + ((warp_k_tile + 1) * 128))], ®s_A[(((warp_k_tile + 1) % 2) * 8)]); | |
matmul_bt128x128_bsz256_s128x128_block_b_s2r_warp(&smem_B[((((block_k_tile % 2) * 1024) + ((warp_k_tile + 1) * 128)) + ((warp_id % 2) * 64))], ®s_B[(((warp_k_tile + 1) % 2) * 8)]); | |
matmul_bt128x128_bsz256_s128x128_block_a_g2r_block(&A[((block_k_tile + 1) * 8)], regs_A_ldg); | |
matmul_bt128x128_bsz256_s128x128_block_b_g2r_block(&B[(((block_k_tile + 1) * 8) * 768)], regs_B_ldg); | |
matmul_bt128x128_bsz256_s128x128_block_compute_warp(®s_A[((warp_k_tile % 2) * 8)], ®s_B[((warp_k_tile % 2) * 8)], regs_C); | |
} else { | |
if (warp_k_tile < 7) { | |
matmul_bt128x128_bsz256_s128x128_block_a_s2r_warp(&smem_A[((((block_k_tile % 2) * 1024) + ((warp_id / 2) * 32)) + ((warp_k_tile + 1) * 128))], ®s_A[(((warp_k_tile + 1) % 2) * 8)]); | |
matmul_bt128x128_bsz256_s128x128_block_b_s2r_warp(&smem_B[((((block_k_tile % 2) * 1024) + ((warp_k_tile + 1) * 128)) + ((warp_id % 2) * 64))], ®s_B[(((warp_k_tile + 1) % 2) * 8)]); | |
matmul_bt128x128_bsz256_s128x128_block_compute_warp(®s_A[((warp_k_tile % 2) * 8)], ®s_B[((warp_k_tile % 2) * 8)], regs_C); | |
} else { | |
matmul_bt128x128_bsz256_s128x128_block_a_r2s_block(regs_A_ldg, &smem_A[(((block_k_tile + 1) % 2) * 1024)]); | |
matmul_bt128x128_bsz256_s128x128_block_b_r2s_block(regs_B_ldg, &smem_B[(((block_k_tile + 1) % 2) * 1024)]); | |
__syncthreads(); | |
matmul_bt128x128_bsz256_s128x128_block_a_s2r_warp(&smem_A[((((block_k_tile + 1) % 2) * 1024) + ((warp_id / 2) * 32))], ®s_A[(((warp_k_tile + 1) % 2) * 8)]); | |
matmul_bt128x128_bsz256_s128x128_block_b_s2r_warp(&smem_B[((((block_k_tile + 1) % 2) * 1024) + ((warp_id % 2) * 64))], ®s_B[(((warp_k_tile + 1) % 2) * 8)]); | |
matmul_bt128x128_bsz256_s128x128_block_compute_warp(®s_A[((warp_k_tile % 2) * 8)], ®s_B[((warp_k_tile % 2) * 8)], regs_C); | |
} | |
} | |
} | |
} | |
#pragma unroll | |
for (int32_t warp_k_tile_1 = 0; (warp_k_tile_1 < 8); warp_k_tile_1 = (warp_k_tile_1 + 1)) { | |
if (warp_k_tile_1 < 7) { | |
matmul_bt128x128_bsz256_s128x128_block_a_s2r_warp(&smem_A[((1024 + ((warp_id / 2) * 32)) + ((warp_k_tile_1 + 1) * 128))], ®s_A[(((warp_k_tile_1 + 1) % 2) * 8)]); | |
matmul_bt128x128_bsz256_s128x128_block_b_s2r_warp(&smem_B[((1024 + ((warp_k_tile_1 + 1) * 128)) + ((warp_id % 2) * 64))], ®s_B[(((warp_k_tile_1 + 1) % 2) * 8)]); | |
matmul_bt128x128_bsz256_s128x128_block_compute_warp(®s_A[((warp_k_tile_1 % 2) * 8)], ®s_B[((warp_k_tile_1 % 2) * 8)], regs_C); | |
} else { | |
matmul_bt128x128_bsz256_s128x128_block_compute_warp(®s_A[((warp_k_tile_1 % 2) * 8)], ®s_B[((warp_k_tile_1 % 2) * 8)], regs_C); | |
} | |
} | |
__syncthreads(); | |
matmul_bt128x128_bsz256_s128x128_block_r2g_warp(regs_C, &C[((((warp_id / 2) * 32) * 768) + ((warp_id % 2) * 64))]); | |
} | |
__global__ void __launch_bounds__(256, 2) matmul_grid(float A[3833856], float B[1769472], float C[1277952]) { | |
// label: block_task-128x128-block_size-256 | |
int32_t n_block_idx = (blockIdx.x / 6); | |
int32_t m_block_idx = (blockIdx.x % 6); | |
matmul_bt128x128_bsz256_s128x128_block(&A[((n_block_idx * 128) * 2304)], &B[(m_block_idx * 128)], &C[(((n_block_idx * 128) * 768) + (m_block_idx * 128))]); | |
} | |
__host__ void matmul(int32_t num_args, int32_t *arg_types, void* *args) { | |
assert(((void)"expect 3 args", (num_args == 3))); | |
assert(((void)"The 0 th arg should be TensorType(ScalarType(float32), [1664, 2304], global)", (arg_types[0] == 3))); | |
assert(((void)"The 1 th arg should be TensorType(ScalarType(float32), [2304, 768], global)", (arg_types[1] == 3))); | |
assert(((void)"The 2 th arg should be TensorType(ScalarType(float32), [1664, 768], global)", (arg_types[2] == 3))); | |
matmul_grid<<<78,256>>>((float*)args[0], (float*)args[1], (float*)args[2]); | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment