Skip to content

Instantly share code, notes, and snippets.

@zxybazh

zxybazh/repro.py Secret

Created July 14, 2023 17:42
Show Gist options
  • Save zxybazh/39d1bc9722bdbe30769f9efd4075c9cf to your computer and use it in GitHub Desktop.
Save zxybazh/39d1bc9722bdbe30769f9efd4075c9cf to your computer and use it in GitHub Desktop.
Script to Reproduce the Cross Reduction Issue with Dlight
from tvm import dlight as dl
from tvm.script import ir as I
from tvm.script import tir as T
from tvm.target import Target
import tvm
@I.ir_module
class After:
@T.prim_func
def func(
W: T.Buffer((4096, 512), "uint32"),
S: T.Buffer((4096, 128), "float16"),
V: T.Buffer((1, 1, 4096), "float16"),
C: T.Buffer((1, 1, 4096), "float16"),
):
T.func_attr({"global_symbol": "main", "tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
# with T.block("root"):
C_rf_local = T.alloc_buffer((16, 1, 1, 4096), "float16", scope="local")
for i2_0_i0_i1_fused_0 in T.thread_binding(32, thread="blockIdx.x"):
for i2_0_i0_i1_fused_1 in T.thread_binding(16, thread="threadIdx.x"):
for k_fused_1 in T.thread_binding(16, thread="threadIdx.y"):
for i2_1_init in range(8):
with T.block("matmul_rf_init"):
vk_fused_1 = T.axis.spatial(16, k_fused_1)
v_i2 = T.axis.spatial(
4096, i2_0_i0_i1_fused_0 * 128 + i2_0_i0_i1_fused_1 * 8 + i2_1_init
)
C_rf_local[vk_fused_1, 0, 0, v_i2] = T.float16(0)
for k_fused_0, i2_1 in T.grid(256, 8):
with T.block("matmul_rf_update"):
vk_fused_1 = T.axis.spatial(16, k_fused_1)
v_i2 = T.axis.spatial(
4096, i2_0_i0_i1_fused_0 * 128 + i2_0_i0_i1_fused_1 * 8 + i2_1
)
vk_fused_0 = T.axis.reduce(256, k_fused_0)
C_rf_local[vk_fused_1, 0, 0, v_i2] = C_rf_local[
vk_fused_1, 0, 0, v_i2
] + V[0, 0, vk_fused_0 * 16 + vk_fused_1] * (
(
T.Cast(
"float16",
T.bitwise_and(
T.shift_right(
W[vk_fused_0 * 16 + vk_fused_1, v_i2 // 8],
T.Cast("uint32", v_i2 % 8) * T.uint32(4),
),
T.uint32(15),
),
)
- T.float16(7)
)
* S[vk_fused_0 * 16 + vk_fused_1, v_i2 // 32]
)
for ax1_ax2_ax3_fused_0 in T.thread_binding(16, thread="threadIdx.x"):
for ax0_fused in T.thread_binding(16, thread="threadIdx.y"):
for ax1_ax2_ax3_fused_1 in range(8):
with T.block("matmul"):
vk_fused_1 = T.axis.reduce(16, ax0_fused)
v_i2 = T.axis.spatial(
4096,
i2_0_i0_i1_fused_0 * 128
+ ax1_ax2_ax3_fused_0 * 8
+ ax1_ax2_ax3_fused_1,
)
with T.init():
C[0, 0, v_i2] = T.float16(0)
C[0, 0, v_i2] = C[0, 0, v_i2] + C_rf_local[vk_fused_1, 0, 0, v_i2]
target = Target("nvidia/geforce-rtx-3090-ti")
tvm.build(After, target=target, name="after")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment