Last active
July 23, 2024 03:30
-
-
Save w32zhong/c6a3dcc2bdd623ac9ec2f2ebee8d54c2 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
@I.ir_module | |
class Module: | |
@T.prim_func | |
def main(var_A: T.handle, B: T.Buffer((768, 384), "int8"), Scale: T.Buffer((768, 3), "float16"), Zeros: T.Buffer((768, 3), "float16"), var_D: T.handle): | |
T.func_attr({"dequantize_info": {"B_decode": {"decode_block": "B_decode", "fast_decoding": T.bool(False), "group_size": 256, "source_format": {"bits": 4, "format": "uint"}, "storage_dtype": "int8", "target_format": "float16", "with_scaling": T.bool(True), "with_zeros": T.bool(True), "zeros_mode": "rescale"}}, "dlight.tensorcore_prenormlized": T.bool(True), "opt_shapes": {"m": [2, 12]}, "tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) | |
m = T.int32() | |
A = T.match_buffer(var_A, (m, 768), "float16") | |
D = T.match_buffer(var_D, (m, 768), "float16") | |
# with T.block("root"): | |
A_reindex_pad_shared_dyn = T.alloc_buffer((1, (m + 127) // 128 * 128, 768), "float16", scope="shared.dyn") | |
B_decode_reindex_shared_dyn = T.alloc_buffer((1, 768, 768), "float16", scope="shared.dyn") | |
B_decode_reindex_local = T.alloc_buffer((1, 768, 768), "float16", scope="local") | |
B_local = T.alloc_buffer((768, 384), "int8", scope="local") | |
A_reindex_pad_shared_dyn_warp = T.alloc_buffer((1, (m + 127) // 128 * 8, 48, 32, 8), "float16", scope="warp") | |
B_decode_reindex_shared_dyn_warp = T.alloc_buffer((1, 48, 48, 32, 8), "float16", scope="warp") | |
C_reindex_pad_shared_dyn = T.alloc_buffer((1, (m + 127) // 128 * 8, 48, 16, 16), "float16", scope="shared.dyn") | |
C_reindex_pad_shared_dyn_warp = T.alloc_buffer((1, (m + 127) // 128 * 8, 48, 32, 8), "float16", scope="warp") | |
for ax0 in T.thread_binding(1, thread="blockIdx.z"): | |
for ax1_0_0_ax2_0_0_fused in T.thread_binding((m + 127) // 128, thread="blockIdx.y"): | |
for ax1_0_1_ax2_0_1_fused in T.thread_binding(6, thread="blockIdx.x"): | |
for ax1_0_2 in T.thread_binding(2, thread="threadIdx.y"): | |
for ax2_0_2 in T.thread_binding(2, thread="threadIdx.z"): | |
for ax1_0_3_init, ax2_0_3_init in T.grid(4, 4): | |
with T.block("C_o_init"): | |
v0_o = T.axis.spatial(1, ax0) | |
v1_o = T.axis.spatial((m + 127) // 128 * 8, ax1_0_0_ax2_0_0_fused * 8 + ax1_0_2 * 4 + ax1_0_3_init) | |
v2_o = T.axis.spatial(48, ax1_0_1_ax2_0_1_fused * 8 + ax2_0_2 * 4 + ax2_0_3_init) | |
T.reads() | |
T.writes(C_reindex_pad_shared_dyn_warp[0, v1_o, v2_o, 0:32, 0:8]) | |
with T.block("C_init_o"): | |
v1_i_init_o = T.axis.spatial(1, 0) | |
v2_i_init_o = T.axis.spatial(1, 0) | |
T.reads() | |
T.writes(C_reindex_pad_shared_dyn_warp[0, v1_o, v2_o, 0:32, 0:8]) | |
C_warp = T.match_buffer(C_reindex_pad_shared_dyn_warp[0, v1_o, v2_o, 0:32, 0:8], (32, 8), "float16", scope="warp", offset_factor=1) | |
for tx in T.thread_binding(32, thread="threadIdx.x"): | |
T.mma_fill("float16", 8, C_warp.data, C_warp.elem_offset) | |
for ax3_0_0 in range(24): | |
for ax0_ax1_ax2_fused_0 in T.unroll(8): | |
for ax0_ax1_ax2_fused_1 in T.thread_binding(2, thread="threadIdx.y"): | |
for ax0_ax1_ax2_fused_2 in T.thread_binding(2, thread="threadIdx.z", annotations={"pragma_unroll_explicit": 0}): | |
for ax0_ax1_ax2_fused_3 in T.thread_binding(32, thread="threadIdx.x"): | |
for ax0_ax1_ax2_fused_4 in T.vectorized(4): | |
with T.block("A_reindex_pad_shared.dyn"): | |
v0 = T.axis.spatial(1, 0) | |
v1 = T.axis.spatial((m + 127) // 128 * 128, ax1_0_0_ax2_0_0_fused * 128 + (ax0_ax1_ax2_fused_0 * 512 + ax0_ax1_ax2_fused_1 * 256 + ax0_ax1_ax2_fused_2 * 128 + ax0_ax1_ax2_fused_3 * 4 + ax0_ax1_ax2_fused_4) // 32) | |
v2 = T.axis.spatial(768, ax3_0_0 * 32 + (ax0_ax1_ax2_fused_0 * 512 + ax0_ax1_ax2_fused_1 * 256 + ax0_ax1_ax2_fused_2 * 128 + ax0_ax1_ax2_fused_3 * 4 + ax0_ax1_ax2_fused_4) % 32) | |
T.reads(A[v1, v2]) | |
T.writes(A_reindex_pad_shared_dyn[v0, v1, v2]) | |
T.block_attr({"permuted_layout": 1}) | |
A_reindex_pad_shared_dyn[v0, v1, v2] = T.if_then_else(v1 < m, A[v1, v2], T.float16(0)) | |
for ax0_1, ax1_ax2_0_fused_0 in T.grid(1, 4): | |
for ax1_ax2_0_fused_1 in T.thread_binding(2, thread="threadIdx.y"): | |
for ax1_ax2_0_fused_2 in T.thread_binding(2, thread="threadIdx.z"): | |
for ax1_ax2_0_fused_3 in T.thread_binding(32, thread="threadIdx.x"): | |
for ax2_1 in range(1): | |
for ax0_2 in range(1): | |
for ax1 in T.vectorized(4): | |
with T.block("B_local"): | |
v0 = T.axis.spatial(768, ax1_0_1_ax2_0_1_fused * 128 + (ax1_ax2_0_fused_0 * 128 + ax1_ax2_0_fused_1 * 64 + ax1_ax2_0_fused_2 * 32 + ax1_ax2_0_fused_3) // 4 + ax0_2) | |
v1 = T.axis.spatial(384, ax3_0_0 * 16 + (ax1_ax2_0_fused_0 * 128 + ax1_ax2_0_fused_1 * 64 + ax1_ax2_0_fused_2 * 32 + ax1_ax2_0_fused_3) % 4 * 4 + ax1) | |
T.reads(B[v0, v1]) | |
T.writes(B_local[v0, v1]) | |
B_local[v0, v1] = B[v0, v1] | |
for ax0_2, ax1, ax2 in T.grid(1, 1, 8): | |
with T.block("B_decode_reindex_local"): | |
v0 = T.axis.spatial(1, ax0_2) | |
v1 = T.axis.spatial(768, ax1_0_1_ax2_0_1_fused * 128 + (ax1_ax2_0_fused_0 * 128 + ax1_ax2_0_fused_1 * 64 + ax1_ax2_0_fused_2 * 32 + ax1_ax2_0_fused_3) // 4 + ax1) | |
v2 = T.axis.spatial(768, ax3_0_0 * 32 + (ax1_ax2_0_fused_0 * 128 + ax1_ax2_0_fused_1 * 64 + ax1_ax2_0_fused_2 * 32 + ax1_ax2_0_fused_3) % 4 * 8 + ax2) | |
T.reads(B_local[v1, v2 // 2], Scale[v1, v2 // 256], Zeros[v1, v2 // 256]) | |
T.writes(B_decode_reindex_local[v0, v1, v2]) | |
B_decode_reindex_local[v0, v1, v2] = T.Cast("float16", T.bitwise_and(T.shift_right(B_local[v1, v2 // 2], T.Cast("int8", v2 % 2 * 4)), T.int8(15))) * Scale[v1, v2 // 256] - Zeros[v1, v2 // 256] | |
for ax2_2 in T.vectorized(8): | |
with T.block("B_decode_reindex_shared.dyn"): | |
v0 = T.axis.spatial(1, ax0_1) | |
v1 = T.axis.spatial(768, ax1_0_1_ax2_0_1_fused * 128 + (ax1_ax2_0_fused_0 * 128 + ax1_ax2_0_fused_1 * 64 + ax1_ax2_0_fused_2 * 32 + ax1_ax2_0_fused_3) // 4) | |
v2 = T.axis.spatial(768, ax3_0_0 * 32 + (ax1_ax2_0_fused_0 * 128 + ax1_ax2_0_fused_1 * 64 + ax1_ax2_0_fused_2 * 32 + ax1_ax2_0_fused_3) % 4 * 8 + ax2_1 * 8 + ax2_2) | |
T.reads(B_decode_reindex_local[v0, v1, v2]) | |
T.writes(B_decode_reindex_shared_dyn[v0, v1, v2]) | |
T.block_attr({"permuted_layout": 1}) | |
B_decode_reindex_shared_dyn[v0, v1, v2] = B_decode_reindex_local[v0, v1, v2] | |
for ax3_0_1 in range(2): | |
for ax0_1, ax1_0, ax2_0 in T.grid(1, 4, 1): | |
with T.block("A_reindex_pad_shared.dyn_warp_o"): | |
v0_o = T.axis.spatial(1, ax0_1) | |
v1_o = T.axis.spatial(8 * ((m + 127) // 128), ax1_0_0_ax2_0_0_fused * 8 + ax1_0_2 * 4 + ax1_0) | |
v2_o = T.axis.spatial(48, ax3_0_0 * 2 + ax3_0_1 + ax2_0) | |
T.reads(A_reindex_pad_shared_dyn[v0_o, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16]) | |
T.writes(A_reindex_pad_shared_dyn_warp[v0_o, v1_o, v2_o, 0:32, 0:8]) | |
T.block_attr({"permuted_layout": 1}) | |
warp = T.match_buffer(A_reindex_pad_shared_dyn_warp[v0_o, v1_o, v2_o, 0:32, 0:8], (32, 8), "float16", scope="warp", offset_factor=16) | |
shared = T.match_buffer(A_reindex_pad_shared_dyn[v0_o, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16], (16, 16), "float16", strides=("shared_s0", "shared_s1"), scope="shared.dyn", offset_factor=16) | |
for tx in T.thread_binding(32, thread="threadIdx.x"): | |
T.ptx_ldmatrix("float16", T.bool(False), 4, ".b16", warp.data, warp.elem_offset + 8 * tx, T.tvm_access_ptr(T.type_annotation("float16"), shared.data, shared.elem_offset, shared.strides[0] * 16, 1), shared.strides[0] * (tx % 16) + 8 * (tx // 16)) | |
for ax0_1, ax1_0, ax2_0 in T.grid(1, 4, 1): | |
with T.block("B_decode_reindex_shared.dyn_warp_o"): | |
v0_o = T.axis.spatial(1, ax0_1) | |
v1_o = T.axis.spatial(48, ax1_0_1_ax2_0_1_fused * 8 + ax2_0_2 * 4 + ax1_0) | |
v2_o = T.axis.spatial(48, ax3_0_0 * 2 + ax3_0_1 + ax2_0) | |
T.reads(B_decode_reindex_shared_dyn[v0_o, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16]) | |
T.writes(B_decode_reindex_shared_dyn_warp[v0_o, v1_o, v2_o, 0:32, 0:8]) | |
T.block_attr({"permuted_layout": 1}) | |
warp = T.match_buffer(B_decode_reindex_shared_dyn_warp[v0_o, v1_o, v2_o, 0:32, 0:8], (32, 8), "float16", scope="warp", offset_factor=16) | |
shared = T.match_buffer(B_decode_reindex_shared_dyn[v0_o, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16], (16, 16), "float16", strides=("shared_s0", "shared_s1"), scope="shared.dyn", offset_factor=16) | |
for tx in T.thread_binding(32, thread="threadIdx.x"): | |
T.ptx_ldmatrix("float16", T.bool(False), 4, ".b16", warp.data, warp.elem_offset + 8 * tx, T.tvm_access_ptr(T.type_annotation("float16"), shared.data, shared.elem_offset, shared.strides[0] * 16, 1), shared.strides[0] * 8 * (tx // 16) + shared.strides[0] * (tx % 8) + 8 * (tx % 16 // 8)) | |
for ax1_0_3, ax2_0_3 in T.grid(4, 4): | |
with T.block("C_o_update"): | |
v0_o = T.axis.spatial(1, ax0) | |
v1_o = T.axis.spatial((m + 127) // 128 * 8, ax1_0_0_ax2_0_0_fused * 8 + ax1_0_2 * 4 + ax1_0_3) | |
v2_o = T.axis.spatial(48, ax1_0_1_ax2_0_1_fused * 8 + ax2_0_2 * 4 + ax2_0_3) | |
v3_o = T.axis.reduce(48, ax3_0_0 * 2 + ax3_0_1) | |
T.reads(C_reindex_pad_shared_dyn_warp[0, v1_o, v2_o, 0:32, 0:8], A_reindex_pad_shared_dyn_warp[0, v1_o, v3_o, 0:32, 0:8], B_decode_reindex_shared_dyn_warp[0, v2_o, v3_o, 0:32, 0:8]) | |
T.writes(C_reindex_pad_shared_dyn_warp[0, v1_o, v2_o, 0:32, 0:8]) | |
with T.block("C_o"): | |
v1_i_o = T.axis.spatial(1, 0) | |
v2_i_o = T.axis.spatial(1, 0) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment