-
-
Save masahi/01a80b86062122ad57b9b1fd785fb960 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import tvm | |
from tvm import tir | |
from tvm.script import tir as T | |
from tvm.tir.tensor_intrin.cuda import * | |
from tvm.tir import Schedule | |
import tvm.meta_schedule as ms | |
from tvm.target import Target | |
@tvm.script.ir_module | |
class Module: | |
@T.prim_func | |
def main(p0: T.Buffer[(16, 56, 56, 64), "int8"], p1: T.Buffer[(256, 1, 1, 64), "int8"], p2: T.Buffer[(1, 1, 1, 256), "int32"], p3: T.Buffer[(1, 1, 1, 256), "int32"], p4: T.Buffer[256, "int32"], p5: T.Buffer[256, "int32"], p6: T.Buffer[256, "int32"], p7: T.Buffer[(), "int32"], p8: T.Buffer[1, "int32"], p9: T.Buffer[(16, 56, 56, 256), "int32"], compute: T.Buffer[(16, 56, 56, 256), "int32"]): | |
# function attr dict | |
T.func_attr({"global_symbol": "main", "tir.noalias": True}) | |
# body | |
with T.block("root"): | |
T.reads() | |
T.writes() | |
T.block_attr({"meta_schedule.unroll_explicit":1024}) | |
compute_3 = T.alloc_buffer([16, 56, 56, 256], dtype="int32") | |
conv2d_nhwc_reindex_shared = T.alloc_buffer([50176, 256], dtype="int32", scope="shared") | |
conv2d_nhwc_reindex_shared_wmma_accumulator = T.alloc_buffer([50176, 256], dtype="int32", scope="wmma.accumulator") | |
pad_temp_reindex_shared = T.alloc_buffer([50176, 64], dtype="int8", scope="shared") | |
p1_reindex_shared = T.alloc_buffer([1, 1, 256, 64], dtype="int8", scope="shared") | |
pad_temp_reindex_shared_wmma_matrix_a = T.alloc_buffer([50176, 64], dtype="int8", scope="wmma.matrix_a") | |
p1_reindex_shared_wmma_matrix_b = T.alloc_buffer([1, 1, 256, 64], dtype="int8", scope="wmma.matrix_b") | |
for ax2_0_0_ax3_0_0_fused in T.thread_binding(32, thread="blockIdx.y"): | |
for ax2_0_1_ax3_0_1_fused in T.thread_binding(196, thread="blockIdx.x"): | |
for ax2_0_2_ax3_0_2_fused in T.thread_binding(4, thread="threadIdx.y"): | |
for ax0_0, ax1_0, ax4_0_0 in T.grid(1, 1, 2): | |
for ax0_ax1_fused in T.serial(1024): | |
with T.block("pad_temp_reindex_shared"): | |
v0 = T.axis.spatial(50176, ax2_0_0_ax3_0_0_fused // 4 * 6272 + ax2_0_1_ax3_0_1_fused * 32 + ax0_ax1_fused // 32) | |
v1 = T.axis.spatial(64, ax4_0_0 * 32 + ax0_ax1_fused % 32) | |
T.reads(p0[v0 // 3136, v0 % 3136 // 56, v0 % 56, v1]) | |
T.writes(pad_temp_reindex_shared[v0, v1]) | |
T.block_attr({"buffer_dim_align":[[0, 0, 32, 16]], "meta_schedule.cooperative_fetch":4}) | |
pad_temp_reindex_shared[v0, v1] = p0[v0 // 3136, v0 % 3136 // 56, v0 % 56, v1] | |
for ax0_ax1_ax2_ax3_fused in T.serial(2048): | |
with T.block("p1_reindex_shared"): | |
v0 = T.axis.spatial(1, 0) | |
v1 = T.axis.spatial(1, 0) | |
v2 = T.axis.spatial(256, ax2_0_0_ax3_0_0_fused % 4 * 64 + ax0_ax1_ax2_ax3_fused // 32) | |
v3 = T.axis.spatial(64, ax4_0_0 * 32 + ax0_ax1_ax2_ax3_fused % 32) | |
T.reads(p1[v2, v0, v1, v3]) | |
T.writes(p1_reindex_shared[v0, v1, v2, v3]) | |
T.block_attr({"buffer_dim_align":[[0, 2, 32, 16]], "meta_schedule.cooperative_fetch":3}) | |
p1_reindex_shared[v0, v1, v2, v3] = p1[v2, v0, v1, v3] | |
for ax0_1, ax1_1, ax4_0_1 in T.grid(1, 1, 2): | |
for ax0_0_1, ax1_0_1 in T.grid(1, 1): | |
with T.block("pad_temp_reindex_shared_wmma.matrix_a_o"): | |
v0_o = T.axis.spatial(3136, ax2_0_0_ax3_0_0_fused // 4 * 392 + ax2_0_1_ax3_0_1_fused * 2 + ax2_0_2_ax3_0_2_fused // 2) | |
v1_o = T.axis.spatial(4, ax4_0_0 * 2 + ax4_0_1) | |
T.reads(pad_temp_reindex_shared[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) | |
T.writes(pad_temp_reindex_shared_wmma_matrix_a[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) | |
T.block_attr({"meta_schedule.auto_tensorize":"wmma_load_16x16x16_s8_a"}) | |
for ax0_1_1, ax1_1_1 in T.grid(16, 16): | |
with T.block("pad_temp_reindex_shared_wmma.matrix_a"): | |
v0_i, v1_i = T.axis.remap("SS", [ax0_1_1, ax1_1_1]) | |
T.reads(pad_temp_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) | |
T.writes(pad_temp_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) | |
pad_temp_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i, v1_o * 16 + v1_i] = pad_temp_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i] | |
for ax0, ax1, ax2_0, ax3_0 in T.grid(1, 1, 2, 1): | |
with T.block("p1_reindex_shared_wmma.matrix_b_o"): | |
v0 = T.axis.spatial(1, 0) | |
v1 = T.axis.spatial(1, 0) | |
v2_o = T.axis.spatial(16, ax2_0_0_ax3_0_0_fused % 4 * 4 + ax2_0_2_ax3_0_2_fused % 2 * 2 + ax2_0) | |
v3_o = T.axis.spatial(4, ax4_0_0 * 2 + ax4_0_1) | |
T.reads(p1_reindex_shared[v0, v1, v2_o * 16 : v2_o * 16 + 16, v3_o * 16 : v3_o * 16 + 16]) | |
T.writes(p1_reindex_shared_wmma_matrix_b[v0, v1, v2_o * 16 : v2_o * 16 + 16, v3_o * 16 : v3_o * 16 + 16]) | |
T.block_attr({"meta_schedule.auto_tensorize":"wmma_load_16x16x16_s8_b_trans"}) | |
for ax2_1, ax3_1 in T.grid(16, 16): | |
with T.block("p1_reindex_shared_wmma.matrix_b"): | |
v2_i, v3_i = T.axis.remap("SS", [ax2_1, ax3_1]) | |
T.reads(p1_reindex_shared[v0, v1, v2_o * 16 + v2_i, v3_o * 16 + v3_i]) | |
T.writes(p1_reindex_shared_wmma_matrix_b[v0, v1, v2_o * 16 + v2_i, v3_o * 16 + v3_i]) | |
p1_reindex_shared_wmma_matrix_b[v0, v1, v2_o * 16 + v2_i, v3_o * 16 + v3_i] = p1_reindex_shared[v0, v1, v2_o * 16 + v2_i, v3_o * 16 + v3_i] | |
for ax2_0_3, ax3_0_3, ax0_2, ax1_2, ax4_0_2, ax2_0_4, ax3_0_4 in T.grid(1, 1, 1, 1, 1, 1, 2): | |
with T.block("conv2d_nhwc_o"): | |
v0 = T.axis.reduce(1, 0) | |
v1 = T.axis.reduce(1, 0) | |
v2_o = T.axis.spatial(3136, ax2_0_0_ax3_0_0_fused // 4 * 392 + ax2_0_1_ax3_0_1_fused * 2 + ax2_0_2_ax3_0_2_fused // 2 + ax2_0_3 + ax2_0_4) | |
v3_o = T.axis.spatial(16, ax2_0_0_ax3_0_0_fused % 4 * 4 + ax2_0_2_ax3_0_2_fused % 2 * 2 + ax3_0_3 * 2 + ax3_0_4) | |
v4_o = T.axis.reduce(4, ax4_0_0 * 2 + ax4_0_1 + ax4_0_2) | |
T.reads(pad_temp_reindex_shared_wmma_matrix_a[v2_o * 16 : v2_o * 16 + 16, v4_o * 16 : v4_o * 16 + 16], p1_reindex_shared_wmma_matrix_b[v0, v1, v3_o * 16 : v3_o * 16 + 16, v4_o * 16 : v4_o * 16 + 16]) | |
T.writes(conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o * 16 : v2_o * 16 + 16, v3_o * 16 : v3_o * 16 + 16]) | |
T.block_attr({"meta_schedule.auto_tensorize":"wmma_sync_16x16x16_s8s8s32_trans", "meta_schedule.auto_tensorize_init":"wmma_fill_16x16x16_s32", "meta_schedule.thread_extent_high_inclusive":1024, "meta_schedule.thread_extent_low_inclusive":32, "warp_execution":1}) | |
with T.init(): | |
for ax2_1, ax3_1 in T.grid(16, 16): | |
with T.block("conv2d_nhwc_init"): | |
v2_i_init, v3_i_init = T.axis.remap("SS", [ax2_1, ax3_1]) | |
T.reads() | |
T.writes(conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o * 16 + v2_i_init, v3_o * 16 + v3_i_init]) | |
conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o * 16 + v2_i_init, v3_o * 16 + v3_i_init] = 0 | |
for ax2_1, ax3_1, ax4_1 in T.grid(16, 16, 16): | |
with T.block("conv2d_nhwc"): | |
v2_i, v3_i, v4_i = T.axis.remap("SSR", [ax2_1, ax3_1, ax4_1]) | |
T.reads(conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o * 16 + v2_i, v3_o * 16 + v3_i], pad_temp_reindex_shared_wmma_matrix_a[v2_o * 16 + v2_i, v4_o * 16 + v4_i], p1_reindex_shared_wmma_matrix_b[v0, v1, v3_o * 16 + v3_i, v4_o * 16 + v4_i]) | |
T.writes(conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o * 16 + v2_i, v3_o * 16 + v3_i]) | |
T.block_attr({"meta_schedule.tiling_structure":"SSSRRSRS"}) | |
conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o * 16 + v2_i, v3_o * 16 + v3_i] = conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o * 16 + v2_i, v3_o * 16 + v3_i] + T.cast(pad_temp_reindex_shared_wmma_matrix_a[v2_o * 16 + v2_i, v4_o * 16 + v4_i], "int32") * T.cast(p1_reindex_shared_wmma_matrix_b[v0, v1, v3_o * 16 + v3_i, v4_o * 16 + v4_i], "int32") | |
for ax0_0, ax1_0 in T.grid(1, 2): | |
with T.block("conv2d_nhwc_reindex_shared_wmma.accumulator_o"): | |
v0_o = T.axis.spatial(3136, ax2_0_0_ax3_0_0_fused // 4 * 392 + ax2_0_1_ax3_0_1_fused * 2 + ax2_0_2_ax3_0_2_fused // 2) | |
v1_o = T.axis.spatial(16, ax2_0_0_ax3_0_0_fused % 4 * 4 + ax2_0_2_ax3_0_2_fused % 2 * 2 + ax1_0) | |
T.reads(conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) | |
T.writes(conv2d_nhwc_reindex_shared[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) | |
T.block_attr({"meta_schedule.auto_tensorize":"wmma_store_16x16x16_s32_shared"}) | |
for ax0_1, ax1_1 in T.grid(16, 16): | |
with T.block("conv2d_nhwc_reindex_shared_wmma.accumulator"): | |
v0_i, v1_i = T.axis.remap("SS", [ax0_1, ax1_1]) | |
T.reads(conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) | |
T.writes(conv2d_nhwc_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) | |
conv2d_nhwc_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i] = conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i] | |
for ax0, ax1_0, ax1_1, ax1_2, ax1_3 in T.grid(32, 1, 4, 32, 2): | |
with T.block("conv2d_nhwc_reindex_shared"): | |
T.where(((ax1_0 * 4 + ax1_1) * 32 + ax1_2) * 2 + ax1_3 < 64) | |
v0 = T.axis.spatial(50176, ax2_0_0_ax3_0_0_fused // 4 * 6272 + ax2_0_1_ax3_0_1_fused * 32 + ax0) | |
v1 = T.axis.spatial(256, ax2_0_0_ax3_0_0_fused % 4 * 64 + (ax1_0 * 256 + ax1_1 * 64 + ax1_2 * 2 + ax1_3)) | |
T.reads(p7[()], conv2d_nhwc_reindex_shared[v0, v1], p2[0, 0, 0, v1], p3[0, 0, 0, v1], p4[v1], p5[v1], p6[v1], p8[0]) | |
T.writes(compute_3[v0 // 3136, v0 % 3136 // 56, v0 % 56, v1]) | |
compute_3[v0 // 3136, v0 % 3136 // 56, v0 % 56, v1] = T.q_multiply_shift(T.max(T.min(p7[()] + T.q_multiply_shift_per_axis(conv2d_nhwc_reindex_shared[v0, v1] - p2[0, 0, 0, v1] + p3[0, 0, 0, v1], p4[v1], p5[v1], p6[v1], 31, False, True, dtype="int32"), 255), 0) - p8[0], 1457846997, 31, 0, dtype="int32") | |
for i0_12, i1_12, i2_12, i3_12 in T.grid(16, 56, 56, 256): | |
with T.block("compute_4"): | |
i0_13, i1_13, i2_13, i3_13 = T.axis.remap("SSSS", [i0_12, i1_12, i2_12, i3_12]) | |
T.reads(compute_3[i0_13, i1_13, i2_13, i3_13], p9[i0_13, i1_13, i2_13, i3_13]) | |
T.writes(compute[i0_13, i1_13, i2_13, i3_13]) | |
compute[i0_13, i1_13, i2_13, i3_13] = T.max(T.min(compute_3[i0_13, i1_13, i2_13, i3_13] + T.q_multiply_shift(p9[i0_13, i1_13, i2_13, i3_13], 2101000910, 31, 0, dtype="int32"), 255), 0) | |
sch = tir.Schedule(Module) | |
sch.reverse_compute_inline(sch.get_block("compute_4")) | |
print(sch.mod.script()) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment