Last active
February 7, 2022 23:17
-
-
Save zxybazh/4e6b2ff3dca7ca66254ea1e553103cfd to your computer and use it in GitHub Desktop.
The test script for failing compute at.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Licensed to the Apache Software Foundation (ASF) under one | |
# or more contributor license agreements. See the NOTICE file | |
# distributed with this work for additional information | |
# regarding copyright ownership. The ASF licenses this file | |
# to you under the Apache License, Version 2.0 (the | |
# "License"); you may not use this file except in compliance | |
# with the License. You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, | |
# software distributed under the License is distributed on an | |
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | |
# KIND, either express or implied. See the License for the | |
# specific language governing permissions and limitations | |
# under the License. | |
# pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring | |
import tvm | |
from tvm.script import tir as T | |
# pylint: disable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument, | |
# fmt: off | |
@tvm.script.ir_module | |
class Conv2d_Winograd: | |
@T.prim_func | |
def main(var_placeholder: T.handle, var_placeholder_1: T.handle, var_conv2d_winograd: T.handle) -> None: | |
# function attr dict | |
T.func_attr({"layout_free_placeholders": [var_placeholder_1]}) | |
placeholder = T.match_buffer(var_placeholder, [1, 14, 14, 128], elem_offset=0, align=128, offset_factor=1) | |
placeholder_1 = T.match_buffer(var_placeholder_1, [6, 6, 128, 128], elem_offset=0, align=128, offset_factor=1) | |
conv2d_winograd = T.match_buffer(var_conv2d_winograd, [1, 12, 12, 128], elem_offset=0, align=128, offset_factor=1) | |
# body | |
with T.block("root"): | |
data_pad = T.alloc_buffer([1, 16, 16, 128], elem_offset=0, align=128, offset_factor=1) | |
input_tile = T.alloc_buffer([6, 6, 9, 128], elem_offset=0, align=128, offset_factor=1) | |
B = T.alloc_buffer([6, 6], elem_offset=0, align=128, offset_factor=1) | |
data_pack = T.alloc_buffer([6, 6, 9, 128], elem_offset=0, align=128, offset_factor=1) | |
bgemm = T.alloc_buffer([6, 6, 9, 128], elem_offset=0, align=128, offset_factor=1) | |
A = T.alloc_buffer([6, 4], elem_offset=0, align=128, offset_factor=1) | |
inverse = T.alloc_buffer([4, 4, 9, 128], elem_offset=0, align=128, offset_factor=1) | |
for i0, i1, i2, i3 in T.grid(1, 16, 16, 128): | |
with T.block("data_pad"): | |
i0_1, i1_1, i2_1, i3_1 = T.axis.remap("SSSS", [i0, i1, i2, i3]) | |
T.reads([placeholder[i0_1, i1_1, i2_1, i3_1]]) | |
T.writes([data_pad[i0_1, i1_1, i2_1, i3_1]]) | |
data_pad[i0_1, i1_1, i2_1, i3_1] = T.if_then_else(((((0 <= i1_1) and (i1_1 < 14)) and (0 <= i2_1)) and (i2_1 < 14)), placeholder[i0_1, i1_1, i2_1, i3_1], T.float32(0), dtype="float32") | |
for i0_2, i1_2, i2_2, i3_2 in T.grid(6, 6, 9, 128): | |
with T.block("input_tile"): | |
eps, nu, p, ci = T.axis.remap("SSSS", [i0_2, i1_2, i2_2, i3_2]) | |
T.reads([data_pad[T.floordiv(p, 9), ((T.floordiv(T.floormod(p, 9), 3)*4) + eps), ((T.floormod(p, 3)*4) + nu), ci]]) | |
T.writes([input_tile[eps, nu, p, ci]]) | |
input_tile[eps, nu, p, ci] = data_pad[T.floordiv(p, 9), ((T.floordiv(T.floormod(p, 9), 3)*4) + eps), ((T.floormod(p, 3)*4) + nu), ci] | |
for i0_3, i1_3 in T.grid(6, 6): | |
with T.block("B"): | |
i, j = T.axis.remap("SS", [i0_3, i1_3]) | |
T.writes([B[i, j]]) | |
T.block_attr({ | |
"const_matrix" : True, | |
}) | |
B[i, j] = T.Select(((T.floormod(i, 6) == 5) and (T.floormod(j, 6) == 5)), T.float32(1), T.Select(((T.floormod(i, 6) == 5) and (T.floormod(j, 6) == 4)), T.float32(0), T.Select(((T.floormod(i, 6) == 5) and (T.floormod(j, 6) == 3)), T.float32(0), T.Select(((T.floormod(i, 6) == 5) and (T.floormod(j, 6) == 2)), T.float32(0), T.Select(((T.floormod(i, 6) == 5) and (T.floormod(j, 6) == 1)), T.float32(0), T.Select(((T.floormod(i, 6) == 5) and (T.floormod(j, 6) == 0)), T.float32(0), T.Select(((T.floormod(i, 6) == 4) and (T.floormod(j, 6) == 5)), T.float32(1.5), T.Select(((T.floormod(i, 6) == 4) and (T.floormod(j, 6) == 4)), T.float32(1), T.Select(((T.floormod(i, 6) == 4) and (T.floormod(j, 6) == 3)), T.float32(1), T.Select(((T.floormod(i, 6) == 4) and (T.floormod(j, 6) == 2)), T.float32(1), T.Select(((T.floormod(i, 6) == 4) and (T.floormod(j, 6) == 1)), T.float32(1), T.Select(((T.floormod(i, 6) == 4) and (T.floormod(j, 6) == 0)), T.float32(1), T.Select(((T.floormod(i, 6) == 3) and (T.floormod(j, 6) == 5)), T.float32(-2), T.Select(((T.floormod(i, 6) == 3) and (T.floormod(j, 6) == 4)), T.float32(-0.5), T.Select(((T.floormod(i, 6) == 3) and (T.floormod(j, 6) == 3)), T.float32(2), T.Select(((T.floormod(i, 6) == 3) and (T.floormod(j, 6) == 2)), T.float32(2.5), T.Select(((T.floormod(i, 6) == 3) and (T.floormod(j, 6) == 1)), T.float32(0.5), T.Select(((T.floormod(i, 6) == 3) and (T.floormod(j, 6) == 0)), T.float32(1.5), T.Select(((T.floormod(i, 6) == 2) and (T.floormod(j, 6) == 5)), T.float32(-1.5), T.Select(((T.floormod(i, 6) == 2) and (T.floormod(j, 6) == 4)), T.float32(-1), T.Select(((T.floormod(i, 6) == 2) and (T.floormod(j, 6) == 3)), T.float32(-1), T.Select(((T.floormod(i, 6) == 2) and (T.floormod(j, 6) == 2)), T.float32(0.5), T.Select(((T.floormod(i, 6) == 2) and (T.floormod(j, 6) == 1)), T.float32(-2.5), T.Select(((T.floormod(i, 6) == 2) and (T.floormod(j, 6) == 0)), T.float32(-2), T.Select(((T.floormod(i, 6) == 1) and (T.floormod(j, 6) == 5)), T.float32(1), T.Select(((T.floormod(i, 6) == 1) and (T.floormod(j, 6) == 4)), T.float32(0.5), T.Select(((T.floormod(i, 6) == 1) and (T.floormod(j, 6) == 3)), T.float32(-2), T.Select(((T.floormod(i, 6) == 1) and (T.floormod(j, 6) == 2)), T.float32(-1), T.Select(((T.floormod(i, 6) == 1) and (T.floormod(j, 6) == 1)), T.float32(1), T.Select(((T.floormod(i, 6) == 1) and (T.floormod(j, 6) == 0)), T.float32(-1.5), T.Select(((T.floormod(i, 6) == 0) and (T.floormod(j, 6) == 5)), T.float32(0), T.Select(((T.floormod(i, 6) == 0) and (T.floormod(j, 6) == 4)), T.float32(0), T.Select(((T.floormod(i, 6) == 0) and (T.floormod(j, 6) == 3)), T.float32(0), T.Select(((T.floormod(i, 6) == 0) and (T.floormod(j, 6) == 2)), T.float32(0), T.Select(((T.floormod(i, 6) == 0) and (T.floormod(j, 6) == 1)), T.float32(0), T.Select(((T.floormod(i, 6) == 0) and (T.floormod(j, 6) == 0)), T.float32(1), T.float32(0))))))))))))))))))))))))))))))))))))) | |
for i0_4, i1_4, i2_3, i3_3, i4, i5 in T.grid(6, 6, 9, 128, 6, 6): | |
with T.block("data_pack"): | |
eps_1, nu_1, p_1, ci_1, r_a, r_b = T.axis.remap("SSSSRR", [i0_4, i1_4, i2_3, i3_3, i4, i5]) | |
T.reads([data_pack[eps_1, nu_1, p_1, ci_1], input_tile[r_a, r_b, p_1, ci_1], B[T.min(r_a, r_b):(T.min(r_a, r_b) + ((T.max(r_a, r_b) + 1) - T.min(r_a, r_b))), T.min(eps_1, nu_1):(T.min(eps_1, nu_1) + ((T.max(eps_1, nu_1) + 1) - T.min(eps_1, nu_1)))]]) | |
T.writes([data_pack[eps_1, nu_1, p_1, ci_1]]) | |
T.block_attr({ | |
"auto_scheduler_simplify_const_tensor_indices":["eps", "nu", "r_a", "r_b"], | |
}) | |
with T.init(): | |
data_pack[eps_1, nu_1, p_1, ci_1] = T.float32(0) | |
data_pack[eps_1, nu_1, p_1, ci_1] = (data_pack[eps_1, nu_1, p_1, ci_1] + ((input_tile[r_a, r_b, p_1, ci_1]*B[r_a, eps_1])*B[r_b, nu_1])) | |
for i0_5, i1_5, i2_4, i3_4, i4_1 in T.grid(6, 6, 9, 128, 128): | |
with T.block("bgemm"): | |
eps_2, nu_2, p_2, co, ci_2 = T.axis.remap("SSSSR", [i0_5, i1_5, i2_4, i3_4, i4_1]) | |
T.reads([bgemm[eps_2, nu_2, p_2, co], data_pack[eps_2, nu_2, p_2, ci_2], placeholder_1[eps_2, nu_2, co, ci_2]]) | |
T.writes([bgemm[eps_2, nu_2, p_2, co]]) | |
T.block_attr({ | |
"schedule_rule": "None", | |
}) | |
with T.init(): | |
bgemm[eps_2, nu_2, p_2, co] = T.float32(0) | |
bgemm[eps_2, nu_2, p_2, co] = (bgemm[eps_2, nu_2, p_2, co] + (data_pack[eps_2, nu_2, p_2, ci_2]*placeholder_1[eps_2, nu_2, co, ci_2])) | |
for i0_6, i1_6 in T.grid(6, 4): | |
with T.block("A"): | |
i_1, j_1 = T.axis.remap("SS", [i0_6, i1_6]) | |
T.writes([A[i_1, j_1]]) | |
T.block_attr({ | |
"const_matrix" : True, | |
}) | |
A[i_1, j_1] = T.Select(((T.floormod(i_1, 6) == 5) and (T.floormod(j_1, 4) == 3)), T.float32(1), T.Select(((T.floormod(i_1, 6) == 5) and (T.floormod(j_1, 4) == 2)), T.float32(0), T.Select(((T.floormod(i_1, 6) == 5) and (T.floormod(j_1, 4) == 1)), T.float32(0), T.Select(((T.floormod(i_1, 6) == 5) and (T.floormod(j_1, 4) == 0)), T.float32(0), T.Select(((T.floormod(i_1, 6) == 4) and (T.floormod(j_1, 4) == 3)), T.float32(-8), T.Select(((T.floormod(i_1, 6) == 4) and (T.floormod(j_1, 4) == 2)), T.float32(4), T.Select(((T.floormod(i_1, 6) == 4) and (T.floormod(j_1, 4) == 1)), T.float32(-2), T.Select(((T.floormod(i_1, 6) == 4) and (T.floormod(j_1, 4) == 0)), T.float32(1), T.Select(((T.floormod(i_1, 6) == 3) and (T.floormod(j_1, 4) == 3)), T.float32(0.125), T.Select(((T.floormod(i_1, 6) == 3) and (T.floormod(j_1, 4) == 2)), T.float32(0.25), T.Select(((T.floormod(i_1, 6) == 3) and (T.floormod(j_1, 4) == 1)), T.float32(0.5), T.Select(((T.floormod(i_1, 6) == 3) and (T.floormod(j_1, 4) == 0)), T.float32(1), T.Select(((T.floormod(i_1, 6) == 2) and (T.floormod(j_1, 4) == 3)), T.float32(1), T.Select(((T.floormod(i_1, 6) == 2) and (T.floormod(j_1, 4) == 2)), T.float32(1), T.Select(((T.floormod(i_1, 6) == 2) and (T.floormod(j_1, 4) == 1)), T.float32(1), T.Select(((T.floormod(i_1, 6) == 2) and (T.floormod(j_1, 4) == 0)), T.float32(1), T.Select(((T.floormod(i_1, 6) == 1) and (T.floormod(j_1, 4) == 3)), T.float32(-1), T.Select(((T.floormod(i_1, 6) == 1) and (T.floormod(j_1, 4) == 2)), T.float32(1), T.Select(((T.floormod(i_1, 6) == 1) and (T.floormod(j_1, 4) == 1)), T.float32(-1), T.Select(((T.floormod(i_1, 6) == 1) and (T.floormod(j_1, 4) == 0)), T.float32(1), T.Select(((T.floormod(i_1, 6) == 0) and (T.floormod(j_1, 4) == 3)), T.float32(0), T.Select(((T.floormod(i_1, 6) == 0) and (T.floormod(j_1, 4) == 2)), T.float32(0), T.Select(((T.floormod(i_1, 6) == 0) and (T.floormod(j_1, 4) == 1)), T.float32(0), T.Select(((T.floormod(i_1, 6) == 0) and (T.floormod(j_1, 4) == 0)), T.float32(1), T.float32(0))))))))))))))))))))))))) | |
for i0_7, i1_7, i2_5, i3_5, i4_2, i5_1 in T.grid(4, 4, 9, 128, 6, 6): | |
with T.block("inverse"): | |
vh, vw, p_3, co_1, r_a_1, r_b_1 = T.axis.remap("SSSSRR", [i0_7, i1_7, i2_5, i3_5, i4_2, i5_1]) | |
T.reads([inverse[vh, vw, p_3, co_1], bgemm[r_a_1, r_b_1, p_3, co_1], A[T.min(r_a_1, r_b_1):(T.min(r_a_1, r_b_1) + ((T.max(r_a_1, r_b_1) + 1) - T.min(r_a_1, r_b_1))), T.min(vh, vw):(T.min(vh, vw) + ((T.max(vh, vw) + 1) - T.min(vh, vw)))]]) | |
T.writes([inverse[vh, vw, p_3, co_1]]) | |
T.block_attr({ | |
"auto_scheduler_simplify_const_tensor_indices":["vh", "vw", "r_a", "r_b"], | |
}) | |
with T.init(): | |
inverse[vh, vw, p_3, co_1] = T.float32(0) | |
inverse[vh, vw, p_3, co_1] = (inverse[vh, vw, p_3, co_1] + ((bgemm[r_a_1, r_b_1, p_3, co_1]*A[r_a_1, vh])*A[r_b_1, vw])) | |
for i0_8, i1_8, i2_6, i3_6 in T.grid(1, 12, 12, 128): | |
with T.block("conv2d_winograd"): | |
n, h, w, co_2 = T.axis.remap("SSSS", [i0_8, i1_8, i2_6, i3_6]) | |
T.reads([inverse[T.floormod(h, 4), T.floormod(w, 4), (((n*9) + (T.floordiv(h, 4)*3)) + T.floordiv(w, 4)), co_2]]) | |
T.writes([conv2d_winograd[n, h, w, co_2]]) | |
conv2d_winograd[n, h, w, co_2] = inverse[T.floormod(h, 4), T.floormod(w, 4), (((n*9) + (T.floordiv(h, 4)*3)) + T.floordiv(w, 4)), co_2] | |
# fmt: on | |
# pylint: enable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument | |
if __name__ == "__main__": | |
mod = Conv2d_Winograd | |
sch = tvm.tir.Schedule(mod) | |
b1 = sch.get_block(name="A") | |
sch.compute_inline(block=b1) | |
b2 = sch.get_block(name="B") | |
sch.compute_inline(block=b2) | |
b3 = sch.get_block(name="inverse") | |
l4, l5, l6, l7, l8, l9 = sch.get_loops(block=b3) | |
sch.unroll(loop=l4) | |
sch.unroll(loop=l5) | |
sch.unroll(loop=l8) | |
sch.unroll(loop=l9) | |
v10, v11 = sch.sample_perfect_tile(n=2, loop=l6, max_innermost_factor=64, decision=[1, 9]) | |
l12, l13 = sch.split(loop=l6, factors=[v10, v11]) | |
v14, v15 = sch.sample_perfect_tile(n=2, loop=l7, max_innermost_factor=64, decision=[2, 64]) | |
l16, l17 = sch.split(loop=l7, factors=[v14, v15]) | |
sch.reorder(l12, l16, l13, l17, l4, l5, l8, l9) | |
b18 = sch.get_block(name="data_pack") | |
l19, l20, l21, l22, l23, l24 = sch.get_loops(block=b18) | |
sch.unroll(loop=l19) | |
sch.unroll(loop=l20) | |
sch.unroll(loop=l23) | |
sch.unroll(loop=l24) | |
v25, v26 = sch.sample_perfect_tile(n=2, loop=l21, max_innermost_factor=64, decision=[9, 1]) | |
l27, l28 = sch.split(loop=l21, factors=[v25, v26]) | |
v29, v30 = sch.sample_perfect_tile(n=2, loop=l22, max_innermost_factor=64, decision=[32, 4]) | |
l31, l32 = sch.split(loop=l22, factors=[v29, v30]) | |
sch.reorder(l27, l31, l28, l32, l19, l20, l23, l24) | |
b33 = sch.get_block(name="bgemm") | |
b34 = sch.cache_write(block=b33, write_buffer_index=0, storage_scope="global") | |
b33, b34 = b34, b33 | |
l35, l36, l37, l38, l39 = sch.get_loops(block=b34) | |
v40, v41, v42, v43 = sch.sample_perfect_tile( | |
n=4, loop=l35, max_innermost_factor=64, decision=[1, 2, 3, 1] | |
) | |
l44, l45, l46, l47 = sch.split(loop=l35, factors=[v40, v41, v42, v43]) | |
v48, v49, v50, v51 = sch.sample_perfect_tile( | |
n=4, loop=l36, max_innermost_factor=64, decision=[1, 1, 1, 6] | |
) | |
l52, l53, l54, l55 = sch.split(loop=l36, factors=[v48, v49, v50, v51]) | |
v56, v57, v58, v59 = sch.sample_perfect_tile( | |
n=4, loop=l37, max_innermost_factor=64, decision=[1, 1, 1, 9] | |
) | |
l60, l61, l62, l63 = sch.split(loop=l37, factors=[v56, v57, v58, v59]) | |
v64, v65, v66, v67 = sch.sample_perfect_tile( | |
n=4, loop=l38, max_innermost_factor=64, decision=[2, 1, 16, 4] | |
) | |
l68, l69, l70, l71 = sch.split(loop=l38, factors=[v64, v65, v66, v67]) | |
v72, v73 = sch.sample_perfect_tile(n=2, loop=l39, max_innermost_factor=64, decision=[16, 8]) | |
l74, l75 = sch.split(loop=l39, factors=[v72, v73]) | |
sch.reorder( | |
l44, l52, l60, l68, l45, l53, l61, l69, l74, l46, l54, l62, l70, l75, l47, l55, l63, l71 | |
) | |
sch.reverse_compute_at(block=b33, loop=l69, preserve_unit_loops=True) | |
b76 = sch.get_block(name="root") | |
sch.annotate(block_or_loop=b76, ann_key="auto_parallel_extent", ann_val=64) | |
sch.annotate(block_or_loop=b76, ann_key="auto_vectorize_extent", ann_val=32) | |
v77 = sch.sample_categorical( | |
candidates=[0, 16, 64, 512], probs=[0.25, 0.25, 0.25, 0.25], decision=1 | |
) | |
sch.annotate(block_or_loop=b76, ann_key="auto_unroll_explicit", ann_val=v77) | |
b78 = sch.get_block(name="input_tile") | |
l80 = sch.sample_compute_location(block=b78, decision=4) | |
sch.compute_at(block=b78, loop=l80, preserve_unit_loops=True) | |
b81 = sch.get_block(name="data_pad") | |
(b82,) = sch.get_consumers(block=b81) | |
l83 = sch.sample_compute_location(block=b82, decision=-2) | |
sch.compute_at(block=b81, loop=l83, preserve_unit_loops=True) | |
print(sch.mod.script()) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment