Skip to content

Instantly share code, notes, and snippets.

@w32zhong
Last active July 23, 2024 03:30
Show Gist options
  • Save w32zhong/c6a3dcc2bdd623ac9ec2f2ebee8d54c2 to your computer and use it in GitHub Desktop.
Save w32zhong/c6a3dcc2bdd623ac9ec2f2ebee8d54c2 to your computer and use it in GitHub Desktop.
@I.ir_module
class Module:
@T.prim_func
def main(var_A: T.handle, B: T.Buffer((768, 384), "int8"), Scale: T.Buffer((768, 3), "float16"), Zeros: T.Buffer((768, 3), "float16"), var_D: T.handle):
T.func_attr({"dequantize_info": {"B_decode": {"decode_block": "B_decode", "fast_decoding": T.bool(False), "group_size": 256, "source_format": {"bits": 4, "format": "uint"}, "storage_dtype": "int8", "target_format": "float16", "with_scaling": T.bool(True), "with_zeros": T.bool(True), "zeros_mode": "rescale"}}, "dlight.tensorcore_prenormlized": T.bool(True), "opt_shapes": {"m": [2, 12]}, "tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
m = T.int32()
A = T.match_buffer(var_A, (m, 768), "float16")
D = T.match_buffer(var_D, (m, 768), "float16")
# with T.block("root"):
A_reindex_pad_shared_dyn = T.alloc_buffer((1, (m + 127) // 128 * 128, 768), "float16", scope="shared.dyn")
B_decode_reindex_shared_dyn = T.alloc_buffer((1, 768, 768), "float16", scope="shared.dyn")
B_decode_reindex_local = T.alloc_buffer((1, 768, 768), "float16", scope="local")
B_local = T.alloc_buffer((768, 384), "int8", scope="local")
A_reindex_pad_shared_dyn_warp = T.alloc_buffer((1, (m + 127) // 128 * 8, 48, 32, 8), "float16", scope="warp")
B_decode_reindex_shared_dyn_warp = T.alloc_buffer((1, 48, 48, 32, 8), "float16", scope="warp")
C_reindex_pad_shared_dyn = T.alloc_buffer((1, (m + 127) // 128 * 8, 48, 16, 16), "float16", scope="shared.dyn")
C_reindex_pad_shared_dyn_warp = T.alloc_buffer((1, (m + 127) // 128 * 8, 48, 32, 8), "float16", scope="warp")
for ax0 in T.thread_binding(1, thread="blockIdx.z"):
for ax1_0_0_ax2_0_0_fused in T.thread_binding((m + 127) // 128, thread="blockIdx.y"):
for ax1_0_1_ax2_0_1_fused in T.thread_binding(6, thread="blockIdx.x"):
for ax1_0_2 in T.thread_binding(2, thread="threadIdx.y"):
for ax2_0_2 in T.thread_binding(2, thread="threadIdx.z"):
for ax1_0_3_init, ax2_0_3_init in T.grid(4, 4):
with T.block("C_o_init"):
v0_o = T.axis.spatial(1, ax0)
v1_o = T.axis.spatial((m + 127) // 128 * 8, ax1_0_0_ax2_0_0_fused * 8 + ax1_0_2 * 4 + ax1_0_3_init)
v2_o = T.axis.spatial(48, ax1_0_1_ax2_0_1_fused * 8 + ax2_0_2 * 4 + ax2_0_3_init)
T.reads()
T.writes(C_reindex_pad_shared_dyn_warp[0, v1_o, v2_o, 0:32, 0:8])
with T.block("C_init_o"):
v1_i_init_o = T.axis.spatial(1, 0)
v2_i_init_o = T.axis.spatial(1, 0)
T.reads()
T.writes(C_reindex_pad_shared_dyn_warp[0, v1_o, v2_o, 0:32, 0:8])
C_warp = T.match_buffer(C_reindex_pad_shared_dyn_warp[0, v1_o, v2_o, 0:32, 0:8], (32, 8), "float16", scope="warp", offset_factor=1)
for tx in T.thread_binding(32, thread="threadIdx.x"):
T.mma_fill("float16", 8, C_warp.data, C_warp.elem_offset)
for ax3_0_0 in range(24):
for ax0_ax1_ax2_fused_0 in T.unroll(8):
for ax0_ax1_ax2_fused_1 in T.thread_binding(2, thread="threadIdx.y"):
for ax0_ax1_ax2_fused_2 in T.thread_binding(2, thread="threadIdx.z", annotations={"pragma_unroll_explicit": 0}):
for ax0_ax1_ax2_fused_3 in T.thread_binding(32, thread="threadIdx.x"):
for ax0_ax1_ax2_fused_4 in T.vectorized(4):
with T.block("A_reindex_pad_shared.dyn"):
v0 = T.axis.spatial(1, 0)
v1 = T.axis.spatial((m + 127) // 128 * 128, ax1_0_0_ax2_0_0_fused * 128 + (ax0_ax1_ax2_fused_0 * 512 + ax0_ax1_ax2_fused_1 * 256 + ax0_ax1_ax2_fused_2 * 128 + ax0_ax1_ax2_fused_3 * 4 + ax0_ax1_ax2_fused_4) // 32)
v2 = T.axis.spatial(768, ax3_0_0 * 32 + (ax0_ax1_ax2_fused_0 * 512 + ax0_ax1_ax2_fused_1 * 256 + ax0_ax1_ax2_fused_2 * 128 + ax0_ax1_ax2_fused_3 * 4 + ax0_ax1_ax2_fused_4) % 32)
T.reads(A[v1, v2])
T.writes(A_reindex_pad_shared_dyn[v0, v1, v2])
T.block_attr({"permuted_layout": 1})
A_reindex_pad_shared_dyn[v0, v1, v2] = T.if_then_else(v1 < m, A[v1, v2], T.float16(0))
for ax0_1, ax1_ax2_0_fused_0 in T.grid(1, 4):
for ax1_ax2_0_fused_1 in T.thread_binding(2, thread="threadIdx.y"):
for ax1_ax2_0_fused_2 in T.thread_binding(2, thread="threadIdx.z"):
for ax1_ax2_0_fused_3 in T.thread_binding(32, thread="threadIdx.x"):
for ax2_1 in range(1):
for ax0_2 in range(1):
for ax1 in T.vectorized(4):
with T.block("B_local"):
v0 = T.axis.spatial(768, ax1_0_1_ax2_0_1_fused * 128 + (ax1_ax2_0_fused_0 * 128 + ax1_ax2_0_fused_1 * 64 + ax1_ax2_0_fused_2 * 32 + ax1_ax2_0_fused_3) // 4 + ax0_2)
v1 = T.axis.spatial(384, ax3_0_0 * 16 + (ax1_ax2_0_fused_0 * 128 + ax1_ax2_0_fused_1 * 64 + ax1_ax2_0_fused_2 * 32 + ax1_ax2_0_fused_3) % 4 * 4 + ax1)
T.reads(B[v0, v1])
T.writes(B_local[v0, v1])
B_local[v0, v1] = B[v0, v1]
for ax0_2, ax1, ax2 in T.grid(1, 1, 8):
with T.block("B_decode_reindex_local"):
v0 = T.axis.spatial(1, ax0_2)
v1 = T.axis.spatial(768, ax1_0_1_ax2_0_1_fused * 128 + (ax1_ax2_0_fused_0 * 128 + ax1_ax2_0_fused_1 * 64 + ax1_ax2_0_fused_2 * 32 + ax1_ax2_0_fused_3) // 4 + ax1)
v2 = T.axis.spatial(768, ax3_0_0 * 32 + (ax1_ax2_0_fused_0 * 128 + ax1_ax2_0_fused_1 * 64 + ax1_ax2_0_fused_2 * 32 + ax1_ax2_0_fused_3) % 4 * 8 + ax2)
T.reads(B_local[v1, v2 // 2], Scale[v1, v2 // 256], Zeros[v1, v2 // 256])
T.writes(B_decode_reindex_local[v0, v1, v2])
B_decode_reindex_local[v0, v1, v2] = T.Cast("float16", T.bitwise_and(T.shift_right(B_local[v1, v2 // 2], T.Cast("int8", v2 % 2 * 4)), T.int8(15))) * Scale[v1, v2 // 256] - Zeros[v1, v2 // 256]
for ax2_2 in T.vectorized(8):
with T.block("B_decode_reindex_shared.dyn"):
v0 = T.axis.spatial(1, ax0_1)
v1 = T.axis.spatial(768, ax1_0_1_ax2_0_1_fused * 128 + (ax1_ax2_0_fused_0 * 128 + ax1_ax2_0_fused_1 * 64 + ax1_ax2_0_fused_2 * 32 + ax1_ax2_0_fused_3) // 4)
v2 = T.axis.spatial(768, ax3_0_0 * 32 + (ax1_ax2_0_fused_0 * 128 + ax1_ax2_0_fused_1 * 64 + ax1_ax2_0_fused_2 * 32 + ax1_ax2_0_fused_3) % 4 * 8 + ax2_1 * 8 + ax2_2)
T.reads(B_decode_reindex_local[v0, v1, v2])
T.writes(B_decode_reindex_shared_dyn[v0, v1, v2])
T.block_attr({"permuted_layout": 1})
B_decode_reindex_shared_dyn[v0, v1, v2] = B_decode_reindex_local[v0, v1, v2]
for ax3_0_1 in range(2):
for ax0_1, ax1_0, ax2_0 in T.grid(1, 4, 1):
with T.block("A_reindex_pad_shared.dyn_warp_o"):
v0_o = T.axis.spatial(1, ax0_1)
v1_o = T.axis.spatial(8 * ((m + 127) // 128), ax1_0_0_ax2_0_0_fused * 8 + ax1_0_2 * 4 + ax1_0)
v2_o = T.axis.spatial(48, ax3_0_0 * 2 + ax3_0_1 + ax2_0)
T.reads(A_reindex_pad_shared_dyn[v0_o, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16])
T.writes(A_reindex_pad_shared_dyn_warp[v0_o, v1_o, v2_o, 0:32, 0:8])
T.block_attr({"permuted_layout": 1})
warp = T.match_buffer(A_reindex_pad_shared_dyn_warp[v0_o, v1_o, v2_o, 0:32, 0:8], (32, 8), "float16", scope="warp", offset_factor=16)
shared = T.match_buffer(A_reindex_pad_shared_dyn[v0_o, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16], (16, 16), "float16", strides=("shared_s0", "shared_s1"), scope="shared.dyn", offset_factor=16)
for tx in T.thread_binding(32, thread="threadIdx.x"):
T.ptx_ldmatrix("float16", T.bool(False), 4, ".b16", warp.data, warp.elem_offset + 8 * tx, T.tvm_access_ptr(T.type_annotation("float16"), shared.data, shared.elem_offset, shared.strides[0] * 16, 1), shared.strides[0] * (tx % 16) + 8 * (tx // 16))
for ax0_1, ax1_0, ax2_0 in T.grid(1, 4, 1):
with T.block("B_decode_reindex_shared.dyn_warp_o"):
v0_o = T.axis.spatial(1, ax0_1)
v1_o = T.axis.spatial(48, ax1_0_1_ax2_0_1_fused * 8 + ax2_0_2 * 4 + ax1_0)
v2_o = T.axis.spatial(48, ax3_0_0 * 2 + ax3_0_1 + ax2_0)
T.reads(B_decode_reindex_shared_dyn[v0_o, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16])
T.writes(B_decode_reindex_shared_dyn_warp[v0_o, v1_o, v2_o, 0:32, 0:8])
T.block_attr({"permuted_layout": 1})
warp = T.match_buffer(B_decode_reindex_shared_dyn_warp[v0_o, v1_o, v2_o, 0:32, 0:8], (32, 8), "float16", scope="warp", offset_factor=16)
shared = T.match_buffer(B_decode_reindex_shared_dyn[v0_o, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16], (16, 16), "float16", strides=("shared_s0", "shared_s1"), scope="shared.dyn", offset_factor=16)
for tx in T.thread_binding(32, thread="threadIdx.x"):
T.ptx_ldmatrix("float16", T.bool(False), 4, ".b16", warp.data, warp.elem_offset + 8 * tx, T.tvm_access_ptr(T.type_annotation("float16"), shared.data, shared.elem_offset, shared.strides[0] * 16, 1), shared.strides[0] * 8 * (tx // 16) + shared.strides[0] * (tx % 8) + 8 * (tx % 16 // 8))
for ax1_0_3, ax2_0_3 in T.grid(4, 4):
with T.block("C_o_update"):
v0_o = T.axis.spatial(1, ax0)
v1_o = T.axis.spatial((m + 127) // 128 * 8, ax1_0_0_ax2_0_0_fused * 8 + ax1_0_2 * 4 + ax1_0_3)
v2_o = T.axis.spatial(48, ax1_0_1_ax2_0_1_fused * 8 + ax2_0_2 * 4 + ax2_0_3)
v3_o = T.axis.reduce(48, ax3_0_0 * 2 + ax3_0_1)
T.reads(C_reindex_pad_shared_dyn_warp[0, v1_o, v2_o, 0:32, 0:8], A_reindex_pad_shared_dyn_warp[0, v1_o, v3_o, 0:32, 0:8], B_decode_reindex_shared_dyn_warp[0, v2_o, v3_o, 0:32, 0:8])
T.writes(C_reindex_pad_shared_dyn_warp[0, v1_o, v2_o, 0:32, 0:8])
with T.block("C_o"):
v1_i_o = T.axis.spatial(1, 0)
v2_i_o = T.axis.spatial(1, 0)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment