Skip to content

Instantly share code, notes, and snippets.

@zxybazh
Last active February 7, 2022 23:17
Show Gist options
  • Save zxybazh/4e6b2ff3dca7ca66254ea1e553103cfd to your computer and use it in GitHub Desktop.
Save zxybazh/4e6b2ff3dca7ca66254ea1e553103cfd to your computer and use it in GitHub Desktop.
The test script for failing compute at.
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring
import tvm
from tvm.script import tir as T
# pylint: disable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument,
# fmt: off
@tvm.script.ir_module
class Conv2d_Winograd:
@T.prim_func
def main(var_placeholder: T.handle, var_placeholder_1: T.handle, var_conv2d_winograd: T.handle) -> None:
# function attr dict
T.func_attr({"layout_free_placeholders": [var_placeholder_1]})
placeholder = T.match_buffer(var_placeholder, [1, 14, 14, 128], elem_offset=0, align=128, offset_factor=1)
placeholder_1 = T.match_buffer(var_placeholder_1, [6, 6, 128, 128], elem_offset=0, align=128, offset_factor=1)
conv2d_winograd = T.match_buffer(var_conv2d_winograd, [1, 12, 12, 128], elem_offset=0, align=128, offset_factor=1)
# body
with T.block("root"):
data_pad = T.alloc_buffer([1, 16, 16, 128], elem_offset=0, align=128, offset_factor=1)
input_tile = T.alloc_buffer([6, 6, 9, 128], elem_offset=0, align=128, offset_factor=1)
B = T.alloc_buffer([6, 6], elem_offset=0, align=128, offset_factor=1)
data_pack = T.alloc_buffer([6, 6, 9, 128], elem_offset=0, align=128, offset_factor=1)
bgemm = T.alloc_buffer([6, 6, 9, 128], elem_offset=0, align=128, offset_factor=1)
A = T.alloc_buffer([6, 4], elem_offset=0, align=128, offset_factor=1)
inverse = T.alloc_buffer([4, 4, 9, 128], elem_offset=0, align=128, offset_factor=1)
for i0, i1, i2, i3 in T.grid(1, 16, 16, 128):
with T.block("data_pad"):
i0_1, i1_1, i2_1, i3_1 = T.axis.remap("SSSS", [i0, i1, i2, i3])
T.reads([placeholder[i0_1, i1_1, i2_1, i3_1]])
T.writes([data_pad[i0_1, i1_1, i2_1, i3_1]])
data_pad[i0_1, i1_1, i2_1, i3_1] = T.if_then_else(((((0 <= i1_1) and (i1_1 < 14)) and (0 <= i2_1)) and (i2_1 < 14)), placeholder[i0_1, i1_1, i2_1, i3_1], T.float32(0), dtype="float32")
for i0_2, i1_2, i2_2, i3_2 in T.grid(6, 6, 9, 128):
with T.block("input_tile"):
eps, nu, p, ci = T.axis.remap("SSSS", [i0_2, i1_2, i2_2, i3_2])
T.reads([data_pad[T.floordiv(p, 9), ((T.floordiv(T.floormod(p, 9), 3)*4) + eps), ((T.floormod(p, 3)*4) + nu), ci]])
T.writes([input_tile[eps, nu, p, ci]])
input_tile[eps, nu, p, ci] = data_pad[T.floordiv(p, 9), ((T.floordiv(T.floormod(p, 9), 3)*4) + eps), ((T.floormod(p, 3)*4) + nu), ci]
for i0_3, i1_3 in T.grid(6, 6):
with T.block("B"):
i, j = T.axis.remap("SS", [i0_3, i1_3])
T.writes([B[i, j]])
T.block_attr({
"const_matrix" : True,
})
B[i, j] = T.Select(((T.floormod(i, 6) == 5) and (T.floormod(j, 6) == 5)), T.float32(1), T.Select(((T.floormod(i, 6) == 5) and (T.floormod(j, 6) == 4)), T.float32(0), T.Select(((T.floormod(i, 6) == 5) and (T.floormod(j, 6) == 3)), T.float32(0), T.Select(((T.floormod(i, 6) == 5) and (T.floormod(j, 6) == 2)), T.float32(0), T.Select(((T.floormod(i, 6) == 5) and (T.floormod(j, 6) == 1)), T.float32(0), T.Select(((T.floormod(i, 6) == 5) and (T.floormod(j, 6) == 0)), T.float32(0), T.Select(((T.floormod(i, 6) == 4) and (T.floormod(j, 6) == 5)), T.float32(1.5), T.Select(((T.floormod(i, 6) == 4) and (T.floormod(j, 6) == 4)), T.float32(1), T.Select(((T.floormod(i, 6) == 4) and (T.floormod(j, 6) == 3)), T.float32(1), T.Select(((T.floormod(i, 6) == 4) and (T.floormod(j, 6) == 2)), T.float32(1), T.Select(((T.floormod(i, 6) == 4) and (T.floormod(j, 6) == 1)), T.float32(1), T.Select(((T.floormod(i, 6) == 4) and (T.floormod(j, 6) == 0)), T.float32(1), T.Select(((T.floormod(i, 6) == 3) and (T.floormod(j, 6) == 5)), T.float32(-2), T.Select(((T.floormod(i, 6) == 3) and (T.floormod(j, 6) == 4)), T.float32(-0.5), T.Select(((T.floormod(i, 6) == 3) and (T.floormod(j, 6) == 3)), T.float32(2), T.Select(((T.floormod(i, 6) == 3) and (T.floormod(j, 6) == 2)), T.float32(2.5), T.Select(((T.floormod(i, 6) == 3) and (T.floormod(j, 6) == 1)), T.float32(0.5), T.Select(((T.floormod(i, 6) == 3) and (T.floormod(j, 6) == 0)), T.float32(1.5), T.Select(((T.floormod(i, 6) == 2) and (T.floormod(j, 6) == 5)), T.float32(-1.5), T.Select(((T.floormod(i, 6) == 2) and (T.floormod(j, 6) == 4)), T.float32(-1), T.Select(((T.floormod(i, 6) == 2) and (T.floormod(j, 6) == 3)), T.float32(-1), T.Select(((T.floormod(i, 6) == 2) and (T.floormod(j, 6) == 2)), T.float32(0.5), T.Select(((T.floormod(i, 6) == 2) and (T.floormod(j, 6) == 1)), T.float32(-2.5), T.Select(((T.floormod(i, 6) == 2) and (T.floormod(j, 6) == 0)), T.float32(-2), T.Select(((T.floormod(i, 6) == 1) and (T.floormod(j, 6) == 5)), T.float32(1), T.Select(((T.floormod(i, 6) == 1) and (T.floormod(j, 6) == 4)), T.float32(0.5), T.Select(((T.floormod(i, 6) == 1) and (T.floormod(j, 6) == 3)), T.float32(-2), T.Select(((T.floormod(i, 6) == 1) and (T.floormod(j, 6) == 2)), T.float32(-1), T.Select(((T.floormod(i, 6) == 1) and (T.floormod(j, 6) == 1)), T.float32(1), T.Select(((T.floormod(i, 6) == 1) and (T.floormod(j, 6) == 0)), T.float32(-1.5), T.Select(((T.floormod(i, 6) == 0) and (T.floormod(j, 6) == 5)), T.float32(0), T.Select(((T.floormod(i, 6) == 0) and (T.floormod(j, 6) == 4)), T.float32(0), T.Select(((T.floormod(i, 6) == 0) and (T.floormod(j, 6) == 3)), T.float32(0), T.Select(((T.floormod(i, 6) == 0) and (T.floormod(j, 6) == 2)), T.float32(0), T.Select(((T.floormod(i, 6) == 0) and (T.floormod(j, 6) == 1)), T.float32(0), T.Select(((T.floormod(i, 6) == 0) and (T.floormod(j, 6) == 0)), T.float32(1), T.float32(0)))))))))))))))))))))))))))))))))))))
for i0_4, i1_4, i2_3, i3_3, i4, i5 in T.grid(6, 6, 9, 128, 6, 6):
with T.block("data_pack"):
eps_1, nu_1, p_1, ci_1, r_a, r_b = T.axis.remap("SSSSRR", [i0_4, i1_4, i2_3, i3_3, i4, i5])
T.reads([data_pack[eps_1, nu_1, p_1, ci_1], input_tile[r_a, r_b, p_1, ci_1], B[T.min(r_a, r_b):(T.min(r_a, r_b) + ((T.max(r_a, r_b) + 1) - T.min(r_a, r_b))), T.min(eps_1, nu_1):(T.min(eps_1, nu_1) + ((T.max(eps_1, nu_1) + 1) - T.min(eps_1, nu_1)))]])
T.writes([data_pack[eps_1, nu_1, p_1, ci_1]])
T.block_attr({
"auto_scheduler_simplify_const_tensor_indices":["eps", "nu", "r_a", "r_b"],
})
with T.init():
data_pack[eps_1, nu_1, p_1, ci_1] = T.float32(0)
data_pack[eps_1, nu_1, p_1, ci_1] = (data_pack[eps_1, nu_1, p_1, ci_1] + ((input_tile[r_a, r_b, p_1, ci_1]*B[r_a, eps_1])*B[r_b, nu_1]))
for i0_5, i1_5, i2_4, i3_4, i4_1 in T.grid(6, 6, 9, 128, 128):
with T.block("bgemm"):
eps_2, nu_2, p_2, co, ci_2 = T.axis.remap("SSSSR", [i0_5, i1_5, i2_4, i3_4, i4_1])
T.reads([bgemm[eps_2, nu_2, p_2, co], data_pack[eps_2, nu_2, p_2, ci_2], placeholder_1[eps_2, nu_2, co, ci_2]])
T.writes([bgemm[eps_2, nu_2, p_2, co]])
T.block_attr({
"schedule_rule": "None",
})
with T.init():
bgemm[eps_2, nu_2, p_2, co] = T.float32(0)
bgemm[eps_2, nu_2, p_2, co] = (bgemm[eps_2, nu_2, p_2, co] + (data_pack[eps_2, nu_2, p_2, ci_2]*placeholder_1[eps_2, nu_2, co, ci_2]))
for i0_6, i1_6 in T.grid(6, 4):
with T.block("A"):
i_1, j_1 = T.axis.remap("SS", [i0_6, i1_6])
T.writes([A[i_1, j_1]])
T.block_attr({
"const_matrix" : True,
})
A[i_1, j_1] = T.Select(((T.floormod(i_1, 6) == 5) and (T.floormod(j_1, 4) == 3)), T.float32(1), T.Select(((T.floormod(i_1, 6) == 5) and (T.floormod(j_1, 4) == 2)), T.float32(0), T.Select(((T.floormod(i_1, 6) == 5) and (T.floormod(j_1, 4) == 1)), T.float32(0), T.Select(((T.floormod(i_1, 6) == 5) and (T.floormod(j_1, 4) == 0)), T.float32(0), T.Select(((T.floormod(i_1, 6) == 4) and (T.floormod(j_1, 4) == 3)), T.float32(-8), T.Select(((T.floormod(i_1, 6) == 4) and (T.floormod(j_1, 4) == 2)), T.float32(4), T.Select(((T.floormod(i_1, 6) == 4) and (T.floormod(j_1, 4) == 1)), T.float32(-2), T.Select(((T.floormod(i_1, 6) == 4) and (T.floormod(j_1, 4) == 0)), T.float32(1), T.Select(((T.floormod(i_1, 6) == 3) and (T.floormod(j_1, 4) == 3)), T.float32(0.125), T.Select(((T.floormod(i_1, 6) == 3) and (T.floormod(j_1, 4) == 2)), T.float32(0.25), T.Select(((T.floormod(i_1, 6) == 3) and (T.floormod(j_1, 4) == 1)), T.float32(0.5), T.Select(((T.floormod(i_1, 6) == 3) and (T.floormod(j_1, 4) == 0)), T.float32(1), T.Select(((T.floormod(i_1, 6) == 2) and (T.floormod(j_1, 4) == 3)), T.float32(1), T.Select(((T.floormod(i_1, 6) == 2) and (T.floormod(j_1, 4) == 2)), T.float32(1), T.Select(((T.floormod(i_1, 6) == 2) and (T.floormod(j_1, 4) == 1)), T.float32(1), T.Select(((T.floormod(i_1, 6) == 2) and (T.floormod(j_1, 4) == 0)), T.float32(1), T.Select(((T.floormod(i_1, 6) == 1) and (T.floormod(j_1, 4) == 3)), T.float32(-1), T.Select(((T.floormod(i_1, 6) == 1) and (T.floormod(j_1, 4) == 2)), T.float32(1), T.Select(((T.floormod(i_1, 6) == 1) and (T.floormod(j_1, 4) == 1)), T.float32(-1), T.Select(((T.floormod(i_1, 6) == 1) and (T.floormod(j_1, 4) == 0)), T.float32(1), T.Select(((T.floormod(i_1, 6) == 0) and (T.floormod(j_1, 4) == 3)), T.float32(0), T.Select(((T.floormod(i_1, 6) == 0) and (T.floormod(j_1, 4) == 2)), T.float32(0), T.Select(((T.floormod(i_1, 6) == 0) and (T.floormod(j_1, 4) == 1)), T.float32(0), T.Select(((T.floormod(i_1, 6) == 0) and (T.floormod(j_1, 4) == 0)), T.float32(1), T.float32(0)))))))))))))))))))))))))
for i0_7, i1_7, i2_5, i3_5, i4_2, i5_1 in T.grid(4, 4, 9, 128, 6, 6):
with T.block("inverse"):
vh, vw, p_3, co_1, r_a_1, r_b_1 = T.axis.remap("SSSSRR", [i0_7, i1_7, i2_5, i3_5, i4_2, i5_1])
T.reads([inverse[vh, vw, p_3, co_1], bgemm[r_a_1, r_b_1, p_3, co_1], A[T.min(r_a_1, r_b_1):(T.min(r_a_1, r_b_1) + ((T.max(r_a_1, r_b_1) + 1) - T.min(r_a_1, r_b_1))), T.min(vh, vw):(T.min(vh, vw) + ((T.max(vh, vw) + 1) - T.min(vh, vw)))]])
T.writes([inverse[vh, vw, p_3, co_1]])
T.block_attr({
"auto_scheduler_simplify_const_tensor_indices":["vh", "vw", "r_a", "r_b"],
})
with T.init():
inverse[vh, vw, p_3, co_1] = T.float32(0)
inverse[vh, vw, p_3, co_1] = (inverse[vh, vw, p_3, co_1] + ((bgemm[r_a_1, r_b_1, p_3, co_1]*A[r_a_1, vh])*A[r_b_1, vw]))
for i0_8, i1_8, i2_6, i3_6 in T.grid(1, 12, 12, 128):
with T.block("conv2d_winograd"):
n, h, w, co_2 = T.axis.remap("SSSS", [i0_8, i1_8, i2_6, i3_6])
T.reads([inverse[T.floormod(h, 4), T.floormod(w, 4), (((n*9) + (T.floordiv(h, 4)*3)) + T.floordiv(w, 4)), co_2]])
T.writes([conv2d_winograd[n, h, w, co_2]])
conv2d_winograd[n, h, w, co_2] = inverse[T.floormod(h, 4), T.floormod(w, 4), (((n*9) + (T.floordiv(h, 4)*3)) + T.floordiv(w, 4)), co_2]
# fmt: on
# pylint: enable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument
if __name__ == "__main__":
mod = Conv2d_Winograd
sch = tvm.tir.Schedule(mod)
b1 = sch.get_block(name="A")
sch.compute_inline(block=b1)
b2 = sch.get_block(name="B")
sch.compute_inline(block=b2)
b3 = sch.get_block(name="inverse")
l4, l5, l6, l7, l8, l9 = sch.get_loops(block=b3)
sch.unroll(loop=l4)
sch.unroll(loop=l5)
sch.unroll(loop=l8)
sch.unroll(loop=l9)
v10, v11 = sch.sample_perfect_tile(n=2, loop=l6, max_innermost_factor=64, decision=[1, 9])
l12, l13 = sch.split(loop=l6, factors=[v10, v11])
v14, v15 = sch.sample_perfect_tile(n=2, loop=l7, max_innermost_factor=64, decision=[2, 64])
l16, l17 = sch.split(loop=l7, factors=[v14, v15])
sch.reorder(l12, l16, l13, l17, l4, l5, l8, l9)
b18 = sch.get_block(name="data_pack")
l19, l20, l21, l22, l23, l24 = sch.get_loops(block=b18)
sch.unroll(loop=l19)
sch.unroll(loop=l20)
sch.unroll(loop=l23)
sch.unroll(loop=l24)
v25, v26 = sch.sample_perfect_tile(n=2, loop=l21, max_innermost_factor=64, decision=[9, 1])
l27, l28 = sch.split(loop=l21, factors=[v25, v26])
v29, v30 = sch.sample_perfect_tile(n=2, loop=l22, max_innermost_factor=64, decision=[32, 4])
l31, l32 = sch.split(loop=l22, factors=[v29, v30])
sch.reorder(l27, l31, l28, l32, l19, l20, l23, l24)
b33 = sch.get_block(name="bgemm")
b34 = sch.cache_write(block=b33, write_buffer_index=0, storage_scope="global")
b33, b34 = b34, b33
l35, l36, l37, l38, l39 = sch.get_loops(block=b34)
v40, v41, v42, v43 = sch.sample_perfect_tile(
n=4, loop=l35, max_innermost_factor=64, decision=[1, 2, 3, 1]
)
l44, l45, l46, l47 = sch.split(loop=l35, factors=[v40, v41, v42, v43])
v48, v49, v50, v51 = sch.sample_perfect_tile(
n=4, loop=l36, max_innermost_factor=64, decision=[1, 1, 1, 6]
)
l52, l53, l54, l55 = sch.split(loop=l36, factors=[v48, v49, v50, v51])
v56, v57, v58, v59 = sch.sample_perfect_tile(
n=4, loop=l37, max_innermost_factor=64, decision=[1, 1, 1, 9]
)
l60, l61, l62, l63 = sch.split(loop=l37, factors=[v56, v57, v58, v59])
v64, v65, v66, v67 = sch.sample_perfect_tile(
n=4, loop=l38, max_innermost_factor=64, decision=[2, 1, 16, 4]
)
l68, l69, l70, l71 = sch.split(loop=l38, factors=[v64, v65, v66, v67])
v72, v73 = sch.sample_perfect_tile(n=2, loop=l39, max_innermost_factor=64, decision=[16, 8])
l74, l75 = sch.split(loop=l39, factors=[v72, v73])
sch.reorder(
l44, l52, l60, l68, l45, l53, l61, l69, l74, l46, l54, l62, l70, l75, l47, l55, l63, l71
)
sch.reverse_compute_at(block=b33, loop=l69, preserve_unit_loops=True)
b76 = sch.get_block(name="root")
sch.annotate(block_or_loop=b76, ann_key="auto_parallel_extent", ann_val=64)
sch.annotate(block_or_loop=b76, ann_key="auto_vectorize_extent", ann_val=32)
v77 = sch.sample_categorical(
candidates=[0, 16, 64, 512], probs=[0.25, 0.25, 0.25, 0.25], decision=1
)
sch.annotate(block_or_loop=b76, ann_key="auto_unroll_explicit", ann_val=v77)
b78 = sch.get_block(name="input_tile")
l80 = sch.sample_compute_location(block=b78, decision=4)
sch.compute_at(block=b78, loop=l80, preserve_unit_loops=True)
b81 = sch.get_block(name="data_pad")
(b82,) = sch.get_consumers(block=b81)
l83 = sch.sample_compute_location(block=b82, decision=-2)
sch.compute_at(block=b81, loop=l83, preserve_unit_loops=True)
print(sch.mod.script())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment