Skip to content

Instantly share code, notes, and snippets.

@masahi
Created November 9, 2022 08:51
Show Gist options
  • Save masahi/01a80b86062122ad57b9b1fd785fb960 to your computer and use it in GitHub Desktop.
Save masahi/01a80b86062122ad57b9b1fd785fb960 to your computer and use it in GitHub Desktop.
import tvm
from tvm import tir
from tvm.script import tir as T
from tvm.tir.tensor_intrin.cuda import *
from tvm.tir import Schedule
import tvm.meta_schedule as ms
from tvm.target import Target
@tvm.script.ir_module
class Module:
@T.prim_func
def main(p0: T.Buffer[(16, 56, 56, 64), "int8"], p1: T.Buffer[(256, 1, 1, 64), "int8"], p2: T.Buffer[(1, 1, 1, 256), "int32"], p3: T.Buffer[(1, 1, 1, 256), "int32"], p4: T.Buffer[256, "int32"], p5: T.Buffer[256, "int32"], p6: T.Buffer[256, "int32"], p7: T.Buffer[(), "int32"], p8: T.Buffer[1, "int32"], p9: T.Buffer[(16, 56, 56, 256), "int32"], compute: T.Buffer[(16, 56, 56, 256), "int32"]):
# function attr dict
T.func_attr({"global_symbol": "main", "tir.noalias": True})
# body
with T.block("root"):
T.reads()
T.writes()
T.block_attr({"meta_schedule.unroll_explicit":1024})
compute_3 = T.alloc_buffer([16, 56, 56, 256], dtype="int32")
conv2d_nhwc_reindex_shared = T.alloc_buffer([50176, 256], dtype="int32", scope="shared")
conv2d_nhwc_reindex_shared_wmma_accumulator = T.alloc_buffer([50176, 256], dtype="int32", scope="wmma.accumulator")
pad_temp_reindex_shared = T.alloc_buffer([50176, 64], dtype="int8", scope="shared")
p1_reindex_shared = T.alloc_buffer([1, 1, 256, 64], dtype="int8", scope="shared")
pad_temp_reindex_shared_wmma_matrix_a = T.alloc_buffer([50176, 64], dtype="int8", scope="wmma.matrix_a")
p1_reindex_shared_wmma_matrix_b = T.alloc_buffer([1, 1, 256, 64], dtype="int8", scope="wmma.matrix_b")
for ax2_0_0_ax3_0_0_fused in T.thread_binding(32, thread="blockIdx.y"):
for ax2_0_1_ax3_0_1_fused in T.thread_binding(196, thread="blockIdx.x"):
for ax2_0_2_ax3_0_2_fused in T.thread_binding(4, thread="threadIdx.y"):
for ax0_0, ax1_0, ax4_0_0 in T.grid(1, 1, 2):
for ax0_ax1_fused in T.serial(1024):
with T.block("pad_temp_reindex_shared"):
v0 = T.axis.spatial(50176, ax2_0_0_ax3_0_0_fused // 4 * 6272 + ax2_0_1_ax3_0_1_fused * 32 + ax0_ax1_fused // 32)
v1 = T.axis.spatial(64, ax4_0_0 * 32 + ax0_ax1_fused % 32)
T.reads(p0[v0 // 3136, v0 % 3136 // 56, v0 % 56, v1])
T.writes(pad_temp_reindex_shared[v0, v1])
T.block_attr({"buffer_dim_align":[[0, 0, 32, 16]], "meta_schedule.cooperative_fetch":4})
pad_temp_reindex_shared[v0, v1] = p0[v0 // 3136, v0 % 3136 // 56, v0 % 56, v1]
for ax0_ax1_ax2_ax3_fused in T.serial(2048):
with T.block("p1_reindex_shared"):
v0 = T.axis.spatial(1, 0)
v1 = T.axis.spatial(1, 0)
v2 = T.axis.spatial(256, ax2_0_0_ax3_0_0_fused % 4 * 64 + ax0_ax1_ax2_ax3_fused // 32)
v3 = T.axis.spatial(64, ax4_0_0 * 32 + ax0_ax1_ax2_ax3_fused % 32)
T.reads(p1[v2, v0, v1, v3])
T.writes(p1_reindex_shared[v0, v1, v2, v3])
T.block_attr({"buffer_dim_align":[[0, 2, 32, 16]], "meta_schedule.cooperative_fetch":3})
p1_reindex_shared[v0, v1, v2, v3] = p1[v2, v0, v1, v3]
for ax0_1, ax1_1, ax4_0_1 in T.grid(1, 1, 2):
for ax0_0_1, ax1_0_1 in T.grid(1, 1):
with T.block("pad_temp_reindex_shared_wmma.matrix_a_o"):
v0_o = T.axis.spatial(3136, ax2_0_0_ax3_0_0_fused // 4 * 392 + ax2_0_1_ax3_0_1_fused * 2 + ax2_0_2_ax3_0_2_fused // 2)
v1_o = T.axis.spatial(4, ax4_0_0 * 2 + ax4_0_1)
T.reads(pad_temp_reindex_shared[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16])
T.writes(pad_temp_reindex_shared_wmma_matrix_a[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16])
T.block_attr({"meta_schedule.auto_tensorize":"wmma_load_16x16x16_s8_a"})
for ax0_1_1, ax1_1_1 in T.grid(16, 16):
with T.block("pad_temp_reindex_shared_wmma.matrix_a"):
v0_i, v1_i = T.axis.remap("SS", [ax0_1_1, ax1_1_1])
T.reads(pad_temp_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i])
T.writes(pad_temp_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i, v1_o * 16 + v1_i])
pad_temp_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i, v1_o * 16 + v1_i] = pad_temp_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i]
for ax0, ax1, ax2_0, ax3_0 in T.grid(1, 1, 2, 1):
with T.block("p1_reindex_shared_wmma.matrix_b_o"):
v0 = T.axis.spatial(1, 0)
v1 = T.axis.spatial(1, 0)
v2_o = T.axis.spatial(16, ax2_0_0_ax3_0_0_fused % 4 * 4 + ax2_0_2_ax3_0_2_fused % 2 * 2 + ax2_0)
v3_o = T.axis.spatial(4, ax4_0_0 * 2 + ax4_0_1)
T.reads(p1_reindex_shared[v0, v1, v2_o * 16 : v2_o * 16 + 16, v3_o * 16 : v3_o * 16 + 16])
T.writes(p1_reindex_shared_wmma_matrix_b[v0, v1, v2_o * 16 : v2_o * 16 + 16, v3_o * 16 : v3_o * 16 + 16])
T.block_attr({"meta_schedule.auto_tensorize":"wmma_load_16x16x16_s8_b_trans"})
for ax2_1, ax3_1 in T.grid(16, 16):
with T.block("p1_reindex_shared_wmma.matrix_b"):
v2_i, v3_i = T.axis.remap("SS", [ax2_1, ax3_1])
T.reads(p1_reindex_shared[v0, v1, v2_o * 16 + v2_i, v3_o * 16 + v3_i])
T.writes(p1_reindex_shared_wmma_matrix_b[v0, v1, v2_o * 16 + v2_i, v3_o * 16 + v3_i])
p1_reindex_shared_wmma_matrix_b[v0, v1, v2_o * 16 + v2_i, v3_o * 16 + v3_i] = p1_reindex_shared[v0, v1, v2_o * 16 + v2_i, v3_o * 16 + v3_i]
for ax2_0_3, ax3_0_3, ax0_2, ax1_2, ax4_0_2, ax2_0_4, ax3_0_4 in T.grid(1, 1, 1, 1, 1, 1, 2):
with T.block("conv2d_nhwc_o"):
v0 = T.axis.reduce(1, 0)
v1 = T.axis.reduce(1, 0)
v2_o = T.axis.spatial(3136, ax2_0_0_ax3_0_0_fused // 4 * 392 + ax2_0_1_ax3_0_1_fused * 2 + ax2_0_2_ax3_0_2_fused // 2 + ax2_0_3 + ax2_0_4)
v3_o = T.axis.spatial(16, ax2_0_0_ax3_0_0_fused % 4 * 4 + ax2_0_2_ax3_0_2_fused % 2 * 2 + ax3_0_3 * 2 + ax3_0_4)
v4_o = T.axis.reduce(4, ax4_0_0 * 2 + ax4_0_1 + ax4_0_2)
T.reads(pad_temp_reindex_shared_wmma_matrix_a[v2_o * 16 : v2_o * 16 + 16, v4_o * 16 : v4_o * 16 + 16], p1_reindex_shared_wmma_matrix_b[v0, v1, v3_o * 16 : v3_o * 16 + 16, v4_o * 16 : v4_o * 16 + 16])
T.writes(conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o * 16 : v2_o * 16 + 16, v3_o * 16 : v3_o * 16 + 16])
T.block_attr({"meta_schedule.auto_tensorize":"wmma_sync_16x16x16_s8s8s32_trans", "meta_schedule.auto_tensorize_init":"wmma_fill_16x16x16_s32", "meta_schedule.thread_extent_high_inclusive":1024, "meta_schedule.thread_extent_low_inclusive":32, "warp_execution":1})
with T.init():
for ax2_1, ax3_1 in T.grid(16, 16):
with T.block("conv2d_nhwc_init"):
v2_i_init, v3_i_init = T.axis.remap("SS", [ax2_1, ax3_1])
T.reads()
T.writes(conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o * 16 + v2_i_init, v3_o * 16 + v3_i_init])
conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o * 16 + v2_i_init, v3_o * 16 + v3_i_init] = 0
for ax2_1, ax3_1, ax4_1 in T.grid(16, 16, 16):
with T.block("conv2d_nhwc"):
v2_i, v3_i, v4_i = T.axis.remap("SSR", [ax2_1, ax3_1, ax4_1])
T.reads(conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o * 16 + v2_i, v3_o * 16 + v3_i], pad_temp_reindex_shared_wmma_matrix_a[v2_o * 16 + v2_i, v4_o * 16 + v4_i], p1_reindex_shared_wmma_matrix_b[v0, v1, v3_o * 16 + v3_i, v4_o * 16 + v4_i])
T.writes(conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o * 16 + v2_i, v3_o * 16 + v3_i])
T.block_attr({"meta_schedule.tiling_structure":"SSSRRSRS"})
conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o * 16 + v2_i, v3_o * 16 + v3_i] = conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o * 16 + v2_i, v3_o * 16 + v3_i] + T.cast(pad_temp_reindex_shared_wmma_matrix_a[v2_o * 16 + v2_i, v4_o * 16 + v4_i], "int32") * T.cast(p1_reindex_shared_wmma_matrix_b[v0, v1, v3_o * 16 + v3_i, v4_o * 16 + v4_i], "int32")
for ax0_0, ax1_0 in T.grid(1, 2):
with T.block("conv2d_nhwc_reindex_shared_wmma.accumulator_o"):
v0_o = T.axis.spatial(3136, ax2_0_0_ax3_0_0_fused // 4 * 392 + ax2_0_1_ax3_0_1_fused * 2 + ax2_0_2_ax3_0_2_fused // 2)
v1_o = T.axis.spatial(16, ax2_0_0_ax3_0_0_fused % 4 * 4 + ax2_0_2_ax3_0_2_fused % 2 * 2 + ax1_0)
T.reads(conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16])
T.writes(conv2d_nhwc_reindex_shared[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16])
T.block_attr({"meta_schedule.auto_tensorize":"wmma_store_16x16x16_s32_shared"})
for ax0_1, ax1_1 in T.grid(16, 16):
with T.block("conv2d_nhwc_reindex_shared_wmma.accumulator"):
v0_i, v1_i = T.axis.remap("SS", [ax0_1, ax1_1])
T.reads(conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i])
T.writes(conv2d_nhwc_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i])
conv2d_nhwc_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i] = conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i]
for ax0, ax1_0, ax1_1, ax1_2, ax1_3 in T.grid(32, 1, 4, 32, 2):
with T.block("conv2d_nhwc_reindex_shared"):
T.where(((ax1_0 * 4 + ax1_1) * 32 + ax1_2) * 2 + ax1_3 < 64)
v0 = T.axis.spatial(50176, ax2_0_0_ax3_0_0_fused // 4 * 6272 + ax2_0_1_ax3_0_1_fused * 32 + ax0)
v1 = T.axis.spatial(256, ax2_0_0_ax3_0_0_fused % 4 * 64 + (ax1_0 * 256 + ax1_1 * 64 + ax1_2 * 2 + ax1_3))
T.reads(p7[()], conv2d_nhwc_reindex_shared[v0, v1], p2[0, 0, 0, v1], p3[0, 0, 0, v1], p4[v1], p5[v1], p6[v1], p8[0])
T.writes(compute_3[v0 // 3136, v0 % 3136 // 56, v0 % 56, v1])
compute_3[v0 // 3136, v0 % 3136 // 56, v0 % 56, v1] = T.q_multiply_shift(T.max(T.min(p7[()] + T.q_multiply_shift_per_axis(conv2d_nhwc_reindex_shared[v0, v1] - p2[0, 0, 0, v1] + p3[0, 0, 0, v1], p4[v1], p5[v1], p6[v1], 31, False, True, dtype="int32"), 255), 0) - p8[0], 1457846997, 31, 0, dtype="int32")
for i0_12, i1_12, i2_12, i3_12 in T.grid(16, 56, 56, 256):
with T.block("compute_4"):
i0_13, i1_13, i2_13, i3_13 = T.axis.remap("SSSS", [i0_12, i1_12, i2_12, i3_12])
T.reads(compute_3[i0_13, i1_13, i2_13, i3_13], p9[i0_13, i1_13, i2_13, i3_13])
T.writes(compute[i0_13, i1_13, i2_13, i3_13])
compute[i0_13, i1_13, i2_13, i3_13] = T.max(T.min(compute_3[i0_13, i1_13, i2_13, i3_13] + T.q_multiply_shift(p9[i0_13, i1_13, i2_13, i3_13], 2101000910, 31, 0, dtype="int32"), 255), 0)
sch = tir.Schedule(Module)
sch.reverse_compute_inline(sch.get_block("compute_4"))
print(sch.mod.script())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment