Skip to content

Instantly share code, notes, and snippets.

@gau-nernst
Created June 6, 2024 01:04
Show Gist options
  • Save gau-nernst/cde24dabe000f11991030609fc497a80 to your computer and use it in GitHub Desktop.
Save gau-nernst/cde24dabe000f11991030609fc497a80 to your computer and use it in GitHub Desktop.
gpt-fast w/ FP6-LLM
# AOT ID: ['0_inference']
from ctypes import c_void_p, c_long
import torch
import math
import random
import os
import tempfile
from math import inf, nan
from torch._inductor.hooks import run_intermediate_hooks
from torch._inductor.utils import maybe_profile
from torch._inductor.codegen.memory_planning import _align as align
from torch import device, empty_strided
from torch._inductor.async_compile import AsyncCompile
from torch._inductor.select_algorithm import extern_kernels
from torch._inductor.codegen.multi_kernel import MultiKernelCall
aten = torch.ops.aten
inductor_ops = torch.ops.inductor
_quantized = torch.ops._quantized
assert_size_stride = torch._C._dynamo.guards.assert_size_stride
empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu
empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda
alloc_from_pool = torch.ops.inductor._alloc_from_pool
reinterpret_tensor = torch.ops.inductor._reinterpret_tensor
async_compile = AsyncCompile()
# kernel path: /tmp/torchinductor_ubuntu/35/c35n52yrfyq3r2uvr6syoon44qkntahvhxuhmylcpsxnin76axfd.py
# Source Nodes: [float_1, mean, mul, out, x], Original ATen: [aten._to_copy, aten.embedding, aten.mean, aten.mul, torchao.fp16act_fp6weight_linear]
# float_1 => convert_element_type
# mean => mean
# mul => mul
# out => fp16act_fp6weight_linear
# x => embedding
triton_red_fused__to_copy_embedding_fp16act_fp6weight_linear_mean_mul_0 = async_compile.triton('triton_red_fused__to_copy_embedding_fp16act_fp6weight_linear_mean_mul_0', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
@triton_heuristics.reduction(
size_hints=[1, 4096],
reduction_hint=ReductionHint.INNER,
filename=__file__,
triton_meta={'signature': {0: '*i32', 1: '*fp16', 2: '*fp16', 3: '*fp16', 4: 'i32', 5: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=76), 'constants': {4: 1}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 5), equal_to_1=(4,))]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_embedding_fp16act_fp6weight_linear_mean_mul_0', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 3, 'num_reduction': 1, 'backend_hash': '5002d3b4f89584c56f8df13bbe2510318f7ec3b73ba9efd611640fb08b2fe939', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False}
)
@triton.jit
def triton_red_fused__to_copy_embedding_fp16act_fp6weight_linear_mean_mul_0(in_ptr0, in_ptr1, in_ptr2, out_ptr1, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr):
xnumel = 1
rnumel = 4096
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
xmask = xindex < xnumel
rbase = tl.arange(0, RBLOCK)[None, :]
tmp0 = tl.load(in_ptr0 + (0))
tmp1 = tl.broadcast_to(tmp0, [XBLOCK, RBLOCK])
_tmp11 = tl.full([XBLOCK, RBLOCK], 0, tl.float32)
for roffset in range(0, rnumel, RBLOCK):
rindex = roffset + rbase
rmask = rindex < rnumel
r0 = rindex
tmp2 = tl.full([XBLOCK, RBLOCK], 32000, tl.int32)
tmp3 = tmp1 + tmp2
tmp4 = tmp1 < 0
tmp5 = tl.where(tmp4, tmp3, tmp1)
tl.device_assert((0 <= tmp5) & (tmp5 < 32000), "index out of bounds: 0 <= tmp5 < 32000")
tmp7 = tl.load(in_ptr1 + (r0 + (4096*tmp5)), rmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp8 = tmp7.to(tl.float32)
tmp9 = tmp8 * tmp8
tmp10 = tl.broadcast_to(tmp9, [XBLOCK, RBLOCK])
tmp12 = _tmp11 + tmp10
_tmp11 = tl.where(rmask, tmp12, _tmp11)
tmp11 = tl.sum(_tmp11, 1)[:, None]
tmp13 = tl.load(in_ptr0 + (0))
tmp14 = tl.broadcast_to(tmp13, [XBLOCK, RBLOCK])
for roffset in range(0, rnumel, RBLOCK):
rindex = roffset + rbase
rmask = rindex < rnumel
r0 = rindex
tmp29 = tl.load(in_ptr2 + (r0), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp15 = tl.full([XBLOCK, RBLOCK], 32000, tl.int32)
tmp16 = tmp14 + tmp15
tmp17 = tmp14 < 0
tmp18 = tl.where(tmp17, tmp16, tmp14)
tl.device_assert((0 <= tmp18) & (tmp18 < 32000), "index out of bounds: 0 <= tmp18 < 32000")
tmp20 = tl.load(in_ptr1 + (r0 + (4096*tmp18)), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp21 = tmp20.to(tl.float32)
tmp22 = 4096.0
tmp23 = tmp11 / tmp22
tmp24 = 1e-05
tmp25 = tmp23 + tmp24
tmp26 = libdevice.rsqrt(tmp25)
tmp27 = tmp21 * tmp26
tmp28 = tmp27.to(tl.float32)
tmp30 = tmp28 * tmp29
tl.store(out_ptr1 + (tl.broadcast_to(r0, [XBLOCK, RBLOCK])), tmp30, rmask)
''', device_str='cuda')
import triton
import triton.language as tl
from torch._inductor.runtime.triton_heuristics import grid, split_scan_grid, start_graph, end_graph
from torch._C import _cuda_getCurrentRawStream as get_raw_stream
# kernel path: /tmp/torchinductor_ubuntu/ix/cix3cikar5jbote3hymtcnf5jtrl75mvovejwvdsj67ds2nqaynm.py
# Source Nodes: [setitem, setitem_1, y], Original ATen: [aten.bmm, aten.index_put]
# setitem => index_put
# setitem_1 => index_put_1
# y => convert_element_type_6
triton_poi_fused_bmm_index_put_1 = async_compile.triton('triton_poi_fused_bmm_index_put_1', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
@triton_heuristics.pointwise(
size_hints=[4096],
filename=__file__,
triton_meta={'signature': {0: '*i32', 1: '*fp16', 2: '*fp16', 3: '*fp16', 4: '*fp32', 5: '*fp16', 6: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=76), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4, 5, 6), equal_to_1=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_bmm_index_put_1', 'mutated_arg_names': ['out_ptr0', 'out_ptr2'], 'no_x_dim': False, 'num_load': 10, 'num_reduction': 0, 'backend_hash': '5002d3b4f89584c56f8df13bbe2510318f7ec3b73ba9efd611640fb08b2fe939', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False},
min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused_bmm_index_put_1(in_ptr0, in_ptr1, in_ptr2, out_ptr0, out_ptr1, out_ptr2, xnumel, XBLOCK : tl.constexpr):
xnumel = 4096
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = xindex < xnumel
x0 = xindex % 128
x1 = (xindex // 128)
x2 = xindex
tmp0 = tl.load(in_ptr0 + (0))
tmp1 = tl.broadcast_to(tmp0, [XBLOCK])
tmp71 = tl.load(in_ptr1 + (8192 + x2), None).to(tl.float32)
tmp2 = tl.full([XBLOCK], 208, tl.int32)
tmp3 = tmp1 + tmp2
tmp4 = tmp1 < 0
tmp5 = tl.where(tmp4, tmp3, tmp1)
tl.device_assert((0 <= tmp5) & (tmp5 < 208), "index out of bounds: 0 <= tmp5 < 208")
tmp7 = x0 % 2
tmp8 = tl.full([1], 0, tl.int64)
tmp9 = tmp7 >= tmp8
tmp10 = tl.full([1], 1, tl.int64)
tmp11 = tmp7 < tmp10
tmp12 = tl.load(in_ptr1 + (4096 + (2*(x0 // 2)) + (128*x1)), tmp11, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp13 = tmp12.to(tl.float32)
tmp14 = tl.full([XBLOCK], 2048, tl.int32)
tmp15 = tmp1 + tmp14
tmp16 = tl.where(tmp4, tmp15, tmp1)
tl.device_assert(((0 <= tl.broadcast_to(tmp16, [XBLOCK])) & (tl.broadcast_to(tmp16, [XBLOCK]) < 2048)) | ~(tmp11), "index out of bounds: 0 <= tl.broadcast_to(tmp16, [XBLOCK]) < 2048")
tmp18 = tl.load(in_ptr2 + ((2*(x0 // 2)) + (128*tmp16)), tmp11, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp19 = tmp18.to(tl.float32)
tmp20 = tmp13 * tmp19
tmp21 = tl.load(in_ptr1 + (4097 + (2*(x0 // 2)) + (128*x1)), tmp11, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp22 = tmp21.to(tl.float32)
tmp23 = tl.load(in_ptr2 + (1 + (2*(x0 // 2)) + (128*tmp16)), tmp11, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp24 = tmp23.to(tl.float32)
tmp25 = tmp22 * tmp24
tmp26 = tmp20 - tmp25
tmp27 = tl.full(tmp26.shape, 0.0, tmp26.dtype)
tmp28 = tl.where(tmp11, tmp26, tmp27)
tmp29 = tmp7 >= tmp10
tmp30 = tl.full([1], 2, tl.int64)
tmp31 = tmp7 < tmp30
tmp32 = tl.load(in_ptr1 + (4097 + (2*(x0 // 2)) + (128*x1)), tmp29, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp33 = tmp32.to(tl.float32)
tl.device_assert(((0 <= tl.broadcast_to(tmp16, [XBLOCK])) & (tl.broadcast_to(tmp16, [XBLOCK]) < 2048)) | ~(tmp29), "index out of bounds: 0 <= tl.broadcast_to(tmp16, [XBLOCK]) < 2048")
tmp35 = tl.load(in_ptr2 + ((2*(x0 // 2)) + (128*tmp16)), tmp29, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp36 = tmp35.to(tl.float32)
tmp37 = tmp33 * tmp36
tmp38 = tl.load(in_ptr1 + (4096 + (2*(x0 // 2)) + (128*x1)), tmp29, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp39 = tmp38.to(tl.float32)
tmp40 = tl.load(in_ptr2 + (1 + (2*(x0 // 2)) + (128*tmp16)), tmp29, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp41 = tmp40.to(tl.float32)
tmp42 = tmp39 * tmp41
tmp43 = tmp37 + tmp42
tmp44 = tl.full(tmp43.shape, 0.0, tmp43.dtype)
tmp45 = tl.where(tmp29, tmp43, tmp44)
tmp46 = tl.where(tmp11, tmp28, tmp45)
tmp47 = tmp46.to(tl.float32)
tmp48 = tl.load(in_ptr1 + ((2*(x0 // 2)) + (128*x1)), tmp11, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp49 = tmp48.to(tl.float32)
tmp50 = tmp49 * tmp19
tmp51 = tl.load(in_ptr1 + (1 + (2*(x0 // 2)) + (128*x1)), tmp11, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp52 = tmp51.to(tl.float32)
tmp53 = tmp52 * tmp24
tmp54 = tmp50 - tmp53
tmp55 = tl.full(tmp54.shape, 0.0, tmp54.dtype)
tmp56 = tl.where(tmp11, tmp54, tmp55)
tmp57 = tl.load(in_ptr1 + (1 + (2*(x0 // 2)) + (128*x1)), tmp29, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp58 = tmp57.to(tl.float32)
tmp59 = tmp58 * tmp36
tmp60 = tl.load(in_ptr1 + ((2*(x0 // 2)) + (128*x1)), tmp29, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp61 = tmp60.to(tl.float32)
tmp62 = tmp61 * tmp41
tmp63 = tmp59 + tmp62
tmp64 = tl.full(tmp63.shape, 0.0, tmp63.dtype)
tmp65 = tl.where(tmp29, tmp63, tmp64)
tmp66 = tl.where(tmp11, tmp56, tmp65)
tmp67 = tmp66.to(tl.float32)
tmp68 = 0.29730177875068026
tmp69 = tmp67 * tmp68
tmp70 = tmp69.to(tl.float32)
tl.store(out_ptr0 + (x0 + (128*tmp5) + (26624*x1)), tmp47, None)
tl.store(out_ptr1 + (x2), tmp70, None)
tl.store(out_ptr2 + (x0 + (128*tmp5) + (26624*x1)), tmp71, None)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_ubuntu/tu/ctuxbryvllsmitwtw4gk35f6e4uug6renb6km4llzq2czdpete6i.py
# Source Nodes: [y], Original ATen: [aten.bmm]
# y => mul_13, sum_1
triton_red_fused_bmm_2 = async_compile.triton('triton_red_fused_bmm_2', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
@triton_heuristics.reduction(
size_hints=[8192, 128],
reduction_hint=ReductionHint.DEFAULT,
filename=__file__,
triton_meta={'signature': {0: '*fp32', 1: '*fp16', 2: '*fp32', 3: 'i32', 4: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=76), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4), equal_to_1=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused_bmm_2', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 2, 'num_reduction': 1, 'backend_hash': '5002d3b4f89584c56f8df13bbe2510318f7ec3b73ba9efd611640fb08b2fe939', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False}
)
@triton.jit
def triton_red_fused_bmm_2(in_ptr0, in_ptr1, out_ptr0, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr):
xnumel = 6656
rnumel = 128
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
xmask = xindex < xnumel
rbase = tl.arange(0, RBLOCK)[None, :]
x1 = (xindex // 208)
x3 = xindex
_tmp7 = tl.full([XBLOCK, RBLOCK], 0, tl.float32)
for roffset in range(0, rnumel, RBLOCK):
rindex = roffset + rbase
rmask = rindex < rnumel
r2 = rindex
tmp0 = tl.load(in_ptr0 + (r2 + (128*x1)), rmask & xmask, eviction_policy='evict_last', other=0.0)
tmp1 = tl.load(in_ptr1 + (r2 + (128*x3)), rmask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp2 = 0.29730177875068026
tmp3 = tmp1 * tmp2
tmp4 = tmp3.to(tl.float32)
tmp5 = tmp0 * tmp4
tmp6 = tl.broadcast_to(tmp5, [XBLOCK, RBLOCK])
tmp8 = _tmp7 + tmp6
_tmp7 = tl.where(rmask & xmask, tmp8, _tmp7)
tmp7 = tl.sum(_tmp7, 1)[:, None]
tl.store(out_ptr0 + (x3), tmp7, xmask)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_ubuntu/f6/cf6hoyusglegqlcllwth4hbv7jzk7xu6ol3qg56cytjysskalpwn.py
# Source Nodes: [mask, y], Original ATen: [aten._softmax, aten.add, aten.bmm, aten.index, aten.logical_not, aten.masked_fill, aten.zeros_like]
# mask => index
# y => add_3, amax, convert_element_type_11, convert_element_type_9, exp, full_default, full_default_1, logical_not, sub_2, sum_2, where
triton_per_fused__softmax_add_bmm_index_logical_not_masked_fill_zeros_like_3 = async_compile.triton('triton_per_fused__softmax_add_bmm_index_logical_not_masked_fill_zeros_like_3', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
@triton_heuristics.persistent_reduction(
size_hints=[32, 256],
reduction_hint=ReductionHint.INNER,
filename=__file__,
triton_meta={'signature': {0: '*fp32', 1: '*i32', 2: '*i1', 3: '*fp32', 4: 'i32', 5: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=76), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4, 5), equal_to_1=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_per_fused__softmax_add_bmm_index_logical_not_masked_fill_zeros_like_3', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 2, 'num_reduction': 2, 'backend_hash': '5002d3b4f89584c56f8df13bbe2510318f7ec3b73ba9efd611640fb08b2fe939', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False}
)
@triton.jit
def triton_per_fused__softmax_add_bmm_index_logical_not_masked_fill_zeros_like_3(in_ptr0, in_ptr1, in_ptr2, out_ptr2, xnumel, rnumel, XBLOCK : tl.constexpr):
xnumel = 32
rnumel = 208
RBLOCK: tl.constexpr = 256
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
xmask = xindex < xnumel
rindex = tl.arange(0, RBLOCK)[None, :]
roffset = 0
rmask = rindex < rnumel
r1 = rindex
x0 = xindex
tmp0 = tl.load(in_ptr0 + (r1 + (208*x0)), rmask & xmask, other=0.0)
tmp2 = tl.load(in_ptr1 + (0))
tmp3 = tl.broadcast_to(tmp2, [XBLOCK, RBLOCK])
tmp1 = tmp0.to(tl.float32)
tmp4 = tl.full([XBLOCK, RBLOCK], 208, tl.int32)
tmp5 = tmp3 + tmp4
tmp6 = tmp3 < 0
tmp7 = tl.where(tmp6, tmp5, tmp3)
tl.device_assert((0 <= tmp7) & (tmp7 < 208), "index out of bounds: 0 <= tmp7 < 208")
tmp9 = tl.load(in_ptr2 + (r1 + (208*tmp7)), rmask, other=0.0).to(tl.int1)
tmp10 = tmp9 == 0
tmp11 = float("-inf")
tmp12 = 0.0
tmp13 = tl.where(tmp10, tmp11, tmp12)
tmp14 = tmp1 + tmp13
tmp15 = tmp14.to(tl.float32)
tmp16 = tl.broadcast_to(tmp15, [XBLOCK, RBLOCK])
tmp18 = tl.where(rmask & xmask, tmp16, float("-inf"))
tmp19 = triton_helpers.max2(tmp18, 1)[:, None]
tmp20 = tmp15 - tmp19
tmp21 = tl_math.exp(tmp20)
tmp22 = tl.broadcast_to(tmp21, [XBLOCK, RBLOCK])
tmp24 = tl.where(rmask & xmask, tmp22, 0)
tmp25 = tl.sum(tmp24, 1)[:, None]
tmp26 = tmp21 / tmp25
tmp27 = tmp26.to(tl.float32)
tmp28 = tmp27.to(tl.float32)
tl.store(out_ptr2 + (r1 + (208*x0)), tmp28, rmask & xmask)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_ubuntu/3q/c3q5eet7mn3gx23dlvno24ooilcrlirl64ahmhluvxu6ynwq452m.py
# Source Nodes: [y], Original ATen: [aten.bmm]
# y => mul_14, sum_3
triton_red_fused_bmm_4 = async_compile.triton('triton_red_fused_bmm_4', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
@triton_heuristics.reduction(
size_hints=[8192, 128],
reduction_hint=ReductionHint.OUTER,
filename=__file__,
triton_meta={'signature': {0: '*fp32', 1: '*fp16', 2: '*fp32', 3: 'i32', 4: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=76), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3), equal_to_1=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused_bmm_4', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 2, 'num_reduction': 1, 'backend_hash': '5002d3b4f89584c56f8df13bbe2510318f7ec3b73ba9efd611640fb08b2fe939', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False}
)
@triton.jit
def triton_red_fused_bmm_4(in_ptr0, in_ptr1, out_ptr0, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr):
xnumel = 8192
rnumel = 104
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
xmask = xindex < xnumel
rbase = tl.arange(0, RBLOCK)[None, :]
x1 = (xindex // 128)
x0 = xindex % 128
_tmp5 = tl.full([XBLOCK, RBLOCK], 0, tl.float32)
x3 = xindex
for roffset in range(0, rnumel, RBLOCK):
rindex = roffset + rbase
rmask = rindex < rnumel
r2 = rindex
tmp0 = tl.load(in_ptr0 + (r2 + (104*x1)), rmask, eviction_policy='evict_last', other=0.0)
tmp1 = tl.load(in_ptr1 + (x0 + (128*r2) + (13312*x1)), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp2 = tmp1.to(tl.float32)
tmp3 = tmp0 * tmp2
tmp4 = tl.broadcast_to(tmp3, [XBLOCK, RBLOCK])
tmp6 = _tmp5 + tmp4
_tmp5 = tl.where(rmask, tmp6, _tmp5)
tmp5 = tl.sum(_tmp5, 1)[:, None]
tl.store(out_ptr0 + (x3), tmp5, None)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_ubuntu/s7/cs7bh5txamreduogo3dbpgu62hdnknommfet7uxuau6q2gi24u6j.py
# Source Nodes: [out_1, y], Original ATen: [aten.bmm, torchao.fp16act_fp6weight_linear]
# out_1 => fp16act_fp6weight_linear_1
# y => mul_14, sum_3
triton_per_fused_bmm_fp16act_fp6weight_linear_5 = async_compile.triton('triton_per_fused_bmm_fp16act_fp6weight_linear_5', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
@triton_heuristics.persistent_reduction(
size_hints=[4096, 2],
reduction_hint=ReductionHint.OUTER,
filename=__file__,
triton_meta={'signature': {0: '*fp32', 1: '*fp16', 2: 'i32', 3: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=76), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2), equal_to_1=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_per_fused_bmm_fp16act_fp6weight_linear_5', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': '5002d3b4f89584c56f8df13bbe2510318f7ec3b73ba9efd611640fb08b2fe939', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False}
)
@triton.jit
def triton_per_fused_bmm_fp16act_fp6weight_linear_5(in_ptr0, out_ptr1, xnumel, rnumel, XBLOCK : tl.constexpr):
xnumel = 4096
rnumel = 2
RBLOCK: tl.constexpr = 2
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
xmask = xindex < xnumel
rindex = tl.arange(0, RBLOCK)[None, :]
roffset = 0
rmask = rindex < rnumel
r2 = rindex
x0 = xindex % 128
x1 = (xindex // 128)
x3 = xindex
tmp0 = tl.load(in_ptr0 + (x0 + (128*r2) + (256*x1)), rmask, other=0.0)
tmp1 = tl.broadcast_to(tmp0, [XBLOCK, RBLOCK])
tmp3 = tl.where(rmask, tmp1, 0)
tmp4 = tl.sum(tmp3, 1)[:, None]
tmp5 = tmp4.to(tl.float32)
tl.store(out_ptr1 + (x3), tmp5, None)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_ubuntu/e6/ce6earzjiin77ifcgxqhqghmkxait7rqhahkb2zasp4hadk7bw4d.py
# Source Nodes: [add_4, float_4, h, mean_1, mul_11, mul_12, mul_13, output_1, rsqrt_1, x], Original ATen: [aten._to_copy, aten.add, aten.embedding, aten.mean, aten.mul, aten.rsqrt]
# add_4 => add_5
# float_4 => convert_element_type_14
# h => add_4
# mean_1 => mean_1
# mul_11 => mul_15
# mul_12 => mul_16
# mul_13 => mul_17
# output_1 => convert_element_type_15
# rsqrt_1 => rsqrt_1
# x => embedding
triton_red_fused__to_copy_add_embedding_mean_mul_rsqrt_6 = async_compile.triton('triton_red_fused__to_copy_add_embedding_mean_mul_rsqrt_6', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
@triton_heuristics.reduction(
size_hints=[1, 4096],
reduction_hint=ReductionHint.INNER,
filename=__file__,
triton_meta={'signature': {0: '*i32', 1: '*fp16', 2: '*fp16', 3: '*fp16', 4: '*fp16', 5: 'i32', 6: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=76), 'constants': {5: 1}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4, 6), equal_to_1=(5,))]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_add_embedding_mean_mul_rsqrt_6', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 5, 'num_reduction': 1, 'backend_hash': '5002d3b4f89584c56f8df13bbe2510318f7ec3b73ba9efd611640fb08b2fe939', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False}
)
@triton.jit
def triton_red_fused__to_copy_add_embedding_mean_mul_rsqrt_6(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr1, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr):
xnumel = 1
rnumel = 4096
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
xmask = xindex < xnumel
rbase = tl.arange(0, RBLOCK)[None, :]
tmp0 = tl.load(in_ptr0 + (0))
tmp1 = tl.broadcast_to(tmp0, [XBLOCK, RBLOCK])
_tmp13 = tl.full([XBLOCK, RBLOCK], 0, tl.float32)
for roffset in range(0, rnumel, RBLOCK):
rindex = roffset + rbase
rmask = rindex < rnumel
r0 = rindex
tmp8 = tl.load(in_ptr2 + (r0), rmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp2 = tl.full([XBLOCK, RBLOCK], 32000, tl.int32)
tmp3 = tmp1 + tmp2
tmp4 = tmp1 < 0
tmp5 = tl.where(tmp4, tmp3, tmp1)
tl.device_assert((0 <= tmp5) & (tmp5 < 32000), "index out of bounds: 0 <= tmp5 < 32000")
tmp7 = tl.load(in_ptr1 + (r0 + (4096*tmp5)), rmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp9 = tmp7 + tmp8
tmp10 = tmp9.to(tl.float32)
tmp11 = tmp10 * tmp10
tmp12 = tl.broadcast_to(tmp11, [XBLOCK, RBLOCK])
tmp14 = _tmp13 + tmp12
_tmp13 = tl.where(rmask, tmp14, _tmp13)
tmp13 = tl.sum(_tmp13, 1)[:, None]
tmp15 = tl.load(in_ptr0 + (0))
tmp16 = tl.broadcast_to(tmp15, [XBLOCK, RBLOCK])
for roffset in range(0, rnumel, RBLOCK):
rindex = roffset + rbase
rmask = rindex < rnumel
r0 = rindex
tmp23 = tl.load(in_ptr2 + (r0), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp33 = tl.load(in_ptr3 + (r0), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp17 = tl.full([XBLOCK, RBLOCK], 32000, tl.int32)
tmp18 = tmp16 + tmp17
tmp19 = tmp16 < 0
tmp20 = tl.where(tmp19, tmp18, tmp16)
tl.device_assert((0 <= tmp20) & (tmp20 < 32000), "index out of bounds: 0 <= tmp20 < 32000")
tmp22 = tl.load(in_ptr1 + (r0 + (4096*tmp20)), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp24 = tmp22 + tmp23
tmp25 = tmp24.to(tl.float32)
tmp26 = 4096.0
tmp27 = tmp13 / tmp26
tmp28 = 1e-05
tmp29 = tmp27 + tmp28
tmp30 = libdevice.rsqrt(tmp29)
tmp31 = tmp25 * tmp30
tmp32 = tmp31.to(tl.float32)
tmp34 = tmp32 * tmp33
tl.store(out_ptr1 + (tl.broadcast_to(r0, [XBLOCK, RBLOCK])), tmp34, rmask)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_ubuntu/7x/c7xzb5ucrv7a232bufgvbjemlk4jgdh7ujoroxm3v7jxuduwdfif.py
# Source Nodes: [out_4], Original ATen: [torchao.fp16act_fp6weight_linear]
# out_4 => fp16act_fp6weight_linear_4
triton_poi_fused_fp16act_fp6weight_linear_7 = async_compile.triton('triton_poi_fused_fp16act_fp6weight_linear_7', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
@triton_heuristics.pointwise(
size_hints=[16384],
filename=__file__,
triton_meta={'signature': {0: '*fp16', 1: '*fp16', 2: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=76), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2), equal_to_1=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_fp16act_fp6weight_linear_7', 'mutated_arg_names': ['in_out_ptr0'], 'no_x_dim': False, 'num_load': 2, 'num_reduction': 0, 'backend_hash': '5002d3b4f89584c56f8df13bbe2510318f7ec3b73ba9efd611640fb08b2fe939', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False},
min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused_fp16act_fp6weight_linear_7(in_out_ptr0, in_ptr0, xnumel, XBLOCK : tl.constexpr):
xnumel = 11008
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = xindex < xnumel
x0 = xindex
tmp0 = tl.load(in_out_ptr0 + (x0), xmask).to(tl.float32)
tmp5 = tl.load(in_ptr0 + (x0), xmask).to(tl.float32)
tmp1 = tmp0.to(tl.float32)
tmp2 = tl.sigmoid(tmp1)
tmp3 = tmp1 * tmp2
tmp4 = tmp3.to(tl.float32)
tmp6 = tmp4 * tmp5
tl.store(in_out_ptr0 + (x0), tmp6, xmask)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_ubuntu/xb/cxbu5uv7truaeft3qae6exsdva5w6j3jruqooy6lvbsp53jek26v.py
# Source Nodes: [float_5, h, mean_2, mul_15, out_5, out_6, x], Original ATen: [aten._to_copy, aten.add, aten.embedding, aten.mean, aten.mul, torchao.fp16act_fp6weight_linear]
# float_5 => convert_element_type_18
# h => add_4
# mean_2 => mean_2
# mul_15 => mul_20
# out_5 => add_6
# out_6 => fp16act_fp6weight_linear_5
# x => embedding
triton_red_fused__to_copy_add_embedding_fp16act_fp6weight_linear_mean_mul_8 = async_compile.triton('triton_red_fused__to_copy_add_embedding_fp16act_fp6weight_linear_mean_mul_8', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
@triton_heuristics.reduction(
size_hints=[1, 4096],
reduction_hint=ReductionHint.INNER,
filename=__file__,
triton_meta={'signature': {0: '*i32', 1: '*fp16', 2: '*fp16', 3: '*fp16', 4: '*fp16', 5: '*fp16', 6: 'i32', 7: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=76), 'constants': {6: 1}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4, 5, 7), equal_to_1=(6,))]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_add_embedding_fp16act_fp6weight_linear_mean_mul_8', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 7, 'num_reduction': 1, 'backend_hash': '5002d3b4f89584c56f8df13bbe2510318f7ec3b73ba9efd611640fb08b2fe939', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False}
)
@triton.jit
def triton_red_fused__to_copy_add_embedding_fp16act_fp6weight_linear_mean_mul_8(in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, out_ptr1, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr):
xnumel = 1
rnumel = 4096
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
xmask = xindex < xnumel
rbase = tl.arange(0, RBLOCK)[None, :]
tmp0 = tl.load(in_ptr0 + (0))
tmp1 = tl.broadcast_to(tmp0, [XBLOCK, RBLOCK])
_tmp15 = tl.full([XBLOCK, RBLOCK], 0, tl.float32)
for roffset in range(0, rnumel, RBLOCK):
rindex = roffset + rbase
rmask = rindex < rnumel
r0 = rindex
tmp8 = tl.load(in_ptr2 + (r0), rmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp10 = tl.load(in_ptr3 + (r0), rmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp2 = tl.full([XBLOCK, RBLOCK], 32000, tl.int32)
tmp3 = tmp1 + tmp2
tmp4 = tmp1 < 0
tmp5 = tl.where(tmp4, tmp3, tmp1)
tl.device_assert((0 <= tmp5) & (tmp5 < 32000), "index out of bounds: 0 <= tmp5 < 32000")
tmp7 = tl.load(in_ptr1 + (r0 + (4096*tmp5)), rmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp9 = tmp7 + tmp8
tmp11 = tmp9 + tmp10
tmp12 = tmp11.to(tl.float32)
tmp13 = tmp12 * tmp12
tmp14 = tl.broadcast_to(tmp13, [XBLOCK, RBLOCK])
tmp16 = _tmp15 + tmp14
_tmp15 = tl.where(rmask, tmp16, _tmp15)
tmp15 = tl.sum(_tmp15, 1)[:, None]
tmp17 = tl.load(in_ptr0 + (0))
tmp18 = tl.broadcast_to(tmp17, [XBLOCK, RBLOCK])
for roffset in range(0, rnumel, RBLOCK):
rindex = roffset + rbase
rmask = rindex < rnumel
r0 = rindex
tmp25 = tl.load(in_ptr2 + (r0), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp27 = tl.load(in_ptr3 + (r0), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp37 = tl.load(in_ptr4 + (r0), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp19 = tl.full([XBLOCK, RBLOCK], 32000, tl.int32)
tmp20 = tmp18 + tmp19
tmp21 = tmp18 < 0
tmp22 = tl.where(tmp21, tmp20, tmp18)
tl.device_assert((0 <= tmp22) & (tmp22 < 32000), "index out of bounds: 0 <= tmp22 < 32000")
tmp24 = tl.load(in_ptr1 + (r0 + (4096*tmp22)), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp26 = tmp24 + tmp25
tmp28 = tmp26 + tmp27
tmp29 = tmp28.to(tl.float32)
tmp30 = 4096.0
tmp31 = tmp15 / tmp30
tmp32 = 1e-05
tmp33 = tmp31 + tmp32
tmp34 = libdevice.rsqrt(tmp33)
tmp35 = tmp29 * tmp34
tmp36 = tmp35.to(tl.float32)
tmp38 = tmp36 * tmp37
tl.store(out_ptr1 + (tl.broadcast_to(r0, [XBLOCK, RBLOCK])), tmp38, rmask)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_ubuntu/ko/ckonz56ntbc3wxvboqbs6j6xmjdrqma6fl5rtmem5grbfiymtcs4.py
# Source Nodes: [float_8, h, h_1, mean_3, mul_26, out_5, out_8, out_9, x], Original ATen: [aten._to_copy, aten.add, aten.embedding, aten.mean, aten.mul, torchao.fp16act_fp6weight_linear]
# float_8 => convert_element_type_32
# h => add_4
# h_1 => add_11
# mean_3 => mean_3
# mul_26 => mul_35
# out_5 => add_6
# out_8 => fp16act_fp6weight_linear_7
# out_9 => fp16act_fp6weight_linear_8
# x => embedding
triton_red_fused__to_copy_add_embedding_fp16act_fp6weight_linear_mean_mul_9 = async_compile.triton('triton_red_fused__to_copy_add_embedding_fp16act_fp6weight_linear_mean_mul_9', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
@triton_heuristics.reduction(
size_hints=[1, 4096],
reduction_hint=ReductionHint.INNER,
filename=__file__,
triton_meta={'signature': {0: '*fp16', 1: '*i32', 2: '*fp16', 3: '*fp16', 4: '*fp16', 5: '*fp16', 6: '*fp16', 7: '*fp16', 8: 'i32', 9: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=76), 'constants': {8: 1}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4, 5, 6, 7, 9), equal_to_1=(8,))]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_add_embedding_fp16act_fp6weight_linear_mean_mul_9', 'mutated_arg_names': ['in_out_ptr0'], 'no_x_dim': False, 'num_load': 6, 'num_reduction': 1, 'backend_hash': '5002d3b4f89584c56f8df13bbe2510318f7ec3b73ba9efd611640fb08b2fe939', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False}
)
@triton.jit
def triton_red_fused__to_copy_add_embedding_fp16act_fp6weight_linear_mean_mul_9(in_out_ptr0, in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, out_ptr1, out_ptr2, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr):
xnumel = 1
rnumel = 4096
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
xmask = xindex < xnumel
rbase = tl.arange(0, RBLOCK)[None, :]
tmp0 = tl.load(in_ptr0 + (0))
tmp1 = tl.broadcast_to(tmp0, [XBLOCK, RBLOCK])
_tmp17 = tl.full([XBLOCK, RBLOCK], 0, tl.float32)
for roffset in range(0, rnumel, RBLOCK):
rindex = roffset + rbase
rmask = rindex < rnumel
r0 = rindex
tmp8 = tl.load(in_out_ptr0 + (r0), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp10 = tl.load(in_ptr2 + (r0), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp12 = tl.load(in_ptr3 + (r0), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp2 = tl.full([XBLOCK, RBLOCK], 32000, tl.int32)
tmp3 = tmp1 + tmp2
tmp4 = tmp1 < 0
tmp5 = tl.where(tmp4, tmp3, tmp1)
tl.device_assert((0 <= tmp5) & (tmp5 < 32000), "index out of bounds: 0 <= tmp5 < 32000")
tmp7 = tl.load(in_ptr1 + (r0 + (4096*tmp5)), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp9 = tmp7 + tmp8
tmp11 = tmp9 + tmp10
tmp13 = tmp11 + tmp12
tmp14 = tmp13.to(tl.float32)
tmp15 = tmp14 * tmp14
tmp16 = tl.broadcast_to(tmp15, [XBLOCK, RBLOCK])
tmp18 = _tmp17 + tmp16
_tmp17 = tl.where(rmask, tmp18, _tmp17)
tl.store(in_out_ptr0 + (tl.broadcast_to(r0, [XBLOCK, RBLOCK])), tmp13, rmask)
tmp17 = tl.sum(_tmp17, 1)[:, None]
for roffset in range(0, rnumel, RBLOCK):
rindex = roffset + rbase
rmask = rindex < rnumel
r0 = rindex
tmp19 = tl.load(in_out_ptr0 + (r0), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp28 = tl.load(in_ptr4 + (r0), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp20 = tmp19.to(tl.float32)
tmp21 = 4096.0
tmp22 = tmp17 / tmp21
tmp23 = 1e-05
tmp24 = tmp22 + tmp23
tmp25 = libdevice.rsqrt(tmp24)
tmp26 = tmp20 * tmp25
tmp27 = tmp26.to(tl.float32)
tmp29 = tmp27 * tmp28
tl.store(out_ptr1 + (tl.broadcast_to(r0, [XBLOCK, RBLOCK])), tmp29, rmask)
tl.store(out_ptr2 + (tl.broadcast_to(r0, [XBLOCK, RBLOCK])), tmp29, rmask)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_ubuntu/23/c23scxttsdz72vesaap27tazzxervcixdrt7hwfybhp2nrfu4kw4.py
# Source Nodes: [float_9, mean_4, mul_30, out_11, out_12], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, torchao.fp16act_fp6weight_linear]
# float_9 => convert_element_type_36
# mean_4 => mean_4
# mul_30 => mul_40
# out_11 => add_13
# out_12 => fp16act_fp6weight_linear_10
triton_red_fused__to_copy_add_fp16act_fp6weight_linear_mean_mul_10 = async_compile.triton('triton_red_fused__to_copy_add_fp16act_fp6weight_linear_mean_mul_10', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
@triton_heuristics.reduction(
size_hints=[1, 4096],
reduction_hint=ReductionHint.INNER,
filename=__file__,
triton_meta={'signature': {0: '*fp16', 1: '*fp16', 2: '*fp16', 3: '*fp16', 4: 'i32', 5: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=76), 'constants': {4: 1}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 5), equal_to_1=(4,))]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_add_fp16act_fp6weight_linear_mean_mul_10', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 5, 'num_reduction': 1, 'backend_hash': '5002d3b4f89584c56f8df13bbe2510318f7ec3b73ba9efd611640fb08b2fe939', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False}
)
@triton.jit
def triton_red_fused__to_copy_add_fp16act_fp6weight_linear_mean_mul_10(in_ptr0, in_ptr1, in_ptr2, out_ptr1, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr):
xnumel = 1
rnumel = 4096
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
xmask = xindex < xnumel
rbase = tl.arange(0, RBLOCK)[None, :]
_tmp6 = tl.full([XBLOCK, RBLOCK], 0, tl.float32)
for roffset in range(0, rnumel, RBLOCK):
rindex = roffset + rbase
rmask = rindex < rnumel
r0 = rindex
tmp0 = tl.load(in_ptr0 + (r0), rmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp1 = tl.load(in_ptr1 + (r0), rmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp2 = tmp0 + tmp1
tmp3 = tmp2.to(tl.float32)
tmp4 = tmp3 * tmp3
tmp5 = tl.broadcast_to(tmp4, [XBLOCK, RBLOCK])
tmp7 = _tmp6 + tmp5
_tmp6 = tl.where(rmask, tmp7, _tmp6)
tmp6 = tl.sum(_tmp6, 1)[:, None]
for roffset in range(0, rnumel, RBLOCK):
rindex = roffset + rbase
rmask = rindex < rnumel
r0 = rindex
tmp8 = tl.load(in_ptr0 + (r0), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp9 = tl.load(in_ptr1 + (r0), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp19 = tl.load(in_ptr2 + (r0), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp10 = tmp8 + tmp9
tmp11 = tmp10.to(tl.float32)
tmp12 = 4096.0
tmp13 = tmp6 / tmp12
tmp14 = 1e-05
tmp15 = tmp13 + tmp14
tmp16 = libdevice.rsqrt(tmp15)
tmp17 = tmp11 * tmp16
tmp18 = tmp17.to(tl.float32)
tmp20 = tmp18 * tmp19
tl.store(out_ptr1 + (tl.broadcast_to(r0, [XBLOCK, RBLOCK])), tmp20, rmask)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_ubuntu/sg/csgnrv5nhw4eic7f275ngxcfz73e62dtja6zilbn6f5a2qhvqn54.py
# Source Nodes: [add_16, float_12, h_2, mean_5, mul_41, mul_42, mul_43, out_11, output_5, rsqrt_5], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, aten.rsqrt]
# add_16 => add_19
# float_12 => convert_element_type_50
# h_2 => add_18
# mean_5 => mean_5
# mul_41 => mul_55
# mul_42 => mul_56
# mul_43 => mul_57
# out_11 => add_13
# output_5 => convert_element_type_51
# rsqrt_5 => rsqrt_5
triton_red_fused__to_copy_add_mean_mul_rsqrt_11 = async_compile.triton('triton_red_fused__to_copy_add_mean_mul_rsqrt_11', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
@triton_heuristics.reduction(
size_hints=[1, 4096],
reduction_hint=ReductionHint.INNER,
filename=__file__,
triton_meta={'signature': {0: '*fp16', 1: '*fp16', 2: '*fp16', 3: '*fp16', 4: '*fp16', 5: 'i32', 6: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=76), 'constants': {5: 1}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4, 6), equal_to_1=(5,))]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_add_mean_mul_rsqrt_11', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 7, 'num_reduction': 1, 'backend_hash': '5002d3b4f89584c56f8df13bbe2510318f7ec3b73ba9efd611640fb08b2fe939', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False}
)
@triton.jit
def triton_red_fused__to_copy_add_mean_mul_rsqrt_11(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr1, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr):
xnumel = 1
rnumel = 4096
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
xmask = xindex < xnumel
rbase = tl.arange(0, RBLOCK)[None, :]
_tmp8 = tl.full([XBLOCK, RBLOCK], 0, tl.float32)
for roffset in range(0, rnumel, RBLOCK):
rindex = roffset + rbase
rmask = rindex < rnumel
r0 = rindex
tmp0 = tl.load(in_ptr0 + (r0), rmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp1 = tl.load(in_ptr1 + (r0), rmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp3 = tl.load(in_ptr2 + (r0), rmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp2 = tmp0 + tmp1
tmp4 = tmp2 + tmp3
tmp5 = tmp4.to(tl.float32)
tmp6 = tmp5 * tmp5
tmp7 = tl.broadcast_to(tmp6, [XBLOCK, RBLOCK])
tmp9 = _tmp8 + tmp7
_tmp8 = tl.where(rmask, tmp9, _tmp8)
tmp8 = tl.sum(_tmp8, 1)[:, None]
for roffset in range(0, rnumel, RBLOCK):
rindex = roffset + rbase
rmask = rindex < rnumel
r0 = rindex
tmp10 = tl.load(in_ptr0 + (r0), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp11 = tl.load(in_ptr1 + (r0), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp13 = tl.load(in_ptr2 + (r0), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp23 = tl.load(in_ptr3 + (r0), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp12 = tmp10 + tmp11
tmp14 = tmp12 + tmp13
tmp15 = tmp14.to(tl.float32)
tmp16 = 4096.0
tmp17 = tmp8 / tmp16
tmp18 = 1e-05
tmp19 = tmp17 + tmp18
tmp20 = libdevice.rsqrt(tmp19)
tmp21 = tmp15 * tmp20
tmp22 = tmp21.to(tl.float32)
tmp24 = tmp22 * tmp23
tl.store(out_ptr1 + (tl.broadcast_to(r0, [XBLOCK, RBLOCK])), tmp24, rmask)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_ubuntu/nw/cnwia6b5wdcoewnme63eazcbz6dvzogn3stxnaty267p4vw2mmwy.py
# Source Nodes: [float_13, h_2, mean_6, mul_45, out_11, out_17, out_18], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, torchao.fp16act_fp6weight_linear]
# float_13 => convert_element_type_54
# h_2 => add_18
# mean_6 => mean_6
# mul_45 => mul_60
# out_11 => add_13
# out_17 => add_20
# out_18 => fp16act_fp6weight_linear_15
triton_red_fused__to_copy_add_fp16act_fp6weight_linear_mean_mul_12 = async_compile.triton('triton_red_fused__to_copy_add_fp16act_fp6weight_linear_mean_mul_12', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
@triton_heuristics.reduction(
size_hints=[1, 4096],
reduction_hint=ReductionHint.INNER,
filename=__file__,
triton_meta={'signature': {0: '*fp16', 1: '*fp16', 2: '*fp16', 3: '*fp16', 4: '*fp16', 5: '*fp16', 6: 'i32', 7: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=76), 'constants': {6: 1}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4, 5, 7), equal_to_1=(6,))]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_add_fp16act_fp6weight_linear_mean_mul_12', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 9, 'num_reduction': 1, 'backend_hash': '5002d3b4f89584c56f8df13bbe2510318f7ec3b73ba9efd611640fb08b2fe939', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False}
)
@triton.jit
def triton_red_fused__to_copy_add_fp16act_fp6weight_linear_mean_mul_12(in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, out_ptr1, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr):
xnumel = 1
rnumel = 4096
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
xmask = xindex < xnumel
rbase = tl.arange(0, RBLOCK)[None, :]
_tmp10 = tl.full([XBLOCK, RBLOCK], 0, tl.float32)
for roffset in range(0, rnumel, RBLOCK):
rindex = roffset + rbase
rmask = rindex < rnumel
r0 = rindex
tmp0 = tl.load(in_ptr0 + (r0), rmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp1 = tl.load(in_ptr1 + (r0), rmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp3 = tl.load(in_ptr2 + (r0), rmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp5 = tl.load(in_ptr3 + (r0), rmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp2 = tmp0 + tmp1
tmp4 = tmp2 + tmp3
tmp6 = tmp4 + tmp5
tmp7 = tmp6.to(tl.float32)
tmp8 = tmp7 * tmp7
tmp9 = tl.broadcast_to(tmp8, [XBLOCK, RBLOCK])
tmp11 = _tmp10 + tmp9
_tmp10 = tl.where(rmask, tmp11, _tmp10)
tmp10 = tl.sum(_tmp10, 1)[:, None]
for roffset in range(0, rnumel, RBLOCK):
rindex = roffset + rbase
rmask = rindex < rnumel
r0 = rindex
tmp12 = tl.load(in_ptr0 + (r0), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp13 = tl.load(in_ptr1 + (r0), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp15 = tl.load(in_ptr2 + (r0), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp17 = tl.load(in_ptr3 + (r0), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp27 = tl.load(in_ptr4 + (r0), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp14 = tmp12 + tmp13
tmp16 = tmp14 + tmp15
tmp18 = tmp16 + tmp17
tmp19 = tmp18.to(tl.float32)
tmp20 = 4096.0
tmp21 = tmp10 / tmp20
tmp22 = 1e-05
tmp23 = tmp21 + tmp22
tmp24 = libdevice.rsqrt(tmp23)
tmp25 = tmp19 * tmp24
tmp26 = tmp25.to(tl.float32)
tmp28 = tmp26 * tmp27
tl.store(out_ptr1 + (tl.broadcast_to(r0, [XBLOCK, RBLOCK])), tmp28, rmask)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_ubuntu/yi/cyiywyx2b2ap35djo57g2tkxu3hdej3o4ootd3nr6bigecahyoeu.py
# Source Nodes: [float_16, h_2, h_3, mean_7, mul_56, out_11, out_17, out_20, out_21], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, torchao.fp16act_fp6weight_linear]
# float_16 => convert_element_type_68
# h_2 => add_18
# h_3 => add_25
# mean_7 => mean_7
# mul_56 => mul_75
# out_11 => add_13
# out_17 => add_20
# out_20 => fp16act_fp6weight_linear_17
# out_21 => fp16act_fp6weight_linear_18
triton_red_fused__to_copy_add_fp16act_fp6weight_linear_mean_mul_13 = async_compile.triton('triton_red_fused__to_copy_add_fp16act_fp6weight_linear_mean_mul_13', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
@triton_heuristics.reduction(
size_hints=[1, 4096],
reduction_hint=ReductionHint.INNER,
filename=__file__,
triton_meta={'signature': {0: '*fp16', 1: '*fp16', 2: '*fp16', 3: '*fp16', 4: '*fp16', 5: '*fp16', 6: '*fp16', 7: '*fp16', 8: 'i32', 9: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=76), 'constants': {8: 1}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4, 5, 6, 7, 9), equal_to_1=(8,))]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_add_fp16act_fp6weight_linear_mean_mul_13', 'mutated_arg_names': ['in_out_ptr0'], 'no_x_dim': False, 'num_load': 7, 'num_reduction': 1, 'backend_hash': '5002d3b4f89584c56f8df13bbe2510318f7ec3b73ba9efd611640fb08b2fe939', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False}
)
@triton.jit
def triton_red_fused__to_copy_add_fp16act_fp6weight_linear_mean_mul_13(in_out_ptr0, in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, out_ptr1, out_ptr2, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr):
xnumel = 1
rnumel = 4096
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
xmask = xindex < xnumel
rbase = tl.arange(0, RBLOCK)[None, :]
_tmp12 = tl.full([XBLOCK, RBLOCK], 0, tl.float32)
for roffset in range(0, rnumel, RBLOCK):
rindex = roffset + rbase
rmask = rindex < rnumel
r0 = rindex
tmp0 = tl.load(in_out_ptr0 + (r0), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp1 = tl.load(in_ptr0 + (r0), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp3 = tl.load(in_ptr1 + (r0), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp5 = tl.load(in_ptr2 + (r0), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp7 = tl.load(in_ptr3 + (r0), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp2 = tmp0 + tmp1
tmp4 = tmp2 + tmp3
tmp6 = tmp4 + tmp5
tmp8 = tmp6 + tmp7
tmp9 = tmp8.to(tl.float32)
tmp10 = tmp9 * tmp9
tmp11 = tl.broadcast_to(tmp10, [XBLOCK, RBLOCK])
tmp13 = _tmp12 + tmp11
_tmp12 = tl.where(rmask, tmp13, _tmp12)
tl.store(in_out_ptr0 + (tl.broadcast_to(r0, [XBLOCK, RBLOCK])), tmp8, rmask)
tmp12 = tl.sum(_tmp12, 1)[:, None]
for roffset in range(0, rnumel, RBLOCK):
rindex = roffset + rbase
rmask = rindex < rnumel
r0 = rindex
tmp14 = tl.load(in_out_ptr0 + (r0), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp23 = tl.load(in_ptr4 + (r0), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp15 = tmp14.to(tl.float32)
tmp16 = 4096.0
tmp17 = tmp12 / tmp16
tmp18 = 1e-05
tmp19 = tmp17 + tmp18
tmp20 = libdevice.rsqrt(tmp19)
tmp21 = tmp15 * tmp20
tmp22 = tmp21.to(tl.float32)
tmp24 = tmp22 * tmp23
tl.store(out_ptr1 + (tl.broadcast_to(r0, [XBLOCK, RBLOCK])), tmp24, rmask)
tl.store(out_ptr2 + (tl.broadcast_to(r0, [XBLOCK, RBLOCK])), tmp24, rmask)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_ubuntu/74/c74rqt5v7cgztcvqz3g36kdattrzmlrq33illnplnagfadtt7nqv.py
# Source Nodes: [out_22], Original ATen: [torchao.fp16act_fp6weight_linear]
# out_22 => fp16act_fp6weight_linear_19
triton_poi_fused_fp16act_fp6weight_linear_14 = async_compile.triton('triton_poi_fused_fp16act_fp6weight_linear_14', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
@triton_heuristics.pointwise(
size_hints=[16384],
filename=__file__,
triton_meta={'signature': {0: '*fp16', 1: '*fp16', 2: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=76), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2), equal_to_1=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_fp16act_fp6weight_linear_14', 'mutated_arg_names': ['in_out_ptr0'], 'no_x_dim': False, 'num_load': 2, 'num_reduction': 0, 'backend_hash': '5002d3b4f89584c56f8df13bbe2510318f7ec3b73ba9efd611640fb08b2fe939', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False},
min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused_fp16act_fp6weight_linear_14(in_out_ptr0, in_ptr0, xnumel, XBLOCK : tl.constexpr):
xnumel = 11008
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = xindex < xnumel
x0 = xindex
tmp0 = tl.load(in_ptr0 + (x0), xmask).to(tl.float32)
tmp5 = tl.load(in_out_ptr0 + (x0), xmask).to(tl.float32)
tmp1 = tmp0.to(tl.float32)
tmp2 = tl.sigmoid(tmp1)
tmp3 = tmp1 * tmp2
tmp4 = tmp3.to(tl.float32)
tmp6 = tmp4 * tmp5
tl.store(in_out_ptr0 + (x0), tmp6, xmask)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_ubuntu/od/codwx422sav4ibjcz7e3u2ii64iq4gtjlj6xmjifbwmqjhhlbprf.py
# Source Nodes: [float_24, h_4, h_5, mean_11, mul_86, out_23, out_29, out_32, out_33], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, torchao.fp16act_fp6weight_linear]
# float_24 => convert_element_type_104
# h_4 => add_32
# h_5 => add_39
# mean_11 => mean_11
# mul_86 => mul_115
# out_23 => add_27
# out_29 => add_34
# out_32 => fp16act_fp6weight_linear_27
# out_33 => fp16act_fp6weight_linear_28
triton_red_fused__to_copy_add_fp16act_fp6weight_linear_mean_mul_15 = async_compile.triton('triton_red_fused__to_copy_add_fp16act_fp6weight_linear_mean_mul_15', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
@triton_heuristics.reduction(
size_hints=[1, 4096],
reduction_hint=ReductionHint.INNER,
filename=__file__,
triton_meta={'signature': {0: '*fp16', 1: '*fp16', 2: '*fp16', 3: '*fp16', 4: '*fp16', 5: '*fp16', 6: '*fp16', 7: '*fp16', 8: 'i32', 9: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=76), 'constants': {8: 1}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4, 5, 6, 7, 9), equal_to_1=(8,))]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_add_fp16act_fp6weight_linear_mean_mul_15', 'mutated_arg_names': ['in_out_ptr0'], 'no_x_dim': False, 'num_load': 7, 'num_reduction': 1, 'backend_hash': '5002d3b4f89584c56f8df13bbe2510318f7ec3b73ba9efd611640fb08b2fe939', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False}
)
@triton.jit
def triton_red_fused__to_copy_add_fp16act_fp6weight_linear_mean_mul_15(in_out_ptr0, in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, out_ptr1, out_ptr2, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr):
xnumel = 1
rnumel = 4096
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
xmask = xindex < xnumel
rbase = tl.arange(0, RBLOCK)[None, :]
_tmp12 = tl.full([XBLOCK, RBLOCK], 0, tl.float32)
for roffset in range(0, rnumel, RBLOCK):
rindex = roffset + rbase
rmask = rindex < rnumel
r0 = rindex
tmp0 = tl.load(in_ptr0 + (r0), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp1 = tl.load(in_out_ptr0 + (r0), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp3 = tl.load(in_ptr1 + (r0), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp5 = tl.load(in_ptr2 + (r0), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp7 = tl.load(in_ptr3 + (r0), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp2 = tmp0 + tmp1
tmp4 = tmp2 + tmp3
tmp6 = tmp4 + tmp5
tmp8 = tmp6 + tmp7
tmp9 = tmp8.to(tl.float32)
tmp10 = tmp9 * tmp9
tmp11 = tl.broadcast_to(tmp10, [XBLOCK, RBLOCK])
tmp13 = _tmp12 + tmp11
_tmp12 = tl.where(rmask, tmp13, _tmp12)
tl.store(in_out_ptr0 + (tl.broadcast_to(r0, [XBLOCK, RBLOCK])), tmp8, rmask)
tmp12 = tl.sum(_tmp12, 1)[:, None]
for roffset in range(0, rnumel, RBLOCK):
rindex = roffset + rbase
rmask = rindex < rnumel
r0 = rindex
tmp14 = tl.load(in_out_ptr0 + (r0), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp23 = tl.load(in_ptr4 + (r0), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp15 = tmp14.to(tl.float32)
tmp16 = 4096.0
tmp17 = tmp12 / tmp16
tmp18 = 1e-05
tmp19 = tmp17 + tmp18
tmp20 = libdevice.rsqrt(tmp19)
tmp21 = tmp15 * tmp20
tmp22 = tmp21.to(tl.float32)
tmp24 = tmp22 * tmp23
tl.store(out_ptr1 + (tl.broadcast_to(r0, [XBLOCK, RBLOCK])), tmp24, rmask)
tl.store(out_ptr2 + (tl.broadcast_to(r0, [XBLOCK, RBLOCK])), tmp24, rmask)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_ubuntu/z6/cz6oxi5i6lto4h7ib75rlkpzqsocefhcbwhm7nhdpxravdtgiuro.py
# Source Nodes: [logits_1], Original ATen: [aten.div]
# logits_1 => div_32
triton_poi_fused_div_16 = async_compile.triton('triton_poi_fused_div_16', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
@triton_heuristics.pointwise(
size_hints=[32768],
filename=__file__,
triton_meta={'signature': {0: '*fp16', 1: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=76), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1), equal_to_1=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_div_16', 'mutated_arg_names': ['in_out_ptr0'], 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': '5002d3b4f89584c56f8df13bbe2510318f7ec3b73ba9efd611640fb08b2fe939', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False},
min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused_div_16(in_out_ptr0, xnumel, XBLOCK : tl.constexpr):
xnumel = 32000
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = xindex < xnumel
x0 = xindex
tmp0 = tl.load(in_out_ptr0 + (x0), xmask).to(tl.float32)
tmp1 = 1.25
tmp2 = tmp0 * tmp1
tl.store(in_out_ptr0 + (x0), tmp2, xmask)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_ubuntu/xv/cxvjbig2sbormp2kxcubuvv2h5v3vxuc4jta7llq4rnh6kwuxhuo.py
# Source Nodes: [logits_2, lt, probs], Original ATen: [aten._softmax, aten.lt, aten.scalar_tensor, aten.where]
# logits_2 => full_default_64, where_32
# lt => lt
# probs => amax_32, convert_element_type_578
triton_red_fused__softmax_lt_scalar_tensor_where_17 = async_compile.triton('triton_red_fused__softmax_lt_scalar_tensor_where_17', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
@triton_heuristics.reduction(
size_hints=[4, 8192],
reduction_hint=ReductionHint.INNER,
filename=__file__,
triton_meta={'signature': {0: '*fp16', 1: '*fp16', 2: '*fp32', 3: 'i32', 4: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=76), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 4), equal_to_1=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused__softmax_lt_scalar_tensor_where_17', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 2, 'num_reduction': 1, 'backend_hash': '5002d3b4f89584c56f8df13bbe2510318f7ec3b73ba9efd611640fb08b2fe939', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False}
)
@triton.jit
def triton_red_fused__softmax_lt_scalar_tensor_where_17(in_ptr0, in_ptr1, out_ptr0, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr):
xnumel = 4
rnumel = 8000
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
xmask = xindex < xnumel
rbase = tl.arange(0, RBLOCK)[None, :]
x0 = xindex
tmp1 = tl.load(in_ptr1 + (199)).to(tl.float32)
tmp2 = tl.broadcast_to(tmp1, [XBLOCK, RBLOCK])
_tmp8 = tl.full([XBLOCK, RBLOCK], float("-inf"), tl.float32)
for roffset in range(0, rnumel, RBLOCK):
rindex = roffset + rbase
rmask = rindex < rnumel
r1 = rindex
tmp0 = tl.load(in_ptr0 + (r1 + (8000*x0)), rmask & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp3 = tmp0 < tmp2
tmp4 = float("-inf")
tmp5 = tl.where(tmp3, tmp4, tmp0)
tmp6 = tmp5.to(tl.float32)
tmp7 = tl.broadcast_to(tmp6, [XBLOCK, RBLOCK])
tmp9 = triton_helpers.maximum(_tmp8, tmp7)
_tmp8 = tl.where(rmask & xmask, tmp9, _tmp8)
tmp8 = triton_helpers.max2(_tmp8, 1)[:, None]
tl.store(out_ptr0 + (x0), tmp8, xmask)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_ubuntu/hz/chzqxkcllimh3vsvmsxoviuxtefrsfgsde3el6pgijogxpaxrs7r.py
# Source Nodes: [logits_2, lt, probs], Original ATen: [aten._softmax, aten.lt, aten.scalar_tensor, aten.where]
# logits_2 => full_default_64, where_32
# lt => lt
# probs => amax_32, convert_element_type_578
triton_per_fused__softmax_lt_scalar_tensor_where_18 = async_compile.triton('triton_per_fused__softmax_lt_scalar_tensor_where_18', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
@triton_heuristics.persistent_reduction(
size_hints=[1, 4],
reduction_hint=ReductionHint.INNER,
filename=__file__,
triton_meta={'signature': {0: '*fp32', 1: '*fp32', 2: 'i32', 3: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=76), 'constants': {2: 1}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1), equal_to_1=(2,))]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_per_fused__softmax_lt_scalar_tensor_where_18', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': '5002d3b4f89584c56f8df13bbe2510318f7ec3b73ba9efd611640fb08b2fe939', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False}
)
@triton.jit
def triton_per_fused__softmax_lt_scalar_tensor_where_18(in_ptr0, out_ptr0, xnumel, rnumel, XBLOCK : tl.constexpr):
xnumel = 1
rnumel = 4
RBLOCK: tl.constexpr = 4
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
xmask = xindex < xnumel
rindex = tl.arange(0, RBLOCK)[None, :]
roffset = 0
rmask = rindex < rnumel
r0 = rindex
tmp0 = tl.load(in_ptr0 + (r0), rmask, other=0.0)
tmp1 = tl.broadcast_to(tmp0, [XBLOCK, RBLOCK])
tmp3 = tl.where(rmask, tmp1, float("-inf"))
tmp4 = triton_helpers.max2(tmp3, 1)[:, None]
tl.store(out_ptr0 + (tl.full([XBLOCK, 1], 0, tl.int32)), tmp4, None)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_ubuntu/za/czakkdksd2yxt6hanh7vxqjmbj7ara4d4rza6bwjwwt4stxakxtm.py
# Source Nodes: [logits_2, lt, probs], Original ATen: [aten._softmax, aten.lt, aten.scalar_tensor, aten.where]
# logits_2 => full_default_64, where_32
# lt => lt
# probs => convert_element_type_578, exp_32, sub_96, sum_97
triton_red_fused__softmax_lt_scalar_tensor_where_19 = async_compile.triton('triton_red_fused__softmax_lt_scalar_tensor_where_19', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
@triton_heuristics.reduction(
size_hints=[4, 8192],
reduction_hint=ReductionHint.INNER,
filename=__file__,
triton_meta={'signature': {0: '*fp16', 1: '*fp16', 2: '*fp32', 3: '*fp32', 4: 'i32', 5: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=76), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 5), equal_to_1=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused__softmax_lt_scalar_tensor_where_19', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 3, 'num_reduction': 1, 'backend_hash': '5002d3b4f89584c56f8df13bbe2510318f7ec3b73ba9efd611640fb08b2fe939', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False}
)
@triton.jit
def triton_red_fused__softmax_lt_scalar_tensor_where_19(in_ptr0, in_ptr1, in_ptr2, out_ptr0, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr):
xnumel = 4
rnumel = 8000
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
xmask = xindex < xnumel
rbase = tl.arange(0, RBLOCK)[None, :]
x0 = xindex
tmp1 = tl.load(in_ptr1 + (199)).to(tl.float32)
tmp2 = tl.broadcast_to(tmp1, [XBLOCK, RBLOCK])
tmp7 = tl.load(in_ptr2 + (0))
tmp8 = tl.broadcast_to(tmp7, [XBLOCK, RBLOCK])
_tmp12 = tl.full([XBLOCK, RBLOCK], 0, tl.float32)
for roffset in range(0, rnumel, RBLOCK):
rindex = roffset + rbase
rmask = rindex < rnumel
r1 = rindex
tmp0 = tl.load(in_ptr0 + (r1 + (8000*x0)), rmask & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp3 = tmp0 < tmp2
tmp4 = float("-inf")
tmp5 = tl.where(tmp3, tmp4, tmp0)
tmp6 = tmp5.to(tl.float32)
tmp9 = tmp6 - tmp8
tmp10 = tl_math.exp(tmp9)
tmp11 = tl.broadcast_to(tmp10, [XBLOCK, RBLOCK])
tmp13 = _tmp12 + tmp11
_tmp12 = tl.where(rmask & xmask, tmp13, _tmp12)
tmp12 = tl.sum(_tmp12, 1)[:, None]
tl.store(out_ptr0 + (x0), tmp12, xmask)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_ubuntu/aa/caad4sxzjec6myyhypd4e6odaasgzubfdw6lhsfkhrv5jqg4uc7g.py
# Source Nodes: [logits_2, lt, probs], Original ATen: [aten._softmax, aten.lt, aten.scalar_tensor, aten.where]
# logits_2 => full_default_64, where_32
# lt => lt
# probs => convert_element_type_578, exp_32, sub_96, sum_97
triton_per_fused__softmax_lt_scalar_tensor_where_20 = async_compile.triton('triton_per_fused__softmax_lt_scalar_tensor_where_20', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
@triton_heuristics.persistent_reduction(
size_hints=[1, 4],
reduction_hint=ReductionHint.INNER,
filename=__file__,
triton_meta={'signature': {0: '*fp32', 1: '*fp32', 2: 'i32', 3: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=76), 'constants': {2: 1}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1), equal_to_1=(2,))]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_per_fused__softmax_lt_scalar_tensor_where_20', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': '5002d3b4f89584c56f8df13bbe2510318f7ec3b73ba9efd611640fb08b2fe939', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False}
)
@triton.jit
def triton_per_fused__softmax_lt_scalar_tensor_where_20(in_ptr0, out_ptr0, xnumel, rnumel, XBLOCK : tl.constexpr):
xnumel = 1
rnumel = 4
RBLOCK: tl.constexpr = 4
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
xmask = xindex < xnumel
rindex = tl.arange(0, RBLOCK)[None, :]
roffset = 0
rmask = rindex < rnumel
r0 = rindex
tmp0 = tl.load(in_ptr0 + (r0), rmask, other=0.0)
tmp1 = tl.broadcast_to(tmp0, [XBLOCK, RBLOCK])
tmp3 = tl.where(rmask, tmp1, 0)
tmp4 = tl.sum(tmp3, 1)[:, None]
tl.store(out_ptr0 + (tl.full([XBLOCK, 1], 0, tl.int32)), tmp4, None)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_ubuntu/ch/cch6v3lnecv5tfe5ipqpowh4xge3iiqm4nrz67gnjbwlbump6mow.py
# Source Nodes: [argmax, idx_next, logits_2, lt, probs, q_128, truediv_1], Original ATen: [aten._softmax, aten._to_copy, aten.argmax, aten.div, aten.exponential, aten.lt, aten.scalar_tensor, aten.where]
# argmax => argmax
# idx_next => convert_element_type_582
# logits_2 => full_default_64, where_32
# lt => lt
# probs => convert_element_type_578, convert_element_type_579, div_33, exp_32, sub_96
# q_128 => convert_element_type_581, log1p, mul_643, neg
# truediv_1 => div_34
triton_red_fused__softmax__to_copy_argmax_div_exponential_lt_scalar_tensor_where_21 = async_compile.triton('triton_red_fused__softmax__to_copy_argmax_div_exponential_lt_scalar_tensor_where_21', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
@triton_heuristics.reduction(
size_hints=[1, 32768],
reduction_hint=ReductionHint.DEFAULT,
filename=__file__,
triton_meta={'signature': {0: '*fp16', 1: '*fp16', 2: '*fp32', 3: '*fp32', 4: '*i64', 5: '*i32', 6: 'i32', 7: 'i32', 8: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=76), 'constants': {7: 1}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4, 5, 8), equal_to_1=(7,))]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused__softmax__to_copy_argmax_div_exponential_lt_scalar_tensor_where_21', 'mutated_arg_names': ['in_out_ptr0'], 'no_x_dim': False, 'num_load': 4, 'num_reduction': 1, 'backend_hash': '5002d3b4f89584c56f8df13bbe2510318f7ec3b73ba9efd611640fb08b2fe939', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False}
)
@triton.jit
def triton_red_fused__softmax__to_copy_argmax_div_exponential_lt_scalar_tensor_where_21(in_out_ptr0, in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr2, load_seed_offset, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr):
xnumel = 1
rnumel = 32000
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
xmask = xindex < xnumel
rbase = tl.arange(0, RBLOCK)[None, :]
tmp1 = tl.load(in_ptr0 + (199)).to(tl.float32)
tmp2 = tl.broadcast_to(tmp1, [XBLOCK, RBLOCK])
tmp7 = tl.load(in_ptr1 + (0))
tmp8 = tl.broadcast_to(tmp7, [XBLOCK, RBLOCK])
tmp11 = tl.load(in_ptr2 + (0))
tmp12 = tl.broadcast_to(tmp11, [XBLOCK, RBLOCK])
_tmp25 = tl.full([XBLOCK, RBLOCK], float("-inf"), tl.float32)
_tmp25_index = tl.full([XBLOCK, RBLOCK], 9223372036854775807, tl.int64)
for roffset in range(0, rnumel, RBLOCK):
rindex = roffset + rbase
rmask = rindex < rnumel
r0 = rindex
tmp0 = tl.load(in_out_ptr0 + (r0), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp3 = tmp0 < tmp2
tmp4 = float("-inf")
tmp5 = tl.where(tmp3, tmp4, tmp0)
tmp6 = tmp5.to(tl.float32)
tmp9 = tmp6 - tmp8
tmp10 = tl_math.exp(tmp9)
tmp13 = tmp10 / tmp12
tmp14 = tmp13.to(tl.float32)
tmp15 = tl.load(in_ptr3 + load_seed_offset)
tmp16 = r0
tmp17 = tl.rand(tmp15, (tmp16).to(tl.uint32))
tmp18 = -tmp17
tmp19 = libdevice.log1p(tmp18)
tmp20 = -1.0
tmp21 = tmp19 * tmp20
tmp22 = tmp21.to(tl.float32)
tmp23 = tmp14 / tmp22
tmp24 = tl.broadcast_to(tmp23, [XBLOCK, RBLOCK])
_tmp25_next, _tmp25_index_next = triton_helpers.maximum_with_index(
_tmp25, _tmp25_index, tmp24, rindex
)
_tmp25 = tl.where(rmask, _tmp25_next, _tmp25)
_tmp25_index = tl.where(rmask, _tmp25_index_next, _tmp25_index)
tl.store(in_out_ptr0 + (tl.broadcast_to(r0, [XBLOCK, RBLOCK])), tmp14, rmask)
_, tmp25_tmp = triton_helpers.max_with_index(_tmp25, _tmp25_index, 1)
tmp25 = tmp25_tmp[:, None]
tmp26 = tmp25.to(tl.int32)
tl.store(out_ptr2 + (tl.full([XBLOCK, 1], 0, tl.int32)), tmp26, None)
''', device_str='cuda')
async_compile.wait(globals())
del async_compile
def call(args):
arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1, arg7_1, arg8_1, arg9_1, arg10_1, arg11_1, arg12_1, arg13_1, arg14_1, arg15_1, arg16_1, arg17_1, arg18_1, arg19_1, arg20_1, arg21_1, arg22_1, arg23_1, arg24_1, arg25_1, arg26_1, arg27_1, arg28_1, arg29_1, arg30_1, arg31_1, arg32_1, arg33_1, arg34_1, arg35_1, arg36_1, arg37_1, arg38_1, arg39_1, arg40_1, arg41_1, arg42_1, arg43_1, arg44_1, arg45_1, arg46_1, arg47_1, arg48_1, arg49_1, arg50_1, arg51_1, arg52_1, arg53_1, arg54_1, arg55_1, arg56_1, arg57_1, arg58_1, arg59_1, arg60_1, arg61_1, arg62_1, arg63_1, arg64_1, arg65_1, arg66_1, arg67_1, arg68_1, arg69_1, arg70_1, arg71_1, arg72_1, arg73_1, arg74_1, arg75_1, arg76_1, arg77_1, arg78_1, arg79_1, arg80_1, arg81_1, arg82_1, arg83_1, arg84_1, arg85_1, arg86_1, arg87_1, arg88_1, arg89_1, arg90_1, arg91_1, arg92_1, arg93_1, arg94_1, arg95_1, arg96_1, arg97_1, arg98_1, arg99_1, arg100_1, arg101_1, arg102_1, arg103_1, arg104_1, arg105_1, arg106_1, arg107_1, arg108_1, arg109_1, arg110_1, arg111_1, arg112_1, arg113_1, arg114_1, arg115_1, arg116_1, arg117_1, arg118_1, arg119_1, arg120_1, arg121_1, arg122_1, arg123_1, arg124_1, arg125_1, arg126_1, arg127_1, arg128_1, arg129_1, arg130_1, arg131_1, arg132_1, arg133_1, arg134_1, arg135_1, arg136_1, arg137_1, arg138_1, arg139_1, arg140_1, arg141_1, arg142_1, arg143_1, arg144_1, arg145_1, arg146_1, arg147_1, arg148_1, arg149_1, arg150_1, arg151_1, arg152_1, arg153_1, arg154_1, arg155_1, arg156_1, arg157_1, arg158_1, arg159_1, arg160_1, arg161_1, arg162_1, arg163_1, arg164_1, arg165_1, arg166_1, arg167_1, arg168_1, arg169_1, arg170_1, arg171_1, arg172_1, arg173_1, arg174_1, arg175_1, arg176_1, arg177_1, arg178_1, arg179_1, arg180_1, arg181_1, arg182_1, arg183_1, arg184_1, arg185_1, arg186_1, arg187_1, arg188_1, arg189_1, arg190_1, arg191_1, arg192_1, arg193_1, arg194_1, arg195_1, arg196_1, arg197_1, arg198_1, arg199_1, arg200_1, arg201_1, arg202_1, arg203_1, arg204_1, arg205_1, arg206_1, arg207_1, arg208_1, arg209_1, arg210_1, arg211_1, arg212_1, arg213_1, arg214_1, arg215_1, arg216_1, arg217_1, arg218_1, arg219_1, arg220_1, arg221_1, arg222_1, arg223_1, arg224_1, arg225_1, arg226_1, arg227_1, arg228_1, arg229_1, arg230_1, arg231_1, arg232_1, arg233_1, arg234_1, arg235_1, arg236_1, arg237_1, arg238_1, arg239_1, arg240_1, arg241_1, arg242_1, arg243_1, arg244_1, arg245_1, arg246_1, arg247_1, arg248_1, arg249_1, arg250_1, arg251_1, arg252_1, arg253_1, arg254_1, arg255_1, arg256_1, arg257_1, arg258_1, arg259_1, arg260_1, arg261_1, arg262_1, arg263_1, arg264_1, arg265_1, arg266_1, arg267_1, arg268_1, arg269_1, arg270_1, arg271_1, arg272_1, arg273_1, arg274_1, arg275_1, arg276_1, arg277_1, arg278_1, arg279_1, arg280_1, arg281_1, arg282_1, arg283_1, arg284_1, arg285_1, arg286_1, arg287_1, arg288_1, arg289_1, arg290_1, arg291_1, arg292_1, arg293_1, arg294_1, arg295_1, arg296_1, arg297_1, arg298_1, arg299_1, arg300_1, arg301_1, arg302_1, arg303_1, arg304_1, arg305_1, arg306_1, arg307_1, arg308_1, arg309_1, arg310_1, arg311_1, arg312_1, arg313_1, arg314_1, arg315_1, arg316_1, arg317_1, arg318_1, arg319_1, arg320_1, arg321_1, arg322_1, arg323_1, arg324_1, arg325_1, arg326_1, arg327_1, arg328_1, arg329_1, arg330_1, arg331_1, arg332_1, arg333_1, arg334_1, arg335_1, arg336_1, arg337_1, arg338_1, arg339_1, arg340_1, arg341_1, arg342_1, arg343_1, arg344_1, arg345_1, arg346_1, arg347_1, arg348_1, arg349_1, arg350_1, arg351_1, arg352_1, arg353_1, arg354_1, arg355_1, arg356_1, arg357_1, arg358_1, arg359_1, arg360_1, arg361_1, arg362_1, arg363_1, arg364_1, arg365_1, arg366_1, arg367_1, arg368_1, arg369_1, arg370_1, arg371_1, arg372_1, arg373_1, arg374_1, arg375_1, arg376_1, arg377_1, arg378_1, arg379_1, arg380_1, arg381_1, arg382_1, arg383_1, arg384_1, arg385_1, arg386_1, arg387_1, arg388_1, arg389_1, arg390_1, arg391_1, arg392_1, arg393_1, arg394_1, arg395_1, arg396_1, arg397_1, arg398_1, arg399_1, arg400_1, arg401_1, arg402_1, arg403_1, arg404_1, arg405_1, arg406_1, arg407_1, arg408_1, arg409_1, arg410_1, arg411_1, arg412_1, arg413_1, arg414_1, arg415_1, arg416_1, arg417_1, arg418_1, arg419_1, arg420_1, arg421_1, arg422_1, arg423_1, arg424_1, arg425_1, arg426_1, arg427_1, arg428_1, arg429_1, arg430_1, arg431_1, arg432_1, arg433_1, arg434_1, arg435_1, arg436_1, arg437_1, arg438_1, arg439_1, arg440_1, arg441_1, arg442_1, arg443_1, arg444_1, arg445_1, arg446_1, arg447_1, arg448_1, arg449_1, arg450_1, arg451_1, arg452_1, arg453_1, arg454_1, arg455_1 = args
args.clear()
assert_size_stride(arg0_1, (4096, ), (1, ))
assert_size_stride(arg1_1, (4096, ), (1, ))
assert_size_stride(arg2_1, (4096, ), (1, ))
assert_size_stride(arg3_1, (4096, ), (1, ))
assert_size_stride(arg4_1, (4096, ), (1, ))
assert_size_stride(arg5_1, (4096, ), (1, ))
assert_size_stride(arg6_1, (4096, ), (1, ))
assert_size_stride(arg7_1, (4096, ), (1, ))
assert_size_stride(arg8_1, (4096, ), (1, ))
assert_size_stride(arg9_1, (4096, ), (1, ))
assert_size_stride(arg10_1, (4096, ), (1, ))
assert_size_stride(arg11_1, (4096, ), (1, ))
assert_size_stride(arg12_1, (4096, ), (1, ))
assert_size_stride(arg13_1, (4096, ), (1, ))
assert_size_stride(arg14_1, (4096, ), (1, ))
assert_size_stride(arg15_1, (4096, ), (1, ))
assert_size_stride(arg16_1, (4096, ), (1, ))
assert_size_stride(arg17_1, (4096, ), (1, ))
assert_size_stride(arg18_1, (4096, ), (1, ))
assert_size_stride(arg19_1, (4096, ), (1, ))
assert_size_stride(arg20_1, (4096, ), (1, ))
assert_size_stride(arg21_1, (4096, ), (1, ))
assert_size_stride(arg22_1, (4096, ), (1, ))
assert_size_stride(arg23_1, (4096, ), (1, ))
assert_size_stride(arg24_1, (4096, ), (1, ))
assert_size_stride(arg25_1, (4096, ), (1, ))
assert_size_stride(arg26_1, (4096, ), (1, ))
assert_size_stride(arg27_1, (4096, ), (1, ))
assert_size_stride(arg28_1, (4096, ), (1, ))
assert_size_stride(arg29_1, (4096, ), (1, ))
assert_size_stride(arg30_1, (4096, ), (1, ))
assert_size_stride(arg31_1, (4096, ), (1, ))
assert_size_stride(arg32_1, (4096, ), (1, ))
assert_size_stride(arg33_1, (4096, ), (1, ))
assert_size_stride(arg34_1, (4096, ), (1, ))
assert_size_stride(arg35_1, (4096, ), (1, ))
assert_size_stride(arg36_1, (4096, ), (1, ))
assert_size_stride(arg37_1, (4096, ), (1, ))
assert_size_stride(arg38_1, (4096, ), (1, ))
assert_size_stride(arg39_1, (4096, ), (1, ))
assert_size_stride(arg40_1, (4096, ), (1, ))
assert_size_stride(arg41_1, (4096, ), (1, ))
assert_size_stride(arg42_1, (4096, ), (1, ))
assert_size_stride(arg43_1, (4096, ), (1, ))
assert_size_stride(arg44_1, (4096, ), (1, ))
assert_size_stride(arg45_1, (4096, ), (1, ))
assert_size_stride(arg46_1, (4096, ), (1, ))
assert_size_stride(arg47_1, (4096, ), (1, ))
assert_size_stride(arg48_1, (4096, ), (1, ))
assert_size_stride(arg49_1, (4096, ), (1, ))
assert_size_stride(arg50_1, (4096, ), (1, ))
assert_size_stride(arg51_1, (4096, ), (1, ))
assert_size_stride(arg52_1, (4096, ), (1, ))
assert_size_stride(arg53_1, (4096, ), (1, ))
assert_size_stride(arg54_1, (4096, ), (1, ))
assert_size_stride(arg55_1, (4096, ), (1, ))
assert_size_stride(arg56_1, (4096, ), (1, ))
assert_size_stride(arg57_1, (4096, ), (1, ))
assert_size_stride(arg58_1, (4096, ), (1, ))
assert_size_stride(arg59_1, (4096, ), (1, ))
assert_size_stride(arg60_1, (4096, ), (1, ))
assert_size_stride(arg61_1, (4096, ), (1, ))
assert_size_stride(arg62_1, (4096, ), (1, ))
assert_size_stride(arg63_1, (4096, ), (1, ))
assert_size_stride(arg64_1, (4096, ), (1, ))
assert_size_stride(arg65_1, (32000, 4096), (4096, 1))
assert_size_stride(arg66_1, (2048, 64, 2), (128, 2, 1))
assert_size_stride(arg67_1, (208, 208), (208, 1))
assert_size_stride(arg68_1, (12288, 768), (768, 1))
assert_size_stride(arg69_1, (12288, ), (1, ))
assert_size_stride(arg70_1, (1, 32, 208, 128), (851968, 26624, 128, 1))
assert_size_stride(arg71_1, (1, 32, 208, 128), (851968, 26624, 128, 1))
assert_size_stride(arg72_1, (4096, 768), (768, 1))
assert_size_stride(arg73_1, (4096, ), (1, ))
assert_size_stride(arg74_1, (11008, 768), (768, 1))
assert_size_stride(arg75_1, (11008, ), (1, ))
assert_size_stride(arg76_1, (11008, 768), (768, 1))
assert_size_stride(arg77_1, (11008, ), (1, ))
assert_size_stride(arg78_1, (4096, 2064), (2064, 1))
assert_size_stride(arg79_1, (4096, ), (1, ))
assert_size_stride(arg80_1, (12288, 768), (768, 1))
assert_size_stride(arg81_1, (12288, ), (1, ))
assert_size_stride(arg82_1, (1, 32, 208, 128), (851968, 26624, 128, 1))
assert_size_stride(arg83_1, (1, 32, 208, 128), (851968, 26624, 128, 1))
assert_size_stride(arg84_1, (4096, 768), (768, 1))
assert_size_stride(arg85_1, (4096, ), (1, ))
assert_size_stride(arg86_1, (11008, 768), (768, 1))
assert_size_stride(arg87_1, (11008, ), (1, ))
assert_size_stride(arg88_1, (11008, 768), (768, 1))
assert_size_stride(arg89_1, (11008, ), (1, ))
assert_size_stride(arg90_1, (4096, 2064), (2064, 1))
assert_size_stride(arg91_1, (4096, ), (1, ))
assert_size_stride(arg92_1, (12288, 768), (768, 1))
assert_size_stride(arg93_1, (12288, ), (1, ))
assert_size_stride(arg94_1, (1, 32, 208, 128), (851968, 26624, 128, 1))
assert_size_stride(arg95_1, (1, 32, 208, 128), (851968, 26624, 128, 1))
assert_size_stride(arg96_1, (4096, 768), (768, 1))
assert_size_stride(arg97_1, (4096, ), (1, ))
assert_size_stride(arg98_1, (11008, 768), (768, 1))
assert_size_stride(arg99_1, (11008, ), (1, ))
assert_size_stride(arg100_1, (11008, 768), (768, 1))
assert_size_stride(arg101_1, (11008, ), (1, ))
assert_size_stride(arg102_1, (4096, 2064), (2064, 1))
assert_size_stride(arg103_1, (4096, ), (1, ))
assert_size_stride(arg104_1, (12288, 768), (768, 1))
assert_size_stride(arg105_1, (12288, ), (1, ))
assert_size_stride(arg106_1, (1, 32, 208, 128), (851968, 26624, 128, 1))
assert_size_stride(arg107_1, (1, 32, 208, 128), (851968, 26624, 128, 1))
assert_size_stride(arg108_1, (4096, 768), (768, 1))
assert_size_stride(arg109_1, (4096, ), (1, ))
assert_size_stride(arg110_1, (11008, 768), (768, 1))
assert_size_stride(arg111_1, (11008, ), (1, ))
assert_size_stride(arg112_1, (11008, 768), (768, 1))
assert_size_stride(arg113_1, (11008, ), (1, ))
assert_size_stride(arg114_1, (4096, 2064), (2064, 1))
assert_size_stride(arg115_1, (4096, ), (1, ))
assert_size_stride(arg116_1, (12288, 768), (768, 1))
assert_size_stride(arg117_1, (12288, ), (1, ))
assert_size_stride(arg118_1, (1, 32, 208, 128), (851968, 26624, 128, 1))
assert_size_stride(arg119_1, (1, 32, 208, 128), (851968, 26624, 128, 1))
assert_size_stride(arg120_1, (4096, 768), (768, 1))
assert_size_stride(arg121_1, (4096, ), (1, ))
assert_size_stride(arg122_1, (11008, 768), (768, 1))
assert_size_stride(arg123_1, (11008, ), (1, ))
assert_size_stride(arg124_1, (11008, 768), (768, 1))
assert_size_stride(arg125_1, (11008, ), (1, ))
assert_size_stride(arg126_1, (4096, 2064), (2064, 1))
assert_size_stride(arg127_1, (4096, ), (1, ))
assert_size_stride(arg128_1, (12288, 768), (768, 1))
assert_size_stride(arg129_1, (12288, ), (1, ))
assert_size_stride(arg130_1, (1, 32, 208, 128), (851968, 26624, 128, 1))
assert_size_stride(arg131_1, (1, 32, 208, 128), (851968, 26624, 128, 1))
assert_size_stride(arg132_1, (4096, 768), (768, 1))
assert_size_stride(arg133_1, (4096, ), (1, ))
assert_size_stride(arg134_1, (11008, 768), (768, 1))
assert_size_stride(arg135_1, (11008, ), (1, ))
assert_size_stride(arg136_1, (11008, 768), (768, 1))
assert_size_stride(arg137_1, (11008, ), (1, ))
assert_size_stride(arg138_1, (4096, 2064), (2064, 1))
assert_size_stride(arg139_1, (4096, ), (1, ))
assert_size_stride(arg140_1, (12288, 768), (768, 1))
assert_size_stride(arg141_1, (12288, ), (1, ))
assert_size_stride(arg142_1, (1, 32, 208, 128), (851968, 26624, 128, 1))
assert_size_stride(arg143_1, (1, 32, 208, 128), (851968, 26624, 128, 1))
assert_size_stride(arg144_1, (4096, 768), (768, 1))
assert_size_stride(arg145_1, (4096, ), (1, ))
assert_size_stride(arg146_1, (11008, 768), (768, 1))
assert_size_stride(arg147_1, (11008, ), (1, ))
assert_size_stride(arg148_1, (11008, 768), (768, 1))
assert_size_stride(arg149_1, (11008, ), (1, ))
assert_size_stride(arg150_1, (4096, 2064), (2064, 1))
assert_size_stride(arg151_1, (4096, ), (1, ))
assert_size_stride(arg152_1, (12288, 768), (768, 1))
assert_size_stride(arg153_1, (12288, ), (1, ))
assert_size_stride(arg154_1, (1, 32, 208, 128), (851968, 26624, 128, 1))
assert_size_stride(arg155_1, (1, 32, 208, 128), (851968, 26624, 128, 1))
assert_size_stride(arg156_1, (4096, 768), (768, 1))
assert_size_stride(arg157_1, (4096, ), (1, ))
assert_size_stride(arg158_1, (11008, 768), (768, 1))
assert_size_stride(arg159_1, (11008, ), (1, ))
assert_size_stride(arg160_1, (11008, 768), (768, 1))
assert_size_stride(arg161_1, (11008, ), (1, ))
assert_size_stride(arg162_1, (4096, 2064), (2064, 1))
assert_size_stride(arg163_1, (4096, ), (1, ))
assert_size_stride(arg164_1, (12288, 768), (768, 1))
assert_size_stride(arg165_1, (12288, ), (1, ))
assert_size_stride(arg166_1, (1, 32, 208, 128), (851968, 26624, 128, 1))
assert_size_stride(arg167_1, (1, 32, 208, 128), (851968, 26624, 128, 1))
assert_size_stride(arg168_1, (4096, 768), (768, 1))
assert_size_stride(arg169_1, (4096, ), (1, ))
assert_size_stride(arg170_1, (11008, 768), (768, 1))
assert_size_stride(arg171_1, (11008, ), (1, ))
assert_size_stride(arg172_1, (11008, 768), (768, 1))
assert_size_stride(arg173_1, (11008, ), (1, ))
assert_size_stride(arg174_1, (4096, 2064), (2064, 1))
assert_size_stride(arg175_1, (4096, ), (1, ))
assert_size_stride(arg176_1, (12288, 768), (768, 1))
assert_size_stride(arg177_1, (12288, ), (1, ))
assert_size_stride(arg178_1, (1, 32, 208, 128), (851968, 26624, 128, 1))
assert_size_stride(arg179_1, (1, 32, 208, 128), (851968, 26624, 128, 1))
assert_size_stride(arg180_1, (4096, 768), (768, 1))
assert_size_stride(arg181_1, (4096, ), (1, ))
assert_size_stride(arg182_1, (11008, 768), (768, 1))
assert_size_stride(arg183_1, (11008, ), (1, ))
assert_size_stride(arg184_1, (11008, 768), (768, 1))
assert_size_stride(arg185_1, (11008, ), (1, ))
assert_size_stride(arg186_1, (4096, 2064), (2064, 1))
assert_size_stride(arg187_1, (4096, ), (1, ))
assert_size_stride(arg188_1, (12288, 768), (768, 1))
assert_size_stride(arg189_1, (12288, ), (1, ))
assert_size_stride(arg190_1, (1, 32, 208, 128), (851968, 26624, 128, 1))
assert_size_stride(arg191_1, (1, 32, 208, 128), (851968, 26624, 128, 1))
assert_size_stride(arg192_1, (4096, 768), (768, 1))
assert_size_stride(arg193_1, (4096, ), (1, ))
assert_size_stride(arg194_1, (11008, 768), (768, 1))
assert_size_stride(arg195_1, (11008, ), (1, ))
assert_size_stride(arg196_1, (11008, 768), (768, 1))
assert_size_stride(arg197_1, (11008, ), (1, ))
assert_size_stride(arg198_1, (4096, 2064), (2064, 1))
assert_size_stride(arg199_1, (4096, ), (1, ))
assert_size_stride(arg200_1, (12288, 768), (768, 1))
assert_size_stride(arg201_1, (12288, ), (1, ))
assert_size_stride(arg202_1, (1, 32, 208, 128), (851968, 26624, 128, 1))
assert_size_stride(arg203_1, (1, 32, 208, 128), (851968, 26624, 128, 1))
assert_size_stride(arg204_1, (4096, 768), (768, 1))
assert_size_stride(arg205_1, (4096, ), (1, ))
assert_size_stride(arg206_1, (11008, 768), (768, 1))
assert_size_stride(arg207_1, (11008, ), (1, ))
assert_size_stride(arg208_1, (11008, 768), (768, 1))
assert_size_stride(arg209_1, (11008, ), (1, ))
assert_size_stride(arg210_1, (4096, 2064), (2064, 1))
assert_size_stride(arg211_1, (4096, ), (1, ))
assert_size_stride(arg212_1, (12288, 768), (768, 1))
assert_size_stride(arg213_1, (12288, ), (1, ))
assert_size_stride(arg214_1, (1, 32, 208, 128), (851968, 26624, 128, 1))
assert_size_stride(arg215_1, (1, 32, 208, 128), (851968, 26624, 128, 1))
assert_size_stride(arg216_1, (4096, 768), (768, 1))
assert_size_stride(arg217_1, (4096, ), (1, ))
assert_size_stride(arg218_1, (11008, 768), (768, 1))
assert_size_stride(arg219_1, (11008, ), (1, ))
assert_size_stride(arg220_1, (11008, 768), (768, 1))
assert_size_stride(arg221_1, (11008, ), (1, ))
assert_size_stride(arg222_1, (4096, 2064), (2064, 1))
assert_size_stride(arg223_1, (4096, ), (1, ))
assert_size_stride(arg224_1, (12288, 768), (768, 1))
assert_size_stride(arg225_1, (12288, ), (1, ))
assert_size_stride(arg226_1, (1, 32, 208, 128), (851968, 26624, 128, 1))
assert_size_stride(arg227_1, (1, 32, 208, 128), (851968, 26624, 128, 1))
assert_size_stride(arg228_1, (4096, 768), (768, 1))
assert_size_stride(arg229_1, (4096, ), (1, ))
assert_size_stride(arg230_1, (11008, 768), (768, 1))
assert_size_stride(arg231_1, (11008, ), (1, ))
assert_size_stride(arg232_1, (11008, 768), (768, 1))
assert_size_stride(arg233_1, (11008, ), (1, ))
assert_size_stride(arg234_1, (4096, 2064), (2064, 1))
assert_size_stride(arg235_1, (4096, ), (1, ))
assert_size_stride(arg236_1, (12288, 768), (768, 1))
assert_size_stride(arg237_1, (12288, ), (1, ))
assert_size_stride(arg238_1, (1, 32, 208, 128), (851968, 26624, 128, 1))
assert_size_stride(arg239_1, (1, 32, 208, 128), (851968, 26624, 128, 1))
assert_size_stride(arg240_1, (4096, 768), (768, 1))
assert_size_stride(arg241_1, (4096, ), (1, ))
assert_size_stride(arg242_1, (11008, 768), (768, 1))
assert_size_stride(arg243_1, (11008, ), (1, ))
assert_size_stride(arg244_1, (11008, 768), (768, 1))
assert_size_stride(arg245_1, (11008, ), (1, ))
assert_size_stride(arg246_1, (4096, 2064), (2064, 1))
assert_size_stride(arg247_1, (4096, ), (1, ))
assert_size_stride(arg248_1, (12288, 768), (768, 1))
assert_size_stride(arg249_1, (12288, ), (1, ))
assert_size_stride(arg250_1, (1, 32, 208, 128), (851968, 26624, 128, 1))
assert_size_stride(arg251_1, (1, 32, 208, 128), (851968, 26624, 128, 1))
assert_size_stride(arg252_1, (4096, 768), (768, 1))
assert_size_stride(arg253_1, (4096, ), (1, ))
assert_size_stride(arg254_1, (11008, 768), (768, 1))
assert_size_stride(arg255_1, (11008, ), (1, ))
assert_size_stride(arg256_1, (11008, 768), (768, 1))
assert_size_stride(arg257_1, (11008, ), (1, ))
assert_size_stride(arg258_1, (4096, 2064), (2064, 1))
assert_size_stride(arg259_1, (4096, ), (1, ))
assert_size_stride(arg260_1, (12288, 768), (768, 1))
assert_size_stride(arg261_1, (12288, ), (1, ))
assert_size_stride(arg262_1, (1, 32, 208, 128), (851968, 26624, 128, 1))
assert_size_stride(arg263_1, (1, 32, 208, 128), (851968, 26624, 128, 1))
assert_size_stride(arg264_1, (4096, 768), (768, 1))
assert_size_stride(arg265_1, (4096, ), (1, ))
assert_size_stride(arg266_1, (11008, 768), (768, 1))
assert_size_stride(arg267_1, (11008, ), (1, ))
assert_size_stride(arg268_1, (11008, 768), (768, 1))
assert_size_stride(arg269_1, (11008, ), (1, ))
assert_size_stride(arg270_1, (4096, 2064), (2064, 1))
assert_size_stride(arg271_1, (4096, ), (1, ))
assert_size_stride(arg272_1, (12288, 768), (768, 1))
assert_size_stride(arg273_1, (12288, ), (1, ))
assert_size_stride(arg274_1, (1, 32, 208, 128), (851968, 26624, 128, 1))
assert_size_stride(arg275_1, (1, 32, 208, 128), (851968, 26624, 128, 1))
assert_size_stride(arg276_1, (4096, 768), (768, 1))
assert_size_stride(arg277_1, (4096, ), (1, ))
assert_size_stride(arg278_1, (11008, 768), (768, 1))
assert_size_stride(arg279_1, (11008, ), (1, ))
assert_size_stride(arg280_1, (11008, 768), (768, 1))
assert_size_stride(arg281_1, (11008, ), (1, ))
assert_size_stride(arg282_1, (4096, 2064), (2064, 1))
assert_size_stride(arg283_1, (4096, ), (1, ))
assert_size_stride(arg284_1, (12288, 768), (768, 1))
assert_size_stride(arg285_1, (12288, ), (1, ))
assert_size_stride(arg286_1, (1, 32, 208, 128), (851968, 26624, 128, 1))
assert_size_stride(arg287_1, (1, 32, 208, 128), (851968, 26624, 128, 1))
assert_size_stride(arg288_1, (4096, 768), (768, 1))
assert_size_stride(arg289_1, (4096, ), (1, ))
assert_size_stride(arg290_1, (11008, 768), (768, 1))
assert_size_stride(arg291_1, (11008, ), (1, ))
assert_size_stride(arg292_1, (11008, 768), (768, 1))
assert_size_stride(arg293_1, (11008, ), (1, ))
assert_size_stride(arg294_1, (4096, 2064), (2064, 1))
assert_size_stride(arg295_1, (4096, ), (1, ))
assert_size_stride(arg296_1, (12288, 768), (768, 1))
assert_size_stride(arg297_1, (12288, ), (1, ))
assert_size_stride(arg298_1, (1, 32, 208, 128), (851968, 26624, 128, 1))
assert_size_stride(arg299_1, (1, 32, 208, 128), (851968, 26624, 128, 1))
assert_size_stride(arg300_1, (4096, 768), (768, 1))
assert_size_stride(arg301_1, (4096, ), (1, ))
assert_size_stride(arg302_1, (11008, 768), (768, 1))
assert_size_stride(arg303_1, (11008, ), (1, ))
assert_size_stride(arg304_1, (11008, 768), (768, 1))
assert_size_stride(arg305_1, (11008, ), (1, ))
assert_size_stride(arg306_1, (4096, 2064), (2064, 1))
assert_size_stride(arg307_1, (4096, ), (1, ))
assert_size_stride(arg308_1, (12288, 768), (768, 1))
assert_size_stride(arg309_1, (12288, ), (1, ))
assert_size_stride(arg310_1, (1, 32, 208, 128), (851968, 26624, 128, 1))
assert_size_stride(arg311_1, (1, 32, 208, 128), (851968, 26624, 128, 1))
assert_size_stride(arg312_1, (4096, 768), (768, 1))
assert_size_stride(arg313_1, (4096, ), (1, ))
assert_size_stride(arg314_1, (11008, 768), (768, 1))
assert_size_stride(arg315_1, (11008, ), (1, ))
assert_size_stride(arg316_1, (11008, 768), (768, 1))
assert_size_stride(arg317_1, (11008, ), (1, ))
assert_size_stride(arg318_1, (4096, 2064), (2064, 1))
assert_size_stride(arg319_1, (4096, ), (1, ))
assert_size_stride(arg320_1, (12288, 768), (768, 1))
assert_size_stride(arg321_1, (12288, ), (1, ))
assert_size_stride(arg322_1, (1, 32, 208, 128), (851968, 26624, 128, 1))
assert_size_stride(arg323_1, (1, 32, 208, 128), (851968, 26624, 128, 1))
assert_size_stride(arg324_1, (4096, 768), (768, 1))
assert_size_stride(arg325_1, (4096, ), (1, ))
assert_size_stride(arg326_1, (11008, 768), (768, 1))
assert_size_stride(arg327_1, (11008, ), (1, ))
assert_size_stride(arg328_1, (11008, 768), (768, 1))
assert_size_stride(arg329_1, (11008, ), (1, ))
assert_size_stride(arg330_1, (4096, 2064), (2064, 1))
assert_size_stride(arg331_1, (4096, ), (1, ))
assert_size_stride(arg332_1, (12288, 768), (768, 1))
assert_size_stride(arg333_1, (12288, ), (1, ))
assert_size_stride(arg334_1, (1, 32, 208, 128), (851968, 26624, 128, 1))
assert_size_stride(arg335_1, (1, 32, 208, 128), (851968, 26624, 128, 1))
assert_size_stride(arg336_1, (4096, 768), (768, 1))
assert_size_stride(arg337_1, (4096, ), (1, ))
assert_size_stride(arg338_1, (11008, 768), (768, 1))
assert_size_stride(arg339_1, (11008, ), (1, ))
assert_size_stride(arg340_1, (11008, 768), (768, 1))
assert_size_stride(arg341_1, (11008, ), (1, ))
assert_size_stride(arg342_1, (4096, 2064), (2064, 1))
assert_size_stride(arg343_1, (4096, ), (1, ))
assert_size_stride(arg344_1, (12288, 768), (768, 1))
assert_size_stride(arg345_1, (12288, ), (1, ))
assert_size_stride(arg346_1, (1, 32, 208, 128), (851968, 26624, 128, 1))
assert_size_stride(arg347_1, (1, 32, 208, 128), (851968, 26624, 128, 1))
assert_size_stride(arg348_1, (4096, 768), (768, 1))
assert_size_stride(arg349_1, (4096, ), (1, ))
assert_size_stride(arg350_1, (11008, 768), (768, 1))
assert_size_stride(arg351_1, (11008, ), (1, ))
assert_size_stride(arg352_1, (11008, 768), (768, 1))
assert_size_stride(arg353_1, (11008, ), (1, ))
assert_size_stride(arg354_1, (4096, 2064), (2064, 1))
assert_size_stride(arg355_1, (4096, ), (1, ))
assert_size_stride(arg356_1, (12288, 768), (768, 1))
assert_size_stride(arg357_1, (12288, ), (1, ))
assert_size_stride(arg358_1, (1, 32, 208, 128), (851968, 26624, 128, 1))
assert_size_stride(arg359_1, (1, 32, 208, 128), (851968, 26624, 128, 1))
assert_size_stride(arg360_1, (4096, 768), (768, 1))
assert_size_stride(arg361_1, (4096, ), (1, ))
assert_size_stride(arg362_1, (11008, 768), (768, 1))
assert_size_stride(arg363_1, (11008, ), (1, ))
assert_size_stride(arg364_1, (11008, 768), (768, 1))
assert_size_stride(arg365_1, (11008, ), (1, ))
assert_size_stride(arg366_1, (4096, 2064), (2064, 1))
assert_size_stride(arg367_1, (4096, ), (1, ))
assert_size_stride(arg368_1, (12288, 768), (768, 1))
assert_size_stride(arg369_1, (12288, ), (1, ))
assert_size_stride(arg370_1, (1, 32, 208, 128), (851968, 26624, 128, 1))
assert_size_stride(arg371_1, (1, 32, 208, 128), (851968, 26624, 128, 1))
assert_size_stride(arg372_1, (4096, 768), (768, 1))
assert_size_stride(arg373_1, (4096, ), (1, ))
assert_size_stride(arg374_1, (11008, 768), (768, 1))
assert_size_stride(arg375_1, (11008, ), (1, ))
assert_size_stride(arg376_1, (11008, 768), (768, 1))
assert_size_stride(arg377_1, (11008, ), (1, ))
assert_size_stride(arg378_1, (4096, 2064), (2064, 1))
assert_size_stride(arg379_1, (4096, ), (1, ))
assert_size_stride(arg380_1, (12288, 768), (768, 1))
assert_size_stride(arg381_1, (12288, ), (1, ))
assert_size_stride(arg382_1, (1, 32, 208, 128), (851968, 26624, 128, 1))
assert_size_stride(arg383_1, (1, 32, 208, 128), (851968, 26624, 128, 1))
assert_size_stride(arg384_1, (4096, 768), (768, 1))
assert_size_stride(arg385_1, (4096, ), (1, ))
assert_size_stride(arg386_1, (11008, 768), (768, 1))
assert_size_stride(arg387_1, (11008, ), (1, ))
assert_size_stride(arg388_1, (11008, 768), (768, 1))
assert_size_stride(arg389_1, (11008, ), (1, ))
assert_size_stride(arg390_1, (4096, 2064), (2064, 1))
assert_size_stride(arg391_1, (4096, ), (1, ))
assert_size_stride(arg392_1, (12288, 768), (768, 1))
assert_size_stride(arg393_1, (12288, ), (1, ))
assert_size_stride(arg394_1, (1, 32, 208, 128), (851968, 26624, 128, 1))
assert_size_stride(arg395_1, (1, 32, 208, 128), (851968, 26624, 128, 1))
assert_size_stride(arg396_1, (4096, 768), (768, 1))
assert_size_stride(arg397_1, (4096, ), (1, ))
assert_size_stride(arg398_1, (11008, 768), (768, 1))
assert_size_stride(arg399_1, (11008, ), (1, ))
assert_size_stride(arg400_1, (11008, 768), (768, 1))
assert_size_stride(arg401_1, (11008, ), (1, ))
assert_size_stride(arg402_1, (4096, 2064), (2064, 1))
assert_size_stride(arg403_1, (4096, ), (1, ))
assert_size_stride(arg404_1, (12288, 768), (768, 1))
assert_size_stride(arg405_1, (12288, ), (1, ))
assert_size_stride(arg406_1, (1, 32, 208, 128), (851968, 26624, 128, 1))
assert_size_stride(arg407_1, (1, 32, 208, 128), (851968, 26624, 128, 1))
assert_size_stride(arg408_1, (4096, 768), (768, 1))
assert_size_stride(arg409_1, (4096, ), (1, ))
assert_size_stride(arg410_1, (11008, 768), (768, 1))
assert_size_stride(arg411_1, (11008, ), (1, ))
assert_size_stride(arg412_1, (11008, 768), (768, 1))
assert_size_stride(arg413_1, (11008, ), (1, ))
assert_size_stride(arg414_1, (4096, 2064), (2064, 1))
assert_size_stride(arg415_1, (4096, ), (1, ))
assert_size_stride(arg416_1, (12288, 768), (768, 1))
assert_size_stride(arg417_1, (12288, ), (1, ))
assert_size_stride(arg418_1, (1, 32, 208, 128), (851968, 26624, 128, 1))
assert_size_stride(arg419_1, (1, 32, 208, 128), (851968, 26624, 128, 1))
assert_size_stride(arg420_1, (4096, 768), (768, 1))
assert_size_stride(arg421_1, (4096, ), (1, ))
assert_size_stride(arg422_1, (11008, 768), (768, 1))
assert_size_stride(arg423_1, (11008, ), (1, ))
assert_size_stride(arg424_1, (11008, 768), (768, 1))
assert_size_stride(arg425_1, (11008, ), (1, ))
assert_size_stride(arg426_1, (4096, 2064), (2064, 1))
assert_size_stride(arg427_1, (4096, ), (1, ))
assert_size_stride(arg428_1, (12288, 768), (768, 1))
assert_size_stride(arg429_1, (12288, ), (1, ))
assert_size_stride(arg430_1, (1, 32, 208, 128), (851968, 26624, 128, 1))
assert_size_stride(arg431_1, (1, 32, 208, 128), (851968, 26624, 128, 1))
assert_size_stride(arg432_1, (4096, 768), (768, 1))
assert_size_stride(arg433_1, (4096, ), (1, ))
assert_size_stride(arg434_1, (11008, 768), (768, 1))
assert_size_stride(arg435_1, (11008, ), (1, ))
assert_size_stride(arg436_1, (11008, 768), (768, 1))
assert_size_stride(arg437_1, (11008, ), (1, ))
assert_size_stride(arg438_1, (4096, 2064), (2064, 1))
assert_size_stride(arg439_1, (4096, ), (1, ))
assert_size_stride(arg440_1, (12288, 768), (768, 1))
assert_size_stride(arg441_1, (12288, ), (1, ))
assert_size_stride(arg442_1, (1, 32, 208, 128), (851968, 26624, 128, 1))
assert_size_stride(arg443_1, (1, 32, 208, 128), (851968, 26624, 128, 1))
assert_size_stride(arg444_1, (4096, 768), (768, 1))
assert_size_stride(arg445_1, (4096, ), (1, ))
assert_size_stride(arg446_1, (11008, 768), (768, 1))
assert_size_stride(arg447_1, (11008, ), (1, ))
assert_size_stride(arg448_1, (11008, 768), (768, 1))
assert_size_stride(arg449_1, (11008, ), (1, ))
assert_size_stride(arg450_1, (4096, 2064), (2064, 1))
assert_size_stride(arg451_1, (4096, ), (1, ))
assert_size_stride(arg452_1, (32000, 768), (768, 1))
assert_size_stride(arg453_1, (32000, ), (1, ))
assert_size_stride(arg454_1, (1, ), (1, ))
assert_size_stride(arg455_1, (1, 1), (1, 1))
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
buf1 = empty_strided_cuda((1, 4096), (4096, 1), torch.float16)
# Source Nodes: [float_1, mean, mul, out, x], Original ATen: [aten._to_copy, aten.embedding, aten.mean, aten.mul, torchao.fp16act_fp6weight_linear]
stream0 = get_raw_stream(0)
triton_red_fused__to_copy_embedding_fp16act_fp6weight_linear_mean_mul_0.run(arg455_1, arg65_1, arg0_1, buf1, 1, 4096, grid=grid(1), stream=stream0)
del arg0_1
# Source Nodes: [out], Original ATen: [torchao.fp16act_fp6weight_linear]
buf2 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf1, arg68_1, arg69_1, 1)
del arg68_1
del arg69_1
buf3 = buf2
del buf2
buf5 = empty_strided_cuda((32, 1, 128), (128, 4096, 1), torch.float32)
# Source Nodes: [setitem, setitem_1, y], Original ATen: [aten.bmm, aten.index_put]
triton_poi_fused_bmm_index_put_1.run(arg454_1, buf3, arg66_1, arg70_1, buf5, arg71_1, 4096, grid=grid(4096), stream=stream0)
del buf3
buf6 = empty_strided_cuda((32, 1, 208), (208, 6656, 1), torch.float32)
# Source Nodes: [y], Original ATen: [aten.bmm]
triton_red_fused_bmm_2.run(buf5, arg70_1, buf6, 6656, 128, grid=grid(6656), stream=stream0)
del arg70_1
buf10 = empty_strided_cuda((32, 1, 208), (208, 6656, 1), torch.float32)
# Source Nodes: [mask, y], Original ATen: [aten._softmax, aten.add, aten.bmm, aten.index, aten.logical_not, aten.masked_fill, aten.zeros_like]
triton_per_fused__softmax_add_bmm_index_logical_not_masked_fill_zeros_like_3.run(buf6, arg454_1, arg67_1, buf10, 32, 208, grid=grid(32), stream=stream0)
buf11 = empty_strided_cuda((32, 1, 128, 2), (256, 8192, 1, 128), torch.float32)
# Source Nodes: [y], Original ATen: [aten.bmm]
triton_red_fused_bmm_4.run(buf10, arg71_1, buf11, 8192, 104, grid=grid(8192), stream=stream0)
del arg71_1
buf13 = buf1; del buf1 # reuse
# Source Nodes: [out_1, y], Original ATen: [aten.bmm, torchao.fp16act_fp6weight_linear]
triton_per_fused_bmm_fp16act_fp6weight_linear_5.run(buf11, buf13, 4096, 2, grid=grid(4096), stream=stream0)
# Source Nodes: [out_1], Original ATen: [torchao.fp16act_fp6weight_linear]
buf14 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf13, arg72_1, arg73_1, 13)
del arg72_1
del arg73_1
buf15 = buf14
del buf14
buf17 = reinterpret_tensor(buf13, (1, 1, 4096), (4096, 4096, 1), 0); del buf13 # reuse
# Source Nodes: [add_4, float_4, h, mean_1, mul_11, mul_12, mul_13, output_1, rsqrt_1, x], Original ATen: [aten._to_copy, aten.add, aten.embedding, aten.mean, aten.mul, aten.rsqrt]
triton_red_fused__to_copy_add_embedding_mean_mul_rsqrt_6.run(arg455_1, arg65_1, buf15, arg1_1, buf17, 1, 4096, grid=grid(1), stream=stream0)
del arg1_1
# Source Nodes: [out_2], Original ATen: [torchao.fp16act_fp6weight_linear]
buf18 = torch.ops.torchao.fp16act_fp6weight_linear.default(reinterpret_tensor(buf17, (1, 4096), (4096, 1), 0), arg74_1, arg75_1, 1)
del arg74_1
del arg75_1
buf19 = buf18
del buf18
# Source Nodes: [out_3], Original ATen: [torchao.fp16act_fp6weight_linear]
buf20 = torch.ops.torchao.fp16act_fp6weight_linear.default(reinterpret_tensor(buf17, (1, 4096), (4096, 1), 0), arg76_1, arg77_1, 1)
del arg76_1
del arg77_1
buf21 = buf20
del buf20
buf22 = buf19; del buf19 # reuse
# Source Nodes: [out_4], Original ATen: [torchao.fp16act_fp6weight_linear]
triton_poi_fused_fp16act_fp6weight_linear_7.run(buf22, buf21, 11008, grid=grid(11008), stream=stream0)
del buf21
# Source Nodes: [out_4], Original ATen: [torchao.fp16act_fp6weight_linear]
buf23 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf22, arg78_1, arg79_1, 13)
del arg78_1
del arg79_1
del buf22
buf24 = buf23
del buf23
buf26 = reinterpret_tensor(buf17, (1, 4096), (4096, 1), 0); del buf17 # reuse
# Source Nodes: [float_5, h, mean_2, mul_15, out_5, out_6, x], Original ATen: [aten._to_copy, aten.add, aten.embedding, aten.mean, aten.mul, torchao.fp16act_fp6weight_linear]
triton_red_fused__to_copy_add_embedding_fp16act_fp6weight_linear_mean_mul_8.run(arg455_1, arg65_1, buf15, buf24, arg2_1, buf26, 1, 4096, grid=grid(1), stream=stream0)
del arg2_1
# Source Nodes: [out_6], Original ATen: [torchao.fp16act_fp6weight_linear]
buf27 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf26, arg80_1, arg81_1, 1)
del arg80_1
del arg81_1
buf28 = buf27
del buf27
buf30 = buf5; del buf5 # reuse
# Source Nodes: [setitem_2, setitem_3, y_3], Original ATen: [aten.bmm, aten.index_put]
triton_poi_fused_bmm_index_put_1.run(arg454_1, buf28, arg66_1, arg82_1, buf30, arg83_1, 4096, grid=grid(4096), stream=stream0)
del buf28
buf31 = buf10; del buf10 # reuse
# Source Nodes: [y_3], Original ATen: [aten.bmm]
triton_red_fused_bmm_2.run(buf30, arg82_1, buf31, 6656, 128, grid=grid(6656), stream=stream0)
del arg82_1
buf35 = buf6; del buf6 # reuse
# Source Nodes: [mask, y_3], Original ATen: [aten._softmax, aten.add, aten.bmm, aten.index, aten.logical_not, aten.masked_fill, aten.zeros_like]
triton_per_fused__softmax_add_bmm_index_logical_not_masked_fill_zeros_like_3.run(buf31, arg454_1, arg67_1, buf35, 32, 208, grid=grid(32), stream=stream0)
buf36 = buf11; del buf11 # reuse
# Source Nodes: [y_3], Original ATen: [aten.bmm]
triton_red_fused_bmm_4.run(buf35, arg83_1, buf36, 8192, 104, grid=grid(8192), stream=stream0)
del arg83_1
buf38 = buf26; del buf26 # reuse
# Source Nodes: [out_7, y_3], Original ATen: [aten.bmm, torchao.fp16act_fp6weight_linear]
triton_per_fused_bmm_fp16act_fp6weight_linear_5.run(buf36, buf38, 4096, 2, grid=grid(4096), stream=stream0)
# Source Nodes: [out_7], Original ATen: [torchao.fp16act_fp6weight_linear]
buf39 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf38, arg84_1, arg85_1, 13)
del arg84_1
del arg85_1
buf40 = buf39
del buf39
buf41 = reinterpret_tensor(buf15, (1, 1, 4096), (4096, 4096, 1), 0); del buf15 # reuse
buf43 = buf38; del buf38 # reuse
buf46 = empty_strided_cuda((1, 4096), (4096, 1), torch.float16)
# Source Nodes: [float_8, h, h_1, mean_3, mul_26, out_5, out_8, out_9, x], Original ATen: [aten._to_copy, aten.add, aten.embedding, aten.mean, aten.mul, torchao.fp16act_fp6weight_linear]
triton_red_fused__to_copy_add_embedding_fp16act_fp6weight_linear_mean_mul_9.run(buf41, arg455_1, arg65_1, buf24, buf40, arg3_1, buf43, buf46, 1, 4096, grid=grid(1), stream=stream0)
del arg3_1
del arg455_1
del arg65_1
del buf24
del buf40
# Source Nodes: [out_8], Original ATen: [torchao.fp16act_fp6weight_linear]
buf44 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf43, arg86_1, arg87_1, 1)
del arg86_1
del arg87_1
buf45 = buf44
del buf44
# Source Nodes: [out_9], Original ATen: [torchao.fp16act_fp6weight_linear]
buf47 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf46, arg88_1, arg89_1, 1)
del arg88_1
del arg89_1
buf48 = buf47
del buf47
buf49 = buf45; del buf45 # reuse
# Source Nodes: [out_10], Original ATen: [torchao.fp16act_fp6weight_linear]
triton_poi_fused_fp16act_fp6weight_linear_7.run(buf49, buf48, 11008, grid=grid(11008), stream=stream0)
del buf48
# Source Nodes: [out_10], Original ATen: [torchao.fp16act_fp6weight_linear]
buf50 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf49, arg90_1, arg91_1, 13)
del arg90_1
del arg91_1
del buf49
buf51 = buf50
del buf50
buf53 = buf46; del buf46 # reuse
# Source Nodes: [float_9, mean_4, mul_30, out_11, out_12], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, torchao.fp16act_fp6weight_linear]
triton_red_fused__to_copy_add_fp16act_fp6weight_linear_mean_mul_10.run(buf41, buf51, arg4_1, buf53, 1, 4096, grid=grid(1), stream=stream0)
del arg4_1
# Source Nodes: [out_12], Original ATen: [torchao.fp16act_fp6weight_linear]
buf54 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf53, arg92_1, arg93_1, 1)
del arg92_1
del arg93_1
buf55 = buf54
del buf54
buf57 = buf30; del buf30 # reuse
# Source Nodes: [setitem_4, setitem_5, y_6], Original ATen: [aten.bmm, aten.index_put]
triton_poi_fused_bmm_index_put_1.run(arg454_1, buf55, arg66_1, arg94_1, buf57, arg95_1, 4096, grid=grid(4096), stream=stream0)
del buf55
buf58 = buf35; del buf35 # reuse
# Source Nodes: [y_6], Original ATen: [aten.bmm]
triton_red_fused_bmm_2.run(buf57, arg94_1, buf58, 6656, 128, grid=grid(6656), stream=stream0)
del arg94_1
buf62 = buf31; del buf31 # reuse
# Source Nodes: [mask, y_6], Original ATen: [aten._softmax, aten.add, aten.bmm, aten.index, aten.logical_not, aten.masked_fill, aten.zeros_like]
triton_per_fused__softmax_add_bmm_index_logical_not_masked_fill_zeros_like_3.run(buf58, arg454_1, arg67_1, buf62, 32, 208, grid=grid(32), stream=stream0)
buf63 = buf36; del buf36 # reuse
# Source Nodes: [y_6], Original ATen: [aten.bmm]
triton_red_fused_bmm_4.run(buf62, arg95_1, buf63, 8192, 104, grid=grid(8192), stream=stream0)
del arg95_1
buf65 = buf53; del buf53 # reuse
# Source Nodes: [out_13, y_6], Original ATen: [aten.bmm, torchao.fp16act_fp6weight_linear]
triton_per_fused_bmm_fp16act_fp6weight_linear_5.run(buf63, buf65, 4096, 2, grid=grid(4096), stream=stream0)
# Source Nodes: [out_13], Original ATen: [torchao.fp16act_fp6weight_linear]
buf66 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf65, arg96_1, arg97_1, 13)
del arg96_1
del arg97_1
buf67 = buf66
del buf66
buf69 = reinterpret_tensor(buf65, (1, 1, 4096), (4096, 4096, 1), 0); del buf65 # reuse
# Source Nodes: [add_16, float_12, h_2, mean_5, mul_41, mul_42, mul_43, out_11, output_5, rsqrt_5], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, aten.rsqrt]
triton_red_fused__to_copy_add_mean_mul_rsqrt_11.run(buf41, buf51, buf67, arg5_1, buf69, 1, 4096, grid=grid(1), stream=stream0)
del arg5_1
# Source Nodes: [out_14], Original ATen: [torchao.fp16act_fp6weight_linear]
buf70 = torch.ops.torchao.fp16act_fp6weight_linear.default(reinterpret_tensor(buf69, (1, 4096), (4096, 1), 0), arg98_1, arg99_1, 1)
del arg98_1
del arg99_1
buf71 = buf70
del buf70
# Source Nodes: [out_15], Original ATen: [torchao.fp16act_fp6weight_linear]
buf72 = torch.ops.torchao.fp16act_fp6weight_linear.default(reinterpret_tensor(buf69, (1, 4096), (4096, 1), 0), arg100_1, arg101_1, 1)
del arg100_1
del arg101_1
buf73 = buf72
del buf72
buf74 = buf71; del buf71 # reuse
# Source Nodes: [out_16], Original ATen: [torchao.fp16act_fp6weight_linear]
triton_poi_fused_fp16act_fp6weight_linear_7.run(buf74, buf73, 11008, grid=grid(11008), stream=stream0)
del buf73
# Source Nodes: [out_16], Original ATen: [torchao.fp16act_fp6weight_linear]
buf75 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf74, arg102_1, arg103_1, 13)
del arg102_1
del arg103_1
del buf74
buf76 = buf75
del buf75
buf78 = reinterpret_tensor(buf69, (1, 4096), (4096, 1), 0); del buf69 # reuse
# Source Nodes: [float_13, h_2, mean_6, mul_45, out_11, out_17, out_18], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, torchao.fp16act_fp6weight_linear]
triton_red_fused__to_copy_add_fp16act_fp6weight_linear_mean_mul_12.run(buf41, buf51, buf67, buf76, arg6_1, buf78, 1, 4096, grid=grid(1), stream=stream0)
del arg6_1
# Source Nodes: [out_18], Original ATen: [torchao.fp16act_fp6weight_linear]
buf79 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf78, arg104_1, arg105_1, 1)
del arg104_1
del arg105_1
buf80 = buf79
del buf79
buf82 = buf57; del buf57 # reuse
# Source Nodes: [setitem_6, setitem_7, y_9], Original ATen: [aten.bmm, aten.index_put]
triton_poi_fused_bmm_index_put_1.run(arg454_1, buf80, arg66_1, arg106_1, buf82, arg107_1, 4096, grid=grid(4096), stream=stream0)
del buf80
buf83 = buf62; del buf62 # reuse
# Source Nodes: [y_9], Original ATen: [aten.bmm]
triton_red_fused_bmm_2.run(buf82, arg106_1, buf83, 6656, 128, grid=grid(6656), stream=stream0)
del arg106_1
buf87 = buf58; del buf58 # reuse
# Source Nodes: [mask, y_9], Original ATen: [aten._softmax, aten.add, aten.bmm, aten.index, aten.logical_not, aten.masked_fill, aten.zeros_like]
triton_per_fused__softmax_add_bmm_index_logical_not_masked_fill_zeros_like_3.run(buf83, arg454_1, arg67_1, buf87, 32, 208, grid=grid(32), stream=stream0)
buf88 = buf63; del buf63 # reuse
# Source Nodes: [y_9], Original ATen: [aten.bmm]
triton_red_fused_bmm_4.run(buf87, arg107_1, buf88, 8192, 104, grid=grid(8192), stream=stream0)
del arg107_1
buf90 = buf78; del buf78 # reuse
# Source Nodes: [out_19, y_9], Original ATen: [aten.bmm, torchao.fp16act_fp6weight_linear]
triton_per_fused_bmm_fp16act_fp6weight_linear_5.run(buf88, buf90, 4096, 2, grid=grid(4096), stream=stream0)
# Source Nodes: [out_19], Original ATen: [torchao.fp16act_fp6weight_linear]
buf91 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf90, arg108_1, arg109_1, 13)
del arg108_1
del arg109_1
buf92 = buf91
del buf91
buf93 = buf41; del buf41 # reuse
buf95 = buf90; del buf90 # reuse
buf98 = buf43; del buf43 # reuse
# Source Nodes: [float_16, h_2, h_3, mean_7, mul_56, out_11, out_17, out_20, out_21], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, torchao.fp16act_fp6weight_linear]
triton_red_fused__to_copy_add_fp16act_fp6weight_linear_mean_mul_13.run(buf93, buf51, buf67, buf76, buf92, arg7_1, buf95, buf98, 1, 4096, grid=grid(1), stream=stream0)
del arg7_1
del buf51
del buf67
del buf76
del buf92
# Source Nodes: [out_20], Original ATen: [torchao.fp16act_fp6weight_linear]
buf96 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf95, arg110_1, arg111_1, 1)
del arg110_1
del arg111_1
buf97 = buf96
del buf96
# Source Nodes: [out_21], Original ATen: [torchao.fp16act_fp6weight_linear]
buf99 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf98, arg112_1, arg113_1, 1)
del arg112_1
del arg113_1
buf100 = buf99
del buf99
buf101 = buf100; del buf100 # reuse
# Source Nodes: [out_22], Original ATen: [torchao.fp16act_fp6weight_linear]
triton_poi_fused_fp16act_fp6weight_linear_14.run(buf101, buf97, 11008, grid=grid(11008), stream=stream0)
del buf97
# Source Nodes: [out_22], Original ATen: [torchao.fp16act_fp6weight_linear]
buf102 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf101, arg114_1, arg115_1, 13)
del arg114_1
del arg115_1
del buf101
buf103 = buf102
del buf102
buf105 = buf98; del buf98 # reuse
# Source Nodes: [float_17, mean_8, mul_60, out_23, out_24], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, torchao.fp16act_fp6weight_linear]
triton_red_fused__to_copy_add_fp16act_fp6weight_linear_mean_mul_10.run(buf93, buf103, arg8_1, buf105, 1, 4096, grid=grid(1), stream=stream0)
del arg8_1
# Source Nodes: [out_24], Original ATen: [torchao.fp16act_fp6weight_linear]
buf106 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf105, arg116_1, arg117_1, 1)
del arg116_1
del arg117_1
buf107 = buf106
del buf106
buf109 = buf82; del buf82 # reuse
# Source Nodes: [setitem_8, setitem_9, y_12], Original ATen: [aten.bmm, aten.index_put]
triton_poi_fused_bmm_index_put_1.run(arg454_1, buf107, arg66_1, arg118_1, buf109, arg119_1, 4096, grid=grid(4096), stream=stream0)
del buf107
buf110 = buf87; del buf87 # reuse
# Source Nodes: [y_12], Original ATen: [aten.bmm]
triton_red_fused_bmm_2.run(buf109, arg118_1, buf110, 6656, 128, grid=grid(6656), stream=stream0)
del arg118_1
buf114 = buf83; del buf83 # reuse
# Source Nodes: [mask, y_12], Original ATen: [aten._softmax, aten.add, aten.bmm, aten.index, aten.logical_not, aten.masked_fill, aten.zeros_like]
triton_per_fused__softmax_add_bmm_index_logical_not_masked_fill_zeros_like_3.run(buf110, arg454_1, arg67_1, buf114, 32, 208, grid=grid(32), stream=stream0)
buf115 = buf88; del buf88 # reuse
# Source Nodes: [y_12], Original ATen: [aten.bmm]
triton_red_fused_bmm_4.run(buf114, arg119_1, buf115, 8192, 104, grid=grid(8192), stream=stream0)
del arg119_1
buf117 = buf105; del buf105 # reuse
# Source Nodes: [out_25, y_12], Original ATen: [aten.bmm, torchao.fp16act_fp6weight_linear]
triton_per_fused_bmm_fp16act_fp6weight_linear_5.run(buf115, buf117, 4096, 2, grid=grid(4096), stream=stream0)
# Source Nodes: [out_25], Original ATen: [torchao.fp16act_fp6weight_linear]
buf118 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf117, arg120_1, arg121_1, 13)
del arg120_1
del arg121_1
buf119 = buf118
del buf118
buf121 = reinterpret_tensor(buf117, (1, 1, 4096), (4096, 4096, 1), 0); del buf117 # reuse
# Source Nodes: [add_28, float_20, h_4, mean_9, mul_71, mul_72, mul_73, out_23, output_9, rsqrt_9], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, aten.rsqrt]
triton_red_fused__to_copy_add_mean_mul_rsqrt_11.run(buf93, buf103, buf119, arg9_1, buf121, 1, 4096, grid=grid(1), stream=stream0)
del arg9_1
# Source Nodes: [out_26], Original ATen: [torchao.fp16act_fp6weight_linear]
buf122 = torch.ops.torchao.fp16act_fp6weight_linear.default(reinterpret_tensor(buf121, (1, 4096), (4096, 1), 0), arg122_1, arg123_1, 1)
del arg122_1
del arg123_1
buf123 = buf122
del buf122
# Source Nodes: [out_27], Original ATen: [torchao.fp16act_fp6weight_linear]
buf124 = torch.ops.torchao.fp16act_fp6weight_linear.default(reinterpret_tensor(buf121, (1, 4096), (4096, 1), 0), arg124_1, arg125_1, 1)
del arg124_1
del arg125_1
buf125 = buf124
del buf124
buf126 = buf123; del buf123 # reuse
# Source Nodes: [out_28], Original ATen: [torchao.fp16act_fp6weight_linear]
triton_poi_fused_fp16act_fp6weight_linear_7.run(buf126, buf125, 11008, grid=grid(11008), stream=stream0)
del buf125
# Source Nodes: [out_28], Original ATen: [torchao.fp16act_fp6weight_linear]
buf127 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf126, arg126_1, arg127_1, 13)
del arg126_1
del arg127_1
del buf126
buf128 = buf127
del buf127
buf130 = reinterpret_tensor(buf121, (1, 4096), (4096, 1), 0); del buf121 # reuse
# Source Nodes: [float_21, h_4, mean_10, mul_75, out_23, out_29, out_30], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, torchao.fp16act_fp6weight_linear]
triton_red_fused__to_copy_add_fp16act_fp6weight_linear_mean_mul_12.run(buf93, buf103, buf119, buf128, arg10_1, buf130, 1, 4096, grid=grid(1), stream=stream0)
del arg10_1
# Source Nodes: [out_30], Original ATen: [torchao.fp16act_fp6weight_linear]
buf131 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf130, arg128_1, arg129_1, 1)
del arg128_1
del arg129_1
buf132 = buf131
del buf131
buf134 = buf109; del buf109 # reuse
# Source Nodes: [setitem_10, setitem_11, y_15], Original ATen: [aten.bmm, aten.index_put]
triton_poi_fused_bmm_index_put_1.run(arg454_1, buf132, arg66_1, arg130_1, buf134, arg131_1, 4096, grid=grid(4096), stream=stream0)
del buf132
buf135 = buf114; del buf114 # reuse
# Source Nodes: [y_15], Original ATen: [aten.bmm]
triton_red_fused_bmm_2.run(buf134, arg130_1, buf135, 6656, 128, grid=grid(6656), stream=stream0)
del arg130_1
buf139 = buf110; del buf110 # reuse
# Source Nodes: [mask, y_15], Original ATen: [aten._softmax, aten.add, aten.bmm, aten.index, aten.logical_not, aten.masked_fill, aten.zeros_like]
triton_per_fused__softmax_add_bmm_index_logical_not_masked_fill_zeros_like_3.run(buf135, arg454_1, arg67_1, buf139, 32, 208, grid=grid(32), stream=stream0)
buf140 = buf115; del buf115 # reuse
# Source Nodes: [y_15], Original ATen: [aten.bmm]
triton_red_fused_bmm_4.run(buf139, arg131_1, buf140, 8192, 104, grid=grid(8192), stream=stream0)
del arg131_1
buf142 = buf130; del buf130 # reuse
# Source Nodes: [out_31, y_15], Original ATen: [aten.bmm, torchao.fp16act_fp6weight_linear]
triton_per_fused_bmm_fp16act_fp6weight_linear_5.run(buf140, buf142, 4096, 2, grid=grid(4096), stream=stream0)
# Source Nodes: [out_31], Original ATen: [torchao.fp16act_fp6weight_linear]
buf143 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf142, arg132_1, arg133_1, 13)
del arg132_1
del arg133_1
buf144 = buf143
del buf143
buf145 = reinterpret_tensor(buf103, (1, 1, 4096), (4096, 4096, 1), 0); del buf103 # reuse
buf147 = buf142; del buf142 # reuse
buf150 = buf95; del buf95 # reuse
# Source Nodes: [float_24, h_4, h_5, mean_11, mul_86, out_23, out_29, out_32, out_33], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, torchao.fp16act_fp6weight_linear]
triton_red_fused__to_copy_add_fp16act_fp6weight_linear_mean_mul_15.run(buf145, buf93, buf119, buf128, buf144, arg11_1, buf147, buf150, 1, 4096, grid=grid(1), stream=stream0)
del arg11_1
del buf119
del buf128
del buf144
del buf93
# Source Nodes: [out_32], Original ATen: [torchao.fp16act_fp6weight_linear]
buf148 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf147, arg134_1, arg135_1, 1)
del arg134_1
del arg135_1
buf149 = buf148
del buf148
# Source Nodes: [out_33], Original ATen: [torchao.fp16act_fp6weight_linear]
buf151 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf150, arg136_1, arg137_1, 1)
del arg136_1
del arg137_1
buf152 = buf151
del buf151
buf153 = buf149; del buf149 # reuse
# Source Nodes: [out_34], Original ATen: [torchao.fp16act_fp6weight_linear]
triton_poi_fused_fp16act_fp6weight_linear_7.run(buf153, buf152, 11008, grid=grid(11008), stream=stream0)
del buf152
# Source Nodes: [out_34], Original ATen: [torchao.fp16act_fp6weight_linear]
buf154 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf153, arg138_1, arg139_1, 13)
del arg138_1
del arg139_1
del buf153
buf155 = buf154
del buf154
buf157 = buf150; del buf150 # reuse
# Source Nodes: [float_25, mean_12, mul_90, out_35, out_36], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, torchao.fp16act_fp6weight_linear]
triton_red_fused__to_copy_add_fp16act_fp6weight_linear_mean_mul_10.run(buf145, buf155, arg12_1, buf157, 1, 4096, grid=grid(1), stream=stream0)
del arg12_1
# Source Nodes: [out_36], Original ATen: [torchao.fp16act_fp6weight_linear]
buf158 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf157, arg140_1, arg141_1, 1)
del arg140_1
del arg141_1
buf159 = buf158
del buf158
buf161 = buf134; del buf134 # reuse
# Source Nodes: [setitem_12, setitem_13, y_18], Original ATen: [aten.bmm, aten.index_put]
triton_poi_fused_bmm_index_put_1.run(arg454_1, buf159, arg66_1, arg142_1, buf161, arg143_1, 4096, grid=grid(4096), stream=stream0)
del buf159
buf162 = buf139; del buf139 # reuse
# Source Nodes: [y_18], Original ATen: [aten.bmm]
triton_red_fused_bmm_2.run(buf161, arg142_1, buf162, 6656, 128, grid=grid(6656), stream=stream0)
del arg142_1
buf166 = buf135; del buf135 # reuse
# Source Nodes: [mask, y_18], Original ATen: [aten._softmax, aten.add, aten.bmm, aten.index, aten.logical_not, aten.masked_fill, aten.zeros_like]
triton_per_fused__softmax_add_bmm_index_logical_not_masked_fill_zeros_like_3.run(buf162, arg454_1, arg67_1, buf166, 32, 208, grid=grid(32), stream=stream0)
buf167 = buf140; del buf140 # reuse
# Source Nodes: [y_18], Original ATen: [aten.bmm]
triton_red_fused_bmm_4.run(buf166, arg143_1, buf167, 8192, 104, grid=grid(8192), stream=stream0)
del arg143_1
buf169 = buf157; del buf157 # reuse
# Source Nodes: [out_37, y_18], Original ATen: [aten.bmm, torchao.fp16act_fp6weight_linear]
triton_per_fused_bmm_fp16act_fp6weight_linear_5.run(buf167, buf169, 4096, 2, grid=grid(4096), stream=stream0)
# Source Nodes: [out_37], Original ATen: [torchao.fp16act_fp6weight_linear]
buf170 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf169, arg144_1, arg145_1, 13)
del arg144_1
del arg145_1
buf171 = buf170
del buf170
buf173 = reinterpret_tensor(buf169, (1, 1, 4096), (4096, 4096, 1), 0); del buf169 # reuse
# Source Nodes: [add_40, float_28, h_6, mean_13, mul_101, mul_102, mul_103, out_35, output_13, rsqrt_13], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, aten.rsqrt]
triton_red_fused__to_copy_add_mean_mul_rsqrt_11.run(buf145, buf155, buf171, arg13_1, buf173, 1, 4096, grid=grid(1), stream=stream0)
del arg13_1
# Source Nodes: [out_38], Original ATen: [torchao.fp16act_fp6weight_linear]
buf174 = torch.ops.torchao.fp16act_fp6weight_linear.default(reinterpret_tensor(buf173, (1, 4096), (4096, 1), 0), arg146_1, arg147_1, 1)
del arg146_1
del arg147_1
buf175 = buf174
del buf174
# Source Nodes: [out_39], Original ATen: [torchao.fp16act_fp6weight_linear]
buf176 = torch.ops.torchao.fp16act_fp6weight_linear.default(reinterpret_tensor(buf173, (1, 4096), (4096, 1), 0), arg148_1, arg149_1, 1)
del arg148_1
del arg149_1
buf177 = buf176
del buf176
buf178 = buf175; del buf175 # reuse
# Source Nodes: [out_40], Original ATen: [torchao.fp16act_fp6weight_linear]
triton_poi_fused_fp16act_fp6weight_linear_7.run(buf178, buf177, 11008, grid=grid(11008), stream=stream0)
del buf177
# Source Nodes: [out_40], Original ATen: [torchao.fp16act_fp6weight_linear]
buf179 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf178, arg150_1, arg151_1, 13)
del arg150_1
del arg151_1
del buf178
buf180 = buf179
del buf179
buf182 = reinterpret_tensor(buf173, (1, 4096), (4096, 1), 0); del buf173 # reuse
# Source Nodes: [float_29, h_6, mean_14, mul_105, out_35, out_41, out_42], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, torchao.fp16act_fp6weight_linear]
triton_red_fused__to_copy_add_fp16act_fp6weight_linear_mean_mul_12.run(buf145, buf155, buf171, buf180, arg14_1, buf182, 1, 4096, grid=grid(1), stream=stream0)
del arg14_1
# Source Nodes: [out_42], Original ATen: [torchao.fp16act_fp6weight_linear]
buf183 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf182, arg152_1, arg153_1, 1)
del arg152_1
del arg153_1
buf184 = buf183
del buf183
buf186 = buf161; del buf161 # reuse
# Source Nodes: [setitem_14, setitem_15, y_21], Original ATen: [aten.bmm, aten.index_put]
triton_poi_fused_bmm_index_put_1.run(arg454_1, buf184, arg66_1, arg154_1, buf186, arg155_1, 4096, grid=grid(4096), stream=stream0)
del buf184
buf187 = buf166; del buf166 # reuse
# Source Nodes: [y_21], Original ATen: [aten.bmm]
triton_red_fused_bmm_2.run(buf186, arg154_1, buf187, 6656, 128, grid=grid(6656), stream=stream0)
del arg154_1
buf191 = buf162; del buf162 # reuse
# Source Nodes: [mask, y_21], Original ATen: [aten._softmax, aten.add, aten.bmm, aten.index, aten.logical_not, aten.masked_fill, aten.zeros_like]
triton_per_fused__softmax_add_bmm_index_logical_not_masked_fill_zeros_like_3.run(buf187, arg454_1, arg67_1, buf191, 32, 208, grid=grid(32), stream=stream0)
buf192 = buf167; del buf167 # reuse
# Source Nodes: [y_21], Original ATen: [aten.bmm]
triton_red_fused_bmm_4.run(buf191, arg155_1, buf192, 8192, 104, grid=grid(8192), stream=stream0)
del arg155_1
buf194 = buf182; del buf182 # reuse
# Source Nodes: [out_43, y_21], Original ATen: [aten.bmm, torchao.fp16act_fp6weight_linear]
triton_per_fused_bmm_fp16act_fp6weight_linear_5.run(buf192, buf194, 4096, 2, grid=grid(4096), stream=stream0)
# Source Nodes: [out_43], Original ATen: [torchao.fp16act_fp6weight_linear]
buf195 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf194, arg156_1, arg157_1, 13)
del arg156_1
del arg157_1
buf196 = buf195
del buf195
buf197 = buf145; del buf145 # reuse
buf199 = buf194; del buf194 # reuse
buf202 = buf147; del buf147 # reuse
# Source Nodes: [float_32, h_6, h_7, mean_15, mul_116, out_35, out_41, out_44, out_45], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, torchao.fp16act_fp6weight_linear]
triton_red_fused__to_copy_add_fp16act_fp6weight_linear_mean_mul_13.run(buf197, buf155, buf171, buf180, buf196, arg15_1, buf199, buf202, 1, 4096, grid=grid(1), stream=stream0)
del arg15_1
del buf155
del buf171
del buf180
del buf196
# Source Nodes: [out_44], Original ATen: [torchao.fp16act_fp6weight_linear]
buf200 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf199, arg158_1, arg159_1, 1)
del arg158_1
del arg159_1
buf201 = buf200
del buf200
# Source Nodes: [out_45], Original ATen: [torchao.fp16act_fp6weight_linear]
buf203 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf202, arg160_1, arg161_1, 1)
del arg160_1
del arg161_1
buf204 = buf203
del buf203
buf205 = buf201; del buf201 # reuse
# Source Nodes: [out_46], Original ATen: [torchao.fp16act_fp6weight_linear]
triton_poi_fused_fp16act_fp6weight_linear_7.run(buf205, buf204, 11008, grid=grid(11008), stream=stream0)
del buf204
# Source Nodes: [out_46], Original ATen: [torchao.fp16act_fp6weight_linear]
buf206 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf205, arg162_1, arg163_1, 13)
del arg162_1
del arg163_1
del buf205
buf207 = buf206
del buf206
buf209 = buf202; del buf202 # reuse
# Source Nodes: [float_33, mean_16, mul_120, out_47, out_48], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, torchao.fp16act_fp6weight_linear]
triton_red_fused__to_copy_add_fp16act_fp6weight_linear_mean_mul_10.run(buf197, buf207, arg16_1, buf209, 1, 4096, grid=grid(1), stream=stream0)
del arg16_1
# Source Nodes: [out_48], Original ATen: [torchao.fp16act_fp6weight_linear]
buf210 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf209, arg164_1, arg165_1, 1)
del arg164_1
del arg165_1
buf211 = buf210
del buf210
buf213 = buf186; del buf186 # reuse
# Source Nodes: [setitem_16, setitem_17, y_24], Original ATen: [aten.bmm, aten.index_put]
triton_poi_fused_bmm_index_put_1.run(arg454_1, buf211, arg66_1, arg166_1, buf213, arg167_1, 4096, grid=grid(4096), stream=stream0)
del buf211
buf214 = buf191; del buf191 # reuse
# Source Nodes: [y_24], Original ATen: [aten.bmm]
triton_red_fused_bmm_2.run(buf213, arg166_1, buf214, 6656, 128, grid=grid(6656), stream=stream0)
del arg166_1
buf218 = buf187; del buf187 # reuse
# Source Nodes: [mask, y_24], Original ATen: [aten._softmax, aten.add, aten.bmm, aten.index, aten.logical_not, aten.masked_fill, aten.zeros_like]
triton_per_fused__softmax_add_bmm_index_logical_not_masked_fill_zeros_like_3.run(buf214, arg454_1, arg67_1, buf218, 32, 208, grid=grid(32), stream=stream0)
buf219 = buf192; del buf192 # reuse
# Source Nodes: [y_24], Original ATen: [aten.bmm]
triton_red_fused_bmm_4.run(buf218, arg167_1, buf219, 8192, 104, grid=grid(8192), stream=stream0)
del arg167_1
buf221 = buf209; del buf209 # reuse
# Source Nodes: [out_49, y_24], Original ATen: [aten.bmm, torchao.fp16act_fp6weight_linear]
triton_per_fused_bmm_fp16act_fp6weight_linear_5.run(buf219, buf221, 4096, 2, grid=grid(4096), stream=stream0)
# Source Nodes: [out_49], Original ATen: [torchao.fp16act_fp6weight_linear]
buf222 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf221, arg168_1, arg169_1, 13)
del arg168_1
del arg169_1
buf223 = buf222
del buf222
buf225 = reinterpret_tensor(buf221, (1, 1, 4096), (4096, 4096, 1), 0); del buf221 # reuse
# Source Nodes: [add_52, float_36, h_8, mean_17, mul_131, mul_132, mul_133, out_47, output_17, rsqrt_17], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, aten.rsqrt]
triton_red_fused__to_copy_add_mean_mul_rsqrt_11.run(buf197, buf207, buf223, arg17_1, buf225, 1, 4096, grid=grid(1), stream=stream0)
del arg17_1
# Source Nodes: [out_50], Original ATen: [torchao.fp16act_fp6weight_linear]
buf226 = torch.ops.torchao.fp16act_fp6weight_linear.default(reinterpret_tensor(buf225, (1, 4096), (4096, 1), 0), arg170_1, arg171_1, 1)
del arg170_1
del arg171_1
buf227 = buf226
del buf226
# Source Nodes: [out_51], Original ATen: [torchao.fp16act_fp6weight_linear]
buf228 = torch.ops.torchao.fp16act_fp6weight_linear.default(reinterpret_tensor(buf225, (1, 4096), (4096, 1), 0), arg172_1, arg173_1, 1)
del arg172_1
del arg173_1
buf229 = buf228
del buf228
buf230 = buf227; del buf227 # reuse
# Source Nodes: [out_52], Original ATen: [torchao.fp16act_fp6weight_linear]
triton_poi_fused_fp16act_fp6weight_linear_7.run(buf230, buf229, 11008, grid=grid(11008), stream=stream0)
del buf229
# Source Nodes: [out_52], Original ATen: [torchao.fp16act_fp6weight_linear]
buf231 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf230, arg174_1, arg175_1, 13)
del arg174_1
del arg175_1
del buf230
buf232 = buf231
del buf231
buf234 = reinterpret_tensor(buf225, (1, 4096), (4096, 1), 0); del buf225 # reuse
# Source Nodes: [float_37, h_8, mean_18, mul_135, out_47, out_53, out_54], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, torchao.fp16act_fp6weight_linear]
triton_red_fused__to_copy_add_fp16act_fp6weight_linear_mean_mul_12.run(buf197, buf207, buf223, buf232, arg18_1, buf234, 1, 4096, grid=grid(1), stream=stream0)
del arg18_1
# Source Nodes: [out_54], Original ATen: [torchao.fp16act_fp6weight_linear]
buf235 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf234, arg176_1, arg177_1, 1)
del arg176_1
del arg177_1
buf236 = buf235
del buf235
buf238 = buf213; del buf213 # reuse
# Source Nodes: [setitem_18, setitem_19, y_27], Original ATen: [aten.bmm, aten.index_put]
triton_poi_fused_bmm_index_put_1.run(arg454_1, buf236, arg66_1, arg178_1, buf238, arg179_1, 4096, grid=grid(4096), stream=stream0)
del buf236
buf239 = buf218; del buf218 # reuse
# Source Nodes: [y_27], Original ATen: [aten.bmm]
triton_red_fused_bmm_2.run(buf238, arg178_1, buf239, 6656, 128, grid=grid(6656), stream=stream0)
del arg178_1
buf243 = buf214; del buf214 # reuse
# Source Nodes: [mask, y_27], Original ATen: [aten._softmax, aten.add, aten.bmm, aten.index, aten.logical_not, aten.masked_fill, aten.zeros_like]
triton_per_fused__softmax_add_bmm_index_logical_not_masked_fill_zeros_like_3.run(buf239, arg454_1, arg67_1, buf243, 32, 208, grid=grid(32), stream=stream0)
buf244 = buf219; del buf219 # reuse
# Source Nodes: [y_27], Original ATen: [aten.bmm]
triton_red_fused_bmm_4.run(buf243, arg179_1, buf244, 8192, 104, grid=grid(8192), stream=stream0)
del arg179_1
buf246 = buf234; del buf234 # reuse
# Source Nodes: [out_55, y_27], Original ATen: [aten.bmm, torchao.fp16act_fp6weight_linear]
triton_per_fused_bmm_fp16act_fp6weight_linear_5.run(buf244, buf246, 4096, 2, grid=grid(4096), stream=stream0)
# Source Nodes: [out_55], Original ATen: [torchao.fp16act_fp6weight_linear]
buf247 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf246, arg180_1, arg181_1, 13)
del arg180_1
del arg181_1
buf248 = buf247
del buf247
buf249 = buf197; del buf197 # reuse
buf251 = buf246; del buf246 # reuse
buf254 = buf199; del buf199 # reuse
# Source Nodes: [float_40, h_8, h_9, mean_19, mul_146, out_47, out_53, out_56, out_57], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, torchao.fp16act_fp6weight_linear]
triton_red_fused__to_copy_add_fp16act_fp6weight_linear_mean_mul_13.run(buf249, buf207, buf223, buf232, buf248, arg19_1, buf251, buf254, 1, 4096, grid=grid(1), stream=stream0)
del arg19_1
del buf207
del buf223
del buf232
del buf248
# Source Nodes: [out_56], Original ATen: [torchao.fp16act_fp6weight_linear]
buf252 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf251, arg182_1, arg183_1, 1)
del arg182_1
del arg183_1
buf253 = buf252
del buf252
# Source Nodes: [out_57], Original ATen: [torchao.fp16act_fp6weight_linear]
buf255 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf254, arg184_1, arg185_1, 1)
del arg184_1
del arg185_1
buf256 = buf255
del buf255
buf257 = buf253; del buf253 # reuse
# Source Nodes: [out_58], Original ATen: [torchao.fp16act_fp6weight_linear]
triton_poi_fused_fp16act_fp6weight_linear_7.run(buf257, buf256, 11008, grid=grid(11008), stream=stream0)
del buf256
# Source Nodes: [out_58], Original ATen: [torchao.fp16act_fp6weight_linear]
buf258 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf257, arg186_1, arg187_1, 13)
del arg186_1
del arg187_1
del buf257
buf259 = buf258
del buf258
buf261 = buf254; del buf254 # reuse
# Source Nodes: [float_41, mean_20, mul_150, out_59, out_60], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, torchao.fp16act_fp6weight_linear]
triton_red_fused__to_copy_add_fp16act_fp6weight_linear_mean_mul_10.run(buf249, buf259, arg20_1, buf261, 1, 4096, grid=grid(1), stream=stream0)
del arg20_1
# Source Nodes: [out_60], Original ATen: [torchao.fp16act_fp6weight_linear]
buf262 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf261, arg188_1, arg189_1, 1)
del arg188_1
del arg189_1
buf263 = buf262
del buf262
buf265 = buf238; del buf238 # reuse
# Source Nodes: [setitem_20, setitem_21, y_30], Original ATen: [aten.bmm, aten.index_put]
triton_poi_fused_bmm_index_put_1.run(arg454_1, buf263, arg66_1, arg190_1, buf265, arg191_1, 4096, grid=grid(4096), stream=stream0)
del buf263
buf266 = buf243; del buf243 # reuse
# Source Nodes: [y_30], Original ATen: [aten.bmm]
triton_red_fused_bmm_2.run(buf265, arg190_1, buf266, 6656, 128, grid=grid(6656), stream=stream0)
del arg190_1
buf270 = buf239; del buf239 # reuse
# Source Nodes: [mask, y_30], Original ATen: [aten._softmax, aten.add, aten.bmm, aten.index, aten.logical_not, aten.masked_fill, aten.zeros_like]
triton_per_fused__softmax_add_bmm_index_logical_not_masked_fill_zeros_like_3.run(buf266, arg454_1, arg67_1, buf270, 32, 208, grid=grid(32), stream=stream0)
buf271 = buf244; del buf244 # reuse
# Source Nodes: [y_30], Original ATen: [aten.bmm]
triton_red_fused_bmm_4.run(buf270, arg191_1, buf271, 8192, 104, grid=grid(8192), stream=stream0)
del arg191_1
buf273 = buf261; del buf261 # reuse
# Source Nodes: [out_61, y_30], Original ATen: [aten.bmm, torchao.fp16act_fp6weight_linear]
triton_per_fused_bmm_fp16act_fp6weight_linear_5.run(buf271, buf273, 4096, 2, grid=grid(4096), stream=stream0)
# Source Nodes: [out_61], Original ATen: [torchao.fp16act_fp6weight_linear]
buf274 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf273, arg192_1, arg193_1, 13)
del arg192_1
del arg193_1
buf275 = buf274
del buf274
buf277 = reinterpret_tensor(buf273, (1, 1, 4096), (4096, 4096, 1), 0); del buf273 # reuse
# Source Nodes: [add_64, float_44, h_10, mean_21, mul_161, mul_162, mul_163, out_59, output_21, rsqrt_21], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, aten.rsqrt]
triton_red_fused__to_copy_add_mean_mul_rsqrt_11.run(buf249, buf259, buf275, arg21_1, buf277, 1, 4096, grid=grid(1), stream=stream0)
del arg21_1
# Source Nodes: [out_62], Original ATen: [torchao.fp16act_fp6weight_linear]
buf278 = torch.ops.torchao.fp16act_fp6weight_linear.default(reinterpret_tensor(buf277, (1, 4096), (4096, 1), 0), arg194_1, arg195_1, 1)
del arg194_1
del arg195_1
buf279 = buf278
del buf278
# Source Nodes: [out_63], Original ATen: [torchao.fp16act_fp6weight_linear]
buf280 = torch.ops.torchao.fp16act_fp6weight_linear.default(reinterpret_tensor(buf277, (1, 4096), (4096, 1), 0), arg196_1, arg197_1, 1)
del arg196_1
del arg197_1
buf281 = buf280
del buf280
buf282 = buf279; del buf279 # reuse
# Source Nodes: [out_64], Original ATen: [torchao.fp16act_fp6weight_linear]
triton_poi_fused_fp16act_fp6weight_linear_7.run(buf282, buf281, 11008, grid=grid(11008), stream=stream0)
del buf281
# Source Nodes: [out_64], Original ATen: [torchao.fp16act_fp6weight_linear]
buf283 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf282, arg198_1, arg199_1, 13)
del arg198_1
del arg199_1
del buf282
buf284 = buf283
del buf283
buf286 = reinterpret_tensor(buf277, (1, 4096), (4096, 1), 0); del buf277 # reuse
# Source Nodes: [float_45, h_10, mean_22, mul_165, out_59, out_65, out_66], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, torchao.fp16act_fp6weight_linear]
triton_red_fused__to_copy_add_fp16act_fp6weight_linear_mean_mul_12.run(buf249, buf259, buf275, buf284, arg22_1, buf286, 1, 4096, grid=grid(1), stream=stream0)
del arg22_1
# Source Nodes: [out_66], Original ATen: [torchao.fp16act_fp6weight_linear]
buf287 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf286, arg200_1, arg201_1, 1)
del arg200_1
del arg201_1
buf288 = buf287
del buf287
buf290 = buf265; del buf265 # reuse
# Source Nodes: [setitem_22, setitem_23, y_33], Original ATen: [aten.bmm, aten.index_put]
triton_poi_fused_bmm_index_put_1.run(arg454_1, buf288, arg66_1, arg202_1, buf290, arg203_1, 4096, grid=grid(4096), stream=stream0)
del buf288
buf291 = buf270; del buf270 # reuse
# Source Nodes: [y_33], Original ATen: [aten.bmm]
triton_red_fused_bmm_2.run(buf290, arg202_1, buf291, 6656, 128, grid=grid(6656), stream=stream0)
del arg202_1
buf295 = buf266; del buf266 # reuse
# Source Nodes: [mask, y_33], Original ATen: [aten._softmax, aten.add, aten.bmm, aten.index, aten.logical_not, aten.masked_fill, aten.zeros_like]
triton_per_fused__softmax_add_bmm_index_logical_not_masked_fill_zeros_like_3.run(buf291, arg454_1, arg67_1, buf295, 32, 208, grid=grid(32), stream=stream0)
buf296 = buf271; del buf271 # reuse
# Source Nodes: [y_33], Original ATen: [aten.bmm]
triton_red_fused_bmm_4.run(buf295, arg203_1, buf296, 8192, 104, grid=grid(8192), stream=stream0)
del arg203_1
buf298 = buf286; del buf286 # reuse
# Source Nodes: [out_67, y_33], Original ATen: [aten.bmm, torchao.fp16act_fp6weight_linear]
triton_per_fused_bmm_fp16act_fp6weight_linear_5.run(buf296, buf298, 4096, 2, grid=grid(4096), stream=stream0)
# Source Nodes: [out_67], Original ATen: [torchao.fp16act_fp6weight_linear]
buf299 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf298, arg204_1, arg205_1, 13)
del arg204_1
del arg205_1
buf300 = buf299
del buf299
buf301 = buf249; del buf249 # reuse
buf303 = buf298; del buf298 # reuse
buf306 = buf251; del buf251 # reuse
# Source Nodes: [float_48, h_10, h_11, mean_23, mul_176, out_59, out_65, out_68, out_69], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, torchao.fp16act_fp6weight_linear]
triton_red_fused__to_copy_add_fp16act_fp6weight_linear_mean_mul_13.run(buf301, buf259, buf275, buf284, buf300, arg23_1, buf303, buf306, 1, 4096, grid=grid(1), stream=stream0)
del arg23_1
del buf259
del buf275
del buf284
del buf300
# Source Nodes: [out_68], Original ATen: [torchao.fp16act_fp6weight_linear]
buf304 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf303, arg206_1, arg207_1, 1)
del arg206_1
del arg207_1
buf305 = buf304
del buf304
# Source Nodes: [out_69], Original ATen: [torchao.fp16act_fp6weight_linear]
buf307 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf306, arg208_1, arg209_1, 1)
del arg208_1
del arg209_1
buf308 = buf307
del buf307
buf309 = buf305; del buf305 # reuse
# Source Nodes: [out_70], Original ATen: [torchao.fp16act_fp6weight_linear]
triton_poi_fused_fp16act_fp6weight_linear_7.run(buf309, buf308, 11008, grid=grid(11008), stream=stream0)
del buf308
# Source Nodes: [out_70], Original ATen: [torchao.fp16act_fp6weight_linear]
buf310 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf309, arg210_1, arg211_1, 13)
del arg210_1
del arg211_1
del buf309
buf311 = buf310
del buf310
buf313 = buf306; del buf306 # reuse
# Source Nodes: [float_49, mean_24, mul_180, out_71, out_72], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, torchao.fp16act_fp6weight_linear]
triton_red_fused__to_copy_add_fp16act_fp6weight_linear_mean_mul_10.run(buf301, buf311, arg24_1, buf313, 1, 4096, grid=grid(1), stream=stream0)
del arg24_1
# Source Nodes: [out_72], Original ATen: [torchao.fp16act_fp6weight_linear]
buf314 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf313, arg212_1, arg213_1, 1)
del arg212_1
del arg213_1
buf315 = buf314
del buf314
buf317 = buf290; del buf290 # reuse
# Source Nodes: [setitem_24, setitem_25, y_36], Original ATen: [aten.bmm, aten.index_put]
triton_poi_fused_bmm_index_put_1.run(arg454_1, buf315, arg66_1, arg214_1, buf317, arg215_1, 4096, grid=grid(4096), stream=stream0)
del buf315
buf318 = buf295; del buf295 # reuse
# Source Nodes: [y_36], Original ATen: [aten.bmm]
triton_red_fused_bmm_2.run(buf317, arg214_1, buf318, 6656, 128, grid=grid(6656), stream=stream0)
del arg214_1
buf322 = buf291; del buf291 # reuse
# Source Nodes: [mask, y_36], Original ATen: [aten._softmax, aten.add, aten.bmm, aten.index, aten.logical_not, aten.masked_fill, aten.zeros_like]
triton_per_fused__softmax_add_bmm_index_logical_not_masked_fill_zeros_like_3.run(buf318, arg454_1, arg67_1, buf322, 32, 208, grid=grid(32), stream=stream0)
buf323 = buf296; del buf296 # reuse
# Source Nodes: [y_36], Original ATen: [aten.bmm]
triton_red_fused_bmm_4.run(buf322, arg215_1, buf323, 8192, 104, grid=grid(8192), stream=stream0)
del arg215_1
buf325 = buf313; del buf313 # reuse
# Source Nodes: [out_73, y_36], Original ATen: [aten.bmm, torchao.fp16act_fp6weight_linear]
triton_per_fused_bmm_fp16act_fp6weight_linear_5.run(buf323, buf325, 4096, 2, grid=grid(4096), stream=stream0)
# Source Nodes: [out_73], Original ATen: [torchao.fp16act_fp6weight_linear]
buf326 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf325, arg216_1, arg217_1, 13)
del arg216_1
del arg217_1
buf327 = buf326
del buf326
buf329 = reinterpret_tensor(buf325, (1, 1, 4096), (4096, 4096, 1), 0); del buf325 # reuse
# Source Nodes: [add_76, float_52, h_12, mean_25, mul_191, mul_192, mul_193, out_71, output_25, rsqrt_25], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, aten.rsqrt]
triton_red_fused__to_copy_add_mean_mul_rsqrt_11.run(buf301, buf311, buf327, arg25_1, buf329, 1, 4096, grid=grid(1), stream=stream0)
del arg25_1
# Source Nodes: [out_74], Original ATen: [torchao.fp16act_fp6weight_linear]
buf330 = torch.ops.torchao.fp16act_fp6weight_linear.default(reinterpret_tensor(buf329, (1, 4096), (4096, 1), 0), arg218_1, arg219_1, 1)
del arg218_1
del arg219_1
buf331 = buf330
del buf330
# Source Nodes: [out_75], Original ATen: [torchao.fp16act_fp6weight_linear]
buf332 = torch.ops.torchao.fp16act_fp6weight_linear.default(reinterpret_tensor(buf329, (1, 4096), (4096, 1), 0), arg220_1, arg221_1, 1)
del arg220_1
del arg221_1
buf333 = buf332
del buf332
buf334 = buf331; del buf331 # reuse
# Source Nodes: [out_76], Original ATen: [torchao.fp16act_fp6weight_linear]
triton_poi_fused_fp16act_fp6weight_linear_7.run(buf334, buf333, 11008, grid=grid(11008), stream=stream0)
del buf333
# Source Nodes: [out_76], Original ATen: [torchao.fp16act_fp6weight_linear]
buf335 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf334, arg222_1, arg223_1, 13)
del arg222_1
del arg223_1
del buf334
buf336 = buf335
del buf335
buf338 = reinterpret_tensor(buf329, (1, 4096), (4096, 1), 0); del buf329 # reuse
# Source Nodes: [float_53, h_12, mean_26, mul_195, out_71, out_77, out_78], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, torchao.fp16act_fp6weight_linear]
triton_red_fused__to_copy_add_fp16act_fp6weight_linear_mean_mul_12.run(buf301, buf311, buf327, buf336, arg26_1, buf338, 1, 4096, grid=grid(1), stream=stream0)
del arg26_1
# Source Nodes: [out_78], Original ATen: [torchao.fp16act_fp6weight_linear]
buf339 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf338, arg224_1, arg225_1, 1)
del arg224_1
del arg225_1
buf340 = buf339
del buf339
buf342 = buf317; del buf317 # reuse
# Source Nodes: [setitem_26, setitem_27, y_39], Original ATen: [aten.bmm, aten.index_put]
triton_poi_fused_bmm_index_put_1.run(arg454_1, buf340, arg66_1, arg226_1, buf342, arg227_1, 4096, grid=grid(4096), stream=stream0)
del buf340
buf343 = buf322; del buf322 # reuse
# Source Nodes: [y_39], Original ATen: [aten.bmm]
triton_red_fused_bmm_2.run(buf342, arg226_1, buf343, 6656, 128, grid=grid(6656), stream=stream0)
del arg226_1
buf347 = buf318; del buf318 # reuse
# Source Nodes: [mask, y_39], Original ATen: [aten._softmax, aten.add, aten.bmm, aten.index, aten.logical_not, aten.masked_fill, aten.zeros_like]
triton_per_fused__softmax_add_bmm_index_logical_not_masked_fill_zeros_like_3.run(buf343, arg454_1, arg67_1, buf347, 32, 208, grid=grid(32), stream=stream0)
buf348 = buf323; del buf323 # reuse
# Source Nodes: [y_39], Original ATen: [aten.bmm]
triton_red_fused_bmm_4.run(buf347, arg227_1, buf348, 8192, 104, grid=grid(8192), stream=stream0)
del arg227_1
buf350 = buf338; del buf338 # reuse
# Source Nodes: [out_79, y_39], Original ATen: [aten.bmm, torchao.fp16act_fp6weight_linear]
triton_per_fused_bmm_fp16act_fp6weight_linear_5.run(buf348, buf350, 4096, 2, grid=grid(4096), stream=stream0)
# Source Nodes: [out_79], Original ATen: [torchao.fp16act_fp6weight_linear]
buf351 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf350, arg228_1, arg229_1, 13)
del arg228_1
del arg229_1
buf352 = buf351
del buf351
buf353 = buf301; del buf301 # reuse
buf355 = buf350; del buf350 # reuse
buf358 = buf303; del buf303 # reuse
# Source Nodes: [float_56, h_12, h_13, mean_27, mul_206, out_71, out_77, out_80, out_81], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, torchao.fp16act_fp6weight_linear]
triton_red_fused__to_copy_add_fp16act_fp6weight_linear_mean_mul_13.run(buf353, buf311, buf327, buf336, buf352, arg27_1, buf355, buf358, 1, 4096, grid=grid(1), stream=stream0)
del arg27_1
del buf311
del buf327
del buf336
del buf352
# Source Nodes: [out_80], Original ATen: [torchao.fp16act_fp6weight_linear]
buf356 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf355, arg230_1, arg231_1, 1)
del arg230_1
del arg231_1
buf357 = buf356
del buf356
# Source Nodes: [out_81], Original ATen: [torchao.fp16act_fp6weight_linear]
buf359 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf358, arg232_1, arg233_1, 1)
del arg232_1
del arg233_1
buf360 = buf359
del buf359
buf361 = buf357; del buf357 # reuse
# Source Nodes: [out_82], Original ATen: [torchao.fp16act_fp6weight_linear]
triton_poi_fused_fp16act_fp6weight_linear_7.run(buf361, buf360, 11008, grid=grid(11008), stream=stream0)
del buf360
# Source Nodes: [out_82], Original ATen: [torchao.fp16act_fp6weight_linear]
buf362 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf361, arg234_1, arg235_1, 13)
del arg234_1
del arg235_1
del buf361
buf363 = buf362
del buf362
buf365 = buf358; del buf358 # reuse
# Source Nodes: [float_57, mean_28, mul_210, out_83, out_84], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, torchao.fp16act_fp6weight_linear]
triton_red_fused__to_copy_add_fp16act_fp6weight_linear_mean_mul_10.run(buf353, buf363, arg28_1, buf365, 1, 4096, grid=grid(1), stream=stream0)
del arg28_1
# Source Nodes: [out_84], Original ATen: [torchao.fp16act_fp6weight_linear]
buf366 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf365, arg236_1, arg237_1, 1)
del arg236_1
del arg237_1
buf367 = buf366
del buf366
buf369 = buf342; del buf342 # reuse
# Source Nodes: [setitem_28, setitem_29, y_42], Original ATen: [aten.bmm, aten.index_put]
triton_poi_fused_bmm_index_put_1.run(arg454_1, buf367, arg66_1, arg238_1, buf369, arg239_1, 4096, grid=grid(4096), stream=stream0)
del buf367
buf370 = buf347; del buf347 # reuse
# Source Nodes: [y_42], Original ATen: [aten.bmm]
triton_red_fused_bmm_2.run(buf369, arg238_1, buf370, 6656, 128, grid=grid(6656), stream=stream0)
del arg238_1
buf374 = buf343; del buf343 # reuse
# Source Nodes: [mask, y_42], Original ATen: [aten._softmax, aten.add, aten.bmm, aten.index, aten.logical_not, aten.masked_fill, aten.zeros_like]
triton_per_fused__softmax_add_bmm_index_logical_not_masked_fill_zeros_like_3.run(buf370, arg454_1, arg67_1, buf374, 32, 208, grid=grid(32), stream=stream0)
buf375 = buf348; del buf348 # reuse
# Source Nodes: [y_42], Original ATen: [aten.bmm]
triton_red_fused_bmm_4.run(buf374, arg239_1, buf375, 8192, 104, grid=grid(8192), stream=stream0)
del arg239_1
buf377 = buf365; del buf365 # reuse
# Source Nodes: [out_85, y_42], Original ATen: [aten.bmm, torchao.fp16act_fp6weight_linear]
triton_per_fused_bmm_fp16act_fp6weight_linear_5.run(buf375, buf377, 4096, 2, grid=grid(4096), stream=stream0)
# Source Nodes: [out_85], Original ATen: [torchao.fp16act_fp6weight_linear]
buf378 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf377, arg240_1, arg241_1, 13)
del arg240_1
del arg241_1
buf379 = buf378
del buf378
buf381 = reinterpret_tensor(buf377, (1, 1, 4096), (4096, 4096, 1), 0); del buf377 # reuse
# Source Nodes: [add_88, float_60, h_14, mean_29, mul_221, mul_222, mul_223, out_83, output_29, rsqrt_29], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, aten.rsqrt]
triton_red_fused__to_copy_add_mean_mul_rsqrt_11.run(buf353, buf363, buf379, arg29_1, buf381, 1, 4096, grid=grid(1), stream=stream0)
del arg29_1
# Source Nodes: [out_86], Original ATen: [torchao.fp16act_fp6weight_linear]
buf382 = torch.ops.torchao.fp16act_fp6weight_linear.default(reinterpret_tensor(buf381, (1, 4096), (4096, 1), 0), arg242_1, arg243_1, 1)
del arg242_1
del arg243_1
buf383 = buf382
del buf382
# Source Nodes: [out_87], Original ATen: [torchao.fp16act_fp6weight_linear]
buf384 = torch.ops.torchao.fp16act_fp6weight_linear.default(reinterpret_tensor(buf381, (1, 4096), (4096, 1), 0), arg244_1, arg245_1, 1)
del arg244_1
del arg245_1
buf385 = buf384
del buf384
buf386 = buf383; del buf383 # reuse
# Source Nodes: [out_88], Original ATen: [torchao.fp16act_fp6weight_linear]
triton_poi_fused_fp16act_fp6weight_linear_7.run(buf386, buf385, 11008, grid=grid(11008), stream=stream0)
del buf385
# Source Nodes: [out_88], Original ATen: [torchao.fp16act_fp6weight_linear]
buf387 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf386, arg246_1, arg247_1, 13)
del arg246_1
del arg247_1
del buf386
buf388 = buf387
del buf387
buf390 = reinterpret_tensor(buf381, (1, 4096), (4096, 1), 0); del buf381 # reuse
# Source Nodes: [float_61, h_14, mean_30, mul_225, out_83, out_89, out_90], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, torchao.fp16act_fp6weight_linear]
triton_red_fused__to_copy_add_fp16act_fp6weight_linear_mean_mul_12.run(buf353, buf363, buf379, buf388, arg30_1, buf390, 1, 4096, grid=grid(1), stream=stream0)
del arg30_1
# Source Nodes: [out_90], Original ATen: [torchao.fp16act_fp6weight_linear]
buf391 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf390, arg248_1, arg249_1, 1)
del arg248_1
del arg249_1
buf392 = buf391
del buf391
buf394 = buf369; del buf369 # reuse
# Source Nodes: [setitem_30, setitem_31, y_45], Original ATen: [aten.bmm, aten.index_put]
triton_poi_fused_bmm_index_put_1.run(arg454_1, buf392, arg66_1, arg250_1, buf394, arg251_1, 4096, grid=grid(4096), stream=stream0)
del buf392
buf395 = buf374; del buf374 # reuse
# Source Nodes: [y_45], Original ATen: [aten.bmm]
triton_red_fused_bmm_2.run(buf394, arg250_1, buf395, 6656, 128, grid=grid(6656), stream=stream0)
del arg250_1
buf399 = buf370; del buf370 # reuse
# Source Nodes: [mask, y_45], Original ATen: [aten._softmax, aten.add, aten.bmm, aten.index, aten.logical_not, aten.masked_fill, aten.zeros_like]
triton_per_fused__softmax_add_bmm_index_logical_not_masked_fill_zeros_like_3.run(buf395, arg454_1, arg67_1, buf399, 32, 208, grid=grid(32), stream=stream0)
buf400 = buf375; del buf375 # reuse
# Source Nodes: [y_45], Original ATen: [aten.bmm]
triton_red_fused_bmm_4.run(buf399, arg251_1, buf400, 8192, 104, grid=grid(8192), stream=stream0)
del arg251_1
buf402 = buf390; del buf390 # reuse
# Source Nodes: [out_91, y_45], Original ATen: [aten.bmm, torchao.fp16act_fp6weight_linear]
triton_per_fused_bmm_fp16act_fp6weight_linear_5.run(buf400, buf402, 4096, 2, grid=grid(4096), stream=stream0)
# Source Nodes: [out_91], Original ATen: [torchao.fp16act_fp6weight_linear]
buf403 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf402, arg252_1, arg253_1, 13)
del arg252_1
del arg253_1
buf404 = buf403
del buf403
buf405 = buf353; del buf353 # reuse
buf407 = buf402; del buf402 # reuse
buf410 = buf355; del buf355 # reuse
# Source Nodes: [float_64, h_14, h_15, mean_31, mul_236, out_83, out_89, out_92, out_93], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, torchao.fp16act_fp6weight_linear]
triton_red_fused__to_copy_add_fp16act_fp6weight_linear_mean_mul_13.run(buf405, buf363, buf379, buf388, buf404, arg31_1, buf407, buf410, 1, 4096, grid=grid(1), stream=stream0)
del arg31_1
del buf363
del buf379
del buf388
del buf404
# Source Nodes: [out_92], Original ATen: [torchao.fp16act_fp6weight_linear]
buf408 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf407, arg254_1, arg255_1, 1)
del arg254_1
del arg255_1
buf409 = buf408
del buf408
# Source Nodes: [out_93], Original ATen: [torchao.fp16act_fp6weight_linear]
buf411 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf410, arg256_1, arg257_1, 1)
del arg256_1
del arg257_1
buf412 = buf411
del buf411
buf413 = buf409; del buf409 # reuse
# Source Nodes: [out_94], Original ATen: [torchao.fp16act_fp6weight_linear]
triton_poi_fused_fp16act_fp6weight_linear_7.run(buf413, buf412, 11008, grid=grid(11008), stream=stream0)
del buf412
# Source Nodes: [out_94], Original ATen: [torchao.fp16act_fp6weight_linear]
buf414 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf413, arg258_1, arg259_1, 13)
del arg258_1
del arg259_1
del buf413
buf415 = buf414
del buf414
buf417 = buf410; del buf410 # reuse
# Source Nodes: [float_65, mean_32, mul_240, out_95, out_96], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, torchao.fp16act_fp6weight_linear]
triton_red_fused__to_copy_add_fp16act_fp6weight_linear_mean_mul_10.run(buf405, buf415, arg32_1, buf417, 1, 4096, grid=grid(1), stream=stream0)
del arg32_1
# Source Nodes: [out_96], Original ATen: [torchao.fp16act_fp6weight_linear]
buf418 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf417, arg260_1, arg261_1, 1)
del arg260_1
del arg261_1
buf419 = buf418
del buf418
buf421 = buf394; del buf394 # reuse
# Source Nodes: [setitem_32, setitem_33, y_48], Original ATen: [aten.bmm, aten.index_put]
triton_poi_fused_bmm_index_put_1.run(arg454_1, buf419, arg66_1, arg262_1, buf421, arg263_1, 4096, grid=grid(4096), stream=stream0)
del buf419
buf422 = buf399; del buf399 # reuse
# Source Nodes: [y_48], Original ATen: [aten.bmm]
triton_red_fused_bmm_2.run(buf421, arg262_1, buf422, 6656, 128, grid=grid(6656), stream=stream0)
del arg262_1
buf426 = buf395; del buf395 # reuse
# Source Nodes: [mask, y_48], Original ATen: [aten._softmax, aten.add, aten.bmm, aten.index, aten.logical_not, aten.masked_fill, aten.zeros_like]
triton_per_fused__softmax_add_bmm_index_logical_not_masked_fill_zeros_like_3.run(buf422, arg454_1, arg67_1, buf426, 32, 208, grid=grid(32), stream=stream0)
buf427 = buf400; del buf400 # reuse
# Source Nodes: [y_48], Original ATen: [aten.bmm]
triton_red_fused_bmm_4.run(buf426, arg263_1, buf427, 8192, 104, grid=grid(8192), stream=stream0)
del arg263_1
buf429 = buf417; del buf417 # reuse
# Source Nodes: [out_97, y_48], Original ATen: [aten.bmm, torchao.fp16act_fp6weight_linear]
triton_per_fused_bmm_fp16act_fp6weight_linear_5.run(buf427, buf429, 4096, 2, grid=grid(4096), stream=stream0)
# Source Nodes: [out_97], Original ATen: [torchao.fp16act_fp6weight_linear]
buf430 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf429, arg264_1, arg265_1, 13)
del arg264_1
del arg265_1
buf431 = buf430
del buf430
buf433 = reinterpret_tensor(buf429, (1, 1, 4096), (4096, 4096, 1), 0); del buf429 # reuse
# Source Nodes: [add_100, float_68, h_16, mean_33, mul_251, mul_252, mul_253, out_95, output_33, rsqrt_33], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, aten.rsqrt]
triton_red_fused__to_copy_add_mean_mul_rsqrt_11.run(buf405, buf415, buf431, arg33_1, buf433, 1, 4096, grid=grid(1), stream=stream0)
del arg33_1
# Source Nodes: [out_98], Original ATen: [torchao.fp16act_fp6weight_linear]
buf434 = torch.ops.torchao.fp16act_fp6weight_linear.default(reinterpret_tensor(buf433, (1, 4096), (4096, 1), 0), arg266_1, arg267_1, 1)
del arg266_1
del arg267_1
buf435 = buf434
del buf434
# Source Nodes: [out_99], Original ATen: [torchao.fp16act_fp6weight_linear]
buf436 = torch.ops.torchao.fp16act_fp6weight_linear.default(reinterpret_tensor(buf433, (1, 4096), (4096, 1), 0), arg268_1, arg269_1, 1)
del arg268_1
del arg269_1
buf437 = buf436
del buf436
buf438 = buf435; del buf435 # reuse
# Source Nodes: [out_100], Original ATen: [torchao.fp16act_fp6weight_linear]
triton_poi_fused_fp16act_fp6weight_linear_7.run(buf438, buf437, 11008, grid=grid(11008), stream=stream0)
del buf437
# Source Nodes: [out_100], Original ATen: [torchao.fp16act_fp6weight_linear]
buf439 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf438, arg270_1, arg271_1, 13)
del arg270_1
del arg271_1
del buf438
buf440 = buf439
del buf439
buf442 = reinterpret_tensor(buf433, (1, 4096), (4096, 1), 0); del buf433 # reuse
# Source Nodes: [float_69, h_16, mean_34, mul_255, out_101, out_102, out_95], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, torchao.fp16act_fp6weight_linear]
triton_red_fused__to_copy_add_fp16act_fp6weight_linear_mean_mul_12.run(buf405, buf415, buf431, buf440, arg34_1, buf442, 1, 4096, grid=grid(1), stream=stream0)
del arg34_1
# Source Nodes: [out_102], Original ATen: [torchao.fp16act_fp6weight_linear]
buf443 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf442, arg272_1, arg273_1, 1)
del arg272_1
del arg273_1
buf444 = buf443
del buf443
buf446 = buf421; del buf421 # reuse
# Source Nodes: [setitem_34, setitem_35, y_51], Original ATen: [aten.bmm, aten.index_put]
triton_poi_fused_bmm_index_put_1.run(arg454_1, buf444, arg66_1, arg274_1, buf446, arg275_1, 4096, grid=grid(4096), stream=stream0)
del buf444
buf447 = buf426; del buf426 # reuse
# Source Nodes: [y_51], Original ATen: [aten.bmm]
triton_red_fused_bmm_2.run(buf446, arg274_1, buf447, 6656, 128, grid=grid(6656), stream=stream0)
del arg274_1
buf451 = buf422; del buf422 # reuse
# Source Nodes: [mask, y_51], Original ATen: [aten._softmax, aten.add, aten.bmm, aten.index, aten.logical_not, aten.masked_fill, aten.zeros_like]
triton_per_fused__softmax_add_bmm_index_logical_not_masked_fill_zeros_like_3.run(buf447, arg454_1, arg67_1, buf451, 32, 208, grid=grid(32), stream=stream0)
buf452 = buf427; del buf427 # reuse
# Source Nodes: [y_51], Original ATen: [aten.bmm]
triton_red_fused_bmm_4.run(buf451, arg275_1, buf452, 8192, 104, grid=grid(8192), stream=stream0)
del arg275_1
buf454 = buf442; del buf442 # reuse
# Source Nodes: [out_103, y_51], Original ATen: [aten.bmm, torchao.fp16act_fp6weight_linear]
triton_per_fused_bmm_fp16act_fp6weight_linear_5.run(buf452, buf454, 4096, 2, grid=grid(4096), stream=stream0)
# Source Nodes: [out_103], Original ATen: [torchao.fp16act_fp6weight_linear]
buf455 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf454, arg276_1, arg277_1, 13)
del arg276_1
del arg277_1
buf456 = buf455
del buf455
buf457 = buf405; del buf405 # reuse
buf459 = buf454; del buf454 # reuse
buf462 = buf407; del buf407 # reuse
# Source Nodes: [float_72, h_16, h_17, mean_35, mul_266, out_101, out_104, out_105, out_95], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, torchao.fp16act_fp6weight_linear]
triton_red_fused__to_copy_add_fp16act_fp6weight_linear_mean_mul_13.run(buf457, buf415, buf431, buf440, buf456, arg35_1, buf459, buf462, 1, 4096, grid=grid(1), stream=stream0)
del arg35_1
del buf415
del buf431
del buf440
del buf456
# Source Nodes: [out_104], Original ATen: [torchao.fp16act_fp6weight_linear]
buf460 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf459, arg278_1, arg279_1, 1)
del arg278_1
del arg279_1
buf461 = buf460
del buf460
# Source Nodes: [out_105], Original ATen: [torchao.fp16act_fp6weight_linear]
buf463 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf462, arg280_1, arg281_1, 1)
del arg280_1
del arg281_1
buf464 = buf463
del buf463
buf465 = buf461; del buf461 # reuse
# Source Nodes: [out_106], Original ATen: [torchao.fp16act_fp6weight_linear]
triton_poi_fused_fp16act_fp6weight_linear_7.run(buf465, buf464, 11008, grid=grid(11008), stream=stream0)
del buf464
# Source Nodes: [out_106], Original ATen: [torchao.fp16act_fp6weight_linear]
buf466 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf465, arg282_1, arg283_1, 13)
del arg282_1
del arg283_1
del buf465
buf467 = buf466
del buf466
buf469 = buf462; del buf462 # reuse
# Source Nodes: [float_73, mean_36, mul_270, out_107, out_108], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, torchao.fp16act_fp6weight_linear]
triton_red_fused__to_copy_add_fp16act_fp6weight_linear_mean_mul_10.run(buf457, buf467, arg36_1, buf469, 1, 4096, grid=grid(1), stream=stream0)
del arg36_1
# Source Nodes: [out_108], Original ATen: [torchao.fp16act_fp6weight_linear]
buf470 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf469, arg284_1, arg285_1, 1)
del arg284_1
del arg285_1
buf471 = buf470
del buf470
buf473 = buf446; del buf446 # reuse
# Source Nodes: [setitem_36, setitem_37, y_54], Original ATen: [aten.bmm, aten.index_put]
triton_poi_fused_bmm_index_put_1.run(arg454_1, buf471, arg66_1, arg286_1, buf473, arg287_1, 4096, grid=grid(4096), stream=stream0)
del buf471
buf474 = buf451; del buf451 # reuse
# Source Nodes: [y_54], Original ATen: [aten.bmm]
triton_red_fused_bmm_2.run(buf473, arg286_1, buf474, 6656, 128, grid=grid(6656), stream=stream0)
del arg286_1
buf478 = buf447; del buf447 # reuse
# Source Nodes: [mask, y_54], Original ATen: [aten._softmax, aten.add, aten.bmm, aten.index, aten.logical_not, aten.masked_fill, aten.zeros_like]
triton_per_fused__softmax_add_bmm_index_logical_not_masked_fill_zeros_like_3.run(buf474, arg454_1, arg67_1, buf478, 32, 208, grid=grid(32), stream=stream0)
buf479 = buf452; del buf452 # reuse
# Source Nodes: [y_54], Original ATen: [aten.bmm]
triton_red_fused_bmm_4.run(buf478, arg287_1, buf479, 8192, 104, grid=grid(8192), stream=stream0)
del arg287_1
buf481 = buf469; del buf469 # reuse
# Source Nodes: [out_109, y_54], Original ATen: [aten.bmm, torchao.fp16act_fp6weight_linear]
triton_per_fused_bmm_fp16act_fp6weight_linear_5.run(buf479, buf481, 4096, 2, grid=grid(4096), stream=stream0)
# Source Nodes: [out_109], Original ATen: [torchao.fp16act_fp6weight_linear]
buf482 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf481, arg288_1, arg289_1, 13)
del arg288_1
del arg289_1
buf483 = buf482
del buf482
buf485 = reinterpret_tensor(buf481, (1, 1, 4096), (4096, 4096, 1), 0); del buf481 # reuse
# Source Nodes: [add_112, float_76, h_18, mean_37, mul_281, mul_282, mul_283, out_107, output_37, rsqrt_37], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, aten.rsqrt]
triton_red_fused__to_copy_add_mean_mul_rsqrt_11.run(buf457, buf467, buf483, arg37_1, buf485, 1, 4096, grid=grid(1), stream=stream0)
del arg37_1
# Source Nodes: [out_110], Original ATen: [torchao.fp16act_fp6weight_linear]
buf486 = torch.ops.torchao.fp16act_fp6weight_linear.default(reinterpret_tensor(buf485, (1, 4096), (4096, 1), 0), arg290_1, arg291_1, 1)
del arg290_1
del arg291_1
buf487 = buf486
del buf486
# Source Nodes: [out_111], Original ATen: [torchao.fp16act_fp6weight_linear]
buf488 = torch.ops.torchao.fp16act_fp6weight_linear.default(reinterpret_tensor(buf485, (1, 4096), (4096, 1), 0), arg292_1, arg293_1, 1)
del arg292_1
del arg293_1
buf489 = buf488
del buf488
buf490 = buf487; del buf487 # reuse
# Source Nodes: [out_112], Original ATen: [torchao.fp16act_fp6weight_linear]
triton_poi_fused_fp16act_fp6weight_linear_7.run(buf490, buf489, 11008, grid=grid(11008), stream=stream0)
del buf489
# Source Nodes: [out_112], Original ATen: [torchao.fp16act_fp6weight_linear]
buf491 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf490, arg294_1, arg295_1, 13)
del arg294_1
del arg295_1
del buf490
buf492 = buf491
del buf491
buf494 = reinterpret_tensor(buf485, (1, 4096), (4096, 1), 0); del buf485 # reuse
# Source Nodes: [float_77, h_18, mean_38, mul_285, out_107, out_113, out_114], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, torchao.fp16act_fp6weight_linear]
triton_red_fused__to_copy_add_fp16act_fp6weight_linear_mean_mul_12.run(buf457, buf467, buf483, buf492, arg38_1, buf494, 1, 4096, grid=grid(1), stream=stream0)
del arg38_1
# Source Nodes: [out_114], Original ATen: [torchao.fp16act_fp6weight_linear]
buf495 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf494, arg296_1, arg297_1, 1)
del arg296_1
del arg297_1
buf496 = buf495
del buf495
buf498 = buf473; del buf473 # reuse
# Source Nodes: [setitem_38, setitem_39, y_57], Original ATen: [aten.bmm, aten.index_put]
triton_poi_fused_bmm_index_put_1.run(arg454_1, buf496, arg66_1, arg298_1, buf498, arg299_1, 4096, grid=grid(4096), stream=stream0)
del buf496
buf499 = buf478; del buf478 # reuse
# Source Nodes: [y_57], Original ATen: [aten.bmm]
triton_red_fused_bmm_2.run(buf498, arg298_1, buf499, 6656, 128, grid=grid(6656), stream=stream0)
del arg298_1
buf503 = buf474; del buf474 # reuse
# Source Nodes: [mask, y_57], Original ATen: [aten._softmax, aten.add, aten.bmm, aten.index, aten.logical_not, aten.masked_fill, aten.zeros_like]
triton_per_fused__softmax_add_bmm_index_logical_not_masked_fill_zeros_like_3.run(buf499, arg454_1, arg67_1, buf503, 32, 208, grid=grid(32), stream=stream0)
buf504 = buf479; del buf479 # reuse
# Source Nodes: [y_57], Original ATen: [aten.bmm]
triton_red_fused_bmm_4.run(buf503, arg299_1, buf504, 8192, 104, grid=grid(8192), stream=stream0)
del arg299_1
buf506 = buf494; del buf494 # reuse
# Source Nodes: [out_115, y_57], Original ATen: [aten.bmm, torchao.fp16act_fp6weight_linear]
triton_per_fused_bmm_fp16act_fp6weight_linear_5.run(buf504, buf506, 4096, 2, grid=grid(4096), stream=stream0)
# Source Nodes: [out_115], Original ATen: [torchao.fp16act_fp6weight_linear]
buf507 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf506, arg300_1, arg301_1, 13)
del arg300_1
del arg301_1
buf508 = buf507
del buf507
buf509 = buf457; del buf457 # reuse
buf511 = buf506; del buf506 # reuse
buf514 = buf459; del buf459 # reuse
# Source Nodes: [float_80, h_18, h_19, mean_39, mul_296, out_107, out_113, out_116, out_117], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, torchao.fp16act_fp6weight_linear]
triton_red_fused__to_copy_add_fp16act_fp6weight_linear_mean_mul_13.run(buf509, buf467, buf483, buf492, buf508, arg39_1, buf511, buf514, 1, 4096, grid=grid(1), stream=stream0)
del arg39_1
del buf467
del buf483
del buf492
del buf508
# Source Nodes: [out_116], Original ATen: [torchao.fp16act_fp6weight_linear]
buf512 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf511, arg302_1, arg303_1, 1)
del arg302_1
del arg303_1
buf513 = buf512
del buf512
# Source Nodes: [out_117], Original ATen: [torchao.fp16act_fp6weight_linear]
buf515 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf514, arg304_1, arg305_1, 1)
del arg304_1
del arg305_1
buf516 = buf515
del buf515
buf517 = buf513; del buf513 # reuse
# Source Nodes: [out_118], Original ATen: [torchao.fp16act_fp6weight_linear]
triton_poi_fused_fp16act_fp6weight_linear_7.run(buf517, buf516, 11008, grid=grid(11008), stream=stream0)
del buf516
# Source Nodes: [out_118], Original ATen: [torchao.fp16act_fp6weight_linear]
buf518 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf517, arg306_1, arg307_1, 13)
del arg306_1
del arg307_1
del buf517
buf519 = buf518
del buf518
buf521 = buf514; del buf514 # reuse
# Source Nodes: [float_81, mean_40, mul_300, out_119, out_120], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, torchao.fp16act_fp6weight_linear]
triton_red_fused__to_copy_add_fp16act_fp6weight_linear_mean_mul_10.run(buf509, buf519, arg40_1, buf521, 1, 4096, grid=grid(1), stream=stream0)
del arg40_1
# Source Nodes: [out_120], Original ATen: [torchao.fp16act_fp6weight_linear]
buf522 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf521, arg308_1, arg309_1, 1)
del arg308_1
del arg309_1
buf523 = buf522
del buf522
buf525 = buf498; del buf498 # reuse
# Source Nodes: [setitem_40, setitem_41, y_60], Original ATen: [aten.bmm, aten.index_put]
triton_poi_fused_bmm_index_put_1.run(arg454_1, buf523, arg66_1, arg310_1, buf525, arg311_1, 4096, grid=grid(4096), stream=stream0)
del buf523
buf526 = buf503; del buf503 # reuse
# Source Nodes: [y_60], Original ATen: [aten.bmm]
triton_red_fused_bmm_2.run(buf525, arg310_1, buf526, 6656, 128, grid=grid(6656), stream=stream0)
del arg310_1
buf530 = buf499; del buf499 # reuse
# Source Nodes: [mask, y_60], Original ATen: [aten._softmax, aten.add, aten.bmm, aten.index, aten.logical_not, aten.masked_fill, aten.zeros_like]
triton_per_fused__softmax_add_bmm_index_logical_not_masked_fill_zeros_like_3.run(buf526, arg454_1, arg67_1, buf530, 32, 208, grid=grid(32), stream=stream0)
buf531 = buf504; del buf504 # reuse
# Source Nodes: [y_60], Original ATen: [aten.bmm]
triton_red_fused_bmm_4.run(buf530, arg311_1, buf531, 8192, 104, grid=grid(8192), stream=stream0)
del arg311_1
buf533 = buf521; del buf521 # reuse
# Source Nodes: [out_121, y_60], Original ATen: [aten.bmm, torchao.fp16act_fp6weight_linear]
triton_per_fused_bmm_fp16act_fp6weight_linear_5.run(buf531, buf533, 4096, 2, grid=grid(4096), stream=stream0)
# Source Nodes: [out_121], Original ATen: [torchao.fp16act_fp6weight_linear]
buf534 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf533, arg312_1, arg313_1, 13)
del arg312_1
del arg313_1
buf535 = buf534
del buf534
buf537 = reinterpret_tensor(buf533, (1, 1, 4096), (4096, 4096, 1), 0); del buf533 # reuse
# Source Nodes: [add_124, float_84, h_20, mean_41, mul_311, mul_312, mul_313, out_119, output_41, rsqrt_41], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, aten.rsqrt]
triton_red_fused__to_copy_add_mean_mul_rsqrt_11.run(buf509, buf519, buf535, arg41_1, buf537, 1, 4096, grid=grid(1), stream=stream0)
del arg41_1
# Source Nodes: [out_122], Original ATen: [torchao.fp16act_fp6weight_linear]
buf538 = torch.ops.torchao.fp16act_fp6weight_linear.default(reinterpret_tensor(buf537, (1, 4096), (4096, 1), 0), arg314_1, arg315_1, 1)
del arg314_1
del arg315_1
buf539 = buf538
del buf538
# Source Nodes: [out_123], Original ATen: [torchao.fp16act_fp6weight_linear]
buf540 = torch.ops.torchao.fp16act_fp6weight_linear.default(reinterpret_tensor(buf537, (1, 4096), (4096, 1), 0), arg316_1, arg317_1, 1)
del arg316_1
del arg317_1
buf541 = buf540
del buf540
buf542 = buf539; del buf539 # reuse
# Source Nodes: [out_124], Original ATen: [torchao.fp16act_fp6weight_linear]
triton_poi_fused_fp16act_fp6weight_linear_7.run(buf542, buf541, 11008, grid=grid(11008), stream=stream0)
del buf541
# Source Nodes: [out_124], Original ATen: [torchao.fp16act_fp6weight_linear]
buf543 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf542, arg318_1, arg319_1, 13)
del arg318_1
del arg319_1
del buf542
buf544 = buf543
del buf543
buf546 = reinterpret_tensor(buf537, (1, 4096), (4096, 1), 0); del buf537 # reuse
# Source Nodes: [float_85, h_20, mean_42, mul_315, out_119, out_125, out_126], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, torchao.fp16act_fp6weight_linear]
triton_red_fused__to_copy_add_fp16act_fp6weight_linear_mean_mul_12.run(buf509, buf519, buf535, buf544, arg42_1, buf546, 1, 4096, grid=grid(1), stream=stream0)
del arg42_1
# Source Nodes: [out_126], Original ATen: [torchao.fp16act_fp6weight_linear]
buf547 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf546, arg320_1, arg321_1, 1)
del arg320_1
del arg321_1
buf548 = buf547
del buf547
buf550 = buf525; del buf525 # reuse
# Source Nodes: [setitem_42, setitem_43, y_63], Original ATen: [aten.bmm, aten.index_put]
triton_poi_fused_bmm_index_put_1.run(arg454_1, buf548, arg66_1, arg322_1, buf550, arg323_1, 4096, grid=grid(4096), stream=stream0)
del buf548
buf551 = buf530; del buf530 # reuse
# Source Nodes: [y_63], Original ATen: [aten.bmm]
triton_red_fused_bmm_2.run(buf550, arg322_1, buf551, 6656, 128, grid=grid(6656), stream=stream0)
del arg322_1
buf555 = buf526; del buf526 # reuse
# Source Nodes: [mask, y_63], Original ATen: [aten._softmax, aten.add, aten.bmm, aten.index, aten.logical_not, aten.masked_fill, aten.zeros_like]
triton_per_fused__softmax_add_bmm_index_logical_not_masked_fill_zeros_like_3.run(buf551, arg454_1, arg67_1, buf555, 32, 208, grid=grid(32), stream=stream0)
buf556 = buf531; del buf531 # reuse
# Source Nodes: [y_63], Original ATen: [aten.bmm]
triton_red_fused_bmm_4.run(buf555, arg323_1, buf556, 8192, 104, grid=grid(8192), stream=stream0)
del arg323_1
buf558 = buf546; del buf546 # reuse
# Source Nodes: [out_127, y_63], Original ATen: [aten.bmm, torchao.fp16act_fp6weight_linear]
triton_per_fused_bmm_fp16act_fp6weight_linear_5.run(buf556, buf558, 4096, 2, grid=grid(4096), stream=stream0)
# Source Nodes: [out_127], Original ATen: [torchao.fp16act_fp6weight_linear]
buf559 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf558, arg324_1, arg325_1, 13)
del arg324_1
del arg325_1
buf560 = buf559
del buf559
buf561 = buf509; del buf509 # reuse
buf563 = buf558; del buf558 # reuse
buf566 = buf511; del buf511 # reuse
# Source Nodes: [float_88, h_20, h_21, mean_43, mul_326, out_119, out_125, out_128, out_129], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, torchao.fp16act_fp6weight_linear]
triton_red_fused__to_copy_add_fp16act_fp6weight_linear_mean_mul_13.run(buf561, buf519, buf535, buf544, buf560, arg43_1, buf563, buf566, 1, 4096, grid=grid(1), stream=stream0)
del arg43_1
del buf519
del buf535
del buf544
del buf560
# Source Nodes: [out_128], Original ATen: [torchao.fp16act_fp6weight_linear]
buf564 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf563, arg326_1, arg327_1, 1)
del arg326_1
del arg327_1
buf565 = buf564
del buf564
# Source Nodes: [out_129], Original ATen: [torchao.fp16act_fp6weight_linear]
buf567 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf566, arg328_1, arg329_1, 1)
del arg328_1
del arg329_1
buf568 = buf567
del buf567
buf569 = buf565; del buf565 # reuse
# Source Nodes: [out_130], Original ATen: [torchao.fp16act_fp6weight_linear]
triton_poi_fused_fp16act_fp6weight_linear_7.run(buf569, buf568, 11008, grid=grid(11008), stream=stream0)
del buf568
# Source Nodes: [out_130], Original ATen: [torchao.fp16act_fp6weight_linear]
buf570 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf569, arg330_1, arg331_1, 13)
del arg330_1
del arg331_1
del buf569
buf571 = buf570
del buf570
buf573 = buf566; del buf566 # reuse
# Source Nodes: [float_89, mean_44, mul_330, out_131, out_132], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, torchao.fp16act_fp6weight_linear]
triton_red_fused__to_copy_add_fp16act_fp6weight_linear_mean_mul_10.run(buf561, buf571, arg44_1, buf573, 1, 4096, grid=grid(1), stream=stream0)
del arg44_1
# Source Nodes: [out_132], Original ATen: [torchao.fp16act_fp6weight_linear]
buf574 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf573, arg332_1, arg333_1, 1)
del arg332_1
del arg333_1
buf575 = buf574
del buf574
buf577 = buf550; del buf550 # reuse
# Source Nodes: [setitem_44, setitem_45, y_66], Original ATen: [aten.bmm, aten.index_put]
triton_poi_fused_bmm_index_put_1.run(arg454_1, buf575, arg66_1, arg334_1, buf577, arg335_1, 4096, grid=grid(4096), stream=stream0)
del buf575
buf578 = buf555; del buf555 # reuse
# Source Nodes: [y_66], Original ATen: [aten.bmm]
triton_red_fused_bmm_2.run(buf577, arg334_1, buf578, 6656, 128, grid=grid(6656), stream=stream0)
del arg334_1
buf582 = buf551; del buf551 # reuse
# Source Nodes: [mask, y_66], Original ATen: [aten._softmax, aten.add, aten.bmm, aten.index, aten.logical_not, aten.masked_fill, aten.zeros_like]
triton_per_fused__softmax_add_bmm_index_logical_not_masked_fill_zeros_like_3.run(buf578, arg454_1, arg67_1, buf582, 32, 208, grid=grid(32), stream=stream0)
buf583 = buf556; del buf556 # reuse
# Source Nodes: [y_66], Original ATen: [aten.bmm]
triton_red_fused_bmm_4.run(buf582, arg335_1, buf583, 8192, 104, grid=grid(8192), stream=stream0)
del arg335_1
buf585 = buf573; del buf573 # reuse
# Source Nodes: [out_133, y_66], Original ATen: [aten.bmm, torchao.fp16act_fp6weight_linear]
triton_per_fused_bmm_fp16act_fp6weight_linear_5.run(buf583, buf585, 4096, 2, grid=grid(4096), stream=stream0)
# Source Nodes: [out_133], Original ATen: [torchao.fp16act_fp6weight_linear]
buf586 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf585, arg336_1, arg337_1, 13)
del arg336_1
del arg337_1
buf587 = buf586
del buf586
buf589 = reinterpret_tensor(buf585, (1, 1, 4096), (4096, 4096, 1), 0); del buf585 # reuse
# Source Nodes: [add_136, float_92, h_22, mean_45, mul_341, mul_342, mul_343, out_131, output_45, rsqrt_45], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, aten.rsqrt]
triton_red_fused__to_copy_add_mean_mul_rsqrt_11.run(buf561, buf571, buf587, arg45_1, buf589, 1, 4096, grid=grid(1), stream=stream0)
del arg45_1
# Source Nodes: [out_134], Original ATen: [torchao.fp16act_fp6weight_linear]
buf590 = torch.ops.torchao.fp16act_fp6weight_linear.default(reinterpret_tensor(buf589, (1, 4096), (4096, 1), 0), arg338_1, arg339_1, 1)
del arg338_1
del arg339_1
buf591 = buf590
del buf590
# Source Nodes: [out_135], Original ATen: [torchao.fp16act_fp6weight_linear]
buf592 = torch.ops.torchao.fp16act_fp6weight_linear.default(reinterpret_tensor(buf589, (1, 4096), (4096, 1), 0), arg340_1, arg341_1, 1)
del arg340_1
del arg341_1
buf593 = buf592
del buf592
buf594 = buf591; del buf591 # reuse
# Source Nodes: [out_136], Original ATen: [torchao.fp16act_fp6weight_linear]
triton_poi_fused_fp16act_fp6weight_linear_7.run(buf594, buf593, 11008, grid=grid(11008), stream=stream0)
del buf593
# Source Nodes: [out_136], Original ATen: [torchao.fp16act_fp6weight_linear]
buf595 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf594, arg342_1, arg343_1, 13)
del arg342_1
del arg343_1
del buf594
buf596 = buf595
del buf595
buf598 = reinterpret_tensor(buf589, (1, 4096), (4096, 1), 0); del buf589 # reuse
# Source Nodes: [float_93, h_22, mean_46, mul_345, out_131, out_137, out_138], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, torchao.fp16act_fp6weight_linear]
triton_red_fused__to_copy_add_fp16act_fp6weight_linear_mean_mul_12.run(buf561, buf571, buf587, buf596, arg46_1, buf598, 1, 4096, grid=grid(1), stream=stream0)
del arg46_1
# Source Nodes: [out_138], Original ATen: [torchao.fp16act_fp6weight_linear]
buf599 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf598, arg344_1, arg345_1, 1)
del arg344_1
del arg345_1
buf600 = buf599
del buf599
buf602 = buf577; del buf577 # reuse
# Source Nodes: [setitem_46, setitem_47, y_69], Original ATen: [aten.bmm, aten.index_put]
triton_poi_fused_bmm_index_put_1.run(arg454_1, buf600, arg66_1, arg346_1, buf602, arg347_1, 4096, grid=grid(4096), stream=stream0)
del buf600
buf603 = buf582; del buf582 # reuse
# Source Nodes: [y_69], Original ATen: [aten.bmm]
triton_red_fused_bmm_2.run(buf602, arg346_1, buf603, 6656, 128, grid=grid(6656), stream=stream0)
del arg346_1
buf607 = buf578; del buf578 # reuse
# Source Nodes: [mask, y_69], Original ATen: [aten._softmax, aten.add, aten.bmm, aten.index, aten.logical_not, aten.masked_fill, aten.zeros_like]
triton_per_fused__softmax_add_bmm_index_logical_not_masked_fill_zeros_like_3.run(buf603, arg454_1, arg67_1, buf607, 32, 208, grid=grid(32), stream=stream0)
buf608 = buf583; del buf583 # reuse
# Source Nodes: [y_69], Original ATen: [aten.bmm]
triton_red_fused_bmm_4.run(buf607, arg347_1, buf608, 8192, 104, grid=grid(8192), stream=stream0)
del arg347_1
buf610 = buf598; del buf598 # reuse
# Source Nodes: [out_139, y_69], Original ATen: [aten.bmm, torchao.fp16act_fp6weight_linear]
triton_per_fused_bmm_fp16act_fp6weight_linear_5.run(buf608, buf610, 4096, 2, grid=grid(4096), stream=stream0)
# Source Nodes: [out_139], Original ATen: [torchao.fp16act_fp6weight_linear]
buf611 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf610, arg348_1, arg349_1, 13)
del arg348_1
del arg349_1
buf612 = buf611
del buf611
buf613 = buf561; del buf561 # reuse
buf615 = buf610; del buf610 # reuse
buf618 = buf563; del buf563 # reuse
# Source Nodes: [float_96, h_22, h_23, mean_47, mul_356, out_131, out_137, out_140, out_141], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, torchao.fp16act_fp6weight_linear]
triton_red_fused__to_copy_add_fp16act_fp6weight_linear_mean_mul_13.run(buf613, buf571, buf587, buf596, buf612, arg47_1, buf615, buf618, 1, 4096, grid=grid(1), stream=stream0)
del arg47_1
del buf571
del buf587
del buf596
del buf612
# Source Nodes: [out_140], Original ATen: [torchao.fp16act_fp6weight_linear]
buf616 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf615, arg350_1, arg351_1, 1)
del arg350_1
del arg351_1
buf617 = buf616
del buf616
# Source Nodes: [out_141], Original ATen: [torchao.fp16act_fp6weight_linear]
buf619 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf618, arg352_1, arg353_1, 1)
del arg352_1
del arg353_1
buf620 = buf619
del buf619
buf621 = buf617; del buf617 # reuse
# Source Nodes: [out_142], Original ATen: [torchao.fp16act_fp6weight_linear]
triton_poi_fused_fp16act_fp6weight_linear_7.run(buf621, buf620, 11008, grid=grid(11008), stream=stream0)
del buf620
# Source Nodes: [out_142], Original ATen: [torchao.fp16act_fp6weight_linear]
buf622 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf621, arg354_1, arg355_1, 13)
del arg354_1
del arg355_1
del buf621
buf623 = buf622
del buf622
buf625 = buf618; del buf618 # reuse
# Source Nodes: [float_97, mean_48, mul_360, out_143, out_144], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, torchao.fp16act_fp6weight_linear]
triton_red_fused__to_copy_add_fp16act_fp6weight_linear_mean_mul_10.run(buf613, buf623, arg48_1, buf625, 1, 4096, grid=grid(1), stream=stream0)
del arg48_1
# Source Nodes: [out_144], Original ATen: [torchao.fp16act_fp6weight_linear]
buf626 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf625, arg356_1, arg357_1, 1)
del arg356_1
del arg357_1
buf627 = buf626
del buf626
buf629 = buf602; del buf602 # reuse
# Source Nodes: [setitem_48, setitem_49, y_72], Original ATen: [aten.bmm, aten.index_put]
triton_poi_fused_bmm_index_put_1.run(arg454_1, buf627, arg66_1, arg358_1, buf629, arg359_1, 4096, grid=grid(4096), stream=stream0)
del buf627
buf630 = buf607; del buf607 # reuse
# Source Nodes: [y_72], Original ATen: [aten.bmm]
triton_red_fused_bmm_2.run(buf629, arg358_1, buf630, 6656, 128, grid=grid(6656), stream=stream0)
del arg358_1
buf634 = buf603; del buf603 # reuse
# Source Nodes: [mask, y_72], Original ATen: [aten._softmax, aten.add, aten.bmm, aten.index, aten.logical_not, aten.masked_fill, aten.zeros_like]
triton_per_fused__softmax_add_bmm_index_logical_not_masked_fill_zeros_like_3.run(buf630, arg454_1, arg67_1, buf634, 32, 208, grid=grid(32), stream=stream0)
buf635 = buf608; del buf608 # reuse
# Source Nodes: [y_72], Original ATen: [aten.bmm]
triton_red_fused_bmm_4.run(buf634, arg359_1, buf635, 8192, 104, grid=grid(8192), stream=stream0)
del arg359_1
buf637 = buf625; del buf625 # reuse
# Source Nodes: [out_145, y_72], Original ATen: [aten.bmm, torchao.fp16act_fp6weight_linear]
triton_per_fused_bmm_fp16act_fp6weight_linear_5.run(buf635, buf637, 4096, 2, grid=grid(4096), stream=stream0)
# Source Nodes: [out_145], Original ATen: [torchao.fp16act_fp6weight_linear]
buf638 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf637, arg360_1, arg361_1, 13)
del arg360_1
del arg361_1
buf639 = buf638
del buf638
buf641 = reinterpret_tensor(buf637, (1, 1, 4096), (4096, 4096, 1), 0); del buf637 # reuse
# Source Nodes: [add_148, float_100, h_24, mean_49, mul_371, mul_372, mul_373, out_143, output_49, rsqrt_49], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, aten.rsqrt]
triton_red_fused__to_copy_add_mean_mul_rsqrt_11.run(buf613, buf623, buf639, arg49_1, buf641, 1, 4096, grid=grid(1), stream=stream0)
del arg49_1
# Source Nodes: [out_146], Original ATen: [torchao.fp16act_fp6weight_linear]
buf642 = torch.ops.torchao.fp16act_fp6weight_linear.default(reinterpret_tensor(buf641, (1, 4096), (4096, 1), 0), arg362_1, arg363_1, 1)
del arg362_1
del arg363_1
buf643 = buf642
del buf642
# Source Nodes: [out_147], Original ATen: [torchao.fp16act_fp6weight_linear]
buf644 = torch.ops.torchao.fp16act_fp6weight_linear.default(reinterpret_tensor(buf641, (1, 4096), (4096, 1), 0), arg364_1, arg365_1, 1)
del arg364_1
del arg365_1
buf645 = buf644
del buf644
buf646 = buf643; del buf643 # reuse
# Source Nodes: [out_148], Original ATen: [torchao.fp16act_fp6weight_linear]
triton_poi_fused_fp16act_fp6weight_linear_7.run(buf646, buf645, 11008, grid=grid(11008), stream=stream0)
del buf645
# Source Nodes: [out_148], Original ATen: [torchao.fp16act_fp6weight_linear]
buf647 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf646, arg366_1, arg367_1, 13)
del arg366_1
del arg367_1
del buf646
buf648 = buf647
del buf647
buf650 = reinterpret_tensor(buf641, (1, 4096), (4096, 1), 0); del buf641 # reuse
# Source Nodes: [float_101, h_24, mean_50, mul_375, out_143, out_149, out_150], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, torchao.fp16act_fp6weight_linear]
triton_red_fused__to_copy_add_fp16act_fp6weight_linear_mean_mul_12.run(buf613, buf623, buf639, buf648, arg50_1, buf650, 1, 4096, grid=grid(1), stream=stream0)
del arg50_1
# Source Nodes: [out_150], Original ATen: [torchao.fp16act_fp6weight_linear]
buf651 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf650, arg368_1, arg369_1, 1)
del arg368_1
del arg369_1
buf652 = buf651
del buf651
buf654 = buf629; del buf629 # reuse
# Source Nodes: [setitem_50, setitem_51, y_75], Original ATen: [aten.bmm, aten.index_put]
triton_poi_fused_bmm_index_put_1.run(arg454_1, buf652, arg66_1, arg370_1, buf654, arg371_1, 4096, grid=grid(4096), stream=stream0)
del buf652
buf655 = buf634; del buf634 # reuse
# Source Nodes: [y_75], Original ATen: [aten.bmm]
triton_red_fused_bmm_2.run(buf654, arg370_1, buf655, 6656, 128, grid=grid(6656), stream=stream0)
del arg370_1
buf659 = buf630; del buf630 # reuse
# Source Nodes: [mask, y_75], Original ATen: [aten._softmax, aten.add, aten.bmm, aten.index, aten.logical_not, aten.masked_fill, aten.zeros_like]
triton_per_fused__softmax_add_bmm_index_logical_not_masked_fill_zeros_like_3.run(buf655, arg454_1, arg67_1, buf659, 32, 208, grid=grid(32), stream=stream0)
buf660 = buf635; del buf635 # reuse
# Source Nodes: [y_75], Original ATen: [aten.bmm]
triton_red_fused_bmm_4.run(buf659, arg371_1, buf660, 8192, 104, grid=grid(8192), stream=stream0)
del arg371_1
buf662 = buf650; del buf650 # reuse
# Source Nodes: [out_151, y_75], Original ATen: [aten.bmm, torchao.fp16act_fp6weight_linear]
triton_per_fused_bmm_fp16act_fp6weight_linear_5.run(buf660, buf662, 4096, 2, grid=grid(4096), stream=stream0)
# Source Nodes: [out_151], Original ATen: [torchao.fp16act_fp6weight_linear]
buf663 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf662, arg372_1, arg373_1, 13)
del arg372_1
del arg373_1
buf664 = buf663
del buf663
buf665 = buf613; del buf613 # reuse
buf667 = buf662; del buf662 # reuse
buf670 = buf615; del buf615 # reuse
# Source Nodes: [float_104, h_24, h_25, mean_51, mul_386, out_143, out_149, out_152, out_153], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, torchao.fp16act_fp6weight_linear]
triton_red_fused__to_copy_add_fp16act_fp6weight_linear_mean_mul_13.run(buf665, buf623, buf639, buf648, buf664, arg51_1, buf667, buf670, 1, 4096, grid=grid(1), stream=stream0)
del arg51_1
del buf623
del buf639
del buf648
del buf664
# Source Nodes: [out_152], Original ATen: [torchao.fp16act_fp6weight_linear]
buf668 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf667, arg374_1, arg375_1, 1)
del arg374_1
del arg375_1
buf669 = buf668
del buf668
# Source Nodes: [out_153], Original ATen: [torchao.fp16act_fp6weight_linear]
buf671 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf670, arg376_1, arg377_1, 1)
del arg376_1
del arg377_1
buf672 = buf671
del buf671
buf673 = buf669; del buf669 # reuse
# Source Nodes: [out_154], Original ATen: [torchao.fp16act_fp6weight_linear]
triton_poi_fused_fp16act_fp6weight_linear_7.run(buf673, buf672, 11008, grid=grid(11008), stream=stream0)
del buf672
# Source Nodes: [out_154], Original ATen: [torchao.fp16act_fp6weight_linear]
buf674 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf673, arg378_1, arg379_1, 13)
del arg378_1
del arg379_1
del buf673
buf675 = buf674
del buf674
buf677 = buf670; del buf670 # reuse
# Source Nodes: [float_105, mean_52, mul_390, out_155, out_156], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, torchao.fp16act_fp6weight_linear]
triton_red_fused__to_copy_add_fp16act_fp6weight_linear_mean_mul_10.run(buf665, buf675, arg52_1, buf677, 1, 4096, grid=grid(1), stream=stream0)
del arg52_1
# Source Nodes: [out_156], Original ATen: [torchao.fp16act_fp6weight_linear]
buf678 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf677, arg380_1, arg381_1, 1)
del arg380_1
del arg381_1
buf679 = buf678
del buf678
buf681 = buf654; del buf654 # reuse
# Source Nodes: [setitem_52, setitem_53, y_78], Original ATen: [aten.bmm, aten.index_put]
triton_poi_fused_bmm_index_put_1.run(arg454_1, buf679, arg66_1, arg382_1, buf681, arg383_1, 4096, grid=grid(4096), stream=stream0)
del buf679
buf682 = buf659; del buf659 # reuse
# Source Nodes: [y_78], Original ATen: [aten.bmm]
triton_red_fused_bmm_2.run(buf681, arg382_1, buf682, 6656, 128, grid=grid(6656), stream=stream0)
del arg382_1
buf686 = buf655; del buf655 # reuse
# Source Nodes: [mask, y_78], Original ATen: [aten._softmax, aten.add, aten.bmm, aten.index, aten.logical_not, aten.masked_fill, aten.zeros_like]
triton_per_fused__softmax_add_bmm_index_logical_not_masked_fill_zeros_like_3.run(buf682, arg454_1, arg67_1, buf686, 32, 208, grid=grid(32), stream=stream0)
buf687 = buf660; del buf660 # reuse
# Source Nodes: [y_78], Original ATen: [aten.bmm]
triton_red_fused_bmm_4.run(buf686, arg383_1, buf687, 8192, 104, grid=grid(8192), stream=stream0)
del arg383_1
buf689 = buf677; del buf677 # reuse
# Source Nodes: [out_157, y_78], Original ATen: [aten.bmm, torchao.fp16act_fp6weight_linear]
triton_per_fused_bmm_fp16act_fp6weight_linear_5.run(buf687, buf689, 4096, 2, grid=grid(4096), stream=stream0)
# Source Nodes: [out_157], Original ATen: [torchao.fp16act_fp6weight_linear]
buf690 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf689, arg384_1, arg385_1, 13)
del arg384_1
del arg385_1
buf691 = buf690
del buf690
buf693 = reinterpret_tensor(buf689, (1, 1, 4096), (4096, 4096, 1), 0); del buf689 # reuse
# Source Nodes: [add_160, float_108, h_26, mean_53, mul_401, mul_402, mul_403, out_155, output_53, rsqrt_53], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, aten.rsqrt]
triton_red_fused__to_copy_add_mean_mul_rsqrt_11.run(buf665, buf675, buf691, arg53_1, buf693, 1, 4096, grid=grid(1), stream=stream0)
del arg53_1
# Source Nodes: [out_158], Original ATen: [torchao.fp16act_fp6weight_linear]
buf694 = torch.ops.torchao.fp16act_fp6weight_linear.default(reinterpret_tensor(buf693, (1, 4096), (4096, 1), 0), arg386_1, arg387_1, 1)
del arg386_1
del arg387_1
buf695 = buf694
del buf694
# Source Nodes: [out_159], Original ATen: [torchao.fp16act_fp6weight_linear]
buf696 = torch.ops.torchao.fp16act_fp6weight_linear.default(reinterpret_tensor(buf693, (1, 4096), (4096, 1), 0), arg388_1, arg389_1, 1)
del arg388_1
del arg389_1
buf697 = buf696
del buf696
buf698 = buf695; del buf695 # reuse
# Source Nodes: [out_160], Original ATen: [torchao.fp16act_fp6weight_linear]
triton_poi_fused_fp16act_fp6weight_linear_7.run(buf698, buf697, 11008, grid=grid(11008), stream=stream0)
del buf697
# Source Nodes: [out_160], Original ATen: [torchao.fp16act_fp6weight_linear]
buf699 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf698, arg390_1, arg391_1, 13)
del arg390_1
del arg391_1
del buf698
buf700 = buf699
del buf699
buf702 = reinterpret_tensor(buf693, (1, 4096), (4096, 1), 0); del buf693 # reuse
# Source Nodes: [float_109, h_26, mean_54, mul_405, out_155, out_161, out_162], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, torchao.fp16act_fp6weight_linear]
triton_red_fused__to_copy_add_fp16act_fp6weight_linear_mean_mul_12.run(buf665, buf675, buf691, buf700, arg54_1, buf702, 1, 4096, grid=grid(1), stream=stream0)
del arg54_1
# Source Nodes: [out_162], Original ATen: [torchao.fp16act_fp6weight_linear]
buf703 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf702, arg392_1, arg393_1, 1)
del arg392_1
del arg393_1
buf704 = buf703
del buf703
buf706 = buf681; del buf681 # reuse
# Source Nodes: [setitem_54, setitem_55, y_81], Original ATen: [aten.bmm, aten.index_put]
triton_poi_fused_bmm_index_put_1.run(arg454_1, buf704, arg66_1, arg394_1, buf706, arg395_1, 4096, grid=grid(4096), stream=stream0)
del buf704
buf707 = buf686; del buf686 # reuse
# Source Nodes: [y_81], Original ATen: [aten.bmm]
triton_red_fused_bmm_2.run(buf706, arg394_1, buf707, 6656, 128, grid=grid(6656), stream=stream0)
del arg394_1
buf711 = buf682; del buf682 # reuse
# Source Nodes: [mask, y_81], Original ATen: [aten._softmax, aten.add, aten.bmm, aten.index, aten.logical_not, aten.masked_fill, aten.zeros_like]
triton_per_fused__softmax_add_bmm_index_logical_not_masked_fill_zeros_like_3.run(buf707, arg454_1, arg67_1, buf711, 32, 208, grid=grid(32), stream=stream0)
buf712 = buf687; del buf687 # reuse
# Source Nodes: [y_81], Original ATen: [aten.bmm]
triton_red_fused_bmm_4.run(buf711, arg395_1, buf712, 8192, 104, grid=grid(8192), stream=stream0)
del arg395_1
buf714 = buf702; del buf702 # reuse
# Source Nodes: [out_163, y_81], Original ATen: [aten.bmm, torchao.fp16act_fp6weight_linear]
triton_per_fused_bmm_fp16act_fp6weight_linear_5.run(buf712, buf714, 4096, 2, grid=grid(4096), stream=stream0)
# Source Nodes: [out_163], Original ATen: [torchao.fp16act_fp6weight_linear]
buf715 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf714, arg396_1, arg397_1, 13)
del arg396_1
del arg397_1
buf716 = buf715
del buf715
buf717 = buf665; del buf665 # reuse
buf719 = buf714; del buf714 # reuse
buf722 = buf667; del buf667 # reuse
# Source Nodes: [float_112, h_26, h_27, mean_55, mul_416, out_155, out_161, out_164, out_165], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, torchao.fp16act_fp6weight_linear]
triton_red_fused__to_copy_add_fp16act_fp6weight_linear_mean_mul_13.run(buf717, buf675, buf691, buf700, buf716, arg55_1, buf719, buf722, 1, 4096, grid=grid(1), stream=stream0)
del arg55_1
del buf675
del buf691
del buf700
del buf716
# Source Nodes: [out_164], Original ATen: [torchao.fp16act_fp6weight_linear]
buf720 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf719, arg398_1, arg399_1, 1)
del arg398_1
del arg399_1
buf721 = buf720
del buf720
# Source Nodes: [out_165], Original ATen: [torchao.fp16act_fp6weight_linear]
buf723 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf722, arg400_1, arg401_1, 1)
del arg400_1
del arg401_1
buf724 = buf723
del buf723
buf725 = buf721; del buf721 # reuse
# Source Nodes: [out_166], Original ATen: [torchao.fp16act_fp6weight_linear]
triton_poi_fused_fp16act_fp6weight_linear_7.run(buf725, buf724, 11008, grid=grid(11008), stream=stream0)
del buf724
# Source Nodes: [out_166], Original ATen: [torchao.fp16act_fp6weight_linear]
buf726 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf725, arg402_1, arg403_1, 13)
del arg402_1
del arg403_1
del buf725
buf727 = buf726
del buf726
buf729 = buf722; del buf722 # reuse
# Source Nodes: [float_113, mean_56, mul_420, out_167, out_168], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, torchao.fp16act_fp6weight_linear]
triton_red_fused__to_copy_add_fp16act_fp6weight_linear_mean_mul_10.run(buf717, buf727, arg56_1, buf729, 1, 4096, grid=grid(1), stream=stream0)
del arg56_1
# Source Nodes: [out_168], Original ATen: [torchao.fp16act_fp6weight_linear]
buf730 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf729, arg404_1, arg405_1, 1)
del arg404_1
del arg405_1
buf731 = buf730
del buf730
buf733 = buf706; del buf706 # reuse
# Source Nodes: [setitem_56, setitem_57, y_84], Original ATen: [aten.bmm, aten.index_put]
triton_poi_fused_bmm_index_put_1.run(arg454_1, buf731, arg66_1, arg406_1, buf733, arg407_1, 4096, grid=grid(4096), stream=stream0)
del buf731
buf734 = buf711; del buf711 # reuse
# Source Nodes: [y_84], Original ATen: [aten.bmm]
triton_red_fused_bmm_2.run(buf733, arg406_1, buf734, 6656, 128, grid=grid(6656), stream=stream0)
del arg406_1
buf738 = buf707; del buf707 # reuse
# Source Nodes: [mask, y_84], Original ATen: [aten._softmax, aten.add, aten.bmm, aten.index, aten.logical_not, aten.masked_fill, aten.zeros_like]
triton_per_fused__softmax_add_bmm_index_logical_not_masked_fill_zeros_like_3.run(buf734, arg454_1, arg67_1, buf738, 32, 208, grid=grid(32), stream=stream0)
buf739 = buf712; del buf712 # reuse
# Source Nodes: [y_84], Original ATen: [aten.bmm]
triton_red_fused_bmm_4.run(buf738, arg407_1, buf739, 8192, 104, grid=grid(8192), stream=stream0)
del arg407_1
buf741 = buf729; del buf729 # reuse
# Source Nodes: [out_169, y_84], Original ATen: [aten.bmm, torchao.fp16act_fp6weight_linear]
triton_per_fused_bmm_fp16act_fp6weight_linear_5.run(buf739, buf741, 4096, 2, grid=grid(4096), stream=stream0)
# Source Nodes: [out_169], Original ATen: [torchao.fp16act_fp6weight_linear]
buf742 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf741, arg408_1, arg409_1, 13)
del arg408_1
del arg409_1
buf743 = buf742
del buf742
buf745 = reinterpret_tensor(buf741, (1, 1, 4096), (4096, 4096, 1), 0); del buf741 # reuse
# Source Nodes: [add_172, float_116, h_28, mean_57, mul_431, mul_432, mul_433, out_167, output_57, rsqrt_57], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, aten.rsqrt]
triton_red_fused__to_copy_add_mean_mul_rsqrt_11.run(buf717, buf727, buf743, arg57_1, buf745, 1, 4096, grid=grid(1), stream=stream0)
del arg57_1
# Source Nodes: [out_170], Original ATen: [torchao.fp16act_fp6weight_linear]
buf746 = torch.ops.torchao.fp16act_fp6weight_linear.default(reinterpret_tensor(buf745, (1, 4096), (4096, 1), 0), arg410_1, arg411_1, 1)
del arg410_1
del arg411_1
buf747 = buf746
del buf746
# Source Nodes: [out_171], Original ATen: [torchao.fp16act_fp6weight_linear]
buf748 = torch.ops.torchao.fp16act_fp6weight_linear.default(reinterpret_tensor(buf745, (1, 4096), (4096, 1), 0), arg412_1, arg413_1, 1)
del arg412_1
del arg413_1
buf749 = buf748
del buf748
buf750 = buf747; del buf747 # reuse
# Source Nodes: [out_172], Original ATen: [torchao.fp16act_fp6weight_linear]
triton_poi_fused_fp16act_fp6weight_linear_7.run(buf750, buf749, 11008, grid=grid(11008), stream=stream0)
del buf749
# Source Nodes: [out_172], Original ATen: [torchao.fp16act_fp6weight_linear]
buf751 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf750, arg414_1, arg415_1, 13)
del arg414_1
del arg415_1
del buf750
buf752 = buf751
del buf751
buf754 = reinterpret_tensor(buf745, (1, 4096), (4096, 1), 0); del buf745 # reuse
# Source Nodes: [float_117, h_28, mean_58, mul_435, out_167, out_173, out_174], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, torchao.fp16act_fp6weight_linear]
triton_red_fused__to_copy_add_fp16act_fp6weight_linear_mean_mul_12.run(buf717, buf727, buf743, buf752, arg58_1, buf754, 1, 4096, grid=grid(1), stream=stream0)
del arg58_1
# Source Nodes: [out_174], Original ATen: [torchao.fp16act_fp6weight_linear]
buf755 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf754, arg416_1, arg417_1, 1)
del arg416_1
del arg417_1
buf756 = buf755
del buf755
buf758 = buf733; del buf733 # reuse
# Source Nodes: [setitem_58, setitem_59, y_87], Original ATen: [aten.bmm, aten.index_put]
triton_poi_fused_bmm_index_put_1.run(arg454_1, buf756, arg66_1, arg418_1, buf758, arg419_1, 4096, grid=grid(4096), stream=stream0)
del buf756
buf759 = buf738; del buf738 # reuse
# Source Nodes: [y_87], Original ATen: [aten.bmm]
triton_red_fused_bmm_2.run(buf758, arg418_1, buf759, 6656, 128, grid=grid(6656), stream=stream0)
del arg418_1
buf763 = buf734; del buf734 # reuse
# Source Nodes: [mask, y_87], Original ATen: [aten._softmax, aten.add, aten.bmm, aten.index, aten.logical_not, aten.masked_fill, aten.zeros_like]
triton_per_fused__softmax_add_bmm_index_logical_not_masked_fill_zeros_like_3.run(buf759, arg454_1, arg67_1, buf763, 32, 208, grid=grid(32), stream=stream0)
buf764 = buf739; del buf739 # reuse
# Source Nodes: [y_87], Original ATen: [aten.bmm]
triton_red_fused_bmm_4.run(buf763, arg419_1, buf764, 8192, 104, grid=grid(8192), stream=stream0)
del arg419_1
buf766 = buf754; del buf754 # reuse
# Source Nodes: [out_175, y_87], Original ATen: [aten.bmm, torchao.fp16act_fp6weight_linear]
triton_per_fused_bmm_fp16act_fp6weight_linear_5.run(buf764, buf766, 4096, 2, grid=grid(4096), stream=stream0)
# Source Nodes: [out_175], Original ATen: [torchao.fp16act_fp6weight_linear]
buf767 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf766, arg420_1, arg421_1, 13)
del arg420_1
del arg421_1
buf768 = buf767
del buf767
buf769 = buf717; del buf717 # reuse
buf771 = buf766; del buf766 # reuse
buf774 = buf719; del buf719 # reuse
# Source Nodes: [float_120, h_28, h_29, mean_59, mul_446, out_167, out_173, out_176, out_177], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, torchao.fp16act_fp6weight_linear]
triton_red_fused__to_copy_add_fp16act_fp6weight_linear_mean_mul_13.run(buf769, buf727, buf743, buf752, buf768, arg59_1, buf771, buf774, 1, 4096, grid=grid(1), stream=stream0)
del arg59_1
del buf727
del buf743
del buf752
del buf768
# Source Nodes: [out_176], Original ATen: [torchao.fp16act_fp6weight_linear]
buf772 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf771, arg422_1, arg423_1, 1)
del arg422_1
del arg423_1
buf773 = buf772
del buf772
# Source Nodes: [out_177], Original ATen: [torchao.fp16act_fp6weight_linear]
buf775 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf774, arg424_1, arg425_1, 1)
del arg424_1
del arg425_1
buf776 = buf775
del buf775
buf777 = buf773; del buf773 # reuse
# Source Nodes: [out_178], Original ATen: [torchao.fp16act_fp6weight_linear]
triton_poi_fused_fp16act_fp6weight_linear_7.run(buf777, buf776, 11008, grid=grid(11008), stream=stream0)
del buf776
# Source Nodes: [out_178], Original ATen: [torchao.fp16act_fp6weight_linear]
buf778 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf777, arg426_1, arg427_1, 13)
del arg426_1
del arg427_1
del buf777
buf779 = buf778
del buf778
buf781 = buf774; del buf774 # reuse
# Source Nodes: [float_121, mean_60, mul_450, out_179, out_180], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, torchao.fp16act_fp6weight_linear]
triton_red_fused__to_copy_add_fp16act_fp6weight_linear_mean_mul_10.run(buf769, buf779, arg60_1, buf781, 1, 4096, grid=grid(1), stream=stream0)
del arg60_1
# Source Nodes: [out_180], Original ATen: [torchao.fp16act_fp6weight_linear]
buf782 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf781, arg428_1, arg429_1, 1)
del arg428_1
del arg429_1
buf783 = buf782
del buf782
buf785 = buf758; del buf758 # reuse
# Source Nodes: [setitem_60, setitem_61, y_90], Original ATen: [aten.bmm, aten.index_put]
triton_poi_fused_bmm_index_put_1.run(arg454_1, buf783, arg66_1, arg430_1, buf785, arg431_1, 4096, grid=grid(4096), stream=stream0)
del buf783
buf786 = buf763; del buf763 # reuse
# Source Nodes: [y_90], Original ATen: [aten.bmm]
triton_red_fused_bmm_2.run(buf785, arg430_1, buf786, 6656, 128, grid=grid(6656), stream=stream0)
del arg430_1
buf790 = buf759; del buf759 # reuse
# Source Nodes: [mask, y_90], Original ATen: [aten._softmax, aten.add, aten.bmm, aten.index, aten.logical_not, aten.masked_fill, aten.zeros_like]
triton_per_fused__softmax_add_bmm_index_logical_not_masked_fill_zeros_like_3.run(buf786, arg454_1, arg67_1, buf790, 32, 208, grid=grid(32), stream=stream0)
buf791 = buf764; del buf764 # reuse
# Source Nodes: [y_90], Original ATen: [aten.bmm]
triton_red_fused_bmm_4.run(buf790, arg431_1, buf791, 8192, 104, grid=grid(8192), stream=stream0)
del arg431_1
buf793 = buf781; del buf781 # reuse
# Source Nodes: [out_181, y_90], Original ATen: [aten.bmm, torchao.fp16act_fp6weight_linear]
triton_per_fused_bmm_fp16act_fp6weight_linear_5.run(buf791, buf793, 4096, 2, grid=grid(4096), stream=stream0)
# Source Nodes: [out_181], Original ATen: [torchao.fp16act_fp6weight_linear]
buf794 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf793, arg432_1, arg433_1, 13)
del arg432_1
del arg433_1
buf795 = buf794
del buf794
buf797 = reinterpret_tensor(buf793, (1, 1, 4096), (4096, 4096, 1), 0); del buf793 # reuse
# Source Nodes: [add_184, float_124, h_30, mean_61, mul_461, mul_462, mul_463, out_179, output_61, rsqrt_61], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, aten.rsqrt]
triton_red_fused__to_copy_add_mean_mul_rsqrt_11.run(buf769, buf779, buf795, arg61_1, buf797, 1, 4096, grid=grid(1), stream=stream0)
del arg61_1
# Source Nodes: [out_182], Original ATen: [torchao.fp16act_fp6weight_linear]
buf798 = torch.ops.torchao.fp16act_fp6weight_linear.default(reinterpret_tensor(buf797, (1, 4096), (4096, 1), 0), arg434_1, arg435_1, 1)
del arg434_1
del arg435_1
buf799 = buf798
del buf798
# Source Nodes: [out_183], Original ATen: [torchao.fp16act_fp6weight_linear]
buf800 = torch.ops.torchao.fp16act_fp6weight_linear.default(reinterpret_tensor(buf797, (1, 4096), (4096, 1), 0), arg436_1, arg437_1, 1)
del arg436_1
del arg437_1
buf801 = buf800
del buf800
buf802 = buf799; del buf799 # reuse
# Source Nodes: [out_184], Original ATen: [torchao.fp16act_fp6weight_linear]
triton_poi_fused_fp16act_fp6weight_linear_7.run(buf802, buf801, 11008, grid=grid(11008), stream=stream0)
del buf801
# Source Nodes: [out_184], Original ATen: [torchao.fp16act_fp6weight_linear]
buf803 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf802, arg438_1, arg439_1, 13)
del arg438_1
del arg439_1
del buf802
buf804 = buf803
del buf803
buf806 = reinterpret_tensor(buf797, (1, 4096), (4096, 1), 0); del buf797 # reuse
# Source Nodes: [float_125, h_30, mean_62, mul_465, out_179, out_185, out_186], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, torchao.fp16act_fp6weight_linear]
triton_red_fused__to_copy_add_fp16act_fp6weight_linear_mean_mul_12.run(buf769, buf779, buf795, buf804, arg62_1, buf806, 1, 4096, grid=grid(1), stream=stream0)
del arg62_1
# Source Nodes: [out_186], Original ATen: [torchao.fp16act_fp6weight_linear]
buf807 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf806, arg440_1, arg441_1, 1)
del arg440_1
del arg441_1
buf808 = buf807
del buf807
buf810 = buf785; del buf785 # reuse
# Source Nodes: [setitem_62, setitem_63, y_93], Original ATen: [aten.bmm, aten.index_put]
triton_poi_fused_bmm_index_put_1.run(arg454_1, buf808, arg66_1, arg442_1, buf810, arg443_1, 4096, grid=grid(4096), stream=stream0)
del arg66_1
del buf808
buf811 = buf790; del buf790 # reuse
# Source Nodes: [y_93], Original ATen: [aten.bmm]
triton_red_fused_bmm_2.run(buf810, arg442_1, buf811, 6656, 128, grid=grid(6656), stream=stream0)
del arg442_1
del buf810
buf815 = buf786; del buf786 # reuse
# Source Nodes: [mask, y_93], Original ATen: [aten._softmax, aten.add, aten.bmm, aten.index, aten.logical_not, aten.masked_fill, aten.zeros_like]
triton_per_fused__softmax_add_bmm_index_logical_not_masked_fill_zeros_like_3.run(buf811, arg454_1, arg67_1, buf815, 32, 208, grid=grid(32), stream=stream0)
del arg454_1
del arg67_1
del buf811
buf816 = buf791; del buf791 # reuse
# Source Nodes: [y_93], Original ATen: [aten.bmm]
triton_red_fused_bmm_4.run(buf815, arg443_1, buf816, 8192, 104, grid=grid(8192), stream=stream0)
del arg443_1
del buf815
buf818 = buf806; del buf806 # reuse
# Source Nodes: [out_187, y_93], Original ATen: [aten.bmm, torchao.fp16act_fp6weight_linear]
triton_per_fused_bmm_fp16act_fp6weight_linear_5.run(buf816, buf818, 4096, 2, grid=grid(4096), stream=stream0)
del buf816
# Source Nodes: [out_187], Original ATen: [torchao.fp16act_fp6weight_linear]
buf819 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf818, arg444_1, arg445_1, 13)
del arg444_1
del arg445_1
buf820 = buf819
del buf819
buf821 = buf769; del buf769 # reuse
buf823 = buf818; del buf818 # reuse
buf826 = buf771; del buf771 # reuse
# Source Nodes: [float_128, h_30, h_31, mean_63, mul_476, out_179, out_185, out_188, out_189], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, torchao.fp16act_fp6weight_linear]
triton_red_fused__to_copy_add_fp16act_fp6weight_linear_mean_mul_13.run(buf821, buf779, buf795, buf804, buf820, arg63_1, buf823, buf826, 1, 4096, grid=grid(1), stream=stream0)
del arg63_1
del buf779
del buf795
del buf804
del buf820
# Source Nodes: [out_188], Original ATen: [torchao.fp16act_fp6weight_linear]
buf824 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf823, arg446_1, arg447_1, 1)
del arg446_1
del arg447_1
del buf823
buf825 = buf824
del buf824
# Source Nodes: [out_189], Original ATen: [torchao.fp16act_fp6weight_linear]
buf827 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf826, arg448_1, arg449_1, 1)
del arg448_1
del arg449_1
buf828 = buf827
del buf827
buf829 = buf825; del buf825 # reuse
# Source Nodes: [out_190], Original ATen: [torchao.fp16act_fp6weight_linear]
triton_poi_fused_fp16act_fp6weight_linear_7.run(buf829, buf828, 11008, grid=grid(11008), stream=stream0)
del buf828
# Source Nodes: [out_190], Original ATen: [torchao.fp16act_fp6weight_linear]
buf830 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf829, arg450_1, arg451_1, 13)
del arg450_1
del arg451_1
del buf829
buf831 = buf830
del buf830
buf833 = buf826; del buf826 # reuse
# Source Nodes: [float_129, mean_64, mul_480, out_191, out_192], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, torchao.fp16act_fp6weight_linear]
triton_red_fused__to_copy_add_fp16act_fp6weight_linear_mean_mul_10.run(buf821, buf831, arg64_1, buf833, 1, 4096, grid=grid(1), stream=stream0)
del arg64_1
del buf821
del buf831
# Source Nodes: [out_192], Original ATen: [torchao.fp16act_fp6weight_linear]
buf834 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf833, arg452_1, arg453_1, 1)
del arg452_1
del arg453_1
del buf833
buf835 = buf834
del buf834
buf836 = reinterpret_tensor(buf835, (32000, ), (1, ), 0); del buf835 # reuse
# Source Nodes: [logits_1], Original ATen: [aten.div]
triton_poi_fused_div_16.run(buf836, 32000, grid=grid(32000), stream=stream0)
# Source Nodes: [logits_1, topk], Original ATen: [aten.div, aten.topk]
buf837 = aten.topk.default(buf836, 200)
buf838 = buf837[0]
del buf837
buf840 = empty_strided_cuda((1, 4), (4, 1), torch.float32)
# Source Nodes: [logits_2, lt, probs], Original ATen: [aten._softmax, aten.lt, aten.scalar_tensor, aten.where]
triton_red_fused__softmax_lt_scalar_tensor_where_17.run(buf836, buf838, buf840, 4, 8000, grid=grid(4), stream=stream0)
buf841 = empty_strided_cuda((1, ), (1, ), torch.float32)
# Source Nodes: [logits_2, lt, probs], Original ATen: [aten._softmax, aten.lt, aten.scalar_tensor, aten.where]
triton_per_fused__softmax_lt_scalar_tensor_where_18.run(buf840, buf841, 1, 4, grid=grid(1), stream=stream0)
buf842 = buf840; del buf840 # reuse
# Source Nodes: [logits_2, lt, probs], Original ATen: [aten._softmax, aten.lt, aten.scalar_tensor, aten.where]
triton_red_fused__softmax_lt_scalar_tensor_where_19.run(buf836, buf838, buf841, buf842, 4, 8000, grid=grid(4), stream=stream0)
buf843 = empty_strided_cuda((1, ), (1, ), torch.float32)
# Source Nodes: [logits_2, lt, probs], Original ATen: [aten._softmax, aten.lt, aten.scalar_tensor, aten.where]
triton_per_fused__softmax_lt_scalar_tensor_where_20.run(buf842, buf843, 1, 4, grid=grid(1), stream=stream0)
del buf842
buf845 = empty_strided_cuda((1, ), (1, ), torch.int64)
# Source Nodes: [], Original ATen: []
aten.randint.low_out(-9223372036854775808, 9223372036854775807, [1], out=buf845)
buf844 = buf836; del buf836 # reuse
buf848 = empty_strided_cuda((1, ), (1, ), torch.int32)
# Source Nodes: [argmax, idx_next, logits_2, lt, probs, q_128, truediv_1], Original ATen: [aten._softmax, aten._to_copy, aten.argmax, aten.div, aten.exponential, aten.lt, aten.scalar_tensor, aten.where]
triton_red_fused__softmax__to_copy_argmax_div_exponential_lt_scalar_tensor_where_21.run(buf844, buf838, buf841, buf843, buf845, buf848, 0, 1, 32000, grid=grid(1), stream=stream0)
del buf838
del buf841
del buf843
del buf845
return (buf848, buf844, )
def benchmark_compiled_module(times=10, repeat=10):
from torch._dynamo.testing import rand_strided
from torch._inductor.utils import print_performance
arg0_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg1_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg2_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg3_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg4_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg5_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg6_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg7_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg8_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg9_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg10_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg11_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg12_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg13_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg14_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg15_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg16_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg17_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg18_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg19_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg20_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg21_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg22_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg23_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg24_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg25_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg26_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg27_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg28_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg29_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg30_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg31_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg32_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg33_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg34_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg35_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg36_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg37_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg38_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg39_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg40_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg41_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg42_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg43_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg44_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg45_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg46_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg47_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg48_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg49_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg50_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg51_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg52_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg53_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg54_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg55_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg56_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg57_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg58_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg59_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg60_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg61_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg62_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg63_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg64_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg65_1 = rand_strided((32000, 4096), (4096, 1), device='cuda:0', dtype=torch.float16)
arg66_1 = rand_strided((2048, 64, 2), (128, 2, 1), device='cuda:0', dtype=torch.float16)
arg67_1 = rand_strided((208, 208), (208, 1), device='cuda:0', dtype=torch.bool)
arg68_1 = rand_strided((12288, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg69_1 = rand_strided((12288, ), (1, ), device='cuda:0', dtype=torch.float16)
arg70_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16)
arg71_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16)
arg72_1 = rand_strided((4096, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg73_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg74_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg75_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16)
arg76_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg77_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16)
arg78_1 = rand_strided((4096, 2064), (2064, 1), device='cuda:0', dtype=torch.int32)
arg79_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg80_1 = rand_strided((12288, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg81_1 = rand_strided((12288, ), (1, ), device='cuda:0', dtype=torch.float16)
arg82_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16)
arg83_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16)
arg84_1 = rand_strided((4096, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg85_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg86_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg87_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16)
arg88_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg89_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16)
arg90_1 = rand_strided((4096, 2064), (2064, 1), device='cuda:0', dtype=torch.int32)
arg91_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg92_1 = rand_strided((12288, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg93_1 = rand_strided((12288, ), (1, ), device='cuda:0', dtype=torch.float16)
arg94_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16)
arg95_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16)
arg96_1 = rand_strided((4096, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg97_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg98_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg99_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16)
arg100_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg101_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16)
arg102_1 = rand_strided((4096, 2064), (2064, 1), device='cuda:0', dtype=torch.int32)
arg103_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg104_1 = rand_strided((12288, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg105_1 = rand_strided((12288, ), (1, ), device='cuda:0', dtype=torch.float16)
arg106_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16)
arg107_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16)
arg108_1 = rand_strided((4096, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg109_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg110_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg111_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16)
arg112_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg113_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16)
arg114_1 = rand_strided((4096, 2064), (2064, 1), device='cuda:0', dtype=torch.int32)
arg115_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg116_1 = rand_strided((12288, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg117_1 = rand_strided((12288, ), (1, ), device='cuda:0', dtype=torch.float16)
arg118_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16)
arg119_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16)
arg120_1 = rand_strided((4096, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg121_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg122_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg123_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16)
arg124_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg125_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16)
arg126_1 = rand_strided((4096, 2064), (2064, 1), device='cuda:0', dtype=torch.int32)
arg127_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg128_1 = rand_strided((12288, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg129_1 = rand_strided((12288, ), (1, ), device='cuda:0', dtype=torch.float16)
arg130_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16)
arg131_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16)
arg132_1 = rand_strided((4096, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg133_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg134_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg135_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16)
arg136_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg137_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16)
arg138_1 = rand_strided((4096, 2064), (2064, 1), device='cuda:0', dtype=torch.int32)
arg139_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg140_1 = rand_strided((12288, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg141_1 = rand_strided((12288, ), (1, ), device='cuda:0', dtype=torch.float16)
arg142_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16)
arg143_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16)
arg144_1 = rand_strided((4096, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg145_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg146_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg147_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16)
arg148_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg149_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16)
arg150_1 = rand_strided((4096, 2064), (2064, 1), device='cuda:0', dtype=torch.int32)
arg151_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg152_1 = rand_strided((12288, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg153_1 = rand_strided((12288, ), (1, ), device='cuda:0', dtype=torch.float16)
arg154_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16)
arg155_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16)
arg156_1 = rand_strided((4096, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg157_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg158_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg159_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16)
arg160_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg161_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16)
arg162_1 = rand_strided((4096, 2064), (2064, 1), device='cuda:0', dtype=torch.int32)
arg163_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg164_1 = rand_strided((12288, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg165_1 = rand_strided((12288, ), (1, ), device='cuda:0', dtype=torch.float16)
arg166_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16)
arg167_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16)
arg168_1 = rand_strided((4096, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg169_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg170_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg171_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16)
arg172_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg173_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16)
arg174_1 = rand_strided((4096, 2064), (2064, 1), device='cuda:0', dtype=torch.int32)
arg175_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg176_1 = rand_strided((12288, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg177_1 = rand_strided((12288, ), (1, ), device='cuda:0', dtype=torch.float16)
arg178_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16)
arg179_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16)
arg180_1 = rand_strided((4096, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg181_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg182_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg183_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16)
arg184_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg185_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16)
arg186_1 = rand_strided((4096, 2064), (2064, 1), device='cuda:0', dtype=torch.int32)
arg187_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg188_1 = rand_strided((12288, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg189_1 = rand_strided((12288, ), (1, ), device='cuda:0', dtype=torch.float16)
arg190_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16)
arg191_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16)
arg192_1 = rand_strided((4096, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg193_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg194_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg195_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16)
arg196_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg197_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16)
arg198_1 = rand_strided((4096, 2064), (2064, 1), device='cuda:0', dtype=torch.int32)
arg199_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg200_1 = rand_strided((12288, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg201_1 = rand_strided((12288, ), (1, ), device='cuda:0', dtype=torch.float16)
arg202_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16)
arg203_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16)
arg204_1 = rand_strided((4096, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg205_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg206_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg207_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16)
arg208_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg209_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16)
arg210_1 = rand_strided((4096, 2064), (2064, 1), device='cuda:0', dtype=torch.int32)
arg211_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg212_1 = rand_strided((12288, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg213_1 = rand_strided((12288, ), (1, ), device='cuda:0', dtype=torch.float16)
arg214_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16)
arg215_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16)
arg216_1 = rand_strided((4096, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg217_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg218_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg219_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16)
arg220_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg221_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16)
arg222_1 = rand_strided((4096, 2064), (2064, 1), device='cuda:0', dtype=torch.int32)
arg223_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg224_1 = rand_strided((12288, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg225_1 = rand_strided((12288, ), (1, ), device='cuda:0', dtype=torch.float16)
arg226_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16)
arg227_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16)
arg228_1 = rand_strided((4096, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg229_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg230_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg231_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16)
arg232_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg233_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16)
arg234_1 = rand_strided((4096, 2064), (2064, 1), device='cuda:0', dtype=torch.int32)
arg235_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg236_1 = rand_strided((12288, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg237_1 = rand_strided((12288, ), (1, ), device='cuda:0', dtype=torch.float16)
arg238_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16)
arg239_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16)
arg240_1 = rand_strided((4096, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg241_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg242_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg243_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16)
arg244_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg245_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16)
arg246_1 = rand_strided((4096, 2064), (2064, 1), device='cuda:0', dtype=torch.int32)
arg247_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg248_1 = rand_strided((12288, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg249_1 = rand_strided((12288, ), (1, ), device='cuda:0', dtype=torch.float16)
arg250_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16)
arg251_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16)
arg252_1 = rand_strided((4096, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg253_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg254_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg255_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16)
arg256_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg257_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16)
arg258_1 = rand_strided((4096, 2064), (2064, 1), device='cuda:0', dtype=torch.int32)
arg259_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg260_1 = rand_strided((12288, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg261_1 = rand_strided((12288, ), (1, ), device='cuda:0', dtype=torch.float16)
arg262_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16)
arg263_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16)
arg264_1 = rand_strided((4096, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg265_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg266_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg267_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16)
arg268_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg269_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16)
arg270_1 = rand_strided((4096, 2064), (2064, 1), device='cuda:0', dtype=torch.int32)
arg271_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg272_1 = rand_strided((12288, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg273_1 = rand_strided((12288, ), (1, ), device='cuda:0', dtype=torch.float16)
arg274_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16)
arg275_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16)
arg276_1 = rand_strided((4096, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg277_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg278_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg279_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16)
arg280_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg281_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16)
arg282_1 = rand_strided((4096, 2064), (2064, 1), device='cuda:0', dtype=torch.int32)
arg283_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg284_1 = rand_strided((12288, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg285_1 = rand_strided((12288, ), (1, ), device='cuda:0', dtype=torch.float16)
arg286_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16)
arg287_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16)
arg288_1 = rand_strided((4096, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg289_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg290_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg291_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16)
arg292_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg293_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16)
arg294_1 = rand_strided((4096, 2064), (2064, 1), device='cuda:0', dtype=torch.int32)
arg295_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg296_1 = rand_strided((12288, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg297_1 = rand_strided((12288, ), (1, ), device='cuda:0', dtype=torch.float16)
arg298_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16)
arg299_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16)
arg300_1 = rand_strided((4096, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg301_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg302_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg303_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16)
arg304_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg305_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16)
arg306_1 = rand_strided((4096, 2064), (2064, 1), device='cuda:0', dtype=torch.int32)
arg307_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg308_1 = rand_strided((12288, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg309_1 = rand_strided((12288, ), (1, ), device='cuda:0', dtype=torch.float16)
arg310_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16)
arg311_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16)
arg312_1 = rand_strided((4096, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg313_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg314_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg315_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16)
arg316_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg317_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16)
arg318_1 = rand_strided((4096, 2064), (2064, 1), device='cuda:0', dtype=torch.int32)
arg319_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg320_1 = rand_strided((12288, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg321_1 = rand_strided((12288, ), (1, ), device='cuda:0', dtype=torch.float16)
arg322_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16)
arg323_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16)
arg324_1 = rand_strided((4096, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg325_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg326_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg327_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16)
arg328_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg329_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16)
arg330_1 = rand_strided((4096, 2064), (2064, 1), device='cuda:0', dtype=torch.int32)
arg331_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg332_1 = rand_strided((12288, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg333_1 = rand_strided((12288, ), (1, ), device='cuda:0', dtype=torch.float16)
arg334_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16)
arg335_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16)
arg336_1 = rand_strided((4096, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg337_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg338_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg339_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16)
arg340_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg341_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16)
arg342_1 = rand_strided((4096, 2064), (2064, 1), device='cuda:0', dtype=torch.int32)
arg343_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg344_1 = rand_strided((12288, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg345_1 = rand_strided((12288, ), (1, ), device='cuda:0', dtype=torch.float16)
arg346_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16)
arg347_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16)
arg348_1 = rand_strided((4096, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg349_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg350_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg351_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16)
arg352_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg353_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16)
arg354_1 = rand_strided((4096, 2064), (2064, 1), device='cuda:0', dtype=torch.int32)
arg355_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg356_1 = rand_strided((12288, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg357_1 = rand_strided((12288, ), (1, ), device='cuda:0', dtype=torch.float16)
arg358_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16)
arg359_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16)
arg360_1 = rand_strided((4096, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg361_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg362_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg363_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16)
arg364_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg365_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16)
arg366_1 = rand_strided((4096, 2064), (2064, 1), device='cuda:0', dtype=torch.int32)
arg367_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg368_1 = rand_strided((12288, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg369_1 = rand_strided((12288, ), (1, ), device='cuda:0', dtype=torch.float16)
arg370_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16)
arg371_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16)
arg372_1 = rand_strided((4096, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg373_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg374_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg375_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16)
arg376_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg377_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16)
arg378_1 = rand_strided((4096, 2064), (2064, 1), device='cuda:0', dtype=torch.int32)
arg379_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg380_1 = rand_strided((12288, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg381_1 = rand_strided((12288, ), (1, ), device='cuda:0', dtype=torch.float16)
arg382_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16)
arg383_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16)
arg384_1 = rand_strided((4096, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg385_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg386_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg387_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16)
arg388_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg389_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16)
arg390_1 = rand_strided((4096, 2064), (2064, 1), device='cuda:0', dtype=torch.int32)
arg391_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg392_1 = rand_strided((12288, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg393_1 = rand_strided((12288, ), (1, ), device='cuda:0', dtype=torch.float16)
arg394_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16)
arg395_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16)
arg396_1 = rand_strided((4096, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg397_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg398_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg399_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16)
arg400_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg401_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16)
arg402_1 = rand_strided((4096, 2064), (2064, 1), device='cuda:0', dtype=torch.int32)
arg403_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg404_1 = rand_strided((12288, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg405_1 = rand_strided((12288, ), (1, ), device='cuda:0', dtype=torch.float16)
arg406_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16)
arg407_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16)
arg408_1 = rand_strided((4096, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg409_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg410_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg411_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16)
arg412_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg413_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16)
arg414_1 = rand_strided((4096, 2064), (2064, 1), device='cuda:0', dtype=torch.int32)
arg415_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg416_1 = rand_strided((12288, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg417_1 = rand_strided((12288, ), (1, ), device='cuda:0', dtype=torch.float16)
arg418_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16)
arg419_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16)
arg420_1 = rand_strided((4096, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg421_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg422_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg423_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16)
arg424_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg425_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16)
arg426_1 = rand_strided((4096, 2064), (2064, 1), device='cuda:0', dtype=torch.int32)
arg427_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg428_1 = rand_strided((12288, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg429_1 = rand_strided((12288, ), (1, ), device='cuda:0', dtype=torch.float16)
arg430_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16)
arg431_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16)
arg432_1 = rand_strided((4096, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg433_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg434_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg435_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16)
arg436_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg437_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16)
arg438_1 = rand_strided((4096, 2064), (2064, 1), device='cuda:0', dtype=torch.int32)
arg439_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg440_1 = rand_strided((12288, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg441_1 = rand_strided((12288, ), (1, ), device='cuda:0', dtype=torch.float16)
arg442_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16)
arg443_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16)
arg444_1 = rand_strided((4096, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg445_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg446_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg447_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16)
arg448_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg449_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16)
arg450_1 = rand_strided((4096, 2064), (2064, 1), device='cuda:0', dtype=torch.int32)
arg451_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg452_1 = rand_strided((32000, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg453_1 = rand_strided((32000, ), (1, ), device='cuda:0', dtype=torch.float16)
arg454_1 = rand_strided((1, ), (1, ), device='cuda:0', dtype=torch.int32)
arg455_1 = rand_strided((1, 1), (1, 1), device='cuda:0', dtype=torch.int32)
fn = lambda: call([arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1, arg7_1, arg8_1, arg9_1, arg10_1, arg11_1, arg12_1, arg13_1, arg14_1, arg15_1, arg16_1, arg17_1, arg18_1, arg19_1, arg20_1, arg21_1, arg22_1, arg23_1, arg24_1, arg25_1, arg26_1, arg27_1, arg28_1, arg29_1, arg30_1, arg31_1, arg32_1, arg33_1, arg34_1, arg35_1, arg36_1, arg37_1, arg38_1, arg39_1, arg40_1, arg41_1, arg42_1, arg43_1, arg44_1, arg45_1, arg46_1, arg47_1, arg48_1, arg49_1, arg50_1, arg51_1, arg52_1, arg53_1, arg54_1, arg55_1, arg56_1, arg57_1, arg58_1, arg59_1, arg60_1, arg61_1, arg62_1, arg63_1, arg64_1, arg65_1, arg66_1, arg67_1, arg68_1, arg69_1, arg70_1, arg71_1, arg72_1, arg73_1, arg74_1, arg75_1, arg76_1, arg77_1, arg78_1, arg79_1, arg80_1, arg81_1, arg82_1, arg83_1, arg84_1, arg85_1, arg86_1, arg87_1, arg88_1, arg89_1, arg90_1, arg91_1, arg92_1, arg93_1, arg94_1, arg95_1, arg96_1, arg97_1, arg98_1, arg99_1, arg100_1, arg101_1, arg102_1, arg103_1, arg104_1, arg105_1, arg106_1, arg107_1, arg108_1, arg109_1, arg110_1, arg111_1, arg112_1, arg113_1, arg114_1, arg115_1, arg116_1, arg117_1, arg118_1, arg119_1, arg120_1, arg121_1, arg122_1, arg123_1, arg124_1, arg125_1, arg126_1, arg127_1, arg128_1, arg129_1, arg130_1, arg131_1, arg132_1, arg133_1, arg134_1, arg135_1, arg136_1, arg137_1, arg138_1, arg139_1, arg140_1, arg141_1, arg142_1, arg143_1, arg144_1, arg145_1, arg146_1, arg147_1, arg148_1, arg149_1, arg150_1, arg151_1, arg152_1, arg153_1, arg154_1, arg155_1, arg156_1, arg157_1, arg158_1, arg159_1, arg160_1, arg161_1, arg162_1, arg163_1, arg164_1, arg165_1, arg166_1, arg167_1, arg168_1, arg169_1, arg170_1, arg171_1, arg172_1, arg173_1, arg174_1, arg175_1, arg176_1, arg177_1, arg178_1, arg179_1, arg180_1, arg181_1, arg182_1, arg183_1, arg184_1, arg185_1, arg186_1, arg187_1, arg188_1, arg189_1, arg190_1, arg191_1, arg192_1, arg193_1, arg194_1, arg195_1, arg196_1, arg197_1, arg198_1, arg199_1, arg200_1, arg201_1, arg202_1, arg203_1, arg204_1, arg205_1, arg206_1, arg207_1, arg208_1, arg209_1, arg210_1, arg211_1, arg212_1, arg213_1, arg214_1, arg215_1, arg216_1, arg217_1, arg218_1, arg219_1, arg220_1, arg221_1, arg222_1, arg223_1, arg224_1, arg225_1, arg226_1, arg227_1, arg228_1, arg229_1, arg230_1, arg231_1, arg232_1, arg233_1, arg234_1, arg235_1, arg236_1, arg237_1, arg238_1, arg239_1, arg240_1, arg241_1, arg242_1, arg243_1, arg244_1, arg245_1, arg246_1, arg247_1, arg248_1, arg249_1, arg250_1, arg251_1, arg252_1, arg253_1, arg254_1, arg255_1, arg256_1, arg257_1, arg258_1, arg259_1, arg260_1, arg261_1, arg262_1, arg263_1, arg264_1, arg265_1, arg266_1, arg267_1, arg268_1, arg269_1, arg270_1, arg271_1, arg272_1, arg273_1, arg274_1, arg275_1, arg276_1, arg277_1, arg278_1, arg279_1, arg280_1, arg281_1, arg282_1, arg283_1, arg284_1, arg285_1, arg286_1, arg287_1, arg288_1, arg289_1, arg290_1, arg291_1, arg292_1, arg293_1, arg294_1, arg295_1, arg296_1, arg297_1, arg298_1, arg299_1, arg300_1, arg301_1, arg302_1, arg303_1, arg304_1, arg305_1, arg306_1, arg307_1, arg308_1, arg309_1, arg310_1, arg311_1, arg312_1, arg313_1, arg314_1, arg315_1, arg316_1, arg317_1, arg318_1, arg319_1, arg320_1, arg321_1, arg322_1, arg323_1, arg324_1, arg325_1, arg326_1, arg327_1, arg328_1, arg329_1, arg330_1, arg331_1, arg332_1, arg333_1, arg334_1, arg335_1, arg336_1, arg337_1, arg338_1, arg339_1, arg340_1, arg341_1, arg342_1, arg343_1, arg344_1, arg345_1, arg346_1, arg347_1, arg348_1, arg349_1, arg350_1, arg351_1, arg352_1, arg353_1, arg354_1, arg355_1, arg356_1, arg357_1, arg358_1, arg359_1, arg360_1, arg361_1, arg362_1, arg363_1, arg364_1, arg365_1, arg366_1, arg367_1, arg368_1, arg369_1, arg370_1, arg371_1, arg372_1, arg373_1, arg374_1, arg375_1, arg376_1, arg377_1, arg378_1, arg379_1, arg380_1, arg381_1, arg382_1, arg383_1, arg384_1, arg385_1, arg386_1, arg387_1, arg388_1, arg389_1, arg390_1, arg391_1, arg392_1, arg393_1, arg394_1, arg395_1, arg396_1, arg397_1, arg398_1, arg399_1, arg400_1, arg401_1, arg402_1, arg403_1, arg404_1, arg405_1, arg406_1, arg407_1, arg408_1, arg409_1, arg410_1, arg411_1, arg412_1, arg413_1, arg414_1, arg415_1, arg416_1, arg417_1, arg418_1, arg419_1, arg420_1, arg421_1, arg422_1, arg423_1, arg424_1, arg425_1, arg426_1, arg427_1, arg428_1, arg429_1, arg430_1, arg431_1, arg432_1, arg433_1, arg434_1, arg435_1, arg436_1, arg437_1, arg438_1, arg439_1, arg440_1, arg441_1, arg442_1, arg443_1, arg444_1, arg445_1, arg446_1, arg447_1, arg448_1, arg449_1, arg450_1, arg451_1, arg452_1, arg453_1, arg454_1, arg455_1])
return print_performance(fn, times=times, repeat=repeat)
if __name__ == "__main__":
from torch._inductor.wrapper_benchmark import compiled_module_main
compiled_module_main('None', benchmark_compiled_module)
# this is set: torch._inductor.config.triton.cudagraph_trees = False
# AOT ID: ['0_inference']
from ctypes import c_void_p, c_long
import torch
import math
import random
import os
import tempfile
from math import inf, nan
from torch._inductor.hooks import run_intermediate_hooks
from torch._inductor.utils import maybe_profile
from torch._inductor.codegen.memory_planning import _align as align
from torch import device, empty_strided
from torch._inductor.async_compile import AsyncCompile
from torch._inductor.select_algorithm import extern_kernels
from torch._inductor.codegen.multi_kernel import MultiKernelCall
aten = torch.ops.aten
inductor_ops = torch.ops.inductor
_quantized = torch.ops._quantized
assert_size_stride = torch._C._dynamo.guards.assert_size_stride
empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu
empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda
alloc_from_pool = torch.ops.inductor._alloc_from_pool
reinterpret_tensor = torch.ops.inductor._reinterpret_tensor
async_compile = AsyncCompile()
# kernel path: /tmp/torchinductor_ubuntu/35/c35n52yrfyq3r2uvr6syoon44qkntahvhxuhmylcpsxnin76axfd.py
# Source Nodes: [float_1, mean, mul, out, x], Original ATen: [aten._to_copy, aten.embedding, aten.mean, aten.mul, torchao.fp16act_fp6weight_linear]
# float_1 => convert_element_type
# mean => mean
# mul => mul
# out => fp16act_fp6weight_linear
# x => embedding
triton_red_fused__to_copy_embedding_fp16act_fp6weight_linear_mean_mul_0 = async_compile.triton('triton_red_fused__to_copy_embedding_fp16act_fp6weight_linear_mean_mul_0', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
@triton_heuristics.reduction(
size_hints=[1, 4096],
reduction_hint=ReductionHint.INNER,
filename=__file__,
triton_meta={'signature': {0: '*i32', 1: '*fp16', 2: '*fp16', 3: '*fp16', 4: 'i32', 5: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=76), 'constants': {4: 1}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 5), equal_to_1=(4,))]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_embedding_fp16act_fp6weight_linear_mean_mul_0', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 3, 'num_reduction': 1, 'backend_hash': '5002d3b4f89584c56f8df13bbe2510318f7ec3b73ba9efd611640fb08b2fe939', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False}
)
@triton.jit
def triton_red_fused__to_copy_embedding_fp16act_fp6weight_linear_mean_mul_0(in_ptr0, in_ptr1, in_ptr2, out_ptr1, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr):
xnumel = 1
rnumel = 4096
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
xmask = xindex < xnumel
rbase = tl.arange(0, RBLOCK)[None, :]
tmp0 = tl.load(in_ptr0 + (0))
tmp1 = tl.broadcast_to(tmp0, [XBLOCK, RBLOCK])
_tmp11 = tl.full([XBLOCK, RBLOCK], 0, tl.float32)
for roffset in range(0, rnumel, RBLOCK):
rindex = roffset + rbase
rmask = rindex < rnumel
r0 = rindex
tmp2 = tl.full([XBLOCK, RBLOCK], 32000, tl.int32)
tmp3 = tmp1 + tmp2
tmp4 = tmp1 < 0
tmp5 = tl.where(tmp4, tmp3, tmp1)
tl.device_assert((0 <= tmp5) & (tmp5 < 32000), "index out of bounds: 0 <= tmp5 < 32000")
tmp7 = tl.load(in_ptr1 + (r0 + (4096*tmp5)), rmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp8 = tmp7.to(tl.float32)
tmp9 = tmp8 * tmp8
tmp10 = tl.broadcast_to(tmp9, [XBLOCK, RBLOCK])
tmp12 = _tmp11 + tmp10
_tmp11 = tl.where(rmask, tmp12, _tmp11)
tmp11 = tl.sum(_tmp11, 1)[:, None]
tmp13 = tl.load(in_ptr0 + (0))
tmp14 = tl.broadcast_to(tmp13, [XBLOCK, RBLOCK])
for roffset in range(0, rnumel, RBLOCK):
rindex = roffset + rbase
rmask = rindex < rnumel
r0 = rindex
tmp29 = tl.load(in_ptr2 + (r0), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp15 = tl.full([XBLOCK, RBLOCK], 32000, tl.int32)
tmp16 = tmp14 + tmp15
tmp17 = tmp14 < 0
tmp18 = tl.where(tmp17, tmp16, tmp14)
tl.device_assert((0 <= tmp18) & (tmp18 < 32000), "index out of bounds: 0 <= tmp18 < 32000")
tmp20 = tl.load(in_ptr1 + (r0 + (4096*tmp18)), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp21 = tmp20.to(tl.float32)
tmp22 = 4096.0
tmp23 = tmp11 / tmp22
tmp24 = 1e-05
tmp25 = tmp23 + tmp24
tmp26 = libdevice.rsqrt(tmp25)
tmp27 = tmp21 * tmp26
tmp28 = tmp27.to(tl.float32)
tmp30 = tmp28 * tmp29
tl.store(out_ptr1 + (tl.broadcast_to(r0, [XBLOCK, RBLOCK])), tmp30, rmask)
''', device_str='cuda')
import triton
import triton.language as tl
from torch._inductor.runtime.triton_heuristics import grid, split_scan_grid, start_graph, end_graph
from torch._C import _cuda_getCurrentRawStream as get_raw_stream
# kernel path: /tmp/torchinductor_ubuntu/ix/cix3cikar5jbote3hymtcnf5jtrl75mvovejwvdsj67ds2nqaynm.py
# Source Nodes: [setitem, setitem_1, y], Original ATen: [aten.bmm, aten.index_put]
# setitem => index_put
# setitem_1 => index_put_1
# y => convert_element_type_6
triton_poi_fused_bmm_index_put_1 = async_compile.triton('triton_poi_fused_bmm_index_put_1', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
@triton_heuristics.pointwise(
size_hints=[4096],
filename=__file__,
triton_meta={'signature': {0: '*i32', 1: '*fp16', 2: '*fp16', 3: '*fp16', 4: '*fp32', 5: '*fp16', 6: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=76), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4, 5, 6), equal_to_1=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_bmm_index_put_1', 'mutated_arg_names': ['out_ptr0', 'out_ptr2'], 'no_x_dim': False, 'num_load': 10, 'num_reduction': 0, 'backend_hash': '5002d3b4f89584c56f8df13bbe2510318f7ec3b73ba9efd611640fb08b2fe939', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False},
min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused_bmm_index_put_1(in_ptr0, in_ptr1, in_ptr2, out_ptr0, out_ptr1, out_ptr2, xnumel, XBLOCK : tl.constexpr):
xnumel = 4096
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = xindex < xnumel
x0 = xindex % 128
x1 = (xindex // 128)
x2 = xindex
tmp0 = tl.load(in_ptr0 + (0))
tmp1 = tl.broadcast_to(tmp0, [XBLOCK])
tmp71 = tl.load(in_ptr1 + (8192 + x2), None).to(tl.float32)
tmp2 = tl.full([XBLOCK], 208, tl.int32)
tmp3 = tmp1 + tmp2
tmp4 = tmp1 < 0
tmp5 = tl.where(tmp4, tmp3, tmp1)
tl.device_assert((0 <= tmp5) & (tmp5 < 208), "index out of bounds: 0 <= tmp5 < 208")
tmp7 = x0 % 2
tmp8 = tl.full([1], 0, tl.int64)
tmp9 = tmp7 >= tmp8
tmp10 = tl.full([1], 1, tl.int64)
tmp11 = tmp7 < tmp10
tmp12 = tl.load(in_ptr1 + (4096 + (2*(x0 // 2)) + (128*x1)), tmp11, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp13 = tmp12.to(tl.float32)
tmp14 = tl.full([XBLOCK], 2048, tl.int32)
tmp15 = tmp1 + tmp14
tmp16 = tl.where(tmp4, tmp15, tmp1)
tl.device_assert(((0 <= tl.broadcast_to(tmp16, [XBLOCK])) & (tl.broadcast_to(tmp16, [XBLOCK]) < 2048)) | ~(tmp11), "index out of bounds: 0 <= tl.broadcast_to(tmp16, [XBLOCK]) < 2048")
tmp18 = tl.load(in_ptr2 + ((2*(x0 // 2)) + (128*tmp16)), tmp11, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp19 = tmp18.to(tl.float32)
tmp20 = tmp13 * tmp19
tmp21 = tl.load(in_ptr1 + (4097 + (2*(x0 // 2)) + (128*x1)), tmp11, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp22 = tmp21.to(tl.float32)
tmp23 = tl.load(in_ptr2 + (1 + (2*(x0 // 2)) + (128*tmp16)), tmp11, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp24 = tmp23.to(tl.float32)
tmp25 = tmp22 * tmp24
tmp26 = tmp20 - tmp25
tmp27 = tl.full(tmp26.shape, 0.0, tmp26.dtype)
tmp28 = tl.where(tmp11, tmp26, tmp27)
tmp29 = tmp7 >= tmp10
tmp30 = tl.full([1], 2, tl.int64)
tmp31 = tmp7 < tmp30
tmp32 = tl.load(in_ptr1 + (4097 + (2*(x0 // 2)) + (128*x1)), tmp29, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp33 = tmp32.to(tl.float32)
tl.device_assert(((0 <= tl.broadcast_to(tmp16, [XBLOCK])) & (tl.broadcast_to(tmp16, [XBLOCK]) < 2048)) | ~(tmp29), "index out of bounds: 0 <= tl.broadcast_to(tmp16, [XBLOCK]) < 2048")
tmp35 = tl.load(in_ptr2 + ((2*(x0 // 2)) + (128*tmp16)), tmp29, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp36 = tmp35.to(tl.float32)
tmp37 = tmp33 * tmp36
tmp38 = tl.load(in_ptr1 + (4096 + (2*(x0 // 2)) + (128*x1)), tmp29, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp39 = tmp38.to(tl.float32)
tmp40 = tl.load(in_ptr2 + (1 + (2*(x0 // 2)) + (128*tmp16)), tmp29, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp41 = tmp40.to(tl.float32)
tmp42 = tmp39 * tmp41
tmp43 = tmp37 + tmp42
tmp44 = tl.full(tmp43.shape, 0.0, tmp43.dtype)
tmp45 = tl.where(tmp29, tmp43, tmp44)
tmp46 = tl.where(tmp11, tmp28, tmp45)
tmp47 = tmp46.to(tl.float32)
tmp48 = tl.load(in_ptr1 + ((2*(x0 // 2)) + (128*x1)), tmp11, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp49 = tmp48.to(tl.float32)
tmp50 = tmp49 * tmp19
tmp51 = tl.load(in_ptr1 + (1 + (2*(x0 // 2)) + (128*x1)), tmp11, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp52 = tmp51.to(tl.float32)
tmp53 = tmp52 * tmp24
tmp54 = tmp50 - tmp53
tmp55 = tl.full(tmp54.shape, 0.0, tmp54.dtype)
tmp56 = tl.where(tmp11, tmp54, tmp55)
tmp57 = tl.load(in_ptr1 + (1 + (2*(x0 // 2)) + (128*x1)), tmp29, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp58 = tmp57.to(tl.float32)
tmp59 = tmp58 * tmp36
tmp60 = tl.load(in_ptr1 + ((2*(x0 // 2)) + (128*x1)), tmp29, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp61 = tmp60.to(tl.float32)
tmp62 = tmp61 * tmp41
tmp63 = tmp59 + tmp62
tmp64 = tl.full(tmp63.shape, 0.0, tmp63.dtype)
tmp65 = tl.where(tmp29, tmp63, tmp64)
tmp66 = tl.where(tmp11, tmp56, tmp65)
tmp67 = tmp66.to(tl.float32)
tmp68 = 0.29730177875068026
tmp69 = tmp67 * tmp68
tmp70 = tmp69.to(tl.float32)
tl.store(out_ptr0 + (x0 + (128*tmp5) + (26624*x1)), tmp47, None)
tl.store(out_ptr1 + (x2), tmp70, None)
tl.store(out_ptr2 + (x0 + (128*tmp5) + (26624*x1)), tmp71, None)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_ubuntu/tu/ctuxbryvllsmitwtw4gk35f6e4uug6renb6km4llzq2czdpete6i.py
# Source Nodes: [y], Original ATen: [aten.bmm]
# y => mul_13, sum_1
triton_red_fused_bmm_2 = async_compile.triton('triton_red_fused_bmm_2', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
@triton_heuristics.reduction(
size_hints=[8192, 128],
reduction_hint=ReductionHint.DEFAULT,
filename=__file__,
triton_meta={'signature': {0: '*fp32', 1: '*fp16', 2: '*fp32', 3: 'i32', 4: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=76), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4), equal_to_1=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused_bmm_2', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 2, 'num_reduction': 1, 'backend_hash': '5002d3b4f89584c56f8df13bbe2510318f7ec3b73ba9efd611640fb08b2fe939', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False}
)
@triton.jit
def triton_red_fused_bmm_2(in_ptr0, in_ptr1, out_ptr0, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr):
xnumel = 6656
rnumel = 128
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
xmask = xindex < xnumel
rbase = tl.arange(0, RBLOCK)[None, :]
x1 = (xindex // 208)
x3 = xindex
_tmp7 = tl.full([XBLOCK, RBLOCK], 0, tl.float32)
for roffset in range(0, rnumel, RBLOCK):
rindex = roffset + rbase
rmask = rindex < rnumel
r2 = rindex
tmp0 = tl.load(in_ptr0 + (r2 + (128*x1)), rmask & xmask, eviction_policy='evict_last', other=0.0)
tmp1 = tl.load(in_ptr1 + (r2 + (128*x3)), rmask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp2 = 0.29730177875068026
tmp3 = tmp1 * tmp2
tmp4 = tmp3.to(tl.float32)
tmp5 = tmp0 * tmp4
tmp6 = tl.broadcast_to(tmp5, [XBLOCK, RBLOCK])
tmp8 = _tmp7 + tmp6
_tmp7 = tl.where(rmask & xmask, tmp8, _tmp7)
tmp7 = tl.sum(_tmp7, 1)[:, None]
tl.store(out_ptr0 + (x3), tmp7, xmask)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_ubuntu/f6/cf6hoyusglegqlcllwth4hbv7jzk7xu6ol3qg56cytjysskalpwn.py
# Source Nodes: [mask, y], Original ATen: [aten._softmax, aten.add, aten.bmm, aten.index, aten.logical_not, aten.masked_fill, aten.zeros_like]
# mask => index
# y => add_3, amax, convert_element_type_11, convert_element_type_9, exp, full_default, full_default_1, logical_not, sub_2, sum_2, where
triton_per_fused__softmax_add_bmm_index_logical_not_masked_fill_zeros_like_3 = async_compile.triton('triton_per_fused__softmax_add_bmm_index_logical_not_masked_fill_zeros_like_3', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
@triton_heuristics.persistent_reduction(
size_hints=[32, 256],
reduction_hint=ReductionHint.INNER,
filename=__file__,
triton_meta={'signature': {0: '*fp32', 1: '*i32', 2: '*i1', 3: '*fp32', 4: 'i32', 5: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=76), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4, 5), equal_to_1=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_per_fused__softmax_add_bmm_index_logical_not_masked_fill_zeros_like_3', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 2, 'num_reduction': 2, 'backend_hash': '5002d3b4f89584c56f8df13bbe2510318f7ec3b73ba9efd611640fb08b2fe939', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False}
)
@triton.jit
def triton_per_fused__softmax_add_bmm_index_logical_not_masked_fill_zeros_like_3(in_ptr0, in_ptr1, in_ptr2, out_ptr2, xnumel, rnumel, XBLOCK : tl.constexpr):
xnumel = 32
rnumel = 208
RBLOCK: tl.constexpr = 256
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
xmask = xindex < xnumel
rindex = tl.arange(0, RBLOCK)[None, :]
roffset = 0
rmask = rindex < rnumel
r1 = rindex
x0 = xindex
tmp0 = tl.load(in_ptr0 + (r1 + (208*x0)), rmask & xmask, other=0.0)
tmp2 = tl.load(in_ptr1 + (0))
tmp3 = tl.broadcast_to(tmp2, [XBLOCK, RBLOCK])
tmp1 = tmp0.to(tl.float32)
tmp4 = tl.full([XBLOCK, RBLOCK], 208, tl.int32)
tmp5 = tmp3 + tmp4
tmp6 = tmp3 < 0
tmp7 = tl.where(tmp6, tmp5, tmp3)
tl.device_assert((0 <= tmp7) & (tmp7 < 208), "index out of bounds: 0 <= tmp7 < 208")
tmp9 = tl.load(in_ptr2 + (r1 + (208*tmp7)), rmask, other=0.0).to(tl.int1)
tmp10 = tmp9 == 0
tmp11 = float("-inf")
tmp12 = 0.0
tmp13 = tl.where(tmp10, tmp11, tmp12)
tmp14 = tmp1 + tmp13
tmp15 = tmp14.to(tl.float32)
tmp16 = tl.broadcast_to(tmp15, [XBLOCK, RBLOCK])
tmp18 = tl.where(rmask & xmask, tmp16, float("-inf"))
tmp19 = triton_helpers.max2(tmp18, 1)[:, None]
tmp20 = tmp15 - tmp19
tmp21 = tl_math.exp(tmp20)
tmp22 = tl.broadcast_to(tmp21, [XBLOCK, RBLOCK])
tmp24 = tl.where(rmask & xmask, tmp22, 0)
tmp25 = tl.sum(tmp24, 1)[:, None]
tmp26 = tmp21 / tmp25
tmp27 = tmp26.to(tl.float32)
tmp28 = tmp27.to(tl.float32)
tl.store(out_ptr2 + (r1 + (208*x0)), tmp28, rmask & xmask)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_ubuntu/3q/c3q5eet7mn3gx23dlvno24ooilcrlirl64ahmhluvxu6ynwq452m.py
# Source Nodes: [y], Original ATen: [aten.bmm]
# y => mul_14, sum_3
triton_red_fused_bmm_4 = async_compile.triton('triton_red_fused_bmm_4', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
@triton_heuristics.reduction(
size_hints=[8192, 128],
reduction_hint=ReductionHint.OUTER,
filename=__file__,
triton_meta={'signature': {0: '*fp32', 1: '*fp16', 2: '*fp32', 3: 'i32', 4: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=76), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3), equal_to_1=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused_bmm_4', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 2, 'num_reduction': 1, 'backend_hash': '5002d3b4f89584c56f8df13bbe2510318f7ec3b73ba9efd611640fb08b2fe939', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False}
)
@triton.jit
def triton_red_fused_bmm_4(in_ptr0, in_ptr1, out_ptr0, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr):
xnumel = 8192
rnumel = 104
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
xmask = xindex < xnumel
rbase = tl.arange(0, RBLOCK)[None, :]
x1 = (xindex // 128)
x0 = xindex % 128
_tmp5 = tl.full([XBLOCK, RBLOCK], 0, tl.float32)
x3 = xindex
for roffset in range(0, rnumel, RBLOCK):
rindex = roffset + rbase
rmask = rindex < rnumel
r2 = rindex
tmp0 = tl.load(in_ptr0 + (r2 + (104*x1)), rmask, eviction_policy='evict_last', other=0.0)
tmp1 = tl.load(in_ptr1 + (x0 + (128*r2) + (13312*x1)), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp2 = tmp1.to(tl.float32)
tmp3 = tmp0 * tmp2
tmp4 = tl.broadcast_to(tmp3, [XBLOCK, RBLOCK])
tmp6 = _tmp5 + tmp4
_tmp5 = tl.where(rmask, tmp6, _tmp5)
tmp5 = tl.sum(_tmp5, 1)[:, None]
tl.store(out_ptr0 + (x3), tmp5, None)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_ubuntu/s7/cs7bh5txamreduogo3dbpgu62hdnknommfet7uxuau6q2gi24u6j.py
# Source Nodes: [out_1, y], Original ATen: [aten.bmm, torchao.fp16act_fp6weight_linear]
# out_1 => fp16act_fp6weight_linear_1
# y => mul_14, sum_3
triton_per_fused_bmm_fp16act_fp6weight_linear_5 = async_compile.triton('triton_per_fused_bmm_fp16act_fp6weight_linear_5', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
@triton_heuristics.persistent_reduction(
size_hints=[4096, 2],
reduction_hint=ReductionHint.OUTER,
filename=__file__,
triton_meta={'signature': {0: '*fp32', 1: '*fp16', 2: 'i32', 3: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=76), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2), equal_to_1=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_per_fused_bmm_fp16act_fp6weight_linear_5', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': '5002d3b4f89584c56f8df13bbe2510318f7ec3b73ba9efd611640fb08b2fe939', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False}
)
@triton.jit
def triton_per_fused_bmm_fp16act_fp6weight_linear_5(in_ptr0, out_ptr1, xnumel, rnumel, XBLOCK : tl.constexpr):
xnumel = 4096
rnumel = 2
RBLOCK: tl.constexpr = 2
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
xmask = xindex < xnumel
rindex = tl.arange(0, RBLOCK)[None, :]
roffset = 0
rmask = rindex < rnumel
r2 = rindex
x0 = xindex % 128
x1 = (xindex // 128)
x3 = xindex
tmp0 = tl.load(in_ptr0 + (x0 + (128*r2) + (256*x1)), rmask, other=0.0)
tmp1 = tl.broadcast_to(tmp0, [XBLOCK, RBLOCK])
tmp3 = tl.where(rmask, tmp1, 0)
tmp4 = tl.sum(tmp3, 1)[:, None]
tmp5 = tmp4.to(tl.float32)
tl.store(out_ptr1 + (x3), tmp5, None)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_ubuntu/e6/ce6earzjiin77ifcgxqhqghmkxait7rqhahkb2zasp4hadk7bw4d.py
# Source Nodes: [add_4, float_4, h, mean_1, mul_11, mul_12, mul_13, output_1, rsqrt_1, x], Original ATen: [aten._to_copy, aten.add, aten.embedding, aten.mean, aten.mul, aten.rsqrt]
# add_4 => add_5
# float_4 => convert_element_type_14
# h => add_4
# mean_1 => mean_1
# mul_11 => mul_15
# mul_12 => mul_16
# mul_13 => mul_17
# output_1 => convert_element_type_15
# rsqrt_1 => rsqrt_1
# x => embedding
triton_red_fused__to_copy_add_embedding_mean_mul_rsqrt_6 = async_compile.triton('triton_red_fused__to_copy_add_embedding_mean_mul_rsqrt_6', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
@triton_heuristics.reduction(
size_hints=[1, 4096],
reduction_hint=ReductionHint.INNER,
filename=__file__,
triton_meta={'signature': {0: '*i32', 1: '*fp16', 2: '*fp16', 3: '*fp16', 4: '*fp16', 5: 'i32', 6: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=76), 'constants': {5: 1}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4, 6), equal_to_1=(5,))]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_add_embedding_mean_mul_rsqrt_6', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 5, 'num_reduction': 1, 'backend_hash': '5002d3b4f89584c56f8df13bbe2510318f7ec3b73ba9efd611640fb08b2fe939', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False}
)
@triton.jit
def triton_red_fused__to_copy_add_embedding_mean_mul_rsqrt_6(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr1, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr):
xnumel = 1
rnumel = 4096
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
xmask = xindex < xnumel
rbase = tl.arange(0, RBLOCK)[None, :]
tmp0 = tl.load(in_ptr0 + (0))
tmp1 = tl.broadcast_to(tmp0, [XBLOCK, RBLOCK])
_tmp13 = tl.full([XBLOCK, RBLOCK], 0, tl.float32)
for roffset in range(0, rnumel, RBLOCK):
rindex = roffset + rbase
rmask = rindex < rnumel
r0 = rindex
tmp8 = tl.load(in_ptr2 + (r0), rmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp2 = tl.full([XBLOCK, RBLOCK], 32000, tl.int32)
tmp3 = tmp1 + tmp2
tmp4 = tmp1 < 0
tmp5 = tl.where(tmp4, tmp3, tmp1)
tl.device_assert((0 <= tmp5) & (tmp5 < 32000), "index out of bounds: 0 <= tmp5 < 32000")
tmp7 = tl.load(in_ptr1 + (r0 + (4096*tmp5)), rmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp9 = tmp7 + tmp8
tmp10 = tmp9.to(tl.float32)
tmp11 = tmp10 * tmp10
tmp12 = tl.broadcast_to(tmp11, [XBLOCK, RBLOCK])
tmp14 = _tmp13 + tmp12
_tmp13 = tl.where(rmask, tmp14, _tmp13)
tmp13 = tl.sum(_tmp13, 1)[:, None]
tmp15 = tl.load(in_ptr0 + (0))
tmp16 = tl.broadcast_to(tmp15, [XBLOCK, RBLOCK])
for roffset in range(0, rnumel, RBLOCK):
rindex = roffset + rbase
rmask = rindex < rnumel
r0 = rindex
tmp23 = tl.load(in_ptr2 + (r0), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp33 = tl.load(in_ptr3 + (r0), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp17 = tl.full([XBLOCK, RBLOCK], 32000, tl.int32)
tmp18 = tmp16 + tmp17
tmp19 = tmp16 < 0
tmp20 = tl.where(tmp19, tmp18, tmp16)
tl.device_assert((0 <= tmp20) & (tmp20 < 32000), "index out of bounds: 0 <= tmp20 < 32000")
tmp22 = tl.load(in_ptr1 + (r0 + (4096*tmp20)), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp24 = tmp22 + tmp23
tmp25 = tmp24.to(tl.float32)
tmp26 = 4096.0
tmp27 = tmp13 / tmp26
tmp28 = 1e-05
tmp29 = tmp27 + tmp28
tmp30 = libdevice.rsqrt(tmp29)
tmp31 = tmp25 * tmp30
tmp32 = tmp31.to(tl.float32)
tmp34 = tmp32 * tmp33
tl.store(out_ptr1 + (tl.broadcast_to(r0, [XBLOCK, RBLOCK])), tmp34, rmask)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_ubuntu/7x/c7xzb5ucrv7a232bufgvbjemlk4jgdh7ujoroxm3v7jxuduwdfif.py
# Source Nodes: [out_4], Original ATen: [torchao.fp16act_fp6weight_linear]
# out_4 => fp16act_fp6weight_linear_4
triton_poi_fused_fp16act_fp6weight_linear_7 = async_compile.triton('triton_poi_fused_fp16act_fp6weight_linear_7', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
@triton_heuristics.pointwise(
size_hints=[16384],
filename=__file__,
triton_meta={'signature': {0: '*fp16', 1: '*fp16', 2: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=76), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2), equal_to_1=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_fp16act_fp6weight_linear_7', 'mutated_arg_names': ['in_out_ptr0'], 'no_x_dim': False, 'num_load': 2, 'num_reduction': 0, 'backend_hash': '5002d3b4f89584c56f8df13bbe2510318f7ec3b73ba9efd611640fb08b2fe939', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False},
min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused_fp16act_fp6weight_linear_7(in_out_ptr0, in_ptr0, xnumel, XBLOCK : tl.constexpr):
xnumel = 11008
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = xindex < xnumel
x0 = xindex
tmp0 = tl.load(in_out_ptr0 + (x0), xmask).to(tl.float32)
tmp5 = tl.load(in_ptr0 + (x0), xmask).to(tl.float32)
tmp1 = tmp0.to(tl.float32)
tmp2 = tl.sigmoid(tmp1)
tmp3 = tmp1 * tmp2
tmp4 = tmp3.to(tl.float32)
tmp6 = tmp4 * tmp5
tl.store(in_out_ptr0 + (x0), tmp6, xmask)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_ubuntu/xb/cxbu5uv7truaeft3qae6exsdva5w6j3jruqooy6lvbsp53jek26v.py
# Source Nodes: [float_5, h, mean_2, mul_15, out_5, out_6, x], Original ATen: [aten._to_copy, aten.add, aten.embedding, aten.mean, aten.mul, torchao.fp16act_fp6weight_linear]
# float_5 => convert_element_type_18
# h => add_4
# mean_2 => mean_2
# mul_15 => mul_20
# out_5 => add_6
# out_6 => fp16act_fp6weight_linear_5
# x => embedding
triton_red_fused__to_copy_add_embedding_fp16act_fp6weight_linear_mean_mul_8 = async_compile.triton('triton_red_fused__to_copy_add_embedding_fp16act_fp6weight_linear_mean_mul_8', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
@triton_heuristics.reduction(
size_hints=[1, 4096],
reduction_hint=ReductionHint.INNER,
filename=__file__,
triton_meta={'signature': {0: '*i32', 1: '*fp16', 2: '*fp16', 3: '*fp16', 4: '*fp16', 5: '*fp16', 6: 'i32', 7: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=76), 'constants': {6: 1}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4, 5, 7), equal_to_1=(6,))]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_add_embedding_fp16act_fp6weight_linear_mean_mul_8', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 7, 'num_reduction': 1, 'backend_hash': '5002d3b4f89584c56f8df13bbe2510318f7ec3b73ba9efd611640fb08b2fe939', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False}
)
@triton.jit
def triton_red_fused__to_copy_add_embedding_fp16act_fp6weight_linear_mean_mul_8(in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, out_ptr1, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr):
xnumel = 1
rnumel = 4096
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
xmask = xindex < xnumel
rbase = tl.arange(0, RBLOCK)[None, :]
tmp0 = tl.load(in_ptr0 + (0))
tmp1 = tl.broadcast_to(tmp0, [XBLOCK, RBLOCK])
_tmp15 = tl.full([XBLOCK, RBLOCK], 0, tl.float32)
for roffset in range(0, rnumel, RBLOCK):
rindex = roffset + rbase
rmask = rindex < rnumel
r0 = rindex
tmp8 = tl.load(in_ptr2 + (r0), rmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp10 = tl.load(in_ptr3 + (r0), rmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp2 = tl.full([XBLOCK, RBLOCK], 32000, tl.int32)
tmp3 = tmp1 + tmp2
tmp4 = tmp1 < 0
tmp5 = tl.where(tmp4, tmp3, tmp1)
tl.device_assert((0 <= tmp5) & (tmp5 < 32000), "index out of bounds: 0 <= tmp5 < 32000")
tmp7 = tl.load(in_ptr1 + (r0 + (4096*tmp5)), rmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp9 = tmp7 + tmp8
tmp11 = tmp9 + tmp10
tmp12 = tmp11.to(tl.float32)
tmp13 = tmp12 * tmp12
tmp14 = tl.broadcast_to(tmp13, [XBLOCK, RBLOCK])
tmp16 = _tmp15 + tmp14
_tmp15 = tl.where(rmask, tmp16, _tmp15)
tmp15 = tl.sum(_tmp15, 1)[:, None]
tmp17 = tl.load(in_ptr0 + (0))
tmp18 = tl.broadcast_to(tmp17, [XBLOCK, RBLOCK])
for roffset in range(0, rnumel, RBLOCK):
rindex = roffset + rbase
rmask = rindex < rnumel
r0 = rindex
tmp25 = tl.load(in_ptr2 + (r0), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp27 = tl.load(in_ptr3 + (r0), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp37 = tl.load(in_ptr4 + (r0), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp19 = tl.full([XBLOCK, RBLOCK], 32000, tl.int32)
tmp20 = tmp18 + tmp19
tmp21 = tmp18 < 0
tmp22 = tl.where(tmp21, tmp20, tmp18)
tl.device_assert((0 <= tmp22) & (tmp22 < 32000), "index out of bounds: 0 <= tmp22 < 32000")
tmp24 = tl.load(in_ptr1 + (r0 + (4096*tmp22)), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp26 = tmp24 + tmp25
tmp28 = tmp26 + tmp27
tmp29 = tmp28.to(tl.float32)
tmp30 = 4096.0
tmp31 = tmp15 / tmp30
tmp32 = 1e-05
tmp33 = tmp31 + tmp32
tmp34 = libdevice.rsqrt(tmp33)
tmp35 = tmp29 * tmp34
tmp36 = tmp35.to(tl.float32)
tmp38 = tmp36 * tmp37
tl.store(out_ptr1 + (tl.broadcast_to(r0, [XBLOCK, RBLOCK])), tmp38, rmask)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_ubuntu/ko/ckonz56ntbc3wxvboqbs6j6xmjdrqma6fl5rtmem5grbfiymtcs4.py
# Source Nodes: [float_8, h, h_1, mean_3, mul_26, out_5, out_8, out_9, x], Original ATen: [aten._to_copy, aten.add, aten.embedding, aten.mean, aten.mul, torchao.fp16act_fp6weight_linear]
# float_8 => convert_element_type_32
# h => add_4
# h_1 => add_11
# mean_3 => mean_3
# mul_26 => mul_35
# out_5 => add_6
# out_8 => fp16act_fp6weight_linear_7
# out_9 => fp16act_fp6weight_linear_8
# x => embedding
triton_red_fused__to_copy_add_embedding_fp16act_fp6weight_linear_mean_mul_9 = async_compile.triton('triton_red_fused__to_copy_add_embedding_fp16act_fp6weight_linear_mean_mul_9', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
@triton_heuristics.reduction(
size_hints=[1, 4096],
reduction_hint=ReductionHint.INNER,
filename=__file__,
triton_meta={'signature': {0: '*fp16', 1: '*i32', 2: '*fp16', 3: '*fp16', 4: '*fp16', 5: '*fp16', 6: '*fp16', 7: '*fp16', 8: 'i32', 9: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=76), 'constants': {8: 1}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4, 5, 6, 7, 9), equal_to_1=(8,))]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_add_embedding_fp16act_fp6weight_linear_mean_mul_9', 'mutated_arg_names': ['in_out_ptr0'], 'no_x_dim': False, 'num_load': 6, 'num_reduction': 1, 'backend_hash': '5002d3b4f89584c56f8df13bbe2510318f7ec3b73ba9efd611640fb08b2fe939', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False}
)
@triton.jit
def triton_red_fused__to_copy_add_embedding_fp16act_fp6weight_linear_mean_mul_9(in_out_ptr0, in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, out_ptr1, out_ptr2, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr):
xnumel = 1
rnumel = 4096
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
xmask = xindex < xnumel
rbase = tl.arange(0, RBLOCK)[None, :]
tmp0 = tl.load(in_ptr0 + (0))
tmp1 = tl.broadcast_to(tmp0, [XBLOCK, RBLOCK])
_tmp17 = tl.full([XBLOCK, RBLOCK], 0, tl.float32)
for roffset in range(0, rnumel, RBLOCK):
rindex = roffset + rbase
rmask = rindex < rnumel
r0 = rindex
tmp8 = tl.load(in_out_ptr0 + (r0), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp10 = tl.load(in_ptr2 + (r0), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp12 = tl.load(in_ptr3 + (r0), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp2 = tl.full([XBLOCK, RBLOCK], 32000, tl.int32)
tmp3 = tmp1 + tmp2
tmp4 = tmp1 < 0
tmp5 = tl.where(tmp4, tmp3, tmp1)
tl.device_assert((0 <= tmp5) & (tmp5 < 32000), "index out of bounds: 0 <= tmp5 < 32000")
tmp7 = tl.load(in_ptr1 + (r0 + (4096*tmp5)), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp9 = tmp7 + tmp8
tmp11 = tmp9 + tmp10
tmp13 = tmp11 + tmp12
tmp14 = tmp13.to(tl.float32)
tmp15 = tmp14 * tmp14
tmp16 = tl.broadcast_to(tmp15, [XBLOCK, RBLOCK])
tmp18 = _tmp17 + tmp16
_tmp17 = tl.where(rmask, tmp18, _tmp17)
tl.store(in_out_ptr0 + (tl.broadcast_to(r0, [XBLOCK, RBLOCK])), tmp13, rmask)
tmp17 = tl.sum(_tmp17, 1)[:, None]
for roffset in range(0, rnumel, RBLOCK):
rindex = roffset + rbase
rmask = rindex < rnumel
r0 = rindex
tmp19 = tl.load(in_out_ptr0 + (r0), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp28 = tl.load(in_ptr4 + (r0), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp20 = tmp19.to(tl.float32)
tmp21 = 4096.0
tmp22 = tmp17 / tmp21
tmp23 = 1e-05
tmp24 = tmp22 + tmp23
tmp25 = libdevice.rsqrt(tmp24)
tmp26 = tmp20 * tmp25
tmp27 = tmp26.to(tl.float32)
tmp29 = tmp27 * tmp28
tl.store(out_ptr1 + (tl.broadcast_to(r0, [XBLOCK, RBLOCK])), tmp29, rmask)
tl.store(out_ptr2 + (tl.broadcast_to(r0, [XBLOCK, RBLOCK])), tmp29, rmask)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_ubuntu/23/c23scxttsdz72vesaap27tazzxervcixdrt7hwfybhp2nrfu4kw4.py
# Source Nodes: [float_9, mean_4, mul_30, out_11, out_12], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, torchao.fp16act_fp6weight_linear]
# float_9 => convert_element_type_36
# mean_4 => mean_4
# mul_30 => mul_40
# out_11 => add_13
# out_12 => fp16act_fp6weight_linear_10
triton_red_fused__to_copy_add_fp16act_fp6weight_linear_mean_mul_10 = async_compile.triton('triton_red_fused__to_copy_add_fp16act_fp6weight_linear_mean_mul_10', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
@triton_heuristics.reduction(
size_hints=[1, 4096],
reduction_hint=ReductionHint.INNER,
filename=__file__,
triton_meta={'signature': {0: '*fp16', 1: '*fp16', 2: '*fp16', 3: '*fp16', 4: 'i32', 5: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=76), 'constants': {4: 1}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 5), equal_to_1=(4,))]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_add_fp16act_fp6weight_linear_mean_mul_10', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 5, 'num_reduction': 1, 'backend_hash': '5002d3b4f89584c56f8df13bbe2510318f7ec3b73ba9efd611640fb08b2fe939', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False}
)
@triton.jit
def triton_red_fused__to_copy_add_fp16act_fp6weight_linear_mean_mul_10(in_ptr0, in_ptr1, in_ptr2, out_ptr1, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr):
xnumel = 1
rnumel = 4096
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
xmask = xindex < xnumel
rbase = tl.arange(0, RBLOCK)[None, :]
_tmp6 = tl.full([XBLOCK, RBLOCK], 0, tl.float32)
for roffset in range(0, rnumel, RBLOCK):
rindex = roffset + rbase
rmask = rindex < rnumel
r0 = rindex
tmp0 = tl.load(in_ptr0 + (r0), rmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp1 = tl.load(in_ptr1 + (r0), rmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp2 = tmp0 + tmp1
tmp3 = tmp2.to(tl.float32)
tmp4 = tmp3 * tmp3
tmp5 = tl.broadcast_to(tmp4, [XBLOCK, RBLOCK])
tmp7 = _tmp6 + tmp5
_tmp6 = tl.where(rmask, tmp7, _tmp6)
tmp6 = tl.sum(_tmp6, 1)[:, None]
for roffset in range(0, rnumel, RBLOCK):
rindex = roffset + rbase
rmask = rindex < rnumel
r0 = rindex
tmp8 = tl.load(in_ptr0 + (r0), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp9 = tl.load(in_ptr1 + (r0), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp19 = tl.load(in_ptr2 + (r0), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp10 = tmp8 + tmp9
tmp11 = tmp10.to(tl.float32)
tmp12 = 4096.0
tmp13 = tmp6 / tmp12
tmp14 = 1e-05
tmp15 = tmp13 + tmp14
tmp16 = libdevice.rsqrt(tmp15)
tmp17 = tmp11 * tmp16
tmp18 = tmp17.to(tl.float32)
tmp20 = tmp18 * tmp19
tl.store(out_ptr1 + (tl.broadcast_to(r0, [XBLOCK, RBLOCK])), tmp20, rmask)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_ubuntu/sg/csgnrv5nhw4eic7f275ngxcfz73e62dtja6zilbn6f5a2qhvqn54.py
# Source Nodes: [add_16, float_12, h_2, mean_5, mul_41, mul_42, mul_43, out_11, output_5, rsqrt_5], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, aten.rsqrt]
# add_16 => add_19
# float_12 => convert_element_type_50
# h_2 => add_18
# mean_5 => mean_5
# mul_41 => mul_55
# mul_42 => mul_56
# mul_43 => mul_57
# out_11 => add_13
# output_5 => convert_element_type_51
# rsqrt_5 => rsqrt_5
triton_red_fused__to_copy_add_mean_mul_rsqrt_11 = async_compile.triton('triton_red_fused__to_copy_add_mean_mul_rsqrt_11', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
@triton_heuristics.reduction(
size_hints=[1, 4096],
reduction_hint=ReductionHint.INNER,
filename=__file__,
triton_meta={'signature': {0: '*fp16', 1: '*fp16', 2: '*fp16', 3: '*fp16', 4: '*fp16', 5: 'i32', 6: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=76), 'constants': {5: 1}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4, 6), equal_to_1=(5,))]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_add_mean_mul_rsqrt_11', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 7, 'num_reduction': 1, 'backend_hash': '5002d3b4f89584c56f8df13bbe2510318f7ec3b73ba9efd611640fb08b2fe939', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False}
)
@triton.jit
def triton_red_fused__to_copy_add_mean_mul_rsqrt_11(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr1, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr):
xnumel = 1
rnumel = 4096
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
xmask = xindex < xnumel
rbase = tl.arange(0, RBLOCK)[None, :]
_tmp8 = tl.full([XBLOCK, RBLOCK], 0, tl.float32)
for roffset in range(0, rnumel, RBLOCK):
rindex = roffset + rbase
rmask = rindex < rnumel
r0 = rindex
tmp0 = tl.load(in_ptr0 + (r0), rmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp1 = tl.load(in_ptr1 + (r0), rmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp3 = tl.load(in_ptr2 + (r0), rmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp2 = tmp0 + tmp1
tmp4 = tmp2 + tmp3
tmp5 = tmp4.to(tl.float32)
tmp6 = tmp5 * tmp5
tmp7 = tl.broadcast_to(tmp6, [XBLOCK, RBLOCK])
tmp9 = _tmp8 + tmp7
_tmp8 = tl.where(rmask, tmp9, _tmp8)
tmp8 = tl.sum(_tmp8, 1)[:, None]
for roffset in range(0, rnumel, RBLOCK):
rindex = roffset + rbase
rmask = rindex < rnumel
r0 = rindex
tmp10 = tl.load(in_ptr0 + (r0), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp11 = tl.load(in_ptr1 + (r0), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp13 = tl.load(in_ptr2 + (r0), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp23 = tl.load(in_ptr3 + (r0), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp12 = tmp10 + tmp11
tmp14 = tmp12 + tmp13
tmp15 = tmp14.to(tl.float32)
tmp16 = 4096.0
tmp17 = tmp8 / tmp16
tmp18 = 1e-05
tmp19 = tmp17 + tmp18
tmp20 = libdevice.rsqrt(tmp19)
tmp21 = tmp15 * tmp20
tmp22 = tmp21.to(tl.float32)
tmp24 = tmp22 * tmp23
tl.store(out_ptr1 + (tl.broadcast_to(r0, [XBLOCK, RBLOCK])), tmp24, rmask)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_ubuntu/nw/cnwia6b5wdcoewnme63eazcbz6dvzogn3stxnaty267p4vw2mmwy.py
# Source Nodes: [float_13, h_2, mean_6, mul_45, out_11, out_17, out_18], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, torchao.fp16act_fp6weight_linear]
# float_13 => convert_element_type_54
# h_2 => add_18
# mean_6 => mean_6
# mul_45 => mul_60
# out_11 => add_13
# out_17 => add_20
# out_18 => fp16act_fp6weight_linear_15
triton_red_fused__to_copy_add_fp16act_fp6weight_linear_mean_mul_12 = async_compile.triton('triton_red_fused__to_copy_add_fp16act_fp6weight_linear_mean_mul_12', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
@triton_heuristics.reduction(
size_hints=[1, 4096],
reduction_hint=ReductionHint.INNER,
filename=__file__,
triton_meta={'signature': {0: '*fp16', 1: '*fp16', 2: '*fp16', 3: '*fp16', 4: '*fp16', 5: '*fp16', 6: 'i32', 7: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=76), 'constants': {6: 1}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4, 5, 7), equal_to_1=(6,))]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_add_fp16act_fp6weight_linear_mean_mul_12', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 9, 'num_reduction': 1, 'backend_hash': '5002d3b4f89584c56f8df13bbe2510318f7ec3b73ba9efd611640fb08b2fe939', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False}
)
@triton.jit
def triton_red_fused__to_copy_add_fp16act_fp6weight_linear_mean_mul_12(in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, out_ptr1, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr):
xnumel = 1
rnumel = 4096
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
xmask = xindex < xnumel
rbase = tl.arange(0, RBLOCK)[None, :]
_tmp10 = tl.full([XBLOCK, RBLOCK], 0, tl.float32)
for roffset in range(0, rnumel, RBLOCK):
rindex = roffset + rbase
rmask = rindex < rnumel
r0 = rindex
tmp0 = tl.load(in_ptr0 + (r0), rmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp1 = tl.load(in_ptr1 + (r0), rmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp3 = tl.load(in_ptr2 + (r0), rmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp5 = tl.load(in_ptr3 + (r0), rmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp2 = tmp0 + tmp1
tmp4 = tmp2 + tmp3
tmp6 = tmp4 + tmp5
tmp7 = tmp6.to(tl.float32)
tmp8 = tmp7 * tmp7
tmp9 = tl.broadcast_to(tmp8, [XBLOCK, RBLOCK])
tmp11 = _tmp10 + tmp9
_tmp10 = tl.where(rmask, tmp11, _tmp10)
tmp10 = tl.sum(_tmp10, 1)[:, None]
for roffset in range(0, rnumel, RBLOCK):
rindex = roffset + rbase
rmask = rindex < rnumel
r0 = rindex
tmp12 = tl.load(in_ptr0 + (r0), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp13 = tl.load(in_ptr1 + (r0), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp15 = tl.load(in_ptr2 + (r0), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp17 = tl.load(in_ptr3 + (r0), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp27 = tl.load(in_ptr4 + (r0), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp14 = tmp12 + tmp13
tmp16 = tmp14 + tmp15
tmp18 = tmp16 + tmp17
tmp19 = tmp18.to(tl.float32)
tmp20 = 4096.0
tmp21 = tmp10 / tmp20
tmp22 = 1e-05
tmp23 = tmp21 + tmp22
tmp24 = libdevice.rsqrt(tmp23)
tmp25 = tmp19 * tmp24
tmp26 = tmp25.to(tl.float32)
tmp28 = tmp26 * tmp27
tl.store(out_ptr1 + (tl.broadcast_to(r0, [XBLOCK, RBLOCK])), tmp28, rmask)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_ubuntu/yi/cyiywyx2b2ap35djo57g2tkxu3hdej3o4ootd3nr6bigecahyoeu.py
# Source Nodes: [float_16, h_2, h_3, mean_7, mul_56, out_11, out_17, out_20, out_21], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, torchao.fp16act_fp6weight_linear]
# float_16 => convert_element_type_68
# h_2 => add_18
# h_3 => add_25
# mean_7 => mean_7
# mul_56 => mul_75
# out_11 => add_13
# out_17 => add_20
# out_20 => fp16act_fp6weight_linear_17
# out_21 => fp16act_fp6weight_linear_18
triton_red_fused__to_copy_add_fp16act_fp6weight_linear_mean_mul_13 = async_compile.triton('triton_red_fused__to_copy_add_fp16act_fp6weight_linear_mean_mul_13', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
@triton_heuristics.reduction(
size_hints=[1, 4096],
reduction_hint=ReductionHint.INNER,
filename=__file__,
triton_meta={'signature': {0: '*fp16', 1: '*fp16', 2: '*fp16', 3: '*fp16', 4: '*fp16', 5: '*fp16', 6: '*fp16', 7: '*fp16', 8: 'i32', 9: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=76), 'constants': {8: 1}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4, 5, 6, 7, 9), equal_to_1=(8,))]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_add_fp16act_fp6weight_linear_mean_mul_13', 'mutated_arg_names': ['in_out_ptr0'], 'no_x_dim': False, 'num_load': 7, 'num_reduction': 1, 'backend_hash': '5002d3b4f89584c56f8df13bbe2510318f7ec3b73ba9efd611640fb08b2fe939', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False}
)
@triton.jit
def triton_red_fused__to_copy_add_fp16act_fp6weight_linear_mean_mul_13(in_out_ptr0, in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, out_ptr1, out_ptr2, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr):
xnumel = 1
rnumel = 4096
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
xmask = xindex < xnumel
rbase = tl.arange(0, RBLOCK)[None, :]
_tmp12 = tl.full([XBLOCK, RBLOCK], 0, tl.float32)
for roffset in range(0, rnumel, RBLOCK):
rindex = roffset + rbase
rmask = rindex < rnumel
r0 = rindex
tmp0 = tl.load(in_out_ptr0 + (r0), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp1 = tl.load(in_ptr0 + (r0), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp3 = tl.load(in_ptr1 + (r0), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp5 = tl.load(in_ptr2 + (r0), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp7 = tl.load(in_ptr3 + (r0), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp2 = tmp0 + tmp1
tmp4 = tmp2 + tmp3
tmp6 = tmp4 + tmp5
tmp8 = tmp6 + tmp7
tmp9 = tmp8.to(tl.float32)
tmp10 = tmp9 * tmp9
tmp11 = tl.broadcast_to(tmp10, [XBLOCK, RBLOCK])
tmp13 = _tmp12 + tmp11
_tmp12 = tl.where(rmask, tmp13, _tmp12)
tl.store(in_out_ptr0 + (tl.broadcast_to(r0, [XBLOCK, RBLOCK])), tmp8, rmask)
tmp12 = tl.sum(_tmp12, 1)[:, None]
for roffset in range(0, rnumel, RBLOCK):
rindex = roffset + rbase
rmask = rindex < rnumel
r0 = rindex
tmp14 = tl.load(in_out_ptr0 + (r0), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp23 = tl.load(in_ptr4 + (r0), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp15 = tmp14.to(tl.float32)
tmp16 = 4096.0
tmp17 = tmp12 / tmp16
tmp18 = 1e-05
tmp19 = tmp17 + tmp18
tmp20 = libdevice.rsqrt(tmp19)
tmp21 = tmp15 * tmp20
tmp22 = tmp21.to(tl.float32)
tmp24 = tmp22 * tmp23
tl.store(out_ptr1 + (tl.broadcast_to(r0, [XBLOCK, RBLOCK])), tmp24, rmask)
tl.store(out_ptr2 + (tl.broadcast_to(r0, [XBLOCK, RBLOCK])), tmp24, rmask)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_ubuntu/74/c74rqt5v7cgztcvqz3g36kdattrzmlrq33illnplnagfadtt7nqv.py
# Source Nodes: [out_22], Original ATen: [torchao.fp16act_fp6weight_linear]
# out_22 => fp16act_fp6weight_linear_19
triton_poi_fused_fp16act_fp6weight_linear_14 = async_compile.triton('triton_poi_fused_fp16act_fp6weight_linear_14', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
@triton_heuristics.pointwise(
size_hints=[16384],
filename=__file__,
triton_meta={'signature': {0: '*fp16', 1: '*fp16', 2: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=76), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2), equal_to_1=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_fp16act_fp6weight_linear_14', 'mutated_arg_names': ['in_out_ptr0'], 'no_x_dim': False, 'num_load': 2, 'num_reduction': 0, 'backend_hash': '5002d3b4f89584c56f8df13bbe2510318f7ec3b73ba9efd611640fb08b2fe939', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False},
min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused_fp16act_fp6weight_linear_14(in_out_ptr0, in_ptr0, xnumel, XBLOCK : tl.constexpr):
xnumel = 11008
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = xindex < xnumel
x0 = xindex
tmp0 = tl.load(in_ptr0 + (x0), xmask).to(tl.float32)
tmp5 = tl.load(in_out_ptr0 + (x0), xmask).to(tl.float32)
tmp1 = tmp0.to(tl.float32)
tmp2 = tl.sigmoid(tmp1)
tmp3 = tmp1 * tmp2
tmp4 = tmp3.to(tl.float32)
tmp6 = tmp4 * tmp5
tl.store(in_out_ptr0 + (x0), tmp6, xmask)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_ubuntu/od/codwx422sav4ibjcz7e3u2ii64iq4gtjlj6xmjifbwmqjhhlbprf.py
# Source Nodes: [float_24, h_4, h_5, mean_11, mul_86, out_23, out_29, out_32, out_33], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, torchao.fp16act_fp6weight_linear]
# float_24 => convert_element_type_104
# h_4 => add_32
# h_5 => add_39
# mean_11 => mean_11
# mul_86 => mul_115
# out_23 => add_27
# out_29 => add_34
# out_32 => fp16act_fp6weight_linear_27
# out_33 => fp16act_fp6weight_linear_28
triton_red_fused__to_copy_add_fp16act_fp6weight_linear_mean_mul_15 = async_compile.triton('triton_red_fused__to_copy_add_fp16act_fp6weight_linear_mean_mul_15', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
@triton_heuristics.reduction(
size_hints=[1, 4096],
reduction_hint=ReductionHint.INNER,
filename=__file__,
triton_meta={'signature': {0: '*fp16', 1: '*fp16', 2: '*fp16', 3: '*fp16', 4: '*fp16', 5: '*fp16', 6: '*fp16', 7: '*fp16', 8: 'i32', 9: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=76), 'constants': {8: 1}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4, 5, 6, 7, 9), equal_to_1=(8,))]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_add_fp16act_fp6weight_linear_mean_mul_15', 'mutated_arg_names': ['in_out_ptr0'], 'no_x_dim': False, 'num_load': 7, 'num_reduction': 1, 'backend_hash': '5002d3b4f89584c56f8df13bbe2510318f7ec3b73ba9efd611640fb08b2fe939', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False}
)
@triton.jit
def triton_red_fused__to_copy_add_fp16act_fp6weight_linear_mean_mul_15(in_out_ptr0, in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, out_ptr1, out_ptr2, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr):
xnumel = 1
rnumel = 4096
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
xmask = xindex < xnumel
rbase = tl.arange(0, RBLOCK)[None, :]
_tmp12 = tl.full([XBLOCK, RBLOCK], 0, tl.float32)
for roffset in range(0, rnumel, RBLOCK):
rindex = roffset + rbase
rmask = rindex < rnumel
r0 = rindex
tmp0 = tl.load(in_ptr0 + (r0), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp1 = tl.load(in_out_ptr0 + (r0), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp3 = tl.load(in_ptr1 + (r0), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp5 = tl.load(in_ptr2 + (r0), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp7 = tl.load(in_ptr3 + (r0), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp2 = tmp0 + tmp1
tmp4 = tmp2 + tmp3
tmp6 = tmp4 + tmp5
tmp8 = tmp6 + tmp7
tmp9 = tmp8.to(tl.float32)
tmp10 = tmp9 * tmp9
tmp11 = tl.broadcast_to(tmp10, [XBLOCK, RBLOCK])
tmp13 = _tmp12 + tmp11
_tmp12 = tl.where(rmask, tmp13, _tmp12)
tl.store(in_out_ptr0 + (tl.broadcast_to(r0, [XBLOCK, RBLOCK])), tmp8, rmask)
tmp12 = tl.sum(_tmp12, 1)[:, None]
for roffset in range(0, rnumel, RBLOCK):
rindex = roffset + rbase
rmask = rindex < rnumel
r0 = rindex
tmp14 = tl.load(in_out_ptr0 + (r0), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp23 = tl.load(in_ptr4 + (r0), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp15 = tmp14.to(tl.float32)
tmp16 = 4096.0
tmp17 = tmp12 / tmp16
tmp18 = 1e-05
tmp19 = tmp17 + tmp18
tmp20 = libdevice.rsqrt(tmp19)
tmp21 = tmp15 * tmp20
tmp22 = tmp21.to(tl.float32)
tmp24 = tmp22 * tmp23
tl.store(out_ptr1 + (tl.broadcast_to(r0, [XBLOCK, RBLOCK])), tmp24, rmask)
tl.store(out_ptr2 + (tl.broadcast_to(r0, [XBLOCK, RBLOCK])), tmp24, rmask)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_ubuntu/z6/cz6oxi5i6lto4h7ib75rlkpzqsocefhcbwhm7nhdpxravdtgiuro.py
# Source Nodes: [logits_1], Original ATen: [aten.div]
# logits_1 => div_32
triton_poi_fused_div_16 = async_compile.triton('triton_poi_fused_div_16', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
@triton_heuristics.pointwise(
size_hints=[32768],
filename=__file__,
triton_meta={'signature': {0: '*fp16', 1: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=76), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1), equal_to_1=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_div_16', 'mutated_arg_names': ['in_out_ptr0'], 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': '5002d3b4f89584c56f8df13bbe2510318f7ec3b73ba9efd611640fb08b2fe939', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False},
min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused_div_16(in_out_ptr0, xnumel, XBLOCK : tl.constexpr):
xnumel = 32000
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = xindex < xnumel
x0 = xindex
tmp0 = tl.load(in_out_ptr0 + (x0), xmask).to(tl.float32)
tmp1 = 1.25
tmp2 = tmp0 * tmp1
tl.store(in_out_ptr0 + (x0), tmp2, xmask)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_ubuntu/xv/cxvjbig2sbormp2kxcubuvv2h5v3vxuc4jta7llq4rnh6kwuxhuo.py
# Source Nodes: [logits_2, lt, probs], Original ATen: [aten._softmax, aten.lt, aten.scalar_tensor, aten.where]
# logits_2 => full_default_64, where_32
# lt => lt
# probs => amax_32, convert_element_type_578
triton_red_fused__softmax_lt_scalar_tensor_where_17 = async_compile.triton('triton_red_fused__softmax_lt_scalar_tensor_where_17', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
@triton_heuristics.reduction(
size_hints=[4, 8192],
reduction_hint=ReductionHint.INNER,
filename=__file__,
triton_meta={'signature': {0: '*fp16', 1: '*fp16', 2: '*fp32', 3: 'i32', 4: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=76), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 4), equal_to_1=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused__softmax_lt_scalar_tensor_where_17', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 2, 'num_reduction': 1, 'backend_hash': '5002d3b4f89584c56f8df13bbe2510318f7ec3b73ba9efd611640fb08b2fe939', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False}
)
@triton.jit
def triton_red_fused__softmax_lt_scalar_tensor_where_17(in_ptr0, in_ptr1, out_ptr0, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr):
xnumel = 4
rnumel = 8000
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
xmask = xindex < xnumel
rbase = tl.arange(0, RBLOCK)[None, :]
x0 = xindex
tmp1 = tl.load(in_ptr1 + (199)).to(tl.float32)
tmp2 = tl.broadcast_to(tmp1, [XBLOCK, RBLOCK])
_tmp8 = tl.full([XBLOCK, RBLOCK], float("-inf"), tl.float32)
for roffset in range(0, rnumel, RBLOCK):
rindex = roffset + rbase
rmask = rindex < rnumel
r1 = rindex
tmp0 = tl.load(in_ptr0 + (r1 + (8000*x0)), rmask & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp3 = tmp0 < tmp2
tmp4 = float("-inf")
tmp5 = tl.where(tmp3, tmp4, tmp0)
tmp6 = tmp5.to(tl.float32)
tmp7 = tl.broadcast_to(tmp6, [XBLOCK, RBLOCK])
tmp9 = triton_helpers.maximum(_tmp8, tmp7)
_tmp8 = tl.where(rmask & xmask, tmp9, _tmp8)
tmp8 = triton_helpers.max2(_tmp8, 1)[:, None]
tl.store(out_ptr0 + (x0), tmp8, xmask)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_ubuntu/hz/chzqxkcllimh3vsvmsxoviuxtefrsfgsde3el6pgijogxpaxrs7r.py
# Source Nodes: [logits_2, lt, probs], Original ATen: [aten._softmax, aten.lt, aten.scalar_tensor, aten.where]
# logits_2 => full_default_64, where_32
# lt => lt
# probs => amax_32, convert_element_type_578
triton_per_fused__softmax_lt_scalar_tensor_where_18 = async_compile.triton('triton_per_fused__softmax_lt_scalar_tensor_where_18', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
@triton_heuristics.persistent_reduction(
size_hints=[1, 4],
reduction_hint=ReductionHint.INNER,
filename=__file__,
triton_meta={'signature': {0: '*fp32', 1: '*fp32', 2: 'i32', 3: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=76), 'constants': {2: 1}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1), equal_to_1=(2,))]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_per_fused__softmax_lt_scalar_tensor_where_18', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': '5002d3b4f89584c56f8df13bbe2510318f7ec3b73ba9efd611640fb08b2fe939', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False}
)
@triton.jit
def triton_per_fused__softmax_lt_scalar_tensor_where_18(in_ptr0, out_ptr0, xnumel, rnumel, XBLOCK : tl.constexpr):
xnumel = 1
rnumel = 4
RBLOCK: tl.constexpr = 4
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
xmask = xindex < xnumel
rindex = tl.arange(0, RBLOCK)[None, :]
roffset = 0
rmask = rindex < rnumel
r0 = rindex
tmp0 = tl.load(in_ptr0 + (r0), rmask, other=0.0)
tmp1 = tl.broadcast_to(tmp0, [XBLOCK, RBLOCK])
tmp3 = tl.where(rmask, tmp1, float("-inf"))
tmp4 = triton_helpers.max2(tmp3, 1)[:, None]
tl.store(out_ptr0 + (tl.full([XBLOCK, 1], 0, tl.int32)), tmp4, None)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_ubuntu/za/czakkdksd2yxt6hanh7vxqjmbj7ara4d4rza6bwjwwt4stxakxtm.py
# Source Nodes: [logits_2, lt, probs], Original ATen: [aten._softmax, aten.lt, aten.scalar_tensor, aten.where]
# logits_2 => full_default_64, where_32
# lt => lt
# probs => convert_element_type_578, exp_32, sub_96, sum_97
triton_red_fused__softmax_lt_scalar_tensor_where_19 = async_compile.triton('triton_red_fused__softmax_lt_scalar_tensor_where_19', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
@triton_heuristics.reduction(
size_hints=[4, 8192],
reduction_hint=ReductionHint.INNER,
filename=__file__,
triton_meta={'signature': {0: '*fp16', 1: '*fp16', 2: '*fp32', 3: '*fp32', 4: 'i32', 5: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=76), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 5), equal_to_1=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused__softmax_lt_scalar_tensor_where_19', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 3, 'num_reduction': 1, 'backend_hash': '5002d3b4f89584c56f8df13bbe2510318f7ec3b73ba9efd611640fb08b2fe939', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False}
)
@triton.jit
def triton_red_fused__softmax_lt_scalar_tensor_where_19(in_ptr0, in_ptr1, in_ptr2, out_ptr0, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr):
xnumel = 4
rnumel = 8000
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
xmask = xindex < xnumel
rbase = tl.arange(0, RBLOCK)[None, :]
x0 = xindex
tmp1 = tl.load(in_ptr1 + (199)).to(tl.float32)
tmp2 = tl.broadcast_to(tmp1, [XBLOCK, RBLOCK])
tmp7 = tl.load(in_ptr2 + (0))
tmp8 = tl.broadcast_to(tmp7, [XBLOCK, RBLOCK])
_tmp12 = tl.full([XBLOCK, RBLOCK], 0, tl.float32)
for roffset in range(0, rnumel, RBLOCK):
rindex = roffset + rbase
rmask = rindex < rnumel
r1 = rindex
tmp0 = tl.load(in_ptr0 + (r1 + (8000*x0)), rmask & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp3 = tmp0 < tmp2
tmp4 = float("-inf")
tmp5 = tl.where(tmp3, tmp4, tmp0)
tmp6 = tmp5.to(tl.float32)
tmp9 = tmp6 - tmp8
tmp10 = tl_math.exp(tmp9)
tmp11 = tl.broadcast_to(tmp10, [XBLOCK, RBLOCK])
tmp13 = _tmp12 + tmp11
_tmp12 = tl.where(rmask & xmask, tmp13, _tmp12)
tmp12 = tl.sum(_tmp12, 1)[:, None]
tl.store(out_ptr0 + (x0), tmp12, xmask)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_ubuntu/aa/caad4sxzjec6myyhypd4e6odaasgzubfdw6lhsfkhrv5jqg4uc7g.py
# Source Nodes: [logits_2, lt, probs], Original ATen: [aten._softmax, aten.lt, aten.scalar_tensor, aten.where]
# logits_2 => full_default_64, where_32
# lt => lt
# probs => convert_element_type_578, exp_32, sub_96, sum_97
triton_per_fused__softmax_lt_scalar_tensor_where_20 = async_compile.triton('triton_per_fused__softmax_lt_scalar_tensor_where_20', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
@triton_heuristics.persistent_reduction(
size_hints=[1, 4],
reduction_hint=ReductionHint.INNER,
filename=__file__,
triton_meta={'signature': {0: '*fp32', 1: '*fp32', 2: 'i32', 3: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=76), 'constants': {2: 1}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1), equal_to_1=(2,))]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_per_fused__softmax_lt_scalar_tensor_where_20', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': '5002d3b4f89584c56f8df13bbe2510318f7ec3b73ba9efd611640fb08b2fe939', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False}
)
@triton.jit
def triton_per_fused__softmax_lt_scalar_tensor_where_20(in_ptr0, out_ptr0, xnumel, rnumel, XBLOCK : tl.constexpr):
xnumel = 1
rnumel = 4
RBLOCK: tl.constexpr = 4
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
xmask = xindex < xnumel
rindex = tl.arange(0, RBLOCK)[None, :]
roffset = 0
rmask = rindex < rnumel
r0 = rindex
tmp0 = tl.load(in_ptr0 + (r0), rmask, other=0.0)
tmp1 = tl.broadcast_to(tmp0, [XBLOCK, RBLOCK])
tmp3 = tl.where(rmask, tmp1, 0)
tmp4 = tl.sum(tmp3, 1)[:, None]
tl.store(out_ptr0 + (tl.full([XBLOCK, 1], 0, tl.int32)), tmp4, None)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_ubuntu/ch/cch6v3lnecv5tfe5ipqpowh4xge3iiqm4nrz67gnjbwlbump6mow.py
# Source Nodes: [argmax, idx_next, logits_2, lt, probs, q_128, truediv_1], Original ATen: [aten._softmax, aten._to_copy, aten.argmax, aten.div, aten.exponential, aten.lt, aten.scalar_tensor, aten.where]
# argmax => argmax
# idx_next => convert_element_type_582
# logits_2 => full_default_64, where_32
# lt => lt
# probs => convert_element_type_578, convert_element_type_579, div_33, exp_32, sub_96
# q_128 => convert_element_type_581, log1p, mul_643, neg
# truediv_1 => div_34
triton_red_fused__softmax__to_copy_argmax_div_exponential_lt_scalar_tensor_where_21 = async_compile.triton('triton_red_fused__softmax__to_copy_argmax_div_exponential_lt_scalar_tensor_where_21', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
@triton_heuristics.reduction(
size_hints=[1, 32768],
reduction_hint=ReductionHint.DEFAULT,
filename=__file__,
triton_meta={'signature': {0: '*fp16', 1: '*fp16', 2: '*fp32', 3: '*fp32', 4: '*i64', 5: '*i32', 6: 'i32', 7: 'i32', 8: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=76), 'constants': {7: 1}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4, 5, 8), equal_to_1=(7,))]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused__softmax__to_copy_argmax_div_exponential_lt_scalar_tensor_where_21', 'mutated_arg_names': ['in_out_ptr0'], 'no_x_dim': False, 'num_load': 4, 'num_reduction': 1, 'backend_hash': '5002d3b4f89584c56f8df13bbe2510318f7ec3b73ba9efd611640fb08b2fe939', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False}
)
@triton.jit
def triton_red_fused__softmax__to_copy_argmax_div_exponential_lt_scalar_tensor_where_21(in_out_ptr0, in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr2, load_seed_offset, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr):
xnumel = 1
rnumel = 32000
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
xmask = xindex < xnumel
rbase = tl.arange(0, RBLOCK)[None, :]
tmp1 = tl.load(in_ptr0 + (199)).to(tl.float32)
tmp2 = tl.broadcast_to(tmp1, [XBLOCK, RBLOCK])
tmp7 = tl.load(in_ptr1 + (0))
tmp8 = tl.broadcast_to(tmp7, [XBLOCK, RBLOCK])
tmp11 = tl.load(in_ptr2 + (0))
tmp12 = tl.broadcast_to(tmp11, [XBLOCK, RBLOCK])
_tmp25 = tl.full([XBLOCK, RBLOCK], float("-inf"), tl.float32)
_tmp25_index = tl.full([XBLOCK, RBLOCK], 9223372036854775807, tl.int64)
for roffset in range(0, rnumel, RBLOCK):
rindex = roffset + rbase
rmask = rindex < rnumel
r0 = rindex
tmp0 = tl.load(in_out_ptr0 + (r0), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp3 = tmp0 < tmp2
tmp4 = float("-inf")
tmp5 = tl.where(tmp3, tmp4, tmp0)
tmp6 = tmp5.to(tl.float32)
tmp9 = tmp6 - tmp8
tmp10 = tl_math.exp(tmp9)
tmp13 = tmp10 / tmp12
tmp14 = tmp13.to(tl.float32)
tmp15 = tl.load(in_ptr3 + load_seed_offset)
tmp16 = r0
tmp17 = tl.rand(tmp15, (tmp16).to(tl.uint32))
tmp18 = -tmp17
tmp19 = libdevice.log1p(tmp18)
tmp20 = -1.0
tmp21 = tmp19 * tmp20
tmp22 = tmp21.to(tl.float32)
tmp23 = tmp14 / tmp22
tmp24 = tl.broadcast_to(tmp23, [XBLOCK, RBLOCK])
_tmp25_next, _tmp25_index_next = triton_helpers.maximum_with_index(
_tmp25, _tmp25_index, tmp24, rindex
)
_tmp25 = tl.where(rmask, _tmp25_next, _tmp25)
_tmp25_index = tl.where(rmask, _tmp25_index_next, _tmp25_index)
tl.store(in_out_ptr0 + (tl.broadcast_to(r0, [XBLOCK, RBLOCK])), tmp14, rmask)
_, tmp25_tmp = triton_helpers.max_with_index(_tmp25, _tmp25_index, 1)
tmp25 = tmp25_tmp[:, None]
tmp26 = tmp25.to(tl.int32)
tl.store(out_ptr2 + (tl.full([XBLOCK, 1], 0, tl.int32)), tmp26, None)
''', device_str='cuda')
async_compile.wait(globals())
del async_compile
def call(args):
arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1, arg7_1, arg8_1, arg9_1, arg10_1, arg11_1, arg12_1, arg13_1, arg14_1, arg15_1, arg16_1, arg17_1, arg18_1, arg19_1, arg20_1, arg21_1, arg22_1, arg23_1, arg24_1, arg25_1, arg26_1, arg27_1, arg28_1, arg29_1, arg30_1, arg31_1, arg32_1, arg33_1, arg34_1, arg35_1, arg36_1, arg37_1, arg38_1, arg39_1, arg40_1, arg41_1, arg42_1, arg43_1, arg44_1, arg45_1, arg46_1, arg47_1, arg48_1, arg49_1, arg50_1, arg51_1, arg52_1, arg53_1, arg54_1, arg55_1, arg56_1, arg57_1, arg58_1, arg59_1, arg60_1, arg61_1, arg62_1, arg63_1, arg64_1, arg65_1, arg66_1, arg67_1, arg68_1, arg69_1, arg70_1, arg71_1, arg72_1, arg73_1, arg74_1, arg75_1, arg76_1, arg77_1, arg78_1, arg79_1, arg80_1, arg81_1, arg82_1, arg83_1, arg84_1, arg85_1, arg86_1, arg87_1, arg88_1, arg89_1, arg90_1, arg91_1, arg92_1, arg93_1, arg94_1, arg95_1, arg96_1, arg97_1, arg98_1, arg99_1, arg100_1, arg101_1, arg102_1, arg103_1, arg104_1, arg105_1, arg106_1, arg107_1, arg108_1, arg109_1, arg110_1, arg111_1, arg112_1, arg113_1, arg114_1, arg115_1, arg116_1, arg117_1, arg118_1, arg119_1, arg120_1, arg121_1, arg122_1, arg123_1, arg124_1, arg125_1, arg126_1, arg127_1, arg128_1, arg129_1, arg130_1, arg131_1, arg132_1, arg133_1, arg134_1, arg135_1, arg136_1, arg137_1, arg138_1, arg139_1, arg140_1, arg141_1, arg142_1, arg143_1, arg144_1, arg145_1, arg146_1, arg147_1, arg148_1, arg149_1, arg150_1, arg151_1, arg152_1, arg153_1, arg154_1, arg155_1, arg156_1, arg157_1, arg158_1, arg159_1, arg160_1, arg161_1, arg162_1, arg163_1, arg164_1, arg165_1, arg166_1, arg167_1, arg168_1, arg169_1, arg170_1, arg171_1, arg172_1, arg173_1, arg174_1, arg175_1, arg176_1, arg177_1, arg178_1, arg179_1, arg180_1, arg181_1, arg182_1, arg183_1, arg184_1, arg185_1, arg186_1, arg187_1, arg188_1, arg189_1, arg190_1, arg191_1, arg192_1, arg193_1, arg194_1, arg195_1, arg196_1, arg197_1, arg198_1, arg199_1, arg200_1, arg201_1, arg202_1, arg203_1, arg204_1, arg205_1, arg206_1, arg207_1, arg208_1, arg209_1, arg210_1, arg211_1, arg212_1, arg213_1, arg214_1, arg215_1, arg216_1, arg217_1, arg218_1, arg219_1, arg220_1, arg221_1, arg222_1, arg223_1, arg224_1, arg225_1, arg226_1, arg227_1, arg228_1, arg229_1, arg230_1, arg231_1, arg232_1, arg233_1, arg234_1, arg235_1, arg236_1, arg237_1, arg238_1, arg239_1, arg240_1, arg241_1, arg242_1, arg243_1, arg244_1, arg245_1, arg246_1, arg247_1, arg248_1, arg249_1, arg250_1, arg251_1, arg252_1, arg253_1, arg254_1, arg255_1, arg256_1, arg257_1, arg258_1, arg259_1, arg260_1, arg261_1, arg262_1, arg263_1, arg264_1, arg265_1, arg266_1, arg267_1, arg268_1, arg269_1, arg270_1, arg271_1, arg272_1, arg273_1, arg274_1, arg275_1, arg276_1, arg277_1, arg278_1, arg279_1, arg280_1, arg281_1, arg282_1, arg283_1, arg284_1, arg285_1, arg286_1, arg287_1, arg288_1, arg289_1, arg290_1, arg291_1, arg292_1, arg293_1, arg294_1, arg295_1, arg296_1, arg297_1, arg298_1, arg299_1, arg300_1, arg301_1, arg302_1, arg303_1, arg304_1, arg305_1, arg306_1, arg307_1, arg308_1, arg309_1, arg310_1, arg311_1, arg312_1, arg313_1, arg314_1, arg315_1, arg316_1, arg317_1, arg318_1, arg319_1, arg320_1, arg321_1, arg322_1, arg323_1, arg324_1, arg325_1, arg326_1, arg327_1, arg328_1, arg329_1, arg330_1, arg331_1, arg332_1, arg333_1, arg334_1, arg335_1, arg336_1, arg337_1, arg338_1, arg339_1, arg340_1, arg341_1, arg342_1, arg343_1, arg344_1, arg345_1, arg346_1, arg347_1, arg348_1, arg349_1, arg350_1, arg351_1, arg352_1, arg353_1, arg354_1, arg355_1, arg356_1, arg357_1, arg358_1, arg359_1, arg360_1, arg361_1, arg362_1, arg363_1, arg364_1, arg365_1, arg366_1, arg367_1, arg368_1, arg369_1, arg370_1, arg371_1, arg372_1, arg373_1, arg374_1, arg375_1, arg376_1, arg377_1, arg378_1, arg379_1, arg380_1, arg381_1, arg382_1, arg383_1, arg384_1, arg385_1, arg386_1, arg387_1, arg388_1, arg389_1, arg390_1, arg391_1, arg392_1, arg393_1, arg394_1, arg395_1, arg396_1, arg397_1, arg398_1, arg399_1, arg400_1, arg401_1, arg402_1, arg403_1, arg404_1, arg405_1, arg406_1, arg407_1, arg408_1, arg409_1, arg410_1, arg411_1, arg412_1, arg413_1, arg414_1, arg415_1, arg416_1, arg417_1, arg418_1, arg419_1, arg420_1, arg421_1, arg422_1, arg423_1, arg424_1, arg425_1, arg426_1, arg427_1, arg428_1, arg429_1, arg430_1, arg431_1, arg432_1, arg433_1, arg434_1, arg435_1, arg436_1, arg437_1, arg438_1, arg439_1, arg440_1, arg441_1, arg442_1, arg443_1, arg444_1, arg445_1, arg446_1, arg447_1, arg448_1, arg449_1, arg450_1, arg451_1, arg452_1, arg453_1, arg454_1, arg455_1 = args
args.clear()
assert_size_stride(arg0_1, (4096, ), (1, ))
assert_size_stride(arg1_1, (4096, ), (1, ))
assert_size_stride(arg2_1, (4096, ), (1, ))
assert_size_stride(arg3_1, (4096, ), (1, ))
assert_size_stride(arg4_1, (4096, ), (1, ))
assert_size_stride(arg5_1, (4096, ), (1, ))
assert_size_stride(arg6_1, (4096, ), (1, ))
assert_size_stride(arg7_1, (4096, ), (1, ))
assert_size_stride(arg8_1, (4096, ), (1, ))
assert_size_stride(arg9_1, (4096, ), (1, ))
assert_size_stride(arg10_1, (4096, ), (1, ))
assert_size_stride(arg11_1, (4096, ), (1, ))
assert_size_stride(arg12_1, (4096, ), (1, ))
assert_size_stride(arg13_1, (4096, ), (1, ))
assert_size_stride(arg14_1, (4096, ), (1, ))
assert_size_stride(arg15_1, (4096, ), (1, ))
assert_size_stride(arg16_1, (4096, ), (1, ))
assert_size_stride(arg17_1, (4096, ), (1, ))
assert_size_stride(arg18_1, (4096, ), (1, ))
assert_size_stride(arg19_1, (4096, ), (1, ))
assert_size_stride(arg20_1, (4096, ), (1, ))
assert_size_stride(arg21_1, (4096, ), (1, ))
assert_size_stride(arg22_1, (4096, ), (1, ))
assert_size_stride(arg23_1, (4096, ), (1, ))
assert_size_stride(arg24_1, (4096, ), (1, ))
assert_size_stride(arg25_1, (4096, ), (1, ))
assert_size_stride(arg26_1, (4096, ), (1, ))
assert_size_stride(arg27_1, (4096, ), (1, ))
assert_size_stride(arg28_1, (4096, ), (1, ))
assert_size_stride(arg29_1, (4096, ), (1, ))
assert_size_stride(arg30_1, (4096, ), (1, ))
assert_size_stride(arg31_1, (4096, ), (1, ))
assert_size_stride(arg32_1, (4096, ), (1, ))
assert_size_stride(arg33_1, (4096, ), (1, ))
assert_size_stride(arg34_1, (4096, ), (1, ))
assert_size_stride(arg35_1, (4096, ), (1, ))
assert_size_stride(arg36_1, (4096, ), (1, ))
assert_size_stride(arg37_1, (4096, ), (1, ))
assert_size_stride(arg38_1, (4096, ), (1, ))
assert_size_stride(arg39_1, (4096, ), (1, ))
assert_size_stride(arg40_1, (4096, ), (1, ))
assert_size_stride(arg41_1, (4096, ), (1, ))
assert_size_stride(arg42_1, (4096, ), (1, ))
assert_size_stride(arg43_1, (4096, ), (1, ))
assert_size_stride(arg44_1, (4096, ), (1, ))
assert_size_stride(arg45_1, (4096, ), (1, ))
assert_size_stride(arg46_1, (4096, ), (1, ))
assert_size_stride(arg47_1, (4096, ), (1, ))
assert_size_stride(arg48_1, (4096, ), (1, ))
assert_size_stride(arg49_1, (4096, ), (1, ))
assert_size_stride(arg50_1, (4096, ), (1, ))
assert_size_stride(arg51_1, (4096, ), (1, ))
assert_size_stride(arg52_1, (4096, ), (1, ))
assert_size_stride(arg53_1, (4096, ), (1, ))
assert_size_stride(arg54_1, (4096, ), (1, ))
assert_size_stride(arg55_1, (4096, ), (1, ))
assert_size_stride(arg56_1, (4096, ), (1, ))
assert_size_stride(arg57_1, (4096, ), (1, ))
assert_size_stride(arg58_1, (4096, ), (1, ))
assert_size_stride(arg59_1, (4096, ), (1, ))
assert_size_stride(arg60_1, (4096, ), (1, ))
assert_size_stride(arg61_1, (4096, ), (1, ))
assert_size_stride(arg62_1, (4096, ), (1, ))
assert_size_stride(arg63_1, (4096, ), (1, ))
assert_size_stride(arg64_1, (4096, ), (1, ))
assert_size_stride(arg65_1, (32000, 4096), (4096, 1))
assert_size_stride(arg66_1, (2048, 64, 2), (128, 2, 1))
assert_size_stride(arg67_1, (208, 208), (208, 1))
assert_size_stride(arg68_1, (12288, 768), (768, 1))
assert_size_stride(arg69_1, (12288, ), (1, ))
assert_size_stride(arg70_1, (1, 32, 208, 128), (851968, 26624, 128, 1))
assert_size_stride(arg71_1, (1, 32, 208, 128), (851968, 26624, 128, 1))
assert_size_stride(arg72_1, (4096, 768), (768, 1))
assert_size_stride(arg73_1, (4096, ), (1, ))
assert_size_stride(arg74_1, (11008, 768), (768, 1))
assert_size_stride(arg75_1, (11008, ), (1, ))
assert_size_stride(arg76_1, (11008, 768), (768, 1))
assert_size_stride(arg77_1, (11008, ), (1, ))
assert_size_stride(arg78_1, (4096, 2064), (2064, 1))
assert_size_stride(arg79_1, (4096, ), (1, ))
assert_size_stride(arg80_1, (12288, 768), (768, 1))
assert_size_stride(arg81_1, (12288, ), (1, ))
assert_size_stride(arg82_1, (1, 32, 208, 128), (851968, 26624, 128, 1))
assert_size_stride(arg83_1, (1, 32, 208, 128), (851968, 26624, 128, 1))
assert_size_stride(arg84_1, (4096, 768), (768, 1))
assert_size_stride(arg85_1, (4096, ), (1, ))
assert_size_stride(arg86_1, (11008, 768), (768, 1))
assert_size_stride(arg87_1, (11008, ), (1, ))
assert_size_stride(arg88_1, (11008, 768), (768, 1))
assert_size_stride(arg89_1, (11008, ), (1, ))
assert_size_stride(arg90_1, (4096, 2064), (2064, 1))
assert_size_stride(arg91_1, (4096, ), (1, ))
assert_size_stride(arg92_1, (12288, 768), (768, 1))
assert_size_stride(arg93_1, (12288, ), (1, ))
assert_size_stride(arg94_1, (1, 32, 208, 128), (851968, 26624, 128, 1))
assert_size_stride(arg95_1, (1, 32, 208, 128), (851968, 26624, 128, 1))
assert_size_stride(arg96_1, (4096, 768), (768, 1))
assert_size_stride(arg97_1, (4096, ), (1, ))
assert_size_stride(arg98_1, (11008, 768), (768, 1))
assert_size_stride(arg99_1, (11008, ), (1, ))
assert_size_stride(arg100_1, (11008, 768), (768, 1))
assert_size_stride(arg101_1, (11008, ), (1, ))
assert_size_stride(arg102_1, (4096, 2064), (2064, 1))
assert_size_stride(arg103_1, (4096, ), (1, ))
assert_size_stride(arg104_1, (12288, 768), (768, 1))
assert_size_stride(arg105_1, (12288, ), (1, ))
assert_size_stride(arg106_1, (1, 32, 208, 128), (851968, 26624, 128, 1))
assert_size_stride(arg107_1, (1, 32, 208, 128), (851968, 26624, 128, 1))
assert_size_stride(arg108_1, (4096, 768), (768, 1))
assert_size_stride(arg109_1, (4096, ), (1, ))
assert_size_stride(arg110_1, (11008, 768), (768, 1))
assert_size_stride(arg111_1, (11008, ), (1, ))
assert_size_stride(arg112_1, (11008, 768), (768, 1))
assert_size_stride(arg113_1, (11008, ), (1, ))
assert_size_stride(arg114_1, (4096, 2064), (2064, 1))
assert_size_stride(arg115_1, (4096, ), (1, ))
assert_size_stride(arg116_1, (12288, 768), (768, 1))
assert_size_stride(arg117_1, (12288, ), (1, ))
assert_size_stride(arg118_1, (1, 32, 208, 128), (851968, 26624, 128, 1))
assert_size_stride(arg119_1, (1, 32, 208, 128), (851968, 26624, 128, 1))
assert_size_stride(arg120_1, (4096, 768), (768, 1))
assert_size_stride(arg121_1, (4096, ), (1, ))
assert_size_stride(arg122_1, (11008, 768), (768, 1))
assert_size_stride(arg123_1, (11008, ), (1, ))
assert_size_stride(arg124_1, (11008, 768), (768, 1))
assert_size_stride(arg125_1, (11008, ), (1, ))
assert_size_stride(arg126_1, (4096, 2064), (2064, 1))
assert_size_stride(arg127_1, (4096, ), (1, ))
assert_size_stride(arg128_1, (12288, 768), (768, 1))
assert_size_stride(arg129_1, (12288, ), (1, ))
assert_size_stride(arg130_1, (1, 32, 208, 128), (851968, 26624, 128, 1))
assert_size_stride(arg131_1, (1, 32, 208, 128), (851968, 26624, 128, 1))
assert_size_stride(arg132_1, (4096, 768), (768, 1))
assert_size_stride(arg133_1, (4096, ), (1, ))
assert_size_stride(arg134_1, (11008, 768), (768, 1))
assert_size_stride(arg135_1, (11008, ), (1, ))
assert_size_stride(arg136_1, (11008, 768), (768, 1))
assert_size_stride(arg137_1, (11008, ), (1, ))
assert_size_stride(arg138_1, (4096, 2064), (2064, 1))
assert_size_stride(arg139_1, (4096, ), (1, ))
assert_size_stride(arg140_1, (12288, 768), (768, 1))
assert_size_stride(arg141_1, (12288, ), (1, ))
assert_size_stride(arg142_1, (1, 32, 208, 128), (851968, 26624, 128, 1))
assert_size_stride(arg143_1, (1, 32, 208, 128), (851968, 26624, 128, 1))
assert_size_stride(arg144_1, (4096, 768), (768, 1))
assert_size_stride(arg145_1, (4096, ), (1, ))
assert_size_stride(arg146_1, (11008, 768), (768, 1))
assert_size_stride(arg147_1, (11008, ), (1, ))
assert_size_stride(arg148_1, (11008, 768), (768, 1))
assert_size_stride(arg149_1, (11008, ), (1, ))
assert_size_stride(arg150_1, (4096, 2064), (2064, 1))
assert_size_stride(arg151_1, (4096, ), (1, ))
assert_size_stride(arg152_1, (12288, 768), (768, 1))
assert_size_stride(arg153_1, (12288, ), (1, ))
assert_size_stride(arg154_1, (1, 32, 208, 128), (851968, 26624, 128, 1))
assert_size_stride(arg155_1, (1, 32, 208, 128), (851968, 26624, 128, 1))
assert_size_stride(arg156_1, (4096, 768), (768, 1))
assert_size_stride(arg157_1, (4096, ), (1, ))
assert_size_stride(arg158_1, (11008, 768), (768, 1))
assert_size_stride(arg159_1, (11008, ), (1, ))
assert_size_stride(arg160_1, (11008, 768), (768, 1))
assert_size_stride(arg161_1, (11008, ), (1, ))
assert_size_stride(arg162_1, (4096, 2064), (2064, 1))
assert_size_stride(arg163_1, (4096, ), (1, ))
assert_size_stride(arg164_1, (12288, 768), (768, 1))
assert_size_stride(arg165_1, (12288, ), (1, ))
assert_size_stride(arg166_1, (1, 32, 208, 128), (851968, 26624, 128, 1))
assert_size_stride(arg167_1, (1, 32, 208, 128), (851968, 26624, 128, 1))
assert_size_stride(arg168_1, (4096, 768), (768, 1))
assert_size_stride(arg169_1, (4096, ), (1, ))
assert_size_stride(arg170_1, (11008, 768), (768, 1))
assert_size_stride(arg171_1, (11008, ), (1, ))
assert_size_stride(arg172_1, (11008, 768), (768, 1))
assert_size_stride(arg173_1, (11008, ), (1, ))
assert_size_stride(arg174_1, (4096, 2064), (2064, 1))
assert_size_stride(arg175_1, (4096, ), (1, ))
assert_size_stride(arg176_1, (12288, 768), (768, 1))
assert_size_stride(arg177_1, (12288, ), (1, ))
assert_size_stride(arg178_1, (1, 32, 208, 128), (851968, 26624, 128, 1))
assert_size_stride(arg179_1, (1, 32, 208, 128), (851968, 26624, 128, 1))
assert_size_stride(arg180_1, (4096, 768), (768, 1))
assert_size_stride(arg181_1, (4096, ), (1, ))
assert_size_stride(arg182_1, (11008, 768), (768, 1))
assert_size_stride(arg183_1, (11008, ), (1, ))
assert_size_stride(arg184_1, (11008, 768), (768, 1))
assert_size_stride(arg185_1, (11008, ), (1, ))
assert_size_stride(arg186_1, (4096, 2064), (2064, 1))
assert_size_stride(arg187_1, (4096, ), (1, ))
assert_size_stride(arg188_1, (12288, 768), (768, 1))
assert_size_stride(arg189_1, (12288, ), (1, ))
assert_size_stride(arg190_1, (1, 32, 208, 128), (851968, 26624, 128, 1))
assert_size_stride(arg191_1, (1, 32, 208, 128), (851968, 26624, 128, 1))
assert_size_stride(arg192_1, (4096, 768), (768, 1))
assert_size_stride(arg193_1, (4096, ), (1, ))
assert_size_stride(arg194_1, (11008, 768), (768, 1))
assert_size_stride(arg195_1, (11008, ), (1, ))
assert_size_stride(arg196_1, (11008, 768), (768, 1))
assert_size_stride(arg197_1, (11008, ), (1, ))
assert_size_stride(arg198_1, (4096, 2064), (2064, 1))
assert_size_stride(arg199_1, (4096, ), (1, ))
assert_size_stride(arg200_1, (12288, 768), (768, 1))
assert_size_stride(arg201_1, (12288, ), (1, ))
assert_size_stride(arg202_1, (1, 32, 208, 128), (851968, 26624, 128, 1))
assert_size_stride(arg203_1, (1, 32, 208, 128), (851968, 26624, 128, 1))
assert_size_stride(arg204_1, (4096, 768), (768, 1))
assert_size_stride(arg205_1, (4096, ), (1, ))
assert_size_stride(arg206_1, (11008, 768), (768, 1))
assert_size_stride(arg207_1, (11008, ), (1, ))
assert_size_stride(arg208_1, (11008, 768), (768, 1))
assert_size_stride(arg209_1, (11008, ), (1, ))
assert_size_stride(arg210_1, (4096, 2064), (2064, 1))
assert_size_stride(arg211_1, (4096, ), (1, ))
assert_size_stride(arg212_1, (12288, 768), (768, 1))
assert_size_stride(arg213_1, (12288, ), (1, ))
assert_size_stride(arg214_1, (1, 32, 208, 128), (851968, 26624, 128, 1))
assert_size_stride(arg215_1, (1, 32, 208, 128), (851968, 26624, 128, 1))
assert_size_stride(arg216_1, (4096, 768), (768, 1))
assert_size_stride(arg217_1, (4096, ), (1, ))
assert_size_stride(arg218_1, (11008, 768), (768, 1))
assert_size_stride(arg219_1, (11008, ), (1, ))
assert_size_stride(arg220_1, (11008, 768), (768, 1))
assert_size_stride(arg221_1, (11008, ), (1, ))
assert_size_stride(arg222_1, (4096, 2064), (2064, 1))
assert_size_stride(arg223_1, (4096, ), (1, ))
assert_size_stride(arg224_1, (12288, 768), (768, 1))
assert_size_stride(arg225_1, (12288, ), (1, ))
assert_size_stride(arg226_1, (1, 32, 208, 128), (851968, 26624, 128, 1))
assert_size_stride(arg227_1, (1, 32, 208, 128), (851968, 26624, 128, 1))
assert_size_stride(arg228_1, (4096, 768), (768, 1))
assert_size_stride(arg229_1, (4096, ), (1, ))
assert_size_stride(arg230_1, (11008, 768), (768, 1))
assert_size_stride(arg231_1, (11008, ), (1, ))
assert_size_stride(arg232_1, (11008, 768), (768, 1))
assert_size_stride(arg233_1, (11008, ), (1, ))
assert_size_stride(arg234_1, (4096, 2064), (2064, 1))
assert_size_stride(arg235_1, (4096, ), (1, ))
assert_size_stride(arg236_1, (12288, 768), (768, 1))
assert_size_stride(arg237_1, (12288, ), (1, ))
assert_size_stride(arg238_1, (1, 32, 208, 128), (851968, 26624, 128, 1))
assert_size_stride(arg239_1, (1, 32, 208, 128), (851968, 26624, 128, 1))
assert_size_stride(arg240_1, (4096, 768), (768, 1))
assert_size_stride(arg241_1, (4096, ), (1, ))
assert_size_stride(arg242_1, (11008, 768), (768, 1))
assert_size_stride(arg243_1, (11008, ), (1, ))
assert_size_stride(arg244_1, (11008, 768), (768, 1))
assert_size_stride(arg245_1, (11008, ), (1, ))
assert_size_stride(arg246_1, (4096, 2064), (2064, 1))
assert_size_stride(arg247_1, (4096, ), (1, ))
assert_size_stride(arg248_1, (12288, 768), (768, 1))
assert_size_stride(arg249_1, (12288, ), (1, ))
assert_size_stride(arg250_1, (1, 32, 208, 128), (851968, 26624, 128, 1))
assert_size_stride(arg251_1, (1, 32, 208, 128), (851968, 26624, 128, 1))
assert_size_stride(arg252_1, (4096, 768), (768, 1))
assert_size_stride(arg253_1, (4096, ), (1, ))
assert_size_stride(arg254_1, (11008, 768), (768, 1))
assert_size_stride(arg255_1, (11008, ), (1, ))
assert_size_stride(arg256_1, (11008, 768), (768, 1))
assert_size_stride(arg257_1, (11008, ), (1, ))
assert_size_stride(arg258_1, (4096, 2064), (2064, 1))
assert_size_stride(arg259_1, (4096, ), (1, ))
assert_size_stride(arg260_1, (12288, 768), (768, 1))
assert_size_stride(arg261_1, (12288, ), (1, ))
assert_size_stride(arg262_1, (1, 32, 208, 128), (851968, 26624, 128, 1))
assert_size_stride(arg263_1, (1, 32, 208, 128), (851968, 26624, 128, 1))
assert_size_stride(arg264_1, (4096, 768), (768, 1))
assert_size_stride(arg265_1, (4096, ), (1, ))
assert_size_stride(arg266_1, (11008, 768), (768, 1))
assert_size_stride(arg267_1, (11008, ), (1, ))
assert_size_stride(arg268_1, (11008, 768), (768, 1))
assert_size_stride(arg269_1, (11008, ), (1, ))
assert_size_stride(arg270_1, (4096, 2064), (2064, 1))
assert_size_stride(arg271_1, (4096, ), (1, ))
assert_size_stride(arg272_1, (12288, 768), (768, 1))
assert_size_stride(arg273_1, (12288, ), (1, ))
assert_size_stride(arg274_1, (1, 32, 208, 128), (851968, 26624, 128, 1))
assert_size_stride(arg275_1, (1, 32, 208, 128), (851968, 26624, 128, 1))
assert_size_stride(arg276_1, (4096, 768), (768, 1))
assert_size_stride(arg277_1, (4096, ), (1, ))
assert_size_stride(arg278_1, (11008, 768), (768, 1))
assert_size_stride(arg279_1, (11008, ), (1, ))
assert_size_stride(arg280_1, (11008, 768), (768, 1))
assert_size_stride(arg281_1, (11008, ), (1, ))
assert_size_stride(arg282_1, (4096, 2064), (2064, 1))
assert_size_stride(arg283_1, (4096, ), (1, ))
assert_size_stride(arg284_1, (12288, 768), (768, 1))
assert_size_stride(arg285_1, (12288, ), (1, ))
assert_size_stride(arg286_1, (1, 32, 208, 128), (851968, 26624, 128, 1))
assert_size_stride(arg287_1, (1, 32, 208, 128), (851968, 26624, 128, 1))
assert_size_stride(arg288_1, (4096, 768), (768, 1))
assert_size_stride(arg289_1, (4096, ), (1, ))
assert_size_stride(arg290_1, (11008, 768), (768, 1))
assert_size_stride(arg291_1, (11008, ), (1, ))
assert_size_stride(arg292_1, (11008, 768), (768, 1))
assert_size_stride(arg293_1, (11008, ), (1, ))
assert_size_stride(arg294_1, (4096, 2064), (2064, 1))
assert_size_stride(arg295_1, (4096, ), (1, ))
assert_size_stride(arg296_1, (12288, 768), (768, 1))
assert_size_stride(arg297_1, (12288, ), (1, ))
assert_size_stride(arg298_1, (1, 32, 208, 128), (851968, 26624, 128, 1))
assert_size_stride(arg299_1, (1, 32, 208, 128), (851968, 26624, 128, 1))
assert_size_stride(arg300_1, (4096, 768), (768, 1))
assert_size_stride(arg301_1, (4096, ), (1, ))
assert_size_stride(arg302_1, (11008, 768), (768, 1))
assert_size_stride(arg303_1, (11008, ), (1, ))
assert_size_stride(arg304_1, (11008, 768), (768, 1))
assert_size_stride(arg305_1, (11008, ), (1, ))
assert_size_stride(arg306_1, (4096, 2064), (2064, 1))
assert_size_stride(arg307_1, (4096, ), (1, ))
assert_size_stride(arg308_1, (12288, 768), (768, 1))
assert_size_stride(arg309_1, (12288, ), (1, ))
assert_size_stride(arg310_1, (1, 32, 208, 128), (851968, 26624, 128, 1))
assert_size_stride(arg311_1, (1, 32, 208, 128), (851968, 26624, 128, 1))
assert_size_stride(arg312_1, (4096, 768), (768, 1))
assert_size_stride(arg313_1, (4096, ), (1, ))
assert_size_stride(arg314_1, (11008, 768), (768, 1))
assert_size_stride(arg315_1, (11008, ), (1, ))
assert_size_stride(arg316_1, (11008, 768), (768, 1))
assert_size_stride(arg317_1, (11008, ), (1, ))
assert_size_stride(arg318_1, (4096, 2064), (2064, 1))
assert_size_stride(arg319_1, (4096, ), (1, ))
assert_size_stride(arg320_1, (12288, 768), (768, 1))
assert_size_stride(arg321_1, (12288, ), (1, ))
assert_size_stride(arg322_1, (1, 32, 208, 128), (851968, 26624, 128, 1))
assert_size_stride(arg323_1, (1, 32, 208, 128), (851968, 26624, 128, 1))
assert_size_stride(arg324_1, (4096, 768), (768, 1))
assert_size_stride(arg325_1, (4096, ), (1, ))
assert_size_stride(arg326_1, (11008, 768), (768, 1))
assert_size_stride(arg327_1, (11008, ), (1, ))
assert_size_stride(arg328_1, (11008, 768), (768, 1))
assert_size_stride(arg329_1, (11008, ), (1, ))
assert_size_stride(arg330_1, (4096, 2064), (2064, 1))
assert_size_stride(arg331_1, (4096, ), (1, ))
assert_size_stride(arg332_1, (12288, 768), (768, 1))
assert_size_stride(arg333_1, (12288, ), (1, ))
assert_size_stride(arg334_1, (1, 32, 208, 128), (851968, 26624, 128, 1))
assert_size_stride(arg335_1, (1, 32, 208, 128), (851968, 26624, 128, 1))
assert_size_stride(arg336_1, (4096, 768), (768, 1))
assert_size_stride(arg337_1, (4096, ), (1, ))
assert_size_stride(arg338_1, (11008, 768), (768, 1))
assert_size_stride(arg339_1, (11008, ), (1, ))
assert_size_stride(arg340_1, (11008, 768), (768, 1))
assert_size_stride(arg341_1, (11008, ), (1, ))
assert_size_stride(arg342_1, (4096, 2064), (2064, 1))
assert_size_stride(arg343_1, (4096, ), (1, ))
assert_size_stride(arg344_1, (12288, 768), (768, 1))
assert_size_stride(arg345_1, (12288, ), (1, ))
assert_size_stride(arg346_1, (1, 32, 208, 128), (851968, 26624, 128, 1))
assert_size_stride(arg347_1, (1, 32, 208, 128), (851968, 26624, 128, 1))
assert_size_stride(arg348_1, (4096, 768), (768, 1))
assert_size_stride(arg349_1, (4096, ), (1, ))
assert_size_stride(arg350_1, (11008, 768), (768, 1))
assert_size_stride(arg351_1, (11008, ), (1, ))
assert_size_stride(arg352_1, (11008, 768), (768, 1))
assert_size_stride(arg353_1, (11008, ), (1, ))
assert_size_stride(arg354_1, (4096, 2064), (2064, 1))
assert_size_stride(arg355_1, (4096, ), (1, ))
assert_size_stride(arg356_1, (12288, 768), (768, 1))
assert_size_stride(arg357_1, (12288, ), (1, ))
assert_size_stride(arg358_1, (1, 32, 208, 128), (851968, 26624, 128, 1))
assert_size_stride(arg359_1, (1, 32, 208, 128), (851968, 26624, 128, 1))
assert_size_stride(arg360_1, (4096, 768), (768, 1))
assert_size_stride(arg361_1, (4096, ), (1, ))
assert_size_stride(arg362_1, (11008, 768), (768, 1))
assert_size_stride(arg363_1, (11008, ), (1, ))
assert_size_stride(arg364_1, (11008, 768), (768, 1))
assert_size_stride(arg365_1, (11008, ), (1, ))
assert_size_stride(arg366_1, (4096, 2064), (2064, 1))
assert_size_stride(arg367_1, (4096, ), (1, ))
assert_size_stride(arg368_1, (12288, 768), (768, 1))
assert_size_stride(arg369_1, (12288, ), (1, ))
assert_size_stride(arg370_1, (1, 32, 208, 128), (851968, 26624, 128, 1))
assert_size_stride(arg371_1, (1, 32, 208, 128), (851968, 26624, 128, 1))
assert_size_stride(arg372_1, (4096, 768), (768, 1))
assert_size_stride(arg373_1, (4096, ), (1, ))
assert_size_stride(arg374_1, (11008, 768), (768, 1))
assert_size_stride(arg375_1, (11008, ), (1, ))
assert_size_stride(arg376_1, (11008, 768), (768, 1))
assert_size_stride(arg377_1, (11008, ), (1, ))
assert_size_stride(arg378_1, (4096, 2064), (2064, 1))
assert_size_stride(arg379_1, (4096, ), (1, ))
assert_size_stride(arg380_1, (12288, 768), (768, 1))
assert_size_stride(arg381_1, (12288, ), (1, ))
assert_size_stride(arg382_1, (1, 32, 208, 128), (851968, 26624, 128, 1))
assert_size_stride(arg383_1, (1, 32, 208, 128), (851968, 26624, 128, 1))
assert_size_stride(arg384_1, (4096, 768), (768, 1))
assert_size_stride(arg385_1, (4096, ), (1, ))
assert_size_stride(arg386_1, (11008, 768), (768, 1))
assert_size_stride(arg387_1, (11008, ), (1, ))
assert_size_stride(arg388_1, (11008, 768), (768, 1))
assert_size_stride(arg389_1, (11008, ), (1, ))
assert_size_stride(arg390_1, (4096, 2064), (2064, 1))
assert_size_stride(arg391_1, (4096, ), (1, ))
assert_size_stride(arg392_1, (12288, 768), (768, 1))
assert_size_stride(arg393_1, (12288, ), (1, ))
assert_size_stride(arg394_1, (1, 32, 208, 128), (851968, 26624, 128, 1))
assert_size_stride(arg395_1, (1, 32, 208, 128), (851968, 26624, 128, 1))
assert_size_stride(arg396_1, (4096, 768), (768, 1))
assert_size_stride(arg397_1, (4096, ), (1, ))
assert_size_stride(arg398_1, (11008, 768), (768, 1))
assert_size_stride(arg399_1, (11008, ), (1, ))
assert_size_stride(arg400_1, (11008, 768), (768, 1))
assert_size_stride(arg401_1, (11008, ), (1, ))
assert_size_stride(arg402_1, (4096, 2064), (2064, 1))
assert_size_stride(arg403_1, (4096, ), (1, ))
assert_size_stride(arg404_1, (12288, 768), (768, 1))
assert_size_stride(arg405_1, (12288, ), (1, ))
assert_size_stride(arg406_1, (1, 32, 208, 128), (851968, 26624, 128, 1))
assert_size_stride(arg407_1, (1, 32, 208, 128), (851968, 26624, 128, 1))
assert_size_stride(arg408_1, (4096, 768), (768, 1))
assert_size_stride(arg409_1, (4096, ), (1, ))
assert_size_stride(arg410_1, (11008, 768), (768, 1))
assert_size_stride(arg411_1, (11008, ), (1, ))
assert_size_stride(arg412_1, (11008, 768), (768, 1))
assert_size_stride(arg413_1, (11008, ), (1, ))
assert_size_stride(arg414_1, (4096, 2064), (2064, 1))
assert_size_stride(arg415_1, (4096, ), (1, ))
assert_size_stride(arg416_1, (12288, 768), (768, 1))
assert_size_stride(arg417_1, (12288, ), (1, ))
assert_size_stride(arg418_1, (1, 32, 208, 128), (851968, 26624, 128, 1))
assert_size_stride(arg419_1, (1, 32, 208, 128), (851968, 26624, 128, 1))
assert_size_stride(arg420_1, (4096, 768), (768, 1))
assert_size_stride(arg421_1, (4096, ), (1, ))
assert_size_stride(arg422_1, (11008, 768), (768, 1))
assert_size_stride(arg423_1, (11008, ), (1, ))
assert_size_stride(arg424_1, (11008, 768), (768, 1))
assert_size_stride(arg425_1, (11008, ), (1, ))
assert_size_stride(arg426_1, (4096, 2064), (2064, 1))
assert_size_stride(arg427_1, (4096, ), (1, ))
assert_size_stride(arg428_1, (12288, 768), (768, 1))
assert_size_stride(arg429_1, (12288, ), (1, ))
assert_size_stride(arg430_1, (1, 32, 208, 128), (851968, 26624, 128, 1))
assert_size_stride(arg431_1, (1, 32, 208, 128), (851968, 26624, 128, 1))
assert_size_stride(arg432_1, (4096, 768), (768, 1))
assert_size_stride(arg433_1, (4096, ), (1, ))
assert_size_stride(arg434_1, (11008, 768), (768, 1))
assert_size_stride(arg435_1, (11008, ), (1, ))
assert_size_stride(arg436_1, (11008, 768), (768, 1))
assert_size_stride(arg437_1, (11008, ), (1, ))
assert_size_stride(arg438_1, (4096, 2064), (2064, 1))
assert_size_stride(arg439_1, (4096, ), (1, ))
assert_size_stride(arg440_1, (12288, 768), (768, 1))
assert_size_stride(arg441_1, (12288, ), (1, ))
assert_size_stride(arg442_1, (1, 32, 208, 128), (851968, 26624, 128, 1))
assert_size_stride(arg443_1, (1, 32, 208, 128), (851968, 26624, 128, 1))
assert_size_stride(arg444_1, (4096, 768), (768, 1))
assert_size_stride(arg445_1, (4096, ), (1, ))
assert_size_stride(arg446_1, (11008, 768), (768, 1))
assert_size_stride(arg447_1, (11008, ), (1, ))
assert_size_stride(arg448_1, (11008, 768), (768, 1))
assert_size_stride(arg449_1, (11008, ), (1, ))
assert_size_stride(arg450_1, (4096, 2064), (2064, 1))
assert_size_stride(arg451_1, (4096, ), (1, ))
assert_size_stride(arg452_1, (32000, 768), (768, 1))
assert_size_stride(arg453_1, (32000, ), (1, ))
assert_size_stride(arg454_1, (1, ), (1, ))
assert_size_stride(arg455_1, (1, 1), (1, 1))
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
buf1 = empty_strided_cuda((1, 4096), (4096, 1), torch.float16)
# Source Nodes: [float_1, mean, mul, out, x], Original ATen: [aten._to_copy, aten.embedding, aten.mean, aten.mul, torchao.fp16act_fp6weight_linear]
stream0 = get_raw_stream(0)
triton_red_fused__to_copy_embedding_fp16act_fp6weight_linear_mean_mul_0.run(arg455_1, arg65_1, arg0_1, buf1, 1, 4096, grid=grid(1), stream=stream0)
del arg0_1
# Source Nodes: [out], Original ATen: [torchao.fp16act_fp6weight_linear]
buf2 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf1, arg68_1, arg69_1, 1)
del arg68_1
del arg69_1
buf3 = buf2
del buf2
buf5 = empty_strided_cuda((32, 1, 128), (128, 4096, 1), torch.float32)
# Source Nodes: [setitem, setitem_1, y], Original ATen: [aten.bmm, aten.index_put]
triton_poi_fused_bmm_index_put_1.run(arg454_1, buf3, arg66_1, arg70_1, buf5, arg71_1, 4096, grid=grid(4096), stream=stream0)
del buf3
buf6 = empty_strided_cuda((32, 1, 208), (208, 6656, 1), torch.float32)
# Source Nodes: [y], Original ATen: [aten.bmm]
triton_red_fused_bmm_2.run(buf5, arg70_1, buf6, 6656, 128, grid=grid(6656), stream=stream0)
del arg70_1
buf10 = empty_strided_cuda((32, 1, 208), (208, 6656, 1), torch.float32)
# Source Nodes: [mask, y], Original ATen: [aten._softmax, aten.add, aten.bmm, aten.index, aten.logical_not, aten.masked_fill, aten.zeros_like]
triton_per_fused__softmax_add_bmm_index_logical_not_masked_fill_zeros_like_3.run(buf6, arg454_1, arg67_1, buf10, 32, 208, grid=grid(32), stream=stream0)
buf11 = empty_strided_cuda((32, 1, 128, 2), (256, 8192, 1, 128), torch.float32)
# Source Nodes: [y], Original ATen: [aten.bmm]
triton_red_fused_bmm_4.run(buf10, arg71_1, buf11, 8192, 104, grid=grid(8192), stream=stream0)
del arg71_1
buf13 = buf1; del buf1 # reuse
# Source Nodes: [out_1, y], Original ATen: [aten.bmm, torchao.fp16act_fp6weight_linear]
triton_per_fused_bmm_fp16act_fp6weight_linear_5.run(buf11, buf13, 4096, 2, grid=grid(4096), stream=stream0)
# Source Nodes: [out_1], Original ATen: [torchao.fp16act_fp6weight_linear]
buf14 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf13, arg72_1, arg73_1, 13)
del arg72_1
del arg73_1
buf15 = buf14
del buf14
buf17 = reinterpret_tensor(buf13, (1, 1, 4096), (4096, 4096, 1), 0); del buf13 # reuse
# Source Nodes: [add_4, float_4, h, mean_1, mul_11, mul_12, mul_13, output_1, rsqrt_1, x], Original ATen: [aten._to_copy, aten.add, aten.embedding, aten.mean, aten.mul, aten.rsqrt]
triton_red_fused__to_copy_add_embedding_mean_mul_rsqrt_6.run(arg455_1, arg65_1, buf15, arg1_1, buf17, 1, 4096, grid=grid(1), stream=stream0)
del arg1_1
# Source Nodes: [out_2], Original ATen: [torchao.fp16act_fp6weight_linear]
buf18 = torch.ops.torchao.fp16act_fp6weight_linear.default(reinterpret_tensor(buf17, (1, 4096), (4096, 1), 0), arg74_1, arg75_1, 1)
del arg74_1
del arg75_1
buf19 = buf18
del buf18
# Source Nodes: [out_3], Original ATen: [torchao.fp16act_fp6weight_linear]
buf20 = torch.ops.torchao.fp16act_fp6weight_linear.default(reinterpret_tensor(buf17, (1, 4096), (4096, 1), 0), arg76_1, arg77_1, 1)
del arg76_1
del arg77_1
buf21 = buf20
del buf20
buf22 = buf19; del buf19 # reuse
# Source Nodes: [out_4], Original ATen: [torchao.fp16act_fp6weight_linear]
triton_poi_fused_fp16act_fp6weight_linear_7.run(buf22, buf21, 11008, grid=grid(11008), stream=stream0)
del buf21
# Source Nodes: [out_4], Original ATen: [torchao.fp16act_fp6weight_linear]
buf23 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf22, arg78_1, arg79_1, 13)
del arg78_1
del arg79_1
del buf22
buf24 = buf23
del buf23
buf26 = reinterpret_tensor(buf17, (1, 4096), (4096, 1), 0); del buf17 # reuse
# Source Nodes: [float_5, h, mean_2, mul_15, out_5, out_6, x], Original ATen: [aten._to_copy, aten.add, aten.embedding, aten.mean, aten.mul, torchao.fp16act_fp6weight_linear]
triton_red_fused__to_copy_add_embedding_fp16act_fp6weight_linear_mean_mul_8.run(arg455_1, arg65_1, buf15, buf24, arg2_1, buf26, 1, 4096, grid=grid(1), stream=stream0)
del arg2_1
# Source Nodes: [out_6], Original ATen: [torchao.fp16act_fp6weight_linear]
buf27 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf26, arg80_1, arg81_1, 1)
del arg80_1
del arg81_1
buf28 = buf27
del buf27
buf30 = buf5; del buf5 # reuse
# Source Nodes: [setitem_2, setitem_3, y_3], Original ATen: [aten.bmm, aten.index_put]
triton_poi_fused_bmm_index_put_1.run(arg454_1, buf28, arg66_1, arg82_1, buf30, arg83_1, 4096, grid=grid(4096), stream=stream0)
del buf28
buf31 = buf10; del buf10 # reuse
# Source Nodes: [y_3], Original ATen: [aten.bmm]
triton_red_fused_bmm_2.run(buf30, arg82_1, buf31, 6656, 128, grid=grid(6656), stream=stream0)
del arg82_1
buf35 = buf6; del buf6 # reuse
# Source Nodes: [mask, y_3], Original ATen: [aten._softmax, aten.add, aten.bmm, aten.index, aten.logical_not, aten.masked_fill, aten.zeros_like]
triton_per_fused__softmax_add_bmm_index_logical_not_masked_fill_zeros_like_3.run(buf31, arg454_1, arg67_1, buf35, 32, 208, grid=grid(32), stream=stream0)
buf36 = buf11; del buf11 # reuse
# Source Nodes: [y_3], Original ATen: [aten.bmm]
triton_red_fused_bmm_4.run(buf35, arg83_1, buf36, 8192, 104, grid=grid(8192), stream=stream0)
del arg83_1
buf38 = buf26; del buf26 # reuse
# Source Nodes: [out_7, y_3], Original ATen: [aten.bmm, torchao.fp16act_fp6weight_linear]
triton_per_fused_bmm_fp16act_fp6weight_linear_5.run(buf36, buf38, 4096, 2, grid=grid(4096), stream=stream0)
# Source Nodes: [out_7], Original ATen: [torchao.fp16act_fp6weight_linear]
buf39 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf38, arg84_1, arg85_1, 13)
del arg84_1
del arg85_1
buf40 = buf39
del buf39
buf41 = reinterpret_tensor(buf15, (1, 1, 4096), (4096, 4096, 1), 0); del buf15 # reuse
buf43 = buf38; del buf38 # reuse
buf46 = empty_strided_cuda((1, 4096), (4096, 1), torch.float16)
# Source Nodes: [float_8, h, h_1, mean_3, mul_26, out_5, out_8, out_9, x], Original ATen: [aten._to_copy, aten.add, aten.embedding, aten.mean, aten.mul, torchao.fp16act_fp6weight_linear]
triton_red_fused__to_copy_add_embedding_fp16act_fp6weight_linear_mean_mul_9.run(buf41, arg455_1, arg65_1, buf24, buf40, arg3_1, buf43, buf46, 1, 4096, grid=grid(1), stream=stream0)
del arg3_1
del arg455_1
del arg65_1
del buf24
del buf40
# Source Nodes: [out_8], Original ATen: [torchao.fp16act_fp6weight_linear]
buf44 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf43, arg86_1, arg87_1, 1)
del arg86_1
del arg87_1
buf45 = buf44
del buf44
# Source Nodes: [out_9], Original ATen: [torchao.fp16act_fp6weight_linear]
buf47 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf46, arg88_1, arg89_1, 1)
del arg88_1
del arg89_1
buf48 = buf47
del buf47
buf49 = buf45; del buf45 # reuse
# Source Nodes: [out_10], Original ATen: [torchao.fp16act_fp6weight_linear]
triton_poi_fused_fp16act_fp6weight_linear_7.run(buf49, buf48, 11008, grid=grid(11008), stream=stream0)
del buf48
# Source Nodes: [out_10], Original ATen: [torchao.fp16act_fp6weight_linear]
buf50 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf49, arg90_1, arg91_1, 13)
del arg90_1
del arg91_1
del buf49
buf51 = buf50
del buf50
buf53 = buf46; del buf46 # reuse
# Source Nodes: [float_9, mean_4, mul_30, out_11, out_12], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, torchao.fp16act_fp6weight_linear]
triton_red_fused__to_copy_add_fp16act_fp6weight_linear_mean_mul_10.run(buf41, buf51, arg4_1, buf53, 1, 4096, grid=grid(1), stream=stream0)
del arg4_1
# Source Nodes: [out_12], Original ATen: [torchao.fp16act_fp6weight_linear]
buf54 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf53, arg92_1, arg93_1, 1)
del arg92_1
del arg93_1
buf55 = buf54
del buf54
buf57 = buf30; del buf30 # reuse
# Source Nodes: [setitem_4, setitem_5, y_6], Original ATen: [aten.bmm, aten.index_put]
triton_poi_fused_bmm_index_put_1.run(arg454_1, buf55, arg66_1, arg94_1, buf57, arg95_1, 4096, grid=grid(4096), stream=stream0)
del buf55
buf58 = buf35; del buf35 # reuse
# Source Nodes: [y_6], Original ATen: [aten.bmm]
triton_red_fused_bmm_2.run(buf57, arg94_1, buf58, 6656, 128, grid=grid(6656), stream=stream0)
del arg94_1
buf62 = buf31; del buf31 # reuse
# Source Nodes: [mask, y_6], Original ATen: [aten._softmax, aten.add, aten.bmm, aten.index, aten.logical_not, aten.masked_fill, aten.zeros_like]
triton_per_fused__softmax_add_bmm_index_logical_not_masked_fill_zeros_like_3.run(buf58, arg454_1, arg67_1, buf62, 32, 208, grid=grid(32), stream=stream0)
buf63 = buf36; del buf36 # reuse
# Source Nodes: [y_6], Original ATen: [aten.bmm]
triton_red_fused_bmm_4.run(buf62, arg95_1, buf63, 8192, 104, grid=grid(8192), stream=stream0)
del arg95_1
buf65 = buf53; del buf53 # reuse
# Source Nodes: [out_13, y_6], Original ATen: [aten.bmm, torchao.fp16act_fp6weight_linear]
triton_per_fused_bmm_fp16act_fp6weight_linear_5.run(buf63, buf65, 4096, 2, grid=grid(4096), stream=stream0)
# Source Nodes: [out_13], Original ATen: [torchao.fp16act_fp6weight_linear]
buf66 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf65, arg96_1, arg97_1, 13)
del arg96_1
del arg97_1
buf67 = buf66
del buf66
buf69 = reinterpret_tensor(buf65, (1, 1, 4096), (4096, 4096, 1), 0); del buf65 # reuse
# Source Nodes: [add_16, float_12, h_2, mean_5, mul_41, mul_42, mul_43, out_11, output_5, rsqrt_5], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, aten.rsqrt]
triton_red_fused__to_copy_add_mean_mul_rsqrt_11.run(buf41, buf51, buf67, arg5_1, buf69, 1, 4096, grid=grid(1), stream=stream0)
del arg5_1
# Source Nodes: [out_14], Original ATen: [torchao.fp16act_fp6weight_linear]
buf70 = torch.ops.torchao.fp16act_fp6weight_linear.default(reinterpret_tensor(buf69, (1, 4096), (4096, 1), 0), arg98_1, arg99_1, 1)
del arg98_1
del arg99_1
buf71 = buf70
del buf70
# Source Nodes: [out_15], Original ATen: [torchao.fp16act_fp6weight_linear]
buf72 = torch.ops.torchao.fp16act_fp6weight_linear.default(reinterpret_tensor(buf69, (1, 4096), (4096, 1), 0), arg100_1, arg101_1, 1)
del arg100_1
del arg101_1
buf73 = buf72
del buf72
buf74 = buf71; del buf71 # reuse
# Source Nodes: [out_16], Original ATen: [torchao.fp16act_fp6weight_linear]
triton_poi_fused_fp16act_fp6weight_linear_7.run(buf74, buf73, 11008, grid=grid(11008), stream=stream0)
del buf73
# Source Nodes: [out_16], Original ATen: [torchao.fp16act_fp6weight_linear]
buf75 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf74, arg102_1, arg103_1, 13)
del arg102_1
del arg103_1
del buf74
buf76 = buf75
del buf75
buf78 = reinterpret_tensor(buf69, (1, 4096), (4096, 1), 0); del buf69 # reuse
# Source Nodes: [float_13, h_2, mean_6, mul_45, out_11, out_17, out_18], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, torchao.fp16act_fp6weight_linear]
triton_red_fused__to_copy_add_fp16act_fp6weight_linear_mean_mul_12.run(buf41, buf51, buf67, buf76, arg6_1, buf78, 1, 4096, grid=grid(1), stream=stream0)
del arg6_1
# Source Nodes: [out_18], Original ATen: [torchao.fp16act_fp6weight_linear]
buf79 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf78, arg104_1, arg105_1, 1)
del arg104_1
del arg105_1
buf80 = buf79
del buf79
buf82 = buf57; del buf57 # reuse
# Source Nodes: [setitem_6, setitem_7, y_9], Original ATen: [aten.bmm, aten.index_put]
triton_poi_fused_bmm_index_put_1.run(arg454_1, buf80, arg66_1, arg106_1, buf82, arg107_1, 4096, grid=grid(4096), stream=stream0)
del buf80
buf83 = buf62; del buf62 # reuse
# Source Nodes: [y_9], Original ATen: [aten.bmm]
triton_red_fused_bmm_2.run(buf82, arg106_1, buf83, 6656, 128, grid=grid(6656), stream=stream0)
del arg106_1
buf87 = buf58; del buf58 # reuse
# Source Nodes: [mask, y_9], Original ATen: [aten._softmax, aten.add, aten.bmm, aten.index, aten.logical_not, aten.masked_fill, aten.zeros_like]
triton_per_fused__softmax_add_bmm_index_logical_not_masked_fill_zeros_like_3.run(buf83, arg454_1, arg67_1, buf87, 32, 208, grid=grid(32), stream=stream0)
buf88 = buf63; del buf63 # reuse
# Source Nodes: [y_9], Original ATen: [aten.bmm]
triton_red_fused_bmm_4.run(buf87, arg107_1, buf88, 8192, 104, grid=grid(8192), stream=stream0)
del arg107_1
buf90 = buf78; del buf78 # reuse
# Source Nodes: [out_19, y_9], Original ATen: [aten.bmm, torchao.fp16act_fp6weight_linear]
triton_per_fused_bmm_fp16act_fp6weight_linear_5.run(buf88, buf90, 4096, 2, grid=grid(4096), stream=stream0)
# Source Nodes: [out_19], Original ATen: [torchao.fp16act_fp6weight_linear]
buf91 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf90, arg108_1, arg109_1, 13)
del arg108_1
del arg109_1
buf92 = buf91
del buf91
buf93 = buf41; del buf41 # reuse
buf95 = buf90; del buf90 # reuse
buf98 = buf43; del buf43 # reuse
# Source Nodes: [float_16, h_2, h_3, mean_7, mul_56, out_11, out_17, out_20, out_21], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, torchao.fp16act_fp6weight_linear]
triton_red_fused__to_copy_add_fp16act_fp6weight_linear_mean_mul_13.run(buf93, buf51, buf67, buf76, buf92, arg7_1, buf95, buf98, 1, 4096, grid=grid(1), stream=stream0)
del arg7_1
del buf51
del buf67
del buf76
del buf92
# Source Nodes: [out_20], Original ATen: [torchao.fp16act_fp6weight_linear]
buf96 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf95, arg110_1, arg111_1, 1)
del arg110_1
del arg111_1
buf97 = buf96
del buf96
# Source Nodes: [out_21], Original ATen: [torchao.fp16act_fp6weight_linear]
buf99 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf98, arg112_1, arg113_1, 1)
del arg112_1
del arg113_1
buf100 = buf99
del buf99
buf101 = buf100; del buf100 # reuse
# Source Nodes: [out_22], Original ATen: [torchao.fp16act_fp6weight_linear]
triton_poi_fused_fp16act_fp6weight_linear_14.run(buf101, buf97, 11008, grid=grid(11008), stream=stream0)
del buf97
# Source Nodes: [out_22], Original ATen: [torchao.fp16act_fp6weight_linear]
buf102 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf101, arg114_1, arg115_1, 13)
del arg114_1
del arg115_1
del buf101
buf103 = buf102
del buf102
buf105 = buf98; del buf98 # reuse
# Source Nodes: [float_17, mean_8, mul_60, out_23, out_24], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, torchao.fp16act_fp6weight_linear]
triton_red_fused__to_copy_add_fp16act_fp6weight_linear_mean_mul_10.run(buf93, buf103, arg8_1, buf105, 1, 4096, grid=grid(1), stream=stream0)
del arg8_1
# Source Nodes: [out_24], Original ATen: [torchao.fp16act_fp6weight_linear]
buf106 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf105, arg116_1, arg117_1, 1)
del arg116_1
del arg117_1
buf107 = buf106
del buf106
buf109 = buf82; del buf82 # reuse
# Source Nodes: [setitem_8, setitem_9, y_12], Original ATen: [aten.bmm, aten.index_put]
triton_poi_fused_bmm_index_put_1.run(arg454_1, buf107, arg66_1, arg118_1, buf109, arg119_1, 4096, grid=grid(4096), stream=stream0)
del buf107
buf110 = buf87; del buf87 # reuse
# Source Nodes: [y_12], Original ATen: [aten.bmm]
triton_red_fused_bmm_2.run(buf109, arg118_1, buf110, 6656, 128, grid=grid(6656), stream=stream0)
del arg118_1
buf114 = buf83; del buf83 # reuse
# Source Nodes: [mask, y_12], Original ATen: [aten._softmax, aten.add, aten.bmm, aten.index, aten.logical_not, aten.masked_fill, aten.zeros_like]
triton_per_fused__softmax_add_bmm_index_logical_not_masked_fill_zeros_like_3.run(buf110, arg454_1, arg67_1, buf114, 32, 208, grid=grid(32), stream=stream0)
buf115 = buf88; del buf88 # reuse
# Source Nodes: [y_12], Original ATen: [aten.bmm]
triton_red_fused_bmm_4.run(buf114, arg119_1, buf115, 8192, 104, grid=grid(8192), stream=stream0)
del arg119_1
buf117 = buf105; del buf105 # reuse
# Source Nodes: [out_25, y_12], Original ATen: [aten.bmm, torchao.fp16act_fp6weight_linear]
triton_per_fused_bmm_fp16act_fp6weight_linear_5.run(buf115, buf117, 4096, 2, grid=grid(4096), stream=stream0)
# Source Nodes: [out_25], Original ATen: [torchao.fp16act_fp6weight_linear]
buf118 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf117, arg120_1, arg121_1, 13)
del arg120_1
del arg121_1
buf119 = buf118
del buf118
buf121 = reinterpret_tensor(buf117, (1, 1, 4096), (4096, 4096, 1), 0); del buf117 # reuse
# Source Nodes: [add_28, float_20, h_4, mean_9, mul_71, mul_72, mul_73, out_23, output_9, rsqrt_9], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, aten.rsqrt]
triton_red_fused__to_copy_add_mean_mul_rsqrt_11.run(buf93, buf103, buf119, arg9_1, buf121, 1, 4096, grid=grid(1), stream=stream0)
del arg9_1
# Source Nodes: [out_26], Original ATen: [torchao.fp16act_fp6weight_linear]
buf122 = torch.ops.torchao.fp16act_fp6weight_linear.default(reinterpret_tensor(buf121, (1, 4096), (4096, 1), 0), arg122_1, arg123_1, 1)
del arg122_1
del arg123_1
buf123 = buf122
del buf122
# Source Nodes: [out_27], Original ATen: [torchao.fp16act_fp6weight_linear]
buf124 = torch.ops.torchao.fp16act_fp6weight_linear.default(reinterpret_tensor(buf121, (1, 4096), (4096, 1), 0), arg124_1, arg125_1, 1)
del arg124_1
del arg125_1
buf125 = buf124
del buf124
buf126 = buf123; del buf123 # reuse
# Source Nodes: [out_28], Original ATen: [torchao.fp16act_fp6weight_linear]
triton_poi_fused_fp16act_fp6weight_linear_7.run(buf126, buf125, 11008, grid=grid(11008), stream=stream0)
del buf125
# Source Nodes: [out_28], Original ATen: [torchao.fp16act_fp6weight_linear]
buf127 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf126, arg126_1, arg127_1, 13)
del arg126_1
del arg127_1
del buf126
buf128 = buf127
del buf127
buf130 = reinterpret_tensor(buf121, (1, 4096), (4096, 1), 0); del buf121 # reuse
# Source Nodes: [float_21, h_4, mean_10, mul_75, out_23, out_29, out_30], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, torchao.fp16act_fp6weight_linear]
triton_red_fused__to_copy_add_fp16act_fp6weight_linear_mean_mul_12.run(buf93, buf103, buf119, buf128, arg10_1, buf130, 1, 4096, grid=grid(1), stream=stream0)
del arg10_1
# Source Nodes: [out_30], Original ATen: [torchao.fp16act_fp6weight_linear]
buf131 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf130, arg128_1, arg129_1, 1)
del arg128_1
del arg129_1
buf132 = buf131
del buf131
buf134 = buf109; del buf109 # reuse
# Source Nodes: [setitem_10, setitem_11, y_15], Original ATen: [aten.bmm, aten.index_put]
triton_poi_fused_bmm_index_put_1.run(arg454_1, buf132, arg66_1, arg130_1, buf134, arg131_1, 4096, grid=grid(4096), stream=stream0)
del buf132
buf135 = buf114; del buf114 # reuse
# Source Nodes: [y_15], Original ATen: [aten.bmm]
triton_red_fused_bmm_2.run(buf134, arg130_1, buf135, 6656, 128, grid=grid(6656), stream=stream0)
del arg130_1
buf139 = buf110; del buf110 # reuse
# Source Nodes: [mask, y_15], Original ATen: [aten._softmax, aten.add, aten.bmm, aten.index, aten.logical_not, aten.masked_fill, aten.zeros_like]
triton_per_fused__softmax_add_bmm_index_logical_not_masked_fill_zeros_like_3.run(buf135, arg454_1, arg67_1, buf139, 32, 208, grid=grid(32), stream=stream0)
buf140 = buf115; del buf115 # reuse
# Source Nodes: [y_15], Original ATen: [aten.bmm]
triton_red_fused_bmm_4.run(buf139, arg131_1, buf140, 8192, 104, grid=grid(8192), stream=stream0)
del arg131_1
buf142 = buf130; del buf130 # reuse
# Source Nodes: [out_31, y_15], Original ATen: [aten.bmm, torchao.fp16act_fp6weight_linear]
triton_per_fused_bmm_fp16act_fp6weight_linear_5.run(buf140, buf142, 4096, 2, grid=grid(4096), stream=stream0)
# Source Nodes: [out_31], Original ATen: [torchao.fp16act_fp6weight_linear]
buf143 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf142, arg132_1, arg133_1, 13)
del arg132_1
del arg133_1
buf144 = buf143
del buf143
buf145 = reinterpret_tensor(buf103, (1, 1, 4096), (4096, 4096, 1), 0); del buf103 # reuse
buf147 = buf142; del buf142 # reuse
buf150 = buf95; del buf95 # reuse
# Source Nodes: [float_24, h_4, h_5, mean_11, mul_86, out_23, out_29, out_32, out_33], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, torchao.fp16act_fp6weight_linear]
triton_red_fused__to_copy_add_fp16act_fp6weight_linear_mean_mul_15.run(buf145, buf93, buf119, buf128, buf144, arg11_1, buf147, buf150, 1, 4096, grid=grid(1), stream=stream0)
del arg11_1
del buf119
del buf128
del buf144
del buf93
# Source Nodes: [out_32], Original ATen: [torchao.fp16act_fp6weight_linear]
buf148 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf147, arg134_1, arg135_1, 1)
del arg134_1
del arg135_1
buf149 = buf148
del buf148
# Source Nodes: [out_33], Original ATen: [torchao.fp16act_fp6weight_linear]
buf151 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf150, arg136_1, arg137_1, 1)
del arg136_1
del arg137_1
buf152 = buf151
del buf151
buf153 = buf149; del buf149 # reuse
# Source Nodes: [out_34], Original ATen: [torchao.fp16act_fp6weight_linear]
triton_poi_fused_fp16act_fp6weight_linear_7.run(buf153, buf152, 11008, grid=grid(11008), stream=stream0)
del buf152
# Source Nodes: [out_34], Original ATen: [torchao.fp16act_fp6weight_linear]
buf154 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf153, arg138_1, arg139_1, 13)
del arg138_1
del arg139_1
del buf153
buf155 = buf154
del buf154
buf157 = buf150; del buf150 # reuse
# Source Nodes: [float_25, mean_12, mul_90, out_35, out_36], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, torchao.fp16act_fp6weight_linear]
triton_red_fused__to_copy_add_fp16act_fp6weight_linear_mean_mul_10.run(buf145, buf155, arg12_1, buf157, 1, 4096, grid=grid(1), stream=stream0)
del arg12_1
# Source Nodes: [out_36], Original ATen: [torchao.fp16act_fp6weight_linear]
buf158 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf157, arg140_1, arg141_1, 1)
del arg140_1
del arg141_1
buf159 = buf158
del buf158
buf161 = buf134; del buf134 # reuse
# Source Nodes: [setitem_12, setitem_13, y_18], Original ATen: [aten.bmm, aten.index_put]
triton_poi_fused_bmm_index_put_1.run(arg454_1, buf159, arg66_1, arg142_1, buf161, arg143_1, 4096, grid=grid(4096), stream=stream0)
del buf159
buf162 = buf139; del buf139 # reuse
# Source Nodes: [y_18], Original ATen: [aten.bmm]
triton_red_fused_bmm_2.run(buf161, arg142_1, buf162, 6656, 128, grid=grid(6656), stream=stream0)
del arg142_1
buf166 = buf135; del buf135 # reuse
# Source Nodes: [mask, y_18], Original ATen: [aten._softmax, aten.add, aten.bmm, aten.index, aten.logical_not, aten.masked_fill, aten.zeros_like]
triton_per_fused__softmax_add_bmm_index_logical_not_masked_fill_zeros_like_3.run(buf162, arg454_1, arg67_1, buf166, 32, 208, grid=grid(32), stream=stream0)
buf167 = buf140; del buf140 # reuse
# Source Nodes: [y_18], Original ATen: [aten.bmm]
triton_red_fused_bmm_4.run(buf166, arg143_1, buf167, 8192, 104, grid=grid(8192), stream=stream0)
del arg143_1
buf169 = buf157; del buf157 # reuse
# Source Nodes: [out_37, y_18], Original ATen: [aten.bmm, torchao.fp16act_fp6weight_linear]
triton_per_fused_bmm_fp16act_fp6weight_linear_5.run(buf167, buf169, 4096, 2, grid=grid(4096), stream=stream0)
# Source Nodes: [out_37], Original ATen: [torchao.fp16act_fp6weight_linear]
buf170 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf169, arg144_1, arg145_1, 13)
del arg144_1
del arg145_1
buf171 = buf170
del buf170
buf173 = reinterpret_tensor(buf169, (1, 1, 4096), (4096, 4096, 1), 0); del buf169 # reuse
# Source Nodes: [add_40, float_28, h_6, mean_13, mul_101, mul_102, mul_103, out_35, output_13, rsqrt_13], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, aten.rsqrt]
triton_red_fused__to_copy_add_mean_mul_rsqrt_11.run(buf145, buf155, buf171, arg13_1, buf173, 1, 4096, grid=grid(1), stream=stream0)
del arg13_1
# Source Nodes: [out_38], Original ATen: [torchao.fp16act_fp6weight_linear]
buf174 = torch.ops.torchao.fp16act_fp6weight_linear.default(reinterpret_tensor(buf173, (1, 4096), (4096, 1), 0), arg146_1, arg147_1, 1)
del arg146_1
del arg147_1
buf175 = buf174
del buf174
# Source Nodes: [out_39], Original ATen: [torchao.fp16act_fp6weight_linear]
buf176 = torch.ops.torchao.fp16act_fp6weight_linear.default(reinterpret_tensor(buf173, (1, 4096), (4096, 1), 0), arg148_1, arg149_1, 1)
del arg148_1
del arg149_1
buf177 = buf176
del buf176
buf178 = buf175; del buf175 # reuse
# Source Nodes: [out_40], Original ATen: [torchao.fp16act_fp6weight_linear]
triton_poi_fused_fp16act_fp6weight_linear_7.run(buf178, buf177, 11008, grid=grid(11008), stream=stream0)
del buf177
# Source Nodes: [out_40], Original ATen: [torchao.fp16act_fp6weight_linear]
buf179 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf178, arg150_1, arg151_1, 13)
del arg150_1
del arg151_1
del buf178
buf180 = buf179
del buf179
buf182 = reinterpret_tensor(buf173, (1, 4096), (4096, 1), 0); del buf173 # reuse
# Source Nodes: [float_29, h_6, mean_14, mul_105, out_35, out_41, out_42], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, torchao.fp16act_fp6weight_linear]
triton_red_fused__to_copy_add_fp16act_fp6weight_linear_mean_mul_12.run(buf145, buf155, buf171, buf180, arg14_1, buf182, 1, 4096, grid=grid(1), stream=stream0)
del arg14_1
# Source Nodes: [out_42], Original ATen: [torchao.fp16act_fp6weight_linear]
buf183 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf182, arg152_1, arg153_1, 1)
del arg152_1
del arg153_1
buf184 = buf183
del buf183
buf186 = buf161; del buf161 # reuse
# Source Nodes: [setitem_14, setitem_15, y_21], Original ATen: [aten.bmm, aten.index_put]
triton_poi_fused_bmm_index_put_1.run(arg454_1, buf184, arg66_1, arg154_1, buf186, arg155_1, 4096, grid=grid(4096), stream=stream0)
del buf184
buf187 = buf166; del buf166 # reuse
# Source Nodes: [y_21], Original ATen: [aten.bmm]
triton_red_fused_bmm_2.run(buf186, arg154_1, buf187, 6656, 128, grid=grid(6656), stream=stream0)
del arg154_1
buf191 = buf162; del buf162 # reuse
# Source Nodes: [mask, y_21], Original ATen: [aten._softmax, aten.add, aten.bmm, aten.index, aten.logical_not, aten.masked_fill, aten.zeros_like]
triton_per_fused__softmax_add_bmm_index_logical_not_masked_fill_zeros_like_3.run(buf187, arg454_1, arg67_1, buf191, 32, 208, grid=grid(32), stream=stream0)
buf192 = buf167; del buf167 # reuse
# Source Nodes: [y_21], Original ATen: [aten.bmm]
triton_red_fused_bmm_4.run(buf191, arg155_1, buf192, 8192, 104, grid=grid(8192), stream=stream0)
del arg155_1
buf194 = buf182; del buf182 # reuse
# Source Nodes: [out_43, y_21], Original ATen: [aten.bmm, torchao.fp16act_fp6weight_linear]
triton_per_fused_bmm_fp16act_fp6weight_linear_5.run(buf192, buf194, 4096, 2, grid=grid(4096), stream=stream0)
# Source Nodes: [out_43], Original ATen: [torchao.fp16act_fp6weight_linear]
buf195 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf194, arg156_1, arg157_1, 13)
del arg156_1
del arg157_1
buf196 = buf195
del buf195
buf197 = buf145; del buf145 # reuse
buf199 = buf194; del buf194 # reuse
buf202 = buf147; del buf147 # reuse
# Source Nodes: [float_32, h_6, h_7, mean_15, mul_116, out_35, out_41, out_44, out_45], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, torchao.fp16act_fp6weight_linear]
triton_red_fused__to_copy_add_fp16act_fp6weight_linear_mean_mul_13.run(buf197, buf155, buf171, buf180, buf196, arg15_1, buf199, buf202, 1, 4096, grid=grid(1), stream=stream0)
del arg15_1
del buf155
del buf171
del buf180
del buf196
# Source Nodes: [out_44], Original ATen: [torchao.fp16act_fp6weight_linear]
buf200 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf199, arg158_1, arg159_1, 1)
del arg158_1
del arg159_1
buf201 = buf200
del buf200
# Source Nodes: [out_45], Original ATen: [torchao.fp16act_fp6weight_linear]
buf203 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf202, arg160_1, arg161_1, 1)
del arg160_1
del arg161_1
buf204 = buf203
del buf203
buf205 = buf201; del buf201 # reuse
# Source Nodes: [out_46], Original ATen: [torchao.fp16act_fp6weight_linear]
triton_poi_fused_fp16act_fp6weight_linear_7.run(buf205, buf204, 11008, grid=grid(11008), stream=stream0)
del buf204
# Source Nodes: [out_46], Original ATen: [torchao.fp16act_fp6weight_linear]
buf206 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf205, arg162_1, arg163_1, 13)
del arg162_1
del arg163_1
del buf205
buf207 = buf206
del buf206
buf209 = buf202; del buf202 # reuse
# Source Nodes: [float_33, mean_16, mul_120, out_47, out_48], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, torchao.fp16act_fp6weight_linear]
triton_red_fused__to_copy_add_fp16act_fp6weight_linear_mean_mul_10.run(buf197, buf207, arg16_1, buf209, 1, 4096, grid=grid(1), stream=stream0)
del arg16_1
# Source Nodes: [out_48], Original ATen: [torchao.fp16act_fp6weight_linear]
buf210 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf209, arg164_1, arg165_1, 1)
del arg164_1
del arg165_1
buf211 = buf210
del buf210
buf213 = buf186; del buf186 # reuse
# Source Nodes: [setitem_16, setitem_17, y_24], Original ATen: [aten.bmm, aten.index_put]
triton_poi_fused_bmm_index_put_1.run(arg454_1, buf211, arg66_1, arg166_1, buf213, arg167_1, 4096, grid=grid(4096), stream=stream0)
del buf211
buf214 = buf191; del buf191 # reuse
# Source Nodes: [y_24], Original ATen: [aten.bmm]
triton_red_fused_bmm_2.run(buf213, arg166_1, buf214, 6656, 128, grid=grid(6656), stream=stream0)
del arg166_1
buf218 = buf187; del buf187 # reuse
# Source Nodes: [mask, y_24], Original ATen: [aten._softmax, aten.add, aten.bmm, aten.index, aten.logical_not, aten.masked_fill, aten.zeros_like]
triton_per_fused__softmax_add_bmm_index_logical_not_masked_fill_zeros_like_3.run(buf214, arg454_1, arg67_1, buf218, 32, 208, grid=grid(32), stream=stream0)
buf219 = buf192; del buf192 # reuse
# Source Nodes: [y_24], Original ATen: [aten.bmm]
triton_red_fused_bmm_4.run(buf218, arg167_1, buf219, 8192, 104, grid=grid(8192), stream=stream0)
del arg167_1
buf221 = buf209; del buf209 # reuse
# Source Nodes: [out_49, y_24], Original ATen: [aten.bmm, torchao.fp16act_fp6weight_linear]
triton_per_fused_bmm_fp16act_fp6weight_linear_5.run(buf219, buf221, 4096, 2, grid=grid(4096), stream=stream0)
# Source Nodes: [out_49], Original ATen: [torchao.fp16act_fp6weight_linear]
buf222 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf221, arg168_1, arg169_1, 13)
del arg168_1
del arg169_1
buf223 = buf222
del buf222
buf225 = reinterpret_tensor(buf221, (1, 1, 4096), (4096, 4096, 1), 0); del buf221 # reuse
# Source Nodes: [add_52, float_36, h_8, mean_17, mul_131, mul_132, mul_133, out_47, output_17, rsqrt_17], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, aten.rsqrt]
triton_red_fused__to_copy_add_mean_mul_rsqrt_11.run(buf197, buf207, buf223, arg17_1, buf225, 1, 4096, grid=grid(1), stream=stream0)
del arg17_1
# Source Nodes: [out_50], Original ATen: [torchao.fp16act_fp6weight_linear]
buf226 = torch.ops.torchao.fp16act_fp6weight_linear.default(reinterpret_tensor(buf225, (1, 4096), (4096, 1), 0), arg170_1, arg171_1, 1)
del arg170_1
del arg171_1
buf227 = buf226
del buf226
# Source Nodes: [out_51], Original ATen: [torchao.fp16act_fp6weight_linear]
buf228 = torch.ops.torchao.fp16act_fp6weight_linear.default(reinterpret_tensor(buf225, (1, 4096), (4096, 1), 0), arg172_1, arg173_1, 1)
del arg172_1
del arg173_1
buf229 = buf228
del buf228
buf230 = buf227; del buf227 # reuse
# Source Nodes: [out_52], Original ATen: [torchao.fp16act_fp6weight_linear]
triton_poi_fused_fp16act_fp6weight_linear_7.run(buf230, buf229, 11008, grid=grid(11008), stream=stream0)
del buf229
# Source Nodes: [out_52], Original ATen: [torchao.fp16act_fp6weight_linear]
buf231 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf230, arg174_1, arg175_1, 13)
del arg174_1
del arg175_1
del buf230
buf232 = buf231
del buf231
buf234 = reinterpret_tensor(buf225, (1, 4096), (4096, 1), 0); del buf225 # reuse
# Source Nodes: [float_37, h_8, mean_18, mul_135, out_47, out_53, out_54], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, torchao.fp16act_fp6weight_linear]
triton_red_fused__to_copy_add_fp16act_fp6weight_linear_mean_mul_12.run(buf197, buf207, buf223, buf232, arg18_1, buf234, 1, 4096, grid=grid(1), stream=stream0)
del arg18_1
# Source Nodes: [out_54], Original ATen: [torchao.fp16act_fp6weight_linear]
buf235 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf234, arg176_1, arg177_1, 1)
del arg176_1
del arg177_1
buf236 = buf235
del buf235
buf238 = buf213; del buf213 # reuse
# Source Nodes: [setitem_18, setitem_19, y_27], Original ATen: [aten.bmm, aten.index_put]
triton_poi_fused_bmm_index_put_1.run(arg454_1, buf236, arg66_1, arg178_1, buf238, arg179_1, 4096, grid=grid(4096), stream=stream0)
del buf236
buf239 = buf218; del buf218 # reuse
# Source Nodes: [y_27], Original ATen: [aten.bmm]
triton_red_fused_bmm_2.run(buf238, arg178_1, buf239, 6656, 128, grid=grid(6656), stream=stream0)
del arg178_1
buf243 = buf214; del buf214 # reuse
# Source Nodes: [mask, y_27], Original ATen: [aten._softmax, aten.add, aten.bmm, aten.index, aten.logical_not, aten.masked_fill, aten.zeros_like]
triton_per_fused__softmax_add_bmm_index_logical_not_masked_fill_zeros_like_3.run(buf239, arg454_1, arg67_1, buf243, 32, 208, grid=grid(32), stream=stream0)
buf244 = buf219; del buf219 # reuse
# Source Nodes: [y_27], Original ATen: [aten.bmm]
triton_red_fused_bmm_4.run(buf243, arg179_1, buf244, 8192, 104, grid=grid(8192), stream=stream0)
del arg179_1
buf246 = buf234; del buf234 # reuse
# Source Nodes: [out_55, y_27], Original ATen: [aten.bmm, torchao.fp16act_fp6weight_linear]
triton_per_fused_bmm_fp16act_fp6weight_linear_5.run(buf244, buf246, 4096, 2, grid=grid(4096), stream=stream0)
# Source Nodes: [out_55], Original ATen: [torchao.fp16act_fp6weight_linear]
buf247 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf246, arg180_1, arg181_1, 13)
del arg180_1
del arg181_1
buf248 = buf247
del buf247
buf249 = buf197; del buf197 # reuse
buf251 = buf246; del buf246 # reuse
buf254 = buf199; del buf199 # reuse
# Source Nodes: [float_40, h_8, h_9, mean_19, mul_146, out_47, out_53, out_56, out_57], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, torchao.fp16act_fp6weight_linear]
triton_red_fused__to_copy_add_fp16act_fp6weight_linear_mean_mul_13.run(buf249, buf207, buf223, buf232, buf248, arg19_1, buf251, buf254, 1, 4096, grid=grid(1), stream=stream0)
del arg19_1
del buf207
del buf223
del buf232
del buf248
# Source Nodes: [out_56], Original ATen: [torchao.fp16act_fp6weight_linear]
buf252 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf251, arg182_1, arg183_1, 1)
del arg182_1
del arg183_1
buf253 = buf252
del buf252
# Source Nodes: [out_57], Original ATen: [torchao.fp16act_fp6weight_linear]
buf255 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf254, arg184_1, arg185_1, 1)
del arg184_1
del arg185_1
buf256 = buf255
del buf255
buf257 = buf253; del buf253 # reuse
# Source Nodes: [out_58], Original ATen: [torchao.fp16act_fp6weight_linear]
triton_poi_fused_fp16act_fp6weight_linear_7.run(buf257, buf256, 11008, grid=grid(11008), stream=stream0)
del buf256
# Source Nodes: [out_58], Original ATen: [torchao.fp16act_fp6weight_linear]
buf258 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf257, arg186_1, arg187_1, 13)
del arg186_1
del arg187_1
del buf257
buf259 = buf258
del buf258
buf261 = buf254; del buf254 # reuse
# Source Nodes: [float_41, mean_20, mul_150, out_59, out_60], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, torchao.fp16act_fp6weight_linear]
triton_red_fused__to_copy_add_fp16act_fp6weight_linear_mean_mul_10.run(buf249, buf259, arg20_1, buf261, 1, 4096, grid=grid(1), stream=stream0)
del arg20_1
# Source Nodes: [out_60], Original ATen: [torchao.fp16act_fp6weight_linear]
buf262 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf261, arg188_1, arg189_1, 1)
del arg188_1
del arg189_1
buf263 = buf262
del buf262
buf265 = buf238; del buf238 # reuse
# Source Nodes: [setitem_20, setitem_21, y_30], Original ATen: [aten.bmm, aten.index_put]
triton_poi_fused_bmm_index_put_1.run(arg454_1, buf263, arg66_1, arg190_1, buf265, arg191_1, 4096, grid=grid(4096), stream=stream0)
del buf263
buf266 = buf243; del buf243 # reuse
# Source Nodes: [y_30], Original ATen: [aten.bmm]
triton_red_fused_bmm_2.run(buf265, arg190_1, buf266, 6656, 128, grid=grid(6656), stream=stream0)
del arg190_1
buf270 = buf239; del buf239 # reuse
# Source Nodes: [mask, y_30], Original ATen: [aten._softmax, aten.add, aten.bmm, aten.index, aten.logical_not, aten.masked_fill, aten.zeros_like]
triton_per_fused__softmax_add_bmm_index_logical_not_masked_fill_zeros_like_3.run(buf266, arg454_1, arg67_1, buf270, 32, 208, grid=grid(32), stream=stream0)
buf271 = buf244; del buf244 # reuse
# Source Nodes: [y_30], Original ATen: [aten.bmm]
triton_red_fused_bmm_4.run(buf270, arg191_1, buf271, 8192, 104, grid=grid(8192), stream=stream0)
del arg191_1
buf273 = buf261; del buf261 # reuse
# Source Nodes: [out_61, y_30], Original ATen: [aten.bmm, torchao.fp16act_fp6weight_linear]
triton_per_fused_bmm_fp16act_fp6weight_linear_5.run(buf271, buf273, 4096, 2, grid=grid(4096), stream=stream0)
# Source Nodes: [out_61], Original ATen: [torchao.fp16act_fp6weight_linear]
buf274 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf273, arg192_1, arg193_1, 13)
del arg192_1
del arg193_1
buf275 = buf274
del buf274
buf277 = reinterpret_tensor(buf273, (1, 1, 4096), (4096, 4096, 1), 0); del buf273 # reuse
# Source Nodes: [add_64, float_44, h_10, mean_21, mul_161, mul_162, mul_163, out_59, output_21, rsqrt_21], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, aten.rsqrt]
triton_red_fused__to_copy_add_mean_mul_rsqrt_11.run(buf249, buf259, buf275, arg21_1, buf277, 1, 4096, grid=grid(1), stream=stream0)
del arg21_1
# Source Nodes: [out_62], Original ATen: [torchao.fp16act_fp6weight_linear]
buf278 = torch.ops.torchao.fp16act_fp6weight_linear.default(reinterpret_tensor(buf277, (1, 4096), (4096, 1), 0), arg194_1, arg195_1, 1)
del arg194_1
del arg195_1
buf279 = buf278
del buf278
# Source Nodes: [out_63], Original ATen: [torchao.fp16act_fp6weight_linear]
buf280 = torch.ops.torchao.fp16act_fp6weight_linear.default(reinterpret_tensor(buf277, (1, 4096), (4096, 1), 0), arg196_1, arg197_1, 1)
del arg196_1
del arg197_1
buf281 = buf280
del buf280
buf282 = buf279; del buf279 # reuse
# Source Nodes: [out_64], Original ATen: [torchao.fp16act_fp6weight_linear]
triton_poi_fused_fp16act_fp6weight_linear_7.run(buf282, buf281, 11008, grid=grid(11008), stream=stream0)
del buf281
# Source Nodes: [out_64], Original ATen: [torchao.fp16act_fp6weight_linear]
buf283 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf282, arg198_1, arg199_1, 13)
del arg198_1
del arg199_1
del buf282
buf284 = buf283
del buf283
buf286 = reinterpret_tensor(buf277, (1, 4096), (4096, 1), 0); del buf277 # reuse
# Source Nodes: [float_45, h_10, mean_22, mul_165, out_59, out_65, out_66], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, torchao.fp16act_fp6weight_linear]
triton_red_fused__to_copy_add_fp16act_fp6weight_linear_mean_mul_12.run(buf249, buf259, buf275, buf284, arg22_1, buf286, 1, 4096, grid=grid(1), stream=stream0)
del arg22_1
# Source Nodes: [out_66], Original ATen: [torchao.fp16act_fp6weight_linear]
buf287 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf286, arg200_1, arg201_1, 1)
del arg200_1
del arg201_1
buf288 = buf287
del buf287
buf290 = buf265; del buf265 # reuse
# Source Nodes: [setitem_22, setitem_23, y_33], Original ATen: [aten.bmm, aten.index_put]
triton_poi_fused_bmm_index_put_1.run(arg454_1, buf288, arg66_1, arg202_1, buf290, arg203_1, 4096, grid=grid(4096), stream=stream0)
del buf288
buf291 = buf270; del buf270 # reuse
# Source Nodes: [y_33], Original ATen: [aten.bmm]
triton_red_fused_bmm_2.run(buf290, arg202_1, buf291, 6656, 128, grid=grid(6656), stream=stream0)
del arg202_1
buf295 = buf266; del buf266 # reuse
# Source Nodes: [mask, y_33], Original ATen: [aten._softmax, aten.add, aten.bmm, aten.index, aten.logical_not, aten.masked_fill, aten.zeros_like]
triton_per_fused__softmax_add_bmm_index_logical_not_masked_fill_zeros_like_3.run(buf291, arg454_1, arg67_1, buf295, 32, 208, grid=grid(32), stream=stream0)
buf296 = buf271; del buf271 # reuse
# Source Nodes: [y_33], Original ATen: [aten.bmm]
triton_red_fused_bmm_4.run(buf295, arg203_1, buf296, 8192, 104, grid=grid(8192), stream=stream0)
del arg203_1
buf298 = buf286; del buf286 # reuse
# Source Nodes: [out_67, y_33], Original ATen: [aten.bmm, torchao.fp16act_fp6weight_linear]
triton_per_fused_bmm_fp16act_fp6weight_linear_5.run(buf296, buf298, 4096, 2, grid=grid(4096), stream=stream0)
# Source Nodes: [out_67], Original ATen: [torchao.fp16act_fp6weight_linear]
buf299 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf298, arg204_1, arg205_1, 13)
del arg204_1
del arg205_1
buf300 = buf299
del buf299
buf301 = buf249; del buf249 # reuse
buf303 = buf298; del buf298 # reuse
buf306 = buf251; del buf251 # reuse
# Source Nodes: [float_48, h_10, h_11, mean_23, mul_176, out_59, out_65, out_68, out_69], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, torchao.fp16act_fp6weight_linear]
triton_red_fused__to_copy_add_fp16act_fp6weight_linear_mean_mul_13.run(buf301, buf259, buf275, buf284, buf300, arg23_1, buf303, buf306, 1, 4096, grid=grid(1), stream=stream0)
del arg23_1
del buf259
del buf275
del buf284
del buf300
# Source Nodes: [out_68], Original ATen: [torchao.fp16act_fp6weight_linear]
buf304 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf303, arg206_1, arg207_1, 1)
del arg206_1
del arg207_1
buf305 = buf304
del buf304
# Source Nodes: [out_69], Original ATen: [torchao.fp16act_fp6weight_linear]
buf307 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf306, arg208_1, arg209_1, 1)
del arg208_1
del arg209_1
buf308 = buf307
del buf307
buf309 = buf305; del buf305 # reuse
# Source Nodes: [out_70], Original ATen: [torchao.fp16act_fp6weight_linear]
triton_poi_fused_fp16act_fp6weight_linear_7.run(buf309, buf308, 11008, grid=grid(11008), stream=stream0)
del buf308
# Source Nodes: [out_70], Original ATen: [torchao.fp16act_fp6weight_linear]
buf310 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf309, arg210_1, arg211_1, 13)
del arg210_1
del arg211_1
del buf309
buf311 = buf310
del buf310
buf313 = buf306; del buf306 # reuse
# Source Nodes: [float_49, mean_24, mul_180, out_71, out_72], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, torchao.fp16act_fp6weight_linear]
triton_red_fused__to_copy_add_fp16act_fp6weight_linear_mean_mul_10.run(buf301, buf311, arg24_1, buf313, 1, 4096, grid=grid(1), stream=stream0)
del arg24_1
# Source Nodes: [out_72], Original ATen: [torchao.fp16act_fp6weight_linear]
buf314 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf313, arg212_1, arg213_1, 1)
del arg212_1
del arg213_1
buf315 = buf314
del buf314
buf317 = buf290; del buf290 # reuse
# Source Nodes: [setitem_24, setitem_25, y_36], Original ATen: [aten.bmm, aten.index_put]
triton_poi_fused_bmm_index_put_1.run(arg454_1, buf315, arg66_1, arg214_1, buf317, arg215_1, 4096, grid=grid(4096), stream=stream0)
del buf315
buf318 = buf295; del buf295 # reuse
# Source Nodes: [y_36], Original ATen: [aten.bmm]
triton_red_fused_bmm_2.run(buf317, arg214_1, buf318, 6656, 128, grid=grid(6656), stream=stream0)
del arg214_1
buf322 = buf291; del buf291 # reuse
# Source Nodes: [mask, y_36], Original ATen: [aten._softmax, aten.add, aten.bmm, aten.index, aten.logical_not, aten.masked_fill, aten.zeros_like]
triton_per_fused__softmax_add_bmm_index_logical_not_masked_fill_zeros_like_3.run(buf318, arg454_1, arg67_1, buf322, 32, 208, grid=grid(32), stream=stream0)
buf323 = buf296; del buf296 # reuse
# Source Nodes: [y_36], Original ATen: [aten.bmm]
triton_red_fused_bmm_4.run(buf322, arg215_1, buf323, 8192, 104, grid=grid(8192), stream=stream0)
del arg215_1
buf325 = buf313; del buf313 # reuse
# Source Nodes: [out_73, y_36], Original ATen: [aten.bmm, torchao.fp16act_fp6weight_linear]
triton_per_fused_bmm_fp16act_fp6weight_linear_5.run(buf323, buf325, 4096, 2, grid=grid(4096), stream=stream0)
# Source Nodes: [out_73], Original ATen: [torchao.fp16act_fp6weight_linear]
buf326 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf325, arg216_1, arg217_1, 13)
del arg216_1
del arg217_1
buf327 = buf326
del buf326
buf329 = reinterpret_tensor(buf325, (1, 1, 4096), (4096, 4096, 1), 0); del buf325 # reuse
# Source Nodes: [add_76, float_52, h_12, mean_25, mul_191, mul_192, mul_193, out_71, output_25, rsqrt_25], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, aten.rsqrt]
triton_red_fused__to_copy_add_mean_mul_rsqrt_11.run(buf301, buf311, buf327, arg25_1, buf329, 1, 4096, grid=grid(1), stream=stream0)
del arg25_1
# Source Nodes: [out_74], Original ATen: [torchao.fp16act_fp6weight_linear]
buf330 = torch.ops.torchao.fp16act_fp6weight_linear.default(reinterpret_tensor(buf329, (1, 4096), (4096, 1), 0), arg218_1, arg219_1, 1)
del arg218_1
del arg219_1
buf331 = buf330
del buf330
# Source Nodes: [out_75], Original ATen: [torchao.fp16act_fp6weight_linear]
buf332 = torch.ops.torchao.fp16act_fp6weight_linear.default(reinterpret_tensor(buf329, (1, 4096), (4096, 1), 0), arg220_1, arg221_1, 1)
del arg220_1
del arg221_1
buf333 = buf332
del buf332
buf334 = buf331; del buf331 # reuse
# Source Nodes: [out_76], Original ATen: [torchao.fp16act_fp6weight_linear]
triton_poi_fused_fp16act_fp6weight_linear_7.run(buf334, buf333, 11008, grid=grid(11008), stream=stream0)
del buf333
# Source Nodes: [out_76], Original ATen: [torchao.fp16act_fp6weight_linear]
buf335 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf334, arg222_1, arg223_1, 13)
del arg222_1
del arg223_1
del buf334
buf336 = buf335
del buf335
buf338 = reinterpret_tensor(buf329, (1, 4096), (4096, 1), 0); del buf329 # reuse
# Source Nodes: [float_53, h_12, mean_26, mul_195, out_71, out_77, out_78], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, torchao.fp16act_fp6weight_linear]
triton_red_fused__to_copy_add_fp16act_fp6weight_linear_mean_mul_12.run(buf301, buf311, buf327, buf336, arg26_1, buf338, 1, 4096, grid=grid(1), stream=stream0)
del arg26_1
# Source Nodes: [out_78], Original ATen: [torchao.fp16act_fp6weight_linear]
buf339 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf338, arg224_1, arg225_1, 1)
del arg224_1
del arg225_1
buf340 = buf339
del buf339
buf342 = buf317; del buf317 # reuse
# Source Nodes: [setitem_26, setitem_27, y_39], Original ATen: [aten.bmm, aten.index_put]
triton_poi_fused_bmm_index_put_1.run(arg454_1, buf340, arg66_1, arg226_1, buf342, arg227_1, 4096, grid=grid(4096), stream=stream0)
del buf340
buf343 = buf322; del buf322 # reuse
# Source Nodes: [y_39], Original ATen: [aten.bmm]
triton_red_fused_bmm_2.run(buf342, arg226_1, buf343, 6656, 128, grid=grid(6656), stream=stream0)
del arg226_1
buf347 = buf318; del buf318 # reuse
# Source Nodes: [mask, y_39], Original ATen: [aten._softmax, aten.add, aten.bmm, aten.index, aten.logical_not, aten.masked_fill, aten.zeros_like]
triton_per_fused__softmax_add_bmm_index_logical_not_masked_fill_zeros_like_3.run(buf343, arg454_1, arg67_1, buf347, 32, 208, grid=grid(32), stream=stream0)
buf348 = buf323; del buf323 # reuse
# Source Nodes: [y_39], Original ATen: [aten.bmm]
triton_red_fused_bmm_4.run(buf347, arg227_1, buf348, 8192, 104, grid=grid(8192), stream=stream0)
del arg227_1
buf350 = buf338; del buf338 # reuse
# Source Nodes: [out_79, y_39], Original ATen: [aten.bmm, torchao.fp16act_fp6weight_linear]
triton_per_fused_bmm_fp16act_fp6weight_linear_5.run(buf348, buf350, 4096, 2, grid=grid(4096), stream=stream0)
# Source Nodes: [out_79], Original ATen: [torchao.fp16act_fp6weight_linear]
buf351 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf350, arg228_1, arg229_1, 13)
del arg228_1
del arg229_1
buf352 = buf351
del buf351
buf353 = buf301; del buf301 # reuse
buf355 = buf350; del buf350 # reuse
buf358 = buf303; del buf303 # reuse
# Source Nodes: [float_56, h_12, h_13, mean_27, mul_206, out_71, out_77, out_80, out_81], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, torchao.fp16act_fp6weight_linear]
triton_red_fused__to_copy_add_fp16act_fp6weight_linear_mean_mul_13.run(buf353, buf311, buf327, buf336, buf352, arg27_1, buf355, buf358, 1, 4096, grid=grid(1), stream=stream0)
del arg27_1
del buf311
del buf327
del buf336
del buf352
# Source Nodes: [out_80], Original ATen: [torchao.fp16act_fp6weight_linear]
buf356 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf355, arg230_1, arg231_1, 1)
del arg230_1
del arg231_1
buf357 = buf356
del buf356
# Source Nodes: [out_81], Original ATen: [torchao.fp16act_fp6weight_linear]
buf359 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf358, arg232_1, arg233_1, 1)
del arg232_1
del arg233_1
buf360 = buf359
del buf359
buf361 = buf357; del buf357 # reuse
# Source Nodes: [out_82], Original ATen: [torchao.fp16act_fp6weight_linear]
triton_poi_fused_fp16act_fp6weight_linear_7.run(buf361, buf360, 11008, grid=grid(11008), stream=stream0)
del buf360
# Source Nodes: [out_82], Original ATen: [torchao.fp16act_fp6weight_linear]
buf362 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf361, arg234_1, arg235_1, 13)
del arg234_1
del arg235_1
del buf361
buf363 = buf362
del buf362
buf365 = buf358; del buf358 # reuse
# Source Nodes: [float_57, mean_28, mul_210, out_83, out_84], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, torchao.fp16act_fp6weight_linear]
triton_red_fused__to_copy_add_fp16act_fp6weight_linear_mean_mul_10.run(buf353, buf363, arg28_1, buf365, 1, 4096, grid=grid(1), stream=stream0)
del arg28_1
# Source Nodes: [out_84], Original ATen: [torchao.fp16act_fp6weight_linear]
buf366 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf365, arg236_1, arg237_1, 1)
del arg236_1
del arg237_1
buf367 = buf366
del buf366
buf369 = buf342; del buf342 # reuse
# Source Nodes: [setitem_28, setitem_29, y_42], Original ATen: [aten.bmm, aten.index_put]
triton_poi_fused_bmm_index_put_1.run(arg454_1, buf367, arg66_1, arg238_1, buf369, arg239_1, 4096, grid=grid(4096), stream=stream0)
del buf367
buf370 = buf347; del buf347 # reuse
# Source Nodes: [y_42], Original ATen: [aten.bmm]
triton_red_fused_bmm_2.run(buf369, arg238_1, buf370, 6656, 128, grid=grid(6656), stream=stream0)
del arg238_1
buf374 = buf343; del buf343 # reuse
# Source Nodes: [mask, y_42], Original ATen: [aten._softmax, aten.add, aten.bmm, aten.index, aten.logical_not, aten.masked_fill, aten.zeros_like]
triton_per_fused__softmax_add_bmm_index_logical_not_masked_fill_zeros_like_3.run(buf370, arg454_1, arg67_1, buf374, 32, 208, grid=grid(32), stream=stream0)
buf375 = buf348; del buf348 # reuse
# Source Nodes: [y_42], Original ATen: [aten.bmm]
triton_red_fused_bmm_4.run(buf374, arg239_1, buf375, 8192, 104, grid=grid(8192), stream=stream0)
del arg239_1
buf377 = buf365; del buf365 # reuse
# Source Nodes: [out_85, y_42], Original ATen: [aten.bmm, torchao.fp16act_fp6weight_linear]
triton_per_fused_bmm_fp16act_fp6weight_linear_5.run(buf375, buf377, 4096, 2, grid=grid(4096), stream=stream0)
# Source Nodes: [out_85], Original ATen: [torchao.fp16act_fp6weight_linear]
buf378 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf377, arg240_1, arg241_1, 13)
del arg240_1
del arg241_1
buf379 = buf378
del buf378
buf381 = reinterpret_tensor(buf377, (1, 1, 4096), (4096, 4096, 1), 0); del buf377 # reuse
# Source Nodes: [add_88, float_60, h_14, mean_29, mul_221, mul_222, mul_223, out_83, output_29, rsqrt_29], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, aten.rsqrt]
triton_red_fused__to_copy_add_mean_mul_rsqrt_11.run(buf353, buf363, buf379, arg29_1, buf381, 1, 4096, grid=grid(1), stream=stream0)
del arg29_1
# Source Nodes: [out_86], Original ATen: [torchao.fp16act_fp6weight_linear]
buf382 = torch.ops.torchao.fp16act_fp6weight_linear.default(reinterpret_tensor(buf381, (1, 4096), (4096, 1), 0), arg242_1, arg243_1, 1)
del arg242_1
del arg243_1
buf383 = buf382
del buf382
# Source Nodes: [out_87], Original ATen: [torchao.fp16act_fp6weight_linear]
buf384 = torch.ops.torchao.fp16act_fp6weight_linear.default(reinterpret_tensor(buf381, (1, 4096), (4096, 1), 0), arg244_1, arg245_1, 1)
del arg244_1
del arg245_1
buf385 = buf384
del buf384
buf386 = buf383; del buf383 # reuse
# Source Nodes: [out_88], Original ATen: [torchao.fp16act_fp6weight_linear]
triton_poi_fused_fp16act_fp6weight_linear_7.run(buf386, buf385, 11008, grid=grid(11008), stream=stream0)
del buf385
# Source Nodes: [out_88], Original ATen: [torchao.fp16act_fp6weight_linear]
buf387 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf386, arg246_1, arg247_1, 13)
del arg246_1
del arg247_1
del buf386
buf388 = buf387
del buf387
buf390 = reinterpret_tensor(buf381, (1, 4096), (4096, 1), 0); del buf381 # reuse
# Source Nodes: [float_61, h_14, mean_30, mul_225, out_83, out_89, out_90], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, torchao.fp16act_fp6weight_linear]
triton_red_fused__to_copy_add_fp16act_fp6weight_linear_mean_mul_12.run(buf353, buf363, buf379, buf388, arg30_1, buf390, 1, 4096, grid=grid(1), stream=stream0)
del arg30_1
# Source Nodes: [out_90], Original ATen: [torchao.fp16act_fp6weight_linear]
buf391 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf390, arg248_1, arg249_1, 1)
del arg248_1
del arg249_1
buf392 = buf391
del buf391
buf394 = buf369; del buf369 # reuse
# Source Nodes: [setitem_30, setitem_31, y_45], Original ATen: [aten.bmm, aten.index_put]
triton_poi_fused_bmm_index_put_1.run(arg454_1, buf392, arg66_1, arg250_1, buf394, arg251_1, 4096, grid=grid(4096), stream=stream0)
del buf392
buf395 = buf374; del buf374 # reuse
# Source Nodes: [y_45], Original ATen: [aten.bmm]
triton_red_fused_bmm_2.run(buf394, arg250_1, buf395, 6656, 128, grid=grid(6656), stream=stream0)
del arg250_1
buf399 = buf370; del buf370 # reuse
# Source Nodes: [mask, y_45], Original ATen: [aten._softmax, aten.add, aten.bmm, aten.index, aten.logical_not, aten.masked_fill, aten.zeros_like]
triton_per_fused__softmax_add_bmm_index_logical_not_masked_fill_zeros_like_3.run(buf395, arg454_1, arg67_1, buf399, 32, 208, grid=grid(32), stream=stream0)
buf400 = buf375; del buf375 # reuse
# Source Nodes: [y_45], Original ATen: [aten.bmm]
triton_red_fused_bmm_4.run(buf399, arg251_1, buf400, 8192, 104, grid=grid(8192), stream=stream0)
del arg251_1
buf402 = buf390; del buf390 # reuse
# Source Nodes: [out_91, y_45], Original ATen: [aten.bmm, torchao.fp16act_fp6weight_linear]
triton_per_fused_bmm_fp16act_fp6weight_linear_5.run(buf400, buf402, 4096, 2, grid=grid(4096), stream=stream0)
# Source Nodes: [out_91], Original ATen: [torchao.fp16act_fp6weight_linear]
buf403 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf402, arg252_1, arg253_1, 13)
del arg252_1
del arg253_1
buf404 = buf403
del buf403
buf405 = buf353; del buf353 # reuse
buf407 = buf402; del buf402 # reuse
buf410 = buf355; del buf355 # reuse
# Source Nodes: [float_64, h_14, h_15, mean_31, mul_236, out_83, out_89, out_92, out_93], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, torchao.fp16act_fp6weight_linear]
triton_red_fused__to_copy_add_fp16act_fp6weight_linear_mean_mul_13.run(buf405, buf363, buf379, buf388, buf404, arg31_1, buf407, buf410, 1, 4096, grid=grid(1), stream=stream0)
del arg31_1
del buf363
del buf379
del buf388
del buf404
# Source Nodes: [out_92], Original ATen: [torchao.fp16act_fp6weight_linear]
buf408 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf407, arg254_1, arg255_1, 1)
del arg254_1
del arg255_1
buf409 = buf408
del buf408
# Source Nodes: [out_93], Original ATen: [torchao.fp16act_fp6weight_linear]
buf411 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf410, arg256_1, arg257_1, 1)
del arg256_1
del arg257_1
buf412 = buf411
del buf411
buf413 = buf409; del buf409 # reuse
# Source Nodes: [out_94], Original ATen: [torchao.fp16act_fp6weight_linear]
triton_poi_fused_fp16act_fp6weight_linear_7.run(buf413, buf412, 11008, grid=grid(11008), stream=stream0)
del buf412
# Source Nodes: [out_94], Original ATen: [torchao.fp16act_fp6weight_linear]
buf414 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf413, arg258_1, arg259_1, 13)
del arg258_1
del arg259_1
del buf413
buf415 = buf414
del buf414
buf417 = buf410; del buf410 # reuse
# Source Nodes: [float_65, mean_32, mul_240, out_95, out_96], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, torchao.fp16act_fp6weight_linear]
triton_red_fused__to_copy_add_fp16act_fp6weight_linear_mean_mul_10.run(buf405, buf415, arg32_1, buf417, 1, 4096, grid=grid(1), stream=stream0)
del arg32_1
# Source Nodes: [out_96], Original ATen: [torchao.fp16act_fp6weight_linear]
buf418 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf417, arg260_1, arg261_1, 1)
del arg260_1
del arg261_1
buf419 = buf418
del buf418
buf421 = buf394; del buf394 # reuse
# Source Nodes: [setitem_32, setitem_33, y_48], Original ATen: [aten.bmm, aten.index_put]
triton_poi_fused_bmm_index_put_1.run(arg454_1, buf419, arg66_1, arg262_1, buf421, arg263_1, 4096, grid=grid(4096), stream=stream0)
del buf419
buf422 = buf399; del buf399 # reuse
# Source Nodes: [y_48], Original ATen: [aten.bmm]
triton_red_fused_bmm_2.run(buf421, arg262_1, buf422, 6656, 128, grid=grid(6656), stream=stream0)
del arg262_1
buf426 = buf395; del buf395 # reuse
# Source Nodes: [mask, y_48], Original ATen: [aten._softmax, aten.add, aten.bmm, aten.index, aten.logical_not, aten.masked_fill, aten.zeros_like]
triton_per_fused__softmax_add_bmm_index_logical_not_masked_fill_zeros_like_3.run(buf422, arg454_1, arg67_1, buf426, 32, 208, grid=grid(32), stream=stream0)
buf427 = buf400; del buf400 # reuse
# Source Nodes: [y_48], Original ATen: [aten.bmm]
triton_red_fused_bmm_4.run(buf426, arg263_1, buf427, 8192, 104, grid=grid(8192), stream=stream0)
del arg263_1
buf429 = buf417; del buf417 # reuse
# Source Nodes: [out_97, y_48], Original ATen: [aten.bmm, torchao.fp16act_fp6weight_linear]
triton_per_fused_bmm_fp16act_fp6weight_linear_5.run(buf427, buf429, 4096, 2, grid=grid(4096), stream=stream0)
# Source Nodes: [out_97], Original ATen: [torchao.fp16act_fp6weight_linear]
buf430 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf429, arg264_1, arg265_1, 13)
del arg264_1
del arg265_1
buf431 = buf430
del buf430
buf433 = reinterpret_tensor(buf429, (1, 1, 4096), (4096, 4096, 1), 0); del buf429 # reuse
# Source Nodes: [add_100, float_68, h_16, mean_33, mul_251, mul_252, mul_253, out_95, output_33, rsqrt_33], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, aten.rsqrt]
triton_red_fused__to_copy_add_mean_mul_rsqrt_11.run(buf405, buf415, buf431, arg33_1, buf433, 1, 4096, grid=grid(1), stream=stream0)
del arg33_1
# Source Nodes: [out_98], Original ATen: [torchao.fp16act_fp6weight_linear]
buf434 = torch.ops.torchao.fp16act_fp6weight_linear.default(reinterpret_tensor(buf433, (1, 4096), (4096, 1), 0), arg266_1, arg267_1, 1)
del arg266_1
del arg267_1
buf435 = buf434
del buf434
# Source Nodes: [out_99], Original ATen: [torchao.fp16act_fp6weight_linear]
buf436 = torch.ops.torchao.fp16act_fp6weight_linear.default(reinterpret_tensor(buf433, (1, 4096), (4096, 1), 0), arg268_1, arg269_1, 1)
del arg268_1
del arg269_1
buf437 = buf436
del buf436
buf438 = buf435; del buf435 # reuse
# Source Nodes: [out_100], Original ATen: [torchao.fp16act_fp6weight_linear]
triton_poi_fused_fp16act_fp6weight_linear_7.run(buf438, buf437, 11008, grid=grid(11008), stream=stream0)
del buf437
# Source Nodes: [out_100], Original ATen: [torchao.fp16act_fp6weight_linear]
buf439 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf438, arg270_1, arg271_1, 13)
del arg270_1
del arg271_1
del buf438
buf440 = buf439
del buf439
buf442 = reinterpret_tensor(buf433, (1, 4096), (4096, 1), 0); del buf433 # reuse
# Source Nodes: [float_69, h_16, mean_34, mul_255, out_101, out_102, out_95], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, torchao.fp16act_fp6weight_linear]
triton_red_fused__to_copy_add_fp16act_fp6weight_linear_mean_mul_12.run(buf405, buf415, buf431, buf440, arg34_1, buf442, 1, 4096, grid=grid(1), stream=stream0)
del arg34_1
# Source Nodes: [out_102], Original ATen: [torchao.fp16act_fp6weight_linear]
buf443 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf442, arg272_1, arg273_1, 1)
del arg272_1
del arg273_1
buf444 = buf443
del buf443
buf446 = buf421; del buf421 # reuse
# Source Nodes: [setitem_34, setitem_35, y_51], Original ATen: [aten.bmm, aten.index_put]
triton_poi_fused_bmm_index_put_1.run(arg454_1, buf444, arg66_1, arg274_1, buf446, arg275_1, 4096, grid=grid(4096), stream=stream0)
del buf444
buf447 = buf426; del buf426 # reuse
# Source Nodes: [y_51], Original ATen: [aten.bmm]
triton_red_fused_bmm_2.run(buf446, arg274_1, buf447, 6656, 128, grid=grid(6656), stream=stream0)
del arg274_1
buf451 = buf422; del buf422 # reuse
# Source Nodes: [mask, y_51], Original ATen: [aten._softmax, aten.add, aten.bmm, aten.index, aten.logical_not, aten.masked_fill, aten.zeros_like]
triton_per_fused__softmax_add_bmm_index_logical_not_masked_fill_zeros_like_3.run(buf447, arg454_1, arg67_1, buf451, 32, 208, grid=grid(32), stream=stream0)
buf452 = buf427; del buf427 # reuse
# Source Nodes: [y_51], Original ATen: [aten.bmm]
triton_red_fused_bmm_4.run(buf451, arg275_1, buf452, 8192, 104, grid=grid(8192), stream=stream0)
del arg275_1
buf454 = buf442; del buf442 # reuse
# Source Nodes: [out_103, y_51], Original ATen: [aten.bmm, torchao.fp16act_fp6weight_linear]
triton_per_fused_bmm_fp16act_fp6weight_linear_5.run(buf452, buf454, 4096, 2, grid=grid(4096), stream=stream0)
# Source Nodes: [out_103], Original ATen: [torchao.fp16act_fp6weight_linear]
buf455 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf454, arg276_1, arg277_1, 13)
del arg276_1
del arg277_1
buf456 = buf455
del buf455
buf457 = buf405; del buf405 # reuse
buf459 = buf454; del buf454 # reuse
buf462 = buf407; del buf407 # reuse
# Source Nodes: [float_72, h_16, h_17, mean_35, mul_266, out_101, out_104, out_105, out_95], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, torchao.fp16act_fp6weight_linear]
triton_red_fused__to_copy_add_fp16act_fp6weight_linear_mean_mul_13.run(buf457, buf415, buf431, buf440, buf456, arg35_1, buf459, buf462, 1, 4096, grid=grid(1), stream=stream0)
del arg35_1
del buf415
del buf431
del buf440
del buf456
# Source Nodes: [out_104], Original ATen: [torchao.fp16act_fp6weight_linear]
buf460 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf459, arg278_1, arg279_1, 1)
del arg278_1
del arg279_1
buf461 = buf460
del buf460
# Source Nodes: [out_105], Original ATen: [torchao.fp16act_fp6weight_linear]
buf463 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf462, arg280_1, arg281_1, 1)
del arg280_1
del arg281_1
buf464 = buf463
del buf463
buf465 = buf461; del buf461 # reuse
# Source Nodes: [out_106], Original ATen: [torchao.fp16act_fp6weight_linear]
triton_poi_fused_fp16act_fp6weight_linear_7.run(buf465, buf464, 11008, grid=grid(11008), stream=stream0)
del buf464
# Source Nodes: [out_106], Original ATen: [torchao.fp16act_fp6weight_linear]
buf466 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf465, arg282_1, arg283_1, 13)
del arg282_1
del arg283_1
del buf465
buf467 = buf466
del buf466
buf469 = buf462; del buf462 # reuse
# Source Nodes: [float_73, mean_36, mul_270, out_107, out_108], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, torchao.fp16act_fp6weight_linear]
triton_red_fused__to_copy_add_fp16act_fp6weight_linear_mean_mul_10.run(buf457, buf467, arg36_1, buf469, 1, 4096, grid=grid(1), stream=stream0)
del arg36_1
# Source Nodes: [out_108], Original ATen: [torchao.fp16act_fp6weight_linear]
buf470 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf469, arg284_1, arg285_1, 1)
del arg284_1
del arg285_1
buf471 = buf470
del buf470
buf473 = buf446; del buf446 # reuse
# Source Nodes: [setitem_36, setitem_37, y_54], Original ATen: [aten.bmm, aten.index_put]
triton_poi_fused_bmm_index_put_1.run(arg454_1, buf471, arg66_1, arg286_1, buf473, arg287_1, 4096, grid=grid(4096), stream=stream0)
del buf471
buf474 = buf451; del buf451 # reuse
# Source Nodes: [y_54], Original ATen: [aten.bmm]
triton_red_fused_bmm_2.run(buf473, arg286_1, buf474, 6656, 128, grid=grid(6656), stream=stream0)
del arg286_1
buf478 = buf447; del buf447 # reuse
# Source Nodes: [mask, y_54], Original ATen: [aten._softmax, aten.add, aten.bmm, aten.index, aten.logical_not, aten.masked_fill, aten.zeros_like]
triton_per_fused__softmax_add_bmm_index_logical_not_masked_fill_zeros_like_3.run(buf474, arg454_1, arg67_1, buf478, 32, 208, grid=grid(32), stream=stream0)
buf479 = buf452; del buf452 # reuse
# Source Nodes: [y_54], Original ATen: [aten.bmm]
triton_red_fused_bmm_4.run(buf478, arg287_1, buf479, 8192, 104, grid=grid(8192), stream=stream0)
del arg287_1
buf481 = buf469; del buf469 # reuse
# Source Nodes: [out_109, y_54], Original ATen: [aten.bmm, torchao.fp16act_fp6weight_linear]
triton_per_fused_bmm_fp16act_fp6weight_linear_5.run(buf479, buf481, 4096, 2, grid=grid(4096), stream=stream0)
# Source Nodes: [out_109], Original ATen: [torchao.fp16act_fp6weight_linear]
buf482 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf481, arg288_1, arg289_1, 13)
del arg288_1
del arg289_1
buf483 = buf482
del buf482
buf485 = reinterpret_tensor(buf481, (1, 1, 4096), (4096, 4096, 1), 0); del buf481 # reuse
# Source Nodes: [add_112, float_76, h_18, mean_37, mul_281, mul_282, mul_283, out_107, output_37, rsqrt_37], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, aten.rsqrt]
triton_red_fused__to_copy_add_mean_mul_rsqrt_11.run(buf457, buf467, buf483, arg37_1, buf485, 1, 4096, grid=grid(1), stream=stream0)
del arg37_1
# Source Nodes: [out_110], Original ATen: [torchao.fp16act_fp6weight_linear]
buf486 = torch.ops.torchao.fp16act_fp6weight_linear.default(reinterpret_tensor(buf485, (1, 4096), (4096, 1), 0), arg290_1, arg291_1, 1)
del arg290_1
del arg291_1
buf487 = buf486
del buf486
# Source Nodes: [out_111], Original ATen: [torchao.fp16act_fp6weight_linear]
buf488 = torch.ops.torchao.fp16act_fp6weight_linear.default(reinterpret_tensor(buf485, (1, 4096), (4096, 1), 0), arg292_1, arg293_1, 1)
del arg292_1
del arg293_1
buf489 = buf488
del buf488
buf490 = buf487; del buf487 # reuse
# Source Nodes: [out_112], Original ATen: [torchao.fp16act_fp6weight_linear]
triton_poi_fused_fp16act_fp6weight_linear_7.run(buf490, buf489, 11008, grid=grid(11008), stream=stream0)
del buf489
# Source Nodes: [out_112], Original ATen: [torchao.fp16act_fp6weight_linear]
buf491 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf490, arg294_1, arg295_1, 13)
del arg294_1
del arg295_1
del buf490
buf492 = buf491
del buf491
buf494 = reinterpret_tensor(buf485, (1, 4096), (4096, 1), 0); del buf485 # reuse
# Source Nodes: [float_77, h_18, mean_38, mul_285, out_107, out_113, out_114], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, torchao.fp16act_fp6weight_linear]
triton_red_fused__to_copy_add_fp16act_fp6weight_linear_mean_mul_12.run(buf457, buf467, buf483, buf492, arg38_1, buf494, 1, 4096, grid=grid(1), stream=stream0)
del arg38_1
# Source Nodes: [out_114], Original ATen: [torchao.fp16act_fp6weight_linear]
buf495 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf494, arg296_1, arg297_1, 1)
del arg296_1
del arg297_1
buf496 = buf495
del buf495
buf498 = buf473; del buf473 # reuse
# Source Nodes: [setitem_38, setitem_39, y_57], Original ATen: [aten.bmm, aten.index_put]
triton_poi_fused_bmm_index_put_1.run(arg454_1, buf496, arg66_1, arg298_1, buf498, arg299_1, 4096, grid=grid(4096), stream=stream0)
del buf496
buf499 = buf478; del buf478 # reuse
# Source Nodes: [y_57], Original ATen: [aten.bmm]
triton_red_fused_bmm_2.run(buf498, arg298_1, buf499, 6656, 128, grid=grid(6656), stream=stream0)
del arg298_1
buf503 = buf474; del buf474 # reuse
# Source Nodes: [mask, y_57], Original ATen: [aten._softmax, aten.add, aten.bmm, aten.index, aten.logical_not, aten.masked_fill, aten.zeros_like]
triton_per_fused__softmax_add_bmm_index_logical_not_masked_fill_zeros_like_3.run(buf499, arg454_1, arg67_1, buf503, 32, 208, grid=grid(32), stream=stream0)
buf504 = buf479; del buf479 # reuse
# Source Nodes: [y_57], Original ATen: [aten.bmm]
triton_red_fused_bmm_4.run(buf503, arg299_1, buf504, 8192, 104, grid=grid(8192), stream=stream0)
del arg299_1
buf506 = buf494; del buf494 # reuse
# Source Nodes: [out_115, y_57], Original ATen: [aten.bmm, torchao.fp16act_fp6weight_linear]
triton_per_fused_bmm_fp16act_fp6weight_linear_5.run(buf504, buf506, 4096, 2, grid=grid(4096), stream=stream0)
# Source Nodes: [out_115], Original ATen: [torchao.fp16act_fp6weight_linear]
buf507 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf506, arg300_1, arg301_1, 13)
del arg300_1
del arg301_1
buf508 = buf507
del buf507
buf509 = buf457; del buf457 # reuse
buf511 = buf506; del buf506 # reuse
buf514 = buf459; del buf459 # reuse
# Source Nodes: [float_80, h_18, h_19, mean_39, mul_296, out_107, out_113, out_116, out_117], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, torchao.fp16act_fp6weight_linear]
triton_red_fused__to_copy_add_fp16act_fp6weight_linear_mean_mul_13.run(buf509, buf467, buf483, buf492, buf508, arg39_1, buf511, buf514, 1, 4096, grid=grid(1), stream=stream0)
del arg39_1
del buf467
del buf483
del buf492
del buf508
# Source Nodes: [out_116], Original ATen: [torchao.fp16act_fp6weight_linear]
buf512 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf511, arg302_1, arg303_1, 1)
del arg302_1
del arg303_1
buf513 = buf512
del buf512
# Source Nodes: [out_117], Original ATen: [torchao.fp16act_fp6weight_linear]
buf515 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf514, arg304_1, arg305_1, 1)
del arg304_1
del arg305_1
buf516 = buf515
del buf515
buf517 = buf513; del buf513 # reuse
# Source Nodes: [out_118], Original ATen: [torchao.fp16act_fp6weight_linear]
triton_poi_fused_fp16act_fp6weight_linear_7.run(buf517, buf516, 11008, grid=grid(11008), stream=stream0)
del buf516
# Source Nodes: [out_118], Original ATen: [torchao.fp16act_fp6weight_linear]
buf518 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf517, arg306_1, arg307_1, 13)
del arg306_1
del arg307_1
del buf517
buf519 = buf518
del buf518
buf521 = buf514; del buf514 # reuse
# Source Nodes: [float_81, mean_40, mul_300, out_119, out_120], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, torchao.fp16act_fp6weight_linear]
triton_red_fused__to_copy_add_fp16act_fp6weight_linear_mean_mul_10.run(buf509, buf519, arg40_1, buf521, 1, 4096, grid=grid(1), stream=stream0)
del arg40_1
# Source Nodes: [out_120], Original ATen: [torchao.fp16act_fp6weight_linear]
buf522 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf521, arg308_1, arg309_1, 1)
del arg308_1
del arg309_1
buf523 = buf522
del buf522
buf525 = buf498; del buf498 # reuse
# Source Nodes: [setitem_40, setitem_41, y_60], Original ATen: [aten.bmm, aten.index_put]
triton_poi_fused_bmm_index_put_1.run(arg454_1, buf523, arg66_1, arg310_1, buf525, arg311_1, 4096, grid=grid(4096), stream=stream0)
del buf523
buf526 = buf503; del buf503 # reuse
# Source Nodes: [y_60], Original ATen: [aten.bmm]
triton_red_fused_bmm_2.run(buf525, arg310_1, buf526, 6656, 128, grid=grid(6656), stream=stream0)
del arg310_1
buf530 = buf499; del buf499 # reuse
# Source Nodes: [mask, y_60], Original ATen: [aten._softmax, aten.add, aten.bmm, aten.index, aten.logical_not, aten.masked_fill, aten.zeros_like]
triton_per_fused__softmax_add_bmm_index_logical_not_masked_fill_zeros_like_3.run(buf526, arg454_1, arg67_1, buf530, 32, 208, grid=grid(32), stream=stream0)
buf531 = buf504; del buf504 # reuse
# Source Nodes: [y_60], Original ATen: [aten.bmm]
triton_red_fused_bmm_4.run(buf530, arg311_1, buf531, 8192, 104, grid=grid(8192), stream=stream0)
del arg311_1
buf533 = buf521; del buf521 # reuse
# Source Nodes: [out_121, y_60], Original ATen: [aten.bmm, torchao.fp16act_fp6weight_linear]
triton_per_fused_bmm_fp16act_fp6weight_linear_5.run(buf531, buf533, 4096, 2, grid=grid(4096), stream=stream0)
# Source Nodes: [out_121], Original ATen: [torchao.fp16act_fp6weight_linear]
buf534 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf533, arg312_1, arg313_1, 13)
del arg312_1
del arg313_1
buf535 = buf534
del buf534
buf537 = reinterpret_tensor(buf533, (1, 1, 4096), (4096, 4096, 1), 0); del buf533 # reuse
# Source Nodes: [add_124, float_84, h_20, mean_41, mul_311, mul_312, mul_313, out_119, output_41, rsqrt_41], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, aten.rsqrt]
triton_red_fused__to_copy_add_mean_mul_rsqrt_11.run(buf509, buf519, buf535, arg41_1, buf537, 1, 4096, grid=grid(1), stream=stream0)
del arg41_1
# Source Nodes: [out_122], Original ATen: [torchao.fp16act_fp6weight_linear]
buf538 = torch.ops.torchao.fp16act_fp6weight_linear.default(reinterpret_tensor(buf537, (1, 4096), (4096, 1), 0), arg314_1, arg315_1, 1)
del arg314_1
del arg315_1
buf539 = buf538
del buf538
# Source Nodes: [out_123], Original ATen: [torchao.fp16act_fp6weight_linear]
buf540 = torch.ops.torchao.fp16act_fp6weight_linear.default(reinterpret_tensor(buf537, (1, 4096), (4096, 1), 0), arg316_1, arg317_1, 1)
del arg316_1
del arg317_1
buf541 = buf540
del buf540
buf542 = buf539; del buf539 # reuse
# Source Nodes: [out_124], Original ATen: [torchao.fp16act_fp6weight_linear]
triton_poi_fused_fp16act_fp6weight_linear_7.run(buf542, buf541, 11008, grid=grid(11008), stream=stream0)
del buf541
# Source Nodes: [out_124], Original ATen: [torchao.fp16act_fp6weight_linear]
buf543 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf542, arg318_1, arg319_1, 13)
del arg318_1
del arg319_1
del buf542
buf544 = buf543
del buf543
buf546 = reinterpret_tensor(buf537, (1, 4096), (4096, 1), 0); del buf537 # reuse
# Source Nodes: [float_85, h_20, mean_42, mul_315, out_119, out_125, out_126], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, torchao.fp16act_fp6weight_linear]
triton_red_fused__to_copy_add_fp16act_fp6weight_linear_mean_mul_12.run(buf509, buf519, buf535, buf544, arg42_1, buf546, 1, 4096, grid=grid(1), stream=stream0)
del arg42_1
# Source Nodes: [out_126], Original ATen: [torchao.fp16act_fp6weight_linear]
buf547 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf546, arg320_1, arg321_1, 1)
del arg320_1
del arg321_1
buf548 = buf547
del buf547
buf550 = buf525; del buf525 # reuse
# Source Nodes: [setitem_42, setitem_43, y_63], Original ATen: [aten.bmm, aten.index_put]
triton_poi_fused_bmm_index_put_1.run(arg454_1, buf548, arg66_1, arg322_1, buf550, arg323_1, 4096, grid=grid(4096), stream=stream0)
del buf548
buf551 = buf530; del buf530 # reuse
# Source Nodes: [y_63], Original ATen: [aten.bmm]
triton_red_fused_bmm_2.run(buf550, arg322_1, buf551, 6656, 128, grid=grid(6656), stream=stream0)
del arg322_1
buf555 = buf526; del buf526 # reuse
# Source Nodes: [mask, y_63], Original ATen: [aten._softmax, aten.add, aten.bmm, aten.index, aten.logical_not, aten.masked_fill, aten.zeros_like]
triton_per_fused__softmax_add_bmm_index_logical_not_masked_fill_zeros_like_3.run(buf551, arg454_1, arg67_1, buf555, 32, 208, grid=grid(32), stream=stream0)
buf556 = buf531; del buf531 # reuse
# Source Nodes: [y_63], Original ATen: [aten.bmm]
triton_red_fused_bmm_4.run(buf555, arg323_1, buf556, 8192, 104, grid=grid(8192), stream=stream0)
del arg323_1
buf558 = buf546; del buf546 # reuse
# Source Nodes: [out_127, y_63], Original ATen: [aten.bmm, torchao.fp16act_fp6weight_linear]
triton_per_fused_bmm_fp16act_fp6weight_linear_5.run(buf556, buf558, 4096, 2, grid=grid(4096), stream=stream0)
# Source Nodes: [out_127], Original ATen: [torchao.fp16act_fp6weight_linear]
buf559 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf558, arg324_1, arg325_1, 13)
del arg324_1
del arg325_1
buf560 = buf559
del buf559
buf561 = buf509; del buf509 # reuse
buf563 = buf558; del buf558 # reuse
buf566 = buf511; del buf511 # reuse
# Source Nodes: [float_88, h_20, h_21, mean_43, mul_326, out_119, out_125, out_128, out_129], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, torchao.fp16act_fp6weight_linear]
triton_red_fused__to_copy_add_fp16act_fp6weight_linear_mean_mul_13.run(buf561, buf519, buf535, buf544, buf560, arg43_1, buf563, buf566, 1, 4096, grid=grid(1), stream=stream0)
del arg43_1
del buf519
del buf535
del buf544
del buf560
# Source Nodes: [out_128], Original ATen: [torchao.fp16act_fp6weight_linear]
buf564 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf563, arg326_1, arg327_1, 1)
del arg326_1
del arg327_1
buf565 = buf564
del buf564
# Source Nodes: [out_129], Original ATen: [torchao.fp16act_fp6weight_linear]
buf567 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf566, arg328_1, arg329_1, 1)
del arg328_1
del arg329_1
buf568 = buf567
del buf567
buf569 = buf565; del buf565 # reuse
# Source Nodes: [out_130], Original ATen: [torchao.fp16act_fp6weight_linear]
triton_poi_fused_fp16act_fp6weight_linear_7.run(buf569, buf568, 11008, grid=grid(11008), stream=stream0)
del buf568
# Source Nodes: [out_130], Original ATen: [torchao.fp16act_fp6weight_linear]
buf570 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf569, arg330_1, arg331_1, 13)
del arg330_1
del arg331_1
del buf569
buf571 = buf570
del buf570
buf573 = buf566; del buf566 # reuse
# Source Nodes: [float_89, mean_44, mul_330, out_131, out_132], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, torchao.fp16act_fp6weight_linear]
triton_red_fused__to_copy_add_fp16act_fp6weight_linear_mean_mul_10.run(buf561, buf571, arg44_1, buf573, 1, 4096, grid=grid(1), stream=stream0)
del arg44_1
# Source Nodes: [out_132], Original ATen: [torchao.fp16act_fp6weight_linear]
buf574 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf573, arg332_1, arg333_1, 1)
del arg332_1
del arg333_1
buf575 = buf574
del buf574
buf577 = buf550; del buf550 # reuse
# Source Nodes: [setitem_44, setitem_45, y_66], Original ATen: [aten.bmm, aten.index_put]
triton_poi_fused_bmm_index_put_1.run(arg454_1, buf575, arg66_1, arg334_1, buf577, arg335_1, 4096, grid=grid(4096), stream=stream0)
del buf575
buf578 = buf555; del buf555 # reuse
# Source Nodes: [y_66], Original ATen: [aten.bmm]
triton_red_fused_bmm_2.run(buf577, arg334_1, buf578, 6656, 128, grid=grid(6656), stream=stream0)
del arg334_1
buf582 = buf551; del buf551 # reuse
# Source Nodes: [mask, y_66], Original ATen: [aten._softmax, aten.add, aten.bmm, aten.index, aten.logical_not, aten.masked_fill, aten.zeros_like]
triton_per_fused__softmax_add_bmm_index_logical_not_masked_fill_zeros_like_3.run(buf578, arg454_1, arg67_1, buf582, 32, 208, grid=grid(32), stream=stream0)
buf583 = buf556; del buf556 # reuse
# Source Nodes: [y_66], Original ATen: [aten.bmm]
triton_red_fused_bmm_4.run(buf582, arg335_1, buf583, 8192, 104, grid=grid(8192), stream=stream0)
del arg335_1
buf585 = buf573; del buf573 # reuse
# Source Nodes: [out_133, y_66], Original ATen: [aten.bmm, torchao.fp16act_fp6weight_linear]
triton_per_fused_bmm_fp16act_fp6weight_linear_5.run(buf583, buf585, 4096, 2, grid=grid(4096), stream=stream0)
# Source Nodes: [out_133], Original ATen: [torchao.fp16act_fp6weight_linear]
buf586 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf585, arg336_1, arg337_1, 13)
del arg336_1
del arg337_1
buf587 = buf586
del buf586
buf589 = reinterpret_tensor(buf585, (1, 1, 4096), (4096, 4096, 1), 0); del buf585 # reuse
# Source Nodes: [add_136, float_92, h_22, mean_45, mul_341, mul_342, mul_343, out_131, output_45, rsqrt_45], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, aten.rsqrt]
triton_red_fused__to_copy_add_mean_mul_rsqrt_11.run(buf561, buf571, buf587, arg45_1, buf589, 1, 4096, grid=grid(1), stream=stream0)
del arg45_1
# Source Nodes: [out_134], Original ATen: [torchao.fp16act_fp6weight_linear]
buf590 = torch.ops.torchao.fp16act_fp6weight_linear.default(reinterpret_tensor(buf589, (1, 4096), (4096, 1), 0), arg338_1, arg339_1, 1)
del arg338_1
del arg339_1
buf591 = buf590
del buf590
# Source Nodes: [out_135], Original ATen: [torchao.fp16act_fp6weight_linear]
buf592 = torch.ops.torchao.fp16act_fp6weight_linear.default(reinterpret_tensor(buf589, (1, 4096), (4096, 1), 0), arg340_1, arg341_1, 1)
del arg340_1
del arg341_1
buf593 = buf592
del buf592
buf594 = buf591; del buf591 # reuse
# Source Nodes: [out_136], Original ATen: [torchao.fp16act_fp6weight_linear]
triton_poi_fused_fp16act_fp6weight_linear_7.run(buf594, buf593, 11008, grid=grid(11008), stream=stream0)
del buf593
# Source Nodes: [out_136], Original ATen: [torchao.fp16act_fp6weight_linear]
buf595 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf594, arg342_1, arg343_1, 13)
del arg342_1
del arg343_1
del buf594
buf596 = buf595
del buf595
buf598 = reinterpret_tensor(buf589, (1, 4096), (4096, 1), 0); del buf589 # reuse
# Source Nodes: [float_93, h_22, mean_46, mul_345, out_131, out_137, out_138], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, torchao.fp16act_fp6weight_linear]
triton_red_fused__to_copy_add_fp16act_fp6weight_linear_mean_mul_12.run(buf561, buf571, buf587, buf596, arg46_1, buf598, 1, 4096, grid=grid(1), stream=stream0)
del arg46_1
# Source Nodes: [out_138], Original ATen: [torchao.fp16act_fp6weight_linear]
buf599 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf598, arg344_1, arg345_1, 1)
del arg344_1
del arg345_1
buf600 = buf599
del buf599
buf602 = buf577; del buf577 # reuse
# Source Nodes: [setitem_46, setitem_47, y_69], Original ATen: [aten.bmm, aten.index_put]
triton_poi_fused_bmm_index_put_1.run(arg454_1, buf600, arg66_1, arg346_1, buf602, arg347_1, 4096, grid=grid(4096), stream=stream0)
del buf600
buf603 = buf582; del buf582 # reuse
# Source Nodes: [y_69], Original ATen: [aten.bmm]
triton_red_fused_bmm_2.run(buf602, arg346_1, buf603, 6656, 128, grid=grid(6656), stream=stream0)
del arg346_1
buf607 = buf578; del buf578 # reuse
# Source Nodes: [mask, y_69], Original ATen: [aten._softmax, aten.add, aten.bmm, aten.index, aten.logical_not, aten.masked_fill, aten.zeros_like]
triton_per_fused__softmax_add_bmm_index_logical_not_masked_fill_zeros_like_3.run(buf603, arg454_1, arg67_1, buf607, 32, 208, grid=grid(32), stream=stream0)
buf608 = buf583; del buf583 # reuse
# Source Nodes: [y_69], Original ATen: [aten.bmm]
triton_red_fused_bmm_4.run(buf607, arg347_1, buf608, 8192, 104, grid=grid(8192), stream=stream0)
del arg347_1
buf610 = buf598; del buf598 # reuse
# Source Nodes: [out_139, y_69], Original ATen: [aten.bmm, torchao.fp16act_fp6weight_linear]
triton_per_fused_bmm_fp16act_fp6weight_linear_5.run(buf608, buf610, 4096, 2, grid=grid(4096), stream=stream0)
# Source Nodes: [out_139], Original ATen: [torchao.fp16act_fp6weight_linear]
buf611 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf610, arg348_1, arg349_1, 13)
del arg348_1
del arg349_1
buf612 = buf611
del buf611
buf613 = buf561; del buf561 # reuse
buf615 = buf610; del buf610 # reuse
buf618 = buf563; del buf563 # reuse
# Source Nodes: [float_96, h_22, h_23, mean_47, mul_356, out_131, out_137, out_140, out_141], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, torchao.fp16act_fp6weight_linear]
triton_red_fused__to_copy_add_fp16act_fp6weight_linear_mean_mul_13.run(buf613, buf571, buf587, buf596, buf612, arg47_1, buf615, buf618, 1, 4096, grid=grid(1), stream=stream0)
del arg47_1
del buf571
del buf587
del buf596
del buf612
# Source Nodes: [out_140], Original ATen: [torchao.fp16act_fp6weight_linear]
buf616 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf615, arg350_1, arg351_1, 1)
del arg350_1
del arg351_1
buf617 = buf616
del buf616
# Source Nodes: [out_141], Original ATen: [torchao.fp16act_fp6weight_linear]
buf619 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf618, arg352_1, arg353_1, 1)
del arg352_1
del arg353_1
buf620 = buf619
del buf619
buf621 = buf617; del buf617 # reuse
# Source Nodes: [out_142], Original ATen: [torchao.fp16act_fp6weight_linear]
triton_poi_fused_fp16act_fp6weight_linear_7.run(buf621, buf620, 11008, grid=grid(11008), stream=stream0)
del buf620
# Source Nodes: [out_142], Original ATen: [torchao.fp16act_fp6weight_linear]
buf622 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf621, arg354_1, arg355_1, 13)
del arg354_1
del arg355_1
del buf621
buf623 = buf622
del buf622
buf625 = buf618; del buf618 # reuse
# Source Nodes: [float_97, mean_48, mul_360, out_143, out_144], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, torchao.fp16act_fp6weight_linear]
triton_red_fused__to_copy_add_fp16act_fp6weight_linear_mean_mul_10.run(buf613, buf623, arg48_1, buf625, 1, 4096, grid=grid(1), stream=stream0)
del arg48_1
# Source Nodes: [out_144], Original ATen: [torchao.fp16act_fp6weight_linear]
buf626 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf625, arg356_1, arg357_1, 1)
del arg356_1
del arg357_1
buf627 = buf626
del buf626
buf629 = buf602; del buf602 # reuse
# Source Nodes: [setitem_48, setitem_49, y_72], Original ATen: [aten.bmm, aten.index_put]
triton_poi_fused_bmm_index_put_1.run(arg454_1, buf627, arg66_1, arg358_1, buf629, arg359_1, 4096, grid=grid(4096), stream=stream0)
del buf627
buf630 = buf607; del buf607 # reuse
# Source Nodes: [y_72], Original ATen: [aten.bmm]
triton_red_fused_bmm_2.run(buf629, arg358_1, buf630, 6656, 128, grid=grid(6656), stream=stream0)
del arg358_1
buf634 = buf603; del buf603 # reuse
# Source Nodes: [mask, y_72], Original ATen: [aten._softmax, aten.add, aten.bmm, aten.index, aten.logical_not, aten.masked_fill, aten.zeros_like]
triton_per_fused__softmax_add_bmm_index_logical_not_masked_fill_zeros_like_3.run(buf630, arg454_1, arg67_1, buf634, 32, 208, grid=grid(32), stream=stream0)
buf635 = buf608; del buf608 # reuse
# Source Nodes: [y_72], Original ATen: [aten.bmm]
triton_red_fused_bmm_4.run(buf634, arg359_1, buf635, 8192, 104, grid=grid(8192), stream=stream0)
del arg359_1
buf637 = buf625; del buf625 # reuse
# Source Nodes: [out_145, y_72], Original ATen: [aten.bmm, torchao.fp16act_fp6weight_linear]
triton_per_fused_bmm_fp16act_fp6weight_linear_5.run(buf635, buf637, 4096, 2, grid=grid(4096), stream=stream0)
# Source Nodes: [out_145], Original ATen: [torchao.fp16act_fp6weight_linear]
buf638 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf637, arg360_1, arg361_1, 13)
del arg360_1
del arg361_1
buf639 = buf638
del buf638
buf641 = reinterpret_tensor(buf637, (1, 1, 4096), (4096, 4096, 1), 0); del buf637 # reuse
# Source Nodes: [add_148, float_100, h_24, mean_49, mul_371, mul_372, mul_373, out_143, output_49, rsqrt_49], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, aten.rsqrt]
triton_red_fused__to_copy_add_mean_mul_rsqrt_11.run(buf613, buf623, buf639, arg49_1, buf641, 1, 4096, grid=grid(1), stream=stream0)
del arg49_1
# Source Nodes: [out_146], Original ATen: [torchao.fp16act_fp6weight_linear]
buf642 = torch.ops.torchao.fp16act_fp6weight_linear.default(reinterpret_tensor(buf641, (1, 4096), (4096, 1), 0), arg362_1, arg363_1, 1)
del arg362_1
del arg363_1
buf643 = buf642
del buf642
# Source Nodes: [out_147], Original ATen: [torchao.fp16act_fp6weight_linear]
buf644 = torch.ops.torchao.fp16act_fp6weight_linear.default(reinterpret_tensor(buf641, (1, 4096), (4096, 1), 0), arg364_1, arg365_1, 1)
del arg364_1
del arg365_1
buf645 = buf644
del buf644
buf646 = buf643; del buf643 # reuse
# Source Nodes: [out_148], Original ATen: [torchao.fp16act_fp6weight_linear]
triton_poi_fused_fp16act_fp6weight_linear_7.run(buf646, buf645, 11008, grid=grid(11008), stream=stream0)
del buf645
# Source Nodes: [out_148], Original ATen: [torchao.fp16act_fp6weight_linear]
buf647 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf646, arg366_1, arg367_1, 13)
del arg366_1
del arg367_1
del buf646
buf648 = buf647
del buf647
buf650 = reinterpret_tensor(buf641, (1, 4096), (4096, 1), 0); del buf641 # reuse
# Source Nodes: [float_101, h_24, mean_50, mul_375, out_143, out_149, out_150], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, torchao.fp16act_fp6weight_linear]
triton_red_fused__to_copy_add_fp16act_fp6weight_linear_mean_mul_12.run(buf613, buf623, buf639, buf648, arg50_1, buf650, 1, 4096, grid=grid(1), stream=stream0)
del arg50_1
# Source Nodes: [out_150], Original ATen: [torchao.fp16act_fp6weight_linear]
buf651 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf650, arg368_1, arg369_1, 1)
del arg368_1
del arg369_1
buf652 = buf651
del buf651
buf654 = buf629; del buf629 # reuse
# Source Nodes: [setitem_50, setitem_51, y_75], Original ATen: [aten.bmm, aten.index_put]
triton_poi_fused_bmm_index_put_1.run(arg454_1, buf652, arg66_1, arg370_1, buf654, arg371_1, 4096, grid=grid(4096), stream=stream0)
del buf652
buf655 = buf634; del buf634 # reuse
# Source Nodes: [y_75], Original ATen: [aten.bmm]
triton_red_fused_bmm_2.run(buf654, arg370_1, buf655, 6656, 128, grid=grid(6656), stream=stream0)
del arg370_1
buf659 = buf630; del buf630 # reuse
# Source Nodes: [mask, y_75], Original ATen: [aten._softmax, aten.add, aten.bmm, aten.index, aten.logical_not, aten.masked_fill, aten.zeros_like]
triton_per_fused__softmax_add_bmm_index_logical_not_masked_fill_zeros_like_3.run(buf655, arg454_1, arg67_1, buf659, 32, 208, grid=grid(32), stream=stream0)
buf660 = buf635; del buf635 # reuse
# Source Nodes: [y_75], Original ATen: [aten.bmm]
triton_red_fused_bmm_4.run(buf659, arg371_1, buf660, 8192, 104, grid=grid(8192), stream=stream0)
del arg371_1
buf662 = buf650; del buf650 # reuse
# Source Nodes: [out_151, y_75], Original ATen: [aten.bmm, torchao.fp16act_fp6weight_linear]
triton_per_fused_bmm_fp16act_fp6weight_linear_5.run(buf660, buf662, 4096, 2, grid=grid(4096), stream=stream0)
# Source Nodes: [out_151], Original ATen: [torchao.fp16act_fp6weight_linear]
buf663 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf662, arg372_1, arg373_1, 13)
del arg372_1
del arg373_1
buf664 = buf663
del buf663
buf665 = buf613; del buf613 # reuse
buf667 = buf662; del buf662 # reuse
buf670 = buf615; del buf615 # reuse
# Source Nodes: [float_104, h_24, h_25, mean_51, mul_386, out_143, out_149, out_152, out_153], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, torchao.fp16act_fp6weight_linear]
triton_red_fused__to_copy_add_fp16act_fp6weight_linear_mean_mul_13.run(buf665, buf623, buf639, buf648, buf664, arg51_1, buf667, buf670, 1, 4096, grid=grid(1), stream=stream0)
del arg51_1
del buf623
del buf639
del buf648
del buf664
# Source Nodes: [out_152], Original ATen: [torchao.fp16act_fp6weight_linear]
buf668 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf667, arg374_1, arg375_1, 1)
del arg374_1
del arg375_1
buf669 = buf668
del buf668
# Source Nodes: [out_153], Original ATen: [torchao.fp16act_fp6weight_linear]
buf671 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf670, arg376_1, arg377_1, 1)
del arg376_1
del arg377_1
buf672 = buf671
del buf671
buf673 = buf669; del buf669 # reuse
# Source Nodes: [out_154], Original ATen: [torchao.fp16act_fp6weight_linear]
triton_poi_fused_fp16act_fp6weight_linear_7.run(buf673, buf672, 11008, grid=grid(11008), stream=stream0)
del buf672
# Source Nodes: [out_154], Original ATen: [torchao.fp16act_fp6weight_linear]
buf674 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf673, arg378_1, arg379_1, 13)
del arg378_1
del arg379_1
del buf673
buf675 = buf674
del buf674
buf677 = buf670; del buf670 # reuse
# Source Nodes: [float_105, mean_52, mul_390, out_155, out_156], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, torchao.fp16act_fp6weight_linear]
triton_red_fused__to_copy_add_fp16act_fp6weight_linear_mean_mul_10.run(buf665, buf675, arg52_1, buf677, 1, 4096, grid=grid(1), stream=stream0)
del arg52_1
# Source Nodes: [out_156], Original ATen: [torchao.fp16act_fp6weight_linear]
buf678 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf677, arg380_1, arg381_1, 1)
del arg380_1
del arg381_1
buf679 = buf678
del buf678
buf681 = buf654; del buf654 # reuse
# Source Nodes: [setitem_52, setitem_53, y_78], Original ATen: [aten.bmm, aten.index_put]
triton_poi_fused_bmm_index_put_1.run(arg454_1, buf679, arg66_1, arg382_1, buf681, arg383_1, 4096, grid=grid(4096), stream=stream0)
del buf679
buf682 = buf659; del buf659 # reuse
# Source Nodes: [y_78], Original ATen: [aten.bmm]
triton_red_fused_bmm_2.run(buf681, arg382_1, buf682, 6656, 128, grid=grid(6656), stream=stream0)
del arg382_1
buf686 = buf655; del buf655 # reuse
# Source Nodes: [mask, y_78], Original ATen: [aten._softmax, aten.add, aten.bmm, aten.index, aten.logical_not, aten.masked_fill, aten.zeros_like]
triton_per_fused__softmax_add_bmm_index_logical_not_masked_fill_zeros_like_3.run(buf682, arg454_1, arg67_1, buf686, 32, 208, grid=grid(32), stream=stream0)
buf687 = buf660; del buf660 # reuse
# Source Nodes: [y_78], Original ATen: [aten.bmm]
triton_red_fused_bmm_4.run(buf686, arg383_1, buf687, 8192, 104, grid=grid(8192), stream=stream0)
del arg383_1
buf689 = buf677; del buf677 # reuse
# Source Nodes: [out_157, y_78], Original ATen: [aten.bmm, torchao.fp16act_fp6weight_linear]
triton_per_fused_bmm_fp16act_fp6weight_linear_5.run(buf687, buf689, 4096, 2, grid=grid(4096), stream=stream0)
# Source Nodes: [out_157], Original ATen: [torchao.fp16act_fp6weight_linear]
buf690 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf689, arg384_1, arg385_1, 13)
del arg384_1
del arg385_1
buf691 = buf690
del buf690
buf693 = reinterpret_tensor(buf689, (1, 1, 4096), (4096, 4096, 1), 0); del buf689 # reuse
# Source Nodes: [add_160, float_108, h_26, mean_53, mul_401, mul_402, mul_403, out_155, output_53, rsqrt_53], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, aten.rsqrt]
triton_red_fused__to_copy_add_mean_mul_rsqrt_11.run(buf665, buf675, buf691, arg53_1, buf693, 1, 4096, grid=grid(1), stream=stream0)
del arg53_1
# Source Nodes: [out_158], Original ATen: [torchao.fp16act_fp6weight_linear]
buf694 = torch.ops.torchao.fp16act_fp6weight_linear.default(reinterpret_tensor(buf693, (1, 4096), (4096, 1), 0), arg386_1, arg387_1, 1)
del arg386_1
del arg387_1
buf695 = buf694
del buf694
# Source Nodes: [out_159], Original ATen: [torchao.fp16act_fp6weight_linear]
buf696 = torch.ops.torchao.fp16act_fp6weight_linear.default(reinterpret_tensor(buf693, (1, 4096), (4096, 1), 0), arg388_1, arg389_1, 1)
del arg388_1
del arg389_1
buf697 = buf696
del buf696
buf698 = buf695; del buf695 # reuse
# Source Nodes: [out_160], Original ATen: [torchao.fp16act_fp6weight_linear]
triton_poi_fused_fp16act_fp6weight_linear_7.run(buf698, buf697, 11008, grid=grid(11008), stream=stream0)
del buf697
# Source Nodes: [out_160], Original ATen: [torchao.fp16act_fp6weight_linear]
buf699 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf698, arg390_1, arg391_1, 13)
del arg390_1
del arg391_1
del buf698
buf700 = buf699
del buf699
buf702 = reinterpret_tensor(buf693, (1, 4096), (4096, 1), 0); del buf693 # reuse
# Source Nodes: [float_109, h_26, mean_54, mul_405, out_155, out_161, out_162], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, torchao.fp16act_fp6weight_linear]
triton_red_fused__to_copy_add_fp16act_fp6weight_linear_mean_mul_12.run(buf665, buf675, buf691, buf700, arg54_1, buf702, 1, 4096, grid=grid(1), stream=stream0)
del arg54_1
# Source Nodes: [out_162], Original ATen: [torchao.fp16act_fp6weight_linear]
buf703 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf702, arg392_1, arg393_1, 1)
del arg392_1
del arg393_1
buf704 = buf703
del buf703
buf706 = buf681; del buf681 # reuse
# Source Nodes: [setitem_54, setitem_55, y_81], Original ATen: [aten.bmm, aten.index_put]
triton_poi_fused_bmm_index_put_1.run(arg454_1, buf704, arg66_1, arg394_1, buf706, arg395_1, 4096, grid=grid(4096), stream=stream0)
del buf704
buf707 = buf686; del buf686 # reuse
# Source Nodes: [y_81], Original ATen: [aten.bmm]
triton_red_fused_bmm_2.run(buf706, arg394_1, buf707, 6656, 128, grid=grid(6656), stream=stream0)
del arg394_1
buf711 = buf682; del buf682 # reuse
# Source Nodes: [mask, y_81], Original ATen: [aten._softmax, aten.add, aten.bmm, aten.index, aten.logical_not, aten.masked_fill, aten.zeros_like]
triton_per_fused__softmax_add_bmm_index_logical_not_masked_fill_zeros_like_3.run(buf707, arg454_1, arg67_1, buf711, 32, 208, grid=grid(32), stream=stream0)
buf712 = buf687; del buf687 # reuse
# Source Nodes: [y_81], Original ATen: [aten.bmm]
triton_red_fused_bmm_4.run(buf711, arg395_1, buf712, 8192, 104, grid=grid(8192), stream=stream0)
del arg395_1
buf714 = buf702; del buf702 # reuse
# Source Nodes: [out_163, y_81], Original ATen: [aten.bmm, torchao.fp16act_fp6weight_linear]
triton_per_fused_bmm_fp16act_fp6weight_linear_5.run(buf712, buf714, 4096, 2, grid=grid(4096), stream=stream0)
# Source Nodes: [out_163], Original ATen: [torchao.fp16act_fp6weight_linear]
buf715 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf714, arg396_1, arg397_1, 13)
del arg396_1
del arg397_1
buf716 = buf715
del buf715
buf717 = buf665; del buf665 # reuse
buf719 = buf714; del buf714 # reuse
buf722 = buf667; del buf667 # reuse
# Source Nodes: [float_112, h_26, h_27, mean_55, mul_416, out_155, out_161, out_164, out_165], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, torchao.fp16act_fp6weight_linear]
triton_red_fused__to_copy_add_fp16act_fp6weight_linear_mean_mul_13.run(buf717, buf675, buf691, buf700, buf716, arg55_1, buf719, buf722, 1, 4096, grid=grid(1), stream=stream0)
del arg55_1
del buf675
del buf691
del buf700
del buf716
# Source Nodes: [out_164], Original ATen: [torchao.fp16act_fp6weight_linear]
buf720 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf719, arg398_1, arg399_1, 1)
del arg398_1
del arg399_1
buf721 = buf720
del buf720
# Source Nodes: [out_165], Original ATen: [torchao.fp16act_fp6weight_linear]
buf723 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf722, arg400_1, arg401_1, 1)
del arg400_1
del arg401_1
buf724 = buf723
del buf723
buf725 = buf721; del buf721 # reuse
# Source Nodes: [out_166], Original ATen: [torchao.fp16act_fp6weight_linear]
triton_poi_fused_fp16act_fp6weight_linear_7.run(buf725, buf724, 11008, grid=grid(11008), stream=stream0)
del buf724
# Source Nodes: [out_166], Original ATen: [torchao.fp16act_fp6weight_linear]
buf726 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf725, arg402_1, arg403_1, 13)
del arg402_1
del arg403_1
del buf725
buf727 = buf726
del buf726
buf729 = buf722; del buf722 # reuse
# Source Nodes: [float_113, mean_56, mul_420, out_167, out_168], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, torchao.fp16act_fp6weight_linear]
triton_red_fused__to_copy_add_fp16act_fp6weight_linear_mean_mul_10.run(buf717, buf727, arg56_1, buf729, 1, 4096, grid=grid(1), stream=stream0)
del arg56_1
# Source Nodes: [out_168], Original ATen: [torchao.fp16act_fp6weight_linear]
buf730 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf729, arg404_1, arg405_1, 1)
del arg404_1
del arg405_1
buf731 = buf730
del buf730
buf733 = buf706; del buf706 # reuse
# Source Nodes: [setitem_56, setitem_57, y_84], Original ATen: [aten.bmm, aten.index_put]
triton_poi_fused_bmm_index_put_1.run(arg454_1, buf731, arg66_1, arg406_1, buf733, arg407_1, 4096, grid=grid(4096), stream=stream0)
del buf731
buf734 = buf711; del buf711 # reuse
# Source Nodes: [y_84], Original ATen: [aten.bmm]
triton_red_fused_bmm_2.run(buf733, arg406_1, buf734, 6656, 128, grid=grid(6656), stream=stream0)
del arg406_1
buf738 = buf707; del buf707 # reuse
# Source Nodes: [mask, y_84], Original ATen: [aten._softmax, aten.add, aten.bmm, aten.index, aten.logical_not, aten.masked_fill, aten.zeros_like]
triton_per_fused__softmax_add_bmm_index_logical_not_masked_fill_zeros_like_3.run(buf734, arg454_1, arg67_1, buf738, 32, 208, grid=grid(32), stream=stream0)
buf739 = buf712; del buf712 # reuse
# Source Nodes: [y_84], Original ATen: [aten.bmm]
triton_red_fused_bmm_4.run(buf738, arg407_1, buf739, 8192, 104, grid=grid(8192), stream=stream0)
del arg407_1
buf741 = buf729; del buf729 # reuse
# Source Nodes: [out_169, y_84], Original ATen: [aten.bmm, torchao.fp16act_fp6weight_linear]
triton_per_fused_bmm_fp16act_fp6weight_linear_5.run(buf739, buf741, 4096, 2, grid=grid(4096), stream=stream0)
# Source Nodes: [out_169], Original ATen: [torchao.fp16act_fp6weight_linear]
buf742 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf741, arg408_1, arg409_1, 13)
del arg408_1
del arg409_1
buf743 = buf742
del buf742
buf745 = reinterpret_tensor(buf741, (1, 1, 4096), (4096, 4096, 1), 0); del buf741 # reuse
# Source Nodes: [add_172, float_116, h_28, mean_57, mul_431, mul_432, mul_433, out_167, output_57, rsqrt_57], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, aten.rsqrt]
triton_red_fused__to_copy_add_mean_mul_rsqrt_11.run(buf717, buf727, buf743, arg57_1, buf745, 1, 4096, grid=grid(1), stream=stream0)
del arg57_1
# Source Nodes: [out_170], Original ATen: [torchao.fp16act_fp6weight_linear]
buf746 = torch.ops.torchao.fp16act_fp6weight_linear.default(reinterpret_tensor(buf745, (1, 4096), (4096, 1), 0), arg410_1, arg411_1, 1)
del arg410_1
del arg411_1
buf747 = buf746
del buf746
# Source Nodes: [out_171], Original ATen: [torchao.fp16act_fp6weight_linear]
buf748 = torch.ops.torchao.fp16act_fp6weight_linear.default(reinterpret_tensor(buf745, (1, 4096), (4096, 1), 0), arg412_1, arg413_1, 1)
del arg412_1
del arg413_1
buf749 = buf748
del buf748
buf750 = buf747; del buf747 # reuse
# Source Nodes: [out_172], Original ATen: [torchao.fp16act_fp6weight_linear]
triton_poi_fused_fp16act_fp6weight_linear_7.run(buf750, buf749, 11008, grid=grid(11008), stream=stream0)
del buf749
# Source Nodes: [out_172], Original ATen: [torchao.fp16act_fp6weight_linear]
buf751 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf750, arg414_1, arg415_1, 13)
del arg414_1
del arg415_1
del buf750
buf752 = buf751
del buf751
buf754 = reinterpret_tensor(buf745, (1, 4096), (4096, 1), 0); del buf745 # reuse
# Source Nodes: [float_117, h_28, mean_58, mul_435, out_167, out_173, out_174], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, torchao.fp16act_fp6weight_linear]
triton_red_fused__to_copy_add_fp16act_fp6weight_linear_mean_mul_12.run(buf717, buf727, buf743, buf752, arg58_1, buf754, 1, 4096, grid=grid(1), stream=stream0)
del arg58_1
# Source Nodes: [out_174], Original ATen: [torchao.fp16act_fp6weight_linear]
buf755 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf754, arg416_1, arg417_1, 1)
del arg416_1
del arg417_1
buf756 = buf755
del buf755
buf758 = buf733; del buf733 # reuse
# Source Nodes: [setitem_58, setitem_59, y_87], Original ATen: [aten.bmm, aten.index_put]
triton_poi_fused_bmm_index_put_1.run(arg454_1, buf756, arg66_1, arg418_1, buf758, arg419_1, 4096, grid=grid(4096), stream=stream0)
del buf756
buf759 = buf738; del buf738 # reuse
# Source Nodes: [y_87], Original ATen: [aten.bmm]
triton_red_fused_bmm_2.run(buf758, arg418_1, buf759, 6656, 128, grid=grid(6656), stream=stream0)
del arg418_1
buf763 = buf734; del buf734 # reuse
# Source Nodes: [mask, y_87], Original ATen: [aten._softmax, aten.add, aten.bmm, aten.index, aten.logical_not, aten.masked_fill, aten.zeros_like]
triton_per_fused__softmax_add_bmm_index_logical_not_masked_fill_zeros_like_3.run(buf759, arg454_1, arg67_1, buf763, 32, 208, grid=grid(32), stream=stream0)
buf764 = buf739; del buf739 # reuse
# Source Nodes: [y_87], Original ATen: [aten.bmm]
triton_red_fused_bmm_4.run(buf763, arg419_1, buf764, 8192, 104, grid=grid(8192), stream=stream0)
del arg419_1
buf766 = buf754; del buf754 # reuse
# Source Nodes: [out_175, y_87], Original ATen: [aten.bmm, torchao.fp16act_fp6weight_linear]
triton_per_fused_bmm_fp16act_fp6weight_linear_5.run(buf764, buf766, 4096, 2, grid=grid(4096), stream=stream0)
# Source Nodes: [out_175], Original ATen: [torchao.fp16act_fp6weight_linear]
buf767 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf766, arg420_1, arg421_1, 13)
del arg420_1
del arg421_1
buf768 = buf767
del buf767
buf769 = buf717; del buf717 # reuse
buf771 = buf766; del buf766 # reuse
buf774 = buf719; del buf719 # reuse
# Source Nodes: [float_120, h_28, h_29, mean_59, mul_446, out_167, out_173, out_176, out_177], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, torchao.fp16act_fp6weight_linear]
triton_red_fused__to_copy_add_fp16act_fp6weight_linear_mean_mul_13.run(buf769, buf727, buf743, buf752, buf768, arg59_1, buf771, buf774, 1, 4096, grid=grid(1), stream=stream0)
del arg59_1
del buf727
del buf743
del buf752
del buf768
# Source Nodes: [out_176], Original ATen: [torchao.fp16act_fp6weight_linear]
buf772 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf771, arg422_1, arg423_1, 1)
del arg422_1
del arg423_1
buf773 = buf772
del buf772
# Source Nodes: [out_177], Original ATen: [torchao.fp16act_fp6weight_linear]
buf775 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf774, arg424_1, arg425_1, 1)
del arg424_1
del arg425_1
buf776 = buf775
del buf775
buf777 = buf773; del buf773 # reuse
# Source Nodes: [out_178], Original ATen: [torchao.fp16act_fp6weight_linear]
triton_poi_fused_fp16act_fp6weight_linear_7.run(buf777, buf776, 11008, grid=grid(11008), stream=stream0)
del buf776
# Source Nodes: [out_178], Original ATen: [torchao.fp16act_fp6weight_linear]
buf778 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf777, arg426_1, arg427_1, 13)
del arg426_1
del arg427_1
del buf777
buf779 = buf778
del buf778
buf781 = buf774; del buf774 # reuse
# Source Nodes: [float_121, mean_60, mul_450, out_179, out_180], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, torchao.fp16act_fp6weight_linear]
triton_red_fused__to_copy_add_fp16act_fp6weight_linear_mean_mul_10.run(buf769, buf779, arg60_1, buf781, 1, 4096, grid=grid(1), stream=stream0)
del arg60_1
# Source Nodes: [out_180], Original ATen: [torchao.fp16act_fp6weight_linear]
buf782 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf781, arg428_1, arg429_1, 1)
del arg428_1
del arg429_1
buf783 = buf782
del buf782
buf785 = buf758; del buf758 # reuse
# Source Nodes: [setitem_60, setitem_61, y_90], Original ATen: [aten.bmm, aten.index_put]
triton_poi_fused_bmm_index_put_1.run(arg454_1, buf783, arg66_1, arg430_1, buf785, arg431_1, 4096, grid=grid(4096), stream=stream0)
del buf783
buf786 = buf763; del buf763 # reuse
# Source Nodes: [y_90], Original ATen: [aten.bmm]
triton_red_fused_bmm_2.run(buf785, arg430_1, buf786, 6656, 128, grid=grid(6656), stream=stream0)
del arg430_1
buf790 = buf759; del buf759 # reuse
# Source Nodes: [mask, y_90], Original ATen: [aten._softmax, aten.add, aten.bmm, aten.index, aten.logical_not, aten.masked_fill, aten.zeros_like]
triton_per_fused__softmax_add_bmm_index_logical_not_masked_fill_zeros_like_3.run(buf786, arg454_1, arg67_1, buf790, 32, 208, grid=grid(32), stream=stream0)
buf791 = buf764; del buf764 # reuse
# Source Nodes: [y_90], Original ATen: [aten.bmm]
triton_red_fused_bmm_4.run(buf790, arg431_1, buf791, 8192, 104, grid=grid(8192), stream=stream0)
del arg431_1
buf793 = buf781; del buf781 # reuse
# Source Nodes: [out_181, y_90], Original ATen: [aten.bmm, torchao.fp16act_fp6weight_linear]
triton_per_fused_bmm_fp16act_fp6weight_linear_5.run(buf791, buf793, 4096, 2, grid=grid(4096), stream=stream0)
# Source Nodes: [out_181], Original ATen: [torchao.fp16act_fp6weight_linear]
buf794 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf793, arg432_1, arg433_1, 13)
del arg432_1
del arg433_1
buf795 = buf794
del buf794
buf797 = reinterpret_tensor(buf793, (1, 1, 4096), (4096, 4096, 1), 0); del buf793 # reuse
# Source Nodes: [add_184, float_124, h_30, mean_61, mul_461, mul_462, mul_463, out_179, output_61, rsqrt_61], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, aten.rsqrt]
triton_red_fused__to_copy_add_mean_mul_rsqrt_11.run(buf769, buf779, buf795, arg61_1, buf797, 1, 4096, grid=grid(1), stream=stream0)
del arg61_1
# Source Nodes: [out_182], Original ATen: [torchao.fp16act_fp6weight_linear]
buf798 = torch.ops.torchao.fp16act_fp6weight_linear.default(reinterpret_tensor(buf797, (1, 4096), (4096, 1), 0), arg434_1, arg435_1, 1)
del arg434_1
del arg435_1
buf799 = buf798
del buf798
# Source Nodes: [out_183], Original ATen: [torchao.fp16act_fp6weight_linear]
buf800 = torch.ops.torchao.fp16act_fp6weight_linear.default(reinterpret_tensor(buf797, (1, 4096), (4096, 1), 0), arg436_1, arg437_1, 1)
del arg436_1
del arg437_1
buf801 = buf800
del buf800
buf802 = buf799; del buf799 # reuse
# Source Nodes: [out_184], Original ATen: [torchao.fp16act_fp6weight_linear]
triton_poi_fused_fp16act_fp6weight_linear_7.run(buf802, buf801, 11008, grid=grid(11008), stream=stream0)
del buf801
# Source Nodes: [out_184], Original ATen: [torchao.fp16act_fp6weight_linear]
buf803 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf802, arg438_1, arg439_1, 13)
del arg438_1
del arg439_1
del buf802
buf804 = buf803
del buf803
buf806 = reinterpret_tensor(buf797, (1, 4096), (4096, 1), 0); del buf797 # reuse
# Source Nodes: [float_125, h_30, mean_62, mul_465, out_179, out_185, out_186], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, torchao.fp16act_fp6weight_linear]
triton_red_fused__to_copy_add_fp16act_fp6weight_linear_mean_mul_12.run(buf769, buf779, buf795, buf804, arg62_1, buf806, 1, 4096, grid=grid(1), stream=stream0)
del arg62_1
# Source Nodes: [out_186], Original ATen: [torchao.fp16act_fp6weight_linear]
buf807 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf806, arg440_1, arg441_1, 1)
del arg440_1
del arg441_1
buf808 = buf807
del buf807
buf810 = buf785; del buf785 # reuse
# Source Nodes: [setitem_62, setitem_63, y_93], Original ATen: [aten.bmm, aten.index_put]
triton_poi_fused_bmm_index_put_1.run(arg454_1, buf808, arg66_1, arg442_1, buf810, arg443_1, 4096, grid=grid(4096), stream=stream0)
del arg66_1
del buf808
buf811 = buf790; del buf790 # reuse
# Source Nodes: [y_93], Original ATen: [aten.bmm]
triton_red_fused_bmm_2.run(buf810, arg442_1, buf811, 6656, 128, grid=grid(6656), stream=stream0)
del arg442_1
del buf810
buf815 = buf786; del buf786 # reuse
# Source Nodes: [mask, y_93], Original ATen: [aten._softmax, aten.add, aten.bmm, aten.index, aten.logical_not, aten.masked_fill, aten.zeros_like]
triton_per_fused__softmax_add_bmm_index_logical_not_masked_fill_zeros_like_3.run(buf811, arg454_1, arg67_1, buf815, 32, 208, grid=grid(32), stream=stream0)
del arg454_1
del arg67_1
del buf811
buf816 = buf791; del buf791 # reuse
# Source Nodes: [y_93], Original ATen: [aten.bmm]
triton_red_fused_bmm_4.run(buf815, arg443_1, buf816, 8192, 104, grid=grid(8192), stream=stream0)
del arg443_1
del buf815
buf818 = buf806; del buf806 # reuse
# Source Nodes: [out_187, y_93], Original ATen: [aten.bmm, torchao.fp16act_fp6weight_linear]
triton_per_fused_bmm_fp16act_fp6weight_linear_5.run(buf816, buf818, 4096, 2, grid=grid(4096), stream=stream0)
del buf816
# Source Nodes: [out_187], Original ATen: [torchao.fp16act_fp6weight_linear]
buf819 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf818, arg444_1, arg445_1, 13)
del arg444_1
del arg445_1
buf820 = buf819
del buf819
buf821 = buf769; del buf769 # reuse
buf823 = buf818; del buf818 # reuse
buf826 = buf771; del buf771 # reuse
# Source Nodes: [float_128, h_30, h_31, mean_63, mul_476, out_179, out_185, out_188, out_189], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, torchao.fp16act_fp6weight_linear]
triton_red_fused__to_copy_add_fp16act_fp6weight_linear_mean_mul_13.run(buf821, buf779, buf795, buf804, buf820, arg63_1, buf823, buf826, 1, 4096, grid=grid(1), stream=stream0)
del arg63_1
del buf779
del buf795
del buf804
del buf820
# Source Nodes: [out_188], Original ATen: [torchao.fp16act_fp6weight_linear]
buf824 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf823, arg446_1, arg447_1, 1)
del arg446_1
del arg447_1
del buf823
buf825 = buf824
del buf824
# Source Nodes: [out_189], Original ATen: [torchao.fp16act_fp6weight_linear]
buf827 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf826, arg448_1, arg449_1, 1)
del arg448_1
del arg449_1
buf828 = buf827
del buf827
buf829 = buf825; del buf825 # reuse
# Source Nodes: [out_190], Original ATen: [torchao.fp16act_fp6weight_linear]
triton_poi_fused_fp16act_fp6weight_linear_7.run(buf829, buf828, 11008, grid=grid(11008), stream=stream0)
del buf828
# Source Nodes: [out_190], Original ATen: [torchao.fp16act_fp6weight_linear]
buf830 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf829, arg450_1, arg451_1, 13)
del arg450_1
del arg451_1
del buf829
buf831 = buf830
del buf830
buf833 = buf826; del buf826 # reuse
# Source Nodes: [float_129, mean_64, mul_480, out_191, out_192], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, torchao.fp16act_fp6weight_linear]
triton_red_fused__to_copy_add_fp16act_fp6weight_linear_mean_mul_10.run(buf821, buf831, arg64_1, buf833, 1, 4096, grid=grid(1), stream=stream0)
del arg64_1
del buf821
del buf831
# Source Nodes: [out_192], Original ATen: [torchao.fp16act_fp6weight_linear]
buf834 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf833, arg452_1, arg453_1, 1)
del arg452_1
del arg453_1
del buf833
buf835 = buf834
del buf834
buf836 = reinterpret_tensor(buf835, (32000, ), (1, ), 0); del buf835 # reuse
# Source Nodes: [logits_1], Original ATen: [aten.div]
triton_poi_fused_div_16.run(buf836, 32000, grid=grid(32000), stream=stream0)
# Source Nodes: [logits_1, topk], Original ATen: [aten.div, aten.topk]
buf837 = aten.topk.default(buf836, 200)
buf838 = buf837[0]
del buf837
buf840 = empty_strided_cuda((1, 4), (4, 1), torch.float32)
# Source Nodes: [logits_2, lt, probs], Original ATen: [aten._softmax, aten.lt, aten.scalar_tensor, aten.where]
triton_red_fused__softmax_lt_scalar_tensor_where_17.run(buf836, buf838, buf840, 4, 8000, grid=grid(4), stream=stream0)
buf841 = empty_strided_cuda((1, ), (1, ), torch.float32)
# Source Nodes: [logits_2, lt, probs], Original ATen: [aten._softmax, aten.lt, aten.scalar_tensor, aten.where]
triton_per_fused__softmax_lt_scalar_tensor_where_18.run(buf840, buf841, 1, 4, grid=grid(1), stream=stream0)
buf842 = buf840; del buf840 # reuse
# Source Nodes: [logits_2, lt, probs], Original ATen: [aten._softmax, aten.lt, aten.scalar_tensor, aten.where]
triton_red_fused__softmax_lt_scalar_tensor_where_19.run(buf836, buf838, buf841, buf842, 4, 8000, grid=grid(4), stream=stream0)
buf843 = empty_strided_cuda((1, ), (1, ), torch.float32)
# Source Nodes: [logits_2, lt, probs], Original ATen: [aten._softmax, aten.lt, aten.scalar_tensor, aten.where]
triton_per_fused__softmax_lt_scalar_tensor_where_20.run(buf842, buf843, 1, 4, grid=grid(1), stream=stream0)
del buf842
buf845 = empty_strided_cuda((1, ), (1, ), torch.int64)
# Source Nodes: [], Original ATen: []
aten.randint.low_out(-9223372036854775808, 9223372036854775807, [1], out=buf845)
buf844 = buf836; del buf836 # reuse
buf848 = empty_strided_cuda((1, ), (1, ), torch.int32)
# Source Nodes: [argmax, idx_next, logits_2, lt, probs, q_128, truediv_1], Original ATen: [aten._softmax, aten._to_copy, aten.argmax, aten.div, aten.exponential, aten.lt, aten.scalar_tensor, aten.where]
triton_red_fused__softmax__to_copy_argmax_div_exponential_lt_scalar_tensor_where_21.run(buf844, buf838, buf841, buf843, buf845, buf848, 0, 1, 32000, grid=grid(1), stream=stream0)
del buf838
del buf841
del buf843
del buf845
return (buf848, buf844, )
def benchmark_compiled_module(times=10, repeat=10):
from torch._dynamo.testing import rand_strided
from torch._inductor.utils import print_performance
arg0_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg1_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg2_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg3_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg4_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg5_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg6_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg7_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg8_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg9_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg10_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg11_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg12_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg13_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg14_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg15_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg16_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg17_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg18_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg19_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg20_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg21_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg22_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg23_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg24_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg25_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg26_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg27_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg28_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg29_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg30_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg31_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg32_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg33_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg34_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg35_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg36_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg37_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg38_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg39_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg40_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg41_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg42_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg43_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg44_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg45_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg46_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg47_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg48_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg49_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg50_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg51_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg52_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg53_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg54_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg55_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg56_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg57_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg58_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg59_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg60_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg61_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg62_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg63_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg64_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg65_1 = rand_strided((32000, 4096), (4096, 1), device='cuda:0', dtype=torch.float16)
arg66_1 = rand_strided((2048, 64, 2), (128, 2, 1), device='cuda:0', dtype=torch.float16)
arg67_1 = rand_strided((208, 208), (208, 1), device='cuda:0', dtype=torch.bool)
arg68_1 = rand_strided((12288, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg69_1 = rand_strided((12288, ), (1, ), device='cuda:0', dtype=torch.float16)
arg70_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16)
arg71_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16)
arg72_1 = rand_strided((4096, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg73_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg74_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg75_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16)
arg76_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg77_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16)
arg78_1 = rand_strided((4096, 2064), (2064, 1), device='cuda:0', dtype=torch.int32)
arg79_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg80_1 = rand_strided((12288, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg81_1 = rand_strided((12288, ), (1, ), device='cuda:0', dtype=torch.float16)
arg82_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16)
arg83_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16)
arg84_1 = rand_strided((4096, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg85_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg86_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg87_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16)
arg88_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg89_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16)
arg90_1 = rand_strided((4096, 2064), (2064, 1), device='cuda:0', dtype=torch.int32)
arg91_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg92_1 = rand_strided((12288, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg93_1 = rand_strided((12288, ), (1, ), device='cuda:0', dtype=torch.float16)
arg94_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16)
arg95_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16)
arg96_1 = rand_strided((4096, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg97_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg98_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg99_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16)
arg100_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg101_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16)
arg102_1 = rand_strided((4096, 2064), (2064, 1), device='cuda:0', dtype=torch.int32)
arg103_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg104_1 = rand_strided((12288, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg105_1 = rand_strided((12288, ), (1, ), device='cuda:0', dtype=torch.float16)
arg106_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16)
arg107_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16)
arg108_1 = rand_strided((4096, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg109_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg110_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg111_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16)
arg112_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg113_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16)
arg114_1 = rand_strided((4096, 2064), (2064, 1), device='cuda:0', dtype=torch.int32)
arg115_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg116_1 = rand_strided((12288, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg117_1 = rand_strided((12288, ), (1, ), device='cuda:0', dtype=torch.float16)
arg118_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16)
arg119_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16)
arg120_1 = rand_strided((4096, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg121_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg122_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg123_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16)
arg124_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg125_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16)
arg126_1 = rand_strided((4096, 2064), (2064, 1), device='cuda:0', dtype=torch.int32)
arg127_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg128_1 = rand_strided((12288, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg129_1 = rand_strided((12288, ), (1, ), device='cuda:0', dtype=torch.float16)
arg130_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16)
arg131_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16)
arg132_1 = rand_strided((4096, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg133_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg134_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg135_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16)
arg136_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg137_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16)
arg138_1 = rand_strided((4096, 2064), (2064, 1), device='cuda:0', dtype=torch.int32)
arg139_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg140_1 = rand_strided((12288, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg141_1 = rand_strided((12288, ), (1, ), device='cuda:0', dtype=torch.float16)
arg142_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16)
arg143_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16)
arg144_1 = rand_strided((4096, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg145_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg146_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg147_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16)
arg148_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg149_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16)
arg150_1 = rand_strided((4096, 2064), (2064, 1), device='cuda:0', dtype=torch.int32)
arg151_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg152_1 = rand_strided((12288, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg153_1 = rand_strided((12288, ), (1, ), device='cuda:0', dtype=torch.float16)
arg154_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16)
arg155_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16)
arg156_1 = rand_strided((4096, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg157_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg158_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg159_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16)
arg160_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg161_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16)
arg162_1 = rand_strided((4096, 2064), (2064, 1), device='cuda:0', dtype=torch.int32)
arg163_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg164_1 = rand_strided((12288, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg165_1 = rand_strided((12288, ), (1, ), device='cuda:0', dtype=torch.float16)
arg166_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16)
arg167_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16)
arg168_1 = rand_strided((4096, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg169_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg170_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg171_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16)
arg172_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg173_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16)
arg174_1 = rand_strided((4096, 2064), (2064, 1), device='cuda:0', dtype=torch.int32)
arg175_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg176_1 = rand_strided((12288, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg177_1 = rand_strided((12288, ), (1, ), device='cuda:0', dtype=torch.float16)
arg178_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16)
arg179_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16)
arg180_1 = rand_strided((4096, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg181_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg182_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg183_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16)
arg184_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg185_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16)
arg186_1 = rand_strided((4096, 2064), (2064, 1), device='cuda:0', dtype=torch.int32)
arg187_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg188_1 = rand_strided((12288, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg189_1 = rand_strided((12288, ), (1, ), device='cuda:0', dtype=torch.float16)
arg190_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16)
arg191_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16)
arg192_1 = rand_strided((4096, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg193_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg194_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg195_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16)
arg196_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg197_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16)
arg198_1 = rand_strided((4096, 2064), (2064, 1), device='cuda:0', dtype=torch.int32)
arg199_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg200_1 = rand_strided((12288, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg201_1 = rand_strided((12288, ), (1, ), device='cuda:0', dtype=torch.float16)
arg202_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16)
arg203_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16)
arg204_1 = rand_strided((4096, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg205_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg206_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg207_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16)
arg208_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg209_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16)
arg210_1 = rand_strided((4096, 2064), (2064, 1), device='cuda:0', dtype=torch.int32)
arg211_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg212_1 = rand_strided((12288, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg213_1 = rand_strided((12288, ), (1, ), device='cuda:0', dtype=torch.float16)
arg214_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16)
arg215_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16)
arg216_1 = rand_strided((4096, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg217_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg218_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg219_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16)
arg220_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg221_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16)
arg222_1 = rand_strided((4096, 2064), (2064, 1), device='cuda:0', dtype=torch.int32)
arg223_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg224_1 = rand_strided((12288, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg225_1 = rand_strided((12288, ), (1, ), device='cuda:0', dtype=torch.float16)
arg226_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16)
arg227_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16)
arg228_1 = rand_strided((4096, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg229_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg230_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg231_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16)
arg232_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg233_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16)
arg234_1 = rand_strided((4096, 2064), (2064, 1), device='cuda:0', dtype=torch.int32)
arg235_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg236_1 = rand_strided((12288, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg237_1 = rand_strided((12288, ), (1, ), device='cuda:0', dtype=torch.float16)
arg238_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16)
arg239_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16)
arg240_1 = rand_strided((4096, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg241_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg242_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg243_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16)
arg244_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg245_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16)
arg246_1 = rand_strided((4096, 2064), (2064, 1), device='cuda:0', dtype=torch.int32)
arg247_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg248_1 = rand_strided((12288, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg249_1 = rand_strided((12288, ), (1, ), device='cuda:0', dtype=torch.float16)
arg250_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16)
arg251_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16)
arg252_1 = rand_strided((4096, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg253_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg254_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg255_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16)
arg256_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg257_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16)
arg258_1 = rand_strided((4096, 2064), (2064, 1), device='cuda:0', dtype=torch.int32)
arg259_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg260_1 = rand_strided((12288, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg261_1 = rand_strided((12288, ), (1, ), device='cuda:0', dtype=torch.float16)
arg262_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16)
arg263_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16)
arg264_1 = rand_strided((4096, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg265_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg266_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg267_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16)
arg268_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg269_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16)
arg270_1 = rand_strided((4096, 2064), (2064, 1), device='cuda:0', dtype=torch.int32)
arg271_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg272_1 = rand_strided((12288, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg273_1 = rand_strided((12288, ), (1, ), device='cuda:0', dtype=torch.float16)
arg274_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16)
arg275_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16)
arg276_1 = rand_strided((4096, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg277_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg278_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg279_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16)
arg280_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg281_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16)
arg282_1 = rand_strided((4096, 2064), (2064, 1), device='cuda:0', dtype=torch.int32)
arg283_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg284_1 = rand_strided((12288, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg285_1 = rand_strided((12288, ), (1, ), device='cuda:0', dtype=torch.float16)
arg286_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16)
arg287_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16)
arg288_1 = rand_strided((4096, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg289_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg290_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg291_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16)
arg292_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg293_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16)
arg294_1 = rand_strided((4096, 2064), (2064, 1), device='cuda:0', dtype=torch.int32)
arg295_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg296_1 = rand_strided((12288, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg297_1 = rand_strided((12288, ), (1, ), device='cuda:0', dtype=torch.float16)
arg298_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16)
arg299_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16)
arg300_1 = rand_strided((4096, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg301_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg302_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg303_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16)
arg304_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg305_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16)
arg306_1 = rand_strided((4096, 2064), (2064, 1), device='cuda:0', dtype=torch.int32)
arg307_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg308_1 = rand_strided((12288, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg309_1 = rand_strided((12288, ), (1, ), device='cuda:0', dtype=torch.float16)
arg310_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16)
arg311_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16)
arg312_1 = rand_strided((4096, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg313_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg314_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg315_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16)
arg316_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg317_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16)
arg318_1 = rand_strided((4096, 2064), (2064, 1), device='cuda:0', dtype=torch.int32)
arg319_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg320_1 = rand_strided((12288, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg321_1 = rand_strided((12288, ), (1, ), device='cuda:0', dtype=torch.float16)
arg322_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16)
arg323_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16)
arg324_1 = rand_strided((4096, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg325_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg326_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg327_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16)
arg328_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg329_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16)
arg330_1 = rand_strided((4096, 2064), (2064, 1), device='cuda:0', dtype=torch.int32)
arg331_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg332_1 = rand_strided((12288, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg333_1 = rand_strided((12288, ), (1, ), device='cuda:0', dtype=torch.float16)
arg334_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16)
arg335_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16)
arg336_1 = rand_strided((4096, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg337_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg338_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg339_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16)
arg340_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg341_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16)
arg342_1 = rand_strided((4096, 2064), (2064, 1), device='cuda:0', dtype=torch.int32)
arg343_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg344_1 = rand_strided((12288, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg345_1 = rand_strided((12288, ), (1, ), device='cuda:0', dtype=torch.float16)
arg346_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16)
arg347_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16)
arg348_1 = rand_strided((4096, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg349_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg350_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg351_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16)
arg352_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg353_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16)
arg354_1 = rand_strided((4096, 2064), (2064, 1), device='cuda:0', dtype=torch.int32)
arg355_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg356_1 = rand_strided((12288, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg357_1 = rand_strided((12288, ), (1, ), device='cuda:0', dtype=torch.float16)
arg358_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16)
arg359_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16)
arg360_1 = rand_strided((4096, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg361_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg362_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg363_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16)
arg364_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg365_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16)
arg366_1 = rand_strided((4096, 2064), (2064, 1), device='cuda:0', dtype=torch.int32)
arg367_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg368_1 = rand_strided((12288, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg369_1 = rand_strided((12288, ), (1, ), device='cuda:0', dtype=torch.float16)
arg370_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16)
arg371_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16)
arg372_1 = rand_strided((4096, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg373_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg374_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg375_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16)
arg376_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg377_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16)
arg378_1 = rand_strided((4096, 2064), (2064, 1), device='cuda:0', dtype=torch.int32)
arg379_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg380_1 = rand_strided((12288, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg381_1 = rand_strided((12288, ), (1, ), device='cuda:0', dtype=torch.float16)
arg382_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16)
arg383_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16)
arg384_1 = rand_strided((4096, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg385_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg386_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg387_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16)
arg388_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg389_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16)
arg390_1 = rand_strided((4096, 2064), (2064, 1), device='cuda:0', dtype=torch.int32)
arg391_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg392_1 = rand_strided((12288, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg393_1 = rand_strided((12288, ), (1, ), device='cuda:0', dtype=torch.float16)
arg394_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16)
arg395_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16)
arg396_1 = rand_strided((4096, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg397_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg398_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg399_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16)
arg400_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg401_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16)
arg402_1 = rand_strided((4096, 2064), (2064, 1), device='cuda:0', dtype=torch.int32)
arg403_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg404_1 = rand_strided((12288, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg405_1 = rand_strided((12288, ), (1, ), device='cuda:0', dtype=torch.float16)
arg406_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16)
arg407_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16)
arg408_1 = rand_strided((4096, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg409_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg410_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg411_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16)
arg412_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg413_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16)
arg414_1 = rand_strided((4096, 2064), (2064, 1), device='cuda:0', dtype=torch.int32)
arg415_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg416_1 = rand_strided((12288, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg417_1 = rand_strided((12288, ), (1, ), device='cuda:0', dtype=torch.float16)
arg418_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16)
arg419_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16)
arg420_1 = rand_strided((4096, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg421_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg422_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg423_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16)
arg424_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg425_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16)
arg426_1 = rand_strided((4096, 2064), (2064, 1), device='cuda:0', dtype=torch.int32)
arg427_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg428_1 = rand_strided((12288, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg429_1 = rand_strided((12288, ), (1, ), device='cuda:0', dtype=torch.float16)
arg430_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16)
arg431_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16)
arg432_1 = rand_strided((4096, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg433_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg434_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg435_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16)
arg436_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg437_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16)
arg438_1 = rand_strided((4096, 2064), (2064, 1), device='cuda:0', dtype=torch.int32)
arg439_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg440_1 = rand_strided((12288, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg441_1 = rand_strided((12288, ), (1, ), device='cuda:0', dtype=torch.float16)
arg442_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16)
arg443_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16)
arg444_1 = rand_strided((4096, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg445_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg446_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg447_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16)
arg448_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg449_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16)
arg450_1 = rand_strided((4096, 2064), (2064, 1), device='cuda:0', dtype=torch.int32)
arg451_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16)
arg452_1 = rand_strided((32000, 768), (768, 1), device='cuda:0', dtype=torch.int32)
arg453_1 = rand_strided((32000, ), (1, ), device='cuda:0', dtype=torch.float16)
arg454_1 = rand_strided((1, ), (1, ), device='cuda:0', dtype=torch.int32)
arg455_1 = rand_strided((1, 1), (1, 1), device='cuda:0', dtype=torch.int32)
fn = lambda: call([arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1, arg7_1, arg8_1, arg9_1, arg10_1, arg11_1, arg12_1, arg13_1, arg14_1, arg15_1, arg16_1, arg17_1, arg18_1, arg19_1, arg20_1, arg21_1, arg22_1, arg23_1, arg24_1, arg25_1, arg26_1, arg27_1, arg28_1, arg29_1, arg30_1, arg31_1, arg32_1, arg33_1, arg34_1, arg35_1, arg36_1, arg37_1, arg38_1, arg39_1, arg40_1, arg41_1, arg42_1, arg43_1, arg44_1, arg45_1, arg46_1, arg47_1, arg48_1, arg49_1, arg50_1, arg51_1, arg52_1, arg53_1, arg54_1, arg55_1, arg56_1, arg57_1, arg58_1, arg59_1, arg60_1, arg61_1, arg62_1, arg63_1, arg64_1, arg65_1, arg66_1, arg67_1, arg68_1, arg69_1, arg70_1, arg71_1, arg72_1, arg73_1, arg74_1, arg75_1, arg76_1, arg77_1, arg78_1, arg79_1, arg80_1, arg81_1, arg82_1, arg83_1, arg84_1, arg85_1, arg86_1, arg87_1, arg88_1, arg89_1, arg90_1, arg91_1, arg92_1, arg93_1, arg94_1, arg95_1, arg96_1, arg97_1, arg98_1, arg99_1, arg100_1, arg101_1, arg102_1, arg103_1, arg104_1, arg105_1, arg106_1, arg107_1, arg108_1, arg109_1, arg110_1, arg111_1, arg112_1, arg113_1, arg114_1, arg115_1, arg116_1, arg117_1, arg118_1, arg119_1, arg120_1, arg121_1, arg122_1, arg123_1, arg124_1, arg125_1, arg126_1, arg127_1, arg128_1, arg129_1, arg130_1, arg131_1, arg132_1, arg133_1, arg134_1, arg135_1, arg136_1, arg137_1, arg138_1, arg139_1, arg140_1, arg141_1, arg142_1, arg143_1, arg144_1, arg145_1, arg146_1, arg147_1, arg148_1, arg149_1, arg150_1, arg151_1, arg152_1, arg153_1, arg154_1, arg155_1, arg156_1, arg157_1, arg158_1, arg159_1, arg160_1, arg161_1, arg162_1, arg163_1, arg164_1, arg165_1, arg166_1, arg167_1, arg168_1, arg169_1, arg170_1, arg171_1, arg172_1, arg173_1, arg174_1, arg175_1, arg176_1, arg177_1, arg178_1, arg179_1, arg180_1, arg181_1, arg182_1, arg183_1, arg184_1, arg185_1, arg186_1, arg187_1, arg188_1, arg189_1, arg190_1, arg191_1, arg192_1, arg193_1, arg194_1, arg195_1, arg196_1, arg197_1, arg198_1, arg199_1, arg200_1, arg201_1, arg202_1, arg203_1, arg204_1, arg205_1, arg206_1, arg207_1, arg208_1, arg209_1, arg210_1, arg211_1, arg212_1, arg213_1, arg214_1, arg215_1, arg216_1, arg217_1, arg218_1, arg219_1, arg220_1, arg221_1, arg222_1, arg223_1, arg224_1, arg225_1, arg226_1, arg227_1, arg228_1, arg229_1, arg230_1, arg231_1, arg232_1, arg233_1, arg234_1, arg235_1, arg236_1, arg237_1, arg238_1, arg239_1, arg240_1, arg241_1, arg242_1, arg243_1, arg244_1, arg245_1, arg246_1, arg247_1, arg248_1, arg249_1, arg250_1, arg251_1, arg252_1, arg253_1, arg254_1, arg255_1, arg256_1, arg257_1, arg258_1, arg259_1, arg260_1, arg261_1, arg262_1, arg263_1, arg264_1, arg265_1, arg266_1, arg267_1, arg268_1, arg269_1, arg270_1, arg271_1, arg272_1, arg273_1, arg274_1, arg275_1, arg276_1, arg277_1, arg278_1, arg279_1, arg280_1, arg281_1, arg282_1, arg283_1, arg284_1, arg285_1, arg286_1, arg287_1, arg288_1, arg289_1, arg290_1, arg291_1, arg292_1, arg293_1, arg294_1, arg295_1, arg296_1, arg297_1, arg298_1, arg299_1, arg300_1, arg301_1, arg302_1, arg303_1, arg304_1, arg305_1, arg306_1, arg307_1, arg308_1, arg309_1, arg310_1, arg311_1, arg312_1, arg313_1, arg314_1, arg315_1, arg316_1, arg317_1, arg318_1, arg319_1, arg320_1, arg321_1, arg322_1, arg323_1, arg324_1, arg325_1, arg326_1, arg327_1, arg328_1, arg329_1, arg330_1, arg331_1, arg332_1, arg333_1, arg334_1, arg335_1, arg336_1, arg337_1, arg338_1, arg339_1, arg340_1, arg341_1, arg342_1, arg343_1, arg344_1, arg345_1, arg346_1, arg347_1, arg348_1, arg349_1, arg350_1, arg351_1, arg352_1, arg353_1, arg354_1, arg355_1, arg356_1, arg357_1, arg358_1, arg359_1, arg360_1, arg361_1, arg362_1, arg363_1, arg364_1, arg365_1, arg366_1, arg367_1, arg368_1, arg369_1, arg370_1, arg371_1, arg372_1, arg373_1, arg374_1, arg375_1, arg376_1, arg377_1, arg378_1, arg379_1, arg380_1, arg381_1, arg382_1, arg383_1, arg384_1, arg385_1, arg386_1, arg387_1, arg388_1, arg389_1, arg390_1, arg391_1, arg392_1, arg393_1, arg394_1, arg395_1, arg396_1, arg397_1, arg398_1, arg399_1, arg400_1, arg401_1, arg402_1, arg403_1, arg404_1, arg405_1, arg406_1, arg407_1, arg408_1, arg409_1, arg410_1, arg411_1, arg412_1, arg413_1, arg414_1, arg415_1, arg416_1, arg417_1, arg418_1, arg419_1, arg420_1, arg421_1, arg422_1, arg423_1, arg424_1, arg425_1, arg426_1, arg427_1, arg428_1, arg429_1, arg430_1, arg431_1, arg432_1, arg433_1, arg434_1, arg435_1, arg436_1, arg437_1, arg438_1, arg439_1, arg440_1, arg441_1, arg442_1, arg443_1, arg444_1, arg445_1, arg446_1, arg447_1, arg448_1, arg449_1, arg450_1, arg451_1, arg452_1, arg453_1, arg454_1, arg455_1])
return print_performance(fn, times=times, repeat=repeat)
if __name__ == "__main__":
from torch._inductor.wrapper_benchmark import compiled_module_main
compiled_module_main('None', benchmark_compiled_module)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment