-
-
Save zxybazh/6bff29ae4e7cb273d57bb30599790008 to your computer and use it in GitHub Desktop.
Script to Reproduce Error in RewriteLayout PostProcessor
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""Testing script""" | |
import tvm | |
from tvm.ir.module import IRModule | |
from tvm.script import tir as T | |
from tvm.tir import Schedule | |
from tvm import meta_schedule as ms | |
from tempfile import TemporaryDirectory | |
# fmt: off | |
@tvm.script.ir_module | |
class Module: | |
@T.prim_func | |
def main(p0: T.Buffer[(1, 14, 14, 256), "float32"], p1: T.Buffer[(1, 1, 256, 512), "float32"], p2: T.Buffer[(1, 1, 1, 512), "float32"], T_add: T.Buffer[(1, 7, 7, 512), "float32"]) -> None: | |
# function attr dict | |
T.func_attr({"layout_free_buffers": [1], "tir.noalias": True, "global_symbol": "main"}) | |
# body | |
# with T.block("root") | |
conv2d_nhwc = T.alloc_buffer([1, 7, 7, 512], dtype="float32") | |
p1_global = T.alloc_buffer([1, 1, 256, 512], dtype="float32") | |
for ax0, ax1, ax2, ax3 in T.grid(1, 1, 256, 512): | |
with T.block("p1_global"): | |
v0, v1, v2, v3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) | |
T.reads(p1[v0, v1, v2, v3]) | |
T.writes(p1_global[v0, v1, v2, v3]) | |
p1_global[v0, v1, v2, v3] = p1[v0, v1, v2, v3] | |
for i0_0_i1_0_i2_0_i3_0_i0_1_i1_1_fused in T.parallel(196, annotations={"pragma_auto_unroll_max_step":16, "pragma_unroll_explicit":1}): | |
for i2_1, i3_1 in T.grid(1, 8): | |
for i0_2_init, i1_2_init, i2_2_init, i3_2_init, i0_3_init, i1_3_init, i2_3_init in T.grid(1, 1, 1, 4, 1, 1, 1): | |
for i3_3_fused_init in T.vectorized(4): | |
with T.block("conv2d_nhwc_init"): | |
nn = T.axis.spatial(1, i0_3_init + i0_2_init) | |
yy = T.axis.spatial(7, i1_3_init + i0_0_i1_0_i2_0_i3_0_i0_1_i1_1_fused % 7 + i1_2_init) | |
xx = T.axis.spatial(7, i2_3_init + i0_0_i1_0_i2_0_i3_0_i0_1_i1_1_fused // 28 + i2_1 + i2_2_init) | |
ff = T.axis.spatial(512, i0_0_i1_0_i2_0_i3_0_i0_1_i1_1_fused % 28 // 7 * 128 + i3_1 * 16 + i3_2_init * 4 + i3_3_fused_init) | |
T.reads() | |
T.writes(conv2d_nhwc[nn, yy, xx, ff]) | |
T.block_attr({"meta_schedule.tiling_structure":"SSRSRS"}) | |
conv2d_nhwc[nn, yy, xx, ff] = T.float32(0) | |
for i4_0, i5_0, i6_0, i0_2, i1_2, i2_2, i3_2, i4_1, i5_1, i6_1, i0_3, i1_3, i2_3 in T.grid(1, 1, 8, 1, 1, 1, 4, 1, 1, 32, 1, 1, 1): | |
for i3_3_fused in T.vectorized(4): | |
with T.block("conv2d_nhwc_update"): | |
nn = T.axis.spatial(1, i0_3 + i0_2) | |
yy = T.axis.spatial(7, i1_3 + i0_0_i1_0_i2_0_i3_0_i0_1_i1_1_fused % 7 + i1_2) | |
xx = T.axis.spatial(7, i2_3 + i0_0_i1_0_i2_0_i3_0_i0_1_i1_1_fused // 28 + i2_1 + i2_2) | |
ff = T.axis.spatial(512, i0_0_i1_0_i2_0_i3_0_i0_1_i1_1_fused % 28 // 7 * 128 + i3_1 * 16 + i3_2 * 4 + i3_3_fused) | |
ry = T.axis.reduce(1, i4_1 + i4_0) | |
rx = T.axis.reduce(1, i5_0 + i5_1) | |
rc = T.axis.reduce(256, i6_0 * 32 + i6_1) | |
T.reads(conv2d_nhwc[nn, yy, xx, ff], p0[nn, yy * 2 + ry, xx * 2 + rx, rc], p1_global[ry, rx, rc, ff]) | |
T.writes(conv2d_nhwc[nn, yy, xx, ff]) | |
T.block_attr({"meta_schedule.tiling_structure":"SSRSRS"}) | |
conv2d_nhwc[nn, yy, xx, ff] = conv2d_nhwc[nn, yy, xx, ff] + p0[nn, yy * 2 + ry, xx * 2 + rx, rc] * p1_global[ry, rx, rc, ff] | |
for ax0, ax1, ax2 in T.grid(1, 1, 1): | |
for ax3_fused in T.vectorized(16): | |
with T.block("T_add"): | |
ax0_1 = T.axis.spatial(1, ax0) | |
ax1_1 = T.axis.spatial(7, i0_0_i1_0_i2_0_i3_0_i0_1_i1_1_fused % 7 + ax1) | |
ax2_1 = T.axis.spatial(7, i0_0_i1_0_i2_0_i3_0_i0_1_i1_1_fused // 28 + ax2) | |
ax3 = T.axis.spatial(512, i0_0_i1_0_i2_0_i3_0_i0_1_i1_1_fused % 28 // 7 * 128 + i3_1 * 16 + ax3_fused) | |
T.reads(conv2d_nhwc[ax0_1, ax1_1, ax2_1, ax3], p2[ax0_1, 0, 0, ax3]) | |
T.writes(T_add[ax0_1, ax1_1, ax2_1, ax3]) | |
T_add[ax0_1, ax1_1, ax2_1, ax3] = conv2d_nhwc[ax0_1, ax1_1, ax2_1, ax3] + p2[ax0_1, 0, 0, ax3] | |
# fmt: on | |
if __name__ == "__main__": | |
mod = Module | |
target = tvm.target.Target("llvm --num-cores=12") | |
pp = ms.postproc.RewriteLayout() | |
sch = Schedule(mod) | |
pp.apply(sch) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment