Skip to content

Instantly share code, notes, and snippets.

@zxybazh
Created September 20, 2022 21:39
Show Gist options
  • Save zxybazh/6bff29ae4e7cb273d57bb30599790008 to your computer and use it in GitHub Desktop.
Save zxybazh/6bff29ae4e7cb273d57bb30599790008 to your computer and use it in GitHub Desktop.
Script to Reproduce Error in RewriteLayout PostProcessor
"""Testing script"""
import tvm
from tvm.ir.module import IRModule
from tvm.script import tir as T
from tvm.tir import Schedule
from tvm import meta_schedule as ms
from tempfile import TemporaryDirectory
# fmt: off
@tvm.script.ir_module
class Module:
@T.prim_func
def main(p0: T.Buffer[(1, 14, 14, 256), "float32"], p1: T.Buffer[(1, 1, 256, 512), "float32"], p2: T.Buffer[(1, 1, 1, 512), "float32"], T_add: T.Buffer[(1, 7, 7, 512), "float32"]) -> None:
# function attr dict
T.func_attr({"layout_free_buffers": [1], "tir.noalias": True, "global_symbol": "main"})
# body
# with T.block("root")
conv2d_nhwc = T.alloc_buffer([1, 7, 7, 512], dtype="float32")
p1_global = T.alloc_buffer([1, 1, 256, 512], dtype="float32")
for ax0, ax1, ax2, ax3 in T.grid(1, 1, 256, 512):
with T.block("p1_global"):
v0, v1, v2, v3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
T.reads(p1[v0, v1, v2, v3])
T.writes(p1_global[v0, v1, v2, v3])
p1_global[v0, v1, v2, v3] = p1[v0, v1, v2, v3]
for i0_0_i1_0_i2_0_i3_0_i0_1_i1_1_fused in T.parallel(196, annotations={"pragma_auto_unroll_max_step":16, "pragma_unroll_explicit":1}):
for i2_1, i3_1 in T.grid(1, 8):
for i0_2_init, i1_2_init, i2_2_init, i3_2_init, i0_3_init, i1_3_init, i2_3_init in T.grid(1, 1, 1, 4, 1, 1, 1):
for i3_3_fused_init in T.vectorized(4):
with T.block("conv2d_nhwc_init"):
nn = T.axis.spatial(1, i0_3_init + i0_2_init)
yy = T.axis.spatial(7, i1_3_init + i0_0_i1_0_i2_0_i3_0_i0_1_i1_1_fused % 7 + i1_2_init)
xx = T.axis.spatial(7, i2_3_init + i0_0_i1_0_i2_0_i3_0_i0_1_i1_1_fused // 28 + i2_1 + i2_2_init)
ff = T.axis.spatial(512, i0_0_i1_0_i2_0_i3_0_i0_1_i1_1_fused % 28 // 7 * 128 + i3_1 * 16 + i3_2_init * 4 + i3_3_fused_init)
T.reads()
T.writes(conv2d_nhwc[nn, yy, xx, ff])
T.block_attr({"meta_schedule.tiling_structure":"SSRSRS"})
conv2d_nhwc[nn, yy, xx, ff] = T.float32(0)
for i4_0, i5_0, i6_0, i0_2, i1_2, i2_2, i3_2, i4_1, i5_1, i6_1, i0_3, i1_3, i2_3 in T.grid(1, 1, 8, 1, 1, 1, 4, 1, 1, 32, 1, 1, 1):
for i3_3_fused in T.vectorized(4):
with T.block("conv2d_nhwc_update"):
nn = T.axis.spatial(1, i0_3 + i0_2)
yy = T.axis.spatial(7, i1_3 + i0_0_i1_0_i2_0_i3_0_i0_1_i1_1_fused % 7 + i1_2)
xx = T.axis.spatial(7, i2_3 + i0_0_i1_0_i2_0_i3_0_i0_1_i1_1_fused // 28 + i2_1 + i2_2)
ff = T.axis.spatial(512, i0_0_i1_0_i2_0_i3_0_i0_1_i1_1_fused % 28 // 7 * 128 + i3_1 * 16 + i3_2 * 4 + i3_3_fused)
ry = T.axis.reduce(1, i4_1 + i4_0)
rx = T.axis.reduce(1, i5_0 + i5_1)
rc = T.axis.reduce(256, i6_0 * 32 + i6_1)
T.reads(conv2d_nhwc[nn, yy, xx, ff], p0[nn, yy * 2 + ry, xx * 2 + rx, rc], p1_global[ry, rx, rc, ff])
T.writes(conv2d_nhwc[nn, yy, xx, ff])
T.block_attr({"meta_schedule.tiling_structure":"SSRSRS"})
conv2d_nhwc[nn, yy, xx, ff] = conv2d_nhwc[nn, yy, xx, ff] + p0[nn, yy * 2 + ry, xx * 2 + rx, rc] * p1_global[ry, rx, rc, ff]
for ax0, ax1, ax2 in T.grid(1, 1, 1):
for ax3_fused in T.vectorized(16):
with T.block("T_add"):
ax0_1 = T.axis.spatial(1, ax0)
ax1_1 = T.axis.spatial(7, i0_0_i1_0_i2_0_i3_0_i0_1_i1_1_fused % 7 + ax1)
ax2_1 = T.axis.spatial(7, i0_0_i1_0_i2_0_i3_0_i0_1_i1_1_fused // 28 + ax2)
ax3 = T.axis.spatial(512, i0_0_i1_0_i2_0_i3_0_i0_1_i1_1_fused % 28 // 7 * 128 + i3_1 * 16 + ax3_fused)
T.reads(conv2d_nhwc[ax0_1, ax1_1, ax2_1, ax3], p2[ax0_1, 0, 0, ax3])
T.writes(T_add[ax0_1, ax1_1, ax2_1, ax3])
T_add[ax0_1, ax1_1, ax2_1, ax3] = conv2d_nhwc[ax0_1, ax1_1, ax2_1, ax3] + p2[ax0_1, 0, 0, ax3]
# fmt: on
if __name__ == "__main__":
mod = Module
target = tvm.target.Target("llvm --num-cores=12")
pp = ms.postproc.RewriteLayout()
sch = Schedule(mod)
pp.apply(sch)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment