Created
August 16, 2024 21:04
-
-
Save mreso/61d5d384854d1a6ef67c331803116d41 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# AOT ID: ['0_inference'] | |
from ctypes import c_void_p, c_long | |
import torch | |
import math | |
import random | |
import os | |
import tempfile | |
from math import inf, nan | |
from torch._inductor.hooks import run_intermediate_hooks | |
from torch._inductor.utils import maybe_profile | |
from torch._inductor.codegen.memory_planning import _align as align | |
from torch import device, empty_strided | |
from torch._inductor.async_compile import AsyncCompile | |
from torch._inductor.select_algorithm import extern_kernels | |
from torch._inductor.codegen.multi_kernel import MultiKernelCall | |
aten = torch.ops.aten | |
inductor_ops = torch.ops.inductor | |
_quantized = torch.ops._quantized | |
assert_size_stride = torch._C._dynamo.guards.assert_size_stride | |
empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu | |
empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda | |
reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor | |
alloc_from_pool = torch.ops.inductor._alloc_from_pool | |
async_compile = AsyncCompile() | |
# kernel path: /tmp/torchinductor_mreso/cp/ccpansfhuhbzfmq6fiemul7defvnmxsdsnwsbejgv3aqviivkozv.py | |
# Source Nodes: [add, h, mean, mul, pow_1, rsqrt, x_fp32, x_normed, y], Original ATen: [aten._to_copy, aten.add, aten.embedding, aten.mean, aten.mul, aten.pow, aten.rsqrt] | |
# add => add | |
# h => embedding | |
# mean => mean | |
# mul => mul | |
# pow_1 => pow_1 | |
# rsqrt => rsqrt | |
# x_fp32 => convert_element_type | |
# x_normed => convert_element_type_1 | |
# y => mul_1 | |
triton_red_fused__to_copy_add_embedding_mean_mul_pow_rsqrt_0 = async_compile.triton('triton_', ''' | |
import triton | |
import triton.language as tl | |
from triton.compiler.compiler import AttrsDescriptor | |
from torch._inductor.runtime import triton_helpers, triton_heuristics | |
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties | |
@triton_heuristics.reduction( | |
size_hints=[1, 4096], | |
reduction_hint=ReductionHint.INNER, | |
filename=__file__, | |
triton_meta={'signature': {0: '*i32', 1: '*bf16', 2: '*bf16', 3: '*bf16', 4: 'i32', 5: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {4: 1}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 5), equal_to_1=(4,))]}, | |
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_add_embedding_mean_mul_pow_rsqrt_0', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 3, 'num_reduction': 1, 'backend_hash': '025D55A81D4601B19131CB9FF543A0FE014EB1904D412643095F7F0A8A7FF71B', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': True, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} | |
) | |
@triton.jit | |
def triton_(in_ptr0, in_ptr1, in_ptr2, out_ptr1, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr): | |
xnumel = 1 | |
rnumel = 4096 | |
xoffset = tl.program_id(0) * XBLOCK | |
xindex = xoffset + tl.arange(0, XBLOCK)[:, None] | |
xmask = tl.full([XBLOCK, RBLOCK], True, tl.int1) | |
rbase = tl.arange(0, RBLOCK)[None, :] | |
tmp0 = tl.load(in_ptr0 + (0)) | |
tmp1 = tl.broadcast_to(tmp0, [XBLOCK, RBLOCK]) | |
_tmp11 = tl.full([XBLOCK, RBLOCK], 0, tl.float32) | |
for roffset in range(0, rnumel, RBLOCK): | |
rindex = roffset + rbase | |
rmask = rindex < rnumel | |
r0 = rindex | |
tmp2 = tl.full([XBLOCK, RBLOCK], 128256, tl.int32) | |
tmp3 = tmp1 + tmp2 | |
tmp4 = tmp1 < 0 | |
tmp5 = tl.where(tmp4, tmp3, tmp1) | |
tl.device_assert((0 <= tmp5) & (tmp5 < 128256), "index out of bounds: 0 <= tmp5 < 128256") | |
tmp7 = tl.load(in_ptr1 + (r0 + (4096*tmp5)), rmask, eviction_policy='evict_last', other=0.0).to(tl.float32) | |
tmp8 = tmp7.to(tl.float32) | |
tmp9 = tmp8 * tmp8 | |
tmp10 = tl.broadcast_to(tmp9, [XBLOCK, RBLOCK]) | |
tmp12 = _tmp11 + tmp10 | |
_tmp11 = tl.where(rmask, tmp12, _tmp11) | |
tmp11 = tl.sum(_tmp11, 1)[:, None] | |
tmp13 = tl.load(in_ptr0 + (0)) | |
tmp14 = tl.broadcast_to(tmp13, [XBLOCK, RBLOCK]) | |
for roffset in range(0, rnumel, RBLOCK): | |
rindex = roffset + rbase | |
rmask = rindex < rnumel | |
r0 = rindex | |
tmp29 = tl.load(in_ptr2 + (r0), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32) | |
tmp15 = tl.full([XBLOCK, RBLOCK], 128256, tl.int32) | |
tmp16 = tmp14 + tmp15 | |
tmp17 = tmp14 < 0 | |
tmp18 = tl.where(tmp17, tmp16, tmp14) | |
tl.device_assert((0 <= tmp18) & (tmp18 < 128256), "index out of bounds: 0 <= tmp18 < 128256") | |
tmp20 = tl.load(in_ptr1 + (r0 + (4096*tmp18)), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32) | |
tmp21 = tmp20.to(tl.float32) | |
tmp22 = 4096.0 | |
tmp23 = tmp11 / tmp22 | |
tmp24 = 1e-05 | |
tmp25 = tmp23 + tmp24 | |
tmp26 = libdevice.rsqrt(tmp25) | |
tmp27 = tmp21 * tmp26 | |
tmp28 = tmp27.to(tl.float32) | |
tmp30 = tmp28 * tmp29 | |
tl.store(out_ptr1 + (tl.broadcast_to(r0, [XBLOCK, RBLOCK])), tmp30, rmask) | |
''', device_str='cuda') | |
import triton | |
import triton.language as tl | |
from torch._inductor.runtime.triton_heuristics import grid, split_scan_grid, start_graph, end_graph | |
from torch._C import _cuda_getCurrentRawStream as get_raw_stream | |
# kernel path: /tmp/torchinductor_mreso/zh/czhm7z6vmgstxp34zha46q2znijjccrf4zpqaatdp4lazkbs7cyh.py | |
# Source Nodes: [max_1], Original ATen: [aten.max] | |
# max_1 => max_1 | |
triton_poi_fused_max_1 = async_compile.triton('triton_', ''' | |
import triton | |
import triton.language as tl | |
from triton.compiler.compiler import AttrsDescriptor | |
from torch._inductor.runtime import triton_helpers, triton_heuristics | |
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties | |
@triton_heuristics.pointwise( | |
size_hints=[1], | |
filename=__file__, | |
triton_meta={'signature': {0: '*i64', 1: '*i64', 2: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {2: 1}, 'configs': [AttrsDescriptor(divisible_by_16=(1,), equal_to_1=(2,))]}, | |
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_max_1', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': '025D55A81D4601B19131CB9FF543A0FE014EB1904D412643095F7F0A8A7FF71B', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': True, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, | |
min_elem_per_thread=0 | |
) | |
@triton.jit | |
def triton_(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr): | |
xnumel = 1 | |
xoffset = tl.program_id(0) * XBLOCK | |
xindex = xoffset + tl.arange(0, XBLOCK)[:] | |
xmask = tl.full([XBLOCK], True, tl.int1) | |
tmp0 = tl.load(in_ptr0 + (0)) | |
tmp1 = tl.broadcast_to(tmp0, [XBLOCK]) | |
tl.store(out_ptr0 + (tl.full([XBLOCK], 0, tl.int32)), tmp1, None) | |
''', device_str='cuda') | |
# kernel path: /tmp/torchinductor_mreso/dm/cdmuegqtdrjgliya52h5xlbmwaerh3jvmz3o234f62w2wpha4dsv.py | |
# Source Nodes: [c], Original ATen: [aten.mm] | |
# c => mm | |
triton_tem_fused_mm_2 = async_compile.triton('triton_', ''' | |
import triton | |
import triton.language as tl | |
from triton.compiler.compiler import AttrsDescriptor | |
from torch._inductor.runtime import triton_helpers, triton_heuristics | |
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties | |
@triton_heuristics.template( | |
num_stages=5, | |
num_warps=4, | |
triton_meta={'signature': {0: '*bf16', 1: '*bf16', 2: '*bf16'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2), equal_to_1=())]}, | |
inductor_meta={'kernel_name': 'triton_tem_fused_mm_2', 'backend_hash': '025D55A81D4601B19131CB9FF543A0FE014EB1904D412643095F7F0A8A7FF71B', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': True, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, | |
) | |
@triton.jit | |
def triton_(arg_A, arg_B, out_ptr0): | |
GROUP_M : tl.constexpr = 8 | |
EVEN_K : tl.constexpr = True | |
ALLOW_TF32 : tl.constexpr = False | |
ACC_TYPE : tl.constexpr = tl.float32 | |
B_PROLOGUE_CAST_TYPE : tl.constexpr = None | |
BLOCK_M : tl.constexpr = 16 | |
BLOCK_N : tl.constexpr = 64 | |
BLOCK_K : tl.constexpr = 128 | |
A = arg_A | |
B = arg_B | |
M = 1 | |
N = 4096 | |
K = 4096 | |
if M * N == 0: | |
# early exit due to zero-size input(s) | |
return | |
stride_am = 4096 | |
stride_ak = 1 | |
stride_bk = 1 | |
stride_bn = 4096 | |
# based on triton.ops.matmul | |
pid = tl.program_id(0) | |
grid_m = (M + BLOCK_M - 1) // BLOCK_M | |
grid_n = (N + BLOCK_N - 1) // BLOCK_N | |
# re-order program ID for better L2 performance | |
width = GROUP_M * grid_n | |
group_id = pid // width | |
group_size = min(grid_m - group_id * GROUP_M, GROUP_M) | |
pid_m = group_id * GROUP_M + (pid % group_size) | |
pid_n = (pid % width) // (group_size) | |
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) | |
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) | |
if (stride_am == 1 and stride_ak == M) or (stride_am == K and stride_ak == 1): | |
ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M) | |
else: | |
ram = rm % M | |
if (stride_bk == 1 and stride_bn == K) or (stride_bk == N and stride_bn == 1): | |
rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N) | |
else: | |
rbn = rn % N | |
rk = tl.arange(0, BLOCK_K) | |
A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak) | |
B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn) | |
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) | |
for k in range(K, 0, -BLOCK_K): | |
if EVEN_K: | |
a = tl.load(A) | |
b = tl.load(B) | |
else: | |
a = tl.load(A, mask=rk[None, :] < k, other=0.) | |
b = tl.load(B, mask=rk[:, None] < k, other=0.) | |
if B_PROLOGUE_CAST_TYPE is not None: | |
b = b.to(B_PROLOGUE_CAST_TYPE) | |
acc += tl.dot(a, b, allow_tf32=ALLOW_TF32) | |
A += BLOCK_K * stride_ak | |
B += BLOCK_K * stride_bk | |
# rematerialize rm and rn to save registers | |
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) | |
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) | |
idx_m = rm[:, None] | |
idx_n = rn[None, :] | |
mask = (idx_m < M) & (idx_n < N) | |
# inductor generates a suffix | |
xindex = idx_n + (4096*idx_m) | |
tl.store(out_ptr0 + (tl.broadcast_to(idx_n, acc.shape)), acc, mask) | |
''', device_str='cuda') | |
import torch._inductor.kernel.mm_common | |
meta0 = {'GROUP_M': 8, 'EVEN_K': True, 'ALLOW_TF32': False, 'ACC_TYPE': 'tl.float32', 'B_PROLOGUE_CAST_TYPE': None, 'BLOCK_M': 16, 'BLOCK_N': 64, 'BLOCK_K': 128} | |
# kernel path: /tmp/torchinductor_mreso/lz/clz2e2d6xj6n3oho4nys4dfgydzhykoegeiet3dzqpcasieoh5up.py | |
# Source Nodes: [output, setitem, setitem_1], Original ATen: [aten._scaled_dot_product_efficient_attention, aten.index_put] | |
# output => _scaled_dot_product_efficient_attention | |
# setitem => index_put | |
# setitem_1 => index_put_1 | |
triton_poi_fused__scaled_dot_product_efficient_attention_index_put_3 = async_compile.triton('triton_', ''' | |
import triton | |
import triton.language as tl | |
from triton.compiler.compiler import AttrsDescriptor | |
from torch._inductor.runtime import triton_helpers, triton_heuristics | |
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties | |
@triton_heuristics.pointwise( | |
size_hints=[4096], | |
filename=__file__, | |
triton_meta={'signature': {0: '*i64', 1: '*bf16', 2: '*fp32', 3: '*bf16', 4: '*bf16', 5: '*bf16', 6: '*bf16', 7: '*bf16', 8: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(1, 2, 3, 4, 5, 6, 7, 8), equal_to_1=())]}, | |
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__scaled_dot_product_efficient_attention_index_put_3', 'mutated_arg_names': ['out_ptr0', 'out_ptr1'], 'no_x_dim': False, 'num_load': 10, 'num_reduction': 0, 'backend_hash': '025D55A81D4601B19131CB9FF543A0FE014EB1904D412643095F7F0A8A7FF71B', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': True, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, | |
min_elem_per_thread=0 | |
) | |
@triton.jit | |
def triton_(in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, out_ptr0, out_ptr1, out_ptr2, xnumel, XBLOCK : tl.constexpr): | |
xnumel = 4096 | |
xoffset = tl.program_id(0) * XBLOCK | |
xindex = xoffset + tl.arange(0, XBLOCK)[:] | |
xmask = tl.full([XBLOCK], True, tl.int1) | |
x2 = xindex | |
x0 = xindex % 128 | |
x1 = (xindex // 128) | |
tmp0 = tl.load(in_ptr0 + (0)) | |
tmp1 = tl.broadcast_to(tmp0, [XBLOCK]) | |
tmp41 = tl.load(in_ptr3 + (x0 + (128*(x1 // 4))), None).to(tl.float32) | |
tmp2 = tl.full([XBLOCK], 2048, tl.int32) | |
tmp3 = tmp1 + tmp2 | |
tmp4 = tmp1 < 0 | |
tmp5 = tl.where(tmp4, tmp3, tmp1) | |
tl.device_assert((0 <= tmp5) & (tmp5 < 2048), "index out of bounds: 0 <= tmp5 < 2048") | |
tmp7 = x2 % 2 | |
tmp8 = tl.full([1], 0, tl.int64) | |
tmp9 = tmp7 >= tmp8 | |
tmp10 = tl.full([1], 1, tl.int64) | |
tmp11 = tmp7 < tmp10 | |
tmp12 = tl.load(in_ptr1 + ((2*(x0 // 2)) + (128*(x1 // 4))), tmp11, eviction_policy='evict_last', other=0.0).to(tl.float32) | |
tmp13 = tmp12.to(tl.float32) | |
tl.device_assert(((0 <= tl.broadcast_to(tmp5, [XBLOCK])) & (tl.broadcast_to(tmp5, [XBLOCK]) < 2048)) | ~(tmp11), "index out of bounds: 0 <= tl.broadcast_to(tmp5, [XBLOCK]) < 2048") | |
tmp15 = tl.load(in_ptr2 + ((2*(x0 // 2)) + (128*tmp5)), tmp11, eviction_policy='evict_last', other=0.0) | |
tmp16 = tmp13 * tmp15 | |
tmp17 = tl.load(in_ptr1 + (1 + (2*(x0 // 2)) + (128*(x1 // 4))), tmp11, eviction_policy='evict_last', other=0.0).to(tl.float32) | |
tmp18 = tmp17.to(tl.float32) | |
tmp19 = tl.load(in_ptr2 + (1 + (2*(x0 // 2)) + (128*tmp5)), tmp11, eviction_policy='evict_last', other=0.0) | |
tmp20 = tmp18 * tmp19 | |
tmp21 = tmp16 - tmp20 | |
tmp22 = tl.full(tmp21.shape, 0.0, tmp21.dtype) | |
tmp23 = tl.where(tmp11, tmp21, tmp22) | |
tmp24 = tmp7 >= tmp10 | |
tmp25 = tl.full([1], 2, tl.int64) | |
tmp26 = tmp7 < tmp25 | |
tmp27 = tl.load(in_ptr1 + (1 + (2*(x0 // 2)) + (128*(x1 // 4))), tmp24, eviction_policy='evict_last', other=0.0).to(tl.float32) | |
tmp28 = tmp27.to(tl.float32) | |
tl.device_assert(((0 <= tl.broadcast_to(tmp5, [XBLOCK])) & (tl.broadcast_to(tmp5, [XBLOCK]) < 2048)) | ~(tmp24), "index out of bounds: 0 <= tl.broadcast_to(tmp5, [XBLOCK]) < 2048") | |
tmp30 = tl.load(in_ptr2 + ((2*(x0 // 2)) + (128*tmp5)), tmp24, eviction_policy='evict_last', other=0.0) | |
tmp31 = tmp28 * tmp30 | |
tmp32 = tl.load(in_ptr1 + ((2*(x0 // 2)) + (128*(x1 // 4))), tmp24, eviction_policy='evict_last', other=0.0).to(tl.float32) | |
tmp33 = tmp32.to(tl.float32) | |
tmp34 = tl.load(in_ptr2 + (1 + (2*(x0 // 2)) + (128*tmp5)), tmp24, eviction_policy='evict_last', other=0.0) | |
tmp35 = tmp33 * tmp34 | |
tmp36 = tmp31 + tmp35 | |
tmp37 = tl.full(tmp36.shape, 0.0, tmp36.dtype) | |
tmp38 = tl.where(tmp24, tmp36, tmp37) | |
tmp39 = tl.where(tmp11, tmp23, tmp38) | |
tmp40 = tmp39.to(tl.float32) | |
tmp42 = tl.load(in_ptr4 + ((2*(x0 // 2)) + (128*x1)), tmp11, eviction_policy='evict_last', other=0.0).to(tl.float32) | |
tmp43 = tmp42.to(tl.float32) | |
tmp44 = tmp43 * tmp15 | |
tmp45 = tl.load(in_ptr4 + (1 + (2*(x0 // 2)) + (128*x1)), tmp11, eviction_policy='evict_last', other=0.0).to(tl.float32) | |
tmp46 = tmp45.to(tl.float32) | |
tmp47 = tmp46 * tmp19 | |
tmp48 = tmp44 - tmp47 | |
tmp49 = tl.full(tmp48.shape, 0.0, tmp48.dtype) | |
tmp50 = tl.where(tmp11, tmp48, tmp49) | |
tmp51 = tl.load(in_ptr4 + (1 + (2*(x0 // 2)) + (128*x1)), tmp24, eviction_policy='evict_last', other=0.0).to(tl.float32) | |
tmp52 = tmp51.to(tl.float32) | |
tmp53 = tmp52 * tmp30 | |
tmp54 = tl.load(in_ptr4 + ((2*(x0 // 2)) + (128*x1)), tmp24, eviction_policy='evict_last', other=0.0).to(tl.float32) | |
tmp55 = tmp54.to(tl.float32) | |
tmp56 = tmp55 * tmp34 | |
tmp57 = tmp53 + tmp56 | |
tmp58 = tl.full(tmp57.shape, 0.0, tmp57.dtype) | |
tmp59 = tl.where(tmp24, tmp57, tmp58) | |
tmp60 = tl.where(tmp11, tmp50, tmp59) | |
tmp61 = tmp60.to(tl.float32) | |
tl.store(out_ptr0 + (x0 + (128*tmp5) + (262144*x1)), tmp40, None) | |
tl.store(out_ptr1 + (x0 + (128*tmp5) + (262144*x1)), tmp41, None) | |
tl.store(out_ptr2 + (x2), tmp61, None) | |
''', device_str='cuda') | |
# kernel path: /tmp/torchinductor_mreso/n6/cn6h3fga4dekm5cvy4s4fhrazz4zcn3kegoonkhz2piybh75ca7i.py | |
# Source Nodes: [output], Original ATen: [aten._scaled_dot_product_efficient_attention] | |
# output => _scaled_dot_product_efficient_attention | |
triton_poi_fused__scaled_dot_product_efficient_attention_4 = async_compile.triton('triton_', ''' | |
import triton | |
import triton.language as tl | |
from triton.compiler.compiler import AttrsDescriptor | |
from torch._inductor.runtime import triton_helpers, triton_heuristics | |
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties | |
@triton_heuristics.pointwise( | |
size_hints=[65536], | |
filename=__file__, | |
triton_meta={'signature': {0: '*i64', 1: '*i1', 2: '*bf16', 3: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(1, 2, 3), equal_to_1=())]}, | |
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__scaled_dot_product_efficient_attention_4', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': '025D55A81D4601B19131CB9FF543A0FE014EB1904D412643095F7F0A8A7FF71B', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': True, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, | |
min_elem_per_thread=0 | |
) | |
@triton.jit | |
def triton_(in_ptr0, in_ptr1, out_ptr0, xnumel, XBLOCK : tl.constexpr): | |
xnumel = 65536 | |
xoffset = tl.program_id(0) * XBLOCK | |
xindex = xoffset + tl.arange(0, XBLOCK)[:] | |
xmask = tl.full([XBLOCK], True, tl.int1) | |
x0 = xindex % 2048 | |
x2 = xindex | |
tmp0 = tl.load(in_ptr0 + (0)) | |
tmp1 = tl.broadcast_to(tmp0, [XBLOCK]) | |
tmp2 = tl.full([XBLOCK], 2048, tl.int32) | |
tmp3 = tmp1 + tmp2 | |
tmp4 = tmp1 < 0 | |
tmp5 = tl.where(tmp4, tmp3, tmp1) | |
tl.device_assert((0 <= tmp5) & (tmp5 < 2048), "index out of bounds: 0 <= tmp5 < 2048") | |
tmp7 = tl.load(in_ptr1 + (x0 + (2048*tmp5)), None).to(tl.int1) | |
tmp8 = tmp7 == 0 | |
tmp9 = float("-inf") | |
tmp10 = 0.0 | |
tmp11 = tl.where(tmp8, tmp9, tmp10) | |
tl.store(out_ptr0 + (x2), tmp11, None) | |
''', device_str='cuda') | |
# kernel path: /tmp/torchinductor_mreso/hi/chi7k4nldga4ozz4jralolpcemic42rouob3i2giv2okoaj746yc.py | |
# Source Nodes: [add_5, h, h_1, mean_1, mul_10, mul_11, pow_2, rsqrt_1, x_fp32_1, x_normed_1], Original ATen: [aten._to_copy, aten.add, aten.embedding, aten.mean, aten.mul, aten.pow, aten.rsqrt] | |
# add_5 => add_4 | |
# h => embedding | |
# h_1 => add_3 | |
# mean_1 => mean_1 | |
# mul_10 => mul_10 | |
# mul_11 => mul_11 | |
# pow_2 => pow_2 | |
# rsqrt_1 => rsqrt_1 | |
# x_fp32_1 => convert_element_type_14 | |
# x_normed_1 => convert_element_type_15 | |
triton_red_fused__to_copy_add_embedding_mean_mul_pow_rsqrt_5 = async_compile.triton('triton_', ''' | |
import triton | |
import triton.language as tl | |
from triton.compiler.compiler import AttrsDescriptor | |
from torch._inductor.runtime import triton_helpers, triton_heuristics | |
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties | |
@triton_heuristics.reduction( | |
size_hints=[1, 4096], | |
reduction_hint=ReductionHint.INNER, | |
filename=__file__, | |
triton_meta={'signature': {0: '*bf16', 1: '*i32', 2: '*bf16', 3: '*bf16', 4: '*bf16', 5: 'i32', 6: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {5: 1}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4, 6), equal_to_1=(5,))]}, | |
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_add_embedding_mean_mul_pow_rsqrt_5', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 5, 'num_reduction': 1, 'backend_hash': '025D55A81D4601B19131CB9FF543A0FE014EB1904D412643095F7F0A8A7FF71B', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': True, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} | |
) | |
@triton.jit | |
def triton_(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr1, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr): | |
xnumel = 1 | |
rnumel = 4096 | |
xoffset = tl.program_id(0) * XBLOCK | |
xindex = xoffset + tl.arange(0, XBLOCK)[:, None] | |
xmask = tl.full([XBLOCK, RBLOCK], True, tl.int1) | |
rbase = tl.arange(0, RBLOCK)[None, :] | |
tmp1 = tl.load(in_ptr1 + (0)) | |
tmp2 = tl.broadcast_to(tmp1, [XBLOCK, RBLOCK]) | |
_tmp13 = tl.full([XBLOCK, RBLOCK], 0, tl.float32) | |
for roffset in range(0, rnumel, RBLOCK): | |
rindex = roffset + rbase | |
rmask = rindex < rnumel | |
r0 = rindex | |
tmp0 = tl.load(in_ptr0 + (r0), rmask, eviction_policy='evict_last', other=0.0).to(tl.float32) | |
tmp3 = tl.full([XBLOCK, RBLOCK], 128256, tl.int32) | |
tmp4 = tmp2 + tmp3 | |
tmp5 = tmp2 < 0 | |
tmp6 = tl.where(tmp5, tmp4, tmp2) | |
tl.device_assert((0 <= tmp6) & (tmp6 < 128256), "index out of bounds: 0 <= tmp6 < 128256") | |
tmp8 = tl.load(in_ptr2 + (r0 + (4096*tmp6)), rmask, eviction_policy='evict_last', other=0.0).to(tl.float32) | |
tmp9 = tmp0 + tmp8 | |
tmp10 = tmp9.to(tl.float32) | |
tmp11 = tmp10 * tmp10 | |
tmp12 = tl.broadcast_to(tmp11, [XBLOCK, RBLOCK]) | |
tmp14 = _tmp13 + tmp12 | |
_tmp13 = tl.where(rmask, tmp14, _tmp13) | |
tmp13 = tl.sum(_tmp13, 1)[:, None] | |
tmp16 = tl.load(in_ptr1 + (0)) | |
tmp17 = tl.broadcast_to(tmp16, [XBLOCK, RBLOCK]) | |
for roffset in range(0, rnumel, RBLOCK): | |
rindex = roffset + rbase | |
rmask = rindex < rnumel | |
r0 = rindex | |
tmp15 = tl.load(in_ptr0 + (r0), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32) | |
tmp33 = tl.load(in_ptr3 + (r0), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32) | |
tmp18 = tl.full([XBLOCK, RBLOCK], 128256, tl.int32) | |
tmp19 = tmp17 + tmp18 | |
tmp20 = tmp17 < 0 | |
tmp21 = tl.where(tmp20, tmp19, tmp17) | |
tl.device_assert((0 <= tmp21) & (tmp21 < 128256), "index out of bounds: 0 <= tmp21 < 128256") | |
tmp23 = tl.load(in_ptr2 + (r0 + (4096*tmp21)), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32) | |
tmp24 = tmp15 + tmp23 | |
tmp25 = tmp24.to(tl.float32) | |
tmp26 = 4096.0 | |
tmp27 = tmp13 / tmp26 | |
tmp28 = 1e-05 | |
tmp29 = tmp27 + tmp28 | |
tmp30 = libdevice.rsqrt(tmp29) | |
tmp31 = tmp25 * tmp30 | |
tmp32 = tmp31.to(tl.float32) | |
tmp34 = tmp32 * tmp33 | |
tl.store(out_ptr1 + (tl.broadcast_to(r0, [XBLOCK, RBLOCK])), tmp34, rmask) | |
''', device_str='cuda') | |
# kernel path: /tmp/torchinductor_mreso/ur/curh37b7uwqy7x2z7p6suloz7vk3aiutkohhxggzhrcmv6ypkm2u.py | |
# Source Nodes: [c_4], Original ATen: [aten.mm] | |
# c_4 => mm_4 | |
triton_tem_fused_mm_6 = async_compile.triton('triton_', ''' | |
import triton | |
import triton.language as tl | |
from triton.compiler.compiler import AttrsDescriptor | |
from torch._inductor.runtime import triton_helpers, triton_heuristics | |
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties | |
@triton_heuristics.template( | |
num_stages=4, | |
num_warps=4, | |
triton_meta={'signature': {0: '*bf16', 1: '*bf16', 2: '*bf16'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2), equal_to_1=())]}, | |
inductor_meta={'kernel_name': 'triton_tem_fused_mm_6', 'backend_hash': '025D55A81D4601B19131CB9FF543A0FE014EB1904D412643095F7F0A8A7FF71B', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': True, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, | |
) | |
@triton.jit | |
def triton_(arg_A, arg_B, out_ptr0): | |
GROUP_M : tl.constexpr = 8 | |
EVEN_K : tl.constexpr = True | |
ALLOW_TF32 : tl.constexpr = False | |
ACC_TYPE : tl.constexpr = tl.float32 | |
B_PROLOGUE_CAST_TYPE : tl.constexpr = None | |
BLOCK_M : tl.constexpr = 16 | |
BLOCK_N : tl.constexpr = 64 | |
BLOCK_K : tl.constexpr = 32 | |
A = arg_A | |
B = arg_B | |
M = 1 | |
N = 14336 | |
K = 4096 | |
if M * N == 0: | |
# early exit due to zero-size input(s) | |
return | |
stride_am = 4096 | |
stride_ak = 1 | |
stride_bk = 1 | |
stride_bn = 4096 | |
# based on triton.ops.matmul | |
pid = tl.program_id(0) | |
grid_m = (M + BLOCK_M - 1) // BLOCK_M | |
grid_n = (N + BLOCK_N - 1) // BLOCK_N | |
# re-order program ID for better L2 performance | |
width = GROUP_M * grid_n | |
group_id = pid // width | |
group_size = min(grid_m - group_id * GROUP_M, GROUP_M) | |
pid_m = group_id * GROUP_M + (pid % group_size) | |
pid_n = (pid % width) // (group_size) | |
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) | |
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) | |
if (stride_am == 1 and stride_ak == M) or (stride_am == K and stride_ak == 1): | |
ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M) | |
else: | |
ram = rm % M | |
if (stride_bk == 1 and stride_bn == K) or (stride_bk == N and stride_bn == 1): | |
rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N) | |
else: | |
rbn = rn % N | |
rk = tl.arange(0, BLOCK_K) | |
A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak) | |
B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn) | |
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) | |
for k in range(K, 0, -BLOCK_K): | |
if EVEN_K: | |
a = tl.load(A) | |
b = tl.load(B) | |
else: | |
a = tl.load(A, mask=rk[None, :] < k, other=0.) | |
b = tl.load(B, mask=rk[:, None] < k, other=0.) | |
if B_PROLOGUE_CAST_TYPE is not None: | |
b = b.to(B_PROLOGUE_CAST_TYPE) | |
acc += tl.dot(a, b, allow_tf32=ALLOW_TF32) | |
A += BLOCK_K * stride_ak | |
B += BLOCK_K * stride_bk | |
# rematerialize rm and rn to save registers | |
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) | |
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) | |
idx_m = rm[:, None] | |
idx_n = rn[None, :] | |
mask = (idx_m < M) & (idx_n < N) | |
# inductor generates a suffix | |
xindex = idx_n + (14336*idx_m) | |
tl.store(out_ptr0 + (tl.broadcast_to(idx_n, acc.shape)), acc, mask) | |
''', device_str='cuda') | |
meta1 = {'GROUP_M': 8, 'EVEN_K': True, 'ALLOW_TF32': False, 'ACC_TYPE': 'tl.float32', 'B_PROLOGUE_CAST_TYPE': None, 'BLOCK_M': 16, 'BLOCK_N': 64, 'BLOCK_K': 32} | |
# kernel path: /tmp/torchinductor_mreso/gg/cggdihc3hz6r76zojlkqkx4mohi7rtngag436gzcigqfbq6qr63m.py | |
# Source Nodes: [c_5, mul_12, silu], Original ATen: [aten.mm, aten.mul, aten.silu] | |
# c_5 => mm_5 | |
# mul_12 => mul_13 | |
# silu => convert_element_type_18, convert_element_type_19, mul_12, sigmoid | |
triton_tem_fused_mm_mul_silu_7 = async_compile.triton('triton_', ''' | |
import triton | |
import triton.language as tl | |
from triton.compiler.compiler import AttrsDescriptor | |
from torch._inductor.runtime import triton_helpers, triton_heuristics | |
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties | |
@triton_heuristics.template( | |
num_stages=5, | |
num_warps=4, | |
triton_meta={'signature': {0: '*bf16', 1: '*bf16', 2: '*bf16', 3: '*bf16'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3), equal_to_1=())]}, | |
inductor_meta={'kernel_name': 'triton_tem_fused_mm_mul_silu_7', 'backend_hash': '025D55A81D4601B19131CB9FF543A0FE014EB1904D412643095F7F0A8A7FF71B', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': True, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, | |
) | |
@triton.jit | |
def triton_(arg_A, arg_B, in_ptr2, out_ptr1): | |
GROUP_M : tl.constexpr = 8 | |
EVEN_K : tl.constexpr = True | |
ALLOW_TF32 : tl.constexpr = False | |
ACC_TYPE : tl.constexpr = tl.float32 | |
B_PROLOGUE_CAST_TYPE : tl.constexpr = None | |
BLOCK_M : tl.constexpr = 16 | |
BLOCK_N : tl.constexpr = 64 | |
BLOCK_K : tl.constexpr = 32 | |
A = arg_A | |
B = arg_B | |
M = 1 | |
N = 14336 | |
K = 4096 | |
if M * N == 0: | |
# early exit due to zero-size input(s) | |
return | |
stride_am = 4096 | |
stride_ak = 1 | |
stride_bk = 1 | |
stride_bn = 4096 | |
# based on triton.ops.matmul | |
pid = tl.program_id(0) | |
grid_m = (M + BLOCK_M - 1) // BLOCK_M | |
grid_n = (N + BLOCK_N - 1) // BLOCK_N | |
# re-order program ID for better L2 performance | |
width = GROUP_M * grid_n | |
group_id = pid // width | |
group_size = min(grid_m - group_id * GROUP_M, GROUP_M) | |
pid_m = group_id * GROUP_M + (pid % group_size) | |
pid_n = (pid % width) // (group_size) | |
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) | |
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) | |
if (stride_am == 1 and stride_ak == M) or (stride_am == K and stride_ak == 1): | |
ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M) | |
else: | |
ram = rm % M | |
if (stride_bk == 1 and stride_bn == K) or (stride_bk == N and stride_bn == 1): | |
rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N) | |
else: | |
rbn = rn % N | |
rk = tl.arange(0, BLOCK_K) | |
A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak) | |
B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn) | |
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) | |
for k in range(K, 0, -BLOCK_K): | |
if EVEN_K: | |
a = tl.load(A) | |
b = tl.load(B) | |
else: | |
a = tl.load(A, mask=rk[None, :] < k, other=0.) | |
b = tl.load(B, mask=rk[:, None] < k, other=0.) | |
if B_PROLOGUE_CAST_TYPE is not None: | |
b = b.to(B_PROLOGUE_CAST_TYPE) | |
acc += tl.dot(a, b, allow_tf32=ALLOW_TF32) | |
A += BLOCK_K * stride_ak | |
B += BLOCK_K * stride_bk | |
# rematerialize rm and rn to save registers | |
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) | |
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) | |
idx_m = rm[:, None] | |
idx_n = rn[None, :] | |
mask = (idx_m < M) & (idx_n < N) | |
# inductor generates a suffix | |
xindex = idx_n + (14336*idx_m) | |
tmp0 = tl.load(in_ptr2 + (tl.broadcast_to(xindex, acc.shape)), mask, eviction_policy='evict_last').to(tl.float32) | |
tmp1 = tmp0.to(tl.float32) | |
tmp2 = tl.sigmoid(tmp1) | |
tmp3 = tmp1 * tmp2 | |
tmp4 = tmp3.to(tl.float32) | |
tmp5 = tmp4 * acc | |
tl.store(out_ptr1 + (tl.broadcast_to(xindex, acc.shape)), tmp5, mask) | |
''', device_str='cuda') | |
# kernel path: /tmp/torchinductor_mreso/6m/c6m3krqefnziqfhazzhlqeozbuyylbyb2cp5ijykpbol2gg2v3qr.py | |
# Source Nodes: [c_6], Original ATen: [aten.mm] | |
# c_6 => mm_6 | |
triton_tem_fused_mm_8 = async_compile.triton('triton_', ''' | |
import triton | |
import triton.language as tl | |
from triton.compiler.compiler import AttrsDescriptor | |
from torch._inductor.runtime import triton_helpers, triton_heuristics | |
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties | |
@triton_heuristics.template( | |
num_stages=5, | |
num_warps=4, | |
triton_meta={'signature': {0: '*bf16', 1: '*bf16', 2: '*bf16'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2), equal_to_1=())]}, | |
inductor_meta={'kernel_name': 'triton_tem_fused_mm_8', 'backend_hash': '025D55A81D4601B19131CB9FF543A0FE014EB1904D412643095F7F0A8A7FF71B', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': True, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, | |
) | |
@triton.jit | |
def triton_(arg_A, arg_B, out_ptr0): | |
GROUP_M : tl.constexpr = 8 | |
EVEN_K : tl.constexpr = True | |
ALLOW_TF32 : tl.constexpr = False | |
ACC_TYPE : tl.constexpr = tl.float32 | |
B_PROLOGUE_CAST_TYPE : tl.constexpr = None | |
BLOCK_M : tl.constexpr = 16 | |
BLOCK_N : tl.constexpr = 64 | |
BLOCK_K : tl.constexpr = 128 | |
A = arg_A | |
B = arg_B | |
M = 1 | |
N = 4096 | |
K = 14336 | |
if M * N == 0: | |
# early exit due to zero-size input(s) | |
return | |
stride_am = 14336 | |
stride_ak = 1 | |
stride_bk = 1 | |
stride_bn = 14336 | |
# based on triton.ops.matmul | |
pid = tl.program_id(0) | |
grid_m = (M + BLOCK_M - 1) // BLOCK_M | |
grid_n = (N + BLOCK_N - 1) // BLOCK_N | |
# re-order program ID for better L2 performance | |
width = GROUP_M * grid_n | |
group_id = pid // width | |
group_size = min(grid_m - group_id * GROUP_M, GROUP_M) | |
pid_m = group_id * GROUP_M + (pid % group_size) | |
pid_n = (pid % width) // (group_size) | |
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) | |
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) | |
if (stride_am == 1 and stride_ak == M) or (stride_am == K and stride_ak == 1): | |
ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M) | |
else: | |
ram = rm % M | |
if (stride_bk == 1 and stride_bn == K) or (stride_bk == N and stride_bn == 1): | |
rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N) | |
else: | |
rbn = rn % N | |
rk = tl.arange(0, BLOCK_K) | |
A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak) | |
B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn) | |
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) | |
for k in range(K, 0, -BLOCK_K): | |
if EVEN_K: | |
a = tl.load(A) | |
b = tl.load(B) | |
else: | |
a = tl.load(A, mask=rk[None, :] < k, other=0.) | |
b = tl.load(B, mask=rk[:, None] < k, other=0.) | |
if B_PROLOGUE_CAST_TYPE is not None: | |
b = b.to(B_PROLOGUE_CAST_TYPE) | |
acc += tl.dot(a, b, allow_tf32=ALLOW_TF32) | |
A += BLOCK_K * stride_ak | |
B += BLOCK_K * stride_bk | |
# rematerialize rm and rn to save registers | |
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) | |
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) | |
idx_m = rm[:, None] | |
idx_n = rn[None, :] | |
mask = (idx_m < M) & (idx_n < N) | |
# inductor generates a suffix | |
xindex = idx_n + (4096*idx_m) | |
tl.store(out_ptr0 + (tl.broadcast_to(idx_n, acc.shape)), acc, mask) | |
''', device_str='cuda') | |
# kernel path: /tmp/torchinductor_mreso/fg/cfghtgfe4n4ny6pchk5ahx43xbhxnqllrs5nwoae7yaniw5lxpyv.py | |
# Source Nodes: [add_7, h, h_1, mean_2, mul_13, out, pow_3, rsqrt_2, x_fp32_2, x_normed_2, y_1], Original ATen: [aten._to_copy, aten.add, aten.embedding, aten.mean, aten.mul, aten.pow, aten.rsqrt] | |
# add_7 => add_6 | |
# h => embedding | |
# h_1 => add_3 | |
# mean_2 => mean_2 | |
# mul_13 => mul_14 | |
# out => add_5 | |
# pow_3 => pow_3 | |
# rsqrt_2 => rsqrt_2 | |
# x_fp32_2 => convert_element_type_24 | |
# x_normed_2 => convert_element_type_25 | |
# y_1 => mul_15 | |
triton_red_fused__to_copy_add_embedding_mean_mul_pow_rsqrt_9 = async_compile.triton('triton_', ''' | |
import triton | |
import triton.language as tl | |
from triton.compiler.compiler import AttrsDescriptor | |
from torch._inductor.runtime import triton_helpers, triton_heuristics | |
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties | |
@triton_heuristics.reduction( | |
size_hints=[1, 4096], | |
reduction_hint=ReductionHint.INNER, | |
filename=__file__, | |
triton_meta={'signature': {0: '*bf16', 1: '*i32', 2: '*bf16', 3: '*bf16', 4: '*bf16', 5: '*bf16', 6: 'i32', 7: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {6: 1}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4, 5, 7), equal_to_1=(6,))]}, | |
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_add_embedding_mean_mul_pow_rsqrt_9', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 7, 'num_reduction': 1, 'backend_hash': '025D55A81D4601B19131CB9FF543A0FE014EB1904D412643095F7F0A8A7FF71B', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': True, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} | |
) | |
@triton.jit | |
def triton_(in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, out_ptr1, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr): | |
xnumel = 1 | |
rnumel = 4096 | |
xoffset = tl.program_id(0) * XBLOCK | |
xindex = xoffset + tl.arange(0, XBLOCK)[:, None] | |
xmask = tl.full([XBLOCK, RBLOCK], True, tl.int1) | |
rbase = tl.arange(0, RBLOCK)[None, :] | |
tmp1 = tl.load(in_ptr1 + (0)) | |
tmp2 = tl.broadcast_to(tmp1, [XBLOCK, RBLOCK]) | |
_tmp15 = tl.full([XBLOCK, RBLOCK], 0, tl.float32) | |
for roffset in range(0, rnumel, RBLOCK): | |
rindex = roffset + rbase | |
rmask = rindex < rnumel | |
r0 = rindex | |
tmp0 = tl.load(in_ptr0 + (r0), rmask, eviction_policy='evict_last', other=0.0).to(tl.float32) | |
tmp10 = tl.load(in_ptr3 + (r0), rmask, eviction_policy='evict_last', other=0.0).to(tl.float32) | |
tmp3 = tl.full([XBLOCK, RBLOCK], 128256, tl.int32) | |
tmp4 = tmp2 + tmp3 | |
tmp5 = tmp2 < 0 | |
tmp6 = tl.where(tmp5, tmp4, tmp2) | |
tl.device_assert((0 <= tmp6) & (tmp6 < 128256), "index out of bounds: 0 <= tmp6 < 128256") | |
tmp8 = tl.load(in_ptr2 + (r0 + (4096*tmp6)), rmask, eviction_policy='evict_last', other=0.0).to(tl.float32) | |
tmp9 = tmp0 + tmp8 | |
tmp11 = tmp9 + tmp10 | |
tmp12 = tmp11.to(tl.float32) | |
tmp13 = tmp12 * tmp12 | |
tmp14 = tl.broadcast_to(tmp13, [XBLOCK, RBLOCK]) | |
tmp16 = _tmp15 + tmp14 | |
_tmp15 = tl.where(rmask, tmp16, _tmp15) | |
tmp15 = tl.sum(_tmp15, 1)[:, None] | |
tmp18 = tl.load(in_ptr1 + (0)) | |
tmp19 = tl.broadcast_to(tmp18, [XBLOCK, RBLOCK]) | |
for roffset in range(0, rnumel, RBLOCK): | |
rindex = roffset + rbase | |
rmask = rindex < rnumel | |
r0 = rindex | |
tmp17 = tl.load(in_ptr0 + (r0), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32) | |
tmp27 = tl.load(in_ptr3 + (r0), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32) | |
tmp37 = tl.load(in_ptr4 + (r0), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32) | |
tmp20 = tl.full([XBLOCK, RBLOCK], 128256, tl.int32) | |
tmp21 = tmp19 + tmp20 | |
tmp22 = tmp19 < 0 | |
tmp23 = tl.where(tmp22, tmp21, tmp19) | |
tl.device_assert((0 <= tmp23) & (tmp23 < 128256), "index out of bounds: 0 <= tmp23 < 128256") | |
tmp25 = tl.load(in_ptr2 + (r0 + (4096*tmp23)), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32) | |
tmp26 = tmp17 + tmp25 | |
tmp28 = tmp26 + tmp27 | |
tmp29 = tmp28.to(tl.float32) | |
tmp30 = 4096.0 | |
tmp31 = tmp15 / tmp30 | |
tmp32 = 1e-05 | |
tmp33 = tmp31 + tmp32 | |
tmp34 = libdevice.rsqrt(tmp33) | |
tmp35 = tmp29 * tmp34 | |
tmp36 = tmp35.to(tl.float32) | |
tmp38 = tmp36 * tmp37 | |
tl.store(out_ptr1 + (tl.broadcast_to(r0, [XBLOCK, RBLOCK])), tmp38, rmask) | |
''', device_str='cuda') | |
# kernel path: /tmp/torchinductor_mreso/zo/czoyobha6z3774km535v7px7sq2dkfatukjcv4fgmxonhuumqv67.py | |
# Source Nodes: [add_12, h, h_1, h_2, mean_3, mul_23, mul_24, out, pow_4, rsqrt_3, x_fp32_3, x_normed_3], Original ATen: [aten._to_copy, aten.add, aten.embedding, aten.mean, aten.mul, aten.pow, aten.rsqrt] | |
# add_12 => add_10 | |
# h => embedding | |
# h_1 => add_3 | |
# h_2 => add_9 | |
# mean_3 => mean_3 | |
# mul_23 => mul_24 | |
# mul_24 => mul_25 | |
# out => add_5 | |
# pow_4 => pow_4 | |
# rsqrt_3 => rsqrt_3 | |
# x_fp32_3 => convert_element_type_38 | |
# x_normed_3 => convert_element_type_39 | |
triton_red_fused__to_copy_add_embedding_mean_mul_pow_rsqrt_10 = async_compile.triton('triton_', ''' | |
import triton | |
import triton.language as tl | |
from triton.compiler.compiler import AttrsDescriptor | |
from torch._inductor.runtime import triton_helpers, triton_heuristics | |
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties | |
@triton_heuristics.reduction( | |
size_hints=[1, 4096], | |
reduction_hint=ReductionHint.INNER, | |
filename=__file__, | |
triton_meta={'signature': {0: '*bf16', 1: '*bf16', 2: '*i32', 3: '*bf16', 4: '*bf16', 5: '*bf16', 6: '*bf16', 7: 'i32', 8: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {7: 1}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4, 5, 6, 8), equal_to_1=(7,))]}, | |
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_add_embedding_mean_mul_pow_rsqrt_10', 'mutated_arg_names': ['in_out_ptr0'], 'no_x_dim': False, 'num_load': 6, 'num_reduction': 1, 'backend_hash': '025D55A81D4601B19131CB9FF543A0FE014EB1904D412643095F7F0A8A7FF71B', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': True, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} | |
) | |
@triton.jit | |
def triton_(in_out_ptr0, in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, out_ptr1, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr): | |
xnumel = 1 | |
rnumel = 4096 | |
xoffset = tl.program_id(0) * XBLOCK | |
xindex = xoffset + tl.arange(0, XBLOCK)[:, None] | |
xmask = tl.full([XBLOCK, RBLOCK], True, tl.int1) | |
rbase = tl.arange(0, RBLOCK)[None, :] | |
tmp2 = tl.load(in_ptr1 + (0)) | |
tmp3 = tl.broadcast_to(tmp2, [XBLOCK, RBLOCK]) | |
_tmp17 = tl.full([XBLOCK, RBLOCK], 0, tl.float32) | |
for roffset in range(0, rnumel, RBLOCK): | |
rindex = roffset + rbase | |
rmask = rindex < rnumel | |
r0 = rindex | |
tmp0 = tl.load(in_out_ptr0 + (r0), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32) | |
tmp1 = tl.load(in_ptr0 + (r0), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32) | |
tmp11 = tl.load(in_ptr3 + (r0), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32) | |
tmp4 = tl.full([XBLOCK, RBLOCK], 128256, tl.int32) | |
tmp5 = tmp3 + tmp4 | |
tmp6 = tmp3 < 0 | |
tmp7 = tl.where(tmp6, tmp5, tmp3) | |
tl.device_assert((0 <= tmp7) & (tmp7 < 128256), "index out of bounds: 0 <= tmp7 < 128256") | |
tmp9 = tl.load(in_ptr2 + (r0 + (4096*tmp7)), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32) | |
tmp10 = tmp1 + tmp9 | |
tmp12 = tmp10 + tmp11 | |
tmp13 = tmp0 + tmp12 | |
tmp14 = tmp13.to(tl.float32) | |
tmp15 = tmp14 * tmp14 | |
tmp16 = tl.broadcast_to(tmp15, [XBLOCK, RBLOCK]) | |
tmp18 = _tmp17 + tmp16 | |
_tmp17 = tl.where(rmask, tmp18, _tmp17) | |
tl.store(in_out_ptr0 + (tl.broadcast_to(r0, [XBLOCK, RBLOCK])), tmp13, rmask) | |
tmp17 = tl.sum(_tmp17, 1)[:, None] | |
for roffset in range(0, rnumel, RBLOCK): | |
rindex = roffset + rbase | |
rmask = rindex < rnumel | |
r0 = rindex | |
tmp19 = tl.load(in_out_ptr0 + (r0), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32) | |
tmp28 = tl.load(in_ptr4 + (r0), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32) | |
tmp20 = tmp19.to(tl.float32) | |
tmp21 = 4096.0 | |
tmp22 = tmp17 / tmp21 | |
tmp23 = 1e-05 | |
tmp24 = tmp22 + tmp23 | |
tmp25 = libdevice.rsqrt(tmp24) | |
tmp26 = tmp20 * tmp25 | |
tmp27 = tmp26.to(tl.float32) | |
tmp29 = tmp27 * tmp28 | |
tl.store(out_ptr1 + (tl.broadcast_to(r0, [XBLOCK, RBLOCK])), tmp29, rmask) | |
''', device_str='cuda') | |
# kernel path: /tmp/torchinductor_mreso/t7/ct7gses5bjf3ss3snvk3ainp7pao5q3rukpyjc5fe32c5zvf4q4d.py | |
# Source Nodes: [add_14, mean_4, mul_26, out_1, pow_5, rsqrt_4, x_fp32_4, x_normed_4, y_2], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, aten.pow, aten.rsqrt] | |
# add_14 => add_12 | |
# mean_4 => mean_4 | |
# mul_26 => mul_28 | |
# out_1 => add_11 | |
# pow_5 => pow_5 | |
# rsqrt_4 => rsqrt_4 | |
# x_fp32_4 => convert_element_type_48 | |
# x_normed_4 => convert_element_type_49 | |
# y_2 => mul_29 | |
triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_11 = async_compile.triton('triton_', ''' | |
import triton | |
import triton.language as tl | |
from triton.compiler.compiler import AttrsDescriptor | |
from torch._inductor.runtime import triton_helpers, triton_heuristics | |
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties | |
@triton_heuristics.reduction( | |
size_hints=[1, 4096], | |
reduction_hint=ReductionHint.INNER, | |
filename=__file__, | |
triton_meta={'signature': {0: '*bf16', 1: '*bf16', 2: '*bf16', 3: '*bf16', 4: 'i32', 5: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {4: 1}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 5), equal_to_1=(4,))]}, | |
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_11', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 5, 'num_reduction': 1, 'backend_hash': '025D55A81D4601B19131CB9FF543A0FE014EB1904D412643095F7F0A8A7FF71B', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': True, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} | |
) | |
@triton.jit | |
def triton_(in_ptr0, in_ptr1, in_ptr2, out_ptr1, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr): | |
xnumel = 1 | |
rnumel = 4096 | |
xoffset = tl.program_id(0) * XBLOCK | |
xindex = xoffset + tl.arange(0, XBLOCK)[:, None] | |
xmask = tl.full([XBLOCK, RBLOCK], True, tl.int1) | |
rbase = tl.arange(0, RBLOCK)[None, :] | |
_tmp6 = tl.full([XBLOCK, RBLOCK], 0, tl.float32) | |
for roffset in range(0, rnumel, RBLOCK): | |
rindex = roffset + rbase | |
rmask = rindex < rnumel | |
r0 = rindex | |
tmp0 = tl.load(in_ptr0 + (r0), rmask, eviction_policy='evict_last', other=0.0).to(tl.float32) | |
tmp1 = tl.load(in_ptr1 + (r0), rmask, eviction_policy='evict_last', other=0.0).to(tl.float32) | |
tmp2 = tmp0 + tmp1 | |
tmp3 = tmp2.to(tl.float32) | |
tmp4 = tmp3 * tmp3 | |
tmp5 = tl.broadcast_to(tmp4, [XBLOCK, RBLOCK]) | |
tmp7 = _tmp6 + tmp5 | |
_tmp6 = tl.where(rmask, tmp7, _tmp6) | |
tmp6 = tl.sum(_tmp6, 1)[:, None] | |
for roffset in range(0, rnumel, RBLOCK): | |
rindex = roffset + rbase | |
rmask = rindex < rnumel | |
r0 = rindex | |
tmp8 = tl.load(in_ptr0 + (r0), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32) | |
tmp9 = tl.load(in_ptr1 + (r0), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32) | |
tmp19 = tl.load(in_ptr2 + (r0), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32) | |
tmp10 = tmp8 + tmp9 | |
tmp11 = tmp10.to(tl.float32) | |
tmp12 = 4096.0 | |
tmp13 = tmp6 / tmp12 | |
tmp14 = 1e-05 | |
tmp15 = tmp13 + tmp14 | |
tmp16 = libdevice.rsqrt(tmp15) | |
tmp17 = tmp11 * tmp16 | |
tmp18 = tmp17.to(tl.float32) | |
tmp20 = tmp18 * tmp19 | |
tl.store(out_ptr1 + (tl.broadcast_to(r0, [XBLOCK, RBLOCK])), tmp20, rmask) | |
''', device_str='cuda') | |
# kernel path: /tmp/torchinductor_mreso/b6/cb67w5ewxshtrgdm622b4yogv47npqcmskcea3x3froy22yqknx4.py | |
# Source Nodes: [add_19, h_3, mean_5, mul_36, mul_37, out_1, pow_6, rsqrt_5, x_fp32_5, x_normed_5], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, aten.pow, aten.rsqrt] | |
# add_19 => add_16 | |
# h_3 => add_15 | |
# mean_5 => mean_5 | |
# mul_36 => mul_38 | |
# mul_37 => mul_39 | |
# out_1 => add_11 | |
# pow_6 => pow_6 | |
# rsqrt_5 => rsqrt_5 | |
# x_fp32_5 => convert_element_type_62 | |
# x_normed_5 => convert_element_type_63 | |
triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_12 = async_compile.triton('triton_', ''' | |
import triton | |
import triton.language as tl | |
from triton.compiler.compiler import AttrsDescriptor | |
from torch._inductor.runtime import triton_helpers, triton_heuristics | |
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties | |
@triton_heuristics.reduction( | |
size_hints=[1, 4096], | |
reduction_hint=ReductionHint.INNER, | |
filename=__file__, | |
triton_meta={'signature': {0: '*bf16', 1: '*bf16', 2: '*bf16', 3: '*bf16', 4: '*bf16', 5: 'i32', 6: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {5: 1}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4, 6), equal_to_1=(5,))]}, | |
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_12', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 7, 'num_reduction': 1, 'backend_hash': '025D55A81D4601B19131CB9FF543A0FE014EB1904D412643095F7F0A8A7FF71B', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': True, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} | |
) | |
@triton.jit | |
def triton_(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr1, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr): | |
xnumel = 1 | |
rnumel = 4096 | |
xoffset = tl.program_id(0) * XBLOCK | |
xindex = xoffset + tl.arange(0, XBLOCK)[:, None] | |
xmask = tl.full([XBLOCK, RBLOCK], True, tl.int1) | |
rbase = tl.arange(0, RBLOCK)[None, :] | |
_tmp8 = tl.full([XBLOCK, RBLOCK], 0, tl.float32) | |
for roffset in range(0, rnumel, RBLOCK): | |
rindex = roffset + rbase | |
rmask = rindex < rnumel | |
r0 = rindex | |
tmp0 = tl.load(in_ptr0 + (r0), rmask, eviction_policy='evict_last', other=0.0).to(tl.float32) | |
tmp1 = tl.load(in_ptr1 + (r0), rmask, eviction_policy='evict_last', other=0.0).to(tl.float32) | |
tmp2 = tl.load(in_ptr2 + (r0), rmask, eviction_policy='evict_last', other=0.0).to(tl.float32) | |
tmp3 = tmp1 + tmp2 | |
tmp4 = tmp0 + tmp3 | |
tmp5 = tmp4.to(tl.float32) | |
tmp6 = tmp5 * tmp5 | |
tmp7 = tl.broadcast_to(tmp6, [XBLOCK, RBLOCK]) | |
tmp9 = _tmp8 + tmp7 | |
_tmp8 = tl.where(rmask, tmp9, _tmp8) | |
tmp8 = tl.sum(_tmp8, 1)[:, None] | |
for roffset in range(0, rnumel, RBLOCK): | |
rindex = roffset + rbase | |
rmask = rindex < rnumel | |
r0 = rindex | |
tmp10 = tl.load(in_ptr0 + (r0), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32) | |
tmp11 = tl.load(in_ptr1 + (r0), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32) | |
tmp12 = tl.load(in_ptr2 + (r0), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32) | |
tmp23 = tl.load(in_ptr3 + (r0), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32) | |
tmp13 = tmp11 + tmp12 | |
tmp14 = tmp10 + tmp13 | |
tmp15 = tmp14.to(tl.float32) | |
tmp16 = 4096.0 | |
tmp17 = tmp8 / tmp16 | |
tmp18 = 1e-05 | |
tmp19 = tmp17 + tmp18 | |
tmp20 = libdevice.rsqrt(tmp19) | |
tmp21 = tmp15 * tmp20 | |
tmp22 = tmp21.to(tl.float32) | |
tmp24 = tmp22 * tmp23 | |
tl.store(out_ptr1 + (tl.broadcast_to(r0, [XBLOCK, RBLOCK])), tmp24, rmask) | |
''', device_str='cuda') | |
# kernel path: /tmp/torchinductor_mreso/rb/crbbcpdu63ofr6njzrpxbpqsry65k5jtwokgjwbmsrhzlwsiwy53.py | |
# Source Nodes: [add_21, h_3, mean_6, mul_39, out_1, out_2, pow_7, rsqrt_6, x_fp32_6, x_normed_6, y_3], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, aten.pow, aten.rsqrt] | |
# add_21 => add_18 | |
# h_3 => add_15 | |
# mean_6 => mean_6 | |
# mul_39 => mul_42 | |
# out_1 => add_11 | |
# out_2 => add_17 | |
# pow_7 => pow_7 | |
# rsqrt_6 => rsqrt_6 | |
# x_fp32_6 => convert_element_type_72 | |
# x_normed_6 => convert_element_type_73 | |
# y_3 => mul_43 | |
triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_13 = async_compile.triton('triton_', ''' | |
import triton | |
import triton.language as tl | |
from triton.compiler.compiler import AttrsDescriptor | |
from torch._inductor.runtime import triton_helpers, triton_heuristics | |
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties | |
@triton_heuristics.reduction( | |
size_hints=[1, 4096], | |
reduction_hint=ReductionHint.INNER, | |
filename=__file__, | |
triton_meta={'signature': {0: '*bf16', 1: '*bf16', 2: '*bf16', 3: '*bf16', 4: '*bf16', 5: '*bf16', 6: 'i32', 7: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {6: 1}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4, 5, 7), equal_to_1=(6,))]}, | |
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_13', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 9, 'num_reduction': 1, 'backend_hash': '025D55A81D4601B19131CB9FF543A0FE014EB1904D412643095F7F0A8A7FF71B', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': True, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} | |
) | |
@triton.jit | |
def triton_(in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, out_ptr1, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr): | |
xnumel = 1 | |
rnumel = 4096 | |
xoffset = tl.program_id(0) * XBLOCK | |
xindex = xoffset + tl.arange(0, XBLOCK)[:, None] | |
xmask = tl.full([XBLOCK, RBLOCK], True, tl.int1) | |
rbase = tl.arange(0, RBLOCK)[None, :] | |
_tmp10 = tl.full([XBLOCK, RBLOCK], 0, tl.float32) | |
for roffset in range(0, rnumel, RBLOCK): | |
rindex = roffset + rbase | |
rmask = rindex < rnumel | |
r0 = rindex | |
tmp0 = tl.load(in_ptr0 + (r0), rmask, eviction_policy='evict_last', other=0.0).to(tl.float32) | |
tmp1 = tl.load(in_ptr1 + (r0), rmask, eviction_policy='evict_last', other=0.0).to(tl.float32) | |
tmp2 = tl.load(in_ptr2 + (r0), rmask, eviction_policy='evict_last', other=0.0).to(tl.float32) | |
tmp5 = tl.load(in_ptr3 + (r0), rmask, eviction_policy='evict_last', other=0.0).to(tl.float32) | |
tmp3 = tmp1 + tmp2 | |
tmp4 = tmp0 + tmp3 | |
tmp6 = tmp4 + tmp5 | |
tmp7 = tmp6.to(tl.float32) | |
tmp8 = tmp7 * tmp7 | |
tmp9 = tl.broadcast_to(tmp8, [XBLOCK, RBLOCK]) | |
tmp11 = _tmp10 + tmp9 | |
_tmp10 = tl.where(rmask, tmp11, _tmp10) | |
tmp10 = tl.sum(_tmp10, 1)[:, None] | |
for roffset in range(0, rnumel, RBLOCK): | |
rindex = roffset + rbase | |
rmask = rindex < rnumel | |
r0 = rindex | |
tmp12 = tl.load(in_ptr0 + (r0), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32) | |
tmp13 = tl.load(in_ptr1 + (r0), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32) | |
tmp14 = tl.load(in_ptr2 + (r0), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32) | |
tmp17 = tl.load(in_ptr3 + (r0), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32) | |
tmp27 = tl.load(in_ptr4 + (r0), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32) | |
tmp15 = tmp13 + tmp14 | |
tmp16 = tmp12 + tmp15 | |
tmp18 = tmp16 + tmp17 | |
tmp19 = tmp18.to(tl.float32) | |
tmp20 = 4096.0 | |
tmp21 = tmp10 / tmp20 | |
tmp22 = 1e-05 | |
tmp23 = tmp21 + tmp22 | |
tmp24 = libdevice.rsqrt(tmp23) | |
tmp25 = tmp19 * tmp24 | |
tmp26 = tmp25.to(tl.float32) | |
tmp28 = tmp26 * tmp27 | |
tl.store(out_ptr1 + (tl.broadcast_to(r0, [XBLOCK, RBLOCK])), tmp28, rmask) | |
''', device_str='cuda') | |
# kernel path: /tmp/torchinductor_mreso/tu/ctukwfouesobny732yngexqv32bgnsgz5l6q24aqfdzbci2ogg5b.py | |
# Source Nodes: [add_26, h_3, h_4, mean_7, mul_49, mul_50, out_1, out_2, pow_8, rsqrt_7, x_fp32_7, x_normed_7], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, aten.pow, aten.rsqrt] | |
# add_26 => add_22 | |
# h_3 => add_15 | |
# h_4 => add_21 | |
# mean_7 => mean_7 | |
# mul_49 => mul_52 | |
# mul_50 => mul_53 | |
# out_1 => add_11 | |
# out_2 => add_17 | |
# pow_8 => pow_8 | |
# rsqrt_7 => rsqrt_7 | |
# x_fp32_7 => convert_element_type_86 | |
# x_normed_7 => convert_element_type_87 | |
triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_14 = async_compile.triton('triton_', ''' | |
import triton | |
import triton.language as tl | |
from triton.compiler.compiler import AttrsDescriptor | |
from torch._inductor.runtime import triton_helpers, triton_heuristics | |
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties | |
@triton_heuristics.reduction( | |
size_hints=[1, 4096], | |
reduction_hint=ReductionHint.INNER, | |
filename=__file__, | |
triton_meta={'signature': {0: '*bf16', 1: '*bf16', 2: '*bf16', 3: '*bf16', 4: '*bf16', 5: '*bf16', 6: '*bf16', 7: 'i32', 8: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {7: 1}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4, 5, 6, 8), equal_to_1=(7,))]}, | |
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_14', 'mutated_arg_names': ['in_out_ptr0'], 'no_x_dim': False, 'num_load': 7, 'num_reduction': 1, 'backend_hash': '025D55A81D4601B19131CB9FF543A0FE014EB1904D412643095F7F0A8A7FF71B', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': True, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} | |
) | |
@triton.jit | |
def triton_(in_out_ptr0, in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, out_ptr1, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr): | |
xnumel = 1 | |
rnumel = 4096 | |
xoffset = tl.program_id(0) * XBLOCK | |
xindex = xoffset + tl.arange(0, XBLOCK)[:, None] | |
xmask = tl.full([XBLOCK, RBLOCK], True, tl.int1) | |
rbase = tl.arange(0, RBLOCK)[None, :] | |
_tmp12 = tl.full([XBLOCK, RBLOCK], 0, tl.float32) | |
for roffset in range(0, rnumel, RBLOCK): | |
rindex = roffset + rbase | |
rmask = rindex < rnumel | |
r0 = rindex | |
tmp0 = tl.load(in_ptr0 + (r0), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32) | |
tmp1 = tl.load(in_ptr1 + (r0), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32) | |
tmp2 = tl.load(in_out_ptr0 + (r0), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32) | |
tmp3 = tl.load(in_ptr2 + (r0), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32) | |
tmp6 = tl.load(in_ptr3 + (r0), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32) | |
tmp4 = tmp2 + tmp3 | |
tmp5 = tmp1 + tmp4 | |
tmp7 = tmp5 + tmp6 | |
tmp8 = tmp0 + tmp7 | |
tmp9 = tmp8.to(tl.float32) | |
tmp10 = tmp9 * tmp9 | |
tmp11 = tl.broadcast_to(tmp10, [XBLOCK, RBLOCK]) | |
tmp13 = _tmp12 + tmp11 | |
_tmp12 = tl.where(rmask, tmp13, _tmp12) | |
tl.store(in_out_ptr0 + (tl.broadcast_to(r0, [XBLOCK, RBLOCK])), tmp8, rmask) | |
tmp12 = tl.sum(_tmp12, 1)[:, None] | |
for roffset in range(0, rnumel, RBLOCK): | |
rindex = roffset + rbase | |
rmask = rindex < rnumel | |
r0 = rindex | |
tmp14 = tl.load(in_out_ptr0 + (r0), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32) | |
tmp23 = tl.load(in_ptr4 + (r0), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32) | |
tmp15 = tmp14.to(tl.float32) | |
tmp16 = 4096.0 | |
tmp17 = tmp12 / tmp16 | |
tmp18 = 1e-05 | |
tmp19 = tmp17 + tmp18 | |
tmp20 = libdevice.rsqrt(tmp19) | |
tmp21 = tmp15 * tmp20 | |
tmp22 = tmp21.to(tl.float32) | |
tmp24 = tmp22 * tmp23 | |
tl.store(out_ptr1 + (tl.broadcast_to(r0, [XBLOCK, RBLOCK])), tmp24, rmask) | |
''', device_str='cuda') | |
# kernel path: /tmp/torchinductor_mreso/nh/cnhnggduh5o3eeov7ogifj7cb5tdo3j5ndcivpiqhwnkp3u25u5n.py | |
# Source Nodes: [add_82, h_11, h_12, mean_23, mul_153, mul_154, out_10, out_9, pow_24, rsqrt_23, x_fp32_23, x_normed_23], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, aten.pow, aten.rsqrt] | |
# add_82 => add_70 | |
# h_11 => add_63 | |
# h_12 => add_69 | |
# mean_23 => mean_23 | |
# mul_153 => mul_164 | |
# mul_154 => mul_165 | |
# out_10 => add_65 | |
# out_9 => add_59 | |
# pow_24 => pow_24 | |
# rsqrt_23 => rsqrt_23 | |
# x_fp32_23 => convert_element_type_278 | |
# x_normed_23 => convert_element_type_279 | |
triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_15 = async_compile.triton('triton_', ''' | |
import triton | |
import triton.language as tl | |
from triton.compiler.compiler import AttrsDescriptor | |
from torch._inductor.runtime import triton_helpers, triton_heuristics | |
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties | |
@triton_heuristics.reduction( | |
size_hints=[1, 4096], | |
reduction_hint=ReductionHint.INNER, | |
filename=__file__, | |
triton_meta={'signature': {0: '*bf16', 1: '*bf16', 2: '*bf16', 3: '*bf16', 4: '*bf16', 5: '*bf16', 6: '*bf16', 7: 'i32', 8: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {7: 1}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4, 5, 6, 8), equal_to_1=(7,))]}, | |
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_15', 'mutated_arg_names': ['in_out_ptr0'], 'no_x_dim': False, 'num_load': 7, 'num_reduction': 1, 'backend_hash': '025D55A81D4601B19131CB9FF543A0FE014EB1904D412643095F7F0A8A7FF71B', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': True, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} | |
) | |
@triton.jit | |
def triton_(in_out_ptr0, in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, out_ptr1, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr): | |
xnumel = 1 | |
rnumel = 4096 | |
xoffset = tl.program_id(0) * XBLOCK | |
xindex = xoffset + tl.arange(0, XBLOCK)[:, None] | |
xmask = tl.full([XBLOCK, RBLOCK], True, tl.int1) | |
rbase = tl.arange(0, RBLOCK)[None, :] | |
_tmp12 = tl.full([XBLOCK, RBLOCK], 0, tl.float32) | |
for roffset in range(0, rnumel, RBLOCK): | |
rindex = roffset + rbase | |
rmask = rindex < rnumel | |
r0 = rindex | |
tmp0 = tl.load(in_out_ptr0 + (r0), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32) | |
tmp1 = tl.load(in_ptr0 + (r0), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32) | |
tmp2 = tl.load(in_ptr1 + (r0), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32) | |
tmp3 = tl.load(in_ptr2 + (r0), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32) | |
tmp6 = tl.load(in_ptr3 + (r0), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32) | |
tmp4 = tmp2 + tmp3 | |
tmp5 = tmp1 + tmp4 | |
tmp7 = tmp5 + tmp6 | |
tmp8 = tmp0 + tmp7 | |
tmp9 = tmp8.to(tl.float32) | |
tmp10 = tmp9 * tmp9 | |
tmp11 = tl.broadcast_to(tmp10, [XBLOCK, RBLOCK]) | |
tmp13 = _tmp12 + tmp11 | |
_tmp12 = tl.where(rmask, tmp13, _tmp12) | |
tl.store(in_out_ptr0 + (tl.broadcast_to(r0, [XBLOCK, RBLOCK])), tmp8, rmask) | |
tmp12 = tl.sum(_tmp12, 1)[:, None] | |
for roffset in range(0, rnumel, RBLOCK): | |
rindex = roffset + rbase | |
rmask = rindex < rnumel | |
r0 = rindex | |
tmp14 = tl.load(in_out_ptr0 + (r0), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32) | |
tmp23 = tl.load(in_ptr4 + (r0), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32) | |
tmp15 = tmp14.to(tl.float32) | |
tmp16 = 4096.0 | |
tmp17 = tmp12 / tmp16 | |
tmp18 = 1e-05 | |
tmp19 = tmp17 + tmp18 | |
tmp20 = libdevice.rsqrt(tmp19) | |
tmp21 = tmp15 * tmp20 | |
tmp22 = tmp21.to(tl.float32) | |
tmp24 = tmp22 * tmp23 | |
tl.store(out_ptr1 + (tl.broadcast_to(r0, [XBLOCK, RBLOCK])), tmp24, rmask) | |
''', device_str='cuda') | |
# kernel path: /tmp/torchinductor_mreso/dk/cdki6rbltj45344wgybatqntlalewzncir7p36oko7sux4egpdcm.py | |
# Source Nodes: [c_224, logits_1], Original ATen: [aten.div, aten.mm] | |
# c_224 => mm_224 | |
# logits_1 => div | |
triton_tem_fused_div_mm_16 = async_compile.triton('triton_', ''' | |
import triton | |
import triton.language as tl | |
from triton.compiler.compiler import AttrsDescriptor | |
from torch._inductor.runtime import triton_helpers, triton_heuristics | |
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties | |
@triton_heuristics.template( | |
num_stages=4, | |
num_warps=4, | |
triton_meta={'signature': {0: '*bf16', 1: '*bf16', 2: '*fp32'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2), equal_to_1=())]}, | |
inductor_meta={'kernel_name': 'triton_tem_fused_div_mm_16', 'backend_hash': '025D55A81D4601B19131CB9FF543A0FE014EB1904D412643095F7F0A8A7FF71B', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': True, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, | |
) | |
@triton.jit | |
def triton_(arg_A, arg_B, out_ptr1): | |
GROUP_M : tl.constexpr = 8 | |
EVEN_K : tl.constexpr = True | |
ALLOW_TF32 : tl.constexpr = False | |
ACC_TYPE : tl.constexpr = tl.float32 | |
B_PROLOGUE_CAST_TYPE : tl.constexpr = None | |
BLOCK_M : tl.constexpr = 16 | |
BLOCK_N : tl.constexpr = 128 | |
BLOCK_K : tl.constexpr = 128 | |
A = arg_A | |
B = arg_B | |
M = 1 | |
N = 128256 | |
K = 4096 | |
if M * N == 0: | |
# early exit due to zero-size input(s) | |
return | |
stride_am = 4096 | |
stride_ak = 1 | |
stride_bk = 1 | |
stride_bn = 4096 | |
# based on triton.ops.matmul | |
pid = tl.program_id(0) | |
grid_m = (M + BLOCK_M - 1) // BLOCK_M | |
grid_n = (N + BLOCK_N - 1) // BLOCK_N | |
# re-order program ID for better L2 performance | |
width = GROUP_M * grid_n | |
group_id = pid // width | |
group_size = min(grid_m - group_id * GROUP_M, GROUP_M) | |
pid_m = group_id * GROUP_M + (pid % group_size) | |
pid_n = (pid % width) // (group_size) | |
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) | |
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) | |
if (stride_am == 1 and stride_ak == M) or (stride_am == K and stride_ak == 1): | |
ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M) | |
else: | |
ram = rm % M | |
if (stride_bk == 1 and stride_bn == K) or (stride_bk == N and stride_bn == 1): | |
rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N) | |
else: | |
rbn = rn % N | |
rk = tl.arange(0, BLOCK_K) | |
A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak) | |
B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn) | |
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) | |
for k in range(K, 0, -BLOCK_K): | |
if EVEN_K: | |
a = tl.load(A) | |
b = tl.load(B) | |
else: | |
a = tl.load(A, mask=rk[None, :] < k, other=0.) | |
b = tl.load(B, mask=rk[:, None] < k, other=0.) | |
if B_PROLOGUE_CAST_TYPE is not None: | |
b = b.to(B_PROLOGUE_CAST_TYPE) | |
acc += tl.dot(a, b, allow_tf32=ALLOW_TF32) | |
A += BLOCK_K * stride_ak | |
B += BLOCK_K * stride_bk | |
# rematerialize rm and rn to save registers | |
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) | |
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) | |
idx_m = rm[:, None] | |
idx_n = rn[None, :] | |
mask = (idx_m < M) & (idx_n < N) | |
# inductor generates a suffix | |
xindex = idx_n + (128256*idx_m) | |
tmp0 = acc.to(tl.float32) | |
tmp1 = 1.6666666666666667 | |
tmp2 = tmp0 * tmp1 | |
tl.store(out_ptr1 + (tl.broadcast_to(xindex, acc.shape)), tmp2, mask) | |
''', device_str='cuda') | |
meta2 = {'GROUP_M': 8, 'EVEN_K': True, 'ALLOW_TF32': False, 'ACC_TYPE': 'tl.float32', 'B_PROLOGUE_CAST_TYPE': None, 'BLOCK_M': 16, 'BLOCK_N': 128, 'BLOCK_K': 128} | |
# kernel path: /tmp/torchinductor_mreso/lw/clwphkgp36abizew5ghfdhpfw7vyfjuel6fa5dordkziwvs7k47u.py | |
# Source Nodes: [logits_2, lt, probs], Original ATen: [aten._softmax, aten.lt, aten.scalar_tensor, aten.where] | |
# logits_2 => full_default_64, where_32 | |
# lt => lt | |
# probs => amax | |
triton_red_fused__softmax_lt_scalar_tensor_where_17 = async_compile.triton('triton_', ''' | |
import triton | |
import triton.language as tl | |
from triton.compiler.compiler import AttrsDescriptor | |
from torch._inductor.runtime import triton_helpers, triton_heuristics | |
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties | |
@triton_heuristics.reduction( | |
size_hints=[16, 8192], | |
reduction_hint=ReductionHint.INNER, | |
filename=__file__, | |
triton_meta={'signature': {0: '*fp32', 1: '*fp32', 2: '*fp32', 3: 'i32', 4: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4), equal_to_1=())]}, | |
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused__softmax_lt_scalar_tensor_where_17', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 2, 'num_reduction': 1, 'backend_hash': '025D55A81D4601B19131CB9FF543A0FE014EB1904D412643095F7F0A8A7FF71B', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': True, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} | |
) | |
@triton.jit | |
def triton_(in_ptr0, in_ptr1, out_ptr0, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr): | |
xnumel = 16 | |
rnumel = 8016 | |
xoffset = tl.program_id(0) * XBLOCK | |
xindex = xoffset + tl.arange(0, XBLOCK)[:, None] | |
xmask = xindex < xnumel | |
rbase = tl.arange(0, RBLOCK)[None, :] | |
x0 = xindex | |
tmp1 = tl.load(in_ptr1 + (299)) | |
tmp2 = tl.broadcast_to(tmp1, [XBLOCK, RBLOCK]) | |
_tmp7 = tl.full([XBLOCK, RBLOCK], float("-inf"), tl.float32) | |
for roffset in range(0, rnumel, RBLOCK): | |
rindex = roffset + rbase | |
rmask = rindex < rnumel | |
r1 = rindex | |
tmp0 = tl.load(in_ptr0 + (r1 + (8016*x0)), rmask & xmask, eviction_policy='evict_last', other=0.0) | |
tmp3 = tmp0 < tmp2 | |
tmp4 = float("-inf") | |
tmp5 = tl.where(tmp3, tmp4, tmp0) | |
tmp6 = tl.broadcast_to(tmp5, [XBLOCK, RBLOCK]) | |
tmp8 = triton_helpers.maximum(_tmp7, tmp6) | |
_tmp7 = tl.where(rmask & xmask, tmp8, _tmp7) | |
tmp7 = triton_helpers.max2(_tmp7, 1)[:, None] | |
tl.store(out_ptr0 + (x0), tmp7, xmask) | |
''', device_str='cuda') | |
# kernel path: /tmp/torchinductor_mreso/73/c73m5zklquxdkmoswttbn6lvjao4jtaanllldybhblahbrxtvlnu.py | |
# Source Nodes: [logits_2, lt, probs], Original ATen: [aten._softmax, aten.lt, aten.scalar_tensor, aten.where] | |
# logits_2 => full_default_64, where_32 | |
# lt => lt | |
# probs => amax | |
triton_per_fused__softmax_lt_scalar_tensor_where_18 = async_compile.triton('triton_', ''' | |
import triton | |
import triton.language as tl | |
from triton.compiler.compiler import AttrsDescriptor | |
from torch._inductor.runtime import triton_helpers, triton_heuristics | |
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties | |
@triton_heuristics.persistent_reduction( | |
size_hints=[1, 16], | |
reduction_hint=ReductionHint.INNER, | |
filename=__file__, | |
triton_meta={'signature': {0: '*fp32', 1: '*fp32', 2: 'i32', 3: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {2: 1}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 3), equal_to_1=(2,))]}, | |
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_per_fused__softmax_lt_scalar_tensor_where_18', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': '025D55A81D4601B19131CB9FF543A0FE014EB1904D412643095F7F0A8A7FF71B', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': True, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} | |
) | |
@triton.jit | |
def triton_(in_ptr0, out_ptr0, xnumel, rnumel, XBLOCK : tl.constexpr): | |
xnumel = 1 | |
rnumel = 16 | |
RBLOCK: tl.constexpr = 16 | |
xoffset = tl.program_id(0) * XBLOCK | |
xindex = xoffset + tl.arange(0, XBLOCK)[:, None] | |
xmask = tl.full([XBLOCK, RBLOCK], True, tl.int1) | |
rindex = tl.arange(0, RBLOCK)[None, :] | |
roffset = 0 | |
rmask = tl.full([XBLOCK, RBLOCK], True, tl.int1) | |
r0 = rindex | |
tmp0 = tl.load(in_ptr0 + (r0), None) | |
tmp1 = tl.broadcast_to(tmp0, [XBLOCK, RBLOCK]) | |
tmp3 = triton_helpers.max2(tmp1, 1)[:, None] | |
tl.store(out_ptr0 + (tl.full([XBLOCK, 1], 0, tl.int32)), tmp3, None) | |
''', device_str='cuda') | |
# kernel path: /tmp/torchinductor_mreso/p7/cp7u644vw4j3c6njbq2l3pgixqa34e47y4v7jqlflpndagvldspp.py | |
# Source Nodes: [logits_2, lt, probs], Original ATen: [aten._softmax, aten.lt, aten.scalar_tensor, aten.where] | |
# logits_2 => full_default_64, where_32 | |
# lt => lt | |
# probs => exp, sub_64, sum_1 | |
triton_red_fused__softmax_lt_scalar_tensor_where_19 = async_compile.triton('triton_', ''' | |
import triton | |
import triton.language as tl | |
from triton.compiler.compiler import AttrsDescriptor | |
from torch._inductor.runtime import triton_helpers, triton_heuristics | |
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties | |
@triton_heuristics.reduction( | |
size_hints=[16, 8192], | |
reduction_hint=ReductionHint.INNER, | |
filename=__file__, | |
triton_meta={'signature': {0: '*fp32', 1: '*fp32', 2: '*fp32', 3: '*fp32', 4: 'i32', 5: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4, 5), equal_to_1=())]}, | |
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused__softmax_lt_scalar_tensor_where_19', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 3, 'num_reduction': 1, 'backend_hash': '025D55A81D4601B19131CB9FF543A0FE014EB1904D412643095F7F0A8A7FF71B', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': True, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} | |
) | |
@triton.jit | |
def triton_(in_ptr0, in_ptr1, in_ptr2, out_ptr0, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr): | |
xnumel = 16 | |
rnumel = 8016 | |
xoffset = tl.program_id(0) * XBLOCK | |
xindex = xoffset + tl.arange(0, XBLOCK)[:, None] | |
xmask = xindex < xnumel | |
rbase = tl.arange(0, RBLOCK)[None, :] | |
x0 = xindex | |
tmp1 = tl.load(in_ptr1 + (299)) | |
tmp2 = tl.broadcast_to(tmp1, [XBLOCK, RBLOCK]) | |
tmp6 = tl.load(in_ptr2 + (0)) | |
tmp7 = tl.broadcast_to(tmp6, [XBLOCK, RBLOCK]) | |
_tmp11 = tl.full([XBLOCK, RBLOCK], 0, tl.float32) | |
for roffset in range(0, rnumel, RBLOCK): | |
rindex = roffset + rbase | |
rmask = rindex < rnumel | |
r1 = rindex | |
tmp0 = tl.load(in_ptr0 + (r1 + (8016*x0)), rmask & xmask, eviction_policy='evict_last', other=0.0) | |
tmp3 = tmp0 < tmp2 | |
tmp4 = float("-inf") | |
tmp5 = tl.where(tmp3, tmp4, tmp0) | |
tmp8 = tmp5 - tmp7 | |
tmp9 = tl_math.exp(tmp8) | |
tmp10 = tl.broadcast_to(tmp9, [XBLOCK, RBLOCK]) | |
tmp12 = _tmp11 + tmp10 | |
_tmp11 = tl.where(rmask & xmask, tmp12, _tmp11) | |
tmp11 = tl.sum(_tmp11, 1)[:, None] | |
tl.store(out_ptr0 + (x0), tmp11, xmask) | |
''', device_str='cuda') | |
# kernel path: /tmp/torchinductor_mreso/df/cdfdwm76l6z5rzxg3kkcpq6h47bpkmm5n4zgve47fbz36i65dwyu.py | |
# Source Nodes: [logits_2, lt, probs], Original ATen: [aten._softmax, aten.lt, aten.scalar_tensor, aten.where] | |
# logits_2 => full_default_64, where_32 | |
# lt => lt | |
# probs => exp, sub_64, sum_1 | |
triton_per_fused__softmax_lt_scalar_tensor_where_20 = async_compile.triton('triton_', ''' | |
import triton | |
import triton.language as tl | |
from triton.compiler.compiler import AttrsDescriptor | |
from torch._inductor.runtime import triton_helpers, triton_heuristics | |
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties | |
@triton_heuristics.persistent_reduction( | |
size_hints=[1, 16], | |
reduction_hint=ReductionHint.INNER, | |
filename=__file__, | |
triton_meta={'signature': {0: '*fp32', 1: '*fp32', 2: 'i32', 3: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {2: 1}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 3), equal_to_1=(2,))]}, | |
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_per_fused__softmax_lt_scalar_tensor_where_20', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': '025D55A81D4601B19131CB9FF543A0FE014EB1904D412643095F7F0A8A7FF71B', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': True, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} | |
) | |
@triton.jit | |
def triton_(in_ptr0, out_ptr0, xnumel, rnumel, XBLOCK : tl.constexpr): | |
xnumel = 1 | |
rnumel = 16 | |
RBLOCK: tl.constexpr = 16 | |
xoffset = tl.program_id(0) * XBLOCK | |
xindex = xoffset + tl.arange(0, XBLOCK)[:, None] | |
xmask = tl.full([XBLOCK, RBLOCK], True, tl.int1) | |
rindex = tl.arange(0, RBLOCK)[None, :] | |
roffset = 0 | |
rmask = tl.full([XBLOCK, RBLOCK], True, tl.int1) | |
r0 = rindex | |
tmp0 = tl.load(in_ptr0 + (r0), None) | |
tmp1 = tl.broadcast_to(tmp0, [XBLOCK, RBLOCK]) | |
tmp3 = tl.sum(tmp1, 1)[:, None] | |
tl.store(out_ptr0 + (tl.full([XBLOCK, 1], 0, tl.int32)), tmp3, None) | |
''', device_str='cuda') | |
# kernel path: /tmp/torchinductor_mreso/zt/cztz4kuxbvivmp6txwflqjx3qpqrnxard4octamzlea35cgxr3fn.py | |
# Source Nodes: [argmax, logits_2, lt, probs, q_128, to_450, truediv_1], Original ATen: [aten._softmax, aten._to_copy, aten.argmax, aten.div, aten.exponential, aten.lt, aten.scalar_tensor, aten.where] | |
# argmax => argmax | |
# logits_2 => full_default_64, where_32 | |
# lt => lt | |
# probs => div_1, exp, sub_64 | |
# q_128 => full_default_65, ge, log, mul_450, where_33 | |
# to_450 => convert_element_type_773 | |
# truediv_1 => div_2 | |
triton_red_fused__softmax__to_copy_argmax_div_exponential_lt_scalar_tensor_where_21 = async_compile.triton('triton_', ''' | |
import triton | |
import triton.language as tl | |
from triton.compiler.compiler import AttrsDescriptor | |
from torch._inductor.runtime import triton_helpers, triton_heuristics | |
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties | |
@triton_heuristics.reduction( | |
size_hints=[1, 131072], | |
reduction_hint=ReductionHint.DEFAULT, | |
filename=__file__, | |
triton_meta={'signature': {0: '*i64', 1: '*fp32', 2: '*fp32', 3: '*fp32', 4: '*fp32', 5: '*i32', 6: 'i32', 7: 'i32', 8: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {7: 1}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4, 5, 8), equal_to_1=(7,))]}, | |
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused__softmax__to_copy_argmax_div_exponential_lt_scalar_tensor_where_21', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 4, 'num_reduction': 1, 'backend_hash': '025D55A81D4601B19131CB9FF543A0FE014EB1904D412643095F7F0A8A7FF71B', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': True, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} | |
) | |
@triton.jit | |
def triton_(in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, out_ptr2, load_seed_offset, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr): | |
xnumel = 1 | |
rnumel = 128256 | |
xoffset = tl.program_id(0) * XBLOCK | |
xindex = xoffset + tl.arange(0, XBLOCK)[:, None] | |
xmask = tl.full([XBLOCK, RBLOCK], True, tl.int1) | |
rbase = tl.arange(0, RBLOCK)[None, :] | |
tmp4 = tl.load(in_ptr2 + (299)) | |
tmp5 = tl.broadcast_to(tmp4, [XBLOCK, RBLOCK]) | |
tmp9 = tl.load(in_ptr3 + (0)) | |
tmp10 = tl.broadcast_to(tmp9, [XBLOCK, RBLOCK]) | |
tmp13 = tl.load(in_ptr4 + (0)) | |
tmp14 = tl.broadcast_to(tmp13, [XBLOCK, RBLOCK]) | |
_tmp25 = tl.full([XBLOCK, RBLOCK], float("-inf"), tl.float32) | |
_tmp25_index = tl.full([XBLOCK, RBLOCK], 9223372036854775807, tl.int64) | |
for roffset in range(0, rnumel, RBLOCK): | |
rindex = roffset + rbase | |
rmask = rindex < rnumel | |
r0 = rindex | |
tmp3 = tl.load(in_ptr1 + (r0), rmask, eviction_policy='evict_first', other=0.0) | |
tmp0 = tl.load(in_ptr0 + load_seed_offset) | |
tmp1 = r0 | |
tmp2 = tl.rand(tmp0, (tmp1).to(tl.uint32)) | |
tmp6 = tmp3 < tmp5 | |
tmp7 = float("-inf") | |
tmp8 = tl.where(tmp6, tmp7, tmp3) | |
tmp11 = tmp8 - tmp10 | |
tmp12 = tl_math.exp(tmp11) | |
tmp15 = tmp12 / tmp14 | |
tmp16 = 0.9999999403953552 | |
tmp17 = tmp2 >= tmp16 | |
tmp18 = tl_math.log(tmp2) | |
tmp19 = -5.960464477539063e-08 | |
tmp20 = tl.where(tmp17, tmp19, tmp18) | |
tmp21 = -1.0 | |
tmp22 = tmp20 * tmp21 | |
tmp23 = tmp15 / tmp22 | |
tmp24 = tl.broadcast_to(tmp23, [XBLOCK, RBLOCK]) | |
_tmp25_next, _tmp25_index_next = triton_helpers.maximum_with_index( | |
_tmp25, _tmp25_index, tmp24, rindex | |
) | |
_tmp25 = tl.where(rmask, _tmp25_next, _tmp25) | |
_tmp25_index = tl.where(rmask, _tmp25_index_next, _tmp25_index) | |
_, tmp25_tmp = triton_helpers.max_with_index(_tmp25, _tmp25_index, 1) | |
tmp25 = tmp25_tmp[:, None] | |
tmp26 = tmp25.to(tl.int32) | |
tl.store(out_ptr2 + (tl.full([XBLOCK, 1], 0, tl.int32)), tmp26, None) | |
''', device_str='cuda') | |
async_compile.wait(globals()) | |
del async_compile | |
def call(args): | |
arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1, arg7_1, arg8_1, arg9_1, arg10_1, arg11_1, arg12_1, arg13_1, arg14_1, arg15_1, arg16_1, arg17_1, arg18_1, arg19_1, arg20_1, arg21_1, arg22_1, arg23_1, arg24_1, arg25_1, arg26_1, arg27_1, arg28_1, arg29_1, arg30_1, arg31_1, arg32_1, arg33_1, arg34_1, arg35_1, arg36_1, arg37_1, arg38_1, arg39_1, arg40_1, arg41_1, arg42_1, arg43_1, arg44_1, arg45_1, arg46_1, arg47_1, arg48_1, arg49_1, arg50_1, arg51_1, arg52_1, arg53_1, arg54_1, arg55_1, arg56_1, arg57_1, arg58_1, arg59_1, arg60_1, arg61_1, arg62_1, arg63_1, arg64_1, arg65_1, arg66_1, arg67_1, arg68_1, arg69_1, arg70_1, arg71_1, arg72_1, arg73_1, arg74_1, arg75_1, arg76_1, arg77_1, arg78_1, arg79_1, arg80_1, arg81_1, arg82_1, arg83_1, arg84_1, arg85_1, arg86_1, arg87_1, arg88_1, arg89_1, arg90_1, arg91_1, arg92_1, arg93_1, arg94_1, arg95_1, arg96_1, arg97_1, arg98_1, arg99_1, arg100_1, arg101_1, arg102_1, arg103_1, arg104_1, arg105_1, arg106_1, arg107_1, arg108_1, arg109_1, arg110_1, arg111_1, arg112_1, arg113_1, arg114_1, arg115_1, arg116_1, arg117_1, arg118_1, arg119_1, arg120_1, arg121_1, arg122_1, arg123_1, arg124_1, arg125_1, arg126_1, arg127_1, arg128_1, arg129_1, arg130_1, arg131_1, arg132_1, arg133_1, arg134_1, arg135_1, arg136_1, arg137_1, arg138_1, arg139_1, arg140_1, arg141_1, arg142_1, arg143_1, arg144_1, arg145_1, arg146_1, arg147_1, arg148_1, arg149_1, arg150_1, arg151_1, arg152_1, arg153_1, arg154_1, arg155_1, arg156_1, arg157_1, arg158_1, arg159_1, arg160_1, arg161_1, arg162_1, arg163_1, arg164_1, arg165_1, arg166_1, arg167_1, arg168_1, arg169_1, arg170_1, arg171_1, arg172_1, arg173_1, arg174_1, arg175_1, arg176_1, arg177_1, arg178_1, arg179_1, arg180_1, arg181_1, arg182_1, arg183_1, arg184_1, arg185_1, arg186_1, arg187_1, arg188_1, arg189_1, arg190_1, arg191_1, arg192_1, arg193_1, arg194_1, arg195_1, arg196_1, arg197_1, arg198_1, arg199_1, arg200_1, arg201_1, arg202_1, arg203_1, arg204_1, arg205_1, arg206_1, arg207_1, arg208_1, arg209_1, arg210_1, arg211_1, arg212_1, arg213_1, arg214_1, arg215_1, arg216_1, arg217_1, arg218_1, arg219_1, arg220_1, arg221_1, arg222_1, arg223_1, arg224_1, arg225_1, arg226_1, arg227_1, arg228_1, arg229_1, arg230_1, arg231_1, arg232_1, arg233_1, arg234_1, arg235_1, arg236_1, arg237_1, arg238_1, arg239_1, arg240_1, arg241_1, arg242_1, arg243_1, arg244_1, arg245_1, arg246_1, arg247_1, arg248_1, arg249_1, arg250_1, arg251_1, arg252_1, arg253_1, arg254_1, arg255_1, arg256_1, arg257_1, arg258_1, arg259_1, arg260_1, arg261_1, arg262_1, arg263_1, arg264_1, arg265_1, arg266_1, arg267_1, arg268_1, arg269_1, arg270_1, arg271_1, arg272_1, arg273_1, arg274_1, arg275_1, arg276_1, arg277_1, arg278_1, arg279_1, arg280_1, arg281_1, arg282_1, arg283_1, arg284_1, arg285_1, arg286_1, arg287_1, arg288_1, arg289_1, arg290_1, arg291_1, arg292_1, arg293_1, arg294_1, arg295_1, arg296_1, arg297_1, arg298_1, arg299_1, arg300_1, arg301_1, arg302_1, arg303_1, arg304_1, arg305_1, arg306_1, arg307_1, arg308_1, arg309_1, arg310_1, arg311_1, arg312_1, arg313_1, arg314_1, arg315_1, arg316_1, arg317_1, arg318_1, arg319_1, arg320_1, arg321_1, arg322_1, arg323_1, arg324_1, arg325_1, arg326_1, arg327_1, arg328_1, arg329_1, arg330_1, arg331_1, arg332_1, arg333_1, arg334_1, arg335_1, arg336_1, arg337_1, arg338_1, arg339_1, arg340_1, arg341_1, arg342_1, arg343_1, arg344_1, arg345_1, arg346_1, arg347_1, arg348_1, arg349_1, arg350_1, arg351_1, arg352_1, arg353_1, arg354_1, arg355_1, arg356_1, arg357_1, arg358_1, arg359_1, arg360_1, arg361_1, arg362_1, arg363_1, arg364_1, arg365_1, arg366_1, arg367_1, arg368_1, arg369_1, arg370_1, arg371_1, arg372_1, arg373_1, arg374_1, arg375_1, arg376_1, arg377_1, arg378_1, arg379_1, arg380_1, arg381_1, arg382_1, arg383_1, arg384_1, arg385_1, arg386_1, arg387_1, arg388_1, arg389_1, arg390_1, arg391_1, arg392_1, arg393_1, arg394_1, arg395_1, arg396_1, arg397_1, arg398_1, arg399_1, arg400_1, arg401_1, arg402_1, arg403_1, arg404_1, arg405_1, arg406_1, arg407_1, arg408_1, arg409_1, arg410_1, arg411_1, arg412_1, arg413_1, arg414_1, arg415_1, arg416_1, arg417_1, arg418_1, arg419_1, arg420_1, arg421_1, arg422_1, arg423_1, arg424_1, arg425_1, arg426_1, arg427_1, arg428_1, arg429_1, arg430_1, arg431_1, arg432_1, arg433_1, arg434_1, arg435_1, arg436_1, arg437_1, arg438_1, arg439_1, arg440_1, arg441_1, arg442_1, arg443_1, arg444_1, arg445_1, arg446_1, arg447_1, arg448_1, arg449_1, arg450_1, arg451_1, arg452_1, arg453_1, arg454_1, arg455_1, arg456_1, arg457_1, arg458_1, arg459_1, arg460_1, arg461_1, arg462_1, arg463_1, arg464_1, arg465_1, arg466_1, arg467_1, arg468_1, arg469_1, arg470_1, arg471_1, arg472_1, arg473_1, arg474_1, arg475_1, arg476_1, arg477_1, arg478_1, arg479_1, arg480_1, arg481_1, arg482_1, arg483_1, arg484_1, arg485_1, arg486_1, arg487_1, arg488_1, arg489_1, arg490_1, arg491_1, arg492_1, arg493_1, arg494_1, arg495_1, arg496_1, arg497_1, arg498_1, arg499_1, arg500_1, arg501_1, arg502_1, arg503_1, arg504_1, arg505_1, arg506_1, arg507_1, arg508_1, arg509_1, arg510_1, arg511_1, arg512_1, arg513_1, arg514_1, arg515_1, arg516_1, arg517_1, arg518_1, arg519_1, arg520_1, arg521_1, arg522_1, arg523_1, arg524_1, arg525_1, arg526_1, arg527_1, arg528_1, arg529_1, arg530_1, arg531_1, arg532_1, arg533_1, arg534_1, arg535_1, arg536_1, arg537_1, arg538_1, arg539_1, arg540_1, arg541_1, arg542_1, arg543_1, arg544_1, arg545_1, arg546_1, arg547_1, arg548_1, arg549_1, arg550_1, arg551_1, arg552_1, arg553_1, arg554_1, arg555_1, arg556_1, arg557_1, arg558_1, arg559_1, arg560_1, arg561_1, arg562_1, arg563_1, arg564_1, arg565_1, arg566_1, arg567_1, arg568_1, arg569_1, arg570_1, arg571_1, arg572_1, arg573_1, arg574_1, arg575_1, arg576_1, arg577_1, arg578_1, arg579_1, arg580_1, arg581_1, arg582_1, arg583_1, arg584_1, arg585_1, arg586_1, arg587_1, arg588_1, arg589_1, arg590_1, arg591_1, arg592_1, arg593_1, arg594_1, arg595_1, arg596_1, arg597_1, arg598_1, arg599_1, arg600_1, arg601_1, arg602_1, arg603_1, arg604_1, arg605_1, arg606_1, arg607_1, arg608_1, arg609_1, arg610_1, arg611_1, arg612_1, arg613_1, arg614_1, arg615_1, arg616_1, arg617_1, arg618_1, arg619_1, arg620_1, arg621_1, arg622_1, arg623_1, arg624_1, arg625_1, arg626_1, arg627_1, arg628_1, arg629_1, arg630_1, arg631_1, arg632_1, arg633_1, arg634_1, arg635_1, arg636_1, arg637_1, arg638_1, arg639_1, arg640_1, arg641_1, arg642_1, arg643_1, arg644_1, arg645_1, arg646_1, arg647_1, arg648_1, arg649_1, arg650_1, arg651_1, arg652_1, arg653_1, arg654_1, arg655_1, arg656_1, arg657_1, arg658_1, arg659_1, arg660_1, arg661_1, arg662_1, arg663_1, arg664_1, arg665_1, arg666_1, arg667_1, arg668_1, arg669_1, arg670_1, arg671_1, arg672_1, arg673_1, arg674_1, arg675_1, arg676_1, arg677_1, arg678_1, arg679_1, arg680_1, arg681_1, arg682_1, arg683_1, arg684_1, arg685_1, arg686_1, arg687_1, arg688_1, arg689_1, arg690_1, arg691_1, arg692_1, arg693_1, arg694_1, arg695_1, arg696_1, arg697_1, arg698_1, arg699_1, arg700_1, arg701_1, arg702_1, arg703_1, arg704_1, arg705_1, arg706_1, arg707_1, arg708_1, arg709_1, arg710_1, arg711_1, arg712_1, arg713_1, arg714_1, arg715_1, arg716_1, arg717_1, arg718_1, arg719_1, arg720_1, arg721_1, arg722_1, arg723_1, arg724_1, arg725_1, arg726_1, arg727_1, arg728_1, arg729_1, arg730_1, arg731_1, arg732_1, arg733_1, arg734_1, arg735_1, arg736_1, arg737_1, arg738_1, arg739_1, arg740_1, arg741_1, arg742_1, arg743_1, arg744_1, arg745_1, arg746_1, arg747_1, arg748_1, arg749_1, arg750_1, arg751_1, arg752_1, arg753_1, arg754_1, arg755_1, arg756_1, arg757_1, arg758_1, arg759_1, arg760_1, arg761_1, arg762_1, arg763_1, arg764_1, arg765_1, arg766_1, arg767_1, arg768_1, arg769_1, arg770_1, arg771_1, arg772_1, arg773_1, arg774_1, arg775_1, arg776_1, arg777_1, arg778_1, arg779_1, arg780_1, arg781_1, arg782_1, arg783_1, arg784_1, arg785_1, arg786_1, arg787_1, arg788_1, arg789_1, arg790_1, arg791_1, arg792_1, arg793_1, arg794_1, arg795_1, arg796_1, arg797_1, arg798_1, arg799_1, arg800_1, arg801_1, arg802_1, arg803_1, arg804_1, arg805_1, arg806_1, arg807_1, arg808_1, arg809_1, arg810_1, arg811_1, arg812_1, arg813_1, arg814_1, arg815_1, arg816_1, arg817_1, arg818_1, arg819_1, arg820_1, arg821_1, arg822_1, arg823_1, arg824_1, arg825_1, arg826_1, arg827_1, arg828_1, arg829_1, arg830_1, arg831_1, arg832_1, arg833_1, arg834_1, arg835_1, arg836_1, arg837_1, arg838_1, arg839_1 = args | |
args.clear() | |
assert_size_stride(arg0_1, (1, 1), (1, 1)) | |
assert_size_stride(arg1_1, (128256, 4096), (4096, 1)) | |
assert_size_stride(arg2_1, (2048, 2048), (2048, 1)) | |
assert_size_stride(arg3_1, (1, ), (1, )) | |
assert_size_stride(arg4_1, (4096, ), (1, )) | |
assert_size_stride(arg5_1, (4096, 4096), (4096, 1)) | |
assert_size_stride(arg6_1, (4096, 16), (16, 1)) | |
assert_size_stride(arg7_1, (4096, 16), (16, 1)) | |
assert_size_stride(arg8_1, (1024, 4096), (4096, 1)) | |
assert_size_stride(arg9_1, (1024, 16), (16, 1)) | |
assert_size_stride(arg10_1, (1024, 16), (16, 1)) | |
assert_size_stride(arg11_1, (1024, 4096), (4096, 1)) | |
assert_size_stride(arg12_1, (1024, 16), (16, 1)) | |
assert_size_stride(arg13_1, (1024, 16), (16, 1)) | |
assert_size_stride(arg14_1, (2048, 64, 2), (128, 2, 1)) | |
assert_size_stride(arg15_1, (1, 32, 2048, 128), (8388608, 262144, 128, 1)) | |
assert_size_stride(arg16_1, (1, 32, 2048, 128), (8388608, 262144, 128, 1)) | |
assert_size_stride(arg17_1, (4096, 4096), (4096, 1)) | |
assert_size_stride(arg18_1, (4096, 16), (16, 1)) | |
assert_size_stride(arg19_1, (4096, 16), (16, 1)) | |
assert_size_stride(arg20_1, (4096, ), (1, )) | |
assert_size_stride(arg21_1, (14336, 4096), (4096, 1)) | |
assert_size_stride(arg22_1, (14336, 16), (16, 1)) | |
assert_size_stride(arg23_1, (14336, 16), (16, 1)) | |
assert_size_stride(arg24_1, (14336, 4096), (4096, 1)) | |
assert_size_stride(arg25_1, (14336, 16), (16, 1)) | |
assert_size_stride(arg26_1, (14336, 16), (16, 1)) | |
assert_size_stride(arg27_1, (4096, 14336), (14336, 1)) | |
assert_size_stride(arg28_1, (4096, 56), (56, 1)) | |
assert_size_stride(arg29_1, (4096, 56), (56, 1)) | |
assert_size_stride(arg30_1, (4096, ), (1, )) | |
assert_size_stride(arg31_1, (4096, 4096), (4096, 1)) | |
assert_size_stride(arg32_1, (4096, 16), (16, 1)) | |
assert_size_stride(arg33_1, (4096, 16), (16, 1)) | |
assert_size_stride(arg34_1, (1024, 4096), (4096, 1)) | |
assert_size_stride(arg35_1, (1024, 16), (16, 1)) | |
assert_size_stride(arg36_1, (1024, 16), (16, 1)) | |
assert_size_stride(arg37_1, (1024, 4096), (4096, 1)) | |
assert_size_stride(arg38_1, (1024, 16), (16, 1)) | |
assert_size_stride(arg39_1, (1024, 16), (16, 1)) | |
assert_size_stride(arg40_1, (2048, 64, 2), (128, 2, 1)) | |
assert_size_stride(arg41_1, (1, 32, 2048, 128), (8388608, 262144, 128, 1)) | |
assert_size_stride(arg42_1, (1, 32, 2048, 128), (8388608, 262144, 128, 1)) | |
assert_size_stride(arg43_1, (4096, 4096), (4096, 1)) | |
assert_size_stride(arg44_1, (4096, 16), (16, 1)) | |
assert_size_stride(arg45_1, (4096, 16), (16, 1)) | |
assert_size_stride(arg46_1, (4096, ), (1, )) | |
assert_size_stride(arg47_1, (14336, 4096), (4096, 1)) | |
assert_size_stride(arg48_1, (14336, 16), (16, 1)) | |
assert_size_stride(arg49_1, (14336, 16), (16, 1)) | |
assert_size_stride(arg50_1, (14336, 4096), (4096, 1)) | |
assert_size_stride(arg51_1, (14336, 16), (16, 1)) | |
assert_size_stride(arg52_1, (14336, 16), (16, 1)) | |
assert_size_stride(arg53_1, (4096, 14336), (14336, 1)) | |
assert_size_stride(arg54_1, (4096, 56), (56, 1)) | |
assert_size_stride(arg55_1, (4096, 56), (56, 1)) | |
assert_size_stride(arg56_1, (4096, ), (1, )) | |
assert_size_stride(arg57_1, (4096, 4096), (4096, 1)) | |
assert_size_stride(arg58_1, (4096, 16), (16, 1)) | |
assert_size_stride(arg59_1, (4096, 16), (16, 1)) | |
assert_size_stride(arg60_1, (1024, 4096), (4096, 1)) | |
assert_size_stride(arg61_1, (1024, 16), (16, 1)) | |
assert_size_stride(arg62_1, (1024, 16), (16, 1)) | |
assert_size_stride(arg63_1, (1024, 4096), (4096, 1)) | |
assert_size_stride(arg64_1, (1024, 16), (16, 1)) | |
assert_size_stride(arg65_1, (1024, 16), (16, 1)) | |
assert_size_stride(arg66_1, (2048, 64, 2), (128, 2, 1)) | |
assert_size_stride(arg67_1, (1, 32, 2048, 128), (8388608, 262144, 128, 1)) | |
assert_size_stride(arg68_1, (1, 32, 2048, 128), (8388608, 262144, 128, 1)) | |
assert_size_stride(arg69_1, (4096, 4096), (4096, 1)) | |
assert_size_stride(arg70_1, (4096, 16), (16, 1)) | |
assert_size_stride(arg71_1, (4096, 16), (16, 1)) | |
assert_size_stride(arg72_1, (4096, ), (1, )) | |
assert_size_stride(arg73_1, (14336, 4096), (4096, 1)) | |
assert_size_stride(arg74_1, (14336, 16), (16, 1)) | |
assert_size_stride(arg75_1, (14336, 16), (16, 1)) | |
assert_size_stride(arg76_1, (14336, 4096), (4096, 1)) | |
assert_size_stride(arg77_1, (14336, 16), (16, 1)) | |
assert_size_stride(arg78_1, (14336, 16), (16, 1)) | |
assert_size_stride(arg79_1, (4096, 14336), (14336, 1)) | |
assert_size_stride(arg80_1, (4096, 56), (56, 1)) | |
assert_size_stride(arg81_1, (4096, 56), (56, 1)) | |
assert_size_stride(arg82_1, (4096, ), (1, )) | |
assert_size_stride(arg83_1, (4096, 4096), (4096, 1)) | |
assert_size_stride(arg84_1, (4096, 16), (16, 1)) | |
assert_size_stride(arg85_1, (4096, 16), (16, 1)) | |
assert_size_stride(arg86_1, (1024, 4096), (4096, 1)) | |
assert_size_stride(arg87_1, (1024, 16), (16, 1)) | |
assert_size_stride(arg88_1, (1024, 16), (16, 1)) | |
assert_size_stride(arg89_1, (1024, 4096), (4096, 1)) | |
assert_size_stride(arg90_1, (1024, 16), (16, 1)) | |
assert_size_stride(arg91_1, (1024, 16), (16, 1)) | |
assert_size_stride(arg92_1, (2048, 64, 2), (128, 2, 1)) | |
assert_size_stride(arg93_1, (1, 32, 2048, 128), (8388608, 262144, 128, 1)) | |
assert_size_stride(arg94_1, (1, 32, 2048, 128), (8388608, 262144, 128, 1)) | |
assert_size_stride(arg95_1, (4096, 4096), (4096, 1)) | |
assert_size_stride(arg96_1, (4096, 16), (16, 1)) | |
assert_size_stride(arg97_1, (4096, 16), (16, 1)) | |
assert_size_stride(arg98_1, (4096, ), (1, )) | |
assert_size_stride(arg99_1, (14336, 4096), (4096, 1)) | |
assert_size_stride(arg100_1, (14336, 16), (16, 1)) | |
assert_size_stride(arg101_1, (14336, 16), (16, 1)) | |
assert_size_stride(arg102_1, (14336, 4096), (4096, 1)) | |
assert_size_stride(arg103_1, (14336, 16), (16, 1)) | |
assert_size_stride(arg104_1, (14336, 16), (16, 1)) | |
assert_size_stride(arg105_1, (4096, 14336), (14336, 1)) | |
assert_size_stride(arg106_1, (4096, 56), (56, 1)) | |
assert_size_stride(arg107_1, (4096, 56), (56, 1)) | |
assert_size_stride(arg108_1, (4096, ), (1, )) | |
assert_size_stride(arg109_1, (4096, 4096), (4096, 1)) | |
assert_size_stride(arg110_1, (4096, 16), (16, 1)) | |
assert_size_stride(arg111_1, (4096, 16), (16, 1)) | |
assert_size_stride(arg112_1, (1024, 4096), (4096, 1)) | |
assert_size_stride(arg113_1, (1024, 16), (16, 1)) | |
assert_size_stride(arg114_1, (1024, 16), (16, 1)) | |
assert_size_stride(arg115_1, (1024, 4096), (4096, 1)) | |
assert_size_stride(arg116_1, (1024, 16), (16, 1)) | |
assert_size_stride(arg117_1, (1024, 16), (16, 1)) | |
assert_size_stride(arg118_1, (2048, 64, 2), (128, 2, 1)) | |
assert_size_stride(arg119_1, (1, 32, 2048, 128), (8388608, 262144, 128, 1)) | |
assert_size_stride(arg120_1, (1, 32, 2048, 128), (8388608, 262144, 128, 1)) | |
assert_size_stride(arg121_1, (4096, 4096), (4096, 1)) | |
assert_size_stride(arg122_1, (4096, 16), (16, 1)) | |
assert_size_stride(arg123_1, (4096, 16), (16, 1)) | |
assert_size_stride(arg124_1, (4096, ), (1, )) | |
assert_size_stride(arg125_1, (14336, 4096), (4096, 1)) | |
assert_size_stride(arg126_1, (14336, 16), (16, 1)) | |
assert_size_stride(arg127_1, (14336, 16), (16, 1)) | |
assert_size_stride(arg128_1, (14336, 4096), (4096, 1)) | |
assert_size_stride(arg129_1, (14336, 16), (16, 1)) | |
assert_size_stride(arg130_1, (14336, 16), (16, 1)) | |
assert_size_stride(arg131_1, (4096, 14336), (14336, 1)) | |
assert_size_stride(arg132_1, (4096, 56), (56, 1)) | |
assert_size_stride(arg133_1, (4096, 56), (56, 1)) | |
assert_size_stride(arg134_1, (4096, ), (1, )) | |
assert_size_stride(arg135_1, (4096, 4096), (4096, 1)) | |
assert_size_stride(arg136_1, (4096, 16), (16, 1)) | |
assert_size_stride(arg137_1, (4096, 16), (16, 1)) | |
assert_size_stride(arg138_1, (1024, 4096), (4096, 1)) | |
assert_size_stride(arg139_1, (1024, 16), (16, 1)) | |
assert_size_stride(arg140_1, (1024, 16), (16, 1)) | |
assert_size_stride(arg141_1, (1024, 4096), (4096, 1)) | |
assert_size_stride(arg142_1, (1024, 16), (16, 1)) | |
assert_size_stride(arg143_1, (1024, 16), (16, 1)) | |
assert_size_stride(arg144_1, (2048, 64, 2), (128, 2, 1)) | |
assert_size_stride(arg145_1, (1, 32, 2048, 128), (8388608, 262144, 128, 1)) | |
assert_size_stride(arg146_1, (1, 32, 2048, 128), (8388608, 262144, 128, 1)) | |
assert_size_stride(arg147_1, (4096, 4096), (4096, 1)) | |
assert_size_stride(arg148_1, (4096, 16), (16, 1)) | |
assert_size_stride(arg149_1, (4096, 16), (16, 1)) | |
assert_size_stride(arg150_1, (4096, ), (1, )) | |
assert_size_stride(arg151_1, (14336, 4096), (4096, 1)) | |
assert_size_stride(arg152_1, (14336, 16), (16, 1)) | |
assert_size_stride(arg153_1, (14336, 16), (16, 1)) | |
assert_size_stride(arg154_1, (14336, 4096), (4096, 1)) | |
assert_size_stride(arg155_1, (14336, 16), (16, 1)) | |
assert_size_stride(arg156_1, (14336, 16), (16, 1)) | |
assert_size_stride(arg157_1, (4096, 14336), (14336, 1)) | |
assert_size_stride(arg158_1, (4096, 56), (56, 1)) | |
assert_size_stride(arg159_1, (4096, 56), (56, 1)) | |
assert_size_stride(arg160_1, (4096, ), (1, )) | |
assert_size_stride(arg161_1, (4096, 4096), (4096, 1)) | |
assert_size_stride(arg162_1, (4096, 16), (16, 1)) | |
assert_size_stride(arg163_1, (4096, 16), (16, 1)) | |
assert_size_stride(arg164_1, (1024, 4096), (4096, 1)) | |
assert_size_stride(arg165_1, (1024, 16), (16, 1)) | |
assert_size_stride(arg166_1, (1024, 16), (16, 1)) | |
assert_size_stride(arg167_1, (1024, 4096), (4096, 1)) | |
assert_size_stride(arg168_1, (1024, 16), (16, 1)) | |
assert_size_stride(arg169_1, (1024, 16), (16, 1)) | |
assert_size_stride(arg170_1, (2048, 64, 2), (128, 2, 1)) | |
assert_size_stride(arg171_1, (1, 32, 2048, 128), (8388608, 262144, 128, 1)) | |
assert_size_stride(arg172_1, (1, 32, 2048, 128), (8388608, 262144, 128, 1)) | |
assert_size_stride(arg173_1, (4096, 4096), (4096, 1)) | |
assert_size_stride(arg174_1, (4096, 16), (16, 1)) | |
assert_size_stride(arg175_1, (4096, 16), (16, 1)) | |
assert_size_stride(arg176_1, (4096, ), (1, )) | |
assert_size_stride(arg177_1, (14336, 4096), (4096, 1)) | |
assert_size_stride(arg178_1, (14336, 16), (16, 1)) | |
assert_size_stride(arg179_1, (14336, 16), (16, 1)) | |
assert_size_stride(arg180_1, (14336, 4096), (4096, 1)) | |
assert_size_stride(arg181_1, (14336, 16), (16, 1)) | |
assert_size_stride(arg182_1, (14336, 16), (16, 1)) | |
assert_size_stride(arg183_1, (4096, 14336), (14336, 1)) | |
assert_size_stride(arg184_1, (4096, 56), (56, 1)) | |
assert_size_stride(arg185_1, (4096, 56), (56, 1)) | |
assert_size_stride(arg186_1, (4096, ), (1, )) | |
assert_size_stride(arg187_1, (4096, 4096), (4096, 1)) | |
assert_size_stride(arg188_1, (4096, 16), (16, 1)) | |
assert_size_stride(arg189_1, (4096, 16), (16, 1)) | |
assert_size_stride(arg190_1, (1024, 4096), (4096, 1)) | |
assert_size_stride(arg191_1, (1024, 16), (16, 1)) | |
assert_size_stride(arg192_1, (1024, 16), (16, 1)) | |
assert_size_stride(arg193_1, (1024, 4096), (4096, 1)) | |
assert_size_stride(arg194_1, (1024, 16), (16, 1)) | |
assert_size_stride(arg195_1, (1024, 16), (16, 1)) | |
assert_size_stride(arg196_1, (2048, 64, 2), (128, 2, 1)) | |
assert_size_stride(arg197_1, (1, 32, 2048, 128), (8388608, 262144, 128, 1)) | |
assert_size_stride(arg198_1, (1, 32, 2048, 128), (8388608, 262144, 128, 1)) | |
assert_size_stride(arg199_1, (4096, 4096), (4096, 1)) | |
assert_size_stride(arg200_1, (4096, 16), (16, 1)) | |
assert_size_stride(arg201_1, (4096, 16), (16, 1)) | |
assert_size_stride(arg202_1, (4096, ), (1, )) | |
assert_size_stride(arg203_1, (14336, 4096), (4096, 1)) | |
assert_size_stride(arg204_1, (14336, 16), (16, 1)) | |
assert_size_stride(arg205_1, (14336, 16), (16, 1)) | |
assert_size_stride(arg206_1, (14336, 4096), (4096, 1)) | |
assert_size_stride(arg207_1, (14336, 16), (16, 1)) | |
assert_size_stride(arg208_1, (14336, 16), (16, 1)) | |
assert_size_stride(arg209_1, (4096, 14336), (14336, 1)) | |
assert_size_stride(arg210_1, (4096, 56), (56, 1)) | |
assert_size_stride(arg211_1, (4096, 56), (56, 1)) | |
assert_size_stride(arg212_1, (4096, ), (1, )) | |
assert_size_stride(arg213_1, (4096, 4096), (4096, 1)) | |
assert_size_stride(arg214_1, (4096, 16), (16, 1)) | |
assert_size_stride(arg215_1, (4096, 16), (16, 1)) | |
assert_size_stride(arg216_1, (1024, 4096), (4096, 1)) | |
assert_size_stride(arg217_1, (1024, 16), (16, 1)) | |
assert_size_stride(arg218_1, (1024, 16), (16, 1)) | |
assert_size_stride(arg219_1, (1024, 4096), (4096, 1)) | |
assert_size_stride(arg220_1, (1024, 16), (16, 1)) | |
assert_size_stride(arg221_1, (1024, 16), (16, 1)) | |
assert_size_stride(arg222_1, (2048, 64, 2), (128, 2, 1)) | |
assert_size_stride(arg223_1, (1, 32, 2048, 128), (8388608, 262144, 128, 1)) | |
assert_size_stride(arg224_1, (1, 32, 2048, 128), (8388608, 262144, 128, 1)) | |
assert_size_stride(arg225_1, (4096, 4096), (4096, 1)) | |
assert_size_stride(arg226_1, (4096, 16), (16, 1)) | |
assert_size_stride(arg227_1, (4096, 16), (16, 1)) | |
assert_size_stride(arg228_1, (4096, ), (1, )) | |
assert_size_stride(arg229_1, (14336, 4096), (4096, 1)) | |
assert_size_stride(arg230_1, (14336, 16), (16, 1)) | |
assert_size_stride(arg231_1, (14336, 16), (16, 1)) | |
assert_size_stride(arg232_1, (14336, 4096), (4096, 1)) | |
assert_size_stride(arg233_1, (14336, 16), (16, 1)) | |
assert_size_stride(arg234_1, (14336, 16), (16, 1)) | |
assert_size_stride(arg235_1, (4096, 14336), (14336, 1)) | |
assert_size_stride(arg236_1, (4096, 56), (56, 1)) | |
assert_size_stride(arg237_1, (4096, 56), (56, 1)) | |
assert_size_stride(arg238_1, (4096, ), (1, )) | |
assert_size_stride(arg239_1, (4096, 4096), (4096, 1)) | |
assert_size_stride(arg240_1, (4096, 16), (16, 1)) | |
assert_size_stride(arg241_1, (4096, 16), (16, 1)) | |
assert_size_stride(arg242_1, (1024, 4096), (4096, 1)) | |
assert_size_stride(arg243_1, (1024, 16), (16, 1)) | |
assert_size_stride(arg244_1, (1024, 16), (16, 1)) | |
assert_size_stride(arg245_1, (1024, 4096), (4096, 1)) | |
assert_size_stride(arg246_1, (1024, 16), (16, 1)) | |
assert_size_stride(arg247_1, (1024, 16), (16, 1)) | |
assert_size_stride(arg248_1, (2048, 64, 2), (128, 2, 1)) | |
assert_size_stride(arg249_1, (1, 32, 2048, 128), (8388608, 262144, 128, 1)) | |
assert_size_stride(arg250_1, (1, 32, 2048, 128), (8388608, 262144, 128, 1)) | |
assert_size_stride(arg251_1, (4096, 4096), (4096, 1)) | |
assert_size_stride(arg252_1, (4096, 16), (16, 1)) | |
assert_size_stride(arg253_1, (4096, 16), (16, 1)) | |
assert_size_stride(arg254_1, (4096, ), (1, )) | |
assert_size_stride(arg255_1, (14336, 4096), (4096, 1)) | |
assert_size_stride(arg256_1, (14336, 16), (16, 1)) | |
assert_size_stride(arg257_1, (14336, 16), (16, 1)) | |
assert_size_stride(arg258_1, (14336, 4096), (4096, 1)) | |
assert_size_stride(arg259_1, (14336, 16), (16, 1)) | |
assert_size_stride(arg260_1, (14336, 16), (16, 1)) | |
assert_size_stride(arg261_1, (4096, 14336), (14336, 1)) | |
assert_size_stride(arg262_1, (4096, 56), (56, 1)) | |
assert_size_stride(arg263_1, (4096, 56), (56, 1)) | |
assert_size_stride(arg264_1, (4096, ), (1, )) | |
assert_size_stride(arg265_1, (4096, 4096), (4096, 1)) | |
assert_size_stride(arg266_1, (4096, 16), (16, 1)) | |
assert_size_stride(arg267_1, (4096, 16), (16, 1)) | |
assert_size_stride(arg268_1, (1024, 4096), (4096, 1)) | |
assert_size_stride(arg269_1, (1024, 16), (16, 1)) | |
assert_size_stride(arg270_1, (1024, 16), (16, 1)) | |
assert_size_stride(arg271_1, (1024, 4096), (4096, 1)) | |
assert_size_stride(arg272_1, (1024, 16), (16, 1)) | |
assert_size_stride(arg273_1, (1024, 16), (16, 1)) | |
assert_size_stride(arg274_1, (2048, 64, 2), (128, 2, 1)) | |
assert_size_stride(arg275_1, (1, 32, 2048, 128), (8388608, 262144, 128, 1)) | |
assert_size_stride(arg276_1, (1, 32, 2048, 128), (8388608, 262144, 128, 1)) | |
assert_size_stride(arg277_1, (4096, 4096), (4096, 1)) | |
assert_size_stride(arg278_1, (4096, 16), (16, 1)) | |
assert_size_stride(arg279_1, (4096, 16), (16, 1)) | |
assert_size_stride(arg280_1, (4096, ), (1, )) | |
assert_size_stride(arg281_1, (14336, 4096), (4096, 1)) | |
assert_size_stride(arg282_1, (14336, 16), (16, 1)) | |
assert_size_stride(arg283_1, (14336, 16), (16, 1)) | |
assert_size_stride(arg284_1, (14336, 4096), (4096, 1)) | |
assert_size_stride(arg285_1, (14336, 16), (16, 1)) | |
assert_size_stride(arg286_1, (14336, 16), (16, 1)) | |
assert_size_stride(arg287_1, (4096, 14336), (14336, 1)) | |
assert_size_stride(arg288_1, (4096, 56), (56, 1)) | |
assert_size_stride(arg289_1, (4096, 56), (56, 1)) | |
assert_size_stride(arg290_1, (4096, ), (1, )) | |
assert_size_stride(arg291_1, (4096, 4096), (4096, 1)) | |
assert_size_stride(arg292_1, (4096, 16), (16, 1)) | |
assert_size_stride(arg293_1, (4096, 16), (16, 1)) | |
assert_size_stride(arg294_1, (1024, 4096), (4096, 1)) | |
assert_size_stride(arg295_1, (1024, 16), (16, 1)) | |
assert_size_stride(arg296_1, (1024, 16), (16, 1)) | |
assert_size_stride(arg297_1, (1024, 4096), (4096, 1)) | |
assert_size_stride(arg298_1, (1024, 16), (16, 1)) | |
assert_size_stride(arg299_1, (1024, 16), (16, 1)) | |
assert_size_stride(arg300_1, (2048, 64, 2), (128, 2, 1)) | |
assert_size_stride(arg301_1, (1, 32, 2048, 128), (8388608, 262144, 128, 1)) | |
assert_size_stride(arg302_1, (1, 32, 2048, 128), (8388608, 262144, 128, 1)) | |
assert_size_stride(arg303_1, (4096, 4096), (4096, 1)) | |
assert_size_stride(arg304_1, (4096, 16), (16, 1)) | |
assert_size_stride(arg305_1, (4096, 16), (16, 1)) | |
assert_size_stride(arg306_1, (4096, ), (1, )) | |
assert_size_stride(arg307_1, (14336, 4096), (4096, 1)) | |
assert_size_stride(arg308_1, (14336, 16), (16, 1)) | |
assert_size_stride(arg309_1, (14336, 16), (16, 1)) | |
assert_size_stride(arg310_1, (14336, 4096), (4096, 1)) | |
assert_size_stride(arg311_1, (14336, 16), (16, 1)) | |
assert_size_stride(arg312_1, (14336, 16), (16, 1)) | |
assert_size_stride(arg313_1, (4096, 14336), (14336, 1)) | |
assert_size_stride(arg314_1, (4096, 56), (56, 1)) | |
assert_size_stride(arg315_1, (4096, 56), (56, 1)) | |
assert_size_stride(arg316_1, (4096, ), (1, )) | |
assert_size_stride(arg317_1, (4096, 4096), (4096, 1)) | |
assert_size_stride(arg318_1, (4096, 16), (16, 1)) | |
assert_size_stride(arg319_1, (4096, 16), (16, 1)) | |
assert_size_stride(arg320_1, (1024, 4096), (4096, 1)) | |
assert_size_stride(arg321_1, (1024, 16), (16, 1)) | |
assert_size_stride(arg322_1, (1024, 16), (16, 1)) | |
assert_size_stride(arg323_1, (1024, 4096), (4096, 1)) | |
assert_size_stride(arg324_1, (1024, 16), (16, 1)) | |
assert_size_stride(arg325_1, (1024, 16), (16, 1)) | |
assert_size_stride(arg326_1, (2048, 64, 2), (128, 2, 1)) | |
assert_size_stride(arg327_1, (1, 32, 2048, 128), (8388608, 262144, 128, 1)) | |
assert_size_stride(arg328_1, (1, 32, 2048, 128), (8388608, 262144, 128, 1)) | |
assert_size_stride(arg329_1, (4096, 4096), (4096, 1)) | |
assert_size_stride(arg330_1, (4096, 16), (16, 1)) | |
assert_size_stride(arg331_1, (4096, 16), (16, 1)) | |
assert_size_stride(arg332_1, (4096, ), (1, )) | |
assert_size_stride(arg333_1, (14336, 4096), (4096, 1)) | |
assert_size_stride(arg334_1, (14336, 16), (16, 1)) | |
assert_size_stride(arg335_1, (14336, 16), (16, 1)) | |
assert_size_stride(arg336_1, (14336, 4096), (4096, 1)) | |
assert_size_stride(arg337_1, (14336, 16), (16, 1)) | |
assert_size_stride(arg338_1, (14336, 16), (16, 1)) | |
assert_size_stride(arg339_1, (4096, 14336), (14336, 1)) | |
assert_size_stride(arg340_1, (4096, 56), (56, 1)) | |
assert_size_stride(arg341_1, (4096, 56), (56, 1)) | |
assert_size_stride(arg342_1, (4096, ), (1, )) | |
assert_size_stride(arg343_1, (4096, 4096), (4096, 1)) | |
assert_size_stride(arg344_1, (4096, 16), (16, 1)) | |
assert_size_stride(arg345_1, (4096, 16), (16, 1)) | |
assert_size_stride(arg346_1, (1024, 4096), (4096, 1)) | |
assert_size_stride(arg347_1, (1024, 16), (16, 1)) | |
assert_size_stride(arg348_1, (1024, 16), (16, 1)) | |
assert_size_stride(arg349_1, (1024, 4096), (4096, 1)) | |
assert_size_stride(arg350_1, (1024, 16), (16, 1)) | |
assert_size_stride(arg351_1, (1024, 16), (16, 1)) | |
assert_size_stride(arg352_1, (2048, 64, 2), (128, 2, 1)) | |
assert_size_stride(arg353_1, (1, 32, 2048, 128), (8388608, 262144, 128, 1)) | |
assert_size_stride(arg354_1, (1, 32, 2048, 128), (8388608, 262144, 128, 1)) | |
assert_size_stride(arg355_1, (4096, 4096), (4096, 1)) | |
assert_size_stride(arg356_1, (4096, 16), (16, 1)) | |
assert_size_stride(arg357_1, (4096, 16), (16, 1)) | |
assert_size_stride(arg358_1, (4096, ), (1, )) | |
assert_size_stride(arg359_1, (14336, 4096), (4096, 1)) | |
assert_size_stride(arg360_1, (14336, 16), (16, 1)) | |
assert_size_stride(arg361_1, (14336, 16), (16, 1)) | |
assert_size_stride(arg362_1, (14336, 4096), (4096, 1)) | |
assert_size_stride(arg363_1, (14336, 16), (16, 1)) | |
assert_size_stride(arg364_1, (14336, 16), (16, 1)) | |
assert_size_stride(arg365_1, (4096, 14336), (14336, 1)) | |
assert_size_stride(arg366_1, (4096, 56), (56, 1)) | |
assert_size_stride(arg367_1, (4096, 56), (56, 1)) | |
assert_size_stride(arg368_1, (4096, ), (1, )) | |
assert_size_stride(arg369_1, (4096, 4096), (4096, 1)) | |
assert_size_stride(arg370_1, (4096, 16), (16, 1)) | |
assert_size_stride(arg371_1, (4096, 16), (16, 1)) | |
assert_size_stride(arg372_1, (1024, 4096), (4096, 1)) | |
assert_size_stride(arg373_1, (1024, 16), (16, 1)) | |
assert_size_stride(arg374_1, (1024, 16), (16, 1)) | |
assert_size_stride(arg375_1, (1024, 4096), (4096, 1)) | |
assert_size_stride(arg376_1, (1024, 16), (16, 1)) | |
assert_size_stride(arg377_1, (1024, 16), (16, 1)) | |
assert_size_stride(arg378_1, (2048, 64, 2), (128, 2, 1)) | |
assert_size_stride(arg379_1, (1, 32, 2048, 128), (8388608, 262144, 128, 1)) | |
assert_size_stride(arg380_1, (1, 32, 2048, 128), (8388608, 262144, 128, 1)) | |
assert_size_stride(arg381_1, (4096, 4096), (4096, 1)) | |
assert_size_stride(arg382_1, (4096, 16), (16, 1)) | |
assert_size_stride(arg383_1, (4096, 16), (16, 1)) | |
assert_size_stride(arg384_1, (4096, ), (1, )) | |
assert_size_stride(arg385_1, (14336, 4096), (4096, 1)) | |
assert_size_stride(arg386_1, (14336, 16), (16, 1)) | |
assert_size_stride(arg387_1, (14336, 16), (16, 1)) | |
assert_size_stride(arg388_1, (14336, 4096), (4096, 1)) | |
assert_size_stride(arg389_1, (14336, 16), (16, 1)) | |
assert_size_stride(arg390_1, (14336, 16), (16, 1)) | |
assert_size_stride(arg391_1, (4096, 14336), (14336, 1)) | |
assert_size_stride(arg392_1, (4096, 56), (56, 1)) | |
assert_size_stride(arg393_1, (4096, 56), (56, 1)) | |
assert_size_stride(arg394_1, (4096, ), (1, )) | |
assert_size_stride(arg395_1, (4096, 4096), (4096, 1)) | |
assert_size_stride(arg396_1, (4096, 16), (16, 1)) | |
assert_size_stride(arg397_1, (4096, 16), (16, 1)) | |
assert_size_stride(arg398_1, (1024, 4096), (4096, 1)) | |
assert_size_stride(arg399_1, (1024, 16), (16, 1)) | |
assert_size_stride(arg400_1, (1024, 16), (16, 1)) | |
assert_size_stride(arg401_1, (1024, 4096), (4096, 1)) | |
assert_size_stride(arg402_1, (1024, 16), (16, 1)) | |
assert_size_stride(arg403_1, (1024, 16), (16, 1)) | |
assert_size_stride(arg404_1, (2048, 64, 2), (128, 2, 1)) | |
assert_size_stride(arg405_1, (1, 32, 2048, 128), (8388608, 262144, 128, 1)) | |
assert_size_stride(arg406_1, (1, 32, 2048, 128), (8388608, 262144, 128, 1)) | |
assert_size_stride(arg407_1, (4096, 4096), (4096, 1)) | |
assert_size_stride(arg408_1, (4096, 16), (16, 1)) | |
assert_size_stride(arg409_1, (4096, 16), (16, 1)) | |
assert_size_stride(arg410_1, (4096, ), (1, )) | |
assert_size_stride(arg411_1, (14336, 4096), (4096, 1)) | |
assert_size_stride(arg412_1, (14336, 16), (16, 1)) | |
assert_size_stride(arg413_1, (14336, 16), (16, 1)) | |
assert_size_stride(arg414_1, (14336, 4096), (4096, 1)) | |
assert_size_stride(arg415_1, (14336, 16), (16, 1)) | |
assert_size_stride(arg416_1, (14336, 16), (16, 1)) | |
assert_size_stride(arg417_1, (4096, 14336), (14336, 1)) | |
assert_size_stride(arg418_1, (4096, 56), (56, 1)) | |
assert_size_stride(arg419_1, (4096, 56), (56, 1)) | |
assert_size_stride(arg420_1, (4096, ), (1, )) | |
assert_size_stride(arg421_1, (4096, 4096), (4096, 1)) | |
assert_size_stride(arg422_1, (4096, 16), (16, 1)) | |
assert_size_stride(arg423_1, (4096, 16), (16, 1)) | |
assert_size_stride(arg424_1, (1024, 4096), (4096, 1)) | |
assert_size_stride(arg425_1, (1024, 16), (16, 1)) | |
assert_size_stride(arg426_1, (1024, 16), (16, 1)) | |
assert_size_stride(arg427_1, (1024, 4096), (4096, 1)) | |
assert_size_stride(arg428_1, (1024, 16), (16, 1)) | |
assert_size_stride(arg429_1, (1024, 16), (16, 1)) | |
assert_size_stride(arg430_1, (2048, 64, 2), (128, 2, 1)) | |
assert_size_stride(arg431_1, (1, 32, 2048, 128), (8388608, 262144, 128, 1)) | |
assert_size_stride(arg432_1, (1, 32, 2048, 128), (8388608, 262144, 128, 1)) | |
assert_size_stride(arg433_1, (4096, 4096), (4096, 1)) | |
assert_size_stride(arg434_1, (4096, 16), (16, 1)) | |
assert_size_stride(arg435_1, (4096, 16), (16, 1)) | |
assert_size_stride(arg436_1, (4096, ), (1, )) | |
assert_size_stride(arg437_1, (14336, 4096), (4096, 1)) | |
assert_size_stride(arg438_1, (14336, 16), (16, 1)) | |
assert_size_stride(arg439_1, (14336, 16), (16, 1)) | |
assert_size_stride(arg440_1, (14336, 4096), (4096, 1)) | |
assert_size_stride(arg441_1, (14336, 16), (16, 1)) | |
assert_size_stride(arg442_1, (14336, 16), (16, 1)) | |
assert_size_stride(arg443_1, (4096, 14336), (14336, 1)) | |
assert_size_stride(arg444_1, (4096, 56), (56, 1)) | |
assert_size_stride(arg445_1, (4096, 56), (56, 1)) | |
assert_size_stride(arg446_1, (4096, ), (1, )) | |
assert_size_stride(arg447_1, (4096, 4096), (4096, 1)) | |
assert_size_stride(arg448_1, (4096, 16), (16, 1)) | |
assert_size_stride(arg449_1, (4096, 16), (16, 1)) | |
assert_size_stride(arg450_1, (1024, 4096), (4096, 1)) | |
assert_size_stride(arg451_1, (1024, 16), (16, 1)) | |
assert_size_stride(arg452_1, (1024, 16), (16, 1)) | |
assert_size_stride(arg453_1, (1024, 4096), (4096, 1)) | |
assert_size_stride(arg454_1, (1024, 16), (16, 1)) | |
assert_size_stride(arg455_1, (1024, 16), (16, 1)) | |
assert_size_stride(arg456_1, (2048, 64, 2), (128, 2, 1)) | |
assert_size_stride(arg457_1, (1, 32, 2048, 128), (8388608, 262144, 128, 1)) | |
assert_size_stride(arg458_1, (1, 32, 2048, 128), (8388608, 262144, 128, 1)) | |
assert_size_stride(arg459_1, (4096, 4096), (4096, 1)) | |
assert_size_stride(arg460_1, (4096, 16), (16, 1)) | |
assert_size_stride(arg461_1, (4096, 16), (16, 1)) | |
assert_size_stride(arg462_1, (4096, ), (1, )) | |
assert_size_stride(arg463_1, (14336, 4096), (4096, 1)) | |
assert_size_stride(arg464_1, (14336, 16), (16, 1)) | |
assert_size_stride(arg465_1, (14336, 16), (16, 1)) | |
assert_size_stride(arg466_1, (14336, 4096), (4096, 1)) | |
assert_size_stride(arg467_1, (14336, 16), (16, 1)) | |
assert_size_stride(arg468_1, (14336, 16), (16, 1)) | |
assert_size_stride(arg469_1, (4096, 14336), (14336, 1)) | |
assert_size_stride(arg470_1, (4096, 56), (56, 1)) | |
assert_size_stride(arg471_1, (4096, 56), (56, 1)) | |
assert_size_stride(arg472_1, (4096, ), (1, )) | |
assert_size_stride(arg473_1, (4096, 4096), (4096, 1)) | |
assert_size_stride(arg474_1, (4096, 16), (16, 1)) | |
assert_size_stride(arg475_1, (4096, 16), (16, 1)) | |
assert_size_stride(arg476_1, (1024, 4096), (4096, 1)) | |
assert_size_stride(arg477_1, (1024, 16), (16, 1)) | |
assert_size_stride(arg478_1, (1024, 16), (16, 1)) | |
assert_size_stride(arg479_1, (1024, 4096), (4096, 1)) | |
assert_size_stride(arg480_1, (1024, 16), (16, 1)) | |
assert_size_stride(arg481_1, (1024, 16), (16, 1)) | |
assert_size_stride(arg482_1, (2048, 64, 2), (128, 2, 1)) | |
assert_size_stride(arg483_1, (1, 32, 2048, 128), (8388608, 262144, 128, 1)) | |
assert_size_stride(arg484_1, (1, 32, 2048, 128), (8388608, 262144, 128, 1)) | |
assert_size_stride(arg485_1, (4096, 4096), (4096, 1)) | |
assert_size_stride(arg486_1, (4096, 16), (16, 1)) | |
assert_size_stride(arg487_1, (4096, 16), (16, 1)) | |
assert_size_stride(arg488_1, (4096, ), (1, )) | |
assert_size_stride(arg489_1, (14336, 4096), (4096, 1)) | |
assert_size_stride(arg490_1, (14336, 16), (16, 1)) | |
assert_size_stride(arg491_1, (14336, 16), (16, 1)) | |
assert_size_stride(arg492_1, (14336, 4096), (4096, 1)) | |
assert_size_stride(arg493_1, (14336, 16), (16, 1)) | |
assert_size_stride(arg494_1, (14336, 16), (16, 1)) | |
assert_size_stride(arg495_1, (4096, 14336), (14336, 1)) | |
assert_size_stride(arg496_1, (4096, 56), (56, 1)) | |
assert_size_stride(arg497_1, (4096, 56), (56, 1)) | |
assert_size_stride(arg498_1, (4096, ), (1, )) | |
assert_size_stride(arg499_1, (4096, 4096), (4096, 1)) | |
assert_size_stride(arg500_1, (4096, 16), (16, 1)) | |
assert_size_stride(arg501_1, (4096, 16), (16, 1)) | |
assert_size_stride(arg502_1, (1024, 4096), (4096, 1)) | |
assert_size_stride(arg503_1, (1024, 16), (16, 1)) | |
assert_size_stride(arg504_1, (1024, 16), (16, 1)) | |
assert_size_stride(arg505_1, (1024, 4096), (4096, 1)) | |
assert_size_stride(arg506_1, (1024, 16), (16, 1)) | |
assert_size_stride(arg507_1, (1024, 16), (16, 1)) | |
assert_size_stride(arg508_1, (2048, 64, 2), (128, 2, 1)) | |
assert_size_stride(arg509_1, (1, 32, 2048, 128), (8388608, 262144, 128, 1)) | |
assert_size_stride(arg510_1, (1, 32, 2048, 128), (8388608, 262144, 128, 1)) | |
assert_size_stride(arg511_1, (4096, 4096), (4096, 1)) | |
assert_size_stride(arg512_1, (4096, 16), (16, 1)) | |
assert_size_stride(arg513_1, (4096, 16), (16, 1)) | |
assert_size_stride(arg514_1, (4096, ), (1, )) | |
assert_size_stride(arg515_1, (14336, 4096), (4096, 1)) | |
assert_size_stride(arg516_1, (14336, 16), (16, 1)) | |
assert_size_stride(arg517_1, (14336, 16), (16, 1)) | |
assert_size_stride(arg518_1, (14336, 4096), (4096, 1)) | |
assert_size_stride(arg519_1, (14336, 16), (16, 1)) | |
assert_size_stride(arg520_1, (14336, 16), (16, 1)) | |
assert_size_stride(arg521_1, (4096, 14336), (14336, 1)) | |
assert_size_stride(arg522_1, (4096, 56), (56, 1)) | |
assert_size_stride(arg523_1, (4096, 56), (56, 1)) | |
assert_size_stride(arg524_1, (4096, ), (1, )) | |
assert_size_stride(arg525_1, (4096, 4096), (4096, 1)) | |
assert_size_stride(arg526_1, (4096, 16), (16, 1)) | |
assert_size_stride(arg527_1, (4096, 16), (16, 1)) | |
assert_size_stride(arg528_1, (1024, 4096), (4096, 1)) | |
assert_size_stride(arg529_1, (1024, 16), (16, 1)) | |
assert_size_stride(arg530_1, (1024, 16), (16, 1)) | |
assert_size_stride(arg531_1, (1024, 4096), (4096, 1)) | |
assert_size_stride(arg532_1, (1024, 16), (16, 1)) | |
assert_size_stride(arg533_1, (1024, 16), (16, 1)) | |
assert_size_stride(arg534_1, (2048, 64, 2), (128, 2, 1)) | |
assert_size_stride(arg535_1, (1, 32, 2048, 128), (8388608, 262144, 128, 1)) | |
assert_size_stride(arg536_1, (1, 32, 2048, 128), (8388608, 262144, 128, 1)) | |
assert_size_stride(arg537_1, (4096, 4096), (4096, 1)) | |
assert_size_stride(arg538_1, (4096, 16), (16, 1)) | |
assert_size_stride(arg539_1, (4096, 16), (16, 1)) | |
assert_size_stride(arg540_1, (4096, ), (1, )) | |
assert_size_stride(arg541_1, (14336, 4096), (4096, 1)) | |
assert_size_stride(arg542_1, (14336, 16), (16, 1)) | |
assert_size_stride(arg543_1, (14336, 16), (16, 1)) | |
assert_size_stride(arg544_1, (14336, 4096), (4096, 1)) | |
assert_size_stride(arg545_1, (14336, 16), (16, 1)) | |
assert_size_stride(arg546_1, (14336, 16), (16, 1)) | |
assert_size_stride(arg547_1, (4096, 14336), (14336, 1)) | |
assert_size_stride(arg548_1, (4096, 56), (56, 1)) | |
assert_size_stride(arg549_1, (4096, 56), (56, 1)) | |
assert_size_stride(arg550_1, (4096, ), (1, )) | |
assert_size_stride(arg551_1, (4096, 4096), (4096, 1)) | |
assert_size_stride(arg552_1, (4096, 16), (16, 1)) | |
assert_size_stride(arg553_1, (4096, 16), (16, 1)) | |
assert_size_stride(arg554_1, (1024, 4096), (4096, 1)) | |
assert_size_stride(arg555_1, (1024, 16), (16, 1)) | |
assert_size_stride(arg556_1, (1024, 16), (16, 1)) | |
assert_size_stride(arg557_1, (1024, 4096), (4096, 1)) | |
assert_size_stride(arg558_1, (1024, 16), (16, 1)) | |
assert_size_stride(arg559_1, (1024, 16), (16, 1)) | |
assert_size_stride(arg560_1, (2048, 64, 2), (128, 2, 1)) | |
assert_size_stride(arg561_1, (1, 32, 2048, 128), (8388608, 262144, 128, 1)) | |
assert_size_stride(arg562_1, (1, 32, 2048, 128), (8388608, 262144, 128, 1)) | |
assert_size_stride(arg563_1, (4096, 4096), (4096, 1)) | |
assert_size_stride(arg564_1, (4096, 16), (16, 1)) | |
assert_size_stride(arg565_1, (4096, 16), (16, 1)) | |
assert_size_stride(arg566_1, (4096, ), (1, )) | |
assert_size_stride(arg567_1, (14336, 4096), (4096, 1)) | |
assert_size_stride(arg568_1, (14336, 16), (16, 1)) | |
assert_size_stride(arg569_1, (14336, 16), (16, 1)) | |
assert_size_stride(arg570_1, (14336, 4096), (4096, 1)) | |
assert_size_stride(arg571_1, (14336, 16), (16, 1)) | |
assert_size_stride(arg572_1, (14336, 16), (16, 1)) | |
assert_size_stride(arg573_1, (4096, 14336), (14336, 1)) | |
assert_size_stride(arg574_1, (4096, 56), (56, 1)) | |
assert_size_stride(arg575_1, (4096, 56), (56, 1)) | |
assert_size_stride(arg576_1, (4096, ), (1, )) | |
assert_size_stride(arg577_1, (4096, 4096), (4096, 1)) | |
assert_size_stride(arg578_1, (4096, 16), (16, 1)) | |
assert_size_stride(arg579_1, (4096, 16), (16, 1)) | |
assert_size_stride(arg580_1, (1024, 4096), (4096, 1)) | |
assert_size_stride(arg581_1, (1024, 16), (16, 1)) | |
assert_size_stride(arg582_1, (1024, 16), (16, 1)) | |
assert_size_stride(arg583_1, (1024, 4096), (4096, 1)) | |
assert_size_stride(arg584_1, (1024, 16), (16, 1)) | |
assert_size_stride(arg585_1, (1024, 16), (16, 1)) | |
assert_size_stride(arg586_1, (2048, 64, 2), (128, 2, 1)) | |
assert_size_stride(arg587_1, (1, 32, 2048, 128), (8388608, 262144, 128, 1)) | |
assert_size_stride(arg588_1, (1, 32, 2048, 128), (8388608, 262144, 128, 1)) | |
assert_size_stride(arg589_1, (4096, 4096), (4096, 1)) | |
assert_size_stride(arg590_1, (4096, 16), (16, 1)) | |
assert_size_stride(arg591_1, (4096, 16), (16, 1)) | |
assert_size_stride(arg592_1, (4096, ), (1, )) | |
assert_size_stride(arg593_1, (14336, 4096), (4096, 1)) | |
assert_size_stride(arg594_1, (14336, 16), (16, 1)) | |
assert_size_stride(arg595_1, (14336, 16), (16, 1)) | |
assert_size_stride(arg596_1, (14336, 4096), (4096, 1)) | |
assert_size_stride(arg597_1, (14336, 16), (16, 1)) | |
assert_size_stride(arg598_1, (14336, 16), (16, 1)) | |
assert_size_stride(arg599_1, (4096, 14336), (14336, 1)) | |
assert_size_stride(arg600_1, (4096, 56), (56, 1)) | |
assert_size_stride(arg601_1, (4096, 56), (56, 1)) | |
assert_size_stride(arg602_1, (4096, ), (1, )) | |
assert_size_stride(arg603_1, (4096, 4096), (4096, 1)) | |
assert_size_stride(arg604_1, (4096, 16), (16, 1)) | |
assert_size_stride(arg605_1, (4096, 16), (16, 1)) | |
assert_size_stride(arg606_1, (1024, 4096), (4096, 1)) | |
assert_size_stride(arg607_1, (1024, 16), (16, 1)) | |
assert_size_stride(arg608_1, (1024, 16), (16, 1)) | |
assert_size_stride(arg609_1, (1024, 4096), (4096, 1)) | |
assert_size_stride(arg610_1, (1024, 16), (16, 1)) | |
assert_size_stride(arg611_1, (1024, 16), (16, 1)) | |
assert_size_stride(arg612_1, (2048, 64, 2), (128, 2, 1)) | |
assert_size_stride(arg613_1, (1, 32, 2048, 128), (8388608, 262144, 128, 1)) | |
assert_size_stride(arg614_1, (1, 32, 2048, 128), (8388608, 262144, 128, 1)) | |
assert_size_stride(arg615_1, (4096, 4096), (4096, 1)) | |
assert_size_stride(arg616_1, (4096, 16), (16, 1)) | |
assert_size_stride(arg617_1, (4096, 16), (16, 1)) | |
assert_size_stride(arg618_1, (4096, ), (1, )) | |
assert_size_stride(arg619_1, (14336, 4096), (4096, 1)) | |
assert_size_stride(arg620_1, (14336, 16), (16, 1)) | |
assert_size_stride(arg621_1, (14336, 16), (16, 1)) | |
assert_size_stride(arg622_1, (14336, 4096), (4096, 1)) | |
assert_size_stride(arg623_1, (14336, 16), (16, 1)) | |
assert_size_stride(arg624_1, (14336, 16), (16, 1)) | |
assert_size_stride(arg625_1, (4096, 14336), (14336, 1)) | |
assert_size_stride(arg626_1, (4096, 56), (56, 1)) | |
assert_size_stride(arg627_1, (4096, 56), (56, 1)) | |
assert_size_stride(arg628_1, (4096, ), (1, )) | |
assert_size_stride(arg629_1, (4096, 4096), (4096, 1)) | |
assert_size_stride(arg630_1, (4096, 16), (16, 1)) | |
assert_size_stride(arg631_1, (4096, 16), (16, 1)) | |
assert_size_stride(arg632_1, (1024, 4096), (4096, 1)) | |
assert_size_stride(arg633_1, (1024, 16), (16, 1)) | |
assert_size_stride(arg634_1, (1024, 16), (16, 1)) | |
assert_size_stride(arg635_1, (1024, 4096), (4096, 1)) | |
assert_size_stride(arg636_1, (1024, 16), (16, 1)) | |
assert_size_stride(arg637_1, (1024, 16), (16, 1)) | |
assert_size_stride(arg638_1, (2048, 64, 2), (128, 2, 1)) | |
assert_size_stride(arg639_1, (1, 32, 2048, 128), (8388608, 262144, 128, 1)) | |
assert_size_stride(arg640_1, (1, 32, 2048, 128), (8388608, 262144, 128, 1)) | |
assert_size_stride(arg641_1, (4096, 4096), (4096, 1)) | |
assert_size_stride(arg642_1, (4096, 16), (16, 1)) | |
assert_size_stride(arg643_1, (4096, 16), (16, 1)) | |
assert_size_stride(arg644_1, (4096, ), (1, )) | |
assert_size_stride(arg645_1, (14336, 4096), (4096, 1)) | |
assert_size_stride(arg646_1, (14336, 16), (16, 1)) | |
assert_size_stride(arg647_1, (14336, 16), (16, 1)) | |
assert_size_stride(arg648_1, (14336, 4096), (4096, 1)) | |
assert_size_stride(arg649_1, (14336, 16), (16, 1)) | |
assert_size_stride(arg650_1, (14336, 16), (16, 1)) | |
assert_size_stride(arg651_1, (4096, 14336), (14336, 1)) | |
assert_size_stride(arg652_1, (4096, 56), (56, 1)) | |
assert_size_stride(arg653_1, (4096, 56), (56, 1)) | |
assert_size_stride(arg654_1, (4096, ), (1, )) | |
assert_size_stride(arg655_1, (4096, 4096), (4096, 1)) | |
assert_size_stride(arg656_1, (4096, 16), (16, 1)) | |
assert_size_stride(arg657_1, (4096, 16), (16, 1)) | |
assert_size_stride(arg658_1, (1024, 4096), (4096, 1)) | |
assert_size_stride(arg659_1, (1024, 16), (16, 1)) | |
assert_size_stride(arg660_1, (1024, 16), (16, 1)) | |
assert_size_stride(arg661_1, (1024, 4096), (4096, 1)) | |
assert_size_stride(arg662_1, (1024, 16), (16, 1)) | |
assert_size_stride(arg663_1, (1024, 16), (16, 1)) | |
assert_size_stride(arg664_1, (2048, 64, 2), (128, 2, 1)) | |
assert_size_stride(arg665_1, (1, 32, 2048, 128), (8388608, 262144, 128, 1)) | |
assert_size_stride(arg666_1, (1, 32, 2048, 128), (8388608, 262144, 128, 1)) | |
assert_size_stride(arg667_1, (4096, 4096), (4096, 1)) | |
assert_size_stride(arg668_1, (4096, 16), (16, 1)) | |
assert_size_stride(arg669_1, (4096, 16), (16, 1)) | |
assert_size_stride(arg670_1, (4096, ), (1, )) | |
assert_size_stride(arg671_1, (14336, 4096), (4096, 1)) | |
assert_size_stride(arg672_1, (14336, 16), (16, 1)) | |
assert_size_stride(arg673_1, (14336, 16), (16, 1)) | |
assert_size_stride(arg674_1, (14336, 4096), (4096, 1)) | |
assert_size_stride(arg675_1, (14336, 16), (16, 1)) | |
assert_size_stride(arg676_1, (14336, 16), (16, 1)) | |
assert_size_stride(arg677_1, (4096, 14336), (14336, 1)) | |
assert_size_stride(arg678_1, (4096, 56), (56, 1)) | |
assert_size_stride(arg679_1, (4096, 56), (56, 1)) | |
assert_size_stride(arg680_1, (4096, ), (1, )) | |
assert_size_stride(arg681_1, (4096, 4096), (4096, 1)) | |
assert_size_stride(arg682_1, (4096, 16), (16, 1)) | |
assert_size_stride(arg683_1, (4096, 16), (16, 1)) | |
assert_size_stride(arg684_1, (1024, 4096), (4096, 1)) | |
assert_size_stride(arg685_1, (1024, 16), (16, 1)) | |
assert_size_stride(arg686_1, (1024, 16), (16, 1)) | |
assert_size_stride(arg687_1, (1024, 4096), (4096, 1)) | |
assert_size_stride(arg688_1, (1024, 16), (16, 1)) | |
assert_size_stride(arg689_1, (1024, 16), (16, 1)) | |
assert_size_stride(arg690_1, (2048, 64, 2), (128, 2, 1)) | |
assert_size_stride(arg691_1, (1, 32, 2048, 128), (8388608, 262144, 128, 1)) | |
assert_size_stride(arg692_1, (1, 32, 2048, 128), (8388608, 262144, 128, 1)) | |
assert_size_stride(arg693_1, (4096, 4096), (4096, 1)) | |
assert_size_stride(arg694_1, (4096, 16), (16, 1)) | |
assert_size_stride(arg695_1, (4096, 16), (16, 1)) | |
assert_size_stride(arg696_1, (4096, ), (1, )) | |
assert_size_stride(arg697_1, (14336, 4096), (4096, 1)) | |
assert_size_stride(arg698_1, (14336, 16), (16, 1)) | |
assert_size_stride(arg699_1, (14336, 16), (16, 1)) | |
assert_size_stride(arg700_1, (14336, 4096), (4096, 1)) | |
assert_size_stride(arg701_1, (14336, 16), (16, 1)) | |
assert_size_stride(arg702_1, (14336, 16), (16, 1)) | |
assert_size_stride(arg703_1, (4096, 14336), (14336, 1)) | |
assert_size_stride(arg704_1, (4096, 56), (56, 1)) | |
assert_size_stride(arg705_1, (4096, 56), (56, 1)) | |
assert_size_stride(arg706_1, (4096, ), (1, )) | |
assert_size_stride(arg707_1, (4096, 4096), (4096, 1)) | |
assert_size_stride(arg708_1, (4096, 16), (16, 1)) | |
assert_size_stride(arg709_1, (4096, 16), (16, 1)) | |
assert_size_stride(arg710_1, (1024, 4096), (4096, 1)) | |
assert_size_stride(arg711_1, (1024, 16), (16, 1)) | |
assert_size_stride(arg712_1, (1024, 16), (16, 1)) | |
assert_size_stride(arg713_1, (1024, 4096), (4096, 1)) | |
assert_size_stride(arg714_1, (1024, 16), (16, 1)) | |
assert_size_stride(arg715_1, (1024, 16), (16, 1)) | |
assert_size_stride(arg716_1, (2048, 64, 2), (128, 2, 1)) | |
assert_size_stride(arg717_1, (1, 32, 2048, 128), (8388608, 262144, 128, 1)) | |
assert_size_stride(arg718_1, (1, 32, 2048, 128), (8388608, 262144, 128, 1)) | |
assert_size_stride(arg719_1, (4096, 4096), (4096, 1)) | |
assert_size_stride(arg720_1, (4096, 16), (16, 1)) | |
assert_size_stride(arg721_1, (4096, 16), (16, 1)) | |
assert_size_stride(arg722_1, (4096, ), (1, )) | |
assert_size_stride(arg723_1, (14336, 4096), (4096, 1)) | |
assert_size_stride(arg724_1, (14336, 16), (16, 1)) | |
assert_size_stride(arg725_1, (14336, 16), (16, 1)) | |
assert_size_stride(arg726_1, (14336, 4096), (4096, 1)) | |
assert_size_stride(arg727_1, (14336, 16), (16, 1)) | |
assert_size_stride(arg728_1, (14336, 16), (16, 1)) | |
assert_size_stride(arg729_1, (4096, 14336), (14336, 1)) | |
assert_size_stride(arg730_1, (4096, 56), (56, 1)) | |
assert_size_stride(arg731_1, (4096, 56), (56, 1)) | |
assert_size_stride(arg732_1, (4096, ), (1, )) | |
assert_size_stride(arg733_1, (4096, 4096), (4096, 1)) | |
assert_size_stride(arg734_1, (4096, 16), (16, 1)) | |
assert_size_stride(arg735_1, (4096, 16), (16, 1)) | |
assert_size_stride(arg736_1, (1024, 4096), (4096, 1)) | |
assert_size_stride(arg737_1, (1024, 16), (16, 1)) | |
assert_size_stride(arg738_1, (1024, 16), (16, 1)) | |
assert_size_stride(arg739_1, (1024, 4096), (4096, 1)) | |
assert_size_stride(arg740_1, (1024, 16), (16, 1)) | |
assert_size_stride(arg741_1, (1024, 16), (16, 1)) | |
assert_size_stride(arg742_1, (2048, 64, 2), (128, 2, 1)) | |
assert_size_stride(arg743_1, (1, 32, 2048, 128), (8388608, 262144, 128, 1)) | |
assert_size_stride(arg744_1, (1, 32, 2048, 128), (8388608, 262144, 128, 1)) | |
assert_size_stride(arg745_1, (4096, 4096), (4096, 1)) | |
assert_size_stride(arg746_1, (4096, 16), (16, 1)) | |
assert_size_stride(arg747_1, (4096, 16), (16, 1)) | |
assert_size_stride(arg748_1, (4096, ), (1, )) | |
assert_size_stride(arg749_1, (14336, 4096), (4096, 1)) | |
assert_size_stride(arg750_1, (14336, 16), (16, 1)) | |
assert_size_stride(arg751_1, (14336, 16), (16, 1)) | |
assert_size_stride(arg752_1, (14336, 4096), (4096, 1)) | |
assert_size_stride(arg753_1, (14336, 16), (16, 1)) | |
assert_size_stride(arg754_1, (14336, 16), (16, 1)) | |
assert_size_stride(arg755_1, (4096, 14336), (14336, 1)) | |
assert_size_stride(arg756_1, (4096, 56), (56, 1)) | |
assert_size_stride(arg757_1, (4096, 56), (56, 1)) | |
assert_size_stride(arg758_1, (4096, ), (1, )) | |
assert_size_stride(arg759_1, (4096, 4096), (4096, 1)) | |
assert_size_stride(arg760_1, (4096, 16), (16, 1)) | |
assert_size_stride(arg761_1, (4096, 16), (16, 1)) | |
assert_size_stride(arg762_1, (1024, 4096), (4096, 1)) | |
assert_size_stride(arg763_1, (1024, 16), (16, 1)) | |
assert_size_stride(arg764_1, (1024, 16), (16, 1)) | |
assert_size_stride(arg765_1, (1024, 4096), (4096, 1)) | |
assert_size_stride(arg766_1, (1024, 16), (16, 1)) | |
assert_size_stride(arg767_1, (1024, 16), (16, 1)) | |
assert_size_stride(arg768_1, (2048, 64, 2), (128, 2, 1)) | |
assert_size_stride(arg769_1, (1, 32, 2048, 128), (8388608, 262144, 128, 1)) | |
assert_size_stride(arg770_1, (1, 32, 2048, 128), (8388608, 262144, 128, 1)) | |
assert_size_stride(arg771_1, (4096, 4096), (4096, 1)) | |
assert_size_stride(arg772_1, (4096, 16), (16, 1)) | |
assert_size_stride(arg773_1, (4096, 16), (16, 1)) | |
assert_size_stride(arg774_1, (4096, ), (1, )) | |
assert_size_stride(arg775_1, (14336, 4096), (4096, 1)) | |
assert_size_stride(arg776_1, (14336, 16), (16, 1)) | |
assert_size_stride(arg777_1, (14336, 16), (16, 1)) | |
assert_size_stride(arg778_1, (14336, 4096), (4096, 1)) | |
assert_size_stride(arg779_1, (14336, 16), (16, 1)) | |
assert_size_stride(arg780_1, (14336, 16), (16, 1)) | |
assert_size_stride(arg781_1, (4096, 14336), (14336, 1)) | |
assert_size_stride(arg782_1, (4096, 56), (56, 1)) | |
assert_size_stride(arg783_1, (4096, 56), (56, 1)) | |
assert_size_stride(arg784_1, (4096, ), (1, )) | |
assert_size_stride(arg785_1, (4096, 4096), (4096, 1)) | |
assert_size_stride(arg786_1, (4096, 16), (16, 1)) | |
assert_size_stride(arg787_1, (4096, 16), (16, 1)) | |
assert_size_stride(arg788_1, (1024, 4096), (4096, 1)) | |
assert_size_stride(arg789_1, (1024, 16), (16, 1)) | |
assert_size_stride(arg790_1, (1024, 16), (16, 1)) | |
assert_size_stride(arg791_1, (1024, 4096), (4096, 1)) | |
assert_size_stride(arg792_1, (1024, 16), (16, 1)) | |
assert_size_stride(arg793_1, (1024, 16), (16, 1)) | |
assert_size_stride(arg794_1, (2048, 64, 2), (128, 2, 1)) | |
assert_size_stride(arg795_1, (1, 32, 2048, 128), (8388608, 262144, 128, 1)) | |
assert_size_stride(arg796_1, (1, 32, 2048, 128), (8388608, 262144, 128, 1)) | |
assert_size_stride(arg797_1, (4096, 4096), (4096, 1)) | |
assert_size_stride(arg798_1, (4096, 16), (16, 1)) | |
assert_size_stride(arg799_1, (4096, 16), (16, 1)) | |
assert_size_stride(arg800_1, (4096, ), (1, )) | |
assert_size_stride(arg801_1, (14336, 4096), (4096, 1)) | |
assert_size_stride(arg802_1, (14336, 16), (16, 1)) | |
assert_size_stride(arg803_1, (14336, 16), (16, 1)) | |
assert_size_stride(arg804_1, (14336, 4096), (4096, 1)) | |
assert_size_stride(arg805_1, (14336, 16), (16, 1)) | |
assert_size_stride(arg806_1, (14336, 16), (16, 1)) | |
assert_size_stride(arg807_1, (4096, 14336), (14336, 1)) | |
assert_size_stride(arg808_1, (4096, 56), (56, 1)) | |
assert_size_stride(arg809_1, (4096, 56), (56, 1)) | |
assert_size_stride(arg810_1, (4096, ), (1, )) | |
assert_size_stride(arg811_1, (4096, 4096), (4096, 1)) | |
assert_size_stride(arg812_1, (4096, 16), (16, 1)) | |
assert_size_stride(arg813_1, (4096, 16), (16, 1)) | |
assert_size_stride(arg814_1, (1024, 4096), (4096, 1)) | |
assert_size_stride(arg815_1, (1024, 16), (16, 1)) | |
assert_size_stride(arg816_1, (1024, 16), (16, 1)) | |
assert_size_stride(arg817_1, (1024, 4096), (4096, 1)) | |
assert_size_stride(arg818_1, (1024, 16), (16, 1)) | |
assert_size_stride(arg819_1, (1024, 16), (16, 1)) | |
assert_size_stride(arg820_1, (2048, 64, 2), (128, 2, 1)) | |
assert_size_stride(arg821_1, (1, 32, 2048, 128), (8388608, 262144, 128, 1)) | |
assert_size_stride(arg822_1, (1, 32, 2048, 128), (8388608, 262144, 128, 1)) | |
assert_size_stride(arg823_1, (4096, 4096), (4096, 1)) | |
assert_size_stride(arg824_1, (4096, 16), (16, 1)) | |
assert_size_stride(arg825_1, (4096, 16), (16, 1)) | |
assert_size_stride(arg826_1, (4096, ), (1, )) | |
assert_size_stride(arg827_1, (14336, 4096), (4096, 1)) | |
assert_size_stride(arg828_1, (14336, 16), (16, 1)) | |
assert_size_stride(arg829_1, (14336, 16), (16, 1)) | |
assert_size_stride(arg830_1, (14336, 4096), (4096, 1)) | |
assert_size_stride(arg831_1, (14336, 16), (16, 1)) | |
assert_size_stride(arg832_1, (14336, 16), (16, 1)) | |
assert_size_stride(arg833_1, (4096, 14336), (14336, 1)) | |
assert_size_stride(arg834_1, (4096, 56), (56, 1)) | |
assert_size_stride(arg835_1, (4096, 56), (56, 1)) | |
assert_size_stride(arg836_1, (4096, ), (1, )) | |
assert_size_stride(arg837_1, (128256, 4096), (4096, 1)) | |
assert_size_stride(arg838_1, (128256, 16), (16, 1)) | |
assert_size_stride(arg839_1, (128256, 16), (16, 1)) | |
with torch.cuda._DeviceGuard(0): | |
torch.cuda.set_device(0) | |
buf1 = empty_strided_cuda((1, 1, 4096), (4096, 4096, 1), torch.bfloat16) | |
# Source Nodes: [add, h, mean, mul, pow_1, rsqrt, x_fp32, x_normed, y], Original ATen: [aten._to_copy, aten.add, aten.embedding, aten.mean, aten.mul, aten.pow, aten.rsqrt] | |
stream0 = get_raw_stream(0) | |
triton_red_fused__to_copy_add_embedding_mean_mul_pow_rsqrt_0.run(arg0_1, arg1_1, arg4_1, buf1, 1, 4096, grid=grid(1), stream=stream0) | |
del arg4_1 | |
# Source Nodes: [add, choose_qparams_per_token_asymmetric, h, mean, mul, pow_1, rsqrt, x_fp32, x_normed, y], Original ATen: [aten._to_copy, aten.add, aten.embedding, aten.mean, aten.mul, aten.pow, aten.rsqrt, quantized_decomposed.choose_qparams_per_token_asymmetric] | |
buf2 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf1, torch.int8) | |
buf3 = buf2[0] | |
buf4 = buf2[1] | |
del buf2 | |
# Source Nodes: [choose_qparams_per_token_asymmetric_1], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric] | |
buf5 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf1, torch.int8) | |
buf6 = buf5[0] | |
buf7 = buf5[1] | |
del buf5 | |
# Source Nodes: [choose_qparams_per_token_asymmetric_2], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric] | |
buf8 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf1, torch.int8) | |
buf9 = buf8[0] | |
buf10 = buf8[1] | |
del buf8 | |
buf11 = empty_strided_cuda((), (), torch.int64) | |
# Source Nodes: [max_1], Original ATen: [aten.max] | |
triton_poi_fused_max_1.run(arg3_1, buf11, 1, grid=grid(1), stream=stream0) | |
u0 = buf11.item() | |
buf12 = None | |
with torch.cuda._DeviceGuard(0): | |
torch.cuda.set_device(0) | |
# Source Nodes: [input_2], Original ATen: [quantized_decomposed.quantize_per_token] | |
buf13 = torch.ops.quantized_decomposed.quantize_per_token.default(buf1, buf3, buf4, -128, 127, torch.int8) | |
buf14 = buf13 | |
del buf13 | |
# Source Nodes: [input_3], Original ATen: [quantized_decomposed.dequantize_per_token] | |
buf15 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf14, buf3, buf4, -128, 127, torch.int8, torch.bfloat16) | |
del buf14 | |
del buf3 | |
del buf4 | |
buf16 = buf15 | |
del buf15 | |
# Source Nodes: [w_dq], Original ATen: [quantized_decomposed.dequantize_per_channel_group] | |
buf17 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg5_1, arg6_1, arg7_1, -8, 7, torch.int8, 256, torch.bfloat16) | |
del arg5_1 | |
del arg6_1 | |
del arg7_1 | |
buf18 = buf17 | |
del buf17 | |
buf19 = empty_strided_cuda((1, 4096), (4096, 1), torch.bfloat16) | |
# Source Nodes: [c], Original ATen: [aten.mm] | |
triton_tem_fused_mm_2.run(buf16, buf18, buf19, grid=torch._inductor.kernel.mm_common.mm_grid(1, 4096, meta0), stream=stream0) | |
del buf16 | |
del buf18 | |
# Source Nodes: [input_5], Original ATen: [quantized_decomposed.quantize_per_token] | |
buf20 = torch.ops.quantized_decomposed.quantize_per_token.default(buf1, buf6, buf7, -128, 127, torch.int8) | |
buf21 = buf20 | |
del buf20 | |
# Source Nodes: [input_6], Original ATen: [quantized_decomposed.dequantize_per_token] | |
buf22 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf21, buf6, buf7, -128, 127, torch.int8, torch.bfloat16) | |
del buf21 | |
del buf6 | |
del buf7 | |
buf23 = buf22 | |
del buf22 | |
# Source Nodes: [w_dq_1], Original ATen: [quantized_decomposed.dequantize_per_channel_group] | |
buf24 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg8_1, arg9_1, arg10_1, -8, 7, torch.int8, 256, torch.bfloat16) | |
del arg10_1 | |
del arg8_1 | |
del arg9_1 | |
buf25 = buf24 | |
del buf24 | |
buf26 = empty_strided_cuda((1, 1024), (1024, 1), torch.bfloat16) | |
# Source Nodes: [], Original ATen: [] | |
extern_kernels.mm(reinterpret_tensor(buf23, (1, 4096), (4096, 1), 0), reinterpret_tensor(buf25, (4096, 1024), (1, 4096), 0), out=buf26) | |
del buf23 | |
del buf25 | |
# Source Nodes: [input_8], Original ATen: [quantized_decomposed.quantize_per_token] | |
buf28 = torch.ops.quantized_decomposed.quantize_per_token.default(buf1, buf9, buf10, -128, 127, torch.int8) | |
del buf1 | |
buf29 = buf28 | |
del buf28 | |
# Source Nodes: [input_9], Original ATen: [quantized_decomposed.dequantize_per_token] | |
buf30 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf29, buf9, buf10, -128, 127, torch.int8, torch.bfloat16) | |
del buf10 | |
del buf29 | |
del buf9 | |
buf31 = buf30 | |
del buf30 | |
# Source Nodes: [w_dq_2], Original ATen: [quantized_decomposed.dequantize_per_channel_group] | |
buf32 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg11_1, arg12_1, arg13_1, -8, 7, torch.int8, 256, torch.bfloat16) | |
del arg11_1 | |
del arg12_1 | |
del arg13_1 | |
buf33 = buf32 | |
del buf32 | |
buf34 = empty_strided_cuda((1, 1024), (1024, 1), torch.bfloat16) | |
# Source Nodes: [], Original ATen: [] | |
extern_kernels.mm(reinterpret_tensor(buf31, (1, 4096), (4096, 1), 0), reinterpret_tensor(buf33, (4096, 1024), (1, 4096), 0), out=buf34) | |
del buf33 | |
buf36 = reinterpret_tensor(buf31, (1, 32, 1, 128), (4096, 128, 4096, 1), 0); del buf31 # reuse | |
# Source Nodes: [output, setitem, setitem_1], Original ATen: [aten._scaled_dot_product_efficient_attention, aten.index_put] | |
triton_poi_fused__scaled_dot_product_efficient_attention_index_put_3.run(arg3_1, buf26, arg14_1, buf34, buf19, arg15_1, arg16_1, buf36, 4096, grid=grid(4096), stream=stream0) | |
del arg14_1 | |
del buf19 | |
buf37 = empty_strided_cuda((1, 32, 1, 2048), (65536, 2048, 2048, 1), torch.bfloat16) | |
# Source Nodes: [output], Original ATen: [aten._scaled_dot_product_efficient_attention] | |
triton_poi_fused__scaled_dot_product_efficient_attention_4.run(arg3_1, arg2_1, buf37, 65536, grid=grid(65536), stream=stream0) | |
# Source Nodes: [output], Original ATen: [aten._scaled_dot_product_efficient_attention] | |
buf38 = torch.ops.aten._scaled_dot_product_efficient_attention.default(buf36, arg15_1, arg16_1, buf37, False) | |
del arg15_1 | |
del arg16_1 | |
del buf36 | |
buf39 = buf38[0] | |
del buf38 | |
# Source Nodes: [choose_qparams_per_token_asymmetric_3], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric] | |
buf43 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(reinterpret_tensor(buf39, (1, 1, 4096), (4096, 4096, 1), 0), torch.int8) | |
buf44 = buf43[0] | |
buf45 = buf43[1] | |
del buf43 | |
# Source Nodes: [input_11], Original ATen: [quantized_decomposed.quantize_per_token] | |
buf46 = torch.ops.quantized_decomposed.quantize_per_token.default(reinterpret_tensor(buf39, (1, 1, 4096), (4096, 4096, 1), 0), buf44, buf45, -128, 127, torch.int8) | |
buf47 = buf46 | |
del buf46 | |
# Source Nodes: [input_12], Original ATen: [quantized_decomposed.dequantize_per_token] | |
buf48 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf47, buf44, buf45, -128, 127, torch.int8, torch.bfloat16) | |
del buf44 | |
del buf45 | |
del buf47 | |
buf49 = buf48 | |
del buf48 | |
# Source Nodes: [w_dq_3], Original ATen: [quantized_decomposed.dequantize_per_channel_group] | |
buf50 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg17_1, arg18_1, arg19_1, -8, 7, torch.int8, 256, torch.bfloat16) | |
del arg17_1 | |
del arg18_1 | |
del arg19_1 | |
buf51 = buf50 | |
del buf50 | |
buf52 = reinterpret_tensor(buf39, (1, 4096), (4096, 1), 0); del buf39 # reuse | |
# Source Nodes: [c_3], Original ATen: [aten.mm] | |
triton_tem_fused_mm_2.run(buf49, buf51, buf52, grid=torch._inductor.kernel.mm_common.mm_grid(1, 4096, meta0), stream=stream0) | |
del buf51 | |
buf54 = buf49; del buf49 # reuse | |
# Source Nodes: [add_5, h, h_1, mean_1, mul_10, mul_11, pow_2, rsqrt_1, x_fp32_1, x_normed_1], Original ATen: [aten._to_copy, aten.add, aten.embedding, aten.mean, aten.mul, aten.pow, aten.rsqrt] | |
triton_red_fused__to_copy_add_embedding_mean_mul_pow_rsqrt_5.run(buf52, arg0_1, arg1_1, arg20_1, buf54, 1, 4096, grid=grid(1), stream=stream0) | |
del arg20_1 | |
# Source Nodes: [choose_qparams_per_token_asymmetric_4], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric] | |
buf55 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf54, torch.int8) | |
buf56 = buf55[0] | |
buf57 = buf55[1] | |
del buf55 | |
# Source Nodes: [choose_qparams_per_token_asymmetric_5], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric] | |
buf58 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf54, torch.int8) | |
buf59 = buf58[0] | |
buf60 = buf58[1] | |
del buf58 | |
# Source Nodes: [input_14], Original ATen: [quantized_decomposed.quantize_per_token] | |
buf61 = torch.ops.quantized_decomposed.quantize_per_token.default(buf54, buf56, buf57, -128, 127, torch.int8) | |
buf62 = buf61 | |
del buf61 | |
# Source Nodes: [input_15], Original ATen: [quantized_decomposed.dequantize_per_token] | |
buf63 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf62, buf56, buf57, -128, 127, torch.int8, torch.bfloat16) | |
del buf56 | |
del buf57 | |
del buf62 | |
buf64 = buf63 | |
del buf63 | |
# Source Nodes: [w_dq_4], Original ATen: [quantized_decomposed.dequantize_per_channel_group] | |
buf65 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg21_1, arg22_1, arg23_1, -8, 7, torch.int8, 256, torch.bfloat16) | |
del arg21_1 | |
del arg22_1 | |
del arg23_1 | |
buf66 = buf65 | |
del buf65 | |
buf67 = empty_strided_cuda((1, 14336), (14336, 1), torch.bfloat16) | |
# Source Nodes: [c_4], Original ATen: [aten.mm] | |
triton_tem_fused_mm_6.run(buf64, buf66, buf67, grid=torch._inductor.kernel.mm_common.mm_grid(1, 14336, meta1), stream=stream0) | |
del buf66 | |
# Source Nodes: [input_17], Original ATen: [quantized_decomposed.quantize_per_token] | |
buf68 = torch.ops.quantized_decomposed.quantize_per_token.default(buf54, buf59, buf60, -128, 127, torch.int8) | |
buf69 = buf68 | |
del buf68 | |
# Source Nodes: [input_18], Original ATen: [quantized_decomposed.dequantize_per_token] | |
buf70 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf69, buf59, buf60, -128, 127, torch.int8, torch.bfloat16) | |
del buf59 | |
del buf60 | |
del buf69 | |
buf71 = buf70 | |
del buf70 | |
# Source Nodes: [w_dq_5], Original ATen: [quantized_decomposed.dequantize_per_channel_group] | |
buf72 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg24_1, arg25_1, arg26_1, -8, 7, torch.int8, 256, torch.bfloat16) | |
del arg24_1 | |
del arg25_1 | |
del arg26_1 | |
buf73 = buf72 | |
del buf72 | |
buf75 = empty_strided_cuda((1, 1, 14336), (14336, 14336, 1), torch.bfloat16) | |
# Source Nodes: [c_5, mul_12, silu], Original ATen: [aten.mm, aten.mul, aten.silu] | |
triton_tem_fused_mm_mul_silu_7.run(buf71, buf73, buf67, buf75, grid=torch._inductor.kernel.mm_common.mm_grid(1, 14336, meta1), stream=stream0) | |
del buf67 | |
del buf73 | |
# Source Nodes: [choose_qparams_per_token_asymmetric_6], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric] | |
buf76 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf75, torch.int8) | |
buf77 = buf76[0] | |
buf78 = buf76[1] | |
del buf76 | |
# Source Nodes: [input_20], Original ATen: [quantized_decomposed.quantize_per_token] | |
buf79 = torch.ops.quantized_decomposed.quantize_per_token.default(buf75, buf77, buf78, -128, 127, torch.int8) | |
buf80 = buf79 | |
del buf79 | |
# Source Nodes: [input_21], Original ATen: [quantized_decomposed.dequantize_per_token] | |
buf81 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf80, buf77, buf78, -128, 127, torch.int8, torch.bfloat16) | |
del buf77 | |
del buf78 | |
del buf80 | |
buf82 = buf81 | |
del buf81 | |
# Source Nodes: [w_dq_6], Original ATen: [quantized_decomposed.dequantize_per_channel_group] | |
buf83 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg27_1, arg28_1, arg29_1, -8, 7, torch.int8, 256, torch.bfloat16) | |
del arg27_1 | |
del arg28_1 | |
del arg29_1 | |
buf84 = buf83 | |
del buf83 | |
buf85 = reinterpret_tensor(buf71, (1, 4096), (4096, 1), 0); del buf71 # reuse | |
# Source Nodes: [c_6], Original ATen: [aten.mm] | |
triton_tem_fused_mm_8.run(buf82, buf84, buf85, grid=torch._inductor.kernel.mm_common.mm_grid(1, 4096, meta0), stream=stream0) | |
del buf84 | |
buf87 = buf54; del buf54 # reuse | |
# Source Nodes: [add_7, h, h_1, mean_2, mul_13, out, pow_3, rsqrt_2, x_fp32_2, x_normed_2, y_1], Original ATen: [aten._to_copy, aten.add, aten.embedding, aten.mean, aten.mul, aten.pow, aten.rsqrt] | |
triton_red_fused__to_copy_add_embedding_mean_mul_pow_rsqrt_9.run(buf52, arg0_1, arg1_1, buf85, arg30_1, buf87, 1, 4096, grid=grid(1), stream=stream0) | |
del arg30_1 | |
# Source Nodes: [choose_qparams_per_token_asymmetric_7], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric] | |
buf88 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf87, torch.int8) | |
buf89 = buf88[0] | |
buf90 = buf88[1] | |
del buf88 | |
# Source Nodes: [choose_qparams_per_token_asymmetric_8], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric] | |
buf91 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf87, torch.int8) | |
buf92 = buf91[0] | |
buf93 = buf91[1] | |
del buf91 | |
# Source Nodes: [choose_qparams_per_token_asymmetric_9], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric] | |
buf94 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf87, torch.int8) | |
buf95 = buf94[0] | |
buf96 = buf94[1] | |
del buf94 | |
buf97 = buf11; del buf11 # reuse | |
# Source Nodes: [max_2], Original ATen: [aten.max] | |
triton_poi_fused_max_1.run(arg3_1, buf97, 1, grid=grid(1), stream=stream0) | |
u1 = buf97.item() | |
buf98 = None | |
with torch.cuda._DeviceGuard(0): | |
torch.cuda.set_device(0) | |
# Source Nodes: [input_23], Original ATen: [quantized_decomposed.quantize_per_token] | |
buf99 = torch.ops.quantized_decomposed.quantize_per_token.default(buf87, buf89, buf90, -128, 127, torch.int8) | |
buf100 = buf99 | |
del buf99 | |
# Source Nodes: [input_24], Original ATen: [quantized_decomposed.dequantize_per_token] | |
buf101 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf100, buf89, buf90, -128, 127, torch.int8, torch.bfloat16) | |
del buf100 | |
del buf89 | |
del buf90 | |
buf102 = buf101 | |
del buf101 | |
# Source Nodes: [w_dq_7], Original ATen: [quantized_decomposed.dequantize_per_channel_group] | |
buf103 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg31_1, arg32_1, arg33_1, -8, 7, torch.int8, 256, torch.bfloat16) | |
del arg31_1 | |
del arg32_1 | |
del arg33_1 | |
buf104 = buf103 | |
del buf103 | |
buf105 = reinterpret_tensor(buf64, (1, 4096), (4096, 1), 0); del buf64 # reuse | |
# Source Nodes: [c_7], Original ATen: [aten.mm] | |
triton_tem_fused_mm_2.run(buf102, buf104, buf105, grid=torch._inductor.kernel.mm_common.mm_grid(1, 4096, meta0), stream=stream0) | |
del buf102 | |
del buf104 | |
# Source Nodes: [input_26], Original ATen: [quantized_decomposed.quantize_per_token] | |
buf106 = torch.ops.quantized_decomposed.quantize_per_token.default(buf87, buf92, buf93, -128, 127, torch.int8) | |
buf107 = buf106 | |
del buf106 | |
# Source Nodes: [input_27], Original ATen: [quantized_decomposed.dequantize_per_token] | |
buf108 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf107, buf92, buf93, -128, 127, torch.int8, torch.bfloat16) | |
del buf107 | |
del buf92 | |
del buf93 | |
buf109 = buf108 | |
del buf108 | |
# Source Nodes: [w_dq_8], Original ATen: [quantized_decomposed.dequantize_per_channel_group] | |
buf110 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg34_1, arg35_1, arg36_1, -8, 7, torch.int8, 256, torch.bfloat16) | |
del arg34_1 | |
del arg35_1 | |
del arg36_1 | |
buf111 = buf110 | |
del buf110 | |
buf112 = buf34; del buf34 # reuse | |
# Source Nodes: [], Original ATen: [] | |
extern_kernels.mm(reinterpret_tensor(buf109, (1, 4096), (4096, 1), 0), reinterpret_tensor(buf111, (4096, 1024), (1, 4096), 0), out=buf112) | |
del buf109 | |
del buf111 | |
# Source Nodes: [input_29], Original ATen: [quantized_decomposed.quantize_per_token] | |
buf114 = torch.ops.quantized_decomposed.quantize_per_token.default(buf87, buf95, buf96, -128, 127, torch.int8) | |
del buf87 | |
buf115 = buf114 | |
del buf114 | |
# Source Nodes: [input_30], Original ATen: [quantized_decomposed.dequantize_per_token] | |
buf116 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf115, buf95, buf96, -128, 127, torch.int8, torch.bfloat16) | |
del buf115 | |
del buf95 | |
del buf96 | |
buf117 = buf116 | |
del buf116 | |
# Source Nodes: [w_dq_9], Original ATen: [quantized_decomposed.dequantize_per_channel_group] | |
buf118 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg37_1, arg38_1, arg39_1, -8, 7, torch.int8, 256, torch.bfloat16) | |
del arg37_1 | |
del arg38_1 | |
del arg39_1 | |
buf119 = buf118 | |
del buf118 | |
buf120 = buf26; del buf26 # reuse | |
# Source Nodes: [], Original ATen: [] | |
extern_kernels.mm(reinterpret_tensor(buf117, (1, 4096), (4096, 1), 0), reinterpret_tensor(buf119, (4096, 1024), (1, 4096), 0), out=buf120) | |
del buf119 | |
buf122 = reinterpret_tensor(buf117, (1, 32, 1, 128), (4096, 128, 4096, 1), 0); del buf117 # reuse | |
# Source Nodes: [output_2, setitem_2, setitem_3], Original ATen: [aten._scaled_dot_product_efficient_attention, aten.index_put] | |
triton_poi_fused__scaled_dot_product_efficient_attention_index_put_3.run(arg3_1, buf112, arg40_1, buf120, buf105, arg41_1, arg42_1, buf122, 4096, grid=grid(4096), stream=stream0) | |
del arg40_1 | |
del buf105 | |
buf123 = buf37; del buf37 # reuse | |
# Source Nodes: [output_2], Original ATen: [aten._scaled_dot_product_efficient_attention] | |
triton_poi_fused__scaled_dot_product_efficient_attention_4.run(arg3_1, arg2_1, buf123, 65536, grid=grid(65536), stream=stream0) | |
# Source Nodes: [output_2], Original ATen: [aten._scaled_dot_product_efficient_attention] | |
buf124 = torch.ops.aten._scaled_dot_product_efficient_attention.default(buf122, arg41_1, arg42_1, buf123, False) | |
del arg41_1 | |
del arg42_1 | |
del buf122 | |
buf125 = buf124[0] | |
del buf124 | |
# Source Nodes: [choose_qparams_per_token_asymmetric_10], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric] | |
buf129 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(reinterpret_tensor(buf125, (1, 1, 4096), (4096, 4096, 1), 0), torch.int8) | |
buf130 = buf129[0] | |
buf131 = buf129[1] | |
del buf129 | |
# Source Nodes: [input_32], Original ATen: [quantized_decomposed.quantize_per_token] | |
buf132 = torch.ops.quantized_decomposed.quantize_per_token.default(reinterpret_tensor(buf125, (1, 1, 4096), (4096, 4096, 1), 0), buf130, buf131, -128, 127, torch.int8) | |
buf133 = buf132 | |
del buf132 | |
# Source Nodes: [input_33], Original ATen: [quantized_decomposed.dequantize_per_token] | |
buf134 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf133, buf130, buf131, -128, 127, torch.int8, torch.bfloat16) | |
del buf130 | |
del buf131 | |
del buf133 | |
buf135 = buf134 | |
del buf134 | |
# Source Nodes: [w_dq_10], Original ATen: [quantized_decomposed.dequantize_per_channel_group] | |
buf136 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg43_1, arg44_1, arg45_1, -8, 7, torch.int8, 256, torch.bfloat16) | |
del arg43_1 | |
del arg44_1 | |
del arg45_1 | |
buf137 = buf136 | |
del buf136 | |
buf138 = reinterpret_tensor(buf125, (1, 4096), (4096, 1), 0); del buf125 # reuse | |
# Source Nodes: [c_10], Original ATen: [aten.mm] | |
triton_tem_fused_mm_2.run(buf135, buf137, buf138, grid=torch._inductor.kernel.mm_common.mm_grid(1, 4096, meta0), stream=stream0) | |
del buf137 | |
buf139 = reinterpret_tensor(buf138, (1, 1, 4096), (4096, 4096, 1), 0); del buf138 # reuse | |
buf141 = buf135; del buf135 # reuse | |
# Source Nodes: [add_12, h, h_1, h_2, mean_3, mul_23, mul_24, out, pow_4, rsqrt_3, x_fp32_3, x_normed_3], Original ATen: [aten._to_copy, aten.add, aten.embedding, aten.mean, aten.mul, aten.pow, aten.rsqrt] | |
triton_red_fused__to_copy_add_embedding_mean_mul_pow_rsqrt_10.run(buf139, buf52, arg0_1, arg1_1, buf85, arg46_1, buf141, 1, 4096, grid=grid(1), stream=stream0) | |
del arg0_1 | |
del arg1_1 | |
del arg46_1 | |
del buf52 | |
del buf85 | |
# Source Nodes: [choose_qparams_per_token_asymmetric_11], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric] | |
buf142 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf141, torch.int8) | |
buf143 = buf142[0] | |
buf144 = buf142[1] | |
del buf142 | |
# Source Nodes: [choose_qparams_per_token_asymmetric_12], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric] | |
buf145 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf141, torch.int8) | |
buf146 = buf145[0] | |
buf147 = buf145[1] | |
del buf145 | |
# Source Nodes: [input_35], Original ATen: [quantized_decomposed.quantize_per_token] | |
buf148 = torch.ops.quantized_decomposed.quantize_per_token.default(buf141, buf143, buf144, -128, 127, torch.int8) | |
buf149 = buf148 | |
del buf148 | |
# Source Nodes: [input_36], Original ATen: [quantized_decomposed.dequantize_per_token] | |
buf150 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf149, buf143, buf144, -128, 127, torch.int8, torch.bfloat16) | |
del buf143 | |
del buf144 | |
del buf149 | |
buf151 = buf150 | |
del buf150 | |
# Source Nodes: [w_dq_11], Original ATen: [quantized_decomposed.dequantize_per_channel_group] | |
buf152 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg47_1, arg48_1, arg49_1, -8, 7, torch.int8, 256, torch.bfloat16) | |
del arg47_1 | |
del arg48_1 | |
del arg49_1 | |
buf153 = buf152 | |
del buf152 | |
buf154 = reinterpret_tensor(buf82, (1, 14336), (14336, 1), 0); del buf82 # reuse | |
# Source Nodes: [c_11], Original ATen: [aten.mm] | |
triton_tem_fused_mm_6.run(buf151, buf153, buf154, grid=torch._inductor.kernel.mm_common.mm_grid(1, 14336, meta1), stream=stream0) | |
del buf153 | |
# Source Nodes: [input_38], Original ATen: [quantized_decomposed.quantize_per_token] | |
buf155 = torch.ops.quantized_decomposed.quantize_per_token.default(buf141, buf146, buf147, -128, 127, torch.int8) | |
buf156 = buf155 | |
del buf155 | |
# Source Nodes: [input_39], Original ATen: [quantized_decomposed.dequantize_per_token] | |
buf157 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf156, buf146, buf147, -128, 127, torch.int8, torch.bfloat16) | |
del buf146 | |
del buf147 | |
del buf156 | |
buf158 = buf157 | |
del buf157 | |
# Source Nodes: [w_dq_12], Original ATen: [quantized_decomposed.dequantize_per_channel_group] | |
buf159 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg50_1, arg51_1, arg52_1, -8, 7, torch.int8, 256, torch.bfloat16) | |
del arg50_1 | |
del arg51_1 | |
del arg52_1 | |
buf160 = buf159 | |
del buf159 | |
buf162 = buf75; del buf75 # reuse | |
# Source Nodes: [c_12, mul_25, silu_1], Original ATen: [aten.mm, aten.mul, aten.silu] | |
triton_tem_fused_mm_mul_silu_7.run(buf158, buf160, buf154, buf162, grid=torch._inductor.kernel.mm_common.mm_grid(1, 14336, meta1), stream=stream0) | |
del buf154 | |
del buf160 | |
# Source Nodes: [choose_qparams_per_token_asymmetric_13], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric] | |
buf163 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf162, torch.int8) | |
buf164 = buf163[0] | |
buf165 = buf163[1] | |
del buf163 | |
# Source Nodes: [input_41], Original ATen: [quantized_decomposed.quantize_per_token] | |
buf166 = torch.ops.quantized_decomposed.quantize_per_token.default(buf162, buf164, buf165, -128, 127, torch.int8) | |
buf167 = buf166 | |
del buf166 | |
# Source Nodes: [input_42], Original ATen: [quantized_decomposed.dequantize_per_token] | |
buf168 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf167, buf164, buf165, -128, 127, torch.int8, torch.bfloat16) | |
del buf164 | |
del buf165 | |
del buf167 | |
buf169 = buf168 | |
del buf168 | |
# Source Nodes: [w_dq_13], Original ATen: [quantized_decomposed.dequantize_per_channel_group] | |
buf170 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg53_1, arg54_1, arg55_1, -8, 7, torch.int8, 256, torch.bfloat16) | |
del arg53_1 | |
del arg54_1 | |
del arg55_1 | |
buf171 = buf170 | |
del buf170 | |
buf172 = reinterpret_tensor(buf158, (1, 4096), (4096, 1), 0); del buf158 # reuse | |
# Source Nodes: [c_13], Original ATen: [aten.mm] | |
triton_tem_fused_mm_8.run(buf169, buf171, buf172, grid=torch._inductor.kernel.mm_common.mm_grid(1, 4096, meta0), stream=stream0) | |
del buf171 | |
buf174 = buf141; del buf141 # reuse | |
# Source Nodes: [add_14, mean_4, mul_26, out_1, pow_5, rsqrt_4, x_fp32_4, x_normed_4, y_2], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, aten.pow, aten.rsqrt] | |
triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_11.run(buf139, buf172, arg56_1, buf174, 1, 4096, grid=grid(1), stream=stream0) | |
del arg56_1 | |
# Source Nodes: [choose_qparams_per_token_asymmetric_14], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric] | |
buf175 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf174, torch.int8) | |
buf176 = buf175[0] | |
buf177 = buf175[1] | |
del buf175 | |
# Source Nodes: [choose_qparams_per_token_asymmetric_15], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric] | |
buf178 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf174, torch.int8) | |
buf179 = buf178[0] | |
buf180 = buf178[1] | |
del buf178 | |
# Source Nodes: [choose_qparams_per_token_asymmetric_16], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric] | |
buf181 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf174, torch.int8) | |
buf182 = buf181[0] | |
buf183 = buf181[1] | |
del buf181 | |
buf184 = buf97; del buf97 # reuse | |
# Source Nodes: [max_3], Original ATen: [aten.max] | |
triton_poi_fused_max_1.run(arg3_1, buf184, 1, grid=grid(1), stream=stream0) | |
u2 = buf184.item() | |
buf185 = None | |
with torch.cuda._DeviceGuard(0): | |
torch.cuda.set_device(0) | |
# Source Nodes: [input_44], Original ATen: [quantized_decomposed.quantize_per_token] | |
buf186 = torch.ops.quantized_decomposed.quantize_per_token.default(buf174, buf176, buf177, -128, 127, torch.int8) | |
buf187 = buf186 | |
del buf186 | |
# Source Nodes: [input_45], Original ATen: [quantized_decomposed.dequantize_per_token] | |
buf188 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf187, buf176, buf177, -128, 127, torch.int8, torch.bfloat16) | |
del buf176 | |
del buf177 | |
del buf187 | |
buf189 = buf188 | |
del buf188 | |
# Source Nodes: [w_dq_14], Original ATen: [quantized_decomposed.dequantize_per_channel_group] | |
buf190 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg57_1, arg58_1, arg59_1, -8, 7, torch.int8, 256, torch.bfloat16) | |
del arg57_1 | |
del arg58_1 | |
del arg59_1 | |
buf191 = buf190 | |
del buf190 | |
buf192 = reinterpret_tensor(buf151, (1, 4096), (4096, 1), 0); del buf151 # reuse | |
# Source Nodes: [c_14], Original ATen: [aten.mm] | |
triton_tem_fused_mm_2.run(buf189, buf191, buf192, grid=torch._inductor.kernel.mm_common.mm_grid(1, 4096, meta0), stream=stream0) | |
del buf189 | |
del buf191 | |
# Source Nodes: [input_47], Original ATen: [quantized_decomposed.quantize_per_token] | |
buf193 = torch.ops.quantized_decomposed.quantize_per_token.default(buf174, buf179, buf180, -128, 127, torch.int8) | |
buf194 = buf193 | |
del buf193 | |
# Source Nodes: [input_48], Original ATen: [quantized_decomposed.dequantize_per_token] | |
buf195 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf194, buf179, buf180, -128, 127, torch.int8, torch.bfloat16) | |
del buf179 | |
del buf180 | |
del buf194 | |
buf196 = buf195 | |
del buf195 | |
# Source Nodes: [w_dq_15], Original ATen: [quantized_decomposed.dequantize_per_channel_group] | |
buf197 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg60_1, arg61_1, arg62_1, -8, 7, torch.int8, 256, torch.bfloat16) | |
del arg60_1 | |
del arg61_1 | |
del arg62_1 | |
buf198 = buf197 | |
del buf197 | |
buf199 = buf120; del buf120 # reuse | |
# Source Nodes: [], Original ATen: [] | |
extern_kernels.mm(reinterpret_tensor(buf196, (1, 4096), (4096, 1), 0), reinterpret_tensor(buf198, (4096, 1024), (1, 4096), 0), out=buf199) | |
del buf196 | |
del buf198 | |
# Source Nodes: [input_50], Original ATen: [quantized_decomposed.quantize_per_token] | |
buf201 = torch.ops.quantized_decomposed.quantize_per_token.default(buf174, buf182, buf183, -128, 127, torch.int8) | |
del buf174 | |
buf202 = buf201 | |
del buf201 | |
# Source Nodes: [input_51], Original ATen: [quantized_decomposed.dequantize_per_token] | |
buf203 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf202, buf182, buf183, -128, 127, torch.int8, torch.bfloat16) | |
del buf182 | |
del buf183 | |
del buf202 | |
buf204 = buf203 | |
del buf203 | |
# Source Nodes: [w_dq_16], Original ATen: [quantized_decomposed.dequantize_per_channel_group] | |
buf205 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg63_1, arg64_1, arg65_1, -8, 7, torch.int8, 256, torch.bfloat16) | |
del arg63_1 | |
del arg64_1 | |
del arg65_1 | |
buf206 = buf205 | |
del buf205 | |
buf207 = buf112; del buf112 # reuse | |
# Source Nodes: [], Original ATen: [] | |
extern_kernels.mm(reinterpret_tensor(buf204, (1, 4096), (4096, 1), 0), reinterpret_tensor(buf206, (4096, 1024), (1, 4096), 0), out=buf207) | |
del buf206 | |
buf209 = reinterpret_tensor(buf204, (1, 32, 1, 128), (4096, 128, 4096, 1), 0); del buf204 # reuse | |
# Source Nodes: [output_4, setitem_4, setitem_5], Original ATen: [aten._scaled_dot_product_efficient_attention, aten.index_put] | |
triton_poi_fused__scaled_dot_product_efficient_attention_index_put_3.run(arg3_1, buf199, arg66_1, buf207, buf192, arg67_1, arg68_1, buf209, 4096, grid=grid(4096), stream=stream0) | |
del arg66_1 | |
del buf192 | |
buf210 = buf123; del buf123 # reuse | |
# Source Nodes: [output_4], Original ATen: [aten._scaled_dot_product_efficient_attention] | |
triton_poi_fused__scaled_dot_product_efficient_attention_4.run(arg3_1, arg2_1, buf210, 65536, grid=grid(65536), stream=stream0) | |
# Source Nodes: [output_4], Original ATen: [aten._scaled_dot_product_efficient_attention] | |
buf211 = torch.ops.aten._scaled_dot_product_efficient_attention.default(buf209, arg67_1, arg68_1, buf210, False) | |
del arg67_1 | |
del arg68_1 | |
del buf209 | |
buf212 = buf211[0] | |
del buf211 | |
# Source Nodes: [choose_qparams_per_token_asymmetric_17], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric] | |
buf216 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(reinterpret_tensor(buf212, (1, 1, 4096), (4096, 4096, 1), 0), torch.int8) | |
buf217 = buf216[0] | |
buf218 = buf216[1] | |
del buf216 | |
# Source Nodes: [input_53], Original ATen: [quantized_decomposed.quantize_per_token] | |
buf219 = torch.ops.quantized_decomposed.quantize_per_token.default(reinterpret_tensor(buf212, (1, 1, 4096), (4096, 4096, 1), 0), buf217, buf218, -128, 127, torch.int8) | |
buf220 = buf219 | |
del buf219 | |
# Source Nodes: [input_54], Original ATen: [quantized_decomposed.dequantize_per_token] | |
buf221 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf220, buf217, buf218, -128, 127, torch.int8, torch.bfloat16) | |
del buf217 | |
del buf218 | |
del buf220 | |
buf222 = buf221 | |
del buf221 | |
# Source Nodes: [w_dq_17], Original ATen: [quantized_decomposed.dequantize_per_channel_group] | |
buf223 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg69_1, arg70_1, arg71_1, -8, 7, torch.int8, 256, torch.bfloat16) | |
del arg69_1 | |
del arg70_1 | |
del arg71_1 | |
buf224 = buf223 | |
del buf223 | |
buf225 = reinterpret_tensor(buf212, (1, 4096), (4096, 1), 0); del buf212 # reuse | |
# Source Nodes: [c_17], Original ATen: [aten.mm] | |
triton_tem_fused_mm_2.run(buf222, buf224, buf225, grid=torch._inductor.kernel.mm_common.mm_grid(1, 4096, meta0), stream=stream0) | |
del buf224 | |
buf227 = buf222; del buf222 # reuse | |
# Source Nodes: [add_19, h_3, mean_5, mul_36, mul_37, out_1, pow_6, rsqrt_5, x_fp32_5, x_normed_5], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, aten.pow, aten.rsqrt] | |
triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_12.run(buf225, buf139, buf172, arg72_1, buf227, 1, 4096, grid=grid(1), stream=stream0) | |
del arg72_1 | |
# Source Nodes: [choose_qparams_per_token_asymmetric_18], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric] | |
buf228 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf227, torch.int8) | |
buf229 = buf228[0] | |
buf230 = buf228[1] | |
del buf228 | |
# Source Nodes: [choose_qparams_per_token_asymmetric_19], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric] | |
buf231 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf227, torch.int8) | |
buf232 = buf231[0] | |
buf233 = buf231[1] | |
del buf231 | |
# Source Nodes: [input_56], Original ATen: [quantized_decomposed.quantize_per_token] | |
buf234 = torch.ops.quantized_decomposed.quantize_per_token.default(buf227, buf229, buf230, -128, 127, torch.int8) | |
buf235 = buf234 | |
del buf234 | |
# Source Nodes: [input_57], Original ATen: [quantized_decomposed.dequantize_per_token] | |
buf236 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf235, buf229, buf230, -128, 127, torch.int8, torch.bfloat16) | |
del buf229 | |
del buf230 | |
del buf235 | |
buf237 = buf236 | |
del buf236 | |
# Source Nodes: [w_dq_18], Original ATen: [quantized_decomposed.dequantize_per_channel_group] | |
buf238 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg73_1, arg74_1, arg75_1, -8, 7, torch.int8, 256, torch.bfloat16) | |
del arg73_1 | |
del arg74_1 | |
del arg75_1 | |
buf239 = buf238 | |
del buf238 | |
buf240 = reinterpret_tensor(buf169, (1, 14336), (14336, 1), 0); del buf169 # reuse | |
# Source Nodes: [c_18], Original ATen: [aten.mm] | |
triton_tem_fused_mm_6.run(buf237, buf239, buf240, grid=torch._inductor.kernel.mm_common.mm_grid(1, 14336, meta1), stream=stream0) | |
del buf239 | |
# Source Nodes: [input_59], Original ATen: [quantized_decomposed.quantize_per_token] | |
buf241 = torch.ops.quantized_decomposed.quantize_per_token.default(buf227, buf232, buf233, -128, 127, torch.int8) | |
buf242 = buf241 | |
del buf241 | |
# Source Nodes: [input_60], Original ATen: [quantized_decomposed.dequantize_per_token] | |
buf243 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf242, buf232, buf233, -128, 127, torch.int8, torch.bfloat16) | |
del buf232 | |
del buf233 | |
del buf242 | |
buf244 = buf243 | |
del buf243 | |
# Source Nodes: [w_dq_19], Original ATen: [quantized_decomposed.dequantize_per_channel_group] | |
buf245 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg76_1, arg77_1, arg78_1, -8, 7, torch.int8, 256, torch.bfloat16) | |
del arg76_1 | |
del arg77_1 | |
del arg78_1 | |
buf246 = buf245 | |
del buf245 | |
buf248 = buf162; del buf162 # reuse | |
# Source Nodes: [c_19, mul_38, silu_2], Original ATen: [aten.mm, aten.mul, aten.silu] | |
triton_tem_fused_mm_mul_silu_7.run(buf244, buf246, buf240, buf248, grid=torch._inductor.kernel.mm_common.mm_grid(1, 14336, meta1), stream=stream0) | |
del buf240 | |
del buf246 | |
# Source Nodes: [choose_qparams_per_token_asymmetric_20], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric] | |
buf249 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf248, torch.int8) | |
buf250 = buf249[0] | |
buf251 = buf249[1] | |
del buf249 | |
# Source Nodes: [input_62], Original ATen: [quantized_decomposed.quantize_per_token] | |
buf252 = torch.ops.quantized_decomposed.quantize_per_token.default(buf248, buf250, buf251, -128, 127, torch.int8) | |
buf253 = buf252 | |
del buf252 | |
# Source Nodes: [input_63], Original ATen: [quantized_decomposed.dequantize_per_token] | |
buf254 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf253, buf250, buf251, -128, 127, torch.int8, torch.bfloat16) | |
del buf250 | |
del buf251 | |
del buf253 | |
buf255 = buf254 | |
del buf254 | |
# Source Nodes: [w_dq_20], Original ATen: [quantized_decomposed.dequantize_per_channel_group] | |
buf256 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg79_1, arg80_1, arg81_1, -8, 7, torch.int8, 256, torch.bfloat16) | |
del arg79_1 | |
del arg80_1 | |
del arg81_1 | |
buf257 = buf256 | |
del buf256 | |
buf258 = reinterpret_tensor(buf244, (1, 4096), (4096, 1), 0); del buf244 # reuse | |
# Source Nodes: [c_20], Original ATen: [aten.mm] | |
triton_tem_fused_mm_8.run(buf255, buf257, buf258, grid=torch._inductor.kernel.mm_common.mm_grid(1, 4096, meta0), stream=stream0) | |
del buf257 | |
buf260 = buf227; del buf227 # reuse | |
# Source Nodes: [add_21, h_3, mean_6, mul_39, out_1, out_2, pow_7, rsqrt_6, x_fp32_6, x_normed_6, y_3], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, aten.pow, aten.rsqrt] | |
triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_13.run(buf225, buf139, buf172, buf258, arg82_1, buf260, 1, 4096, grid=grid(1), stream=stream0) | |
del arg82_1 | |
# Source Nodes: [choose_qparams_per_token_asymmetric_21], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric] | |
buf261 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf260, torch.int8) | |
buf262 = buf261[0] | |
buf263 = buf261[1] | |
del buf261 | |
# Source Nodes: [choose_qparams_per_token_asymmetric_22], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric] | |
buf264 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf260, torch.int8) | |
buf265 = buf264[0] | |
buf266 = buf264[1] | |
del buf264 | |
# Source Nodes: [choose_qparams_per_token_asymmetric_23], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric] | |
buf267 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf260, torch.int8) | |
buf268 = buf267[0] | |
buf269 = buf267[1] | |
del buf267 | |
buf270 = buf184; del buf184 # reuse | |
# Source Nodes: [max_4], Original ATen: [aten.max] | |
triton_poi_fused_max_1.run(arg3_1, buf270, 1, grid=grid(1), stream=stream0) | |
u3 = buf270.item() | |
buf271 = None | |
with torch.cuda._DeviceGuard(0): | |
torch.cuda.set_device(0) | |
# Source Nodes: [input_65], Original ATen: [quantized_decomposed.quantize_per_token] | |
buf272 = torch.ops.quantized_decomposed.quantize_per_token.default(buf260, buf262, buf263, -128, 127, torch.int8) | |
buf273 = buf272 | |
del buf272 | |
# Source Nodes: [input_66], Original ATen: [quantized_decomposed.dequantize_per_token] | |
buf274 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf273, buf262, buf263, -128, 127, torch.int8, torch.bfloat16) | |
del buf262 | |
del buf263 | |
del buf273 | |
buf275 = buf274 | |
del buf274 | |
# Source Nodes: [w_dq_21], Original ATen: [quantized_decomposed.dequantize_per_channel_group] | |
buf276 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg83_1, arg84_1, arg85_1, -8, 7, torch.int8, 256, torch.bfloat16) | |
del arg83_1 | |
del arg84_1 | |
del arg85_1 | |
buf277 = buf276 | |
del buf276 | |
buf278 = reinterpret_tensor(buf237, (1, 4096), (4096, 1), 0); del buf237 # reuse | |
# Source Nodes: [c_21], Original ATen: [aten.mm] | |
triton_tem_fused_mm_2.run(buf275, buf277, buf278, grid=torch._inductor.kernel.mm_common.mm_grid(1, 4096, meta0), stream=stream0) | |
del buf275 | |
del buf277 | |
# Source Nodes: [input_68], Original ATen: [quantized_decomposed.quantize_per_token] | |
buf279 = torch.ops.quantized_decomposed.quantize_per_token.default(buf260, buf265, buf266, -128, 127, torch.int8) | |
buf280 = buf279 | |
del buf279 | |
# Source Nodes: [input_69], Original ATen: [quantized_decomposed.dequantize_per_token] | |
buf281 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf280, buf265, buf266, -128, 127, torch.int8, torch.bfloat16) | |
del buf265 | |
del buf266 | |
del buf280 | |
buf282 = buf281 | |
del buf281 | |
# Source Nodes: [w_dq_22], Original ATen: [quantized_decomposed.dequantize_per_channel_group] | |
buf283 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg86_1, arg87_1, arg88_1, -8, 7, torch.int8, 256, torch.bfloat16) | |
del arg86_1 | |
del arg87_1 | |
del arg88_1 | |
buf284 = buf283 | |
del buf283 | |
buf285 = buf207; del buf207 # reuse | |
# Source Nodes: [], Original ATen: [] | |
extern_kernels.mm(reinterpret_tensor(buf282, (1, 4096), (4096, 1), 0), reinterpret_tensor(buf284, (4096, 1024), (1, 4096), 0), out=buf285) | |
del buf282 | |
del buf284 | |
# Source Nodes: [input_71], Original ATen: [quantized_decomposed.quantize_per_token] | |
buf287 = torch.ops.quantized_decomposed.quantize_per_token.default(buf260, buf268, buf269, -128, 127, torch.int8) | |
del buf260 | |
buf288 = buf287 | |
del buf287 | |
# Source Nodes: [input_72], Original ATen: [quantized_decomposed.dequantize_per_token] | |
buf289 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf288, buf268, buf269, -128, 127, torch.int8, torch.bfloat16) | |
del buf268 | |
del buf269 | |
del buf288 | |
buf290 = buf289 | |
del buf289 | |
# Source Nodes: [w_dq_23], Original ATen: [quantized_decomposed.dequantize_per_channel_group] | |
buf291 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg89_1, arg90_1, arg91_1, -8, 7, torch.int8, 256, torch.bfloat16) | |
del arg89_1 | |
del arg90_1 | |
del arg91_1 | |
buf292 = buf291 | |
del buf291 | |
buf293 = buf199; del buf199 # reuse | |
# Source Nodes: [], Original ATen: [] | |
extern_kernels.mm(reinterpret_tensor(buf290, (1, 4096), (4096, 1), 0), reinterpret_tensor(buf292, (4096, 1024), (1, 4096), 0), out=buf293) | |
del buf292 | |
buf295 = reinterpret_tensor(buf290, (1, 32, 1, 128), (4096, 128, 4096, 1), 0); del buf290 # reuse | |
# Source Nodes: [output_6, setitem_6, setitem_7], Original ATen: [aten._scaled_dot_product_efficient_attention, aten.index_put] | |
triton_poi_fused__scaled_dot_product_efficient_attention_index_put_3.run(arg3_1, buf285, arg92_1, buf293, buf278, arg93_1, arg94_1, buf295, 4096, grid=grid(4096), stream=stream0) | |
del arg92_1 | |
del buf278 | |
buf296 = buf210; del buf210 # reuse | |
# Source Nodes: [output_6], Original ATen: [aten._scaled_dot_product_efficient_attention] | |
triton_poi_fused__scaled_dot_product_efficient_attention_4.run(arg3_1, arg2_1, buf296, 65536, grid=grid(65536), stream=stream0) | |
# Source Nodes: [output_6], Original ATen: [aten._scaled_dot_product_efficient_attention] | |
buf297 = torch.ops.aten._scaled_dot_product_efficient_attention.default(buf295, arg93_1, arg94_1, buf296, False) | |
del arg93_1 | |
del arg94_1 | |
del buf295 | |
buf298 = buf297[0] | |
del buf297 | |
# Source Nodes: [choose_qparams_per_token_asymmetric_24], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric] | |
buf302 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(reinterpret_tensor(buf298, (1, 1, 4096), (4096, 4096, 1), 0), torch.int8) | |
buf303 = buf302[0] | |
buf304 = buf302[1] | |
del buf302 | |
# Source Nodes: [input_74], Original ATen: [quantized_decomposed.quantize_per_token] | |
buf305 = torch.ops.quantized_decomposed.quantize_per_token.default(reinterpret_tensor(buf298, (1, 1, 4096), (4096, 4096, 1), 0), buf303, buf304, -128, 127, torch.int8) | |
buf306 = buf305 | |
del buf305 | |
# Source Nodes: [input_75], Original ATen: [quantized_decomposed.dequantize_per_token] | |
buf307 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf306, buf303, buf304, -128, 127, torch.int8, torch.bfloat16) | |
del buf303 | |
del buf304 | |
del buf306 | |
buf308 = buf307 | |
del buf307 | |
# Source Nodes: [w_dq_24], Original ATen: [quantized_decomposed.dequantize_per_channel_group] | |
buf309 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg95_1, arg96_1, arg97_1, -8, 7, torch.int8, 256, torch.bfloat16) | |
del arg95_1 | |
del arg96_1 | |
del arg97_1 | |
buf310 = buf309 | |
del buf309 | |
buf311 = reinterpret_tensor(buf298, (1, 4096), (4096, 1), 0); del buf298 # reuse | |
# Source Nodes: [c_24], Original ATen: [aten.mm] | |
triton_tem_fused_mm_2.run(buf308, buf310, buf311, grid=torch._inductor.kernel.mm_common.mm_grid(1, 4096, meta0), stream=stream0) | |
del buf310 | |
buf312 = buf139; del buf139 # reuse | |
buf314 = buf308; del buf308 # reuse | |
# Source Nodes: [add_26, h_3, h_4, mean_7, mul_49, mul_50, out_1, out_2, pow_8, rsqrt_7, x_fp32_7, x_normed_7], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, aten.pow, aten.rsqrt] | |
triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_14.run(buf312, buf311, buf225, buf172, buf258, arg98_1, buf314, 1, 4096, grid=grid(1), stream=stream0) | |
del arg98_1 | |
del buf172 | |
del buf225 | |
del buf258 | |
del buf311 | |
# Source Nodes: [choose_qparams_per_token_asymmetric_25], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric] | |
buf315 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf314, torch.int8) | |
buf316 = buf315[0] | |
buf317 = buf315[1] | |
del buf315 | |
# Source Nodes: [choose_qparams_per_token_asymmetric_26], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric] | |
buf318 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf314, torch.int8) | |
buf319 = buf318[0] | |
buf320 = buf318[1] | |
del buf318 | |
# Source Nodes: [input_77], Original ATen: [quantized_decomposed.quantize_per_token] | |
buf321 = torch.ops.quantized_decomposed.quantize_per_token.default(buf314, buf316, buf317, -128, 127, torch.int8) | |
buf322 = buf321 | |
del buf321 | |
# Source Nodes: [input_78], Original ATen: [quantized_decomposed.dequantize_per_token] | |
buf323 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf322, buf316, buf317, -128, 127, torch.int8, torch.bfloat16) | |
del buf316 | |
del buf317 | |
del buf322 | |
buf324 = buf323 | |
del buf323 | |
# Source Nodes: [w_dq_25], Original ATen: [quantized_decomposed.dequantize_per_channel_group] | |
buf325 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg99_1, arg100_1, arg101_1, -8, 7, torch.int8, 256, torch.bfloat16) | |
del arg100_1 | |
del arg101_1 | |
del arg99_1 | |
buf326 = buf325 | |
del buf325 | |
buf327 = reinterpret_tensor(buf255, (1, 14336), (14336, 1), 0); del buf255 # reuse | |
# Source Nodes: [c_25], Original ATen: [aten.mm] | |
triton_tem_fused_mm_6.run(buf324, buf326, buf327, grid=torch._inductor.kernel.mm_common.mm_grid(1, 14336, meta1), stream=stream0) | |
del buf326 | |
# Source Nodes: [input_80], Original ATen: [quantized_decomposed.quantize_per_token] | |
buf328 = torch.ops.quantized_decomposed.quantize_per_token.default(buf314, buf319, buf320, -128, 127, torch.int8) | |
buf329 = buf328 | |
del buf328 | |
# Source Nodes: [input_81], Original ATen: [quantized_decomposed.dequantize_per_token] | |
buf330 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf329, buf319, buf320, -128, 127, torch.int8, torch.bfloat16) | |
del buf319 | |
del buf320 | |
del buf329 | |
buf331 = buf330 | |
del buf330 | |
# Source Nodes: [w_dq_26], Original ATen: [quantized_decomposed.dequantize_per_channel_group] | |
buf332 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg102_1, arg103_1, arg104_1, -8, 7, torch.int8, 256, torch.bfloat16) | |
del arg102_1 | |
del arg103_1 | |
del arg104_1 | |
buf333 = buf332 | |
del buf332 | |
buf335 = buf248; del buf248 # reuse | |
# Source Nodes: [c_26, mul_51, silu_3], Original ATen: [aten.mm, aten.mul, aten.silu] | |
triton_tem_fused_mm_mul_silu_7.run(buf331, buf333, buf327, buf335, grid=torch._inductor.kernel.mm_common.mm_grid(1, 14336, meta1), stream=stream0) | |
del buf327 | |
del buf333 | |
# Source Nodes: [choose_qparams_per_token_asymmetric_27], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric] | |
buf336 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf335, torch.int8) | |
buf337 = buf336[0] | |
buf338 = buf336[1] | |
del buf336 | |
# Source Nodes: [input_83], Original ATen: [quantized_decomposed.quantize_per_token] | |
buf339 = torch.ops.quantized_decomposed.quantize_per_token.default(buf335, buf337, buf338, -128, 127, torch.int8) | |
buf340 = buf339 | |
del buf339 | |
# Source Nodes: [input_84], Original ATen: [quantized_decomposed.dequantize_per_token] | |
buf341 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf340, buf337, buf338, -128, 127, torch.int8, torch.bfloat16) | |
del buf337 | |
del buf338 | |
del buf340 | |
buf342 = buf341 | |
del buf341 | |
# Source Nodes: [w_dq_27], Original ATen: [quantized_decomposed.dequantize_per_channel_group] | |
buf343 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg105_1, arg106_1, arg107_1, -8, 7, torch.int8, 256, torch.bfloat16) | |
del arg105_1 | |
del arg106_1 | |
del arg107_1 | |
buf344 = buf343 | |
del buf343 | |
buf345 = reinterpret_tensor(buf331, (1, 4096), (4096, 1), 0); del buf331 # reuse | |
# Source Nodes: [c_27], Original ATen: [aten.mm] | |
triton_tem_fused_mm_8.run(buf342, buf344, buf345, grid=torch._inductor.kernel.mm_common.mm_grid(1, 4096, meta0), stream=stream0) | |
del buf344 | |
buf347 = buf314; del buf314 # reuse | |
# Source Nodes: [add_28, mean_8, mul_52, out_3, pow_9, rsqrt_8, x_fp32_8, x_normed_8, y_4], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, aten.pow, aten.rsqrt] | |
triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_11.run(buf312, buf345, arg108_1, buf347, 1, 4096, grid=grid(1), stream=stream0) | |
del arg108_1 | |
# Source Nodes: [choose_qparams_per_token_asymmetric_28], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric] | |
buf348 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf347, torch.int8) | |
buf349 = buf348[0] | |
buf350 = buf348[1] | |
del buf348 | |
# Source Nodes: [choose_qparams_per_token_asymmetric_29], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric] | |
buf351 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf347, torch.int8) | |
buf352 = buf351[0] | |
buf353 = buf351[1] | |
del buf351 | |
# Source Nodes: [choose_qparams_per_token_asymmetric_30], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric] | |
buf354 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf347, torch.int8) | |
buf355 = buf354[0] | |
buf356 = buf354[1] | |
del buf354 | |
buf357 = buf270; del buf270 # reuse | |
# Source Nodes: [max_5], Original ATen: [aten.max] | |
triton_poi_fused_max_1.run(arg3_1, buf357, 1, grid=grid(1), stream=stream0) | |
u4 = buf357.item() | |
buf358 = None | |
with torch.cuda._DeviceGuard(0): | |
torch.cuda.set_device(0) | |
# Source Nodes: [input_86], Original ATen: [quantized_decomposed.quantize_per_token] | |
buf359 = torch.ops.quantized_decomposed.quantize_per_token.default(buf347, buf349, buf350, -128, 127, torch.int8) | |
buf360 = buf359 | |
del buf359 | |
# Source Nodes: [input_87], Original ATen: [quantized_decomposed.dequantize_per_token] | |
buf361 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf360, buf349, buf350, -128, 127, torch.int8, torch.bfloat16) | |
del buf349 | |
del buf350 | |
del buf360 | |
buf362 = buf361 | |
del buf361 | |
# Source Nodes: [w_dq_28], Original ATen: [quantized_decomposed.dequantize_per_channel_group] | |
buf363 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg109_1, arg110_1, arg111_1, -8, 7, torch.int8, 256, torch.bfloat16) | |
del arg109_1 | |
del arg110_1 | |
del arg111_1 | |
buf364 = buf363 | |
del buf363 | |
buf365 = reinterpret_tensor(buf324, (1, 4096), (4096, 1), 0); del buf324 # reuse | |
# Source Nodes: [c_28], Original ATen: [aten.mm] | |
triton_tem_fused_mm_2.run(buf362, buf364, buf365, grid=torch._inductor.kernel.mm_common.mm_grid(1, 4096, meta0), stream=stream0) | |
del buf362 | |
del buf364 | |
# Source Nodes: [input_89], Original ATen: [quantized_decomposed.quantize_per_token] | |
buf366 = torch.ops.quantized_decomposed.quantize_per_token.default(buf347, buf352, buf353, -128, 127, torch.int8) | |
buf367 = buf366 | |
del buf366 | |
# Source Nodes: [input_90], Original ATen: [quantized_decomposed.dequantize_per_token] | |
buf368 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf367, buf352, buf353, -128, 127, torch.int8, torch.bfloat16) | |
del buf352 | |
del buf353 | |
del buf367 | |
buf369 = buf368 | |
del buf368 | |
# Source Nodes: [w_dq_29], Original ATen: [quantized_decomposed.dequantize_per_channel_group] | |
buf370 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg112_1, arg113_1, arg114_1, -8, 7, torch.int8, 256, torch.bfloat16) | |
del arg112_1 | |
del arg113_1 | |
del arg114_1 | |
buf371 = buf370 | |
del buf370 | |
buf372 = buf293; del buf293 # reuse | |
# Source Nodes: [], Original ATen: [] | |
extern_kernels.mm(reinterpret_tensor(buf369, (1, 4096), (4096, 1), 0), reinterpret_tensor(buf371, (4096, 1024), (1, 4096), 0), out=buf372) | |
del buf369 | |
del buf371 | |
# Source Nodes: [input_92], Original ATen: [quantized_decomposed.quantize_per_token] | |
buf374 = torch.ops.quantized_decomposed.quantize_per_token.default(buf347, buf355, buf356, -128, 127, torch.int8) | |
del buf347 | |
buf375 = buf374 | |
del buf374 | |
# Source Nodes: [input_93], Original ATen: [quantized_decomposed.dequantize_per_token] | |
buf376 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf375, buf355, buf356, -128, 127, torch.int8, torch.bfloat16) | |
del buf355 | |
del buf356 | |
del buf375 | |
buf377 = buf376 | |
del buf376 | |
# Source Nodes: [w_dq_30], Original ATen: [quantized_decomposed.dequantize_per_channel_group] | |
buf378 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg115_1, arg116_1, arg117_1, -8, 7, torch.int8, 256, torch.bfloat16) | |
del arg115_1 | |
del arg116_1 | |
del arg117_1 | |
buf379 = buf378 | |
del buf378 | |
buf380 = buf285; del buf285 # reuse | |
# Source Nodes: [], Original ATen: [] | |
extern_kernels.mm(reinterpret_tensor(buf377, (1, 4096), (4096, 1), 0), reinterpret_tensor(buf379, (4096, 1024), (1, 4096), 0), out=buf380) | |
del buf379 | |
buf382 = reinterpret_tensor(buf377, (1, 32, 1, 128), (4096, 128, 4096, 1), 0); del buf377 # reuse | |
# Source Nodes: [output_8, setitem_8, setitem_9], Original ATen: [aten._scaled_dot_product_efficient_attention, aten.index_put] | |
triton_poi_fused__scaled_dot_product_efficient_attention_index_put_3.run(arg3_1, buf372, arg118_1, buf380, buf365, arg119_1, arg120_1, buf382, 4096, grid=grid(4096), stream=stream0) | |
del arg118_1 | |
del buf365 | |
buf383 = buf296; del buf296 # reuse | |
# Source Nodes: [output_8], Original ATen: [aten._scaled_dot_product_efficient_attention] | |
triton_poi_fused__scaled_dot_product_efficient_attention_4.run(arg3_1, arg2_1, buf383, 65536, grid=grid(65536), stream=stream0) | |
# Source Nodes: [output_8], Original ATen: [aten._scaled_dot_product_efficient_attention] | |
buf384 = torch.ops.aten._scaled_dot_product_efficient_attention.default(buf382, arg119_1, arg120_1, buf383, False) | |
del arg119_1 | |
del arg120_1 | |
del buf382 | |
buf385 = buf384[0] | |
del buf384 | |
# Source Nodes: [choose_qparams_per_token_asymmetric_31], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric] | |
buf389 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(reinterpret_tensor(buf385, (1, 1, 4096), (4096, 4096, 1), 0), torch.int8) | |
buf390 = buf389[0] | |
buf391 = buf389[1] | |
del buf389 | |
# Source Nodes: [input_95], Original ATen: [quantized_decomposed.quantize_per_token] | |
buf392 = torch.ops.quantized_decomposed.quantize_per_token.default(reinterpret_tensor(buf385, (1, 1, 4096), (4096, 4096, 1), 0), buf390, buf391, -128, 127, torch.int8) | |
buf393 = buf392 | |
del buf392 | |
# Source Nodes: [input_96], Original ATen: [quantized_decomposed.dequantize_per_token] | |
buf394 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf393, buf390, buf391, -128, 127, torch.int8, torch.bfloat16) | |
del buf390 | |
del buf391 | |
del buf393 | |
buf395 = buf394 | |
del buf394 | |
# Source Nodes: [w_dq_31], Original ATen: [quantized_decomposed.dequantize_per_channel_group] | |
buf396 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg121_1, arg122_1, arg123_1, -8, 7, torch.int8, 256, torch.bfloat16) | |
del arg121_1 | |
del arg122_1 | |
del arg123_1 | |
buf397 = buf396 | |
del buf396 | |
buf398 = reinterpret_tensor(buf385, (1, 4096), (4096, 1), 0); del buf385 # reuse | |
# Source Nodes: [c_31], Original ATen: [aten.mm] | |
triton_tem_fused_mm_2.run(buf395, buf397, buf398, grid=torch._inductor.kernel.mm_common.mm_grid(1, 4096, meta0), stream=stream0) | |
del buf397 | |
buf400 = buf395; del buf395 # reuse | |
# Source Nodes: [add_33, h_5, mean_9, mul_62, mul_63, out_3, pow_10, rsqrt_9, x_fp32_9, x_normed_9], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, aten.pow, aten.rsqrt] | |
triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_12.run(buf398, buf312, buf345, arg124_1, buf400, 1, 4096, grid=grid(1), stream=stream0) | |
del arg124_1 | |
# Source Nodes: [choose_qparams_per_token_asymmetric_32], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric] | |
buf401 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf400, torch.int8) | |
buf402 = buf401[0] | |
buf403 = buf401[1] | |
del buf401 | |
# Source Nodes: [choose_qparams_per_token_asymmetric_33], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric] | |
buf404 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf400, torch.int8) | |
buf405 = buf404[0] | |
buf406 = buf404[1] | |
del buf404 | |
# Source Nodes: [input_98], Original ATen: [quantized_decomposed.quantize_per_token] | |
buf407 = torch.ops.quantized_decomposed.quantize_per_token.default(buf400, buf402, buf403, -128, 127, torch.int8) | |
buf408 = buf407 | |
del buf407 | |
# Source Nodes: [input_99], Original ATen: [quantized_decomposed.dequantize_per_token] | |
buf409 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf408, buf402, buf403, -128, 127, torch.int8, torch.bfloat16) | |
del buf402 | |
del buf403 | |
del buf408 | |
buf410 = buf409 | |
del buf409 | |
# Source Nodes: [w_dq_32], Original ATen: [quantized_decomposed.dequantize_per_channel_group] | |
buf411 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg125_1, arg126_1, arg127_1, -8, 7, torch.int8, 256, torch.bfloat16) | |
del arg125_1 | |
del arg126_1 | |
del arg127_1 | |
buf412 = buf411 | |
del buf411 | |
buf413 = reinterpret_tensor(buf342, (1, 14336), (14336, 1), 0); del buf342 # reuse | |
# Source Nodes: [c_32], Original ATen: [aten.mm] | |
triton_tem_fused_mm_6.run(buf410, buf412, buf413, grid=torch._inductor.kernel.mm_common.mm_grid(1, 14336, meta1), stream=stream0) | |
del buf412 | |
# Source Nodes: [input_101], Original ATen: [quantized_decomposed.quantize_per_token] | |
buf414 = torch.ops.quantized_decomposed.quantize_per_token.default(buf400, buf405, buf406, -128, 127, torch.int8) | |
buf415 = buf414 | |
del buf414 | |
# Source Nodes: [input_102], Original ATen: [quantized_decomposed.dequantize_per_token] | |
buf416 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf415, buf405, buf406, -128, 127, torch.int8, torch.bfloat16) | |
del buf405 | |
del buf406 | |
del buf415 | |
buf417 = buf416 | |
del buf416 | |
# Source Nodes: [w_dq_33], Original ATen: [quantized_decomposed.dequantize_per_channel_group] | |
buf418 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg128_1, arg129_1, arg130_1, -8, 7, torch.int8, 256, torch.bfloat16) | |
del arg128_1 | |
del arg129_1 | |
del arg130_1 | |
buf419 = buf418 | |
del buf418 | |
buf421 = buf335; del buf335 # reuse | |
# Source Nodes: [c_33, mul_64, silu_4], Original ATen: [aten.mm, aten.mul, aten.silu] | |
triton_tem_fused_mm_mul_silu_7.run(buf417, buf419, buf413, buf421, grid=torch._inductor.kernel.mm_common.mm_grid(1, 14336, meta1), stream=stream0) | |
del buf413 | |
del buf419 | |
# Source Nodes: [choose_qparams_per_token_asymmetric_34], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric] | |
buf422 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf421, torch.int8) | |
buf423 = buf422[0] | |
buf424 = buf422[1] | |
del buf422 | |
# Source Nodes: [input_104], Original ATen: [quantized_decomposed.quantize_per_token] | |
buf425 = torch.ops.quantized_decomposed.quantize_per_token.default(buf421, buf423, buf424, -128, 127, torch.int8) | |
buf426 = buf425 | |
del buf425 | |
# Source Nodes: [input_105], Original ATen: [quantized_decomposed.dequantize_per_token] | |
buf427 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf426, buf423, buf424, -128, 127, torch.int8, torch.bfloat16) | |
del buf423 | |
del buf424 | |
del buf426 | |
buf428 = buf427 | |
del buf427 | |
# Source Nodes: [w_dq_34], Original ATen: [quantized_decomposed.dequantize_per_channel_group] | |
buf429 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg131_1, arg132_1, arg133_1, -8, 7, torch.int8, 256, torch.bfloat16) | |
del arg131_1 | |
del arg132_1 | |
del arg133_1 | |
buf430 = buf429 | |
del buf429 | |
buf431 = reinterpret_tensor(buf417, (1, 4096), (4096, 1), 0); del buf417 # reuse | |
# Source Nodes: [c_34], Original ATen: [aten.mm] | |
triton_tem_fused_mm_8.run(buf428, buf430, buf431, grid=torch._inductor.kernel.mm_common.mm_grid(1, 4096, meta0), stream=stream0) | |
del buf430 | |
buf433 = buf400; del buf400 # reuse | |
# Source Nodes: [add_35, h_5, mean_10, mul_65, out_3, out_4, pow_11, rsqrt_10, x_fp32_10, x_normed_10, y_5], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, aten.pow, aten.rsqrt] | |
triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_13.run(buf398, buf312, buf345, buf431, arg134_1, buf433, 1, 4096, grid=grid(1), stream=stream0) | |
del arg134_1 | |
# Source Nodes: [choose_qparams_per_token_asymmetric_35], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric] | |
buf434 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf433, torch.int8) | |
buf435 = buf434[0] | |
buf436 = buf434[1] | |
del buf434 | |
# Source Nodes: [choose_qparams_per_token_asymmetric_36], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric] | |
buf437 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf433, torch.int8) | |
buf438 = buf437[0] | |
buf439 = buf437[1] | |
del buf437 | |
# Source Nodes: [choose_qparams_per_token_asymmetric_37], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric] | |
buf440 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf433, torch.int8) | |
buf441 = buf440[0] | |
buf442 = buf440[1] | |
del buf440 | |
buf443 = buf357; del buf357 # reuse | |
# Source Nodes: [max_6], Original ATen: [aten.max] | |
triton_poi_fused_max_1.run(arg3_1, buf443, 1, grid=grid(1), stream=stream0) | |
u5 = buf443.item() | |
buf444 = None | |
with torch.cuda._DeviceGuard(0): | |
torch.cuda.set_device(0) | |
# Source Nodes: [input_107], Original ATen: [quantized_decomposed.quantize_per_token] | |
buf445 = torch.ops.quantized_decomposed.quantize_per_token.default(buf433, buf435, buf436, -128, 127, torch.int8) | |
buf446 = buf445 | |
del buf445 | |
# Source Nodes: [input_108], Original ATen: [quantized_decomposed.dequantize_per_token] | |
buf447 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf446, buf435, buf436, -128, 127, torch.int8, torch.bfloat16) | |
del buf435 | |
del buf436 | |
del buf446 | |
buf448 = buf447 | |
del buf447 | |
# Source Nodes: [w_dq_35], Original ATen: [quantized_decomposed.dequantize_per_channel_group] | |
buf449 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg135_1, arg136_1, arg137_1, -8, 7, torch.int8, 256, torch.bfloat16) | |
del arg135_1 | |
del arg136_1 | |
del arg137_1 | |
buf450 = buf449 | |
del buf449 | |
buf451 = reinterpret_tensor(buf410, (1, 4096), (4096, 1), 0); del buf410 # reuse | |
# Source Nodes: [c_35], Original ATen: [aten.mm] | |
triton_tem_fused_mm_2.run(buf448, buf450, buf451, grid=torch._inductor.kernel.mm_common.mm_grid(1, 4096, meta0), stream=stream0) | |
del buf448 | |
del buf450 | |
# Source Nodes: [input_110], Original ATen: [quantized_decomposed.quantize_per_token] | |
buf452 = torch.ops.quantized_decomposed.quantize_per_token.default(buf433, buf438, buf439, -128, 127, torch.int8) | |
buf453 = buf452 | |
del buf452 | |
# Source Nodes: [input_111], Original ATen: [quantized_decomposed.dequantize_per_token] | |
buf454 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf453, buf438, buf439, -128, 127, torch.int8, torch.bfloat16) | |
del buf438 | |
del buf439 | |
del buf453 | |
buf455 = buf454 | |
del buf454 | |
# Source Nodes: [w_dq_36], Original ATen: [quantized_decomposed.dequantize_per_channel_group] | |
buf456 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg138_1, arg139_1, arg140_1, -8, 7, torch.int8, 256, torch.bfloat16) | |
del arg138_1 | |
del arg139_1 | |
del arg140_1 | |
buf457 = buf456 | |
del buf456 | |
buf458 = buf380; del buf380 # reuse | |
# Source Nodes: [], Original ATen: [] | |
extern_kernels.mm(reinterpret_tensor(buf455, (1, 4096), (4096, 1), 0), reinterpret_tensor(buf457, (4096, 1024), (1, 4096), 0), out=buf458) | |
del buf455 | |
del buf457 | |
# Source Nodes: [input_113], Original ATen: [quantized_decomposed.quantize_per_token] | |
buf460 = torch.ops.quantized_decomposed.quantize_per_token.default(buf433, buf441, buf442, -128, 127, torch.int8) | |
del buf433 | |
buf461 = buf460 | |
del buf460 | |
# Source Nodes: [input_114], Original ATen: [quantized_decomposed.dequantize_per_token] | |
buf462 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf461, buf441, buf442, -128, 127, torch.int8, torch.bfloat16) | |
del buf441 | |
del buf442 | |
del buf461 | |
buf463 = buf462 | |
del buf462 | |
# Source Nodes: [w_dq_37], Original ATen: [quantized_decomposed.dequantize_per_channel_group] | |
buf464 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg141_1, arg142_1, arg143_1, -8, 7, torch.int8, 256, torch.bfloat16) | |
del arg141_1 | |
del arg142_1 | |
del arg143_1 | |
buf465 = buf464 | |
del buf464 | |
buf466 = buf372; del buf372 # reuse | |
# Source Nodes: [], Original ATen: [] | |
extern_kernels.mm(reinterpret_tensor(buf463, (1, 4096), (4096, 1), 0), reinterpret_tensor(buf465, (4096, 1024), (1, 4096), 0), out=buf466) | |
del buf465 | |
buf468 = reinterpret_tensor(buf463, (1, 32, 1, 128), (4096, 128, 4096, 1), 0); del buf463 # reuse | |
# Source Nodes: [output_10, setitem_10, setitem_11], Original ATen: [aten._scaled_dot_product_efficient_attention, aten.index_put] | |
triton_poi_fused__scaled_dot_product_efficient_attention_index_put_3.run(arg3_1, buf458, arg144_1, buf466, buf451, arg145_1, arg146_1, buf468, 4096, grid=grid(4096), stream=stream0) | |
del arg144_1 | |
del buf451 | |
buf469 = buf383; del buf383 # reuse | |
# Source Nodes: [output_10], Original ATen: [aten._scaled_dot_product_efficient_attention] | |
triton_poi_fused__scaled_dot_product_efficient_attention_4.run(arg3_1, arg2_1, buf469, 65536, grid=grid(65536), stream=stream0) | |
# Source Nodes: [output_10], Original ATen: [aten._scaled_dot_product_efficient_attention] | |
buf470 = torch.ops.aten._scaled_dot_product_efficient_attention.default(buf468, arg145_1, arg146_1, buf469, False) | |
del arg145_1 | |
del arg146_1 | |
del buf468 | |
buf471 = buf470[0] | |
del buf470 | |
# Source Nodes: [choose_qparams_per_token_asymmetric_38], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric] | |
buf475 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(reinterpret_tensor(buf471, (1, 1, 4096), (4096, 4096, 1), 0), torch.int8) | |
buf476 = buf475[0] | |
buf477 = buf475[1] | |
del buf475 | |
# Source Nodes: [input_116], Original ATen: [quantized_decomposed.quantize_per_token] | |
buf478 = torch.ops.quantized_decomposed.quantize_per_token.default(reinterpret_tensor(buf471, (1, 1, 4096), (4096, 4096, 1), 0), buf476, buf477, -128, 127, torch.int8) | |
buf479 = buf478 | |
del buf478 | |
# Source Nodes: [input_117], Original ATen: [quantized_decomposed.dequantize_per_token] | |
buf480 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf479, buf476, buf477, -128, 127, torch.int8, torch.bfloat16) | |
del buf476 | |
del buf477 | |
del buf479 | |
buf481 = buf480 | |
del buf480 | |
# Source Nodes: [w_dq_38], Original ATen: [quantized_decomposed.dequantize_per_channel_group] | |
buf482 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg147_1, arg148_1, arg149_1, -8, 7, torch.int8, 256, torch.bfloat16) | |
del arg147_1 | |
del arg148_1 | |
del arg149_1 | |
buf483 = buf482 | |
del buf482 | |
buf484 = reinterpret_tensor(buf471, (1, 4096), (4096, 1), 0); del buf471 # reuse | |
# Source Nodes: [c_38], Original ATen: [aten.mm] | |
triton_tem_fused_mm_2.run(buf481, buf483, buf484, grid=torch._inductor.kernel.mm_common.mm_grid(1, 4096, meta0), stream=stream0) | |
del buf483 | |
buf485 = buf312; del buf312 # reuse | |
buf487 = buf481; del buf481 # reuse | |
# Source Nodes: [add_40, h_5, h_6, mean_11, mul_75, mul_76, out_3, out_4, pow_12, rsqrt_11, x_fp32_11, x_normed_11], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, aten.pow, aten.rsqrt] | |
triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_14.run(buf485, buf484, buf398, buf345, buf431, arg150_1, buf487, 1, 4096, grid=grid(1), stream=stream0) | |
del arg150_1 | |
del buf345 | |
del buf398 | |
del buf431 | |
del buf484 | |
# Source Nodes: [choose_qparams_per_token_asymmetric_39], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric] | |
buf488 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf487, torch.int8) | |
buf489 = buf488[0] | |
buf490 = buf488[1] | |
del buf488 | |
# Source Nodes: [choose_qparams_per_token_asymmetric_40], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric] | |
buf491 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf487, torch.int8) | |
buf492 = buf491[0] | |
buf493 = buf491[1] | |
del buf491 | |
# Source Nodes: [input_119], Original ATen: [quantized_decomposed.quantize_per_token] | |
buf494 = torch.ops.quantized_decomposed.quantize_per_token.default(buf487, buf489, buf490, -128, 127, torch.int8) | |
buf495 = buf494 | |
del buf494 | |
# Source Nodes: [input_120], Original ATen: [quantized_decomposed.dequantize_per_token] | |
buf496 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf495, buf489, buf490, -128, 127, torch.int8, torch.bfloat16) | |
del buf489 | |
del buf490 | |
del buf495 | |
buf497 = buf496 | |
del buf496 | |
# Source Nodes: [w_dq_39], Original ATen: [quantized_decomposed.dequantize_per_channel_group] | |
buf498 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg151_1, arg152_1, arg153_1, -8, 7, torch.int8, 256, torch.bfloat16) | |
del arg151_1 | |
del arg152_1 | |
del arg153_1 | |
buf499 = buf498 | |
del buf498 | |
buf500 = reinterpret_tensor(buf428, (1, 14336), (14336, 1), 0); del buf428 # reuse | |
# Source Nodes: [c_39], Original ATen: [aten.mm] | |
triton_tem_fused_mm_6.run(buf497, buf499, buf500, grid=torch._inductor.kernel.mm_common.mm_grid(1, 14336, meta1), stream=stream0) | |
del buf499 | |
# Source Nodes: [input_122], Original ATen: [quantized_decomposed.quantize_per_token] | |
buf501 = torch.ops.quantized_decomposed.quantize_per_token.default(buf487, buf492, buf493, -128, 127, torch.int8) | |
buf502 = buf501 | |
del buf501 | |
# Source Nodes: [input_123], Original ATen: [quantized_decomposed.dequantize_per_token] | |
buf503 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf502, buf492, buf493, -128, 127, torch.int8, torch.bfloat16) | |
del buf492 | |
del buf493 | |
del buf502 | |
buf504 = buf503 | |
del buf503 | |
# Source Nodes: [w_dq_40], Original ATen: [quantized_decomposed.dequantize_per_channel_group] | |
buf505 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg154_1, arg155_1, arg156_1, -8, 7, torch.int8, 256, torch.bfloat16) | |
del arg154_1 | |
del arg155_1 | |
del arg156_1 | |
buf506 = buf505 | |
del buf505 | |
buf508 = buf421; del buf421 # reuse | |
# Source Nodes: [c_40, mul_77, silu_5], Original ATen: [aten.mm, aten.mul, aten.silu] | |
triton_tem_fused_mm_mul_silu_7.run(buf504, buf506, buf500, buf508, grid=torch._inductor.kernel.mm_common.mm_grid(1, 14336, meta1), stream=stream0) | |
del buf500 | |
del buf506 | |
# Source Nodes: [choose_qparams_per_token_asymmetric_41], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric] | |
buf509 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf508, torch.int8) | |
buf510 = buf509[0] | |
buf511 = buf509[1] | |
del buf509 | |
# Source Nodes: [input_125], Original ATen: [quantized_decomposed.quantize_per_token] | |
buf512 = torch.ops.quantized_decomposed.quantize_per_token.default(buf508, buf510, buf511, -128, 127, torch.int8) | |
buf513 = buf512 | |
del buf512 | |
# Source Nodes: [input_126], Original ATen: [quantized_decomposed.dequantize_per_token] | |
buf514 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf513, buf510, buf511, -128, 127, torch.int8, torch.bfloat16) | |
del buf510 | |
del buf511 | |
del buf513 | |
buf515 = buf514 | |
del buf514 | |
# Source Nodes: [w_dq_41], Original ATen: [quantized_decomposed.dequantize_per_channel_group] | |
buf516 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg157_1, arg158_1, arg159_1, -8, 7, torch.int8, 256, torch.bfloat16) | |
del arg157_1 | |
del arg158_1 | |
del arg159_1 | |
buf517 = buf516 | |
del buf516 | |
buf518 = reinterpret_tensor(buf504, (1, 4096), (4096, 1), 0); del buf504 # reuse | |
# Source Nodes: [c_41], Original ATen: [aten.mm] | |
triton_tem_fused_mm_8.run(buf515, buf517, buf518, grid=torch._inductor.kernel.mm_common.mm_grid(1, 4096, meta0), stream=stream0) | |
del buf517 | |
buf520 = buf487; del buf487 # reuse | |
# Source Nodes: [add_42, mean_12, mul_78, out_5, pow_13, rsqrt_12, x_fp32_12, x_normed_12, y_6], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, aten.pow, aten.rsqrt] | |
triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_11.run(buf485, buf518, arg160_1, buf520, 1, 4096, grid=grid(1), stream=stream0) | |
del arg160_1 | |
# Source Nodes: [choose_qparams_per_token_asymmetric_42], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric] | |
buf521 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf520, torch.int8) | |
buf522 = buf521[0] | |
buf523 = buf521[1] | |
del buf521 | |
# Source Nodes: [choose_qparams_per_token_asymmetric_43], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric] | |
buf524 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf520, torch.int8) | |
buf525 = buf524[0] | |
buf526 = buf524[1] | |
del buf524 | |
# Source Nodes: [choose_qparams_per_token_asymmetric_44], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric] | |
buf527 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf520, torch.int8) | |
buf528 = buf527[0] | |
buf529 = buf527[1] | |
del buf527 | |
buf530 = buf443; del buf443 # reuse | |
# Source Nodes: [max_7], Original ATen: [aten.max] | |
triton_poi_fused_max_1.run(arg3_1, buf530, 1, grid=grid(1), stream=stream0) | |
u6 = buf530.item() | |
buf531 = None | |
with torch.cuda._DeviceGuard(0): | |
torch.cuda.set_device(0) | |
# Source Nodes: [input_128], Original ATen: [quantized_decomposed.quantize_per_token] | |
buf532 = torch.ops.quantized_decomposed.quantize_per_token.default(buf520, buf522, buf523, -128, 127, torch.int8) | |
buf533 = buf532 | |
del buf532 | |
# Source Nodes: [input_129], Original ATen: [quantized_decomposed.dequantize_per_token] | |
buf534 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf533, buf522, buf523, -128, 127, torch.int8, torch.bfloat16) | |
del buf522 | |
del buf523 | |
del buf533 | |
buf535 = buf534 | |
del buf534 | |
# Source Nodes: [w_dq_42], Original ATen: [quantized_decomposed.dequantize_per_channel_group] | |
buf536 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg161_1, arg162_1, arg163_1, -8, 7, torch.int8, 256, torch.bfloat16) | |
del arg161_1 | |
del arg162_1 | |
del arg163_1 | |
buf537 = buf536 | |
del buf536 | |
buf538 = reinterpret_tensor(buf497, (1, 4096), (4096, 1), 0); del buf497 # reuse | |
# Source Nodes: [c_42], Original ATen: [aten.mm] | |
triton_tem_fused_mm_2.run(buf535, buf537, buf538, grid=torch._inductor.kernel.mm_common.mm_grid(1, 4096, meta0), stream=stream0) | |
del buf535 | |
del buf537 | |
# Source Nodes: [input_131], Original ATen: [quantized_decomposed.quantize_per_token] | |
buf539 = torch.ops.quantized_decomposed.quantize_per_token.default(buf520, buf525, buf526, -128, 127, torch.int8) | |
buf540 = buf539 | |
del buf539 | |
# Source Nodes: [input_132], Original ATen: [quantized_decomposed.dequantize_per_token] | |
buf541 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf540, buf525, buf526, -128, 127, torch.int8, torch.bfloat16) | |
del buf525 | |
del buf526 | |
del buf540 | |
buf542 = buf541 | |
del buf541 | |
# Source Nodes: [w_dq_43], Original ATen: [quantized_decomposed.dequantize_per_channel_group] | |
buf543 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg164_1, arg165_1, arg166_1, -8, 7, torch.int8, 256, torch.bfloat16) | |
del arg164_1 | |
del arg165_1 | |
del arg166_1 | |
buf544 = buf543 | |
del buf543 | |
buf545 = buf466; del buf466 # reuse | |
# Source Nodes: [], Original ATen: [] | |
extern_kernels.mm(reinterpret_tensor(buf542, (1, 4096), (4096, 1), 0), reinterpret_tensor(buf544, (4096, 1024), (1, 4096), 0), out=buf545) | |
del buf542 | |
del buf544 | |
# Source Nodes: [input_134], Original ATen: [quantized_decomposed.quantize_per_token] | |
buf547 = torch.ops.quantized_decomposed.quantize_per_token.default(buf520, buf528, buf529, -128, 127, torch.int8) | |
del buf520 | |
buf548 = buf547 | |
del buf547 | |
# Source Nodes: [input_135], Original ATen: [quantized_decomposed.dequantize_per_token] | |
buf549 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf548, buf528, buf529, -128, 127, torch.int8, torch.bfloat16) | |
del buf528 | |
del buf529 | |
del buf548 | |
buf550 = buf549 | |
del buf549 | |
# Source Nodes: [w_dq_44], Original ATen: [quantized_decomposed.dequantize_per_channel_group] | |
buf551 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg167_1, arg168_1, arg169_1, -8, 7, torch.int8, 256, torch.bfloat16) | |
del arg167_1 | |
del arg168_1 | |
del arg169_1 | |
buf552 = buf551 | |
del buf551 | |
buf553 = buf458; del buf458 # reuse | |
# Source Nodes: [], Original ATen: [] | |
extern_kernels.mm(reinterpret_tensor(buf550, (1, 4096), (4096, 1), 0), reinterpret_tensor(buf552, (4096, 1024), (1, 4096), 0), out=buf553) | |
del buf552 | |
buf555 = reinterpret_tensor(buf550, (1, 32, 1, 128), (4096, 128, 4096, 1), 0); del buf550 # reuse | |
# Source Nodes: [output_12, setitem_12, setitem_13], Original ATen: [aten._scaled_dot_product_efficient_attention, aten.index_put] | |
triton_poi_fused__scaled_dot_product_efficient_attention_index_put_3.run(arg3_1, buf545, arg170_1, buf553, buf538, arg171_1, arg172_1, buf555, 4096, grid=grid(4096), stream=stream0) | |
del arg170_1 | |
del buf538 | |
buf556 = buf469; del buf469 # reuse | |
# Source Nodes: [output_12], Original ATen: [aten._scaled_dot_product_efficient_attention] | |
triton_poi_fused__scaled_dot_product_efficient_attention_4.run(arg3_1, arg2_1, buf556, 65536, grid=grid(65536), stream=stream0) | |
# Source Nodes: [output_12], Original ATen: [aten._scaled_dot_product_efficient_attention] | |
buf557 = torch.ops.aten._scaled_dot_product_efficient_attention.default(buf555, arg171_1, arg172_1, buf556, False) | |
del arg171_1 | |
del arg172_1 | |
del buf555 | |
buf558 = buf557[0] | |
del buf557 | |
# Source Nodes: [choose_qparams_per_token_asymmetric_45], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric] | |
buf562 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(reinterpret_tensor(buf558, (1, 1, 4096), (4096, 4096, 1), 0), torch.int8) | |
buf563 = buf562[0] | |
buf564 = buf562[1] | |
del buf562 | |
# Source Nodes: [input_137], Original ATen: [quantized_decomposed.quantize_per_token] | |
buf565 = torch.ops.quantized_decomposed.quantize_per_token.default(reinterpret_tensor(buf558, (1, 1, 4096), (4096, 4096, 1), 0), buf563, buf564, -128, 127, torch.int8) | |
buf566 = buf565 | |
del buf565 | |
# Source Nodes: [input_138], Original ATen: [quantized_decomposed.dequantize_per_token] | |
buf567 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf566, buf563, buf564, -128, 127, torch.int8, torch.bfloat16) | |
del buf563 | |
del buf564 | |
del buf566 | |
buf568 = buf567 | |
del buf567 | |
# Source Nodes: [w_dq_45], Original ATen: [quantized_decomposed.dequantize_per_channel_group] | |
buf569 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg173_1, arg174_1, arg175_1, -8, 7, torch.int8, 256, torch.bfloat16) | |
del arg173_1 | |
del arg174_1 | |
del arg175_1 | |
buf570 = buf569 | |
del buf569 | |
buf571 = reinterpret_tensor(buf558, (1, 4096), (4096, 1), 0); del buf558 # reuse | |
# Source Nodes: [c_45], Original ATen: [aten.mm] | |
triton_tem_fused_mm_2.run(buf568, buf570, buf571, grid=torch._inductor.kernel.mm_common.mm_grid(1, 4096, meta0), stream=stream0) | |
del buf570 | |
buf573 = buf568; del buf568 # reuse | |
# Source Nodes: [add_47, h_7, mean_13, mul_88, mul_89, out_5, pow_14, rsqrt_13, x_fp32_13, x_normed_13], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, aten.pow, aten.rsqrt] | |
triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_12.run(buf571, buf485, buf518, arg176_1, buf573, 1, 4096, grid=grid(1), stream=stream0) | |
del arg176_1 | |
# Source Nodes: [choose_qparams_per_token_asymmetric_46], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric] | |
buf574 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf573, torch.int8) | |
buf575 = buf574[0] | |
buf576 = buf574[1] | |
del buf574 | |
# Source Nodes: [choose_qparams_per_token_asymmetric_47], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric] | |
buf577 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf573, torch.int8) | |
buf578 = buf577[0] | |
buf579 = buf577[1] | |
del buf577 | |
# Source Nodes: [input_140], Original ATen: [quantized_decomposed.quantize_per_token] | |
buf580 = torch.ops.quantized_decomposed.quantize_per_token.default(buf573, buf575, buf576, -128, 127, torch.int8) | |
buf581 = buf580 | |
del buf580 | |
# Source Nodes: [input_141], Original ATen: [quantized_decomposed.dequantize_per_token] | |
buf582 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf581, buf575, buf576, -128, 127, torch.int8, torch.bfloat16) | |
del buf575 | |
del buf576 | |
del buf581 | |
buf583 = buf582 | |
del buf582 | |
# Source Nodes: [w_dq_46], Original ATen: [quantized_decomposed.dequantize_per_channel_group] | |
buf584 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg177_1, arg178_1, arg179_1, -8, 7, torch.int8, 256, torch.bfloat16) | |
del arg177_1 | |
del arg178_1 | |
del arg179_1 | |
buf585 = buf584 | |
del buf584 | |
buf586 = reinterpret_tensor(buf515, (1, 14336), (14336, 1), 0); del buf515 # reuse | |
# Source Nodes: [c_46], Original ATen: [aten.mm] | |
triton_tem_fused_mm_6.run(buf583, buf585, buf586, grid=torch._inductor.kernel.mm_common.mm_grid(1, 14336, meta1), stream=stream0) | |
del buf585 | |
# Source Nodes: [input_143], Original ATen: [quantized_decomposed.quantize_per_token] | |
buf587 = torch.ops.quantized_decomposed.quantize_per_token.default(buf573, buf578, buf579, -128, 127, torch.int8) | |
buf588 = buf587 | |
del buf587 | |
# Source Nodes: [input_144], Original ATen: [quantized_decomposed.dequantize_per_token] | |
buf589 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf588, buf578, buf579, -128, 127, torch.int8, torch.bfloat16) | |
del buf578 | |
del buf579 | |
del buf588 | |
buf590 = buf589 | |
del buf589 | |
# Source Nodes: [w_dq_47], Original ATen: [quantized_decomposed.dequantize_per_channel_group] | |
buf591 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg180_1, arg181_1, arg182_1, -8, 7, torch.int8, 256, torch.bfloat16) | |
del arg180_1 | |
del arg181_1 | |
del arg182_1 | |
buf592 = buf591 | |
del buf591 | |
buf594 = buf508; del buf508 # reuse | |
# Source Nodes: [c_47, mul_90, silu_6], Original ATen: [aten.mm, aten.mul, aten.silu] | |
triton_tem_fused_mm_mul_silu_7.run(buf590, buf592, buf586, buf594, grid=torch._inductor.kernel.mm_common.mm_grid(1, 14336, meta1), stream=stream0) | |
del buf586 | |
del buf592 | |
# Source Nodes: [choose_qparams_per_token_asymmetric_48], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric] | |
buf595 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf594, torch.int8) | |
buf596 = buf595[0] | |
buf597 = buf595[1] | |
del buf595 | |
# Source Nodes: [input_146], Original ATen: [quantized_decomposed.quantize_per_token] | |
buf598 = torch.ops.quantized_decomposed.quantize_per_token.default(buf594, buf596, buf597, -128, 127, torch.int8) | |
buf599 = buf598 | |
del buf598 | |
# Source Nodes: [input_147], Original ATen: [quantized_decomposed.dequantize_per_token] | |
buf600 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf599, buf596, buf597, -128, 127, torch.int8, torch.bfloat16) | |
del buf596 | |
del buf597 | |
del buf599 | |
buf601 = buf600 | |
del buf600 | |
# Source Nodes: [w_dq_48], Original ATen: [quantized_decomposed.dequantize_per_channel_group] | |
buf602 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg183_1, arg184_1, arg185_1, -8, 7, torch.int8, 256, torch.bfloat16) | |
del arg183_1 | |
del arg184_1 | |
del arg185_1 | |
buf603 = buf602 | |
del buf602 | |
buf604 = reinterpret_tensor(buf590, (1, 4096), (4096, 1), 0); del buf590 # reuse | |
# Source Nodes: [c_48], Original ATen: [aten.mm] | |
triton_tem_fused_mm_8.run(buf601, buf603, buf604, grid=torch._inductor.kernel.mm_common.mm_grid(1, 4096, meta0), stream=stream0) | |
del buf603 | |
buf606 = buf573; del buf573 # reuse | |
# Source Nodes: [add_49, h_7, mean_14, mul_91, out_5, out_6, pow_15, rsqrt_14, x_fp32_14, x_normed_14, y_7], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, aten.pow, aten.rsqrt] | |
triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_13.run(buf571, buf485, buf518, buf604, arg186_1, buf606, 1, 4096, grid=grid(1), stream=stream0) | |
del arg186_1 | |
# Source Nodes: [choose_qparams_per_token_asymmetric_49], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric] | |
buf607 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf606, torch.int8) | |
buf608 = buf607[0] | |
buf609 = buf607[1] | |
del buf607 | |
# Source Nodes: [choose_qparams_per_token_asymmetric_50], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric] | |
buf610 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf606, torch.int8) | |
buf611 = buf610[0] | |
buf612 = buf610[1] | |
del buf610 | |
# Source Nodes: [choose_qparams_per_token_asymmetric_51], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric] | |
buf613 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf606, torch.int8) | |
buf614 = buf613[0] | |
buf615 = buf613[1] | |
del buf613 | |
buf616 = buf530; del buf530 # reuse | |
# Source Nodes: [max_8], Original ATen: [aten.max] | |
triton_poi_fused_max_1.run(arg3_1, buf616, 1, grid=grid(1), stream=stream0) | |
u7 = buf616.item() | |
buf617 = None | |
with torch.cuda._DeviceGuard(0): | |
torch.cuda.set_device(0) | |
# Source Nodes: [input_149], Original ATen: [quantized_decomposed.quantize_per_token] | |
buf618 = torch.ops.quantized_decomposed.quantize_per_token.default(buf606, buf608, buf609, -128, 127, torch.int8) | |
buf619 = buf618 | |
del buf618 | |
# Source Nodes: [input_150], Original ATen: [quantized_decomposed.dequantize_per_token] | |
buf620 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf619, buf608, buf609, -128, 127, torch.int8, torch.bfloat16) | |
del buf608 | |
del buf609 | |
del buf619 | |
buf621 = buf620 | |
del buf620 | |
# Source Nodes: [w_dq_49], Original ATen: [quantized_decomposed.dequantize_per_channel_group] | |
buf622 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg187_1, arg188_1, arg189_1, -8, 7, torch.int8, 256, torch.bfloat16) | |
del arg187_1 | |
del arg188_1 | |
del arg189_1 | |
buf623 = buf622 | |
del buf622 | |
buf624 = reinterpret_tensor(buf583, (1, 4096), (4096, 1), 0); del buf583 # reuse | |
# Source Nodes: [c_49], Original ATen: [aten.mm] | |
triton_tem_fused_mm_2.run(buf621, buf623, buf624, grid=torch._inductor.kernel.mm_common.mm_grid(1, 4096, meta0), stream=stream0) | |
del buf621 | |
del buf623 | |
# Source Nodes: [input_152], Original ATen: [quantized_decomposed.quantize_per_token] | |
buf625 = torch.ops.quantized_decomposed.quantize_per_token.default(buf606, buf611, buf612, -128, 127, torch.int8) | |
buf626 = buf625 | |
del buf625 | |
# Source Nodes: [input_153], Original ATen: [quantized_decomposed.dequantize_per_token] | |
buf627 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf626, buf611, buf612, -128, 127, torch.int8, torch.bfloat16) | |
del buf611 | |
del buf612 | |
del buf626 | |
buf628 = buf627 | |
del buf627 | |
# Source Nodes: [w_dq_50], Original ATen: [quantized_decomposed.dequantize_per_channel_group] | |
buf629 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg190_1, arg191_1, arg192_1, -8, 7, torch.int8, 256, torch.bfloat16) | |
del arg190_1 | |
del arg191_1 | |
del arg192_1 | |
buf630 = buf629 | |
del buf629 | |
buf631 = buf553; del buf553 # reuse | |
# Source Nodes: [], Original ATen: [] | |
extern_kernels.mm(reinterpret_tensor(buf628, (1, 4096), (4096, 1), 0), reinterpret_tensor(buf630, (4096, 1024), (1, 4096), 0), out=buf631) | |
del buf628 | |
del buf630 | |
# Source Nodes: [input_155], Original ATen: [quantized_decomposed.quantize_per_token] | |
buf633 = torch.ops.quantized_decomposed.quantize_per_token.default(buf606, buf614, buf615, -128, 127, torch.int8) | |
del buf606 | |
buf634 = buf633 | |
del buf633 | |
# Source Nodes: [input_156], Original ATen: [quantized_decomposed.dequantize_per_token] | |
buf635 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf634, buf614, buf615, -128, 127, torch.int8, torch.bfloat16) | |
del buf614 | |
del buf615 | |
del buf634 | |
buf636 = buf635 | |
del buf635 | |
# Source Nodes: [w_dq_51], Original ATen: [quantized_decomposed.dequantize_per_channel_group] | |
buf637 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg193_1, arg194_1, arg195_1, -8, 7, torch.int8, 256, torch.bfloat16) | |
del arg193_1 | |
del arg194_1 | |
del arg195_1 | |
buf638 = buf637 | |
del buf637 | |
buf639 = buf545; del buf545 # reuse | |
# Source Nodes: [], Original ATen: [] | |
extern_kernels.mm(reinterpret_tensor(buf636, (1, 4096), (4096, 1), 0), reinterpret_tensor(buf638, (4096, 1024), (1, 4096), 0), out=buf639) | |
del buf638 | |
buf641 = reinterpret_tensor(buf636, (1, 32, 1, 128), (4096, 128, 4096, 1), 0); del buf636 # reuse | |
# Source Nodes: [output_14, setitem_14, setitem_15], Original ATen: [aten._scaled_dot_product_efficient_attention, aten.index_put] | |
triton_poi_fused__scaled_dot_product_efficient_attention_index_put_3.run(arg3_1, buf631, arg196_1, buf639, buf624, arg197_1, arg198_1, buf641, 4096, grid=grid(4096), stream=stream0) | |
del arg196_1 | |
del buf624 | |
buf642 = buf556; del buf556 # reuse | |
# Source Nodes: [output_14], Original ATen: [aten._scaled_dot_product_efficient_attention] | |
triton_poi_fused__scaled_dot_product_efficient_attention_4.run(arg3_1, arg2_1, buf642, 65536, grid=grid(65536), stream=stream0) | |
# Source Nodes: [output_14], Original ATen: [aten._scaled_dot_product_efficient_attention] | |
buf643 = torch.ops.aten._scaled_dot_product_efficient_attention.default(buf641, arg197_1, arg198_1, buf642, False) | |
del arg197_1 | |
del arg198_1 | |
del buf641 | |
buf644 = buf643[0] | |
del buf643 | |
# Source Nodes: [choose_qparams_per_token_asymmetric_52], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric] | |
buf648 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(reinterpret_tensor(buf644, (1, 1, 4096), (4096, 4096, 1), 0), torch.int8) | |
buf649 = buf648[0] | |
buf650 = buf648[1] | |
del buf648 | |
# Source Nodes: [input_158], Original ATen: [quantized_decomposed.quantize_per_token] | |
buf651 = torch.ops.quantized_decomposed.quantize_per_token.default(reinterpret_tensor(buf644, (1, 1, 4096), (4096, 4096, 1), 0), buf649, buf650, -128, 127, torch.int8) | |
buf652 = buf651 | |
del buf651 | |
# Source Nodes: [input_159], Original ATen: [quantized_decomposed.dequantize_per_token] | |
buf653 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf652, buf649, buf650, -128, 127, torch.int8, torch.bfloat16) | |
del buf649 | |
del buf650 | |
del buf652 | |
buf654 = buf653 | |
del buf653 | |
# Source Nodes: [w_dq_52], Original ATen: [quantized_decomposed.dequantize_per_channel_group] | |
buf655 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg199_1, arg200_1, arg201_1, -8, 7, torch.int8, 256, torch.bfloat16) | |
del arg199_1 | |
del arg200_1 | |
del arg201_1 | |
buf656 = buf655 | |
del buf655 | |
buf657 = reinterpret_tensor(buf644, (1, 4096), (4096, 1), 0); del buf644 # reuse | |
# Source Nodes: [c_52], Original ATen: [aten.mm] | |
triton_tem_fused_mm_2.run(buf654, buf656, buf657, grid=torch._inductor.kernel.mm_common.mm_grid(1, 4096, meta0), stream=stream0) | |
del buf656 | |
buf658 = buf485; del buf485 # reuse | |
buf660 = buf654; del buf654 # reuse | |
# Source Nodes: [add_54, h_7, h_8, mean_15, mul_101, mul_102, out_5, out_6, pow_16, rsqrt_15, x_fp32_15, x_normed_15], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, aten.pow, aten.rsqrt] | |
triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_14.run(buf658, buf657, buf571, buf518, buf604, arg202_1, buf660, 1, 4096, grid=grid(1), stream=stream0) | |
del arg202_1 | |
del buf518 | |
del buf571 | |
del buf604 | |
del buf657 | |
# Source Nodes: [choose_qparams_per_token_asymmetric_53], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric] | |
buf661 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf660, torch.int8) | |
buf662 = buf661[0] | |
buf663 = buf661[1] | |
del buf661 | |
# Source Nodes: [choose_qparams_per_token_asymmetric_54], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric] | |
buf664 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf660, torch.int8) | |
buf665 = buf664[0] | |
buf666 = buf664[1] | |
del buf664 | |
# Source Nodes: [input_161], Original ATen: [quantized_decomposed.quantize_per_token] | |
buf667 = torch.ops.quantized_decomposed.quantize_per_token.default(buf660, buf662, buf663, -128, 127, torch.int8) | |
buf668 = buf667 | |
del buf667 | |
# Source Nodes: [input_162], Original ATen: [quantized_decomposed.dequantize_per_token] | |
buf669 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf668, buf662, buf663, -128, 127, torch.int8, torch.bfloat16) | |
del buf662 | |
del buf663 | |
del buf668 | |
buf670 = buf669 | |
del buf669 | |
# Source Nodes: [w_dq_53], Original ATen: [quantized_decomposed.dequantize_per_channel_group] | |
buf671 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg203_1, arg204_1, arg205_1, -8, 7, torch.int8, 256, torch.bfloat16) | |
del arg203_1 | |
del arg204_1 | |
del arg205_1 | |
buf672 = buf671 | |
del buf671 | |
buf673 = reinterpret_tensor(buf601, (1, 14336), (14336, 1), 0); del buf601 # reuse | |
# Source Nodes: [c_53], Original ATen: [aten.mm] | |
triton_tem_fused_mm_6.run(buf670, buf672, buf673, grid=torch._inductor.kernel.mm_common.mm_grid(1, 14336, meta1), stream=stream0) | |
del buf672 | |
# Source Nodes: [input_164], Original ATen: [quantized_decomposed.quantize_per_token] | |
buf674 = torch.ops.quantized_decomposed.quantize_per_token.default(buf660, buf665, buf666, -128, 127, torch.int8) | |
buf675 = buf674 | |
del buf674 | |
# Source Nodes: [input_165], Original ATen: [quantized_decomposed.dequantize_per_token] | |
buf676 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf675, buf665, buf666, -128, 127, torch.int8, torch.bfloat16) | |
del buf665 | |
del buf666 | |
del buf675 | |
buf677 = buf676 | |
del buf676 | |
# Source Nodes: [w_dq_54], Original ATen: [quantized_decomposed.dequantize_per_channel_group] | |
buf678 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg206_1, arg207_1, arg208_1, -8, 7, torch.int8, 256, torch.bfloat16) | |
del arg206_1 | |
del arg207_1 | |
del arg208_1 | |
buf679 = buf678 | |
del buf678 | |
buf681 = buf594; del buf594 # reuse | |
# Source Nodes: [c_54, mul_103, silu_7], Original ATen: [aten.mm, aten.mul, aten.silu] | |
triton_tem_fused_mm_mul_silu_7.run(buf677, buf679, buf673, buf681, grid=torch._inductor.kernel.mm_common.mm_grid(1, 14336, meta1), stream=stream0) | |
del buf673 | |
del buf679 | |
# Source Nodes: [choose_qparams_per_token_asymmetric_55], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric] | |
buf682 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf681, torch.int8) | |
buf683 = buf682[0] | |
buf684 = buf682[1] | |
del buf682 | |
# Source Nodes: [input_167], Original ATen: [quantized_decomposed.quantize_per_token] | |
buf685 = torch.ops.quantized_decomposed.quantize_per_token.default(buf681, buf683, buf684, -128, 127, torch.int8) | |
buf686 = buf685 | |
del buf685 | |
# Source Nodes: [input_168], Original ATen: [quantized_decomposed.dequantize_per_token] | |
buf687 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf686, buf683, buf684, -128, 127, torch.int8, torch.bfloat16) | |
del buf683 | |
del buf684 | |
del buf686 | |
buf688 = buf687 | |
del buf687 | |
# Source Nodes: [w_dq_55], Original ATen: [quantized_decomposed.dequantize_per_channel_group] | |
buf689 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg209_1, arg210_1, arg211_1, -8, 7, torch.int8, 256, torch.bfloat16) | |
del arg209_1 | |
del arg210_1 | |
del arg211_1 | |
buf690 = buf689 | |
del buf689 | |
buf691 = reinterpret_tensor(buf677, (1, 4096), (4096, 1), 0); del buf677 # reuse | |
# Source Nodes: [c_55], Original ATen: [aten.mm] | |
triton_tem_fused_mm_8.run(buf688, buf690, buf691, grid=torch._inductor.kernel.mm_common.mm_grid(1, 4096, meta0), stream=stream0) | |
del buf690 | |
buf693 = buf660; del buf660 # reuse | |
# Source Nodes: [add_56, mean_16, mul_104, out_7, pow_17, rsqrt_16, x_fp32_16, x_normed_16, y_8], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, aten.pow, aten.rsqrt] | |
triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_11.run(buf658, buf691, arg212_1, buf693, 1, 4096, grid=grid(1), stream=stream0) | |
del arg212_1 | |
# Source Nodes: [choose_qparams_per_token_asymmetric_56], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric] | |
buf694 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf693, torch.int8) | |
buf695 = buf694[0] | |
buf696 = buf694[1] | |
del buf694 | |
# Source Nodes: [choose_qparams_per_token_asymmetric_57], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric] | |
buf697 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf693, torch.int8) | |
buf698 = buf697[0] | |
buf699 = buf697[1] | |
del buf697 | |
# Source Nodes: [choose_qparams_per_token_asymmetric_58], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric] | |
buf700 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf693, torch.int8) | |
buf701 = buf700[0] | |
buf702 = buf700[1] | |
del buf700 | |
buf703 = buf616; del buf616 # reuse | |
# Source Nodes: [max_9], Original ATen: [aten.max] | |
triton_poi_fused_max_1.run(arg3_1, buf703, 1, grid=grid(1), stream=stream0) | |
u8 = buf703.item() | |
buf704 = None | |
with torch.cuda._DeviceGuard(0): | |
torch.cuda.set_device(0) | |
# Source Nodes: [input_170], Original ATen: [quantized_decomposed.quantize_per_token] | |
buf705 = torch.ops.quantized_decomposed.quantize_per_token.default(buf693, buf695, buf696, -128, 127, torch.int8) | |
buf706 = buf705 | |
del buf705 | |
# Source Nodes: [input_171], Original ATen: [quantized_decomposed.dequantize_per_token] | |
buf707 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf706, buf695, buf696, -128, 127, torch.int8, torch.bfloat16) | |
del buf695 | |
del buf696 | |
del buf706 | |
buf708 = buf707 | |
del buf707 | |
# Source Nodes: [w_dq_56], Original ATen: [quantized_decomposed.dequantize_per_channel_group] | |
buf709 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg213_1, arg214_1, arg215_1, -8, 7, torch.int8, 256, torch.bfloat16) | |
del arg213_1 | |
del arg214_1 | |
del arg215_1 | |
buf710 = buf709 | |
del buf709 | |
buf711 = reinterpret_tensor(buf670, (1, 4096), (4096, 1), 0); del buf670 # reuse | |
# Source Nodes: [c_56], Original ATen: [aten.mm] | |
triton_tem_fused_mm_2.run(buf708, buf710, buf711, grid=torch._inductor.kernel.mm_common.mm_grid(1, 4096, meta0), stream=stream0) | |
del buf708 | |
del buf710 | |
# Source Nodes: [input_173], Original ATen: [quantized_decomposed.quantize_per_token] | |
buf712 = torch.ops.quantized_decomposed.quantize_per_token.default(buf693, buf698, buf699, -128, 127, torch.int8) | |
buf713 = buf712 | |
del buf712 | |
# Source Nodes: [input_174], Original ATen: [quantized_decomposed.dequantize_per_token] | |
buf714 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf713, buf698, buf699, -128, 127, torch.int8, torch.bfloat16) | |
del buf698 | |
del buf699 | |
del buf713 | |
buf715 = buf714 | |
del buf714 | |
# Source Nodes: [w_dq_57], Original ATen: [quantized_decomposed.dequantize_per_channel_group] | |
buf716 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg216_1, arg217_1, arg218_1, -8, 7, torch.int8, 256, torch.bfloat16) | |
del arg216_1 | |
del arg217_1 | |
del arg218_1 | |
buf717 = buf716 | |
del buf716 | |
buf718 = buf639; del buf639 # reuse | |
# Source Nodes: [], Original ATen: [] | |
extern_kernels.mm(reinterpret_tensor(buf715, (1, 4096), (4096, 1), 0), reinterpret_tensor(buf717, (4096, 1024), (1, 4096), 0), out=buf718) | |
del buf715 | |
del buf717 | |
# Source Nodes: [input_176], Original ATen: [quantized_decomposed.quantize_per_token] | |
buf720 = torch.ops.quantized_decomposed.quantize_per_token.default(buf693, buf701, buf702, -128, 127, torch.int8) | |
del buf693 | |
buf721 = buf720 | |
del buf720 | |
# Source Nodes: [input_177], Original ATen: [quantized_decomposed.dequantize_per_token] | |
buf722 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf721, buf701, buf702, -128, 127, torch.int8, torch.bfloat16) | |
del buf701 | |
del buf702 | |
del buf721 | |
buf723 = buf722 | |
del buf722 | |
# Source Nodes: [w_dq_58], Original ATen: [quantized_decomposed.dequantize_per_channel_group] | |
buf724 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg219_1, arg220_1, arg221_1, -8, 7, torch.int8, 256, torch.bfloat16) | |
del arg219_1 | |
del arg220_1 | |
del arg221_1 | |
buf725 = buf724 | |
del buf724 | |
buf726 = buf631; del buf631 # reuse | |
# Source Nodes: [], Original ATen: [] | |
extern_kernels.mm(reinterpret_tensor(buf723, (1, 4096), (4096, 1), 0), reinterpret_tensor(buf725, (4096, 1024), (1, 4096), 0), out=buf726) | |
del buf725 | |
buf728 = reinterpret_tensor(buf723, (1, 32, 1, 128), (4096, 128, 4096, 1), 0); del buf723 # reuse | |
# Source Nodes: [output_16, setitem_16, setitem_17], Original ATen: [aten._scaled_dot_product_efficient_attention, aten.index_put] | |
triton_poi_fused__scaled_dot_product_efficient_attention_index_put_3.run(arg3_1, buf718, arg222_1, buf726, buf711, arg223_1, arg224_1, buf728, 4096, grid=grid(4096), stream=stream0) | |
del arg222_1 | |
del buf711 | |
buf729 = buf642; del buf642 # reuse | |
# Source Nodes: [output_16], Original ATen: [aten._scaled_dot_product_efficient_attention] | |
triton_poi_fused__scaled_dot_product_efficient_attention_4.run(arg3_1, arg2_1, buf729, 65536, grid=grid(65536), stream=stream0) | |
# Source Nodes: [output_16], Original ATen: [aten._scaled_dot_product_efficient_attention] | |
buf730 = torch.ops.aten._scaled_dot_product_efficient_attention.default(buf728, arg223_1, arg224_1, buf729, False) | |
del arg223_1 | |
del arg224_1 | |
del buf728 | |
buf731 = buf730[0] | |
del buf730 | |
# Source Nodes: [choose_qparams_per_token_asymmetric_59], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric] | |
buf735 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(reinterpret_tensor(buf731, (1, 1, 4096), (4096, 4096, 1), 0), torch.int8) | |
buf736 = buf735[0] | |
buf737 = buf735[1] | |
del buf735 | |
# Source Nodes: [input_179], Original ATen: [quantized_decomposed.quantize_per_token] | |
buf738 = torch.ops.quantized_decomposed.quantize_per_token.default(reinterpret_tensor(buf731, (1, 1, 4096), (4096, 4096, 1), 0), buf736, buf737, -128, 127, torch.int8) | |
buf739 = buf738 | |
del buf738 | |
# Source Nodes: [input_180], Original ATen: [quantized_decomposed.dequantize_per_token] | |
buf740 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf739, buf736, buf737, -128, 127, torch.int8, torch.bfloat16) | |
del buf736 | |
del buf737 | |
del buf739 | |
buf741 = buf740 | |
del buf740 | |
# Source Nodes: [w_dq_59], Original ATen: [quantized_decomposed.dequantize_per_channel_group] | |
buf742 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg225_1, arg226_1, arg227_1, -8, 7, torch.int8, 256, torch.bfloat16) | |
del arg225_1 | |
del arg226_1 | |
del arg227_1 | |
buf743 = buf742 | |
del buf742 | |
buf744 = reinterpret_tensor(buf731, (1, 4096), (4096, 1), 0); del buf731 # reuse | |
# Source Nodes: [c_59], Original ATen: [aten.mm] | |
triton_tem_fused_mm_2.run(buf741, buf743, buf744, grid=torch._inductor.kernel.mm_common.mm_grid(1, 4096, meta0), stream=stream0) | |
del buf743 | |
buf746 = buf741; del buf741 # reuse | |
# Source Nodes: [add_61, h_9, mean_17, mul_114, mul_115, out_7, pow_18, rsqrt_17, x_fp32_17, x_normed_17], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, aten.pow, aten.rsqrt] | |
triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_12.run(buf744, buf658, buf691, arg228_1, buf746, 1, 4096, grid=grid(1), stream=stream0) | |
del arg228_1 | |
# Source Nodes: [choose_qparams_per_token_asymmetric_60], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric] | |
buf747 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf746, torch.int8) | |
buf748 = buf747[0] | |
buf749 = buf747[1] | |
del buf747 | |
# Source Nodes: [choose_qparams_per_token_asymmetric_61], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric] | |
buf750 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf746, torch.int8) | |
buf751 = buf750[0] | |
buf752 = buf750[1] | |
del buf750 | |
# Source Nodes: [input_182], Original ATen: [quantized_decomposed.quantize_per_token] | |
buf753 = torch.ops.quantized_decomposed.quantize_per_token.default(buf746, buf748, buf749, -128, 127, torch.int8) | |
buf754 = buf753 | |
del buf753 | |
# Source Nodes: [input_183], Original ATen: [quantized_decomposed.dequantize_per_token] | |
buf755 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf754, buf748, buf749, -128, 127, torch.int8, torch.bfloat16) | |
del buf748 | |
del buf749 | |
del buf754 | |
buf756 = buf755 | |
del buf755 | |
# Source Nodes: [w_dq_60], Original ATen: [quantized_decomposed.dequantize_per_channel_group] | |
buf757 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg229_1, arg230_1, arg231_1, -8, 7, torch.int8, 256, torch.bfloat16) | |
del arg229_1 | |
del arg230_1 | |
del arg231_1 | |
buf758 = buf757 | |
del buf757 | |
buf759 = reinterpret_tensor(buf688, (1, 14336), (14336, 1), 0); del buf688 # reuse | |
# Source Nodes: [c_60], Original ATen: [aten.mm] | |
triton_tem_fused_mm_6.run(buf756, buf758, buf759, grid=torch._inductor.kernel.mm_common.mm_grid(1, 14336, meta1), stream=stream0) | |
del buf758 | |
# Source Nodes: [input_185], Original ATen: [quantized_decomposed.quantize_per_token] | |
buf760 = torch.ops.quantized_decomposed.quantize_per_token.default(buf746, buf751, buf752, -128, 127, torch.int8) | |
buf761 = buf760 | |
del buf760 | |
# Source Nodes: [input_186], Original ATen: [quantized_decomposed.dequantize_per_token] | |
buf762 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf761, buf751, buf752, -128, 127, torch.int8, torch.bfloat16) | |
del buf751 | |
del buf752 | |
del buf761 | |
buf763 = buf762 | |
del buf762 | |
# Source Nodes: [w_dq_61], Original ATen: [quantized_decomposed.dequantize_per_channel_group] | |
buf764 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg232_1, arg233_1, arg234_1, -8, 7, torch.int8, 256, torch.bfloat16) | |
del arg232_1 | |
del arg233_1 | |
del arg234_1 | |
buf765 = buf764 | |
del buf764 | |
buf767 = buf681; del buf681 # reuse | |
# Source Nodes: [c_61, mul_116, silu_8], Original ATen: [aten.mm, aten.mul, aten.silu] | |
triton_tem_fused_mm_mul_silu_7.run(buf763, buf765, buf759, buf767, grid=torch._inductor.kernel.mm_common.mm_grid(1, 14336, meta1), stream=stream0) | |
del buf759 | |
del buf765 | |
# Source Nodes: [choose_qparams_per_token_asymmetric_62], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric] | |
buf768 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf767, torch.int8) | |
buf769 = buf768[0] | |
buf770 = buf768[1] | |
del buf768 | |
# Source Nodes: [input_188], Original ATen: [quantized_decomposed.quantize_per_token] | |
buf771 = torch.ops.quantized_decomposed.quantize_per_token.default(buf767, buf769, buf770, -128, 127, torch.int8) | |
buf772 = buf771 | |
del buf771 | |
# Source Nodes: [input_189], Original ATen: [quantized_decomposed.dequantize_per_token] | |
buf773 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf772, buf769, buf770, -128, 127, torch.int8, torch.bfloat16) | |
del buf769 | |
del buf770 | |
del buf772 | |
buf774 = buf773 | |
del buf773 | |
# Source Nodes: [w_dq_62], Original ATen: [quantized_decomposed.dequantize_per_channel_group] | |
buf775 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg235_1, arg236_1, arg237_1, -8, 7, torch.int8, 256, torch.bfloat16) | |
del arg235_1 | |
del arg236_1 | |
del arg237_1 | |
buf776 = buf775 | |
del buf775 | |
buf777 = reinterpret_tensor(buf763, (1, 4096), (4096, 1), 0); del buf763 # reuse | |
# Source Nodes: [c_62], Original ATen: [aten.mm] | |
triton_tem_fused_mm_8.run(buf774, buf776, buf777, grid=torch._inductor.kernel.mm_common.mm_grid(1, 4096, meta0), stream=stream0) | |
del buf776 | |
buf779 = buf746; del buf746 # reuse | |
# Source Nodes: [add_63, h_9, mean_18, mul_117, out_7, out_8, pow_19, rsqrt_18, x_fp32_18, x_normed_18, y_9], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, aten.pow, aten.rsqrt] | |
triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_13.run(buf744, buf658, buf691, buf777, arg238_1, buf779, 1, 4096, grid=grid(1), stream=stream0) | |
del arg238_1 | |
# Source Nodes: [choose_qparams_per_token_asymmetric_63], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric] | |
buf780 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf779, torch.int8) | |
buf781 = buf780[0] | |
buf782 = buf780[1] | |
del buf780 | |
# Source Nodes: [choose_qparams_per_token_asymmetric_64], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric] | |
buf783 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf779, torch.int8) | |
buf784 = buf783[0] | |
buf785 = buf783[1] | |
del buf783 | |
# Source Nodes: [choose_qparams_per_token_asymmetric_65], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric] | |
buf786 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf779, torch.int8) | |
buf787 = buf786[0] | |
buf788 = buf786[1] | |
del buf786 | |
buf789 = buf703; del buf703 # reuse | |
# Source Nodes: [max_10], Original ATen: [aten.max] | |
triton_poi_fused_max_1.run(arg3_1, buf789, 1, grid=grid(1), stream=stream0) | |
u9 = buf789.item() | |
buf790 = None | |
with torch.cuda._DeviceGuard(0): | |
torch.cuda.set_device(0) | |
# Source Nodes: [input_191], Original ATen: [quantized_decomposed.quantize_per_token] | |
buf791 = torch.ops.quantized_decomposed.quantize_per_token.default(buf779, buf781, buf782, -128, 127, torch.int8) | |
buf792 = buf791 | |
del buf791 | |
# Source Nodes: [input_192], Original ATen: [quantized_decomposed.dequantize_per_token] | |
buf793 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf792, buf781, buf782, -128, 127, torch.int8, torch.bfloat16) | |
del buf781 | |
del buf782 | |
del buf792 | |
buf794 = buf793 | |
del buf793 | |
# Source Nodes: [w_dq_63], Original ATen: [quantized_decomposed.dequantize_per_channel_group] | |
buf795 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg239_1, arg240_1, arg241_1, -8, 7, torch.int8, 256, torch.bfloat16) | |
del arg239_1 | |
del arg240_1 | |
del arg241_1 | |
buf796 = buf795 | |
del buf795 | |
buf797 = reinterpret_tensor(buf756, (1, 4096), (4096, 1), 0); del buf756 # reuse | |
# Source Nodes: [c_63], Original ATen: [aten.mm] | |
triton_tem_fused_mm_2.run(buf794, buf796, buf797, grid=torch._inductor.kernel.mm_common.mm_grid(1, 4096, meta0), stream=stream0) | |
del buf794 | |
del buf796 | |
# Source Nodes: [input_194], Original ATen: [quantized_decomposed.quantize_per_token] | |
buf798 = torch.ops.quantized_decomposed.quantize_per_token.default(buf779, buf784, buf785, -128, 127, torch.int8) | |
buf799 = buf798 | |
del buf798 | |
# Source Nodes: [input_195], Original ATen: [quantized_decomposed.dequantize_per_token] | |
buf800 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf799, buf784, buf785, -128, 127, torch.int8, torch.bfloat16) | |
del buf784 | |
del buf785 | |
del buf799 | |
buf801 = buf800 | |
del buf800 | |
# Source Nodes: [w_dq_64], Original ATen: [quantized_decomposed.dequantize_per_channel_group] | |
buf802 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg242_1, arg243_1, arg244_1, -8, 7, torch.int8, 256, torch.bfloat16) | |
del arg242_1 | |
del arg243_1 | |
del arg244_1 | |
buf803 = buf802 | |
del buf802 | |
buf804 = buf726; del buf726 # reuse | |
# Source Nodes: [], Original ATen: [] | |
extern_kernels.mm(reinterpret_tensor(buf801, (1, 4096), (4096, 1), 0), reinterpret_tensor(buf803, (4096, 1024), (1, 4096), 0), out=buf804) | |
del buf801 | |
del buf803 | |
# Source Nodes: [input_197], Original ATen: [quantized_decomposed.quantize_per_token] | |
buf806 = torch.ops.quantized_decomposed.quantize_per_token.default(buf779, buf787, buf788, -128, 127, torch.int8) | |
del buf779 | |
buf807 = buf806 | |
del buf806 | |
# Source Nodes: [input_198], Original ATen: [quantized_decomposed.dequantize_per_token] | |
buf808 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf807, buf787, buf788, -128, 127, torch.int8, torch.bfloat16) | |
del buf787 | |
del buf788 | |
del buf807 | |
buf809 = buf808 | |
del buf808 | |
# Source Nodes: [w_dq_65], Original ATen: [quantized_decomposed.dequantize_per_channel_group] | |
buf810 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg245_1, arg246_1, arg247_1, -8, 7, torch.int8, 256, torch.bfloat16) | |
del arg245_1 | |
del arg246_1 | |
del arg247_1 | |
buf811 = buf810 | |
del buf810 | |
buf812 = buf718; del buf718 # reuse | |
# Source Nodes: [], Original ATen: [] | |
extern_kernels.mm(reinterpret_tensor(buf809, (1, 4096), (4096, 1), 0), reinterpret_tensor(buf811, (4096, 1024), (1, 4096), 0), out=buf812) | |
del buf811 | |
buf814 = reinterpret_tensor(buf809, (1, 32, 1, 128), (4096, 128, 4096, 1), 0); del buf809 # reuse | |
# Source Nodes: [output_18, setitem_18, setitem_19], Original ATen: [aten._scaled_dot_product_efficient_attention, aten.index_put] | |
triton_poi_fused__scaled_dot_product_efficient_attention_index_put_3.run(arg3_1, buf804, arg248_1, buf812, buf797, arg249_1, arg250_1, buf814, 4096, grid=grid(4096), stream=stream0) | |
del arg248_1 | |
del buf797 | |
buf815 = buf729; del buf729 # reuse | |
# Source Nodes: [output_18], Original ATen: [aten._scaled_dot_product_efficient_attention] | |
triton_poi_fused__scaled_dot_product_efficient_attention_4.run(arg3_1, arg2_1, buf815, 65536, grid=grid(65536), stream=stream0) | |
# Source Nodes: [output_18], Original ATen: [aten._scaled_dot_product_efficient_attention] | |
buf816 = torch.ops.aten._scaled_dot_product_efficient_attention.default(buf814, arg249_1, arg250_1, buf815, False) | |
del arg249_1 | |
del arg250_1 | |
del buf814 | |
buf817 = buf816[0] | |
del buf816 | |
# Source Nodes: [choose_qparams_per_token_asymmetric_66], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric] | |
buf821 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(reinterpret_tensor(buf817, (1, 1, 4096), (4096, 4096, 1), 0), torch.int8) | |
buf822 = buf821[0] | |
buf823 = buf821[1] | |
del buf821 | |
# Source Nodes: [input_200], Original ATen: [quantized_decomposed.quantize_per_token] | |
buf824 = torch.ops.quantized_decomposed.quantize_per_token.default(reinterpret_tensor(buf817, (1, 1, 4096), (4096, 4096, 1), 0), buf822, buf823, -128, 127, torch.int8) | |
buf825 = buf824 | |
del buf824 | |
# Source Nodes: [input_201], Original ATen: [quantized_decomposed.dequantize_per_token] | |
buf826 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf825, buf822, buf823, -128, 127, torch.int8, torch.bfloat16) | |
del buf822 | |
del buf823 | |
del buf825 | |
buf827 = buf826 | |
del buf826 | |
# Source Nodes: [w_dq_66], Original ATen: [quantized_decomposed.dequantize_per_channel_group] | |
buf828 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg251_1, arg252_1, arg253_1, -8, 7, torch.int8, 256, torch.bfloat16) | |
del arg251_1 | |
del arg252_1 | |
del arg253_1 | |
buf829 = buf828 | |
del buf828 | |
buf830 = reinterpret_tensor(buf817, (1, 4096), (4096, 1), 0); del buf817 # reuse | |
# Source Nodes: [c_66], Original ATen: [aten.mm] | |
triton_tem_fused_mm_2.run(buf827, buf829, buf830, grid=torch._inductor.kernel.mm_common.mm_grid(1, 4096, meta0), stream=stream0) | |
del buf829 | |
buf831 = buf658; del buf658 # reuse | |
buf833 = buf827; del buf827 # reuse | |
# Source Nodes: [add_68, h_10, h_9, mean_19, mul_127, mul_128, out_7, out_8, pow_20, rsqrt_19, x_fp32_19, x_normed_19], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, aten.pow, aten.rsqrt] | |
triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_14.run(buf831, buf830, buf744, buf691, buf777, arg254_1, buf833, 1, 4096, grid=grid(1), stream=stream0) | |
del arg254_1 | |
del buf691 | |
del buf744 | |
del buf777 | |
del buf830 | |
# Source Nodes: [choose_qparams_per_token_asymmetric_67], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric] | |
buf834 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf833, torch.int8) | |
buf835 = buf834[0] | |
buf836 = buf834[1] | |
del buf834 | |
# Source Nodes: [choose_qparams_per_token_asymmetric_68], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric] | |
buf837 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf833, torch.int8) | |
buf838 = buf837[0] | |
buf839 = buf837[1] | |
del buf837 | |
# Source Nodes: [input_203], Original ATen: [quantized_decomposed.quantize_per_token] | |
buf840 = torch.ops.quantized_decomposed.quantize_per_token.default(buf833, buf835, buf836, -128, 127, torch.int8) | |
buf841 = buf840 | |
del buf840 | |
# Source Nodes: [input_204], Original ATen: [quantized_decomposed.dequantize_per_token] | |
buf842 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf841, buf835, buf836, -128, 127, torch.int8, torch.bfloat16) | |
del buf835 | |
del buf836 | |
del buf841 | |
buf843 = buf842 | |
del buf842 | |
# Source Nodes: [w_dq_67], Original ATen: [quantized_decomposed.dequantize_per_channel_group] | |
buf844 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg255_1, arg256_1, arg257_1, -8, 7, torch.int8, 256, torch.bfloat16) | |
del arg255_1 | |
del arg256_1 | |
del arg257_1 | |
buf845 = buf844 | |
del buf844 | |
buf846 = reinterpret_tensor(buf774, (1, 14336), (14336, 1), 0); del buf774 # reuse | |
# Source Nodes: [c_67], Original ATen: [aten.mm] | |
triton_tem_fused_mm_6.run(buf843, buf845, buf846, grid=torch._inductor.kernel.mm_common.mm_grid(1, 14336, meta1), stream=stream0) | |
del buf845 | |
# Source Nodes: [input_206], Original ATen: [quantized_decomposed.quantize_per_token] | |
buf847 = torch.ops.quantized_decomposed.quantize_per_token.default(buf833, buf838, buf839, -128, 127, torch.int8) | |
buf848 = buf847 | |
del buf847 | |
# Source Nodes: [input_207], Original ATen: [quantized_decomposed.dequantize_per_token] | |
buf849 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf848, buf838, buf839, -128, 127, torch.int8, torch.bfloat16) | |
del buf838 | |
del buf839 | |
del buf848 | |
buf850 = buf849 | |
del buf849 | |
# Source Nodes: [w_dq_68], Original ATen: [quantized_decomposed.dequantize_per_channel_group] | |
buf851 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg258_1, arg259_1, arg260_1, -8, 7, torch.int8, 256, torch.bfloat16) | |
del arg258_1 | |
del arg259_1 | |
del arg260_1 | |
buf852 = buf851 | |
del buf851 | |
buf854 = buf767; del buf767 # reuse | |
# Source Nodes: [c_68, mul_129, silu_9], Original ATen: [aten.mm, aten.mul, aten.silu] | |
triton_tem_fused_mm_mul_silu_7.run(buf850, buf852, buf846, buf854, grid=torch._inductor.kernel.mm_common.mm_grid(1, 14336, meta1), stream=stream0) | |
del buf846 | |
del buf852 | |
# Source Nodes: [choose_qparams_per_token_asymmetric_69], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric] | |
buf855 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf854, torch.int8) | |
buf856 = buf855[0] | |
buf857 = buf855[1] | |
del buf855 | |
# Source Nodes: [input_209], Original ATen: [quantized_decomposed.quantize_per_token] | |
buf858 = torch.ops.quantized_decomposed.quantize_per_token.default(buf854, buf856, buf857, -128, 127, torch.int8) | |
buf859 = buf858 | |
del buf858 | |
# Source Nodes: [input_210], Original ATen: [quantized_decomposed.dequantize_per_token] | |
buf860 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf859, buf856, buf857, -128, 127, torch.int8, torch.bfloat16) | |
del buf856 | |
del buf857 | |
del buf859 | |
buf861 = buf860 | |
del buf860 | |
# Source Nodes: [w_dq_69], Original ATen: [quantized_decomposed.dequantize_per_channel_group] | |
buf862 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg261_1, arg262_1, arg263_1, -8, 7, torch.int8, 256, torch.bfloat16) | |
del arg261_1 | |
del arg262_1 | |
del arg263_1 | |
buf863 = buf862 | |
del buf862 | |
buf864 = reinterpret_tensor(buf850, (1, 4096), (4096, 1), 0); del buf850 # reuse | |
# Source Nodes: [c_69], Original ATen: [aten.mm] | |
triton_tem_fused_mm_8.run(buf861, buf863, buf864, grid=torch._inductor.kernel.mm_common.mm_grid(1, 4096, meta0), stream=stream0) | |
del buf863 | |
buf866 = buf833; del buf833 # reuse | |
# Source Nodes: [add_70, mean_20, mul_130, out_9, pow_21, rsqrt_20, x_fp32_20, x_normed_20, y_10], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, aten.pow, aten.rsqrt] | |
triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_11.run(buf831, buf864, arg264_1, buf866, 1, 4096, grid=grid(1), stream=stream0) | |
del arg264_1 | |
# Source Nodes: [choose_qparams_per_token_asymmetric_70], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric] | |
buf867 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf866, torch.int8) | |
buf868 = buf867[0] | |
buf869 = buf867[1] | |
del buf867 | |
# Source Nodes: [choose_qparams_per_token_asymmetric_71], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric] | |
buf870 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf866, torch.int8) | |
buf871 = buf870[0] | |
buf872 = buf870[1] | |
del buf870 | |
# Source Nodes: [choose_qparams_per_token_asymmetric_72], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric] | |
buf873 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf866, torch.int8) | |
buf874 = buf873[0] | |
buf875 = buf873[1] | |
del buf873 | |
buf876 = buf789; del buf789 # reuse | |
# Source Nodes: [max_11], Original ATen: [aten.max] | |
triton_poi_fused_max_1.run(arg3_1, buf876, 1, grid=grid(1), stream=stream0) | |
u10 = buf876.item() | |
buf877 = None | |
with torch.cuda._DeviceGuard(0): | |
torch.cuda.set_device(0) | |
# Source Nodes: [input_212], Original ATen: [quantized_decomposed.quantize_per_token] | |
buf878 = torch.ops.quantized_decomposed.quantize_per_token.default(buf866, buf868, buf869, -128, 127, torch.int8) | |
buf879 = buf878 | |
del buf878 | |
# Source Nodes: [input_213], Original ATen: [quantized_decomposed.dequantize_per_token] | |
buf880 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf879, buf868, buf869, -128, 127, torch.int8, torch.bfloat16) | |
del buf868 | |
del buf869 | |
del buf879 | |
buf881 = buf880 | |
del buf880 | |
# Source Nodes: [w_dq_70], Original ATen: [quantized_decomposed.dequantize_per_channel_group] | |
buf882 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg265_1, arg266_1, arg267_1, -8, 7, torch.int8, 256, torch.bfloat16) | |
del arg265_1 | |
del arg266_1 | |
del arg267_1 | |
buf883 = buf882 | |
del buf882 | |
buf884 = reinterpret_tensor(buf843, (1, 4096), (4096, 1), 0); del buf843 # reuse | |
# Source Nodes: [c_70], Original ATen: [aten.mm] | |
triton_tem_fused_mm_2.run(buf881, buf883, buf884, grid=torch._inductor.kernel.mm_common.mm_grid(1, 4096, meta0), stream=stream0) | |
del buf881 | |
del buf883 | |
# Source Nodes: [input_215], Original ATen: [quantized_decomposed.quantize_per_token] | |
buf885 = torch.ops.quantized_decomposed.quantize_per_token.default(buf866, buf871, buf872, -128, 127, torch.int8) | |
buf886 = buf885 | |
del buf885 | |
# Source Nodes: [input_216], Original ATen: [quantized_decomposed.dequantize_per_token] | |
buf887 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf886, buf871, buf872, -128, 127, torch.int8, torch.bfloat16) | |
del buf871 | |
del buf872 | |
del buf886 | |
buf888 = buf887 | |
del buf887 | |
# Source Nodes: [w_dq_71], Original ATen: [quantized_decomposed.dequantize_per_channel_group] | |
buf889 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg268_1, arg269_1, arg270_1, -8, 7, torch.int8, 256, torch.bfloat16) | |
del arg268_1 | |
del arg269_1 | |
del arg270_1 | |
buf890 = buf889 | |
del buf889 | |
buf891 = buf812; del buf812 # reuse | |
# Source Nodes: [], Original ATen: [] | |
extern_kernels.mm(reinterpret_tensor(buf888, (1, 4096), (4096, 1), 0), reinterpret_tensor(buf890, (4096, 1024), (1, 4096), 0), out=buf891) | |
del buf888 | |
del buf890 | |
# Source Nodes: [input_218], Original ATen: [quantized_decomposed.quantize_per_token] | |
buf893 = torch.ops.quantized_decomposed.quantize_per_token.default(buf866, buf874, buf875, -128, 127, torch.int8) | |
del buf866 | |
buf894 = buf893 | |
del buf893 | |
# Source Nodes: [input_219], Original ATen: [quantized_decomposed.dequantize_per_token] | |
buf895 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf894, buf874, buf875, -128, 127, torch.int8, torch.bfloat16) | |
del buf874 | |
del buf875 | |
del buf894 | |
buf896 = buf895 | |
del buf895 | |
# Source Nodes: [w_dq_72], Original ATen: [quantized_decomposed.dequantize_per_channel_group] | |
buf897 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg271_1, arg272_1, arg273_1, -8, 7, torch.int8, 256, torch.bfloat16) | |
del arg271_1 | |
del arg272_1 | |
del arg273_1 | |
buf898 = buf897 | |
del buf897 | |
buf899 = buf804; del buf804 # reuse | |
# Source Nodes: [], Original ATen: [] | |
extern_kernels.mm(reinterpret_tensor(buf896, (1, 4096), (4096, 1), 0), reinterpret_tensor(buf898, (4096, 1024), (1, 4096), 0), out=buf899) | |
del buf898 | |
buf901 = reinterpret_tensor(buf896, (1, 32, 1, 128), (4096, 128, 4096, 1), 0); del buf896 # reuse | |
# Source Nodes: [output_20, setitem_20, setitem_21], Original ATen: [aten._scaled_dot_product_efficient_attention, aten.index_put] | |
triton_poi_fused__scaled_dot_product_efficient_attention_index_put_3.run(arg3_1, buf891, arg274_1, buf899, buf884, arg275_1, arg276_1, buf901, 4096, grid=grid(4096), stream=stream0) | |
del arg274_1 | |
del buf884 | |
buf902 = buf815; del buf815 # reuse | |
# Source Nodes: [output_20], Original ATen: [aten._scaled_dot_product_efficient_attention] | |
triton_poi_fused__scaled_dot_product_efficient_attention_4.run(arg3_1, arg2_1, buf902, 65536, grid=grid(65536), stream=stream0) | |
# Source Nodes: [output_20], Original ATen: [aten._scaled_dot_product_efficient_attention] | |
buf903 = torch.ops.aten._scaled_dot_product_efficient_attention.default(buf901, arg275_1, arg276_1, buf902, False) | |
del arg275_1 | |
del arg276_1 | |
del buf901 | |
buf904 = buf903[0] | |
del buf903 | |
# Source Nodes: [choose_qparams_per_token_asymmetric_73], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric] | |
buf908 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(reinterpret_tensor(buf904, (1, 1, 4096), (4096, 4096, 1), 0), torch.int8) | |
buf909 = buf908[0] | |
buf910 = buf908[1] | |
del buf908 | |
# Source Nodes: [input_221], Original ATen: [quantized_decomposed.quantize_per_token] | |
buf911 = torch.ops.quantized_decomposed.quantize_per_token.default(reinterpret_tensor(buf904, (1, 1, 4096), (4096, 4096, 1), 0), buf909, buf910, -128, 127, torch.int8) | |
buf912 = buf911 | |
del buf911 | |
# Source Nodes: [input_222], Original ATen: [quantized_decomposed.dequantize_per_token] | |
buf913 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf912, buf909, buf910, -128, 127, torch.int8, torch.bfloat16) | |
del buf909 | |
del buf910 | |
del buf912 | |
buf914 = buf913 | |
del buf913 | |
# Source Nodes: [w_dq_73], Original ATen: [quantized_decomposed.dequantize_per_channel_group] | |
buf915 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg277_1, arg278_1, arg279_1, -8, 7, torch.int8, 256, torch.bfloat16) | |
del arg277_1 | |
del arg278_1 | |
del arg279_1 | |
buf916 = buf915 | |
del buf915 | |
buf917 = reinterpret_tensor(buf904, (1, 4096), (4096, 1), 0); del buf904 # reuse | |
# Source Nodes: [c_73], Original ATen: [aten.mm] | |
triton_tem_fused_mm_2.run(buf914, buf916, buf917, grid=torch._inductor.kernel.mm_common.mm_grid(1, 4096, meta0), stream=stream0) | |
del buf916 | |
buf919 = buf914; del buf914 # reuse | |
# Source Nodes: [add_75, h_11, mean_21, mul_140, mul_141, out_9, pow_22, rsqrt_21, x_fp32_21, x_normed_21], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, aten.pow, aten.rsqrt] | |
triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_12.run(buf917, buf831, buf864, arg280_1, buf919, 1, 4096, grid=grid(1), stream=stream0) | |
del arg280_1 | |
# Source Nodes: [choose_qparams_per_token_asymmetric_74], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric] | |
buf920 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf919, torch.int8) | |
buf921 = buf920[0] | |
buf922 = buf920[1] | |
del buf920 | |
# Source Nodes: [choose_qparams_per_token_asymmetric_75], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric] | |
buf923 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf919, torch.int8) | |
buf924 = buf923[0] | |
buf925 = buf923[1] | |
del buf923 | |
# Source Nodes: [input_224], Original ATen: [quantized_decomposed.quantize_per_token] | |
buf926 = torch.ops.quantized_decomposed.quantize_per_token.default(buf919, buf921, buf922, -128, 127, torch.int8) | |
buf927 = buf926 | |
del buf926 | |
# Source Nodes: [input_225], Original ATen: [quantized_decomposed.dequantize_per_token] | |
buf928 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf927, buf921, buf922, -128, 127, torch.int8, torch.bfloat16) | |
del buf921 | |
del buf922 | |
del buf927 | |
buf929 = buf928 | |
del buf928 | |
# Source Nodes: [w_dq_74], Original ATen: [quantized_decomposed.dequantize_per_channel_group] | |
buf930 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg281_1, arg282_1, arg283_1, -8, 7, torch.int8, 256, torch.bfloat16) | |
del arg281_1 | |
del arg282_1 | |
del arg283_1 | |
buf931 = buf930 | |
del buf930 | |
buf932 = reinterpret_tensor(buf861, (1, 14336), (14336, 1), 0); del buf861 # reuse | |
# Source Nodes: [c_74], Original ATen: [aten.mm] | |
triton_tem_fused_mm_6.run(buf929, buf931, buf932, grid=torch._inductor.kernel.mm_common.mm_grid(1, 14336, meta1), stream=stream0) | |
del buf931 | |
# Source Nodes: [input_227], Original ATen: [quantized_decomposed.quantize_per_token] | |
buf933 = torch.ops.quantized_decomposed.quantize_per_token.default(buf919, buf924, buf925, -128, 127, torch.int8) | |
buf934 = buf933 | |
del buf933 | |
# Source Nodes: [input_228], Original ATen: [quantized_decomposed.dequantize_per_token] | |
buf935 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf934, buf924, buf925, -128, 127, torch.int8, torch.bfloat16) | |
del buf924 | |
del buf925 | |
del buf934 | |
buf936 = buf935 | |
del buf935 | |
# Source Nodes: [w_dq_75], Original ATen: [quantized_decomposed.dequantize_per_channel_group] | |
buf937 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg284_1, arg285_1, arg286_1, -8, 7, torch.int8, 256, torch.bfloat16) | |
del arg284_1 | |
del arg285_1 | |
del arg286_1 | |
buf938 = buf937 | |
del buf937 | |
buf940 = buf854; del buf854 # reuse | |
# Source Nodes: [c_75, mul_142, silu_10], Original ATen: [aten.mm, aten.mul, aten.silu] | |
triton_tem_fused_mm_mul_silu_7.run(buf936, buf938, buf932, buf940, grid=torch._inductor.kernel.mm_common.mm_grid(1, 14336, meta1), stream=stream0) | |
del buf932 | |
del buf938 | |
# Source Nodes: [choose_qparams_per_token_asymmetric_76], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric] | |
buf941 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf940, torch.int8) | |
buf942 = buf941[0] | |
buf943 = buf941[1] | |
del buf941 | |
# Source Nodes: [input_230], Original ATen: [quantized_decomposed.quantize_per_token] | |
buf944 = torch.ops.quantized_decomposed.quantize_per_token.default(buf940, buf942, buf943, -128, 127, torch.int8) | |
buf945 = buf944 | |
del buf944 | |
# Source Nodes: [input_231], Original ATen: [quantized_decomposed.dequantize_per_token] | |
buf946 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf945, buf942, buf943, -128, 127, torch.int8, torch.bfloat16) | |
del buf942 | |
del buf943 | |
del buf945 | |
buf947 = buf946 | |
del buf946 | |
# Source Nodes: [w_dq_76], Original ATen: [quantized_decomposed.dequantize_per_channel_group] | |
buf948 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg287_1, arg288_1, arg289_1, -8, 7, torch.int8, 256, torch.bfloat16) | |
del arg287_1 | |
del arg288_1 | |
del arg289_1 | |
buf949 = buf948 | |
del buf948 | |
buf950 = reinterpret_tensor(buf936, (1, 4096), (4096, 1), 0); del buf936 # reuse | |
# Source Nodes: [c_76], Original ATen: [aten.mm] | |
triton_tem_fused_mm_8.run(buf947, buf949, buf950, grid=torch._inductor.kernel.mm_common.mm_grid(1, 4096, meta0), stream=stream0) | |
del buf949 | |
buf952 = buf919; del buf919 # reuse | |
# Source Nodes: [add_77, h_11, mean_22, mul_143, out_10, out_9, pow_23, rsqrt_22, x_fp32_22, x_normed_22, y_11], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, aten.pow, aten.rsqrt] | |
triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_13.run(buf917, buf831, buf864, buf950, arg290_1, buf952, 1, 4096, grid=grid(1), stream=stream0) | |
del arg290_1 | |
# Source Nodes: [choose_qparams_per_token_asymmetric_77], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric] | |
buf953 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf952, torch.int8) | |
buf954 = buf953[0] | |
buf955 = buf953[1] | |
del buf953 | |
# Source Nodes: [choose_qparams_per_token_asymmetric_78], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric] | |
buf956 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf952, torch.int8) | |
buf957 = buf956[0] | |
buf958 = buf956[1] | |
del buf956 | |
# Source Nodes: [choose_qparams_per_token_asymmetric_79], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric] | |
buf959 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf952, torch.int8) | |
buf960 = buf959[0] | |
buf961 = buf959[1] | |
del buf959 | |
buf962 = buf876; del buf876 # reuse | |
# Source Nodes: [max_12], Original ATen: [aten.max] | |
triton_poi_fused_max_1.run(arg3_1, buf962, 1, grid=grid(1), stream=stream0) | |
u11 = buf962.item() | |
buf963 = None | |
with torch.cuda._DeviceGuard(0): | |
torch.cuda.set_device(0) | |
# Source Nodes: [input_233], Original ATen: [quantized_decomposed.quantize_per_token] | |
buf964 = torch.ops.quantized_decomposed.quantize_per_token.default(buf952, buf954, buf955, -128, 127, torch.int8) | |
buf965 = buf964 | |
del buf964 | |
# Source Nodes: [input_234], Original ATen: [quantized_decomposed.dequantize_per_token] | |
buf966 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf965, buf954, buf955, -128, 127, torch.int8, torch.bfloat16) | |
del buf954 | |
del buf955 | |
del buf965 | |
buf967 = buf966 | |
del buf966 | |
# Source Nodes: [w_dq_77], Original ATen: [quantized_decomposed.dequantize_per_channel_group] | |
buf968 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg291_1, arg292_1, arg293_1, -8, 7, torch.int8, 256, torch.bfloat16) | |
del arg291_1 | |
del arg292_1 | |
del arg293_1 | |
buf969 = buf968 | |
del buf968 | |
buf970 = reinterpret_tensor(buf929, (1, 4096), (4096, 1), 0); del buf929 # reuse | |
# Source Nodes: [c_77], Original ATen: [aten.mm] | |
triton_tem_fused_mm_2.run(buf967, buf969, buf970, grid=torch._inductor.kernel.mm_common.mm_grid(1, 4096, meta0), stream=stream0) | |
del buf967 | |
del buf969 | |
# Source Nodes: [input_236], Original ATen: [quantized_decomposed.quantize_per_token] | |
buf971 = torch.ops.quantized_decomposed.quantize_per_token.default(buf952, buf957, buf958, -128, 127, torch.int8) | |
buf972 = buf971 | |
del buf971 | |
# Source Nodes: [input_237], Original ATen: [quantized_decomposed.dequantize_per_token] | |
buf973 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf972, buf957, buf958, -128, 127, torch.int8, torch.bfloat16) | |
del buf957 | |
del buf958 | |
del buf972 | |
buf974 = buf973 | |
del buf973 | |
# Source Nodes: [w_dq_78], Original ATen: [quantized_decomposed.dequantize_per_channel_group] | |
buf975 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg294_1, arg295_1, arg296_1, -8, 7, torch.int8, 256, torch.bfloat16) | |
del arg294_1 | |
del arg295_1 | |
del arg296_1 | |
buf976 = buf975 | |
del buf975 | |
buf977 = buf899; del buf899 # reuse | |
# Source Nodes: [], Original ATen: [] | |
extern_kernels.mm(reinterpret_tensor(buf974, (1, 4096), (4096, 1), 0), reinterpret_tensor(buf976, (4096, 1024), (1, 4096), 0), out=buf977) | |
del buf974 | |
del buf976 | |
# Source Nodes: [input_239], Original ATen: [quantized_decomposed.quantize_per_token] | |
buf979 = torch.ops.quantized_decomposed.quantize_per_token.default(buf952, buf960, buf961, -128, 127, torch.int8) | |
del buf952 | |
buf980 = buf979 | |
del buf979 | |
# Source Nodes: [input_240], Original ATen: [quantized_decomposed.dequantize_per_token] | |
buf981 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf980, buf960, buf961, -128, 127, torch.int8, torch.bfloat16) | |
del buf960 | |
del buf961 | |
del buf980 | |
buf982 = buf981 | |
del buf981 | |
# Source Nodes: [w_dq_79], Original ATen: [quantized_decomposed.dequantize_per_channel_group] | |
buf983 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg297_1, arg298_1, arg299_1, -8, 7, torch.int8, 256, torch.bfloat16) | |
del arg297_1 | |
del arg298_1 | |
del arg299_1 | |
buf984 = buf983 | |
del buf983 | |
buf985 = buf891; del buf891 # reuse | |
# Source Nodes: [], Original ATen: [] | |
extern_kernels.mm(reinterpret_tensor(buf982, (1, 4096), (4096, 1), 0), reinterpret_tensor(buf984, (4096, 1024), (1, 4096), 0), out=buf985) | |
del buf984 | |
buf987 = reinterpret_tensor(buf982, (1, 32, 1, 128), (4096, 128, 4096, 1), 0); del buf982 # reuse | |
# Source Nodes: [output_22, setitem_22, setitem_23], Original ATen: [aten._scaled_dot_product_efficient_attention, aten.index_put] | |
triton_poi_fused__scaled_dot_product_efficient_attention_index_put_3.run(arg3_1, buf977, arg300_1, buf985, buf970, arg301_1, arg302_1, buf987, 4096, grid=grid(4096), stream=stream0) | |
del arg300_1 | |
del buf970 | |
buf988 = buf902; del buf902 # reuse | |
# Source Nodes: [output_22], Original ATen: [aten._scaled_dot_product_efficient_attention] | |
triton_poi_fused__scaled_dot_product_efficient_attention_4.run(arg3_1, arg2_1, buf988, 65536, grid=grid(65536), stream=stream0) | |
# Source Nodes: [output_22], Original ATen: [aten._scaled_dot_product_efficient_attention] | |
buf989 = torch.ops.aten._scaled_dot_product_efficient_attention.default(buf987, arg301_1, arg302_1, buf988, False) | |
del arg301_1 | |
del arg302_1 | |
del buf987 | |
buf990 = buf989[0] | |
del buf989 | |
# Source Nodes: [choose_qparams_per_token_asymmetric_80], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric] | |
buf994 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(reinterpret_tensor(buf990, (1, 1, 4096), (4096, 4096, 1), 0), torch.int8) | |
buf995 = buf994[0] | |
buf996 = buf994[1] | |
del buf994 | |
# Source Nodes: [input_242], Original ATen: [quantized_decomposed.quantize_per_token] | |
buf997 = torch.ops.quantized_decomposed.quantize_per_token.default(reinterpret_tensor(buf990, (1, 1, 4096), (4096, 4096, 1), 0), buf995, buf996, -128, 127, torch.int8) | |
buf998 = buf997 | |
del buf997 | |
# Source Nodes: [input_243], Original ATen: [quantized_decomposed.dequantize_per_token] | |
buf999 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf998, buf995, buf996, -128, 127, torch.int8, torch.bfloat16) | |
del buf995 | |
del buf996 | |
del buf998 | |
buf1000 = buf999 | |
del buf999 | |
# Source Nodes: [w_dq_80], Original ATen: [quantized_decomposed.dequantize_per_channel_group] | |
buf1001 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg303_1, arg304_1, arg305_1, -8, 7, torch.int8, 256, torch.bfloat16) | |
del arg303_1 | |
del arg304_1 | |
del arg305_1 | |
buf1002 = buf1001 | |
del buf1001 | |
buf1003 = reinterpret_tensor(buf990, (1, 4096), (4096, 1), 0); del buf990 # reuse | |
# Source Nodes: [c_80], Original ATen: [aten.mm] | |
triton_tem_fused_mm_2.run(buf1000, buf1002, buf1003, grid=torch._inductor.kernel.mm_common.mm_grid(1, 4096, meta0), stream=stream0) | |
del buf1002 | |
buf1004 = reinterpret_tensor(buf1003, (1, 1, 4096), (4096, 4096, 1), 0); del buf1003 # reuse | |
buf1006 = buf1000; del buf1000 # reuse | |
# Source Nodes: [add_82, h_11, h_12, mean_23, mul_153, mul_154, out_10, out_9, pow_24, rsqrt_23, x_fp32_23, x_normed_23], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, aten.pow, aten.rsqrt] | |
triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_15.run(buf1004, buf917, buf831, buf864, buf950, arg306_1, buf1006, 1, 4096, grid=grid(1), stream=stream0) | |
del arg306_1 | |
del buf831 | |
del buf864 | |
del buf917 | |
del buf950 | |
# Source Nodes: [choose_qparams_per_token_asymmetric_81], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric] | |
buf1007 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf1006, torch.int8) | |
buf1008 = buf1007[0] | |
buf1009 = buf1007[1] | |
del buf1007 | |
# Source Nodes: [choose_qparams_per_token_asymmetric_82], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric] | |
buf1010 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf1006, torch.int8) | |
buf1011 = buf1010[0] | |
buf1012 = buf1010[1] | |
del buf1010 | |
# Source Nodes: [input_245], Original ATen: [quantized_decomposed.quantize_per_token] | |
buf1013 = torch.ops.quantized_decomposed.quantize_per_token.default(buf1006, buf1008, buf1009, -128, 127, torch.int8) | |
buf1014 = buf1013 | |
del buf1013 | |
# Source Nodes: [input_246], Original ATen: [quantized_decomposed.dequantize_per_token] | |
buf1015 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf1014, buf1008, buf1009, -128, 127, torch.int8, torch.bfloat16) | |
del buf1008 | |
del buf1009 | |
del buf1014 | |
buf1016 = buf1015 | |
del buf1015 | |
# Source Nodes: [w_dq_81], Original ATen: [quantized_decomposed.dequantize_per_channel_group] | |
buf1017 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg307_1, arg308_1, arg309_1, -8, 7, torch.int8, 256, torch.bfloat16) | |
del arg307_1 | |
del arg308_1 | |
del arg309_1 | |
buf1018 = buf1017 | |
del buf1017 | |
buf1019 = reinterpret_tensor(buf947, (1, 14336), (14336, 1), 0); del buf947 # reuse | |
# Source Nodes: [c_81], Original ATen: [aten.mm] | |
triton_tem_fused_mm_6.run(buf1016, buf1018, buf1019, grid=torch._inductor.kernel.mm_common.mm_grid(1, 14336, meta1), stream=stream0) | |
del buf1018 | |
# Source Nodes: [input_248], Original ATen: [quantized_decomposed.quantize_per_token] | |
buf1020 = torch.ops.quantized_decomposed.quantize_per_token.default(buf1006, buf1011, buf1012, -128, 127, torch.int8) | |
buf1021 = buf1020 | |
del buf1020 | |
# Source Nodes: [input_249], Original ATen: [quantized_decomposed.dequantize_per_token] | |
buf1022 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf1021, buf1011, buf1012, -128, 127, torch.int8, torch.bfloat16) | |
del buf1011 | |
del buf1012 | |
del buf1021 | |
buf1023 = buf1022 | |
del buf1022 | |
# Source Nodes: [w_dq_82], Original ATen: [quantized_decomposed.dequantize_per_channel_group] | |
buf1024 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg310_1, arg311_1, arg312_1, -8, 7, torch.int8, 256, torch.bfloat16) | |
del arg310_1 | |
del arg311_1 | |
del arg312_1 | |
buf1025 = buf1024 | |
del buf1024 | |
buf1027 = buf940; del buf940 # reuse | |
# Source Nodes: [c_82, mul_155, silu_11], Original ATen: [aten.mm, aten.mul, aten.silu] | |
triton_tem_fused_mm_mul_silu_7.run(buf1023, buf1025, buf1019, buf1027, grid=torch._inductor.kernel.mm_common.mm_grid(1, 14336, meta1), stream=stream0) | |
del buf1019 | |
del buf1025 | |
# Source Nodes: [choose_qparams_per_token_asymmetric_83], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric] | |
buf1028 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf1027, torch.int8) | |
buf1029 = buf1028[0] | |
buf1030 = buf1028[1] | |
del buf1028 | |
# Source Nodes: [input_251], Original ATen: [quantized_decomposed.quantize_per_token] | |
buf1031 = torch.ops.quantized_decomposed.quantize_per_token.default(buf1027, buf1029, buf1030, -128, 127, torch.int8) | |
buf1032 = buf1031 | |
del buf1031 | |
# Source Nodes: [input_252], Original ATen: [quantized_decomposed.dequantize_per_token] | |
buf1033 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf1032, buf1029, buf1030, -128, 127, torch.int8, torch.bfloat16) | |
del buf1029 | |
del buf1030 | |
del buf1032 | |
buf1034 = buf1033 | |
del buf1033 | |
# Source Nodes: [w_dq_83], Original ATen: [quantized_decomposed.dequantize_per_channel_group] | |
buf1035 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg313_1, arg314_1, arg315_1, -8, 7, torch.int8, 256, torch.bfloat16) | |
del arg313_1 | |
del arg314_1 | |
del arg315_1 | |
buf1036 = buf1035 | |
del buf1035 | |
buf1037 = reinterpret_tensor(buf1023, (1, 4096), (4096, 1), 0); del buf1023 # reuse | |
# Source Nodes: [c_83], Original ATen: [aten.mm] | |
triton_tem_fused_mm_8.run(buf1034, buf1036, buf1037, grid=torch._inductor.kernel.mm_common.mm_grid(1, 4096, meta0), stream=stream0) | |
del buf1036 | |
buf1039 = buf1006; del buf1006 # reuse | |
# Source Nodes: [add_84, mean_24, mul_156, out_11, pow_25, rsqrt_24, x_fp32_24, x_normed_24, y_12], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, aten.pow, aten.rsqrt] | |
triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_11.run(buf1004, buf1037, arg316_1, buf1039, 1, 4096, grid=grid(1), stream=stream0) | |
del arg316_1 | |
# Source Nodes: [choose_qparams_per_token_asymmetric_84], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric] | |
buf1040 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf1039, torch.int8) | |
buf1041 = buf1040[0] | |
buf1042 = buf1040[1] | |
del buf1040 | |
# Source Nodes: [choose_qparams_per_token_asymmetric_85], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric] | |
buf1043 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf1039, torch.int8) | |
buf1044 = buf1043[0] | |
buf1045 = buf1043[1] | |
del buf1043 | |
# Source Nodes: [choose_qparams_per_token_asymmetric_86], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric] | |
buf1046 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf1039, torch.int8) | |
buf1047 = buf1046[0] | |
buf1048 = buf1046[1] | |
del buf1046 | |
buf1049 = buf962; del buf962 # reuse | |
# Source Nodes: [max_13], Original ATen: [aten.max] | |
triton_poi_fused_max_1.run(arg3_1, buf1049, 1, grid=grid(1), stream=stream0) | |
u12 = buf1049.item() | |
buf1050 = None | |
with torch.cuda._DeviceGuard(0): | |
torch.cuda.set_device(0) | |
# Source Nodes: [input_254], Original ATen: [quantized_decomposed.quantize_per_token] | |
buf1051 = torch.ops.quantized_decomposed.quantize_per_token.default(buf1039, buf1041, buf1042, -128, 127, torch.int8) | |
buf1052 = buf1051 | |
del buf1051 | |
# Source Nodes: [input_255], Original ATen: [quantized_decomposed.dequantize_per_token] | |
buf1053 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf1052, buf1041, buf1042, -128, 127, torch.int8, torch.bfloat16) | |
del buf1041 | |
del buf1042 | |
del buf1052 | |
buf1054 = buf1053 | |
del buf1053 | |
# Source Nodes: [w_dq_84], Original ATen: [quantized_decomposed.dequantize_per_channel_group] | |
buf1055 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg317_1, arg318_1, arg319_1, -8, 7, torch.int8, 256, torch.bfloat16) | |
del arg317_1 | |
del arg318_1 | |
del arg319_1 | |
buf1056 = buf1055 | |
del buf1055 | |
buf1057 = reinterpret_tensor(buf1016, (1, 4096), (4096, 1), 0); del buf1016 # reuse | |
# Source Nodes: [c_84], Original ATen: [aten.mm] | |
triton_tem_fused_mm_2.run(buf1054, buf1056, buf1057, grid=torch._inductor.kernel.mm_common.mm_grid(1, 4096, meta0), stream=stream0) | |
del buf1054 | |
del buf1056 | |
# Source Nodes: [input_257], Original ATen: [quantized_decomposed.quantize_per_token] | |
buf1058 = torch.ops.quantized_decomposed.quantize_per_token.default(buf1039, buf1044, buf1045, -128, 127, torch.int8) | |
buf1059 = buf1058 | |
del buf1058 | |
# Source Nodes: [input_258], Original ATen: [quantized_decomposed.dequantize_per_token] | |
buf1060 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf1059, buf1044, buf1045, -128, 127, torch.int8, torch.bfloat16) | |
del buf1044 | |
del buf1045 | |
del buf1059 | |
buf1061 = buf1060 | |
del buf1060 | |
# Source Nodes: [w_dq_85], Original ATen: [quantized_decomposed.dequantize_per_channel_group] | |
buf1062 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg320_1, arg321_1, arg322_1, -8, 7, torch.int8, 256, torch.bfloat16) | |
del arg320_1 | |
del arg321_1 | |
del arg322_1 | |
buf1063 = buf1062 | |
del buf1062 | |
buf1064 = buf985; del buf985 # reuse | |
# Source Nodes: [], Original ATen: [] | |
extern_kernels.mm(reinterpret_tensor(buf1061, (1, 4096), (4096, 1), 0), reinterpret_tensor(buf1063, (4096, 1024), (1, 4096), 0), out=buf1064) | |
del buf1061 | |
del buf1063 | |
# Source Nodes: [input_260], Original ATen: [quantized_decomposed.quantize_per_token] | |
buf1066 = torch.ops.quantized_decomposed.quantize_per_token.default(buf1039, buf1047, buf1048, -128, 127, torch.int8) | |
del buf1039 | |
buf1067 = buf1066 | |
del buf1066 | |
# Source Nodes: [input_261], Original ATen: [quantized_decomposed.dequantize_per_token] | |
buf1068 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf1067, buf1047, buf1048, -128, 127, torch.int8, torch.bfloat16) | |
del buf1047 | |
del buf1048 | |
del buf1067 | |
buf1069 = buf1068 | |
del buf1068 | |
# Source Nodes: [w_dq_86], Original ATen: [quantized_decomposed.dequantize_per_channel_group] | |
buf1070 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg323_1, arg324_1, arg325_1, -8, 7, torch.int8, 256, torch.bfloat16) | |
del arg323_1 | |
del arg324_1 | |
del arg325_1 | |
buf1071 = buf1070 | |
del buf1070 | |
buf1072 = buf977; del buf977 # reuse | |
# Source Nodes: [], Original ATen: [] | |
extern_kernels.mm(reinterpret_tensor(buf1069, (1, 4096), (4096, 1), 0), reinterpret_tensor(buf1071, (4096, 1024), (1, 4096), 0), out=buf1072) | |
del buf1071 | |
buf1074 = reinterpret_tensor(buf1069, (1, 32, 1, 128), (4096, 128, 4096, 1), 0); del buf1069 # reuse | |
# Source Nodes: [output_24, setitem_24, setitem_25], Original ATen: [aten._scaled_dot_product_efficient_attention, aten.index_put] | |
triton_poi_fused__scaled_dot_product_efficient_attention_index_put_3.run(arg3_1, buf1064, arg326_1, buf1072, buf1057, arg327_1, arg328_1, buf1074, 4096, grid=grid(4096), stream=stream0) | |
del arg326_1 | |
del buf1057 | |
buf1075 = buf988; del buf988 # reuse | |
# Source Nodes: [output_24], Original ATen: [aten._scaled_dot_product_efficient_attention] | |
triton_poi_fused__scaled_dot_product_efficient_attention_4.run(arg3_1, arg2_1, buf1075, 65536, grid=grid(65536), stream=stream0) | |
# Source Nodes: [output_24], Original ATen: [aten._scaled_dot_product_efficient_attention] | |
buf1076 = torch.ops.aten._scaled_dot_product_efficient_attention.default(buf1074, arg327_1, arg328_1, buf1075, False) | |
del arg327_1 | |
del arg328_1 | |
del buf1074 | |
buf1077 = buf1076[0] | |
del buf1076 | |
# Source Nodes: [choose_qparams_per_token_asymmetric_87], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric] | |
buf1081 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(reinterpret_tensor(buf1077, (1, 1, 4096), (4096, 4096, 1), 0), torch.int8) | |
buf1082 = buf1081[0] | |
buf1083 = buf1081[1] | |
del buf1081 | |
# Source Nodes: [input_263], Original ATen: [quantized_decomposed.quantize_per_token] | |
buf1084 = torch.ops.quantized_decomposed.quantize_per_token.default(reinterpret_tensor(buf1077, (1, 1, 4096), (4096, 4096, 1), 0), buf1082, buf1083, -128, 127, torch.int8) | |
buf1085 = buf1084 | |
del buf1084 | |
# Source Nodes: [input_264], Original ATen: [quantized_decomposed.dequantize_per_token] | |
buf1086 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf1085, buf1082, buf1083, -128, 127, torch.int8, torch.bfloat16) | |
del buf1082 | |
del buf1083 | |
del buf1085 | |
buf1087 = buf1086 | |
del buf1086 | |
# Source Nodes: [w_dq_87], Original ATen: [quantized_decomposed.dequantize_per_channel_group] | |
buf1088 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg329_1, arg330_1, arg331_1, -8, 7, torch.int8, 256, torch.bfloat16) | |
del arg329_1 | |
del arg330_1 | |
del arg331_1 | |
buf1089 = buf1088 | |
del buf1088 | |
buf1090 = reinterpret_tensor(buf1077, (1, 4096), (4096, 1), 0); del buf1077 # reuse | |
# Source Nodes: [c_87], Original ATen: [aten.mm] | |
triton_tem_fused_mm_2.run(buf1087, buf1089, buf1090, grid=torch._inductor.kernel.mm_common.mm_grid(1, 4096, meta0), stream=stream0) | |
del buf1089 | |
buf1092 = buf1087; del buf1087 # reuse | |
# Source Nodes: [add_89, h_13, mean_25, mul_166, mul_167, out_11, pow_26, rsqrt_25, x_fp32_25, x_normed_25], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, aten.pow, aten.rsqrt] | |
triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_12.run(buf1090, buf1004, buf1037, arg332_1, buf1092, 1, 4096, grid=grid(1), stream=stream0) | |
del arg332_1 | |
# Source Nodes: [choose_qparams_per_token_asymmetric_88], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric] | |
buf1093 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf1092, torch.int8) | |
buf1094 = buf1093[0] | |
buf1095 = buf1093[1] | |
del buf1093 | |
# Source Nodes: [choose_qparams_per_token_asymmetric_89], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric] | |
buf1096 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf1092, torch.int8) | |
buf1097 = buf1096[0] | |
buf1098 = buf1096[1] | |
del buf1096 | |
# Source Nodes: [input_266], Original ATen: [quantized_decomposed.quantize_per_token] | |
buf1099 = torch.ops.quantized_decomposed.quantize_per_token.default(buf1092, buf1094, buf1095, -128, 127, torch.int8) | |
buf1100 = buf1099 | |
del buf1099 | |
# Source Nodes: [input_267], Original ATen: [quantized_decomposed.dequantize_per_token] | |
buf1101 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf1100, buf1094, buf1095, -128, 127, torch.int8, torch.bfloat16) | |
del buf1094 | |
del buf1095 | |
del buf1100 | |
buf1102 = buf1101 | |
del buf1101 | |
# Source Nodes: [w_dq_88], Original ATen: [quantized_decomposed.dequantize_per_channel_group] | |
buf1103 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg333_1, arg334_1, arg335_1, -8, 7, torch.int8, 256, torch.bfloat16) | |
del arg333_1 | |
del arg334_1 | |
del arg335_1 | |
buf1104 = buf1103 | |
del buf1103 | |
buf1105 = reinterpret_tensor(buf1034, (1, 14336), (14336, 1), 0); del buf1034 # reuse | |
# Source Nodes: [c_88], Original ATen: [aten.mm] | |
triton_tem_fused_mm_6.run(buf1102, buf1104, buf1105, grid=torch._inductor.kernel.mm_common.mm_grid(1, 14336, meta1), stream=stream0) | |
del buf1104 | |
# Source Nodes: [input_269], Original ATen: [quantized_decomposed.quantize_per_token] | |
buf1106 = torch.ops.quantized_decomposed.quantize_per_token.default(buf1092, buf1097, buf1098, -128, 127, torch.int8) | |
buf1107 = buf1106 | |
del buf1106 | |
# Source Nodes: [input_270], Original ATen: [quantized_decomposed.dequantize_per_token] | |
buf1108 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf1107, buf1097, buf1098, -128, 127, torch.int8, torch.bfloat16) | |
del buf1097 | |
del buf1098 | |
del buf1107 | |
buf1109 = buf1108 | |
del buf1108 | |
# Source Nodes: [w_dq_89], Original ATen: [quantized_decomposed.dequantize_per_channel_group] | |
buf1110 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg336_1, arg337_1, arg338_1, -8, 7, torch.int8, 256, torch.bfloat16) | |
del arg336_1 | |
del arg337_1 | |
del arg338_1 | |
buf1111 = buf1110 | |
del buf1110 | |
buf1113 = buf1027; del buf1027 # reuse | |
# Source Nodes: [c_89, mul_168, silu_12], Original ATen: [aten.mm, aten.mul, aten.silu] | |
triton_tem_fused_mm_mul_silu_7.run(buf1109, buf1111, buf1105, buf1113, grid=torch._inductor.kernel.mm_common.mm_grid(1, 14336, meta1), stream=stream0) | |
del buf1105 | |
del buf1111 | |
# Source Nodes: [choose_qparams_per_token_asymmetric_90], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric] | |
buf1114 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf1113, torch.int8) | |
buf1115 = buf1114[0] | |
buf1116 = buf1114[1] | |
del buf1114 | |
# Source Nodes: [input_272], Original ATen: [quantized_decomposed.quantize_per_token] | |
buf1117 = torch.ops.quantized_decomposed.quantize_per_token.default(buf1113, buf1115, buf1116, -128, 127, torch.int8) | |
buf1118 = buf1117 | |
del buf1117 | |
# Source Nodes: [input_273], Original ATen: [quantized_decomposed.dequantize_per_token] | |
buf1119 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf1118, buf1115, buf1116, -128, 127, torch.int8, torch.bfloat16) | |
del buf1115 | |
del buf1116 | |
del buf1118 | |
buf1120 = buf1119 | |
del buf1119 | |
# Source Nodes: [w_dq_90], Original ATen: [quantized_decomposed.dequantize_per_channel_group] | |
buf1121 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg339_1, arg340_1, arg341_1, -8, 7, torch.int8, 256, torch.bfloat16) | |
del arg339_1 | |
del arg340_1 | |
del arg341_1 | |
buf1122 = buf1121 | |
del buf1121 | |
buf1123 = reinterpret_tensor(buf1109, (1, 4096), (4096, 1), 0); del buf1109 # reuse | |
# Source Nodes: [c_90], Original ATen: [aten.mm] | |
triton_tem_fused_mm_8.run(buf1120, buf1122, buf1123, grid=torch._inductor.kernel.mm_common.mm_grid(1, 4096, meta0), stream=stream0) | |
del buf1122 | |
buf1125 = buf1092; del buf1092 # reuse | |
# Source Nodes: [add_91, h_13, mean_26, mul_169, out_11, out_12, pow_27, rsqrt_26, x_fp32_26, x_normed_26, y_13], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, aten.pow, aten.rsqrt] | |
triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_13.run(buf1090, buf1004, buf1037, buf1123, arg342_1, buf1125, 1, 4096, grid=grid(1), stream=stream0) | |
del arg342_1 | |
# Source Nodes: [choose_qparams_per_token_asymmetric_91], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric] | |
buf1126 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf1125, torch.int8) | |
buf1127 = buf1126[0] | |
buf1128 = buf1126[1] | |
del buf1126 | |
# Source Nodes: [choose_qparams_per_token_asymmetric_92], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric] | |
buf1129 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf1125, torch.int8) | |
buf1130 = buf1129[0] | |
buf1131 = buf1129[1] | |
del buf1129 | |
# Source Nodes: [choose_qparams_per_token_asymmetric_93], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric] | |
buf1132 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf1125, torch.int8) | |
buf1133 = buf1132[0] | |
buf1134 = buf1132[1] | |
del buf1132 | |
buf1135 = buf1049; del buf1049 # reuse | |
# Source Nodes: [max_14], Original ATen: [aten.max] | |
triton_poi_fused_max_1.run(arg3_1, buf1135, 1, grid=grid(1), stream=stream0) | |
u13 = buf1135.item() | |
buf1136 = None | |
with torch.cuda._DeviceGuard(0): | |
torch.cuda.set_device(0) | |
# Source Nodes: [input_275], Original ATen: [quantized_decomposed.quantize_per_token] | |
buf1137 = torch.ops.quantized_decomposed.quantize_per_token.default(buf1125, buf1127, buf1128, -128, 127, torch.int8) | |
buf1138 = buf1137 | |
del buf1137 | |
# Source Nodes: [input_276], Original ATen: [quantized_decomposed.dequantize_per_token] | |
buf1139 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf1138, buf1127, buf1128, -128, 127, torch.int8, torch.bfloat16) | |
del buf1127 | |
del buf1128 | |
del buf1138 | |
buf1140 = buf1139 | |
del buf1139 | |
# Source Nodes: [w_dq_91], Original ATen: [quantized_decomposed.dequantize_per_channel_group] | |
buf1141 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg343_1, arg344_1, arg345_1, -8, 7, torch.int8, 256, torch.bfloat16) | |
del arg343_1 | |
del arg344_1 | |
del arg345_1 | |
buf1142 = buf1141 | |
del buf1141 | |
buf1143 = reinterpret_tensor(buf1102, (1, 4096), (4096, 1), 0); del buf1102 # reuse | |
# Source Nodes: [c_91], Original ATen: [aten.mm] | |
triton_tem_fused_mm_2.run(buf1140, buf1142, buf1143, grid=torch._inductor.kernel.mm_common.mm_grid(1, 4096, meta0), stream=stream0) | |
del buf1140 | |
del buf1142 | |
# Source Nodes: [input_278], Original ATen: [quantized_decomposed.quantize_per_token] | |
buf1144 = torch.ops.quantized_decomposed.quantize_per_token.default(buf1125, buf1130, buf1131, -128, 127, torch.int8) | |
buf1145 = buf1144 | |
del buf1144 | |
# Source Nodes: [input_279], Original ATen: [quantized_decomposed.dequantize_per_token] | |
buf1146 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf1145, buf1130, buf1131, -128, 127, torch.int8, torch.bfloat16) | |
del buf1130 | |
del buf1131 | |
del buf1145 | |
buf1147 = buf1146 | |
del buf1146 | |
# Source Nodes: [w_dq_92], Original ATen: [quantized_decomposed.dequantize_per_channel_group] | |
buf1148 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg346_1, arg347_1, arg348_1, -8, 7, torch.int8, 256, torch.bfloat16) | |
del arg346_1 | |
del arg347_1 | |
del arg348_1 | |
buf1149 = buf1148 | |
del buf1148 | |
buf1150 = buf1072; del buf1072 # reuse | |
# Source Nodes: [], Original ATen: [] | |
extern_kernels.mm(reinterpret_tensor(buf1147, (1, 4096), (4096, 1), 0), reinterpret_tensor(buf1149, (4096, 1024), (1, 4096), 0), out=buf1150) | |
del buf1147 | |
del buf1149 | |
# Source Nodes: [input_281], Original ATen: [quantized_decomposed.quantize_per_token] | |
buf1152 = torch.ops.quantized_decomposed.quantize_per_token.default(buf1125, buf1133, buf1134, -128, 127, torch.int8) | |
del buf1125 | |
buf1153 = buf1152 | |
del buf1152 | |
# Source Nodes: [input_282], Original ATen: [quantized_decomposed.dequantize_per_token] | |
buf1154 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf1153, buf1133, buf1134, -128, 127, torch.int8, torch.bfloat16) | |
del buf1133 | |
del buf1134 | |
del buf1153 | |
buf1155 = buf1154 | |
del buf1154 | |
# Source Nodes: [w_dq_93], Original ATen: [quantized_decomposed.dequantize_per_channel_group] | |
buf1156 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg349_1, arg350_1, arg351_1, -8, 7, torch.int8, 256, torch.bfloat16) | |
del arg349_1 | |
del arg350_1 | |
del arg351_1 | |
buf1157 = buf1156 | |
del buf1156 | |
buf1158 = buf1064; del buf1064 # reuse | |
# Source Nodes: [], Original ATen: [] | |
extern_kernels.mm(reinterpret_tensor(buf1155, (1, 4096), (4096, 1), 0), reinterpret_tensor(buf1157, (4096, 1024), (1, 4096), 0), out=buf1158) | |
del buf1157 | |
buf1160 = reinterpret_tensor(buf1155, (1, 32, 1, 128), (4096, 128, 4096, 1), 0); del buf1155 # reuse | |
# Source Nodes: [output_26, setitem_26, setitem_27], Original ATen: [aten._scaled_dot_product_efficient_attention, aten.index_put] | |
triton_poi_fused__scaled_dot_product_efficient_attention_index_put_3.run(arg3_1, buf1150, arg352_1, buf1158, buf1143, arg353_1, arg354_1, buf1160, 4096, grid=grid(4096), stream=stream0) | |
del arg352_1 | |
del buf1143 | |
buf1161 = buf1075; del buf1075 # reuse | |
# Source Nodes: [output_26], Original ATen: [aten._scaled_dot_product_efficient_attention] | |
triton_poi_fused__scaled_dot_product_efficient_attention_4.run(arg3_1, arg2_1, buf1161, 65536, grid=grid(65536), stream=stream0) | |
# Source Nodes: [output_26], Original ATen: [aten._scaled_dot_product_efficient_attention] | |
buf1162 = torch.ops.aten._scaled_dot_product_efficient_attention.default(buf1160, arg353_1, arg354_1, buf1161, False) | |
del arg353_1 | |
del arg354_1 | |
del buf1160 | |
buf1163 = buf1162[0] | |
del buf1162 | |
# Source Nodes: [choose_qparams_per_token_asymmetric_94], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric] | |
buf1167 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(reinterpret_tensor(buf1163, (1, 1, 4096), (4096, 4096, 1), 0), torch.int8) | |
buf1168 = buf1167[0] | |
buf1169 = buf1167[1] | |
del buf1167 | |
# Source Nodes: [input_284], Original ATen: [quantized_decomposed.quantize_per_token] | |
buf1170 = torch.ops.quantized_decomposed.quantize_per_token.default(reinterpret_tensor(buf1163, (1, 1, 4096), (4096, 4096, 1), 0), buf1168, buf1169, -128, 127, torch.int8) | |
buf1171 = buf1170 | |
del buf1170 | |
# Source Nodes: [input_285], Original ATen: [quantized_decomposed.dequantize_per_token] | |
buf1172 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf1171, buf1168, buf1169, -128, 127, torch.int8, torch.bfloat16) | |
del buf1168 | |
del buf1169 | |
del buf1171 | |
buf1173 = buf1172 | |
del buf1172 | |
# Source Nodes: [w_dq_94], Original ATen: [quantized_decomposed.dequantize_per_channel_group] | |
buf1174 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg355_1, arg356_1, arg357_1, -8, 7, torch.int8, 256, torch.bfloat16) | |
del arg355_1 | |
del arg356_1 | |
del arg357_1 | |
buf1175 = buf1174 | |
del buf1174 | |
buf1176 = reinterpret_tensor(buf1163, (1, 4096), (4096, 1), 0); del buf1163 # reuse | |
# Source Nodes: [c_94], Original ATen: [aten.mm] | |
triton_tem_fused_mm_2.run(buf1173, buf1175, buf1176, grid=torch._inductor.kernel.mm_common.mm_grid(1, 4096, meta0), stream=stream0) | |
del buf1175 | |
buf1177 = buf1004; del buf1004 # reuse | |
buf1179 = buf1173; del buf1173 # reuse | |
# Source Nodes: [add_96, h_13, h_14, mean_27, mul_179, mul_180, out_11, out_12, pow_28, rsqrt_27, x_fp32_27, x_normed_27], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, aten.pow, aten.rsqrt] | |
triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_14.run(buf1177, buf1176, buf1090, buf1037, buf1123, arg358_1, buf1179, 1, 4096, grid=grid(1), stream=stream0) | |
del arg358_1 | |
del buf1037 | |
del buf1090 | |
del buf1123 | |
del buf1176 | |
# Source Nodes: [choose_qparams_per_token_asymmetric_95], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric] | |
buf1180 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf1179, torch.int8) | |
buf1181 = buf1180[0] | |
buf1182 = buf1180[1] | |
del buf1180 | |
# Source Nodes: [choose_qparams_per_token_asymmetric_96], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric] | |
buf1183 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf1179, torch.int8) | |
buf1184 = buf1183[0] | |
buf1185 = buf1183[1] | |
del buf1183 | |
# Source Nodes: [input_287], Original ATen: [quantized_decomposed.quantize_per_token] | |
buf1186 = torch.ops.quantized_decomposed.quantize_per_token.default(buf1179, buf1181, buf1182, -128, 127, torch.int8) | |
buf1187 = buf1186 | |
del buf1186 | |
# Source Nodes: [input_288], Original ATen: [quantized_decomposed.dequantize_per_token] | |
buf1188 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf1187, buf1181, buf1182, -128, 127, torch.int8, torch.bfloat16) | |
del buf1181 | |
del buf1182 | |
del buf1187 | |
buf1189 = buf1188 | |
del buf1188 | |
# Source Nodes: [w_dq_95], Original ATen: [quantized_decomposed.dequantize_per_channel_group] | |
buf1190 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg359_1, arg360_1, arg361_1, -8, 7, torch.int8, 256, torch.bfloat16) | |
del arg359_1 | |
del arg360_1 | |
del arg361_1 | |
buf1191 = buf1190 | |
del buf1190 | |
buf1192 = reinterpret_tensor(buf1120, (1, 14336), (14336, 1), 0); del buf1120 # reuse | |
# Source Nodes: [c_95], Original ATen: [aten.mm] | |
triton_tem_fused_mm_6.run(buf1189, buf1191, buf1192, grid=torch._inductor.kernel.mm_common.mm_grid(1, 14336, meta1), stream=stream0) | |
del buf1191 | |
# Source Nodes: [input_290], Original ATen: [quantized_decomposed.quantize_per_token] | |
buf1193 = torch.ops.quantized_decomposed.quantize_per_token.default(buf1179, buf1184, buf1185, -128, 127, torch.int8) | |
buf1194 = buf1193 | |
del buf1193 | |
# Source Nodes: [input_291], Original ATen: [quantized_decomposed.dequantize_per_token] | |
buf1195 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf1194, buf1184, buf1185, -128, 127, torch.int8, torch.bfloat16) | |
del buf1184 | |
del buf1185 | |
del buf1194 | |
buf1196 = buf1195 | |
del buf1195 | |
# Source Nodes: [w_dq_96], Original ATen: [quantized_decomposed.dequantize_per_channel_group] | |
buf1197 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg362_1, arg363_1, arg364_1, -8, 7, torch.int8, 256, torch.bfloat16) | |
del arg362_1 | |
del arg363_1 | |
del arg364_1 | |
buf1198 = buf1197 | |
del buf1197 | |
buf1200 = buf1113; del buf1113 # reuse | |
# Source Nodes: [c_96, mul_181, silu_13], Original ATen: [aten.mm, aten.mul, aten.silu] | |
triton_tem_fused_mm_mul_silu_7.run(buf1196, buf1198, buf1192, buf1200, grid=torch._inductor.kernel.mm_common.mm_grid(1, 14336, meta1), stream=stream0) | |
del buf1192 | |
del buf1198 | |
# Source Nodes: [choose_qparams_per_token_asymmetric_97], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric] | |
buf1201 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf1200, torch.int8) | |
buf1202 = buf1201[0] | |
buf1203 = buf1201[1] | |
del buf1201 | |
# Source Nodes: [input_293], Original ATen: [quantized_decomposed.quantize_per_token] | |
buf1204 = torch.ops.quantized_decomposed.quantize_per_token.default(buf1200, buf1202, buf1203, -128, 127, torch.int8) | |
buf1205 = buf1204 | |
del buf1204 | |
# Source Nodes: [input_294], Original ATen: [quantized_decomposed.dequantize_per_token] | |
buf1206 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf1205, buf1202, buf1203, -128, 127, torch.int8, torch.bfloat16) | |
del buf1202 | |
del buf1203 | |
del buf1205 | |
buf1207 = buf1206 | |
del buf1206 | |
# Source Nodes: [w_dq_97], Original ATen: [quantized_decomposed.dequantize_per_channel_group] | |
buf1208 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg365_1, arg366_1, arg367_1, -8, 7, torch.int8, 256, torch.bfloat16) | |
del arg365_1 | |
del arg366_1 | |
del arg367_1 | |
buf1209 = buf1208 | |
del buf1208 | |
buf1210 = reinterpret_tensor(buf1196, (1, 4096), (4096, 1), 0); del buf1196 # reuse | |
# Source Nodes: [c_97], Original ATen: [aten.mm] | |
triton_tem_fused_mm_8.run(buf1207, buf1209, buf1210, grid=torch._inductor.kernel.mm_common.mm_grid(1, 4096, meta0), stream=stream0) | |
del buf1209 | |
buf1212 = buf1179; del buf1179 # reuse | |
# Source Nodes: [add_98, mean_28, mul_182, out_13, pow_29, rsqrt_28, x_fp32_28, x_normed_28, y_14], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, aten.pow, aten.rsqrt] | |
triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_11.run(buf1177, buf1210, arg368_1, buf1212, 1, 4096, grid=grid(1), stream=stream0) | |
del arg368_1 | |
# Source Nodes: [choose_qparams_per_token_asymmetric_98], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric] | |
buf1213 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf1212, torch.int8) | |
buf1214 = buf1213[0] | |
buf1215 = buf1213[1] | |
del buf1213 | |
# Source Nodes: [choose_qparams_per_token_asymmetric_99], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric] | |
buf1216 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf1212, torch.int8) | |
buf1217 = buf1216[0] | |
buf1218 = buf1216[1] | |
del buf1216 | |
# Source Nodes: [choose_qparams_per_token_asymmetric_100], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric] | |
buf1219 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf1212, torch.int8) | |
buf1220 = buf1219[0] | |
buf1221 = buf1219[1] | |
del buf1219 | |
buf1222 = buf1135; del buf1135 # reuse | |
# Source Nodes: [max_15], Original ATen: [aten.max] | |
triton_poi_fused_max_1.run(arg3_1, buf1222, 1, grid=grid(1), stream=stream0) | |
u14 = buf1222.item() | |
buf1223 = None | |
with torch.cuda._DeviceGuard(0): | |
torch.cuda.set_device(0) | |
# Source Nodes: [input_296], Original ATen: [quantized_decomposed.quantize_per_token] | |
buf1224 = torch.ops.quantized_decomposed.quantize_per_token.default(buf1212, buf1214, buf1215, -128, 127, torch.int8) | |
buf1225 = buf1224 | |
del buf1224 | |
# Source Nodes: [input_297], Original ATen: [quantized_decomposed.dequantize_per_token] | |
buf1226 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf1225, buf1214, buf1215, -128, 127, torch.int8, torch.bfloat16) | |
del buf1214 | |
del buf1215 | |
del buf1225 | |
buf1227 = buf1226 | |
del buf1226 | |
# Source Nodes: [w_dq_98], Original ATen: [quantized_decomposed.dequantize_per_channel_group] | |
buf1228 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg369_1, arg370_1, arg371_1, -8, 7, torch.int8, 256, torch.bfloat16) | |
del arg369_1 | |
del arg370_1 | |
del arg371_1 | |
buf1229 = buf1228 | |
del buf1228 | |
buf1230 = reinterpret_tensor(buf1189, (1, 4096), (4096, 1), 0); del buf1189 # reuse | |
# Source Nodes: [c_98], Original ATen: [aten.mm] | |
triton_tem_fused_mm_2.run(buf1227, buf1229, buf1230, grid=torch._inductor.kernel.mm_common.mm_grid(1, 4096, meta0), stream=stream0) | |
del buf1227 | |
del buf1229 | |
# Source Nodes: [input_299], Original ATen: [quantized_decomposed.quantize_per_token] | |
buf1231 = torch.ops.quantized_decomposed.quantize_per_token.default(buf1212, buf1217, buf1218, -128, 127, torch.int8) | |
buf1232 = buf1231 | |
del buf1231 | |
# Source Nodes: [input_300], Original ATen: [quantized_decomposed.dequantize_per_token] | |
buf1233 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf1232, buf1217, buf1218, -128, 127, torch.int8, torch.bfloat16) | |
del buf1217 | |
del buf1218 | |
del buf1232 | |
buf1234 = buf1233 | |
del buf1233 | |
# Source Nodes: [w_dq_99], Original ATen: [quantized_decomposed.dequantize_per_channel_group] | |
buf1235 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg372_1, arg373_1, arg374_1, -8, 7, torch.int8, 256, torch.bfloat16) | |
del arg372_1 | |
del arg373_1 | |
del arg374_1 | |
buf1236 = buf1235 | |
del buf1235 | |
buf1237 = buf1158; del buf1158 # reuse | |
# Source Nodes: [], Original ATen: [] | |
extern_kernels.mm(reinterpret_tensor(buf1234, (1, 4096), (4096, 1), 0), reinterpret_tensor(buf1236, (4096, 1024), (1, 4096), 0), out=buf1237) | |
del buf1234 | |
del buf1236 | |
# Source Nodes: [input_302], Original ATen: [quantized_decomposed.quantize_per_token] | |
buf1239 = torch.ops.quantized_decomposed.quantize_per_token.default(buf1212, buf1220, buf1221, -128, 127, torch.int8) | |
del buf1212 | |
buf1240 = buf1239 | |
del buf1239 | |
# Source Nodes: [input_303], Original ATen: [quantized_decomposed.dequantize_per_token] | |
buf1241 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf1240, buf1220, buf1221, -128, 127, torch.int8, torch.bfloat16) | |
del buf1220 | |
del buf1221 | |
del buf1240 | |
buf1242 = buf1241 | |
del buf1241 | |
# Source Nodes: [w_dq_100], Original ATen: [quantized_decomposed.dequantize_per_channel_group] | |
buf1243 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg375_1, arg376_1, arg377_1, -8, 7, torch.int8, 256, torch.bfloat16) | |
del arg375_1 | |
del arg376_1 | |
del arg377_1 | |
buf1244 = buf1243 | |
del buf1243 | |
buf1245 = buf1150; del buf1150 # reuse | |
# Source Nodes: [], Original ATen: [] | |
extern_kernels.mm(reinterpret_tensor(buf1242, (1, 4096), (4096, 1), 0), reinterpret_tensor(buf1244, (4096, 1024), (1, 4096), 0), out=buf1245) | |
del buf1244 | |
buf1247 = reinterpret_tensor(buf1242, (1, 32, 1, 128), (4096, 128, 4096, 1), 0); del buf1242 # reuse | |
# Source Nodes: [output_28, setitem_28, setitem_29], Original ATen: [aten._scaled_dot_product_efficient_attention, aten.index_put] | |
triton_poi_fused__scaled_dot_product_efficient_attention_index_put_3.run(arg3_1, buf1237, arg378_1, buf1245, buf1230, arg379_1, arg380_1, buf1247, 4096, grid=grid(4096), stream=stream0) | |
del arg378_1 | |
del buf1230 | |
buf1248 = buf1161; del buf1161 # reuse | |
# Source Nodes: [output_28], Original ATen: [aten._scaled_dot_product_efficient_attention] | |
triton_poi_fused__scaled_dot_product_efficient_attention_4.run(arg3_1, arg2_1, buf1248, 65536, grid=grid(65536), stream=stream0) | |
# Source Nodes: [output_28], Original ATen: [aten._scaled_dot_product_efficient_attention] | |
buf1249 = torch.ops.aten._scaled_dot_product_efficient_attention.default(buf1247, arg379_1, arg380_1, buf1248, False) | |
del arg379_1 | |
del arg380_1 | |
del buf1247 | |
buf1250 = buf1249[0] | |
del buf1249 | |
# Source Nodes: [choose_qparams_per_token_asymmetric_101], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric] | |
buf1254 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(reinterpret_tensor(buf1250, (1, 1, 4096), (4096, 4096, 1), 0), torch.int8) | |
buf1255 = buf1254[0] | |
buf1256 = buf1254[1] | |
del buf1254 | |
# Source Nodes: [input_305], Original ATen: [quantized_decomposed.quantize_per_token] | |
buf1257 = torch.ops.quantized_decomposed.quantize_per_token.default(reinterpret_tensor(buf1250, (1, 1, 4096), (4096, 4096, 1), 0), buf1255, buf1256, -128, 127, torch.int8) | |
buf1258 = buf1257 | |
del buf1257 | |
# Source Nodes: [input_306], Original ATen: [quantized_decomposed.dequantize_per_token] | |
buf1259 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf1258, buf1255, buf1256, -128, 127, torch.int8, torch.bfloat16) | |
del buf1255 | |
del buf1256 | |
del buf1258 | |
buf1260 = buf1259 | |
del buf1259 | |
# Source Nodes: [w_dq_101], Original ATen: [quantized_decomposed.dequantize_per_channel_group] | |
buf1261 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg381_1, arg382_1, arg383_1, -8, 7, torch.int8, 256, torch.bfloat16) | |
del arg381_1 | |
del arg382_1 | |
del arg383_1 | |
buf1262 = buf1261 | |
del buf1261 | |
buf1263 = reinterpret_tensor(buf1250, (1, 4096), (4096, 1), 0); del buf1250 # reuse | |
# Source Nodes: [c_101], Original ATen: [aten.mm] | |
triton_tem_fused_mm_2.run(buf1260, buf1262, buf1263, grid=torch._inductor.kernel.mm_common.mm_grid(1, 4096, meta0), stream=stream0) | |
del buf1262 | |
buf1265 = buf1260; del buf1260 # reuse | |
# Source Nodes: [add_103, h_15, mean_29, mul_192, mul_193, out_13, pow_30, rsqrt_29, x_fp32_29, x_normed_29], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, aten.pow, aten.rsqrt] | |
triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_12.run(buf1263, buf1177, buf1210, arg384_1, buf1265, 1, 4096, grid=grid(1), stream=stream0) | |
del arg384_1 | |
# Source Nodes: [choose_qparams_per_token_asymmetric_102], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric] | |
buf1266 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf1265, torch.int8) | |
buf1267 = buf1266[0] | |
buf1268 = buf1266[1] | |
del buf1266 | |
# Source Nodes: [choose_qparams_per_token_asymmetric_103], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric] | |
buf1269 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf1265, torch.int8) | |
buf1270 = buf1269[0] | |
buf1271 = buf1269[1] | |
del buf1269 | |
# Source Nodes: [input_308], Original ATen: [quantized_decomposed.quantize_per_token] | |
buf1272 = torch.ops.quantized_decomposed.quantize_per_token.default(buf1265, buf1267, buf1268, -128, 127, torch.int8) | |
buf1273 = buf1272 | |
del buf1272 | |
# Source Nodes: [input_309], Original ATen: [quantized_decomposed.dequantize_per_token] | |
buf1274 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf1273, buf1267, buf1268, -128, 127, torch.int8, torch.bfloat16) | |
del buf1267 | |
del buf1268 | |
del buf1273 | |
buf1275 = buf1274 | |
del buf1274 | |
# Source Nodes: [w_dq_102], Original ATen: [quantized_decomposed.dequantize_per_channel_group] | |
buf1276 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg385_1, arg386_1, arg387_1, -8, 7, torch.int8, 256, torch.bfloat16) | |
del arg385_1 | |
del arg386_1 | |
del arg387_1 | |
buf1277 = buf1276 | |
del buf1276 | |
buf1278 = reinterpret_tensor(buf1207, (1, 14336), (14336, 1), 0); del buf1207 # reuse | |
# Source Nodes: [c_102], Original ATen: [aten.mm] | |
triton_tem_fused_mm_6.run(buf1275, buf1277, buf1278, grid=torch._inductor.kernel.mm_common.mm_grid(1, 14336, meta1), stream=stream0) | |
del buf1277 | |
# Source Nodes: [input_311], Original ATen: [quantized_decomposed.quantize_per_token] | |
buf1279 = torch.ops.quantized_decomposed.quantize_per_token.default(buf1265, buf1270, buf1271, -128, 127, torch.int8) | |
buf1280 = buf1279 | |
del buf1279 | |
# Source Nodes: [input_312], Original ATen: [quantized_decomposed.dequantize_per_token] | |
buf1281 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf1280, buf1270, buf1271, -128, 127, torch.int8, torch.bfloat16) | |
del buf1270 | |
del buf1271 | |
del buf1280 | |
buf1282 = buf1281 | |
del buf1281 | |
# Source Nodes: [w_dq_103], Original ATen: [quantized_decomposed.dequantize_per_channel_group] | |
buf1283 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg388_1, arg389_1, arg390_1, -8, 7, torch.int8, 256, torch.bfloat16) | |
del arg388_1 | |
del arg389_1 | |
del arg390_1 | |
buf1284 = buf1283 | |
del buf1283 | |
buf1286 = buf1200; del buf1200 # reuse | |
# Source Nodes: [c_103, mul_194, silu_14], Original ATen: [aten.mm, aten.mul, aten.silu] | |
triton_tem_fused_mm_mul_silu_7.run(buf1282, buf1284, buf1278, buf1286, grid=torch._inductor.kernel.mm_common.mm_grid(1, 14336, meta1), stream=stream0) | |
del buf1278 | |
del buf1284 | |
# Source Nodes: [choose_qparams_per_token_asymmetric_104], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric] | |
buf1287 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf1286, torch.int8) | |
buf1288 = buf1287[0] | |
buf1289 = buf1287[1] | |
del buf1287 | |
# Source Nodes: [input_314], Original ATen: [quantized_decomposed.quantize_per_token] | |
buf1290 = torch.ops.quantized_decomposed.quantize_per_token.default(buf1286, buf1288, buf1289, -128, 127, torch.int8) | |
buf1291 = buf1290 | |
del buf1290 | |
# Source Nodes: [input_315], Original ATen: [quantized_decomposed.dequantize_per_token] | |
buf1292 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf1291, buf1288, buf1289, -128, 127, torch.int8, torch.bfloat16) | |
del buf1288 | |
del buf1289 | |
del buf1291 | |
buf1293 = buf1292 | |
del buf1292 | |
# Source Nodes: [w_dq_104], Original ATen: [quantized_decomposed.dequantize_per_channel_group] | |
buf1294 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg391_1, arg392_1, arg393_1, -8, 7, torch.int8, 256, torch.bfloat16) | |
del arg391_1 | |
del arg392_1 | |
del arg393_1 | |
buf1295 = buf1294 | |
del buf1294 | |
buf1296 = reinterpret_tensor(buf1282, (1, 4096), (4096, 1), 0); del buf1282 # reuse | |
# Source Nodes: [c_104], Original ATen: [aten.mm] | |
triton_tem_fused_mm_8.run(buf1293, buf1295, buf1296, grid=torch._inductor.kernel.mm_common.mm_grid(1, 4096, meta0), stream=stream0) | |
del buf1295 | |
buf1298 = buf1265; del buf1265 # reuse | |
# Source Nodes: [add_105, h_15, mean_30, mul_195, out_13, out_14, pow_31, rsqrt_30, x_fp32_30, x_normed_30, y_15], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, aten.pow, aten.rsqrt] | |
triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_13.run(buf1263, buf1177, buf1210, buf1296, arg394_1, buf1298, 1, 4096, grid=grid(1), stream=stream0) | |
del arg394_1 | |
# Source Nodes: [choose_qparams_per_token_asymmetric_105], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric] | |
buf1299 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf1298, torch.int8) | |
buf1300 = buf1299[0] | |
buf1301 = buf1299[1] | |
del buf1299 | |
# Source Nodes: [choose_qparams_per_token_asymmetric_106], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric] | |
buf1302 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf1298, torch.int8) | |
buf1303 = buf1302[0] | |
buf1304 = buf1302[1] | |
del buf1302 | |
# Source Nodes: [choose_qparams_per_token_asymmetric_107], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric] | |
buf1305 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf1298, torch.int8) | |
buf1306 = buf1305[0] | |
buf1307 = buf1305[1] | |
del buf1305 | |
buf1308 = buf1222; del buf1222 # reuse | |
# Source Nodes: [max_16], Original ATen: [aten.max] | |
triton_poi_fused_max_1.run(arg3_1, buf1308, 1, grid=grid(1), stream=stream0) | |
u15 = buf1308.item() | |
buf1309 = None | |
with torch.cuda._DeviceGuard(0): | |
torch.cuda.set_device(0) | |
# Source Nodes: [input_317], Original ATen: [quantized_decomposed.quantize_per_token] | |
buf1310 = torch.ops.quantized_decomposed.quantize_per_token.default(buf1298, buf1300, buf1301, -128, 127, torch.int8) | |
buf1311 = buf1310 | |
del buf1310 | |
# Source Nodes: [input_318], Original ATen: [quantized_decomposed.dequantize_per_token] | |
buf1312 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf1311, buf1300, buf1301, -128, 127, torch.int8, torch.bfloat16) | |
del buf1300 | |
del buf1301 | |
del buf1311 | |
buf1313 = buf1312 | |
del buf1312 | |
# Source Nodes: [w_dq_105], Original ATen: [quantized_decomposed.dequantize_per_channel_group] | |
buf1314 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg395_1, arg396_1, arg397_1, -8, 7, torch.int8, 256, torch.bfloat16) | |
del arg395_1 | |
del arg396_1 | |
del arg397_1 | |
buf1315 = buf1314 | |
del buf1314 | |
buf1316 = reinterpret_tensor(buf1275, (1, 4096), (4096, 1), 0); del buf1275 # reuse | |
# Source Nodes: [c_105], Original ATen: [aten.mm] | |
triton_tem_fused_mm_2.run(buf1313, buf1315, buf1316, grid=torch._inductor.kernel.mm_common.mm_grid(1, 4096, meta0), stream=stream0) | |
del buf1313 | |
del buf1315 | |
# Source Nodes: [input_320], Original ATen: [quantized_decomposed.quantize_per_token] | |
buf1317 = torch.ops.quantized_decomposed.quantize_per_token.default(buf1298, buf1303, buf1304, -128, 127, torch.int8) | |
buf1318 = buf1317 | |
del buf1317 | |
# Source Nodes: [input_321], Original ATen: [quantized_decomposed.dequantize_per_token] | |
buf1319 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf1318, buf1303, buf1304, -128, 127, torch.int8, torch.bfloat16) | |
del buf1303 | |
del buf1304 | |
del buf1318 | |
buf1320 = buf1319 | |
del buf1319 | |
# Source Nodes: [w_dq_106], Original ATen: [quantized_decomposed.dequantize_per_channel_group] | |
buf1321 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg398_1, arg399_1, arg400_1, -8, 7, torch.int8, 256, torch.bfloat16) | |
del arg398_1 | |
del arg399_1 | |
del arg400_1 | |
buf1322 = buf1321 | |
del buf1321 | |
buf1323 = buf1245; del buf1245 # reuse | |
# Source Nodes: [], Original ATen: [] | |
extern_kernels.mm(reinterpret_tensor(buf1320, (1, 4096), (4096, 1), 0), reinterpret_tensor(buf1322, (4096, 1024), (1, 4096), 0), out=buf1323) | |
del buf1320 | |
del buf1322 | |
# Source Nodes: [input_323], Original ATen: [quantized_decomposed.quantize_per_token] | |
buf1325 = torch.ops.quantized_decomposed.quantize_per_token.default(buf1298, buf1306, buf1307, -128, 127, torch.int8) | |
del buf1298 | |
buf1326 = buf1325 | |
del buf1325 | |
# Source Nodes: [input_324], Original ATen: [quantized_decomposed.dequantize_per_token] | |
buf1327 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf1326, buf1306, buf1307, -128, 127, torch.int8, torch.bfloat16) | |
del buf1306 | |
del buf1307 | |
del buf1326 | |
buf1328 = buf1327 | |
del buf1327 | |
# Source Nodes: [w_dq_107], Original ATen: [quantized_decomposed.dequantize_per_channel_group] | |
buf1329 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg401_1, arg402_1, arg403_1, -8, 7, torch.int8, 256, torch.bfloat16) | |
del arg401_1 | |
del arg402_1 | |
del arg403_1 | |
buf1330 = buf1329 | |
del buf1329 | |
buf1331 = buf1237; del buf1237 # reuse | |
# Source Nodes: [], Original ATen: [] | |
extern_kernels.mm(reinterpret_tensor(buf1328, (1, 4096), (4096, 1), 0), reinterpret_tensor(buf1330, (4096, 1024), (1, 4096), 0), out=buf1331) | |
del buf1330 | |
buf1333 = reinterpret_tensor(buf1328, (1, 32, 1, 128), (4096, 128, 4096, 1), 0); del buf1328 # reuse | |
# Source Nodes: [output_30, setitem_30, setitem_31], Original ATen: [aten._scaled_dot_product_efficient_attention, aten.index_put] | |
triton_poi_fused__scaled_dot_product_efficient_attention_index_put_3.run(arg3_1, buf1323, arg404_1, buf1331, buf1316, arg405_1, arg406_1, buf1333, 4096, grid=grid(4096), stream=stream0) | |
del arg404_1 | |
del buf1316 | |
buf1334 = buf1248; del buf1248 # reuse | |
# Source Nodes: [output_30], Original ATen: [aten._scaled_dot_product_efficient_attention] | |
triton_poi_fused__scaled_dot_product_efficient_attention_4.run(arg3_1, arg2_1, buf1334, 65536, grid=grid(65536), stream=stream0) | |
# Source Nodes: [output_30], Original ATen: [aten._scaled_dot_product_efficient_attention] | |
buf1335 = torch.ops.aten._scaled_dot_product_efficient_attention.default(buf1333, arg405_1, arg406_1, buf1334, False) | |
del arg405_1 | |
del arg406_1 | |
del buf1333 | |
buf1336 = buf1335[0] | |
del buf1335 | |
# Source Nodes: [choose_qparams_per_token_asymmetric_108], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric] | |
buf1340 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(reinterpret_tensor(buf1336, (1, 1, 4096), (4096, 4096, 1), 0), torch.int8) | |
buf1341 = buf1340[0] | |
buf1342 = buf1340[1] | |
del buf1340 | |
# Source Nodes: [input_326], Original ATen: [quantized_decomposed.quantize_per_token] | |
buf1343 = torch.ops.quantized_decomposed.quantize_per_token.default(reinterpret_tensor(buf1336, (1, 1, 4096), (4096, 4096, 1), 0), buf1341, buf1342, -128, 127, torch.int8) | |
buf1344 = buf1343 | |
del buf1343 | |
# Source Nodes: [input_327], Original ATen: [quantized_decomposed.dequantize_per_token] | |
buf1345 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf1344, buf1341, buf1342, -128, 127, torch.int8, torch.bfloat16) | |
del buf1341 | |
del buf1342 | |
del buf1344 | |
buf1346 = buf1345 | |
del buf1345 | |
# Source Nodes: [w_dq_108], Original ATen: [quantized_decomposed.dequantize_per_channel_group] | |
buf1347 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg407_1, arg408_1, arg409_1, -8, 7, torch.int8, 256, torch.bfloat16) | |
del arg407_1 | |
del arg408_1 | |
del arg409_1 | |
buf1348 = buf1347 | |
del buf1347 | |
buf1349 = reinterpret_tensor(buf1336, (1, 4096), (4096, 1), 0); del buf1336 # reuse | |
# Source Nodes: [c_108], Original ATen: [aten.mm] | |
triton_tem_fused_mm_2.run(buf1346, buf1348, buf1349, grid=torch._inductor.kernel.mm_common.mm_grid(1, 4096, meta0), stream=stream0) | |
del buf1348 | |
buf1350 = buf1177; del buf1177 # reuse | |
buf1352 = buf1346; del buf1346 # reuse | |
# Source Nodes: [add_110, h_15, h_16, mean_31, mul_205, mul_206, out_13, out_14, pow_32, rsqrt_31, x_fp32_31, x_normed_31], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, aten.pow, aten.rsqrt] | |
triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_14.run(buf1350, buf1349, buf1263, buf1210, buf1296, arg410_1, buf1352, 1, 4096, grid=grid(1), stream=stream0) | |
del arg410_1 | |
del buf1210 | |
del buf1263 | |
del buf1296 | |
del buf1349 | |
# Source Nodes: [choose_qparams_per_token_asymmetric_109], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric] | |
buf1353 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf1352, torch.int8) | |
buf1354 = buf1353[0] | |
buf1355 = buf1353[1] | |
del buf1353 | |
# Source Nodes: [choose_qparams_per_token_asymmetric_110], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric] | |
buf1356 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf1352, torch.int8) | |
buf1357 = buf1356[0] | |
buf1358 = buf1356[1] | |
del buf1356 | |
# Source Nodes: [input_329], Original ATen: [quantized_decomposed.quantize_per_token] | |
buf1359 = torch.ops.quantized_decomposed.quantize_per_token.default(buf1352, buf1354, buf1355, -128, 127, torch.int8) | |
buf1360 = buf1359 | |
del buf1359 | |
# Source Nodes: [input_330], Original ATen: [quantized_decomposed.dequantize_per_token] | |
buf1361 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf1360, buf1354, buf1355, -128, 127, torch.int8, torch.bfloat16) | |
del buf1354 | |
del buf1355 | |
del buf1360 | |
buf1362 = buf1361 | |
del buf1361 | |
# Source Nodes: [w_dq_109], Original ATen: [quantized_decomposed.dequantize_per_channel_group] | |
buf1363 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg411_1, arg412_1, arg413_1, -8, 7, torch.int8, 256, torch.bfloat16) | |
del arg411_1 | |
del arg412_1 | |
del arg413_1 | |
buf1364 = buf1363 | |
del buf1363 | |
buf1365 = reinterpret_tensor(buf1293, (1, 14336), (14336, 1), 0); del buf1293 # reuse | |
# Source Nodes: [c_109], Original ATen: [aten.mm] | |
triton_tem_fused_mm_6.run(buf1362, buf1364, buf1365, grid=torch._inductor.kernel.mm_common.mm_grid(1, 14336, meta1), stream=stream0) | |
del buf1364 | |
# Source Nodes: [input_332], Original ATen: [quantized_decomposed.quantize_per_token] | |
buf1366 = torch.ops.quantized_decomposed.quantize_per_token.default(buf1352, buf1357, buf1358, -128, 127, torch.int8) | |
buf1367 = buf1366 | |
del buf1366 | |
# Source Nodes: [input_333], Original ATen: [quantized_decomposed.dequantize_per_token] | |
buf1368 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf1367, buf1357, buf1358, -128, 127, torch.int8, torch.bfloat16) | |
del buf1357 | |
del buf1358 | |
del buf1367 | |
buf1369 = buf1368 | |
del buf1368 | |
# Source Nodes: [w_dq_110], Original ATen: [quantized_decomposed.dequantize_per_channel_group] | |
buf1370 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg414_1, arg415_1, arg416_1, -8, 7, torch.int8, 256, torch.bfloat16) | |
del arg414_1 | |
del arg415_1 | |
del arg416_1 | |
buf1371 = buf1370 | |
del buf1370 | |
buf1373 = buf1286; del buf1286 # reuse | |
# Source Nodes: [c_110, mul_207, silu_15], Original ATen: [aten.mm, aten.mul, aten.silu] | |
triton_tem_fused_mm_mul_silu_7.run(buf1369, buf1371, buf1365, buf1373, grid=torch._inductor.kernel.mm_common.mm_grid(1, 14336, meta1), stream=stream0) | |
del buf1365 | |
del buf1371 | |
# Source Nodes: [choose_qparams_per_token_asymmetric_111], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric] | |
buf1374 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf1373, torch.int8) | |
buf1375 = buf1374[0] | |
buf1376 = buf1374[1] | |
del buf1374 | |
# Source Nodes: [input_335], Original ATen: [quantized_decomposed.quantize_per_token] | |
buf1377 = torch.ops.quantized_decomposed.quantize_per_token.default(buf1373, buf1375, buf1376, -128, 127, torch.int8) | |
buf1378 = buf1377 | |
del buf1377 | |
# Source Nodes: [input_336], Original ATen: [quantized_decomposed.dequantize_per_token] | |
buf1379 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf1378, buf1375, buf1376, -128, 127, torch.int8, torch.bfloat16) | |
del buf1375 | |
del buf1376 | |
del buf1378 | |
buf1380 = buf1379 | |
del buf1379 | |
# Source Nodes: [w_dq_111], Original ATen: [quantized_decomposed.dequantize_per_channel_group] | |
buf1381 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg417_1, arg418_1, arg419_1, -8, 7, torch.int8, 256, torch.bfloat16) | |
del arg417_1 | |
del arg418_1 | |
del arg419_1 | |
buf1382 = buf1381 | |
del buf1381 | |
buf1383 = reinterpret_tensor(buf1369, (1, 4096), (4096, 1), 0); del buf1369 # reuse | |
# Source Nodes: [c_111], Original ATen: [aten.mm] | |
triton_tem_fused_mm_8.run(buf1380, buf1382, buf1383, grid=torch._inductor.kernel.mm_common.mm_grid(1, 4096, meta0), stream=stream0) | |
del buf1382 | |
buf1385 = buf1352; del buf1352 # reuse | |
# Source Nodes: [add_112, mean_32, mul_208, out_15, pow_33, rsqrt_32, x_fp32_32, x_normed_32, y_16], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, aten.pow, aten.rsqrt] | |
triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_11.run(buf1350, buf1383, arg420_1, buf1385, 1, 4096, grid=grid(1), stream=stream0) | |
del arg420_1 | |
# Source Nodes: [choose_qparams_per_token_asymmetric_112], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric] | |
buf1386 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf1385, torch.int8) | |
buf1387 = buf1386[0] | |
buf1388 = buf1386[1] | |
del buf1386 | |
# Source Nodes: [choose_qparams_per_token_asymmetric_113], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric] | |
buf1389 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf1385, torch.int8) | |
buf1390 = buf1389[0] | |
buf1391 = buf1389[1] | |
del buf1389 | |
# Source Nodes: [choose_qparams_per_token_asymmetric_114], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric] | |
buf1392 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf1385, torch.int8) | |
buf1393 = buf1392[0] | |
buf1394 = buf1392[1] | |
del buf1392 | |
buf1395 = buf1308; del buf1308 # reuse | |
# Source Nodes: [max_17], Original ATen: [aten.max] | |
triton_poi_fused_max_1.run(arg3_1, buf1395, 1, grid=grid(1), stream=stream0) | |
u16 = buf1395.item() | |
buf1396 = None | |
with torch.cuda._DeviceGuard(0): | |
torch.cuda.set_device(0) | |
# Source Nodes: [input_338], Original ATen: [quantized_decomposed.quantize_per_token] | |
buf1397 = torch.ops.quantized_decomposed.quantize_per_token.default(buf1385, buf1387, buf1388, -128, 127, torch.int8) | |
buf1398 = buf1397 | |
del buf1397 | |
# Source Nodes: [input_339], Original ATen: [quantized_decomposed.dequantize_per_token] | |
buf1399 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf1398, buf1387, buf1388, -128, 127, torch.int8, torch.bfloat16) | |
del buf1387 | |
del buf1388 | |
del buf1398 | |
buf1400 = buf1399 | |
del buf1399 | |
# Source Nodes: [w_dq_112], Original ATen: [quantized_decomposed.dequantize_per_channel_group] | |
buf1401 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg421_1, arg422_1, arg423_1, -8, 7, torch.int8, 256, torch.bfloat16) | |
del arg421_1 | |
del arg422_1 | |
del arg423_1 | |
buf1402 = buf1401 | |
del buf1401 | |
buf1403 = reinterpret_tensor(buf1362, (1, 4096), (4096, 1), 0); del buf1362 # reuse | |
# Source Nodes: [c_112], Original ATen: [aten.mm] | |
triton_tem_fused_mm_2.run(buf1400, buf1402, buf1403, grid=torch._inductor.kernel.mm_common.mm_grid(1, 4096, meta0), stream=stream0) | |
del buf1400 | |
del buf1402 | |
# Source Nodes: [input_341], Original ATen: [quantized_decomposed.quantize_per_token] | |
buf1404 = torch.ops.quantized_decomposed.quantize_per_token.default(buf1385, buf1390, buf1391, -128, 127, torch.int8) | |
buf1405 = buf1404 | |
del buf1404 | |
# Source Nodes: [input_342], Original ATen: [quantized_decomposed.dequantize_per_token] | |
buf1406 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf1405, buf1390, buf1391, -128, 127, torch.int8, torch.bfloat16) | |
del buf1390 | |
del buf1391 | |
del buf1405 | |
buf1407 = buf1406 | |
del buf1406 | |
# Source Nodes: [w_dq_113], Original ATen: [quantized_decomposed.dequantize_per_channel_group] | |
buf1408 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg424_1, arg425_1, arg426_1, -8, 7, torch.int8, 256, torch.bfloat16) | |
del arg424_1 | |
del arg425_1 | |
del arg426_1 | |
buf1409 = buf1408 | |
del buf1408 | |
buf1410 = buf1331; del buf1331 # reuse | |
# Source Nodes: [], Original ATen: [] | |
extern_kernels.mm(reinterpret_tensor(buf1407, (1, 4096), (4096, 1), 0), reinterpret_tensor(buf1409, (4096, 1024), (1, 4096), 0), out=buf1410) | |
del buf1407 | |
del buf1409 | |
# Source Nodes: [input_344], Original ATen: [quantized_decomposed.quantize_per_token] | |
buf1412 = torch.ops.quantized_decomposed.quantize_per_token.default(buf1385, buf1393, buf1394, -128, 127, torch.int8) | |
del buf1385 | |
buf1413 = buf1412 | |
del buf1412 | |
# Source Nodes: [input_345], Original ATen: [quantized_decomposed.dequantize_per_token] | |
buf1414 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf1413, buf1393, buf1394, -128, 127, torch.int8, torch.bfloat16) | |
del buf1393 | |
del buf1394 | |
del buf1413 | |
buf1415 = buf1414 | |
del buf1414 | |
# Source Nodes: [w_dq_114], Original ATen: [quantized_decomposed.dequantize_per_channel_group] | |
buf1416 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg427_1, arg428_1, arg429_1, -8, 7, torch.int8, 256, torch.bfloat16) | |
del arg427_1 | |
del arg428_1 | |
del arg429_1 | |
buf1417 = buf1416 | |
del buf1416 | |
buf1418 = buf1323; del buf1323 # reuse | |
# Source Nodes: [], Original ATen: [] | |
extern_kernels.mm(reinterpret_tensor(buf1415, (1, 4096), (4096, 1), 0), reinterpret_tensor(buf1417, (4096, 1024), (1, 4096), 0), out=buf1418) | |
del buf1417 | |
buf1420 = reinterpret_tensor(buf1415, (1, 32, 1, 128), (4096, 128, 4096, 1), 0); del buf1415 # reuse | |
# Source Nodes: [output_32, setitem_32, setitem_33], Original ATen: [aten._scaled_dot_product_efficient_attention, aten.index_put] | |
triton_poi_fused__scaled_dot_product_efficient_attention_index_put_3.run(arg3_1, buf1410, arg430_1, buf1418, buf1403, arg431_1, arg432_1, buf1420, 4096, grid=grid(4096), stream=stream0) | |
del arg430_1 | |
del buf1403 | |
buf1421 = buf1334; del buf1334 # reuse | |
# Source Nodes: [output_32], Original ATen: [aten._scaled_dot_product_efficient_attention] | |
triton_poi_fused__scaled_dot_product_efficient_attention_4.run(arg3_1, arg2_1, buf1421, 65536, grid=grid(65536), stream=stream0) | |
# Source Nodes: [output_32], Original ATen: [aten._scaled_dot_product_efficient_attention] | |
buf1422 = torch.ops.aten._scaled_dot_product_efficient_attention.default(buf1420, arg431_1, arg432_1, buf1421, False) | |
del arg431_1 | |
del arg432_1 | |
del buf1420 | |
buf1423 = buf1422[0] | |
del buf1422 | |
# Source Nodes: [choose_qparams_per_token_asymmetric_115], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric] | |
buf1427 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(reinterpret_tensor(buf1423, (1, 1, 4096), (4096, 4096, 1), 0), torch.int8) | |
buf1428 = buf1427[0] | |
buf1429 = buf1427[1] | |
del buf1427 | |
# Source Nodes: [input_347], Original ATen: [quantized_decomposed.quantize_per_token] | |
buf1430 = torch.ops.quantized_decomposed.quantize_per_token.default(reinterpret_tensor(buf1423, (1, 1, 4096), (4096, 4096, 1), 0), buf1428, buf1429, -128, 127, torch.int8) | |
buf1431 = buf1430 | |
del buf1430 | |
# Source Nodes: [input_348], Original ATen: [quantized_decomposed.dequantize_per_token] | |
buf1432 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf1431, buf1428, buf1429, -128, 127, torch.int8, torch.bfloat16) | |
del buf1428 | |
del buf1429 | |
del buf1431 | |
buf1433 = buf1432 | |
del buf1432 | |
# Source Nodes: [w_dq_115], Original ATen: [quantized_decomposed.dequantize_per_channel_group] | |
buf1434 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg433_1, arg434_1, arg435_1, -8, 7, torch.int8, 256, torch.bfloat16) | |
del arg433_1 | |
del arg434_1 | |
del arg435_1 | |
buf1435 = buf1434 | |
del buf1434 | |
buf1436 = reinterpret_tensor(buf1423, (1, 4096), (4096, 1), 0); del buf1423 # reuse | |
# Source Nodes: [c_115], Original ATen: [aten.mm] | |
triton_tem_fused_mm_2.run(buf1433, buf1435, buf1436, grid=torch._inductor.kernel.mm_common.mm_grid(1, 4096, meta0), stream=stream0) | |
del buf1435 | |
buf1438 = buf1433; del buf1433 # reuse | |
# Source Nodes: [add_117, h_17, mean_33, mul_218, mul_219, out_15, pow_34, rsqrt_33, x_fp32_33, x_normed_33], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, aten.pow, aten.rsqrt] | |
triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_12.run(buf1436, buf1350, buf1383, arg436_1, buf1438, 1, 4096, grid=grid(1), stream=stream0) | |
del arg436_1 | |
# Source Nodes: [choose_qparams_per_token_asymmetric_116], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric] | |
buf1439 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf1438, torch.int8) | |
buf1440 = buf1439[0] | |
buf1441 = buf1439[1] | |
del buf1439 | |
# Source Nodes: [choose_qparams_per_token_asymmetric_117], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric] | |
buf1442 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf1438, torch.int8) | |
buf1443 = buf1442[0] | |
buf1444 = buf1442[1] | |
del buf1442 | |
# Source Nodes: [input_350], Original ATen: [quantized_decomposed.quantize_per_token] | |
buf1445 = torch.ops.quantized_decomposed.quantize_per_token.default(buf1438, buf1440, buf1441, -128, 127, torch.int8) | |
buf1446 = buf1445 | |
del buf1445 | |
# Source Nodes: [input_351], Original ATen: [quantized_decomposed.dequantize_per_token] | |
buf1447 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf1446, buf1440, buf1441, -128, 127, torch.int8, torch.bfloat16) | |
del buf1440 | |
del buf1441 | |
del buf1446 | |
buf1448 = buf1447 | |
del buf1447 | |
# Source Nodes: [w_dq_116], Original ATen: [quantized_decomposed.dequantize_per_channel_group] | |
buf1449 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg437_1, arg438_1, arg439_1, -8, 7, torch.int8, 256, torch.bfloat16) | |
del arg437_1 | |
del arg438_1 | |
del arg439_1 | |
buf1450 = buf1449 | |
del buf1449 | |
buf1451 = reinterpret_tensor(buf1380, (1, 14336), (14336, 1), 0); del buf1380 # reuse | |
# Source Nodes: [c_116], Original ATen: [aten.mm] | |
triton_tem_fused_mm_6.run(buf1448, buf1450, buf1451, grid=torch._inductor.kernel.mm_common.mm_grid(1, 14336, meta1), stream=stream0) | |
del buf1450 | |
# Source Nodes: [input_353], Original ATen: [quantized_decomposed.quantize_per_token] | |
buf1452 = torch.ops.quantized_decomposed.quantize_per_token.default(buf1438, buf1443, buf1444, -128, 127, torch.int8) | |
buf1453 = buf1452 | |
del buf1452 | |
# Source Nodes: [input_354], Original ATen: [quantized_decomposed.dequantize_per_token] | |
buf1454 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf1453, buf1443, buf1444, -128, 127, torch.int8, torch.bfloat16) | |
del buf1443 | |
del buf1444 | |
del buf1453 | |
buf1455 = buf1454 | |
del buf1454 | |
# Source Nodes: [w_dq_117], Original ATen: [quantized_decomposed.dequantize_per_channel_group] | |
buf1456 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg440_1, arg441_1, arg442_1, -8, 7, torch.int8, 256, torch.bfloat16) | |
del arg440_1 | |
del arg441_1 | |
del arg442_1 | |
buf1457 = buf1456 | |
del buf1456 | |
buf1459 = buf1373; del buf1373 # reuse | |
# Source Nodes: [c_117, mul_220, silu_16], Original ATen: [aten.mm, aten.mul, aten.silu] | |
triton_tem_fused_mm_mul_silu_7.run(buf1455, buf1457, buf1451, buf1459, grid=torch._inductor.kernel.mm_common.mm_grid(1, 14336, meta1), stream=stream0) | |
del buf1451 | |
del buf1457 | |
# Source Nodes: [choose_qparams_per_token_asymmetric_118], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric] | |
buf1460 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf1459, torch.int8) | |
buf1461 = buf1460[0] | |
buf1462 = buf1460[1] | |
del buf1460 | |
# Source Nodes: [input_356], Original ATen: [quantized_decomposed.quantize_per_token] | |
buf1463 = torch.ops.quantized_decomposed.quantize_per_token.default(buf1459, buf1461, buf1462, -128, 127, torch.int8) | |
buf1464 = buf1463 | |
del buf1463 | |
# Source Nodes: [input_357], Original ATen: [quantized_decomposed.dequantize_per_token] | |
buf1465 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf1464, buf1461, buf1462, -128, 127, torch.int8, torch.bfloat16) | |
del buf1461 | |
del buf1462 | |
del buf1464 | |
buf1466 = buf1465 | |
del buf1465 | |
# Source Nodes: [w_dq_118], Original ATen: [quantized_decomposed.dequantize_per_channel_group] | |
buf1467 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg443_1, arg444_1, arg445_1, -8, 7, torch.int8, 256, torch.bfloat16) | |
del arg443_1 | |
del arg444_1 | |
del arg445_1 | |
buf1468 = buf1467 | |
del buf1467 | |
buf1469 = reinterpret_tensor(buf1455, (1, 4096), (4096, 1), 0); del buf1455 # reuse | |
# Source Nodes: [c_118], Original ATen: [aten.mm] | |
triton_tem_fused_mm_8.run(buf1466, buf1468, buf1469, grid=torch._inductor.kernel.mm_common.mm_grid(1, 4096, meta0), stream=stream0) | |
del buf1468 | |
buf1471 = buf1438; del buf1438 # reuse | |
# Source Nodes: [add_119, h_17, mean_34, mul_221, out_15, out_16, pow_35, rsqrt_34, x_fp32_34, x_normed_34, y_17], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, aten.pow, aten.rsqrt] | |
triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_13.run(buf1436, buf1350, buf1383, buf1469, arg446_1, buf1471, 1, 4096, grid=grid(1), stream=stream0) | |
del arg446_1 | |
# Source Nodes: [choose_qparams_per_token_asymmetric_119], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric] | |
buf1472 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf1471, torch.int8) | |
buf1473 = buf1472[0] | |
buf1474 = buf1472[1] | |
del buf1472 | |
# Source Nodes: [choose_qparams_per_token_asymmetric_120], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric] | |
buf1475 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf1471, torch.int8) | |
buf1476 = buf1475[0] | |
buf1477 = buf1475[1] | |
del buf1475 | |
# Source Nodes: [choose_qparams_per_token_asymmetric_121], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric] | |
buf1478 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf1471, torch.int8) | |
buf1479 = buf1478[0] | |
buf1480 = buf1478[1] | |
del buf1478 | |
buf1481 = buf1395; del buf1395 # reuse | |
# Source Nodes: [max_18], Original ATen: [aten.max] | |
triton_poi_fused_max_1.run(arg3_1, buf1481, 1, grid=grid(1), stream=stream0) | |
u17 = buf1481.item() | |
buf1482 = None | |
with torch.cuda._DeviceGuard(0): | |
torch.cuda.set_device(0) | |
# Source Nodes: [input_359], Original ATen: [quantized_decomposed.quantize_per_token] | |
buf1483 = torch.ops.quantized_decomposed.quantize_per_token.default(buf1471, buf1473, buf1474, -128, 127, torch.int8) | |
buf1484 = buf1483 | |
del buf1483 | |
# Source Nodes: [input_360], Original ATen: [quantized_decomposed.dequantize_per_token] | |
buf1485 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf1484, buf1473, buf1474, -128, 127, torch.int8, torch.bfloat16) | |
del buf1473 | |
del buf1474 | |
del buf1484 | |
buf1486 = buf1485 | |
del buf1485 | |
# Source Nodes: [w_dq_119], Original ATen: [quantized_decomposed.dequantize_per_channel_group] | |
buf1487 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg447_1, arg448_1, arg449_1, -8, 7, torch.int8, 256, torch.bfloat16) | |
del arg447_1 | |
del arg448_1 | |
del arg449_1 | |
buf1488 = buf1487 | |
del buf1487 | |
buf1489 = reinterpret_tensor(buf1448, (1, 4096), (4096, 1), 0); del buf1448 # reuse | |
# Source Nodes: [c_119], Original ATen: [aten.mm] | |
triton_tem_fused_mm_2.run(buf1486, buf1488, buf1489, grid=torch._inductor.kernel.mm_common.mm_grid(1, 4096, meta0), stream=stream0) | |
del buf1486 | |
del buf1488 | |
# Source Nodes: [input_362], Original ATen: [quantized_decomposed.quantize_per_token] | |
buf1490 = torch.ops.quantized_decomposed.quantize_per_token.default(buf1471, buf1476, buf1477, -128, 127, torch.int8) | |
buf1491 = buf1490 | |
del buf1490 | |
# Source Nodes: [input_363], Original ATen: [quantized_decomposed.dequantize_per_token] | |
buf1492 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf1491, buf1476, buf1477, -128, 127, torch.int8, torch.bfloat16) | |
del buf1476 | |
del buf1477 | |
del buf1491 | |
buf1493 = buf1492 | |
del buf1492 | |
# Source Nodes: [w_dq_120], Original ATen: [quantized_decomposed.dequantize_per_channel_group] | |
buf1494 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg450_1, arg451_1, arg452_1, -8, 7, torch.int8, 256, torch.bfloat16) | |
del arg450_1 | |
del arg451_1 | |
del arg452_1 | |
buf1495 = buf1494 | |
del buf1494 | |
buf1496 = buf1418; del buf1418 # reuse | |
# Source Nodes: [], Original ATen: [] | |
extern_kernels.mm(reinterpret_tensor(buf1493, (1, 4096), (4096, 1), 0), reinterpret_tensor(buf1495, (4096, 1024), (1, 4096), 0), out=buf1496) | |
del buf1493 | |
del buf1495 | |
# Source Nodes: [input_365], Original ATen: [quantized_decomposed.quantize_per_token] | |
buf1498 = torch.ops.quantized_decomposed.quantize_per_token.default(buf1471, buf1479, buf1480, -128, 127, torch.int8) | |
del buf1471 | |
buf1499 = buf1498 | |
del buf1498 | |
# Source Nodes: [input_366], Original ATen: [quantized_decomposed.dequantize_per_token] | |
buf1500 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf1499, buf1479, buf1480, -128, 127, torch.int8, torch.bfloat16) | |
del buf1479 | |
del buf1480 | |
del buf1499 | |
buf1501 = buf1500 | |
del buf1500 | |
# Source Nodes: [w_dq_121], Original ATen: [quantized_decomposed.dequantize_per_channel_group] | |
buf1502 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg453_1, arg454_1, arg455_1, -8, 7, torch.int8, 256, torch.bfloat16) | |
del arg453_1 | |
del arg454_1 | |
del arg455_1 | |
buf1503 = buf1502 | |
del buf1502 | |
buf1504 = buf1410; del buf1410 # reuse | |
# Source Nodes: [], Original ATen: [] | |
extern_kernels.mm(reinterpret_tensor(buf1501, (1, 4096), (4096, 1), 0), reinterpret_tensor(buf1503, (4096, 1024), (1, 4096), 0), out=buf1504) | |
del buf1503 | |
buf1506 = reinterpret_tensor(buf1501, (1, 32, 1, 128), (4096, 128, 4096, 1), 0); del buf1501 # reuse | |
# Source Nodes: [output_34, setitem_34, setitem_35], Original ATen: [aten._scaled_dot_product_efficient_attention, aten.index_put] | |
triton_poi_fused__scaled_dot_product_efficient_attention_index_put_3.run(arg3_1, buf1496, arg456_1, buf1504, buf1489, arg457_1, arg458_1, buf1506, 4096, grid=grid(4096), stream=stream0) | |
del arg456_1 | |
del buf1489 | |
buf1507 = buf1421; del buf1421 # reuse | |
# Source Nodes: [output_34], Original ATen: [aten._scaled_dot_product_efficient_attention] | |
triton_poi_fused__scaled_dot_product_efficient_attention_4.run(arg3_1, arg2_1, buf1507, 65536, grid=grid(65536), stream=stream0) | |
# Source Nodes: [output_34], Original ATen: [aten._scaled_dot_product_efficient_attention] | |
buf1508 = torch.ops.aten._scaled_dot_product_efficient_attention.default(buf1506, arg457_1, arg458_1, buf1507, False) | |
del arg457_1 | |
del arg458_1 | |
del buf1506 | |
buf1509 = buf1508[0] | |
del buf1508 | |
# Source Nodes: [choose_qparams_per_token_asymmetric_122], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric] | |
buf1513 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(reinterpret_tensor(buf1509, (1, 1, 4096), (4096, 4096, 1), 0), torch.int8) | |
buf1514 = buf1513[0] | |
buf1515 = buf1513[1] | |
del buf1513 | |
# Source Nodes: [input_368], Original ATen: [quantized_decomposed.quantize_per_token] | |
buf1516 = torch.ops.quantized_decomposed.quantize_per_token.default(reinterpret_tensor(buf1509, (1, 1, 4096), (4096, 4096, 1), 0), buf1514, buf1515, -128, 127, torch.int8) | |
buf1517 = buf1516 | |
del buf1516 | |
# Source Nodes: [input_369], Original ATen: [quantized_decomposed.dequantize_per_token] | |
buf1518 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf1517, buf1514, buf1515, -128, 127, torch.int8, torch.bfloat16) | |
del buf1514 | |
del buf1515 | |
del buf1517 | |
buf1519 = buf1518 | |
del buf1518 | |
# Source Nodes: [w_dq_122], Original ATen: [quantized_decomposed.dequantize_per_channel_group] | |
buf1520 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg459_1, arg460_1, arg461_1, -8, 7, torch.int8, 256, torch.bfloat16) | |
del arg459_1 | |
del arg460_1 | |
del arg461_1 | |
buf1521 = buf1520 | |
del buf1520 | |
buf1522 = reinterpret_tensor(buf1509, (1, 4096), (4096, 1), 0); del buf1509 # reuse | |
# Source Nodes: [c_122], Original ATen: [aten.mm] | |
triton_tem_fused_mm_2.run(buf1519, buf1521, buf1522, grid=torch._inductor.kernel.mm_common.mm_grid(1, 4096, meta0), stream=stream0) | |
del buf1521 | |
buf1523 = buf1350; del buf1350 # reuse | |
buf1525 = buf1519; del buf1519 # reuse | |
# Source Nodes: [add_124, h_17, h_18, mean_35, mul_231, mul_232, out_15, out_16, pow_36, rsqrt_35, x_fp32_35, x_normed_35], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, aten.pow, aten.rsqrt] | |
triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_14.run(buf1523, buf1522, buf1436, buf1383, buf1469, arg462_1, buf1525, 1, 4096, grid=grid(1), stream=stream0) | |
del arg462_1 | |
del buf1383 | |
del buf1436 | |
del buf1469 | |
del buf1522 | |
# Source Nodes: [choose_qparams_per_token_asymmetric_123], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric] | |
buf1526 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf1525, torch.int8) | |
buf1527 = buf1526[0] | |
buf1528 = buf1526[1] | |
del buf1526 | |
# Source Nodes: [choose_qparams_per_token_asymmetric_124], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric] | |
buf1529 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf1525, torch.int8) | |
buf1530 = buf1529[0] | |
buf1531 = buf1529[1] | |
del buf1529 | |
# Source Nodes: [input_371], Original ATen: [quantized_decomposed.quantize_per_token] | |
buf1532 = torch.ops.quantized_decomposed.quantize_per_token.default(buf1525, buf1527, buf1528, -128, 127, torch.int8) | |
buf1533 = buf1532 | |
del buf1532 | |
# Source Nodes: [input_372], Original ATen: [quantized_decomposed.dequantize_per_token] | |
buf1534 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf1533, buf1527, buf1528, -128, 127, torch.int8, torch.bfloat16) | |
del buf1527 | |
del buf1528 | |
del buf1533 | |
buf1535 = buf1534 | |
del buf1534 | |
# Source Nodes: [w_dq_123], Original ATen: [quantized_decomposed.dequantize_per_channel_group] | |
buf1536 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg463_1, arg464_1, arg465_1, -8, 7, torch.int8, 256, torch.bfloat16) | |
del arg463_1 | |
del arg464_1 | |
del arg465_1 | |
buf1537 = buf1536 | |
del buf1536 | |
buf1538 = reinterpret_tensor(buf1466, (1, 14336), (14336, 1), 0); del buf1466 # reuse | |
# Source Nodes: [c_123], Original ATen: [aten.mm] | |
triton_tem_fused_mm_6.run(buf1535, buf1537, buf1538, grid=torch._inductor.kernel.mm_common.mm_grid(1, 14336, meta1), stream=stream0) | |
del buf1537 | |
# Source Nodes: [input_374], Original ATen: [quantized_decomposed.quantize_per_token] | |
buf1539 = torch.ops.quantized_decomposed.quantize_per_token.default(buf1525, buf1530, buf1531, -128, 127, torch.int8) | |
buf1540 = buf1539 | |
del buf1539 | |
# Source Nodes: [input_375], Original ATen: [quantized_decomposed.dequantize_per_token] | |
buf1541 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf1540, buf1530, buf1531, -128, 127, torch.int8, torch.bfloat16) | |
del buf1530 | |
del buf1531 | |
del buf1540 | |
buf1542 = buf1541 | |
del buf1541 | |
# Source Nodes: [w_dq_124], Original ATen: [quantized_decomposed.dequantize_per_channel_group] | |
buf1543 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg466_1, arg467_1, arg468_1, -8, 7, torch.int8, 256, torch.bfloat16) | |
del arg466_1 | |
del arg467_1 | |
del arg468_1 | |
buf1544 = buf1543 | |
del buf1543 | |
buf1546 = buf1459; del buf1459 # reuse | |
# Source Nodes: [c_124, mul_233, silu_17], Original ATen: [aten.mm, aten.mul, aten.silu] | |
triton_tem_fused_mm_mul_silu_7.run(buf1542, buf1544, buf1538, buf1546, grid=torch._inductor.kernel.mm_common.mm_grid(1, 14336, meta1), stream=stream0) | |
del buf1538 | |
del buf1544 | |
# Source Nodes: [choose_qparams_per_token_asymmetric_125], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric] | |
buf1547 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf1546, torch.int8) | |
buf1548 = buf1547[0] | |
buf1549 = buf1547[1] | |
del buf1547 | |
# Source Nodes: [input_377], Original ATen: [quantized_decomposed.quantize_per_token] | |
buf1550 = torch.ops.quantized_decomposed.quantize_per_token.default(buf1546, buf1548, buf1549, -128, 127, torch.int8) | |
buf1551 = buf1550 | |
del buf1550 | |
# Source Nodes: [input_378], Original ATen: [quantized_decomposed.dequantize_per_token] | |
buf1552 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf1551, buf1548, buf1549, -128, 127, torch.int8, torch.bfloat16) | |
del buf1548 | |
del buf1549 | |
del buf1551 | |
buf1553 = buf1552 | |
del buf1552 | |
# Source Nodes: [w_dq_125], Original ATen: [quantized_decomposed.dequantize_per_channel_group] | |
buf1554 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg469_1, arg470_1, arg471_1, -8, 7, torch.int8, 256, torch.bfloat16) | |
del arg469_1 | |
del arg470_1 | |
del arg471_1 | |
buf1555 = buf1554 | |
del buf1554 | |
buf1556 = reinterpret_tensor(buf1542, (1, 4096), (4096, 1), 0); del buf1542 # reuse | |
# Source Nodes: [c_125], Original ATen: [aten.mm] | |
triton_tem_fused_mm_8.run(buf1553, buf1555, buf1556, grid=torch._inductor.kernel.mm_common.mm_grid(1, 4096, meta0), stream=stream0) | |
del buf1555 | |
buf1558 = buf1525; del buf1525 # reuse | |
# Source Nodes: [add_126, mean_36, mul_234, out_17, pow_37, rsqrt_36, x_fp32_36, x_normed_36, y_18], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, aten.pow, aten.rsqrt] | |
triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_11.run(buf1523, buf1556, arg472_1, buf1558, 1, 4096, grid=grid(1), stream=stream0) | |
del arg472_1 | |
# Source Nodes: [choose_qparams_per_token_asymmetric_126], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric] | |
buf1559 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf1558, torch.int8) | |
buf1560 = buf1559[0] | |
buf1561 = buf1559[1] | |
del buf1559 | |
# Source Nodes: [choose_qparams_per_token_asymmetric_127], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric] | |
buf1562 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf1558, torch.int8) | |
buf1563 = buf1562[0] | |
buf1564 = buf1562[1] | |
del buf1562 | |
# Source Nodes: [choose_qparams_per_token_asymmetric_128], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric] | |
buf1565 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf1558, torch.int8) | |
buf1566 = buf1565[0] | |
buf1567 = buf1565[1] | |
del buf1565 | |
buf1568 = buf1481; del buf1481 # reuse | |
# Source Nodes: [max_19], Original ATen: [aten.max] | |
triton_poi_fused_max_1.run(arg3_1, buf1568, 1, grid=grid(1), stream=stream0) | |
u18 = buf1568.item() | |
buf1569 = None | |
with torch.cuda._DeviceGuard(0): | |
torch.cuda.set_device(0) | |
# Source Nodes: [input_380], Original ATen: [quantized_decomposed.quantize_per_token] | |
buf1570 = torch.ops.quantized_decomposed.quantize_per_token.default(buf1558, buf1560, buf1561, -128, 127, torch.int8) | |
buf1571 = buf1570 | |
del buf1570 | |
# Source Nodes: [input_381], Original ATen: [quantized_decomposed.dequantize_per_token] | |
buf1572 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf1571, buf1560, buf1561, -128, 127, torch.int8, torch.bfloat16) | |
del buf1560 | |
del buf1561 | |
del buf1571 | |
buf1573 = buf1572 | |
del buf1572 | |
# Source Nodes: [w_dq_126], Original ATen: [quantized_decomposed.dequantize_per_channel_group] | |
buf1574 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg473_1, arg474_1, arg475_1, -8, 7, torch.int8, 256, torch.bfloat16) | |
del arg473_1 | |
del arg474_1 | |
del arg475_1 | |
buf1575 = buf1574 | |
del buf1574 | |
buf1576 = reinterpret_tensor(buf1535, (1, 4096), (4096, 1), 0); del buf1535 # reuse | |
# Source Nodes: [c_126], Original ATen: [aten.mm] | |
triton_tem_fused_mm_2.run(buf1573, buf1575, buf1576, grid=torch._inductor.kernel.mm_common.mm_grid(1, 4096, meta0), stream=stream0) | |
del buf1573 | |
del buf1575 | |
# Source Nodes: [input_383], Original ATen: [quantized_decomposed.quantize_per_token] | |
buf1577 = torch.ops.quantized_decomposed.quantize_per_token.default(buf1558, buf1563, buf1564, -128, 127, torch.int8) | |
buf1578 = buf1577 | |
del buf1577 | |
# Source Nodes: [input_384], Original ATen: [quantized_decomposed.dequantize_per_token] | |
buf1579 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf1578, buf1563, buf1564, -128, 127, torch.int8, torch.bfloat16) | |
del buf1563 | |
del buf1564 | |
del buf1578 | |
buf1580 = buf1579 | |
del buf1579 | |
# Source Nodes: [w_dq_127], Original ATen: [quantized_decomposed.dequantize_per_channel_group] | |
buf1581 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg476_1, arg477_1, arg478_1, -8, 7, torch.int8, 256, torch.bfloat16) | |
del arg476_1 | |
del arg477_1 | |
del arg478_1 | |
buf1582 = buf1581 | |
del buf1581 | |
buf1583 = buf1504; del buf1504 # reuse | |
# Source Nodes: [], Original ATen: [] | |
extern_kernels.mm(reinterpret_tensor(buf1580, (1, 4096), (4096, 1), 0), reinterpret_tensor(buf1582, (4096, 1024), (1, 4096), 0), out=buf1583) | |
del buf1580 | |
del buf1582 | |
# Source Nodes: [input_386], Original ATen: [quantized_decomposed.quantize_per_token] | |
buf1585 = torch.ops.quantized_decomposed.quantize_per_token.default(buf1558, buf1566, buf1567, -128, 127, torch.int8) | |
del buf1558 | |
buf1586 = buf1585 | |
del buf1585 | |
# Source Nodes: [input_387], Original ATen: [quantized_decomposed.dequantize_per_token] | |
buf1587 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf1586, buf1566, buf1567, -128, 127, torch.int8, torch.bfloat16) | |
del buf1566 | |
del buf1567 | |
del buf1586 | |
buf1588 = buf1587 | |
del buf1587 | |
# Source Nodes: [w_dq_128], Original ATen: [quantized_decomposed.dequantize_per_channel_group] | |
buf1589 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg479_1, arg480_1, arg481_1, -8, 7, torch.int8, 256, torch.bfloat16) | |
del arg479_1 | |
del arg480_1 | |
del arg481_1 | |
buf1590 = buf1589 | |
del buf1589 | |
buf1591 = buf1496; del buf1496 # reuse | |
# Source Nodes: [], Original ATen: [] | |
extern_kernels.mm(reinterpret_tensor(buf1588, (1, 4096), (4096, 1), 0), reinterpret_tensor(buf1590, (4096, 1024), (1, 4096), 0), out=buf1591) | |
del buf1590 | |
buf1593 = reinterpret_tensor(buf1588, (1, 32, 1, 128), (4096, 128, 4096, 1), 0); del buf1588 # reuse | |
# Source Nodes: [output_36, setitem_36, setitem_37], Original ATen: [aten._scaled_dot_product_efficient_attention, aten.index_put] | |
triton_poi_fused__scaled_dot_product_efficient_attention_index_put_3.run(arg3_1, buf1583, arg482_1, buf1591, buf1576, arg483_1, arg484_1, buf1593, 4096, grid=grid(4096), stream=stream0) | |
del arg482_1 | |
del buf1576 | |
buf1594 = buf1507; del buf1507 # reuse | |
# Source Nodes: [output_36], Original ATen: [aten._scaled_dot_product_efficient_attention] | |
triton_poi_fused__scaled_dot_product_efficient_attention_4.run(arg3_1, arg2_1, buf1594, 65536, grid=grid(65536), stream=stream0) | |
# Source Nodes: [output_36], Original ATen: [aten._scaled_dot_product_efficient_attention] | |
buf1595 = torch.ops.aten._scaled_dot_product_efficient_attention.default(buf1593, arg483_1, arg484_1, buf1594, False) | |
del arg483_1 | |
del arg484_1 | |
del buf1593 | |
buf1596 = buf1595[0] | |
del buf1595 | |
# Source Nodes: [choose_qparams_per_token_asymmetric_129], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric] | |
buf1600 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(reinterpret_tensor(buf1596, (1, 1, 4096), (4096, 4096, 1), 0), torch.int8) | |
buf1601 = buf1600[0] | |
buf1602 = buf1600[1] | |
del buf1600 | |
# Source Nodes: [input_389], Original ATen: [quantized_decomposed.quantize_per_token] | |
buf1603 = torch.ops.quantized_decomposed.quantize_per_token.default(reinterpret_tensor(buf1596, (1, 1, 4096), (4096, 4096, 1), 0), buf1601, buf1602, -128, 127, torch.int8) | |
buf1604 = buf1603 | |
del buf1603 | |
# Source Nodes: [input_390], Original ATen: [quantized_decomposed.dequantize_per_token] | |
buf1605 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf1604, buf1601, buf1602, -128, 127, torch.int8, torch.bfloat16) | |
del buf1601 | |
del buf1602 | |
del buf1604 | |
buf1606 = buf1605 | |
del buf1605 | |
# Source Nodes: [w_dq_129], Original ATen: [quantized_decomposed.dequantize_per_channel_group] | |
buf1607 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg485_1, arg486_1, arg487_1, -8, 7, torch.int8, 256, torch.bfloat16) | |
del arg485_1 | |
del arg486_1 | |
del arg487_1 | |
buf1608 = buf1607 | |
del buf1607 | |
buf1609 = reinterpret_tensor(buf1596, (1, 4096), (4096, 1), 0); del buf1596 # reuse | |
# Source Nodes: [c_129], Original ATen: [aten.mm] | |
triton_tem_fused_mm_2.run(buf1606, buf1608, buf1609, grid=torch._inductor.kernel.mm_common.mm_grid(1, 4096, meta0), stream=stream0) | |
del buf1608 | |
buf1611 = buf1606; del buf1606 # reuse | |
# Source Nodes: [add_131, h_19, mean_37, mul_244, mul_245, out_17, pow_38, rsqrt_37, x_fp32_37, x_normed_37], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, aten.pow, aten.rsqrt] | |
triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_12.run(buf1609, buf1523, buf1556, arg488_1, buf1611, 1, 4096, grid=grid(1), stream=stream0) | |
del arg488_1 | |
# Source Nodes: [choose_qparams_per_token_asymmetric_130], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric] | |
buf1612 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf1611, torch.int8) | |
buf1613 = buf1612[0] | |
buf1614 = buf1612[1] | |
del buf1612 | |
# Source Nodes: [choose_qparams_per_token_asymmetric_131], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric] | |
buf1615 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf1611, torch.int8) | |
buf1616 = buf1615[0] | |
buf1617 = buf1615[1] | |
del buf1615 | |
# Source Nodes: [input_392], Original ATen: [quantized_decomposed.quantize_per_token] | |
buf1618 = torch.ops.quantized_decomposed.quantize_per_token.default(buf1611, buf1613, buf1614, -128, 127, torch.int8) | |
buf1619 = buf1618 | |
del buf1618 | |
# Source Nodes: [input_393], Original ATen: [quantized_decomposed.dequantize_per_token] | |
buf1620 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf1619, buf1613, buf1614, -128, 127, torch.int8, torch.bfloat16) | |
del buf1613 | |
del buf1614 | |
del buf1619 | |
buf1621 = buf1620 | |
del buf1620 | |
# Source Nodes: [w_dq_130], Original ATen: [quantized_decomposed.dequantize_per_channel_group] | |
buf1622 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg489_1, arg490_1, arg491_1, -8, 7, torch.int8, 256, torch.bfloat16) | |
del arg489_1 | |
del arg490_1 | |
del arg491_1 | |
buf1623 = buf1622 | |
del buf1622 | |
buf1624 = reinterpret_tensor(buf1553, (1, 14336), (14336, 1), 0); del buf1553 # reuse | |
# Source Nodes: [c_130], Original ATen: [aten.mm] | |
triton_tem_fused_mm_6.run(buf1621, buf1623, buf1624, grid=torch._inductor.kernel.mm_common.mm_grid(1, 14336, meta1), stream=stream0) | |
del buf1623 | |
# Source Nodes: [input_395], Original ATen: [quantized_decomposed.quantize_per_token] | |
buf1625 = torch.ops.quantized_decomposed.quantize_per_token.default(buf1611, buf1616, buf1617, -128, 127, torch.int8) | |
buf1626 = buf1625 | |
del buf1625 | |
# Source Nodes: [input_396], Original ATen: [quantized_decomposed.dequantize_per_token] | |
buf1627 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf1626, buf1616, buf1617, -128, 127, torch.int8, torch.bfloat16) | |
del buf1616 | |
del buf1617 | |
del buf1626 | |
buf1628 = buf1627 | |
del buf1627 | |
# Source Nodes: [w_dq_131], Original ATen: [quantized_decomposed.dequantize_per_channel_group] | |
buf1629 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg492_1, arg493_1, arg494_1, -8, 7, torch.int8, 256, torch.bfloat16) | |
del arg492_1 | |
del arg493_1 | |
del arg494_1 | |
buf1630 = buf1629 | |
del buf1629 | |
buf1632 = buf1546; del buf1546 # reuse | |
# Source Nodes: [c_131, mul_246, silu_18], Original ATen: [aten.mm, aten.mul, aten.silu] | |
triton_tem_fused_mm_mul_silu_7.run(buf1628, buf1630, buf1624, buf1632, grid=torch._inductor.kernel.mm_common.mm_grid(1, 14336, meta1), stream=stream0) | |
del buf1624 | |
del buf1630 | |
# Source Nodes: [choose_qparams_per_token_asymmetric_132], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric] | |
buf1633 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf1632, torch.int8) | |
buf1634 = buf1633[0] | |
buf1635 = buf1633[1] | |
del buf1633 | |
# Source Nodes: [input_398], Original ATen: [quantized_decomposed.quantize_per_token] | |
buf1636 = torch.ops.quantized_decomposed.quantize_per_token.default(buf1632, buf1634, buf1635, -128, 127, torch.int8) | |
buf1637 = buf1636 | |
del buf1636 | |
# Source Nodes: [input_399], Original ATen: [quantized_decomposed.dequantize_per_token] | |
buf1638 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf1637, buf1634, buf1635, -128, 127, torch.int8, torch.bfloat16) | |
del buf1634 | |
del buf1635 | |
del buf1637 | |
buf1639 = buf1638 | |
del buf1638 | |
# Source Nodes: [w_dq_132], Original ATen: [quantized_decomposed.dequantize_per_channel_group] | |
buf1640 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg495_1, arg496_1, arg497_1, -8, 7, torch.int8, 256, torch.bfloat16) | |
del arg495_1 | |
del arg496_1 | |
del arg497_1 | |
buf1641 = buf1640 | |
del buf1640 | |
buf1642 = reinterpret_tensor(buf1628, (1, 4096), (4096, 1), 0); del buf1628 # reuse | |
# Source Nodes: [c_132], Original ATen: [aten.mm] | |
triton_tem_fused_mm_8.run(buf1639, buf1641, buf1642, grid=torch._inductor.kernel.mm_common.mm_grid(1, 4096, meta0), stream=stream0) | |
del buf1641 | |
buf1644 = buf1611; del buf1611 # reuse | |
# Source Nodes: [add_133, h_19, mean_38, mul_247, out_17, out_18, pow_39, rsqrt_38, x_fp32_38, x_normed_38, y_19], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, aten.pow, aten.rsqrt] | |
triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_13.run(buf1609, buf1523, buf1556, buf1642, arg498_1, buf1644, 1, 4096, grid=grid(1), stream=stream0) | |
del arg498_1 | |
# Source Nodes: [choose_qparams_per_token_asymmetric_133], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric] | |
buf1645 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf1644, torch.int8) | |
buf1646 = buf1645[0] | |
buf1647 = buf1645[1] | |
del buf1645 | |
# Source Nodes: [choose_qparams_per_token_asymmetric_134], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric] | |
buf1648 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf1644, torch.int8) | |
buf1649 = buf1648[0] | |
buf1650 = buf1648[1] | |
del buf1648 | |
# Source Nodes: [choose_qparams_per_token_asymmetric_135], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric] | |
buf1651 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf1644, torch.int8) | |
buf1652 = buf1651[0] | |
buf1653 = buf1651[1] | |
del buf1651 | |
buf1654 = buf1568; del buf1568 # reuse | |
# Source Nodes: [max_20], Original ATen: [aten.max] | |
triton_poi_fused_max_1.run(arg3_1, buf1654, 1, grid=grid(1), stream=stream0) | |
u19 = buf1654.item() | |
buf1655 = None | |
with torch.cuda._DeviceGuard(0): | |
torch.cuda.set_device(0) | |
# Source Nodes: [input_401], Original ATen: [quantized_decomposed.quantize_per_token] | |
buf1656 = torch.ops.quantized_decomposed.quantize_per_token.default(buf1644, buf1646, buf1647, -128, 127, torch.int8) | |
buf1657 = buf1656 | |
del buf1656 | |
# Source Nodes: [input_402], Original ATen: [quantized_decomposed.dequantize_per_token] | |
buf1658 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf1657, buf1646, buf1647, -128, 127, torch.int8, torch.bfloat16) | |
del buf1646 | |
del buf1647 | |
del buf1657 | |
buf1659 = buf1658 | |
del buf1658 | |
# Source Nodes: [w_dq_133], Original ATen: [quantized_decomposed.dequantize_per_channel_group] | |
buf1660 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg499_1, arg500_1, arg501_1, -8, 7, torch.int8, 256, torch.bfloat16) | |
del arg499_1 | |
del arg500_1 | |
del arg501_1 | |
buf1661 = buf1660 | |
del buf1660 | |
buf1662 = reinterpret_tensor(buf1621, (1, 4096), (4096, 1), 0); del buf1621 # reuse | |
# Source Nodes: [c_133], Original ATen: [aten.mm] | |
triton_tem_fused_mm_2.run(buf1659, buf1661, buf1662, grid=torch._inductor.kernel.mm_common.mm_grid(1, 4096, meta0), stream=stream0) | |
del buf1659 | |
del buf1661 | |
# Source Nodes: [input_404], Original ATen: [quantized_decomposed.quantize_per_token] | |
buf1663 = torch.ops.quantized_decomposed.quantize_per_token.default(buf1644, buf1649, buf1650, -128, 127, torch.int8) | |
buf1664 = buf1663 | |
del buf1663 | |
# Source Nodes: [input_405], Original ATen: [quantized_decomposed.dequantize_per_token] | |
buf1665 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf1664, buf1649, buf1650, -128, 127, torch.int8, torch.bfloat16) | |
del buf1649 | |
del buf1650 | |
del buf1664 | |
buf1666 = buf1665 | |
del buf1665 | |
# Source Nodes: [w_dq_134], Original ATen: [quantized_decomposed.dequantize_per_channel_group] | |
buf1667 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg502_1, arg503_1, arg504_1, -8, 7, torch.int8, 256, torch.bfloat16) | |
del arg502_1 | |
del arg503_1 | |
del arg504_1 | |
buf1668 = buf1667 | |
del buf1667 | |
buf1669 = buf1591; del buf1591 # reuse | |
# Source Nodes: [], Original ATen: [] | |
extern_kernels.mm(reinterpret_tensor(buf1666, (1, 4096), (4096, 1), 0), reinterpret_tensor(buf1668, (4096, 1024), (1, 4096), 0), out=buf1669) | |
del buf1666 | |
del buf1668 | |
# Source Nodes: [input_407], Original ATen: [quantized_decomposed.quantize_per_token] | |
buf1671 = torch.ops.quantized_decomposed.quantize_per_token.default(buf1644, buf1652, buf1653, -128, 127, torch.int8) | |
del buf1644 | |
buf1672 = buf1671 | |
del buf1671 | |
# Source Nodes: [input_408], Original ATen: [quantized_decomposed.dequantize_per_token] | |
buf1673 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf1672, buf1652, buf1653, -128, 127, torch.int8, torch.bfloat16) | |
del buf1652 | |
del buf1653 | |
del buf1672 | |
buf1674 = buf1673 | |
del buf1673 | |
# Source Nodes: [w_dq_135], Original ATen: [quantized_decomposed.dequantize_per_channel_group] | |
buf1675 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg505_1, arg506_1, arg507_1, -8, 7, torch.int8, 256, torch.bfloat16) | |
del arg505_1 | |
del arg506_1 | |
del arg507_1 | |
buf1676 = buf1675 | |
del buf1675 | |
buf1677 = buf1583; del buf1583 # reuse | |
# Source Nodes: [], Original ATen: [] | |
extern_kernels.mm(reinterpret_tensor(buf1674, (1, 4096), (4096, 1), 0), reinterpret_tensor(buf1676, (4096, 1024), (1, 4096), 0), out=buf1677) | |
del buf1676 | |
buf1679 = reinterpret_tensor(buf1674, (1, 32, 1, 128), (4096, 128, 4096, 1), 0); del buf1674 # reuse | |
# Source Nodes: [output_38, setitem_38, setitem_39], Original ATen: [aten._scaled_dot_product_efficient_attention, aten.index_put] | |
triton_poi_fused__scaled_dot_product_efficient_attention_index_put_3.run(arg3_1, buf1669, arg508_1, buf1677, buf1662, arg509_1, arg510_1, buf1679, 4096, grid=grid(4096), stream=stream0) | |
del arg508_1 | |
del buf1662 | |
buf1680 = buf1594; del buf1594 # reuse | |
# Source Nodes: [output_38], Original ATen: [aten._scaled_dot_product_efficient_attention] | |
triton_poi_fused__scaled_dot_product_efficient_attention_4.run(arg3_1, arg2_1, buf1680, 65536, grid=grid(65536), stream=stream0) | |
# Source Nodes: [output_38], Original ATen: [aten._scaled_dot_product_efficient_attention] | |
buf1681 = torch.ops.aten._scaled_dot_product_efficient_attention.default(buf1679, arg509_1, arg510_1, buf1680, False) | |
del arg509_1 | |
del arg510_1 | |
del buf1679 | |
buf1682 = buf1681[0] | |
del buf1681 | |
# Source Nodes: [choose_qparams_per_token_asymmetric_136], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric] | |
buf1686 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(reinterpret_tensor(buf1682, (1, 1, 4096), (4096, 4096, 1), 0), torch.int8) | |
buf1687 = buf1686[0] | |
buf1688 = buf1686[1] | |
del buf1686 | |
# Source Nodes: [input_410], Original ATen: [quantized_decomposed.quantize_per_token] | |
buf1689 = torch.ops.quantized_decomposed.quantize_per_token.default(reinterpret_tensor(buf1682, (1, 1, 4096), (4096, 4096, 1), 0), buf1687, buf1688, -128, 127, torch.int8) | |
buf1690 = buf1689 | |
del buf1689 | |
# Source Nodes: [input_411], Original ATen: [quantized_decomposed.dequantize_per_token] | |
buf1691 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf1690, buf1687, buf1688, -128, 127, torch.int8, torch.bfloat16) | |
del buf1687 | |
del buf1688 | |
del buf1690 | |
buf1692 = buf1691 | |
del buf1691 | |
# Source Nodes: [w_dq_136], Original ATen: [quantized_decomposed.dequantize_per_channel_group] | |
buf1693 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg511_1, arg512_1, arg513_1, -8, 7, torch.int8, 256, torch.bfloat16) | |
del arg511_1 | |
del arg512_1 | |
del arg513_1 | |
buf1694 = buf1693 | |
del buf1693 | |
buf1695 = reinterpret_tensor(buf1682, (1, 4096), (4096, 1), 0); del buf1682 # reuse | |
# Source Nodes: [c_136], Original ATen: [aten.mm] | |
triton_tem_fused_mm_2.run(buf1692, buf1694, buf1695, grid=torch._inductor.kernel.mm_common.mm_grid(1, 4096, meta0), stream=stream0) | |
del buf1694 | |
buf1696 = buf1523; del buf1523 # reuse | |
buf1698 = buf1692; del buf1692 # reuse | |
# Source Nodes: [add_138, h_19, h_20, mean_39, mul_257, mul_258, out_17, out_18, pow_40, rsqrt_39, x_fp32_39, x_normed_39], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, aten.pow, aten.rsqrt] | |
triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_14.run(buf1696, buf1695, buf1609, buf1556, buf1642, arg514_1, buf1698, 1, 4096, grid=grid(1), stream=stream0) | |
del arg514_1 | |
del buf1556 | |
del buf1609 | |
del buf1642 | |
del buf1695 | |
# Source Nodes: [choose_qparams_per_token_asymmetric_137], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric] | |
buf1699 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf1698, torch.int8) | |
buf1700 = buf1699[0] | |
buf1701 = buf1699[1] | |
del buf1699 | |
# Source Nodes: [choose_qparams_per_token_asymmetric_138], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric] | |
buf1702 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf1698, torch.int8) | |
buf1703 = buf1702[0] | |
buf1704 = buf1702[1] | |
del buf1702 | |
# Source Nodes: [input_413], Original ATen: [quantized_decomposed.quantize_per_token] | |
buf1705 = torch.ops.quantized_decomposed.quantize_per_token.default(buf1698, buf1700, buf1701, -128, 127, torch.int8) | |
buf1706 = buf1705 | |
del buf1705 | |
# Source Nodes: [input_414], Original ATen: [quantized_decomposed.dequantize_per_token] | |
buf1707 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf1706, buf1700, buf1701, -128, 127, torch.int8, torch.bfloat16) | |
del buf1700 | |
del buf1701 | |
del buf1706 | |
buf1708 = buf1707 | |
del buf1707 | |
# Source Nodes: [w_dq_137], Original ATen: [quantized_decomposed.dequantize_per_channel_group] | |
buf1709 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg515_1, arg516_1, arg517_1, -8, 7, torch.int8, 256, torch.bfloat16) | |
del arg515_1 | |
del arg516_1 | |
del arg517_1 | |
buf1710 = buf1709 | |
del buf1709 | |
buf1711 = reinterpret_tensor(buf1639, (1, 14336), (14336, 1), 0); del buf1639 # reuse | |
# Source Nodes: [c_137], Original ATen: [aten.mm] | |
triton_tem_fused_mm_6.run(buf1708, buf1710, buf1711, grid=torch._inductor.kernel.mm_common.mm_grid(1, 14336, meta1), stream=stream0) | |
del buf1710 | |
# Source Nodes: [input_416], Original ATen: [quantized_decomposed.quantize_per_token] | |
buf1712 = torch.ops.quantized_decomposed.quantize_per_token.default(buf1698, buf1703, buf1704, -128, 127, torch.int8) | |
buf1713 = buf1712 | |
del buf1712 | |
# Source Nodes: [input_417], Original ATen: [quantized_decomposed.dequantize_per_token] | |
buf1714 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf1713, buf1703, buf1704, -128, 127, torch.int8, torch.bfloat16) | |
del buf1703 | |
del buf1704 | |
del buf1713 | |
buf1715 = buf1714 | |
del buf1714 | |
# Source Nodes: [w_dq_138], Original ATen: [quantized_decomposed.dequantize_per_channel_group] | |
buf1716 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg518_1, arg519_1, arg520_1, -8, 7, torch.int8, 256, torch.bfloat16) | |
del arg518_1 | |
del arg519_1 | |
del arg520_1 | |
buf1717 = buf1716 | |
del buf1716 | |
buf1719 = buf1632; del buf1632 # reuse | |
# Source Nodes: [c_138, mul_259, silu_19], Original ATen: [aten.mm, aten.mul, aten.silu] | |
triton_tem_fused_mm_mul_silu_7.run(buf1715, buf1717, buf1711, buf1719, grid=torch._inductor.kernel.mm_common.mm_grid(1, 14336, meta1), stream=stream0) | |
del buf1711 | |
del buf1717 | |
# Source Nodes: [choose_qparams_per_token_asymmetric_139], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric] | |
buf1720 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf1719, torch.int8) | |
buf1721 = buf1720[0] | |
buf1722 = buf1720[1] | |
del buf1720 | |
# Source Nodes: [input_419], Original ATen: [quantized_decomposed.quantize_per_token] | |
buf1723 = torch.ops.quantized_decomposed.quantize_per_token.default(buf1719, buf1721, buf1722, -128, 127, torch.int8) | |
buf1724 = buf1723 | |
del buf1723 | |
# Source Nodes: [input_420], Original ATen: [quantized_decomposed.dequantize_per_token] | |
buf1725 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf1724, buf1721, buf1722, -128, 127, torch.int8, torch.bfloat16) | |
del buf1721 | |
del buf1722 | |
del buf1724 | |
buf1726 = buf1725 | |
del buf1725 | |
# Source Nodes: [w_dq_139], Original ATen: [quantized_decomposed.dequantize_per_channel_group] | |
buf1727 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg521_1, arg522_1, arg523_1, -8, 7, torch.int8, 256, torch.bfloat16) | |
del arg521_1 | |
del arg522_1 | |
del arg523_1 | |
buf1728 = buf1727 | |
del buf1727 | |
buf1729 = reinterpret_tensor(buf1715, (1, 4096), (4096, 1), 0); del buf1715 # reuse | |
# Source Nodes: [c_139], Original ATen: [aten.mm] | |
triton_tem_fused_mm_8.run(buf1726, buf1728, buf1729, grid=torch._inductor.kernel.mm_common.mm_grid(1, 4096, meta0), stream=stream0) | |
del buf1728 | |
buf1731 = buf1698; del buf1698 # reuse | |
# Source Nodes: [add_140, mean_40, mul_260, out_19, pow_41, rsqrt_40, x_fp32_40, x_normed_40, y_20], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, aten.pow, aten.rsqrt] | |
triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_11.run(buf1696, buf1729, arg524_1, buf1731, 1, 4096, grid=grid(1), stream=stream0) | |
del arg524_1 | |
# Source Nodes: [choose_qparams_per_token_asymmetric_140], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric] | |
buf1732 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf1731, torch.int8) | |
buf1733 = buf1732[0] | |
buf1734 = buf1732[1] | |
del buf1732 | |
# Source Nodes: [choose_qparams_per_token_asymmetric_141], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric] | |
buf1735 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf1731, torch.int8) | |
buf1736 = buf1735[0] | |
buf1737 = buf1735[1] | |
del buf1735 | |
# Source Nodes: [choose_qparams_per_token_asymmetric_142], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric] | |
buf1738 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf1731, torch.int8) | |
buf1739 = buf1738[0] | |
buf1740 = buf1738[1] | |
del buf1738 | |
buf1741 = buf1654; del buf1654 # reuse | |
# Source Nodes: [max_21], Original ATen: [aten.max] | |
triton_poi_fused_max_1.run(arg3_1, buf1741, 1, grid=grid(1), stream=stream0) | |
u20 = buf1741.item() | |
buf1742 = None | |
with torch.cuda._DeviceGuard(0): | |
torch.cuda.set_device(0) | |
# Source Nodes: [input_422], Original ATen: [quantized_decomposed.quantize_per_token] | |
buf1743 = torch.ops.quantized_decomposed.quantize_per_token.default(buf1731, buf1733, buf1734, -128, 127, torch.int8) | |
buf1744 = buf1743 | |
del buf1743 | |
# Source Nodes: [input_423], Original ATen: [quantized_decomposed.dequantize_per_token] | |
buf1745 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf1744, buf1733, buf1734, -128, 127, torch.int8, torch.bfloat16) | |
del buf1733 | |
del buf1734 | |
del buf1744 | |
buf1746 = buf1745 | |
del buf1745 | |
# Source Nodes: [w_dq_140], Original ATen: [quantized_decomposed.dequantize_per_channel_group] | |
buf1747 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg525_1, arg526_1, arg527_1, -8, 7, torch.int8, 256, torch.bfloat16) | |
del arg525_1 | |
del arg526_1 | |
del arg527_1 | |
buf1748 = buf1747 | |
del buf1747 | |
buf1749 = reinterpret_tensor(buf1708, (1, 4096), (4096, 1), 0); del buf1708 # reuse | |
# Source Nodes: [c_140], Original ATen: [aten.mm] | |
triton_tem_fused_mm_2.run(buf1746, buf1748, buf1749, grid=torch._inductor.kernel.mm_common.mm_grid(1, 4096, meta0), stream=stream0) | |
del buf1746 | |
del buf1748 | |
# Source Nodes: [input_425], Original ATen: [quantized_decomposed.quantize_per_token] | |
buf1750 = torch.ops.quantized_decomposed.quantize_per_token.default(buf1731, buf1736, buf1737, -128, 127, torch.int8) | |
buf1751 = buf1750 | |
del buf1750 | |
# Source Nodes: [input_426], Original ATen: [quantized_decomposed.dequantize_per_token] | |
buf1752 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf1751, buf1736, buf1737, -128, 127, torch.int8, torch.bfloat16) | |
del buf1736 | |
del buf1737 | |
del buf1751 | |
buf1753 = buf1752 | |
del buf1752 | |
# Source Nodes: [w_dq_141], Original ATen: [quantized_decomposed.dequantize_per_channel_group] | |
buf1754 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg528_1, arg529_1, arg530_1, -8, 7, torch.int8, 256, torch.bfloat16) | |
del arg528_1 | |
del arg529_1 | |
del arg530_1 | |
buf1755 = buf1754 | |
del buf1754 | |
buf1756 = buf1677; del buf1677 # reuse | |
# Source Nodes: [], Original ATen: [] | |
extern_kernels.mm(reinterpret_tensor(buf1753, (1, 4096), (4096, 1), 0), reinterpret_tensor(buf1755, (4096, 1024), (1, 4096), 0), out=buf1756) | |
del buf1753 | |
del buf1755 | |
# Source Nodes: [input_428], Original ATen: [quantized_decomposed.quantize_per_token] | |
buf1758 = torch.ops.quantized_decomposed.quantize_per_token.default(buf1731, buf1739, buf1740, -128, 127, torch.int8) | |
del buf1731 | |
buf1759 = buf1758 | |
del buf1758 | |
# Source Nodes: [input_429], Original ATen: [quantized_decomposed.dequantize_per_token] | |
buf1760 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf1759, buf1739, buf1740, -128, 127, torch.int8, torch.bfloat16) | |
del buf1739 | |
del buf1740 | |
del buf1759 | |
buf1761 = buf1760 | |
del buf1760 | |
# Source Nodes: [w_dq_142], Original ATen: [quantized_decomposed.dequantize_per_channel_group] | |
buf1762 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg531_1, arg532_1, arg533_1, -8, 7, torch.int8, 256, torch.bfloat16) | |
del arg531_1 | |
del arg532_1 | |
del arg533_1 | |
buf1763 = buf1762 | |
del buf1762 | |
buf1764 = buf1669; del buf1669 # reuse | |
# Source Nodes: [], Original ATen: [] | |
extern_kernels.mm(reinterpret_tensor(buf1761, (1, 4096), (4096, 1), 0), reinterpret_tensor(buf1763, (4096, 1024), (1, 4096), 0), out=buf1764) | |
del buf1763 | |
buf1766 = reinterpret_tensor(buf1761, (1, 32, 1, 128), (4096, 128, 4096, 1), 0); del buf1761 # reuse | |
# Source Nodes: [output_40, setitem_40, setitem_41], Original ATen: [aten._scaled_dot_product_efficient_attention, aten.index_put] | |
triton_poi_fused__scaled_dot_product_efficient_attention_index_put_3.run(arg3_1, buf1756, arg534_1, buf1764, buf1749, arg535_1, arg536_1, buf1766, 4096, grid=grid(4096), stream=stream0) | |
del arg534_1 | |
del buf1749 | |
buf1767 = buf1680; del buf1680 # reuse | |
# Source Nodes: [output_40], Original ATen: [aten._scaled_dot_product_efficient_attention] | |
triton_poi_fused__scaled_dot_product_efficient_attention_4.run(arg3_1, arg2_1, buf1767, 65536, grid=grid(65536), stream=stream0) | |
# Source Nodes: [output_40], Original ATen: [aten._scaled_dot_product_efficient_attention] | |
buf1768 = torch.ops.aten._scaled_dot_product_efficient_attention.default(buf1766, arg535_1, arg536_1, buf1767, False) | |
del arg535_1 | |
del arg536_1 | |
del buf1766 | |
buf1769 = buf1768[0] | |
del buf1768 | |
# Source Nodes: [choose_qparams_per_token_asymmetric_143], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric] | |
buf1773 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(reinterpret_tensor(buf1769, (1, 1, 4096), (4096, 4096, 1), 0), torch.int8) | |
buf1774 = buf1773[0] | |
buf1775 = buf1773[1] | |
del buf1773 | |
# Source Nodes: [input_431], Original ATen: [quantized_decomposed.quantize_per_token] | |
buf1776 = torch.ops.quantized_decomposed.quantize_per_token.default(reinterpret_tensor(buf1769, (1, 1, 4096), (4096, 4096, 1), 0), buf1774, buf1775, -128, 127, torch.int8) | |
buf1777 = buf1776 | |
del buf1776 | |
# Source Nodes: [input_432], Original ATen: [quantized_decomposed.dequantize_per_token] | |
buf1778 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf1777, buf1774, buf1775, -128, 127, torch.int8, torch.bfloat16) | |
del buf1774 | |
del buf1775 | |
del buf1777 | |
buf1779 = buf1778 | |
del buf1778 | |
# Source Nodes: [w_dq_143], Original ATen: [quantized_decomposed.dequantize_per_channel_group] | |
buf1780 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg537_1, arg538_1, arg539_1, -8, 7, torch.int8, 256, torch.bfloat16) | |
del arg537_1 | |
del arg538_1 | |
del arg539_1 | |
buf1781 = buf1780 | |
del buf1780 | |
buf1782 = reinterpret_tensor(buf1769, (1, 4096), (4096, 1), 0); del buf1769 # reuse | |
# Source Nodes: [c_143], Original ATen: [aten.mm] | |
triton_tem_fused_mm_2.run(buf1779, buf1781, buf1782, grid=torch._inductor.kernel.mm_common.mm_grid(1, 4096, meta0), stream=stream0) | |
del buf1781 | |
buf1784 = buf1779; del buf1779 # reuse | |
# Source Nodes: [add_145, h_21, mean_41, mul_270, mul_271, out_19, pow_42, rsqrt_41, x_fp32_41, x_normed_41], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, aten.pow, aten.rsqrt] | |
triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_12.run(buf1782, buf1696, buf1729, arg540_1, buf1784, 1, 4096, grid=grid(1), stream=stream0) | |
del arg540_1 | |
# Source Nodes: [choose_qparams_per_token_asymmetric_144], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric] | |
buf1785 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf1784, torch.int8) | |
buf1786 = buf1785[0] | |
buf1787 = buf1785[1] | |
del buf1785 | |
# Source Nodes: [choose_qparams_per_token_asymmetric_145], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric] | |
buf1788 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf1784, torch.int8) | |
buf1789 = buf1788[0] | |
buf1790 = buf1788[1] | |
del buf1788 | |
# Source Nodes: [input_434], Original ATen: [quantized_decomposed.quantize_per_token] | |
buf1791 = torch.ops.quantized_decomposed.quantize_per_token.default(buf1784, buf1786, buf1787, -128, 127, torch.int8) | |
buf1792 = buf1791 | |
del buf1791 | |
# Source Nodes: [input_435], Original ATen: [quantized_decomposed.dequantize_per_token] | |
buf1793 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf1792, buf1786, buf1787, -128, 127, torch.int8, torch.bfloat16) | |
del buf1786 | |
del buf1787 | |
del buf1792 | |
buf1794 = buf1793 | |
del buf1793 | |
# Source Nodes: [w_dq_144], Original ATen: [quantized_decomposed.dequantize_per_channel_group] | |
buf1795 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg541_1, arg542_1, arg543_1, -8, 7, torch.int8, 256, torch.bfloat16) | |
del arg541_1 | |
del arg542_1 | |
del arg543_1 | |
buf1796 = buf1795 | |
del buf1795 | |
buf1797 = reinterpret_tensor(buf1726, (1, 14336), (14336, 1), 0); del buf1726 # reuse | |
# Source Nodes: [c_144], Original ATen: [aten.mm] | |
triton_tem_fused_mm_6.run(buf1794, buf1796, buf1797, grid=torch._inductor.kernel.mm_common.mm_grid(1, 14336, meta1), stream=stream0) | |
del buf1796 | |
# Source Nodes: [input_437], Original ATen: [quantized_decomposed.quantize_per_token] | |
buf1798 = torch.ops.quantized_decomposed.quantize_per_token.default(buf1784, buf1789, buf1790, -128, 127, torch.int8) | |
buf1799 = buf1798 | |
del buf1798 | |
# Source Nodes: [input_438], Original ATen: [quantized_decomposed.dequantize_per_token] | |
buf1800 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf1799, buf1789, buf1790, -128, 127, torch.int8, torch.bfloat16) | |
del buf1789 | |
del buf1790 | |
del buf1799 | |
buf1801 = buf1800 | |
del buf1800 | |
# Source Nodes: [w_dq_145], Original ATen: [quantized_decomposed.dequantize_per_channel_group] | |
buf1802 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg544_1, arg545_1, arg546_1, -8, 7, torch.int8, 256, torch.bfloat16) | |
del arg544_1 | |
del arg545_1 | |
del arg546_1 | |
buf1803 = buf1802 | |
del buf1802 | |
buf1805 = buf1719; del buf1719 # reuse | |
# Source Nodes: [c_145, mul_272, silu_20], Original ATen: [aten.mm, aten.mul, aten.silu] | |
triton_tem_fused_mm_mul_silu_7.run(buf1801, buf1803, buf1797, buf1805, grid=torch._inductor.kernel.mm_common.mm_grid(1, 14336, meta1), stream=stream0) | |
del buf1797 | |
del buf1803 | |
# Source Nodes: [choose_qparams_per_token_asymmetric_146], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric] | |
buf1806 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf1805, torch.int8) | |
buf1807 = buf1806[0] | |
buf1808 = buf1806[1] | |
del buf1806 | |
# Source Nodes: [input_440], Original ATen: [quantized_decomposed.quantize_per_token] | |
buf1809 = torch.ops.quantized_decomposed.quantize_per_token.default(buf1805, buf1807, buf1808, -128, 127, torch.int8) | |
buf1810 = buf1809 | |
del buf1809 | |
# Source Nodes: [input_441], Original ATen: [quantized_decomposed.dequantize_per_token] | |
buf1811 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf1810, buf1807, buf1808, -128, 127, torch.int8, torch.bfloat16) | |
del buf1807 | |
del buf1808 | |
del buf1810 | |
buf1812 = buf1811 | |
del buf1811 | |
# Source Nodes: [w_dq_146], Original ATen: [quantized_decomposed.dequantize_per_channel_group] | |
buf1813 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg547_1, arg548_1, arg549_1, -8, 7, torch.int8, 256, torch.bfloat16) | |
del arg547_1 | |
del arg548_1 | |
del arg549_1 | |
buf1814 = buf1813 | |
del buf1813 | |
buf1815 = reinterpret_tensor(buf1801, (1, 4096), (4096, 1), 0); del buf1801 # reuse | |
# Source Nodes: [c_146], Original ATen: [aten.mm] | |
triton_tem_fused_mm_8.run(buf1812, buf1814, buf1815, grid=torch._inductor.kernel.mm_common.mm_grid(1, 4096, meta0), stream=stream0) | |
del buf1814 | |
buf1817 = buf1784; del buf1784 # reuse | |
# Source Nodes: [add_147, h_21, mean_42, mul_273, out_19, out_20, pow_43, rsqrt_42, x_fp32_42, x_normed_42, y_21], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, aten.pow, aten.rsqrt] | |
triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_13.run(buf1782, buf1696, buf1729, buf1815, arg550_1, buf1817, 1, 4096, grid=grid(1), stream=stream0) | |
del arg550_1 | |
# Source Nodes: [choose_qparams_per_token_asymmetric_147], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric] | |
buf1818 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf1817, torch.int8) | |
buf1819 = buf1818[0] | |
buf1820 = buf1818[1] | |
del buf1818 | |
# Source Nodes: [choose_qparams_per_token_asymmetric_148], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric] | |
buf1821 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf1817, torch.int8) | |
buf1822 = buf1821[0] | |
buf1823 = buf1821[1] | |
del buf1821 | |
# Source Nodes: [choose_qparams_per_token_asymmetric_149], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric] | |
buf1824 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf1817, torch.int8) | |
buf1825 = buf1824[0] | |
buf1826 = buf1824[1] | |
del buf1824 | |
buf1827 = buf1741; del buf1741 # reuse | |
# Source Nodes: [max_22], Original ATen: [aten.max] | |
triton_poi_fused_max_1.run(arg3_1, buf1827, 1, grid=grid(1), stream=stream0) | |
u21 = buf1827.item() | |
buf1828 = None | |
with torch.cuda._DeviceGuard(0): | |
torch.cuda.set_device(0) | |
# Source Nodes: [input_443], Original ATen: [quantized_decomposed.quantize_per_token] | |
buf1829 = torch.ops.quantized_decomposed.quantize_per_token.default(buf1817, buf1819, buf1820, -128, 127, torch.int8) | |
buf1830 = buf1829 | |
del buf1829 | |
# Source Nodes: [input_444], Original ATen: [quantized_decomposed.dequantize_per_token] | |
buf1831 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf1830, buf1819, buf1820, -128, 127, torch.int8, torch.bfloat16) | |
del buf1819 | |
del buf1820 | |
del buf1830 | |
buf1832 = buf1831 | |
del buf1831 | |
# Source Nodes: [w_dq_147], Original ATen: [quantized_decomposed.dequantize_per_channel_group] | |
buf1833 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg551_1, arg552_1, arg553_1, -8, 7, torch.int8, 256, torch.bfloat16) | |
del arg551_1 | |
del arg552_1 | |
del arg553_1 | |
buf1834 = buf1833 | |
del buf1833 | |
buf1835 = reinterpret_tensor(buf1794, (1, 4096), (4096, 1), 0); del buf1794 # reuse | |
# Source Nodes: [c_147], Original ATen: [aten.mm] | |
triton_tem_fused_mm_2.run(buf1832, buf1834, buf1835, grid=torch._inductor.kernel.mm_common.mm_grid(1, 4096, meta0), stream=stream0) | |
del buf1832 | |
del buf1834 | |
# Source Nodes: [input_446], Original ATen: [quantized_decomposed.quantize_per_token] | |
buf1836 = torch.ops.quantized_decomposed.quantize_per_token.default(buf1817, buf1822, buf1823, -128, 127, torch.int8) | |
buf1837 = buf1836 | |
del buf1836 | |
# Source Nodes: [input_447], Original ATen: [quantized_decomposed.dequantize_per_token] | |
buf1838 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf1837, buf1822, buf1823, -128, 127, torch.int8, torch.bfloat16) | |
del buf1822 | |
del buf1823 | |
del buf1837 | |
buf1839 = buf1838 | |
del buf1838 | |
# Source Nodes: [w_dq_148], Original ATen: [quantized_decomposed.dequantize_per_channel_group] | |
buf1840 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg554_1, arg555_1, arg556_1, -8, 7, torch.int8, 256, torch.bfloat16) | |
del arg554_1 | |
del arg555_1 | |
del arg556_1 | |
buf1841 = buf1840 | |
del buf1840 | |
buf1842 = buf1764; del buf1764 # reuse | |
# Source Nodes: [], Original ATen: [] | |
extern_kernels.mm(reinterpret_tensor(buf1839, (1, 4096), (4096, 1), 0), reinterpret_tensor(buf1841, (4096, 1024), (1, 4096), 0), out=buf1842) | |
del buf1839 | |
del buf1841 | |
# Source Nodes: [input_449], Original ATen: [quantized_decomposed.quantize_per_token] | |
buf1844 = torch.ops.quantized_decomposed.quantize_per_token.default(buf1817, buf1825, buf1826, -128, 127, torch.int8) | |
del buf1817 | |
buf1845 = buf1844 | |
del buf1844 | |
# Source Nodes: [input_450], Original ATen: [quantized_decomposed.dequantize_per_token] | |
buf1846 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf1845, buf1825, buf1826, -128, 127, torch.int8, torch.bfloat16) | |
del buf1825 | |
del buf1826 | |
del buf1845 | |
buf1847 = buf1846 | |
del buf1846 | |
# Source Nodes: [w_dq_149], Original ATen: [quantized_decomposed.dequantize_per_channel_group] | |
buf1848 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg557_1, arg558_1, arg559_1, -8, 7, torch.int8, 256, torch.bfloat16) | |
del arg557_1 | |
del arg558_1 | |
del arg559_1 | |
buf1849 = buf1848 | |
del buf1848 | |
buf1850 = buf1756; del buf1756 # reuse | |
# Source Nodes: [], Original ATen: [] | |
extern_kernels.mm(reinterpret_tensor(buf1847, (1, 4096), (4096, 1), 0), reinterpret_tensor(buf1849, (4096, 1024), (1, 4096), 0), out=buf1850) | |
del buf1849 | |
buf1852 = reinterpret_tensor(buf1847, (1, 32, 1, 128), (4096, 128, 4096, 1), 0); del buf1847 # reuse | |
# Source Nodes: [output_42, setitem_42, setitem_43], Original ATen: [aten._scaled_dot_product_efficient_attention, aten.index_put] | |
triton_poi_fused__scaled_dot_product_efficient_attention_index_put_3.run(arg3_1, buf1842, arg560_1, buf1850, buf1835, arg561_1, arg562_1, buf1852, 4096, grid=grid(4096), stream=stream0) | |
del arg560_1 | |
del buf1835 | |
buf1853 = buf1767; del buf1767 # reuse | |
# Source Nodes: [output_42], Original ATen: [aten._scaled_dot_product_efficient_attention] | |
triton_poi_fused__scaled_dot_product_efficient_attention_4.run(arg3_1, arg2_1, buf1853, 65536, grid=grid(65536), stream=stream0) | |
# Source Nodes: [output_42], Original ATen: [aten._scaled_dot_product_efficient_attention] | |
buf1854 = torch.ops.aten._scaled_dot_product_efficient_attention.default(buf1852, arg561_1, arg562_1, buf1853, False) | |
del arg561_1 | |
del arg562_1 | |
del buf1852 | |
buf1855 = buf1854[0] | |
del buf1854 | |
# Source Nodes: [choose_qparams_per_token_asymmetric_150], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric] | |
buf1859 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(reinterpret_tensor(buf1855, (1, 1, 4096), (4096, 4096, 1), 0), torch.int8) | |
buf1860 = buf1859[0] | |
buf1861 = buf1859[1] | |
del buf1859 | |
# Source Nodes: [input_452], Original ATen: [quantized_decomposed.quantize_per_token] | |
buf1862 = torch.ops.quantized_decomposed.quantize_per_token.default(reinterpret_tensor(buf1855, (1, 1, 4096), (4096, 4096, 1), 0), buf1860, buf1861, -128, 127, torch.int8) | |
buf1863 = buf1862 | |
del buf1862 | |
# Source Nodes: [input_453], Original ATen: [quantized_decomposed.dequantize_per_token] | |
buf1864 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf1863, buf1860, buf1861, -128, 127, torch.int8, torch.bfloat16) | |
del buf1860 | |
del buf1861 | |
del buf1863 | |
buf1865 = buf1864 | |
del buf1864 | |
# Source Nodes: [w_dq_150], Original ATen: [quantized_decomposed.dequantize_per_channel_group] | |
buf1866 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg563_1, arg564_1, arg565_1, -8, 7, torch.int8, 256, torch.bfloat16) | |
del arg563_1 | |
del arg564_1 | |
del arg565_1 | |
buf1867 = buf1866 | |
del buf1866 | |
buf1868 = reinterpret_tensor(buf1855, (1, 4096), (4096, 1), 0); del buf1855 # reuse | |
# Source Nodes: [c_150], Original ATen: [aten.mm] | |
triton_tem_fused_mm_2.run(buf1865, buf1867, buf1868, grid=torch._inductor.kernel.mm_common.mm_grid(1, 4096, meta0), stream=stream0) | |
del buf1867 | |
buf1869 = buf1696; del buf1696 # reuse | |
buf1871 = buf1865; del buf1865 # reuse | |
# Source Nodes: [add_152, h_21, h_22, mean_43, mul_283, mul_284, out_19, out_20, pow_44, rsqrt_43, x_fp32_43, x_normed_43], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, aten.pow, aten.rsqrt] | |
triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_14.run(buf1869, buf1868, buf1782, buf1729, buf1815, arg566_1, buf1871, 1, 4096, grid=grid(1), stream=stream0) | |
del arg566_1 | |
del buf1729 | |
del buf1782 | |
del buf1815 | |
del buf1868 | |
# Source Nodes: [choose_qparams_per_token_asymmetric_151], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric] | |
buf1872 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf1871, torch.int8) | |
buf1873 = buf1872[0] | |
buf1874 = buf1872[1] | |
del buf1872 | |
# Source Nodes: [choose_qparams_per_token_asymmetric_152], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric] | |
buf1875 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf1871, torch.int8) | |
buf1876 = buf1875[0] | |
buf1877 = buf1875[1] | |
del buf1875 | |
# Source Nodes: [input_455], Original ATen: [quantized_decomposed.quantize_per_token] | |
buf1878 = torch.ops.quantized_decomposed.quantize_per_token.default(buf1871, buf1873, buf1874, -128, 127, torch.int8) | |
buf1879 = buf1878 | |
del buf1878 | |
# Source Nodes: [input_456], Original ATen: [quantized_decomposed.dequantize_per_token] | |
buf1880 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf1879, buf1873, buf1874, -128, 127, torch.int8, torch.bfloat16) | |
del buf1873 | |
del buf1874 | |
del buf1879 | |
buf1881 = buf1880 | |
del buf1880 | |
# Source Nodes: [w_dq_151], Original ATen: [quantized_decomposed.dequantize_per_channel_group] | |
buf1882 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg567_1, arg568_1, arg569_1, -8, 7, torch.int8, 256, torch.bfloat16) | |
del arg567_1 | |
del arg568_1 | |
del arg569_1 | |
buf1883 = buf1882 | |
del buf1882 | |
buf1884 = reinterpret_tensor(buf1812, (1, 14336), (14336, 1), 0); del buf1812 # reuse | |
# Source Nodes: [c_151], Original ATen: [aten.mm] | |
triton_tem_fused_mm_6.run(buf1881, buf1883, buf1884, grid=torch._inductor.kernel.mm_common.mm_grid(1, 14336, meta1), stream=stream0) | |
del buf1883 | |
# Source Nodes: [input_458], Original ATen: [quantized_decomposed.quantize_per_token] | |
buf1885 = torch.ops.quantized_decomposed.quantize_per_token.default(buf1871, buf1876, buf1877, -128, 127, torch.int8) | |
buf1886 = buf1885 | |
del buf1885 | |
# Source Nodes: [input_459], Original ATen: [quantized_decomposed.dequantize_per_token] | |
buf1887 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf1886, buf1876, buf1877, -128, 127, torch.int8, torch.bfloat16) | |
del buf1876 | |
del buf1877 | |
del buf1886 | |
buf1888 = buf1887 | |
del buf1887 | |
# Source Nodes: [w_dq_152], Original ATen: [quantized_decomposed.dequantize_per_channel_group] | |
buf1889 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg570_1, arg571_1, arg572_1, -8, 7, torch.int8, 256, torch.bfloat16) | |
del arg570_1 | |
del arg571_1 | |
del arg572_1 | |
buf1890 = buf1889 | |
del buf1889 | |
buf1892 = buf1805; del buf1805 # reuse | |
# Source Nodes: [c_152, mul_285, silu_21], Original ATen: [aten.mm, aten.mul, aten.silu] | |
triton_tem_fused_mm_mul_silu_7.run(buf1888, buf1890, buf1884, buf1892, grid=torch._inductor.kernel.mm_common.mm_grid(1, 14336, meta1), stream=stream0) | |
del buf1884 | |
del buf1890 | |
# Source Nodes: [choose_qparams_per_token_asymmetric_153], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric] | |
buf1893 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf1892, torch.int8) | |
buf1894 = buf1893[0] | |
buf1895 = buf1893[1] | |
del buf1893 | |
# Source Nodes: [input_461], Original ATen: [quantized_decomposed.quantize_per_token] | |
buf1896 = torch.ops.quantized_decomposed.quantize_per_token.default(buf1892, buf1894, buf1895, -128, 127, torch.int8) | |
buf1897 = buf1896 | |
del buf1896 | |
# Source Nodes: [input_462], Original ATen: [quantized_decomposed.dequantize_per_token] | |
buf1898 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf1897, buf1894, buf1895, -128, 127, torch.int8, torch.bfloat16) | |
del buf1894 | |
del buf1895 | |
del buf1897 | |
buf1899 = buf1898 | |
del buf1898 | |
# Source Nodes: [w_dq_153], Original ATen: [quantized_decomposed.dequantize_per_channel_group] | |
buf1900 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg573_1, arg574_1, arg575_1, -8, 7, torch.int8, 256, torch.bfloat16) | |
del arg573_1 | |
del arg574_1 | |
del arg575_1 | |
buf1901 = buf1900 | |
del buf1900 | |
buf1902 = reinterpret_tensor(buf1888, (1, 4096), (4096, 1), 0); del buf1888 # reuse | |
# Source Nodes: [c_153], Original ATen: [aten.mm] | |
triton_tem_fused_mm_8.run(buf1899, buf1901, buf1902, grid=torch._inductor.kernel.mm_common.mm_grid(1, 4096, meta0), stream=stream0) | |
del buf1901 | |
buf1904 = buf1871; del buf1871 # reuse | |
# Source Nodes: [add_154, mean_44, mul_286, out_21, pow_45, rsqrt_44, x_fp32_44, x_normed_44, y_22], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, aten.pow, aten.rsqrt] | |
triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_11.run(buf1869, buf1902, arg576_1, buf1904, 1, 4096, grid=grid(1), stream=stream0) | |
del arg576_1 | |
# Source Nodes: [choose_qparams_per_token_asymmetric_154], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric] | |
buf1905 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf1904, torch.int8) | |
buf1906 = buf1905[0] | |
buf1907 = buf1905[1] | |
del buf1905 | |
# Source Nodes: [choose_qparams_per_token_asymmetric_155], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric] | |
buf1908 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf1904, torch.int8) | |
buf1909 = buf1908[0] | |
buf1910 = buf1908[1] | |
del buf1908 | |
# Source Nodes: [choose_qparams_per_token_asymmetric_156], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric] | |
buf1911 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf1904, torch.int8) | |
buf1912 = buf1911[0] | |
buf1913 = buf1911[1] | |
del buf1911 | |
buf1914 = buf1827; del buf1827 # reuse | |
# Source Nodes: [max_23], Original ATen: [aten.max] | |
triton_poi_fused_max_1.run(arg3_1, buf1914, 1, grid=grid(1), stream=stream0) | |
u22 = buf1914.item() | |
buf1915 = None | |
with torch.cuda._DeviceGuard(0): | |
torch.cuda.set_device(0) | |
# Source Nodes: [input_464], Original ATen: [quantized_decomposed.quantize_per_token] | |
buf1916 = torch.ops.quantized_decomposed.quantize_per_token.default(buf1904, buf1906, buf1907, -128, 127, torch.int8) | |
buf1917 = buf1916 | |
del buf1916 | |
# Source Nodes: [input_465], Original ATen: [quantized_decomposed.dequantize_per_token] | |
buf1918 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf1917, buf1906, buf1907, -128, 127, torch.int8, torch.bfloat16) | |
del buf1906 | |
del buf1907 | |
del buf1917 | |
buf1919 = buf1918 | |
del buf1918 | |
# Source Nodes: [w_dq_154], Original ATen: [quantized_decomposed.dequantize_per_channel_group] | |
buf1920 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg577_1, arg578_1, arg579_1, -8, 7, torch.int8, 256, torch.bfloat16) | |
del arg577_1 | |
del arg578_1 | |
del arg579_1 | |
buf1921 = buf1920 | |
del buf1920 | |
buf1922 = reinterpret_tensor(buf1881, (1, 4096), (4096, 1), 0); del buf1881 # reuse | |
# Source Nodes: [c_154], Original ATen: [aten.mm] | |
triton_tem_fused_mm_2.run(buf1919, buf1921, buf1922, grid=torch._inductor.kernel.mm_common.mm_grid(1, 4096, meta0), stream=stream0) | |
del buf1919 | |
del buf1921 | |
# Source Nodes: [input_467], Original ATen: [quantized_decomposed.quantize_per_token] | |
buf1923 = torch.ops.quantized_decomposed.quantize_per_token.default(buf1904, buf1909, buf1910, -128, 127, torch.int8) | |
buf1924 = buf1923 | |
del buf1923 | |
# Source Nodes: [input_468], Original ATen: [quantized_decomposed.dequantize_per_token] | |
buf1925 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf1924, buf1909, buf1910, -128, 127, torch.int8, torch.bfloat16) | |
del buf1909 | |
del buf1910 | |
del buf1924 | |
buf1926 = buf1925 | |
del buf1925 | |
# Source Nodes: [w_dq_155], Original ATen: [quantized_decomposed.dequantize_per_channel_group] | |
buf1927 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg580_1, arg581_1, arg582_1, -8, 7, torch.int8, 256, torch.bfloat16) | |
del arg580_1 | |
del arg581_1 | |
del arg582_1 | |
buf1928 = buf1927 | |
del buf1927 | |
buf1929 = buf1850; del buf1850 # reuse | |
# Source Nodes: [], Original ATen: [] | |
extern_kernels.mm(reinterpret_tensor(buf1926, (1, 4096), (4096, 1), 0), reinterpret_tensor(buf1928, (4096, 1024), (1, 4096), 0), out=buf1929) | |
del buf1926 | |
del buf1928 | |
# Source Nodes: [input_470], Original ATen: [quantized_decomposed.quantize_per_token] | |
buf1931 = torch.ops.quantized_decomposed.quantize_per_token.default(buf1904, buf1912, buf1913, -128, 127, torch.int8) | |
del buf1904 | |
buf1932 = buf1931 | |
del buf1931 | |
# Source Nodes: [input_471], Original ATen: [quantized_decomposed.dequantize_per_token] | |
buf1933 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf1932, buf1912, buf1913, -128, 127, torch.int8, torch.bfloat16) | |
del buf1912 | |
del buf1913 | |
del buf1932 | |
buf1934 = buf1933 | |
del buf1933 | |
# Source Nodes: [w_dq_156], Original ATen: [quantized_decomposed.dequantize_per_channel_group] | |
buf1935 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg583_1, arg584_1, arg585_1, -8, 7, torch.int8, 256, torch.bfloat16) | |
del arg583_1 | |
del arg584_1 | |
del arg585_1 | |
buf1936 = buf1935 | |
del buf1935 | |
buf1937 = buf1842; del buf1842 # reuse | |
# Source Nodes: [], Original ATen: [] | |
extern_kernels.mm(reinterpret_tensor(buf1934, (1, 4096), (4096, 1), 0), reinterpret_tensor(buf1936, (4096, 1024), (1, 4096), 0), out=buf1937) | |
del buf1936 | |
buf1939 = reinterpret_tensor(buf1934, (1, 32, 1, 128), (4096, 128, 4096, 1), 0); del buf1934 # reuse | |
# Source Nodes: [output_44, setitem_44, setitem_45], Original ATen: [aten._scaled_dot_product_efficient_attention, aten.index_put] | |
triton_poi_fused__scaled_dot_product_efficient_attention_index_put_3.run(arg3_1, buf1929, arg586_1, buf1937, buf1922, arg587_1, arg588_1, buf1939, 4096, grid=grid(4096), stream=stream0) | |
del arg586_1 | |
del buf1922 | |
buf1940 = buf1853; del buf1853 # reuse | |
# Source Nodes: [output_44], Original ATen: [aten._scaled_dot_product_efficient_attention] | |
triton_poi_fused__scaled_dot_product_efficient_attention_4.run(arg3_1, arg2_1, buf1940, 65536, grid=grid(65536), stream=stream0) | |
# Source Nodes: [output_44], Original ATen: [aten._scaled_dot_product_efficient_attention] | |
buf1941 = torch.ops.aten._scaled_dot_product_efficient_attention.default(buf1939, arg587_1, arg588_1, buf1940, False) | |
del arg587_1 | |
del arg588_1 | |
del buf1939 | |
buf1942 = buf1941[0] | |
del buf1941 | |
# Source Nodes: [choose_qparams_per_token_asymmetric_157], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric] | |
buf1946 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(reinterpret_tensor(buf1942, (1, 1, 4096), (4096, 4096, 1), 0), torch.int8) | |
buf1947 = buf1946[0] | |
buf1948 = buf1946[1] | |
del buf1946 | |
# Source Nodes: [input_473], Original ATen: [quantized_decomposed.quantize_per_token] | |
buf1949 = torch.ops.quantized_decomposed.quantize_per_token.default(reinterpret_tensor(buf1942, (1, 1, 4096), (4096, 4096, 1), 0), buf1947, buf1948, -128, 127, torch.int8) | |
buf1950 = buf1949 | |
del buf1949 | |
# Source Nodes: [input_474], Original ATen: [quantized_decomposed.dequantize_per_token] | |
buf1951 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf1950, buf1947, buf1948, -128, 127, torch.int8, torch.bfloat16) | |
del buf1947 | |
del buf1948 | |
del buf1950 | |
buf1952 = buf1951 | |
del buf1951 | |
# Source Nodes: [w_dq_157], Original ATen: [quantized_decomposed.dequantize_per_channel_group] | |
buf1953 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg589_1, arg590_1, arg591_1, -8, 7, torch.int8, 256, torch.bfloat16) | |
del arg589_1 | |
del arg590_1 | |
del arg591_1 | |
buf1954 = buf1953 | |
del buf1953 | |
buf1955 = reinterpret_tensor(buf1942, (1, 4096), (4096, 1), 0); del buf1942 # reuse | |
# Source Nodes: [c_157], Original ATen: [aten.mm] | |
triton_tem_fused_mm_2.run(buf1952, buf1954, buf1955, grid=torch._inductor.kernel.mm_common.mm_grid(1, 4096, meta0), stream=stream0) | |
del buf1954 | |
buf1957 = buf1952; del buf1952 # reuse | |
# Source Nodes: [add_159, h_23, mean_45, mul_296, mul_297, out_21, pow_46, rsqrt_45, x_fp32_45, x_normed_45], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, aten.pow, aten.rsqrt] | |
triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_12.run(buf1955, buf1869, buf1902, arg592_1, buf1957, 1, 4096, grid=grid(1), stream=stream0) | |
del arg592_1 | |
# Source Nodes: [choose_qparams_per_token_asymmetric_158], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric] | |
buf1958 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf1957, torch.int8) | |
buf1959 = buf1958[0] | |
buf1960 = buf1958[1] | |
del buf1958 | |
# Source Nodes: [choose_qparams_per_token_asymmetric_159], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric] | |
buf1961 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf1957, torch.int8) | |
buf1962 = buf1961[0] | |
buf1963 = buf1961[1] | |
del buf1961 | |
# Source Nodes: [input_476], Original ATen: [quantized_decomposed.quantize_per_token] | |
buf1964 = torch.ops.quantized_decomposed.quantize_per_token.default(buf1957, buf1959, buf1960, -128, 127, torch.int8) | |
buf1965 = buf1964 | |
del buf1964 | |
# Source Nodes: [input_477], Original ATen: [quantized_decomposed.dequantize_per_token] | |
buf1966 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf1965, buf1959, buf1960, -128, 127, torch.int8, torch.bfloat16) | |
del buf1959 | |
del buf1960 | |
del buf1965 | |
buf1967 = buf1966 | |
del buf1966 | |
# Source Nodes: [w_dq_158], Original ATen: [quantized_decomposed.dequantize_per_channel_group] | |
buf1968 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg593_1, arg594_1, arg595_1, -8, 7, torch.int8, 256, torch.bfloat16) | |
del arg593_1 | |
del arg594_1 | |
del arg595_1 | |
buf1969 = buf1968 | |
del buf1968 | |
buf1970 = reinterpret_tensor(buf1899, (1, 14336), (14336, 1), 0); del buf1899 # reuse | |
# Source Nodes: [c_158], Original ATen: [aten.mm] | |
triton_tem_fused_mm_6.run(buf1967, buf1969, buf1970, grid=torch._inductor.kernel.mm_common.mm_grid(1, 14336, meta1), stream=stream0) | |
del buf1969 | |
# Source Nodes: [input_479], Original ATen: [quantized_decomposed.quantize_per_token] | |
buf1971 = torch.ops.quantized_decomposed.quantize_per_token.default(buf1957, buf1962, buf1963, -128, 127, torch.int8) | |
buf1972 = buf1971 | |
del buf1971 | |
# Source Nodes: [input_480], Original ATen: [quantized_decomposed.dequantize_per_token] | |
buf1973 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf1972, buf1962, buf1963, -128, 127, torch.int8, torch.bfloat16) | |
del buf1962 | |
del buf1963 | |
del buf1972 | |
buf1974 = buf1973 | |
del buf1973 | |
# Source Nodes: [w_dq_159], Original ATen: [quantized_decomposed.dequantize_per_channel_group] | |
buf1975 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg596_1, arg597_1, arg598_1, -8, 7, torch.int8, 256, torch.bfloat16) | |
del arg596_1 | |
del arg597_1 | |
del arg598_1 | |
buf1976 = buf1975 | |
del buf1975 | |
buf1978 = buf1892; del buf1892 # reuse | |
# Source Nodes: [c_159, mul_298, silu_22], Original ATen: [aten.mm, aten.mul, aten.silu] | |
triton_tem_fused_mm_mul_silu_7.run(buf1974, buf1976, buf1970, buf1978, grid=torch._inductor.kernel.mm_common.mm_grid(1, 14336, meta1), stream=stream0) | |
del buf1970 | |
del buf1976 | |
# Source Nodes: [choose_qparams_per_token_asymmetric_160], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric] | |
buf1979 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf1978, torch.int8) | |
buf1980 = buf1979[0] | |
buf1981 = buf1979[1] | |
del buf1979 | |
# Source Nodes: [input_482], Original ATen: [quantized_decomposed.quantize_per_token] | |
buf1982 = torch.ops.quantized_decomposed.quantize_per_token.default(buf1978, buf1980, buf1981, -128, 127, torch.int8) | |
buf1983 = buf1982 | |
del buf1982 | |
# Source Nodes: [input_483], Original ATen: [quantized_decomposed.dequantize_per_token] | |
buf1984 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf1983, buf1980, buf1981, -128, 127, torch.int8, torch.bfloat16) | |
del buf1980 | |
del buf1981 | |
del buf1983 | |
buf1985 = buf1984 | |
del buf1984 | |
# Source Nodes: [w_dq_160], Original ATen: [quantized_decomposed.dequantize_per_channel_group] | |
buf1986 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg599_1, arg600_1, arg601_1, -8, 7, torch.int8, 256, torch.bfloat16) | |
del arg599_1 | |
del arg600_1 | |
del arg601_1 | |
buf1987 = buf1986 | |
del buf1986 | |
buf1988 = reinterpret_tensor(buf1974, (1, 4096), (4096, 1), 0); del buf1974 # reuse | |
# Source Nodes: [c_160], Original ATen: [aten.mm] | |
triton_tem_fused_mm_8.run(buf1985, buf1987, buf1988, grid=torch._inductor.kernel.mm_common.mm_grid(1, 4096, meta0), stream=stream0) | |
del buf1987 | |
buf1990 = buf1957; del buf1957 # reuse | |
# Source Nodes: [add_161, h_23, mean_46, mul_299, out_21, out_22, pow_47, rsqrt_46, x_fp32_46, x_normed_46, y_23], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, aten.pow, aten.rsqrt] | |
triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_13.run(buf1955, buf1869, buf1902, buf1988, arg602_1, buf1990, 1, 4096, grid=grid(1), stream=stream0) | |
del arg602_1 | |
# Source Nodes: [choose_qparams_per_token_asymmetric_161], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric] | |
buf1991 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf1990, torch.int8) | |
buf1992 = buf1991[0] | |
buf1993 = buf1991[1] | |
del buf1991 | |
# Source Nodes: [choose_qparams_per_token_asymmetric_162], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric] | |
buf1994 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf1990, torch.int8) | |
buf1995 = buf1994[0] | |
buf1996 = buf1994[1] | |
del buf1994 | |
# Source Nodes: [choose_qparams_per_token_asymmetric_163], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric] | |
buf1997 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf1990, torch.int8) | |
buf1998 = buf1997[0] | |
buf1999 = buf1997[1] | |
del buf1997 | |
buf2000 = buf1914; del buf1914 # reuse | |
# Source Nodes: [max_24], Original ATen: [aten.max] | |
triton_poi_fused_max_1.run(arg3_1, buf2000, 1, grid=grid(1), stream=stream0) | |
u23 = buf2000.item() | |
buf2001 = None | |
with torch.cuda._DeviceGuard(0): | |
torch.cuda.set_device(0) | |
# Source Nodes: [input_485], Original ATen: [quantized_decomposed.quantize_per_token] | |
buf2002 = torch.ops.quantized_decomposed.quantize_per_token.default(buf1990, buf1992, buf1993, -128, 127, torch.int8) | |
buf2003 = buf2002 | |
del buf2002 | |
# Source Nodes: [input_486], Original ATen: [quantized_decomposed.dequantize_per_token] | |
buf2004 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf2003, buf1992, buf1993, -128, 127, torch.int8, torch.bfloat16) | |
del buf1992 | |
del buf1993 | |
del buf2003 | |
buf2005 = buf2004 | |
del buf2004 | |
# Source Nodes: [w_dq_161], Original ATen: [quantized_decomposed.dequantize_per_channel_group] | |
buf2006 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg603_1, arg604_1, arg605_1, -8, 7, torch.int8, 256, torch.bfloat16) | |
del arg603_1 | |
del arg604_1 | |
del arg605_1 | |
buf2007 = buf2006 | |
del buf2006 | |
buf2008 = reinterpret_tensor(buf1967, (1, 4096), (4096, 1), 0); del buf1967 # reuse | |
# Source Nodes: [c_161], Original ATen: [aten.mm] | |
triton_tem_fused_mm_2.run(buf2005, buf2007, buf2008, grid=torch._inductor.kernel.mm_common.mm_grid(1, 4096, meta0), stream=stream0) | |
del buf2005 | |
del buf2007 | |
# Source Nodes: [input_488], Original ATen: [quantized_decomposed.quantize_per_token] | |
buf2009 = torch.ops.quantized_decomposed.quantize_per_token.default(buf1990, buf1995, buf1996, -128, 127, torch.int8) | |
buf2010 = buf2009 | |
del buf2009 | |
# Source Nodes: [input_489], Original ATen: [quantized_decomposed.dequantize_per_token] | |
buf2011 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf2010, buf1995, buf1996, -128, 127, torch.int8, torch.bfloat16) | |
del buf1995 | |
del buf1996 | |
del buf2010 | |
buf2012 = buf2011 | |
del buf2011 | |
# Source Nodes: [w_dq_162], Original ATen: [quantized_decomposed.dequantize_per_channel_group] | |
buf2013 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg606_1, arg607_1, arg608_1, -8, 7, torch.int8, 256, torch.bfloat16) | |
del arg606_1 | |
del arg607_1 | |
del arg608_1 | |
buf2014 = buf2013 | |
del buf2013 | |
buf2015 = buf1937; del buf1937 # reuse | |
# Source Nodes: [], Original ATen: [] | |
extern_kernels.mm(reinterpret_tensor(buf2012, (1, 4096), (4096, 1), 0), reinterpret_tensor(buf2014, (4096, 1024), (1, 4096), 0), out=buf2015) | |
del buf2012 | |
del buf2014 | |
# Source Nodes: [input_491], Original ATen: [quantized_decomposed.quantize_per_token] | |
buf2017 = torch.ops.quantized_decomposed.quantize_per_token.default(buf1990, buf1998, buf1999, -128, 127, torch.int8) | |
del buf1990 | |
buf2018 = buf2017 | |
del buf2017 | |
# Source Nodes: [input_492], Original ATen: [quantized_decomposed.dequantize_per_token] | |
buf2019 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf2018, buf1998, buf1999, -128, 127, torch.int8, torch.bfloat16) | |
del buf1998 | |
del buf1999 | |
del buf2018 | |
buf2020 = buf2019 | |
del buf2019 | |
# Source Nodes: [w_dq_163], Original ATen: [quantized_decomposed.dequantize_per_channel_group] | |
buf2021 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg609_1, arg610_1, arg611_1, -8, 7, torch.int8, 256, torch.bfloat16) | |
del arg609_1 | |
del arg610_1 | |
del arg611_1 | |
buf2022 = buf2021 | |
del buf2021 | |
buf2023 = buf1929; del buf1929 # reuse | |
# Source Nodes: [], Original ATen: [] | |
extern_kernels.mm(reinterpret_tensor(buf2020, (1, 4096), (4096, 1), 0), reinterpret_tensor(buf2022, (4096, 1024), (1, 4096), 0), out=buf2023) | |
del buf2022 | |
buf2025 = reinterpret_tensor(buf2020, (1, 32, 1, 128), (4096, 128, 4096, 1), 0); del buf2020 # reuse | |
# Source Nodes: [output_46, setitem_46, setitem_47], Original ATen: [aten._scaled_dot_product_efficient_attention, aten.index_put] | |
triton_poi_fused__scaled_dot_product_efficient_attention_index_put_3.run(arg3_1, buf2015, arg612_1, buf2023, buf2008, arg613_1, arg614_1, buf2025, 4096, grid=grid(4096), stream=stream0) | |
del arg612_1 | |
del buf2008 | |
buf2026 = buf1940; del buf1940 # reuse | |
# Source Nodes: [output_46], Original ATen: [aten._scaled_dot_product_efficient_attention] | |
triton_poi_fused__scaled_dot_product_efficient_attention_4.run(arg3_1, arg2_1, buf2026, 65536, grid=grid(65536), stream=stream0) | |
# Source Nodes: [output_46], Original ATen: [aten._scaled_dot_product_efficient_attention] | |
buf2027 = torch.ops.aten._scaled_dot_product_efficient_attention.default(buf2025, arg613_1, arg614_1, buf2026, False) | |
del arg613_1 | |
del arg614_1 | |
del buf2025 | |
buf2028 = buf2027[0] | |
del buf2027 | |
# Source Nodes: [choose_qparams_per_token_asymmetric_164], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric] | |
buf2032 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(reinterpret_tensor(buf2028, (1, 1, 4096), (4096, 4096, 1), 0), torch.int8) | |
buf2033 = buf2032[0] | |
buf2034 = buf2032[1] | |
del buf2032 | |
# Source Nodes: [input_494], Original ATen: [quantized_decomposed.quantize_per_token] | |
buf2035 = torch.ops.quantized_decomposed.quantize_per_token.default(reinterpret_tensor(buf2028, (1, 1, 4096), (4096, 4096, 1), 0), buf2033, buf2034, -128, 127, torch.int8) | |
buf2036 = buf2035 | |
del buf2035 | |
# Source Nodes: [input_495], Original ATen: [quantized_decomposed.dequantize_per_token] | |
buf2037 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf2036, buf2033, buf2034, -128, 127, torch.int8, torch.bfloat16) | |
del buf2033 | |
del buf2034 | |
del buf2036 | |
buf2038 = buf2037 | |
del buf2037 | |
# Source Nodes: [w_dq_164], Original ATen: [quantized_decomposed.dequantize_per_channel_group] | |
buf2039 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg615_1, arg616_1, arg617_1, -8, 7, torch.int8, 256, torch.bfloat16) | |
del arg615_1 | |
del arg616_1 | |
del arg617_1 | |
buf2040 = buf2039 | |
del buf2039 | |
buf2041 = reinterpret_tensor(buf2028, (1, 4096), (4096, 1), 0); del buf2028 # reuse | |
# Source Nodes: [c_164], Original ATen: [aten.mm] | |
triton_tem_fused_mm_2.run(buf2038, buf2040, buf2041, grid=torch._inductor.kernel.mm_common.mm_grid(1, 4096, meta0), stream=stream0) | |
del buf2040 | |
buf2042 = buf1869; del buf1869 # reuse | |
buf2044 = buf2038; del buf2038 # reuse | |
# Source Nodes: [add_166, h_23, h_24, mean_47, mul_309, mul_310, out_21, out_22, pow_48, rsqrt_47, x_fp32_47, x_normed_47], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, aten.pow, aten.rsqrt] | |
triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_14.run(buf2042, buf2041, buf1955, buf1902, buf1988, arg618_1, buf2044, 1, 4096, grid=grid(1), stream=stream0) | |
del arg618_1 | |
del buf1902 | |
del buf1955 | |
del buf1988 | |
del buf2041 | |
# Source Nodes: [choose_qparams_per_token_asymmetric_165], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric] | |
buf2045 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf2044, torch.int8) | |
buf2046 = buf2045[0] | |
buf2047 = buf2045[1] | |
del buf2045 | |
# Source Nodes: [choose_qparams_per_token_asymmetric_166], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric] | |
buf2048 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf2044, torch.int8) | |
buf2049 = buf2048[0] | |
buf2050 = buf2048[1] | |
del buf2048 | |
# Source Nodes: [input_497], Original ATen: [quantized_decomposed.quantize_per_token] | |
buf2051 = torch.ops.quantized_decomposed.quantize_per_token.default(buf2044, buf2046, buf2047, -128, 127, torch.int8) | |
buf2052 = buf2051 | |
del buf2051 | |
# Source Nodes: [input_498], Original ATen: [quantized_decomposed.dequantize_per_token] | |
buf2053 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf2052, buf2046, buf2047, -128, 127, torch.int8, torch.bfloat16) | |
del buf2046 | |
del buf2047 | |
del buf2052 | |
buf2054 = buf2053 | |
del buf2053 | |
# Source Nodes: [w_dq_165], Original ATen: [quantized_decomposed.dequantize_per_channel_group] | |
buf2055 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg619_1, arg620_1, arg621_1, -8, 7, torch.int8, 256, torch.bfloat16) | |
del arg619_1 | |
del arg620_1 | |
del arg621_1 | |
buf2056 = buf2055 | |
del buf2055 | |
buf2057 = reinterpret_tensor(buf1985, (1, 14336), (14336, 1), 0); del buf1985 # reuse | |
# Source Nodes: [c_165], Original ATen: [aten.mm] | |
triton_tem_fused_mm_6.run(buf2054, buf2056, buf2057, grid=torch._inductor.kernel.mm_common.mm_grid(1, 14336, meta1), stream=stream0) | |
del buf2056 | |
# Source Nodes: [input_500], Original ATen: [quantized_decomposed.quantize_per_token] | |
buf2058 = torch.ops.quantized_decomposed.quantize_per_token.default(buf2044, buf2049, buf2050, -128, 127, torch.int8) | |
buf2059 = buf2058 | |
del buf2058 | |
# Source Nodes: [input_501], Original ATen: [quantized_decomposed.dequantize_per_token] | |
buf2060 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf2059, buf2049, buf2050, -128, 127, torch.int8, torch.bfloat16) | |
del buf2049 | |
del buf2050 | |
del buf2059 | |
buf2061 = buf2060 | |
del buf2060 | |
# Source Nodes: [w_dq_166], Original ATen: [quantized_decomposed.dequantize_per_channel_group] | |
buf2062 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg622_1, arg623_1, arg624_1, -8, 7, torch.int8, 256, torch.bfloat16) | |
del arg622_1 | |
del arg623_1 | |
del arg624_1 | |
buf2063 = buf2062 | |
del buf2062 | |
buf2065 = buf1978; del buf1978 # reuse | |
# Source Nodes: [c_166, mul_311, silu_23], Original ATen: [aten.mm, aten.mul, aten.silu] | |
triton_tem_fused_mm_mul_silu_7.run(buf2061, buf2063, buf2057, buf2065, grid=torch._inductor.kernel.mm_common.mm_grid(1, 14336, meta1), stream=stream0) | |
del buf2057 | |
del buf2063 | |
# Source Nodes: [choose_qparams_per_token_asymmetric_167], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric] | |
buf2066 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf2065, torch.int8) | |
buf2067 = buf2066[0] | |
buf2068 = buf2066[1] | |
del buf2066 | |
# Source Nodes: [input_503], Original ATen: [quantized_decomposed.quantize_per_token] | |
buf2069 = torch.ops.quantized_decomposed.quantize_per_token.default(buf2065, buf2067, buf2068, -128, 127, torch.int8) | |
buf2070 = buf2069 | |
del buf2069 | |
# Source Nodes: [input_504], Original ATen: [quantized_decomposed.dequantize_per_token] | |
buf2071 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf2070, buf2067, buf2068, -128, 127, torch.int8, torch.bfloat16) | |
del buf2067 | |
del buf2068 | |
del buf2070 | |
buf2072 = buf2071 | |
del buf2071 | |
# Source Nodes: [w_dq_167], Original ATen: [quantized_decomposed.dequantize_per_channel_group] | |
buf2073 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg625_1, arg626_1, arg627_1, -8, 7, torch.int8, 256, torch.bfloat16) | |
del arg625_1 | |
del arg626_1 | |
del arg627_1 | |
buf2074 = buf2073 | |
del buf2073 | |
buf2075 = reinterpret_tensor(buf2061, (1, 4096), (4096, 1), 0); del buf2061 # reuse | |
# Source Nodes: [c_167], Original ATen: [aten.mm] | |
triton_tem_fused_mm_8.run(buf2072, buf2074, buf2075, grid=torch._inductor.kernel.mm_common.mm_grid(1, 4096, meta0), stream=stream0) | |
del buf2074 | |
buf2077 = buf2044; del buf2044 # reuse | |
# Source Nodes: [add_168, mean_48, mul_312, out_23, pow_49, rsqrt_48, x_fp32_48, x_normed_48, y_24], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, aten.pow, aten.rsqrt] | |
triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_11.run(buf2042, buf2075, arg628_1, buf2077, 1, 4096, grid=grid(1), stream=stream0) | |
del arg628_1 | |
# Source Nodes: [choose_qparams_per_token_asymmetric_168], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric] | |
buf2078 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf2077, torch.int8) | |
buf2079 = buf2078[0] | |
buf2080 = buf2078[1] | |
del buf2078 | |
# Source Nodes: [choose_qparams_per_token_asymmetric_169], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric] | |
buf2081 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf2077, torch.int8) | |
buf2082 = buf2081[0] | |
buf2083 = buf2081[1] | |
del buf2081 | |
# Source Nodes: [choose_qparams_per_token_asymmetric_170], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric] | |
buf2084 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf2077, torch.int8) | |
buf2085 = buf2084[0] | |
buf2086 = buf2084[1] | |
del buf2084 | |
buf2087 = buf2000; del buf2000 # reuse | |
# Source Nodes: [max_25], Original ATen: [aten.max] | |
triton_poi_fused_max_1.run(arg3_1, buf2087, 1, grid=grid(1), stream=stream0) | |
u24 = buf2087.item() | |
buf2088 = None | |
with torch.cuda._DeviceGuard(0): | |
torch.cuda.set_device(0) | |
# Source Nodes: [input_506], Original ATen: [quantized_decomposed.quantize_per_token] | |
buf2089 = torch.ops.quantized_decomposed.quantize_per_token.default(buf2077, buf2079, buf2080, -128, 127, torch.int8) | |
buf2090 = buf2089 | |
del buf2089 | |
# Source Nodes: [input_507], Original ATen: [quantized_decomposed.dequantize_per_token] | |
buf2091 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf2090, buf2079, buf2080, -128, 127, torch.int8, torch.bfloat16) | |
del buf2079 | |
del buf2080 | |
del buf2090 | |
buf2092 = buf2091 | |
del buf2091 | |
# Source Nodes: [w_dq_168], Original ATen: [quantized_decomposed.dequantize_per_channel_group] | |
buf2093 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg629_1, arg630_1, arg631_1, -8, 7, torch.int8, 256, torch.bfloat16) | |
del arg629_1 | |
del arg630_1 | |
del arg631_1 | |
buf2094 = buf2093 | |
del buf2093 | |
buf2095 = reinterpret_tensor(buf2054, (1, 4096), (4096, 1), 0); del buf2054 # reuse | |
# Source Nodes: [c_168], Original ATen: [aten.mm] | |
triton_tem_fused_mm_2.run(buf2092, buf2094, buf2095, grid=torch._inductor.kernel.mm_common.mm_grid(1, 4096, meta0), stream=stream0) | |
del buf2092 | |
del buf2094 | |
# Source Nodes: [input_509], Original ATen: [quantized_decomposed.quantize_per_token] | |
buf2096 = torch.ops.quantized_decomposed.quantize_per_token.default(buf2077, buf2082, buf2083, -128, 127, torch.int8) | |
buf2097 = buf2096 | |
del buf2096 | |
# Source Nodes: [input_510], Original ATen: [quantized_decomposed.dequantize_per_token] | |
buf2098 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf2097, buf2082, buf2083, -128, 127, torch.int8, torch.bfloat16) | |
del buf2082 | |
del buf2083 | |
del buf2097 | |
buf2099 = buf2098 | |
del buf2098 | |
# Source Nodes: [w_dq_169], Original ATen: [quantized_decomposed.dequantize_per_channel_group] | |
buf2100 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg632_1, arg633_1, arg634_1, -8, 7, torch.int8, 256, torch.bfloat16) | |
del arg632_1 | |
del arg633_1 | |
del arg634_1 | |
buf2101 = buf2100 | |
del buf2100 | |
buf2102 = buf2023; del buf2023 # reuse | |
# Source Nodes: [], Original ATen: [] | |
extern_kernels.mm(reinterpret_tensor(buf2099, (1, 4096), (4096, 1), 0), reinterpret_tensor(buf2101, (4096, 1024), (1, 4096), 0), out=buf2102) | |
del buf2099 | |
del buf2101 | |
# Source Nodes: [input_512], Original ATen: [quantized_decomposed.quantize_per_token] | |
buf2104 = torch.ops.quantized_decomposed.quantize_per_token.default(buf2077, buf2085, buf2086, -128, 127, torch.int8) | |
del buf2077 | |
buf2105 = buf2104 | |
del buf2104 | |
# Source Nodes: [input_513], Original ATen: [quantized_decomposed.dequantize_per_token] | |
buf2106 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf2105, buf2085, buf2086, -128, 127, torch.int8, torch.bfloat16) | |
del buf2085 | |
del buf2086 | |
del buf2105 | |
buf2107 = buf2106 | |
del buf2106 | |
# Source Nodes: [w_dq_170], Original ATen: [quantized_decomposed.dequantize_per_channel_group] | |
buf2108 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg635_1, arg636_1, arg637_1, -8, 7, torch.int8, 256, torch.bfloat16) | |
del arg635_1 | |
del arg636_1 | |
del arg637_1 | |
buf2109 = buf2108 | |
del buf2108 | |
buf2110 = buf2015; del buf2015 # reuse | |
# Source Nodes: [], Original ATen: [] | |
extern_kernels.mm(reinterpret_tensor(buf2107, (1, 4096), (4096, 1), 0), reinterpret_tensor(buf2109, (4096, 1024), (1, 4096), 0), out=buf2110) | |
del buf2109 | |
buf2112 = reinterpret_tensor(buf2107, (1, 32, 1, 128), (4096, 128, 4096, 1), 0); del buf2107 # reuse | |
# Source Nodes: [output_48, setitem_48, setitem_49], Original ATen: [aten._scaled_dot_product_efficient_attention, aten.index_put] | |
triton_poi_fused__scaled_dot_product_efficient_attention_index_put_3.run(arg3_1, buf2102, arg638_1, buf2110, buf2095, arg639_1, arg640_1, buf2112, 4096, grid=grid(4096), stream=stream0) | |
del arg638_1 | |
del buf2095 | |
buf2113 = buf2026; del buf2026 # reuse | |
# Source Nodes: [output_48], Original ATen: [aten._scaled_dot_product_efficient_attention] | |
triton_poi_fused__scaled_dot_product_efficient_attention_4.run(arg3_1, arg2_1, buf2113, 65536, grid=grid(65536), stream=stream0) | |
# Source Nodes: [output_48], Original ATen: [aten._scaled_dot_product_efficient_attention] | |
buf2114 = torch.ops.aten._scaled_dot_product_efficient_attention.default(buf2112, arg639_1, arg640_1, buf2113, False) | |
del arg639_1 | |
del arg640_1 | |
del buf2112 | |
buf2115 = buf2114[0] | |
del buf2114 | |
# Source Nodes: [choose_qparams_per_token_asymmetric_171], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric] | |
buf2119 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(reinterpret_tensor(buf2115, (1, 1, 4096), (4096, 4096, 1), 0), torch.int8) | |
buf2120 = buf2119[0] | |
buf2121 = buf2119[1] | |
del buf2119 | |
# Source Nodes: [input_515], Original ATen: [quantized_decomposed.quantize_per_token] | |
buf2122 = torch.ops.quantized_decomposed.quantize_per_token.default(reinterpret_tensor(buf2115, (1, 1, 4096), (4096, 4096, 1), 0), buf2120, buf2121, -128, 127, torch.int8) | |
buf2123 = buf2122 | |
del buf2122 | |
# Source Nodes: [input_516], Original ATen: [quantized_decomposed.dequantize_per_token] | |
buf2124 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf2123, buf2120, buf2121, -128, 127, torch.int8, torch.bfloat16) | |
del buf2120 | |
del buf2121 | |
del buf2123 | |
buf2125 = buf2124 | |
del buf2124 | |
# Source Nodes: [w_dq_171], Original ATen: [quantized_decomposed.dequantize_per_channel_group] | |
buf2126 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg641_1, arg642_1, arg643_1, -8, 7, torch.int8, 256, torch.bfloat16) | |
del arg641_1 | |
del arg642_1 | |
del arg643_1 | |
buf2127 = buf2126 | |
del buf2126 | |
buf2128 = reinterpret_tensor(buf2115, (1, 4096), (4096, 1), 0); del buf2115 # reuse | |
# Source Nodes: [c_171], Original ATen: [aten.mm] | |
triton_tem_fused_mm_2.run(buf2125, buf2127, buf2128, grid=torch._inductor.kernel.mm_common.mm_grid(1, 4096, meta0), stream=stream0) | |
del buf2127 | |
buf2130 = buf2125; del buf2125 # reuse | |
# Source Nodes: [add_173, h_25, mean_49, mul_322, mul_323, out_23, pow_50, rsqrt_49, x_fp32_49, x_normed_49], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, aten.pow, aten.rsqrt] | |
triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_12.run(buf2128, buf2042, buf2075, arg644_1, buf2130, 1, 4096, grid=grid(1), stream=stream0) | |
del arg644_1 | |
# Source Nodes: [choose_qparams_per_token_asymmetric_172], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric] | |
buf2131 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf2130, torch.int8) | |
buf2132 = buf2131[0] | |
buf2133 = buf2131[1] | |
del buf2131 | |
# Source Nodes: [choose_qparams_per_token_asymmetric_173], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric] | |
buf2134 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf2130, torch.int8) | |
buf2135 = buf2134[0] | |
buf2136 = buf2134[1] | |
del buf2134 | |
# Source Nodes: [input_518], Original ATen: [quantized_decomposed.quantize_per_token] | |
buf2137 = torch.ops.quantized_decomposed.quantize_per_token.default(buf2130, buf2132, buf2133, -128, 127, torch.int8) | |
buf2138 = buf2137 | |
del buf2137 | |
# Source Nodes: [input_519], Original ATen: [quantized_decomposed.dequantize_per_token] | |
buf2139 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf2138, buf2132, buf2133, -128, 127, torch.int8, torch.bfloat16) | |
del buf2132 | |
del buf2133 | |
del buf2138 | |
buf2140 = buf2139 | |
del buf2139 | |
# Source Nodes: [w_dq_172], Original ATen: [quantized_decomposed.dequantize_per_channel_group] | |
buf2141 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg645_1, arg646_1, arg647_1, -8, 7, torch.int8, 256, torch.bfloat16) | |
del arg645_1 | |
del arg646_1 | |
del arg647_1 | |
buf2142 = buf2141 | |
del buf2141 | |
buf2143 = reinterpret_tensor(buf2072, (1, 14336), (14336, 1), 0); del buf2072 # reuse | |
# Source Nodes: [c_172], Original ATen: [aten.mm] | |
triton_tem_fused_mm_6.run(buf2140, buf2142, buf2143, grid=torch._inductor.kernel.mm_common.mm_grid(1, 14336, meta1), stream=stream0) | |
del buf2142 | |
# Source Nodes: [input_521], Original ATen: [quantized_decomposed.quantize_per_token] | |
buf2144 = torch.ops.quantized_decomposed.quantize_per_token.default(buf2130, buf2135, buf2136, -128, 127, torch.int8) | |
buf2145 = buf2144 | |
del buf2144 | |
# Source Nodes: [input_522], Original ATen: [quantized_decomposed.dequantize_per_token] | |
buf2146 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf2145, buf2135, buf2136, -128, 127, torch.int8, torch.bfloat16) | |
del buf2135 | |
del buf2136 | |
del buf2145 | |
buf2147 = buf2146 | |
del buf2146 | |
# Source Nodes: [w_dq_173], Original ATen: [quantized_decomposed.dequantize_per_channel_group] | |
buf2148 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg648_1, arg649_1, arg650_1, -8, 7, torch.int8, 256, torch.bfloat16) | |
del arg648_1 | |
del arg649_1 | |
del arg650_1 | |
buf2149 = buf2148 | |
del buf2148 | |
buf2151 = buf2065; del buf2065 # reuse | |
# Source Nodes: [c_173, mul_324, silu_24], Original ATen: [aten.mm, aten.mul, aten.silu] | |
triton_tem_fused_mm_mul_silu_7.run(buf2147, buf2149, buf2143, buf2151, grid=torch._inductor.kernel.mm_common.mm_grid(1, 14336, meta1), stream=stream0) | |
del buf2143 | |
del buf2149 | |
# Source Nodes: [choose_qparams_per_token_asymmetric_174], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric] | |
buf2152 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf2151, torch.int8) | |
buf2153 = buf2152[0] | |
buf2154 = buf2152[1] | |
del buf2152 | |
# Source Nodes: [input_524], Original ATen: [quantized_decomposed.quantize_per_token] | |
buf2155 = torch.ops.quantized_decomposed.quantize_per_token.default(buf2151, buf2153, buf2154, -128, 127, torch.int8) | |
buf2156 = buf2155 | |
del buf2155 | |
# Source Nodes: [input_525], Original ATen: [quantized_decomposed.dequantize_per_token] | |
buf2157 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf2156, buf2153, buf2154, -128, 127, torch.int8, torch.bfloat16) | |
del buf2153 | |
del buf2154 | |
del buf2156 | |
buf2158 = buf2157 | |
del buf2157 | |
# Source Nodes: [w_dq_174], Original ATen: [quantized_decomposed.dequantize_per_channel_group] | |
buf2159 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg651_1, arg652_1, arg653_1, -8, 7, torch.int8, 256, torch.bfloat16) | |
del arg651_1 | |
del arg652_1 | |
del arg653_1 | |
buf2160 = buf2159 | |
del buf2159 | |
buf2161 = reinterpret_tensor(buf2147, (1, 4096), (4096, 1), 0); del buf2147 # reuse | |
# Source Nodes: [c_174], Original ATen: [aten.mm] | |
triton_tem_fused_mm_8.run(buf2158, buf2160, buf2161, grid=torch._inductor.kernel.mm_common.mm_grid(1, 4096, meta0), stream=stream0) | |
del buf2160 | |
buf2163 = buf2130; del buf2130 # reuse | |
# Source Nodes: [add_175, h_25, mean_50, mul_325, out_23, out_24, pow_51, rsqrt_50, x_fp32_50, x_normed_50, y_25], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, aten.pow, aten.rsqrt] | |
triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_13.run(buf2128, buf2042, buf2075, buf2161, arg654_1, buf2163, 1, 4096, grid=grid(1), stream=stream0) | |
del arg654_1 | |
# Source Nodes: [choose_qparams_per_token_asymmetric_175], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric] | |
buf2164 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf2163, torch.int8) | |
buf2165 = buf2164[0] | |
buf2166 = buf2164[1] | |
del buf2164 | |
# Source Nodes: [choose_qparams_per_token_asymmetric_176], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric] | |
buf2167 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf2163, torch.int8) | |
buf2168 = buf2167[0] | |
buf2169 = buf2167[1] | |
del buf2167 | |
# Source Nodes: [choose_qparams_per_token_asymmetric_177], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric] | |
buf2170 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf2163, torch.int8) | |
buf2171 = buf2170[0] | |
buf2172 = buf2170[1] | |
del buf2170 | |
buf2173 = buf2087; del buf2087 # reuse | |
# Source Nodes: [max_26], Original ATen: [aten.max] | |
triton_poi_fused_max_1.run(arg3_1, buf2173, 1, grid=grid(1), stream=stream0) | |
u25 = buf2173.item() | |
buf2174 = None | |
with torch.cuda._DeviceGuard(0): | |
torch.cuda.set_device(0) | |
# Source Nodes: [input_527], Original ATen: [quantized_decomposed.quantize_per_token] | |
buf2175 = torch.ops.quantized_decomposed.quantize_per_token.default(buf2163, buf2165, buf2166, -128, 127, torch.int8) | |
buf2176 = buf2175 | |
del buf2175 | |
# Source Nodes: [input_528], Original ATen: [quantized_decomposed.dequantize_per_token] | |
buf2177 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf2176, buf2165, buf2166, -128, 127, torch.int8, torch.bfloat16) | |
del buf2165 | |
del buf2166 | |
del buf2176 | |
buf2178 = buf2177 | |
del buf2177 | |
# Source Nodes: [w_dq_175], Original ATen: [quantized_decomposed.dequantize_per_channel_group] | |
buf2179 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg655_1, arg656_1, arg657_1, -8, 7, torch.int8, 256, torch.bfloat16) | |
del arg655_1 | |
del arg656_1 | |
del arg657_1 | |
buf2180 = buf2179 | |
del buf2179 | |
buf2181 = reinterpret_tensor(buf2140, (1, 4096), (4096, 1), 0); del buf2140 # reuse | |
# Source Nodes: [c_175], Original ATen: [aten.mm] | |
triton_tem_fused_mm_2.run(buf2178, buf2180, buf2181, grid=torch._inductor.kernel.mm_common.mm_grid(1, 4096, meta0), stream=stream0) | |
del buf2178 | |
del buf2180 | |
# Source Nodes: [input_530], Original ATen: [quantized_decomposed.quantize_per_token] | |
buf2182 = torch.ops.quantized_decomposed.quantize_per_token.default(buf2163, buf2168, buf2169, -128, 127, torch.int8) | |
buf2183 = buf2182 | |
del buf2182 | |
# Source Nodes: [input_531], Original ATen: [quantized_decomposed.dequantize_per_token] | |
buf2184 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf2183, buf2168, buf2169, -128, 127, torch.int8, torch.bfloat16) | |
del buf2168 | |
del buf2169 | |
del buf2183 | |
buf2185 = buf2184 | |
del buf2184 | |
# Source Nodes: [w_dq_176], Original ATen: [quantized_decomposed.dequantize_per_channel_group] | |
buf2186 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg658_1, arg659_1, arg660_1, -8, 7, torch.int8, 256, torch.bfloat16) | |
del arg658_1 | |
del arg659_1 | |
del arg660_1 | |
buf2187 = buf2186 | |
del buf2186 | |
buf2188 = buf2110; del buf2110 # reuse | |
# Source Nodes: [], Original ATen: [] | |
extern_kernels.mm(reinterpret_tensor(buf2185, (1, 4096), (4096, 1), 0), reinterpret_tensor(buf2187, (4096, 1024), (1, 4096), 0), out=buf2188) | |
del buf2185 | |
del buf2187 | |
# Source Nodes: [input_533], Original ATen: [quantized_decomposed.quantize_per_token] | |
buf2190 = torch.ops.quantized_decomposed.quantize_per_token.default(buf2163, buf2171, buf2172, -128, 127, torch.int8) | |
del buf2163 | |
buf2191 = buf2190 | |
del buf2190 | |
# Source Nodes: [input_534], Original ATen: [quantized_decomposed.dequantize_per_token] | |
buf2192 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf2191, buf2171, buf2172, -128, 127, torch.int8, torch.bfloat16) | |
del buf2171 | |
del buf2172 | |
del buf2191 | |
buf2193 = buf2192 | |
del buf2192 | |
# Source Nodes: [w_dq_177], Original ATen: [quantized_decomposed.dequantize_per_channel_group] | |
buf2194 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg661_1, arg662_1, arg663_1, -8, 7, torch.int8, 256, torch.bfloat16) | |
del arg661_1 | |
del arg662_1 | |
del arg663_1 | |
buf2195 = buf2194 | |
del buf2194 | |
buf2196 = buf2102; del buf2102 # reuse | |
# Source Nodes: [], Original ATen: [] | |
extern_kernels.mm(reinterpret_tensor(buf2193, (1, 4096), (4096, 1), 0), reinterpret_tensor(buf2195, (4096, 1024), (1, 4096), 0), out=buf2196) | |
del buf2195 | |
buf2198 = reinterpret_tensor(buf2193, (1, 32, 1, 128), (4096, 128, 4096, 1), 0); del buf2193 # reuse | |
# Source Nodes: [output_50, setitem_50, setitem_51], Original ATen: [aten._scaled_dot_product_efficient_attention, aten.index_put] | |
triton_poi_fused__scaled_dot_product_efficient_attention_index_put_3.run(arg3_1, buf2188, arg664_1, buf2196, buf2181, arg665_1, arg666_1, buf2198, 4096, grid=grid(4096), stream=stream0) | |
del arg664_1 | |
del buf2181 | |
buf2199 = buf2113; del buf2113 # reuse | |
# Source Nodes: [output_50], Original ATen: [aten._scaled_dot_product_efficient_attention] | |
triton_poi_fused__scaled_dot_product_efficient_attention_4.run(arg3_1, arg2_1, buf2199, 65536, grid=grid(65536), stream=stream0) | |
# Source Nodes: [output_50], Original ATen: [aten._scaled_dot_product_efficient_attention] | |
buf2200 = torch.ops.aten._scaled_dot_product_efficient_attention.default(buf2198, arg665_1, arg666_1, buf2199, False) | |
del arg665_1 | |
del arg666_1 | |
del buf2198 | |
buf2201 = buf2200[0] | |
del buf2200 | |
# Source Nodes: [choose_qparams_per_token_asymmetric_178], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric] | |
buf2205 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(reinterpret_tensor(buf2201, (1, 1, 4096), (4096, 4096, 1), 0), torch.int8) | |
buf2206 = buf2205[0] | |
buf2207 = buf2205[1] | |
del buf2205 | |
# Source Nodes: [input_536], Original ATen: [quantized_decomposed.quantize_per_token] | |
buf2208 = torch.ops.quantized_decomposed.quantize_per_token.default(reinterpret_tensor(buf2201, (1, 1, 4096), (4096, 4096, 1), 0), buf2206, buf2207, -128, 127, torch.int8) | |
buf2209 = buf2208 | |
del buf2208 | |
# Source Nodes: [input_537], Original ATen: [quantized_decomposed.dequantize_per_token] | |
buf2210 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf2209, buf2206, buf2207, -128, 127, torch.int8, torch.bfloat16) | |
del buf2206 | |
del buf2207 | |
del buf2209 | |
buf2211 = buf2210 | |
del buf2210 | |
# Source Nodes: [w_dq_178], Original ATen: [quantized_decomposed.dequantize_per_channel_group] | |
buf2212 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg667_1, arg668_1, arg669_1, -8, 7, torch.int8, 256, torch.bfloat16) | |
del arg667_1 | |
del arg668_1 | |
del arg669_1 | |
buf2213 = buf2212 | |
del buf2212 | |
buf2214 = reinterpret_tensor(buf2201, (1, 4096), (4096, 1), 0); del buf2201 # reuse | |
# Source Nodes: [c_178], Original ATen: [aten.mm] | |
triton_tem_fused_mm_2.run(buf2211, buf2213, buf2214, grid=torch._inductor.kernel.mm_common.mm_grid(1, 4096, meta0), stream=stream0) | |
del buf2213 | |
buf2215 = buf2042; del buf2042 # reuse | |
buf2217 = buf2211; del buf2211 # reuse | |
# Source Nodes: [add_180, h_25, h_26, mean_51, mul_335, mul_336, out_23, out_24, pow_52, rsqrt_51, x_fp32_51, x_normed_51], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, aten.pow, aten.rsqrt] | |
triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_14.run(buf2215, buf2214, buf2128, buf2075, buf2161, arg670_1, buf2217, 1, 4096, grid=grid(1), stream=stream0) | |
del arg670_1 | |
del buf2075 | |
del buf2128 | |
del buf2161 | |
del buf2214 | |
# Source Nodes: [choose_qparams_per_token_asymmetric_179], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric] | |
buf2218 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf2217, torch.int8) | |
buf2219 = buf2218[0] | |
buf2220 = buf2218[1] | |
del buf2218 | |
# Source Nodes: [choose_qparams_per_token_asymmetric_180], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric] | |
buf2221 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf2217, torch.int8) | |
buf2222 = buf2221[0] | |
buf2223 = buf2221[1] | |
del buf2221 | |
# Source Nodes: [input_539], Original ATen: [quantized_decomposed.quantize_per_token] | |
buf2224 = torch.ops.quantized_decomposed.quantize_per_token.default(buf2217, buf2219, buf2220, -128, 127, torch.int8) | |
buf2225 = buf2224 | |
del buf2224 | |
# Source Nodes: [input_540], Original ATen: [quantized_decomposed.dequantize_per_token] | |
buf2226 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf2225, buf2219, buf2220, -128, 127, torch.int8, torch.bfloat16) | |
del buf2219 | |
del buf2220 | |
del buf2225 | |
buf2227 = buf2226 | |
del buf2226 | |
# Source Nodes: [w_dq_179], Original ATen: [quantized_decomposed.dequantize_per_channel_group] | |
buf2228 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg671_1, arg672_1, arg673_1, -8, 7, torch.int8, 256, torch.bfloat16) | |
del arg671_1 | |
del arg672_1 | |
del arg673_1 | |
buf2229 = buf2228 | |
del buf2228 | |
buf2230 = reinterpret_tensor(buf2158, (1, 14336), (14336, 1), 0); del buf2158 # reuse | |
# Source Nodes: [c_179], Original ATen: [aten.mm] | |
triton_tem_fused_mm_6.run(buf2227, buf2229, buf2230, grid=torch._inductor.kernel.mm_common.mm_grid(1, 14336, meta1), stream=stream0) | |
del buf2229 | |
# Source Nodes: [input_542], Original ATen: [quantized_decomposed.quantize_per_token] | |
buf2231 = torch.ops.quantized_decomposed.quantize_per_token.default(buf2217, buf2222, buf2223, -128, 127, torch.int8) | |
buf2232 = buf2231 | |
del buf2231 | |
# Source Nodes: [input_543], Original ATen: [quantized_decomposed.dequantize_per_token] | |
buf2233 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf2232, buf2222, buf2223, -128, 127, torch.int8, torch.bfloat16) | |
del buf2222 | |
del buf2223 | |
del buf2232 | |
buf2234 = buf2233 | |
del buf2233 | |
# Source Nodes: [w_dq_180], Original ATen: [quantized_decomposed.dequantize_per_channel_group] | |
buf2235 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg674_1, arg675_1, arg676_1, -8, 7, torch.int8, 256, torch.bfloat16) | |
del arg674_1 | |
del arg675_1 | |
del arg676_1 | |
buf2236 = buf2235 | |
del buf2235 | |
buf2238 = buf2151; del buf2151 # reuse | |
# Source Nodes: [c_180, mul_337, silu_25], Original ATen: [aten.mm, aten.mul, aten.silu] | |
triton_tem_fused_mm_mul_silu_7.run(buf2234, buf2236, buf2230, buf2238, grid=torch._inductor.kernel.mm_common.mm_grid(1, 14336, meta1), stream=stream0) | |
del buf2230 | |
del buf2236 | |
# Source Nodes: [choose_qparams_per_token_asymmetric_181], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric] | |
buf2239 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf2238, torch.int8) | |
buf2240 = buf2239[0] | |
buf2241 = buf2239[1] | |
del buf2239 | |
# Source Nodes: [input_545], Original ATen: [quantized_decomposed.quantize_per_token] | |
buf2242 = torch.ops.quantized_decomposed.quantize_per_token.default(buf2238, buf2240, buf2241, -128, 127, torch.int8) | |
buf2243 = buf2242 | |
del buf2242 | |
# Source Nodes: [input_546], Original ATen: [quantized_decomposed.dequantize_per_token] | |
buf2244 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf2243, buf2240, buf2241, -128, 127, torch.int8, torch.bfloat16) | |
del buf2240 | |
del buf2241 | |
del buf2243 | |
buf2245 = buf2244 | |
del buf2244 | |
# Source Nodes: [w_dq_181], Original ATen: [quantized_decomposed.dequantize_per_channel_group] | |
buf2246 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg677_1, arg678_1, arg679_1, -8, 7, torch.int8, 256, torch.bfloat16) | |
del arg677_1 | |
del arg678_1 | |
del arg679_1 | |
buf2247 = buf2246 | |
del buf2246 | |
buf2248 = reinterpret_tensor(buf2234, (1, 4096), (4096, 1), 0); del buf2234 # reuse | |
# Source Nodes: [c_181], Original ATen: [aten.mm] | |
triton_tem_fused_mm_8.run(buf2245, buf2247, buf2248, grid=torch._inductor.kernel.mm_common.mm_grid(1, 4096, meta0), stream=stream0) | |
del buf2247 | |
buf2250 = buf2217; del buf2217 # reuse | |
# Source Nodes: [add_182, mean_52, mul_338, out_25, pow_53, rsqrt_52, x_fp32_52, x_normed_52, y_26], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, aten.pow, aten.rsqrt] | |
triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_11.run(buf2215, buf2248, arg680_1, buf2250, 1, 4096, grid=grid(1), stream=stream0) | |
del arg680_1 | |
# Source Nodes: [choose_qparams_per_token_asymmetric_182], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric] | |
buf2251 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf2250, torch.int8) | |
buf2252 = buf2251[0] | |
buf2253 = buf2251[1] | |
del buf2251 | |
# Source Nodes: [choose_qparams_per_token_asymmetric_183], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric] | |
buf2254 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf2250, torch.int8) | |
buf2255 = buf2254[0] | |
buf2256 = buf2254[1] | |
del buf2254 | |
# Source Nodes: [choose_qparams_per_token_asymmetric_184], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric] | |
buf2257 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf2250, torch.int8) | |
buf2258 = buf2257[0] | |
buf2259 = buf2257[1] | |
del buf2257 | |
buf2260 = buf2173; del buf2173 # reuse | |
# Source Nodes: [max_27], Original ATen: [aten.max] | |
triton_poi_fused_max_1.run(arg3_1, buf2260, 1, grid=grid(1), stream=stream0) | |
u26 = buf2260.item() | |
buf2261 = None | |
with torch.cuda._DeviceGuard(0): | |
torch.cuda.set_device(0) | |
# Source Nodes: [input_548], Original ATen: [quantized_decomposed.quantize_per_token] | |
buf2262 = torch.ops.quantized_decomposed.quantize_per_token.default(buf2250, buf2252, buf2253, -128, 127, torch.int8) | |
buf2263 = buf2262 | |
del buf2262 | |
# Source Nodes: [input_549], Original ATen: [quantized_decomposed.dequantize_per_token] | |
buf2264 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf2263, buf2252, buf2253, -128, 127, torch.int8, torch.bfloat16) | |
del buf2252 | |
del buf2253 | |
del buf2263 | |
buf2265 = buf2264 | |
del buf2264 | |
# Source Nodes: [w_dq_182], Original ATen: [quantized_decomposed.dequantize_per_channel_group] | |
buf2266 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg681_1, arg682_1, arg683_1, -8, 7, torch.int8, 256, torch.bfloat16) | |
del arg681_1 | |
del arg682_1 | |
del arg683_1 | |
buf2267 = buf2266 | |
del buf2266 | |
buf2268 = reinterpret_tensor(buf2227, (1, 4096), (4096, 1), 0); del buf2227 # reuse | |
# Source Nodes: [c_182], Original ATen: [aten.mm] | |
triton_tem_fused_mm_2.run(buf2265, buf2267, buf2268, grid=torch._inductor.kernel.mm_common.mm_grid(1, 4096, meta0), stream=stream0) | |
del buf2265 | |
del buf2267 | |
# Source Nodes: [input_551], Original ATen: [quantized_decomposed.quantize_per_token] | |
buf2269 = torch.ops.quantized_decomposed.quantize_per_token.default(buf2250, buf2255, buf2256, -128, 127, torch.int8) | |
buf2270 = buf2269 | |
del buf2269 | |
# Source Nodes: [input_552], Original ATen: [quantized_decomposed.dequantize_per_token] | |
buf2271 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf2270, buf2255, buf2256, -128, 127, torch.int8, torch.bfloat16) | |
del buf2255 | |
del buf2256 | |
del buf2270 | |
buf2272 = buf2271 | |
del buf2271 | |
# Source Nodes: [w_dq_183], Original ATen: [quantized_decomposed.dequantize_per_channel_group] | |
buf2273 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg684_1, arg685_1, arg686_1, -8, 7, torch.int8, 256, torch.bfloat16) | |
del arg684_1 | |
del arg685_1 | |
del arg686_1 | |
buf2274 = buf2273 | |
del buf2273 | |
buf2275 = buf2196; del buf2196 # reuse | |
# Source Nodes: [], Original ATen: [] | |
extern_kernels.mm(reinterpret_tensor(buf2272, (1, 4096), (4096, 1), 0), reinterpret_tensor(buf2274, (4096, 1024), (1, 4096), 0), out=buf2275) | |
del buf2272 | |
del buf2274 | |
# Source Nodes: [input_554], Original ATen: [quantized_decomposed.quantize_per_token] | |
buf2277 = torch.ops.quantized_decomposed.quantize_per_token.default(buf2250, buf2258, buf2259, -128, 127, torch.int8) | |
del buf2250 | |
buf2278 = buf2277 | |
del buf2277 | |
# Source Nodes: [input_555], Original ATen: [quantized_decomposed.dequantize_per_token] | |
buf2279 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf2278, buf2258, buf2259, -128, 127, torch.int8, torch.bfloat16) | |
del buf2258 | |
del buf2259 | |
del buf2278 | |
buf2280 = buf2279 | |
del buf2279 | |
# Source Nodes: [w_dq_184], Original ATen: [quantized_decomposed.dequantize_per_channel_group] | |
buf2281 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg687_1, arg688_1, arg689_1, -8, 7, torch.int8, 256, torch.bfloat16) | |
del arg687_1 | |
del arg688_1 | |
del arg689_1 | |
buf2282 = buf2281 | |
del buf2281 | |
buf2283 = buf2188; del buf2188 # reuse | |
# Source Nodes: [], Original ATen: [] | |
extern_kernels.mm(reinterpret_tensor(buf2280, (1, 4096), (4096, 1), 0), reinterpret_tensor(buf2282, (4096, 1024), (1, 4096), 0), out=buf2283) | |
del buf2282 | |
buf2285 = reinterpret_tensor(buf2280, (1, 32, 1, 128), (4096, 128, 4096, 1), 0); del buf2280 # reuse | |
# Source Nodes: [output_52, setitem_52, setitem_53], Original ATen: [aten._scaled_dot_product_efficient_attention, aten.index_put] | |
triton_poi_fused__scaled_dot_product_efficient_attention_index_put_3.run(arg3_1, buf2275, arg690_1, buf2283, buf2268, arg691_1, arg692_1, buf2285, 4096, grid=grid(4096), stream=stream0) | |
del arg690_1 | |
del buf2268 | |
buf2286 = buf2199; del buf2199 # reuse | |
# Source Nodes: [output_52], Original ATen: [aten._scaled_dot_product_efficient_attention] | |
triton_poi_fused__scaled_dot_product_efficient_attention_4.run(arg3_1, arg2_1, buf2286, 65536, grid=grid(65536), stream=stream0) | |
# Source Nodes: [output_52], Original ATen: [aten._scaled_dot_product_efficient_attention] | |
buf2287 = torch.ops.aten._scaled_dot_product_efficient_attention.default(buf2285, arg691_1, arg692_1, buf2286, False) | |
del arg691_1 | |
del arg692_1 | |
del buf2285 | |
buf2288 = buf2287[0] | |
del buf2287 | |
# Source Nodes: [choose_qparams_per_token_asymmetric_185], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric] | |
buf2292 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(reinterpret_tensor(buf2288, (1, 1, 4096), (4096, 4096, 1), 0), torch.int8) | |
buf2293 = buf2292[0] | |
buf2294 = buf2292[1] | |
del buf2292 | |
# Source Nodes: [input_557], Original ATen: [quantized_decomposed.quantize_per_token] | |
buf2295 = torch.ops.quantized_decomposed.quantize_per_token.default(reinterpret_tensor(buf2288, (1, 1, 4096), (4096, 4096, 1), 0), buf2293, buf2294, -128, 127, torch.int8) | |
buf2296 = buf2295 | |
del buf2295 | |
# Source Nodes: [input_558], Original ATen: [quantized_decomposed.dequantize_per_token] | |
buf2297 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf2296, buf2293, buf2294, -128, 127, torch.int8, torch.bfloat16) | |
del buf2293 | |
del buf2294 | |
del buf2296 | |
buf2298 = buf2297 | |
del buf2297 | |
# Source Nodes: [w_dq_185], Original ATen: [quantized_decomposed.dequantize_per_channel_group] | |
buf2299 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg693_1, arg694_1, arg695_1, -8, 7, torch.int8, 256, torch.bfloat16) | |
del arg693_1 | |
del arg694_1 | |
del arg695_1 | |
buf2300 = buf2299 | |
del buf2299 | |
buf2301 = reinterpret_tensor(buf2288, (1, 4096), (4096, 1), 0); del buf2288 # reuse | |
# Source Nodes: [c_185], Original ATen: [aten.mm] | |
triton_tem_fused_mm_2.run(buf2298, buf2300, buf2301, grid=torch._inductor.kernel.mm_common.mm_grid(1, 4096, meta0), stream=stream0) | |
del buf2300 | |
buf2303 = buf2298; del buf2298 # reuse | |
# Source Nodes: [add_187, h_27, mean_53, mul_348, mul_349, out_25, pow_54, rsqrt_53, x_fp32_53, x_normed_53], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, aten.pow, aten.rsqrt] | |
triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_12.run(buf2301, buf2215, buf2248, arg696_1, buf2303, 1, 4096, grid=grid(1), stream=stream0) | |
del arg696_1 | |
# Source Nodes: [choose_qparams_per_token_asymmetric_186], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric] | |
buf2304 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf2303, torch.int8) | |
buf2305 = buf2304[0] | |
buf2306 = buf2304[1] | |
del buf2304 | |
# Source Nodes: [choose_qparams_per_token_asymmetric_187], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric] | |
buf2307 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf2303, torch.int8) | |
buf2308 = buf2307[0] | |
buf2309 = buf2307[1] | |
del buf2307 | |
# Source Nodes: [input_560], Original ATen: [quantized_decomposed.quantize_per_token] | |
buf2310 = torch.ops.quantized_decomposed.quantize_per_token.default(buf2303, buf2305, buf2306, -128, 127, torch.int8) | |
buf2311 = buf2310 | |
del buf2310 | |
# Source Nodes: [input_561], Original ATen: [quantized_decomposed.dequantize_per_token] | |
buf2312 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf2311, buf2305, buf2306, -128, 127, torch.int8, torch.bfloat16) | |
del buf2305 | |
del buf2306 | |
del buf2311 | |
buf2313 = buf2312 | |
del buf2312 | |
# Source Nodes: [w_dq_186], Original ATen: [quantized_decomposed.dequantize_per_channel_group] | |
buf2314 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg697_1, arg698_1, arg699_1, -8, 7, torch.int8, 256, torch.bfloat16) | |
del arg697_1 | |
del arg698_1 | |
del arg699_1 | |
buf2315 = buf2314 | |
del buf2314 | |
buf2316 = reinterpret_tensor(buf2245, (1, 14336), (14336, 1), 0); del buf2245 # reuse | |
# Source Nodes: [c_186], Original ATen: [aten.mm] | |
triton_tem_fused_mm_6.run(buf2313, buf2315, buf2316, grid=torch._inductor.kernel.mm_common.mm_grid(1, 14336, meta1), stream=stream0) | |
del buf2315 | |
# Source Nodes: [input_563], Original ATen: [quantized_decomposed.quantize_per_token] | |
buf2317 = torch.ops.quantized_decomposed.quantize_per_token.default(buf2303, buf2308, buf2309, -128, 127, torch.int8) | |
buf2318 = buf2317 | |
del buf2317 | |
# Source Nodes: [input_564], Original ATen: [quantized_decomposed.dequantize_per_token] | |
buf2319 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf2318, buf2308, buf2309, -128, 127, torch.int8, torch.bfloat16) | |
del buf2308 | |
del buf2309 | |
del buf2318 | |
buf2320 = buf2319 | |
del buf2319 | |
# Source Nodes: [w_dq_187], Original ATen: [quantized_decomposed.dequantize_per_channel_group] | |
buf2321 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg700_1, arg701_1, arg702_1, -8, 7, torch.int8, 256, torch.bfloat16) | |
del arg700_1 | |
del arg701_1 | |
del arg702_1 | |
buf2322 = buf2321 | |
del buf2321 | |
buf2324 = buf2238; del buf2238 # reuse | |
# Source Nodes: [c_187, mul_350, silu_26], Original ATen: [aten.mm, aten.mul, aten.silu] | |
triton_tem_fused_mm_mul_silu_7.run(buf2320, buf2322, buf2316, buf2324, grid=torch._inductor.kernel.mm_common.mm_grid(1, 14336, meta1), stream=stream0) | |
del buf2316 | |
del buf2322 | |
# Source Nodes: [choose_qparams_per_token_asymmetric_188], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric] | |
buf2325 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf2324, torch.int8) | |
buf2326 = buf2325[0] | |
buf2327 = buf2325[1] | |
del buf2325 | |
# Source Nodes: [input_566], Original ATen: [quantized_decomposed.quantize_per_token] | |
buf2328 = torch.ops.quantized_decomposed.quantize_per_token.default(buf2324, buf2326, buf2327, -128, 127, torch.int8) | |
buf2329 = buf2328 | |
del buf2328 | |
# Source Nodes: [input_567], Original ATen: [quantized_decomposed.dequantize_per_token] | |
buf2330 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf2329, buf2326, buf2327, -128, 127, torch.int8, torch.bfloat16) | |
del buf2326 | |
del buf2327 | |
del buf2329 | |
buf2331 = buf2330 | |
del buf2330 | |
# Source Nodes: [w_dq_188], Original ATen: [quantized_decomposed.dequantize_per_channel_group] | |
buf2332 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg703_1, arg704_1, arg705_1, -8, 7, torch.int8, 256, torch.bfloat16) | |
del arg703_1 | |
del arg704_1 | |
del arg705_1 | |
buf2333 = buf2332 | |
del buf2332 | |
buf2334 = reinterpret_tensor(buf2320, (1, 4096), (4096, 1), 0); del buf2320 # reuse | |
# Source Nodes: [c_188], Original ATen: [aten.mm] | |
triton_tem_fused_mm_8.run(buf2331, buf2333, buf2334, grid=torch._inductor.kernel.mm_common.mm_grid(1, 4096, meta0), stream=stream0) | |
del buf2333 | |
buf2336 = buf2303; del buf2303 # reuse | |
# Source Nodes: [add_189, h_27, mean_54, mul_351, out_25, out_26, pow_55, rsqrt_54, x_fp32_54, x_normed_54, y_27], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, aten.pow, aten.rsqrt] | |
triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_13.run(buf2301, buf2215, buf2248, buf2334, arg706_1, buf2336, 1, 4096, grid=grid(1), stream=stream0) | |
del arg706_1 | |
# Source Nodes: [choose_qparams_per_token_asymmetric_189], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric] | |
buf2337 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf2336, torch.int8) | |
buf2338 = buf2337[0] | |
buf2339 = buf2337[1] | |
del buf2337 | |
# Source Nodes: [choose_qparams_per_token_asymmetric_190], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric] | |
buf2340 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf2336, torch.int8) | |
buf2341 = buf2340[0] | |
buf2342 = buf2340[1] | |
del buf2340 | |
# Source Nodes: [choose_qparams_per_token_asymmetric_191], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric] | |
buf2343 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf2336, torch.int8) | |
buf2344 = buf2343[0] | |
buf2345 = buf2343[1] | |
del buf2343 | |
buf2346 = buf2260; del buf2260 # reuse | |
# Source Nodes: [max_28], Original ATen: [aten.max] | |
triton_poi_fused_max_1.run(arg3_1, buf2346, 1, grid=grid(1), stream=stream0) | |
u27 = buf2346.item() | |
buf2347 = None | |
with torch.cuda._DeviceGuard(0): | |
torch.cuda.set_device(0) | |
# Source Nodes: [input_569], Original ATen: [quantized_decomposed.quantize_per_token] | |
buf2348 = torch.ops.quantized_decomposed.quantize_per_token.default(buf2336, buf2338, buf2339, -128, 127, torch.int8) | |
buf2349 = buf2348 | |
del buf2348 | |
# Source Nodes: [input_570], Original ATen: [quantized_decomposed.dequantize_per_token] | |
buf2350 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf2349, buf2338, buf2339, -128, 127, torch.int8, torch.bfloat16) | |
del buf2338 | |
del buf2339 | |
del buf2349 | |
buf2351 = buf2350 | |
del buf2350 | |
# Source Nodes: [w_dq_189], Original ATen: [quantized_decomposed.dequantize_per_channel_group] | |
buf2352 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg707_1, arg708_1, arg709_1, -8, 7, torch.int8, 256, torch.bfloat16) | |
del arg707_1 | |
del arg708_1 | |
del arg709_1 | |
buf2353 = buf2352 | |
del buf2352 | |
buf2354 = reinterpret_tensor(buf2313, (1, 4096), (4096, 1), 0); del buf2313 # reuse | |
# Source Nodes: [c_189], Original ATen: [aten.mm] | |
triton_tem_fused_mm_2.run(buf2351, buf2353, buf2354, grid=torch._inductor.kernel.mm_common.mm_grid(1, 4096, meta0), stream=stream0) | |
del buf2351 | |
del buf2353 | |
# Source Nodes: [input_572], Original ATen: [quantized_decomposed.quantize_per_token] | |
buf2355 = torch.ops.quantized_decomposed.quantize_per_token.default(buf2336, buf2341, buf2342, -128, 127, torch.int8) | |
buf2356 = buf2355 | |
del buf2355 | |
# Source Nodes: [input_573], Original ATen: [quantized_decomposed.dequantize_per_token] | |
buf2357 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf2356, buf2341, buf2342, -128, 127, torch.int8, torch.bfloat16) | |
del buf2341 | |
del buf2342 | |
del buf2356 | |
buf2358 = buf2357 | |
del buf2357 | |
# Source Nodes: [w_dq_190], Original ATen: [quantized_decomposed.dequantize_per_channel_group] | |
buf2359 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg710_1, arg711_1, arg712_1, -8, 7, torch.int8, 256, torch.bfloat16) | |
del arg710_1 | |
del arg711_1 | |
del arg712_1 | |
buf2360 = buf2359 | |
del buf2359 | |
buf2361 = buf2283; del buf2283 # reuse | |
# Source Nodes: [], Original ATen: [] | |
extern_kernels.mm(reinterpret_tensor(buf2358, (1, 4096), (4096, 1), 0), reinterpret_tensor(buf2360, (4096, 1024), (1, 4096), 0), out=buf2361) | |
del buf2358 | |
del buf2360 | |
# Source Nodes: [input_575], Original ATen: [quantized_decomposed.quantize_per_token] | |
buf2363 = torch.ops.quantized_decomposed.quantize_per_token.default(buf2336, buf2344, buf2345, -128, 127, torch.int8) | |
del buf2336 | |
buf2364 = buf2363 | |
del buf2363 | |
# Source Nodes: [input_576], Original ATen: [quantized_decomposed.dequantize_per_token] | |
buf2365 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf2364, buf2344, buf2345, -128, 127, torch.int8, torch.bfloat16) | |
del buf2344 | |
del buf2345 | |
del buf2364 | |
buf2366 = buf2365 | |
del buf2365 | |
# Source Nodes: [w_dq_191], Original ATen: [quantized_decomposed.dequantize_per_channel_group] | |
buf2367 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg713_1, arg714_1, arg715_1, -8, 7, torch.int8, 256, torch.bfloat16) | |
del arg713_1 | |
del arg714_1 | |
del arg715_1 | |
buf2368 = buf2367 | |
del buf2367 | |
buf2369 = buf2275; del buf2275 # reuse | |
# Source Nodes: [], Original ATen: [] | |
extern_kernels.mm(reinterpret_tensor(buf2366, (1, 4096), (4096, 1), 0), reinterpret_tensor(buf2368, (4096, 1024), (1, 4096), 0), out=buf2369) | |
del buf2368 | |
buf2371 = reinterpret_tensor(buf2366, (1, 32, 1, 128), (4096, 128, 4096, 1), 0); del buf2366 # reuse | |
# Source Nodes: [output_54, setitem_54, setitem_55], Original ATen: [aten._scaled_dot_product_efficient_attention, aten.index_put] | |
triton_poi_fused__scaled_dot_product_efficient_attention_index_put_3.run(arg3_1, buf2361, arg716_1, buf2369, buf2354, arg717_1, arg718_1, buf2371, 4096, grid=grid(4096), stream=stream0) | |
del arg716_1 | |
del buf2354 | |
buf2372 = buf2286; del buf2286 # reuse | |
# Source Nodes: [output_54], Original ATen: [aten._scaled_dot_product_efficient_attention] | |
triton_poi_fused__scaled_dot_product_efficient_attention_4.run(arg3_1, arg2_1, buf2372, 65536, grid=grid(65536), stream=stream0) | |
# Source Nodes: [output_54], Original ATen: [aten._scaled_dot_product_efficient_attention] | |
buf2373 = torch.ops.aten._scaled_dot_product_efficient_attention.default(buf2371, arg717_1, arg718_1, buf2372, False) | |
del arg717_1 | |
del arg718_1 | |
del buf2371 | |
buf2374 = buf2373[0] | |
del buf2373 | |
# Source Nodes: [choose_qparams_per_token_asymmetric_192], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric] | |
buf2378 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(reinterpret_tensor(buf2374, (1, 1, 4096), (4096, 4096, 1), 0), torch.int8) | |
buf2379 = buf2378[0] | |
buf2380 = buf2378[1] | |
del buf2378 | |
# Source Nodes: [input_578], Original ATen: [quantized_decomposed.quantize_per_token] | |
buf2381 = torch.ops.quantized_decomposed.quantize_per_token.default(reinterpret_tensor(buf2374, (1, 1, 4096), (4096, 4096, 1), 0), buf2379, buf2380, -128, 127, torch.int8) | |
buf2382 = buf2381 | |
del buf2381 | |
# Source Nodes: [input_579], Original ATen: [quantized_decomposed.dequantize_per_token] | |
buf2383 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf2382, buf2379, buf2380, -128, 127, torch.int8, torch.bfloat16) | |
del buf2379 | |
del buf2380 | |
del buf2382 | |
buf2384 = buf2383 | |
del buf2383 | |
# Source Nodes: [w_dq_192], Original ATen: [quantized_decomposed.dequantize_per_channel_group] | |
buf2385 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg719_1, arg720_1, arg721_1, -8, 7, torch.int8, 256, torch.bfloat16) | |
del arg719_1 | |
del arg720_1 | |
del arg721_1 | |
buf2386 = buf2385 | |
del buf2385 | |
buf2387 = reinterpret_tensor(buf2374, (1, 4096), (4096, 1), 0); del buf2374 # reuse | |
# Source Nodes: [c_192], Original ATen: [aten.mm] | |
triton_tem_fused_mm_2.run(buf2384, buf2386, buf2387, grid=torch._inductor.kernel.mm_common.mm_grid(1, 4096, meta0), stream=stream0) | |
del buf2386 | |
buf2388 = buf2215; del buf2215 # reuse | |
buf2390 = buf2384; del buf2384 # reuse | |
# Source Nodes: [add_194, h_27, h_28, mean_55, mul_361, mul_362, out_25, out_26, pow_56, rsqrt_55, x_fp32_55, x_normed_55], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, aten.pow, aten.rsqrt] | |
triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_14.run(buf2388, buf2387, buf2301, buf2248, buf2334, arg722_1, buf2390, 1, 4096, grid=grid(1), stream=stream0) | |
del arg722_1 | |
del buf2248 | |
del buf2301 | |
del buf2334 | |
del buf2387 | |
# Source Nodes: [choose_qparams_per_token_asymmetric_193], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric] | |
buf2391 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf2390, torch.int8) | |
buf2392 = buf2391[0] | |
buf2393 = buf2391[1] | |
del buf2391 | |
# Source Nodes: [choose_qparams_per_token_asymmetric_194], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric] | |
buf2394 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf2390, torch.int8) | |
buf2395 = buf2394[0] | |
buf2396 = buf2394[1] | |
del buf2394 | |
# Source Nodes: [input_581], Original ATen: [quantized_decomposed.quantize_per_token] | |
buf2397 = torch.ops.quantized_decomposed.quantize_per_token.default(buf2390, buf2392, buf2393, -128, 127, torch.int8) | |
buf2398 = buf2397 | |
del buf2397 | |
# Source Nodes: [input_582], Original ATen: [quantized_decomposed.dequantize_per_token] | |
buf2399 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf2398, buf2392, buf2393, -128, 127, torch.int8, torch.bfloat16) | |
del buf2392 | |
del buf2393 | |
del buf2398 | |
buf2400 = buf2399 | |
del buf2399 | |
# Source Nodes: [w_dq_193], Original ATen: [quantized_decomposed.dequantize_per_channel_group] | |
buf2401 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg723_1, arg724_1, arg725_1, -8, 7, torch.int8, 256, torch.bfloat16) | |
del arg723_1 | |
del arg724_1 | |
del arg725_1 | |
buf2402 = buf2401 | |
del buf2401 | |
buf2403 = reinterpret_tensor(buf2331, (1, 14336), (14336, 1), 0); del buf2331 # reuse | |
# Source Nodes: [c_193], Original ATen: [aten.mm] | |
triton_tem_fused_mm_6.run(buf2400, buf2402, buf2403, grid=torch._inductor.kernel.mm_common.mm_grid(1, 14336, meta1), stream=stream0) | |
del buf2402 | |
# Source Nodes: [input_584], Original ATen: [quantized_decomposed.quantize_per_token] | |
buf2404 = torch.ops.quantized_decomposed.quantize_per_token.default(buf2390, buf2395, buf2396, -128, 127, torch.int8) | |
buf2405 = buf2404 | |
del buf2404 | |
# Source Nodes: [input_585], Original ATen: [quantized_decomposed.dequantize_per_token] | |
buf2406 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf2405, buf2395, buf2396, -128, 127, torch.int8, torch.bfloat16) | |
del buf2395 | |
del buf2396 | |
del buf2405 | |
buf2407 = buf2406 | |
del buf2406 | |
# Source Nodes: [w_dq_194], Original ATen: [quantized_decomposed.dequantize_per_channel_group] | |
buf2408 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg726_1, arg727_1, arg728_1, -8, 7, torch.int8, 256, torch.bfloat16) | |
del arg726_1 | |
del arg727_1 | |
del arg728_1 | |
buf2409 = buf2408 | |
del buf2408 | |
buf2411 = buf2324; del buf2324 # reuse | |
# Source Nodes: [c_194, mul_363, silu_27], Original ATen: [aten.mm, aten.mul, aten.silu] | |
triton_tem_fused_mm_mul_silu_7.run(buf2407, buf2409, buf2403, buf2411, grid=torch._inductor.kernel.mm_common.mm_grid(1, 14336, meta1), stream=stream0) | |
del buf2403 | |
del buf2409 | |
# Source Nodes: [choose_qparams_per_token_asymmetric_195], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric] | |
buf2412 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf2411, torch.int8) | |
buf2413 = buf2412[0] | |
buf2414 = buf2412[1] | |
del buf2412 | |
# Source Nodes: [input_587], Original ATen: [quantized_decomposed.quantize_per_token] | |
buf2415 = torch.ops.quantized_decomposed.quantize_per_token.default(buf2411, buf2413, buf2414, -128, 127, torch.int8) | |
buf2416 = buf2415 | |
del buf2415 | |
# Source Nodes: [input_588], Original ATen: [quantized_decomposed.dequantize_per_token] | |
buf2417 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf2416, buf2413, buf2414, -128, 127, torch.int8, torch.bfloat16) | |
del buf2413 | |
del buf2414 | |
del buf2416 | |
buf2418 = buf2417 | |
del buf2417 | |
# Source Nodes: [w_dq_195], Original ATen: [quantized_decomposed.dequantize_per_channel_group] | |
buf2419 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg729_1, arg730_1, arg731_1, -8, 7, torch.int8, 256, torch.bfloat16) | |
del arg729_1 | |
del arg730_1 | |
del arg731_1 | |
buf2420 = buf2419 | |
del buf2419 | |
buf2421 = reinterpret_tensor(buf2407, (1, 4096), (4096, 1), 0); del buf2407 # reuse | |
# Source Nodes: [c_195], Original ATen: [aten.mm] | |
triton_tem_fused_mm_8.run(buf2418, buf2420, buf2421, grid=torch._inductor.kernel.mm_common.mm_grid(1, 4096, meta0), stream=stream0) | |
del buf2420 | |
buf2423 = buf2390; del buf2390 # reuse | |
# Source Nodes: [add_196, mean_56, mul_364, out_27, pow_57, rsqrt_56, x_fp32_56, x_normed_56, y_28], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, aten.pow, aten.rsqrt] | |
triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_11.run(buf2388, buf2421, arg732_1, buf2423, 1, 4096, grid=grid(1), stream=stream0) | |
del arg732_1 | |
# Source Nodes: [choose_qparams_per_token_asymmetric_196], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric] | |
buf2424 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf2423, torch.int8) | |
buf2425 = buf2424[0] | |
buf2426 = buf2424[1] | |
del buf2424 | |
# Source Nodes: [choose_qparams_per_token_asymmetric_197], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric] | |
buf2427 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf2423, torch.int8) | |
buf2428 = buf2427[0] | |
buf2429 = buf2427[1] | |
del buf2427 | |
# Source Nodes: [choose_qparams_per_token_asymmetric_198], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric] | |
buf2430 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf2423, torch.int8) | |
buf2431 = buf2430[0] | |
buf2432 = buf2430[1] | |
del buf2430 | |
buf2433 = buf2346; del buf2346 # reuse | |
# Source Nodes: [max_29], Original ATen: [aten.max] | |
triton_poi_fused_max_1.run(arg3_1, buf2433, 1, grid=grid(1), stream=stream0) | |
u28 = buf2433.item() | |
buf2434 = None | |
with torch.cuda._DeviceGuard(0): | |
torch.cuda.set_device(0) | |
# Source Nodes: [input_590], Original ATen: [quantized_decomposed.quantize_per_token] | |
buf2435 = torch.ops.quantized_decomposed.quantize_per_token.default(buf2423, buf2425, buf2426, -128, 127, torch.int8) | |
buf2436 = buf2435 | |
del buf2435 | |
# Source Nodes: [input_591], Original ATen: [quantized_decomposed.dequantize_per_token] | |
buf2437 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf2436, buf2425, buf2426, -128, 127, torch.int8, torch.bfloat16) | |
del buf2425 | |
del buf2426 | |
del buf2436 | |
buf2438 = buf2437 | |
del buf2437 | |
# Source Nodes: [w_dq_196], Original ATen: [quantized_decomposed.dequantize_per_channel_group] | |
buf2439 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg733_1, arg734_1, arg735_1, -8, 7, torch.int8, 256, torch.bfloat16) | |
del arg733_1 | |
del arg734_1 | |
del arg735_1 | |
buf2440 = buf2439 | |
del buf2439 | |
buf2441 = reinterpret_tensor(buf2400, (1, 4096), (4096, 1), 0); del buf2400 # reuse | |
# Source Nodes: [c_196], Original ATen: [aten.mm] | |
triton_tem_fused_mm_2.run(buf2438, buf2440, buf2441, grid=torch._inductor.kernel.mm_common.mm_grid(1, 4096, meta0), stream=stream0) | |
del buf2438 | |
del buf2440 | |
# Source Nodes: [input_593], Original ATen: [quantized_decomposed.quantize_per_token] | |
buf2442 = torch.ops.quantized_decomposed.quantize_per_token.default(buf2423, buf2428, buf2429, -128, 127, torch.int8) | |
buf2443 = buf2442 | |
del buf2442 | |
# Source Nodes: [input_594], Original ATen: [quantized_decomposed.dequantize_per_token] | |
buf2444 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf2443, buf2428, buf2429, -128, 127, torch.int8, torch.bfloat16) | |
del buf2428 | |
del buf2429 | |
del buf2443 | |
buf2445 = buf2444 | |
del buf2444 | |
# Source Nodes: [w_dq_197], Original ATen: [quantized_decomposed.dequantize_per_channel_group] | |
buf2446 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg736_1, arg737_1, arg738_1, -8, 7, torch.int8, 256, torch.bfloat16) | |
del arg736_1 | |
del arg737_1 | |
del arg738_1 | |
buf2447 = buf2446 | |
del buf2446 | |
buf2448 = buf2369; del buf2369 # reuse | |
# Source Nodes: [], Original ATen: [] | |
extern_kernels.mm(reinterpret_tensor(buf2445, (1, 4096), (4096, 1), 0), reinterpret_tensor(buf2447, (4096, 1024), (1, 4096), 0), out=buf2448) | |
del buf2445 | |
del buf2447 | |
# Source Nodes: [input_596], Original ATen: [quantized_decomposed.quantize_per_token] | |
buf2450 = torch.ops.quantized_decomposed.quantize_per_token.default(buf2423, buf2431, buf2432, -128, 127, torch.int8) | |
del buf2423 | |
buf2451 = buf2450 | |
del buf2450 | |
# Source Nodes: [input_597], Original ATen: [quantized_decomposed.dequantize_per_token] | |
buf2452 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf2451, buf2431, buf2432, -128, 127, torch.int8, torch.bfloat16) | |
del buf2431 | |
del buf2432 | |
del buf2451 | |
buf2453 = buf2452 | |
del buf2452 | |
# Source Nodes: [w_dq_198], Original ATen: [quantized_decomposed.dequantize_per_channel_group] | |
buf2454 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg739_1, arg740_1, arg741_1, -8, 7, torch.int8, 256, torch.bfloat16) | |
del arg739_1 | |
del arg740_1 | |
del arg741_1 | |
buf2455 = buf2454 | |
del buf2454 | |
buf2456 = buf2361; del buf2361 # reuse | |
# Source Nodes: [], Original ATen: [] | |
extern_kernels.mm(reinterpret_tensor(buf2453, (1, 4096), (4096, 1), 0), reinterpret_tensor(buf2455, (4096, 1024), (1, 4096), 0), out=buf2456) | |
del buf2455 | |
buf2458 = reinterpret_tensor(buf2453, (1, 32, 1, 128), (4096, 128, 4096, 1), 0); del buf2453 # reuse | |
# Source Nodes: [output_56, setitem_56, setitem_57], Original ATen: [aten._scaled_dot_product_efficient_attention, aten.index_put] | |
triton_poi_fused__scaled_dot_product_efficient_attention_index_put_3.run(arg3_1, buf2448, arg742_1, buf2456, buf2441, arg743_1, arg744_1, buf2458, 4096, grid=grid(4096), stream=stream0) | |
del arg742_1 | |
del buf2441 | |
buf2459 = buf2372; del buf2372 # reuse | |
# Source Nodes: [output_56], Original ATen: [aten._scaled_dot_product_efficient_attention] | |
triton_poi_fused__scaled_dot_product_efficient_attention_4.run(arg3_1, arg2_1, buf2459, 65536, grid=grid(65536), stream=stream0) | |
# Source Nodes: [output_56], Original ATen: [aten._scaled_dot_product_efficient_attention] | |
buf2460 = torch.ops.aten._scaled_dot_product_efficient_attention.default(buf2458, arg743_1, arg744_1, buf2459, False) | |
del arg743_1 | |
del arg744_1 | |
del buf2458 | |
buf2461 = buf2460[0] | |
del buf2460 | |
# Source Nodes: [choose_qparams_per_token_asymmetric_199], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric] | |
buf2465 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(reinterpret_tensor(buf2461, (1, 1, 4096), (4096, 4096, 1), 0), torch.int8) | |
buf2466 = buf2465[0] | |
buf2467 = buf2465[1] | |
del buf2465 | |
# Source Nodes: [input_599], Original ATen: [quantized_decomposed.quantize_per_token] | |
buf2468 = torch.ops.quantized_decomposed.quantize_per_token.default(reinterpret_tensor(buf2461, (1, 1, 4096), (4096, 4096, 1), 0), buf2466, buf2467, -128, 127, torch.int8) | |
buf2469 = buf2468 | |
del buf2468 | |
# Source Nodes: [input_600], Original ATen: [quantized_decomposed.dequantize_per_token] | |
buf2470 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf2469, buf2466, buf2467, -128, 127, torch.int8, torch.bfloat16) | |
del buf2466 | |
del buf2467 | |
del buf2469 | |
buf2471 = buf2470 | |
del buf2470 | |
# Source Nodes: [w_dq_199], Original ATen: [quantized_decomposed.dequantize_per_channel_group] | |
buf2472 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg745_1, arg746_1, arg747_1, -8, 7, torch.int8, 256, torch.bfloat16) | |
del arg745_1 | |
del arg746_1 | |
del arg747_1 | |
buf2473 = buf2472 | |
del buf2472 | |
buf2474 = reinterpret_tensor(buf2461, (1, 4096), (4096, 1), 0); del buf2461 # reuse | |
# Source Nodes: [c_199], Original ATen: [aten.mm] | |
triton_tem_fused_mm_2.run(buf2471, buf2473, buf2474, grid=torch._inductor.kernel.mm_common.mm_grid(1, 4096, meta0), stream=stream0) | |
del buf2473 | |
buf2476 = buf2471; del buf2471 # reuse | |
# Source Nodes: [add_201, h_29, mean_57, mul_374, mul_375, out_27, pow_58, rsqrt_57, x_fp32_57, x_normed_57], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, aten.pow, aten.rsqrt] | |
triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_12.run(buf2474, buf2388, buf2421, arg748_1, buf2476, 1, 4096, grid=grid(1), stream=stream0) | |
del arg748_1 | |
# Source Nodes: [choose_qparams_per_token_asymmetric_200], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric] | |
buf2477 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf2476, torch.int8) | |
buf2478 = buf2477[0] | |
buf2479 = buf2477[1] | |
del buf2477 | |
# Source Nodes: [choose_qparams_per_token_asymmetric_201], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric] | |
buf2480 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf2476, torch.int8) | |
buf2481 = buf2480[0] | |
buf2482 = buf2480[1] | |
del buf2480 | |
# Source Nodes: [input_602], Original ATen: [quantized_decomposed.quantize_per_token] | |
buf2483 = torch.ops.quantized_decomposed.quantize_per_token.default(buf2476, buf2478, buf2479, -128, 127, torch.int8) | |
buf2484 = buf2483 | |
del buf2483 | |
# Source Nodes: [input_603], Original ATen: [quantized_decomposed.dequantize_per_token] | |
buf2485 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf2484, buf2478, buf2479, -128, 127, torch.int8, torch.bfloat16) | |
del buf2478 | |
del buf2479 | |
del buf2484 | |
buf2486 = buf2485 | |
del buf2485 | |
# Source Nodes: [w_dq_200], Original ATen: [quantized_decomposed.dequantize_per_channel_group] | |
buf2487 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg749_1, arg750_1, arg751_1, -8, 7, torch.int8, 256, torch.bfloat16) | |
del arg749_1 | |
del arg750_1 | |
del arg751_1 | |
buf2488 = buf2487 | |
del buf2487 | |
buf2489 = reinterpret_tensor(buf2418, (1, 14336), (14336, 1), 0); del buf2418 # reuse | |
# Source Nodes: [c_200], Original ATen: [aten.mm] | |
triton_tem_fused_mm_6.run(buf2486, buf2488, buf2489, grid=torch._inductor.kernel.mm_common.mm_grid(1, 14336, meta1), stream=stream0) | |
del buf2488 | |
# Source Nodes: [input_605], Original ATen: [quantized_decomposed.quantize_per_token] | |
buf2490 = torch.ops.quantized_decomposed.quantize_per_token.default(buf2476, buf2481, buf2482, -128, 127, torch.int8) | |
buf2491 = buf2490 | |
del buf2490 | |
# Source Nodes: [input_606], Original ATen: [quantized_decomposed.dequantize_per_token] | |
buf2492 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf2491, buf2481, buf2482, -128, 127, torch.int8, torch.bfloat16) | |
del buf2481 | |
del buf2482 | |
del buf2491 | |
buf2493 = buf2492 | |
del buf2492 | |
# Source Nodes: [w_dq_201], Original ATen: [quantized_decomposed.dequantize_per_channel_group] | |
buf2494 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg752_1, arg753_1, arg754_1, -8, 7, torch.int8, 256, torch.bfloat16) | |
del arg752_1 | |
del arg753_1 | |
del arg754_1 | |
buf2495 = buf2494 | |
del buf2494 | |
buf2497 = buf2411; del buf2411 # reuse | |
# Source Nodes: [c_201, mul_376, silu_28], Original ATen: [aten.mm, aten.mul, aten.silu] | |
triton_tem_fused_mm_mul_silu_7.run(buf2493, buf2495, buf2489, buf2497, grid=torch._inductor.kernel.mm_common.mm_grid(1, 14336, meta1), stream=stream0) | |
del buf2489 | |
del buf2495 | |
# Source Nodes: [choose_qparams_per_token_asymmetric_202], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric] | |
buf2498 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf2497, torch.int8) | |
buf2499 = buf2498[0] | |
buf2500 = buf2498[1] | |
del buf2498 | |
# Source Nodes: [input_608], Original ATen: [quantized_decomposed.quantize_per_token] | |
buf2501 = torch.ops.quantized_decomposed.quantize_per_token.default(buf2497, buf2499, buf2500, -128, 127, torch.int8) | |
buf2502 = buf2501 | |
del buf2501 | |
# Source Nodes: [input_609], Original ATen: [quantized_decomposed.dequantize_per_token] | |
buf2503 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf2502, buf2499, buf2500, -128, 127, torch.int8, torch.bfloat16) | |
del buf2499 | |
del buf2500 | |
del buf2502 | |
buf2504 = buf2503 | |
del buf2503 | |
# Source Nodes: [w_dq_202], Original ATen: [quantized_decomposed.dequantize_per_channel_group] | |
buf2505 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg755_1, arg756_1, arg757_1, -8, 7, torch.int8, 256, torch.bfloat16) | |
del arg755_1 | |
del arg756_1 | |
del arg757_1 | |
buf2506 = buf2505 | |
del buf2505 | |
buf2507 = reinterpret_tensor(buf2493, (1, 4096), (4096, 1), 0); del buf2493 # reuse | |
# Source Nodes: [c_202], Original ATen: [aten.mm] | |
triton_tem_fused_mm_8.run(buf2504, buf2506, buf2507, grid=torch._inductor.kernel.mm_common.mm_grid(1, 4096, meta0), stream=stream0) | |
del buf2506 | |
buf2509 = buf2476; del buf2476 # reuse | |
# Source Nodes: [add_203, h_29, mean_58, mul_377, out_27, out_28, pow_59, rsqrt_58, x_fp32_58, x_normed_58, y_29], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, aten.pow, aten.rsqrt] | |
triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_13.run(buf2474, buf2388, buf2421, buf2507, arg758_1, buf2509, 1, 4096, grid=grid(1), stream=stream0) | |
del arg758_1 | |
# Source Nodes: [choose_qparams_per_token_asymmetric_203], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric] | |
buf2510 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf2509, torch.int8) | |
buf2511 = buf2510[0] | |
buf2512 = buf2510[1] | |
del buf2510 | |
# Source Nodes: [choose_qparams_per_token_asymmetric_204], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric] | |
buf2513 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf2509, torch.int8) | |
buf2514 = buf2513[0] | |
buf2515 = buf2513[1] | |
del buf2513 | |
# Source Nodes: [choose_qparams_per_token_asymmetric_205], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric] | |
buf2516 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf2509, torch.int8) | |
buf2517 = buf2516[0] | |
buf2518 = buf2516[1] | |
del buf2516 | |
buf2519 = buf2433; del buf2433 # reuse | |
# Source Nodes: [max_30], Original ATen: [aten.max] | |
triton_poi_fused_max_1.run(arg3_1, buf2519, 1, grid=grid(1), stream=stream0) | |
u29 = buf2519.item() | |
buf2520 = None | |
with torch.cuda._DeviceGuard(0): | |
torch.cuda.set_device(0) | |
# Source Nodes: [input_611], Original ATen: [quantized_decomposed.quantize_per_token] | |
buf2521 = torch.ops.quantized_decomposed.quantize_per_token.default(buf2509, buf2511, buf2512, -128, 127, torch.int8) | |
buf2522 = buf2521 | |
del buf2521 | |
# Source Nodes: [input_612], Original ATen: [quantized_decomposed.dequantize_per_token] | |
buf2523 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf2522, buf2511, buf2512, -128, 127, torch.int8, torch.bfloat16) | |
del buf2511 | |
del buf2512 | |
del buf2522 | |
buf2524 = buf2523 | |
del buf2523 | |
# Source Nodes: [w_dq_203], Original ATen: [quantized_decomposed.dequantize_per_channel_group] | |
buf2525 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg759_1, arg760_1, arg761_1, -8, 7, torch.int8, 256, torch.bfloat16) | |
del arg759_1 | |
del arg760_1 | |
del arg761_1 | |
buf2526 = buf2525 | |
del buf2525 | |
buf2527 = reinterpret_tensor(buf2486, (1, 4096), (4096, 1), 0); del buf2486 # reuse | |
# Source Nodes: [c_203], Original ATen: [aten.mm] | |
triton_tem_fused_mm_2.run(buf2524, buf2526, buf2527, grid=torch._inductor.kernel.mm_common.mm_grid(1, 4096, meta0), stream=stream0) | |
del buf2524 | |
del buf2526 | |
# Source Nodes: [input_614], Original ATen: [quantized_decomposed.quantize_per_token] | |
buf2528 = torch.ops.quantized_decomposed.quantize_per_token.default(buf2509, buf2514, buf2515, -128, 127, torch.int8) | |
buf2529 = buf2528 | |
del buf2528 | |
# Source Nodes: [input_615], Original ATen: [quantized_decomposed.dequantize_per_token] | |
buf2530 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf2529, buf2514, buf2515, -128, 127, torch.int8, torch.bfloat16) | |
del buf2514 | |
del buf2515 | |
del buf2529 | |
buf2531 = buf2530 | |
del buf2530 | |
# Source Nodes: [w_dq_204], Original ATen: [quantized_decomposed.dequantize_per_channel_group] | |
buf2532 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg762_1, arg763_1, arg764_1, -8, 7, torch.int8, 256, torch.bfloat16) | |
del arg762_1 | |
del arg763_1 | |
del arg764_1 | |
buf2533 = buf2532 | |
del buf2532 | |
buf2534 = buf2456; del buf2456 # reuse | |
# Source Nodes: [], Original ATen: [] | |
extern_kernels.mm(reinterpret_tensor(buf2531, (1, 4096), (4096, 1), 0), reinterpret_tensor(buf2533, (4096, 1024), (1, 4096), 0), out=buf2534) | |
del buf2531 | |
del buf2533 | |
# Source Nodes: [input_617], Original ATen: [quantized_decomposed.quantize_per_token] | |
buf2536 = torch.ops.quantized_decomposed.quantize_per_token.default(buf2509, buf2517, buf2518, -128, 127, torch.int8) | |
del buf2509 | |
buf2537 = buf2536 | |
del buf2536 | |
# Source Nodes: [input_618], Original ATen: [quantized_decomposed.dequantize_per_token] | |
buf2538 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf2537, buf2517, buf2518, -128, 127, torch.int8, torch.bfloat16) | |
del buf2517 | |
del buf2518 | |
del buf2537 | |
buf2539 = buf2538 | |
del buf2538 | |
# Source Nodes: [w_dq_205], Original ATen: [quantized_decomposed.dequantize_per_channel_group] | |
buf2540 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg765_1, arg766_1, arg767_1, -8, 7, torch.int8, 256, torch.bfloat16) | |
del arg765_1 | |
del arg766_1 | |
del arg767_1 | |
buf2541 = buf2540 | |
del buf2540 | |
buf2542 = buf2448; del buf2448 # reuse | |
# Source Nodes: [], Original ATen: [] | |
extern_kernels.mm(reinterpret_tensor(buf2539, (1, 4096), (4096, 1), 0), reinterpret_tensor(buf2541, (4096, 1024), (1, 4096), 0), out=buf2542) | |
del buf2541 | |
buf2544 = reinterpret_tensor(buf2539, (1, 32, 1, 128), (4096, 128, 4096, 1), 0); del buf2539 # reuse | |
# Source Nodes: [output_58, setitem_58, setitem_59], Original ATen: [aten._scaled_dot_product_efficient_attention, aten.index_put] | |
triton_poi_fused__scaled_dot_product_efficient_attention_index_put_3.run(arg3_1, buf2534, arg768_1, buf2542, buf2527, arg769_1, arg770_1, buf2544, 4096, grid=grid(4096), stream=stream0) | |
del arg768_1 | |
del buf2527 | |
buf2545 = buf2459; del buf2459 # reuse | |
# Source Nodes: [output_58], Original ATen: [aten._scaled_dot_product_efficient_attention] | |
triton_poi_fused__scaled_dot_product_efficient_attention_4.run(arg3_1, arg2_1, buf2545, 65536, grid=grid(65536), stream=stream0) | |
# Source Nodes: [output_58], Original ATen: [aten._scaled_dot_product_efficient_attention] | |
buf2546 = torch.ops.aten._scaled_dot_product_efficient_attention.default(buf2544, arg769_1, arg770_1, buf2545, False) | |
del arg769_1 | |
del arg770_1 | |
del buf2544 | |
buf2547 = buf2546[0] | |
del buf2546 | |
# Source Nodes: [choose_qparams_per_token_asymmetric_206], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric] | |
buf2551 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(reinterpret_tensor(buf2547, (1, 1, 4096), (4096, 4096, 1), 0), torch.int8) | |
buf2552 = buf2551[0] | |
buf2553 = buf2551[1] | |
del buf2551 | |
# Source Nodes: [input_620], Original ATen: [quantized_decomposed.quantize_per_token] | |
buf2554 = torch.ops.quantized_decomposed.quantize_per_token.default(reinterpret_tensor(buf2547, (1, 1, 4096), (4096, 4096, 1), 0), buf2552, buf2553, -128, 127, torch.int8) | |
buf2555 = buf2554 | |
del buf2554 | |
# Source Nodes: [input_621], Original ATen: [quantized_decomposed.dequantize_per_token] | |
buf2556 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf2555, buf2552, buf2553, -128, 127, torch.int8, torch.bfloat16) | |
del buf2552 | |
del buf2553 | |
del buf2555 | |
buf2557 = buf2556 | |
del buf2556 | |
# Source Nodes: [w_dq_206], Original ATen: [quantized_decomposed.dequantize_per_channel_group] | |
buf2558 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg771_1, arg772_1, arg773_1, -8, 7, torch.int8, 256, torch.bfloat16) | |
del arg771_1 | |
del arg772_1 | |
del arg773_1 | |
buf2559 = buf2558 | |
del buf2558 | |
buf2560 = reinterpret_tensor(buf2547, (1, 4096), (4096, 1), 0); del buf2547 # reuse | |
# Source Nodes: [c_206], Original ATen: [aten.mm] | |
triton_tem_fused_mm_2.run(buf2557, buf2559, buf2560, grid=torch._inductor.kernel.mm_common.mm_grid(1, 4096, meta0), stream=stream0) | |
del buf2559 | |
buf2561 = buf2388; del buf2388 # reuse | |
buf2563 = buf2557; del buf2557 # reuse | |
# Source Nodes: [add_208, h_29, h_30, mean_59, mul_387, mul_388, out_27, out_28, pow_60, rsqrt_59, x_fp32_59, x_normed_59], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, aten.pow, aten.rsqrt] | |
triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_14.run(buf2561, buf2560, buf2474, buf2421, buf2507, arg774_1, buf2563, 1, 4096, grid=grid(1), stream=stream0) | |
del arg774_1 | |
del buf2421 | |
del buf2474 | |
del buf2507 | |
del buf2560 | |
# Source Nodes: [choose_qparams_per_token_asymmetric_207], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric] | |
buf2564 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf2563, torch.int8) | |
buf2565 = buf2564[0] | |
buf2566 = buf2564[1] | |
del buf2564 | |
# Source Nodes: [choose_qparams_per_token_asymmetric_208], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric] | |
buf2567 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf2563, torch.int8) | |
buf2568 = buf2567[0] | |
buf2569 = buf2567[1] | |
del buf2567 | |
# Source Nodes: [input_623], Original ATen: [quantized_decomposed.quantize_per_token] | |
buf2570 = torch.ops.quantized_decomposed.quantize_per_token.default(buf2563, buf2565, buf2566, -128, 127, torch.int8) | |
buf2571 = buf2570 | |
del buf2570 | |
# Source Nodes: [input_624], Original ATen: [quantized_decomposed.dequantize_per_token] | |
buf2572 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf2571, buf2565, buf2566, -128, 127, torch.int8, torch.bfloat16) | |
del buf2565 | |
del buf2566 | |
del buf2571 | |
buf2573 = buf2572 | |
del buf2572 | |
# Source Nodes: [w_dq_207], Original ATen: [quantized_decomposed.dequantize_per_channel_group] | |
buf2574 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg775_1, arg776_1, arg777_1, -8, 7, torch.int8, 256, torch.bfloat16) | |
del arg775_1 | |
del arg776_1 | |
del arg777_1 | |
buf2575 = buf2574 | |
del buf2574 | |
buf2576 = reinterpret_tensor(buf2504, (1, 14336), (14336, 1), 0); del buf2504 # reuse | |
# Source Nodes: [c_207], Original ATen: [aten.mm] | |
triton_tem_fused_mm_6.run(buf2573, buf2575, buf2576, grid=torch._inductor.kernel.mm_common.mm_grid(1, 14336, meta1), stream=stream0) | |
del buf2575 | |
# Source Nodes: [input_626], Original ATen: [quantized_decomposed.quantize_per_token] | |
buf2577 = torch.ops.quantized_decomposed.quantize_per_token.default(buf2563, buf2568, buf2569, -128, 127, torch.int8) | |
buf2578 = buf2577 | |
del buf2577 | |
# Source Nodes: [input_627], Original ATen: [quantized_decomposed.dequantize_per_token] | |
buf2579 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf2578, buf2568, buf2569, -128, 127, torch.int8, torch.bfloat16) | |
del buf2568 | |
del buf2569 | |
del buf2578 | |
buf2580 = buf2579 | |
del buf2579 | |
# Source Nodes: [w_dq_208], Original ATen: [quantized_decomposed.dequantize_per_channel_group] | |
buf2581 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg778_1, arg779_1, arg780_1, -8, 7, torch.int8, 256, torch.bfloat16) | |
del arg778_1 | |
del arg779_1 | |
del arg780_1 | |
buf2582 = buf2581 | |
del buf2581 | |
buf2584 = buf2497; del buf2497 # reuse | |
# Source Nodes: [c_208, mul_389, silu_29], Original ATen: [aten.mm, aten.mul, aten.silu] | |
triton_tem_fused_mm_mul_silu_7.run(buf2580, buf2582, buf2576, buf2584, grid=torch._inductor.kernel.mm_common.mm_grid(1, 14336, meta1), stream=stream0) | |
del buf2576 | |
del buf2582 | |
# Source Nodes: [choose_qparams_per_token_asymmetric_209], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric] | |
buf2585 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf2584, torch.int8) | |
buf2586 = buf2585[0] | |
buf2587 = buf2585[1] | |
del buf2585 | |
# Source Nodes: [input_629], Original ATen: [quantized_decomposed.quantize_per_token] | |
buf2588 = torch.ops.quantized_decomposed.quantize_per_token.default(buf2584, buf2586, buf2587, -128, 127, torch.int8) | |
buf2589 = buf2588 | |
del buf2588 | |
# Source Nodes: [input_630], Original ATen: [quantized_decomposed.dequantize_per_token] | |
buf2590 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf2589, buf2586, buf2587, -128, 127, torch.int8, torch.bfloat16) | |
del buf2586 | |
del buf2587 | |
del buf2589 | |
buf2591 = buf2590 | |
del buf2590 | |
# Source Nodes: [w_dq_209], Original ATen: [quantized_decomposed.dequantize_per_channel_group] | |
buf2592 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg781_1, arg782_1, arg783_1, -8, 7, torch.int8, 256, torch.bfloat16) | |
del arg781_1 | |
del arg782_1 | |
del arg783_1 | |
buf2593 = buf2592 | |
del buf2592 | |
buf2594 = reinterpret_tensor(buf2580, (1, 4096), (4096, 1), 0); del buf2580 # reuse | |
# Source Nodes: [c_209], Original ATen: [aten.mm] | |
triton_tem_fused_mm_8.run(buf2591, buf2593, buf2594, grid=torch._inductor.kernel.mm_common.mm_grid(1, 4096, meta0), stream=stream0) | |
del buf2593 | |
buf2596 = buf2563; del buf2563 # reuse | |
# Source Nodes: [add_210, mean_60, mul_390, out_29, pow_61, rsqrt_60, x_fp32_60, x_normed_60, y_30], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, aten.pow, aten.rsqrt] | |
triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_11.run(buf2561, buf2594, arg784_1, buf2596, 1, 4096, grid=grid(1), stream=stream0) | |
del arg784_1 | |
# Source Nodes: [choose_qparams_per_token_asymmetric_210], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric] | |
buf2597 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf2596, torch.int8) | |
buf2598 = buf2597[0] | |
buf2599 = buf2597[1] | |
del buf2597 | |
# Source Nodes: [choose_qparams_per_token_asymmetric_211], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric] | |
buf2600 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf2596, torch.int8) | |
buf2601 = buf2600[0] | |
buf2602 = buf2600[1] | |
del buf2600 | |
# Source Nodes: [choose_qparams_per_token_asymmetric_212], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric] | |
buf2603 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf2596, torch.int8) | |
buf2604 = buf2603[0] | |
buf2605 = buf2603[1] | |
del buf2603 | |
buf2606 = buf2519; del buf2519 # reuse | |
# Source Nodes: [max_31], Original ATen: [aten.max] | |
triton_poi_fused_max_1.run(arg3_1, buf2606, 1, grid=grid(1), stream=stream0) | |
u30 = buf2606.item() | |
buf2607 = None | |
with torch.cuda._DeviceGuard(0): | |
torch.cuda.set_device(0) | |
# Source Nodes: [input_632], Original ATen: [quantized_decomposed.quantize_per_token] | |
buf2608 = torch.ops.quantized_decomposed.quantize_per_token.default(buf2596, buf2598, buf2599, -128, 127, torch.int8) | |
buf2609 = buf2608 | |
del buf2608 | |
# Source Nodes: [input_633], Original ATen: [quantized_decomposed.dequantize_per_token] | |
buf2610 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf2609, buf2598, buf2599, -128, 127, torch.int8, torch.bfloat16) | |
del buf2598 | |
del buf2599 | |
del buf2609 | |
buf2611 = buf2610 | |
del buf2610 | |
# Source Nodes: [w_dq_210], Original ATen: [quantized_decomposed.dequantize_per_channel_group] | |
buf2612 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg785_1, arg786_1, arg787_1, -8, 7, torch.int8, 256, torch.bfloat16) | |
del arg785_1 | |
del arg786_1 | |
del arg787_1 | |
buf2613 = buf2612 | |
del buf2612 | |
buf2614 = reinterpret_tensor(buf2573, (1, 4096), (4096, 1), 0); del buf2573 # reuse | |
# Source Nodes: [c_210], Original ATen: [aten.mm] | |
triton_tem_fused_mm_2.run(buf2611, buf2613, buf2614, grid=torch._inductor.kernel.mm_common.mm_grid(1, 4096, meta0), stream=stream0) | |
del buf2611 | |
del buf2613 | |
# Source Nodes: [input_635], Original ATen: [quantized_decomposed.quantize_per_token] | |
buf2615 = torch.ops.quantized_decomposed.quantize_per_token.default(buf2596, buf2601, buf2602, -128, 127, torch.int8) | |
buf2616 = buf2615 | |
del buf2615 | |
# Source Nodes: [input_636], Original ATen: [quantized_decomposed.dequantize_per_token] | |
buf2617 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf2616, buf2601, buf2602, -128, 127, torch.int8, torch.bfloat16) | |
del buf2601 | |
del buf2602 | |
del buf2616 | |
buf2618 = buf2617 | |
del buf2617 | |
# Source Nodes: [w_dq_211], Original ATen: [quantized_decomposed.dequantize_per_channel_group] | |
buf2619 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg788_1, arg789_1, arg790_1, -8, 7, torch.int8, 256, torch.bfloat16) | |
del arg788_1 | |
del arg789_1 | |
del arg790_1 | |
buf2620 = buf2619 | |
del buf2619 | |
buf2621 = buf2542; del buf2542 # reuse | |
# Source Nodes: [], Original ATen: [] | |
extern_kernels.mm(reinterpret_tensor(buf2618, (1, 4096), (4096, 1), 0), reinterpret_tensor(buf2620, (4096, 1024), (1, 4096), 0), out=buf2621) | |
del buf2618 | |
del buf2620 | |
# Source Nodes: [input_638], Original ATen: [quantized_decomposed.quantize_per_token] | |
buf2623 = torch.ops.quantized_decomposed.quantize_per_token.default(buf2596, buf2604, buf2605, -128, 127, torch.int8) | |
del buf2596 | |
buf2624 = buf2623 | |
del buf2623 | |
# Source Nodes: [input_639], Original ATen: [quantized_decomposed.dequantize_per_token] | |
buf2625 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf2624, buf2604, buf2605, -128, 127, torch.int8, torch.bfloat16) | |
del buf2604 | |
del buf2605 | |
del buf2624 | |
buf2626 = buf2625 | |
del buf2625 | |
# Source Nodes: [w_dq_212], Original ATen: [quantized_decomposed.dequantize_per_channel_group] | |
buf2627 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg791_1, arg792_1, arg793_1, -8, 7, torch.int8, 256, torch.bfloat16) | |
del arg791_1 | |
del arg792_1 | |
del arg793_1 | |
buf2628 = buf2627 | |
del buf2627 | |
buf2629 = buf2534; del buf2534 # reuse | |
# Source Nodes: [], Original ATen: [] | |
extern_kernels.mm(reinterpret_tensor(buf2626, (1, 4096), (4096, 1), 0), reinterpret_tensor(buf2628, (4096, 1024), (1, 4096), 0), out=buf2629) | |
del buf2628 | |
buf2631 = reinterpret_tensor(buf2626, (1, 32, 1, 128), (4096, 128, 4096, 1), 0); del buf2626 # reuse | |
# Source Nodes: [output_60, setitem_60, setitem_61], Original ATen: [aten._scaled_dot_product_efficient_attention, aten.index_put] | |
triton_poi_fused__scaled_dot_product_efficient_attention_index_put_3.run(arg3_1, buf2621, arg794_1, buf2629, buf2614, arg795_1, arg796_1, buf2631, 4096, grid=grid(4096), stream=stream0) | |
del arg794_1 | |
del buf2614 | |
buf2632 = buf2545; del buf2545 # reuse | |
# Source Nodes: [output_60], Original ATen: [aten._scaled_dot_product_efficient_attention] | |
triton_poi_fused__scaled_dot_product_efficient_attention_4.run(arg3_1, arg2_1, buf2632, 65536, grid=grid(65536), stream=stream0) | |
# Source Nodes: [output_60], Original ATen: [aten._scaled_dot_product_efficient_attention] | |
buf2633 = torch.ops.aten._scaled_dot_product_efficient_attention.default(buf2631, arg795_1, arg796_1, buf2632, False) | |
del arg795_1 | |
del arg796_1 | |
del buf2631 | |
buf2634 = buf2633[0] | |
del buf2633 | |
# Source Nodes: [choose_qparams_per_token_asymmetric_213], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric] | |
buf2638 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(reinterpret_tensor(buf2634, (1, 1, 4096), (4096, 4096, 1), 0), torch.int8) | |
buf2639 = buf2638[0] | |
buf2640 = buf2638[1] | |
del buf2638 | |
# Source Nodes: [input_641], Original ATen: [quantized_decomposed.quantize_per_token] | |
buf2641 = torch.ops.quantized_decomposed.quantize_per_token.default(reinterpret_tensor(buf2634, (1, 1, 4096), (4096, 4096, 1), 0), buf2639, buf2640, -128, 127, torch.int8) | |
buf2642 = buf2641 | |
del buf2641 | |
# Source Nodes: [input_642], Original ATen: [quantized_decomposed.dequantize_per_token] | |
buf2643 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf2642, buf2639, buf2640, -128, 127, torch.int8, torch.bfloat16) | |
del buf2639 | |
del buf2640 | |
del buf2642 | |
buf2644 = buf2643 | |
del buf2643 | |
# Source Nodes: [w_dq_213], Original ATen: [quantized_decomposed.dequantize_per_channel_group] | |
buf2645 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg797_1, arg798_1, arg799_1, -8, 7, torch.int8, 256, torch.bfloat16) | |
del arg797_1 | |
del arg798_1 | |
del arg799_1 | |
buf2646 = buf2645 | |
del buf2645 | |
buf2647 = reinterpret_tensor(buf2634, (1, 4096), (4096, 1), 0); del buf2634 # reuse | |
# Source Nodes: [c_213], Original ATen: [aten.mm] | |
triton_tem_fused_mm_2.run(buf2644, buf2646, buf2647, grid=torch._inductor.kernel.mm_common.mm_grid(1, 4096, meta0), stream=stream0) | |
del buf2646 | |
buf2649 = buf2644; del buf2644 # reuse | |
# Source Nodes: [add_215, h_31, mean_61, mul_400, mul_401, out_29, pow_62, rsqrt_61, x_fp32_61, x_normed_61], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, aten.pow, aten.rsqrt] | |
triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_12.run(buf2647, buf2561, buf2594, arg800_1, buf2649, 1, 4096, grid=grid(1), stream=stream0) | |
del arg800_1 | |
# Source Nodes: [choose_qparams_per_token_asymmetric_214], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric] | |
buf2650 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf2649, torch.int8) | |
buf2651 = buf2650[0] | |
buf2652 = buf2650[1] | |
del buf2650 | |
# Source Nodes: [choose_qparams_per_token_asymmetric_215], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric] | |
buf2653 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf2649, torch.int8) | |
buf2654 = buf2653[0] | |
buf2655 = buf2653[1] | |
del buf2653 | |
# Source Nodes: [input_644], Original ATen: [quantized_decomposed.quantize_per_token] | |
buf2656 = torch.ops.quantized_decomposed.quantize_per_token.default(buf2649, buf2651, buf2652, -128, 127, torch.int8) | |
buf2657 = buf2656 | |
del buf2656 | |
# Source Nodes: [input_645], Original ATen: [quantized_decomposed.dequantize_per_token] | |
buf2658 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf2657, buf2651, buf2652, -128, 127, torch.int8, torch.bfloat16) | |
del buf2651 | |
del buf2652 | |
del buf2657 | |
buf2659 = buf2658 | |
del buf2658 | |
# Source Nodes: [w_dq_214], Original ATen: [quantized_decomposed.dequantize_per_channel_group] | |
buf2660 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg801_1, arg802_1, arg803_1, -8, 7, torch.int8, 256, torch.bfloat16) | |
del arg801_1 | |
del arg802_1 | |
del arg803_1 | |
buf2661 = buf2660 | |
del buf2660 | |
buf2662 = reinterpret_tensor(buf2591, (1, 14336), (14336, 1), 0); del buf2591 # reuse | |
# Source Nodes: [c_214], Original ATen: [aten.mm] | |
triton_tem_fused_mm_6.run(buf2659, buf2661, buf2662, grid=torch._inductor.kernel.mm_common.mm_grid(1, 14336, meta1), stream=stream0) | |
del buf2661 | |
# Source Nodes: [input_647], Original ATen: [quantized_decomposed.quantize_per_token] | |
buf2663 = torch.ops.quantized_decomposed.quantize_per_token.default(buf2649, buf2654, buf2655, -128, 127, torch.int8) | |
buf2664 = buf2663 | |
del buf2663 | |
# Source Nodes: [input_648], Original ATen: [quantized_decomposed.dequantize_per_token] | |
buf2665 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf2664, buf2654, buf2655, -128, 127, torch.int8, torch.bfloat16) | |
del buf2654 | |
del buf2655 | |
del buf2664 | |
buf2666 = buf2665 | |
del buf2665 | |
# Source Nodes: [w_dq_215], Original ATen: [quantized_decomposed.dequantize_per_channel_group] | |
buf2667 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg804_1, arg805_1, arg806_1, -8, 7, torch.int8, 256, torch.bfloat16) | |
del arg804_1 | |
del arg805_1 | |
del arg806_1 | |
buf2668 = buf2667 | |
del buf2667 | |
buf2670 = buf2584; del buf2584 # reuse | |
# Source Nodes: [c_215, mul_402, silu_30], Original ATen: [aten.mm, aten.mul, aten.silu] | |
triton_tem_fused_mm_mul_silu_7.run(buf2666, buf2668, buf2662, buf2670, grid=torch._inductor.kernel.mm_common.mm_grid(1, 14336, meta1), stream=stream0) | |
del buf2662 | |
del buf2668 | |
# Source Nodes: [choose_qparams_per_token_asymmetric_216], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric] | |
buf2671 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf2670, torch.int8) | |
buf2672 = buf2671[0] | |
buf2673 = buf2671[1] | |
del buf2671 | |
# Source Nodes: [input_650], Original ATen: [quantized_decomposed.quantize_per_token] | |
buf2674 = torch.ops.quantized_decomposed.quantize_per_token.default(buf2670, buf2672, buf2673, -128, 127, torch.int8) | |
buf2675 = buf2674 | |
del buf2674 | |
# Source Nodes: [input_651], Original ATen: [quantized_decomposed.dequantize_per_token] | |
buf2676 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf2675, buf2672, buf2673, -128, 127, torch.int8, torch.bfloat16) | |
del buf2672 | |
del buf2673 | |
del buf2675 | |
buf2677 = buf2676 | |
del buf2676 | |
# Source Nodes: [w_dq_216], Original ATen: [quantized_decomposed.dequantize_per_channel_group] | |
buf2678 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg807_1, arg808_1, arg809_1, -8, 7, torch.int8, 256, torch.bfloat16) | |
del arg807_1 | |
del arg808_1 | |
del arg809_1 | |
buf2679 = buf2678 | |
del buf2678 | |
buf2680 = reinterpret_tensor(buf2666, (1, 4096), (4096, 1), 0); del buf2666 # reuse | |
# Source Nodes: [c_216], Original ATen: [aten.mm] | |
triton_tem_fused_mm_8.run(buf2677, buf2679, buf2680, grid=torch._inductor.kernel.mm_common.mm_grid(1, 4096, meta0), stream=stream0) | |
del buf2679 | |
buf2682 = buf2649; del buf2649 # reuse | |
# Source Nodes: [add_217, h_31, mean_62, mul_403, out_29, out_30, pow_63, rsqrt_62, x_fp32_62, x_normed_62, y_31], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, aten.pow, aten.rsqrt] | |
triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_13.run(buf2647, buf2561, buf2594, buf2680, arg810_1, buf2682, 1, 4096, grid=grid(1), stream=stream0) | |
del arg810_1 | |
# Source Nodes: [choose_qparams_per_token_asymmetric_217], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric] | |
buf2683 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf2682, torch.int8) | |
buf2684 = buf2683[0] | |
buf2685 = buf2683[1] | |
del buf2683 | |
# Source Nodes: [choose_qparams_per_token_asymmetric_218], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric] | |
buf2686 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf2682, torch.int8) | |
buf2687 = buf2686[0] | |
buf2688 = buf2686[1] | |
del buf2686 | |
# Source Nodes: [choose_qparams_per_token_asymmetric_219], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric] | |
buf2689 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf2682, torch.int8) | |
buf2690 = buf2689[0] | |
buf2691 = buf2689[1] | |
del buf2689 | |
buf2692 = buf2606; del buf2606 # reuse | |
# Source Nodes: [max_32], Original ATen: [aten.max] | |
triton_poi_fused_max_1.run(arg3_1, buf2692, 1, grid=grid(1), stream=stream0) | |
u31 = buf2692.item() | |
buf2693 = None | |
with torch.cuda._DeviceGuard(0): | |
torch.cuda.set_device(0) | |
# Source Nodes: [input_653], Original ATen: [quantized_decomposed.quantize_per_token] | |
buf2694 = torch.ops.quantized_decomposed.quantize_per_token.default(buf2682, buf2684, buf2685, -128, 127, torch.int8) | |
buf2695 = buf2694 | |
del buf2694 | |
# Source Nodes: [input_654], Original ATen: [quantized_decomposed.dequantize_per_token] | |
buf2696 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf2695, buf2684, buf2685, -128, 127, torch.int8, torch.bfloat16) | |
del buf2684 | |
del buf2685 | |
del buf2695 | |
buf2697 = buf2696 | |
del buf2696 | |
# Source Nodes: [w_dq_217], Original ATen: [quantized_decomposed.dequantize_per_channel_group] | |
buf2698 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg811_1, arg812_1, arg813_1, -8, 7, torch.int8, 256, torch.bfloat16) | |
del arg811_1 | |
del arg812_1 | |
del arg813_1 | |
buf2699 = buf2698 | |
del buf2698 | |
buf2700 = reinterpret_tensor(buf2659, (1, 4096), (4096, 1), 0); del buf2659 # reuse | |
# Source Nodes: [c_217], Original ATen: [aten.mm] | |
triton_tem_fused_mm_2.run(buf2697, buf2699, buf2700, grid=torch._inductor.kernel.mm_common.mm_grid(1, 4096, meta0), stream=stream0) | |
del buf2697 | |
del buf2699 | |
# Source Nodes: [input_656], Original ATen: [quantized_decomposed.quantize_per_token] | |
buf2701 = torch.ops.quantized_decomposed.quantize_per_token.default(buf2682, buf2687, buf2688, -128, 127, torch.int8) | |
buf2702 = buf2701 | |
del buf2701 | |
# Source Nodes: [input_657], Original ATen: [quantized_decomposed.dequantize_per_token] | |
buf2703 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf2702, buf2687, buf2688, -128, 127, torch.int8, torch.bfloat16) | |
del buf2687 | |
del buf2688 | |
del buf2702 | |
buf2704 = buf2703 | |
del buf2703 | |
# Source Nodes: [w_dq_218], Original ATen: [quantized_decomposed.dequantize_per_channel_group] | |
buf2705 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg814_1, arg815_1, arg816_1, -8, 7, torch.int8, 256, torch.bfloat16) | |
del arg814_1 | |
del arg815_1 | |
del arg816_1 | |
buf2706 = buf2705 | |
del buf2705 | |
buf2707 = buf2629; del buf2629 # reuse | |
# Source Nodes: [], Original ATen: [] | |
extern_kernels.mm(reinterpret_tensor(buf2704, (1, 4096), (4096, 1), 0), reinterpret_tensor(buf2706, (4096, 1024), (1, 4096), 0), out=buf2707) | |
del buf2704 | |
del buf2706 | |
# Source Nodes: [input_659], Original ATen: [quantized_decomposed.quantize_per_token] | |
buf2709 = torch.ops.quantized_decomposed.quantize_per_token.default(buf2682, buf2690, buf2691, -128, 127, torch.int8) | |
del buf2682 | |
buf2710 = buf2709 | |
del buf2709 | |
# Source Nodes: [input_660], Original ATen: [quantized_decomposed.dequantize_per_token] | |
buf2711 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf2710, buf2690, buf2691, -128, 127, torch.int8, torch.bfloat16) | |
del buf2690 | |
del buf2691 | |
del buf2710 | |
buf2712 = buf2711 | |
del buf2711 | |
# Source Nodes: [w_dq_219], Original ATen: [quantized_decomposed.dequantize_per_channel_group] | |
buf2713 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg817_1, arg818_1, arg819_1, -8, 7, torch.int8, 256, torch.bfloat16) | |
del arg817_1 | |
del arg818_1 | |
del arg819_1 | |
buf2714 = buf2713 | |
del buf2713 | |
buf2715 = buf2621; del buf2621 # reuse | |
# Source Nodes: [], Original ATen: [] | |
extern_kernels.mm(reinterpret_tensor(buf2712, (1, 4096), (4096, 1), 0), reinterpret_tensor(buf2714, (4096, 1024), (1, 4096), 0), out=buf2715) | |
del buf2714 | |
buf2717 = reinterpret_tensor(buf2712, (1, 32, 1, 128), (4096, 128, 4096, 1), 0); del buf2712 # reuse | |
# Source Nodes: [output_62, setitem_62, setitem_63], Original ATen: [aten._scaled_dot_product_efficient_attention, aten.index_put] | |
triton_poi_fused__scaled_dot_product_efficient_attention_index_put_3.run(arg3_1, buf2707, arg820_1, buf2715, buf2700, arg821_1, arg822_1, buf2717, 4096, grid=grid(4096), stream=stream0) | |
del arg820_1 | |
del buf2700 | |
del buf2707 | |
del buf2715 | |
buf2718 = buf2632; del buf2632 # reuse | |
# Source Nodes: [output_62], Original ATen: [aten._scaled_dot_product_efficient_attention] | |
triton_poi_fused__scaled_dot_product_efficient_attention_4.run(arg3_1, arg2_1, buf2718, 65536, grid=grid(65536), stream=stream0) | |
del arg2_1 | |
del arg3_1 | |
# Source Nodes: [output_62], Original ATen: [aten._scaled_dot_product_efficient_attention] | |
buf2719 = torch.ops.aten._scaled_dot_product_efficient_attention.default(buf2717, arg821_1, arg822_1, buf2718, False) | |
del arg821_1 | |
del arg822_1 | |
del buf2717 | |
del buf2718 | |
buf2720 = buf2719[0] | |
del buf2719 | |
# Source Nodes: [choose_qparams_per_token_asymmetric_220], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric] | |
buf2724 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(reinterpret_tensor(buf2720, (1, 1, 4096), (4096, 4096, 1), 0), torch.int8) | |
buf2725 = buf2724[0] | |
buf2726 = buf2724[1] | |
del buf2724 | |
# Source Nodes: [input_662], Original ATen: [quantized_decomposed.quantize_per_token] | |
buf2727 = torch.ops.quantized_decomposed.quantize_per_token.default(reinterpret_tensor(buf2720, (1, 1, 4096), (4096, 4096, 1), 0), buf2725, buf2726, -128, 127, torch.int8) | |
buf2728 = buf2727 | |
del buf2727 | |
# Source Nodes: [input_663], Original ATen: [quantized_decomposed.dequantize_per_token] | |
buf2729 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf2728, buf2725, buf2726, -128, 127, torch.int8, torch.bfloat16) | |
del buf2725 | |
del buf2726 | |
del buf2728 | |
buf2730 = buf2729 | |
del buf2729 | |
# Source Nodes: [w_dq_220], Original ATen: [quantized_decomposed.dequantize_per_channel_group] | |
buf2731 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg823_1, arg824_1, arg825_1, -8, 7, torch.int8, 256, torch.bfloat16) | |
del arg823_1 | |
del arg824_1 | |
del arg825_1 | |
buf2732 = buf2731 | |
del buf2731 | |
buf2733 = reinterpret_tensor(buf2720, (1, 4096), (4096, 1), 0); del buf2720 # reuse | |
# Source Nodes: [c_220], Original ATen: [aten.mm] | |
triton_tem_fused_mm_2.run(buf2730, buf2732, buf2733, grid=torch._inductor.kernel.mm_common.mm_grid(1, 4096, meta0), stream=stream0) | |
del buf2732 | |
buf2734 = buf2561; del buf2561 # reuse | |
buf2736 = buf2730; del buf2730 # reuse | |
# Source Nodes: [add_222, h_31, h_32, mean_63, mul_413, mul_414, out_29, out_30, pow_64, rsqrt_63, x_fp32_63, x_normed_63], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, aten.pow, aten.rsqrt] | |
triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_14.run(buf2734, buf2733, buf2647, buf2594, buf2680, arg826_1, buf2736, 1, 4096, grid=grid(1), stream=stream0) | |
del arg826_1 | |
del buf2594 | |
del buf2647 | |
del buf2680 | |
del buf2733 | |
# Source Nodes: [choose_qparams_per_token_asymmetric_221], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric] | |
buf2737 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf2736, torch.int8) | |
buf2738 = buf2737[0] | |
buf2739 = buf2737[1] | |
del buf2737 | |
# Source Nodes: [choose_qparams_per_token_asymmetric_222], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric] | |
buf2740 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf2736, torch.int8) | |
buf2741 = buf2740[0] | |
buf2742 = buf2740[1] | |
del buf2740 | |
# Source Nodes: [input_665], Original ATen: [quantized_decomposed.quantize_per_token] | |
buf2743 = torch.ops.quantized_decomposed.quantize_per_token.default(buf2736, buf2738, buf2739, -128, 127, torch.int8) | |
buf2744 = buf2743 | |
del buf2743 | |
# Source Nodes: [input_666], Original ATen: [quantized_decomposed.dequantize_per_token] | |
buf2745 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf2744, buf2738, buf2739, -128, 127, torch.int8, torch.bfloat16) | |
del buf2738 | |
del buf2739 | |
del buf2744 | |
buf2746 = buf2745 | |
del buf2745 | |
# Source Nodes: [w_dq_221], Original ATen: [quantized_decomposed.dequantize_per_channel_group] | |
buf2747 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg827_1, arg828_1, arg829_1, -8, 7, torch.int8, 256, torch.bfloat16) | |
del arg827_1 | |
del arg828_1 | |
del arg829_1 | |
buf2748 = buf2747 | |
del buf2747 | |
buf2749 = reinterpret_tensor(buf2677, (1, 14336), (14336, 1), 0); del buf2677 # reuse | |
# Source Nodes: [c_221], Original ATen: [aten.mm] | |
triton_tem_fused_mm_6.run(buf2746, buf2748, buf2749, grid=torch._inductor.kernel.mm_common.mm_grid(1, 14336, meta1), stream=stream0) | |
del buf2746 | |
del buf2748 | |
# Source Nodes: [input_668], Original ATen: [quantized_decomposed.quantize_per_token] | |
buf2750 = torch.ops.quantized_decomposed.quantize_per_token.default(buf2736, buf2741, buf2742, -128, 127, torch.int8) | |
buf2751 = buf2750 | |
del buf2750 | |
# Source Nodes: [input_669], Original ATen: [quantized_decomposed.dequantize_per_token] | |
buf2752 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf2751, buf2741, buf2742, -128, 127, torch.int8, torch.bfloat16) | |
del buf2741 | |
del buf2742 | |
del buf2751 | |
buf2753 = buf2752 | |
del buf2752 | |
# Source Nodes: [w_dq_222], Original ATen: [quantized_decomposed.dequantize_per_channel_group] | |
buf2754 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg830_1, arg831_1, arg832_1, -8, 7, torch.int8, 256, torch.bfloat16) | |
del arg830_1 | |
del arg831_1 | |
del arg832_1 | |
buf2755 = buf2754 | |
del buf2754 | |
buf2757 = buf2670; del buf2670 # reuse | |
# Source Nodes: [c_222, mul_415, silu_31], Original ATen: [aten.mm, aten.mul, aten.silu] | |
triton_tem_fused_mm_mul_silu_7.run(buf2753, buf2755, buf2749, buf2757, grid=torch._inductor.kernel.mm_common.mm_grid(1, 14336, meta1), stream=stream0) | |
del buf2749 | |
del buf2755 | |
# Source Nodes: [choose_qparams_per_token_asymmetric_223], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric] | |
buf2758 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf2757, torch.int8) | |
buf2759 = buf2758[0] | |
buf2760 = buf2758[1] | |
del buf2758 | |
# Source Nodes: [input_671], Original ATen: [quantized_decomposed.quantize_per_token] | |
buf2761 = torch.ops.quantized_decomposed.quantize_per_token.default(buf2757, buf2759, buf2760, -128, 127, torch.int8) | |
del buf2757 | |
buf2762 = buf2761 | |
del buf2761 | |
# Source Nodes: [input_672], Original ATen: [quantized_decomposed.dequantize_per_token] | |
buf2763 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf2762, buf2759, buf2760, -128, 127, torch.int8, torch.bfloat16) | |
del buf2759 | |
del buf2760 | |
del buf2762 | |
buf2764 = buf2763 | |
del buf2763 | |
# Source Nodes: [w_dq_223], Original ATen: [quantized_decomposed.dequantize_per_channel_group] | |
buf2765 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg833_1, arg834_1, arg835_1, -8, 7, torch.int8, 256, torch.bfloat16) | |
del arg833_1 | |
del arg834_1 | |
del arg835_1 | |
buf2766 = buf2765 | |
del buf2765 | |
buf2767 = reinterpret_tensor(buf2753, (1, 4096), (4096, 1), 0); del buf2753 # reuse | |
# Source Nodes: [c_223], Original ATen: [aten.mm] | |
triton_tem_fused_mm_8.run(buf2764, buf2766, buf2767, grid=torch._inductor.kernel.mm_common.mm_grid(1, 4096, meta0), stream=stream0) | |
del buf2764 | |
del buf2766 | |
buf2769 = buf2736; del buf2736 # reuse | |
# Source Nodes: [add_224, h_33, mean_64, mul_416, out_31, pow_65, rsqrt_64, x_fp32_64, x_normed_64], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, aten.pow, aten.rsqrt] | |
triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_11.run(buf2734, buf2767, arg836_1, buf2769, 1, 4096, grid=grid(1), stream=stream0) | |
del arg836_1 | |
del buf2734 | |
del buf2767 | |
# Source Nodes: [choose_qparams_per_token_asymmetric_224], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric] | |
buf2770 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf2769, torch.int8) | |
buf2771 = buf2770[0] | |
buf2772 = buf2770[1] | |
del buf2770 | |
# Source Nodes: [input_674], Original ATen: [quantized_decomposed.quantize_per_token] | |
buf2773 = torch.ops.quantized_decomposed.quantize_per_token.default(buf2769, buf2771, buf2772, -128, 127, torch.int8) | |
del buf2769 | |
buf2774 = buf2773 | |
del buf2773 | |
# Source Nodes: [input_675], Original ATen: [quantized_decomposed.dequantize_per_token] | |
buf2775 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf2774, buf2771, buf2772, -128, 127, torch.int8, torch.bfloat16) | |
del buf2771 | |
del buf2772 | |
del buf2774 | |
buf2776 = buf2775 | |
del buf2775 | |
# Source Nodes: [w_dq_224], Original ATen: [quantized_decomposed.dequantize_per_channel_group] | |
buf2777 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg837_1, arg838_1, arg839_1, -8, 7, torch.int8, 256, torch.bfloat16) | |
del arg837_1 | |
del arg838_1 | |
del arg839_1 | |
buf2778 = buf2777 | |
del buf2777 | |
buf2780 = empty_strided_cuda((1, 128256), (128256, 1), torch.float32) | |
# Source Nodes: [c_224, logits_1], Original ATen: [aten.div, aten.mm] | |
triton_tem_fused_div_mm_16.run(buf2776, buf2778, buf2780, grid=torch._inductor.kernel.mm_common.mm_grid(1, 128256, meta2), stream=stream0) | |
del buf2776 | |
del buf2778 | |
# Source Nodes: [logits_1, topk], Original ATen: [aten.div, aten.topk] | |
buf2781 = torch.ops.aten.topk.default(buf2780, 300) | |
buf2782 = buf2781[0] | |
del buf2781 | |
buf2784 = reinterpret_tensor(buf2692, (1, ), (1, ), 0); del buf2692 # reuse | |
# Source Nodes: [], Original ATen: [] | |
aten.randint.low_out(-9223372036854775808, 9223372036854775807, [1], out=buf2784) | |
buf2786 = empty_strided_cuda((1, 1, 16), (16, 16, 1), torch.float32) | |
# Source Nodes: [logits_2, lt, probs], Original ATen: [aten._softmax, aten.lt, aten.scalar_tensor, aten.where] | |
triton_red_fused__softmax_lt_scalar_tensor_where_17.run(buf2780, buf2782, buf2786, 16, 8016, grid=grid(16), stream=stream0) | |
buf2787 = empty_strided_cuda((1, 1), (1, 1), torch.float32) | |
# Source Nodes: [logits_2, lt, probs], Original ATen: [aten._softmax, aten.lt, aten.scalar_tensor, aten.where] | |
triton_per_fused__softmax_lt_scalar_tensor_where_18.run(buf2786, buf2787, 1, 16, grid=grid(1), stream=stream0) | |
buf2788 = buf2786; del buf2786 # reuse | |
# Source Nodes: [logits_2, lt, probs], Original ATen: [aten._softmax, aten.lt, aten.scalar_tensor, aten.where] | |
triton_red_fused__softmax_lt_scalar_tensor_where_19.run(buf2780, buf2782, buf2787, buf2788, 16, 8016, grid=grid(16), stream=stream0) | |
buf2789 = empty_strided_cuda((1, 1), (1, 1), torch.float32) | |
# Source Nodes: [logits_2, lt, probs], Original ATen: [aten._softmax, aten.lt, aten.scalar_tensor, aten.where] | |
triton_per_fused__softmax_lt_scalar_tensor_where_20.run(buf2788, buf2789, 1, 16, grid=grid(1), stream=stream0) | |
del buf2788 | |
buf2791 = empty_strided_cuda((1, 1), (1, 1), torch.int32) | |
# Source Nodes: [argmax, logits_2, lt, probs, q_128, to_450, truediv_1], Original ATen: [aten._softmax, aten._to_copy, aten.argmax, aten.div, aten.exponential, aten.lt, aten.scalar_tensor, aten.where] | |
triton_red_fused__softmax__to_copy_argmax_div_exponential_lt_scalar_tensor_where_21.run(buf2784, buf2780, buf2782, buf2787, buf2789, buf2791, 0, 1, 128256, grid=grid(1), stream=stream0) | |
del buf2780 | |
del buf2782 | |
del buf2784 | |
del buf2787 | |
del buf2789 | |
return (buf2791, 1 + u0, 1 + u1, 1 + u2, 1 + u3, 1 + u4, 1 + u5, 1 + u6, 1 + u7, 1 + u8, 1 + u9, 1 + u10, 1 + u11, 1 + u12, 1 + u13, 1 + u14, 1 + u15, 1 + u16, 1 + u17, 1 + u18, 1 + u19, 1 + u20, 1 + u21, 1 + u22, 1 + u23, 1 + u24, 1 + u25, 1 + u26, 1 + u27, 1 + u28, 1 + u29, 1 + u30, 1 + u31, ) | |
def benchmark_compiled_module(times=10, repeat=10): | |
from torch._dynamo.testing import rand_strided | |
from torch._inductor.utils import print_performance | |
arg0_1 = rand_strided((1, 1), (1, 1), device='cuda:0', dtype=torch.int32) | |
arg1_1 = rand_strided((128256, 4096), (4096, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg2_1 = rand_strided((2048, 2048), (2048, 1), device='cuda:0', dtype=torch.bool) | |
arg3_1 = rand_strided((1, ), (1, ), device='cuda:0', dtype=torch.int64) | |
arg4_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.bfloat16) | |
arg5_1 = rand_strided((4096, 4096), (4096, 1), device='cuda:0', dtype=torch.int8) | |
arg6_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg7_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg8_1 = rand_strided((1024, 4096), (4096, 1), device='cuda:0', dtype=torch.int8) | |
arg9_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg10_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg11_1 = rand_strided((1024, 4096), (4096, 1), device='cuda:0', dtype=torch.int8) | |
arg12_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg13_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg14_1 = rand_strided((2048, 64, 2), (128, 2, 1), device='cuda:0', dtype=torch.float32) | |
arg15_1 = rand_strided((1, 32, 2048, 128), (8388608, 262144, 128, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg16_1 = rand_strided((1, 32, 2048, 128), (8388608, 262144, 128, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg17_1 = rand_strided((4096, 4096), (4096, 1), device='cuda:0', dtype=torch.int8) | |
arg18_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg19_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg20_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.bfloat16) | |
arg21_1 = rand_strided((14336, 4096), (4096, 1), device='cuda:0', dtype=torch.int8) | |
arg22_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg23_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg24_1 = rand_strided((14336, 4096), (4096, 1), device='cuda:0', dtype=torch.int8) | |
arg25_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg26_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg27_1 = rand_strided((4096, 14336), (14336, 1), device='cuda:0', dtype=torch.int8) | |
arg28_1 = rand_strided((4096, 56), (56, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg29_1 = rand_strided((4096, 56), (56, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg30_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.bfloat16) | |
arg31_1 = rand_strided((4096, 4096), (4096, 1), device='cuda:0', dtype=torch.int8) | |
arg32_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg33_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg34_1 = rand_strided((1024, 4096), (4096, 1), device='cuda:0', dtype=torch.int8) | |
arg35_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg36_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg37_1 = rand_strided((1024, 4096), (4096, 1), device='cuda:0', dtype=torch.int8) | |
arg38_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg39_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg40_1 = rand_strided((2048, 64, 2), (128, 2, 1), device='cuda:0', dtype=torch.float32) | |
arg41_1 = rand_strided((1, 32, 2048, 128), (8388608, 262144, 128, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg42_1 = rand_strided((1, 32, 2048, 128), (8388608, 262144, 128, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg43_1 = rand_strided((4096, 4096), (4096, 1), device='cuda:0', dtype=torch.int8) | |
arg44_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg45_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg46_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.bfloat16) | |
arg47_1 = rand_strided((14336, 4096), (4096, 1), device='cuda:0', dtype=torch.int8) | |
arg48_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg49_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg50_1 = rand_strided((14336, 4096), (4096, 1), device='cuda:0', dtype=torch.int8) | |
arg51_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg52_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg53_1 = rand_strided((4096, 14336), (14336, 1), device='cuda:0', dtype=torch.int8) | |
arg54_1 = rand_strided((4096, 56), (56, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg55_1 = rand_strided((4096, 56), (56, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg56_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.bfloat16) | |
arg57_1 = rand_strided((4096, 4096), (4096, 1), device='cuda:0', dtype=torch.int8) | |
arg58_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg59_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg60_1 = rand_strided((1024, 4096), (4096, 1), device='cuda:0', dtype=torch.int8) | |
arg61_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg62_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg63_1 = rand_strided((1024, 4096), (4096, 1), device='cuda:0', dtype=torch.int8) | |
arg64_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg65_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg66_1 = rand_strided((2048, 64, 2), (128, 2, 1), device='cuda:0', dtype=torch.float32) | |
arg67_1 = rand_strided((1, 32, 2048, 128), (8388608, 262144, 128, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg68_1 = rand_strided((1, 32, 2048, 128), (8388608, 262144, 128, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg69_1 = rand_strided((4096, 4096), (4096, 1), device='cuda:0', dtype=torch.int8) | |
arg70_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg71_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg72_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.bfloat16) | |
arg73_1 = rand_strided((14336, 4096), (4096, 1), device='cuda:0', dtype=torch.int8) | |
arg74_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg75_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg76_1 = rand_strided((14336, 4096), (4096, 1), device='cuda:0', dtype=torch.int8) | |
arg77_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg78_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg79_1 = rand_strided((4096, 14336), (14336, 1), device='cuda:0', dtype=torch.int8) | |
arg80_1 = rand_strided((4096, 56), (56, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg81_1 = rand_strided((4096, 56), (56, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg82_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.bfloat16) | |
arg83_1 = rand_strided((4096, 4096), (4096, 1), device='cuda:0', dtype=torch.int8) | |
arg84_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg85_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg86_1 = rand_strided((1024, 4096), (4096, 1), device='cuda:0', dtype=torch.int8) | |
arg87_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg88_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg89_1 = rand_strided((1024, 4096), (4096, 1), device='cuda:0', dtype=torch.int8) | |
arg90_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg91_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg92_1 = rand_strided((2048, 64, 2), (128, 2, 1), device='cuda:0', dtype=torch.float32) | |
arg93_1 = rand_strided((1, 32, 2048, 128), (8388608, 262144, 128, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg94_1 = rand_strided((1, 32, 2048, 128), (8388608, 262144, 128, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg95_1 = rand_strided((4096, 4096), (4096, 1), device='cuda:0', dtype=torch.int8) | |
arg96_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg97_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg98_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.bfloat16) | |
arg99_1 = rand_strided((14336, 4096), (4096, 1), device='cuda:0', dtype=torch.int8) | |
arg100_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg101_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg102_1 = rand_strided((14336, 4096), (4096, 1), device='cuda:0', dtype=torch.int8) | |
arg103_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg104_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg105_1 = rand_strided((4096, 14336), (14336, 1), device='cuda:0', dtype=torch.int8) | |
arg106_1 = rand_strided((4096, 56), (56, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg107_1 = rand_strided((4096, 56), (56, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg108_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.bfloat16) | |
arg109_1 = rand_strided((4096, 4096), (4096, 1), device='cuda:0', dtype=torch.int8) | |
arg110_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg111_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg112_1 = rand_strided((1024, 4096), (4096, 1), device='cuda:0', dtype=torch.int8) | |
arg113_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg114_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg115_1 = rand_strided((1024, 4096), (4096, 1), device='cuda:0', dtype=torch.int8) | |
arg116_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg117_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg118_1 = rand_strided((2048, 64, 2), (128, 2, 1), device='cuda:0', dtype=torch.float32) | |
arg119_1 = rand_strided((1, 32, 2048, 128), (8388608, 262144, 128, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg120_1 = rand_strided((1, 32, 2048, 128), (8388608, 262144, 128, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg121_1 = rand_strided((4096, 4096), (4096, 1), device='cuda:0', dtype=torch.int8) | |
arg122_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg123_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg124_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.bfloat16) | |
arg125_1 = rand_strided((14336, 4096), (4096, 1), device='cuda:0', dtype=torch.int8) | |
arg126_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg127_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg128_1 = rand_strided((14336, 4096), (4096, 1), device='cuda:0', dtype=torch.int8) | |
arg129_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg130_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg131_1 = rand_strided((4096, 14336), (14336, 1), device='cuda:0', dtype=torch.int8) | |
arg132_1 = rand_strided((4096, 56), (56, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg133_1 = rand_strided((4096, 56), (56, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg134_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.bfloat16) | |
arg135_1 = rand_strided((4096, 4096), (4096, 1), device='cuda:0', dtype=torch.int8) | |
arg136_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg137_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg138_1 = rand_strided((1024, 4096), (4096, 1), device='cuda:0', dtype=torch.int8) | |
arg139_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg140_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg141_1 = rand_strided((1024, 4096), (4096, 1), device='cuda:0', dtype=torch.int8) | |
arg142_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg143_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg144_1 = rand_strided((2048, 64, 2), (128, 2, 1), device='cuda:0', dtype=torch.float32) | |
arg145_1 = rand_strided((1, 32, 2048, 128), (8388608, 262144, 128, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg146_1 = rand_strided((1, 32, 2048, 128), (8388608, 262144, 128, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg147_1 = rand_strided((4096, 4096), (4096, 1), device='cuda:0', dtype=torch.int8) | |
arg148_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg149_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg150_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.bfloat16) | |
arg151_1 = rand_strided((14336, 4096), (4096, 1), device='cuda:0', dtype=torch.int8) | |
arg152_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg153_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg154_1 = rand_strided((14336, 4096), (4096, 1), device='cuda:0', dtype=torch.int8) | |
arg155_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg156_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg157_1 = rand_strided((4096, 14336), (14336, 1), device='cuda:0', dtype=torch.int8) | |
arg158_1 = rand_strided((4096, 56), (56, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg159_1 = rand_strided((4096, 56), (56, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg160_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.bfloat16) | |
arg161_1 = rand_strided((4096, 4096), (4096, 1), device='cuda:0', dtype=torch.int8) | |
arg162_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg163_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg164_1 = rand_strided((1024, 4096), (4096, 1), device='cuda:0', dtype=torch.int8) | |
arg165_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg166_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg167_1 = rand_strided((1024, 4096), (4096, 1), device='cuda:0', dtype=torch.int8) | |
arg168_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg169_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg170_1 = rand_strided((2048, 64, 2), (128, 2, 1), device='cuda:0', dtype=torch.float32) | |
arg171_1 = rand_strided((1, 32, 2048, 128), (8388608, 262144, 128, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg172_1 = rand_strided((1, 32, 2048, 128), (8388608, 262144, 128, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg173_1 = rand_strided((4096, 4096), (4096, 1), device='cuda:0', dtype=torch.int8) | |
arg174_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg175_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg176_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.bfloat16) | |
arg177_1 = rand_strided((14336, 4096), (4096, 1), device='cuda:0', dtype=torch.int8) | |
arg178_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg179_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg180_1 = rand_strided((14336, 4096), (4096, 1), device='cuda:0', dtype=torch.int8) | |
arg181_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg182_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg183_1 = rand_strided((4096, 14336), (14336, 1), device='cuda:0', dtype=torch.int8) | |
arg184_1 = rand_strided((4096, 56), (56, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg185_1 = rand_strided((4096, 56), (56, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg186_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.bfloat16) | |
arg187_1 = rand_strided((4096, 4096), (4096, 1), device='cuda:0', dtype=torch.int8) | |
arg188_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg189_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg190_1 = rand_strided((1024, 4096), (4096, 1), device='cuda:0', dtype=torch.int8) | |
arg191_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg192_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg193_1 = rand_strided((1024, 4096), (4096, 1), device='cuda:0', dtype=torch.int8) | |
arg194_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg195_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg196_1 = rand_strided((2048, 64, 2), (128, 2, 1), device='cuda:0', dtype=torch.float32) | |
arg197_1 = rand_strided((1, 32, 2048, 128), (8388608, 262144, 128, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg198_1 = rand_strided((1, 32, 2048, 128), (8388608, 262144, 128, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg199_1 = rand_strided((4096, 4096), (4096, 1), device='cuda:0', dtype=torch.int8) | |
arg200_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg201_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg202_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.bfloat16) | |
arg203_1 = rand_strided((14336, 4096), (4096, 1), device='cuda:0', dtype=torch.int8) | |
arg204_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg205_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg206_1 = rand_strided((14336, 4096), (4096, 1), device='cuda:0', dtype=torch.int8) | |
arg207_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg208_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg209_1 = rand_strided((4096, 14336), (14336, 1), device='cuda:0', dtype=torch.int8) | |
arg210_1 = rand_strided((4096, 56), (56, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg211_1 = rand_strided((4096, 56), (56, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg212_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.bfloat16) | |
arg213_1 = rand_strided((4096, 4096), (4096, 1), device='cuda:0', dtype=torch.int8) | |
arg214_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg215_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg216_1 = rand_strided((1024, 4096), (4096, 1), device='cuda:0', dtype=torch.int8) | |
arg217_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg218_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg219_1 = rand_strided((1024, 4096), (4096, 1), device='cuda:0', dtype=torch.int8) | |
arg220_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg221_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg222_1 = rand_strided((2048, 64, 2), (128, 2, 1), device='cuda:0', dtype=torch.float32) | |
arg223_1 = rand_strided((1, 32, 2048, 128), (8388608, 262144, 128, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg224_1 = rand_strided((1, 32, 2048, 128), (8388608, 262144, 128, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg225_1 = rand_strided((4096, 4096), (4096, 1), device='cuda:0', dtype=torch.int8) | |
arg226_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg227_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg228_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.bfloat16) | |
arg229_1 = rand_strided((14336, 4096), (4096, 1), device='cuda:0', dtype=torch.int8) | |
arg230_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg231_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg232_1 = rand_strided((14336, 4096), (4096, 1), device='cuda:0', dtype=torch.int8) | |
arg233_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg234_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg235_1 = rand_strided((4096, 14336), (14336, 1), device='cuda:0', dtype=torch.int8) | |
arg236_1 = rand_strided((4096, 56), (56, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg237_1 = rand_strided((4096, 56), (56, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg238_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.bfloat16) | |
arg239_1 = rand_strided((4096, 4096), (4096, 1), device='cuda:0', dtype=torch.int8) | |
arg240_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg241_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg242_1 = rand_strided((1024, 4096), (4096, 1), device='cuda:0', dtype=torch.int8) | |
arg243_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg244_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg245_1 = rand_strided((1024, 4096), (4096, 1), device='cuda:0', dtype=torch.int8) | |
arg246_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg247_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg248_1 = rand_strided((2048, 64, 2), (128, 2, 1), device='cuda:0', dtype=torch.float32) | |
arg249_1 = rand_strided((1, 32, 2048, 128), (8388608, 262144, 128, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg250_1 = rand_strided((1, 32, 2048, 128), (8388608, 262144, 128, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg251_1 = rand_strided((4096, 4096), (4096, 1), device='cuda:0', dtype=torch.int8) | |
arg252_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg253_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg254_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.bfloat16) | |
arg255_1 = rand_strided((14336, 4096), (4096, 1), device='cuda:0', dtype=torch.int8) | |
arg256_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg257_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg258_1 = rand_strided((14336, 4096), (4096, 1), device='cuda:0', dtype=torch.int8) | |
arg259_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg260_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg261_1 = rand_strided((4096, 14336), (14336, 1), device='cuda:0', dtype=torch.int8) | |
arg262_1 = rand_strided((4096, 56), (56, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg263_1 = rand_strided((4096, 56), (56, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg264_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.bfloat16) | |
arg265_1 = rand_strided((4096, 4096), (4096, 1), device='cuda:0', dtype=torch.int8) | |
arg266_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg267_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg268_1 = rand_strided((1024, 4096), (4096, 1), device='cuda:0', dtype=torch.int8) | |
arg269_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg270_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg271_1 = rand_strided((1024, 4096), (4096, 1), device='cuda:0', dtype=torch.int8) | |
arg272_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg273_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg274_1 = rand_strided((2048, 64, 2), (128, 2, 1), device='cuda:0', dtype=torch.float32) | |
arg275_1 = rand_strided((1, 32, 2048, 128), (8388608, 262144, 128, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg276_1 = rand_strided((1, 32, 2048, 128), (8388608, 262144, 128, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg277_1 = rand_strided((4096, 4096), (4096, 1), device='cuda:0', dtype=torch.int8) | |
arg278_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg279_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg280_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.bfloat16) | |
arg281_1 = rand_strided((14336, 4096), (4096, 1), device='cuda:0', dtype=torch.int8) | |
arg282_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg283_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg284_1 = rand_strided((14336, 4096), (4096, 1), device='cuda:0', dtype=torch.int8) | |
arg285_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg286_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg287_1 = rand_strided((4096, 14336), (14336, 1), device='cuda:0', dtype=torch.int8) | |
arg288_1 = rand_strided((4096, 56), (56, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg289_1 = rand_strided((4096, 56), (56, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg290_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.bfloat16) | |
arg291_1 = rand_strided((4096, 4096), (4096, 1), device='cuda:0', dtype=torch.int8) | |
arg292_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg293_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg294_1 = rand_strided((1024, 4096), (4096, 1), device='cuda:0', dtype=torch.int8) | |
arg295_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg296_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg297_1 = rand_strided((1024, 4096), (4096, 1), device='cuda:0', dtype=torch.int8) | |
arg298_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg299_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg300_1 = rand_strided((2048, 64, 2), (128, 2, 1), device='cuda:0', dtype=torch.float32) | |
arg301_1 = rand_strided((1, 32, 2048, 128), (8388608, 262144, 128, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg302_1 = rand_strided((1, 32, 2048, 128), (8388608, 262144, 128, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg303_1 = rand_strided((4096, 4096), (4096, 1), device='cuda:0', dtype=torch.int8) | |
arg304_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg305_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg306_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.bfloat16) | |
arg307_1 = rand_strided((14336, 4096), (4096, 1), device='cuda:0', dtype=torch.int8) | |
arg308_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg309_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg310_1 = rand_strided((14336, 4096), (4096, 1), device='cuda:0', dtype=torch.int8) | |
arg311_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg312_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg313_1 = rand_strided((4096, 14336), (14336, 1), device='cuda:0', dtype=torch.int8) | |
arg314_1 = rand_strided((4096, 56), (56, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg315_1 = rand_strided((4096, 56), (56, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg316_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.bfloat16) | |
arg317_1 = rand_strided((4096, 4096), (4096, 1), device='cuda:0', dtype=torch.int8) | |
arg318_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg319_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg320_1 = rand_strided((1024, 4096), (4096, 1), device='cuda:0', dtype=torch.int8) | |
arg321_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg322_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg323_1 = rand_strided((1024, 4096), (4096, 1), device='cuda:0', dtype=torch.int8) | |
arg324_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg325_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg326_1 = rand_strided((2048, 64, 2), (128, 2, 1), device='cuda:0', dtype=torch.float32) | |
arg327_1 = rand_strided((1, 32, 2048, 128), (8388608, 262144, 128, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg328_1 = rand_strided((1, 32, 2048, 128), (8388608, 262144, 128, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg329_1 = rand_strided((4096, 4096), (4096, 1), device='cuda:0', dtype=torch.int8) | |
arg330_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg331_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg332_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.bfloat16) | |
arg333_1 = rand_strided((14336, 4096), (4096, 1), device='cuda:0', dtype=torch.int8) | |
arg334_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg335_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg336_1 = rand_strided((14336, 4096), (4096, 1), device='cuda:0', dtype=torch.int8) | |
arg337_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg338_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg339_1 = rand_strided((4096, 14336), (14336, 1), device='cuda:0', dtype=torch.int8) | |
arg340_1 = rand_strided((4096, 56), (56, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg341_1 = rand_strided((4096, 56), (56, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg342_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.bfloat16) | |
arg343_1 = rand_strided((4096, 4096), (4096, 1), device='cuda:0', dtype=torch.int8) | |
arg344_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg345_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg346_1 = rand_strided((1024, 4096), (4096, 1), device='cuda:0', dtype=torch.int8) | |
arg347_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg348_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg349_1 = rand_strided((1024, 4096), (4096, 1), device='cuda:0', dtype=torch.int8) | |
arg350_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg351_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg352_1 = rand_strided((2048, 64, 2), (128, 2, 1), device='cuda:0', dtype=torch.float32) | |
arg353_1 = rand_strided((1, 32, 2048, 128), (8388608, 262144, 128, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg354_1 = rand_strided((1, 32, 2048, 128), (8388608, 262144, 128, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg355_1 = rand_strided((4096, 4096), (4096, 1), device='cuda:0', dtype=torch.int8) | |
arg356_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg357_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg358_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.bfloat16) | |
arg359_1 = rand_strided((14336, 4096), (4096, 1), device='cuda:0', dtype=torch.int8) | |
arg360_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg361_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg362_1 = rand_strided((14336, 4096), (4096, 1), device='cuda:0', dtype=torch.int8) | |
arg363_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg364_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg365_1 = rand_strided((4096, 14336), (14336, 1), device='cuda:0', dtype=torch.int8) | |
arg366_1 = rand_strided((4096, 56), (56, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg367_1 = rand_strided((4096, 56), (56, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg368_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.bfloat16) | |
arg369_1 = rand_strided((4096, 4096), (4096, 1), device='cuda:0', dtype=torch.int8) | |
arg370_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg371_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg372_1 = rand_strided((1024, 4096), (4096, 1), device='cuda:0', dtype=torch.int8) | |
arg373_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg374_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg375_1 = rand_strided((1024, 4096), (4096, 1), device='cuda:0', dtype=torch.int8) | |
arg376_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg377_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg378_1 = rand_strided((2048, 64, 2), (128, 2, 1), device='cuda:0', dtype=torch.float32) | |
arg379_1 = rand_strided((1, 32, 2048, 128), (8388608, 262144, 128, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg380_1 = rand_strided((1, 32, 2048, 128), (8388608, 262144, 128, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg381_1 = rand_strided((4096, 4096), (4096, 1), device='cuda:0', dtype=torch.int8) | |
arg382_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg383_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg384_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.bfloat16) | |
arg385_1 = rand_strided((14336, 4096), (4096, 1), device='cuda:0', dtype=torch.int8) | |
arg386_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg387_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg388_1 = rand_strided((14336, 4096), (4096, 1), device='cuda:0', dtype=torch.int8) | |
arg389_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg390_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg391_1 = rand_strided((4096, 14336), (14336, 1), device='cuda:0', dtype=torch.int8) | |
arg392_1 = rand_strided((4096, 56), (56, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg393_1 = rand_strided((4096, 56), (56, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg394_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.bfloat16) | |
arg395_1 = rand_strided((4096, 4096), (4096, 1), device='cuda:0', dtype=torch.int8) | |
arg396_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg397_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg398_1 = rand_strided((1024, 4096), (4096, 1), device='cuda:0', dtype=torch.int8) | |
arg399_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg400_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg401_1 = rand_strided((1024, 4096), (4096, 1), device='cuda:0', dtype=torch.int8) | |
arg402_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg403_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg404_1 = rand_strided((2048, 64, 2), (128, 2, 1), device='cuda:0', dtype=torch.float32) | |
arg405_1 = rand_strided((1, 32, 2048, 128), (8388608, 262144, 128, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg406_1 = rand_strided((1, 32, 2048, 128), (8388608, 262144, 128, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg407_1 = rand_strided((4096, 4096), (4096, 1), device='cuda:0', dtype=torch.int8) | |
arg408_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg409_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg410_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.bfloat16) | |
arg411_1 = rand_strided((14336, 4096), (4096, 1), device='cuda:0', dtype=torch.int8) | |
arg412_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg413_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg414_1 = rand_strided((14336, 4096), (4096, 1), device='cuda:0', dtype=torch.int8) | |
arg415_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg416_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg417_1 = rand_strided((4096, 14336), (14336, 1), device='cuda:0', dtype=torch.int8) | |
arg418_1 = rand_strided((4096, 56), (56, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg419_1 = rand_strided((4096, 56), (56, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg420_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.bfloat16) | |
arg421_1 = rand_strided((4096, 4096), (4096, 1), device='cuda:0', dtype=torch.int8) | |
arg422_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg423_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg424_1 = rand_strided((1024, 4096), (4096, 1), device='cuda:0', dtype=torch.int8) | |
arg425_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg426_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg427_1 = rand_strided((1024, 4096), (4096, 1), device='cuda:0', dtype=torch.int8) | |
arg428_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg429_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg430_1 = rand_strided((2048, 64, 2), (128, 2, 1), device='cuda:0', dtype=torch.float32) | |
arg431_1 = rand_strided((1, 32, 2048, 128), (8388608, 262144, 128, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg432_1 = rand_strided((1, 32, 2048, 128), (8388608, 262144, 128, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg433_1 = rand_strided((4096, 4096), (4096, 1), device='cuda:0', dtype=torch.int8) | |
arg434_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg435_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg436_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.bfloat16) | |
arg437_1 = rand_strided((14336, 4096), (4096, 1), device='cuda:0', dtype=torch.int8) | |
arg438_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg439_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg440_1 = rand_strided((14336, 4096), (4096, 1), device='cuda:0', dtype=torch.int8) | |
arg441_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg442_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg443_1 = rand_strided((4096, 14336), (14336, 1), device='cuda:0', dtype=torch.int8) | |
arg444_1 = rand_strided((4096, 56), (56, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg445_1 = rand_strided((4096, 56), (56, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg446_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.bfloat16) | |
arg447_1 = rand_strided((4096, 4096), (4096, 1), device='cuda:0', dtype=torch.int8) | |
arg448_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg449_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg450_1 = rand_strided((1024, 4096), (4096, 1), device='cuda:0', dtype=torch.int8) | |
arg451_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg452_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg453_1 = rand_strided((1024, 4096), (4096, 1), device='cuda:0', dtype=torch.int8) | |
arg454_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg455_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg456_1 = rand_strided((2048, 64, 2), (128, 2, 1), device='cuda:0', dtype=torch.float32) | |
arg457_1 = rand_strided((1, 32, 2048, 128), (8388608, 262144, 128, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg458_1 = rand_strided((1, 32, 2048, 128), (8388608, 262144, 128, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg459_1 = rand_strided((4096, 4096), (4096, 1), device='cuda:0', dtype=torch.int8) | |
arg460_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg461_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg462_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.bfloat16) | |
arg463_1 = rand_strided((14336, 4096), (4096, 1), device='cuda:0', dtype=torch.int8) | |
arg464_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg465_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg466_1 = rand_strided((14336, 4096), (4096, 1), device='cuda:0', dtype=torch.int8) | |
arg467_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg468_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg469_1 = rand_strided((4096, 14336), (14336, 1), device='cuda:0', dtype=torch.int8) | |
arg470_1 = rand_strided((4096, 56), (56, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg471_1 = rand_strided((4096, 56), (56, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg472_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.bfloat16) | |
arg473_1 = rand_strided((4096, 4096), (4096, 1), device='cuda:0', dtype=torch.int8) | |
arg474_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg475_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg476_1 = rand_strided((1024, 4096), (4096, 1), device='cuda:0', dtype=torch.int8) | |
arg477_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg478_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg479_1 = rand_strided((1024, 4096), (4096, 1), device='cuda:0', dtype=torch.int8) | |
arg480_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg481_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg482_1 = rand_strided((2048, 64, 2), (128, 2, 1), device='cuda:0', dtype=torch.float32) | |
arg483_1 = rand_strided((1, 32, 2048, 128), (8388608, 262144, 128, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg484_1 = rand_strided((1, 32, 2048, 128), (8388608, 262144, 128, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg485_1 = rand_strided((4096, 4096), (4096, 1), device='cuda:0', dtype=torch.int8) | |
arg486_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg487_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg488_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.bfloat16) | |
arg489_1 = rand_strided((14336, 4096), (4096, 1), device='cuda:0', dtype=torch.int8) | |
arg490_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg491_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg492_1 = rand_strided((14336, 4096), (4096, 1), device='cuda:0', dtype=torch.int8) | |
arg493_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg494_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg495_1 = rand_strided((4096, 14336), (14336, 1), device='cuda:0', dtype=torch.int8) | |
arg496_1 = rand_strided((4096, 56), (56, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg497_1 = rand_strided((4096, 56), (56, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg498_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.bfloat16) | |
arg499_1 = rand_strided((4096, 4096), (4096, 1), device='cuda:0', dtype=torch.int8) | |
arg500_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg501_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg502_1 = rand_strided((1024, 4096), (4096, 1), device='cuda:0', dtype=torch.int8) | |
arg503_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg504_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg505_1 = rand_strided((1024, 4096), (4096, 1), device='cuda:0', dtype=torch.int8) | |
arg506_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg507_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg508_1 = rand_strided((2048, 64, 2), (128, 2, 1), device='cuda:0', dtype=torch.float32) | |
arg509_1 = rand_strided((1, 32, 2048, 128), (8388608, 262144, 128, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg510_1 = rand_strided((1, 32, 2048, 128), (8388608, 262144, 128, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg511_1 = rand_strided((4096, 4096), (4096, 1), device='cuda:0', dtype=torch.int8) | |
arg512_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg513_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg514_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.bfloat16) | |
arg515_1 = rand_strided((14336, 4096), (4096, 1), device='cuda:0', dtype=torch.int8) | |
arg516_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg517_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg518_1 = rand_strided((14336, 4096), (4096, 1), device='cuda:0', dtype=torch.int8) | |
arg519_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg520_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg521_1 = rand_strided((4096, 14336), (14336, 1), device='cuda:0', dtype=torch.int8) | |
arg522_1 = rand_strided((4096, 56), (56, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg523_1 = rand_strided((4096, 56), (56, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg524_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.bfloat16) | |
arg525_1 = rand_strided((4096, 4096), (4096, 1), device='cuda:0', dtype=torch.int8) | |
arg526_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg527_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg528_1 = rand_strided((1024, 4096), (4096, 1), device='cuda:0', dtype=torch.int8) | |
arg529_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg530_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg531_1 = rand_strided((1024, 4096), (4096, 1), device='cuda:0', dtype=torch.int8) | |
arg532_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg533_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg534_1 = rand_strided((2048, 64, 2), (128, 2, 1), device='cuda:0', dtype=torch.float32) | |
arg535_1 = rand_strided((1, 32, 2048, 128), (8388608, 262144, 128, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg536_1 = rand_strided((1, 32, 2048, 128), (8388608, 262144, 128, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg537_1 = rand_strided((4096, 4096), (4096, 1), device='cuda:0', dtype=torch.int8) | |
arg538_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg539_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg540_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.bfloat16) | |
arg541_1 = rand_strided((14336, 4096), (4096, 1), device='cuda:0', dtype=torch.int8) | |
arg542_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg543_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg544_1 = rand_strided((14336, 4096), (4096, 1), device='cuda:0', dtype=torch.int8) | |
arg545_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg546_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg547_1 = rand_strided((4096, 14336), (14336, 1), device='cuda:0', dtype=torch.int8) | |
arg548_1 = rand_strided((4096, 56), (56, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg549_1 = rand_strided((4096, 56), (56, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg550_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.bfloat16) | |
arg551_1 = rand_strided((4096, 4096), (4096, 1), device='cuda:0', dtype=torch.int8) | |
arg552_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg553_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg554_1 = rand_strided((1024, 4096), (4096, 1), device='cuda:0', dtype=torch.int8) | |
arg555_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg556_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg557_1 = rand_strided((1024, 4096), (4096, 1), device='cuda:0', dtype=torch.int8) | |
arg558_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg559_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg560_1 = rand_strided((2048, 64, 2), (128, 2, 1), device='cuda:0', dtype=torch.float32) | |
arg561_1 = rand_strided((1, 32, 2048, 128), (8388608, 262144, 128, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg562_1 = rand_strided((1, 32, 2048, 128), (8388608, 262144, 128, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg563_1 = rand_strided((4096, 4096), (4096, 1), device='cuda:0', dtype=torch.int8) | |
arg564_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg565_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg566_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.bfloat16) | |
arg567_1 = rand_strided((14336, 4096), (4096, 1), device='cuda:0', dtype=torch.int8) | |
arg568_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg569_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg570_1 = rand_strided((14336, 4096), (4096, 1), device='cuda:0', dtype=torch.int8) | |
arg571_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg572_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg573_1 = rand_strided((4096, 14336), (14336, 1), device='cuda:0', dtype=torch.int8) | |
arg574_1 = rand_strided((4096, 56), (56, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg575_1 = rand_strided((4096, 56), (56, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg576_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.bfloat16) | |
arg577_1 = rand_strided((4096, 4096), (4096, 1), device='cuda:0', dtype=torch.int8) | |
arg578_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg579_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg580_1 = rand_strided((1024, 4096), (4096, 1), device='cuda:0', dtype=torch.int8) | |
arg581_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg582_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg583_1 = rand_strided((1024, 4096), (4096, 1), device='cuda:0', dtype=torch.int8) | |
arg584_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg585_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg586_1 = rand_strided((2048, 64, 2), (128, 2, 1), device='cuda:0', dtype=torch.float32) | |
arg587_1 = rand_strided((1, 32, 2048, 128), (8388608, 262144, 128, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg588_1 = rand_strided((1, 32, 2048, 128), (8388608, 262144, 128, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg589_1 = rand_strided((4096, 4096), (4096, 1), device='cuda:0', dtype=torch.int8) | |
arg590_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg591_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg592_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.bfloat16) | |
arg593_1 = rand_strided((14336, 4096), (4096, 1), device='cuda:0', dtype=torch.int8) | |
arg594_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg595_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg596_1 = rand_strided((14336, 4096), (4096, 1), device='cuda:0', dtype=torch.int8) | |
arg597_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg598_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg599_1 = rand_strided((4096, 14336), (14336, 1), device='cuda:0', dtype=torch.int8) | |
arg600_1 = rand_strided((4096, 56), (56, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg601_1 = rand_strided((4096, 56), (56, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg602_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.bfloat16) | |
arg603_1 = rand_strided((4096, 4096), (4096, 1), device='cuda:0', dtype=torch.int8) | |
arg604_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg605_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg606_1 = rand_strided((1024, 4096), (4096, 1), device='cuda:0', dtype=torch.int8) | |
arg607_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg608_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg609_1 = rand_strided((1024, 4096), (4096, 1), device='cuda:0', dtype=torch.int8) | |
arg610_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg611_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg612_1 = rand_strided((2048, 64, 2), (128, 2, 1), device='cuda:0', dtype=torch.float32) | |
arg613_1 = rand_strided((1, 32, 2048, 128), (8388608, 262144, 128, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg614_1 = rand_strided((1, 32, 2048, 128), (8388608, 262144, 128, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg615_1 = rand_strided((4096, 4096), (4096, 1), device='cuda:0', dtype=torch.int8) | |
arg616_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg617_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg618_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.bfloat16) | |
arg619_1 = rand_strided((14336, 4096), (4096, 1), device='cuda:0', dtype=torch.int8) | |
arg620_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg621_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg622_1 = rand_strided((14336, 4096), (4096, 1), device='cuda:0', dtype=torch.int8) | |
arg623_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg624_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg625_1 = rand_strided((4096, 14336), (14336, 1), device='cuda:0', dtype=torch.int8) | |
arg626_1 = rand_strided((4096, 56), (56, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg627_1 = rand_strided((4096, 56), (56, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg628_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.bfloat16) | |
arg629_1 = rand_strided((4096, 4096), (4096, 1), device='cuda:0', dtype=torch.int8) | |
arg630_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg631_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg632_1 = rand_strided((1024, 4096), (4096, 1), device='cuda:0', dtype=torch.int8) | |
arg633_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg634_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg635_1 = rand_strided((1024, 4096), (4096, 1), device='cuda:0', dtype=torch.int8) | |
arg636_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg637_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg638_1 = rand_strided((2048, 64, 2), (128, 2, 1), device='cuda:0', dtype=torch.float32) | |
arg639_1 = rand_strided((1, 32, 2048, 128), (8388608, 262144, 128, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg640_1 = rand_strided((1, 32, 2048, 128), (8388608, 262144, 128, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg641_1 = rand_strided((4096, 4096), (4096, 1), device='cuda:0', dtype=torch.int8) | |
arg642_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg643_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg644_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.bfloat16) | |
arg645_1 = rand_strided((14336, 4096), (4096, 1), device='cuda:0', dtype=torch.int8) | |
arg646_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg647_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg648_1 = rand_strided((14336, 4096), (4096, 1), device='cuda:0', dtype=torch.int8) | |
arg649_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg650_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg651_1 = rand_strided((4096, 14336), (14336, 1), device='cuda:0', dtype=torch.int8) | |
arg652_1 = rand_strided((4096, 56), (56, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg653_1 = rand_strided((4096, 56), (56, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg654_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.bfloat16) | |
arg655_1 = rand_strided((4096, 4096), (4096, 1), device='cuda:0', dtype=torch.int8) | |
arg656_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg657_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg658_1 = rand_strided((1024, 4096), (4096, 1), device='cuda:0', dtype=torch.int8) | |
arg659_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg660_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg661_1 = rand_strided((1024, 4096), (4096, 1), device='cuda:0', dtype=torch.int8) | |
arg662_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg663_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg664_1 = rand_strided((2048, 64, 2), (128, 2, 1), device='cuda:0', dtype=torch.float32) | |
arg665_1 = rand_strided((1, 32, 2048, 128), (8388608, 262144, 128, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg666_1 = rand_strided((1, 32, 2048, 128), (8388608, 262144, 128, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg667_1 = rand_strided((4096, 4096), (4096, 1), device='cuda:0', dtype=torch.int8) | |
arg668_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg669_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg670_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.bfloat16) | |
arg671_1 = rand_strided((14336, 4096), (4096, 1), device='cuda:0', dtype=torch.int8) | |
arg672_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg673_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg674_1 = rand_strided((14336, 4096), (4096, 1), device='cuda:0', dtype=torch.int8) | |
arg675_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg676_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg677_1 = rand_strided((4096, 14336), (14336, 1), device='cuda:0', dtype=torch.int8) | |
arg678_1 = rand_strided((4096, 56), (56, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg679_1 = rand_strided((4096, 56), (56, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg680_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.bfloat16) | |
arg681_1 = rand_strided((4096, 4096), (4096, 1), device='cuda:0', dtype=torch.int8) | |
arg682_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg683_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg684_1 = rand_strided((1024, 4096), (4096, 1), device='cuda:0', dtype=torch.int8) | |
arg685_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg686_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg687_1 = rand_strided((1024, 4096), (4096, 1), device='cuda:0', dtype=torch.int8) | |
arg688_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg689_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg690_1 = rand_strided((2048, 64, 2), (128, 2, 1), device='cuda:0', dtype=torch.float32) | |
arg691_1 = rand_strided((1, 32, 2048, 128), (8388608, 262144, 128, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg692_1 = rand_strided((1, 32, 2048, 128), (8388608, 262144, 128, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg693_1 = rand_strided((4096, 4096), (4096, 1), device='cuda:0', dtype=torch.int8) | |
arg694_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg695_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg696_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.bfloat16) | |
arg697_1 = rand_strided((14336, 4096), (4096, 1), device='cuda:0', dtype=torch.int8) | |
arg698_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg699_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg700_1 = rand_strided((14336, 4096), (4096, 1), device='cuda:0', dtype=torch.int8) | |
arg701_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg702_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg703_1 = rand_strided((4096, 14336), (14336, 1), device='cuda:0', dtype=torch.int8) | |
arg704_1 = rand_strided((4096, 56), (56, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg705_1 = rand_strided((4096, 56), (56, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg706_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.bfloat16) | |
arg707_1 = rand_strided((4096, 4096), (4096, 1), device='cuda:0', dtype=torch.int8) | |
arg708_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg709_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg710_1 = rand_strided((1024, 4096), (4096, 1), device='cuda:0', dtype=torch.int8) | |
arg711_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg712_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg713_1 = rand_strided((1024, 4096), (4096, 1), device='cuda:0', dtype=torch.int8) | |
arg714_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg715_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg716_1 = rand_strided((2048, 64, 2), (128, 2, 1), device='cuda:0', dtype=torch.float32) | |
arg717_1 = rand_strided((1, 32, 2048, 128), (8388608, 262144, 128, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg718_1 = rand_strided((1, 32, 2048, 128), (8388608, 262144, 128, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg719_1 = rand_strided((4096, 4096), (4096, 1), device='cuda:0', dtype=torch.int8) | |
arg720_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg721_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg722_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.bfloat16) | |
arg723_1 = rand_strided((14336, 4096), (4096, 1), device='cuda:0', dtype=torch.int8) | |
arg724_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg725_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg726_1 = rand_strided((14336, 4096), (4096, 1), device='cuda:0', dtype=torch.int8) | |
arg727_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg728_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg729_1 = rand_strided((4096, 14336), (14336, 1), device='cuda:0', dtype=torch.int8) | |
arg730_1 = rand_strided((4096, 56), (56, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg731_1 = rand_strided((4096, 56), (56, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg732_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.bfloat16) | |
arg733_1 = rand_strided((4096, 4096), (4096, 1), device='cuda:0', dtype=torch.int8) | |
arg734_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg735_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg736_1 = rand_strided((1024, 4096), (4096, 1), device='cuda:0', dtype=torch.int8) | |
arg737_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg738_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg739_1 = rand_strided((1024, 4096), (4096, 1), device='cuda:0', dtype=torch.int8) | |
arg740_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg741_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg742_1 = rand_strided((2048, 64, 2), (128, 2, 1), device='cuda:0', dtype=torch.float32) | |
arg743_1 = rand_strided((1, 32, 2048, 128), (8388608, 262144, 128, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg744_1 = rand_strided((1, 32, 2048, 128), (8388608, 262144, 128, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg745_1 = rand_strided((4096, 4096), (4096, 1), device='cuda:0', dtype=torch.int8) | |
arg746_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg747_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg748_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.bfloat16) | |
arg749_1 = rand_strided((14336, 4096), (4096, 1), device='cuda:0', dtype=torch.int8) | |
arg750_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg751_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg752_1 = rand_strided((14336, 4096), (4096, 1), device='cuda:0', dtype=torch.int8) | |
arg753_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg754_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg755_1 = rand_strided((4096, 14336), (14336, 1), device='cuda:0', dtype=torch.int8) | |
arg756_1 = rand_strided((4096, 56), (56, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg757_1 = rand_strided((4096, 56), (56, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg758_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.bfloat16) | |
arg759_1 = rand_strided((4096, 4096), (4096, 1), device='cuda:0', dtype=torch.int8) | |
arg760_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg761_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg762_1 = rand_strided((1024, 4096), (4096, 1), device='cuda:0', dtype=torch.int8) | |
arg763_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg764_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg765_1 = rand_strided((1024, 4096), (4096, 1), device='cuda:0', dtype=torch.int8) | |
arg766_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg767_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg768_1 = rand_strided((2048, 64, 2), (128, 2, 1), device='cuda:0', dtype=torch.float32) | |
arg769_1 = rand_strided((1, 32, 2048, 128), (8388608, 262144, 128, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg770_1 = rand_strided((1, 32, 2048, 128), (8388608, 262144, 128, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg771_1 = rand_strided((4096, 4096), (4096, 1), device='cuda:0', dtype=torch.int8) | |
arg772_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg773_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg774_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.bfloat16) | |
arg775_1 = rand_strided((14336, 4096), (4096, 1), device='cuda:0', dtype=torch.int8) | |
arg776_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg777_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg778_1 = rand_strided((14336, 4096), (4096, 1), device='cuda:0', dtype=torch.int8) | |
arg779_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg780_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg781_1 = rand_strided((4096, 14336), (14336, 1), device='cuda:0', dtype=torch.int8) | |
arg782_1 = rand_strided((4096, 56), (56, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg783_1 = rand_strided((4096, 56), (56, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg784_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.bfloat16) | |
arg785_1 = rand_strided((4096, 4096), (4096, 1), device='cuda:0', dtype=torch.int8) | |
arg786_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg787_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg788_1 = rand_strided((1024, 4096), (4096, 1), device='cuda:0', dtype=torch.int8) | |
arg789_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg790_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg791_1 = rand_strided((1024, 4096), (4096, 1), device='cuda:0', dtype=torch.int8) | |
arg792_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg793_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg794_1 = rand_strided((2048, 64, 2), (128, 2, 1), device='cuda:0', dtype=torch.float32) | |
arg795_1 = rand_strided((1, 32, 2048, 128), (8388608, 262144, 128, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg796_1 = rand_strided((1, 32, 2048, 128), (8388608, 262144, 128, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg797_1 = rand_strided((4096, 4096), (4096, 1), device='cuda:0', dtype=torch.int8) | |
arg798_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg799_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg800_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.bfloat16) | |
arg801_1 = rand_strided((14336, 4096), (4096, 1), device='cuda:0', dtype=torch.int8) | |
arg802_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg803_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg804_1 = rand_strided((14336, 4096), (4096, 1), device='cuda:0', dtype=torch.int8) | |
arg805_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg806_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg807_1 = rand_strided((4096, 14336), (14336, 1), device='cuda:0', dtype=torch.int8) | |
arg808_1 = rand_strided((4096, 56), (56, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg809_1 = rand_strided((4096, 56), (56, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg810_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.bfloat16) | |
arg811_1 = rand_strided((4096, 4096), (4096, 1), device='cuda:0', dtype=torch.int8) | |
arg812_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg813_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg814_1 = rand_strided((1024, 4096), (4096, 1), device='cuda:0', dtype=torch.int8) | |
arg815_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg816_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg817_1 = rand_strided((1024, 4096), (4096, 1), device='cuda:0', dtype=torch.int8) | |
arg818_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg819_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg820_1 = rand_strided((2048, 64, 2), (128, 2, 1), device='cuda:0', dtype=torch.float32) | |
arg821_1 = rand_strided((1, 32, 2048, 128), (8388608, 262144, 128, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg822_1 = rand_strided((1, 32, 2048, 128), (8388608, 262144, 128, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg823_1 = rand_strided((4096, 4096), (4096, 1), device='cuda:0', dtype=torch.int8) | |
arg824_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg825_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg826_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.bfloat16) | |
arg827_1 = rand_strided((14336, 4096), (4096, 1), device='cuda:0', dtype=torch.int8) | |
arg828_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg829_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg830_1 = rand_strided((14336, 4096), (4096, 1), device='cuda:0', dtype=torch.int8) | |
arg831_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg832_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg833_1 = rand_strided((4096, 14336), (14336, 1), device='cuda:0', dtype=torch.int8) | |
arg834_1 = rand_strided((4096, 56), (56, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg835_1 = rand_strided((4096, 56), (56, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg836_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.bfloat16) | |
arg837_1 = rand_strided((128256, 4096), (4096, 1), device='cuda:0', dtype=torch.int8) | |
arg838_1 = rand_strided((128256, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg839_1 = rand_strided((128256, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16) | |
fn = lambda: call([arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1, arg7_1, arg8_1, arg9_1, arg10_1, arg11_1, arg12_1, arg13_1, arg14_1, arg15_1, arg16_1, arg17_1, arg18_1, arg19_1, arg20_1, arg21_1, arg22_1, arg23_1, arg24_1, arg25_1, arg26_1, arg27_1, arg28_1, arg29_1, arg30_1, arg31_1, arg32_1, arg33_1, arg34_1, arg35_1, arg36_1, arg37_1, arg38_1, arg39_1, arg40_1, arg41_1, arg42_1, arg43_1, arg44_1, arg45_1, arg46_1, arg47_1, arg48_1, arg49_1, arg50_1, arg51_1, arg52_1, arg53_1, arg54_1, arg55_1, arg56_1, arg57_1, arg58_1, arg59_1, arg60_1, arg61_1, arg62_1, arg63_1, arg64_1, arg65_1, arg66_1, arg67_1, arg68_1, arg69_1, arg70_1, arg71_1, arg72_1, arg73_1, arg74_1, arg75_1, arg76_1, arg77_1, arg78_1, arg79_1, arg80_1, arg81_1, arg82_1, arg83_1, arg84_1, arg85_1, arg86_1, arg87_1, arg88_1, arg89_1, arg90_1, arg91_1, arg92_1, arg93_1, arg94_1, arg95_1, arg96_1, arg97_1, arg98_1, arg99_1, arg100_1, arg101_1, arg102_1, arg103_1, arg104_1, arg105_1, arg106_1, arg107_1, arg108_1, arg109_1, arg110_1, arg111_1, arg112_1, arg113_1, arg114_1, arg115_1, arg116_1, arg117_1, arg118_1, arg119_1, arg120_1, arg121_1, arg122_1, arg123_1, arg124_1, arg125_1, arg126_1, arg127_1, arg128_1, arg129_1, arg130_1, arg131_1, arg132_1, arg133_1, arg134_1, arg135_1, arg136_1, arg137_1, arg138_1, arg139_1, arg140_1, arg141_1, arg142_1, arg143_1, arg144_1, arg145_1, arg146_1, arg147_1, arg148_1, arg149_1, arg150_1, arg151_1, arg152_1, arg153_1, arg154_1, arg155_1, arg156_1, arg157_1, arg158_1, arg159_1, arg160_1, arg161_1, arg162_1, arg163_1, arg164_1, arg165_1, arg166_1, arg167_1, arg168_1, arg169_1, arg170_1, arg171_1, arg172_1, arg173_1, arg174_1, arg175_1, arg176_1, arg177_1, arg178_1, arg179_1, arg180_1, arg181_1, arg182_1, arg183_1, arg184_1, arg185_1, arg186_1, arg187_1, arg188_1, arg189_1, arg190_1, arg191_1, arg192_1, arg193_1, arg194_1, arg195_1, arg196_1, arg197_1, arg198_1, arg199_1, arg200_1, arg201_1, arg202_1, arg203_1, arg204_1, arg205_1, arg206_1, arg207_1, arg208_1, arg209_1, arg210_1, arg211_1, arg212_1, arg213_1, arg214_1, arg215_1, arg216_1, arg217_1, arg218_1, arg219_1, arg220_1, arg221_1, arg222_1, arg223_1, arg224_1, arg225_1, arg226_1, arg227_1, arg228_1, arg229_1, arg230_1, arg231_1, arg232_1, arg233_1, arg234_1, arg235_1, arg236_1, arg237_1, arg238_1, arg239_1, arg240_1, arg241_1, arg242_1, arg243_1, arg244_1, arg245_1, arg246_1, arg247_1, arg248_1, arg249_1, arg250_1, arg251_1, arg252_1, arg253_1, arg254_1, arg255_1, arg256_1, arg257_1, arg258_1, arg259_1, arg260_1, arg261_1, arg262_1, arg263_1, arg264_1, arg265_1, arg266_1, arg267_1, arg268_1, arg269_1, arg270_1, arg271_1, arg272_1, arg273_1, arg274_1, arg275_1, arg276_1, arg277_1, arg278_1, arg279_1, arg280_1, arg281_1, arg282_1, arg283_1, arg284_1, arg285_1, arg286_1, arg287_1, arg288_1, arg289_1, arg290_1, arg291_1, arg292_1, arg293_1, arg294_1, arg295_1, arg296_1, arg297_1, arg298_1, arg299_1, arg300_1, arg301_1, arg302_1, arg303_1, arg304_1, arg305_1, arg306_1, arg307_1, arg308_1, arg309_1, arg310_1, arg311_1, arg312_1, arg313_1, arg314_1, arg315_1, arg316_1, arg317_1, arg318_1, arg319_1, arg320_1, arg321_1, arg322_1, arg323_1, arg324_1, arg325_1, arg326_1, arg327_1, arg328_1, arg329_1, arg330_1, arg331_1, arg332_1, arg333_1, arg334_1, arg335_1, arg336_1, arg337_1, arg338_1, arg339_1, arg340_1, arg341_1, arg342_1, arg343_1, arg344_1, arg345_1, arg346_1, arg347_1, arg348_1, arg349_1, arg350_1, arg351_1, arg352_1, arg353_1, arg354_1, arg355_1, arg356_1, arg357_1, arg358_1, arg359_1, arg360_1, arg361_1, arg362_1, arg363_1, arg364_1, arg365_1, arg366_1, arg367_1, arg368_1, arg369_1, arg370_1, arg371_1, arg372_1, arg373_1, arg374_1, arg375_1, arg376_1, arg377_1, arg378_1, arg379_1, arg380_1, arg381_1, arg382_1, arg383_1, arg384_1, arg385_1, arg386_1, arg387_1, arg388_1, arg389_1, arg390_1, arg391_1, arg392_1, arg393_1, arg394_1, arg395_1, arg396_1, arg397_1, arg398_1, arg399_1, arg400_1, arg401_1, arg402_1, arg403_1, arg404_1, arg405_1, arg406_1, arg407_1, arg408_1, arg409_1, arg410_1, arg411_1, arg412_1, arg413_1, arg414_1, arg415_1, arg416_1, arg417_1, arg418_1, arg419_1, arg420_1, arg421_1, arg422_1, arg423_1, arg424_1, arg425_1, arg426_1, arg427_1, arg428_1, arg429_1, arg430_1, arg431_1, arg432_1, arg433_1, arg434_1, arg435_1, arg436_1, arg437_1, arg438_1, arg439_1, arg440_1, arg441_1, arg442_1, arg443_1, arg444_1, arg445_1, arg446_1, arg447_1, arg448_1, arg449_1, arg450_1, arg451_1, arg452_1, arg453_1, arg454_1, arg455_1, arg456_1, arg457_1, arg458_1, arg459_1, arg460_1, arg461_1, arg462_1, arg463_1, arg464_1, arg465_1, arg466_1, arg467_1, arg468_1, arg469_1, arg470_1, arg471_1, arg472_1, arg473_1, arg474_1, arg475_1, arg476_1, arg477_1, arg478_1, arg479_1, arg480_1, arg481_1, arg482_1, arg483_1, arg484_1, arg485_1, arg486_1, arg487_1, arg488_1, arg489_1, arg490_1, arg491_1, arg492_1, arg493_1, arg494_1, arg495_1, arg496_1, arg497_1, arg498_1, arg499_1, arg500_1, arg501_1, arg502_1, arg503_1, arg504_1, arg505_1, arg506_1, arg507_1, arg508_1, arg509_1, arg510_1, arg511_1, arg512_1, arg513_1, arg514_1, arg515_1, arg516_1, arg517_1, arg518_1, arg519_1, arg520_1, arg521_1, arg522_1, arg523_1, arg524_1, arg525_1, arg526_1, arg527_1, arg528_1, arg529_1, arg530_1, arg531_1, arg532_1, arg533_1, arg534_1, arg535_1, arg536_1, arg537_1, arg538_1, arg539_1, arg540_1, arg541_1, arg542_1, arg543_1, arg544_1, arg545_1, arg546_1, arg547_1, arg548_1, arg549_1, arg550_1, arg551_1, arg552_1, arg553_1, arg554_1, arg555_1, arg556_1, arg557_1, arg558_1, arg559_1, arg560_1, arg561_1, arg562_1, arg563_1, arg564_1, arg565_1, arg566_1, arg567_1, arg568_1, arg569_1, arg570_1, arg571_1, arg572_1, arg573_1, arg574_1, arg575_1, arg576_1, arg577_1, arg578_1, arg579_1, arg580_1, arg581_1, arg582_1, arg583_1, arg584_1, arg585_1, arg586_1, arg587_1, arg588_1, arg589_1, arg590_1, arg591_1, arg592_1, arg593_1, arg594_1, arg595_1, arg596_1, arg597_1, arg598_1, arg599_1, arg600_1, arg601_1, arg602_1, arg603_1, arg604_1, arg605_1, arg606_1, arg607_1, arg608_1, arg609_1, arg610_1, arg611_1, arg612_1, arg613_1, arg614_1, arg615_1, arg616_1, arg617_1, arg618_1, arg619_1, arg620_1, arg621_1, arg622_1, arg623_1, arg624_1, arg625_1, arg626_1, arg627_1, arg628_1, arg629_1, arg630_1, arg631_1, arg632_1, arg633_1, arg634_1, arg635_1, arg636_1, arg637_1, arg638_1, arg639_1, arg640_1, arg641_1, arg642_1, arg643_1, arg644_1, arg645_1, arg646_1, arg647_1, arg648_1, arg649_1, arg650_1, arg651_1, arg652_1, arg653_1, arg654_1, arg655_1, arg656_1, arg657_1, arg658_1, arg659_1, arg660_1, arg661_1, arg662_1, arg663_1, arg664_1, arg665_1, arg666_1, arg667_1, arg668_1, arg669_1, arg670_1, arg671_1, arg672_1, arg673_1, arg674_1, arg675_1, arg676_1, arg677_1, arg678_1, arg679_1, arg680_1, arg681_1, arg682_1, arg683_1, arg684_1, arg685_1, arg686_1, arg687_1, arg688_1, arg689_1, arg690_1, arg691_1, arg692_1, arg693_1, arg694_1, arg695_1, arg696_1, arg697_1, arg698_1, arg699_1, arg700_1, arg701_1, arg702_1, arg703_1, arg704_1, arg705_1, arg706_1, arg707_1, arg708_1, arg709_1, arg710_1, arg711_1, arg712_1, arg713_1, arg714_1, arg715_1, arg716_1, arg717_1, arg718_1, arg719_1, arg720_1, arg721_1, arg722_1, arg723_1, arg724_1, arg725_1, arg726_1, arg727_1, arg728_1, arg729_1, arg730_1, arg731_1, arg732_1, arg733_1, arg734_1, arg735_1, arg736_1, arg737_1, arg738_1, arg739_1, arg740_1, arg741_1, arg742_1, arg743_1, arg744_1, arg745_1, arg746_1, arg747_1, arg748_1, arg749_1, arg750_1, arg751_1, arg752_1, arg753_1, arg754_1, arg755_1, arg756_1, arg757_1, arg758_1, arg759_1, arg760_1, arg761_1, arg762_1, arg763_1, arg764_1, arg765_1, arg766_1, arg767_1, arg768_1, arg769_1, arg770_1, arg771_1, arg772_1, arg773_1, arg774_1, arg775_1, arg776_1, arg777_1, arg778_1, arg779_1, arg780_1, arg781_1, arg782_1, arg783_1, arg784_1, arg785_1, arg786_1, arg787_1, arg788_1, arg789_1, arg790_1, arg791_1, arg792_1, arg793_1, arg794_1, arg795_1, arg796_1, arg797_1, arg798_1, arg799_1, arg800_1, arg801_1, arg802_1, arg803_1, arg804_1, arg805_1, arg806_1, arg807_1, arg808_1, arg809_1, arg810_1, arg811_1, arg812_1, arg813_1, arg814_1, arg815_1, arg816_1, arg817_1, arg818_1, arg819_1, arg820_1, arg821_1, arg822_1, arg823_1, arg824_1, arg825_1, arg826_1, arg827_1, arg828_1, arg829_1, arg830_1, arg831_1, arg832_1, arg833_1, arg834_1, arg835_1, arg836_1, arg837_1, arg838_1, arg839_1]) | |
return print_performance(fn, times=times, repeat=repeat) | |
if __name__ == "__main__": | |
from torch._inductor.wrapper_benchmark import compiled_module_main | |
compiled_module_main('None', benchmark_compiled_module) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment