Created
November 1, 2023 10:09
-
-
Save zhen8838/9a17fa51763ee24baea00f57c2e4e73f to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# from tvm.script import ir as I | |
# from tvm.script import tir as T | |
# from tvm.script import relax as R | |
@I.ir_module | |
class Module: | |
@T.prim_func(private=True) | |
def cast(var_A: T.handle, var_compute: T.handle): | |
T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) | |
n = T.int32() | |
A = T.match_buffer(var_A, (1, n, 2560), "float16") | |
compute = T.match_buffer(var_compute, (1, n, 2560)) | |
# with T.block("root"): | |
for ax0_ax1_fused_0 in T.thread_binding((n * 2560 + 1023) // 1024, thread="blockIdx.x"): | |
for ax0_ax1_fused_1 in T.thread_binding(1024, thread="threadIdx.x"): | |
with T.block("compute"): | |
v0 = T.axis.spatial(n, (ax0_ax1_fused_0 * 1024 + ax0_ax1_fused_1) // 2560) | |
v1 = T.axis.spatial(2560, (ax0_ax1_fused_0 * 1024 + ax0_ax1_fused_1) % 2560) | |
T.where(ax0_ax1_fused_0 * 1024 + ax0_ax1_fused_1 < n * 2560) | |
T.reads(A[0, v0, v1]) | |
T.writes(compute[0, v0, v1]) | |
compute[0, v0, v1] = T.Cast("float32", A[0, v0, v1]) | |
@T.prim_func(private=True) | |
def cast6(A: T.Buffer((1, 1, 2560), "float32"), compute: T.Buffer((1, 1, 2560), "float32")): | |
T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) | |
# with T.block("root"): | |
for ax0_fused_0 in T.thread_binding(3, thread="blockIdx.x"): | |
for ax0_fused_1 in T.thread_binding(1024, thread="threadIdx.x"): | |
with T.block("compute"): | |
v0 = T.axis.spatial(2560, ax0_fused_0 * 1024 + ax0_fused_1) | |
T.where(ax0_fused_0 * 1024 + ax0_fused_1 < 2560) | |
T.reads(A[0, 0, v0]) | |
T.writes(compute[0, 0, v0]) | |
compute[0, 0, v0] = A[0, 0, v0] | |
@T.prim_func(private=True) | |
def cast7(A: T.Buffer((1, 1, 2560), "float16"), compute: T.Buffer((1, 1, 2560), "float32")): | |
T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) | |
# with T.block("root"): | |
for ax0_fused_0 in T.thread_binding(3, thread="blockIdx.x"): | |
for ax0_fused_1 in T.thread_binding(1024, thread="threadIdx.x"): | |
with T.block("compute"): | |
v0 = T.axis.spatial(2560, ax0_fused_0 * 1024 + ax0_fused_1) | |
T.where(ax0_fused_0 * 1024 + ax0_fused_1 < 2560) | |
T.reads(A[0, 0, v0]) | |
T.writes(compute[0, 0, v0]) | |
compute[0, 0, v0] = T.Cast("float32", A[0, 0, v0]) | |
@T.prim_func(private=True) | |
def divide1(A: T.Buffer((1, 1, 50432), "float32"), B: T.Buffer((), "float32"), T_divide: T.Buffer((1, 1, 50432), "float32")): | |
T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) | |
# with T.block("root"): | |
for ax0_fused_0 in T.thread_binding(50, thread="blockIdx.x"): | |
for ax0_fused_1 in T.thread_binding(1024, thread="threadIdx.x"): | |
with T.block("T_divide"): | |
v0 = T.axis.spatial(50432, ax0_fused_0 * 1024 + ax0_fused_1) | |
T.where(ax0_fused_0 * 1024 + ax0_fused_1 < 50432) | |
T.reads(A[0, 0, v0], B[()]) | |
T.writes(T_divide[0, 0, v0]) | |
T_divide[0, 0, v0] = A[0, 0, v0] / B[()] | |
@T.prim_func(private=True) | |
def extend_te(var_A: T.handle, var_concat_te: T.handle): | |
T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) | |
n = T.int32() | |
A = T.match_buffer(var_A, (1, 1, n, n), "float16") | |
m = T.int32() | |
concat_te = T.match_buffer(var_concat_te, (1, 1, n, m), "float16") | |
# with T.block("root"): | |
for ax0_ax1_fused_0 in T.thread_binding((n * m + 1023) // 1024, thread="blockIdx.x"): | |
for ax0_ax1_fused_1 in T.thread_binding(1024, thread="threadIdx.x"): | |
with T.block("concat_te"): | |
v0 = T.axis.spatial(n, (ax0_ax1_fused_0 * 1024 + ax0_ax1_fused_1) % (m * n) // m) | |
v1 = T.axis.spatial(m, (ax0_ax1_fused_0 * 1024 + ax0_ax1_fused_1) % m) | |
T.where(ax0_ax1_fused_0 * 1024 + ax0_ax1_fused_1 < n * m) | |
T.reads(A[0, 0, v0, v1 + (n - m)]) | |
T.writes(concat_te[0, 0, v0, v1]) | |
concat_te[0, 0, v0, v1] = T.if_then_else(v1 < m - n, T.float16(65504), A[0, 0, v0, v1 + (n - m)]) | |
@T.prim_func(private=True) | |
def full(var_T_full: T.handle): | |
T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) | |
n = T.int32() | |
T_full = T.match_buffer(var_T_full, (1, 1, 1, n), "float16") | |
# with T.block("root"): | |
for ax0_fused_0 in T.thread_binding((n + 1023) // 1024, thread="blockIdx.x"): | |
for ax0_fused_1 in T.thread_binding(1024, thread="threadIdx.x"): | |
with T.block("T_full"): | |
v0 = T.axis.spatial(n, ax0_fused_0 * 1024 + ax0_fused_1) | |
T.where(ax0_fused_0 * 1024 + ax0_fused_1 < n) | |
T.reads() | |
T.writes(T_full[0, 0, 0, v0]) | |
T_full[0, 0, 0, v0] = T.float16(65504) | |
@T.prim_func(private=True) | |
def fused_NT_matmul1_divide_maximum_minimum_cast2(p_lv30: T.handle, p_lv31: T.handle, p_lv5: T.handle, p_output0: T.handle): | |
T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) | |
n = T.int32() | |
lv30 = T.match_buffer(p_lv30, (1, 32, n, 80), "float16") | |
m = T.int32() | |
lv31 = T.match_buffer(p_lv31, (1, 32, m, 80), "float16") | |
lv5 = T.match_buffer(p_lv5, (1, 1, n, m), "float16") | |
var_compute_intermediate = T.match_buffer(p_output0, (1, 32, n, m)) | |
# with T.block("root"): | |
var_NT_matmul_intermediate_reindex_pad_local = T.alloc_buffer((32, (n + 31) // 32 * 32, (m + 31) // 32 * 32), "float16", scope="local") | |
lv30_reindex_pad_shared = T.alloc_buffer((32, (n + 31) // 32 * 32, 80), "float16", scope="shared") | |
lv31_reindex_pad_shared = T.alloc_buffer((32, (m + 31) // 32 * 32, 80), "float16", scope="shared") | |
for ax0_ax2_0_fused in T.thread_binding((m + 31) // 32 * 32, thread="blockIdx.y"): | |
for ax1_0 in T.thread_binding((n + 31) // 32, thread="blockIdx.x"): | |
for ax2_1 in T.thread_binding(1, thread="vthread.y"): | |
for ax1_1 in T.thread_binding(1, thread="vthread.x"): | |
for ax2_2 in T.thread_binding(8, thread="threadIdx.y"): | |
for ax1_2 in T.thread_binding(8, thread="threadIdx.x", annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}): | |
for ax2_3_init, ax1_3_0_init in T.grid(4, 4): | |
for ax1_3_1_init in T.vectorized(1): | |
with T.block("NT_matmul_init"): | |
v0 = T.axis.spatial(32, ax0_ax2_0_fused // ((m + 31) // 32)) | |
v1 = T.axis.spatial((n + 31) // 32 * 32, ax1_0 * 32 + ax1_1 * 32 + ax1_2 * 4 + ax1_3_0_init + ax1_3_1_init) | |
v2 = T.axis.spatial((m + 31) // 32 * 32, ax0_ax2_0_fused % ((m + 31) // 32) * 32 + ax2_1 * 32 + ax2_2 * 4 + ax2_3_init) | |
T.reads() | |
T.writes(var_NT_matmul_intermediate_reindex_pad_local[v0, v1, v2]) | |
var_NT_matmul_intermediate_reindex_pad_local[v0, v1, v2] = T.float16(0) | |
for ax3_0 in range(10): | |
for ax0_ax1_ax2_fused_0 in T.thread_binding(8, thread="threadIdx.y"): | |
for ax0_ax1_ax2_fused_1 in T.thread_binding(8, thread="threadIdx.x"): | |
for ax0_ax1_ax2_fused_2 in range(4): | |
for ax0_ax1_ax2_fused_3 in T.vectorized(1): | |
with T.block("lv30_reindex_pad_shared"): | |
v0 = T.axis.spatial(32, ax0_ax2_0_fused // ((m + 31) // 32)) | |
v1 = T.axis.spatial((n + 31) // 32 * 32, ax1_0 * 32 + (ax0_ax1_ax2_fused_0 * 32 + ax0_ax1_ax2_fused_1 * 4 + ax0_ax1_ax2_fused_2 + ax0_ax1_ax2_fused_3) // 8) | |
v2 = T.axis.spatial(80, ax3_0 * 8 + (ax0_ax1_ax2_fused_0 * 32 + ax0_ax1_ax2_fused_1 * 4 + ax0_ax1_ax2_fused_2 + ax0_ax1_ax2_fused_3) % 8) | |
T.reads(lv30[0, v0, v1, v2]) | |
T.writes(lv30_reindex_pad_shared[v0, v1, v2]) | |
lv30_reindex_pad_shared[v0, v1, v2] = T.if_then_else(v1 < n, lv30[0, v0, v1, v2], T.float16(0)) | |
for ax0_ax1_ax2_fused_0 in T.thread_binding(8, thread="threadIdx.y"): | |
for ax0_ax1_ax2_fused_1 in T.thread_binding(8, thread="threadIdx.x"): | |
for ax0_ax1_ax2_fused_2 in range(4): | |
for ax0_ax1_ax2_fused_3 in T.vectorized(1): | |
with T.block("lv31_reindex_pad_shared"): | |
v0 = T.axis.spatial(32, ax0_ax2_0_fused // ((m + 31) // 32)) | |
v1 = T.axis.spatial((m + 31) // 32 * 32, ax0_ax2_0_fused % ((m + 31) // 32) * 32 + (ax0_ax1_ax2_fused_0 * 32 + ax0_ax1_ax2_fused_1 * 4 + ax0_ax1_ax2_fused_2 + ax0_ax1_ax2_fused_3) // 8) | |
v2 = T.axis.spatial(80, ax3_0 * 8 + (ax0_ax1_ax2_fused_0 * 32 + ax0_ax1_ax2_fused_1 * 4 + ax0_ax1_ax2_fused_2 + ax0_ax1_ax2_fused_3) % 8) | |
T.reads(lv31[0, v0, v1, v2]) | |
T.writes(lv31_reindex_pad_shared[v0, v1, v2]) | |
lv31_reindex_pad_shared[v0, v1, v2] = T.if_then_else(v1 < m, lv31[0, v0, v1, v2], T.float16(0)) | |
for ax3_1, ax2_3, ax1_3_0 in T.grid(8, 4, 4): | |
for ax1_3_1 in T.vectorized(1): | |
with T.block("NT_matmul_update"): | |
v0 = T.axis.spatial(32, ax0_ax2_0_fused // ((m + 31) // 32)) | |
v1 = T.axis.spatial((n + 31) // 32 * 32, ax1_0 * 32 + ax1_1 * 32 + ax1_2 * 4 + ax1_3_0 + ax1_3_1) | |
v2 = T.axis.spatial((m + 31) // 32 * 32, ax0_ax2_0_fused % ((m + 31) // 32) * 32 + ax2_1 * 32 + ax2_2 * 4 + ax2_3) | |
v3 = T.axis.reduce(80, ax3_0 * 8 + ax3_1) | |
T.reads(var_NT_matmul_intermediate_reindex_pad_local[v0, v1, v2], lv30_reindex_pad_shared[v0, v1, v3], lv31_reindex_pad_shared[v0, v2, v3]) | |
T.writes(var_NT_matmul_intermediate_reindex_pad_local[v0, v1, v2]) | |
var_NT_matmul_intermediate_reindex_pad_local[v0, v1, v2] = var_NT_matmul_intermediate_reindex_pad_local[v0, v1, v2] + lv30_reindex_pad_shared[v0, v1, v3] * lv31_reindex_pad_shared[v0, v2, v3] | |
for ax0, ax1, ax2_0 in T.grid(1, 4, 4): | |
for ax2_1_1 in T.vectorized(1): | |
with T.block("var_NT_matmul_intermediate_reindex_pad_local"): | |
v0 = T.axis.spatial(32, ax0_ax2_0_fused // ((m + 31) // 32) + ax0) | |
v1 = T.axis.spatial((n + 31) // 32 * 32, ax1_0 * 32 + ax1_2 * 4 + ax1) | |
v2 = T.axis.spatial((m + 31) // 32 * 32, ax0_ax2_0_fused % ((m + 31) // 32) * 32 + ax2_2 * 4 + ax2_0 + ax2_1_1) | |
T.reads(var_NT_matmul_intermediate_reindex_pad_local[v0, v1, v2], lv5[0, 0, v1, v2]) | |
T.writes(var_compute_intermediate[0, v0, v1, v2]) | |
if v1 < n and v2 < m: | |
var_compute_intermediate[0, v0, v1, v2] = T.Cast("float32", T.min(T.max(var_NT_matmul_intermediate_reindex_pad_local[v0, v1, v2] * T.float16(0.11179039301310044), T.float16(-65504)), lv5[0, 0, v1, v2])) | |
@T.prim_func(private=True) | |
def fused_NT_matmul7_divide2_maximum1_minimum1_cast9(lv1896: T.Buffer((1, 32, 1, 80), "float16"), p_lv1897: T.handle, p_lv1871: T.handle, p_output0: T.handle): | |
T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) | |
n = T.int32() | |
lv1897 = T.match_buffer(p_lv1897, (1, 32, n, 80), "float16") | |
lv1871 = T.match_buffer(p_lv1871, (1, 1, 1, n), "float16") | |
var_compute_intermediate = T.match_buffer(p_output0, (1, 32, 1, n)) | |
# with T.block("root"): | |
var_NT_matmul_intermediate_local = T.alloc_buffer((1, 32, 1, n), "float16", scope="local") | |
var_NT_matmul_intermediate_rf_local = T.alloc_buffer((64, 1, 32, 1, n), "float16", scope="local") | |
var_NT_matmul_intermediate_rf_local_1 = T.alloc_buffer((64, 1, 32, 1, n), "float16", scope="local") | |
lv1897_local = T.alloc_buffer((1, 32, n, 80), "float16", scope="local") | |
lv1896_shared = T.alloc_buffer((1, 32, 1, 80), "float16", scope="shared") | |
for ax0_fused_ax1_fused_fused_0 in T.thread_binding(n * 32, thread="blockIdx.x"): | |
for ax0_fused_ax1_fused_fused_1 in T.thread_binding(1, thread="threadIdx.y"): | |
for ax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0 in T.thread_binding(64, thread="threadIdx.x"): | |
for ax0, ax1, ax2 in T.grid(1, 1, 1): | |
for ax3_0 in T.serial(2, annotations={"pragma_unroll_explicit": 256, "pragma_vectorize": 1}): | |
for ax3_1 in T.thread_binding(1, thread="threadIdx.y"): | |
for ax3_2 in T.thread_binding(64, thread="threadIdx.x"): | |
for ax3_3 in T.vectorized(1): | |
with T.block("lv1896_shared"): | |
v0 = T.axis.spatial(1, ax0) | |
v1 = T.axis.spatial(32, ax0_fused_ax1_fused_fused_0 // n + ax1) | |
v2 = T.axis.spatial(1, ax2) | |
v3 = T.axis.spatial(80, ax3_0 * 64 + ax3_1 * 64 + ax3_2 + ax3_3) | |
T.where((ax3_0 + ax3_1) * 64 + ax3_2 + ax3_3 < 80) | |
T.reads(lv1896[v0, v1, v2, v3]) | |
T.writes(lv1896_shared[v0, v1, v2, v3]) | |
lv1896_shared[v0, v1, v2, v3] = lv1896[v0, v1, v2, v3] | |
for ax0_fused_ax1_fused_fused_2_init in range(1): | |
for ax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_1_init in T.vectorized(1): | |
with T.block("NT_matmul_rf_init"): | |
vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused = T.axis.spatial(64, ax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0 + ax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_1_init) | |
v0 = T.axis.spatial(32, (ax0_fused_ax1_fused_fused_0 + ax0_fused_ax1_fused_fused_1 + ax0_fused_ax1_fused_fused_2_init) // n) | |
v1 = T.axis.spatial(n, (ax0_fused_ax1_fused_fused_0 + ax0_fused_ax1_fused_fused_1 + ax0_fused_ax1_fused_fused_2_init) % n) | |
T.reads() | |
T.writes(var_NT_matmul_intermediate_rf_local[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused, 0, v0, 0, v1]) | |
var_NT_matmul_intermediate_rf_local[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused, 0, v0, 0, v1] = T.float16(0) | |
for ax2_fused_u_fused_0 in T.serial(2, annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}): | |
for ax0, ax1, ax2_0, ax3 in T.grid(1, 1, 1, 1): | |
for ax2_1 in T.vectorized(1): | |
with T.block("lv1897_local"): | |
v0 = T.axis.spatial(1, ax0) | |
v1 = T.axis.spatial(32, ax0_fused_ax1_fused_fused_0 // n + ax1) | |
v2 = T.axis.spatial(n, ax0_fused_ax1_fused_fused_0 % n + ax2_0 + ax2_1) | |
v3 = T.axis.spatial(80, ax2_fused_u_fused_0 * 64 + ax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0 + ax3) | |
T.where(ax2_fused_u_fused_0 * 64 + ax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0 < 80) | |
T.reads(lv1897[v0, v1, v2, v3]) | |
T.writes(lv1897_local[v0, v1, v2, v3]) | |
lv1897_local[v0, v1, v2, v3] = lv1897[v0, v1, v2, v3] | |
for ax0_fused_ax1_fused_fused_2, ax2_fused_u_fused_2 in T.grid(1, 1): | |
for ax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_1 in T.vectorized(1): | |
with T.block("NT_matmul_rf_update"): | |
vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused = T.axis.spatial(64, ax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0 + ax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_1) | |
v0 = T.axis.spatial(32, (ax0_fused_ax1_fused_fused_0 + ax0_fused_ax1_fused_fused_1 + ax0_fused_ax1_fused_fused_2) // n) | |
v1 = T.axis.spatial(n, (ax0_fused_ax1_fused_fused_0 + ax0_fused_ax1_fused_fused_1 + ax0_fused_ax1_fused_fused_2) % n) | |
vax2_fused_u_fused_0, vax2_fused_u_fused_2 = T.axis.remap("RR", [ax2_fused_u_fused_0, ax2_fused_u_fused_2]) | |
T.where(ax2_fused_u_fused_0 * 64 + (ax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0 + ax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_1) + ax2_fused_u_fused_2 < 80) | |
T.reads(var_NT_matmul_intermediate_rf_local[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused, 0, v0, 0, v1], lv1896_shared[0, v0, 0, vax2_fused_u_fused_0 * 64 + vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused + vax2_fused_u_fused_2], lv1897_local[0, v0, v1, vax2_fused_u_fused_0 * 64 + vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused + vax2_fused_u_fused_2]) | |
T.writes(var_NT_matmul_intermediate_rf_local[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused, 0, v0, 0, v1]) | |
var_NT_matmul_intermediate_rf_local[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused, 0, v0, 0, v1] = var_NT_matmul_intermediate_rf_local[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused, 0, v0, 0, v1] + lv1896_shared[0, v0, 0, vax2_fused_u_fused_0 * 64 + vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused + vax2_fused_u_fused_2] * lv1897_local[0, v0, v1, vax2_fused_u_fused_0 * 64 + vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused + vax2_fused_u_fused_2] | |
for ax2_ax3_fused_0 in T.thread_binding(1, thread="threadIdx.y"): | |
for ax0 in T.thread_binding(64, thread="threadIdx.x"): | |
for ax2_ax3_fused_1_0 in T.serial(1, annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}): | |
for ax2_ax3_fused_1_1 in T.vectorized(1): | |
with T.block("NT_matmul_rf_init"): | |
vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0 = T.axis.spatial(64, ax0) | |
v0 = T.axis.spatial(32, ax0_fused_ax1_fused_fused_0 // n) | |
v1 = T.axis.spatial(n, ax0_fused_ax1_fused_fused_0 % n) | |
T.reads() | |
T.writes(var_NT_matmul_intermediate_rf_local_1[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0, 0, v0, 0, v1]) | |
var_NT_matmul_intermediate_rf_local_1[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0, 0, v0, 0, v1] = T.float16(0) | |
for ax1 in range(1): | |
with T.block("NT_matmul_rf_update"): | |
vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0, vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_1 = T.axis.remap("SR", [ax0, ax1]) | |
v0 = T.axis.spatial(32, ax0_fused_ax1_fused_fused_0 // n) | |
v1 = T.axis.spatial(n, ax0_fused_ax1_fused_fused_0 % n) | |
T.reads(var_NT_matmul_intermediate_rf_local_1[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0, 0, v0, 0, v1], var_NT_matmul_intermediate_rf_local[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0 + vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_1, 0, v0, 0, v1]) | |
T.writes(var_NT_matmul_intermediate_rf_local_1[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0, 0, v0, 0, v1]) | |
var_NT_matmul_intermediate_rf_local_1[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0, 0, v0, 0, v1] = var_NT_matmul_intermediate_rf_local_1[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0, 0, v0, 0, v1] + var_NT_matmul_intermediate_rf_local[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0 + vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_1, 0, v0, 0, v1] | |
for ax1_ax2_fused_1 in range(1): | |
for ax1_ax2_fused_0 in T.thread_binding(1, thread="threadIdx.y"): | |
for ax0 in T.thread_binding(64, thread="threadIdx.x"): | |
with T.block("NT_matmul"): | |
vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0 = T.axis.reduce(64, ax0) | |
v0 = T.axis.spatial(32, ax0_fused_ax1_fused_fused_0 // n) | |
v1 = T.axis.spatial(n, ax0_fused_ax1_fused_fused_0 % n) | |
T.reads(var_NT_matmul_intermediate_rf_local_1[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0, 0, v0, 0, v1]) | |
T.writes(var_NT_matmul_intermediate_local[0, v0, 0, v1]) | |
with T.init(): | |
var_NT_matmul_intermediate_local[0, v0, 0, v1] = T.float16(0) | |
var_NT_matmul_intermediate_local[0, v0, 0, v1] = var_NT_matmul_intermediate_local[0, v0, 0, v1] + var_NT_matmul_intermediate_rf_local_1[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0, 0, v0, 0, v1] | |
for ax0_ax1_fused_0 in T.thread_binding(1, thread="threadIdx.y"): | |
for ax0_ax1_fused_1 in range(1): | |
with T.block("compute"): | |
v0 = T.axis.spatial(32, ax0_fused_ax1_fused_fused_0 // n) | |
v1 = T.axis.spatial(n, ax0_fused_ax1_fused_fused_0 % n) | |
T.reads(var_NT_matmul_intermediate_local[0, v0, 0, v1], lv1871[0, 0, 0, v1]) | |
T.writes(var_compute_intermediate[0, v0, 0, v1]) | |
var_compute_intermediate[0, v0, 0, v1] = T.Cast("float32", T.min(T.max(var_NT_matmul_intermediate_local[0, v0, 0, v1] * T.float16(0.11179039301310044), T.float16(-65504)), lv1871[0, 0, 0, v1])) | |
@T.prim_func(private=True) | |
def fused_fused_decode1_take(lv: T.Buffer((50432, 320), "uint32"), lv1: T.Buffer((50432, 80), "float16"), p_lv: T.handle, p_output0: T.handle): | |
T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) | |
n = T.int32() | |
lv_1 = T.match_buffer(p_lv, (n,), "int32") | |
var_T_take_intermediate = T.match_buffer(p_output0, (n, 2560), "float16") | |
# with T.block("root"): | |
for ax0_ax1_fused_0 in T.thread_binding((n * 2560 + 1023) // 1024, thread="blockIdx.x"): | |
for ax0_ax1_fused_1 in T.thread_binding(1024, thread="threadIdx.x"): | |
with T.block("T_take"): | |
v0 = T.axis.spatial(n, (ax0_ax1_fused_0 * 1024 + ax0_ax1_fused_1) // 2560) | |
v1 = T.axis.spatial(2560, (ax0_ax1_fused_0 * 1024 + ax0_ax1_fused_1) % 2560) | |
T.where(ax0_ax1_fused_0 * 1024 + ax0_ax1_fused_1 < n * 2560) | |
T.reads(lv[lv_1[v0], v1 // 8], lv_1[v0], lv1[lv_1[v0], v1 // 32]) | |
T.writes(var_T_take_intermediate[v0, v1]) | |
var_T_take_intermediate[v0, v1] = (T.Cast("float16", T.bitwise_and(T.shift_right(lv[lv_1[v0], v1 // 8], T.Cast("uint32", v1 % 8) * T.uint32(4)), T.uint32(15))) - T.float16(7)) * lv1[lv_1[v0], v1 // 32] | |
@T.prim_func(private=True) | |
def fused_fused_decode1_take1(lv711: T.Buffer((50432, 320), "uint32"), lv712: T.Buffer((50432, 80), "float16"), lv1868: T.Buffer((1,), "int32"), var_T_take_intermediate: T.Buffer((1, 2560), "float16")): | |
T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) | |
# with T.block("root"): | |
for ax0_fused_0 in T.thread_binding(3, thread="blockIdx.x"): | |
for ax0_fused_1 in T.thread_binding(1024, thread="threadIdx.x"): | |
with T.block("T_take"): | |
v0 = T.axis.spatial(2560, ax0_fused_0 * 1024 + ax0_fused_1) | |
T.where(ax0_fused_0 * 1024 + ax0_fused_1 < 2560) | |
T.reads(lv711[lv1868[0], v0 // 8], lv1868[0], lv712[lv1868[0], v0 // 32]) | |
T.writes(var_T_take_intermediate[0, v0]) | |
var_T_take_intermediate[0, v0] = (T.Cast("float16", T.bitwise_and(T.shift_right(lv711[lv1868[0], v0 // 8], T.Cast("uint32", v0 % 8) * T.uint32(4)), T.uint32(15))) - T.float16(7)) * lv712[lv1868[0], v0 // 32] | |
@T.prim_func(private=True) | |
def fused_fused_decode2_fused_NT_matmul6_add5(lv715: T.Buffer((7680, 320), "uint32"), lv716: T.Buffer((7680, 80), "float16"), lv1875: T.Buffer((1, 1, 2560), "float16"), lv524: T.Buffer((7680,), "float16"), p_output0_intermediate: T.Buffer((1, 1, 7680), "float16")): | |
T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) | |
# with T.block("root"): | |
var_NT_matmul_intermediate_local = T.alloc_buffer((1, 1, 7680), "float16", scope="local") | |
var_NT_matmul_intermediate_rf_local = T.alloc_buffer((16, 1, 1, 7680), "float16", scope="local") | |
var_NT_matmul_intermediate_rf_local_1 = T.alloc_buffer((8, 1, 1, 7680), "float16", scope="local") | |
lv715_local = T.alloc_buffer((7680, 320), "uint32", scope="local") | |
lv1875_shared = T.alloc_buffer((1, 1, 2560), "float16", scope="shared") | |
for u_fused_ax0_fused_fused_0 in T.thread_binding(240, thread="blockIdx.x"): | |
for u_fused_ax0_fused_fused_1 in T.thread_binding(32, thread="threadIdx.y"): | |
for ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 in T.thread_binding(8, thread="threadIdx.x"): | |
for ax0, ax1 in T.grid(1, 1): | |
for ax2_0 in T.serial(5, annotations={"pragma_unroll_explicit": 256, "pragma_vectorize": 1}): | |
for ax2_1 in T.thread_binding(32, thread="threadIdx.y"): | |
for ax2_2 in T.thread_binding(8, thread="threadIdx.x"): | |
for ax2_3 in T.vectorized(2): | |
with T.block("lv1875_shared"): | |
v0, v1 = T.axis.remap("SS", [ax0, ax1]) | |
v2 = T.axis.spatial(2560, ax2_0 * 512 + ax2_1 * 16 + ax2_2 * 2 + ax2_3) | |
T.reads(lv1875[v0, v1, v2]) | |
T.writes(lv1875_shared[v0, v1, v2]) | |
lv1875_shared[v0, v1, v2] = lv1875[v0, v1, v2] | |
for u_fused_ax0_fused_fused_2_init in range(1): | |
for ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1_init in T.vectorized(2): | |
with T.block("NT_matmul_rf_init"): | |
vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused = T.axis.spatial(16, ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 * 2 + ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1_init) | |
v0 = T.axis.spatial(7680, u_fused_ax0_fused_fused_0 * 32 + u_fused_ax0_fused_fused_1 + u_fused_ax0_fused_fused_2_init) | |
T.reads() | |
T.writes(var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused, 0, 0, v0]) | |
var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused, 0, 0, v0] = T.float16(0) | |
for ax1_0_fused_ax1_1_fused_0 in T.serial(40, annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}): | |
for ax0_0, ax1 in T.grid(1, 1): | |
for ax0_1 in T.vectorized(1): | |
with T.block("lv715_local"): | |
v0 = T.axis.spatial(7680, u_fused_ax0_fused_fused_0 * 32 + u_fused_ax0_fused_fused_1 + ax0_0 + ax0_1) | |
v1 = T.axis.spatial(320, ax1_0_fused_ax1_1_fused_0 * 8 + ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 + ax1) | |
T.reads(lv715[v0, v1]) | |
T.writes(lv715_local[v0, v1]) | |
lv715_local[v0, v1] = lv715[v0, v1] | |
for u_fused_ax0_fused_fused_2, ax1_0_fused_ax1_1_fused_2 in T.grid(1, 4): | |
for ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1 in T.vectorized(2): | |
with T.block("NT_matmul_rf_update"): | |
vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused = T.axis.spatial(16, ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 * 2 + ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1) | |
v0 = T.axis.spatial(7680, u_fused_ax0_fused_fused_0 * 32 + u_fused_ax0_fused_fused_1 + u_fused_ax0_fused_fused_2) | |
vax1_0_fused_ax1_1_fused_0, vax1_0_fused_ax1_1_fused_2 = T.axis.remap("RR", [ax1_0_fused_ax1_1_fused_0, ax1_0_fused_ax1_1_fused_2]) | |
T.reads(var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused, 0, 0, v0], lv1875_shared[0, 0, vax1_0_fused_ax1_1_fused_0 * 64 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // 2 * 8 + vax1_0_fused_ax1_1_fused_2 * 2 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused % 2], lv715_local[v0, vax1_0_fused_ax1_1_fused_0 * 8 + vax1_0_fused_ax1_1_fused_2 // 4 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // 2], lv716[v0, (vax1_0_fused_ax1_1_fused_0 * 64 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // 2 * 8 + vax1_0_fused_ax1_1_fused_2 * 2 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused % 2) // 32]) | |
T.writes(var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused, 0, 0, v0]) | |
var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused, 0, 0, v0] = var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused, 0, 0, v0] + lv1875_shared[0, 0, vax1_0_fused_ax1_1_fused_0 * 64 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // 2 * 8 + vax1_0_fused_ax1_1_fused_2 * 2 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused % 2] * ((T.Cast("float16", T.bitwise_and(T.shift_right(lv715_local[v0, vax1_0_fused_ax1_1_fused_0 * 8 + vax1_0_fused_ax1_1_fused_2 // 4 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // 2], T.Cast("uint32", (vax1_0_fused_ax1_1_fused_0 * 64 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // 2 * 8 + vax1_0_fused_ax1_1_fused_2 * 2 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused % 2) % 8) * T.uint32(4)), T.uint32(15))) - T.float16(7)) * lv716[v0, (vax1_0_fused_ax1_1_fused_0 * 64 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // 2 * 8 + vax1_0_fused_ax1_1_fused_2 * 2 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused % 2) // 32]) | |
for ax2_fused_0 in T.thread_binding(32, thread="threadIdx.y"): | |
for ax0 in T.thread_binding(8, thread="threadIdx.x"): | |
for ax2_fused_1_0 in T.serial(1, annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}): | |
for ax2_fused_1_1 in T.vectorized(1): | |
with T.block("NT_matmul_rf_init"): | |
vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 = T.axis.spatial(8, ax0) | |
v0 = T.axis.spatial(7680, u_fused_ax0_fused_fused_0 * 32 + ax2_fused_0 + ax2_fused_1_0 + ax2_fused_1_1) | |
T.reads() | |
T.writes(var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0]) | |
var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0] = T.float16(0) | |
for ax1 in range(2): | |
with T.block("NT_matmul_rf_update"): | |
vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1 = T.axis.remap("SR", [ax0, ax1]) | |
v0 = T.axis.spatial(7680, u_fused_ax0_fused_fused_0 * 32 + ax2_fused_0 + ax2_fused_1_0 + ax2_fused_1_1) | |
T.reads(var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0], var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 * 2 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1, 0, 0, v0]) | |
T.writes(var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0]) | |
var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0] = var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0] + var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 * 2 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1, 0, 0, v0] | |
for ax1_fused_1 in range(1): | |
for ax1_fused_0 in T.thread_binding(32, thread="threadIdx.y"): | |
for ax0 in T.thread_binding(8, thread="threadIdx.x"): | |
with T.block("NT_matmul"): | |
vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 = T.axis.reduce(8, ax0) | |
v0 = T.axis.spatial(7680, u_fused_ax0_fused_fused_0 * 32 + ax1_fused_0 + ax1_fused_1) | |
T.reads(var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0]) | |
T.writes(var_NT_matmul_intermediate_local[0, 0, v0]) | |
with T.init(): | |
var_NT_matmul_intermediate_local[0, 0, v0] = T.float16(0) | |
var_NT_matmul_intermediate_local[0, 0, v0] = var_NT_matmul_intermediate_local[0, 0, v0] + var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0] | |
for ax0_fused_0 in T.thread_binding(32, thread="threadIdx.y"): | |
for ax0_fused_1 in range(1): | |
with T.block("T_add"): | |
v0 = T.axis.spatial(7680, u_fused_ax0_fused_fused_0 * 32 + ax0_fused_0 + ax0_fused_1) | |
T.reads(var_NT_matmul_intermediate_local[0, 0, v0], lv524[v0]) | |
T.writes(p_output0_intermediate[0, 0, v0]) | |
p_output0_intermediate[0, 0, v0] = var_NT_matmul_intermediate_local[0, 0, v0] + lv524[v0] | |
@T.prim_func(private=True) | |
def fused_fused_decode2_fused_NT_matmul_add(lv5: T.Buffer((7680, 320), "uint32"), lv6: T.Buffer((7680, 80), "float16"), p_lv9: T.handle, lv6_1: T.Buffer((7680,), "float16"), p_output0: T.handle): | |
T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) | |
n = T.int32() | |
lv9 = T.match_buffer(p_lv9, (1, n, 2560), "float16") | |
p_output0_intermediate = T.match_buffer(p_output0, (1, n, 7680), "float16") | |
# with T.block("root"): | |
var_NT_matmul_intermediate_reindex_pad_local = T.alloc_buffer((1, (n + 31) // 32 * 32, 7680), "float16", scope="local") | |
lv9_reindex_pad_shared = T.alloc_buffer((1, (n + 31) // 32 * 32, 2560), "float16", scope="shared") | |
p_output0_intermediate_reindex_shared = T.alloc_buffer((1, 7680, 2560), "float16", scope="shared") | |
for ax0_ax2_0_fused in T.thread_binding(240, thread="blockIdx.y"): | |
for ax1_0 in T.thread_binding((n + 31) // 32, thread="blockIdx.x"): | |
for ax2_1 in T.thread_binding(1, thread="vthread.y"): | |
for ax1_1 in T.thread_binding(1, thread="vthread.x"): | |
for ax2_2 in T.thread_binding(8, thread="threadIdx.y"): | |
for ax1_2 in T.thread_binding(8, thread="threadIdx.x", annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}): | |
for ax2_3_init, ax1_3_0_init in T.grid(4, 4): | |
for ax1_3_1_init in T.vectorized(1): | |
with T.block("NT_matmul_init"): | |
v0 = T.axis.spatial(1, 0) | |
v1 = T.axis.spatial((n + 31) // 32 * 32, ax1_0 * 32 + ax1_1 * 32 + ax1_2 * 4 + ax1_3_0_init + ax1_3_1_init) | |
v2 = T.axis.spatial(7680, ax0_ax2_0_fused * 32 + ax2_1 * 32 + ax2_2 * 4 + ax2_3_init) | |
T.reads() | |
T.writes(var_NT_matmul_intermediate_reindex_pad_local[0, v1, v2]) | |
var_NT_matmul_intermediate_reindex_pad_local[0, v1, v2] = T.float16(0) | |
for ax3_0 in range(320): | |
for ax0_ax1_ax2_fused_0 in T.thread_binding(8, thread="threadIdx.y"): | |
for ax0_ax1_ax2_fused_1 in T.thread_binding(8, thread="threadIdx.x"): | |
for ax0_ax1_ax2_fused_2 in range(4): | |
for ax0_ax1_ax2_fused_3 in T.vectorized(1): | |
with T.block("lv9_reindex_pad_shared"): | |
v0 = T.axis.spatial(1, 0) | |
v1 = T.axis.spatial((n + 31) // 32 * 32, ax1_0 * 32 + (ax0_ax1_ax2_fused_0 * 32 + ax0_ax1_ax2_fused_1 * 4 + ax0_ax1_ax2_fused_2 + ax0_ax1_ax2_fused_3) // 8) | |
v2 = T.axis.spatial(2560, ax3_0 * 8 + (ax0_ax1_ax2_fused_0 * 32 + ax0_ax1_ax2_fused_1 * 4 + ax0_ax1_ax2_fused_2 + ax0_ax1_ax2_fused_3) % 8) | |
T.reads(lv9[v0, v1, v2]) | |
T.writes(lv9_reindex_pad_shared[v0, v1, v2]) | |
lv9_reindex_pad_shared[v0, v1, v2] = T.if_then_else(v1 < n, lv9[v0, v1, v2], T.float16(0)) | |
for ax0_ax1_ax2_fused_0 in T.thread_binding(8, thread="threadIdx.y"): | |
for ax0_ax1_ax2_fused_1 in T.thread_binding(8, thread="threadIdx.x"): | |
for ax0_ax1_ax2_fused_2 in range(4): | |
for ax0_ax1_ax2_fused_3 in T.vectorized(1): | |
with T.block("p_output0_intermediate_reindex_shared"): | |
v0 = T.axis.spatial(1, 0) | |
v1 = T.axis.spatial(7680, ax0_ax2_0_fused * 32 + (ax0_ax1_ax2_fused_0 * 32 + ax0_ax1_ax2_fused_1 * 4 + ax0_ax1_ax2_fused_2 + ax0_ax1_ax2_fused_3) // 8) | |
v2 = T.axis.spatial(2560, ax3_0 * 8 + (ax0_ax1_ax2_fused_0 * 32 + ax0_ax1_ax2_fused_1 * 4 + ax0_ax1_ax2_fused_2 + ax0_ax1_ax2_fused_3) % 8) | |
T.reads(lv5[v1, v2 // 8], lv6[v1, v2 // 32]) | |
T.writes(p_output0_intermediate_reindex_shared[v0, v1, v2]) | |
p_output0_intermediate_reindex_shared[v0, v1, v2] = (T.Cast("float16", T.bitwise_and(T.shift_right(lv5[v1, v2 // 8], T.Cast("uint32", v2 % 8) * T.uint32(4)), T.uint32(15))) - T.float16(7)) * lv6[v1, v2 // 32] | |
for ax3_1, ax2_3, ax1_3_0 in T.grid(8, 4, 4): | |
for ax1_3_1 in T.vectorized(1): | |
with T.block("NT_matmul_update"): | |
v0 = T.axis.spatial(1, 0) | |
v1 = T.axis.spatial((n + 31) // 32 * 32, ax1_0 * 32 + ax1_1 * 32 + ax1_2 * 4 + ax1_3_0 + ax1_3_1) | |
v2 = T.axis.spatial(7680, ax0_ax2_0_fused * 32 + ax2_1 * 32 + ax2_2 * 4 + ax2_3) | |
v3 = T.axis.reduce(2560, ax3_0 * 8 + ax3_1) | |
T.reads(var_NT_matmul_intermediate_reindex_pad_local[0, v1, v2], lv9_reindex_pad_shared[0, v1, v3], p_output0_intermediate_reindex_shared[0, v2, v3]) | |
T.writes(var_NT_matmul_intermediate_reindex_pad_local[0, v1, v2]) | |
var_NT_matmul_intermediate_reindex_pad_local[0, v1, v2] = var_NT_matmul_intermediate_reindex_pad_local[0, v1, v2] + lv9_reindex_pad_shared[0, v1, v3] * p_output0_intermediate_reindex_shared[0, v2, v3] | |
for ax0, ax1, ax2_0 in T.grid(1, 4, 4): | |
for ax2_1_1 in T.vectorized(1): | |
with T.block("var_NT_matmul_intermediate_reindex_pad_local"): | |
v0 = T.axis.spatial(1, ax0) | |
v1 = T.axis.spatial((n + 31) // 32 * 32, ax1_0 * 32 + ax1_2 * 4 + ax1) | |
v2 = T.axis.spatial(7680, ax0_ax2_0_fused * 32 + ax2_2 * 4 + ax2_0 + ax2_1_1) | |
T.reads(var_NT_matmul_intermediate_reindex_pad_local[v0, v1, v2], lv6_1[v2]) | |
T.writes(p_output0_intermediate[0, v1, v2]) | |
if v1 < n: | |
p_output0_intermediate[0, v1, v2] = var_NT_matmul_intermediate_reindex_pad_local[v0, v1, v2] + lv6_1[v2] | |
@T.prim_func(private=True) | |
def fused_fused_decode3_fused_NT_matmul2_add1_add2(lv13: T.Buffer((2560, 320), "uint32"), lv14: T.Buffer((2560, 80), "float16"), p_lv43: T.handle, lv9: T.Buffer((2560,), "float16"), p_lv2: T.handle, p_output0: T.handle): | |
T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) | |
n = T.int32() | |
lv43 = T.match_buffer(p_lv43, (1, n, 2560), "float16") | |
lv2 = T.match_buffer(p_lv2, (1, n, 2560), "float16") | |
p_output0_intermediate = T.match_buffer(p_output0, (1, n, 2560), "float16") | |
# with T.block("root"): | |
var_NT_matmul_intermediate_reindex_pad_local = T.alloc_buffer((1, (n + 31) // 32 * 32, 2560), "float16", scope="local") | |
lv43_reindex_pad_shared = T.alloc_buffer((1, (n + 31) // 32 * 32, 2560), "float16", scope="shared") | |
p_output0_intermediate_reindex_shared = T.alloc_buffer((1, 2560, 2560), "float16", scope="shared") | |
for ax0_ax2_0_fused in T.thread_binding(80, thread="blockIdx.y"): | |
for ax1_0 in T.thread_binding((n + 31) // 32, thread="blockIdx.x"): | |
for ax2_1 in T.thread_binding(1, thread="vthread.y"): | |
for ax1_1 in T.thread_binding(1, thread="vthread.x"): | |
for ax2_2 in T.thread_binding(8, thread="threadIdx.y"): | |
for ax1_2 in T.thread_binding(8, thread="threadIdx.x", annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}): | |
for ax2_3_init, ax1_3_0_init in T.grid(4, 4): | |
for ax1_3_1_init in T.vectorized(1): | |
with T.block("NT_matmul_init"): | |
v0 = T.axis.spatial(1, 0) | |
v1 = T.axis.spatial((n + 31) // 32 * 32, ax1_0 * 32 + ax1_1 * 32 + ax1_2 * 4 + ax1_3_0_init + ax1_3_1_init) | |
v2 = T.axis.spatial(2560, ax0_ax2_0_fused * 32 + ax2_1 * 32 + ax2_2 * 4 + ax2_3_init) | |
T.reads() | |
T.writes(var_NT_matmul_intermediate_reindex_pad_local[0, v1, v2]) | |
var_NT_matmul_intermediate_reindex_pad_local[0, v1, v2] = T.float16(0) | |
for ax3_0 in range(320): | |
for ax0_ax1_ax2_fused_0 in T.thread_binding(8, thread="threadIdx.y"): | |
for ax0_ax1_ax2_fused_1 in T.thread_binding(8, thread="threadIdx.x"): | |
for ax0_ax1_ax2_fused_2 in range(4): | |
for ax0_ax1_ax2_fused_3 in T.vectorized(1): | |
with T.block("lv43_reindex_pad_shared"): | |
v0 = T.axis.spatial(1, 0) | |
v1 = T.axis.spatial((n + 31) // 32 * 32, ax1_0 * 32 + (ax0_ax1_ax2_fused_0 * 32 + ax0_ax1_ax2_fused_1 * 4 + ax0_ax1_ax2_fused_2 + ax0_ax1_ax2_fused_3) // 8) | |
v2 = T.axis.spatial(2560, ax3_0 * 8 + (ax0_ax1_ax2_fused_0 * 32 + ax0_ax1_ax2_fused_1 * 4 + ax0_ax1_ax2_fused_2 + ax0_ax1_ax2_fused_3) % 8) | |
T.reads(lv43[v0, v1, v2]) | |
T.writes(lv43_reindex_pad_shared[v0, v1, v2]) | |
lv43_reindex_pad_shared[v0, v1, v2] = T.if_then_else(v1 < n, lv43[v0, v1, v2], T.float16(0)) | |
for ax0_ax1_ax2_fused_0 in T.thread_binding(8, thread="threadIdx.y"): | |
for ax0_ax1_ax2_fused_1 in T.thread_binding(8, thread="threadIdx.x"): | |
for ax0_ax1_ax2_fused_2 in range(4): | |
for ax0_ax1_ax2_fused_3 in T.vectorized(1): | |
with T.block("p_output0_intermediate_reindex_shared"): | |
v0 = T.axis.spatial(1, 0) | |
v1 = T.axis.spatial(2560, ax0_ax2_0_fused * 32 + (ax0_ax1_ax2_fused_0 * 32 + ax0_ax1_ax2_fused_1 * 4 + ax0_ax1_ax2_fused_2 + ax0_ax1_ax2_fused_3) // 8) | |
v2 = T.axis.spatial(2560, ax3_0 * 8 + (ax0_ax1_ax2_fused_0 * 32 + ax0_ax1_ax2_fused_1 * 4 + ax0_ax1_ax2_fused_2 + ax0_ax1_ax2_fused_3) % 8) | |
T.reads(lv13[v1, v2 // 8], lv14[v1, v2 // 32]) | |
T.writes(p_output0_intermediate_reindex_shared[v0, v1, v2]) | |
p_output0_intermediate_reindex_shared[v0, v1, v2] = (T.Cast("float16", T.bitwise_and(T.shift_right(lv13[v1, v2 // 8], T.Cast("uint32", v2 % 8) * T.uint32(4)), T.uint32(15))) - T.float16(7)) * lv14[v1, v2 // 32] | |
for ax3_1, ax2_3, ax1_3_0 in T.grid(8, 4, 4): | |
for ax1_3_1 in T.vectorized(1): | |
with T.block("NT_matmul_update"): | |
v0 = T.axis.spatial(1, 0) | |
v1 = T.axis.spatial((n + 31) // 32 * 32, ax1_0 * 32 + ax1_1 * 32 + ax1_2 * 4 + ax1_3_0 + ax1_3_1) | |
v2 = T.axis.spatial(2560, ax0_ax2_0_fused * 32 + ax2_1 * 32 + ax2_2 * 4 + ax2_3) | |
v3 = T.axis.reduce(2560, ax3_0 * 8 + ax3_1) | |
T.reads(var_NT_matmul_intermediate_reindex_pad_local[0, v1, v2], lv43_reindex_pad_shared[0, v1, v3], p_output0_intermediate_reindex_shared[0, v2, v3]) | |
T.writes(var_NT_matmul_intermediate_reindex_pad_local[0, v1, v2]) | |
var_NT_matmul_intermediate_reindex_pad_local[0, v1, v2] = var_NT_matmul_intermediate_reindex_pad_local[0, v1, v2] + lv43_reindex_pad_shared[0, v1, v3] * p_output0_intermediate_reindex_shared[0, v2, v3] | |
for ax0, ax1, ax2_0 in T.grid(1, 4, 4): | |
for ax2_1_1 in T.vectorized(1): | |
with T.block("var_NT_matmul_intermediate_reindex_pad_local"): | |
v0 = T.axis.spatial(1, ax0) | |
v1 = T.axis.spatial((n + 31) // 32 * 32, ax1_0 * 32 + ax1_2 * 4 + ax1) | |
v2 = T.axis.spatial(2560, ax0_ax2_0_fused * 32 + ax2_2 * 4 + ax2_0 + ax2_1_1) | |
T.reads(var_NT_matmul_intermediate_reindex_pad_local[v0, v1, v2], lv9[v2], lv2[0, v1, v2]) | |
T.writes(p_output0_intermediate[0, v1, v2]) | |
if v1 < n: | |
p_output0_intermediate[0, v1, v2] = var_NT_matmul_intermediate_reindex_pad_local[v0, v1, v2] + lv9[v2] + lv2[0, v1, v2] | |
@T.prim_func(private=True) | |
def fused_fused_decode3_fused_NT_matmul8_add6_add7(lv725: T.Buffer((2560, 320), "uint32"), lv726: T.Buffer((2560, 80), "float16"), lv724: T.Buffer((1, 1, 2560), "float16"), lv527: T.Buffer((2560,), "float16"), lv1870: T.Buffer((1, 1, 2560), "float16"), p_output0_intermediate: T.Buffer((1, 1, 2560), "float16")): | |
T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) | |
# with T.block("root"): | |
var_NT_matmul_intermediate_local = T.alloc_buffer((1, 1, 2560), "float16", scope="local") | |
var_NT_matmul_intermediate_rf_local = T.alloc_buffer((16, 1, 1, 2560), "float16", scope="local") | |
var_NT_matmul_intermediate_rf_local_1 = T.alloc_buffer((8, 1, 1, 2560), "float16", scope="local") | |
lv725_local = T.alloc_buffer((2560, 320), "uint32", scope="local") | |
lv724_shared = T.alloc_buffer((1, 1, 2560), "float16", scope="shared") | |
for u_fused_ax0_fused_fused_0 in T.thread_binding(80, thread="blockIdx.x"): | |
for u_fused_ax0_fused_fused_1 in T.thread_binding(32, thread="threadIdx.y"): | |
for ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 in T.thread_binding(8, thread="threadIdx.x"): | |
for ax0, ax1 in T.grid(1, 1): | |
for ax2_0 in T.serial(5, annotations={"pragma_unroll_explicit": 256, "pragma_vectorize": 1}): | |
for ax2_1 in T.thread_binding(32, thread="threadIdx.y"): | |
for ax2_2 in T.thread_binding(8, thread="threadIdx.x"): | |
for ax2_3 in T.vectorized(2): | |
with T.block("lv724_shared"): | |
v0, v1 = T.axis.remap("SS", [ax0, ax1]) | |
v2 = T.axis.spatial(2560, ax2_0 * 512 + ax2_1 * 16 + ax2_2 * 2 + ax2_3) | |
T.reads(lv724[v0, v1, v2]) | |
T.writes(lv724_shared[v0, v1, v2]) | |
lv724_shared[v0, v1, v2] = lv724[v0, v1, v2] | |
for u_fused_ax0_fused_fused_2_init in range(1): | |
for ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1_init in T.vectorized(2): | |
with T.block("NT_matmul_rf_init"): | |
vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused = T.axis.spatial(16, ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 * 2 + ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1_init) | |
v0 = T.axis.spatial(2560, u_fused_ax0_fused_fused_0 * 32 + u_fused_ax0_fused_fused_1 + u_fused_ax0_fused_fused_2_init) | |
T.reads() | |
T.writes(var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused, 0, 0, v0]) | |
var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused, 0, 0, v0] = T.float16(0) | |
for ax1_0_fused_ax1_1_fused_0 in T.serial(40, annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}): | |
for ax0_0, ax1 in T.grid(1, 1): | |
for ax0_1 in T.vectorized(1): | |
with T.block("lv725_local"): | |
v0 = T.axis.spatial(2560, u_fused_ax0_fused_fused_0 * 32 + u_fused_ax0_fused_fused_1 + ax0_0 + ax0_1) | |
v1 = T.axis.spatial(320, ax1_0_fused_ax1_1_fused_0 * 8 + ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 + ax1) | |
T.reads(lv725[v0, v1]) | |
T.writes(lv725_local[v0, v1]) | |
lv725_local[v0, v1] = lv725[v0, v1] | |
for u_fused_ax0_fused_fused_2, ax1_0_fused_ax1_1_fused_2 in T.grid(1, 4): | |
for ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1 in T.vectorized(2): | |
with T.block("NT_matmul_rf_update"): | |
vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused = T.axis.spatial(16, ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 * 2 + ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1) | |
v0 = T.axis.spatial(2560, u_fused_ax0_fused_fused_0 * 32 + u_fused_ax0_fused_fused_1 + u_fused_ax0_fused_fused_2) | |
vax1_0_fused_ax1_1_fused_0, vax1_0_fused_ax1_1_fused_2 = T.axis.remap("RR", [ax1_0_fused_ax1_1_fused_0, ax1_0_fused_ax1_1_fused_2]) | |
T.reads(var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused, 0, 0, v0], lv724_shared[0, 0, vax1_0_fused_ax1_1_fused_0 * 64 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // 2 * 8 + vax1_0_fused_ax1_1_fused_2 * 2 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused % 2], lv725_local[v0, vax1_0_fused_ax1_1_fused_0 * 8 + vax1_0_fused_ax1_1_fused_2 // 4 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // 2], lv726[v0, (vax1_0_fused_ax1_1_fused_0 * 64 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // 2 * 8 + vax1_0_fused_ax1_1_fused_2 * 2 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused % 2) // 32]) | |
T.writes(var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused, 0, 0, v0]) | |
var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused, 0, 0, v0] = var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused, 0, 0, v0] + lv724_shared[0, 0, vax1_0_fused_ax1_1_fused_0 * 64 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // 2 * 8 + vax1_0_fused_ax1_1_fused_2 * 2 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused % 2] * ((T.Cast("float16", T.bitwise_and(T.shift_right(lv725_local[v0, vax1_0_fused_ax1_1_fused_0 * 8 + vax1_0_fused_ax1_1_fused_2 // 4 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // 2], T.Cast("uint32", (vax1_0_fused_ax1_1_fused_0 * 64 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // 2 * 8 + vax1_0_fused_ax1_1_fused_2 * 2 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused % 2) % 8) * T.uint32(4)), T.uint32(15))) - T.float16(7)) * lv726[v0, (vax1_0_fused_ax1_1_fused_0 * 64 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // 2 * 8 + vax1_0_fused_ax1_1_fused_2 * 2 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused % 2) // 32]) | |
for ax2_fused_0 in T.thread_binding(32, thread="threadIdx.y"): | |
for ax0 in T.thread_binding(8, thread="threadIdx.x"): | |
for ax2_fused_1_0 in T.serial(1, annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}): | |
for ax2_fused_1_1 in T.vectorized(1): | |
with T.block("NT_matmul_rf_init"): | |
vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 = T.axis.spatial(8, ax0) | |
v0 = T.axis.spatial(2560, u_fused_ax0_fused_fused_0 * 32 + ax2_fused_0 + ax2_fused_1_0 + ax2_fused_1_1) | |
T.reads() | |
T.writes(var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0]) | |
var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0] = T.float16(0) | |
for ax1 in range(2): | |
with T.block("NT_matmul_rf_update"): | |
vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1 = T.axis.remap("SR", [ax0, ax1]) | |
v0 = T.axis.spatial(2560, u_fused_ax0_fused_fused_0 * 32 + ax2_fused_0 + ax2_fused_1_0 + ax2_fused_1_1) | |
T.reads(var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0], var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 * 2 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1, 0, 0, v0]) | |
T.writes(var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0]) | |
var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0] = var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0] + var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 * 2 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1, 0, 0, v0] | |
for ax1_fused_1 in range(1): | |
for ax1_fused_0 in T.thread_binding(32, thread="threadIdx.y"): | |
for ax0 in T.thread_binding(8, thread="threadIdx.x"): | |
with T.block("NT_matmul"): | |
vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 = T.axis.reduce(8, ax0) | |
v0 = T.axis.spatial(2560, u_fused_ax0_fused_fused_0 * 32 + ax1_fused_0 + ax1_fused_1) | |
T.reads(var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0]) | |
T.writes(var_NT_matmul_intermediate_local[0, 0, v0]) | |
with T.init(): | |
var_NT_matmul_intermediate_local[0, 0, v0] = T.float16(0) | |
var_NT_matmul_intermediate_local[0, 0, v0] = var_NT_matmul_intermediate_local[0, 0, v0] + var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0] | |
for ax0_fused_0 in T.thread_binding(32, thread="threadIdx.y"): | |
for ax0_fused_1 in range(1): | |
with T.block("T_add_1"): | |
v0 = T.axis.spatial(2560, u_fused_ax0_fused_fused_0 * 32 + ax0_fused_0 + ax0_fused_1) | |
T.reads(var_NT_matmul_intermediate_local[0, 0, v0], lv527[v0], lv1870[0, 0, v0]) | |
T.writes(p_output0_intermediate[0, 0, v0]) | |
p_output0_intermediate[0, 0, v0] = var_NT_matmul_intermediate_local[0, 0, v0] + lv527[v0] + lv1870[0, 0, v0] | |
@T.prim_func(private=True) | |
def fused_fused_decode4_fused_NT_matmul3_add3_gelu_cast4(lv18: T.Buffer((10240, 320), "uint32"), lv19: T.Buffer((10240, 80), "float16"), p_lv51: T.handle, lv14: T.Buffer((10240,), "float32"), p_output0: T.handle): | |
T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) | |
n = T.int32() | |
lv51 = T.match_buffer(p_lv51, (1, n, 2560), "float16") | |
p_output0_intermediate = T.match_buffer(p_output0, (1, n, 10240), "float16") | |
# with T.block("root"): | |
var_NT_matmul_intermediate_reindex_pad_local = T.alloc_buffer((1, (n + 31) // 32 * 32, 10240), scope="local") | |
lv51_reindex_pad_shared = T.alloc_buffer((1, (n + 31) // 32 * 32, 2560), "float16", scope="shared") | |
p_output0_intermediate_reindex_shared = T.alloc_buffer((1, 10240, 2560), "float16", scope="shared") | |
for ax0_ax2_0_fused in T.thread_binding(320, thread="blockIdx.y"): | |
for ax1_0 in T.thread_binding((n + 31) // 32, thread="blockIdx.x"): | |
for ax2_1 in T.thread_binding(1, thread="vthread.y"): | |
for ax1_1 in T.thread_binding(1, thread="vthread.x"): | |
for ax2_2 in T.thread_binding(8, thread="threadIdx.y"): | |
for ax1_2 in T.thread_binding(8, thread="threadIdx.x", annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}): | |
for ax2_3_init, ax1_3_0_init in T.grid(4, 4): | |
for ax1_3_1_init in T.vectorized(1): | |
with T.block("NT_matmul_init"): | |
v0 = T.axis.spatial(1, 0) | |
v1 = T.axis.spatial((n + 31) // 32 * 32, ax1_0 * 32 + ax1_1 * 32 + ax1_2 * 4 + ax1_3_0_init + ax1_3_1_init) | |
v2 = T.axis.spatial(10240, ax0_ax2_0_fused * 32 + ax2_1 * 32 + ax2_2 * 4 + ax2_3_init) | |
T.reads() | |
T.writes(var_NT_matmul_intermediate_reindex_pad_local[0, v1, v2]) | |
var_NT_matmul_intermediate_reindex_pad_local[0, v1, v2] = T.float32(0) | |
for ax3_0 in range(320): | |
for ax0_ax1_ax2_fused_0 in T.thread_binding(8, thread="threadIdx.y"): | |
for ax0_ax1_ax2_fused_1 in T.thread_binding(8, thread="threadIdx.x"): | |
for ax0_ax1_ax2_fused_2 in range(4): | |
for ax0_ax1_ax2_fused_3 in T.vectorized(1): | |
with T.block("lv51_reindex_pad_shared"): | |
v0 = T.axis.spatial(1, 0) | |
v1 = T.axis.spatial((n + 31) // 32 * 32, ax1_0 * 32 + (ax0_ax1_ax2_fused_0 * 32 + ax0_ax1_ax2_fused_1 * 4 + ax0_ax1_ax2_fused_2 + ax0_ax1_ax2_fused_3) // 8) | |
v2 = T.axis.spatial(2560, ax3_0 * 8 + (ax0_ax1_ax2_fused_0 * 32 + ax0_ax1_ax2_fused_1 * 4 + ax0_ax1_ax2_fused_2 + ax0_ax1_ax2_fused_3) % 8) | |
T.reads(lv51[v0, v1, v2]) | |
T.writes(lv51_reindex_pad_shared[v0, v1, v2]) | |
lv51_reindex_pad_shared[v0, v1, v2] = T.if_then_else(v1 < n, lv51[v0, v1, v2], T.float16(0)) | |
for ax0_ax1_ax2_fused_0 in T.thread_binding(8, thread="threadIdx.y"): | |
for ax0_ax1_ax2_fused_1 in T.thread_binding(8, thread="threadIdx.x"): | |
for ax0_ax1_ax2_fused_2 in range(4): | |
for ax0_ax1_ax2_fused_3 in T.vectorized(1): | |
with T.block("p_output0_intermediate_reindex_shared"): | |
v0 = T.axis.spatial(1, 0) | |
v1 = T.axis.spatial(10240, ax0_ax2_0_fused * 32 + (ax0_ax1_ax2_fused_0 * 32 + ax0_ax1_ax2_fused_1 * 4 + ax0_ax1_ax2_fused_2 + ax0_ax1_ax2_fused_3) // 8) | |
v2 = T.axis.spatial(2560, ax3_0 * 8 + (ax0_ax1_ax2_fused_0 * 32 + ax0_ax1_ax2_fused_1 * 4 + ax0_ax1_ax2_fused_2 + ax0_ax1_ax2_fused_3) % 8) | |
T.reads(lv18[v1, v2 // 8], lv19[v1, v2 // 32]) | |
T.writes(p_output0_intermediate_reindex_shared[v0, v1, v2]) | |
p_output0_intermediate_reindex_shared[v0, v1, v2] = (T.Cast("float16", T.bitwise_and(T.shift_right(lv18[v1, v2 // 8], T.Cast("uint32", v2 % 8) * T.uint32(4)), T.uint32(15))) - T.float16(7)) * lv19[v1, v2 // 32] | |
for ax3_1, ax2_3, ax1_3_0 in T.grid(8, 4, 4): | |
for ax1_3_1 in T.vectorized(1): | |
with T.block("NT_matmul_update"): | |
v0 = T.axis.spatial(1, 0) | |
v1 = T.axis.spatial((n + 31) // 32 * 32, ax1_0 * 32 + ax1_1 * 32 + ax1_2 * 4 + ax1_3_0 + ax1_3_1) | |
v2 = T.axis.spatial(10240, ax0_ax2_0_fused * 32 + ax2_1 * 32 + ax2_2 * 4 + ax2_3) | |
v3 = T.axis.reduce(2560, ax3_0 * 8 + ax3_1) | |
T.reads(var_NT_matmul_intermediate_reindex_pad_local[0, v1, v2], lv51_reindex_pad_shared[0, v1, v3], p_output0_intermediate_reindex_shared[0, v2, v3]) | |
T.writes(var_NT_matmul_intermediate_reindex_pad_local[0, v1, v2]) | |
var_NT_matmul_intermediate_reindex_pad_local[0, v1, v2] = var_NT_matmul_intermediate_reindex_pad_local[0, v1, v2] + T.Cast("float32", lv51_reindex_pad_shared[0, v1, v3]) * T.Cast("float32", p_output0_intermediate_reindex_shared[0, v2, v3]) | |
for ax0, ax1, ax2_0 in T.grid(1, 4, 4): | |
for ax2_1_1 in T.vectorized(1): | |
with T.block("var_NT_matmul_intermediate_reindex_pad_local"): | |
v0 = T.axis.spatial(1, ax0) | |
v1 = T.axis.spatial((n + 31) // 32 * 32, ax1_0 * 32 + ax1_2 * 4 + ax1) | |
v2 = T.axis.spatial(10240, ax0_ax2_0_fused * 32 + ax2_2 * 4 + ax2_0 + ax2_1_1) | |
T.reads(var_NT_matmul_intermediate_reindex_pad_local[v0, v1, v2], lv14[v2]) | |
T.writes(p_output0_intermediate[0, v1, v2]) | |
if v1 < n: | |
p_output0_intermediate[0, v1, v2] = T.Cast("float16", (var_NT_matmul_intermediate_reindex_pad_local[v0, v1, v2] + lv14[v2]) * (T.float32(0.5) + T.erf((var_NT_matmul_intermediate_reindex_pad_local[v0, v1, v2] + lv14[v2]) * T.float32(0.70710678118654757)) * T.float32(0.5))) | |
@T.prim_func(private=True) | |
def fused_fused_decode4_fused_NT_matmul9_add8_gelu1_cast11(lv730: T.Buffer((10240, 320), "uint32"), lv731: T.Buffer((10240, 80), "float16"), lv1917: T.Buffer((1, 1, 2560), "float16"), lv532: T.Buffer((10240,), "float32"), p_output0_intermediate: T.Buffer((1, 1, 10240), "float16")): | |
T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) | |
# with T.block("root"): | |
var_NT_matmul_intermediate_local = T.alloc_buffer((1, 1, 10240), scope="local") | |
var_NT_matmul_intermediate_rf_local = T.alloc_buffer((16, 1, 1, 10240), scope="local") | |
var_NT_matmul_intermediate_rf_local_1 = T.alloc_buffer((8, 1, 1, 10240), scope="local") | |
lv730_local = T.alloc_buffer((10240, 320), "uint32", scope="local") | |
lv1917_shared = T.alloc_buffer((1, 1, 2560), "float16", scope="shared") | |
for u_fused_ax0_fused_fused_0 in T.thread_binding(320, thread="blockIdx.x"): | |
for u_fused_ax0_fused_fused_1 in T.thread_binding(32, thread="threadIdx.y"): | |
for ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 in T.thread_binding(8, thread="threadIdx.x"): | |
for ax0, ax1 in T.grid(1, 1): | |
for ax2_0 in T.serial(5, annotations={"pragma_unroll_explicit": 256, "pragma_vectorize": 1}): | |
for ax2_1 in T.thread_binding(32, thread="threadIdx.y"): | |
for ax2_2 in T.thread_binding(8, thread="threadIdx.x"): | |
for ax2_3 in T.vectorized(2): | |
with T.block("lv1917_shared"): | |
v0, v1 = T.axis.remap("SS", [ax0, ax1]) | |
v2 = T.axis.spatial(2560, ax2_0 * 512 + ax2_1 * 16 + ax2_2 * 2 + ax2_3) | |
T.reads(lv1917[v0, v1, v2]) | |
T.writes(lv1917_shared[v0, v1, v2]) | |
lv1917_shared[v0, v1, v2] = lv1917[v0, v1, v2] | |
for u_fused_ax0_fused_fused_2_init in range(1): | |
for ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1_init in T.vectorized(2): | |
with T.block("NT_matmul_rf_init"): | |
vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused = T.axis.spatial(16, ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 * 2 + ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1_init) | |
v0 = T.axis.spatial(10240, u_fused_ax0_fused_fused_0 * 32 + u_fused_ax0_fused_fused_1 + u_fused_ax0_fused_fused_2_init) | |
T.reads() | |
T.writes(var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused, 0, 0, v0]) | |
var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused, 0, 0, v0] = T.float32(0) | |
for ax1_0_fused_ax1_1_fused_0 in T.serial(40, annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}): | |
for ax0_0, ax1 in T.grid(1, 1): | |
for ax0_1 in T.vectorized(1): | |
with T.block("lv730_local"): | |
v0 = T.axis.spatial(10240, u_fused_ax0_fused_fused_0 * 32 + u_fused_ax0_fused_fused_1 + ax0_0 + ax0_1) | |
v1 = T.axis.spatial(320, ax1_0_fused_ax1_1_fused_0 * 8 + ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 + ax1) | |
T.reads(lv730[v0, v1]) | |
T.writes(lv730_local[v0, v1]) | |
lv730_local[v0, v1] = lv730[v0, v1] | |
for u_fused_ax0_fused_fused_2, ax1_0_fused_ax1_1_fused_2 in T.grid(1, 4): | |
for ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1 in T.vectorized(2): | |
with T.block("NT_matmul_rf_update"): | |
vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused = T.axis.spatial(16, ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 * 2 + ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1) | |
v0 = T.axis.spatial(10240, u_fused_ax0_fused_fused_0 * 32 + u_fused_ax0_fused_fused_1 + u_fused_ax0_fused_fused_2) | |
vax1_0_fused_ax1_1_fused_0, vax1_0_fused_ax1_1_fused_2 = T.axis.remap("RR", [ax1_0_fused_ax1_1_fused_0, ax1_0_fused_ax1_1_fused_2]) | |
T.reads(var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused, 0, 0, v0], lv1917_shared[0, 0, vax1_0_fused_ax1_1_fused_0 * 64 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // 2 * 8 + vax1_0_fused_ax1_1_fused_2 * 2 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused % 2], lv730_local[v0, vax1_0_fused_ax1_1_fused_0 * 8 + vax1_0_fused_ax1_1_fused_2 // 4 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // 2], lv731[v0, (vax1_0_fused_ax1_1_fused_0 * 64 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // 2 * 8 + vax1_0_fused_ax1_1_fused_2 * 2 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused % 2) // 32]) | |
T.writes(var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused, 0, 0, v0]) | |
var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused, 0, 0, v0] = var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused, 0, 0, v0] + T.Cast("float32", lv1917_shared[0, 0, vax1_0_fused_ax1_1_fused_0 * 64 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // 2 * 8 + vax1_0_fused_ax1_1_fused_2 * 2 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused % 2]) * T.Cast("float32", (T.Cast("float16", T.bitwise_and(T.shift_right(lv730_local[v0, vax1_0_fused_ax1_1_fused_0 * 8 + vax1_0_fused_ax1_1_fused_2 // 4 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // 2], T.Cast("uint32", (vax1_0_fused_ax1_1_fused_0 * 64 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // 2 * 8 + vax1_0_fused_ax1_1_fused_2 * 2 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused % 2) % 8) * T.uint32(4)), T.uint32(15))) - T.float16(7)) * lv731[v0, (vax1_0_fused_ax1_1_fused_0 * 64 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // 2 * 8 + vax1_0_fused_ax1_1_fused_2 * 2 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused % 2) // 32]) | |
for ax2_fused_0 in T.thread_binding(32, thread="threadIdx.y"): | |
for ax0 in T.thread_binding(8, thread="threadIdx.x"): | |
for ax2_fused_1_0 in T.serial(1, annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}): | |
for ax2_fused_1_1 in T.vectorized(1): | |
with T.block("NT_matmul_rf_init"): | |
vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 = T.axis.spatial(8, ax0) | |
v0 = T.axis.spatial(10240, u_fused_ax0_fused_fused_0 * 32 + ax2_fused_0 + ax2_fused_1_0 + ax2_fused_1_1) | |
T.reads() | |
T.writes(var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0]) | |
var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0] = T.float32(0) | |
for ax1 in range(2): | |
with T.block("NT_matmul_rf_update"): | |
vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1 = T.axis.remap("SR", [ax0, ax1]) | |
v0 = T.axis.spatial(10240, u_fused_ax0_fused_fused_0 * 32 + ax2_fused_0 + ax2_fused_1_0 + ax2_fused_1_1) | |
T.reads(var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0], var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 * 2 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1, 0, 0, v0]) | |
T.writes(var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0]) | |
var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0] = var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0] + var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 * 2 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1, 0, 0, v0] | |
for ax1_fused_1 in range(1): | |
for ax1_fused_0 in T.thread_binding(32, thread="threadIdx.y"): | |
for ax0 in T.thread_binding(8, thread="threadIdx.x"): | |
with T.block("NT_matmul"): | |
vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 = T.axis.reduce(8, ax0) | |
v0 = T.axis.spatial(10240, u_fused_ax0_fused_fused_0 * 32 + ax1_fused_0 + ax1_fused_1) | |
T.reads(var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0]) | |
T.writes(var_NT_matmul_intermediate_local[0, 0, v0]) | |
with T.init(): | |
var_NT_matmul_intermediate_local[0, 0, v0] = T.float32(0) | |
var_NT_matmul_intermediate_local[0, 0, v0] = var_NT_matmul_intermediate_local[0, 0, v0] + var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0] | |
for ax0_fused_0 in T.thread_binding(32, thread="threadIdx.y"): | |
for ax0_fused_1 in range(1): | |
with T.block("compute_1"): | |
v0 = T.axis.spatial(10240, u_fused_ax0_fused_fused_0 * 32 + ax0_fused_0 + ax0_fused_1) | |
T.reads(var_NT_matmul_intermediate_local[0, 0, v0], lv532[v0]) | |
T.writes(p_output0_intermediate[0, 0, v0]) | |
p_output0_intermediate[0, 0, v0] = T.Cast("float16", (var_NT_matmul_intermediate_local[0, 0, v0] + lv532[v0]) * (T.float32(0.5) + T.erf((var_NT_matmul_intermediate_local[0, 0, v0] + lv532[v0]) * T.float32(0.70710678118654757)) * T.float32(0.5))) | |
@T.prim_func(private=True) | |
def fused_fused_decode5_fused_NT_matmul10_add9_cast8_cast12_add7(lv734: T.Buffer((2560, 1280), "uint32"), lv735: T.Buffer((2560, 320), "float16"), lv1923: T.Buffer((1, 1, 10240), "float16"), lv535: T.Buffer((2560,), "float32"), lv728: T.Buffer((1, 1, 2560), "float16"), p_output0_intermediate: T.Buffer((1, 1, 2560), "float16")): | |
T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) | |
# with T.block("root"): | |
var_NT_matmul_intermediate_local = T.alloc_buffer((1, 1, 2560), scope="local") | |
var_NT_matmul_intermediate_rf_local = T.alloc_buffer((16, 1, 1, 2560), scope="local") | |
var_NT_matmul_intermediate_rf_local_1 = T.alloc_buffer((8, 1, 1, 2560), scope="local") | |
lv734_local = T.alloc_buffer((2560, 1280), "uint32", scope="local") | |
lv1923_shared = T.alloc_buffer((1, 1, 10240), "float16", scope="shared") | |
for u_fused_ax0_fused_fused_0 in T.thread_binding(80, thread="blockIdx.x"): | |
for u_fused_ax0_fused_fused_1 in T.thread_binding(32, thread="threadIdx.y"): | |
for ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 in T.thread_binding(8, thread="threadIdx.x"): | |
for ax0, ax1 in T.grid(1, 1): | |
for ax2_0 in T.serial(10, annotations={"pragma_unroll_explicit": 256, "pragma_vectorize": 1}): | |
for ax2_1 in T.thread_binding(32, thread="threadIdx.y"): | |
for ax2_2 in T.thread_binding(8, thread="threadIdx.x"): | |
for ax2_3 in T.vectorized(4): | |
with T.block("lv1923_shared"): | |
v0, v1 = T.axis.remap("SS", [ax0, ax1]) | |
v2 = T.axis.spatial(10240, ax2_0 * 1024 + ax2_1 * 32 + ax2_2 * 4 + ax2_3) | |
T.reads(lv1923[v0, v1, v2]) | |
T.writes(lv1923_shared[v0, v1, v2]) | |
lv1923_shared[v0, v1, v2] = lv1923[v0, v1, v2] | |
for u_fused_ax0_fused_fused_2_init in range(1): | |
for ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1_init in T.vectorized(2): | |
with T.block("NT_matmul_rf_init"): | |
vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused = T.axis.spatial(16, ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 * 2 + ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1_init) | |
v0 = T.axis.spatial(2560, u_fused_ax0_fused_fused_0 * 32 + u_fused_ax0_fused_fused_1 + u_fused_ax0_fused_fused_2_init) | |
T.reads() | |
T.writes(var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused, 0, 0, v0]) | |
var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused, 0, 0, v0] = T.float32(0) | |
for ax1_0_fused_ax1_1_fused_0 in T.serial(160, annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}): | |
for ax0_0, ax1 in T.grid(1, 1): | |
for ax0_1 in T.vectorized(1): | |
with T.block("lv734_local"): | |
v0 = T.axis.spatial(2560, u_fused_ax0_fused_fused_0 * 32 + u_fused_ax0_fused_fused_1 + ax0_0 + ax0_1) | |
v1 = T.axis.spatial(1280, ax1_0_fused_ax1_1_fused_0 * 8 + ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 + ax1) | |
T.reads(lv734[v0, v1]) | |
T.writes(lv734_local[v0, v1]) | |
lv734_local[v0, v1] = lv734[v0, v1] | |
for u_fused_ax0_fused_fused_2, ax1_0_fused_ax1_1_fused_2 in T.grid(1, 4): | |
for ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1 in T.vectorized(2): | |
with T.block("NT_matmul_rf_update"): | |
vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused = T.axis.spatial(16, ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 * 2 + ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1) | |
v0 = T.axis.spatial(2560, u_fused_ax0_fused_fused_0 * 32 + u_fused_ax0_fused_fused_1 + u_fused_ax0_fused_fused_2) | |
vax1_0_fused_ax1_1_fused_0, vax1_0_fused_ax1_1_fused_2 = T.axis.remap("RR", [ax1_0_fused_ax1_1_fused_0, ax1_0_fused_ax1_1_fused_2]) | |
T.reads(var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused, 0, 0, v0], lv1923_shared[0, 0, vax1_0_fused_ax1_1_fused_0 * 64 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // 2 * 8 + vax1_0_fused_ax1_1_fused_2 * 2 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused % 2], lv734_local[v0, vax1_0_fused_ax1_1_fused_0 * 8 + vax1_0_fused_ax1_1_fused_2 // 4 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // 2], lv735[v0, (vax1_0_fused_ax1_1_fused_0 * 64 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // 2 * 8 + vax1_0_fused_ax1_1_fused_2 * 2 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused % 2) // 32]) | |
T.writes(var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused, 0, 0, v0]) | |
var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused, 0, 0, v0] = var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused, 0, 0, v0] + T.Cast("float32", lv1923_shared[0, 0, vax1_0_fused_ax1_1_fused_0 * 64 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // 2 * 8 + vax1_0_fused_ax1_1_fused_2 * 2 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused % 2]) * T.Cast("float32", (T.Cast("float16", T.bitwise_and(T.shift_right(lv734_local[v0, vax1_0_fused_ax1_1_fused_0 * 8 + vax1_0_fused_ax1_1_fused_2 // 4 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // 2], T.Cast("uint32", (vax1_0_fused_ax1_1_fused_0 * 64 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // 2 * 8 + vax1_0_fused_ax1_1_fused_2 * 2 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused % 2) % 8) * T.uint32(4)), T.uint32(15))) - T.float16(7)) * lv735[v0, (vax1_0_fused_ax1_1_fused_0 * 64 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // 2 * 8 + vax1_0_fused_ax1_1_fused_2 * 2 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused % 2) // 32]) | |
for ax2_fused_0 in T.thread_binding(32, thread="threadIdx.y"): | |
for ax0 in T.thread_binding(8, thread="threadIdx.x"): | |
for ax2_fused_1_0 in T.serial(1, annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}): | |
for ax2_fused_1_1 in T.vectorized(1): | |
with T.block("NT_matmul_rf_init"): | |
vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 = T.axis.spatial(8, ax0) | |
v0 = T.axis.spatial(2560, u_fused_ax0_fused_fused_0 * 32 + ax2_fused_0 + ax2_fused_1_0 + ax2_fused_1_1) | |
T.reads() | |
T.writes(var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0]) | |
var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0] = T.float32(0) | |
for ax1 in range(2): | |
with T.block("NT_matmul_rf_update"): | |
vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1 = T.axis.remap("SR", [ax0, ax1]) | |
v0 = T.axis.spatial(2560, u_fused_ax0_fused_fused_0 * 32 + ax2_fused_0 + ax2_fused_1_0 + ax2_fused_1_1) | |
T.reads(var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0], var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 * 2 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1, 0, 0, v0]) | |
T.writes(var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0]) | |
var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0] = var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0] + var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 * 2 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1, 0, 0, v0] | |
for ax1_fused_1 in range(1): | |
for ax1_fused_0 in T.thread_binding(32, thread="threadIdx.y"): | |
for ax0 in T.thread_binding(8, thread="threadIdx.x"): | |
with T.block("NT_matmul"): | |
vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 = T.axis.reduce(8, ax0) | |
v0 = T.axis.spatial(2560, u_fused_ax0_fused_fused_0 * 32 + ax1_fused_0 + ax1_fused_1) | |
T.reads(var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0]) | |
T.writes(var_NT_matmul_intermediate_local[0, 0, v0]) | |
with T.init(): | |
var_NT_matmul_intermediate_local[0, 0, v0] = T.float32(0) | |
var_NT_matmul_intermediate_local[0, 0, v0] = var_NT_matmul_intermediate_local[0, 0, v0] + var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0] | |
for ax0_fused_0 in T.thread_binding(32, thread="threadIdx.y"): | |
for ax0_fused_1 in range(1): | |
with T.block("T_add_1"): | |
v0 = T.axis.spatial(2560, u_fused_ax0_fused_fused_0 * 32 + ax0_fused_0 + ax0_fused_1) | |
T.reads(var_NT_matmul_intermediate_local[0, 0, v0], lv535[v0], lv728[0, 0, v0]) | |
T.writes(p_output0_intermediate[0, 0, v0]) | |
p_output0_intermediate[0, 0, v0] = T.Cast("float16", var_NT_matmul_intermediate_local[0, 0, v0] + lv535[v0]) + lv728[0, 0, v0] | |
@T.prim_func(private=True) | |
def fused_fused_decode5_fused_NT_matmul10_add9_cast8_cast12_add7_cast7(lv1478: T.Buffer((2560, 1280), "uint32"), lv1479: T.Buffer((2560, 320), "float16"), lv3721: T.Buffer((1, 1, 10240), "float16"), lv1031: T.Buffer((2560,), "float32"), lv1472: T.Buffer((1, 1, 2560), "float16"), p_output0_intermediate: T.Buffer((1, 1, 2560), "float32")): | |
T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) | |
# with T.block("root"): | |
var_NT_matmul_intermediate_local = T.alloc_buffer((1, 1, 2560), scope="local") | |
var_NT_matmul_intermediate_rf_local = T.alloc_buffer((16, 1, 1, 2560), scope="local") | |
var_NT_matmul_intermediate_rf_local_1 = T.alloc_buffer((8, 1, 1, 2560), scope="local") | |
lv1478_local = T.alloc_buffer((2560, 1280), "uint32", scope="local") | |
lv3721_shared = T.alloc_buffer((1, 1, 10240), "float16", scope="shared") | |
for u_fused_ax0_fused_fused_0 in T.thread_binding(80, thread="blockIdx.x"): | |
for u_fused_ax0_fused_fused_1 in T.thread_binding(32, thread="threadIdx.y"): | |
for ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 in T.thread_binding(8, thread="threadIdx.x"): | |
for ax0, ax1 in T.grid(1, 1): | |
for ax2_0 in T.serial(10, annotations={"pragma_unroll_explicit": 256, "pragma_vectorize": 1}): | |
for ax2_1 in T.thread_binding(32, thread="threadIdx.y"): | |
for ax2_2 in T.thread_binding(8, thread="threadIdx.x"): | |
for ax2_3 in T.vectorized(4): | |
with T.block("lv3721_shared"): | |
v0, v1 = T.axis.remap("SS", [ax0, ax1]) | |
v2 = T.axis.spatial(10240, ax2_0 * 1024 + ax2_1 * 32 + ax2_2 * 4 + ax2_3) | |
T.reads(lv3721[v0, v1, v2]) | |
T.writes(lv3721_shared[v0, v1, v2]) | |
lv3721_shared[v0, v1, v2] = lv3721[v0, v1, v2] | |
for u_fused_ax0_fused_fused_2_init in range(1): | |
for ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1_init in T.vectorized(2): | |
with T.block("NT_matmul_rf_init"): | |
vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused = T.axis.spatial(16, ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 * 2 + ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1_init) | |
v0 = T.axis.spatial(2560, u_fused_ax0_fused_fused_0 * 32 + u_fused_ax0_fused_fused_1 + u_fused_ax0_fused_fused_2_init) | |
T.reads() | |
T.writes(var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused, 0, 0, v0]) | |
var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused, 0, 0, v0] = T.float32(0) | |
for ax1_0_fused_ax1_1_fused_0 in T.serial(160, annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}): | |
for ax0_0, ax1 in T.grid(1, 1): | |
for ax0_1 in T.vectorized(1): | |
with T.block("lv1478_local"): | |
v0 = T.axis.spatial(2560, u_fused_ax0_fused_fused_0 * 32 + u_fused_ax0_fused_fused_1 + ax0_0 + ax0_1) | |
v1 = T.axis.spatial(1280, ax1_0_fused_ax1_1_fused_0 * 8 + ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 + ax1) | |
T.reads(lv1478[v0, v1]) | |
T.writes(lv1478_local[v0, v1]) | |
lv1478_local[v0, v1] = lv1478[v0, v1] | |
for u_fused_ax0_fused_fused_2, ax1_0_fused_ax1_1_fused_2 in T.grid(1, 4): | |
for ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1 in T.vectorized(2): | |
with T.block("NT_matmul_rf_update"): | |
vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused = T.axis.spatial(16, ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 * 2 + ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1) | |
v0 = T.axis.spatial(2560, u_fused_ax0_fused_fused_0 * 32 + u_fused_ax0_fused_fused_1 + u_fused_ax0_fused_fused_2) | |
vax1_0_fused_ax1_1_fused_0, vax1_0_fused_ax1_1_fused_2 = T.axis.remap("RR", [ax1_0_fused_ax1_1_fused_0, ax1_0_fused_ax1_1_fused_2]) | |
T.reads(var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused, 0, 0, v0], lv3721_shared[0, 0, vax1_0_fused_ax1_1_fused_0 * 64 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // 2 * 8 + vax1_0_fused_ax1_1_fused_2 * 2 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused % 2], lv1478_local[v0, vax1_0_fused_ax1_1_fused_0 * 8 + vax1_0_fused_ax1_1_fused_2 // 4 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // 2], lv1479[v0, (vax1_0_fused_ax1_1_fused_0 * 64 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // 2 * 8 + vax1_0_fused_ax1_1_fused_2 * 2 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused % 2) // 32]) | |
T.writes(var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused, 0, 0, v0]) | |
var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused, 0, 0, v0] = var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused, 0, 0, v0] + T.Cast("float32", lv3721_shared[0, 0, vax1_0_fused_ax1_1_fused_0 * 64 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // 2 * 8 + vax1_0_fused_ax1_1_fused_2 * 2 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused % 2]) * T.Cast("float32", (T.Cast("float16", T.bitwise_and(T.shift_right(lv1478_local[v0, vax1_0_fused_ax1_1_fused_0 * 8 + vax1_0_fused_ax1_1_fused_2 // 4 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // 2], T.Cast("uint32", (vax1_0_fused_ax1_1_fused_0 * 64 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // 2 * 8 + vax1_0_fused_ax1_1_fused_2 * 2 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused % 2) % 8) * T.uint32(4)), T.uint32(15))) - T.float16(7)) * lv1479[v0, (vax1_0_fused_ax1_1_fused_0 * 64 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // 2 * 8 + vax1_0_fused_ax1_1_fused_2 * 2 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused % 2) // 32]) | |
for ax2_fused_0 in T.thread_binding(32, thread="threadIdx.y"): | |
for ax0 in T.thread_binding(8, thread="threadIdx.x"): | |
for ax2_fused_1_0 in T.serial(1, annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}): | |
for ax2_fused_1_1 in T.vectorized(1): | |
with T.block("NT_matmul_rf_init"): | |
vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 = T.axis.spatial(8, ax0) | |
v0 = T.axis.spatial(2560, u_fused_ax0_fused_fused_0 * 32 + ax2_fused_0 + ax2_fused_1_0 + ax2_fused_1_1) | |
T.reads() | |
T.writes(var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0]) | |
var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0] = T.float32(0) | |
for ax1 in range(2): | |
with T.block("NT_matmul_rf_update"): | |
vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1 = T.axis.remap("SR", [ax0, ax1]) | |
v0 = T.axis.spatial(2560, u_fused_ax0_fused_fused_0 * 32 + ax2_fused_0 + ax2_fused_1_0 + ax2_fused_1_1) | |
T.reads(var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0], var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 * 2 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1, 0, 0, v0]) | |
T.writes(var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0]) | |
var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0] = var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0] + var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 * 2 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1, 0, 0, v0] | |
for ax1_fused_1 in range(1): | |
for ax1_fused_0 in T.thread_binding(32, thread="threadIdx.y"): | |
for ax0 in T.thread_binding(8, thread="threadIdx.x"): | |
with T.block("NT_matmul"): | |
vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 = T.axis.reduce(8, ax0) | |
v0 = T.axis.spatial(2560, u_fused_ax0_fused_fused_0 * 32 + ax1_fused_0 + ax1_fused_1) | |
T.reads(var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0]) | |
T.writes(var_NT_matmul_intermediate_local[0, 0, v0]) | |
with T.init(): | |
var_NT_matmul_intermediate_local[0, 0, v0] = T.float32(0) | |
var_NT_matmul_intermediate_local[0, 0, v0] = var_NT_matmul_intermediate_local[0, 0, v0] + var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0] | |
for ax0_fused_0 in T.thread_binding(32, thread="threadIdx.y"): | |
for ax0_fused_1 in range(1): | |
with T.block("compute_2"): | |
v0 = T.axis.spatial(2560, u_fused_ax0_fused_fused_0 * 32 + ax0_fused_0 + ax0_fused_1) | |
T.reads(var_NT_matmul_intermediate_local[0, 0, v0], lv1031[v0], lv1472[0, 0, v0]) | |
T.writes(p_output0_intermediate[0, 0, v0]) | |
p_output0_intermediate[0, 0, v0] = T.Cast("float32", T.Cast("float16", var_NT_matmul_intermediate_local[0, 0, v0] + lv1031[v0]) + lv1472[0, 0, v0]) | |
@T.prim_func(private=True) | |
def fused_fused_decode5_fused_NT_matmul4_add4_cast1_cast5_add2(lv22: T.Buffer((2560, 1280), "uint32"), lv23: T.Buffer((2560, 320), "float16"), p_lv57: T.handle, lv17: T.Buffer((2560,), "float32"), p_lv16: T.handle, p_output0: T.handle): | |
T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) | |
n = T.int32() | |
lv57 = T.match_buffer(p_lv57, (1, n, 10240), "float16") | |
lv16 = T.match_buffer(p_lv16, (1, n, 2560), "float16") | |
p_output0_intermediate = T.match_buffer(p_output0, (1, n, 2560), "float16") | |
# with T.block("root"): | |
var_NT_matmul_intermediate_reindex_pad_local = T.alloc_buffer((1, (n + 31) // 32 * 32, 2560), scope="local") | |
lv57_reindex_pad_shared = T.alloc_buffer((1, (n + 31) // 32 * 32, 10240), "float16", scope="shared") | |
p_output0_intermediate_reindex_shared = T.alloc_buffer((1, 2560, 10240), "float16", scope="shared") | |
for ax0_ax2_0_fused in T.thread_binding(80, thread="blockIdx.y"): | |
for ax1_0 in T.thread_binding((n + 31) // 32, thread="blockIdx.x"): | |
for ax2_1 in T.thread_binding(1, thread="vthread.y"): | |
for ax1_1 in T.thread_binding(1, thread="vthread.x"): | |
for ax2_2 in T.thread_binding(8, thread="threadIdx.y"): | |
for ax1_2 in T.thread_binding(8, thread="threadIdx.x", annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}): | |
for ax2_3_init, ax1_3_0_init in T.grid(4, 4): | |
for ax1_3_1_init in T.vectorized(1): | |
with T.block("NT_matmul_init"): | |
v0 = T.axis.spatial(1, 0) | |
v1 = T.axis.spatial((n + 31) // 32 * 32, ax1_0 * 32 + ax1_1 * 32 + ax1_2 * 4 + ax1_3_0_init + ax1_3_1_init) | |
v2 = T.axis.spatial(2560, ax0_ax2_0_fused * 32 + ax2_1 * 32 + ax2_2 * 4 + ax2_3_init) | |
T.reads() | |
T.writes(var_NT_matmul_intermediate_reindex_pad_local[0, v1, v2]) | |
var_NT_matmul_intermediate_reindex_pad_local[0, v1, v2] = T.float32(0) | |
for ax3_0 in range(1280): | |
for ax0_ax1_ax2_fused_0 in T.thread_binding(8, thread="threadIdx.y"): | |
for ax0_ax1_ax2_fused_1 in T.thread_binding(8, thread="threadIdx.x"): | |
for ax0_ax1_ax2_fused_2 in range(4): | |
for ax0_ax1_ax2_fused_3 in T.vectorized(1): | |
with T.block("lv57_reindex_pad_shared"): | |
v0 = T.axis.spatial(1, 0) | |
v1 = T.axis.spatial((n + 31) // 32 * 32, ax1_0 * 32 + (ax0_ax1_ax2_fused_0 * 32 + ax0_ax1_ax2_fused_1 * 4 + ax0_ax1_ax2_fused_2 + ax0_ax1_ax2_fused_3) // 8) | |
v2 = T.axis.spatial(10240, ax3_0 * 8 + (ax0_ax1_ax2_fused_0 * 32 + ax0_ax1_ax2_fused_1 * 4 + ax0_ax1_ax2_fused_2 + ax0_ax1_ax2_fused_3) % 8) | |
T.reads(lv57[v0, v1, v2]) | |
T.writes(lv57_reindex_pad_shared[v0, v1, v2]) | |
lv57_reindex_pad_shared[v0, v1, v2] = T.if_then_else(v1 < n, lv57[v0, v1, v2], T.float16(0)) | |
for ax0_ax1_ax2_fused_0 in T.thread_binding(8, thread="threadIdx.y"): | |
for ax0_ax1_ax2_fused_1 in T.thread_binding(8, thread="threadIdx.x"): | |
for ax0_ax1_ax2_fused_2 in range(4): | |
for ax0_ax1_ax2_fused_3 in T.vectorized(1): | |
with T.block("p_output0_intermediate_reindex_shared"): | |
v0 = T.axis.spatial(1, 0) | |
v1 = T.axis.spatial(2560, ax0_ax2_0_fused * 32 + (ax0_ax1_ax2_fused_0 * 32 + ax0_ax1_ax2_fused_1 * 4 + ax0_ax1_ax2_fused_2 + ax0_ax1_ax2_fused_3) // 8) | |
v2 = T.axis.spatial(10240, ax3_0 * 8 + (ax0_ax1_ax2_fused_0 * 32 + ax0_ax1_ax2_fused_1 * 4 + ax0_ax1_ax2_fused_2 + ax0_ax1_ax2_fused_3) % 8) | |
T.reads(lv22[v1, v2 // 8], lv23[v1, v2 // 32]) | |
T.writes(p_output0_intermediate_reindex_shared[v0, v1, v2]) | |
p_output0_intermediate_reindex_shared[v0, v1, v2] = (T.Cast("float16", T.bitwise_and(T.shift_right(lv22[v1, v2 // 8], T.Cast("uint32", v2 % 8) * T.uint32(4)), T.uint32(15))) - T.float16(7)) * lv23[v1, v2 // 32] | |
for ax3_1, ax2_3, ax1_3_0 in T.grid(8, 4, 4): | |
for ax1_3_1 in T.vectorized(1): | |
with T.block("NT_matmul_update"): | |
v0 = T.axis.spatial(1, 0) | |
v1 = T.axis.spatial((n + 31) // 32 * 32, ax1_0 * 32 + ax1_1 * 32 + ax1_2 * 4 + ax1_3_0 + ax1_3_1) | |
v2 = T.axis.spatial(2560, ax0_ax2_0_fused * 32 + ax2_1 * 32 + ax2_2 * 4 + ax2_3) | |
v3 = T.axis.reduce(10240, ax3_0 * 8 + ax3_1) | |
T.reads(var_NT_matmul_intermediate_reindex_pad_local[0, v1, v2], lv57_reindex_pad_shared[0, v1, v3], p_output0_intermediate_reindex_shared[0, v2, v3]) | |
T.writes(var_NT_matmul_intermediate_reindex_pad_local[0, v1, v2]) | |
var_NT_matmul_intermediate_reindex_pad_local[0, v1, v2] = var_NT_matmul_intermediate_reindex_pad_local[0, v1, v2] + T.Cast("float32", lv57_reindex_pad_shared[0, v1, v3]) * T.Cast("float32", p_output0_intermediate_reindex_shared[0, v2, v3]) | |
for ax0, ax1, ax2_0 in T.grid(1, 4, 4): | |
for ax2_1_1 in T.vectorized(1): | |
with T.block("var_NT_matmul_intermediate_reindex_pad_local"): | |
v0 = T.axis.spatial(1, ax0) | |
v1 = T.axis.spatial((n + 31) // 32 * 32, ax1_0 * 32 + ax1_2 * 4 + ax1) | |
v2 = T.axis.spatial(2560, ax0_ax2_0_fused * 32 + ax2_2 * 4 + ax2_0 + ax2_1_1) | |
T.reads(var_NT_matmul_intermediate_reindex_pad_local[v0, v1, v2], lv17[v2], lv16[0, v1, v2]) | |
T.writes(p_output0_intermediate[0, v1, v2]) | |
if v1 < n: | |
p_output0_intermediate[0, v1, v2] = T.Cast("float16", var_NT_matmul_intermediate_reindex_pad_local[v0, v1, v2] + lv17[v2]) + lv16[0, v1, v2] | |
@T.prim_func(private=True) | |
def fused_fused_decode5_fused_NT_matmul4_add4_cast1_cast5_add2_cast(lv704: T.Buffer((2560, 1280), "uint32"), lv705: T.Buffer((2560, 320), "float16"), p_lv1855: T.handle, lv513: T.Buffer((2560,), "float32"), p_lv698: T.handle, p_output0: T.handle): | |
T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) | |
n = T.int32() | |
lv1855 = T.match_buffer(p_lv1855, (1, n, 10240), "float16") | |
lv698 = T.match_buffer(p_lv698, (1, n, 2560), "float16") | |
p_output0_intermediate = T.match_buffer(p_output0, (1, n, 2560)) | |
# with T.block("root"): | |
var_NT_matmul_intermediate_reindex_pad_local = T.alloc_buffer((1, (n + 31) // 32 * 32, 2560), scope="local") | |
lv1855_reindex_pad_shared = T.alloc_buffer((1, (n + 31) // 32 * 32, 10240), "float16", scope="shared") | |
p_output0_intermediate_reindex_shared = T.alloc_buffer((1, 2560, 10240), "float16", scope="shared") | |
for ax0_ax2_0_fused in T.thread_binding(80, thread="blockIdx.y"): | |
for ax1_0 in T.thread_binding((n + 31) // 32, thread="blockIdx.x"): | |
for ax2_1 in T.thread_binding(1, thread="vthread.y"): | |
for ax1_1 in T.thread_binding(1, thread="vthread.x"): | |
for ax2_2 in T.thread_binding(8, thread="threadIdx.y"): | |
for ax1_2 in T.thread_binding(8, thread="threadIdx.x", annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}): | |
for ax2_3_init, ax1_3_0_init in T.grid(4, 4): | |
for ax1_3_1_init in T.vectorized(1): | |
with T.block("NT_matmul_init"): | |
v0 = T.axis.spatial(1, 0) | |
v1 = T.axis.spatial((n + 31) // 32 * 32, ax1_0 * 32 + ax1_1 * 32 + ax1_2 * 4 + ax1_3_0_init + ax1_3_1_init) | |
v2 = T.axis.spatial(2560, ax0_ax2_0_fused * 32 + ax2_1 * 32 + ax2_2 * 4 + ax2_3_init) | |
T.reads() | |
T.writes(var_NT_matmul_intermediate_reindex_pad_local[0, v1, v2]) | |
var_NT_matmul_intermediate_reindex_pad_local[0, v1, v2] = T.float32(0) | |
for ax3_0 in range(1280): | |
for ax0_ax1_ax2_fused_0 in T.thread_binding(8, thread="threadIdx.y"): | |
for ax0_ax1_ax2_fused_1 in T.thread_binding(8, thread="threadIdx.x"): | |
for ax0_ax1_ax2_fused_2 in range(4): | |
for ax0_ax1_ax2_fused_3 in T.vectorized(1): | |
with T.block("lv1855_reindex_pad_shared"): | |
v0 = T.axis.spatial(1, 0) | |
v1 = T.axis.spatial((n + 31) // 32 * 32, ax1_0 * 32 + (ax0_ax1_ax2_fused_0 * 32 + ax0_ax1_ax2_fused_1 * 4 + ax0_ax1_ax2_fused_2 + ax0_ax1_ax2_fused_3) // 8) | |
v2 = T.axis.spatial(10240, ax3_0 * 8 + (ax0_ax1_ax2_fused_0 * 32 + ax0_ax1_ax2_fused_1 * 4 + ax0_ax1_ax2_fused_2 + ax0_ax1_ax2_fused_3) % 8) | |
T.reads(lv1855[v0, v1, v2]) | |
T.writes(lv1855_reindex_pad_shared[v0, v1, v2]) | |
lv1855_reindex_pad_shared[v0, v1, v2] = T.if_then_else(v1 < n, lv1855[v0, v1, v2], T.float16(0)) | |
for ax0_ax1_ax2_fused_0 in T.thread_binding(8, thread="threadIdx.y"): | |
for ax0_ax1_ax2_fused_1 in T.thread_binding(8, thread="threadIdx.x"): | |
for ax0_ax1_ax2_fused_2 in range(4): | |
for ax0_ax1_ax2_fused_3 in T.vectorized(1): | |
with T.block("p_output0_intermediate_reindex_shared"): | |
v0 = T.axis.spatial(1, 0) | |
v1 = T.axis.spatial(2560, ax0_ax2_0_fused * 32 + (ax0_ax1_ax2_fused_0 * 32 + ax0_ax1_ax2_fused_1 * 4 + ax0_ax1_ax2_fused_2 + ax0_ax1_ax2_fused_3) // 8) | |
v2 = T.axis.spatial(10240, ax3_0 * 8 + (ax0_ax1_ax2_fused_0 * 32 + ax0_ax1_ax2_fused_1 * 4 + ax0_ax1_ax2_fused_2 + ax0_ax1_ax2_fused_3) % 8) | |
T.reads(lv704[v1, v2 // 8], lv705[v1, v2 // 32]) | |
T.writes(p_output0_intermediate_reindex_shared[v0, v1, v2]) | |
p_output0_intermediate_reindex_shared[v0, v1, v2] = (T.Cast("float16", T.bitwise_and(T.shift_right(lv704[v1, v2 // 8], T.Cast("uint32", v2 % 8) * T.uint32(4)), T.uint32(15))) - T.float16(7)) * lv705[v1, v2 // 32] | |
for ax3_1, ax2_3, ax1_3_0 in T.grid(8, 4, 4): | |
for ax1_3_1 in T.vectorized(1): | |
with T.block("NT_matmul_update"): | |
v0 = T.axis.spatial(1, 0) | |
v1 = T.axis.spatial((n + 31) // 32 * 32, ax1_0 * 32 + ax1_1 * 32 + ax1_2 * 4 + ax1_3_0 + ax1_3_1) | |
v2 = T.axis.spatial(2560, ax0_ax2_0_fused * 32 + ax2_1 * 32 + ax2_2 * 4 + ax2_3) | |
v3 = T.axis.reduce(10240, ax3_0 * 8 + ax3_1) | |
T.reads(var_NT_matmul_intermediate_reindex_pad_local[0, v1, v2], lv1855_reindex_pad_shared[0, v1, v3], p_output0_intermediate_reindex_shared[0, v2, v3]) | |
T.writes(var_NT_matmul_intermediate_reindex_pad_local[0, v1, v2]) | |
var_NT_matmul_intermediate_reindex_pad_local[0, v1, v2] = var_NT_matmul_intermediate_reindex_pad_local[0, v1, v2] + T.Cast("float32", lv1855_reindex_pad_shared[0, v1, v3]) * T.Cast("float32", p_output0_intermediate_reindex_shared[0, v2, v3]) | |
for ax0, ax1, ax2_0 in T.grid(1, 4, 4): | |
for ax2_1_1 in T.vectorized(1): | |
with T.block("var_NT_matmul_intermediate_reindex_pad_local"): | |
v0 = T.axis.spatial(1, ax0) | |
v1 = T.axis.spatial((n + 31) // 32 * 32, ax1_0 * 32 + ax1_2 * 4 + ax1) | |
v2 = T.axis.spatial(2560, ax0_ax2_0_fused * 32 + ax2_2 * 4 + ax2_0 + ax2_1_1) | |
T.reads(var_NT_matmul_intermediate_reindex_pad_local[v0, v1, v2], lv513[v2], lv698[0, v1, v2]) | |
T.writes(p_output0_intermediate[0, v1, v2]) | |
if v1 < n: | |
p_output0_intermediate[0, v1, v2] = T.Cast("float32", T.Cast("float16", var_NT_matmul_intermediate_reindex_pad_local[v0, v1, v2] + lv513[v2]) + lv698[0, v1, v2]) | |
@T.prim_func(private=True) | |
def fused_fused_decode6_NT_matmul5(lv1483: T.Buffer((50432, 320), "uint32"), lv1484: T.Buffer((50432, 80), "float32"), lv1482: T.Buffer((1, 1, 2560), "float32"), var_NT_matmul_intermediate: T.Buffer((1, 1, 50432), "float32")): | |
T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) | |
# with T.block("root"): | |
var_NT_matmul_intermediate_rf_local = T.alloc_buffer((16, 1, 1, 50432), scope="local") | |
var_NT_matmul_intermediate_rf_local_1 = T.alloc_buffer((8, 1, 1, 50432), scope="local") | |
lv1483_local = T.alloc_buffer((50432, 320), "uint32", scope="local") | |
lv1482_shared = T.alloc_buffer((1, 1, 2560), scope="shared") | |
for u_fused_ax0_fused_fused_0 in T.thread_binding(1576, thread="blockIdx.x"): | |
for u_fused_ax0_fused_fused_1 in T.thread_binding(32, thread="threadIdx.y"): | |
for ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 in T.thread_binding(8, thread="threadIdx.x"): | |
for ax0, ax1 in T.grid(1, 1): | |
for ax2_0 in T.serial(5, annotations={"pragma_unroll_explicit": 256, "pragma_vectorize": 1}): | |
for ax2_1 in T.thread_binding(32, thread="threadIdx.y"): | |
for ax2_2 in T.thread_binding(8, thread="threadIdx.x"): | |
for ax2_3 in T.vectorized(2): | |
with T.block("lv1482_shared"): | |
v0, v1 = T.axis.remap("SS", [ax0, ax1]) | |
v2 = T.axis.spatial(2560, ax2_0 * 512 + ax2_1 * 16 + ax2_2 * 2 + ax2_3) | |
T.reads(lv1482[v0, v1, v2]) | |
T.writes(lv1482_shared[v0, v1, v2]) | |
lv1482_shared[v0, v1, v2] = lv1482[v0, v1, v2] | |
for u_fused_ax0_fused_fused_2_init in range(1): | |
for ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1_init in T.vectorized(2): | |
with T.block("NT_matmul_rf_init"): | |
vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused = T.axis.spatial(16, ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 * 2 + ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1_init) | |
v0 = T.axis.spatial(50432, u_fused_ax0_fused_fused_0 * 32 + u_fused_ax0_fused_fused_1 + u_fused_ax0_fused_fused_2_init) | |
T.reads() | |
T.writes(var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused, 0, 0, v0]) | |
var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused, 0, 0, v0] = T.float32(0) | |
for ax1_0_fused_ax1_1_fused_0 in T.serial(40, annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}): | |
for ax0_0, ax1 in T.grid(1, 1): | |
for ax0_1 in T.vectorized(1): | |
with T.block("lv1483_local"): | |
v0 = T.axis.spatial(50432, u_fused_ax0_fused_fused_0 * 32 + u_fused_ax0_fused_fused_1 + ax0_0 + ax0_1) | |
v1 = T.axis.spatial(320, ax1_0_fused_ax1_1_fused_0 * 8 + ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 + ax1) | |
T.reads(lv1483[v0, v1]) | |
T.writes(lv1483_local[v0, v1]) | |
lv1483_local[v0, v1] = lv1483[v0, v1] | |
for u_fused_ax0_fused_fused_2, ax1_0_fused_ax1_1_fused_2 in T.grid(1, 4): | |
for ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1 in T.vectorized(2): | |
with T.block("NT_matmul_rf_update"): | |
vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused = T.axis.spatial(16, ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 * 2 + ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1) | |
v0 = T.axis.spatial(50432, u_fused_ax0_fused_fused_0 * 32 + u_fused_ax0_fused_fused_1 + u_fused_ax0_fused_fused_2) | |
vax1_0_fused_ax1_1_fused_0, vax1_0_fused_ax1_1_fused_2 = T.axis.remap("RR", [ax1_0_fused_ax1_1_fused_0, ax1_0_fused_ax1_1_fused_2]) | |
T.reads(var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused, 0, 0, v0], lv1482_shared[0, 0, vax1_0_fused_ax1_1_fused_0 * 64 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // 2 * 8 + vax1_0_fused_ax1_1_fused_2 * 2 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused % 2], lv1483_local[v0, vax1_0_fused_ax1_1_fused_0 * 8 + vax1_0_fused_ax1_1_fused_2 // 4 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // 2], lv1484[v0, (vax1_0_fused_ax1_1_fused_0 * 64 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // 2 * 8 + vax1_0_fused_ax1_1_fused_2 * 2 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused % 2) // 32]) | |
T.writes(var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused, 0, 0, v0]) | |
var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused, 0, 0, v0] = var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused, 0, 0, v0] + lv1482_shared[0, 0, vax1_0_fused_ax1_1_fused_0 * 64 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // 2 * 8 + vax1_0_fused_ax1_1_fused_2 * 2 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused % 2] * (T.Cast("float32", T.Cast("float16", T.bitwise_and(T.shift_right(lv1483_local[v0, vax1_0_fused_ax1_1_fused_0 * 8 + vax1_0_fused_ax1_1_fused_2 // 4 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // 2], T.Cast("uint32", (vax1_0_fused_ax1_1_fused_0 * 64 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // 2 * 8 + vax1_0_fused_ax1_1_fused_2 * 2 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused % 2) % 8) * T.uint32(4)), T.uint32(15))) - T.float16(7)) * lv1484[v0, (vax1_0_fused_ax1_1_fused_0 * 64 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // 2 * 8 + vax1_0_fused_ax1_1_fused_2 * 2 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused % 2) // 32]) | |
for ax2_fused_0 in T.thread_binding(32, thread="threadIdx.y"): | |
for ax0 in T.thread_binding(8, thread="threadIdx.x"): | |
for ax2_fused_1_0 in T.serial(1, annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}): | |
for ax2_fused_1_1 in T.vectorized(1): | |
with T.block("NT_matmul_rf_init"): | |
vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 = T.axis.spatial(8, ax0) | |
v0 = T.axis.spatial(50432, u_fused_ax0_fused_fused_0 * 32 + ax2_fused_0 + ax2_fused_1_0 + ax2_fused_1_1) | |
T.reads() | |
T.writes(var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0]) | |
var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0] = T.float32(0) | |
for ax1 in range(2): | |
with T.block("NT_matmul_rf_update"): | |
vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1 = T.axis.remap("SR", [ax0, ax1]) | |
v0 = T.axis.spatial(50432, u_fused_ax0_fused_fused_0 * 32 + ax2_fused_0 + ax2_fused_1_0 + ax2_fused_1_1) | |
T.reads(var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0], var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 * 2 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1, 0, 0, v0]) | |
T.writes(var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0]) | |
var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0] = var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0] + var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 * 2 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1, 0, 0, v0] | |
for ax1_fused_1 in range(1): | |
for ax1_fused_0 in T.thread_binding(32, thread="threadIdx.y"): | |
for ax0 in T.thread_binding(8, thread="threadIdx.x"): | |
with T.block("NT_matmul"): | |
vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 = T.axis.reduce(8, ax0) | |
v0 = T.axis.spatial(50432, u_fused_ax0_fused_fused_0 * 32 + ax1_fused_0 + ax1_fused_1) | |
T.reads(var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0]) | |
T.writes(var_NT_matmul_intermediate[0, 0, v0]) | |
with T.init(): | |
var_NT_matmul_intermediate[0, 0, v0] = T.float32(0) | |
var_NT_matmul_intermediate[0, 0, v0] = var_NT_matmul_intermediate[0, 0, v0] + var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0] | |
@T.prim_func(private=True) | |
def fused_layer_norm1_cast8(lv1872: T.Buffer((1, 1, 2560), "float32"), lv520: T.Buffer((2560,), "float32"), lv521: T.Buffer((2560,), "float32"), var_compute_intermediate: T.Buffer((1, 1, 2560), "float16")): | |
T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) | |
# with T.block("root"): | |
A_red_temp_v0_shared = T.alloc_buffer((1, 1), scope="shared") | |
A_red_temp_v1_shared = T.alloc_buffer((1, 1), scope="shared") | |
for ax0_fused in T.thread_binding(1, thread="blockIdx.x", annotations={"pragma_auto_unroll_max_step": 64, "pragma_unroll_explicit": 1}): | |
for ax0, ax1_fused_0 in T.grid(1, 40): | |
for ax1_fused_1 in T.thread_binding(64, thread="threadIdx.x"): | |
with T.block("A_red_temp"): | |
v0 = T.axis.spatial(1, ax0) | |
v1 = T.axis.reduce(2560, ax1_fused_0 * 64 + ax1_fused_1) | |
T.reads(lv1872[0, 0, v1]) | |
T.writes(A_red_temp_v0_shared[0, 0], A_red_temp_v1_shared[0, 0]) | |
with T.init(): | |
A_red_temp_v0_shared[0, 0] = T.float32(0) | |
A_red_temp_v1_shared[0, 0] = T.float32(0) | |
v_A_red_temp_v0: T.float32 = A_red_temp_v0_shared[0, 0] + lv1872[0, 0, v1] | |
v_A_red_temp_v1: T.float32 = A_red_temp_v1_shared[0, 0] + lv1872[0, 0, v1] * lv1872[0, 0, v1] | |
A_red_temp_v0_shared[0, 0] = v_A_red_temp_v0 | |
A_red_temp_v1_shared[0, 0] = v_A_red_temp_v1 | |
for ax1_0 in range(40): | |
for ax1_1 in T.thread_binding(64, thread="threadIdx.x"): | |
with T.block("compute"): | |
v0 = T.axis.spatial(1, 0) | |
v1 = T.axis.spatial(2560, ax1_0 * 64 + ax1_1) | |
T.reads(lv1872[0, 0, v1], A_red_temp_v0_shared[0, 0], A_red_temp_v1_shared[0, 0], lv520[v1], lv521[v1]) | |
T.writes(var_compute_intermediate[0, 0, v1]) | |
var_compute_intermediate[0, 0, v1] = T.Cast("float16", (lv1872[0, 0, v1] - A_red_temp_v0_shared[0, 0] * T.float32(0.00039062500000000002)) * T.rsqrt(A_red_temp_v1_shared[0, 0] * T.float32(0.00039062500000000002) - A_red_temp_v0_shared[0, 0] * T.float32(0.00039062500000000002) * (A_red_temp_v0_shared[0, 0] * T.float32(0.00039062500000000002)) + T.float32(1.0000000000000001e-05)) * lv520[v1] + lv521[v1]) | |
@T.prim_func(private=True) | |
def fused_layer_norm_cast1(p_lv6: T.handle, lv2: T.Buffer((2560,), "float32"), lv3: T.Buffer((2560,), "float32"), p_output0: T.handle): | |
T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) | |
n = T.int32() | |
lv6 = T.match_buffer(p_lv6, (1, n, 2560)) | |
var_compute_intermediate = T.match_buffer(p_output0, (1, n, 2560), "float16") | |
# with T.block("root"): | |
A_red_temp_v0_shared = T.alloc_buffer((1, n), scope="shared") | |
A_red_temp_v1_shared = T.alloc_buffer((1, n), scope="shared") | |
for ax0_fused in T.thread_binding(n, thread="blockIdx.x", annotations={"pragma_auto_unroll_max_step": 64, "pragma_unroll_explicit": 1}): | |
for ax0, ax1_fused_0 in T.grid(1, 40): | |
for ax1_fused_1 in T.thread_binding(64, thread="threadIdx.x"): | |
with T.block("A_red_temp"): | |
v0 = T.axis.spatial(n, ax0_fused + ax0) | |
v1 = T.axis.reduce(2560, ax1_fused_0 * 64 + ax1_fused_1) | |
T.reads(lv6[0, v0, v1]) | |
T.writes(A_red_temp_v0_shared[0, v0], A_red_temp_v1_shared[0, v0]) | |
with T.init(): | |
A_red_temp_v0_shared[0, v0] = T.float32(0) | |
A_red_temp_v1_shared[0, v0] = T.float32(0) | |
v_A_red_temp_v0: T.float32 = A_red_temp_v0_shared[0, v0] + lv6[0, v0, v1] | |
v_A_red_temp_v1: T.float32 = A_red_temp_v1_shared[0, v0] + lv6[0, v0, v1] * lv6[0, v0, v1] | |
A_red_temp_v0_shared[0, v0] = v_A_red_temp_v0 | |
A_red_temp_v1_shared[0, v0] = v_A_red_temp_v1 | |
for ax1_0 in range(40): | |
for ax1_1 in T.thread_binding(64, thread="threadIdx.x"): | |
with T.block("compute"): | |
v0 = T.axis.spatial(n, ax0_fused) | |
v1 = T.axis.spatial(2560, ax1_0 * 64 + ax1_1) | |
T.reads(lv6[0, v0, v1], A_red_temp_v0_shared[0, v0], A_red_temp_v1_shared[0, v0], lv2[v1], lv3[v1]) | |
T.writes(var_compute_intermediate[0, v0, v1]) | |
var_compute_intermediate[0, v0, v1] = T.Cast("float16", (lv6[0, v0, v1] - A_red_temp_v0_shared[0, v0] * T.float32(0.00039062500000000002)) * T.rsqrt(A_red_temp_v1_shared[0, v0] * T.float32(0.00039062500000000002) - A_red_temp_v0_shared[0, v0] * T.float32(0.00039062500000000002) * (A_red_temp_v0_shared[0, v0] * T.float32(0.00039062500000000002)) + T.float32(1.0000000000000001e-05)) * lv2[v1] + lv3[v1]) | |
@T.prim_func(private=True) | |
def fused_min_max_triu_te_broadcast_to(p_output0: T.handle, n: T.int32): | |
T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) | |
var_T_broadcast_to_intermediate = T.match_buffer(p_output0, (1, 1, n, n), "float16") | |
# with T.block("root"): | |
for ax0_ax1_fused_0 in T.thread_binding((n * n + 1023) // 1024, thread="blockIdx.x"): | |
for ax0_ax1_fused_1 in T.thread_binding(1024, thread="threadIdx.x"): | |
with T.block("T_broadcast_to"): | |
v0 = T.axis.spatial(n, (ax0_ax1_fused_0 * 1024 + ax0_ax1_fused_1) // n) | |
v1 = T.axis.spatial(n, (ax0_ax1_fused_0 * 1024 + ax0_ax1_fused_1) % n) | |
T.where(ax0_ax1_fused_0 * 1024 + ax0_ax1_fused_1 < n * n) | |
T.reads() | |
T.writes(var_T_broadcast_to_intermediate[0, 0, v0, v1]) | |
var_T_broadcast_to_intermediate[0, 0, v0, v1] = T.Select(v0 < v1, T.float16(-65504), T.float16(65504)) | |
@T.prim_func(private=True) | |
def fused_reshape7_split1(lv1878: T.Buffer((1, 1, 7680), "float16"), var_T_split_sections_intermediate: T.Buffer((1, 1, 32, 80), "float16"), var_T_split_sections_intermediate_1: T.Buffer((1, 1, 32, 80), "float16"), var_T_split_sections_intermediate_2: T.Buffer((1, 1, 32, 80), "float16")): | |
T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) | |
# with T.block("root"): | |
for ax0_ax1_fused_0 in T.thread_binding(3, thread="blockIdx.x"): | |
for ax0_ax1_fused_1 in T.thread_binding(1024, thread="threadIdx.x"): | |
with T.block("T_split_sections"): | |
v0 = T.axis.spatial(32, (ax0_ax1_fused_0 * 1024 + ax0_ax1_fused_1) // 80) | |
v1 = T.axis.spatial(80, (ax0_ax1_fused_0 * 1024 + ax0_ax1_fused_1) % 80) | |
T.where(ax0_ax1_fused_0 * 1024 + ax0_ax1_fused_1 < 2560) | |
T.reads(lv1878[0, 0, v0 * 240 + v1]) | |
T.writes(var_T_split_sections_intermediate[0, 0, v0, v1]) | |
var_T_split_sections_intermediate[0, 0, v0, v1] = lv1878[0, 0, v0 * 240 + v1] | |
for ax0_ax1_fused_0 in T.thread_binding(3, thread="blockIdx.x"): | |
for ax0_ax1_fused_1 in T.thread_binding(1024, thread="threadIdx.x"): | |
with T.block("T_split_sections_1"): | |
v0 = T.axis.spatial(32, (ax0_ax1_fused_0 * 1024 + ax0_ax1_fused_1) // 80) | |
v1 = T.axis.spatial(80, (ax0_ax1_fused_0 * 1024 + ax0_ax1_fused_1) % 80) | |
T.where(ax0_ax1_fused_0 * 1024 + ax0_ax1_fused_1 < 2560) | |
T.reads(lv1878[0, 0, v0 * 240 + (v1 + 80)]) | |
T.writes(var_T_split_sections_intermediate_1[0, 0, v0, v1]) | |
var_T_split_sections_intermediate_1[0, 0, v0, v1] = lv1878[0, 0, v0 * 240 + (v1 + 80)] | |
for ax0_ax1_fused_0 in T.thread_binding(3, thread="blockIdx.x"): | |
for ax0_ax1_fused_1 in T.thread_binding(1024, thread="threadIdx.x"): | |
with T.block("T_split_sections_2"): | |
v0 = T.axis.spatial(32, (ax0_ax1_fused_0 * 1024 + ax0_ax1_fused_1) // 80) | |
v1 = T.axis.spatial(80, (ax0_ax1_fused_0 * 1024 + ax0_ax1_fused_1) % 80) | |
T.where(ax0_ax1_fused_0 * 1024 + ax0_ax1_fused_1 < 2560) | |
T.reads(lv1878[0, 0, v0 * 240 + (v1 + 160)]) | |
T.writes(var_T_split_sections_intermediate_2[0, 0, v0, v1]) | |
var_T_split_sections_intermediate_2[0, 0, v0, v1] = lv1878[0, 0, v0 * 240 + (v1 + 160)] | |
@T.prim_func(private=True) | |
def fused_slice1_cast6(lv3729: T.Buffer((1, 1, 2560), "float32"), var_compute_intermediate: T.Buffer((1, 1, 2560), "float32")): | |
T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) | |
# with T.block("root"): | |
for ax0_fused_0 in T.thread_binding(3, thread="blockIdx.x"): | |
for ax0_fused_1 in T.thread_binding(1024, thread="threadIdx.x"): | |
with T.block("compute"): | |
v0 = T.axis.spatial(2560, ax0_fused_0 * 1024 + ax0_fused_1) | |
T.where(ax0_fused_0 * 1024 + ax0_fused_1 < 2560) | |
T.reads(lv3729[0, 0, v0]) | |
T.writes(var_compute_intermediate[0, 0, v0]) | |
var_compute_intermediate[0, 0, v0] = lv3729[0, 0, v0] | |
@T.prim_func(private=True) | |
def fused_softmax2_cast10(p_lv1904: T.handle, p_output0: T.handle): | |
T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) | |
n = T.int32() | |
lv1904 = T.match_buffer(p_lv1904, (1, 32, 1, n)) | |
var_compute_intermediate = T.match_buffer(p_output0, (1, 32, 1, n), "float16") | |
# with T.block("root"): | |
T_softmax_maxelem_shared = T.alloc_buffer((1, 32, 1), scope="shared") | |
T_softmax_expsum_shared = T.alloc_buffer((1, 32, 1), scope="shared") | |
for ax0_fused in T.thread_binding(32, thread="blockIdx.x", annotations={"pragma_auto_unroll_max_step": 64, "pragma_unroll_explicit": 1}): | |
for ax0, ax1_fused_0 in T.grid(1, (n + 63) // 64): | |
for ax1_fused_1 in T.thread_binding(64, thread="threadIdx.x"): | |
with T.block("T_softmax_maxelem"): | |
v0 = T.axis.spatial(32, ax0_fused + ax0) | |
v1 = T.axis.reduce(n, ax1_fused_0 * 64 + ax1_fused_1) | |
T.where(ax1_fused_0 * 64 + ax1_fused_1 < n) | |
T.reads(lv1904[0, v0, 0, v1]) | |
T.writes(T_softmax_maxelem_shared[0, v0, 0]) | |
with T.init(): | |
T_softmax_maxelem_shared[0, v0, 0] = T.float32(-3.4028234663852886e+38) | |
T_softmax_maxelem_shared[0, v0, 0] = T.max(T_softmax_maxelem_shared[0, v0, 0], lv1904[0, v0, 0, v1]) | |
for ax0, ax1_fused_0 in T.grid(1, (n + 63) // 64): | |
for ax1_fused_1 in T.thread_binding(64, thread="threadIdx.x"): | |
with T.block("T_softmax_expsum"): | |
v0 = T.axis.spatial(32, ax0_fused + ax0) | |
v1 = T.axis.reduce(n, ax1_fused_0 * 64 + ax1_fused_1) | |
T.where(ax1_fused_0 * 64 + ax1_fused_1 < n) | |
T.reads(lv1904[0, v0, 0, v1], T_softmax_maxelem_shared[0, v0, 0]) | |
T.writes(T_softmax_expsum_shared[0, v0, 0]) | |
with T.init(): | |
T_softmax_expsum_shared[0, v0, 0] = T.float32(0) | |
T_softmax_expsum_shared[0, v0, 0] = T_softmax_expsum_shared[0, v0, 0] + T.exp(lv1904[0, v0, 0, v1] - T_softmax_maxelem_shared[0, v0, 0]) | |
for ax1_0 in range((n + 63) // 64): | |
for ax1_1 in T.thread_binding(64, thread="threadIdx.x"): | |
with T.block("compute"): | |
v0 = T.axis.spatial(32, ax0_fused) | |
v1 = T.axis.spatial(n, ax1_0 * 64 + ax1_1) | |
T.where(ax1_0 * 64 + ax1_1 < n) | |
T.reads(lv1904[0, v0, 0, v1], T_softmax_maxelem_shared[0, v0, 0], T_softmax_expsum_shared[0, v0, 0]) | |
T.writes(var_compute_intermediate[0, v0, 0, v1]) | |
var_compute_intermediate[0, v0, 0, v1] = T.Cast("float16", T.exp(lv1904[0, v0, 0, v1] - T_softmax_maxelem_shared[0, v0, 0]) / T_softmax_expsum_shared[0, v0, 0]) | |
@T.prim_func(private=True) | |
def fused_softmax_cast3(p_lv38: T.handle, p_output0: T.handle): | |
T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) | |
n, m = T.int32(), T.int32() | |
lv38 = T.match_buffer(p_lv38, (1, 32, n, m)) | |
var_compute_intermediate = T.match_buffer(p_output0, (1, 32, n, m), "float16") | |
# with T.block("root"): | |
T_softmax_maxelem_shared = T.alloc_buffer((1, 32, n), scope="shared") | |
T_softmax_expsum_shared = T.alloc_buffer((1, 32, n), scope="shared") | |
for ax0_ax1_fused in T.thread_binding(n * 32, thread="blockIdx.x", annotations={"pragma_auto_unroll_max_step": 64, "pragma_unroll_explicit": 1}): | |
for ax0, ax1, ax2_fused_0 in T.grid(1, 1, (m + 63) // 64): | |
for ax2_fused_1 in T.thread_binding(64, thread="threadIdx.x"): | |
with T.block("T_softmax_maxelem"): | |
v0 = T.axis.spatial(32, ax0_ax1_fused // n + ax0) | |
v1 = T.axis.spatial(n, ax0_ax1_fused % n + ax1) | |
v2 = T.axis.reduce(m, ax2_fused_0 * 64 + ax2_fused_1) | |
T.where(ax2_fused_0 * 64 + ax2_fused_1 < m) | |
T.reads(lv38[0, v0, v1, v2]) | |
T.writes(T_softmax_maxelem_shared[0, v0, v1]) | |
with T.init(): | |
T_softmax_maxelem_shared[0, v0, v1] = T.float32(-3.4028234663852886e+38) | |
T_softmax_maxelem_shared[0, v0, v1] = T.max(T_softmax_maxelem_shared[0, v0, v1], lv38[0, v0, v1, v2]) | |
for ax0, ax1, ax2_fused_0 in T.grid(1, 1, (m + 63) // 64): | |
for ax2_fused_1 in T.thread_binding(64, thread="threadIdx.x"): | |
with T.block("T_softmax_expsum"): | |
v0 = T.axis.spatial(32, ax0_ax1_fused // n + ax0) | |
v1 = T.axis.spatial(n, ax0_ax1_fused % n + ax1) | |
v2 = T.axis.reduce(m, ax2_fused_0 * 64 + ax2_fused_1) | |
T.where(ax2_fused_0 * 64 + ax2_fused_1 < m) | |
T.reads(lv38[0, v0, v1, v2], T_softmax_maxelem_shared[0, v0, v1]) | |
T.writes(T_softmax_expsum_shared[0, v0, v1]) | |
with T.init(): | |
T_softmax_expsum_shared[0, v0, v1] = T.float32(0) | |
T_softmax_expsum_shared[0, v0, v1] = T_softmax_expsum_shared[0, v0, v1] + T.exp(lv38[0, v0, v1, v2] - T_softmax_maxelem_shared[0, v0, v1]) | |
for ax2_0 in range((m + 63) // 64): | |
for ax2_1 in T.thread_binding(64, thread="threadIdx.x"): | |
with T.block("compute"): | |
v0 = T.axis.spatial(32, ax0_ax1_fused // n) | |
v1 = T.axis.spatial(n, ax0_ax1_fused % n) | |
v2 = T.axis.spatial(m, ax2_0 * 64 + ax2_1) | |
T.where(ax2_0 * 64 + ax2_1 < m) | |
T.reads(lv38[0, v0, v1, v2], T_softmax_maxelem_shared[0, v0, v1], T_softmax_expsum_shared[0, v0, v1]) | |
T.writes(var_compute_intermediate[0, v0, v1, v2]) | |
var_compute_intermediate[0, v0, v1, v2] = T.Cast("float16", T.exp(lv38[0, v0, v1, v2] - T_softmax_maxelem_shared[0, v0, v1]) / T_softmax_expsum_shared[0, v0, v1]) | |
@T.prim_func(private=True) | |
def fused_squeeze(p_lv14_2: T.handle, p_output0: T.handle): | |
T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) | |
n = T.int32() | |
lv14_2 = T.match_buffer(p_lv14_2, (1, n, 32, 80), "float16") | |
var_T_squeeze_intermediate = T.match_buffer(p_output0, (n, 32, 80), "float16") | |
# with T.block("root"): | |
for ax0_ax1_ax2_fused_0 in T.thread_binding((n * 2560 + 1023) // 1024, thread="blockIdx.x"): | |
for ax0_ax1_ax2_fused_1 in T.thread_binding(1024, thread="threadIdx.x"): | |
with T.block("T_squeeze"): | |
v0 = T.axis.spatial(n, (ax0_ax1_ax2_fused_0 * 1024 + ax0_ax1_ax2_fused_1) // 2560) | |
v1 = T.axis.spatial(32, (ax0_ax1_ax2_fused_0 * 1024 + ax0_ax1_ax2_fused_1) % 2560 // 80) | |
v2 = T.axis.spatial(80, (ax0_ax1_ax2_fused_0 * 1024 + ax0_ax1_ax2_fused_1) % 80) | |
T.where(ax0_ax1_ax2_fused_0 * 1024 + ax0_ax1_ax2_fused_1 < n * 2560) | |
T.reads(lv14_2[0, v0, v1, v2]) | |
T.writes(var_T_squeeze_intermediate[v0, v1, v2]) | |
var_T_squeeze_intermediate[v0, v1, v2] = lv14_2[0, v0, v1, v2] | |
@T.prim_func(private=True) | |
def fused_squeeze1(lv1880_2: T.Buffer((1, 1, 32, 80), "float16"), var_T_squeeze_intermediate: T.Buffer((1, 32, 80), "float16")): | |
T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) | |
# with T.block("root"): | |
for ax0_ax1_fused_0 in T.thread_binding(3, thread="blockIdx.x"): | |
for ax0_ax1_fused_1 in T.thread_binding(1024, thread="threadIdx.x"): | |
with T.block("T_squeeze"): | |
v0 = T.axis.spatial(32, (ax0_ax1_fused_0 * 1024 + ax0_ax1_fused_1) // 80) | |
v1 = T.axis.spatial(80, (ax0_ax1_fused_0 * 1024 + ax0_ax1_fused_1) % 80) | |
T.where(ax0_ax1_fused_0 * 1024 + ax0_ax1_fused_1 < 2560) | |
T.reads(lv1880_2[0, 0, v0, v1]) | |
T.writes(var_T_squeeze_intermediate[0, v0, v1]) | |
var_T_squeeze_intermediate[0, v0, v1] = lv1880_2[0, 0, v0, v1] | |
@T.prim_func(private=True) | |
def fused_transpose8_reshape8(lv1907: T.Buffer((1, 32, 1, 80), "float16"), var_T_reshape_intermediate: T.Buffer((1, 1, 2560), "float16")): | |
T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) | |
# with T.block("root"): | |
for ax0_fused_0 in T.thread_binding(3, thread="blockIdx.x"): | |
for ax0_fused_1 in T.thread_binding(1024, thread="threadIdx.x"): | |
with T.block("T_reshape"): | |
v0 = T.axis.spatial(2560, ax0_fused_0 * 1024 + ax0_fused_1) | |
T.where(ax0_fused_0 * 1024 + ax0_fused_1 < 2560) | |
T.reads(lv1907[0, v0 // 80, 0, v0 % 80]) | |
T.writes(var_T_reshape_intermediate[0, 0, v0]) | |
var_T_reshape_intermediate[0, 0, v0] = lv1907[0, v0 // 80, 0, v0 % 80] | |
@T.prim_func(private=True) | |
def layer_norm(var_A: T.handle, B: T.Buffer((2560,), "float32"), C: T.Buffer((2560,), "float32"), var_T_layer_norm: T.handle): | |
T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) | |
n = T.int32() | |
A = T.match_buffer(var_A, (1, n, 2560)) | |
T_layer_norm = T.match_buffer(var_T_layer_norm, (1, n, 2560)) | |
# with T.block("root"): | |
A_red_temp_v0_shared = T.alloc_buffer((1, n), scope="shared") | |
A_red_temp_v1_shared = T.alloc_buffer((1, n), scope="shared") | |
for ax0_fused in T.thread_binding(n, thread="blockIdx.x", annotations={"pragma_auto_unroll_max_step": 64, "pragma_unroll_explicit": 1}): | |
for ax0, ax1_fused_0 in T.grid(1, 40): | |
for ax1_fused_1 in T.thread_binding(64, thread="threadIdx.x"): | |
with T.block("A_red_temp"): | |
v0 = T.axis.spatial(n, ax0_fused + ax0) | |
v1 = T.axis.reduce(2560, ax1_fused_0 * 64 + ax1_fused_1) | |
T.reads(A[0, v0, v1]) | |
T.writes(A_red_temp_v0_shared[0, v0], A_red_temp_v1_shared[0, v0]) | |
with T.init(): | |
A_red_temp_v0_shared[0, v0] = T.float32(0) | |
A_red_temp_v1_shared[0, v0] = T.float32(0) | |
v_A_red_temp_v0: T.float32 = A_red_temp_v0_shared[0, v0] + A[0, v0, v1] | |
v_A_red_temp_v1: T.float32 = A_red_temp_v1_shared[0, v0] + A[0, v0, v1] * A[0, v0, v1] | |
A_red_temp_v0_shared[0, v0] = v_A_red_temp_v0 | |
A_red_temp_v1_shared[0, v0] = v_A_red_temp_v1 | |
for ax1_0 in range(40): | |
for ax1_1 in T.thread_binding(64, thread="threadIdx.x"): | |
with T.block("T_layer_norm"): | |
v0 = T.axis.spatial(n, ax0_fused) | |
v1 = T.axis.spatial(2560, ax1_0 * 64 + ax1_1) | |
T.reads(A[0, v0, v1], A_red_temp_v0_shared[0, v0], A_red_temp_v1_shared[0, v0], B[v1], C[v1]) | |
T.writes(T_layer_norm[0, v0, v1]) | |
T_layer_norm[0, v0, v1] = (A[0, v0, v1] - A_red_temp_v0_shared[0, v0] * T.float32(0.00039062500000000002)) * T.rsqrt(A_red_temp_v1_shared[0, v0] * T.float32(0.00039062500000000002) - A_red_temp_v0_shared[0, v0] * T.float32(0.00039062500000000002) * (A_red_temp_v0_shared[0, v0] * T.float32(0.00039062500000000002)) + T.float32(1.0000000000000001e-05)) * B[v1] + C[v1] | |
@T.prim_func(private=True) | |
def layer_norm1(A: T.Buffer((1, 1, 2560), "float32"), B: T.Buffer((2560,), "float32"), C: T.Buffer((2560,), "float32"), T_layer_norm: T.Buffer((1, 1, 2560), "float32")): | |
T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) | |
# with T.block("root"): | |
A_red_temp_v0_shared = T.alloc_buffer((1, 1), scope="shared") | |
A_red_temp_v1_shared = T.alloc_buffer((1, 1), scope="shared") | |
for ax0_fused in T.thread_binding(1, thread="blockIdx.x", annotations={"pragma_auto_unroll_max_step": 64, "pragma_unroll_explicit": 1}): | |
for ax0, ax1_fused_0 in T.grid(1, 40): | |
for ax1_fused_1 in T.thread_binding(64, thread="threadIdx.x"): | |
with T.block("A_red_temp"): | |
v0 = T.axis.spatial(1, ax0) | |
v1 = T.axis.reduce(2560, ax1_fused_0 * 64 + ax1_fused_1) | |
T.reads(A[0, 0, v1]) | |
T.writes(A_red_temp_v0_shared[0, 0], A_red_temp_v1_shared[0, 0]) | |
with T.init(): | |
A_red_temp_v0_shared[0, 0] = T.float32(0) | |
A_red_temp_v1_shared[0, 0] = T.float32(0) | |
v_A_red_temp_v0: T.float32 = A_red_temp_v0_shared[0, 0] + A[0, 0, v1] | |
v_A_red_temp_v1: T.float32 = A_red_temp_v1_shared[0, 0] + A[0, 0, v1] * A[0, 0, v1] | |
A_red_temp_v0_shared[0, 0] = v_A_red_temp_v0 | |
A_red_temp_v1_shared[0, 0] = v_A_red_temp_v1 | |
for ax1_0 in range(40): | |
for ax1_1 in T.thread_binding(64, thread="threadIdx.x"): | |
with T.block("T_layer_norm"): | |
v0 = T.axis.spatial(1, 0) | |
v1 = T.axis.spatial(2560, ax1_0 * 64 + ax1_1) | |
T.reads(A[0, 0, v1], A_red_temp_v0_shared[0, 0], A_red_temp_v1_shared[0, 0], B[v1], C[v1]) | |
T.writes(T_layer_norm[0, 0, v1]) | |
T_layer_norm[0, 0, v1] = (A[0, 0, v1] - A_red_temp_v0_shared[0, 0] * T.float32(0.00039062500000000002)) * T.rsqrt(A_red_temp_v1_shared[0, 0] * T.float32(0.00039062500000000002) - A_red_temp_v0_shared[0, 0] * T.float32(0.00039062500000000002) * (A_red_temp_v0_shared[0, 0] * T.float32(0.00039062500000000002)) + T.float32(1.0000000000000001e-05)) * B[v1] + C[v1] | |
@T.prim_func(private=True) | |
def matmul8(var_A: T.handle, var_B: T.handle, var_matmul: T.handle): | |
T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) | |
n, m = T.int32(), T.int32() | |
A = T.match_buffer(var_A, (1, 32, n, m), "float16") | |
B = T.match_buffer(var_B, (1, 32, m, 80), "float16") | |
matmul = T.match_buffer(var_matmul, (1, 32, n, 80), "float16") | |
# with T.block("root"): | |
matmul_reindex_pad_local = T.alloc_buffer((32, (n + 31) // 32 * 32, 96), "float16", scope="local") | |
A_reindex_pad_shared = T.alloc_buffer((32, (n + 31) // 32 * 32, (m + 7) // 8 * 8), "float16", scope="shared") | |
B_reindex_pad_shared = T.alloc_buffer((32, 96, (m + 7) // 8 * 8), "float16", scope="shared") | |
for ax0_ax2_0_fused in T.thread_binding(96, thread="blockIdx.y"): | |
for ax1_0 in T.thread_binding((n + 31) // 32, thread="blockIdx.x"): | |
for ax2_1 in T.thread_binding(1, thread="vthread.y"): | |
for ax1_1 in T.thread_binding(1, thread="vthread.x"): | |
for ax2_2 in T.thread_binding(8, thread="threadIdx.y"): | |
for ax1_2 in T.thread_binding(8, thread="threadIdx.x", annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}): | |
for ax2_3_init, ax1_3_0_init in T.grid(4, 4): | |
for ax1_3_1_init in T.vectorized(1): | |
with T.block("matmul_init"): | |
v0 = T.axis.spatial(32, ax0_ax2_0_fused // 3) | |
v1 = T.axis.spatial((n + 31) // 32 * 32, ax1_0 * 32 + ax1_1 * 32 + ax1_2 * 4 + ax1_3_0_init + ax1_3_1_init) | |
v2 = T.axis.spatial(96, ax0_ax2_0_fused % 3 * 32 + ax2_1 * 32 + ax2_2 * 4 + ax2_3_init) | |
T.reads() | |
T.writes(matmul_reindex_pad_local[v0, v1, v2]) | |
matmul_reindex_pad_local[v0, v1, v2] = T.float16(0) | |
for ax3_0 in range((m + 7) // 8): | |
for ax0_ax1_ax2_fused_0 in T.thread_binding(8, thread="threadIdx.y"): | |
for ax0_ax1_ax2_fused_1 in T.thread_binding(8, thread="threadIdx.x"): | |
for ax0_ax1_ax2_fused_2 in range(4): | |
for ax0_ax1_ax2_fused_3 in T.vectorized(1): | |
with T.block("A_reindex_pad_shared"): | |
v0 = T.axis.spatial(32, ax0_ax2_0_fused // 3) | |
v1 = T.axis.spatial((n + 31) // 32 * 32, ax1_0 * 32 + (ax0_ax1_ax2_fused_0 * 32 + ax0_ax1_ax2_fused_1 * 4 + ax0_ax1_ax2_fused_2 + ax0_ax1_ax2_fused_3) // 8) | |
v2 = T.axis.spatial((m + 7) // 8 * 8, ax3_0 * 8 + (ax0_ax1_ax2_fused_0 * 32 + ax0_ax1_ax2_fused_1 * 4 + ax0_ax1_ax2_fused_2 + ax0_ax1_ax2_fused_3) % 8) | |
T.reads(A[0, v0, v1, v2]) | |
T.writes(A_reindex_pad_shared[v0, v1, v2]) | |
A_reindex_pad_shared[v0, v1, v2] = T.if_then_else(v1 < n and v2 < m, A[0, v0, v1, v2], T.float16(0)) | |
for ax0_ax1_ax2_fused_0 in T.thread_binding(8, thread="threadIdx.y"): | |
for ax0_ax1_ax2_fused_1 in T.thread_binding(8, thread="threadIdx.x"): | |
for ax0_ax1_ax2_fused_2 in range(4): | |
for ax0_ax1_ax2_fused_3 in T.vectorized(1): | |
with T.block("B_reindex_pad_shared"): | |
v0 = T.axis.spatial(32, ax0_ax2_0_fused // 3) | |
v1 = T.axis.spatial(96, ax0_ax2_0_fused % 3 * 32 + (ax0_ax1_ax2_fused_0 * 32 + ax0_ax1_ax2_fused_1 * 4 + ax0_ax1_ax2_fused_2 + ax0_ax1_ax2_fused_3) // 8) | |
v2 = T.axis.spatial((m + 7) // 8 * 8, ax3_0 * 8 + (ax0_ax1_ax2_fused_0 * 32 + ax0_ax1_ax2_fused_1 * 4 + ax0_ax1_ax2_fused_2 + ax0_ax1_ax2_fused_3) % 8) | |
T.reads(B[0, v0, v2, v1]) | |
T.writes(B_reindex_pad_shared[v0, v1, v2]) | |
B_reindex_pad_shared[v0, v1, v2] = T.if_then_else(v1 < 80 and v2 < m, B[0, v0, v2, v1], T.float16(0)) | |
for ax3_1, ax2_3, ax1_3_0 in T.grid(8, 4, 4): | |
for ax1_3_1 in T.vectorized(1): | |
with T.block("matmul_update"): | |
v0 = T.axis.spatial(32, ax0_ax2_0_fused // 3) | |
v1 = T.axis.spatial((n + 31) // 32 * 32, ax1_0 * 32 + ax1_1 * 32 + ax1_2 * 4 + ax1_3_0 + ax1_3_1) | |
v2 = T.axis.spatial(96, ax0_ax2_0_fused % 3 * 32 + ax2_1 * 32 + ax2_2 * 4 + ax2_3) | |
v3 = T.axis.reduce((m + 7) // 8 * 8, ax3_0 * 8 + ax3_1) | |
T.reads(matmul_reindex_pad_local[v0, v1, v2], A_reindex_pad_shared[v0, v1, v3], B_reindex_pad_shared[v0, v2, v3]) | |
T.writes(matmul_reindex_pad_local[v0, v1, v2]) | |
matmul_reindex_pad_local[v0, v1, v2] = matmul_reindex_pad_local[v0, v1, v2] + A_reindex_pad_shared[v0, v1, v3] * B_reindex_pad_shared[v0, v2, v3] | |
for ax0, ax1, ax2_0 in T.grid(1, 4, 4): | |
for ax2_1_1 in T.vectorized(1): | |
with T.block("matmul_reindex_pad_local"): | |
v0 = T.axis.spatial(32, ax0_ax2_0_fused // 3 + ax0) | |
v1 = T.axis.spatial((n + 31) // 32 * 32, ax1_0 * 32 + ax1_2 * 4 + ax1) | |
v2 = T.axis.spatial(96, ax0_ax2_0_fused % 3 * 32 + ax2_2 * 4 + ax2_0 + ax2_1_1) | |
T.reads(matmul_reindex_pad_local[v0, v1, v2]) | |
T.writes(matmul[0, v0, v1, v2]) | |
if v1 < n and v2 < 80: | |
matmul[0, v0, v1, v2] = matmul_reindex_pad_local[v0, v1, v2] | |
@T.prim_func(private=True) | |
def matmul9(var_A: T.handle, var_B: T.handle, matmul: T.Buffer((1, 32, 1, 80), "float16")): | |
T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) | |
n = T.int32() | |
A = T.match_buffer(var_A, (1, 32, 1, n), "float16") | |
B = T.match_buffer(var_B, (1, 32, n, 80), "float16") | |
# with T.block("root"): | |
matmul_rf_local = T.alloc_buffer((16, 1, 32, 1, 80), "float16", scope="local") | |
for ax0_ax1_fused_0 in T.thread_binding(160, thread="blockIdx.x"): | |
for ax0_ax1_fused_1 in T.thread_binding(16, thread="threadIdx.x"): | |
for ax2_fused_1 in T.thread_binding(16, thread="threadIdx.y"): | |
with T.block("matmul_rf_init"): | |
vax2_fused_1 = T.axis.spatial(16, ax2_fused_1) | |
v0 = T.axis.spatial(32, (ax0_ax1_fused_0 * 16 + ax0_ax1_fused_1) // 80) | |
v1 = T.axis.spatial(80, (ax0_ax1_fused_0 * 16 + ax0_ax1_fused_1) % 80) | |
T.reads() | |
T.writes(matmul_rf_local[vax2_fused_1, 0, v0, 0, v1]) | |
matmul_rf_local[vax2_fused_1, 0, v0, 0, v1] = T.float16(0) | |
for ax2_fused_0, u in T.grid((n + 15) // 16, 1): | |
with T.block("matmul_rf_update"): | |
vax2_fused_1 = T.axis.spatial(16, ax2_fused_1) | |
v0 = T.axis.spatial(32, (ax0_ax1_fused_0 * 16 + ax0_ax1_fused_1) // 80) | |
v1 = T.axis.spatial(80, (ax0_ax1_fused_0 * 16 + ax0_ax1_fused_1) % 80) | |
vax2_fused_0 = T.axis.reduce((n + 15) // 16, ax2_fused_0) | |
T.where(ax2_fused_0 * 16 + ax2_fused_1 < n) | |
T.reads(matmul_rf_local[vax2_fused_1, 0, v0, 0, v1], A[0, v0, 0, vax2_fused_0 * 16 + vax2_fused_1], B[0, v0, vax2_fused_0 * 16 + vax2_fused_1, v1]) | |
T.writes(matmul_rf_local[vax2_fused_1, 0, v0, 0, v1]) | |
matmul_rf_local[vax2_fused_1, 0, v0, 0, v1] = matmul_rf_local[vax2_fused_1, 0, v0, 0, v1] + A[0, v0, 0, vax2_fused_0 * 16 + vax2_fused_1] * B[0, v0, vax2_fused_0 * 16 + vax2_fused_1, v1] | |
for ax1_ax2_fused in T.thread_binding(16, thread="threadIdx.x"): | |
for ax0 in T.thread_binding(16, thread="threadIdx.y"): | |
with T.block("matmul"): | |
vax2_fused_1 = T.axis.reduce(16, ax0) | |
v0 = T.axis.spatial(32, ax0_ax1_fused_0 // 5) | |
v1 = T.axis.spatial(80, ax0_ax1_fused_0 % 5 * 16 + ax1_ax2_fused) | |
T.reads(matmul_rf_local[vax2_fused_1, 0, v0, 0, v1]) | |
T.writes(matmul[0, v0, 0, v1]) | |
with T.init(): | |
matmul[0, v0, 0, v1] = T.float16(0) | |
matmul[0, v0, 0, v1] = matmul[0, v0, 0, v1] + matmul_rf_local[vax2_fused_1, 0, v0, 0, v1] | |
@T.prim_func(private=True) | |
def reshape(var_A: T.handle, var_T_reshape: T.handle): | |
T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) | |
n = T.int32() | |
A = T.match_buffer(var_A, (1, n), "int32") | |
T_reshape = T.match_buffer(var_T_reshape, (n,), "int32") | |
# with T.block("root"): | |
for ax0_fused_0 in T.thread_binding((n + 1023) // 1024, thread="blockIdx.x"): | |
for ax0_fused_1 in T.thread_binding(1024, thread="threadIdx.x"): | |
with T.block("T_reshape"): | |
v0 = T.axis.spatial(n, ax0_fused_0 * 1024 + ax0_fused_1) | |
T.where(ax0_fused_0 * 1024 + ax0_fused_1 < n) | |
T.reads(A[0, v0]) | |
T.writes(T_reshape[v0]) | |
T_reshape[v0] = A[0, v0] | |
@T.prim_func(private=True) | |
def reshape1(var_A: T.handle, var_T_reshape: T.handle): | |
T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) | |
n = T.int32() | |
A = T.match_buffer(var_A, (n, 2560), "float16") | |
T_reshape = T.match_buffer(var_T_reshape, (1, n, 2560), "float16") | |
# with T.block("root"): | |
for ax0_ax1_fused_0 in T.thread_binding((n * 2560 + 1023) // 1024, thread="blockIdx.x"): | |
for ax0_ax1_fused_1 in T.thread_binding(1024, thread="threadIdx.x"): | |
with T.block("T_reshape"): | |
v0 = T.axis.spatial(n, (ax0_ax1_fused_0 * 1024 + ax0_ax1_fused_1) // 2560) | |
v1 = T.axis.spatial(2560, (ax0_ax1_fused_0 * 1024 + ax0_ax1_fused_1) % 2560) | |
T.where(ax0_ax1_fused_0 * 1024 + ax0_ax1_fused_1 < n * 2560) | |
T.reads(A[v0, v1]) | |
T.writes(T_reshape[0, v0, v1]) | |
T_reshape[0, v0, v1] = A[v0, v1] | |
@T.prim_func(private=True) | |
def reshape2(var_A: T.handle, var_T_reshape: T.handle): | |
T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) | |
n = T.int32() | |
A = T.match_buffer(var_A, (1, n, 7680), "float16") | |
T_reshape = T.match_buffer(var_T_reshape, (1, n, 32, 240), "float16") | |
# with T.block("root"): | |
for ax0_ax1_ax2_fused_0 in T.thread_binding((n * 7680 + 1023) // 1024, thread="blockIdx.x"): | |
for ax0_ax1_ax2_fused_1 in T.thread_binding(1024, thread="threadIdx.x"): | |
with T.block("T_reshape"): | |
v0 = T.axis.spatial(n, (ax0_ax1_ax2_fused_0 * 1024 + ax0_ax1_ax2_fused_1) // 7680) | |
v1 = T.axis.spatial(32, (ax0_ax1_ax2_fused_0 * 1024 + ax0_ax1_ax2_fused_1) % 7680 // 240) | |
v2 = T.axis.spatial(240, (ax0_ax1_ax2_fused_0 * 1024 + ax0_ax1_ax2_fused_1) % 240) | |
T.where(ax0_ax1_ax2_fused_0 * 1024 + ax0_ax1_ax2_fused_1 < n * 7680) | |
T.reads(A[0, v0, v1 * 240 + v2]) | |
T.writes(T_reshape[0, v0, v1, v2]) | |
T_reshape[0, v0, v1, v2] = A[0, v0, v1 * 240 + v2] | |
@T.prim_func(private=True) | |
def reshape3(var_A: T.handle, var_T_reshape: T.handle): | |
T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) | |
m = T.int32() | |
A = T.match_buffer(var_A, (m, 32, 80), "float16") | |
T_reshape = T.match_buffer(var_T_reshape, (1, m, 32, 80), "float16") | |
# with T.block("root"): | |
for ax0_ax1_ax2_fused_0 in T.thread_binding((m * 2560 + 1023) // 1024, thread="blockIdx.x"): | |
for ax0_ax1_ax2_fused_1 in T.thread_binding(1024, thread="threadIdx.x"): | |
with T.block("T_reshape"): | |
v0 = T.axis.spatial(m, (ax0_ax1_ax2_fused_0 * 1024 + ax0_ax1_ax2_fused_1) // 2560) | |
v1 = T.axis.spatial(32, (ax0_ax1_ax2_fused_0 * 1024 + ax0_ax1_ax2_fused_1) % 2560 // 80) | |
v2 = T.axis.spatial(80, (ax0_ax1_ax2_fused_0 * 1024 + ax0_ax1_ax2_fused_1) % 80) | |
T.where(ax0_ax1_ax2_fused_0 * 1024 + ax0_ax1_ax2_fused_1 < m * 2560) | |
T.reads(A[v0, v1, v2]) | |
T.writes(T_reshape[0, v0, v1, v2]) | |
T_reshape[0, v0, v1, v2] = A[v0, v1, v2] | |
@T.prim_func(private=True) | |
def reshape4(var_A: T.handle, var_T_reshape: T.handle): | |
T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) | |
n = T.int32() | |
A = T.match_buffer(var_A, (1, n, 32, 80), "float16") | |
T_reshape = T.match_buffer(var_T_reshape, (1, n, 2560), "float16") | |
# with T.block("root"): | |
for ax0_ax1_fused_0 in T.thread_binding((n * 2560 + 1023) // 1024, thread="blockIdx.x"): | |
for ax0_ax1_fused_1 in T.thread_binding(1024, thread="threadIdx.x"): | |
with T.block("T_reshape"): | |
v0 = T.axis.spatial(n, (ax0_ax1_fused_0 * 1024 + ax0_ax1_fused_1) // 2560) | |
v1 = T.axis.spatial(2560, (ax0_ax1_fused_0 * 1024 + ax0_ax1_fused_1) % 2560) | |
T.where(ax0_ax1_fused_0 * 1024 + ax0_ax1_fused_1 < n * 2560) | |
T.reads(A[0, v0, v1 // 80, v1 % 80]) | |
T.writes(T_reshape[0, v0, v1]) | |
T_reshape[0, v0, v1] = A[0, v0, v1 // 80, v1 % 80] | |
@T.prim_func(private=True) | |
def reshape5(A: T.Buffer((1, 1), "int32"), T_reshape: T.Buffer((1,), "int32")): | |
T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) | |
# with T.block("root"): | |
for ax0_fused_0 in T.thread_binding(1, thread="blockIdx.x"): | |
for ax0_fused_1 in T.thread_binding(1024, thread="threadIdx.x"): | |
with T.block("T_reshape"): | |
v0 = T.axis.spatial(1, 0) | |
T.where(ax0_fused_0 * 1024 + ax0_fused_1 < 1) | |
T.reads(A[0, 0]) | |
T.writes(T_reshape[0]) | |
T_reshape[0] = A[0, 0] | |
@T.prim_func(private=True) | |
def reshape6(A: T.Buffer((1, 2560), "float16"), T_reshape: T.Buffer((1, 1, 2560), "float16")): | |
T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) | |
# with T.block("root"): | |
for ax0_fused_0 in T.thread_binding(3, thread="blockIdx.x"): | |
for ax0_fused_1 in T.thread_binding(1024, thread="threadIdx.x"): | |
with T.block("T_reshape"): | |
v0 = T.axis.spatial(2560, ax0_fused_0 * 1024 + ax0_fused_1) | |
T.where(ax0_fused_0 * 1024 + ax0_fused_1 < 2560) | |
T.reads(A[0, v0]) | |
T.writes(T_reshape[0, 0, v0]) | |
T_reshape[0, 0, v0] = A[0, v0] | |
@T.prim_func(private=True) | |
def rotary_embedding(var_A: T.handle, B: T.Buffer((2048, 80), "float16"), C: T.Buffer((2048, 80), "float16"), var_rotary: T.handle, m: T.int32): | |
T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) | |
n = T.int32() | |
A = T.match_buffer(var_A, (1, n, 32, 80), "float16") | |
rotary = T.match_buffer(var_rotary, (1, n, 32, 80), "float16") | |
# with T.block("root"): | |
for ax0_ax1_ax2_fused_0 in T.thread_binding((n * 2560 + 1023) // 1024, thread="blockIdx.x"): | |
for ax0_ax1_ax2_fused_1 in T.thread_binding(1024, thread="threadIdx.x"): | |
with T.block("rotary"): | |
v0 = T.axis.spatial(n, (ax0_ax1_ax2_fused_0 * 1024 + ax0_ax1_ax2_fused_1) // 2560) | |
v1 = T.axis.spatial(32, (ax0_ax1_ax2_fused_0 * 1024 + ax0_ax1_ax2_fused_1) % 2560 // 80) | |
v2 = T.axis.spatial(80, (ax0_ax1_ax2_fused_0 * 1024 + ax0_ax1_ax2_fused_1) % 80) | |
T.where(ax0_ax1_ax2_fused_0 * 1024 + ax0_ax1_ax2_fused_1 < n * 2560) | |
T.reads(B[v0 + (m - n), v2], A[0, v0, v1, v2 + -40:v2 + -40 + 81], C[v0 + (m - n), v2]) | |
T.writes(rotary[0, v0, v1, v2]) | |
rotary[0, v0, v1, v2] = T.Select(v2 < 80, B[v0 + (m - n), v2] * A[0, v0, v1, v2] + C[v0 + (m - n), v2] * T.Select(v2 < 40, A[0, v0, v1, v2 + 40] * T.float16(-1), A[0, v0, v1, v2 + -40]), A[0, v0, v1, v2]) | |
@T.prim_func(private=True) | |
def rotary_embedding1(A: T.Buffer((1, 1, 32, 80), "float16"), B: T.Buffer((2048, 80), "float16"), C: T.Buffer((2048, 80), "float16"), rotary: T.Buffer((1, 1, 32, 80), "float16"), n: T.int32): | |
T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) | |
# with T.block("root"): | |
for ax0_ax1_fused_0 in T.thread_binding(3, thread="blockIdx.x"): | |
for ax0_ax1_fused_1 in T.thread_binding(1024, thread="threadIdx.x"): | |
with T.block("rotary"): | |
v0 = T.axis.spatial(32, (ax0_ax1_fused_0 * 1024 + ax0_ax1_fused_1) // 80) | |
v1 = T.axis.spatial(80, (ax0_ax1_fused_0 * 1024 + ax0_ax1_fused_1) % 80) | |
T.where(ax0_ax1_fused_0 * 1024 + ax0_ax1_fused_1 < 2560) | |
T.reads(B[n - 1, v1], A[0, 0, v0, v1 + -40:v1 + -40 + 81], C[n - 1, v1]) | |
T.writes(rotary[0, 0, v0, v1]) | |
rotary[0, 0, v0, v1] = T.Select(v1 < 80, B[n - 1, v1] * A[0, 0, v0, v1] + C[n - 1, v1] * T.Select(v1 < 40, A[0, 0, v0, v1 + 40] * T.float16(-1), A[0, 0, v0, v1 + -40]), A[0, 0, v0, v1]) | |
@T.prim_func(private=True) | |
def slice(var_A: T.handle, slice: T.Buffer((1, 1, 2560), "float32")): | |
T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) | |
n = T.int32() | |
A = T.match_buffer(var_A, (1, n, 2560)) | |
# with T.block("root"): | |
for ax0_fused_0 in T.thread_binding(3, thread="blockIdx.x"): | |
for ax0_fused_1 in T.thread_binding(1024, thread="threadIdx.x"): | |
with T.block("slice"): | |
v0 = T.axis.spatial(2560, ax0_fused_0 * 1024 + ax0_fused_1) | |
T.where(ax0_fused_0 * 1024 + ax0_fused_1 < 2560) | |
T.reads(A[0, n - 1, v0]) | |
T.writes(slice[0, 0, v0]) | |
slice[0, 0, v0] = A[0, n - 1, v0] | |
@T.prim_func(private=True) | |
def softmax1(A: T.Buffer((1, 1, 50432), "float32"), T_softmax_norm: T.Buffer((1, 1, 50432), "float32")): | |
T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) | |
# with T.block("root"): | |
T_softmax_maxelem_shared = T.alloc_buffer((1, 1), scope="shared") | |
T_softmax_expsum_shared = T.alloc_buffer((1, 1), scope="shared") | |
for ax0_fused in T.thread_binding(1, thread="blockIdx.x", annotations={"pragma_auto_unroll_max_step": 64, "pragma_unroll_explicit": 1}): | |
for ax0, ax1_fused_0 in T.grid(1, 788): | |
for ax1_fused_1 in T.thread_binding(64, thread="threadIdx.x"): | |
with T.block("T_softmax_maxelem"): | |
v0 = T.axis.spatial(1, ax0) | |
v1 = T.axis.reduce(50432, ax1_fused_0 * 64 + ax1_fused_1) | |
T.reads(A[0, 0, v1]) | |
T.writes(T_softmax_maxelem_shared[0, 0]) | |
with T.init(): | |
T_softmax_maxelem_shared[0, 0] = T.float32(-3.4028234663852886e+38) | |
T_softmax_maxelem_shared[0, 0] = T.max(T_softmax_maxelem_shared[0, 0], A[0, 0, v1]) | |
for ax0, ax1_fused_0 in T.grid(1, 788): | |
for ax1_fused_1 in T.thread_binding(64, thread="threadIdx.x"): | |
with T.block("T_softmax_expsum"): | |
v0 = T.axis.spatial(1, ax0) | |
v1 = T.axis.reduce(50432, ax1_fused_0 * 64 + ax1_fused_1) | |
T.reads(A[0, 0, v1], T_softmax_maxelem_shared[0, 0]) | |
T.writes(T_softmax_expsum_shared[0, 0]) | |
with T.init(): | |
T_softmax_expsum_shared[0, 0] = T.float32(0) | |
T_softmax_expsum_shared[0, 0] = T_softmax_expsum_shared[0, 0] + T.exp(A[0, 0, v1] - T_softmax_maxelem_shared[0, 0]) | |
for ax1_0 in range(788): | |
for ax1_1 in T.thread_binding(64, thread="threadIdx.x"): | |
with T.block("T_softmax_norm"): | |
v0 = T.axis.spatial(1, 0) | |
v1 = T.axis.spatial(50432, ax1_0 * 64 + ax1_1) | |
T.reads(A[0, 0, v1], T_softmax_maxelem_shared[0, 0], T_softmax_expsum_shared[0, 0]) | |
T.writes(T_softmax_norm[0, 0, v1]) | |
T.block_attr({"axis": 2}) | |
T_softmax_norm[0, 0, v1] = T.exp(A[0, 0, v1] - T_softmax_maxelem_shared[0, 0]) / T_softmax_expsum_shared[0, 0] | |
@T.prim_func(private=True) | |
def split(var_A: T.handle, var_T_split_sections: T.handle, var_T_split_sections_1: T.handle, var_T_split_sections_2: T.handle): | |
T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) | |
n = T.int32() | |
A = T.match_buffer(var_A, (1, n, 32, 240), "float16") | |
T_split_sections = T.match_buffer(var_T_split_sections, (1, n, 32, 80), "float16") | |
T_split_sections_1 = T.match_buffer(var_T_split_sections_1, (1, n, 32, 80), "float16") | |
T_split_sections_2 = T.match_buffer(var_T_split_sections_2, (1, n, 32, 80), "float16") | |
# with T.block("root"): | |
for ax0_ax1_ax2_fused_0 in T.thread_binding((n * 2560 + 1023) // 1024, thread="blockIdx.x"): | |
for ax0_ax1_ax2_fused_1 in T.thread_binding(1024, thread="threadIdx.x"): | |
with T.block("T_split_sections"): | |
v0 = T.axis.spatial(n, (ax0_ax1_ax2_fused_0 * 1024 + ax0_ax1_ax2_fused_1) // 2560) | |
v1 = T.axis.spatial(32, (ax0_ax1_ax2_fused_0 * 1024 + ax0_ax1_ax2_fused_1) % 2560 // 80) | |
v2 = T.axis.spatial(80, (ax0_ax1_ax2_fused_0 * 1024 + ax0_ax1_ax2_fused_1) % 80) | |
T.where(ax0_ax1_ax2_fused_0 * 1024 + ax0_ax1_ax2_fused_1 < n * 2560) | |
T.reads(A[0, v0, v1, v2]) | |
T.writes(T_split_sections[0, v0, v1, v2]) | |
T_split_sections[0, v0, v1, v2] = A[0, v0, v1, v2] | |
for ax0_ax1_ax2_fused_0 in T.thread_binding((n * 2560 + 1023) // 1024, thread="blockIdx.x"): | |
for ax0_ax1_ax2_fused_1 in T.thread_binding(1024, thread="threadIdx.x"): | |
with T.block("T_split_sections_1"): | |
v0 = T.axis.spatial(n, (ax0_ax1_ax2_fused_0 * 1024 + ax0_ax1_ax2_fused_1) // 2560) | |
v1 = T.axis.spatial(32, (ax0_ax1_ax2_fused_0 * 1024 + ax0_ax1_ax2_fused_1) % 2560 // 80) | |
v2 = T.axis.spatial(80, (ax0_ax1_ax2_fused_0 * 1024 + ax0_ax1_ax2_fused_1) % 80) | |
T.where(ax0_ax1_ax2_fused_0 * 1024 + ax0_ax1_ax2_fused_1 < n * 2560) | |
T.reads(A[0, v0, v1, v2 + 80]) | |
T.writes(T_split_sections_1[0, v0, v1, v2]) | |
T_split_sections_1[0, v0, v1, v2] = A[0, v0, v1, v2 + 80] | |
for ax0_ax1_ax2_fused_0 in T.thread_binding((n * 2560 + 1023) // 1024, thread="blockIdx.x"): | |
for ax0_ax1_ax2_fused_1 in T.thread_binding(1024, thread="threadIdx.x"): | |
with T.block("T_split_sections_2"): | |
v0 = T.axis.spatial(n, (ax0_ax1_ax2_fused_0 * 1024 + ax0_ax1_ax2_fused_1) // 2560) | |
v1 = T.axis.spatial(32, (ax0_ax1_ax2_fused_0 * 1024 + ax0_ax1_ax2_fused_1) % 2560 // 80) | |
v2 = T.axis.spatial(80, (ax0_ax1_ax2_fused_0 * 1024 + ax0_ax1_ax2_fused_1) % 80) | |
T.where(ax0_ax1_ax2_fused_0 * 1024 + ax0_ax1_ax2_fused_1 < n * 2560) | |
T.reads(A[0, v0, v1, v2 + 160]) | |
T.writes(T_split_sections_2[0, v0, v1, v2]) | |
T_split_sections_2[0, v0, v1, v2] = A[0, v0, v1, v2 + 160] | |
@T.prim_func(private=True) | |
def squeeze(var_A: T.handle, var_T_squeeze: T.handle): | |
T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) | |
n = T.int32() | |
A = T.match_buffer(var_A, (1, n, 32, 80), "float16") | |
T_squeeze = T.match_buffer(var_T_squeeze, (n, 32, 80), "float16") | |
# with T.block("root"): | |
for ax0_ax1_ax2_fused_0 in T.thread_binding((n * 2560 + 1023) // 1024, thread="blockIdx.x"): | |
for ax0_ax1_ax2_fused_1 in T.thread_binding(1024, thread="threadIdx.x"): | |
with T.block("T_squeeze"): | |
v0 = T.axis.spatial(n, (ax0_ax1_ax2_fused_0 * 1024 + ax0_ax1_ax2_fused_1) // 2560) | |
v1 = T.axis.spatial(32, (ax0_ax1_ax2_fused_0 * 1024 + ax0_ax1_ax2_fused_1) % 2560 // 80) | |
v2 = T.axis.spatial(80, (ax0_ax1_ax2_fused_0 * 1024 + ax0_ax1_ax2_fused_1) % 80) | |
T.where(ax0_ax1_ax2_fused_0 * 1024 + ax0_ax1_ax2_fused_1 < n * 2560) | |
T.reads(A[0, v0, v1, v2]) | |
T.writes(T_squeeze[v0, v1, v2]) | |
T_squeeze[v0, v1, v2] = A[0, v0, v1, v2] | |
@T.prim_func(private=True) | |
def squeeze1(A: T.Buffer((1, 1, 32, 80), "float16"), T_squeeze: T.Buffer((1, 32, 80), "float16")): | |
T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) | |
# with T.block("root"): | |
for ax0_ax1_fused_0 in T.thread_binding(3, thread="blockIdx.x"): | |
for ax0_ax1_fused_1 in T.thread_binding(1024, thread="threadIdx.x"): | |
with T.block("T_squeeze"): | |
v0 = T.axis.spatial(32, (ax0_ax1_fused_0 * 1024 + ax0_ax1_fused_1) // 80) | |
v1 = T.axis.spatial(80, (ax0_ax1_fused_0 * 1024 + ax0_ax1_fused_1) % 80) | |
T.where(ax0_ax1_fused_0 * 1024 + ax0_ax1_fused_1 < 2560) | |
T.reads(A[0, 0, v0, v1]) | |
T.writes(T_squeeze[0, v0, v1]) | |
T_squeeze[0, v0, v1] = A[0, 0, v0, v1] | |
@T.prim_func(private=True) | |
def transpose5(var_A: T.handle, var_T_transpose: T.handle): | |
T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) | |
n = T.int32() | |
A = T.match_buffer(var_A, (1, n, 32, 80), "float16") | |
T_transpose = T.match_buffer(var_T_transpose, (1, 32, n, 80), "float16") | |
# with T.block("root"): | |
for ax0_ax1_ax2_fused_0 in T.thread_binding((n * 2560 + 1023) // 1024, thread="blockIdx.x"): | |
for ax0_ax1_ax2_fused_1 in T.thread_binding(1024, thread="threadIdx.x"): | |
with T.block("T_transpose"): | |
v0 = T.axis.spatial(32, (ax0_ax1_ax2_fused_0 * 1024 + ax0_ax1_ax2_fused_1) // (80 * n)) | |
v1 = T.axis.spatial(n, (ax0_ax1_ax2_fused_0 * 1024 + ax0_ax1_ax2_fused_1) % (80 * n) // 80) | |
v2 = T.axis.spatial(80, (ax0_ax1_ax2_fused_0 * 1024 + ax0_ax1_ax2_fused_1) % 80) | |
T.where(ax0_ax1_ax2_fused_0 * 1024 + ax0_ax1_ax2_fused_1 < n * 2560) | |
T.reads(A[0, v1, v0, v2]) | |
T.writes(T_transpose[0, v0, v1, v2]) | |
T_transpose[0, v0, v1, v2] = A[0, v1, v0, v2] | |
@T.prim_func(private=True) | |
def transpose6(var_A: T.handle, var_T_transpose: T.handle): | |
T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) | |
n = T.int32() | |
A = T.match_buffer(var_A, (1, 32, n, 80), "float16") | |
T_transpose = T.match_buffer(var_T_transpose, (1, n, 32, 80), "float16") | |
# with T.block("root"): | |
for ax0_ax1_ax2_fused_0 in T.thread_binding((n * 2560 + 1023) // 1024, thread="blockIdx.x"): | |
for ax0_ax1_ax2_fused_1 in T.thread_binding(1024, thread="threadIdx.x"): | |
with T.block("T_transpose"): | |
v0 = T.axis.spatial(n, (ax0_ax1_ax2_fused_0 * 1024 + ax0_ax1_ax2_fused_1) // 2560) | |
v1 = T.axis.spatial(32, (ax0_ax1_ax2_fused_0 * 1024 + ax0_ax1_ax2_fused_1) % 2560 // 80) | |
v2 = T.axis.spatial(80, (ax0_ax1_ax2_fused_0 * 1024 + ax0_ax1_ax2_fused_1) % 80) | |
T.where(ax0_ax1_ax2_fused_0 * 1024 + ax0_ax1_ax2_fused_1 < n * 2560) | |
T.reads(A[0, v1, v0, v2]) | |
T.writes(T_transpose[0, v0, v1, v2]) | |
T_transpose[0, v0, v1, v2] = A[0, v1, v0, v2] | |
@T.prim_func(private=True) | |
def transpose7(A: T.Buffer((1, 1, 32, 80), "float16"), T_transpose: T.Buffer((1, 32, 1, 80), "float16")): | |
T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) | |
# with T.block("root"): | |
for ax0_ax1_fused_0 in T.thread_binding(3, thread="blockIdx.x"): | |
for ax0_ax1_fused_1 in T.thread_binding(1024, thread="threadIdx.x"): | |
with T.block("T_transpose"): | |
v0 = T.axis.spatial(32, (ax0_ax1_fused_0 * 1024 + ax0_ax1_fused_1) // 80) | |
v1 = T.axis.spatial(80, (ax0_ax1_fused_0 * 1024 + ax0_ax1_fused_1) % 80) | |
T.where(ax0_ax1_fused_0 * 1024 + ax0_ax1_fused_1 < 2560) | |
T.reads(A[0, 0, v0, v1]) | |
T.writes(T_transpose[0, v0, 0, v1]) | |
T_transpose[0, v0, 0, v1] = A[0, 0, v0, v1] | |
@R.function | |
def create_kv_cache() -> R.Tuple(R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object): | |
R.func_attr({"tir_var_upper_bound": {"m": 2048, "n": 2048}}) | |
with R.dataflow(): | |
lv3735: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 80]), R.prim_value(0), sinfo_args=(R.Object,)) | |
lv3736: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 80]), R.prim_value(0), sinfo_args=(R.Object,)) | |
lv3737: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 80]), R.prim_value(0), sinfo_args=(R.Object,)) | |
lv3738: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 80]), R.prim_value(0), sinfo_args=(R.Object,)) | |
lv3739: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 80]), R.prim_value(0), sinfo_args=(R.Object,)) | |
lv3740: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 80]), R.prim_value(0), sinfo_args=(R.Object,)) | |
lv3741: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 80]), R.prim_value(0), sinfo_args=(R.Object,)) | |
lv3742: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 80]), R.prim_value(0), sinfo_args=(R.Object,)) | |
lv3743: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 80]), R.prim_value(0), sinfo_args=(R.Object,)) | |
lv3744: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 80]), R.prim_value(0), sinfo_args=(R.Object,)) | |
lv3745: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 80]), R.prim_value(0), sinfo_args=(R.Object,)) | |
lv3746: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 80]), R.prim_value(0), sinfo_args=(R.Object,)) | |
lv3747: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 80]), R.prim_value(0), sinfo_args=(R.Object,)) | |
lv3748: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 80]), R.prim_value(0), sinfo_args=(R.Object,)) | |
lv3749: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 80]), R.prim_value(0), sinfo_args=(R.Object,)) | |
lv3750: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 80]), R.prim_value(0), sinfo_args=(R.Object,)) | |
lv3751: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 80]), R.prim_value(0), sinfo_args=(R.Object,)) | |
lv3752: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 80]), R.prim_value(0), sinfo_args=(R.Object,)) | |
lv3753: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 80]), R.prim_value(0), sinfo_args=(R.Object,)) | |
lv3754: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 80]), R.prim_value(0), sinfo_args=(R.Object,)) | |
lv3755: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 80]), R.prim_value(0), sinfo_args=(R.Object,)) | |
lv3756: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 80]), R.prim_value(0), sinfo_args=(R.Object,)) | |
lv3757: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 80]), R.prim_value(0), sinfo_args=(R.Object,)) | |
lv3758: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 80]), R.prim_value(0), sinfo_args=(R.Object,)) | |
lv3759: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 80]), R.prim_value(0), sinfo_args=(R.Object,)) | |
lv3760: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 80]), R.prim_value(0), sinfo_args=(R.Object,)) | |
lv3761: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 80]), R.prim_value(0), sinfo_args=(R.Object,)) | |
lv3762: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 80]), R.prim_value(0), sinfo_args=(R.Object,)) | |
lv3763: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 80]), R.prim_value(0), sinfo_args=(R.Object,)) | |
lv3764: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 80]), R.prim_value(0), sinfo_args=(R.Object,)) | |
lv3765: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 80]), R.prim_value(0), sinfo_args=(R.Object,)) | |
lv3766: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 80]), R.prim_value(0), sinfo_args=(R.Object,)) | |
lv3767: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 80]), R.prim_value(0), sinfo_args=(R.Object,)) | |
lv3768: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 80]), R.prim_value(0), sinfo_args=(R.Object,)) | |
lv3769: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 80]), R.prim_value(0), sinfo_args=(R.Object,)) | |
lv3770: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 80]), R.prim_value(0), sinfo_args=(R.Object,)) | |
lv3771: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 80]), R.prim_value(0), sinfo_args=(R.Object,)) | |
lv3772: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 80]), R.prim_value(0), sinfo_args=(R.Object,)) | |
lv3773: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 80]), R.prim_value(0), sinfo_args=(R.Object,)) | |
lv3774: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 80]), R.prim_value(0), sinfo_args=(R.Object,)) | |
lv3775: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 80]), R.prim_value(0), sinfo_args=(R.Object,)) | |
lv3776: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 80]), R.prim_value(0), sinfo_args=(R.Object,)) | |
lv3777: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 80]), R.prim_value(0), sinfo_args=(R.Object,)) | |
lv3778: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 80]), R.prim_value(0), sinfo_args=(R.Object,)) | |
lv3779: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 80]), R.prim_value(0), sinfo_args=(R.Object,)) | |
lv3780: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 80]), R.prim_value(0), sinfo_args=(R.Object,)) | |
lv3781: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 80]), R.prim_value(0), sinfo_args=(R.Object,)) | |
lv3782: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 80]), R.prim_value(0), sinfo_args=(R.Object,)) | |
lv3783: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 80]), R.prim_value(0), sinfo_args=(R.Object,)) | |
lv3784: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 80]), R.prim_value(0), sinfo_args=(R.Object,)) | |
lv3785: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 80]), R.prim_value(0), sinfo_args=(R.Object,)) | |
lv3786: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 80]), R.prim_value(0), sinfo_args=(R.Object,)) | |
lv3787: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 80]), R.prim_value(0), sinfo_args=(R.Object,)) | |
lv3788: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 80]), R.prim_value(0), sinfo_args=(R.Object,)) | |
lv3789: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 80]), R.prim_value(0), sinfo_args=(R.Object,)) | |
lv3790: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 80]), R.prim_value(0), sinfo_args=(R.Object,)) | |
lv3791: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 80]), R.prim_value(0), sinfo_args=(R.Object,)) | |
lv3792: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 80]), R.prim_value(0), sinfo_args=(R.Object,)) | |
lv3793: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 80]), R.prim_value(0), sinfo_args=(R.Object,)) | |
lv3794: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 80]), R.prim_value(0), sinfo_args=(R.Object,)) | |
lv3795: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 80]), R.prim_value(0), sinfo_args=(R.Object,)) | |
lv3796: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 80]), R.prim_value(0), sinfo_args=(R.Object,)) | |
lv3797: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 80]), R.prim_value(0), sinfo_args=(R.Object,)) | |
lv3798: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 80]), R.prim_value(0), sinfo_args=(R.Object,)) | |
gv2: R.Tuple(R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object) = lv3735, lv3736, lv3737, lv3738, lv3739, lv3740, lv3741, lv3742, lv3743, lv3744, lv3745, lv3746, lv3747, lv3748, lv3749, lv3750, lv3751, lv3752, lv3753, lv3754, lv3755, lv3756, lv3757, lv3758, lv3759, lv3760, lv3761, lv3762, lv3763, lv3764, lv3765, lv3766, lv3767, lv3768, lv3769, lv3770, lv3771, lv3772, lv3773, lv3774, lv3775, lv3776, lv3777, lv3778, lv3779, lv3780, lv3781, lv3782, lv3783, lv3784, lv3785, lv3786, lv3787, lv3788, lv3789, lv3790, lv3791, lv3792, lv3793, lv3794, lv3795, lv3796, lv3797, lv3798 | |
R.output(gv2) | |
return gv2 | |
@R.function | |
def decode(input_ids1: R.Tensor((1, 1), dtype="int32"), all_seq_len: R.Shape(["n"]), kv_cache: R.Tuple(R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object), model_params: R.Tuple(R.Tensor((50432, 320), dtype="uint32"), R.Tensor((50432, 80), dtype="float16"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((7680, 320), dtype="uint32"), R.Tensor((7680, 80), dtype="float16"), R.Tensor((7680,), dtype="float16"), R.Tensor((2560, 320), dtype="uint32"), R.Tensor((2560, 80), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((10240, 320), dtype="uint32"), R.Tensor((10240, 80), dtype="float16"), R.Tensor((10240,), dtype="float32"), R.Tensor((2560, 1280), dtype="uint32"), R.Tensor((2560, 320), dtype="float16"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((7680, 320), dtype="uint32"), R.Tensor((7680, 80), dtype="float16"), R.Tensor((7680,), dtype="float16"), R.Tensor((2560, 320), dtype="uint32"), R.Tensor((2560, 80), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((10240, 320), dtype="uint32"), R.Tensor((10240, 80), dtype="float16"), R.Tensor((10240,), dtype="float32"), R.Tensor((2560, 1280), dtype="uint32"), R.Tensor((2560, 320), dtype="float16"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((7680, 320), dtype="uint32"), R.Tensor((7680, 80), dtype="float16"), R.Tensor((7680,), dtype="float16"), R.Tensor((2560, 320), dtype="uint32"), R.Tensor((2560, 80), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((10240, 320), dtype="uint32"), R.Tensor((10240, 80), dtype="float16"), R.Tensor((10240,), dtype="float32"), R.Tensor((2560, 1280), dtype="uint32"), R.Tensor((2560, 320), dtype="float16"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((7680, 320), dtype="uint32"), R.Tensor((7680, 80), dtype="float16"), R.Tensor((7680,), dtype="float16"), R.Tensor((2560, 320), dtype="uint32"), R.Tensor((2560, 80), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((10240, 320), dtype="uint32"), R.Tensor((10240, 80), dtype="float16"), R.Tensor((10240,), dtype="float32"), R.Tensor((2560, 1280), dtype="uint32"), R.Tensor((2560, 320), dtype="float16"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((7680, 320), dtype="uint32"), R.Tensor((7680, 80), dtype="float16"), R.Tensor((7680,), dtype="float16"), R.Tensor((2560, 320), dtype="uint32"), R.Tensor((2560, 80), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((10240, 320), dtype="uint32"), R.Tensor((10240, 80), dtype="float16"), R.Tensor((10240,), dtype="float32"), R.Tensor((2560, 1280), dtype="uint32"), R.Tensor((2560, 320), dtype="float16"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((7680, 320), dtype="uint32"), R.Tensor((7680, 80), dtype="float16"), R.Tensor((7680,), dtype="float16"), R.Tensor((2560, 320), dtype="uint32"), R.Tensor((2560, 80), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((10240, 320), dtype="uint32"), R.Tensor((10240, 80), dtype="float16"), R.Tensor((10240,), dtype="float32"), R.Tensor((2560, 1280), dtype="uint32"), R.Tensor((2560, 320), dtype="float16"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((7680, 320), dtype="uint32"), R.Tensor((7680, 80), dtype="float16"), R.Tensor((7680,), dtype="float16"), R.Tensor((2560, 320), dtype="uint32"), R.Tensor((2560, 80), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((10240, 320), dtype="uint32"), R.Tensor((10240, 80), dtype="float16"), R.Tensor((10240,), dtype="float32"), R.Tensor((2560, 1280), dtype="uint32"), R.Tensor((2560, 320), dtype="float16"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((7680, 320), dtype="uint32"), R.Tensor((7680, 80), dtype="float16"), R.Tensor((7680,), dtype="float16"), R.Tensor((2560, 320), dtype="uint32"), R.Tensor((2560, 80), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((10240, 320), dtype="uint32"), R.Tensor((10240, 80), dtype="float16"), R.Tensor((10240,), dtype="float32"), R.Tensor((2560, 1280), dtype="uint32"), R.Tensor((2560, 320), dtype="float16"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((7680, 320), dtype="uint32"), R.Tensor((7680, 80), dtype="float16"), R.Tensor((7680,), dtype="float16"), R.Tensor((2560, 320), dtype="uint32"), R.Tensor((2560, 80), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((10240, 320), dtype="uint32"), R.Tensor((10240, 80), dtype="float16"), R.Tensor((10240,), dtype="float32"), R.Tensor((2560, 1280), dtype="uint32"), R.Tensor((2560, 320), dtype="float16"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((7680, 320), dtype="uint32"), R.Tensor((7680, 80), dtype="float16"), R.Tensor((7680,), dtype="float16"), R.Tensor((2560, 320), dtype="uint32"), R.Tensor((2560, 80), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((10240, 320), dtype="uint32"), R.Tensor((10240, 80), dtype="float16"), R.Tensor((10240,), dtype="float32"), R.Tensor((2560, 1280), dtype="uint32"), R.Tensor((2560, 320), dtype="float16"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((7680, 320), dtype="uint32"), R.Tensor((7680, 80), dtype="float16"), R.Tensor((7680,), dtype="float16"), R.Tensor((2560, 320), dtype="uint32"), R.Tensor((2560, 80), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((10240, 320), dtype="uint32"), R.Tensor((10240, 80), dtype="float16"), R.Tensor((10240,), dtype="float32"), R.Tensor((2560, 1280), dtype="uint32"), R.Tensor((2560, 320), dtype="float16"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((7680, 320), dtype="uint32"), R.Tensor((7680, 80), dtype="float16"), R.Tensor((7680,), dtype="float16"), R.Tensor((2560, 320), dtype="uint32"), R.Tensor((2560, 80), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((10240, 320), dtype="uint32"), R.Tensor((10240, 80), dtype="float16"), R.Tensor((10240,), dtype="float32"), R.Tensor((2560, 1280), dtype="uint32"), R.Tensor((2560, 320), dtype="float16"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((7680, 320), dtype="uint32"), R.Tensor((7680, 80), dtype="float16"), R.Tensor((7680,), dtype="float16"), R.Tensor((2560, 320), dtype="uint32"), R.Tensor((2560, 80), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((10240, 320), dtype="uint32"), R.Tensor((10240, 80), dtype="float16"), R.Tensor((10240,), dtype="float32"), R.Tensor((2560, 1280), dtype="uint32"), R.Tensor((2560, 320), dtype="float16"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((7680, 320), dtype="uint32"), R.Tensor((7680, 80), dtype="float16"), R.Tensor((7680,), dtype="float16"), R.Tensor((2560, 320), dtype="uint32"), R.Tensor((2560, 80), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((10240, 320), dtype="uint32"), R.Tensor((10240, 80), dtype="float16"), R.Tensor((10240,), dtype="float32"), R.Tensor((2560, 1280), dtype="uint32"), R.Tensor((2560, 320), dtype="float16"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((7680, 320), dtype="uint32"), R.Tensor((7680, 80), dtype="float16"), R.Tensor((7680,), dtype="float16"), R.Tensor((2560, 320), dtype="uint32"), R.Tensor((2560, 80), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((10240, 320), dtype="uint32"), R.Tensor((10240, 80), dtype="float16"), R.Tensor((10240,), dtype="float32"), R.Tensor((2560, 1280), dtype="uint32"), R.Tensor((2560, 320), dtype="float16"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((7680, 320), dtype="uint32"), R.Tensor((7680, 80), dtype="float16"), R.Tensor((7680,), dtype="float16"), R.Tensor((2560, 320), dtype="uint32"), R.Tensor((2560, 80), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((10240, 320), dtype="uint32"), R.Tensor((10240, 80), dtype="float16"), R.Tensor((10240,), dtype="float32"), R.Tensor((2560, 1280), dtype="uint32"), R.Tensor((2560, 320), dtype="float16"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((7680, 320), dtype="uint32"), R.Tensor((7680, 80), dtype="float16"), R.Tensor((7680,), dtype="float16"), R.Tensor((2560, 320), dtype="uint32"), R.Tensor((2560, 80), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((10240, 320), dtype="uint32"), R.Tensor((10240, 80), dtype="float16"), R.Tensor((10240,), dtype="float32"), R.Tensor((2560, 1280), dtype="uint32"), R.Tensor((2560, 320), dtype="float16"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((7680, 320), dtype="uint32"), R.Tensor((7680, 80), dtype="float16"), R.Tensor((7680,), dtype="float16"), R.Tensor((2560, 320), dtype="uint32"), R.Tensor((2560, 80), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((10240, 320), dtype="uint32"), R.Tensor((10240, 80), dtype="float16"), R.Tensor((10240,), dtype="float32"), R.Tensor((2560, 1280), dtype="uint32"), R.Tensor((2560, 320), dtype="float16"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((7680, 320), dtype="uint32"), R.Tensor((7680, 80), dtype="float16"), R.Tensor((7680,), dtype="float16"), R.Tensor((2560, 320), dtype="uint32"), R.Tensor((2560, 80), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((10240, 320), dtype="uint32"), R.Tensor((10240, 80), dtype="float16"), R.Tensor((10240,), dtype="float32"), R.Tensor((2560, 1280), dtype="uint32"), R.Tensor((2560, 320), dtype="float16"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((7680, 320), dtype="uint32"), R.Tensor((7680, 80), dtype="float16"), R.Tensor((7680,), dtype="float16"), R.Tensor((2560, 320), dtype="uint32"), R.Tensor((2560, 80), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((10240, 320), dtype="uint32"), R.Tensor((10240, 80), dtype="float16"), R.Tensor((10240,), dtype="float32"), R.Tensor((2560, 1280), dtype="uint32"), R.Tensor((2560, 320), dtype="float16"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((7680, 320), dtype="uint32"), R.Tensor((7680, 80), dtype="float16"), R.Tensor((7680,), dtype="float16"), R.Tensor((2560, 320), dtype="uint32"), R.Tensor((2560, 80), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((10240, 320), dtype="uint32"), R.Tensor((10240, 80), dtype="float16"), R.Tensor((10240,), dtype="float32"), R.Tensor((2560, 1280), dtype="uint32"), R.Tensor((2560, 320), dtype="float16"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((7680, 320), dtype="uint32"), R.Tensor((7680, 80), dtype="float16"), R.Tensor((7680,), dtype="float16"), R.Tensor((2560, 320), dtype="uint32"), R.Tensor((2560, 80), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((10240, 320), dtype="uint32"), R.Tensor((10240, 80), dtype="float16"), R.Tensor((10240,), dtype="float32"), R.Tensor((2560, 1280), dtype="uint32"), R.Tensor((2560, 320), dtype="float16"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((7680, 320), dtype="uint32"), R.Tensor((7680, 80), dtype="float16"), R.Tensor((7680,), dtype="float16"), R.Tensor((2560, 320), dtype="uint32"), R.Tensor((2560, 80), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((10240, 320), dtype="uint32"), R.Tensor((10240, 80), dtype="float16"), R.Tensor((10240,), dtype="float32"), R.Tensor((2560, 1280), dtype="uint32"), R.Tensor((2560, 320), dtype="float16"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((7680, 320), dtype="uint32"), R.Tensor((7680, 80), dtype="float16"), R.Tensor((7680,), dtype="float16"), R.Tensor((2560, 320), dtype="uint32"), R.Tensor((2560, 80), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((10240, 320), dtype="uint32"), R.Tensor((10240, 80), dtype="float16"), R.Tensor((10240,), dtype="float32"), R.Tensor((2560, 1280), dtype="uint32"), R.Tensor((2560, 320), dtype="float16"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((7680, 320), dtype="uint32"), R.Tensor((7680, 80), dtype="float16"), R.Tensor((7680,), dtype="float16"), R.Tensor((2560, 320), dtype="uint32"), R.Tensor((2560, 80), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((10240, 320), dtype="uint32"), R.Tensor((10240, 80), dtype="float16"), R.Tensor((10240,), dtype="float32"), R.Tensor((2560, 1280), dtype="uint32"), R.Tensor((2560, 320), dtype="float16"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((7680, 320), dtype="uint32"), R.Tensor((7680, 80), dtype="float16"), R.Tensor((7680,), dtype="float16"), R.Tensor((2560, 320), dtype="uint32"), R.Tensor((2560, 80), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((10240, 320), dtype="uint32"), R.Tensor((10240, 80), dtype="float16"), R.Tensor((10240,), dtype="float32"), R.Tensor((2560, 1280), dtype="uint32"), R.Tensor((2560, 320), dtype="float16"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((7680, 320), dtype="uint32"), R.Tensor((7680, 80), dtype="float16"), R.Tensor((7680,), dtype="float16"), R.Tensor((2560, 320), dtype="uint32"), R.Tensor((2560, 80), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((10240, 320), dtype="uint32"), R.Tensor((10240, 80), dtype="float16"), R.Tensor((10240,), dtype="float32"), R.Tensor((2560, 1280), dtype="uint32"), R.Tensor((2560, 320), dtype="float16"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((7680, 320), dtype="uint32"), R.Tensor((7680, 80), dtype="float16"), R.Tensor((7680,), dtype="float16"), R.Tensor((2560, 320), dtype="uint32"), R.Tensor((2560, 80), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((10240, 320), dtype="uint32"), R.Tensor((10240, 80), dtype="float16"), R.Tensor((10240,), dtype="float32"), R.Tensor((2560, 1280), dtype="uint32"), R.Tensor((2560, 320), dtype="float16"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((7680, 320), dtype="uint32"), R.Tensor((7680, 80), dtype="float16"), R.Tensor((7680,), dtype="float16"), R.Tensor((2560, 320), dtype="uint32"), R.Tensor((2560, 80), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((10240, 320), dtype="uint32"), R.Tensor((10240, 80), dtype="float16"), R.Tensor((10240,), dtype="float32"), R.Tensor((2560, 1280), dtype="uint32"), R.Tensor((2560, 320), dtype="float16"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((7680, 320), dtype="uint32"), R.Tensor((7680, 80), dtype="float16"), R.Tensor((7680,), dtype="float16"), R.Tensor((2560, 320), dtype="uint32"), R.Tensor((2560, 80), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((10240, 320), dtype="uint32"), R.Tensor((10240, 80), dtype="float16"), R.Tensor((10240,), dtype="float32"), R.Tensor((2560, 1280), dtype="uint32"), R.Tensor((2560, 320), dtype="float16"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((7680, 320), dtype="uint32"), R.Tensor((7680, 80), dtype="float16"), R.Tensor((7680,), dtype="float16"), R.Tensor((2560, 320), dtype="uint32"), R.Tensor((2560, 80), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((10240, 320), dtype="uint32"), R.Tensor((10240, 80), dtype="float16"), R.Tensor((10240,), dtype="float32"), R.Tensor((2560, 1280), dtype="uint32"), R.Tensor((2560, 320), dtype="float16"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((7680, 320), dtype="uint32"), R.Tensor((7680, 80), dtype="float16"), R.Tensor((7680,), dtype="float16"), R.Tensor((2560, 320), dtype="uint32"), R.Tensor((2560, 80), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((10240, 320), dtype="uint32"), R.Tensor((10240, 80), dtype="float16"), R.Tensor((10240,), dtype="float32"), R.Tensor((2560, 1280), dtype="uint32"), R.Tensor((2560, 320), dtype="float16"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((50432, 320), dtype="uint32"), R.Tensor((50432, 80), dtype="float32"))) -> R.Tuple(R.Tensor((1, 1, 50432), dtype="float32"), R.Tuple(R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object)): | |
n = T.int64() | |
R.func_attr({"num_input": 3, "tir_var_upper_bound": {"m": 2048, "n": 2048}}) | |
cls = Module | |
with R.dataflow(): | |
lv1868 = R.call_tir(cls.reshape5, (input_ids1,), out_sinfo=R.Tensor((1,), dtype="int32")) | |
lv711: R.Tensor((50432, 320), dtype="uint32") = model_params[0] | |
lv712: R.Tensor((50432, 80), dtype="float16") = model_params[1] | |
lv = R.call_tir(cls.fused_fused_decode1_take1, (lv711, lv712, lv1868), out_sinfo=R.Tensor((1, 2560), dtype="float16")) | |
lv1870 = R.call_tir(cls.reshape6, (lv,), out_sinfo=R.Tensor((1, 1, 2560), dtype="float16")) | |
lv1871 = R.call_tir(cls.full, R.tuple(), out_sinfo=R.Tensor((1, 1, 1, n), dtype="float16")) | |
lv1872 = R.call_tir(cls.cast7, (lv1870,), out_sinfo=R.Tensor((1, 1, 2560), dtype="float32")) | |
lv520: R.Tensor((2560,), dtype="float32") = model_params[2] | |
lv521: R.Tensor((2560,), dtype="float32") = model_params[3] | |
lv714 = R.call_tir(cls.fused_layer_norm1_cast8, (lv1872, lv520, lv521), out_sinfo=R.Tensor((1, 1, 2560), dtype="float16")) | |
lv1875: R.Tensor((1, 1, 2560), dtype="float16") = lv714 | |
lv715: R.Tensor((7680, 320), dtype="uint32") = model_params[6] | |
lv716: R.Tensor((7680, 80), dtype="float16") = model_params[7] | |
lv524: R.Tensor((7680,), dtype="float16") = model_params[8] | |
lv_1 = R.call_tir(cls.fused_fused_decode2_fused_NT_matmul6_add5, (lv715, lv716, lv1875, lv524), out_sinfo=R.Tensor((1, 1, 7680), dtype="float16")) | |
lv719 = R.call_tir(cls.fused_reshape7_split1, (lv_1,), out_sinfo=[R.Tensor((1, 1, 32, 80), dtype="float16"), R.Tensor((1, 1, 32, 80), dtype="float16"), R.Tensor((1, 1, 32, 80), dtype="float16")]) | |
lv1881: R.Tensor((1, 1, 32, 80), dtype="float16") = lv719[0] | |
lv1882 = R.call_tir(cls.rotary_embedding1, (lv1881, metadata["relax.expr.Constant"][1], metadata["relax.expr.Constant"][2]), out_sinfo=R.Tensor((1, 1, 32, 80), dtype="float16"), tir_vars=R.shape([n])) | |
lv1883: R.Tensor((1, 1, 32, 80), dtype="float16") = lv719[1] | |
lv1884 = R.call_tir(cls.rotary_embedding1, (lv1883, metadata["relax.expr.Constant"][1], metadata["relax.expr.Constant"][2]), out_sinfo=R.Tensor((1, 1, 32, 80), dtype="float16"), tir_vars=R.shape([n])) | |
lv1885: R.Object = kv_cache[0] | |
lv1886 = R.call_tir(cls.squeeze1, (lv1884,), out_sinfo=R.Tensor((1, 32, 80), dtype="float16")) | |
lv1887: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv1885, lv1886, sinfo_args=(R.Object,)) | |
lv1888: R.Object = kv_cache[1] | |
lv720: R.Tensor((1, 1, 32, 80), dtype="float16") = lv719[2] | |
lv721 = R.call_tir(cls.fused_squeeze1, (lv720,), out_sinfo=R.Tensor((1, 32, 80), dtype="float16")) | |
lv1891: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv1888, lv721, sinfo_args=(R.Object,)) | |
lv1892: R.Tensor((n, 32, 80), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv1887, R.shape([n, 32, 80]), sinfo_args=(R.Tensor((n, 32, 80), dtype="float16"),)) | |
lv1893: R.Tensor((n, 32, 80), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv1891, R.shape([n, 32, 80]), sinfo_args=(R.Tensor((n, 32, 80), dtype="float16"),)) | |
lv1894 = R.call_tir(cls.reshape3, (lv1892,), out_sinfo=R.Tensor((1, n, 32, 80), dtype="float16")) | |
lv1895 = R.call_tir(cls.reshape3, (lv1893,), out_sinfo=R.Tensor((1, n, 32, 80), dtype="float16")) | |
lv1896 = R.call_tir(cls.transpose7, (lv1882,), out_sinfo=R.Tensor((1, 32, 1, 80), dtype="float16")) | |
lv1897 = R.call_tir(cls.transpose5, (lv1894,), out_sinfo=R.Tensor((1, 32, n, 80), dtype="float16")) | |
lv1898 = R.call_tir(cls.transpose5, (lv1895,), out_sinfo=R.Tensor((1, 32, n, 80), dtype="float16")) | |
lv722 = R.call_tir(cls.fused_NT_matmul7_divide2_maximum1_minimum1_cast9, (lv1896, lv1897, lv1871), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float32")) | |
lv723 = R.call_tir(cls.fused_softmax2_cast10, (lv722,), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float16")) | |
lv1907 = R.call_tir(cls.matmul9, (lv723, lv1898), out_sinfo=R.Tensor((1, 32, 1, 80), dtype="float16")) | |
lv724 = R.call_tir(cls.fused_transpose8_reshape8, (lv1907,), out_sinfo=R.Tensor((1, 1, 2560), dtype="float16")) | |
lv725: R.Tensor((2560, 320), dtype="uint32") = model_params[9] | |
lv726: R.Tensor((2560, 80), dtype="float16") = model_params[10] | |
lv527: R.Tensor((2560,), dtype="float16") = model_params[11] | |
lv_2 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul8_add6_add7, (lv725, lv726, lv724, lv527, lv1870), out_sinfo=R.Tensor((1, 1, 2560), dtype="float16")) | |
lv1914 = R.call_tir(cls.cast7, (lv_2,), out_sinfo=R.Tensor((1, 1, 2560), dtype="float32")) | |
lv528: R.Tensor((2560,), dtype="float32") = model_params[4] | |
lv529: R.Tensor((2560,), dtype="float32") = model_params[5] | |
lv729 = R.call_tir(cls.fused_layer_norm1_cast8, (lv1914, lv528, lv529), out_sinfo=R.Tensor((1, 1, 2560), dtype="float16")) | |
lv1917: R.Tensor((1, 1, 2560), dtype="float16") = lv729 | |
lv730: R.Tensor((10240, 320), dtype="uint32") = model_params[12] | |
lv731: R.Tensor((10240, 80), dtype="float16") = model_params[13] | |
lv532: R.Tensor((10240,), dtype="float32") = model_params[14] | |
lv1 = R.call_tir(cls.fused_fused_decode4_fused_NT_matmul9_add8_gelu1_cast11, (lv730, lv731, lv1917, lv532), out_sinfo=R.Tensor((1, 1, 10240), dtype="float16")) | |
lv1923: R.Tensor((1, 1, 10240), dtype="float16") = lv1 | |
lv734: R.Tensor((2560, 1280), dtype="uint32") = model_params[15] | |
lv735: R.Tensor((2560, 320), dtype="float16") = model_params[16] | |
lv535: R.Tensor((2560,), dtype="float32") = model_params[17] | |
lv1_1 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul10_add9_cast8_cast12_add7, (lv734, lv735, lv1923, lv535, lv_2), out_sinfo=R.Tensor((1, 1, 2560), dtype="float16")) | |
lv1930 = R.call_tir(cls.cast7, (lv1_1,), out_sinfo=R.Tensor((1, 1, 2560), dtype="float32")) | |
lv536: R.Tensor((2560,), dtype="float32") = model_params[18] | |
lv537: R.Tensor((2560,), dtype="float32") = model_params[19] | |
lv738 = R.call_tir(cls.fused_layer_norm1_cast8, (lv1930, lv536, lv537), out_sinfo=R.Tensor((1, 1, 2560), dtype="float16")) | |
lv1933: R.Tensor((1, 1, 2560), dtype="float16") = lv738 | |
lv739: R.Tensor((7680, 320), dtype="uint32") = model_params[22] | |
lv740: R.Tensor((7680, 80), dtype="float16") = model_params[23] | |
lv540: R.Tensor((7680,), dtype="float16") = model_params[24] | |
lv2 = R.call_tir(cls.fused_fused_decode2_fused_NT_matmul6_add5, (lv739, lv740, lv1933, lv540), out_sinfo=R.Tensor((1, 1, 7680), dtype="float16")) | |
lv743 = R.call_tir(cls.fused_reshape7_split1, (lv2,), out_sinfo=[R.Tensor((1, 1, 32, 80), dtype="float16"), R.Tensor((1, 1, 32, 80), dtype="float16"), R.Tensor((1, 1, 32, 80), dtype="float16")]) | |
lv1939: R.Tensor((1, 1, 32, 80), dtype="float16") = lv743[0] | |
lv1940 = R.call_tir(cls.rotary_embedding1, (lv1939, metadata["relax.expr.Constant"][1], metadata["relax.expr.Constant"][2]), out_sinfo=R.Tensor((1, 1, 32, 80), dtype="float16"), tir_vars=R.shape([n])) | |
lv1941: R.Tensor((1, 1, 32, 80), dtype="float16") = lv743[1] | |
lv1942 = R.call_tir(cls.rotary_embedding1, (lv1941, metadata["relax.expr.Constant"][1], metadata["relax.expr.Constant"][2]), out_sinfo=R.Tensor((1, 1, 32, 80), dtype="float16"), tir_vars=R.shape([n])) | |
lv1943: R.Object = kv_cache[2] | |
lv1944 = R.call_tir(cls.squeeze1, (lv1942,), out_sinfo=R.Tensor((1, 32, 80), dtype="float16")) | |
lv1945: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv1943, lv1944, sinfo_args=(R.Object,)) | |
lv1946: R.Object = kv_cache[3] | |
lv744: R.Tensor((1, 1, 32, 80), dtype="float16") = lv743[2] | |
lv745 = R.call_tir(cls.fused_squeeze1, (lv744,), out_sinfo=R.Tensor((1, 32, 80), dtype="float16")) | |
lv1949: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv1946, lv745, sinfo_args=(R.Object,)) | |
lv1950: R.Tensor((n, 32, 80), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv1945, R.shape([n, 32, 80]), sinfo_args=(R.Tensor((n, 32, 80), dtype="float16"),)) | |
lv1951: R.Tensor((n, 32, 80), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv1949, R.shape([n, 32, 80]), sinfo_args=(R.Tensor((n, 32, 80), dtype="float16"),)) | |
lv1952 = R.call_tir(cls.reshape3, (lv1950,), out_sinfo=R.Tensor((1, n, 32, 80), dtype="float16")) | |
lv1953 = R.call_tir(cls.reshape3, (lv1951,), out_sinfo=R.Tensor((1, n, 32, 80), dtype="float16")) | |
lv1954 = R.call_tir(cls.transpose7, (lv1940,), out_sinfo=R.Tensor((1, 32, 1, 80), dtype="float16")) | |
lv1955 = R.call_tir(cls.transpose5, (lv1952,), out_sinfo=R.Tensor((1, 32, n, 80), dtype="float16")) | |
lv1956 = R.call_tir(cls.transpose5, (lv1953,), out_sinfo=R.Tensor((1, 32, n, 80), dtype="float16")) | |
lv746 = R.call_tir(cls.fused_NT_matmul7_divide2_maximum1_minimum1_cast9, (lv1954, lv1955, lv1871), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float32")) | |
lv747 = R.call_tir(cls.fused_softmax2_cast10, (lv746,), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float16")) | |
lv1965 = R.call_tir(cls.matmul9, (lv747, lv1956), out_sinfo=R.Tensor((1, 32, 1, 80), dtype="float16")) | |
lv748 = R.call_tir(cls.fused_transpose8_reshape8, (lv1965,), out_sinfo=R.Tensor((1, 1, 2560), dtype="float16")) | |
lv749: R.Tensor((2560, 320), dtype="uint32") = model_params[25] | |
lv750: R.Tensor((2560, 80), dtype="float16") = model_params[26] | |
lv543: R.Tensor((2560,), dtype="float16") = model_params[27] | |
lv2_1 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul8_add6_add7, (lv749, lv750, lv748, lv543, lv1_1), out_sinfo=R.Tensor((1, 1, 2560), dtype="float16")) | |
lv1972 = R.call_tir(cls.cast7, (lv2_1,), out_sinfo=R.Tensor((1, 1, 2560), dtype="float32")) | |
lv544: R.Tensor((2560,), dtype="float32") = model_params[20] | |
lv545: R.Tensor((2560,), dtype="float32") = model_params[21] | |
lv753 = R.call_tir(cls.fused_layer_norm1_cast8, (lv1972, lv544, lv545), out_sinfo=R.Tensor((1, 1, 2560), dtype="float16")) | |
lv1975: R.Tensor((1, 1, 2560), dtype="float16") = lv753 | |
lv754: R.Tensor((10240, 320), dtype="uint32") = model_params[28] | |
lv755: R.Tensor((10240, 80), dtype="float16") = model_params[29] | |
lv548: R.Tensor((10240,), dtype="float32") = model_params[30] | |
lv3 = R.call_tir(cls.fused_fused_decode4_fused_NT_matmul9_add8_gelu1_cast11, (lv754, lv755, lv1975, lv548), out_sinfo=R.Tensor((1, 1, 10240), dtype="float16")) | |
lv1981: R.Tensor((1, 1, 10240), dtype="float16") = lv3 | |
lv758: R.Tensor((2560, 1280), dtype="uint32") = model_params[31] | |
lv759: R.Tensor((2560, 320), dtype="float16") = model_params[32] | |
lv551: R.Tensor((2560,), dtype="float32") = model_params[33] | |
lv3_1 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul10_add9_cast8_cast12_add7, (lv758, lv759, lv1981, lv551, lv2_1), out_sinfo=R.Tensor((1, 1, 2560), dtype="float16")) | |
lv1988 = R.call_tir(cls.cast7, (lv3_1,), out_sinfo=R.Tensor((1, 1, 2560), dtype="float32")) | |
lv552: R.Tensor((2560,), dtype="float32") = model_params[34] | |
lv553: R.Tensor((2560,), dtype="float32") = model_params[35] | |
lv762 = R.call_tir(cls.fused_layer_norm1_cast8, (lv1988, lv552, lv553), out_sinfo=R.Tensor((1, 1, 2560), dtype="float16")) | |
lv1991: R.Tensor((1, 1, 2560), dtype="float16") = lv762 | |
lv763: R.Tensor((7680, 320), dtype="uint32") = model_params[38] | |
lv764: R.Tensor((7680, 80), dtype="float16") = model_params[39] | |
lv556: R.Tensor((7680,), dtype="float16") = model_params[40] | |
lv4 = R.call_tir(cls.fused_fused_decode2_fused_NT_matmul6_add5, (lv763, lv764, lv1991, lv556), out_sinfo=R.Tensor((1, 1, 7680), dtype="float16")) | |
lv767 = R.call_tir(cls.fused_reshape7_split1, (lv4,), out_sinfo=[R.Tensor((1, 1, 32, 80), dtype="float16"), R.Tensor((1, 1, 32, 80), dtype="float16"), R.Tensor((1, 1, 32, 80), dtype="float16")]) | |
lv1997: R.Tensor((1, 1, 32, 80), dtype="float16") = lv767[0] | |
lv1998 = R.call_tir(cls.rotary_embedding1, (lv1997, metadata["relax.expr.Constant"][1], metadata["relax.expr.Constant"][2]), out_sinfo=R.Tensor((1, 1, 32, 80), dtype="float16"), tir_vars=R.shape([n])) | |
lv1999: R.Tensor((1, 1, 32, 80), dtype="float16") = lv767[1] | |
lv2000 = R.call_tir(cls.rotary_embedding1, (lv1999, metadata["relax.expr.Constant"][1], metadata["relax.expr.Constant"][2]), out_sinfo=R.Tensor((1, 1, 32, 80), dtype="float16"), tir_vars=R.shape([n])) | |
lv2001: R.Object = kv_cache[4] | |
lv2002 = R.call_tir(cls.squeeze1, (lv2000,), out_sinfo=R.Tensor((1, 32, 80), dtype="float16")) | |
lv2003: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv2001, lv2002, sinfo_args=(R.Object,)) | |
lv2004: R.Object = kv_cache[5] | |
lv768: R.Tensor((1, 1, 32, 80), dtype="float16") = lv767[2] | |
lv769 = R.call_tir(cls.fused_squeeze1, (lv768,), out_sinfo=R.Tensor((1, 32, 80), dtype="float16")) | |
lv2007: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv2004, lv769, sinfo_args=(R.Object,)) | |
lv2008: R.Tensor((n, 32, 80), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv2003, R.shape([n, 32, 80]), sinfo_args=(R.Tensor((n, 32, 80), dtype="float16"),)) | |
lv2009: R.Tensor((n, 32, 80), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv2007, R.shape([n, 32, 80]), sinfo_args=(R.Tensor((n, 32, 80), dtype="float16"),)) | |
lv2010 = R.call_tir(cls.reshape3, (lv2008,), out_sinfo=R.Tensor((1, n, 32, 80), dtype="float16")) | |
lv2011 = R.call_tir(cls.reshape3, (lv2009,), out_sinfo=R.Tensor((1, n, 32, 80), dtype="float16")) | |
lv2012 = R.call_tir(cls.transpose7, (lv1998,), out_sinfo=R.Tensor((1, 32, 1, 80), dtype="float16")) | |
lv2013 = R.call_tir(cls.transpose5, (lv2010,), out_sinfo=R.Tensor((1, 32, n, 80), dtype="float16")) | |
lv2014 = R.call_tir(cls.transpose5, (lv2011,), out_sinfo=R.Tensor((1, 32, n, 80), dtype="float16")) | |
lv770 = R.call_tir(cls.fused_NT_matmul7_divide2_maximum1_minimum1_cast9, (lv2012, lv2013, lv1871), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float32")) | |
lv771 = R.call_tir(cls.fused_softmax2_cast10, (lv770,), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float16")) | |
lv2023 = R.call_tir(cls.matmul9, (lv771, lv2014), out_sinfo=R.Tensor((1, 32, 1, 80), dtype="float16")) | |
lv772 = R.call_tir(cls.fused_transpose8_reshape8, (lv2023,), out_sinfo=R.Tensor((1, 1, 2560), dtype="float16")) | |
lv773: R.Tensor((2560, 320), dtype="uint32") = model_params[41] | |
lv774: R.Tensor((2560, 80), dtype="float16") = model_params[42] | |
lv559: R.Tensor((2560,), dtype="float16") = model_params[43] | |
lv4_1 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul8_add6_add7, (lv773, lv774, lv772, lv559, lv3_1), out_sinfo=R.Tensor((1, 1, 2560), dtype="float16")) | |
lv2030 = R.call_tir(cls.cast7, (lv4_1,), out_sinfo=R.Tensor((1, 1, 2560), dtype="float32")) | |
lv560: R.Tensor((2560,), dtype="float32") = model_params[36] | |
lv561: R.Tensor((2560,), dtype="float32") = model_params[37] | |
lv777 = R.call_tir(cls.fused_layer_norm1_cast8, (lv2030, lv560, lv561), out_sinfo=R.Tensor((1, 1, 2560), dtype="float16")) | |
lv2033: R.Tensor((1, 1, 2560), dtype="float16") = lv777 | |
lv778: R.Tensor((10240, 320), dtype="uint32") = model_params[44] | |
lv779: R.Tensor((10240, 80), dtype="float16") = model_params[45] | |
lv564: R.Tensor((10240,), dtype="float32") = model_params[46] | |
lv5 = R.call_tir(cls.fused_fused_decode4_fused_NT_matmul9_add8_gelu1_cast11, (lv778, lv779, lv2033, lv564), out_sinfo=R.Tensor((1, 1, 10240), dtype="float16")) | |
lv2039: R.Tensor((1, 1, 10240), dtype="float16") = lv5 | |
lv782: R.Tensor((2560, 1280), dtype="uint32") = model_params[47] | |
lv783: R.Tensor((2560, 320), dtype="float16") = model_params[48] | |
lv567: R.Tensor((2560,), dtype="float32") = model_params[49] | |
lv5_1 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul10_add9_cast8_cast12_add7, (lv782, lv783, lv2039, lv567, lv4_1), out_sinfo=R.Tensor((1, 1, 2560), dtype="float16")) | |
lv2046 = R.call_tir(cls.cast7, (lv5_1,), out_sinfo=R.Tensor((1, 1, 2560), dtype="float32")) | |
lv568: R.Tensor((2560,), dtype="float32") = model_params[50] | |
lv569: R.Tensor((2560,), dtype="float32") = model_params[51] | |
lv786 = R.call_tir(cls.fused_layer_norm1_cast8, (lv2046, lv568, lv569), out_sinfo=R.Tensor((1, 1, 2560), dtype="float16")) | |
lv2049: R.Tensor((1, 1, 2560), dtype="float16") = lv786 | |
lv787: R.Tensor((7680, 320), dtype="uint32") = model_params[54] | |
lv788: R.Tensor((7680, 80), dtype="float16") = model_params[55] | |
lv572: R.Tensor((7680,), dtype="float16") = model_params[56] | |
lv6 = R.call_tir(cls.fused_fused_decode2_fused_NT_matmul6_add5, (lv787, lv788, lv2049, lv572), out_sinfo=R.Tensor((1, 1, 7680), dtype="float16")) | |
lv791 = R.call_tir(cls.fused_reshape7_split1, (lv6,), out_sinfo=[R.Tensor((1, 1, 32, 80), dtype="float16"), R.Tensor((1, 1, 32, 80), dtype="float16"), R.Tensor((1, 1, 32, 80), dtype="float16")]) | |
lv2055: R.Tensor((1, 1, 32, 80), dtype="float16") = lv791[0] | |
lv2056 = R.call_tir(cls.rotary_embedding1, (lv2055, metadata["relax.expr.Constant"][1], metadata["relax.expr.Constant"][2]), out_sinfo=R.Tensor((1, 1, 32, 80), dtype="float16"), tir_vars=R.shape([n])) | |
lv2057: R.Tensor((1, 1, 32, 80), dtype="float16") = lv791[1] | |
lv2058 = R.call_tir(cls.rotary_embedding1, (lv2057, metadata["relax.expr.Constant"][1], metadata["relax.expr.Constant"][2]), out_sinfo=R.Tensor((1, 1, 32, 80), dtype="float16"), tir_vars=R.shape([n])) | |
lv2059: R.Object = kv_cache[6] | |
lv2060 = R.call_tir(cls.squeeze1, (lv2058,), out_sinfo=R.Tensor((1, 32, 80), dtype="float16")) | |
lv2061: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv2059, lv2060, sinfo_args=(R.Object,)) | |
lv2062: R.Object = kv_cache[7] | |
lv792: R.Tensor((1, 1, 32, 80), dtype="float16") = lv791[2] | |
lv793 = R.call_tir(cls.fused_squeeze1, (lv792,), out_sinfo=R.Tensor((1, 32, 80), dtype="float16")) | |
lv2065: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv2062, lv793, sinfo_args=(R.Object,)) | |
lv2066: R.Tensor((n, 32, 80), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv2061, R.shape([n, 32, 80]), sinfo_args=(R.Tensor((n, 32, 80), dtype="float16"),)) | |
lv2067: R.Tensor((n, 32, 80), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv2065, R.shape([n, 32, 80]), sinfo_args=(R.Tensor((n, 32, 80), dtype="float16"),)) | |
lv2068 = R.call_tir(cls.reshape3, (lv2066,), out_sinfo=R.Tensor((1, n, 32, 80), dtype="float16")) | |
lv2069 = R.call_tir(cls.reshape3, (lv2067,), out_sinfo=R.Tensor((1, n, 32, 80), dtype="float16")) | |
lv2070 = R.call_tir(cls.transpose7, (lv2056,), out_sinfo=R.Tensor((1, 32, 1, 80), dtype="float16")) | |
lv2071 = R.call_tir(cls.transpose5, (lv2068,), out_sinfo=R.Tensor((1, 32, n, 80), dtype="float16")) | |
lv2072 = R.call_tir(cls.transpose5, (lv2069,), out_sinfo=R.Tensor((1, 32, n, 80), dtype="float16")) | |
lv794 = R.call_tir(cls.fused_NT_matmul7_divide2_maximum1_minimum1_cast9, (lv2070, lv2071, lv1871), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float32")) | |
lv795 = R.call_tir(cls.fused_softmax2_cast10, (lv794,), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float16")) | |
lv2081 = R.call_tir(cls.matmul9, (lv795, lv2072), out_sinfo=R.Tensor((1, 32, 1, 80), dtype="float16")) | |
lv796 = R.call_tir(cls.fused_transpose8_reshape8, (lv2081,), out_sinfo=R.Tensor((1, 1, 2560), dtype="float16")) | |
lv797: R.Tensor((2560, 320), dtype="uint32") = model_params[57] | |
lv798: R.Tensor((2560, 80), dtype="float16") = model_params[58] | |
lv575: R.Tensor((2560,), dtype="float16") = model_params[59] | |
lv6_1 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul8_add6_add7, (lv797, lv798, lv796, lv575, lv5_1), out_sinfo=R.Tensor((1, 1, 2560), dtype="float16")) | |
lv2088 = R.call_tir(cls.cast7, (lv6_1,), out_sinfo=R.Tensor((1, 1, 2560), dtype="float32")) | |
lv576: R.Tensor((2560,), dtype="float32") = model_params[52] | |
lv577: R.Tensor((2560,), dtype="float32") = model_params[53] | |
lv801 = R.call_tir(cls.fused_layer_norm1_cast8, (lv2088, lv576, lv577), out_sinfo=R.Tensor((1, 1, 2560), dtype="float16")) | |
lv2091: R.Tensor((1, 1, 2560), dtype="float16") = lv801 | |
lv802: R.Tensor((10240, 320), dtype="uint32") = model_params[60] | |
lv803: R.Tensor((10240, 80), dtype="float16") = model_params[61] | |
lv580: R.Tensor((10240,), dtype="float32") = model_params[62] | |
lv7 = R.call_tir(cls.fused_fused_decode4_fused_NT_matmul9_add8_gelu1_cast11, (lv802, lv803, lv2091, lv580), out_sinfo=R.Tensor((1, 1, 10240), dtype="float16")) | |
lv2097: R.Tensor((1, 1, 10240), dtype="float16") = lv7 | |
lv806: R.Tensor((2560, 1280), dtype="uint32") = model_params[63] | |
lv807: R.Tensor((2560, 320), dtype="float16") = model_params[64] | |
lv583: R.Tensor((2560,), dtype="float32") = model_params[65] | |
lv7_1 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul10_add9_cast8_cast12_add7, (lv806, lv807, lv2097, lv583, lv6_1), out_sinfo=R.Tensor((1, 1, 2560), dtype="float16")) | |
lv2104 = R.call_tir(cls.cast7, (lv7_1,), out_sinfo=R.Tensor((1, 1, 2560), dtype="float32")) | |
lv584: R.Tensor((2560,), dtype="float32") = model_params[66] | |
lv585: R.Tensor((2560,), dtype="float32") = model_params[67] | |
lv810 = R.call_tir(cls.fused_layer_norm1_cast8, (lv2104, lv584, lv585), out_sinfo=R.Tensor((1, 1, 2560), dtype="float16")) | |
lv2107: R.Tensor((1, 1, 2560), dtype="float16") = lv810 | |
lv811: R.Tensor((7680, 320), dtype="uint32") = model_params[70] | |
lv812: R.Tensor((7680, 80), dtype="float16") = model_params[71] | |
lv588: R.Tensor((7680,), dtype="float16") = model_params[72] | |
lv8 = R.call_tir(cls.fused_fused_decode2_fused_NT_matmul6_add5, (lv811, lv812, lv2107, lv588), out_sinfo=R.Tensor((1, 1, 7680), dtype="float16")) | |
lv815 = R.call_tir(cls.fused_reshape7_split1, (lv8,), out_sinfo=[R.Tensor((1, 1, 32, 80), dtype="float16"), R.Tensor((1, 1, 32, 80), dtype="float16"), R.Tensor((1, 1, 32, 80), dtype="float16")]) | |
lv2113: R.Tensor((1, 1, 32, 80), dtype="float16") = lv815[0] | |
lv2114 = R.call_tir(cls.rotary_embedding1, (lv2113, metadata["relax.expr.Constant"][1], metadata["relax.expr.Constant"][2]), out_sinfo=R.Tensor((1, 1, 32, 80), dtype="float16"), tir_vars=R.shape([n])) | |
lv2115: R.Tensor((1, 1, 32, 80), dtype="float16") = lv815[1] | |
lv2116 = R.call_tir(cls.rotary_embedding1, (lv2115, metadata["relax.expr.Constant"][1], metadata["relax.expr.Constant"][2]), out_sinfo=R.Tensor((1, 1, 32, 80), dtype="float16"), tir_vars=R.shape([n])) | |
lv2117: R.Object = kv_cache[8] | |
lv2118 = R.call_tir(cls.squeeze1, (lv2116,), out_sinfo=R.Tensor((1, 32, 80), dtype="float16")) | |
lv2119: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv2117, lv2118, sinfo_args=(R.Object,)) | |
lv2120: R.Object = kv_cache[9] | |
lv816: R.Tensor((1, 1, 32, 80), dtype="float16") = lv815[2] | |
lv817 = R.call_tir(cls.fused_squeeze1, (lv816,), out_sinfo=R.Tensor((1, 32, 80), dtype="float16")) | |
lv2123: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv2120, lv817, sinfo_args=(R.Object,)) | |
lv2124: R.Tensor((n, 32, 80), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv2119, R.shape([n, 32, 80]), sinfo_args=(R.Tensor((n, 32, 80), dtype="float16"),)) | |
lv2125: R.Tensor((n, 32, 80), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv2123, R.shape([n, 32, 80]), sinfo_args=(R.Tensor((n, 32, 80), dtype="float16"),)) | |
lv2126 = R.call_tir(cls.reshape3, (lv2124,), out_sinfo=R.Tensor((1, n, 32, 80), dtype="float16")) | |
lv2127 = R.call_tir(cls.reshape3, (lv2125,), out_sinfo=R.Tensor((1, n, 32, 80), dtype="float16")) | |
lv2128 = R.call_tir(cls.transpose7, (lv2114,), out_sinfo=R.Tensor((1, 32, 1, 80), dtype="float16")) | |
lv2129 = R.call_tir(cls.transpose5, (lv2126,), out_sinfo=R.Tensor((1, 32, n, 80), dtype="float16")) | |
lv2130 = R.call_tir(cls.transpose5, (lv2127,), out_sinfo=R.Tensor((1, 32, n, 80), dtype="float16")) | |
lv818 = R.call_tir(cls.fused_NT_matmul7_divide2_maximum1_minimum1_cast9, (lv2128, lv2129, lv1871), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float32")) | |
lv819 = R.call_tir(cls.fused_softmax2_cast10, (lv818,), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float16")) | |
lv2139 = R.call_tir(cls.matmul9, (lv819, lv2130), out_sinfo=R.Tensor((1, 32, 1, 80), dtype="float16")) | |
lv820 = R.call_tir(cls.fused_transpose8_reshape8, (lv2139,), out_sinfo=R.Tensor((1, 1, 2560), dtype="float16")) | |
lv821: R.Tensor((2560, 320), dtype="uint32") = model_params[73] | |
lv822: R.Tensor((2560, 80), dtype="float16") = model_params[74] | |
lv591: R.Tensor((2560,), dtype="float16") = model_params[75] | |
lv8_1 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul8_add6_add7, (lv821, lv822, lv820, lv591, lv7_1), out_sinfo=R.Tensor((1, 1, 2560), dtype="float16")) | |
lv2146 = R.call_tir(cls.cast7, (lv8_1,), out_sinfo=R.Tensor((1, 1, 2560), dtype="float32")) | |
lv592: R.Tensor((2560,), dtype="float32") = model_params[68] | |
lv593: R.Tensor((2560,), dtype="float32") = model_params[69] | |
lv825 = R.call_tir(cls.fused_layer_norm1_cast8, (lv2146, lv592, lv593), out_sinfo=R.Tensor((1, 1, 2560), dtype="float16")) | |
lv2149: R.Tensor((1, 1, 2560), dtype="float16") = lv825 | |
lv826: R.Tensor((10240, 320), dtype="uint32") = model_params[76] | |
lv827: R.Tensor((10240, 80), dtype="float16") = model_params[77] | |
lv596: R.Tensor((10240,), dtype="float32") = model_params[78] | |
lv9 = R.call_tir(cls.fused_fused_decode4_fused_NT_matmul9_add8_gelu1_cast11, (lv826, lv827, lv2149, lv596), out_sinfo=R.Tensor((1, 1, 10240), dtype="float16")) | |
lv2155: R.Tensor((1, 1, 10240), dtype="float16") = lv9 | |
lv830: R.Tensor((2560, 1280), dtype="uint32") = model_params[79] | |
lv831: R.Tensor((2560, 320), dtype="float16") = model_params[80] | |
lv599: R.Tensor((2560,), dtype="float32") = model_params[81] | |
lv9_1 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul10_add9_cast8_cast12_add7, (lv830, lv831, lv2155, lv599, lv8_1), out_sinfo=R.Tensor((1, 1, 2560), dtype="float16")) | |
lv2162 = R.call_tir(cls.cast7, (lv9_1,), out_sinfo=R.Tensor((1, 1, 2560), dtype="float32")) | |
lv600: R.Tensor((2560,), dtype="float32") = model_params[82] | |
lv601: R.Tensor((2560,), dtype="float32") = model_params[83] | |
lv834 = R.call_tir(cls.fused_layer_norm1_cast8, (lv2162, lv600, lv601), out_sinfo=R.Tensor((1, 1, 2560), dtype="float16")) | |
lv2165: R.Tensor((1, 1, 2560), dtype="float16") = lv834 | |
lv835: R.Tensor((7680, 320), dtype="uint32") = model_params[86] | |
lv836: R.Tensor((7680, 80), dtype="float16") = model_params[87] | |
lv604: R.Tensor((7680,), dtype="float16") = model_params[88] | |
lv10 = R.call_tir(cls.fused_fused_decode2_fused_NT_matmul6_add5, (lv835, lv836, lv2165, lv604), out_sinfo=R.Tensor((1, 1, 7680), dtype="float16")) | |
lv839 = R.call_tir(cls.fused_reshape7_split1, (lv10,), out_sinfo=[R.Tensor((1, 1, 32, 80), dtype="float16"), R.Tensor((1, 1, 32, 80), dtype="float16"), R.Tensor((1, 1, 32, 80), dtype="float16")]) | |
lv2171: R.Tensor((1, 1, 32, 80), dtype="float16") = lv839[0] | |
lv2172 = R.call_tir(cls.rotary_embedding1, (lv2171, metadata["relax.expr.Constant"][1], metadata["relax.expr.Constant"][2]), out_sinfo=R.Tensor((1, 1, 32, 80), dtype="float16"), tir_vars=R.shape([n])) | |
lv2173: R.Tensor((1, 1, 32, 80), dtype="float16") = lv839[1] | |
lv2174 = R.call_tir(cls.rotary_embedding1, (lv2173, metadata["relax.expr.Constant"][1], metadata["relax.expr.Constant"][2]), out_sinfo=R.Tensor((1, 1, 32, 80), dtype="float16"), tir_vars=R.shape([n])) | |
lv2175: R.Object = kv_cache[10] | |
lv2176 = R.call_tir(cls.squeeze1, (lv2174,), out_sinfo=R.Tensor((1, 32, 80), dtype="float16")) | |
lv2177: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv2175, lv2176, sinfo_args=(R.Object,)) | |
lv2178: R.Object = kv_cache[11] | |
lv840: R.Tensor((1, 1, 32, 80), dtype="float16") = lv839[2] | |
lv841 = R.call_tir(cls.fused_squeeze1, (lv840,), out_sinfo=R.Tensor((1, 32, 80), dtype="float16")) | |
lv2181: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv2178, lv841, sinfo_args=(R.Object,)) | |
lv2182: R.Tensor((n, 32, 80), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv2177, R.shape([n, 32, 80]), sinfo_args=(R.Tensor((n, 32, 80), dtype="float16"),)) | |
lv2183: R.Tensor((n, 32, 80), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv2181, R.shape([n, 32, 80]), sinfo_args=(R.Tensor((n, 32, 80), dtype="float16"),)) | |
lv2184 = R.call_tir(cls.reshape3, (lv2182,), out_sinfo=R.Tensor((1, n, 32, 80), dtype="float16")) | |
lv2185 = R.call_tir(cls.reshape3, (lv2183,), out_sinfo=R.Tensor((1, n, 32, 80), dtype="float16")) | |
lv2186 = R.call_tir(cls.transpose7, (lv2172,), out_sinfo=R.Tensor((1, 32, 1, 80), dtype="float16")) | |
lv2187 = R.call_tir(cls.transpose5, (lv2184,), out_sinfo=R.Tensor((1, 32, n, 80), dtype="float16")) | |
lv2188 = R.call_tir(cls.transpose5, (lv2185,), out_sinfo=R.Tensor((1, 32, n, 80), dtype="float16")) | |
lv842 = R.call_tir(cls.fused_NT_matmul7_divide2_maximum1_minimum1_cast9, (lv2186, lv2187, lv1871), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float32")) | |
lv843 = R.call_tir(cls.fused_softmax2_cast10, (lv842,), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float16")) | |
lv2197 = R.call_tir(cls.matmul9, (lv843, lv2188), out_sinfo=R.Tensor((1, 32, 1, 80), dtype="float16")) | |
lv844 = R.call_tir(cls.fused_transpose8_reshape8, (lv2197,), out_sinfo=R.Tensor((1, 1, 2560), dtype="float16")) | |
lv845: R.Tensor((2560, 320), dtype="uint32") = model_params[89] | |
lv846: R.Tensor((2560, 80), dtype="float16") = model_params[90] | |
lv607: R.Tensor((2560,), dtype="float16") = model_params[91] | |
lv10_1 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul8_add6_add7, (lv845, lv846, lv844, lv607, lv9_1), out_sinfo=R.Tensor((1, 1, 2560), dtype="float16")) | |
lv2204 = R.call_tir(cls.cast7, (lv10_1,), out_sinfo=R.Tensor((1, 1, 2560), dtype="float32")) | |
lv608: R.Tensor((2560,), dtype="float32") = model_params[84] | |
lv609: R.Tensor((2560,), dtype="float32") = model_params[85] | |
lv849 = R.call_tir(cls.fused_layer_norm1_cast8, (lv2204, lv608, lv609), out_sinfo=R.Tensor((1, 1, 2560), dtype="float16")) | |
lv2207: R.Tensor((1, 1, 2560), dtype="float16") = lv849 | |
lv850: R.Tensor((10240, 320), dtype="uint32") = model_params[92] | |
lv851: R.Tensor((10240, 80), dtype="float16") = model_params[93] | |
lv612: R.Tensor((10240,), dtype="float32") = model_params[94] | |
lv11 = R.call_tir(cls.fused_fused_decode4_fused_NT_matmul9_add8_gelu1_cast11, (lv850, lv851, lv2207, lv612), out_sinfo=R.Tensor((1, 1, 10240), dtype="float16")) | |
lv2213: R.Tensor((1, 1, 10240), dtype="float16") = lv11 | |
lv854: R.Tensor((2560, 1280), dtype="uint32") = model_params[95] | |
lv855: R.Tensor((2560, 320), dtype="float16") = model_params[96] | |
lv615: R.Tensor((2560,), dtype="float32") = model_params[97] | |
lv11_1 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul10_add9_cast8_cast12_add7, (lv854, lv855, lv2213, lv615, lv10_1), out_sinfo=R.Tensor((1, 1, 2560), dtype="float16")) | |
lv2220 = R.call_tir(cls.cast7, (lv11_1,), out_sinfo=R.Tensor((1, 1, 2560), dtype="float32")) | |
lv616: R.Tensor((2560,), dtype="float32") = model_params[98] | |
lv617: R.Tensor((2560,), dtype="float32") = model_params[99] | |
lv858 = R.call_tir(cls.fused_layer_norm1_cast8, (lv2220, lv616, lv617), out_sinfo=R.Tensor((1, 1, 2560), dtype="float16")) | |
lv2223: R.Tensor((1, 1, 2560), dtype="float16") = lv858 | |
lv859: R.Tensor((7680, 320), dtype="uint32") = model_params[102] | |
lv860: R.Tensor((7680, 80), dtype="float16") = model_params[103] | |
lv620: R.Tensor((7680,), dtype="float16") = model_params[104] | |
lv12 = R.call_tir(cls.fused_fused_decode2_fused_NT_matmul6_add5, (lv859, lv860, lv2223, lv620), out_sinfo=R.Tensor((1, 1, 7680), dtype="float16")) | |
lv863 = R.call_tir(cls.fused_reshape7_split1, (lv12,), out_sinfo=[R.Tensor((1, 1, 32, 80), dtype="float16"), R.Tensor((1, 1, 32, 80), dtype="float16"), R.Tensor((1, 1, 32, 80), dtype="float16")]) | |
lv2229: R.Tensor((1, 1, 32, 80), dtype="float16") = lv863[0] | |
lv2230 = R.call_tir(cls.rotary_embedding1, (lv2229, metadata["relax.expr.Constant"][1], metadata["relax.expr.Constant"][2]), out_sinfo=R.Tensor((1, 1, 32, 80), dtype="float16"), tir_vars=R.shape([n])) | |
lv2231: R.Tensor((1, 1, 32, 80), dtype="float16") = lv863[1] | |
lv2232 = R.call_tir(cls.rotary_embedding1, (lv2231, metadata["relax.expr.Constant"][1], metadata["relax.expr.Constant"][2]), out_sinfo=R.Tensor((1, 1, 32, 80), dtype="float16"), tir_vars=R.shape([n])) | |
lv2233: R.Object = kv_cache[12] | |
lv2234 = R.call_tir(cls.squeeze1, (lv2232,), out_sinfo=R.Tensor((1, 32, 80), dtype="float16")) | |
lv2235: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv2233, lv2234, sinfo_args=(R.Object,)) | |
lv2236: R.Object = kv_cache[13] | |
lv864: R.Tensor((1, 1, 32, 80), dtype="float16") = lv863[2] | |
lv865 = R.call_tir(cls.fused_squeeze1, (lv864,), out_sinfo=R.Tensor((1, 32, 80), dtype="float16")) | |
lv2239: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv2236, lv865, sinfo_args=(R.Object,)) | |
lv2240: R.Tensor((n, 32, 80), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv2235, R.shape([n, 32, 80]), sinfo_args=(R.Tensor((n, 32, 80), dtype="float16"),)) | |
lv2241: R.Tensor((n, 32, 80), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv2239, R.shape([n, 32, 80]), sinfo_args=(R.Tensor((n, 32, 80), dtype="float16"),)) | |
lv2242 = R.call_tir(cls.reshape3, (lv2240,), out_sinfo=R.Tensor((1, n, 32, 80), dtype="float16")) | |
lv2243 = R.call_tir(cls.reshape3, (lv2241,), out_sinfo=R.Tensor((1, n, 32, 80), dtype="float16")) | |
lv2244 = R.call_tir(cls.transpose7, (lv2230,), out_sinfo=R.Tensor((1, 32, 1, 80), dtype="float16")) | |
lv2245 = R.call_tir(cls.transpose5, (lv2242,), out_sinfo=R.Tensor((1, 32, n, 80), dtype="float16")) | |
lv2246 = R.call_tir(cls.transpose5, (lv2243,), out_sinfo=R.Tensor((1, 32, n, 80), dtype="float16")) | |
lv866 = R.call_tir(cls.fused_NT_matmul7_divide2_maximum1_minimum1_cast9, (lv2244, lv2245, lv1871), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float32")) | |
lv867 = R.call_tir(cls.fused_softmax2_cast10, (lv866,), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float16")) | |
lv2255 = R.call_tir(cls.matmul9, (lv867, lv2246), out_sinfo=R.Tensor((1, 32, 1, 80), dtype="float16")) | |
lv868 = R.call_tir(cls.fused_transpose8_reshape8, (lv2255,), out_sinfo=R.Tensor((1, 1, 2560), dtype="float16")) | |
lv869: R.Tensor((2560, 320), dtype="uint32") = model_params[105] | |
lv870: R.Tensor((2560, 80), dtype="float16") = model_params[106] | |
lv623: R.Tensor((2560,), dtype="float16") = model_params[107] | |
lv12_1 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul8_add6_add7, (lv869, lv870, lv868, lv623, lv11_1), out_sinfo=R.Tensor((1, 1, 2560), dtype="float16")) | |
lv2262 = R.call_tir(cls.cast7, (lv12_1,), out_sinfo=R.Tensor((1, 1, 2560), dtype="float32")) | |
lv624: R.Tensor((2560,), dtype="float32") = model_params[100] | |
lv625: R.Tensor((2560,), dtype="float32") = model_params[101] | |
lv873 = R.call_tir(cls.fused_layer_norm1_cast8, (lv2262, lv624, lv625), out_sinfo=R.Tensor((1, 1, 2560), dtype="float16")) | |
lv2265: R.Tensor((1, 1, 2560), dtype="float16") = lv873 | |
lv874: R.Tensor((10240, 320), dtype="uint32") = model_params[108] | |
lv875: R.Tensor((10240, 80), dtype="float16") = model_params[109] | |
lv628: R.Tensor((10240,), dtype="float32") = model_params[110] | |
lv13 = R.call_tir(cls.fused_fused_decode4_fused_NT_matmul9_add8_gelu1_cast11, (lv874, lv875, lv2265, lv628), out_sinfo=R.Tensor((1, 1, 10240), dtype="float16")) | |
lv2271: R.Tensor((1, 1, 10240), dtype="float16") = lv13 | |
lv878: R.Tensor((2560, 1280), dtype="uint32") = model_params[111] | |
lv879: R.Tensor((2560, 320), dtype="float16") = model_params[112] | |
lv631: R.Tensor((2560,), dtype="float32") = model_params[113] | |
lv13_1 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul10_add9_cast8_cast12_add7, (lv878, lv879, lv2271, lv631, lv12_1), out_sinfo=R.Tensor((1, 1, 2560), dtype="float16")) | |
lv2278 = R.call_tir(cls.cast7, (lv13_1,), out_sinfo=R.Tensor((1, 1, 2560), dtype="float32")) | |
lv632: R.Tensor((2560,), dtype="float32") = model_params[114] | |
lv633: R.Tensor((2560,), dtype="float32") = model_params[115] | |
lv882 = R.call_tir(cls.fused_layer_norm1_cast8, (lv2278, lv632, lv633), out_sinfo=R.Tensor((1, 1, 2560), dtype="float16")) | |
lv2281: R.Tensor((1, 1, 2560), dtype="float16") = lv882 | |
lv883: R.Tensor((7680, 320), dtype="uint32") = model_params[118] | |
lv884: R.Tensor((7680, 80), dtype="float16") = model_params[119] | |
lv636: R.Tensor((7680,), dtype="float16") = model_params[120] | |
lv14 = R.call_tir(cls.fused_fused_decode2_fused_NT_matmul6_add5, (lv883, lv884, lv2281, lv636), out_sinfo=R.Tensor((1, 1, 7680), dtype="float16")) | |
lv887 = R.call_tir(cls.fused_reshape7_split1, (lv14,), out_sinfo=[R.Tensor((1, 1, 32, 80), dtype="float16"), R.Tensor((1, 1, 32, 80), dtype="float16"), R.Tensor((1, 1, 32, 80), dtype="float16")]) | |
lv2287: R.Tensor((1, 1, 32, 80), dtype="float16") = lv887[0] | |
lv2288 = R.call_tir(cls.rotary_embedding1, (lv2287, metadata["relax.expr.Constant"][1], metadata["relax.expr.Constant"][2]), out_sinfo=R.Tensor((1, 1, 32, 80), dtype="float16"), tir_vars=R.shape([n])) | |
lv2289: R.Tensor((1, 1, 32, 80), dtype="float16") = lv887[1] | |
lv2290 = R.call_tir(cls.rotary_embedding1, (lv2289, metadata["relax.expr.Constant"][1], metadata["relax.expr.Constant"][2]), out_sinfo=R.Tensor((1, 1, 32, 80), dtype="float16"), tir_vars=R.shape([n])) | |
lv2291: R.Object = kv_cache[14] | |
lv2292 = R.call_tir(cls.squeeze1, (lv2290,), out_sinfo=R.Tensor((1, 32, 80), dtype="float16")) | |
lv2293: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv2291, lv2292, sinfo_args=(R.Object,)) | |
lv2294: R.Object = kv_cache[15] | |
lv888: R.Tensor((1, 1, 32, 80), dtype="float16") = lv887[2] | |
lv889 = R.call_tir(cls.fused_squeeze1, (lv888,), out_sinfo=R.Tensor((1, 32, 80), dtype="float16")) | |
lv2297: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv2294, lv889, sinfo_args=(R.Object,)) | |
lv2298: R.Tensor((n, 32, 80), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv2293, R.shape([n, 32, 80]), sinfo_args=(R.Tensor((n, 32, 80), dtype="float16"),)) | |
lv2299: R.Tensor((n, 32, 80), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv2297, R.shape([n, 32, 80]), sinfo_args=(R.Tensor((n, 32, 80), dtype="float16"),)) | |
lv2300 = R.call_tir(cls.reshape3, (lv2298,), out_sinfo=R.Tensor((1, n, 32, 80), dtype="float16")) | |
lv2301 = R.call_tir(cls.reshape3, (lv2299,), out_sinfo=R.Tensor((1, n, 32, 80), dtype="float16")) | |
lv2302 = R.call_tir(cls.transpose7, (lv2288,), out_sinfo=R.Tensor((1, 32, 1, 80), dtype="float16")) | |
lv2303 = R.call_tir(cls.transpose5, (lv2300,), out_sinfo=R.Tensor((1, 32, n, 80), dtype="float16")) | |
lv2304 = R.call_tir(cls.transpose5, (lv2301,), out_sinfo=R.Tensor((1, 32, n, 80), dtype="float16")) | |
lv890 = R.call_tir(cls.fused_NT_matmul7_divide2_maximum1_minimum1_cast9, (lv2302, lv2303, lv1871), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float32")) | |
lv891 = R.call_tir(cls.fused_softmax2_cast10, (lv890,), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float16")) | |
lv2313 = R.call_tir(cls.matmul9, (lv891, lv2304), out_sinfo=R.Tensor((1, 32, 1, 80), dtype="float16")) | |
lv892 = R.call_tir(cls.fused_transpose8_reshape8, (lv2313,), out_sinfo=R.Tensor((1, 1, 2560), dtype="float16")) | |
lv893: R.Tensor((2560, 320), dtype="uint32") = model_params[121] | |
lv894: R.Tensor((2560, 80), dtype="float16") = model_params[122] | |
lv639: R.Tensor((2560,), dtype="float16") = model_params[123] | |
lv14_1 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul8_add6_add7, (lv893, lv894, lv892, lv639, lv13_1), out_sinfo=R.Tensor((1, 1, 2560), dtype="float16")) | |
lv2320 = R.call_tir(cls.cast7, (lv14_1,), out_sinfo=R.Tensor((1, 1, 2560), dtype="float32")) | |
lv640: R.Tensor((2560,), dtype="float32") = model_params[116] | |
lv641: R.Tensor((2560,), dtype="float32") = model_params[117] | |
lv897 = R.call_tir(cls.fused_layer_norm1_cast8, (lv2320, lv640, lv641), out_sinfo=R.Tensor((1, 1, 2560), dtype="float16")) | |
lv2323: R.Tensor((1, 1, 2560), dtype="float16") = lv897 | |
lv898: R.Tensor((10240, 320), dtype="uint32") = model_params[124] | |
lv899: R.Tensor((10240, 80), dtype="float16") = model_params[125] | |
lv644: R.Tensor((10240,), dtype="float32") = model_params[126] | |
lv15 = R.call_tir(cls.fused_fused_decode4_fused_NT_matmul9_add8_gelu1_cast11, (lv898, lv899, lv2323, lv644), out_sinfo=R.Tensor((1, 1, 10240), dtype="float16")) | |
lv2329: R.Tensor((1, 1, 10240), dtype="float16") = lv15 | |
lv902: R.Tensor((2560, 1280), dtype="uint32") = model_params[127] | |
lv903: R.Tensor((2560, 320), dtype="float16") = model_params[128] | |
lv647: R.Tensor((2560,), dtype="float32") = model_params[129] | |
lv15_1 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul10_add9_cast8_cast12_add7, (lv902, lv903, lv2329, lv647, lv14_1), out_sinfo=R.Tensor((1, 1, 2560), dtype="float16")) | |
lv2336 = R.call_tir(cls.cast7, (lv15_1,), out_sinfo=R.Tensor((1, 1, 2560), dtype="float32")) | |
lv648: R.Tensor((2560,), dtype="float32") = model_params[130] | |
lv649: R.Tensor((2560,), dtype="float32") = model_params[131] | |
lv906 = R.call_tir(cls.fused_layer_norm1_cast8, (lv2336, lv648, lv649), out_sinfo=R.Tensor((1, 1, 2560), dtype="float16")) | |
lv2339: R.Tensor((1, 1, 2560), dtype="float16") = lv906 | |
lv907: R.Tensor((7680, 320), dtype="uint32") = model_params[134] | |
lv908: R.Tensor((7680, 80), dtype="float16") = model_params[135] | |
lv652: R.Tensor((7680,), dtype="float16") = model_params[136] | |
lv16 = R.call_tir(cls.fused_fused_decode2_fused_NT_matmul6_add5, (lv907, lv908, lv2339, lv652), out_sinfo=R.Tensor((1, 1, 7680), dtype="float16")) | |
lv911 = R.call_tir(cls.fused_reshape7_split1, (lv16,), out_sinfo=[R.Tensor((1, 1, 32, 80), dtype="float16"), R.Tensor((1, 1, 32, 80), dtype="float16"), R.Tensor((1, 1, 32, 80), dtype="float16")]) | |
lv2345: R.Tensor((1, 1, 32, 80), dtype="float16") = lv911[0] | |
lv2346 = R.call_tir(cls.rotary_embedding1, (lv2345, metadata["relax.expr.Constant"][1], metadata["relax.expr.Constant"][2]), out_sinfo=R.Tensor((1, 1, 32, 80), dtype="float16"), tir_vars=R.shape([n])) | |
lv2347: R.Tensor((1, 1, 32, 80), dtype="float16") = lv911[1] | |
lv2348 = R.call_tir(cls.rotary_embedding1, (lv2347, metadata["relax.expr.Constant"][1], metadata["relax.expr.Constant"][2]), out_sinfo=R.Tensor((1, 1, 32, 80), dtype="float16"), tir_vars=R.shape([n])) | |
lv2349: R.Object = kv_cache[16] | |
lv2350 = R.call_tir(cls.squeeze1, (lv2348,), out_sinfo=R.Tensor((1, 32, 80), dtype="float16")) | |
lv2351: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv2349, lv2350, sinfo_args=(R.Object,)) | |
lv2352: R.Object = kv_cache[17] | |
lv912: R.Tensor((1, 1, 32, 80), dtype="float16") = lv911[2] | |
lv913 = R.call_tir(cls.fused_squeeze1, (lv912,), out_sinfo=R.Tensor((1, 32, 80), dtype="float16")) | |
lv2355: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv2352, lv913, sinfo_args=(R.Object,)) | |
lv2356: R.Tensor((n, 32, 80), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv2351, R.shape([n, 32, 80]), sinfo_args=(R.Tensor((n, 32, 80), dtype="float16"),)) | |
lv2357: R.Tensor((n, 32, 80), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv2355, R.shape([n, 32, 80]), sinfo_args=(R.Tensor((n, 32, 80), dtype="float16"),)) | |
lv2358 = R.call_tir(cls.reshape3, (lv2356,), out_sinfo=R.Tensor((1, n, 32, 80), dtype="float16")) | |
lv2359 = R.call_tir(cls.reshape3, (lv2357,), out_sinfo=R.Tensor((1, n, 32, 80), dtype="float16")) | |
lv2360 = R.call_tir(cls.transpose7, (lv2346,), out_sinfo=R.Tensor((1, 32, 1, 80), dtype="float16")) | |
lv2361 = R.call_tir(cls.transpose5, (lv2358,), out_sinfo=R.Tensor((1, 32, n, 80), dtype="float16")) | |
lv2362 = R.call_tir(cls.transpose5, (lv2359,), out_sinfo=R.Tensor((1, 32, n, 80), dtype="float16")) | |
lv914 = R.call_tir(cls.fused_NT_matmul7_divide2_maximum1_minimum1_cast9, (lv2360, lv2361, lv1871), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float32")) | |
lv915 = R.call_tir(cls.fused_softmax2_cast10, (lv914,), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float16")) | |
lv2371 = R.call_tir(cls.matmul9, (lv915, lv2362), out_sinfo=R.Tensor((1, 32, 1, 80), dtype="float16")) | |
lv916 = R.call_tir(cls.fused_transpose8_reshape8, (lv2371,), out_sinfo=R.Tensor((1, 1, 2560), dtype="float16")) | |
lv917: R.Tensor((2560, 320), dtype="uint32") = model_params[137] | |
lv918: R.Tensor((2560, 80), dtype="float16") = model_params[138] | |
lv655: R.Tensor((2560,), dtype="float16") = model_params[139] | |
lv16_1 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul8_add6_add7, (lv917, lv918, lv916, lv655, lv15_1), out_sinfo=R.Tensor((1, 1, 2560), dtype="float16")) | |
lv2378 = R.call_tir(cls.cast7, (lv16_1,), out_sinfo=R.Tensor((1, 1, 2560), dtype="float32")) | |
lv656: R.Tensor((2560,), dtype="float32") = model_params[132] | |
lv657: R.Tensor((2560,), dtype="float32") = model_params[133] | |
lv921 = R.call_tir(cls.fused_layer_norm1_cast8, (lv2378, lv656, lv657), out_sinfo=R.Tensor((1, 1, 2560), dtype="float16")) | |
lv2381: R.Tensor((1, 1, 2560), dtype="float16") = lv921 | |
lv922: R.Tensor((10240, 320), dtype="uint32") = model_params[140] | |
lv923: R.Tensor((10240, 80), dtype="float16") = model_params[141] | |
lv660: R.Tensor((10240,), dtype="float32") = model_params[142] | |
lv17 = R.call_tir(cls.fused_fused_decode4_fused_NT_matmul9_add8_gelu1_cast11, (lv922, lv923, lv2381, lv660), out_sinfo=R.Tensor((1, 1, 10240), dtype="float16")) | |
lv2387: R.Tensor((1, 1, 10240), dtype="float16") = lv17 | |
lv926: R.Tensor((2560, 1280), dtype="uint32") = model_params[143] | |
lv927: R.Tensor((2560, 320), dtype="float16") = model_params[144] | |
lv663: R.Tensor((2560,), dtype="float32") = model_params[145] | |
lv17_1 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul10_add9_cast8_cast12_add7, (lv926, lv927, lv2387, lv663, lv16_1), out_sinfo=R.Tensor((1, 1, 2560), dtype="float16")) | |
lv2394 = R.call_tir(cls.cast7, (lv17_1,), out_sinfo=R.Tensor((1, 1, 2560), dtype="float32")) | |
lv664: R.Tensor((2560,), dtype="float32") = model_params[146] | |
lv665: R.Tensor((2560,), dtype="float32") = model_params[147] | |
lv930 = R.call_tir(cls.fused_layer_norm1_cast8, (lv2394, lv664, lv665), out_sinfo=R.Tensor((1, 1, 2560), dtype="float16")) | |
lv2397: R.Tensor((1, 1, 2560), dtype="float16") = lv930 | |
lv931: R.Tensor((7680, 320), dtype="uint32") = model_params[150] | |
lv932: R.Tensor((7680, 80), dtype="float16") = model_params[151] | |
lv668: R.Tensor((7680,), dtype="float16") = model_params[152] | |
lv18 = R.call_tir(cls.fused_fused_decode2_fused_NT_matmul6_add5, (lv931, lv932, lv2397, lv668), out_sinfo=R.Tensor((1, 1, 7680), dtype="float16")) | |
lv935 = R.call_tir(cls.fused_reshape7_split1, (lv18,), out_sinfo=[R.Tensor((1, 1, 32, 80), dtype="float16"), R.Tensor((1, 1, 32, 80), dtype="float16"), R.Tensor((1, 1, 32, 80), dtype="float16")]) | |
lv2403: R.Tensor((1, 1, 32, 80), dtype="float16") = lv935[0] | |
lv2404 = R.call_tir(cls.rotary_embedding1, (lv2403, metadata["relax.expr.Constant"][1], metadata["relax.expr.Constant"][2]), out_sinfo=R.Tensor((1, 1, 32, 80), dtype="float16"), tir_vars=R.shape([n])) | |
lv2405: R.Tensor((1, 1, 32, 80), dtype="float16") = lv935[1] | |
lv2406 = R.call_tir(cls.rotary_embedding1, (lv2405, metadata["relax.expr.Constant"][1], metadata["relax.expr.Constant"][2]), out_sinfo=R.Tensor((1, 1, 32, 80), dtype="float16"), tir_vars=R.shape([n])) | |
lv2407: R.Object = kv_cache[18] | |
lv2408 = R.call_tir(cls.squeeze1, (lv2406,), out_sinfo=R.Tensor((1, 32, 80), dtype="float16")) | |
lv2409: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv2407, lv2408, sinfo_args=(R.Object,)) | |
lv2410: R.Object = kv_cache[19] | |
lv936: R.Tensor((1, 1, 32, 80), dtype="float16") = lv935[2] | |
lv937 = R.call_tir(cls.fused_squeeze1, (lv936,), out_sinfo=R.Tensor((1, 32, 80), dtype="float16")) | |
lv2413: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv2410, lv937, sinfo_args=(R.Object,)) | |
lv2414: R.Tensor((n, 32, 80), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv2409, R.shape([n, 32, 80]), sinfo_args=(R.Tensor((n, 32, 80), dtype="float16"),)) | |
lv2415: R.Tensor((n, 32, 80), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv2413, R.shape([n, 32, 80]), sinfo_args=(R.Tensor((n, 32, 80), dtype="float16"),)) | |
lv2416 = R.call_tir(cls.reshape3, (lv2414,), out_sinfo=R.Tensor((1, n, 32, 80), dtype="float16")) | |
lv2417 = R.call_tir(cls.reshape3, (lv2415,), out_sinfo=R.Tensor((1, n, 32, 80), dtype="float16")) | |
lv2418 = R.call_tir(cls.transpose7, (lv2404,), out_sinfo=R.Tensor((1, 32, 1, 80), dtype="float16")) | |
lv2419 = R.call_tir(cls.transpose5, (lv2416,), out_sinfo=R.Tensor((1, 32, n, 80), dtype="float16")) | |
lv2420 = R.call_tir(cls.transpose5, (lv2417,), out_sinfo=R.Tensor((1, 32, n, 80), dtype="float16")) | |
lv938 = R.call_tir(cls.fused_NT_matmul7_divide2_maximum1_minimum1_cast9, (lv2418, lv2419, lv1871), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float32")) | |
lv939 = R.call_tir(cls.fused_softmax2_cast10, (lv938,), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float16")) | |
lv2429 = R.call_tir(cls.matmul9, (lv939, lv2420), out_sinfo=R.Tensor((1, 32, 1, 80), dtype="float16")) | |
lv940 = R.call_tir(cls.fused_transpose8_reshape8, (lv2429,), out_sinfo=R.Tensor((1, 1, 2560), dtype="float16")) | |
lv941: R.Tensor((2560, 320), dtype="uint32") = model_params[153] | |
lv942: R.Tensor((2560, 80), dtype="float16") = model_params[154] | |
lv671: R.Tensor((2560,), dtype="float16") = model_params[155] | |
lv18_1 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul8_add6_add7, (lv941, lv942, lv940, lv671, lv17_1), out_sinfo=R.Tensor((1, 1, 2560), dtype="float16")) | |
lv2436 = R.call_tir(cls.cast7, (lv18_1,), out_sinfo=R.Tensor((1, 1, 2560), dtype="float32")) | |
lv672: R.Tensor((2560,), dtype="float32") = model_params[148] | |
lv673: R.Tensor((2560,), dtype="float32") = model_params[149] | |
lv945 = R.call_tir(cls.fused_layer_norm1_cast8, (lv2436, lv672, lv673), out_sinfo=R.Tensor((1, 1, 2560), dtype="float16")) | |
lv2439: R.Tensor((1, 1, 2560), dtype="float16") = lv945 | |
lv946: R.Tensor((10240, 320), dtype="uint32") = model_params[156] | |
lv947: R.Tensor((10240, 80), dtype="float16") = model_params[157] | |
lv676: R.Tensor((10240,), dtype="float32") = model_params[158] | |
lv19 = R.call_tir(cls.fused_fused_decode4_fused_NT_matmul9_add8_gelu1_cast11, (lv946, lv947, lv2439, lv676), out_sinfo=R.Tensor((1, 1, 10240), dtype="float16")) | |
lv2445: R.Tensor((1, 1, 10240), dtype="float16") = lv19 | |
lv950: R.Tensor((2560, 1280), dtype="uint32") = model_params[159] | |
lv951: R.Tensor((2560, 320), dtype="float16") = model_params[160] | |
lv679: R.Tensor((2560,), dtype="float32") = model_params[161] | |
lv19_1 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul10_add9_cast8_cast12_add7, (lv950, lv951, lv2445, lv679, lv18_1), out_sinfo=R.Tensor((1, 1, 2560), dtype="float16")) | |
lv2452 = R.call_tir(cls.cast7, (lv19_1,), out_sinfo=R.Tensor((1, 1, 2560), dtype="float32")) | |
lv680: R.Tensor((2560,), dtype="float32") = model_params[162] | |
lv681: R.Tensor((2560,), dtype="float32") = model_params[163] | |
lv954 = R.call_tir(cls.fused_layer_norm1_cast8, (lv2452, lv680, lv681), out_sinfo=R.Tensor((1, 1, 2560), dtype="float16")) | |
lv2455: R.Tensor((1, 1, 2560), dtype="float16") = lv954 | |
lv955: R.Tensor((7680, 320), dtype="uint32") = model_params[166] | |
lv956: R.Tensor((7680, 80), dtype="float16") = model_params[167] | |
lv684: R.Tensor((7680,), dtype="float16") = model_params[168] | |
lv20 = R.call_tir(cls.fused_fused_decode2_fused_NT_matmul6_add5, (lv955, lv956, lv2455, lv684), out_sinfo=R.Tensor((1, 1, 7680), dtype="float16")) | |
lv959 = R.call_tir(cls.fused_reshape7_split1, (lv20,), out_sinfo=[R.Tensor((1, 1, 32, 80), dtype="float16"), R.Tensor((1, 1, 32, 80), dtype="float16"), R.Tensor((1, 1, 32, 80), dtype="float16")]) | |
lv2461: R.Tensor((1, 1, 32, 80), dtype="float16") = lv959[0] | |
lv2462 = R.call_tir(cls.rotary_embedding1, (lv2461, metadata["relax.expr.Constant"][1], metadata["relax.expr.Constant"][2]), out_sinfo=R.Tensor((1, 1, 32, 80), dtype="float16"), tir_vars=R.shape([n])) | |
lv2463: R.Tensor((1, 1, 32, 80), dtype="float16") = lv959[1] | |
lv2464 = R.call_tir(cls.rotary_embedding1, (lv2463, metadata["relax.expr.Constant"][1], metadata["relax.expr.Constant"][2]), out_sinfo=R.Tensor((1, 1, 32, 80), dtype="float16"), tir_vars=R.shape([n])) | |
lv2465: R.Object = kv_cache[20] | |
lv2466 = R.call_tir(cls.squeeze1, (lv2464,), out_sinfo=R.Tensor((1, 32, 80), dtype="float16")) | |
lv2467: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv2465, lv2466, sinfo_args=(R.Object,)) | |
lv2468: R.Object = kv_cache[21] | |
lv960: R.Tensor((1, 1, 32, 80), dtype="float16") = lv959[2] | |
lv961 = R.call_tir(cls.fused_squeeze1, (lv960,), out_sinfo=R.Tensor((1, 32, 80), dtype="float16")) | |
lv2471: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv2468, lv961, sinfo_args=(R.Object,)) | |
lv2472: R.Tensor((n, 32, 80), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv2467, R.shape([n, 32, 80]), sinfo_args=(R.Tensor((n, 32, 80), dtype="float16"),)) | |
lv2473: R.Tensor((n, 32, 80), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv2471, R.shape([n, 32, 80]), sinfo_args=(R.Tensor((n, 32, 80), dtype="float16"),)) | |
lv2474 = R.call_tir(cls.reshape3, (lv2472,), out_sinfo=R.Tensor((1, n, 32, 80), dtype="float16")) | |
lv2475 = R.call_tir(cls.reshape3, (lv2473,), out_sinfo=R.Tensor((1, n, 32, 80), dtype="float16")) | |
lv2476 = R.call_tir(cls.transpose7, (lv2462,), out_sinfo=R.Tensor((1, 32, 1, 80), dtype="float16")) | |
lv2477 = R.call_tir(cls.transpose5, (lv2474,), out_sinfo=R.Tensor((1, 32, n, 80), dtype="float16")) | |
lv2478 = R.call_tir(cls.transpose5, (lv2475,), out_sinfo=R.Tensor((1, 32, n, 80), dtype="float16")) | |
lv962 = R.call_tir(cls.fused_NT_matmul7_divide2_maximum1_minimum1_cast9, (lv2476, lv2477, lv1871), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float32")) | |
lv963 = R.call_tir(cls.fused_softmax2_cast10, (lv962,), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float16")) | |
lv2487 = R.call_tir(cls.matmul9, (lv963, lv2478), out_sinfo=R.Tensor((1, 32, 1, 80), dtype="float16")) | |
lv964 = R.call_tir(cls.fused_transpose8_reshape8, (lv2487,), out_sinfo=R.Tensor((1, 1, 2560), dtype="float16")) | |
lv965: R.Tensor((2560, 320), dtype="uint32") = model_params[169] | |
lv966: R.Tensor((2560, 80), dtype="float16") = model_params[170] | |
lv687: R.Tensor((2560,), dtype="float16") = model_params[171] | |
lv20_1 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul8_add6_add7, (lv965, lv966, lv964, lv687, lv19_1), out_sinfo=R.Tensor((1, 1, 2560), dtype="float16")) | |
lv2494 = R.call_tir(cls.cast7, (lv20_1,), out_sinfo=R.Tensor((1, 1, 2560), dtype="float32")) | |
lv688: R.Tensor((2560,), dtype="float32") = model_params[164] | |
lv689: R.Tensor((2560,), dtype="float32") = model_params[165] | |
lv969 = R.call_tir(cls.fused_layer_norm1_cast8, (lv2494, lv688, lv689), out_sinfo=R.Tensor((1, 1, 2560), dtype="float16")) | |
lv2497: R.Tensor((1, 1, 2560), dtype="float16") = lv969 | |
lv970: R.Tensor((10240, 320), dtype="uint32") = model_params[172] | |
lv971: R.Tensor((10240, 80), dtype="float16") = model_params[173] | |
lv692: R.Tensor((10240,), dtype="float32") = model_params[174] | |
lv21 = R.call_tir(cls.fused_fused_decode4_fused_NT_matmul9_add8_gelu1_cast11, (lv970, lv971, lv2497, lv692), out_sinfo=R.Tensor((1, 1, 10240), dtype="float16")) | |
lv2503: R.Tensor((1, 1, 10240), dtype="float16") = lv21 | |
lv974: R.Tensor((2560, 1280), dtype="uint32") = model_params[175] | |
lv975: R.Tensor((2560, 320), dtype="float16") = model_params[176] | |
lv695: R.Tensor((2560,), dtype="float32") = model_params[177] | |
lv21_1 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul10_add9_cast8_cast12_add7, (lv974, lv975, lv2503, lv695, lv20_1), out_sinfo=R.Tensor((1, 1, 2560), dtype="float16")) | |
lv2510 = R.call_tir(cls.cast7, (lv21_1,), out_sinfo=R.Tensor((1, 1, 2560), dtype="float32")) | |
lv696: R.Tensor((2560,), dtype="float32") = model_params[178] | |
lv697: R.Tensor((2560,), dtype="float32") = model_params[179] | |
lv978 = R.call_tir(cls.fused_layer_norm1_cast8, (lv2510, lv696, lv697), out_sinfo=R.Tensor((1, 1, 2560), dtype="float16")) | |
lv2513: R.Tensor((1, 1, 2560), dtype="float16") = lv978 | |
lv979: R.Tensor((7680, 320), dtype="uint32") = model_params[182] | |
lv980: R.Tensor((7680, 80), dtype="float16") = model_params[183] | |
lv700: R.Tensor((7680,), dtype="float16") = model_params[184] | |
lv22 = R.call_tir(cls.fused_fused_decode2_fused_NT_matmul6_add5, (lv979, lv980, lv2513, lv700), out_sinfo=R.Tensor((1, 1, 7680), dtype="float16")) | |
lv983 = R.call_tir(cls.fused_reshape7_split1, (lv22,), out_sinfo=[R.Tensor((1, 1, 32, 80), dtype="float16"), R.Tensor((1, 1, 32, 80), dtype="float16"), R.Tensor((1, 1, 32, 80), dtype="float16")]) | |
lv2519: R.Tensor((1, 1, 32, 80), dtype="float16") = lv983[0] | |
lv2520 = R.call_tir(cls.rotary_embedding1, (lv2519, metadata["relax.expr.Constant"][1], metadata["relax.expr.Constant"][2]), out_sinfo=R.Tensor((1, 1, 32, 80), dtype="float16"), tir_vars=R.shape([n])) | |
lv2521: R.Tensor((1, 1, 32, 80), dtype="float16") = lv983[1] | |
lv2522 = R.call_tir(cls.rotary_embedding1, (lv2521, metadata["relax.expr.Constant"][1], metadata["relax.expr.Constant"][2]), out_sinfo=R.Tensor((1, 1, 32, 80), dtype="float16"), tir_vars=R.shape([n])) | |
lv2523: R.Object = kv_cache[22] | |
lv2524 = R.call_tir(cls.squeeze1, (lv2522,), out_sinfo=R.Tensor((1, 32, 80), dtype="float16")) | |
lv2525: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv2523, lv2524, sinfo_args=(R.Object,)) | |
lv2526: R.Object = kv_cache[23] | |
lv984: R.Tensor((1, 1, 32, 80), dtype="float16") = lv983[2] | |
lv985 = R.call_tir(cls.fused_squeeze1, (lv984,), out_sinfo=R.Tensor((1, 32, 80), dtype="float16")) | |
lv2529: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv2526, lv985, sinfo_args=(R.Object,)) | |
lv2530: R.Tensor((n, 32, 80), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv2525, R.shape([n, 32, 80]), sinfo_args=(R.Tensor((n, 32, 80), dtype="float16"),)) | |
lv2531: R.Tensor((n, 32, 80), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv2529, R.shape([n, 32, 80]), sinfo_args=(R.Tensor((n, 32, 80), dtype="float16"),)) | |
lv2532 = R.call_tir(cls.reshape3, (lv2530,), out_sinfo=R.Tensor((1, n, 32, 80), dtype="float16")) | |
lv2533 = R.call_tir(cls.reshape3, (lv2531,), out_sinfo=R.Tensor((1, n, 32, 80), dtype="float16")) | |
lv2534 = R.call_tir(cls.transpose7, (lv2520,), out_sinfo=R.Tensor((1, 32, 1, 80), dtype="float16")) | |
lv2535 = R.call_tir(cls.transpose5, (lv2532,), out_sinfo=R.Tensor((1, 32, n, 80), dtype="float16")) | |
lv2536 = R.call_tir(cls.transpose5, (lv2533,), out_sinfo=R.Tensor((1, 32, n, 80), dtype="float16")) | |
lv986 = R.call_tir(cls.fused_NT_matmul7_divide2_maximum1_minimum1_cast9, (lv2534, lv2535, lv1871), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float32")) | |
lv987 = R.call_tir(cls.fused_softmax2_cast10, (lv986,), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float16")) | |
lv2545 = R.call_tir(cls.matmul9, (lv987, lv2536), out_sinfo=R.Tensor((1, 32, 1, 80), dtype="float16")) | |
lv988 = R.call_tir(cls.fused_transpose8_reshape8, (lv2545,), out_sinfo=R.Tensor((1, 1, 2560), dtype="float16")) | |
lv989: R.Tensor((2560, 320), dtype="uint32") = model_params[185] | |
lv990: R.Tensor((2560, 80), dtype="float16") = model_params[186] | |
lv703: R.Tensor((2560,), dtype="float16") = model_params[187] | |
lv22_1 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul8_add6_add7, (lv989, lv990, lv988, lv703, lv21_1), out_sinfo=R.Tensor((1, 1, 2560), dtype="float16")) | |
lv2552 = R.call_tir(cls.cast7, (lv22_1,), out_sinfo=R.Tensor((1, 1, 2560), dtype="float32")) | |
lv704: R.Tensor((2560,), dtype="float32") = model_params[180] | |
lv705: R.Tensor((2560,), dtype="float32") = model_params[181] | |
lv993 = R.call_tir(cls.fused_layer_norm1_cast8, (lv2552, lv704, lv705), out_sinfo=R.Tensor((1, 1, 2560), dtype="float16")) | |
lv2555: R.Tensor((1, 1, 2560), dtype="float16") = lv993 | |
lv994: R.Tensor((10240, 320), dtype="uint32") = model_params[188] | |
lv995: R.Tensor((10240, 80), dtype="float16") = model_params[189] | |
lv708: R.Tensor((10240,), dtype="float32") = model_params[190] | |
lv23 = R.call_tir(cls.fused_fused_decode4_fused_NT_matmul9_add8_gelu1_cast11, (lv994, lv995, lv2555, lv708), out_sinfo=R.Tensor((1, 1, 10240), dtype="float16")) | |
lv2561: R.Tensor((1, 1, 10240), dtype="float16") = lv23 | |
lv998: R.Tensor((2560, 1280), dtype="uint32") = model_params[191] | |
lv999: R.Tensor((2560, 320), dtype="float16") = model_params[192] | |
lv711_1: R.Tensor((2560,), dtype="float32") = model_params[193] | |
lv23_1 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul10_add9_cast8_cast12_add7, (lv998, lv999, lv2561, lv711_1, lv22_1), out_sinfo=R.Tensor((1, 1, 2560), dtype="float16")) | |
lv2568 = R.call_tir(cls.cast7, (lv23_1,), out_sinfo=R.Tensor((1, 1, 2560), dtype="float32")) | |
lv712_1: R.Tensor((2560,), dtype="float32") = model_params[194] | |
lv713: R.Tensor((2560,), dtype="float32") = model_params[195] | |
lv1002 = R.call_tir(cls.fused_layer_norm1_cast8, (lv2568, lv712_1, lv713), out_sinfo=R.Tensor((1, 1, 2560), dtype="float16")) | |
lv2571: R.Tensor((1, 1, 2560), dtype="float16") = lv1002 | |
lv1003: R.Tensor((7680, 320), dtype="uint32") = model_params[198] | |
lv1004: R.Tensor((7680, 80), dtype="float16") = model_params[199] | |
lv716_1: R.Tensor((7680,), dtype="float16") = model_params[200] | |
lv24 = R.call_tir(cls.fused_fused_decode2_fused_NT_matmul6_add5, (lv1003, lv1004, lv2571, lv716_1), out_sinfo=R.Tensor((1, 1, 7680), dtype="float16")) | |
lv1007 = R.call_tir(cls.fused_reshape7_split1, (lv24,), out_sinfo=[R.Tensor((1, 1, 32, 80), dtype="float16"), R.Tensor((1, 1, 32, 80), dtype="float16"), R.Tensor((1, 1, 32, 80), dtype="float16")]) | |
lv2577: R.Tensor((1, 1, 32, 80), dtype="float16") = lv1007[0] | |
lv2578 = R.call_tir(cls.rotary_embedding1, (lv2577, metadata["relax.expr.Constant"][1], metadata["relax.expr.Constant"][2]), out_sinfo=R.Tensor((1, 1, 32, 80), dtype="float16"), tir_vars=R.shape([n])) | |
lv2579: R.Tensor((1, 1, 32, 80), dtype="float16") = lv1007[1] | |
lv2580 = R.call_tir(cls.rotary_embedding1, (lv2579, metadata["relax.expr.Constant"][1], metadata["relax.expr.Constant"][2]), out_sinfo=R.Tensor((1, 1, 32, 80), dtype="float16"), tir_vars=R.shape([n])) | |
lv2581: R.Object = kv_cache[24] | |
lv2582 = R.call_tir(cls.squeeze1, (lv2580,), out_sinfo=R.Tensor((1, 32, 80), dtype="float16")) | |
lv2583: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv2581, lv2582, sinfo_args=(R.Object,)) | |
lv2584: R.Object = kv_cache[25] | |
lv1008: R.Tensor((1, 1, 32, 80), dtype="float16") = lv1007[2] | |
lv1009 = R.call_tir(cls.fused_squeeze1, (lv1008,), out_sinfo=R.Tensor((1, 32, 80), dtype="float16")) | |
lv2587: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv2584, lv1009, sinfo_args=(R.Object,)) | |
lv2588: R.Tensor((n, 32, 80), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv2583, R.shape([n, 32, 80]), sinfo_args=(R.Tensor((n, 32, 80), dtype="float16"),)) | |
lv2589: R.Tensor((n, 32, 80), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv2587, R.shape([n, 32, 80]), sinfo_args=(R.Tensor((n, 32, 80), dtype="float16"),)) | |
lv2590 = R.call_tir(cls.reshape3, (lv2588,), out_sinfo=R.Tensor((1, n, 32, 80), dtype="float16")) | |
lv2591 = R.call_tir(cls.reshape3, (lv2589,), out_sinfo=R.Tensor((1, n, 32, 80), dtype="float16")) | |
lv2592 = R.call_tir(cls.transpose7, (lv2578,), out_sinfo=R.Tensor((1, 32, 1, 80), dtype="float16")) | |
lv2593 = R.call_tir(cls.transpose5, (lv2590,), out_sinfo=R.Tensor((1, 32, n, 80), dtype="float16")) | |
lv2594 = R.call_tir(cls.transpose5, (lv2591,), out_sinfo=R.Tensor((1, 32, n, 80), dtype="float16")) | |
lv1010 = R.call_tir(cls.fused_NT_matmul7_divide2_maximum1_minimum1_cast9, (lv2592, lv2593, lv1871), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float32")) | |
lv1011 = R.call_tir(cls.fused_softmax2_cast10, (lv1010,), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float16")) | |
lv2603 = R.call_tir(cls.matmul9, (lv1011, lv2594), out_sinfo=R.Tensor((1, 32, 1, 80), dtype="float16")) | |
lv1012 = R.call_tir(cls.fused_transpose8_reshape8, (lv2603,), out_sinfo=R.Tensor((1, 1, 2560), dtype="float16")) | |
lv1013: R.Tensor((2560, 320), dtype="uint32") = model_params[201] | |
lv1014: R.Tensor((2560, 80), dtype="float16") = model_params[202] | |
lv719_1: R.Tensor((2560,), dtype="float16") = model_params[203] | |
lv24_1 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul8_add6_add7, (lv1013, lv1014, lv1012, lv719_1, lv23_1), out_sinfo=R.Tensor((1, 1, 2560), dtype="float16")) | |
lv2610 = R.call_tir(cls.cast7, (lv24_1,), out_sinfo=R.Tensor((1, 1, 2560), dtype="float32")) | |
lv720_1: R.Tensor((2560,), dtype="float32") = model_params[196] | |
lv721_1: R.Tensor((2560,), dtype="float32") = model_params[197] | |
lv1017 = R.call_tir(cls.fused_layer_norm1_cast8, (lv2610, lv720_1, lv721_1), out_sinfo=R.Tensor((1, 1, 2560), dtype="float16")) | |
lv2613: R.Tensor((1, 1, 2560), dtype="float16") = lv1017 | |
lv1018: R.Tensor((10240, 320), dtype="uint32") = model_params[204] | |
lv1019: R.Tensor((10240, 80), dtype="float16") = model_params[205] | |
lv724_1: R.Tensor((10240,), dtype="float32") = model_params[206] | |
lv25 = R.call_tir(cls.fused_fused_decode4_fused_NT_matmul9_add8_gelu1_cast11, (lv1018, lv1019, lv2613, lv724_1), out_sinfo=R.Tensor((1, 1, 10240), dtype="float16")) | |
lv2619: R.Tensor((1, 1, 10240), dtype="float16") = lv25 | |
lv1022: R.Tensor((2560, 1280), dtype="uint32") = model_params[207] | |
lv1023: R.Tensor((2560, 320), dtype="float16") = model_params[208] | |
lv727: R.Tensor((2560,), dtype="float32") = model_params[209] | |
lv25_1 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul10_add9_cast8_cast12_add7, (lv1022, lv1023, lv2619, lv727, lv24_1), out_sinfo=R.Tensor((1, 1, 2560), dtype="float16")) | |
lv2626 = R.call_tir(cls.cast7, (lv25_1,), out_sinfo=R.Tensor((1, 1, 2560), dtype="float32")) | |
lv728: R.Tensor((2560,), dtype="float32") = model_params[210] | |
lv729_1: R.Tensor((2560,), dtype="float32") = model_params[211] | |
lv1026 = R.call_tir(cls.fused_layer_norm1_cast8, (lv2626, lv728, lv729_1), out_sinfo=R.Tensor((1, 1, 2560), dtype="float16")) | |
lv2629: R.Tensor((1, 1, 2560), dtype="float16") = lv1026 | |
lv1027: R.Tensor((7680, 320), dtype="uint32") = model_params[214] | |
lv1028: R.Tensor((7680, 80), dtype="float16") = model_params[215] | |
lv732: R.Tensor((7680,), dtype="float16") = model_params[216] | |
lv26 = R.call_tir(cls.fused_fused_decode2_fused_NT_matmul6_add5, (lv1027, lv1028, lv2629, lv732), out_sinfo=R.Tensor((1, 1, 7680), dtype="float16")) | |
lv1031 = R.call_tir(cls.fused_reshape7_split1, (lv26,), out_sinfo=[R.Tensor((1, 1, 32, 80), dtype="float16"), R.Tensor((1, 1, 32, 80), dtype="float16"), R.Tensor((1, 1, 32, 80), dtype="float16")]) | |
lv2635: R.Tensor((1, 1, 32, 80), dtype="float16") = lv1031[0] | |
lv2636 = R.call_tir(cls.rotary_embedding1, (lv2635, metadata["relax.expr.Constant"][1], metadata["relax.expr.Constant"][2]), out_sinfo=R.Tensor((1, 1, 32, 80), dtype="float16"), tir_vars=R.shape([n])) | |
lv2637: R.Tensor((1, 1, 32, 80), dtype="float16") = lv1031[1] | |
lv2638 = R.call_tir(cls.rotary_embedding1, (lv2637, metadata["relax.expr.Constant"][1], metadata["relax.expr.Constant"][2]), out_sinfo=R.Tensor((1, 1, 32, 80), dtype="float16"), tir_vars=R.shape([n])) | |
lv2639: R.Object = kv_cache[26] | |
lv2640 = R.call_tir(cls.squeeze1, (lv2638,), out_sinfo=R.Tensor((1, 32, 80), dtype="float16")) | |
lv2641: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv2639, lv2640, sinfo_args=(R.Object,)) | |
lv2642: R.Object = kv_cache[27] | |
lv1032: R.Tensor((1, 1, 32, 80), dtype="float16") = lv1031[2] | |
lv1033 = R.call_tir(cls.fused_squeeze1, (lv1032,), out_sinfo=R.Tensor((1, 32, 80), dtype="float16")) | |
lv2645: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv2642, lv1033, sinfo_args=(R.Object,)) | |
lv2646: R.Tensor((n, 32, 80), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv2641, R.shape([n, 32, 80]), sinfo_args=(R.Tensor((n, 32, 80), dtype="float16"),)) | |
lv2647: R.Tensor((n, 32, 80), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv2645, R.shape([n, 32, 80]), sinfo_args=(R.Tensor((n, 32, 80), dtype="float16"),)) | |
lv2648 = R.call_tir(cls.reshape3, (lv2646,), out_sinfo=R.Tensor((1, n, 32, 80), dtype="float16")) | |
lv2649 = R.call_tir(cls.reshape3, (lv2647,), out_sinfo=R.Tensor((1, n, 32, 80), dtype="float16")) | |
lv2650 = R.call_tir(cls.transpose7, (lv2636,), out_sinfo=R.Tensor((1, 32, 1, 80), dtype="float16")) | |
lv2651 = R.call_tir(cls.transpose5, (lv2648,), out_sinfo=R.Tensor((1, 32, n, 80), dtype="float16")) | |
lv2652 = R.call_tir(cls.transpose5, (lv2649,), out_sinfo=R.Tensor((1, 32, n, 80), dtype="float16")) | |
lv1034 = R.call_tir(cls.fused_NT_matmul7_divide2_maximum1_minimum1_cast9, (lv2650, lv2651, lv1871), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float32")) | |
lv1035 = R.call_tir(cls.fused_softmax2_cast10, (lv1034,), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float16")) | |
lv2661 = R.call_tir(cls.matmul9, (lv1035, lv2652), out_sinfo=R.Tensor((1, 32, 1, 80), dtype="float16")) | |
lv1036 = R.call_tir(cls.fused_transpose8_reshape8, (lv2661,), out_sinfo=R.Tensor((1, 1, 2560), dtype="float16")) | |
lv1037: R.Tensor((2560, 320), dtype="uint32") = model_params[217] | |
lv1038: R.Tensor((2560, 80), dtype="float16") = model_params[218] | |
lv735_1: R.Tensor((2560,), dtype="float16") = model_params[219] | |
lv26_1 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul8_add6_add7, (lv1037, lv1038, lv1036, lv735_1, lv25_1), out_sinfo=R.Tensor((1, 1, 2560), dtype="float16")) | |
lv2668 = R.call_tir(cls.cast7, (lv26_1,), out_sinfo=R.Tensor((1, 1, 2560), dtype="float32")) | |
lv736: R.Tensor((2560,), dtype="float32") = model_params[212] | |
lv737: R.Tensor((2560,), dtype="float32") = model_params[213] | |
lv1041 = R.call_tir(cls.fused_layer_norm1_cast8, (lv2668, lv736, lv737), out_sinfo=R.Tensor((1, 1, 2560), dtype="float16")) | |
lv2671: R.Tensor((1, 1, 2560), dtype="float16") = lv1041 | |
lv1042: R.Tensor((10240, 320), dtype="uint32") = model_params[220] | |
lv1043: R.Tensor((10240, 80), dtype="float16") = model_params[221] | |
lv740_1: R.Tensor((10240,), dtype="float32") = model_params[222] | |
lv27 = R.call_tir(cls.fused_fused_decode4_fused_NT_matmul9_add8_gelu1_cast11, (lv1042, lv1043, lv2671, lv740_1), out_sinfo=R.Tensor((1, 1, 10240), dtype="float16")) | |
lv2677: R.Tensor((1, 1, 10240), dtype="float16") = lv27 | |
lv1046: R.Tensor((2560, 1280), dtype="uint32") = model_params[223] | |
lv1047: R.Tensor((2560, 320), dtype="float16") = model_params[224] | |
lv743_1: R.Tensor((2560,), dtype="float32") = model_params[225] | |
lv27_1 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul10_add9_cast8_cast12_add7, (lv1046, lv1047, lv2677, lv743_1, lv26_1), out_sinfo=R.Tensor((1, 1, 2560), dtype="float16")) | |
lv2684 = R.call_tir(cls.cast7, (lv27_1,), out_sinfo=R.Tensor((1, 1, 2560), dtype="float32")) | |
lv744_1: R.Tensor((2560,), dtype="float32") = model_params[226] | |
lv745_1: R.Tensor((2560,), dtype="float32") = model_params[227] | |
lv1050 = R.call_tir(cls.fused_layer_norm1_cast8, (lv2684, lv744_1, lv745_1), out_sinfo=R.Tensor((1, 1, 2560), dtype="float16")) | |
lv2687: R.Tensor((1, 1, 2560), dtype="float16") = lv1050 | |
lv1051: R.Tensor((7680, 320), dtype="uint32") = model_params[230] | |
lv1052: R.Tensor((7680, 80), dtype="float16") = model_params[231] | |
lv748_1: R.Tensor((7680,), dtype="float16") = model_params[232] | |
lv28 = R.call_tir(cls.fused_fused_decode2_fused_NT_matmul6_add5, (lv1051, lv1052, lv2687, lv748_1), out_sinfo=R.Tensor((1, 1, 7680), dtype="float16")) | |
lv1055 = R.call_tir(cls.fused_reshape7_split1, (lv28,), out_sinfo=[R.Tensor((1, 1, 32, 80), dtype="float16"), R.Tensor((1, 1, 32, 80), dtype="float16"), R.Tensor((1, 1, 32, 80), dtype="float16")]) | |
lv2693: R.Tensor((1, 1, 32, 80), dtype="float16") = lv1055[0] | |
lv2694 = R.call_tir(cls.rotary_embedding1, (lv2693, metadata["relax.expr.Constant"][1], metadata["relax.expr.Constant"][2]), out_sinfo=R.Tensor((1, 1, 32, 80), dtype="float16"), tir_vars=R.shape([n])) | |
lv2695: R.Tensor((1, 1, 32, 80), dtype="float16") = lv1055[1] | |
lv2696 = R.call_tir(cls.rotary_embedding1, (lv2695, metadata["relax.expr.Constant"][1], metadata["relax.expr.Constant"][2]), out_sinfo=R.Tensor((1, 1, 32, 80), dtype="float16"), tir_vars=R.shape([n])) | |
lv2697: R.Object = kv_cache[28] | |
lv2698 = R.call_tir(cls.squeeze1, (lv2696,), out_sinfo=R.Tensor((1, 32, 80), dtype="float16")) | |
lv2699: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv2697, lv2698, sinfo_args=(R.Object,)) | |
lv2700: R.Object = kv_cache[29] | |
lv1056: R.Tensor((1, 1, 32, 80), dtype="float16") = lv1055[2] | |
lv1057 = R.call_tir(cls.fused_squeeze1, (lv1056,), out_sinfo=R.Tensor((1, 32, 80), dtype="float16")) | |
lv2703: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv2700, lv1057, sinfo_args=(R.Object,)) | |
lv2704: R.Tensor((n, 32, 80), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv2699, R.shape([n, 32, 80]), sinfo_args=(R.Tensor((n, 32, 80), dtype="float16"),)) | |
lv2705: R.Tensor((n, 32, 80), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv2703, R.shape([n, 32, 80]), sinfo_args=(R.Tensor((n, 32, 80), dtype="float16"),)) | |
lv2706 = R.call_tir(cls.reshape3, (lv2704,), out_sinfo=R.Tensor((1, n, 32, 80), dtype="float16")) | |
lv2707 = R.call_tir(cls.reshape3, (lv2705,), out_sinfo=R.Tensor((1, n, 32, 80), dtype="float16")) | |
lv2708 = R.call_tir(cls.transpose7, (lv2694,), out_sinfo=R.Tensor((1, 32, 1, 80), dtype="float16")) | |
lv2709 = R.call_tir(cls.transpose5, (lv2706,), out_sinfo=R.Tensor((1, 32, n, 80), dtype="float16")) | |
lv2710 = R.call_tir(cls.transpose5, (lv2707,), out_sinfo=R.Tensor((1, 32, n, 80), dtype="float16")) | |
lv1058 = R.call_tir(cls.fused_NT_matmul7_divide2_maximum1_minimum1_cast9, (lv2708, lv2709, lv1871), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float32")) | |
lv1059 = R.call_tir(cls.fused_softmax2_cast10, (lv1058,), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float16")) | |
lv2719 = R.call_tir(cls.matmul9, (lv1059, lv2710), out_sinfo=R.Tensor((1, 32, 1, 80), dtype="float16")) | |
lv1060 = R.call_tir(cls.fused_transpose8_reshape8, (lv2719,), out_sinfo=R.Tensor((1, 1, 2560), dtype="float16")) | |
lv1061: R.Tensor((2560, 320), dtype="uint32") = model_params[233] | |
lv1062: R.Tensor((2560, 80), dtype="float16") = model_params[234] | |
lv751: R.Tensor((2560,), dtype="float16") = model_params[235] | |
lv28_1 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul8_add6_add7, (lv1061, lv1062, lv1060, lv751, lv27_1), out_sinfo=R.Tensor((1, 1, 2560), dtype="float16")) | |
lv2726 = R.call_tir(cls.cast7, (lv28_1,), out_sinfo=R.Tensor((1, 1, 2560), dtype="float32")) | |
lv752: R.Tensor((2560,), dtype="float32") = model_params[228] | |
lv753_1: R.Tensor((2560,), dtype="float32") = model_params[229] | |
lv1065 = R.call_tir(cls.fused_layer_norm1_cast8, (lv2726, lv752, lv753_1), out_sinfo=R.Tensor((1, 1, 2560), dtype="float16")) | |
lv2729: R.Tensor((1, 1, 2560), dtype="float16") = lv1065 | |
lv1066: R.Tensor((10240, 320), dtype="uint32") = model_params[236] | |
lv1067: R.Tensor((10240, 80), dtype="float16") = model_params[237] | |
lv756: R.Tensor((10240,), dtype="float32") = model_params[238] | |
lv29 = R.call_tir(cls.fused_fused_decode4_fused_NT_matmul9_add8_gelu1_cast11, (lv1066, lv1067, lv2729, lv756), out_sinfo=R.Tensor((1, 1, 10240), dtype="float16")) | |
lv2735: R.Tensor((1, 1, 10240), dtype="float16") = lv29 | |
lv1070: R.Tensor((2560, 1280), dtype="uint32") = model_params[239] | |
lv1071: R.Tensor((2560, 320), dtype="float16") = model_params[240] | |
lv759_1: R.Tensor((2560,), dtype="float32") = model_params[241] | |
lv29_1 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul10_add9_cast8_cast12_add7, (lv1070, lv1071, lv2735, lv759_1, lv28_1), out_sinfo=R.Tensor((1, 1, 2560), dtype="float16")) | |
lv2742 = R.call_tir(cls.cast7, (lv29_1,), out_sinfo=R.Tensor((1, 1, 2560), dtype="float32")) | |
lv760: R.Tensor((2560,), dtype="float32") = model_params[242] | |
lv761: R.Tensor((2560,), dtype="float32") = model_params[243] | |
lv1074 = R.call_tir(cls.fused_layer_norm1_cast8, (lv2742, lv760, lv761), out_sinfo=R.Tensor((1, 1, 2560), dtype="float16")) | |
lv2745: R.Tensor((1, 1, 2560), dtype="float16") = lv1074 | |
lv1075: R.Tensor((7680, 320), dtype="uint32") = model_params[246] | |
lv1076: R.Tensor((7680, 80), dtype="float16") = model_params[247] | |
lv764_1: R.Tensor((7680,), dtype="float16") = model_params[248] | |
lv30 = R.call_tir(cls.fused_fused_decode2_fused_NT_matmul6_add5, (lv1075, lv1076, lv2745, lv764_1), out_sinfo=R.Tensor((1, 1, 7680), dtype="float16")) | |
lv1079 = R.call_tir(cls.fused_reshape7_split1, (lv30,), out_sinfo=[R.Tensor((1, 1, 32, 80), dtype="float16"), R.Tensor((1, 1, 32, 80), dtype="float16"), R.Tensor((1, 1, 32, 80), dtype="float16")]) | |
lv2751: R.Tensor((1, 1, 32, 80), dtype="float16") = lv1079[0] | |
lv2752 = R.call_tir(cls.rotary_embedding1, (lv2751, metadata["relax.expr.Constant"][1], metadata["relax.expr.Constant"][2]), out_sinfo=R.Tensor((1, 1, 32, 80), dtype="float16"), tir_vars=R.shape([n])) | |
lv2753: R.Tensor((1, 1, 32, 80), dtype="float16") = lv1079[1] | |
lv2754 = R.call_tir(cls.rotary_embedding1, (lv2753, metadata["relax.expr.Constant"][1], metadata["relax.expr.Constant"][2]), out_sinfo=R.Tensor((1, 1, 32, 80), dtype="float16"), tir_vars=R.shape([n])) | |
lv2755: R.Object = kv_cache[30] | |
lv2756 = R.call_tir(cls.squeeze1, (lv2754,), out_sinfo=R.Tensor((1, 32, 80), dtype="float16")) | |
lv2757: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv2755, lv2756, sinfo_args=(R.Object,)) | |
lv2758: R.Object = kv_cache[31] | |
lv1080: R.Tensor((1, 1, 32, 80), dtype="float16") = lv1079[2] | |
lv1081 = R.call_tir(cls.fused_squeeze1, (lv1080,), out_sinfo=R.Tensor((1, 32, 80), dtype="float16")) | |
lv2761: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv2758, lv1081, sinfo_args=(R.Object,)) | |
lv2762: R.Tensor((n, 32, 80), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv2757, R.shape([n, 32, 80]), sinfo_args=(R.Tensor((n, 32, 80), dtype="float16"),)) | |
lv2763: R.Tensor((n, 32, 80), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv2761, R.shape([n, 32, 80]), sinfo_args=(R.Tensor((n, 32, 80), dtype="float16"),)) | |
lv2764 = R.call_tir(cls.reshape3, (lv2762,), out_sinfo=R.Tensor((1, n, 32, 80), dtype="float16")) | |
lv2765 = R.call_tir(cls.reshape3, (lv2763,), out_sinfo=R.Tensor((1, n, 32, 80), dtype="float16")) | |
lv2766 = R.call_tir(cls.transpose7, (lv2752,), out_sinfo=R.Tensor((1, 32, 1, 80), dtype="float16")) | |
lv2767 = R.call_tir(cls.transpose5, (lv2764,), out_sinfo=R.Tensor((1, 32, n, 80), dtype="float16")) | |
lv2768 = R.call_tir(cls.transpose5, (lv2765,), out_sinfo=R.Tensor((1, 32, n, 80), dtype="float16")) | |
lv1082 = R.call_tir(cls.fused_NT_matmul7_divide2_maximum1_minimum1_cast9, (lv2766, lv2767, lv1871), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float32")) | |
lv1083 = R.call_tir(cls.fused_softmax2_cast10, (lv1082,), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float16")) | |
lv2777 = R.call_tir(cls.matmul9, (lv1083, lv2768), out_sinfo=R.Tensor((1, 32, 1, 80), dtype="float16")) | |
lv1084 = R.call_tir(cls.fused_transpose8_reshape8, (lv2777,), out_sinfo=R.Tensor((1, 1, 2560), dtype="float16")) | |
lv1085: R.Tensor((2560, 320), dtype="uint32") = model_params[249] | |
lv1086: R.Tensor((2560, 80), dtype="float16") = model_params[250] | |
lv767_1: R.Tensor((2560,), dtype="float16") = model_params[251] | |
lv30_1 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul8_add6_add7, (lv1085, lv1086, lv1084, lv767_1, lv29_1), out_sinfo=R.Tensor((1, 1, 2560), dtype="float16")) | |
lv2784 = R.call_tir(cls.cast7, (lv30_1,), out_sinfo=R.Tensor((1, 1, 2560), dtype="float32")) | |
lv768_1: R.Tensor((2560,), dtype="float32") = model_params[244] | |
lv769_1: R.Tensor((2560,), dtype="float32") = model_params[245] | |
lv1089 = R.call_tir(cls.fused_layer_norm1_cast8, (lv2784, lv768_1, lv769_1), out_sinfo=R.Tensor((1, 1, 2560), dtype="float16")) | |
lv2787: R.Tensor((1, 1, 2560), dtype="float16") = lv1089 | |
lv1090: R.Tensor((10240, 320), dtype="uint32") = model_params[252] | |
lv1091: R.Tensor((10240, 80), dtype="float16") = model_params[253] | |
lv772_1: R.Tensor((10240,), dtype="float32") = model_params[254] | |
lv31 = R.call_tir(cls.fused_fused_decode4_fused_NT_matmul9_add8_gelu1_cast11, (lv1090, lv1091, lv2787, lv772_1), out_sinfo=R.Tensor((1, 1, 10240), dtype="float16")) | |
lv2793: R.Tensor((1, 1, 10240), dtype="float16") = lv31 | |
lv1094: R.Tensor((2560, 1280), dtype="uint32") = model_params[255] | |
lv1095: R.Tensor((2560, 320), dtype="float16") = model_params[256] | |
lv775: R.Tensor((2560,), dtype="float32") = model_params[257] | |
lv31_1 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul10_add9_cast8_cast12_add7, (lv1094, lv1095, lv2793, lv775, lv30_1), out_sinfo=R.Tensor((1, 1, 2560), dtype="float16")) | |
lv2800 = R.call_tir(cls.cast7, (lv31_1,), out_sinfo=R.Tensor((1, 1, 2560), dtype="float32")) | |
lv776: R.Tensor((2560,), dtype="float32") = model_params[258] | |
lv777_1: R.Tensor((2560,), dtype="float32") = model_params[259] | |
lv1098 = R.call_tir(cls.fused_layer_norm1_cast8, (lv2800, lv776, lv777_1), out_sinfo=R.Tensor((1, 1, 2560), dtype="float16")) | |
lv2803: R.Tensor((1, 1, 2560), dtype="float16") = lv1098 | |
lv1099: R.Tensor((7680, 320), dtype="uint32") = model_params[262] | |
lv1100: R.Tensor((7680, 80), dtype="float16") = model_params[263] | |
lv780: R.Tensor((7680,), dtype="float16") = model_params[264] | |
lv32 = R.call_tir(cls.fused_fused_decode2_fused_NT_matmul6_add5, (lv1099, lv1100, lv2803, lv780), out_sinfo=R.Tensor((1, 1, 7680), dtype="float16")) | |
lv1103 = R.call_tir(cls.fused_reshape7_split1, (lv32,), out_sinfo=[R.Tensor((1, 1, 32, 80), dtype="float16"), R.Tensor((1, 1, 32, 80), dtype="float16"), R.Tensor((1, 1, 32, 80), dtype="float16")]) | |
lv2809: R.Tensor((1, 1, 32, 80), dtype="float16") = lv1103[0] | |
lv2810 = R.call_tir(cls.rotary_embedding1, (lv2809, metadata["relax.expr.Constant"][1], metadata["relax.expr.Constant"][2]), out_sinfo=R.Tensor((1, 1, 32, 80), dtype="float16"), tir_vars=R.shape([n])) | |
lv2811: R.Tensor((1, 1, 32, 80), dtype="float16") = lv1103[1] | |
lv2812 = R.call_tir(cls.rotary_embedding1, (lv2811, metadata["relax.expr.Constant"][1], metadata["relax.expr.Constant"][2]), out_sinfo=R.Tensor((1, 1, 32, 80), dtype="float16"), tir_vars=R.shape([n])) | |
lv2813: R.Object = kv_cache[32] | |
lv2814 = R.call_tir(cls.squeeze1, (lv2812,), out_sinfo=R.Tensor((1, 32, 80), dtype="float16")) | |
lv2815: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv2813, lv2814, sinfo_args=(R.Object,)) | |
lv2816: R.Object = kv_cache[33] | |
lv1104: R.Tensor((1, 1, 32, 80), dtype="float16") = lv1103[2] | |
lv1105 = R.call_tir(cls.fused_squeeze1, (lv1104,), out_sinfo=R.Tensor((1, 32, 80), dtype="float16")) | |
lv2819: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv2816, lv1105, sinfo_args=(R.Object,)) | |
lv2820: R.Tensor((n, 32, 80), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv2815, R.shape([n, 32, 80]), sinfo_args=(R.Tensor((n, 32, 80), dtype="float16"),)) | |
lv2821: R.Tensor((n, 32, 80), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv2819, R.shape([n, 32, 80]), sinfo_args=(R.Tensor((n, 32, 80), dtype="float16"),)) | |
lv2822 = R.call_tir(cls.reshape3, (lv2820,), out_sinfo=R.Tensor((1, n, 32, 80), dtype="float16")) | |
lv2823 = R.call_tir(cls.reshape3, (lv2821,), out_sinfo=R.Tensor((1, n, 32, 80), dtype="float16")) | |
lv2824 = R.call_tir(cls.transpose7, (lv2810,), out_sinfo=R.Tensor((1, 32, 1, 80), dtype="float16")) | |
lv2825 = R.call_tir(cls.transpose5, (lv2822,), out_sinfo=R.Tensor((1, 32, n, 80), dtype="float16")) | |
lv2826 = R.call_tir(cls.transpose5, (lv2823,), out_sinfo=R.Tensor((1, 32, n, 80), dtype="float16")) | |
lv1106 = R.call_tir(cls.fused_NT_matmul7_divide2_maximum1_minimum1_cast9, (lv2824, lv2825, lv1871), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float32")) | |
lv1107 = R.call_tir(cls.fused_softmax2_cast10, (lv1106,), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float16")) | |
lv2835 = R.call_tir(cls.matmul9, (lv1107, lv2826), out_sinfo=R.Tensor((1, 32, 1, 80), dtype="float16")) | |
lv1108 = R.call_tir(cls.fused_transpose8_reshape8, (lv2835,), out_sinfo=R.Tensor((1, 1, 2560), dtype="float16")) | |
lv1109: R.Tensor((2560, 320), dtype="uint32") = model_params[265] | |
lv1110: R.Tensor((2560, 80), dtype="float16") = model_params[266] | |
lv783_1: R.Tensor((2560,), dtype="float16") = model_params[267] | |
lv32_1 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul8_add6_add7, (lv1109, lv1110, lv1108, lv783_1, lv31_1), out_sinfo=R.Tensor((1, 1, 2560), dtype="float16")) | |
lv2842 = R.call_tir(cls.cast7, (lv32_1,), out_sinfo=R.Tensor((1, 1, 2560), dtype="float32")) | |
lv784: R.Tensor((2560,), dtype="float32") = model_params[260] | |
lv785: R.Tensor((2560,), dtype="float32") = model_params[261] | |
lv1113 = R.call_tir(cls.fused_layer_norm1_cast8, (lv2842, lv784, lv785), out_sinfo=R.Tensor((1, 1, 2560), dtype="float16")) | |
lv2845: R.Tensor((1, 1, 2560), dtype="float16") = lv1113 | |
lv1114: R.Tensor((10240, 320), dtype="uint32") = model_params[268] | |
lv1115: R.Tensor((10240, 80), dtype="float16") = model_params[269] | |
lv788_1: R.Tensor((10240,), dtype="float32") = model_params[270] | |
lv33 = R.call_tir(cls.fused_fused_decode4_fused_NT_matmul9_add8_gelu1_cast11, (lv1114, lv1115, lv2845, lv788_1), out_sinfo=R.Tensor((1, 1, 10240), dtype="float16")) | |
lv2851: R.Tensor((1, 1, 10240), dtype="float16") = lv33 | |
lv1118: R.Tensor((2560, 1280), dtype="uint32") = model_params[271] | |
lv1119: R.Tensor((2560, 320), dtype="float16") = model_params[272] | |
lv791_1: R.Tensor((2560,), dtype="float32") = model_params[273] | |
lv33_1 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul10_add9_cast8_cast12_add7, (lv1118, lv1119, lv2851, lv791_1, lv32_1), out_sinfo=R.Tensor((1, 1, 2560), dtype="float16")) | |
lv2858 = R.call_tir(cls.cast7, (lv33_1,), out_sinfo=R.Tensor((1, 1, 2560), dtype="float32")) | |
lv792_1: R.Tensor((2560,), dtype="float32") = model_params[274] | |
lv793_1: R.Tensor((2560,), dtype="float32") = model_params[275] | |
lv1122 = R.call_tir(cls.fused_layer_norm1_cast8, (lv2858, lv792_1, lv793_1), out_sinfo=R.Tensor((1, 1, 2560), dtype="float16")) | |
lv2861: R.Tensor((1, 1, 2560), dtype="float16") = lv1122 | |
lv1123: R.Tensor((7680, 320), dtype="uint32") = model_params[278] | |
lv1124: R.Tensor((7680, 80), dtype="float16") = model_params[279] | |
lv796_1: R.Tensor((7680,), dtype="float16") = model_params[280] | |
lv34 = R.call_tir(cls.fused_fused_decode2_fused_NT_matmul6_add5, (lv1123, lv1124, lv2861, lv796_1), out_sinfo=R.Tensor((1, 1, 7680), dtype="float16")) | |
lv1127 = R.call_tir(cls.fused_reshape7_split1, (lv34,), out_sinfo=[R.Tensor((1, 1, 32, 80), dtype="float16"), R.Tensor((1, 1, 32, 80), dtype="float16"), R.Tensor((1, 1, 32, 80), dtype="float16")]) | |
lv2867: R.Tensor((1, 1, 32, 80), dtype="float16") = lv1127[0] | |
lv2868 = R.call_tir(cls.rotary_embedding1, (lv2867, metadata["relax.expr.Constant"][1], metadata["relax.expr.Constant"][2]), out_sinfo=R.Tensor((1, 1, 32, 80), dtype="float16"), tir_vars=R.shape([n])) | |
lv2869: R.Tensor((1, 1, 32, 80), dtype="float16") = lv1127[1] | |
lv2870 = R.call_tir(cls.rotary_embedding1, (lv2869, metadata["relax.expr.Constant"][1], metadata["relax.expr.Constant"][2]), out_sinfo=R.Tensor((1, 1, 32, 80), dtype="float16"), tir_vars=R.shape([n])) | |
lv2871: R.Object = kv_cache[34] | |
lv2872 = R.call_tir(cls.squeeze1, (lv2870,), out_sinfo=R.Tensor((1, 32, 80), dtype="float16")) | |
lv2873: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv2871, lv2872, sinfo_args=(R.Object,)) | |
lv2874: R.Object = kv_cache[35] | |
lv1128: R.Tensor((1, 1, 32, 80), dtype="float16") = lv1127[2] | |
lv1129 = R.call_tir(cls.fused_squeeze1, (lv1128,), out_sinfo=R.Tensor((1, 32, 80), dtype="float16")) | |
lv2877: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv2874, lv1129, sinfo_args=(R.Object,)) | |
lv2878: R.Tensor((n, 32, 80), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv2873, R.shape([n, 32, 80]), sinfo_args=(R.Tensor((n, 32, 80), dtype="float16"),)) | |
lv2879: R.Tensor((n, 32, 80), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv2877, R.shape([n, 32, 80]), sinfo_args=(R.Tensor((n, 32, 80), dtype="float16"),)) | |
lv2880 = R.call_tir(cls.reshape3, (lv2878,), out_sinfo=R.Tensor((1, n, 32, 80), dtype="float16")) | |
lv2881 = R.call_tir(cls.reshape3, (lv2879,), out_sinfo=R.Tensor((1, n, 32, 80), dtype="float16")) | |
lv2882 = R.call_tir(cls.transpose7, (lv2868,), out_sinfo=R.Tensor((1, 32, 1, 80), dtype="float16")) | |
lv2883 = R.call_tir(cls.transpose5, (lv2880,), out_sinfo=R.Tensor((1, 32, n, 80), dtype="float16")) | |
lv2884 = R.call_tir(cls.transpose5, (lv2881,), out_sinfo=R.Tensor((1, 32, n, 80), dtype="float16")) | |
lv1130 = R.call_tir(cls.fused_NT_matmul7_divide2_maximum1_minimum1_cast9, (lv2882, lv2883, lv1871), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float32")) | |
lv1131 = R.call_tir(cls.fused_softmax2_cast10, (lv1130,), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float16")) | |
lv2893 = R.call_tir(cls.matmul9, (lv1131, lv2884), out_sinfo=R.Tensor((1, 32, 1, 80), dtype="float16")) | |
lv1132 = R.call_tir(cls.fused_transpose8_reshape8, (lv2893,), out_sinfo=R.Tensor((1, 1, 2560), dtype="float16")) | |
lv1133: R.Tensor((2560, 320), dtype="uint32") = model_params[281] | |
lv1134: R.Tensor((2560, 80), dtype="float16") = model_params[282] | |
lv799: R.Tensor((2560,), dtype="float16") = model_params[283] | |
lv34_1 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul8_add6_add7, (lv1133, lv1134, lv1132, lv799, lv33_1), out_sinfo=R.Tensor((1, 1, 2560), dtype="float16")) | |
lv2900 = R.call_tir(cls.cast7, (lv34_1,), out_sinfo=R.Tensor((1, 1, 2560), dtype="float32")) | |
lv800: R.Tensor((2560,), dtype="float32") = model_params[276] | |
lv801_1: R.Tensor((2560,), dtype="float32") = model_params[277] | |
lv1137 = R.call_tir(cls.fused_layer_norm1_cast8, (lv2900, lv800, lv801_1), out_sinfo=R.Tensor((1, 1, 2560), dtype="float16")) | |
lv2903: R.Tensor((1, 1, 2560), dtype="float16") = lv1137 | |
lv1138: R.Tensor((10240, 320), dtype="uint32") = model_params[284] | |
lv1139: R.Tensor((10240, 80), dtype="float16") = model_params[285] | |
lv804: R.Tensor((10240,), dtype="float32") = model_params[286] | |
lv35 = R.call_tir(cls.fused_fused_decode4_fused_NT_matmul9_add8_gelu1_cast11, (lv1138, lv1139, lv2903, lv804), out_sinfo=R.Tensor((1, 1, 10240), dtype="float16")) | |
lv2909: R.Tensor((1, 1, 10240), dtype="float16") = lv35 | |
lv1142: R.Tensor((2560, 1280), dtype="uint32") = model_params[287] | |
lv1143: R.Tensor((2560, 320), dtype="float16") = model_params[288] | |
lv807_1: R.Tensor((2560,), dtype="float32") = model_params[289] | |
lv35_1 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul10_add9_cast8_cast12_add7, (lv1142, lv1143, lv2909, lv807_1, lv34_1), out_sinfo=R.Tensor((1, 1, 2560), dtype="float16")) | |
lv2916 = R.call_tir(cls.cast7, (lv35_1,), out_sinfo=R.Tensor((1, 1, 2560), dtype="float32")) | |
lv808: R.Tensor((2560,), dtype="float32") = model_params[290] | |
lv809: R.Tensor((2560,), dtype="float32") = model_params[291] | |
lv1146 = R.call_tir(cls.fused_layer_norm1_cast8, (lv2916, lv808, lv809), out_sinfo=R.Tensor((1, 1, 2560), dtype="float16")) | |
lv2919: R.Tensor((1, 1, 2560), dtype="float16") = lv1146 | |
lv1147: R.Tensor((7680, 320), dtype="uint32") = model_params[294] | |
lv1148: R.Tensor((7680, 80), dtype="float16") = model_params[295] | |
lv812_1: R.Tensor((7680,), dtype="float16") = model_params[296] | |
lv36 = R.call_tir(cls.fused_fused_decode2_fused_NT_matmul6_add5, (lv1147, lv1148, lv2919, lv812_1), out_sinfo=R.Tensor((1, 1, 7680), dtype="float16")) | |
lv1151 = R.call_tir(cls.fused_reshape7_split1, (lv36,), out_sinfo=[R.Tensor((1, 1, 32, 80), dtype="float16"), R.Tensor((1, 1, 32, 80), dtype="float16"), R.Tensor((1, 1, 32, 80), dtype="float16")]) | |
lv2925: R.Tensor((1, 1, 32, 80), dtype="float16") = lv1151[0] | |
lv2926 = R.call_tir(cls.rotary_embedding1, (lv2925, metadata["relax.expr.Constant"][1], metadata["relax.expr.Constant"][2]), out_sinfo=R.Tensor((1, 1, 32, 80), dtype="float16"), tir_vars=R.shape([n])) | |
lv2927: R.Tensor((1, 1, 32, 80), dtype="float16") = lv1151[1] | |
lv2928 = R.call_tir(cls.rotary_embedding1, (lv2927, metadata["relax.expr.Constant"][1], metadata["relax.expr.Constant"][2]), out_sinfo=R.Tensor((1, 1, 32, 80), dtype="float16"), tir_vars=R.shape([n])) | |
lv2929: R.Object = kv_cache[36] | |
lv2930 = R.call_tir(cls.squeeze1, (lv2928,), out_sinfo=R.Tensor((1, 32, 80), dtype="float16")) | |
lv2931: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv2929, lv2930, sinfo_args=(R.Object,)) | |
lv2932: R.Object = kv_cache[37] | |
lv1152: R.Tensor((1, 1, 32, 80), dtype="float16") = lv1151[2] | |
lv1153 = R.call_tir(cls.fused_squeeze1, (lv1152,), out_sinfo=R.Tensor((1, 32, 80), dtype="float16")) | |
lv2935: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv2932, lv1153, sinfo_args=(R.Object,)) | |
lv2936: R.Tensor((n, 32, 80), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv2931, R.shape([n, 32, 80]), sinfo_args=(R.Tensor((n, 32, 80), dtype="float16"),)) | |
lv2937: R.Tensor((n, 32, 80), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv2935, R.shape([n, 32, 80]), sinfo_args=(R.Tensor((n, 32, 80), dtype="float16"),)) | |
lv2938 = R.call_tir(cls.reshape3, (lv2936,), out_sinfo=R.Tensor((1, n, 32, 80), dtype="float16")) | |
lv2939 = R.call_tir(cls.reshape3, (lv2937,), out_sinfo=R.Tensor((1, n, 32, 80), dtype="float16")) | |
lv2940 = R.call_tir(cls.transpose7, (lv2926,), out_sinfo=R.Tensor((1, 32, 1, 80), dtype="float16")) | |
lv2941 = R.call_tir(cls.transpose5, (lv2938,), out_sinfo=R.Tensor((1, 32, n, 80), dtype="float16")) | |
lv2942 = R.call_tir(cls.transpose5, (lv2939,), out_sinfo=R.Tensor((1, 32, n, 80), dtype="float16")) | |
lv1154 = R.call_tir(cls.fused_NT_matmul7_divide2_maximum1_minimum1_cast9, (lv2940, lv2941, lv1871), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float32")) | |
lv1155 = R.call_tir(cls.fused_softmax2_cast10, (lv1154,), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float16")) | |
lv2951 = R.call_tir(cls.matmul9, (lv1155, lv2942), out_sinfo=R.Tensor((1, 32, 1, 80), dtype="float16")) | |
lv1156 = R.call_tir(cls.fused_transpose8_reshape8, (lv2951,), out_sinfo=R.Tensor((1, 1, 2560), dtype="float16")) | |
lv1157: R.Tensor((2560, 320), dtype="uint32") = model_params[297] | |
lv1158: R.Tensor((2560, 80), dtype="float16") = model_params[298] | |
lv815_1: R.Tensor((2560,), dtype="float16") = model_params[299] | |
lv36_1 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul8_add6_add7, (lv1157, lv1158, lv1156, lv815_1, lv35_1), out_sinfo=R.Tensor((1, 1, 2560), dtype="float16")) | |
lv2958 = R.call_tir(cls.cast7, (lv36_1,), out_sinfo=R.Tensor((1, 1, 2560), dtype="float32")) | |
lv816_1: R.Tensor((2560,), dtype="float32") = model_params[292] | |
lv817_1: R.Tensor((2560,), dtype="float32") = model_params[293] | |
lv1161 = R.call_tir(cls.fused_layer_norm1_cast8, (lv2958, lv816_1, lv817_1), out_sinfo=R.Tensor((1, 1, 2560), dtype="float16")) | |
lv2961: R.Tensor((1, 1, 2560), dtype="float16") = lv1161 | |
lv1162: R.Tensor((10240, 320), dtype="uint32") = model_params[300] | |
lv1163: R.Tensor((10240, 80), dtype="float16") = model_params[301] | |
lv820_1: R.Tensor((10240,), dtype="float32") = model_params[302] | |
lv37 = R.call_tir(cls.fused_fused_decode4_fused_NT_matmul9_add8_gelu1_cast11, (lv1162, lv1163, lv2961, lv820_1), out_sinfo=R.Tensor((1, 1, 10240), dtype="float16")) | |
lv2967: R.Tensor((1, 1, 10240), dtype="float16") = lv37 | |
lv1166: R.Tensor((2560, 1280), dtype="uint32") = model_params[303] | |
lv1167: R.Tensor((2560, 320), dtype="float16") = model_params[304] | |
lv823: R.Tensor((2560,), dtype="float32") = model_params[305] | |
lv37_1 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul10_add9_cast8_cast12_add7, (lv1166, lv1167, lv2967, lv823, lv36_1), out_sinfo=R.Tensor((1, 1, 2560), dtype="float16")) | |
lv2974 = R.call_tir(cls.cast7, (lv37_1,), out_sinfo=R.Tensor((1, 1, 2560), dtype="float32")) | |
lv824: R.Tensor((2560,), dtype="float32") = model_params[306] | |
lv825_1: R.Tensor((2560,), dtype="float32") = model_params[307] | |
lv1170 = R.call_tir(cls.fused_layer_norm1_cast8, (lv2974, lv824, lv825_1), out_sinfo=R.Tensor((1, 1, 2560), dtype="float16")) | |
lv2977: R.Tensor((1, 1, 2560), dtype="float16") = lv1170 | |
lv1171: R.Tensor((7680, 320), dtype="uint32") = model_params[310] | |
lv1172: R.Tensor((7680, 80), dtype="float16") = model_params[311] | |
lv828: R.Tensor((7680,), dtype="float16") = model_params[312] | |
lv38 = R.call_tir(cls.fused_fused_decode2_fused_NT_matmul6_add5, (lv1171, lv1172, lv2977, lv828), out_sinfo=R.Tensor((1, 1, 7680), dtype="float16")) | |
lv1175 = R.call_tir(cls.fused_reshape7_split1, (lv38,), out_sinfo=[R.Tensor((1, 1, 32, 80), dtype="float16"), R.Tensor((1, 1, 32, 80), dtype="float16"), R.Tensor((1, 1, 32, 80), dtype="float16")]) | |
lv2983: R.Tensor((1, 1, 32, 80), dtype="float16") = lv1175[0] | |
lv2984 = R.call_tir(cls.rotary_embedding1, (lv2983, metadata["relax.expr.Constant"][1], metadata["relax.expr.Constant"][2]), out_sinfo=R.Tensor((1, 1, 32, 80), dtype="float16"), tir_vars=R.shape([n])) | |
lv2985: R.Tensor((1, 1, 32, 80), dtype="float16") = lv1175[1] | |
lv2986 = R.call_tir(cls.rotary_embedding1, (lv2985, metadata["relax.expr.Constant"][1], metadata["relax.expr.Constant"][2]), out_sinfo=R.Tensor((1, 1, 32, 80), dtype="float16"), tir_vars=R.shape([n])) | |
lv2987: R.Object = kv_cache[38] | |
lv2988 = R.call_tir(cls.squeeze1, (lv2986,), out_sinfo=R.Tensor((1, 32, 80), dtype="float16")) | |
lv2989: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv2987, lv2988, sinfo_args=(R.Object,)) | |
lv2990: R.Object = kv_cache[39] | |
lv1176: R.Tensor((1, 1, 32, 80), dtype="float16") = lv1175[2] | |
lv1177 = R.call_tir(cls.fused_squeeze1, (lv1176,), out_sinfo=R.Tensor((1, 32, 80), dtype="float16")) | |
lv2993: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv2990, lv1177, sinfo_args=(R.Object,)) | |
lv2994: R.Tensor((n, 32, 80), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv2989, R.shape([n, 32, 80]), sinfo_args=(R.Tensor((n, 32, 80), dtype="float16"),)) | |
lv2995: R.Tensor((n, 32, 80), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv2993, R.shape([n, 32, 80]), sinfo_args=(R.Tensor((n, 32, 80), dtype="float16"),)) | |
lv2996 = R.call_tir(cls.reshape3, (lv2994,), out_sinfo=R.Tensor((1, n, 32, 80), dtype="float16")) | |
lv2997 = R.call_tir(cls.reshape3, (lv2995,), out_sinfo=R.Tensor((1, n, 32, 80), dtype="float16")) | |
lv2998 = R.call_tir(cls.transpose7, (lv2984,), out_sinfo=R.Tensor((1, 32, 1, 80), dtype="float16")) | |
lv2999 = R.call_tir(cls.transpose5, (lv2996,), out_sinfo=R.Tensor((1, 32, n, 80), dtype="float16")) | |
lv3000 = R.call_tir(cls.transpose5, (lv2997,), out_sinfo=R.Tensor((1, 32, n, 80), dtype="float16")) | |
lv1178 = R.call_tir(cls.fused_NT_matmul7_divide2_maximum1_minimum1_cast9, (lv2998, lv2999, lv1871), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float32")) | |
lv1179 = R.call_tir(cls.fused_softmax2_cast10, (lv1178,), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float16")) | |
lv3009 = R.call_tir(cls.matmul9, (lv1179, lv3000), out_sinfo=R.Tensor((1, 32, 1, 80), dtype="float16")) | |
lv1180 = R.call_tir(cls.fused_transpose8_reshape8, (lv3009,), out_sinfo=R.Tensor((1, 1, 2560), dtype="float16")) | |
lv1181: R.Tensor((2560, 320), dtype="uint32") = model_params[313] | |
lv1182: R.Tensor((2560, 80), dtype="float16") = model_params[314] | |
lv831_1: R.Tensor((2560,), dtype="float16") = model_params[315] | |
lv38_1 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul8_add6_add7, (lv1181, lv1182, lv1180, lv831_1, lv37_1), out_sinfo=R.Tensor((1, 1, 2560), dtype="float16")) | |
lv3016 = R.call_tir(cls.cast7, (lv38_1,), out_sinfo=R.Tensor((1, 1, 2560), dtype="float32")) | |
lv832: R.Tensor((2560,), dtype="float32") = model_params[308] | |
lv833: R.Tensor((2560,), dtype="float32") = model_params[309] | |
lv1185 = R.call_tir(cls.fused_layer_norm1_cast8, (lv3016, lv832, lv833), out_sinfo=R.Tensor((1, 1, 2560), dtype="float16")) | |
lv3019: R.Tensor((1, 1, 2560), dtype="float16") = lv1185 | |
lv1186: R.Tensor((10240, 320), dtype="uint32") = model_params[316] | |
lv1187: R.Tensor((10240, 80), dtype="float16") = model_params[317] | |
lv836_1: R.Tensor((10240,), dtype="float32") = model_params[318] | |
lv39 = R.call_tir(cls.fused_fused_decode4_fused_NT_matmul9_add8_gelu1_cast11, (lv1186, lv1187, lv3019, lv836_1), out_sinfo=R.Tensor((1, 1, 10240), dtype="float16")) | |
lv3025: R.Tensor((1, 1, 10240), dtype="float16") = lv39 | |
lv1190: R.Tensor((2560, 1280), dtype="uint32") = model_params[319] | |
lv1191: R.Tensor((2560, 320), dtype="float16") = model_params[320] | |
lv839_1: R.Tensor((2560,), dtype="float32") = model_params[321] | |
lv39_1 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul10_add9_cast8_cast12_add7, (lv1190, lv1191, lv3025, lv839_1, lv38_1), out_sinfo=R.Tensor((1, 1, 2560), dtype="float16")) | |
lv3032 = R.call_tir(cls.cast7, (lv39_1,), out_sinfo=R.Tensor((1, 1, 2560), dtype="float32")) | |
lv840_1: R.Tensor((2560,), dtype="float32") = model_params[322] | |
lv841_1: R.Tensor((2560,), dtype="float32") = model_params[323] | |
lv1194 = R.call_tir(cls.fused_layer_norm1_cast8, (lv3032, lv840_1, lv841_1), out_sinfo=R.Tensor((1, 1, 2560), dtype="float16")) | |
lv3035: R.Tensor((1, 1, 2560), dtype="float16") = lv1194 | |
lv1195: R.Tensor((7680, 320), dtype="uint32") = model_params[326] | |
lv1196: R.Tensor((7680, 80), dtype="float16") = model_params[327] | |
lv844_1: R.Tensor((7680,), dtype="float16") = model_params[328] | |
lv40 = R.call_tir(cls.fused_fused_decode2_fused_NT_matmul6_add5, (lv1195, lv1196, lv3035, lv844_1), out_sinfo=R.Tensor((1, 1, 7680), dtype="float16")) | |
lv1199 = R.call_tir(cls.fused_reshape7_split1, (lv40,), out_sinfo=[R.Tensor((1, 1, 32, 80), dtype="float16"), R.Tensor((1, 1, 32, 80), dtype="float16"), R.Tensor((1, 1, 32, 80), dtype="float16")]) | |
lv3041: R.Tensor((1, 1, 32, 80), dtype="float16") = lv1199[0] | |
lv3042 = R.call_tir(cls.rotary_embedding1, (lv3041, metadata["relax.expr.Constant"][1], metadata["relax.expr.Constant"][2]), out_sinfo=R.Tensor((1, 1, 32, 80), dtype="float16"), tir_vars=R.shape([n])) | |
lv3043: R.Tensor((1, 1, 32, 80), dtype="float16") = lv1199[1] | |
lv3044 = R.call_tir(cls.rotary_embedding1, (lv3043, metadata["relax.expr.Constant"][1], metadata["relax.expr.Constant"][2]), out_sinfo=R.Tensor((1, 1, 32, 80), dtype="float16"), tir_vars=R.shape([n])) | |
lv3045: R.Object = kv_cache[40] | |
lv3046 = R.call_tir(cls.squeeze1, (lv3044,), out_sinfo=R.Tensor((1, 32, 80), dtype="float16")) | |
lv3047: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv3045, lv3046, sinfo_args=(R.Object,)) | |
lv3048: R.Object = kv_cache[41] | |
lv1200: R.Tensor((1, 1, 32, 80), dtype="float16") = lv1199[2] | |
lv1201 = R.call_tir(cls.fused_squeeze1, (lv1200,), out_sinfo=R.Tensor((1, 32, 80), dtype="float16")) | |
lv3051: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv3048, lv1201, sinfo_args=(R.Object,)) | |
lv3052: R.Tensor((n, 32, 80), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv3047, R.shape([n, 32, 80]), sinfo_args=(R.Tensor((n, 32, 80), dtype="float16"),)) | |
lv3053: R.Tensor((n, 32, 80), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv3051, R.shape([n, 32, 80]), sinfo_args=(R.Tensor((n, 32, 80), dtype="float16"),)) | |
lv3054 = R.call_tir(cls.reshape3, (lv3052,), out_sinfo=R.Tensor((1, n, 32, 80), dtype="float16")) | |
lv3055 = R.call_tir(cls.reshape3, (lv3053,), out_sinfo=R.Tensor((1, n, 32, 80), dtype="float16")) | |
lv3056 = R.call_tir(cls.transpose7, (lv3042,), out_sinfo=R.Tensor((1, 32, 1, 80), dtype="float16")) | |
lv3057 = R.call_tir(cls.transpose5, (lv3054,), out_sinfo=R.Tensor((1, 32, n, 80), dtype="float16")) | |
lv3058 = R.call_tir(cls.transpose5, (lv3055,), out_sinfo=R.Tensor((1, 32, n, 80), dtype="float16")) | |
lv1202 = R.call_tir(cls.fused_NT_matmul7_divide2_maximum1_minimum1_cast9, (lv3056, lv3057, lv1871), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float32")) | |
lv1203 = R.call_tir(cls.fused_softmax2_cast10, (lv1202,), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float16")) | |
lv3067 = R.call_tir(cls.matmul9, (lv1203, lv3058), out_sinfo=R.Tensor((1, 32, 1, 80), dtype="float16")) | |
lv1204 = R.call_tir(cls.fused_transpose8_reshape8, (lv3067,), out_sinfo=R.Tensor((1, 1, 2560), dtype="float16")) | |
lv1205: R.Tensor((2560, 320), dtype="uint32") = model_params[329] | |
lv1206: R.Tensor((2560, 80), dtype="float16") = model_params[330] | |
lv847: R.Tensor((2560,), dtype="float16") = model_params[331] | |
lv40_1 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul8_add6_add7, (lv1205, lv1206, lv1204, lv847, lv39_1), out_sinfo=R.Tensor((1, 1, 2560), dtype="float16")) | |
lv3074 = R.call_tir(cls.cast7, (lv40_1,), out_sinfo=R.Tensor((1, 1, 2560), dtype="float32")) | |
lv848: R.Tensor((2560,), dtype="float32") = model_params[324] | |
lv849_1: R.Tensor((2560,), dtype="float32") = model_params[325] | |
lv1209 = R.call_tir(cls.fused_layer_norm1_cast8, (lv3074, lv848, lv849_1), out_sinfo=R.Tensor((1, 1, 2560), dtype="float16")) | |
lv3077: R.Tensor((1, 1, 2560), dtype="float16") = lv1209 | |
lv1210: R.Tensor((10240, 320), dtype="uint32") = model_params[332] | |
lv1211: R.Tensor((10240, 80), dtype="float16") = model_params[333] | |
lv852: R.Tensor((10240,), dtype="float32") = model_params[334] | |
lv41 = R.call_tir(cls.fused_fused_decode4_fused_NT_matmul9_add8_gelu1_cast11, (lv1210, lv1211, lv3077, lv852), out_sinfo=R.Tensor((1, 1, 10240), dtype="float16")) | |
lv3083: R.Tensor((1, 1, 10240), dtype="float16") = lv41 | |
lv1214: R.Tensor((2560, 1280), dtype="uint32") = model_params[335] | |
lv1215: R.Tensor((2560, 320), dtype="float16") = model_params[336] | |
lv855_1: R.Tensor((2560,), dtype="float32") = model_params[337] | |
lv41_1 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul10_add9_cast8_cast12_add7, (lv1214, lv1215, lv3083, lv855_1, lv40_1), out_sinfo=R.Tensor((1, 1, 2560), dtype="float16")) | |
lv3090 = R.call_tir(cls.cast7, (lv41_1,), out_sinfo=R.Tensor((1, 1, 2560), dtype="float32")) | |
lv856: R.Tensor((2560,), dtype="float32") = model_params[338] | |
lv857: R.Tensor((2560,), dtype="float32") = model_params[339] | |
lv1218 = R.call_tir(cls.fused_layer_norm1_cast8, (lv3090, lv856, lv857), out_sinfo=R.Tensor((1, 1, 2560), dtype="float16")) | |
lv3093: R.Tensor((1, 1, 2560), dtype="float16") = lv1218 | |
lv1219: R.Tensor((7680, 320), dtype="uint32") = model_params[342] | |
lv1220: R.Tensor((7680, 80), dtype="float16") = model_params[343] | |
lv860_1: R.Tensor((7680,), dtype="float16") = model_params[344] | |
lv42 = R.call_tir(cls.fused_fused_decode2_fused_NT_matmul6_add5, (lv1219, lv1220, lv3093, lv860_1), out_sinfo=R.Tensor((1, 1, 7680), dtype="float16")) | |
lv1223 = R.call_tir(cls.fused_reshape7_split1, (lv42,), out_sinfo=[R.Tensor((1, 1, 32, 80), dtype="float16"), R.Tensor((1, 1, 32, 80), dtype="float16"), R.Tensor((1, 1, 32, 80), dtype="float16")]) | |
lv3099: R.Tensor((1, 1, 32, 80), dtype="float16") = lv1223[0] | |
lv3100 = R.call_tir(cls.rotary_embedding1, (lv3099, metadata["relax.expr.Constant"][1], metadata["relax.expr.Constant"][2]), out_sinfo=R.Tensor((1, 1, 32, 80), dtype="float16"), tir_vars=R.shape([n])) | |
lv3101: R.Tensor((1, 1, 32, 80), dtype="float16") = lv1223[1] | |
lv3102 = R.call_tir(cls.rotary_embedding1, (lv3101, metadata["relax.expr.Constant"][1], metadata["relax.expr.Constant"][2]), out_sinfo=R.Tensor((1, 1, 32, 80), dtype="float16"), tir_vars=R.shape([n])) | |
lv3103: R.Object = kv_cache[42] | |
lv3104 = R.call_tir(cls.squeeze1, (lv3102,), out_sinfo=R.Tensor((1, 32, 80), dtype="float16")) | |
lv3105: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv3103, lv3104, sinfo_args=(R.Object,)) | |
lv3106: R.Object = kv_cache[43] | |
lv1224: R.Tensor((1, 1, 32, 80), dtype="float16") = lv1223[2] | |
lv1225 = R.call_tir(cls.fused_squeeze1, (lv1224,), out_sinfo=R.Tensor((1, 32, 80), dtype="float16")) | |
lv3109: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv3106, lv1225, sinfo_args=(R.Object,)) | |
lv3110: R.Tensor((n, 32, 80), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv3105, R.shape([n, 32, 80]), sinfo_args=(R.Tensor((n, 32, 80), dtype="float16"),)) | |
lv3111: R.Tensor((n, 32, 80), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv3109, R.shape([n, 32, 80]), sinfo_args=(R.Tensor((n, 32, 80), dtype="float16"),)) | |
lv3112 = R.call_tir(cls.reshape3, (lv3110,), out_sinfo=R.Tensor((1, n, 32, 80), dtype="float16")) | |
lv3113 = R.call_tir(cls.reshape3, (lv3111,), out_sinfo=R.Tensor((1, n, 32, 80), dtype="float16")) | |
lv3114 = R.call_tir(cls.transpose7, (lv3100,), out_sinfo=R.Tensor((1, 32, 1, 80), dtype="float16")) | |
lv3115 = R.call_tir(cls.transpose5, (lv3112,), out_sinfo=R.Tensor((1, 32, n, 80), dtype="float16")) | |
lv3116 = R.call_tir(cls.transpose5, (lv3113,), out_sinfo=R.Tensor((1, 32, n, 80), dtype="float16")) | |
lv1226 = R.call_tir(cls.fused_NT_matmul7_divide2_maximum1_minimum1_cast9, (lv3114, lv3115, lv1871), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float32")) | |
lv1227 = R.call_tir(cls.fused_softmax2_cast10, (lv1226,), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float16")) | |
lv3125 = R.call_tir(cls.matmul9, (lv1227, lv3116), out_sinfo=R.Tensor((1, 32, 1, 80), dtype="float16")) | |
lv1228 = R.call_tir(cls.fused_transpose8_reshape8, (lv3125,), out_sinfo=R.Tensor((1, 1, 2560), dtype="float16")) | |
lv1229: R.Tensor((2560, 320), dtype="uint32") = model_params[345] | |
lv1230: R.Tensor((2560, 80), dtype="float16") = model_params[346] | |
lv863_1: R.Tensor((2560,), dtype="float16") = model_params[347] | |
lv42_1 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul8_add6_add7, (lv1229, lv1230, lv1228, lv863_1, lv41_1), out_sinfo=R.Tensor((1, 1, 2560), dtype="float16")) | |
lv3132 = R.call_tir(cls.cast7, (lv42_1,), out_sinfo=R.Tensor((1, 1, 2560), dtype="float32")) | |
lv864_1: R.Tensor((2560,), dtype="float32") = model_params[340] | |
lv865_1: R.Tensor((2560,), dtype="float32") = model_params[341] | |
lv1233 = R.call_tir(cls.fused_layer_norm1_cast8, (lv3132, lv864_1, lv865_1), out_sinfo=R.Tensor((1, 1, 2560), dtype="float16")) | |
lv3135: R.Tensor((1, 1, 2560), dtype="float16") = lv1233 | |
lv1234: R.Tensor((10240, 320), dtype="uint32") = model_params[348] | |
lv1235: R.Tensor((10240, 80), dtype="float16") = model_params[349] | |
lv868_1: R.Tensor((10240,), dtype="float32") = model_params[350] | |
lv43 = R.call_tir(cls.fused_fused_decode4_fused_NT_matmul9_add8_gelu1_cast11, (lv1234, lv1235, lv3135, lv868_1), out_sinfo=R.Tensor((1, 1, 10240), dtype="float16")) | |
lv3141: R.Tensor((1, 1, 10240), dtype="float16") = lv43 | |
lv1238: R.Tensor((2560, 1280), dtype="uint32") = model_params[351] | |
lv1239: R.Tensor((2560, 320), dtype="float16") = model_params[352] | |
lv871: R.Tensor((2560,), dtype="float32") = model_params[353] | |
lv43_1 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul10_add9_cast8_cast12_add7, (lv1238, lv1239, lv3141, lv871, lv42_1), out_sinfo=R.Tensor((1, 1, 2560), dtype="float16")) | |
lv3148 = R.call_tir(cls.cast7, (lv43_1,), out_sinfo=R.Tensor((1, 1, 2560), dtype="float32")) | |
lv872: R.Tensor((2560,), dtype="float32") = model_params[354] | |
lv873_1: R.Tensor((2560,), dtype="float32") = model_params[355] | |
lv1242 = R.call_tir(cls.fused_layer_norm1_cast8, (lv3148, lv872, lv873_1), out_sinfo=R.Tensor((1, 1, 2560), dtype="float16")) | |
lv3151: R.Tensor((1, 1, 2560), dtype="float16") = lv1242 | |
lv1243: R.Tensor((7680, 320), dtype="uint32") = model_params[358] | |
lv1244: R.Tensor((7680, 80), dtype="float16") = model_params[359] | |
lv876: R.Tensor((7680,), dtype="float16") = model_params[360] | |
lv44 = R.call_tir(cls.fused_fused_decode2_fused_NT_matmul6_add5, (lv1243, lv1244, lv3151, lv876), out_sinfo=R.Tensor((1, 1, 7680), dtype="float16")) | |
lv1247 = R.call_tir(cls.fused_reshape7_split1, (lv44,), out_sinfo=[R.Tensor((1, 1, 32, 80), dtype="float16"), R.Tensor((1, 1, 32, 80), dtype="float16"), R.Tensor((1, 1, 32, 80), dtype="float16")]) | |
lv3157: R.Tensor((1, 1, 32, 80), dtype="float16") = lv1247[0] | |
lv3158 = R.call_tir(cls.rotary_embedding1, (lv3157, metadata["relax.expr.Constant"][1], metadata["relax.expr.Constant"][2]), out_sinfo=R.Tensor((1, 1, 32, 80), dtype="float16"), tir_vars=R.shape([n])) | |
lv3159: R.Tensor((1, 1, 32, 80), dtype="float16") = lv1247[1] | |
lv3160 = R.call_tir(cls.rotary_embedding1, (lv3159, metadata["relax.expr.Constant"][1], metadata["relax.expr.Constant"][2]), out_sinfo=R.Tensor((1, 1, 32, 80), dtype="float16"), tir_vars=R.shape([n])) | |
lv3161: R.Object = kv_cache[44] | |
lv3162 = R.call_tir(cls.squeeze1, (lv3160,), out_sinfo=R.Tensor((1, 32, 80), dtype="float16")) | |
lv3163: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv3161, lv3162, sinfo_args=(R.Object,)) | |
lv3164: R.Object = kv_cache[45] | |
lv1248: R.Tensor((1, 1, 32, 80), dtype="float16") = lv1247[2] | |
lv1249 = R.call_tir(cls.fused_squeeze1, (lv1248,), out_sinfo=R.Tensor((1, 32, 80), dtype="float16")) | |
lv3167: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv3164, lv1249, sinfo_args=(R.Object,)) | |
lv3168: R.Tensor((n, 32, 80), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv3163, R.shape([n, 32, 80]), sinfo_args=(R.Tensor((n, 32, 80), dtype="float16"),)) | |
lv3169: R.Tensor((n, 32, 80), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv3167, R.shape([n, 32, 80]), sinfo_args=(R.Tensor((n, 32, 80), dtype="float16"),)) | |
lv3170 = R.call_tir(cls.reshape3, (lv3168,), out_sinfo=R.Tensor((1, n, 32, 80), dtype="float16")) | |
lv3171 = R.call_tir(cls.reshape3, (lv3169,), out_sinfo=R.Tensor((1, n, 32, 80), dtype="float16")) | |
lv3172 = R.call_tir(cls.transpose7, (lv3158,), out_sinfo=R.Tensor((1, 32, 1, 80), dtype="float16")) | |
lv3173 = R.call_tir(cls.transpose5, (lv3170,), out_sinfo=R.Tensor((1, 32, n, 80), dtype="float16")) | |
lv3174 = R.call_tir(cls.transpose5, (lv3171,), out_sinfo=R.Tensor((1, 32, n, 80), dtype="float16")) | |
lv1250 = R.call_tir(cls.fused_NT_matmul7_divide2_maximum1_minimum1_cast9, (lv3172, lv3173, lv1871), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float32")) | |
lv1251 = R.call_tir(cls.fused_softmax2_cast10, (lv1250,), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float16")) | |
lv3183 = R.call_tir(cls.matmul9, (lv1251, lv3174), out_sinfo=R.Tensor((1, 32, 1, 80), dtype="float16")) | |
lv1252 = R.call_tir(cls.fused_transpose8_reshape8, (lv3183,), out_sinfo=R.Tensor((1, 1, 2560), dtype="float16")) | |
lv1253: R.Tensor((2560, 320), dtype="uint32") = model_params[361] | |
lv1254: R.Tensor((2560, 80), dtype="float16") = model_params[362] | |
lv879_1: R.Tensor((2560,), dtype="float16") = model_params[363] | |
lv44_1 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul8_add6_add7, (lv1253, lv1254, lv1252, lv879_1, lv43_1), out_sinfo=R.Tensor((1, 1, 2560), dtype="float16")) | |
lv3190 = R.call_tir(cls.cast7, (lv44_1,), out_sinfo=R.Tensor((1, 1, 2560), dtype="float32")) | |
lv880: R.Tensor((2560,), dtype="float32") = model_params[356] | |
lv881: R.Tensor((2560,), dtype="float32") = model_params[357] | |
lv1257 = R.call_tir(cls.fused_layer_norm1_cast8, (lv3190, lv880, lv881), out_sinfo=R.Tensor((1, 1, 2560), dtype="float16")) | |
lv3193: R.Tensor((1, 1, 2560), dtype="float16") = lv1257 | |
lv1258: R.Tensor((10240, 320), dtype="uint32") = model_params[364] | |
lv1259: R.Tensor((10240, 80), dtype="float16") = model_params[365] | |
lv884_1: R.Tensor((10240,), dtype="float32") = model_params[366] | |
lv45 = R.call_tir(cls.fused_fused_decode4_fused_NT_matmul9_add8_gelu1_cast11, (lv1258, lv1259, lv3193, lv884_1), out_sinfo=R.Tensor((1, 1, 10240), dtype="float16")) | |
lv3199: R.Tensor((1, 1, 10240), dtype="float16") = lv45 | |
lv1262: R.Tensor((2560, 1280), dtype="uint32") = model_params[367] | |
lv1263: R.Tensor((2560, 320), dtype="float16") = model_params[368] | |
lv887_1: R.Tensor((2560,), dtype="float32") = model_params[369] | |
lv45_1 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul10_add9_cast8_cast12_add7, (lv1262, lv1263, lv3199, lv887_1, lv44_1), out_sinfo=R.Tensor((1, 1, 2560), dtype="float16")) | |
lv3206 = R.call_tir(cls.cast7, (lv45_1,), out_sinfo=R.Tensor((1, 1, 2560), dtype="float32")) | |
lv888_1: R.Tensor((2560,), dtype="float32") = model_params[370] | |
lv889_1: R.Tensor((2560,), dtype="float32") = model_params[371] | |
lv1266 = R.call_tir(cls.fused_layer_norm1_cast8, (lv3206, lv888_1, lv889_1), out_sinfo=R.Tensor((1, 1, 2560), dtype="float16")) | |
lv3209: R.Tensor((1, 1, 2560), dtype="float16") = lv1266 | |
lv1267: R.Tensor((7680, 320), dtype="uint32") = model_params[374] | |
lv1268: R.Tensor((7680, 80), dtype="float16") = model_params[375] | |
lv892_1: R.Tensor((7680,), dtype="float16") = model_params[376] | |
lv46 = R.call_tir(cls.fused_fused_decode2_fused_NT_matmul6_add5, (lv1267, lv1268, lv3209, lv892_1), out_sinfo=R.Tensor((1, 1, 7680), dtype="float16")) | |
lv1271 = R.call_tir(cls.fused_reshape7_split1, (lv46,), out_sinfo=[R.Tensor((1, 1, 32, 80), dtype="float16"), R.Tensor((1, 1, 32, 80), dtype="float16"), R.Tensor((1, 1, 32, 80), dtype="float16")]) | |
lv3215: R.Tensor((1, 1, 32, 80), dtype="float16") = lv1271[0] | |
lv3216 = R.call_tir(cls.rotary_embedding1, (lv3215, metadata["relax.expr.Constant"][1], metadata["relax.expr.Constant"][2]), out_sinfo=R.Tensor((1, 1, 32, 80), dtype="float16"), tir_vars=R.shape([n])) | |
lv3217: R.Tensor((1, 1, 32, 80), dtype="float16") = lv1271[1] | |
lv3218 = R.call_tir(cls.rotary_embedding1, (lv3217, metadata["relax.expr.Constant"][1], metadata["relax.expr.Constant"][2]), out_sinfo=R.Tensor((1, 1, 32, 80), dtype="float16"), tir_vars=R.shape([n])) | |
lv3219: R.Object = kv_cache[46] | |
lv3220 = R.call_tir(cls.squeeze1, (lv3218,), out_sinfo=R.Tensor((1, 32, 80), dtype="float16")) | |
lv3221: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv3219, lv3220, sinfo_args=(R.Object,)) | |
lv3222: R.Object = kv_cache[47] | |
lv1272: R.Tensor((1, 1, 32, 80), dtype="float16") = lv1271[2] | |
lv1273 = R.call_tir(cls.fused_squeeze1, (lv1272,), out_sinfo=R.Tensor((1, 32, 80), dtype="float16")) | |
lv3225: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv3222, lv1273, sinfo_args=(R.Object,)) | |
lv3226: R.Tensor((n, 32, 80), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv3221, R.shape([n, 32, 80]), sinfo_args=(R.Tensor((n, 32, 80), dtype="float16"),)) | |
lv3227: R.Tensor((n, 32, 80), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv3225, R.shape([n, 32, 80]), sinfo_args=(R.Tensor((n, 32, 80), dtype="float16"),)) | |
lv3228 = R.call_tir(cls.reshape3, (lv3226,), out_sinfo=R.Tensor((1, n, 32, 80), dtype="float16")) | |
lv3229 = R.call_tir(cls.reshape3, (lv3227,), out_sinfo=R.Tensor((1, n, 32, 80), dtype="float16")) | |
lv3230 = R.call_tir(cls.transpose7, (lv3216,), out_sinfo=R.Tensor((1, 32, 1, 80), dtype="float16")) | |
lv3231 = R.call_tir(cls.transpose5, (lv3228,), out_sinfo=R.Tensor((1, 32, n, 80), dtype="float16")) | |
lv3232 = R.call_tir(cls.transpose5, (lv3229,), out_sinfo=R.Tensor((1, 32, n, 80), dtype="float16")) | |
lv1274 = R.call_tir(cls.fused_NT_matmul7_divide2_maximum1_minimum1_cast9, (lv3230, lv3231, lv1871), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float32")) | |
lv1275 = R.call_tir(cls.fused_softmax2_cast10, (lv1274,), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float16")) | |
lv3241 = R.call_tir(cls.matmul9, (lv1275, lv3232), out_sinfo=R.Tensor((1, 32, 1, 80), dtype="float16")) | |
lv1276 = R.call_tir(cls.fused_transpose8_reshape8, (lv3241,), out_sinfo=R.Tensor((1, 1, 2560), dtype="float16")) | |
lv1277: R.Tensor((2560, 320), dtype="uint32") = model_params[377] | |
lv1278: R.Tensor((2560, 80), dtype="float16") = model_params[378] | |
lv895: R.Tensor((2560,), dtype="float16") = model_params[379] | |
lv46_1 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul8_add6_add7, (lv1277, lv1278, lv1276, lv895, lv45_1), out_sinfo=R.Tensor((1, 1, 2560), dtype="float16")) | |
lv3248 = R.call_tir(cls.cast7, (lv46_1,), out_sinfo=R.Tensor((1, 1, 2560), dtype="float32")) | |
lv896: R.Tensor((2560,), dtype="float32") = model_params[372] | |
lv897_1: R.Tensor((2560,), dtype="float32") = model_params[373] | |
lv1281 = R.call_tir(cls.fused_layer_norm1_cast8, (lv3248, lv896, lv897_1), out_sinfo=R.Tensor((1, 1, 2560), dtype="float16")) | |
lv3251: R.Tensor((1, 1, 2560), dtype="float16") = lv1281 | |
lv1282: R.Tensor((10240, 320), dtype="uint32") = model_params[380] | |
lv1283: R.Tensor((10240, 80), dtype="float16") = model_params[381] | |
lv900: R.Tensor((10240,), dtype="float32") = model_params[382] | |
lv47 = R.call_tir(cls.fused_fused_decode4_fused_NT_matmul9_add8_gelu1_cast11, (lv1282, lv1283, lv3251, lv900), out_sinfo=R.Tensor((1, 1, 10240), dtype="float16")) | |
lv3257: R.Tensor((1, 1, 10240), dtype="float16") = lv47 | |
lv1286: R.Tensor((2560, 1280), dtype="uint32") = model_params[383] | |
lv1287: R.Tensor((2560, 320), dtype="float16") = model_params[384] | |
lv903_1: R.Tensor((2560,), dtype="float32") = model_params[385] | |
lv47_1 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul10_add9_cast8_cast12_add7, (lv1286, lv1287, lv3257, lv903_1, lv46_1), out_sinfo=R.Tensor((1, 1, 2560), dtype="float16")) | |
lv3264 = R.call_tir(cls.cast7, (lv47_1,), out_sinfo=R.Tensor((1, 1, 2560), dtype="float32")) | |
lv904: R.Tensor((2560,), dtype="float32") = model_params[386] | |
lv905: R.Tensor((2560,), dtype="float32") = model_params[387] | |
lv1290 = R.call_tir(cls.fused_layer_norm1_cast8, (lv3264, lv904, lv905), out_sinfo=R.Tensor((1, 1, 2560), dtype="float16")) | |
lv3267: R.Tensor((1, 1, 2560), dtype="float16") = lv1290 | |
lv1291: R.Tensor((7680, 320), dtype="uint32") = model_params[390] | |
lv1292: R.Tensor((7680, 80), dtype="float16") = model_params[391] | |
lv908_1: R.Tensor((7680,), dtype="float16") = model_params[392] | |
lv48 = R.call_tir(cls.fused_fused_decode2_fused_NT_matmul6_add5, (lv1291, lv1292, lv3267, lv908_1), out_sinfo=R.Tensor((1, 1, 7680), dtype="float16")) | |
lv1295 = R.call_tir(cls.fused_reshape7_split1, (lv48,), out_sinfo=[R.Tensor((1, 1, 32, 80), dtype="float16"), R.Tensor((1, 1, 32, 80), dtype="float16"), R.Tensor((1, 1, 32, 80), dtype="float16")]) | |
lv3273: R.Tensor((1, 1, 32, 80), dtype="float16") = lv1295[0] | |
lv3274 = R.call_tir(cls.rotary_embedding1, (lv3273, metadata["relax.expr.Constant"][1], metadata["relax.expr.Constant"][2]), out_sinfo=R.Tensor((1, 1, 32, 80), dtype="float16"), tir_vars=R.shape([n])) | |
lv3275: R.Tensor((1, 1, 32, 80), dtype="float16") = lv1295[1] | |
lv3276 = R.call_tir(cls.rotary_embedding1, (lv3275, metadata["relax.expr.Constant"][1], metadata["relax.expr.Constant"][2]), out_sinfo=R.Tensor((1, 1, 32, 80), dtype="float16"), tir_vars=R.shape([n])) | |
lv3277: R.Object = kv_cache[48] | |
lv3278 = R.call_tir(cls.squeeze1, (lv3276,), out_sinfo=R.Tensor((1, 32, 80), dtype="float16")) | |
lv3279: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv3277, lv3278, sinfo_args=(R.Object,)) | |
lv3280: R.Object = kv_cache[49] | |
lv1296: R.Tensor((1, 1, 32, 80), dtype="float16") = lv1295[2] | |
lv1297 = R.call_tir(cls.fused_squeeze1, (lv1296,), out_sinfo=R.Tensor((1, 32, 80), dtype="float16")) | |
lv3283: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv3280, lv1297, sinfo_args=(R.Object,)) | |
lv3284: R.Tensor((n, 32, 80), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv3279, R.shape([n, 32, 80]), sinfo_args=(R.Tensor((n, 32, 80), dtype="float16"),)) | |
lv3285: R.Tensor((n, 32, 80), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv3283, R.shape([n, 32, 80]), sinfo_args=(R.Tensor((n, 32, 80), dtype="float16"),)) | |
lv3286 = R.call_tir(cls.reshape3, (lv3284,), out_sinfo=R.Tensor((1, n, 32, 80), dtype="float16")) | |
lv3287 = R.call_tir(cls.reshape3, (lv3285,), out_sinfo=R.Tensor((1, n, 32, 80), dtype="float16")) | |
lv3288 = R.call_tir(cls.transpose7, (lv3274,), out_sinfo=R.Tensor((1, 32, 1, 80), dtype="float16")) | |
lv3289 = R.call_tir(cls.transpose5, (lv3286,), out_sinfo=R.Tensor((1, 32, n, 80), dtype="float16")) | |
lv3290 = R.call_tir(cls.transpose5, (lv3287,), out_sinfo=R.Tensor((1, 32, n, 80), dtype="float16")) | |
lv1298 = R.call_tir(cls.fused_NT_matmul7_divide2_maximum1_minimum1_cast9, (lv3288, lv3289, lv1871), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float32")) | |
lv1299 = R.call_tir(cls.fused_softmax2_cast10, (lv1298,), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float16")) | |
lv3299 = R.call_tir(cls.matmul9, (lv1299, lv3290), out_sinfo=R.Tensor((1, 32, 1, 80), dtype="float16")) | |
lv1300 = R.call_tir(cls.fused_transpose8_reshape8, (lv3299,), out_sinfo=R.Tensor((1, 1, 2560), dtype="float16")) | |
lv1301: R.Tensor((2560, 320), dtype="uint32") = model_params[393] | |
lv1302: R.Tensor((2560, 80), dtype="float16") = model_params[394] | |
lv911_1: R.Tensor((2560,), dtype="float16") = model_params[395] | |
lv48_1 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul8_add6_add7, (lv1301, lv1302, lv1300, lv911_1, lv47_1), out_sinfo=R.Tensor((1, 1, 2560), dtype="float16")) | |
lv3306 = R.call_tir(cls.cast7, (lv48_1,), out_sinfo=R.Tensor((1, 1, 2560), dtype="float32")) | |
lv912_1: R.Tensor((2560,), dtype="float32") = model_params[388] | |
lv913_1: R.Tensor((2560,), dtype="float32") = model_params[389] | |
lv1305 = R.call_tir(cls.fused_layer_norm1_cast8, (lv3306, lv912_1, lv913_1), out_sinfo=R.Tensor((1, 1, 2560), dtype="float16")) | |
lv3309: R.Tensor((1, 1, 2560), dtype="float16") = lv1305 | |
lv1306: R.Tensor((10240, 320), dtype="uint32") = model_params[396] | |
lv1307: R.Tensor((10240, 80), dtype="float16") = model_params[397] | |
lv916_1: R.Tensor((10240,), dtype="float32") = model_params[398] | |
lv49 = R.call_tir(cls.fused_fused_decode4_fused_NT_matmul9_add8_gelu1_cast11, (lv1306, lv1307, lv3309, lv916_1), out_sinfo=R.Tensor((1, 1, 10240), dtype="float16")) | |
lv3315: R.Tensor((1, 1, 10240), dtype="float16") = lv49 | |
lv1310: R.Tensor((2560, 1280), dtype="uint32") = model_params[399] | |
lv1311: R.Tensor((2560, 320), dtype="float16") = model_params[400] | |
lv919: R.Tensor((2560,), dtype="float32") = model_params[401] | |
lv49_1 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul10_add9_cast8_cast12_add7, (lv1310, lv1311, lv3315, lv919, lv48_1), out_sinfo=R.Tensor((1, 1, 2560), dtype="float16")) | |
lv3322 = R.call_tir(cls.cast7, (lv49_1,), out_sinfo=R.Tensor((1, 1, 2560), dtype="float32")) | |
lv920: R.Tensor((2560,), dtype="float32") = model_params[402] | |
lv921_1: R.Tensor((2560,), dtype="float32") = model_params[403] | |
lv1314 = R.call_tir(cls.fused_layer_norm1_cast8, (lv3322, lv920, lv921_1), out_sinfo=R.Tensor((1, 1, 2560), dtype="float16")) | |
lv3325: R.Tensor((1, 1, 2560), dtype="float16") = lv1314 | |
lv1315: R.Tensor((7680, 320), dtype="uint32") = model_params[406] | |
lv1316: R.Tensor((7680, 80), dtype="float16") = model_params[407] | |
lv924: R.Tensor((7680,), dtype="float16") = model_params[408] | |
lv50 = R.call_tir(cls.fused_fused_decode2_fused_NT_matmul6_add5, (lv1315, lv1316, lv3325, lv924), out_sinfo=R.Tensor((1, 1, 7680), dtype="float16")) | |
lv1319 = R.call_tir(cls.fused_reshape7_split1, (lv50,), out_sinfo=[R.Tensor((1, 1, 32, 80), dtype="float16"), R.Tensor((1, 1, 32, 80), dtype="float16"), R.Tensor((1, 1, 32, 80), dtype="float16")]) | |
lv3331: R.Tensor((1, 1, 32, 80), dtype="float16") = lv1319[0] | |
lv3332 = R.call_tir(cls.rotary_embedding1, (lv3331, metadata["relax.expr.Constant"][1], metadata["relax.expr.Constant"][2]), out_sinfo=R.Tensor((1, 1, 32, 80), dtype="float16"), tir_vars=R.shape([n])) | |
lv3333: R.Tensor((1, 1, 32, 80), dtype="float16") = lv1319[1] | |
lv3334 = R.call_tir(cls.rotary_embedding1, (lv3333, metadata["relax.expr.Constant"][1], metadata["relax.expr.Constant"][2]), out_sinfo=R.Tensor((1, 1, 32, 80), dtype="float16"), tir_vars=R.shape([n])) | |
lv3335: R.Object = kv_cache[50] | |
lv3336 = R.call_tir(cls.squeeze1, (lv3334,), out_sinfo=R.Tensor((1, 32, 80), dtype="float16")) | |
lv3337: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv3335, lv3336, sinfo_args=(R.Object,)) | |
lv3338: R.Object = kv_cache[51] | |
lv1320: R.Tensor((1, 1, 32, 80), dtype="float16") = lv1319[2] | |
lv1321 = R.call_tir(cls.fused_squeeze1, (lv1320,), out_sinfo=R.Tensor((1, 32, 80), dtype="float16")) | |
lv3341: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv3338, lv1321, sinfo_args=(R.Object,)) | |
lv3342: R.Tensor((n, 32, 80), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv3337, R.shape([n, 32, 80]), sinfo_args=(R.Tensor((n, 32, 80), dtype="float16"),)) | |
lv3343: R.Tensor((n, 32, 80), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv3341, R.shape([n, 32, 80]), sinfo_args=(R.Tensor((n, 32, 80), dtype="float16"),)) | |
lv3344 = R.call_tir(cls.reshape3, (lv3342,), out_sinfo=R.Tensor((1, n, 32, 80), dtype="float16")) | |
lv3345 = R.call_tir(cls.reshape3, (lv3343,), out_sinfo=R.Tensor((1, n, 32, 80), dtype="float16")) | |
lv3346 = R.call_tir(cls.transpose7, (lv3332,), out_sinfo=R.Tensor((1, 32, 1, 80), dtype="float16")) | |
lv3347 = R.call_tir(cls.transpose5, (lv3344,), out_sinfo=R.Tensor((1, 32, n, 80), dtype="float16")) | |
lv3348 = R.call_tir(cls.transpose5, (lv3345,), out_sinfo=R.Tensor((1, 32, n, 80), dtype="float16")) | |
lv1322 = R.call_tir(cls.fused_NT_matmul7_divide2_maximum1_minimum1_cast9, (lv3346, lv3347, lv1871), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float32")) | |
lv1323 = R.call_tir(cls.fused_softmax2_cast10, (lv1322,), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float16")) | |
lv3357 = R.call_tir(cls.matmul9, (lv1323, lv3348), out_sinfo=R.Tensor((1, 32, 1, 80), dtype="float16")) | |
lv1324 = R.call_tir(cls.fused_transpose8_reshape8, (lv3357,), out_sinfo=R.Tensor((1, 1, 2560), dtype="float16")) | |
lv1325: R.Tensor((2560, 320), dtype="uint32") = model_params[409] | |
lv1326: R.Tensor((2560, 80), dtype="float16") = model_params[410] | |
lv927_1: R.Tensor((2560,), dtype="float16") = model_params[411] | |
lv50_1 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul8_add6_add7, (lv1325, lv1326, lv1324, lv927_1, lv49_1), out_sinfo=R.Tensor((1, 1, 2560), dtype="float16")) | |
lv3364 = R.call_tir(cls.cast7, (lv50_1,), out_sinfo=R.Tensor((1, 1, 2560), dtype="float32")) | |
lv928: R.Tensor((2560,), dtype="float32") = model_params[404] | |
lv929: R.Tensor((2560,), dtype="float32") = model_params[405] | |
lv1329 = R.call_tir(cls.fused_layer_norm1_cast8, (lv3364, lv928, lv929), out_sinfo=R.Tensor((1, 1, 2560), dtype="float16")) | |
lv3367: R.Tensor((1, 1, 2560), dtype="float16") = lv1329 | |
lv1330: R.Tensor((10240, 320), dtype="uint32") = model_params[412] | |
lv1331: R.Tensor((10240, 80), dtype="float16") = model_params[413] | |
lv932_1: R.Tensor((10240,), dtype="float32") = model_params[414] | |
lv51 = R.call_tir(cls.fused_fused_decode4_fused_NT_matmul9_add8_gelu1_cast11, (lv1330, lv1331, lv3367, lv932_1), out_sinfo=R.Tensor((1, 1, 10240), dtype="float16")) | |
lv3373: R.Tensor((1, 1, 10240), dtype="float16") = lv51 | |
lv1334: R.Tensor((2560, 1280), dtype="uint32") = model_params[415] | |
lv1335: R.Tensor((2560, 320), dtype="float16") = model_params[416] | |
lv935_1: R.Tensor((2560,), dtype="float32") = model_params[417] | |
lv51_1 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul10_add9_cast8_cast12_add7, (lv1334, lv1335, lv3373, lv935_1, lv50_1), out_sinfo=R.Tensor((1, 1, 2560), dtype="float16")) | |
lv3380 = R.call_tir(cls.cast7, (lv51_1,), out_sinfo=R.Tensor((1, 1, 2560), dtype="float32")) | |
lv936_1: R.Tensor((2560,), dtype="float32") = model_params[418] | |
lv937_1: R.Tensor((2560,), dtype="float32") = model_params[419] | |
lv1338 = R.call_tir(cls.fused_layer_norm1_cast8, (lv3380, lv936_1, lv937_1), out_sinfo=R.Tensor((1, 1, 2560), dtype="float16")) | |
lv3383: R.Tensor((1, 1, 2560), dtype="float16") = lv1338 | |
lv1339: R.Tensor((7680, 320), dtype="uint32") = model_params[422] | |
lv1340: R.Tensor((7680, 80), dtype="float16") = model_params[423] | |
lv940_1: R.Tensor((7680,), dtype="float16") = model_params[424] | |
lv52 = R.call_tir(cls.fused_fused_decode2_fused_NT_matmul6_add5, (lv1339, lv1340, lv3383, lv940_1), out_sinfo=R.Tensor((1, 1, 7680), dtype="float16")) | |
lv1343 = R.call_tir(cls.fused_reshape7_split1, (lv52,), out_sinfo=[R.Tensor((1, 1, 32, 80), dtype="float16"), R.Tensor((1, 1, 32, 80), dtype="float16"), R.Tensor((1, 1, 32, 80), dtype="float16")]) | |
lv3389: R.Tensor((1, 1, 32, 80), dtype="float16") = lv1343[0] | |
lv3390 = R.call_tir(cls.rotary_embedding1, (lv3389, metadata["relax.expr.Constant"][1], metadata["relax.expr.Constant"][2]), out_sinfo=R.Tensor((1, 1, 32, 80), dtype="float16"), tir_vars=R.shape([n])) | |
lv3391: R.Tensor((1, 1, 32, 80), dtype="float16") = lv1343[1] | |
lv3392 = R.call_tir(cls.rotary_embedding1, (lv3391, metadata["relax.expr.Constant"][1], metadata["relax.expr.Constant"][2]), out_sinfo=R.Tensor((1, 1, 32, 80), dtype="float16"), tir_vars=R.shape([n])) | |
lv3393: R.Object = kv_cache[52] | |
lv3394 = R.call_tir(cls.squeeze1, (lv3392,), out_sinfo=R.Tensor((1, 32, 80), dtype="float16")) | |
lv3395: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv3393, lv3394, sinfo_args=(R.Object,)) | |
lv3396: R.Object = kv_cache[53] | |
lv1344: R.Tensor((1, 1, 32, 80), dtype="float16") = lv1343[2] | |
lv1345 = R.call_tir(cls.fused_squeeze1, (lv1344,), out_sinfo=R.Tensor((1, 32, 80), dtype="float16")) | |
lv3399: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv3396, lv1345, sinfo_args=(R.Object,)) | |
lv3400: R.Tensor((n, 32, 80), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv3395, R.shape([n, 32, 80]), sinfo_args=(R.Tensor((n, 32, 80), dtype="float16"),)) | |
lv3401: R.Tensor((n, 32, 80), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv3399, R.shape([n, 32, 80]), sinfo_args=(R.Tensor((n, 32, 80), dtype="float16"),)) | |
lv3402 = R.call_tir(cls.reshape3, (lv3400,), out_sinfo=R.Tensor((1, n, 32, 80), dtype="float16")) | |
lv3403 = R.call_tir(cls.reshape3, (lv3401,), out_sinfo=R.Tensor((1, n, 32, 80), dtype="float16")) | |
lv3404 = R.call_tir(cls.transpose7, (lv3390,), out_sinfo=R.Tensor((1, 32, 1, 80), dtype="float16")) | |
lv3405 = R.call_tir(cls.transpose5, (lv3402,), out_sinfo=R.Tensor((1, 32, n, 80), dtype="float16")) | |
lv3406 = R.call_tir(cls.transpose5, (lv3403,), out_sinfo=R.Tensor((1, 32, n, 80), dtype="float16")) | |
lv1346 = R.call_tir(cls.fused_NT_matmul7_divide2_maximum1_minimum1_cast9, (lv3404, lv3405, lv1871), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float32")) | |
lv1347 = R.call_tir(cls.fused_softmax2_cast10, (lv1346,), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float16")) | |
lv3415 = R.call_tir(cls.matmul9, (lv1347, lv3406), out_sinfo=R.Tensor((1, 32, 1, 80), dtype="float16")) | |
lv1348 = R.call_tir(cls.fused_transpose8_reshape8, (lv3415,), out_sinfo=R.Tensor((1, 1, 2560), dtype="float16")) | |
lv1349: R.Tensor((2560, 320), dtype="uint32") = model_params[425] | |
lv1350: R.Tensor((2560, 80), dtype="float16") = model_params[426] | |
lv943: R.Tensor((2560,), dtype="float16") = model_params[427] | |
lv52_1 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul8_add6_add7, (lv1349, lv1350, lv1348, lv943, lv51_1), out_sinfo=R.Tensor((1, 1, 2560), dtype="float16")) | |
lv3422 = R.call_tir(cls.cast7, (lv52_1,), out_sinfo=R.Tensor((1, 1, 2560), dtype="float32")) | |
lv944: R.Tensor((2560,), dtype="float32") = model_params[420] | |
lv945_1: R.Tensor((2560,), dtype="float32") = model_params[421] | |
lv1353 = R.call_tir(cls.fused_layer_norm1_cast8, (lv3422, lv944, lv945_1), out_sinfo=R.Tensor((1, 1, 2560), dtype="float16")) | |
lv3425: R.Tensor((1, 1, 2560), dtype="float16") = lv1353 | |
lv1354: R.Tensor((10240, 320), dtype="uint32") = model_params[428] | |
lv1355: R.Tensor((10240, 80), dtype="float16") = model_params[429] | |
lv948: R.Tensor((10240,), dtype="float32") = model_params[430] | |
lv53 = R.call_tir(cls.fused_fused_decode4_fused_NT_matmul9_add8_gelu1_cast11, (lv1354, lv1355, lv3425, lv948), out_sinfo=R.Tensor((1, 1, 10240), dtype="float16")) | |
lv3431: R.Tensor((1, 1, 10240), dtype="float16") = lv53 | |
lv1358: R.Tensor((2560, 1280), dtype="uint32") = model_params[431] | |
lv1359: R.Tensor((2560, 320), dtype="float16") = model_params[432] | |
lv951_1: R.Tensor((2560,), dtype="float32") = model_params[433] | |
lv53_1 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul10_add9_cast8_cast12_add7, (lv1358, lv1359, lv3431, lv951_1, lv52_1), out_sinfo=R.Tensor((1, 1, 2560), dtype="float16")) | |
lv3438 = R.call_tir(cls.cast7, (lv53_1,), out_sinfo=R.Tensor((1, 1, 2560), dtype="float32")) | |
lv952: R.Tensor((2560,), dtype="float32") = model_params[434] | |
lv953: R.Tensor((2560,), dtype="float32") = model_params[435] | |
lv1362 = R.call_tir(cls.fused_layer_norm1_cast8, (lv3438, lv952, lv953), out_sinfo=R.Tensor((1, 1, 2560), dtype="float16")) | |
lv3441: R.Tensor((1, 1, 2560), dtype="float16") = lv1362 | |
lv1363: R.Tensor((7680, 320), dtype="uint32") = model_params[438] | |
lv1364: R.Tensor((7680, 80), dtype="float16") = model_params[439] | |
lv956_1: R.Tensor((7680,), dtype="float16") = model_params[440] | |
lv54 = R.call_tir(cls.fused_fused_decode2_fused_NT_matmul6_add5, (lv1363, lv1364, lv3441, lv956_1), out_sinfo=R.Tensor((1, 1, 7680), dtype="float16")) | |
lv1367 = R.call_tir(cls.fused_reshape7_split1, (lv54,), out_sinfo=[R.Tensor((1, 1, 32, 80), dtype="float16"), R.Tensor((1, 1, 32, 80), dtype="float16"), R.Tensor((1, 1, 32, 80), dtype="float16")]) | |
lv3447: R.Tensor((1, 1, 32, 80), dtype="float16") = lv1367[0] | |
lv3448 = R.call_tir(cls.rotary_embedding1, (lv3447, metadata["relax.expr.Constant"][1], metadata["relax.expr.Constant"][2]), out_sinfo=R.Tensor((1, 1, 32, 80), dtype="float16"), tir_vars=R.shape([n])) | |
lv3449: R.Tensor((1, 1, 32, 80), dtype="float16") = lv1367[1] | |
lv3450 = R.call_tir(cls.rotary_embedding1, (lv3449, metadata["relax.expr.Constant"][1], metadata["relax.expr.Constant"][2]), out_sinfo=R.Tensor((1, 1, 32, 80), dtype="float16"), tir_vars=R.shape([n])) | |
lv3451: R.Object = kv_cache[54] | |
lv3452 = R.call_tir(cls.squeeze1, (lv3450,), out_sinfo=R.Tensor((1, 32, 80), dtype="float16")) | |
lv3453: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv3451, lv3452, sinfo_args=(R.Object,)) | |
lv3454: R.Object = kv_cache[55] | |
lv1368: R.Tensor((1, 1, 32, 80), dtype="float16") = lv1367[2] | |
lv1369 = R.call_tir(cls.fused_squeeze1, (lv1368,), out_sinfo=R.Tensor((1, 32, 80), dtype="float16")) | |
lv3457: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv3454, lv1369, sinfo_args=(R.Object,)) | |
lv3458: R.Tensor((n, 32, 80), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv3453, R.shape([n, 32, 80]), sinfo_args=(R.Tensor((n, 32, 80), dtype="float16"),)) | |
lv3459: R.Tensor((n, 32, 80), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv3457, R.shape([n, 32, 80]), sinfo_args=(R.Tensor((n, 32, 80), dtype="float16"),)) | |
lv3460 = R.call_tir(cls.reshape3, (lv3458,), out_sinfo=R.Tensor((1, n, 32, 80), dtype="float16")) | |
lv3461 = R.call_tir(cls.reshape3, (lv3459,), out_sinfo=R.Tensor((1, n, 32, 80), dtype="float16")) | |
lv3462 = R.call_tir(cls.transpose7, (lv3448,), out_sinfo=R.Tensor((1, 32, 1, 80), dtype="float16")) | |
lv3463 = R.call_tir(cls.transpose5, (lv3460,), out_sinfo=R.Tensor((1, 32, n, 80), dtype="float16")) | |
lv3464 = R.call_tir(cls.transpose5, (lv3461,), out_sinfo=R.Tensor((1, 32, n, 80), dtype="float16")) | |
lv1370 = R.call_tir(cls.fused_NT_matmul7_divide2_maximum1_minimum1_cast9, (lv3462, lv3463, lv1871), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float32")) | |
lv1371 = R.call_tir(cls.fused_softmax2_cast10, (lv1370,), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float16")) | |
lv3473 = R.call_tir(cls.matmul9, (lv1371, lv3464), out_sinfo=R.Tensor((1, 32, 1, 80), dtype="float16")) | |
lv1372 = R.call_tir(cls.fused_transpose8_reshape8, (lv3473,), out_sinfo=R.Tensor((1, 1, 2560), dtype="float16")) | |
lv1373: R.Tensor((2560, 320), dtype="uint32") = model_params[441] | |
lv1374: R.Tensor((2560, 80), dtype="float16") = model_params[442] | |
lv959_1: R.Tensor((2560,), dtype="float16") = model_params[443] | |
lv54_1 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul8_add6_add7, (lv1373, lv1374, lv1372, lv959_1, lv53_1), out_sinfo=R.Tensor((1, 1, 2560), dtype="float16")) | |
lv3480 = R.call_tir(cls.cast7, (lv54_1,), out_sinfo=R.Tensor((1, 1, 2560), dtype="float32")) | |
lv960_1: R.Tensor((2560,), dtype="float32") = model_params[436] | |
lv961_1: R.Tensor((2560,), dtype="float32") = model_params[437] | |
lv1377 = R.call_tir(cls.fused_layer_norm1_cast8, (lv3480, lv960_1, lv961_1), out_sinfo=R.Tensor((1, 1, 2560), dtype="float16")) | |
lv3483: R.Tensor((1, 1, 2560), dtype="float16") = lv1377 | |
lv1378: R.Tensor((10240, 320), dtype="uint32") = model_params[444] | |
lv1379: R.Tensor((10240, 80), dtype="float16") = model_params[445] | |
lv964_1: R.Tensor((10240,), dtype="float32") = model_params[446] | |
lv55 = R.call_tir(cls.fused_fused_decode4_fused_NT_matmul9_add8_gelu1_cast11, (lv1378, lv1379, lv3483, lv964_1), out_sinfo=R.Tensor((1, 1, 10240), dtype="float16")) | |
lv3489: R.Tensor((1, 1, 10240), dtype="float16") = lv55 | |
lv1382: R.Tensor((2560, 1280), dtype="uint32") = model_params[447] | |
lv1383: R.Tensor((2560, 320), dtype="float16") = model_params[448] | |
lv967: R.Tensor((2560,), dtype="float32") = model_params[449] | |
lv55_1 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul10_add9_cast8_cast12_add7, (lv1382, lv1383, lv3489, lv967, lv54_1), out_sinfo=R.Tensor((1, 1, 2560), dtype="float16")) | |
lv3496 = R.call_tir(cls.cast7, (lv55_1,), out_sinfo=R.Tensor((1, 1, 2560), dtype="float32")) | |
lv968: R.Tensor((2560,), dtype="float32") = model_params[450] | |
lv969_1: R.Tensor((2560,), dtype="float32") = model_params[451] | |
lv1386 = R.call_tir(cls.fused_layer_norm1_cast8, (lv3496, lv968, lv969_1), out_sinfo=R.Tensor((1, 1, 2560), dtype="float16")) | |
lv3499: R.Tensor((1, 1, 2560), dtype="float16") = lv1386 | |
lv1387: R.Tensor((7680, 320), dtype="uint32") = model_params[454] | |
lv1388: R.Tensor((7680, 80), dtype="float16") = model_params[455] | |
lv972: R.Tensor((7680,), dtype="float16") = model_params[456] | |
lv56 = R.call_tir(cls.fused_fused_decode2_fused_NT_matmul6_add5, (lv1387, lv1388, lv3499, lv972), out_sinfo=R.Tensor((1, 1, 7680), dtype="float16")) | |
lv1391 = R.call_tir(cls.fused_reshape7_split1, (lv56,), out_sinfo=[R.Tensor((1, 1, 32, 80), dtype="float16"), R.Tensor((1, 1, 32, 80), dtype="float16"), R.Tensor((1, 1, 32, 80), dtype="float16")]) | |
lv3505: R.Tensor((1, 1, 32, 80), dtype="float16") = lv1391[0] | |
lv3506 = R.call_tir(cls.rotary_embedding1, (lv3505, metadata["relax.expr.Constant"][1], metadata["relax.expr.Constant"][2]), out_sinfo=R.Tensor((1, 1, 32, 80), dtype="float16"), tir_vars=R.shape([n])) | |
lv3507: R.Tensor((1, 1, 32, 80), dtype="float16") = lv1391[1] | |
lv3508 = R.call_tir(cls.rotary_embedding1, (lv3507, metadata["relax.expr.Constant"][1], metadata["relax.expr.Constant"][2]), out_sinfo=R.Tensor((1, 1, 32, 80), dtype="float16"), tir_vars=R.shape([n])) | |
lv3509: R.Object = kv_cache[56] | |
lv3510 = R.call_tir(cls.squeeze1, (lv3508,), out_sinfo=R.Tensor((1, 32, 80), dtype="float16")) | |
lv3511: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv3509, lv3510, sinfo_args=(R.Object,)) | |
lv3512: R.Object = kv_cache[57] | |
lv1392: R.Tensor((1, 1, 32, 80), dtype="float16") = lv1391[2] | |
lv1393 = R.call_tir(cls.fused_squeeze1, (lv1392,), out_sinfo=R.Tensor((1, 32, 80), dtype="float16")) | |
lv3515: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv3512, lv1393, sinfo_args=(R.Object,)) | |
lv3516: R.Tensor((n, 32, 80), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv3511, R.shape([n, 32, 80]), sinfo_args=(R.Tensor((n, 32, 80), dtype="float16"),)) | |
lv3517: R.Tensor((n, 32, 80), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv3515, R.shape([n, 32, 80]), sinfo_args=(R.Tensor((n, 32, 80), dtype="float16"),)) | |
lv3518 = R.call_tir(cls.reshape3, (lv3516,), out_sinfo=R.Tensor((1, n, 32, 80), dtype="float16")) | |
lv3519 = R.call_tir(cls.reshape3, (lv3517,), out_sinfo=R.Tensor((1, n, 32, 80), dtype="float16")) | |
lv3520 = R.call_tir(cls.transpose7, (lv3506,), out_sinfo=R.Tensor((1, 32, 1, 80), dtype="float16")) | |
lv3521 = R.call_tir(cls.transpose5, (lv3518,), out_sinfo=R.Tensor((1, 32, n, 80), dtype="float16")) | |
lv3522 = R.call_tir(cls.transpose5, (lv3519,), out_sinfo=R.Tensor((1, 32, n, 80), dtype="float16")) | |
lv1394 = R.call_tir(cls.fused_NT_matmul7_divide2_maximum1_minimum1_cast9, (lv3520, lv3521, lv1871), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float32")) | |
lv1395 = R.call_tir(cls.fused_softmax2_cast10, (lv1394,), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float16")) | |
lv3531 = R.call_tir(cls.matmul9, (lv1395, lv3522), out_sinfo=R.Tensor((1, 32, 1, 80), dtype="float16")) | |
lv1396 = R.call_tir(cls.fused_transpose8_reshape8, (lv3531,), out_sinfo=R.Tensor((1, 1, 2560), dtype="float16")) | |
lv1397: R.Tensor((2560, 320), dtype="uint32") = model_params[457] | |
lv1398: R.Tensor((2560, 80), dtype="float16") = model_params[458] | |
lv975_1: R.Tensor((2560,), dtype="float16") = model_params[459] | |
lv56_1 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul8_add6_add7, (lv1397, lv1398, lv1396, lv975_1, lv55_1), out_sinfo=R.Tensor((1, 1, 2560), dtype="float16")) | |
lv3538 = R.call_tir(cls.cast7, (lv56_1,), out_sinfo=R.Tensor((1, 1, 2560), dtype="float32")) | |
lv976: R.Tensor((2560,), dtype="float32") = model_params[452] | |
lv977: R.Tensor((2560,), dtype="float32") = model_params[453] | |
lv1401 = R.call_tir(cls.fused_layer_norm1_cast8, (lv3538, lv976, lv977), out_sinfo=R.Tensor((1, 1, 2560), dtype="float16")) | |
lv3541: R.Tensor((1, 1, 2560), dtype="float16") = lv1401 | |
lv1402: R.Tensor((10240, 320), dtype="uint32") = model_params[460] | |
lv1403: R.Tensor((10240, 80), dtype="float16") = model_params[461] | |
lv980_1: R.Tensor((10240,), dtype="float32") = model_params[462] | |
lv57 = R.call_tir(cls.fused_fused_decode4_fused_NT_matmul9_add8_gelu1_cast11, (lv1402, lv1403, lv3541, lv980_1), out_sinfo=R.Tensor((1, 1, 10240), dtype="float16")) | |
lv3547: R.Tensor((1, 1, 10240), dtype="float16") = lv57 | |
lv1406: R.Tensor((2560, 1280), dtype="uint32") = model_params[463] | |
lv1407: R.Tensor((2560, 320), dtype="float16") = model_params[464] | |
lv983_1: R.Tensor((2560,), dtype="float32") = model_params[465] | |
lv57_1 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul10_add9_cast8_cast12_add7, (lv1406, lv1407, lv3547, lv983_1, lv56_1), out_sinfo=R.Tensor((1, 1, 2560), dtype="float16")) | |
lv3554 = R.call_tir(cls.cast7, (lv57_1,), out_sinfo=R.Tensor((1, 1, 2560), dtype="float32")) | |
lv984_1: R.Tensor((2560,), dtype="float32") = model_params[466] | |
lv985_1: R.Tensor((2560,), dtype="float32") = model_params[467] | |
lv1410 = R.call_tir(cls.fused_layer_norm1_cast8, (lv3554, lv984_1, lv985_1), out_sinfo=R.Tensor((1, 1, 2560), dtype="float16")) | |
lv3557: R.Tensor((1, 1, 2560), dtype="float16") = lv1410 | |
lv1411: R.Tensor((7680, 320), dtype="uint32") = model_params[470] | |
lv1412: R.Tensor((7680, 80), dtype="float16") = model_params[471] | |
lv988_1: R.Tensor((7680,), dtype="float16") = model_params[472] | |
lv58 = R.call_tir(cls.fused_fused_decode2_fused_NT_matmul6_add5, (lv1411, lv1412, lv3557, lv988_1), out_sinfo=R.Tensor((1, 1, 7680), dtype="float16")) | |
lv1415 = R.call_tir(cls.fused_reshape7_split1, (lv58,), out_sinfo=[R.Tensor((1, 1, 32, 80), dtype="float16"), R.Tensor((1, 1, 32, 80), dtype="float16"), R.Tensor((1, 1, 32, 80), dtype="float16")]) | |
lv3563: R.Tensor((1, 1, 32, 80), dtype="float16") = lv1415[0] | |
lv3564 = R.call_tir(cls.rotary_embedding1, (lv3563, metadata["relax.expr.Constant"][1], metadata["relax.expr.Constant"][2]), out_sinfo=R.Tensor((1, 1, 32, 80), dtype="float16"), tir_vars=R.shape([n])) | |
lv3565: R.Tensor((1, 1, 32, 80), dtype="float16") = lv1415[1] | |
lv3566 = R.call_tir(cls.rotary_embedding1, (lv3565, metadata["relax.expr.Constant"][1], metadata["relax.expr.Constant"][2]), out_sinfo=R.Tensor((1, 1, 32, 80), dtype="float16"), tir_vars=R.shape([n])) | |
lv3567: R.Object = kv_cache[58] | |
lv3568 = R.call_tir(cls.squeeze1, (lv3566,), out_sinfo=R.Tensor((1, 32, 80), dtype="float16")) | |
lv3569: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv3567, lv3568, sinfo_args=(R.Object,)) | |
lv3570: R.Object = kv_cache[59] | |
lv1416: R.Tensor((1, 1, 32, 80), dtype="float16") = lv1415[2] | |
lv1417 = R.call_tir(cls.fused_squeeze1, (lv1416,), out_sinfo=R.Tensor((1, 32, 80), dtype="float16")) | |
lv3573: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv3570, lv1417, sinfo_args=(R.Object,)) | |
lv3574: R.Tensor((n, 32, 80), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv3569, R.shape([n, 32, 80]), sinfo_args=(R.Tensor((n, 32, 80), dtype="float16"),)) | |
lv3575: R.Tensor((n, 32, 80), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv3573, R.shape([n, 32, 80]), sinfo_args=(R.Tensor((n, 32, 80), dtype="float16"),)) | |
lv3576 = R.call_tir(cls.reshape3, (lv3574,), out_sinfo=R.Tensor((1, n, 32, 80), dtype="float16")) | |
lv3577 = R.call_tir(cls.reshape3, (lv3575,), out_sinfo=R.Tensor((1, n, 32, 80), dtype="float16")) | |
lv3578 = R.call_tir(cls.transpose7, (lv3564,), out_sinfo=R.Tensor((1, 32, 1, 80), dtype="float16")) | |
lv3579 = R.call_tir(cls.transpose5, (lv3576,), out_sinfo=R.Tensor((1, 32, n, 80), dtype="float16")) | |
lv3580 = R.call_tir(cls.transpose5, (lv3577,), out_sinfo=R.Tensor((1, 32, n, 80), dtype="float16")) | |
lv1418 = R.call_tir(cls.fused_NT_matmul7_divide2_maximum1_minimum1_cast9, (lv3578, lv3579, lv1871), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float32")) | |
lv1419 = R.call_tir(cls.fused_softmax2_cast10, (lv1418,), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float16")) | |
lv3589 = R.call_tir(cls.matmul9, (lv1419, lv3580), out_sinfo=R.Tensor((1, 32, 1, 80), dtype="float16")) | |
lv1420 = R.call_tir(cls.fused_transpose8_reshape8, (lv3589,), out_sinfo=R.Tensor((1, 1, 2560), dtype="float16")) | |
lv1421: R.Tensor((2560, 320), dtype="uint32") = model_params[473] | |
lv1422: R.Tensor((2560, 80), dtype="float16") = model_params[474] | |
lv991: R.Tensor((2560,), dtype="float16") = model_params[475] | |
lv58_1 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul8_add6_add7, (lv1421, lv1422, lv1420, lv991, lv57_1), out_sinfo=R.Tensor((1, 1, 2560), dtype="float16")) | |
lv3596 = R.call_tir(cls.cast7, (lv58_1,), out_sinfo=R.Tensor((1, 1, 2560), dtype="float32")) | |
lv992: R.Tensor((2560,), dtype="float32") = model_params[468] | |
lv993_1: R.Tensor((2560,), dtype="float32") = model_params[469] | |
lv1425 = R.call_tir(cls.fused_layer_norm1_cast8, (lv3596, lv992, lv993_1), out_sinfo=R.Tensor((1, 1, 2560), dtype="float16")) | |
lv3599: R.Tensor((1, 1, 2560), dtype="float16") = lv1425 | |
lv1426: R.Tensor((10240, 320), dtype="uint32") = model_params[476] | |
lv1427: R.Tensor((10240, 80), dtype="float16") = model_params[477] | |
lv996: R.Tensor((10240,), dtype="float32") = model_params[478] | |
lv59 = R.call_tir(cls.fused_fused_decode4_fused_NT_matmul9_add8_gelu1_cast11, (lv1426, lv1427, lv3599, lv996), out_sinfo=R.Tensor((1, 1, 10240), dtype="float16")) | |
lv3605: R.Tensor((1, 1, 10240), dtype="float16") = lv59 | |
lv1430: R.Tensor((2560, 1280), dtype="uint32") = model_params[479] | |
lv1431: R.Tensor((2560, 320), dtype="float16") = model_params[480] | |
lv999_1: R.Tensor((2560,), dtype="float32") = model_params[481] | |
lv59_1 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul10_add9_cast8_cast12_add7, (lv1430, lv1431, lv3605, lv999_1, lv58_1), out_sinfo=R.Tensor((1, 1, 2560), dtype="float16")) | |
lv3612 = R.call_tir(cls.cast7, (lv59_1,), out_sinfo=R.Tensor((1, 1, 2560), dtype="float32")) | |
lv1000: R.Tensor((2560,), dtype="float32") = model_params[482] | |
lv1001: R.Tensor((2560,), dtype="float32") = model_params[483] | |
lv1434 = R.call_tir(cls.fused_layer_norm1_cast8, (lv3612, lv1000, lv1001), out_sinfo=R.Tensor((1, 1, 2560), dtype="float16")) | |
lv3615: R.Tensor((1, 1, 2560), dtype="float16") = lv1434 | |
lv1435: R.Tensor((7680, 320), dtype="uint32") = model_params[486] | |
lv1436: R.Tensor((7680, 80), dtype="float16") = model_params[487] | |
lv1004_1: R.Tensor((7680,), dtype="float16") = model_params[488] | |
lv60 = R.call_tir(cls.fused_fused_decode2_fused_NT_matmul6_add5, (lv1435, lv1436, lv3615, lv1004_1), out_sinfo=R.Tensor((1, 1, 7680), dtype="float16")) | |
lv1439 = R.call_tir(cls.fused_reshape7_split1, (lv60,), out_sinfo=[R.Tensor((1, 1, 32, 80), dtype="float16"), R.Tensor((1, 1, 32, 80), dtype="float16"), R.Tensor((1, 1, 32, 80), dtype="float16")]) | |
lv3621: R.Tensor((1, 1, 32, 80), dtype="float16") = lv1439[0] | |
lv3622 = R.call_tir(cls.rotary_embedding1, (lv3621, metadata["relax.expr.Constant"][1], metadata["relax.expr.Constant"][2]), out_sinfo=R.Tensor((1, 1, 32, 80), dtype="float16"), tir_vars=R.shape([n])) | |
lv3623: R.Tensor((1, 1, 32, 80), dtype="float16") = lv1439[1] | |
lv3624 = R.call_tir(cls.rotary_embedding1, (lv3623, metadata["relax.expr.Constant"][1], metadata["relax.expr.Constant"][2]), out_sinfo=R.Tensor((1, 1, 32, 80), dtype="float16"), tir_vars=R.shape([n])) | |
lv3625: R.Object = kv_cache[60] | |
lv3626 = R.call_tir(cls.squeeze1, (lv3624,), out_sinfo=R.Tensor((1, 32, 80), dtype="float16")) | |
lv3627: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv3625, lv3626, sinfo_args=(R.Object,)) | |
lv3628: R.Object = kv_cache[61] | |
lv1440: R.Tensor((1, 1, 32, 80), dtype="float16") = lv1439[2] | |
lv1441 = R.call_tir(cls.fused_squeeze1, (lv1440,), out_sinfo=R.Tensor((1, 32, 80), dtype="float16")) | |
lv3631: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv3628, lv1441, sinfo_args=(R.Object,)) | |
lv3632: R.Tensor((n, 32, 80), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv3627, R.shape([n, 32, 80]), sinfo_args=(R.Tensor((n, 32, 80), dtype="float16"),)) | |
lv3633: R.Tensor((n, 32, 80), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv3631, R.shape([n, 32, 80]), sinfo_args=(R.Tensor((n, 32, 80), dtype="float16"),)) | |
lv3634 = R.call_tir(cls.reshape3, (lv3632,), out_sinfo=R.Tensor((1, n, 32, 80), dtype="float16")) | |
lv3635 = R.call_tir(cls.reshape3, (lv3633,), out_sinfo=R.Tensor((1, n, 32, 80), dtype="float16")) | |
lv3636 = R.call_tir(cls.transpose7, (lv3622,), out_sinfo=R.Tensor((1, 32, 1, 80), dtype="float16")) | |
lv3637 = R.call_tir(cls.transpose5, (lv3634,), out_sinfo=R.Tensor((1, 32, n, 80), dtype="float16")) | |
lv3638 = R.call_tir(cls.transpose5, (lv3635,), out_sinfo=R.Tensor((1, 32, n, 80), dtype="float16")) | |
lv1442 = R.call_tir(cls.fused_NT_matmul7_divide2_maximum1_minimum1_cast9, (lv3636, lv3637, lv1871), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float32")) | |
lv1443 = R.call_tir(cls.fused_softmax2_cast10, (lv1442,), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float16")) | |
lv3647 = R.call_tir(cls.matmul9, (lv1443, lv3638), out_sinfo=R.Tensor((1, 32, 1, 80), dtype="float16")) | |
lv1444 = R.call_tir(cls.fused_transpose8_reshape8, (lv3647,), out_sinfo=R.Tensor((1, 1, 2560), dtype="float16")) | |
lv1445: R.Tensor((2560, 320), dtype="uint32") = model_params[489] | |
lv1446: R.Tensor((2560, 80), dtype="float16") = model_params[490] | |
lv1007_1: R.Tensor((2560,), dtype="float16") = model_params[491] | |
lv60_1 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul8_add6_add7, (lv1445, lv1446, lv1444, lv1007_1, lv59_1), out_sinfo=R.Tensor((1, 1, 2560), dtype="float16")) | |
lv3654 = R.call_tir(cls.cast7, (lv60_1,), out_sinfo=R.Tensor((1, 1, 2560), dtype="float32")) | |
lv1008_1: R.Tensor((2560,), dtype="float32") = model_params[484] | |
lv1009_1: R.Tensor((2560,), dtype="float32") = model_params[485] | |
lv1449 = R.call_tir(cls.fused_layer_norm1_cast8, (lv3654, lv1008_1, lv1009_1), out_sinfo=R.Tensor((1, 1, 2560), dtype="float16")) | |
lv3657: R.Tensor((1, 1, 2560), dtype="float16") = lv1449 | |
lv1450: R.Tensor((10240, 320), dtype="uint32") = model_params[492] | |
lv1451: R.Tensor((10240, 80), dtype="float16") = model_params[493] | |
lv1012_1: R.Tensor((10240,), dtype="float32") = model_params[494] | |
lv61 = R.call_tir(cls.fused_fused_decode4_fused_NT_matmul9_add8_gelu1_cast11, (lv1450, lv1451, lv3657, lv1012_1), out_sinfo=R.Tensor((1, 1, 10240), dtype="float16")) | |
lv3663: R.Tensor((1, 1, 10240), dtype="float16") = lv61 | |
lv1454: R.Tensor((2560, 1280), dtype="uint32") = model_params[495] | |
lv1455: R.Tensor((2560, 320), dtype="float16") = model_params[496] | |
lv1015: R.Tensor((2560,), dtype="float32") = model_params[497] | |
lv61_1 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul10_add9_cast8_cast12_add7, (lv1454, lv1455, lv3663, lv1015, lv60_1), out_sinfo=R.Tensor((1, 1, 2560), dtype="float16")) | |
lv3670 = R.call_tir(cls.cast7, (lv61_1,), out_sinfo=R.Tensor((1, 1, 2560), dtype="float32")) | |
lv1016: R.Tensor((2560,), dtype="float32") = model_params[498] | |
lv1017_1: R.Tensor((2560,), dtype="float32") = model_params[499] | |
lv1458 = R.call_tir(cls.fused_layer_norm1_cast8, (lv3670, lv1016, lv1017_1), out_sinfo=R.Tensor((1, 1, 2560), dtype="float16")) | |
lv3673: R.Tensor((1, 1, 2560), dtype="float16") = lv1458 | |
lv1459: R.Tensor((7680, 320), dtype="uint32") = model_params[502] | |
lv1460: R.Tensor((7680, 80), dtype="float16") = model_params[503] | |
lv1020: R.Tensor((7680,), dtype="float16") = model_params[504] | |
lv62 = R.call_tir(cls.fused_fused_decode2_fused_NT_matmul6_add5, (lv1459, lv1460, lv3673, lv1020), out_sinfo=R.Tensor((1, 1, 7680), dtype="float16")) | |
lv1463 = R.call_tir(cls.fused_reshape7_split1, (lv62,), out_sinfo=[R.Tensor((1, 1, 32, 80), dtype="float16"), R.Tensor((1, 1, 32, 80), dtype="float16"), R.Tensor((1, 1, 32, 80), dtype="float16")]) | |
lv3679: R.Tensor((1, 1, 32, 80), dtype="float16") = lv1463[0] | |
lv3680 = R.call_tir(cls.rotary_embedding1, (lv3679, metadata["relax.expr.Constant"][1], metadata["relax.expr.Constant"][2]), out_sinfo=R.Tensor((1, 1, 32, 80), dtype="float16"), tir_vars=R.shape([n])) | |
lv3681: R.Tensor((1, 1, 32, 80), dtype="float16") = lv1463[1] | |
lv3682 = R.call_tir(cls.rotary_embedding1, (lv3681, metadata["relax.expr.Constant"][1], metadata["relax.expr.Constant"][2]), out_sinfo=R.Tensor((1, 1, 32, 80), dtype="float16"), tir_vars=R.shape([n])) | |
lv3683: R.Object = kv_cache[62] | |
lv3684 = R.call_tir(cls.squeeze1, (lv3682,), out_sinfo=R.Tensor((1, 32, 80), dtype="float16")) | |
lv3685: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv3683, lv3684, sinfo_args=(R.Object,)) | |
lv3686: R.Object = kv_cache[63] | |
lv1464: R.Tensor((1, 1, 32, 80), dtype="float16") = lv1463[2] | |
lv1465 = R.call_tir(cls.fused_squeeze1, (lv1464,), out_sinfo=R.Tensor((1, 32, 80), dtype="float16")) | |
lv3689: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv3686, lv1465, sinfo_args=(R.Object,)) | |
lv3690: R.Tensor((n, 32, 80), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv3685, R.shape([n, 32, 80]), sinfo_args=(R.Tensor((n, 32, 80), dtype="float16"),)) | |
lv3691: R.Tensor((n, 32, 80), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv3689, R.shape([n, 32, 80]), sinfo_args=(R.Tensor((n, 32, 80), dtype="float16"),)) | |
lv3692 = R.call_tir(cls.reshape3, (lv3690,), out_sinfo=R.Tensor((1, n, 32, 80), dtype="float16")) | |
lv3693 = R.call_tir(cls.reshape3, (lv3691,), out_sinfo=R.Tensor((1, n, 32, 80), dtype="float16")) | |
lv3694 = R.call_tir(cls.transpose7, (lv3680,), out_sinfo=R.Tensor((1, 32, 1, 80), dtype="float16")) | |
lv3695 = R.call_tir(cls.transpose5, (lv3692,), out_sinfo=R.Tensor((1, 32, n, 80), dtype="float16")) | |
lv3696 = R.call_tir(cls.transpose5, (lv3693,), out_sinfo=R.Tensor((1, 32, n, 80), dtype="float16")) | |
lv1466 = R.call_tir(cls.fused_NT_matmul7_divide2_maximum1_minimum1_cast9, (lv3694, lv3695, lv1871), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float32")) | |
lv1467 = R.call_tir(cls.fused_softmax2_cast10, (lv1466,), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float16")) | |
lv3705 = R.call_tir(cls.matmul9, (lv1467, lv3696), out_sinfo=R.Tensor((1, 32, 1, 80), dtype="float16")) | |
lv1468 = R.call_tir(cls.fused_transpose8_reshape8, (lv3705,), out_sinfo=R.Tensor((1, 1, 2560), dtype="float16")) | |
lv1469: R.Tensor((2560, 320), dtype="uint32") = model_params[505] | |
lv1470: R.Tensor((2560, 80), dtype="float16") = model_params[506] | |
lv1023_1: R.Tensor((2560,), dtype="float16") = model_params[507] | |
lv62_1 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul8_add6_add7, (lv1469, lv1470, lv1468, lv1023_1, lv61_1), out_sinfo=R.Tensor((1, 1, 2560), dtype="float16")) | |
lv3712 = R.call_tir(cls.cast7, (lv62_1,), out_sinfo=R.Tensor((1, 1, 2560), dtype="float32")) | |
lv1024: R.Tensor((2560,), dtype="float32") = model_params[500] | |
lv1025: R.Tensor((2560,), dtype="float32") = model_params[501] | |
lv1473 = R.call_tir(cls.fused_layer_norm1_cast8, (lv3712, lv1024, lv1025), out_sinfo=R.Tensor((1, 1, 2560), dtype="float16")) | |
lv3715: R.Tensor((1, 1, 2560), dtype="float16") = lv1473 | |
lv1474: R.Tensor((10240, 320), dtype="uint32") = model_params[508] | |
lv1475: R.Tensor((10240, 80), dtype="float16") = model_params[509] | |
lv1028_1: R.Tensor((10240,), dtype="float32") = model_params[510] | |
lv63 = R.call_tir(cls.fused_fused_decode4_fused_NT_matmul9_add8_gelu1_cast11, (lv1474, lv1475, lv3715, lv1028_1), out_sinfo=R.Tensor((1, 1, 10240), dtype="float16")) | |
lv3721: R.Tensor((1, 1, 10240), dtype="float16") = lv63 | |
lv1478: R.Tensor((2560, 1280), dtype="uint32") = model_params[511] | |
lv1479: R.Tensor((2560, 320), dtype="float16") = model_params[512] | |
lv1031_1: R.Tensor((2560,), dtype="float32") = model_params[513] | |
lv63_1 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul10_add9_cast8_cast12_add7_cast7, (lv1478, lv1479, lv3721, lv1031_1, lv62_1), out_sinfo=R.Tensor((1, 1, 2560), dtype="float32")) | |
lv1032_1: R.Tensor((2560,), dtype="float32") = model_params[514] | |
lv1033_1: R.Tensor((2560,), dtype="float32") = model_params[515] | |
lv3729 = R.call_tir(cls.layer_norm1, (lv63_1, lv1032_1, lv1033_1), out_sinfo=R.Tensor((1, 1, 2560), dtype="float32")) | |
lv1482 = R.call_tir(cls.fused_slice1_cast6, (lv3729,), out_sinfo=R.Tensor((1, 1, 2560), dtype="float32")) | |
lv1483: R.Tensor((50432, 320), dtype="uint32") = model_params[516] | |
lv1484: R.Tensor((50432, 80), dtype="float32") = model_params[517] | |
lv_3 = R.call_tir(cls.fused_fused_decode6_NT_matmul5, (lv1483, lv1484, lv1482), out_sinfo=R.Tensor((1, 1, 50432), dtype="float32")) | |
gv1: R.Tuple(R.Tensor((1, 1, 50432), dtype="float32"), R.Tuple(R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object)) = lv_3, (lv1887, lv1891, lv1945, lv1949, lv2003, lv2007, lv2061, lv2065, lv2119, lv2123, lv2177, lv2181, lv2235, lv2239, lv2293, lv2297, lv2351, lv2355, lv2409, lv2413, lv2467, lv2471, lv2525, lv2529, lv2583, lv2587, lv2641, lv2645, lv2699, lv2703, lv2757, lv2761, lv2815, lv2819, lv2873, lv2877, lv2931, lv2935, lv2989, lv2993, lv3047, lv3051, lv3105, lv3109, lv3163, lv3167, lv3221, lv3225, lv3279, lv3283, lv3337, lv3341, lv3395, lv3399, lv3453, lv3457, lv3511, lv3515, lv3569, lv3573, lv3627, lv3631, lv3685, lv3689) | |
R.output(gv1) | |
return gv1 | |
@R.function | |
def get_metadata() -> R.Object: | |
R.func_attr({"tir_var_upper_bound": {"m": 2048, "n": 2048}}) | |
return R.str("{\"model_name\": \"RedPajama-INCITE-Chat-3B-v1\", \"max_window_size\": 2048, \"stop_tokens\": [0], \"add_prefix_space\": false}") | |
@R.function | |
def prefill(input_ids: R.Tensor((1, "n"), dtype="int32"), all_seq_len: R.Shape(["m"]), kv_cache: R.Tuple(R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object), model_params: R.Tuple(R.Tensor((50432, 320), dtype="uint32"), R.Tensor((50432, 80), dtype="float16"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((7680, 320), dtype="uint32"), R.Tensor((7680, 80), dtype="float16"), R.Tensor((7680,), dtype="float16"), R.Tensor((2560, 320), dtype="uint32"), R.Tensor((2560, 80), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((10240, 320), dtype="uint32"), R.Tensor((10240, 80), dtype="float16"), R.Tensor((10240,), dtype="float32"), R.Tensor((2560, 1280), dtype="uint32"), R.Tensor((2560, 320), dtype="float16"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((7680, 320), dtype="uint32"), R.Tensor((7680, 80), dtype="float16"), R.Tensor((7680,), dtype="float16"), R.Tensor((2560, 320), dtype="uint32"), R.Tensor((2560, 80), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((10240, 320), dtype="uint32"), R.Tensor((10240, 80), dtype="float16"), R.Tensor((10240,), dtype="float32"), R.Tensor((2560, 1280), dtype="uint32"), R.Tensor((2560, 320), dtype="float16"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((7680, 320), dtype="uint32"), R.Tensor((7680, 80), dtype="float16"), R.Tensor((7680,), dtype="float16"), R.Tensor((2560, 320), dtype="uint32"), R.Tensor((2560, 80), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((10240, 320), dtype="uint32"), R.Tensor((10240, 80), dtype="float16"), R.Tensor((10240,), dtype="float32"), R.Tensor((2560, 1280), dtype="uint32"), R.Tensor((2560, 320), dtype="float16"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((7680, 320), dtype="uint32"), R.Tensor((7680, 80), dtype="float16"), R.Tensor((7680,), dtype="float16"), R.Tensor((2560, 320), dtype="uint32"), R.Tensor((2560, 80), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((10240, 320), dtype="uint32"), R.Tensor((10240, 80), dtype="float16"), R.Tensor((10240,), dtype="float32"), R.Tensor((2560, 1280), dtype="uint32"), R.Tensor((2560, 320), dtype="float16"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((7680, 320), dtype="uint32"), R.Tensor((7680, 80), dtype="float16"), R.Tensor((7680,), dtype="float16"), R.Tensor((2560, 320), dtype="uint32"), R.Tensor((2560, 80), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((10240, 320), dtype="uint32"), R.Tensor((10240, 80), dtype="float16"), R.Tensor((10240,), dtype="float32"), R.Tensor((2560, 1280), dtype="uint32"), R.Tensor((2560, 320), dtype="float16"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((7680, 320), dtype="uint32"), R.Tensor((7680, 80), dtype="float16"), R.Tensor((7680,), dtype="float16"), R.Tensor((2560, 320), dtype="uint32"), R.Tensor((2560, 80), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((10240, 320), dtype="uint32"), R.Tensor((10240, 80), dtype="float16"), R.Tensor((10240,), dtype="float32"), R.Tensor((2560, 1280), dtype="uint32"), R.Tensor((2560, 320), dtype="float16"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((7680, 320), dtype="uint32"), R.Tensor((7680, 80), dtype="float16"), R.Tensor((7680,), dtype="float16"), R.Tensor((2560, 320), dtype="uint32"), R.Tensor((2560, 80), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((10240, 320), dtype="uint32"), R.Tensor((10240, 80), dtype="float16"), R.Tensor((10240,), dtype="float32"), R.Tensor((2560, 1280), dtype="uint32"), R.Tensor((2560, 320), dtype="float16"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((7680, 320), dtype="uint32"), R.Tensor((7680, 80), dtype="float16"), R.Tensor((7680,), dtype="float16"), R.Tensor((2560, 320), dtype="uint32"), R.Tensor((2560, 80), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((10240, 320), dtype="uint32"), R.Tensor((10240, 80), dtype="float16"), R.Tensor((10240,), dtype="float32"), R.Tensor((2560, 1280), dtype="uint32"), R.Tensor((2560, 320), dtype="float16"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((7680, 320), dtype="uint32"), R.Tensor((7680, 80), dtype="float16"), R.Tensor((7680,), dtype="float16"), R.Tensor((2560, 320), dtype="uint32"), R.Tensor((2560, 80), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((10240, 320), dtype="uint32"), R.Tensor((10240, 80), dtype="float16"), R.Tensor((10240,), dtype="float32"), R.Tensor((2560, 1280), dtype="uint32"), R.Tensor((2560, 320), dtype="float16"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((7680, 320), dtype="uint32"), R.Tensor((7680, 80), dtype="float16"), R.Tensor((7680,), dtype="float16"), R.Tensor((2560, 320), dtype="uint32"), R.Tensor((2560, 80), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((10240, 320), dtype="uint32"), R.Tensor((10240, 80), dtype="float16"), R.Tensor((10240,), dtype="float32"), R.Tensor((2560, 1280), dtype="uint32"), R.Tensor((2560, 320), dtype="float16"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((7680, 320), dtype="uint32"), R.Tensor((7680, 80), dtype="float16"), R.Tensor((7680,), dtype="float16"), R.Tensor((2560, 320), dtype="uint32"), R.Tensor((2560, 80), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((10240, 320), dtype="uint32"), R.Tensor((10240, 80), dtype="float16"), R.Tensor((10240,), dtype="float32"), R.Tensor((2560, 1280), dtype="uint32"), R.Tensor((2560, 320), dtype="float16"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((7680, 320), dtype="uint32"), R.Tensor((7680, 80), dtype="float16"), R.Tensor((7680,), dtype="float16"), R.Tensor((2560, 320), dtype="uint32"), R.Tensor((2560, 80), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((10240, 320), dtype="uint32"), R.Tensor((10240, 80), dtype="float16"), R.Tensor((10240,), dtype="float32"), R.Tensor((2560, 1280), dtype="uint32"), R.Tensor((2560, 320), dtype="float16"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((7680, 320), dtype="uint32"), R.Tensor((7680, 80), dtype="float16"), R.Tensor((7680,), dtype="float16"), R.Tensor((2560, 320), dtype="uint32"), R.Tensor((2560, 80), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((10240, 320), dtype="uint32"), R.Tensor((10240, 80), dtype="float16"), R.Tensor((10240,), dtype="float32"), R.Tensor((2560, 1280), dtype="uint32"), R.Tensor((2560, 320), dtype="float16"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((7680, 320), dtype="uint32"), R.Tensor((7680, 80), dtype="float16"), R.Tensor((7680,), dtype="float16"), R.Tensor((2560, 320), dtype="uint32"), R.Tensor((2560, 80), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((10240, 320), dtype="uint32"), R.Tensor((10240, 80), dtype="float16"), R.Tensor((10240,), dtype="float32"), R.Tensor((2560, 1280), dtype="uint32"), R.Tensor((2560, 320), dtype="float16"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((7680, 320), dtype="uint32"), R.Tensor((7680, 80), dtype="float16"), R.Tensor((7680,), dtype="float16"), R.Tensor((2560, 320), dtype="uint32"), R.Tensor((2560, 80), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((10240, 320), dtype="uint32"), R.Tensor((10240, 80), dtype="float16"), R.Tensor((10240,), dtype="float32"), R.Tensor((2560, 1280), dtype="uint32"), R.Tensor((2560, 320), dtype="float16"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((7680, 320), dtype="uint32"), R.Tensor((7680, 80), dtype="float16"), R.Tensor((7680,), dtype="float16"), R.Tensor((2560, 320), dtype="uint32"), R.Tensor((2560, 80), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((10240, 320), dtype="uint32"), R.Tensor((10240, 80), dtype="float16"), R.Tensor((10240,), dtype="float32"), R.Tensor((2560, 1280), dtype="uint32"), R.Tensor((2560, 320), dtype="float16"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((7680, 320), dtype="uint32"), R.Tensor((7680, 80), dtype="float16"), R.Tensor((7680,), dtype="float16"), R.Tensor((2560, 320), dtype="uint32"), R.Tensor((2560, 80), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((10240, 320), dtype="uint32"), R.Tensor((10240, 80), dtype="float16"), R.Tensor((10240,), dtype="float32"), R.Tensor((2560, 1280), dtype="uint32"), R.Tensor((2560, 320), dtype="float16"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((7680, 320), dtype="uint32"), R.Tensor((7680, 80), dtype="float16"), R.Tensor((7680,), dtype="float16"), R.Tensor((2560, 320), dtype="uint32"), R.Tensor((2560, 80), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((10240, 320), dtype="uint32"), R.Tensor((10240, 80), dtype="float16"), R.Tensor((10240,), dtype="float32"), R.Tensor((2560, 1280), dtype="uint32"), R.Tensor((2560, 320), dtype="float16"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((7680, 320), dtype="uint32"), R.Tensor((7680, 80), dtype="float16"), R.Tensor((7680,), dtype="float16"), R.Tensor((2560, 320), dtype="uint32"), R.Tensor((2560, 80), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((10240, 320), dtype="uint32"), R.Tensor((10240, 80), dtype="float16"), R.Tensor((10240,), dtype="float32"), R.Tensor((2560, 1280), dtype="uint32"), R.Tensor((2560, 320), dtype="float16"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((7680, 320), dtype="uint32"), R.Tensor((7680, 80), dtype="float16"), R.Tensor((7680,), dtype="float16"), R.Tensor((2560, 320), dtype="uint32"), R.Tensor((2560, 80), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((10240, 320), dtype="uint32"), R.Tensor((10240, 80), dtype="float16"), R.Tensor((10240,), dtype="float32"), R.Tensor((2560, 1280), dtype="uint32"), R.Tensor((2560, 320), dtype="float16"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((7680, 320), dtype="uint32"), R.Tensor((7680, 80), dtype="float16"), R.Tensor((7680,), dtype="float16"), R.Tensor((2560, 320), dtype="uint32"), R.Tensor((2560, 80), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((10240, 320), dtype="uint32"), R.Tensor((10240, 80), dtype="float16"), R.Tensor((10240,), dtype="float32"), R.Tensor((2560, 1280), dtype="uint32"), R.Tensor((2560, 320), dtype="float16"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((7680, 320), dtype="uint32"), R.Tensor((7680, 80), dtype="float16"), R.Tensor((7680,), dtype="float16"), R.Tensor((2560, 320), dtype="uint32"), R.Tensor((2560, 80), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((10240, 320), dtype="uint32"), R.Tensor((10240, 80), dtype="float16"), R.Tensor((10240,), dtype="float32"), R.Tensor((2560, 1280), dtype="uint32"), R.Tensor((2560, 320), dtype="float16"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((7680, 320), dtype="uint32"), R.Tensor((7680, 80), dtype="float16"), R.Tensor((7680,), dtype="float16"), R.Tensor((2560, 320), dtype="uint32"), R.Tensor((2560, 80), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((10240, 320), dtype="uint32"), R.Tensor((10240, 80), dtype="float16"), R.Tensor((10240,), dtype="float32"), R.Tensor((2560, 1280), dtype="uint32"), R.Tensor((2560, 320), dtype="float16"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((7680, 320), dtype="uint32"), R.Tensor((7680, 80), dtype="float16"), R.Tensor((7680,), dtype="float16"), R.Tensor((2560, 320), dtype="uint32"), R.Tensor((2560, 80), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((10240, 320), dtype="uint32"), R.Tensor((10240, 80), dtype="float16"), R.Tensor((10240,), dtype="float32"), R.Tensor((2560, 1280), dtype="uint32"), R.Tensor((2560, 320), dtype="float16"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((7680, 320), dtype="uint32"), R.Tensor((7680, 80), dtype="float16"), R.Tensor((7680,), dtype="float16"), R.Tensor((2560, 320), dtype="uint32"), R.Tensor((2560, 80), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((10240, 320), dtype="uint32"), R.Tensor((10240, 80), dtype="float16"), R.Tensor((10240,), dtype="float32"), R.Tensor((2560, 1280), dtype="uint32"), R.Tensor((2560, 320), dtype="float16"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((7680, 320), dtype="uint32"), R.Tensor((7680, 80), dtype="float16"), R.Tensor((7680,), dtype="float16"), R.Tensor((2560, 320), dtype="uint32"), R.Tensor((2560, 80), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((10240, 320), dtype="uint32"), R.Tensor((10240, 80), dtype="float16"), R.Tensor((10240,), dtype="float32"), R.Tensor((2560, 1280), dtype="uint32"), R.Tensor((2560, 320), dtype="float16"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((7680, 320), dtype="uint32"), R.Tensor((7680, 80), dtype="float16"), R.Tensor((7680,), dtype="float16"), R.Tensor((2560, 320), dtype="uint32"), R.Tensor((2560, 80), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((10240, 320), dtype="uint32"), R.Tensor((10240, 80), dtype="float16"), R.Tensor((10240,), dtype="float32"), R.Tensor((2560, 1280), dtype="uint32"), R.Tensor((2560, 320), dtype="float16"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((7680, 320), dtype="uint32"), R.Tensor((7680, 80), dtype="float16"), R.Tensor((7680,), dtype="float16"), R.Tensor((2560, 320), dtype="uint32"), R.Tensor((2560, 80), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((10240, 320), dtype="uint32"), R.Tensor((10240, 80), dtype="float16"), R.Tensor((10240,), dtype="float32"), R.Tensor((2560, 1280), dtype="uint32"), R.Tensor((2560, 320), dtype="float16"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((7680, 320), dtype="uint32"), R.Tensor((7680, 80), dtype="float16"), R.Tensor((7680,), dtype="float16"), R.Tensor((2560, 320), dtype="uint32"), R.Tensor((2560, 80), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((10240, 320), dtype="uint32"), R.Tensor((10240, 80), dtype="float16"), R.Tensor((10240,), dtype="float32"), R.Tensor((2560, 1280), dtype="uint32"), R.Tensor((2560, 320), dtype="float16"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((7680, 320), dtype="uint32"), R.Tensor((7680, 80), dtype="float16"), R.Tensor((7680,), dtype="float16"), R.Tensor((2560, 320), dtype="uint32"), R.Tensor((2560, 80), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((10240, 320), dtype="uint32"), R.Tensor((10240, 80), dtype="float16"), R.Tensor((10240,), dtype="float32"), R.Tensor((2560, 1280), dtype="uint32"), R.Tensor((2560, 320), dtype="float16"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((7680, 320), dtype="uint32"), R.Tensor((7680, 80), dtype="float16"), R.Tensor((7680,), dtype="float16"), R.Tensor((2560, 320), dtype="uint32"), R.Tensor((2560, 80), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((10240, 320), dtype="uint32"), R.Tensor((10240, 80), dtype="float16"), R.Tensor((10240,), dtype="float32"), R.Tensor((2560, 1280), dtype="uint32"), R.Tensor((2560, 320), dtype="float16"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((7680, 320), dtype="uint32"), R.Tensor((7680, 80), dtype="float16"), R.Tensor((7680,), dtype="float16"), R.Tensor((2560, 320), dtype="uint32"), R.Tensor((2560, 80), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((10240, 320), dtype="uint32"), R.Tensor((10240, 80), dtype="float16"), R.Tensor((10240,), dtype="float32"), R.Tensor((2560, 1280), dtype="uint32"), R.Tensor((2560, 320), dtype="float16"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((2560,), dtype="float32"), R.Tensor((50432, 320), dtype="uint32"), R.Tensor((50432, 80), dtype="float32"))) -> R.Tuple(R.Tensor((1, 1, 50432), dtype="float32"), R.Tuple(R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object)): | |
n = T.int64() | |
m = T.int64() | |
R.func_attr({"num_input": 3, "tir_var_upper_bound": {"m": 2048, "n": 2048}}) | |
cls = Module | |
with R.dataflow(): | |
lv = R.call_tir(cls.reshape, (input_ids,), out_sinfo=R.Tensor((n,), dtype="int32")) | |
lv_1: R.Tensor((50432, 320), dtype="uint32") = model_params[0] | |
lv1: R.Tensor((50432, 80), dtype="float16") = model_params[1] | |
lv1_1 = R.call_tir(cls.fused_fused_decode1_take, (lv_1, lv1, lv), out_sinfo=R.Tensor((n, 2560), dtype="float16")) | |
lv2 = R.call_tir(cls.reshape1, (lv1_1,), out_sinfo=R.Tensor((1, n, 2560), dtype="float16")) | |
lv3 = R.call_tir(cls.fused_min_max_triu_te_broadcast_to, R.tuple(), out_sinfo=R.Tensor((1, 1, n, n), dtype="float16"), tir_vars=R.shape([n])) | |
lv5 = R.call_tir(cls.extend_te, (lv3,), out_sinfo=R.Tensor((1, 1, n, m), dtype="float16")) | |
lv6 = R.call_tir(cls.cast, (lv2,), out_sinfo=R.Tensor((1, n, 2560), dtype="float32")) | |
lv2_1: R.Tensor((2560,), dtype="float32") = model_params[2] | |
lv3_1: R.Tensor((2560,), dtype="float32") = model_params[3] | |
lv4 = R.call_tir(cls.fused_layer_norm_cast1, (lv6, lv2_1, lv3_1), out_sinfo=R.Tensor((1, n, 2560), dtype="float16")) | |
lv9: R.Tensor((1, n, 2560), dtype="float16") = lv4 | |
lv5_1: R.Tensor((7680, 320), dtype="uint32") = model_params[6] | |
lv6_1: R.Tensor((7680, 80), dtype="float16") = model_params[7] | |
lv6_2: R.Tensor((7680,), dtype="float16") = model_params[8] | |
lv64 = R.call_tir(cls.fused_fused_decode2_fused_NT_matmul_add, (lv5_1, lv6_1, lv9, lv6_2), out_sinfo=R.Tensor((1, n, 7680), dtype="float16")) | |
lv13 = R.call_tir(cls.reshape2, (lv64,), out_sinfo=R.Tensor((1, n, 32, 240), dtype="float16")) | |
lv14 = R.call_tir(cls.split, (lv13,), out_sinfo=[R.Tensor((1, n, 32, 80), dtype="float16"), R.Tensor((1, n, 32, 80), dtype="float16"), R.Tensor((1, n, 32, 80), dtype="float16")]) | |
lv15: R.Tensor((1, n, 32, 80), dtype="float16") = lv14[0] | |
lv16 = R.call_tir(cls.rotary_embedding, (lv15, metadata["relax.expr.Constant"][3], metadata["relax.expr.Constant"][4]), out_sinfo=R.Tensor((1, n, 32, 80), dtype="float16"), tir_vars=R.shape([m])) | |
lv17: R.Tensor((1, n, 32, 80), dtype="float16") = lv14[1] | |
lv18 = R.call_tir(cls.rotary_embedding, (lv17, metadata["relax.expr.Constant"][3], metadata["relax.expr.Constant"][4]), out_sinfo=R.Tensor((1, n, 32, 80), dtype="float16"), tir_vars=R.shape([m])) | |
lv19: R.Object = kv_cache[0] | |
lv20 = R.call_tir(cls.squeeze, (lv18,), out_sinfo=R.Tensor((n, 32, 80), dtype="float16")) | |
lv21: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv19, lv20, sinfo_args=(R.Object,)) | |
lv22: R.Object = kv_cache[1] | |
lv9_1: R.Tensor((1, n, 32, 80), dtype="float16") = lv14[2] | |
lv10 = R.call_tir(cls.fused_squeeze, (lv9_1,), out_sinfo=R.Tensor((n, 32, 80), dtype="float16")) | |
lv25: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv22, lv10, sinfo_args=(R.Object,)) | |
lv26: R.Tensor((m, 32, 80), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv21, R.shape([m, 32, 80]), sinfo_args=(R.Tensor((m, 32, 80), dtype="float16"),)) | |
lv27: R.Tensor((m, 32, 80), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv25, R.shape([m, 32, 80]), sinfo_args=(R.Tensor((m, 32, 80), dtype="float16"),)) | |
lv28 = R.call_tir(cls.reshape3, (lv26,), out_sinfo=R.Tensor((1, m, 32, 80), dtype="float16")) | |
lv29 = R.call_tir(cls.reshape3, (lv27,), out_sinfo=R.Tensor((1, m, 32, 80), dtype="float16")) | |
lv30 = R.call_tir(cls.transpose5, (lv16,), out_sinfo=R.Tensor((1, 32, n, 80), dtype="float16")) | |
lv31 = R.call_tir(cls.transpose5, (lv28,), out_sinfo=R.Tensor((1, 32, m, 80), dtype="float16")) | |
lv32 = R.call_tir(cls.transpose5, (lv29,), out_sinfo=R.Tensor((1, 32, m, 80), dtype="float16")) | |
lv11 = R.call_tir(cls.fused_NT_matmul1_divide_maximum_minimum_cast2, (lv30, lv31, lv5), out_sinfo=R.Tensor((1, 32, n, m), dtype="float32")) | |
lv12 = R.call_tir(cls.fused_softmax_cast3, (lv11,), out_sinfo=R.Tensor((1, 32, n, m), dtype="float16")) | |
lv41 = R.call_tir(cls.matmul8, (lv12, lv32), out_sinfo=R.Tensor((1, 32, n, 80), dtype="float16")) | |
lv42 = R.call_tir(cls.transpose6, (lv41,), out_sinfo=R.Tensor((1, n, 32, 80), dtype="float16")) | |
lv43 = R.call_tir(cls.reshape4, (lv42,), out_sinfo=R.Tensor((1, n, 2560), dtype="float16")) | |
lv13_1: R.Tensor((2560, 320), dtype="uint32") = model_params[9] | |
lv14_1: R.Tensor((2560, 80), dtype="float16") = model_params[10] | |
lv9_2: R.Tensor((2560,), dtype="float16") = model_params[11] | |
lv64_1 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul2_add1_add2, (lv13_1, lv14_1, lv43, lv9_2, lv2), out_sinfo=R.Tensor((1, n, 2560), dtype="float16")) | |
lv48 = R.call_tir(cls.cast, (lv64_1,), out_sinfo=R.Tensor((1, n, 2560), dtype="float32")) | |
lv10_1: R.Tensor((2560,), dtype="float32") = model_params[4] | |
lv11_1: R.Tensor((2560,), dtype="float32") = model_params[5] | |
lv17_1 = R.call_tir(cls.fused_layer_norm_cast1, (lv48, lv10_1, lv11_1), out_sinfo=R.Tensor((1, n, 2560), dtype="float16")) | |
lv51: R.Tensor((1, n, 2560), dtype="float16") = lv17_1 | |
lv18_1: R.Tensor((10240, 320), dtype="uint32") = model_params[12] | |
lv19_1: R.Tensor((10240, 80), dtype="float16") = model_params[13] | |
lv14_2: R.Tensor((10240,), dtype="float32") = model_params[14] | |
lv65 = R.call_tir(cls.fused_fused_decode4_fused_NT_matmul3_add3_gelu_cast4, (lv18_1, lv19_1, lv51, lv14_2), out_sinfo=R.Tensor((1, n, 10240), dtype="float16")) | |
lv57: R.Tensor((1, n, 10240), dtype="float16") = lv65 | |
lv22_1: R.Tensor((2560, 1280), dtype="uint32") = model_params[15] | |
lv23: R.Tensor((2560, 320), dtype="float16") = model_params[16] | |
lv17_2: R.Tensor((2560,), dtype="float32") = model_params[17] | |
lv65_1 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul4_add4_cast1_cast5_add2, (lv22_1, lv23, lv57, lv17_2, lv64_1), out_sinfo=R.Tensor((1, n, 2560), dtype="float16")) | |
lv64_2 = R.call_tir(cls.cast, (lv65_1,), out_sinfo=R.Tensor((1, n, 2560), dtype="float32")) | |
lv18_2: R.Tensor((2560,), dtype="float32") = model_params[18] | |
lv19_2: R.Tensor((2560,), dtype="float32") = model_params[19] | |
lv26_1 = R.call_tir(cls.fused_layer_norm_cast1, (lv64_2, lv18_2, lv19_2), out_sinfo=R.Tensor((1, n, 2560), dtype="float16")) | |
lv67: R.Tensor((1, n, 2560), dtype="float16") = lv26_1 | |
lv27_1: R.Tensor((7680, 320), dtype="uint32") = model_params[22] | |
lv28_1: R.Tensor((7680, 80), dtype="float16") = model_params[23] | |
lv22_2: R.Tensor((7680,), dtype="float16") = model_params[24] | |
lv66 = R.call_tir(cls.fused_fused_decode2_fused_NT_matmul_add, (lv27_1, lv28_1, lv67, lv22_2), out_sinfo=R.Tensor((1, n, 7680), dtype="float16")) | |
lv71 = R.call_tir(cls.reshape2, (lv66,), out_sinfo=R.Tensor((1, n, 32, 240), dtype="float16")) | |
lv72 = R.call_tir(cls.split, (lv71,), out_sinfo=[R.Tensor((1, n, 32, 80), dtype="float16"), R.Tensor((1, n, 32, 80), dtype="float16"), R.Tensor((1, n, 32, 80), dtype="float16")]) | |
lv73: R.Tensor((1, n, 32, 80), dtype="float16") = lv72[0] | |
lv74 = R.call_tir(cls.rotary_embedding, (lv73, metadata["relax.expr.Constant"][3], metadata["relax.expr.Constant"][4]), out_sinfo=R.Tensor((1, n, 32, 80), dtype="float16"), tir_vars=R.shape([m])) | |
lv75: R.Tensor((1, n, 32, 80), dtype="float16") = lv72[1] | |
lv76 = R.call_tir(cls.rotary_embedding, (lv75, metadata["relax.expr.Constant"][3], metadata["relax.expr.Constant"][4]), out_sinfo=R.Tensor((1, n, 32, 80), dtype="float16"), tir_vars=R.shape([m])) | |
lv77: R.Object = kv_cache[2] | |
lv78 = R.call_tir(cls.squeeze, (lv76,), out_sinfo=R.Tensor((n, 32, 80), dtype="float16")) | |
lv79: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv77, lv78, sinfo_args=(R.Object,)) | |
lv80: R.Object = kv_cache[3] | |
lv31_1: R.Tensor((1, n, 32, 80), dtype="float16") = lv72[2] | |
lv32_1 = R.call_tir(cls.fused_squeeze, (lv31_1,), out_sinfo=R.Tensor((n, 32, 80), dtype="float16")) | |
lv83: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv80, lv32_1, sinfo_args=(R.Object,)) | |
lv84: R.Tensor((m, 32, 80), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv79, R.shape([m, 32, 80]), sinfo_args=(R.Tensor((m, 32, 80), dtype="float16"),)) | |
lv85: R.Tensor((m, 32, 80), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv83, R.shape([m, 32, 80]), sinfo_args=(R.Tensor((m, 32, 80), dtype="float16"),)) | |
lv86 = R.call_tir(cls.reshape3, (lv84,), out_sinfo=R.Tensor((1, m, 32, 80), dtype="float16")) | |
lv87 = R.call_tir(cls.reshape3, (lv85,), out_sinfo=R.Tensor((1, m, 32, 80), dtype="float16")) | |
lv88 = R.call_tir(cls.transpose5, (lv74,), out_sinfo=R.Tensor((1, 32, n, 80), dtype="float16")) | |
lv89 = R.call_tir(cls.transpose5, (lv86,), out_sinfo=R.Tensor((1, 32, m, 80), dtype="float16")) | |
lv90 = R.call_tir(cls.transpose5, (lv87,), out_sinfo=R.Tensor((1, 32, m, 80), dtype="float16")) | |
lv33 = R.call_tir(cls.fused_NT_matmul1_divide_maximum_minimum_cast2, (lv88, lv89, lv5), out_sinfo=R.Tensor((1, 32, n, m), dtype="float32")) | |
lv34 = R.call_tir(cls.fused_softmax_cast3, (lv33,), out_sinfo=R.Tensor((1, 32, n, m), dtype="float16")) | |
lv99 = R.call_tir(cls.matmul8, (lv34, lv90), out_sinfo=R.Tensor((1, 32, n, 80), dtype="float16")) | |
lv100 = R.call_tir(cls.transpose6, (lv99,), out_sinfo=R.Tensor((1, n, 32, 80), dtype="float16")) | |
lv101 = R.call_tir(cls.reshape4, (lv100,), out_sinfo=R.Tensor((1, n, 2560), dtype="float16")) | |
lv35: R.Tensor((2560, 320), dtype="uint32") = model_params[25] | |
lv36: R.Tensor((2560, 80), dtype="float16") = model_params[26] | |
lv25_1: R.Tensor((2560,), dtype="float16") = model_params[27] | |
lv66_1 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul2_add1_add2, (lv35, lv36, lv101, lv25_1, lv65_1), out_sinfo=R.Tensor((1, n, 2560), dtype="float16")) | |
lv106 = R.call_tir(cls.cast, (lv66_1,), out_sinfo=R.Tensor((1, n, 2560), dtype="float32")) | |
lv26_2: R.Tensor((2560,), dtype="float32") = model_params[20] | |
lv27_2: R.Tensor((2560,), dtype="float32") = model_params[21] | |
lv39 = R.call_tir(cls.fused_layer_norm_cast1, (lv106, lv26_2, lv27_2), out_sinfo=R.Tensor((1, n, 2560), dtype="float16")) | |
lv109: R.Tensor((1, n, 2560), dtype="float16") = lv39 | |
lv40: R.Tensor((10240, 320), dtype="uint32") = model_params[28] | |
lv41_1: R.Tensor((10240, 80), dtype="float16") = model_params[29] | |
lv30_1: R.Tensor((10240,), dtype="float32") = model_params[30] | |
lv67_1 = R.call_tir(cls.fused_fused_decode4_fused_NT_matmul3_add3_gelu_cast4, (lv40, lv41_1, lv109, lv30_1), out_sinfo=R.Tensor((1, n, 10240), dtype="float16")) | |
lv115: R.Tensor((1, n, 10240), dtype="float16") = lv67_1 | |
lv44: R.Tensor((2560, 1280), dtype="uint32") = model_params[31] | |
lv45: R.Tensor((2560, 320), dtype="float16") = model_params[32] | |
lv33_1: R.Tensor((2560,), dtype="float32") = model_params[33] | |
lv67_2 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul4_add4_cast1_cast5_add2, (lv44, lv45, lv115, lv33_1, lv66_1), out_sinfo=R.Tensor((1, n, 2560), dtype="float16")) | |
lv122 = R.call_tir(cls.cast, (lv67_2,), out_sinfo=R.Tensor((1, n, 2560), dtype="float32")) | |
lv34_1: R.Tensor((2560,), dtype="float32") = model_params[34] | |
lv35_1: R.Tensor((2560,), dtype="float32") = model_params[35] | |
lv48_1 = R.call_tir(cls.fused_layer_norm_cast1, (lv122, lv34_1, lv35_1), out_sinfo=R.Tensor((1, n, 2560), dtype="float16")) | |
lv125: R.Tensor((1, n, 2560), dtype="float16") = lv48_1 | |
lv49: R.Tensor((7680, 320), dtype="uint32") = model_params[38] | |
lv50: R.Tensor((7680, 80), dtype="float16") = model_params[39] | |
lv38: R.Tensor((7680,), dtype="float16") = model_params[40] | |
lv68 = R.call_tir(cls.fused_fused_decode2_fused_NT_matmul_add, (lv49, lv50, lv125, lv38), out_sinfo=R.Tensor((1, n, 7680), dtype="float16")) | |
lv129 = R.call_tir(cls.reshape2, (lv68,), out_sinfo=R.Tensor((1, n, 32, 240), dtype="float16")) | |
lv130 = R.call_tir(cls.split, (lv129,), out_sinfo=[R.Tensor((1, n, 32, 80), dtype="float16"), R.Tensor((1, n, 32, 80), dtype="float16"), R.Tensor((1, n, 32, 80), dtype="float16")]) | |
lv131: R.Tensor((1, n, 32, 80), dtype="float16") = lv130[0] | |
lv132 = R.call_tir(cls.rotary_embedding, (lv131, metadata["relax.expr.Constant"][3], metadata["relax.expr.Constant"][4]), out_sinfo=R.Tensor((1, n, 32, 80), dtype="float16"), tir_vars=R.shape([m])) | |
lv133: R.Tensor((1, n, 32, 80), dtype="float16") = lv130[1] | |
lv134 = R.call_tir(cls.rotary_embedding, (lv133, metadata["relax.expr.Constant"][3], metadata["relax.expr.Constant"][4]), out_sinfo=R.Tensor((1, n, 32, 80), dtype="float16"), tir_vars=R.shape([m])) | |
lv135: R.Object = kv_cache[4] | |
lv136 = R.call_tir(cls.squeeze, (lv134,), out_sinfo=R.Tensor((n, 32, 80), dtype="float16")) | |
lv137: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv135, lv136, sinfo_args=(R.Object,)) | |
lv138: R.Object = kv_cache[5] | |
lv53: R.Tensor((1, n, 32, 80), dtype="float16") = lv130[2] | |
lv54 = R.call_tir(cls.fused_squeeze, (lv53,), out_sinfo=R.Tensor((n, 32, 80), dtype="float16")) | |
lv141: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv138, lv54, sinfo_args=(R.Object,)) | |
lv142: R.Tensor((m, 32, 80), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv137, R.shape([m, 32, 80]), sinfo_args=(R.Tensor((m, 32, 80), dtype="float16"),)) | |
lv143: R.Tensor((m, 32, 80), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv141, R.shape([m, 32, 80]), sinfo_args=(R.Tensor((m, 32, 80), dtype="float16"),)) | |
lv144 = R.call_tir(cls.reshape3, (lv142,), out_sinfo=R.Tensor((1, m, 32, 80), dtype="float16")) | |
lv145 = R.call_tir(cls.reshape3, (lv143,), out_sinfo=R.Tensor((1, m, 32, 80), dtype="float16")) | |
lv146 = R.call_tir(cls.transpose5, (lv132,), out_sinfo=R.Tensor((1, 32, n, 80), dtype="float16")) | |
lv147 = R.call_tir(cls.transpose5, (lv144,), out_sinfo=R.Tensor((1, 32, m, 80), dtype="float16")) | |
lv148 = R.call_tir(cls.transpose5, (lv145,), out_sinfo=R.Tensor((1, 32, m, 80), dtype="float16")) | |
lv55 = R.call_tir(cls.fused_NT_matmul1_divide_maximum_minimum_cast2, (lv146, lv147, lv5), out_sinfo=R.Tensor((1, 32, n, m), dtype="float32")) | |
lv56 = R.call_tir(cls.fused_softmax_cast3, (lv55,), out_sinfo=R.Tensor((1, 32, n, m), dtype="float16")) | |
lv157 = R.call_tir(cls.matmul8, (lv56, lv148), out_sinfo=R.Tensor((1, 32, n, 80), dtype="float16")) | |
lv158 = R.call_tir(cls.transpose6, (lv157,), out_sinfo=R.Tensor((1, n, 32, 80), dtype="float16")) | |
lv159 = R.call_tir(cls.reshape4, (lv158,), out_sinfo=R.Tensor((1, n, 2560), dtype="float16")) | |
lv57_1: R.Tensor((2560, 320), dtype="uint32") = model_params[41] | |
lv58: R.Tensor((2560, 80), dtype="float16") = model_params[42] | |
lv41_2: R.Tensor((2560,), dtype="float16") = model_params[43] | |
lv68_1 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul2_add1_add2, (lv57_1, lv58, lv159, lv41_2, lv67_2), out_sinfo=R.Tensor((1, n, 2560), dtype="float16")) | |
lv164 = R.call_tir(cls.cast, (lv68_1,), out_sinfo=R.Tensor((1, n, 2560), dtype="float32")) | |
lv42_1: R.Tensor((2560,), dtype="float32") = model_params[36] | |
lv43_1: R.Tensor((2560,), dtype="float32") = model_params[37] | |
lv61 = R.call_tir(cls.fused_layer_norm_cast1, (lv164, lv42_1, lv43_1), out_sinfo=R.Tensor((1, n, 2560), dtype="float16")) | |
lv167: R.Tensor((1, n, 2560), dtype="float16") = lv61 | |
lv62: R.Tensor((10240, 320), dtype="uint32") = model_params[44] | |
lv63: R.Tensor((10240, 80), dtype="float16") = model_params[45] | |
lv46: R.Tensor((10240,), dtype="float32") = model_params[46] | |
lv69 = R.call_tir(cls.fused_fused_decode4_fused_NT_matmul3_add3_gelu_cast4, (lv62, lv63, lv167, lv46), out_sinfo=R.Tensor((1, n, 10240), dtype="float16")) | |
lv173: R.Tensor((1, n, 10240), dtype="float16") = lv69 | |
lv66_2: R.Tensor((2560, 1280), dtype="uint32") = model_params[47] | |
lv67_3: R.Tensor((2560, 320), dtype="float16") = model_params[48] | |
lv49_1: R.Tensor((2560,), dtype="float32") = model_params[49] | |
lv69_1 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul4_add4_cast1_cast5_add2, (lv66_2, lv67_3, lv173, lv49_1, lv68_1), out_sinfo=R.Tensor((1, n, 2560), dtype="float16")) | |
lv180 = R.call_tir(cls.cast, (lv69_1,), out_sinfo=R.Tensor((1, n, 2560), dtype="float32")) | |
lv50_1: R.Tensor((2560,), dtype="float32") = model_params[50] | |
lv51_1: R.Tensor((2560,), dtype="float32") = model_params[51] | |
lv70 = R.call_tir(cls.fused_layer_norm_cast1, (lv180, lv50_1, lv51_1), out_sinfo=R.Tensor((1, n, 2560), dtype="float16")) | |
lv183: R.Tensor((1, n, 2560), dtype="float16") = lv70 | |
lv71_1: R.Tensor((7680, 320), dtype="uint32") = model_params[54] | |
lv72_1: R.Tensor((7680, 80), dtype="float16") = model_params[55] | |
lv54_1: R.Tensor((7680,), dtype="float16") = model_params[56] | |
lv70_1 = R.call_tir(cls.fused_fused_decode2_fused_NT_matmul_add, (lv71_1, lv72_1, lv183, lv54_1), out_sinfo=R.Tensor((1, n, 7680), dtype="float16")) | |
lv187 = R.call_tir(cls.reshape2, (lv70_1,), out_sinfo=R.Tensor((1, n, 32, 240), dtype="float16")) | |
lv188 = R.call_tir(cls.split, (lv187,), out_sinfo=[R.Tensor((1, n, 32, 80), dtype="float16"), R.Tensor((1, n, 32, 80), dtype="float16"), R.Tensor((1, n, 32, 80), dtype="float16")]) | |
lv189: R.Tensor((1, n, 32, 80), dtype="float16") = lv188[0] | |
lv190 = R.call_tir(cls.rotary_embedding, (lv189, metadata["relax.expr.Constant"][3], metadata["relax.expr.Constant"][4]), out_sinfo=R.Tensor((1, n, 32, 80), dtype="float16"), tir_vars=R.shape([m])) | |
lv191: R.Tensor((1, n, 32, 80), dtype="float16") = lv188[1] | |
lv192 = R.call_tir(cls.rotary_embedding, (lv191, metadata["relax.expr.Constant"][3], metadata["relax.expr.Constant"][4]), out_sinfo=R.Tensor((1, n, 32, 80), dtype="float16"), tir_vars=R.shape([m])) | |
lv193: R.Object = kv_cache[6] | |
lv194 = R.call_tir(cls.squeeze, (lv192,), out_sinfo=R.Tensor((n, 32, 80), dtype="float16")) | |
lv195: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv193, lv194, sinfo_args=(R.Object,)) | |
lv196: R.Object = kv_cache[7] | |
lv75_1: R.Tensor((1, n, 32, 80), dtype="float16") = lv188[2] | |
lv76_1 = R.call_tir(cls.fused_squeeze, (lv75_1,), out_sinfo=R.Tensor((n, 32, 80), dtype="float16")) | |
lv199: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv196, lv76_1, sinfo_args=(R.Object,)) | |
lv200: R.Tensor((m, 32, 80), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv195, R.shape([m, 32, 80]), sinfo_args=(R.Tensor((m, 32, 80), dtype="float16"),)) | |
lv201: R.Tensor((m, 32, 80), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv199, R.shape([m, 32, 80]), sinfo_args=(R.Tensor((m, 32, 80), dtype="float16"),)) | |
lv202 = R.call_tir(cls.reshape3, (lv200,), out_sinfo=R.Tensor((1, m, 32, 80), dtype="float16")) | |
lv203 = R.call_tir(cls.reshape3, (lv201,), out_sinfo=R.Tensor((1, m, 32, 80), dtype="float16")) | |
lv204 = R.call_tir(cls.transpose5, (lv190,), out_sinfo=R.Tensor((1, 32, n, 80), dtype="float16")) | |
lv205 = R.call_tir(cls.transpose5, (lv202,), out_sinfo=R.Tensor((1, 32, m, 80), dtype="float16")) | |
lv206 = R.call_tir(cls.transpose5, (lv203,), out_sinfo=R.Tensor((1, 32, m, 80), dtype="float16")) | |
lv77_1 = R.call_tir(cls.fused_NT_matmul1_divide_maximum_minimum_cast2, (lv204, lv205, lv5), out_sinfo=R.Tensor((1, 32, n, m), dtype="float32")) | |
lv78_1 = R.call_tir(cls.fused_softmax_cast3, (lv77_1,), out_sinfo=R.Tensor((1, 32, n, m), dtype="float16")) | |
lv215 = R.call_tir(cls.matmul8, (lv78_1, lv206), out_sinfo=R.Tensor((1, 32, n, 80), dtype="float16")) | |
lv216 = R.call_tir(cls.transpose6, (lv215,), out_sinfo=R.Tensor((1, n, 32, 80), dtype="float16")) | |
lv217 = R.call_tir(cls.reshape4, (lv216,), out_sinfo=R.Tensor((1, n, 2560), dtype="float16")) | |
lv79_1: R.Tensor((2560, 320), dtype="uint32") = model_params[57] | |
lv80_1: R.Tensor((2560, 80), dtype="float16") = model_params[58] | |
lv57_2: R.Tensor((2560,), dtype="float16") = model_params[59] | |
lv70_2 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul2_add1_add2, (lv79_1, lv80_1, lv217, lv57_2, lv69_1), out_sinfo=R.Tensor((1, n, 2560), dtype="float16")) | |
lv222 = R.call_tir(cls.cast, (lv70_2,), out_sinfo=R.Tensor((1, n, 2560), dtype="float32")) | |
lv58_1: R.Tensor((2560,), dtype="float32") = model_params[52] | |
lv59: R.Tensor((2560,), dtype="float32") = model_params[53] | |
lv83_1 = R.call_tir(cls.fused_layer_norm_cast1, (lv222, lv58_1, lv59), out_sinfo=R.Tensor((1, n, 2560), dtype="float16")) | |
lv225: R.Tensor((1, n, 2560), dtype="float16") = lv83_1 | |
lv84_1: R.Tensor((10240, 320), dtype="uint32") = model_params[60] | |
lv85_1: R.Tensor((10240, 80), dtype="float16") = model_params[61] | |
lv62_1: R.Tensor((10240,), dtype="float32") = model_params[62] | |
lv71_2 = R.call_tir(cls.fused_fused_decode4_fused_NT_matmul3_add3_gelu_cast4, (lv84_1, lv85_1, lv225, lv62_1), out_sinfo=R.Tensor((1, n, 10240), dtype="float16")) | |
lv231: R.Tensor((1, n, 10240), dtype="float16") = lv71_2 | |
lv88_1: R.Tensor((2560, 1280), dtype="uint32") = model_params[63] | |
lv89_1: R.Tensor((2560, 320), dtype="float16") = model_params[64] | |
lv65_2: R.Tensor((2560,), dtype="float32") = model_params[65] | |
lv71_3 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul4_add4_cast1_cast5_add2, (lv88_1, lv89_1, lv231, lv65_2, lv70_2), out_sinfo=R.Tensor((1, n, 2560), dtype="float16")) | |
lv238 = R.call_tir(cls.cast, (lv71_3,), out_sinfo=R.Tensor((1, n, 2560), dtype="float32")) | |
lv66_3: R.Tensor((2560,), dtype="float32") = model_params[66] | |
lv67_4: R.Tensor((2560,), dtype="float32") = model_params[67] | |
lv92 = R.call_tir(cls.fused_layer_norm_cast1, (lv238, lv66_3, lv67_4), out_sinfo=R.Tensor((1, n, 2560), dtype="float16")) | |
lv241: R.Tensor((1, n, 2560), dtype="float16") = lv92 | |
lv93: R.Tensor((7680, 320), dtype="uint32") = model_params[70] | |
lv94: R.Tensor((7680, 80), dtype="float16") = model_params[71] | |
lv70_3: R.Tensor((7680,), dtype="float16") = model_params[72] | |
lv72_2 = R.call_tir(cls.fused_fused_decode2_fused_NT_matmul_add, (lv93, lv94, lv241, lv70_3), out_sinfo=R.Tensor((1, n, 7680), dtype="float16")) | |
lv245 = R.call_tir(cls.reshape2, (lv72_2,), out_sinfo=R.Tensor((1, n, 32, 240), dtype="float16")) | |
lv246 = R.call_tir(cls.split, (lv245,), out_sinfo=[R.Tensor((1, n, 32, 80), dtype="float16"), R.Tensor((1, n, 32, 80), dtype="float16"), R.Tensor((1, n, 32, 80), dtype="float16")]) | |
lv247: R.Tensor((1, n, 32, 80), dtype="float16") = lv246[0] | |
lv248 = R.call_tir(cls.rotary_embedding, (lv247, metadata["relax.expr.Constant"][3], metadata["relax.expr.Constant"][4]), out_sinfo=R.Tensor((1, n, 32, 80), dtype="float16"), tir_vars=R.shape([m])) | |
lv249: R.Tensor((1, n, 32, 80), dtype="float16") = lv246[1] | |
lv250 = R.call_tir(cls.rotary_embedding, (lv249, metadata["relax.expr.Constant"][3], metadata["relax.expr.Constant"][4]), out_sinfo=R.Tensor((1, n, 32, 80), dtype="float16"), tir_vars=R.shape([m])) | |
lv251: R.Object = kv_cache[8] | |
lv252 = R.call_tir(cls.squeeze, (lv250,), out_sinfo=R.Tensor((n, 32, 80), dtype="float16")) | |
lv253: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv251, lv252, sinfo_args=(R.Object,)) | |
lv254: R.Object = kv_cache[9] | |
lv97: R.Tensor((1, n, 32, 80), dtype="float16") = lv246[2] | |
lv98 = R.call_tir(cls.fused_squeeze, (lv97,), out_sinfo=R.Tensor((n, 32, 80), dtype="float16")) | |
lv257: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv254, lv98, sinfo_args=(R.Object,)) | |
lv258: R.Tensor((m, 32, 80), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv253, R.shape([m, 32, 80]), sinfo_args=(R.Tensor((m, 32, 80), dtype="float16"),)) | |
lv259: R.Tensor((m, 32, 80), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv257, R.shape([m, 32, 80]), sinfo_args=(R.Tensor((m, 32, 80), dtype="float16"),)) | |
lv260 = R.call_tir(cls.reshape3, (lv258,), out_sinfo=R.Tensor((1, m, 32, 80), dtype="float16")) | |
lv261 = R.call_tir(cls.reshape3, (lv259,), out_sinfo=R.Tensor((1, m, 32, 80), dtype="float16")) | |
lv262 = R.call_tir(cls.transpose5, (lv248,), out_sinfo=R.Tensor((1, 32, n, 80), dtype="float16")) | |
lv263 = R.call_tir(cls.transpose5, (lv260,), out_sinfo=R.Tensor((1, 32, m, 80), dtype="float16")) | |
lv264 = R.call_tir(cls.transpose5, (lv261,), out_sinfo=R.Tensor((1, 32, m, 80), dtype="float16")) | |
lv99_1 = R.call_tir(cls.fused_NT_matmul1_divide_maximum_minimum_cast2, (lv262, lv263, lv5), out_sinfo=R.Tensor((1, 32, n, m), dtype="float32")) | |
lv100_1 = R.call_tir(cls.fused_softmax_cast3, (lv99_1,), out_sinfo=R.Tensor((1, 32, n, m), dtype="float16")) | |
lv273 = R.call_tir(cls.matmul8, (lv100_1, lv264), out_sinfo=R.Tensor((1, 32, n, 80), dtype="float16")) | |
lv274 = R.call_tir(cls.transpose6, (lv273,), out_sinfo=R.Tensor((1, n, 32, 80), dtype="float16")) | |
lv275 = R.call_tir(cls.reshape4, (lv274,), out_sinfo=R.Tensor((1, n, 2560), dtype="float16")) | |
lv101_1: R.Tensor((2560, 320), dtype="uint32") = model_params[73] | |
lv102: R.Tensor((2560, 80), dtype="float16") = model_params[74] | |
lv73_1: R.Tensor((2560,), dtype="float16") = model_params[75] | |
lv72_3 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul2_add1_add2, (lv101_1, lv102, lv275, lv73_1, lv71_3), out_sinfo=R.Tensor((1, n, 2560), dtype="float16")) | |
lv280 = R.call_tir(cls.cast, (lv72_3,), out_sinfo=R.Tensor((1, n, 2560), dtype="float32")) | |
lv74_1: R.Tensor((2560,), dtype="float32") = model_params[68] | |
lv75_2: R.Tensor((2560,), dtype="float32") = model_params[69] | |
lv105 = R.call_tir(cls.fused_layer_norm_cast1, (lv280, lv74_1, lv75_2), out_sinfo=R.Tensor((1, n, 2560), dtype="float16")) | |
lv283: R.Tensor((1, n, 2560), dtype="float16") = lv105 | |
lv106_1: R.Tensor((10240, 320), dtype="uint32") = model_params[76] | |
lv107: R.Tensor((10240, 80), dtype="float16") = model_params[77] | |
lv78_2: R.Tensor((10240,), dtype="float32") = model_params[78] | |
lv73_2 = R.call_tir(cls.fused_fused_decode4_fused_NT_matmul3_add3_gelu_cast4, (lv106_1, lv107, lv283, lv78_2), out_sinfo=R.Tensor((1, n, 10240), dtype="float16")) | |
lv289: R.Tensor((1, n, 10240), dtype="float16") = lv73_2 | |
lv110: R.Tensor((2560, 1280), dtype="uint32") = model_params[79] | |
lv111: R.Tensor((2560, 320), dtype="float16") = model_params[80] | |
lv81: R.Tensor((2560,), dtype="float32") = model_params[81] | |
lv73_3 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul4_add4_cast1_cast5_add2, (lv110, lv111, lv289, lv81, lv72_3), out_sinfo=R.Tensor((1, n, 2560), dtype="float16")) | |
lv296 = R.call_tir(cls.cast, (lv73_3,), out_sinfo=R.Tensor((1, n, 2560), dtype="float32")) | |
lv82: R.Tensor((2560,), dtype="float32") = model_params[82] | |
lv83_2: R.Tensor((2560,), dtype="float32") = model_params[83] | |
lv114 = R.call_tir(cls.fused_layer_norm_cast1, (lv296, lv82, lv83_2), out_sinfo=R.Tensor((1, n, 2560), dtype="float16")) | |
lv299: R.Tensor((1, n, 2560), dtype="float16") = lv114 | |
lv115_1: R.Tensor((7680, 320), dtype="uint32") = model_params[86] | |
lv116: R.Tensor((7680, 80), dtype="float16") = model_params[87] | |
lv86_1: R.Tensor((7680,), dtype="float16") = model_params[88] | |
lv74_2 = R.call_tir(cls.fused_fused_decode2_fused_NT_matmul_add, (lv115_1, lv116, lv299, lv86_1), out_sinfo=R.Tensor((1, n, 7680), dtype="float16")) | |
lv303 = R.call_tir(cls.reshape2, (lv74_2,), out_sinfo=R.Tensor((1, n, 32, 240), dtype="float16")) | |
lv304 = R.call_tir(cls.split, (lv303,), out_sinfo=[R.Tensor((1, n, 32, 80), dtype="float16"), R.Tensor((1, n, 32, 80), dtype="float16"), R.Tensor((1, n, 32, 80), dtype="float16")]) | |
lv305: R.Tensor((1, n, 32, 80), dtype="float16") = lv304[0] | |
lv306 = R.call_tir(cls.rotary_embedding, (lv305, metadata["relax.expr.Constant"][3], metadata["relax.expr.Constant"][4]), out_sinfo=R.Tensor((1, n, 32, 80), dtype="float16"), tir_vars=R.shape([m])) | |
lv307: R.Tensor((1, n, 32, 80), dtype="float16") = lv304[1] | |
lv308 = R.call_tir(cls.rotary_embedding, (lv307, metadata["relax.expr.Constant"][3], metadata["relax.expr.Constant"][4]), out_sinfo=R.Tensor((1, n, 32, 80), dtype="float16"), tir_vars=R.shape([m])) | |
lv309: R.Object = kv_cache[10] | |
lv310 = R.call_tir(cls.squeeze, (lv308,), out_sinfo=R.Tensor((n, 32, 80), dtype="float16")) | |
lv311: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv309, lv310, sinfo_args=(R.Object,)) | |
lv312: R.Object = kv_cache[11] | |
lv119: R.Tensor((1, n, 32, 80), dtype="float16") = lv304[2] | |
lv120 = R.call_tir(cls.fused_squeeze, (lv119,), out_sinfo=R.Tensor((n, 32, 80), dtype="float16")) | |
lv315: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv312, lv120, sinfo_args=(R.Object,)) | |
lv316: R.Tensor((m, 32, 80), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv311, R.shape([m, 32, 80]), sinfo_args=(R.Tensor((m, 32, 80), dtype="float16"),)) | |
lv317: R.Tensor((m, 32, 80), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv315, R.shape([m, 32, 80]), sinfo_args=(R.Tensor((m, 32, 80), dtype="float16"),)) | |
lv318 = R.call_tir(cls.reshape3, (lv316,), out_sinfo=R.Tensor((1, m, 32, 80), dtype="float16")) | |
lv319 = R.call_tir(cls.reshape3, (lv317,), out_sinfo=R.Tensor((1, m, 32, 80), dtype="float16")) | |
lv320 = R.call_tir(cls.transpose5, (lv306,), out_sinfo=R.Tensor((1, 32, n, 80), dtype="float16")) | |
lv321 = R.call_tir(cls.transpose5, (lv318,), out_sinfo=R.Tensor((1, 32, m, 80), dtype="float16")) | |
lv322 = R.call_tir(cls.transpose5, (lv319,), out_sinfo=R.Tensor((1, 32, m, 80), dtype="float16")) | |
lv121 = R.call_tir(cls.fused_NT_matmul1_divide_maximum_minimum_cast2, (lv320, lv321, lv5), out_sinfo=R.Tensor((1, 32, n, m), dtype="float32")) | |
lv122_1 = R.call_tir(cls.fused_softmax_cast3, (lv121,), out_sinfo=R.Tensor((1, 32, n, m), dtype="float16")) | |
lv331 = R.call_tir(cls.matmul8, (lv122_1, lv322), out_sinfo=R.Tensor((1, 32, n, 80), dtype="float16")) | |
lv332 = R.call_tir(cls.transpose6, (lv331,), out_sinfo=R.Tensor((1, n, 32, 80), dtype="float16")) | |
lv333 = R.call_tir(cls.reshape4, (lv332,), out_sinfo=R.Tensor((1, n, 2560), dtype="float16")) | |
lv123: R.Tensor((2560, 320), dtype="uint32") = model_params[89] | |
lv124: R.Tensor((2560, 80), dtype="float16") = model_params[90] | |
lv89_2: R.Tensor((2560,), dtype="float16") = model_params[91] | |
lv74_3 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul2_add1_add2, (lv123, lv124, lv333, lv89_2, lv73_3), out_sinfo=R.Tensor((1, n, 2560), dtype="float16")) | |
lv338 = R.call_tir(cls.cast, (lv74_3,), out_sinfo=R.Tensor((1, n, 2560), dtype="float32")) | |
lv90_1: R.Tensor((2560,), dtype="float32") = model_params[84] | |
lv91: R.Tensor((2560,), dtype="float32") = model_params[85] | |
lv127 = R.call_tir(cls.fused_layer_norm_cast1, (lv338, lv90_1, lv91), out_sinfo=R.Tensor((1, n, 2560), dtype="float16")) | |
lv341: R.Tensor((1, n, 2560), dtype="float16") = lv127 | |
lv128: R.Tensor((10240, 320), dtype="uint32") = model_params[92] | |
lv129_1: R.Tensor((10240, 80), dtype="float16") = model_params[93] | |
lv94_1: R.Tensor((10240,), dtype="float32") = model_params[94] | |
lv75_3 = R.call_tir(cls.fused_fused_decode4_fused_NT_matmul3_add3_gelu_cast4, (lv128, lv129_1, lv341, lv94_1), out_sinfo=R.Tensor((1, n, 10240), dtype="float16")) | |
lv347: R.Tensor((1, n, 10240), dtype="float16") = lv75_3 | |
lv132_1: R.Tensor((2560, 1280), dtype="uint32") = model_params[95] | |
lv133_1: R.Tensor((2560, 320), dtype="float16") = model_params[96] | |
lv97_1: R.Tensor((2560,), dtype="float32") = model_params[97] | |
lv75_4 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul4_add4_cast1_cast5_add2, (lv132_1, lv133_1, lv347, lv97_1, lv74_3), out_sinfo=R.Tensor((1, n, 2560), dtype="float16")) | |
lv354 = R.call_tir(cls.cast, (lv75_4,), out_sinfo=R.Tensor((1, n, 2560), dtype="float32")) | |
lv98_1: R.Tensor((2560,), dtype="float32") = model_params[98] | |
lv99_2: R.Tensor((2560,), dtype="float32") = model_params[99] | |
lv136_1 = R.call_tir(cls.fused_layer_norm_cast1, (lv354, lv98_1, lv99_2), out_sinfo=R.Tensor((1, n, 2560), dtype="float16")) | |
lv357: R.Tensor((1, n, 2560), dtype="float16") = lv136_1 | |
lv137_1: R.Tensor((7680, 320), dtype="uint32") = model_params[102] | |
lv138_1: R.Tensor((7680, 80), dtype="float16") = model_params[103] | |
lv102_1: R.Tensor((7680,), dtype="float16") = model_params[104] | |
lv76_2 = R.call_tir(cls.fused_fused_decode2_fused_NT_matmul_add, (lv137_1, lv138_1, lv357, lv102_1), out_sinfo=R.Tensor((1, n, 7680), dtype="float16")) | |
lv361 = R.call_tir(cls.reshape2, (lv76_2,), out_sinfo=R.Tensor((1, n, 32, 240), dtype="float16")) | |
lv362 = R.call_tir(cls.split, (lv361,), out_sinfo=[R.Tensor((1, n, 32, 80), dtype="float16"), R.Tensor((1, n, 32, 80), dtype="float16"), R.Tensor((1, n, 32, 80), dtype="float16")]) | |
lv363: R.Tensor((1, n, 32, 80), dtype="float16") = lv362[0] | |
lv364 = R.call_tir(cls.rotary_embedding, (lv363, metadata["relax.expr.Constant"][3], metadata["relax.expr.Constant"][4]), out_sinfo=R.Tensor((1, n, 32, 80), dtype="float16"), tir_vars=R.shape([m])) | |
lv365: R.Tensor((1, n, 32, 80), dtype="float16") = lv362[1] | |
lv366 = R.call_tir(cls.rotary_embedding, (lv365, metadata["relax.expr.Constant"][3], metadata["relax.expr.Constant"][4]), out_sinfo=R.Tensor((1, n, 32, 80), dtype="float16"), tir_vars=R.shape([m])) | |
lv367: R.Object = kv_cache[12] | |
lv368 = R.call_tir(cls.squeeze, (lv366,), out_sinfo=R.Tensor((n, 32, 80), dtype="float16")) | |
lv369: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv367, lv368, sinfo_args=(R.Object,)) | |
lv370: R.Object = kv_cache[13] | |
lv141_1: R.Tensor((1, n, 32, 80), dtype="float16") = lv362[2] | |
lv142_1 = R.call_tir(cls.fused_squeeze, (lv141_1,), out_sinfo=R.Tensor((n, 32, 80), dtype="float16")) | |
lv373: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv370, lv142_1, sinfo_args=(R.Object,)) | |
lv374: R.Tensor((m, 32, 80), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv369, R.shape([m, 32, 80]), sinfo_args=(R.Tensor((m, 32, 80), dtype="float16"),)) | |
lv375: R.Tensor((m, 32, 80), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv373, R.shape([m, 32, 80]), sinfo_args=(R.Tensor((m, 32, 80), dtype="float16"),)) | |
lv376 = R.call_tir(cls.reshape3, (lv374,), out_sinfo=R.Tensor((1, m, 32, 80), dtype="float16")) | |
lv377 = R.call_tir(cls.reshape3, (lv375,), out_sinfo=R.Tensor((1, m, 32, 80), dtype="float16")) | |
lv378 = R.call_tir(cls.transpose5, (lv364,), out_sinfo=R.Tensor((1, 32, n, 80), dtype="float16")) | |
lv379 = R.call_tir(cls.transpose5, (lv376,), out_sinfo=R.Tensor((1, 32, m, 80), dtype="float16")) | |
lv380 = R.call_tir(cls.transpose5, (lv377,), out_sinfo=R.Tensor((1, 32, m, 80), dtype="float16")) | |
lv143_1 = R.call_tir(cls.fused_NT_matmul1_divide_maximum_minimum_cast2, (lv378, lv379, lv5), out_sinfo=R.Tensor((1, 32, n, m), dtype="float32")) | |
lv144_1 = R.call_tir(cls.fused_softmax_cast3, (lv143_1,), out_sinfo=R.Tensor((1, 32, n, m), dtype="float16")) | |
lv389 = R.call_tir(cls.matmul8, (lv144_1, lv380), out_sinfo=R.Tensor((1, 32, n, 80), dtype="float16")) | |
lv390 = R.call_tir(cls.transpose6, (lv389,), out_sinfo=R.Tensor((1, n, 32, 80), dtype="float16")) | |
lv391 = R.call_tir(cls.reshape4, (lv390,), out_sinfo=R.Tensor((1, n, 2560), dtype="float16")) | |
lv145_1: R.Tensor((2560, 320), dtype="uint32") = model_params[105] | |
lv146_1: R.Tensor((2560, 80), dtype="float16") = model_params[106] | |
lv105_1: R.Tensor((2560,), dtype="float16") = model_params[107] | |
lv76_3 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul2_add1_add2, (lv145_1, lv146_1, lv391, lv105_1, lv75_4), out_sinfo=R.Tensor((1, n, 2560), dtype="float16")) | |
lv396 = R.call_tir(cls.cast, (lv76_3,), out_sinfo=R.Tensor((1, n, 2560), dtype="float32")) | |
lv106_2: R.Tensor((2560,), dtype="float32") = model_params[100] | |
lv107_1: R.Tensor((2560,), dtype="float32") = model_params[101] | |
lv149 = R.call_tir(cls.fused_layer_norm_cast1, (lv396, lv106_2, lv107_1), out_sinfo=R.Tensor((1, n, 2560), dtype="float16")) | |
lv399: R.Tensor((1, n, 2560), dtype="float16") = lv149 | |
lv150: R.Tensor((10240, 320), dtype="uint32") = model_params[108] | |
lv151: R.Tensor((10240, 80), dtype="float16") = model_params[109] | |
lv110_1: R.Tensor((10240,), dtype="float32") = model_params[110] | |
lv77_2 = R.call_tir(cls.fused_fused_decode4_fused_NT_matmul3_add3_gelu_cast4, (lv150, lv151, lv399, lv110_1), out_sinfo=R.Tensor((1, n, 10240), dtype="float16")) | |
lv405: R.Tensor((1, n, 10240), dtype="float16") = lv77_2 | |
lv154: R.Tensor((2560, 1280), dtype="uint32") = model_params[111] | |
lv155: R.Tensor((2560, 320), dtype="float16") = model_params[112] | |
lv113: R.Tensor((2560,), dtype="float32") = model_params[113] | |
lv77_3 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul4_add4_cast1_cast5_add2, (lv154, lv155, lv405, lv113, lv76_3), out_sinfo=R.Tensor((1, n, 2560), dtype="float16")) | |
lv412 = R.call_tir(cls.cast, (lv77_3,), out_sinfo=R.Tensor((1, n, 2560), dtype="float32")) | |
lv114_1: R.Tensor((2560,), dtype="float32") = model_params[114] | |
lv115_2: R.Tensor((2560,), dtype="float32") = model_params[115] | |
lv158_1 = R.call_tir(cls.fused_layer_norm_cast1, (lv412, lv114_1, lv115_2), out_sinfo=R.Tensor((1, n, 2560), dtype="float16")) | |
lv415: R.Tensor((1, n, 2560), dtype="float16") = lv158_1 | |
lv159_1: R.Tensor((7680, 320), dtype="uint32") = model_params[118] | |
lv160: R.Tensor((7680, 80), dtype="float16") = model_params[119] | |
lv118: R.Tensor((7680,), dtype="float16") = model_params[120] | |
lv78_3 = R.call_tir(cls.fused_fused_decode2_fused_NT_matmul_add, (lv159_1, lv160, lv415, lv118), out_sinfo=R.Tensor((1, n, 7680), dtype="float16")) | |
lv419 = R.call_tir(cls.reshape2, (lv78_3,), out_sinfo=R.Tensor((1, n, 32, 240), dtype="float16")) | |
lv420 = R.call_tir(cls.split, (lv419,), out_sinfo=[R.Tensor((1, n, 32, 80), dtype="float16"), R.Tensor((1, n, 32, 80), dtype="float16"), R.Tensor((1, n, 32, 80), dtype="float16")]) | |
lv421: R.Tensor((1, n, 32, 80), dtype="float16") = lv420[0] | |
lv422 = R.call_tir(cls.rotary_embedding, (lv421, metadata["relax.expr.Constant"][3], metadata["relax.expr.Constant"][4]), out_sinfo=R.Tensor((1, n, 32, 80), dtype="float16"), tir_vars=R.shape([m])) | |
lv423: R.Tensor((1, n, 32, 80), dtype="float16") = lv420[1] | |
lv424 = R.call_tir(cls.rotary_embedding, (lv423, metadata["relax.expr.Constant"][3], metadata["relax.expr.Constant"][4]), out_sinfo=R.Tensor((1, n, 32, 80), dtype="float16"), tir_vars=R.shape([m])) | |
lv425: R.Object = kv_cache[14] | |
lv426 = R.call_tir(cls.squeeze, (lv424,), out_sinfo=R.Tensor((n, 32, 80), dtype="float16")) | |
lv427: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv425, lv426, sinfo_args=(R.Object,)) | |
lv428: R.Object = kv_cache[15] | |
lv163: R.Tensor((1, n, 32, 80), dtype="float16") = lv420[2] | |
lv164_1 = R.call_tir(cls.fused_squeeze, (lv163,), out_sinfo=R.Tensor((n, 32, 80), dtype="float16")) | |
lv431: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv428, lv164_1, sinfo_args=(R.Object,)) | |
lv432: R.Tensor((m, 32, 80), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv427, R.shape([m, 32, 80]), sinfo_args=(R.Tensor((m, 32, 80), dtype="float16"),)) | |
lv433: R.Tensor((m, 32, 80), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv431, R.shape([m, 32, 80]), sinfo_args=(R.Tensor((m, 32, 80), dtype="float16"),)) | |
lv434 = R.call_tir(cls.reshape3, (lv432,), out_sinfo=R.Tensor((1, m, 32, 80), dtype="float16")) | |
lv435 = R.call_tir(cls.reshape3, (lv433,), out_sinfo=R.Tensor((1, m, 32, 80), dtype="float16")) | |
lv436 = R.call_tir(cls.transpose5, (lv422,), out_sinfo=R.Tensor((1, 32, n, 80), dtype="float16")) | |
lv437 = R.call_tir(cls.transpose5, (lv434,), out_sinfo=R.Tensor((1, 32, m, 80), dtype="float16")) | |
lv438 = R.call_tir(cls.transpose5, (lv435,), out_sinfo=R.Tensor((1, 32, m, 80), dtype="float16")) | |
lv165 = R.call_tir(cls.fused_NT_matmul1_divide_maximum_minimum_cast2, (lv436, lv437, lv5), out_sinfo=R.Tensor((1, 32, n, m), dtype="float32")) | |
lv166 = R.call_tir(cls.fused_softmax_cast3, (lv165,), out_sinfo=R.Tensor((1, 32, n, m), dtype="float16")) | |
lv447 = R.call_tir(cls.matmul8, (lv166, lv438), out_sinfo=R.Tensor((1, 32, n, 80), dtype="float16")) | |
lv448 = R.call_tir(cls.transpose6, (lv447,), out_sinfo=R.Tensor((1, n, 32, 80), dtype="float16")) | |
lv449 = R.call_tir(cls.reshape4, (lv448,), out_sinfo=R.Tensor((1, n, 2560), dtype="float16")) | |
lv167_1: R.Tensor((2560, 320), dtype="uint32") = model_params[121] | |
lv168: R.Tensor((2560, 80), dtype="float16") = model_params[122] | |
lv121_1: R.Tensor((2560,), dtype="float16") = model_params[123] | |
lv78_4 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul2_add1_add2, (lv167_1, lv168, lv449, lv121_1, lv77_3), out_sinfo=R.Tensor((1, n, 2560), dtype="float16")) | |
lv454 = R.call_tir(cls.cast, (lv78_4,), out_sinfo=R.Tensor((1, n, 2560), dtype="float32")) | |
lv122_2: R.Tensor((2560,), dtype="float32") = model_params[116] | |
lv123_1: R.Tensor((2560,), dtype="float32") = model_params[117] | |
lv171 = R.call_tir(cls.fused_layer_norm_cast1, (lv454, lv122_2, lv123_1), out_sinfo=R.Tensor((1, n, 2560), dtype="float16")) | |
lv457: R.Tensor((1, n, 2560), dtype="float16") = lv171 | |
lv172: R.Tensor((10240, 320), dtype="uint32") = model_params[124] | |
lv173_1: R.Tensor((10240, 80), dtype="float16") = model_params[125] | |
lv126: R.Tensor((10240,), dtype="float32") = model_params[126] | |
lv79_2 = R.call_tir(cls.fused_fused_decode4_fused_NT_matmul3_add3_gelu_cast4, (lv172, lv173_1, lv457, lv126), out_sinfo=R.Tensor((1, n, 10240), dtype="float16")) | |
lv463: R.Tensor((1, n, 10240), dtype="float16") = lv79_2 | |
lv176: R.Tensor((2560, 1280), dtype="uint32") = model_params[127] | |
lv177: R.Tensor((2560, 320), dtype="float16") = model_params[128] | |
lv129_2: R.Tensor((2560,), dtype="float32") = model_params[129] | |
lv79_3 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul4_add4_cast1_cast5_add2, (lv176, lv177, lv463, lv129_2, lv78_4), out_sinfo=R.Tensor((1, n, 2560), dtype="float16")) | |
lv470 = R.call_tir(cls.cast, (lv79_3,), out_sinfo=R.Tensor((1, n, 2560), dtype="float32")) | |
lv130_1: R.Tensor((2560,), dtype="float32") = model_params[130] | |
lv131_1: R.Tensor((2560,), dtype="float32") = model_params[131] | |
lv180_1 = R.call_tir(cls.fused_layer_norm_cast1, (lv470, lv130_1, lv131_1), out_sinfo=R.Tensor((1, n, 2560), dtype="float16")) | |
lv473: R.Tensor((1, n, 2560), dtype="float16") = lv180_1 | |
lv181: R.Tensor((7680, 320), dtype="uint32") = model_params[134] | |
lv182: R.Tensor((7680, 80), dtype="float16") = model_params[135] | |
lv134_1: R.Tensor((7680,), dtype="float16") = model_params[136] | |
lv80_2 = R.call_tir(cls.fused_fused_decode2_fused_NT_matmul_add, (lv181, lv182, lv473, lv134_1), out_sinfo=R.Tensor((1, n, 7680), dtype="float16")) | |
lv477 = R.call_tir(cls.reshape2, (lv80_2,), out_sinfo=R.Tensor((1, n, 32, 240), dtype="float16")) | |
lv478 = R.call_tir(cls.split, (lv477,), out_sinfo=[R.Tensor((1, n, 32, 80), dtype="float16"), R.Tensor((1, n, 32, 80), dtype="float16"), R.Tensor((1, n, 32, 80), dtype="float16")]) | |
lv479: R.Tensor((1, n, 32, 80), dtype="float16") = lv478[0] | |
lv480 = R.call_tir(cls.rotary_embedding, (lv479, metadata["relax.expr.Constant"][3], metadata["relax.expr.Constant"][4]), out_sinfo=R.Tensor((1, n, 32, 80), dtype="float16"), tir_vars=R.shape([m])) | |
lv481: R.Tensor((1, n, 32, 80), dtype="float16") = lv478[1] | |
lv482 = R.call_tir(cls.rotary_embedding, (lv481, metadata["relax.expr.Constant"][3], metadata["relax.expr.Constant"][4]), out_sinfo=R.Tensor((1, n, 32, 80), dtype="float16"), tir_vars=R.shape([m])) | |
lv483: R.Object = kv_cache[16] | |
lv484 = R.call_tir(cls.squeeze, (lv482,), out_sinfo=R.Tensor((n, 32, 80), dtype="float16")) | |
lv485: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv483, lv484, sinfo_args=(R.Object,)) | |
lv486: R.Object = kv_cache[17] | |
lv185: R.Tensor((1, n, 32, 80), dtype="float16") = lv478[2] | |
lv186 = R.call_tir(cls.fused_squeeze, (lv185,), out_sinfo=R.Tensor((n, 32, 80), dtype="float16")) | |
lv489: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv486, lv186, sinfo_args=(R.Object,)) | |
lv490: R.Tensor((m, 32, 80), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv485, R.shape([m, 32, 80]), sinfo_args=(R.Tensor((m, 32, 80), dtype="float16"),)) | |
lv491: R.Tensor((m, 32, 80), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv489, R.shape([m, 32, 80]), sinfo_args=(R.Tensor((m, 32, 80), dtype="float16"),)) | |
lv492 = R.call_tir(cls.reshape3, (lv490,), out_sinfo=R.Tensor((1, m, 32, 80), dtype="float16")) | |
lv493 = R.call_tir(cls.reshape3, (lv491,), out_sinfo=R.Tensor((1, m, 32, 80), dtype="float16")) | |
lv494 = R.call_tir(cls.transpose5, (lv480,), out_sinfo=R.Tensor((1, 32, n, 80), dtype="float16")) | |
lv495 = R.call_tir(cls.transpose5, (lv492,), out_sinfo=R.Tensor((1, 32, m, 80), dtype="float16")) | |
lv496 = R.call_tir(cls.transpose5, (lv493,), out_sinfo=R.Tensor((1, 32, m, 80), dtype="float16")) | |
lv187_1 = R.call_tir(cls.fused_NT_matmul1_divide_maximum_minimum_cast2, (lv494, lv495, lv5), out_sinfo=R.Tensor((1, 32, n, m), dtype="float32")) | |
lv188_1 = R.call_tir(cls.fused_softmax_cast3, (lv187_1,), out_sinfo=R.Tensor((1, 32, n, m), dtype="float16")) | |
lv505 = R.call_tir(cls.matmul8, (lv188_1, lv496), out_sinfo=R.Tensor((1, 32, n, 80), dtype="float16")) | |
lv506 = R.call_tir(cls.transpose6, (lv505,), out_sinfo=R.Tensor((1, n, 32, 80), dtype="float16")) | |
lv507 = R.call_tir(cls.reshape4, (lv506,), out_sinfo=R.Tensor((1, n, 2560), dtype="float16")) | |
lv189_1: R.Tensor((2560, 320), dtype="uint32") = model_params[137] | |
lv190_1: R.Tensor((2560, 80), dtype="float16") = model_params[138] | |
lv137_2: R.Tensor((2560,), dtype="float16") = model_params[139] | |
lv80_3 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul2_add1_add2, (lv189_1, lv190_1, lv507, lv137_2, lv79_3), out_sinfo=R.Tensor((1, n, 2560), dtype="float16")) | |
lv512 = R.call_tir(cls.cast, (lv80_3,), out_sinfo=R.Tensor((1, n, 2560), dtype="float32")) | |
lv138_2: R.Tensor((2560,), dtype="float32") = model_params[132] | |
lv139: R.Tensor((2560,), dtype="float32") = model_params[133] | |
lv193_1 = R.call_tir(cls.fused_layer_norm_cast1, (lv512, lv138_2, lv139), out_sinfo=R.Tensor((1, n, 2560), dtype="float16")) | |
lv515: R.Tensor((1, n, 2560), dtype="float16") = lv193_1 | |
lv194_1: R.Tensor((10240, 320), dtype="uint32") = model_params[140] | |
lv195_1: R.Tensor((10240, 80), dtype="float16") = model_params[141] | |
lv142_2: R.Tensor((10240,), dtype="float32") = model_params[142] | |
lv81_1 = R.call_tir(cls.fused_fused_decode4_fused_NT_matmul3_add3_gelu_cast4, (lv194_1, lv195_1, lv515, lv142_2), out_sinfo=R.Tensor((1, n, 10240), dtype="float16")) | |
lv521: R.Tensor((1, n, 10240), dtype="float16") = lv81_1 | |
lv198: R.Tensor((2560, 1280), dtype="uint32") = model_params[143] | |
lv199_1: R.Tensor((2560, 320), dtype="float16") = model_params[144] | |
lv145_2: R.Tensor((2560,), dtype="float32") = model_params[145] | |
lv81_2 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul4_add4_cast1_cast5_add2, (lv198, lv199_1, lv521, lv145_2, lv80_3), out_sinfo=R.Tensor((1, n, 2560), dtype="float16")) | |
lv528 = R.call_tir(cls.cast, (lv81_2,), out_sinfo=R.Tensor((1, n, 2560), dtype="float32")) | |
lv146_2: R.Tensor((2560,), dtype="float32") = model_params[146] | |
lv147_1: R.Tensor((2560,), dtype="float32") = model_params[147] | |
lv202_1 = R.call_tir(cls.fused_layer_norm_cast1, (lv528, lv146_2, lv147_1), out_sinfo=R.Tensor((1, n, 2560), dtype="float16")) | |
lv531: R.Tensor((1, n, 2560), dtype="float16") = lv202_1 | |
lv203_1: R.Tensor((7680, 320), dtype="uint32") = model_params[150] | |
lv204_1: R.Tensor((7680, 80), dtype="float16") = model_params[151] | |
lv150_1: R.Tensor((7680,), dtype="float16") = model_params[152] | |
lv82_1 = R.call_tir(cls.fused_fused_decode2_fused_NT_matmul_add, (lv203_1, lv204_1, lv531, lv150_1), out_sinfo=R.Tensor((1, n, 7680), dtype="float16")) | |
lv535 = R.call_tir(cls.reshape2, (lv82_1,), out_sinfo=R.Tensor((1, n, 32, 240), dtype="float16")) | |
lv536 = R.call_tir(cls.split, (lv535,), out_sinfo=[R.Tensor((1, n, 32, 80), dtype="float16"), R.Tensor((1, n, 32, 80), dtype="float16"), R.Tensor((1, n, 32, 80), dtype="float16")]) | |
lv537: R.Tensor((1, n, 32, 80), dtype="float16") = lv536[0] | |
lv538 = R.call_tir(cls.rotary_embedding, (lv537, metadata["relax.expr.Constant"][3], metadata["relax.expr.Constant"][4]), out_sinfo=R.Tensor((1, n, 32, 80), dtype="float16"), tir_vars=R.shape([m])) | |
lv539: R.Tensor((1, n, 32, 80), dtype="float16") = lv536[1] | |
lv540 = R.call_tir(cls.rotary_embedding, (lv539, metadata["relax.expr.Constant"][3], metadata["relax.expr.Constant"][4]), out_sinfo=R.Tensor((1, n, 32, 80), dtype="float16"), tir_vars=R.shape([m])) | |
lv541: R.Object = kv_cache[18] | |
lv542 = R.call_tir(cls.squeeze, (lv540,), out_sinfo=R.Tensor((n, 32, 80), dtype="float16")) | |
lv543: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv541, lv542, sinfo_args=(R.Object,)) | |
lv544: R.Object = kv_cache[19] | |
lv207: R.Tensor((1, n, 32, 80), dtype="float16") = lv536[2] | |
lv208 = R.call_tir(cls.fused_squeeze, (lv207,), out_sinfo=R.Tensor((n, 32, 80), dtype="float16")) | |
lv547: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv544, lv208, sinfo_args=(R.Object,)) | |
lv548: R.Tensor((m, 32, 80), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv543, R.shape([m, 32, 80]), sinfo_args=(R.Tensor((m, 32, 80), dtype="float16"),)) | |
lv549: R.Tensor((m, 32, 80), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv547, R.shape([m, 32, 80]), sinfo_args=(R.Tensor((m, 32, 80), dtype="float16"),)) | |
lv550 = R.call_tir(cls.reshape3, (lv548,), out_sinfo=R.Tensor((1, m, 32, 80), dtype="float16")) | |
lv551 = R.call_tir(cls.reshape3, (lv549,), out_sinfo=R.Tensor((1, m, 32, 80), dtype="float16")) | |
lv552 = R.call_tir(cls.transpose5, (lv538,), out_sinfo=R.Tensor((1, 32, n, 80), dtype="float16")) | |
lv553 = R.call_tir(cls.transpose5, (lv550,), out_sinfo=R.Tensor((1, 32, m, 80), dtype="float16")) | |
lv554 = R.call_tir(cls.transpose5, (lv551,), out_sinfo=R.Tensor((1, 32, m, 80), dtype="float16")) | |
lv209 = R.call_tir(cls.fused_NT_matmul1_divide_maximum_minimum_cast2, (lv552, lv553, lv5), out_sinfo=R.Tensor((1, 32, n, m), dtype="float32")) | |
lv210 = R.call_tir(cls.fused_softmax_cast3, (lv209,), out_sinfo=R.Tensor((1, 32, n, m), dtype="float16")) | |
lv563 = R.call_tir(cls.matmul8, (lv210, lv554), out_sinfo=R.Tensor((1, 32, n, 80), dtype="float16")) | |
lv564 = R.call_tir(cls.transpose6, (lv563,), out_sinfo=R.Tensor((1, n, 32, 80), dtype="float16")) | |
lv565 = R.call_tir(cls.reshape4, (lv564,), out_sinfo=R.Tensor((1, n, 2560), dtype="float16")) | |
lv211: R.Tensor((2560, 320), dtype="uint32") = model_params[153] | |
lv212: R.Tensor((2560, 80), dtype="float16") = model_params[154] | |
lv153: R.Tensor((2560,), dtype="float16") = model_params[155] | |
lv82_2 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul2_add1_add2, (lv211, lv212, lv565, lv153, lv81_2), out_sinfo=R.Tensor((1, n, 2560), dtype="float16")) | |
lv570 = R.call_tir(cls.cast, (lv82_2,), out_sinfo=R.Tensor((1, n, 2560), dtype="float32")) | |
lv154_1: R.Tensor((2560,), dtype="float32") = model_params[148] | |
lv155_1: R.Tensor((2560,), dtype="float32") = model_params[149] | |
lv215_1 = R.call_tir(cls.fused_layer_norm_cast1, (lv570, lv154_1, lv155_1), out_sinfo=R.Tensor((1, n, 2560), dtype="float16")) | |
lv573: R.Tensor((1, n, 2560), dtype="float16") = lv215_1 | |
lv216_1: R.Tensor((10240, 320), dtype="uint32") = model_params[156] | |
lv217_1: R.Tensor((10240, 80), dtype="float16") = model_params[157] | |
lv158_2: R.Tensor((10240,), dtype="float32") = model_params[158] | |
lv83_3 = R.call_tir(cls.fused_fused_decode4_fused_NT_matmul3_add3_gelu_cast4, (lv216_1, lv217_1, lv573, lv158_2), out_sinfo=R.Tensor((1, n, 10240), dtype="float16")) | |
lv579: R.Tensor((1, n, 10240), dtype="float16") = lv83_3 | |
lv220: R.Tensor((2560, 1280), dtype="uint32") = model_params[159] | |
lv221: R.Tensor((2560, 320), dtype="float16") = model_params[160] | |
lv161: R.Tensor((2560,), dtype="float32") = model_params[161] | |
lv83_4 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul4_add4_cast1_cast5_add2, (lv220, lv221, lv579, lv161, lv82_2), out_sinfo=R.Tensor((1, n, 2560), dtype="float16")) | |
lv586 = R.call_tir(cls.cast, (lv83_4,), out_sinfo=R.Tensor((1, n, 2560), dtype="float32")) | |
lv162: R.Tensor((2560,), dtype="float32") = model_params[162] | |
lv163_1: R.Tensor((2560,), dtype="float32") = model_params[163] | |
lv224 = R.call_tir(cls.fused_layer_norm_cast1, (lv586, lv162, lv163_1), out_sinfo=R.Tensor((1, n, 2560), dtype="float16")) | |
lv589: R.Tensor((1, n, 2560), dtype="float16") = lv224 | |
lv225_1: R.Tensor((7680, 320), dtype="uint32") = model_params[166] | |
lv226: R.Tensor((7680, 80), dtype="float16") = model_params[167] | |
lv166_1: R.Tensor((7680,), dtype="float16") = model_params[168] | |
lv84_2 = R.call_tir(cls.fused_fused_decode2_fused_NT_matmul_add, (lv225_1, lv226, lv589, lv166_1), out_sinfo=R.Tensor((1, n, 7680), dtype="float16")) | |
lv593 = R.call_tir(cls.reshape2, (lv84_2,), out_sinfo=R.Tensor((1, n, 32, 240), dtype="float16")) | |
lv594 = R.call_tir(cls.split, (lv593,), out_sinfo=[R.Tensor((1, n, 32, 80), dtype="float16"), R.Tensor((1, n, 32, 80), dtype="float16"), R.Tensor((1, n, 32, 80), dtype="float16")]) | |
lv595: R.Tensor((1, n, 32, 80), dtype="float16") = lv594[0] | |
lv596 = R.call_tir(cls.rotary_embedding, (lv595, metadata["relax.expr.Constant"][3], metadata["relax.expr.Constant"][4]), out_sinfo=R.Tensor((1, n, 32, 80), dtype="float16"), tir_vars=R.shape([m])) | |
lv597: R.Tensor((1, n, 32, 80), dtype="float16") = lv594[1] | |
lv598 = R.call_tir(cls.rotary_embedding, (lv597, metadata["relax.expr.Constant"][3], metadata["relax.expr.Constant"][4]), out_sinfo=R.Tensor((1, n, 32, 80), dtype="float16"), tir_vars=R.shape([m])) | |
lv599: R.Object = kv_cache[20] | |
lv600 = R.call_tir(cls.squeeze, (lv598,), out_sinfo=R.Tensor((n, 32, 80), dtype="float16")) | |
lv601: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv599, lv600, sinfo_args=(R.Object,)) | |
lv602: R.Object = kv_cache[21] | |
lv229: R.Tensor((1, n, 32, 80), dtype="float16") = lv594[2] | |
lv230 = R.call_tir(cls.fused_squeeze, (lv229,), out_sinfo=R.Tensor((n, 32, 80), dtype="float16")) | |
lv605: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv602, lv230, sinfo_args=(R.Object,)) | |
lv606: R.Tensor((m, 32, 80), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv601, R.shape([m, 32, 80]), sinfo_args=(R.Tensor((m, 32, 80), dtype="float16"),)) | |
lv607: R.Tensor((m, 32, 80), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv605, R.shape([m, 32, 80]), sinfo_args=(R.Tensor((m, 32, 80), dtype="float16"),)) | |
lv608 = R.call_tir(cls.reshape3, (lv606,), out_sinfo=R.Tensor((1, m, 32, 80), dtype="float16")) | |
lv609 = R.call_tir(cls.reshape3, (lv607,), out_sinfo=R.Tensor((1, m, 32, 80), dtype="float16")) | |
lv610 = R.call_tir(cls.transpose5, (lv596,), out_sinfo=R.Tensor((1, 32, n, 80), dtype="float16")) | |
lv611 = R.call_tir(cls.transpose5, (lv608,), out_sinfo=R.Tensor((1, 32, m, 80), dtype="float16")) | |
lv612 = R.call_tir(cls.transpose5, (lv609,), out_sinfo=R.Tensor((1, 32, m, 80), dtype="float16")) | |
lv231_1 = R.call_tir(cls.fused_NT_matmul1_divide_maximum_minimum_cast2, (lv610, lv611, lv5), out_sinfo=R.Tensor((1, 32, n, m), dtype="float32")) | |
lv232 = R.call_tir(cls.fused_softmax_cast3, (lv231_1,), out_sinfo=R.Tensor((1, 32, n, m), dtype="float16")) | |
lv621 = R.call_tir(cls.matmul8, (lv232, lv612), out_sinfo=R.Tensor((1, 32, n, 80), dtype="float16")) | |
lv622 = R.call_tir(cls.transpose6, (lv621,), out_sinfo=R.Tensor((1, n, 32, 80), dtype="float16")) | |
lv623 = R.call_tir(cls.reshape4, (lv622,), out_sinfo=R.Tensor((1, n, 2560), dtype="float16")) | |
lv233: R.Tensor((2560, 320), dtype="uint32") = model_params[169] | |
lv234: R.Tensor((2560, 80), dtype="float16") = model_params[170] | |
lv169: R.Tensor((2560,), dtype="float16") = model_params[171] | |
lv84_3 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul2_add1_add2, (lv233, lv234, lv623, lv169, lv83_4), out_sinfo=R.Tensor((1, n, 2560), dtype="float16")) | |
lv628 = R.call_tir(cls.cast, (lv84_3,), out_sinfo=R.Tensor((1, n, 2560), dtype="float32")) | |
lv170: R.Tensor((2560,), dtype="float32") = model_params[164] | |
lv171_1: R.Tensor((2560,), dtype="float32") = model_params[165] | |
lv237 = R.call_tir(cls.fused_layer_norm_cast1, (lv628, lv170, lv171_1), out_sinfo=R.Tensor((1, n, 2560), dtype="float16")) | |
lv631: R.Tensor((1, n, 2560), dtype="float16") = lv237 | |
lv238_1: R.Tensor((10240, 320), dtype="uint32") = model_params[172] | |
lv239: R.Tensor((10240, 80), dtype="float16") = model_params[173] | |
lv174: R.Tensor((10240,), dtype="float32") = model_params[174] | |
lv85_2 = R.call_tir(cls.fused_fused_decode4_fused_NT_matmul3_add3_gelu_cast4, (lv238_1, lv239, lv631, lv174), out_sinfo=R.Tensor((1, n, 10240), dtype="float16")) | |
lv637: R.Tensor((1, n, 10240), dtype="float16") = lv85_2 | |
lv242: R.Tensor((2560, 1280), dtype="uint32") = model_params[175] | |
lv243: R.Tensor((2560, 320), dtype="float16") = model_params[176] | |
lv177_1: R.Tensor((2560,), dtype="float32") = model_params[177] | |
lv85_3 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul4_add4_cast1_cast5_add2, (lv242, lv243, lv637, lv177_1, lv84_3), out_sinfo=R.Tensor((1, n, 2560), dtype="float16")) | |
lv644 = R.call_tir(cls.cast, (lv85_3,), out_sinfo=R.Tensor((1, n, 2560), dtype="float32")) | |
lv178: R.Tensor((2560,), dtype="float32") = model_params[178] | |
lv179: R.Tensor((2560,), dtype="float32") = model_params[179] | |
lv246_1 = R.call_tir(cls.fused_layer_norm_cast1, (lv644, lv178, lv179), out_sinfo=R.Tensor((1, n, 2560), dtype="float16")) | |
lv647: R.Tensor((1, n, 2560), dtype="float16") = lv246_1 | |
lv247_1: R.Tensor((7680, 320), dtype="uint32") = model_params[182] | |
lv248_1: R.Tensor((7680, 80), dtype="float16") = model_params[183] | |
lv182_1: R.Tensor((7680,), dtype="float16") = model_params[184] | |
lv86_2 = R.call_tir(cls.fused_fused_decode2_fused_NT_matmul_add, (lv247_1, lv248_1, lv647, lv182_1), out_sinfo=R.Tensor((1, n, 7680), dtype="float16")) | |
lv651 = R.call_tir(cls.reshape2, (lv86_2,), out_sinfo=R.Tensor((1, n, 32, 240), dtype="float16")) | |
lv652 = R.call_tir(cls.split, (lv651,), out_sinfo=[R.Tensor((1, n, 32, 80), dtype="float16"), R.Tensor((1, n, 32, 80), dtype="float16"), R.Tensor((1, n, 32, 80), dtype="float16")]) | |
lv653: R.Tensor((1, n, 32, 80), dtype="float16") = lv652[0] | |
lv654 = R.call_tir(cls.rotary_embedding, (lv653, metadata["relax.expr.Constant"][3], metadata["relax.expr.Constant"][4]), out_sinfo=R.Tensor((1, n, 32, 80), dtype="float16"), tir_vars=R.shape([m])) | |
lv655: R.Tensor((1, n, 32, 80), dtype="float16") = lv652[1] | |
lv656 = R.call_tir(cls.rotary_embedding, (lv655, metadata["relax.expr.Constant"][3], metadata["relax.expr.Constant"][4]), out_sinfo=R.Tensor((1, n, 32, 80), dtype="float16"), tir_vars=R.shape([m])) | |
lv657: R.Object = kv_cache[22] | |
lv658 = R.call_tir(cls.squeeze, (lv656,), out_sinfo=R.Tensor((n, 32, 80), dtype="float16")) | |
lv659: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv657, lv658, sinfo_args=(R.Object,)) | |
lv660: R.Object = kv_cache[23] | |
lv251_1: R.Tensor((1, n, 32, 80), dtype="float16") = lv652[2] | |
lv252_1 = R.call_tir(cls.fused_squeeze, (lv251_1,), out_sinfo=R.Tensor((n, 32, 80), dtype="float16")) | |
lv663: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv660, lv252_1, sinfo_args=(R.Object,)) | |
lv664: R.Tensor((m, 32, 80), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv659, R.shape([m, 32, 80]), sinfo_args=(R.Tensor((m, 32, 80), dtype="float16"),)) | |
lv665: R.Tensor((m, 32, 80), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv663, R.shape([m, 32, 80]), sinfo_args=(R.Tensor((m, 32, 80), dtype="float16"),)) | |
lv666 = R.call_tir(cls.reshape3, (lv664,), out_sinfo=R.Tensor((1, m, 32, 80), dtype="float16")) | |
lv667 = R.call_tir(cls.reshape3, (lv665,), out_sinfo=R.Tensor((1, m, 32, 80), dtype="float16")) | |
lv668 = R.call_tir(cls.transpose5, (lv654,), out_sinfo=R.Tensor((1, 32, n, 80), dtype="float16")) | |
lv669 = R.call_tir(cls.transpose5, (lv666,), out_sinfo=R.Tensor((1, 32, m, 80), dtype="float16")) | |
lv670 = R.call_tir(cls.transpose5, (lv667,), out_sinfo=R.Tensor((1, 32, m, 80), dtype="float16")) | |
lv253_1 = R.call_tir(cls.fused_NT_matmul1_divide_maximum_minimum_cast2, (lv668, lv669, lv5), out_sinfo=R.Tensor((1, 32, n, m), dtype="float32")) | |
lv254_1 = R.call_tir(cls.fused_softmax_cast3, (lv253_1,), out_sinfo=R.Tensor((1, 32, n, m), dtype="float16")) | |
lv679 = R.call_tir(cls.matmul8, (lv254_1, lv670), out_sinfo=R.Tensor((1, 32, n, 80), dtype="float16")) | |
lv680 = R.call_tir(cls.transpose6, (lv679,), out_sinfo=R.Tensor((1, n, 32, 80), dtype="float16")) | |
lv681 = R.call_tir(cls.reshape4, (lv680,), out_sinfo=R.Tensor((1, n, 2560), dtype="float16")) | |
lv255: R.Tensor((2560, 320), dtype="uint32") = model_params[185] | |
lv256: R.Tensor((2560, 80), dtype="float16") = model_params[186] | |
lv185_1: R.Tensor((2560,), dtype="float16") = model_params[187] | |
lv86_3 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul2_add1_add2, (lv255, lv256, lv681, lv185_1, lv85_3), out_sinfo=R.Tensor((1, n, 2560), dtype="float16")) | |
lv686 = R.call_tir(cls.cast, (lv86_3,), out_sinfo=R.Tensor((1, n, 2560), dtype="float32")) | |
lv186_1: R.Tensor((2560,), dtype="float32") = model_params[180] | |
lv187_2: R.Tensor((2560,), dtype="float32") = model_params[181] | |
lv259_1 = R.call_tir(cls.fused_layer_norm_cast1, (lv686, lv186_1, lv187_2), out_sinfo=R.Tensor((1, n, 2560), dtype="float16")) | |
lv689: R.Tensor((1, n, 2560), dtype="float16") = lv259_1 | |
lv260_1: R.Tensor((10240, 320), dtype="uint32") = model_params[188] | |
lv261_1: R.Tensor((10240, 80), dtype="float16") = model_params[189] | |
lv190_2: R.Tensor((10240,), dtype="float32") = model_params[190] | |
lv87_1 = R.call_tir(cls.fused_fused_decode4_fused_NT_matmul3_add3_gelu_cast4, (lv260_1, lv261_1, lv689, lv190_2), out_sinfo=R.Tensor((1, n, 10240), dtype="float16")) | |
lv695: R.Tensor((1, n, 10240), dtype="float16") = lv87_1 | |
lv264_1: R.Tensor((2560, 1280), dtype="uint32") = model_params[191] | |
lv265: R.Tensor((2560, 320), dtype="float16") = model_params[192] | |
lv193_2: R.Tensor((2560,), dtype="float32") = model_params[193] | |
lv87_2 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul4_add4_cast1_cast5_add2, (lv264_1, lv265, lv695, lv193_2, lv86_3), out_sinfo=R.Tensor((1, n, 2560), dtype="float16")) | |
lv702 = R.call_tir(cls.cast, (lv87_2,), out_sinfo=R.Tensor((1, n, 2560), dtype="float32")) | |
lv194_2: R.Tensor((2560,), dtype="float32") = model_params[194] | |
lv195_2: R.Tensor((2560,), dtype="float32") = model_params[195] | |
lv268 = R.call_tir(cls.fused_layer_norm_cast1, (lv702, lv194_2, lv195_2), out_sinfo=R.Tensor((1, n, 2560), dtype="float16")) | |
lv705: R.Tensor((1, n, 2560), dtype="float16") = lv268 | |
lv269: R.Tensor((7680, 320), dtype="uint32") = model_params[198] | |
lv270: R.Tensor((7680, 80), dtype="float16") = model_params[199] | |
lv198_1: R.Tensor((7680,), dtype="float16") = model_params[200] | |
lv88_2 = R.call_tir(cls.fused_fused_decode2_fused_NT_matmul_add, (lv269, lv270, lv705, lv198_1), out_sinfo=R.Tensor((1, n, 7680), dtype="float16")) | |
lv709 = R.call_tir(cls.reshape2, (lv88_2,), out_sinfo=R.Tensor((1, n, 32, 240), dtype="float16")) | |
lv710 = R.call_tir(cls.split, (lv709,), out_sinfo=[R.Tensor((1, n, 32, 80), dtype="float16"), R.Tensor((1, n, 32, 80), dtype="float16"), R.Tensor((1, n, 32, 80), dtype="float16")]) | |
lv711: R.Tensor((1, n, 32, 80), dtype="float16") = lv710[0] | |
lv712 = R.call_tir(cls.rotary_embedding, (lv711, metadata["relax.expr.Constant"][3], metadata["relax.expr.Constant"][4]), out_sinfo=R.Tensor((1, n, 32, 80), dtype="float16"), tir_vars=R.shape([m])) | |
lv713: R.Tensor((1, n, 32, 80), dtype="float16") = lv710[1] | |
lv714 = R.call_tir(cls.rotary_embedding, (lv713, metadata["relax.expr.Constant"][3], metadata["relax.expr.Constant"][4]), out_sinfo=R.Tensor((1, n, 32, 80), dtype="float16"), tir_vars=R.shape([m])) | |
lv715: R.Object = kv_cache[24] | |
lv716 = R.call_tir(cls.squeeze, (lv714,), out_sinfo=R.Tensor((n, 32, 80), dtype="float16")) | |
lv717: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv715, lv716, sinfo_args=(R.Object,)) | |
lv718: R.Object = kv_cache[25] | |
lv273_1: R.Tensor((1, n, 32, 80), dtype="float16") = lv710[2] | |
lv274_1 = R.call_tir(cls.fused_squeeze, (lv273_1,), out_sinfo=R.Tensor((n, 32, 80), dtype="float16")) | |
lv721: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv718, lv274_1, sinfo_args=(R.Object,)) | |
lv722: R.Tensor((m, 32, 80), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv717, R.shape([m, 32, 80]), sinfo_args=(R.Tensor((m, 32, 80), dtype="float16"),)) | |
lv723: R.Tensor((m, 32, 80), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv721, R.shape([m, 32, 80]), sinfo_args=(R.Tensor((m, 32, 80), dtype="float16"),)) | |
lv724 = R.call_tir(cls.reshape3, (lv722,), out_sinfo=R.Tensor((1, m, 32, 80), dtype="float16")) | |
lv725 = R.call_tir(cls.reshape3, (lv723,), out_sinfo=R.Tensor((1, m, 32, 80), dtype="float16")) | |
lv726 = R.call_tir(cls.transpose5, (lv712,), out_sinfo=R.Tensor((1, 32, n, 80), dtype="float16")) | |
lv727 = R.call_tir(cls.transpose5, (lv724,), out_sinfo=R.Tensor((1, 32, m, 80), dtype="float16")) | |
lv728 = R.call_tir(cls.transpose5, (lv725,), out_sinfo=R.Tensor((1, 32, m, 80), dtype="float16")) | |
lv275_1 = R.call_tir(cls.fused_NT_matmul1_divide_maximum_minimum_cast2, (lv726, lv727, lv5), out_sinfo=R.Tensor((1, 32, n, m), dtype="float32")) | |
lv276 = R.call_tir(cls.fused_softmax_cast3, (lv275_1,), out_sinfo=R.Tensor((1, 32, n, m), dtype="float16")) | |
lv737 = R.call_tir(cls.matmul8, (lv276, lv728), out_sinfo=R.Tensor((1, 32, n, 80), dtype="float16")) | |
lv738 = R.call_tir(cls.transpose6, (lv737,), out_sinfo=R.Tensor((1, n, 32, 80), dtype="float16")) | |
lv739 = R.call_tir(cls.reshape4, (lv738,), out_sinfo=R.Tensor((1, n, 2560), dtype="float16")) | |
lv277: R.Tensor((2560, 320), dtype="uint32") = model_params[201] | |
lv278: R.Tensor((2560, 80), dtype="float16") = model_params[202] | |
lv201_1: R.Tensor((2560,), dtype="float16") = model_params[203] | |
lv88_3 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul2_add1_add2, (lv277, lv278, lv739, lv201_1, lv87_2), out_sinfo=R.Tensor((1, n, 2560), dtype="float16")) | |
lv744 = R.call_tir(cls.cast, (lv88_3,), out_sinfo=R.Tensor((1, n, 2560), dtype="float32")) | |
lv202_2: R.Tensor((2560,), dtype="float32") = model_params[196] | |
lv203_2: R.Tensor((2560,), dtype="float32") = model_params[197] | |
lv281 = R.call_tir(cls.fused_layer_norm_cast1, (lv744, lv202_2, lv203_2), out_sinfo=R.Tensor((1, n, 2560), dtype="float16")) | |
lv747: R.Tensor((1, n, 2560), dtype="float16") = lv281 | |
lv282: R.Tensor((10240, 320), dtype="uint32") = model_params[204] | |
lv283_1: R.Tensor((10240, 80), dtype="float16") = model_params[205] | |
lv206_1: R.Tensor((10240,), dtype="float32") = model_params[206] | |
lv89_3 = R.call_tir(cls.fused_fused_decode4_fused_NT_matmul3_add3_gelu_cast4, (lv282, lv283_1, lv747, lv206_1), out_sinfo=R.Tensor((1, n, 10240), dtype="float16")) | |
lv753: R.Tensor((1, n, 10240), dtype="float16") = lv89_3 | |
lv286: R.Tensor((2560, 1280), dtype="uint32") = model_params[207] | |
lv287: R.Tensor((2560, 320), dtype="float16") = model_params[208] | |
lv209_1: R.Tensor((2560,), dtype="float32") = model_params[209] | |
lv89_4 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul4_add4_cast1_cast5_add2, (lv286, lv287, lv753, lv209_1, lv88_3), out_sinfo=R.Tensor((1, n, 2560), dtype="float16")) | |
lv760 = R.call_tir(cls.cast, (lv89_4,), out_sinfo=R.Tensor((1, n, 2560), dtype="float32")) | |
lv210_1: R.Tensor((2560,), dtype="float32") = model_params[210] | |
lv211_1: R.Tensor((2560,), dtype="float32") = model_params[211] | |
lv290 = R.call_tir(cls.fused_layer_norm_cast1, (lv760, lv210_1, lv211_1), out_sinfo=R.Tensor((1, n, 2560), dtype="float16")) | |
lv763: R.Tensor((1, n, 2560), dtype="float16") = lv290 | |
lv291: R.Tensor((7680, 320), dtype="uint32") = model_params[214] | |
lv292: R.Tensor((7680, 80), dtype="float16") = model_params[215] | |
lv214: R.Tensor((7680,), dtype="float16") = model_params[216] | |
lv90_2 = R.call_tir(cls.fused_fused_decode2_fused_NT_matmul_add, (lv291, lv292, lv763, lv214), out_sinfo=R.Tensor((1, n, 7680), dtype="float16")) | |
lv767 = R.call_tir(cls.reshape2, (lv90_2,), out_sinfo=R.Tensor((1, n, 32, 240), dtype="float16")) | |
lv768 = R.call_tir(cls.split, (lv767,), out_sinfo=[R.Tensor((1, n, 32, 80), dtype="float16"), R.Tensor((1, n, 32, 80), dtype="float16"), R.Tensor((1, n, 32, 80), dtype="float16")]) | |
lv769: R.Tensor((1, n, 32, 80), dtype="float16") = lv768[0] | |
lv770 = R.call_tir(cls.rotary_embedding, (lv769, metadata["relax.expr.Constant"][3], metadata["relax.expr.Constant"][4]), out_sinfo=R.Tensor((1, n, 32, 80), dtype="float16"), tir_vars=R.shape([m])) | |
lv771: R.Tensor((1, n, 32, 80), dtype="float16") = lv768[1] | |
lv772 = R.call_tir(cls.rotary_embedding, (lv771, metadata["relax.expr.Constant"][3], metadata["relax.expr.Constant"][4]), out_sinfo=R.Tensor((1, n, 32, 80), dtype="float16"), tir_vars=R.shape([m])) | |
lv773: R.Object = kv_cache[26] | |
lv774 = R.call_tir(cls.squeeze, (lv772,), out_sinfo=R.Tensor((n, 32, 80), dtype="float16")) | |
lv775: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv773, lv774, sinfo_args=(R.Object,)) | |
lv776: R.Object = kv_cache[27] | |
lv295: R.Tensor((1, n, 32, 80), dtype="float16") = lv768[2] | |
lv296_1 = R.call_tir(cls.fused_squeeze, (lv295,), out_sinfo=R.Tensor((n, 32, 80), dtype="float16")) | |
lv779: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv776, lv296_1, sinfo_args=(R.Object,)) | |
lv780: R.Tensor((m, 32, 80), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv775, R.shape([m, 32, 80]), sinfo_args=(R.Tensor((m, 32, 80), dtype="float16"),)) | |
lv781: R.Tensor((m, 32, 80), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv779, R.shape([m, 32, 80]), sinfo_args=(R.Tensor((m, 32, 80), dtype="float16"),)) | |
lv782 = R.call_tir(cls.reshape3, (lv780,), out_sinfo=R.Tensor((1, m, 32, 80), dtype="float16")) | |
lv783 = R.call_tir(cls.reshape3, (lv781,), out_sinfo=R.Tensor((1, m, 32, 80), dtype="float16")) | |
lv784 = R.call_tir(cls.transpose5, (lv770,), out_sinfo=R.Tensor((1, 32, n, 80), dtype="float16")) | |
lv785 = R.call_tir(cls.transpose5, (lv782,), out_sinfo=R.Tensor((1, 32, m, 80), dtype="float16")) | |
lv786 = R.call_tir(cls.transpose5, (lv783,), out_sinfo=R.Tensor((1, 32, m, 80), dtype="float16")) | |
lv297 = R.call_tir(cls.fused_NT_matmul1_divide_maximum_minimum_cast2, (lv784, lv785, lv5), out_sinfo=R.Tensor((1, 32, n, m), dtype="float32")) | |
lv298 = R.call_tir(cls.fused_softmax_cast3, (lv297,), out_sinfo=R.Tensor((1, 32, n, m), dtype="float16")) | |
lv795 = R.call_tir(cls.matmul8, (lv298, lv786), out_sinfo=R.Tensor((1, 32, n, 80), dtype="float16")) | |
lv796 = R.call_tir(cls.transpose6, (lv795,), out_sinfo=R.Tensor((1, n, 32, 80), dtype="float16")) | |
lv797 = R.call_tir(cls.reshape4, (lv796,), out_sinfo=R.Tensor((1, n, 2560), dtype="float16")) | |
lv299_1: R.Tensor((2560, 320), dtype="uint32") = model_params[217] | |
lv300: R.Tensor((2560, 80), dtype="float16") = model_params[218] | |
lv217_2: R.Tensor((2560,), dtype="float16") = model_params[219] | |
lv90_3 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul2_add1_add2, (lv299_1, lv300, lv797, lv217_2, lv89_4), out_sinfo=R.Tensor((1, n, 2560), dtype="float16")) | |
lv802 = R.call_tir(cls.cast, (lv90_3,), out_sinfo=R.Tensor((1, n, 2560), dtype="float32")) | |
lv218: R.Tensor((2560,), dtype="float32") = model_params[212] | |
lv219: R.Tensor((2560,), dtype="float32") = model_params[213] | |
lv303_1 = R.call_tir(cls.fused_layer_norm_cast1, (lv802, lv218, lv219), out_sinfo=R.Tensor((1, n, 2560), dtype="float16")) | |
lv805: R.Tensor((1, n, 2560), dtype="float16") = lv303_1 | |
lv304_1: R.Tensor((10240, 320), dtype="uint32") = model_params[220] | |
lv305_1: R.Tensor((10240, 80), dtype="float16") = model_params[221] | |
lv222_1: R.Tensor((10240,), dtype="float32") = model_params[222] | |
lv91_1 = R.call_tir(cls.fused_fused_decode4_fused_NT_matmul3_add3_gelu_cast4, (lv304_1, lv305_1, lv805, lv222_1), out_sinfo=R.Tensor((1, n, 10240), dtype="float16")) | |
lv811: R.Tensor((1, n, 10240), dtype="float16") = lv91_1 | |
lv308_1: R.Tensor((2560, 1280), dtype="uint32") = model_params[223] | |
lv309_1: R.Tensor((2560, 320), dtype="float16") = model_params[224] | |
lv225_2: R.Tensor((2560,), dtype="float32") = model_params[225] | |
lv91_2 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul4_add4_cast1_cast5_add2, (lv308_1, lv309_1, lv811, lv225_2, lv90_3), out_sinfo=R.Tensor((1, n, 2560), dtype="float16")) | |
lv818 = R.call_tir(cls.cast, (lv91_2,), out_sinfo=R.Tensor((1, n, 2560), dtype="float32")) | |
lv226_1: R.Tensor((2560,), dtype="float32") = model_params[226] | |
lv227: R.Tensor((2560,), dtype="float32") = model_params[227] | |
lv312_1 = R.call_tir(cls.fused_layer_norm_cast1, (lv818, lv226_1, lv227), out_sinfo=R.Tensor((1, n, 2560), dtype="float16")) | |
lv821: R.Tensor((1, n, 2560), dtype="float16") = lv312_1 | |
lv313: R.Tensor((7680, 320), dtype="uint32") = model_params[230] | |
lv314: R.Tensor((7680, 80), dtype="float16") = model_params[231] | |
lv230_1: R.Tensor((7680,), dtype="float16") = model_params[232] | |
lv92_1 = R.call_tir(cls.fused_fused_decode2_fused_NT_matmul_add, (lv313, lv314, lv821, lv230_1), out_sinfo=R.Tensor((1, n, 7680), dtype="float16")) | |
lv825 = R.call_tir(cls.reshape2, (lv92_1,), out_sinfo=R.Tensor((1, n, 32, 240), dtype="float16")) | |
lv826 = R.call_tir(cls.split, (lv825,), out_sinfo=[R.Tensor((1, n, 32, 80), dtype="float16"), R.Tensor((1, n, 32, 80), dtype="float16"), R.Tensor((1, n, 32, 80), dtype="float16")]) | |
lv827: R.Tensor((1, n, 32, 80), dtype="float16") = lv826[0] | |
lv828 = R.call_tir(cls.rotary_embedding, (lv827, metadata["relax.expr.Constant"][3], metadata["relax.expr.Constant"][4]), out_sinfo=R.Tensor((1, n, 32, 80), dtype="float16"), tir_vars=R.shape([m])) | |
lv829: R.Tensor((1, n, 32, 80), dtype="float16") = lv826[1] | |
lv830 = R.call_tir(cls.rotary_embedding, (lv829, metadata["relax.expr.Constant"][3], metadata["relax.expr.Constant"][4]), out_sinfo=R.Tensor((1, n, 32, 80), dtype="float16"), tir_vars=R.shape([m])) | |
lv831: R.Object = kv_cache[28] | |
lv832 = R.call_tir(cls.squeeze, (lv830,), out_sinfo=R.Tensor((n, 32, 80), dtype="float16")) | |
lv833: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv831, lv832, sinfo_args=(R.Object,)) | |
lv834: R.Object = kv_cache[29] | |
lv317_1: R.Tensor((1, n, 32, 80), dtype="float16") = lv826[2] | |
lv318_1 = R.call_tir(cls.fused_squeeze, (lv317_1,), out_sinfo=R.Tensor((n, 32, 80), dtype="float16")) | |
lv837: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv834, lv318_1, sinfo_args=(R.Object,)) | |
lv838: R.Tensor((m, 32, 80), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv833, R.shape([m, 32, 80]), sinfo_args=(R.Tensor((m, 32, 80), dtype="float16"),)) | |
lv839: R.Tensor((m, 32, 80), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv837, R.shape([m, 32, 80]), sinfo_args=(R.Tensor((m, 32, 80), dtype="float16"),)) | |
lv840 = R.call_tir(cls.reshape3, (lv838,), out_sinfo=R.Tensor((1, m, 32, 80), dtype="float16")) | |
lv841 = R.call_tir(cls.reshape3, (lv839,), out_sinfo=R.Tensor((1, m, 32, 80), dtype="float16")) | |
lv842 = R.call_tir(cls.transpose5, (lv828,), out_sinfo=R.Tensor((1, 32, n, 80), dtype="float16")) | |
lv843 = R.call_tir(cls.transpose5, (lv840,), out_sinfo=R.Tensor((1, 32, m, 80), dtype="float16")) | |
lv844 = R.call_tir(cls.transpose5, (lv841,), out_sinfo=R.Tensor((1, 32, m, 80), dtype="float16")) | |
lv319_1 = R.call_tir(cls.fused_NT_matmul1_divide_maximum_minimum_cast2, (lv842, lv843, lv5), out_sinfo=R.Tensor((1, 32, n, m), dtype="float32")) | |
lv320_1 = R.call_tir(cls.fused_softmax_cast3, (lv319_1,), out_sinfo=R.Tensor((1, 32, n, m), dtype="float16")) | |
lv853 = R.call_tir(cls.matmul8, (lv320_1, lv844), out_sinfo=R.Tensor((1, 32, n, 80), dtype="float16")) | |
lv854 = R.call_tir(cls.transpose6, (lv853,), out_sinfo=R.Tensor((1, n, 32, 80), dtype="float16")) | |
lv855 = R.call_tir(cls.reshape4, (lv854,), out_sinfo=R.Tensor((1, n, 2560), dtype="float16")) | |
lv321_1: R.Tensor((2560, 320), dtype="uint32") = model_params[233] | |
lv322_1: R.Tensor((2560, 80), dtype="float16") = model_params[234] | |
lv233_1: R.Tensor((2560,), dtype="float16") = model_params[235] | |
lv92_2 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul2_add1_add2, (lv321_1, lv322_1, lv855, lv233_1, lv91_2), out_sinfo=R.Tensor((1, n, 2560), dtype="float16")) | |
lv860 = R.call_tir(cls.cast, (lv92_2,), out_sinfo=R.Tensor((1, n, 2560), dtype="float32")) | |
lv234_1: R.Tensor((2560,), dtype="float32") = model_params[228] | |
lv235: R.Tensor((2560,), dtype="float32") = model_params[229] | |
lv325 = R.call_tir(cls.fused_layer_norm_cast1, (lv860, lv234_1, lv235), out_sinfo=R.Tensor((1, n, 2560), dtype="float16")) | |
lv863: R.Tensor((1, n, 2560), dtype="float16") = lv325 | |
lv326: R.Tensor((10240, 320), dtype="uint32") = model_params[236] | |
lv327: R.Tensor((10240, 80), dtype="float16") = model_params[237] | |
lv238_2: R.Tensor((10240,), dtype="float32") = model_params[238] | |
lv93_1 = R.call_tir(cls.fused_fused_decode4_fused_NT_matmul3_add3_gelu_cast4, (lv326, lv327, lv863, lv238_2), out_sinfo=R.Tensor((1, n, 10240), dtype="float16")) | |
lv869: R.Tensor((1, n, 10240), dtype="float16") = lv93_1 | |
lv330: R.Tensor((2560, 1280), dtype="uint32") = model_params[239] | |
lv331_1: R.Tensor((2560, 320), dtype="float16") = model_params[240] | |
lv241_1: R.Tensor((2560,), dtype="float32") = model_params[241] | |
lv93_2 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul4_add4_cast1_cast5_add2, (lv330, lv331_1, lv869, lv241_1, lv92_2), out_sinfo=R.Tensor((1, n, 2560), dtype="float16")) | |
lv876 = R.call_tir(cls.cast, (lv93_2,), out_sinfo=R.Tensor((1, n, 2560), dtype="float32")) | |
lv242_1: R.Tensor((2560,), dtype="float32") = model_params[242] | |
lv243_1: R.Tensor((2560,), dtype="float32") = model_params[243] | |
lv334 = R.call_tir(cls.fused_layer_norm_cast1, (lv876, lv242_1, lv243_1), out_sinfo=R.Tensor((1, n, 2560), dtype="float16")) | |
lv879: R.Tensor((1, n, 2560), dtype="float16") = lv334 | |
lv335: R.Tensor((7680, 320), dtype="uint32") = model_params[246] | |
lv336: R.Tensor((7680, 80), dtype="float16") = model_params[247] | |
lv246_2: R.Tensor((7680,), dtype="float16") = model_params[248] | |
lv94_2 = R.call_tir(cls.fused_fused_decode2_fused_NT_matmul_add, (lv335, lv336, lv879, lv246_2), out_sinfo=R.Tensor((1, n, 7680), dtype="float16")) | |
lv883 = R.call_tir(cls.reshape2, (lv94_2,), out_sinfo=R.Tensor((1, n, 32, 240), dtype="float16")) | |
lv884 = R.call_tir(cls.split, (lv883,), out_sinfo=[R.Tensor((1, n, 32, 80), dtype="float16"), R.Tensor((1, n, 32, 80), dtype="float16"), R.Tensor((1, n, 32, 80), dtype="float16")]) | |
lv885: R.Tensor((1, n, 32, 80), dtype="float16") = lv884[0] | |
lv886 = R.call_tir(cls.rotary_embedding, (lv885, metadata["relax.expr.Constant"][3], metadata["relax.expr.Constant"][4]), out_sinfo=R.Tensor((1, n, 32, 80), dtype="float16"), tir_vars=R.shape([m])) | |
lv887: R.Tensor((1, n, 32, 80), dtype="float16") = lv884[1] | |
lv888 = R.call_tir(cls.rotary_embedding, (lv887, metadata["relax.expr.Constant"][3], metadata["relax.expr.Constant"][4]), out_sinfo=R.Tensor((1, n, 32, 80), dtype="float16"), tir_vars=R.shape([m])) | |
lv889: R.Object = kv_cache[30] | |
lv890 = R.call_tir(cls.squeeze, (lv888,), out_sinfo=R.Tensor((n, 32, 80), dtype="float16")) | |
lv891: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv889, lv890, sinfo_args=(R.Object,)) | |
lv892: R.Object = kv_cache[31] | |
lv339: R.Tensor((1, n, 32, 80), dtype="float16") = lv884[2] | |
lv340 = R.call_tir(cls.fused_squeeze, (lv339,), out_sinfo=R.Tensor((n, 32, 80), dtype="float16")) | |
lv895: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv892, lv340, sinfo_args=(R.Object,)) | |
lv896: R.Tensor((m, 32, 80), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv891, R.shape([m, 32, 80]), sinfo_args=(R.Tensor((m, 32, 80), dtype="float16"),)) | |
lv897: R.Tensor((m, 32, 80), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv895, R.shape([m, 32, 80]), sinfo_args=(R.Tensor((m, 32, 80), dtype="float16"),)) | |
lv898 = R.call_tir(cls.reshape3, (lv896,), out_sinfo=R.Tensor((1, m, 32, 80), dtype="float16")) | |
lv899 = R.call_tir(cls.reshape3, (lv897,), out_sinfo=R.Tensor((1, m, 32, 80), dtype="float16")) | |
lv900 = R.call_tir(cls.transpose5, (lv886,), out_sinfo=R.Tensor((1, 32, n, 80), dtype="float16")) | |
lv901 = R.call_tir(cls.transpose5, (lv898,), out_sinfo=R.Tensor((1, 32, m, 80), dtype="float16")) | |
lv902 = R.call_tir(cls.transpose5, (lv899,), out_sinfo=R.Tensor((1, 32, m, 80), dtype="float16")) | |
lv341_1 = R.call_tir(cls.fused_NT_matmul1_divide_maximum_minimum_cast2, (lv900, lv901, lv5), out_sinfo=R.Tensor((1, 32, n, m), dtype="float32")) | |
lv342 = R.call_tir(cls.fused_softmax_cast3, (lv341_1,), out_sinfo=R.Tensor((1, 32, n, m), dtype="float16")) | |
lv911 = R.call_tir(cls.matmul8, (lv342, lv902), out_sinfo=R.Tensor((1, 32, n, 80), dtype="float16")) | |
lv912 = R.call_tir(cls.transpose6, (lv911,), out_sinfo=R.Tensor((1, n, 32, 80), dtype="float16")) | |
lv913 = R.call_tir(cls.reshape4, (lv912,), out_sinfo=R.Tensor((1, n, 2560), dtype="float16")) | |
lv343: R.Tensor((2560, 320), dtype="uint32") = model_params[249] | |
lv344: R.Tensor((2560, 80), dtype="float16") = model_params[250] | |
lv249_1: R.Tensor((2560,), dtype="float16") = model_params[251] | |
lv94_3 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul2_add1_add2, (lv343, lv344, lv913, lv249_1, lv93_2), out_sinfo=R.Tensor((1, n, 2560), dtype="float16")) | |
lv918 = R.call_tir(cls.cast, (lv94_3,), out_sinfo=R.Tensor((1, n, 2560), dtype="float32")) | |
lv250_1: R.Tensor((2560,), dtype="float32") = model_params[244] | |
lv251_2: R.Tensor((2560,), dtype="float32") = model_params[245] | |
lv347_1 = R.call_tir(cls.fused_layer_norm_cast1, (lv918, lv250_1, lv251_2), out_sinfo=R.Tensor((1, n, 2560), dtype="float16")) | |
lv921: R.Tensor((1, n, 2560), dtype="float16") = lv347_1 | |
lv348: R.Tensor((10240, 320), dtype="uint32") = model_params[252] | |
lv349: R.Tensor((10240, 80), dtype="float16") = model_params[253] | |
lv254_2: R.Tensor((10240,), dtype="float32") = model_params[254] | |
lv95 = R.call_tir(cls.fused_fused_decode4_fused_NT_matmul3_add3_gelu_cast4, (lv348, lv349, lv921, lv254_2), out_sinfo=R.Tensor((1, n, 10240), dtype="float16")) | |
lv927: R.Tensor((1, n, 10240), dtype="float16") = lv95 | |
lv352: R.Tensor((2560, 1280), dtype="uint32") = model_params[255] | |
lv353: R.Tensor((2560, 320), dtype="float16") = model_params[256] | |
lv257_1: R.Tensor((2560,), dtype="float32") = model_params[257] | |
lv95_1 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul4_add4_cast1_cast5_add2, (lv352, lv353, lv927, lv257_1, lv94_3), out_sinfo=R.Tensor((1, n, 2560), dtype="float16")) | |
lv934 = R.call_tir(cls.cast, (lv95_1,), out_sinfo=R.Tensor((1, n, 2560), dtype="float32")) | |
lv258_1: R.Tensor((2560,), dtype="float32") = model_params[258] | |
lv259_2: R.Tensor((2560,), dtype="float32") = model_params[259] | |
lv356 = R.call_tir(cls.fused_layer_norm_cast1, (lv934, lv258_1, lv259_2), out_sinfo=R.Tensor((1, n, 2560), dtype="float16")) | |
lv937: R.Tensor((1, n, 2560), dtype="float16") = lv356 | |
lv357_1: R.Tensor((7680, 320), dtype="uint32") = model_params[262] | |
lv358: R.Tensor((7680, 80), dtype="float16") = model_params[263] | |
lv262_1: R.Tensor((7680,), dtype="float16") = model_params[264] | |
lv96 = R.call_tir(cls.fused_fused_decode2_fused_NT_matmul_add, (lv357_1, lv358, lv937, lv262_1), out_sinfo=R.Tensor((1, n, 7680), dtype="float16")) | |
lv941 = R.call_tir(cls.reshape2, (lv96,), out_sinfo=R.Tensor((1, n, 32, 240), dtype="float16")) | |
lv942 = R.call_tir(cls.split, (lv941,), out_sinfo=[R.Tensor((1, n, 32, 80), dtype="float16"), R.Tensor((1, n, 32, 80), dtype="float16"), R.Tensor((1, n, 32, 80), dtype="float16")]) | |
lv943: R.Tensor((1, n, 32, 80), dtype="float16") = lv942[0] | |
lv944 = R.call_tir(cls.rotary_embedding, (lv943, metadata["relax.expr.Constant"][3], metadata["relax.expr.Constant"][4]), out_sinfo=R.Tensor((1, n, 32, 80), dtype="float16"), tir_vars=R.shape([m])) | |
lv945: R.Tensor((1, n, 32, 80), dtype="float16") = lv942[1] | |
lv946 = R.call_tir(cls.rotary_embedding, (lv945, metadata["relax.expr.Constant"][3], metadata["relax.expr.Constant"][4]), out_sinfo=R.Tensor((1, n, 32, 80), dtype="float16"), tir_vars=R.shape([m])) | |
lv947: R.Object = kv_cache[32] | |
lv948 = R.call_tir(cls.squeeze, (lv946,), out_sinfo=R.Tensor((n, 32, 80), dtype="float16")) | |
lv949: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv947, lv948, sinfo_args=(R.Object,)) | |
lv950: R.Object = kv_cache[33] | |
lv361_1: R.Tensor((1, n, 32, 80), dtype="float16") = lv942[2] | |
lv362_1 = R.call_tir(cls.fused_squeeze, (lv361_1,), out_sinfo=R.Tensor((n, 32, 80), dtype="float16")) | |
lv953: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv950, lv362_1, sinfo_args=(R.Object,)) | |
lv954: R.Tensor((m, 32, 80), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv949, R.shape([m, 32, 80]), sinfo_args=(R.Tensor((m, 32, 80), dtype="float16"),)) | |
lv955: R.Tensor((m, 32, 80), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv953, R.shape([m, 32, 80]), sinfo_args=(R.Tensor((m, 32, 80), dtype="float16"),)) | |
lv956 = R.call_tir(cls.reshape3, (lv954,), out_sinfo=R.Tensor((1, m, 32, 80), dtype="float16")) | |
lv957 = R.call_tir(cls.reshape3, (lv955,), out_sinfo=R.Tensor((1, m, 32, 80), dtype="float16")) | |
lv958 = R.call_tir(cls.transpose5, (lv944,), out_sinfo=R.Tensor((1, 32, n, 80), dtype="float16")) | |
lv959 = R.call_tir(cls.transpose5, (lv956,), out_sinfo=R.Tensor((1, 32, m, 80), dtype="float16")) | |
lv960 = R.call_tir(cls.transpose5, (lv957,), out_sinfo=R.Tensor((1, 32, m, 80), dtype="float16")) | |
lv363_1 = R.call_tir(cls.fused_NT_matmul1_divide_maximum_minimum_cast2, (lv958, lv959, lv5), out_sinfo=R.Tensor((1, 32, n, m), dtype="float32")) | |
lv364_1 = R.call_tir(cls.fused_softmax_cast3, (lv363_1,), out_sinfo=R.Tensor((1, 32, n, m), dtype="float16")) | |
lv969 = R.call_tir(cls.matmul8, (lv364_1, lv960), out_sinfo=R.Tensor((1, 32, n, 80), dtype="float16")) | |
lv970 = R.call_tir(cls.transpose6, (lv969,), out_sinfo=R.Tensor((1, n, 32, 80), dtype="float16")) | |
lv971 = R.call_tir(cls.reshape4, (lv970,), out_sinfo=R.Tensor((1, n, 2560), dtype="float16")) | |
lv365_1: R.Tensor((2560, 320), dtype="uint32") = model_params[265] | |
lv366_1: R.Tensor((2560, 80), dtype="float16") = model_params[266] | |
lv265_1: R.Tensor((2560,), dtype="float16") = model_params[267] | |
lv96_1 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul2_add1_add2, (lv365_1, lv366_1, lv971, lv265_1, lv95_1), out_sinfo=R.Tensor((1, n, 2560), dtype="float16")) | |
lv976 = R.call_tir(cls.cast, (lv96_1,), out_sinfo=R.Tensor((1, n, 2560), dtype="float32")) | |
lv266: R.Tensor((2560,), dtype="float32") = model_params[260] | |
lv267: R.Tensor((2560,), dtype="float32") = model_params[261] | |
lv369_1 = R.call_tir(cls.fused_layer_norm_cast1, (lv976, lv266, lv267), out_sinfo=R.Tensor((1, n, 2560), dtype="float16")) | |
lv979: R.Tensor((1, n, 2560), dtype="float16") = lv369_1 | |
lv370_1: R.Tensor((10240, 320), dtype="uint32") = model_params[268] | |
lv371: R.Tensor((10240, 80), dtype="float16") = model_params[269] | |
lv270_1: R.Tensor((10240,), dtype="float32") = model_params[270] | |
lv97_2 = R.call_tir(cls.fused_fused_decode4_fused_NT_matmul3_add3_gelu_cast4, (lv370_1, lv371, lv979, lv270_1), out_sinfo=R.Tensor((1, n, 10240), dtype="float16")) | |
lv985: R.Tensor((1, n, 10240), dtype="float16") = lv97_2 | |
lv374_1: R.Tensor((2560, 1280), dtype="uint32") = model_params[271] | |
lv375_1: R.Tensor((2560, 320), dtype="float16") = model_params[272] | |
lv273_2: R.Tensor((2560,), dtype="float32") = model_params[273] | |
lv97_3 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul4_add4_cast1_cast5_add2, (lv374_1, lv375_1, lv985, lv273_2, lv96_1), out_sinfo=R.Tensor((1, n, 2560), dtype="float16")) | |
lv992 = R.call_tir(cls.cast, (lv97_3,), out_sinfo=R.Tensor((1, n, 2560), dtype="float32")) | |
lv274_2: R.Tensor((2560,), dtype="float32") = model_params[274] | |
lv275_2: R.Tensor((2560,), dtype="float32") = model_params[275] | |
lv378_1 = R.call_tir(cls.fused_layer_norm_cast1, (lv992, lv274_2, lv275_2), out_sinfo=R.Tensor((1, n, 2560), dtype="float16")) | |
lv995: R.Tensor((1, n, 2560), dtype="float16") = lv378_1 | |
lv379_1: R.Tensor((7680, 320), dtype="uint32") = model_params[278] | |
lv380_1: R.Tensor((7680, 80), dtype="float16") = model_params[279] | |
lv278_1: R.Tensor((7680,), dtype="float16") = model_params[280] | |
lv98_2 = R.call_tir(cls.fused_fused_decode2_fused_NT_matmul_add, (lv379_1, lv380_1, lv995, lv278_1), out_sinfo=R.Tensor((1, n, 7680), dtype="float16")) | |
lv999 = R.call_tir(cls.reshape2, (lv98_2,), out_sinfo=R.Tensor((1, n, 32, 240), dtype="float16")) | |
lv1000 = R.call_tir(cls.split, (lv999,), out_sinfo=[R.Tensor((1, n, 32, 80), dtype="float16"), R.Tensor((1, n, 32, 80), dtype="float16"), R.Tensor((1, n, 32, 80), dtype="float16")]) | |
lv1001: R.Tensor((1, n, 32, 80), dtype="float16") = lv1000[0] | |
lv1002 = R.call_tir(cls.rotary_embedding, (lv1001, metadata["relax.expr.Constant"][3], metadata["relax.expr.Constant"][4]), out_sinfo=R.Tensor((1, n, 32, 80), dtype="float16"), tir_vars=R.shape([m])) | |
lv1003: R.Tensor((1, n, 32, 80), dtype="float16") = lv1000[1] | |
lv1004 = R.call_tir(cls.rotary_embedding, (lv1003, metadata["relax.expr.Constant"][3], metadata["relax.expr.Constant"][4]), out_sinfo=R.Tensor((1, n, 32, 80), dtype="float16"), tir_vars=R.shape([m])) | |
lv1005: R.Object = kv_cache[34] | |
lv1006 = R.call_tir(cls.squeeze, (lv1004,), out_sinfo=R.Tensor((n, 32, 80), dtype="float16")) | |
lv1007: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv1005, lv1006, sinfo_args=(R.Object,)) | |
lv1008: R.Object = kv_cache[35] | |
lv383: R.Tensor((1, n, 32, 80), dtype="float16") = lv1000[2] | |
lv384 = R.call_tir(cls.fused_squeeze, (lv383,), out_sinfo=R.Tensor((n, 32, 80), dtype="float16")) | |
lv1011: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv1008, lv384, sinfo_args=(R.Object,)) | |
lv1012: R.Tensor((m, 32, 80), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv1007, R.shape([m, 32, 80]), sinfo_args=(R.Tensor((m, 32, 80), dtype="float16"),)) | |
lv1013: R.Tensor((m, 32, 80), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv1011, R.shape([m, 32, 80]), sinfo_args=(R.Tensor((m, 32, 80), dtype="float16"),)) | |
lv1014 = R.call_tir(cls.reshape3, (lv1012,), out_sinfo=R.Tensor((1, m, 32, 80), dtype="float16")) | |
lv1015 = R.call_tir(cls.reshape3, (lv1013,), out_sinfo=R.Tensor((1, m, 32, 80), dtype="float16")) | |
lv1016 = R.call_tir(cls.transpose5, (lv1002,), out_sinfo=R.Tensor((1, 32, n, 80), dtype="float16")) | |
lv1017 = R.call_tir(cls.transpose5, (lv1014,), out_sinfo=R.Tensor((1, 32, m, 80), dtype="float16")) | |
lv1018 = R.call_tir(cls.transpose5, (lv1015,), out_sinfo=R.Tensor((1, 32, m, 80), dtype="float16")) | |
lv385 = R.call_tir(cls.fused_NT_matmul1_divide_maximum_minimum_cast2, (lv1016, lv1017, lv5), out_sinfo=R.Tensor((1, 32, n, m), dtype="float32")) | |
lv386 = R.call_tir(cls.fused_softmax_cast3, (lv385,), out_sinfo=R.Tensor((1, 32, n, m), dtype="float16")) | |
lv1027 = R.call_tir(cls.matmul8, (lv386, lv1018), out_sinfo=R.Tensor((1, 32, n, 80), dtype="float16")) | |
lv1028 = R.call_tir(cls.transpose6, (lv1027,), out_sinfo=R.Tensor((1, n, 32, 80), dtype="float16")) | |
lv1029 = R.call_tir(cls.reshape4, (lv1028,), out_sinfo=R.Tensor((1, n, 2560), dtype="float16")) | |
lv387: R.Tensor((2560, 320), dtype="uint32") = model_params[281] | |
lv388: R.Tensor((2560, 80), dtype="float16") = model_params[282] | |
lv281_1: R.Tensor((2560,), dtype="float16") = model_params[283] | |
lv98_3 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul2_add1_add2, (lv387, lv388, lv1029, lv281_1, lv97_3), out_sinfo=R.Tensor((1, n, 2560), dtype="float16")) | |
lv1034 = R.call_tir(cls.cast, (lv98_3,), out_sinfo=R.Tensor((1, n, 2560), dtype="float32")) | |
lv282_1: R.Tensor((2560,), dtype="float32") = model_params[276] | |
lv283_2: R.Tensor((2560,), dtype="float32") = model_params[277] | |
lv391_1 = R.call_tir(cls.fused_layer_norm_cast1, (lv1034, lv282_1, lv283_2), out_sinfo=R.Tensor((1, n, 2560), dtype="float16")) | |
lv1037: R.Tensor((1, n, 2560), dtype="float16") = lv391_1 | |
lv392: R.Tensor((10240, 320), dtype="uint32") = model_params[284] | |
lv393: R.Tensor((10240, 80), dtype="float16") = model_params[285] | |
lv286_1: R.Tensor((10240,), dtype="float32") = model_params[286] | |
lv99_3 = R.call_tir(cls.fused_fused_decode4_fused_NT_matmul3_add3_gelu_cast4, (lv392, lv393, lv1037, lv286_1), out_sinfo=R.Tensor((1, n, 10240), dtype="float16")) | |
lv1043: R.Tensor((1, n, 10240), dtype="float16") = lv99_3 | |
lv396_1: R.Tensor((2560, 1280), dtype="uint32") = model_params[287] | |
lv397: R.Tensor((2560, 320), dtype="float16") = model_params[288] | |
lv289_1: R.Tensor((2560,), dtype="float32") = model_params[289] | |
lv99_4 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul4_add4_cast1_cast5_add2, (lv396_1, lv397, lv1043, lv289_1, lv98_3), out_sinfo=R.Tensor((1, n, 2560), dtype="float16")) | |
lv1050 = R.call_tir(cls.cast, (lv99_4,), out_sinfo=R.Tensor((1, n, 2560), dtype="float32")) | |
lv290_1: R.Tensor((2560,), dtype="float32") = model_params[290] | |
lv291_1: R.Tensor((2560,), dtype="float32") = model_params[291] | |
lv400 = R.call_tir(cls.fused_layer_norm_cast1, (lv1050, lv290_1, lv291_1), out_sinfo=R.Tensor((1, n, 2560), dtype="float16")) | |
lv1053: R.Tensor((1, n, 2560), dtype="float16") = lv400 | |
lv401: R.Tensor((7680, 320), dtype="uint32") = model_params[294] | |
lv402: R.Tensor((7680, 80), dtype="float16") = model_params[295] | |
lv294: R.Tensor((7680,), dtype="float16") = model_params[296] | |
lv100_2 = R.call_tir(cls.fused_fused_decode2_fused_NT_matmul_add, (lv401, lv402, lv1053, lv294), out_sinfo=R.Tensor((1, n, 7680), dtype="float16")) | |
lv1057 = R.call_tir(cls.reshape2, (lv100_2,), out_sinfo=R.Tensor((1, n, 32, 240), dtype="float16")) | |
lv1058 = R.call_tir(cls.split, (lv1057,), out_sinfo=[R.Tensor((1, n, 32, 80), dtype="float16"), R.Tensor((1, n, 32, 80), dtype="float16"), R.Tensor((1, n, 32, 80), dtype="float16")]) | |
lv1059: R.Tensor((1, n, 32, 80), dtype="float16") = lv1058[0] | |
lv1060 = R.call_tir(cls.rotary_embedding, (lv1059, metadata["relax.expr.Constant"][3], metadata["relax.expr.Constant"][4]), out_sinfo=R.Tensor((1, n, 32, 80), dtype="float16"), tir_vars=R.shape([m])) | |
lv1061: R.Tensor((1, n, 32, 80), dtype="float16") = lv1058[1] | |
lv1062 = R.call_tir(cls.rotary_embedding, (lv1061, metadata["relax.expr.Constant"][3], metadata["relax.expr.Constant"][4]), out_sinfo=R.Tensor((1, n, 32, 80), dtype="float16"), tir_vars=R.shape([m])) | |
lv1063: R.Object = kv_cache[36] | |
lv1064 = R.call_tir(cls.squeeze, (lv1062,), out_sinfo=R.Tensor((n, 32, 80), dtype="float16")) | |
lv1065: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv1063, lv1064, sinfo_args=(R.Object,)) | |
lv1066: R.Object = kv_cache[37] | |
lv405_1: R.Tensor((1, n, 32, 80), dtype="float16") = lv1058[2] | |
lv406 = R.call_tir(cls.fused_squeeze, (lv405_1,), out_sinfo=R.Tensor((n, 32, 80), dtype="float16")) | |
lv1069: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv1066, lv406, sinfo_args=(R.Object,)) | |
lv1070: R.Tensor((m, 32, 80), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv1065, R.shape([m, 32, 80]), sinfo_args=(R.Tensor((m, 32, 80), dtype="float16"),)) | |
lv1071: R.Tensor((m, 32, 80), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv1069, R.shape([m, 32, 80]), sinfo_args=(R.Tensor((m, 32, 80), dtype="float16"),)) | |
lv1072 = R.call_tir(cls.reshape3, (lv1070,), out_sinfo=R.Tensor((1, m, 32, 80), dtype="float16")) | |
lv1073 = R.call_tir(cls.reshape3, (lv1071,), out_sinfo=R.Tensor((1, m, 32, 80), dtype="float16")) | |
lv1074 = R.call_tir(cls.transpose5, (lv1060,), out_sinfo=R.Tensor((1, 32, n, 80), dtype="float16")) | |
lv1075 = R.call_tir(cls.transpose5, (lv1072,), out_sinfo=R.Tensor((1, 32, m, 80), dtype="float16")) | |
lv1076 = R.call_tir(cls.transpose5, (lv1073,), out_sinfo=R.Tensor((1, 32, m, 80), dtype="float16")) | |
lv407 = R.call_tir(cls.fused_NT_matmul1_divide_maximum_minimum_cast2, (lv1074, lv1075, lv5), out_sinfo=R.Tensor((1, 32, n, m), dtype="float32")) | |
lv408 = R.call_tir(cls.fused_softmax_cast3, (lv407,), out_sinfo=R.Tensor((1, 32, n, m), dtype="float16")) | |
lv1085 = R.call_tir(cls.matmul8, (lv408, lv1076), out_sinfo=R.Tensor((1, 32, n, 80), dtype="float16")) | |
lv1086 = R.call_tir(cls.transpose6, (lv1085,), out_sinfo=R.Tensor((1, n, 32, 80), dtype="float16")) | |
lv1087 = R.call_tir(cls.reshape4, (lv1086,), out_sinfo=R.Tensor((1, n, 2560), dtype="float16")) | |
lv409: R.Tensor((2560, 320), dtype="uint32") = model_params[297] | |
lv410: R.Tensor((2560, 80), dtype="float16") = model_params[298] | |
lv297_1: R.Tensor((2560,), dtype="float16") = model_params[299] | |
lv100_3 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul2_add1_add2, (lv409, lv410, lv1087, lv297_1, lv99_4), out_sinfo=R.Tensor((1, n, 2560), dtype="float16")) | |
lv1092 = R.call_tir(cls.cast, (lv100_3,), out_sinfo=R.Tensor((1, n, 2560), dtype="float32")) | |
lv298_1: R.Tensor((2560,), dtype="float32") = model_params[292] | |
lv299_2: R.Tensor((2560,), dtype="float32") = model_params[293] | |
lv413 = R.call_tir(cls.fused_layer_norm_cast1, (lv1092, lv298_1, lv299_2), out_sinfo=R.Tensor((1, n, 2560), dtype="float16")) | |
lv1095: R.Tensor((1, n, 2560), dtype="float16") = lv413 | |
lv414: R.Tensor((10240, 320), dtype="uint32") = model_params[300] | |
lv415_1: R.Tensor((10240, 80), dtype="float16") = model_params[301] | |
lv302: R.Tensor((10240,), dtype="float32") = model_params[302] | |
lv101_2 = R.call_tir(cls.fused_fused_decode4_fused_NT_matmul3_add3_gelu_cast4, (lv414, lv415_1, lv1095, lv302), out_sinfo=R.Tensor((1, n, 10240), dtype="float16")) | |
lv1101: R.Tensor((1, n, 10240), dtype="float16") = lv101_2 | |
lv418: R.Tensor((2560, 1280), dtype="uint32") = model_params[303] | |
lv419_1: R.Tensor((2560, 320), dtype="float16") = model_params[304] | |
lv305_2: R.Tensor((2560,), dtype="float32") = model_params[305] | |
lv101_3 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul4_add4_cast1_cast5_add2, (lv418, lv419_1, lv1101, lv305_2, lv100_3), out_sinfo=R.Tensor((1, n, 2560), dtype="float16")) | |
lv1108 = R.call_tir(cls.cast, (lv101_3,), out_sinfo=R.Tensor((1, n, 2560), dtype="float32")) | |
lv306_1: R.Tensor((2560,), dtype="float32") = model_params[306] | |
lv307_1: R.Tensor((2560,), dtype="float32") = model_params[307] | |
lv422_1 = R.call_tir(cls.fused_layer_norm_cast1, (lv1108, lv306_1, lv307_1), out_sinfo=R.Tensor((1, n, 2560), dtype="float16")) | |
lv1111: R.Tensor((1, n, 2560), dtype="float16") = lv422_1 | |
lv423_1: R.Tensor((7680, 320), dtype="uint32") = model_params[310] | |
lv424_1: R.Tensor((7680, 80), dtype="float16") = model_params[311] | |
lv310_1: R.Tensor((7680,), dtype="float16") = model_params[312] | |
lv102_2 = R.call_tir(cls.fused_fused_decode2_fused_NT_matmul_add, (lv423_1, lv424_1, lv1111, lv310_1), out_sinfo=R.Tensor((1, n, 7680), dtype="float16")) | |
lv1115 = R.call_tir(cls.reshape2, (lv102_2,), out_sinfo=R.Tensor((1, n, 32, 240), dtype="float16")) | |
lv1116 = R.call_tir(cls.split, (lv1115,), out_sinfo=[R.Tensor((1, n, 32, 80), dtype="float16"), R.Tensor((1, n, 32, 80), dtype="float16"), R.Tensor((1, n, 32, 80), dtype="float16")]) | |
lv1117: R.Tensor((1, n, 32, 80), dtype="float16") = lv1116[0] | |
lv1118 = R.call_tir(cls.rotary_embedding, (lv1117, metadata["relax.expr.Constant"][3], metadata["relax.expr.Constant"][4]), out_sinfo=R.Tensor((1, n, 32, 80), dtype="float16"), tir_vars=R.shape([m])) | |
lv1119: R.Tensor((1, n, 32, 80), dtype="float16") = lv1116[1] | |
lv1120 = R.call_tir(cls.rotary_embedding, (lv1119, metadata["relax.expr.Constant"][3], metadata["relax.expr.Constant"][4]), out_sinfo=R.Tensor((1, n, 32, 80), dtype="float16"), tir_vars=R.shape([m])) | |
lv1121: R.Object = kv_cache[38] | |
lv1122 = R.call_tir(cls.squeeze, (lv1120,), out_sinfo=R.Tensor((n, 32, 80), dtype="float16")) | |
lv1123: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv1121, lv1122, sinfo_args=(R.Object,)) | |
lv1124: R.Object = kv_cache[39] | |
lv427_1: R.Tensor((1, n, 32, 80), dtype="float16") = lv1116[2] | |
lv428_1 = R.call_tir(cls.fused_squeeze, (lv427_1,), out_sinfo=R.Tensor((n, 32, 80), dtype="float16")) | |
lv1127: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv1124, lv428_1, sinfo_args=(R.Object,)) | |
lv1128: R.Tensor((m, 32, 80), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv1123, R.shape([m, 32, 80]), sinfo_args=(R.Tensor((m, 32, 80), dtype="float16"),)) | |
lv1129: R.Tensor((m, 32, 80), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv1127, R.shape([m, 32, 80]), sinfo_args=(R.Tensor((m, 32, 80), dtype="float16"),)) | |
lv1130 = R.call_tir(cls.reshape3, (lv1128,), out_sinfo=R.Tensor((1, m, 32, 80), dtype="float16")) | |
lv1131 = R.call_tir(cls.reshape3, (lv1129,), out_sinfo=R.Tensor((1, m, 32, 80), dtype="float16")) | |
lv1132 = R.call_tir(cls.transpose5, (lv1118,), out_sinfo=R.Tensor((1, 32, n, 80), dtype="float16")) | |
lv1133 = R.call_tir(cls.transpose5, (lv1130,), out_sinfo=R.Tensor((1, 32, m, 80), dtype="float16")) | |
lv1134 = R.call_tir(cls.transpose5, (lv1131,), out_sinfo=R.Tensor((1, 32, m, 80), dtype="float16")) | |
lv429 = R.call_tir(cls.fused_NT_matmul1_divide_maximum_minimum_cast2, (lv1132, lv1133, lv5), out_sinfo=R.Tensor((1, 32, n, m), dtype="float32")) | |
lv430 = R.call_tir(cls.fused_softmax_cast3, (lv429,), out_sinfo=R.Tensor((1, 32, n, m), dtype="float16")) | |
lv1143 = R.call_tir(cls.matmul8, (lv430, lv1134), out_sinfo=R.Tensor((1, 32, n, 80), dtype="float16")) | |
lv1144 = R.call_tir(cls.transpose6, (lv1143,), out_sinfo=R.Tensor((1, n, 32, 80), dtype="float16")) | |
lv1145 = R.call_tir(cls.reshape4, (lv1144,), out_sinfo=R.Tensor((1, n, 2560), dtype="float16")) | |
lv431_1: R.Tensor((2560, 320), dtype="uint32") = model_params[313] | |
lv432_1: R.Tensor((2560, 80), dtype="float16") = model_params[314] | |
lv313_1: R.Tensor((2560,), dtype="float16") = model_params[315] | |
lv102_3 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul2_add1_add2, (lv431_1, lv432_1, lv1145, lv313_1, lv101_3), out_sinfo=R.Tensor((1, n, 2560), dtype="float16")) | |
lv1150 = R.call_tir(cls.cast, (lv102_3,), out_sinfo=R.Tensor((1, n, 2560), dtype="float32")) | |
lv314_1: R.Tensor((2560,), dtype="float32") = model_params[308] | |
lv315_1: R.Tensor((2560,), dtype="float32") = model_params[309] | |
lv435_1 = R.call_tir(cls.fused_layer_norm_cast1, (lv1150, lv314_1, lv315_1), out_sinfo=R.Tensor((1, n, 2560), dtype="float16")) | |
lv1153: R.Tensor((1, n, 2560), dtype="float16") = lv435_1 | |
lv436_1: R.Tensor((10240, 320), dtype="uint32") = model_params[316] | |
lv437_1: R.Tensor((10240, 80), dtype="float16") = model_params[317] | |
lv318_2: R.Tensor((10240,), dtype="float32") = model_params[318] | |
lv103 = R.call_tir(cls.fused_fused_decode4_fused_NT_matmul3_add3_gelu_cast4, (lv436_1, lv437_1, lv1153, lv318_2), out_sinfo=R.Tensor((1, n, 10240), dtype="float16")) | |
lv1159: R.Tensor((1, n, 10240), dtype="float16") = lv103 | |
lv440: R.Tensor((2560, 1280), dtype="uint32") = model_params[319] | |
lv441: R.Tensor((2560, 320), dtype="float16") = model_params[320] | |
lv321_2: R.Tensor((2560,), dtype="float32") = model_params[321] | |
lv103_1 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul4_add4_cast1_cast5_add2, (lv440, lv441, lv1159, lv321_2, lv102_3), out_sinfo=R.Tensor((1, n, 2560), dtype="float16")) | |
lv1166 = R.call_tir(cls.cast, (lv103_1,), out_sinfo=R.Tensor((1, n, 2560), dtype="float32")) | |
lv322_2: R.Tensor((2560,), dtype="float32") = model_params[322] | |
lv323: R.Tensor((2560,), dtype="float32") = model_params[323] | |
lv444 = R.call_tir(cls.fused_layer_norm_cast1, (lv1166, lv322_2, lv323), out_sinfo=R.Tensor((1, n, 2560), dtype="float16")) | |
lv1169: R.Tensor((1, n, 2560), dtype="float16") = lv444 | |
lv445: R.Tensor((7680, 320), dtype="uint32") = model_params[326] | |
lv446: R.Tensor((7680, 80), dtype="float16") = model_params[327] | |
lv326_1: R.Tensor((7680,), dtype="float16") = model_params[328] | |
lv104 = R.call_tir(cls.fused_fused_decode2_fused_NT_matmul_add, (lv445, lv446, lv1169, lv326_1), out_sinfo=R.Tensor((1, n, 7680), dtype="float16")) | |
lv1173 = R.call_tir(cls.reshape2, (lv104,), out_sinfo=R.Tensor((1, n, 32, 240), dtype="float16")) | |
lv1174 = R.call_tir(cls.split, (lv1173,), out_sinfo=[R.Tensor((1, n, 32, 80), dtype="float16"), R.Tensor((1, n, 32, 80), dtype="float16"), R.Tensor((1, n, 32, 80), dtype="float16")]) | |
lv1175: R.Tensor((1, n, 32, 80), dtype="float16") = lv1174[0] | |
lv1176 = R.call_tir(cls.rotary_embedding, (lv1175, metadata["relax.expr.Constant"][3], metadata["relax.expr.Constant"][4]), out_sinfo=R.Tensor((1, n, 32, 80), dtype="float16"), tir_vars=R.shape([m])) | |
lv1177: R.Tensor((1, n, 32, 80), dtype="float16") = lv1174[1] | |
lv1178 = R.call_tir(cls.rotary_embedding, (lv1177, metadata["relax.expr.Constant"][3], metadata["relax.expr.Constant"][4]), out_sinfo=R.Tensor((1, n, 32, 80), dtype="float16"), tir_vars=R.shape([m])) | |
lv1179: R.Object = kv_cache[40] | |
lv1180 = R.call_tir(cls.squeeze, (lv1178,), out_sinfo=R.Tensor((n, 32, 80), dtype="float16")) | |
lv1181: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv1179, lv1180, sinfo_args=(R.Object,)) | |
lv1182: R.Object = kv_cache[41] | |
lv449_1: R.Tensor((1, n, 32, 80), dtype="float16") = lv1174[2] | |
lv450 = R.call_tir(cls.fused_squeeze, (lv449_1,), out_sinfo=R.Tensor((n, 32, 80), dtype="float16")) | |
lv1185: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv1182, lv450, sinfo_args=(R.Object,)) | |
lv1186: R.Tensor((m, 32, 80), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv1181, R.shape([m, 32, 80]), sinfo_args=(R.Tensor((m, 32, 80), dtype="float16"),)) | |
lv1187: R.Tensor((m, 32, 80), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv1185, R.shape([m, 32, 80]), sinfo_args=(R.Tensor((m, 32, 80), dtype="float16"),)) | |
lv1188 = R.call_tir(cls.reshape3, (lv1186,), out_sinfo=R.Tensor((1, m, 32, 80), dtype="float16")) | |
lv1189 = R.call_tir(cls.reshape3, (lv1187,), out_sinfo=R.Tensor((1, m, 32, 80), dtype="float16")) | |
lv1190 = R.call_tir(cls.transpose5, (lv1176,), out_sinfo=R.Tensor((1, 32, n, 80), dtype="float16")) | |
lv1191 = R.call_tir(cls.transpose5, (lv1188,), out_sinfo=R.Tensor((1, 32, m, 80), dtype="float16")) | |
lv1192 = R.call_tir(cls.transpose5, (lv1189,), out_sinfo=R.Tensor((1, 32, m, 80), dtype="float16")) | |
lv451 = R.call_tir(cls.fused_NT_matmul1_divide_maximum_minimum_cast2, (lv1190, lv1191, lv5), out_sinfo=R.Tensor((1, 32, n, m), dtype="float32")) | |
lv452 = R.call_tir(cls.fused_softmax_cast3, (lv451,), out_sinfo=R.Tensor((1, 32, n, m), dtype="float16")) | |
lv1201 = R.call_tir(cls.matmul8, (lv452, lv1192), out_sinfo=R.Tensor((1, 32, n, 80), dtype="float16")) | |
lv1202 = R.call_tir(cls.transpose6, (lv1201,), out_sinfo=R.Tensor((1, n, 32, 80), dtype="float16")) | |
lv1203 = R.call_tir(cls.reshape4, (lv1202,), out_sinfo=R.Tensor((1, n, 2560), dtype="float16")) | |
lv453: R.Tensor((2560, 320), dtype="uint32") = model_params[329] | |
lv454_1: R.Tensor((2560, 80), dtype="float16") = model_params[330] | |
lv329: R.Tensor((2560,), dtype="float16") = model_params[331] | |
lv104_1 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul2_add1_add2, (lv453, lv454_1, lv1203, lv329, lv103_1), out_sinfo=R.Tensor((1, n, 2560), dtype="float16")) | |
lv1208 = R.call_tir(cls.cast, (lv104_1,), out_sinfo=R.Tensor((1, n, 2560), dtype="float32")) | |
lv330_1: R.Tensor((2560,), dtype="float32") = model_params[324] | |
lv331_2: R.Tensor((2560,), dtype="float32") = model_params[325] | |
lv457_1 = R.call_tir(cls.fused_layer_norm_cast1, (lv1208, lv330_1, lv331_2), out_sinfo=R.Tensor((1, n, 2560), dtype="float16")) | |
lv1211: R.Tensor((1, n, 2560), dtype="float16") = lv457_1 | |
lv458: R.Tensor((10240, 320), dtype="uint32") = model_params[332] | |
lv459: R.Tensor((10240, 80), dtype="float16") = model_params[333] | |
lv334_1: R.Tensor((10240,), dtype="float32") = model_params[334] | |
lv105_2 = R.call_tir(cls.fused_fused_decode4_fused_NT_matmul3_add3_gelu_cast4, (lv458, lv459, lv1211, lv334_1), out_sinfo=R.Tensor((1, n, 10240), dtype="float16")) | |
lv1217: R.Tensor((1, n, 10240), dtype="float16") = lv105_2 | |
lv462: R.Tensor((2560, 1280), dtype="uint32") = model_params[335] | |
lv463_1: R.Tensor((2560, 320), dtype="float16") = model_params[336] | |
lv337: R.Tensor((2560,), dtype="float32") = model_params[337] | |
lv105_3 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul4_add4_cast1_cast5_add2, (lv462, lv463_1, lv1217, lv337, lv104_1), out_sinfo=R.Tensor((1, n, 2560), dtype="float16")) | |
lv1224 = R.call_tir(cls.cast, (lv105_3,), out_sinfo=R.Tensor((1, n, 2560), dtype="float32")) | |
lv338_1: R.Tensor((2560,), dtype="float32") = model_params[338] | |
lv339_1: R.Tensor((2560,), dtype="float32") = model_params[339] | |
lv466 = R.call_tir(cls.fused_layer_norm_cast1, (lv1224, lv338_1, lv339_1), out_sinfo=R.Tensor((1, n, 2560), dtype="float16")) | |
lv1227: R.Tensor((1, n, 2560), dtype="float16") = lv466 | |
lv467: R.Tensor((7680, 320), dtype="uint32") = model_params[342] | |
lv468: R.Tensor((7680, 80), dtype="float16") = model_params[343] | |
lv342_1: R.Tensor((7680,), dtype="float16") = model_params[344] | |
lv106_3 = R.call_tir(cls.fused_fused_decode2_fused_NT_matmul_add, (lv467, lv468, lv1227, lv342_1), out_sinfo=R.Tensor((1, n, 7680), dtype="float16")) | |
lv1231 = R.call_tir(cls.reshape2, (lv106_3,), out_sinfo=R.Tensor((1, n, 32, 240), dtype="float16")) | |
lv1232 = R.call_tir(cls.split, (lv1231,), out_sinfo=[R.Tensor((1, n, 32, 80), dtype="float16"), R.Tensor((1, n, 32, 80), dtype="float16"), R.Tensor((1, n, 32, 80), dtype="float16")]) | |
lv1233: R.Tensor((1, n, 32, 80), dtype="float16") = lv1232[0] | |
lv1234 = R.call_tir(cls.rotary_embedding, (lv1233, metadata["relax.expr.Constant"][3], metadata["relax.expr.Constant"][4]), out_sinfo=R.Tensor((1, n, 32, 80), dtype="float16"), tir_vars=R.shape([m])) | |
lv1235: R.Tensor((1, n, 32, 80), dtype="float16") = lv1232[1] | |
lv1236 = R.call_tir(cls.rotary_embedding, (lv1235, metadata["relax.expr.Constant"][3], metadata["relax.expr.Constant"][4]), out_sinfo=R.Tensor((1, n, 32, 80), dtype="float16"), tir_vars=R.shape([m])) | |
lv1237: R.Object = kv_cache[42] | |
lv1238 = R.call_tir(cls.squeeze, (lv1236,), out_sinfo=R.Tensor((n, 32, 80), dtype="float16")) | |
lv1239: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv1237, lv1238, sinfo_args=(R.Object,)) | |
lv1240: R.Object = kv_cache[43] | |
lv471: R.Tensor((1, n, 32, 80), dtype="float16") = lv1232[2] | |
lv472 = R.call_tir(cls.fused_squeeze, (lv471,), out_sinfo=R.Tensor((n, 32, 80), dtype="float16")) | |
lv1243: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv1240, lv472, sinfo_args=(R.Object,)) | |
lv1244: R.Tensor((m, 32, 80), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv1239, R.shape([m, 32, 80]), sinfo_args=(R.Tensor((m, 32, 80), dtype="float16"),)) | |
lv1245: R.Tensor((m, 32, 80), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv1243, R.shape([m, 32, 80]), sinfo_args=(R.Tensor((m, 32, 80), dtype="float16"),)) | |
lv1246 = R.call_tir(cls.reshape3, (lv1244,), out_sinfo=R.Tensor((1, m, 32, 80), dtype="float16")) | |
lv1247 = R.call_tir(cls.reshape3, (lv1245,), out_sinfo=R.Tensor((1, m, 32, 80), dtype="float16")) | |
lv1248 = R.call_tir(cls.transpose5, (lv1234,), out_sinfo=R.Tensor((1, 32, n, 80), dtype="float16")) | |
lv1249 = R.call_tir(cls.transpose5, (lv1246,), out_sinfo=R.Tensor((1, 32, m, 80), dtype="float16")) | |
lv1250 = R.call_tir(cls.transpose5, (lv1247,), out_sinfo=R.Tensor((1, 32, m, 80), dtype="float16")) | |
lv473_1 = R.call_tir(cls.fused_NT_matmul1_divide_maximum_minimum_cast2, (lv1248, lv1249, lv5), out_sinfo=R.Tensor((1, 32, n, m), dtype="float32")) | |
lv474 = R.call_tir(cls.fused_softmax_cast3, (lv473_1,), out_sinfo=R.Tensor((1, 32, n, m), dtype="float16")) | |
lv1259 = R.call_tir(cls.matmul8, (lv474, lv1250), out_sinfo=R.Tensor((1, 32, n, 80), dtype="float16")) | |
lv1260 = R.call_tir(cls.transpose6, (lv1259,), out_sinfo=R.Tensor((1, n, 32, 80), dtype="float16")) | |
lv1261 = R.call_tir(cls.reshape4, (lv1260,), out_sinfo=R.Tensor((1, n, 2560), dtype="float16")) | |
lv475: R.Tensor((2560, 320), dtype="uint32") = model_params[345] | |
lv476: R.Tensor((2560, 80), dtype="float16") = model_params[346] | |
lv345: R.Tensor((2560,), dtype="float16") = model_params[347] | |
lv106_4 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul2_add1_add2, (lv475, lv476, lv1261, lv345, lv105_3), out_sinfo=R.Tensor((1, n, 2560), dtype="float16")) | |
lv1266 = R.call_tir(cls.cast, (lv106_4,), out_sinfo=R.Tensor((1, n, 2560), dtype="float32")) | |
lv346: R.Tensor((2560,), dtype="float32") = model_params[340] | |
lv347_2: R.Tensor((2560,), dtype="float32") = model_params[341] | |
lv479_1 = R.call_tir(cls.fused_layer_norm_cast1, (lv1266, lv346, lv347_2), out_sinfo=R.Tensor((1, n, 2560), dtype="float16")) | |
lv1269: R.Tensor((1, n, 2560), dtype="float16") = lv479_1 | |
lv480_1: R.Tensor((10240, 320), dtype="uint32") = model_params[348] | |
lv481_1: R.Tensor((10240, 80), dtype="float16") = model_params[349] | |
lv350: R.Tensor((10240,), dtype="float32") = model_params[350] | |
lv107_2 = R.call_tir(cls.fused_fused_decode4_fused_NT_matmul3_add3_gelu_cast4, (lv480_1, lv481_1, lv1269, lv350), out_sinfo=R.Tensor((1, n, 10240), dtype="float16")) | |
lv1275: R.Tensor((1, n, 10240), dtype="float16") = lv107_2 | |
lv484_1: R.Tensor((2560, 1280), dtype="uint32") = model_params[351] | |
lv485_1: R.Tensor((2560, 320), dtype="float16") = model_params[352] | |
lv353_1: R.Tensor((2560,), dtype="float32") = model_params[353] | |
lv107_3 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul4_add4_cast1_cast5_add2, (lv484_1, lv485_1, lv1275, lv353_1, lv106_4), out_sinfo=R.Tensor((1, n, 2560), dtype="float16")) | |
lv1282 = R.call_tir(cls.cast, (lv107_3,), out_sinfo=R.Tensor((1, n, 2560), dtype="float32")) | |
lv354_1: R.Tensor((2560,), dtype="float32") = model_params[354] | |
lv355: R.Tensor((2560,), dtype="float32") = model_params[355] | |
lv488 = R.call_tir(cls.fused_layer_norm_cast1, (lv1282, lv354_1, lv355), out_sinfo=R.Tensor((1, n, 2560), dtype="float16")) | |
lv1285: R.Tensor((1, n, 2560), dtype="float16") = lv488 | |
lv489_1: R.Tensor((7680, 320), dtype="uint32") = model_params[358] | |
lv490_1: R.Tensor((7680, 80), dtype="float16") = model_params[359] | |
lv358_1: R.Tensor((7680,), dtype="float16") = model_params[360] | |
lv108 = R.call_tir(cls.fused_fused_decode2_fused_NT_matmul_add, (lv489_1, lv490_1, lv1285, lv358_1), out_sinfo=R.Tensor((1, n, 7680), dtype="float16")) | |
lv1289 = R.call_tir(cls.reshape2, (lv108,), out_sinfo=R.Tensor((1, n, 32, 240), dtype="float16")) | |
lv1290 = R.call_tir(cls.split, (lv1289,), out_sinfo=[R.Tensor((1, n, 32, 80), dtype="float16"), R.Tensor((1, n, 32, 80), dtype="float16"), R.Tensor((1, n, 32, 80), dtype="float16")]) | |
lv1291: R.Tensor((1, n, 32, 80), dtype="float16") = lv1290[0] | |
lv1292 = R.call_tir(cls.rotary_embedding, (lv1291, metadata["relax.expr.Constant"][3], metadata["relax.expr.Constant"][4]), out_sinfo=R.Tensor((1, n, 32, 80), dtype="float16"), tir_vars=R.shape([m])) | |
lv1293: R.Tensor((1, n, 32, 80), dtype="float16") = lv1290[1] | |
lv1294 = R.call_tir(cls.rotary_embedding, (lv1293, metadata["relax.expr.Constant"][3], metadata["relax.expr.Constant"][4]), out_sinfo=R.Tensor((1, n, 32, 80), dtype="float16"), tir_vars=R.shape([m])) | |
lv1295: R.Object = kv_cache[44] | |
lv1296 = R.call_tir(cls.squeeze, (lv1294,), out_sinfo=R.Tensor((n, 32, 80), dtype="float16")) | |
lv1297: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv1295, lv1296, sinfo_args=(R.Object,)) | |
lv1298: R.Object = kv_cache[45] | |
lv493_1: R.Tensor((1, n, 32, 80), dtype="float16") = lv1290[2] | |
lv494_1 = R.call_tir(cls.fused_squeeze, (lv493_1,), out_sinfo=R.Tensor((n, 32, 80), dtype="float16")) | |
lv1301: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv1298, lv494_1, sinfo_args=(R.Object,)) | |
lv1302: R.Tensor((m, 32, 80), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv1297, R.shape([m, 32, 80]), sinfo_args=(R.Tensor((m, 32, 80), dtype="float16"),)) | |
lv1303: R.Tensor((m, 32, 80), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv1301, R.shape([m, 32, 80]), sinfo_args=(R.Tensor((m, 32, 80), dtype="float16"),)) | |
lv1304 = R.call_tir(cls.reshape3, (lv1302,), out_sinfo=R.Tensor((1, m, 32, 80), dtype="float16")) | |
lv1305 = R.call_tir(cls.reshape3, (lv1303,), out_sinfo=R.Tensor((1, m, 32, 80), dtype="float16")) | |
lv1306 = R.call_tir(cls.transpose5, (lv1292,), out_sinfo=R.Tensor((1, 32, n, 80), dtype="float16")) | |
lv1307 = R.call_tir(cls.transpose5, (lv1304,), out_sinfo=R.Tensor((1, 32, m, 80), dtype="float16")) | |
lv1308 = R.call_tir(cls.transpose5, (lv1305,), out_sinfo=R.Tensor((1, 32, m, 80), dtype="float16")) | |
lv495_1 = R.call_tir(cls.fused_NT_matmul1_divide_maximum_minimum_cast2, (lv1306, lv1307, lv5), out_sinfo=R.Tensor((1, 32, n, m), dtype="float32")) | |
lv496_1 = R.call_tir(cls.fused_softmax_cast3, (lv495_1,), out_sinfo=R.Tensor((1, 32, n, m), dtype="float16")) | |
lv1317 = R.call_tir(cls.matmul8, (lv496_1, lv1308), out_sinfo=R.Tensor((1, 32, n, 80), dtype="float16")) | |
lv1318 = R.call_tir(cls.transpose6, (lv1317,), out_sinfo=R.Tensor((1, n, 32, 80), dtype="float16")) | |
lv1319 = R.call_tir(cls.reshape4, (lv1318,), out_sinfo=R.Tensor((1, n, 2560), dtype="float16")) | |
lv497: R.Tensor((2560, 320), dtype="uint32") = model_params[361] | |
lv498: R.Tensor((2560, 80), dtype="float16") = model_params[362] | |
lv361_2: R.Tensor((2560,), dtype="float16") = model_params[363] | |
lv108_1 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul2_add1_add2, (lv497, lv498, lv1319, lv361_2, lv107_3), out_sinfo=R.Tensor((1, n, 2560), dtype="float16")) | |
lv1324 = R.call_tir(cls.cast, (lv108_1,), out_sinfo=R.Tensor((1, n, 2560), dtype="float32")) | |
lv362_2: R.Tensor((2560,), dtype="float32") = model_params[356] | |
lv363_2: R.Tensor((2560,), dtype="float32") = model_params[357] | |
lv501 = R.call_tir(cls.fused_layer_norm_cast1, (lv1324, lv362_2, lv363_2), out_sinfo=R.Tensor((1, n, 2560), dtype="float16")) | |
lv1327: R.Tensor((1, n, 2560), dtype="float16") = lv501 | |
lv502: R.Tensor((10240, 320), dtype="uint32") = model_params[364] | |
lv503: R.Tensor((10240, 80), dtype="float16") = model_params[365] | |
lv366_2: R.Tensor((10240,), dtype="float32") = model_params[366] | |
lv109_1 = R.call_tir(cls.fused_fused_decode4_fused_NT_matmul3_add3_gelu_cast4, (lv502, lv503, lv1327, lv366_2), out_sinfo=R.Tensor((1, n, 10240), dtype="float16")) | |
lv1333: R.Tensor((1, n, 10240), dtype="float16") = lv109_1 | |
lv506_1: R.Tensor((2560, 1280), dtype="uint32") = model_params[367] | |
lv507_1: R.Tensor((2560, 320), dtype="float16") = model_params[368] | |
lv369_2: R.Tensor((2560,), dtype="float32") = model_params[369] | |
lv109_2 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul4_add4_cast1_cast5_add2, (lv506_1, lv507_1, lv1333, lv369_2, lv108_1), out_sinfo=R.Tensor((1, n, 2560), dtype="float16")) | |
lv1340 = R.call_tir(cls.cast, (lv109_2,), out_sinfo=R.Tensor((1, n, 2560), dtype="float32")) | |
lv370_2: R.Tensor((2560,), dtype="float32") = model_params[370] | |
lv371_1: R.Tensor((2560,), dtype="float32") = model_params[371] | |
lv510 = R.call_tir(cls.fused_layer_norm_cast1, (lv1340, lv370_2, lv371_1), out_sinfo=R.Tensor((1, n, 2560), dtype="float16")) | |
lv1343: R.Tensor((1, n, 2560), dtype="float16") = lv510 | |
lv511: R.Tensor((7680, 320), dtype="uint32") = model_params[374] | |
lv512_1: R.Tensor((7680, 80), dtype="float16") = model_params[375] | |
lv374_2: R.Tensor((7680,), dtype="float16") = model_params[376] | |
lv110_2 = R.call_tir(cls.fused_fused_decode2_fused_NT_matmul_add, (lv511, lv512_1, lv1343, lv374_2), out_sinfo=R.Tensor((1, n, 7680), dtype="float16")) | |
lv1347 = R.call_tir(cls.reshape2, (lv110_2,), out_sinfo=R.Tensor((1, n, 32, 240), dtype="float16")) | |
lv1348 = R.call_tir(cls.split, (lv1347,), out_sinfo=[R.Tensor((1, n, 32, 80), dtype="float16"), R.Tensor((1, n, 32, 80), dtype="float16"), R.Tensor((1, n, 32, 80), dtype="float16")]) | |
lv1349: R.Tensor((1, n, 32, 80), dtype="float16") = lv1348[0] | |
lv1350 = R.call_tir(cls.rotary_embedding, (lv1349, metadata["relax.expr.Constant"][3], metadata["relax.expr.Constant"][4]), out_sinfo=R.Tensor((1, n, 32, 80), dtype="float16"), tir_vars=R.shape([m])) | |
lv1351: R.Tensor((1, n, 32, 80), dtype="float16") = lv1348[1] | |
lv1352 = R.call_tir(cls.rotary_embedding, (lv1351, metadata["relax.expr.Constant"][3], metadata["relax.expr.Constant"][4]), out_sinfo=R.Tensor((1, n, 32, 80), dtype="float16"), tir_vars=R.shape([m])) | |
lv1353: R.Object = kv_cache[46] | |
lv1354 = R.call_tir(cls.squeeze, (lv1352,), out_sinfo=R.Tensor((n, 32, 80), dtype="float16")) | |
lv1355: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv1353, lv1354, sinfo_args=(R.Object,)) | |
lv1356: R.Object = kv_cache[47] | |
lv515_1: R.Tensor((1, n, 32, 80), dtype="float16") = lv1348[2] | |
lv516 = R.call_tir(cls.fused_squeeze, (lv515_1,), out_sinfo=R.Tensor((n, 32, 80), dtype="float16")) | |
lv1359: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv1356, lv516, sinfo_args=(R.Object,)) | |
lv1360: R.Tensor((m, 32, 80), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv1355, R.shape([m, 32, 80]), sinfo_args=(R.Tensor((m, 32, 80), dtype="float16"),)) | |
lv1361: R.Tensor((m, 32, 80), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv1359, R.shape([m, 32, 80]), sinfo_args=(R.Tensor((m, 32, 80), dtype="float16"),)) | |
lv1362 = R.call_tir(cls.reshape3, (lv1360,), out_sinfo=R.Tensor((1, m, 32, 80), dtype="float16")) | |
lv1363 = R.call_tir(cls.reshape3, (lv1361,), out_sinfo=R.Tensor((1, m, 32, 80), dtype="float16")) | |
lv1364 = R.call_tir(cls.transpose5, (lv1350,), out_sinfo=R.Tensor((1, 32, n, 80), dtype="float16")) | |
lv1365 = R.call_tir(cls.transpose5, (lv1362,), out_sinfo=R.Tensor((1, 32, m, 80), dtype="float16")) | |
lv1366 = R.call_tir(cls.transpose5, (lv1363,), out_sinfo=R.Tensor((1, 32, m, 80), dtype="float16")) | |
lv517 = R.call_tir(cls.fused_NT_matmul1_divide_maximum_minimum_cast2, (lv1364, lv1365, lv5), out_sinfo=R.Tensor((1, 32, n, m), dtype="float32")) | |
lv518 = R.call_tir(cls.fused_softmax_cast3, (lv517,), out_sinfo=R.Tensor((1, 32, n, m), dtype="float16")) | |
lv1375 = R.call_tir(cls.matmul8, (lv518, lv1366), out_sinfo=R.Tensor((1, 32, n, 80), dtype="float16")) | |
lv1376 = R.call_tir(cls.transpose6, (lv1375,), out_sinfo=R.Tensor((1, n, 32, 80), dtype="float16")) | |
lv1377 = R.call_tir(cls.reshape4, (lv1376,), out_sinfo=R.Tensor((1, n, 2560), dtype="float16")) | |
lv519: R.Tensor((2560, 320), dtype="uint32") = model_params[377] | |
lv520: R.Tensor((2560, 80), dtype="float16") = model_params[378] | |
lv377_1: R.Tensor((2560,), dtype="float16") = model_params[379] | |
lv110_3 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul2_add1_add2, (lv519, lv520, lv1377, lv377_1, lv109_2), out_sinfo=R.Tensor((1, n, 2560), dtype="float16")) | |
lv1382 = R.call_tir(cls.cast, (lv110_3,), out_sinfo=R.Tensor((1, n, 2560), dtype="float32")) | |
lv378_2: R.Tensor((2560,), dtype="float32") = model_params[372] | |
lv379_2: R.Tensor((2560,), dtype="float32") = model_params[373] | |
lv523 = R.call_tir(cls.fused_layer_norm_cast1, (lv1382, lv378_2, lv379_2), out_sinfo=R.Tensor((1, n, 2560), dtype="float16")) | |
lv1385: R.Tensor((1, n, 2560), dtype="float16") = lv523 | |
lv524: R.Tensor((10240, 320), dtype="uint32") = model_params[380] | |
lv525: R.Tensor((10240, 80), dtype="float16") = model_params[381] | |
lv382: R.Tensor((10240,), dtype="float32") = model_params[382] | |
lv111_1 = R.call_tir(cls.fused_fused_decode4_fused_NT_matmul3_add3_gelu_cast4, (lv524, lv525, lv1385, lv382), out_sinfo=R.Tensor((1, n, 10240), dtype="float16")) | |
lv1391: R.Tensor((1, n, 10240), dtype="float16") = lv111_1 | |
lv528_1: R.Tensor((2560, 1280), dtype="uint32") = model_params[383] | |
lv529: R.Tensor((2560, 320), dtype="float16") = model_params[384] | |
lv385_1: R.Tensor((2560,), dtype="float32") = model_params[385] | |
lv111_2 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul4_add4_cast1_cast5_add2, (lv528_1, lv529, lv1391, lv385_1, lv110_3), out_sinfo=R.Tensor((1, n, 2560), dtype="float16")) | |
lv1398 = R.call_tir(cls.cast, (lv111_2,), out_sinfo=R.Tensor((1, n, 2560), dtype="float32")) | |
lv386_1: R.Tensor((2560,), dtype="float32") = model_params[386] | |
lv387_1: R.Tensor((2560,), dtype="float32") = model_params[387] | |
lv532 = R.call_tir(cls.fused_layer_norm_cast1, (lv1398, lv386_1, lv387_1), out_sinfo=R.Tensor((1, n, 2560), dtype="float16")) | |
lv1401: R.Tensor((1, n, 2560), dtype="float16") = lv532 | |
lv533: R.Tensor((7680, 320), dtype="uint32") = model_params[390] | |
lv534: R.Tensor((7680, 80), dtype="float16") = model_params[391] | |
lv390_1: R.Tensor((7680,), dtype="float16") = model_params[392] | |
lv112 = R.call_tir(cls.fused_fused_decode2_fused_NT_matmul_add, (lv533, lv534, lv1401, lv390_1), out_sinfo=R.Tensor((1, n, 7680), dtype="float16")) | |
lv1405 = R.call_tir(cls.reshape2, (lv112,), out_sinfo=R.Tensor((1, n, 32, 240), dtype="float16")) | |
lv1406 = R.call_tir(cls.split, (lv1405,), out_sinfo=[R.Tensor((1, n, 32, 80), dtype="float16"), R.Tensor((1, n, 32, 80), dtype="float16"), R.Tensor((1, n, 32, 80), dtype="float16")]) | |
lv1407: R.Tensor((1, n, 32, 80), dtype="float16") = lv1406[0] | |
lv1408 = R.call_tir(cls.rotary_embedding, (lv1407, metadata["relax.expr.Constant"][3], metadata["relax.expr.Constant"][4]), out_sinfo=R.Tensor((1, n, 32, 80), dtype="float16"), tir_vars=R.shape([m])) | |
lv1409: R.Tensor((1, n, 32, 80), dtype="float16") = lv1406[1] | |
lv1410 = R.call_tir(cls.rotary_embedding, (lv1409, metadata["relax.expr.Constant"][3], metadata["relax.expr.Constant"][4]), out_sinfo=R.Tensor((1, n, 32, 80), dtype="float16"), tir_vars=R.shape([m])) | |
lv1411: R.Object = kv_cache[48] | |
lv1412 = R.call_tir(cls.squeeze, (lv1410,), out_sinfo=R.Tensor((n, 32, 80), dtype="float16")) | |
lv1413: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv1411, lv1412, sinfo_args=(R.Object,)) | |
lv1414: R.Object = kv_cache[49] | |
lv537_1: R.Tensor((1, n, 32, 80), dtype="float16") = lv1406[2] | |
lv538_1 = R.call_tir(cls.fused_squeeze, (lv537_1,), out_sinfo=R.Tensor((n, 32, 80), dtype="float16")) | |
lv1417: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv1414, lv538_1, sinfo_args=(R.Object,)) | |
lv1418: R.Tensor((m, 32, 80), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv1413, R.shape([m, 32, 80]), sinfo_args=(R.Tensor((m, 32, 80), dtype="float16"),)) | |
lv1419: R.Tensor((m, 32, 80), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv1417, R.shape([m, 32, 80]), sinfo_args=(R.Tensor((m, 32, 80), dtype="float16"),)) | |
lv1420 = R.call_tir(cls.reshape3, (lv1418,), out_sinfo=R.Tensor((1, m, 32, 80), dtype="float16")) | |
lv1421 = R.call_tir(cls.reshape3, (lv1419,), out_sinfo=R.Tensor((1, m, 32, 80), dtype="float16")) | |
lv1422 = R.call_tir(cls.transpose5, (lv1408,), out_sinfo=R.Tensor((1, 32, n, 80), dtype="float16")) | |
lv1423 = R.call_tir(cls.transpose5, (lv1420,), out_sinfo=R.Tensor((1, 32, m, 80), dtype="float16")) | |
lv1424 = R.call_tir(cls.transpose5, (lv1421,), out_sinfo=R.Tensor((1, 32, m, 80), dtype="float16")) | |
lv539_1 = R.call_tir(cls.fused_NT_matmul1_divide_maximum_minimum_cast2, (lv1422, lv1423, lv5), out_sinfo=R.Tensor((1, 32, n, m), dtype="float32")) | |
lv540_1 = R.call_tir(cls.fused_softmax_cast3, (lv539_1,), out_sinfo=R.Tensor((1, 32, n, m), dtype="float16")) | |
lv1433 = R.call_tir(cls.matmul8, (lv540_1, lv1424), out_sinfo=R.Tensor((1, 32, n, 80), dtype="float16")) | |
lv1434 = R.call_tir(cls.transpose6, (lv1433,), out_sinfo=R.Tensor((1, n, 32, 80), dtype="float16")) | |
lv1435 = R.call_tir(cls.reshape4, (lv1434,), out_sinfo=R.Tensor((1, n, 2560), dtype="float16")) | |
lv541_1: R.Tensor((2560, 320), dtype="uint32") = model_params[393] | |
lv542_1: R.Tensor((2560, 80), dtype="float16") = model_params[394] | |
lv393_1: R.Tensor((2560,), dtype="float16") = model_params[395] | |
lv112_1 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul2_add1_add2, (lv541_1, lv542_1, lv1435, lv393_1, lv111_2), out_sinfo=R.Tensor((1, n, 2560), dtype="float16")) | |
lv1440 = R.call_tir(cls.cast, (lv112_1,), out_sinfo=R.Tensor((1, n, 2560), dtype="float32")) | |
lv394: R.Tensor((2560,), dtype="float32") = model_params[388] | |
lv395: R.Tensor((2560,), dtype="float32") = model_params[389] | |
lv545 = R.call_tir(cls.fused_layer_norm_cast1, (lv1440, lv394, lv395), out_sinfo=R.Tensor((1, n, 2560), dtype="float16")) | |
lv1443: R.Tensor((1, n, 2560), dtype="float16") = lv545 | |
lv546: R.Tensor((10240, 320), dtype="uint32") = model_params[396] | |
lv547_1: R.Tensor((10240, 80), dtype="float16") = model_params[397] | |
lv398: R.Tensor((10240,), dtype="float32") = model_params[398] | |
lv113_1 = R.call_tir(cls.fused_fused_decode4_fused_NT_matmul3_add3_gelu_cast4, (lv546, lv547_1, lv1443, lv398), out_sinfo=R.Tensor((1, n, 10240), dtype="float16")) | |
lv1449: R.Tensor((1, n, 10240), dtype="float16") = lv113_1 | |
lv550_1: R.Tensor((2560, 1280), dtype="uint32") = model_params[399] | |
lv551_1: R.Tensor((2560, 320), dtype="float16") = model_params[400] | |
lv401_1: R.Tensor((2560,), dtype="float32") = model_params[401] | |
lv113_2 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul4_add4_cast1_cast5_add2, (lv550_1, lv551_1, lv1449, lv401_1, lv112_1), out_sinfo=R.Tensor((1, n, 2560), dtype="float16")) | |
lv1456 = R.call_tir(cls.cast, (lv113_2,), out_sinfo=R.Tensor((1, n, 2560), dtype="float32")) | |
lv402_1: R.Tensor((2560,), dtype="float32") = model_params[402] | |
lv403: R.Tensor((2560,), dtype="float32") = model_params[403] | |
lv554_1 = R.call_tir(cls.fused_layer_norm_cast1, (lv1456, lv402_1, lv403), out_sinfo=R.Tensor((1, n, 2560), dtype="float16")) | |
lv1459: R.Tensor((1, n, 2560), dtype="float16") = lv554_1 | |
lv555: R.Tensor((7680, 320), dtype="uint32") = model_params[406] | |
lv556: R.Tensor((7680, 80), dtype="float16") = model_params[407] | |
lv406_1: R.Tensor((7680,), dtype="float16") = model_params[408] | |
lv114_2 = R.call_tir(cls.fused_fused_decode2_fused_NT_matmul_add, (lv555, lv556, lv1459, lv406_1), out_sinfo=R.Tensor((1, n, 7680), dtype="float16")) | |
lv1463 = R.call_tir(cls.reshape2, (lv114_2,), out_sinfo=R.Tensor((1, n, 32, 240), dtype="float16")) | |
lv1464 = R.call_tir(cls.split, (lv1463,), out_sinfo=[R.Tensor((1, n, 32, 80), dtype="float16"), R.Tensor((1, n, 32, 80), dtype="float16"), R.Tensor((1, n, 32, 80), dtype="float16")]) | |
lv1465: R.Tensor((1, n, 32, 80), dtype="float16") = lv1464[0] | |
lv1466 = R.call_tir(cls.rotary_embedding, (lv1465, metadata["relax.expr.Constant"][3], metadata["relax.expr.Constant"][4]), out_sinfo=R.Tensor((1, n, 32, 80), dtype="float16"), tir_vars=R.shape([m])) | |
lv1467: R.Tensor((1, n, 32, 80), dtype="float16") = lv1464[1] | |
lv1468 = R.call_tir(cls.rotary_embedding, (lv1467, metadata["relax.expr.Constant"][3], metadata["relax.expr.Constant"][4]), out_sinfo=R.Tensor((1, n, 32, 80), dtype="float16"), tir_vars=R.shape([m])) | |
lv1469: R.Object = kv_cache[50] | |
lv1470 = R.call_tir(cls.squeeze, (lv1468,), out_sinfo=R.Tensor((n, 32, 80), dtype="float16")) | |
lv1471: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv1469, lv1470, sinfo_args=(R.Object,)) | |
lv1472: R.Object = kv_cache[51] | |
lv559: R.Tensor((1, n, 32, 80), dtype="float16") = lv1464[2] | |
lv560 = R.call_tir(cls.fused_squeeze, (lv559,), out_sinfo=R.Tensor((n, 32, 80), dtype="float16")) | |
lv1475: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv1472, lv560, sinfo_args=(R.Object,)) | |
lv1476: R.Tensor((m, 32, 80), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv1471, R.shape([m, 32, 80]), sinfo_args=(R.Tensor((m, 32, 80), dtype="float16"),)) | |
lv1477: R.Tensor((m, 32, 80), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv1475, R.shape([m, 32, 80]), sinfo_args=(R.Tensor((m, 32, 80), dtype="float16"),)) | |
lv1478 = R.call_tir(cls.reshape3, (lv1476,), out_sinfo=R.Tensor((1, m, 32, 80), dtype="float16")) | |
lv1479 = R.call_tir(cls.reshape3, (lv1477,), out_sinfo=R.Tensor((1, m, 32, 80), dtype="float16")) | |
lv1480 = R.call_tir(cls.transpose5, (lv1466,), out_sinfo=R.Tensor((1, 32, n, 80), dtype="float16")) | |
lv1481 = R.call_tir(cls.transpose5, (lv1478,), out_sinfo=R.Tensor((1, 32, m, 80), dtype="float16")) | |
lv1482 = R.call_tir(cls.transpose5, (lv1479,), out_sinfo=R.Tensor((1, 32, m, 80), dtype="float16")) | |
lv561 = R.call_tir(cls.fused_NT_matmul1_divide_maximum_minimum_cast2, (lv1480, lv1481, lv5), out_sinfo=R.Tensor((1, 32, n, m), dtype="float32")) | |
lv562 = R.call_tir(cls.fused_softmax_cast3, (lv561,), out_sinfo=R.Tensor((1, 32, n, m), dtype="float16")) | |
lv1491 = R.call_tir(cls.matmul8, (lv562, lv1482), out_sinfo=R.Tensor((1, 32, n, 80), dtype="float16")) | |
lv1492 = R.call_tir(cls.transpose6, (lv1491,), out_sinfo=R.Tensor((1, n, 32, 80), dtype="float16")) | |
lv1493 = R.call_tir(cls.reshape4, (lv1492,), out_sinfo=R.Tensor((1, n, 2560), dtype="float16")) | |
lv563_1: R.Tensor((2560, 320), dtype="uint32") = model_params[409] | |
lv564_1: R.Tensor((2560, 80), dtype="float16") = model_params[410] | |
lv409_1: R.Tensor((2560,), dtype="float16") = model_params[411] | |
lv114_3 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul2_add1_add2, (lv563_1, lv564_1, lv1493, lv409_1, lv113_2), out_sinfo=R.Tensor((1, n, 2560), dtype="float16")) | |
lv1498 = R.call_tir(cls.cast, (lv114_3,), out_sinfo=R.Tensor((1, n, 2560), dtype="float32")) | |
lv410_1: R.Tensor((2560,), dtype="float32") = model_params[404] | |
lv411: R.Tensor((2560,), dtype="float32") = model_params[405] | |
lv567 = R.call_tir(cls.fused_layer_norm_cast1, (lv1498, lv410_1, lv411), out_sinfo=R.Tensor((1, n, 2560), dtype="float16")) | |
lv1501: R.Tensor((1, n, 2560), dtype="float16") = lv567 | |
lv568: R.Tensor((10240, 320), dtype="uint32") = model_params[412] | |
lv569: R.Tensor((10240, 80), dtype="float16") = model_params[413] | |
lv414_1: R.Tensor((10240,), dtype="float32") = model_params[414] | |
lv115_3 = R.call_tir(cls.fused_fused_decode4_fused_NT_matmul3_add3_gelu_cast4, (lv568, lv569, lv1501, lv414_1), out_sinfo=R.Tensor((1, n, 10240), dtype="float16")) | |
lv1507: R.Tensor((1, n, 10240), dtype="float16") = lv115_3 | |
lv572: R.Tensor((2560, 1280), dtype="uint32") = model_params[415] | |
lv573_1: R.Tensor((2560, 320), dtype="float16") = model_params[416] | |
lv417: R.Tensor((2560,), dtype="float32") = model_params[417] | |
lv115_4 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul4_add4_cast1_cast5_add2, (lv572, lv573_1, lv1507, lv417, lv114_3), out_sinfo=R.Tensor((1, n, 2560), dtype="float16")) | |
lv1514 = R.call_tir(cls.cast, (lv115_4,), out_sinfo=R.Tensor((1, n, 2560), dtype="float32")) | |
lv418_1: R.Tensor((2560,), dtype="float32") = model_params[418] | |
lv419_2: R.Tensor((2560,), dtype="float32") = model_params[419] | |
lv576 = R.call_tir(cls.fused_layer_norm_cast1, (lv1514, lv418_1, lv419_2), out_sinfo=R.Tensor((1, n, 2560), dtype="float16")) | |
lv1517: R.Tensor((1, n, 2560), dtype="float16") = lv576 | |
lv577: R.Tensor((7680, 320), dtype="uint32") = model_params[422] | |
lv578: R.Tensor((7680, 80), dtype="float16") = model_params[423] | |
lv422_2: R.Tensor((7680,), dtype="float16") = model_params[424] | |
lv116_1 = R.call_tir(cls.fused_fused_decode2_fused_NT_matmul_add, (lv577, lv578, lv1517, lv422_2), out_sinfo=R.Tensor((1, n, 7680), dtype="float16")) | |
lv1521 = R.call_tir(cls.reshape2, (lv116_1,), out_sinfo=R.Tensor((1, n, 32, 240), dtype="float16")) | |
lv1522 = R.call_tir(cls.split, (lv1521,), out_sinfo=[R.Tensor((1, n, 32, 80), dtype="float16"), R.Tensor((1, n, 32, 80), dtype="float16"), R.Tensor((1, n, 32, 80), dtype="float16")]) | |
lv1523: R.Tensor((1, n, 32, 80), dtype="float16") = lv1522[0] | |
lv1524 = R.call_tir(cls.rotary_embedding, (lv1523, metadata["relax.expr.Constant"][3], metadata["relax.expr.Constant"][4]), out_sinfo=R.Tensor((1, n, 32, 80), dtype="float16"), tir_vars=R.shape([m])) | |
lv1525: R.Tensor((1, n, 32, 80), dtype="float16") = lv1522[1] | |
lv1526 = R.call_tir(cls.rotary_embedding, (lv1525, metadata["relax.expr.Constant"][3], metadata["relax.expr.Constant"][4]), out_sinfo=R.Tensor((1, n, 32, 80), dtype="float16"), tir_vars=R.shape([m])) | |
lv1527: R.Object = kv_cache[52] | |
lv1528 = R.call_tir(cls.squeeze, (lv1526,), out_sinfo=R.Tensor((n, 32, 80), dtype="float16")) | |
lv1529: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv1527, lv1528, sinfo_args=(R.Object,)) | |
lv1530: R.Object = kv_cache[53] | |
lv581: R.Tensor((1, n, 32, 80), dtype="float16") = lv1522[2] | |
lv582 = R.call_tir(cls.fused_squeeze, (lv581,), out_sinfo=R.Tensor((n, 32, 80), dtype="float16")) | |
lv1533: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv1530, lv582, sinfo_args=(R.Object,)) | |
lv1534: R.Tensor((m, 32, 80), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv1529, R.shape([m, 32, 80]), sinfo_args=(R.Tensor((m, 32, 80), dtype="float16"),)) | |
lv1535: R.Tensor((m, 32, 80), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv1533, R.shape([m, 32, 80]), sinfo_args=(R.Tensor((m, 32, 80), dtype="float16"),)) | |
lv1536 = R.call_tir(cls.reshape3, (lv1534,), out_sinfo=R.Tensor((1, m, 32, 80), dtype="float16")) | |
lv1537 = R.call_tir(cls.reshape3, (lv1535,), out_sinfo=R.Tensor((1, m, 32, 80), dtype="float16")) | |
lv1538 = R.call_tir(cls.transpose5, (lv1524,), out_sinfo=R.Tensor((1, 32, n, 80), dtype="float16")) | |
lv1539 = R.call_tir(cls.transpose5, (lv1536,), out_sinfo=R.Tensor((1, 32, m, 80), dtype="float16")) | |
lv1540 = R.call_tir(cls.transpose5, (lv1537,), out_sinfo=R.Tensor((1, 32, m, 80), dtype="float16")) | |
lv583 = R.call_tir(cls.fused_NT_matmul1_divide_maximum_minimum_cast2, (lv1538, lv1539, lv5), out_sinfo=R.Tensor((1, 32, n, m), dtype="float32")) | |
lv584 = R.call_tir(cls.fused_softmax_cast3, (lv583,), out_sinfo=R.Tensor((1, 32, n, m), dtype="float16")) | |
lv1549 = R.call_tir(cls.matmul8, (lv584, lv1540), out_sinfo=R.Tensor((1, 32, n, 80), dtype="float16")) | |
lv1550 = R.call_tir(cls.transpose6, (lv1549,), out_sinfo=R.Tensor((1, n, 32, 80), dtype="float16")) | |
lv1551 = R.call_tir(cls.reshape4, (lv1550,), out_sinfo=R.Tensor((1, n, 2560), dtype="float16")) | |
lv585: R.Tensor((2560, 320), dtype="uint32") = model_params[425] | |
lv586_1: R.Tensor((2560, 80), dtype="float16") = model_params[426] | |
lv425_1: R.Tensor((2560,), dtype="float16") = model_params[427] | |
lv116_2 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul2_add1_add2, (lv585, lv586_1, lv1551, lv425_1, lv115_4), out_sinfo=R.Tensor((1, n, 2560), dtype="float16")) | |
lv1556 = R.call_tir(cls.cast, (lv116_2,), out_sinfo=R.Tensor((1, n, 2560), dtype="float32")) | |
lv426_1: R.Tensor((2560,), dtype="float32") = model_params[420] | |
lv427_2: R.Tensor((2560,), dtype="float32") = model_params[421] | |
lv589_1 = R.call_tir(cls.fused_layer_norm_cast1, (lv1556, lv426_1, lv427_2), out_sinfo=R.Tensor((1, n, 2560), dtype="float16")) | |
lv1559: R.Tensor((1, n, 2560), dtype="float16") = lv589_1 | |
lv590: R.Tensor((10240, 320), dtype="uint32") = model_params[428] | |
lv591: R.Tensor((10240, 80), dtype="float16") = model_params[429] | |
lv430_1: R.Tensor((10240,), dtype="float32") = model_params[430] | |
lv117 = R.call_tir(cls.fused_fused_decode4_fused_NT_matmul3_add3_gelu_cast4, (lv590, lv591, lv1559, lv430_1), out_sinfo=R.Tensor((1, n, 10240), dtype="float16")) | |
lv1565: R.Tensor((1, n, 10240), dtype="float16") = lv117 | |
lv594_1: R.Tensor((2560, 1280), dtype="uint32") = model_params[431] | |
lv595_1: R.Tensor((2560, 320), dtype="float16") = model_params[432] | |
lv433_1: R.Tensor((2560,), dtype="float32") = model_params[433] | |
lv117_1 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul4_add4_cast1_cast5_add2, (lv594_1, lv595_1, lv1565, lv433_1, lv116_2), out_sinfo=R.Tensor((1, n, 2560), dtype="float16")) | |
lv1572 = R.call_tir(cls.cast, (lv117_1,), out_sinfo=R.Tensor((1, n, 2560), dtype="float32")) | |
lv434_1: R.Tensor((2560,), dtype="float32") = model_params[434] | |
lv435_2: R.Tensor((2560,), dtype="float32") = model_params[435] | |
lv598_1 = R.call_tir(cls.fused_layer_norm_cast1, (lv1572, lv434_1, lv435_2), out_sinfo=R.Tensor((1, n, 2560), dtype="float16")) | |
lv1575: R.Tensor((1, n, 2560), dtype="float16") = lv598_1 | |
lv599_1: R.Tensor((7680, 320), dtype="uint32") = model_params[438] | |
lv600_1: R.Tensor((7680, 80), dtype="float16") = model_params[439] | |
lv438_1: R.Tensor((7680,), dtype="float16") = model_params[440] | |
lv118_1 = R.call_tir(cls.fused_fused_decode2_fused_NT_matmul_add, (lv599_1, lv600_1, lv1575, lv438_1), out_sinfo=R.Tensor((1, n, 7680), dtype="float16")) | |
lv1579 = R.call_tir(cls.reshape2, (lv118_1,), out_sinfo=R.Tensor((1, n, 32, 240), dtype="float16")) | |
lv1580 = R.call_tir(cls.split, (lv1579,), out_sinfo=[R.Tensor((1, n, 32, 80), dtype="float16"), R.Tensor((1, n, 32, 80), dtype="float16"), R.Tensor((1, n, 32, 80), dtype="float16")]) | |
lv1581: R.Tensor((1, n, 32, 80), dtype="float16") = lv1580[0] | |
lv1582 = R.call_tir(cls.rotary_embedding, (lv1581, metadata["relax.expr.Constant"][3], metadata["relax.expr.Constant"][4]), out_sinfo=R.Tensor((1, n, 32, 80), dtype="float16"), tir_vars=R.shape([m])) | |
lv1583: R.Tensor((1, n, 32, 80), dtype="float16") = lv1580[1] | |
lv1584 = R.call_tir(cls.rotary_embedding, (lv1583, metadata["relax.expr.Constant"][3], metadata["relax.expr.Constant"][4]), out_sinfo=R.Tensor((1, n, 32, 80), dtype="float16"), tir_vars=R.shape([m])) | |
lv1585: R.Object = kv_cache[54] | |
lv1586 = R.call_tir(cls.squeeze, (lv1584,), out_sinfo=R.Tensor((n, 32, 80), dtype="float16")) | |
lv1587: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv1585, lv1586, sinfo_args=(R.Object,)) | |
lv1588: R.Object = kv_cache[55] | |
lv603: R.Tensor((1, n, 32, 80), dtype="float16") = lv1580[2] | |
lv604 = R.call_tir(cls.fused_squeeze, (lv603,), out_sinfo=R.Tensor((n, 32, 80), dtype="float16")) | |
lv1591: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv1588, lv604, sinfo_args=(R.Object,)) | |
lv1592: R.Tensor((m, 32, 80), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv1587, R.shape([m, 32, 80]), sinfo_args=(R.Tensor((m, 32, 80), dtype="float16"),)) | |
lv1593: R.Tensor((m, 32, 80), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv1591, R.shape([m, 32, 80]), sinfo_args=(R.Tensor((m, 32, 80), dtype="float16"),)) | |
lv1594 = R.call_tir(cls.reshape3, (lv1592,), out_sinfo=R.Tensor((1, m, 32, 80), dtype="float16")) | |
lv1595 = R.call_tir(cls.reshape3, (lv1593,), out_sinfo=R.Tensor((1, m, 32, 80), dtype="float16")) | |
lv1596 = R.call_tir(cls.transpose5, (lv1582,), out_sinfo=R.Tensor((1, 32, n, 80), dtype="float16")) | |
lv1597 = R.call_tir(cls.transpose5, (lv1594,), out_sinfo=R.Tensor((1, 32, m, 80), dtype="float16")) | |
lv1598 = R.call_tir(cls.transpose5, (lv1595,), out_sinfo=R.Tensor((1, 32, m, 80), dtype="float16")) | |
lv605_1 = R.call_tir(cls.fused_NT_matmul1_divide_maximum_minimum_cast2, (lv1596, lv1597, lv5), out_sinfo=R.Tensor((1, 32, n, m), dtype="float32")) | |
lv606_1 = R.call_tir(cls.fused_softmax_cast3, (lv605_1,), out_sinfo=R.Tensor((1, 32, n, m), dtype="float16")) | |
lv1607 = R.call_tir(cls.matmul8, (lv606_1, lv1598), out_sinfo=R.Tensor((1, 32, n, 80), dtype="float16")) | |
lv1608 = R.call_tir(cls.transpose6, (lv1607,), out_sinfo=R.Tensor((1, n, 32, 80), dtype="float16")) | |
lv1609 = R.call_tir(cls.reshape4, (lv1608,), out_sinfo=R.Tensor((1, n, 2560), dtype="float16")) | |
lv607_1: R.Tensor((2560, 320), dtype="uint32") = model_params[441] | |
lv608_1: R.Tensor((2560, 80), dtype="float16") = model_params[442] | |
lv441_1: R.Tensor((2560,), dtype="float16") = model_params[443] | |
lv118_2 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul2_add1_add2, (lv607_1, lv608_1, lv1609, lv441_1, lv117_1), out_sinfo=R.Tensor((1, n, 2560), dtype="float16")) | |
lv1614 = R.call_tir(cls.cast, (lv118_2,), out_sinfo=R.Tensor((1, n, 2560), dtype="float32")) | |
lv442: R.Tensor((2560,), dtype="float32") = model_params[436] | |
lv443: R.Tensor((2560,), dtype="float32") = model_params[437] | |
lv611_1 = R.call_tir(cls.fused_layer_norm_cast1, (lv1614, lv442, lv443), out_sinfo=R.Tensor((1, n, 2560), dtype="float16")) | |
lv1617: R.Tensor((1, n, 2560), dtype="float16") = lv611_1 | |
lv612_1: R.Tensor((10240, 320), dtype="uint32") = model_params[444] | |
lv613: R.Tensor((10240, 80), dtype="float16") = model_params[445] | |
lv446_1: R.Tensor((10240,), dtype="float32") = model_params[446] | |
lv119_1 = R.call_tir(cls.fused_fused_decode4_fused_NT_matmul3_add3_gelu_cast4, (lv612_1, lv613, lv1617, lv446_1), out_sinfo=R.Tensor((1, n, 10240), dtype="float16")) | |
lv1623: R.Tensor((1, n, 10240), dtype="float16") = lv119_1 | |
lv616: R.Tensor((2560, 1280), dtype="uint32") = model_params[447] | |
lv617: R.Tensor((2560, 320), dtype="float16") = model_params[448] | |
lv449_2: R.Tensor((2560,), dtype="float32") = model_params[449] | |
lv119_2 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul4_add4_cast1_cast5_add2, (lv616, lv617, lv1623, lv449_2, lv118_2), out_sinfo=R.Tensor((1, n, 2560), dtype="float16")) | |
lv1630 = R.call_tir(cls.cast, (lv119_2,), out_sinfo=R.Tensor((1, n, 2560), dtype="float32")) | |
lv450_1: R.Tensor((2560,), dtype="float32") = model_params[450] | |
lv451_1: R.Tensor((2560,), dtype="float32") = model_params[451] | |
lv620 = R.call_tir(cls.fused_layer_norm_cast1, (lv1630, lv450_1, lv451_1), out_sinfo=R.Tensor((1, n, 2560), dtype="float16")) | |
lv1633: R.Tensor((1, n, 2560), dtype="float16") = lv620 | |
lv621_1: R.Tensor((7680, 320), dtype="uint32") = model_params[454] | |
lv622_1: R.Tensor((7680, 80), dtype="float16") = model_params[455] | |
lv454_2: R.Tensor((7680,), dtype="float16") = model_params[456] | |
lv120_1 = R.call_tir(cls.fused_fused_decode2_fused_NT_matmul_add, (lv621_1, lv622_1, lv1633, lv454_2), out_sinfo=R.Tensor((1, n, 7680), dtype="float16")) | |
lv1637 = R.call_tir(cls.reshape2, (lv120_1,), out_sinfo=R.Tensor((1, n, 32, 240), dtype="float16")) | |
lv1638 = R.call_tir(cls.split, (lv1637,), out_sinfo=[R.Tensor((1, n, 32, 80), dtype="float16"), R.Tensor((1, n, 32, 80), dtype="float16"), R.Tensor((1, n, 32, 80), dtype="float16")]) | |
lv1639: R.Tensor((1, n, 32, 80), dtype="float16") = lv1638[0] | |
lv1640 = R.call_tir(cls.rotary_embedding, (lv1639, metadata["relax.expr.Constant"][3], metadata["relax.expr.Constant"][4]), out_sinfo=R.Tensor((1, n, 32, 80), dtype="float16"), tir_vars=R.shape([m])) | |
lv1641: R.Tensor((1, n, 32, 80), dtype="float16") = lv1638[1] | |
lv1642 = R.call_tir(cls.rotary_embedding, (lv1641, metadata["relax.expr.Constant"][3], metadata["relax.expr.Constant"][4]), out_sinfo=R.Tensor((1, n, 32, 80), dtype="float16"), tir_vars=R.shape([m])) | |
lv1643: R.Object = kv_cache[56] | |
lv1644 = R.call_tir(cls.squeeze, (lv1642,), out_sinfo=R.Tensor((n, 32, 80), dtype="float16")) | |
lv1645: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv1643, lv1644, sinfo_args=(R.Object,)) | |
lv1646: R.Object = kv_cache[57] | |
lv625: R.Tensor((1, n, 32, 80), dtype="float16") = lv1638[2] | |
lv626 = R.call_tir(cls.fused_squeeze, (lv625,), out_sinfo=R.Tensor((n, 32, 80), dtype="float16")) | |
lv1649: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv1646, lv626, sinfo_args=(R.Object,)) | |
lv1650: R.Tensor((m, 32, 80), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv1645, R.shape([m, 32, 80]), sinfo_args=(R.Tensor((m, 32, 80), dtype="float16"),)) | |
lv1651: R.Tensor((m, 32, 80), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv1649, R.shape([m, 32, 80]), sinfo_args=(R.Tensor((m, 32, 80), dtype="float16"),)) | |
lv1652 = R.call_tir(cls.reshape3, (lv1650,), out_sinfo=R.Tensor((1, m, 32, 80), dtype="float16")) | |
lv1653 = R.call_tir(cls.reshape3, (lv1651,), out_sinfo=R.Tensor((1, m, 32, 80), dtype="float16")) | |
lv1654 = R.call_tir(cls.transpose5, (lv1640,), out_sinfo=R.Tensor((1, 32, n, 80), dtype="float16")) | |
lv1655 = R.call_tir(cls.transpose5, (lv1652,), out_sinfo=R.Tensor((1, 32, m, 80), dtype="float16")) | |
lv1656 = R.call_tir(cls.transpose5, (lv1653,), out_sinfo=R.Tensor((1, 32, m, 80), dtype="float16")) | |
lv627 = R.call_tir(cls.fused_NT_matmul1_divide_maximum_minimum_cast2, (lv1654, lv1655, lv5), out_sinfo=R.Tensor((1, 32, n, m), dtype="float32")) | |
lv628_1 = R.call_tir(cls.fused_softmax_cast3, (lv627,), out_sinfo=R.Tensor((1, 32, n, m), dtype="float16")) | |
lv1665 = R.call_tir(cls.matmul8, (lv628_1, lv1656), out_sinfo=R.Tensor((1, 32, n, 80), dtype="float16")) | |
lv1666 = R.call_tir(cls.transpose6, (lv1665,), out_sinfo=R.Tensor((1, n, 32, 80), dtype="float16")) | |
lv1667 = R.call_tir(cls.reshape4, (lv1666,), out_sinfo=R.Tensor((1, n, 2560), dtype="float16")) | |
lv629: R.Tensor((2560, 320), dtype="uint32") = model_params[457] | |
lv630: R.Tensor((2560, 80), dtype="float16") = model_params[458] | |
lv457_2: R.Tensor((2560,), dtype="float16") = model_params[459] | |
lv120_2 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul2_add1_add2, (lv629, lv630, lv1667, lv457_2, lv119_2), out_sinfo=R.Tensor((1, n, 2560), dtype="float16")) | |
lv1672 = R.call_tir(cls.cast, (lv120_2,), out_sinfo=R.Tensor((1, n, 2560), dtype="float32")) | |
lv458_1: R.Tensor((2560,), dtype="float32") = model_params[452] | |
lv459_1: R.Tensor((2560,), dtype="float32") = model_params[453] | |
lv633 = R.call_tir(cls.fused_layer_norm_cast1, (lv1672, lv458_1, lv459_1), out_sinfo=R.Tensor((1, n, 2560), dtype="float16")) | |
lv1675: R.Tensor((1, n, 2560), dtype="float16") = lv633 | |
lv634: R.Tensor((10240, 320), dtype="uint32") = model_params[460] | |
lv635: R.Tensor((10240, 80), dtype="float16") = model_params[461] | |
lv462_1: R.Tensor((10240,), dtype="float32") = model_params[462] | |
lv121_2 = R.call_tir(cls.fused_fused_decode4_fused_NT_matmul3_add3_gelu_cast4, (lv634, lv635, lv1675, lv462_1), out_sinfo=R.Tensor((1, n, 10240), dtype="float16")) | |
lv1681: R.Tensor((1, n, 10240), dtype="float16") = lv121_2 | |
lv638: R.Tensor((2560, 1280), dtype="uint32") = model_params[463] | |
lv639: R.Tensor((2560, 320), dtype="float16") = model_params[464] | |
lv465: R.Tensor((2560,), dtype="float32") = model_params[465] | |
lv121_3 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul4_add4_cast1_cast5_add2, (lv638, lv639, lv1681, lv465, lv120_2), out_sinfo=R.Tensor((1, n, 2560), dtype="float16")) | |
lv1688 = R.call_tir(cls.cast, (lv121_3,), out_sinfo=R.Tensor((1, n, 2560), dtype="float32")) | |
lv466_1: R.Tensor((2560,), dtype="float32") = model_params[466] | |
lv467_1: R.Tensor((2560,), dtype="float32") = model_params[467] | |
lv642 = R.call_tir(cls.fused_layer_norm_cast1, (lv1688, lv466_1, lv467_1), out_sinfo=R.Tensor((1, n, 2560), dtype="float16")) | |
lv1691: R.Tensor((1, n, 2560), dtype="float16") = lv642 | |
lv643: R.Tensor((7680, 320), dtype="uint32") = model_params[470] | |
lv644_1: R.Tensor((7680, 80), dtype="float16") = model_params[471] | |
lv470_1: R.Tensor((7680,), dtype="float16") = model_params[472] | |
lv122_3 = R.call_tir(cls.fused_fused_decode2_fused_NT_matmul_add, (lv643, lv644_1, lv1691, lv470_1), out_sinfo=R.Tensor((1, n, 7680), dtype="float16")) | |
lv1695 = R.call_tir(cls.reshape2, (lv122_3,), out_sinfo=R.Tensor((1, n, 32, 240), dtype="float16")) | |
lv1696 = R.call_tir(cls.split, (lv1695,), out_sinfo=[R.Tensor((1, n, 32, 80), dtype="float16"), R.Tensor((1, n, 32, 80), dtype="float16"), R.Tensor((1, n, 32, 80), dtype="float16")]) | |
lv1697: R.Tensor((1, n, 32, 80), dtype="float16") = lv1696[0] | |
lv1698 = R.call_tir(cls.rotary_embedding, (lv1697, metadata["relax.expr.Constant"][3], metadata["relax.expr.Constant"][4]), out_sinfo=R.Tensor((1, n, 32, 80), dtype="float16"), tir_vars=R.shape([m])) | |
lv1699: R.Tensor((1, n, 32, 80), dtype="float16") = lv1696[1] | |
lv1700 = R.call_tir(cls.rotary_embedding, (lv1699, metadata["relax.expr.Constant"][3], metadata["relax.expr.Constant"][4]), out_sinfo=R.Tensor((1, n, 32, 80), dtype="float16"), tir_vars=R.shape([m])) | |
lv1701: R.Object = kv_cache[58] | |
lv1702 = R.call_tir(cls.squeeze, (lv1700,), out_sinfo=R.Tensor((n, 32, 80), dtype="float16")) | |
lv1703: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv1701, lv1702, sinfo_args=(R.Object,)) | |
lv1704: R.Object = kv_cache[59] | |
lv647_1: R.Tensor((1, n, 32, 80), dtype="float16") = lv1696[2] | |
lv648 = R.call_tir(cls.fused_squeeze, (lv647_1,), out_sinfo=R.Tensor((n, 32, 80), dtype="float16")) | |
lv1707: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv1704, lv648, sinfo_args=(R.Object,)) | |
lv1708: R.Tensor((m, 32, 80), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv1703, R.shape([m, 32, 80]), sinfo_args=(R.Tensor((m, 32, 80), dtype="float16"),)) | |
lv1709: R.Tensor((m, 32, 80), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv1707, R.shape([m, 32, 80]), sinfo_args=(R.Tensor((m, 32, 80), dtype="float16"),)) | |
lv1710 = R.call_tir(cls.reshape3, (lv1708,), out_sinfo=R.Tensor((1, m, 32, 80), dtype="float16")) | |
lv1711 = R.call_tir(cls.reshape3, (lv1709,), out_sinfo=R.Tensor((1, m, 32, 80), dtype="float16")) | |
lv1712 = R.call_tir(cls.transpose5, (lv1698,), out_sinfo=R.Tensor((1, 32, n, 80), dtype="float16")) | |
lv1713 = R.call_tir(cls.transpose5, (lv1710,), out_sinfo=R.Tensor((1, 32, m, 80), dtype="float16")) | |
lv1714 = R.call_tir(cls.transpose5, (lv1711,), out_sinfo=R.Tensor((1, 32, m, 80), dtype="float16")) | |
lv649 = R.call_tir(cls.fused_NT_matmul1_divide_maximum_minimum_cast2, (lv1712, lv1713, lv5), out_sinfo=R.Tensor((1, 32, n, m), dtype="float32")) | |
lv650 = R.call_tir(cls.fused_softmax_cast3, (lv649,), out_sinfo=R.Tensor((1, 32, n, m), dtype="float16")) | |
lv1723 = R.call_tir(cls.matmul8, (lv650, lv1714), out_sinfo=R.Tensor((1, 32, n, 80), dtype="float16")) | |
lv1724 = R.call_tir(cls.transpose6, (lv1723,), out_sinfo=R.Tensor((1, n, 32, 80), dtype="float16")) | |
lv1725 = R.call_tir(cls.reshape4, (lv1724,), out_sinfo=R.Tensor((1, n, 2560), dtype="float16")) | |
lv651_1: R.Tensor((2560, 320), dtype="uint32") = model_params[473] | |
lv652_1: R.Tensor((2560, 80), dtype="float16") = model_params[474] | |
lv473_2: R.Tensor((2560,), dtype="float16") = model_params[475] | |
lv122_4 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul2_add1_add2, (lv651_1, lv652_1, lv1725, lv473_2, lv121_3), out_sinfo=R.Tensor((1, n, 2560), dtype="float16")) | |
lv1730 = R.call_tir(cls.cast, (lv122_4,), out_sinfo=R.Tensor((1, n, 2560), dtype="float32")) | |
lv474_1: R.Tensor((2560,), dtype="float32") = model_params[468] | |
lv475_1: R.Tensor((2560,), dtype="float32") = model_params[469] | |
lv655_1 = R.call_tir(cls.fused_layer_norm_cast1, (lv1730, lv474_1, lv475_1), out_sinfo=R.Tensor((1, n, 2560), dtype="float16")) | |
lv1733: R.Tensor((1, n, 2560), dtype="float16") = lv655_1 | |
lv656_1: R.Tensor((10240, 320), dtype="uint32") = model_params[476] | |
lv657_1: R.Tensor((10240, 80), dtype="float16") = model_params[477] | |
lv478_1: R.Tensor((10240,), dtype="float32") = model_params[478] | |
lv123_2 = R.call_tir(cls.fused_fused_decode4_fused_NT_matmul3_add3_gelu_cast4, (lv656_1, lv657_1, lv1733, lv478_1), out_sinfo=R.Tensor((1, n, 10240), dtype="float16")) | |
lv1739: R.Tensor((1, n, 10240), dtype="float16") = lv123_2 | |
lv660_1: R.Tensor((2560, 1280), dtype="uint32") = model_params[479] | |
lv661: R.Tensor((2560, 320), dtype="float16") = model_params[480] | |
lv481_2: R.Tensor((2560,), dtype="float32") = model_params[481] | |
lv123_3 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul4_add4_cast1_cast5_add2, (lv660_1, lv661, lv1739, lv481_2, lv122_4), out_sinfo=R.Tensor((1, n, 2560), dtype="float16")) | |
lv1746 = R.call_tir(cls.cast, (lv123_3,), out_sinfo=R.Tensor((1, n, 2560), dtype="float32")) | |
lv482_1: R.Tensor((2560,), dtype="float32") = model_params[482] | |
lv483_1: R.Tensor((2560,), dtype="float32") = model_params[483] | |
lv664_1 = R.call_tir(cls.fused_layer_norm_cast1, (lv1746, lv482_1, lv483_1), out_sinfo=R.Tensor((1, n, 2560), dtype="float16")) | |
lv1749: R.Tensor((1, n, 2560), dtype="float16") = lv664_1 | |
lv665_1: R.Tensor((7680, 320), dtype="uint32") = model_params[486] | |
lv666_1: R.Tensor((7680, 80), dtype="float16") = model_params[487] | |
lv486_1: R.Tensor((7680,), dtype="float16") = model_params[488] | |
lv124_1 = R.call_tir(cls.fused_fused_decode2_fused_NT_matmul_add, (lv665_1, lv666_1, lv1749, lv486_1), out_sinfo=R.Tensor((1, n, 7680), dtype="float16")) | |
lv1753 = R.call_tir(cls.reshape2, (lv124_1,), out_sinfo=R.Tensor((1, n, 32, 240), dtype="float16")) | |
lv1754 = R.call_tir(cls.split, (lv1753,), out_sinfo=[R.Tensor((1, n, 32, 80), dtype="float16"), R.Tensor((1, n, 32, 80), dtype="float16"), R.Tensor((1, n, 32, 80), dtype="float16")]) | |
lv1755: R.Tensor((1, n, 32, 80), dtype="float16") = lv1754[0] | |
lv1756 = R.call_tir(cls.rotary_embedding, (lv1755, metadata["relax.expr.Constant"][3], metadata["relax.expr.Constant"][4]), out_sinfo=R.Tensor((1, n, 32, 80), dtype="float16"), tir_vars=R.shape([m])) | |
lv1757: R.Tensor((1, n, 32, 80), dtype="float16") = lv1754[1] | |
lv1758 = R.call_tir(cls.rotary_embedding, (lv1757, metadata["relax.expr.Constant"][3], metadata["relax.expr.Constant"][4]), out_sinfo=R.Tensor((1, n, 32, 80), dtype="float16"), tir_vars=R.shape([m])) | |
lv1759: R.Object = kv_cache[60] | |
lv1760 = R.call_tir(cls.squeeze, (lv1758,), out_sinfo=R.Tensor((n, 32, 80), dtype="float16")) | |
lv1761: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv1759, lv1760, sinfo_args=(R.Object,)) | |
lv1762: R.Object = kv_cache[61] | |
lv669_1: R.Tensor((1, n, 32, 80), dtype="float16") = lv1754[2] | |
lv670_1 = R.call_tir(cls.fused_squeeze, (lv669_1,), out_sinfo=R.Tensor((n, 32, 80), dtype="float16")) | |
lv1765: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv1762, lv670_1, sinfo_args=(R.Object,)) | |
lv1766: R.Tensor((m, 32, 80), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv1761, R.shape([m, 32, 80]), sinfo_args=(R.Tensor((m, 32, 80), dtype="float16"),)) | |
lv1767: R.Tensor((m, 32, 80), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv1765, R.shape([m, 32, 80]), sinfo_args=(R.Tensor((m, 32, 80), dtype="float16"),)) | |
lv1768 = R.call_tir(cls.reshape3, (lv1766,), out_sinfo=R.Tensor((1, m, 32, 80), dtype="float16")) | |
lv1769 = R.call_tir(cls.reshape3, (lv1767,), out_sinfo=R.Tensor((1, m, 32, 80), dtype="float16")) | |
lv1770 = R.call_tir(cls.transpose5, (lv1756,), out_sinfo=R.Tensor((1, 32, n, 80), dtype="float16")) | |
lv1771 = R.call_tir(cls.transpose5, (lv1768,), out_sinfo=R.Tensor((1, 32, m, 80), dtype="float16")) | |
lv1772 = R.call_tir(cls.transpose5, (lv1769,), out_sinfo=R.Tensor((1, 32, m, 80), dtype="float16")) | |
lv671 = R.call_tir(cls.fused_NT_matmul1_divide_maximum_minimum_cast2, (lv1770, lv1771, lv5), out_sinfo=R.Tensor((1, 32, n, m), dtype="float32")) | |
lv672 = R.call_tir(cls.fused_softmax_cast3, (lv671,), out_sinfo=R.Tensor((1, 32, n, m), dtype="float16")) | |
lv1781 = R.call_tir(cls.matmul8, (lv672, lv1772), out_sinfo=R.Tensor((1, 32, n, 80), dtype="float16")) | |
lv1782 = R.call_tir(cls.transpose6, (lv1781,), out_sinfo=R.Tensor((1, n, 32, 80), dtype="float16")) | |
lv1783 = R.call_tir(cls.reshape4, (lv1782,), out_sinfo=R.Tensor((1, n, 2560), dtype="float16")) | |
lv673: R.Tensor((2560, 320), dtype="uint32") = model_params[489] | |
lv674: R.Tensor((2560, 80), dtype="float16") = model_params[490] | |
lv489_2: R.Tensor((2560,), dtype="float16") = model_params[491] | |
lv124_2 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul2_add1_add2, (lv673, lv674, lv1783, lv489_2, lv123_3), out_sinfo=R.Tensor((1, n, 2560), dtype="float16")) | |
lv1788 = R.call_tir(cls.cast, (lv124_2,), out_sinfo=R.Tensor((1, n, 2560), dtype="float32")) | |
lv490_2: R.Tensor((2560,), dtype="float32") = model_params[484] | |
lv491_1: R.Tensor((2560,), dtype="float32") = model_params[485] | |
lv677 = R.call_tir(cls.fused_layer_norm_cast1, (lv1788, lv490_2, lv491_1), out_sinfo=R.Tensor((1, n, 2560), dtype="float16")) | |
lv1791: R.Tensor((1, n, 2560), dtype="float16") = lv677 | |
lv678: R.Tensor((10240, 320), dtype="uint32") = model_params[492] | |
lv679_1: R.Tensor((10240, 80), dtype="float16") = model_params[493] | |
lv494_2: R.Tensor((10240,), dtype="float32") = model_params[494] | |
lv125_1 = R.call_tir(cls.fused_fused_decode4_fused_NT_matmul3_add3_gelu_cast4, (lv678, lv679_1, lv1791, lv494_2), out_sinfo=R.Tensor((1, n, 10240), dtype="float16")) | |
lv1797: R.Tensor((1, n, 10240), dtype="float16") = lv125_1 | |
lv682: R.Tensor((2560, 1280), dtype="uint32") = model_params[495] | |
lv683: R.Tensor((2560, 320), dtype="float16") = model_params[496] | |
lv497_1: R.Tensor((2560,), dtype="float32") = model_params[497] | |
lv125_2 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul4_add4_cast1_cast5_add2, (lv682, lv683, lv1797, lv497_1, lv124_2), out_sinfo=R.Tensor((1, n, 2560), dtype="float16")) | |
lv1804 = R.call_tir(cls.cast, (lv125_2,), out_sinfo=R.Tensor((1, n, 2560), dtype="float32")) | |
lv498_1: R.Tensor((2560,), dtype="float32") = model_params[498] | |
lv499: R.Tensor((2560,), dtype="float32") = model_params[499] | |
lv686_1 = R.call_tir(cls.fused_layer_norm_cast1, (lv1804, lv498_1, lv499), out_sinfo=R.Tensor((1, n, 2560), dtype="float16")) | |
lv1807: R.Tensor((1, n, 2560), dtype="float16") = lv686_1 | |
lv687: R.Tensor((7680, 320), dtype="uint32") = model_params[502] | |
lv688: R.Tensor((7680, 80), dtype="float16") = model_params[503] | |
lv502_1: R.Tensor((7680,), dtype="float16") = model_params[504] | |
lv126_1 = R.call_tir(cls.fused_fused_decode2_fused_NT_matmul_add, (lv687, lv688, lv1807, lv502_1), out_sinfo=R.Tensor((1, n, 7680), dtype="float16")) | |
lv1811 = R.call_tir(cls.reshape2, (lv126_1,), out_sinfo=R.Tensor((1, n, 32, 240), dtype="float16")) | |
lv1812 = R.call_tir(cls.split, (lv1811,), out_sinfo=[R.Tensor((1, n, 32, 80), dtype="float16"), R.Tensor((1, n, 32, 80), dtype="float16"), R.Tensor((1, n, 32, 80), dtype="float16")]) | |
lv1813: R.Tensor((1, n, 32, 80), dtype="float16") = lv1812[0] | |
lv1814 = R.call_tir(cls.rotary_embedding, (lv1813, metadata["relax.expr.Constant"][3], metadata["relax.expr.Constant"][4]), out_sinfo=R.Tensor((1, n, 32, 80), dtype="float16"), tir_vars=R.shape([m])) | |
lv1815: R.Tensor((1, n, 32, 80), dtype="float16") = lv1812[1] | |
lv1816 = R.call_tir(cls.rotary_embedding, (lv1815, metadata["relax.expr.Constant"][3], metadata["relax.expr.Constant"][4]), out_sinfo=R.Tensor((1, n, 32, 80), dtype="float16"), tir_vars=R.shape([m])) | |
lv1817: R.Object = kv_cache[62] | |
lv1818 = R.call_tir(cls.squeeze, (lv1816,), out_sinfo=R.Tensor((n, 32, 80), dtype="float16")) | |
lv1819: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv1817, lv1818, sinfo_args=(R.Object,)) | |
lv1820: R.Object = kv_cache[63] | |
lv691: R.Tensor((1, n, 32, 80), dtype="float16") = lv1812[2] | |
lv692 = R.call_tir(cls.fused_squeeze, (lv691,), out_sinfo=R.Tensor((n, 32, 80), dtype="float16")) | |
lv1823: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv1820, lv692, sinfo_args=(R.Object,)) | |
lv1824: R.Tensor((m, 32, 80), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv1819, R.shape([m, 32, 80]), sinfo_args=(R.Tensor((m, 32, 80), dtype="float16"),)) | |
lv1825: R.Tensor((m, 32, 80), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv1823, R.shape([m, 32, 80]), sinfo_args=(R.Tensor((m, 32, 80), dtype="float16"),)) | |
lv1826 = R.call_tir(cls.reshape3, (lv1824,), out_sinfo=R.Tensor((1, m, 32, 80), dtype="float16")) | |
lv1827 = R.call_tir(cls.reshape3, (lv1825,), out_sinfo=R.Tensor((1, m, 32, 80), dtype="float16")) | |
lv1828 = R.call_tir(cls.transpose5, (lv1814,), out_sinfo=R.Tensor((1, 32, n, 80), dtype="float16")) | |
lv1829 = R.call_tir(cls.transpose5, (lv1826,), out_sinfo=R.Tensor((1, 32, m, 80), dtype="float16")) | |
lv1830 = R.call_tir(cls.transpose5, (lv1827,), out_sinfo=R.Tensor((1, 32, m, 80), dtype="float16")) | |
lv693 = R.call_tir(cls.fused_NT_matmul1_divide_maximum_minimum_cast2, (lv1828, lv1829, lv5), out_sinfo=R.Tensor((1, 32, n, m), dtype="float32")) | |
lv694 = R.call_tir(cls.fused_softmax_cast3, (lv693,), out_sinfo=R.Tensor((1, 32, n, m), dtype="float16")) | |
lv1839 = R.call_tir(cls.matmul8, (lv694, lv1830), out_sinfo=R.Tensor((1, 32, n, 80), dtype="float16")) | |
lv1840 = R.call_tir(cls.transpose6, (lv1839,), out_sinfo=R.Tensor((1, n, 32, 80), dtype="float16")) | |
lv1841 = R.call_tir(cls.reshape4, (lv1840,), out_sinfo=R.Tensor((1, n, 2560), dtype="float16")) | |
lv695_1: R.Tensor((2560, 320), dtype="uint32") = model_params[505] | |
lv696: R.Tensor((2560, 80), dtype="float16") = model_params[506] | |
lv505_1: R.Tensor((2560,), dtype="float16") = model_params[507] | |
lv126_2 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul2_add1_add2, (lv695_1, lv696, lv1841, lv505_1, lv125_2), out_sinfo=R.Tensor((1, n, 2560), dtype="float16")) | |
lv1846 = R.call_tir(cls.cast, (lv126_2,), out_sinfo=R.Tensor((1, n, 2560), dtype="float32")) | |
lv506_2: R.Tensor((2560,), dtype="float32") = model_params[500] | |
lv507_2: R.Tensor((2560,), dtype="float32") = model_params[501] | |
lv699 = R.call_tir(cls.fused_layer_norm_cast1, (lv1846, lv506_2, lv507_2), out_sinfo=R.Tensor((1, n, 2560), dtype="float16")) | |
lv1849: R.Tensor((1, n, 2560), dtype="float16") = lv699 | |
lv700: R.Tensor((10240, 320), dtype="uint32") = model_params[508] | |
lv701: R.Tensor((10240, 80), dtype="float16") = model_params[509] | |
lv510_1: R.Tensor((10240,), dtype="float32") = model_params[510] | |
lv127_1 = R.call_tir(cls.fused_fused_decode4_fused_NT_matmul3_add3_gelu_cast4, (lv700, lv701, lv1849, lv510_1), out_sinfo=R.Tensor((1, n, 10240), dtype="float16")) | |
lv1855: R.Tensor((1, n, 10240), dtype="float16") = lv127_1 | |
lv704: R.Tensor((2560, 1280), dtype="uint32") = model_params[511] | |
lv705_1: R.Tensor((2560, 320), dtype="float16") = model_params[512] | |
lv513: R.Tensor((2560,), dtype="float32") = model_params[513] | |
lv127_2 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul4_add4_cast1_cast5_add2_cast, (lv704, lv705_1, lv1855, lv513, lv126_2), out_sinfo=R.Tensor((1, n, 2560), dtype="float32")) | |
lv514: R.Tensor((2560,), dtype="float32") = model_params[514] | |
lv515_2: R.Tensor((2560,), dtype="float32") = model_params[515] | |
lv1863 = R.call_tir(cls.layer_norm, (lv127_2, lv514, lv515_2), out_sinfo=R.Tensor((1, n, 2560), dtype="float32")) | |
lv1864 = R.call_tir(cls.slice, (lv1863,), out_sinfo=R.Tensor((1, 1, 2560), dtype="float32")) | |
lv1865 = R.call_tir(cls.cast6, (lv1864,), out_sinfo=R.Tensor((1, 1, 2560), dtype="float32")) | |
lv708: R.Tensor((50432, 320), dtype="uint32") = model_params[516] | |
lv709_1: R.Tensor((50432, 80), dtype="float32") = model_params[517] | |
lv1_2 = R.call_tir(cls.fused_fused_decode6_NT_matmul5, (lv708, lv709_1, lv1865), out_sinfo=R.Tensor((1, 1, 50432), dtype="float32")) | |
gv: R.Tuple(R.Tensor((1, 1, 50432), dtype="float32"), R.Tuple(R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object)) = lv1_2, (lv21, lv25, lv79, lv83, lv137, lv141, lv195, lv199, lv253, lv257, lv311, lv315, lv369, lv373, lv427, lv431, lv485, lv489, lv543, lv547, lv601, lv605, lv659, lv663, lv717, lv721, lv775, lv779, lv833, lv837, lv891, lv895, lv949, lv953, lv1007, lv1011, lv1065, lv1069, lv1123, lv1127, lv1181, lv1185, lv1239, lv1243, lv1297, lv1301, lv1355, lv1359, lv1413, lv1417, lv1471, lv1475, lv1529, lv1533, lv1587, lv1591, lv1645, lv1649, lv1703, lv1707, lv1761, lv1765, lv1819, lv1823) | |
R.output(gv) | |
return gv | |
@R.function | |
def softmax_with_temperature(logits: R.Tensor((1, 1, 50432), dtype="float32"), temperature: R.Tensor((), dtype="float32")) -> R.Tensor((1, 1, 50432), dtype="float32"): | |
R.func_attr({"tir_var_upper_bound": {"m": 2048, "n": 2048}}) | |
cls = Module | |
with R.dataflow(): | |
lv3799 = R.call_tir(cls.divide1, (logits, temperature), out_sinfo=R.Tensor((1, 1, 50432), dtype="float32")) | |
lv3800 = R.call_tir(cls.softmax1, (lv3799,), out_sinfo=R.Tensor((1, 1, 50432), dtype="float32")) | |
gv3: R.Tensor((1, 1, 50432), dtype="float32") = lv3800 | |
R.output(gv3) | |
return gv3 | |
# Metadata omitted. Use show_meta=True in script() method to show it. |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment