Skip to content

Instantly share code, notes, and snippets.

@mreso
Created August 16, 2024 21:04
Show Gist options
  • Save mreso/61d5d384854d1a6ef67c331803116d41 to your computer and use it in GitHub Desktop.
Save mreso/61d5d384854d1a6ef67c331803116d41 to your computer and use it in GitHub Desktop.
# AOT ID: ['0_inference']
from ctypes import c_void_p, c_long
import torch
import math
import random
import os
import tempfile
from math import inf, nan
from torch._inductor.hooks import run_intermediate_hooks
from torch._inductor.utils import maybe_profile
from torch._inductor.codegen.memory_planning import _align as align
from torch import device, empty_strided
from torch._inductor.async_compile import AsyncCompile
from torch._inductor.select_algorithm import extern_kernels
from torch._inductor.codegen.multi_kernel import MultiKernelCall
aten = torch.ops.aten
inductor_ops = torch.ops.inductor
_quantized = torch.ops._quantized
assert_size_stride = torch._C._dynamo.guards.assert_size_stride
empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu
empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda
reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor
alloc_from_pool = torch.ops.inductor._alloc_from_pool
async_compile = AsyncCompile()
# kernel path: /tmp/torchinductor_mreso/cp/ccpansfhuhbzfmq6fiemul7defvnmxsdsnwsbejgv3aqviivkozv.py
# Source Nodes: [add, h, mean, mul, pow_1, rsqrt, x_fp32, x_normed, y], Original ATen: [aten._to_copy, aten.add, aten.embedding, aten.mean, aten.mul, aten.pow, aten.rsqrt]
# add => add
# h => embedding
# mean => mean
# mul => mul
# pow_1 => pow_1
# rsqrt => rsqrt
# x_fp32 => convert_element_type
# x_normed => convert_element_type_1
# y => mul_1
triton_red_fused__to_copy_add_embedding_mean_mul_pow_rsqrt_0 = async_compile.triton('triton_', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
@triton_heuristics.reduction(
size_hints=[1, 4096],
reduction_hint=ReductionHint.INNER,
filename=__file__,
triton_meta={'signature': {0: '*i32', 1: '*bf16', 2: '*bf16', 3: '*bf16', 4: 'i32', 5: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {4: 1}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 5), equal_to_1=(4,))]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_add_embedding_mean_mul_pow_rsqrt_0', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 3, 'num_reduction': 1, 'backend_hash': '025D55A81D4601B19131CB9FF543A0FE014EB1904D412643095F7F0A8A7FF71B', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': True, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
)
@triton.jit
def triton_(in_ptr0, in_ptr1, in_ptr2, out_ptr1, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr):
xnumel = 1
rnumel = 4096
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
xmask = tl.full([XBLOCK, RBLOCK], True, tl.int1)
rbase = tl.arange(0, RBLOCK)[None, :]
tmp0 = tl.load(in_ptr0 + (0))
tmp1 = tl.broadcast_to(tmp0, [XBLOCK, RBLOCK])
_tmp11 = tl.full([XBLOCK, RBLOCK], 0, tl.float32)
for roffset in range(0, rnumel, RBLOCK):
rindex = roffset + rbase
rmask = rindex < rnumel
r0 = rindex
tmp2 = tl.full([XBLOCK, RBLOCK], 128256, tl.int32)
tmp3 = tmp1 + tmp2
tmp4 = tmp1 < 0
tmp5 = tl.where(tmp4, tmp3, tmp1)
tl.device_assert((0 <= tmp5) & (tmp5 < 128256), "index out of bounds: 0 <= tmp5 < 128256")
tmp7 = tl.load(in_ptr1 + (r0 + (4096*tmp5)), rmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp8 = tmp7.to(tl.float32)
tmp9 = tmp8 * tmp8
tmp10 = tl.broadcast_to(tmp9, [XBLOCK, RBLOCK])
tmp12 = _tmp11 + tmp10
_tmp11 = tl.where(rmask, tmp12, _tmp11)
tmp11 = tl.sum(_tmp11, 1)[:, None]
tmp13 = tl.load(in_ptr0 + (0))
tmp14 = tl.broadcast_to(tmp13, [XBLOCK, RBLOCK])
for roffset in range(0, rnumel, RBLOCK):
rindex = roffset + rbase
rmask = rindex < rnumel
r0 = rindex
tmp29 = tl.load(in_ptr2 + (r0), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp15 = tl.full([XBLOCK, RBLOCK], 128256, tl.int32)
tmp16 = tmp14 + tmp15
tmp17 = tmp14 < 0
tmp18 = tl.where(tmp17, tmp16, tmp14)
tl.device_assert((0 <= tmp18) & (tmp18 < 128256), "index out of bounds: 0 <= tmp18 < 128256")
tmp20 = tl.load(in_ptr1 + (r0 + (4096*tmp18)), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp21 = tmp20.to(tl.float32)
tmp22 = 4096.0
tmp23 = tmp11 / tmp22
tmp24 = 1e-05
tmp25 = tmp23 + tmp24
tmp26 = libdevice.rsqrt(tmp25)
tmp27 = tmp21 * tmp26
tmp28 = tmp27.to(tl.float32)
tmp30 = tmp28 * tmp29
tl.store(out_ptr1 + (tl.broadcast_to(r0, [XBLOCK, RBLOCK])), tmp30, rmask)
''', device_str='cuda')
import triton
import triton.language as tl
from torch._inductor.runtime.triton_heuristics import grid, split_scan_grid, start_graph, end_graph
from torch._C import _cuda_getCurrentRawStream as get_raw_stream
# kernel path: /tmp/torchinductor_mreso/zh/czhm7z6vmgstxp34zha46q2znijjccrf4zpqaatdp4lazkbs7cyh.py
# Source Nodes: [max_1], Original ATen: [aten.max]
# max_1 => max_1
triton_poi_fused_max_1 = async_compile.triton('triton_', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
@triton_heuristics.pointwise(
size_hints=[1],
filename=__file__,
triton_meta={'signature': {0: '*i64', 1: '*i64', 2: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {2: 1}, 'configs': [AttrsDescriptor(divisible_by_16=(1,), equal_to_1=(2,))]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_max_1', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': '025D55A81D4601B19131CB9FF543A0FE014EB1904D412643095F7F0A8A7FF71B', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': True, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
min_elem_per_thread=0
)
@triton.jit
def triton_(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
xnumel = 1
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = tl.full([XBLOCK], True, tl.int1)
tmp0 = tl.load(in_ptr0 + (0))
tmp1 = tl.broadcast_to(tmp0, [XBLOCK])
tl.store(out_ptr0 + (tl.full([XBLOCK], 0, tl.int32)), tmp1, None)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_mreso/dm/cdmuegqtdrjgliya52h5xlbmwaerh3jvmz3o234f62w2wpha4dsv.py
# Source Nodes: [c], Original ATen: [aten.mm]
# c => mm
triton_tem_fused_mm_2 = async_compile.triton('triton_', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
@triton_heuristics.template(
num_stages=5,
num_warps=4,
triton_meta={'signature': {0: '*bf16', 1: '*bf16', 2: '*bf16'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2), equal_to_1=())]},
inductor_meta={'kernel_name': 'triton_tem_fused_mm_2', 'backend_hash': '025D55A81D4601B19131CB9FF543A0FE014EB1904D412643095F7F0A8A7FF71B', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': True, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
)
@triton.jit
def triton_(arg_A, arg_B, out_ptr0):
GROUP_M : tl.constexpr = 8
EVEN_K : tl.constexpr = True
ALLOW_TF32 : tl.constexpr = False
ACC_TYPE : tl.constexpr = tl.float32
B_PROLOGUE_CAST_TYPE : tl.constexpr = None
BLOCK_M : tl.constexpr = 16
BLOCK_N : tl.constexpr = 64
BLOCK_K : tl.constexpr = 128
A = arg_A
B = arg_B
M = 1
N = 4096
K = 4096
if M * N == 0:
# early exit due to zero-size input(s)
return
stride_am = 4096
stride_ak = 1
stride_bk = 1
stride_bn = 4096
# based on triton.ops.matmul
pid = tl.program_id(0)
grid_m = (M + BLOCK_M - 1) // BLOCK_M
grid_n = (N + BLOCK_N - 1) // BLOCK_N
# re-order program ID for better L2 performance
width = GROUP_M * grid_n
group_id = pid // width
group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
pid_m = group_id * GROUP_M + (pid % group_size)
pid_n = (pid % width) // (group_size)
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
if (stride_am == 1 and stride_ak == M) or (stride_am == K and stride_ak == 1):
ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)
else:
ram = rm % M
if (stride_bk == 1 and stride_bn == K) or (stride_bk == N and stride_bn == 1):
rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)
else:
rbn = rn % N
rk = tl.arange(0, BLOCK_K)
A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)
B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
for k in range(K, 0, -BLOCK_K):
if EVEN_K:
a = tl.load(A)
b = tl.load(B)
else:
a = tl.load(A, mask=rk[None, :] < k, other=0.)
b = tl.load(B, mask=rk[:, None] < k, other=0.)
if B_PROLOGUE_CAST_TYPE is not None:
b = b.to(B_PROLOGUE_CAST_TYPE)
acc += tl.dot(a, b, allow_tf32=ALLOW_TF32)
A += BLOCK_K * stride_ak
B += BLOCK_K * stride_bk
# rematerialize rm and rn to save registers
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
idx_m = rm[:, None]
idx_n = rn[None, :]
mask = (idx_m < M) & (idx_n < N)
# inductor generates a suffix
xindex = idx_n + (4096*idx_m)
tl.store(out_ptr0 + (tl.broadcast_to(idx_n, acc.shape)), acc, mask)
''', device_str='cuda')
import torch._inductor.kernel.mm_common
meta0 = {'GROUP_M': 8, 'EVEN_K': True, 'ALLOW_TF32': False, 'ACC_TYPE': 'tl.float32', 'B_PROLOGUE_CAST_TYPE': None, 'BLOCK_M': 16, 'BLOCK_N': 64, 'BLOCK_K': 128}
# kernel path: /tmp/torchinductor_mreso/lz/clz2e2d6xj6n3oho4nys4dfgydzhykoegeiet3dzqpcasieoh5up.py
# Source Nodes: [output, setitem, setitem_1], Original ATen: [aten._scaled_dot_product_efficient_attention, aten.index_put]
# output => _scaled_dot_product_efficient_attention
# setitem => index_put
# setitem_1 => index_put_1
triton_poi_fused__scaled_dot_product_efficient_attention_index_put_3 = async_compile.triton('triton_', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
@triton_heuristics.pointwise(
size_hints=[4096],
filename=__file__,
triton_meta={'signature': {0: '*i64', 1: '*bf16', 2: '*fp32', 3: '*bf16', 4: '*bf16', 5: '*bf16', 6: '*bf16', 7: '*bf16', 8: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(1, 2, 3, 4, 5, 6, 7, 8), equal_to_1=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__scaled_dot_product_efficient_attention_index_put_3', 'mutated_arg_names': ['out_ptr0', 'out_ptr1'], 'no_x_dim': False, 'num_load': 10, 'num_reduction': 0, 'backend_hash': '025D55A81D4601B19131CB9FF543A0FE014EB1904D412643095F7F0A8A7FF71B', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': True, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
min_elem_per_thread=0
)
@triton.jit
def triton_(in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, out_ptr0, out_ptr1, out_ptr2, xnumel, XBLOCK : tl.constexpr):
xnumel = 4096
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = tl.full([XBLOCK], True, tl.int1)
x2 = xindex
x0 = xindex % 128
x1 = (xindex // 128)
tmp0 = tl.load(in_ptr0 + (0))
tmp1 = tl.broadcast_to(tmp0, [XBLOCK])
tmp41 = tl.load(in_ptr3 + (x0 + (128*(x1 // 4))), None).to(tl.float32)
tmp2 = tl.full([XBLOCK], 2048, tl.int32)
tmp3 = tmp1 + tmp2
tmp4 = tmp1 < 0
tmp5 = tl.where(tmp4, tmp3, tmp1)
tl.device_assert((0 <= tmp5) & (tmp5 < 2048), "index out of bounds: 0 <= tmp5 < 2048")
tmp7 = x2 % 2
tmp8 = tl.full([1], 0, tl.int64)
tmp9 = tmp7 >= tmp8
tmp10 = tl.full([1], 1, tl.int64)
tmp11 = tmp7 < tmp10
tmp12 = tl.load(in_ptr1 + ((2*(x0 // 2)) + (128*(x1 // 4))), tmp11, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp13 = tmp12.to(tl.float32)
tl.device_assert(((0 <= tl.broadcast_to(tmp5, [XBLOCK])) & (tl.broadcast_to(tmp5, [XBLOCK]) < 2048)) | ~(tmp11), "index out of bounds: 0 <= tl.broadcast_to(tmp5, [XBLOCK]) < 2048")
tmp15 = tl.load(in_ptr2 + ((2*(x0 // 2)) + (128*tmp5)), tmp11, eviction_policy='evict_last', other=0.0)
tmp16 = tmp13 * tmp15
tmp17 = tl.load(in_ptr1 + (1 + (2*(x0 // 2)) + (128*(x1 // 4))), tmp11, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp18 = tmp17.to(tl.float32)
tmp19 = tl.load(in_ptr2 + (1 + (2*(x0 // 2)) + (128*tmp5)), tmp11, eviction_policy='evict_last', other=0.0)
tmp20 = tmp18 * tmp19
tmp21 = tmp16 - tmp20
tmp22 = tl.full(tmp21.shape, 0.0, tmp21.dtype)
tmp23 = tl.where(tmp11, tmp21, tmp22)
tmp24 = tmp7 >= tmp10
tmp25 = tl.full([1], 2, tl.int64)
tmp26 = tmp7 < tmp25
tmp27 = tl.load(in_ptr1 + (1 + (2*(x0 // 2)) + (128*(x1 // 4))), tmp24, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp28 = tmp27.to(tl.float32)
tl.device_assert(((0 <= tl.broadcast_to(tmp5, [XBLOCK])) & (tl.broadcast_to(tmp5, [XBLOCK]) < 2048)) | ~(tmp24), "index out of bounds: 0 <= tl.broadcast_to(tmp5, [XBLOCK]) < 2048")
tmp30 = tl.load(in_ptr2 + ((2*(x0 // 2)) + (128*tmp5)), tmp24, eviction_policy='evict_last', other=0.0)
tmp31 = tmp28 * tmp30
tmp32 = tl.load(in_ptr1 + ((2*(x0 // 2)) + (128*(x1 // 4))), tmp24, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp33 = tmp32.to(tl.float32)
tmp34 = tl.load(in_ptr2 + (1 + (2*(x0 // 2)) + (128*tmp5)), tmp24, eviction_policy='evict_last', other=0.0)
tmp35 = tmp33 * tmp34
tmp36 = tmp31 + tmp35
tmp37 = tl.full(tmp36.shape, 0.0, tmp36.dtype)
tmp38 = tl.where(tmp24, tmp36, tmp37)
tmp39 = tl.where(tmp11, tmp23, tmp38)
tmp40 = tmp39.to(tl.float32)
tmp42 = tl.load(in_ptr4 + ((2*(x0 // 2)) + (128*x1)), tmp11, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp43 = tmp42.to(tl.float32)
tmp44 = tmp43 * tmp15
tmp45 = tl.load(in_ptr4 + (1 + (2*(x0 // 2)) + (128*x1)), tmp11, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp46 = tmp45.to(tl.float32)
tmp47 = tmp46 * tmp19
tmp48 = tmp44 - tmp47
tmp49 = tl.full(tmp48.shape, 0.0, tmp48.dtype)
tmp50 = tl.where(tmp11, tmp48, tmp49)
tmp51 = tl.load(in_ptr4 + (1 + (2*(x0 // 2)) + (128*x1)), tmp24, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp52 = tmp51.to(tl.float32)
tmp53 = tmp52 * tmp30
tmp54 = tl.load(in_ptr4 + ((2*(x0 // 2)) + (128*x1)), tmp24, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp55 = tmp54.to(tl.float32)
tmp56 = tmp55 * tmp34
tmp57 = tmp53 + tmp56
tmp58 = tl.full(tmp57.shape, 0.0, tmp57.dtype)
tmp59 = tl.where(tmp24, tmp57, tmp58)
tmp60 = tl.where(tmp11, tmp50, tmp59)
tmp61 = tmp60.to(tl.float32)
tl.store(out_ptr0 + (x0 + (128*tmp5) + (262144*x1)), tmp40, None)
tl.store(out_ptr1 + (x0 + (128*tmp5) + (262144*x1)), tmp41, None)
tl.store(out_ptr2 + (x2), tmp61, None)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_mreso/n6/cn6h3fga4dekm5cvy4s4fhrazz4zcn3kegoonkhz2piybh75ca7i.py
# Source Nodes: [output], Original ATen: [aten._scaled_dot_product_efficient_attention]
# output => _scaled_dot_product_efficient_attention
triton_poi_fused__scaled_dot_product_efficient_attention_4 = async_compile.triton('triton_', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
@triton_heuristics.pointwise(
size_hints=[65536],
filename=__file__,
triton_meta={'signature': {0: '*i64', 1: '*i1', 2: '*bf16', 3: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(1, 2, 3), equal_to_1=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__scaled_dot_product_efficient_attention_4', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': '025D55A81D4601B19131CB9FF543A0FE014EB1904D412643095F7F0A8A7FF71B', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': True, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
min_elem_per_thread=0
)
@triton.jit
def triton_(in_ptr0, in_ptr1, out_ptr0, xnumel, XBLOCK : tl.constexpr):
xnumel = 65536
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = tl.full([XBLOCK], True, tl.int1)
x0 = xindex % 2048
x2 = xindex
tmp0 = tl.load(in_ptr0 + (0))
tmp1 = tl.broadcast_to(tmp0, [XBLOCK])
tmp2 = tl.full([XBLOCK], 2048, tl.int32)
tmp3 = tmp1 + tmp2
tmp4 = tmp1 < 0
tmp5 = tl.where(tmp4, tmp3, tmp1)
tl.device_assert((0 <= tmp5) & (tmp5 < 2048), "index out of bounds: 0 <= tmp5 < 2048")
tmp7 = tl.load(in_ptr1 + (x0 + (2048*tmp5)), None).to(tl.int1)
tmp8 = tmp7 == 0
tmp9 = float("-inf")
tmp10 = 0.0
tmp11 = tl.where(tmp8, tmp9, tmp10)
tl.store(out_ptr0 + (x2), tmp11, None)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_mreso/hi/chi7k4nldga4ozz4jralolpcemic42rouob3i2giv2okoaj746yc.py
# Source Nodes: [add_5, h, h_1, mean_1, mul_10, mul_11, pow_2, rsqrt_1, x_fp32_1, x_normed_1], Original ATen: [aten._to_copy, aten.add, aten.embedding, aten.mean, aten.mul, aten.pow, aten.rsqrt]
# add_5 => add_4
# h => embedding
# h_1 => add_3
# mean_1 => mean_1
# mul_10 => mul_10
# mul_11 => mul_11
# pow_2 => pow_2
# rsqrt_1 => rsqrt_1
# x_fp32_1 => convert_element_type_14
# x_normed_1 => convert_element_type_15
triton_red_fused__to_copy_add_embedding_mean_mul_pow_rsqrt_5 = async_compile.triton('triton_', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
@triton_heuristics.reduction(
size_hints=[1, 4096],
reduction_hint=ReductionHint.INNER,
filename=__file__,
triton_meta={'signature': {0: '*bf16', 1: '*i32', 2: '*bf16', 3: '*bf16', 4: '*bf16', 5: 'i32', 6: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {5: 1}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4, 6), equal_to_1=(5,))]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_add_embedding_mean_mul_pow_rsqrt_5', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 5, 'num_reduction': 1, 'backend_hash': '025D55A81D4601B19131CB9FF543A0FE014EB1904D412643095F7F0A8A7FF71B', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': True, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
)
@triton.jit
def triton_(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr1, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr):
xnumel = 1
rnumel = 4096
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
xmask = tl.full([XBLOCK, RBLOCK], True, tl.int1)
rbase = tl.arange(0, RBLOCK)[None, :]
tmp1 = tl.load(in_ptr1 + (0))
tmp2 = tl.broadcast_to(tmp1, [XBLOCK, RBLOCK])
_tmp13 = tl.full([XBLOCK, RBLOCK], 0, tl.float32)
for roffset in range(0, rnumel, RBLOCK):
rindex = roffset + rbase
rmask = rindex < rnumel
r0 = rindex
tmp0 = tl.load(in_ptr0 + (r0), rmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp3 = tl.full([XBLOCK, RBLOCK], 128256, tl.int32)
tmp4 = tmp2 + tmp3
tmp5 = tmp2 < 0
tmp6 = tl.where(tmp5, tmp4, tmp2)
tl.device_assert((0 <= tmp6) & (tmp6 < 128256), "index out of bounds: 0 <= tmp6 < 128256")
tmp8 = tl.load(in_ptr2 + (r0 + (4096*tmp6)), rmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp9 = tmp0 + tmp8
tmp10 = tmp9.to(tl.float32)
tmp11 = tmp10 * tmp10
tmp12 = tl.broadcast_to(tmp11, [XBLOCK, RBLOCK])
tmp14 = _tmp13 + tmp12
_tmp13 = tl.where(rmask, tmp14, _tmp13)
tmp13 = tl.sum(_tmp13, 1)[:, None]
tmp16 = tl.load(in_ptr1 + (0))
tmp17 = tl.broadcast_to(tmp16, [XBLOCK, RBLOCK])
for roffset in range(0, rnumel, RBLOCK):
rindex = roffset + rbase
rmask = rindex < rnumel
r0 = rindex
tmp15 = tl.load(in_ptr0 + (r0), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp33 = tl.load(in_ptr3 + (r0), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp18 = tl.full([XBLOCK, RBLOCK], 128256, tl.int32)
tmp19 = tmp17 + tmp18
tmp20 = tmp17 < 0
tmp21 = tl.where(tmp20, tmp19, tmp17)
tl.device_assert((0 <= tmp21) & (tmp21 < 128256), "index out of bounds: 0 <= tmp21 < 128256")
tmp23 = tl.load(in_ptr2 + (r0 + (4096*tmp21)), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp24 = tmp15 + tmp23
tmp25 = tmp24.to(tl.float32)
tmp26 = 4096.0
tmp27 = tmp13 / tmp26
tmp28 = 1e-05
tmp29 = tmp27 + tmp28
tmp30 = libdevice.rsqrt(tmp29)
tmp31 = tmp25 * tmp30
tmp32 = tmp31.to(tl.float32)
tmp34 = tmp32 * tmp33
tl.store(out_ptr1 + (tl.broadcast_to(r0, [XBLOCK, RBLOCK])), tmp34, rmask)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_mreso/ur/curh37b7uwqy7x2z7p6suloz7vk3aiutkohhxggzhrcmv6ypkm2u.py
# Source Nodes: [c_4], Original ATen: [aten.mm]
# c_4 => mm_4
triton_tem_fused_mm_6 = async_compile.triton('triton_', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
@triton_heuristics.template(
num_stages=4,
num_warps=4,
triton_meta={'signature': {0: '*bf16', 1: '*bf16', 2: '*bf16'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2), equal_to_1=())]},
inductor_meta={'kernel_name': 'triton_tem_fused_mm_6', 'backend_hash': '025D55A81D4601B19131CB9FF543A0FE014EB1904D412643095F7F0A8A7FF71B', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': True, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
)
@triton.jit
def triton_(arg_A, arg_B, out_ptr0):
GROUP_M : tl.constexpr = 8
EVEN_K : tl.constexpr = True
ALLOW_TF32 : tl.constexpr = False
ACC_TYPE : tl.constexpr = tl.float32
B_PROLOGUE_CAST_TYPE : tl.constexpr = None
BLOCK_M : tl.constexpr = 16
BLOCK_N : tl.constexpr = 64
BLOCK_K : tl.constexpr = 32
A = arg_A
B = arg_B
M = 1
N = 14336
K = 4096
if M * N == 0:
# early exit due to zero-size input(s)
return
stride_am = 4096
stride_ak = 1
stride_bk = 1
stride_bn = 4096
# based on triton.ops.matmul
pid = tl.program_id(0)
grid_m = (M + BLOCK_M - 1) // BLOCK_M
grid_n = (N + BLOCK_N - 1) // BLOCK_N
# re-order program ID for better L2 performance
width = GROUP_M * grid_n
group_id = pid // width
group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
pid_m = group_id * GROUP_M + (pid % group_size)
pid_n = (pid % width) // (group_size)
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
if (stride_am == 1 and stride_ak == M) or (stride_am == K and stride_ak == 1):
ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)
else:
ram = rm % M
if (stride_bk == 1 and stride_bn == K) or (stride_bk == N and stride_bn == 1):
rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)
else:
rbn = rn % N
rk = tl.arange(0, BLOCK_K)
A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)
B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
for k in range(K, 0, -BLOCK_K):
if EVEN_K:
a = tl.load(A)
b = tl.load(B)
else:
a = tl.load(A, mask=rk[None, :] < k, other=0.)
b = tl.load(B, mask=rk[:, None] < k, other=0.)
if B_PROLOGUE_CAST_TYPE is not None:
b = b.to(B_PROLOGUE_CAST_TYPE)
acc += tl.dot(a, b, allow_tf32=ALLOW_TF32)
A += BLOCK_K * stride_ak
B += BLOCK_K * stride_bk
# rematerialize rm and rn to save registers
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
idx_m = rm[:, None]
idx_n = rn[None, :]
mask = (idx_m < M) & (idx_n < N)
# inductor generates a suffix
xindex = idx_n + (14336*idx_m)
tl.store(out_ptr0 + (tl.broadcast_to(idx_n, acc.shape)), acc, mask)
''', device_str='cuda')
meta1 = {'GROUP_M': 8, 'EVEN_K': True, 'ALLOW_TF32': False, 'ACC_TYPE': 'tl.float32', 'B_PROLOGUE_CAST_TYPE': None, 'BLOCK_M': 16, 'BLOCK_N': 64, 'BLOCK_K': 32}
# kernel path: /tmp/torchinductor_mreso/gg/cggdihc3hz6r76zojlkqkx4mohi7rtngag436gzcigqfbq6qr63m.py
# Source Nodes: [c_5, mul_12, silu], Original ATen: [aten.mm, aten.mul, aten.silu]
# c_5 => mm_5
# mul_12 => mul_13
# silu => convert_element_type_18, convert_element_type_19, mul_12, sigmoid
triton_tem_fused_mm_mul_silu_7 = async_compile.triton('triton_', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
@triton_heuristics.template(
num_stages=5,
num_warps=4,
triton_meta={'signature': {0: '*bf16', 1: '*bf16', 2: '*bf16', 3: '*bf16'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3), equal_to_1=())]},
inductor_meta={'kernel_name': 'triton_tem_fused_mm_mul_silu_7', 'backend_hash': '025D55A81D4601B19131CB9FF543A0FE014EB1904D412643095F7F0A8A7FF71B', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': True, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
)
@triton.jit
def triton_(arg_A, arg_B, in_ptr2, out_ptr1):
GROUP_M : tl.constexpr = 8
EVEN_K : tl.constexpr = True
ALLOW_TF32 : tl.constexpr = False
ACC_TYPE : tl.constexpr = tl.float32
B_PROLOGUE_CAST_TYPE : tl.constexpr = None
BLOCK_M : tl.constexpr = 16
BLOCK_N : tl.constexpr = 64
BLOCK_K : tl.constexpr = 32
A = arg_A
B = arg_B
M = 1
N = 14336
K = 4096
if M * N == 0:
# early exit due to zero-size input(s)
return
stride_am = 4096
stride_ak = 1
stride_bk = 1
stride_bn = 4096
# based on triton.ops.matmul
pid = tl.program_id(0)
grid_m = (M + BLOCK_M - 1) // BLOCK_M
grid_n = (N + BLOCK_N - 1) // BLOCK_N
# re-order program ID for better L2 performance
width = GROUP_M * grid_n
group_id = pid // width
group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
pid_m = group_id * GROUP_M + (pid % group_size)
pid_n = (pid % width) // (group_size)
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
if (stride_am == 1 and stride_ak == M) or (stride_am == K and stride_ak == 1):
ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)
else:
ram = rm % M
if (stride_bk == 1 and stride_bn == K) or (stride_bk == N and stride_bn == 1):
rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)
else:
rbn = rn % N
rk = tl.arange(0, BLOCK_K)
A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)
B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
for k in range(K, 0, -BLOCK_K):
if EVEN_K:
a = tl.load(A)
b = tl.load(B)
else:
a = tl.load(A, mask=rk[None, :] < k, other=0.)
b = tl.load(B, mask=rk[:, None] < k, other=0.)
if B_PROLOGUE_CAST_TYPE is not None:
b = b.to(B_PROLOGUE_CAST_TYPE)
acc += tl.dot(a, b, allow_tf32=ALLOW_TF32)
A += BLOCK_K * stride_ak
B += BLOCK_K * stride_bk
# rematerialize rm and rn to save registers
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
idx_m = rm[:, None]
idx_n = rn[None, :]
mask = (idx_m < M) & (idx_n < N)
# inductor generates a suffix
xindex = idx_n + (14336*idx_m)
tmp0 = tl.load(in_ptr2 + (tl.broadcast_to(xindex, acc.shape)), mask, eviction_policy='evict_last').to(tl.float32)
tmp1 = tmp0.to(tl.float32)
tmp2 = tl.sigmoid(tmp1)
tmp3 = tmp1 * tmp2
tmp4 = tmp3.to(tl.float32)
tmp5 = tmp4 * acc
tl.store(out_ptr1 + (tl.broadcast_to(xindex, acc.shape)), tmp5, mask)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_mreso/6m/c6m3krqefnziqfhazzhlqeozbuyylbyb2cp5ijykpbol2gg2v3qr.py
# Source Nodes: [c_6], Original ATen: [aten.mm]
# c_6 => mm_6
triton_tem_fused_mm_8 = async_compile.triton('triton_', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
@triton_heuristics.template(
num_stages=5,
num_warps=4,
triton_meta={'signature': {0: '*bf16', 1: '*bf16', 2: '*bf16'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2), equal_to_1=())]},
inductor_meta={'kernel_name': 'triton_tem_fused_mm_8', 'backend_hash': '025D55A81D4601B19131CB9FF543A0FE014EB1904D412643095F7F0A8A7FF71B', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': True, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
)
@triton.jit
def triton_(arg_A, arg_B, out_ptr0):
GROUP_M : tl.constexpr = 8
EVEN_K : tl.constexpr = True
ALLOW_TF32 : tl.constexpr = False
ACC_TYPE : tl.constexpr = tl.float32
B_PROLOGUE_CAST_TYPE : tl.constexpr = None
BLOCK_M : tl.constexpr = 16
BLOCK_N : tl.constexpr = 64
BLOCK_K : tl.constexpr = 128
A = arg_A
B = arg_B
M = 1
N = 4096
K = 14336
if M * N == 0:
# early exit due to zero-size input(s)
return
stride_am = 14336
stride_ak = 1
stride_bk = 1
stride_bn = 14336
# based on triton.ops.matmul
pid = tl.program_id(0)
grid_m = (M + BLOCK_M - 1) // BLOCK_M
grid_n = (N + BLOCK_N - 1) // BLOCK_N
# re-order program ID for better L2 performance
width = GROUP_M * grid_n
group_id = pid // width
group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
pid_m = group_id * GROUP_M + (pid % group_size)
pid_n = (pid % width) // (group_size)
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
if (stride_am == 1 and stride_ak == M) or (stride_am == K and stride_ak == 1):
ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)
else:
ram = rm % M
if (stride_bk == 1 and stride_bn == K) or (stride_bk == N and stride_bn == 1):
rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)
else:
rbn = rn % N
rk = tl.arange(0, BLOCK_K)
A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)
B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
for k in range(K, 0, -BLOCK_K):
if EVEN_K:
a = tl.load(A)
b = tl.load(B)
else:
a = tl.load(A, mask=rk[None, :] < k, other=0.)
b = tl.load(B, mask=rk[:, None] < k, other=0.)
if B_PROLOGUE_CAST_TYPE is not None:
b = b.to(B_PROLOGUE_CAST_TYPE)
acc += tl.dot(a, b, allow_tf32=ALLOW_TF32)
A += BLOCK_K * stride_ak
B += BLOCK_K * stride_bk
# rematerialize rm and rn to save registers
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
idx_m = rm[:, None]
idx_n = rn[None, :]
mask = (idx_m < M) & (idx_n < N)
# inductor generates a suffix
xindex = idx_n + (4096*idx_m)
tl.store(out_ptr0 + (tl.broadcast_to(idx_n, acc.shape)), acc, mask)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_mreso/fg/cfghtgfe4n4ny6pchk5ahx43xbhxnqllrs5nwoae7yaniw5lxpyv.py
# Source Nodes: [add_7, h, h_1, mean_2, mul_13, out, pow_3, rsqrt_2, x_fp32_2, x_normed_2, y_1], Original ATen: [aten._to_copy, aten.add, aten.embedding, aten.mean, aten.mul, aten.pow, aten.rsqrt]
# add_7 => add_6
# h => embedding
# h_1 => add_3
# mean_2 => mean_2
# mul_13 => mul_14
# out => add_5
# pow_3 => pow_3
# rsqrt_2 => rsqrt_2
# x_fp32_2 => convert_element_type_24
# x_normed_2 => convert_element_type_25
# y_1 => mul_15
triton_red_fused__to_copy_add_embedding_mean_mul_pow_rsqrt_9 = async_compile.triton('triton_', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
@triton_heuristics.reduction(
size_hints=[1, 4096],
reduction_hint=ReductionHint.INNER,
filename=__file__,
triton_meta={'signature': {0: '*bf16', 1: '*i32', 2: '*bf16', 3: '*bf16', 4: '*bf16', 5: '*bf16', 6: 'i32', 7: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {6: 1}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4, 5, 7), equal_to_1=(6,))]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_add_embedding_mean_mul_pow_rsqrt_9', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 7, 'num_reduction': 1, 'backend_hash': '025D55A81D4601B19131CB9FF543A0FE014EB1904D412643095F7F0A8A7FF71B', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': True, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
)
@triton.jit
def triton_(in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, out_ptr1, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr):
xnumel = 1
rnumel = 4096
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
xmask = tl.full([XBLOCK, RBLOCK], True, tl.int1)
rbase = tl.arange(0, RBLOCK)[None, :]
tmp1 = tl.load(in_ptr1 + (0))
tmp2 = tl.broadcast_to(tmp1, [XBLOCK, RBLOCK])
_tmp15 = tl.full([XBLOCK, RBLOCK], 0, tl.float32)
for roffset in range(0, rnumel, RBLOCK):
rindex = roffset + rbase
rmask = rindex < rnumel
r0 = rindex
tmp0 = tl.load(in_ptr0 + (r0), rmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp10 = tl.load(in_ptr3 + (r0), rmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp3 = tl.full([XBLOCK, RBLOCK], 128256, tl.int32)
tmp4 = tmp2 + tmp3
tmp5 = tmp2 < 0
tmp6 = tl.where(tmp5, tmp4, tmp2)
tl.device_assert((0 <= tmp6) & (tmp6 < 128256), "index out of bounds: 0 <= tmp6 < 128256")
tmp8 = tl.load(in_ptr2 + (r0 + (4096*tmp6)), rmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp9 = tmp0 + tmp8
tmp11 = tmp9 + tmp10
tmp12 = tmp11.to(tl.float32)
tmp13 = tmp12 * tmp12
tmp14 = tl.broadcast_to(tmp13, [XBLOCK, RBLOCK])
tmp16 = _tmp15 + tmp14
_tmp15 = tl.where(rmask, tmp16, _tmp15)
tmp15 = tl.sum(_tmp15, 1)[:, None]
tmp18 = tl.load(in_ptr1 + (0))
tmp19 = tl.broadcast_to(tmp18, [XBLOCK, RBLOCK])
for roffset in range(0, rnumel, RBLOCK):
rindex = roffset + rbase
rmask = rindex < rnumel
r0 = rindex
tmp17 = tl.load(in_ptr0 + (r0), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp27 = tl.load(in_ptr3 + (r0), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp37 = tl.load(in_ptr4 + (r0), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp20 = tl.full([XBLOCK, RBLOCK], 128256, tl.int32)
tmp21 = tmp19 + tmp20
tmp22 = tmp19 < 0
tmp23 = tl.where(tmp22, tmp21, tmp19)
tl.device_assert((0 <= tmp23) & (tmp23 < 128256), "index out of bounds: 0 <= tmp23 < 128256")
tmp25 = tl.load(in_ptr2 + (r0 + (4096*tmp23)), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp26 = tmp17 + tmp25
tmp28 = tmp26 + tmp27
tmp29 = tmp28.to(tl.float32)
tmp30 = 4096.0
tmp31 = tmp15 / tmp30
tmp32 = 1e-05
tmp33 = tmp31 + tmp32
tmp34 = libdevice.rsqrt(tmp33)
tmp35 = tmp29 * tmp34
tmp36 = tmp35.to(tl.float32)
tmp38 = tmp36 * tmp37
tl.store(out_ptr1 + (tl.broadcast_to(r0, [XBLOCK, RBLOCK])), tmp38, rmask)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_mreso/zo/czoyobha6z3774km535v7px7sq2dkfatukjcv4fgmxonhuumqv67.py
# Source Nodes: [add_12, h, h_1, h_2, mean_3, mul_23, mul_24, out, pow_4, rsqrt_3, x_fp32_3, x_normed_3], Original ATen: [aten._to_copy, aten.add, aten.embedding, aten.mean, aten.mul, aten.pow, aten.rsqrt]
# add_12 => add_10
# h => embedding
# h_1 => add_3
# h_2 => add_9
# mean_3 => mean_3
# mul_23 => mul_24
# mul_24 => mul_25
# out => add_5
# pow_4 => pow_4
# rsqrt_3 => rsqrt_3
# x_fp32_3 => convert_element_type_38
# x_normed_3 => convert_element_type_39
triton_red_fused__to_copy_add_embedding_mean_mul_pow_rsqrt_10 = async_compile.triton('triton_', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
@triton_heuristics.reduction(
size_hints=[1, 4096],
reduction_hint=ReductionHint.INNER,
filename=__file__,
triton_meta={'signature': {0: '*bf16', 1: '*bf16', 2: '*i32', 3: '*bf16', 4: '*bf16', 5: '*bf16', 6: '*bf16', 7: 'i32', 8: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {7: 1}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4, 5, 6, 8), equal_to_1=(7,))]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_add_embedding_mean_mul_pow_rsqrt_10', 'mutated_arg_names': ['in_out_ptr0'], 'no_x_dim': False, 'num_load': 6, 'num_reduction': 1, 'backend_hash': '025D55A81D4601B19131CB9FF543A0FE014EB1904D412643095F7F0A8A7FF71B', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': True, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
)
@triton.jit
def triton_(in_out_ptr0, in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, out_ptr1, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr):
xnumel = 1
rnumel = 4096
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
xmask = tl.full([XBLOCK, RBLOCK], True, tl.int1)
rbase = tl.arange(0, RBLOCK)[None, :]
tmp2 = tl.load(in_ptr1 + (0))
tmp3 = tl.broadcast_to(tmp2, [XBLOCK, RBLOCK])
_tmp17 = tl.full([XBLOCK, RBLOCK], 0, tl.float32)
for roffset in range(0, rnumel, RBLOCK):
rindex = roffset + rbase
rmask = rindex < rnumel
r0 = rindex
tmp0 = tl.load(in_out_ptr0 + (r0), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp1 = tl.load(in_ptr0 + (r0), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp11 = tl.load(in_ptr3 + (r0), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp4 = tl.full([XBLOCK, RBLOCK], 128256, tl.int32)
tmp5 = tmp3 + tmp4
tmp6 = tmp3 < 0
tmp7 = tl.where(tmp6, tmp5, tmp3)
tl.device_assert((0 <= tmp7) & (tmp7 < 128256), "index out of bounds: 0 <= tmp7 < 128256")
tmp9 = tl.load(in_ptr2 + (r0 + (4096*tmp7)), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp10 = tmp1 + tmp9
tmp12 = tmp10 + tmp11
tmp13 = tmp0 + tmp12
tmp14 = tmp13.to(tl.float32)
tmp15 = tmp14 * tmp14
tmp16 = tl.broadcast_to(tmp15, [XBLOCK, RBLOCK])
tmp18 = _tmp17 + tmp16
_tmp17 = tl.where(rmask, tmp18, _tmp17)
tl.store(in_out_ptr0 + (tl.broadcast_to(r0, [XBLOCK, RBLOCK])), tmp13, rmask)
tmp17 = tl.sum(_tmp17, 1)[:, None]
for roffset in range(0, rnumel, RBLOCK):
rindex = roffset + rbase
rmask = rindex < rnumel
r0 = rindex
tmp19 = tl.load(in_out_ptr0 + (r0), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp28 = tl.load(in_ptr4 + (r0), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp20 = tmp19.to(tl.float32)
tmp21 = 4096.0
tmp22 = tmp17 / tmp21
tmp23 = 1e-05
tmp24 = tmp22 + tmp23
tmp25 = libdevice.rsqrt(tmp24)
tmp26 = tmp20 * tmp25
tmp27 = tmp26.to(tl.float32)
tmp29 = tmp27 * tmp28
tl.store(out_ptr1 + (tl.broadcast_to(r0, [XBLOCK, RBLOCK])), tmp29, rmask)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_mreso/t7/ct7gses5bjf3ss3snvk3ainp7pao5q3rukpyjc5fe32c5zvf4q4d.py
# Source Nodes: [add_14, mean_4, mul_26, out_1, pow_5, rsqrt_4, x_fp32_4, x_normed_4, y_2], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, aten.pow, aten.rsqrt]
# add_14 => add_12
# mean_4 => mean_4
# mul_26 => mul_28
# out_1 => add_11
# pow_5 => pow_5
# rsqrt_4 => rsqrt_4
# x_fp32_4 => convert_element_type_48
# x_normed_4 => convert_element_type_49
# y_2 => mul_29
triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_11 = async_compile.triton('triton_', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
@triton_heuristics.reduction(
size_hints=[1, 4096],
reduction_hint=ReductionHint.INNER,
filename=__file__,
triton_meta={'signature': {0: '*bf16', 1: '*bf16', 2: '*bf16', 3: '*bf16', 4: 'i32', 5: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {4: 1}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 5), equal_to_1=(4,))]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_11', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 5, 'num_reduction': 1, 'backend_hash': '025D55A81D4601B19131CB9FF543A0FE014EB1904D412643095F7F0A8A7FF71B', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': True, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
)
@triton.jit
def triton_(in_ptr0, in_ptr1, in_ptr2, out_ptr1, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr):
xnumel = 1
rnumel = 4096
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
xmask = tl.full([XBLOCK, RBLOCK], True, tl.int1)
rbase = tl.arange(0, RBLOCK)[None, :]
_tmp6 = tl.full([XBLOCK, RBLOCK], 0, tl.float32)
for roffset in range(0, rnumel, RBLOCK):
rindex = roffset + rbase
rmask = rindex < rnumel
r0 = rindex
tmp0 = tl.load(in_ptr0 + (r0), rmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp1 = tl.load(in_ptr1 + (r0), rmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp2 = tmp0 + tmp1
tmp3 = tmp2.to(tl.float32)
tmp4 = tmp3 * tmp3
tmp5 = tl.broadcast_to(tmp4, [XBLOCK, RBLOCK])
tmp7 = _tmp6 + tmp5
_tmp6 = tl.where(rmask, tmp7, _tmp6)
tmp6 = tl.sum(_tmp6, 1)[:, None]
for roffset in range(0, rnumel, RBLOCK):
rindex = roffset + rbase
rmask = rindex < rnumel
r0 = rindex
tmp8 = tl.load(in_ptr0 + (r0), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp9 = tl.load(in_ptr1 + (r0), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp19 = tl.load(in_ptr2 + (r0), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp10 = tmp8 + tmp9
tmp11 = tmp10.to(tl.float32)
tmp12 = 4096.0
tmp13 = tmp6 / tmp12
tmp14 = 1e-05
tmp15 = tmp13 + tmp14
tmp16 = libdevice.rsqrt(tmp15)
tmp17 = tmp11 * tmp16
tmp18 = tmp17.to(tl.float32)
tmp20 = tmp18 * tmp19
tl.store(out_ptr1 + (tl.broadcast_to(r0, [XBLOCK, RBLOCK])), tmp20, rmask)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_mreso/b6/cb67w5ewxshtrgdm622b4yogv47npqcmskcea3x3froy22yqknx4.py
# Source Nodes: [add_19, h_3, mean_5, mul_36, mul_37, out_1, pow_6, rsqrt_5, x_fp32_5, x_normed_5], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, aten.pow, aten.rsqrt]
# add_19 => add_16
# h_3 => add_15
# mean_5 => mean_5
# mul_36 => mul_38
# mul_37 => mul_39
# out_1 => add_11
# pow_6 => pow_6
# rsqrt_5 => rsqrt_5
# x_fp32_5 => convert_element_type_62
# x_normed_5 => convert_element_type_63
triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_12 = async_compile.triton('triton_', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
@triton_heuristics.reduction(
size_hints=[1, 4096],
reduction_hint=ReductionHint.INNER,
filename=__file__,
triton_meta={'signature': {0: '*bf16', 1: '*bf16', 2: '*bf16', 3: '*bf16', 4: '*bf16', 5: 'i32', 6: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {5: 1}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4, 6), equal_to_1=(5,))]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_12', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 7, 'num_reduction': 1, 'backend_hash': '025D55A81D4601B19131CB9FF543A0FE014EB1904D412643095F7F0A8A7FF71B', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': True, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
)
@triton.jit
def triton_(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr1, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr):
xnumel = 1
rnumel = 4096
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
xmask = tl.full([XBLOCK, RBLOCK], True, tl.int1)
rbase = tl.arange(0, RBLOCK)[None, :]
_tmp8 = tl.full([XBLOCK, RBLOCK], 0, tl.float32)
for roffset in range(0, rnumel, RBLOCK):
rindex = roffset + rbase
rmask = rindex < rnumel
r0 = rindex
tmp0 = tl.load(in_ptr0 + (r0), rmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp1 = tl.load(in_ptr1 + (r0), rmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp2 = tl.load(in_ptr2 + (r0), rmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp3 = tmp1 + tmp2
tmp4 = tmp0 + tmp3
tmp5 = tmp4.to(tl.float32)
tmp6 = tmp5 * tmp5
tmp7 = tl.broadcast_to(tmp6, [XBLOCK, RBLOCK])
tmp9 = _tmp8 + tmp7
_tmp8 = tl.where(rmask, tmp9, _tmp8)
tmp8 = tl.sum(_tmp8, 1)[:, None]
for roffset in range(0, rnumel, RBLOCK):
rindex = roffset + rbase
rmask = rindex < rnumel
r0 = rindex
tmp10 = tl.load(in_ptr0 + (r0), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp11 = tl.load(in_ptr1 + (r0), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp12 = tl.load(in_ptr2 + (r0), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp23 = tl.load(in_ptr3 + (r0), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp13 = tmp11 + tmp12
tmp14 = tmp10 + tmp13
tmp15 = tmp14.to(tl.float32)
tmp16 = 4096.0
tmp17 = tmp8 / tmp16
tmp18 = 1e-05
tmp19 = tmp17 + tmp18
tmp20 = libdevice.rsqrt(tmp19)
tmp21 = tmp15 * tmp20
tmp22 = tmp21.to(tl.float32)
tmp24 = tmp22 * tmp23
tl.store(out_ptr1 + (tl.broadcast_to(r0, [XBLOCK, RBLOCK])), tmp24, rmask)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_mreso/rb/crbbcpdu63ofr6njzrpxbpqsry65k5jtwokgjwbmsrhzlwsiwy53.py
# Source Nodes: [add_21, h_3, mean_6, mul_39, out_1, out_2, pow_7, rsqrt_6, x_fp32_6, x_normed_6, y_3], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, aten.pow, aten.rsqrt]
# add_21 => add_18
# h_3 => add_15
# mean_6 => mean_6
# mul_39 => mul_42
# out_1 => add_11
# out_2 => add_17
# pow_7 => pow_7
# rsqrt_6 => rsqrt_6
# x_fp32_6 => convert_element_type_72
# x_normed_6 => convert_element_type_73
# y_3 => mul_43
triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_13 = async_compile.triton('triton_', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
@triton_heuristics.reduction(
size_hints=[1, 4096],
reduction_hint=ReductionHint.INNER,
filename=__file__,
triton_meta={'signature': {0: '*bf16', 1: '*bf16', 2: '*bf16', 3: '*bf16', 4: '*bf16', 5: '*bf16', 6: 'i32', 7: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {6: 1}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4, 5, 7), equal_to_1=(6,))]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_13', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 9, 'num_reduction': 1, 'backend_hash': '025D55A81D4601B19131CB9FF543A0FE014EB1904D412643095F7F0A8A7FF71B', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': True, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
)
@triton.jit
def triton_(in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, out_ptr1, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr):
xnumel = 1
rnumel = 4096
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
xmask = tl.full([XBLOCK, RBLOCK], True, tl.int1)
rbase = tl.arange(0, RBLOCK)[None, :]
_tmp10 = tl.full([XBLOCK, RBLOCK], 0, tl.float32)
for roffset in range(0, rnumel, RBLOCK):
rindex = roffset + rbase
rmask = rindex < rnumel
r0 = rindex
tmp0 = tl.load(in_ptr0 + (r0), rmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp1 = tl.load(in_ptr1 + (r0), rmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp2 = tl.load(in_ptr2 + (r0), rmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp5 = tl.load(in_ptr3 + (r0), rmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp3 = tmp1 + tmp2
tmp4 = tmp0 + tmp3
tmp6 = tmp4 + tmp5
tmp7 = tmp6.to(tl.float32)
tmp8 = tmp7 * tmp7
tmp9 = tl.broadcast_to(tmp8, [XBLOCK, RBLOCK])
tmp11 = _tmp10 + tmp9
_tmp10 = tl.where(rmask, tmp11, _tmp10)
tmp10 = tl.sum(_tmp10, 1)[:, None]
for roffset in range(0, rnumel, RBLOCK):
rindex = roffset + rbase
rmask = rindex < rnumel
r0 = rindex
tmp12 = tl.load(in_ptr0 + (r0), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp13 = tl.load(in_ptr1 + (r0), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp14 = tl.load(in_ptr2 + (r0), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp17 = tl.load(in_ptr3 + (r0), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp27 = tl.load(in_ptr4 + (r0), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp15 = tmp13 + tmp14
tmp16 = tmp12 + tmp15
tmp18 = tmp16 + tmp17
tmp19 = tmp18.to(tl.float32)
tmp20 = 4096.0
tmp21 = tmp10 / tmp20
tmp22 = 1e-05
tmp23 = tmp21 + tmp22
tmp24 = libdevice.rsqrt(tmp23)
tmp25 = tmp19 * tmp24
tmp26 = tmp25.to(tl.float32)
tmp28 = tmp26 * tmp27
tl.store(out_ptr1 + (tl.broadcast_to(r0, [XBLOCK, RBLOCK])), tmp28, rmask)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_mreso/tu/ctukwfouesobny732yngexqv32bgnsgz5l6q24aqfdzbci2ogg5b.py
# Source Nodes: [add_26, h_3, h_4, mean_7, mul_49, mul_50, out_1, out_2, pow_8, rsqrt_7, x_fp32_7, x_normed_7], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, aten.pow, aten.rsqrt]
# add_26 => add_22
# h_3 => add_15
# h_4 => add_21
# mean_7 => mean_7
# mul_49 => mul_52
# mul_50 => mul_53
# out_1 => add_11
# out_2 => add_17
# pow_8 => pow_8
# rsqrt_7 => rsqrt_7
# x_fp32_7 => convert_element_type_86
# x_normed_7 => convert_element_type_87
triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_14 = async_compile.triton('triton_', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
@triton_heuristics.reduction(
size_hints=[1, 4096],
reduction_hint=ReductionHint.INNER,
filename=__file__,
triton_meta={'signature': {0: '*bf16', 1: '*bf16', 2: '*bf16', 3: '*bf16', 4: '*bf16', 5: '*bf16', 6: '*bf16', 7: 'i32', 8: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {7: 1}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4, 5, 6, 8), equal_to_1=(7,))]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_14', 'mutated_arg_names': ['in_out_ptr0'], 'no_x_dim': False, 'num_load': 7, 'num_reduction': 1, 'backend_hash': '025D55A81D4601B19131CB9FF543A0FE014EB1904D412643095F7F0A8A7FF71B', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': True, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
)
@triton.jit
def triton_(in_out_ptr0, in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, out_ptr1, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr):
xnumel = 1
rnumel = 4096
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
xmask = tl.full([XBLOCK, RBLOCK], True, tl.int1)
rbase = tl.arange(0, RBLOCK)[None, :]
_tmp12 = tl.full([XBLOCK, RBLOCK], 0, tl.float32)
for roffset in range(0, rnumel, RBLOCK):
rindex = roffset + rbase
rmask = rindex < rnumel
r0 = rindex
tmp0 = tl.load(in_ptr0 + (r0), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp1 = tl.load(in_ptr1 + (r0), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp2 = tl.load(in_out_ptr0 + (r0), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp3 = tl.load(in_ptr2 + (r0), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp6 = tl.load(in_ptr3 + (r0), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp4 = tmp2 + tmp3
tmp5 = tmp1 + tmp4
tmp7 = tmp5 + tmp6
tmp8 = tmp0 + tmp7
tmp9 = tmp8.to(tl.float32)
tmp10 = tmp9 * tmp9
tmp11 = tl.broadcast_to(tmp10, [XBLOCK, RBLOCK])
tmp13 = _tmp12 + tmp11
_tmp12 = tl.where(rmask, tmp13, _tmp12)
tl.store(in_out_ptr0 + (tl.broadcast_to(r0, [XBLOCK, RBLOCK])), tmp8, rmask)
tmp12 = tl.sum(_tmp12, 1)[:, None]
for roffset in range(0, rnumel, RBLOCK):
rindex = roffset + rbase
rmask = rindex < rnumel
r0 = rindex
tmp14 = tl.load(in_out_ptr0 + (r0), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp23 = tl.load(in_ptr4 + (r0), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp15 = tmp14.to(tl.float32)
tmp16 = 4096.0
tmp17 = tmp12 / tmp16
tmp18 = 1e-05
tmp19 = tmp17 + tmp18
tmp20 = libdevice.rsqrt(tmp19)
tmp21 = tmp15 * tmp20
tmp22 = tmp21.to(tl.float32)
tmp24 = tmp22 * tmp23
tl.store(out_ptr1 + (tl.broadcast_to(r0, [XBLOCK, RBLOCK])), tmp24, rmask)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_mreso/nh/cnhnggduh5o3eeov7ogifj7cb5tdo3j5ndcivpiqhwnkp3u25u5n.py
# Source Nodes: [add_82, h_11, h_12, mean_23, mul_153, mul_154, out_10, out_9, pow_24, rsqrt_23, x_fp32_23, x_normed_23], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, aten.pow, aten.rsqrt]
# add_82 => add_70
# h_11 => add_63
# h_12 => add_69
# mean_23 => mean_23
# mul_153 => mul_164
# mul_154 => mul_165
# out_10 => add_65
# out_9 => add_59
# pow_24 => pow_24
# rsqrt_23 => rsqrt_23
# x_fp32_23 => convert_element_type_278
# x_normed_23 => convert_element_type_279
triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_15 = async_compile.triton('triton_', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
@triton_heuristics.reduction(
size_hints=[1, 4096],
reduction_hint=ReductionHint.INNER,
filename=__file__,
triton_meta={'signature': {0: '*bf16', 1: '*bf16', 2: '*bf16', 3: '*bf16', 4: '*bf16', 5: '*bf16', 6: '*bf16', 7: 'i32', 8: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {7: 1}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4, 5, 6, 8), equal_to_1=(7,))]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_15', 'mutated_arg_names': ['in_out_ptr0'], 'no_x_dim': False, 'num_load': 7, 'num_reduction': 1, 'backend_hash': '025D55A81D4601B19131CB9FF543A0FE014EB1904D412643095F7F0A8A7FF71B', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': True, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
)
@triton.jit
def triton_(in_out_ptr0, in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, out_ptr1, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr):
xnumel = 1
rnumel = 4096
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
xmask = tl.full([XBLOCK, RBLOCK], True, tl.int1)
rbase = tl.arange(0, RBLOCK)[None, :]
_tmp12 = tl.full([XBLOCK, RBLOCK], 0, tl.float32)
for roffset in range(0, rnumel, RBLOCK):
rindex = roffset + rbase
rmask = rindex < rnumel
r0 = rindex
tmp0 = tl.load(in_out_ptr0 + (r0), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp1 = tl.load(in_ptr0 + (r0), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp2 = tl.load(in_ptr1 + (r0), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp3 = tl.load(in_ptr2 + (r0), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp6 = tl.load(in_ptr3 + (r0), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp4 = tmp2 + tmp3
tmp5 = tmp1 + tmp4
tmp7 = tmp5 + tmp6
tmp8 = tmp0 + tmp7
tmp9 = tmp8.to(tl.float32)
tmp10 = tmp9 * tmp9
tmp11 = tl.broadcast_to(tmp10, [XBLOCK, RBLOCK])
tmp13 = _tmp12 + tmp11
_tmp12 = tl.where(rmask, tmp13, _tmp12)
tl.store(in_out_ptr0 + (tl.broadcast_to(r0, [XBLOCK, RBLOCK])), tmp8, rmask)
tmp12 = tl.sum(_tmp12, 1)[:, None]
for roffset in range(0, rnumel, RBLOCK):
rindex = roffset + rbase
rmask = rindex < rnumel
r0 = rindex
tmp14 = tl.load(in_out_ptr0 + (r0), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp23 = tl.load(in_ptr4 + (r0), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp15 = tmp14.to(tl.float32)
tmp16 = 4096.0
tmp17 = tmp12 / tmp16
tmp18 = 1e-05
tmp19 = tmp17 + tmp18
tmp20 = libdevice.rsqrt(tmp19)
tmp21 = tmp15 * tmp20
tmp22 = tmp21.to(tl.float32)
tmp24 = tmp22 * tmp23
tl.store(out_ptr1 + (tl.broadcast_to(r0, [XBLOCK, RBLOCK])), tmp24, rmask)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_mreso/dk/cdki6rbltj45344wgybatqntlalewzncir7p36oko7sux4egpdcm.py
# Source Nodes: [c_224, logits_1], Original ATen: [aten.div, aten.mm]
# c_224 => mm_224
# logits_1 => div
triton_tem_fused_div_mm_16 = async_compile.triton('triton_', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
@triton_heuristics.template(
num_stages=4,
num_warps=4,
triton_meta={'signature': {0: '*bf16', 1: '*bf16', 2: '*fp32'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2), equal_to_1=())]},
inductor_meta={'kernel_name': 'triton_tem_fused_div_mm_16', 'backend_hash': '025D55A81D4601B19131CB9FF543A0FE014EB1904D412643095F7F0A8A7FF71B', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': True, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
)
@triton.jit
def triton_(arg_A, arg_B, out_ptr1):
GROUP_M : tl.constexpr = 8
EVEN_K : tl.constexpr = True
ALLOW_TF32 : tl.constexpr = False
ACC_TYPE : tl.constexpr = tl.float32
B_PROLOGUE_CAST_TYPE : tl.constexpr = None
BLOCK_M : tl.constexpr = 16
BLOCK_N : tl.constexpr = 128
BLOCK_K : tl.constexpr = 128
A = arg_A
B = arg_B
M = 1
N = 128256
K = 4096
if M * N == 0:
# early exit due to zero-size input(s)
return
stride_am = 4096
stride_ak = 1
stride_bk = 1
stride_bn = 4096
# based on triton.ops.matmul
pid = tl.program_id(0)
grid_m = (M + BLOCK_M - 1) // BLOCK_M
grid_n = (N + BLOCK_N - 1) // BLOCK_N
# re-order program ID for better L2 performance
width = GROUP_M * grid_n
group_id = pid // width
group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
pid_m = group_id * GROUP_M + (pid % group_size)
pid_n = (pid % width) // (group_size)
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
if (stride_am == 1 and stride_ak == M) or (stride_am == K and stride_ak == 1):
ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)
else:
ram = rm % M
if (stride_bk == 1 and stride_bn == K) or (stride_bk == N and stride_bn == 1):
rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)
else:
rbn = rn % N
rk = tl.arange(0, BLOCK_K)
A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)
B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
for k in range(K, 0, -BLOCK_K):
if EVEN_K:
a = tl.load(A)
b = tl.load(B)
else:
a = tl.load(A, mask=rk[None, :] < k, other=0.)
b = tl.load(B, mask=rk[:, None] < k, other=0.)
if B_PROLOGUE_CAST_TYPE is not None:
b = b.to(B_PROLOGUE_CAST_TYPE)
acc += tl.dot(a, b, allow_tf32=ALLOW_TF32)
A += BLOCK_K * stride_ak
B += BLOCK_K * stride_bk
# rematerialize rm and rn to save registers
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
idx_m = rm[:, None]
idx_n = rn[None, :]
mask = (idx_m < M) & (idx_n < N)
# inductor generates a suffix
xindex = idx_n + (128256*idx_m)
tmp0 = acc.to(tl.float32)
tmp1 = 1.6666666666666667
tmp2 = tmp0 * tmp1
tl.store(out_ptr1 + (tl.broadcast_to(xindex, acc.shape)), tmp2, mask)
''', device_str='cuda')
meta2 = {'GROUP_M': 8, 'EVEN_K': True, 'ALLOW_TF32': False, 'ACC_TYPE': 'tl.float32', 'B_PROLOGUE_CAST_TYPE': None, 'BLOCK_M': 16, 'BLOCK_N': 128, 'BLOCK_K': 128}
# kernel path: /tmp/torchinductor_mreso/lw/clwphkgp36abizew5ghfdhpfw7vyfjuel6fa5dordkziwvs7k47u.py
# Source Nodes: [logits_2, lt, probs], Original ATen: [aten._softmax, aten.lt, aten.scalar_tensor, aten.where]
# logits_2 => full_default_64, where_32
# lt => lt
# probs => amax
triton_red_fused__softmax_lt_scalar_tensor_where_17 = async_compile.triton('triton_', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
@triton_heuristics.reduction(
size_hints=[16, 8192],
reduction_hint=ReductionHint.INNER,
filename=__file__,
triton_meta={'signature': {0: '*fp32', 1: '*fp32', 2: '*fp32', 3: 'i32', 4: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4), equal_to_1=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused__softmax_lt_scalar_tensor_where_17', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 2, 'num_reduction': 1, 'backend_hash': '025D55A81D4601B19131CB9FF543A0FE014EB1904D412643095F7F0A8A7FF71B', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': True, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
)
@triton.jit
def triton_(in_ptr0, in_ptr1, out_ptr0, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr):
xnumel = 16
rnumel = 8016
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
xmask = xindex < xnumel
rbase = tl.arange(0, RBLOCK)[None, :]
x0 = xindex
tmp1 = tl.load(in_ptr1 + (299))
tmp2 = tl.broadcast_to(tmp1, [XBLOCK, RBLOCK])
_tmp7 = tl.full([XBLOCK, RBLOCK], float("-inf"), tl.float32)
for roffset in range(0, rnumel, RBLOCK):
rindex = roffset + rbase
rmask = rindex < rnumel
r1 = rindex
tmp0 = tl.load(in_ptr0 + (r1 + (8016*x0)), rmask & xmask, eviction_policy='evict_last', other=0.0)
tmp3 = tmp0 < tmp2
tmp4 = float("-inf")
tmp5 = tl.where(tmp3, tmp4, tmp0)
tmp6 = tl.broadcast_to(tmp5, [XBLOCK, RBLOCK])
tmp8 = triton_helpers.maximum(_tmp7, tmp6)
_tmp7 = tl.where(rmask & xmask, tmp8, _tmp7)
tmp7 = triton_helpers.max2(_tmp7, 1)[:, None]
tl.store(out_ptr0 + (x0), tmp7, xmask)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_mreso/73/c73m5zklquxdkmoswttbn6lvjao4jtaanllldybhblahbrxtvlnu.py
# Source Nodes: [logits_2, lt, probs], Original ATen: [aten._softmax, aten.lt, aten.scalar_tensor, aten.where]
# logits_2 => full_default_64, where_32
# lt => lt
# probs => amax
triton_per_fused__softmax_lt_scalar_tensor_where_18 = async_compile.triton('triton_', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
@triton_heuristics.persistent_reduction(
size_hints=[1, 16],
reduction_hint=ReductionHint.INNER,
filename=__file__,
triton_meta={'signature': {0: '*fp32', 1: '*fp32', 2: 'i32', 3: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {2: 1}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 3), equal_to_1=(2,))]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_per_fused__softmax_lt_scalar_tensor_where_18', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': '025D55A81D4601B19131CB9FF543A0FE014EB1904D412643095F7F0A8A7FF71B', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': True, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
)
@triton.jit
def triton_(in_ptr0, out_ptr0, xnumel, rnumel, XBLOCK : tl.constexpr):
xnumel = 1
rnumel = 16
RBLOCK: tl.constexpr = 16
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
xmask = tl.full([XBLOCK, RBLOCK], True, tl.int1)
rindex = tl.arange(0, RBLOCK)[None, :]
roffset = 0
rmask = tl.full([XBLOCK, RBLOCK], True, tl.int1)
r0 = rindex
tmp0 = tl.load(in_ptr0 + (r0), None)
tmp1 = tl.broadcast_to(tmp0, [XBLOCK, RBLOCK])
tmp3 = triton_helpers.max2(tmp1, 1)[:, None]
tl.store(out_ptr0 + (tl.full([XBLOCK, 1], 0, tl.int32)), tmp3, None)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_mreso/p7/cp7u644vw4j3c6njbq2l3pgixqa34e47y4v7jqlflpndagvldspp.py
# Source Nodes: [logits_2, lt, probs], Original ATen: [aten._softmax, aten.lt, aten.scalar_tensor, aten.where]
# logits_2 => full_default_64, where_32
# lt => lt
# probs => exp, sub_64, sum_1
triton_red_fused__softmax_lt_scalar_tensor_where_19 = async_compile.triton('triton_', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
@triton_heuristics.reduction(
size_hints=[16, 8192],
reduction_hint=ReductionHint.INNER,
filename=__file__,
triton_meta={'signature': {0: '*fp32', 1: '*fp32', 2: '*fp32', 3: '*fp32', 4: 'i32', 5: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4, 5), equal_to_1=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused__softmax_lt_scalar_tensor_where_19', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 3, 'num_reduction': 1, 'backend_hash': '025D55A81D4601B19131CB9FF543A0FE014EB1904D412643095F7F0A8A7FF71B', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': True, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
)
@triton.jit
def triton_(in_ptr0, in_ptr1, in_ptr2, out_ptr0, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr):
xnumel = 16
rnumel = 8016
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
xmask = xindex < xnumel
rbase = tl.arange(0, RBLOCK)[None, :]
x0 = xindex
tmp1 = tl.load(in_ptr1 + (299))
tmp2 = tl.broadcast_to(tmp1, [XBLOCK, RBLOCK])
tmp6 = tl.load(in_ptr2 + (0))
tmp7 = tl.broadcast_to(tmp6, [XBLOCK, RBLOCK])
_tmp11 = tl.full([XBLOCK, RBLOCK], 0, tl.float32)
for roffset in range(0, rnumel, RBLOCK):
rindex = roffset + rbase
rmask = rindex < rnumel
r1 = rindex
tmp0 = tl.load(in_ptr0 + (r1 + (8016*x0)), rmask & xmask, eviction_policy='evict_last', other=0.0)
tmp3 = tmp0 < tmp2
tmp4 = float("-inf")
tmp5 = tl.where(tmp3, tmp4, tmp0)
tmp8 = tmp5 - tmp7
tmp9 = tl_math.exp(tmp8)
tmp10 = tl.broadcast_to(tmp9, [XBLOCK, RBLOCK])
tmp12 = _tmp11 + tmp10
_tmp11 = tl.where(rmask & xmask, tmp12, _tmp11)
tmp11 = tl.sum(_tmp11, 1)[:, None]
tl.store(out_ptr0 + (x0), tmp11, xmask)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_mreso/df/cdfdwm76l6z5rzxg3kkcpq6h47bpkmm5n4zgve47fbz36i65dwyu.py
# Source Nodes: [logits_2, lt, probs], Original ATen: [aten._softmax, aten.lt, aten.scalar_tensor, aten.where]
# logits_2 => full_default_64, where_32
# lt => lt
# probs => exp, sub_64, sum_1
triton_per_fused__softmax_lt_scalar_tensor_where_20 = async_compile.triton('triton_', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
@triton_heuristics.persistent_reduction(
size_hints=[1, 16],
reduction_hint=ReductionHint.INNER,
filename=__file__,
triton_meta={'signature': {0: '*fp32', 1: '*fp32', 2: 'i32', 3: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {2: 1}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 3), equal_to_1=(2,))]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_per_fused__softmax_lt_scalar_tensor_where_20', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': '025D55A81D4601B19131CB9FF543A0FE014EB1904D412643095F7F0A8A7FF71B', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': True, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
)
@triton.jit
def triton_(in_ptr0, out_ptr0, xnumel, rnumel, XBLOCK : tl.constexpr):
xnumel = 1
rnumel = 16
RBLOCK: tl.constexpr = 16
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
xmask = tl.full([XBLOCK, RBLOCK], True, tl.int1)
rindex = tl.arange(0, RBLOCK)[None, :]
roffset = 0
rmask = tl.full([XBLOCK, RBLOCK], True, tl.int1)
r0 = rindex
tmp0 = tl.load(in_ptr0 + (r0), None)
tmp1 = tl.broadcast_to(tmp0, [XBLOCK, RBLOCK])
tmp3 = tl.sum(tmp1, 1)[:, None]
tl.store(out_ptr0 + (tl.full([XBLOCK, 1], 0, tl.int32)), tmp3, None)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_mreso/zt/cztz4kuxbvivmp6txwflqjx3qpqrnxard4octamzlea35cgxr3fn.py
# Source Nodes: [argmax, logits_2, lt, probs, q_128, to_450, truediv_1], Original ATen: [aten._softmax, aten._to_copy, aten.argmax, aten.div, aten.exponential, aten.lt, aten.scalar_tensor, aten.where]
# argmax => argmax
# logits_2 => full_default_64, where_32
# lt => lt
# probs => div_1, exp, sub_64
# q_128 => full_default_65, ge, log, mul_450, where_33
# to_450 => convert_element_type_773
# truediv_1 => div_2
triton_red_fused__softmax__to_copy_argmax_div_exponential_lt_scalar_tensor_where_21 = async_compile.triton('triton_', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
@triton_heuristics.reduction(
size_hints=[1, 131072],
reduction_hint=ReductionHint.DEFAULT,
filename=__file__,
triton_meta={'signature': {0: '*i64', 1: '*fp32', 2: '*fp32', 3: '*fp32', 4: '*fp32', 5: '*i32', 6: 'i32', 7: 'i32', 8: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {7: 1}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4, 5, 8), equal_to_1=(7,))]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused__softmax__to_copy_argmax_div_exponential_lt_scalar_tensor_where_21', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 4, 'num_reduction': 1, 'backend_hash': '025D55A81D4601B19131CB9FF543A0FE014EB1904D412643095F7F0A8A7FF71B', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': True, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
)
@triton.jit
def triton_(in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, out_ptr2, load_seed_offset, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr):
xnumel = 1
rnumel = 128256
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
xmask = tl.full([XBLOCK, RBLOCK], True, tl.int1)
rbase = tl.arange(0, RBLOCK)[None, :]
tmp4 = tl.load(in_ptr2 + (299))
tmp5 = tl.broadcast_to(tmp4, [XBLOCK, RBLOCK])
tmp9 = tl.load(in_ptr3 + (0))
tmp10 = tl.broadcast_to(tmp9, [XBLOCK, RBLOCK])
tmp13 = tl.load(in_ptr4 + (0))
tmp14 = tl.broadcast_to(tmp13, [XBLOCK, RBLOCK])
_tmp25 = tl.full([XBLOCK, RBLOCK], float("-inf"), tl.float32)
_tmp25_index = tl.full([XBLOCK, RBLOCK], 9223372036854775807, tl.int64)
for roffset in range(0, rnumel, RBLOCK):
rindex = roffset + rbase
rmask = rindex < rnumel
r0 = rindex
tmp3 = tl.load(in_ptr1 + (r0), rmask, eviction_policy='evict_first', other=0.0)
tmp0 = tl.load(in_ptr0 + load_seed_offset)
tmp1 = r0
tmp2 = tl.rand(tmp0, (tmp1).to(tl.uint32))
tmp6 = tmp3 < tmp5
tmp7 = float("-inf")
tmp8 = tl.where(tmp6, tmp7, tmp3)
tmp11 = tmp8 - tmp10
tmp12 = tl_math.exp(tmp11)
tmp15 = tmp12 / tmp14
tmp16 = 0.9999999403953552
tmp17 = tmp2 >= tmp16
tmp18 = tl_math.log(tmp2)
tmp19 = -5.960464477539063e-08
tmp20 = tl.where(tmp17, tmp19, tmp18)
tmp21 = -1.0
tmp22 = tmp20 * tmp21
tmp23 = tmp15 / tmp22
tmp24 = tl.broadcast_to(tmp23, [XBLOCK, RBLOCK])
_tmp25_next, _tmp25_index_next = triton_helpers.maximum_with_index(
_tmp25, _tmp25_index, tmp24, rindex
)
_tmp25 = tl.where(rmask, _tmp25_next, _tmp25)
_tmp25_index = tl.where(rmask, _tmp25_index_next, _tmp25_index)
_, tmp25_tmp = triton_helpers.max_with_index(_tmp25, _tmp25_index, 1)
tmp25 = tmp25_tmp[:, None]
tmp26 = tmp25.to(tl.int32)
tl.store(out_ptr2 + (tl.full([XBLOCK, 1], 0, tl.int32)), tmp26, None)
''', device_str='cuda')
async_compile.wait(globals())
del async_compile
def call(args):
arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1, arg7_1, arg8_1, arg9_1, arg10_1, arg11_1, arg12_1, arg13_1, arg14_1, arg15_1, arg16_1, arg17_1, arg18_1, arg19_1, arg20_1, arg21_1, arg22_1, arg23_1, arg24_1, arg25_1, arg26_1, arg27_1, arg28_1, arg29_1, arg30_1, arg31_1, arg32_1, arg33_1, arg34_1, arg35_1, arg36_1, arg37_1, arg38_1, arg39_1, arg40_1, arg41_1, arg42_1, arg43_1, arg44_1, arg45_1, arg46_1, arg47_1, arg48_1, arg49_1, arg50_1, arg51_1, arg52_1, arg53_1, arg54_1, arg55_1, arg56_1, arg57_1, arg58_1, arg59_1, arg60_1, arg61_1, arg62_1, arg63_1, arg64_1, arg65_1, arg66_1, arg67_1, arg68_1, arg69_1, arg70_1, arg71_1, arg72_1, arg73_1, arg74_1, arg75_1, arg76_1, arg77_1, arg78_1, arg79_1, arg80_1, arg81_1, arg82_1, arg83_1, arg84_1, arg85_1, arg86_1, arg87_1, arg88_1, arg89_1, arg90_1, arg91_1, arg92_1, arg93_1, arg94_1, arg95_1, arg96_1, arg97_1, arg98_1, arg99_1, arg100_1, arg101_1, arg102_1, arg103_1, arg104_1, arg105_1, arg106_1, arg107_1, arg108_1, arg109_1, arg110_1, arg111_1, arg112_1, arg113_1, arg114_1, arg115_1, arg116_1, arg117_1, arg118_1, arg119_1, arg120_1, arg121_1, arg122_1, arg123_1, arg124_1, arg125_1, arg126_1, arg127_1, arg128_1, arg129_1, arg130_1, arg131_1, arg132_1, arg133_1, arg134_1, arg135_1, arg136_1, arg137_1, arg138_1, arg139_1, arg140_1, arg141_1, arg142_1, arg143_1, arg144_1, arg145_1, arg146_1, arg147_1, arg148_1, arg149_1, arg150_1, arg151_1, arg152_1, arg153_1, arg154_1, arg155_1, arg156_1, arg157_1, arg158_1, arg159_1, arg160_1, arg161_1, arg162_1, arg163_1, arg164_1, arg165_1, arg166_1, arg167_1, arg168_1, arg169_1, arg170_1, arg171_1, arg172_1, arg173_1, arg174_1, arg175_1, arg176_1, arg177_1, arg178_1, arg179_1, arg180_1, arg181_1, arg182_1, arg183_1, arg184_1, arg185_1, arg186_1, arg187_1, arg188_1, arg189_1, arg190_1, arg191_1, arg192_1, arg193_1, arg194_1, arg195_1, arg196_1, arg197_1, arg198_1, arg199_1, arg200_1, arg201_1, arg202_1, arg203_1, arg204_1, arg205_1, arg206_1, arg207_1, arg208_1, arg209_1, arg210_1, arg211_1, arg212_1, arg213_1, arg214_1, arg215_1, arg216_1, arg217_1, arg218_1, arg219_1, arg220_1, arg221_1, arg222_1, arg223_1, arg224_1, arg225_1, arg226_1, arg227_1, arg228_1, arg229_1, arg230_1, arg231_1, arg232_1, arg233_1, arg234_1, arg235_1, arg236_1, arg237_1, arg238_1, arg239_1, arg240_1, arg241_1, arg242_1, arg243_1, arg244_1, arg245_1, arg246_1, arg247_1, arg248_1, arg249_1, arg250_1, arg251_1, arg252_1, arg253_1, arg254_1, arg255_1, arg256_1, arg257_1, arg258_1, arg259_1, arg260_1, arg261_1, arg262_1, arg263_1, arg264_1, arg265_1, arg266_1, arg267_1, arg268_1, arg269_1, arg270_1, arg271_1, arg272_1, arg273_1, arg274_1, arg275_1, arg276_1, arg277_1, arg278_1, arg279_1, arg280_1, arg281_1, arg282_1, arg283_1, arg284_1, arg285_1, arg286_1, arg287_1, arg288_1, arg289_1, arg290_1, arg291_1, arg292_1, arg293_1, arg294_1, arg295_1, arg296_1, arg297_1, arg298_1, arg299_1, arg300_1, arg301_1, arg302_1, arg303_1, arg304_1, arg305_1, arg306_1, arg307_1, arg308_1, arg309_1, arg310_1, arg311_1, arg312_1, arg313_1, arg314_1, arg315_1, arg316_1, arg317_1, arg318_1, arg319_1, arg320_1, arg321_1, arg322_1, arg323_1, arg324_1, arg325_1, arg326_1, arg327_1, arg328_1, arg329_1, arg330_1, arg331_1, arg332_1, arg333_1, arg334_1, arg335_1, arg336_1, arg337_1, arg338_1, arg339_1, arg340_1, arg341_1, arg342_1, arg343_1, arg344_1, arg345_1, arg346_1, arg347_1, arg348_1, arg349_1, arg350_1, arg351_1, arg352_1, arg353_1, arg354_1, arg355_1, arg356_1, arg357_1, arg358_1, arg359_1, arg360_1, arg361_1, arg362_1, arg363_1, arg364_1, arg365_1, arg366_1, arg367_1, arg368_1, arg369_1, arg370_1, arg371_1, arg372_1, arg373_1, arg374_1, arg375_1, arg376_1, arg377_1, arg378_1, arg379_1, arg380_1, arg381_1, arg382_1, arg383_1, arg384_1, arg385_1, arg386_1, arg387_1, arg388_1, arg389_1, arg390_1, arg391_1, arg392_1, arg393_1, arg394_1, arg395_1, arg396_1, arg397_1, arg398_1, arg399_1, arg400_1, arg401_1, arg402_1, arg403_1, arg404_1, arg405_1, arg406_1, arg407_1, arg408_1, arg409_1, arg410_1, arg411_1, arg412_1, arg413_1, arg414_1, arg415_1, arg416_1, arg417_1, arg418_1, arg419_1, arg420_1, arg421_1, arg422_1, arg423_1, arg424_1, arg425_1, arg426_1, arg427_1, arg428_1, arg429_1, arg430_1, arg431_1, arg432_1, arg433_1, arg434_1, arg435_1, arg436_1, arg437_1, arg438_1, arg439_1, arg440_1, arg441_1, arg442_1, arg443_1, arg444_1, arg445_1, arg446_1, arg447_1, arg448_1, arg449_1, arg450_1, arg451_1, arg452_1, arg453_1, arg454_1, arg455_1, arg456_1, arg457_1, arg458_1, arg459_1, arg460_1, arg461_1, arg462_1, arg463_1, arg464_1, arg465_1, arg466_1, arg467_1, arg468_1, arg469_1, arg470_1, arg471_1, arg472_1, arg473_1, arg474_1, arg475_1, arg476_1, arg477_1, arg478_1, arg479_1, arg480_1, arg481_1, arg482_1, arg483_1, arg484_1, arg485_1, arg486_1, arg487_1, arg488_1, arg489_1, arg490_1, arg491_1, arg492_1, arg493_1, arg494_1, arg495_1, arg496_1, arg497_1, arg498_1, arg499_1, arg500_1, arg501_1, arg502_1, arg503_1, arg504_1, arg505_1, arg506_1, arg507_1, arg508_1, arg509_1, arg510_1, arg511_1, arg512_1, arg513_1, arg514_1, arg515_1, arg516_1, arg517_1, arg518_1, arg519_1, arg520_1, arg521_1, arg522_1, arg523_1, arg524_1, arg525_1, arg526_1, arg527_1, arg528_1, arg529_1, arg530_1, arg531_1, arg532_1, arg533_1, arg534_1, arg535_1, arg536_1, arg537_1, arg538_1, arg539_1, arg540_1, arg541_1, arg542_1, arg543_1, arg544_1, arg545_1, arg546_1, arg547_1, arg548_1, arg549_1, arg550_1, arg551_1, arg552_1, arg553_1, arg554_1, arg555_1, arg556_1, arg557_1, arg558_1, arg559_1, arg560_1, arg561_1, arg562_1, arg563_1, arg564_1, arg565_1, arg566_1, arg567_1, arg568_1, arg569_1, arg570_1, arg571_1, arg572_1, arg573_1, arg574_1, arg575_1, arg576_1, arg577_1, arg578_1, arg579_1, arg580_1, arg581_1, arg582_1, arg583_1, arg584_1, arg585_1, arg586_1, arg587_1, arg588_1, arg589_1, arg590_1, arg591_1, arg592_1, arg593_1, arg594_1, arg595_1, arg596_1, arg597_1, arg598_1, arg599_1, arg600_1, arg601_1, arg602_1, arg603_1, arg604_1, arg605_1, arg606_1, arg607_1, arg608_1, arg609_1, arg610_1, arg611_1, arg612_1, arg613_1, arg614_1, arg615_1, arg616_1, arg617_1, arg618_1, arg619_1, arg620_1, arg621_1, arg622_1, arg623_1, arg624_1, arg625_1, arg626_1, arg627_1, arg628_1, arg629_1, arg630_1, arg631_1, arg632_1, arg633_1, arg634_1, arg635_1, arg636_1, arg637_1, arg638_1, arg639_1, arg640_1, arg641_1, arg642_1, arg643_1, arg644_1, arg645_1, arg646_1, arg647_1, arg648_1, arg649_1, arg650_1, arg651_1, arg652_1, arg653_1, arg654_1, arg655_1, arg656_1, arg657_1, arg658_1, arg659_1, arg660_1, arg661_1, arg662_1, arg663_1, arg664_1, arg665_1, arg666_1, arg667_1, arg668_1, arg669_1, arg670_1, arg671_1, arg672_1, arg673_1, arg674_1, arg675_1, arg676_1, arg677_1, arg678_1, arg679_1, arg680_1, arg681_1, arg682_1, arg683_1, arg684_1, arg685_1, arg686_1, arg687_1, arg688_1, arg689_1, arg690_1, arg691_1, arg692_1, arg693_1, arg694_1, arg695_1, arg696_1, arg697_1, arg698_1, arg699_1, arg700_1, arg701_1, arg702_1, arg703_1, arg704_1, arg705_1, arg706_1, arg707_1, arg708_1, arg709_1, arg710_1, arg711_1, arg712_1, arg713_1, arg714_1, arg715_1, arg716_1, arg717_1, arg718_1, arg719_1, arg720_1, arg721_1, arg722_1, arg723_1, arg724_1, arg725_1, arg726_1, arg727_1, arg728_1, arg729_1, arg730_1, arg731_1, arg732_1, arg733_1, arg734_1, arg735_1, arg736_1, arg737_1, arg738_1, arg739_1, arg740_1, arg741_1, arg742_1, arg743_1, arg744_1, arg745_1, arg746_1, arg747_1, arg748_1, arg749_1, arg750_1, arg751_1, arg752_1, arg753_1, arg754_1, arg755_1, arg756_1, arg757_1, arg758_1, arg759_1, arg760_1, arg761_1, arg762_1, arg763_1, arg764_1, arg765_1, arg766_1, arg767_1, arg768_1, arg769_1, arg770_1, arg771_1, arg772_1, arg773_1, arg774_1, arg775_1, arg776_1, arg777_1, arg778_1, arg779_1, arg780_1, arg781_1, arg782_1, arg783_1, arg784_1, arg785_1, arg786_1, arg787_1, arg788_1, arg789_1, arg790_1, arg791_1, arg792_1, arg793_1, arg794_1, arg795_1, arg796_1, arg797_1, arg798_1, arg799_1, arg800_1, arg801_1, arg802_1, arg803_1, arg804_1, arg805_1, arg806_1, arg807_1, arg808_1, arg809_1, arg810_1, arg811_1, arg812_1, arg813_1, arg814_1, arg815_1, arg816_1, arg817_1, arg818_1, arg819_1, arg820_1, arg821_1, arg822_1, arg823_1, arg824_1, arg825_1, arg826_1, arg827_1, arg828_1, arg829_1, arg830_1, arg831_1, arg832_1, arg833_1, arg834_1, arg835_1, arg836_1, arg837_1, arg838_1, arg839_1 = args
args.clear()
assert_size_stride(arg0_1, (1, 1), (1, 1))
assert_size_stride(arg1_1, (128256, 4096), (4096, 1))
assert_size_stride(arg2_1, (2048, 2048), (2048, 1))
assert_size_stride(arg3_1, (1, ), (1, ))
assert_size_stride(arg4_1, (4096, ), (1, ))
assert_size_stride(arg5_1, (4096, 4096), (4096, 1))
assert_size_stride(arg6_1, (4096, 16), (16, 1))
assert_size_stride(arg7_1, (4096, 16), (16, 1))
assert_size_stride(arg8_1, (1024, 4096), (4096, 1))
assert_size_stride(arg9_1, (1024, 16), (16, 1))
assert_size_stride(arg10_1, (1024, 16), (16, 1))
assert_size_stride(arg11_1, (1024, 4096), (4096, 1))
assert_size_stride(arg12_1, (1024, 16), (16, 1))
assert_size_stride(arg13_1, (1024, 16), (16, 1))
assert_size_stride(arg14_1, (2048, 64, 2), (128, 2, 1))
assert_size_stride(arg15_1, (1, 32, 2048, 128), (8388608, 262144, 128, 1))
assert_size_stride(arg16_1, (1, 32, 2048, 128), (8388608, 262144, 128, 1))
assert_size_stride(arg17_1, (4096, 4096), (4096, 1))
assert_size_stride(arg18_1, (4096, 16), (16, 1))
assert_size_stride(arg19_1, (4096, 16), (16, 1))
assert_size_stride(arg20_1, (4096, ), (1, ))
assert_size_stride(arg21_1, (14336, 4096), (4096, 1))
assert_size_stride(arg22_1, (14336, 16), (16, 1))
assert_size_stride(arg23_1, (14336, 16), (16, 1))
assert_size_stride(arg24_1, (14336, 4096), (4096, 1))
assert_size_stride(arg25_1, (14336, 16), (16, 1))
assert_size_stride(arg26_1, (14336, 16), (16, 1))
assert_size_stride(arg27_1, (4096, 14336), (14336, 1))
assert_size_stride(arg28_1, (4096, 56), (56, 1))
assert_size_stride(arg29_1, (4096, 56), (56, 1))
assert_size_stride(arg30_1, (4096, ), (1, ))
assert_size_stride(arg31_1, (4096, 4096), (4096, 1))
assert_size_stride(arg32_1, (4096, 16), (16, 1))
assert_size_stride(arg33_1, (4096, 16), (16, 1))
assert_size_stride(arg34_1, (1024, 4096), (4096, 1))
assert_size_stride(arg35_1, (1024, 16), (16, 1))
assert_size_stride(arg36_1, (1024, 16), (16, 1))
assert_size_stride(arg37_1, (1024, 4096), (4096, 1))
assert_size_stride(arg38_1, (1024, 16), (16, 1))
assert_size_stride(arg39_1, (1024, 16), (16, 1))
assert_size_stride(arg40_1, (2048, 64, 2), (128, 2, 1))
assert_size_stride(arg41_1, (1, 32, 2048, 128), (8388608, 262144, 128, 1))
assert_size_stride(arg42_1, (1, 32, 2048, 128), (8388608, 262144, 128, 1))
assert_size_stride(arg43_1, (4096, 4096), (4096, 1))
assert_size_stride(arg44_1, (4096, 16), (16, 1))
assert_size_stride(arg45_1, (4096, 16), (16, 1))
assert_size_stride(arg46_1, (4096, ), (1, ))
assert_size_stride(arg47_1, (14336, 4096), (4096, 1))
assert_size_stride(arg48_1, (14336, 16), (16, 1))
assert_size_stride(arg49_1, (14336, 16), (16, 1))
assert_size_stride(arg50_1, (14336, 4096), (4096, 1))
assert_size_stride(arg51_1, (14336, 16), (16, 1))
assert_size_stride(arg52_1, (14336, 16), (16, 1))
assert_size_stride(arg53_1, (4096, 14336), (14336, 1))
assert_size_stride(arg54_1, (4096, 56), (56, 1))
assert_size_stride(arg55_1, (4096, 56), (56, 1))
assert_size_stride(arg56_1, (4096, ), (1, ))
assert_size_stride(arg57_1, (4096, 4096), (4096, 1))
assert_size_stride(arg58_1, (4096, 16), (16, 1))
assert_size_stride(arg59_1, (4096, 16), (16, 1))
assert_size_stride(arg60_1, (1024, 4096), (4096, 1))
assert_size_stride(arg61_1, (1024, 16), (16, 1))
assert_size_stride(arg62_1, (1024, 16), (16, 1))
assert_size_stride(arg63_1, (1024, 4096), (4096, 1))
assert_size_stride(arg64_1, (1024, 16), (16, 1))
assert_size_stride(arg65_1, (1024, 16), (16, 1))
assert_size_stride(arg66_1, (2048, 64, 2), (128, 2, 1))
assert_size_stride(arg67_1, (1, 32, 2048, 128), (8388608, 262144, 128, 1))
assert_size_stride(arg68_1, (1, 32, 2048, 128), (8388608, 262144, 128, 1))
assert_size_stride(arg69_1, (4096, 4096), (4096, 1))
assert_size_stride(arg70_1, (4096, 16), (16, 1))
assert_size_stride(arg71_1, (4096, 16), (16, 1))
assert_size_stride(arg72_1, (4096, ), (1, ))
assert_size_stride(arg73_1, (14336, 4096), (4096, 1))
assert_size_stride(arg74_1, (14336, 16), (16, 1))
assert_size_stride(arg75_1, (14336, 16), (16, 1))
assert_size_stride(arg76_1, (14336, 4096), (4096, 1))
assert_size_stride(arg77_1, (14336, 16), (16, 1))
assert_size_stride(arg78_1, (14336, 16), (16, 1))
assert_size_stride(arg79_1, (4096, 14336), (14336, 1))
assert_size_stride(arg80_1, (4096, 56), (56, 1))
assert_size_stride(arg81_1, (4096, 56), (56, 1))
assert_size_stride(arg82_1, (4096, ), (1, ))
assert_size_stride(arg83_1, (4096, 4096), (4096, 1))
assert_size_stride(arg84_1, (4096, 16), (16, 1))
assert_size_stride(arg85_1, (4096, 16), (16, 1))
assert_size_stride(arg86_1, (1024, 4096), (4096, 1))
assert_size_stride(arg87_1, (1024, 16), (16, 1))
assert_size_stride(arg88_1, (1024, 16), (16, 1))
assert_size_stride(arg89_1, (1024, 4096), (4096, 1))
assert_size_stride(arg90_1, (1024, 16), (16, 1))
assert_size_stride(arg91_1, (1024, 16), (16, 1))
assert_size_stride(arg92_1, (2048, 64, 2), (128, 2, 1))
assert_size_stride(arg93_1, (1, 32, 2048, 128), (8388608, 262144, 128, 1))
assert_size_stride(arg94_1, (1, 32, 2048, 128), (8388608, 262144, 128, 1))
assert_size_stride(arg95_1, (4096, 4096), (4096, 1))
assert_size_stride(arg96_1, (4096, 16), (16, 1))
assert_size_stride(arg97_1, (4096, 16), (16, 1))
assert_size_stride(arg98_1, (4096, ), (1, ))
assert_size_stride(arg99_1, (14336, 4096), (4096, 1))
assert_size_stride(arg100_1, (14336, 16), (16, 1))
assert_size_stride(arg101_1, (14336, 16), (16, 1))
assert_size_stride(arg102_1, (14336, 4096), (4096, 1))
assert_size_stride(arg103_1, (14336, 16), (16, 1))
assert_size_stride(arg104_1, (14336, 16), (16, 1))
assert_size_stride(arg105_1, (4096, 14336), (14336, 1))
assert_size_stride(arg106_1, (4096, 56), (56, 1))
assert_size_stride(arg107_1, (4096, 56), (56, 1))
assert_size_stride(arg108_1, (4096, ), (1, ))
assert_size_stride(arg109_1, (4096, 4096), (4096, 1))
assert_size_stride(arg110_1, (4096, 16), (16, 1))
assert_size_stride(arg111_1, (4096, 16), (16, 1))
assert_size_stride(arg112_1, (1024, 4096), (4096, 1))
assert_size_stride(arg113_1, (1024, 16), (16, 1))
assert_size_stride(arg114_1, (1024, 16), (16, 1))
assert_size_stride(arg115_1, (1024, 4096), (4096, 1))
assert_size_stride(arg116_1, (1024, 16), (16, 1))
assert_size_stride(arg117_1, (1024, 16), (16, 1))
assert_size_stride(arg118_1, (2048, 64, 2), (128, 2, 1))
assert_size_stride(arg119_1, (1, 32, 2048, 128), (8388608, 262144, 128, 1))
assert_size_stride(arg120_1, (1, 32, 2048, 128), (8388608, 262144, 128, 1))
assert_size_stride(arg121_1, (4096, 4096), (4096, 1))
assert_size_stride(arg122_1, (4096, 16), (16, 1))
assert_size_stride(arg123_1, (4096, 16), (16, 1))
assert_size_stride(arg124_1, (4096, ), (1, ))
assert_size_stride(arg125_1, (14336, 4096), (4096, 1))
assert_size_stride(arg126_1, (14336, 16), (16, 1))
assert_size_stride(arg127_1, (14336, 16), (16, 1))
assert_size_stride(arg128_1, (14336, 4096), (4096, 1))
assert_size_stride(arg129_1, (14336, 16), (16, 1))
assert_size_stride(arg130_1, (14336, 16), (16, 1))
assert_size_stride(arg131_1, (4096, 14336), (14336, 1))
assert_size_stride(arg132_1, (4096, 56), (56, 1))
assert_size_stride(arg133_1, (4096, 56), (56, 1))
assert_size_stride(arg134_1, (4096, ), (1, ))
assert_size_stride(arg135_1, (4096, 4096), (4096, 1))
assert_size_stride(arg136_1, (4096, 16), (16, 1))
assert_size_stride(arg137_1, (4096, 16), (16, 1))
assert_size_stride(arg138_1, (1024, 4096), (4096, 1))
assert_size_stride(arg139_1, (1024, 16), (16, 1))
assert_size_stride(arg140_1, (1024, 16), (16, 1))
assert_size_stride(arg141_1, (1024, 4096), (4096, 1))
assert_size_stride(arg142_1, (1024, 16), (16, 1))
assert_size_stride(arg143_1, (1024, 16), (16, 1))
assert_size_stride(arg144_1, (2048, 64, 2), (128, 2, 1))
assert_size_stride(arg145_1, (1, 32, 2048, 128), (8388608, 262144, 128, 1))
assert_size_stride(arg146_1, (1, 32, 2048, 128), (8388608, 262144, 128, 1))
assert_size_stride(arg147_1, (4096, 4096), (4096, 1))
assert_size_stride(arg148_1, (4096, 16), (16, 1))
assert_size_stride(arg149_1, (4096, 16), (16, 1))
assert_size_stride(arg150_1, (4096, ), (1, ))
assert_size_stride(arg151_1, (14336, 4096), (4096, 1))
assert_size_stride(arg152_1, (14336, 16), (16, 1))
assert_size_stride(arg153_1, (14336, 16), (16, 1))
assert_size_stride(arg154_1, (14336, 4096), (4096, 1))
assert_size_stride(arg155_1, (14336, 16), (16, 1))
assert_size_stride(arg156_1, (14336, 16), (16, 1))
assert_size_stride(arg157_1, (4096, 14336), (14336, 1))
assert_size_stride(arg158_1, (4096, 56), (56, 1))
assert_size_stride(arg159_1, (4096, 56), (56, 1))
assert_size_stride(arg160_1, (4096, ), (1, ))
assert_size_stride(arg161_1, (4096, 4096), (4096, 1))
assert_size_stride(arg162_1, (4096, 16), (16, 1))
assert_size_stride(arg163_1, (4096, 16), (16, 1))
assert_size_stride(arg164_1, (1024, 4096), (4096, 1))
assert_size_stride(arg165_1, (1024, 16), (16, 1))
assert_size_stride(arg166_1, (1024, 16), (16, 1))
assert_size_stride(arg167_1, (1024, 4096), (4096, 1))
assert_size_stride(arg168_1, (1024, 16), (16, 1))
assert_size_stride(arg169_1, (1024, 16), (16, 1))
assert_size_stride(arg170_1, (2048, 64, 2), (128, 2, 1))
assert_size_stride(arg171_1, (1, 32, 2048, 128), (8388608, 262144, 128, 1))
assert_size_stride(arg172_1, (1, 32, 2048, 128), (8388608, 262144, 128, 1))
assert_size_stride(arg173_1, (4096, 4096), (4096, 1))
assert_size_stride(arg174_1, (4096, 16), (16, 1))
assert_size_stride(arg175_1, (4096, 16), (16, 1))
assert_size_stride(arg176_1, (4096, ), (1, ))
assert_size_stride(arg177_1, (14336, 4096), (4096, 1))
assert_size_stride(arg178_1, (14336, 16), (16, 1))
assert_size_stride(arg179_1, (14336, 16), (16, 1))
assert_size_stride(arg180_1, (14336, 4096), (4096, 1))
assert_size_stride(arg181_1, (14336, 16), (16, 1))
assert_size_stride(arg182_1, (14336, 16), (16, 1))
assert_size_stride(arg183_1, (4096, 14336), (14336, 1))
assert_size_stride(arg184_1, (4096, 56), (56, 1))
assert_size_stride(arg185_1, (4096, 56), (56, 1))
assert_size_stride(arg186_1, (4096, ), (1, ))
assert_size_stride(arg187_1, (4096, 4096), (4096, 1))
assert_size_stride(arg188_1, (4096, 16), (16, 1))
assert_size_stride(arg189_1, (4096, 16), (16, 1))
assert_size_stride(arg190_1, (1024, 4096), (4096, 1))
assert_size_stride(arg191_1, (1024, 16), (16, 1))
assert_size_stride(arg192_1, (1024, 16), (16, 1))
assert_size_stride(arg193_1, (1024, 4096), (4096, 1))
assert_size_stride(arg194_1, (1024, 16), (16, 1))
assert_size_stride(arg195_1, (1024, 16), (16, 1))
assert_size_stride(arg196_1, (2048, 64, 2), (128, 2, 1))
assert_size_stride(arg197_1, (1, 32, 2048, 128), (8388608, 262144, 128, 1))
assert_size_stride(arg198_1, (1, 32, 2048, 128), (8388608, 262144, 128, 1))
assert_size_stride(arg199_1, (4096, 4096), (4096, 1))
assert_size_stride(arg200_1, (4096, 16), (16, 1))
assert_size_stride(arg201_1, (4096, 16), (16, 1))
assert_size_stride(arg202_1, (4096, ), (1, ))
assert_size_stride(arg203_1, (14336, 4096), (4096, 1))
assert_size_stride(arg204_1, (14336, 16), (16, 1))
assert_size_stride(arg205_1, (14336, 16), (16, 1))
assert_size_stride(arg206_1, (14336, 4096), (4096, 1))
assert_size_stride(arg207_1, (14336, 16), (16, 1))
assert_size_stride(arg208_1, (14336, 16), (16, 1))
assert_size_stride(arg209_1, (4096, 14336), (14336, 1))
assert_size_stride(arg210_1, (4096, 56), (56, 1))
assert_size_stride(arg211_1, (4096, 56), (56, 1))
assert_size_stride(arg212_1, (4096, ), (1, ))
assert_size_stride(arg213_1, (4096, 4096), (4096, 1))
assert_size_stride(arg214_1, (4096, 16), (16, 1))
assert_size_stride(arg215_1, (4096, 16), (16, 1))
assert_size_stride(arg216_1, (1024, 4096), (4096, 1))
assert_size_stride(arg217_1, (1024, 16), (16, 1))
assert_size_stride(arg218_1, (1024, 16), (16, 1))
assert_size_stride(arg219_1, (1024, 4096), (4096, 1))
assert_size_stride(arg220_1, (1024, 16), (16, 1))
assert_size_stride(arg221_1, (1024, 16), (16, 1))
assert_size_stride(arg222_1, (2048, 64, 2), (128, 2, 1))
assert_size_stride(arg223_1, (1, 32, 2048, 128), (8388608, 262144, 128, 1))
assert_size_stride(arg224_1, (1, 32, 2048, 128), (8388608, 262144, 128, 1))
assert_size_stride(arg225_1, (4096, 4096), (4096, 1))
assert_size_stride(arg226_1, (4096, 16), (16, 1))
assert_size_stride(arg227_1, (4096, 16), (16, 1))
assert_size_stride(arg228_1, (4096, ), (1, ))
assert_size_stride(arg229_1, (14336, 4096), (4096, 1))
assert_size_stride(arg230_1, (14336, 16), (16, 1))
assert_size_stride(arg231_1, (14336, 16), (16, 1))
assert_size_stride(arg232_1, (14336, 4096), (4096, 1))
assert_size_stride(arg233_1, (14336, 16), (16, 1))
assert_size_stride(arg234_1, (14336, 16), (16, 1))
assert_size_stride(arg235_1, (4096, 14336), (14336, 1))
assert_size_stride(arg236_1, (4096, 56), (56, 1))
assert_size_stride(arg237_1, (4096, 56), (56, 1))
assert_size_stride(arg238_1, (4096, ), (1, ))
assert_size_stride(arg239_1, (4096, 4096), (4096, 1))
assert_size_stride(arg240_1, (4096, 16), (16, 1))
assert_size_stride(arg241_1, (4096, 16), (16, 1))
assert_size_stride(arg242_1, (1024, 4096), (4096, 1))
assert_size_stride(arg243_1, (1024, 16), (16, 1))
assert_size_stride(arg244_1, (1024, 16), (16, 1))
assert_size_stride(arg245_1, (1024, 4096), (4096, 1))
assert_size_stride(arg246_1, (1024, 16), (16, 1))
assert_size_stride(arg247_1, (1024, 16), (16, 1))
assert_size_stride(arg248_1, (2048, 64, 2), (128, 2, 1))
assert_size_stride(arg249_1, (1, 32, 2048, 128), (8388608, 262144, 128, 1))
assert_size_stride(arg250_1, (1, 32, 2048, 128), (8388608, 262144, 128, 1))
assert_size_stride(arg251_1, (4096, 4096), (4096, 1))
assert_size_stride(arg252_1, (4096, 16), (16, 1))
assert_size_stride(arg253_1, (4096, 16), (16, 1))
assert_size_stride(arg254_1, (4096, ), (1, ))
assert_size_stride(arg255_1, (14336, 4096), (4096, 1))
assert_size_stride(arg256_1, (14336, 16), (16, 1))
assert_size_stride(arg257_1, (14336, 16), (16, 1))
assert_size_stride(arg258_1, (14336, 4096), (4096, 1))
assert_size_stride(arg259_1, (14336, 16), (16, 1))
assert_size_stride(arg260_1, (14336, 16), (16, 1))
assert_size_stride(arg261_1, (4096, 14336), (14336, 1))
assert_size_stride(arg262_1, (4096, 56), (56, 1))
assert_size_stride(arg263_1, (4096, 56), (56, 1))
assert_size_stride(arg264_1, (4096, ), (1, ))
assert_size_stride(arg265_1, (4096, 4096), (4096, 1))
assert_size_stride(arg266_1, (4096, 16), (16, 1))
assert_size_stride(arg267_1, (4096, 16), (16, 1))
assert_size_stride(arg268_1, (1024, 4096), (4096, 1))
assert_size_stride(arg269_1, (1024, 16), (16, 1))
assert_size_stride(arg270_1, (1024, 16), (16, 1))
assert_size_stride(arg271_1, (1024, 4096), (4096, 1))
assert_size_stride(arg272_1, (1024, 16), (16, 1))
assert_size_stride(arg273_1, (1024, 16), (16, 1))
assert_size_stride(arg274_1, (2048, 64, 2), (128, 2, 1))
assert_size_stride(arg275_1, (1, 32, 2048, 128), (8388608, 262144, 128, 1))
assert_size_stride(arg276_1, (1, 32, 2048, 128), (8388608, 262144, 128, 1))
assert_size_stride(arg277_1, (4096, 4096), (4096, 1))
assert_size_stride(arg278_1, (4096, 16), (16, 1))
assert_size_stride(arg279_1, (4096, 16), (16, 1))
assert_size_stride(arg280_1, (4096, ), (1, ))
assert_size_stride(arg281_1, (14336, 4096), (4096, 1))
assert_size_stride(arg282_1, (14336, 16), (16, 1))
assert_size_stride(arg283_1, (14336, 16), (16, 1))
assert_size_stride(arg284_1, (14336, 4096), (4096, 1))
assert_size_stride(arg285_1, (14336, 16), (16, 1))
assert_size_stride(arg286_1, (14336, 16), (16, 1))
assert_size_stride(arg287_1, (4096, 14336), (14336, 1))
assert_size_stride(arg288_1, (4096, 56), (56, 1))
assert_size_stride(arg289_1, (4096, 56), (56, 1))
assert_size_stride(arg290_1, (4096, ), (1, ))
assert_size_stride(arg291_1, (4096, 4096), (4096, 1))
assert_size_stride(arg292_1, (4096, 16), (16, 1))
assert_size_stride(arg293_1, (4096, 16), (16, 1))
assert_size_stride(arg294_1, (1024, 4096), (4096, 1))
assert_size_stride(arg295_1, (1024, 16), (16, 1))
assert_size_stride(arg296_1, (1024, 16), (16, 1))
assert_size_stride(arg297_1, (1024, 4096), (4096, 1))
assert_size_stride(arg298_1, (1024, 16), (16, 1))
assert_size_stride(arg299_1, (1024, 16), (16, 1))
assert_size_stride(arg300_1, (2048, 64, 2), (128, 2, 1))
assert_size_stride(arg301_1, (1, 32, 2048, 128), (8388608, 262144, 128, 1))
assert_size_stride(arg302_1, (1, 32, 2048, 128), (8388608, 262144, 128, 1))
assert_size_stride(arg303_1, (4096, 4096), (4096, 1))
assert_size_stride(arg304_1, (4096, 16), (16, 1))
assert_size_stride(arg305_1, (4096, 16), (16, 1))
assert_size_stride(arg306_1, (4096, ), (1, ))
assert_size_stride(arg307_1, (14336, 4096), (4096, 1))
assert_size_stride(arg308_1, (14336, 16), (16, 1))
assert_size_stride(arg309_1, (14336, 16), (16, 1))
assert_size_stride(arg310_1, (14336, 4096), (4096, 1))
assert_size_stride(arg311_1, (14336, 16), (16, 1))
assert_size_stride(arg312_1, (14336, 16), (16, 1))
assert_size_stride(arg313_1, (4096, 14336), (14336, 1))
assert_size_stride(arg314_1, (4096, 56), (56, 1))
assert_size_stride(arg315_1, (4096, 56), (56, 1))
assert_size_stride(arg316_1, (4096, ), (1, ))
assert_size_stride(arg317_1, (4096, 4096), (4096, 1))
assert_size_stride(arg318_1, (4096, 16), (16, 1))
assert_size_stride(arg319_1, (4096, 16), (16, 1))
assert_size_stride(arg320_1, (1024, 4096), (4096, 1))
assert_size_stride(arg321_1, (1024, 16), (16, 1))
assert_size_stride(arg322_1, (1024, 16), (16, 1))
assert_size_stride(arg323_1, (1024, 4096), (4096, 1))
assert_size_stride(arg324_1, (1024, 16), (16, 1))
assert_size_stride(arg325_1, (1024, 16), (16, 1))
assert_size_stride(arg326_1, (2048, 64, 2), (128, 2, 1))
assert_size_stride(arg327_1, (1, 32, 2048, 128), (8388608, 262144, 128, 1))
assert_size_stride(arg328_1, (1, 32, 2048, 128), (8388608, 262144, 128, 1))
assert_size_stride(arg329_1, (4096, 4096), (4096, 1))
assert_size_stride(arg330_1, (4096, 16), (16, 1))
assert_size_stride(arg331_1, (4096, 16), (16, 1))
assert_size_stride(arg332_1, (4096, ), (1, ))
assert_size_stride(arg333_1, (14336, 4096), (4096, 1))
assert_size_stride(arg334_1, (14336, 16), (16, 1))
assert_size_stride(arg335_1, (14336, 16), (16, 1))
assert_size_stride(arg336_1, (14336, 4096), (4096, 1))
assert_size_stride(arg337_1, (14336, 16), (16, 1))
assert_size_stride(arg338_1, (14336, 16), (16, 1))
assert_size_stride(arg339_1, (4096, 14336), (14336, 1))
assert_size_stride(arg340_1, (4096, 56), (56, 1))
assert_size_stride(arg341_1, (4096, 56), (56, 1))
assert_size_stride(arg342_1, (4096, ), (1, ))
assert_size_stride(arg343_1, (4096, 4096), (4096, 1))
assert_size_stride(arg344_1, (4096, 16), (16, 1))
assert_size_stride(arg345_1, (4096, 16), (16, 1))
assert_size_stride(arg346_1, (1024, 4096), (4096, 1))
assert_size_stride(arg347_1, (1024, 16), (16, 1))
assert_size_stride(arg348_1, (1024, 16), (16, 1))
assert_size_stride(arg349_1, (1024, 4096), (4096, 1))
assert_size_stride(arg350_1, (1024, 16), (16, 1))
assert_size_stride(arg351_1, (1024, 16), (16, 1))
assert_size_stride(arg352_1, (2048, 64, 2), (128, 2, 1))
assert_size_stride(arg353_1, (1, 32, 2048, 128), (8388608, 262144, 128, 1))
assert_size_stride(arg354_1, (1, 32, 2048, 128), (8388608, 262144, 128, 1))
assert_size_stride(arg355_1, (4096, 4096), (4096, 1))
assert_size_stride(arg356_1, (4096, 16), (16, 1))
assert_size_stride(arg357_1, (4096, 16), (16, 1))
assert_size_stride(arg358_1, (4096, ), (1, ))
assert_size_stride(arg359_1, (14336, 4096), (4096, 1))
assert_size_stride(arg360_1, (14336, 16), (16, 1))
assert_size_stride(arg361_1, (14336, 16), (16, 1))
assert_size_stride(arg362_1, (14336, 4096), (4096, 1))
assert_size_stride(arg363_1, (14336, 16), (16, 1))
assert_size_stride(arg364_1, (14336, 16), (16, 1))
assert_size_stride(arg365_1, (4096, 14336), (14336, 1))
assert_size_stride(arg366_1, (4096, 56), (56, 1))
assert_size_stride(arg367_1, (4096, 56), (56, 1))
assert_size_stride(arg368_1, (4096, ), (1, ))
assert_size_stride(arg369_1, (4096, 4096), (4096, 1))
assert_size_stride(arg370_1, (4096, 16), (16, 1))
assert_size_stride(arg371_1, (4096, 16), (16, 1))
assert_size_stride(arg372_1, (1024, 4096), (4096, 1))
assert_size_stride(arg373_1, (1024, 16), (16, 1))
assert_size_stride(arg374_1, (1024, 16), (16, 1))
assert_size_stride(arg375_1, (1024, 4096), (4096, 1))
assert_size_stride(arg376_1, (1024, 16), (16, 1))
assert_size_stride(arg377_1, (1024, 16), (16, 1))
assert_size_stride(arg378_1, (2048, 64, 2), (128, 2, 1))
assert_size_stride(arg379_1, (1, 32, 2048, 128), (8388608, 262144, 128, 1))
assert_size_stride(arg380_1, (1, 32, 2048, 128), (8388608, 262144, 128, 1))
assert_size_stride(arg381_1, (4096, 4096), (4096, 1))
assert_size_stride(arg382_1, (4096, 16), (16, 1))
assert_size_stride(arg383_1, (4096, 16), (16, 1))
assert_size_stride(arg384_1, (4096, ), (1, ))
assert_size_stride(arg385_1, (14336, 4096), (4096, 1))
assert_size_stride(arg386_1, (14336, 16), (16, 1))
assert_size_stride(arg387_1, (14336, 16), (16, 1))
assert_size_stride(arg388_1, (14336, 4096), (4096, 1))
assert_size_stride(arg389_1, (14336, 16), (16, 1))
assert_size_stride(arg390_1, (14336, 16), (16, 1))
assert_size_stride(arg391_1, (4096, 14336), (14336, 1))
assert_size_stride(arg392_1, (4096, 56), (56, 1))
assert_size_stride(arg393_1, (4096, 56), (56, 1))
assert_size_stride(arg394_1, (4096, ), (1, ))
assert_size_stride(arg395_1, (4096, 4096), (4096, 1))
assert_size_stride(arg396_1, (4096, 16), (16, 1))
assert_size_stride(arg397_1, (4096, 16), (16, 1))
assert_size_stride(arg398_1, (1024, 4096), (4096, 1))
assert_size_stride(arg399_1, (1024, 16), (16, 1))
assert_size_stride(arg400_1, (1024, 16), (16, 1))
assert_size_stride(arg401_1, (1024, 4096), (4096, 1))
assert_size_stride(arg402_1, (1024, 16), (16, 1))
assert_size_stride(arg403_1, (1024, 16), (16, 1))
assert_size_stride(arg404_1, (2048, 64, 2), (128, 2, 1))
assert_size_stride(arg405_1, (1, 32, 2048, 128), (8388608, 262144, 128, 1))
assert_size_stride(arg406_1, (1, 32, 2048, 128), (8388608, 262144, 128, 1))
assert_size_stride(arg407_1, (4096, 4096), (4096, 1))
assert_size_stride(arg408_1, (4096, 16), (16, 1))
assert_size_stride(arg409_1, (4096, 16), (16, 1))
assert_size_stride(arg410_1, (4096, ), (1, ))
assert_size_stride(arg411_1, (14336, 4096), (4096, 1))
assert_size_stride(arg412_1, (14336, 16), (16, 1))
assert_size_stride(arg413_1, (14336, 16), (16, 1))
assert_size_stride(arg414_1, (14336, 4096), (4096, 1))
assert_size_stride(arg415_1, (14336, 16), (16, 1))
assert_size_stride(arg416_1, (14336, 16), (16, 1))
assert_size_stride(arg417_1, (4096, 14336), (14336, 1))
assert_size_stride(arg418_1, (4096, 56), (56, 1))
assert_size_stride(arg419_1, (4096, 56), (56, 1))
assert_size_stride(arg420_1, (4096, ), (1, ))
assert_size_stride(arg421_1, (4096, 4096), (4096, 1))
assert_size_stride(arg422_1, (4096, 16), (16, 1))
assert_size_stride(arg423_1, (4096, 16), (16, 1))
assert_size_stride(arg424_1, (1024, 4096), (4096, 1))
assert_size_stride(arg425_1, (1024, 16), (16, 1))
assert_size_stride(arg426_1, (1024, 16), (16, 1))
assert_size_stride(arg427_1, (1024, 4096), (4096, 1))
assert_size_stride(arg428_1, (1024, 16), (16, 1))
assert_size_stride(arg429_1, (1024, 16), (16, 1))
assert_size_stride(arg430_1, (2048, 64, 2), (128, 2, 1))
assert_size_stride(arg431_1, (1, 32, 2048, 128), (8388608, 262144, 128, 1))
assert_size_stride(arg432_1, (1, 32, 2048, 128), (8388608, 262144, 128, 1))
assert_size_stride(arg433_1, (4096, 4096), (4096, 1))
assert_size_stride(arg434_1, (4096, 16), (16, 1))
assert_size_stride(arg435_1, (4096, 16), (16, 1))
assert_size_stride(arg436_1, (4096, ), (1, ))
assert_size_stride(arg437_1, (14336, 4096), (4096, 1))
assert_size_stride(arg438_1, (14336, 16), (16, 1))
assert_size_stride(arg439_1, (14336, 16), (16, 1))
assert_size_stride(arg440_1, (14336, 4096), (4096, 1))
assert_size_stride(arg441_1, (14336, 16), (16, 1))
assert_size_stride(arg442_1, (14336, 16), (16, 1))
assert_size_stride(arg443_1, (4096, 14336), (14336, 1))
assert_size_stride(arg444_1, (4096, 56), (56, 1))
assert_size_stride(arg445_1, (4096, 56), (56, 1))
assert_size_stride(arg446_1, (4096, ), (1, ))
assert_size_stride(arg447_1, (4096, 4096), (4096, 1))
assert_size_stride(arg448_1, (4096, 16), (16, 1))
assert_size_stride(arg449_1, (4096, 16), (16, 1))
assert_size_stride(arg450_1, (1024, 4096), (4096, 1))
assert_size_stride(arg451_1, (1024, 16), (16, 1))
assert_size_stride(arg452_1, (1024, 16), (16, 1))
assert_size_stride(arg453_1, (1024, 4096), (4096, 1))
assert_size_stride(arg454_1, (1024, 16), (16, 1))
assert_size_stride(arg455_1, (1024, 16), (16, 1))
assert_size_stride(arg456_1, (2048, 64, 2), (128, 2, 1))
assert_size_stride(arg457_1, (1, 32, 2048, 128), (8388608, 262144, 128, 1))
assert_size_stride(arg458_1, (1, 32, 2048, 128), (8388608, 262144, 128, 1))
assert_size_stride(arg459_1, (4096, 4096), (4096, 1))
assert_size_stride(arg460_1, (4096, 16), (16, 1))
assert_size_stride(arg461_1, (4096, 16), (16, 1))
assert_size_stride(arg462_1, (4096, ), (1, ))
assert_size_stride(arg463_1, (14336, 4096), (4096, 1))
assert_size_stride(arg464_1, (14336, 16), (16, 1))
assert_size_stride(arg465_1, (14336, 16), (16, 1))
assert_size_stride(arg466_1, (14336, 4096), (4096, 1))
assert_size_stride(arg467_1, (14336, 16), (16, 1))
assert_size_stride(arg468_1, (14336, 16), (16, 1))
assert_size_stride(arg469_1, (4096, 14336), (14336, 1))
assert_size_stride(arg470_1, (4096, 56), (56, 1))
assert_size_stride(arg471_1, (4096, 56), (56, 1))
assert_size_stride(arg472_1, (4096, ), (1, ))
assert_size_stride(arg473_1, (4096, 4096), (4096, 1))
assert_size_stride(arg474_1, (4096, 16), (16, 1))
assert_size_stride(arg475_1, (4096, 16), (16, 1))
assert_size_stride(arg476_1, (1024, 4096), (4096, 1))
assert_size_stride(arg477_1, (1024, 16), (16, 1))
assert_size_stride(arg478_1, (1024, 16), (16, 1))
assert_size_stride(arg479_1, (1024, 4096), (4096, 1))
assert_size_stride(arg480_1, (1024, 16), (16, 1))
assert_size_stride(arg481_1, (1024, 16), (16, 1))
assert_size_stride(arg482_1, (2048, 64, 2), (128, 2, 1))
assert_size_stride(arg483_1, (1, 32, 2048, 128), (8388608, 262144, 128, 1))
assert_size_stride(arg484_1, (1, 32, 2048, 128), (8388608, 262144, 128, 1))
assert_size_stride(arg485_1, (4096, 4096), (4096, 1))
assert_size_stride(arg486_1, (4096, 16), (16, 1))
assert_size_stride(arg487_1, (4096, 16), (16, 1))
assert_size_stride(arg488_1, (4096, ), (1, ))
assert_size_stride(arg489_1, (14336, 4096), (4096, 1))
assert_size_stride(arg490_1, (14336, 16), (16, 1))
assert_size_stride(arg491_1, (14336, 16), (16, 1))
assert_size_stride(arg492_1, (14336, 4096), (4096, 1))
assert_size_stride(arg493_1, (14336, 16), (16, 1))
assert_size_stride(arg494_1, (14336, 16), (16, 1))
assert_size_stride(arg495_1, (4096, 14336), (14336, 1))
assert_size_stride(arg496_1, (4096, 56), (56, 1))
assert_size_stride(arg497_1, (4096, 56), (56, 1))
assert_size_stride(arg498_1, (4096, ), (1, ))
assert_size_stride(arg499_1, (4096, 4096), (4096, 1))
assert_size_stride(arg500_1, (4096, 16), (16, 1))
assert_size_stride(arg501_1, (4096, 16), (16, 1))
assert_size_stride(arg502_1, (1024, 4096), (4096, 1))
assert_size_stride(arg503_1, (1024, 16), (16, 1))
assert_size_stride(arg504_1, (1024, 16), (16, 1))
assert_size_stride(arg505_1, (1024, 4096), (4096, 1))
assert_size_stride(arg506_1, (1024, 16), (16, 1))
assert_size_stride(arg507_1, (1024, 16), (16, 1))
assert_size_stride(arg508_1, (2048, 64, 2), (128, 2, 1))
assert_size_stride(arg509_1, (1, 32, 2048, 128), (8388608, 262144, 128, 1))
assert_size_stride(arg510_1, (1, 32, 2048, 128), (8388608, 262144, 128, 1))
assert_size_stride(arg511_1, (4096, 4096), (4096, 1))
assert_size_stride(arg512_1, (4096, 16), (16, 1))
assert_size_stride(arg513_1, (4096, 16), (16, 1))
assert_size_stride(arg514_1, (4096, ), (1, ))
assert_size_stride(arg515_1, (14336, 4096), (4096, 1))
assert_size_stride(arg516_1, (14336, 16), (16, 1))
assert_size_stride(arg517_1, (14336, 16), (16, 1))
assert_size_stride(arg518_1, (14336, 4096), (4096, 1))
assert_size_stride(arg519_1, (14336, 16), (16, 1))
assert_size_stride(arg520_1, (14336, 16), (16, 1))
assert_size_stride(arg521_1, (4096, 14336), (14336, 1))
assert_size_stride(arg522_1, (4096, 56), (56, 1))
assert_size_stride(arg523_1, (4096, 56), (56, 1))
assert_size_stride(arg524_1, (4096, ), (1, ))
assert_size_stride(arg525_1, (4096, 4096), (4096, 1))
assert_size_stride(arg526_1, (4096, 16), (16, 1))
assert_size_stride(arg527_1, (4096, 16), (16, 1))
assert_size_stride(arg528_1, (1024, 4096), (4096, 1))
assert_size_stride(arg529_1, (1024, 16), (16, 1))
assert_size_stride(arg530_1, (1024, 16), (16, 1))
assert_size_stride(arg531_1, (1024, 4096), (4096, 1))
assert_size_stride(arg532_1, (1024, 16), (16, 1))
assert_size_stride(arg533_1, (1024, 16), (16, 1))
assert_size_stride(arg534_1, (2048, 64, 2), (128, 2, 1))
assert_size_stride(arg535_1, (1, 32, 2048, 128), (8388608, 262144, 128, 1))
assert_size_stride(arg536_1, (1, 32, 2048, 128), (8388608, 262144, 128, 1))
assert_size_stride(arg537_1, (4096, 4096), (4096, 1))
assert_size_stride(arg538_1, (4096, 16), (16, 1))
assert_size_stride(arg539_1, (4096, 16), (16, 1))
assert_size_stride(arg540_1, (4096, ), (1, ))
assert_size_stride(arg541_1, (14336, 4096), (4096, 1))
assert_size_stride(arg542_1, (14336, 16), (16, 1))
assert_size_stride(arg543_1, (14336, 16), (16, 1))
assert_size_stride(arg544_1, (14336, 4096), (4096, 1))
assert_size_stride(arg545_1, (14336, 16), (16, 1))
assert_size_stride(arg546_1, (14336, 16), (16, 1))
assert_size_stride(arg547_1, (4096, 14336), (14336, 1))
assert_size_stride(arg548_1, (4096, 56), (56, 1))
assert_size_stride(arg549_1, (4096, 56), (56, 1))
assert_size_stride(arg550_1, (4096, ), (1, ))
assert_size_stride(arg551_1, (4096, 4096), (4096, 1))
assert_size_stride(arg552_1, (4096, 16), (16, 1))
assert_size_stride(arg553_1, (4096, 16), (16, 1))
assert_size_stride(arg554_1, (1024, 4096), (4096, 1))
assert_size_stride(arg555_1, (1024, 16), (16, 1))
assert_size_stride(arg556_1, (1024, 16), (16, 1))
assert_size_stride(arg557_1, (1024, 4096), (4096, 1))
assert_size_stride(arg558_1, (1024, 16), (16, 1))
assert_size_stride(arg559_1, (1024, 16), (16, 1))
assert_size_stride(arg560_1, (2048, 64, 2), (128, 2, 1))
assert_size_stride(arg561_1, (1, 32, 2048, 128), (8388608, 262144, 128, 1))
assert_size_stride(arg562_1, (1, 32, 2048, 128), (8388608, 262144, 128, 1))
assert_size_stride(arg563_1, (4096, 4096), (4096, 1))
assert_size_stride(arg564_1, (4096, 16), (16, 1))
assert_size_stride(arg565_1, (4096, 16), (16, 1))
assert_size_stride(arg566_1, (4096, ), (1, ))
assert_size_stride(arg567_1, (14336, 4096), (4096, 1))
assert_size_stride(arg568_1, (14336, 16), (16, 1))
assert_size_stride(arg569_1, (14336, 16), (16, 1))
assert_size_stride(arg570_1, (14336, 4096), (4096, 1))
assert_size_stride(arg571_1, (14336, 16), (16, 1))
assert_size_stride(arg572_1, (14336, 16), (16, 1))
assert_size_stride(arg573_1, (4096, 14336), (14336, 1))
assert_size_stride(arg574_1, (4096, 56), (56, 1))
assert_size_stride(arg575_1, (4096, 56), (56, 1))
assert_size_stride(arg576_1, (4096, ), (1, ))
assert_size_stride(arg577_1, (4096, 4096), (4096, 1))
assert_size_stride(arg578_1, (4096, 16), (16, 1))
assert_size_stride(arg579_1, (4096, 16), (16, 1))
assert_size_stride(arg580_1, (1024, 4096), (4096, 1))
assert_size_stride(arg581_1, (1024, 16), (16, 1))
assert_size_stride(arg582_1, (1024, 16), (16, 1))
assert_size_stride(arg583_1, (1024, 4096), (4096, 1))
assert_size_stride(arg584_1, (1024, 16), (16, 1))
assert_size_stride(arg585_1, (1024, 16), (16, 1))
assert_size_stride(arg586_1, (2048, 64, 2), (128, 2, 1))
assert_size_stride(arg587_1, (1, 32, 2048, 128), (8388608, 262144, 128, 1))
assert_size_stride(arg588_1, (1, 32, 2048, 128), (8388608, 262144, 128, 1))
assert_size_stride(arg589_1, (4096, 4096), (4096, 1))
assert_size_stride(arg590_1, (4096, 16), (16, 1))
assert_size_stride(arg591_1, (4096, 16), (16, 1))
assert_size_stride(arg592_1, (4096, ), (1, ))
assert_size_stride(arg593_1, (14336, 4096), (4096, 1))
assert_size_stride(arg594_1, (14336, 16), (16, 1))
assert_size_stride(arg595_1, (14336, 16), (16, 1))
assert_size_stride(arg596_1, (14336, 4096), (4096, 1))
assert_size_stride(arg597_1, (14336, 16), (16, 1))
assert_size_stride(arg598_1, (14336, 16), (16, 1))
assert_size_stride(arg599_1, (4096, 14336), (14336, 1))
assert_size_stride(arg600_1, (4096, 56), (56, 1))
assert_size_stride(arg601_1, (4096, 56), (56, 1))
assert_size_stride(arg602_1, (4096, ), (1, ))
assert_size_stride(arg603_1, (4096, 4096), (4096, 1))
assert_size_stride(arg604_1, (4096, 16), (16, 1))
assert_size_stride(arg605_1, (4096, 16), (16, 1))
assert_size_stride(arg606_1, (1024, 4096), (4096, 1))
assert_size_stride(arg607_1, (1024, 16), (16, 1))
assert_size_stride(arg608_1, (1024, 16), (16, 1))
assert_size_stride(arg609_1, (1024, 4096), (4096, 1))
assert_size_stride(arg610_1, (1024, 16), (16, 1))
assert_size_stride(arg611_1, (1024, 16), (16, 1))
assert_size_stride(arg612_1, (2048, 64, 2), (128, 2, 1))
assert_size_stride(arg613_1, (1, 32, 2048, 128), (8388608, 262144, 128, 1))
assert_size_stride(arg614_1, (1, 32, 2048, 128), (8388608, 262144, 128, 1))
assert_size_stride(arg615_1, (4096, 4096), (4096, 1))
assert_size_stride(arg616_1, (4096, 16), (16, 1))
assert_size_stride(arg617_1, (4096, 16), (16, 1))
assert_size_stride(arg618_1, (4096, ), (1, ))
assert_size_stride(arg619_1, (14336, 4096), (4096, 1))
assert_size_stride(arg620_1, (14336, 16), (16, 1))
assert_size_stride(arg621_1, (14336, 16), (16, 1))
assert_size_stride(arg622_1, (14336, 4096), (4096, 1))
assert_size_stride(arg623_1, (14336, 16), (16, 1))
assert_size_stride(arg624_1, (14336, 16), (16, 1))
assert_size_stride(arg625_1, (4096, 14336), (14336, 1))
assert_size_stride(arg626_1, (4096, 56), (56, 1))
assert_size_stride(arg627_1, (4096, 56), (56, 1))
assert_size_stride(arg628_1, (4096, ), (1, ))
assert_size_stride(arg629_1, (4096, 4096), (4096, 1))
assert_size_stride(arg630_1, (4096, 16), (16, 1))
assert_size_stride(arg631_1, (4096, 16), (16, 1))
assert_size_stride(arg632_1, (1024, 4096), (4096, 1))
assert_size_stride(arg633_1, (1024, 16), (16, 1))
assert_size_stride(arg634_1, (1024, 16), (16, 1))
assert_size_stride(arg635_1, (1024, 4096), (4096, 1))
assert_size_stride(arg636_1, (1024, 16), (16, 1))
assert_size_stride(arg637_1, (1024, 16), (16, 1))
assert_size_stride(arg638_1, (2048, 64, 2), (128, 2, 1))
assert_size_stride(arg639_1, (1, 32, 2048, 128), (8388608, 262144, 128, 1))
assert_size_stride(arg640_1, (1, 32, 2048, 128), (8388608, 262144, 128, 1))
assert_size_stride(arg641_1, (4096, 4096), (4096, 1))
assert_size_stride(arg642_1, (4096, 16), (16, 1))
assert_size_stride(arg643_1, (4096, 16), (16, 1))
assert_size_stride(arg644_1, (4096, ), (1, ))
assert_size_stride(arg645_1, (14336, 4096), (4096, 1))
assert_size_stride(arg646_1, (14336, 16), (16, 1))
assert_size_stride(arg647_1, (14336, 16), (16, 1))
assert_size_stride(arg648_1, (14336, 4096), (4096, 1))
assert_size_stride(arg649_1, (14336, 16), (16, 1))
assert_size_stride(arg650_1, (14336, 16), (16, 1))
assert_size_stride(arg651_1, (4096, 14336), (14336, 1))
assert_size_stride(arg652_1, (4096, 56), (56, 1))
assert_size_stride(arg653_1, (4096, 56), (56, 1))
assert_size_stride(arg654_1, (4096, ), (1, ))
assert_size_stride(arg655_1, (4096, 4096), (4096, 1))
assert_size_stride(arg656_1, (4096, 16), (16, 1))
assert_size_stride(arg657_1, (4096, 16), (16, 1))
assert_size_stride(arg658_1, (1024, 4096), (4096, 1))
assert_size_stride(arg659_1, (1024, 16), (16, 1))
assert_size_stride(arg660_1, (1024, 16), (16, 1))
assert_size_stride(arg661_1, (1024, 4096), (4096, 1))
assert_size_stride(arg662_1, (1024, 16), (16, 1))
assert_size_stride(arg663_1, (1024, 16), (16, 1))
assert_size_stride(arg664_1, (2048, 64, 2), (128, 2, 1))
assert_size_stride(arg665_1, (1, 32, 2048, 128), (8388608, 262144, 128, 1))
assert_size_stride(arg666_1, (1, 32, 2048, 128), (8388608, 262144, 128, 1))
assert_size_stride(arg667_1, (4096, 4096), (4096, 1))
assert_size_stride(arg668_1, (4096, 16), (16, 1))
assert_size_stride(arg669_1, (4096, 16), (16, 1))
assert_size_stride(arg670_1, (4096, ), (1, ))
assert_size_stride(arg671_1, (14336, 4096), (4096, 1))
assert_size_stride(arg672_1, (14336, 16), (16, 1))
assert_size_stride(arg673_1, (14336, 16), (16, 1))
assert_size_stride(arg674_1, (14336, 4096), (4096, 1))
assert_size_stride(arg675_1, (14336, 16), (16, 1))
assert_size_stride(arg676_1, (14336, 16), (16, 1))
assert_size_stride(arg677_1, (4096, 14336), (14336, 1))
assert_size_stride(arg678_1, (4096, 56), (56, 1))
assert_size_stride(arg679_1, (4096, 56), (56, 1))
assert_size_stride(arg680_1, (4096, ), (1, ))
assert_size_stride(arg681_1, (4096, 4096), (4096, 1))
assert_size_stride(arg682_1, (4096, 16), (16, 1))
assert_size_stride(arg683_1, (4096, 16), (16, 1))
assert_size_stride(arg684_1, (1024, 4096), (4096, 1))
assert_size_stride(arg685_1, (1024, 16), (16, 1))
assert_size_stride(arg686_1, (1024, 16), (16, 1))
assert_size_stride(arg687_1, (1024, 4096), (4096, 1))
assert_size_stride(arg688_1, (1024, 16), (16, 1))
assert_size_stride(arg689_1, (1024, 16), (16, 1))
assert_size_stride(arg690_1, (2048, 64, 2), (128, 2, 1))
assert_size_stride(arg691_1, (1, 32, 2048, 128), (8388608, 262144, 128, 1))
assert_size_stride(arg692_1, (1, 32, 2048, 128), (8388608, 262144, 128, 1))
assert_size_stride(arg693_1, (4096, 4096), (4096, 1))
assert_size_stride(arg694_1, (4096, 16), (16, 1))
assert_size_stride(arg695_1, (4096, 16), (16, 1))
assert_size_stride(arg696_1, (4096, ), (1, ))
assert_size_stride(arg697_1, (14336, 4096), (4096, 1))
assert_size_stride(arg698_1, (14336, 16), (16, 1))
assert_size_stride(arg699_1, (14336, 16), (16, 1))
assert_size_stride(arg700_1, (14336, 4096), (4096, 1))
assert_size_stride(arg701_1, (14336, 16), (16, 1))
assert_size_stride(arg702_1, (14336, 16), (16, 1))
assert_size_stride(arg703_1, (4096, 14336), (14336, 1))
assert_size_stride(arg704_1, (4096, 56), (56, 1))
assert_size_stride(arg705_1, (4096, 56), (56, 1))
assert_size_stride(arg706_1, (4096, ), (1, ))
assert_size_stride(arg707_1, (4096, 4096), (4096, 1))
assert_size_stride(arg708_1, (4096, 16), (16, 1))
assert_size_stride(arg709_1, (4096, 16), (16, 1))
assert_size_stride(arg710_1, (1024, 4096), (4096, 1))
assert_size_stride(arg711_1, (1024, 16), (16, 1))
assert_size_stride(arg712_1, (1024, 16), (16, 1))
assert_size_stride(arg713_1, (1024, 4096), (4096, 1))
assert_size_stride(arg714_1, (1024, 16), (16, 1))
assert_size_stride(arg715_1, (1024, 16), (16, 1))
assert_size_stride(arg716_1, (2048, 64, 2), (128, 2, 1))
assert_size_stride(arg717_1, (1, 32, 2048, 128), (8388608, 262144, 128, 1))
assert_size_stride(arg718_1, (1, 32, 2048, 128), (8388608, 262144, 128, 1))
assert_size_stride(arg719_1, (4096, 4096), (4096, 1))
assert_size_stride(arg720_1, (4096, 16), (16, 1))
assert_size_stride(arg721_1, (4096, 16), (16, 1))
assert_size_stride(arg722_1, (4096, ), (1, ))
assert_size_stride(arg723_1, (14336, 4096), (4096, 1))
assert_size_stride(arg724_1, (14336, 16), (16, 1))
assert_size_stride(arg725_1, (14336, 16), (16, 1))
assert_size_stride(arg726_1, (14336, 4096), (4096, 1))
assert_size_stride(arg727_1, (14336, 16), (16, 1))
assert_size_stride(arg728_1, (14336, 16), (16, 1))
assert_size_stride(arg729_1, (4096, 14336), (14336, 1))
assert_size_stride(arg730_1, (4096, 56), (56, 1))
assert_size_stride(arg731_1, (4096, 56), (56, 1))
assert_size_stride(arg732_1, (4096, ), (1, ))
assert_size_stride(arg733_1, (4096, 4096), (4096, 1))
assert_size_stride(arg734_1, (4096, 16), (16, 1))
assert_size_stride(arg735_1, (4096, 16), (16, 1))
assert_size_stride(arg736_1, (1024, 4096), (4096, 1))
assert_size_stride(arg737_1, (1024, 16), (16, 1))
assert_size_stride(arg738_1, (1024, 16), (16, 1))
assert_size_stride(arg739_1, (1024, 4096), (4096, 1))
assert_size_stride(arg740_1, (1024, 16), (16, 1))
assert_size_stride(arg741_1, (1024, 16), (16, 1))
assert_size_stride(arg742_1, (2048, 64, 2), (128, 2, 1))
assert_size_stride(arg743_1, (1, 32, 2048, 128), (8388608, 262144, 128, 1))
assert_size_stride(arg744_1, (1, 32, 2048, 128), (8388608, 262144, 128, 1))
assert_size_stride(arg745_1, (4096, 4096), (4096, 1))
assert_size_stride(arg746_1, (4096, 16), (16, 1))
assert_size_stride(arg747_1, (4096, 16), (16, 1))
assert_size_stride(arg748_1, (4096, ), (1, ))
assert_size_stride(arg749_1, (14336, 4096), (4096, 1))
assert_size_stride(arg750_1, (14336, 16), (16, 1))
assert_size_stride(arg751_1, (14336, 16), (16, 1))
assert_size_stride(arg752_1, (14336, 4096), (4096, 1))
assert_size_stride(arg753_1, (14336, 16), (16, 1))
assert_size_stride(arg754_1, (14336, 16), (16, 1))
assert_size_stride(arg755_1, (4096, 14336), (14336, 1))
assert_size_stride(arg756_1, (4096, 56), (56, 1))
assert_size_stride(arg757_1, (4096, 56), (56, 1))
assert_size_stride(arg758_1, (4096, ), (1, ))
assert_size_stride(arg759_1, (4096, 4096), (4096, 1))
assert_size_stride(arg760_1, (4096, 16), (16, 1))
assert_size_stride(arg761_1, (4096, 16), (16, 1))
assert_size_stride(arg762_1, (1024, 4096), (4096, 1))
assert_size_stride(arg763_1, (1024, 16), (16, 1))
assert_size_stride(arg764_1, (1024, 16), (16, 1))
assert_size_stride(arg765_1, (1024, 4096), (4096, 1))
assert_size_stride(arg766_1, (1024, 16), (16, 1))
assert_size_stride(arg767_1, (1024, 16), (16, 1))
assert_size_stride(arg768_1, (2048, 64, 2), (128, 2, 1))
assert_size_stride(arg769_1, (1, 32, 2048, 128), (8388608, 262144, 128, 1))
assert_size_stride(arg770_1, (1, 32, 2048, 128), (8388608, 262144, 128, 1))
assert_size_stride(arg771_1, (4096, 4096), (4096, 1))
assert_size_stride(arg772_1, (4096, 16), (16, 1))
assert_size_stride(arg773_1, (4096, 16), (16, 1))
assert_size_stride(arg774_1, (4096, ), (1, ))
assert_size_stride(arg775_1, (14336, 4096), (4096, 1))
assert_size_stride(arg776_1, (14336, 16), (16, 1))
assert_size_stride(arg777_1, (14336, 16), (16, 1))
assert_size_stride(arg778_1, (14336, 4096), (4096, 1))
assert_size_stride(arg779_1, (14336, 16), (16, 1))
assert_size_stride(arg780_1, (14336, 16), (16, 1))
assert_size_stride(arg781_1, (4096, 14336), (14336, 1))
assert_size_stride(arg782_1, (4096, 56), (56, 1))
assert_size_stride(arg783_1, (4096, 56), (56, 1))
assert_size_stride(arg784_1, (4096, ), (1, ))
assert_size_stride(arg785_1, (4096, 4096), (4096, 1))
assert_size_stride(arg786_1, (4096, 16), (16, 1))
assert_size_stride(arg787_1, (4096, 16), (16, 1))
assert_size_stride(arg788_1, (1024, 4096), (4096, 1))
assert_size_stride(arg789_1, (1024, 16), (16, 1))
assert_size_stride(arg790_1, (1024, 16), (16, 1))
assert_size_stride(arg791_1, (1024, 4096), (4096, 1))
assert_size_stride(arg792_1, (1024, 16), (16, 1))
assert_size_stride(arg793_1, (1024, 16), (16, 1))
assert_size_stride(arg794_1, (2048, 64, 2), (128, 2, 1))
assert_size_stride(arg795_1, (1, 32, 2048, 128), (8388608, 262144, 128, 1))
assert_size_stride(arg796_1, (1, 32, 2048, 128), (8388608, 262144, 128, 1))
assert_size_stride(arg797_1, (4096, 4096), (4096, 1))
assert_size_stride(arg798_1, (4096, 16), (16, 1))
assert_size_stride(arg799_1, (4096, 16), (16, 1))
assert_size_stride(arg800_1, (4096, ), (1, ))
assert_size_stride(arg801_1, (14336, 4096), (4096, 1))
assert_size_stride(arg802_1, (14336, 16), (16, 1))
assert_size_stride(arg803_1, (14336, 16), (16, 1))
assert_size_stride(arg804_1, (14336, 4096), (4096, 1))
assert_size_stride(arg805_1, (14336, 16), (16, 1))
assert_size_stride(arg806_1, (14336, 16), (16, 1))
assert_size_stride(arg807_1, (4096, 14336), (14336, 1))
assert_size_stride(arg808_1, (4096, 56), (56, 1))
assert_size_stride(arg809_1, (4096, 56), (56, 1))
assert_size_stride(arg810_1, (4096, ), (1, ))
assert_size_stride(arg811_1, (4096, 4096), (4096, 1))
assert_size_stride(arg812_1, (4096, 16), (16, 1))
assert_size_stride(arg813_1, (4096, 16), (16, 1))
assert_size_stride(arg814_1, (1024, 4096), (4096, 1))
assert_size_stride(arg815_1, (1024, 16), (16, 1))
assert_size_stride(arg816_1, (1024, 16), (16, 1))
assert_size_stride(arg817_1, (1024, 4096), (4096, 1))
assert_size_stride(arg818_1, (1024, 16), (16, 1))
assert_size_stride(arg819_1, (1024, 16), (16, 1))
assert_size_stride(arg820_1, (2048, 64, 2), (128, 2, 1))
assert_size_stride(arg821_1, (1, 32, 2048, 128), (8388608, 262144, 128, 1))
assert_size_stride(arg822_1, (1, 32, 2048, 128), (8388608, 262144, 128, 1))
assert_size_stride(arg823_1, (4096, 4096), (4096, 1))
assert_size_stride(arg824_1, (4096, 16), (16, 1))
assert_size_stride(arg825_1, (4096, 16), (16, 1))
assert_size_stride(arg826_1, (4096, ), (1, ))
assert_size_stride(arg827_1, (14336, 4096), (4096, 1))
assert_size_stride(arg828_1, (14336, 16), (16, 1))
assert_size_stride(arg829_1, (14336, 16), (16, 1))
assert_size_stride(arg830_1, (14336, 4096), (4096, 1))
assert_size_stride(arg831_1, (14336, 16), (16, 1))
assert_size_stride(arg832_1, (14336, 16), (16, 1))
assert_size_stride(arg833_1, (4096, 14336), (14336, 1))
assert_size_stride(arg834_1, (4096, 56), (56, 1))
assert_size_stride(arg835_1, (4096, 56), (56, 1))
assert_size_stride(arg836_1, (4096, ), (1, ))
assert_size_stride(arg837_1, (128256, 4096), (4096, 1))
assert_size_stride(arg838_1, (128256, 16), (16, 1))
assert_size_stride(arg839_1, (128256, 16), (16, 1))
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
buf1 = empty_strided_cuda((1, 1, 4096), (4096, 4096, 1), torch.bfloat16)
# Source Nodes: [add, h, mean, mul, pow_1, rsqrt, x_fp32, x_normed, y], Original ATen: [aten._to_copy, aten.add, aten.embedding, aten.mean, aten.mul, aten.pow, aten.rsqrt]
stream0 = get_raw_stream(0)
triton_red_fused__to_copy_add_embedding_mean_mul_pow_rsqrt_0.run(arg0_1, arg1_1, arg4_1, buf1, 1, 4096, grid=grid(1), stream=stream0)
del arg4_1
# Source Nodes: [add, choose_qparams_per_token_asymmetric, h, mean, mul, pow_1, rsqrt, x_fp32, x_normed, y], Original ATen: [aten._to_copy, aten.add, aten.embedding, aten.mean, aten.mul, aten.pow, aten.rsqrt, quantized_decomposed.choose_qparams_per_token_asymmetric]
buf2 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf1, torch.int8)
buf3 = buf2[0]
buf4 = buf2[1]
del buf2
# Source Nodes: [choose_qparams_per_token_asymmetric_1], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric]
buf5 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf1, torch.int8)
buf6 = buf5[0]
buf7 = buf5[1]
del buf5
# Source Nodes: [choose_qparams_per_token_asymmetric_2], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric]
buf8 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf1, torch.int8)
buf9 = buf8[0]
buf10 = buf8[1]
del buf8
buf11 = empty_strided_cuda((), (), torch.int64)
# Source Nodes: [max_1], Original ATen: [aten.max]
triton_poi_fused_max_1.run(arg3_1, buf11, 1, grid=grid(1), stream=stream0)
u0 = buf11.item()
buf12 = None
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
# Source Nodes: [input_2], Original ATen: [quantized_decomposed.quantize_per_token]
buf13 = torch.ops.quantized_decomposed.quantize_per_token.default(buf1, buf3, buf4, -128, 127, torch.int8)
buf14 = buf13
del buf13
# Source Nodes: [input_3], Original ATen: [quantized_decomposed.dequantize_per_token]
buf15 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf14, buf3, buf4, -128, 127, torch.int8, torch.bfloat16)
del buf14
del buf3
del buf4
buf16 = buf15
del buf15
# Source Nodes: [w_dq], Original ATen: [quantized_decomposed.dequantize_per_channel_group]
buf17 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg5_1, arg6_1, arg7_1, -8, 7, torch.int8, 256, torch.bfloat16)
del arg5_1
del arg6_1
del arg7_1
buf18 = buf17
del buf17
buf19 = empty_strided_cuda((1, 4096), (4096, 1), torch.bfloat16)
# Source Nodes: [c], Original ATen: [aten.mm]
triton_tem_fused_mm_2.run(buf16, buf18, buf19, grid=torch._inductor.kernel.mm_common.mm_grid(1, 4096, meta0), stream=stream0)
del buf16
del buf18
# Source Nodes: [input_5], Original ATen: [quantized_decomposed.quantize_per_token]
buf20 = torch.ops.quantized_decomposed.quantize_per_token.default(buf1, buf6, buf7, -128, 127, torch.int8)
buf21 = buf20
del buf20
# Source Nodes: [input_6], Original ATen: [quantized_decomposed.dequantize_per_token]
buf22 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf21, buf6, buf7, -128, 127, torch.int8, torch.bfloat16)
del buf21
del buf6
del buf7
buf23 = buf22
del buf22
# Source Nodes: [w_dq_1], Original ATen: [quantized_decomposed.dequantize_per_channel_group]
buf24 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg8_1, arg9_1, arg10_1, -8, 7, torch.int8, 256, torch.bfloat16)
del arg10_1
del arg8_1
del arg9_1
buf25 = buf24
del buf24
buf26 = empty_strided_cuda((1, 1024), (1024, 1), torch.bfloat16)
# Source Nodes: [], Original ATen: []
extern_kernels.mm(reinterpret_tensor(buf23, (1, 4096), (4096, 1), 0), reinterpret_tensor(buf25, (4096, 1024), (1, 4096), 0), out=buf26)
del buf23
del buf25
# Source Nodes: [input_8], Original ATen: [quantized_decomposed.quantize_per_token]
buf28 = torch.ops.quantized_decomposed.quantize_per_token.default(buf1, buf9, buf10, -128, 127, torch.int8)
del buf1
buf29 = buf28
del buf28
# Source Nodes: [input_9], Original ATen: [quantized_decomposed.dequantize_per_token]
buf30 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf29, buf9, buf10, -128, 127, torch.int8, torch.bfloat16)
del buf10
del buf29
del buf9
buf31 = buf30
del buf30
# Source Nodes: [w_dq_2], Original ATen: [quantized_decomposed.dequantize_per_channel_group]
buf32 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg11_1, arg12_1, arg13_1, -8, 7, torch.int8, 256, torch.bfloat16)
del arg11_1
del arg12_1
del arg13_1
buf33 = buf32
del buf32
buf34 = empty_strided_cuda((1, 1024), (1024, 1), torch.bfloat16)
# Source Nodes: [], Original ATen: []
extern_kernels.mm(reinterpret_tensor(buf31, (1, 4096), (4096, 1), 0), reinterpret_tensor(buf33, (4096, 1024), (1, 4096), 0), out=buf34)
del buf33
buf36 = reinterpret_tensor(buf31, (1, 32, 1, 128), (4096, 128, 4096, 1), 0); del buf31 # reuse
# Source Nodes: [output, setitem, setitem_1], Original ATen: [aten._scaled_dot_product_efficient_attention, aten.index_put]
triton_poi_fused__scaled_dot_product_efficient_attention_index_put_3.run(arg3_1, buf26, arg14_1, buf34, buf19, arg15_1, arg16_1, buf36, 4096, grid=grid(4096), stream=stream0)
del arg14_1
del buf19
buf37 = empty_strided_cuda((1, 32, 1, 2048), (65536, 2048, 2048, 1), torch.bfloat16)
# Source Nodes: [output], Original ATen: [aten._scaled_dot_product_efficient_attention]
triton_poi_fused__scaled_dot_product_efficient_attention_4.run(arg3_1, arg2_1, buf37, 65536, grid=grid(65536), stream=stream0)
# Source Nodes: [output], Original ATen: [aten._scaled_dot_product_efficient_attention]
buf38 = torch.ops.aten._scaled_dot_product_efficient_attention.default(buf36, arg15_1, arg16_1, buf37, False)
del arg15_1
del arg16_1
del buf36
buf39 = buf38[0]
del buf38
# Source Nodes: [choose_qparams_per_token_asymmetric_3], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric]
buf43 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(reinterpret_tensor(buf39, (1, 1, 4096), (4096, 4096, 1), 0), torch.int8)
buf44 = buf43[0]
buf45 = buf43[1]
del buf43
# Source Nodes: [input_11], Original ATen: [quantized_decomposed.quantize_per_token]
buf46 = torch.ops.quantized_decomposed.quantize_per_token.default(reinterpret_tensor(buf39, (1, 1, 4096), (4096, 4096, 1), 0), buf44, buf45, -128, 127, torch.int8)
buf47 = buf46
del buf46
# Source Nodes: [input_12], Original ATen: [quantized_decomposed.dequantize_per_token]
buf48 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf47, buf44, buf45, -128, 127, torch.int8, torch.bfloat16)
del buf44
del buf45
del buf47
buf49 = buf48
del buf48
# Source Nodes: [w_dq_3], Original ATen: [quantized_decomposed.dequantize_per_channel_group]
buf50 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg17_1, arg18_1, arg19_1, -8, 7, torch.int8, 256, torch.bfloat16)
del arg17_1
del arg18_1
del arg19_1
buf51 = buf50
del buf50
buf52 = reinterpret_tensor(buf39, (1, 4096), (4096, 1), 0); del buf39 # reuse
# Source Nodes: [c_3], Original ATen: [aten.mm]
triton_tem_fused_mm_2.run(buf49, buf51, buf52, grid=torch._inductor.kernel.mm_common.mm_grid(1, 4096, meta0), stream=stream0)
del buf51
buf54 = buf49; del buf49 # reuse
# Source Nodes: [add_5, h, h_1, mean_1, mul_10, mul_11, pow_2, rsqrt_1, x_fp32_1, x_normed_1], Original ATen: [aten._to_copy, aten.add, aten.embedding, aten.mean, aten.mul, aten.pow, aten.rsqrt]
triton_red_fused__to_copy_add_embedding_mean_mul_pow_rsqrt_5.run(buf52, arg0_1, arg1_1, arg20_1, buf54, 1, 4096, grid=grid(1), stream=stream0)
del arg20_1
# Source Nodes: [choose_qparams_per_token_asymmetric_4], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric]
buf55 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf54, torch.int8)
buf56 = buf55[0]
buf57 = buf55[1]
del buf55
# Source Nodes: [choose_qparams_per_token_asymmetric_5], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric]
buf58 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf54, torch.int8)
buf59 = buf58[0]
buf60 = buf58[1]
del buf58
# Source Nodes: [input_14], Original ATen: [quantized_decomposed.quantize_per_token]
buf61 = torch.ops.quantized_decomposed.quantize_per_token.default(buf54, buf56, buf57, -128, 127, torch.int8)
buf62 = buf61
del buf61
# Source Nodes: [input_15], Original ATen: [quantized_decomposed.dequantize_per_token]
buf63 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf62, buf56, buf57, -128, 127, torch.int8, torch.bfloat16)
del buf56
del buf57
del buf62
buf64 = buf63
del buf63
# Source Nodes: [w_dq_4], Original ATen: [quantized_decomposed.dequantize_per_channel_group]
buf65 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg21_1, arg22_1, arg23_1, -8, 7, torch.int8, 256, torch.bfloat16)
del arg21_1
del arg22_1
del arg23_1
buf66 = buf65
del buf65
buf67 = empty_strided_cuda((1, 14336), (14336, 1), torch.bfloat16)
# Source Nodes: [c_4], Original ATen: [aten.mm]
triton_tem_fused_mm_6.run(buf64, buf66, buf67, grid=torch._inductor.kernel.mm_common.mm_grid(1, 14336, meta1), stream=stream0)
del buf66
# Source Nodes: [input_17], Original ATen: [quantized_decomposed.quantize_per_token]
buf68 = torch.ops.quantized_decomposed.quantize_per_token.default(buf54, buf59, buf60, -128, 127, torch.int8)
buf69 = buf68
del buf68
# Source Nodes: [input_18], Original ATen: [quantized_decomposed.dequantize_per_token]
buf70 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf69, buf59, buf60, -128, 127, torch.int8, torch.bfloat16)
del buf59
del buf60
del buf69
buf71 = buf70
del buf70
# Source Nodes: [w_dq_5], Original ATen: [quantized_decomposed.dequantize_per_channel_group]
buf72 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg24_1, arg25_1, arg26_1, -8, 7, torch.int8, 256, torch.bfloat16)
del arg24_1
del arg25_1
del arg26_1
buf73 = buf72
del buf72
buf75 = empty_strided_cuda((1, 1, 14336), (14336, 14336, 1), torch.bfloat16)
# Source Nodes: [c_5, mul_12, silu], Original ATen: [aten.mm, aten.mul, aten.silu]
triton_tem_fused_mm_mul_silu_7.run(buf71, buf73, buf67, buf75, grid=torch._inductor.kernel.mm_common.mm_grid(1, 14336, meta1), stream=stream0)
del buf67
del buf73
# Source Nodes: [choose_qparams_per_token_asymmetric_6], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric]
buf76 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf75, torch.int8)
buf77 = buf76[0]
buf78 = buf76[1]
del buf76
# Source Nodes: [input_20], Original ATen: [quantized_decomposed.quantize_per_token]
buf79 = torch.ops.quantized_decomposed.quantize_per_token.default(buf75, buf77, buf78, -128, 127, torch.int8)
buf80 = buf79
del buf79
# Source Nodes: [input_21], Original ATen: [quantized_decomposed.dequantize_per_token]
buf81 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf80, buf77, buf78, -128, 127, torch.int8, torch.bfloat16)
del buf77
del buf78
del buf80
buf82 = buf81
del buf81
# Source Nodes: [w_dq_6], Original ATen: [quantized_decomposed.dequantize_per_channel_group]
buf83 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg27_1, arg28_1, arg29_1, -8, 7, torch.int8, 256, torch.bfloat16)
del arg27_1
del arg28_1
del arg29_1
buf84 = buf83
del buf83
buf85 = reinterpret_tensor(buf71, (1, 4096), (4096, 1), 0); del buf71 # reuse
# Source Nodes: [c_6], Original ATen: [aten.mm]
triton_tem_fused_mm_8.run(buf82, buf84, buf85, grid=torch._inductor.kernel.mm_common.mm_grid(1, 4096, meta0), stream=stream0)
del buf84
buf87 = buf54; del buf54 # reuse
# Source Nodes: [add_7, h, h_1, mean_2, mul_13, out, pow_3, rsqrt_2, x_fp32_2, x_normed_2, y_1], Original ATen: [aten._to_copy, aten.add, aten.embedding, aten.mean, aten.mul, aten.pow, aten.rsqrt]
triton_red_fused__to_copy_add_embedding_mean_mul_pow_rsqrt_9.run(buf52, arg0_1, arg1_1, buf85, arg30_1, buf87, 1, 4096, grid=grid(1), stream=stream0)
del arg30_1
# Source Nodes: [choose_qparams_per_token_asymmetric_7], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric]
buf88 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf87, torch.int8)
buf89 = buf88[0]
buf90 = buf88[1]
del buf88
# Source Nodes: [choose_qparams_per_token_asymmetric_8], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric]
buf91 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf87, torch.int8)
buf92 = buf91[0]
buf93 = buf91[1]
del buf91
# Source Nodes: [choose_qparams_per_token_asymmetric_9], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric]
buf94 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf87, torch.int8)
buf95 = buf94[0]
buf96 = buf94[1]
del buf94
buf97 = buf11; del buf11 # reuse
# Source Nodes: [max_2], Original ATen: [aten.max]
triton_poi_fused_max_1.run(arg3_1, buf97, 1, grid=grid(1), stream=stream0)
u1 = buf97.item()
buf98 = None
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
# Source Nodes: [input_23], Original ATen: [quantized_decomposed.quantize_per_token]
buf99 = torch.ops.quantized_decomposed.quantize_per_token.default(buf87, buf89, buf90, -128, 127, torch.int8)
buf100 = buf99
del buf99
# Source Nodes: [input_24], Original ATen: [quantized_decomposed.dequantize_per_token]
buf101 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf100, buf89, buf90, -128, 127, torch.int8, torch.bfloat16)
del buf100
del buf89
del buf90
buf102 = buf101
del buf101
# Source Nodes: [w_dq_7], Original ATen: [quantized_decomposed.dequantize_per_channel_group]
buf103 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg31_1, arg32_1, arg33_1, -8, 7, torch.int8, 256, torch.bfloat16)
del arg31_1
del arg32_1
del arg33_1
buf104 = buf103
del buf103
buf105 = reinterpret_tensor(buf64, (1, 4096), (4096, 1), 0); del buf64 # reuse
# Source Nodes: [c_7], Original ATen: [aten.mm]
triton_tem_fused_mm_2.run(buf102, buf104, buf105, grid=torch._inductor.kernel.mm_common.mm_grid(1, 4096, meta0), stream=stream0)
del buf102
del buf104
# Source Nodes: [input_26], Original ATen: [quantized_decomposed.quantize_per_token]
buf106 = torch.ops.quantized_decomposed.quantize_per_token.default(buf87, buf92, buf93, -128, 127, torch.int8)
buf107 = buf106
del buf106
# Source Nodes: [input_27], Original ATen: [quantized_decomposed.dequantize_per_token]
buf108 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf107, buf92, buf93, -128, 127, torch.int8, torch.bfloat16)
del buf107
del buf92
del buf93
buf109 = buf108
del buf108
# Source Nodes: [w_dq_8], Original ATen: [quantized_decomposed.dequantize_per_channel_group]
buf110 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg34_1, arg35_1, arg36_1, -8, 7, torch.int8, 256, torch.bfloat16)
del arg34_1
del arg35_1
del arg36_1
buf111 = buf110
del buf110
buf112 = buf34; del buf34 # reuse
# Source Nodes: [], Original ATen: []
extern_kernels.mm(reinterpret_tensor(buf109, (1, 4096), (4096, 1), 0), reinterpret_tensor(buf111, (4096, 1024), (1, 4096), 0), out=buf112)
del buf109
del buf111
# Source Nodes: [input_29], Original ATen: [quantized_decomposed.quantize_per_token]
buf114 = torch.ops.quantized_decomposed.quantize_per_token.default(buf87, buf95, buf96, -128, 127, torch.int8)
del buf87
buf115 = buf114
del buf114
# Source Nodes: [input_30], Original ATen: [quantized_decomposed.dequantize_per_token]
buf116 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf115, buf95, buf96, -128, 127, torch.int8, torch.bfloat16)
del buf115
del buf95
del buf96
buf117 = buf116
del buf116
# Source Nodes: [w_dq_9], Original ATen: [quantized_decomposed.dequantize_per_channel_group]
buf118 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg37_1, arg38_1, arg39_1, -8, 7, torch.int8, 256, torch.bfloat16)
del arg37_1
del arg38_1
del arg39_1
buf119 = buf118
del buf118
buf120 = buf26; del buf26 # reuse
# Source Nodes: [], Original ATen: []
extern_kernels.mm(reinterpret_tensor(buf117, (1, 4096), (4096, 1), 0), reinterpret_tensor(buf119, (4096, 1024), (1, 4096), 0), out=buf120)
del buf119
buf122 = reinterpret_tensor(buf117, (1, 32, 1, 128), (4096, 128, 4096, 1), 0); del buf117 # reuse
# Source Nodes: [output_2, setitem_2, setitem_3], Original ATen: [aten._scaled_dot_product_efficient_attention, aten.index_put]
triton_poi_fused__scaled_dot_product_efficient_attention_index_put_3.run(arg3_1, buf112, arg40_1, buf120, buf105, arg41_1, arg42_1, buf122, 4096, grid=grid(4096), stream=stream0)
del arg40_1
del buf105
buf123 = buf37; del buf37 # reuse
# Source Nodes: [output_2], Original ATen: [aten._scaled_dot_product_efficient_attention]
triton_poi_fused__scaled_dot_product_efficient_attention_4.run(arg3_1, arg2_1, buf123, 65536, grid=grid(65536), stream=stream0)
# Source Nodes: [output_2], Original ATen: [aten._scaled_dot_product_efficient_attention]
buf124 = torch.ops.aten._scaled_dot_product_efficient_attention.default(buf122, arg41_1, arg42_1, buf123, False)
del arg41_1
del arg42_1
del buf122
buf125 = buf124[0]
del buf124
# Source Nodes: [choose_qparams_per_token_asymmetric_10], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric]
buf129 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(reinterpret_tensor(buf125, (1, 1, 4096), (4096, 4096, 1), 0), torch.int8)
buf130 = buf129[0]
buf131 = buf129[1]
del buf129
# Source Nodes: [input_32], Original ATen: [quantized_decomposed.quantize_per_token]
buf132 = torch.ops.quantized_decomposed.quantize_per_token.default(reinterpret_tensor(buf125, (1, 1, 4096), (4096, 4096, 1), 0), buf130, buf131, -128, 127, torch.int8)
buf133 = buf132
del buf132
# Source Nodes: [input_33], Original ATen: [quantized_decomposed.dequantize_per_token]
buf134 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf133, buf130, buf131, -128, 127, torch.int8, torch.bfloat16)
del buf130
del buf131
del buf133
buf135 = buf134
del buf134
# Source Nodes: [w_dq_10], Original ATen: [quantized_decomposed.dequantize_per_channel_group]
buf136 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg43_1, arg44_1, arg45_1, -8, 7, torch.int8, 256, torch.bfloat16)
del arg43_1
del arg44_1
del arg45_1
buf137 = buf136
del buf136
buf138 = reinterpret_tensor(buf125, (1, 4096), (4096, 1), 0); del buf125 # reuse
# Source Nodes: [c_10], Original ATen: [aten.mm]
triton_tem_fused_mm_2.run(buf135, buf137, buf138, grid=torch._inductor.kernel.mm_common.mm_grid(1, 4096, meta0), stream=stream0)
del buf137
buf139 = reinterpret_tensor(buf138, (1, 1, 4096), (4096, 4096, 1), 0); del buf138 # reuse
buf141 = buf135; del buf135 # reuse
# Source Nodes: [add_12, h, h_1, h_2, mean_3, mul_23, mul_24, out, pow_4, rsqrt_3, x_fp32_3, x_normed_3], Original ATen: [aten._to_copy, aten.add, aten.embedding, aten.mean, aten.mul, aten.pow, aten.rsqrt]
triton_red_fused__to_copy_add_embedding_mean_mul_pow_rsqrt_10.run(buf139, buf52, arg0_1, arg1_1, buf85, arg46_1, buf141, 1, 4096, grid=grid(1), stream=stream0)
del arg0_1
del arg1_1
del arg46_1
del buf52
del buf85
# Source Nodes: [choose_qparams_per_token_asymmetric_11], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric]
buf142 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf141, torch.int8)
buf143 = buf142[0]
buf144 = buf142[1]
del buf142
# Source Nodes: [choose_qparams_per_token_asymmetric_12], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric]
buf145 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf141, torch.int8)
buf146 = buf145[0]
buf147 = buf145[1]
del buf145
# Source Nodes: [input_35], Original ATen: [quantized_decomposed.quantize_per_token]
buf148 = torch.ops.quantized_decomposed.quantize_per_token.default(buf141, buf143, buf144, -128, 127, torch.int8)
buf149 = buf148
del buf148
# Source Nodes: [input_36], Original ATen: [quantized_decomposed.dequantize_per_token]
buf150 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf149, buf143, buf144, -128, 127, torch.int8, torch.bfloat16)
del buf143
del buf144
del buf149
buf151 = buf150
del buf150
# Source Nodes: [w_dq_11], Original ATen: [quantized_decomposed.dequantize_per_channel_group]
buf152 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg47_1, arg48_1, arg49_1, -8, 7, torch.int8, 256, torch.bfloat16)
del arg47_1
del arg48_1
del arg49_1
buf153 = buf152
del buf152
buf154 = reinterpret_tensor(buf82, (1, 14336), (14336, 1), 0); del buf82 # reuse
# Source Nodes: [c_11], Original ATen: [aten.mm]
triton_tem_fused_mm_6.run(buf151, buf153, buf154, grid=torch._inductor.kernel.mm_common.mm_grid(1, 14336, meta1), stream=stream0)
del buf153
# Source Nodes: [input_38], Original ATen: [quantized_decomposed.quantize_per_token]
buf155 = torch.ops.quantized_decomposed.quantize_per_token.default(buf141, buf146, buf147, -128, 127, torch.int8)
buf156 = buf155
del buf155
# Source Nodes: [input_39], Original ATen: [quantized_decomposed.dequantize_per_token]
buf157 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf156, buf146, buf147, -128, 127, torch.int8, torch.bfloat16)
del buf146
del buf147
del buf156
buf158 = buf157
del buf157
# Source Nodes: [w_dq_12], Original ATen: [quantized_decomposed.dequantize_per_channel_group]
buf159 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg50_1, arg51_1, arg52_1, -8, 7, torch.int8, 256, torch.bfloat16)
del arg50_1
del arg51_1
del arg52_1
buf160 = buf159
del buf159
buf162 = buf75; del buf75 # reuse
# Source Nodes: [c_12, mul_25, silu_1], Original ATen: [aten.mm, aten.mul, aten.silu]
triton_tem_fused_mm_mul_silu_7.run(buf158, buf160, buf154, buf162, grid=torch._inductor.kernel.mm_common.mm_grid(1, 14336, meta1), stream=stream0)
del buf154
del buf160
# Source Nodes: [choose_qparams_per_token_asymmetric_13], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric]
buf163 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf162, torch.int8)
buf164 = buf163[0]
buf165 = buf163[1]
del buf163
# Source Nodes: [input_41], Original ATen: [quantized_decomposed.quantize_per_token]
buf166 = torch.ops.quantized_decomposed.quantize_per_token.default(buf162, buf164, buf165, -128, 127, torch.int8)
buf167 = buf166
del buf166
# Source Nodes: [input_42], Original ATen: [quantized_decomposed.dequantize_per_token]
buf168 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf167, buf164, buf165, -128, 127, torch.int8, torch.bfloat16)
del buf164
del buf165
del buf167
buf169 = buf168
del buf168
# Source Nodes: [w_dq_13], Original ATen: [quantized_decomposed.dequantize_per_channel_group]
buf170 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg53_1, arg54_1, arg55_1, -8, 7, torch.int8, 256, torch.bfloat16)
del arg53_1
del arg54_1
del arg55_1
buf171 = buf170
del buf170
buf172 = reinterpret_tensor(buf158, (1, 4096), (4096, 1), 0); del buf158 # reuse
# Source Nodes: [c_13], Original ATen: [aten.mm]
triton_tem_fused_mm_8.run(buf169, buf171, buf172, grid=torch._inductor.kernel.mm_common.mm_grid(1, 4096, meta0), stream=stream0)
del buf171
buf174 = buf141; del buf141 # reuse
# Source Nodes: [add_14, mean_4, mul_26, out_1, pow_5, rsqrt_4, x_fp32_4, x_normed_4, y_2], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, aten.pow, aten.rsqrt]
triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_11.run(buf139, buf172, arg56_1, buf174, 1, 4096, grid=grid(1), stream=stream0)
del arg56_1
# Source Nodes: [choose_qparams_per_token_asymmetric_14], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric]
buf175 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf174, torch.int8)
buf176 = buf175[0]
buf177 = buf175[1]
del buf175
# Source Nodes: [choose_qparams_per_token_asymmetric_15], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric]
buf178 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf174, torch.int8)
buf179 = buf178[0]
buf180 = buf178[1]
del buf178
# Source Nodes: [choose_qparams_per_token_asymmetric_16], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric]
buf181 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf174, torch.int8)
buf182 = buf181[0]
buf183 = buf181[1]
del buf181
buf184 = buf97; del buf97 # reuse
# Source Nodes: [max_3], Original ATen: [aten.max]
triton_poi_fused_max_1.run(arg3_1, buf184, 1, grid=grid(1), stream=stream0)
u2 = buf184.item()
buf185 = None
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
# Source Nodes: [input_44], Original ATen: [quantized_decomposed.quantize_per_token]
buf186 = torch.ops.quantized_decomposed.quantize_per_token.default(buf174, buf176, buf177, -128, 127, torch.int8)
buf187 = buf186
del buf186
# Source Nodes: [input_45], Original ATen: [quantized_decomposed.dequantize_per_token]
buf188 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf187, buf176, buf177, -128, 127, torch.int8, torch.bfloat16)
del buf176
del buf177
del buf187
buf189 = buf188
del buf188
# Source Nodes: [w_dq_14], Original ATen: [quantized_decomposed.dequantize_per_channel_group]
buf190 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg57_1, arg58_1, arg59_1, -8, 7, torch.int8, 256, torch.bfloat16)
del arg57_1
del arg58_1
del arg59_1
buf191 = buf190
del buf190
buf192 = reinterpret_tensor(buf151, (1, 4096), (4096, 1), 0); del buf151 # reuse
# Source Nodes: [c_14], Original ATen: [aten.mm]
triton_tem_fused_mm_2.run(buf189, buf191, buf192, grid=torch._inductor.kernel.mm_common.mm_grid(1, 4096, meta0), stream=stream0)
del buf189
del buf191
# Source Nodes: [input_47], Original ATen: [quantized_decomposed.quantize_per_token]
buf193 = torch.ops.quantized_decomposed.quantize_per_token.default(buf174, buf179, buf180, -128, 127, torch.int8)
buf194 = buf193
del buf193
# Source Nodes: [input_48], Original ATen: [quantized_decomposed.dequantize_per_token]
buf195 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf194, buf179, buf180, -128, 127, torch.int8, torch.bfloat16)
del buf179
del buf180
del buf194
buf196 = buf195
del buf195
# Source Nodes: [w_dq_15], Original ATen: [quantized_decomposed.dequantize_per_channel_group]
buf197 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg60_1, arg61_1, arg62_1, -8, 7, torch.int8, 256, torch.bfloat16)
del arg60_1
del arg61_1
del arg62_1
buf198 = buf197
del buf197
buf199 = buf120; del buf120 # reuse
# Source Nodes: [], Original ATen: []
extern_kernels.mm(reinterpret_tensor(buf196, (1, 4096), (4096, 1), 0), reinterpret_tensor(buf198, (4096, 1024), (1, 4096), 0), out=buf199)
del buf196
del buf198
# Source Nodes: [input_50], Original ATen: [quantized_decomposed.quantize_per_token]
buf201 = torch.ops.quantized_decomposed.quantize_per_token.default(buf174, buf182, buf183, -128, 127, torch.int8)
del buf174
buf202 = buf201
del buf201
# Source Nodes: [input_51], Original ATen: [quantized_decomposed.dequantize_per_token]
buf203 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf202, buf182, buf183, -128, 127, torch.int8, torch.bfloat16)
del buf182
del buf183
del buf202
buf204 = buf203
del buf203
# Source Nodes: [w_dq_16], Original ATen: [quantized_decomposed.dequantize_per_channel_group]
buf205 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg63_1, arg64_1, arg65_1, -8, 7, torch.int8, 256, torch.bfloat16)
del arg63_1
del arg64_1
del arg65_1
buf206 = buf205
del buf205
buf207 = buf112; del buf112 # reuse
# Source Nodes: [], Original ATen: []
extern_kernels.mm(reinterpret_tensor(buf204, (1, 4096), (4096, 1), 0), reinterpret_tensor(buf206, (4096, 1024), (1, 4096), 0), out=buf207)
del buf206
buf209 = reinterpret_tensor(buf204, (1, 32, 1, 128), (4096, 128, 4096, 1), 0); del buf204 # reuse
# Source Nodes: [output_4, setitem_4, setitem_5], Original ATen: [aten._scaled_dot_product_efficient_attention, aten.index_put]
triton_poi_fused__scaled_dot_product_efficient_attention_index_put_3.run(arg3_1, buf199, arg66_1, buf207, buf192, arg67_1, arg68_1, buf209, 4096, grid=grid(4096), stream=stream0)
del arg66_1
del buf192
buf210 = buf123; del buf123 # reuse
# Source Nodes: [output_4], Original ATen: [aten._scaled_dot_product_efficient_attention]
triton_poi_fused__scaled_dot_product_efficient_attention_4.run(arg3_1, arg2_1, buf210, 65536, grid=grid(65536), stream=stream0)
# Source Nodes: [output_4], Original ATen: [aten._scaled_dot_product_efficient_attention]
buf211 = torch.ops.aten._scaled_dot_product_efficient_attention.default(buf209, arg67_1, arg68_1, buf210, False)
del arg67_1
del arg68_1
del buf209
buf212 = buf211[0]
del buf211
# Source Nodes: [choose_qparams_per_token_asymmetric_17], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric]
buf216 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(reinterpret_tensor(buf212, (1, 1, 4096), (4096, 4096, 1), 0), torch.int8)
buf217 = buf216[0]
buf218 = buf216[1]
del buf216
# Source Nodes: [input_53], Original ATen: [quantized_decomposed.quantize_per_token]
buf219 = torch.ops.quantized_decomposed.quantize_per_token.default(reinterpret_tensor(buf212, (1, 1, 4096), (4096, 4096, 1), 0), buf217, buf218, -128, 127, torch.int8)
buf220 = buf219
del buf219
# Source Nodes: [input_54], Original ATen: [quantized_decomposed.dequantize_per_token]
buf221 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf220, buf217, buf218, -128, 127, torch.int8, torch.bfloat16)
del buf217
del buf218
del buf220
buf222 = buf221
del buf221
# Source Nodes: [w_dq_17], Original ATen: [quantized_decomposed.dequantize_per_channel_group]
buf223 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg69_1, arg70_1, arg71_1, -8, 7, torch.int8, 256, torch.bfloat16)
del arg69_1
del arg70_1
del arg71_1
buf224 = buf223
del buf223
buf225 = reinterpret_tensor(buf212, (1, 4096), (4096, 1), 0); del buf212 # reuse
# Source Nodes: [c_17], Original ATen: [aten.mm]
triton_tem_fused_mm_2.run(buf222, buf224, buf225, grid=torch._inductor.kernel.mm_common.mm_grid(1, 4096, meta0), stream=stream0)
del buf224
buf227 = buf222; del buf222 # reuse
# Source Nodes: [add_19, h_3, mean_5, mul_36, mul_37, out_1, pow_6, rsqrt_5, x_fp32_5, x_normed_5], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, aten.pow, aten.rsqrt]
triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_12.run(buf225, buf139, buf172, arg72_1, buf227, 1, 4096, grid=grid(1), stream=stream0)
del arg72_1
# Source Nodes: [choose_qparams_per_token_asymmetric_18], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric]
buf228 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf227, torch.int8)
buf229 = buf228[0]
buf230 = buf228[1]
del buf228
# Source Nodes: [choose_qparams_per_token_asymmetric_19], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric]
buf231 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf227, torch.int8)
buf232 = buf231[0]
buf233 = buf231[1]
del buf231
# Source Nodes: [input_56], Original ATen: [quantized_decomposed.quantize_per_token]
buf234 = torch.ops.quantized_decomposed.quantize_per_token.default(buf227, buf229, buf230, -128, 127, torch.int8)
buf235 = buf234
del buf234
# Source Nodes: [input_57], Original ATen: [quantized_decomposed.dequantize_per_token]
buf236 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf235, buf229, buf230, -128, 127, torch.int8, torch.bfloat16)
del buf229
del buf230
del buf235
buf237 = buf236
del buf236
# Source Nodes: [w_dq_18], Original ATen: [quantized_decomposed.dequantize_per_channel_group]
buf238 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg73_1, arg74_1, arg75_1, -8, 7, torch.int8, 256, torch.bfloat16)
del arg73_1
del arg74_1
del arg75_1
buf239 = buf238
del buf238
buf240 = reinterpret_tensor(buf169, (1, 14336), (14336, 1), 0); del buf169 # reuse
# Source Nodes: [c_18], Original ATen: [aten.mm]
triton_tem_fused_mm_6.run(buf237, buf239, buf240, grid=torch._inductor.kernel.mm_common.mm_grid(1, 14336, meta1), stream=stream0)
del buf239
# Source Nodes: [input_59], Original ATen: [quantized_decomposed.quantize_per_token]
buf241 = torch.ops.quantized_decomposed.quantize_per_token.default(buf227, buf232, buf233, -128, 127, torch.int8)
buf242 = buf241
del buf241
# Source Nodes: [input_60], Original ATen: [quantized_decomposed.dequantize_per_token]
buf243 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf242, buf232, buf233, -128, 127, torch.int8, torch.bfloat16)
del buf232
del buf233
del buf242
buf244 = buf243
del buf243
# Source Nodes: [w_dq_19], Original ATen: [quantized_decomposed.dequantize_per_channel_group]
buf245 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg76_1, arg77_1, arg78_1, -8, 7, torch.int8, 256, torch.bfloat16)
del arg76_1
del arg77_1
del arg78_1
buf246 = buf245
del buf245
buf248 = buf162; del buf162 # reuse
# Source Nodes: [c_19, mul_38, silu_2], Original ATen: [aten.mm, aten.mul, aten.silu]
triton_tem_fused_mm_mul_silu_7.run(buf244, buf246, buf240, buf248, grid=torch._inductor.kernel.mm_common.mm_grid(1, 14336, meta1), stream=stream0)
del buf240
del buf246
# Source Nodes: [choose_qparams_per_token_asymmetric_20], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric]
buf249 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf248, torch.int8)
buf250 = buf249[0]
buf251 = buf249[1]
del buf249
# Source Nodes: [input_62], Original ATen: [quantized_decomposed.quantize_per_token]
buf252 = torch.ops.quantized_decomposed.quantize_per_token.default(buf248, buf250, buf251, -128, 127, torch.int8)
buf253 = buf252
del buf252
# Source Nodes: [input_63], Original ATen: [quantized_decomposed.dequantize_per_token]
buf254 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf253, buf250, buf251, -128, 127, torch.int8, torch.bfloat16)
del buf250
del buf251
del buf253
buf255 = buf254
del buf254
# Source Nodes: [w_dq_20], Original ATen: [quantized_decomposed.dequantize_per_channel_group]
buf256 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg79_1, arg80_1, arg81_1, -8, 7, torch.int8, 256, torch.bfloat16)
del arg79_1
del arg80_1
del arg81_1
buf257 = buf256
del buf256
buf258 = reinterpret_tensor(buf244, (1, 4096), (4096, 1), 0); del buf244 # reuse
# Source Nodes: [c_20], Original ATen: [aten.mm]
triton_tem_fused_mm_8.run(buf255, buf257, buf258, grid=torch._inductor.kernel.mm_common.mm_grid(1, 4096, meta0), stream=stream0)
del buf257
buf260 = buf227; del buf227 # reuse
# Source Nodes: [add_21, h_3, mean_6, mul_39, out_1, out_2, pow_7, rsqrt_6, x_fp32_6, x_normed_6, y_3], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, aten.pow, aten.rsqrt]
triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_13.run(buf225, buf139, buf172, buf258, arg82_1, buf260, 1, 4096, grid=grid(1), stream=stream0)
del arg82_1
# Source Nodes: [choose_qparams_per_token_asymmetric_21], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric]
buf261 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf260, torch.int8)
buf262 = buf261[0]
buf263 = buf261[1]
del buf261
# Source Nodes: [choose_qparams_per_token_asymmetric_22], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric]
buf264 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf260, torch.int8)
buf265 = buf264[0]
buf266 = buf264[1]
del buf264
# Source Nodes: [choose_qparams_per_token_asymmetric_23], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric]
buf267 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf260, torch.int8)
buf268 = buf267[0]
buf269 = buf267[1]
del buf267
buf270 = buf184; del buf184 # reuse
# Source Nodes: [max_4], Original ATen: [aten.max]
triton_poi_fused_max_1.run(arg3_1, buf270, 1, grid=grid(1), stream=stream0)
u3 = buf270.item()
buf271 = None
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
# Source Nodes: [input_65], Original ATen: [quantized_decomposed.quantize_per_token]
buf272 = torch.ops.quantized_decomposed.quantize_per_token.default(buf260, buf262, buf263, -128, 127, torch.int8)
buf273 = buf272
del buf272
# Source Nodes: [input_66], Original ATen: [quantized_decomposed.dequantize_per_token]
buf274 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf273, buf262, buf263, -128, 127, torch.int8, torch.bfloat16)
del buf262
del buf263
del buf273
buf275 = buf274
del buf274
# Source Nodes: [w_dq_21], Original ATen: [quantized_decomposed.dequantize_per_channel_group]
buf276 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg83_1, arg84_1, arg85_1, -8, 7, torch.int8, 256, torch.bfloat16)
del arg83_1
del arg84_1
del arg85_1
buf277 = buf276
del buf276
buf278 = reinterpret_tensor(buf237, (1, 4096), (4096, 1), 0); del buf237 # reuse
# Source Nodes: [c_21], Original ATen: [aten.mm]
triton_tem_fused_mm_2.run(buf275, buf277, buf278, grid=torch._inductor.kernel.mm_common.mm_grid(1, 4096, meta0), stream=stream0)
del buf275
del buf277
# Source Nodes: [input_68], Original ATen: [quantized_decomposed.quantize_per_token]
buf279 = torch.ops.quantized_decomposed.quantize_per_token.default(buf260, buf265, buf266, -128, 127, torch.int8)
buf280 = buf279
del buf279
# Source Nodes: [input_69], Original ATen: [quantized_decomposed.dequantize_per_token]
buf281 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf280, buf265, buf266, -128, 127, torch.int8, torch.bfloat16)
del buf265
del buf266
del buf280
buf282 = buf281
del buf281
# Source Nodes: [w_dq_22], Original ATen: [quantized_decomposed.dequantize_per_channel_group]
buf283 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg86_1, arg87_1, arg88_1, -8, 7, torch.int8, 256, torch.bfloat16)
del arg86_1
del arg87_1
del arg88_1
buf284 = buf283
del buf283
buf285 = buf207; del buf207 # reuse
# Source Nodes: [], Original ATen: []
extern_kernels.mm(reinterpret_tensor(buf282, (1, 4096), (4096, 1), 0), reinterpret_tensor(buf284, (4096, 1024), (1, 4096), 0), out=buf285)
del buf282
del buf284
# Source Nodes: [input_71], Original ATen: [quantized_decomposed.quantize_per_token]
buf287 = torch.ops.quantized_decomposed.quantize_per_token.default(buf260, buf268, buf269, -128, 127, torch.int8)
del buf260
buf288 = buf287
del buf287
# Source Nodes: [input_72], Original ATen: [quantized_decomposed.dequantize_per_token]
buf289 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf288, buf268, buf269, -128, 127, torch.int8, torch.bfloat16)
del buf268
del buf269
del buf288
buf290 = buf289
del buf289
# Source Nodes: [w_dq_23], Original ATen: [quantized_decomposed.dequantize_per_channel_group]
buf291 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg89_1, arg90_1, arg91_1, -8, 7, torch.int8, 256, torch.bfloat16)
del arg89_1
del arg90_1
del arg91_1
buf292 = buf291
del buf291
buf293 = buf199; del buf199 # reuse
# Source Nodes: [], Original ATen: []
extern_kernels.mm(reinterpret_tensor(buf290, (1, 4096), (4096, 1), 0), reinterpret_tensor(buf292, (4096, 1024), (1, 4096), 0), out=buf293)
del buf292
buf295 = reinterpret_tensor(buf290, (1, 32, 1, 128), (4096, 128, 4096, 1), 0); del buf290 # reuse
# Source Nodes: [output_6, setitem_6, setitem_7], Original ATen: [aten._scaled_dot_product_efficient_attention, aten.index_put]
triton_poi_fused__scaled_dot_product_efficient_attention_index_put_3.run(arg3_1, buf285, arg92_1, buf293, buf278, arg93_1, arg94_1, buf295, 4096, grid=grid(4096), stream=stream0)
del arg92_1
del buf278
buf296 = buf210; del buf210 # reuse
# Source Nodes: [output_6], Original ATen: [aten._scaled_dot_product_efficient_attention]
triton_poi_fused__scaled_dot_product_efficient_attention_4.run(arg3_1, arg2_1, buf296, 65536, grid=grid(65536), stream=stream0)
# Source Nodes: [output_6], Original ATen: [aten._scaled_dot_product_efficient_attention]
buf297 = torch.ops.aten._scaled_dot_product_efficient_attention.default(buf295, arg93_1, arg94_1, buf296, False)
del arg93_1
del arg94_1
del buf295
buf298 = buf297[0]
del buf297
# Source Nodes: [choose_qparams_per_token_asymmetric_24], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric]
buf302 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(reinterpret_tensor(buf298, (1, 1, 4096), (4096, 4096, 1), 0), torch.int8)
buf303 = buf302[0]
buf304 = buf302[1]
del buf302
# Source Nodes: [input_74], Original ATen: [quantized_decomposed.quantize_per_token]
buf305 = torch.ops.quantized_decomposed.quantize_per_token.default(reinterpret_tensor(buf298, (1, 1, 4096), (4096, 4096, 1), 0), buf303, buf304, -128, 127, torch.int8)
buf306 = buf305
del buf305
# Source Nodes: [input_75], Original ATen: [quantized_decomposed.dequantize_per_token]
buf307 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf306, buf303, buf304, -128, 127, torch.int8, torch.bfloat16)
del buf303
del buf304
del buf306
buf308 = buf307
del buf307
# Source Nodes: [w_dq_24], Original ATen: [quantized_decomposed.dequantize_per_channel_group]
buf309 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg95_1, arg96_1, arg97_1, -8, 7, torch.int8, 256, torch.bfloat16)
del arg95_1
del arg96_1
del arg97_1
buf310 = buf309
del buf309
buf311 = reinterpret_tensor(buf298, (1, 4096), (4096, 1), 0); del buf298 # reuse
# Source Nodes: [c_24], Original ATen: [aten.mm]
triton_tem_fused_mm_2.run(buf308, buf310, buf311, grid=torch._inductor.kernel.mm_common.mm_grid(1, 4096, meta0), stream=stream0)
del buf310
buf312 = buf139; del buf139 # reuse
buf314 = buf308; del buf308 # reuse
# Source Nodes: [add_26, h_3, h_4, mean_7, mul_49, mul_50, out_1, out_2, pow_8, rsqrt_7, x_fp32_7, x_normed_7], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, aten.pow, aten.rsqrt]
triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_14.run(buf312, buf311, buf225, buf172, buf258, arg98_1, buf314, 1, 4096, grid=grid(1), stream=stream0)
del arg98_1
del buf172
del buf225
del buf258
del buf311
# Source Nodes: [choose_qparams_per_token_asymmetric_25], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric]
buf315 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf314, torch.int8)
buf316 = buf315[0]
buf317 = buf315[1]
del buf315
# Source Nodes: [choose_qparams_per_token_asymmetric_26], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric]
buf318 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf314, torch.int8)
buf319 = buf318[0]
buf320 = buf318[1]
del buf318
# Source Nodes: [input_77], Original ATen: [quantized_decomposed.quantize_per_token]
buf321 = torch.ops.quantized_decomposed.quantize_per_token.default(buf314, buf316, buf317, -128, 127, torch.int8)
buf322 = buf321
del buf321
# Source Nodes: [input_78], Original ATen: [quantized_decomposed.dequantize_per_token]
buf323 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf322, buf316, buf317, -128, 127, torch.int8, torch.bfloat16)
del buf316
del buf317
del buf322
buf324 = buf323
del buf323
# Source Nodes: [w_dq_25], Original ATen: [quantized_decomposed.dequantize_per_channel_group]
buf325 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg99_1, arg100_1, arg101_1, -8, 7, torch.int8, 256, torch.bfloat16)
del arg100_1
del arg101_1
del arg99_1
buf326 = buf325
del buf325
buf327 = reinterpret_tensor(buf255, (1, 14336), (14336, 1), 0); del buf255 # reuse
# Source Nodes: [c_25], Original ATen: [aten.mm]
triton_tem_fused_mm_6.run(buf324, buf326, buf327, grid=torch._inductor.kernel.mm_common.mm_grid(1, 14336, meta1), stream=stream0)
del buf326
# Source Nodes: [input_80], Original ATen: [quantized_decomposed.quantize_per_token]
buf328 = torch.ops.quantized_decomposed.quantize_per_token.default(buf314, buf319, buf320, -128, 127, torch.int8)
buf329 = buf328
del buf328
# Source Nodes: [input_81], Original ATen: [quantized_decomposed.dequantize_per_token]
buf330 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf329, buf319, buf320, -128, 127, torch.int8, torch.bfloat16)
del buf319
del buf320
del buf329
buf331 = buf330
del buf330
# Source Nodes: [w_dq_26], Original ATen: [quantized_decomposed.dequantize_per_channel_group]
buf332 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg102_1, arg103_1, arg104_1, -8, 7, torch.int8, 256, torch.bfloat16)
del arg102_1
del arg103_1
del arg104_1
buf333 = buf332
del buf332
buf335 = buf248; del buf248 # reuse
# Source Nodes: [c_26, mul_51, silu_3], Original ATen: [aten.mm, aten.mul, aten.silu]
triton_tem_fused_mm_mul_silu_7.run(buf331, buf333, buf327, buf335, grid=torch._inductor.kernel.mm_common.mm_grid(1, 14336, meta1), stream=stream0)
del buf327
del buf333
# Source Nodes: [choose_qparams_per_token_asymmetric_27], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric]
buf336 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf335, torch.int8)
buf337 = buf336[0]
buf338 = buf336[1]
del buf336
# Source Nodes: [input_83], Original ATen: [quantized_decomposed.quantize_per_token]
buf339 = torch.ops.quantized_decomposed.quantize_per_token.default(buf335, buf337, buf338, -128, 127, torch.int8)
buf340 = buf339
del buf339
# Source Nodes: [input_84], Original ATen: [quantized_decomposed.dequantize_per_token]
buf341 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf340, buf337, buf338, -128, 127, torch.int8, torch.bfloat16)
del buf337
del buf338
del buf340
buf342 = buf341
del buf341
# Source Nodes: [w_dq_27], Original ATen: [quantized_decomposed.dequantize_per_channel_group]
buf343 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg105_1, arg106_1, arg107_1, -8, 7, torch.int8, 256, torch.bfloat16)
del arg105_1
del arg106_1
del arg107_1
buf344 = buf343
del buf343
buf345 = reinterpret_tensor(buf331, (1, 4096), (4096, 1), 0); del buf331 # reuse
# Source Nodes: [c_27], Original ATen: [aten.mm]
triton_tem_fused_mm_8.run(buf342, buf344, buf345, grid=torch._inductor.kernel.mm_common.mm_grid(1, 4096, meta0), stream=stream0)
del buf344
buf347 = buf314; del buf314 # reuse
# Source Nodes: [add_28, mean_8, mul_52, out_3, pow_9, rsqrt_8, x_fp32_8, x_normed_8, y_4], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, aten.pow, aten.rsqrt]
triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_11.run(buf312, buf345, arg108_1, buf347, 1, 4096, grid=grid(1), stream=stream0)
del arg108_1
# Source Nodes: [choose_qparams_per_token_asymmetric_28], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric]
buf348 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf347, torch.int8)
buf349 = buf348[0]
buf350 = buf348[1]
del buf348
# Source Nodes: [choose_qparams_per_token_asymmetric_29], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric]
buf351 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf347, torch.int8)
buf352 = buf351[0]
buf353 = buf351[1]
del buf351
# Source Nodes: [choose_qparams_per_token_asymmetric_30], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric]
buf354 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf347, torch.int8)
buf355 = buf354[0]
buf356 = buf354[1]
del buf354
buf357 = buf270; del buf270 # reuse
# Source Nodes: [max_5], Original ATen: [aten.max]
triton_poi_fused_max_1.run(arg3_1, buf357, 1, grid=grid(1), stream=stream0)
u4 = buf357.item()
buf358 = None
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
# Source Nodes: [input_86], Original ATen: [quantized_decomposed.quantize_per_token]
buf359 = torch.ops.quantized_decomposed.quantize_per_token.default(buf347, buf349, buf350, -128, 127, torch.int8)
buf360 = buf359
del buf359
# Source Nodes: [input_87], Original ATen: [quantized_decomposed.dequantize_per_token]
buf361 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf360, buf349, buf350, -128, 127, torch.int8, torch.bfloat16)
del buf349
del buf350
del buf360
buf362 = buf361
del buf361
# Source Nodes: [w_dq_28], Original ATen: [quantized_decomposed.dequantize_per_channel_group]
buf363 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg109_1, arg110_1, arg111_1, -8, 7, torch.int8, 256, torch.bfloat16)
del arg109_1
del arg110_1
del arg111_1
buf364 = buf363
del buf363
buf365 = reinterpret_tensor(buf324, (1, 4096), (4096, 1), 0); del buf324 # reuse
# Source Nodes: [c_28], Original ATen: [aten.mm]
triton_tem_fused_mm_2.run(buf362, buf364, buf365, grid=torch._inductor.kernel.mm_common.mm_grid(1, 4096, meta0), stream=stream0)
del buf362
del buf364
# Source Nodes: [input_89], Original ATen: [quantized_decomposed.quantize_per_token]
buf366 = torch.ops.quantized_decomposed.quantize_per_token.default(buf347, buf352, buf353, -128, 127, torch.int8)
buf367 = buf366
del buf366
# Source Nodes: [input_90], Original ATen: [quantized_decomposed.dequantize_per_token]
buf368 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf367, buf352, buf353, -128, 127, torch.int8, torch.bfloat16)
del buf352
del buf353
del buf367
buf369 = buf368
del buf368
# Source Nodes: [w_dq_29], Original ATen: [quantized_decomposed.dequantize_per_channel_group]
buf370 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg112_1, arg113_1, arg114_1, -8, 7, torch.int8, 256, torch.bfloat16)
del arg112_1
del arg113_1
del arg114_1
buf371 = buf370
del buf370
buf372 = buf293; del buf293 # reuse
# Source Nodes: [], Original ATen: []
extern_kernels.mm(reinterpret_tensor(buf369, (1, 4096), (4096, 1), 0), reinterpret_tensor(buf371, (4096, 1024), (1, 4096), 0), out=buf372)
del buf369
del buf371
# Source Nodes: [input_92], Original ATen: [quantized_decomposed.quantize_per_token]
buf374 = torch.ops.quantized_decomposed.quantize_per_token.default(buf347, buf355, buf356, -128, 127, torch.int8)
del buf347
buf375 = buf374
del buf374
# Source Nodes: [input_93], Original ATen: [quantized_decomposed.dequantize_per_token]
buf376 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf375, buf355, buf356, -128, 127, torch.int8, torch.bfloat16)
del buf355
del buf356
del buf375
buf377 = buf376
del buf376
# Source Nodes: [w_dq_30], Original ATen: [quantized_decomposed.dequantize_per_channel_group]
buf378 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg115_1, arg116_1, arg117_1, -8, 7, torch.int8, 256, torch.bfloat16)
del arg115_1
del arg116_1
del arg117_1
buf379 = buf378
del buf378
buf380 = buf285; del buf285 # reuse
# Source Nodes: [], Original ATen: []
extern_kernels.mm(reinterpret_tensor(buf377, (1, 4096), (4096, 1), 0), reinterpret_tensor(buf379, (4096, 1024), (1, 4096), 0), out=buf380)
del buf379
buf382 = reinterpret_tensor(buf377, (1, 32, 1, 128), (4096, 128, 4096, 1), 0); del buf377 # reuse
# Source Nodes: [output_8, setitem_8, setitem_9], Original ATen: [aten._scaled_dot_product_efficient_attention, aten.index_put]
triton_poi_fused__scaled_dot_product_efficient_attention_index_put_3.run(arg3_1, buf372, arg118_1, buf380, buf365, arg119_1, arg120_1, buf382, 4096, grid=grid(4096), stream=stream0)
del arg118_1
del buf365
buf383 = buf296; del buf296 # reuse
# Source Nodes: [output_8], Original ATen: [aten._scaled_dot_product_efficient_attention]
triton_poi_fused__scaled_dot_product_efficient_attention_4.run(arg3_1, arg2_1, buf383, 65536, grid=grid(65536), stream=stream0)
# Source Nodes: [output_8], Original ATen: [aten._scaled_dot_product_efficient_attention]
buf384 = torch.ops.aten._scaled_dot_product_efficient_attention.default(buf382, arg119_1, arg120_1, buf383, False)
del arg119_1
del arg120_1
del buf382
buf385 = buf384[0]
del buf384
# Source Nodes: [choose_qparams_per_token_asymmetric_31], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric]
buf389 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(reinterpret_tensor(buf385, (1, 1, 4096), (4096, 4096, 1), 0), torch.int8)
buf390 = buf389[0]
buf391 = buf389[1]
del buf389
# Source Nodes: [input_95], Original ATen: [quantized_decomposed.quantize_per_token]
buf392 = torch.ops.quantized_decomposed.quantize_per_token.default(reinterpret_tensor(buf385, (1, 1, 4096), (4096, 4096, 1), 0), buf390, buf391, -128, 127, torch.int8)
buf393 = buf392
del buf392
# Source Nodes: [input_96], Original ATen: [quantized_decomposed.dequantize_per_token]
buf394 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf393, buf390, buf391, -128, 127, torch.int8, torch.bfloat16)
del buf390
del buf391
del buf393
buf395 = buf394
del buf394
# Source Nodes: [w_dq_31], Original ATen: [quantized_decomposed.dequantize_per_channel_group]
buf396 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg121_1, arg122_1, arg123_1, -8, 7, torch.int8, 256, torch.bfloat16)
del arg121_1
del arg122_1
del arg123_1
buf397 = buf396
del buf396
buf398 = reinterpret_tensor(buf385, (1, 4096), (4096, 1), 0); del buf385 # reuse
# Source Nodes: [c_31], Original ATen: [aten.mm]
triton_tem_fused_mm_2.run(buf395, buf397, buf398, grid=torch._inductor.kernel.mm_common.mm_grid(1, 4096, meta0), stream=stream0)
del buf397
buf400 = buf395; del buf395 # reuse
# Source Nodes: [add_33, h_5, mean_9, mul_62, mul_63, out_3, pow_10, rsqrt_9, x_fp32_9, x_normed_9], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, aten.pow, aten.rsqrt]
triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_12.run(buf398, buf312, buf345, arg124_1, buf400, 1, 4096, grid=grid(1), stream=stream0)
del arg124_1
# Source Nodes: [choose_qparams_per_token_asymmetric_32], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric]
buf401 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf400, torch.int8)
buf402 = buf401[0]
buf403 = buf401[1]
del buf401
# Source Nodes: [choose_qparams_per_token_asymmetric_33], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric]
buf404 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf400, torch.int8)
buf405 = buf404[0]
buf406 = buf404[1]
del buf404
# Source Nodes: [input_98], Original ATen: [quantized_decomposed.quantize_per_token]
buf407 = torch.ops.quantized_decomposed.quantize_per_token.default(buf400, buf402, buf403, -128, 127, torch.int8)
buf408 = buf407
del buf407
# Source Nodes: [input_99], Original ATen: [quantized_decomposed.dequantize_per_token]
buf409 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf408, buf402, buf403, -128, 127, torch.int8, torch.bfloat16)
del buf402
del buf403
del buf408
buf410 = buf409
del buf409
# Source Nodes: [w_dq_32], Original ATen: [quantized_decomposed.dequantize_per_channel_group]
buf411 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg125_1, arg126_1, arg127_1, -8, 7, torch.int8, 256, torch.bfloat16)
del arg125_1
del arg126_1
del arg127_1
buf412 = buf411
del buf411
buf413 = reinterpret_tensor(buf342, (1, 14336), (14336, 1), 0); del buf342 # reuse
# Source Nodes: [c_32], Original ATen: [aten.mm]
triton_tem_fused_mm_6.run(buf410, buf412, buf413, grid=torch._inductor.kernel.mm_common.mm_grid(1, 14336, meta1), stream=stream0)
del buf412
# Source Nodes: [input_101], Original ATen: [quantized_decomposed.quantize_per_token]
buf414 = torch.ops.quantized_decomposed.quantize_per_token.default(buf400, buf405, buf406, -128, 127, torch.int8)
buf415 = buf414
del buf414
# Source Nodes: [input_102], Original ATen: [quantized_decomposed.dequantize_per_token]
buf416 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf415, buf405, buf406, -128, 127, torch.int8, torch.bfloat16)
del buf405
del buf406
del buf415
buf417 = buf416
del buf416
# Source Nodes: [w_dq_33], Original ATen: [quantized_decomposed.dequantize_per_channel_group]
buf418 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg128_1, arg129_1, arg130_1, -8, 7, torch.int8, 256, torch.bfloat16)
del arg128_1
del arg129_1
del arg130_1
buf419 = buf418
del buf418
buf421 = buf335; del buf335 # reuse
# Source Nodes: [c_33, mul_64, silu_4], Original ATen: [aten.mm, aten.mul, aten.silu]
triton_tem_fused_mm_mul_silu_7.run(buf417, buf419, buf413, buf421, grid=torch._inductor.kernel.mm_common.mm_grid(1, 14336, meta1), stream=stream0)
del buf413
del buf419
# Source Nodes: [choose_qparams_per_token_asymmetric_34], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric]
buf422 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf421, torch.int8)
buf423 = buf422[0]
buf424 = buf422[1]
del buf422
# Source Nodes: [input_104], Original ATen: [quantized_decomposed.quantize_per_token]
buf425 = torch.ops.quantized_decomposed.quantize_per_token.default(buf421, buf423, buf424, -128, 127, torch.int8)
buf426 = buf425
del buf425
# Source Nodes: [input_105], Original ATen: [quantized_decomposed.dequantize_per_token]
buf427 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf426, buf423, buf424, -128, 127, torch.int8, torch.bfloat16)
del buf423
del buf424
del buf426
buf428 = buf427
del buf427
# Source Nodes: [w_dq_34], Original ATen: [quantized_decomposed.dequantize_per_channel_group]
buf429 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg131_1, arg132_1, arg133_1, -8, 7, torch.int8, 256, torch.bfloat16)
del arg131_1
del arg132_1
del arg133_1
buf430 = buf429
del buf429
buf431 = reinterpret_tensor(buf417, (1, 4096), (4096, 1), 0); del buf417 # reuse
# Source Nodes: [c_34], Original ATen: [aten.mm]
triton_tem_fused_mm_8.run(buf428, buf430, buf431, grid=torch._inductor.kernel.mm_common.mm_grid(1, 4096, meta0), stream=stream0)
del buf430
buf433 = buf400; del buf400 # reuse
# Source Nodes: [add_35, h_5, mean_10, mul_65, out_3, out_4, pow_11, rsqrt_10, x_fp32_10, x_normed_10, y_5], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, aten.pow, aten.rsqrt]
triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_13.run(buf398, buf312, buf345, buf431, arg134_1, buf433, 1, 4096, grid=grid(1), stream=stream0)
del arg134_1
# Source Nodes: [choose_qparams_per_token_asymmetric_35], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric]
buf434 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf433, torch.int8)
buf435 = buf434[0]
buf436 = buf434[1]
del buf434
# Source Nodes: [choose_qparams_per_token_asymmetric_36], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric]
buf437 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf433, torch.int8)
buf438 = buf437[0]
buf439 = buf437[1]
del buf437
# Source Nodes: [choose_qparams_per_token_asymmetric_37], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric]
buf440 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf433, torch.int8)
buf441 = buf440[0]
buf442 = buf440[1]
del buf440
buf443 = buf357; del buf357 # reuse
# Source Nodes: [max_6], Original ATen: [aten.max]
triton_poi_fused_max_1.run(arg3_1, buf443, 1, grid=grid(1), stream=stream0)
u5 = buf443.item()
buf444 = None
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
# Source Nodes: [input_107], Original ATen: [quantized_decomposed.quantize_per_token]
buf445 = torch.ops.quantized_decomposed.quantize_per_token.default(buf433, buf435, buf436, -128, 127, torch.int8)
buf446 = buf445
del buf445
# Source Nodes: [input_108], Original ATen: [quantized_decomposed.dequantize_per_token]
buf447 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf446, buf435, buf436, -128, 127, torch.int8, torch.bfloat16)
del buf435
del buf436
del buf446
buf448 = buf447
del buf447
# Source Nodes: [w_dq_35], Original ATen: [quantized_decomposed.dequantize_per_channel_group]
buf449 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg135_1, arg136_1, arg137_1, -8, 7, torch.int8, 256, torch.bfloat16)
del arg135_1
del arg136_1
del arg137_1
buf450 = buf449
del buf449
buf451 = reinterpret_tensor(buf410, (1, 4096), (4096, 1), 0); del buf410 # reuse
# Source Nodes: [c_35], Original ATen: [aten.mm]
triton_tem_fused_mm_2.run(buf448, buf450, buf451, grid=torch._inductor.kernel.mm_common.mm_grid(1, 4096, meta0), stream=stream0)
del buf448
del buf450
# Source Nodes: [input_110], Original ATen: [quantized_decomposed.quantize_per_token]
buf452 = torch.ops.quantized_decomposed.quantize_per_token.default(buf433, buf438, buf439, -128, 127, torch.int8)
buf453 = buf452
del buf452
# Source Nodes: [input_111], Original ATen: [quantized_decomposed.dequantize_per_token]
buf454 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf453, buf438, buf439, -128, 127, torch.int8, torch.bfloat16)
del buf438
del buf439
del buf453
buf455 = buf454
del buf454
# Source Nodes: [w_dq_36], Original ATen: [quantized_decomposed.dequantize_per_channel_group]
buf456 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg138_1, arg139_1, arg140_1, -8, 7, torch.int8, 256, torch.bfloat16)
del arg138_1
del arg139_1
del arg140_1
buf457 = buf456
del buf456
buf458 = buf380; del buf380 # reuse
# Source Nodes: [], Original ATen: []
extern_kernels.mm(reinterpret_tensor(buf455, (1, 4096), (4096, 1), 0), reinterpret_tensor(buf457, (4096, 1024), (1, 4096), 0), out=buf458)
del buf455
del buf457
# Source Nodes: [input_113], Original ATen: [quantized_decomposed.quantize_per_token]
buf460 = torch.ops.quantized_decomposed.quantize_per_token.default(buf433, buf441, buf442, -128, 127, torch.int8)
del buf433
buf461 = buf460
del buf460
# Source Nodes: [input_114], Original ATen: [quantized_decomposed.dequantize_per_token]
buf462 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf461, buf441, buf442, -128, 127, torch.int8, torch.bfloat16)
del buf441
del buf442
del buf461
buf463 = buf462
del buf462
# Source Nodes: [w_dq_37], Original ATen: [quantized_decomposed.dequantize_per_channel_group]
buf464 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg141_1, arg142_1, arg143_1, -8, 7, torch.int8, 256, torch.bfloat16)
del arg141_1
del arg142_1
del arg143_1
buf465 = buf464
del buf464
buf466 = buf372; del buf372 # reuse
# Source Nodes: [], Original ATen: []
extern_kernels.mm(reinterpret_tensor(buf463, (1, 4096), (4096, 1), 0), reinterpret_tensor(buf465, (4096, 1024), (1, 4096), 0), out=buf466)
del buf465
buf468 = reinterpret_tensor(buf463, (1, 32, 1, 128), (4096, 128, 4096, 1), 0); del buf463 # reuse
# Source Nodes: [output_10, setitem_10, setitem_11], Original ATen: [aten._scaled_dot_product_efficient_attention, aten.index_put]
triton_poi_fused__scaled_dot_product_efficient_attention_index_put_3.run(arg3_1, buf458, arg144_1, buf466, buf451, arg145_1, arg146_1, buf468, 4096, grid=grid(4096), stream=stream0)
del arg144_1
del buf451
buf469 = buf383; del buf383 # reuse
# Source Nodes: [output_10], Original ATen: [aten._scaled_dot_product_efficient_attention]
triton_poi_fused__scaled_dot_product_efficient_attention_4.run(arg3_1, arg2_1, buf469, 65536, grid=grid(65536), stream=stream0)
# Source Nodes: [output_10], Original ATen: [aten._scaled_dot_product_efficient_attention]
buf470 = torch.ops.aten._scaled_dot_product_efficient_attention.default(buf468, arg145_1, arg146_1, buf469, False)
del arg145_1
del arg146_1
del buf468
buf471 = buf470[0]
del buf470
# Source Nodes: [choose_qparams_per_token_asymmetric_38], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric]
buf475 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(reinterpret_tensor(buf471, (1, 1, 4096), (4096, 4096, 1), 0), torch.int8)
buf476 = buf475[0]
buf477 = buf475[1]
del buf475
# Source Nodes: [input_116], Original ATen: [quantized_decomposed.quantize_per_token]
buf478 = torch.ops.quantized_decomposed.quantize_per_token.default(reinterpret_tensor(buf471, (1, 1, 4096), (4096, 4096, 1), 0), buf476, buf477, -128, 127, torch.int8)
buf479 = buf478
del buf478
# Source Nodes: [input_117], Original ATen: [quantized_decomposed.dequantize_per_token]
buf480 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf479, buf476, buf477, -128, 127, torch.int8, torch.bfloat16)
del buf476
del buf477
del buf479
buf481 = buf480
del buf480
# Source Nodes: [w_dq_38], Original ATen: [quantized_decomposed.dequantize_per_channel_group]
buf482 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg147_1, arg148_1, arg149_1, -8, 7, torch.int8, 256, torch.bfloat16)
del arg147_1
del arg148_1
del arg149_1
buf483 = buf482
del buf482
buf484 = reinterpret_tensor(buf471, (1, 4096), (4096, 1), 0); del buf471 # reuse
# Source Nodes: [c_38], Original ATen: [aten.mm]
triton_tem_fused_mm_2.run(buf481, buf483, buf484, grid=torch._inductor.kernel.mm_common.mm_grid(1, 4096, meta0), stream=stream0)
del buf483
buf485 = buf312; del buf312 # reuse
buf487 = buf481; del buf481 # reuse
# Source Nodes: [add_40, h_5, h_6, mean_11, mul_75, mul_76, out_3, out_4, pow_12, rsqrt_11, x_fp32_11, x_normed_11], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, aten.pow, aten.rsqrt]
triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_14.run(buf485, buf484, buf398, buf345, buf431, arg150_1, buf487, 1, 4096, grid=grid(1), stream=stream0)
del arg150_1
del buf345
del buf398
del buf431
del buf484
# Source Nodes: [choose_qparams_per_token_asymmetric_39], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric]
buf488 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf487, torch.int8)
buf489 = buf488[0]
buf490 = buf488[1]
del buf488
# Source Nodes: [choose_qparams_per_token_asymmetric_40], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric]
buf491 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf487, torch.int8)
buf492 = buf491[0]
buf493 = buf491[1]
del buf491
# Source Nodes: [input_119], Original ATen: [quantized_decomposed.quantize_per_token]
buf494 = torch.ops.quantized_decomposed.quantize_per_token.default(buf487, buf489, buf490, -128, 127, torch.int8)
buf495 = buf494
del buf494
# Source Nodes: [input_120], Original ATen: [quantized_decomposed.dequantize_per_token]
buf496 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf495, buf489, buf490, -128, 127, torch.int8, torch.bfloat16)
del buf489
del buf490
del buf495
buf497 = buf496
del buf496
# Source Nodes: [w_dq_39], Original ATen: [quantized_decomposed.dequantize_per_channel_group]
buf498 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg151_1, arg152_1, arg153_1, -8, 7, torch.int8, 256, torch.bfloat16)
del arg151_1
del arg152_1
del arg153_1
buf499 = buf498
del buf498
buf500 = reinterpret_tensor(buf428, (1, 14336), (14336, 1), 0); del buf428 # reuse
# Source Nodes: [c_39], Original ATen: [aten.mm]
triton_tem_fused_mm_6.run(buf497, buf499, buf500, grid=torch._inductor.kernel.mm_common.mm_grid(1, 14336, meta1), stream=stream0)
del buf499
# Source Nodes: [input_122], Original ATen: [quantized_decomposed.quantize_per_token]
buf501 = torch.ops.quantized_decomposed.quantize_per_token.default(buf487, buf492, buf493, -128, 127, torch.int8)
buf502 = buf501
del buf501
# Source Nodes: [input_123], Original ATen: [quantized_decomposed.dequantize_per_token]
buf503 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf502, buf492, buf493, -128, 127, torch.int8, torch.bfloat16)
del buf492
del buf493
del buf502
buf504 = buf503
del buf503
# Source Nodes: [w_dq_40], Original ATen: [quantized_decomposed.dequantize_per_channel_group]
buf505 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg154_1, arg155_1, arg156_1, -8, 7, torch.int8, 256, torch.bfloat16)
del arg154_1
del arg155_1
del arg156_1
buf506 = buf505
del buf505
buf508 = buf421; del buf421 # reuse
# Source Nodes: [c_40, mul_77, silu_5], Original ATen: [aten.mm, aten.mul, aten.silu]
triton_tem_fused_mm_mul_silu_7.run(buf504, buf506, buf500, buf508, grid=torch._inductor.kernel.mm_common.mm_grid(1, 14336, meta1), stream=stream0)
del buf500
del buf506
# Source Nodes: [choose_qparams_per_token_asymmetric_41], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric]
buf509 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf508, torch.int8)
buf510 = buf509[0]
buf511 = buf509[1]
del buf509
# Source Nodes: [input_125], Original ATen: [quantized_decomposed.quantize_per_token]
buf512 = torch.ops.quantized_decomposed.quantize_per_token.default(buf508, buf510, buf511, -128, 127, torch.int8)
buf513 = buf512
del buf512
# Source Nodes: [input_126], Original ATen: [quantized_decomposed.dequantize_per_token]
buf514 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf513, buf510, buf511, -128, 127, torch.int8, torch.bfloat16)
del buf510
del buf511
del buf513
buf515 = buf514
del buf514
# Source Nodes: [w_dq_41], Original ATen: [quantized_decomposed.dequantize_per_channel_group]
buf516 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg157_1, arg158_1, arg159_1, -8, 7, torch.int8, 256, torch.bfloat16)
del arg157_1
del arg158_1
del arg159_1
buf517 = buf516
del buf516
buf518 = reinterpret_tensor(buf504, (1, 4096), (4096, 1), 0); del buf504 # reuse
# Source Nodes: [c_41], Original ATen: [aten.mm]
triton_tem_fused_mm_8.run(buf515, buf517, buf518, grid=torch._inductor.kernel.mm_common.mm_grid(1, 4096, meta0), stream=stream0)
del buf517
buf520 = buf487; del buf487 # reuse
# Source Nodes: [add_42, mean_12, mul_78, out_5, pow_13, rsqrt_12, x_fp32_12, x_normed_12, y_6], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, aten.pow, aten.rsqrt]
triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_11.run(buf485, buf518, arg160_1, buf520, 1, 4096, grid=grid(1), stream=stream0)
del arg160_1
# Source Nodes: [choose_qparams_per_token_asymmetric_42], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric]
buf521 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf520, torch.int8)
buf522 = buf521[0]
buf523 = buf521[1]
del buf521
# Source Nodes: [choose_qparams_per_token_asymmetric_43], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric]
buf524 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf520, torch.int8)
buf525 = buf524[0]
buf526 = buf524[1]
del buf524
# Source Nodes: [choose_qparams_per_token_asymmetric_44], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric]
buf527 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf520, torch.int8)
buf528 = buf527[0]
buf529 = buf527[1]
del buf527
buf530 = buf443; del buf443 # reuse
# Source Nodes: [max_7], Original ATen: [aten.max]
triton_poi_fused_max_1.run(arg3_1, buf530, 1, grid=grid(1), stream=stream0)
u6 = buf530.item()
buf531 = None
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
# Source Nodes: [input_128], Original ATen: [quantized_decomposed.quantize_per_token]
buf532 = torch.ops.quantized_decomposed.quantize_per_token.default(buf520, buf522, buf523, -128, 127, torch.int8)
buf533 = buf532
del buf532
# Source Nodes: [input_129], Original ATen: [quantized_decomposed.dequantize_per_token]
buf534 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf533, buf522, buf523, -128, 127, torch.int8, torch.bfloat16)
del buf522
del buf523
del buf533
buf535 = buf534
del buf534
# Source Nodes: [w_dq_42], Original ATen: [quantized_decomposed.dequantize_per_channel_group]
buf536 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg161_1, arg162_1, arg163_1, -8, 7, torch.int8, 256, torch.bfloat16)
del arg161_1
del arg162_1
del arg163_1
buf537 = buf536
del buf536
buf538 = reinterpret_tensor(buf497, (1, 4096), (4096, 1), 0); del buf497 # reuse
# Source Nodes: [c_42], Original ATen: [aten.mm]
triton_tem_fused_mm_2.run(buf535, buf537, buf538, grid=torch._inductor.kernel.mm_common.mm_grid(1, 4096, meta0), stream=stream0)
del buf535
del buf537
# Source Nodes: [input_131], Original ATen: [quantized_decomposed.quantize_per_token]
buf539 = torch.ops.quantized_decomposed.quantize_per_token.default(buf520, buf525, buf526, -128, 127, torch.int8)
buf540 = buf539
del buf539
# Source Nodes: [input_132], Original ATen: [quantized_decomposed.dequantize_per_token]
buf541 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf540, buf525, buf526, -128, 127, torch.int8, torch.bfloat16)
del buf525
del buf526
del buf540
buf542 = buf541
del buf541
# Source Nodes: [w_dq_43], Original ATen: [quantized_decomposed.dequantize_per_channel_group]
buf543 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg164_1, arg165_1, arg166_1, -8, 7, torch.int8, 256, torch.bfloat16)
del arg164_1
del arg165_1
del arg166_1
buf544 = buf543
del buf543
buf545 = buf466; del buf466 # reuse
# Source Nodes: [], Original ATen: []
extern_kernels.mm(reinterpret_tensor(buf542, (1, 4096), (4096, 1), 0), reinterpret_tensor(buf544, (4096, 1024), (1, 4096), 0), out=buf545)
del buf542
del buf544
# Source Nodes: [input_134], Original ATen: [quantized_decomposed.quantize_per_token]
buf547 = torch.ops.quantized_decomposed.quantize_per_token.default(buf520, buf528, buf529, -128, 127, torch.int8)
del buf520
buf548 = buf547
del buf547
# Source Nodes: [input_135], Original ATen: [quantized_decomposed.dequantize_per_token]
buf549 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf548, buf528, buf529, -128, 127, torch.int8, torch.bfloat16)
del buf528
del buf529
del buf548
buf550 = buf549
del buf549
# Source Nodes: [w_dq_44], Original ATen: [quantized_decomposed.dequantize_per_channel_group]
buf551 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg167_1, arg168_1, arg169_1, -8, 7, torch.int8, 256, torch.bfloat16)
del arg167_1
del arg168_1
del arg169_1
buf552 = buf551
del buf551
buf553 = buf458; del buf458 # reuse
# Source Nodes: [], Original ATen: []
extern_kernels.mm(reinterpret_tensor(buf550, (1, 4096), (4096, 1), 0), reinterpret_tensor(buf552, (4096, 1024), (1, 4096), 0), out=buf553)
del buf552
buf555 = reinterpret_tensor(buf550, (1, 32, 1, 128), (4096, 128, 4096, 1), 0); del buf550 # reuse
# Source Nodes: [output_12, setitem_12, setitem_13], Original ATen: [aten._scaled_dot_product_efficient_attention, aten.index_put]
triton_poi_fused__scaled_dot_product_efficient_attention_index_put_3.run(arg3_1, buf545, arg170_1, buf553, buf538, arg171_1, arg172_1, buf555, 4096, grid=grid(4096), stream=stream0)
del arg170_1
del buf538
buf556 = buf469; del buf469 # reuse
# Source Nodes: [output_12], Original ATen: [aten._scaled_dot_product_efficient_attention]
triton_poi_fused__scaled_dot_product_efficient_attention_4.run(arg3_1, arg2_1, buf556, 65536, grid=grid(65536), stream=stream0)
# Source Nodes: [output_12], Original ATen: [aten._scaled_dot_product_efficient_attention]
buf557 = torch.ops.aten._scaled_dot_product_efficient_attention.default(buf555, arg171_1, arg172_1, buf556, False)
del arg171_1
del arg172_1
del buf555
buf558 = buf557[0]
del buf557
# Source Nodes: [choose_qparams_per_token_asymmetric_45], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric]
buf562 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(reinterpret_tensor(buf558, (1, 1, 4096), (4096, 4096, 1), 0), torch.int8)
buf563 = buf562[0]
buf564 = buf562[1]
del buf562
# Source Nodes: [input_137], Original ATen: [quantized_decomposed.quantize_per_token]
buf565 = torch.ops.quantized_decomposed.quantize_per_token.default(reinterpret_tensor(buf558, (1, 1, 4096), (4096, 4096, 1), 0), buf563, buf564, -128, 127, torch.int8)
buf566 = buf565
del buf565
# Source Nodes: [input_138], Original ATen: [quantized_decomposed.dequantize_per_token]
buf567 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf566, buf563, buf564, -128, 127, torch.int8, torch.bfloat16)
del buf563
del buf564
del buf566
buf568 = buf567
del buf567
# Source Nodes: [w_dq_45], Original ATen: [quantized_decomposed.dequantize_per_channel_group]
buf569 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg173_1, arg174_1, arg175_1, -8, 7, torch.int8, 256, torch.bfloat16)
del arg173_1
del arg174_1
del arg175_1
buf570 = buf569
del buf569
buf571 = reinterpret_tensor(buf558, (1, 4096), (4096, 1), 0); del buf558 # reuse
# Source Nodes: [c_45], Original ATen: [aten.mm]
triton_tem_fused_mm_2.run(buf568, buf570, buf571, grid=torch._inductor.kernel.mm_common.mm_grid(1, 4096, meta0), stream=stream0)
del buf570
buf573 = buf568; del buf568 # reuse
# Source Nodes: [add_47, h_7, mean_13, mul_88, mul_89, out_5, pow_14, rsqrt_13, x_fp32_13, x_normed_13], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, aten.pow, aten.rsqrt]
triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_12.run(buf571, buf485, buf518, arg176_1, buf573, 1, 4096, grid=grid(1), stream=stream0)
del arg176_1
# Source Nodes: [choose_qparams_per_token_asymmetric_46], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric]
buf574 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf573, torch.int8)
buf575 = buf574[0]
buf576 = buf574[1]
del buf574
# Source Nodes: [choose_qparams_per_token_asymmetric_47], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric]
buf577 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf573, torch.int8)
buf578 = buf577[0]
buf579 = buf577[1]
del buf577
# Source Nodes: [input_140], Original ATen: [quantized_decomposed.quantize_per_token]
buf580 = torch.ops.quantized_decomposed.quantize_per_token.default(buf573, buf575, buf576, -128, 127, torch.int8)
buf581 = buf580
del buf580
# Source Nodes: [input_141], Original ATen: [quantized_decomposed.dequantize_per_token]
buf582 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf581, buf575, buf576, -128, 127, torch.int8, torch.bfloat16)
del buf575
del buf576
del buf581
buf583 = buf582
del buf582
# Source Nodes: [w_dq_46], Original ATen: [quantized_decomposed.dequantize_per_channel_group]
buf584 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg177_1, arg178_1, arg179_1, -8, 7, torch.int8, 256, torch.bfloat16)
del arg177_1
del arg178_1
del arg179_1
buf585 = buf584
del buf584
buf586 = reinterpret_tensor(buf515, (1, 14336), (14336, 1), 0); del buf515 # reuse
# Source Nodes: [c_46], Original ATen: [aten.mm]
triton_tem_fused_mm_6.run(buf583, buf585, buf586, grid=torch._inductor.kernel.mm_common.mm_grid(1, 14336, meta1), stream=stream0)
del buf585
# Source Nodes: [input_143], Original ATen: [quantized_decomposed.quantize_per_token]
buf587 = torch.ops.quantized_decomposed.quantize_per_token.default(buf573, buf578, buf579, -128, 127, torch.int8)
buf588 = buf587
del buf587
# Source Nodes: [input_144], Original ATen: [quantized_decomposed.dequantize_per_token]
buf589 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf588, buf578, buf579, -128, 127, torch.int8, torch.bfloat16)
del buf578
del buf579
del buf588
buf590 = buf589
del buf589
# Source Nodes: [w_dq_47], Original ATen: [quantized_decomposed.dequantize_per_channel_group]
buf591 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg180_1, arg181_1, arg182_1, -8, 7, torch.int8, 256, torch.bfloat16)
del arg180_1
del arg181_1
del arg182_1
buf592 = buf591
del buf591
buf594 = buf508; del buf508 # reuse
# Source Nodes: [c_47, mul_90, silu_6], Original ATen: [aten.mm, aten.mul, aten.silu]
triton_tem_fused_mm_mul_silu_7.run(buf590, buf592, buf586, buf594, grid=torch._inductor.kernel.mm_common.mm_grid(1, 14336, meta1), stream=stream0)
del buf586
del buf592
# Source Nodes: [choose_qparams_per_token_asymmetric_48], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric]
buf595 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf594, torch.int8)
buf596 = buf595[0]
buf597 = buf595[1]
del buf595
# Source Nodes: [input_146], Original ATen: [quantized_decomposed.quantize_per_token]
buf598 = torch.ops.quantized_decomposed.quantize_per_token.default(buf594, buf596, buf597, -128, 127, torch.int8)
buf599 = buf598
del buf598
# Source Nodes: [input_147], Original ATen: [quantized_decomposed.dequantize_per_token]
buf600 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf599, buf596, buf597, -128, 127, torch.int8, torch.bfloat16)
del buf596
del buf597
del buf599
buf601 = buf600
del buf600
# Source Nodes: [w_dq_48], Original ATen: [quantized_decomposed.dequantize_per_channel_group]
buf602 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg183_1, arg184_1, arg185_1, -8, 7, torch.int8, 256, torch.bfloat16)
del arg183_1
del arg184_1
del arg185_1
buf603 = buf602
del buf602
buf604 = reinterpret_tensor(buf590, (1, 4096), (4096, 1), 0); del buf590 # reuse
# Source Nodes: [c_48], Original ATen: [aten.mm]
triton_tem_fused_mm_8.run(buf601, buf603, buf604, grid=torch._inductor.kernel.mm_common.mm_grid(1, 4096, meta0), stream=stream0)
del buf603
buf606 = buf573; del buf573 # reuse
# Source Nodes: [add_49, h_7, mean_14, mul_91, out_5, out_6, pow_15, rsqrt_14, x_fp32_14, x_normed_14, y_7], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, aten.pow, aten.rsqrt]
triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_13.run(buf571, buf485, buf518, buf604, arg186_1, buf606, 1, 4096, grid=grid(1), stream=stream0)
del arg186_1
# Source Nodes: [choose_qparams_per_token_asymmetric_49], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric]
buf607 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf606, torch.int8)
buf608 = buf607[0]
buf609 = buf607[1]
del buf607
# Source Nodes: [choose_qparams_per_token_asymmetric_50], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric]
buf610 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf606, torch.int8)
buf611 = buf610[0]
buf612 = buf610[1]
del buf610
# Source Nodes: [choose_qparams_per_token_asymmetric_51], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric]
buf613 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf606, torch.int8)
buf614 = buf613[0]
buf615 = buf613[1]
del buf613
buf616 = buf530; del buf530 # reuse
# Source Nodes: [max_8], Original ATen: [aten.max]
triton_poi_fused_max_1.run(arg3_1, buf616, 1, grid=grid(1), stream=stream0)
u7 = buf616.item()
buf617 = None
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
# Source Nodes: [input_149], Original ATen: [quantized_decomposed.quantize_per_token]
buf618 = torch.ops.quantized_decomposed.quantize_per_token.default(buf606, buf608, buf609, -128, 127, torch.int8)
buf619 = buf618
del buf618
# Source Nodes: [input_150], Original ATen: [quantized_decomposed.dequantize_per_token]
buf620 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf619, buf608, buf609, -128, 127, torch.int8, torch.bfloat16)
del buf608
del buf609
del buf619
buf621 = buf620
del buf620
# Source Nodes: [w_dq_49], Original ATen: [quantized_decomposed.dequantize_per_channel_group]
buf622 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg187_1, arg188_1, arg189_1, -8, 7, torch.int8, 256, torch.bfloat16)
del arg187_1
del arg188_1
del arg189_1
buf623 = buf622
del buf622
buf624 = reinterpret_tensor(buf583, (1, 4096), (4096, 1), 0); del buf583 # reuse
# Source Nodes: [c_49], Original ATen: [aten.mm]
triton_tem_fused_mm_2.run(buf621, buf623, buf624, grid=torch._inductor.kernel.mm_common.mm_grid(1, 4096, meta0), stream=stream0)
del buf621
del buf623
# Source Nodes: [input_152], Original ATen: [quantized_decomposed.quantize_per_token]
buf625 = torch.ops.quantized_decomposed.quantize_per_token.default(buf606, buf611, buf612, -128, 127, torch.int8)
buf626 = buf625
del buf625
# Source Nodes: [input_153], Original ATen: [quantized_decomposed.dequantize_per_token]
buf627 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf626, buf611, buf612, -128, 127, torch.int8, torch.bfloat16)
del buf611
del buf612
del buf626
buf628 = buf627
del buf627
# Source Nodes: [w_dq_50], Original ATen: [quantized_decomposed.dequantize_per_channel_group]
buf629 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg190_1, arg191_1, arg192_1, -8, 7, torch.int8, 256, torch.bfloat16)
del arg190_1
del arg191_1
del arg192_1
buf630 = buf629
del buf629
buf631 = buf553; del buf553 # reuse
# Source Nodes: [], Original ATen: []
extern_kernels.mm(reinterpret_tensor(buf628, (1, 4096), (4096, 1), 0), reinterpret_tensor(buf630, (4096, 1024), (1, 4096), 0), out=buf631)
del buf628
del buf630
# Source Nodes: [input_155], Original ATen: [quantized_decomposed.quantize_per_token]
buf633 = torch.ops.quantized_decomposed.quantize_per_token.default(buf606, buf614, buf615, -128, 127, torch.int8)
del buf606
buf634 = buf633
del buf633
# Source Nodes: [input_156], Original ATen: [quantized_decomposed.dequantize_per_token]
buf635 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf634, buf614, buf615, -128, 127, torch.int8, torch.bfloat16)
del buf614
del buf615
del buf634
buf636 = buf635
del buf635
# Source Nodes: [w_dq_51], Original ATen: [quantized_decomposed.dequantize_per_channel_group]
buf637 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg193_1, arg194_1, arg195_1, -8, 7, torch.int8, 256, torch.bfloat16)
del arg193_1
del arg194_1
del arg195_1
buf638 = buf637
del buf637
buf639 = buf545; del buf545 # reuse
# Source Nodes: [], Original ATen: []
extern_kernels.mm(reinterpret_tensor(buf636, (1, 4096), (4096, 1), 0), reinterpret_tensor(buf638, (4096, 1024), (1, 4096), 0), out=buf639)
del buf638
buf641 = reinterpret_tensor(buf636, (1, 32, 1, 128), (4096, 128, 4096, 1), 0); del buf636 # reuse
# Source Nodes: [output_14, setitem_14, setitem_15], Original ATen: [aten._scaled_dot_product_efficient_attention, aten.index_put]
triton_poi_fused__scaled_dot_product_efficient_attention_index_put_3.run(arg3_1, buf631, arg196_1, buf639, buf624, arg197_1, arg198_1, buf641, 4096, grid=grid(4096), stream=stream0)
del arg196_1
del buf624
buf642 = buf556; del buf556 # reuse
# Source Nodes: [output_14], Original ATen: [aten._scaled_dot_product_efficient_attention]
triton_poi_fused__scaled_dot_product_efficient_attention_4.run(arg3_1, arg2_1, buf642, 65536, grid=grid(65536), stream=stream0)
# Source Nodes: [output_14], Original ATen: [aten._scaled_dot_product_efficient_attention]
buf643 = torch.ops.aten._scaled_dot_product_efficient_attention.default(buf641, arg197_1, arg198_1, buf642, False)
del arg197_1
del arg198_1
del buf641
buf644 = buf643[0]
del buf643
# Source Nodes: [choose_qparams_per_token_asymmetric_52], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric]
buf648 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(reinterpret_tensor(buf644, (1, 1, 4096), (4096, 4096, 1), 0), torch.int8)
buf649 = buf648[0]
buf650 = buf648[1]
del buf648
# Source Nodes: [input_158], Original ATen: [quantized_decomposed.quantize_per_token]
buf651 = torch.ops.quantized_decomposed.quantize_per_token.default(reinterpret_tensor(buf644, (1, 1, 4096), (4096, 4096, 1), 0), buf649, buf650, -128, 127, torch.int8)
buf652 = buf651
del buf651
# Source Nodes: [input_159], Original ATen: [quantized_decomposed.dequantize_per_token]
buf653 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf652, buf649, buf650, -128, 127, torch.int8, torch.bfloat16)
del buf649
del buf650
del buf652
buf654 = buf653
del buf653
# Source Nodes: [w_dq_52], Original ATen: [quantized_decomposed.dequantize_per_channel_group]
buf655 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg199_1, arg200_1, arg201_1, -8, 7, torch.int8, 256, torch.bfloat16)
del arg199_1
del arg200_1
del arg201_1
buf656 = buf655
del buf655
buf657 = reinterpret_tensor(buf644, (1, 4096), (4096, 1), 0); del buf644 # reuse
# Source Nodes: [c_52], Original ATen: [aten.mm]
triton_tem_fused_mm_2.run(buf654, buf656, buf657, grid=torch._inductor.kernel.mm_common.mm_grid(1, 4096, meta0), stream=stream0)
del buf656
buf658 = buf485; del buf485 # reuse
buf660 = buf654; del buf654 # reuse
# Source Nodes: [add_54, h_7, h_8, mean_15, mul_101, mul_102, out_5, out_6, pow_16, rsqrt_15, x_fp32_15, x_normed_15], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, aten.pow, aten.rsqrt]
triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_14.run(buf658, buf657, buf571, buf518, buf604, arg202_1, buf660, 1, 4096, grid=grid(1), stream=stream0)
del arg202_1
del buf518
del buf571
del buf604
del buf657
# Source Nodes: [choose_qparams_per_token_asymmetric_53], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric]
buf661 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf660, torch.int8)
buf662 = buf661[0]
buf663 = buf661[1]
del buf661
# Source Nodes: [choose_qparams_per_token_asymmetric_54], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric]
buf664 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf660, torch.int8)
buf665 = buf664[0]
buf666 = buf664[1]
del buf664
# Source Nodes: [input_161], Original ATen: [quantized_decomposed.quantize_per_token]
buf667 = torch.ops.quantized_decomposed.quantize_per_token.default(buf660, buf662, buf663, -128, 127, torch.int8)
buf668 = buf667
del buf667
# Source Nodes: [input_162], Original ATen: [quantized_decomposed.dequantize_per_token]
buf669 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf668, buf662, buf663, -128, 127, torch.int8, torch.bfloat16)
del buf662
del buf663
del buf668
buf670 = buf669
del buf669
# Source Nodes: [w_dq_53], Original ATen: [quantized_decomposed.dequantize_per_channel_group]
buf671 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg203_1, arg204_1, arg205_1, -8, 7, torch.int8, 256, torch.bfloat16)
del arg203_1
del arg204_1
del arg205_1
buf672 = buf671
del buf671
buf673 = reinterpret_tensor(buf601, (1, 14336), (14336, 1), 0); del buf601 # reuse
# Source Nodes: [c_53], Original ATen: [aten.mm]
triton_tem_fused_mm_6.run(buf670, buf672, buf673, grid=torch._inductor.kernel.mm_common.mm_grid(1, 14336, meta1), stream=stream0)
del buf672
# Source Nodes: [input_164], Original ATen: [quantized_decomposed.quantize_per_token]
buf674 = torch.ops.quantized_decomposed.quantize_per_token.default(buf660, buf665, buf666, -128, 127, torch.int8)
buf675 = buf674
del buf674
# Source Nodes: [input_165], Original ATen: [quantized_decomposed.dequantize_per_token]
buf676 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf675, buf665, buf666, -128, 127, torch.int8, torch.bfloat16)
del buf665
del buf666
del buf675
buf677 = buf676
del buf676
# Source Nodes: [w_dq_54], Original ATen: [quantized_decomposed.dequantize_per_channel_group]
buf678 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg206_1, arg207_1, arg208_1, -8, 7, torch.int8, 256, torch.bfloat16)
del arg206_1
del arg207_1
del arg208_1
buf679 = buf678
del buf678
buf681 = buf594; del buf594 # reuse
# Source Nodes: [c_54, mul_103, silu_7], Original ATen: [aten.mm, aten.mul, aten.silu]
triton_tem_fused_mm_mul_silu_7.run(buf677, buf679, buf673, buf681, grid=torch._inductor.kernel.mm_common.mm_grid(1, 14336, meta1), stream=stream0)
del buf673
del buf679
# Source Nodes: [choose_qparams_per_token_asymmetric_55], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric]
buf682 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf681, torch.int8)
buf683 = buf682[0]
buf684 = buf682[1]
del buf682
# Source Nodes: [input_167], Original ATen: [quantized_decomposed.quantize_per_token]
buf685 = torch.ops.quantized_decomposed.quantize_per_token.default(buf681, buf683, buf684, -128, 127, torch.int8)
buf686 = buf685
del buf685
# Source Nodes: [input_168], Original ATen: [quantized_decomposed.dequantize_per_token]
buf687 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf686, buf683, buf684, -128, 127, torch.int8, torch.bfloat16)
del buf683
del buf684
del buf686
buf688 = buf687
del buf687
# Source Nodes: [w_dq_55], Original ATen: [quantized_decomposed.dequantize_per_channel_group]
buf689 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg209_1, arg210_1, arg211_1, -8, 7, torch.int8, 256, torch.bfloat16)
del arg209_1
del arg210_1
del arg211_1
buf690 = buf689
del buf689
buf691 = reinterpret_tensor(buf677, (1, 4096), (4096, 1), 0); del buf677 # reuse
# Source Nodes: [c_55], Original ATen: [aten.mm]
triton_tem_fused_mm_8.run(buf688, buf690, buf691, grid=torch._inductor.kernel.mm_common.mm_grid(1, 4096, meta0), stream=stream0)
del buf690
buf693 = buf660; del buf660 # reuse
# Source Nodes: [add_56, mean_16, mul_104, out_7, pow_17, rsqrt_16, x_fp32_16, x_normed_16, y_8], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, aten.pow, aten.rsqrt]
triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_11.run(buf658, buf691, arg212_1, buf693, 1, 4096, grid=grid(1), stream=stream0)
del arg212_1
# Source Nodes: [choose_qparams_per_token_asymmetric_56], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric]
buf694 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf693, torch.int8)
buf695 = buf694[0]
buf696 = buf694[1]
del buf694
# Source Nodes: [choose_qparams_per_token_asymmetric_57], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric]
buf697 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf693, torch.int8)
buf698 = buf697[0]
buf699 = buf697[1]
del buf697
# Source Nodes: [choose_qparams_per_token_asymmetric_58], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric]
buf700 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf693, torch.int8)
buf701 = buf700[0]
buf702 = buf700[1]
del buf700
buf703 = buf616; del buf616 # reuse
# Source Nodes: [max_9], Original ATen: [aten.max]
triton_poi_fused_max_1.run(arg3_1, buf703, 1, grid=grid(1), stream=stream0)
u8 = buf703.item()
buf704 = None
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
# Source Nodes: [input_170], Original ATen: [quantized_decomposed.quantize_per_token]
buf705 = torch.ops.quantized_decomposed.quantize_per_token.default(buf693, buf695, buf696, -128, 127, torch.int8)
buf706 = buf705
del buf705
# Source Nodes: [input_171], Original ATen: [quantized_decomposed.dequantize_per_token]
buf707 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf706, buf695, buf696, -128, 127, torch.int8, torch.bfloat16)
del buf695
del buf696
del buf706
buf708 = buf707
del buf707
# Source Nodes: [w_dq_56], Original ATen: [quantized_decomposed.dequantize_per_channel_group]
buf709 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg213_1, arg214_1, arg215_1, -8, 7, torch.int8, 256, torch.bfloat16)
del arg213_1
del arg214_1
del arg215_1
buf710 = buf709
del buf709
buf711 = reinterpret_tensor(buf670, (1, 4096), (4096, 1), 0); del buf670 # reuse
# Source Nodes: [c_56], Original ATen: [aten.mm]
triton_tem_fused_mm_2.run(buf708, buf710, buf711, grid=torch._inductor.kernel.mm_common.mm_grid(1, 4096, meta0), stream=stream0)
del buf708
del buf710
# Source Nodes: [input_173], Original ATen: [quantized_decomposed.quantize_per_token]
buf712 = torch.ops.quantized_decomposed.quantize_per_token.default(buf693, buf698, buf699, -128, 127, torch.int8)
buf713 = buf712
del buf712
# Source Nodes: [input_174], Original ATen: [quantized_decomposed.dequantize_per_token]
buf714 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf713, buf698, buf699, -128, 127, torch.int8, torch.bfloat16)
del buf698
del buf699
del buf713
buf715 = buf714
del buf714
# Source Nodes: [w_dq_57], Original ATen: [quantized_decomposed.dequantize_per_channel_group]
buf716 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg216_1, arg217_1, arg218_1, -8, 7, torch.int8, 256, torch.bfloat16)
del arg216_1
del arg217_1
del arg218_1
buf717 = buf716
del buf716
buf718 = buf639; del buf639 # reuse
# Source Nodes: [], Original ATen: []
extern_kernels.mm(reinterpret_tensor(buf715, (1, 4096), (4096, 1), 0), reinterpret_tensor(buf717, (4096, 1024), (1, 4096), 0), out=buf718)
del buf715
del buf717
# Source Nodes: [input_176], Original ATen: [quantized_decomposed.quantize_per_token]
buf720 = torch.ops.quantized_decomposed.quantize_per_token.default(buf693, buf701, buf702, -128, 127, torch.int8)
del buf693
buf721 = buf720
del buf720
# Source Nodes: [input_177], Original ATen: [quantized_decomposed.dequantize_per_token]
buf722 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf721, buf701, buf702, -128, 127, torch.int8, torch.bfloat16)
del buf701
del buf702
del buf721
buf723 = buf722
del buf722
# Source Nodes: [w_dq_58], Original ATen: [quantized_decomposed.dequantize_per_channel_group]
buf724 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg219_1, arg220_1, arg221_1, -8, 7, torch.int8, 256, torch.bfloat16)
del arg219_1
del arg220_1
del arg221_1
buf725 = buf724
del buf724
buf726 = buf631; del buf631 # reuse
# Source Nodes: [], Original ATen: []
extern_kernels.mm(reinterpret_tensor(buf723, (1, 4096), (4096, 1), 0), reinterpret_tensor(buf725, (4096, 1024), (1, 4096), 0), out=buf726)
del buf725
buf728 = reinterpret_tensor(buf723, (1, 32, 1, 128), (4096, 128, 4096, 1), 0); del buf723 # reuse
# Source Nodes: [output_16, setitem_16, setitem_17], Original ATen: [aten._scaled_dot_product_efficient_attention, aten.index_put]
triton_poi_fused__scaled_dot_product_efficient_attention_index_put_3.run(arg3_1, buf718, arg222_1, buf726, buf711, arg223_1, arg224_1, buf728, 4096, grid=grid(4096), stream=stream0)
del arg222_1
del buf711
buf729 = buf642; del buf642 # reuse
# Source Nodes: [output_16], Original ATen: [aten._scaled_dot_product_efficient_attention]
triton_poi_fused__scaled_dot_product_efficient_attention_4.run(arg3_1, arg2_1, buf729, 65536, grid=grid(65536), stream=stream0)
# Source Nodes: [output_16], Original ATen: [aten._scaled_dot_product_efficient_attention]
buf730 = torch.ops.aten._scaled_dot_product_efficient_attention.default(buf728, arg223_1, arg224_1, buf729, False)
del arg223_1
del arg224_1
del buf728
buf731 = buf730[0]
del buf730
# Source Nodes: [choose_qparams_per_token_asymmetric_59], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric]
buf735 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(reinterpret_tensor(buf731, (1, 1, 4096), (4096, 4096, 1), 0), torch.int8)
buf736 = buf735[0]
buf737 = buf735[1]
del buf735
# Source Nodes: [input_179], Original ATen: [quantized_decomposed.quantize_per_token]
buf738 = torch.ops.quantized_decomposed.quantize_per_token.default(reinterpret_tensor(buf731, (1, 1, 4096), (4096, 4096, 1), 0), buf736, buf737, -128, 127, torch.int8)
buf739 = buf738
del buf738
# Source Nodes: [input_180], Original ATen: [quantized_decomposed.dequantize_per_token]
buf740 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf739, buf736, buf737, -128, 127, torch.int8, torch.bfloat16)
del buf736
del buf737
del buf739
buf741 = buf740
del buf740
# Source Nodes: [w_dq_59], Original ATen: [quantized_decomposed.dequantize_per_channel_group]
buf742 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg225_1, arg226_1, arg227_1, -8, 7, torch.int8, 256, torch.bfloat16)
del arg225_1
del arg226_1
del arg227_1
buf743 = buf742
del buf742
buf744 = reinterpret_tensor(buf731, (1, 4096), (4096, 1), 0); del buf731 # reuse
# Source Nodes: [c_59], Original ATen: [aten.mm]
triton_tem_fused_mm_2.run(buf741, buf743, buf744, grid=torch._inductor.kernel.mm_common.mm_grid(1, 4096, meta0), stream=stream0)
del buf743
buf746 = buf741; del buf741 # reuse
# Source Nodes: [add_61, h_9, mean_17, mul_114, mul_115, out_7, pow_18, rsqrt_17, x_fp32_17, x_normed_17], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, aten.pow, aten.rsqrt]
triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_12.run(buf744, buf658, buf691, arg228_1, buf746, 1, 4096, grid=grid(1), stream=stream0)
del arg228_1
# Source Nodes: [choose_qparams_per_token_asymmetric_60], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric]
buf747 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf746, torch.int8)
buf748 = buf747[0]
buf749 = buf747[1]
del buf747
# Source Nodes: [choose_qparams_per_token_asymmetric_61], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric]
buf750 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf746, torch.int8)
buf751 = buf750[0]
buf752 = buf750[1]
del buf750
# Source Nodes: [input_182], Original ATen: [quantized_decomposed.quantize_per_token]
buf753 = torch.ops.quantized_decomposed.quantize_per_token.default(buf746, buf748, buf749, -128, 127, torch.int8)
buf754 = buf753
del buf753
# Source Nodes: [input_183], Original ATen: [quantized_decomposed.dequantize_per_token]
buf755 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf754, buf748, buf749, -128, 127, torch.int8, torch.bfloat16)
del buf748
del buf749
del buf754
buf756 = buf755
del buf755
# Source Nodes: [w_dq_60], Original ATen: [quantized_decomposed.dequantize_per_channel_group]
buf757 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg229_1, arg230_1, arg231_1, -8, 7, torch.int8, 256, torch.bfloat16)
del arg229_1
del arg230_1
del arg231_1
buf758 = buf757
del buf757
buf759 = reinterpret_tensor(buf688, (1, 14336), (14336, 1), 0); del buf688 # reuse
# Source Nodes: [c_60], Original ATen: [aten.mm]
triton_tem_fused_mm_6.run(buf756, buf758, buf759, grid=torch._inductor.kernel.mm_common.mm_grid(1, 14336, meta1), stream=stream0)
del buf758
# Source Nodes: [input_185], Original ATen: [quantized_decomposed.quantize_per_token]
buf760 = torch.ops.quantized_decomposed.quantize_per_token.default(buf746, buf751, buf752, -128, 127, torch.int8)
buf761 = buf760
del buf760
# Source Nodes: [input_186], Original ATen: [quantized_decomposed.dequantize_per_token]
buf762 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf761, buf751, buf752, -128, 127, torch.int8, torch.bfloat16)
del buf751
del buf752
del buf761
buf763 = buf762
del buf762
# Source Nodes: [w_dq_61], Original ATen: [quantized_decomposed.dequantize_per_channel_group]
buf764 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg232_1, arg233_1, arg234_1, -8, 7, torch.int8, 256, torch.bfloat16)
del arg232_1
del arg233_1
del arg234_1
buf765 = buf764
del buf764
buf767 = buf681; del buf681 # reuse
# Source Nodes: [c_61, mul_116, silu_8], Original ATen: [aten.mm, aten.mul, aten.silu]
triton_tem_fused_mm_mul_silu_7.run(buf763, buf765, buf759, buf767, grid=torch._inductor.kernel.mm_common.mm_grid(1, 14336, meta1), stream=stream0)
del buf759
del buf765
# Source Nodes: [choose_qparams_per_token_asymmetric_62], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric]
buf768 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf767, torch.int8)
buf769 = buf768[0]
buf770 = buf768[1]
del buf768
# Source Nodes: [input_188], Original ATen: [quantized_decomposed.quantize_per_token]
buf771 = torch.ops.quantized_decomposed.quantize_per_token.default(buf767, buf769, buf770, -128, 127, torch.int8)
buf772 = buf771
del buf771
# Source Nodes: [input_189], Original ATen: [quantized_decomposed.dequantize_per_token]
buf773 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf772, buf769, buf770, -128, 127, torch.int8, torch.bfloat16)
del buf769
del buf770
del buf772
buf774 = buf773
del buf773
# Source Nodes: [w_dq_62], Original ATen: [quantized_decomposed.dequantize_per_channel_group]
buf775 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg235_1, arg236_1, arg237_1, -8, 7, torch.int8, 256, torch.bfloat16)
del arg235_1
del arg236_1
del arg237_1
buf776 = buf775
del buf775
buf777 = reinterpret_tensor(buf763, (1, 4096), (4096, 1), 0); del buf763 # reuse
# Source Nodes: [c_62], Original ATen: [aten.mm]
triton_tem_fused_mm_8.run(buf774, buf776, buf777, grid=torch._inductor.kernel.mm_common.mm_grid(1, 4096, meta0), stream=stream0)
del buf776
buf779 = buf746; del buf746 # reuse
# Source Nodes: [add_63, h_9, mean_18, mul_117, out_7, out_8, pow_19, rsqrt_18, x_fp32_18, x_normed_18, y_9], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, aten.pow, aten.rsqrt]
triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_13.run(buf744, buf658, buf691, buf777, arg238_1, buf779, 1, 4096, grid=grid(1), stream=stream0)
del arg238_1
# Source Nodes: [choose_qparams_per_token_asymmetric_63], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric]
buf780 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf779, torch.int8)
buf781 = buf780[0]
buf782 = buf780[1]
del buf780
# Source Nodes: [choose_qparams_per_token_asymmetric_64], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric]
buf783 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf779, torch.int8)
buf784 = buf783[0]
buf785 = buf783[1]
del buf783
# Source Nodes: [choose_qparams_per_token_asymmetric_65], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric]
buf786 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf779, torch.int8)
buf787 = buf786[0]
buf788 = buf786[1]
del buf786
buf789 = buf703; del buf703 # reuse
# Source Nodes: [max_10], Original ATen: [aten.max]
triton_poi_fused_max_1.run(arg3_1, buf789, 1, grid=grid(1), stream=stream0)
u9 = buf789.item()
buf790 = None
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
# Source Nodes: [input_191], Original ATen: [quantized_decomposed.quantize_per_token]
buf791 = torch.ops.quantized_decomposed.quantize_per_token.default(buf779, buf781, buf782, -128, 127, torch.int8)
buf792 = buf791
del buf791
# Source Nodes: [input_192], Original ATen: [quantized_decomposed.dequantize_per_token]
buf793 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf792, buf781, buf782, -128, 127, torch.int8, torch.bfloat16)
del buf781
del buf782
del buf792
buf794 = buf793
del buf793
# Source Nodes: [w_dq_63], Original ATen: [quantized_decomposed.dequantize_per_channel_group]
buf795 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg239_1, arg240_1, arg241_1, -8, 7, torch.int8, 256, torch.bfloat16)
del arg239_1
del arg240_1
del arg241_1
buf796 = buf795
del buf795
buf797 = reinterpret_tensor(buf756, (1, 4096), (4096, 1), 0); del buf756 # reuse
# Source Nodes: [c_63], Original ATen: [aten.mm]
triton_tem_fused_mm_2.run(buf794, buf796, buf797, grid=torch._inductor.kernel.mm_common.mm_grid(1, 4096, meta0), stream=stream0)
del buf794
del buf796
# Source Nodes: [input_194], Original ATen: [quantized_decomposed.quantize_per_token]
buf798 = torch.ops.quantized_decomposed.quantize_per_token.default(buf779, buf784, buf785, -128, 127, torch.int8)
buf799 = buf798
del buf798
# Source Nodes: [input_195], Original ATen: [quantized_decomposed.dequantize_per_token]
buf800 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf799, buf784, buf785, -128, 127, torch.int8, torch.bfloat16)
del buf784
del buf785
del buf799
buf801 = buf800
del buf800
# Source Nodes: [w_dq_64], Original ATen: [quantized_decomposed.dequantize_per_channel_group]
buf802 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg242_1, arg243_1, arg244_1, -8, 7, torch.int8, 256, torch.bfloat16)
del arg242_1
del arg243_1
del arg244_1
buf803 = buf802
del buf802
buf804 = buf726; del buf726 # reuse
# Source Nodes: [], Original ATen: []
extern_kernels.mm(reinterpret_tensor(buf801, (1, 4096), (4096, 1), 0), reinterpret_tensor(buf803, (4096, 1024), (1, 4096), 0), out=buf804)
del buf801
del buf803
# Source Nodes: [input_197], Original ATen: [quantized_decomposed.quantize_per_token]
buf806 = torch.ops.quantized_decomposed.quantize_per_token.default(buf779, buf787, buf788, -128, 127, torch.int8)
del buf779
buf807 = buf806
del buf806
# Source Nodes: [input_198], Original ATen: [quantized_decomposed.dequantize_per_token]
buf808 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf807, buf787, buf788, -128, 127, torch.int8, torch.bfloat16)
del buf787
del buf788
del buf807
buf809 = buf808
del buf808
# Source Nodes: [w_dq_65], Original ATen: [quantized_decomposed.dequantize_per_channel_group]
buf810 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg245_1, arg246_1, arg247_1, -8, 7, torch.int8, 256, torch.bfloat16)
del arg245_1
del arg246_1
del arg247_1
buf811 = buf810
del buf810
buf812 = buf718; del buf718 # reuse
# Source Nodes: [], Original ATen: []
extern_kernels.mm(reinterpret_tensor(buf809, (1, 4096), (4096, 1), 0), reinterpret_tensor(buf811, (4096, 1024), (1, 4096), 0), out=buf812)
del buf811
buf814 = reinterpret_tensor(buf809, (1, 32, 1, 128), (4096, 128, 4096, 1), 0); del buf809 # reuse
# Source Nodes: [output_18, setitem_18, setitem_19], Original ATen: [aten._scaled_dot_product_efficient_attention, aten.index_put]
triton_poi_fused__scaled_dot_product_efficient_attention_index_put_3.run(arg3_1, buf804, arg248_1, buf812, buf797, arg249_1, arg250_1, buf814, 4096, grid=grid(4096), stream=stream0)
del arg248_1
del buf797
buf815 = buf729; del buf729 # reuse
# Source Nodes: [output_18], Original ATen: [aten._scaled_dot_product_efficient_attention]
triton_poi_fused__scaled_dot_product_efficient_attention_4.run(arg3_1, arg2_1, buf815, 65536, grid=grid(65536), stream=stream0)
# Source Nodes: [output_18], Original ATen: [aten._scaled_dot_product_efficient_attention]
buf816 = torch.ops.aten._scaled_dot_product_efficient_attention.default(buf814, arg249_1, arg250_1, buf815, False)
del arg249_1
del arg250_1
del buf814
buf817 = buf816[0]
del buf816
# Source Nodes: [choose_qparams_per_token_asymmetric_66], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric]
buf821 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(reinterpret_tensor(buf817, (1, 1, 4096), (4096, 4096, 1), 0), torch.int8)
buf822 = buf821[0]
buf823 = buf821[1]
del buf821
# Source Nodes: [input_200], Original ATen: [quantized_decomposed.quantize_per_token]
buf824 = torch.ops.quantized_decomposed.quantize_per_token.default(reinterpret_tensor(buf817, (1, 1, 4096), (4096, 4096, 1), 0), buf822, buf823, -128, 127, torch.int8)
buf825 = buf824
del buf824
# Source Nodes: [input_201], Original ATen: [quantized_decomposed.dequantize_per_token]
buf826 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf825, buf822, buf823, -128, 127, torch.int8, torch.bfloat16)
del buf822
del buf823
del buf825
buf827 = buf826
del buf826
# Source Nodes: [w_dq_66], Original ATen: [quantized_decomposed.dequantize_per_channel_group]
buf828 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg251_1, arg252_1, arg253_1, -8, 7, torch.int8, 256, torch.bfloat16)
del arg251_1
del arg252_1
del arg253_1
buf829 = buf828
del buf828
buf830 = reinterpret_tensor(buf817, (1, 4096), (4096, 1), 0); del buf817 # reuse
# Source Nodes: [c_66], Original ATen: [aten.mm]
triton_tem_fused_mm_2.run(buf827, buf829, buf830, grid=torch._inductor.kernel.mm_common.mm_grid(1, 4096, meta0), stream=stream0)
del buf829
buf831 = buf658; del buf658 # reuse
buf833 = buf827; del buf827 # reuse
# Source Nodes: [add_68, h_10, h_9, mean_19, mul_127, mul_128, out_7, out_8, pow_20, rsqrt_19, x_fp32_19, x_normed_19], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, aten.pow, aten.rsqrt]
triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_14.run(buf831, buf830, buf744, buf691, buf777, arg254_1, buf833, 1, 4096, grid=grid(1), stream=stream0)
del arg254_1
del buf691
del buf744
del buf777
del buf830
# Source Nodes: [choose_qparams_per_token_asymmetric_67], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric]
buf834 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf833, torch.int8)
buf835 = buf834[0]
buf836 = buf834[1]
del buf834
# Source Nodes: [choose_qparams_per_token_asymmetric_68], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric]
buf837 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf833, torch.int8)
buf838 = buf837[0]
buf839 = buf837[1]
del buf837
# Source Nodes: [input_203], Original ATen: [quantized_decomposed.quantize_per_token]
buf840 = torch.ops.quantized_decomposed.quantize_per_token.default(buf833, buf835, buf836, -128, 127, torch.int8)
buf841 = buf840
del buf840
# Source Nodes: [input_204], Original ATen: [quantized_decomposed.dequantize_per_token]
buf842 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf841, buf835, buf836, -128, 127, torch.int8, torch.bfloat16)
del buf835
del buf836
del buf841
buf843 = buf842
del buf842
# Source Nodes: [w_dq_67], Original ATen: [quantized_decomposed.dequantize_per_channel_group]
buf844 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg255_1, arg256_1, arg257_1, -8, 7, torch.int8, 256, torch.bfloat16)
del arg255_1
del arg256_1
del arg257_1
buf845 = buf844
del buf844
buf846 = reinterpret_tensor(buf774, (1, 14336), (14336, 1), 0); del buf774 # reuse
# Source Nodes: [c_67], Original ATen: [aten.mm]
triton_tem_fused_mm_6.run(buf843, buf845, buf846, grid=torch._inductor.kernel.mm_common.mm_grid(1, 14336, meta1), stream=stream0)
del buf845
# Source Nodes: [input_206], Original ATen: [quantized_decomposed.quantize_per_token]
buf847 = torch.ops.quantized_decomposed.quantize_per_token.default(buf833, buf838, buf839, -128, 127, torch.int8)
buf848 = buf847
del buf847
# Source Nodes: [input_207], Original ATen: [quantized_decomposed.dequantize_per_token]
buf849 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf848, buf838, buf839, -128, 127, torch.int8, torch.bfloat16)
del buf838
del buf839
del buf848
buf850 = buf849
del buf849
# Source Nodes: [w_dq_68], Original ATen: [quantized_decomposed.dequantize_per_channel_group]
buf851 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg258_1, arg259_1, arg260_1, -8, 7, torch.int8, 256, torch.bfloat16)
del arg258_1
del arg259_1
del arg260_1
buf852 = buf851
del buf851
buf854 = buf767; del buf767 # reuse
# Source Nodes: [c_68, mul_129, silu_9], Original ATen: [aten.mm, aten.mul, aten.silu]
triton_tem_fused_mm_mul_silu_7.run(buf850, buf852, buf846, buf854, grid=torch._inductor.kernel.mm_common.mm_grid(1, 14336, meta1), stream=stream0)
del buf846
del buf852
# Source Nodes: [choose_qparams_per_token_asymmetric_69], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric]
buf855 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf854, torch.int8)
buf856 = buf855[0]
buf857 = buf855[1]
del buf855
# Source Nodes: [input_209], Original ATen: [quantized_decomposed.quantize_per_token]
buf858 = torch.ops.quantized_decomposed.quantize_per_token.default(buf854, buf856, buf857, -128, 127, torch.int8)
buf859 = buf858
del buf858
# Source Nodes: [input_210], Original ATen: [quantized_decomposed.dequantize_per_token]
buf860 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf859, buf856, buf857, -128, 127, torch.int8, torch.bfloat16)
del buf856
del buf857
del buf859
buf861 = buf860
del buf860
# Source Nodes: [w_dq_69], Original ATen: [quantized_decomposed.dequantize_per_channel_group]
buf862 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg261_1, arg262_1, arg263_1, -8, 7, torch.int8, 256, torch.bfloat16)
del arg261_1
del arg262_1
del arg263_1
buf863 = buf862
del buf862
buf864 = reinterpret_tensor(buf850, (1, 4096), (4096, 1), 0); del buf850 # reuse
# Source Nodes: [c_69], Original ATen: [aten.mm]
triton_tem_fused_mm_8.run(buf861, buf863, buf864, grid=torch._inductor.kernel.mm_common.mm_grid(1, 4096, meta0), stream=stream0)
del buf863
buf866 = buf833; del buf833 # reuse
# Source Nodes: [add_70, mean_20, mul_130, out_9, pow_21, rsqrt_20, x_fp32_20, x_normed_20, y_10], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, aten.pow, aten.rsqrt]
triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_11.run(buf831, buf864, arg264_1, buf866, 1, 4096, grid=grid(1), stream=stream0)
del arg264_1
# Source Nodes: [choose_qparams_per_token_asymmetric_70], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric]
buf867 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf866, torch.int8)
buf868 = buf867[0]
buf869 = buf867[1]
del buf867
# Source Nodes: [choose_qparams_per_token_asymmetric_71], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric]
buf870 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf866, torch.int8)
buf871 = buf870[0]
buf872 = buf870[1]
del buf870
# Source Nodes: [choose_qparams_per_token_asymmetric_72], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric]
buf873 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf866, torch.int8)
buf874 = buf873[0]
buf875 = buf873[1]
del buf873
buf876 = buf789; del buf789 # reuse
# Source Nodes: [max_11], Original ATen: [aten.max]
triton_poi_fused_max_1.run(arg3_1, buf876, 1, grid=grid(1), stream=stream0)
u10 = buf876.item()
buf877 = None
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
# Source Nodes: [input_212], Original ATen: [quantized_decomposed.quantize_per_token]
buf878 = torch.ops.quantized_decomposed.quantize_per_token.default(buf866, buf868, buf869, -128, 127, torch.int8)
buf879 = buf878
del buf878
# Source Nodes: [input_213], Original ATen: [quantized_decomposed.dequantize_per_token]
buf880 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf879, buf868, buf869, -128, 127, torch.int8, torch.bfloat16)
del buf868
del buf869
del buf879
buf881 = buf880
del buf880
# Source Nodes: [w_dq_70], Original ATen: [quantized_decomposed.dequantize_per_channel_group]
buf882 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg265_1, arg266_1, arg267_1, -8, 7, torch.int8, 256, torch.bfloat16)
del arg265_1
del arg266_1
del arg267_1
buf883 = buf882
del buf882
buf884 = reinterpret_tensor(buf843, (1, 4096), (4096, 1), 0); del buf843 # reuse
# Source Nodes: [c_70], Original ATen: [aten.mm]
triton_tem_fused_mm_2.run(buf881, buf883, buf884, grid=torch._inductor.kernel.mm_common.mm_grid(1, 4096, meta0), stream=stream0)
del buf881
del buf883
# Source Nodes: [input_215], Original ATen: [quantized_decomposed.quantize_per_token]
buf885 = torch.ops.quantized_decomposed.quantize_per_token.default(buf866, buf871, buf872, -128, 127, torch.int8)
buf886 = buf885
del buf885
# Source Nodes: [input_216], Original ATen: [quantized_decomposed.dequantize_per_token]
buf887 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf886, buf871, buf872, -128, 127, torch.int8, torch.bfloat16)
del buf871
del buf872
del buf886
buf888 = buf887
del buf887
# Source Nodes: [w_dq_71], Original ATen: [quantized_decomposed.dequantize_per_channel_group]
buf889 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg268_1, arg269_1, arg270_1, -8, 7, torch.int8, 256, torch.bfloat16)
del arg268_1
del arg269_1
del arg270_1
buf890 = buf889
del buf889
buf891 = buf812; del buf812 # reuse
# Source Nodes: [], Original ATen: []
extern_kernels.mm(reinterpret_tensor(buf888, (1, 4096), (4096, 1), 0), reinterpret_tensor(buf890, (4096, 1024), (1, 4096), 0), out=buf891)
del buf888
del buf890
# Source Nodes: [input_218], Original ATen: [quantized_decomposed.quantize_per_token]
buf893 = torch.ops.quantized_decomposed.quantize_per_token.default(buf866, buf874, buf875, -128, 127, torch.int8)
del buf866
buf894 = buf893
del buf893
# Source Nodes: [input_219], Original ATen: [quantized_decomposed.dequantize_per_token]
buf895 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf894, buf874, buf875, -128, 127, torch.int8, torch.bfloat16)
del buf874
del buf875
del buf894
buf896 = buf895
del buf895
# Source Nodes: [w_dq_72], Original ATen: [quantized_decomposed.dequantize_per_channel_group]
buf897 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg271_1, arg272_1, arg273_1, -8, 7, torch.int8, 256, torch.bfloat16)
del arg271_1
del arg272_1
del arg273_1
buf898 = buf897
del buf897
buf899 = buf804; del buf804 # reuse
# Source Nodes: [], Original ATen: []
extern_kernels.mm(reinterpret_tensor(buf896, (1, 4096), (4096, 1), 0), reinterpret_tensor(buf898, (4096, 1024), (1, 4096), 0), out=buf899)
del buf898
buf901 = reinterpret_tensor(buf896, (1, 32, 1, 128), (4096, 128, 4096, 1), 0); del buf896 # reuse
# Source Nodes: [output_20, setitem_20, setitem_21], Original ATen: [aten._scaled_dot_product_efficient_attention, aten.index_put]
triton_poi_fused__scaled_dot_product_efficient_attention_index_put_3.run(arg3_1, buf891, arg274_1, buf899, buf884, arg275_1, arg276_1, buf901, 4096, grid=grid(4096), stream=stream0)
del arg274_1
del buf884
buf902 = buf815; del buf815 # reuse
# Source Nodes: [output_20], Original ATen: [aten._scaled_dot_product_efficient_attention]
triton_poi_fused__scaled_dot_product_efficient_attention_4.run(arg3_1, arg2_1, buf902, 65536, grid=grid(65536), stream=stream0)
# Source Nodes: [output_20], Original ATen: [aten._scaled_dot_product_efficient_attention]
buf903 = torch.ops.aten._scaled_dot_product_efficient_attention.default(buf901, arg275_1, arg276_1, buf902, False)
del arg275_1
del arg276_1
del buf901
buf904 = buf903[0]
del buf903
# Source Nodes: [choose_qparams_per_token_asymmetric_73], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric]
buf908 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(reinterpret_tensor(buf904, (1, 1, 4096), (4096, 4096, 1), 0), torch.int8)
buf909 = buf908[0]
buf910 = buf908[1]
del buf908
# Source Nodes: [input_221], Original ATen: [quantized_decomposed.quantize_per_token]
buf911 = torch.ops.quantized_decomposed.quantize_per_token.default(reinterpret_tensor(buf904, (1, 1, 4096), (4096, 4096, 1), 0), buf909, buf910, -128, 127, torch.int8)
buf912 = buf911
del buf911
# Source Nodes: [input_222], Original ATen: [quantized_decomposed.dequantize_per_token]
buf913 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf912, buf909, buf910, -128, 127, torch.int8, torch.bfloat16)
del buf909
del buf910
del buf912
buf914 = buf913
del buf913
# Source Nodes: [w_dq_73], Original ATen: [quantized_decomposed.dequantize_per_channel_group]
buf915 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg277_1, arg278_1, arg279_1, -8, 7, torch.int8, 256, torch.bfloat16)
del arg277_1
del arg278_1
del arg279_1
buf916 = buf915
del buf915
buf917 = reinterpret_tensor(buf904, (1, 4096), (4096, 1), 0); del buf904 # reuse
# Source Nodes: [c_73], Original ATen: [aten.mm]
triton_tem_fused_mm_2.run(buf914, buf916, buf917, grid=torch._inductor.kernel.mm_common.mm_grid(1, 4096, meta0), stream=stream0)
del buf916
buf919 = buf914; del buf914 # reuse
# Source Nodes: [add_75, h_11, mean_21, mul_140, mul_141, out_9, pow_22, rsqrt_21, x_fp32_21, x_normed_21], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, aten.pow, aten.rsqrt]
triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_12.run(buf917, buf831, buf864, arg280_1, buf919, 1, 4096, grid=grid(1), stream=stream0)
del arg280_1
# Source Nodes: [choose_qparams_per_token_asymmetric_74], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric]
buf920 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf919, torch.int8)
buf921 = buf920[0]
buf922 = buf920[1]
del buf920
# Source Nodes: [choose_qparams_per_token_asymmetric_75], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric]
buf923 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf919, torch.int8)
buf924 = buf923[0]
buf925 = buf923[1]
del buf923
# Source Nodes: [input_224], Original ATen: [quantized_decomposed.quantize_per_token]
buf926 = torch.ops.quantized_decomposed.quantize_per_token.default(buf919, buf921, buf922, -128, 127, torch.int8)
buf927 = buf926
del buf926
# Source Nodes: [input_225], Original ATen: [quantized_decomposed.dequantize_per_token]
buf928 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf927, buf921, buf922, -128, 127, torch.int8, torch.bfloat16)
del buf921
del buf922
del buf927
buf929 = buf928
del buf928
# Source Nodes: [w_dq_74], Original ATen: [quantized_decomposed.dequantize_per_channel_group]
buf930 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg281_1, arg282_1, arg283_1, -8, 7, torch.int8, 256, torch.bfloat16)
del arg281_1
del arg282_1
del arg283_1
buf931 = buf930
del buf930
buf932 = reinterpret_tensor(buf861, (1, 14336), (14336, 1), 0); del buf861 # reuse
# Source Nodes: [c_74], Original ATen: [aten.mm]
triton_tem_fused_mm_6.run(buf929, buf931, buf932, grid=torch._inductor.kernel.mm_common.mm_grid(1, 14336, meta1), stream=stream0)
del buf931
# Source Nodes: [input_227], Original ATen: [quantized_decomposed.quantize_per_token]
buf933 = torch.ops.quantized_decomposed.quantize_per_token.default(buf919, buf924, buf925, -128, 127, torch.int8)
buf934 = buf933
del buf933
# Source Nodes: [input_228], Original ATen: [quantized_decomposed.dequantize_per_token]
buf935 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf934, buf924, buf925, -128, 127, torch.int8, torch.bfloat16)
del buf924
del buf925
del buf934
buf936 = buf935
del buf935
# Source Nodes: [w_dq_75], Original ATen: [quantized_decomposed.dequantize_per_channel_group]
buf937 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg284_1, arg285_1, arg286_1, -8, 7, torch.int8, 256, torch.bfloat16)
del arg284_1
del arg285_1
del arg286_1
buf938 = buf937
del buf937
buf940 = buf854; del buf854 # reuse
# Source Nodes: [c_75, mul_142, silu_10], Original ATen: [aten.mm, aten.mul, aten.silu]
triton_tem_fused_mm_mul_silu_7.run(buf936, buf938, buf932, buf940, grid=torch._inductor.kernel.mm_common.mm_grid(1, 14336, meta1), stream=stream0)
del buf932
del buf938
# Source Nodes: [choose_qparams_per_token_asymmetric_76], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric]
buf941 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf940, torch.int8)
buf942 = buf941[0]
buf943 = buf941[1]
del buf941
# Source Nodes: [input_230], Original ATen: [quantized_decomposed.quantize_per_token]
buf944 = torch.ops.quantized_decomposed.quantize_per_token.default(buf940, buf942, buf943, -128, 127, torch.int8)
buf945 = buf944
del buf944
# Source Nodes: [input_231], Original ATen: [quantized_decomposed.dequantize_per_token]
buf946 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf945, buf942, buf943, -128, 127, torch.int8, torch.bfloat16)
del buf942
del buf943
del buf945
buf947 = buf946
del buf946
# Source Nodes: [w_dq_76], Original ATen: [quantized_decomposed.dequantize_per_channel_group]
buf948 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg287_1, arg288_1, arg289_1, -8, 7, torch.int8, 256, torch.bfloat16)
del arg287_1
del arg288_1
del arg289_1
buf949 = buf948
del buf948
buf950 = reinterpret_tensor(buf936, (1, 4096), (4096, 1), 0); del buf936 # reuse
# Source Nodes: [c_76], Original ATen: [aten.mm]
triton_tem_fused_mm_8.run(buf947, buf949, buf950, grid=torch._inductor.kernel.mm_common.mm_grid(1, 4096, meta0), stream=stream0)
del buf949
buf952 = buf919; del buf919 # reuse
# Source Nodes: [add_77, h_11, mean_22, mul_143, out_10, out_9, pow_23, rsqrt_22, x_fp32_22, x_normed_22, y_11], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, aten.pow, aten.rsqrt]
triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_13.run(buf917, buf831, buf864, buf950, arg290_1, buf952, 1, 4096, grid=grid(1), stream=stream0)
del arg290_1
# Source Nodes: [choose_qparams_per_token_asymmetric_77], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric]
buf953 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf952, torch.int8)
buf954 = buf953[0]
buf955 = buf953[1]
del buf953
# Source Nodes: [choose_qparams_per_token_asymmetric_78], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric]
buf956 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf952, torch.int8)
buf957 = buf956[0]
buf958 = buf956[1]
del buf956
# Source Nodes: [choose_qparams_per_token_asymmetric_79], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric]
buf959 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf952, torch.int8)
buf960 = buf959[0]
buf961 = buf959[1]
del buf959
buf962 = buf876; del buf876 # reuse
# Source Nodes: [max_12], Original ATen: [aten.max]
triton_poi_fused_max_1.run(arg3_1, buf962, 1, grid=grid(1), stream=stream0)
u11 = buf962.item()
buf963 = None
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
# Source Nodes: [input_233], Original ATen: [quantized_decomposed.quantize_per_token]
buf964 = torch.ops.quantized_decomposed.quantize_per_token.default(buf952, buf954, buf955, -128, 127, torch.int8)
buf965 = buf964
del buf964
# Source Nodes: [input_234], Original ATen: [quantized_decomposed.dequantize_per_token]
buf966 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf965, buf954, buf955, -128, 127, torch.int8, torch.bfloat16)
del buf954
del buf955
del buf965
buf967 = buf966
del buf966
# Source Nodes: [w_dq_77], Original ATen: [quantized_decomposed.dequantize_per_channel_group]
buf968 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg291_1, arg292_1, arg293_1, -8, 7, torch.int8, 256, torch.bfloat16)
del arg291_1
del arg292_1
del arg293_1
buf969 = buf968
del buf968
buf970 = reinterpret_tensor(buf929, (1, 4096), (4096, 1), 0); del buf929 # reuse
# Source Nodes: [c_77], Original ATen: [aten.mm]
triton_tem_fused_mm_2.run(buf967, buf969, buf970, grid=torch._inductor.kernel.mm_common.mm_grid(1, 4096, meta0), stream=stream0)
del buf967
del buf969
# Source Nodes: [input_236], Original ATen: [quantized_decomposed.quantize_per_token]
buf971 = torch.ops.quantized_decomposed.quantize_per_token.default(buf952, buf957, buf958, -128, 127, torch.int8)
buf972 = buf971
del buf971
# Source Nodes: [input_237], Original ATen: [quantized_decomposed.dequantize_per_token]
buf973 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf972, buf957, buf958, -128, 127, torch.int8, torch.bfloat16)
del buf957
del buf958
del buf972
buf974 = buf973
del buf973
# Source Nodes: [w_dq_78], Original ATen: [quantized_decomposed.dequantize_per_channel_group]
buf975 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg294_1, arg295_1, arg296_1, -8, 7, torch.int8, 256, torch.bfloat16)
del arg294_1
del arg295_1
del arg296_1
buf976 = buf975
del buf975
buf977 = buf899; del buf899 # reuse
# Source Nodes: [], Original ATen: []
extern_kernels.mm(reinterpret_tensor(buf974, (1, 4096), (4096, 1), 0), reinterpret_tensor(buf976, (4096, 1024), (1, 4096), 0), out=buf977)
del buf974
del buf976
# Source Nodes: [input_239], Original ATen: [quantized_decomposed.quantize_per_token]
buf979 = torch.ops.quantized_decomposed.quantize_per_token.default(buf952, buf960, buf961, -128, 127, torch.int8)
del buf952
buf980 = buf979
del buf979
# Source Nodes: [input_240], Original ATen: [quantized_decomposed.dequantize_per_token]
buf981 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf980, buf960, buf961, -128, 127, torch.int8, torch.bfloat16)
del buf960
del buf961
del buf980
buf982 = buf981
del buf981
# Source Nodes: [w_dq_79], Original ATen: [quantized_decomposed.dequantize_per_channel_group]
buf983 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg297_1, arg298_1, arg299_1, -8, 7, torch.int8, 256, torch.bfloat16)
del arg297_1
del arg298_1
del arg299_1
buf984 = buf983
del buf983
buf985 = buf891; del buf891 # reuse
# Source Nodes: [], Original ATen: []
extern_kernels.mm(reinterpret_tensor(buf982, (1, 4096), (4096, 1), 0), reinterpret_tensor(buf984, (4096, 1024), (1, 4096), 0), out=buf985)
del buf984
buf987 = reinterpret_tensor(buf982, (1, 32, 1, 128), (4096, 128, 4096, 1), 0); del buf982 # reuse
# Source Nodes: [output_22, setitem_22, setitem_23], Original ATen: [aten._scaled_dot_product_efficient_attention, aten.index_put]
triton_poi_fused__scaled_dot_product_efficient_attention_index_put_3.run(arg3_1, buf977, arg300_1, buf985, buf970, arg301_1, arg302_1, buf987, 4096, grid=grid(4096), stream=stream0)
del arg300_1
del buf970
buf988 = buf902; del buf902 # reuse
# Source Nodes: [output_22], Original ATen: [aten._scaled_dot_product_efficient_attention]
triton_poi_fused__scaled_dot_product_efficient_attention_4.run(arg3_1, arg2_1, buf988, 65536, grid=grid(65536), stream=stream0)
# Source Nodes: [output_22], Original ATen: [aten._scaled_dot_product_efficient_attention]
buf989 = torch.ops.aten._scaled_dot_product_efficient_attention.default(buf987, arg301_1, arg302_1, buf988, False)
del arg301_1
del arg302_1
del buf987
buf990 = buf989[0]
del buf989
# Source Nodes: [choose_qparams_per_token_asymmetric_80], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric]
buf994 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(reinterpret_tensor(buf990, (1, 1, 4096), (4096, 4096, 1), 0), torch.int8)
buf995 = buf994[0]
buf996 = buf994[1]
del buf994
# Source Nodes: [input_242], Original ATen: [quantized_decomposed.quantize_per_token]
buf997 = torch.ops.quantized_decomposed.quantize_per_token.default(reinterpret_tensor(buf990, (1, 1, 4096), (4096, 4096, 1), 0), buf995, buf996, -128, 127, torch.int8)
buf998 = buf997
del buf997
# Source Nodes: [input_243], Original ATen: [quantized_decomposed.dequantize_per_token]
buf999 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf998, buf995, buf996, -128, 127, torch.int8, torch.bfloat16)
del buf995
del buf996
del buf998
buf1000 = buf999
del buf999
# Source Nodes: [w_dq_80], Original ATen: [quantized_decomposed.dequantize_per_channel_group]
buf1001 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg303_1, arg304_1, arg305_1, -8, 7, torch.int8, 256, torch.bfloat16)
del arg303_1
del arg304_1
del arg305_1
buf1002 = buf1001
del buf1001
buf1003 = reinterpret_tensor(buf990, (1, 4096), (4096, 1), 0); del buf990 # reuse
# Source Nodes: [c_80], Original ATen: [aten.mm]
triton_tem_fused_mm_2.run(buf1000, buf1002, buf1003, grid=torch._inductor.kernel.mm_common.mm_grid(1, 4096, meta0), stream=stream0)
del buf1002
buf1004 = reinterpret_tensor(buf1003, (1, 1, 4096), (4096, 4096, 1), 0); del buf1003 # reuse
buf1006 = buf1000; del buf1000 # reuse
# Source Nodes: [add_82, h_11, h_12, mean_23, mul_153, mul_154, out_10, out_9, pow_24, rsqrt_23, x_fp32_23, x_normed_23], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, aten.pow, aten.rsqrt]
triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_15.run(buf1004, buf917, buf831, buf864, buf950, arg306_1, buf1006, 1, 4096, grid=grid(1), stream=stream0)
del arg306_1
del buf831
del buf864
del buf917
del buf950
# Source Nodes: [choose_qparams_per_token_asymmetric_81], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric]
buf1007 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf1006, torch.int8)
buf1008 = buf1007[0]
buf1009 = buf1007[1]
del buf1007
# Source Nodes: [choose_qparams_per_token_asymmetric_82], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric]
buf1010 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf1006, torch.int8)
buf1011 = buf1010[0]
buf1012 = buf1010[1]
del buf1010
# Source Nodes: [input_245], Original ATen: [quantized_decomposed.quantize_per_token]
buf1013 = torch.ops.quantized_decomposed.quantize_per_token.default(buf1006, buf1008, buf1009, -128, 127, torch.int8)
buf1014 = buf1013
del buf1013
# Source Nodes: [input_246], Original ATen: [quantized_decomposed.dequantize_per_token]
buf1015 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf1014, buf1008, buf1009, -128, 127, torch.int8, torch.bfloat16)
del buf1008
del buf1009
del buf1014
buf1016 = buf1015
del buf1015
# Source Nodes: [w_dq_81], Original ATen: [quantized_decomposed.dequantize_per_channel_group]
buf1017 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg307_1, arg308_1, arg309_1, -8, 7, torch.int8, 256, torch.bfloat16)
del arg307_1
del arg308_1
del arg309_1
buf1018 = buf1017
del buf1017
buf1019 = reinterpret_tensor(buf947, (1, 14336), (14336, 1), 0); del buf947 # reuse
# Source Nodes: [c_81], Original ATen: [aten.mm]
triton_tem_fused_mm_6.run(buf1016, buf1018, buf1019, grid=torch._inductor.kernel.mm_common.mm_grid(1, 14336, meta1), stream=stream0)
del buf1018
# Source Nodes: [input_248], Original ATen: [quantized_decomposed.quantize_per_token]
buf1020 = torch.ops.quantized_decomposed.quantize_per_token.default(buf1006, buf1011, buf1012, -128, 127, torch.int8)
buf1021 = buf1020
del buf1020
# Source Nodes: [input_249], Original ATen: [quantized_decomposed.dequantize_per_token]
buf1022 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf1021, buf1011, buf1012, -128, 127, torch.int8, torch.bfloat16)
del buf1011
del buf1012
del buf1021
buf1023 = buf1022
del buf1022
# Source Nodes: [w_dq_82], Original ATen: [quantized_decomposed.dequantize_per_channel_group]
buf1024 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg310_1, arg311_1, arg312_1, -8, 7, torch.int8, 256, torch.bfloat16)
del arg310_1
del arg311_1
del arg312_1
buf1025 = buf1024
del buf1024
buf1027 = buf940; del buf940 # reuse
# Source Nodes: [c_82, mul_155, silu_11], Original ATen: [aten.mm, aten.mul, aten.silu]
triton_tem_fused_mm_mul_silu_7.run(buf1023, buf1025, buf1019, buf1027, grid=torch._inductor.kernel.mm_common.mm_grid(1, 14336, meta1), stream=stream0)
del buf1019
del buf1025
# Source Nodes: [choose_qparams_per_token_asymmetric_83], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric]
buf1028 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf1027, torch.int8)
buf1029 = buf1028[0]
buf1030 = buf1028[1]
del buf1028
# Source Nodes: [input_251], Original ATen: [quantized_decomposed.quantize_per_token]
buf1031 = torch.ops.quantized_decomposed.quantize_per_token.default(buf1027, buf1029, buf1030, -128, 127, torch.int8)
buf1032 = buf1031
del buf1031
# Source Nodes: [input_252], Original ATen: [quantized_decomposed.dequantize_per_token]
buf1033 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf1032, buf1029, buf1030, -128, 127, torch.int8, torch.bfloat16)
del buf1029
del buf1030
del buf1032
buf1034 = buf1033
del buf1033
# Source Nodes: [w_dq_83], Original ATen: [quantized_decomposed.dequantize_per_channel_group]
buf1035 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg313_1, arg314_1, arg315_1, -8, 7, torch.int8, 256, torch.bfloat16)
del arg313_1
del arg314_1
del arg315_1
buf1036 = buf1035
del buf1035
buf1037 = reinterpret_tensor(buf1023, (1, 4096), (4096, 1), 0); del buf1023 # reuse
# Source Nodes: [c_83], Original ATen: [aten.mm]
triton_tem_fused_mm_8.run(buf1034, buf1036, buf1037, grid=torch._inductor.kernel.mm_common.mm_grid(1, 4096, meta0), stream=stream0)
del buf1036
buf1039 = buf1006; del buf1006 # reuse
# Source Nodes: [add_84, mean_24, mul_156, out_11, pow_25, rsqrt_24, x_fp32_24, x_normed_24, y_12], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, aten.pow, aten.rsqrt]
triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_11.run(buf1004, buf1037, arg316_1, buf1039, 1, 4096, grid=grid(1), stream=stream0)
del arg316_1
# Source Nodes: [choose_qparams_per_token_asymmetric_84], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric]
buf1040 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf1039, torch.int8)
buf1041 = buf1040[0]
buf1042 = buf1040[1]
del buf1040
# Source Nodes: [choose_qparams_per_token_asymmetric_85], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric]
buf1043 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf1039, torch.int8)
buf1044 = buf1043[0]
buf1045 = buf1043[1]
del buf1043
# Source Nodes: [choose_qparams_per_token_asymmetric_86], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric]
buf1046 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf1039, torch.int8)
buf1047 = buf1046[0]
buf1048 = buf1046[1]
del buf1046
buf1049 = buf962; del buf962 # reuse
# Source Nodes: [max_13], Original ATen: [aten.max]
triton_poi_fused_max_1.run(arg3_1, buf1049, 1, grid=grid(1), stream=stream0)
u12 = buf1049.item()
buf1050 = None
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
# Source Nodes: [input_254], Original ATen: [quantized_decomposed.quantize_per_token]
buf1051 = torch.ops.quantized_decomposed.quantize_per_token.default(buf1039, buf1041, buf1042, -128, 127, torch.int8)
buf1052 = buf1051
del buf1051
# Source Nodes: [input_255], Original ATen: [quantized_decomposed.dequantize_per_token]
buf1053 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf1052, buf1041, buf1042, -128, 127, torch.int8, torch.bfloat16)
del buf1041
del buf1042
del buf1052
buf1054 = buf1053
del buf1053
# Source Nodes: [w_dq_84], Original ATen: [quantized_decomposed.dequantize_per_channel_group]
buf1055 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg317_1, arg318_1, arg319_1, -8, 7, torch.int8, 256, torch.bfloat16)
del arg317_1
del arg318_1
del arg319_1
buf1056 = buf1055
del buf1055
buf1057 = reinterpret_tensor(buf1016, (1, 4096), (4096, 1), 0); del buf1016 # reuse
# Source Nodes: [c_84], Original ATen: [aten.mm]
triton_tem_fused_mm_2.run(buf1054, buf1056, buf1057, grid=torch._inductor.kernel.mm_common.mm_grid(1, 4096, meta0), stream=stream0)
del buf1054
del buf1056
# Source Nodes: [input_257], Original ATen: [quantized_decomposed.quantize_per_token]
buf1058 = torch.ops.quantized_decomposed.quantize_per_token.default(buf1039, buf1044, buf1045, -128, 127, torch.int8)
buf1059 = buf1058
del buf1058
# Source Nodes: [input_258], Original ATen: [quantized_decomposed.dequantize_per_token]
buf1060 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf1059, buf1044, buf1045, -128, 127, torch.int8, torch.bfloat16)
del buf1044
del buf1045
del buf1059
buf1061 = buf1060
del buf1060
# Source Nodes: [w_dq_85], Original ATen: [quantized_decomposed.dequantize_per_channel_group]
buf1062 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg320_1, arg321_1, arg322_1, -8, 7, torch.int8, 256, torch.bfloat16)
del arg320_1
del arg321_1
del arg322_1
buf1063 = buf1062
del buf1062
buf1064 = buf985; del buf985 # reuse
# Source Nodes: [], Original ATen: []
extern_kernels.mm(reinterpret_tensor(buf1061, (1, 4096), (4096, 1), 0), reinterpret_tensor(buf1063, (4096, 1024), (1, 4096), 0), out=buf1064)
del buf1061
del buf1063
# Source Nodes: [input_260], Original ATen: [quantized_decomposed.quantize_per_token]
buf1066 = torch.ops.quantized_decomposed.quantize_per_token.default(buf1039, buf1047, buf1048, -128, 127, torch.int8)
del buf1039
buf1067 = buf1066
del buf1066
# Source Nodes: [input_261], Original ATen: [quantized_decomposed.dequantize_per_token]
buf1068 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf1067, buf1047, buf1048, -128, 127, torch.int8, torch.bfloat16)
del buf1047
del buf1048
del buf1067
buf1069 = buf1068
del buf1068
# Source Nodes: [w_dq_86], Original ATen: [quantized_decomposed.dequantize_per_channel_group]
buf1070 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg323_1, arg324_1, arg325_1, -8, 7, torch.int8, 256, torch.bfloat16)
del arg323_1
del arg324_1
del arg325_1
buf1071 = buf1070
del buf1070
buf1072 = buf977; del buf977 # reuse
# Source Nodes: [], Original ATen: []
extern_kernels.mm(reinterpret_tensor(buf1069, (1, 4096), (4096, 1), 0), reinterpret_tensor(buf1071, (4096, 1024), (1, 4096), 0), out=buf1072)
del buf1071
buf1074 = reinterpret_tensor(buf1069, (1, 32, 1, 128), (4096, 128, 4096, 1), 0); del buf1069 # reuse
# Source Nodes: [output_24, setitem_24, setitem_25], Original ATen: [aten._scaled_dot_product_efficient_attention, aten.index_put]
triton_poi_fused__scaled_dot_product_efficient_attention_index_put_3.run(arg3_1, buf1064, arg326_1, buf1072, buf1057, arg327_1, arg328_1, buf1074, 4096, grid=grid(4096), stream=stream0)
del arg326_1
del buf1057
buf1075 = buf988; del buf988 # reuse
# Source Nodes: [output_24], Original ATen: [aten._scaled_dot_product_efficient_attention]
triton_poi_fused__scaled_dot_product_efficient_attention_4.run(arg3_1, arg2_1, buf1075, 65536, grid=grid(65536), stream=stream0)
# Source Nodes: [output_24], Original ATen: [aten._scaled_dot_product_efficient_attention]
buf1076 = torch.ops.aten._scaled_dot_product_efficient_attention.default(buf1074, arg327_1, arg328_1, buf1075, False)
del arg327_1
del arg328_1
del buf1074
buf1077 = buf1076[0]
del buf1076
# Source Nodes: [choose_qparams_per_token_asymmetric_87], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric]
buf1081 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(reinterpret_tensor(buf1077, (1, 1, 4096), (4096, 4096, 1), 0), torch.int8)
buf1082 = buf1081[0]
buf1083 = buf1081[1]
del buf1081
# Source Nodes: [input_263], Original ATen: [quantized_decomposed.quantize_per_token]
buf1084 = torch.ops.quantized_decomposed.quantize_per_token.default(reinterpret_tensor(buf1077, (1, 1, 4096), (4096, 4096, 1), 0), buf1082, buf1083, -128, 127, torch.int8)
buf1085 = buf1084
del buf1084
# Source Nodes: [input_264], Original ATen: [quantized_decomposed.dequantize_per_token]
buf1086 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf1085, buf1082, buf1083, -128, 127, torch.int8, torch.bfloat16)
del buf1082
del buf1083
del buf1085
buf1087 = buf1086
del buf1086
# Source Nodes: [w_dq_87], Original ATen: [quantized_decomposed.dequantize_per_channel_group]
buf1088 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg329_1, arg330_1, arg331_1, -8, 7, torch.int8, 256, torch.bfloat16)
del arg329_1
del arg330_1
del arg331_1
buf1089 = buf1088
del buf1088
buf1090 = reinterpret_tensor(buf1077, (1, 4096), (4096, 1), 0); del buf1077 # reuse
# Source Nodes: [c_87], Original ATen: [aten.mm]
triton_tem_fused_mm_2.run(buf1087, buf1089, buf1090, grid=torch._inductor.kernel.mm_common.mm_grid(1, 4096, meta0), stream=stream0)
del buf1089
buf1092 = buf1087; del buf1087 # reuse
# Source Nodes: [add_89, h_13, mean_25, mul_166, mul_167, out_11, pow_26, rsqrt_25, x_fp32_25, x_normed_25], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, aten.pow, aten.rsqrt]
triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_12.run(buf1090, buf1004, buf1037, arg332_1, buf1092, 1, 4096, grid=grid(1), stream=stream0)
del arg332_1
# Source Nodes: [choose_qparams_per_token_asymmetric_88], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric]
buf1093 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf1092, torch.int8)
buf1094 = buf1093[0]
buf1095 = buf1093[1]
del buf1093
# Source Nodes: [choose_qparams_per_token_asymmetric_89], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric]
buf1096 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf1092, torch.int8)
buf1097 = buf1096[0]
buf1098 = buf1096[1]
del buf1096
# Source Nodes: [input_266], Original ATen: [quantized_decomposed.quantize_per_token]
buf1099 = torch.ops.quantized_decomposed.quantize_per_token.default(buf1092, buf1094, buf1095, -128, 127, torch.int8)
buf1100 = buf1099
del buf1099
# Source Nodes: [input_267], Original ATen: [quantized_decomposed.dequantize_per_token]
buf1101 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf1100, buf1094, buf1095, -128, 127, torch.int8, torch.bfloat16)
del buf1094
del buf1095
del buf1100
buf1102 = buf1101
del buf1101
# Source Nodes: [w_dq_88], Original ATen: [quantized_decomposed.dequantize_per_channel_group]
buf1103 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg333_1, arg334_1, arg335_1, -8, 7, torch.int8, 256, torch.bfloat16)
del arg333_1
del arg334_1
del arg335_1
buf1104 = buf1103
del buf1103
buf1105 = reinterpret_tensor(buf1034, (1, 14336), (14336, 1), 0); del buf1034 # reuse
# Source Nodes: [c_88], Original ATen: [aten.mm]
triton_tem_fused_mm_6.run(buf1102, buf1104, buf1105, grid=torch._inductor.kernel.mm_common.mm_grid(1, 14336, meta1), stream=stream0)
del buf1104
# Source Nodes: [input_269], Original ATen: [quantized_decomposed.quantize_per_token]
buf1106 = torch.ops.quantized_decomposed.quantize_per_token.default(buf1092, buf1097, buf1098, -128, 127, torch.int8)
buf1107 = buf1106
del buf1106
# Source Nodes: [input_270], Original ATen: [quantized_decomposed.dequantize_per_token]
buf1108 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf1107, buf1097, buf1098, -128, 127, torch.int8, torch.bfloat16)
del buf1097
del buf1098
del buf1107
buf1109 = buf1108
del buf1108
# Source Nodes: [w_dq_89], Original ATen: [quantized_decomposed.dequantize_per_channel_group]
buf1110 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg336_1, arg337_1, arg338_1, -8, 7, torch.int8, 256, torch.bfloat16)
del arg336_1
del arg337_1
del arg338_1
buf1111 = buf1110
del buf1110
buf1113 = buf1027; del buf1027 # reuse
# Source Nodes: [c_89, mul_168, silu_12], Original ATen: [aten.mm, aten.mul, aten.silu]
triton_tem_fused_mm_mul_silu_7.run(buf1109, buf1111, buf1105, buf1113, grid=torch._inductor.kernel.mm_common.mm_grid(1, 14336, meta1), stream=stream0)
del buf1105
del buf1111
# Source Nodes: [choose_qparams_per_token_asymmetric_90], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric]
buf1114 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf1113, torch.int8)
buf1115 = buf1114[0]
buf1116 = buf1114[1]
del buf1114
# Source Nodes: [input_272], Original ATen: [quantized_decomposed.quantize_per_token]
buf1117 = torch.ops.quantized_decomposed.quantize_per_token.default(buf1113, buf1115, buf1116, -128, 127, torch.int8)
buf1118 = buf1117
del buf1117
# Source Nodes: [input_273], Original ATen: [quantized_decomposed.dequantize_per_token]
buf1119 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf1118, buf1115, buf1116, -128, 127, torch.int8, torch.bfloat16)
del buf1115
del buf1116
del buf1118
buf1120 = buf1119
del buf1119
# Source Nodes: [w_dq_90], Original ATen: [quantized_decomposed.dequantize_per_channel_group]
buf1121 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg339_1, arg340_1, arg341_1, -8, 7, torch.int8, 256, torch.bfloat16)
del arg339_1
del arg340_1
del arg341_1
buf1122 = buf1121
del buf1121
buf1123 = reinterpret_tensor(buf1109, (1, 4096), (4096, 1), 0); del buf1109 # reuse
# Source Nodes: [c_90], Original ATen: [aten.mm]
triton_tem_fused_mm_8.run(buf1120, buf1122, buf1123, grid=torch._inductor.kernel.mm_common.mm_grid(1, 4096, meta0), stream=stream0)
del buf1122
buf1125 = buf1092; del buf1092 # reuse
# Source Nodes: [add_91, h_13, mean_26, mul_169, out_11, out_12, pow_27, rsqrt_26, x_fp32_26, x_normed_26, y_13], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, aten.pow, aten.rsqrt]
triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_13.run(buf1090, buf1004, buf1037, buf1123, arg342_1, buf1125, 1, 4096, grid=grid(1), stream=stream0)
del arg342_1
# Source Nodes: [choose_qparams_per_token_asymmetric_91], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric]
buf1126 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf1125, torch.int8)
buf1127 = buf1126[0]
buf1128 = buf1126[1]
del buf1126
# Source Nodes: [choose_qparams_per_token_asymmetric_92], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric]
buf1129 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf1125, torch.int8)
buf1130 = buf1129[0]
buf1131 = buf1129[1]
del buf1129
# Source Nodes: [choose_qparams_per_token_asymmetric_93], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric]
buf1132 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf1125, torch.int8)
buf1133 = buf1132[0]
buf1134 = buf1132[1]
del buf1132
buf1135 = buf1049; del buf1049 # reuse
# Source Nodes: [max_14], Original ATen: [aten.max]
triton_poi_fused_max_1.run(arg3_1, buf1135, 1, grid=grid(1), stream=stream0)
u13 = buf1135.item()
buf1136 = None
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
# Source Nodes: [input_275], Original ATen: [quantized_decomposed.quantize_per_token]
buf1137 = torch.ops.quantized_decomposed.quantize_per_token.default(buf1125, buf1127, buf1128, -128, 127, torch.int8)
buf1138 = buf1137
del buf1137
# Source Nodes: [input_276], Original ATen: [quantized_decomposed.dequantize_per_token]
buf1139 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf1138, buf1127, buf1128, -128, 127, torch.int8, torch.bfloat16)
del buf1127
del buf1128
del buf1138
buf1140 = buf1139
del buf1139
# Source Nodes: [w_dq_91], Original ATen: [quantized_decomposed.dequantize_per_channel_group]
buf1141 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg343_1, arg344_1, arg345_1, -8, 7, torch.int8, 256, torch.bfloat16)
del arg343_1
del arg344_1
del arg345_1
buf1142 = buf1141
del buf1141
buf1143 = reinterpret_tensor(buf1102, (1, 4096), (4096, 1), 0); del buf1102 # reuse
# Source Nodes: [c_91], Original ATen: [aten.mm]
triton_tem_fused_mm_2.run(buf1140, buf1142, buf1143, grid=torch._inductor.kernel.mm_common.mm_grid(1, 4096, meta0), stream=stream0)
del buf1140
del buf1142
# Source Nodes: [input_278], Original ATen: [quantized_decomposed.quantize_per_token]
buf1144 = torch.ops.quantized_decomposed.quantize_per_token.default(buf1125, buf1130, buf1131, -128, 127, torch.int8)
buf1145 = buf1144
del buf1144
# Source Nodes: [input_279], Original ATen: [quantized_decomposed.dequantize_per_token]
buf1146 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf1145, buf1130, buf1131, -128, 127, torch.int8, torch.bfloat16)
del buf1130
del buf1131
del buf1145
buf1147 = buf1146
del buf1146
# Source Nodes: [w_dq_92], Original ATen: [quantized_decomposed.dequantize_per_channel_group]
buf1148 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg346_1, arg347_1, arg348_1, -8, 7, torch.int8, 256, torch.bfloat16)
del arg346_1
del arg347_1
del arg348_1
buf1149 = buf1148
del buf1148
buf1150 = buf1072; del buf1072 # reuse
# Source Nodes: [], Original ATen: []
extern_kernels.mm(reinterpret_tensor(buf1147, (1, 4096), (4096, 1), 0), reinterpret_tensor(buf1149, (4096, 1024), (1, 4096), 0), out=buf1150)
del buf1147
del buf1149
# Source Nodes: [input_281], Original ATen: [quantized_decomposed.quantize_per_token]
buf1152 = torch.ops.quantized_decomposed.quantize_per_token.default(buf1125, buf1133, buf1134, -128, 127, torch.int8)
del buf1125
buf1153 = buf1152
del buf1152
# Source Nodes: [input_282], Original ATen: [quantized_decomposed.dequantize_per_token]
buf1154 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf1153, buf1133, buf1134, -128, 127, torch.int8, torch.bfloat16)
del buf1133
del buf1134
del buf1153
buf1155 = buf1154
del buf1154
# Source Nodes: [w_dq_93], Original ATen: [quantized_decomposed.dequantize_per_channel_group]
buf1156 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg349_1, arg350_1, arg351_1, -8, 7, torch.int8, 256, torch.bfloat16)
del arg349_1
del arg350_1
del arg351_1
buf1157 = buf1156
del buf1156
buf1158 = buf1064; del buf1064 # reuse
# Source Nodes: [], Original ATen: []
extern_kernels.mm(reinterpret_tensor(buf1155, (1, 4096), (4096, 1), 0), reinterpret_tensor(buf1157, (4096, 1024), (1, 4096), 0), out=buf1158)
del buf1157
buf1160 = reinterpret_tensor(buf1155, (1, 32, 1, 128), (4096, 128, 4096, 1), 0); del buf1155 # reuse
# Source Nodes: [output_26, setitem_26, setitem_27], Original ATen: [aten._scaled_dot_product_efficient_attention, aten.index_put]
triton_poi_fused__scaled_dot_product_efficient_attention_index_put_3.run(arg3_1, buf1150, arg352_1, buf1158, buf1143, arg353_1, arg354_1, buf1160, 4096, grid=grid(4096), stream=stream0)
del arg352_1
del buf1143
buf1161 = buf1075; del buf1075 # reuse
# Source Nodes: [output_26], Original ATen: [aten._scaled_dot_product_efficient_attention]
triton_poi_fused__scaled_dot_product_efficient_attention_4.run(arg3_1, arg2_1, buf1161, 65536, grid=grid(65536), stream=stream0)
# Source Nodes: [output_26], Original ATen: [aten._scaled_dot_product_efficient_attention]
buf1162 = torch.ops.aten._scaled_dot_product_efficient_attention.default(buf1160, arg353_1, arg354_1, buf1161, False)
del arg353_1
del arg354_1
del buf1160
buf1163 = buf1162[0]
del buf1162
# Source Nodes: [choose_qparams_per_token_asymmetric_94], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric]
buf1167 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(reinterpret_tensor(buf1163, (1, 1, 4096), (4096, 4096, 1), 0), torch.int8)
buf1168 = buf1167[0]
buf1169 = buf1167[1]
del buf1167
# Source Nodes: [input_284], Original ATen: [quantized_decomposed.quantize_per_token]
buf1170 = torch.ops.quantized_decomposed.quantize_per_token.default(reinterpret_tensor(buf1163, (1, 1, 4096), (4096, 4096, 1), 0), buf1168, buf1169, -128, 127, torch.int8)
buf1171 = buf1170
del buf1170
# Source Nodes: [input_285], Original ATen: [quantized_decomposed.dequantize_per_token]
buf1172 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf1171, buf1168, buf1169, -128, 127, torch.int8, torch.bfloat16)
del buf1168
del buf1169
del buf1171
buf1173 = buf1172
del buf1172
# Source Nodes: [w_dq_94], Original ATen: [quantized_decomposed.dequantize_per_channel_group]
buf1174 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg355_1, arg356_1, arg357_1, -8, 7, torch.int8, 256, torch.bfloat16)
del arg355_1
del arg356_1
del arg357_1
buf1175 = buf1174
del buf1174
buf1176 = reinterpret_tensor(buf1163, (1, 4096), (4096, 1), 0); del buf1163 # reuse
# Source Nodes: [c_94], Original ATen: [aten.mm]
triton_tem_fused_mm_2.run(buf1173, buf1175, buf1176, grid=torch._inductor.kernel.mm_common.mm_grid(1, 4096, meta0), stream=stream0)
del buf1175
buf1177 = buf1004; del buf1004 # reuse
buf1179 = buf1173; del buf1173 # reuse
# Source Nodes: [add_96, h_13, h_14, mean_27, mul_179, mul_180, out_11, out_12, pow_28, rsqrt_27, x_fp32_27, x_normed_27], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, aten.pow, aten.rsqrt]
triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_14.run(buf1177, buf1176, buf1090, buf1037, buf1123, arg358_1, buf1179, 1, 4096, grid=grid(1), stream=stream0)
del arg358_1
del buf1037
del buf1090
del buf1123
del buf1176
# Source Nodes: [choose_qparams_per_token_asymmetric_95], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric]
buf1180 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf1179, torch.int8)
buf1181 = buf1180[0]
buf1182 = buf1180[1]
del buf1180
# Source Nodes: [choose_qparams_per_token_asymmetric_96], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric]
buf1183 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf1179, torch.int8)
buf1184 = buf1183[0]
buf1185 = buf1183[1]
del buf1183
# Source Nodes: [input_287], Original ATen: [quantized_decomposed.quantize_per_token]
buf1186 = torch.ops.quantized_decomposed.quantize_per_token.default(buf1179, buf1181, buf1182, -128, 127, torch.int8)
buf1187 = buf1186
del buf1186
# Source Nodes: [input_288], Original ATen: [quantized_decomposed.dequantize_per_token]
buf1188 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf1187, buf1181, buf1182, -128, 127, torch.int8, torch.bfloat16)
del buf1181
del buf1182
del buf1187
buf1189 = buf1188
del buf1188
# Source Nodes: [w_dq_95], Original ATen: [quantized_decomposed.dequantize_per_channel_group]
buf1190 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg359_1, arg360_1, arg361_1, -8, 7, torch.int8, 256, torch.bfloat16)
del arg359_1
del arg360_1
del arg361_1
buf1191 = buf1190
del buf1190
buf1192 = reinterpret_tensor(buf1120, (1, 14336), (14336, 1), 0); del buf1120 # reuse
# Source Nodes: [c_95], Original ATen: [aten.mm]
triton_tem_fused_mm_6.run(buf1189, buf1191, buf1192, grid=torch._inductor.kernel.mm_common.mm_grid(1, 14336, meta1), stream=stream0)
del buf1191
# Source Nodes: [input_290], Original ATen: [quantized_decomposed.quantize_per_token]
buf1193 = torch.ops.quantized_decomposed.quantize_per_token.default(buf1179, buf1184, buf1185, -128, 127, torch.int8)
buf1194 = buf1193
del buf1193
# Source Nodes: [input_291], Original ATen: [quantized_decomposed.dequantize_per_token]
buf1195 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf1194, buf1184, buf1185, -128, 127, torch.int8, torch.bfloat16)
del buf1184
del buf1185
del buf1194
buf1196 = buf1195
del buf1195
# Source Nodes: [w_dq_96], Original ATen: [quantized_decomposed.dequantize_per_channel_group]
buf1197 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg362_1, arg363_1, arg364_1, -8, 7, torch.int8, 256, torch.bfloat16)
del arg362_1
del arg363_1
del arg364_1
buf1198 = buf1197
del buf1197
buf1200 = buf1113; del buf1113 # reuse
# Source Nodes: [c_96, mul_181, silu_13], Original ATen: [aten.mm, aten.mul, aten.silu]
triton_tem_fused_mm_mul_silu_7.run(buf1196, buf1198, buf1192, buf1200, grid=torch._inductor.kernel.mm_common.mm_grid(1, 14336, meta1), stream=stream0)
del buf1192
del buf1198
# Source Nodes: [choose_qparams_per_token_asymmetric_97], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric]
buf1201 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf1200, torch.int8)
buf1202 = buf1201[0]
buf1203 = buf1201[1]
del buf1201
# Source Nodes: [input_293], Original ATen: [quantized_decomposed.quantize_per_token]
buf1204 = torch.ops.quantized_decomposed.quantize_per_token.default(buf1200, buf1202, buf1203, -128, 127, torch.int8)
buf1205 = buf1204
del buf1204
# Source Nodes: [input_294], Original ATen: [quantized_decomposed.dequantize_per_token]
buf1206 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf1205, buf1202, buf1203, -128, 127, torch.int8, torch.bfloat16)
del buf1202
del buf1203
del buf1205
buf1207 = buf1206
del buf1206
# Source Nodes: [w_dq_97], Original ATen: [quantized_decomposed.dequantize_per_channel_group]
buf1208 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg365_1, arg366_1, arg367_1, -8, 7, torch.int8, 256, torch.bfloat16)
del arg365_1
del arg366_1
del arg367_1
buf1209 = buf1208
del buf1208
buf1210 = reinterpret_tensor(buf1196, (1, 4096), (4096, 1), 0); del buf1196 # reuse
# Source Nodes: [c_97], Original ATen: [aten.mm]
triton_tem_fused_mm_8.run(buf1207, buf1209, buf1210, grid=torch._inductor.kernel.mm_common.mm_grid(1, 4096, meta0), stream=stream0)
del buf1209
buf1212 = buf1179; del buf1179 # reuse
# Source Nodes: [add_98, mean_28, mul_182, out_13, pow_29, rsqrt_28, x_fp32_28, x_normed_28, y_14], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, aten.pow, aten.rsqrt]
triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_11.run(buf1177, buf1210, arg368_1, buf1212, 1, 4096, grid=grid(1), stream=stream0)
del arg368_1
# Source Nodes: [choose_qparams_per_token_asymmetric_98], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric]
buf1213 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf1212, torch.int8)
buf1214 = buf1213[0]
buf1215 = buf1213[1]
del buf1213
# Source Nodes: [choose_qparams_per_token_asymmetric_99], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric]
buf1216 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf1212, torch.int8)
buf1217 = buf1216[0]
buf1218 = buf1216[1]
del buf1216
# Source Nodes: [choose_qparams_per_token_asymmetric_100], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric]
buf1219 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf1212, torch.int8)
buf1220 = buf1219[0]
buf1221 = buf1219[1]
del buf1219
buf1222 = buf1135; del buf1135 # reuse
# Source Nodes: [max_15], Original ATen: [aten.max]
triton_poi_fused_max_1.run(arg3_1, buf1222, 1, grid=grid(1), stream=stream0)
u14 = buf1222.item()
buf1223 = None
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
# Source Nodes: [input_296], Original ATen: [quantized_decomposed.quantize_per_token]
buf1224 = torch.ops.quantized_decomposed.quantize_per_token.default(buf1212, buf1214, buf1215, -128, 127, torch.int8)
buf1225 = buf1224
del buf1224
# Source Nodes: [input_297], Original ATen: [quantized_decomposed.dequantize_per_token]
buf1226 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf1225, buf1214, buf1215, -128, 127, torch.int8, torch.bfloat16)
del buf1214
del buf1215
del buf1225
buf1227 = buf1226
del buf1226
# Source Nodes: [w_dq_98], Original ATen: [quantized_decomposed.dequantize_per_channel_group]
buf1228 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg369_1, arg370_1, arg371_1, -8, 7, torch.int8, 256, torch.bfloat16)
del arg369_1
del arg370_1
del arg371_1
buf1229 = buf1228
del buf1228
buf1230 = reinterpret_tensor(buf1189, (1, 4096), (4096, 1), 0); del buf1189 # reuse
# Source Nodes: [c_98], Original ATen: [aten.mm]
triton_tem_fused_mm_2.run(buf1227, buf1229, buf1230, grid=torch._inductor.kernel.mm_common.mm_grid(1, 4096, meta0), stream=stream0)
del buf1227
del buf1229
# Source Nodes: [input_299], Original ATen: [quantized_decomposed.quantize_per_token]
buf1231 = torch.ops.quantized_decomposed.quantize_per_token.default(buf1212, buf1217, buf1218, -128, 127, torch.int8)
buf1232 = buf1231
del buf1231
# Source Nodes: [input_300], Original ATen: [quantized_decomposed.dequantize_per_token]
buf1233 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf1232, buf1217, buf1218, -128, 127, torch.int8, torch.bfloat16)
del buf1217
del buf1218
del buf1232
buf1234 = buf1233
del buf1233
# Source Nodes: [w_dq_99], Original ATen: [quantized_decomposed.dequantize_per_channel_group]
buf1235 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg372_1, arg373_1, arg374_1, -8, 7, torch.int8, 256, torch.bfloat16)
del arg372_1
del arg373_1
del arg374_1
buf1236 = buf1235
del buf1235
buf1237 = buf1158; del buf1158 # reuse
# Source Nodes: [], Original ATen: []
extern_kernels.mm(reinterpret_tensor(buf1234, (1, 4096), (4096, 1), 0), reinterpret_tensor(buf1236, (4096, 1024), (1, 4096), 0), out=buf1237)
del buf1234
del buf1236
# Source Nodes: [input_302], Original ATen: [quantized_decomposed.quantize_per_token]
buf1239 = torch.ops.quantized_decomposed.quantize_per_token.default(buf1212, buf1220, buf1221, -128, 127, torch.int8)
del buf1212
buf1240 = buf1239
del buf1239
# Source Nodes: [input_303], Original ATen: [quantized_decomposed.dequantize_per_token]
buf1241 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf1240, buf1220, buf1221, -128, 127, torch.int8, torch.bfloat16)
del buf1220
del buf1221
del buf1240
buf1242 = buf1241
del buf1241
# Source Nodes: [w_dq_100], Original ATen: [quantized_decomposed.dequantize_per_channel_group]
buf1243 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg375_1, arg376_1, arg377_1, -8, 7, torch.int8, 256, torch.bfloat16)
del arg375_1
del arg376_1
del arg377_1
buf1244 = buf1243
del buf1243
buf1245 = buf1150; del buf1150 # reuse
# Source Nodes: [], Original ATen: []
extern_kernels.mm(reinterpret_tensor(buf1242, (1, 4096), (4096, 1), 0), reinterpret_tensor(buf1244, (4096, 1024), (1, 4096), 0), out=buf1245)
del buf1244
buf1247 = reinterpret_tensor(buf1242, (1, 32, 1, 128), (4096, 128, 4096, 1), 0); del buf1242 # reuse
# Source Nodes: [output_28, setitem_28, setitem_29], Original ATen: [aten._scaled_dot_product_efficient_attention, aten.index_put]
triton_poi_fused__scaled_dot_product_efficient_attention_index_put_3.run(arg3_1, buf1237, arg378_1, buf1245, buf1230, arg379_1, arg380_1, buf1247, 4096, grid=grid(4096), stream=stream0)
del arg378_1
del buf1230
buf1248 = buf1161; del buf1161 # reuse
# Source Nodes: [output_28], Original ATen: [aten._scaled_dot_product_efficient_attention]
triton_poi_fused__scaled_dot_product_efficient_attention_4.run(arg3_1, arg2_1, buf1248, 65536, grid=grid(65536), stream=stream0)
# Source Nodes: [output_28], Original ATen: [aten._scaled_dot_product_efficient_attention]
buf1249 = torch.ops.aten._scaled_dot_product_efficient_attention.default(buf1247, arg379_1, arg380_1, buf1248, False)
del arg379_1
del arg380_1
del buf1247
buf1250 = buf1249[0]
del buf1249
# Source Nodes: [choose_qparams_per_token_asymmetric_101], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric]
buf1254 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(reinterpret_tensor(buf1250, (1, 1, 4096), (4096, 4096, 1), 0), torch.int8)
buf1255 = buf1254[0]
buf1256 = buf1254[1]
del buf1254
# Source Nodes: [input_305], Original ATen: [quantized_decomposed.quantize_per_token]
buf1257 = torch.ops.quantized_decomposed.quantize_per_token.default(reinterpret_tensor(buf1250, (1, 1, 4096), (4096, 4096, 1), 0), buf1255, buf1256, -128, 127, torch.int8)
buf1258 = buf1257
del buf1257
# Source Nodes: [input_306], Original ATen: [quantized_decomposed.dequantize_per_token]
buf1259 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf1258, buf1255, buf1256, -128, 127, torch.int8, torch.bfloat16)
del buf1255
del buf1256
del buf1258
buf1260 = buf1259
del buf1259
# Source Nodes: [w_dq_101], Original ATen: [quantized_decomposed.dequantize_per_channel_group]
buf1261 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg381_1, arg382_1, arg383_1, -8, 7, torch.int8, 256, torch.bfloat16)
del arg381_1
del arg382_1
del arg383_1
buf1262 = buf1261
del buf1261
buf1263 = reinterpret_tensor(buf1250, (1, 4096), (4096, 1), 0); del buf1250 # reuse
# Source Nodes: [c_101], Original ATen: [aten.mm]
triton_tem_fused_mm_2.run(buf1260, buf1262, buf1263, grid=torch._inductor.kernel.mm_common.mm_grid(1, 4096, meta0), stream=stream0)
del buf1262
buf1265 = buf1260; del buf1260 # reuse
# Source Nodes: [add_103, h_15, mean_29, mul_192, mul_193, out_13, pow_30, rsqrt_29, x_fp32_29, x_normed_29], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, aten.pow, aten.rsqrt]
triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_12.run(buf1263, buf1177, buf1210, arg384_1, buf1265, 1, 4096, grid=grid(1), stream=stream0)
del arg384_1
# Source Nodes: [choose_qparams_per_token_asymmetric_102], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric]
buf1266 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf1265, torch.int8)
buf1267 = buf1266[0]
buf1268 = buf1266[1]
del buf1266
# Source Nodes: [choose_qparams_per_token_asymmetric_103], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric]
buf1269 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf1265, torch.int8)
buf1270 = buf1269[0]
buf1271 = buf1269[1]
del buf1269
# Source Nodes: [input_308], Original ATen: [quantized_decomposed.quantize_per_token]
buf1272 = torch.ops.quantized_decomposed.quantize_per_token.default(buf1265, buf1267, buf1268, -128, 127, torch.int8)
buf1273 = buf1272
del buf1272
# Source Nodes: [input_309], Original ATen: [quantized_decomposed.dequantize_per_token]
buf1274 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf1273, buf1267, buf1268, -128, 127, torch.int8, torch.bfloat16)
del buf1267
del buf1268
del buf1273
buf1275 = buf1274
del buf1274
# Source Nodes: [w_dq_102], Original ATen: [quantized_decomposed.dequantize_per_channel_group]
buf1276 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg385_1, arg386_1, arg387_1, -8, 7, torch.int8, 256, torch.bfloat16)
del arg385_1
del arg386_1
del arg387_1
buf1277 = buf1276
del buf1276
buf1278 = reinterpret_tensor(buf1207, (1, 14336), (14336, 1), 0); del buf1207 # reuse
# Source Nodes: [c_102], Original ATen: [aten.mm]
triton_tem_fused_mm_6.run(buf1275, buf1277, buf1278, grid=torch._inductor.kernel.mm_common.mm_grid(1, 14336, meta1), stream=stream0)
del buf1277
# Source Nodes: [input_311], Original ATen: [quantized_decomposed.quantize_per_token]
buf1279 = torch.ops.quantized_decomposed.quantize_per_token.default(buf1265, buf1270, buf1271, -128, 127, torch.int8)
buf1280 = buf1279
del buf1279
# Source Nodes: [input_312], Original ATen: [quantized_decomposed.dequantize_per_token]
buf1281 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf1280, buf1270, buf1271, -128, 127, torch.int8, torch.bfloat16)
del buf1270
del buf1271
del buf1280
buf1282 = buf1281
del buf1281
# Source Nodes: [w_dq_103], Original ATen: [quantized_decomposed.dequantize_per_channel_group]
buf1283 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg388_1, arg389_1, arg390_1, -8, 7, torch.int8, 256, torch.bfloat16)
del arg388_1
del arg389_1
del arg390_1
buf1284 = buf1283
del buf1283
buf1286 = buf1200; del buf1200 # reuse
# Source Nodes: [c_103, mul_194, silu_14], Original ATen: [aten.mm, aten.mul, aten.silu]
triton_tem_fused_mm_mul_silu_7.run(buf1282, buf1284, buf1278, buf1286, grid=torch._inductor.kernel.mm_common.mm_grid(1, 14336, meta1), stream=stream0)
del buf1278
del buf1284
# Source Nodes: [choose_qparams_per_token_asymmetric_104], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric]
buf1287 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf1286, torch.int8)
buf1288 = buf1287[0]
buf1289 = buf1287[1]
del buf1287
# Source Nodes: [input_314], Original ATen: [quantized_decomposed.quantize_per_token]
buf1290 = torch.ops.quantized_decomposed.quantize_per_token.default(buf1286, buf1288, buf1289, -128, 127, torch.int8)
buf1291 = buf1290
del buf1290
# Source Nodes: [input_315], Original ATen: [quantized_decomposed.dequantize_per_token]
buf1292 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf1291, buf1288, buf1289, -128, 127, torch.int8, torch.bfloat16)
del buf1288
del buf1289
del buf1291
buf1293 = buf1292
del buf1292
# Source Nodes: [w_dq_104], Original ATen: [quantized_decomposed.dequantize_per_channel_group]
buf1294 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg391_1, arg392_1, arg393_1, -8, 7, torch.int8, 256, torch.bfloat16)
del arg391_1
del arg392_1
del arg393_1
buf1295 = buf1294
del buf1294
buf1296 = reinterpret_tensor(buf1282, (1, 4096), (4096, 1), 0); del buf1282 # reuse
# Source Nodes: [c_104], Original ATen: [aten.mm]
triton_tem_fused_mm_8.run(buf1293, buf1295, buf1296, grid=torch._inductor.kernel.mm_common.mm_grid(1, 4096, meta0), stream=stream0)
del buf1295
buf1298 = buf1265; del buf1265 # reuse
# Source Nodes: [add_105, h_15, mean_30, mul_195, out_13, out_14, pow_31, rsqrt_30, x_fp32_30, x_normed_30, y_15], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, aten.pow, aten.rsqrt]
triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_13.run(buf1263, buf1177, buf1210, buf1296, arg394_1, buf1298, 1, 4096, grid=grid(1), stream=stream0)
del arg394_1
# Source Nodes: [choose_qparams_per_token_asymmetric_105], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric]
buf1299 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf1298, torch.int8)
buf1300 = buf1299[0]
buf1301 = buf1299[1]
del buf1299
# Source Nodes: [choose_qparams_per_token_asymmetric_106], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric]
buf1302 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf1298, torch.int8)
buf1303 = buf1302[0]
buf1304 = buf1302[1]
del buf1302
# Source Nodes: [choose_qparams_per_token_asymmetric_107], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric]
buf1305 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf1298, torch.int8)
buf1306 = buf1305[0]
buf1307 = buf1305[1]
del buf1305
buf1308 = buf1222; del buf1222 # reuse
# Source Nodes: [max_16], Original ATen: [aten.max]
triton_poi_fused_max_1.run(arg3_1, buf1308, 1, grid=grid(1), stream=stream0)
u15 = buf1308.item()
buf1309 = None
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
# Source Nodes: [input_317], Original ATen: [quantized_decomposed.quantize_per_token]
buf1310 = torch.ops.quantized_decomposed.quantize_per_token.default(buf1298, buf1300, buf1301, -128, 127, torch.int8)
buf1311 = buf1310
del buf1310
# Source Nodes: [input_318], Original ATen: [quantized_decomposed.dequantize_per_token]
buf1312 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf1311, buf1300, buf1301, -128, 127, torch.int8, torch.bfloat16)
del buf1300
del buf1301
del buf1311
buf1313 = buf1312
del buf1312
# Source Nodes: [w_dq_105], Original ATen: [quantized_decomposed.dequantize_per_channel_group]
buf1314 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg395_1, arg396_1, arg397_1, -8, 7, torch.int8, 256, torch.bfloat16)
del arg395_1
del arg396_1
del arg397_1
buf1315 = buf1314
del buf1314
buf1316 = reinterpret_tensor(buf1275, (1, 4096), (4096, 1), 0); del buf1275 # reuse
# Source Nodes: [c_105], Original ATen: [aten.mm]
triton_tem_fused_mm_2.run(buf1313, buf1315, buf1316, grid=torch._inductor.kernel.mm_common.mm_grid(1, 4096, meta0), stream=stream0)
del buf1313
del buf1315
# Source Nodes: [input_320], Original ATen: [quantized_decomposed.quantize_per_token]
buf1317 = torch.ops.quantized_decomposed.quantize_per_token.default(buf1298, buf1303, buf1304, -128, 127, torch.int8)
buf1318 = buf1317
del buf1317
# Source Nodes: [input_321], Original ATen: [quantized_decomposed.dequantize_per_token]
buf1319 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf1318, buf1303, buf1304, -128, 127, torch.int8, torch.bfloat16)
del buf1303
del buf1304
del buf1318
buf1320 = buf1319
del buf1319
# Source Nodes: [w_dq_106], Original ATen: [quantized_decomposed.dequantize_per_channel_group]
buf1321 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg398_1, arg399_1, arg400_1, -8, 7, torch.int8, 256, torch.bfloat16)
del arg398_1
del arg399_1
del arg400_1
buf1322 = buf1321
del buf1321
buf1323 = buf1245; del buf1245 # reuse
# Source Nodes: [], Original ATen: []
extern_kernels.mm(reinterpret_tensor(buf1320, (1, 4096), (4096, 1), 0), reinterpret_tensor(buf1322, (4096, 1024), (1, 4096), 0), out=buf1323)
del buf1320
del buf1322
# Source Nodes: [input_323], Original ATen: [quantized_decomposed.quantize_per_token]
buf1325 = torch.ops.quantized_decomposed.quantize_per_token.default(buf1298, buf1306, buf1307, -128, 127, torch.int8)
del buf1298
buf1326 = buf1325
del buf1325
# Source Nodes: [input_324], Original ATen: [quantized_decomposed.dequantize_per_token]
buf1327 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf1326, buf1306, buf1307, -128, 127, torch.int8, torch.bfloat16)
del buf1306
del buf1307
del buf1326
buf1328 = buf1327
del buf1327
# Source Nodes: [w_dq_107], Original ATen: [quantized_decomposed.dequantize_per_channel_group]
buf1329 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg401_1, arg402_1, arg403_1, -8, 7, torch.int8, 256, torch.bfloat16)
del arg401_1
del arg402_1
del arg403_1
buf1330 = buf1329
del buf1329
buf1331 = buf1237; del buf1237 # reuse
# Source Nodes: [], Original ATen: []
extern_kernels.mm(reinterpret_tensor(buf1328, (1, 4096), (4096, 1), 0), reinterpret_tensor(buf1330, (4096, 1024), (1, 4096), 0), out=buf1331)
del buf1330
buf1333 = reinterpret_tensor(buf1328, (1, 32, 1, 128), (4096, 128, 4096, 1), 0); del buf1328 # reuse
# Source Nodes: [output_30, setitem_30, setitem_31], Original ATen: [aten._scaled_dot_product_efficient_attention, aten.index_put]
triton_poi_fused__scaled_dot_product_efficient_attention_index_put_3.run(arg3_1, buf1323, arg404_1, buf1331, buf1316, arg405_1, arg406_1, buf1333, 4096, grid=grid(4096), stream=stream0)
del arg404_1
del buf1316
buf1334 = buf1248; del buf1248 # reuse
# Source Nodes: [output_30], Original ATen: [aten._scaled_dot_product_efficient_attention]
triton_poi_fused__scaled_dot_product_efficient_attention_4.run(arg3_1, arg2_1, buf1334, 65536, grid=grid(65536), stream=stream0)
# Source Nodes: [output_30], Original ATen: [aten._scaled_dot_product_efficient_attention]
buf1335 = torch.ops.aten._scaled_dot_product_efficient_attention.default(buf1333, arg405_1, arg406_1, buf1334, False)
del arg405_1
del arg406_1
del buf1333
buf1336 = buf1335[0]
del buf1335
# Source Nodes: [choose_qparams_per_token_asymmetric_108], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric]
buf1340 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(reinterpret_tensor(buf1336, (1, 1, 4096), (4096, 4096, 1), 0), torch.int8)
buf1341 = buf1340[0]
buf1342 = buf1340[1]
del buf1340
# Source Nodes: [input_326], Original ATen: [quantized_decomposed.quantize_per_token]
buf1343 = torch.ops.quantized_decomposed.quantize_per_token.default(reinterpret_tensor(buf1336, (1, 1, 4096), (4096, 4096, 1), 0), buf1341, buf1342, -128, 127, torch.int8)
buf1344 = buf1343
del buf1343
# Source Nodes: [input_327], Original ATen: [quantized_decomposed.dequantize_per_token]
buf1345 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf1344, buf1341, buf1342, -128, 127, torch.int8, torch.bfloat16)
del buf1341
del buf1342
del buf1344
buf1346 = buf1345
del buf1345
# Source Nodes: [w_dq_108], Original ATen: [quantized_decomposed.dequantize_per_channel_group]
buf1347 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg407_1, arg408_1, arg409_1, -8, 7, torch.int8, 256, torch.bfloat16)
del arg407_1
del arg408_1
del arg409_1
buf1348 = buf1347
del buf1347
buf1349 = reinterpret_tensor(buf1336, (1, 4096), (4096, 1), 0); del buf1336 # reuse
# Source Nodes: [c_108], Original ATen: [aten.mm]
triton_tem_fused_mm_2.run(buf1346, buf1348, buf1349, grid=torch._inductor.kernel.mm_common.mm_grid(1, 4096, meta0), stream=stream0)
del buf1348
buf1350 = buf1177; del buf1177 # reuse
buf1352 = buf1346; del buf1346 # reuse
# Source Nodes: [add_110, h_15, h_16, mean_31, mul_205, mul_206, out_13, out_14, pow_32, rsqrt_31, x_fp32_31, x_normed_31], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, aten.pow, aten.rsqrt]
triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_14.run(buf1350, buf1349, buf1263, buf1210, buf1296, arg410_1, buf1352, 1, 4096, grid=grid(1), stream=stream0)
del arg410_1
del buf1210
del buf1263
del buf1296
del buf1349
# Source Nodes: [choose_qparams_per_token_asymmetric_109], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric]
buf1353 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf1352, torch.int8)
buf1354 = buf1353[0]
buf1355 = buf1353[1]
del buf1353
# Source Nodes: [choose_qparams_per_token_asymmetric_110], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric]
buf1356 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf1352, torch.int8)
buf1357 = buf1356[0]
buf1358 = buf1356[1]
del buf1356
# Source Nodes: [input_329], Original ATen: [quantized_decomposed.quantize_per_token]
buf1359 = torch.ops.quantized_decomposed.quantize_per_token.default(buf1352, buf1354, buf1355, -128, 127, torch.int8)
buf1360 = buf1359
del buf1359
# Source Nodes: [input_330], Original ATen: [quantized_decomposed.dequantize_per_token]
buf1361 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf1360, buf1354, buf1355, -128, 127, torch.int8, torch.bfloat16)
del buf1354
del buf1355
del buf1360
buf1362 = buf1361
del buf1361
# Source Nodes: [w_dq_109], Original ATen: [quantized_decomposed.dequantize_per_channel_group]
buf1363 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg411_1, arg412_1, arg413_1, -8, 7, torch.int8, 256, torch.bfloat16)
del arg411_1
del arg412_1
del arg413_1
buf1364 = buf1363
del buf1363
buf1365 = reinterpret_tensor(buf1293, (1, 14336), (14336, 1), 0); del buf1293 # reuse
# Source Nodes: [c_109], Original ATen: [aten.mm]
triton_tem_fused_mm_6.run(buf1362, buf1364, buf1365, grid=torch._inductor.kernel.mm_common.mm_grid(1, 14336, meta1), stream=stream0)
del buf1364
# Source Nodes: [input_332], Original ATen: [quantized_decomposed.quantize_per_token]
buf1366 = torch.ops.quantized_decomposed.quantize_per_token.default(buf1352, buf1357, buf1358, -128, 127, torch.int8)
buf1367 = buf1366
del buf1366
# Source Nodes: [input_333], Original ATen: [quantized_decomposed.dequantize_per_token]
buf1368 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf1367, buf1357, buf1358, -128, 127, torch.int8, torch.bfloat16)
del buf1357
del buf1358
del buf1367
buf1369 = buf1368
del buf1368
# Source Nodes: [w_dq_110], Original ATen: [quantized_decomposed.dequantize_per_channel_group]
buf1370 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg414_1, arg415_1, arg416_1, -8, 7, torch.int8, 256, torch.bfloat16)
del arg414_1
del arg415_1
del arg416_1
buf1371 = buf1370
del buf1370
buf1373 = buf1286; del buf1286 # reuse
# Source Nodes: [c_110, mul_207, silu_15], Original ATen: [aten.mm, aten.mul, aten.silu]
triton_tem_fused_mm_mul_silu_7.run(buf1369, buf1371, buf1365, buf1373, grid=torch._inductor.kernel.mm_common.mm_grid(1, 14336, meta1), stream=stream0)
del buf1365
del buf1371
# Source Nodes: [choose_qparams_per_token_asymmetric_111], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric]
buf1374 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf1373, torch.int8)
buf1375 = buf1374[0]
buf1376 = buf1374[1]
del buf1374
# Source Nodes: [input_335], Original ATen: [quantized_decomposed.quantize_per_token]
buf1377 = torch.ops.quantized_decomposed.quantize_per_token.default(buf1373, buf1375, buf1376, -128, 127, torch.int8)
buf1378 = buf1377
del buf1377
# Source Nodes: [input_336], Original ATen: [quantized_decomposed.dequantize_per_token]
buf1379 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf1378, buf1375, buf1376, -128, 127, torch.int8, torch.bfloat16)
del buf1375
del buf1376
del buf1378
buf1380 = buf1379
del buf1379
# Source Nodes: [w_dq_111], Original ATen: [quantized_decomposed.dequantize_per_channel_group]
buf1381 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg417_1, arg418_1, arg419_1, -8, 7, torch.int8, 256, torch.bfloat16)
del arg417_1
del arg418_1
del arg419_1
buf1382 = buf1381
del buf1381
buf1383 = reinterpret_tensor(buf1369, (1, 4096), (4096, 1), 0); del buf1369 # reuse
# Source Nodes: [c_111], Original ATen: [aten.mm]
triton_tem_fused_mm_8.run(buf1380, buf1382, buf1383, grid=torch._inductor.kernel.mm_common.mm_grid(1, 4096, meta0), stream=stream0)
del buf1382
buf1385 = buf1352; del buf1352 # reuse
# Source Nodes: [add_112, mean_32, mul_208, out_15, pow_33, rsqrt_32, x_fp32_32, x_normed_32, y_16], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, aten.pow, aten.rsqrt]
triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_11.run(buf1350, buf1383, arg420_1, buf1385, 1, 4096, grid=grid(1), stream=stream0)
del arg420_1
# Source Nodes: [choose_qparams_per_token_asymmetric_112], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric]
buf1386 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf1385, torch.int8)
buf1387 = buf1386[0]
buf1388 = buf1386[1]
del buf1386
# Source Nodes: [choose_qparams_per_token_asymmetric_113], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric]
buf1389 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf1385, torch.int8)
buf1390 = buf1389[0]
buf1391 = buf1389[1]
del buf1389
# Source Nodes: [choose_qparams_per_token_asymmetric_114], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric]
buf1392 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf1385, torch.int8)
buf1393 = buf1392[0]
buf1394 = buf1392[1]
del buf1392
buf1395 = buf1308; del buf1308 # reuse
# Source Nodes: [max_17], Original ATen: [aten.max]
triton_poi_fused_max_1.run(arg3_1, buf1395, 1, grid=grid(1), stream=stream0)
u16 = buf1395.item()
buf1396 = None
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
# Source Nodes: [input_338], Original ATen: [quantized_decomposed.quantize_per_token]
buf1397 = torch.ops.quantized_decomposed.quantize_per_token.default(buf1385, buf1387, buf1388, -128, 127, torch.int8)
buf1398 = buf1397
del buf1397
# Source Nodes: [input_339], Original ATen: [quantized_decomposed.dequantize_per_token]
buf1399 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf1398, buf1387, buf1388, -128, 127, torch.int8, torch.bfloat16)
del buf1387
del buf1388
del buf1398
buf1400 = buf1399
del buf1399
# Source Nodes: [w_dq_112], Original ATen: [quantized_decomposed.dequantize_per_channel_group]
buf1401 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg421_1, arg422_1, arg423_1, -8, 7, torch.int8, 256, torch.bfloat16)
del arg421_1
del arg422_1
del arg423_1
buf1402 = buf1401
del buf1401
buf1403 = reinterpret_tensor(buf1362, (1, 4096), (4096, 1), 0); del buf1362 # reuse
# Source Nodes: [c_112], Original ATen: [aten.mm]
triton_tem_fused_mm_2.run(buf1400, buf1402, buf1403, grid=torch._inductor.kernel.mm_common.mm_grid(1, 4096, meta0), stream=stream0)
del buf1400
del buf1402
# Source Nodes: [input_341], Original ATen: [quantized_decomposed.quantize_per_token]
buf1404 = torch.ops.quantized_decomposed.quantize_per_token.default(buf1385, buf1390, buf1391, -128, 127, torch.int8)
buf1405 = buf1404
del buf1404
# Source Nodes: [input_342], Original ATen: [quantized_decomposed.dequantize_per_token]
buf1406 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf1405, buf1390, buf1391, -128, 127, torch.int8, torch.bfloat16)
del buf1390
del buf1391
del buf1405
buf1407 = buf1406
del buf1406
# Source Nodes: [w_dq_113], Original ATen: [quantized_decomposed.dequantize_per_channel_group]
buf1408 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg424_1, arg425_1, arg426_1, -8, 7, torch.int8, 256, torch.bfloat16)
del arg424_1
del arg425_1
del arg426_1
buf1409 = buf1408
del buf1408
buf1410 = buf1331; del buf1331 # reuse
# Source Nodes: [], Original ATen: []
extern_kernels.mm(reinterpret_tensor(buf1407, (1, 4096), (4096, 1), 0), reinterpret_tensor(buf1409, (4096, 1024), (1, 4096), 0), out=buf1410)
del buf1407
del buf1409
# Source Nodes: [input_344], Original ATen: [quantized_decomposed.quantize_per_token]
buf1412 = torch.ops.quantized_decomposed.quantize_per_token.default(buf1385, buf1393, buf1394, -128, 127, torch.int8)
del buf1385
buf1413 = buf1412
del buf1412
# Source Nodes: [input_345], Original ATen: [quantized_decomposed.dequantize_per_token]
buf1414 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf1413, buf1393, buf1394, -128, 127, torch.int8, torch.bfloat16)
del buf1393
del buf1394
del buf1413
buf1415 = buf1414
del buf1414
# Source Nodes: [w_dq_114], Original ATen: [quantized_decomposed.dequantize_per_channel_group]
buf1416 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg427_1, arg428_1, arg429_1, -8, 7, torch.int8, 256, torch.bfloat16)
del arg427_1
del arg428_1
del arg429_1
buf1417 = buf1416
del buf1416
buf1418 = buf1323; del buf1323 # reuse
# Source Nodes: [], Original ATen: []
extern_kernels.mm(reinterpret_tensor(buf1415, (1, 4096), (4096, 1), 0), reinterpret_tensor(buf1417, (4096, 1024), (1, 4096), 0), out=buf1418)
del buf1417
buf1420 = reinterpret_tensor(buf1415, (1, 32, 1, 128), (4096, 128, 4096, 1), 0); del buf1415 # reuse
# Source Nodes: [output_32, setitem_32, setitem_33], Original ATen: [aten._scaled_dot_product_efficient_attention, aten.index_put]
triton_poi_fused__scaled_dot_product_efficient_attention_index_put_3.run(arg3_1, buf1410, arg430_1, buf1418, buf1403, arg431_1, arg432_1, buf1420, 4096, grid=grid(4096), stream=stream0)
del arg430_1
del buf1403
buf1421 = buf1334; del buf1334 # reuse
# Source Nodes: [output_32], Original ATen: [aten._scaled_dot_product_efficient_attention]
triton_poi_fused__scaled_dot_product_efficient_attention_4.run(arg3_1, arg2_1, buf1421, 65536, grid=grid(65536), stream=stream0)
# Source Nodes: [output_32], Original ATen: [aten._scaled_dot_product_efficient_attention]
buf1422 = torch.ops.aten._scaled_dot_product_efficient_attention.default(buf1420, arg431_1, arg432_1, buf1421, False)
del arg431_1
del arg432_1
del buf1420
buf1423 = buf1422[0]
del buf1422
# Source Nodes: [choose_qparams_per_token_asymmetric_115], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric]
buf1427 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(reinterpret_tensor(buf1423, (1, 1, 4096), (4096, 4096, 1), 0), torch.int8)
buf1428 = buf1427[0]
buf1429 = buf1427[1]
del buf1427
# Source Nodes: [input_347], Original ATen: [quantized_decomposed.quantize_per_token]
buf1430 = torch.ops.quantized_decomposed.quantize_per_token.default(reinterpret_tensor(buf1423, (1, 1, 4096), (4096, 4096, 1), 0), buf1428, buf1429, -128, 127, torch.int8)
buf1431 = buf1430
del buf1430
# Source Nodes: [input_348], Original ATen: [quantized_decomposed.dequantize_per_token]
buf1432 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf1431, buf1428, buf1429, -128, 127, torch.int8, torch.bfloat16)
del buf1428
del buf1429
del buf1431
buf1433 = buf1432
del buf1432
# Source Nodes: [w_dq_115], Original ATen: [quantized_decomposed.dequantize_per_channel_group]
buf1434 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg433_1, arg434_1, arg435_1, -8, 7, torch.int8, 256, torch.bfloat16)
del arg433_1
del arg434_1
del arg435_1
buf1435 = buf1434
del buf1434
buf1436 = reinterpret_tensor(buf1423, (1, 4096), (4096, 1), 0); del buf1423 # reuse
# Source Nodes: [c_115], Original ATen: [aten.mm]
triton_tem_fused_mm_2.run(buf1433, buf1435, buf1436, grid=torch._inductor.kernel.mm_common.mm_grid(1, 4096, meta0), stream=stream0)
del buf1435
buf1438 = buf1433; del buf1433 # reuse
# Source Nodes: [add_117, h_17, mean_33, mul_218, mul_219, out_15, pow_34, rsqrt_33, x_fp32_33, x_normed_33], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, aten.pow, aten.rsqrt]
triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_12.run(buf1436, buf1350, buf1383, arg436_1, buf1438, 1, 4096, grid=grid(1), stream=stream0)
del arg436_1
# Source Nodes: [choose_qparams_per_token_asymmetric_116], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric]
buf1439 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf1438, torch.int8)
buf1440 = buf1439[0]
buf1441 = buf1439[1]
del buf1439
# Source Nodes: [choose_qparams_per_token_asymmetric_117], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric]
buf1442 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf1438, torch.int8)
buf1443 = buf1442[0]
buf1444 = buf1442[1]
del buf1442
# Source Nodes: [input_350], Original ATen: [quantized_decomposed.quantize_per_token]
buf1445 = torch.ops.quantized_decomposed.quantize_per_token.default(buf1438, buf1440, buf1441, -128, 127, torch.int8)
buf1446 = buf1445
del buf1445
# Source Nodes: [input_351], Original ATen: [quantized_decomposed.dequantize_per_token]
buf1447 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf1446, buf1440, buf1441, -128, 127, torch.int8, torch.bfloat16)
del buf1440
del buf1441
del buf1446
buf1448 = buf1447
del buf1447
# Source Nodes: [w_dq_116], Original ATen: [quantized_decomposed.dequantize_per_channel_group]
buf1449 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg437_1, arg438_1, arg439_1, -8, 7, torch.int8, 256, torch.bfloat16)
del arg437_1
del arg438_1
del arg439_1
buf1450 = buf1449
del buf1449
buf1451 = reinterpret_tensor(buf1380, (1, 14336), (14336, 1), 0); del buf1380 # reuse
# Source Nodes: [c_116], Original ATen: [aten.mm]
triton_tem_fused_mm_6.run(buf1448, buf1450, buf1451, grid=torch._inductor.kernel.mm_common.mm_grid(1, 14336, meta1), stream=stream0)
del buf1450
# Source Nodes: [input_353], Original ATen: [quantized_decomposed.quantize_per_token]
buf1452 = torch.ops.quantized_decomposed.quantize_per_token.default(buf1438, buf1443, buf1444, -128, 127, torch.int8)
buf1453 = buf1452
del buf1452
# Source Nodes: [input_354], Original ATen: [quantized_decomposed.dequantize_per_token]
buf1454 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf1453, buf1443, buf1444, -128, 127, torch.int8, torch.bfloat16)
del buf1443
del buf1444
del buf1453
buf1455 = buf1454
del buf1454
# Source Nodes: [w_dq_117], Original ATen: [quantized_decomposed.dequantize_per_channel_group]
buf1456 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg440_1, arg441_1, arg442_1, -8, 7, torch.int8, 256, torch.bfloat16)
del arg440_1
del arg441_1
del arg442_1
buf1457 = buf1456
del buf1456
buf1459 = buf1373; del buf1373 # reuse
# Source Nodes: [c_117, mul_220, silu_16], Original ATen: [aten.mm, aten.mul, aten.silu]
triton_tem_fused_mm_mul_silu_7.run(buf1455, buf1457, buf1451, buf1459, grid=torch._inductor.kernel.mm_common.mm_grid(1, 14336, meta1), stream=stream0)
del buf1451
del buf1457
# Source Nodes: [choose_qparams_per_token_asymmetric_118], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric]
buf1460 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf1459, torch.int8)
buf1461 = buf1460[0]
buf1462 = buf1460[1]
del buf1460
# Source Nodes: [input_356], Original ATen: [quantized_decomposed.quantize_per_token]
buf1463 = torch.ops.quantized_decomposed.quantize_per_token.default(buf1459, buf1461, buf1462, -128, 127, torch.int8)
buf1464 = buf1463
del buf1463
# Source Nodes: [input_357], Original ATen: [quantized_decomposed.dequantize_per_token]
buf1465 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf1464, buf1461, buf1462, -128, 127, torch.int8, torch.bfloat16)
del buf1461
del buf1462
del buf1464
buf1466 = buf1465
del buf1465
# Source Nodes: [w_dq_118], Original ATen: [quantized_decomposed.dequantize_per_channel_group]
buf1467 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg443_1, arg444_1, arg445_1, -8, 7, torch.int8, 256, torch.bfloat16)
del arg443_1
del arg444_1
del arg445_1
buf1468 = buf1467
del buf1467
buf1469 = reinterpret_tensor(buf1455, (1, 4096), (4096, 1), 0); del buf1455 # reuse
# Source Nodes: [c_118], Original ATen: [aten.mm]
triton_tem_fused_mm_8.run(buf1466, buf1468, buf1469, grid=torch._inductor.kernel.mm_common.mm_grid(1, 4096, meta0), stream=stream0)
del buf1468
buf1471 = buf1438; del buf1438 # reuse
# Source Nodes: [add_119, h_17, mean_34, mul_221, out_15, out_16, pow_35, rsqrt_34, x_fp32_34, x_normed_34, y_17], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, aten.pow, aten.rsqrt]
triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_13.run(buf1436, buf1350, buf1383, buf1469, arg446_1, buf1471, 1, 4096, grid=grid(1), stream=stream0)
del arg446_1
# Source Nodes: [choose_qparams_per_token_asymmetric_119], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric]
buf1472 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf1471, torch.int8)
buf1473 = buf1472[0]
buf1474 = buf1472[1]
del buf1472
# Source Nodes: [choose_qparams_per_token_asymmetric_120], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric]
buf1475 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf1471, torch.int8)
buf1476 = buf1475[0]
buf1477 = buf1475[1]
del buf1475
# Source Nodes: [choose_qparams_per_token_asymmetric_121], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric]
buf1478 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf1471, torch.int8)
buf1479 = buf1478[0]
buf1480 = buf1478[1]
del buf1478
buf1481 = buf1395; del buf1395 # reuse
# Source Nodes: [max_18], Original ATen: [aten.max]
triton_poi_fused_max_1.run(arg3_1, buf1481, 1, grid=grid(1), stream=stream0)
u17 = buf1481.item()
buf1482 = None
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
# Source Nodes: [input_359], Original ATen: [quantized_decomposed.quantize_per_token]
buf1483 = torch.ops.quantized_decomposed.quantize_per_token.default(buf1471, buf1473, buf1474, -128, 127, torch.int8)
buf1484 = buf1483
del buf1483
# Source Nodes: [input_360], Original ATen: [quantized_decomposed.dequantize_per_token]
buf1485 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf1484, buf1473, buf1474, -128, 127, torch.int8, torch.bfloat16)
del buf1473
del buf1474
del buf1484
buf1486 = buf1485
del buf1485
# Source Nodes: [w_dq_119], Original ATen: [quantized_decomposed.dequantize_per_channel_group]
buf1487 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg447_1, arg448_1, arg449_1, -8, 7, torch.int8, 256, torch.bfloat16)
del arg447_1
del arg448_1
del arg449_1
buf1488 = buf1487
del buf1487
buf1489 = reinterpret_tensor(buf1448, (1, 4096), (4096, 1), 0); del buf1448 # reuse
# Source Nodes: [c_119], Original ATen: [aten.mm]
triton_tem_fused_mm_2.run(buf1486, buf1488, buf1489, grid=torch._inductor.kernel.mm_common.mm_grid(1, 4096, meta0), stream=stream0)
del buf1486
del buf1488
# Source Nodes: [input_362], Original ATen: [quantized_decomposed.quantize_per_token]
buf1490 = torch.ops.quantized_decomposed.quantize_per_token.default(buf1471, buf1476, buf1477, -128, 127, torch.int8)
buf1491 = buf1490
del buf1490
# Source Nodes: [input_363], Original ATen: [quantized_decomposed.dequantize_per_token]
buf1492 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf1491, buf1476, buf1477, -128, 127, torch.int8, torch.bfloat16)
del buf1476
del buf1477
del buf1491
buf1493 = buf1492
del buf1492
# Source Nodes: [w_dq_120], Original ATen: [quantized_decomposed.dequantize_per_channel_group]
buf1494 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg450_1, arg451_1, arg452_1, -8, 7, torch.int8, 256, torch.bfloat16)
del arg450_1
del arg451_1
del arg452_1
buf1495 = buf1494
del buf1494
buf1496 = buf1418; del buf1418 # reuse
# Source Nodes: [], Original ATen: []
extern_kernels.mm(reinterpret_tensor(buf1493, (1, 4096), (4096, 1), 0), reinterpret_tensor(buf1495, (4096, 1024), (1, 4096), 0), out=buf1496)
del buf1493
del buf1495
# Source Nodes: [input_365], Original ATen: [quantized_decomposed.quantize_per_token]
buf1498 = torch.ops.quantized_decomposed.quantize_per_token.default(buf1471, buf1479, buf1480, -128, 127, torch.int8)
del buf1471
buf1499 = buf1498
del buf1498
# Source Nodes: [input_366], Original ATen: [quantized_decomposed.dequantize_per_token]
buf1500 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf1499, buf1479, buf1480, -128, 127, torch.int8, torch.bfloat16)
del buf1479
del buf1480
del buf1499
buf1501 = buf1500
del buf1500
# Source Nodes: [w_dq_121], Original ATen: [quantized_decomposed.dequantize_per_channel_group]
buf1502 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg453_1, arg454_1, arg455_1, -8, 7, torch.int8, 256, torch.bfloat16)
del arg453_1
del arg454_1
del arg455_1
buf1503 = buf1502
del buf1502
buf1504 = buf1410; del buf1410 # reuse
# Source Nodes: [], Original ATen: []
extern_kernels.mm(reinterpret_tensor(buf1501, (1, 4096), (4096, 1), 0), reinterpret_tensor(buf1503, (4096, 1024), (1, 4096), 0), out=buf1504)
del buf1503
buf1506 = reinterpret_tensor(buf1501, (1, 32, 1, 128), (4096, 128, 4096, 1), 0); del buf1501 # reuse
# Source Nodes: [output_34, setitem_34, setitem_35], Original ATen: [aten._scaled_dot_product_efficient_attention, aten.index_put]
triton_poi_fused__scaled_dot_product_efficient_attention_index_put_3.run(arg3_1, buf1496, arg456_1, buf1504, buf1489, arg457_1, arg458_1, buf1506, 4096, grid=grid(4096), stream=stream0)
del arg456_1
del buf1489
buf1507 = buf1421; del buf1421 # reuse
# Source Nodes: [output_34], Original ATen: [aten._scaled_dot_product_efficient_attention]
triton_poi_fused__scaled_dot_product_efficient_attention_4.run(arg3_1, arg2_1, buf1507, 65536, grid=grid(65536), stream=stream0)
# Source Nodes: [output_34], Original ATen: [aten._scaled_dot_product_efficient_attention]
buf1508 = torch.ops.aten._scaled_dot_product_efficient_attention.default(buf1506, arg457_1, arg458_1, buf1507, False)
del arg457_1
del arg458_1
del buf1506
buf1509 = buf1508[0]
del buf1508
# Source Nodes: [choose_qparams_per_token_asymmetric_122], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric]
buf1513 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(reinterpret_tensor(buf1509, (1, 1, 4096), (4096, 4096, 1), 0), torch.int8)
buf1514 = buf1513[0]
buf1515 = buf1513[1]
del buf1513
# Source Nodes: [input_368], Original ATen: [quantized_decomposed.quantize_per_token]
buf1516 = torch.ops.quantized_decomposed.quantize_per_token.default(reinterpret_tensor(buf1509, (1, 1, 4096), (4096, 4096, 1), 0), buf1514, buf1515, -128, 127, torch.int8)
buf1517 = buf1516
del buf1516
# Source Nodes: [input_369], Original ATen: [quantized_decomposed.dequantize_per_token]
buf1518 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf1517, buf1514, buf1515, -128, 127, torch.int8, torch.bfloat16)
del buf1514
del buf1515
del buf1517
buf1519 = buf1518
del buf1518
# Source Nodes: [w_dq_122], Original ATen: [quantized_decomposed.dequantize_per_channel_group]
buf1520 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg459_1, arg460_1, arg461_1, -8, 7, torch.int8, 256, torch.bfloat16)
del arg459_1
del arg460_1
del arg461_1
buf1521 = buf1520
del buf1520
buf1522 = reinterpret_tensor(buf1509, (1, 4096), (4096, 1), 0); del buf1509 # reuse
# Source Nodes: [c_122], Original ATen: [aten.mm]
triton_tem_fused_mm_2.run(buf1519, buf1521, buf1522, grid=torch._inductor.kernel.mm_common.mm_grid(1, 4096, meta0), stream=stream0)
del buf1521
buf1523 = buf1350; del buf1350 # reuse
buf1525 = buf1519; del buf1519 # reuse
# Source Nodes: [add_124, h_17, h_18, mean_35, mul_231, mul_232, out_15, out_16, pow_36, rsqrt_35, x_fp32_35, x_normed_35], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, aten.pow, aten.rsqrt]
triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_14.run(buf1523, buf1522, buf1436, buf1383, buf1469, arg462_1, buf1525, 1, 4096, grid=grid(1), stream=stream0)
del arg462_1
del buf1383
del buf1436
del buf1469
del buf1522
# Source Nodes: [choose_qparams_per_token_asymmetric_123], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric]
buf1526 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf1525, torch.int8)
buf1527 = buf1526[0]
buf1528 = buf1526[1]
del buf1526
# Source Nodes: [choose_qparams_per_token_asymmetric_124], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric]
buf1529 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf1525, torch.int8)
buf1530 = buf1529[0]
buf1531 = buf1529[1]
del buf1529
# Source Nodes: [input_371], Original ATen: [quantized_decomposed.quantize_per_token]
buf1532 = torch.ops.quantized_decomposed.quantize_per_token.default(buf1525, buf1527, buf1528, -128, 127, torch.int8)
buf1533 = buf1532
del buf1532
# Source Nodes: [input_372], Original ATen: [quantized_decomposed.dequantize_per_token]
buf1534 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf1533, buf1527, buf1528, -128, 127, torch.int8, torch.bfloat16)
del buf1527
del buf1528
del buf1533
buf1535 = buf1534
del buf1534
# Source Nodes: [w_dq_123], Original ATen: [quantized_decomposed.dequantize_per_channel_group]
buf1536 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg463_1, arg464_1, arg465_1, -8, 7, torch.int8, 256, torch.bfloat16)
del arg463_1
del arg464_1
del arg465_1
buf1537 = buf1536
del buf1536
buf1538 = reinterpret_tensor(buf1466, (1, 14336), (14336, 1), 0); del buf1466 # reuse
# Source Nodes: [c_123], Original ATen: [aten.mm]
triton_tem_fused_mm_6.run(buf1535, buf1537, buf1538, grid=torch._inductor.kernel.mm_common.mm_grid(1, 14336, meta1), stream=stream0)
del buf1537
# Source Nodes: [input_374], Original ATen: [quantized_decomposed.quantize_per_token]
buf1539 = torch.ops.quantized_decomposed.quantize_per_token.default(buf1525, buf1530, buf1531, -128, 127, torch.int8)
buf1540 = buf1539
del buf1539
# Source Nodes: [input_375], Original ATen: [quantized_decomposed.dequantize_per_token]
buf1541 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf1540, buf1530, buf1531, -128, 127, torch.int8, torch.bfloat16)
del buf1530
del buf1531
del buf1540
buf1542 = buf1541
del buf1541
# Source Nodes: [w_dq_124], Original ATen: [quantized_decomposed.dequantize_per_channel_group]
buf1543 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg466_1, arg467_1, arg468_1, -8, 7, torch.int8, 256, torch.bfloat16)
del arg466_1
del arg467_1
del arg468_1
buf1544 = buf1543
del buf1543
buf1546 = buf1459; del buf1459 # reuse
# Source Nodes: [c_124, mul_233, silu_17], Original ATen: [aten.mm, aten.mul, aten.silu]
triton_tem_fused_mm_mul_silu_7.run(buf1542, buf1544, buf1538, buf1546, grid=torch._inductor.kernel.mm_common.mm_grid(1, 14336, meta1), stream=stream0)
del buf1538
del buf1544
# Source Nodes: [choose_qparams_per_token_asymmetric_125], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric]
buf1547 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf1546, torch.int8)
buf1548 = buf1547[0]
buf1549 = buf1547[1]
del buf1547
# Source Nodes: [input_377], Original ATen: [quantized_decomposed.quantize_per_token]
buf1550 = torch.ops.quantized_decomposed.quantize_per_token.default(buf1546, buf1548, buf1549, -128, 127, torch.int8)
buf1551 = buf1550
del buf1550
# Source Nodes: [input_378], Original ATen: [quantized_decomposed.dequantize_per_token]
buf1552 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf1551, buf1548, buf1549, -128, 127, torch.int8, torch.bfloat16)
del buf1548
del buf1549
del buf1551
buf1553 = buf1552
del buf1552
# Source Nodes: [w_dq_125], Original ATen: [quantized_decomposed.dequantize_per_channel_group]
buf1554 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg469_1, arg470_1, arg471_1, -8, 7, torch.int8, 256, torch.bfloat16)
del arg469_1
del arg470_1
del arg471_1
buf1555 = buf1554
del buf1554
buf1556 = reinterpret_tensor(buf1542, (1, 4096), (4096, 1), 0); del buf1542 # reuse
# Source Nodes: [c_125], Original ATen: [aten.mm]
triton_tem_fused_mm_8.run(buf1553, buf1555, buf1556, grid=torch._inductor.kernel.mm_common.mm_grid(1, 4096, meta0), stream=stream0)
del buf1555
buf1558 = buf1525; del buf1525 # reuse
# Source Nodes: [add_126, mean_36, mul_234, out_17, pow_37, rsqrt_36, x_fp32_36, x_normed_36, y_18], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, aten.pow, aten.rsqrt]
triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_11.run(buf1523, buf1556, arg472_1, buf1558, 1, 4096, grid=grid(1), stream=stream0)
del arg472_1
# Source Nodes: [choose_qparams_per_token_asymmetric_126], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric]
buf1559 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf1558, torch.int8)
buf1560 = buf1559[0]
buf1561 = buf1559[1]
del buf1559
# Source Nodes: [choose_qparams_per_token_asymmetric_127], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric]
buf1562 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf1558, torch.int8)
buf1563 = buf1562[0]
buf1564 = buf1562[1]
del buf1562
# Source Nodes: [choose_qparams_per_token_asymmetric_128], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric]
buf1565 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf1558, torch.int8)
buf1566 = buf1565[0]
buf1567 = buf1565[1]
del buf1565
buf1568 = buf1481; del buf1481 # reuse
# Source Nodes: [max_19], Original ATen: [aten.max]
triton_poi_fused_max_1.run(arg3_1, buf1568, 1, grid=grid(1), stream=stream0)
u18 = buf1568.item()
buf1569 = None
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
# Source Nodes: [input_380], Original ATen: [quantized_decomposed.quantize_per_token]
buf1570 = torch.ops.quantized_decomposed.quantize_per_token.default(buf1558, buf1560, buf1561, -128, 127, torch.int8)
buf1571 = buf1570
del buf1570
# Source Nodes: [input_381], Original ATen: [quantized_decomposed.dequantize_per_token]
buf1572 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf1571, buf1560, buf1561, -128, 127, torch.int8, torch.bfloat16)
del buf1560
del buf1561
del buf1571
buf1573 = buf1572
del buf1572
# Source Nodes: [w_dq_126], Original ATen: [quantized_decomposed.dequantize_per_channel_group]
buf1574 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg473_1, arg474_1, arg475_1, -8, 7, torch.int8, 256, torch.bfloat16)
del arg473_1
del arg474_1
del arg475_1
buf1575 = buf1574
del buf1574
buf1576 = reinterpret_tensor(buf1535, (1, 4096), (4096, 1), 0); del buf1535 # reuse
# Source Nodes: [c_126], Original ATen: [aten.mm]
triton_tem_fused_mm_2.run(buf1573, buf1575, buf1576, grid=torch._inductor.kernel.mm_common.mm_grid(1, 4096, meta0), stream=stream0)
del buf1573
del buf1575
# Source Nodes: [input_383], Original ATen: [quantized_decomposed.quantize_per_token]
buf1577 = torch.ops.quantized_decomposed.quantize_per_token.default(buf1558, buf1563, buf1564, -128, 127, torch.int8)
buf1578 = buf1577
del buf1577
# Source Nodes: [input_384], Original ATen: [quantized_decomposed.dequantize_per_token]
buf1579 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf1578, buf1563, buf1564, -128, 127, torch.int8, torch.bfloat16)
del buf1563
del buf1564
del buf1578
buf1580 = buf1579
del buf1579
# Source Nodes: [w_dq_127], Original ATen: [quantized_decomposed.dequantize_per_channel_group]
buf1581 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg476_1, arg477_1, arg478_1, -8, 7, torch.int8, 256, torch.bfloat16)
del arg476_1
del arg477_1
del arg478_1
buf1582 = buf1581
del buf1581
buf1583 = buf1504; del buf1504 # reuse
# Source Nodes: [], Original ATen: []
extern_kernels.mm(reinterpret_tensor(buf1580, (1, 4096), (4096, 1), 0), reinterpret_tensor(buf1582, (4096, 1024), (1, 4096), 0), out=buf1583)
del buf1580
del buf1582
# Source Nodes: [input_386], Original ATen: [quantized_decomposed.quantize_per_token]
buf1585 = torch.ops.quantized_decomposed.quantize_per_token.default(buf1558, buf1566, buf1567, -128, 127, torch.int8)
del buf1558
buf1586 = buf1585
del buf1585
# Source Nodes: [input_387], Original ATen: [quantized_decomposed.dequantize_per_token]
buf1587 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf1586, buf1566, buf1567, -128, 127, torch.int8, torch.bfloat16)
del buf1566
del buf1567
del buf1586
buf1588 = buf1587
del buf1587
# Source Nodes: [w_dq_128], Original ATen: [quantized_decomposed.dequantize_per_channel_group]
buf1589 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg479_1, arg480_1, arg481_1, -8, 7, torch.int8, 256, torch.bfloat16)
del arg479_1
del arg480_1
del arg481_1
buf1590 = buf1589
del buf1589
buf1591 = buf1496; del buf1496 # reuse
# Source Nodes: [], Original ATen: []
extern_kernels.mm(reinterpret_tensor(buf1588, (1, 4096), (4096, 1), 0), reinterpret_tensor(buf1590, (4096, 1024), (1, 4096), 0), out=buf1591)
del buf1590
buf1593 = reinterpret_tensor(buf1588, (1, 32, 1, 128), (4096, 128, 4096, 1), 0); del buf1588 # reuse
# Source Nodes: [output_36, setitem_36, setitem_37], Original ATen: [aten._scaled_dot_product_efficient_attention, aten.index_put]
triton_poi_fused__scaled_dot_product_efficient_attention_index_put_3.run(arg3_1, buf1583, arg482_1, buf1591, buf1576, arg483_1, arg484_1, buf1593, 4096, grid=grid(4096), stream=stream0)
del arg482_1
del buf1576
buf1594 = buf1507; del buf1507 # reuse
# Source Nodes: [output_36], Original ATen: [aten._scaled_dot_product_efficient_attention]
triton_poi_fused__scaled_dot_product_efficient_attention_4.run(arg3_1, arg2_1, buf1594, 65536, grid=grid(65536), stream=stream0)
# Source Nodes: [output_36], Original ATen: [aten._scaled_dot_product_efficient_attention]
buf1595 = torch.ops.aten._scaled_dot_product_efficient_attention.default(buf1593, arg483_1, arg484_1, buf1594, False)
del arg483_1
del arg484_1
del buf1593
buf1596 = buf1595[0]
del buf1595
# Source Nodes: [choose_qparams_per_token_asymmetric_129], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric]
buf1600 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(reinterpret_tensor(buf1596, (1, 1, 4096), (4096, 4096, 1), 0), torch.int8)
buf1601 = buf1600[0]
buf1602 = buf1600[1]
del buf1600
# Source Nodes: [input_389], Original ATen: [quantized_decomposed.quantize_per_token]
buf1603 = torch.ops.quantized_decomposed.quantize_per_token.default(reinterpret_tensor(buf1596, (1, 1, 4096), (4096, 4096, 1), 0), buf1601, buf1602, -128, 127, torch.int8)
buf1604 = buf1603
del buf1603
# Source Nodes: [input_390], Original ATen: [quantized_decomposed.dequantize_per_token]
buf1605 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf1604, buf1601, buf1602, -128, 127, torch.int8, torch.bfloat16)
del buf1601
del buf1602
del buf1604
buf1606 = buf1605
del buf1605
# Source Nodes: [w_dq_129], Original ATen: [quantized_decomposed.dequantize_per_channel_group]
buf1607 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg485_1, arg486_1, arg487_1, -8, 7, torch.int8, 256, torch.bfloat16)
del arg485_1
del arg486_1
del arg487_1
buf1608 = buf1607
del buf1607
buf1609 = reinterpret_tensor(buf1596, (1, 4096), (4096, 1), 0); del buf1596 # reuse
# Source Nodes: [c_129], Original ATen: [aten.mm]
triton_tem_fused_mm_2.run(buf1606, buf1608, buf1609, grid=torch._inductor.kernel.mm_common.mm_grid(1, 4096, meta0), stream=stream0)
del buf1608
buf1611 = buf1606; del buf1606 # reuse
# Source Nodes: [add_131, h_19, mean_37, mul_244, mul_245, out_17, pow_38, rsqrt_37, x_fp32_37, x_normed_37], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, aten.pow, aten.rsqrt]
triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_12.run(buf1609, buf1523, buf1556, arg488_1, buf1611, 1, 4096, grid=grid(1), stream=stream0)
del arg488_1
# Source Nodes: [choose_qparams_per_token_asymmetric_130], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric]
buf1612 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf1611, torch.int8)
buf1613 = buf1612[0]
buf1614 = buf1612[1]
del buf1612
# Source Nodes: [choose_qparams_per_token_asymmetric_131], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric]
buf1615 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf1611, torch.int8)
buf1616 = buf1615[0]
buf1617 = buf1615[1]
del buf1615
# Source Nodes: [input_392], Original ATen: [quantized_decomposed.quantize_per_token]
buf1618 = torch.ops.quantized_decomposed.quantize_per_token.default(buf1611, buf1613, buf1614, -128, 127, torch.int8)
buf1619 = buf1618
del buf1618
# Source Nodes: [input_393], Original ATen: [quantized_decomposed.dequantize_per_token]
buf1620 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf1619, buf1613, buf1614, -128, 127, torch.int8, torch.bfloat16)
del buf1613
del buf1614
del buf1619
buf1621 = buf1620
del buf1620
# Source Nodes: [w_dq_130], Original ATen: [quantized_decomposed.dequantize_per_channel_group]
buf1622 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg489_1, arg490_1, arg491_1, -8, 7, torch.int8, 256, torch.bfloat16)
del arg489_1
del arg490_1
del arg491_1
buf1623 = buf1622
del buf1622
buf1624 = reinterpret_tensor(buf1553, (1, 14336), (14336, 1), 0); del buf1553 # reuse
# Source Nodes: [c_130], Original ATen: [aten.mm]
triton_tem_fused_mm_6.run(buf1621, buf1623, buf1624, grid=torch._inductor.kernel.mm_common.mm_grid(1, 14336, meta1), stream=stream0)
del buf1623
# Source Nodes: [input_395], Original ATen: [quantized_decomposed.quantize_per_token]
buf1625 = torch.ops.quantized_decomposed.quantize_per_token.default(buf1611, buf1616, buf1617, -128, 127, torch.int8)
buf1626 = buf1625
del buf1625
# Source Nodes: [input_396], Original ATen: [quantized_decomposed.dequantize_per_token]
buf1627 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf1626, buf1616, buf1617, -128, 127, torch.int8, torch.bfloat16)
del buf1616
del buf1617
del buf1626
buf1628 = buf1627
del buf1627
# Source Nodes: [w_dq_131], Original ATen: [quantized_decomposed.dequantize_per_channel_group]
buf1629 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg492_1, arg493_1, arg494_1, -8, 7, torch.int8, 256, torch.bfloat16)
del arg492_1
del arg493_1
del arg494_1
buf1630 = buf1629
del buf1629
buf1632 = buf1546; del buf1546 # reuse
# Source Nodes: [c_131, mul_246, silu_18], Original ATen: [aten.mm, aten.mul, aten.silu]
triton_tem_fused_mm_mul_silu_7.run(buf1628, buf1630, buf1624, buf1632, grid=torch._inductor.kernel.mm_common.mm_grid(1, 14336, meta1), stream=stream0)
del buf1624
del buf1630
# Source Nodes: [choose_qparams_per_token_asymmetric_132], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric]
buf1633 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf1632, torch.int8)
buf1634 = buf1633[0]
buf1635 = buf1633[1]
del buf1633
# Source Nodes: [input_398], Original ATen: [quantized_decomposed.quantize_per_token]
buf1636 = torch.ops.quantized_decomposed.quantize_per_token.default(buf1632, buf1634, buf1635, -128, 127, torch.int8)
buf1637 = buf1636
del buf1636
# Source Nodes: [input_399], Original ATen: [quantized_decomposed.dequantize_per_token]
buf1638 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf1637, buf1634, buf1635, -128, 127, torch.int8, torch.bfloat16)
del buf1634
del buf1635
del buf1637
buf1639 = buf1638
del buf1638
# Source Nodes: [w_dq_132], Original ATen: [quantized_decomposed.dequantize_per_channel_group]
buf1640 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg495_1, arg496_1, arg497_1, -8, 7, torch.int8, 256, torch.bfloat16)
del arg495_1
del arg496_1
del arg497_1
buf1641 = buf1640
del buf1640
buf1642 = reinterpret_tensor(buf1628, (1, 4096), (4096, 1), 0); del buf1628 # reuse
# Source Nodes: [c_132], Original ATen: [aten.mm]
triton_tem_fused_mm_8.run(buf1639, buf1641, buf1642, grid=torch._inductor.kernel.mm_common.mm_grid(1, 4096, meta0), stream=stream0)
del buf1641
buf1644 = buf1611; del buf1611 # reuse
# Source Nodes: [add_133, h_19, mean_38, mul_247, out_17, out_18, pow_39, rsqrt_38, x_fp32_38, x_normed_38, y_19], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, aten.pow, aten.rsqrt]
triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_13.run(buf1609, buf1523, buf1556, buf1642, arg498_1, buf1644, 1, 4096, grid=grid(1), stream=stream0)
del arg498_1
# Source Nodes: [choose_qparams_per_token_asymmetric_133], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric]
buf1645 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf1644, torch.int8)
buf1646 = buf1645[0]
buf1647 = buf1645[1]
del buf1645
# Source Nodes: [choose_qparams_per_token_asymmetric_134], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric]
buf1648 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf1644, torch.int8)
buf1649 = buf1648[0]
buf1650 = buf1648[1]
del buf1648
# Source Nodes: [choose_qparams_per_token_asymmetric_135], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric]
buf1651 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf1644, torch.int8)
buf1652 = buf1651[0]
buf1653 = buf1651[1]
del buf1651
buf1654 = buf1568; del buf1568 # reuse
# Source Nodes: [max_20], Original ATen: [aten.max]
triton_poi_fused_max_1.run(arg3_1, buf1654, 1, grid=grid(1), stream=stream0)
u19 = buf1654.item()
buf1655 = None
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
# Source Nodes: [input_401], Original ATen: [quantized_decomposed.quantize_per_token]
buf1656 = torch.ops.quantized_decomposed.quantize_per_token.default(buf1644, buf1646, buf1647, -128, 127, torch.int8)
buf1657 = buf1656
del buf1656
# Source Nodes: [input_402], Original ATen: [quantized_decomposed.dequantize_per_token]
buf1658 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf1657, buf1646, buf1647, -128, 127, torch.int8, torch.bfloat16)
del buf1646
del buf1647
del buf1657
buf1659 = buf1658
del buf1658
# Source Nodes: [w_dq_133], Original ATen: [quantized_decomposed.dequantize_per_channel_group]
buf1660 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg499_1, arg500_1, arg501_1, -8, 7, torch.int8, 256, torch.bfloat16)
del arg499_1
del arg500_1
del arg501_1
buf1661 = buf1660
del buf1660
buf1662 = reinterpret_tensor(buf1621, (1, 4096), (4096, 1), 0); del buf1621 # reuse
# Source Nodes: [c_133], Original ATen: [aten.mm]
triton_tem_fused_mm_2.run(buf1659, buf1661, buf1662, grid=torch._inductor.kernel.mm_common.mm_grid(1, 4096, meta0), stream=stream0)
del buf1659
del buf1661
# Source Nodes: [input_404], Original ATen: [quantized_decomposed.quantize_per_token]
buf1663 = torch.ops.quantized_decomposed.quantize_per_token.default(buf1644, buf1649, buf1650, -128, 127, torch.int8)
buf1664 = buf1663
del buf1663
# Source Nodes: [input_405], Original ATen: [quantized_decomposed.dequantize_per_token]
buf1665 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf1664, buf1649, buf1650, -128, 127, torch.int8, torch.bfloat16)
del buf1649
del buf1650
del buf1664
buf1666 = buf1665
del buf1665
# Source Nodes: [w_dq_134], Original ATen: [quantized_decomposed.dequantize_per_channel_group]
buf1667 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg502_1, arg503_1, arg504_1, -8, 7, torch.int8, 256, torch.bfloat16)
del arg502_1
del arg503_1
del arg504_1
buf1668 = buf1667
del buf1667
buf1669 = buf1591; del buf1591 # reuse
# Source Nodes: [], Original ATen: []
extern_kernels.mm(reinterpret_tensor(buf1666, (1, 4096), (4096, 1), 0), reinterpret_tensor(buf1668, (4096, 1024), (1, 4096), 0), out=buf1669)
del buf1666
del buf1668
# Source Nodes: [input_407], Original ATen: [quantized_decomposed.quantize_per_token]
buf1671 = torch.ops.quantized_decomposed.quantize_per_token.default(buf1644, buf1652, buf1653, -128, 127, torch.int8)
del buf1644
buf1672 = buf1671
del buf1671
# Source Nodes: [input_408], Original ATen: [quantized_decomposed.dequantize_per_token]
buf1673 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf1672, buf1652, buf1653, -128, 127, torch.int8, torch.bfloat16)
del buf1652
del buf1653
del buf1672
buf1674 = buf1673
del buf1673
# Source Nodes: [w_dq_135], Original ATen: [quantized_decomposed.dequantize_per_channel_group]
buf1675 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg505_1, arg506_1, arg507_1, -8, 7, torch.int8, 256, torch.bfloat16)
del arg505_1
del arg506_1
del arg507_1
buf1676 = buf1675
del buf1675
buf1677 = buf1583; del buf1583 # reuse
# Source Nodes: [], Original ATen: []
extern_kernels.mm(reinterpret_tensor(buf1674, (1, 4096), (4096, 1), 0), reinterpret_tensor(buf1676, (4096, 1024), (1, 4096), 0), out=buf1677)
del buf1676
buf1679 = reinterpret_tensor(buf1674, (1, 32, 1, 128), (4096, 128, 4096, 1), 0); del buf1674 # reuse
# Source Nodes: [output_38, setitem_38, setitem_39], Original ATen: [aten._scaled_dot_product_efficient_attention, aten.index_put]
triton_poi_fused__scaled_dot_product_efficient_attention_index_put_3.run(arg3_1, buf1669, arg508_1, buf1677, buf1662, arg509_1, arg510_1, buf1679, 4096, grid=grid(4096), stream=stream0)
del arg508_1
del buf1662
buf1680 = buf1594; del buf1594 # reuse
# Source Nodes: [output_38], Original ATen: [aten._scaled_dot_product_efficient_attention]
triton_poi_fused__scaled_dot_product_efficient_attention_4.run(arg3_1, arg2_1, buf1680, 65536, grid=grid(65536), stream=stream0)
# Source Nodes: [output_38], Original ATen: [aten._scaled_dot_product_efficient_attention]
buf1681 = torch.ops.aten._scaled_dot_product_efficient_attention.default(buf1679, arg509_1, arg510_1, buf1680, False)
del arg509_1
del arg510_1
del buf1679
buf1682 = buf1681[0]
del buf1681
# Source Nodes: [choose_qparams_per_token_asymmetric_136], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric]
buf1686 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(reinterpret_tensor(buf1682, (1, 1, 4096), (4096, 4096, 1), 0), torch.int8)
buf1687 = buf1686[0]
buf1688 = buf1686[1]
del buf1686
# Source Nodes: [input_410], Original ATen: [quantized_decomposed.quantize_per_token]
buf1689 = torch.ops.quantized_decomposed.quantize_per_token.default(reinterpret_tensor(buf1682, (1, 1, 4096), (4096, 4096, 1), 0), buf1687, buf1688, -128, 127, torch.int8)
buf1690 = buf1689
del buf1689
# Source Nodes: [input_411], Original ATen: [quantized_decomposed.dequantize_per_token]
buf1691 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf1690, buf1687, buf1688, -128, 127, torch.int8, torch.bfloat16)
del buf1687
del buf1688
del buf1690
buf1692 = buf1691
del buf1691
# Source Nodes: [w_dq_136], Original ATen: [quantized_decomposed.dequantize_per_channel_group]
buf1693 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg511_1, arg512_1, arg513_1, -8, 7, torch.int8, 256, torch.bfloat16)
del arg511_1
del arg512_1
del arg513_1
buf1694 = buf1693
del buf1693
buf1695 = reinterpret_tensor(buf1682, (1, 4096), (4096, 1), 0); del buf1682 # reuse
# Source Nodes: [c_136], Original ATen: [aten.mm]
triton_tem_fused_mm_2.run(buf1692, buf1694, buf1695, grid=torch._inductor.kernel.mm_common.mm_grid(1, 4096, meta0), stream=stream0)
del buf1694
buf1696 = buf1523; del buf1523 # reuse
buf1698 = buf1692; del buf1692 # reuse
# Source Nodes: [add_138, h_19, h_20, mean_39, mul_257, mul_258, out_17, out_18, pow_40, rsqrt_39, x_fp32_39, x_normed_39], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, aten.pow, aten.rsqrt]
triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_14.run(buf1696, buf1695, buf1609, buf1556, buf1642, arg514_1, buf1698, 1, 4096, grid=grid(1), stream=stream0)
del arg514_1
del buf1556
del buf1609
del buf1642
del buf1695
# Source Nodes: [choose_qparams_per_token_asymmetric_137], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric]
buf1699 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf1698, torch.int8)
buf1700 = buf1699[0]
buf1701 = buf1699[1]
del buf1699
# Source Nodes: [choose_qparams_per_token_asymmetric_138], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric]
buf1702 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf1698, torch.int8)
buf1703 = buf1702[0]
buf1704 = buf1702[1]
del buf1702
# Source Nodes: [input_413], Original ATen: [quantized_decomposed.quantize_per_token]
buf1705 = torch.ops.quantized_decomposed.quantize_per_token.default(buf1698, buf1700, buf1701, -128, 127, torch.int8)
buf1706 = buf1705
del buf1705
# Source Nodes: [input_414], Original ATen: [quantized_decomposed.dequantize_per_token]
buf1707 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf1706, buf1700, buf1701, -128, 127, torch.int8, torch.bfloat16)
del buf1700
del buf1701
del buf1706
buf1708 = buf1707
del buf1707
# Source Nodes: [w_dq_137], Original ATen: [quantized_decomposed.dequantize_per_channel_group]
buf1709 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg515_1, arg516_1, arg517_1, -8, 7, torch.int8, 256, torch.bfloat16)
del arg515_1
del arg516_1
del arg517_1
buf1710 = buf1709
del buf1709
buf1711 = reinterpret_tensor(buf1639, (1, 14336), (14336, 1), 0); del buf1639 # reuse
# Source Nodes: [c_137], Original ATen: [aten.mm]
triton_tem_fused_mm_6.run(buf1708, buf1710, buf1711, grid=torch._inductor.kernel.mm_common.mm_grid(1, 14336, meta1), stream=stream0)
del buf1710
# Source Nodes: [input_416], Original ATen: [quantized_decomposed.quantize_per_token]
buf1712 = torch.ops.quantized_decomposed.quantize_per_token.default(buf1698, buf1703, buf1704, -128, 127, torch.int8)
buf1713 = buf1712
del buf1712
# Source Nodes: [input_417], Original ATen: [quantized_decomposed.dequantize_per_token]
buf1714 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf1713, buf1703, buf1704, -128, 127, torch.int8, torch.bfloat16)
del buf1703
del buf1704
del buf1713
buf1715 = buf1714
del buf1714
# Source Nodes: [w_dq_138], Original ATen: [quantized_decomposed.dequantize_per_channel_group]
buf1716 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg518_1, arg519_1, arg520_1, -8, 7, torch.int8, 256, torch.bfloat16)
del arg518_1
del arg519_1
del arg520_1
buf1717 = buf1716
del buf1716
buf1719 = buf1632; del buf1632 # reuse
# Source Nodes: [c_138, mul_259, silu_19], Original ATen: [aten.mm, aten.mul, aten.silu]
triton_tem_fused_mm_mul_silu_7.run(buf1715, buf1717, buf1711, buf1719, grid=torch._inductor.kernel.mm_common.mm_grid(1, 14336, meta1), stream=stream0)
del buf1711
del buf1717
# Source Nodes: [choose_qparams_per_token_asymmetric_139], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric]
buf1720 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf1719, torch.int8)
buf1721 = buf1720[0]
buf1722 = buf1720[1]
del buf1720
# Source Nodes: [input_419], Original ATen: [quantized_decomposed.quantize_per_token]
buf1723 = torch.ops.quantized_decomposed.quantize_per_token.default(buf1719, buf1721, buf1722, -128, 127, torch.int8)
buf1724 = buf1723
del buf1723
# Source Nodes: [input_420], Original ATen: [quantized_decomposed.dequantize_per_token]
buf1725 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf1724, buf1721, buf1722, -128, 127, torch.int8, torch.bfloat16)
del buf1721
del buf1722
del buf1724
buf1726 = buf1725
del buf1725
# Source Nodes: [w_dq_139], Original ATen: [quantized_decomposed.dequantize_per_channel_group]
buf1727 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg521_1, arg522_1, arg523_1, -8, 7, torch.int8, 256, torch.bfloat16)
del arg521_1
del arg522_1
del arg523_1
buf1728 = buf1727
del buf1727
buf1729 = reinterpret_tensor(buf1715, (1, 4096), (4096, 1), 0); del buf1715 # reuse
# Source Nodes: [c_139], Original ATen: [aten.mm]
triton_tem_fused_mm_8.run(buf1726, buf1728, buf1729, grid=torch._inductor.kernel.mm_common.mm_grid(1, 4096, meta0), stream=stream0)
del buf1728
buf1731 = buf1698; del buf1698 # reuse
# Source Nodes: [add_140, mean_40, mul_260, out_19, pow_41, rsqrt_40, x_fp32_40, x_normed_40, y_20], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, aten.pow, aten.rsqrt]
triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_11.run(buf1696, buf1729, arg524_1, buf1731, 1, 4096, grid=grid(1), stream=stream0)
del arg524_1
# Source Nodes: [choose_qparams_per_token_asymmetric_140], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric]
buf1732 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf1731, torch.int8)
buf1733 = buf1732[0]
buf1734 = buf1732[1]
del buf1732
# Source Nodes: [choose_qparams_per_token_asymmetric_141], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric]
buf1735 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf1731, torch.int8)
buf1736 = buf1735[0]
buf1737 = buf1735[1]
del buf1735
# Source Nodes: [choose_qparams_per_token_asymmetric_142], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric]
buf1738 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf1731, torch.int8)
buf1739 = buf1738[0]
buf1740 = buf1738[1]
del buf1738
buf1741 = buf1654; del buf1654 # reuse
# Source Nodes: [max_21], Original ATen: [aten.max]
triton_poi_fused_max_1.run(arg3_1, buf1741, 1, grid=grid(1), stream=stream0)
u20 = buf1741.item()
buf1742 = None
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
# Source Nodes: [input_422], Original ATen: [quantized_decomposed.quantize_per_token]
buf1743 = torch.ops.quantized_decomposed.quantize_per_token.default(buf1731, buf1733, buf1734, -128, 127, torch.int8)
buf1744 = buf1743
del buf1743
# Source Nodes: [input_423], Original ATen: [quantized_decomposed.dequantize_per_token]
buf1745 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf1744, buf1733, buf1734, -128, 127, torch.int8, torch.bfloat16)
del buf1733
del buf1734
del buf1744
buf1746 = buf1745
del buf1745
# Source Nodes: [w_dq_140], Original ATen: [quantized_decomposed.dequantize_per_channel_group]
buf1747 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg525_1, arg526_1, arg527_1, -8, 7, torch.int8, 256, torch.bfloat16)
del arg525_1
del arg526_1
del arg527_1
buf1748 = buf1747
del buf1747
buf1749 = reinterpret_tensor(buf1708, (1, 4096), (4096, 1), 0); del buf1708 # reuse
# Source Nodes: [c_140], Original ATen: [aten.mm]
triton_tem_fused_mm_2.run(buf1746, buf1748, buf1749, grid=torch._inductor.kernel.mm_common.mm_grid(1, 4096, meta0), stream=stream0)
del buf1746
del buf1748
# Source Nodes: [input_425], Original ATen: [quantized_decomposed.quantize_per_token]
buf1750 = torch.ops.quantized_decomposed.quantize_per_token.default(buf1731, buf1736, buf1737, -128, 127, torch.int8)
buf1751 = buf1750
del buf1750
# Source Nodes: [input_426], Original ATen: [quantized_decomposed.dequantize_per_token]
buf1752 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf1751, buf1736, buf1737, -128, 127, torch.int8, torch.bfloat16)
del buf1736
del buf1737
del buf1751
buf1753 = buf1752
del buf1752
# Source Nodes: [w_dq_141], Original ATen: [quantized_decomposed.dequantize_per_channel_group]
buf1754 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg528_1, arg529_1, arg530_1, -8, 7, torch.int8, 256, torch.bfloat16)
del arg528_1
del arg529_1
del arg530_1
buf1755 = buf1754
del buf1754
buf1756 = buf1677; del buf1677 # reuse
# Source Nodes: [], Original ATen: []
extern_kernels.mm(reinterpret_tensor(buf1753, (1, 4096), (4096, 1), 0), reinterpret_tensor(buf1755, (4096, 1024), (1, 4096), 0), out=buf1756)
del buf1753
del buf1755
# Source Nodes: [input_428], Original ATen: [quantized_decomposed.quantize_per_token]
buf1758 = torch.ops.quantized_decomposed.quantize_per_token.default(buf1731, buf1739, buf1740, -128, 127, torch.int8)
del buf1731
buf1759 = buf1758
del buf1758
# Source Nodes: [input_429], Original ATen: [quantized_decomposed.dequantize_per_token]
buf1760 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf1759, buf1739, buf1740, -128, 127, torch.int8, torch.bfloat16)
del buf1739
del buf1740
del buf1759
buf1761 = buf1760
del buf1760
# Source Nodes: [w_dq_142], Original ATen: [quantized_decomposed.dequantize_per_channel_group]
buf1762 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg531_1, arg532_1, arg533_1, -8, 7, torch.int8, 256, torch.bfloat16)
del arg531_1
del arg532_1
del arg533_1
buf1763 = buf1762
del buf1762
buf1764 = buf1669; del buf1669 # reuse
# Source Nodes: [], Original ATen: []
extern_kernels.mm(reinterpret_tensor(buf1761, (1, 4096), (4096, 1), 0), reinterpret_tensor(buf1763, (4096, 1024), (1, 4096), 0), out=buf1764)
del buf1763
buf1766 = reinterpret_tensor(buf1761, (1, 32, 1, 128), (4096, 128, 4096, 1), 0); del buf1761 # reuse
# Source Nodes: [output_40, setitem_40, setitem_41], Original ATen: [aten._scaled_dot_product_efficient_attention, aten.index_put]
triton_poi_fused__scaled_dot_product_efficient_attention_index_put_3.run(arg3_1, buf1756, arg534_1, buf1764, buf1749, arg535_1, arg536_1, buf1766, 4096, grid=grid(4096), stream=stream0)
del arg534_1
del buf1749
buf1767 = buf1680; del buf1680 # reuse
# Source Nodes: [output_40], Original ATen: [aten._scaled_dot_product_efficient_attention]
triton_poi_fused__scaled_dot_product_efficient_attention_4.run(arg3_1, arg2_1, buf1767, 65536, grid=grid(65536), stream=stream0)
# Source Nodes: [output_40], Original ATen: [aten._scaled_dot_product_efficient_attention]
buf1768 = torch.ops.aten._scaled_dot_product_efficient_attention.default(buf1766, arg535_1, arg536_1, buf1767, False)
del arg535_1
del arg536_1
del buf1766
buf1769 = buf1768[0]
del buf1768
# Source Nodes: [choose_qparams_per_token_asymmetric_143], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric]
buf1773 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(reinterpret_tensor(buf1769, (1, 1, 4096), (4096, 4096, 1), 0), torch.int8)
buf1774 = buf1773[0]
buf1775 = buf1773[1]
del buf1773
# Source Nodes: [input_431], Original ATen: [quantized_decomposed.quantize_per_token]
buf1776 = torch.ops.quantized_decomposed.quantize_per_token.default(reinterpret_tensor(buf1769, (1, 1, 4096), (4096, 4096, 1), 0), buf1774, buf1775, -128, 127, torch.int8)
buf1777 = buf1776
del buf1776
# Source Nodes: [input_432], Original ATen: [quantized_decomposed.dequantize_per_token]
buf1778 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf1777, buf1774, buf1775, -128, 127, torch.int8, torch.bfloat16)
del buf1774
del buf1775
del buf1777
buf1779 = buf1778
del buf1778
# Source Nodes: [w_dq_143], Original ATen: [quantized_decomposed.dequantize_per_channel_group]
buf1780 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg537_1, arg538_1, arg539_1, -8, 7, torch.int8, 256, torch.bfloat16)
del arg537_1
del arg538_1
del arg539_1
buf1781 = buf1780
del buf1780
buf1782 = reinterpret_tensor(buf1769, (1, 4096), (4096, 1), 0); del buf1769 # reuse
# Source Nodes: [c_143], Original ATen: [aten.mm]
triton_tem_fused_mm_2.run(buf1779, buf1781, buf1782, grid=torch._inductor.kernel.mm_common.mm_grid(1, 4096, meta0), stream=stream0)
del buf1781
buf1784 = buf1779; del buf1779 # reuse
# Source Nodes: [add_145, h_21, mean_41, mul_270, mul_271, out_19, pow_42, rsqrt_41, x_fp32_41, x_normed_41], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, aten.pow, aten.rsqrt]
triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_12.run(buf1782, buf1696, buf1729, arg540_1, buf1784, 1, 4096, grid=grid(1), stream=stream0)
del arg540_1
# Source Nodes: [choose_qparams_per_token_asymmetric_144], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric]
buf1785 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf1784, torch.int8)
buf1786 = buf1785[0]
buf1787 = buf1785[1]
del buf1785
# Source Nodes: [choose_qparams_per_token_asymmetric_145], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric]
buf1788 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf1784, torch.int8)
buf1789 = buf1788[0]
buf1790 = buf1788[1]
del buf1788
# Source Nodes: [input_434], Original ATen: [quantized_decomposed.quantize_per_token]
buf1791 = torch.ops.quantized_decomposed.quantize_per_token.default(buf1784, buf1786, buf1787, -128, 127, torch.int8)
buf1792 = buf1791
del buf1791
# Source Nodes: [input_435], Original ATen: [quantized_decomposed.dequantize_per_token]
buf1793 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf1792, buf1786, buf1787, -128, 127, torch.int8, torch.bfloat16)
del buf1786
del buf1787
del buf1792
buf1794 = buf1793
del buf1793
# Source Nodes: [w_dq_144], Original ATen: [quantized_decomposed.dequantize_per_channel_group]
buf1795 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg541_1, arg542_1, arg543_1, -8, 7, torch.int8, 256, torch.bfloat16)
del arg541_1
del arg542_1
del arg543_1
buf1796 = buf1795
del buf1795
buf1797 = reinterpret_tensor(buf1726, (1, 14336), (14336, 1), 0); del buf1726 # reuse
# Source Nodes: [c_144], Original ATen: [aten.mm]
triton_tem_fused_mm_6.run(buf1794, buf1796, buf1797, grid=torch._inductor.kernel.mm_common.mm_grid(1, 14336, meta1), stream=stream0)
del buf1796
# Source Nodes: [input_437], Original ATen: [quantized_decomposed.quantize_per_token]
buf1798 = torch.ops.quantized_decomposed.quantize_per_token.default(buf1784, buf1789, buf1790, -128, 127, torch.int8)
buf1799 = buf1798
del buf1798
# Source Nodes: [input_438], Original ATen: [quantized_decomposed.dequantize_per_token]
buf1800 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf1799, buf1789, buf1790, -128, 127, torch.int8, torch.bfloat16)
del buf1789
del buf1790
del buf1799
buf1801 = buf1800
del buf1800
# Source Nodes: [w_dq_145], Original ATen: [quantized_decomposed.dequantize_per_channel_group]
buf1802 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg544_1, arg545_1, arg546_1, -8, 7, torch.int8, 256, torch.bfloat16)
del arg544_1
del arg545_1
del arg546_1
buf1803 = buf1802
del buf1802
buf1805 = buf1719; del buf1719 # reuse
# Source Nodes: [c_145, mul_272, silu_20], Original ATen: [aten.mm, aten.mul, aten.silu]
triton_tem_fused_mm_mul_silu_7.run(buf1801, buf1803, buf1797, buf1805, grid=torch._inductor.kernel.mm_common.mm_grid(1, 14336, meta1), stream=stream0)
del buf1797
del buf1803
# Source Nodes: [choose_qparams_per_token_asymmetric_146], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric]
buf1806 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf1805, torch.int8)
buf1807 = buf1806[0]
buf1808 = buf1806[1]
del buf1806
# Source Nodes: [input_440], Original ATen: [quantized_decomposed.quantize_per_token]
buf1809 = torch.ops.quantized_decomposed.quantize_per_token.default(buf1805, buf1807, buf1808, -128, 127, torch.int8)
buf1810 = buf1809
del buf1809
# Source Nodes: [input_441], Original ATen: [quantized_decomposed.dequantize_per_token]
buf1811 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf1810, buf1807, buf1808, -128, 127, torch.int8, torch.bfloat16)
del buf1807
del buf1808
del buf1810
buf1812 = buf1811
del buf1811
# Source Nodes: [w_dq_146], Original ATen: [quantized_decomposed.dequantize_per_channel_group]
buf1813 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg547_1, arg548_1, arg549_1, -8, 7, torch.int8, 256, torch.bfloat16)
del arg547_1
del arg548_1
del arg549_1
buf1814 = buf1813
del buf1813
buf1815 = reinterpret_tensor(buf1801, (1, 4096), (4096, 1), 0); del buf1801 # reuse
# Source Nodes: [c_146], Original ATen: [aten.mm]
triton_tem_fused_mm_8.run(buf1812, buf1814, buf1815, grid=torch._inductor.kernel.mm_common.mm_grid(1, 4096, meta0), stream=stream0)
del buf1814
buf1817 = buf1784; del buf1784 # reuse
# Source Nodes: [add_147, h_21, mean_42, mul_273, out_19, out_20, pow_43, rsqrt_42, x_fp32_42, x_normed_42, y_21], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, aten.pow, aten.rsqrt]
triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_13.run(buf1782, buf1696, buf1729, buf1815, arg550_1, buf1817, 1, 4096, grid=grid(1), stream=stream0)
del arg550_1
# Source Nodes: [choose_qparams_per_token_asymmetric_147], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric]
buf1818 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf1817, torch.int8)
buf1819 = buf1818[0]
buf1820 = buf1818[1]
del buf1818
# Source Nodes: [choose_qparams_per_token_asymmetric_148], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric]
buf1821 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf1817, torch.int8)
buf1822 = buf1821[0]
buf1823 = buf1821[1]
del buf1821
# Source Nodes: [choose_qparams_per_token_asymmetric_149], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric]
buf1824 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf1817, torch.int8)
buf1825 = buf1824[0]
buf1826 = buf1824[1]
del buf1824
buf1827 = buf1741; del buf1741 # reuse
# Source Nodes: [max_22], Original ATen: [aten.max]
triton_poi_fused_max_1.run(arg3_1, buf1827, 1, grid=grid(1), stream=stream0)
u21 = buf1827.item()
buf1828 = None
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
# Source Nodes: [input_443], Original ATen: [quantized_decomposed.quantize_per_token]
buf1829 = torch.ops.quantized_decomposed.quantize_per_token.default(buf1817, buf1819, buf1820, -128, 127, torch.int8)
buf1830 = buf1829
del buf1829
# Source Nodes: [input_444], Original ATen: [quantized_decomposed.dequantize_per_token]
buf1831 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf1830, buf1819, buf1820, -128, 127, torch.int8, torch.bfloat16)
del buf1819
del buf1820
del buf1830
buf1832 = buf1831
del buf1831
# Source Nodes: [w_dq_147], Original ATen: [quantized_decomposed.dequantize_per_channel_group]
buf1833 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg551_1, arg552_1, arg553_1, -8, 7, torch.int8, 256, torch.bfloat16)
del arg551_1
del arg552_1
del arg553_1
buf1834 = buf1833
del buf1833
buf1835 = reinterpret_tensor(buf1794, (1, 4096), (4096, 1), 0); del buf1794 # reuse
# Source Nodes: [c_147], Original ATen: [aten.mm]
triton_tem_fused_mm_2.run(buf1832, buf1834, buf1835, grid=torch._inductor.kernel.mm_common.mm_grid(1, 4096, meta0), stream=stream0)
del buf1832
del buf1834
# Source Nodes: [input_446], Original ATen: [quantized_decomposed.quantize_per_token]
buf1836 = torch.ops.quantized_decomposed.quantize_per_token.default(buf1817, buf1822, buf1823, -128, 127, torch.int8)
buf1837 = buf1836
del buf1836
# Source Nodes: [input_447], Original ATen: [quantized_decomposed.dequantize_per_token]
buf1838 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf1837, buf1822, buf1823, -128, 127, torch.int8, torch.bfloat16)
del buf1822
del buf1823
del buf1837
buf1839 = buf1838
del buf1838
# Source Nodes: [w_dq_148], Original ATen: [quantized_decomposed.dequantize_per_channel_group]
buf1840 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg554_1, arg555_1, arg556_1, -8, 7, torch.int8, 256, torch.bfloat16)
del arg554_1
del arg555_1
del arg556_1
buf1841 = buf1840
del buf1840
buf1842 = buf1764; del buf1764 # reuse
# Source Nodes: [], Original ATen: []
extern_kernels.mm(reinterpret_tensor(buf1839, (1, 4096), (4096, 1), 0), reinterpret_tensor(buf1841, (4096, 1024), (1, 4096), 0), out=buf1842)
del buf1839
del buf1841
# Source Nodes: [input_449], Original ATen: [quantized_decomposed.quantize_per_token]
buf1844 = torch.ops.quantized_decomposed.quantize_per_token.default(buf1817, buf1825, buf1826, -128, 127, torch.int8)
del buf1817
buf1845 = buf1844
del buf1844
# Source Nodes: [input_450], Original ATen: [quantized_decomposed.dequantize_per_token]
buf1846 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf1845, buf1825, buf1826, -128, 127, torch.int8, torch.bfloat16)
del buf1825
del buf1826
del buf1845
buf1847 = buf1846
del buf1846
# Source Nodes: [w_dq_149], Original ATen: [quantized_decomposed.dequantize_per_channel_group]
buf1848 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg557_1, arg558_1, arg559_1, -8, 7, torch.int8, 256, torch.bfloat16)
del arg557_1
del arg558_1
del arg559_1
buf1849 = buf1848
del buf1848
buf1850 = buf1756; del buf1756 # reuse
# Source Nodes: [], Original ATen: []
extern_kernels.mm(reinterpret_tensor(buf1847, (1, 4096), (4096, 1), 0), reinterpret_tensor(buf1849, (4096, 1024), (1, 4096), 0), out=buf1850)
del buf1849
buf1852 = reinterpret_tensor(buf1847, (1, 32, 1, 128), (4096, 128, 4096, 1), 0); del buf1847 # reuse
# Source Nodes: [output_42, setitem_42, setitem_43], Original ATen: [aten._scaled_dot_product_efficient_attention, aten.index_put]
triton_poi_fused__scaled_dot_product_efficient_attention_index_put_3.run(arg3_1, buf1842, arg560_1, buf1850, buf1835, arg561_1, arg562_1, buf1852, 4096, grid=grid(4096), stream=stream0)
del arg560_1
del buf1835
buf1853 = buf1767; del buf1767 # reuse
# Source Nodes: [output_42], Original ATen: [aten._scaled_dot_product_efficient_attention]
triton_poi_fused__scaled_dot_product_efficient_attention_4.run(arg3_1, arg2_1, buf1853, 65536, grid=grid(65536), stream=stream0)
# Source Nodes: [output_42], Original ATen: [aten._scaled_dot_product_efficient_attention]
buf1854 = torch.ops.aten._scaled_dot_product_efficient_attention.default(buf1852, arg561_1, arg562_1, buf1853, False)
del arg561_1
del arg562_1
del buf1852
buf1855 = buf1854[0]
del buf1854
# Source Nodes: [choose_qparams_per_token_asymmetric_150], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric]
buf1859 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(reinterpret_tensor(buf1855, (1, 1, 4096), (4096, 4096, 1), 0), torch.int8)
buf1860 = buf1859[0]
buf1861 = buf1859[1]
del buf1859
# Source Nodes: [input_452], Original ATen: [quantized_decomposed.quantize_per_token]
buf1862 = torch.ops.quantized_decomposed.quantize_per_token.default(reinterpret_tensor(buf1855, (1, 1, 4096), (4096, 4096, 1), 0), buf1860, buf1861, -128, 127, torch.int8)
buf1863 = buf1862
del buf1862
# Source Nodes: [input_453], Original ATen: [quantized_decomposed.dequantize_per_token]
buf1864 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf1863, buf1860, buf1861, -128, 127, torch.int8, torch.bfloat16)
del buf1860
del buf1861
del buf1863
buf1865 = buf1864
del buf1864
# Source Nodes: [w_dq_150], Original ATen: [quantized_decomposed.dequantize_per_channel_group]
buf1866 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg563_1, arg564_1, arg565_1, -8, 7, torch.int8, 256, torch.bfloat16)
del arg563_1
del arg564_1
del arg565_1
buf1867 = buf1866
del buf1866
buf1868 = reinterpret_tensor(buf1855, (1, 4096), (4096, 1), 0); del buf1855 # reuse
# Source Nodes: [c_150], Original ATen: [aten.mm]
triton_tem_fused_mm_2.run(buf1865, buf1867, buf1868, grid=torch._inductor.kernel.mm_common.mm_grid(1, 4096, meta0), stream=stream0)
del buf1867
buf1869 = buf1696; del buf1696 # reuse
buf1871 = buf1865; del buf1865 # reuse
# Source Nodes: [add_152, h_21, h_22, mean_43, mul_283, mul_284, out_19, out_20, pow_44, rsqrt_43, x_fp32_43, x_normed_43], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, aten.pow, aten.rsqrt]
triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_14.run(buf1869, buf1868, buf1782, buf1729, buf1815, arg566_1, buf1871, 1, 4096, grid=grid(1), stream=stream0)
del arg566_1
del buf1729
del buf1782
del buf1815
del buf1868
# Source Nodes: [choose_qparams_per_token_asymmetric_151], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric]
buf1872 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf1871, torch.int8)
buf1873 = buf1872[0]
buf1874 = buf1872[1]
del buf1872
# Source Nodes: [choose_qparams_per_token_asymmetric_152], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric]
buf1875 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf1871, torch.int8)
buf1876 = buf1875[0]
buf1877 = buf1875[1]
del buf1875
# Source Nodes: [input_455], Original ATen: [quantized_decomposed.quantize_per_token]
buf1878 = torch.ops.quantized_decomposed.quantize_per_token.default(buf1871, buf1873, buf1874, -128, 127, torch.int8)
buf1879 = buf1878
del buf1878
# Source Nodes: [input_456], Original ATen: [quantized_decomposed.dequantize_per_token]
buf1880 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf1879, buf1873, buf1874, -128, 127, torch.int8, torch.bfloat16)
del buf1873
del buf1874
del buf1879
buf1881 = buf1880
del buf1880
# Source Nodes: [w_dq_151], Original ATen: [quantized_decomposed.dequantize_per_channel_group]
buf1882 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg567_1, arg568_1, arg569_1, -8, 7, torch.int8, 256, torch.bfloat16)
del arg567_1
del arg568_1
del arg569_1
buf1883 = buf1882
del buf1882
buf1884 = reinterpret_tensor(buf1812, (1, 14336), (14336, 1), 0); del buf1812 # reuse
# Source Nodes: [c_151], Original ATen: [aten.mm]
triton_tem_fused_mm_6.run(buf1881, buf1883, buf1884, grid=torch._inductor.kernel.mm_common.mm_grid(1, 14336, meta1), stream=stream0)
del buf1883
# Source Nodes: [input_458], Original ATen: [quantized_decomposed.quantize_per_token]
buf1885 = torch.ops.quantized_decomposed.quantize_per_token.default(buf1871, buf1876, buf1877, -128, 127, torch.int8)
buf1886 = buf1885
del buf1885
# Source Nodes: [input_459], Original ATen: [quantized_decomposed.dequantize_per_token]
buf1887 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf1886, buf1876, buf1877, -128, 127, torch.int8, torch.bfloat16)
del buf1876
del buf1877
del buf1886
buf1888 = buf1887
del buf1887
# Source Nodes: [w_dq_152], Original ATen: [quantized_decomposed.dequantize_per_channel_group]
buf1889 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg570_1, arg571_1, arg572_1, -8, 7, torch.int8, 256, torch.bfloat16)
del arg570_1
del arg571_1
del arg572_1
buf1890 = buf1889
del buf1889
buf1892 = buf1805; del buf1805 # reuse
# Source Nodes: [c_152, mul_285, silu_21], Original ATen: [aten.mm, aten.mul, aten.silu]
triton_tem_fused_mm_mul_silu_7.run(buf1888, buf1890, buf1884, buf1892, grid=torch._inductor.kernel.mm_common.mm_grid(1, 14336, meta1), stream=stream0)
del buf1884
del buf1890
# Source Nodes: [choose_qparams_per_token_asymmetric_153], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric]
buf1893 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf1892, torch.int8)
buf1894 = buf1893[0]
buf1895 = buf1893[1]
del buf1893
# Source Nodes: [input_461], Original ATen: [quantized_decomposed.quantize_per_token]
buf1896 = torch.ops.quantized_decomposed.quantize_per_token.default(buf1892, buf1894, buf1895, -128, 127, torch.int8)
buf1897 = buf1896
del buf1896
# Source Nodes: [input_462], Original ATen: [quantized_decomposed.dequantize_per_token]
buf1898 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf1897, buf1894, buf1895, -128, 127, torch.int8, torch.bfloat16)
del buf1894
del buf1895
del buf1897
buf1899 = buf1898
del buf1898
# Source Nodes: [w_dq_153], Original ATen: [quantized_decomposed.dequantize_per_channel_group]
buf1900 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg573_1, arg574_1, arg575_1, -8, 7, torch.int8, 256, torch.bfloat16)
del arg573_1
del arg574_1
del arg575_1
buf1901 = buf1900
del buf1900
buf1902 = reinterpret_tensor(buf1888, (1, 4096), (4096, 1), 0); del buf1888 # reuse
# Source Nodes: [c_153], Original ATen: [aten.mm]
triton_tem_fused_mm_8.run(buf1899, buf1901, buf1902, grid=torch._inductor.kernel.mm_common.mm_grid(1, 4096, meta0), stream=stream0)
del buf1901
buf1904 = buf1871; del buf1871 # reuse
# Source Nodes: [add_154, mean_44, mul_286, out_21, pow_45, rsqrt_44, x_fp32_44, x_normed_44, y_22], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, aten.pow, aten.rsqrt]
triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_11.run(buf1869, buf1902, arg576_1, buf1904, 1, 4096, grid=grid(1), stream=stream0)
del arg576_1
# Source Nodes: [choose_qparams_per_token_asymmetric_154], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric]
buf1905 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf1904, torch.int8)
buf1906 = buf1905[0]
buf1907 = buf1905[1]
del buf1905
# Source Nodes: [choose_qparams_per_token_asymmetric_155], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric]
buf1908 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf1904, torch.int8)
buf1909 = buf1908[0]
buf1910 = buf1908[1]
del buf1908
# Source Nodes: [choose_qparams_per_token_asymmetric_156], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric]
buf1911 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf1904, torch.int8)
buf1912 = buf1911[0]
buf1913 = buf1911[1]
del buf1911
buf1914 = buf1827; del buf1827 # reuse
# Source Nodes: [max_23], Original ATen: [aten.max]
triton_poi_fused_max_1.run(arg3_1, buf1914, 1, grid=grid(1), stream=stream0)
u22 = buf1914.item()
buf1915 = None
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
# Source Nodes: [input_464], Original ATen: [quantized_decomposed.quantize_per_token]
buf1916 = torch.ops.quantized_decomposed.quantize_per_token.default(buf1904, buf1906, buf1907, -128, 127, torch.int8)
buf1917 = buf1916
del buf1916
# Source Nodes: [input_465], Original ATen: [quantized_decomposed.dequantize_per_token]
buf1918 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf1917, buf1906, buf1907, -128, 127, torch.int8, torch.bfloat16)
del buf1906
del buf1907
del buf1917
buf1919 = buf1918
del buf1918
# Source Nodes: [w_dq_154], Original ATen: [quantized_decomposed.dequantize_per_channel_group]
buf1920 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg577_1, arg578_1, arg579_1, -8, 7, torch.int8, 256, torch.bfloat16)
del arg577_1
del arg578_1
del arg579_1
buf1921 = buf1920
del buf1920
buf1922 = reinterpret_tensor(buf1881, (1, 4096), (4096, 1), 0); del buf1881 # reuse
# Source Nodes: [c_154], Original ATen: [aten.mm]
triton_tem_fused_mm_2.run(buf1919, buf1921, buf1922, grid=torch._inductor.kernel.mm_common.mm_grid(1, 4096, meta0), stream=stream0)
del buf1919
del buf1921
# Source Nodes: [input_467], Original ATen: [quantized_decomposed.quantize_per_token]
buf1923 = torch.ops.quantized_decomposed.quantize_per_token.default(buf1904, buf1909, buf1910, -128, 127, torch.int8)
buf1924 = buf1923
del buf1923
# Source Nodes: [input_468], Original ATen: [quantized_decomposed.dequantize_per_token]
buf1925 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf1924, buf1909, buf1910, -128, 127, torch.int8, torch.bfloat16)
del buf1909
del buf1910
del buf1924
buf1926 = buf1925
del buf1925
# Source Nodes: [w_dq_155], Original ATen: [quantized_decomposed.dequantize_per_channel_group]
buf1927 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg580_1, arg581_1, arg582_1, -8, 7, torch.int8, 256, torch.bfloat16)
del arg580_1
del arg581_1
del arg582_1
buf1928 = buf1927
del buf1927
buf1929 = buf1850; del buf1850 # reuse
# Source Nodes: [], Original ATen: []
extern_kernels.mm(reinterpret_tensor(buf1926, (1, 4096), (4096, 1), 0), reinterpret_tensor(buf1928, (4096, 1024), (1, 4096), 0), out=buf1929)
del buf1926
del buf1928
# Source Nodes: [input_470], Original ATen: [quantized_decomposed.quantize_per_token]
buf1931 = torch.ops.quantized_decomposed.quantize_per_token.default(buf1904, buf1912, buf1913, -128, 127, torch.int8)
del buf1904
buf1932 = buf1931
del buf1931
# Source Nodes: [input_471], Original ATen: [quantized_decomposed.dequantize_per_token]
buf1933 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf1932, buf1912, buf1913, -128, 127, torch.int8, torch.bfloat16)
del buf1912
del buf1913
del buf1932
buf1934 = buf1933
del buf1933
# Source Nodes: [w_dq_156], Original ATen: [quantized_decomposed.dequantize_per_channel_group]
buf1935 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg583_1, arg584_1, arg585_1, -8, 7, torch.int8, 256, torch.bfloat16)
del arg583_1
del arg584_1
del arg585_1
buf1936 = buf1935
del buf1935
buf1937 = buf1842; del buf1842 # reuse
# Source Nodes: [], Original ATen: []
extern_kernels.mm(reinterpret_tensor(buf1934, (1, 4096), (4096, 1), 0), reinterpret_tensor(buf1936, (4096, 1024), (1, 4096), 0), out=buf1937)
del buf1936
buf1939 = reinterpret_tensor(buf1934, (1, 32, 1, 128), (4096, 128, 4096, 1), 0); del buf1934 # reuse
# Source Nodes: [output_44, setitem_44, setitem_45], Original ATen: [aten._scaled_dot_product_efficient_attention, aten.index_put]
triton_poi_fused__scaled_dot_product_efficient_attention_index_put_3.run(arg3_1, buf1929, arg586_1, buf1937, buf1922, arg587_1, arg588_1, buf1939, 4096, grid=grid(4096), stream=stream0)
del arg586_1
del buf1922
buf1940 = buf1853; del buf1853 # reuse
# Source Nodes: [output_44], Original ATen: [aten._scaled_dot_product_efficient_attention]
triton_poi_fused__scaled_dot_product_efficient_attention_4.run(arg3_1, arg2_1, buf1940, 65536, grid=grid(65536), stream=stream0)
# Source Nodes: [output_44], Original ATen: [aten._scaled_dot_product_efficient_attention]
buf1941 = torch.ops.aten._scaled_dot_product_efficient_attention.default(buf1939, arg587_1, arg588_1, buf1940, False)
del arg587_1
del arg588_1
del buf1939
buf1942 = buf1941[0]
del buf1941
# Source Nodes: [choose_qparams_per_token_asymmetric_157], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric]
buf1946 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(reinterpret_tensor(buf1942, (1, 1, 4096), (4096, 4096, 1), 0), torch.int8)
buf1947 = buf1946[0]
buf1948 = buf1946[1]
del buf1946
# Source Nodes: [input_473], Original ATen: [quantized_decomposed.quantize_per_token]
buf1949 = torch.ops.quantized_decomposed.quantize_per_token.default(reinterpret_tensor(buf1942, (1, 1, 4096), (4096, 4096, 1), 0), buf1947, buf1948, -128, 127, torch.int8)
buf1950 = buf1949
del buf1949
# Source Nodes: [input_474], Original ATen: [quantized_decomposed.dequantize_per_token]
buf1951 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf1950, buf1947, buf1948, -128, 127, torch.int8, torch.bfloat16)
del buf1947
del buf1948
del buf1950
buf1952 = buf1951
del buf1951
# Source Nodes: [w_dq_157], Original ATen: [quantized_decomposed.dequantize_per_channel_group]
buf1953 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg589_1, arg590_1, arg591_1, -8, 7, torch.int8, 256, torch.bfloat16)
del arg589_1
del arg590_1
del arg591_1
buf1954 = buf1953
del buf1953
buf1955 = reinterpret_tensor(buf1942, (1, 4096), (4096, 1), 0); del buf1942 # reuse
# Source Nodes: [c_157], Original ATen: [aten.mm]
triton_tem_fused_mm_2.run(buf1952, buf1954, buf1955, grid=torch._inductor.kernel.mm_common.mm_grid(1, 4096, meta0), stream=stream0)
del buf1954
buf1957 = buf1952; del buf1952 # reuse
# Source Nodes: [add_159, h_23, mean_45, mul_296, mul_297, out_21, pow_46, rsqrt_45, x_fp32_45, x_normed_45], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, aten.pow, aten.rsqrt]
triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_12.run(buf1955, buf1869, buf1902, arg592_1, buf1957, 1, 4096, grid=grid(1), stream=stream0)
del arg592_1
# Source Nodes: [choose_qparams_per_token_asymmetric_158], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric]
buf1958 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf1957, torch.int8)
buf1959 = buf1958[0]
buf1960 = buf1958[1]
del buf1958
# Source Nodes: [choose_qparams_per_token_asymmetric_159], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric]
buf1961 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf1957, torch.int8)
buf1962 = buf1961[0]
buf1963 = buf1961[1]
del buf1961
# Source Nodes: [input_476], Original ATen: [quantized_decomposed.quantize_per_token]
buf1964 = torch.ops.quantized_decomposed.quantize_per_token.default(buf1957, buf1959, buf1960, -128, 127, torch.int8)
buf1965 = buf1964
del buf1964
# Source Nodes: [input_477], Original ATen: [quantized_decomposed.dequantize_per_token]
buf1966 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf1965, buf1959, buf1960, -128, 127, torch.int8, torch.bfloat16)
del buf1959
del buf1960
del buf1965
buf1967 = buf1966
del buf1966
# Source Nodes: [w_dq_158], Original ATen: [quantized_decomposed.dequantize_per_channel_group]
buf1968 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg593_1, arg594_1, arg595_1, -8, 7, torch.int8, 256, torch.bfloat16)
del arg593_1
del arg594_1
del arg595_1
buf1969 = buf1968
del buf1968
buf1970 = reinterpret_tensor(buf1899, (1, 14336), (14336, 1), 0); del buf1899 # reuse
# Source Nodes: [c_158], Original ATen: [aten.mm]
triton_tem_fused_mm_6.run(buf1967, buf1969, buf1970, grid=torch._inductor.kernel.mm_common.mm_grid(1, 14336, meta1), stream=stream0)
del buf1969
# Source Nodes: [input_479], Original ATen: [quantized_decomposed.quantize_per_token]
buf1971 = torch.ops.quantized_decomposed.quantize_per_token.default(buf1957, buf1962, buf1963, -128, 127, torch.int8)
buf1972 = buf1971
del buf1971
# Source Nodes: [input_480], Original ATen: [quantized_decomposed.dequantize_per_token]
buf1973 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf1972, buf1962, buf1963, -128, 127, torch.int8, torch.bfloat16)
del buf1962
del buf1963
del buf1972
buf1974 = buf1973
del buf1973
# Source Nodes: [w_dq_159], Original ATen: [quantized_decomposed.dequantize_per_channel_group]
buf1975 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg596_1, arg597_1, arg598_1, -8, 7, torch.int8, 256, torch.bfloat16)
del arg596_1
del arg597_1
del arg598_1
buf1976 = buf1975
del buf1975
buf1978 = buf1892; del buf1892 # reuse
# Source Nodes: [c_159, mul_298, silu_22], Original ATen: [aten.mm, aten.mul, aten.silu]
triton_tem_fused_mm_mul_silu_7.run(buf1974, buf1976, buf1970, buf1978, grid=torch._inductor.kernel.mm_common.mm_grid(1, 14336, meta1), stream=stream0)
del buf1970
del buf1976
# Source Nodes: [choose_qparams_per_token_asymmetric_160], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric]
buf1979 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf1978, torch.int8)
buf1980 = buf1979[0]
buf1981 = buf1979[1]
del buf1979
# Source Nodes: [input_482], Original ATen: [quantized_decomposed.quantize_per_token]
buf1982 = torch.ops.quantized_decomposed.quantize_per_token.default(buf1978, buf1980, buf1981, -128, 127, torch.int8)
buf1983 = buf1982
del buf1982
# Source Nodes: [input_483], Original ATen: [quantized_decomposed.dequantize_per_token]
buf1984 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf1983, buf1980, buf1981, -128, 127, torch.int8, torch.bfloat16)
del buf1980
del buf1981
del buf1983
buf1985 = buf1984
del buf1984
# Source Nodes: [w_dq_160], Original ATen: [quantized_decomposed.dequantize_per_channel_group]
buf1986 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg599_1, arg600_1, arg601_1, -8, 7, torch.int8, 256, torch.bfloat16)
del arg599_1
del arg600_1
del arg601_1
buf1987 = buf1986
del buf1986
buf1988 = reinterpret_tensor(buf1974, (1, 4096), (4096, 1), 0); del buf1974 # reuse
# Source Nodes: [c_160], Original ATen: [aten.mm]
triton_tem_fused_mm_8.run(buf1985, buf1987, buf1988, grid=torch._inductor.kernel.mm_common.mm_grid(1, 4096, meta0), stream=stream0)
del buf1987
buf1990 = buf1957; del buf1957 # reuse
# Source Nodes: [add_161, h_23, mean_46, mul_299, out_21, out_22, pow_47, rsqrt_46, x_fp32_46, x_normed_46, y_23], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, aten.pow, aten.rsqrt]
triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_13.run(buf1955, buf1869, buf1902, buf1988, arg602_1, buf1990, 1, 4096, grid=grid(1), stream=stream0)
del arg602_1
# Source Nodes: [choose_qparams_per_token_asymmetric_161], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric]
buf1991 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf1990, torch.int8)
buf1992 = buf1991[0]
buf1993 = buf1991[1]
del buf1991
# Source Nodes: [choose_qparams_per_token_asymmetric_162], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric]
buf1994 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf1990, torch.int8)
buf1995 = buf1994[0]
buf1996 = buf1994[1]
del buf1994
# Source Nodes: [choose_qparams_per_token_asymmetric_163], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric]
buf1997 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf1990, torch.int8)
buf1998 = buf1997[0]
buf1999 = buf1997[1]
del buf1997
buf2000 = buf1914; del buf1914 # reuse
# Source Nodes: [max_24], Original ATen: [aten.max]
triton_poi_fused_max_1.run(arg3_1, buf2000, 1, grid=grid(1), stream=stream0)
u23 = buf2000.item()
buf2001 = None
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
# Source Nodes: [input_485], Original ATen: [quantized_decomposed.quantize_per_token]
buf2002 = torch.ops.quantized_decomposed.quantize_per_token.default(buf1990, buf1992, buf1993, -128, 127, torch.int8)
buf2003 = buf2002
del buf2002
# Source Nodes: [input_486], Original ATen: [quantized_decomposed.dequantize_per_token]
buf2004 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf2003, buf1992, buf1993, -128, 127, torch.int8, torch.bfloat16)
del buf1992
del buf1993
del buf2003
buf2005 = buf2004
del buf2004
# Source Nodes: [w_dq_161], Original ATen: [quantized_decomposed.dequantize_per_channel_group]
buf2006 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg603_1, arg604_1, arg605_1, -8, 7, torch.int8, 256, torch.bfloat16)
del arg603_1
del arg604_1
del arg605_1
buf2007 = buf2006
del buf2006
buf2008 = reinterpret_tensor(buf1967, (1, 4096), (4096, 1), 0); del buf1967 # reuse
# Source Nodes: [c_161], Original ATen: [aten.mm]
triton_tem_fused_mm_2.run(buf2005, buf2007, buf2008, grid=torch._inductor.kernel.mm_common.mm_grid(1, 4096, meta0), stream=stream0)
del buf2005
del buf2007
# Source Nodes: [input_488], Original ATen: [quantized_decomposed.quantize_per_token]
buf2009 = torch.ops.quantized_decomposed.quantize_per_token.default(buf1990, buf1995, buf1996, -128, 127, torch.int8)
buf2010 = buf2009
del buf2009
# Source Nodes: [input_489], Original ATen: [quantized_decomposed.dequantize_per_token]
buf2011 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf2010, buf1995, buf1996, -128, 127, torch.int8, torch.bfloat16)
del buf1995
del buf1996
del buf2010
buf2012 = buf2011
del buf2011
# Source Nodes: [w_dq_162], Original ATen: [quantized_decomposed.dequantize_per_channel_group]
buf2013 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg606_1, arg607_1, arg608_1, -8, 7, torch.int8, 256, torch.bfloat16)
del arg606_1
del arg607_1
del arg608_1
buf2014 = buf2013
del buf2013
buf2015 = buf1937; del buf1937 # reuse
# Source Nodes: [], Original ATen: []
extern_kernels.mm(reinterpret_tensor(buf2012, (1, 4096), (4096, 1), 0), reinterpret_tensor(buf2014, (4096, 1024), (1, 4096), 0), out=buf2015)
del buf2012
del buf2014
# Source Nodes: [input_491], Original ATen: [quantized_decomposed.quantize_per_token]
buf2017 = torch.ops.quantized_decomposed.quantize_per_token.default(buf1990, buf1998, buf1999, -128, 127, torch.int8)
del buf1990
buf2018 = buf2017
del buf2017
# Source Nodes: [input_492], Original ATen: [quantized_decomposed.dequantize_per_token]
buf2019 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf2018, buf1998, buf1999, -128, 127, torch.int8, torch.bfloat16)
del buf1998
del buf1999
del buf2018
buf2020 = buf2019
del buf2019
# Source Nodes: [w_dq_163], Original ATen: [quantized_decomposed.dequantize_per_channel_group]
buf2021 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg609_1, arg610_1, arg611_1, -8, 7, torch.int8, 256, torch.bfloat16)
del arg609_1
del arg610_1
del arg611_1
buf2022 = buf2021
del buf2021
buf2023 = buf1929; del buf1929 # reuse
# Source Nodes: [], Original ATen: []
extern_kernels.mm(reinterpret_tensor(buf2020, (1, 4096), (4096, 1), 0), reinterpret_tensor(buf2022, (4096, 1024), (1, 4096), 0), out=buf2023)
del buf2022
buf2025 = reinterpret_tensor(buf2020, (1, 32, 1, 128), (4096, 128, 4096, 1), 0); del buf2020 # reuse
# Source Nodes: [output_46, setitem_46, setitem_47], Original ATen: [aten._scaled_dot_product_efficient_attention, aten.index_put]
triton_poi_fused__scaled_dot_product_efficient_attention_index_put_3.run(arg3_1, buf2015, arg612_1, buf2023, buf2008, arg613_1, arg614_1, buf2025, 4096, grid=grid(4096), stream=stream0)
del arg612_1
del buf2008
buf2026 = buf1940; del buf1940 # reuse
# Source Nodes: [output_46], Original ATen: [aten._scaled_dot_product_efficient_attention]
triton_poi_fused__scaled_dot_product_efficient_attention_4.run(arg3_1, arg2_1, buf2026, 65536, grid=grid(65536), stream=stream0)
# Source Nodes: [output_46], Original ATen: [aten._scaled_dot_product_efficient_attention]
buf2027 = torch.ops.aten._scaled_dot_product_efficient_attention.default(buf2025, arg613_1, arg614_1, buf2026, False)
del arg613_1
del arg614_1
del buf2025
buf2028 = buf2027[0]
del buf2027
# Source Nodes: [choose_qparams_per_token_asymmetric_164], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric]
buf2032 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(reinterpret_tensor(buf2028, (1, 1, 4096), (4096, 4096, 1), 0), torch.int8)
buf2033 = buf2032[0]
buf2034 = buf2032[1]
del buf2032
# Source Nodes: [input_494], Original ATen: [quantized_decomposed.quantize_per_token]
buf2035 = torch.ops.quantized_decomposed.quantize_per_token.default(reinterpret_tensor(buf2028, (1, 1, 4096), (4096, 4096, 1), 0), buf2033, buf2034, -128, 127, torch.int8)
buf2036 = buf2035
del buf2035
# Source Nodes: [input_495], Original ATen: [quantized_decomposed.dequantize_per_token]
buf2037 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf2036, buf2033, buf2034, -128, 127, torch.int8, torch.bfloat16)
del buf2033
del buf2034
del buf2036
buf2038 = buf2037
del buf2037
# Source Nodes: [w_dq_164], Original ATen: [quantized_decomposed.dequantize_per_channel_group]
buf2039 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg615_1, arg616_1, arg617_1, -8, 7, torch.int8, 256, torch.bfloat16)
del arg615_1
del arg616_1
del arg617_1
buf2040 = buf2039
del buf2039
buf2041 = reinterpret_tensor(buf2028, (1, 4096), (4096, 1), 0); del buf2028 # reuse
# Source Nodes: [c_164], Original ATen: [aten.mm]
triton_tem_fused_mm_2.run(buf2038, buf2040, buf2041, grid=torch._inductor.kernel.mm_common.mm_grid(1, 4096, meta0), stream=stream0)
del buf2040
buf2042 = buf1869; del buf1869 # reuse
buf2044 = buf2038; del buf2038 # reuse
# Source Nodes: [add_166, h_23, h_24, mean_47, mul_309, mul_310, out_21, out_22, pow_48, rsqrt_47, x_fp32_47, x_normed_47], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, aten.pow, aten.rsqrt]
triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_14.run(buf2042, buf2041, buf1955, buf1902, buf1988, arg618_1, buf2044, 1, 4096, grid=grid(1), stream=stream0)
del arg618_1
del buf1902
del buf1955
del buf1988
del buf2041
# Source Nodes: [choose_qparams_per_token_asymmetric_165], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric]
buf2045 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf2044, torch.int8)
buf2046 = buf2045[0]
buf2047 = buf2045[1]
del buf2045
# Source Nodes: [choose_qparams_per_token_asymmetric_166], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric]
buf2048 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf2044, torch.int8)
buf2049 = buf2048[0]
buf2050 = buf2048[1]
del buf2048
# Source Nodes: [input_497], Original ATen: [quantized_decomposed.quantize_per_token]
buf2051 = torch.ops.quantized_decomposed.quantize_per_token.default(buf2044, buf2046, buf2047, -128, 127, torch.int8)
buf2052 = buf2051
del buf2051
# Source Nodes: [input_498], Original ATen: [quantized_decomposed.dequantize_per_token]
buf2053 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf2052, buf2046, buf2047, -128, 127, torch.int8, torch.bfloat16)
del buf2046
del buf2047
del buf2052
buf2054 = buf2053
del buf2053
# Source Nodes: [w_dq_165], Original ATen: [quantized_decomposed.dequantize_per_channel_group]
buf2055 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg619_1, arg620_1, arg621_1, -8, 7, torch.int8, 256, torch.bfloat16)
del arg619_1
del arg620_1
del arg621_1
buf2056 = buf2055
del buf2055
buf2057 = reinterpret_tensor(buf1985, (1, 14336), (14336, 1), 0); del buf1985 # reuse
# Source Nodes: [c_165], Original ATen: [aten.mm]
triton_tem_fused_mm_6.run(buf2054, buf2056, buf2057, grid=torch._inductor.kernel.mm_common.mm_grid(1, 14336, meta1), stream=stream0)
del buf2056
# Source Nodes: [input_500], Original ATen: [quantized_decomposed.quantize_per_token]
buf2058 = torch.ops.quantized_decomposed.quantize_per_token.default(buf2044, buf2049, buf2050, -128, 127, torch.int8)
buf2059 = buf2058
del buf2058
# Source Nodes: [input_501], Original ATen: [quantized_decomposed.dequantize_per_token]
buf2060 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf2059, buf2049, buf2050, -128, 127, torch.int8, torch.bfloat16)
del buf2049
del buf2050
del buf2059
buf2061 = buf2060
del buf2060
# Source Nodes: [w_dq_166], Original ATen: [quantized_decomposed.dequantize_per_channel_group]
buf2062 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg622_1, arg623_1, arg624_1, -8, 7, torch.int8, 256, torch.bfloat16)
del arg622_1
del arg623_1
del arg624_1
buf2063 = buf2062
del buf2062
buf2065 = buf1978; del buf1978 # reuse
# Source Nodes: [c_166, mul_311, silu_23], Original ATen: [aten.mm, aten.mul, aten.silu]
triton_tem_fused_mm_mul_silu_7.run(buf2061, buf2063, buf2057, buf2065, grid=torch._inductor.kernel.mm_common.mm_grid(1, 14336, meta1), stream=stream0)
del buf2057
del buf2063
# Source Nodes: [choose_qparams_per_token_asymmetric_167], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric]
buf2066 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf2065, torch.int8)
buf2067 = buf2066[0]
buf2068 = buf2066[1]
del buf2066
# Source Nodes: [input_503], Original ATen: [quantized_decomposed.quantize_per_token]
buf2069 = torch.ops.quantized_decomposed.quantize_per_token.default(buf2065, buf2067, buf2068, -128, 127, torch.int8)
buf2070 = buf2069
del buf2069
# Source Nodes: [input_504], Original ATen: [quantized_decomposed.dequantize_per_token]
buf2071 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf2070, buf2067, buf2068, -128, 127, torch.int8, torch.bfloat16)
del buf2067
del buf2068
del buf2070
buf2072 = buf2071
del buf2071
# Source Nodes: [w_dq_167], Original ATen: [quantized_decomposed.dequantize_per_channel_group]
buf2073 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg625_1, arg626_1, arg627_1, -8, 7, torch.int8, 256, torch.bfloat16)
del arg625_1
del arg626_1
del arg627_1
buf2074 = buf2073
del buf2073
buf2075 = reinterpret_tensor(buf2061, (1, 4096), (4096, 1), 0); del buf2061 # reuse
# Source Nodes: [c_167], Original ATen: [aten.mm]
triton_tem_fused_mm_8.run(buf2072, buf2074, buf2075, grid=torch._inductor.kernel.mm_common.mm_grid(1, 4096, meta0), stream=stream0)
del buf2074
buf2077 = buf2044; del buf2044 # reuse
# Source Nodes: [add_168, mean_48, mul_312, out_23, pow_49, rsqrt_48, x_fp32_48, x_normed_48, y_24], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, aten.pow, aten.rsqrt]
triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_11.run(buf2042, buf2075, arg628_1, buf2077, 1, 4096, grid=grid(1), stream=stream0)
del arg628_1
# Source Nodes: [choose_qparams_per_token_asymmetric_168], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric]
buf2078 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf2077, torch.int8)
buf2079 = buf2078[0]
buf2080 = buf2078[1]
del buf2078
# Source Nodes: [choose_qparams_per_token_asymmetric_169], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric]
buf2081 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf2077, torch.int8)
buf2082 = buf2081[0]
buf2083 = buf2081[1]
del buf2081
# Source Nodes: [choose_qparams_per_token_asymmetric_170], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric]
buf2084 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf2077, torch.int8)
buf2085 = buf2084[0]
buf2086 = buf2084[1]
del buf2084
buf2087 = buf2000; del buf2000 # reuse
# Source Nodes: [max_25], Original ATen: [aten.max]
triton_poi_fused_max_1.run(arg3_1, buf2087, 1, grid=grid(1), stream=stream0)
u24 = buf2087.item()
buf2088 = None
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
# Source Nodes: [input_506], Original ATen: [quantized_decomposed.quantize_per_token]
buf2089 = torch.ops.quantized_decomposed.quantize_per_token.default(buf2077, buf2079, buf2080, -128, 127, torch.int8)
buf2090 = buf2089
del buf2089
# Source Nodes: [input_507], Original ATen: [quantized_decomposed.dequantize_per_token]
buf2091 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf2090, buf2079, buf2080, -128, 127, torch.int8, torch.bfloat16)
del buf2079
del buf2080
del buf2090
buf2092 = buf2091
del buf2091
# Source Nodes: [w_dq_168], Original ATen: [quantized_decomposed.dequantize_per_channel_group]
buf2093 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg629_1, arg630_1, arg631_1, -8, 7, torch.int8, 256, torch.bfloat16)
del arg629_1
del arg630_1
del arg631_1
buf2094 = buf2093
del buf2093
buf2095 = reinterpret_tensor(buf2054, (1, 4096), (4096, 1), 0); del buf2054 # reuse
# Source Nodes: [c_168], Original ATen: [aten.mm]
triton_tem_fused_mm_2.run(buf2092, buf2094, buf2095, grid=torch._inductor.kernel.mm_common.mm_grid(1, 4096, meta0), stream=stream0)
del buf2092
del buf2094
# Source Nodes: [input_509], Original ATen: [quantized_decomposed.quantize_per_token]
buf2096 = torch.ops.quantized_decomposed.quantize_per_token.default(buf2077, buf2082, buf2083, -128, 127, torch.int8)
buf2097 = buf2096
del buf2096
# Source Nodes: [input_510], Original ATen: [quantized_decomposed.dequantize_per_token]
buf2098 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf2097, buf2082, buf2083, -128, 127, torch.int8, torch.bfloat16)
del buf2082
del buf2083
del buf2097
buf2099 = buf2098
del buf2098
# Source Nodes: [w_dq_169], Original ATen: [quantized_decomposed.dequantize_per_channel_group]
buf2100 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg632_1, arg633_1, arg634_1, -8, 7, torch.int8, 256, torch.bfloat16)
del arg632_1
del arg633_1
del arg634_1
buf2101 = buf2100
del buf2100
buf2102 = buf2023; del buf2023 # reuse
# Source Nodes: [], Original ATen: []
extern_kernels.mm(reinterpret_tensor(buf2099, (1, 4096), (4096, 1), 0), reinterpret_tensor(buf2101, (4096, 1024), (1, 4096), 0), out=buf2102)
del buf2099
del buf2101
# Source Nodes: [input_512], Original ATen: [quantized_decomposed.quantize_per_token]
buf2104 = torch.ops.quantized_decomposed.quantize_per_token.default(buf2077, buf2085, buf2086, -128, 127, torch.int8)
del buf2077
buf2105 = buf2104
del buf2104
# Source Nodes: [input_513], Original ATen: [quantized_decomposed.dequantize_per_token]
buf2106 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf2105, buf2085, buf2086, -128, 127, torch.int8, torch.bfloat16)
del buf2085
del buf2086
del buf2105
buf2107 = buf2106
del buf2106
# Source Nodes: [w_dq_170], Original ATen: [quantized_decomposed.dequantize_per_channel_group]
buf2108 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg635_1, arg636_1, arg637_1, -8, 7, torch.int8, 256, torch.bfloat16)
del arg635_1
del arg636_1
del arg637_1
buf2109 = buf2108
del buf2108
buf2110 = buf2015; del buf2015 # reuse
# Source Nodes: [], Original ATen: []
extern_kernels.mm(reinterpret_tensor(buf2107, (1, 4096), (4096, 1), 0), reinterpret_tensor(buf2109, (4096, 1024), (1, 4096), 0), out=buf2110)
del buf2109
buf2112 = reinterpret_tensor(buf2107, (1, 32, 1, 128), (4096, 128, 4096, 1), 0); del buf2107 # reuse
# Source Nodes: [output_48, setitem_48, setitem_49], Original ATen: [aten._scaled_dot_product_efficient_attention, aten.index_put]
triton_poi_fused__scaled_dot_product_efficient_attention_index_put_3.run(arg3_1, buf2102, arg638_1, buf2110, buf2095, arg639_1, arg640_1, buf2112, 4096, grid=grid(4096), stream=stream0)
del arg638_1
del buf2095
buf2113 = buf2026; del buf2026 # reuse
# Source Nodes: [output_48], Original ATen: [aten._scaled_dot_product_efficient_attention]
triton_poi_fused__scaled_dot_product_efficient_attention_4.run(arg3_1, arg2_1, buf2113, 65536, grid=grid(65536), stream=stream0)
# Source Nodes: [output_48], Original ATen: [aten._scaled_dot_product_efficient_attention]
buf2114 = torch.ops.aten._scaled_dot_product_efficient_attention.default(buf2112, arg639_1, arg640_1, buf2113, False)
del arg639_1
del arg640_1
del buf2112
buf2115 = buf2114[0]
del buf2114
# Source Nodes: [choose_qparams_per_token_asymmetric_171], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric]
buf2119 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(reinterpret_tensor(buf2115, (1, 1, 4096), (4096, 4096, 1), 0), torch.int8)
buf2120 = buf2119[0]
buf2121 = buf2119[1]
del buf2119
# Source Nodes: [input_515], Original ATen: [quantized_decomposed.quantize_per_token]
buf2122 = torch.ops.quantized_decomposed.quantize_per_token.default(reinterpret_tensor(buf2115, (1, 1, 4096), (4096, 4096, 1), 0), buf2120, buf2121, -128, 127, torch.int8)
buf2123 = buf2122
del buf2122
# Source Nodes: [input_516], Original ATen: [quantized_decomposed.dequantize_per_token]
buf2124 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf2123, buf2120, buf2121, -128, 127, torch.int8, torch.bfloat16)
del buf2120
del buf2121
del buf2123
buf2125 = buf2124
del buf2124
# Source Nodes: [w_dq_171], Original ATen: [quantized_decomposed.dequantize_per_channel_group]
buf2126 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg641_1, arg642_1, arg643_1, -8, 7, torch.int8, 256, torch.bfloat16)
del arg641_1
del arg642_1
del arg643_1
buf2127 = buf2126
del buf2126
buf2128 = reinterpret_tensor(buf2115, (1, 4096), (4096, 1), 0); del buf2115 # reuse
# Source Nodes: [c_171], Original ATen: [aten.mm]
triton_tem_fused_mm_2.run(buf2125, buf2127, buf2128, grid=torch._inductor.kernel.mm_common.mm_grid(1, 4096, meta0), stream=stream0)
del buf2127
buf2130 = buf2125; del buf2125 # reuse
# Source Nodes: [add_173, h_25, mean_49, mul_322, mul_323, out_23, pow_50, rsqrt_49, x_fp32_49, x_normed_49], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, aten.pow, aten.rsqrt]
triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_12.run(buf2128, buf2042, buf2075, arg644_1, buf2130, 1, 4096, grid=grid(1), stream=stream0)
del arg644_1
# Source Nodes: [choose_qparams_per_token_asymmetric_172], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric]
buf2131 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf2130, torch.int8)
buf2132 = buf2131[0]
buf2133 = buf2131[1]
del buf2131
# Source Nodes: [choose_qparams_per_token_asymmetric_173], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric]
buf2134 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf2130, torch.int8)
buf2135 = buf2134[0]
buf2136 = buf2134[1]
del buf2134
# Source Nodes: [input_518], Original ATen: [quantized_decomposed.quantize_per_token]
buf2137 = torch.ops.quantized_decomposed.quantize_per_token.default(buf2130, buf2132, buf2133, -128, 127, torch.int8)
buf2138 = buf2137
del buf2137
# Source Nodes: [input_519], Original ATen: [quantized_decomposed.dequantize_per_token]
buf2139 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf2138, buf2132, buf2133, -128, 127, torch.int8, torch.bfloat16)
del buf2132
del buf2133
del buf2138
buf2140 = buf2139
del buf2139
# Source Nodes: [w_dq_172], Original ATen: [quantized_decomposed.dequantize_per_channel_group]
buf2141 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg645_1, arg646_1, arg647_1, -8, 7, torch.int8, 256, torch.bfloat16)
del arg645_1
del arg646_1
del arg647_1
buf2142 = buf2141
del buf2141
buf2143 = reinterpret_tensor(buf2072, (1, 14336), (14336, 1), 0); del buf2072 # reuse
# Source Nodes: [c_172], Original ATen: [aten.mm]
triton_tem_fused_mm_6.run(buf2140, buf2142, buf2143, grid=torch._inductor.kernel.mm_common.mm_grid(1, 14336, meta1), stream=stream0)
del buf2142
# Source Nodes: [input_521], Original ATen: [quantized_decomposed.quantize_per_token]
buf2144 = torch.ops.quantized_decomposed.quantize_per_token.default(buf2130, buf2135, buf2136, -128, 127, torch.int8)
buf2145 = buf2144
del buf2144
# Source Nodes: [input_522], Original ATen: [quantized_decomposed.dequantize_per_token]
buf2146 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf2145, buf2135, buf2136, -128, 127, torch.int8, torch.bfloat16)
del buf2135
del buf2136
del buf2145
buf2147 = buf2146
del buf2146
# Source Nodes: [w_dq_173], Original ATen: [quantized_decomposed.dequantize_per_channel_group]
buf2148 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg648_1, arg649_1, arg650_1, -8, 7, torch.int8, 256, torch.bfloat16)
del arg648_1
del arg649_1
del arg650_1
buf2149 = buf2148
del buf2148
buf2151 = buf2065; del buf2065 # reuse
# Source Nodes: [c_173, mul_324, silu_24], Original ATen: [aten.mm, aten.mul, aten.silu]
triton_tem_fused_mm_mul_silu_7.run(buf2147, buf2149, buf2143, buf2151, grid=torch._inductor.kernel.mm_common.mm_grid(1, 14336, meta1), stream=stream0)
del buf2143
del buf2149
# Source Nodes: [choose_qparams_per_token_asymmetric_174], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric]
buf2152 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf2151, torch.int8)
buf2153 = buf2152[0]
buf2154 = buf2152[1]
del buf2152
# Source Nodes: [input_524], Original ATen: [quantized_decomposed.quantize_per_token]
buf2155 = torch.ops.quantized_decomposed.quantize_per_token.default(buf2151, buf2153, buf2154, -128, 127, torch.int8)
buf2156 = buf2155
del buf2155
# Source Nodes: [input_525], Original ATen: [quantized_decomposed.dequantize_per_token]
buf2157 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf2156, buf2153, buf2154, -128, 127, torch.int8, torch.bfloat16)
del buf2153
del buf2154
del buf2156
buf2158 = buf2157
del buf2157
# Source Nodes: [w_dq_174], Original ATen: [quantized_decomposed.dequantize_per_channel_group]
buf2159 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg651_1, arg652_1, arg653_1, -8, 7, torch.int8, 256, torch.bfloat16)
del arg651_1
del arg652_1
del arg653_1
buf2160 = buf2159
del buf2159
buf2161 = reinterpret_tensor(buf2147, (1, 4096), (4096, 1), 0); del buf2147 # reuse
# Source Nodes: [c_174], Original ATen: [aten.mm]
triton_tem_fused_mm_8.run(buf2158, buf2160, buf2161, grid=torch._inductor.kernel.mm_common.mm_grid(1, 4096, meta0), stream=stream0)
del buf2160
buf2163 = buf2130; del buf2130 # reuse
# Source Nodes: [add_175, h_25, mean_50, mul_325, out_23, out_24, pow_51, rsqrt_50, x_fp32_50, x_normed_50, y_25], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, aten.pow, aten.rsqrt]
triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_13.run(buf2128, buf2042, buf2075, buf2161, arg654_1, buf2163, 1, 4096, grid=grid(1), stream=stream0)
del arg654_1
# Source Nodes: [choose_qparams_per_token_asymmetric_175], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric]
buf2164 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf2163, torch.int8)
buf2165 = buf2164[0]
buf2166 = buf2164[1]
del buf2164
# Source Nodes: [choose_qparams_per_token_asymmetric_176], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric]
buf2167 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf2163, torch.int8)
buf2168 = buf2167[0]
buf2169 = buf2167[1]
del buf2167
# Source Nodes: [choose_qparams_per_token_asymmetric_177], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric]
buf2170 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf2163, torch.int8)
buf2171 = buf2170[0]
buf2172 = buf2170[1]
del buf2170
buf2173 = buf2087; del buf2087 # reuse
# Source Nodes: [max_26], Original ATen: [aten.max]
triton_poi_fused_max_1.run(arg3_1, buf2173, 1, grid=grid(1), stream=stream0)
u25 = buf2173.item()
buf2174 = None
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
# Source Nodes: [input_527], Original ATen: [quantized_decomposed.quantize_per_token]
buf2175 = torch.ops.quantized_decomposed.quantize_per_token.default(buf2163, buf2165, buf2166, -128, 127, torch.int8)
buf2176 = buf2175
del buf2175
# Source Nodes: [input_528], Original ATen: [quantized_decomposed.dequantize_per_token]
buf2177 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf2176, buf2165, buf2166, -128, 127, torch.int8, torch.bfloat16)
del buf2165
del buf2166
del buf2176
buf2178 = buf2177
del buf2177
# Source Nodes: [w_dq_175], Original ATen: [quantized_decomposed.dequantize_per_channel_group]
buf2179 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg655_1, arg656_1, arg657_1, -8, 7, torch.int8, 256, torch.bfloat16)
del arg655_1
del arg656_1
del arg657_1
buf2180 = buf2179
del buf2179
buf2181 = reinterpret_tensor(buf2140, (1, 4096), (4096, 1), 0); del buf2140 # reuse
# Source Nodes: [c_175], Original ATen: [aten.mm]
triton_tem_fused_mm_2.run(buf2178, buf2180, buf2181, grid=torch._inductor.kernel.mm_common.mm_grid(1, 4096, meta0), stream=stream0)
del buf2178
del buf2180
# Source Nodes: [input_530], Original ATen: [quantized_decomposed.quantize_per_token]
buf2182 = torch.ops.quantized_decomposed.quantize_per_token.default(buf2163, buf2168, buf2169, -128, 127, torch.int8)
buf2183 = buf2182
del buf2182
# Source Nodes: [input_531], Original ATen: [quantized_decomposed.dequantize_per_token]
buf2184 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf2183, buf2168, buf2169, -128, 127, torch.int8, torch.bfloat16)
del buf2168
del buf2169
del buf2183
buf2185 = buf2184
del buf2184
# Source Nodes: [w_dq_176], Original ATen: [quantized_decomposed.dequantize_per_channel_group]
buf2186 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg658_1, arg659_1, arg660_1, -8, 7, torch.int8, 256, torch.bfloat16)
del arg658_1
del arg659_1
del arg660_1
buf2187 = buf2186
del buf2186
buf2188 = buf2110; del buf2110 # reuse
# Source Nodes: [], Original ATen: []
extern_kernels.mm(reinterpret_tensor(buf2185, (1, 4096), (4096, 1), 0), reinterpret_tensor(buf2187, (4096, 1024), (1, 4096), 0), out=buf2188)
del buf2185
del buf2187
# Source Nodes: [input_533], Original ATen: [quantized_decomposed.quantize_per_token]
buf2190 = torch.ops.quantized_decomposed.quantize_per_token.default(buf2163, buf2171, buf2172, -128, 127, torch.int8)
del buf2163
buf2191 = buf2190
del buf2190
# Source Nodes: [input_534], Original ATen: [quantized_decomposed.dequantize_per_token]
buf2192 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf2191, buf2171, buf2172, -128, 127, torch.int8, torch.bfloat16)
del buf2171
del buf2172
del buf2191
buf2193 = buf2192
del buf2192
# Source Nodes: [w_dq_177], Original ATen: [quantized_decomposed.dequantize_per_channel_group]
buf2194 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg661_1, arg662_1, arg663_1, -8, 7, torch.int8, 256, torch.bfloat16)
del arg661_1
del arg662_1
del arg663_1
buf2195 = buf2194
del buf2194
buf2196 = buf2102; del buf2102 # reuse
# Source Nodes: [], Original ATen: []
extern_kernels.mm(reinterpret_tensor(buf2193, (1, 4096), (4096, 1), 0), reinterpret_tensor(buf2195, (4096, 1024), (1, 4096), 0), out=buf2196)
del buf2195
buf2198 = reinterpret_tensor(buf2193, (1, 32, 1, 128), (4096, 128, 4096, 1), 0); del buf2193 # reuse
# Source Nodes: [output_50, setitem_50, setitem_51], Original ATen: [aten._scaled_dot_product_efficient_attention, aten.index_put]
triton_poi_fused__scaled_dot_product_efficient_attention_index_put_3.run(arg3_1, buf2188, arg664_1, buf2196, buf2181, arg665_1, arg666_1, buf2198, 4096, grid=grid(4096), stream=stream0)
del arg664_1
del buf2181
buf2199 = buf2113; del buf2113 # reuse
# Source Nodes: [output_50], Original ATen: [aten._scaled_dot_product_efficient_attention]
triton_poi_fused__scaled_dot_product_efficient_attention_4.run(arg3_1, arg2_1, buf2199, 65536, grid=grid(65536), stream=stream0)
# Source Nodes: [output_50], Original ATen: [aten._scaled_dot_product_efficient_attention]
buf2200 = torch.ops.aten._scaled_dot_product_efficient_attention.default(buf2198, arg665_1, arg666_1, buf2199, False)
del arg665_1
del arg666_1
del buf2198
buf2201 = buf2200[0]
del buf2200
# Source Nodes: [choose_qparams_per_token_asymmetric_178], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric]
buf2205 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(reinterpret_tensor(buf2201, (1, 1, 4096), (4096, 4096, 1), 0), torch.int8)
buf2206 = buf2205[0]
buf2207 = buf2205[1]
del buf2205
# Source Nodes: [input_536], Original ATen: [quantized_decomposed.quantize_per_token]
buf2208 = torch.ops.quantized_decomposed.quantize_per_token.default(reinterpret_tensor(buf2201, (1, 1, 4096), (4096, 4096, 1), 0), buf2206, buf2207, -128, 127, torch.int8)
buf2209 = buf2208
del buf2208
# Source Nodes: [input_537], Original ATen: [quantized_decomposed.dequantize_per_token]
buf2210 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf2209, buf2206, buf2207, -128, 127, torch.int8, torch.bfloat16)
del buf2206
del buf2207
del buf2209
buf2211 = buf2210
del buf2210
# Source Nodes: [w_dq_178], Original ATen: [quantized_decomposed.dequantize_per_channel_group]
buf2212 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg667_1, arg668_1, arg669_1, -8, 7, torch.int8, 256, torch.bfloat16)
del arg667_1
del arg668_1
del arg669_1
buf2213 = buf2212
del buf2212
buf2214 = reinterpret_tensor(buf2201, (1, 4096), (4096, 1), 0); del buf2201 # reuse
# Source Nodes: [c_178], Original ATen: [aten.mm]
triton_tem_fused_mm_2.run(buf2211, buf2213, buf2214, grid=torch._inductor.kernel.mm_common.mm_grid(1, 4096, meta0), stream=stream0)
del buf2213
buf2215 = buf2042; del buf2042 # reuse
buf2217 = buf2211; del buf2211 # reuse
# Source Nodes: [add_180, h_25, h_26, mean_51, mul_335, mul_336, out_23, out_24, pow_52, rsqrt_51, x_fp32_51, x_normed_51], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, aten.pow, aten.rsqrt]
triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_14.run(buf2215, buf2214, buf2128, buf2075, buf2161, arg670_1, buf2217, 1, 4096, grid=grid(1), stream=stream0)
del arg670_1
del buf2075
del buf2128
del buf2161
del buf2214
# Source Nodes: [choose_qparams_per_token_asymmetric_179], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric]
buf2218 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf2217, torch.int8)
buf2219 = buf2218[0]
buf2220 = buf2218[1]
del buf2218
# Source Nodes: [choose_qparams_per_token_asymmetric_180], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric]
buf2221 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf2217, torch.int8)
buf2222 = buf2221[0]
buf2223 = buf2221[1]
del buf2221
# Source Nodes: [input_539], Original ATen: [quantized_decomposed.quantize_per_token]
buf2224 = torch.ops.quantized_decomposed.quantize_per_token.default(buf2217, buf2219, buf2220, -128, 127, torch.int8)
buf2225 = buf2224
del buf2224
# Source Nodes: [input_540], Original ATen: [quantized_decomposed.dequantize_per_token]
buf2226 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf2225, buf2219, buf2220, -128, 127, torch.int8, torch.bfloat16)
del buf2219
del buf2220
del buf2225
buf2227 = buf2226
del buf2226
# Source Nodes: [w_dq_179], Original ATen: [quantized_decomposed.dequantize_per_channel_group]
buf2228 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg671_1, arg672_1, arg673_1, -8, 7, torch.int8, 256, torch.bfloat16)
del arg671_1
del arg672_1
del arg673_1
buf2229 = buf2228
del buf2228
buf2230 = reinterpret_tensor(buf2158, (1, 14336), (14336, 1), 0); del buf2158 # reuse
# Source Nodes: [c_179], Original ATen: [aten.mm]
triton_tem_fused_mm_6.run(buf2227, buf2229, buf2230, grid=torch._inductor.kernel.mm_common.mm_grid(1, 14336, meta1), stream=stream0)
del buf2229
# Source Nodes: [input_542], Original ATen: [quantized_decomposed.quantize_per_token]
buf2231 = torch.ops.quantized_decomposed.quantize_per_token.default(buf2217, buf2222, buf2223, -128, 127, torch.int8)
buf2232 = buf2231
del buf2231
# Source Nodes: [input_543], Original ATen: [quantized_decomposed.dequantize_per_token]
buf2233 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf2232, buf2222, buf2223, -128, 127, torch.int8, torch.bfloat16)
del buf2222
del buf2223
del buf2232
buf2234 = buf2233
del buf2233
# Source Nodes: [w_dq_180], Original ATen: [quantized_decomposed.dequantize_per_channel_group]
buf2235 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg674_1, arg675_1, arg676_1, -8, 7, torch.int8, 256, torch.bfloat16)
del arg674_1
del arg675_1
del arg676_1
buf2236 = buf2235
del buf2235
buf2238 = buf2151; del buf2151 # reuse
# Source Nodes: [c_180, mul_337, silu_25], Original ATen: [aten.mm, aten.mul, aten.silu]
triton_tem_fused_mm_mul_silu_7.run(buf2234, buf2236, buf2230, buf2238, grid=torch._inductor.kernel.mm_common.mm_grid(1, 14336, meta1), stream=stream0)
del buf2230
del buf2236
# Source Nodes: [choose_qparams_per_token_asymmetric_181], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric]
buf2239 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf2238, torch.int8)
buf2240 = buf2239[0]
buf2241 = buf2239[1]
del buf2239
# Source Nodes: [input_545], Original ATen: [quantized_decomposed.quantize_per_token]
buf2242 = torch.ops.quantized_decomposed.quantize_per_token.default(buf2238, buf2240, buf2241, -128, 127, torch.int8)
buf2243 = buf2242
del buf2242
# Source Nodes: [input_546], Original ATen: [quantized_decomposed.dequantize_per_token]
buf2244 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf2243, buf2240, buf2241, -128, 127, torch.int8, torch.bfloat16)
del buf2240
del buf2241
del buf2243
buf2245 = buf2244
del buf2244
# Source Nodes: [w_dq_181], Original ATen: [quantized_decomposed.dequantize_per_channel_group]
buf2246 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg677_1, arg678_1, arg679_1, -8, 7, torch.int8, 256, torch.bfloat16)
del arg677_1
del arg678_1
del arg679_1
buf2247 = buf2246
del buf2246
buf2248 = reinterpret_tensor(buf2234, (1, 4096), (4096, 1), 0); del buf2234 # reuse
# Source Nodes: [c_181], Original ATen: [aten.mm]
triton_tem_fused_mm_8.run(buf2245, buf2247, buf2248, grid=torch._inductor.kernel.mm_common.mm_grid(1, 4096, meta0), stream=stream0)
del buf2247
buf2250 = buf2217; del buf2217 # reuse
# Source Nodes: [add_182, mean_52, mul_338, out_25, pow_53, rsqrt_52, x_fp32_52, x_normed_52, y_26], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, aten.pow, aten.rsqrt]
triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_11.run(buf2215, buf2248, arg680_1, buf2250, 1, 4096, grid=grid(1), stream=stream0)
del arg680_1
# Source Nodes: [choose_qparams_per_token_asymmetric_182], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric]
buf2251 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf2250, torch.int8)
buf2252 = buf2251[0]
buf2253 = buf2251[1]
del buf2251
# Source Nodes: [choose_qparams_per_token_asymmetric_183], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric]
buf2254 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf2250, torch.int8)
buf2255 = buf2254[0]
buf2256 = buf2254[1]
del buf2254
# Source Nodes: [choose_qparams_per_token_asymmetric_184], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric]
buf2257 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf2250, torch.int8)
buf2258 = buf2257[0]
buf2259 = buf2257[1]
del buf2257
buf2260 = buf2173; del buf2173 # reuse
# Source Nodes: [max_27], Original ATen: [aten.max]
triton_poi_fused_max_1.run(arg3_1, buf2260, 1, grid=grid(1), stream=stream0)
u26 = buf2260.item()
buf2261 = None
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
# Source Nodes: [input_548], Original ATen: [quantized_decomposed.quantize_per_token]
buf2262 = torch.ops.quantized_decomposed.quantize_per_token.default(buf2250, buf2252, buf2253, -128, 127, torch.int8)
buf2263 = buf2262
del buf2262
# Source Nodes: [input_549], Original ATen: [quantized_decomposed.dequantize_per_token]
buf2264 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf2263, buf2252, buf2253, -128, 127, torch.int8, torch.bfloat16)
del buf2252
del buf2253
del buf2263
buf2265 = buf2264
del buf2264
# Source Nodes: [w_dq_182], Original ATen: [quantized_decomposed.dequantize_per_channel_group]
buf2266 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg681_1, arg682_1, arg683_1, -8, 7, torch.int8, 256, torch.bfloat16)
del arg681_1
del arg682_1
del arg683_1
buf2267 = buf2266
del buf2266
buf2268 = reinterpret_tensor(buf2227, (1, 4096), (4096, 1), 0); del buf2227 # reuse
# Source Nodes: [c_182], Original ATen: [aten.mm]
triton_tem_fused_mm_2.run(buf2265, buf2267, buf2268, grid=torch._inductor.kernel.mm_common.mm_grid(1, 4096, meta0), stream=stream0)
del buf2265
del buf2267
# Source Nodes: [input_551], Original ATen: [quantized_decomposed.quantize_per_token]
buf2269 = torch.ops.quantized_decomposed.quantize_per_token.default(buf2250, buf2255, buf2256, -128, 127, torch.int8)
buf2270 = buf2269
del buf2269
# Source Nodes: [input_552], Original ATen: [quantized_decomposed.dequantize_per_token]
buf2271 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf2270, buf2255, buf2256, -128, 127, torch.int8, torch.bfloat16)
del buf2255
del buf2256
del buf2270
buf2272 = buf2271
del buf2271
# Source Nodes: [w_dq_183], Original ATen: [quantized_decomposed.dequantize_per_channel_group]
buf2273 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg684_1, arg685_1, arg686_1, -8, 7, torch.int8, 256, torch.bfloat16)
del arg684_1
del arg685_1
del arg686_1
buf2274 = buf2273
del buf2273
buf2275 = buf2196; del buf2196 # reuse
# Source Nodes: [], Original ATen: []
extern_kernels.mm(reinterpret_tensor(buf2272, (1, 4096), (4096, 1), 0), reinterpret_tensor(buf2274, (4096, 1024), (1, 4096), 0), out=buf2275)
del buf2272
del buf2274
# Source Nodes: [input_554], Original ATen: [quantized_decomposed.quantize_per_token]
buf2277 = torch.ops.quantized_decomposed.quantize_per_token.default(buf2250, buf2258, buf2259, -128, 127, torch.int8)
del buf2250
buf2278 = buf2277
del buf2277
# Source Nodes: [input_555], Original ATen: [quantized_decomposed.dequantize_per_token]
buf2279 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf2278, buf2258, buf2259, -128, 127, torch.int8, torch.bfloat16)
del buf2258
del buf2259
del buf2278
buf2280 = buf2279
del buf2279
# Source Nodes: [w_dq_184], Original ATen: [quantized_decomposed.dequantize_per_channel_group]
buf2281 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg687_1, arg688_1, arg689_1, -8, 7, torch.int8, 256, torch.bfloat16)
del arg687_1
del arg688_1
del arg689_1
buf2282 = buf2281
del buf2281
buf2283 = buf2188; del buf2188 # reuse
# Source Nodes: [], Original ATen: []
extern_kernels.mm(reinterpret_tensor(buf2280, (1, 4096), (4096, 1), 0), reinterpret_tensor(buf2282, (4096, 1024), (1, 4096), 0), out=buf2283)
del buf2282
buf2285 = reinterpret_tensor(buf2280, (1, 32, 1, 128), (4096, 128, 4096, 1), 0); del buf2280 # reuse
# Source Nodes: [output_52, setitem_52, setitem_53], Original ATen: [aten._scaled_dot_product_efficient_attention, aten.index_put]
triton_poi_fused__scaled_dot_product_efficient_attention_index_put_3.run(arg3_1, buf2275, arg690_1, buf2283, buf2268, arg691_1, arg692_1, buf2285, 4096, grid=grid(4096), stream=stream0)
del arg690_1
del buf2268
buf2286 = buf2199; del buf2199 # reuse
# Source Nodes: [output_52], Original ATen: [aten._scaled_dot_product_efficient_attention]
triton_poi_fused__scaled_dot_product_efficient_attention_4.run(arg3_1, arg2_1, buf2286, 65536, grid=grid(65536), stream=stream0)
# Source Nodes: [output_52], Original ATen: [aten._scaled_dot_product_efficient_attention]
buf2287 = torch.ops.aten._scaled_dot_product_efficient_attention.default(buf2285, arg691_1, arg692_1, buf2286, False)
del arg691_1
del arg692_1
del buf2285
buf2288 = buf2287[0]
del buf2287
# Source Nodes: [choose_qparams_per_token_asymmetric_185], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric]
buf2292 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(reinterpret_tensor(buf2288, (1, 1, 4096), (4096, 4096, 1), 0), torch.int8)
buf2293 = buf2292[0]
buf2294 = buf2292[1]
del buf2292
# Source Nodes: [input_557], Original ATen: [quantized_decomposed.quantize_per_token]
buf2295 = torch.ops.quantized_decomposed.quantize_per_token.default(reinterpret_tensor(buf2288, (1, 1, 4096), (4096, 4096, 1), 0), buf2293, buf2294, -128, 127, torch.int8)
buf2296 = buf2295
del buf2295
# Source Nodes: [input_558], Original ATen: [quantized_decomposed.dequantize_per_token]
buf2297 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf2296, buf2293, buf2294, -128, 127, torch.int8, torch.bfloat16)
del buf2293
del buf2294
del buf2296
buf2298 = buf2297
del buf2297
# Source Nodes: [w_dq_185], Original ATen: [quantized_decomposed.dequantize_per_channel_group]
buf2299 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg693_1, arg694_1, arg695_1, -8, 7, torch.int8, 256, torch.bfloat16)
del arg693_1
del arg694_1
del arg695_1
buf2300 = buf2299
del buf2299
buf2301 = reinterpret_tensor(buf2288, (1, 4096), (4096, 1), 0); del buf2288 # reuse
# Source Nodes: [c_185], Original ATen: [aten.mm]
triton_tem_fused_mm_2.run(buf2298, buf2300, buf2301, grid=torch._inductor.kernel.mm_common.mm_grid(1, 4096, meta0), stream=stream0)
del buf2300
buf2303 = buf2298; del buf2298 # reuse
# Source Nodes: [add_187, h_27, mean_53, mul_348, mul_349, out_25, pow_54, rsqrt_53, x_fp32_53, x_normed_53], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, aten.pow, aten.rsqrt]
triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_12.run(buf2301, buf2215, buf2248, arg696_1, buf2303, 1, 4096, grid=grid(1), stream=stream0)
del arg696_1
# Source Nodes: [choose_qparams_per_token_asymmetric_186], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric]
buf2304 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf2303, torch.int8)
buf2305 = buf2304[0]
buf2306 = buf2304[1]
del buf2304
# Source Nodes: [choose_qparams_per_token_asymmetric_187], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric]
buf2307 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf2303, torch.int8)
buf2308 = buf2307[0]
buf2309 = buf2307[1]
del buf2307
# Source Nodes: [input_560], Original ATen: [quantized_decomposed.quantize_per_token]
buf2310 = torch.ops.quantized_decomposed.quantize_per_token.default(buf2303, buf2305, buf2306, -128, 127, torch.int8)
buf2311 = buf2310
del buf2310
# Source Nodes: [input_561], Original ATen: [quantized_decomposed.dequantize_per_token]
buf2312 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf2311, buf2305, buf2306, -128, 127, torch.int8, torch.bfloat16)
del buf2305
del buf2306
del buf2311
buf2313 = buf2312
del buf2312
# Source Nodes: [w_dq_186], Original ATen: [quantized_decomposed.dequantize_per_channel_group]
buf2314 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg697_1, arg698_1, arg699_1, -8, 7, torch.int8, 256, torch.bfloat16)
del arg697_1
del arg698_1
del arg699_1
buf2315 = buf2314
del buf2314
buf2316 = reinterpret_tensor(buf2245, (1, 14336), (14336, 1), 0); del buf2245 # reuse
# Source Nodes: [c_186], Original ATen: [aten.mm]
triton_tem_fused_mm_6.run(buf2313, buf2315, buf2316, grid=torch._inductor.kernel.mm_common.mm_grid(1, 14336, meta1), stream=stream0)
del buf2315
# Source Nodes: [input_563], Original ATen: [quantized_decomposed.quantize_per_token]
buf2317 = torch.ops.quantized_decomposed.quantize_per_token.default(buf2303, buf2308, buf2309, -128, 127, torch.int8)
buf2318 = buf2317
del buf2317
# Source Nodes: [input_564], Original ATen: [quantized_decomposed.dequantize_per_token]
buf2319 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf2318, buf2308, buf2309, -128, 127, torch.int8, torch.bfloat16)
del buf2308
del buf2309
del buf2318
buf2320 = buf2319
del buf2319
# Source Nodes: [w_dq_187], Original ATen: [quantized_decomposed.dequantize_per_channel_group]
buf2321 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg700_1, arg701_1, arg702_1, -8, 7, torch.int8, 256, torch.bfloat16)
del arg700_1
del arg701_1
del arg702_1
buf2322 = buf2321
del buf2321
buf2324 = buf2238; del buf2238 # reuse
# Source Nodes: [c_187, mul_350, silu_26], Original ATen: [aten.mm, aten.mul, aten.silu]
triton_tem_fused_mm_mul_silu_7.run(buf2320, buf2322, buf2316, buf2324, grid=torch._inductor.kernel.mm_common.mm_grid(1, 14336, meta1), stream=stream0)
del buf2316
del buf2322
# Source Nodes: [choose_qparams_per_token_asymmetric_188], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric]
buf2325 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf2324, torch.int8)
buf2326 = buf2325[0]
buf2327 = buf2325[1]
del buf2325
# Source Nodes: [input_566], Original ATen: [quantized_decomposed.quantize_per_token]
buf2328 = torch.ops.quantized_decomposed.quantize_per_token.default(buf2324, buf2326, buf2327, -128, 127, torch.int8)
buf2329 = buf2328
del buf2328
# Source Nodes: [input_567], Original ATen: [quantized_decomposed.dequantize_per_token]
buf2330 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf2329, buf2326, buf2327, -128, 127, torch.int8, torch.bfloat16)
del buf2326
del buf2327
del buf2329
buf2331 = buf2330
del buf2330
# Source Nodes: [w_dq_188], Original ATen: [quantized_decomposed.dequantize_per_channel_group]
buf2332 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg703_1, arg704_1, arg705_1, -8, 7, torch.int8, 256, torch.bfloat16)
del arg703_1
del arg704_1
del arg705_1
buf2333 = buf2332
del buf2332
buf2334 = reinterpret_tensor(buf2320, (1, 4096), (4096, 1), 0); del buf2320 # reuse
# Source Nodes: [c_188], Original ATen: [aten.mm]
triton_tem_fused_mm_8.run(buf2331, buf2333, buf2334, grid=torch._inductor.kernel.mm_common.mm_grid(1, 4096, meta0), stream=stream0)
del buf2333
buf2336 = buf2303; del buf2303 # reuse
# Source Nodes: [add_189, h_27, mean_54, mul_351, out_25, out_26, pow_55, rsqrt_54, x_fp32_54, x_normed_54, y_27], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, aten.pow, aten.rsqrt]
triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_13.run(buf2301, buf2215, buf2248, buf2334, arg706_1, buf2336, 1, 4096, grid=grid(1), stream=stream0)
del arg706_1
# Source Nodes: [choose_qparams_per_token_asymmetric_189], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric]
buf2337 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf2336, torch.int8)
buf2338 = buf2337[0]
buf2339 = buf2337[1]
del buf2337
# Source Nodes: [choose_qparams_per_token_asymmetric_190], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric]
buf2340 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf2336, torch.int8)
buf2341 = buf2340[0]
buf2342 = buf2340[1]
del buf2340
# Source Nodes: [choose_qparams_per_token_asymmetric_191], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric]
buf2343 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf2336, torch.int8)
buf2344 = buf2343[0]
buf2345 = buf2343[1]
del buf2343
buf2346 = buf2260; del buf2260 # reuse
# Source Nodes: [max_28], Original ATen: [aten.max]
triton_poi_fused_max_1.run(arg3_1, buf2346, 1, grid=grid(1), stream=stream0)
u27 = buf2346.item()
buf2347 = None
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
# Source Nodes: [input_569], Original ATen: [quantized_decomposed.quantize_per_token]
buf2348 = torch.ops.quantized_decomposed.quantize_per_token.default(buf2336, buf2338, buf2339, -128, 127, torch.int8)
buf2349 = buf2348
del buf2348
# Source Nodes: [input_570], Original ATen: [quantized_decomposed.dequantize_per_token]
buf2350 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf2349, buf2338, buf2339, -128, 127, torch.int8, torch.bfloat16)
del buf2338
del buf2339
del buf2349
buf2351 = buf2350
del buf2350
# Source Nodes: [w_dq_189], Original ATen: [quantized_decomposed.dequantize_per_channel_group]
buf2352 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg707_1, arg708_1, arg709_1, -8, 7, torch.int8, 256, torch.bfloat16)
del arg707_1
del arg708_1
del arg709_1
buf2353 = buf2352
del buf2352
buf2354 = reinterpret_tensor(buf2313, (1, 4096), (4096, 1), 0); del buf2313 # reuse
# Source Nodes: [c_189], Original ATen: [aten.mm]
triton_tem_fused_mm_2.run(buf2351, buf2353, buf2354, grid=torch._inductor.kernel.mm_common.mm_grid(1, 4096, meta0), stream=stream0)
del buf2351
del buf2353
# Source Nodes: [input_572], Original ATen: [quantized_decomposed.quantize_per_token]
buf2355 = torch.ops.quantized_decomposed.quantize_per_token.default(buf2336, buf2341, buf2342, -128, 127, torch.int8)
buf2356 = buf2355
del buf2355
# Source Nodes: [input_573], Original ATen: [quantized_decomposed.dequantize_per_token]
buf2357 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf2356, buf2341, buf2342, -128, 127, torch.int8, torch.bfloat16)
del buf2341
del buf2342
del buf2356
buf2358 = buf2357
del buf2357
# Source Nodes: [w_dq_190], Original ATen: [quantized_decomposed.dequantize_per_channel_group]
buf2359 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg710_1, arg711_1, arg712_1, -8, 7, torch.int8, 256, torch.bfloat16)
del arg710_1
del arg711_1
del arg712_1
buf2360 = buf2359
del buf2359
buf2361 = buf2283; del buf2283 # reuse
# Source Nodes: [], Original ATen: []
extern_kernels.mm(reinterpret_tensor(buf2358, (1, 4096), (4096, 1), 0), reinterpret_tensor(buf2360, (4096, 1024), (1, 4096), 0), out=buf2361)
del buf2358
del buf2360
# Source Nodes: [input_575], Original ATen: [quantized_decomposed.quantize_per_token]
buf2363 = torch.ops.quantized_decomposed.quantize_per_token.default(buf2336, buf2344, buf2345, -128, 127, torch.int8)
del buf2336
buf2364 = buf2363
del buf2363
# Source Nodes: [input_576], Original ATen: [quantized_decomposed.dequantize_per_token]
buf2365 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf2364, buf2344, buf2345, -128, 127, torch.int8, torch.bfloat16)
del buf2344
del buf2345
del buf2364
buf2366 = buf2365
del buf2365
# Source Nodes: [w_dq_191], Original ATen: [quantized_decomposed.dequantize_per_channel_group]
buf2367 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg713_1, arg714_1, arg715_1, -8, 7, torch.int8, 256, torch.bfloat16)
del arg713_1
del arg714_1
del arg715_1
buf2368 = buf2367
del buf2367
buf2369 = buf2275; del buf2275 # reuse
# Source Nodes: [], Original ATen: []
extern_kernels.mm(reinterpret_tensor(buf2366, (1, 4096), (4096, 1), 0), reinterpret_tensor(buf2368, (4096, 1024), (1, 4096), 0), out=buf2369)
del buf2368
buf2371 = reinterpret_tensor(buf2366, (1, 32, 1, 128), (4096, 128, 4096, 1), 0); del buf2366 # reuse
# Source Nodes: [output_54, setitem_54, setitem_55], Original ATen: [aten._scaled_dot_product_efficient_attention, aten.index_put]
triton_poi_fused__scaled_dot_product_efficient_attention_index_put_3.run(arg3_1, buf2361, arg716_1, buf2369, buf2354, arg717_1, arg718_1, buf2371, 4096, grid=grid(4096), stream=stream0)
del arg716_1
del buf2354
buf2372 = buf2286; del buf2286 # reuse
# Source Nodes: [output_54], Original ATen: [aten._scaled_dot_product_efficient_attention]
triton_poi_fused__scaled_dot_product_efficient_attention_4.run(arg3_1, arg2_1, buf2372, 65536, grid=grid(65536), stream=stream0)
# Source Nodes: [output_54], Original ATen: [aten._scaled_dot_product_efficient_attention]
buf2373 = torch.ops.aten._scaled_dot_product_efficient_attention.default(buf2371, arg717_1, arg718_1, buf2372, False)
del arg717_1
del arg718_1
del buf2371
buf2374 = buf2373[0]
del buf2373
# Source Nodes: [choose_qparams_per_token_asymmetric_192], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric]
buf2378 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(reinterpret_tensor(buf2374, (1, 1, 4096), (4096, 4096, 1), 0), torch.int8)
buf2379 = buf2378[0]
buf2380 = buf2378[1]
del buf2378
# Source Nodes: [input_578], Original ATen: [quantized_decomposed.quantize_per_token]
buf2381 = torch.ops.quantized_decomposed.quantize_per_token.default(reinterpret_tensor(buf2374, (1, 1, 4096), (4096, 4096, 1), 0), buf2379, buf2380, -128, 127, torch.int8)
buf2382 = buf2381
del buf2381
# Source Nodes: [input_579], Original ATen: [quantized_decomposed.dequantize_per_token]
buf2383 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf2382, buf2379, buf2380, -128, 127, torch.int8, torch.bfloat16)
del buf2379
del buf2380
del buf2382
buf2384 = buf2383
del buf2383
# Source Nodes: [w_dq_192], Original ATen: [quantized_decomposed.dequantize_per_channel_group]
buf2385 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg719_1, arg720_1, arg721_1, -8, 7, torch.int8, 256, torch.bfloat16)
del arg719_1
del arg720_1
del arg721_1
buf2386 = buf2385
del buf2385
buf2387 = reinterpret_tensor(buf2374, (1, 4096), (4096, 1), 0); del buf2374 # reuse
# Source Nodes: [c_192], Original ATen: [aten.mm]
triton_tem_fused_mm_2.run(buf2384, buf2386, buf2387, grid=torch._inductor.kernel.mm_common.mm_grid(1, 4096, meta0), stream=stream0)
del buf2386
buf2388 = buf2215; del buf2215 # reuse
buf2390 = buf2384; del buf2384 # reuse
# Source Nodes: [add_194, h_27, h_28, mean_55, mul_361, mul_362, out_25, out_26, pow_56, rsqrt_55, x_fp32_55, x_normed_55], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, aten.pow, aten.rsqrt]
triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_14.run(buf2388, buf2387, buf2301, buf2248, buf2334, arg722_1, buf2390, 1, 4096, grid=grid(1), stream=stream0)
del arg722_1
del buf2248
del buf2301
del buf2334
del buf2387
# Source Nodes: [choose_qparams_per_token_asymmetric_193], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric]
buf2391 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf2390, torch.int8)
buf2392 = buf2391[0]
buf2393 = buf2391[1]
del buf2391
# Source Nodes: [choose_qparams_per_token_asymmetric_194], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric]
buf2394 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf2390, torch.int8)
buf2395 = buf2394[0]
buf2396 = buf2394[1]
del buf2394
# Source Nodes: [input_581], Original ATen: [quantized_decomposed.quantize_per_token]
buf2397 = torch.ops.quantized_decomposed.quantize_per_token.default(buf2390, buf2392, buf2393, -128, 127, torch.int8)
buf2398 = buf2397
del buf2397
# Source Nodes: [input_582], Original ATen: [quantized_decomposed.dequantize_per_token]
buf2399 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf2398, buf2392, buf2393, -128, 127, torch.int8, torch.bfloat16)
del buf2392
del buf2393
del buf2398
buf2400 = buf2399
del buf2399
# Source Nodes: [w_dq_193], Original ATen: [quantized_decomposed.dequantize_per_channel_group]
buf2401 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg723_1, arg724_1, arg725_1, -8, 7, torch.int8, 256, torch.bfloat16)
del arg723_1
del arg724_1
del arg725_1
buf2402 = buf2401
del buf2401
buf2403 = reinterpret_tensor(buf2331, (1, 14336), (14336, 1), 0); del buf2331 # reuse
# Source Nodes: [c_193], Original ATen: [aten.mm]
triton_tem_fused_mm_6.run(buf2400, buf2402, buf2403, grid=torch._inductor.kernel.mm_common.mm_grid(1, 14336, meta1), stream=stream0)
del buf2402
# Source Nodes: [input_584], Original ATen: [quantized_decomposed.quantize_per_token]
buf2404 = torch.ops.quantized_decomposed.quantize_per_token.default(buf2390, buf2395, buf2396, -128, 127, torch.int8)
buf2405 = buf2404
del buf2404
# Source Nodes: [input_585], Original ATen: [quantized_decomposed.dequantize_per_token]
buf2406 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf2405, buf2395, buf2396, -128, 127, torch.int8, torch.bfloat16)
del buf2395
del buf2396
del buf2405
buf2407 = buf2406
del buf2406
# Source Nodes: [w_dq_194], Original ATen: [quantized_decomposed.dequantize_per_channel_group]
buf2408 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg726_1, arg727_1, arg728_1, -8, 7, torch.int8, 256, torch.bfloat16)
del arg726_1
del arg727_1
del arg728_1
buf2409 = buf2408
del buf2408
buf2411 = buf2324; del buf2324 # reuse
# Source Nodes: [c_194, mul_363, silu_27], Original ATen: [aten.mm, aten.mul, aten.silu]
triton_tem_fused_mm_mul_silu_7.run(buf2407, buf2409, buf2403, buf2411, grid=torch._inductor.kernel.mm_common.mm_grid(1, 14336, meta1), stream=stream0)
del buf2403
del buf2409
# Source Nodes: [choose_qparams_per_token_asymmetric_195], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric]
buf2412 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf2411, torch.int8)
buf2413 = buf2412[0]
buf2414 = buf2412[1]
del buf2412
# Source Nodes: [input_587], Original ATen: [quantized_decomposed.quantize_per_token]
buf2415 = torch.ops.quantized_decomposed.quantize_per_token.default(buf2411, buf2413, buf2414, -128, 127, torch.int8)
buf2416 = buf2415
del buf2415
# Source Nodes: [input_588], Original ATen: [quantized_decomposed.dequantize_per_token]
buf2417 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf2416, buf2413, buf2414, -128, 127, torch.int8, torch.bfloat16)
del buf2413
del buf2414
del buf2416
buf2418 = buf2417
del buf2417
# Source Nodes: [w_dq_195], Original ATen: [quantized_decomposed.dequantize_per_channel_group]
buf2419 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg729_1, arg730_1, arg731_1, -8, 7, torch.int8, 256, torch.bfloat16)
del arg729_1
del arg730_1
del arg731_1
buf2420 = buf2419
del buf2419
buf2421 = reinterpret_tensor(buf2407, (1, 4096), (4096, 1), 0); del buf2407 # reuse
# Source Nodes: [c_195], Original ATen: [aten.mm]
triton_tem_fused_mm_8.run(buf2418, buf2420, buf2421, grid=torch._inductor.kernel.mm_common.mm_grid(1, 4096, meta0), stream=stream0)
del buf2420
buf2423 = buf2390; del buf2390 # reuse
# Source Nodes: [add_196, mean_56, mul_364, out_27, pow_57, rsqrt_56, x_fp32_56, x_normed_56, y_28], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, aten.pow, aten.rsqrt]
triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_11.run(buf2388, buf2421, arg732_1, buf2423, 1, 4096, grid=grid(1), stream=stream0)
del arg732_1
# Source Nodes: [choose_qparams_per_token_asymmetric_196], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric]
buf2424 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf2423, torch.int8)
buf2425 = buf2424[0]
buf2426 = buf2424[1]
del buf2424
# Source Nodes: [choose_qparams_per_token_asymmetric_197], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric]
buf2427 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf2423, torch.int8)
buf2428 = buf2427[0]
buf2429 = buf2427[1]
del buf2427
# Source Nodes: [choose_qparams_per_token_asymmetric_198], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric]
buf2430 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf2423, torch.int8)
buf2431 = buf2430[0]
buf2432 = buf2430[1]
del buf2430
buf2433 = buf2346; del buf2346 # reuse
# Source Nodes: [max_29], Original ATen: [aten.max]
triton_poi_fused_max_1.run(arg3_1, buf2433, 1, grid=grid(1), stream=stream0)
u28 = buf2433.item()
buf2434 = None
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
# Source Nodes: [input_590], Original ATen: [quantized_decomposed.quantize_per_token]
buf2435 = torch.ops.quantized_decomposed.quantize_per_token.default(buf2423, buf2425, buf2426, -128, 127, torch.int8)
buf2436 = buf2435
del buf2435
# Source Nodes: [input_591], Original ATen: [quantized_decomposed.dequantize_per_token]
buf2437 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf2436, buf2425, buf2426, -128, 127, torch.int8, torch.bfloat16)
del buf2425
del buf2426
del buf2436
buf2438 = buf2437
del buf2437
# Source Nodes: [w_dq_196], Original ATen: [quantized_decomposed.dequantize_per_channel_group]
buf2439 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg733_1, arg734_1, arg735_1, -8, 7, torch.int8, 256, torch.bfloat16)
del arg733_1
del arg734_1
del arg735_1
buf2440 = buf2439
del buf2439
buf2441 = reinterpret_tensor(buf2400, (1, 4096), (4096, 1), 0); del buf2400 # reuse
# Source Nodes: [c_196], Original ATen: [aten.mm]
triton_tem_fused_mm_2.run(buf2438, buf2440, buf2441, grid=torch._inductor.kernel.mm_common.mm_grid(1, 4096, meta0), stream=stream0)
del buf2438
del buf2440
# Source Nodes: [input_593], Original ATen: [quantized_decomposed.quantize_per_token]
buf2442 = torch.ops.quantized_decomposed.quantize_per_token.default(buf2423, buf2428, buf2429, -128, 127, torch.int8)
buf2443 = buf2442
del buf2442
# Source Nodes: [input_594], Original ATen: [quantized_decomposed.dequantize_per_token]
buf2444 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf2443, buf2428, buf2429, -128, 127, torch.int8, torch.bfloat16)
del buf2428
del buf2429
del buf2443
buf2445 = buf2444
del buf2444
# Source Nodes: [w_dq_197], Original ATen: [quantized_decomposed.dequantize_per_channel_group]
buf2446 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg736_1, arg737_1, arg738_1, -8, 7, torch.int8, 256, torch.bfloat16)
del arg736_1
del arg737_1
del arg738_1
buf2447 = buf2446
del buf2446
buf2448 = buf2369; del buf2369 # reuse
# Source Nodes: [], Original ATen: []
extern_kernels.mm(reinterpret_tensor(buf2445, (1, 4096), (4096, 1), 0), reinterpret_tensor(buf2447, (4096, 1024), (1, 4096), 0), out=buf2448)
del buf2445
del buf2447
# Source Nodes: [input_596], Original ATen: [quantized_decomposed.quantize_per_token]
buf2450 = torch.ops.quantized_decomposed.quantize_per_token.default(buf2423, buf2431, buf2432, -128, 127, torch.int8)
del buf2423
buf2451 = buf2450
del buf2450
# Source Nodes: [input_597], Original ATen: [quantized_decomposed.dequantize_per_token]
buf2452 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf2451, buf2431, buf2432, -128, 127, torch.int8, torch.bfloat16)
del buf2431
del buf2432
del buf2451
buf2453 = buf2452
del buf2452
# Source Nodes: [w_dq_198], Original ATen: [quantized_decomposed.dequantize_per_channel_group]
buf2454 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg739_1, arg740_1, arg741_1, -8, 7, torch.int8, 256, torch.bfloat16)
del arg739_1
del arg740_1
del arg741_1
buf2455 = buf2454
del buf2454
buf2456 = buf2361; del buf2361 # reuse
# Source Nodes: [], Original ATen: []
extern_kernels.mm(reinterpret_tensor(buf2453, (1, 4096), (4096, 1), 0), reinterpret_tensor(buf2455, (4096, 1024), (1, 4096), 0), out=buf2456)
del buf2455
buf2458 = reinterpret_tensor(buf2453, (1, 32, 1, 128), (4096, 128, 4096, 1), 0); del buf2453 # reuse
# Source Nodes: [output_56, setitem_56, setitem_57], Original ATen: [aten._scaled_dot_product_efficient_attention, aten.index_put]
triton_poi_fused__scaled_dot_product_efficient_attention_index_put_3.run(arg3_1, buf2448, arg742_1, buf2456, buf2441, arg743_1, arg744_1, buf2458, 4096, grid=grid(4096), stream=stream0)
del arg742_1
del buf2441
buf2459 = buf2372; del buf2372 # reuse
# Source Nodes: [output_56], Original ATen: [aten._scaled_dot_product_efficient_attention]
triton_poi_fused__scaled_dot_product_efficient_attention_4.run(arg3_1, arg2_1, buf2459, 65536, grid=grid(65536), stream=stream0)
# Source Nodes: [output_56], Original ATen: [aten._scaled_dot_product_efficient_attention]
buf2460 = torch.ops.aten._scaled_dot_product_efficient_attention.default(buf2458, arg743_1, arg744_1, buf2459, False)
del arg743_1
del arg744_1
del buf2458
buf2461 = buf2460[0]
del buf2460
# Source Nodes: [choose_qparams_per_token_asymmetric_199], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric]
buf2465 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(reinterpret_tensor(buf2461, (1, 1, 4096), (4096, 4096, 1), 0), torch.int8)
buf2466 = buf2465[0]
buf2467 = buf2465[1]
del buf2465
# Source Nodes: [input_599], Original ATen: [quantized_decomposed.quantize_per_token]
buf2468 = torch.ops.quantized_decomposed.quantize_per_token.default(reinterpret_tensor(buf2461, (1, 1, 4096), (4096, 4096, 1), 0), buf2466, buf2467, -128, 127, torch.int8)
buf2469 = buf2468
del buf2468
# Source Nodes: [input_600], Original ATen: [quantized_decomposed.dequantize_per_token]
buf2470 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf2469, buf2466, buf2467, -128, 127, torch.int8, torch.bfloat16)
del buf2466
del buf2467
del buf2469
buf2471 = buf2470
del buf2470
# Source Nodes: [w_dq_199], Original ATen: [quantized_decomposed.dequantize_per_channel_group]
buf2472 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg745_1, arg746_1, arg747_1, -8, 7, torch.int8, 256, torch.bfloat16)
del arg745_1
del arg746_1
del arg747_1
buf2473 = buf2472
del buf2472
buf2474 = reinterpret_tensor(buf2461, (1, 4096), (4096, 1), 0); del buf2461 # reuse
# Source Nodes: [c_199], Original ATen: [aten.mm]
triton_tem_fused_mm_2.run(buf2471, buf2473, buf2474, grid=torch._inductor.kernel.mm_common.mm_grid(1, 4096, meta0), stream=stream0)
del buf2473
buf2476 = buf2471; del buf2471 # reuse
# Source Nodes: [add_201, h_29, mean_57, mul_374, mul_375, out_27, pow_58, rsqrt_57, x_fp32_57, x_normed_57], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, aten.pow, aten.rsqrt]
triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_12.run(buf2474, buf2388, buf2421, arg748_1, buf2476, 1, 4096, grid=grid(1), stream=stream0)
del arg748_1
# Source Nodes: [choose_qparams_per_token_asymmetric_200], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric]
buf2477 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf2476, torch.int8)
buf2478 = buf2477[0]
buf2479 = buf2477[1]
del buf2477
# Source Nodes: [choose_qparams_per_token_asymmetric_201], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric]
buf2480 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf2476, torch.int8)
buf2481 = buf2480[0]
buf2482 = buf2480[1]
del buf2480
# Source Nodes: [input_602], Original ATen: [quantized_decomposed.quantize_per_token]
buf2483 = torch.ops.quantized_decomposed.quantize_per_token.default(buf2476, buf2478, buf2479, -128, 127, torch.int8)
buf2484 = buf2483
del buf2483
# Source Nodes: [input_603], Original ATen: [quantized_decomposed.dequantize_per_token]
buf2485 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf2484, buf2478, buf2479, -128, 127, torch.int8, torch.bfloat16)
del buf2478
del buf2479
del buf2484
buf2486 = buf2485
del buf2485
# Source Nodes: [w_dq_200], Original ATen: [quantized_decomposed.dequantize_per_channel_group]
buf2487 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg749_1, arg750_1, arg751_1, -8, 7, torch.int8, 256, torch.bfloat16)
del arg749_1
del arg750_1
del arg751_1
buf2488 = buf2487
del buf2487
buf2489 = reinterpret_tensor(buf2418, (1, 14336), (14336, 1), 0); del buf2418 # reuse
# Source Nodes: [c_200], Original ATen: [aten.mm]
triton_tem_fused_mm_6.run(buf2486, buf2488, buf2489, grid=torch._inductor.kernel.mm_common.mm_grid(1, 14336, meta1), stream=stream0)
del buf2488
# Source Nodes: [input_605], Original ATen: [quantized_decomposed.quantize_per_token]
buf2490 = torch.ops.quantized_decomposed.quantize_per_token.default(buf2476, buf2481, buf2482, -128, 127, torch.int8)
buf2491 = buf2490
del buf2490
# Source Nodes: [input_606], Original ATen: [quantized_decomposed.dequantize_per_token]
buf2492 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf2491, buf2481, buf2482, -128, 127, torch.int8, torch.bfloat16)
del buf2481
del buf2482
del buf2491
buf2493 = buf2492
del buf2492
# Source Nodes: [w_dq_201], Original ATen: [quantized_decomposed.dequantize_per_channel_group]
buf2494 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg752_1, arg753_1, arg754_1, -8, 7, torch.int8, 256, torch.bfloat16)
del arg752_1
del arg753_1
del arg754_1
buf2495 = buf2494
del buf2494
buf2497 = buf2411; del buf2411 # reuse
# Source Nodes: [c_201, mul_376, silu_28], Original ATen: [aten.mm, aten.mul, aten.silu]
triton_tem_fused_mm_mul_silu_7.run(buf2493, buf2495, buf2489, buf2497, grid=torch._inductor.kernel.mm_common.mm_grid(1, 14336, meta1), stream=stream0)
del buf2489
del buf2495
# Source Nodes: [choose_qparams_per_token_asymmetric_202], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric]
buf2498 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf2497, torch.int8)
buf2499 = buf2498[0]
buf2500 = buf2498[1]
del buf2498
# Source Nodes: [input_608], Original ATen: [quantized_decomposed.quantize_per_token]
buf2501 = torch.ops.quantized_decomposed.quantize_per_token.default(buf2497, buf2499, buf2500, -128, 127, torch.int8)
buf2502 = buf2501
del buf2501
# Source Nodes: [input_609], Original ATen: [quantized_decomposed.dequantize_per_token]
buf2503 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf2502, buf2499, buf2500, -128, 127, torch.int8, torch.bfloat16)
del buf2499
del buf2500
del buf2502
buf2504 = buf2503
del buf2503
# Source Nodes: [w_dq_202], Original ATen: [quantized_decomposed.dequantize_per_channel_group]
buf2505 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg755_1, arg756_1, arg757_1, -8, 7, torch.int8, 256, torch.bfloat16)
del arg755_1
del arg756_1
del arg757_1
buf2506 = buf2505
del buf2505
buf2507 = reinterpret_tensor(buf2493, (1, 4096), (4096, 1), 0); del buf2493 # reuse
# Source Nodes: [c_202], Original ATen: [aten.mm]
triton_tem_fused_mm_8.run(buf2504, buf2506, buf2507, grid=torch._inductor.kernel.mm_common.mm_grid(1, 4096, meta0), stream=stream0)
del buf2506
buf2509 = buf2476; del buf2476 # reuse
# Source Nodes: [add_203, h_29, mean_58, mul_377, out_27, out_28, pow_59, rsqrt_58, x_fp32_58, x_normed_58, y_29], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, aten.pow, aten.rsqrt]
triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_13.run(buf2474, buf2388, buf2421, buf2507, arg758_1, buf2509, 1, 4096, grid=grid(1), stream=stream0)
del arg758_1
# Source Nodes: [choose_qparams_per_token_asymmetric_203], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric]
buf2510 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf2509, torch.int8)
buf2511 = buf2510[0]
buf2512 = buf2510[1]
del buf2510
# Source Nodes: [choose_qparams_per_token_asymmetric_204], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric]
buf2513 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf2509, torch.int8)
buf2514 = buf2513[0]
buf2515 = buf2513[1]
del buf2513
# Source Nodes: [choose_qparams_per_token_asymmetric_205], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric]
buf2516 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf2509, torch.int8)
buf2517 = buf2516[0]
buf2518 = buf2516[1]
del buf2516
buf2519 = buf2433; del buf2433 # reuse
# Source Nodes: [max_30], Original ATen: [aten.max]
triton_poi_fused_max_1.run(arg3_1, buf2519, 1, grid=grid(1), stream=stream0)
u29 = buf2519.item()
buf2520 = None
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
# Source Nodes: [input_611], Original ATen: [quantized_decomposed.quantize_per_token]
buf2521 = torch.ops.quantized_decomposed.quantize_per_token.default(buf2509, buf2511, buf2512, -128, 127, torch.int8)
buf2522 = buf2521
del buf2521
# Source Nodes: [input_612], Original ATen: [quantized_decomposed.dequantize_per_token]
buf2523 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf2522, buf2511, buf2512, -128, 127, torch.int8, torch.bfloat16)
del buf2511
del buf2512
del buf2522
buf2524 = buf2523
del buf2523
# Source Nodes: [w_dq_203], Original ATen: [quantized_decomposed.dequantize_per_channel_group]
buf2525 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg759_1, arg760_1, arg761_1, -8, 7, torch.int8, 256, torch.bfloat16)
del arg759_1
del arg760_1
del arg761_1
buf2526 = buf2525
del buf2525
buf2527 = reinterpret_tensor(buf2486, (1, 4096), (4096, 1), 0); del buf2486 # reuse
# Source Nodes: [c_203], Original ATen: [aten.mm]
triton_tem_fused_mm_2.run(buf2524, buf2526, buf2527, grid=torch._inductor.kernel.mm_common.mm_grid(1, 4096, meta0), stream=stream0)
del buf2524
del buf2526
# Source Nodes: [input_614], Original ATen: [quantized_decomposed.quantize_per_token]
buf2528 = torch.ops.quantized_decomposed.quantize_per_token.default(buf2509, buf2514, buf2515, -128, 127, torch.int8)
buf2529 = buf2528
del buf2528
# Source Nodes: [input_615], Original ATen: [quantized_decomposed.dequantize_per_token]
buf2530 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf2529, buf2514, buf2515, -128, 127, torch.int8, torch.bfloat16)
del buf2514
del buf2515
del buf2529
buf2531 = buf2530
del buf2530
# Source Nodes: [w_dq_204], Original ATen: [quantized_decomposed.dequantize_per_channel_group]
buf2532 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg762_1, arg763_1, arg764_1, -8, 7, torch.int8, 256, torch.bfloat16)
del arg762_1
del arg763_1
del arg764_1
buf2533 = buf2532
del buf2532
buf2534 = buf2456; del buf2456 # reuse
# Source Nodes: [], Original ATen: []
extern_kernels.mm(reinterpret_tensor(buf2531, (1, 4096), (4096, 1), 0), reinterpret_tensor(buf2533, (4096, 1024), (1, 4096), 0), out=buf2534)
del buf2531
del buf2533
# Source Nodes: [input_617], Original ATen: [quantized_decomposed.quantize_per_token]
buf2536 = torch.ops.quantized_decomposed.quantize_per_token.default(buf2509, buf2517, buf2518, -128, 127, torch.int8)
del buf2509
buf2537 = buf2536
del buf2536
# Source Nodes: [input_618], Original ATen: [quantized_decomposed.dequantize_per_token]
buf2538 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf2537, buf2517, buf2518, -128, 127, torch.int8, torch.bfloat16)
del buf2517
del buf2518
del buf2537
buf2539 = buf2538
del buf2538
# Source Nodes: [w_dq_205], Original ATen: [quantized_decomposed.dequantize_per_channel_group]
buf2540 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg765_1, arg766_1, arg767_1, -8, 7, torch.int8, 256, torch.bfloat16)
del arg765_1
del arg766_1
del arg767_1
buf2541 = buf2540
del buf2540
buf2542 = buf2448; del buf2448 # reuse
# Source Nodes: [], Original ATen: []
extern_kernels.mm(reinterpret_tensor(buf2539, (1, 4096), (4096, 1), 0), reinterpret_tensor(buf2541, (4096, 1024), (1, 4096), 0), out=buf2542)
del buf2541
buf2544 = reinterpret_tensor(buf2539, (1, 32, 1, 128), (4096, 128, 4096, 1), 0); del buf2539 # reuse
# Source Nodes: [output_58, setitem_58, setitem_59], Original ATen: [aten._scaled_dot_product_efficient_attention, aten.index_put]
triton_poi_fused__scaled_dot_product_efficient_attention_index_put_3.run(arg3_1, buf2534, arg768_1, buf2542, buf2527, arg769_1, arg770_1, buf2544, 4096, grid=grid(4096), stream=stream0)
del arg768_1
del buf2527
buf2545 = buf2459; del buf2459 # reuse
# Source Nodes: [output_58], Original ATen: [aten._scaled_dot_product_efficient_attention]
triton_poi_fused__scaled_dot_product_efficient_attention_4.run(arg3_1, arg2_1, buf2545, 65536, grid=grid(65536), stream=stream0)
# Source Nodes: [output_58], Original ATen: [aten._scaled_dot_product_efficient_attention]
buf2546 = torch.ops.aten._scaled_dot_product_efficient_attention.default(buf2544, arg769_1, arg770_1, buf2545, False)
del arg769_1
del arg770_1
del buf2544
buf2547 = buf2546[0]
del buf2546
# Source Nodes: [choose_qparams_per_token_asymmetric_206], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric]
buf2551 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(reinterpret_tensor(buf2547, (1, 1, 4096), (4096, 4096, 1), 0), torch.int8)
buf2552 = buf2551[0]
buf2553 = buf2551[1]
del buf2551
# Source Nodes: [input_620], Original ATen: [quantized_decomposed.quantize_per_token]
buf2554 = torch.ops.quantized_decomposed.quantize_per_token.default(reinterpret_tensor(buf2547, (1, 1, 4096), (4096, 4096, 1), 0), buf2552, buf2553, -128, 127, torch.int8)
buf2555 = buf2554
del buf2554
# Source Nodes: [input_621], Original ATen: [quantized_decomposed.dequantize_per_token]
buf2556 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf2555, buf2552, buf2553, -128, 127, torch.int8, torch.bfloat16)
del buf2552
del buf2553
del buf2555
buf2557 = buf2556
del buf2556
# Source Nodes: [w_dq_206], Original ATen: [quantized_decomposed.dequantize_per_channel_group]
buf2558 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg771_1, arg772_1, arg773_1, -8, 7, torch.int8, 256, torch.bfloat16)
del arg771_1
del arg772_1
del arg773_1
buf2559 = buf2558
del buf2558
buf2560 = reinterpret_tensor(buf2547, (1, 4096), (4096, 1), 0); del buf2547 # reuse
# Source Nodes: [c_206], Original ATen: [aten.mm]
triton_tem_fused_mm_2.run(buf2557, buf2559, buf2560, grid=torch._inductor.kernel.mm_common.mm_grid(1, 4096, meta0), stream=stream0)
del buf2559
buf2561 = buf2388; del buf2388 # reuse
buf2563 = buf2557; del buf2557 # reuse
# Source Nodes: [add_208, h_29, h_30, mean_59, mul_387, mul_388, out_27, out_28, pow_60, rsqrt_59, x_fp32_59, x_normed_59], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, aten.pow, aten.rsqrt]
triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_14.run(buf2561, buf2560, buf2474, buf2421, buf2507, arg774_1, buf2563, 1, 4096, grid=grid(1), stream=stream0)
del arg774_1
del buf2421
del buf2474
del buf2507
del buf2560
# Source Nodes: [choose_qparams_per_token_asymmetric_207], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric]
buf2564 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf2563, torch.int8)
buf2565 = buf2564[0]
buf2566 = buf2564[1]
del buf2564
# Source Nodes: [choose_qparams_per_token_asymmetric_208], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric]
buf2567 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf2563, torch.int8)
buf2568 = buf2567[0]
buf2569 = buf2567[1]
del buf2567
# Source Nodes: [input_623], Original ATen: [quantized_decomposed.quantize_per_token]
buf2570 = torch.ops.quantized_decomposed.quantize_per_token.default(buf2563, buf2565, buf2566, -128, 127, torch.int8)
buf2571 = buf2570
del buf2570
# Source Nodes: [input_624], Original ATen: [quantized_decomposed.dequantize_per_token]
buf2572 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf2571, buf2565, buf2566, -128, 127, torch.int8, torch.bfloat16)
del buf2565
del buf2566
del buf2571
buf2573 = buf2572
del buf2572
# Source Nodes: [w_dq_207], Original ATen: [quantized_decomposed.dequantize_per_channel_group]
buf2574 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg775_1, arg776_1, arg777_1, -8, 7, torch.int8, 256, torch.bfloat16)
del arg775_1
del arg776_1
del arg777_1
buf2575 = buf2574
del buf2574
buf2576 = reinterpret_tensor(buf2504, (1, 14336), (14336, 1), 0); del buf2504 # reuse
# Source Nodes: [c_207], Original ATen: [aten.mm]
triton_tem_fused_mm_6.run(buf2573, buf2575, buf2576, grid=torch._inductor.kernel.mm_common.mm_grid(1, 14336, meta1), stream=stream0)
del buf2575
# Source Nodes: [input_626], Original ATen: [quantized_decomposed.quantize_per_token]
buf2577 = torch.ops.quantized_decomposed.quantize_per_token.default(buf2563, buf2568, buf2569, -128, 127, torch.int8)
buf2578 = buf2577
del buf2577
# Source Nodes: [input_627], Original ATen: [quantized_decomposed.dequantize_per_token]
buf2579 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf2578, buf2568, buf2569, -128, 127, torch.int8, torch.bfloat16)
del buf2568
del buf2569
del buf2578
buf2580 = buf2579
del buf2579
# Source Nodes: [w_dq_208], Original ATen: [quantized_decomposed.dequantize_per_channel_group]
buf2581 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg778_1, arg779_1, arg780_1, -8, 7, torch.int8, 256, torch.bfloat16)
del arg778_1
del arg779_1
del arg780_1
buf2582 = buf2581
del buf2581
buf2584 = buf2497; del buf2497 # reuse
# Source Nodes: [c_208, mul_389, silu_29], Original ATen: [aten.mm, aten.mul, aten.silu]
triton_tem_fused_mm_mul_silu_7.run(buf2580, buf2582, buf2576, buf2584, grid=torch._inductor.kernel.mm_common.mm_grid(1, 14336, meta1), stream=stream0)
del buf2576
del buf2582
# Source Nodes: [choose_qparams_per_token_asymmetric_209], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric]
buf2585 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf2584, torch.int8)
buf2586 = buf2585[0]
buf2587 = buf2585[1]
del buf2585
# Source Nodes: [input_629], Original ATen: [quantized_decomposed.quantize_per_token]
buf2588 = torch.ops.quantized_decomposed.quantize_per_token.default(buf2584, buf2586, buf2587, -128, 127, torch.int8)
buf2589 = buf2588
del buf2588
# Source Nodes: [input_630], Original ATen: [quantized_decomposed.dequantize_per_token]
buf2590 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf2589, buf2586, buf2587, -128, 127, torch.int8, torch.bfloat16)
del buf2586
del buf2587
del buf2589
buf2591 = buf2590
del buf2590
# Source Nodes: [w_dq_209], Original ATen: [quantized_decomposed.dequantize_per_channel_group]
buf2592 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg781_1, arg782_1, arg783_1, -8, 7, torch.int8, 256, torch.bfloat16)
del arg781_1
del arg782_1
del arg783_1
buf2593 = buf2592
del buf2592
buf2594 = reinterpret_tensor(buf2580, (1, 4096), (4096, 1), 0); del buf2580 # reuse
# Source Nodes: [c_209], Original ATen: [aten.mm]
triton_tem_fused_mm_8.run(buf2591, buf2593, buf2594, grid=torch._inductor.kernel.mm_common.mm_grid(1, 4096, meta0), stream=stream0)
del buf2593
buf2596 = buf2563; del buf2563 # reuse
# Source Nodes: [add_210, mean_60, mul_390, out_29, pow_61, rsqrt_60, x_fp32_60, x_normed_60, y_30], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, aten.pow, aten.rsqrt]
triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_11.run(buf2561, buf2594, arg784_1, buf2596, 1, 4096, grid=grid(1), stream=stream0)
del arg784_1
# Source Nodes: [choose_qparams_per_token_asymmetric_210], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric]
buf2597 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf2596, torch.int8)
buf2598 = buf2597[0]
buf2599 = buf2597[1]
del buf2597
# Source Nodes: [choose_qparams_per_token_asymmetric_211], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric]
buf2600 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf2596, torch.int8)
buf2601 = buf2600[0]
buf2602 = buf2600[1]
del buf2600
# Source Nodes: [choose_qparams_per_token_asymmetric_212], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric]
buf2603 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf2596, torch.int8)
buf2604 = buf2603[0]
buf2605 = buf2603[1]
del buf2603
buf2606 = buf2519; del buf2519 # reuse
# Source Nodes: [max_31], Original ATen: [aten.max]
triton_poi_fused_max_1.run(arg3_1, buf2606, 1, grid=grid(1), stream=stream0)
u30 = buf2606.item()
buf2607 = None
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
# Source Nodes: [input_632], Original ATen: [quantized_decomposed.quantize_per_token]
buf2608 = torch.ops.quantized_decomposed.quantize_per_token.default(buf2596, buf2598, buf2599, -128, 127, torch.int8)
buf2609 = buf2608
del buf2608
# Source Nodes: [input_633], Original ATen: [quantized_decomposed.dequantize_per_token]
buf2610 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf2609, buf2598, buf2599, -128, 127, torch.int8, torch.bfloat16)
del buf2598
del buf2599
del buf2609
buf2611 = buf2610
del buf2610
# Source Nodes: [w_dq_210], Original ATen: [quantized_decomposed.dequantize_per_channel_group]
buf2612 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg785_1, arg786_1, arg787_1, -8, 7, torch.int8, 256, torch.bfloat16)
del arg785_1
del arg786_1
del arg787_1
buf2613 = buf2612
del buf2612
buf2614 = reinterpret_tensor(buf2573, (1, 4096), (4096, 1), 0); del buf2573 # reuse
# Source Nodes: [c_210], Original ATen: [aten.mm]
triton_tem_fused_mm_2.run(buf2611, buf2613, buf2614, grid=torch._inductor.kernel.mm_common.mm_grid(1, 4096, meta0), stream=stream0)
del buf2611
del buf2613
# Source Nodes: [input_635], Original ATen: [quantized_decomposed.quantize_per_token]
buf2615 = torch.ops.quantized_decomposed.quantize_per_token.default(buf2596, buf2601, buf2602, -128, 127, torch.int8)
buf2616 = buf2615
del buf2615
# Source Nodes: [input_636], Original ATen: [quantized_decomposed.dequantize_per_token]
buf2617 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf2616, buf2601, buf2602, -128, 127, torch.int8, torch.bfloat16)
del buf2601
del buf2602
del buf2616
buf2618 = buf2617
del buf2617
# Source Nodes: [w_dq_211], Original ATen: [quantized_decomposed.dequantize_per_channel_group]
buf2619 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg788_1, arg789_1, arg790_1, -8, 7, torch.int8, 256, torch.bfloat16)
del arg788_1
del arg789_1
del arg790_1
buf2620 = buf2619
del buf2619
buf2621 = buf2542; del buf2542 # reuse
# Source Nodes: [], Original ATen: []
extern_kernels.mm(reinterpret_tensor(buf2618, (1, 4096), (4096, 1), 0), reinterpret_tensor(buf2620, (4096, 1024), (1, 4096), 0), out=buf2621)
del buf2618
del buf2620
# Source Nodes: [input_638], Original ATen: [quantized_decomposed.quantize_per_token]
buf2623 = torch.ops.quantized_decomposed.quantize_per_token.default(buf2596, buf2604, buf2605, -128, 127, torch.int8)
del buf2596
buf2624 = buf2623
del buf2623
# Source Nodes: [input_639], Original ATen: [quantized_decomposed.dequantize_per_token]
buf2625 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf2624, buf2604, buf2605, -128, 127, torch.int8, torch.bfloat16)
del buf2604
del buf2605
del buf2624
buf2626 = buf2625
del buf2625
# Source Nodes: [w_dq_212], Original ATen: [quantized_decomposed.dequantize_per_channel_group]
buf2627 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg791_1, arg792_1, arg793_1, -8, 7, torch.int8, 256, torch.bfloat16)
del arg791_1
del arg792_1
del arg793_1
buf2628 = buf2627
del buf2627
buf2629 = buf2534; del buf2534 # reuse
# Source Nodes: [], Original ATen: []
extern_kernels.mm(reinterpret_tensor(buf2626, (1, 4096), (4096, 1), 0), reinterpret_tensor(buf2628, (4096, 1024), (1, 4096), 0), out=buf2629)
del buf2628
buf2631 = reinterpret_tensor(buf2626, (1, 32, 1, 128), (4096, 128, 4096, 1), 0); del buf2626 # reuse
# Source Nodes: [output_60, setitem_60, setitem_61], Original ATen: [aten._scaled_dot_product_efficient_attention, aten.index_put]
triton_poi_fused__scaled_dot_product_efficient_attention_index_put_3.run(arg3_1, buf2621, arg794_1, buf2629, buf2614, arg795_1, arg796_1, buf2631, 4096, grid=grid(4096), stream=stream0)
del arg794_1
del buf2614
buf2632 = buf2545; del buf2545 # reuse
# Source Nodes: [output_60], Original ATen: [aten._scaled_dot_product_efficient_attention]
triton_poi_fused__scaled_dot_product_efficient_attention_4.run(arg3_1, arg2_1, buf2632, 65536, grid=grid(65536), stream=stream0)
# Source Nodes: [output_60], Original ATen: [aten._scaled_dot_product_efficient_attention]
buf2633 = torch.ops.aten._scaled_dot_product_efficient_attention.default(buf2631, arg795_1, arg796_1, buf2632, False)
del arg795_1
del arg796_1
del buf2631
buf2634 = buf2633[0]
del buf2633
# Source Nodes: [choose_qparams_per_token_asymmetric_213], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric]
buf2638 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(reinterpret_tensor(buf2634, (1, 1, 4096), (4096, 4096, 1), 0), torch.int8)
buf2639 = buf2638[0]
buf2640 = buf2638[1]
del buf2638
# Source Nodes: [input_641], Original ATen: [quantized_decomposed.quantize_per_token]
buf2641 = torch.ops.quantized_decomposed.quantize_per_token.default(reinterpret_tensor(buf2634, (1, 1, 4096), (4096, 4096, 1), 0), buf2639, buf2640, -128, 127, torch.int8)
buf2642 = buf2641
del buf2641
# Source Nodes: [input_642], Original ATen: [quantized_decomposed.dequantize_per_token]
buf2643 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf2642, buf2639, buf2640, -128, 127, torch.int8, torch.bfloat16)
del buf2639
del buf2640
del buf2642
buf2644 = buf2643
del buf2643
# Source Nodes: [w_dq_213], Original ATen: [quantized_decomposed.dequantize_per_channel_group]
buf2645 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg797_1, arg798_1, arg799_1, -8, 7, torch.int8, 256, torch.bfloat16)
del arg797_1
del arg798_1
del arg799_1
buf2646 = buf2645
del buf2645
buf2647 = reinterpret_tensor(buf2634, (1, 4096), (4096, 1), 0); del buf2634 # reuse
# Source Nodes: [c_213], Original ATen: [aten.mm]
triton_tem_fused_mm_2.run(buf2644, buf2646, buf2647, grid=torch._inductor.kernel.mm_common.mm_grid(1, 4096, meta0), stream=stream0)
del buf2646
buf2649 = buf2644; del buf2644 # reuse
# Source Nodes: [add_215, h_31, mean_61, mul_400, mul_401, out_29, pow_62, rsqrt_61, x_fp32_61, x_normed_61], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, aten.pow, aten.rsqrt]
triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_12.run(buf2647, buf2561, buf2594, arg800_1, buf2649, 1, 4096, grid=grid(1), stream=stream0)
del arg800_1
# Source Nodes: [choose_qparams_per_token_asymmetric_214], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric]
buf2650 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf2649, torch.int8)
buf2651 = buf2650[0]
buf2652 = buf2650[1]
del buf2650
# Source Nodes: [choose_qparams_per_token_asymmetric_215], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric]
buf2653 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf2649, torch.int8)
buf2654 = buf2653[0]
buf2655 = buf2653[1]
del buf2653
# Source Nodes: [input_644], Original ATen: [quantized_decomposed.quantize_per_token]
buf2656 = torch.ops.quantized_decomposed.quantize_per_token.default(buf2649, buf2651, buf2652, -128, 127, torch.int8)
buf2657 = buf2656
del buf2656
# Source Nodes: [input_645], Original ATen: [quantized_decomposed.dequantize_per_token]
buf2658 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf2657, buf2651, buf2652, -128, 127, torch.int8, torch.bfloat16)
del buf2651
del buf2652
del buf2657
buf2659 = buf2658
del buf2658
# Source Nodes: [w_dq_214], Original ATen: [quantized_decomposed.dequantize_per_channel_group]
buf2660 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg801_1, arg802_1, arg803_1, -8, 7, torch.int8, 256, torch.bfloat16)
del arg801_1
del arg802_1
del arg803_1
buf2661 = buf2660
del buf2660
buf2662 = reinterpret_tensor(buf2591, (1, 14336), (14336, 1), 0); del buf2591 # reuse
# Source Nodes: [c_214], Original ATen: [aten.mm]
triton_tem_fused_mm_6.run(buf2659, buf2661, buf2662, grid=torch._inductor.kernel.mm_common.mm_grid(1, 14336, meta1), stream=stream0)
del buf2661
# Source Nodes: [input_647], Original ATen: [quantized_decomposed.quantize_per_token]
buf2663 = torch.ops.quantized_decomposed.quantize_per_token.default(buf2649, buf2654, buf2655, -128, 127, torch.int8)
buf2664 = buf2663
del buf2663
# Source Nodes: [input_648], Original ATen: [quantized_decomposed.dequantize_per_token]
buf2665 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf2664, buf2654, buf2655, -128, 127, torch.int8, torch.bfloat16)
del buf2654
del buf2655
del buf2664
buf2666 = buf2665
del buf2665
# Source Nodes: [w_dq_215], Original ATen: [quantized_decomposed.dequantize_per_channel_group]
buf2667 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg804_1, arg805_1, arg806_1, -8, 7, torch.int8, 256, torch.bfloat16)
del arg804_1
del arg805_1
del arg806_1
buf2668 = buf2667
del buf2667
buf2670 = buf2584; del buf2584 # reuse
# Source Nodes: [c_215, mul_402, silu_30], Original ATen: [aten.mm, aten.mul, aten.silu]
triton_tem_fused_mm_mul_silu_7.run(buf2666, buf2668, buf2662, buf2670, grid=torch._inductor.kernel.mm_common.mm_grid(1, 14336, meta1), stream=stream0)
del buf2662
del buf2668
# Source Nodes: [choose_qparams_per_token_asymmetric_216], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric]
buf2671 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf2670, torch.int8)
buf2672 = buf2671[0]
buf2673 = buf2671[1]
del buf2671
# Source Nodes: [input_650], Original ATen: [quantized_decomposed.quantize_per_token]
buf2674 = torch.ops.quantized_decomposed.quantize_per_token.default(buf2670, buf2672, buf2673, -128, 127, torch.int8)
buf2675 = buf2674
del buf2674
# Source Nodes: [input_651], Original ATen: [quantized_decomposed.dequantize_per_token]
buf2676 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf2675, buf2672, buf2673, -128, 127, torch.int8, torch.bfloat16)
del buf2672
del buf2673
del buf2675
buf2677 = buf2676
del buf2676
# Source Nodes: [w_dq_216], Original ATen: [quantized_decomposed.dequantize_per_channel_group]
buf2678 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg807_1, arg808_1, arg809_1, -8, 7, torch.int8, 256, torch.bfloat16)
del arg807_1
del arg808_1
del arg809_1
buf2679 = buf2678
del buf2678
buf2680 = reinterpret_tensor(buf2666, (1, 4096), (4096, 1), 0); del buf2666 # reuse
# Source Nodes: [c_216], Original ATen: [aten.mm]
triton_tem_fused_mm_8.run(buf2677, buf2679, buf2680, grid=torch._inductor.kernel.mm_common.mm_grid(1, 4096, meta0), stream=stream0)
del buf2679
buf2682 = buf2649; del buf2649 # reuse
# Source Nodes: [add_217, h_31, mean_62, mul_403, out_29, out_30, pow_63, rsqrt_62, x_fp32_62, x_normed_62, y_31], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, aten.pow, aten.rsqrt]
triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_13.run(buf2647, buf2561, buf2594, buf2680, arg810_1, buf2682, 1, 4096, grid=grid(1), stream=stream0)
del arg810_1
# Source Nodes: [choose_qparams_per_token_asymmetric_217], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric]
buf2683 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf2682, torch.int8)
buf2684 = buf2683[0]
buf2685 = buf2683[1]
del buf2683
# Source Nodes: [choose_qparams_per_token_asymmetric_218], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric]
buf2686 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf2682, torch.int8)
buf2687 = buf2686[0]
buf2688 = buf2686[1]
del buf2686
# Source Nodes: [choose_qparams_per_token_asymmetric_219], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric]
buf2689 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf2682, torch.int8)
buf2690 = buf2689[0]
buf2691 = buf2689[1]
del buf2689
buf2692 = buf2606; del buf2606 # reuse
# Source Nodes: [max_32], Original ATen: [aten.max]
triton_poi_fused_max_1.run(arg3_1, buf2692, 1, grid=grid(1), stream=stream0)
u31 = buf2692.item()
buf2693 = None
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
# Source Nodes: [input_653], Original ATen: [quantized_decomposed.quantize_per_token]
buf2694 = torch.ops.quantized_decomposed.quantize_per_token.default(buf2682, buf2684, buf2685, -128, 127, torch.int8)
buf2695 = buf2694
del buf2694
# Source Nodes: [input_654], Original ATen: [quantized_decomposed.dequantize_per_token]
buf2696 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf2695, buf2684, buf2685, -128, 127, torch.int8, torch.bfloat16)
del buf2684
del buf2685
del buf2695
buf2697 = buf2696
del buf2696
# Source Nodes: [w_dq_217], Original ATen: [quantized_decomposed.dequantize_per_channel_group]
buf2698 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg811_1, arg812_1, arg813_1, -8, 7, torch.int8, 256, torch.bfloat16)
del arg811_1
del arg812_1
del arg813_1
buf2699 = buf2698
del buf2698
buf2700 = reinterpret_tensor(buf2659, (1, 4096), (4096, 1), 0); del buf2659 # reuse
# Source Nodes: [c_217], Original ATen: [aten.mm]
triton_tem_fused_mm_2.run(buf2697, buf2699, buf2700, grid=torch._inductor.kernel.mm_common.mm_grid(1, 4096, meta0), stream=stream0)
del buf2697
del buf2699
# Source Nodes: [input_656], Original ATen: [quantized_decomposed.quantize_per_token]
buf2701 = torch.ops.quantized_decomposed.quantize_per_token.default(buf2682, buf2687, buf2688, -128, 127, torch.int8)
buf2702 = buf2701
del buf2701
# Source Nodes: [input_657], Original ATen: [quantized_decomposed.dequantize_per_token]
buf2703 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf2702, buf2687, buf2688, -128, 127, torch.int8, torch.bfloat16)
del buf2687
del buf2688
del buf2702
buf2704 = buf2703
del buf2703
# Source Nodes: [w_dq_218], Original ATen: [quantized_decomposed.dequantize_per_channel_group]
buf2705 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg814_1, arg815_1, arg816_1, -8, 7, torch.int8, 256, torch.bfloat16)
del arg814_1
del arg815_1
del arg816_1
buf2706 = buf2705
del buf2705
buf2707 = buf2629; del buf2629 # reuse
# Source Nodes: [], Original ATen: []
extern_kernels.mm(reinterpret_tensor(buf2704, (1, 4096), (4096, 1), 0), reinterpret_tensor(buf2706, (4096, 1024), (1, 4096), 0), out=buf2707)
del buf2704
del buf2706
# Source Nodes: [input_659], Original ATen: [quantized_decomposed.quantize_per_token]
buf2709 = torch.ops.quantized_decomposed.quantize_per_token.default(buf2682, buf2690, buf2691, -128, 127, torch.int8)
del buf2682
buf2710 = buf2709
del buf2709
# Source Nodes: [input_660], Original ATen: [quantized_decomposed.dequantize_per_token]
buf2711 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf2710, buf2690, buf2691, -128, 127, torch.int8, torch.bfloat16)
del buf2690
del buf2691
del buf2710
buf2712 = buf2711
del buf2711
# Source Nodes: [w_dq_219], Original ATen: [quantized_decomposed.dequantize_per_channel_group]
buf2713 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg817_1, arg818_1, arg819_1, -8, 7, torch.int8, 256, torch.bfloat16)
del arg817_1
del arg818_1
del arg819_1
buf2714 = buf2713
del buf2713
buf2715 = buf2621; del buf2621 # reuse
# Source Nodes: [], Original ATen: []
extern_kernels.mm(reinterpret_tensor(buf2712, (1, 4096), (4096, 1), 0), reinterpret_tensor(buf2714, (4096, 1024), (1, 4096), 0), out=buf2715)
del buf2714
buf2717 = reinterpret_tensor(buf2712, (1, 32, 1, 128), (4096, 128, 4096, 1), 0); del buf2712 # reuse
# Source Nodes: [output_62, setitem_62, setitem_63], Original ATen: [aten._scaled_dot_product_efficient_attention, aten.index_put]
triton_poi_fused__scaled_dot_product_efficient_attention_index_put_3.run(arg3_1, buf2707, arg820_1, buf2715, buf2700, arg821_1, arg822_1, buf2717, 4096, grid=grid(4096), stream=stream0)
del arg820_1
del buf2700
del buf2707
del buf2715
buf2718 = buf2632; del buf2632 # reuse
# Source Nodes: [output_62], Original ATen: [aten._scaled_dot_product_efficient_attention]
triton_poi_fused__scaled_dot_product_efficient_attention_4.run(arg3_1, arg2_1, buf2718, 65536, grid=grid(65536), stream=stream0)
del arg2_1
del arg3_1
# Source Nodes: [output_62], Original ATen: [aten._scaled_dot_product_efficient_attention]
buf2719 = torch.ops.aten._scaled_dot_product_efficient_attention.default(buf2717, arg821_1, arg822_1, buf2718, False)
del arg821_1
del arg822_1
del buf2717
del buf2718
buf2720 = buf2719[0]
del buf2719
# Source Nodes: [choose_qparams_per_token_asymmetric_220], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric]
buf2724 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(reinterpret_tensor(buf2720, (1, 1, 4096), (4096, 4096, 1), 0), torch.int8)
buf2725 = buf2724[0]
buf2726 = buf2724[1]
del buf2724
# Source Nodes: [input_662], Original ATen: [quantized_decomposed.quantize_per_token]
buf2727 = torch.ops.quantized_decomposed.quantize_per_token.default(reinterpret_tensor(buf2720, (1, 1, 4096), (4096, 4096, 1), 0), buf2725, buf2726, -128, 127, torch.int8)
buf2728 = buf2727
del buf2727
# Source Nodes: [input_663], Original ATen: [quantized_decomposed.dequantize_per_token]
buf2729 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf2728, buf2725, buf2726, -128, 127, torch.int8, torch.bfloat16)
del buf2725
del buf2726
del buf2728
buf2730 = buf2729
del buf2729
# Source Nodes: [w_dq_220], Original ATen: [quantized_decomposed.dequantize_per_channel_group]
buf2731 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg823_1, arg824_1, arg825_1, -8, 7, torch.int8, 256, torch.bfloat16)
del arg823_1
del arg824_1
del arg825_1
buf2732 = buf2731
del buf2731
buf2733 = reinterpret_tensor(buf2720, (1, 4096), (4096, 1), 0); del buf2720 # reuse
# Source Nodes: [c_220], Original ATen: [aten.mm]
triton_tem_fused_mm_2.run(buf2730, buf2732, buf2733, grid=torch._inductor.kernel.mm_common.mm_grid(1, 4096, meta0), stream=stream0)
del buf2732
buf2734 = buf2561; del buf2561 # reuse
buf2736 = buf2730; del buf2730 # reuse
# Source Nodes: [add_222, h_31, h_32, mean_63, mul_413, mul_414, out_29, out_30, pow_64, rsqrt_63, x_fp32_63, x_normed_63], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, aten.pow, aten.rsqrt]
triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_14.run(buf2734, buf2733, buf2647, buf2594, buf2680, arg826_1, buf2736, 1, 4096, grid=grid(1), stream=stream0)
del arg826_1
del buf2594
del buf2647
del buf2680
del buf2733
# Source Nodes: [choose_qparams_per_token_asymmetric_221], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric]
buf2737 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf2736, torch.int8)
buf2738 = buf2737[0]
buf2739 = buf2737[1]
del buf2737
# Source Nodes: [choose_qparams_per_token_asymmetric_222], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric]
buf2740 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf2736, torch.int8)
buf2741 = buf2740[0]
buf2742 = buf2740[1]
del buf2740
# Source Nodes: [input_665], Original ATen: [quantized_decomposed.quantize_per_token]
buf2743 = torch.ops.quantized_decomposed.quantize_per_token.default(buf2736, buf2738, buf2739, -128, 127, torch.int8)
buf2744 = buf2743
del buf2743
# Source Nodes: [input_666], Original ATen: [quantized_decomposed.dequantize_per_token]
buf2745 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf2744, buf2738, buf2739, -128, 127, torch.int8, torch.bfloat16)
del buf2738
del buf2739
del buf2744
buf2746 = buf2745
del buf2745
# Source Nodes: [w_dq_221], Original ATen: [quantized_decomposed.dequantize_per_channel_group]
buf2747 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg827_1, arg828_1, arg829_1, -8, 7, torch.int8, 256, torch.bfloat16)
del arg827_1
del arg828_1
del arg829_1
buf2748 = buf2747
del buf2747
buf2749 = reinterpret_tensor(buf2677, (1, 14336), (14336, 1), 0); del buf2677 # reuse
# Source Nodes: [c_221], Original ATen: [aten.mm]
triton_tem_fused_mm_6.run(buf2746, buf2748, buf2749, grid=torch._inductor.kernel.mm_common.mm_grid(1, 14336, meta1), stream=stream0)
del buf2746
del buf2748
# Source Nodes: [input_668], Original ATen: [quantized_decomposed.quantize_per_token]
buf2750 = torch.ops.quantized_decomposed.quantize_per_token.default(buf2736, buf2741, buf2742, -128, 127, torch.int8)
buf2751 = buf2750
del buf2750
# Source Nodes: [input_669], Original ATen: [quantized_decomposed.dequantize_per_token]
buf2752 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf2751, buf2741, buf2742, -128, 127, torch.int8, torch.bfloat16)
del buf2741
del buf2742
del buf2751
buf2753 = buf2752
del buf2752
# Source Nodes: [w_dq_222], Original ATen: [quantized_decomposed.dequantize_per_channel_group]
buf2754 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg830_1, arg831_1, arg832_1, -8, 7, torch.int8, 256, torch.bfloat16)
del arg830_1
del arg831_1
del arg832_1
buf2755 = buf2754
del buf2754
buf2757 = buf2670; del buf2670 # reuse
# Source Nodes: [c_222, mul_415, silu_31], Original ATen: [aten.mm, aten.mul, aten.silu]
triton_tem_fused_mm_mul_silu_7.run(buf2753, buf2755, buf2749, buf2757, grid=torch._inductor.kernel.mm_common.mm_grid(1, 14336, meta1), stream=stream0)
del buf2749
del buf2755
# Source Nodes: [choose_qparams_per_token_asymmetric_223], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric]
buf2758 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf2757, torch.int8)
buf2759 = buf2758[0]
buf2760 = buf2758[1]
del buf2758
# Source Nodes: [input_671], Original ATen: [quantized_decomposed.quantize_per_token]
buf2761 = torch.ops.quantized_decomposed.quantize_per_token.default(buf2757, buf2759, buf2760, -128, 127, torch.int8)
del buf2757
buf2762 = buf2761
del buf2761
# Source Nodes: [input_672], Original ATen: [quantized_decomposed.dequantize_per_token]
buf2763 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf2762, buf2759, buf2760, -128, 127, torch.int8, torch.bfloat16)
del buf2759
del buf2760
del buf2762
buf2764 = buf2763
del buf2763
# Source Nodes: [w_dq_223], Original ATen: [quantized_decomposed.dequantize_per_channel_group]
buf2765 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg833_1, arg834_1, arg835_1, -8, 7, torch.int8, 256, torch.bfloat16)
del arg833_1
del arg834_1
del arg835_1
buf2766 = buf2765
del buf2765
buf2767 = reinterpret_tensor(buf2753, (1, 4096), (4096, 1), 0); del buf2753 # reuse
# Source Nodes: [c_223], Original ATen: [aten.mm]
triton_tem_fused_mm_8.run(buf2764, buf2766, buf2767, grid=torch._inductor.kernel.mm_common.mm_grid(1, 4096, meta0), stream=stream0)
del buf2764
del buf2766
buf2769 = buf2736; del buf2736 # reuse
# Source Nodes: [add_224, h_33, mean_64, mul_416, out_31, pow_65, rsqrt_64, x_fp32_64, x_normed_64], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, aten.pow, aten.rsqrt]
triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_11.run(buf2734, buf2767, arg836_1, buf2769, 1, 4096, grid=grid(1), stream=stream0)
del arg836_1
del buf2734
del buf2767
# Source Nodes: [choose_qparams_per_token_asymmetric_224], Original ATen: [quantized_decomposed.choose_qparams_per_token_asymmetric]
buf2770 = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(buf2769, torch.int8)
buf2771 = buf2770[0]
buf2772 = buf2770[1]
del buf2770
# Source Nodes: [input_674], Original ATen: [quantized_decomposed.quantize_per_token]
buf2773 = torch.ops.quantized_decomposed.quantize_per_token.default(buf2769, buf2771, buf2772, -128, 127, torch.int8)
del buf2769
buf2774 = buf2773
del buf2773
# Source Nodes: [input_675], Original ATen: [quantized_decomposed.dequantize_per_token]
buf2775 = torch.ops.quantized_decomposed.dequantize_per_token.default(buf2774, buf2771, buf2772, -128, 127, torch.int8, torch.bfloat16)
del buf2771
del buf2772
del buf2774
buf2776 = buf2775
del buf2775
# Source Nodes: [w_dq_224], Original ATen: [quantized_decomposed.dequantize_per_channel_group]
buf2777 = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(arg837_1, arg838_1, arg839_1, -8, 7, torch.int8, 256, torch.bfloat16)
del arg837_1
del arg838_1
del arg839_1
buf2778 = buf2777
del buf2777
buf2780 = empty_strided_cuda((1, 128256), (128256, 1), torch.float32)
# Source Nodes: [c_224, logits_1], Original ATen: [aten.div, aten.mm]
triton_tem_fused_div_mm_16.run(buf2776, buf2778, buf2780, grid=torch._inductor.kernel.mm_common.mm_grid(1, 128256, meta2), stream=stream0)
del buf2776
del buf2778
# Source Nodes: [logits_1, topk], Original ATen: [aten.div, aten.topk]
buf2781 = torch.ops.aten.topk.default(buf2780, 300)
buf2782 = buf2781[0]
del buf2781
buf2784 = reinterpret_tensor(buf2692, (1, ), (1, ), 0); del buf2692 # reuse
# Source Nodes: [], Original ATen: []
aten.randint.low_out(-9223372036854775808, 9223372036854775807, [1], out=buf2784)
buf2786 = empty_strided_cuda((1, 1, 16), (16, 16, 1), torch.float32)
# Source Nodes: [logits_2, lt, probs], Original ATen: [aten._softmax, aten.lt, aten.scalar_tensor, aten.where]
triton_red_fused__softmax_lt_scalar_tensor_where_17.run(buf2780, buf2782, buf2786, 16, 8016, grid=grid(16), stream=stream0)
buf2787 = empty_strided_cuda((1, 1), (1, 1), torch.float32)
# Source Nodes: [logits_2, lt, probs], Original ATen: [aten._softmax, aten.lt, aten.scalar_tensor, aten.where]
triton_per_fused__softmax_lt_scalar_tensor_where_18.run(buf2786, buf2787, 1, 16, grid=grid(1), stream=stream0)
buf2788 = buf2786; del buf2786 # reuse
# Source Nodes: [logits_2, lt, probs], Original ATen: [aten._softmax, aten.lt, aten.scalar_tensor, aten.where]
triton_red_fused__softmax_lt_scalar_tensor_where_19.run(buf2780, buf2782, buf2787, buf2788, 16, 8016, grid=grid(16), stream=stream0)
buf2789 = empty_strided_cuda((1, 1), (1, 1), torch.float32)
# Source Nodes: [logits_2, lt, probs], Original ATen: [aten._softmax, aten.lt, aten.scalar_tensor, aten.where]
triton_per_fused__softmax_lt_scalar_tensor_where_20.run(buf2788, buf2789, 1, 16, grid=grid(1), stream=stream0)
del buf2788
buf2791 = empty_strided_cuda((1, 1), (1, 1), torch.int32)
# Source Nodes: [argmax, logits_2, lt, probs, q_128, to_450, truediv_1], Original ATen: [aten._softmax, aten._to_copy, aten.argmax, aten.div, aten.exponential, aten.lt, aten.scalar_tensor, aten.where]
triton_red_fused__softmax__to_copy_argmax_div_exponential_lt_scalar_tensor_where_21.run(buf2784, buf2780, buf2782, buf2787, buf2789, buf2791, 0, 1, 128256, grid=grid(1), stream=stream0)
del buf2780
del buf2782
del buf2784
del buf2787
del buf2789
return (buf2791, 1 + u0, 1 + u1, 1 + u2, 1 + u3, 1 + u4, 1 + u5, 1 + u6, 1 + u7, 1 + u8, 1 + u9, 1 + u10, 1 + u11, 1 + u12, 1 + u13, 1 + u14, 1 + u15, 1 + u16, 1 + u17, 1 + u18, 1 + u19, 1 + u20, 1 + u21, 1 + u22, 1 + u23, 1 + u24, 1 + u25, 1 + u26, 1 + u27, 1 + u28, 1 + u29, 1 + u30, 1 + u31, )
def benchmark_compiled_module(times=10, repeat=10):
from torch._dynamo.testing import rand_strided
from torch._inductor.utils import print_performance
arg0_1 = rand_strided((1, 1), (1, 1), device='cuda:0', dtype=torch.int32)
arg1_1 = rand_strided((128256, 4096), (4096, 1), device='cuda:0', dtype=torch.bfloat16)
arg2_1 = rand_strided((2048, 2048), (2048, 1), device='cuda:0', dtype=torch.bool)
arg3_1 = rand_strided((1, ), (1, ), device='cuda:0', dtype=torch.int64)
arg4_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
arg5_1 = rand_strided((4096, 4096), (4096, 1), device='cuda:0', dtype=torch.int8)
arg6_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg7_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg8_1 = rand_strided((1024, 4096), (4096, 1), device='cuda:0', dtype=torch.int8)
arg9_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg10_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg11_1 = rand_strided((1024, 4096), (4096, 1), device='cuda:0', dtype=torch.int8)
arg12_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg13_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg14_1 = rand_strided((2048, 64, 2), (128, 2, 1), device='cuda:0', dtype=torch.float32)
arg15_1 = rand_strided((1, 32, 2048, 128), (8388608, 262144, 128, 1), device='cuda:0', dtype=torch.bfloat16)
arg16_1 = rand_strided((1, 32, 2048, 128), (8388608, 262144, 128, 1), device='cuda:0', dtype=torch.bfloat16)
arg17_1 = rand_strided((4096, 4096), (4096, 1), device='cuda:0', dtype=torch.int8)
arg18_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg19_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg20_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
arg21_1 = rand_strided((14336, 4096), (4096, 1), device='cuda:0', dtype=torch.int8)
arg22_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg23_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg24_1 = rand_strided((14336, 4096), (4096, 1), device='cuda:0', dtype=torch.int8)
arg25_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg26_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg27_1 = rand_strided((4096, 14336), (14336, 1), device='cuda:0', dtype=torch.int8)
arg28_1 = rand_strided((4096, 56), (56, 1), device='cuda:0', dtype=torch.bfloat16)
arg29_1 = rand_strided((4096, 56), (56, 1), device='cuda:0', dtype=torch.bfloat16)
arg30_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
arg31_1 = rand_strided((4096, 4096), (4096, 1), device='cuda:0', dtype=torch.int8)
arg32_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg33_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg34_1 = rand_strided((1024, 4096), (4096, 1), device='cuda:0', dtype=torch.int8)
arg35_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg36_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg37_1 = rand_strided((1024, 4096), (4096, 1), device='cuda:0', dtype=torch.int8)
arg38_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg39_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg40_1 = rand_strided((2048, 64, 2), (128, 2, 1), device='cuda:0', dtype=torch.float32)
arg41_1 = rand_strided((1, 32, 2048, 128), (8388608, 262144, 128, 1), device='cuda:0', dtype=torch.bfloat16)
arg42_1 = rand_strided((1, 32, 2048, 128), (8388608, 262144, 128, 1), device='cuda:0', dtype=torch.bfloat16)
arg43_1 = rand_strided((4096, 4096), (4096, 1), device='cuda:0', dtype=torch.int8)
arg44_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg45_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg46_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
arg47_1 = rand_strided((14336, 4096), (4096, 1), device='cuda:0', dtype=torch.int8)
arg48_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg49_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg50_1 = rand_strided((14336, 4096), (4096, 1), device='cuda:0', dtype=torch.int8)
arg51_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg52_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg53_1 = rand_strided((4096, 14336), (14336, 1), device='cuda:0', dtype=torch.int8)
arg54_1 = rand_strided((4096, 56), (56, 1), device='cuda:0', dtype=torch.bfloat16)
arg55_1 = rand_strided((4096, 56), (56, 1), device='cuda:0', dtype=torch.bfloat16)
arg56_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
arg57_1 = rand_strided((4096, 4096), (4096, 1), device='cuda:0', dtype=torch.int8)
arg58_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg59_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg60_1 = rand_strided((1024, 4096), (4096, 1), device='cuda:0', dtype=torch.int8)
arg61_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg62_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg63_1 = rand_strided((1024, 4096), (4096, 1), device='cuda:0', dtype=torch.int8)
arg64_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg65_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg66_1 = rand_strided((2048, 64, 2), (128, 2, 1), device='cuda:0', dtype=torch.float32)
arg67_1 = rand_strided((1, 32, 2048, 128), (8388608, 262144, 128, 1), device='cuda:0', dtype=torch.bfloat16)
arg68_1 = rand_strided((1, 32, 2048, 128), (8388608, 262144, 128, 1), device='cuda:0', dtype=torch.bfloat16)
arg69_1 = rand_strided((4096, 4096), (4096, 1), device='cuda:0', dtype=torch.int8)
arg70_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg71_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg72_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
arg73_1 = rand_strided((14336, 4096), (4096, 1), device='cuda:0', dtype=torch.int8)
arg74_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg75_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg76_1 = rand_strided((14336, 4096), (4096, 1), device='cuda:0', dtype=torch.int8)
arg77_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg78_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg79_1 = rand_strided((4096, 14336), (14336, 1), device='cuda:0', dtype=torch.int8)
arg80_1 = rand_strided((4096, 56), (56, 1), device='cuda:0', dtype=torch.bfloat16)
arg81_1 = rand_strided((4096, 56), (56, 1), device='cuda:0', dtype=torch.bfloat16)
arg82_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
arg83_1 = rand_strided((4096, 4096), (4096, 1), device='cuda:0', dtype=torch.int8)
arg84_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg85_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg86_1 = rand_strided((1024, 4096), (4096, 1), device='cuda:0', dtype=torch.int8)
arg87_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg88_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg89_1 = rand_strided((1024, 4096), (4096, 1), device='cuda:0', dtype=torch.int8)
arg90_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg91_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg92_1 = rand_strided((2048, 64, 2), (128, 2, 1), device='cuda:0', dtype=torch.float32)
arg93_1 = rand_strided((1, 32, 2048, 128), (8388608, 262144, 128, 1), device='cuda:0', dtype=torch.bfloat16)
arg94_1 = rand_strided((1, 32, 2048, 128), (8388608, 262144, 128, 1), device='cuda:0', dtype=torch.bfloat16)
arg95_1 = rand_strided((4096, 4096), (4096, 1), device='cuda:0', dtype=torch.int8)
arg96_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg97_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg98_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
arg99_1 = rand_strided((14336, 4096), (4096, 1), device='cuda:0', dtype=torch.int8)
arg100_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg101_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg102_1 = rand_strided((14336, 4096), (4096, 1), device='cuda:0', dtype=torch.int8)
arg103_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg104_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg105_1 = rand_strided((4096, 14336), (14336, 1), device='cuda:0', dtype=torch.int8)
arg106_1 = rand_strided((4096, 56), (56, 1), device='cuda:0', dtype=torch.bfloat16)
arg107_1 = rand_strided((4096, 56), (56, 1), device='cuda:0', dtype=torch.bfloat16)
arg108_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
arg109_1 = rand_strided((4096, 4096), (4096, 1), device='cuda:0', dtype=torch.int8)
arg110_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg111_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg112_1 = rand_strided((1024, 4096), (4096, 1), device='cuda:0', dtype=torch.int8)
arg113_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg114_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg115_1 = rand_strided((1024, 4096), (4096, 1), device='cuda:0', dtype=torch.int8)
arg116_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg117_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg118_1 = rand_strided((2048, 64, 2), (128, 2, 1), device='cuda:0', dtype=torch.float32)
arg119_1 = rand_strided((1, 32, 2048, 128), (8388608, 262144, 128, 1), device='cuda:0', dtype=torch.bfloat16)
arg120_1 = rand_strided((1, 32, 2048, 128), (8388608, 262144, 128, 1), device='cuda:0', dtype=torch.bfloat16)
arg121_1 = rand_strided((4096, 4096), (4096, 1), device='cuda:0', dtype=torch.int8)
arg122_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg123_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg124_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
arg125_1 = rand_strided((14336, 4096), (4096, 1), device='cuda:0', dtype=torch.int8)
arg126_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg127_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg128_1 = rand_strided((14336, 4096), (4096, 1), device='cuda:0', dtype=torch.int8)
arg129_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg130_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg131_1 = rand_strided((4096, 14336), (14336, 1), device='cuda:0', dtype=torch.int8)
arg132_1 = rand_strided((4096, 56), (56, 1), device='cuda:0', dtype=torch.bfloat16)
arg133_1 = rand_strided((4096, 56), (56, 1), device='cuda:0', dtype=torch.bfloat16)
arg134_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
arg135_1 = rand_strided((4096, 4096), (4096, 1), device='cuda:0', dtype=torch.int8)
arg136_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg137_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg138_1 = rand_strided((1024, 4096), (4096, 1), device='cuda:0', dtype=torch.int8)
arg139_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg140_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg141_1 = rand_strided((1024, 4096), (4096, 1), device='cuda:0', dtype=torch.int8)
arg142_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg143_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg144_1 = rand_strided((2048, 64, 2), (128, 2, 1), device='cuda:0', dtype=torch.float32)
arg145_1 = rand_strided((1, 32, 2048, 128), (8388608, 262144, 128, 1), device='cuda:0', dtype=torch.bfloat16)
arg146_1 = rand_strided((1, 32, 2048, 128), (8388608, 262144, 128, 1), device='cuda:0', dtype=torch.bfloat16)
arg147_1 = rand_strided((4096, 4096), (4096, 1), device='cuda:0', dtype=torch.int8)
arg148_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg149_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg150_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
arg151_1 = rand_strided((14336, 4096), (4096, 1), device='cuda:0', dtype=torch.int8)
arg152_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg153_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg154_1 = rand_strided((14336, 4096), (4096, 1), device='cuda:0', dtype=torch.int8)
arg155_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg156_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg157_1 = rand_strided((4096, 14336), (14336, 1), device='cuda:0', dtype=torch.int8)
arg158_1 = rand_strided((4096, 56), (56, 1), device='cuda:0', dtype=torch.bfloat16)
arg159_1 = rand_strided((4096, 56), (56, 1), device='cuda:0', dtype=torch.bfloat16)
arg160_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
arg161_1 = rand_strided((4096, 4096), (4096, 1), device='cuda:0', dtype=torch.int8)
arg162_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg163_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg164_1 = rand_strided((1024, 4096), (4096, 1), device='cuda:0', dtype=torch.int8)
arg165_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg166_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg167_1 = rand_strided((1024, 4096), (4096, 1), device='cuda:0', dtype=torch.int8)
arg168_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg169_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg170_1 = rand_strided((2048, 64, 2), (128, 2, 1), device='cuda:0', dtype=torch.float32)
arg171_1 = rand_strided((1, 32, 2048, 128), (8388608, 262144, 128, 1), device='cuda:0', dtype=torch.bfloat16)
arg172_1 = rand_strided((1, 32, 2048, 128), (8388608, 262144, 128, 1), device='cuda:0', dtype=torch.bfloat16)
arg173_1 = rand_strided((4096, 4096), (4096, 1), device='cuda:0', dtype=torch.int8)
arg174_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg175_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg176_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
arg177_1 = rand_strided((14336, 4096), (4096, 1), device='cuda:0', dtype=torch.int8)
arg178_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg179_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg180_1 = rand_strided((14336, 4096), (4096, 1), device='cuda:0', dtype=torch.int8)
arg181_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg182_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg183_1 = rand_strided((4096, 14336), (14336, 1), device='cuda:0', dtype=torch.int8)
arg184_1 = rand_strided((4096, 56), (56, 1), device='cuda:0', dtype=torch.bfloat16)
arg185_1 = rand_strided((4096, 56), (56, 1), device='cuda:0', dtype=torch.bfloat16)
arg186_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
arg187_1 = rand_strided((4096, 4096), (4096, 1), device='cuda:0', dtype=torch.int8)
arg188_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg189_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg190_1 = rand_strided((1024, 4096), (4096, 1), device='cuda:0', dtype=torch.int8)
arg191_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg192_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg193_1 = rand_strided((1024, 4096), (4096, 1), device='cuda:0', dtype=torch.int8)
arg194_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg195_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg196_1 = rand_strided((2048, 64, 2), (128, 2, 1), device='cuda:0', dtype=torch.float32)
arg197_1 = rand_strided((1, 32, 2048, 128), (8388608, 262144, 128, 1), device='cuda:0', dtype=torch.bfloat16)
arg198_1 = rand_strided((1, 32, 2048, 128), (8388608, 262144, 128, 1), device='cuda:0', dtype=torch.bfloat16)
arg199_1 = rand_strided((4096, 4096), (4096, 1), device='cuda:0', dtype=torch.int8)
arg200_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg201_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg202_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
arg203_1 = rand_strided((14336, 4096), (4096, 1), device='cuda:0', dtype=torch.int8)
arg204_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg205_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg206_1 = rand_strided((14336, 4096), (4096, 1), device='cuda:0', dtype=torch.int8)
arg207_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg208_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg209_1 = rand_strided((4096, 14336), (14336, 1), device='cuda:0', dtype=torch.int8)
arg210_1 = rand_strided((4096, 56), (56, 1), device='cuda:0', dtype=torch.bfloat16)
arg211_1 = rand_strided((4096, 56), (56, 1), device='cuda:0', dtype=torch.bfloat16)
arg212_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
arg213_1 = rand_strided((4096, 4096), (4096, 1), device='cuda:0', dtype=torch.int8)
arg214_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg215_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg216_1 = rand_strided((1024, 4096), (4096, 1), device='cuda:0', dtype=torch.int8)
arg217_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg218_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg219_1 = rand_strided((1024, 4096), (4096, 1), device='cuda:0', dtype=torch.int8)
arg220_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg221_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg222_1 = rand_strided((2048, 64, 2), (128, 2, 1), device='cuda:0', dtype=torch.float32)
arg223_1 = rand_strided((1, 32, 2048, 128), (8388608, 262144, 128, 1), device='cuda:0', dtype=torch.bfloat16)
arg224_1 = rand_strided((1, 32, 2048, 128), (8388608, 262144, 128, 1), device='cuda:0', dtype=torch.bfloat16)
arg225_1 = rand_strided((4096, 4096), (4096, 1), device='cuda:0', dtype=torch.int8)
arg226_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg227_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg228_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
arg229_1 = rand_strided((14336, 4096), (4096, 1), device='cuda:0', dtype=torch.int8)
arg230_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg231_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg232_1 = rand_strided((14336, 4096), (4096, 1), device='cuda:0', dtype=torch.int8)
arg233_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg234_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg235_1 = rand_strided((4096, 14336), (14336, 1), device='cuda:0', dtype=torch.int8)
arg236_1 = rand_strided((4096, 56), (56, 1), device='cuda:0', dtype=torch.bfloat16)
arg237_1 = rand_strided((4096, 56), (56, 1), device='cuda:0', dtype=torch.bfloat16)
arg238_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
arg239_1 = rand_strided((4096, 4096), (4096, 1), device='cuda:0', dtype=torch.int8)
arg240_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg241_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg242_1 = rand_strided((1024, 4096), (4096, 1), device='cuda:0', dtype=torch.int8)
arg243_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg244_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg245_1 = rand_strided((1024, 4096), (4096, 1), device='cuda:0', dtype=torch.int8)
arg246_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg247_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg248_1 = rand_strided((2048, 64, 2), (128, 2, 1), device='cuda:0', dtype=torch.float32)
arg249_1 = rand_strided((1, 32, 2048, 128), (8388608, 262144, 128, 1), device='cuda:0', dtype=torch.bfloat16)
arg250_1 = rand_strided((1, 32, 2048, 128), (8388608, 262144, 128, 1), device='cuda:0', dtype=torch.bfloat16)
arg251_1 = rand_strided((4096, 4096), (4096, 1), device='cuda:0', dtype=torch.int8)
arg252_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg253_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg254_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
arg255_1 = rand_strided((14336, 4096), (4096, 1), device='cuda:0', dtype=torch.int8)
arg256_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg257_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg258_1 = rand_strided((14336, 4096), (4096, 1), device='cuda:0', dtype=torch.int8)
arg259_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg260_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg261_1 = rand_strided((4096, 14336), (14336, 1), device='cuda:0', dtype=torch.int8)
arg262_1 = rand_strided((4096, 56), (56, 1), device='cuda:0', dtype=torch.bfloat16)
arg263_1 = rand_strided((4096, 56), (56, 1), device='cuda:0', dtype=torch.bfloat16)
arg264_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
arg265_1 = rand_strided((4096, 4096), (4096, 1), device='cuda:0', dtype=torch.int8)
arg266_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg267_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg268_1 = rand_strided((1024, 4096), (4096, 1), device='cuda:0', dtype=torch.int8)
arg269_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg270_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg271_1 = rand_strided((1024, 4096), (4096, 1), device='cuda:0', dtype=torch.int8)
arg272_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg273_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg274_1 = rand_strided((2048, 64, 2), (128, 2, 1), device='cuda:0', dtype=torch.float32)
arg275_1 = rand_strided((1, 32, 2048, 128), (8388608, 262144, 128, 1), device='cuda:0', dtype=torch.bfloat16)
arg276_1 = rand_strided((1, 32, 2048, 128), (8388608, 262144, 128, 1), device='cuda:0', dtype=torch.bfloat16)
arg277_1 = rand_strided((4096, 4096), (4096, 1), device='cuda:0', dtype=torch.int8)
arg278_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg279_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg280_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
arg281_1 = rand_strided((14336, 4096), (4096, 1), device='cuda:0', dtype=torch.int8)
arg282_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg283_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg284_1 = rand_strided((14336, 4096), (4096, 1), device='cuda:0', dtype=torch.int8)
arg285_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg286_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg287_1 = rand_strided((4096, 14336), (14336, 1), device='cuda:0', dtype=torch.int8)
arg288_1 = rand_strided((4096, 56), (56, 1), device='cuda:0', dtype=torch.bfloat16)
arg289_1 = rand_strided((4096, 56), (56, 1), device='cuda:0', dtype=torch.bfloat16)
arg290_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
arg291_1 = rand_strided((4096, 4096), (4096, 1), device='cuda:0', dtype=torch.int8)
arg292_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg293_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg294_1 = rand_strided((1024, 4096), (4096, 1), device='cuda:0', dtype=torch.int8)
arg295_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg296_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg297_1 = rand_strided((1024, 4096), (4096, 1), device='cuda:0', dtype=torch.int8)
arg298_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg299_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg300_1 = rand_strided((2048, 64, 2), (128, 2, 1), device='cuda:0', dtype=torch.float32)
arg301_1 = rand_strided((1, 32, 2048, 128), (8388608, 262144, 128, 1), device='cuda:0', dtype=torch.bfloat16)
arg302_1 = rand_strided((1, 32, 2048, 128), (8388608, 262144, 128, 1), device='cuda:0', dtype=torch.bfloat16)
arg303_1 = rand_strided((4096, 4096), (4096, 1), device='cuda:0', dtype=torch.int8)
arg304_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg305_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg306_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
arg307_1 = rand_strided((14336, 4096), (4096, 1), device='cuda:0', dtype=torch.int8)
arg308_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg309_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg310_1 = rand_strided((14336, 4096), (4096, 1), device='cuda:0', dtype=torch.int8)
arg311_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg312_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg313_1 = rand_strided((4096, 14336), (14336, 1), device='cuda:0', dtype=torch.int8)
arg314_1 = rand_strided((4096, 56), (56, 1), device='cuda:0', dtype=torch.bfloat16)
arg315_1 = rand_strided((4096, 56), (56, 1), device='cuda:0', dtype=torch.bfloat16)
arg316_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
arg317_1 = rand_strided((4096, 4096), (4096, 1), device='cuda:0', dtype=torch.int8)
arg318_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg319_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg320_1 = rand_strided((1024, 4096), (4096, 1), device='cuda:0', dtype=torch.int8)
arg321_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg322_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg323_1 = rand_strided((1024, 4096), (4096, 1), device='cuda:0', dtype=torch.int8)
arg324_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg325_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg326_1 = rand_strided((2048, 64, 2), (128, 2, 1), device='cuda:0', dtype=torch.float32)
arg327_1 = rand_strided((1, 32, 2048, 128), (8388608, 262144, 128, 1), device='cuda:0', dtype=torch.bfloat16)
arg328_1 = rand_strided((1, 32, 2048, 128), (8388608, 262144, 128, 1), device='cuda:0', dtype=torch.bfloat16)
arg329_1 = rand_strided((4096, 4096), (4096, 1), device='cuda:0', dtype=torch.int8)
arg330_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg331_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg332_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
arg333_1 = rand_strided((14336, 4096), (4096, 1), device='cuda:0', dtype=torch.int8)
arg334_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg335_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg336_1 = rand_strided((14336, 4096), (4096, 1), device='cuda:0', dtype=torch.int8)
arg337_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg338_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg339_1 = rand_strided((4096, 14336), (14336, 1), device='cuda:0', dtype=torch.int8)
arg340_1 = rand_strided((4096, 56), (56, 1), device='cuda:0', dtype=torch.bfloat16)
arg341_1 = rand_strided((4096, 56), (56, 1), device='cuda:0', dtype=torch.bfloat16)
arg342_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
arg343_1 = rand_strided((4096, 4096), (4096, 1), device='cuda:0', dtype=torch.int8)
arg344_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg345_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg346_1 = rand_strided((1024, 4096), (4096, 1), device='cuda:0', dtype=torch.int8)
arg347_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg348_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg349_1 = rand_strided((1024, 4096), (4096, 1), device='cuda:0', dtype=torch.int8)
arg350_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg351_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg352_1 = rand_strided((2048, 64, 2), (128, 2, 1), device='cuda:0', dtype=torch.float32)
arg353_1 = rand_strided((1, 32, 2048, 128), (8388608, 262144, 128, 1), device='cuda:0', dtype=torch.bfloat16)
arg354_1 = rand_strided((1, 32, 2048, 128), (8388608, 262144, 128, 1), device='cuda:0', dtype=torch.bfloat16)
arg355_1 = rand_strided((4096, 4096), (4096, 1), device='cuda:0', dtype=torch.int8)
arg356_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg357_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg358_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
arg359_1 = rand_strided((14336, 4096), (4096, 1), device='cuda:0', dtype=torch.int8)
arg360_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg361_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg362_1 = rand_strided((14336, 4096), (4096, 1), device='cuda:0', dtype=torch.int8)
arg363_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg364_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg365_1 = rand_strided((4096, 14336), (14336, 1), device='cuda:0', dtype=torch.int8)
arg366_1 = rand_strided((4096, 56), (56, 1), device='cuda:0', dtype=torch.bfloat16)
arg367_1 = rand_strided((4096, 56), (56, 1), device='cuda:0', dtype=torch.bfloat16)
arg368_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
arg369_1 = rand_strided((4096, 4096), (4096, 1), device='cuda:0', dtype=torch.int8)
arg370_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg371_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg372_1 = rand_strided((1024, 4096), (4096, 1), device='cuda:0', dtype=torch.int8)
arg373_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg374_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg375_1 = rand_strided((1024, 4096), (4096, 1), device='cuda:0', dtype=torch.int8)
arg376_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg377_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg378_1 = rand_strided((2048, 64, 2), (128, 2, 1), device='cuda:0', dtype=torch.float32)
arg379_1 = rand_strided((1, 32, 2048, 128), (8388608, 262144, 128, 1), device='cuda:0', dtype=torch.bfloat16)
arg380_1 = rand_strided((1, 32, 2048, 128), (8388608, 262144, 128, 1), device='cuda:0', dtype=torch.bfloat16)
arg381_1 = rand_strided((4096, 4096), (4096, 1), device='cuda:0', dtype=torch.int8)
arg382_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg383_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg384_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
arg385_1 = rand_strided((14336, 4096), (4096, 1), device='cuda:0', dtype=torch.int8)
arg386_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg387_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg388_1 = rand_strided((14336, 4096), (4096, 1), device='cuda:0', dtype=torch.int8)
arg389_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg390_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg391_1 = rand_strided((4096, 14336), (14336, 1), device='cuda:0', dtype=torch.int8)
arg392_1 = rand_strided((4096, 56), (56, 1), device='cuda:0', dtype=torch.bfloat16)
arg393_1 = rand_strided((4096, 56), (56, 1), device='cuda:0', dtype=torch.bfloat16)
arg394_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
arg395_1 = rand_strided((4096, 4096), (4096, 1), device='cuda:0', dtype=torch.int8)
arg396_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg397_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg398_1 = rand_strided((1024, 4096), (4096, 1), device='cuda:0', dtype=torch.int8)
arg399_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg400_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg401_1 = rand_strided((1024, 4096), (4096, 1), device='cuda:0', dtype=torch.int8)
arg402_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg403_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg404_1 = rand_strided((2048, 64, 2), (128, 2, 1), device='cuda:0', dtype=torch.float32)
arg405_1 = rand_strided((1, 32, 2048, 128), (8388608, 262144, 128, 1), device='cuda:0', dtype=torch.bfloat16)
arg406_1 = rand_strided((1, 32, 2048, 128), (8388608, 262144, 128, 1), device='cuda:0', dtype=torch.bfloat16)
arg407_1 = rand_strided((4096, 4096), (4096, 1), device='cuda:0', dtype=torch.int8)
arg408_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg409_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg410_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
arg411_1 = rand_strided((14336, 4096), (4096, 1), device='cuda:0', dtype=torch.int8)
arg412_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg413_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg414_1 = rand_strided((14336, 4096), (4096, 1), device='cuda:0', dtype=torch.int8)
arg415_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg416_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg417_1 = rand_strided((4096, 14336), (14336, 1), device='cuda:0', dtype=torch.int8)
arg418_1 = rand_strided((4096, 56), (56, 1), device='cuda:0', dtype=torch.bfloat16)
arg419_1 = rand_strided((4096, 56), (56, 1), device='cuda:0', dtype=torch.bfloat16)
arg420_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
arg421_1 = rand_strided((4096, 4096), (4096, 1), device='cuda:0', dtype=torch.int8)
arg422_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg423_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg424_1 = rand_strided((1024, 4096), (4096, 1), device='cuda:0', dtype=torch.int8)
arg425_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg426_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg427_1 = rand_strided((1024, 4096), (4096, 1), device='cuda:0', dtype=torch.int8)
arg428_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg429_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg430_1 = rand_strided((2048, 64, 2), (128, 2, 1), device='cuda:0', dtype=torch.float32)
arg431_1 = rand_strided((1, 32, 2048, 128), (8388608, 262144, 128, 1), device='cuda:0', dtype=torch.bfloat16)
arg432_1 = rand_strided((1, 32, 2048, 128), (8388608, 262144, 128, 1), device='cuda:0', dtype=torch.bfloat16)
arg433_1 = rand_strided((4096, 4096), (4096, 1), device='cuda:0', dtype=torch.int8)
arg434_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg435_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg436_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
arg437_1 = rand_strided((14336, 4096), (4096, 1), device='cuda:0', dtype=torch.int8)
arg438_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg439_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg440_1 = rand_strided((14336, 4096), (4096, 1), device='cuda:0', dtype=torch.int8)
arg441_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg442_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg443_1 = rand_strided((4096, 14336), (14336, 1), device='cuda:0', dtype=torch.int8)
arg444_1 = rand_strided((4096, 56), (56, 1), device='cuda:0', dtype=torch.bfloat16)
arg445_1 = rand_strided((4096, 56), (56, 1), device='cuda:0', dtype=torch.bfloat16)
arg446_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
arg447_1 = rand_strided((4096, 4096), (4096, 1), device='cuda:0', dtype=torch.int8)
arg448_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg449_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg450_1 = rand_strided((1024, 4096), (4096, 1), device='cuda:0', dtype=torch.int8)
arg451_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg452_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg453_1 = rand_strided((1024, 4096), (4096, 1), device='cuda:0', dtype=torch.int8)
arg454_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg455_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg456_1 = rand_strided((2048, 64, 2), (128, 2, 1), device='cuda:0', dtype=torch.float32)
arg457_1 = rand_strided((1, 32, 2048, 128), (8388608, 262144, 128, 1), device='cuda:0', dtype=torch.bfloat16)
arg458_1 = rand_strided((1, 32, 2048, 128), (8388608, 262144, 128, 1), device='cuda:0', dtype=torch.bfloat16)
arg459_1 = rand_strided((4096, 4096), (4096, 1), device='cuda:0', dtype=torch.int8)
arg460_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg461_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg462_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
arg463_1 = rand_strided((14336, 4096), (4096, 1), device='cuda:0', dtype=torch.int8)
arg464_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg465_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg466_1 = rand_strided((14336, 4096), (4096, 1), device='cuda:0', dtype=torch.int8)
arg467_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg468_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg469_1 = rand_strided((4096, 14336), (14336, 1), device='cuda:0', dtype=torch.int8)
arg470_1 = rand_strided((4096, 56), (56, 1), device='cuda:0', dtype=torch.bfloat16)
arg471_1 = rand_strided((4096, 56), (56, 1), device='cuda:0', dtype=torch.bfloat16)
arg472_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
arg473_1 = rand_strided((4096, 4096), (4096, 1), device='cuda:0', dtype=torch.int8)
arg474_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg475_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg476_1 = rand_strided((1024, 4096), (4096, 1), device='cuda:0', dtype=torch.int8)
arg477_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg478_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg479_1 = rand_strided((1024, 4096), (4096, 1), device='cuda:0', dtype=torch.int8)
arg480_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg481_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg482_1 = rand_strided((2048, 64, 2), (128, 2, 1), device='cuda:0', dtype=torch.float32)
arg483_1 = rand_strided((1, 32, 2048, 128), (8388608, 262144, 128, 1), device='cuda:0', dtype=torch.bfloat16)
arg484_1 = rand_strided((1, 32, 2048, 128), (8388608, 262144, 128, 1), device='cuda:0', dtype=torch.bfloat16)
arg485_1 = rand_strided((4096, 4096), (4096, 1), device='cuda:0', dtype=torch.int8)
arg486_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg487_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg488_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
arg489_1 = rand_strided((14336, 4096), (4096, 1), device='cuda:0', dtype=torch.int8)
arg490_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg491_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg492_1 = rand_strided((14336, 4096), (4096, 1), device='cuda:0', dtype=torch.int8)
arg493_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg494_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg495_1 = rand_strided((4096, 14336), (14336, 1), device='cuda:0', dtype=torch.int8)
arg496_1 = rand_strided((4096, 56), (56, 1), device='cuda:0', dtype=torch.bfloat16)
arg497_1 = rand_strided((4096, 56), (56, 1), device='cuda:0', dtype=torch.bfloat16)
arg498_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
arg499_1 = rand_strided((4096, 4096), (4096, 1), device='cuda:0', dtype=torch.int8)
arg500_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg501_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg502_1 = rand_strided((1024, 4096), (4096, 1), device='cuda:0', dtype=torch.int8)
arg503_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg504_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg505_1 = rand_strided((1024, 4096), (4096, 1), device='cuda:0', dtype=torch.int8)
arg506_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg507_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg508_1 = rand_strided((2048, 64, 2), (128, 2, 1), device='cuda:0', dtype=torch.float32)
arg509_1 = rand_strided((1, 32, 2048, 128), (8388608, 262144, 128, 1), device='cuda:0', dtype=torch.bfloat16)
arg510_1 = rand_strided((1, 32, 2048, 128), (8388608, 262144, 128, 1), device='cuda:0', dtype=torch.bfloat16)
arg511_1 = rand_strided((4096, 4096), (4096, 1), device='cuda:0', dtype=torch.int8)
arg512_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg513_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg514_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
arg515_1 = rand_strided((14336, 4096), (4096, 1), device='cuda:0', dtype=torch.int8)
arg516_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg517_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg518_1 = rand_strided((14336, 4096), (4096, 1), device='cuda:0', dtype=torch.int8)
arg519_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg520_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg521_1 = rand_strided((4096, 14336), (14336, 1), device='cuda:0', dtype=torch.int8)
arg522_1 = rand_strided((4096, 56), (56, 1), device='cuda:0', dtype=torch.bfloat16)
arg523_1 = rand_strided((4096, 56), (56, 1), device='cuda:0', dtype=torch.bfloat16)
arg524_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
arg525_1 = rand_strided((4096, 4096), (4096, 1), device='cuda:0', dtype=torch.int8)
arg526_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg527_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg528_1 = rand_strided((1024, 4096), (4096, 1), device='cuda:0', dtype=torch.int8)
arg529_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg530_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg531_1 = rand_strided((1024, 4096), (4096, 1), device='cuda:0', dtype=torch.int8)
arg532_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg533_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg534_1 = rand_strided((2048, 64, 2), (128, 2, 1), device='cuda:0', dtype=torch.float32)
arg535_1 = rand_strided((1, 32, 2048, 128), (8388608, 262144, 128, 1), device='cuda:0', dtype=torch.bfloat16)
arg536_1 = rand_strided((1, 32, 2048, 128), (8388608, 262144, 128, 1), device='cuda:0', dtype=torch.bfloat16)
arg537_1 = rand_strided((4096, 4096), (4096, 1), device='cuda:0', dtype=torch.int8)
arg538_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg539_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg540_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
arg541_1 = rand_strided((14336, 4096), (4096, 1), device='cuda:0', dtype=torch.int8)
arg542_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg543_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg544_1 = rand_strided((14336, 4096), (4096, 1), device='cuda:0', dtype=torch.int8)
arg545_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg546_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg547_1 = rand_strided((4096, 14336), (14336, 1), device='cuda:0', dtype=torch.int8)
arg548_1 = rand_strided((4096, 56), (56, 1), device='cuda:0', dtype=torch.bfloat16)
arg549_1 = rand_strided((4096, 56), (56, 1), device='cuda:0', dtype=torch.bfloat16)
arg550_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
arg551_1 = rand_strided((4096, 4096), (4096, 1), device='cuda:0', dtype=torch.int8)
arg552_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg553_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg554_1 = rand_strided((1024, 4096), (4096, 1), device='cuda:0', dtype=torch.int8)
arg555_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg556_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg557_1 = rand_strided((1024, 4096), (4096, 1), device='cuda:0', dtype=torch.int8)
arg558_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg559_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg560_1 = rand_strided((2048, 64, 2), (128, 2, 1), device='cuda:0', dtype=torch.float32)
arg561_1 = rand_strided((1, 32, 2048, 128), (8388608, 262144, 128, 1), device='cuda:0', dtype=torch.bfloat16)
arg562_1 = rand_strided((1, 32, 2048, 128), (8388608, 262144, 128, 1), device='cuda:0', dtype=torch.bfloat16)
arg563_1 = rand_strided((4096, 4096), (4096, 1), device='cuda:0', dtype=torch.int8)
arg564_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg565_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg566_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
arg567_1 = rand_strided((14336, 4096), (4096, 1), device='cuda:0', dtype=torch.int8)
arg568_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg569_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg570_1 = rand_strided((14336, 4096), (4096, 1), device='cuda:0', dtype=torch.int8)
arg571_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg572_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg573_1 = rand_strided((4096, 14336), (14336, 1), device='cuda:0', dtype=torch.int8)
arg574_1 = rand_strided((4096, 56), (56, 1), device='cuda:0', dtype=torch.bfloat16)
arg575_1 = rand_strided((4096, 56), (56, 1), device='cuda:0', dtype=torch.bfloat16)
arg576_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
arg577_1 = rand_strided((4096, 4096), (4096, 1), device='cuda:0', dtype=torch.int8)
arg578_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg579_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg580_1 = rand_strided((1024, 4096), (4096, 1), device='cuda:0', dtype=torch.int8)
arg581_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg582_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg583_1 = rand_strided((1024, 4096), (4096, 1), device='cuda:0', dtype=torch.int8)
arg584_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg585_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg586_1 = rand_strided((2048, 64, 2), (128, 2, 1), device='cuda:0', dtype=torch.float32)
arg587_1 = rand_strided((1, 32, 2048, 128), (8388608, 262144, 128, 1), device='cuda:0', dtype=torch.bfloat16)
arg588_1 = rand_strided((1, 32, 2048, 128), (8388608, 262144, 128, 1), device='cuda:0', dtype=torch.bfloat16)
arg589_1 = rand_strided((4096, 4096), (4096, 1), device='cuda:0', dtype=torch.int8)
arg590_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg591_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg592_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
arg593_1 = rand_strided((14336, 4096), (4096, 1), device='cuda:0', dtype=torch.int8)
arg594_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg595_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg596_1 = rand_strided((14336, 4096), (4096, 1), device='cuda:0', dtype=torch.int8)
arg597_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg598_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg599_1 = rand_strided((4096, 14336), (14336, 1), device='cuda:0', dtype=torch.int8)
arg600_1 = rand_strided((4096, 56), (56, 1), device='cuda:0', dtype=torch.bfloat16)
arg601_1 = rand_strided((4096, 56), (56, 1), device='cuda:0', dtype=torch.bfloat16)
arg602_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
arg603_1 = rand_strided((4096, 4096), (4096, 1), device='cuda:0', dtype=torch.int8)
arg604_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg605_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg606_1 = rand_strided((1024, 4096), (4096, 1), device='cuda:0', dtype=torch.int8)
arg607_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg608_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg609_1 = rand_strided((1024, 4096), (4096, 1), device='cuda:0', dtype=torch.int8)
arg610_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg611_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg612_1 = rand_strided((2048, 64, 2), (128, 2, 1), device='cuda:0', dtype=torch.float32)
arg613_1 = rand_strided((1, 32, 2048, 128), (8388608, 262144, 128, 1), device='cuda:0', dtype=torch.bfloat16)
arg614_1 = rand_strided((1, 32, 2048, 128), (8388608, 262144, 128, 1), device='cuda:0', dtype=torch.bfloat16)
arg615_1 = rand_strided((4096, 4096), (4096, 1), device='cuda:0', dtype=torch.int8)
arg616_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg617_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg618_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
arg619_1 = rand_strided((14336, 4096), (4096, 1), device='cuda:0', dtype=torch.int8)
arg620_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg621_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg622_1 = rand_strided((14336, 4096), (4096, 1), device='cuda:0', dtype=torch.int8)
arg623_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg624_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg625_1 = rand_strided((4096, 14336), (14336, 1), device='cuda:0', dtype=torch.int8)
arg626_1 = rand_strided((4096, 56), (56, 1), device='cuda:0', dtype=torch.bfloat16)
arg627_1 = rand_strided((4096, 56), (56, 1), device='cuda:0', dtype=torch.bfloat16)
arg628_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
arg629_1 = rand_strided((4096, 4096), (4096, 1), device='cuda:0', dtype=torch.int8)
arg630_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg631_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg632_1 = rand_strided((1024, 4096), (4096, 1), device='cuda:0', dtype=torch.int8)
arg633_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg634_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg635_1 = rand_strided((1024, 4096), (4096, 1), device='cuda:0', dtype=torch.int8)
arg636_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg637_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg638_1 = rand_strided((2048, 64, 2), (128, 2, 1), device='cuda:0', dtype=torch.float32)
arg639_1 = rand_strided((1, 32, 2048, 128), (8388608, 262144, 128, 1), device='cuda:0', dtype=torch.bfloat16)
arg640_1 = rand_strided((1, 32, 2048, 128), (8388608, 262144, 128, 1), device='cuda:0', dtype=torch.bfloat16)
arg641_1 = rand_strided((4096, 4096), (4096, 1), device='cuda:0', dtype=torch.int8)
arg642_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg643_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg644_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
arg645_1 = rand_strided((14336, 4096), (4096, 1), device='cuda:0', dtype=torch.int8)
arg646_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg647_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg648_1 = rand_strided((14336, 4096), (4096, 1), device='cuda:0', dtype=torch.int8)
arg649_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg650_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg651_1 = rand_strided((4096, 14336), (14336, 1), device='cuda:0', dtype=torch.int8)
arg652_1 = rand_strided((4096, 56), (56, 1), device='cuda:0', dtype=torch.bfloat16)
arg653_1 = rand_strided((4096, 56), (56, 1), device='cuda:0', dtype=torch.bfloat16)
arg654_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
arg655_1 = rand_strided((4096, 4096), (4096, 1), device='cuda:0', dtype=torch.int8)
arg656_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg657_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg658_1 = rand_strided((1024, 4096), (4096, 1), device='cuda:0', dtype=torch.int8)
arg659_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg660_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg661_1 = rand_strided((1024, 4096), (4096, 1), device='cuda:0', dtype=torch.int8)
arg662_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg663_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg664_1 = rand_strided((2048, 64, 2), (128, 2, 1), device='cuda:0', dtype=torch.float32)
arg665_1 = rand_strided((1, 32, 2048, 128), (8388608, 262144, 128, 1), device='cuda:0', dtype=torch.bfloat16)
arg666_1 = rand_strided((1, 32, 2048, 128), (8388608, 262144, 128, 1), device='cuda:0', dtype=torch.bfloat16)
arg667_1 = rand_strided((4096, 4096), (4096, 1), device='cuda:0', dtype=torch.int8)
arg668_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg669_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg670_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
arg671_1 = rand_strided((14336, 4096), (4096, 1), device='cuda:0', dtype=torch.int8)
arg672_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg673_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg674_1 = rand_strided((14336, 4096), (4096, 1), device='cuda:0', dtype=torch.int8)
arg675_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg676_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg677_1 = rand_strided((4096, 14336), (14336, 1), device='cuda:0', dtype=torch.int8)
arg678_1 = rand_strided((4096, 56), (56, 1), device='cuda:0', dtype=torch.bfloat16)
arg679_1 = rand_strided((4096, 56), (56, 1), device='cuda:0', dtype=torch.bfloat16)
arg680_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
arg681_1 = rand_strided((4096, 4096), (4096, 1), device='cuda:0', dtype=torch.int8)
arg682_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg683_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg684_1 = rand_strided((1024, 4096), (4096, 1), device='cuda:0', dtype=torch.int8)
arg685_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg686_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg687_1 = rand_strided((1024, 4096), (4096, 1), device='cuda:0', dtype=torch.int8)
arg688_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg689_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg690_1 = rand_strided((2048, 64, 2), (128, 2, 1), device='cuda:0', dtype=torch.float32)
arg691_1 = rand_strided((1, 32, 2048, 128), (8388608, 262144, 128, 1), device='cuda:0', dtype=torch.bfloat16)
arg692_1 = rand_strided((1, 32, 2048, 128), (8388608, 262144, 128, 1), device='cuda:0', dtype=torch.bfloat16)
arg693_1 = rand_strided((4096, 4096), (4096, 1), device='cuda:0', dtype=torch.int8)
arg694_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg695_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg696_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
arg697_1 = rand_strided((14336, 4096), (4096, 1), device='cuda:0', dtype=torch.int8)
arg698_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg699_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg700_1 = rand_strided((14336, 4096), (4096, 1), device='cuda:0', dtype=torch.int8)
arg701_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg702_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg703_1 = rand_strided((4096, 14336), (14336, 1), device='cuda:0', dtype=torch.int8)
arg704_1 = rand_strided((4096, 56), (56, 1), device='cuda:0', dtype=torch.bfloat16)
arg705_1 = rand_strided((4096, 56), (56, 1), device='cuda:0', dtype=torch.bfloat16)
arg706_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
arg707_1 = rand_strided((4096, 4096), (4096, 1), device='cuda:0', dtype=torch.int8)
arg708_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg709_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg710_1 = rand_strided((1024, 4096), (4096, 1), device='cuda:0', dtype=torch.int8)
arg711_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg712_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg713_1 = rand_strided((1024, 4096), (4096, 1), device='cuda:0', dtype=torch.int8)
arg714_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg715_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg716_1 = rand_strided((2048, 64, 2), (128, 2, 1), device='cuda:0', dtype=torch.float32)
arg717_1 = rand_strided((1, 32, 2048, 128), (8388608, 262144, 128, 1), device='cuda:0', dtype=torch.bfloat16)
arg718_1 = rand_strided((1, 32, 2048, 128), (8388608, 262144, 128, 1), device='cuda:0', dtype=torch.bfloat16)
arg719_1 = rand_strided((4096, 4096), (4096, 1), device='cuda:0', dtype=torch.int8)
arg720_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg721_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg722_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
arg723_1 = rand_strided((14336, 4096), (4096, 1), device='cuda:0', dtype=torch.int8)
arg724_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg725_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg726_1 = rand_strided((14336, 4096), (4096, 1), device='cuda:0', dtype=torch.int8)
arg727_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg728_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg729_1 = rand_strided((4096, 14336), (14336, 1), device='cuda:0', dtype=torch.int8)
arg730_1 = rand_strided((4096, 56), (56, 1), device='cuda:0', dtype=torch.bfloat16)
arg731_1 = rand_strided((4096, 56), (56, 1), device='cuda:0', dtype=torch.bfloat16)
arg732_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
arg733_1 = rand_strided((4096, 4096), (4096, 1), device='cuda:0', dtype=torch.int8)
arg734_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg735_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg736_1 = rand_strided((1024, 4096), (4096, 1), device='cuda:0', dtype=torch.int8)
arg737_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg738_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg739_1 = rand_strided((1024, 4096), (4096, 1), device='cuda:0', dtype=torch.int8)
arg740_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg741_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg742_1 = rand_strided((2048, 64, 2), (128, 2, 1), device='cuda:0', dtype=torch.float32)
arg743_1 = rand_strided((1, 32, 2048, 128), (8388608, 262144, 128, 1), device='cuda:0', dtype=torch.bfloat16)
arg744_1 = rand_strided((1, 32, 2048, 128), (8388608, 262144, 128, 1), device='cuda:0', dtype=torch.bfloat16)
arg745_1 = rand_strided((4096, 4096), (4096, 1), device='cuda:0', dtype=torch.int8)
arg746_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg747_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg748_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
arg749_1 = rand_strided((14336, 4096), (4096, 1), device='cuda:0', dtype=torch.int8)
arg750_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg751_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg752_1 = rand_strided((14336, 4096), (4096, 1), device='cuda:0', dtype=torch.int8)
arg753_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg754_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg755_1 = rand_strided((4096, 14336), (14336, 1), device='cuda:0', dtype=torch.int8)
arg756_1 = rand_strided((4096, 56), (56, 1), device='cuda:0', dtype=torch.bfloat16)
arg757_1 = rand_strided((4096, 56), (56, 1), device='cuda:0', dtype=torch.bfloat16)
arg758_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
arg759_1 = rand_strided((4096, 4096), (4096, 1), device='cuda:0', dtype=torch.int8)
arg760_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg761_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg762_1 = rand_strided((1024, 4096), (4096, 1), device='cuda:0', dtype=torch.int8)
arg763_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg764_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg765_1 = rand_strided((1024, 4096), (4096, 1), device='cuda:0', dtype=torch.int8)
arg766_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg767_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg768_1 = rand_strided((2048, 64, 2), (128, 2, 1), device='cuda:0', dtype=torch.float32)
arg769_1 = rand_strided((1, 32, 2048, 128), (8388608, 262144, 128, 1), device='cuda:0', dtype=torch.bfloat16)
arg770_1 = rand_strided((1, 32, 2048, 128), (8388608, 262144, 128, 1), device='cuda:0', dtype=torch.bfloat16)
arg771_1 = rand_strided((4096, 4096), (4096, 1), device='cuda:0', dtype=torch.int8)
arg772_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg773_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg774_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
arg775_1 = rand_strided((14336, 4096), (4096, 1), device='cuda:0', dtype=torch.int8)
arg776_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg777_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg778_1 = rand_strided((14336, 4096), (4096, 1), device='cuda:0', dtype=torch.int8)
arg779_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg780_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg781_1 = rand_strided((4096, 14336), (14336, 1), device='cuda:0', dtype=torch.int8)
arg782_1 = rand_strided((4096, 56), (56, 1), device='cuda:0', dtype=torch.bfloat16)
arg783_1 = rand_strided((4096, 56), (56, 1), device='cuda:0', dtype=torch.bfloat16)
arg784_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
arg785_1 = rand_strided((4096, 4096), (4096, 1), device='cuda:0', dtype=torch.int8)
arg786_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg787_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg788_1 = rand_strided((1024, 4096), (4096, 1), device='cuda:0', dtype=torch.int8)
arg789_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg790_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg791_1 = rand_strided((1024, 4096), (4096, 1), device='cuda:0', dtype=torch.int8)
arg792_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg793_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg794_1 = rand_strided((2048, 64, 2), (128, 2, 1), device='cuda:0', dtype=torch.float32)
arg795_1 = rand_strided((1, 32, 2048, 128), (8388608, 262144, 128, 1), device='cuda:0', dtype=torch.bfloat16)
arg796_1 = rand_strided((1, 32, 2048, 128), (8388608, 262144, 128, 1), device='cuda:0', dtype=torch.bfloat16)
arg797_1 = rand_strided((4096, 4096), (4096, 1), device='cuda:0', dtype=torch.int8)
arg798_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg799_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg800_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
arg801_1 = rand_strided((14336, 4096), (4096, 1), device='cuda:0', dtype=torch.int8)
arg802_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg803_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg804_1 = rand_strided((14336, 4096), (4096, 1), device='cuda:0', dtype=torch.int8)
arg805_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg806_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg807_1 = rand_strided((4096, 14336), (14336, 1), device='cuda:0', dtype=torch.int8)
arg808_1 = rand_strided((4096, 56), (56, 1), device='cuda:0', dtype=torch.bfloat16)
arg809_1 = rand_strided((4096, 56), (56, 1), device='cuda:0', dtype=torch.bfloat16)
arg810_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
arg811_1 = rand_strided((4096, 4096), (4096, 1), device='cuda:0', dtype=torch.int8)
arg812_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg813_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg814_1 = rand_strided((1024, 4096), (4096, 1), device='cuda:0', dtype=torch.int8)
arg815_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg816_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg817_1 = rand_strided((1024, 4096), (4096, 1), device='cuda:0', dtype=torch.int8)
arg818_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg819_1 = rand_strided((1024, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg820_1 = rand_strided((2048, 64, 2), (128, 2, 1), device='cuda:0', dtype=torch.float32)
arg821_1 = rand_strided((1, 32, 2048, 128), (8388608, 262144, 128, 1), device='cuda:0', dtype=torch.bfloat16)
arg822_1 = rand_strided((1, 32, 2048, 128), (8388608, 262144, 128, 1), device='cuda:0', dtype=torch.bfloat16)
arg823_1 = rand_strided((4096, 4096), (4096, 1), device='cuda:0', dtype=torch.int8)
arg824_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg825_1 = rand_strided((4096, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg826_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
arg827_1 = rand_strided((14336, 4096), (4096, 1), device='cuda:0', dtype=torch.int8)
arg828_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg829_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg830_1 = rand_strided((14336, 4096), (4096, 1), device='cuda:0', dtype=torch.int8)
arg831_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg832_1 = rand_strided((14336, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg833_1 = rand_strided((4096, 14336), (14336, 1), device='cuda:0', dtype=torch.int8)
arg834_1 = rand_strided((4096, 56), (56, 1), device='cuda:0', dtype=torch.bfloat16)
arg835_1 = rand_strided((4096, 56), (56, 1), device='cuda:0', dtype=torch.bfloat16)
arg836_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
arg837_1 = rand_strided((128256, 4096), (4096, 1), device='cuda:0', dtype=torch.int8)
arg838_1 = rand_strided((128256, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
arg839_1 = rand_strided((128256, 16), (16, 1), device='cuda:0', dtype=torch.bfloat16)
fn = lambda: call([arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1, arg7_1, arg8_1, arg9_1, arg10_1, arg11_1, arg12_1, arg13_1, arg14_1, arg15_1, arg16_1, arg17_1, arg18_1, arg19_1, arg20_1, arg21_1, arg22_1, arg23_1, arg24_1, arg25_1, arg26_1, arg27_1, arg28_1, arg29_1, arg30_1, arg31_1, arg32_1, arg33_1, arg34_1, arg35_1, arg36_1, arg37_1, arg38_1, arg39_1, arg40_1, arg41_1, arg42_1, arg43_1, arg44_1, arg45_1, arg46_1, arg47_1, arg48_1, arg49_1, arg50_1, arg51_1, arg52_1, arg53_1, arg54_1, arg55_1, arg56_1, arg57_1, arg58_1, arg59_1, arg60_1, arg61_1, arg62_1, arg63_1, arg64_1, arg65_1, arg66_1, arg67_1, arg68_1, arg69_1, arg70_1, arg71_1, arg72_1, arg73_1, arg74_1, arg75_1, arg76_1, arg77_1, arg78_1, arg79_1, arg80_1, arg81_1, arg82_1, arg83_1, arg84_1, arg85_1, arg86_1, arg87_1, arg88_1, arg89_1, arg90_1, arg91_1, arg92_1, arg93_1, arg94_1, arg95_1, arg96_1, arg97_1, arg98_1, arg99_1, arg100_1, arg101_1, arg102_1, arg103_1, arg104_1, arg105_1, arg106_1, arg107_1, arg108_1, arg109_1, arg110_1, arg111_1, arg112_1, arg113_1, arg114_1, arg115_1, arg116_1, arg117_1, arg118_1, arg119_1, arg120_1, arg121_1, arg122_1, arg123_1, arg124_1, arg125_1, arg126_1, arg127_1, arg128_1, arg129_1, arg130_1, arg131_1, arg132_1, arg133_1, arg134_1, arg135_1, arg136_1, arg137_1, arg138_1, arg139_1, arg140_1, arg141_1, arg142_1, arg143_1, arg144_1, arg145_1, arg146_1, arg147_1, arg148_1, arg149_1, arg150_1, arg151_1, arg152_1, arg153_1, arg154_1, arg155_1, arg156_1, arg157_1, arg158_1, arg159_1, arg160_1, arg161_1, arg162_1, arg163_1, arg164_1, arg165_1, arg166_1, arg167_1, arg168_1, arg169_1, arg170_1, arg171_1, arg172_1, arg173_1, arg174_1, arg175_1, arg176_1, arg177_1, arg178_1, arg179_1, arg180_1, arg181_1, arg182_1, arg183_1, arg184_1, arg185_1, arg186_1, arg187_1, arg188_1, arg189_1, arg190_1, arg191_1, arg192_1, arg193_1, arg194_1, arg195_1, arg196_1, arg197_1, arg198_1, arg199_1, arg200_1, arg201_1, arg202_1, arg203_1, arg204_1, arg205_1, arg206_1, arg207_1, arg208_1, arg209_1, arg210_1, arg211_1, arg212_1, arg213_1, arg214_1, arg215_1, arg216_1, arg217_1, arg218_1, arg219_1, arg220_1, arg221_1, arg222_1, arg223_1, arg224_1, arg225_1, arg226_1, arg227_1, arg228_1, arg229_1, arg230_1, arg231_1, arg232_1, arg233_1, arg234_1, arg235_1, arg236_1, arg237_1, arg238_1, arg239_1, arg240_1, arg241_1, arg242_1, arg243_1, arg244_1, arg245_1, arg246_1, arg247_1, arg248_1, arg249_1, arg250_1, arg251_1, arg252_1, arg253_1, arg254_1, arg255_1, arg256_1, arg257_1, arg258_1, arg259_1, arg260_1, arg261_1, arg262_1, arg263_1, arg264_1, arg265_1, arg266_1, arg267_1, arg268_1, arg269_1, arg270_1, arg271_1, arg272_1, arg273_1, arg274_1, arg275_1, arg276_1, arg277_1, arg278_1, arg279_1, arg280_1, arg281_1, arg282_1, arg283_1, arg284_1, arg285_1, arg286_1, arg287_1, arg288_1, arg289_1, arg290_1, arg291_1, arg292_1, arg293_1, arg294_1, arg295_1, arg296_1, arg297_1, arg298_1, arg299_1, arg300_1, arg301_1, arg302_1, arg303_1, arg304_1, arg305_1, arg306_1, arg307_1, arg308_1, arg309_1, arg310_1, arg311_1, arg312_1, arg313_1, arg314_1, arg315_1, arg316_1, arg317_1, arg318_1, arg319_1, arg320_1, arg321_1, arg322_1, arg323_1, arg324_1, arg325_1, arg326_1, arg327_1, arg328_1, arg329_1, arg330_1, arg331_1, arg332_1, arg333_1, arg334_1, arg335_1, arg336_1, arg337_1, arg338_1, arg339_1, arg340_1, arg341_1, arg342_1, arg343_1, arg344_1, arg345_1, arg346_1, arg347_1, arg348_1, arg349_1, arg350_1, arg351_1, arg352_1, arg353_1, arg354_1, arg355_1, arg356_1, arg357_1, arg358_1, arg359_1, arg360_1, arg361_1, arg362_1, arg363_1, arg364_1, arg365_1, arg366_1, arg367_1, arg368_1, arg369_1, arg370_1, arg371_1, arg372_1, arg373_1, arg374_1, arg375_1, arg376_1, arg377_1, arg378_1, arg379_1, arg380_1, arg381_1, arg382_1, arg383_1, arg384_1, arg385_1, arg386_1, arg387_1, arg388_1, arg389_1, arg390_1, arg391_1, arg392_1, arg393_1, arg394_1, arg395_1, arg396_1, arg397_1, arg398_1, arg399_1, arg400_1, arg401_1, arg402_1, arg403_1, arg404_1, arg405_1, arg406_1, arg407_1, arg408_1, arg409_1, arg410_1, arg411_1, arg412_1, arg413_1, arg414_1, arg415_1, arg416_1, arg417_1, arg418_1, arg419_1, arg420_1, arg421_1, arg422_1, arg423_1, arg424_1, arg425_1, arg426_1, arg427_1, arg428_1, arg429_1, arg430_1, arg431_1, arg432_1, arg433_1, arg434_1, arg435_1, arg436_1, arg437_1, arg438_1, arg439_1, arg440_1, arg441_1, arg442_1, arg443_1, arg444_1, arg445_1, arg446_1, arg447_1, arg448_1, arg449_1, arg450_1, arg451_1, arg452_1, arg453_1, arg454_1, arg455_1, arg456_1, arg457_1, arg458_1, arg459_1, arg460_1, arg461_1, arg462_1, arg463_1, arg464_1, arg465_1, arg466_1, arg467_1, arg468_1, arg469_1, arg470_1, arg471_1, arg472_1, arg473_1, arg474_1, arg475_1, arg476_1, arg477_1, arg478_1, arg479_1, arg480_1, arg481_1, arg482_1, arg483_1, arg484_1, arg485_1, arg486_1, arg487_1, arg488_1, arg489_1, arg490_1, arg491_1, arg492_1, arg493_1, arg494_1, arg495_1, arg496_1, arg497_1, arg498_1, arg499_1, arg500_1, arg501_1, arg502_1, arg503_1, arg504_1, arg505_1, arg506_1, arg507_1, arg508_1, arg509_1, arg510_1, arg511_1, arg512_1, arg513_1, arg514_1, arg515_1, arg516_1, arg517_1, arg518_1, arg519_1, arg520_1, arg521_1, arg522_1, arg523_1, arg524_1, arg525_1, arg526_1, arg527_1, arg528_1, arg529_1, arg530_1, arg531_1, arg532_1, arg533_1, arg534_1, arg535_1, arg536_1, arg537_1, arg538_1, arg539_1, arg540_1, arg541_1, arg542_1, arg543_1, arg544_1, arg545_1, arg546_1, arg547_1, arg548_1, arg549_1, arg550_1, arg551_1, arg552_1, arg553_1, arg554_1, arg555_1, arg556_1, arg557_1, arg558_1, arg559_1, arg560_1, arg561_1, arg562_1, arg563_1, arg564_1, arg565_1, arg566_1, arg567_1, arg568_1, arg569_1, arg570_1, arg571_1, arg572_1, arg573_1, arg574_1, arg575_1, arg576_1, arg577_1, arg578_1, arg579_1, arg580_1, arg581_1, arg582_1, arg583_1, arg584_1, arg585_1, arg586_1, arg587_1, arg588_1, arg589_1, arg590_1, arg591_1, arg592_1, arg593_1, arg594_1, arg595_1, arg596_1, arg597_1, arg598_1, arg599_1, arg600_1, arg601_1, arg602_1, arg603_1, arg604_1, arg605_1, arg606_1, arg607_1, arg608_1, arg609_1, arg610_1, arg611_1, arg612_1, arg613_1, arg614_1, arg615_1, arg616_1, arg617_1, arg618_1, arg619_1, arg620_1, arg621_1, arg622_1, arg623_1, arg624_1, arg625_1, arg626_1, arg627_1, arg628_1, arg629_1, arg630_1, arg631_1, arg632_1, arg633_1, arg634_1, arg635_1, arg636_1, arg637_1, arg638_1, arg639_1, arg640_1, arg641_1, arg642_1, arg643_1, arg644_1, arg645_1, arg646_1, arg647_1, arg648_1, arg649_1, arg650_1, arg651_1, arg652_1, arg653_1, arg654_1, arg655_1, arg656_1, arg657_1, arg658_1, arg659_1, arg660_1, arg661_1, arg662_1, arg663_1, arg664_1, arg665_1, arg666_1, arg667_1, arg668_1, arg669_1, arg670_1, arg671_1, arg672_1, arg673_1, arg674_1, arg675_1, arg676_1, arg677_1, arg678_1, arg679_1, arg680_1, arg681_1, arg682_1, arg683_1, arg684_1, arg685_1, arg686_1, arg687_1, arg688_1, arg689_1, arg690_1, arg691_1, arg692_1, arg693_1, arg694_1, arg695_1, arg696_1, arg697_1, arg698_1, arg699_1, arg700_1, arg701_1, arg702_1, arg703_1, arg704_1, arg705_1, arg706_1, arg707_1, arg708_1, arg709_1, arg710_1, arg711_1, arg712_1, arg713_1, arg714_1, arg715_1, arg716_1, arg717_1, arg718_1, arg719_1, arg720_1, arg721_1, arg722_1, arg723_1, arg724_1, arg725_1, arg726_1, arg727_1, arg728_1, arg729_1, arg730_1, arg731_1, arg732_1, arg733_1, arg734_1, arg735_1, arg736_1, arg737_1, arg738_1, arg739_1, arg740_1, arg741_1, arg742_1, arg743_1, arg744_1, arg745_1, arg746_1, arg747_1, arg748_1, arg749_1, arg750_1, arg751_1, arg752_1, arg753_1, arg754_1, arg755_1, arg756_1, arg757_1, arg758_1, arg759_1, arg760_1, arg761_1, arg762_1, arg763_1, arg764_1, arg765_1, arg766_1, arg767_1, arg768_1, arg769_1, arg770_1, arg771_1, arg772_1, arg773_1, arg774_1, arg775_1, arg776_1, arg777_1, arg778_1, arg779_1, arg780_1, arg781_1, arg782_1, arg783_1, arg784_1, arg785_1, arg786_1, arg787_1, arg788_1, arg789_1, arg790_1, arg791_1, arg792_1, arg793_1, arg794_1, arg795_1, arg796_1, arg797_1, arg798_1, arg799_1, arg800_1, arg801_1, arg802_1, arg803_1, arg804_1, arg805_1, arg806_1, arg807_1, arg808_1, arg809_1, arg810_1, arg811_1, arg812_1, arg813_1, arg814_1, arg815_1, arg816_1, arg817_1, arg818_1, arg819_1, arg820_1, arg821_1, arg822_1, arg823_1, arg824_1, arg825_1, arg826_1, arg827_1, arg828_1, arg829_1, arg830_1, arg831_1, arg832_1, arg833_1, arg834_1, arg835_1, arg836_1, arg837_1, arg838_1, arg839_1])
return print_performance(fn, times=times, repeat=repeat)
if __name__ == "__main__":
from torch._inductor.wrapper_benchmark import compiled_module_main
compiled_module_main('None', benchmark_compiled_module)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment