-
-
Save gau-nernst/cde24dabe000f11991030609fc497a80 to your computer and use it in GitHub Desktop.
gpt-fast w/ FP6-LLM
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# AOT ID: ['0_inference'] | |
from ctypes import c_void_p, c_long | |
import torch | |
import math | |
import random | |
import os | |
import tempfile | |
from math import inf, nan | |
from torch._inductor.hooks import run_intermediate_hooks | |
from torch._inductor.utils import maybe_profile | |
from torch._inductor.codegen.memory_planning import _align as align | |
from torch import device, empty_strided | |
from torch._inductor.async_compile import AsyncCompile | |
from torch._inductor.select_algorithm import extern_kernels | |
from torch._inductor.codegen.multi_kernel import MultiKernelCall | |
aten = torch.ops.aten | |
inductor_ops = torch.ops.inductor | |
_quantized = torch.ops._quantized | |
assert_size_stride = torch._C._dynamo.guards.assert_size_stride | |
empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu | |
empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda | |
alloc_from_pool = torch.ops.inductor._alloc_from_pool | |
reinterpret_tensor = torch.ops.inductor._reinterpret_tensor | |
async_compile = AsyncCompile() | |
# kernel path: /tmp/torchinductor_ubuntu/35/c35n52yrfyq3r2uvr6syoon44qkntahvhxuhmylcpsxnin76axfd.py | |
# Source Nodes: [float_1, mean, mul, out, x], Original ATen: [aten._to_copy, aten.embedding, aten.mean, aten.mul, torchao.fp16act_fp6weight_linear] | |
# float_1 => convert_element_type | |
# mean => mean | |
# mul => mul | |
# out => fp16act_fp6weight_linear | |
# x => embedding | |
triton_red_fused__to_copy_embedding_fp16act_fp6weight_linear_mean_mul_0 = async_compile.triton('triton_red_fused__to_copy_embedding_fp16act_fp6weight_linear_mean_mul_0', ''' | |
import triton | |
import triton.language as tl | |
from triton.compiler.compiler import AttrsDescriptor | |
from torch._inductor.runtime import triton_helpers, triton_heuristics | |
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties | |
@triton_heuristics.reduction( | |
size_hints=[1, 4096], | |
reduction_hint=ReductionHint.INNER, | |
filename=__file__, | |
triton_meta={'signature': {0: '*i32', 1: '*fp16', 2: '*fp16', 3: '*fp16', 4: 'i32', 5: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=76), 'constants': {4: 1}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 5), equal_to_1=(4,))]}, | |
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_embedding_fp16act_fp6weight_linear_mean_mul_0', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 3, 'num_reduction': 1, 'backend_hash': '5002d3b4f89584c56f8df13bbe2510318f7ec3b73ba9efd611640fb08b2fe939', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False} | |
) | |
@triton.jit | |
def triton_red_fused__to_copy_embedding_fp16act_fp6weight_linear_mean_mul_0(in_ptr0, in_ptr1, in_ptr2, out_ptr1, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr): | |
xnumel = 1 | |
rnumel = 4096 | |
xoffset = tl.program_id(0) * XBLOCK | |
xindex = xoffset + tl.arange(0, XBLOCK)[:, None] | |
xmask = xindex < xnumel | |
rbase = tl.arange(0, RBLOCK)[None, :] | |
tmp0 = tl.load(in_ptr0 + (0)) | |
tmp1 = tl.broadcast_to(tmp0, [XBLOCK, RBLOCK]) | |
_tmp11 = tl.full([XBLOCK, RBLOCK], 0, tl.float32) | |
for roffset in range(0, rnumel, RBLOCK): | |
rindex = roffset + rbase | |
rmask = rindex < rnumel | |
r0 = rindex | |
tmp2 = tl.full([XBLOCK, RBLOCK], 32000, tl.int32) | |
tmp3 = tmp1 + tmp2 | |
tmp4 = tmp1 < 0 | |
tmp5 = tl.where(tmp4, tmp3, tmp1) | |
tl.device_assert((0 <= tmp5) & (tmp5 < 32000), "index out of bounds: 0 <= tmp5 < 32000") | |
tmp7 = tl.load(in_ptr1 + (r0 + (4096*tmp5)), rmask, eviction_policy='evict_last', other=0.0).to(tl.float32) | |
tmp8 = tmp7.to(tl.float32) | |
tmp9 = tmp8 * tmp8 | |
tmp10 = tl.broadcast_to(tmp9, [XBLOCK, RBLOCK]) | |
tmp12 = _tmp11 + tmp10 | |
_tmp11 = tl.where(rmask, tmp12, _tmp11) | |
tmp11 = tl.sum(_tmp11, 1)[:, None] | |
tmp13 = tl.load(in_ptr0 + (0)) | |
tmp14 = tl.broadcast_to(tmp13, [XBLOCK, RBLOCK]) | |
for roffset in range(0, rnumel, RBLOCK): | |
rindex = roffset + rbase | |
rmask = rindex < rnumel | |
r0 = rindex | |
tmp29 = tl.load(in_ptr2 + (r0), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32) | |
tmp15 = tl.full([XBLOCK, RBLOCK], 32000, tl.int32) | |
tmp16 = tmp14 + tmp15 | |
tmp17 = tmp14 < 0 | |
tmp18 = tl.where(tmp17, tmp16, tmp14) | |
tl.device_assert((0 <= tmp18) & (tmp18 < 32000), "index out of bounds: 0 <= tmp18 < 32000") | |
tmp20 = tl.load(in_ptr1 + (r0 + (4096*tmp18)), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32) | |
tmp21 = tmp20.to(tl.float32) | |
tmp22 = 4096.0 | |
tmp23 = tmp11 / tmp22 | |
tmp24 = 1e-05 | |
tmp25 = tmp23 + tmp24 | |
tmp26 = libdevice.rsqrt(tmp25) | |
tmp27 = tmp21 * tmp26 | |
tmp28 = tmp27.to(tl.float32) | |
tmp30 = tmp28 * tmp29 | |
tl.store(out_ptr1 + (tl.broadcast_to(r0, [XBLOCK, RBLOCK])), tmp30, rmask) | |
''', device_str='cuda') | |
import triton | |
import triton.language as tl | |
from torch._inductor.runtime.triton_heuristics import grid, split_scan_grid, start_graph, end_graph | |
from torch._C import _cuda_getCurrentRawStream as get_raw_stream | |
# kernel path: /tmp/torchinductor_ubuntu/ix/cix3cikar5jbote3hymtcnf5jtrl75mvovejwvdsj67ds2nqaynm.py | |
# Source Nodes: [setitem, setitem_1, y], Original ATen: [aten.bmm, aten.index_put] | |
# setitem => index_put | |
# setitem_1 => index_put_1 | |
# y => convert_element_type_6 | |
triton_poi_fused_bmm_index_put_1 = async_compile.triton('triton_poi_fused_bmm_index_put_1', ''' | |
import triton | |
import triton.language as tl | |
from triton.compiler.compiler import AttrsDescriptor | |
from torch._inductor.runtime import triton_helpers, triton_heuristics | |
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties | |
@triton_heuristics.pointwise( | |
size_hints=[4096], | |
filename=__file__, | |
triton_meta={'signature': {0: '*i32', 1: '*fp16', 2: '*fp16', 3: '*fp16', 4: '*fp32', 5: '*fp16', 6: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=76), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4, 5, 6), equal_to_1=())]}, | |
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_bmm_index_put_1', 'mutated_arg_names': ['out_ptr0', 'out_ptr2'], 'no_x_dim': False, 'num_load': 10, 'num_reduction': 0, 'backend_hash': '5002d3b4f89584c56f8df13bbe2510318f7ec3b73ba9efd611640fb08b2fe939', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False}, | |
min_elem_per_thread=0 | |
) | |
@triton.jit | |
def triton_poi_fused_bmm_index_put_1(in_ptr0, in_ptr1, in_ptr2, out_ptr0, out_ptr1, out_ptr2, xnumel, XBLOCK : tl.constexpr): | |
xnumel = 4096 | |
xoffset = tl.program_id(0) * XBLOCK | |
xindex = xoffset + tl.arange(0, XBLOCK)[:] | |
xmask = xindex < xnumel | |
x0 = xindex % 128 | |
x1 = (xindex // 128) | |
x2 = xindex | |
tmp0 = tl.load(in_ptr0 + (0)) | |
tmp1 = tl.broadcast_to(tmp0, [XBLOCK]) | |
tmp71 = tl.load(in_ptr1 + (8192 + x2), None).to(tl.float32) | |
tmp2 = tl.full([XBLOCK], 208, tl.int32) | |
tmp3 = tmp1 + tmp2 | |
tmp4 = tmp1 < 0 | |
tmp5 = tl.where(tmp4, tmp3, tmp1) | |
tl.device_assert((0 <= tmp5) & (tmp5 < 208), "index out of bounds: 0 <= tmp5 < 208") | |
tmp7 = x0 % 2 | |
tmp8 = tl.full([1], 0, tl.int64) | |
tmp9 = tmp7 >= tmp8 | |
tmp10 = tl.full([1], 1, tl.int64) | |
tmp11 = tmp7 < tmp10 | |
tmp12 = tl.load(in_ptr1 + (4096 + (2*(x0 // 2)) + (128*x1)), tmp11, eviction_policy='evict_last', other=0.0).to(tl.float32) | |
tmp13 = tmp12.to(tl.float32) | |
tmp14 = tl.full([XBLOCK], 2048, tl.int32) | |
tmp15 = tmp1 + tmp14 | |
tmp16 = tl.where(tmp4, tmp15, tmp1) | |
tl.device_assert(((0 <= tl.broadcast_to(tmp16, [XBLOCK])) & (tl.broadcast_to(tmp16, [XBLOCK]) < 2048)) | ~(tmp11), "index out of bounds: 0 <= tl.broadcast_to(tmp16, [XBLOCK]) < 2048") | |
tmp18 = tl.load(in_ptr2 + ((2*(x0 // 2)) + (128*tmp16)), tmp11, eviction_policy='evict_last', other=0.0).to(tl.float32) | |
tmp19 = tmp18.to(tl.float32) | |
tmp20 = tmp13 * tmp19 | |
tmp21 = tl.load(in_ptr1 + (4097 + (2*(x0 // 2)) + (128*x1)), tmp11, eviction_policy='evict_last', other=0.0).to(tl.float32) | |
tmp22 = tmp21.to(tl.float32) | |
tmp23 = tl.load(in_ptr2 + (1 + (2*(x0 // 2)) + (128*tmp16)), tmp11, eviction_policy='evict_last', other=0.0).to(tl.float32) | |
tmp24 = tmp23.to(tl.float32) | |
tmp25 = tmp22 * tmp24 | |
tmp26 = tmp20 - tmp25 | |
tmp27 = tl.full(tmp26.shape, 0.0, tmp26.dtype) | |
tmp28 = tl.where(tmp11, tmp26, tmp27) | |
tmp29 = tmp7 >= tmp10 | |
tmp30 = tl.full([1], 2, tl.int64) | |
tmp31 = tmp7 < tmp30 | |
tmp32 = tl.load(in_ptr1 + (4097 + (2*(x0 // 2)) + (128*x1)), tmp29, eviction_policy='evict_last', other=0.0).to(tl.float32) | |
tmp33 = tmp32.to(tl.float32) | |
tl.device_assert(((0 <= tl.broadcast_to(tmp16, [XBLOCK])) & (tl.broadcast_to(tmp16, [XBLOCK]) < 2048)) | ~(tmp29), "index out of bounds: 0 <= tl.broadcast_to(tmp16, [XBLOCK]) < 2048") | |
tmp35 = tl.load(in_ptr2 + ((2*(x0 // 2)) + (128*tmp16)), tmp29, eviction_policy='evict_last', other=0.0).to(tl.float32) | |
tmp36 = tmp35.to(tl.float32) | |
tmp37 = tmp33 * tmp36 | |
tmp38 = tl.load(in_ptr1 + (4096 + (2*(x0 // 2)) + (128*x1)), tmp29, eviction_policy='evict_last', other=0.0).to(tl.float32) | |
tmp39 = tmp38.to(tl.float32) | |
tmp40 = tl.load(in_ptr2 + (1 + (2*(x0 // 2)) + (128*tmp16)), tmp29, eviction_policy='evict_last', other=0.0).to(tl.float32) | |
tmp41 = tmp40.to(tl.float32) | |
tmp42 = tmp39 * tmp41 | |
tmp43 = tmp37 + tmp42 | |
tmp44 = tl.full(tmp43.shape, 0.0, tmp43.dtype) | |
tmp45 = tl.where(tmp29, tmp43, tmp44) | |
tmp46 = tl.where(tmp11, tmp28, tmp45) | |
tmp47 = tmp46.to(tl.float32) | |
tmp48 = tl.load(in_ptr1 + ((2*(x0 // 2)) + (128*x1)), tmp11, eviction_policy='evict_last', other=0.0).to(tl.float32) | |
tmp49 = tmp48.to(tl.float32) | |
tmp50 = tmp49 * tmp19 | |
tmp51 = tl.load(in_ptr1 + (1 + (2*(x0 // 2)) + (128*x1)), tmp11, eviction_policy='evict_last', other=0.0).to(tl.float32) | |
tmp52 = tmp51.to(tl.float32) | |
tmp53 = tmp52 * tmp24 | |
tmp54 = tmp50 - tmp53 | |
tmp55 = tl.full(tmp54.shape, 0.0, tmp54.dtype) | |
tmp56 = tl.where(tmp11, tmp54, tmp55) | |
tmp57 = tl.load(in_ptr1 + (1 + (2*(x0 // 2)) + (128*x1)), tmp29, eviction_policy='evict_last', other=0.0).to(tl.float32) | |
tmp58 = tmp57.to(tl.float32) | |
tmp59 = tmp58 * tmp36 | |
tmp60 = tl.load(in_ptr1 + ((2*(x0 // 2)) + (128*x1)), tmp29, eviction_policy='evict_last', other=0.0).to(tl.float32) | |
tmp61 = tmp60.to(tl.float32) | |
tmp62 = tmp61 * tmp41 | |
tmp63 = tmp59 + tmp62 | |
tmp64 = tl.full(tmp63.shape, 0.0, tmp63.dtype) | |
tmp65 = tl.where(tmp29, tmp63, tmp64) | |
tmp66 = tl.where(tmp11, tmp56, tmp65) | |
tmp67 = tmp66.to(tl.float32) | |
tmp68 = 0.29730177875068026 | |
tmp69 = tmp67 * tmp68 | |
tmp70 = tmp69.to(tl.float32) | |
tl.store(out_ptr0 + (x0 + (128*tmp5) + (26624*x1)), tmp47, None) | |
tl.store(out_ptr1 + (x2), tmp70, None) | |
tl.store(out_ptr2 + (x0 + (128*tmp5) + (26624*x1)), tmp71, None) | |
''', device_str='cuda') | |
# kernel path: /tmp/torchinductor_ubuntu/tu/ctuxbryvllsmitwtw4gk35f6e4uug6renb6km4llzq2czdpete6i.py | |
# Source Nodes: [y], Original ATen: [aten.bmm] | |
# y => mul_13, sum_1 | |
triton_red_fused_bmm_2 = async_compile.triton('triton_red_fused_bmm_2', ''' | |
import triton | |
import triton.language as tl | |
from triton.compiler.compiler import AttrsDescriptor | |
from torch._inductor.runtime import triton_helpers, triton_heuristics | |
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties | |
@triton_heuristics.reduction( | |
size_hints=[8192, 128], | |
reduction_hint=ReductionHint.DEFAULT, | |
filename=__file__, | |
triton_meta={'signature': {0: '*fp32', 1: '*fp16', 2: '*fp32', 3: 'i32', 4: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=76), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4), equal_to_1=())]}, | |
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused_bmm_2', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 2, 'num_reduction': 1, 'backend_hash': '5002d3b4f89584c56f8df13bbe2510318f7ec3b73ba9efd611640fb08b2fe939', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False} | |
) | |
@triton.jit | |
def triton_red_fused_bmm_2(in_ptr0, in_ptr1, out_ptr0, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr): | |
xnumel = 6656 | |
rnumel = 128 | |
xoffset = tl.program_id(0) * XBLOCK | |
xindex = xoffset + tl.arange(0, XBLOCK)[:, None] | |
xmask = xindex < xnumel | |
rbase = tl.arange(0, RBLOCK)[None, :] | |
x1 = (xindex // 208) | |
x3 = xindex | |
_tmp7 = tl.full([XBLOCK, RBLOCK], 0, tl.float32) | |
for roffset in range(0, rnumel, RBLOCK): | |
rindex = roffset + rbase | |
rmask = rindex < rnumel | |
r2 = rindex | |
tmp0 = tl.load(in_ptr0 + (r2 + (128*x1)), rmask & xmask, eviction_policy='evict_last', other=0.0) | |
tmp1 = tl.load(in_ptr1 + (r2 + (128*x3)), rmask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32) | |
tmp2 = 0.29730177875068026 | |
tmp3 = tmp1 * tmp2 | |
tmp4 = tmp3.to(tl.float32) | |
tmp5 = tmp0 * tmp4 | |
tmp6 = tl.broadcast_to(tmp5, [XBLOCK, RBLOCK]) | |
tmp8 = _tmp7 + tmp6 | |
_tmp7 = tl.where(rmask & xmask, tmp8, _tmp7) | |
tmp7 = tl.sum(_tmp7, 1)[:, None] | |
tl.store(out_ptr0 + (x3), tmp7, xmask) | |
''', device_str='cuda') | |
# kernel path: /tmp/torchinductor_ubuntu/f6/cf6hoyusglegqlcllwth4hbv7jzk7xu6ol3qg56cytjysskalpwn.py | |
# Source Nodes: [mask, y], Original ATen: [aten._softmax, aten.add, aten.bmm, aten.index, aten.logical_not, aten.masked_fill, aten.zeros_like] | |
# mask => index | |
# y => add_3, amax, convert_element_type_11, convert_element_type_9, exp, full_default, full_default_1, logical_not, sub_2, sum_2, where | |
triton_per_fused__softmax_add_bmm_index_logical_not_masked_fill_zeros_like_3 = async_compile.triton('triton_per_fused__softmax_add_bmm_index_logical_not_masked_fill_zeros_like_3', ''' | |
import triton | |
import triton.language as tl | |
from triton.compiler.compiler import AttrsDescriptor | |
from torch._inductor.runtime import triton_helpers, triton_heuristics | |
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties | |
@triton_heuristics.persistent_reduction( | |
size_hints=[32, 256], | |
reduction_hint=ReductionHint.INNER, | |
filename=__file__, | |
triton_meta={'signature': {0: '*fp32', 1: '*i32', 2: '*i1', 3: '*fp32', 4: 'i32', 5: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=76), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4, 5), equal_to_1=())]}, | |
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_per_fused__softmax_add_bmm_index_logical_not_masked_fill_zeros_like_3', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 2, 'num_reduction': 2, 'backend_hash': '5002d3b4f89584c56f8df13bbe2510318f7ec3b73ba9efd611640fb08b2fe939', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False} | |
) | |
@triton.jit | |
def triton_per_fused__softmax_add_bmm_index_logical_not_masked_fill_zeros_like_3(in_ptr0, in_ptr1, in_ptr2, out_ptr2, xnumel, rnumel, XBLOCK : tl.constexpr): | |
xnumel = 32 | |
rnumel = 208 | |
RBLOCK: tl.constexpr = 256 | |
xoffset = tl.program_id(0) * XBLOCK | |
xindex = xoffset + tl.arange(0, XBLOCK)[:, None] | |
xmask = xindex < xnumel | |
rindex = tl.arange(0, RBLOCK)[None, :] | |
roffset = 0 | |
rmask = rindex < rnumel | |
r1 = rindex | |
x0 = xindex | |
tmp0 = tl.load(in_ptr0 + (r1 + (208*x0)), rmask & xmask, other=0.0) | |
tmp2 = tl.load(in_ptr1 + (0)) | |
tmp3 = tl.broadcast_to(tmp2, [XBLOCK, RBLOCK]) | |
tmp1 = tmp0.to(tl.float32) | |
tmp4 = tl.full([XBLOCK, RBLOCK], 208, tl.int32) | |
tmp5 = tmp3 + tmp4 | |
tmp6 = tmp3 < 0 | |
tmp7 = tl.where(tmp6, tmp5, tmp3) | |
tl.device_assert((0 <= tmp7) & (tmp7 < 208), "index out of bounds: 0 <= tmp7 < 208") | |
tmp9 = tl.load(in_ptr2 + (r1 + (208*tmp7)), rmask, other=0.0).to(tl.int1) | |
tmp10 = tmp9 == 0 | |
tmp11 = float("-inf") | |
tmp12 = 0.0 | |
tmp13 = tl.where(tmp10, tmp11, tmp12) | |
tmp14 = tmp1 + tmp13 | |
tmp15 = tmp14.to(tl.float32) | |
tmp16 = tl.broadcast_to(tmp15, [XBLOCK, RBLOCK]) | |
tmp18 = tl.where(rmask & xmask, tmp16, float("-inf")) | |
tmp19 = triton_helpers.max2(tmp18, 1)[:, None] | |
tmp20 = tmp15 - tmp19 | |
tmp21 = tl_math.exp(tmp20) | |
tmp22 = tl.broadcast_to(tmp21, [XBLOCK, RBLOCK]) | |
tmp24 = tl.where(rmask & xmask, tmp22, 0) | |
tmp25 = tl.sum(tmp24, 1)[:, None] | |
tmp26 = tmp21 / tmp25 | |
tmp27 = tmp26.to(tl.float32) | |
tmp28 = tmp27.to(tl.float32) | |
tl.store(out_ptr2 + (r1 + (208*x0)), tmp28, rmask & xmask) | |
''', device_str='cuda') | |
# kernel path: /tmp/torchinductor_ubuntu/3q/c3q5eet7mn3gx23dlvno24ooilcrlirl64ahmhluvxu6ynwq452m.py | |
# Source Nodes: [y], Original ATen: [aten.bmm] | |
# y => mul_14, sum_3 | |
triton_red_fused_bmm_4 = async_compile.triton('triton_red_fused_bmm_4', ''' | |
import triton | |
import triton.language as tl | |
from triton.compiler.compiler import AttrsDescriptor | |
from torch._inductor.runtime import triton_helpers, triton_heuristics | |
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties | |
@triton_heuristics.reduction( | |
size_hints=[8192, 128], | |
reduction_hint=ReductionHint.OUTER, | |
filename=__file__, | |
triton_meta={'signature': {0: '*fp32', 1: '*fp16', 2: '*fp32', 3: 'i32', 4: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=76), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3), equal_to_1=())]}, | |
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused_bmm_4', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 2, 'num_reduction': 1, 'backend_hash': '5002d3b4f89584c56f8df13bbe2510318f7ec3b73ba9efd611640fb08b2fe939', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False} | |
) | |
@triton.jit | |
def triton_red_fused_bmm_4(in_ptr0, in_ptr1, out_ptr0, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr): | |
xnumel = 8192 | |
rnumel = 104 | |
xoffset = tl.program_id(0) * XBLOCK | |
xindex = xoffset + tl.arange(0, XBLOCK)[:, None] | |
xmask = xindex < xnumel | |
rbase = tl.arange(0, RBLOCK)[None, :] | |
x1 = (xindex // 128) | |
x0 = xindex % 128 | |
_tmp5 = tl.full([XBLOCK, RBLOCK], 0, tl.float32) | |
x3 = xindex | |
for roffset in range(0, rnumel, RBLOCK): | |
rindex = roffset + rbase | |
rmask = rindex < rnumel | |
r2 = rindex | |
tmp0 = tl.load(in_ptr0 + (r2 + (104*x1)), rmask, eviction_policy='evict_last', other=0.0) | |
tmp1 = tl.load(in_ptr1 + (x0 + (128*r2) + (13312*x1)), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32) | |
tmp2 = tmp1.to(tl.float32) | |
tmp3 = tmp0 * tmp2 | |
tmp4 = tl.broadcast_to(tmp3, [XBLOCK, RBLOCK]) | |
tmp6 = _tmp5 + tmp4 | |
_tmp5 = tl.where(rmask, tmp6, _tmp5) | |
tmp5 = tl.sum(_tmp5, 1)[:, None] | |
tl.store(out_ptr0 + (x3), tmp5, None) | |
''', device_str='cuda') | |
# kernel path: /tmp/torchinductor_ubuntu/s7/cs7bh5txamreduogo3dbpgu62hdnknommfet7uxuau6q2gi24u6j.py | |
# Source Nodes: [out_1, y], Original ATen: [aten.bmm, torchao.fp16act_fp6weight_linear] | |
# out_1 => fp16act_fp6weight_linear_1 | |
# y => mul_14, sum_3 | |
triton_per_fused_bmm_fp16act_fp6weight_linear_5 = async_compile.triton('triton_per_fused_bmm_fp16act_fp6weight_linear_5', ''' | |
import triton | |
import triton.language as tl | |
from triton.compiler.compiler import AttrsDescriptor | |
from torch._inductor.runtime import triton_helpers, triton_heuristics | |
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties | |
@triton_heuristics.persistent_reduction( | |
size_hints=[4096, 2], | |
reduction_hint=ReductionHint.OUTER, | |
filename=__file__, | |
triton_meta={'signature': {0: '*fp32', 1: '*fp16', 2: 'i32', 3: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=76), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2), equal_to_1=())]}, | |
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_per_fused_bmm_fp16act_fp6weight_linear_5', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': '5002d3b4f89584c56f8df13bbe2510318f7ec3b73ba9efd611640fb08b2fe939', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False} | |
) | |
@triton.jit | |
def triton_per_fused_bmm_fp16act_fp6weight_linear_5(in_ptr0, out_ptr1, xnumel, rnumel, XBLOCK : tl.constexpr): | |
xnumel = 4096 | |
rnumel = 2 | |
RBLOCK: tl.constexpr = 2 | |
xoffset = tl.program_id(0) * XBLOCK | |
xindex = xoffset + tl.arange(0, XBLOCK)[:, None] | |
xmask = xindex < xnumel | |
rindex = tl.arange(0, RBLOCK)[None, :] | |
roffset = 0 | |
rmask = rindex < rnumel | |
r2 = rindex | |
x0 = xindex % 128 | |
x1 = (xindex // 128) | |
x3 = xindex | |
tmp0 = tl.load(in_ptr0 + (x0 + (128*r2) + (256*x1)), rmask, other=0.0) | |
tmp1 = tl.broadcast_to(tmp0, [XBLOCK, RBLOCK]) | |
tmp3 = tl.where(rmask, tmp1, 0) | |
tmp4 = tl.sum(tmp3, 1)[:, None] | |
tmp5 = tmp4.to(tl.float32) | |
tl.store(out_ptr1 + (x3), tmp5, None) | |
''', device_str='cuda') | |
# kernel path: /tmp/torchinductor_ubuntu/e6/ce6earzjiin77ifcgxqhqghmkxait7rqhahkb2zasp4hadk7bw4d.py | |
# Source Nodes: [add_4, float_4, h, mean_1, mul_11, mul_12, mul_13, output_1, rsqrt_1, x], Original ATen: [aten._to_copy, aten.add, aten.embedding, aten.mean, aten.mul, aten.rsqrt] | |
# add_4 => add_5 | |
# float_4 => convert_element_type_14 | |
# h => add_4 | |
# mean_1 => mean_1 | |
# mul_11 => mul_15 | |
# mul_12 => mul_16 | |
# mul_13 => mul_17 | |
# output_1 => convert_element_type_15 | |
# rsqrt_1 => rsqrt_1 | |
# x => embedding | |
triton_red_fused__to_copy_add_embedding_mean_mul_rsqrt_6 = async_compile.triton('triton_red_fused__to_copy_add_embedding_mean_mul_rsqrt_6', ''' | |
import triton | |
import triton.language as tl | |
from triton.compiler.compiler import AttrsDescriptor | |
from torch._inductor.runtime import triton_helpers, triton_heuristics | |
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties | |
@triton_heuristics.reduction( | |
size_hints=[1, 4096], | |
reduction_hint=ReductionHint.INNER, | |
filename=__file__, | |
triton_meta={'signature': {0: '*i32', 1: '*fp16', 2: '*fp16', 3: '*fp16', 4: '*fp16', 5: 'i32', 6: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=76), 'constants': {5: 1}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4, 6), equal_to_1=(5,))]}, | |
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_add_embedding_mean_mul_rsqrt_6', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 5, 'num_reduction': 1, 'backend_hash': '5002d3b4f89584c56f8df13bbe2510318f7ec3b73ba9efd611640fb08b2fe939', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False} | |
) | |
@triton.jit | |
def triton_red_fused__to_copy_add_embedding_mean_mul_rsqrt_6(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr1, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr): | |
xnumel = 1 | |
rnumel = 4096 | |
xoffset = tl.program_id(0) * XBLOCK | |
xindex = xoffset + tl.arange(0, XBLOCK)[:, None] | |
xmask = xindex < xnumel | |
rbase = tl.arange(0, RBLOCK)[None, :] | |
tmp0 = tl.load(in_ptr0 + (0)) | |
tmp1 = tl.broadcast_to(tmp0, [XBLOCK, RBLOCK]) | |
_tmp13 = tl.full([XBLOCK, RBLOCK], 0, tl.float32) | |
for roffset in range(0, rnumel, RBLOCK): | |
rindex = roffset + rbase | |
rmask = rindex < rnumel | |
r0 = rindex | |
tmp8 = tl.load(in_ptr2 + (r0), rmask, eviction_policy='evict_last', other=0.0).to(tl.float32) | |
tmp2 = tl.full([XBLOCK, RBLOCK], 32000, tl.int32) | |
tmp3 = tmp1 + tmp2 | |
tmp4 = tmp1 < 0 | |
tmp5 = tl.where(tmp4, tmp3, tmp1) | |
tl.device_assert((0 <= tmp5) & (tmp5 < 32000), "index out of bounds: 0 <= tmp5 < 32000") | |
tmp7 = tl.load(in_ptr1 + (r0 + (4096*tmp5)), rmask, eviction_policy='evict_last', other=0.0).to(tl.float32) | |
tmp9 = tmp7 + tmp8 | |
tmp10 = tmp9.to(tl.float32) | |
tmp11 = tmp10 * tmp10 | |
tmp12 = tl.broadcast_to(tmp11, [XBLOCK, RBLOCK]) | |
tmp14 = _tmp13 + tmp12 | |
_tmp13 = tl.where(rmask, tmp14, _tmp13) | |
tmp13 = tl.sum(_tmp13, 1)[:, None] | |
tmp15 = tl.load(in_ptr0 + (0)) | |
tmp16 = tl.broadcast_to(tmp15, [XBLOCK, RBLOCK]) | |
for roffset in range(0, rnumel, RBLOCK): | |
rindex = roffset + rbase | |
rmask = rindex < rnumel | |
r0 = rindex | |
tmp23 = tl.load(in_ptr2 + (r0), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32) | |
tmp33 = tl.load(in_ptr3 + (r0), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32) | |
tmp17 = tl.full([XBLOCK, RBLOCK], 32000, tl.int32) | |
tmp18 = tmp16 + tmp17 | |
tmp19 = tmp16 < 0 | |
tmp20 = tl.where(tmp19, tmp18, tmp16) | |
tl.device_assert((0 <= tmp20) & (tmp20 < 32000), "index out of bounds: 0 <= tmp20 < 32000") | |
tmp22 = tl.load(in_ptr1 + (r0 + (4096*tmp20)), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32) | |
tmp24 = tmp22 + tmp23 | |
tmp25 = tmp24.to(tl.float32) | |
tmp26 = 4096.0 | |
tmp27 = tmp13 / tmp26 | |
tmp28 = 1e-05 | |
tmp29 = tmp27 + tmp28 | |
tmp30 = libdevice.rsqrt(tmp29) | |
tmp31 = tmp25 * tmp30 | |
tmp32 = tmp31.to(tl.float32) | |
tmp34 = tmp32 * tmp33 | |
tl.store(out_ptr1 + (tl.broadcast_to(r0, [XBLOCK, RBLOCK])), tmp34, rmask) | |
''', device_str='cuda') | |
# kernel path: /tmp/torchinductor_ubuntu/7x/c7xzb5ucrv7a232bufgvbjemlk4jgdh7ujoroxm3v7jxuduwdfif.py | |
# Source Nodes: [out_4], Original ATen: [torchao.fp16act_fp6weight_linear] | |
# out_4 => fp16act_fp6weight_linear_4 | |
triton_poi_fused_fp16act_fp6weight_linear_7 = async_compile.triton('triton_poi_fused_fp16act_fp6weight_linear_7', ''' | |
import triton | |
import triton.language as tl | |
from triton.compiler.compiler import AttrsDescriptor | |
from torch._inductor.runtime import triton_helpers, triton_heuristics | |
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties | |
@triton_heuristics.pointwise( | |
size_hints=[16384], | |
filename=__file__, | |
triton_meta={'signature': {0: '*fp16', 1: '*fp16', 2: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=76), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2), equal_to_1=())]}, | |
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_fp16act_fp6weight_linear_7', 'mutated_arg_names': ['in_out_ptr0'], 'no_x_dim': False, 'num_load': 2, 'num_reduction': 0, 'backend_hash': '5002d3b4f89584c56f8df13bbe2510318f7ec3b73ba9efd611640fb08b2fe939', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False}, | |
min_elem_per_thread=0 | |
) | |
@triton.jit | |
def triton_poi_fused_fp16act_fp6weight_linear_7(in_out_ptr0, in_ptr0, xnumel, XBLOCK : tl.constexpr): | |
xnumel = 11008 | |
xoffset = tl.program_id(0) * XBLOCK | |
xindex = xoffset + tl.arange(0, XBLOCK)[:] | |
xmask = xindex < xnumel | |
x0 = xindex | |
tmp0 = tl.load(in_out_ptr0 + (x0), xmask).to(tl.float32) | |
tmp5 = tl.load(in_ptr0 + (x0), xmask).to(tl.float32) | |
tmp1 = tmp0.to(tl.float32) | |
tmp2 = tl.sigmoid(tmp1) | |
tmp3 = tmp1 * tmp2 | |
tmp4 = tmp3.to(tl.float32) | |
tmp6 = tmp4 * tmp5 | |
tl.store(in_out_ptr0 + (x0), tmp6, xmask) | |
''', device_str='cuda') | |
# kernel path: /tmp/torchinductor_ubuntu/xb/cxbu5uv7truaeft3qae6exsdva5w6j3jruqooy6lvbsp53jek26v.py | |
# Source Nodes: [float_5, h, mean_2, mul_15, out_5, out_6, x], Original ATen: [aten._to_copy, aten.add, aten.embedding, aten.mean, aten.mul, torchao.fp16act_fp6weight_linear] | |
# float_5 => convert_element_type_18 | |
# h => add_4 | |
# mean_2 => mean_2 | |
# mul_15 => mul_20 | |
# out_5 => add_6 | |
# out_6 => fp16act_fp6weight_linear_5 | |
# x => embedding | |
triton_red_fused__to_copy_add_embedding_fp16act_fp6weight_linear_mean_mul_8 = async_compile.triton('triton_red_fused__to_copy_add_embedding_fp16act_fp6weight_linear_mean_mul_8', ''' | |
import triton | |
import triton.language as tl | |
from triton.compiler.compiler import AttrsDescriptor | |
from torch._inductor.runtime import triton_helpers, triton_heuristics | |
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties | |
@triton_heuristics.reduction( | |
size_hints=[1, 4096], | |
reduction_hint=ReductionHint.INNER, | |
filename=__file__, | |
triton_meta={'signature': {0: '*i32', 1: '*fp16', 2: '*fp16', 3: '*fp16', 4: '*fp16', 5: '*fp16', 6: 'i32', 7: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=76), 'constants': {6: 1}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4, 5, 7), equal_to_1=(6,))]}, | |
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_add_embedding_fp16act_fp6weight_linear_mean_mul_8', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 7, 'num_reduction': 1, 'backend_hash': '5002d3b4f89584c56f8df13bbe2510318f7ec3b73ba9efd611640fb08b2fe939', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False} | |
) | |
@triton.jit | |
def triton_red_fused__to_copy_add_embedding_fp16act_fp6weight_linear_mean_mul_8(in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, out_ptr1, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr): | |
xnumel = 1 | |
rnumel = 4096 | |
xoffset = tl.program_id(0) * XBLOCK | |
xindex = xoffset + tl.arange(0, XBLOCK)[:, None] | |
xmask = xindex < xnumel | |
rbase = tl.arange(0, RBLOCK)[None, :] | |
tmp0 = tl.load(in_ptr0 + (0)) | |
tmp1 = tl.broadcast_to(tmp0, [XBLOCK, RBLOCK]) | |
_tmp15 = tl.full([XBLOCK, RBLOCK], 0, tl.float32) | |
for roffset in range(0, rnumel, RBLOCK): | |
rindex = roffset + rbase | |
rmask = rindex < rnumel | |
r0 = rindex | |
tmp8 = tl.load(in_ptr2 + (r0), rmask, eviction_policy='evict_last', other=0.0).to(tl.float32) | |
tmp10 = tl.load(in_ptr3 + (r0), rmask, eviction_policy='evict_last', other=0.0).to(tl.float32) | |
tmp2 = tl.full([XBLOCK, RBLOCK], 32000, tl.int32) | |
tmp3 = tmp1 + tmp2 | |
tmp4 = tmp1 < 0 | |
tmp5 = tl.where(tmp4, tmp3, tmp1) | |
tl.device_assert((0 <= tmp5) & (tmp5 < 32000), "index out of bounds: 0 <= tmp5 < 32000") | |
tmp7 = tl.load(in_ptr1 + (r0 + (4096*tmp5)), rmask, eviction_policy='evict_last', other=0.0).to(tl.float32) | |
tmp9 = tmp7 + tmp8 | |
tmp11 = tmp9 + tmp10 | |
tmp12 = tmp11.to(tl.float32) | |
tmp13 = tmp12 * tmp12 | |
tmp14 = tl.broadcast_to(tmp13, [XBLOCK, RBLOCK]) | |
tmp16 = _tmp15 + tmp14 | |
_tmp15 = tl.where(rmask, tmp16, _tmp15) | |
tmp15 = tl.sum(_tmp15, 1)[:, None] | |
tmp17 = tl.load(in_ptr0 + (0)) | |
tmp18 = tl.broadcast_to(tmp17, [XBLOCK, RBLOCK]) | |
for roffset in range(0, rnumel, RBLOCK): | |
rindex = roffset + rbase | |
rmask = rindex < rnumel | |
r0 = rindex | |
tmp25 = tl.load(in_ptr2 + (r0), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32) | |
tmp27 = tl.load(in_ptr3 + (r0), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32) | |
tmp37 = tl.load(in_ptr4 + (r0), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32) | |
tmp19 = tl.full([XBLOCK, RBLOCK], 32000, tl.int32) | |
tmp20 = tmp18 + tmp19 | |
tmp21 = tmp18 < 0 | |
tmp22 = tl.where(tmp21, tmp20, tmp18) | |
tl.device_assert((0 <= tmp22) & (tmp22 < 32000), "index out of bounds: 0 <= tmp22 < 32000") | |
tmp24 = tl.load(in_ptr1 + (r0 + (4096*tmp22)), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32) | |
tmp26 = tmp24 + tmp25 | |
tmp28 = tmp26 + tmp27 | |
tmp29 = tmp28.to(tl.float32) | |
tmp30 = 4096.0 | |
tmp31 = tmp15 / tmp30 | |
tmp32 = 1e-05 | |
tmp33 = tmp31 + tmp32 | |
tmp34 = libdevice.rsqrt(tmp33) | |
tmp35 = tmp29 * tmp34 | |
tmp36 = tmp35.to(tl.float32) | |
tmp38 = tmp36 * tmp37 | |
tl.store(out_ptr1 + (tl.broadcast_to(r0, [XBLOCK, RBLOCK])), tmp38, rmask) | |
''', device_str='cuda') | |
# kernel path: /tmp/torchinductor_ubuntu/ko/ckonz56ntbc3wxvboqbs6j6xmjdrqma6fl5rtmem5grbfiymtcs4.py | |
# Source Nodes: [float_8, h, h_1, mean_3, mul_26, out_5, out_8, out_9, x], Original ATen: [aten._to_copy, aten.add, aten.embedding, aten.mean, aten.mul, torchao.fp16act_fp6weight_linear] | |
# float_8 => convert_element_type_32 | |
# h => add_4 | |
# h_1 => add_11 | |
# mean_3 => mean_3 | |
# mul_26 => mul_35 | |
# out_5 => add_6 | |
# out_8 => fp16act_fp6weight_linear_7 | |
# out_9 => fp16act_fp6weight_linear_8 | |
# x => embedding | |
triton_red_fused__to_copy_add_embedding_fp16act_fp6weight_linear_mean_mul_9 = async_compile.triton('triton_red_fused__to_copy_add_embedding_fp16act_fp6weight_linear_mean_mul_9', ''' | |
import triton | |
import triton.language as tl | |
from triton.compiler.compiler import AttrsDescriptor | |
from torch._inductor.runtime import triton_helpers, triton_heuristics | |
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties | |
@triton_heuristics.reduction( | |
size_hints=[1, 4096], | |
reduction_hint=ReductionHint.INNER, | |
filename=__file__, | |
triton_meta={'signature': {0: '*fp16', 1: '*i32', 2: '*fp16', 3: '*fp16', 4: '*fp16', 5: '*fp16', 6: '*fp16', 7: '*fp16', 8: 'i32', 9: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=76), 'constants': {8: 1}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4, 5, 6, 7, 9), equal_to_1=(8,))]}, | |
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_add_embedding_fp16act_fp6weight_linear_mean_mul_9', 'mutated_arg_names': ['in_out_ptr0'], 'no_x_dim': False, 'num_load': 6, 'num_reduction': 1, 'backend_hash': '5002d3b4f89584c56f8df13bbe2510318f7ec3b73ba9efd611640fb08b2fe939', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False} | |
) | |
@triton.jit | |
def triton_red_fused__to_copy_add_embedding_fp16act_fp6weight_linear_mean_mul_9(in_out_ptr0, in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, out_ptr1, out_ptr2, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr): | |
xnumel = 1 | |
rnumel = 4096 | |
xoffset = tl.program_id(0) * XBLOCK | |
xindex = xoffset + tl.arange(0, XBLOCK)[:, None] | |
xmask = xindex < xnumel | |
rbase = tl.arange(0, RBLOCK)[None, :] | |
tmp0 = tl.load(in_ptr0 + (0)) | |
tmp1 = tl.broadcast_to(tmp0, [XBLOCK, RBLOCK]) | |
_tmp17 = tl.full([XBLOCK, RBLOCK], 0, tl.float32) | |
for roffset in range(0, rnumel, RBLOCK): | |
rindex = roffset + rbase | |
rmask = rindex < rnumel | |
r0 = rindex | |
tmp8 = tl.load(in_out_ptr0 + (r0), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32) | |
tmp10 = tl.load(in_ptr2 + (r0), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32) | |
tmp12 = tl.load(in_ptr3 + (r0), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32) | |
tmp2 = tl.full([XBLOCK, RBLOCK], 32000, tl.int32) | |
tmp3 = tmp1 + tmp2 | |
tmp4 = tmp1 < 0 | |
tmp5 = tl.where(tmp4, tmp3, tmp1) | |
tl.device_assert((0 <= tmp5) & (tmp5 < 32000), "index out of bounds: 0 <= tmp5 < 32000") | |
tmp7 = tl.load(in_ptr1 + (r0 + (4096*tmp5)), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32) | |
tmp9 = tmp7 + tmp8 | |
tmp11 = tmp9 + tmp10 | |
tmp13 = tmp11 + tmp12 | |
tmp14 = tmp13.to(tl.float32) | |
tmp15 = tmp14 * tmp14 | |
tmp16 = tl.broadcast_to(tmp15, [XBLOCK, RBLOCK]) | |
tmp18 = _tmp17 + tmp16 | |
_tmp17 = tl.where(rmask, tmp18, _tmp17) | |
tl.store(in_out_ptr0 + (tl.broadcast_to(r0, [XBLOCK, RBLOCK])), tmp13, rmask) | |
tmp17 = tl.sum(_tmp17, 1)[:, None] | |
for roffset in range(0, rnumel, RBLOCK): | |
rindex = roffset + rbase | |
rmask = rindex < rnumel | |
r0 = rindex | |
tmp19 = tl.load(in_out_ptr0 + (r0), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32) | |
tmp28 = tl.load(in_ptr4 + (r0), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32) | |
tmp20 = tmp19.to(tl.float32) | |
tmp21 = 4096.0 | |
tmp22 = tmp17 / tmp21 | |
tmp23 = 1e-05 | |
tmp24 = tmp22 + tmp23 | |
tmp25 = libdevice.rsqrt(tmp24) | |
tmp26 = tmp20 * tmp25 | |
tmp27 = tmp26.to(tl.float32) | |
tmp29 = tmp27 * tmp28 | |
tl.store(out_ptr1 + (tl.broadcast_to(r0, [XBLOCK, RBLOCK])), tmp29, rmask) | |
tl.store(out_ptr2 + (tl.broadcast_to(r0, [XBLOCK, RBLOCK])), tmp29, rmask) | |
''', device_str='cuda') | |
# kernel path: /tmp/torchinductor_ubuntu/23/c23scxttsdz72vesaap27tazzxervcixdrt7hwfybhp2nrfu4kw4.py | |
# Source Nodes: [float_9, mean_4, mul_30, out_11, out_12], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, torchao.fp16act_fp6weight_linear] | |
# float_9 => convert_element_type_36 | |
# mean_4 => mean_4 | |
# mul_30 => mul_40 | |
# out_11 => add_13 | |
# out_12 => fp16act_fp6weight_linear_10 | |
triton_red_fused__to_copy_add_fp16act_fp6weight_linear_mean_mul_10 = async_compile.triton('triton_red_fused__to_copy_add_fp16act_fp6weight_linear_mean_mul_10', ''' | |
import triton | |
import triton.language as tl | |
from triton.compiler.compiler import AttrsDescriptor | |
from torch._inductor.runtime import triton_helpers, triton_heuristics | |
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties | |
@triton_heuristics.reduction( | |
size_hints=[1, 4096], | |
reduction_hint=ReductionHint.INNER, | |
filename=__file__, | |
triton_meta={'signature': {0: '*fp16', 1: '*fp16', 2: '*fp16', 3: '*fp16', 4: 'i32', 5: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=76), 'constants': {4: 1}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 5), equal_to_1=(4,))]}, | |
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_add_fp16act_fp6weight_linear_mean_mul_10', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 5, 'num_reduction': 1, 'backend_hash': '5002d3b4f89584c56f8df13bbe2510318f7ec3b73ba9efd611640fb08b2fe939', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False} | |
) | |
@triton.jit | |
def triton_red_fused__to_copy_add_fp16act_fp6weight_linear_mean_mul_10(in_ptr0, in_ptr1, in_ptr2, out_ptr1, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr): | |
xnumel = 1 | |
rnumel = 4096 | |
xoffset = tl.program_id(0) * XBLOCK | |
xindex = xoffset + tl.arange(0, XBLOCK)[:, None] | |
xmask = xindex < xnumel | |
rbase = tl.arange(0, RBLOCK)[None, :] | |
_tmp6 = tl.full([XBLOCK, RBLOCK], 0, tl.float32) | |
for roffset in range(0, rnumel, RBLOCK): | |
rindex = roffset + rbase | |
rmask = rindex < rnumel | |
r0 = rindex | |
tmp0 = tl.load(in_ptr0 + (r0), rmask, eviction_policy='evict_last', other=0.0).to(tl.float32) | |
tmp1 = tl.load(in_ptr1 + (r0), rmask, eviction_policy='evict_last', other=0.0).to(tl.float32) | |
tmp2 = tmp0 + tmp1 | |
tmp3 = tmp2.to(tl.float32) | |
tmp4 = tmp3 * tmp3 | |
tmp5 = tl.broadcast_to(tmp4, [XBLOCK, RBLOCK]) | |
tmp7 = _tmp6 + tmp5 | |
_tmp6 = tl.where(rmask, tmp7, _tmp6) | |
tmp6 = tl.sum(_tmp6, 1)[:, None] | |
for roffset in range(0, rnumel, RBLOCK): | |
rindex = roffset + rbase | |
rmask = rindex < rnumel | |
r0 = rindex | |
tmp8 = tl.load(in_ptr0 + (r0), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32) | |
tmp9 = tl.load(in_ptr1 + (r0), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32) | |
tmp19 = tl.load(in_ptr2 + (r0), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32) | |
tmp10 = tmp8 + tmp9 | |
tmp11 = tmp10.to(tl.float32) | |
tmp12 = 4096.0 | |
tmp13 = tmp6 / tmp12 | |
tmp14 = 1e-05 | |
tmp15 = tmp13 + tmp14 | |
tmp16 = libdevice.rsqrt(tmp15) | |
tmp17 = tmp11 * tmp16 | |
tmp18 = tmp17.to(tl.float32) | |
tmp20 = tmp18 * tmp19 | |
tl.store(out_ptr1 + (tl.broadcast_to(r0, [XBLOCK, RBLOCK])), tmp20, rmask) | |
''', device_str='cuda') | |
# kernel path: /tmp/torchinductor_ubuntu/sg/csgnrv5nhw4eic7f275ngxcfz73e62dtja6zilbn6f5a2qhvqn54.py | |
# Source Nodes: [add_16, float_12, h_2, mean_5, mul_41, mul_42, mul_43, out_11, output_5, rsqrt_5], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, aten.rsqrt] | |
# add_16 => add_19 | |
# float_12 => convert_element_type_50 | |
# h_2 => add_18 | |
# mean_5 => mean_5 | |
# mul_41 => mul_55 | |
# mul_42 => mul_56 | |
# mul_43 => mul_57 | |
# out_11 => add_13 | |
# output_5 => convert_element_type_51 | |
# rsqrt_5 => rsqrt_5 | |
triton_red_fused__to_copy_add_mean_mul_rsqrt_11 = async_compile.triton('triton_red_fused__to_copy_add_mean_mul_rsqrt_11', ''' | |
import triton | |
import triton.language as tl | |
from triton.compiler.compiler import AttrsDescriptor | |
from torch._inductor.runtime import triton_helpers, triton_heuristics | |
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties | |
@triton_heuristics.reduction( | |
size_hints=[1, 4096], | |
reduction_hint=ReductionHint.INNER, | |
filename=__file__, | |
triton_meta={'signature': {0: '*fp16', 1: '*fp16', 2: '*fp16', 3: '*fp16', 4: '*fp16', 5: 'i32', 6: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=76), 'constants': {5: 1}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4, 6), equal_to_1=(5,))]}, | |
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_add_mean_mul_rsqrt_11', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 7, 'num_reduction': 1, 'backend_hash': '5002d3b4f89584c56f8df13bbe2510318f7ec3b73ba9efd611640fb08b2fe939', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False} | |
) | |
@triton.jit | |
def triton_red_fused__to_copy_add_mean_mul_rsqrt_11(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr1, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr): | |
xnumel = 1 | |
rnumel = 4096 | |
xoffset = tl.program_id(0) * XBLOCK | |
xindex = xoffset + tl.arange(0, XBLOCK)[:, None] | |
xmask = xindex < xnumel | |
rbase = tl.arange(0, RBLOCK)[None, :] | |
_tmp8 = tl.full([XBLOCK, RBLOCK], 0, tl.float32) | |
for roffset in range(0, rnumel, RBLOCK): | |
rindex = roffset + rbase | |
rmask = rindex < rnumel | |
r0 = rindex | |
tmp0 = tl.load(in_ptr0 + (r0), rmask, eviction_policy='evict_last', other=0.0).to(tl.float32) | |
tmp1 = tl.load(in_ptr1 + (r0), rmask, eviction_policy='evict_last', other=0.0).to(tl.float32) | |
tmp3 = tl.load(in_ptr2 + (r0), rmask, eviction_policy='evict_last', other=0.0).to(tl.float32) | |
tmp2 = tmp0 + tmp1 | |
tmp4 = tmp2 + tmp3 | |
tmp5 = tmp4.to(tl.float32) | |
tmp6 = tmp5 * tmp5 | |
tmp7 = tl.broadcast_to(tmp6, [XBLOCK, RBLOCK]) | |
tmp9 = _tmp8 + tmp7 | |
_tmp8 = tl.where(rmask, tmp9, _tmp8) | |
tmp8 = tl.sum(_tmp8, 1)[:, None] | |
for roffset in range(0, rnumel, RBLOCK): | |
rindex = roffset + rbase | |
rmask = rindex < rnumel | |
r0 = rindex | |
tmp10 = tl.load(in_ptr0 + (r0), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32) | |
tmp11 = tl.load(in_ptr1 + (r0), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32) | |
tmp13 = tl.load(in_ptr2 + (r0), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32) | |
tmp23 = tl.load(in_ptr3 + (r0), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32) | |
tmp12 = tmp10 + tmp11 | |
tmp14 = tmp12 + tmp13 | |
tmp15 = tmp14.to(tl.float32) | |
tmp16 = 4096.0 | |
tmp17 = tmp8 / tmp16 | |
tmp18 = 1e-05 | |
tmp19 = tmp17 + tmp18 | |
tmp20 = libdevice.rsqrt(tmp19) | |
tmp21 = tmp15 * tmp20 | |
tmp22 = tmp21.to(tl.float32) | |
tmp24 = tmp22 * tmp23 | |
tl.store(out_ptr1 + (tl.broadcast_to(r0, [XBLOCK, RBLOCK])), tmp24, rmask) | |
''', device_str='cuda') | |
# kernel path: /tmp/torchinductor_ubuntu/nw/cnwia6b5wdcoewnme63eazcbz6dvzogn3stxnaty267p4vw2mmwy.py | |
# Source Nodes: [float_13, h_2, mean_6, mul_45, out_11, out_17, out_18], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, torchao.fp16act_fp6weight_linear] | |
# float_13 => convert_element_type_54 | |
# h_2 => add_18 | |
# mean_6 => mean_6 | |
# mul_45 => mul_60 | |
# out_11 => add_13 | |
# out_17 => add_20 | |
# out_18 => fp16act_fp6weight_linear_15 | |
triton_red_fused__to_copy_add_fp16act_fp6weight_linear_mean_mul_12 = async_compile.triton('triton_red_fused__to_copy_add_fp16act_fp6weight_linear_mean_mul_12', ''' | |
import triton | |
import triton.language as tl | |
from triton.compiler.compiler import AttrsDescriptor | |
from torch._inductor.runtime import triton_helpers, triton_heuristics | |
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties | |
@triton_heuristics.reduction( | |
size_hints=[1, 4096], | |
reduction_hint=ReductionHint.INNER, | |
filename=__file__, | |
triton_meta={'signature': {0: '*fp16', 1: '*fp16', 2: '*fp16', 3: '*fp16', 4: '*fp16', 5: '*fp16', 6: 'i32', 7: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=76), 'constants': {6: 1}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4, 5, 7), equal_to_1=(6,))]}, | |
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_add_fp16act_fp6weight_linear_mean_mul_12', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 9, 'num_reduction': 1, 'backend_hash': '5002d3b4f89584c56f8df13bbe2510318f7ec3b73ba9efd611640fb08b2fe939', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False} | |
) | |
@triton.jit | |
def triton_red_fused__to_copy_add_fp16act_fp6weight_linear_mean_mul_12(in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, out_ptr1, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr): | |
xnumel = 1 | |
rnumel = 4096 | |
xoffset = tl.program_id(0) * XBLOCK | |
xindex = xoffset + tl.arange(0, XBLOCK)[:, None] | |
xmask = xindex < xnumel | |
rbase = tl.arange(0, RBLOCK)[None, :] | |
_tmp10 = tl.full([XBLOCK, RBLOCK], 0, tl.float32) | |
for roffset in range(0, rnumel, RBLOCK): | |
rindex = roffset + rbase | |
rmask = rindex < rnumel | |
r0 = rindex | |
tmp0 = tl.load(in_ptr0 + (r0), rmask, eviction_policy='evict_last', other=0.0).to(tl.float32) | |
tmp1 = tl.load(in_ptr1 + (r0), rmask, eviction_policy='evict_last', other=0.0).to(tl.float32) | |
tmp3 = tl.load(in_ptr2 + (r0), rmask, eviction_policy='evict_last', other=0.0).to(tl.float32) | |
tmp5 = tl.load(in_ptr3 + (r0), rmask, eviction_policy='evict_last', other=0.0).to(tl.float32) | |
tmp2 = tmp0 + tmp1 | |
tmp4 = tmp2 + tmp3 | |
tmp6 = tmp4 + tmp5 | |
tmp7 = tmp6.to(tl.float32) | |
tmp8 = tmp7 * tmp7 | |
tmp9 = tl.broadcast_to(tmp8, [XBLOCK, RBLOCK]) | |
tmp11 = _tmp10 + tmp9 | |
_tmp10 = tl.where(rmask, tmp11, _tmp10) | |
tmp10 = tl.sum(_tmp10, 1)[:, None] | |
for roffset in range(0, rnumel, RBLOCK): | |
rindex = roffset + rbase | |
rmask = rindex < rnumel | |
r0 = rindex | |
tmp12 = tl.load(in_ptr0 + (r0), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32) | |
tmp13 = tl.load(in_ptr1 + (r0), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32) | |
tmp15 = tl.load(in_ptr2 + (r0), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32) | |
tmp17 = tl.load(in_ptr3 + (r0), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32) | |
tmp27 = tl.load(in_ptr4 + (r0), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32) | |
tmp14 = tmp12 + tmp13 | |
tmp16 = tmp14 + tmp15 | |
tmp18 = tmp16 + tmp17 | |
tmp19 = tmp18.to(tl.float32) | |
tmp20 = 4096.0 | |
tmp21 = tmp10 / tmp20 | |
tmp22 = 1e-05 | |
tmp23 = tmp21 + tmp22 | |
tmp24 = libdevice.rsqrt(tmp23) | |
tmp25 = tmp19 * tmp24 | |
tmp26 = tmp25.to(tl.float32) | |
tmp28 = tmp26 * tmp27 | |
tl.store(out_ptr1 + (tl.broadcast_to(r0, [XBLOCK, RBLOCK])), tmp28, rmask) | |
''', device_str='cuda') | |
# kernel path: /tmp/torchinductor_ubuntu/yi/cyiywyx2b2ap35djo57g2tkxu3hdej3o4ootd3nr6bigecahyoeu.py | |
# Source Nodes: [float_16, h_2, h_3, mean_7, mul_56, out_11, out_17, out_20, out_21], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, torchao.fp16act_fp6weight_linear] | |
# float_16 => convert_element_type_68 | |
# h_2 => add_18 | |
# h_3 => add_25 | |
# mean_7 => mean_7 | |
# mul_56 => mul_75 | |
# out_11 => add_13 | |
# out_17 => add_20 | |
# out_20 => fp16act_fp6weight_linear_17 | |
# out_21 => fp16act_fp6weight_linear_18 | |
triton_red_fused__to_copy_add_fp16act_fp6weight_linear_mean_mul_13 = async_compile.triton('triton_red_fused__to_copy_add_fp16act_fp6weight_linear_mean_mul_13', ''' | |
import triton | |
import triton.language as tl | |
from triton.compiler.compiler import AttrsDescriptor | |
from torch._inductor.runtime import triton_helpers, triton_heuristics | |
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties | |
@triton_heuristics.reduction( | |
size_hints=[1, 4096], | |
reduction_hint=ReductionHint.INNER, | |
filename=__file__, | |
triton_meta={'signature': {0: '*fp16', 1: '*fp16', 2: '*fp16', 3: '*fp16', 4: '*fp16', 5: '*fp16', 6: '*fp16', 7: '*fp16', 8: 'i32', 9: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=76), 'constants': {8: 1}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4, 5, 6, 7, 9), equal_to_1=(8,))]}, | |
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_add_fp16act_fp6weight_linear_mean_mul_13', 'mutated_arg_names': ['in_out_ptr0'], 'no_x_dim': False, 'num_load': 7, 'num_reduction': 1, 'backend_hash': '5002d3b4f89584c56f8df13bbe2510318f7ec3b73ba9efd611640fb08b2fe939', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False} | |
) | |
@triton.jit | |
def triton_red_fused__to_copy_add_fp16act_fp6weight_linear_mean_mul_13(in_out_ptr0, in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, out_ptr1, out_ptr2, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr): | |
xnumel = 1 | |
rnumel = 4096 | |
xoffset = tl.program_id(0) * XBLOCK | |
xindex = xoffset + tl.arange(0, XBLOCK)[:, None] | |
xmask = xindex < xnumel | |
rbase = tl.arange(0, RBLOCK)[None, :] | |
_tmp12 = tl.full([XBLOCK, RBLOCK], 0, tl.float32) | |
for roffset in range(0, rnumel, RBLOCK): | |
rindex = roffset + rbase | |
rmask = rindex < rnumel | |
r0 = rindex | |
tmp0 = tl.load(in_out_ptr0 + (r0), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32) | |
tmp1 = tl.load(in_ptr0 + (r0), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32) | |
tmp3 = tl.load(in_ptr1 + (r0), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32) | |
tmp5 = tl.load(in_ptr2 + (r0), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32) | |
tmp7 = tl.load(in_ptr3 + (r0), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32) | |
tmp2 = tmp0 + tmp1 | |
tmp4 = tmp2 + tmp3 | |
tmp6 = tmp4 + tmp5 | |
tmp8 = tmp6 + tmp7 | |
tmp9 = tmp8.to(tl.float32) | |
tmp10 = tmp9 * tmp9 | |
tmp11 = tl.broadcast_to(tmp10, [XBLOCK, RBLOCK]) | |
tmp13 = _tmp12 + tmp11 | |
_tmp12 = tl.where(rmask, tmp13, _tmp12) | |
tl.store(in_out_ptr0 + (tl.broadcast_to(r0, [XBLOCK, RBLOCK])), tmp8, rmask) | |
tmp12 = tl.sum(_tmp12, 1)[:, None] | |
for roffset in range(0, rnumel, RBLOCK): | |
rindex = roffset + rbase | |
rmask = rindex < rnumel | |
r0 = rindex | |
tmp14 = tl.load(in_out_ptr0 + (r0), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32) | |
tmp23 = tl.load(in_ptr4 + (r0), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32) | |
tmp15 = tmp14.to(tl.float32) | |
tmp16 = 4096.0 | |
tmp17 = tmp12 / tmp16 | |
tmp18 = 1e-05 | |
tmp19 = tmp17 + tmp18 | |
tmp20 = libdevice.rsqrt(tmp19) | |
tmp21 = tmp15 * tmp20 | |
tmp22 = tmp21.to(tl.float32) | |
tmp24 = tmp22 * tmp23 | |
tl.store(out_ptr1 + (tl.broadcast_to(r0, [XBLOCK, RBLOCK])), tmp24, rmask) | |
tl.store(out_ptr2 + (tl.broadcast_to(r0, [XBLOCK, RBLOCK])), tmp24, rmask) | |
''', device_str='cuda') | |
# kernel path: /tmp/torchinductor_ubuntu/74/c74rqt5v7cgztcvqz3g36kdattrzmlrq33illnplnagfadtt7nqv.py | |
# Source Nodes: [out_22], Original ATen: [torchao.fp16act_fp6weight_linear] | |
# out_22 => fp16act_fp6weight_linear_19 | |
triton_poi_fused_fp16act_fp6weight_linear_14 = async_compile.triton('triton_poi_fused_fp16act_fp6weight_linear_14', ''' | |
import triton | |
import triton.language as tl | |
from triton.compiler.compiler import AttrsDescriptor | |
from torch._inductor.runtime import triton_helpers, triton_heuristics | |
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties | |
@triton_heuristics.pointwise( | |
size_hints=[16384], | |
filename=__file__, | |
triton_meta={'signature': {0: '*fp16', 1: '*fp16', 2: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=76), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2), equal_to_1=())]}, | |
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_fp16act_fp6weight_linear_14', 'mutated_arg_names': ['in_out_ptr0'], 'no_x_dim': False, 'num_load': 2, 'num_reduction': 0, 'backend_hash': '5002d3b4f89584c56f8df13bbe2510318f7ec3b73ba9efd611640fb08b2fe939', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False}, | |
min_elem_per_thread=0 | |
) | |
@triton.jit | |
def triton_poi_fused_fp16act_fp6weight_linear_14(in_out_ptr0, in_ptr0, xnumel, XBLOCK : tl.constexpr): | |
xnumel = 11008 | |
xoffset = tl.program_id(0) * XBLOCK | |
xindex = xoffset + tl.arange(0, XBLOCK)[:] | |
xmask = xindex < xnumel | |
x0 = xindex | |
tmp0 = tl.load(in_ptr0 + (x0), xmask).to(tl.float32) | |
tmp5 = tl.load(in_out_ptr0 + (x0), xmask).to(tl.float32) | |
tmp1 = tmp0.to(tl.float32) | |
tmp2 = tl.sigmoid(tmp1) | |
tmp3 = tmp1 * tmp2 | |
tmp4 = tmp3.to(tl.float32) | |
tmp6 = tmp4 * tmp5 | |
tl.store(in_out_ptr0 + (x0), tmp6, xmask) | |
''', device_str='cuda') | |
# kernel path: /tmp/torchinductor_ubuntu/od/codwx422sav4ibjcz7e3u2ii64iq4gtjlj6xmjifbwmqjhhlbprf.py | |
# Source Nodes: [float_24, h_4, h_5, mean_11, mul_86, out_23, out_29, out_32, out_33], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, torchao.fp16act_fp6weight_linear] | |
# float_24 => convert_element_type_104 | |
# h_4 => add_32 | |
# h_5 => add_39 | |
# mean_11 => mean_11 | |
# mul_86 => mul_115 | |
# out_23 => add_27 | |
# out_29 => add_34 | |
# out_32 => fp16act_fp6weight_linear_27 | |
# out_33 => fp16act_fp6weight_linear_28 | |
triton_red_fused__to_copy_add_fp16act_fp6weight_linear_mean_mul_15 = async_compile.triton('triton_red_fused__to_copy_add_fp16act_fp6weight_linear_mean_mul_15', ''' | |
import triton | |
import triton.language as tl | |
from triton.compiler.compiler import AttrsDescriptor | |
from torch._inductor.runtime import triton_helpers, triton_heuristics | |
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties | |
@triton_heuristics.reduction( | |
size_hints=[1, 4096], | |
reduction_hint=ReductionHint.INNER, | |
filename=__file__, | |
triton_meta={'signature': {0: '*fp16', 1: '*fp16', 2: '*fp16', 3: '*fp16', 4: '*fp16', 5: '*fp16', 6: '*fp16', 7: '*fp16', 8: 'i32', 9: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=76), 'constants': {8: 1}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4, 5, 6, 7, 9), equal_to_1=(8,))]}, | |
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_add_fp16act_fp6weight_linear_mean_mul_15', 'mutated_arg_names': ['in_out_ptr0'], 'no_x_dim': False, 'num_load': 7, 'num_reduction': 1, 'backend_hash': '5002d3b4f89584c56f8df13bbe2510318f7ec3b73ba9efd611640fb08b2fe939', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False} | |
) | |
@triton.jit | |
def triton_red_fused__to_copy_add_fp16act_fp6weight_linear_mean_mul_15(in_out_ptr0, in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, out_ptr1, out_ptr2, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr): | |
xnumel = 1 | |
rnumel = 4096 | |
xoffset = tl.program_id(0) * XBLOCK | |
xindex = xoffset + tl.arange(0, XBLOCK)[:, None] | |
xmask = xindex < xnumel | |
rbase = tl.arange(0, RBLOCK)[None, :] | |
_tmp12 = tl.full([XBLOCK, RBLOCK], 0, tl.float32) | |
for roffset in range(0, rnumel, RBLOCK): | |
rindex = roffset + rbase | |
rmask = rindex < rnumel | |
r0 = rindex | |
tmp0 = tl.load(in_ptr0 + (r0), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32) | |
tmp1 = tl.load(in_out_ptr0 + (r0), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32) | |
tmp3 = tl.load(in_ptr1 + (r0), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32) | |
tmp5 = tl.load(in_ptr2 + (r0), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32) | |
tmp7 = tl.load(in_ptr3 + (r0), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32) | |
tmp2 = tmp0 + tmp1 | |
tmp4 = tmp2 + tmp3 | |
tmp6 = tmp4 + tmp5 | |
tmp8 = tmp6 + tmp7 | |
tmp9 = tmp8.to(tl.float32) | |
tmp10 = tmp9 * tmp9 | |
tmp11 = tl.broadcast_to(tmp10, [XBLOCK, RBLOCK]) | |
tmp13 = _tmp12 + tmp11 | |
_tmp12 = tl.where(rmask, tmp13, _tmp12) | |
tl.store(in_out_ptr0 + (tl.broadcast_to(r0, [XBLOCK, RBLOCK])), tmp8, rmask) | |
tmp12 = tl.sum(_tmp12, 1)[:, None] | |
for roffset in range(0, rnumel, RBLOCK): | |
rindex = roffset + rbase | |
rmask = rindex < rnumel | |
r0 = rindex | |
tmp14 = tl.load(in_out_ptr0 + (r0), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32) | |
tmp23 = tl.load(in_ptr4 + (r0), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32) | |
tmp15 = tmp14.to(tl.float32) | |
tmp16 = 4096.0 | |
tmp17 = tmp12 / tmp16 | |
tmp18 = 1e-05 | |
tmp19 = tmp17 + tmp18 | |
tmp20 = libdevice.rsqrt(tmp19) | |
tmp21 = tmp15 * tmp20 | |
tmp22 = tmp21.to(tl.float32) | |
tmp24 = tmp22 * tmp23 | |
tl.store(out_ptr1 + (tl.broadcast_to(r0, [XBLOCK, RBLOCK])), tmp24, rmask) | |
tl.store(out_ptr2 + (tl.broadcast_to(r0, [XBLOCK, RBLOCK])), tmp24, rmask) | |
''', device_str='cuda') | |
# kernel path: /tmp/torchinductor_ubuntu/z6/cz6oxi5i6lto4h7ib75rlkpzqsocefhcbwhm7nhdpxravdtgiuro.py | |
# Source Nodes: [logits_1], Original ATen: [aten.div] | |
# logits_1 => div_32 | |
triton_poi_fused_div_16 = async_compile.triton('triton_poi_fused_div_16', ''' | |
import triton | |
import triton.language as tl | |
from triton.compiler.compiler import AttrsDescriptor | |
from torch._inductor.runtime import triton_helpers, triton_heuristics | |
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties | |
@triton_heuristics.pointwise( | |
size_hints=[32768], | |
filename=__file__, | |
triton_meta={'signature': {0: '*fp16', 1: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=76), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1), equal_to_1=())]}, | |
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_div_16', 'mutated_arg_names': ['in_out_ptr0'], 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': '5002d3b4f89584c56f8df13bbe2510318f7ec3b73ba9efd611640fb08b2fe939', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False}, | |
min_elem_per_thread=0 | |
) | |
@triton.jit | |
def triton_poi_fused_div_16(in_out_ptr0, xnumel, XBLOCK : tl.constexpr): | |
xnumel = 32000 | |
xoffset = tl.program_id(0) * XBLOCK | |
xindex = xoffset + tl.arange(0, XBLOCK)[:] | |
xmask = xindex < xnumel | |
x0 = xindex | |
tmp0 = tl.load(in_out_ptr0 + (x0), xmask).to(tl.float32) | |
tmp1 = 1.25 | |
tmp2 = tmp0 * tmp1 | |
tl.store(in_out_ptr0 + (x0), tmp2, xmask) | |
''', device_str='cuda') | |
# kernel path: /tmp/torchinductor_ubuntu/xv/cxvjbig2sbormp2kxcubuvv2h5v3vxuc4jta7llq4rnh6kwuxhuo.py | |
# Source Nodes: [logits_2, lt, probs], Original ATen: [aten._softmax, aten.lt, aten.scalar_tensor, aten.where] | |
# logits_2 => full_default_64, where_32 | |
# lt => lt | |
# probs => amax_32, convert_element_type_578 | |
triton_red_fused__softmax_lt_scalar_tensor_where_17 = async_compile.triton('triton_red_fused__softmax_lt_scalar_tensor_where_17', ''' | |
import triton | |
import triton.language as tl | |
from triton.compiler.compiler import AttrsDescriptor | |
from torch._inductor.runtime import triton_helpers, triton_heuristics | |
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties | |
@triton_heuristics.reduction( | |
size_hints=[4, 8192], | |
reduction_hint=ReductionHint.INNER, | |
filename=__file__, | |
triton_meta={'signature': {0: '*fp16', 1: '*fp16', 2: '*fp32', 3: 'i32', 4: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=76), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 4), equal_to_1=())]}, | |
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused__softmax_lt_scalar_tensor_where_17', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 2, 'num_reduction': 1, 'backend_hash': '5002d3b4f89584c56f8df13bbe2510318f7ec3b73ba9efd611640fb08b2fe939', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False} | |
) | |
@triton.jit | |
def triton_red_fused__softmax_lt_scalar_tensor_where_17(in_ptr0, in_ptr1, out_ptr0, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr): | |
xnumel = 4 | |
rnumel = 8000 | |
xoffset = tl.program_id(0) * XBLOCK | |
xindex = xoffset + tl.arange(0, XBLOCK)[:, None] | |
xmask = xindex < xnumel | |
rbase = tl.arange(0, RBLOCK)[None, :] | |
x0 = xindex | |
tmp1 = tl.load(in_ptr1 + (199)).to(tl.float32) | |
tmp2 = tl.broadcast_to(tmp1, [XBLOCK, RBLOCK]) | |
_tmp8 = tl.full([XBLOCK, RBLOCK], float("-inf"), tl.float32) | |
for roffset in range(0, rnumel, RBLOCK): | |
rindex = roffset + rbase | |
rmask = rindex < rnumel | |
r1 = rindex | |
tmp0 = tl.load(in_ptr0 + (r1 + (8000*x0)), rmask & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) | |
tmp3 = tmp0 < tmp2 | |
tmp4 = float("-inf") | |
tmp5 = tl.where(tmp3, tmp4, tmp0) | |
tmp6 = tmp5.to(tl.float32) | |
tmp7 = tl.broadcast_to(tmp6, [XBLOCK, RBLOCK]) | |
tmp9 = triton_helpers.maximum(_tmp8, tmp7) | |
_tmp8 = tl.where(rmask & xmask, tmp9, _tmp8) | |
tmp8 = triton_helpers.max2(_tmp8, 1)[:, None] | |
tl.store(out_ptr0 + (x0), tmp8, xmask) | |
''', device_str='cuda') | |
# kernel path: /tmp/torchinductor_ubuntu/hz/chzqxkcllimh3vsvmsxoviuxtefrsfgsde3el6pgijogxpaxrs7r.py | |
# Source Nodes: [logits_2, lt, probs], Original ATen: [aten._softmax, aten.lt, aten.scalar_tensor, aten.where] | |
# logits_2 => full_default_64, where_32 | |
# lt => lt | |
# probs => amax_32, convert_element_type_578 | |
triton_per_fused__softmax_lt_scalar_tensor_where_18 = async_compile.triton('triton_per_fused__softmax_lt_scalar_tensor_where_18', ''' | |
import triton | |
import triton.language as tl | |
from triton.compiler.compiler import AttrsDescriptor | |
from torch._inductor.runtime import triton_helpers, triton_heuristics | |
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties | |
@triton_heuristics.persistent_reduction( | |
size_hints=[1, 4], | |
reduction_hint=ReductionHint.INNER, | |
filename=__file__, | |
triton_meta={'signature': {0: '*fp32', 1: '*fp32', 2: 'i32', 3: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=76), 'constants': {2: 1}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1), equal_to_1=(2,))]}, | |
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_per_fused__softmax_lt_scalar_tensor_where_18', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': '5002d3b4f89584c56f8df13bbe2510318f7ec3b73ba9efd611640fb08b2fe939', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False} | |
) | |
@triton.jit | |
def triton_per_fused__softmax_lt_scalar_tensor_where_18(in_ptr0, out_ptr0, xnumel, rnumel, XBLOCK : tl.constexpr): | |
xnumel = 1 | |
rnumel = 4 | |
RBLOCK: tl.constexpr = 4 | |
xoffset = tl.program_id(0) * XBLOCK | |
xindex = xoffset + tl.arange(0, XBLOCK)[:, None] | |
xmask = xindex < xnumel | |
rindex = tl.arange(0, RBLOCK)[None, :] | |
roffset = 0 | |
rmask = rindex < rnumel | |
r0 = rindex | |
tmp0 = tl.load(in_ptr0 + (r0), rmask, other=0.0) | |
tmp1 = tl.broadcast_to(tmp0, [XBLOCK, RBLOCK]) | |
tmp3 = tl.where(rmask, tmp1, float("-inf")) | |
tmp4 = triton_helpers.max2(tmp3, 1)[:, None] | |
tl.store(out_ptr0 + (tl.full([XBLOCK, 1], 0, tl.int32)), tmp4, None) | |
''', device_str='cuda') | |
# kernel path: /tmp/torchinductor_ubuntu/za/czakkdksd2yxt6hanh7vxqjmbj7ara4d4rza6bwjwwt4stxakxtm.py | |
# Source Nodes: [logits_2, lt, probs], Original ATen: [aten._softmax, aten.lt, aten.scalar_tensor, aten.where] | |
# logits_2 => full_default_64, where_32 | |
# lt => lt | |
# probs => convert_element_type_578, exp_32, sub_96, sum_97 | |
triton_red_fused__softmax_lt_scalar_tensor_where_19 = async_compile.triton('triton_red_fused__softmax_lt_scalar_tensor_where_19', ''' | |
import triton | |
import triton.language as tl | |
from triton.compiler.compiler import AttrsDescriptor | |
from torch._inductor.runtime import triton_helpers, triton_heuristics | |
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties | |
@triton_heuristics.reduction( | |
size_hints=[4, 8192], | |
reduction_hint=ReductionHint.INNER, | |
filename=__file__, | |
triton_meta={'signature': {0: '*fp16', 1: '*fp16', 2: '*fp32', 3: '*fp32', 4: 'i32', 5: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=76), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 5), equal_to_1=())]}, | |
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused__softmax_lt_scalar_tensor_where_19', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 3, 'num_reduction': 1, 'backend_hash': '5002d3b4f89584c56f8df13bbe2510318f7ec3b73ba9efd611640fb08b2fe939', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False} | |
) | |
@triton.jit | |
def triton_red_fused__softmax_lt_scalar_tensor_where_19(in_ptr0, in_ptr1, in_ptr2, out_ptr0, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr): | |
xnumel = 4 | |
rnumel = 8000 | |
xoffset = tl.program_id(0) * XBLOCK | |
xindex = xoffset + tl.arange(0, XBLOCK)[:, None] | |
xmask = xindex < xnumel | |
rbase = tl.arange(0, RBLOCK)[None, :] | |
x0 = xindex | |
tmp1 = tl.load(in_ptr1 + (199)).to(tl.float32) | |
tmp2 = tl.broadcast_to(tmp1, [XBLOCK, RBLOCK]) | |
tmp7 = tl.load(in_ptr2 + (0)) | |
tmp8 = tl.broadcast_to(tmp7, [XBLOCK, RBLOCK]) | |
_tmp12 = tl.full([XBLOCK, RBLOCK], 0, tl.float32) | |
for roffset in range(0, rnumel, RBLOCK): | |
rindex = roffset + rbase | |
rmask = rindex < rnumel | |
r1 = rindex | |
tmp0 = tl.load(in_ptr0 + (r1 + (8000*x0)), rmask & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) | |
tmp3 = tmp0 < tmp2 | |
tmp4 = float("-inf") | |
tmp5 = tl.where(tmp3, tmp4, tmp0) | |
tmp6 = tmp5.to(tl.float32) | |
tmp9 = tmp6 - tmp8 | |
tmp10 = tl_math.exp(tmp9) | |
tmp11 = tl.broadcast_to(tmp10, [XBLOCK, RBLOCK]) | |
tmp13 = _tmp12 + tmp11 | |
_tmp12 = tl.where(rmask & xmask, tmp13, _tmp12) | |
tmp12 = tl.sum(_tmp12, 1)[:, None] | |
tl.store(out_ptr0 + (x0), tmp12, xmask) | |
''', device_str='cuda') | |
# kernel path: /tmp/torchinductor_ubuntu/aa/caad4sxzjec6myyhypd4e6odaasgzubfdw6lhsfkhrv5jqg4uc7g.py | |
# Source Nodes: [logits_2, lt, probs], Original ATen: [aten._softmax, aten.lt, aten.scalar_tensor, aten.where] | |
# logits_2 => full_default_64, where_32 | |
# lt => lt | |
# probs => convert_element_type_578, exp_32, sub_96, sum_97 | |
triton_per_fused__softmax_lt_scalar_tensor_where_20 = async_compile.triton('triton_per_fused__softmax_lt_scalar_tensor_where_20', ''' | |
import triton | |
import triton.language as tl | |
from triton.compiler.compiler import AttrsDescriptor | |
from torch._inductor.runtime import triton_helpers, triton_heuristics | |
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties | |
@triton_heuristics.persistent_reduction( | |
size_hints=[1, 4], | |
reduction_hint=ReductionHint.INNER, | |
filename=__file__, | |
triton_meta={'signature': {0: '*fp32', 1: '*fp32', 2: 'i32', 3: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=76), 'constants': {2: 1}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1), equal_to_1=(2,))]}, | |
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_per_fused__softmax_lt_scalar_tensor_where_20', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': '5002d3b4f89584c56f8df13bbe2510318f7ec3b73ba9efd611640fb08b2fe939', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False} | |
) | |
@triton.jit | |
def triton_per_fused__softmax_lt_scalar_tensor_where_20(in_ptr0, out_ptr0, xnumel, rnumel, XBLOCK : tl.constexpr): | |
xnumel = 1 | |
rnumel = 4 | |
RBLOCK: tl.constexpr = 4 | |
xoffset = tl.program_id(0) * XBLOCK | |
xindex = xoffset + tl.arange(0, XBLOCK)[:, None] | |
xmask = xindex < xnumel | |
rindex = tl.arange(0, RBLOCK)[None, :] | |
roffset = 0 | |
rmask = rindex < rnumel | |
r0 = rindex | |
tmp0 = tl.load(in_ptr0 + (r0), rmask, other=0.0) | |
tmp1 = tl.broadcast_to(tmp0, [XBLOCK, RBLOCK]) | |
tmp3 = tl.where(rmask, tmp1, 0) | |
tmp4 = tl.sum(tmp3, 1)[:, None] | |
tl.store(out_ptr0 + (tl.full([XBLOCK, 1], 0, tl.int32)), tmp4, None) | |
''', device_str='cuda') | |
# kernel path: /tmp/torchinductor_ubuntu/ch/cch6v3lnecv5tfe5ipqpowh4xge3iiqm4nrz67gnjbwlbump6mow.py | |
# Source Nodes: [argmax, idx_next, logits_2, lt, probs, q_128, truediv_1], Original ATen: [aten._softmax, aten._to_copy, aten.argmax, aten.div, aten.exponential, aten.lt, aten.scalar_tensor, aten.where] | |
# argmax => argmax | |
# idx_next => convert_element_type_582 | |
# logits_2 => full_default_64, where_32 | |
# lt => lt | |
# probs => convert_element_type_578, convert_element_type_579, div_33, exp_32, sub_96 | |
# q_128 => convert_element_type_581, log1p, mul_643, neg | |
# truediv_1 => div_34 | |
triton_red_fused__softmax__to_copy_argmax_div_exponential_lt_scalar_tensor_where_21 = async_compile.triton('triton_red_fused__softmax__to_copy_argmax_div_exponential_lt_scalar_tensor_where_21', ''' | |
import triton | |
import triton.language as tl | |
from triton.compiler.compiler import AttrsDescriptor | |
from torch._inductor.runtime import triton_helpers, triton_heuristics | |
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties | |
@triton_heuristics.reduction( | |
size_hints=[1, 32768], | |
reduction_hint=ReductionHint.DEFAULT, | |
filename=__file__, | |
triton_meta={'signature': {0: '*fp16', 1: '*fp16', 2: '*fp32', 3: '*fp32', 4: '*i64', 5: '*i32', 6: 'i32', 7: 'i32', 8: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=76), 'constants': {7: 1}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4, 5, 8), equal_to_1=(7,))]}, | |
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused__softmax__to_copy_argmax_div_exponential_lt_scalar_tensor_where_21', 'mutated_arg_names': ['in_out_ptr0'], 'no_x_dim': False, 'num_load': 4, 'num_reduction': 1, 'backend_hash': '5002d3b4f89584c56f8df13bbe2510318f7ec3b73ba9efd611640fb08b2fe939', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False} | |
) | |
@triton.jit | |
def triton_red_fused__softmax__to_copy_argmax_div_exponential_lt_scalar_tensor_where_21(in_out_ptr0, in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr2, load_seed_offset, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr): | |
xnumel = 1 | |
rnumel = 32000 | |
xoffset = tl.program_id(0) * XBLOCK | |
xindex = xoffset + tl.arange(0, XBLOCK)[:, None] | |
xmask = xindex < xnumel | |
rbase = tl.arange(0, RBLOCK)[None, :] | |
tmp1 = tl.load(in_ptr0 + (199)).to(tl.float32) | |
tmp2 = tl.broadcast_to(tmp1, [XBLOCK, RBLOCK]) | |
tmp7 = tl.load(in_ptr1 + (0)) | |
tmp8 = tl.broadcast_to(tmp7, [XBLOCK, RBLOCK]) | |
tmp11 = tl.load(in_ptr2 + (0)) | |
tmp12 = tl.broadcast_to(tmp11, [XBLOCK, RBLOCK]) | |
_tmp25 = tl.full([XBLOCK, RBLOCK], float("-inf"), tl.float32) | |
_tmp25_index = tl.full([XBLOCK, RBLOCK], 9223372036854775807, tl.int64) | |
for roffset in range(0, rnumel, RBLOCK): | |
rindex = roffset + rbase | |
rmask = rindex < rnumel | |
r0 = rindex | |
tmp0 = tl.load(in_out_ptr0 + (r0), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32) | |
tmp3 = tmp0 < tmp2 | |
tmp4 = float("-inf") | |
tmp5 = tl.where(tmp3, tmp4, tmp0) | |
tmp6 = tmp5.to(tl.float32) | |
tmp9 = tmp6 - tmp8 | |
tmp10 = tl_math.exp(tmp9) | |
tmp13 = tmp10 / tmp12 | |
tmp14 = tmp13.to(tl.float32) | |
tmp15 = tl.load(in_ptr3 + load_seed_offset) | |
tmp16 = r0 | |
tmp17 = tl.rand(tmp15, (tmp16).to(tl.uint32)) | |
tmp18 = -tmp17 | |
tmp19 = libdevice.log1p(tmp18) | |
tmp20 = -1.0 | |
tmp21 = tmp19 * tmp20 | |
tmp22 = tmp21.to(tl.float32) | |
tmp23 = tmp14 / tmp22 | |
tmp24 = tl.broadcast_to(tmp23, [XBLOCK, RBLOCK]) | |
_tmp25_next, _tmp25_index_next = triton_helpers.maximum_with_index( | |
_tmp25, _tmp25_index, tmp24, rindex | |
) | |
_tmp25 = tl.where(rmask, _tmp25_next, _tmp25) | |
_tmp25_index = tl.where(rmask, _tmp25_index_next, _tmp25_index) | |
tl.store(in_out_ptr0 + (tl.broadcast_to(r0, [XBLOCK, RBLOCK])), tmp14, rmask) | |
_, tmp25_tmp = triton_helpers.max_with_index(_tmp25, _tmp25_index, 1) | |
tmp25 = tmp25_tmp[:, None] | |
tmp26 = tmp25.to(tl.int32) | |
tl.store(out_ptr2 + (tl.full([XBLOCK, 1], 0, tl.int32)), tmp26, None) | |
''', device_str='cuda') | |
async_compile.wait(globals()) | |
del async_compile | |
def call(args): | |
arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1, arg7_1, arg8_1, arg9_1, arg10_1, arg11_1, arg12_1, arg13_1, arg14_1, arg15_1, arg16_1, arg17_1, arg18_1, arg19_1, arg20_1, arg21_1, arg22_1, arg23_1, arg24_1, arg25_1, arg26_1, arg27_1, arg28_1, arg29_1, arg30_1, arg31_1, arg32_1, arg33_1, arg34_1, arg35_1, arg36_1, arg37_1, arg38_1, arg39_1, arg40_1, arg41_1, arg42_1, arg43_1, arg44_1, arg45_1, arg46_1, arg47_1, arg48_1, arg49_1, arg50_1, arg51_1, arg52_1, arg53_1, arg54_1, arg55_1, arg56_1, arg57_1, arg58_1, arg59_1, arg60_1, arg61_1, arg62_1, arg63_1, arg64_1, arg65_1, arg66_1, arg67_1, arg68_1, arg69_1, arg70_1, arg71_1, arg72_1, arg73_1, arg74_1, arg75_1, arg76_1, arg77_1, arg78_1, arg79_1, arg80_1, arg81_1, arg82_1, arg83_1, arg84_1, arg85_1, arg86_1, arg87_1, arg88_1, arg89_1, arg90_1, arg91_1, arg92_1, arg93_1, arg94_1, arg95_1, arg96_1, arg97_1, arg98_1, arg99_1, arg100_1, arg101_1, arg102_1, arg103_1, arg104_1, arg105_1, arg106_1, arg107_1, arg108_1, arg109_1, arg110_1, arg111_1, arg112_1, arg113_1, arg114_1, arg115_1, arg116_1, arg117_1, arg118_1, arg119_1, arg120_1, arg121_1, arg122_1, arg123_1, arg124_1, arg125_1, arg126_1, arg127_1, arg128_1, arg129_1, arg130_1, arg131_1, arg132_1, arg133_1, arg134_1, arg135_1, arg136_1, arg137_1, arg138_1, arg139_1, arg140_1, arg141_1, arg142_1, arg143_1, arg144_1, arg145_1, arg146_1, arg147_1, arg148_1, arg149_1, arg150_1, arg151_1, arg152_1, arg153_1, arg154_1, arg155_1, arg156_1, arg157_1, arg158_1, arg159_1, arg160_1, arg161_1, arg162_1, arg163_1, arg164_1, arg165_1, arg166_1, arg167_1, arg168_1, arg169_1, arg170_1, arg171_1, arg172_1, arg173_1, arg174_1, arg175_1, arg176_1, arg177_1, arg178_1, arg179_1, arg180_1, arg181_1, arg182_1, arg183_1, arg184_1, arg185_1, arg186_1, arg187_1, arg188_1, arg189_1, arg190_1, arg191_1, arg192_1, arg193_1, arg194_1, arg195_1, arg196_1, arg197_1, arg198_1, arg199_1, arg200_1, arg201_1, arg202_1, arg203_1, arg204_1, arg205_1, arg206_1, arg207_1, arg208_1, arg209_1, arg210_1, arg211_1, arg212_1, arg213_1, arg214_1, arg215_1, arg216_1, arg217_1, arg218_1, arg219_1, arg220_1, arg221_1, arg222_1, arg223_1, arg224_1, arg225_1, arg226_1, arg227_1, arg228_1, arg229_1, arg230_1, arg231_1, arg232_1, arg233_1, arg234_1, arg235_1, arg236_1, arg237_1, arg238_1, arg239_1, arg240_1, arg241_1, arg242_1, arg243_1, arg244_1, arg245_1, arg246_1, arg247_1, arg248_1, arg249_1, arg250_1, arg251_1, arg252_1, arg253_1, arg254_1, arg255_1, arg256_1, arg257_1, arg258_1, arg259_1, arg260_1, arg261_1, arg262_1, arg263_1, arg264_1, arg265_1, arg266_1, arg267_1, arg268_1, arg269_1, arg270_1, arg271_1, arg272_1, arg273_1, arg274_1, arg275_1, arg276_1, arg277_1, arg278_1, arg279_1, arg280_1, arg281_1, arg282_1, arg283_1, arg284_1, arg285_1, arg286_1, arg287_1, arg288_1, arg289_1, arg290_1, arg291_1, arg292_1, arg293_1, arg294_1, arg295_1, arg296_1, arg297_1, arg298_1, arg299_1, arg300_1, arg301_1, arg302_1, arg303_1, arg304_1, arg305_1, arg306_1, arg307_1, arg308_1, arg309_1, arg310_1, arg311_1, arg312_1, arg313_1, arg314_1, arg315_1, arg316_1, arg317_1, arg318_1, arg319_1, arg320_1, arg321_1, arg322_1, arg323_1, arg324_1, arg325_1, arg326_1, arg327_1, arg328_1, arg329_1, arg330_1, arg331_1, arg332_1, arg333_1, arg334_1, arg335_1, arg336_1, arg337_1, arg338_1, arg339_1, arg340_1, arg341_1, arg342_1, arg343_1, arg344_1, arg345_1, arg346_1, arg347_1, arg348_1, arg349_1, arg350_1, arg351_1, arg352_1, arg353_1, arg354_1, arg355_1, arg356_1, arg357_1, arg358_1, arg359_1, arg360_1, arg361_1, arg362_1, arg363_1, arg364_1, arg365_1, arg366_1, arg367_1, arg368_1, arg369_1, arg370_1, arg371_1, arg372_1, arg373_1, arg374_1, arg375_1, arg376_1, arg377_1, arg378_1, arg379_1, arg380_1, arg381_1, arg382_1, arg383_1, arg384_1, arg385_1, arg386_1, arg387_1, arg388_1, arg389_1, arg390_1, arg391_1, arg392_1, arg393_1, arg394_1, arg395_1, arg396_1, arg397_1, arg398_1, arg399_1, arg400_1, arg401_1, arg402_1, arg403_1, arg404_1, arg405_1, arg406_1, arg407_1, arg408_1, arg409_1, arg410_1, arg411_1, arg412_1, arg413_1, arg414_1, arg415_1, arg416_1, arg417_1, arg418_1, arg419_1, arg420_1, arg421_1, arg422_1, arg423_1, arg424_1, arg425_1, arg426_1, arg427_1, arg428_1, arg429_1, arg430_1, arg431_1, arg432_1, arg433_1, arg434_1, arg435_1, arg436_1, arg437_1, arg438_1, arg439_1, arg440_1, arg441_1, arg442_1, arg443_1, arg444_1, arg445_1, arg446_1, arg447_1, arg448_1, arg449_1, arg450_1, arg451_1, arg452_1, arg453_1, arg454_1, arg455_1 = args | |
args.clear() | |
assert_size_stride(arg0_1, (4096, ), (1, )) | |
assert_size_stride(arg1_1, (4096, ), (1, )) | |
assert_size_stride(arg2_1, (4096, ), (1, )) | |
assert_size_stride(arg3_1, (4096, ), (1, )) | |
assert_size_stride(arg4_1, (4096, ), (1, )) | |
assert_size_stride(arg5_1, (4096, ), (1, )) | |
assert_size_stride(arg6_1, (4096, ), (1, )) | |
assert_size_stride(arg7_1, (4096, ), (1, )) | |
assert_size_stride(arg8_1, (4096, ), (1, )) | |
assert_size_stride(arg9_1, (4096, ), (1, )) | |
assert_size_stride(arg10_1, (4096, ), (1, )) | |
assert_size_stride(arg11_1, (4096, ), (1, )) | |
assert_size_stride(arg12_1, (4096, ), (1, )) | |
assert_size_stride(arg13_1, (4096, ), (1, )) | |
assert_size_stride(arg14_1, (4096, ), (1, )) | |
assert_size_stride(arg15_1, (4096, ), (1, )) | |
assert_size_stride(arg16_1, (4096, ), (1, )) | |
assert_size_stride(arg17_1, (4096, ), (1, )) | |
assert_size_stride(arg18_1, (4096, ), (1, )) | |
assert_size_stride(arg19_1, (4096, ), (1, )) | |
assert_size_stride(arg20_1, (4096, ), (1, )) | |
assert_size_stride(arg21_1, (4096, ), (1, )) | |
assert_size_stride(arg22_1, (4096, ), (1, )) | |
assert_size_stride(arg23_1, (4096, ), (1, )) | |
assert_size_stride(arg24_1, (4096, ), (1, )) | |
assert_size_stride(arg25_1, (4096, ), (1, )) | |
assert_size_stride(arg26_1, (4096, ), (1, )) | |
assert_size_stride(arg27_1, (4096, ), (1, )) | |
assert_size_stride(arg28_1, (4096, ), (1, )) | |
assert_size_stride(arg29_1, (4096, ), (1, )) | |
assert_size_stride(arg30_1, (4096, ), (1, )) | |
assert_size_stride(arg31_1, (4096, ), (1, )) | |
assert_size_stride(arg32_1, (4096, ), (1, )) | |
assert_size_stride(arg33_1, (4096, ), (1, )) | |
assert_size_stride(arg34_1, (4096, ), (1, )) | |
assert_size_stride(arg35_1, (4096, ), (1, )) | |
assert_size_stride(arg36_1, (4096, ), (1, )) | |
assert_size_stride(arg37_1, (4096, ), (1, )) | |
assert_size_stride(arg38_1, (4096, ), (1, )) | |
assert_size_stride(arg39_1, (4096, ), (1, )) | |
assert_size_stride(arg40_1, (4096, ), (1, )) | |
assert_size_stride(arg41_1, (4096, ), (1, )) | |
assert_size_stride(arg42_1, (4096, ), (1, )) | |
assert_size_stride(arg43_1, (4096, ), (1, )) | |
assert_size_stride(arg44_1, (4096, ), (1, )) | |
assert_size_stride(arg45_1, (4096, ), (1, )) | |
assert_size_stride(arg46_1, (4096, ), (1, )) | |
assert_size_stride(arg47_1, (4096, ), (1, )) | |
assert_size_stride(arg48_1, (4096, ), (1, )) | |
assert_size_stride(arg49_1, (4096, ), (1, )) | |
assert_size_stride(arg50_1, (4096, ), (1, )) | |
assert_size_stride(arg51_1, (4096, ), (1, )) | |
assert_size_stride(arg52_1, (4096, ), (1, )) | |
assert_size_stride(arg53_1, (4096, ), (1, )) | |
assert_size_stride(arg54_1, (4096, ), (1, )) | |
assert_size_stride(arg55_1, (4096, ), (1, )) | |
assert_size_stride(arg56_1, (4096, ), (1, )) | |
assert_size_stride(arg57_1, (4096, ), (1, )) | |
assert_size_stride(arg58_1, (4096, ), (1, )) | |
assert_size_stride(arg59_1, (4096, ), (1, )) | |
assert_size_stride(arg60_1, (4096, ), (1, )) | |
assert_size_stride(arg61_1, (4096, ), (1, )) | |
assert_size_stride(arg62_1, (4096, ), (1, )) | |
assert_size_stride(arg63_1, (4096, ), (1, )) | |
assert_size_stride(arg64_1, (4096, ), (1, )) | |
assert_size_stride(arg65_1, (32000, 4096), (4096, 1)) | |
assert_size_stride(arg66_1, (2048, 64, 2), (128, 2, 1)) | |
assert_size_stride(arg67_1, (208, 208), (208, 1)) | |
assert_size_stride(arg68_1, (12288, 768), (768, 1)) | |
assert_size_stride(arg69_1, (12288, ), (1, )) | |
assert_size_stride(arg70_1, (1, 32, 208, 128), (851968, 26624, 128, 1)) | |
assert_size_stride(arg71_1, (1, 32, 208, 128), (851968, 26624, 128, 1)) | |
assert_size_stride(arg72_1, (4096, 768), (768, 1)) | |
assert_size_stride(arg73_1, (4096, ), (1, )) | |
assert_size_stride(arg74_1, (11008, 768), (768, 1)) | |
assert_size_stride(arg75_1, (11008, ), (1, )) | |
assert_size_stride(arg76_1, (11008, 768), (768, 1)) | |
assert_size_stride(arg77_1, (11008, ), (1, )) | |
assert_size_stride(arg78_1, (4096, 2064), (2064, 1)) | |
assert_size_stride(arg79_1, (4096, ), (1, )) | |
assert_size_stride(arg80_1, (12288, 768), (768, 1)) | |
assert_size_stride(arg81_1, (12288, ), (1, )) | |
assert_size_stride(arg82_1, (1, 32, 208, 128), (851968, 26624, 128, 1)) | |
assert_size_stride(arg83_1, (1, 32, 208, 128), (851968, 26624, 128, 1)) | |
assert_size_stride(arg84_1, (4096, 768), (768, 1)) | |
assert_size_stride(arg85_1, (4096, ), (1, )) | |
assert_size_stride(arg86_1, (11008, 768), (768, 1)) | |
assert_size_stride(arg87_1, (11008, ), (1, )) | |
assert_size_stride(arg88_1, (11008, 768), (768, 1)) | |
assert_size_stride(arg89_1, (11008, ), (1, )) | |
assert_size_stride(arg90_1, (4096, 2064), (2064, 1)) | |
assert_size_stride(arg91_1, (4096, ), (1, )) | |
assert_size_stride(arg92_1, (12288, 768), (768, 1)) | |
assert_size_stride(arg93_1, (12288, ), (1, )) | |
assert_size_stride(arg94_1, (1, 32, 208, 128), (851968, 26624, 128, 1)) | |
assert_size_stride(arg95_1, (1, 32, 208, 128), (851968, 26624, 128, 1)) | |
assert_size_stride(arg96_1, (4096, 768), (768, 1)) | |
assert_size_stride(arg97_1, (4096, ), (1, )) | |
assert_size_stride(arg98_1, (11008, 768), (768, 1)) | |
assert_size_stride(arg99_1, (11008, ), (1, )) | |
assert_size_stride(arg100_1, (11008, 768), (768, 1)) | |
assert_size_stride(arg101_1, (11008, ), (1, )) | |
assert_size_stride(arg102_1, (4096, 2064), (2064, 1)) | |
assert_size_stride(arg103_1, (4096, ), (1, )) | |
assert_size_stride(arg104_1, (12288, 768), (768, 1)) | |
assert_size_stride(arg105_1, (12288, ), (1, )) | |
assert_size_stride(arg106_1, (1, 32, 208, 128), (851968, 26624, 128, 1)) | |
assert_size_stride(arg107_1, (1, 32, 208, 128), (851968, 26624, 128, 1)) | |
assert_size_stride(arg108_1, (4096, 768), (768, 1)) | |
assert_size_stride(arg109_1, (4096, ), (1, )) | |
assert_size_stride(arg110_1, (11008, 768), (768, 1)) | |
assert_size_stride(arg111_1, (11008, ), (1, )) | |
assert_size_stride(arg112_1, (11008, 768), (768, 1)) | |
assert_size_stride(arg113_1, (11008, ), (1, )) | |
assert_size_stride(arg114_1, (4096, 2064), (2064, 1)) | |
assert_size_stride(arg115_1, (4096, ), (1, )) | |
assert_size_stride(arg116_1, (12288, 768), (768, 1)) | |
assert_size_stride(arg117_1, (12288, ), (1, )) | |
assert_size_stride(arg118_1, (1, 32, 208, 128), (851968, 26624, 128, 1)) | |
assert_size_stride(arg119_1, (1, 32, 208, 128), (851968, 26624, 128, 1)) | |
assert_size_stride(arg120_1, (4096, 768), (768, 1)) | |
assert_size_stride(arg121_1, (4096, ), (1, )) | |
assert_size_stride(arg122_1, (11008, 768), (768, 1)) | |
assert_size_stride(arg123_1, (11008, ), (1, )) | |
assert_size_stride(arg124_1, (11008, 768), (768, 1)) | |
assert_size_stride(arg125_1, (11008, ), (1, )) | |
assert_size_stride(arg126_1, (4096, 2064), (2064, 1)) | |
assert_size_stride(arg127_1, (4096, ), (1, )) | |
assert_size_stride(arg128_1, (12288, 768), (768, 1)) | |
assert_size_stride(arg129_1, (12288, ), (1, )) | |
assert_size_stride(arg130_1, (1, 32, 208, 128), (851968, 26624, 128, 1)) | |
assert_size_stride(arg131_1, (1, 32, 208, 128), (851968, 26624, 128, 1)) | |
assert_size_stride(arg132_1, (4096, 768), (768, 1)) | |
assert_size_stride(arg133_1, (4096, ), (1, )) | |
assert_size_stride(arg134_1, (11008, 768), (768, 1)) | |
assert_size_stride(arg135_1, (11008, ), (1, )) | |
assert_size_stride(arg136_1, (11008, 768), (768, 1)) | |
assert_size_stride(arg137_1, (11008, ), (1, )) | |
assert_size_stride(arg138_1, (4096, 2064), (2064, 1)) | |
assert_size_stride(arg139_1, (4096, ), (1, )) | |
assert_size_stride(arg140_1, (12288, 768), (768, 1)) | |
assert_size_stride(arg141_1, (12288, ), (1, )) | |
assert_size_stride(arg142_1, (1, 32, 208, 128), (851968, 26624, 128, 1)) | |
assert_size_stride(arg143_1, (1, 32, 208, 128), (851968, 26624, 128, 1)) | |
assert_size_stride(arg144_1, (4096, 768), (768, 1)) | |
assert_size_stride(arg145_1, (4096, ), (1, )) | |
assert_size_stride(arg146_1, (11008, 768), (768, 1)) | |
assert_size_stride(arg147_1, (11008, ), (1, )) | |
assert_size_stride(arg148_1, (11008, 768), (768, 1)) | |
assert_size_stride(arg149_1, (11008, ), (1, )) | |
assert_size_stride(arg150_1, (4096, 2064), (2064, 1)) | |
assert_size_stride(arg151_1, (4096, ), (1, )) | |
assert_size_stride(arg152_1, (12288, 768), (768, 1)) | |
assert_size_stride(arg153_1, (12288, ), (1, )) | |
assert_size_stride(arg154_1, (1, 32, 208, 128), (851968, 26624, 128, 1)) | |
assert_size_stride(arg155_1, (1, 32, 208, 128), (851968, 26624, 128, 1)) | |
assert_size_stride(arg156_1, (4096, 768), (768, 1)) | |
assert_size_stride(arg157_1, (4096, ), (1, )) | |
assert_size_stride(arg158_1, (11008, 768), (768, 1)) | |
assert_size_stride(arg159_1, (11008, ), (1, )) | |
assert_size_stride(arg160_1, (11008, 768), (768, 1)) | |
assert_size_stride(arg161_1, (11008, ), (1, )) | |
assert_size_stride(arg162_1, (4096, 2064), (2064, 1)) | |
assert_size_stride(arg163_1, (4096, ), (1, )) | |
assert_size_stride(arg164_1, (12288, 768), (768, 1)) | |
assert_size_stride(arg165_1, (12288, ), (1, )) | |
assert_size_stride(arg166_1, (1, 32, 208, 128), (851968, 26624, 128, 1)) | |
assert_size_stride(arg167_1, (1, 32, 208, 128), (851968, 26624, 128, 1)) | |
assert_size_stride(arg168_1, (4096, 768), (768, 1)) | |
assert_size_stride(arg169_1, (4096, ), (1, )) | |
assert_size_stride(arg170_1, (11008, 768), (768, 1)) | |
assert_size_stride(arg171_1, (11008, ), (1, )) | |
assert_size_stride(arg172_1, (11008, 768), (768, 1)) | |
assert_size_stride(arg173_1, (11008, ), (1, )) | |
assert_size_stride(arg174_1, (4096, 2064), (2064, 1)) | |
assert_size_stride(arg175_1, (4096, ), (1, )) | |
assert_size_stride(arg176_1, (12288, 768), (768, 1)) | |
assert_size_stride(arg177_1, (12288, ), (1, )) | |
assert_size_stride(arg178_1, (1, 32, 208, 128), (851968, 26624, 128, 1)) | |
assert_size_stride(arg179_1, (1, 32, 208, 128), (851968, 26624, 128, 1)) | |
assert_size_stride(arg180_1, (4096, 768), (768, 1)) | |
assert_size_stride(arg181_1, (4096, ), (1, )) | |
assert_size_stride(arg182_1, (11008, 768), (768, 1)) | |
assert_size_stride(arg183_1, (11008, ), (1, )) | |
assert_size_stride(arg184_1, (11008, 768), (768, 1)) | |
assert_size_stride(arg185_1, (11008, ), (1, )) | |
assert_size_stride(arg186_1, (4096, 2064), (2064, 1)) | |
assert_size_stride(arg187_1, (4096, ), (1, )) | |
assert_size_stride(arg188_1, (12288, 768), (768, 1)) | |
assert_size_stride(arg189_1, (12288, ), (1, )) | |
assert_size_stride(arg190_1, (1, 32, 208, 128), (851968, 26624, 128, 1)) | |
assert_size_stride(arg191_1, (1, 32, 208, 128), (851968, 26624, 128, 1)) | |
assert_size_stride(arg192_1, (4096, 768), (768, 1)) | |
assert_size_stride(arg193_1, (4096, ), (1, )) | |
assert_size_stride(arg194_1, (11008, 768), (768, 1)) | |
assert_size_stride(arg195_1, (11008, ), (1, )) | |
assert_size_stride(arg196_1, (11008, 768), (768, 1)) | |
assert_size_stride(arg197_1, (11008, ), (1, )) | |
assert_size_stride(arg198_1, (4096, 2064), (2064, 1)) | |
assert_size_stride(arg199_1, (4096, ), (1, )) | |
assert_size_stride(arg200_1, (12288, 768), (768, 1)) | |
assert_size_stride(arg201_1, (12288, ), (1, )) | |
assert_size_stride(arg202_1, (1, 32, 208, 128), (851968, 26624, 128, 1)) | |
assert_size_stride(arg203_1, (1, 32, 208, 128), (851968, 26624, 128, 1)) | |
assert_size_stride(arg204_1, (4096, 768), (768, 1)) | |
assert_size_stride(arg205_1, (4096, ), (1, )) | |
assert_size_stride(arg206_1, (11008, 768), (768, 1)) | |
assert_size_stride(arg207_1, (11008, ), (1, )) | |
assert_size_stride(arg208_1, (11008, 768), (768, 1)) | |
assert_size_stride(arg209_1, (11008, ), (1, )) | |
assert_size_stride(arg210_1, (4096, 2064), (2064, 1)) | |
assert_size_stride(arg211_1, (4096, ), (1, )) | |
assert_size_stride(arg212_1, (12288, 768), (768, 1)) | |
assert_size_stride(arg213_1, (12288, ), (1, )) | |
assert_size_stride(arg214_1, (1, 32, 208, 128), (851968, 26624, 128, 1)) | |
assert_size_stride(arg215_1, (1, 32, 208, 128), (851968, 26624, 128, 1)) | |
assert_size_stride(arg216_1, (4096, 768), (768, 1)) | |
assert_size_stride(arg217_1, (4096, ), (1, )) | |
assert_size_stride(arg218_1, (11008, 768), (768, 1)) | |
assert_size_stride(arg219_1, (11008, ), (1, )) | |
assert_size_stride(arg220_1, (11008, 768), (768, 1)) | |
assert_size_stride(arg221_1, (11008, ), (1, )) | |
assert_size_stride(arg222_1, (4096, 2064), (2064, 1)) | |
assert_size_stride(arg223_1, (4096, ), (1, )) | |
assert_size_stride(arg224_1, (12288, 768), (768, 1)) | |
assert_size_stride(arg225_1, (12288, ), (1, )) | |
assert_size_stride(arg226_1, (1, 32, 208, 128), (851968, 26624, 128, 1)) | |
assert_size_stride(arg227_1, (1, 32, 208, 128), (851968, 26624, 128, 1)) | |
assert_size_stride(arg228_1, (4096, 768), (768, 1)) | |
assert_size_stride(arg229_1, (4096, ), (1, )) | |
assert_size_stride(arg230_1, (11008, 768), (768, 1)) | |
assert_size_stride(arg231_1, (11008, ), (1, )) | |
assert_size_stride(arg232_1, (11008, 768), (768, 1)) | |
assert_size_stride(arg233_1, (11008, ), (1, )) | |
assert_size_stride(arg234_1, (4096, 2064), (2064, 1)) | |
assert_size_stride(arg235_1, (4096, ), (1, )) | |
assert_size_stride(arg236_1, (12288, 768), (768, 1)) | |
assert_size_stride(arg237_1, (12288, ), (1, )) | |
assert_size_stride(arg238_1, (1, 32, 208, 128), (851968, 26624, 128, 1)) | |
assert_size_stride(arg239_1, (1, 32, 208, 128), (851968, 26624, 128, 1)) | |
assert_size_stride(arg240_1, (4096, 768), (768, 1)) | |
assert_size_stride(arg241_1, (4096, ), (1, )) | |
assert_size_stride(arg242_1, (11008, 768), (768, 1)) | |
assert_size_stride(arg243_1, (11008, ), (1, )) | |
assert_size_stride(arg244_1, (11008, 768), (768, 1)) | |
assert_size_stride(arg245_1, (11008, ), (1, )) | |
assert_size_stride(arg246_1, (4096, 2064), (2064, 1)) | |
assert_size_stride(arg247_1, (4096, ), (1, )) | |
assert_size_stride(arg248_1, (12288, 768), (768, 1)) | |
assert_size_stride(arg249_1, (12288, ), (1, )) | |
assert_size_stride(arg250_1, (1, 32, 208, 128), (851968, 26624, 128, 1)) | |
assert_size_stride(arg251_1, (1, 32, 208, 128), (851968, 26624, 128, 1)) | |
assert_size_stride(arg252_1, (4096, 768), (768, 1)) | |
assert_size_stride(arg253_1, (4096, ), (1, )) | |
assert_size_stride(arg254_1, (11008, 768), (768, 1)) | |
assert_size_stride(arg255_1, (11008, ), (1, )) | |
assert_size_stride(arg256_1, (11008, 768), (768, 1)) | |
assert_size_stride(arg257_1, (11008, ), (1, )) | |
assert_size_stride(arg258_1, (4096, 2064), (2064, 1)) | |
assert_size_stride(arg259_1, (4096, ), (1, )) | |
assert_size_stride(arg260_1, (12288, 768), (768, 1)) | |
assert_size_stride(arg261_1, (12288, ), (1, )) | |
assert_size_stride(arg262_1, (1, 32, 208, 128), (851968, 26624, 128, 1)) | |
assert_size_stride(arg263_1, (1, 32, 208, 128), (851968, 26624, 128, 1)) | |
assert_size_stride(arg264_1, (4096, 768), (768, 1)) | |
assert_size_stride(arg265_1, (4096, ), (1, )) | |
assert_size_stride(arg266_1, (11008, 768), (768, 1)) | |
assert_size_stride(arg267_1, (11008, ), (1, )) | |
assert_size_stride(arg268_1, (11008, 768), (768, 1)) | |
assert_size_stride(arg269_1, (11008, ), (1, )) | |
assert_size_stride(arg270_1, (4096, 2064), (2064, 1)) | |
assert_size_stride(arg271_1, (4096, ), (1, )) | |
assert_size_stride(arg272_1, (12288, 768), (768, 1)) | |
assert_size_stride(arg273_1, (12288, ), (1, )) | |
assert_size_stride(arg274_1, (1, 32, 208, 128), (851968, 26624, 128, 1)) | |
assert_size_stride(arg275_1, (1, 32, 208, 128), (851968, 26624, 128, 1)) | |
assert_size_stride(arg276_1, (4096, 768), (768, 1)) | |
assert_size_stride(arg277_1, (4096, ), (1, )) | |
assert_size_stride(arg278_1, (11008, 768), (768, 1)) | |
assert_size_stride(arg279_1, (11008, ), (1, )) | |
assert_size_stride(arg280_1, (11008, 768), (768, 1)) | |
assert_size_stride(arg281_1, (11008, ), (1, )) | |
assert_size_stride(arg282_1, (4096, 2064), (2064, 1)) | |
assert_size_stride(arg283_1, (4096, ), (1, )) | |
assert_size_stride(arg284_1, (12288, 768), (768, 1)) | |
assert_size_stride(arg285_1, (12288, ), (1, )) | |
assert_size_stride(arg286_1, (1, 32, 208, 128), (851968, 26624, 128, 1)) | |
assert_size_stride(arg287_1, (1, 32, 208, 128), (851968, 26624, 128, 1)) | |
assert_size_stride(arg288_1, (4096, 768), (768, 1)) | |
assert_size_stride(arg289_1, (4096, ), (1, )) | |
assert_size_stride(arg290_1, (11008, 768), (768, 1)) | |
assert_size_stride(arg291_1, (11008, ), (1, )) | |
assert_size_stride(arg292_1, (11008, 768), (768, 1)) | |
assert_size_stride(arg293_1, (11008, ), (1, )) | |
assert_size_stride(arg294_1, (4096, 2064), (2064, 1)) | |
assert_size_stride(arg295_1, (4096, ), (1, )) | |
assert_size_stride(arg296_1, (12288, 768), (768, 1)) | |
assert_size_stride(arg297_1, (12288, ), (1, )) | |
assert_size_stride(arg298_1, (1, 32, 208, 128), (851968, 26624, 128, 1)) | |
assert_size_stride(arg299_1, (1, 32, 208, 128), (851968, 26624, 128, 1)) | |
assert_size_stride(arg300_1, (4096, 768), (768, 1)) | |
assert_size_stride(arg301_1, (4096, ), (1, )) | |
assert_size_stride(arg302_1, (11008, 768), (768, 1)) | |
assert_size_stride(arg303_1, (11008, ), (1, )) | |
assert_size_stride(arg304_1, (11008, 768), (768, 1)) | |
assert_size_stride(arg305_1, (11008, ), (1, )) | |
assert_size_stride(arg306_1, (4096, 2064), (2064, 1)) | |
assert_size_stride(arg307_1, (4096, ), (1, )) | |
assert_size_stride(arg308_1, (12288, 768), (768, 1)) | |
assert_size_stride(arg309_1, (12288, ), (1, )) | |
assert_size_stride(arg310_1, (1, 32, 208, 128), (851968, 26624, 128, 1)) | |
assert_size_stride(arg311_1, (1, 32, 208, 128), (851968, 26624, 128, 1)) | |
assert_size_stride(arg312_1, (4096, 768), (768, 1)) | |
assert_size_stride(arg313_1, (4096, ), (1, )) | |
assert_size_stride(arg314_1, (11008, 768), (768, 1)) | |
assert_size_stride(arg315_1, (11008, ), (1, )) | |
assert_size_stride(arg316_1, (11008, 768), (768, 1)) | |
assert_size_stride(arg317_1, (11008, ), (1, )) | |
assert_size_stride(arg318_1, (4096, 2064), (2064, 1)) | |
assert_size_stride(arg319_1, (4096, ), (1, )) | |
assert_size_stride(arg320_1, (12288, 768), (768, 1)) | |
assert_size_stride(arg321_1, (12288, ), (1, )) | |
assert_size_stride(arg322_1, (1, 32, 208, 128), (851968, 26624, 128, 1)) | |
assert_size_stride(arg323_1, (1, 32, 208, 128), (851968, 26624, 128, 1)) | |
assert_size_stride(arg324_1, (4096, 768), (768, 1)) | |
assert_size_stride(arg325_1, (4096, ), (1, )) | |
assert_size_stride(arg326_1, (11008, 768), (768, 1)) | |
assert_size_stride(arg327_1, (11008, ), (1, )) | |
assert_size_stride(arg328_1, (11008, 768), (768, 1)) | |
assert_size_stride(arg329_1, (11008, ), (1, )) | |
assert_size_stride(arg330_1, (4096, 2064), (2064, 1)) | |
assert_size_stride(arg331_1, (4096, ), (1, )) | |
assert_size_stride(arg332_1, (12288, 768), (768, 1)) | |
assert_size_stride(arg333_1, (12288, ), (1, )) | |
assert_size_stride(arg334_1, (1, 32, 208, 128), (851968, 26624, 128, 1)) | |
assert_size_stride(arg335_1, (1, 32, 208, 128), (851968, 26624, 128, 1)) | |
assert_size_stride(arg336_1, (4096, 768), (768, 1)) | |
assert_size_stride(arg337_1, (4096, ), (1, )) | |
assert_size_stride(arg338_1, (11008, 768), (768, 1)) | |
assert_size_stride(arg339_1, (11008, ), (1, )) | |
assert_size_stride(arg340_1, (11008, 768), (768, 1)) | |
assert_size_stride(arg341_1, (11008, ), (1, )) | |
assert_size_stride(arg342_1, (4096, 2064), (2064, 1)) | |
assert_size_stride(arg343_1, (4096, ), (1, )) | |
assert_size_stride(arg344_1, (12288, 768), (768, 1)) | |
assert_size_stride(arg345_1, (12288, ), (1, )) | |
assert_size_stride(arg346_1, (1, 32, 208, 128), (851968, 26624, 128, 1)) | |
assert_size_stride(arg347_1, (1, 32, 208, 128), (851968, 26624, 128, 1)) | |
assert_size_stride(arg348_1, (4096, 768), (768, 1)) | |
assert_size_stride(arg349_1, (4096, ), (1, )) | |
assert_size_stride(arg350_1, (11008, 768), (768, 1)) | |
assert_size_stride(arg351_1, (11008, ), (1, )) | |
assert_size_stride(arg352_1, (11008, 768), (768, 1)) | |
assert_size_stride(arg353_1, (11008, ), (1, )) | |
assert_size_stride(arg354_1, (4096, 2064), (2064, 1)) | |
assert_size_stride(arg355_1, (4096, ), (1, )) | |
assert_size_stride(arg356_1, (12288, 768), (768, 1)) | |
assert_size_stride(arg357_1, (12288, ), (1, )) | |
assert_size_stride(arg358_1, (1, 32, 208, 128), (851968, 26624, 128, 1)) | |
assert_size_stride(arg359_1, (1, 32, 208, 128), (851968, 26624, 128, 1)) | |
assert_size_stride(arg360_1, (4096, 768), (768, 1)) | |
assert_size_stride(arg361_1, (4096, ), (1, )) | |
assert_size_stride(arg362_1, (11008, 768), (768, 1)) | |
assert_size_stride(arg363_1, (11008, ), (1, )) | |
assert_size_stride(arg364_1, (11008, 768), (768, 1)) | |
assert_size_stride(arg365_1, (11008, ), (1, )) | |
assert_size_stride(arg366_1, (4096, 2064), (2064, 1)) | |
assert_size_stride(arg367_1, (4096, ), (1, )) | |
assert_size_stride(arg368_1, (12288, 768), (768, 1)) | |
assert_size_stride(arg369_1, (12288, ), (1, )) | |
assert_size_stride(arg370_1, (1, 32, 208, 128), (851968, 26624, 128, 1)) | |
assert_size_stride(arg371_1, (1, 32, 208, 128), (851968, 26624, 128, 1)) | |
assert_size_stride(arg372_1, (4096, 768), (768, 1)) | |
assert_size_stride(arg373_1, (4096, ), (1, )) | |
assert_size_stride(arg374_1, (11008, 768), (768, 1)) | |
assert_size_stride(arg375_1, (11008, ), (1, )) | |
assert_size_stride(arg376_1, (11008, 768), (768, 1)) | |
assert_size_stride(arg377_1, (11008, ), (1, )) | |
assert_size_stride(arg378_1, (4096, 2064), (2064, 1)) | |
assert_size_stride(arg379_1, (4096, ), (1, )) | |
assert_size_stride(arg380_1, (12288, 768), (768, 1)) | |
assert_size_stride(arg381_1, (12288, ), (1, )) | |
assert_size_stride(arg382_1, (1, 32, 208, 128), (851968, 26624, 128, 1)) | |
assert_size_stride(arg383_1, (1, 32, 208, 128), (851968, 26624, 128, 1)) | |
assert_size_stride(arg384_1, (4096, 768), (768, 1)) | |
assert_size_stride(arg385_1, (4096, ), (1, )) | |
assert_size_stride(arg386_1, (11008, 768), (768, 1)) | |
assert_size_stride(arg387_1, (11008, ), (1, )) | |
assert_size_stride(arg388_1, (11008, 768), (768, 1)) | |
assert_size_stride(arg389_1, (11008, ), (1, )) | |
assert_size_stride(arg390_1, (4096, 2064), (2064, 1)) | |
assert_size_stride(arg391_1, (4096, ), (1, )) | |
assert_size_stride(arg392_1, (12288, 768), (768, 1)) | |
assert_size_stride(arg393_1, (12288, ), (1, )) | |
assert_size_stride(arg394_1, (1, 32, 208, 128), (851968, 26624, 128, 1)) | |
assert_size_stride(arg395_1, (1, 32, 208, 128), (851968, 26624, 128, 1)) | |
assert_size_stride(arg396_1, (4096, 768), (768, 1)) | |
assert_size_stride(arg397_1, (4096, ), (1, )) | |
assert_size_stride(arg398_1, (11008, 768), (768, 1)) | |
assert_size_stride(arg399_1, (11008, ), (1, )) | |
assert_size_stride(arg400_1, (11008, 768), (768, 1)) | |
assert_size_stride(arg401_1, (11008, ), (1, )) | |
assert_size_stride(arg402_1, (4096, 2064), (2064, 1)) | |
assert_size_stride(arg403_1, (4096, ), (1, )) | |
assert_size_stride(arg404_1, (12288, 768), (768, 1)) | |
assert_size_stride(arg405_1, (12288, ), (1, )) | |
assert_size_stride(arg406_1, (1, 32, 208, 128), (851968, 26624, 128, 1)) | |
assert_size_stride(arg407_1, (1, 32, 208, 128), (851968, 26624, 128, 1)) | |
assert_size_stride(arg408_1, (4096, 768), (768, 1)) | |
assert_size_stride(arg409_1, (4096, ), (1, )) | |
assert_size_stride(arg410_1, (11008, 768), (768, 1)) | |
assert_size_stride(arg411_1, (11008, ), (1, )) | |
assert_size_stride(arg412_1, (11008, 768), (768, 1)) | |
assert_size_stride(arg413_1, (11008, ), (1, )) | |
assert_size_stride(arg414_1, (4096, 2064), (2064, 1)) | |
assert_size_stride(arg415_1, (4096, ), (1, )) | |
assert_size_stride(arg416_1, (12288, 768), (768, 1)) | |
assert_size_stride(arg417_1, (12288, ), (1, )) | |
assert_size_stride(arg418_1, (1, 32, 208, 128), (851968, 26624, 128, 1)) | |
assert_size_stride(arg419_1, (1, 32, 208, 128), (851968, 26624, 128, 1)) | |
assert_size_stride(arg420_1, (4096, 768), (768, 1)) | |
assert_size_stride(arg421_1, (4096, ), (1, )) | |
assert_size_stride(arg422_1, (11008, 768), (768, 1)) | |
assert_size_stride(arg423_1, (11008, ), (1, )) | |
assert_size_stride(arg424_1, (11008, 768), (768, 1)) | |
assert_size_stride(arg425_1, (11008, ), (1, )) | |
assert_size_stride(arg426_1, (4096, 2064), (2064, 1)) | |
assert_size_stride(arg427_1, (4096, ), (1, )) | |
assert_size_stride(arg428_1, (12288, 768), (768, 1)) | |
assert_size_stride(arg429_1, (12288, ), (1, )) | |
assert_size_stride(arg430_1, (1, 32, 208, 128), (851968, 26624, 128, 1)) | |
assert_size_stride(arg431_1, (1, 32, 208, 128), (851968, 26624, 128, 1)) | |
assert_size_stride(arg432_1, (4096, 768), (768, 1)) | |
assert_size_stride(arg433_1, (4096, ), (1, )) | |
assert_size_stride(arg434_1, (11008, 768), (768, 1)) | |
assert_size_stride(arg435_1, (11008, ), (1, )) | |
assert_size_stride(arg436_1, (11008, 768), (768, 1)) | |
assert_size_stride(arg437_1, (11008, ), (1, )) | |
assert_size_stride(arg438_1, (4096, 2064), (2064, 1)) | |
assert_size_stride(arg439_1, (4096, ), (1, )) | |
assert_size_stride(arg440_1, (12288, 768), (768, 1)) | |
assert_size_stride(arg441_1, (12288, ), (1, )) | |
assert_size_stride(arg442_1, (1, 32, 208, 128), (851968, 26624, 128, 1)) | |
assert_size_stride(arg443_1, (1, 32, 208, 128), (851968, 26624, 128, 1)) | |
assert_size_stride(arg444_1, (4096, 768), (768, 1)) | |
assert_size_stride(arg445_1, (4096, ), (1, )) | |
assert_size_stride(arg446_1, (11008, 768), (768, 1)) | |
assert_size_stride(arg447_1, (11008, ), (1, )) | |
assert_size_stride(arg448_1, (11008, 768), (768, 1)) | |
assert_size_stride(arg449_1, (11008, ), (1, )) | |
assert_size_stride(arg450_1, (4096, 2064), (2064, 1)) | |
assert_size_stride(arg451_1, (4096, ), (1, )) | |
assert_size_stride(arg452_1, (32000, 768), (768, 1)) | |
assert_size_stride(arg453_1, (32000, ), (1, )) | |
assert_size_stride(arg454_1, (1, ), (1, )) | |
assert_size_stride(arg455_1, (1, 1), (1, 1)) | |
with torch.cuda._DeviceGuard(0): | |
torch.cuda.set_device(0) | |
buf1 = empty_strided_cuda((1, 4096), (4096, 1), torch.float16) | |
# Source Nodes: [float_1, mean, mul, out, x], Original ATen: [aten._to_copy, aten.embedding, aten.mean, aten.mul, torchao.fp16act_fp6weight_linear] | |
stream0 = get_raw_stream(0) | |
triton_red_fused__to_copy_embedding_fp16act_fp6weight_linear_mean_mul_0.run(arg455_1, arg65_1, arg0_1, buf1, 1, 4096, grid=grid(1), stream=stream0) | |
del arg0_1 | |
# Source Nodes: [out], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf2 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf1, arg68_1, arg69_1, 1) | |
del arg68_1 | |
del arg69_1 | |
buf3 = buf2 | |
del buf2 | |
buf5 = empty_strided_cuda((32, 1, 128), (128, 4096, 1), torch.float32) | |
# Source Nodes: [setitem, setitem_1, y], Original ATen: [aten.bmm, aten.index_put] | |
triton_poi_fused_bmm_index_put_1.run(arg454_1, buf3, arg66_1, arg70_1, buf5, arg71_1, 4096, grid=grid(4096), stream=stream0) | |
del buf3 | |
buf6 = empty_strided_cuda((32, 1, 208), (208, 6656, 1), torch.float32) | |
# Source Nodes: [y], Original ATen: [aten.bmm] | |
triton_red_fused_bmm_2.run(buf5, arg70_1, buf6, 6656, 128, grid=grid(6656), stream=stream0) | |
del arg70_1 | |
buf10 = empty_strided_cuda((32, 1, 208), (208, 6656, 1), torch.float32) | |
# Source Nodes: [mask, y], Original ATen: [aten._softmax, aten.add, aten.bmm, aten.index, aten.logical_not, aten.masked_fill, aten.zeros_like] | |
triton_per_fused__softmax_add_bmm_index_logical_not_masked_fill_zeros_like_3.run(buf6, arg454_1, arg67_1, buf10, 32, 208, grid=grid(32), stream=stream0) | |
buf11 = empty_strided_cuda((32, 1, 128, 2), (256, 8192, 1, 128), torch.float32) | |
# Source Nodes: [y], Original ATen: [aten.bmm] | |
triton_red_fused_bmm_4.run(buf10, arg71_1, buf11, 8192, 104, grid=grid(8192), stream=stream0) | |
del arg71_1 | |
buf13 = buf1; del buf1 # reuse | |
# Source Nodes: [out_1, y], Original ATen: [aten.bmm, torchao.fp16act_fp6weight_linear] | |
triton_per_fused_bmm_fp16act_fp6weight_linear_5.run(buf11, buf13, 4096, 2, grid=grid(4096), stream=stream0) | |
# Source Nodes: [out_1], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf14 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf13, arg72_1, arg73_1, 13) | |
del arg72_1 | |
del arg73_1 | |
buf15 = buf14 | |
del buf14 | |
buf17 = reinterpret_tensor(buf13, (1, 1, 4096), (4096, 4096, 1), 0); del buf13 # reuse | |
# Source Nodes: [add_4, float_4, h, mean_1, mul_11, mul_12, mul_13, output_1, rsqrt_1, x], Original ATen: [aten._to_copy, aten.add, aten.embedding, aten.mean, aten.mul, aten.rsqrt] | |
triton_red_fused__to_copy_add_embedding_mean_mul_rsqrt_6.run(arg455_1, arg65_1, buf15, arg1_1, buf17, 1, 4096, grid=grid(1), stream=stream0) | |
del arg1_1 | |
# Source Nodes: [out_2], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf18 = torch.ops.torchao.fp16act_fp6weight_linear.default(reinterpret_tensor(buf17, (1, 4096), (4096, 1), 0), arg74_1, arg75_1, 1) | |
del arg74_1 | |
del arg75_1 | |
buf19 = buf18 | |
del buf18 | |
# Source Nodes: [out_3], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf20 = torch.ops.torchao.fp16act_fp6weight_linear.default(reinterpret_tensor(buf17, (1, 4096), (4096, 1), 0), arg76_1, arg77_1, 1) | |
del arg76_1 | |
del arg77_1 | |
buf21 = buf20 | |
del buf20 | |
buf22 = buf19; del buf19 # reuse | |
# Source Nodes: [out_4], Original ATen: [torchao.fp16act_fp6weight_linear] | |
triton_poi_fused_fp16act_fp6weight_linear_7.run(buf22, buf21, 11008, grid=grid(11008), stream=stream0) | |
del buf21 | |
# Source Nodes: [out_4], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf23 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf22, arg78_1, arg79_1, 13) | |
del arg78_1 | |
del arg79_1 | |
del buf22 | |
buf24 = buf23 | |
del buf23 | |
buf26 = reinterpret_tensor(buf17, (1, 4096), (4096, 1), 0); del buf17 # reuse | |
# Source Nodes: [float_5, h, mean_2, mul_15, out_5, out_6, x], Original ATen: [aten._to_copy, aten.add, aten.embedding, aten.mean, aten.mul, torchao.fp16act_fp6weight_linear] | |
triton_red_fused__to_copy_add_embedding_fp16act_fp6weight_linear_mean_mul_8.run(arg455_1, arg65_1, buf15, buf24, arg2_1, buf26, 1, 4096, grid=grid(1), stream=stream0) | |
del arg2_1 | |
# Source Nodes: [out_6], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf27 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf26, arg80_1, arg81_1, 1) | |
del arg80_1 | |
del arg81_1 | |
buf28 = buf27 | |
del buf27 | |
buf30 = buf5; del buf5 # reuse | |
# Source Nodes: [setitem_2, setitem_3, y_3], Original ATen: [aten.bmm, aten.index_put] | |
triton_poi_fused_bmm_index_put_1.run(arg454_1, buf28, arg66_1, arg82_1, buf30, arg83_1, 4096, grid=grid(4096), stream=stream0) | |
del buf28 | |
buf31 = buf10; del buf10 # reuse | |
# Source Nodes: [y_3], Original ATen: [aten.bmm] | |
triton_red_fused_bmm_2.run(buf30, arg82_1, buf31, 6656, 128, grid=grid(6656), stream=stream0) | |
del arg82_1 | |
buf35 = buf6; del buf6 # reuse | |
# Source Nodes: [mask, y_3], Original ATen: [aten._softmax, aten.add, aten.bmm, aten.index, aten.logical_not, aten.masked_fill, aten.zeros_like] | |
triton_per_fused__softmax_add_bmm_index_logical_not_masked_fill_zeros_like_3.run(buf31, arg454_1, arg67_1, buf35, 32, 208, grid=grid(32), stream=stream0) | |
buf36 = buf11; del buf11 # reuse | |
# Source Nodes: [y_3], Original ATen: [aten.bmm] | |
triton_red_fused_bmm_4.run(buf35, arg83_1, buf36, 8192, 104, grid=grid(8192), stream=stream0) | |
del arg83_1 | |
buf38 = buf26; del buf26 # reuse | |
# Source Nodes: [out_7, y_3], Original ATen: [aten.bmm, torchao.fp16act_fp6weight_linear] | |
triton_per_fused_bmm_fp16act_fp6weight_linear_5.run(buf36, buf38, 4096, 2, grid=grid(4096), stream=stream0) | |
# Source Nodes: [out_7], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf39 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf38, arg84_1, arg85_1, 13) | |
del arg84_1 | |
del arg85_1 | |
buf40 = buf39 | |
del buf39 | |
buf41 = reinterpret_tensor(buf15, (1, 1, 4096), (4096, 4096, 1), 0); del buf15 # reuse | |
buf43 = buf38; del buf38 # reuse | |
buf46 = empty_strided_cuda((1, 4096), (4096, 1), torch.float16) | |
# Source Nodes: [float_8, h, h_1, mean_3, mul_26, out_5, out_8, out_9, x], Original ATen: [aten._to_copy, aten.add, aten.embedding, aten.mean, aten.mul, torchao.fp16act_fp6weight_linear] | |
triton_red_fused__to_copy_add_embedding_fp16act_fp6weight_linear_mean_mul_9.run(buf41, arg455_1, arg65_1, buf24, buf40, arg3_1, buf43, buf46, 1, 4096, grid=grid(1), stream=stream0) | |
del arg3_1 | |
del arg455_1 | |
del arg65_1 | |
del buf24 | |
del buf40 | |
# Source Nodes: [out_8], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf44 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf43, arg86_1, arg87_1, 1) | |
del arg86_1 | |
del arg87_1 | |
buf45 = buf44 | |
del buf44 | |
# Source Nodes: [out_9], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf47 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf46, arg88_1, arg89_1, 1) | |
del arg88_1 | |
del arg89_1 | |
buf48 = buf47 | |
del buf47 | |
buf49 = buf45; del buf45 # reuse | |
# Source Nodes: [out_10], Original ATen: [torchao.fp16act_fp6weight_linear] | |
triton_poi_fused_fp16act_fp6weight_linear_7.run(buf49, buf48, 11008, grid=grid(11008), stream=stream0) | |
del buf48 | |
# Source Nodes: [out_10], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf50 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf49, arg90_1, arg91_1, 13) | |
del arg90_1 | |
del arg91_1 | |
del buf49 | |
buf51 = buf50 | |
del buf50 | |
buf53 = buf46; del buf46 # reuse | |
# Source Nodes: [float_9, mean_4, mul_30, out_11, out_12], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, torchao.fp16act_fp6weight_linear] | |
triton_red_fused__to_copy_add_fp16act_fp6weight_linear_mean_mul_10.run(buf41, buf51, arg4_1, buf53, 1, 4096, grid=grid(1), stream=stream0) | |
del arg4_1 | |
# Source Nodes: [out_12], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf54 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf53, arg92_1, arg93_1, 1) | |
del arg92_1 | |
del arg93_1 | |
buf55 = buf54 | |
del buf54 | |
buf57 = buf30; del buf30 # reuse | |
# Source Nodes: [setitem_4, setitem_5, y_6], Original ATen: [aten.bmm, aten.index_put] | |
triton_poi_fused_bmm_index_put_1.run(arg454_1, buf55, arg66_1, arg94_1, buf57, arg95_1, 4096, grid=grid(4096), stream=stream0) | |
del buf55 | |
buf58 = buf35; del buf35 # reuse | |
# Source Nodes: [y_6], Original ATen: [aten.bmm] | |
triton_red_fused_bmm_2.run(buf57, arg94_1, buf58, 6656, 128, grid=grid(6656), stream=stream0) | |
del arg94_1 | |
buf62 = buf31; del buf31 # reuse | |
# Source Nodes: [mask, y_6], Original ATen: [aten._softmax, aten.add, aten.bmm, aten.index, aten.logical_not, aten.masked_fill, aten.zeros_like] | |
triton_per_fused__softmax_add_bmm_index_logical_not_masked_fill_zeros_like_3.run(buf58, arg454_1, arg67_1, buf62, 32, 208, grid=grid(32), stream=stream0) | |
buf63 = buf36; del buf36 # reuse | |
# Source Nodes: [y_6], Original ATen: [aten.bmm] | |
triton_red_fused_bmm_4.run(buf62, arg95_1, buf63, 8192, 104, grid=grid(8192), stream=stream0) | |
del arg95_1 | |
buf65 = buf53; del buf53 # reuse | |
# Source Nodes: [out_13, y_6], Original ATen: [aten.bmm, torchao.fp16act_fp6weight_linear] | |
triton_per_fused_bmm_fp16act_fp6weight_linear_5.run(buf63, buf65, 4096, 2, grid=grid(4096), stream=stream0) | |
# Source Nodes: [out_13], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf66 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf65, arg96_1, arg97_1, 13) | |
del arg96_1 | |
del arg97_1 | |
buf67 = buf66 | |
del buf66 | |
buf69 = reinterpret_tensor(buf65, (1, 1, 4096), (4096, 4096, 1), 0); del buf65 # reuse | |
# Source Nodes: [add_16, float_12, h_2, mean_5, mul_41, mul_42, mul_43, out_11, output_5, rsqrt_5], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, aten.rsqrt] | |
triton_red_fused__to_copy_add_mean_mul_rsqrt_11.run(buf41, buf51, buf67, arg5_1, buf69, 1, 4096, grid=grid(1), stream=stream0) | |
del arg5_1 | |
# Source Nodes: [out_14], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf70 = torch.ops.torchao.fp16act_fp6weight_linear.default(reinterpret_tensor(buf69, (1, 4096), (4096, 1), 0), arg98_1, arg99_1, 1) | |
del arg98_1 | |
del arg99_1 | |
buf71 = buf70 | |
del buf70 | |
# Source Nodes: [out_15], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf72 = torch.ops.torchao.fp16act_fp6weight_linear.default(reinterpret_tensor(buf69, (1, 4096), (4096, 1), 0), arg100_1, arg101_1, 1) | |
del arg100_1 | |
del arg101_1 | |
buf73 = buf72 | |
del buf72 | |
buf74 = buf71; del buf71 # reuse | |
# Source Nodes: [out_16], Original ATen: [torchao.fp16act_fp6weight_linear] | |
triton_poi_fused_fp16act_fp6weight_linear_7.run(buf74, buf73, 11008, grid=grid(11008), stream=stream0) | |
del buf73 | |
# Source Nodes: [out_16], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf75 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf74, arg102_1, arg103_1, 13) | |
del arg102_1 | |
del arg103_1 | |
del buf74 | |
buf76 = buf75 | |
del buf75 | |
buf78 = reinterpret_tensor(buf69, (1, 4096), (4096, 1), 0); del buf69 # reuse | |
# Source Nodes: [float_13, h_2, mean_6, mul_45, out_11, out_17, out_18], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, torchao.fp16act_fp6weight_linear] | |
triton_red_fused__to_copy_add_fp16act_fp6weight_linear_mean_mul_12.run(buf41, buf51, buf67, buf76, arg6_1, buf78, 1, 4096, grid=grid(1), stream=stream0) | |
del arg6_1 | |
# Source Nodes: [out_18], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf79 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf78, arg104_1, arg105_1, 1) | |
del arg104_1 | |
del arg105_1 | |
buf80 = buf79 | |
del buf79 | |
buf82 = buf57; del buf57 # reuse | |
# Source Nodes: [setitem_6, setitem_7, y_9], Original ATen: [aten.bmm, aten.index_put] | |
triton_poi_fused_bmm_index_put_1.run(arg454_1, buf80, arg66_1, arg106_1, buf82, arg107_1, 4096, grid=grid(4096), stream=stream0) | |
del buf80 | |
buf83 = buf62; del buf62 # reuse | |
# Source Nodes: [y_9], Original ATen: [aten.bmm] | |
triton_red_fused_bmm_2.run(buf82, arg106_1, buf83, 6656, 128, grid=grid(6656), stream=stream0) | |
del arg106_1 | |
buf87 = buf58; del buf58 # reuse | |
# Source Nodes: [mask, y_9], Original ATen: [aten._softmax, aten.add, aten.bmm, aten.index, aten.logical_not, aten.masked_fill, aten.zeros_like] | |
triton_per_fused__softmax_add_bmm_index_logical_not_masked_fill_zeros_like_3.run(buf83, arg454_1, arg67_1, buf87, 32, 208, grid=grid(32), stream=stream0) | |
buf88 = buf63; del buf63 # reuse | |
# Source Nodes: [y_9], Original ATen: [aten.bmm] | |
triton_red_fused_bmm_4.run(buf87, arg107_1, buf88, 8192, 104, grid=grid(8192), stream=stream0) | |
del arg107_1 | |
buf90 = buf78; del buf78 # reuse | |
# Source Nodes: [out_19, y_9], Original ATen: [aten.bmm, torchao.fp16act_fp6weight_linear] | |
triton_per_fused_bmm_fp16act_fp6weight_linear_5.run(buf88, buf90, 4096, 2, grid=grid(4096), stream=stream0) | |
# Source Nodes: [out_19], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf91 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf90, arg108_1, arg109_1, 13) | |
del arg108_1 | |
del arg109_1 | |
buf92 = buf91 | |
del buf91 | |
buf93 = buf41; del buf41 # reuse | |
buf95 = buf90; del buf90 # reuse | |
buf98 = buf43; del buf43 # reuse | |
# Source Nodes: [float_16, h_2, h_3, mean_7, mul_56, out_11, out_17, out_20, out_21], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, torchao.fp16act_fp6weight_linear] | |
triton_red_fused__to_copy_add_fp16act_fp6weight_linear_mean_mul_13.run(buf93, buf51, buf67, buf76, buf92, arg7_1, buf95, buf98, 1, 4096, grid=grid(1), stream=stream0) | |
del arg7_1 | |
del buf51 | |
del buf67 | |
del buf76 | |
del buf92 | |
# Source Nodes: [out_20], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf96 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf95, arg110_1, arg111_1, 1) | |
del arg110_1 | |
del arg111_1 | |
buf97 = buf96 | |
del buf96 | |
# Source Nodes: [out_21], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf99 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf98, arg112_1, arg113_1, 1) | |
del arg112_1 | |
del arg113_1 | |
buf100 = buf99 | |
del buf99 | |
buf101 = buf100; del buf100 # reuse | |
# Source Nodes: [out_22], Original ATen: [torchao.fp16act_fp6weight_linear] | |
triton_poi_fused_fp16act_fp6weight_linear_14.run(buf101, buf97, 11008, grid=grid(11008), stream=stream0) | |
del buf97 | |
# Source Nodes: [out_22], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf102 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf101, arg114_1, arg115_1, 13) | |
del arg114_1 | |
del arg115_1 | |
del buf101 | |
buf103 = buf102 | |
del buf102 | |
buf105 = buf98; del buf98 # reuse | |
# Source Nodes: [float_17, mean_8, mul_60, out_23, out_24], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, torchao.fp16act_fp6weight_linear] | |
triton_red_fused__to_copy_add_fp16act_fp6weight_linear_mean_mul_10.run(buf93, buf103, arg8_1, buf105, 1, 4096, grid=grid(1), stream=stream0) | |
del arg8_1 | |
# Source Nodes: [out_24], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf106 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf105, arg116_1, arg117_1, 1) | |
del arg116_1 | |
del arg117_1 | |
buf107 = buf106 | |
del buf106 | |
buf109 = buf82; del buf82 # reuse | |
# Source Nodes: [setitem_8, setitem_9, y_12], Original ATen: [aten.bmm, aten.index_put] | |
triton_poi_fused_bmm_index_put_1.run(arg454_1, buf107, arg66_1, arg118_1, buf109, arg119_1, 4096, grid=grid(4096), stream=stream0) | |
del buf107 | |
buf110 = buf87; del buf87 # reuse | |
# Source Nodes: [y_12], Original ATen: [aten.bmm] | |
triton_red_fused_bmm_2.run(buf109, arg118_1, buf110, 6656, 128, grid=grid(6656), stream=stream0) | |
del arg118_1 | |
buf114 = buf83; del buf83 # reuse | |
# Source Nodes: [mask, y_12], Original ATen: [aten._softmax, aten.add, aten.bmm, aten.index, aten.logical_not, aten.masked_fill, aten.zeros_like] | |
triton_per_fused__softmax_add_bmm_index_logical_not_masked_fill_zeros_like_3.run(buf110, arg454_1, arg67_1, buf114, 32, 208, grid=grid(32), stream=stream0) | |
buf115 = buf88; del buf88 # reuse | |
# Source Nodes: [y_12], Original ATen: [aten.bmm] | |
triton_red_fused_bmm_4.run(buf114, arg119_1, buf115, 8192, 104, grid=grid(8192), stream=stream0) | |
del arg119_1 | |
buf117 = buf105; del buf105 # reuse | |
# Source Nodes: [out_25, y_12], Original ATen: [aten.bmm, torchao.fp16act_fp6weight_linear] | |
triton_per_fused_bmm_fp16act_fp6weight_linear_5.run(buf115, buf117, 4096, 2, grid=grid(4096), stream=stream0) | |
# Source Nodes: [out_25], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf118 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf117, arg120_1, arg121_1, 13) | |
del arg120_1 | |
del arg121_1 | |
buf119 = buf118 | |
del buf118 | |
buf121 = reinterpret_tensor(buf117, (1, 1, 4096), (4096, 4096, 1), 0); del buf117 # reuse | |
# Source Nodes: [add_28, float_20, h_4, mean_9, mul_71, mul_72, mul_73, out_23, output_9, rsqrt_9], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, aten.rsqrt] | |
triton_red_fused__to_copy_add_mean_mul_rsqrt_11.run(buf93, buf103, buf119, arg9_1, buf121, 1, 4096, grid=grid(1), stream=stream0) | |
del arg9_1 | |
# Source Nodes: [out_26], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf122 = torch.ops.torchao.fp16act_fp6weight_linear.default(reinterpret_tensor(buf121, (1, 4096), (4096, 1), 0), arg122_1, arg123_1, 1) | |
del arg122_1 | |
del arg123_1 | |
buf123 = buf122 | |
del buf122 | |
# Source Nodes: [out_27], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf124 = torch.ops.torchao.fp16act_fp6weight_linear.default(reinterpret_tensor(buf121, (1, 4096), (4096, 1), 0), arg124_1, arg125_1, 1) | |
del arg124_1 | |
del arg125_1 | |
buf125 = buf124 | |
del buf124 | |
buf126 = buf123; del buf123 # reuse | |
# Source Nodes: [out_28], Original ATen: [torchao.fp16act_fp6weight_linear] | |
triton_poi_fused_fp16act_fp6weight_linear_7.run(buf126, buf125, 11008, grid=grid(11008), stream=stream0) | |
del buf125 | |
# Source Nodes: [out_28], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf127 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf126, arg126_1, arg127_1, 13) | |
del arg126_1 | |
del arg127_1 | |
del buf126 | |
buf128 = buf127 | |
del buf127 | |
buf130 = reinterpret_tensor(buf121, (1, 4096), (4096, 1), 0); del buf121 # reuse | |
# Source Nodes: [float_21, h_4, mean_10, mul_75, out_23, out_29, out_30], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, torchao.fp16act_fp6weight_linear] | |
triton_red_fused__to_copy_add_fp16act_fp6weight_linear_mean_mul_12.run(buf93, buf103, buf119, buf128, arg10_1, buf130, 1, 4096, grid=grid(1), stream=stream0) | |
del arg10_1 | |
# Source Nodes: [out_30], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf131 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf130, arg128_1, arg129_1, 1) | |
del arg128_1 | |
del arg129_1 | |
buf132 = buf131 | |
del buf131 | |
buf134 = buf109; del buf109 # reuse | |
# Source Nodes: [setitem_10, setitem_11, y_15], Original ATen: [aten.bmm, aten.index_put] | |
triton_poi_fused_bmm_index_put_1.run(arg454_1, buf132, arg66_1, arg130_1, buf134, arg131_1, 4096, grid=grid(4096), stream=stream0) | |
del buf132 | |
buf135 = buf114; del buf114 # reuse | |
# Source Nodes: [y_15], Original ATen: [aten.bmm] | |
triton_red_fused_bmm_2.run(buf134, arg130_1, buf135, 6656, 128, grid=grid(6656), stream=stream0) | |
del arg130_1 | |
buf139 = buf110; del buf110 # reuse | |
# Source Nodes: [mask, y_15], Original ATen: [aten._softmax, aten.add, aten.bmm, aten.index, aten.logical_not, aten.masked_fill, aten.zeros_like] | |
triton_per_fused__softmax_add_bmm_index_logical_not_masked_fill_zeros_like_3.run(buf135, arg454_1, arg67_1, buf139, 32, 208, grid=grid(32), stream=stream0) | |
buf140 = buf115; del buf115 # reuse | |
# Source Nodes: [y_15], Original ATen: [aten.bmm] | |
triton_red_fused_bmm_4.run(buf139, arg131_1, buf140, 8192, 104, grid=grid(8192), stream=stream0) | |
del arg131_1 | |
buf142 = buf130; del buf130 # reuse | |
# Source Nodes: [out_31, y_15], Original ATen: [aten.bmm, torchao.fp16act_fp6weight_linear] | |
triton_per_fused_bmm_fp16act_fp6weight_linear_5.run(buf140, buf142, 4096, 2, grid=grid(4096), stream=stream0) | |
# Source Nodes: [out_31], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf143 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf142, arg132_1, arg133_1, 13) | |
del arg132_1 | |
del arg133_1 | |
buf144 = buf143 | |
del buf143 | |
buf145 = reinterpret_tensor(buf103, (1, 1, 4096), (4096, 4096, 1), 0); del buf103 # reuse | |
buf147 = buf142; del buf142 # reuse | |
buf150 = buf95; del buf95 # reuse | |
# Source Nodes: [float_24, h_4, h_5, mean_11, mul_86, out_23, out_29, out_32, out_33], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, torchao.fp16act_fp6weight_linear] | |
triton_red_fused__to_copy_add_fp16act_fp6weight_linear_mean_mul_15.run(buf145, buf93, buf119, buf128, buf144, arg11_1, buf147, buf150, 1, 4096, grid=grid(1), stream=stream0) | |
del arg11_1 | |
del buf119 | |
del buf128 | |
del buf144 | |
del buf93 | |
# Source Nodes: [out_32], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf148 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf147, arg134_1, arg135_1, 1) | |
del arg134_1 | |
del arg135_1 | |
buf149 = buf148 | |
del buf148 | |
# Source Nodes: [out_33], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf151 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf150, arg136_1, arg137_1, 1) | |
del arg136_1 | |
del arg137_1 | |
buf152 = buf151 | |
del buf151 | |
buf153 = buf149; del buf149 # reuse | |
# Source Nodes: [out_34], Original ATen: [torchao.fp16act_fp6weight_linear] | |
triton_poi_fused_fp16act_fp6weight_linear_7.run(buf153, buf152, 11008, grid=grid(11008), stream=stream0) | |
del buf152 | |
# Source Nodes: [out_34], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf154 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf153, arg138_1, arg139_1, 13) | |
del arg138_1 | |
del arg139_1 | |
del buf153 | |
buf155 = buf154 | |
del buf154 | |
buf157 = buf150; del buf150 # reuse | |
# Source Nodes: [float_25, mean_12, mul_90, out_35, out_36], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, torchao.fp16act_fp6weight_linear] | |
triton_red_fused__to_copy_add_fp16act_fp6weight_linear_mean_mul_10.run(buf145, buf155, arg12_1, buf157, 1, 4096, grid=grid(1), stream=stream0) | |
del arg12_1 | |
# Source Nodes: [out_36], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf158 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf157, arg140_1, arg141_1, 1) | |
del arg140_1 | |
del arg141_1 | |
buf159 = buf158 | |
del buf158 | |
buf161 = buf134; del buf134 # reuse | |
# Source Nodes: [setitem_12, setitem_13, y_18], Original ATen: [aten.bmm, aten.index_put] | |
triton_poi_fused_bmm_index_put_1.run(arg454_1, buf159, arg66_1, arg142_1, buf161, arg143_1, 4096, grid=grid(4096), stream=stream0) | |
del buf159 | |
buf162 = buf139; del buf139 # reuse | |
# Source Nodes: [y_18], Original ATen: [aten.bmm] | |
triton_red_fused_bmm_2.run(buf161, arg142_1, buf162, 6656, 128, grid=grid(6656), stream=stream0) | |
del arg142_1 | |
buf166 = buf135; del buf135 # reuse | |
# Source Nodes: [mask, y_18], Original ATen: [aten._softmax, aten.add, aten.bmm, aten.index, aten.logical_not, aten.masked_fill, aten.zeros_like] | |
triton_per_fused__softmax_add_bmm_index_logical_not_masked_fill_zeros_like_3.run(buf162, arg454_1, arg67_1, buf166, 32, 208, grid=grid(32), stream=stream0) | |
buf167 = buf140; del buf140 # reuse | |
# Source Nodes: [y_18], Original ATen: [aten.bmm] | |
triton_red_fused_bmm_4.run(buf166, arg143_1, buf167, 8192, 104, grid=grid(8192), stream=stream0) | |
del arg143_1 | |
buf169 = buf157; del buf157 # reuse | |
# Source Nodes: [out_37, y_18], Original ATen: [aten.bmm, torchao.fp16act_fp6weight_linear] | |
triton_per_fused_bmm_fp16act_fp6weight_linear_5.run(buf167, buf169, 4096, 2, grid=grid(4096), stream=stream0) | |
# Source Nodes: [out_37], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf170 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf169, arg144_1, arg145_1, 13) | |
del arg144_1 | |
del arg145_1 | |
buf171 = buf170 | |
del buf170 | |
buf173 = reinterpret_tensor(buf169, (1, 1, 4096), (4096, 4096, 1), 0); del buf169 # reuse | |
# Source Nodes: [add_40, float_28, h_6, mean_13, mul_101, mul_102, mul_103, out_35, output_13, rsqrt_13], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, aten.rsqrt] | |
triton_red_fused__to_copy_add_mean_mul_rsqrt_11.run(buf145, buf155, buf171, arg13_1, buf173, 1, 4096, grid=grid(1), stream=stream0) | |
del arg13_1 | |
# Source Nodes: [out_38], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf174 = torch.ops.torchao.fp16act_fp6weight_linear.default(reinterpret_tensor(buf173, (1, 4096), (4096, 1), 0), arg146_1, arg147_1, 1) | |
del arg146_1 | |
del arg147_1 | |
buf175 = buf174 | |
del buf174 | |
# Source Nodes: [out_39], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf176 = torch.ops.torchao.fp16act_fp6weight_linear.default(reinterpret_tensor(buf173, (1, 4096), (4096, 1), 0), arg148_1, arg149_1, 1) | |
del arg148_1 | |
del arg149_1 | |
buf177 = buf176 | |
del buf176 | |
buf178 = buf175; del buf175 # reuse | |
# Source Nodes: [out_40], Original ATen: [torchao.fp16act_fp6weight_linear] | |
triton_poi_fused_fp16act_fp6weight_linear_7.run(buf178, buf177, 11008, grid=grid(11008), stream=stream0) | |
del buf177 | |
# Source Nodes: [out_40], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf179 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf178, arg150_1, arg151_1, 13) | |
del arg150_1 | |
del arg151_1 | |
del buf178 | |
buf180 = buf179 | |
del buf179 | |
buf182 = reinterpret_tensor(buf173, (1, 4096), (4096, 1), 0); del buf173 # reuse | |
# Source Nodes: [float_29, h_6, mean_14, mul_105, out_35, out_41, out_42], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, torchao.fp16act_fp6weight_linear] | |
triton_red_fused__to_copy_add_fp16act_fp6weight_linear_mean_mul_12.run(buf145, buf155, buf171, buf180, arg14_1, buf182, 1, 4096, grid=grid(1), stream=stream0) | |
del arg14_1 | |
# Source Nodes: [out_42], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf183 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf182, arg152_1, arg153_1, 1) | |
del arg152_1 | |
del arg153_1 | |
buf184 = buf183 | |
del buf183 | |
buf186 = buf161; del buf161 # reuse | |
# Source Nodes: [setitem_14, setitem_15, y_21], Original ATen: [aten.bmm, aten.index_put] | |
triton_poi_fused_bmm_index_put_1.run(arg454_1, buf184, arg66_1, arg154_1, buf186, arg155_1, 4096, grid=grid(4096), stream=stream0) | |
del buf184 | |
buf187 = buf166; del buf166 # reuse | |
# Source Nodes: [y_21], Original ATen: [aten.bmm] | |
triton_red_fused_bmm_2.run(buf186, arg154_1, buf187, 6656, 128, grid=grid(6656), stream=stream0) | |
del arg154_1 | |
buf191 = buf162; del buf162 # reuse | |
# Source Nodes: [mask, y_21], Original ATen: [aten._softmax, aten.add, aten.bmm, aten.index, aten.logical_not, aten.masked_fill, aten.zeros_like] | |
triton_per_fused__softmax_add_bmm_index_logical_not_masked_fill_zeros_like_3.run(buf187, arg454_1, arg67_1, buf191, 32, 208, grid=grid(32), stream=stream0) | |
buf192 = buf167; del buf167 # reuse | |
# Source Nodes: [y_21], Original ATen: [aten.bmm] | |
triton_red_fused_bmm_4.run(buf191, arg155_1, buf192, 8192, 104, grid=grid(8192), stream=stream0) | |
del arg155_1 | |
buf194 = buf182; del buf182 # reuse | |
# Source Nodes: [out_43, y_21], Original ATen: [aten.bmm, torchao.fp16act_fp6weight_linear] | |
triton_per_fused_bmm_fp16act_fp6weight_linear_5.run(buf192, buf194, 4096, 2, grid=grid(4096), stream=stream0) | |
# Source Nodes: [out_43], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf195 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf194, arg156_1, arg157_1, 13) | |
del arg156_1 | |
del arg157_1 | |
buf196 = buf195 | |
del buf195 | |
buf197 = buf145; del buf145 # reuse | |
buf199 = buf194; del buf194 # reuse | |
buf202 = buf147; del buf147 # reuse | |
# Source Nodes: [float_32, h_6, h_7, mean_15, mul_116, out_35, out_41, out_44, out_45], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, torchao.fp16act_fp6weight_linear] | |
triton_red_fused__to_copy_add_fp16act_fp6weight_linear_mean_mul_13.run(buf197, buf155, buf171, buf180, buf196, arg15_1, buf199, buf202, 1, 4096, grid=grid(1), stream=stream0) | |
del arg15_1 | |
del buf155 | |
del buf171 | |
del buf180 | |
del buf196 | |
# Source Nodes: [out_44], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf200 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf199, arg158_1, arg159_1, 1) | |
del arg158_1 | |
del arg159_1 | |
buf201 = buf200 | |
del buf200 | |
# Source Nodes: [out_45], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf203 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf202, arg160_1, arg161_1, 1) | |
del arg160_1 | |
del arg161_1 | |
buf204 = buf203 | |
del buf203 | |
buf205 = buf201; del buf201 # reuse | |
# Source Nodes: [out_46], Original ATen: [torchao.fp16act_fp6weight_linear] | |
triton_poi_fused_fp16act_fp6weight_linear_7.run(buf205, buf204, 11008, grid=grid(11008), stream=stream0) | |
del buf204 | |
# Source Nodes: [out_46], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf206 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf205, arg162_1, arg163_1, 13) | |
del arg162_1 | |
del arg163_1 | |
del buf205 | |
buf207 = buf206 | |
del buf206 | |
buf209 = buf202; del buf202 # reuse | |
# Source Nodes: [float_33, mean_16, mul_120, out_47, out_48], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, torchao.fp16act_fp6weight_linear] | |
triton_red_fused__to_copy_add_fp16act_fp6weight_linear_mean_mul_10.run(buf197, buf207, arg16_1, buf209, 1, 4096, grid=grid(1), stream=stream0) | |
del arg16_1 | |
# Source Nodes: [out_48], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf210 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf209, arg164_1, arg165_1, 1) | |
del arg164_1 | |
del arg165_1 | |
buf211 = buf210 | |
del buf210 | |
buf213 = buf186; del buf186 # reuse | |
# Source Nodes: [setitem_16, setitem_17, y_24], Original ATen: [aten.bmm, aten.index_put] | |
triton_poi_fused_bmm_index_put_1.run(arg454_1, buf211, arg66_1, arg166_1, buf213, arg167_1, 4096, grid=grid(4096), stream=stream0) | |
del buf211 | |
buf214 = buf191; del buf191 # reuse | |
# Source Nodes: [y_24], Original ATen: [aten.bmm] | |
triton_red_fused_bmm_2.run(buf213, arg166_1, buf214, 6656, 128, grid=grid(6656), stream=stream0) | |
del arg166_1 | |
buf218 = buf187; del buf187 # reuse | |
# Source Nodes: [mask, y_24], Original ATen: [aten._softmax, aten.add, aten.bmm, aten.index, aten.logical_not, aten.masked_fill, aten.zeros_like] | |
triton_per_fused__softmax_add_bmm_index_logical_not_masked_fill_zeros_like_3.run(buf214, arg454_1, arg67_1, buf218, 32, 208, grid=grid(32), stream=stream0) | |
buf219 = buf192; del buf192 # reuse | |
# Source Nodes: [y_24], Original ATen: [aten.bmm] | |
triton_red_fused_bmm_4.run(buf218, arg167_1, buf219, 8192, 104, grid=grid(8192), stream=stream0) | |
del arg167_1 | |
buf221 = buf209; del buf209 # reuse | |
# Source Nodes: [out_49, y_24], Original ATen: [aten.bmm, torchao.fp16act_fp6weight_linear] | |
triton_per_fused_bmm_fp16act_fp6weight_linear_5.run(buf219, buf221, 4096, 2, grid=grid(4096), stream=stream0) | |
# Source Nodes: [out_49], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf222 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf221, arg168_1, arg169_1, 13) | |
del arg168_1 | |
del arg169_1 | |
buf223 = buf222 | |
del buf222 | |
buf225 = reinterpret_tensor(buf221, (1, 1, 4096), (4096, 4096, 1), 0); del buf221 # reuse | |
# Source Nodes: [add_52, float_36, h_8, mean_17, mul_131, mul_132, mul_133, out_47, output_17, rsqrt_17], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, aten.rsqrt] | |
triton_red_fused__to_copy_add_mean_mul_rsqrt_11.run(buf197, buf207, buf223, arg17_1, buf225, 1, 4096, grid=grid(1), stream=stream0) | |
del arg17_1 | |
# Source Nodes: [out_50], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf226 = torch.ops.torchao.fp16act_fp6weight_linear.default(reinterpret_tensor(buf225, (1, 4096), (4096, 1), 0), arg170_1, arg171_1, 1) | |
del arg170_1 | |
del arg171_1 | |
buf227 = buf226 | |
del buf226 | |
# Source Nodes: [out_51], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf228 = torch.ops.torchao.fp16act_fp6weight_linear.default(reinterpret_tensor(buf225, (1, 4096), (4096, 1), 0), arg172_1, arg173_1, 1) | |
del arg172_1 | |
del arg173_1 | |
buf229 = buf228 | |
del buf228 | |
buf230 = buf227; del buf227 # reuse | |
# Source Nodes: [out_52], Original ATen: [torchao.fp16act_fp6weight_linear] | |
triton_poi_fused_fp16act_fp6weight_linear_7.run(buf230, buf229, 11008, grid=grid(11008), stream=stream0) | |
del buf229 | |
# Source Nodes: [out_52], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf231 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf230, arg174_1, arg175_1, 13) | |
del arg174_1 | |
del arg175_1 | |
del buf230 | |
buf232 = buf231 | |
del buf231 | |
buf234 = reinterpret_tensor(buf225, (1, 4096), (4096, 1), 0); del buf225 # reuse | |
# Source Nodes: [float_37, h_8, mean_18, mul_135, out_47, out_53, out_54], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, torchao.fp16act_fp6weight_linear] | |
triton_red_fused__to_copy_add_fp16act_fp6weight_linear_mean_mul_12.run(buf197, buf207, buf223, buf232, arg18_1, buf234, 1, 4096, grid=grid(1), stream=stream0) | |
del arg18_1 | |
# Source Nodes: [out_54], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf235 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf234, arg176_1, arg177_1, 1) | |
del arg176_1 | |
del arg177_1 | |
buf236 = buf235 | |
del buf235 | |
buf238 = buf213; del buf213 # reuse | |
# Source Nodes: [setitem_18, setitem_19, y_27], Original ATen: [aten.bmm, aten.index_put] | |
triton_poi_fused_bmm_index_put_1.run(arg454_1, buf236, arg66_1, arg178_1, buf238, arg179_1, 4096, grid=grid(4096), stream=stream0) | |
del buf236 | |
buf239 = buf218; del buf218 # reuse | |
# Source Nodes: [y_27], Original ATen: [aten.bmm] | |
triton_red_fused_bmm_2.run(buf238, arg178_1, buf239, 6656, 128, grid=grid(6656), stream=stream0) | |
del arg178_1 | |
buf243 = buf214; del buf214 # reuse | |
# Source Nodes: [mask, y_27], Original ATen: [aten._softmax, aten.add, aten.bmm, aten.index, aten.logical_not, aten.masked_fill, aten.zeros_like] | |
triton_per_fused__softmax_add_bmm_index_logical_not_masked_fill_zeros_like_3.run(buf239, arg454_1, arg67_1, buf243, 32, 208, grid=grid(32), stream=stream0) | |
buf244 = buf219; del buf219 # reuse | |
# Source Nodes: [y_27], Original ATen: [aten.bmm] | |
triton_red_fused_bmm_4.run(buf243, arg179_1, buf244, 8192, 104, grid=grid(8192), stream=stream0) | |
del arg179_1 | |
buf246 = buf234; del buf234 # reuse | |
# Source Nodes: [out_55, y_27], Original ATen: [aten.bmm, torchao.fp16act_fp6weight_linear] | |
triton_per_fused_bmm_fp16act_fp6weight_linear_5.run(buf244, buf246, 4096, 2, grid=grid(4096), stream=stream0) | |
# Source Nodes: [out_55], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf247 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf246, arg180_1, arg181_1, 13) | |
del arg180_1 | |
del arg181_1 | |
buf248 = buf247 | |
del buf247 | |
buf249 = buf197; del buf197 # reuse | |
buf251 = buf246; del buf246 # reuse | |
buf254 = buf199; del buf199 # reuse | |
# Source Nodes: [float_40, h_8, h_9, mean_19, mul_146, out_47, out_53, out_56, out_57], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, torchao.fp16act_fp6weight_linear] | |
triton_red_fused__to_copy_add_fp16act_fp6weight_linear_mean_mul_13.run(buf249, buf207, buf223, buf232, buf248, arg19_1, buf251, buf254, 1, 4096, grid=grid(1), stream=stream0) | |
del arg19_1 | |
del buf207 | |
del buf223 | |
del buf232 | |
del buf248 | |
# Source Nodes: [out_56], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf252 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf251, arg182_1, arg183_1, 1) | |
del arg182_1 | |
del arg183_1 | |
buf253 = buf252 | |
del buf252 | |
# Source Nodes: [out_57], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf255 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf254, arg184_1, arg185_1, 1) | |
del arg184_1 | |
del arg185_1 | |
buf256 = buf255 | |
del buf255 | |
buf257 = buf253; del buf253 # reuse | |
# Source Nodes: [out_58], Original ATen: [torchao.fp16act_fp6weight_linear] | |
triton_poi_fused_fp16act_fp6weight_linear_7.run(buf257, buf256, 11008, grid=grid(11008), stream=stream0) | |
del buf256 | |
# Source Nodes: [out_58], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf258 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf257, arg186_1, arg187_1, 13) | |
del arg186_1 | |
del arg187_1 | |
del buf257 | |
buf259 = buf258 | |
del buf258 | |
buf261 = buf254; del buf254 # reuse | |
# Source Nodes: [float_41, mean_20, mul_150, out_59, out_60], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, torchao.fp16act_fp6weight_linear] | |
triton_red_fused__to_copy_add_fp16act_fp6weight_linear_mean_mul_10.run(buf249, buf259, arg20_1, buf261, 1, 4096, grid=grid(1), stream=stream0) | |
del arg20_1 | |
# Source Nodes: [out_60], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf262 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf261, arg188_1, arg189_1, 1) | |
del arg188_1 | |
del arg189_1 | |
buf263 = buf262 | |
del buf262 | |
buf265 = buf238; del buf238 # reuse | |
# Source Nodes: [setitem_20, setitem_21, y_30], Original ATen: [aten.bmm, aten.index_put] | |
triton_poi_fused_bmm_index_put_1.run(arg454_1, buf263, arg66_1, arg190_1, buf265, arg191_1, 4096, grid=grid(4096), stream=stream0) | |
del buf263 | |
buf266 = buf243; del buf243 # reuse | |
# Source Nodes: [y_30], Original ATen: [aten.bmm] | |
triton_red_fused_bmm_2.run(buf265, arg190_1, buf266, 6656, 128, grid=grid(6656), stream=stream0) | |
del arg190_1 | |
buf270 = buf239; del buf239 # reuse | |
# Source Nodes: [mask, y_30], Original ATen: [aten._softmax, aten.add, aten.bmm, aten.index, aten.logical_not, aten.masked_fill, aten.zeros_like] | |
triton_per_fused__softmax_add_bmm_index_logical_not_masked_fill_zeros_like_3.run(buf266, arg454_1, arg67_1, buf270, 32, 208, grid=grid(32), stream=stream0) | |
buf271 = buf244; del buf244 # reuse | |
# Source Nodes: [y_30], Original ATen: [aten.bmm] | |
triton_red_fused_bmm_4.run(buf270, arg191_1, buf271, 8192, 104, grid=grid(8192), stream=stream0) | |
del arg191_1 | |
buf273 = buf261; del buf261 # reuse | |
# Source Nodes: [out_61, y_30], Original ATen: [aten.bmm, torchao.fp16act_fp6weight_linear] | |
triton_per_fused_bmm_fp16act_fp6weight_linear_5.run(buf271, buf273, 4096, 2, grid=grid(4096), stream=stream0) | |
# Source Nodes: [out_61], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf274 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf273, arg192_1, arg193_1, 13) | |
del arg192_1 | |
del arg193_1 | |
buf275 = buf274 | |
del buf274 | |
buf277 = reinterpret_tensor(buf273, (1, 1, 4096), (4096, 4096, 1), 0); del buf273 # reuse | |
# Source Nodes: [add_64, float_44, h_10, mean_21, mul_161, mul_162, mul_163, out_59, output_21, rsqrt_21], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, aten.rsqrt] | |
triton_red_fused__to_copy_add_mean_mul_rsqrt_11.run(buf249, buf259, buf275, arg21_1, buf277, 1, 4096, grid=grid(1), stream=stream0) | |
del arg21_1 | |
# Source Nodes: [out_62], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf278 = torch.ops.torchao.fp16act_fp6weight_linear.default(reinterpret_tensor(buf277, (1, 4096), (4096, 1), 0), arg194_1, arg195_1, 1) | |
del arg194_1 | |
del arg195_1 | |
buf279 = buf278 | |
del buf278 | |
# Source Nodes: [out_63], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf280 = torch.ops.torchao.fp16act_fp6weight_linear.default(reinterpret_tensor(buf277, (1, 4096), (4096, 1), 0), arg196_1, arg197_1, 1) | |
del arg196_1 | |
del arg197_1 | |
buf281 = buf280 | |
del buf280 | |
buf282 = buf279; del buf279 # reuse | |
# Source Nodes: [out_64], Original ATen: [torchao.fp16act_fp6weight_linear] | |
triton_poi_fused_fp16act_fp6weight_linear_7.run(buf282, buf281, 11008, grid=grid(11008), stream=stream0) | |
del buf281 | |
# Source Nodes: [out_64], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf283 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf282, arg198_1, arg199_1, 13) | |
del arg198_1 | |
del arg199_1 | |
del buf282 | |
buf284 = buf283 | |
del buf283 | |
buf286 = reinterpret_tensor(buf277, (1, 4096), (4096, 1), 0); del buf277 # reuse | |
# Source Nodes: [float_45, h_10, mean_22, mul_165, out_59, out_65, out_66], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, torchao.fp16act_fp6weight_linear] | |
triton_red_fused__to_copy_add_fp16act_fp6weight_linear_mean_mul_12.run(buf249, buf259, buf275, buf284, arg22_1, buf286, 1, 4096, grid=grid(1), stream=stream0) | |
del arg22_1 | |
# Source Nodes: [out_66], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf287 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf286, arg200_1, arg201_1, 1) | |
del arg200_1 | |
del arg201_1 | |
buf288 = buf287 | |
del buf287 | |
buf290 = buf265; del buf265 # reuse | |
# Source Nodes: [setitem_22, setitem_23, y_33], Original ATen: [aten.bmm, aten.index_put] | |
triton_poi_fused_bmm_index_put_1.run(arg454_1, buf288, arg66_1, arg202_1, buf290, arg203_1, 4096, grid=grid(4096), stream=stream0) | |
del buf288 | |
buf291 = buf270; del buf270 # reuse | |
# Source Nodes: [y_33], Original ATen: [aten.bmm] | |
triton_red_fused_bmm_2.run(buf290, arg202_1, buf291, 6656, 128, grid=grid(6656), stream=stream0) | |
del arg202_1 | |
buf295 = buf266; del buf266 # reuse | |
# Source Nodes: [mask, y_33], Original ATen: [aten._softmax, aten.add, aten.bmm, aten.index, aten.logical_not, aten.masked_fill, aten.zeros_like] | |
triton_per_fused__softmax_add_bmm_index_logical_not_masked_fill_zeros_like_3.run(buf291, arg454_1, arg67_1, buf295, 32, 208, grid=grid(32), stream=stream0) | |
buf296 = buf271; del buf271 # reuse | |
# Source Nodes: [y_33], Original ATen: [aten.bmm] | |
triton_red_fused_bmm_4.run(buf295, arg203_1, buf296, 8192, 104, grid=grid(8192), stream=stream0) | |
del arg203_1 | |
buf298 = buf286; del buf286 # reuse | |
# Source Nodes: [out_67, y_33], Original ATen: [aten.bmm, torchao.fp16act_fp6weight_linear] | |
triton_per_fused_bmm_fp16act_fp6weight_linear_5.run(buf296, buf298, 4096, 2, grid=grid(4096), stream=stream0) | |
# Source Nodes: [out_67], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf299 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf298, arg204_1, arg205_1, 13) | |
del arg204_1 | |
del arg205_1 | |
buf300 = buf299 | |
del buf299 | |
buf301 = buf249; del buf249 # reuse | |
buf303 = buf298; del buf298 # reuse | |
buf306 = buf251; del buf251 # reuse | |
# Source Nodes: [float_48, h_10, h_11, mean_23, mul_176, out_59, out_65, out_68, out_69], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, torchao.fp16act_fp6weight_linear] | |
triton_red_fused__to_copy_add_fp16act_fp6weight_linear_mean_mul_13.run(buf301, buf259, buf275, buf284, buf300, arg23_1, buf303, buf306, 1, 4096, grid=grid(1), stream=stream0) | |
del arg23_1 | |
del buf259 | |
del buf275 | |
del buf284 | |
del buf300 | |
# Source Nodes: [out_68], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf304 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf303, arg206_1, arg207_1, 1) | |
del arg206_1 | |
del arg207_1 | |
buf305 = buf304 | |
del buf304 | |
# Source Nodes: [out_69], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf307 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf306, arg208_1, arg209_1, 1) | |
del arg208_1 | |
del arg209_1 | |
buf308 = buf307 | |
del buf307 | |
buf309 = buf305; del buf305 # reuse | |
# Source Nodes: [out_70], Original ATen: [torchao.fp16act_fp6weight_linear] | |
triton_poi_fused_fp16act_fp6weight_linear_7.run(buf309, buf308, 11008, grid=grid(11008), stream=stream0) | |
del buf308 | |
# Source Nodes: [out_70], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf310 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf309, arg210_1, arg211_1, 13) | |
del arg210_1 | |
del arg211_1 | |
del buf309 | |
buf311 = buf310 | |
del buf310 | |
buf313 = buf306; del buf306 # reuse | |
# Source Nodes: [float_49, mean_24, mul_180, out_71, out_72], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, torchao.fp16act_fp6weight_linear] | |
triton_red_fused__to_copy_add_fp16act_fp6weight_linear_mean_mul_10.run(buf301, buf311, arg24_1, buf313, 1, 4096, grid=grid(1), stream=stream0) | |
del arg24_1 | |
# Source Nodes: [out_72], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf314 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf313, arg212_1, arg213_1, 1) | |
del arg212_1 | |
del arg213_1 | |
buf315 = buf314 | |
del buf314 | |
buf317 = buf290; del buf290 # reuse | |
# Source Nodes: [setitem_24, setitem_25, y_36], Original ATen: [aten.bmm, aten.index_put] | |
triton_poi_fused_bmm_index_put_1.run(arg454_1, buf315, arg66_1, arg214_1, buf317, arg215_1, 4096, grid=grid(4096), stream=stream0) | |
del buf315 | |
buf318 = buf295; del buf295 # reuse | |
# Source Nodes: [y_36], Original ATen: [aten.bmm] | |
triton_red_fused_bmm_2.run(buf317, arg214_1, buf318, 6656, 128, grid=grid(6656), stream=stream0) | |
del arg214_1 | |
buf322 = buf291; del buf291 # reuse | |
# Source Nodes: [mask, y_36], Original ATen: [aten._softmax, aten.add, aten.bmm, aten.index, aten.logical_not, aten.masked_fill, aten.zeros_like] | |
triton_per_fused__softmax_add_bmm_index_logical_not_masked_fill_zeros_like_3.run(buf318, arg454_1, arg67_1, buf322, 32, 208, grid=grid(32), stream=stream0) | |
buf323 = buf296; del buf296 # reuse | |
# Source Nodes: [y_36], Original ATen: [aten.bmm] | |
triton_red_fused_bmm_4.run(buf322, arg215_1, buf323, 8192, 104, grid=grid(8192), stream=stream0) | |
del arg215_1 | |
buf325 = buf313; del buf313 # reuse | |
# Source Nodes: [out_73, y_36], Original ATen: [aten.bmm, torchao.fp16act_fp6weight_linear] | |
triton_per_fused_bmm_fp16act_fp6weight_linear_5.run(buf323, buf325, 4096, 2, grid=grid(4096), stream=stream0) | |
# Source Nodes: [out_73], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf326 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf325, arg216_1, arg217_1, 13) | |
del arg216_1 | |
del arg217_1 | |
buf327 = buf326 | |
del buf326 | |
buf329 = reinterpret_tensor(buf325, (1, 1, 4096), (4096, 4096, 1), 0); del buf325 # reuse | |
# Source Nodes: [add_76, float_52, h_12, mean_25, mul_191, mul_192, mul_193, out_71, output_25, rsqrt_25], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, aten.rsqrt] | |
triton_red_fused__to_copy_add_mean_mul_rsqrt_11.run(buf301, buf311, buf327, arg25_1, buf329, 1, 4096, grid=grid(1), stream=stream0) | |
del arg25_1 | |
# Source Nodes: [out_74], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf330 = torch.ops.torchao.fp16act_fp6weight_linear.default(reinterpret_tensor(buf329, (1, 4096), (4096, 1), 0), arg218_1, arg219_1, 1) | |
del arg218_1 | |
del arg219_1 | |
buf331 = buf330 | |
del buf330 | |
# Source Nodes: [out_75], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf332 = torch.ops.torchao.fp16act_fp6weight_linear.default(reinterpret_tensor(buf329, (1, 4096), (4096, 1), 0), arg220_1, arg221_1, 1) | |
del arg220_1 | |
del arg221_1 | |
buf333 = buf332 | |
del buf332 | |
buf334 = buf331; del buf331 # reuse | |
# Source Nodes: [out_76], Original ATen: [torchao.fp16act_fp6weight_linear] | |
triton_poi_fused_fp16act_fp6weight_linear_7.run(buf334, buf333, 11008, grid=grid(11008), stream=stream0) | |
del buf333 | |
# Source Nodes: [out_76], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf335 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf334, arg222_1, arg223_1, 13) | |
del arg222_1 | |
del arg223_1 | |
del buf334 | |
buf336 = buf335 | |
del buf335 | |
buf338 = reinterpret_tensor(buf329, (1, 4096), (4096, 1), 0); del buf329 # reuse | |
# Source Nodes: [float_53, h_12, mean_26, mul_195, out_71, out_77, out_78], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, torchao.fp16act_fp6weight_linear] | |
triton_red_fused__to_copy_add_fp16act_fp6weight_linear_mean_mul_12.run(buf301, buf311, buf327, buf336, arg26_1, buf338, 1, 4096, grid=grid(1), stream=stream0) | |
del arg26_1 | |
# Source Nodes: [out_78], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf339 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf338, arg224_1, arg225_1, 1) | |
del arg224_1 | |
del arg225_1 | |
buf340 = buf339 | |
del buf339 | |
buf342 = buf317; del buf317 # reuse | |
# Source Nodes: [setitem_26, setitem_27, y_39], Original ATen: [aten.bmm, aten.index_put] | |
triton_poi_fused_bmm_index_put_1.run(arg454_1, buf340, arg66_1, arg226_1, buf342, arg227_1, 4096, grid=grid(4096), stream=stream0) | |
del buf340 | |
buf343 = buf322; del buf322 # reuse | |
# Source Nodes: [y_39], Original ATen: [aten.bmm] | |
triton_red_fused_bmm_2.run(buf342, arg226_1, buf343, 6656, 128, grid=grid(6656), stream=stream0) | |
del arg226_1 | |
buf347 = buf318; del buf318 # reuse | |
# Source Nodes: [mask, y_39], Original ATen: [aten._softmax, aten.add, aten.bmm, aten.index, aten.logical_not, aten.masked_fill, aten.zeros_like] | |
triton_per_fused__softmax_add_bmm_index_logical_not_masked_fill_zeros_like_3.run(buf343, arg454_1, arg67_1, buf347, 32, 208, grid=grid(32), stream=stream0) | |
buf348 = buf323; del buf323 # reuse | |
# Source Nodes: [y_39], Original ATen: [aten.bmm] | |
triton_red_fused_bmm_4.run(buf347, arg227_1, buf348, 8192, 104, grid=grid(8192), stream=stream0) | |
del arg227_1 | |
buf350 = buf338; del buf338 # reuse | |
# Source Nodes: [out_79, y_39], Original ATen: [aten.bmm, torchao.fp16act_fp6weight_linear] | |
triton_per_fused_bmm_fp16act_fp6weight_linear_5.run(buf348, buf350, 4096, 2, grid=grid(4096), stream=stream0) | |
# Source Nodes: [out_79], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf351 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf350, arg228_1, arg229_1, 13) | |
del arg228_1 | |
del arg229_1 | |
buf352 = buf351 | |
del buf351 | |
buf353 = buf301; del buf301 # reuse | |
buf355 = buf350; del buf350 # reuse | |
buf358 = buf303; del buf303 # reuse | |
# Source Nodes: [float_56, h_12, h_13, mean_27, mul_206, out_71, out_77, out_80, out_81], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, torchao.fp16act_fp6weight_linear] | |
triton_red_fused__to_copy_add_fp16act_fp6weight_linear_mean_mul_13.run(buf353, buf311, buf327, buf336, buf352, arg27_1, buf355, buf358, 1, 4096, grid=grid(1), stream=stream0) | |
del arg27_1 | |
del buf311 | |
del buf327 | |
del buf336 | |
del buf352 | |
# Source Nodes: [out_80], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf356 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf355, arg230_1, arg231_1, 1) | |
del arg230_1 | |
del arg231_1 | |
buf357 = buf356 | |
del buf356 | |
# Source Nodes: [out_81], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf359 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf358, arg232_1, arg233_1, 1) | |
del arg232_1 | |
del arg233_1 | |
buf360 = buf359 | |
del buf359 | |
buf361 = buf357; del buf357 # reuse | |
# Source Nodes: [out_82], Original ATen: [torchao.fp16act_fp6weight_linear] | |
triton_poi_fused_fp16act_fp6weight_linear_7.run(buf361, buf360, 11008, grid=grid(11008), stream=stream0) | |
del buf360 | |
# Source Nodes: [out_82], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf362 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf361, arg234_1, arg235_1, 13) | |
del arg234_1 | |
del arg235_1 | |
del buf361 | |
buf363 = buf362 | |
del buf362 | |
buf365 = buf358; del buf358 # reuse | |
# Source Nodes: [float_57, mean_28, mul_210, out_83, out_84], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, torchao.fp16act_fp6weight_linear] | |
triton_red_fused__to_copy_add_fp16act_fp6weight_linear_mean_mul_10.run(buf353, buf363, arg28_1, buf365, 1, 4096, grid=grid(1), stream=stream0) | |
del arg28_1 | |
# Source Nodes: [out_84], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf366 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf365, arg236_1, arg237_1, 1) | |
del arg236_1 | |
del arg237_1 | |
buf367 = buf366 | |
del buf366 | |
buf369 = buf342; del buf342 # reuse | |
# Source Nodes: [setitem_28, setitem_29, y_42], Original ATen: [aten.bmm, aten.index_put] | |
triton_poi_fused_bmm_index_put_1.run(arg454_1, buf367, arg66_1, arg238_1, buf369, arg239_1, 4096, grid=grid(4096), stream=stream0) | |
del buf367 | |
buf370 = buf347; del buf347 # reuse | |
# Source Nodes: [y_42], Original ATen: [aten.bmm] | |
triton_red_fused_bmm_2.run(buf369, arg238_1, buf370, 6656, 128, grid=grid(6656), stream=stream0) | |
del arg238_1 | |
buf374 = buf343; del buf343 # reuse | |
# Source Nodes: [mask, y_42], Original ATen: [aten._softmax, aten.add, aten.bmm, aten.index, aten.logical_not, aten.masked_fill, aten.zeros_like] | |
triton_per_fused__softmax_add_bmm_index_logical_not_masked_fill_zeros_like_3.run(buf370, arg454_1, arg67_1, buf374, 32, 208, grid=grid(32), stream=stream0) | |
buf375 = buf348; del buf348 # reuse | |
# Source Nodes: [y_42], Original ATen: [aten.bmm] | |
triton_red_fused_bmm_4.run(buf374, arg239_1, buf375, 8192, 104, grid=grid(8192), stream=stream0) | |
del arg239_1 | |
buf377 = buf365; del buf365 # reuse | |
# Source Nodes: [out_85, y_42], Original ATen: [aten.bmm, torchao.fp16act_fp6weight_linear] | |
triton_per_fused_bmm_fp16act_fp6weight_linear_5.run(buf375, buf377, 4096, 2, grid=grid(4096), stream=stream0) | |
# Source Nodes: [out_85], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf378 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf377, arg240_1, arg241_1, 13) | |
del arg240_1 | |
del arg241_1 | |
buf379 = buf378 | |
del buf378 | |
buf381 = reinterpret_tensor(buf377, (1, 1, 4096), (4096, 4096, 1), 0); del buf377 # reuse | |
# Source Nodes: [add_88, float_60, h_14, mean_29, mul_221, mul_222, mul_223, out_83, output_29, rsqrt_29], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, aten.rsqrt] | |
triton_red_fused__to_copy_add_mean_mul_rsqrt_11.run(buf353, buf363, buf379, arg29_1, buf381, 1, 4096, grid=grid(1), stream=stream0) | |
del arg29_1 | |
# Source Nodes: [out_86], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf382 = torch.ops.torchao.fp16act_fp6weight_linear.default(reinterpret_tensor(buf381, (1, 4096), (4096, 1), 0), arg242_1, arg243_1, 1) | |
del arg242_1 | |
del arg243_1 | |
buf383 = buf382 | |
del buf382 | |
# Source Nodes: [out_87], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf384 = torch.ops.torchao.fp16act_fp6weight_linear.default(reinterpret_tensor(buf381, (1, 4096), (4096, 1), 0), arg244_1, arg245_1, 1) | |
del arg244_1 | |
del arg245_1 | |
buf385 = buf384 | |
del buf384 | |
buf386 = buf383; del buf383 # reuse | |
# Source Nodes: [out_88], Original ATen: [torchao.fp16act_fp6weight_linear] | |
triton_poi_fused_fp16act_fp6weight_linear_7.run(buf386, buf385, 11008, grid=grid(11008), stream=stream0) | |
del buf385 | |
# Source Nodes: [out_88], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf387 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf386, arg246_1, arg247_1, 13) | |
del arg246_1 | |
del arg247_1 | |
del buf386 | |
buf388 = buf387 | |
del buf387 | |
buf390 = reinterpret_tensor(buf381, (1, 4096), (4096, 1), 0); del buf381 # reuse | |
# Source Nodes: [float_61, h_14, mean_30, mul_225, out_83, out_89, out_90], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, torchao.fp16act_fp6weight_linear] | |
triton_red_fused__to_copy_add_fp16act_fp6weight_linear_mean_mul_12.run(buf353, buf363, buf379, buf388, arg30_1, buf390, 1, 4096, grid=grid(1), stream=stream0) | |
del arg30_1 | |
# Source Nodes: [out_90], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf391 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf390, arg248_1, arg249_1, 1) | |
del arg248_1 | |
del arg249_1 | |
buf392 = buf391 | |
del buf391 | |
buf394 = buf369; del buf369 # reuse | |
# Source Nodes: [setitem_30, setitem_31, y_45], Original ATen: [aten.bmm, aten.index_put] | |
triton_poi_fused_bmm_index_put_1.run(arg454_1, buf392, arg66_1, arg250_1, buf394, arg251_1, 4096, grid=grid(4096), stream=stream0) | |
del buf392 | |
buf395 = buf374; del buf374 # reuse | |
# Source Nodes: [y_45], Original ATen: [aten.bmm] | |
triton_red_fused_bmm_2.run(buf394, arg250_1, buf395, 6656, 128, grid=grid(6656), stream=stream0) | |
del arg250_1 | |
buf399 = buf370; del buf370 # reuse | |
# Source Nodes: [mask, y_45], Original ATen: [aten._softmax, aten.add, aten.bmm, aten.index, aten.logical_not, aten.masked_fill, aten.zeros_like] | |
triton_per_fused__softmax_add_bmm_index_logical_not_masked_fill_zeros_like_3.run(buf395, arg454_1, arg67_1, buf399, 32, 208, grid=grid(32), stream=stream0) | |
buf400 = buf375; del buf375 # reuse | |
# Source Nodes: [y_45], Original ATen: [aten.bmm] | |
triton_red_fused_bmm_4.run(buf399, arg251_1, buf400, 8192, 104, grid=grid(8192), stream=stream0) | |
del arg251_1 | |
buf402 = buf390; del buf390 # reuse | |
# Source Nodes: [out_91, y_45], Original ATen: [aten.bmm, torchao.fp16act_fp6weight_linear] | |
triton_per_fused_bmm_fp16act_fp6weight_linear_5.run(buf400, buf402, 4096, 2, grid=grid(4096), stream=stream0) | |
# Source Nodes: [out_91], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf403 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf402, arg252_1, arg253_1, 13) | |
del arg252_1 | |
del arg253_1 | |
buf404 = buf403 | |
del buf403 | |
buf405 = buf353; del buf353 # reuse | |
buf407 = buf402; del buf402 # reuse | |
buf410 = buf355; del buf355 # reuse | |
# Source Nodes: [float_64, h_14, h_15, mean_31, mul_236, out_83, out_89, out_92, out_93], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, torchao.fp16act_fp6weight_linear] | |
triton_red_fused__to_copy_add_fp16act_fp6weight_linear_mean_mul_13.run(buf405, buf363, buf379, buf388, buf404, arg31_1, buf407, buf410, 1, 4096, grid=grid(1), stream=stream0) | |
del arg31_1 | |
del buf363 | |
del buf379 | |
del buf388 | |
del buf404 | |
# Source Nodes: [out_92], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf408 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf407, arg254_1, arg255_1, 1) | |
del arg254_1 | |
del arg255_1 | |
buf409 = buf408 | |
del buf408 | |
# Source Nodes: [out_93], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf411 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf410, arg256_1, arg257_1, 1) | |
del arg256_1 | |
del arg257_1 | |
buf412 = buf411 | |
del buf411 | |
buf413 = buf409; del buf409 # reuse | |
# Source Nodes: [out_94], Original ATen: [torchao.fp16act_fp6weight_linear] | |
triton_poi_fused_fp16act_fp6weight_linear_7.run(buf413, buf412, 11008, grid=grid(11008), stream=stream0) | |
del buf412 | |
# Source Nodes: [out_94], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf414 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf413, arg258_1, arg259_1, 13) | |
del arg258_1 | |
del arg259_1 | |
del buf413 | |
buf415 = buf414 | |
del buf414 | |
buf417 = buf410; del buf410 # reuse | |
# Source Nodes: [float_65, mean_32, mul_240, out_95, out_96], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, torchao.fp16act_fp6weight_linear] | |
triton_red_fused__to_copy_add_fp16act_fp6weight_linear_mean_mul_10.run(buf405, buf415, arg32_1, buf417, 1, 4096, grid=grid(1), stream=stream0) | |
del arg32_1 | |
# Source Nodes: [out_96], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf418 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf417, arg260_1, arg261_1, 1) | |
del arg260_1 | |
del arg261_1 | |
buf419 = buf418 | |
del buf418 | |
buf421 = buf394; del buf394 # reuse | |
# Source Nodes: [setitem_32, setitem_33, y_48], Original ATen: [aten.bmm, aten.index_put] | |
triton_poi_fused_bmm_index_put_1.run(arg454_1, buf419, arg66_1, arg262_1, buf421, arg263_1, 4096, grid=grid(4096), stream=stream0) | |
del buf419 | |
buf422 = buf399; del buf399 # reuse | |
# Source Nodes: [y_48], Original ATen: [aten.bmm] | |
triton_red_fused_bmm_2.run(buf421, arg262_1, buf422, 6656, 128, grid=grid(6656), stream=stream0) | |
del arg262_1 | |
buf426 = buf395; del buf395 # reuse | |
# Source Nodes: [mask, y_48], Original ATen: [aten._softmax, aten.add, aten.bmm, aten.index, aten.logical_not, aten.masked_fill, aten.zeros_like] | |
triton_per_fused__softmax_add_bmm_index_logical_not_masked_fill_zeros_like_3.run(buf422, arg454_1, arg67_1, buf426, 32, 208, grid=grid(32), stream=stream0) | |
buf427 = buf400; del buf400 # reuse | |
# Source Nodes: [y_48], Original ATen: [aten.bmm] | |
triton_red_fused_bmm_4.run(buf426, arg263_1, buf427, 8192, 104, grid=grid(8192), stream=stream0) | |
del arg263_1 | |
buf429 = buf417; del buf417 # reuse | |
# Source Nodes: [out_97, y_48], Original ATen: [aten.bmm, torchao.fp16act_fp6weight_linear] | |
triton_per_fused_bmm_fp16act_fp6weight_linear_5.run(buf427, buf429, 4096, 2, grid=grid(4096), stream=stream0) | |
# Source Nodes: [out_97], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf430 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf429, arg264_1, arg265_1, 13) | |
del arg264_1 | |
del arg265_1 | |
buf431 = buf430 | |
del buf430 | |
buf433 = reinterpret_tensor(buf429, (1, 1, 4096), (4096, 4096, 1), 0); del buf429 # reuse | |
# Source Nodes: [add_100, float_68, h_16, mean_33, mul_251, mul_252, mul_253, out_95, output_33, rsqrt_33], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, aten.rsqrt] | |
triton_red_fused__to_copy_add_mean_mul_rsqrt_11.run(buf405, buf415, buf431, arg33_1, buf433, 1, 4096, grid=grid(1), stream=stream0) | |
del arg33_1 | |
# Source Nodes: [out_98], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf434 = torch.ops.torchao.fp16act_fp6weight_linear.default(reinterpret_tensor(buf433, (1, 4096), (4096, 1), 0), arg266_1, arg267_1, 1) | |
del arg266_1 | |
del arg267_1 | |
buf435 = buf434 | |
del buf434 | |
# Source Nodes: [out_99], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf436 = torch.ops.torchao.fp16act_fp6weight_linear.default(reinterpret_tensor(buf433, (1, 4096), (4096, 1), 0), arg268_1, arg269_1, 1) | |
del arg268_1 | |
del arg269_1 | |
buf437 = buf436 | |
del buf436 | |
buf438 = buf435; del buf435 # reuse | |
# Source Nodes: [out_100], Original ATen: [torchao.fp16act_fp6weight_linear] | |
triton_poi_fused_fp16act_fp6weight_linear_7.run(buf438, buf437, 11008, grid=grid(11008), stream=stream0) | |
del buf437 | |
# Source Nodes: [out_100], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf439 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf438, arg270_1, arg271_1, 13) | |
del arg270_1 | |
del arg271_1 | |
del buf438 | |
buf440 = buf439 | |
del buf439 | |
buf442 = reinterpret_tensor(buf433, (1, 4096), (4096, 1), 0); del buf433 # reuse | |
# Source Nodes: [float_69, h_16, mean_34, mul_255, out_101, out_102, out_95], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, torchao.fp16act_fp6weight_linear] | |
triton_red_fused__to_copy_add_fp16act_fp6weight_linear_mean_mul_12.run(buf405, buf415, buf431, buf440, arg34_1, buf442, 1, 4096, grid=grid(1), stream=stream0) | |
del arg34_1 | |
# Source Nodes: [out_102], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf443 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf442, arg272_1, arg273_1, 1) | |
del arg272_1 | |
del arg273_1 | |
buf444 = buf443 | |
del buf443 | |
buf446 = buf421; del buf421 # reuse | |
# Source Nodes: [setitem_34, setitem_35, y_51], Original ATen: [aten.bmm, aten.index_put] | |
triton_poi_fused_bmm_index_put_1.run(arg454_1, buf444, arg66_1, arg274_1, buf446, arg275_1, 4096, grid=grid(4096), stream=stream0) | |
del buf444 | |
buf447 = buf426; del buf426 # reuse | |
# Source Nodes: [y_51], Original ATen: [aten.bmm] | |
triton_red_fused_bmm_2.run(buf446, arg274_1, buf447, 6656, 128, grid=grid(6656), stream=stream0) | |
del arg274_1 | |
buf451 = buf422; del buf422 # reuse | |
# Source Nodes: [mask, y_51], Original ATen: [aten._softmax, aten.add, aten.bmm, aten.index, aten.logical_not, aten.masked_fill, aten.zeros_like] | |
triton_per_fused__softmax_add_bmm_index_logical_not_masked_fill_zeros_like_3.run(buf447, arg454_1, arg67_1, buf451, 32, 208, grid=grid(32), stream=stream0) | |
buf452 = buf427; del buf427 # reuse | |
# Source Nodes: [y_51], Original ATen: [aten.bmm] | |
triton_red_fused_bmm_4.run(buf451, arg275_1, buf452, 8192, 104, grid=grid(8192), stream=stream0) | |
del arg275_1 | |
buf454 = buf442; del buf442 # reuse | |
# Source Nodes: [out_103, y_51], Original ATen: [aten.bmm, torchao.fp16act_fp6weight_linear] | |
triton_per_fused_bmm_fp16act_fp6weight_linear_5.run(buf452, buf454, 4096, 2, grid=grid(4096), stream=stream0) | |
# Source Nodes: [out_103], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf455 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf454, arg276_1, arg277_1, 13) | |
del arg276_1 | |
del arg277_1 | |
buf456 = buf455 | |
del buf455 | |
buf457 = buf405; del buf405 # reuse | |
buf459 = buf454; del buf454 # reuse | |
buf462 = buf407; del buf407 # reuse | |
# Source Nodes: [float_72, h_16, h_17, mean_35, mul_266, out_101, out_104, out_105, out_95], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, torchao.fp16act_fp6weight_linear] | |
triton_red_fused__to_copy_add_fp16act_fp6weight_linear_mean_mul_13.run(buf457, buf415, buf431, buf440, buf456, arg35_1, buf459, buf462, 1, 4096, grid=grid(1), stream=stream0) | |
del arg35_1 | |
del buf415 | |
del buf431 | |
del buf440 | |
del buf456 | |
# Source Nodes: [out_104], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf460 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf459, arg278_1, arg279_1, 1) | |
del arg278_1 | |
del arg279_1 | |
buf461 = buf460 | |
del buf460 | |
# Source Nodes: [out_105], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf463 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf462, arg280_1, arg281_1, 1) | |
del arg280_1 | |
del arg281_1 | |
buf464 = buf463 | |
del buf463 | |
buf465 = buf461; del buf461 # reuse | |
# Source Nodes: [out_106], Original ATen: [torchao.fp16act_fp6weight_linear] | |
triton_poi_fused_fp16act_fp6weight_linear_7.run(buf465, buf464, 11008, grid=grid(11008), stream=stream0) | |
del buf464 | |
# Source Nodes: [out_106], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf466 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf465, arg282_1, arg283_1, 13) | |
del arg282_1 | |
del arg283_1 | |
del buf465 | |
buf467 = buf466 | |
del buf466 | |
buf469 = buf462; del buf462 # reuse | |
# Source Nodes: [float_73, mean_36, mul_270, out_107, out_108], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, torchao.fp16act_fp6weight_linear] | |
triton_red_fused__to_copy_add_fp16act_fp6weight_linear_mean_mul_10.run(buf457, buf467, arg36_1, buf469, 1, 4096, grid=grid(1), stream=stream0) | |
del arg36_1 | |
# Source Nodes: [out_108], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf470 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf469, arg284_1, arg285_1, 1) | |
del arg284_1 | |
del arg285_1 | |
buf471 = buf470 | |
del buf470 | |
buf473 = buf446; del buf446 # reuse | |
# Source Nodes: [setitem_36, setitem_37, y_54], Original ATen: [aten.bmm, aten.index_put] | |
triton_poi_fused_bmm_index_put_1.run(arg454_1, buf471, arg66_1, arg286_1, buf473, arg287_1, 4096, grid=grid(4096), stream=stream0) | |
del buf471 | |
buf474 = buf451; del buf451 # reuse | |
# Source Nodes: [y_54], Original ATen: [aten.bmm] | |
triton_red_fused_bmm_2.run(buf473, arg286_1, buf474, 6656, 128, grid=grid(6656), stream=stream0) | |
del arg286_1 | |
buf478 = buf447; del buf447 # reuse | |
# Source Nodes: [mask, y_54], Original ATen: [aten._softmax, aten.add, aten.bmm, aten.index, aten.logical_not, aten.masked_fill, aten.zeros_like] | |
triton_per_fused__softmax_add_bmm_index_logical_not_masked_fill_zeros_like_3.run(buf474, arg454_1, arg67_1, buf478, 32, 208, grid=grid(32), stream=stream0) | |
buf479 = buf452; del buf452 # reuse | |
# Source Nodes: [y_54], Original ATen: [aten.bmm] | |
triton_red_fused_bmm_4.run(buf478, arg287_1, buf479, 8192, 104, grid=grid(8192), stream=stream0) | |
del arg287_1 | |
buf481 = buf469; del buf469 # reuse | |
# Source Nodes: [out_109, y_54], Original ATen: [aten.bmm, torchao.fp16act_fp6weight_linear] | |
triton_per_fused_bmm_fp16act_fp6weight_linear_5.run(buf479, buf481, 4096, 2, grid=grid(4096), stream=stream0) | |
# Source Nodes: [out_109], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf482 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf481, arg288_1, arg289_1, 13) | |
del arg288_1 | |
del arg289_1 | |
buf483 = buf482 | |
del buf482 | |
buf485 = reinterpret_tensor(buf481, (1, 1, 4096), (4096, 4096, 1), 0); del buf481 # reuse | |
# Source Nodes: [add_112, float_76, h_18, mean_37, mul_281, mul_282, mul_283, out_107, output_37, rsqrt_37], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, aten.rsqrt] | |
triton_red_fused__to_copy_add_mean_mul_rsqrt_11.run(buf457, buf467, buf483, arg37_1, buf485, 1, 4096, grid=grid(1), stream=stream0) | |
del arg37_1 | |
# Source Nodes: [out_110], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf486 = torch.ops.torchao.fp16act_fp6weight_linear.default(reinterpret_tensor(buf485, (1, 4096), (4096, 1), 0), arg290_1, arg291_1, 1) | |
del arg290_1 | |
del arg291_1 | |
buf487 = buf486 | |
del buf486 | |
# Source Nodes: [out_111], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf488 = torch.ops.torchao.fp16act_fp6weight_linear.default(reinterpret_tensor(buf485, (1, 4096), (4096, 1), 0), arg292_1, arg293_1, 1) | |
del arg292_1 | |
del arg293_1 | |
buf489 = buf488 | |
del buf488 | |
buf490 = buf487; del buf487 # reuse | |
# Source Nodes: [out_112], Original ATen: [torchao.fp16act_fp6weight_linear] | |
triton_poi_fused_fp16act_fp6weight_linear_7.run(buf490, buf489, 11008, grid=grid(11008), stream=stream0) | |
del buf489 | |
# Source Nodes: [out_112], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf491 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf490, arg294_1, arg295_1, 13) | |
del arg294_1 | |
del arg295_1 | |
del buf490 | |
buf492 = buf491 | |
del buf491 | |
buf494 = reinterpret_tensor(buf485, (1, 4096), (4096, 1), 0); del buf485 # reuse | |
# Source Nodes: [float_77, h_18, mean_38, mul_285, out_107, out_113, out_114], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, torchao.fp16act_fp6weight_linear] | |
triton_red_fused__to_copy_add_fp16act_fp6weight_linear_mean_mul_12.run(buf457, buf467, buf483, buf492, arg38_1, buf494, 1, 4096, grid=grid(1), stream=stream0) | |
del arg38_1 | |
# Source Nodes: [out_114], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf495 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf494, arg296_1, arg297_1, 1) | |
del arg296_1 | |
del arg297_1 | |
buf496 = buf495 | |
del buf495 | |
buf498 = buf473; del buf473 # reuse | |
# Source Nodes: [setitem_38, setitem_39, y_57], Original ATen: [aten.bmm, aten.index_put] | |
triton_poi_fused_bmm_index_put_1.run(arg454_1, buf496, arg66_1, arg298_1, buf498, arg299_1, 4096, grid=grid(4096), stream=stream0) | |
del buf496 | |
buf499 = buf478; del buf478 # reuse | |
# Source Nodes: [y_57], Original ATen: [aten.bmm] | |
triton_red_fused_bmm_2.run(buf498, arg298_1, buf499, 6656, 128, grid=grid(6656), stream=stream0) | |
del arg298_1 | |
buf503 = buf474; del buf474 # reuse | |
# Source Nodes: [mask, y_57], Original ATen: [aten._softmax, aten.add, aten.bmm, aten.index, aten.logical_not, aten.masked_fill, aten.zeros_like] | |
triton_per_fused__softmax_add_bmm_index_logical_not_masked_fill_zeros_like_3.run(buf499, arg454_1, arg67_1, buf503, 32, 208, grid=grid(32), stream=stream0) | |
buf504 = buf479; del buf479 # reuse | |
# Source Nodes: [y_57], Original ATen: [aten.bmm] | |
triton_red_fused_bmm_4.run(buf503, arg299_1, buf504, 8192, 104, grid=grid(8192), stream=stream0) | |
del arg299_1 | |
buf506 = buf494; del buf494 # reuse | |
# Source Nodes: [out_115, y_57], Original ATen: [aten.bmm, torchao.fp16act_fp6weight_linear] | |
triton_per_fused_bmm_fp16act_fp6weight_linear_5.run(buf504, buf506, 4096, 2, grid=grid(4096), stream=stream0) | |
# Source Nodes: [out_115], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf507 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf506, arg300_1, arg301_1, 13) | |
del arg300_1 | |
del arg301_1 | |
buf508 = buf507 | |
del buf507 | |
buf509 = buf457; del buf457 # reuse | |
buf511 = buf506; del buf506 # reuse | |
buf514 = buf459; del buf459 # reuse | |
# Source Nodes: [float_80, h_18, h_19, mean_39, mul_296, out_107, out_113, out_116, out_117], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, torchao.fp16act_fp6weight_linear] | |
triton_red_fused__to_copy_add_fp16act_fp6weight_linear_mean_mul_13.run(buf509, buf467, buf483, buf492, buf508, arg39_1, buf511, buf514, 1, 4096, grid=grid(1), stream=stream0) | |
del arg39_1 | |
del buf467 | |
del buf483 | |
del buf492 | |
del buf508 | |
# Source Nodes: [out_116], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf512 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf511, arg302_1, arg303_1, 1) | |
del arg302_1 | |
del arg303_1 | |
buf513 = buf512 | |
del buf512 | |
# Source Nodes: [out_117], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf515 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf514, arg304_1, arg305_1, 1) | |
del arg304_1 | |
del arg305_1 | |
buf516 = buf515 | |
del buf515 | |
buf517 = buf513; del buf513 # reuse | |
# Source Nodes: [out_118], Original ATen: [torchao.fp16act_fp6weight_linear] | |
triton_poi_fused_fp16act_fp6weight_linear_7.run(buf517, buf516, 11008, grid=grid(11008), stream=stream0) | |
del buf516 | |
# Source Nodes: [out_118], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf518 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf517, arg306_1, arg307_1, 13) | |
del arg306_1 | |
del arg307_1 | |
del buf517 | |
buf519 = buf518 | |
del buf518 | |
buf521 = buf514; del buf514 # reuse | |
# Source Nodes: [float_81, mean_40, mul_300, out_119, out_120], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, torchao.fp16act_fp6weight_linear] | |
triton_red_fused__to_copy_add_fp16act_fp6weight_linear_mean_mul_10.run(buf509, buf519, arg40_1, buf521, 1, 4096, grid=grid(1), stream=stream0) | |
del arg40_1 | |
# Source Nodes: [out_120], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf522 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf521, arg308_1, arg309_1, 1) | |
del arg308_1 | |
del arg309_1 | |
buf523 = buf522 | |
del buf522 | |
buf525 = buf498; del buf498 # reuse | |
# Source Nodes: [setitem_40, setitem_41, y_60], Original ATen: [aten.bmm, aten.index_put] | |
triton_poi_fused_bmm_index_put_1.run(arg454_1, buf523, arg66_1, arg310_1, buf525, arg311_1, 4096, grid=grid(4096), stream=stream0) | |
del buf523 | |
buf526 = buf503; del buf503 # reuse | |
# Source Nodes: [y_60], Original ATen: [aten.bmm] | |
triton_red_fused_bmm_2.run(buf525, arg310_1, buf526, 6656, 128, grid=grid(6656), stream=stream0) | |
del arg310_1 | |
buf530 = buf499; del buf499 # reuse | |
# Source Nodes: [mask, y_60], Original ATen: [aten._softmax, aten.add, aten.bmm, aten.index, aten.logical_not, aten.masked_fill, aten.zeros_like] | |
triton_per_fused__softmax_add_bmm_index_logical_not_masked_fill_zeros_like_3.run(buf526, arg454_1, arg67_1, buf530, 32, 208, grid=grid(32), stream=stream0) | |
buf531 = buf504; del buf504 # reuse | |
# Source Nodes: [y_60], Original ATen: [aten.bmm] | |
triton_red_fused_bmm_4.run(buf530, arg311_1, buf531, 8192, 104, grid=grid(8192), stream=stream0) | |
del arg311_1 | |
buf533 = buf521; del buf521 # reuse | |
# Source Nodes: [out_121, y_60], Original ATen: [aten.bmm, torchao.fp16act_fp6weight_linear] | |
triton_per_fused_bmm_fp16act_fp6weight_linear_5.run(buf531, buf533, 4096, 2, grid=grid(4096), stream=stream0) | |
# Source Nodes: [out_121], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf534 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf533, arg312_1, arg313_1, 13) | |
del arg312_1 | |
del arg313_1 | |
buf535 = buf534 | |
del buf534 | |
buf537 = reinterpret_tensor(buf533, (1, 1, 4096), (4096, 4096, 1), 0); del buf533 # reuse | |
# Source Nodes: [add_124, float_84, h_20, mean_41, mul_311, mul_312, mul_313, out_119, output_41, rsqrt_41], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, aten.rsqrt] | |
triton_red_fused__to_copy_add_mean_mul_rsqrt_11.run(buf509, buf519, buf535, arg41_1, buf537, 1, 4096, grid=grid(1), stream=stream0) | |
del arg41_1 | |
# Source Nodes: [out_122], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf538 = torch.ops.torchao.fp16act_fp6weight_linear.default(reinterpret_tensor(buf537, (1, 4096), (4096, 1), 0), arg314_1, arg315_1, 1) | |
del arg314_1 | |
del arg315_1 | |
buf539 = buf538 | |
del buf538 | |
# Source Nodes: [out_123], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf540 = torch.ops.torchao.fp16act_fp6weight_linear.default(reinterpret_tensor(buf537, (1, 4096), (4096, 1), 0), arg316_1, arg317_1, 1) | |
del arg316_1 | |
del arg317_1 | |
buf541 = buf540 | |
del buf540 | |
buf542 = buf539; del buf539 # reuse | |
# Source Nodes: [out_124], Original ATen: [torchao.fp16act_fp6weight_linear] | |
triton_poi_fused_fp16act_fp6weight_linear_7.run(buf542, buf541, 11008, grid=grid(11008), stream=stream0) | |
del buf541 | |
# Source Nodes: [out_124], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf543 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf542, arg318_1, arg319_1, 13) | |
del arg318_1 | |
del arg319_1 | |
del buf542 | |
buf544 = buf543 | |
del buf543 | |
buf546 = reinterpret_tensor(buf537, (1, 4096), (4096, 1), 0); del buf537 # reuse | |
# Source Nodes: [float_85, h_20, mean_42, mul_315, out_119, out_125, out_126], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, torchao.fp16act_fp6weight_linear] | |
triton_red_fused__to_copy_add_fp16act_fp6weight_linear_mean_mul_12.run(buf509, buf519, buf535, buf544, arg42_1, buf546, 1, 4096, grid=grid(1), stream=stream0) | |
del arg42_1 | |
# Source Nodes: [out_126], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf547 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf546, arg320_1, arg321_1, 1) | |
del arg320_1 | |
del arg321_1 | |
buf548 = buf547 | |
del buf547 | |
buf550 = buf525; del buf525 # reuse | |
# Source Nodes: [setitem_42, setitem_43, y_63], Original ATen: [aten.bmm, aten.index_put] | |
triton_poi_fused_bmm_index_put_1.run(arg454_1, buf548, arg66_1, arg322_1, buf550, arg323_1, 4096, grid=grid(4096), stream=stream0) | |
del buf548 | |
buf551 = buf530; del buf530 # reuse | |
# Source Nodes: [y_63], Original ATen: [aten.bmm] | |
triton_red_fused_bmm_2.run(buf550, arg322_1, buf551, 6656, 128, grid=grid(6656), stream=stream0) | |
del arg322_1 | |
buf555 = buf526; del buf526 # reuse | |
# Source Nodes: [mask, y_63], Original ATen: [aten._softmax, aten.add, aten.bmm, aten.index, aten.logical_not, aten.masked_fill, aten.zeros_like] | |
triton_per_fused__softmax_add_bmm_index_logical_not_masked_fill_zeros_like_3.run(buf551, arg454_1, arg67_1, buf555, 32, 208, grid=grid(32), stream=stream0) | |
buf556 = buf531; del buf531 # reuse | |
# Source Nodes: [y_63], Original ATen: [aten.bmm] | |
triton_red_fused_bmm_4.run(buf555, arg323_1, buf556, 8192, 104, grid=grid(8192), stream=stream0) | |
del arg323_1 | |
buf558 = buf546; del buf546 # reuse | |
# Source Nodes: [out_127, y_63], Original ATen: [aten.bmm, torchao.fp16act_fp6weight_linear] | |
triton_per_fused_bmm_fp16act_fp6weight_linear_5.run(buf556, buf558, 4096, 2, grid=grid(4096), stream=stream0) | |
# Source Nodes: [out_127], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf559 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf558, arg324_1, arg325_1, 13) | |
del arg324_1 | |
del arg325_1 | |
buf560 = buf559 | |
del buf559 | |
buf561 = buf509; del buf509 # reuse | |
buf563 = buf558; del buf558 # reuse | |
buf566 = buf511; del buf511 # reuse | |
# Source Nodes: [float_88, h_20, h_21, mean_43, mul_326, out_119, out_125, out_128, out_129], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, torchao.fp16act_fp6weight_linear] | |
triton_red_fused__to_copy_add_fp16act_fp6weight_linear_mean_mul_13.run(buf561, buf519, buf535, buf544, buf560, arg43_1, buf563, buf566, 1, 4096, grid=grid(1), stream=stream0) | |
del arg43_1 | |
del buf519 | |
del buf535 | |
del buf544 | |
del buf560 | |
# Source Nodes: [out_128], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf564 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf563, arg326_1, arg327_1, 1) | |
del arg326_1 | |
del arg327_1 | |
buf565 = buf564 | |
del buf564 | |
# Source Nodes: [out_129], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf567 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf566, arg328_1, arg329_1, 1) | |
del arg328_1 | |
del arg329_1 | |
buf568 = buf567 | |
del buf567 | |
buf569 = buf565; del buf565 # reuse | |
# Source Nodes: [out_130], Original ATen: [torchao.fp16act_fp6weight_linear] | |
triton_poi_fused_fp16act_fp6weight_linear_7.run(buf569, buf568, 11008, grid=grid(11008), stream=stream0) | |
del buf568 | |
# Source Nodes: [out_130], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf570 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf569, arg330_1, arg331_1, 13) | |
del arg330_1 | |
del arg331_1 | |
del buf569 | |
buf571 = buf570 | |
del buf570 | |
buf573 = buf566; del buf566 # reuse | |
# Source Nodes: [float_89, mean_44, mul_330, out_131, out_132], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, torchao.fp16act_fp6weight_linear] | |
triton_red_fused__to_copy_add_fp16act_fp6weight_linear_mean_mul_10.run(buf561, buf571, arg44_1, buf573, 1, 4096, grid=grid(1), stream=stream0) | |
del arg44_1 | |
# Source Nodes: [out_132], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf574 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf573, arg332_1, arg333_1, 1) | |
del arg332_1 | |
del arg333_1 | |
buf575 = buf574 | |
del buf574 | |
buf577 = buf550; del buf550 # reuse | |
# Source Nodes: [setitem_44, setitem_45, y_66], Original ATen: [aten.bmm, aten.index_put] | |
triton_poi_fused_bmm_index_put_1.run(arg454_1, buf575, arg66_1, arg334_1, buf577, arg335_1, 4096, grid=grid(4096), stream=stream0) | |
del buf575 | |
buf578 = buf555; del buf555 # reuse | |
# Source Nodes: [y_66], Original ATen: [aten.bmm] | |
triton_red_fused_bmm_2.run(buf577, arg334_1, buf578, 6656, 128, grid=grid(6656), stream=stream0) | |
del arg334_1 | |
buf582 = buf551; del buf551 # reuse | |
# Source Nodes: [mask, y_66], Original ATen: [aten._softmax, aten.add, aten.bmm, aten.index, aten.logical_not, aten.masked_fill, aten.zeros_like] | |
triton_per_fused__softmax_add_bmm_index_logical_not_masked_fill_zeros_like_3.run(buf578, arg454_1, arg67_1, buf582, 32, 208, grid=grid(32), stream=stream0) | |
buf583 = buf556; del buf556 # reuse | |
# Source Nodes: [y_66], Original ATen: [aten.bmm] | |
triton_red_fused_bmm_4.run(buf582, arg335_1, buf583, 8192, 104, grid=grid(8192), stream=stream0) | |
del arg335_1 | |
buf585 = buf573; del buf573 # reuse | |
# Source Nodes: [out_133, y_66], Original ATen: [aten.bmm, torchao.fp16act_fp6weight_linear] | |
triton_per_fused_bmm_fp16act_fp6weight_linear_5.run(buf583, buf585, 4096, 2, grid=grid(4096), stream=stream0) | |
# Source Nodes: [out_133], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf586 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf585, arg336_1, arg337_1, 13) | |
del arg336_1 | |
del arg337_1 | |
buf587 = buf586 | |
del buf586 | |
buf589 = reinterpret_tensor(buf585, (1, 1, 4096), (4096, 4096, 1), 0); del buf585 # reuse | |
# Source Nodes: [add_136, float_92, h_22, mean_45, mul_341, mul_342, mul_343, out_131, output_45, rsqrt_45], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, aten.rsqrt] | |
triton_red_fused__to_copy_add_mean_mul_rsqrt_11.run(buf561, buf571, buf587, arg45_1, buf589, 1, 4096, grid=grid(1), stream=stream0) | |
del arg45_1 | |
# Source Nodes: [out_134], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf590 = torch.ops.torchao.fp16act_fp6weight_linear.default(reinterpret_tensor(buf589, (1, 4096), (4096, 1), 0), arg338_1, arg339_1, 1) | |
del arg338_1 | |
del arg339_1 | |
buf591 = buf590 | |
del buf590 | |
# Source Nodes: [out_135], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf592 = torch.ops.torchao.fp16act_fp6weight_linear.default(reinterpret_tensor(buf589, (1, 4096), (4096, 1), 0), arg340_1, arg341_1, 1) | |
del arg340_1 | |
del arg341_1 | |
buf593 = buf592 | |
del buf592 | |
buf594 = buf591; del buf591 # reuse | |
# Source Nodes: [out_136], Original ATen: [torchao.fp16act_fp6weight_linear] | |
triton_poi_fused_fp16act_fp6weight_linear_7.run(buf594, buf593, 11008, grid=grid(11008), stream=stream0) | |
del buf593 | |
# Source Nodes: [out_136], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf595 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf594, arg342_1, arg343_1, 13) | |
del arg342_1 | |
del arg343_1 | |
del buf594 | |
buf596 = buf595 | |
del buf595 | |
buf598 = reinterpret_tensor(buf589, (1, 4096), (4096, 1), 0); del buf589 # reuse | |
# Source Nodes: [float_93, h_22, mean_46, mul_345, out_131, out_137, out_138], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, torchao.fp16act_fp6weight_linear] | |
triton_red_fused__to_copy_add_fp16act_fp6weight_linear_mean_mul_12.run(buf561, buf571, buf587, buf596, arg46_1, buf598, 1, 4096, grid=grid(1), stream=stream0) | |
del arg46_1 | |
# Source Nodes: [out_138], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf599 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf598, arg344_1, arg345_1, 1) | |
del arg344_1 | |
del arg345_1 | |
buf600 = buf599 | |
del buf599 | |
buf602 = buf577; del buf577 # reuse | |
# Source Nodes: [setitem_46, setitem_47, y_69], Original ATen: [aten.bmm, aten.index_put] | |
triton_poi_fused_bmm_index_put_1.run(arg454_1, buf600, arg66_1, arg346_1, buf602, arg347_1, 4096, grid=grid(4096), stream=stream0) | |
del buf600 | |
buf603 = buf582; del buf582 # reuse | |
# Source Nodes: [y_69], Original ATen: [aten.bmm] | |
triton_red_fused_bmm_2.run(buf602, arg346_1, buf603, 6656, 128, grid=grid(6656), stream=stream0) | |
del arg346_1 | |
buf607 = buf578; del buf578 # reuse | |
# Source Nodes: [mask, y_69], Original ATen: [aten._softmax, aten.add, aten.bmm, aten.index, aten.logical_not, aten.masked_fill, aten.zeros_like] | |
triton_per_fused__softmax_add_bmm_index_logical_not_masked_fill_zeros_like_3.run(buf603, arg454_1, arg67_1, buf607, 32, 208, grid=grid(32), stream=stream0) | |
buf608 = buf583; del buf583 # reuse | |
# Source Nodes: [y_69], Original ATen: [aten.bmm] | |
triton_red_fused_bmm_4.run(buf607, arg347_1, buf608, 8192, 104, grid=grid(8192), stream=stream0) | |
del arg347_1 | |
buf610 = buf598; del buf598 # reuse | |
# Source Nodes: [out_139, y_69], Original ATen: [aten.bmm, torchao.fp16act_fp6weight_linear] | |
triton_per_fused_bmm_fp16act_fp6weight_linear_5.run(buf608, buf610, 4096, 2, grid=grid(4096), stream=stream0) | |
# Source Nodes: [out_139], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf611 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf610, arg348_1, arg349_1, 13) | |
del arg348_1 | |
del arg349_1 | |
buf612 = buf611 | |
del buf611 | |
buf613 = buf561; del buf561 # reuse | |
buf615 = buf610; del buf610 # reuse | |
buf618 = buf563; del buf563 # reuse | |
# Source Nodes: [float_96, h_22, h_23, mean_47, mul_356, out_131, out_137, out_140, out_141], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, torchao.fp16act_fp6weight_linear] | |
triton_red_fused__to_copy_add_fp16act_fp6weight_linear_mean_mul_13.run(buf613, buf571, buf587, buf596, buf612, arg47_1, buf615, buf618, 1, 4096, grid=grid(1), stream=stream0) | |
del arg47_1 | |
del buf571 | |
del buf587 | |
del buf596 | |
del buf612 | |
# Source Nodes: [out_140], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf616 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf615, arg350_1, arg351_1, 1) | |
del arg350_1 | |
del arg351_1 | |
buf617 = buf616 | |
del buf616 | |
# Source Nodes: [out_141], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf619 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf618, arg352_1, arg353_1, 1) | |
del arg352_1 | |
del arg353_1 | |
buf620 = buf619 | |
del buf619 | |
buf621 = buf617; del buf617 # reuse | |
# Source Nodes: [out_142], Original ATen: [torchao.fp16act_fp6weight_linear] | |
triton_poi_fused_fp16act_fp6weight_linear_7.run(buf621, buf620, 11008, grid=grid(11008), stream=stream0) | |
del buf620 | |
# Source Nodes: [out_142], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf622 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf621, arg354_1, arg355_1, 13) | |
del arg354_1 | |
del arg355_1 | |
del buf621 | |
buf623 = buf622 | |
del buf622 | |
buf625 = buf618; del buf618 # reuse | |
# Source Nodes: [float_97, mean_48, mul_360, out_143, out_144], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, torchao.fp16act_fp6weight_linear] | |
triton_red_fused__to_copy_add_fp16act_fp6weight_linear_mean_mul_10.run(buf613, buf623, arg48_1, buf625, 1, 4096, grid=grid(1), stream=stream0) | |
del arg48_1 | |
# Source Nodes: [out_144], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf626 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf625, arg356_1, arg357_1, 1) | |
del arg356_1 | |
del arg357_1 | |
buf627 = buf626 | |
del buf626 | |
buf629 = buf602; del buf602 # reuse | |
# Source Nodes: [setitem_48, setitem_49, y_72], Original ATen: [aten.bmm, aten.index_put] | |
triton_poi_fused_bmm_index_put_1.run(arg454_1, buf627, arg66_1, arg358_1, buf629, arg359_1, 4096, grid=grid(4096), stream=stream0) | |
del buf627 | |
buf630 = buf607; del buf607 # reuse | |
# Source Nodes: [y_72], Original ATen: [aten.bmm] | |
triton_red_fused_bmm_2.run(buf629, arg358_1, buf630, 6656, 128, grid=grid(6656), stream=stream0) | |
del arg358_1 | |
buf634 = buf603; del buf603 # reuse | |
# Source Nodes: [mask, y_72], Original ATen: [aten._softmax, aten.add, aten.bmm, aten.index, aten.logical_not, aten.masked_fill, aten.zeros_like] | |
triton_per_fused__softmax_add_bmm_index_logical_not_masked_fill_zeros_like_3.run(buf630, arg454_1, arg67_1, buf634, 32, 208, grid=grid(32), stream=stream0) | |
buf635 = buf608; del buf608 # reuse | |
# Source Nodes: [y_72], Original ATen: [aten.bmm] | |
triton_red_fused_bmm_4.run(buf634, arg359_1, buf635, 8192, 104, grid=grid(8192), stream=stream0) | |
del arg359_1 | |
buf637 = buf625; del buf625 # reuse | |
# Source Nodes: [out_145, y_72], Original ATen: [aten.bmm, torchao.fp16act_fp6weight_linear] | |
triton_per_fused_bmm_fp16act_fp6weight_linear_5.run(buf635, buf637, 4096, 2, grid=grid(4096), stream=stream0) | |
# Source Nodes: [out_145], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf638 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf637, arg360_1, arg361_1, 13) | |
del arg360_1 | |
del arg361_1 | |
buf639 = buf638 | |
del buf638 | |
buf641 = reinterpret_tensor(buf637, (1, 1, 4096), (4096, 4096, 1), 0); del buf637 # reuse | |
# Source Nodes: [add_148, float_100, h_24, mean_49, mul_371, mul_372, mul_373, out_143, output_49, rsqrt_49], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, aten.rsqrt] | |
triton_red_fused__to_copy_add_mean_mul_rsqrt_11.run(buf613, buf623, buf639, arg49_1, buf641, 1, 4096, grid=grid(1), stream=stream0) | |
del arg49_1 | |
# Source Nodes: [out_146], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf642 = torch.ops.torchao.fp16act_fp6weight_linear.default(reinterpret_tensor(buf641, (1, 4096), (4096, 1), 0), arg362_1, arg363_1, 1) | |
del arg362_1 | |
del arg363_1 | |
buf643 = buf642 | |
del buf642 | |
# Source Nodes: [out_147], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf644 = torch.ops.torchao.fp16act_fp6weight_linear.default(reinterpret_tensor(buf641, (1, 4096), (4096, 1), 0), arg364_1, arg365_1, 1) | |
del arg364_1 | |
del arg365_1 | |
buf645 = buf644 | |
del buf644 | |
buf646 = buf643; del buf643 # reuse | |
# Source Nodes: [out_148], Original ATen: [torchao.fp16act_fp6weight_linear] | |
triton_poi_fused_fp16act_fp6weight_linear_7.run(buf646, buf645, 11008, grid=grid(11008), stream=stream0) | |
del buf645 | |
# Source Nodes: [out_148], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf647 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf646, arg366_1, arg367_1, 13) | |
del arg366_1 | |
del arg367_1 | |
del buf646 | |
buf648 = buf647 | |
del buf647 | |
buf650 = reinterpret_tensor(buf641, (1, 4096), (4096, 1), 0); del buf641 # reuse | |
# Source Nodes: [float_101, h_24, mean_50, mul_375, out_143, out_149, out_150], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, torchao.fp16act_fp6weight_linear] | |
triton_red_fused__to_copy_add_fp16act_fp6weight_linear_mean_mul_12.run(buf613, buf623, buf639, buf648, arg50_1, buf650, 1, 4096, grid=grid(1), stream=stream0) | |
del arg50_1 | |
# Source Nodes: [out_150], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf651 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf650, arg368_1, arg369_1, 1) | |
del arg368_1 | |
del arg369_1 | |
buf652 = buf651 | |
del buf651 | |
buf654 = buf629; del buf629 # reuse | |
# Source Nodes: [setitem_50, setitem_51, y_75], Original ATen: [aten.bmm, aten.index_put] | |
triton_poi_fused_bmm_index_put_1.run(arg454_1, buf652, arg66_1, arg370_1, buf654, arg371_1, 4096, grid=grid(4096), stream=stream0) | |
del buf652 | |
buf655 = buf634; del buf634 # reuse | |
# Source Nodes: [y_75], Original ATen: [aten.bmm] | |
triton_red_fused_bmm_2.run(buf654, arg370_1, buf655, 6656, 128, grid=grid(6656), stream=stream0) | |
del arg370_1 | |
buf659 = buf630; del buf630 # reuse | |
# Source Nodes: [mask, y_75], Original ATen: [aten._softmax, aten.add, aten.bmm, aten.index, aten.logical_not, aten.masked_fill, aten.zeros_like] | |
triton_per_fused__softmax_add_bmm_index_logical_not_masked_fill_zeros_like_3.run(buf655, arg454_1, arg67_1, buf659, 32, 208, grid=grid(32), stream=stream0) | |
buf660 = buf635; del buf635 # reuse | |
# Source Nodes: [y_75], Original ATen: [aten.bmm] | |
triton_red_fused_bmm_4.run(buf659, arg371_1, buf660, 8192, 104, grid=grid(8192), stream=stream0) | |
del arg371_1 | |
buf662 = buf650; del buf650 # reuse | |
# Source Nodes: [out_151, y_75], Original ATen: [aten.bmm, torchao.fp16act_fp6weight_linear] | |
triton_per_fused_bmm_fp16act_fp6weight_linear_5.run(buf660, buf662, 4096, 2, grid=grid(4096), stream=stream0) | |
# Source Nodes: [out_151], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf663 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf662, arg372_1, arg373_1, 13) | |
del arg372_1 | |
del arg373_1 | |
buf664 = buf663 | |
del buf663 | |
buf665 = buf613; del buf613 # reuse | |
buf667 = buf662; del buf662 # reuse | |
buf670 = buf615; del buf615 # reuse | |
# Source Nodes: [float_104, h_24, h_25, mean_51, mul_386, out_143, out_149, out_152, out_153], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, torchao.fp16act_fp6weight_linear] | |
triton_red_fused__to_copy_add_fp16act_fp6weight_linear_mean_mul_13.run(buf665, buf623, buf639, buf648, buf664, arg51_1, buf667, buf670, 1, 4096, grid=grid(1), stream=stream0) | |
del arg51_1 | |
del buf623 | |
del buf639 | |
del buf648 | |
del buf664 | |
# Source Nodes: [out_152], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf668 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf667, arg374_1, arg375_1, 1) | |
del arg374_1 | |
del arg375_1 | |
buf669 = buf668 | |
del buf668 | |
# Source Nodes: [out_153], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf671 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf670, arg376_1, arg377_1, 1) | |
del arg376_1 | |
del arg377_1 | |
buf672 = buf671 | |
del buf671 | |
buf673 = buf669; del buf669 # reuse | |
# Source Nodes: [out_154], Original ATen: [torchao.fp16act_fp6weight_linear] | |
triton_poi_fused_fp16act_fp6weight_linear_7.run(buf673, buf672, 11008, grid=grid(11008), stream=stream0) | |
del buf672 | |
# Source Nodes: [out_154], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf674 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf673, arg378_1, arg379_1, 13) | |
del arg378_1 | |
del arg379_1 | |
del buf673 | |
buf675 = buf674 | |
del buf674 | |
buf677 = buf670; del buf670 # reuse | |
# Source Nodes: [float_105, mean_52, mul_390, out_155, out_156], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, torchao.fp16act_fp6weight_linear] | |
triton_red_fused__to_copy_add_fp16act_fp6weight_linear_mean_mul_10.run(buf665, buf675, arg52_1, buf677, 1, 4096, grid=grid(1), stream=stream0) | |
del arg52_1 | |
# Source Nodes: [out_156], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf678 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf677, arg380_1, arg381_1, 1) | |
del arg380_1 | |
del arg381_1 | |
buf679 = buf678 | |
del buf678 | |
buf681 = buf654; del buf654 # reuse | |
# Source Nodes: [setitem_52, setitem_53, y_78], Original ATen: [aten.bmm, aten.index_put] | |
triton_poi_fused_bmm_index_put_1.run(arg454_1, buf679, arg66_1, arg382_1, buf681, arg383_1, 4096, grid=grid(4096), stream=stream0) | |
del buf679 | |
buf682 = buf659; del buf659 # reuse | |
# Source Nodes: [y_78], Original ATen: [aten.bmm] | |
triton_red_fused_bmm_2.run(buf681, arg382_1, buf682, 6656, 128, grid=grid(6656), stream=stream0) | |
del arg382_1 | |
buf686 = buf655; del buf655 # reuse | |
# Source Nodes: [mask, y_78], Original ATen: [aten._softmax, aten.add, aten.bmm, aten.index, aten.logical_not, aten.masked_fill, aten.zeros_like] | |
triton_per_fused__softmax_add_bmm_index_logical_not_masked_fill_zeros_like_3.run(buf682, arg454_1, arg67_1, buf686, 32, 208, grid=grid(32), stream=stream0) | |
buf687 = buf660; del buf660 # reuse | |
# Source Nodes: [y_78], Original ATen: [aten.bmm] | |
triton_red_fused_bmm_4.run(buf686, arg383_1, buf687, 8192, 104, grid=grid(8192), stream=stream0) | |
del arg383_1 | |
buf689 = buf677; del buf677 # reuse | |
# Source Nodes: [out_157, y_78], Original ATen: [aten.bmm, torchao.fp16act_fp6weight_linear] | |
triton_per_fused_bmm_fp16act_fp6weight_linear_5.run(buf687, buf689, 4096, 2, grid=grid(4096), stream=stream0) | |
# Source Nodes: [out_157], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf690 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf689, arg384_1, arg385_1, 13) | |
del arg384_1 | |
del arg385_1 | |
buf691 = buf690 | |
del buf690 | |
buf693 = reinterpret_tensor(buf689, (1, 1, 4096), (4096, 4096, 1), 0); del buf689 # reuse | |
# Source Nodes: [add_160, float_108, h_26, mean_53, mul_401, mul_402, mul_403, out_155, output_53, rsqrt_53], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, aten.rsqrt] | |
triton_red_fused__to_copy_add_mean_mul_rsqrt_11.run(buf665, buf675, buf691, arg53_1, buf693, 1, 4096, grid=grid(1), stream=stream0) | |
del arg53_1 | |
# Source Nodes: [out_158], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf694 = torch.ops.torchao.fp16act_fp6weight_linear.default(reinterpret_tensor(buf693, (1, 4096), (4096, 1), 0), arg386_1, arg387_1, 1) | |
del arg386_1 | |
del arg387_1 | |
buf695 = buf694 | |
del buf694 | |
# Source Nodes: [out_159], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf696 = torch.ops.torchao.fp16act_fp6weight_linear.default(reinterpret_tensor(buf693, (1, 4096), (4096, 1), 0), arg388_1, arg389_1, 1) | |
del arg388_1 | |
del arg389_1 | |
buf697 = buf696 | |
del buf696 | |
buf698 = buf695; del buf695 # reuse | |
# Source Nodes: [out_160], Original ATen: [torchao.fp16act_fp6weight_linear] | |
triton_poi_fused_fp16act_fp6weight_linear_7.run(buf698, buf697, 11008, grid=grid(11008), stream=stream0) | |
del buf697 | |
# Source Nodes: [out_160], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf699 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf698, arg390_1, arg391_1, 13) | |
del arg390_1 | |
del arg391_1 | |
del buf698 | |
buf700 = buf699 | |
del buf699 | |
buf702 = reinterpret_tensor(buf693, (1, 4096), (4096, 1), 0); del buf693 # reuse | |
# Source Nodes: [float_109, h_26, mean_54, mul_405, out_155, out_161, out_162], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, torchao.fp16act_fp6weight_linear] | |
triton_red_fused__to_copy_add_fp16act_fp6weight_linear_mean_mul_12.run(buf665, buf675, buf691, buf700, arg54_1, buf702, 1, 4096, grid=grid(1), stream=stream0) | |
del arg54_1 | |
# Source Nodes: [out_162], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf703 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf702, arg392_1, arg393_1, 1) | |
del arg392_1 | |
del arg393_1 | |
buf704 = buf703 | |
del buf703 | |
buf706 = buf681; del buf681 # reuse | |
# Source Nodes: [setitem_54, setitem_55, y_81], Original ATen: [aten.bmm, aten.index_put] | |
triton_poi_fused_bmm_index_put_1.run(arg454_1, buf704, arg66_1, arg394_1, buf706, arg395_1, 4096, grid=grid(4096), stream=stream0) | |
del buf704 | |
buf707 = buf686; del buf686 # reuse | |
# Source Nodes: [y_81], Original ATen: [aten.bmm] | |
triton_red_fused_bmm_2.run(buf706, arg394_1, buf707, 6656, 128, grid=grid(6656), stream=stream0) | |
del arg394_1 | |
buf711 = buf682; del buf682 # reuse | |
# Source Nodes: [mask, y_81], Original ATen: [aten._softmax, aten.add, aten.bmm, aten.index, aten.logical_not, aten.masked_fill, aten.zeros_like] | |
triton_per_fused__softmax_add_bmm_index_logical_not_masked_fill_zeros_like_3.run(buf707, arg454_1, arg67_1, buf711, 32, 208, grid=grid(32), stream=stream0) | |
buf712 = buf687; del buf687 # reuse | |
# Source Nodes: [y_81], Original ATen: [aten.bmm] | |
triton_red_fused_bmm_4.run(buf711, arg395_1, buf712, 8192, 104, grid=grid(8192), stream=stream0) | |
del arg395_1 | |
buf714 = buf702; del buf702 # reuse | |
# Source Nodes: [out_163, y_81], Original ATen: [aten.bmm, torchao.fp16act_fp6weight_linear] | |
triton_per_fused_bmm_fp16act_fp6weight_linear_5.run(buf712, buf714, 4096, 2, grid=grid(4096), stream=stream0) | |
# Source Nodes: [out_163], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf715 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf714, arg396_1, arg397_1, 13) | |
del arg396_1 | |
del arg397_1 | |
buf716 = buf715 | |
del buf715 | |
buf717 = buf665; del buf665 # reuse | |
buf719 = buf714; del buf714 # reuse | |
buf722 = buf667; del buf667 # reuse | |
# Source Nodes: [float_112, h_26, h_27, mean_55, mul_416, out_155, out_161, out_164, out_165], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, torchao.fp16act_fp6weight_linear] | |
triton_red_fused__to_copy_add_fp16act_fp6weight_linear_mean_mul_13.run(buf717, buf675, buf691, buf700, buf716, arg55_1, buf719, buf722, 1, 4096, grid=grid(1), stream=stream0) | |
del arg55_1 | |
del buf675 | |
del buf691 | |
del buf700 | |
del buf716 | |
# Source Nodes: [out_164], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf720 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf719, arg398_1, arg399_1, 1) | |
del arg398_1 | |
del arg399_1 | |
buf721 = buf720 | |
del buf720 | |
# Source Nodes: [out_165], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf723 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf722, arg400_1, arg401_1, 1) | |
del arg400_1 | |
del arg401_1 | |
buf724 = buf723 | |
del buf723 | |
buf725 = buf721; del buf721 # reuse | |
# Source Nodes: [out_166], Original ATen: [torchao.fp16act_fp6weight_linear] | |
triton_poi_fused_fp16act_fp6weight_linear_7.run(buf725, buf724, 11008, grid=grid(11008), stream=stream0) | |
del buf724 | |
# Source Nodes: [out_166], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf726 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf725, arg402_1, arg403_1, 13) | |
del arg402_1 | |
del arg403_1 | |
del buf725 | |
buf727 = buf726 | |
del buf726 | |
buf729 = buf722; del buf722 # reuse | |
# Source Nodes: [float_113, mean_56, mul_420, out_167, out_168], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, torchao.fp16act_fp6weight_linear] | |
triton_red_fused__to_copy_add_fp16act_fp6weight_linear_mean_mul_10.run(buf717, buf727, arg56_1, buf729, 1, 4096, grid=grid(1), stream=stream0) | |
del arg56_1 | |
# Source Nodes: [out_168], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf730 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf729, arg404_1, arg405_1, 1) | |
del arg404_1 | |
del arg405_1 | |
buf731 = buf730 | |
del buf730 | |
buf733 = buf706; del buf706 # reuse | |
# Source Nodes: [setitem_56, setitem_57, y_84], Original ATen: [aten.bmm, aten.index_put] | |
triton_poi_fused_bmm_index_put_1.run(arg454_1, buf731, arg66_1, arg406_1, buf733, arg407_1, 4096, grid=grid(4096), stream=stream0) | |
del buf731 | |
buf734 = buf711; del buf711 # reuse | |
# Source Nodes: [y_84], Original ATen: [aten.bmm] | |
triton_red_fused_bmm_2.run(buf733, arg406_1, buf734, 6656, 128, grid=grid(6656), stream=stream0) | |
del arg406_1 | |
buf738 = buf707; del buf707 # reuse | |
# Source Nodes: [mask, y_84], Original ATen: [aten._softmax, aten.add, aten.bmm, aten.index, aten.logical_not, aten.masked_fill, aten.zeros_like] | |
triton_per_fused__softmax_add_bmm_index_logical_not_masked_fill_zeros_like_3.run(buf734, arg454_1, arg67_1, buf738, 32, 208, grid=grid(32), stream=stream0) | |
buf739 = buf712; del buf712 # reuse | |
# Source Nodes: [y_84], Original ATen: [aten.bmm] | |
triton_red_fused_bmm_4.run(buf738, arg407_1, buf739, 8192, 104, grid=grid(8192), stream=stream0) | |
del arg407_1 | |
buf741 = buf729; del buf729 # reuse | |
# Source Nodes: [out_169, y_84], Original ATen: [aten.bmm, torchao.fp16act_fp6weight_linear] | |
triton_per_fused_bmm_fp16act_fp6weight_linear_5.run(buf739, buf741, 4096, 2, grid=grid(4096), stream=stream0) | |
# Source Nodes: [out_169], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf742 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf741, arg408_1, arg409_1, 13) | |
del arg408_1 | |
del arg409_1 | |
buf743 = buf742 | |
del buf742 | |
buf745 = reinterpret_tensor(buf741, (1, 1, 4096), (4096, 4096, 1), 0); del buf741 # reuse | |
# Source Nodes: [add_172, float_116, h_28, mean_57, mul_431, mul_432, mul_433, out_167, output_57, rsqrt_57], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, aten.rsqrt] | |
triton_red_fused__to_copy_add_mean_mul_rsqrt_11.run(buf717, buf727, buf743, arg57_1, buf745, 1, 4096, grid=grid(1), stream=stream0) | |
del arg57_1 | |
# Source Nodes: [out_170], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf746 = torch.ops.torchao.fp16act_fp6weight_linear.default(reinterpret_tensor(buf745, (1, 4096), (4096, 1), 0), arg410_1, arg411_1, 1) | |
del arg410_1 | |
del arg411_1 | |
buf747 = buf746 | |
del buf746 | |
# Source Nodes: [out_171], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf748 = torch.ops.torchao.fp16act_fp6weight_linear.default(reinterpret_tensor(buf745, (1, 4096), (4096, 1), 0), arg412_1, arg413_1, 1) | |
del arg412_1 | |
del arg413_1 | |
buf749 = buf748 | |
del buf748 | |
buf750 = buf747; del buf747 # reuse | |
# Source Nodes: [out_172], Original ATen: [torchao.fp16act_fp6weight_linear] | |
triton_poi_fused_fp16act_fp6weight_linear_7.run(buf750, buf749, 11008, grid=grid(11008), stream=stream0) | |
del buf749 | |
# Source Nodes: [out_172], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf751 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf750, arg414_1, arg415_1, 13) | |
del arg414_1 | |
del arg415_1 | |
del buf750 | |
buf752 = buf751 | |
del buf751 | |
buf754 = reinterpret_tensor(buf745, (1, 4096), (4096, 1), 0); del buf745 # reuse | |
# Source Nodes: [float_117, h_28, mean_58, mul_435, out_167, out_173, out_174], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, torchao.fp16act_fp6weight_linear] | |
triton_red_fused__to_copy_add_fp16act_fp6weight_linear_mean_mul_12.run(buf717, buf727, buf743, buf752, arg58_1, buf754, 1, 4096, grid=grid(1), stream=stream0) | |
del arg58_1 | |
# Source Nodes: [out_174], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf755 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf754, arg416_1, arg417_1, 1) | |
del arg416_1 | |
del arg417_1 | |
buf756 = buf755 | |
del buf755 | |
buf758 = buf733; del buf733 # reuse | |
# Source Nodes: [setitem_58, setitem_59, y_87], Original ATen: [aten.bmm, aten.index_put] | |
triton_poi_fused_bmm_index_put_1.run(arg454_1, buf756, arg66_1, arg418_1, buf758, arg419_1, 4096, grid=grid(4096), stream=stream0) | |
del buf756 | |
buf759 = buf738; del buf738 # reuse | |
# Source Nodes: [y_87], Original ATen: [aten.bmm] | |
triton_red_fused_bmm_2.run(buf758, arg418_1, buf759, 6656, 128, grid=grid(6656), stream=stream0) | |
del arg418_1 | |
buf763 = buf734; del buf734 # reuse | |
# Source Nodes: [mask, y_87], Original ATen: [aten._softmax, aten.add, aten.bmm, aten.index, aten.logical_not, aten.masked_fill, aten.zeros_like] | |
triton_per_fused__softmax_add_bmm_index_logical_not_masked_fill_zeros_like_3.run(buf759, arg454_1, arg67_1, buf763, 32, 208, grid=grid(32), stream=stream0) | |
buf764 = buf739; del buf739 # reuse | |
# Source Nodes: [y_87], Original ATen: [aten.bmm] | |
triton_red_fused_bmm_4.run(buf763, arg419_1, buf764, 8192, 104, grid=grid(8192), stream=stream0) | |
del arg419_1 | |
buf766 = buf754; del buf754 # reuse | |
# Source Nodes: [out_175, y_87], Original ATen: [aten.bmm, torchao.fp16act_fp6weight_linear] | |
triton_per_fused_bmm_fp16act_fp6weight_linear_5.run(buf764, buf766, 4096, 2, grid=grid(4096), stream=stream0) | |
# Source Nodes: [out_175], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf767 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf766, arg420_1, arg421_1, 13) | |
del arg420_1 | |
del arg421_1 | |
buf768 = buf767 | |
del buf767 | |
buf769 = buf717; del buf717 # reuse | |
buf771 = buf766; del buf766 # reuse | |
buf774 = buf719; del buf719 # reuse | |
# Source Nodes: [float_120, h_28, h_29, mean_59, mul_446, out_167, out_173, out_176, out_177], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, torchao.fp16act_fp6weight_linear] | |
triton_red_fused__to_copy_add_fp16act_fp6weight_linear_mean_mul_13.run(buf769, buf727, buf743, buf752, buf768, arg59_1, buf771, buf774, 1, 4096, grid=grid(1), stream=stream0) | |
del arg59_1 | |
del buf727 | |
del buf743 | |
del buf752 | |
del buf768 | |
# Source Nodes: [out_176], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf772 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf771, arg422_1, arg423_1, 1) | |
del arg422_1 | |
del arg423_1 | |
buf773 = buf772 | |
del buf772 | |
# Source Nodes: [out_177], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf775 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf774, arg424_1, arg425_1, 1) | |
del arg424_1 | |
del arg425_1 | |
buf776 = buf775 | |
del buf775 | |
buf777 = buf773; del buf773 # reuse | |
# Source Nodes: [out_178], Original ATen: [torchao.fp16act_fp6weight_linear] | |
triton_poi_fused_fp16act_fp6weight_linear_7.run(buf777, buf776, 11008, grid=grid(11008), stream=stream0) | |
del buf776 | |
# Source Nodes: [out_178], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf778 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf777, arg426_1, arg427_1, 13) | |
del arg426_1 | |
del arg427_1 | |
del buf777 | |
buf779 = buf778 | |
del buf778 | |
buf781 = buf774; del buf774 # reuse | |
# Source Nodes: [float_121, mean_60, mul_450, out_179, out_180], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, torchao.fp16act_fp6weight_linear] | |
triton_red_fused__to_copy_add_fp16act_fp6weight_linear_mean_mul_10.run(buf769, buf779, arg60_1, buf781, 1, 4096, grid=grid(1), stream=stream0) | |
del arg60_1 | |
# Source Nodes: [out_180], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf782 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf781, arg428_1, arg429_1, 1) | |
del arg428_1 | |
del arg429_1 | |
buf783 = buf782 | |
del buf782 | |
buf785 = buf758; del buf758 # reuse | |
# Source Nodes: [setitem_60, setitem_61, y_90], Original ATen: [aten.bmm, aten.index_put] | |
triton_poi_fused_bmm_index_put_1.run(arg454_1, buf783, arg66_1, arg430_1, buf785, arg431_1, 4096, grid=grid(4096), stream=stream0) | |
del buf783 | |
buf786 = buf763; del buf763 # reuse | |
# Source Nodes: [y_90], Original ATen: [aten.bmm] | |
triton_red_fused_bmm_2.run(buf785, arg430_1, buf786, 6656, 128, grid=grid(6656), stream=stream0) | |
del arg430_1 | |
buf790 = buf759; del buf759 # reuse | |
# Source Nodes: [mask, y_90], Original ATen: [aten._softmax, aten.add, aten.bmm, aten.index, aten.logical_not, aten.masked_fill, aten.zeros_like] | |
triton_per_fused__softmax_add_bmm_index_logical_not_masked_fill_zeros_like_3.run(buf786, arg454_1, arg67_1, buf790, 32, 208, grid=grid(32), stream=stream0) | |
buf791 = buf764; del buf764 # reuse | |
# Source Nodes: [y_90], Original ATen: [aten.bmm] | |
triton_red_fused_bmm_4.run(buf790, arg431_1, buf791, 8192, 104, grid=grid(8192), stream=stream0) | |
del arg431_1 | |
buf793 = buf781; del buf781 # reuse | |
# Source Nodes: [out_181, y_90], Original ATen: [aten.bmm, torchao.fp16act_fp6weight_linear] | |
triton_per_fused_bmm_fp16act_fp6weight_linear_5.run(buf791, buf793, 4096, 2, grid=grid(4096), stream=stream0) | |
# Source Nodes: [out_181], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf794 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf793, arg432_1, arg433_1, 13) | |
del arg432_1 | |
del arg433_1 | |
buf795 = buf794 | |
del buf794 | |
buf797 = reinterpret_tensor(buf793, (1, 1, 4096), (4096, 4096, 1), 0); del buf793 # reuse | |
# Source Nodes: [add_184, float_124, h_30, mean_61, mul_461, mul_462, mul_463, out_179, output_61, rsqrt_61], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, aten.rsqrt] | |
triton_red_fused__to_copy_add_mean_mul_rsqrt_11.run(buf769, buf779, buf795, arg61_1, buf797, 1, 4096, grid=grid(1), stream=stream0) | |
del arg61_1 | |
# Source Nodes: [out_182], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf798 = torch.ops.torchao.fp16act_fp6weight_linear.default(reinterpret_tensor(buf797, (1, 4096), (4096, 1), 0), arg434_1, arg435_1, 1) | |
del arg434_1 | |
del arg435_1 | |
buf799 = buf798 | |
del buf798 | |
# Source Nodes: [out_183], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf800 = torch.ops.torchao.fp16act_fp6weight_linear.default(reinterpret_tensor(buf797, (1, 4096), (4096, 1), 0), arg436_1, arg437_1, 1) | |
del arg436_1 | |
del arg437_1 | |
buf801 = buf800 | |
del buf800 | |
buf802 = buf799; del buf799 # reuse | |
# Source Nodes: [out_184], Original ATen: [torchao.fp16act_fp6weight_linear] | |
triton_poi_fused_fp16act_fp6weight_linear_7.run(buf802, buf801, 11008, grid=grid(11008), stream=stream0) | |
del buf801 | |
# Source Nodes: [out_184], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf803 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf802, arg438_1, arg439_1, 13) | |
del arg438_1 | |
del arg439_1 | |
del buf802 | |
buf804 = buf803 | |
del buf803 | |
buf806 = reinterpret_tensor(buf797, (1, 4096), (4096, 1), 0); del buf797 # reuse | |
# Source Nodes: [float_125, h_30, mean_62, mul_465, out_179, out_185, out_186], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, torchao.fp16act_fp6weight_linear] | |
triton_red_fused__to_copy_add_fp16act_fp6weight_linear_mean_mul_12.run(buf769, buf779, buf795, buf804, arg62_1, buf806, 1, 4096, grid=grid(1), stream=stream0) | |
del arg62_1 | |
# Source Nodes: [out_186], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf807 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf806, arg440_1, arg441_1, 1) | |
del arg440_1 | |
del arg441_1 | |
buf808 = buf807 | |
del buf807 | |
buf810 = buf785; del buf785 # reuse | |
# Source Nodes: [setitem_62, setitem_63, y_93], Original ATen: [aten.bmm, aten.index_put] | |
triton_poi_fused_bmm_index_put_1.run(arg454_1, buf808, arg66_1, arg442_1, buf810, arg443_1, 4096, grid=grid(4096), stream=stream0) | |
del arg66_1 | |
del buf808 | |
buf811 = buf790; del buf790 # reuse | |
# Source Nodes: [y_93], Original ATen: [aten.bmm] | |
triton_red_fused_bmm_2.run(buf810, arg442_1, buf811, 6656, 128, grid=grid(6656), stream=stream0) | |
del arg442_1 | |
del buf810 | |
buf815 = buf786; del buf786 # reuse | |
# Source Nodes: [mask, y_93], Original ATen: [aten._softmax, aten.add, aten.bmm, aten.index, aten.logical_not, aten.masked_fill, aten.zeros_like] | |
triton_per_fused__softmax_add_bmm_index_logical_not_masked_fill_zeros_like_3.run(buf811, arg454_1, arg67_1, buf815, 32, 208, grid=grid(32), stream=stream0) | |
del arg454_1 | |
del arg67_1 | |
del buf811 | |
buf816 = buf791; del buf791 # reuse | |
# Source Nodes: [y_93], Original ATen: [aten.bmm] | |
triton_red_fused_bmm_4.run(buf815, arg443_1, buf816, 8192, 104, grid=grid(8192), stream=stream0) | |
del arg443_1 | |
del buf815 | |
buf818 = buf806; del buf806 # reuse | |
# Source Nodes: [out_187, y_93], Original ATen: [aten.bmm, torchao.fp16act_fp6weight_linear] | |
triton_per_fused_bmm_fp16act_fp6weight_linear_5.run(buf816, buf818, 4096, 2, grid=grid(4096), stream=stream0) | |
del buf816 | |
# Source Nodes: [out_187], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf819 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf818, arg444_1, arg445_1, 13) | |
del arg444_1 | |
del arg445_1 | |
buf820 = buf819 | |
del buf819 | |
buf821 = buf769; del buf769 # reuse | |
buf823 = buf818; del buf818 # reuse | |
buf826 = buf771; del buf771 # reuse | |
# Source Nodes: [float_128, h_30, h_31, mean_63, mul_476, out_179, out_185, out_188, out_189], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, torchao.fp16act_fp6weight_linear] | |
triton_red_fused__to_copy_add_fp16act_fp6weight_linear_mean_mul_13.run(buf821, buf779, buf795, buf804, buf820, arg63_1, buf823, buf826, 1, 4096, grid=grid(1), stream=stream0) | |
del arg63_1 | |
del buf779 | |
del buf795 | |
del buf804 | |
del buf820 | |
# Source Nodes: [out_188], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf824 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf823, arg446_1, arg447_1, 1) | |
del arg446_1 | |
del arg447_1 | |
del buf823 | |
buf825 = buf824 | |
del buf824 | |
# Source Nodes: [out_189], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf827 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf826, arg448_1, arg449_1, 1) | |
del arg448_1 | |
del arg449_1 | |
buf828 = buf827 | |
del buf827 | |
buf829 = buf825; del buf825 # reuse | |
# Source Nodes: [out_190], Original ATen: [torchao.fp16act_fp6weight_linear] | |
triton_poi_fused_fp16act_fp6weight_linear_7.run(buf829, buf828, 11008, grid=grid(11008), stream=stream0) | |
del buf828 | |
# Source Nodes: [out_190], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf830 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf829, arg450_1, arg451_1, 13) | |
del arg450_1 | |
del arg451_1 | |
del buf829 | |
buf831 = buf830 | |
del buf830 | |
buf833 = buf826; del buf826 # reuse | |
# Source Nodes: [float_129, mean_64, mul_480, out_191, out_192], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, torchao.fp16act_fp6weight_linear] | |
triton_red_fused__to_copy_add_fp16act_fp6weight_linear_mean_mul_10.run(buf821, buf831, arg64_1, buf833, 1, 4096, grid=grid(1), stream=stream0) | |
del arg64_1 | |
del buf821 | |
del buf831 | |
# Source Nodes: [out_192], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf834 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf833, arg452_1, arg453_1, 1) | |
del arg452_1 | |
del arg453_1 | |
del buf833 | |
buf835 = buf834 | |
del buf834 | |
buf836 = reinterpret_tensor(buf835, (32000, ), (1, ), 0); del buf835 # reuse | |
# Source Nodes: [logits_1], Original ATen: [aten.div] | |
triton_poi_fused_div_16.run(buf836, 32000, grid=grid(32000), stream=stream0) | |
# Source Nodes: [logits_1, topk], Original ATen: [aten.div, aten.topk] | |
buf837 = aten.topk.default(buf836, 200) | |
buf838 = buf837[0] | |
del buf837 | |
buf840 = empty_strided_cuda((1, 4), (4, 1), torch.float32) | |
# Source Nodes: [logits_2, lt, probs], Original ATen: [aten._softmax, aten.lt, aten.scalar_tensor, aten.where] | |
triton_red_fused__softmax_lt_scalar_tensor_where_17.run(buf836, buf838, buf840, 4, 8000, grid=grid(4), stream=stream0) | |
buf841 = empty_strided_cuda((1, ), (1, ), torch.float32) | |
# Source Nodes: [logits_2, lt, probs], Original ATen: [aten._softmax, aten.lt, aten.scalar_tensor, aten.where] | |
triton_per_fused__softmax_lt_scalar_tensor_where_18.run(buf840, buf841, 1, 4, grid=grid(1), stream=stream0) | |
buf842 = buf840; del buf840 # reuse | |
# Source Nodes: [logits_2, lt, probs], Original ATen: [aten._softmax, aten.lt, aten.scalar_tensor, aten.where] | |
triton_red_fused__softmax_lt_scalar_tensor_where_19.run(buf836, buf838, buf841, buf842, 4, 8000, grid=grid(4), stream=stream0) | |
buf843 = empty_strided_cuda((1, ), (1, ), torch.float32) | |
# Source Nodes: [logits_2, lt, probs], Original ATen: [aten._softmax, aten.lt, aten.scalar_tensor, aten.where] | |
triton_per_fused__softmax_lt_scalar_tensor_where_20.run(buf842, buf843, 1, 4, grid=grid(1), stream=stream0) | |
del buf842 | |
buf845 = empty_strided_cuda((1, ), (1, ), torch.int64) | |
# Source Nodes: [], Original ATen: [] | |
aten.randint.low_out(-9223372036854775808, 9223372036854775807, [1], out=buf845) | |
buf844 = buf836; del buf836 # reuse | |
buf848 = empty_strided_cuda((1, ), (1, ), torch.int32) | |
# Source Nodes: [argmax, idx_next, logits_2, lt, probs, q_128, truediv_1], Original ATen: [aten._softmax, aten._to_copy, aten.argmax, aten.div, aten.exponential, aten.lt, aten.scalar_tensor, aten.where] | |
triton_red_fused__softmax__to_copy_argmax_div_exponential_lt_scalar_tensor_where_21.run(buf844, buf838, buf841, buf843, buf845, buf848, 0, 1, 32000, grid=grid(1), stream=stream0) | |
del buf838 | |
del buf841 | |
del buf843 | |
del buf845 | |
return (buf848, buf844, ) | |
def benchmark_compiled_module(times=10, repeat=10): | |
from torch._dynamo.testing import rand_strided | |
from torch._inductor.utils import print_performance | |
arg0_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg1_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg2_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg3_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg4_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg5_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg6_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg7_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg8_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg9_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg10_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg11_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg12_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg13_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg14_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg15_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg16_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg17_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg18_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg19_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg20_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg21_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg22_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg23_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg24_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg25_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg26_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg27_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg28_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg29_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg30_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg31_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg32_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg33_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg34_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg35_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg36_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg37_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg38_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg39_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg40_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg41_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg42_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg43_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg44_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg45_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg46_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg47_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg48_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg49_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg50_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg51_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg52_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg53_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg54_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg55_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg56_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg57_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg58_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg59_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg60_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg61_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg62_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg63_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg64_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg65_1 = rand_strided((32000, 4096), (4096, 1), device='cuda:0', dtype=torch.float16) | |
arg66_1 = rand_strided((2048, 64, 2), (128, 2, 1), device='cuda:0', dtype=torch.float16) | |
arg67_1 = rand_strided((208, 208), (208, 1), device='cuda:0', dtype=torch.bool) | |
arg68_1 = rand_strided((12288, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg69_1 = rand_strided((12288, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg70_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16) | |
arg71_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16) | |
arg72_1 = rand_strided((4096, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg73_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg74_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg75_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg76_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg77_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg78_1 = rand_strided((4096, 2064), (2064, 1), device='cuda:0', dtype=torch.int32) | |
arg79_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg80_1 = rand_strided((12288, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg81_1 = rand_strided((12288, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg82_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16) | |
arg83_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16) | |
arg84_1 = rand_strided((4096, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg85_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg86_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg87_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg88_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg89_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg90_1 = rand_strided((4096, 2064), (2064, 1), device='cuda:0', dtype=torch.int32) | |
arg91_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg92_1 = rand_strided((12288, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg93_1 = rand_strided((12288, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg94_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16) | |
arg95_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16) | |
arg96_1 = rand_strided((4096, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg97_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg98_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg99_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg100_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg101_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg102_1 = rand_strided((4096, 2064), (2064, 1), device='cuda:0', dtype=torch.int32) | |
arg103_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg104_1 = rand_strided((12288, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg105_1 = rand_strided((12288, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg106_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16) | |
arg107_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16) | |
arg108_1 = rand_strided((4096, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg109_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg110_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg111_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg112_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg113_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg114_1 = rand_strided((4096, 2064), (2064, 1), device='cuda:0', dtype=torch.int32) | |
arg115_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg116_1 = rand_strided((12288, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg117_1 = rand_strided((12288, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg118_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16) | |
arg119_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16) | |
arg120_1 = rand_strided((4096, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg121_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg122_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg123_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg124_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg125_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg126_1 = rand_strided((4096, 2064), (2064, 1), device='cuda:0', dtype=torch.int32) | |
arg127_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg128_1 = rand_strided((12288, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg129_1 = rand_strided((12288, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg130_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16) | |
arg131_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16) | |
arg132_1 = rand_strided((4096, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg133_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg134_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg135_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg136_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg137_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg138_1 = rand_strided((4096, 2064), (2064, 1), device='cuda:0', dtype=torch.int32) | |
arg139_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg140_1 = rand_strided((12288, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg141_1 = rand_strided((12288, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg142_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16) | |
arg143_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16) | |
arg144_1 = rand_strided((4096, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg145_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg146_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg147_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg148_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg149_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg150_1 = rand_strided((4096, 2064), (2064, 1), device='cuda:0', dtype=torch.int32) | |
arg151_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg152_1 = rand_strided((12288, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg153_1 = rand_strided((12288, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg154_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16) | |
arg155_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16) | |
arg156_1 = rand_strided((4096, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg157_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg158_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg159_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg160_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg161_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg162_1 = rand_strided((4096, 2064), (2064, 1), device='cuda:0', dtype=torch.int32) | |
arg163_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg164_1 = rand_strided((12288, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg165_1 = rand_strided((12288, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg166_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16) | |
arg167_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16) | |
arg168_1 = rand_strided((4096, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg169_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg170_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg171_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg172_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg173_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg174_1 = rand_strided((4096, 2064), (2064, 1), device='cuda:0', dtype=torch.int32) | |
arg175_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg176_1 = rand_strided((12288, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg177_1 = rand_strided((12288, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg178_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16) | |
arg179_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16) | |
arg180_1 = rand_strided((4096, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg181_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg182_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg183_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg184_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg185_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg186_1 = rand_strided((4096, 2064), (2064, 1), device='cuda:0', dtype=torch.int32) | |
arg187_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg188_1 = rand_strided((12288, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg189_1 = rand_strided((12288, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg190_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16) | |
arg191_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16) | |
arg192_1 = rand_strided((4096, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg193_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg194_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg195_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg196_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg197_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg198_1 = rand_strided((4096, 2064), (2064, 1), device='cuda:0', dtype=torch.int32) | |
arg199_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg200_1 = rand_strided((12288, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg201_1 = rand_strided((12288, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg202_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16) | |
arg203_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16) | |
arg204_1 = rand_strided((4096, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg205_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg206_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg207_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg208_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg209_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg210_1 = rand_strided((4096, 2064), (2064, 1), device='cuda:0', dtype=torch.int32) | |
arg211_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg212_1 = rand_strided((12288, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg213_1 = rand_strided((12288, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg214_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16) | |
arg215_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16) | |
arg216_1 = rand_strided((4096, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg217_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg218_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg219_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg220_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg221_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg222_1 = rand_strided((4096, 2064), (2064, 1), device='cuda:0', dtype=torch.int32) | |
arg223_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg224_1 = rand_strided((12288, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg225_1 = rand_strided((12288, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg226_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16) | |
arg227_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16) | |
arg228_1 = rand_strided((4096, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg229_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg230_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg231_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg232_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg233_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg234_1 = rand_strided((4096, 2064), (2064, 1), device='cuda:0', dtype=torch.int32) | |
arg235_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg236_1 = rand_strided((12288, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg237_1 = rand_strided((12288, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg238_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16) | |
arg239_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16) | |
arg240_1 = rand_strided((4096, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg241_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg242_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg243_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg244_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg245_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg246_1 = rand_strided((4096, 2064), (2064, 1), device='cuda:0', dtype=torch.int32) | |
arg247_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg248_1 = rand_strided((12288, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg249_1 = rand_strided((12288, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg250_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16) | |
arg251_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16) | |
arg252_1 = rand_strided((4096, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg253_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg254_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg255_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg256_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg257_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg258_1 = rand_strided((4096, 2064), (2064, 1), device='cuda:0', dtype=torch.int32) | |
arg259_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg260_1 = rand_strided((12288, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg261_1 = rand_strided((12288, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg262_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16) | |
arg263_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16) | |
arg264_1 = rand_strided((4096, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg265_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg266_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg267_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg268_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg269_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg270_1 = rand_strided((4096, 2064), (2064, 1), device='cuda:0', dtype=torch.int32) | |
arg271_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg272_1 = rand_strided((12288, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg273_1 = rand_strided((12288, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg274_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16) | |
arg275_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16) | |
arg276_1 = rand_strided((4096, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg277_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg278_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg279_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg280_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg281_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg282_1 = rand_strided((4096, 2064), (2064, 1), device='cuda:0', dtype=torch.int32) | |
arg283_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg284_1 = rand_strided((12288, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg285_1 = rand_strided((12288, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg286_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16) | |
arg287_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16) | |
arg288_1 = rand_strided((4096, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg289_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg290_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg291_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg292_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg293_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg294_1 = rand_strided((4096, 2064), (2064, 1), device='cuda:0', dtype=torch.int32) | |
arg295_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg296_1 = rand_strided((12288, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg297_1 = rand_strided((12288, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg298_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16) | |
arg299_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16) | |
arg300_1 = rand_strided((4096, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg301_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg302_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg303_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg304_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg305_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg306_1 = rand_strided((4096, 2064), (2064, 1), device='cuda:0', dtype=torch.int32) | |
arg307_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg308_1 = rand_strided((12288, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg309_1 = rand_strided((12288, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg310_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16) | |
arg311_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16) | |
arg312_1 = rand_strided((4096, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg313_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg314_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg315_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg316_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg317_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg318_1 = rand_strided((4096, 2064), (2064, 1), device='cuda:0', dtype=torch.int32) | |
arg319_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg320_1 = rand_strided((12288, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg321_1 = rand_strided((12288, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg322_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16) | |
arg323_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16) | |
arg324_1 = rand_strided((4096, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg325_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg326_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg327_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg328_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg329_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg330_1 = rand_strided((4096, 2064), (2064, 1), device='cuda:0', dtype=torch.int32) | |
arg331_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg332_1 = rand_strided((12288, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg333_1 = rand_strided((12288, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg334_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16) | |
arg335_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16) | |
arg336_1 = rand_strided((4096, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg337_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg338_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg339_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg340_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg341_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg342_1 = rand_strided((4096, 2064), (2064, 1), device='cuda:0', dtype=torch.int32) | |
arg343_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg344_1 = rand_strided((12288, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg345_1 = rand_strided((12288, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg346_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16) | |
arg347_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16) | |
arg348_1 = rand_strided((4096, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg349_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg350_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg351_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg352_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg353_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg354_1 = rand_strided((4096, 2064), (2064, 1), device='cuda:0', dtype=torch.int32) | |
arg355_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg356_1 = rand_strided((12288, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg357_1 = rand_strided((12288, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg358_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16) | |
arg359_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16) | |
arg360_1 = rand_strided((4096, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg361_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg362_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg363_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg364_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg365_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg366_1 = rand_strided((4096, 2064), (2064, 1), device='cuda:0', dtype=torch.int32) | |
arg367_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg368_1 = rand_strided((12288, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg369_1 = rand_strided((12288, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg370_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16) | |
arg371_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16) | |
arg372_1 = rand_strided((4096, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg373_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg374_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg375_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg376_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg377_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg378_1 = rand_strided((4096, 2064), (2064, 1), device='cuda:0', dtype=torch.int32) | |
arg379_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg380_1 = rand_strided((12288, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg381_1 = rand_strided((12288, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg382_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16) | |
arg383_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16) | |
arg384_1 = rand_strided((4096, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg385_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg386_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg387_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg388_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg389_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg390_1 = rand_strided((4096, 2064), (2064, 1), device='cuda:0', dtype=torch.int32) | |
arg391_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg392_1 = rand_strided((12288, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg393_1 = rand_strided((12288, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg394_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16) | |
arg395_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16) | |
arg396_1 = rand_strided((4096, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg397_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg398_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg399_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg400_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg401_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg402_1 = rand_strided((4096, 2064), (2064, 1), device='cuda:0', dtype=torch.int32) | |
arg403_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg404_1 = rand_strided((12288, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg405_1 = rand_strided((12288, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg406_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16) | |
arg407_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16) | |
arg408_1 = rand_strided((4096, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg409_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg410_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg411_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg412_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg413_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg414_1 = rand_strided((4096, 2064), (2064, 1), device='cuda:0', dtype=torch.int32) | |
arg415_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg416_1 = rand_strided((12288, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg417_1 = rand_strided((12288, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg418_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16) | |
arg419_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16) | |
arg420_1 = rand_strided((4096, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg421_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg422_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg423_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg424_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg425_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg426_1 = rand_strided((4096, 2064), (2064, 1), device='cuda:0', dtype=torch.int32) | |
arg427_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg428_1 = rand_strided((12288, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg429_1 = rand_strided((12288, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg430_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16) | |
arg431_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16) | |
arg432_1 = rand_strided((4096, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg433_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg434_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg435_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg436_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg437_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg438_1 = rand_strided((4096, 2064), (2064, 1), device='cuda:0', dtype=torch.int32) | |
arg439_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg440_1 = rand_strided((12288, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg441_1 = rand_strided((12288, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg442_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16) | |
arg443_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16) | |
arg444_1 = rand_strided((4096, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg445_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg446_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg447_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg448_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg449_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg450_1 = rand_strided((4096, 2064), (2064, 1), device='cuda:0', dtype=torch.int32) | |
arg451_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg452_1 = rand_strided((32000, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg453_1 = rand_strided((32000, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg454_1 = rand_strided((1, ), (1, ), device='cuda:0', dtype=torch.int32) | |
arg455_1 = rand_strided((1, 1), (1, 1), device='cuda:0', dtype=torch.int32) | |
fn = lambda: call([arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1, arg7_1, arg8_1, arg9_1, arg10_1, arg11_1, arg12_1, arg13_1, arg14_1, arg15_1, arg16_1, arg17_1, arg18_1, arg19_1, arg20_1, arg21_1, arg22_1, arg23_1, arg24_1, arg25_1, arg26_1, arg27_1, arg28_1, arg29_1, arg30_1, arg31_1, arg32_1, arg33_1, arg34_1, arg35_1, arg36_1, arg37_1, arg38_1, arg39_1, arg40_1, arg41_1, arg42_1, arg43_1, arg44_1, arg45_1, arg46_1, arg47_1, arg48_1, arg49_1, arg50_1, arg51_1, arg52_1, arg53_1, arg54_1, arg55_1, arg56_1, arg57_1, arg58_1, arg59_1, arg60_1, arg61_1, arg62_1, arg63_1, arg64_1, arg65_1, arg66_1, arg67_1, arg68_1, arg69_1, arg70_1, arg71_1, arg72_1, arg73_1, arg74_1, arg75_1, arg76_1, arg77_1, arg78_1, arg79_1, arg80_1, arg81_1, arg82_1, arg83_1, arg84_1, arg85_1, arg86_1, arg87_1, arg88_1, arg89_1, arg90_1, arg91_1, arg92_1, arg93_1, arg94_1, arg95_1, arg96_1, arg97_1, arg98_1, arg99_1, arg100_1, arg101_1, arg102_1, arg103_1, arg104_1, arg105_1, arg106_1, arg107_1, arg108_1, arg109_1, arg110_1, arg111_1, arg112_1, arg113_1, arg114_1, arg115_1, arg116_1, arg117_1, arg118_1, arg119_1, arg120_1, arg121_1, arg122_1, arg123_1, arg124_1, arg125_1, arg126_1, arg127_1, arg128_1, arg129_1, arg130_1, arg131_1, arg132_1, arg133_1, arg134_1, arg135_1, arg136_1, arg137_1, arg138_1, arg139_1, arg140_1, arg141_1, arg142_1, arg143_1, arg144_1, arg145_1, arg146_1, arg147_1, arg148_1, arg149_1, arg150_1, arg151_1, arg152_1, arg153_1, arg154_1, arg155_1, arg156_1, arg157_1, arg158_1, arg159_1, arg160_1, arg161_1, arg162_1, arg163_1, arg164_1, arg165_1, arg166_1, arg167_1, arg168_1, arg169_1, arg170_1, arg171_1, arg172_1, arg173_1, arg174_1, arg175_1, arg176_1, arg177_1, arg178_1, arg179_1, arg180_1, arg181_1, arg182_1, arg183_1, arg184_1, arg185_1, arg186_1, arg187_1, arg188_1, arg189_1, arg190_1, arg191_1, arg192_1, arg193_1, arg194_1, arg195_1, arg196_1, arg197_1, arg198_1, arg199_1, arg200_1, arg201_1, arg202_1, arg203_1, arg204_1, arg205_1, arg206_1, arg207_1, arg208_1, arg209_1, arg210_1, arg211_1, arg212_1, arg213_1, arg214_1, arg215_1, arg216_1, arg217_1, arg218_1, arg219_1, arg220_1, arg221_1, arg222_1, arg223_1, arg224_1, arg225_1, arg226_1, arg227_1, arg228_1, arg229_1, arg230_1, arg231_1, arg232_1, arg233_1, arg234_1, arg235_1, arg236_1, arg237_1, arg238_1, arg239_1, arg240_1, arg241_1, arg242_1, arg243_1, arg244_1, arg245_1, arg246_1, arg247_1, arg248_1, arg249_1, arg250_1, arg251_1, arg252_1, arg253_1, arg254_1, arg255_1, arg256_1, arg257_1, arg258_1, arg259_1, arg260_1, arg261_1, arg262_1, arg263_1, arg264_1, arg265_1, arg266_1, arg267_1, arg268_1, arg269_1, arg270_1, arg271_1, arg272_1, arg273_1, arg274_1, arg275_1, arg276_1, arg277_1, arg278_1, arg279_1, arg280_1, arg281_1, arg282_1, arg283_1, arg284_1, arg285_1, arg286_1, arg287_1, arg288_1, arg289_1, arg290_1, arg291_1, arg292_1, arg293_1, arg294_1, arg295_1, arg296_1, arg297_1, arg298_1, arg299_1, arg300_1, arg301_1, arg302_1, arg303_1, arg304_1, arg305_1, arg306_1, arg307_1, arg308_1, arg309_1, arg310_1, arg311_1, arg312_1, arg313_1, arg314_1, arg315_1, arg316_1, arg317_1, arg318_1, arg319_1, arg320_1, arg321_1, arg322_1, arg323_1, arg324_1, arg325_1, arg326_1, arg327_1, arg328_1, arg329_1, arg330_1, arg331_1, arg332_1, arg333_1, arg334_1, arg335_1, arg336_1, arg337_1, arg338_1, arg339_1, arg340_1, arg341_1, arg342_1, arg343_1, arg344_1, arg345_1, arg346_1, arg347_1, arg348_1, arg349_1, arg350_1, arg351_1, arg352_1, arg353_1, arg354_1, arg355_1, arg356_1, arg357_1, arg358_1, arg359_1, arg360_1, arg361_1, arg362_1, arg363_1, arg364_1, arg365_1, arg366_1, arg367_1, arg368_1, arg369_1, arg370_1, arg371_1, arg372_1, arg373_1, arg374_1, arg375_1, arg376_1, arg377_1, arg378_1, arg379_1, arg380_1, arg381_1, arg382_1, arg383_1, arg384_1, arg385_1, arg386_1, arg387_1, arg388_1, arg389_1, arg390_1, arg391_1, arg392_1, arg393_1, arg394_1, arg395_1, arg396_1, arg397_1, arg398_1, arg399_1, arg400_1, arg401_1, arg402_1, arg403_1, arg404_1, arg405_1, arg406_1, arg407_1, arg408_1, arg409_1, arg410_1, arg411_1, arg412_1, arg413_1, arg414_1, arg415_1, arg416_1, arg417_1, arg418_1, arg419_1, arg420_1, arg421_1, arg422_1, arg423_1, arg424_1, arg425_1, arg426_1, arg427_1, arg428_1, arg429_1, arg430_1, arg431_1, arg432_1, arg433_1, arg434_1, arg435_1, arg436_1, arg437_1, arg438_1, arg439_1, arg440_1, arg441_1, arg442_1, arg443_1, arg444_1, arg445_1, arg446_1, arg447_1, arg448_1, arg449_1, arg450_1, arg451_1, arg452_1, arg453_1, arg454_1, arg455_1]) | |
return print_performance(fn, times=times, repeat=repeat) | |
if __name__ == "__main__": | |
from torch._inductor.wrapper_benchmark import compiled_module_main | |
compiled_module_main('None', benchmark_compiled_module) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# this is set: torch._inductor.config.triton.cudagraph_trees = False | |
# AOT ID: ['0_inference'] | |
from ctypes import c_void_p, c_long | |
import torch | |
import math | |
import random | |
import os | |
import tempfile | |
from math import inf, nan | |
from torch._inductor.hooks import run_intermediate_hooks | |
from torch._inductor.utils import maybe_profile | |
from torch._inductor.codegen.memory_planning import _align as align | |
from torch import device, empty_strided | |
from torch._inductor.async_compile import AsyncCompile | |
from torch._inductor.select_algorithm import extern_kernels | |
from torch._inductor.codegen.multi_kernel import MultiKernelCall | |
aten = torch.ops.aten | |
inductor_ops = torch.ops.inductor | |
_quantized = torch.ops._quantized | |
assert_size_stride = torch._C._dynamo.guards.assert_size_stride | |
empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu | |
empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda | |
alloc_from_pool = torch.ops.inductor._alloc_from_pool | |
reinterpret_tensor = torch.ops.inductor._reinterpret_tensor | |
async_compile = AsyncCompile() | |
# kernel path: /tmp/torchinductor_ubuntu/35/c35n52yrfyq3r2uvr6syoon44qkntahvhxuhmylcpsxnin76axfd.py | |
# Source Nodes: [float_1, mean, mul, out, x], Original ATen: [aten._to_copy, aten.embedding, aten.mean, aten.mul, torchao.fp16act_fp6weight_linear] | |
# float_1 => convert_element_type | |
# mean => mean | |
# mul => mul | |
# out => fp16act_fp6weight_linear | |
# x => embedding | |
triton_red_fused__to_copy_embedding_fp16act_fp6weight_linear_mean_mul_0 = async_compile.triton('triton_red_fused__to_copy_embedding_fp16act_fp6weight_linear_mean_mul_0', ''' | |
import triton | |
import triton.language as tl | |
from triton.compiler.compiler import AttrsDescriptor | |
from torch._inductor.runtime import triton_helpers, triton_heuristics | |
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties | |
@triton_heuristics.reduction( | |
size_hints=[1, 4096], | |
reduction_hint=ReductionHint.INNER, | |
filename=__file__, | |
triton_meta={'signature': {0: '*i32', 1: '*fp16', 2: '*fp16', 3: '*fp16', 4: 'i32', 5: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=76), 'constants': {4: 1}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 5), equal_to_1=(4,))]}, | |
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_embedding_fp16act_fp6weight_linear_mean_mul_0', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 3, 'num_reduction': 1, 'backend_hash': '5002d3b4f89584c56f8df13bbe2510318f7ec3b73ba9efd611640fb08b2fe939', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False} | |
) | |
@triton.jit | |
def triton_red_fused__to_copy_embedding_fp16act_fp6weight_linear_mean_mul_0(in_ptr0, in_ptr1, in_ptr2, out_ptr1, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr): | |
xnumel = 1 | |
rnumel = 4096 | |
xoffset = tl.program_id(0) * XBLOCK | |
xindex = xoffset + tl.arange(0, XBLOCK)[:, None] | |
xmask = xindex < xnumel | |
rbase = tl.arange(0, RBLOCK)[None, :] | |
tmp0 = tl.load(in_ptr0 + (0)) | |
tmp1 = tl.broadcast_to(tmp0, [XBLOCK, RBLOCK]) | |
_tmp11 = tl.full([XBLOCK, RBLOCK], 0, tl.float32) | |
for roffset in range(0, rnumel, RBLOCK): | |
rindex = roffset + rbase | |
rmask = rindex < rnumel | |
r0 = rindex | |
tmp2 = tl.full([XBLOCK, RBLOCK], 32000, tl.int32) | |
tmp3 = tmp1 + tmp2 | |
tmp4 = tmp1 < 0 | |
tmp5 = tl.where(tmp4, tmp3, tmp1) | |
tl.device_assert((0 <= tmp5) & (tmp5 < 32000), "index out of bounds: 0 <= tmp5 < 32000") | |
tmp7 = tl.load(in_ptr1 + (r0 + (4096*tmp5)), rmask, eviction_policy='evict_last', other=0.0).to(tl.float32) | |
tmp8 = tmp7.to(tl.float32) | |
tmp9 = tmp8 * tmp8 | |
tmp10 = tl.broadcast_to(tmp9, [XBLOCK, RBLOCK]) | |
tmp12 = _tmp11 + tmp10 | |
_tmp11 = tl.where(rmask, tmp12, _tmp11) | |
tmp11 = tl.sum(_tmp11, 1)[:, None] | |
tmp13 = tl.load(in_ptr0 + (0)) | |
tmp14 = tl.broadcast_to(tmp13, [XBLOCK, RBLOCK]) | |
for roffset in range(0, rnumel, RBLOCK): | |
rindex = roffset + rbase | |
rmask = rindex < rnumel | |
r0 = rindex | |
tmp29 = tl.load(in_ptr2 + (r0), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32) | |
tmp15 = tl.full([XBLOCK, RBLOCK], 32000, tl.int32) | |
tmp16 = tmp14 + tmp15 | |
tmp17 = tmp14 < 0 | |
tmp18 = tl.where(tmp17, tmp16, tmp14) | |
tl.device_assert((0 <= tmp18) & (tmp18 < 32000), "index out of bounds: 0 <= tmp18 < 32000") | |
tmp20 = tl.load(in_ptr1 + (r0 + (4096*tmp18)), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32) | |
tmp21 = tmp20.to(tl.float32) | |
tmp22 = 4096.0 | |
tmp23 = tmp11 / tmp22 | |
tmp24 = 1e-05 | |
tmp25 = tmp23 + tmp24 | |
tmp26 = libdevice.rsqrt(tmp25) | |
tmp27 = tmp21 * tmp26 | |
tmp28 = tmp27.to(tl.float32) | |
tmp30 = tmp28 * tmp29 | |
tl.store(out_ptr1 + (tl.broadcast_to(r0, [XBLOCK, RBLOCK])), tmp30, rmask) | |
''', device_str='cuda') | |
import triton | |
import triton.language as tl | |
from torch._inductor.runtime.triton_heuristics import grid, split_scan_grid, start_graph, end_graph | |
from torch._C import _cuda_getCurrentRawStream as get_raw_stream | |
# kernel path: /tmp/torchinductor_ubuntu/ix/cix3cikar5jbote3hymtcnf5jtrl75mvovejwvdsj67ds2nqaynm.py | |
# Source Nodes: [setitem, setitem_1, y], Original ATen: [aten.bmm, aten.index_put] | |
# setitem => index_put | |
# setitem_1 => index_put_1 | |
# y => convert_element_type_6 | |
triton_poi_fused_bmm_index_put_1 = async_compile.triton('triton_poi_fused_bmm_index_put_1', ''' | |
import triton | |
import triton.language as tl | |
from triton.compiler.compiler import AttrsDescriptor | |
from torch._inductor.runtime import triton_helpers, triton_heuristics | |
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties | |
@triton_heuristics.pointwise( | |
size_hints=[4096], | |
filename=__file__, | |
triton_meta={'signature': {0: '*i32', 1: '*fp16', 2: '*fp16', 3: '*fp16', 4: '*fp32', 5: '*fp16', 6: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=76), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4, 5, 6), equal_to_1=())]}, | |
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_bmm_index_put_1', 'mutated_arg_names': ['out_ptr0', 'out_ptr2'], 'no_x_dim': False, 'num_load': 10, 'num_reduction': 0, 'backend_hash': '5002d3b4f89584c56f8df13bbe2510318f7ec3b73ba9efd611640fb08b2fe939', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False}, | |
min_elem_per_thread=0 | |
) | |
@triton.jit | |
def triton_poi_fused_bmm_index_put_1(in_ptr0, in_ptr1, in_ptr2, out_ptr0, out_ptr1, out_ptr2, xnumel, XBLOCK : tl.constexpr): | |
xnumel = 4096 | |
xoffset = tl.program_id(0) * XBLOCK | |
xindex = xoffset + tl.arange(0, XBLOCK)[:] | |
xmask = xindex < xnumel | |
x0 = xindex % 128 | |
x1 = (xindex // 128) | |
x2 = xindex | |
tmp0 = tl.load(in_ptr0 + (0)) | |
tmp1 = tl.broadcast_to(tmp0, [XBLOCK]) | |
tmp71 = tl.load(in_ptr1 + (8192 + x2), None).to(tl.float32) | |
tmp2 = tl.full([XBLOCK], 208, tl.int32) | |
tmp3 = tmp1 + tmp2 | |
tmp4 = tmp1 < 0 | |
tmp5 = tl.where(tmp4, tmp3, tmp1) | |
tl.device_assert((0 <= tmp5) & (tmp5 < 208), "index out of bounds: 0 <= tmp5 < 208") | |
tmp7 = x0 % 2 | |
tmp8 = tl.full([1], 0, tl.int64) | |
tmp9 = tmp7 >= tmp8 | |
tmp10 = tl.full([1], 1, tl.int64) | |
tmp11 = tmp7 < tmp10 | |
tmp12 = tl.load(in_ptr1 + (4096 + (2*(x0 // 2)) + (128*x1)), tmp11, eviction_policy='evict_last', other=0.0).to(tl.float32) | |
tmp13 = tmp12.to(tl.float32) | |
tmp14 = tl.full([XBLOCK], 2048, tl.int32) | |
tmp15 = tmp1 + tmp14 | |
tmp16 = tl.where(tmp4, tmp15, tmp1) | |
tl.device_assert(((0 <= tl.broadcast_to(tmp16, [XBLOCK])) & (tl.broadcast_to(tmp16, [XBLOCK]) < 2048)) | ~(tmp11), "index out of bounds: 0 <= tl.broadcast_to(tmp16, [XBLOCK]) < 2048") | |
tmp18 = tl.load(in_ptr2 + ((2*(x0 // 2)) + (128*tmp16)), tmp11, eviction_policy='evict_last', other=0.0).to(tl.float32) | |
tmp19 = tmp18.to(tl.float32) | |
tmp20 = tmp13 * tmp19 | |
tmp21 = tl.load(in_ptr1 + (4097 + (2*(x0 // 2)) + (128*x1)), tmp11, eviction_policy='evict_last', other=0.0).to(tl.float32) | |
tmp22 = tmp21.to(tl.float32) | |
tmp23 = tl.load(in_ptr2 + (1 + (2*(x0 // 2)) + (128*tmp16)), tmp11, eviction_policy='evict_last', other=0.0).to(tl.float32) | |
tmp24 = tmp23.to(tl.float32) | |
tmp25 = tmp22 * tmp24 | |
tmp26 = tmp20 - tmp25 | |
tmp27 = tl.full(tmp26.shape, 0.0, tmp26.dtype) | |
tmp28 = tl.where(tmp11, tmp26, tmp27) | |
tmp29 = tmp7 >= tmp10 | |
tmp30 = tl.full([1], 2, tl.int64) | |
tmp31 = tmp7 < tmp30 | |
tmp32 = tl.load(in_ptr1 + (4097 + (2*(x0 // 2)) + (128*x1)), tmp29, eviction_policy='evict_last', other=0.0).to(tl.float32) | |
tmp33 = tmp32.to(tl.float32) | |
tl.device_assert(((0 <= tl.broadcast_to(tmp16, [XBLOCK])) & (tl.broadcast_to(tmp16, [XBLOCK]) < 2048)) | ~(tmp29), "index out of bounds: 0 <= tl.broadcast_to(tmp16, [XBLOCK]) < 2048") | |
tmp35 = tl.load(in_ptr2 + ((2*(x0 // 2)) + (128*tmp16)), tmp29, eviction_policy='evict_last', other=0.0).to(tl.float32) | |
tmp36 = tmp35.to(tl.float32) | |
tmp37 = tmp33 * tmp36 | |
tmp38 = tl.load(in_ptr1 + (4096 + (2*(x0 // 2)) + (128*x1)), tmp29, eviction_policy='evict_last', other=0.0).to(tl.float32) | |
tmp39 = tmp38.to(tl.float32) | |
tmp40 = tl.load(in_ptr2 + (1 + (2*(x0 // 2)) + (128*tmp16)), tmp29, eviction_policy='evict_last', other=0.0).to(tl.float32) | |
tmp41 = tmp40.to(tl.float32) | |
tmp42 = tmp39 * tmp41 | |
tmp43 = tmp37 + tmp42 | |
tmp44 = tl.full(tmp43.shape, 0.0, tmp43.dtype) | |
tmp45 = tl.where(tmp29, tmp43, tmp44) | |
tmp46 = tl.where(tmp11, tmp28, tmp45) | |
tmp47 = tmp46.to(tl.float32) | |
tmp48 = tl.load(in_ptr1 + ((2*(x0 // 2)) + (128*x1)), tmp11, eviction_policy='evict_last', other=0.0).to(tl.float32) | |
tmp49 = tmp48.to(tl.float32) | |
tmp50 = tmp49 * tmp19 | |
tmp51 = tl.load(in_ptr1 + (1 + (2*(x0 // 2)) + (128*x1)), tmp11, eviction_policy='evict_last', other=0.0).to(tl.float32) | |
tmp52 = tmp51.to(tl.float32) | |
tmp53 = tmp52 * tmp24 | |
tmp54 = tmp50 - tmp53 | |
tmp55 = tl.full(tmp54.shape, 0.0, tmp54.dtype) | |
tmp56 = tl.where(tmp11, tmp54, tmp55) | |
tmp57 = tl.load(in_ptr1 + (1 + (2*(x0 // 2)) + (128*x1)), tmp29, eviction_policy='evict_last', other=0.0).to(tl.float32) | |
tmp58 = tmp57.to(tl.float32) | |
tmp59 = tmp58 * tmp36 | |
tmp60 = tl.load(in_ptr1 + ((2*(x0 // 2)) + (128*x1)), tmp29, eviction_policy='evict_last', other=0.0).to(tl.float32) | |
tmp61 = tmp60.to(tl.float32) | |
tmp62 = tmp61 * tmp41 | |
tmp63 = tmp59 + tmp62 | |
tmp64 = tl.full(tmp63.shape, 0.0, tmp63.dtype) | |
tmp65 = tl.where(tmp29, tmp63, tmp64) | |
tmp66 = tl.where(tmp11, tmp56, tmp65) | |
tmp67 = tmp66.to(tl.float32) | |
tmp68 = 0.29730177875068026 | |
tmp69 = tmp67 * tmp68 | |
tmp70 = tmp69.to(tl.float32) | |
tl.store(out_ptr0 + (x0 + (128*tmp5) + (26624*x1)), tmp47, None) | |
tl.store(out_ptr1 + (x2), tmp70, None) | |
tl.store(out_ptr2 + (x0 + (128*tmp5) + (26624*x1)), tmp71, None) | |
''', device_str='cuda') | |
# kernel path: /tmp/torchinductor_ubuntu/tu/ctuxbryvllsmitwtw4gk35f6e4uug6renb6km4llzq2czdpete6i.py | |
# Source Nodes: [y], Original ATen: [aten.bmm] | |
# y => mul_13, sum_1 | |
triton_red_fused_bmm_2 = async_compile.triton('triton_red_fused_bmm_2', ''' | |
import triton | |
import triton.language as tl | |
from triton.compiler.compiler import AttrsDescriptor | |
from torch._inductor.runtime import triton_helpers, triton_heuristics | |
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties | |
@triton_heuristics.reduction( | |
size_hints=[8192, 128], | |
reduction_hint=ReductionHint.DEFAULT, | |
filename=__file__, | |
triton_meta={'signature': {0: '*fp32', 1: '*fp16', 2: '*fp32', 3: 'i32', 4: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=76), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4), equal_to_1=())]}, | |
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused_bmm_2', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 2, 'num_reduction': 1, 'backend_hash': '5002d3b4f89584c56f8df13bbe2510318f7ec3b73ba9efd611640fb08b2fe939', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False} | |
) | |
@triton.jit | |
def triton_red_fused_bmm_2(in_ptr0, in_ptr1, out_ptr0, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr): | |
xnumel = 6656 | |
rnumel = 128 | |
xoffset = tl.program_id(0) * XBLOCK | |
xindex = xoffset + tl.arange(0, XBLOCK)[:, None] | |
xmask = xindex < xnumel | |
rbase = tl.arange(0, RBLOCK)[None, :] | |
x1 = (xindex // 208) | |
x3 = xindex | |
_tmp7 = tl.full([XBLOCK, RBLOCK], 0, tl.float32) | |
for roffset in range(0, rnumel, RBLOCK): | |
rindex = roffset + rbase | |
rmask = rindex < rnumel | |
r2 = rindex | |
tmp0 = tl.load(in_ptr0 + (r2 + (128*x1)), rmask & xmask, eviction_policy='evict_last', other=0.0) | |
tmp1 = tl.load(in_ptr1 + (r2 + (128*x3)), rmask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32) | |
tmp2 = 0.29730177875068026 | |
tmp3 = tmp1 * tmp2 | |
tmp4 = tmp3.to(tl.float32) | |
tmp5 = tmp0 * tmp4 | |
tmp6 = tl.broadcast_to(tmp5, [XBLOCK, RBLOCK]) | |
tmp8 = _tmp7 + tmp6 | |
_tmp7 = tl.where(rmask & xmask, tmp8, _tmp7) | |
tmp7 = tl.sum(_tmp7, 1)[:, None] | |
tl.store(out_ptr0 + (x3), tmp7, xmask) | |
''', device_str='cuda') | |
# kernel path: /tmp/torchinductor_ubuntu/f6/cf6hoyusglegqlcllwth4hbv7jzk7xu6ol3qg56cytjysskalpwn.py | |
# Source Nodes: [mask, y], Original ATen: [aten._softmax, aten.add, aten.bmm, aten.index, aten.logical_not, aten.masked_fill, aten.zeros_like] | |
# mask => index | |
# y => add_3, amax, convert_element_type_11, convert_element_type_9, exp, full_default, full_default_1, logical_not, sub_2, sum_2, where | |
triton_per_fused__softmax_add_bmm_index_logical_not_masked_fill_zeros_like_3 = async_compile.triton('triton_per_fused__softmax_add_bmm_index_logical_not_masked_fill_zeros_like_3', ''' | |
import triton | |
import triton.language as tl | |
from triton.compiler.compiler import AttrsDescriptor | |
from torch._inductor.runtime import triton_helpers, triton_heuristics | |
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties | |
@triton_heuristics.persistent_reduction( | |
size_hints=[32, 256], | |
reduction_hint=ReductionHint.INNER, | |
filename=__file__, | |
triton_meta={'signature': {0: '*fp32', 1: '*i32', 2: '*i1', 3: '*fp32', 4: 'i32', 5: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=76), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4, 5), equal_to_1=())]}, | |
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_per_fused__softmax_add_bmm_index_logical_not_masked_fill_zeros_like_3', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 2, 'num_reduction': 2, 'backend_hash': '5002d3b4f89584c56f8df13bbe2510318f7ec3b73ba9efd611640fb08b2fe939', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False} | |
) | |
@triton.jit | |
def triton_per_fused__softmax_add_bmm_index_logical_not_masked_fill_zeros_like_3(in_ptr0, in_ptr1, in_ptr2, out_ptr2, xnumel, rnumel, XBLOCK : tl.constexpr): | |
xnumel = 32 | |
rnumel = 208 | |
RBLOCK: tl.constexpr = 256 | |
xoffset = tl.program_id(0) * XBLOCK | |
xindex = xoffset + tl.arange(0, XBLOCK)[:, None] | |
xmask = xindex < xnumel | |
rindex = tl.arange(0, RBLOCK)[None, :] | |
roffset = 0 | |
rmask = rindex < rnumel | |
r1 = rindex | |
x0 = xindex | |
tmp0 = tl.load(in_ptr0 + (r1 + (208*x0)), rmask & xmask, other=0.0) | |
tmp2 = tl.load(in_ptr1 + (0)) | |
tmp3 = tl.broadcast_to(tmp2, [XBLOCK, RBLOCK]) | |
tmp1 = tmp0.to(tl.float32) | |
tmp4 = tl.full([XBLOCK, RBLOCK], 208, tl.int32) | |
tmp5 = tmp3 + tmp4 | |
tmp6 = tmp3 < 0 | |
tmp7 = tl.where(tmp6, tmp5, tmp3) | |
tl.device_assert((0 <= tmp7) & (tmp7 < 208), "index out of bounds: 0 <= tmp7 < 208") | |
tmp9 = tl.load(in_ptr2 + (r1 + (208*tmp7)), rmask, other=0.0).to(tl.int1) | |
tmp10 = tmp9 == 0 | |
tmp11 = float("-inf") | |
tmp12 = 0.0 | |
tmp13 = tl.where(tmp10, tmp11, tmp12) | |
tmp14 = tmp1 + tmp13 | |
tmp15 = tmp14.to(tl.float32) | |
tmp16 = tl.broadcast_to(tmp15, [XBLOCK, RBLOCK]) | |
tmp18 = tl.where(rmask & xmask, tmp16, float("-inf")) | |
tmp19 = triton_helpers.max2(tmp18, 1)[:, None] | |
tmp20 = tmp15 - tmp19 | |
tmp21 = tl_math.exp(tmp20) | |
tmp22 = tl.broadcast_to(tmp21, [XBLOCK, RBLOCK]) | |
tmp24 = tl.where(rmask & xmask, tmp22, 0) | |
tmp25 = tl.sum(tmp24, 1)[:, None] | |
tmp26 = tmp21 / tmp25 | |
tmp27 = tmp26.to(tl.float32) | |
tmp28 = tmp27.to(tl.float32) | |
tl.store(out_ptr2 + (r1 + (208*x0)), tmp28, rmask & xmask) | |
''', device_str='cuda') | |
# kernel path: /tmp/torchinductor_ubuntu/3q/c3q5eet7mn3gx23dlvno24ooilcrlirl64ahmhluvxu6ynwq452m.py | |
# Source Nodes: [y], Original ATen: [aten.bmm] | |
# y => mul_14, sum_3 | |
triton_red_fused_bmm_4 = async_compile.triton('triton_red_fused_bmm_4', ''' | |
import triton | |
import triton.language as tl | |
from triton.compiler.compiler import AttrsDescriptor | |
from torch._inductor.runtime import triton_helpers, triton_heuristics | |
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties | |
@triton_heuristics.reduction( | |
size_hints=[8192, 128], | |
reduction_hint=ReductionHint.OUTER, | |
filename=__file__, | |
triton_meta={'signature': {0: '*fp32', 1: '*fp16', 2: '*fp32', 3: 'i32', 4: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=76), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3), equal_to_1=())]}, | |
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused_bmm_4', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 2, 'num_reduction': 1, 'backend_hash': '5002d3b4f89584c56f8df13bbe2510318f7ec3b73ba9efd611640fb08b2fe939', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False} | |
) | |
@triton.jit | |
def triton_red_fused_bmm_4(in_ptr0, in_ptr1, out_ptr0, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr): | |
xnumel = 8192 | |
rnumel = 104 | |
xoffset = tl.program_id(0) * XBLOCK | |
xindex = xoffset + tl.arange(0, XBLOCK)[:, None] | |
xmask = xindex < xnumel | |
rbase = tl.arange(0, RBLOCK)[None, :] | |
x1 = (xindex // 128) | |
x0 = xindex % 128 | |
_tmp5 = tl.full([XBLOCK, RBLOCK], 0, tl.float32) | |
x3 = xindex | |
for roffset in range(0, rnumel, RBLOCK): | |
rindex = roffset + rbase | |
rmask = rindex < rnumel | |
r2 = rindex | |
tmp0 = tl.load(in_ptr0 + (r2 + (104*x1)), rmask, eviction_policy='evict_last', other=0.0) | |
tmp1 = tl.load(in_ptr1 + (x0 + (128*r2) + (13312*x1)), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32) | |
tmp2 = tmp1.to(tl.float32) | |
tmp3 = tmp0 * tmp2 | |
tmp4 = tl.broadcast_to(tmp3, [XBLOCK, RBLOCK]) | |
tmp6 = _tmp5 + tmp4 | |
_tmp5 = tl.where(rmask, tmp6, _tmp5) | |
tmp5 = tl.sum(_tmp5, 1)[:, None] | |
tl.store(out_ptr0 + (x3), tmp5, None) | |
''', device_str='cuda') | |
# kernel path: /tmp/torchinductor_ubuntu/s7/cs7bh5txamreduogo3dbpgu62hdnknommfet7uxuau6q2gi24u6j.py | |
# Source Nodes: [out_1, y], Original ATen: [aten.bmm, torchao.fp16act_fp6weight_linear] | |
# out_1 => fp16act_fp6weight_linear_1 | |
# y => mul_14, sum_3 | |
triton_per_fused_bmm_fp16act_fp6weight_linear_5 = async_compile.triton('triton_per_fused_bmm_fp16act_fp6weight_linear_5', ''' | |
import triton | |
import triton.language as tl | |
from triton.compiler.compiler import AttrsDescriptor | |
from torch._inductor.runtime import triton_helpers, triton_heuristics | |
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties | |
@triton_heuristics.persistent_reduction( | |
size_hints=[4096, 2], | |
reduction_hint=ReductionHint.OUTER, | |
filename=__file__, | |
triton_meta={'signature': {0: '*fp32', 1: '*fp16', 2: 'i32', 3: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=76), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2), equal_to_1=())]}, | |
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_per_fused_bmm_fp16act_fp6weight_linear_5', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': '5002d3b4f89584c56f8df13bbe2510318f7ec3b73ba9efd611640fb08b2fe939', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False} | |
) | |
@triton.jit | |
def triton_per_fused_bmm_fp16act_fp6weight_linear_5(in_ptr0, out_ptr1, xnumel, rnumel, XBLOCK : tl.constexpr): | |
xnumel = 4096 | |
rnumel = 2 | |
RBLOCK: tl.constexpr = 2 | |
xoffset = tl.program_id(0) * XBLOCK | |
xindex = xoffset + tl.arange(0, XBLOCK)[:, None] | |
xmask = xindex < xnumel | |
rindex = tl.arange(0, RBLOCK)[None, :] | |
roffset = 0 | |
rmask = rindex < rnumel | |
r2 = rindex | |
x0 = xindex % 128 | |
x1 = (xindex // 128) | |
x3 = xindex | |
tmp0 = tl.load(in_ptr0 + (x0 + (128*r2) + (256*x1)), rmask, other=0.0) | |
tmp1 = tl.broadcast_to(tmp0, [XBLOCK, RBLOCK]) | |
tmp3 = tl.where(rmask, tmp1, 0) | |
tmp4 = tl.sum(tmp3, 1)[:, None] | |
tmp5 = tmp4.to(tl.float32) | |
tl.store(out_ptr1 + (x3), tmp5, None) | |
''', device_str='cuda') | |
# kernel path: /tmp/torchinductor_ubuntu/e6/ce6earzjiin77ifcgxqhqghmkxait7rqhahkb2zasp4hadk7bw4d.py | |
# Source Nodes: [add_4, float_4, h, mean_1, mul_11, mul_12, mul_13, output_1, rsqrt_1, x], Original ATen: [aten._to_copy, aten.add, aten.embedding, aten.mean, aten.mul, aten.rsqrt] | |
# add_4 => add_5 | |
# float_4 => convert_element_type_14 | |
# h => add_4 | |
# mean_1 => mean_1 | |
# mul_11 => mul_15 | |
# mul_12 => mul_16 | |
# mul_13 => mul_17 | |
# output_1 => convert_element_type_15 | |
# rsqrt_1 => rsqrt_1 | |
# x => embedding | |
triton_red_fused__to_copy_add_embedding_mean_mul_rsqrt_6 = async_compile.triton('triton_red_fused__to_copy_add_embedding_mean_mul_rsqrt_6', ''' | |
import triton | |
import triton.language as tl | |
from triton.compiler.compiler import AttrsDescriptor | |
from torch._inductor.runtime import triton_helpers, triton_heuristics | |
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties | |
@triton_heuristics.reduction( | |
size_hints=[1, 4096], | |
reduction_hint=ReductionHint.INNER, | |
filename=__file__, | |
triton_meta={'signature': {0: '*i32', 1: '*fp16', 2: '*fp16', 3: '*fp16', 4: '*fp16', 5: 'i32', 6: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=76), 'constants': {5: 1}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4, 6), equal_to_1=(5,))]}, | |
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_add_embedding_mean_mul_rsqrt_6', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 5, 'num_reduction': 1, 'backend_hash': '5002d3b4f89584c56f8df13bbe2510318f7ec3b73ba9efd611640fb08b2fe939', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False} | |
) | |
@triton.jit | |
def triton_red_fused__to_copy_add_embedding_mean_mul_rsqrt_6(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr1, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr): | |
xnumel = 1 | |
rnumel = 4096 | |
xoffset = tl.program_id(0) * XBLOCK | |
xindex = xoffset + tl.arange(0, XBLOCK)[:, None] | |
xmask = xindex < xnumel | |
rbase = tl.arange(0, RBLOCK)[None, :] | |
tmp0 = tl.load(in_ptr0 + (0)) | |
tmp1 = tl.broadcast_to(tmp0, [XBLOCK, RBLOCK]) | |
_tmp13 = tl.full([XBLOCK, RBLOCK], 0, tl.float32) | |
for roffset in range(0, rnumel, RBLOCK): | |
rindex = roffset + rbase | |
rmask = rindex < rnumel | |
r0 = rindex | |
tmp8 = tl.load(in_ptr2 + (r0), rmask, eviction_policy='evict_last', other=0.0).to(tl.float32) | |
tmp2 = tl.full([XBLOCK, RBLOCK], 32000, tl.int32) | |
tmp3 = tmp1 + tmp2 | |
tmp4 = tmp1 < 0 | |
tmp5 = tl.where(tmp4, tmp3, tmp1) | |
tl.device_assert((0 <= tmp5) & (tmp5 < 32000), "index out of bounds: 0 <= tmp5 < 32000") | |
tmp7 = tl.load(in_ptr1 + (r0 + (4096*tmp5)), rmask, eviction_policy='evict_last', other=0.0).to(tl.float32) | |
tmp9 = tmp7 + tmp8 | |
tmp10 = tmp9.to(tl.float32) | |
tmp11 = tmp10 * tmp10 | |
tmp12 = tl.broadcast_to(tmp11, [XBLOCK, RBLOCK]) | |
tmp14 = _tmp13 + tmp12 | |
_tmp13 = tl.where(rmask, tmp14, _tmp13) | |
tmp13 = tl.sum(_tmp13, 1)[:, None] | |
tmp15 = tl.load(in_ptr0 + (0)) | |
tmp16 = tl.broadcast_to(tmp15, [XBLOCK, RBLOCK]) | |
for roffset in range(0, rnumel, RBLOCK): | |
rindex = roffset + rbase | |
rmask = rindex < rnumel | |
r0 = rindex | |
tmp23 = tl.load(in_ptr2 + (r0), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32) | |
tmp33 = tl.load(in_ptr3 + (r0), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32) | |
tmp17 = tl.full([XBLOCK, RBLOCK], 32000, tl.int32) | |
tmp18 = tmp16 + tmp17 | |
tmp19 = tmp16 < 0 | |
tmp20 = tl.where(tmp19, tmp18, tmp16) | |
tl.device_assert((0 <= tmp20) & (tmp20 < 32000), "index out of bounds: 0 <= tmp20 < 32000") | |
tmp22 = tl.load(in_ptr1 + (r0 + (4096*tmp20)), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32) | |
tmp24 = tmp22 + tmp23 | |
tmp25 = tmp24.to(tl.float32) | |
tmp26 = 4096.0 | |
tmp27 = tmp13 / tmp26 | |
tmp28 = 1e-05 | |
tmp29 = tmp27 + tmp28 | |
tmp30 = libdevice.rsqrt(tmp29) | |
tmp31 = tmp25 * tmp30 | |
tmp32 = tmp31.to(tl.float32) | |
tmp34 = tmp32 * tmp33 | |
tl.store(out_ptr1 + (tl.broadcast_to(r0, [XBLOCK, RBLOCK])), tmp34, rmask) | |
''', device_str='cuda') | |
# kernel path: /tmp/torchinductor_ubuntu/7x/c7xzb5ucrv7a232bufgvbjemlk4jgdh7ujoroxm3v7jxuduwdfif.py | |
# Source Nodes: [out_4], Original ATen: [torchao.fp16act_fp6weight_linear] | |
# out_4 => fp16act_fp6weight_linear_4 | |
triton_poi_fused_fp16act_fp6weight_linear_7 = async_compile.triton('triton_poi_fused_fp16act_fp6weight_linear_7', ''' | |
import triton | |
import triton.language as tl | |
from triton.compiler.compiler import AttrsDescriptor | |
from torch._inductor.runtime import triton_helpers, triton_heuristics | |
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties | |
@triton_heuristics.pointwise( | |
size_hints=[16384], | |
filename=__file__, | |
triton_meta={'signature': {0: '*fp16', 1: '*fp16', 2: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=76), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2), equal_to_1=())]}, | |
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_fp16act_fp6weight_linear_7', 'mutated_arg_names': ['in_out_ptr0'], 'no_x_dim': False, 'num_load': 2, 'num_reduction': 0, 'backend_hash': '5002d3b4f89584c56f8df13bbe2510318f7ec3b73ba9efd611640fb08b2fe939', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False}, | |
min_elem_per_thread=0 | |
) | |
@triton.jit | |
def triton_poi_fused_fp16act_fp6weight_linear_7(in_out_ptr0, in_ptr0, xnumel, XBLOCK : tl.constexpr): | |
xnumel = 11008 | |
xoffset = tl.program_id(0) * XBLOCK | |
xindex = xoffset + tl.arange(0, XBLOCK)[:] | |
xmask = xindex < xnumel | |
x0 = xindex | |
tmp0 = tl.load(in_out_ptr0 + (x0), xmask).to(tl.float32) | |
tmp5 = tl.load(in_ptr0 + (x0), xmask).to(tl.float32) | |
tmp1 = tmp0.to(tl.float32) | |
tmp2 = tl.sigmoid(tmp1) | |
tmp3 = tmp1 * tmp2 | |
tmp4 = tmp3.to(tl.float32) | |
tmp6 = tmp4 * tmp5 | |
tl.store(in_out_ptr0 + (x0), tmp6, xmask) | |
''', device_str='cuda') | |
# kernel path: /tmp/torchinductor_ubuntu/xb/cxbu5uv7truaeft3qae6exsdva5w6j3jruqooy6lvbsp53jek26v.py | |
# Source Nodes: [float_5, h, mean_2, mul_15, out_5, out_6, x], Original ATen: [aten._to_copy, aten.add, aten.embedding, aten.mean, aten.mul, torchao.fp16act_fp6weight_linear] | |
# float_5 => convert_element_type_18 | |
# h => add_4 | |
# mean_2 => mean_2 | |
# mul_15 => mul_20 | |
# out_5 => add_6 | |
# out_6 => fp16act_fp6weight_linear_5 | |
# x => embedding | |
triton_red_fused__to_copy_add_embedding_fp16act_fp6weight_linear_mean_mul_8 = async_compile.triton('triton_red_fused__to_copy_add_embedding_fp16act_fp6weight_linear_mean_mul_8', ''' | |
import triton | |
import triton.language as tl | |
from triton.compiler.compiler import AttrsDescriptor | |
from torch._inductor.runtime import triton_helpers, triton_heuristics | |
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties | |
@triton_heuristics.reduction( | |
size_hints=[1, 4096], | |
reduction_hint=ReductionHint.INNER, | |
filename=__file__, | |
triton_meta={'signature': {0: '*i32', 1: '*fp16', 2: '*fp16', 3: '*fp16', 4: '*fp16', 5: '*fp16', 6: 'i32', 7: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=76), 'constants': {6: 1}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4, 5, 7), equal_to_1=(6,))]}, | |
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_add_embedding_fp16act_fp6weight_linear_mean_mul_8', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 7, 'num_reduction': 1, 'backend_hash': '5002d3b4f89584c56f8df13bbe2510318f7ec3b73ba9efd611640fb08b2fe939', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False} | |
) | |
@triton.jit | |
def triton_red_fused__to_copy_add_embedding_fp16act_fp6weight_linear_mean_mul_8(in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, out_ptr1, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr): | |
xnumel = 1 | |
rnumel = 4096 | |
xoffset = tl.program_id(0) * XBLOCK | |
xindex = xoffset + tl.arange(0, XBLOCK)[:, None] | |
xmask = xindex < xnumel | |
rbase = tl.arange(0, RBLOCK)[None, :] | |
tmp0 = tl.load(in_ptr0 + (0)) | |
tmp1 = tl.broadcast_to(tmp0, [XBLOCK, RBLOCK]) | |
_tmp15 = tl.full([XBLOCK, RBLOCK], 0, tl.float32) | |
for roffset in range(0, rnumel, RBLOCK): | |
rindex = roffset + rbase | |
rmask = rindex < rnumel | |
r0 = rindex | |
tmp8 = tl.load(in_ptr2 + (r0), rmask, eviction_policy='evict_last', other=0.0).to(tl.float32) | |
tmp10 = tl.load(in_ptr3 + (r0), rmask, eviction_policy='evict_last', other=0.0).to(tl.float32) | |
tmp2 = tl.full([XBLOCK, RBLOCK], 32000, tl.int32) | |
tmp3 = tmp1 + tmp2 | |
tmp4 = tmp1 < 0 | |
tmp5 = tl.where(tmp4, tmp3, tmp1) | |
tl.device_assert((0 <= tmp5) & (tmp5 < 32000), "index out of bounds: 0 <= tmp5 < 32000") | |
tmp7 = tl.load(in_ptr1 + (r0 + (4096*tmp5)), rmask, eviction_policy='evict_last', other=0.0).to(tl.float32) | |
tmp9 = tmp7 + tmp8 | |
tmp11 = tmp9 + tmp10 | |
tmp12 = tmp11.to(tl.float32) | |
tmp13 = tmp12 * tmp12 | |
tmp14 = tl.broadcast_to(tmp13, [XBLOCK, RBLOCK]) | |
tmp16 = _tmp15 + tmp14 | |
_tmp15 = tl.where(rmask, tmp16, _tmp15) | |
tmp15 = tl.sum(_tmp15, 1)[:, None] | |
tmp17 = tl.load(in_ptr0 + (0)) | |
tmp18 = tl.broadcast_to(tmp17, [XBLOCK, RBLOCK]) | |
for roffset in range(0, rnumel, RBLOCK): | |
rindex = roffset + rbase | |
rmask = rindex < rnumel | |
r0 = rindex | |
tmp25 = tl.load(in_ptr2 + (r0), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32) | |
tmp27 = tl.load(in_ptr3 + (r0), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32) | |
tmp37 = tl.load(in_ptr4 + (r0), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32) | |
tmp19 = tl.full([XBLOCK, RBLOCK], 32000, tl.int32) | |
tmp20 = tmp18 + tmp19 | |
tmp21 = tmp18 < 0 | |
tmp22 = tl.where(tmp21, tmp20, tmp18) | |
tl.device_assert((0 <= tmp22) & (tmp22 < 32000), "index out of bounds: 0 <= tmp22 < 32000") | |
tmp24 = tl.load(in_ptr1 + (r0 + (4096*tmp22)), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32) | |
tmp26 = tmp24 + tmp25 | |
tmp28 = tmp26 + tmp27 | |
tmp29 = tmp28.to(tl.float32) | |
tmp30 = 4096.0 | |
tmp31 = tmp15 / tmp30 | |
tmp32 = 1e-05 | |
tmp33 = tmp31 + tmp32 | |
tmp34 = libdevice.rsqrt(tmp33) | |
tmp35 = tmp29 * tmp34 | |
tmp36 = tmp35.to(tl.float32) | |
tmp38 = tmp36 * tmp37 | |
tl.store(out_ptr1 + (tl.broadcast_to(r0, [XBLOCK, RBLOCK])), tmp38, rmask) | |
''', device_str='cuda') | |
# kernel path: /tmp/torchinductor_ubuntu/ko/ckonz56ntbc3wxvboqbs6j6xmjdrqma6fl5rtmem5grbfiymtcs4.py | |
# Source Nodes: [float_8, h, h_1, mean_3, mul_26, out_5, out_8, out_9, x], Original ATen: [aten._to_copy, aten.add, aten.embedding, aten.mean, aten.mul, torchao.fp16act_fp6weight_linear] | |
# float_8 => convert_element_type_32 | |
# h => add_4 | |
# h_1 => add_11 | |
# mean_3 => mean_3 | |
# mul_26 => mul_35 | |
# out_5 => add_6 | |
# out_8 => fp16act_fp6weight_linear_7 | |
# out_9 => fp16act_fp6weight_linear_8 | |
# x => embedding | |
triton_red_fused__to_copy_add_embedding_fp16act_fp6weight_linear_mean_mul_9 = async_compile.triton('triton_red_fused__to_copy_add_embedding_fp16act_fp6weight_linear_mean_mul_9', ''' | |
import triton | |
import triton.language as tl | |
from triton.compiler.compiler import AttrsDescriptor | |
from torch._inductor.runtime import triton_helpers, triton_heuristics | |
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties | |
@triton_heuristics.reduction( | |
size_hints=[1, 4096], | |
reduction_hint=ReductionHint.INNER, | |
filename=__file__, | |
triton_meta={'signature': {0: '*fp16', 1: '*i32', 2: '*fp16', 3: '*fp16', 4: '*fp16', 5: '*fp16', 6: '*fp16', 7: '*fp16', 8: 'i32', 9: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=76), 'constants': {8: 1}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4, 5, 6, 7, 9), equal_to_1=(8,))]}, | |
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_add_embedding_fp16act_fp6weight_linear_mean_mul_9', 'mutated_arg_names': ['in_out_ptr0'], 'no_x_dim': False, 'num_load': 6, 'num_reduction': 1, 'backend_hash': '5002d3b4f89584c56f8df13bbe2510318f7ec3b73ba9efd611640fb08b2fe939', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False} | |
) | |
@triton.jit | |
def triton_red_fused__to_copy_add_embedding_fp16act_fp6weight_linear_mean_mul_9(in_out_ptr0, in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, out_ptr1, out_ptr2, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr): | |
xnumel = 1 | |
rnumel = 4096 | |
xoffset = tl.program_id(0) * XBLOCK | |
xindex = xoffset + tl.arange(0, XBLOCK)[:, None] | |
xmask = xindex < xnumel | |
rbase = tl.arange(0, RBLOCK)[None, :] | |
tmp0 = tl.load(in_ptr0 + (0)) | |
tmp1 = tl.broadcast_to(tmp0, [XBLOCK, RBLOCK]) | |
_tmp17 = tl.full([XBLOCK, RBLOCK], 0, tl.float32) | |
for roffset in range(0, rnumel, RBLOCK): | |
rindex = roffset + rbase | |
rmask = rindex < rnumel | |
r0 = rindex | |
tmp8 = tl.load(in_out_ptr0 + (r0), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32) | |
tmp10 = tl.load(in_ptr2 + (r0), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32) | |
tmp12 = tl.load(in_ptr3 + (r0), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32) | |
tmp2 = tl.full([XBLOCK, RBLOCK], 32000, tl.int32) | |
tmp3 = tmp1 + tmp2 | |
tmp4 = tmp1 < 0 | |
tmp5 = tl.where(tmp4, tmp3, tmp1) | |
tl.device_assert((0 <= tmp5) & (tmp5 < 32000), "index out of bounds: 0 <= tmp5 < 32000") | |
tmp7 = tl.load(in_ptr1 + (r0 + (4096*tmp5)), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32) | |
tmp9 = tmp7 + tmp8 | |
tmp11 = tmp9 + tmp10 | |
tmp13 = tmp11 + tmp12 | |
tmp14 = tmp13.to(tl.float32) | |
tmp15 = tmp14 * tmp14 | |
tmp16 = tl.broadcast_to(tmp15, [XBLOCK, RBLOCK]) | |
tmp18 = _tmp17 + tmp16 | |
_tmp17 = tl.where(rmask, tmp18, _tmp17) | |
tl.store(in_out_ptr0 + (tl.broadcast_to(r0, [XBLOCK, RBLOCK])), tmp13, rmask) | |
tmp17 = tl.sum(_tmp17, 1)[:, None] | |
for roffset in range(0, rnumel, RBLOCK): | |
rindex = roffset + rbase | |
rmask = rindex < rnumel | |
r0 = rindex | |
tmp19 = tl.load(in_out_ptr0 + (r0), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32) | |
tmp28 = tl.load(in_ptr4 + (r0), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32) | |
tmp20 = tmp19.to(tl.float32) | |
tmp21 = 4096.0 | |
tmp22 = tmp17 / tmp21 | |
tmp23 = 1e-05 | |
tmp24 = tmp22 + tmp23 | |
tmp25 = libdevice.rsqrt(tmp24) | |
tmp26 = tmp20 * tmp25 | |
tmp27 = tmp26.to(tl.float32) | |
tmp29 = tmp27 * tmp28 | |
tl.store(out_ptr1 + (tl.broadcast_to(r0, [XBLOCK, RBLOCK])), tmp29, rmask) | |
tl.store(out_ptr2 + (tl.broadcast_to(r0, [XBLOCK, RBLOCK])), tmp29, rmask) | |
''', device_str='cuda') | |
# kernel path: /tmp/torchinductor_ubuntu/23/c23scxttsdz72vesaap27tazzxervcixdrt7hwfybhp2nrfu4kw4.py | |
# Source Nodes: [float_9, mean_4, mul_30, out_11, out_12], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, torchao.fp16act_fp6weight_linear] | |
# float_9 => convert_element_type_36 | |
# mean_4 => mean_4 | |
# mul_30 => mul_40 | |
# out_11 => add_13 | |
# out_12 => fp16act_fp6weight_linear_10 | |
triton_red_fused__to_copy_add_fp16act_fp6weight_linear_mean_mul_10 = async_compile.triton('triton_red_fused__to_copy_add_fp16act_fp6weight_linear_mean_mul_10', ''' | |
import triton | |
import triton.language as tl | |
from triton.compiler.compiler import AttrsDescriptor | |
from torch._inductor.runtime import triton_helpers, triton_heuristics | |
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties | |
@triton_heuristics.reduction( | |
size_hints=[1, 4096], | |
reduction_hint=ReductionHint.INNER, | |
filename=__file__, | |
triton_meta={'signature': {0: '*fp16', 1: '*fp16', 2: '*fp16', 3: '*fp16', 4: 'i32', 5: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=76), 'constants': {4: 1}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 5), equal_to_1=(4,))]}, | |
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_add_fp16act_fp6weight_linear_mean_mul_10', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 5, 'num_reduction': 1, 'backend_hash': '5002d3b4f89584c56f8df13bbe2510318f7ec3b73ba9efd611640fb08b2fe939', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False} | |
) | |
@triton.jit | |
def triton_red_fused__to_copy_add_fp16act_fp6weight_linear_mean_mul_10(in_ptr0, in_ptr1, in_ptr2, out_ptr1, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr): | |
xnumel = 1 | |
rnumel = 4096 | |
xoffset = tl.program_id(0) * XBLOCK | |
xindex = xoffset + tl.arange(0, XBLOCK)[:, None] | |
xmask = xindex < xnumel | |
rbase = tl.arange(0, RBLOCK)[None, :] | |
_tmp6 = tl.full([XBLOCK, RBLOCK], 0, tl.float32) | |
for roffset in range(0, rnumel, RBLOCK): | |
rindex = roffset + rbase | |
rmask = rindex < rnumel | |
r0 = rindex | |
tmp0 = tl.load(in_ptr0 + (r0), rmask, eviction_policy='evict_last', other=0.0).to(tl.float32) | |
tmp1 = tl.load(in_ptr1 + (r0), rmask, eviction_policy='evict_last', other=0.0).to(tl.float32) | |
tmp2 = tmp0 + tmp1 | |
tmp3 = tmp2.to(tl.float32) | |
tmp4 = tmp3 * tmp3 | |
tmp5 = tl.broadcast_to(tmp4, [XBLOCK, RBLOCK]) | |
tmp7 = _tmp6 + tmp5 | |
_tmp6 = tl.where(rmask, tmp7, _tmp6) | |
tmp6 = tl.sum(_tmp6, 1)[:, None] | |
for roffset in range(0, rnumel, RBLOCK): | |
rindex = roffset + rbase | |
rmask = rindex < rnumel | |
r0 = rindex | |
tmp8 = tl.load(in_ptr0 + (r0), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32) | |
tmp9 = tl.load(in_ptr1 + (r0), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32) | |
tmp19 = tl.load(in_ptr2 + (r0), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32) | |
tmp10 = tmp8 + tmp9 | |
tmp11 = tmp10.to(tl.float32) | |
tmp12 = 4096.0 | |
tmp13 = tmp6 / tmp12 | |
tmp14 = 1e-05 | |
tmp15 = tmp13 + tmp14 | |
tmp16 = libdevice.rsqrt(tmp15) | |
tmp17 = tmp11 * tmp16 | |
tmp18 = tmp17.to(tl.float32) | |
tmp20 = tmp18 * tmp19 | |
tl.store(out_ptr1 + (tl.broadcast_to(r0, [XBLOCK, RBLOCK])), tmp20, rmask) | |
''', device_str='cuda') | |
# kernel path: /tmp/torchinductor_ubuntu/sg/csgnrv5nhw4eic7f275ngxcfz73e62dtja6zilbn6f5a2qhvqn54.py | |
# Source Nodes: [add_16, float_12, h_2, mean_5, mul_41, mul_42, mul_43, out_11, output_5, rsqrt_5], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, aten.rsqrt] | |
# add_16 => add_19 | |
# float_12 => convert_element_type_50 | |
# h_2 => add_18 | |
# mean_5 => mean_5 | |
# mul_41 => mul_55 | |
# mul_42 => mul_56 | |
# mul_43 => mul_57 | |
# out_11 => add_13 | |
# output_5 => convert_element_type_51 | |
# rsqrt_5 => rsqrt_5 | |
triton_red_fused__to_copy_add_mean_mul_rsqrt_11 = async_compile.triton('triton_red_fused__to_copy_add_mean_mul_rsqrt_11', ''' | |
import triton | |
import triton.language as tl | |
from triton.compiler.compiler import AttrsDescriptor | |
from torch._inductor.runtime import triton_helpers, triton_heuristics | |
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties | |
@triton_heuristics.reduction( | |
size_hints=[1, 4096], | |
reduction_hint=ReductionHint.INNER, | |
filename=__file__, | |
triton_meta={'signature': {0: '*fp16', 1: '*fp16', 2: '*fp16', 3: '*fp16', 4: '*fp16', 5: 'i32', 6: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=76), 'constants': {5: 1}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4, 6), equal_to_1=(5,))]}, | |
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_add_mean_mul_rsqrt_11', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 7, 'num_reduction': 1, 'backend_hash': '5002d3b4f89584c56f8df13bbe2510318f7ec3b73ba9efd611640fb08b2fe939', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False} | |
) | |
@triton.jit | |
def triton_red_fused__to_copy_add_mean_mul_rsqrt_11(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr1, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr): | |
xnumel = 1 | |
rnumel = 4096 | |
xoffset = tl.program_id(0) * XBLOCK | |
xindex = xoffset + tl.arange(0, XBLOCK)[:, None] | |
xmask = xindex < xnumel | |
rbase = tl.arange(0, RBLOCK)[None, :] | |
_tmp8 = tl.full([XBLOCK, RBLOCK], 0, tl.float32) | |
for roffset in range(0, rnumel, RBLOCK): | |
rindex = roffset + rbase | |
rmask = rindex < rnumel | |
r0 = rindex | |
tmp0 = tl.load(in_ptr0 + (r0), rmask, eviction_policy='evict_last', other=0.0).to(tl.float32) | |
tmp1 = tl.load(in_ptr1 + (r0), rmask, eviction_policy='evict_last', other=0.0).to(tl.float32) | |
tmp3 = tl.load(in_ptr2 + (r0), rmask, eviction_policy='evict_last', other=0.0).to(tl.float32) | |
tmp2 = tmp0 + tmp1 | |
tmp4 = tmp2 + tmp3 | |
tmp5 = tmp4.to(tl.float32) | |
tmp6 = tmp5 * tmp5 | |
tmp7 = tl.broadcast_to(tmp6, [XBLOCK, RBLOCK]) | |
tmp9 = _tmp8 + tmp7 | |
_tmp8 = tl.where(rmask, tmp9, _tmp8) | |
tmp8 = tl.sum(_tmp8, 1)[:, None] | |
for roffset in range(0, rnumel, RBLOCK): | |
rindex = roffset + rbase | |
rmask = rindex < rnumel | |
r0 = rindex | |
tmp10 = tl.load(in_ptr0 + (r0), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32) | |
tmp11 = tl.load(in_ptr1 + (r0), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32) | |
tmp13 = tl.load(in_ptr2 + (r0), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32) | |
tmp23 = tl.load(in_ptr3 + (r0), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32) | |
tmp12 = tmp10 + tmp11 | |
tmp14 = tmp12 + tmp13 | |
tmp15 = tmp14.to(tl.float32) | |
tmp16 = 4096.0 | |
tmp17 = tmp8 / tmp16 | |
tmp18 = 1e-05 | |
tmp19 = tmp17 + tmp18 | |
tmp20 = libdevice.rsqrt(tmp19) | |
tmp21 = tmp15 * tmp20 | |
tmp22 = tmp21.to(tl.float32) | |
tmp24 = tmp22 * tmp23 | |
tl.store(out_ptr1 + (tl.broadcast_to(r0, [XBLOCK, RBLOCK])), tmp24, rmask) | |
''', device_str='cuda') | |
# kernel path: /tmp/torchinductor_ubuntu/nw/cnwia6b5wdcoewnme63eazcbz6dvzogn3stxnaty267p4vw2mmwy.py | |
# Source Nodes: [float_13, h_2, mean_6, mul_45, out_11, out_17, out_18], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, torchao.fp16act_fp6weight_linear] | |
# float_13 => convert_element_type_54 | |
# h_2 => add_18 | |
# mean_6 => mean_6 | |
# mul_45 => mul_60 | |
# out_11 => add_13 | |
# out_17 => add_20 | |
# out_18 => fp16act_fp6weight_linear_15 | |
triton_red_fused__to_copy_add_fp16act_fp6weight_linear_mean_mul_12 = async_compile.triton('triton_red_fused__to_copy_add_fp16act_fp6weight_linear_mean_mul_12', ''' | |
import triton | |
import triton.language as tl | |
from triton.compiler.compiler import AttrsDescriptor | |
from torch._inductor.runtime import triton_helpers, triton_heuristics | |
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties | |
@triton_heuristics.reduction( | |
size_hints=[1, 4096], | |
reduction_hint=ReductionHint.INNER, | |
filename=__file__, | |
triton_meta={'signature': {0: '*fp16', 1: '*fp16', 2: '*fp16', 3: '*fp16', 4: '*fp16', 5: '*fp16', 6: 'i32', 7: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=76), 'constants': {6: 1}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4, 5, 7), equal_to_1=(6,))]}, | |
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_add_fp16act_fp6weight_linear_mean_mul_12', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 9, 'num_reduction': 1, 'backend_hash': '5002d3b4f89584c56f8df13bbe2510318f7ec3b73ba9efd611640fb08b2fe939', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False} | |
) | |
@triton.jit | |
def triton_red_fused__to_copy_add_fp16act_fp6weight_linear_mean_mul_12(in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, out_ptr1, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr): | |
xnumel = 1 | |
rnumel = 4096 | |
xoffset = tl.program_id(0) * XBLOCK | |
xindex = xoffset + tl.arange(0, XBLOCK)[:, None] | |
xmask = xindex < xnumel | |
rbase = tl.arange(0, RBLOCK)[None, :] | |
_tmp10 = tl.full([XBLOCK, RBLOCK], 0, tl.float32) | |
for roffset in range(0, rnumel, RBLOCK): | |
rindex = roffset + rbase | |
rmask = rindex < rnumel | |
r0 = rindex | |
tmp0 = tl.load(in_ptr0 + (r0), rmask, eviction_policy='evict_last', other=0.0).to(tl.float32) | |
tmp1 = tl.load(in_ptr1 + (r0), rmask, eviction_policy='evict_last', other=0.0).to(tl.float32) | |
tmp3 = tl.load(in_ptr2 + (r0), rmask, eviction_policy='evict_last', other=0.0).to(tl.float32) | |
tmp5 = tl.load(in_ptr3 + (r0), rmask, eviction_policy='evict_last', other=0.0).to(tl.float32) | |
tmp2 = tmp0 + tmp1 | |
tmp4 = tmp2 + tmp3 | |
tmp6 = tmp4 + tmp5 | |
tmp7 = tmp6.to(tl.float32) | |
tmp8 = tmp7 * tmp7 | |
tmp9 = tl.broadcast_to(tmp8, [XBLOCK, RBLOCK]) | |
tmp11 = _tmp10 + tmp9 | |
_tmp10 = tl.where(rmask, tmp11, _tmp10) | |
tmp10 = tl.sum(_tmp10, 1)[:, None] | |
for roffset in range(0, rnumel, RBLOCK): | |
rindex = roffset + rbase | |
rmask = rindex < rnumel | |
r0 = rindex | |
tmp12 = tl.load(in_ptr0 + (r0), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32) | |
tmp13 = tl.load(in_ptr1 + (r0), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32) | |
tmp15 = tl.load(in_ptr2 + (r0), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32) | |
tmp17 = tl.load(in_ptr3 + (r0), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32) | |
tmp27 = tl.load(in_ptr4 + (r0), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32) | |
tmp14 = tmp12 + tmp13 | |
tmp16 = tmp14 + tmp15 | |
tmp18 = tmp16 + tmp17 | |
tmp19 = tmp18.to(tl.float32) | |
tmp20 = 4096.0 | |
tmp21 = tmp10 / tmp20 | |
tmp22 = 1e-05 | |
tmp23 = tmp21 + tmp22 | |
tmp24 = libdevice.rsqrt(tmp23) | |
tmp25 = tmp19 * tmp24 | |
tmp26 = tmp25.to(tl.float32) | |
tmp28 = tmp26 * tmp27 | |
tl.store(out_ptr1 + (tl.broadcast_to(r0, [XBLOCK, RBLOCK])), tmp28, rmask) | |
''', device_str='cuda') | |
# kernel path: /tmp/torchinductor_ubuntu/yi/cyiywyx2b2ap35djo57g2tkxu3hdej3o4ootd3nr6bigecahyoeu.py | |
# Source Nodes: [float_16, h_2, h_3, mean_7, mul_56, out_11, out_17, out_20, out_21], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, torchao.fp16act_fp6weight_linear] | |
# float_16 => convert_element_type_68 | |
# h_2 => add_18 | |
# h_3 => add_25 | |
# mean_7 => mean_7 | |
# mul_56 => mul_75 | |
# out_11 => add_13 | |
# out_17 => add_20 | |
# out_20 => fp16act_fp6weight_linear_17 | |
# out_21 => fp16act_fp6weight_linear_18 | |
triton_red_fused__to_copy_add_fp16act_fp6weight_linear_mean_mul_13 = async_compile.triton('triton_red_fused__to_copy_add_fp16act_fp6weight_linear_mean_mul_13', ''' | |
import triton | |
import triton.language as tl | |
from triton.compiler.compiler import AttrsDescriptor | |
from torch._inductor.runtime import triton_helpers, triton_heuristics | |
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties | |
@triton_heuristics.reduction( | |
size_hints=[1, 4096], | |
reduction_hint=ReductionHint.INNER, | |
filename=__file__, | |
triton_meta={'signature': {0: '*fp16', 1: '*fp16', 2: '*fp16', 3: '*fp16', 4: '*fp16', 5: '*fp16', 6: '*fp16', 7: '*fp16', 8: 'i32', 9: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=76), 'constants': {8: 1}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4, 5, 6, 7, 9), equal_to_1=(8,))]}, | |
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_add_fp16act_fp6weight_linear_mean_mul_13', 'mutated_arg_names': ['in_out_ptr0'], 'no_x_dim': False, 'num_load': 7, 'num_reduction': 1, 'backend_hash': '5002d3b4f89584c56f8df13bbe2510318f7ec3b73ba9efd611640fb08b2fe939', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False} | |
) | |
@triton.jit | |
def triton_red_fused__to_copy_add_fp16act_fp6weight_linear_mean_mul_13(in_out_ptr0, in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, out_ptr1, out_ptr2, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr): | |
xnumel = 1 | |
rnumel = 4096 | |
xoffset = tl.program_id(0) * XBLOCK | |
xindex = xoffset + tl.arange(0, XBLOCK)[:, None] | |
xmask = xindex < xnumel | |
rbase = tl.arange(0, RBLOCK)[None, :] | |
_tmp12 = tl.full([XBLOCK, RBLOCK], 0, tl.float32) | |
for roffset in range(0, rnumel, RBLOCK): | |
rindex = roffset + rbase | |
rmask = rindex < rnumel | |
r0 = rindex | |
tmp0 = tl.load(in_out_ptr0 + (r0), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32) | |
tmp1 = tl.load(in_ptr0 + (r0), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32) | |
tmp3 = tl.load(in_ptr1 + (r0), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32) | |
tmp5 = tl.load(in_ptr2 + (r0), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32) | |
tmp7 = tl.load(in_ptr3 + (r0), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32) | |
tmp2 = tmp0 + tmp1 | |
tmp4 = tmp2 + tmp3 | |
tmp6 = tmp4 + tmp5 | |
tmp8 = tmp6 + tmp7 | |
tmp9 = tmp8.to(tl.float32) | |
tmp10 = tmp9 * tmp9 | |
tmp11 = tl.broadcast_to(tmp10, [XBLOCK, RBLOCK]) | |
tmp13 = _tmp12 + tmp11 | |
_tmp12 = tl.where(rmask, tmp13, _tmp12) | |
tl.store(in_out_ptr0 + (tl.broadcast_to(r0, [XBLOCK, RBLOCK])), tmp8, rmask) | |
tmp12 = tl.sum(_tmp12, 1)[:, None] | |
for roffset in range(0, rnumel, RBLOCK): | |
rindex = roffset + rbase | |
rmask = rindex < rnumel | |
r0 = rindex | |
tmp14 = tl.load(in_out_ptr0 + (r0), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32) | |
tmp23 = tl.load(in_ptr4 + (r0), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32) | |
tmp15 = tmp14.to(tl.float32) | |
tmp16 = 4096.0 | |
tmp17 = tmp12 / tmp16 | |
tmp18 = 1e-05 | |
tmp19 = tmp17 + tmp18 | |
tmp20 = libdevice.rsqrt(tmp19) | |
tmp21 = tmp15 * tmp20 | |
tmp22 = tmp21.to(tl.float32) | |
tmp24 = tmp22 * tmp23 | |
tl.store(out_ptr1 + (tl.broadcast_to(r0, [XBLOCK, RBLOCK])), tmp24, rmask) | |
tl.store(out_ptr2 + (tl.broadcast_to(r0, [XBLOCK, RBLOCK])), tmp24, rmask) | |
''', device_str='cuda') | |
# kernel path: /tmp/torchinductor_ubuntu/74/c74rqt5v7cgztcvqz3g36kdattrzmlrq33illnplnagfadtt7nqv.py | |
# Source Nodes: [out_22], Original ATen: [torchao.fp16act_fp6weight_linear] | |
# out_22 => fp16act_fp6weight_linear_19 | |
triton_poi_fused_fp16act_fp6weight_linear_14 = async_compile.triton('triton_poi_fused_fp16act_fp6weight_linear_14', ''' | |
import triton | |
import triton.language as tl | |
from triton.compiler.compiler import AttrsDescriptor | |
from torch._inductor.runtime import triton_helpers, triton_heuristics | |
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties | |
@triton_heuristics.pointwise( | |
size_hints=[16384], | |
filename=__file__, | |
triton_meta={'signature': {0: '*fp16', 1: '*fp16', 2: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=76), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2), equal_to_1=())]}, | |
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_fp16act_fp6weight_linear_14', 'mutated_arg_names': ['in_out_ptr0'], 'no_x_dim': False, 'num_load': 2, 'num_reduction': 0, 'backend_hash': '5002d3b4f89584c56f8df13bbe2510318f7ec3b73ba9efd611640fb08b2fe939', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False}, | |
min_elem_per_thread=0 | |
) | |
@triton.jit | |
def triton_poi_fused_fp16act_fp6weight_linear_14(in_out_ptr0, in_ptr0, xnumel, XBLOCK : tl.constexpr): | |
xnumel = 11008 | |
xoffset = tl.program_id(0) * XBLOCK | |
xindex = xoffset + tl.arange(0, XBLOCK)[:] | |
xmask = xindex < xnumel | |
x0 = xindex | |
tmp0 = tl.load(in_ptr0 + (x0), xmask).to(tl.float32) | |
tmp5 = tl.load(in_out_ptr0 + (x0), xmask).to(tl.float32) | |
tmp1 = tmp0.to(tl.float32) | |
tmp2 = tl.sigmoid(tmp1) | |
tmp3 = tmp1 * tmp2 | |
tmp4 = tmp3.to(tl.float32) | |
tmp6 = tmp4 * tmp5 | |
tl.store(in_out_ptr0 + (x0), tmp6, xmask) | |
''', device_str='cuda') | |
# kernel path: /tmp/torchinductor_ubuntu/od/codwx422sav4ibjcz7e3u2ii64iq4gtjlj6xmjifbwmqjhhlbprf.py | |
# Source Nodes: [float_24, h_4, h_5, mean_11, mul_86, out_23, out_29, out_32, out_33], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, torchao.fp16act_fp6weight_linear] | |
# float_24 => convert_element_type_104 | |
# h_4 => add_32 | |
# h_5 => add_39 | |
# mean_11 => mean_11 | |
# mul_86 => mul_115 | |
# out_23 => add_27 | |
# out_29 => add_34 | |
# out_32 => fp16act_fp6weight_linear_27 | |
# out_33 => fp16act_fp6weight_linear_28 | |
triton_red_fused__to_copy_add_fp16act_fp6weight_linear_mean_mul_15 = async_compile.triton('triton_red_fused__to_copy_add_fp16act_fp6weight_linear_mean_mul_15', ''' | |
import triton | |
import triton.language as tl | |
from triton.compiler.compiler import AttrsDescriptor | |
from torch._inductor.runtime import triton_helpers, triton_heuristics | |
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties | |
@triton_heuristics.reduction( | |
size_hints=[1, 4096], | |
reduction_hint=ReductionHint.INNER, | |
filename=__file__, | |
triton_meta={'signature': {0: '*fp16', 1: '*fp16', 2: '*fp16', 3: '*fp16', 4: '*fp16', 5: '*fp16', 6: '*fp16', 7: '*fp16', 8: 'i32', 9: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=76), 'constants': {8: 1}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4, 5, 6, 7, 9), equal_to_1=(8,))]}, | |
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_add_fp16act_fp6weight_linear_mean_mul_15', 'mutated_arg_names': ['in_out_ptr0'], 'no_x_dim': False, 'num_load': 7, 'num_reduction': 1, 'backend_hash': '5002d3b4f89584c56f8df13bbe2510318f7ec3b73ba9efd611640fb08b2fe939', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False} | |
) | |
@triton.jit | |
def triton_red_fused__to_copy_add_fp16act_fp6weight_linear_mean_mul_15(in_out_ptr0, in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, out_ptr1, out_ptr2, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr): | |
xnumel = 1 | |
rnumel = 4096 | |
xoffset = tl.program_id(0) * XBLOCK | |
xindex = xoffset + tl.arange(0, XBLOCK)[:, None] | |
xmask = xindex < xnumel | |
rbase = tl.arange(0, RBLOCK)[None, :] | |
_tmp12 = tl.full([XBLOCK, RBLOCK], 0, tl.float32) | |
for roffset in range(0, rnumel, RBLOCK): | |
rindex = roffset + rbase | |
rmask = rindex < rnumel | |
r0 = rindex | |
tmp0 = tl.load(in_ptr0 + (r0), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32) | |
tmp1 = tl.load(in_out_ptr0 + (r0), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32) | |
tmp3 = tl.load(in_ptr1 + (r0), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32) | |
tmp5 = tl.load(in_ptr2 + (r0), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32) | |
tmp7 = tl.load(in_ptr3 + (r0), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32) | |
tmp2 = tmp0 + tmp1 | |
tmp4 = tmp2 + tmp3 | |
tmp6 = tmp4 + tmp5 | |
tmp8 = tmp6 + tmp7 | |
tmp9 = tmp8.to(tl.float32) | |
tmp10 = tmp9 * tmp9 | |
tmp11 = tl.broadcast_to(tmp10, [XBLOCK, RBLOCK]) | |
tmp13 = _tmp12 + tmp11 | |
_tmp12 = tl.where(rmask, tmp13, _tmp12) | |
tl.store(in_out_ptr0 + (tl.broadcast_to(r0, [XBLOCK, RBLOCK])), tmp8, rmask) | |
tmp12 = tl.sum(_tmp12, 1)[:, None] | |
for roffset in range(0, rnumel, RBLOCK): | |
rindex = roffset + rbase | |
rmask = rindex < rnumel | |
r0 = rindex | |
tmp14 = tl.load(in_out_ptr0 + (r0), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32) | |
tmp23 = tl.load(in_ptr4 + (r0), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32) | |
tmp15 = tmp14.to(tl.float32) | |
tmp16 = 4096.0 | |
tmp17 = tmp12 / tmp16 | |
tmp18 = 1e-05 | |
tmp19 = tmp17 + tmp18 | |
tmp20 = libdevice.rsqrt(tmp19) | |
tmp21 = tmp15 * tmp20 | |
tmp22 = tmp21.to(tl.float32) | |
tmp24 = tmp22 * tmp23 | |
tl.store(out_ptr1 + (tl.broadcast_to(r0, [XBLOCK, RBLOCK])), tmp24, rmask) | |
tl.store(out_ptr2 + (tl.broadcast_to(r0, [XBLOCK, RBLOCK])), tmp24, rmask) | |
''', device_str='cuda') | |
# kernel path: /tmp/torchinductor_ubuntu/z6/cz6oxi5i6lto4h7ib75rlkpzqsocefhcbwhm7nhdpxravdtgiuro.py | |
# Source Nodes: [logits_1], Original ATen: [aten.div] | |
# logits_1 => div_32 | |
triton_poi_fused_div_16 = async_compile.triton('triton_poi_fused_div_16', ''' | |
import triton | |
import triton.language as tl | |
from triton.compiler.compiler import AttrsDescriptor | |
from torch._inductor.runtime import triton_helpers, triton_heuristics | |
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties | |
@triton_heuristics.pointwise( | |
size_hints=[32768], | |
filename=__file__, | |
triton_meta={'signature': {0: '*fp16', 1: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=76), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1), equal_to_1=())]}, | |
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_div_16', 'mutated_arg_names': ['in_out_ptr0'], 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': '5002d3b4f89584c56f8df13bbe2510318f7ec3b73ba9efd611640fb08b2fe939', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False}, | |
min_elem_per_thread=0 | |
) | |
@triton.jit | |
def triton_poi_fused_div_16(in_out_ptr0, xnumel, XBLOCK : tl.constexpr): | |
xnumel = 32000 | |
xoffset = tl.program_id(0) * XBLOCK | |
xindex = xoffset + tl.arange(0, XBLOCK)[:] | |
xmask = xindex < xnumel | |
x0 = xindex | |
tmp0 = tl.load(in_out_ptr0 + (x0), xmask).to(tl.float32) | |
tmp1 = 1.25 | |
tmp2 = tmp0 * tmp1 | |
tl.store(in_out_ptr0 + (x0), tmp2, xmask) | |
''', device_str='cuda') | |
# kernel path: /tmp/torchinductor_ubuntu/xv/cxvjbig2sbormp2kxcubuvv2h5v3vxuc4jta7llq4rnh6kwuxhuo.py | |
# Source Nodes: [logits_2, lt, probs], Original ATen: [aten._softmax, aten.lt, aten.scalar_tensor, aten.where] | |
# logits_2 => full_default_64, where_32 | |
# lt => lt | |
# probs => amax_32, convert_element_type_578 | |
triton_red_fused__softmax_lt_scalar_tensor_where_17 = async_compile.triton('triton_red_fused__softmax_lt_scalar_tensor_where_17', ''' | |
import triton | |
import triton.language as tl | |
from triton.compiler.compiler import AttrsDescriptor | |
from torch._inductor.runtime import triton_helpers, triton_heuristics | |
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties | |
@triton_heuristics.reduction( | |
size_hints=[4, 8192], | |
reduction_hint=ReductionHint.INNER, | |
filename=__file__, | |
triton_meta={'signature': {0: '*fp16', 1: '*fp16', 2: '*fp32', 3: 'i32', 4: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=76), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 4), equal_to_1=())]}, | |
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused__softmax_lt_scalar_tensor_where_17', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 2, 'num_reduction': 1, 'backend_hash': '5002d3b4f89584c56f8df13bbe2510318f7ec3b73ba9efd611640fb08b2fe939', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False} | |
) | |
@triton.jit | |
def triton_red_fused__softmax_lt_scalar_tensor_where_17(in_ptr0, in_ptr1, out_ptr0, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr): | |
xnumel = 4 | |
rnumel = 8000 | |
xoffset = tl.program_id(0) * XBLOCK | |
xindex = xoffset + tl.arange(0, XBLOCK)[:, None] | |
xmask = xindex < xnumel | |
rbase = tl.arange(0, RBLOCK)[None, :] | |
x0 = xindex | |
tmp1 = tl.load(in_ptr1 + (199)).to(tl.float32) | |
tmp2 = tl.broadcast_to(tmp1, [XBLOCK, RBLOCK]) | |
_tmp8 = tl.full([XBLOCK, RBLOCK], float("-inf"), tl.float32) | |
for roffset in range(0, rnumel, RBLOCK): | |
rindex = roffset + rbase | |
rmask = rindex < rnumel | |
r1 = rindex | |
tmp0 = tl.load(in_ptr0 + (r1 + (8000*x0)), rmask & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) | |
tmp3 = tmp0 < tmp2 | |
tmp4 = float("-inf") | |
tmp5 = tl.where(tmp3, tmp4, tmp0) | |
tmp6 = tmp5.to(tl.float32) | |
tmp7 = tl.broadcast_to(tmp6, [XBLOCK, RBLOCK]) | |
tmp9 = triton_helpers.maximum(_tmp8, tmp7) | |
_tmp8 = tl.where(rmask & xmask, tmp9, _tmp8) | |
tmp8 = triton_helpers.max2(_tmp8, 1)[:, None] | |
tl.store(out_ptr0 + (x0), tmp8, xmask) | |
''', device_str='cuda') | |
# kernel path: /tmp/torchinductor_ubuntu/hz/chzqxkcllimh3vsvmsxoviuxtefrsfgsde3el6pgijogxpaxrs7r.py | |
# Source Nodes: [logits_2, lt, probs], Original ATen: [aten._softmax, aten.lt, aten.scalar_tensor, aten.where] | |
# logits_2 => full_default_64, where_32 | |
# lt => lt | |
# probs => amax_32, convert_element_type_578 | |
triton_per_fused__softmax_lt_scalar_tensor_where_18 = async_compile.triton('triton_per_fused__softmax_lt_scalar_tensor_where_18', ''' | |
import triton | |
import triton.language as tl | |
from triton.compiler.compiler import AttrsDescriptor | |
from torch._inductor.runtime import triton_helpers, triton_heuristics | |
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties | |
@triton_heuristics.persistent_reduction( | |
size_hints=[1, 4], | |
reduction_hint=ReductionHint.INNER, | |
filename=__file__, | |
triton_meta={'signature': {0: '*fp32', 1: '*fp32', 2: 'i32', 3: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=76), 'constants': {2: 1}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1), equal_to_1=(2,))]}, | |
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_per_fused__softmax_lt_scalar_tensor_where_18', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': '5002d3b4f89584c56f8df13bbe2510318f7ec3b73ba9efd611640fb08b2fe939', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False} | |
) | |
@triton.jit | |
def triton_per_fused__softmax_lt_scalar_tensor_where_18(in_ptr0, out_ptr0, xnumel, rnumel, XBLOCK : tl.constexpr): | |
xnumel = 1 | |
rnumel = 4 | |
RBLOCK: tl.constexpr = 4 | |
xoffset = tl.program_id(0) * XBLOCK | |
xindex = xoffset + tl.arange(0, XBLOCK)[:, None] | |
xmask = xindex < xnumel | |
rindex = tl.arange(0, RBLOCK)[None, :] | |
roffset = 0 | |
rmask = rindex < rnumel | |
r0 = rindex | |
tmp0 = tl.load(in_ptr0 + (r0), rmask, other=0.0) | |
tmp1 = tl.broadcast_to(tmp0, [XBLOCK, RBLOCK]) | |
tmp3 = tl.where(rmask, tmp1, float("-inf")) | |
tmp4 = triton_helpers.max2(tmp3, 1)[:, None] | |
tl.store(out_ptr0 + (tl.full([XBLOCK, 1], 0, tl.int32)), tmp4, None) | |
''', device_str='cuda') | |
# kernel path: /tmp/torchinductor_ubuntu/za/czakkdksd2yxt6hanh7vxqjmbj7ara4d4rza6bwjwwt4stxakxtm.py | |
# Source Nodes: [logits_2, lt, probs], Original ATen: [aten._softmax, aten.lt, aten.scalar_tensor, aten.where] | |
# logits_2 => full_default_64, where_32 | |
# lt => lt | |
# probs => convert_element_type_578, exp_32, sub_96, sum_97 | |
triton_red_fused__softmax_lt_scalar_tensor_where_19 = async_compile.triton('triton_red_fused__softmax_lt_scalar_tensor_where_19', ''' | |
import triton | |
import triton.language as tl | |
from triton.compiler.compiler import AttrsDescriptor | |
from torch._inductor.runtime import triton_helpers, triton_heuristics | |
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties | |
@triton_heuristics.reduction( | |
size_hints=[4, 8192], | |
reduction_hint=ReductionHint.INNER, | |
filename=__file__, | |
triton_meta={'signature': {0: '*fp16', 1: '*fp16', 2: '*fp32', 3: '*fp32', 4: 'i32', 5: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=76), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 5), equal_to_1=())]}, | |
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused__softmax_lt_scalar_tensor_where_19', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 3, 'num_reduction': 1, 'backend_hash': '5002d3b4f89584c56f8df13bbe2510318f7ec3b73ba9efd611640fb08b2fe939', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False} | |
) | |
@triton.jit | |
def triton_red_fused__softmax_lt_scalar_tensor_where_19(in_ptr0, in_ptr1, in_ptr2, out_ptr0, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr): | |
xnumel = 4 | |
rnumel = 8000 | |
xoffset = tl.program_id(0) * XBLOCK | |
xindex = xoffset + tl.arange(0, XBLOCK)[:, None] | |
xmask = xindex < xnumel | |
rbase = tl.arange(0, RBLOCK)[None, :] | |
x0 = xindex | |
tmp1 = tl.load(in_ptr1 + (199)).to(tl.float32) | |
tmp2 = tl.broadcast_to(tmp1, [XBLOCK, RBLOCK]) | |
tmp7 = tl.load(in_ptr2 + (0)) | |
tmp8 = tl.broadcast_to(tmp7, [XBLOCK, RBLOCK]) | |
_tmp12 = tl.full([XBLOCK, RBLOCK], 0, tl.float32) | |
for roffset in range(0, rnumel, RBLOCK): | |
rindex = roffset + rbase | |
rmask = rindex < rnumel | |
r1 = rindex | |
tmp0 = tl.load(in_ptr0 + (r1 + (8000*x0)), rmask & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) | |
tmp3 = tmp0 < tmp2 | |
tmp4 = float("-inf") | |
tmp5 = tl.where(tmp3, tmp4, tmp0) | |
tmp6 = tmp5.to(tl.float32) | |
tmp9 = tmp6 - tmp8 | |
tmp10 = tl_math.exp(tmp9) | |
tmp11 = tl.broadcast_to(tmp10, [XBLOCK, RBLOCK]) | |
tmp13 = _tmp12 + tmp11 | |
_tmp12 = tl.where(rmask & xmask, tmp13, _tmp12) | |
tmp12 = tl.sum(_tmp12, 1)[:, None] | |
tl.store(out_ptr0 + (x0), tmp12, xmask) | |
''', device_str='cuda') | |
# kernel path: /tmp/torchinductor_ubuntu/aa/caad4sxzjec6myyhypd4e6odaasgzubfdw6lhsfkhrv5jqg4uc7g.py | |
# Source Nodes: [logits_2, lt, probs], Original ATen: [aten._softmax, aten.lt, aten.scalar_tensor, aten.where] | |
# logits_2 => full_default_64, where_32 | |
# lt => lt | |
# probs => convert_element_type_578, exp_32, sub_96, sum_97 | |
triton_per_fused__softmax_lt_scalar_tensor_where_20 = async_compile.triton('triton_per_fused__softmax_lt_scalar_tensor_where_20', ''' | |
import triton | |
import triton.language as tl | |
from triton.compiler.compiler import AttrsDescriptor | |
from torch._inductor.runtime import triton_helpers, triton_heuristics | |
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties | |
@triton_heuristics.persistent_reduction( | |
size_hints=[1, 4], | |
reduction_hint=ReductionHint.INNER, | |
filename=__file__, | |
triton_meta={'signature': {0: '*fp32', 1: '*fp32', 2: 'i32', 3: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=76), 'constants': {2: 1}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1), equal_to_1=(2,))]}, | |
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_per_fused__softmax_lt_scalar_tensor_where_20', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': '5002d3b4f89584c56f8df13bbe2510318f7ec3b73ba9efd611640fb08b2fe939', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False} | |
) | |
@triton.jit | |
def triton_per_fused__softmax_lt_scalar_tensor_where_20(in_ptr0, out_ptr0, xnumel, rnumel, XBLOCK : tl.constexpr): | |
xnumel = 1 | |
rnumel = 4 | |
RBLOCK: tl.constexpr = 4 | |
xoffset = tl.program_id(0) * XBLOCK | |
xindex = xoffset + tl.arange(0, XBLOCK)[:, None] | |
xmask = xindex < xnumel | |
rindex = tl.arange(0, RBLOCK)[None, :] | |
roffset = 0 | |
rmask = rindex < rnumel | |
r0 = rindex | |
tmp0 = tl.load(in_ptr0 + (r0), rmask, other=0.0) | |
tmp1 = tl.broadcast_to(tmp0, [XBLOCK, RBLOCK]) | |
tmp3 = tl.where(rmask, tmp1, 0) | |
tmp4 = tl.sum(tmp3, 1)[:, None] | |
tl.store(out_ptr0 + (tl.full([XBLOCK, 1], 0, tl.int32)), tmp4, None) | |
''', device_str='cuda') | |
# kernel path: /tmp/torchinductor_ubuntu/ch/cch6v3lnecv5tfe5ipqpowh4xge3iiqm4nrz67gnjbwlbump6mow.py | |
# Source Nodes: [argmax, idx_next, logits_2, lt, probs, q_128, truediv_1], Original ATen: [aten._softmax, aten._to_copy, aten.argmax, aten.div, aten.exponential, aten.lt, aten.scalar_tensor, aten.where] | |
# argmax => argmax | |
# idx_next => convert_element_type_582 | |
# logits_2 => full_default_64, where_32 | |
# lt => lt | |
# probs => convert_element_type_578, convert_element_type_579, div_33, exp_32, sub_96 | |
# q_128 => convert_element_type_581, log1p, mul_643, neg | |
# truediv_1 => div_34 | |
triton_red_fused__softmax__to_copy_argmax_div_exponential_lt_scalar_tensor_where_21 = async_compile.triton('triton_red_fused__softmax__to_copy_argmax_div_exponential_lt_scalar_tensor_where_21', ''' | |
import triton | |
import triton.language as tl | |
from triton.compiler.compiler import AttrsDescriptor | |
from torch._inductor.runtime import triton_helpers, triton_heuristics | |
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties | |
@triton_heuristics.reduction( | |
size_hints=[1, 32768], | |
reduction_hint=ReductionHint.DEFAULT, | |
filename=__file__, | |
triton_meta={'signature': {0: '*fp16', 1: '*fp16', 2: '*fp32', 3: '*fp32', 4: '*i64', 5: '*i32', 6: 'i32', 7: 'i32', 8: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=76), 'constants': {7: 1}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4, 5, 8), equal_to_1=(7,))]}, | |
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused__softmax__to_copy_argmax_div_exponential_lt_scalar_tensor_where_21', 'mutated_arg_names': ['in_out_ptr0'], 'no_x_dim': False, 'num_load': 4, 'num_reduction': 1, 'backend_hash': '5002d3b4f89584c56f8df13bbe2510318f7ec3b73ba9efd611640fb08b2fe939', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False} | |
) | |
@triton.jit | |
def triton_red_fused__softmax__to_copy_argmax_div_exponential_lt_scalar_tensor_where_21(in_out_ptr0, in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr2, load_seed_offset, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr): | |
xnumel = 1 | |
rnumel = 32000 | |
xoffset = tl.program_id(0) * XBLOCK | |
xindex = xoffset + tl.arange(0, XBLOCK)[:, None] | |
xmask = xindex < xnumel | |
rbase = tl.arange(0, RBLOCK)[None, :] | |
tmp1 = tl.load(in_ptr0 + (199)).to(tl.float32) | |
tmp2 = tl.broadcast_to(tmp1, [XBLOCK, RBLOCK]) | |
tmp7 = tl.load(in_ptr1 + (0)) | |
tmp8 = tl.broadcast_to(tmp7, [XBLOCK, RBLOCK]) | |
tmp11 = tl.load(in_ptr2 + (0)) | |
tmp12 = tl.broadcast_to(tmp11, [XBLOCK, RBLOCK]) | |
_tmp25 = tl.full([XBLOCK, RBLOCK], float("-inf"), tl.float32) | |
_tmp25_index = tl.full([XBLOCK, RBLOCK], 9223372036854775807, tl.int64) | |
for roffset in range(0, rnumel, RBLOCK): | |
rindex = roffset + rbase | |
rmask = rindex < rnumel | |
r0 = rindex | |
tmp0 = tl.load(in_out_ptr0 + (r0), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32) | |
tmp3 = tmp0 < tmp2 | |
tmp4 = float("-inf") | |
tmp5 = tl.where(tmp3, tmp4, tmp0) | |
tmp6 = tmp5.to(tl.float32) | |
tmp9 = tmp6 - tmp8 | |
tmp10 = tl_math.exp(tmp9) | |
tmp13 = tmp10 / tmp12 | |
tmp14 = tmp13.to(tl.float32) | |
tmp15 = tl.load(in_ptr3 + load_seed_offset) | |
tmp16 = r0 | |
tmp17 = tl.rand(tmp15, (tmp16).to(tl.uint32)) | |
tmp18 = -tmp17 | |
tmp19 = libdevice.log1p(tmp18) | |
tmp20 = -1.0 | |
tmp21 = tmp19 * tmp20 | |
tmp22 = tmp21.to(tl.float32) | |
tmp23 = tmp14 / tmp22 | |
tmp24 = tl.broadcast_to(tmp23, [XBLOCK, RBLOCK]) | |
_tmp25_next, _tmp25_index_next = triton_helpers.maximum_with_index( | |
_tmp25, _tmp25_index, tmp24, rindex | |
) | |
_tmp25 = tl.where(rmask, _tmp25_next, _tmp25) | |
_tmp25_index = tl.where(rmask, _tmp25_index_next, _tmp25_index) | |
tl.store(in_out_ptr0 + (tl.broadcast_to(r0, [XBLOCK, RBLOCK])), tmp14, rmask) | |
_, tmp25_tmp = triton_helpers.max_with_index(_tmp25, _tmp25_index, 1) | |
tmp25 = tmp25_tmp[:, None] | |
tmp26 = tmp25.to(tl.int32) | |
tl.store(out_ptr2 + (tl.full([XBLOCK, 1], 0, tl.int32)), tmp26, None) | |
''', device_str='cuda') | |
async_compile.wait(globals()) | |
del async_compile | |
def call(args): | |
arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1, arg7_1, arg8_1, arg9_1, arg10_1, arg11_1, arg12_1, arg13_1, arg14_1, arg15_1, arg16_1, arg17_1, arg18_1, arg19_1, arg20_1, arg21_1, arg22_1, arg23_1, arg24_1, arg25_1, arg26_1, arg27_1, arg28_1, arg29_1, arg30_1, arg31_1, arg32_1, arg33_1, arg34_1, arg35_1, arg36_1, arg37_1, arg38_1, arg39_1, arg40_1, arg41_1, arg42_1, arg43_1, arg44_1, arg45_1, arg46_1, arg47_1, arg48_1, arg49_1, arg50_1, arg51_1, arg52_1, arg53_1, arg54_1, arg55_1, arg56_1, arg57_1, arg58_1, arg59_1, arg60_1, arg61_1, arg62_1, arg63_1, arg64_1, arg65_1, arg66_1, arg67_1, arg68_1, arg69_1, arg70_1, arg71_1, arg72_1, arg73_1, arg74_1, arg75_1, arg76_1, arg77_1, arg78_1, arg79_1, arg80_1, arg81_1, arg82_1, arg83_1, arg84_1, arg85_1, arg86_1, arg87_1, arg88_1, arg89_1, arg90_1, arg91_1, arg92_1, arg93_1, arg94_1, arg95_1, arg96_1, arg97_1, arg98_1, arg99_1, arg100_1, arg101_1, arg102_1, arg103_1, arg104_1, arg105_1, arg106_1, arg107_1, arg108_1, arg109_1, arg110_1, arg111_1, arg112_1, arg113_1, arg114_1, arg115_1, arg116_1, arg117_1, arg118_1, arg119_1, arg120_1, arg121_1, arg122_1, arg123_1, arg124_1, arg125_1, arg126_1, arg127_1, arg128_1, arg129_1, arg130_1, arg131_1, arg132_1, arg133_1, arg134_1, arg135_1, arg136_1, arg137_1, arg138_1, arg139_1, arg140_1, arg141_1, arg142_1, arg143_1, arg144_1, arg145_1, arg146_1, arg147_1, arg148_1, arg149_1, arg150_1, arg151_1, arg152_1, arg153_1, arg154_1, arg155_1, arg156_1, arg157_1, arg158_1, arg159_1, arg160_1, arg161_1, arg162_1, arg163_1, arg164_1, arg165_1, arg166_1, arg167_1, arg168_1, arg169_1, arg170_1, arg171_1, arg172_1, arg173_1, arg174_1, arg175_1, arg176_1, arg177_1, arg178_1, arg179_1, arg180_1, arg181_1, arg182_1, arg183_1, arg184_1, arg185_1, arg186_1, arg187_1, arg188_1, arg189_1, arg190_1, arg191_1, arg192_1, arg193_1, arg194_1, arg195_1, arg196_1, arg197_1, arg198_1, arg199_1, arg200_1, arg201_1, arg202_1, arg203_1, arg204_1, arg205_1, arg206_1, arg207_1, arg208_1, arg209_1, arg210_1, arg211_1, arg212_1, arg213_1, arg214_1, arg215_1, arg216_1, arg217_1, arg218_1, arg219_1, arg220_1, arg221_1, arg222_1, arg223_1, arg224_1, arg225_1, arg226_1, arg227_1, arg228_1, arg229_1, arg230_1, arg231_1, arg232_1, arg233_1, arg234_1, arg235_1, arg236_1, arg237_1, arg238_1, arg239_1, arg240_1, arg241_1, arg242_1, arg243_1, arg244_1, arg245_1, arg246_1, arg247_1, arg248_1, arg249_1, arg250_1, arg251_1, arg252_1, arg253_1, arg254_1, arg255_1, arg256_1, arg257_1, arg258_1, arg259_1, arg260_1, arg261_1, arg262_1, arg263_1, arg264_1, arg265_1, arg266_1, arg267_1, arg268_1, arg269_1, arg270_1, arg271_1, arg272_1, arg273_1, arg274_1, arg275_1, arg276_1, arg277_1, arg278_1, arg279_1, arg280_1, arg281_1, arg282_1, arg283_1, arg284_1, arg285_1, arg286_1, arg287_1, arg288_1, arg289_1, arg290_1, arg291_1, arg292_1, arg293_1, arg294_1, arg295_1, arg296_1, arg297_1, arg298_1, arg299_1, arg300_1, arg301_1, arg302_1, arg303_1, arg304_1, arg305_1, arg306_1, arg307_1, arg308_1, arg309_1, arg310_1, arg311_1, arg312_1, arg313_1, arg314_1, arg315_1, arg316_1, arg317_1, arg318_1, arg319_1, arg320_1, arg321_1, arg322_1, arg323_1, arg324_1, arg325_1, arg326_1, arg327_1, arg328_1, arg329_1, arg330_1, arg331_1, arg332_1, arg333_1, arg334_1, arg335_1, arg336_1, arg337_1, arg338_1, arg339_1, arg340_1, arg341_1, arg342_1, arg343_1, arg344_1, arg345_1, arg346_1, arg347_1, arg348_1, arg349_1, arg350_1, arg351_1, arg352_1, arg353_1, arg354_1, arg355_1, arg356_1, arg357_1, arg358_1, arg359_1, arg360_1, arg361_1, arg362_1, arg363_1, arg364_1, arg365_1, arg366_1, arg367_1, arg368_1, arg369_1, arg370_1, arg371_1, arg372_1, arg373_1, arg374_1, arg375_1, arg376_1, arg377_1, arg378_1, arg379_1, arg380_1, arg381_1, arg382_1, arg383_1, arg384_1, arg385_1, arg386_1, arg387_1, arg388_1, arg389_1, arg390_1, arg391_1, arg392_1, arg393_1, arg394_1, arg395_1, arg396_1, arg397_1, arg398_1, arg399_1, arg400_1, arg401_1, arg402_1, arg403_1, arg404_1, arg405_1, arg406_1, arg407_1, arg408_1, arg409_1, arg410_1, arg411_1, arg412_1, arg413_1, arg414_1, arg415_1, arg416_1, arg417_1, arg418_1, arg419_1, arg420_1, arg421_1, arg422_1, arg423_1, arg424_1, arg425_1, arg426_1, arg427_1, arg428_1, arg429_1, arg430_1, arg431_1, arg432_1, arg433_1, arg434_1, arg435_1, arg436_1, arg437_1, arg438_1, arg439_1, arg440_1, arg441_1, arg442_1, arg443_1, arg444_1, arg445_1, arg446_1, arg447_1, arg448_1, arg449_1, arg450_1, arg451_1, arg452_1, arg453_1, arg454_1, arg455_1 = args | |
args.clear() | |
assert_size_stride(arg0_1, (4096, ), (1, )) | |
assert_size_stride(arg1_1, (4096, ), (1, )) | |
assert_size_stride(arg2_1, (4096, ), (1, )) | |
assert_size_stride(arg3_1, (4096, ), (1, )) | |
assert_size_stride(arg4_1, (4096, ), (1, )) | |
assert_size_stride(arg5_1, (4096, ), (1, )) | |
assert_size_stride(arg6_1, (4096, ), (1, )) | |
assert_size_stride(arg7_1, (4096, ), (1, )) | |
assert_size_stride(arg8_1, (4096, ), (1, )) | |
assert_size_stride(arg9_1, (4096, ), (1, )) | |
assert_size_stride(arg10_1, (4096, ), (1, )) | |
assert_size_stride(arg11_1, (4096, ), (1, )) | |
assert_size_stride(arg12_1, (4096, ), (1, )) | |
assert_size_stride(arg13_1, (4096, ), (1, )) | |
assert_size_stride(arg14_1, (4096, ), (1, )) | |
assert_size_stride(arg15_1, (4096, ), (1, )) | |
assert_size_stride(arg16_1, (4096, ), (1, )) | |
assert_size_stride(arg17_1, (4096, ), (1, )) | |
assert_size_stride(arg18_1, (4096, ), (1, )) | |
assert_size_stride(arg19_1, (4096, ), (1, )) | |
assert_size_stride(arg20_1, (4096, ), (1, )) | |
assert_size_stride(arg21_1, (4096, ), (1, )) | |
assert_size_stride(arg22_1, (4096, ), (1, )) | |
assert_size_stride(arg23_1, (4096, ), (1, )) | |
assert_size_stride(arg24_1, (4096, ), (1, )) | |
assert_size_stride(arg25_1, (4096, ), (1, )) | |
assert_size_stride(arg26_1, (4096, ), (1, )) | |
assert_size_stride(arg27_1, (4096, ), (1, )) | |
assert_size_stride(arg28_1, (4096, ), (1, )) | |
assert_size_stride(arg29_1, (4096, ), (1, )) | |
assert_size_stride(arg30_1, (4096, ), (1, )) | |
assert_size_stride(arg31_1, (4096, ), (1, )) | |
assert_size_stride(arg32_1, (4096, ), (1, )) | |
assert_size_stride(arg33_1, (4096, ), (1, )) | |
assert_size_stride(arg34_1, (4096, ), (1, )) | |
assert_size_stride(arg35_1, (4096, ), (1, )) | |
assert_size_stride(arg36_1, (4096, ), (1, )) | |
assert_size_stride(arg37_1, (4096, ), (1, )) | |
assert_size_stride(arg38_1, (4096, ), (1, )) | |
assert_size_stride(arg39_1, (4096, ), (1, )) | |
assert_size_stride(arg40_1, (4096, ), (1, )) | |
assert_size_stride(arg41_1, (4096, ), (1, )) | |
assert_size_stride(arg42_1, (4096, ), (1, )) | |
assert_size_stride(arg43_1, (4096, ), (1, )) | |
assert_size_stride(arg44_1, (4096, ), (1, )) | |
assert_size_stride(arg45_1, (4096, ), (1, )) | |
assert_size_stride(arg46_1, (4096, ), (1, )) | |
assert_size_stride(arg47_1, (4096, ), (1, )) | |
assert_size_stride(arg48_1, (4096, ), (1, )) | |
assert_size_stride(arg49_1, (4096, ), (1, )) | |
assert_size_stride(arg50_1, (4096, ), (1, )) | |
assert_size_stride(arg51_1, (4096, ), (1, )) | |
assert_size_stride(arg52_1, (4096, ), (1, )) | |
assert_size_stride(arg53_1, (4096, ), (1, )) | |
assert_size_stride(arg54_1, (4096, ), (1, )) | |
assert_size_stride(arg55_1, (4096, ), (1, )) | |
assert_size_stride(arg56_1, (4096, ), (1, )) | |
assert_size_stride(arg57_1, (4096, ), (1, )) | |
assert_size_stride(arg58_1, (4096, ), (1, )) | |
assert_size_stride(arg59_1, (4096, ), (1, )) | |
assert_size_stride(arg60_1, (4096, ), (1, )) | |
assert_size_stride(arg61_1, (4096, ), (1, )) | |
assert_size_stride(arg62_1, (4096, ), (1, )) | |
assert_size_stride(arg63_1, (4096, ), (1, )) | |
assert_size_stride(arg64_1, (4096, ), (1, )) | |
assert_size_stride(arg65_1, (32000, 4096), (4096, 1)) | |
assert_size_stride(arg66_1, (2048, 64, 2), (128, 2, 1)) | |
assert_size_stride(arg67_1, (208, 208), (208, 1)) | |
assert_size_stride(arg68_1, (12288, 768), (768, 1)) | |
assert_size_stride(arg69_1, (12288, ), (1, )) | |
assert_size_stride(arg70_1, (1, 32, 208, 128), (851968, 26624, 128, 1)) | |
assert_size_stride(arg71_1, (1, 32, 208, 128), (851968, 26624, 128, 1)) | |
assert_size_stride(arg72_1, (4096, 768), (768, 1)) | |
assert_size_stride(arg73_1, (4096, ), (1, )) | |
assert_size_stride(arg74_1, (11008, 768), (768, 1)) | |
assert_size_stride(arg75_1, (11008, ), (1, )) | |
assert_size_stride(arg76_1, (11008, 768), (768, 1)) | |
assert_size_stride(arg77_1, (11008, ), (1, )) | |
assert_size_stride(arg78_1, (4096, 2064), (2064, 1)) | |
assert_size_stride(arg79_1, (4096, ), (1, )) | |
assert_size_stride(arg80_1, (12288, 768), (768, 1)) | |
assert_size_stride(arg81_1, (12288, ), (1, )) | |
assert_size_stride(arg82_1, (1, 32, 208, 128), (851968, 26624, 128, 1)) | |
assert_size_stride(arg83_1, (1, 32, 208, 128), (851968, 26624, 128, 1)) | |
assert_size_stride(arg84_1, (4096, 768), (768, 1)) | |
assert_size_stride(arg85_1, (4096, ), (1, )) | |
assert_size_stride(arg86_1, (11008, 768), (768, 1)) | |
assert_size_stride(arg87_1, (11008, ), (1, )) | |
assert_size_stride(arg88_1, (11008, 768), (768, 1)) | |
assert_size_stride(arg89_1, (11008, ), (1, )) | |
assert_size_stride(arg90_1, (4096, 2064), (2064, 1)) | |
assert_size_stride(arg91_1, (4096, ), (1, )) | |
assert_size_stride(arg92_1, (12288, 768), (768, 1)) | |
assert_size_stride(arg93_1, (12288, ), (1, )) | |
assert_size_stride(arg94_1, (1, 32, 208, 128), (851968, 26624, 128, 1)) | |
assert_size_stride(arg95_1, (1, 32, 208, 128), (851968, 26624, 128, 1)) | |
assert_size_stride(arg96_1, (4096, 768), (768, 1)) | |
assert_size_stride(arg97_1, (4096, ), (1, )) | |
assert_size_stride(arg98_1, (11008, 768), (768, 1)) | |
assert_size_stride(arg99_1, (11008, ), (1, )) | |
assert_size_stride(arg100_1, (11008, 768), (768, 1)) | |
assert_size_stride(arg101_1, (11008, ), (1, )) | |
assert_size_stride(arg102_1, (4096, 2064), (2064, 1)) | |
assert_size_stride(arg103_1, (4096, ), (1, )) | |
assert_size_stride(arg104_1, (12288, 768), (768, 1)) | |
assert_size_stride(arg105_1, (12288, ), (1, )) | |
assert_size_stride(arg106_1, (1, 32, 208, 128), (851968, 26624, 128, 1)) | |
assert_size_stride(arg107_1, (1, 32, 208, 128), (851968, 26624, 128, 1)) | |
assert_size_stride(arg108_1, (4096, 768), (768, 1)) | |
assert_size_stride(arg109_1, (4096, ), (1, )) | |
assert_size_stride(arg110_1, (11008, 768), (768, 1)) | |
assert_size_stride(arg111_1, (11008, ), (1, )) | |
assert_size_stride(arg112_1, (11008, 768), (768, 1)) | |
assert_size_stride(arg113_1, (11008, ), (1, )) | |
assert_size_stride(arg114_1, (4096, 2064), (2064, 1)) | |
assert_size_stride(arg115_1, (4096, ), (1, )) | |
assert_size_stride(arg116_1, (12288, 768), (768, 1)) | |
assert_size_stride(arg117_1, (12288, ), (1, )) | |
assert_size_stride(arg118_1, (1, 32, 208, 128), (851968, 26624, 128, 1)) | |
assert_size_stride(arg119_1, (1, 32, 208, 128), (851968, 26624, 128, 1)) | |
assert_size_stride(arg120_1, (4096, 768), (768, 1)) | |
assert_size_stride(arg121_1, (4096, ), (1, )) | |
assert_size_stride(arg122_1, (11008, 768), (768, 1)) | |
assert_size_stride(arg123_1, (11008, ), (1, )) | |
assert_size_stride(arg124_1, (11008, 768), (768, 1)) | |
assert_size_stride(arg125_1, (11008, ), (1, )) | |
assert_size_stride(arg126_1, (4096, 2064), (2064, 1)) | |
assert_size_stride(arg127_1, (4096, ), (1, )) | |
assert_size_stride(arg128_1, (12288, 768), (768, 1)) | |
assert_size_stride(arg129_1, (12288, ), (1, )) | |
assert_size_stride(arg130_1, (1, 32, 208, 128), (851968, 26624, 128, 1)) | |
assert_size_stride(arg131_1, (1, 32, 208, 128), (851968, 26624, 128, 1)) | |
assert_size_stride(arg132_1, (4096, 768), (768, 1)) | |
assert_size_stride(arg133_1, (4096, ), (1, )) | |
assert_size_stride(arg134_1, (11008, 768), (768, 1)) | |
assert_size_stride(arg135_1, (11008, ), (1, )) | |
assert_size_stride(arg136_1, (11008, 768), (768, 1)) | |
assert_size_stride(arg137_1, (11008, ), (1, )) | |
assert_size_stride(arg138_1, (4096, 2064), (2064, 1)) | |
assert_size_stride(arg139_1, (4096, ), (1, )) | |
assert_size_stride(arg140_1, (12288, 768), (768, 1)) | |
assert_size_stride(arg141_1, (12288, ), (1, )) | |
assert_size_stride(arg142_1, (1, 32, 208, 128), (851968, 26624, 128, 1)) | |
assert_size_stride(arg143_1, (1, 32, 208, 128), (851968, 26624, 128, 1)) | |
assert_size_stride(arg144_1, (4096, 768), (768, 1)) | |
assert_size_stride(arg145_1, (4096, ), (1, )) | |
assert_size_stride(arg146_1, (11008, 768), (768, 1)) | |
assert_size_stride(arg147_1, (11008, ), (1, )) | |
assert_size_stride(arg148_1, (11008, 768), (768, 1)) | |
assert_size_stride(arg149_1, (11008, ), (1, )) | |
assert_size_stride(arg150_1, (4096, 2064), (2064, 1)) | |
assert_size_stride(arg151_1, (4096, ), (1, )) | |
assert_size_stride(arg152_1, (12288, 768), (768, 1)) | |
assert_size_stride(arg153_1, (12288, ), (1, )) | |
assert_size_stride(arg154_1, (1, 32, 208, 128), (851968, 26624, 128, 1)) | |
assert_size_stride(arg155_1, (1, 32, 208, 128), (851968, 26624, 128, 1)) | |
assert_size_stride(arg156_1, (4096, 768), (768, 1)) | |
assert_size_stride(arg157_1, (4096, ), (1, )) | |
assert_size_stride(arg158_1, (11008, 768), (768, 1)) | |
assert_size_stride(arg159_1, (11008, ), (1, )) | |
assert_size_stride(arg160_1, (11008, 768), (768, 1)) | |
assert_size_stride(arg161_1, (11008, ), (1, )) | |
assert_size_stride(arg162_1, (4096, 2064), (2064, 1)) | |
assert_size_stride(arg163_1, (4096, ), (1, )) | |
assert_size_stride(arg164_1, (12288, 768), (768, 1)) | |
assert_size_stride(arg165_1, (12288, ), (1, )) | |
assert_size_stride(arg166_1, (1, 32, 208, 128), (851968, 26624, 128, 1)) | |
assert_size_stride(arg167_1, (1, 32, 208, 128), (851968, 26624, 128, 1)) | |
assert_size_stride(arg168_1, (4096, 768), (768, 1)) | |
assert_size_stride(arg169_1, (4096, ), (1, )) | |
assert_size_stride(arg170_1, (11008, 768), (768, 1)) | |
assert_size_stride(arg171_1, (11008, ), (1, )) | |
assert_size_stride(arg172_1, (11008, 768), (768, 1)) | |
assert_size_stride(arg173_1, (11008, ), (1, )) | |
assert_size_stride(arg174_1, (4096, 2064), (2064, 1)) | |
assert_size_stride(arg175_1, (4096, ), (1, )) | |
assert_size_stride(arg176_1, (12288, 768), (768, 1)) | |
assert_size_stride(arg177_1, (12288, ), (1, )) | |
assert_size_stride(arg178_1, (1, 32, 208, 128), (851968, 26624, 128, 1)) | |
assert_size_stride(arg179_1, (1, 32, 208, 128), (851968, 26624, 128, 1)) | |
assert_size_stride(arg180_1, (4096, 768), (768, 1)) | |
assert_size_stride(arg181_1, (4096, ), (1, )) | |
assert_size_stride(arg182_1, (11008, 768), (768, 1)) | |
assert_size_stride(arg183_1, (11008, ), (1, )) | |
assert_size_stride(arg184_1, (11008, 768), (768, 1)) | |
assert_size_stride(arg185_1, (11008, ), (1, )) | |
assert_size_stride(arg186_1, (4096, 2064), (2064, 1)) | |
assert_size_stride(arg187_1, (4096, ), (1, )) | |
assert_size_stride(arg188_1, (12288, 768), (768, 1)) | |
assert_size_stride(arg189_1, (12288, ), (1, )) | |
assert_size_stride(arg190_1, (1, 32, 208, 128), (851968, 26624, 128, 1)) | |
assert_size_stride(arg191_1, (1, 32, 208, 128), (851968, 26624, 128, 1)) | |
assert_size_stride(arg192_1, (4096, 768), (768, 1)) | |
assert_size_stride(arg193_1, (4096, ), (1, )) | |
assert_size_stride(arg194_1, (11008, 768), (768, 1)) | |
assert_size_stride(arg195_1, (11008, ), (1, )) | |
assert_size_stride(arg196_1, (11008, 768), (768, 1)) | |
assert_size_stride(arg197_1, (11008, ), (1, )) | |
assert_size_stride(arg198_1, (4096, 2064), (2064, 1)) | |
assert_size_stride(arg199_1, (4096, ), (1, )) | |
assert_size_stride(arg200_1, (12288, 768), (768, 1)) | |
assert_size_stride(arg201_1, (12288, ), (1, )) | |
assert_size_stride(arg202_1, (1, 32, 208, 128), (851968, 26624, 128, 1)) | |
assert_size_stride(arg203_1, (1, 32, 208, 128), (851968, 26624, 128, 1)) | |
assert_size_stride(arg204_1, (4096, 768), (768, 1)) | |
assert_size_stride(arg205_1, (4096, ), (1, )) | |
assert_size_stride(arg206_1, (11008, 768), (768, 1)) | |
assert_size_stride(arg207_1, (11008, ), (1, )) | |
assert_size_stride(arg208_1, (11008, 768), (768, 1)) | |
assert_size_stride(arg209_1, (11008, ), (1, )) | |
assert_size_stride(arg210_1, (4096, 2064), (2064, 1)) | |
assert_size_stride(arg211_1, (4096, ), (1, )) | |
assert_size_stride(arg212_1, (12288, 768), (768, 1)) | |
assert_size_stride(arg213_1, (12288, ), (1, )) | |
assert_size_stride(arg214_1, (1, 32, 208, 128), (851968, 26624, 128, 1)) | |
assert_size_stride(arg215_1, (1, 32, 208, 128), (851968, 26624, 128, 1)) | |
assert_size_stride(arg216_1, (4096, 768), (768, 1)) | |
assert_size_stride(arg217_1, (4096, ), (1, )) | |
assert_size_stride(arg218_1, (11008, 768), (768, 1)) | |
assert_size_stride(arg219_1, (11008, ), (1, )) | |
assert_size_stride(arg220_1, (11008, 768), (768, 1)) | |
assert_size_stride(arg221_1, (11008, ), (1, )) | |
assert_size_stride(arg222_1, (4096, 2064), (2064, 1)) | |
assert_size_stride(arg223_1, (4096, ), (1, )) | |
assert_size_stride(arg224_1, (12288, 768), (768, 1)) | |
assert_size_stride(arg225_1, (12288, ), (1, )) | |
assert_size_stride(arg226_1, (1, 32, 208, 128), (851968, 26624, 128, 1)) | |
assert_size_stride(arg227_1, (1, 32, 208, 128), (851968, 26624, 128, 1)) | |
assert_size_stride(arg228_1, (4096, 768), (768, 1)) | |
assert_size_stride(arg229_1, (4096, ), (1, )) | |
assert_size_stride(arg230_1, (11008, 768), (768, 1)) | |
assert_size_stride(arg231_1, (11008, ), (1, )) | |
assert_size_stride(arg232_1, (11008, 768), (768, 1)) | |
assert_size_stride(arg233_1, (11008, ), (1, )) | |
assert_size_stride(arg234_1, (4096, 2064), (2064, 1)) | |
assert_size_stride(arg235_1, (4096, ), (1, )) | |
assert_size_stride(arg236_1, (12288, 768), (768, 1)) | |
assert_size_stride(arg237_1, (12288, ), (1, )) | |
assert_size_stride(arg238_1, (1, 32, 208, 128), (851968, 26624, 128, 1)) | |
assert_size_stride(arg239_1, (1, 32, 208, 128), (851968, 26624, 128, 1)) | |
assert_size_stride(arg240_1, (4096, 768), (768, 1)) | |
assert_size_stride(arg241_1, (4096, ), (1, )) | |
assert_size_stride(arg242_1, (11008, 768), (768, 1)) | |
assert_size_stride(arg243_1, (11008, ), (1, )) | |
assert_size_stride(arg244_1, (11008, 768), (768, 1)) | |
assert_size_stride(arg245_1, (11008, ), (1, )) | |
assert_size_stride(arg246_1, (4096, 2064), (2064, 1)) | |
assert_size_stride(arg247_1, (4096, ), (1, )) | |
assert_size_stride(arg248_1, (12288, 768), (768, 1)) | |
assert_size_stride(arg249_1, (12288, ), (1, )) | |
assert_size_stride(arg250_1, (1, 32, 208, 128), (851968, 26624, 128, 1)) | |
assert_size_stride(arg251_1, (1, 32, 208, 128), (851968, 26624, 128, 1)) | |
assert_size_stride(arg252_1, (4096, 768), (768, 1)) | |
assert_size_stride(arg253_1, (4096, ), (1, )) | |
assert_size_stride(arg254_1, (11008, 768), (768, 1)) | |
assert_size_stride(arg255_1, (11008, ), (1, )) | |
assert_size_stride(arg256_1, (11008, 768), (768, 1)) | |
assert_size_stride(arg257_1, (11008, ), (1, )) | |
assert_size_stride(arg258_1, (4096, 2064), (2064, 1)) | |
assert_size_stride(arg259_1, (4096, ), (1, )) | |
assert_size_stride(arg260_1, (12288, 768), (768, 1)) | |
assert_size_stride(arg261_1, (12288, ), (1, )) | |
assert_size_stride(arg262_1, (1, 32, 208, 128), (851968, 26624, 128, 1)) | |
assert_size_stride(arg263_1, (1, 32, 208, 128), (851968, 26624, 128, 1)) | |
assert_size_stride(arg264_1, (4096, 768), (768, 1)) | |
assert_size_stride(arg265_1, (4096, ), (1, )) | |
assert_size_stride(arg266_1, (11008, 768), (768, 1)) | |
assert_size_stride(arg267_1, (11008, ), (1, )) | |
assert_size_stride(arg268_1, (11008, 768), (768, 1)) | |
assert_size_stride(arg269_1, (11008, ), (1, )) | |
assert_size_stride(arg270_1, (4096, 2064), (2064, 1)) | |
assert_size_stride(arg271_1, (4096, ), (1, )) | |
assert_size_stride(arg272_1, (12288, 768), (768, 1)) | |
assert_size_stride(arg273_1, (12288, ), (1, )) | |
assert_size_stride(arg274_1, (1, 32, 208, 128), (851968, 26624, 128, 1)) | |
assert_size_stride(arg275_1, (1, 32, 208, 128), (851968, 26624, 128, 1)) | |
assert_size_stride(arg276_1, (4096, 768), (768, 1)) | |
assert_size_stride(arg277_1, (4096, ), (1, )) | |
assert_size_stride(arg278_1, (11008, 768), (768, 1)) | |
assert_size_stride(arg279_1, (11008, ), (1, )) | |
assert_size_stride(arg280_1, (11008, 768), (768, 1)) | |
assert_size_stride(arg281_1, (11008, ), (1, )) | |
assert_size_stride(arg282_1, (4096, 2064), (2064, 1)) | |
assert_size_stride(arg283_1, (4096, ), (1, )) | |
assert_size_stride(arg284_1, (12288, 768), (768, 1)) | |
assert_size_stride(arg285_1, (12288, ), (1, )) | |
assert_size_stride(arg286_1, (1, 32, 208, 128), (851968, 26624, 128, 1)) | |
assert_size_stride(arg287_1, (1, 32, 208, 128), (851968, 26624, 128, 1)) | |
assert_size_stride(arg288_1, (4096, 768), (768, 1)) | |
assert_size_stride(arg289_1, (4096, ), (1, )) | |
assert_size_stride(arg290_1, (11008, 768), (768, 1)) | |
assert_size_stride(arg291_1, (11008, ), (1, )) | |
assert_size_stride(arg292_1, (11008, 768), (768, 1)) | |
assert_size_stride(arg293_1, (11008, ), (1, )) | |
assert_size_stride(arg294_1, (4096, 2064), (2064, 1)) | |
assert_size_stride(arg295_1, (4096, ), (1, )) | |
assert_size_stride(arg296_1, (12288, 768), (768, 1)) | |
assert_size_stride(arg297_1, (12288, ), (1, )) | |
assert_size_stride(arg298_1, (1, 32, 208, 128), (851968, 26624, 128, 1)) | |
assert_size_stride(arg299_1, (1, 32, 208, 128), (851968, 26624, 128, 1)) | |
assert_size_stride(arg300_1, (4096, 768), (768, 1)) | |
assert_size_stride(arg301_1, (4096, ), (1, )) | |
assert_size_stride(arg302_1, (11008, 768), (768, 1)) | |
assert_size_stride(arg303_1, (11008, ), (1, )) | |
assert_size_stride(arg304_1, (11008, 768), (768, 1)) | |
assert_size_stride(arg305_1, (11008, ), (1, )) | |
assert_size_stride(arg306_1, (4096, 2064), (2064, 1)) | |
assert_size_stride(arg307_1, (4096, ), (1, )) | |
assert_size_stride(arg308_1, (12288, 768), (768, 1)) | |
assert_size_stride(arg309_1, (12288, ), (1, )) | |
assert_size_stride(arg310_1, (1, 32, 208, 128), (851968, 26624, 128, 1)) | |
assert_size_stride(arg311_1, (1, 32, 208, 128), (851968, 26624, 128, 1)) | |
assert_size_stride(arg312_1, (4096, 768), (768, 1)) | |
assert_size_stride(arg313_1, (4096, ), (1, )) | |
assert_size_stride(arg314_1, (11008, 768), (768, 1)) | |
assert_size_stride(arg315_1, (11008, ), (1, )) | |
assert_size_stride(arg316_1, (11008, 768), (768, 1)) | |
assert_size_stride(arg317_1, (11008, ), (1, )) | |
assert_size_stride(arg318_1, (4096, 2064), (2064, 1)) | |
assert_size_stride(arg319_1, (4096, ), (1, )) | |
assert_size_stride(arg320_1, (12288, 768), (768, 1)) | |
assert_size_stride(arg321_1, (12288, ), (1, )) | |
assert_size_stride(arg322_1, (1, 32, 208, 128), (851968, 26624, 128, 1)) | |
assert_size_stride(arg323_1, (1, 32, 208, 128), (851968, 26624, 128, 1)) | |
assert_size_stride(arg324_1, (4096, 768), (768, 1)) | |
assert_size_stride(arg325_1, (4096, ), (1, )) | |
assert_size_stride(arg326_1, (11008, 768), (768, 1)) | |
assert_size_stride(arg327_1, (11008, ), (1, )) | |
assert_size_stride(arg328_1, (11008, 768), (768, 1)) | |
assert_size_stride(arg329_1, (11008, ), (1, )) | |
assert_size_stride(arg330_1, (4096, 2064), (2064, 1)) | |
assert_size_stride(arg331_1, (4096, ), (1, )) | |
assert_size_stride(arg332_1, (12288, 768), (768, 1)) | |
assert_size_stride(arg333_1, (12288, ), (1, )) | |
assert_size_stride(arg334_1, (1, 32, 208, 128), (851968, 26624, 128, 1)) | |
assert_size_stride(arg335_1, (1, 32, 208, 128), (851968, 26624, 128, 1)) | |
assert_size_stride(arg336_1, (4096, 768), (768, 1)) | |
assert_size_stride(arg337_1, (4096, ), (1, )) | |
assert_size_stride(arg338_1, (11008, 768), (768, 1)) | |
assert_size_stride(arg339_1, (11008, ), (1, )) | |
assert_size_stride(arg340_1, (11008, 768), (768, 1)) | |
assert_size_stride(arg341_1, (11008, ), (1, )) | |
assert_size_stride(arg342_1, (4096, 2064), (2064, 1)) | |
assert_size_stride(arg343_1, (4096, ), (1, )) | |
assert_size_stride(arg344_1, (12288, 768), (768, 1)) | |
assert_size_stride(arg345_1, (12288, ), (1, )) | |
assert_size_stride(arg346_1, (1, 32, 208, 128), (851968, 26624, 128, 1)) | |
assert_size_stride(arg347_1, (1, 32, 208, 128), (851968, 26624, 128, 1)) | |
assert_size_stride(arg348_1, (4096, 768), (768, 1)) | |
assert_size_stride(arg349_1, (4096, ), (1, )) | |
assert_size_stride(arg350_1, (11008, 768), (768, 1)) | |
assert_size_stride(arg351_1, (11008, ), (1, )) | |
assert_size_stride(arg352_1, (11008, 768), (768, 1)) | |
assert_size_stride(arg353_1, (11008, ), (1, )) | |
assert_size_stride(arg354_1, (4096, 2064), (2064, 1)) | |
assert_size_stride(arg355_1, (4096, ), (1, )) | |
assert_size_stride(arg356_1, (12288, 768), (768, 1)) | |
assert_size_stride(arg357_1, (12288, ), (1, )) | |
assert_size_stride(arg358_1, (1, 32, 208, 128), (851968, 26624, 128, 1)) | |
assert_size_stride(arg359_1, (1, 32, 208, 128), (851968, 26624, 128, 1)) | |
assert_size_stride(arg360_1, (4096, 768), (768, 1)) | |
assert_size_stride(arg361_1, (4096, ), (1, )) | |
assert_size_stride(arg362_1, (11008, 768), (768, 1)) | |
assert_size_stride(arg363_1, (11008, ), (1, )) | |
assert_size_stride(arg364_1, (11008, 768), (768, 1)) | |
assert_size_stride(arg365_1, (11008, ), (1, )) | |
assert_size_stride(arg366_1, (4096, 2064), (2064, 1)) | |
assert_size_stride(arg367_1, (4096, ), (1, )) | |
assert_size_stride(arg368_1, (12288, 768), (768, 1)) | |
assert_size_stride(arg369_1, (12288, ), (1, )) | |
assert_size_stride(arg370_1, (1, 32, 208, 128), (851968, 26624, 128, 1)) | |
assert_size_stride(arg371_1, (1, 32, 208, 128), (851968, 26624, 128, 1)) | |
assert_size_stride(arg372_1, (4096, 768), (768, 1)) | |
assert_size_stride(arg373_1, (4096, ), (1, )) | |
assert_size_stride(arg374_1, (11008, 768), (768, 1)) | |
assert_size_stride(arg375_1, (11008, ), (1, )) | |
assert_size_stride(arg376_1, (11008, 768), (768, 1)) | |
assert_size_stride(arg377_1, (11008, ), (1, )) | |
assert_size_stride(arg378_1, (4096, 2064), (2064, 1)) | |
assert_size_stride(arg379_1, (4096, ), (1, )) | |
assert_size_stride(arg380_1, (12288, 768), (768, 1)) | |
assert_size_stride(arg381_1, (12288, ), (1, )) | |
assert_size_stride(arg382_1, (1, 32, 208, 128), (851968, 26624, 128, 1)) | |
assert_size_stride(arg383_1, (1, 32, 208, 128), (851968, 26624, 128, 1)) | |
assert_size_stride(arg384_1, (4096, 768), (768, 1)) | |
assert_size_stride(arg385_1, (4096, ), (1, )) | |
assert_size_stride(arg386_1, (11008, 768), (768, 1)) | |
assert_size_stride(arg387_1, (11008, ), (1, )) | |
assert_size_stride(arg388_1, (11008, 768), (768, 1)) | |
assert_size_stride(arg389_1, (11008, ), (1, )) | |
assert_size_stride(arg390_1, (4096, 2064), (2064, 1)) | |
assert_size_stride(arg391_1, (4096, ), (1, )) | |
assert_size_stride(arg392_1, (12288, 768), (768, 1)) | |
assert_size_stride(arg393_1, (12288, ), (1, )) | |
assert_size_stride(arg394_1, (1, 32, 208, 128), (851968, 26624, 128, 1)) | |
assert_size_stride(arg395_1, (1, 32, 208, 128), (851968, 26624, 128, 1)) | |
assert_size_stride(arg396_1, (4096, 768), (768, 1)) | |
assert_size_stride(arg397_1, (4096, ), (1, )) | |
assert_size_stride(arg398_1, (11008, 768), (768, 1)) | |
assert_size_stride(arg399_1, (11008, ), (1, )) | |
assert_size_stride(arg400_1, (11008, 768), (768, 1)) | |
assert_size_stride(arg401_1, (11008, ), (1, )) | |
assert_size_stride(arg402_1, (4096, 2064), (2064, 1)) | |
assert_size_stride(arg403_1, (4096, ), (1, )) | |
assert_size_stride(arg404_1, (12288, 768), (768, 1)) | |
assert_size_stride(arg405_1, (12288, ), (1, )) | |
assert_size_stride(arg406_1, (1, 32, 208, 128), (851968, 26624, 128, 1)) | |
assert_size_stride(arg407_1, (1, 32, 208, 128), (851968, 26624, 128, 1)) | |
assert_size_stride(arg408_1, (4096, 768), (768, 1)) | |
assert_size_stride(arg409_1, (4096, ), (1, )) | |
assert_size_stride(arg410_1, (11008, 768), (768, 1)) | |
assert_size_stride(arg411_1, (11008, ), (1, )) | |
assert_size_stride(arg412_1, (11008, 768), (768, 1)) | |
assert_size_stride(arg413_1, (11008, ), (1, )) | |
assert_size_stride(arg414_1, (4096, 2064), (2064, 1)) | |
assert_size_stride(arg415_1, (4096, ), (1, )) | |
assert_size_stride(arg416_1, (12288, 768), (768, 1)) | |
assert_size_stride(arg417_1, (12288, ), (1, )) | |
assert_size_stride(arg418_1, (1, 32, 208, 128), (851968, 26624, 128, 1)) | |
assert_size_stride(arg419_1, (1, 32, 208, 128), (851968, 26624, 128, 1)) | |
assert_size_stride(arg420_1, (4096, 768), (768, 1)) | |
assert_size_stride(arg421_1, (4096, ), (1, )) | |
assert_size_stride(arg422_1, (11008, 768), (768, 1)) | |
assert_size_stride(arg423_1, (11008, ), (1, )) | |
assert_size_stride(arg424_1, (11008, 768), (768, 1)) | |
assert_size_stride(arg425_1, (11008, ), (1, )) | |
assert_size_stride(arg426_1, (4096, 2064), (2064, 1)) | |
assert_size_stride(arg427_1, (4096, ), (1, )) | |
assert_size_stride(arg428_1, (12288, 768), (768, 1)) | |
assert_size_stride(arg429_1, (12288, ), (1, )) | |
assert_size_stride(arg430_1, (1, 32, 208, 128), (851968, 26624, 128, 1)) | |
assert_size_stride(arg431_1, (1, 32, 208, 128), (851968, 26624, 128, 1)) | |
assert_size_stride(arg432_1, (4096, 768), (768, 1)) | |
assert_size_stride(arg433_1, (4096, ), (1, )) | |
assert_size_stride(arg434_1, (11008, 768), (768, 1)) | |
assert_size_stride(arg435_1, (11008, ), (1, )) | |
assert_size_stride(arg436_1, (11008, 768), (768, 1)) | |
assert_size_stride(arg437_1, (11008, ), (1, )) | |
assert_size_stride(arg438_1, (4096, 2064), (2064, 1)) | |
assert_size_stride(arg439_1, (4096, ), (1, )) | |
assert_size_stride(arg440_1, (12288, 768), (768, 1)) | |
assert_size_stride(arg441_1, (12288, ), (1, )) | |
assert_size_stride(arg442_1, (1, 32, 208, 128), (851968, 26624, 128, 1)) | |
assert_size_stride(arg443_1, (1, 32, 208, 128), (851968, 26624, 128, 1)) | |
assert_size_stride(arg444_1, (4096, 768), (768, 1)) | |
assert_size_stride(arg445_1, (4096, ), (1, )) | |
assert_size_stride(arg446_1, (11008, 768), (768, 1)) | |
assert_size_stride(arg447_1, (11008, ), (1, )) | |
assert_size_stride(arg448_1, (11008, 768), (768, 1)) | |
assert_size_stride(arg449_1, (11008, ), (1, )) | |
assert_size_stride(arg450_1, (4096, 2064), (2064, 1)) | |
assert_size_stride(arg451_1, (4096, ), (1, )) | |
assert_size_stride(arg452_1, (32000, 768), (768, 1)) | |
assert_size_stride(arg453_1, (32000, ), (1, )) | |
assert_size_stride(arg454_1, (1, ), (1, )) | |
assert_size_stride(arg455_1, (1, 1), (1, 1)) | |
with torch.cuda._DeviceGuard(0): | |
torch.cuda.set_device(0) | |
buf1 = empty_strided_cuda((1, 4096), (4096, 1), torch.float16) | |
# Source Nodes: [float_1, mean, mul, out, x], Original ATen: [aten._to_copy, aten.embedding, aten.mean, aten.mul, torchao.fp16act_fp6weight_linear] | |
stream0 = get_raw_stream(0) | |
triton_red_fused__to_copy_embedding_fp16act_fp6weight_linear_mean_mul_0.run(arg455_1, arg65_1, arg0_1, buf1, 1, 4096, grid=grid(1), stream=stream0) | |
del arg0_1 | |
# Source Nodes: [out], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf2 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf1, arg68_1, arg69_1, 1) | |
del arg68_1 | |
del arg69_1 | |
buf3 = buf2 | |
del buf2 | |
buf5 = empty_strided_cuda((32, 1, 128), (128, 4096, 1), torch.float32) | |
# Source Nodes: [setitem, setitem_1, y], Original ATen: [aten.bmm, aten.index_put] | |
triton_poi_fused_bmm_index_put_1.run(arg454_1, buf3, arg66_1, arg70_1, buf5, arg71_1, 4096, grid=grid(4096), stream=stream0) | |
del buf3 | |
buf6 = empty_strided_cuda((32, 1, 208), (208, 6656, 1), torch.float32) | |
# Source Nodes: [y], Original ATen: [aten.bmm] | |
triton_red_fused_bmm_2.run(buf5, arg70_1, buf6, 6656, 128, grid=grid(6656), stream=stream0) | |
del arg70_1 | |
buf10 = empty_strided_cuda((32, 1, 208), (208, 6656, 1), torch.float32) | |
# Source Nodes: [mask, y], Original ATen: [aten._softmax, aten.add, aten.bmm, aten.index, aten.logical_not, aten.masked_fill, aten.zeros_like] | |
triton_per_fused__softmax_add_bmm_index_logical_not_masked_fill_zeros_like_3.run(buf6, arg454_1, arg67_1, buf10, 32, 208, grid=grid(32), stream=stream0) | |
buf11 = empty_strided_cuda((32, 1, 128, 2), (256, 8192, 1, 128), torch.float32) | |
# Source Nodes: [y], Original ATen: [aten.bmm] | |
triton_red_fused_bmm_4.run(buf10, arg71_1, buf11, 8192, 104, grid=grid(8192), stream=stream0) | |
del arg71_1 | |
buf13 = buf1; del buf1 # reuse | |
# Source Nodes: [out_1, y], Original ATen: [aten.bmm, torchao.fp16act_fp6weight_linear] | |
triton_per_fused_bmm_fp16act_fp6weight_linear_5.run(buf11, buf13, 4096, 2, grid=grid(4096), stream=stream0) | |
# Source Nodes: [out_1], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf14 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf13, arg72_1, arg73_1, 13) | |
del arg72_1 | |
del arg73_1 | |
buf15 = buf14 | |
del buf14 | |
buf17 = reinterpret_tensor(buf13, (1, 1, 4096), (4096, 4096, 1), 0); del buf13 # reuse | |
# Source Nodes: [add_4, float_4, h, mean_1, mul_11, mul_12, mul_13, output_1, rsqrt_1, x], Original ATen: [aten._to_copy, aten.add, aten.embedding, aten.mean, aten.mul, aten.rsqrt] | |
triton_red_fused__to_copy_add_embedding_mean_mul_rsqrt_6.run(arg455_1, arg65_1, buf15, arg1_1, buf17, 1, 4096, grid=grid(1), stream=stream0) | |
del arg1_1 | |
# Source Nodes: [out_2], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf18 = torch.ops.torchao.fp16act_fp6weight_linear.default(reinterpret_tensor(buf17, (1, 4096), (4096, 1), 0), arg74_1, arg75_1, 1) | |
del arg74_1 | |
del arg75_1 | |
buf19 = buf18 | |
del buf18 | |
# Source Nodes: [out_3], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf20 = torch.ops.torchao.fp16act_fp6weight_linear.default(reinterpret_tensor(buf17, (1, 4096), (4096, 1), 0), arg76_1, arg77_1, 1) | |
del arg76_1 | |
del arg77_1 | |
buf21 = buf20 | |
del buf20 | |
buf22 = buf19; del buf19 # reuse | |
# Source Nodes: [out_4], Original ATen: [torchao.fp16act_fp6weight_linear] | |
triton_poi_fused_fp16act_fp6weight_linear_7.run(buf22, buf21, 11008, grid=grid(11008), stream=stream0) | |
del buf21 | |
# Source Nodes: [out_4], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf23 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf22, arg78_1, arg79_1, 13) | |
del arg78_1 | |
del arg79_1 | |
del buf22 | |
buf24 = buf23 | |
del buf23 | |
buf26 = reinterpret_tensor(buf17, (1, 4096), (4096, 1), 0); del buf17 # reuse | |
# Source Nodes: [float_5, h, mean_2, mul_15, out_5, out_6, x], Original ATen: [aten._to_copy, aten.add, aten.embedding, aten.mean, aten.mul, torchao.fp16act_fp6weight_linear] | |
triton_red_fused__to_copy_add_embedding_fp16act_fp6weight_linear_mean_mul_8.run(arg455_1, arg65_1, buf15, buf24, arg2_1, buf26, 1, 4096, grid=grid(1), stream=stream0) | |
del arg2_1 | |
# Source Nodes: [out_6], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf27 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf26, arg80_1, arg81_1, 1) | |
del arg80_1 | |
del arg81_1 | |
buf28 = buf27 | |
del buf27 | |
buf30 = buf5; del buf5 # reuse | |
# Source Nodes: [setitem_2, setitem_3, y_3], Original ATen: [aten.bmm, aten.index_put] | |
triton_poi_fused_bmm_index_put_1.run(arg454_1, buf28, arg66_1, arg82_1, buf30, arg83_1, 4096, grid=grid(4096), stream=stream0) | |
del buf28 | |
buf31 = buf10; del buf10 # reuse | |
# Source Nodes: [y_3], Original ATen: [aten.bmm] | |
triton_red_fused_bmm_2.run(buf30, arg82_1, buf31, 6656, 128, grid=grid(6656), stream=stream0) | |
del arg82_1 | |
buf35 = buf6; del buf6 # reuse | |
# Source Nodes: [mask, y_3], Original ATen: [aten._softmax, aten.add, aten.bmm, aten.index, aten.logical_not, aten.masked_fill, aten.zeros_like] | |
triton_per_fused__softmax_add_bmm_index_logical_not_masked_fill_zeros_like_3.run(buf31, arg454_1, arg67_1, buf35, 32, 208, grid=grid(32), stream=stream0) | |
buf36 = buf11; del buf11 # reuse | |
# Source Nodes: [y_3], Original ATen: [aten.bmm] | |
triton_red_fused_bmm_4.run(buf35, arg83_1, buf36, 8192, 104, grid=grid(8192), stream=stream0) | |
del arg83_1 | |
buf38 = buf26; del buf26 # reuse | |
# Source Nodes: [out_7, y_3], Original ATen: [aten.bmm, torchao.fp16act_fp6weight_linear] | |
triton_per_fused_bmm_fp16act_fp6weight_linear_5.run(buf36, buf38, 4096, 2, grid=grid(4096), stream=stream0) | |
# Source Nodes: [out_7], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf39 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf38, arg84_1, arg85_1, 13) | |
del arg84_1 | |
del arg85_1 | |
buf40 = buf39 | |
del buf39 | |
buf41 = reinterpret_tensor(buf15, (1, 1, 4096), (4096, 4096, 1), 0); del buf15 # reuse | |
buf43 = buf38; del buf38 # reuse | |
buf46 = empty_strided_cuda((1, 4096), (4096, 1), torch.float16) | |
# Source Nodes: [float_8, h, h_1, mean_3, mul_26, out_5, out_8, out_9, x], Original ATen: [aten._to_copy, aten.add, aten.embedding, aten.mean, aten.mul, torchao.fp16act_fp6weight_linear] | |
triton_red_fused__to_copy_add_embedding_fp16act_fp6weight_linear_mean_mul_9.run(buf41, arg455_1, arg65_1, buf24, buf40, arg3_1, buf43, buf46, 1, 4096, grid=grid(1), stream=stream0) | |
del arg3_1 | |
del arg455_1 | |
del arg65_1 | |
del buf24 | |
del buf40 | |
# Source Nodes: [out_8], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf44 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf43, arg86_1, arg87_1, 1) | |
del arg86_1 | |
del arg87_1 | |
buf45 = buf44 | |
del buf44 | |
# Source Nodes: [out_9], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf47 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf46, arg88_1, arg89_1, 1) | |
del arg88_1 | |
del arg89_1 | |
buf48 = buf47 | |
del buf47 | |
buf49 = buf45; del buf45 # reuse | |
# Source Nodes: [out_10], Original ATen: [torchao.fp16act_fp6weight_linear] | |
triton_poi_fused_fp16act_fp6weight_linear_7.run(buf49, buf48, 11008, grid=grid(11008), stream=stream0) | |
del buf48 | |
# Source Nodes: [out_10], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf50 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf49, arg90_1, arg91_1, 13) | |
del arg90_1 | |
del arg91_1 | |
del buf49 | |
buf51 = buf50 | |
del buf50 | |
buf53 = buf46; del buf46 # reuse | |
# Source Nodes: [float_9, mean_4, mul_30, out_11, out_12], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, torchao.fp16act_fp6weight_linear] | |
triton_red_fused__to_copy_add_fp16act_fp6weight_linear_mean_mul_10.run(buf41, buf51, arg4_1, buf53, 1, 4096, grid=grid(1), stream=stream0) | |
del arg4_1 | |
# Source Nodes: [out_12], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf54 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf53, arg92_1, arg93_1, 1) | |
del arg92_1 | |
del arg93_1 | |
buf55 = buf54 | |
del buf54 | |
buf57 = buf30; del buf30 # reuse | |
# Source Nodes: [setitem_4, setitem_5, y_6], Original ATen: [aten.bmm, aten.index_put] | |
triton_poi_fused_bmm_index_put_1.run(arg454_1, buf55, arg66_1, arg94_1, buf57, arg95_1, 4096, grid=grid(4096), stream=stream0) | |
del buf55 | |
buf58 = buf35; del buf35 # reuse | |
# Source Nodes: [y_6], Original ATen: [aten.bmm] | |
triton_red_fused_bmm_2.run(buf57, arg94_1, buf58, 6656, 128, grid=grid(6656), stream=stream0) | |
del arg94_1 | |
buf62 = buf31; del buf31 # reuse | |
# Source Nodes: [mask, y_6], Original ATen: [aten._softmax, aten.add, aten.bmm, aten.index, aten.logical_not, aten.masked_fill, aten.zeros_like] | |
triton_per_fused__softmax_add_bmm_index_logical_not_masked_fill_zeros_like_3.run(buf58, arg454_1, arg67_1, buf62, 32, 208, grid=grid(32), stream=stream0) | |
buf63 = buf36; del buf36 # reuse | |
# Source Nodes: [y_6], Original ATen: [aten.bmm] | |
triton_red_fused_bmm_4.run(buf62, arg95_1, buf63, 8192, 104, grid=grid(8192), stream=stream0) | |
del arg95_1 | |
buf65 = buf53; del buf53 # reuse | |
# Source Nodes: [out_13, y_6], Original ATen: [aten.bmm, torchao.fp16act_fp6weight_linear] | |
triton_per_fused_bmm_fp16act_fp6weight_linear_5.run(buf63, buf65, 4096, 2, grid=grid(4096), stream=stream0) | |
# Source Nodes: [out_13], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf66 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf65, arg96_1, arg97_1, 13) | |
del arg96_1 | |
del arg97_1 | |
buf67 = buf66 | |
del buf66 | |
buf69 = reinterpret_tensor(buf65, (1, 1, 4096), (4096, 4096, 1), 0); del buf65 # reuse | |
# Source Nodes: [add_16, float_12, h_2, mean_5, mul_41, mul_42, mul_43, out_11, output_5, rsqrt_5], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, aten.rsqrt] | |
triton_red_fused__to_copy_add_mean_mul_rsqrt_11.run(buf41, buf51, buf67, arg5_1, buf69, 1, 4096, grid=grid(1), stream=stream0) | |
del arg5_1 | |
# Source Nodes: [out_14], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf70 = torch.ops.torchao.fp16act_fp6weight_linear.default(reinterpret_tensor(buf69, (1, 4096), (4096, 1), 0), arg98_1, arg99_1, 1) | |
del arg98_1 | |
del arg99_1 | |
buf71 = buf70 | |
del buf70 | |
# Source Nodes: [out_15], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf72 = torch.ops.torchao.fp16act_fp6weight_linear.default(reinterpret_tensor(buf69, (1, 4096), (4096, 1), 0), arg100_1, arg101_1, 1) | |
del arg100_1 | |
del arg101_1 | |
buf73 = buf72 | |
del buf72 | |
buf74 = buf71; del buf71 # reuse | |
# Source Nodes: [out_16], Original ATen: [torchao.fp16act_fp6weight_linear] | |
triton_poi_fused_fp16act_fp6weight_linear_7.run(buf74, buf73, 11008, grid=grid(11008), stream=stream0) | |
del buf73 | |
# Source Nodes: [out_16], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf75 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf74, arg102_1, arg103_1, 13) | |
del arg102_1 | |
del arg103_1 | |
del buf74 | |
buf76 = buf75 | |
del buf75 | |
buf78 = reinterpret_tensor(buf69, (1, 4096), (4096, 1), 0); del buf69 # reuse | |
# Source Nodes: [float_13, h_2, mean_6, mul_45, out_11, out_17, out_18], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, torchao.fp16act_fp6weight_linear] | |
triton_red_fused__to_copy_add_fp16act_fp6weight_linear_mean_mul_12.run(buf41, buf51, buf67, buf76, arg6_1, buf78, 1, 4096, grid=grid(1), stream=stream0) | |
del arg6_1 | |
# Source Nodes: [out_18], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf79 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf78, arg104_1, arg105_1, 1) | |
del arg104_1 | |
del arg105_1 | |
buf80 = buf79 | |
del buf79 | |
buf82 = buf57; del buf57 # reuse | |
# Source Nodes: [setitem_6, setitem_7, y_9], Original ATen: [aten.bmm, aten.index_put] | |
triton_poi_fused_bmm_index_put_1.run(arg454_1, buf80, arg66_1, arg106_1, buf82, arg107_1, 4096, grid=grid(4096), stream=stream0) | |
del buf80 | |
buf83 = buf62; del buf62 # reuse | |
# Source Nodes: [y_9], Original ATen: [aten.bmm] | |
triton_red_fused_bmm_2.run(buf82, arg106_1, buf83, 6656, 128, grid=grid(6656), stream=stream0) | |
del arg106_1 | |
buf87 = buf58; del buf58 # reuse | |
# Source Nodes: [mask, y_9], Original ATen: [aten._softmax, aten.add, aten.bmm, aten.index, aten.logical_not, aten.masked_fill, aten.zeros_like] | |
triton_per_fused__softmax_add_bmm_index_logical_not_masked_fill_zeros_like_3.run(buf83, arg454_1, arg67_1, buf87, 32, 208, grid=grid(32), stream=stream0) | |
buf88 = buf63; del buf63 # reuse | |
# Source Nodes: [y_9], Original ATen: [aten.bmm] | |
triton_red_fused_bmm_4.run(buf87, arg107_1, buf88, 8192, 104, grid=grid(8192), stream=stream0) | |
del arg107_1 | |
buf90 = buf78; del buf78 # reuse | |
# Source Nodes: [out_19, y_9], Original ATen: [aten.bmm, torchao.fp16act_fp6weight_linear] | |
triton_per_fused_bmm_fp16act_fp6weight_linear_5.run(buf88, buf90, 4096, 2, grid=grid(4096), stream=stream0) | |
# Source Nodes: [out_19], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf91 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf90, arg108_1, arg109_1, 13) | |
del arg108_1 | |
del arg109_1 | |
buf92 = buf91 | |
del buf91 | |
buf93 = buf41; del buf41 # reuse | |
buf95 = buf90; del buf90 # reuse | |
buf98 = buf43; del buf43 # reuse | |
# Source Nodes: [float_16, h_2, h_3, mean_7, mul_56, out_11, out_17, out_20, out_21], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, torchao.fp16act_fp6weight_linear] | |
triton_red_fused__to_copy_add_fp16act_fp6weight_linear_mean_mul_13.run(buf93, buf51, buf67, buf76, buf92, arg7_1, buf95, buf98, 1, 4096, grid=grid(1), stream=stream0) | |
del arg7_1 | |
del buf51 | |
del buf67 | |
del buf76 | |
del buf92 | |
# Source Nodes: [out_20], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf96 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf95, arg110_1, arg111_1, 1) | |
del arg110_1 | |
del arg111_1 | |
buf97 = buf96 | |
del buf96 | |
# Source Nodes: [out_21], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf99 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf98, arg112_1, arg113_1, 1) | |
del arg112_1 | |
del arg113_1 | |
buf100 = buf99 | |
del buf99 | |
buf101 = buf100; del buf100 # reuse | |
# Source Nodes: [out_22], Original ATen: [torchao.fp16act_fp6weight_linear] | |
triton_poi_fused_fp16act_fp6weight_linear_14.run(buf101, buf97, 11008, grid=grid(11008), stream=stream0) | |
del buf97 | |
# Source Nodes: [out_22], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf102 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf101, arg114_1, arg115_1, 13) | |
del arg114_1 | |
del arg115_1 | |
del buf101 | |
buf103 = buf102 | |
del buf102 | |
buf105 = buf98; del buf98 # reuse | |
# Source Nodes: [float_17, mean_8, mul_60, out_23, out_24], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, torchao.fp16act_fp6weight_linear] | |
triton_red_fused__to_copy_add_fp16act_fp6weight_linear_mean_mul_10.run(buf93, buf103, arg8_1, buf105, 1, 4096, grid=grid(1), stream=stream0) | |
del arg8_1 | |
# Source Nodes: [out_24], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf106 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf105, arg116_1, arg117_1, 1) | |
del arg116_1 | |
del arg117_1 | |
buf107 = buf106 | |
del buf106 | |
buf109 = buf82; del buf82 # reuse | |
# Source Nodes: [setitem_8, setitem_9, y_12], Original ATen: [aten.bmm, aten.index_put] | |
triton_poi_fused_bmm_index_put_1.run(arg454_1, buf107, arg66_1, arg118_1, buf109, arg119_1, 4096, grid=grid(4096), stream=stream0) | |
del buf107 | |
buf110 = buf87; del buf87 # reuse | |
# Source Nodes: [y_12], Original ATen: [aten.bmm] | |
triton_red_fused_bmm_2.run(buf109, arg118_1, buf110, 6656, 128, grid=grid(6656), stream=stream0) | |
del arg118_1 | |
buf114 = buf83; del buf83 # reuse | |
# Source Nodes: [mask, y_12], Original ATen: [aten._softmax, aten.add, aten.bmm, aten.index, aten.logical_not, aten.masked_fill, aten.zeros_like] | |
triton_per_fused__softmax_add_bmm_index_logical_not_masked_fill_zeros_like_3.run(buf110, arg454_1, arg67_1, buf114, 32, 208, grid=grid(32), stream=stream0) | |
buf115 = buf88; del buf88 # reuse | |
# Source Nodes: [y_12], Original ATen: [aten.bmm] | |
triton_red_fused_bmm_4.run(buf114, arg119_1, buf115, 8192, 104, grid=grid(8192), stream=stream0) | |
del arg119_1 | |
buf117 = buf105; del buf105 # reuse | |
# Source Nodes: [out_25, y_12], Original ATen: [aten.bmm, torchao.fp16act_fp6weight_linear] | |
triton_per_fused_bmm_fp16act_fp6weight_linear_5.run(buf115, buf117, 4096, 2, grid=grid(4096), stream=stream0) | |
# Source Nodes: [out_25], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf118 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf117, arg120_1, arg121_1, 13) | |
del arg120_1 | |
del arg121_1 | |
buf119 = buf118 | |
del buf118 | |
buf121 = reinterpret_tensor(buf117, (1, 1, 4096), (4096, 4096, 1), 0); del buf117 # reuse | |
# Source Nodes: [add_28, float_20, h_4, mean_9, mul_71, mul_72, mul_73, out_23, output_9, rsqrt_9], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, aten.rsqrt] | |
triton_red_fused__to_copy_add_mean_mul_rsqrt_11.run(buf93, buf103, buf119, arg9_1, buf121, 1, 4096, grid=grid(1), stream=stream0) | |
del arg9_1 | |
# Source Nodes: [out_26], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf122 = torch.ops.torchao.fp16act_fp6weight_linear.default(reinterpret_tensor(buf121, (1, 4096), (4096, 1), 0), arg122_1, arg123_1, 1) | |
del arg122_1 | |
del arg123_1 | |
buf123 = buf122 | |
del buf122 | |
# Source Nodes: [out_27], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf124 = torch.ops.torchao.fp16act_fp6weight_linear.default(reinterpret_tensor(buf121, (1, 4096), (4096, 1), 0), arg124_1, arg125_1, 1) | |
del arg124_1 | |
del arg125_1 | |
buf125 = buf124 | |
del buf124 | |
buf126 = buf123; del buf123 # reuse | |
# Source Nodes: [out_28], Original ATen: [torchao.fp16act_fp6weight_linear] | |
triton_poi_fused_fp16act_fp6weight_linear_7.run(buf126, buf125, 11008, grid=grid(11008), stream=stream0) | |
del buf125 | |
# Source Nodes: [out_28], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf127 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf126, arg126_1, arg127_1, 13) | |
del arg126_1 | |
del arg127_1 | |
del buf126 | |
buf128 = buf127 | |
del buf127 | |
buf130 = reinterpret_tensor(buf121, (1, 4096), (4096, 1), 0); del buf121 # reuse | |
# Source Nodes: [float_21, h_4, mean_10, mul_75, out_23, out_29, out_30], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, torchao.fp16act_fp6weight_linear] | |
triton_red_fused__to_copy_add_fp16act_fp6weight_linear_mean_mul_12.run(buf93, buf103, buf119, buf128, arg10_1, buf130, 1, 4096, grid=grid(1), stream=stream0) | |
del arg10_1 | |
# Source Nodes: [out_30], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf131 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf130, arg128_1, arg129_1, 1) | |
del arg128_1 | |
del arg129_1 | |
buf132 = buf131 | |
del buf131 | |
buf134 = buf109; del buf109 # reuse | |
# Source Nodes: [setitem_10, setitem_11, y_15], Original ATen: [aten.bmm, aten.index_put] | |
triton_poi_fused_bmm_index_put_1.run(arg454_1, buf132, arg66_1, arg130_1, buf134, arg131_1, 4096, grid=grid(4096), stream=stream0) | |
del buf132 | |
buf135 = buf114; del buf114 # reuse | |
# Source Nodes: [y_15], Original ATen: [aten.bmm] | |
triton_red_fused_bmm_2.run(buf134, arg130_1, buf135, 6656, 128, grid=grid(6656), stream=stream0) | |
del arg130_1 | |
buf139 = buf110; del buf110 # reuse | |
# Source Nodes: [mask, y_15], Original ATen: [aten._softmax, aten.add, aten.bmm, aten.index, aten.logical_not, aten.masked_fill, aten.zeros_like] | |
triton_per_fused__softmax_add_bmm_index_logical_not_masked_fill_zeros_like_3.run(buf135, arg454_1, arg67_1, buf139, 32, 208, grid=grid(32), stream=stream0) | |
buf140 = buf115; del buf115 # reuse | |
# Source Nodes: [y_15], Original ATen: [aten.bmm] | |
triton_red_fused_bmm_4.run(buf139, arg131_1, buf140, 8192, 104, grid=grid(8192), stream=stream0) | |
del arg131_1 | |
buf142 = buf130; del buf130 # reuse | |
# Source Nodes: [out_31, y_15], Original ATen: [aten.bmm, torchao.fp16act_fp6weight_linear] | |
triton_per_fused_bmm_fp16act_fp6weight_linear_5.run(buf140, buf142, 4096, 2, grid=grid(4096), stream=stream0) | |
# Source Nodes: [out_31], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf143 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf142, arg132_1, arg133_1, 13) | |
del arg132_1 | |
del arg133_1 | |
buf144 = buf143 | |
del buf143 | |
buf145 = reinterpret_tensor(buf103, (1, 1, 4096), (4096, 4096, 1), 0); del buf103 # reuse | |
buf147 = buf142; del buf142 # reuse | |
buf150 = buf95; del buf95 # reuse | |
# Source Nodes: [float_24, h_4, h_5, mean_11, mul_86, out_23, out_29, out_32, out_33], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, torchao.fp16act_fp6weight_linear] | |
triton_red_fused__to_copy_add_fp16act_fp6weight_linear_mean_mul_15.run(buf145, buf93, buf119, buf128, buf144, arg11_1, buf147, buf150, 1, 4096, grid=grid(1), stream=stream0) | |
del arg11_1 | |
del buf119 | |
del buf128 | |
del buf144 | |
del buf93 | |
# Source Nodes: [out_32], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf148 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf147, arg134_1, arg135_1, 1) | |
del arg134_1 | |
del arg135_1 | |
buf149 = buf148 | |
del buf148 | |
# Source Nodes: [out_33], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf151 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf150, arg136_1, arg137_1, 1) | |
del arg136_1 | |
del arg137_1 | |
buf152 = buf151 | |
del buf151 | |
buf153 = buf149; del buf149 # reuse | |
# Source Nodes: [out_34], Original ATen: [torchao.fp16act_fp6weight_linear] | |
triton_poi_fused_fp16act_fp6weight_linear_7.run(buf153, buf152, 11008, grid=grid(11008), stream=stream0) | |
del buf152 | |
# Source Nodes: [out_34], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf154 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf153, arg138_1, arg139_1, 13) | |
del arg138_1 | |
del arg139_1 | |
del buf153 | |
buf155 = buf154 | |
del buf154 | |
buf157 = buf150; del buf150 # reuse | |
# Source Nodes: [float_25, mean_12, mul_90, out_35, out_36], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, torchao.fp16act_fp6weight_linear] | |
triton_red_fused__to_copy_add_fp16act_fp6weight_linear_mean_mul_10.run(buf145, buf155, arg12_1, buf157, 1, 4096, grid=grid(1), stream=stream0) | |
del arg12_1 | |
# Source Nodes: [out_36], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf158 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf157, arg140_1, arg141_1, 1) | |
del arg140_1 | |
del arg141_1 | |
buf159 = buf158 | |
del buf158 | |
buf161 = buf134; del buf134 # reuse | |
# Source Nodes: [setitem_12, setitem_13, y_18], Original ATen: [aten.bmm, aten.index_put] | |
triton_poi_fused_bmm_index_put_1.run(arg454_1, buf159, arg66_1, arg142_1, buf161, arg143_1, 4096, grid=grid(4096), stream=stream0) | |
del buf159 | |
buf162 = buf139; del buf139 # reuse | |
# Source Nodes: [y_18], Original ATen: [aten.bmm] | |
triton_red_fused_bmm_2.run(buf161, arg142_1, buf162, 6656, 128, grid=grid(6656), stream=stream0) | |
del arg142_1 | |
buf166 = buf135; del buf135 # reuse | |
# Source Nodes: [mask, y_18], Original ATen: [aten._softmax, aten.add, aten.bmm, aten.index, aten.logical_not, aten.masked_fill, aten.zeros_like] | |
triton_per_fused__softmax_add_bmm_index_logical_not_masked_fill_zeros_like_3.run(buf162, arg454_1, arg67_1, buf166, 32, 208, grid=grid(32), stream=stream0) | |
buf167 = buf140; del buf140 # reuse | |
# Source Nodes: [y_18], Original ATen: [aten.bmm] | |
triton_red_fused_bmm_4.run(buf166, arg143_1, buf167, 8192, 104, grid=grid(8192), stream=stream0) | |
del arg143_1 | |
buf169 = buf157; del buf157 # reuse | |
# Source Nodes: [out_37, y_18], Original ATen: [aten.bmm, torchao.fp16act_fp6weight_linear] | |
triton_per_fused_bmm_fp16act_fp6weight_linear_5.run(buf167, buf169, 4096, 2, grid=grid(4096), stream=stream0) | |
# Source Nodes: [out_37], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf170 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf169, arg144_1, arg145_1, 13) | |
del arg144_1 | |
del arg145_1 | |
buf171 = buf170 | |
del buf170 | |
buf173 = reinterpret_tensor(buf169, (1, 1, 4096), (4096, 4096, 1), 0); del buf169 # reuse | |
# Source Nodes: [add_40, float_28, h_6, mean_13, mul_101, mul_102, mul_103, out_35, output_13, rsqrt_13], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, aten.rsqrt] | |
triton_red_fused__to_copy_add_mean_mul_rsqrt_11.run(buf145, buf155, buf171, arg13_1, buf173, 1, 4096, grid=grid(1), stream=stream0) | |
del arg13_1 | |
# Source Nodes: [out_38], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf174 = torch.ops.torchao.fp16act_fp6weight_linear.default(reinterpret_tensor(buf173, (1, 4096), (4096, 1), 0), arg146_1, arg147_1, 1) | |
del arg146_1 | |
del arg147_1 | |
buf175 = buf174 | |
del buf174 | |
# Source Nodes: [out_39], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf176 = torch.ops.torchao.fp16act_fp6weight_linear.default(reinterpret_tensor(buf173, (1, 4096), (4096, 1), 0), arg148_1, arg149_1, 1) | |
del arg148_1 | |
del arg149_1 | |
buf177 = buf176 | |
del buf176 | |
buf178 = buf175; del buf175 # reuse | |
# Source Nodes: [out_40], Original ATen: [torchao.fp16act_fp6weight_linear] | |
triton_poi_fused_fp16act_fp6weight_linear_7.run(buf178, buf177, 11008, grid=grid(11008), stream=stream0) | |
del buf177 | |
# Source Nodes: [out_40], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf179 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf178, arg150_1, arg151_1, 13) | |
del arg150_1 | |
del arg151_1 | |
del buf178 | |
buf180 = buf179 | |
del buf179 | |
buf182 = reinterpret_tensor(buf173, (1, 4096), (4096, 1), 0); del buf173 # reuse | |
# Source Nodes: [float_29, h_6, mean_14, mul_105, out_35, out_41, out_42], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, torchao.fp16act_fp6weight_linear] | |
triton_red_fused__to_copy_add_fp16act_fp6weight_linear_mean_mul_12.run(buf145, buf155, buf171, buf180, arg14_1, buf182, 1, 4096, grid=grid(1), stream=stream0) | |
del arg14_1 | |
# Source Nodes: [out_42], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf183 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf182, arg152_1, arg153_1, 1) | |
del arg152_1 | |
del arg153_1 | |
buf184 = buf183 | |
del buf183 | |
buf186 = buf161; del buf161 # reuse | |
# Source Nodes: [setitem_14, setitem_15, y_21], Original ATen: [aten.bmm, aten.index_put] | |
triton_poi_fused_bmm_index_put_1.run(arg454_1, buf184, arg66_1, arg154_1, buf186, arg155_1, 4096, grid=grid(4096), stream=stream0) | |
del buf184 | |
buf187 = buf166; del buf166 # reuse | |
# Source Nodes: [y_21], Original ATen: [aten.bmm] | |
triton_red_fused_bmm_2.run(buf186, arg154_1, buf187, 6656, 128, grid=grid(6656), stream=stream0) | |
del arg154_1 | |
buf191 = buf162; del buf162 # reuse | |
# Source Nodes: [mask, y_21], Original ATen: [aten._softmax, aten.add, aten.bmm, aten.index, aten.logical_not, aten.masked_fill, aten.zeros_like] | |
triton_per_fused__softmax_add_bmm_index_logical_not_masked_fill_zeros_like_3.run(buf187, arg454_1, arg67_1, buf191, 32, 208, grid=grid(32), stream=stream0) | |
buf192 = buf167; del buf167 # reuse | |
# Source Nodes: [y_21], Original ATen: [aten.bmm] | |
triton_red_fused_bmm_4.run(buf191, arg155_1, buf192, 8192, 104, grid=grid(8192), stream=stream0) | |
del arg155_1 | |
buf194 = buf182; del buf182 # reuse | |
# Source Nodes: [out_43, y_21], Original ATen: [aten.bmm, torchao.fp16act_fp6weight_linear] | |
triton_per_fused_bmm_fp16act_fp6weight_linear_5.run(buf192, buf194, 4096, 2, grid=grid(4096), stream=stream0) | |
# Source Nodes: [out_43], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf195 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf194, arg156_1, arg157_1, 13) | |
del arg156_1 | |
del arg157_1 | |
buf196 = buf195 | |
del buf195 | |
buf197 = buf145; del buf145 # reuse | |
buf199 = buf194; del buf194 # reuse | |
buf202 = buf147; del buf147 # reuse | |
# Source Nodes: [float_32, h_6, h_7, mean_15, mul_116, out_35, out_41, out_44, out_45], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, torchao.fp16act_fp6weight_linear] | |
triton_red_fused__to_copy_add_fp16act_fp6weight_linear_mean_mul_13.run(buf197, buf155, buf171, buf180, buf196, arg15_1, buf199, buf202, 1, 4096, grid=grid(1), stream=stream0) | |
del arg15_1 | |
del buf155 | |
del buf171 | |
del buf180 | |
del buf196 | |
# Source Nodes: [out_44], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf200 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf199, arg158_1, arg159_1, 1) | |
del arg158_1 | |
del arg159_1 | |
buf201 = buf200 | |
del buf200 | |
# Source Nodes: [out_45], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf203 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf202, arg160_1, arg161_1, 1) | |
del arg160_1 | |
del arg161_1 | |
buf204 = buf203 | |
del buf203 | |
buf205 = buf201; del buf201 # reuse | |
# Source Nodes: [out_46], Original ATen: [torchao.fp16act_fp6weight_linear] | |
triton_poi_fused_fp16act_fp6weight_linear_7.run(buf205, buf204, 11008, grid=grid(11008), stream=stream0) | |
del buf204 | |
# Source Nodes: [out_46], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf206 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf205, arg162_1, arg163_1, 13) | |
del arg162_1 | |
del arg163_1 | |
del buf205 | |
buf207 = buf206 | |
del buf206 | |
buf209 = buf202; del buf202 # reuse | |
# Source Nodes: [float_33, mean_16, mul_120, out_47, out_48], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, torchao.fp16act_fp6weight_linear] | |
triton_red_fused__to_copy_add_fp16act_fp6weight_linear_mean_mul_10.run(buf197, buf207, arg16_1, buf209, 1, 4096, grid=grid(1), stream=stream0) | |
del arg16_1 | |
# Source Nodes: [out_48], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf210 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf209, arg164_1, arg165_1, 1) | |
del arg164_1 | |
del arg165_1 | |
buf211 = buf210 | |
del buf210 | |
buf213 = buf186; del buf186 # reuse | |
# Source Nodes: [setitem_16, setitem_17, y_24], Original ATen: [aten.bmm, aten.index_put] | |
triton_poi_fused_bmm_index_put_1.run(arg454_1, buf211, arg66_1, arg166_1, buf213, arg167_1, 4096, grid=grid(4096), stream=stream0) | |
del buf211 | |
buf214 = buf191; del buf191 # reuse | |
# Source Nodes: [y_24], Original ATen: [aten.bmm] | |
triton_red_fused_bmm_2.run(buf213, arg166_1, buf214, 6656, 128, grid=grid(6656), stream=stream0) | |
del arg166_1 | |
buf218 = buf187; del buf187 # reuse | |
# Source Nodes: [mask, y_24], Original ATen: [aten._softmax, aten.add, aten.bmm, aten.index, aten.logical_not, aten.masked_fill, aten.zeros_like] | |
triton_per_fused__softmax_add_bmm_index_logical_not_masked_fill_zeros_like_3.run(buf214, arg454_1, arg67_1, buf218, 32, 208, grid=grid(32), stream=stream0) | |
buf219 = buf192; del buf192 # reuse | |
# Source Nodes: [y_24], Original ATen: [aten.bmm] | |
triton_red_fused_bmm_4.run(buf218, arg167_1, buf219, 8192, 104, grid=grid(8192), stream=stream0) | |
del arg167_1 | |
buf221 = buf209; del buf209 # reuse | |
# Source Nodes: [out_49, y_24], Original ATen: [aten.bmm, torchao.fp16act_fp6weight_linear] | |
triton_per_fused_bmm_fp16act_fp6weight_linear_5.run(buf219, buf221, 4096, 2, grid=grid(4096), stream=stream0) | |
# Source Nodes: [out_49], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf222 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf221, arg168_1, arg169_1, 13) | |
del arg168_1 | |
del arg169_1 | |
buf223 = buf222 | |
del buf222 | |
buf225 = reinterpret_tensor(buf221, (1, 1, 4096), (4096, 4096, 1), 0); del buf221 # reuse | |
# Source Nodes: [add_52, float_36, h_8, mean_17, mul_131, mul_132, mul_133, out_47, output_17, rsqrt_17], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, aten.rsqrt] | |
triton_red_fused__to_copy_add_mean_mul_rsqrt_11.run(buf197, buf207, buf223, arg17_1, buf225, 1, 4096, grid=grid(1), stream=stream0) | |
del arg17_1 | |
# Source Nodes: [out_50], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf226 = torch.ops.torchao.fp16act_fp6weight_linear.default(reinterpret_tensor(buf225, (1, 4096), (4096, 1), 0), arg170_1, arg171_1, 1) | |
del arg170_1 | |
del arg171_1 | |
buf227 = buf226 | |
del buf226 | |
# Source Nodes: [out_51], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf228 = torch.ops.torchao.fp16act_fp6weight_linear.default(reinterpret_tensor(buf225, (1, 4096), (4096, 1), 0), arg172_1, arg173_1, 1) | |
del arg172_1 | |
del arg173_1 | |
buf229 = buf228 | |
del buf228 | |
buf230 = buf227; del buf227 # reuse | |
# Source Nodes: [out_52], Original ATen: [torchao.fp16act_fp6weight_linear] | |
triton_poi_fused_fp16act_fp6weight_linear_7.run(buf230, buf229, 11008, grid=grid(11008), stream=stream0) | |
del buf229 | |
# Source Nodes: [out_52], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf231 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf230, arg174_1, arg175_1, 13) | |
del arg174_1 | |
del arg175_1 | |
del buf230 | |
buf232 = buf231 | |
del buf231 | |
buf234 = reinterpret_tensor(buf225, (1, 4096), (4096, 1), 0); del buf225 # reuse | |
# Source Nodes: [float_37, h_8, mean_18, mul_135, out_47, out_53, out_54], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, torchao.fp16act_fp6weight_linear] | |
triton_red_fused__to_copy_add_fp16act_fp6weight_linear_mean_mul_12.run(buf197, buf207, buf223, buf232, arg18_1, buf234, 1, 4096, grid=grid(1), stream=stream0) | |
del arg18_1 | |
# Source Nodes: [out_54], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf235 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf234, arg176_1, arg177_1, 1) | |
del arg176_1 | |
del arg177_1 | |
buf236 = buf235 | |
del buf235 | |
buf238 = buf213; del buf213 # reuse | |
# Source Nodes: [setitem_18, setitem_19, y_27], Original ATen: [aten.bmm, aten.index_put] | |
triton_poi_fused_bmm_index_put_1.run(arg454_1, buf236, arg66_1, arg178_1, buf238, arg179_1, 4096, grid=grid(4096), stream=stream0) | |
del buf236 | |
buf239 = buf218; del buf218 # reuse | |
# Source Nodes: [y_27], Original ATen: [aten.bmm] | |
triton_red_fused_bmm_2.run(buf238, arg178_1, buf239, 6656, 128, grid=grid(6656), stream=stream0) | |
del arg178_1 | |
buf243 = buf214; del buf214 # reuse | |
# Source Nodes: [mask, y_27], Original ATen: [aten._softmax, aten.add, aten.bmm, aten.index, aten.logical_not, aten.masked_fill, aten.zeros_like] | |
triton_per_fused__softmax_add_bmm_index_logical_not_masked_fill_zeros_like_3.run(buf239, arg454_1, arg67_1, buf243, 32, 208, grid=grid(32), stream=stream0) | |
buf244 = buf219; del buf219 # reuse | |
# Source Nodes: [y_27], Original ATen: [aten.bmm] | |
triton_red_fused_bmm_4.run(buf243, arg179_1, buf244, 8192, 104, grid=grid(8192), stream=stream0) | |
del arg179_1 | |
buf246 = buf234; del buf234 # reuse | |
# Source Nodes: [out_55, y_27], Original ATen: [aten.bmm, torchao.fp16act_fp6weight_linear] | |
triton_per_fused_bmm_fp16act_fp6weight_linear_5.run(buf244, buf246, 4096, 2, grid=grid(4096), stream=stream0) | |
# Source Nodes: [out_55], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf247 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf246, arg180_1, arg181_1, 13) | |
del arg180_1 | |
del arg181_1 | |
buf248 = buf247 | |
del buf247 | |
buf249 = buf197; del buf197 # reuse | |
buf251 = buf246; del buf246 # reuse | |
buf254 = buf199; del buf199 # reuse | |
# Source Nodes: [float_40, h_8, h_9, mean_19, mul_146, out_47, out_53, out_56, out_57], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, torchao.fp16act_fp6weight_linear] | |
triton_red_fused__to_copy_add_fp16act_fp6weight_linear_mean_mul_13.run(buf249, buf207, buf223, buf232, buf248, arg19_1, buf251, buf254, 1, 4096, grid=grid(1), stream=stream0) | |
del arg19_1 | |
del buf207 | |
del buf223 | |
del buf232 | |
del buf248 | |
# Source Nodes: [out_56], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf252 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf251, arg182_1, arg183_1, 1) | |
del arg182_1 | |
del arg183_1 | |
buf253 = buf252 | |
del buf252 | |
# Source Nodes: [out_57], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf255 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf254, arg184_1, arg185_1, 1) | |
del arg184_1 | |
del arg185_1 | |
buf256 = buf255 | |
del buf255 | |
buf257 = buf253; del buf253 # reuse | |
# Source Nodes: [out_58], Original ATen: [torchao.fp16act_fp6weight_linear] | |
triton_poi_fused_fp16act_fp6weight_linear_7.run(buf257, buf256, 11008, grid=grid(11008), stream=stream0) | |
del buf256 | |
# Source Nodes: [out_58], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf258 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf257, arg186_1, arg187_1, 13) | |
del arg186_1 | |
del arg187_1 | |
del buf257 | |
buf259 = buf258 | |
del buf258 | |
buf261 = buf254; del buf254 # reuse | |
# Source Nodes: [float_41, mean_20, mul_150, out_59, out_60], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, torchao.fp16act_fp6weight_linear] | |
triton_red_fused__to_copy_add_fp16act_fp6weight_linear_mean_mul_10.run(buf249, buf259, arg20_1, buf261, 1, 4096, grid=grid(1), stream=stream0) | |
del arg20_1 | |
# Source Nodes: [out_60], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf262 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf261, arg188_1, arg189_1, 1) | |
del arg188_1 | |
del arg189_1 | |
buf263 = buf262 | |
del buf262 | |
buf265 = buf238; del buf238 # reuse | |
# Source Nodes: [setitem_20, setitem_21, y_30], Original ATen: [aten.bmm, aten.index_put] | |
triton_poi_fused_bmm_index_put_1.run(arg454_1, buf263, arg66_1, arg190_1, buf265, arg191_1, 4096, grid=grid(4096), stream=stream0) | |
del buf263 | |
buf266 = buf243; del buf243 # reuse | |
# Source Nodes: [y_30], Original ATen: [aten.bmm] | |
triton_red_fused_bmm_2.run(buf265, arg190_1, buf266, 6656, 128, grid=grid(6656), stream=stream0) | |
del arg190_1 | |
buf270 = buf239; del buf239 # reuse | |
# Source Nodes: [mask, y_30], Original ATen: [aten._softmax, aten.add, aten.bmm, aten.index, aten.logical_not, aten.masked_fill, aten.zeros_like] | |
triton_per_fused__softmax_add_bmm_index_logical_not_masked_fill_zeros_like_3.run(buf266, arg454_1, arg67_1, buf270, 32, 208, grid=grid(32), stream=stream0) | |
buf271 = buf244; del buf244 # reuse | |
# Source Nodes: [y_30], Original ATen: [aten.bmm] | |
triton_red_fused_bmm_4.run(buf270, arg191_1, buf271, 8192, 104, grid=grid(8192), stream=stream0) | |
del arg191_1 | |
buf273 = buf261; del buf261 # reuse | |
# Source Nodes: [out_61, y_30], Original ATen: [aten.bmm, torchao.fp16act_fp6weight_linear] | |
triton_per_fused_bmm_fp16act_fp6weight_linear_5.run(buf271, buf273, 4096, 2, grid=grid(4096), stream=stream0) | |
# Source Nodes: [out_61], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf274 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf273, arg192_1, arg193_1, 13) | |
del arg192_1 | |
del arg193_1 | |
buf275 = buf274 | |
del buf274 | |
buf277 = reinterpret_tensor(buf273, (1, 1, 4096), (4096, 4096, 1), 0); del buf273 # reuse | |
# Source Nodes: [add_64, float_44, h_10, mean_21, mul_161, mul_162, mul_163, out_59, output_21, rsqrt_21], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, aten.rsqrt] | |
triton_red_fused__to_copy_add_mean_mul_rsqrt_11.run(buf249, buf259, buf275, arg21_1, buf277, 1, 4096, grid=grid(1), stream=stream0) | |
del arg21_1 | |
# Source Nodes: [out_62], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf278 = torch.ops.torchao.fp16act_fp6weight_linear.default(reinterpret_tensor(buf277, (1, 4096), (4096, 1), 0), arg194_1, arg195_1, 1) | |
del arg194_1 | |
del arg195_1 | |
buf279 = buf278 | |
del buf278 | |
# Source Nodes: [out_63], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf280 = torch.ops.torchao.fp16act_fp6weight_linear.default(reinterpret_tensor(buf277, (1, 4096), (4096, 1), 0), arg196_1, arg197_1, 1) | |
del arg196_1 | |
del arg197_1 | |
buf281 = buf280 | |
del buf280 | |
buf282 = buf279; del buf279 # reuse | |
# Source Nodes: [out_64], Original ATen: [torchao.fp16act_fp6weight_linear] | |
triton_poi_fused_fp16act_fp6weight_linear_7.run(buf282, buf281, 11008, grid=grid(11008), stream=stream0) | |
del buf281 | |
# Source Nodes: [out_64], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf283 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf282, arg198_1, arg199_1, 13) | |
del arg198_1 | |
del arg199_1 | |
del buf282 | |
buf284 = buf283 | |
del buf283 | |
buf286 = reinterpret_tensor(buf277, (1, 4096), (4096, 1), 0); del buf277 # reuse | |
# Source Nodes: [float_45, h_10, mean_22, mul_165, out_59, out_65, out_66], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, torchao.fp16act_fp6weight_linear] | |
triton_red_fused__to_copy_add_fp16act_fp6weight_linear_mean_mul_12.run(buf249, buf259, buf275, buf284, arg22_1, buf286, 1, 4096, grid=grid(1), stream=stream0) | |
del arg22_1 | |
# Source Nodes: [out_66], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf287 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf286, arg200_1, arg201_1, 1) | |
del arg200_1 | |
del arg201_1 | |
buf288 = buf287 | |
del buf287 | |
buf290 = buf265; del buf265 # reuse | |
# Source Nodes: [setitem_22, setitem_23, y_33], Original ATen: [aten.bmm, aten.index_put] | |
triton_poi_fused_bmm_index_put_1.run(arg454_1, buf288, arg66_1, arg202_1, buf290, arg203_1, 4096, grid=grid(4096), stream=stream0) | |
del buf288 | |
buf291 = buf270; del buf270 # reuse | |
# Source Nodes: [y_33], Original ATen: [aten.bmm] | |
triton_red_fused_bmm_2.run(buf290, arg202_1, buf291, 6656, 128, grid=grid(6656), stream=stream0) | |
del arg202_1 | |
buf295 = buf266; del buf266 # reuse | |
# Source Nodes: [mask, y_33], Original ATen: [aten._softmax, aten.add, aten.bmm, aten.index, aten.logical_not, aten.masked_fill, aten.zeros_like] | |
triton_per_fused__softmax_add_bmm_index_logical_not_masked_fill_zeros_like_3.run(buf291, arg454_1, arg67_1, buf295, 32, 208, grid=grid(32), stream=stream0) | |
buf296 = buf271; del buf271 # reuse | |
# Source Nodes: [y_33], Original ATen: [aten.bmm] | |
triton_red_fused_bmm_4.run(buf295, arg203_1, buf296, 8192, 104, grid=grid(8192), stream=stream0) | |
del arg203_1 | |
buf298 = buf286; del buf286 # reuse | |
# Source Nodes: [out_67, y_33], Original ATen: [aten.bmm, torchao.fp16act_fp6weight_linear] | |
triton_per_fused_bmm_fp16act_fp6weight_linear_5.run(buf296, buf298, 4096, 2, grid=grid(4096), stream=stream0) | |
# Source Nodes: [out_67], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf299 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf298, arg204_1, arg205_1, 13) | |
del arg204_1 | |
del arg205_1 | |
buf300 = buf299 | |
del buf299 | |
buf301 = buf249; del buf249 # reuse | |
buf303 = buf298; del buf298 # reuse | |
buf306 = buf251; del buf251 # reuse | |
# Source Nodes: [float_48, h_10, h_11, mean_23, mul_176, out_59, out_65, out_68, out_69], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, torchao.fp16act_fp6weight_linear] | |
triton_red_fused__to_copy_add_fp16act_fp6weight_linear_mean_mul_13.run(buf301, buf259, buf275, buf284, buf300, arg23_1, buf303, buf306, 1, 4096, grid=grid(1), stream=stream0) | |
del arg23_1 | |
del buf259 | |
del buf275 | |
del buf284 | |
del buf300 | |
# Source Nodes: [out_68], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf304 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf303, arg206_1, arg207_1, 1) | |
del arg206_1 | |
del arg207_1 | |
buf305 = buf304 | |
del buf304 | |
# Source Nodes: [out_69], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf307 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf306, arg208_1, arg209_1, 1) | |
del arg208_1 | |
del arg209_1 | |
buf308 = buf307 | |
del buf307 | |
buf309 = buf305; del buf305 # reuse | |
# Source Nodes: [out_70], Original ATen: [torchao.fp16act_fp6weight_linear] | |
triton_poi_fused_fp16act_fp6weight_linear_7.run(buf309, buf308, 11008, grid=grid(11008), stream=stream0) | |
del buf308 | |
# Source Nodes: [out_70], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf310 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf309, arg210_1, arg211_1, 13) | |
del arg210_1 | |
del arg211_1 | |
del buf309 | |
buf311 = buf310 | |
del buf310 | |
buf313 = buf306; del buf306 # reuse | |
# Source Nodes: [float_49, mean_24, mul_180, out_71, out_72], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, torchao.fp16act_fp6weight_linear] | |
triton_red_fused__to_copy_add_fp16act_fp6weight_linear_mean_mul_10.run(buf301, buf311, arg24_1, buf313, 1, 4096, grid=grid(1), stream=stream0) | |
del arg24_1 | |
# Source Nodes: [out_72], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf314 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf313, arg212_1, arg213_1, 1) | |
del arg212_1 | |
del arg213_1 | |
buf315 = buf314 | |
del buf314 | |
buf317 = buf290; del buf290 # reuse | |
# Source Nodes: [setitem_24, setitem_25, y_36], Original ATen: [aten.bmm, aten.index_put] | |
triton_poi_fused_bmm_index_put_1.run(arg454_1, buf315, arg66_1, arg214_1, buf317, arg215_1, 4096, grid=grid(4096), stream=stream0) | |
del buf315 | |
buf318 = buf295; del buf295 # reuse | |
# Source Nodes: [y_36], Original ATen: [aten.bmm] | |
triton_red_fused_bmm_2.run(buf317, arg214_1, buf318, 6656, 128, grid=grid(6656), stream=stream0) | |
del arg214_1 | |
buf322 = buf291; del buf291 # reuse | |
# Source Nodes: [mask, y_36], Original ATen: [aten._softmax, aten.add, aten.bmm, aten.index, aten.logical_not, aten.masked_fill, aten.zeros_like] | |
triton_per_fused__softmax_add_bmm_index_logical_not_masked_fill_zeros_like_3.run(buf318, arg454_1, arg67_1, buf322, 32, 208, grid=grid(32), stream=stream0) | |
buf323 = buf296; del buf296 # reuse | |
# Source Nodes: [y_36], Original ATen: [aten.bmm] | |
triton_red_fused_bmm_4.run(buf322, arg215_1, buf323, 8192, 104, grid=grid(8192), stream=stream0) | |
del arg215_1 | |
buf325 = buf313; del buf313 # reuse | |
# Source Nodes: [out_73, y_36], Original ATen: [aten.bmm, torchao.fp16act_fp6weight_linear] | |
triton_per_fused_bmm_fp16act_fp6weight_linear_5.run(buf323, buf325, 4096, 2, grid=grid(4096), stream=stream0) | |
# Source Nodes: [out_73], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf326 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf325, arg216_1, arg217_1, 13) | |
del arg216_1 | |
del arg217_1 | |
buf327 = buf326 | |
del buf326 | |
buf329 = reinterpret_tensor(buf325, (1, 1, 4096), (4096, 4096, 1), 0); del buf325 # reuse | |
# Source Nodes: [add_76, float_52, h_12, mean_25, mul_191, mul_192, mul_193, out_71, output_25, rsqrt_25], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, aten.rsqrt] | |
triton_red_fused__to_copy_add_mean_mul_rsqrt_11.run(buf301, buf311, buf327, arg25_1, buf329, 1, 4096, grid=grid(1), stream=stream0) | |
del arg25_1 | |
# Source Nodes: [out_74], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf330 = torch.ops.torchao.fp16act_fp6weight_linear.default(reinterpret_tensor(buf329, (1, 4096), (4096, 1), 0), arg218_1, arg219_1, 1) | |
del arg218_1 | |
del arg219_1 | |
buf331 = buf330 | |
del buf330 | |
# Source Nodes: [out_75], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf332 = torch.ops.torchao.fp16act_fp6weight_linear.default(reinterpret_tensor(buf329, (1, 4096), (4096, 1), 0), arg220_1, arg221_1, 1) | |
del arg220_1 | |
del arg221_1 | |
buf333 = buf332 | |
del buf332 | |
buf334 = buf331; del buf331 # reuse | |
# Source Nodes: [out_76], Original ATen: [torchao.fp16act_fp6weight_linear] | |
triton_poi_fused_fp16act_fp6weight_linear_7.run(buf334, buf333, 11008, grid=grid(11008), stream=stream0) | |
del buf333 | |
# Source Nodes: [out_76], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf335 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf334, arg222_1, arg223_1, 13) | |
del arg222_1 | |
del arg223_1 | |
del buf334 | |
buf336 = buf335 | |
del buf335 | |
buf338 = reinterpret_tensor(buf329, (1, 4096), (4096, 1), 0); del buf329 # reuse | |
# Source Nodes: [float_53, h_12, mean_26, mul_195, out_71, out_77, out_78], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, torchao.fp16act_fp6weight_linear] | |
triton_red_fused__to_copy_add_fp16act_fp6weight_linear_mean_mul_12.run(buf301, buf311, buf327, buf336, arg26_1, buf338, 1, 4096, grid=grid(1), stream=stream0) | |
del arg26_1 | |
# Source Nodes: [out_78], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf339 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf338, arg224_1, arg225_1, 1) | |
del arg224_1 | |
del arg225_1 | |
buf340 = buf339 | |
del buf339 | |
buf342 = buf317; del buf317 # reuse | |
# Source Nodes: [setitem_26, setitem_27, y_39], Original ATen: [aten.bmm, aten.index_put] | |
triton_poi_fused_bmm_index_put_1.run(arg454_1, buf340, arg66_1, arg226_1, buf342, arg227_1, 4096, grid=grid(4096), stream=stream0) | |
del buf340 | |
buf343 = buf322; del buf322 # reuse | |
# Source Nodes: [y_39], Original ATen: [aten.bmm] | |
triton_red_fused_bmm_2.run(buf342, arg226_1, buf343, 6656, 128, grid=grid(6656), stream=stream0) | |
del arg226_1 | |
buf347 = buf318; del buf318 # reuse | |
# Source Nodes: [mask, y_39], Original ATen: [aten._softmax, aten.add, aten.bmm, aten.index, aten.logical_not, aten.masked_fill, aten.zeros_like] | |
triton_per_fused__softmax_add_bmm_index_logical_not_masked_fill_zeros_like_3.run(buf343, arg454_1, arg67_1, buf347, 32, 208, grid=grid(32), stream=stream0) | |
buf348 = buf323; del buf323 # reuse | |
# Source Nodes: [y_39], Original ATen: [aten.bmm] | |
triton_red_fused_bmm_4.run(buf347, arg227_1, buf348, 8192, 104, grid=grid(8192), stream=stream0) | |
del arg227_1 | |
buf350 = buf338; del buf338 # reuse | |
# Source Nodes: [out_79, y_39], Original ATen: [aten.bmm, torchao.fp16act_fp6weight_linear] | |
triton_per_fused_bmm_fp16act_fp6weight_linear_5.run(buf348, buf350, 4096, 2, grid=grid(4096), stream=stream0) | |
# Source Nodes: [out_79], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf351 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf350, arg228_1, arg229_1, 13) | |
del arg228_1 | |
del arg229_1 | |
buf352 = buf351 | |
del buf351 | |
buf353 = buf301; del buf301 # reuse | |
buf355 = buf350; del buf350 # reuse | |
buf358 = buf303; del buf303 # reuse | |
# Source Nodes: [float_56, h_12, h_13, mean_27, mul_206, out_71, out_77, out_80, out_81], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, torchao.fp16act_fp6weight_linear] | |
triton_red_fused__to_copy_add_fp16act_fp6weight_linear_mean_mul_13.run(buf353, buf311, buf327, buf336, buf352, arg27_1, buf355, buf358, 1, 4096, grid=grid(1), stream=stream0) | |
del arg27_1 | |
del buf311 | |
del buf327 | |
del buf336 | |
del buf352 | |
# Source Nodes: [out_80], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf356 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf355, arg230_1, arg231_1, 1) | |
del arg230_1 | |
del arg231_1 | |
buf357 = buf356 | |
del buf356 | |
# Source Nodes: [out_81], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf359 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf358, arg232_1, arg233_1, 1) | |
del arg232_1 | |
del arg233_1 | |
buf360 = buf359 | |
del buf359 | |
buf361 = buf357; del buf357 # reuse | |
# Source Nodes: [out_82], Original ATen: [torchao.fp16act_fp6weight_linear] | |
triton_poi_fused_fp16act_fp6weight_linear_7.run(buf361, buf360, 11008, grid=grid(11008), stream=stream0) | |
del buf360 | |
# Source Nodes: [out_82], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf362 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf361, arg234_1, arg235_1, 13) | |
del arg234_1 | |
del arg235_1 | |
del buf361 | |
buf363 = buf362 | |
del buf362 | |
buf365 = buf358; del buf358 # reuse | |
# Source Nodes: [float_57, mean_28, mul_210, out_83, out_84], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, torchao.fp16act_fp6weight_linear] | |
triton_red_fused__to_copy_add_fp16act_fp6weight_linear_mean_mul_10.run(buf353, buf363, arg28_1, buf365, 1, 4096, grid=grid(1), stream=stream0) | |
del arg28_1 | |
# Source Nodes: [out_84], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf366 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf365, arg236_1, arg237_1, 1) | |
del arg236_1 | |
del arg237_1 | |
buf367 = buf366 | |
del buf366 | |
buf369 = buf342; del buf342 # reuse | |
# Source Nodes: [setitem_28, setitem_29, y_42], Original ATen: [aten.bmm, aten.index_put] | |
triton_poi_fused_bmm_index_put_1.run(arg454_1, buf367, arg66_1, arg238_1, buf369, arg239_1, 4096, grid=grid(4096), stream=stream0) | |
del buf367 | |
buf370 = buf347; del buf347 # reuse | |
# Source Nodes: [y_42], Original ATen: [aten.bmm] | |
triton_red_fused_bmm_2.run(buf369, arg238_1, buf370, 6656, 128, grid=grid(6656), stream=stream0) | |
del arg238_1 | |
buf374 = buf343; del buf343 # reuse | |
# Source Nodes: [mask, y_42], Original ATen: [aten._softmax, aten.add, aten.bmm, aten.index, aten.logical_not, aten.masked_fill, aten.zeros_like] | |
triton_per_fused__softmax_add_bmm_index_logical_not_masked_fill_zeros_like_3.run(buf370, arg454_1, arg67_1, buf374, 32, 208, grid=grid(32), stream=stream0) | |
buf375 = buf348; del buf348 # reuse | |
# Source Nodes: [y_42], Original ATen: [aten.bmm] | |
triton_red_fused_bmm_4.run(buf374, arg239_1, buf375, 8192, 104, grid=grid(8192), stream=stream0) | |
del arg239_1 | |
buf377 = buf365; del buf365 # reuse | |
# Source Nodes: [out_85, y_42], Original ATen: [aten.bmm, torchao.fp16act_fp6weight_linear] | |
triton_per_fused_bmm_fp16act_fp6weight_linear_5.run(buf375, buf377, 4096, 2, grid=grid(4096), stream=stream0) | |
# Source Nodes: [out_85], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf378 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf377, arg240_1, arg241_1, 13) | |
del arg240_1 | |
del arg241_1 | |
buf379 = buf378 | |
del buf378 | |
buf381 = reinterpret_tensor(buf377, (1, 1, 4096), (4096, 4096, 1), 0); del buf377 # reuse | |
# Source Nodes: [add_88, float_60, h_14, mean_29, mul_221, mul_222, mul_223, out_83, output_29, rsqrt_29], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, aten.rsqrt] | |
triton_red_fused__to_copy_add_mean_mul_rsqrt_11.run(buf353, buf363, buf379, arg29_1, buf381, 1, 4096, grid=grid(1), stream=stream0) | |
del arg29_1 | |
# Source Nodes: [out_86], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf382 = torch.ops.torchao.fp16act_fp6weight_linear.default(reinterpret_tensor(buf381, (1, 4096), (4096, 1), 0), arg242_1, arg243_1, 1) | |
del arg242_1 | |
del arg243_1 | |
buf383 = buf382 | |
del buf382 | |
# Source Nodes: [out_87], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf384 = torch.ops.torchao.fp16act_fp6weight_linear.default(reinterpret_tensor(buf381, (1, 4096), (4096, 1), 0), arg244_1, arg245_1, 1) | |
del arg244_1 | |
del arg245_1 | |
buf385 = buf384 | |
del buf384 | |
buf386 = buf383; del buf383 # reuse | |
# Source Nodes: [out_88], Original ATen: [torchao.fp16act_fp6weight_linear] | |
triton_poi_fused_fp16act_fp6weight_linear_7.run(buf386, buf385, 11008, grid=grid(11008), stream=stream0) | |
del buf385 | |
# Source Nodes: [out_88], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf387 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf386, arg246_1, arg247_1, 13) | |
del arg246_1 | |
del arg247_1 | |
del buf386 | |
buf388 = buf387 | |
del buf387 | |
buf390 = reinterpret_tensor(buf381, (1, 4096), (4096, 1), 0); del buf381 # reuse | |
# Source Nodes: [float_61, h_14, mean_30, mul_225, out_83, out_89, out_90], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, torchao.fp16act_fp6weight_linear] | |
triton_red_fused__to_copy_add_fp16act_fp6weight_linear_mean_mul_12.run(buf353, buf363, buf379, buf388, arg30_1, buf390, 1, 4096, grid=grid(1), stream=stream0) | |
del arg30_1 | |
# Source Nodes: [out_90], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf391 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf390, arg248_1, arg249_1, 1) | |
del arg248_1 | |
del arg249_1 | |
buf392 = buf391 | |
del buf391 | |
buf394 = buf369; del buf369 # reuse | |
# Source Nodes: [setitem_30, setitem_31, y_45], Original ATen: [aten.bmm, aten.index_put] | |
triton_poi_fused_bmm_index_put_1.run(arg454_1, buf392, arg66_1, arg250_1, buf394, arg251_1, 4096, grid=grid(4096), stream=stream0) | |
del buf392 | |
buf395 = buf374; del buf374 # reuse | |
# Source Nodes: [y_45], Original ATen: [aten.bmm] | |
triton_red_fused_bmm_2.run(buf394, arg250_1, buf395, 6656, 128, grid=grid(6656), stream=stream0) | |
del arg250_1 | |
buf399 = buf370; del buf370 # reuse | |
# Source Nodes: [mask, y_45], Original ATen: [aten._softmax, aten.add, aten.bmm, aten.index, aten.logical_not, aten.masked_fill, aten.zeros_like] | |
triton_per_fused__softmax_add_bmm_index_logical_not_masked_fill_zeros_like_3.run(buf395, arg454_1, arg67_1, buf399, 32, 208, grid=grid(32), stream=stream0) | |
buf400 = buf375; del buf375 # reuse | |
# Source Nodes: [y_45], Original ATen: [aten.bmm] | |
triton_red_fused_bmm_4.run(buf399, arg251_1, buf400, 8192, 104, grid=grid(8192), stream=stream0) | |
del arg251_1 | |
buf402 = buf390; del buf390 # reuse | |
# Source Nodes: [out_91, y_45], Original ATen: [aten.bmm, torchao.fp16act_fp6weight_linear] | |
triton_per_fused_bmm_fp16act_fp6weight_linear_5.run(buf400, buf402, 4096, 2, grid=grid(4096), stream=stream0) | |
# Source Nodes: [out_91], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf403 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf402, arg252_1, arg253_1, 13) | |
del arg252_1 | |
del arg253_1 | |
buf404 = buf403 | |
del buf403 | |
buf405 = buf353; del buf353 # reuse | |
buf407 = buf402; del buf402 # reuse | |
buf410 = buf355; del buf355 # reuse | |
# Source Nodes: [float_64, h_14, h_15, mean_31, mul_236, out_83, out_89, out_92, out_93], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, torchao.fp16act_fp6weight_linear] | |
triton_red_fused__to_copy_add_fp16act_fp6weight_linear_mean_mul_13.run(buf405, buf363, buf379, buf388, buf404, arg31_1, buf407, buf410, 1, 4096, grid=grid(1), stream=stream0) | |
del arg31_1 | |
del buf363 | |
del buf379 | |
del buf388 | |
del buf404 | |
# Source Nodes: [out_92], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf408 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf407, arg254_1, arg255_1, 1) | |
del arg254_1 | |
del arg255_1 | |
buf409 = buf408 | |
del buf408 | |
# Source Nodes: [out_93], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf411 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf410, arg256_1, arg257_1, 1) | |
del arg256_1 | |
del arg257_1 | |
buf412 = buf411 | |
del buf411 | |
buf413 = buf409; del buf409 # reuse | |
# Source Nodes: [out_94], Original ATen: [torchao.fp16act_fp6weight_linear] | |
triton_poi_fused_fp16act_fp6weight_linear_7.run(buf413, buf412, 11008, grid=grid(11008), stream=stream0) | |
del buf412 | |
# Source Nodes: [out_94], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf414 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf413, arg258_1, arg259_1, 13) | |
del arg258_1 | |
del arg259_1 | |
del buf413 | |
buf415 = buf414 | |
del buf414 | |
buf417 = buf410; del buf410 # reuse | |
# Source Nodes: [float_65, mean_32, mul_240, out_95, out_96], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, torchao.fp16act_fp6weight_linear] | |
triton_red_fused__to_copy_add_fp16act_fp6weight_linear_mean_mul_10.run(buf405, buf415, arg32_1, buf417, 1, 4096, grid=grid(1), stream=stream0) | |
del arg32_1 | |
# Source Nodes: [out_96], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf418 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf417, arg260_1, arg261_1, 1) | |
del arg260_1 | |
del arg261_1 | |
buf419 = buf418 | |
del buf418 | |
buf421 = buf394; del buf394 # reuse | |
# Source Nodes: [setitem_32, setitem_33, y_48], Original ATen: [aten.bmm, aten.index_put] | |
triton_poi_fused_bmm_index_put_1.run(arg454_1, buf419, arg66_1, arg262_1, buf421, arg263_1, 4096, grid=grid(4096), stream=stream0) | |
del buf419 | |
buf422 = buf399; del buf399 # reuse | |
# Source Nodes: [y_48], Original ATen: [aten.bmm] | |
triton_red_fused_bmm_2.run(buf421, arg262_1, buf422, 6656, 128, grid=grid(6656), stream=stream0) | |
del arg262_1 | |
buf426 = buf395; del buf395 # reuse | |
# Source Nodes: [mask, y_48], Original ATen: [aten._softmax, aten.add, aten.bmm, aten.index, aten.logical_not, aten.masked_fill, aten.zeros_like] | |
triton_per_fused__softmax_add_bmm_index_logical_not_masked_fill_zeros_like_3.run(buf422, arg454_1, arg67_1, buf426, 32, 208, grid=grid(32), stream=stream0) | |
buf427 = buf400; del buf400 # reuse | |
# Source Nodes: [y_48], Original ATen: [aten.bmm] | |
triton_red_fused_bmm_4.run(buf426, arg263_1, buf427, 8192, 104, grid=grid(8192), stream=stream0) | |
del arg263_1 | |
buf429 = buf417; del buf417 # reuse | |
# Source Nodes: [out_97, y_48], Original ATen: [aten.bmm, torchao.fp16act_fp6weight_linear] | |
triton_per_fused_bmm_fp16act_fp6weight_linear_5.run(buf427, buf429, 4096, 2, grid=grid(4096), stream=stream0) | |
# Source Nodes: [out_97], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf430 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf429, arg264_1, arg265_1, 13) | |
del arg264_1 | |
del arg265_1 | |
buf431 = buf430 | |
del buf430 | |
buf433 = reinterpret_tensor(buf429, (1, 1, 4096), (4096, 4096, 1), 0); del buf429 # reuse | |
# Source Nodes: [add_100, float_68, h_16, mean_33, mul_251, mul_252, mul_253, out_95, output_33, rsqrt_33], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, aten.rsqrt] | |
triton_red_fused__to_copy_add_mean_mul_rsqrt_11.run(buf405, buf415, buf431, arg33_1, buf433, 1, 4096, grid=grid(1), stream=stream0) | |
del arg33_1 | |
# Source Nodes: [out_98], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf434 = torch.ops.torchao.fp16act_fp6weight_linear.default(reinterpret_tensor(buf433, (1, 4096), (4096, 1), 0), arg266_1, arg267_1, 1) | |
del arg266_1 | |
del arg267_1 | |
buf435 = buf434 | |
del buf434 | |
# Source Nodes: [out_99], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf436 = torch.ops.torchao.fp16act_fp6weight_linear.default(reinterpret_tensor(buf433, (1, 4096), (4096, 1), 0), arg268_1, arg269_1, 1) | |
del arg268_1 | |
del arg269_1 | |
buf437 = buf436 | |
del buf436 | |
buf438 = buf435; del buf435 # reuse | |
# Source Nodes: [out_100], Original ATen: [torchao.fp16act_fp6weight_linear] | |
triton_poi_fused_fp16act_fp6weight_linear_7.run(buf438, buf437, 11008, grid=grid(11008), stream=stream0) | |
del buf437 | |
# Source Nodes: [out_100], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf439 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf438, arg270_1, arg271_1, 13) | |
del arg270_1 | |
del arg271_1 | |
del buf438 | |
buf440 = buf439 | |
del buf439 | |
buf442 = reinterpret_tensor(buf433, (1, 4096), (4096, 1), 0); del buf433 # reuse | |
# Source Nodes: [float_69, h_16, mean_34, mul_255, out_101, out_102, out_95], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, torchao.fp16act_fp6weight_linear] | |
triton_red_fused__to_copy_add_fp16act_fp6weight_linear_mean_mul_12.run(buf405, buf415, buf431, buf440, arg34_1, buf442, 1, 4096, grid=grid(1), stream=stream0) | |
del arg34_1 | |
# Source Nodes: [out_102], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf443 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf442, arg272_1, arg273_1, 1) | |
del arg272_1 | |
del arg273_1 | |
buf444 = buf443 | |
del buf443 | |
buf446 = buf421; del buf421 # reuse | |
# Source Nodes: [setitem_34, setitem_35, y_51], Original ATen: [aten.bmm, aten.index_put] | |
triton_poi_fused_bmm_index_put_1.run(arg454_1, buf444, arg66_1, arg274_1, buf446, arg275_1, 4096, grid=grid(4096), stream=stream0) | |
del buf444 | |
buf447 = buf426; del buf426 # reuse | |
# Source Nodes: [y_51], Original ATen: [aten.bmm] | |
triton_red_fused_bmm_2.run(buf446, arg274_1, buf447, 6656, 128, grid=grid(6656), stream=stream0) | |
del arg274_1 | |
buf451 = buf422; del buf422 # reuse | |
# Source Nodes: [mask, y_51], Original ATen: [aten._softmax, aten.add, aten.bmm, aten.index, aten.logical_not, aten.masked_fill, aten.zeros_like] | |
triton_per_fused__softmax_add_bmm_index_logical_not_masked_fill_zeros_like_3.run(buf447, arg454_1, arg67_1, buf451, 32, 208, grid=grid(32), stream=stream0) | |
buf452 = buf427; del buf427 # reuse | |
# Source Nodes: [y_51], Original ATen: [aten.bmm] | |
triton_red_fused_bmm_4.run(buf451, arg275_1, buf452, 8192, 104, grid=grid(8192), stream=stream0) | |
del arg275_1 | |
buf454 = buf442; del buf442 # reuse | |
# Source Nodes: [out_103, y_51], Original ATen: [aten.bmm, torchao.fp16act_fp6weight_linear] | |
triton_per_fused_bmm_fp16act_fp6weight_linear_5.run(buf452, buf454, 4096, 2, grid=grid(4096), stream=stream0) | |
# Source Nodes: [out_103], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf455 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf454, arg276_1, arg277_1, 13) | |
del arg276_1 | |
del arg277_1 | |
buf456 = buf455 | |
del buf455 | |
buf457 = buf405; del buf405 # reuse | |
buf459 = buf454; del buf454 # reuse | |
buf462 = buf407; del buf407 # reuse | |
# Source Nodes: [float_72, h_16, h_17, mean_35, mul_266, out_101, out_104, out_105, out_95], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, torchao.fp16act_fp6weight_linear] | |
triton_red_fused__to_copy_add_fp16act_fp6weight_linear_mean_mul_13.run(buf457, buf415, buf431, buf440, buf456, arg35_1, buf459, buf462, 1, 4096, grid=grid(1), stream=stream0) | |
del arg35_1 | |
del buf415 | |
del buf431 | |
del buf440 | |
del buf456 | |
# Source Nodes: [out_104], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf460 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf459, arg278_1, arg279_1, 1) | |
del arg278_1 | |
del arg279_1 | |
buf461 = buf460 | |
del buf460 | |
# Source Nodes: [out_105], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf463 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf462, arg280_1, arg281_1, 1) | |
del arg280_1 | |
del arg281_1 | |
buf464 = buf463 | |
del buf463 | |
buf465 = buf461; del buf461 # reuse | |
# Source Nodes: [out_106], Original ATen: [torchao.fp16act_fp6weight_linear] | |
triton_poi_fused_fp16act_fp6weight_linear_7.run(buf465, buf464, 11008, grid=grid(11008), stream=stream0) | |
del buf464 | |
# Source Nodes: [out_106], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf466 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf465, arg282_1, arg283_1, 13) | |
del arg282_1 | |
del arg283_1 | |
del buf465 | |
buf467 = buf466 | |
del buf466 | |
buf469 = buf462; del buf462 # reuse | |
# Source Nodes: [float_73, mean_36, mul_270, out_107, out_108], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, torchao.fp16act_fp6weight_linear] | |
triton_red_fused__to_copy_add_fp16act_fp6weight_linear_mean_mul_10.run(buf457, buf467, arg36_1, buf469, 1, 4096, grid=grid(1), stream=stream0) | |
del arg36_1 | |
# Source Nodes: [out_108], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf470 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf469, arg284_1, arg285_1, 1) | |
del arg284_1 | |
del arg285_1 | |
buf471 = buf470 | |
del buf470 | |
buf473 = buf446; del buf446 # reuse | |
# Source Nodes: [setitem_36, setitem_37, y_54], Original ATen: [aten.bmm, aten.index_put] | |
triton_poi_fused_bmm_index_put_1.run(arg454_1, buf471, arg66_1, arg286_1, buf473, arg287_1, 4096, grid=grid(4096), stream=stream0) | |
del buf471 | |
buf474 = buf451; del buf451 # reuse | |
# Source Nodes: [y_54], Original ATen: [aten.bmm] | |
triton_red_fused_bmm_2.run(buf473, arg286_1, buf474, 6656, 128, grid=grid(6656), stream=stream0) | |
del arg286_1 | |
buf478 = buf447; del buf447 # reuse | |
# Source Nodes: [mask, y_54], Original ATen: [aten._softmax, aten.add, aten.bmm, aten.index, aten.logical_not, aten.masked_fill, aten.zeros_like] | |
triton_per_fused__softmax_add_bmm_index_logical_not_masked_fill_zeros_like_3.run(buf474, arg454_1, arg67_1, buf478, 32, 208, grid=grid(32), stream=stream0) | |
buf479 = buf452; del buf452 # reuse | |
# Source Nodes: [y_54], Original ATen: [aten.bmm] | |
triton_red_fused_bmm_4.run(buf478, arg287_1, buf479, 8192, 104, grid=grid(8192), stream=stream0) | |
del arg287_1 | |
buf481 = buf469; del buf469 # reuse | |
# Source Nodes: [out_109, y_54], Original ATen: [aten.bmm, torchao.fp16act_fp6weight_linear] | |
triton_per_fused_bmm_fp16act_fp6weight_linear_5.run(buf479, buf481, 4096, 2, grid=grid(4096), stream=stream0) | |
# Source Nodes: [out_109], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf482 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf481, arg288_1, arg289_1, 13) | |
del arg288_1 | |
del arg289_1 | |
buf483 = buf482 | |
del buf482 | |
buf485 = reinterpret_tensor(buf481, (1, 1, 4096), (4096, 4096, 1), 0); del buf481 # reuse | |
# Source Nodes: [add_112, float_76, h_18, mean_37, mul_281, mul_282, mul_283, out_107, output_37, rsqrt_37], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, aten.rsqrt] | |
triton_red_fused__to_copy_add_mean_mul_rsqrt_11.run(buf457, buf467, buf483, arg37_1, buf485, 1, 4096, grid=grid(1), stream=stream0) | |
del arg37_1 | |
# Source Nodes: [out_110], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf486 = torch.ops.torchao.fp16act_fp6weight_linear.default(reinterpret_tensor(buf485, (1, 4096), (4096, 1), 0), arg290_1, arg291_1, 1) | |
del arg290_1 | |
del arg291_1 | |
buf487 = buf486 | |
del buf486 | |
# Source Nodes: [out_111], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf488 = torch.ops.torchao.fp16act_fp6weight_linear.default(reinterpret_tensor(buf485, (1, 4096), (4096, 1), 0), arg292_1, arg293_1, 1) | |
del arg292_1 | |
del arg293_1 | |
buf489 = buf488 | |
del buf488 | |
buf490 = buf487; del buf487 # reuse | |
# Source Nodes: [out_112], Original ATen: [torchao.fp16act_fp6weight_linear] | |
triton_poi_fused_fp16act_fp6weight_linear_7.run(buf490, buf489, 11008, grid=grid(11008), stream=stream0) | |
del buf489 | |
# Source Nodes: [out_112], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf491 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf490, arg294_1, arg295_1, 13) | |
del arg294_1 | |
del arg295_1 | |
del buf490 | |
buf492 = buf491 | |
del buf491 | |
buf494 = reinterpret_tensor(buf485, (1, 4096), (4096, 1), 0); del buf485 # reuse | |
# Source Nodes: [float_77, h_18, mean_38, mul_285, out_107, out_113, out_114], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, torchao.fp16act_fp6weight_linear] | |
triton_red_fused__to_copy_add_fp16act_fp6weight_linear_mean_mul_12.run(buf457, buf467, buf483, buf492, arg38_1, buf494, 1, 4096, grid=grid(1), stream=stream0) | |
del arg38_1 | |
# Source Nodes: [out_114], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf495 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf494, arg296_1, arg297_1, 1) | |
del arg296_1 | |
del arg297_1 | |
buf496 = buf495 | |
del buf495 | |
buf498 = buf473; del buf473 # reuse | |
# Source Nodes: [setitem_38, setitem_39, y_57], Original ATen: [aten.bmm, aten.index_put] | |
triton_poi_fused_bmm_index_put_1.run(arg454_1, buf496, arg66_1, arg298_1, buf498, arg299_1, 4096, grid=grid(4096), stream=stream0) | |
del buf496 | |
buf499 = buf478; del buf478 # reuse | |
# Source Nodes: [y_57], Original ATen: [aten.bmm] | |
triton_red_fused_bmm_2.run(buf498, arg298_1, buf499, 6656, 128, grid=grid(6656), stream=stream0) | |
del arg298_1 | |
buf503 = buf474; del buf474 # reuse | |
# Source Nodes: [mask, y_57], Original ATen: [aten._softmax, aten.add, aten.bmm, aten.index, aten.logical_not, aten.masked_fill, aten.zeros_like] | |
triton_per_fused__softmax_add_bmm_index_logical_not_masked_fill_zeros_like_3.run(buf499, arg454_1, arg67_1, buf503, 32, 208, grid=grid(32), stream=stream0) | |
buf504 = buf479; del buf479 # reuse | |
# Source Nodes: [y_57], Original ATen: [aten.bmm] | |
triton_red_fused_bmm_4.run(buf503, arg299_1, buf504, 8192, 104, grid=grid(8192), stream=stream0) | |
del arg299_1 | |
buf506 = buf494; del buf494 # reuse | |
# Source Nodes: [out_115, y_57], Original ATen: [aten.bmm, torchao.fp16act_fp6weight_linear] | |
triton_per_fused_bmm_fp16act_fp6weight_linear_5.run(buf504, buf506, 4096, 2, grid=grid(4096), stream=stream0) | |
# Source Nodes: [out_115], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf507 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf506, arg300_1, arg301_1, 13) | |
del arg300_1 | |
del arg301_1 | |
buf508 = buf507 | |
del buf507 | |
buf509 = buf457; del buf457 # reuse | |
buf511 = buf506; del buf506 # reuse | |
buf514 = buf459; del buf459 # reuse | |
# Source Nodes: [float_80, h_18, h_19, mean_39, mul_296, out_107, out_113, out_116, out_117], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, torchao.fp16act_fp6weight_linear] | |
triton_red_fused__to_copy_add_fp16act_fp6weight_linear_mean_mul_13.run(buf509, buf467, buf483, buf492, buf508, arg39_1, buf511, buf514, 1, 4096, grid=grid(1), stream=stream0) | |
del arg39_1 | |
del buf467 | |
del buf483 | |
del buf492 | |
del buf508 | |
# Source Nodes: [out_116], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf512 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf511, arg302_1, arg303_1, 1) | |
del arg302_1 | |
del arg303_1 | |
buf513 = buf512 | |
del buf512 | |
# Source Nodes: [out_117], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf515 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf514, arg304_1, arg305_1, 1) | |
del arg304_1 | |
del arg305_1 | |
buf516 = buf515 | |
del buf515 | |
buf517 = buf513; del buf513 # reuse | |
# Source Nodes: [out_118], Original ATen: [torchao.fp16act_fp6weight_linear] | |
triton_poi_fused_fp16act_fp6weight_linear_7.run(buf517, buf516, 11008, grid=grid(11008), stream=stream0) | |
del buf516 | |
# Source Nodes: [out_118], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf518 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf517, arg306_1, arg307_1, 13) | |
del arg306_1 | |
del arg307_1 | |
del buf517 | |
buf519 = buf518 | |
del buf518 | |
buf521 = buf514; del buf514 # reuse | |
# Source Nodes: [float_81, mean_40, mul_300, out_119, out_120], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, torchao.fp16act_fp6weight_linear] | |
triton_red_fused__to_copy_add_fp16act_fp6weight_linear_mean_mul_10.run(buf509, buf519, arg40_1, buf521, 1, 4096, grid=grid(1), stream=stream0) | |
del arg40_1 | |
# Source Nodes: [out_120], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf522 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf521, arg308_1, arg309_1, 1) | |
del arg308_1 | |
del arg309_1 | |
buf523 = buf522 | |
del buf522 | |
buf525 = buf498; del buf498 # reuse | |
# Source Nodes: [setitem_40, setitem_41, y_60], Original ATen: [aten.bmm, aten.index_put] | |
triton_poi_fused_bmm_index_put_1.run(arg454_1, buf523, arg66_1, arg310_1, buf525, arg311_1, 4096, grid=grid(4096), stream=stream0) | |
del buf523 | |
buf526 = buf503; del buf503 # reuse | |
# Source Nodes: [y_60], Original ATen: [aten.bmm] | |
triton_red_fused_bmm_2.run(buf525, arg310_1, buf526, 6656, 128, grid=grid(6656), stream=stream0) | |
del arg310_1 | |
buf530 = buf499; del buf499 # reuse | |
# Source Nodes: [mask, y_60], Original ATen: [aten._softmax, aten.add, aten.bmm, aten.index, aten.logical_not, aten.masked_fill, aten.zeros_like] | |
triton_per_fused__softmax_add_bmm_index_logical_not_masked_fill_zeros_like_3.run(buf526, arg454_1, arg67_1, buf530, 32, 208, grid=grid(32), stream=stream0) | |
buf531 = buf504; del buf504 # reuse | |
# Source Nodes: [y_60], Original ATen: [aten.bmm] | |
triton_red_fused_bmm_4.run(buf530, arg311_1, buf531, 8192, 104, grid=grid(8192), stream=stream0) | |
del arg311_1 | |
buf533 = buf521; del buf521 # reuse | |
# Source Nodes: [out_121, y_60], Original ATen: [aten.bmm, torchao.fp16act_fp6weight_linear] | |
triton_per_fused_bmm_fp16act_fp6weight_linear_5.run(buf531, buf533, 4096, 2, grid=grid(4096), stream=stream0) | |
# Source Nodes: [out_121], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf534 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf533, arg312_1, arg313_1, 13) | |
del arg312_1 | |
del arg313_1 | |
buf535 = buf534 | |
del buf534 | |
buf537 = reinterpret_tensor(buf533, (1, 1, 4096), (4096, 4096, 1), 0); del buf533 # reuse | |
# Source Nodes: [add_124, float_84, h_20, mean_41, mul_311, mul_312, mul_313, out_119, output_41, rsqrt_41], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, aten.rsqrt] | |
triton_red_fused__to_copy_add_mean_mul_rsqrt_11.run(buf509, buf519, buf535, arg41_1, buf537, 1, 4096, grid=grid(1), stream=stream0) | |
del arg41_1 | |
# Source Nodes: [out_122], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf538 = torch.ops.torchao.fp16act_fp6weight_linear.default(reinterpret_tensor(buf537, (1, 4096), (4096, 1), 0), arg314_1, arg315_1, 1) | |
del arg314_1 | |
del arg315_1 | |
buf539 = buf538 | |
del buf538 | |
# Source Nodes: [out_123], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf540 = torch.ops.torchao.fp16act_fp6weight_linear.default(reinterpret_tensor(buf537, (1, 4096), (4096, 1), 0), arg316_1, arg317_1, 1) | |
del arg316_1 | |
del arg317_1 | |
buf541 = buf540 | |
del buf540 | |
buf542 = buf539; del buf539 # reuse | |
# Source Nodes: [out_124], Original ATen: [torchao.fp16act_fp6weight_linear] | |
triton_poi_fused_fp16act_fp6weight_linear_7.run(buf542, buf541, 11008, grid=grid(11008), stream=stream0) | |
del buf541 | |
# Source Nodes: [out_124], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf543 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf542, arg318_1, arg319_1, 13) | |
del arg318_1 | |
del arg319_1 | |
del buf542 | |
buf544 = buf543 | |
del buf543 | |
buf546 = reinterpret_tensor(buf537, (1, 4096), (4096, 1), 0); del buf537 # reuse | |
# Source Nodes: [float_85, h_20, mean_42, mul_315, out_119, out_125, out_126], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, torchao.fp16act_fp6weight_linear] | |
triton_red_fused__to_copy_add_fp16act_fp6weight_linear_mean_mul_12.run(buf509, buf519, buf535, buf544, arg42_1, buf546, 1, 4096, grid=grid(1), stream=stream0) | |
del arg42_1 | |
# Source Nodes: [out_126], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf547 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf546, arg320_1, arg321_1, 1) | |
del arg320_1 | |
del arg321_1 | |
buf548 = buf547 | |
del buf547 | |
buf550 = buf525; del buf525 # reuse | |
# Source Nodes: [setitem_42, setitem_43, y_63], Original ATen: [aten.bmm, aten.index_put] | |
triton_poi_fused_bmm_index_put_1.run(arg454_1, buf548, arg66_1, arg322_1, buf550, arg323_1, 4096, grid=grid(4096), stream=stream0) | |
del buf548 | |
buf551 = buf530; del buf530 # reuse | |
# Source Nodes: [y_63], Original ATen: [aten.bmm] | |
triton_red_fused_bmm_2.run(buf550, arg322_1, buf551, 6656, 128, grid=grid(6656), stream=stream0) | |
del arg322_1 | |
buf555 = buf526; del buf526 # reuse | |
# Source Nodes: [mask, y_63], Original ATen: [aten._softmax, aten.add, aten.bmm, aten.index, aten.logical_not, aten.masked_fill, aten.zeros_like] | |
triton_per_fused__softmax_add_bmm_index_logical_not_masked_fill_zeros_like_3.run(buf551, arg454_1, arg67_1, buf555, 32, 208, grid=grid(32), stream=stream0) | |
buf556 = buf531; del buf531 # reuse | |
# Source Nodes: [y_63], Original ATen: [aten.bmm] | |
triton_red_fused_bmm_4.run(buf555, arg323_1, buf556, 8192, 104, grid=grid(8192), stream=stream0) | |
del arg323_1 | |
buf558 = buf546; del buf546 # reuse | |
# Source Nodes: [out_127, y_63], Original ATen: [aten.bmm, torchao.fp16act_fp6weight_linear] | |
triton_per_fused_bmm_fp16act_fp6weight_linear_5.run(buf556, buf558, 4096, 2, grid=grid(4096), stream=stream0) | |
# Source Nodes: [out_127], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf559 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf558, arg324_1, arg325_1, 13) | |
del arg324_1 | |
del arg325_1 | |
buf560 = buf559 | |
del buf559 | |
buf561 = buf509; del buf509 # reuse | |
buf563 = buf558; del buf558 # reuse | |
buf566 = buf511; del buf511 # reuse | |
# Source Nodes: [float_88, h_20, h_21, mean_43, mul_326, out_119, out_125, out_128, out_129], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, torchao.fp16act_fp6weight_linear] | |
triton_red_fused__to_copy_add_fp16act_fp6weight_linear_mean_mul_13.run(buf561, buf519, buf535, buf544, buf560, arg43_1, buf563, buf566, 1, 4096, grid=grid(1), stream=stream0) | |
del arg43_1 | |
del buf519 | |
del buf535 | |
del buf544 | |
del buf560 | |
# Source Nodes: [out_128], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf564 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf563, arg326_1, arg327_1, 1) | |
del arg326_1 | |
del arg327_1 | |
buf565 = buf564 | |
del buf564 | |
# Source Nodes: [out_129], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf567 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf566, arg328_1, arg329_1, 1) | |
del arg328_1 | |
del arg329_1 | |
buf568 = buf567 | |
del buf567 | |
buf569 = buf565; del buf565 # reuse | |
# Source Nodes: [out_130], Original ATen: [torchao.fp16act_fp6weight_linear] | |
triton_poi_fused_fp16act_fp6weight_linear_7.run(buf569, buf568, 11008, grid=grid(11008), stream=stream0) | |
del buf568 | |
# Source Nodes: [out_130], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf570 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf569, arg330_1, arg331_1, 13) | |
del arg330_1 | |
del arg331_1 | |
del buf569 | |
buf571 = buf570 | |
del buf570 | |
buf573 = buf566; del buf566 # reuse | |
# Source Nodes: [float_89, mean_44, mul_330, out_131, out_132], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, torchao.fp16act_fp6weight_linear] | |
triton_red_fused__to_copy_add_fp16act_fp6weight_linear_mean_mul_10.run(buf561, buf571, arg44_1, buf573, 1, 4096, grid=grid(1), stream=stream0) | |
del arg44_1 | |
# Source Nodes: [out_132], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf574 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf573, arg332_1, arg333_1, 1) | |
del arg332_1 | |
del arg333_1 | |
buf575 = buf574 | |
del buf574 | |
buf577 = buf550; del buf550 # reuse | |
# Source Nodes: [setitem_44, setitem_45, y_66], Original ATen: [aten.bmm, aten.index_put] | |
triton_poi_fused_bmm_index_put_1.run(arg454_1, buf575, arg66_1, arg334_1, buf577, arg335_1, 4096, grid=grid(4096), stream=stream0) | |
del buf575 | |
buf578 = buf555; del buf555 # reuse | |
# Source Nodes: [y_66], Original ATen: [aten.bmm] | |
triton_red_fused_bmm_2.run(buf577, arg334_1, buf578, 6656, 128, grid=grid(6656), stream=stream0) | |
del arg334_1 | |
buf582 = buf551; del buf551 # reuse | |
# Source Nodes: [mask, y_66], Original ATen: [aten._softmax, aten.add, aten.bmm, aten.index, aten.logical_not, aten.masked_fill, aten.zeros_like] | |
triton_per_fused__softmax_add_bmm_index_logical_not_masked_fill_zeros_like_3.run(buf578, arg454_1, arg67_1, buf582, 32, 208, grid=grid(32), stream=stream0) | |
buf583 = buf556; del buf556 # reuse | |
# Source Nodes: [y_66], Original ATen: [aten.bmm] | |
triton_red_fused_bmm_4.run(buf582, arg335_1, buf583, 8192, 104, grid=grid(8192), stream=stream0) | |
del arg335_1 | |
buf585 = buf573; del buf573 # reuse | |
# Source Nodes: [out_133, y_66], Original ATen: [aten.bmm, torchao.fp16act_fp6weight_linear] | |
triton_per_fused_bmm_fp16act_fp6weight_linear_5.run(buf583, buf585, 4096, 2, grid=grid(4096), stream=stream0) | |
# Source Nodes: [out_133], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf586 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf585, arg336_1, arg337_1, 13) | |
del arg336_1 | |
del arg337_1 | |
buf587 = buf586 | |
del buf586 | |
buf589 = reinterpret_tensor(buf585, (1, 1, 4096), (4096, 4096, 1), 0); del buf585 # reuse | |
# Source Nodes: [add_136, float_92, h_22, mean_45, mul_341, mul_342, mul_343, out_131, output_45, rsqrt_45], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, aten.rsqrt] | |
triton_red_fused__to_copy_add_mean_mul_rsqrt_11.run(buf561, buf571, buf587, arg45_1, buf589, 1, 4096, grid=grid(1), stream=stream0) | |
del arg45_1 | |
# Source Nodes: [out_134], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf590 = torch.ops.torchao.fp16act_fp6weight_linear.default(reinterpret_tensor(buf589, (1, 4096), (4096, 1), 0), arg338_1, arg339_1, 1) | |
del arg338_1 | |
del arg339_1 | |
buf591 = buf590 | |
del buf590 | |
# Source Nodes: [out_135], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf592 = torch.ops.torchao.fp16act_fp6weight_linear.default(reinterpret_tensor(buf589, (1, 4096), (4096, 1), 0), arg340_1, arg341_1, 1) | |
del arg340_1 | |
del arg341_1 | |
buf593 = buf592 | |
del buf592 | |
buf594 = buf591; del buf591 # reuse | |
# Source Nodes: [out_136], Original ATen: [torchao.fp16act_fp6weight_linear] | |
triton_poi_fused_fp16act_fp6weight_linear_7.run(buf594, buf593, 11008, grid=grid(11008), stream=stream0) | |
del buf593 | |
# Source Nodes: [out_136], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf595 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf594, arg342_1, arg343_1, 13) | |
del arg342_1 | |
del arg343_1 | |
del buf594 | |
buf596 = buf595 | |
del buf595 | |
buf598 = reinterpret_tensor(buf589, (1, 4096), (4096, 1), 0); del buf589 # reuse | |
# Source Nodes: [float_93, h_22, mean_46, mul_345, out_131, out_137, out_138], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, torchao.fp16act_fp6weight_linear] | |
triton_red_fused__to_copy_add_fp16act_fp6weight_linear_mean_mul_12.run(buf561, buf571, buf587, buf596, arg46_1, buf598, 1, 4096, grid=grid(1), stream=stream0) | |
del arg46_1 | |
# Source Nodes: [out_138], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf599 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf598, arg344_1, arg345_1, 1) | |
del arg344_1 | |
del arg345_1 | |
buf600 = buf599 | |
del buf599 | |
buf602 = buf577; del buf577 # reuse | |
# Source Nodes: [setitem_46, setitem_47, y_69], Original ATen: [aten.bmm, aten.index_put] | |
triton_poi_fused_bmm_index_put_1.run(arg454_1, buf600, arg66_1, arg346_1, buf602, arg347_1, 4096, grid=grid(4096), stream=stream0) | |
del buf600 | |
buf603 = buf582; del buf582 # reuse | |
# Source Nodes: [y_69], Original ATen: [aten.bmm] | |
triton_red_fused_bmm_2.run(buf602, arg346_1, buf603, 6656, 128, grid=grid(6656), stream=stream0) | |
del arg346_1 | |
buf607 = buf578; del buf578 # reuse | |
# Source Nodes: [mask, y_69], Original ATen: [aten._softmax, aten.add, aten.bmm, aten.index, aten.logical_not, aten.masked_fill, aten.zeros_like] | |
triton_per_fused__softmax_add_bmm_index_logical_not_masked_fill_zeros_like_3.run(buf603, arg454_1, arg67_1, buf607, 32, 208, grid=grid(32), stream=stream0) | |
buf608 = buf583; del buf583 # reuse | |
# Source Nodes: [y_69], Original ATen: [aten.bmm] | |
triton_red_fused_bmm_4.run(buf607, arg347_1, buf608, 8192, 104, grid=grid(8192), stream=stream0) | |
del arg347_1 | |
buf610 = buf598; del buf598 # reuse | |
# Source Nodes: [out_139, y_69], Original ATen: [aten.bmm, torchao.fp16act_fp6weight_linear] | |
triton_per_fused_bmm_fp16act_fp6weight_linear_5.run(buf608, buf610, 4096, 2, grid=grid(4096), stream=stream0) | |
# Source Nodes: [out_139], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf611 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf610, arg348_1, arg349_1, 13) | |
del arg348_1 | |
del arg349_1 | |
buf612 = buf611 | |
del buf611 | |
buf613 = buf561; del buf561 # reuse | |
buf615 = buf610; del buf610 # reuse | |
buf618 = buf563; del buf563 # reuse | |
# Source Nodes: [float_96, h_22, h_23, mean_47, mul_356, out_131, out_137, out_140, out_141], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, torchao.fp16act_fp6weight_linear] | |
triton_red_fused__to_copy_add_fp16act_fp6weight_linear_mean_mul_13.run(buf613, buf571, buf587, buf596, buf612, arg47_1, buf615, buf618, 1, 4096, grid=grid(1), stream=stream0) | |
del arg47_1 | |
del buf571 | |
del buf587 | |
del buf596 | |
del buf612 | |
# Source Nodes: [out_140], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf616 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf615, arg350_1, arg351_1, 1) | |
del arg350_1 | |
del arg351_1 | |
buf617 = buf616 | |
del buf616 | |
# Source Nodes: [out_141], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf619 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf618, arg352_1, arg353_1, 1) | |
del arg352_1 | |
del arg353_1 | |
buf620 = buf619 | |
del buf619 | |
buf621 = buf617; del buf617 # reuse | |
# Source Nodes: [out_142], Original ATen: [torchao.fp16act_fp6weight_linear] | |
triton_poi_fused_fp16act_fp6weight_linear_7.run(buf621, buf620, 11008, grid=grid(11008), stream=stream0) | |
del buf620 | |
# Source Nodes: [out_142], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf622 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf621, arg354_1, arg355_1, 13) | |
del arg354_1 | |
del arg355_1 | |
del buf621 | |
buf623 = buf622 | |
del buf622 | |
buf625 = buf618; del buf618 # reuse | |
# Source Nodes: [float_97, mean_48, mul_360, out_143, out_144], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, torchao.fp16act_fp6weight_linear] | |
triton_red_fused__to_copy_add_fp16act_fp6weight_linear_mean_mul_10.run(buf613, buf623, arg48_1, buf625, 1, 4096, grid=grid(1), stream=stream0) | |
del arg48_1 | |
# Source Nodes: [out_144], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf626 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf625, arg356_1, arg357_1, 1) | |
del arg356_1 | |
del arg357_1 | |
buf627 = buf626 | |
del buf626 | |
buf629 = buf602; del buf602 # reuse | |
# Source Nodes: [setitem_48, setitem_49, y_72], Original ATen: [aten.bmm, aten.index_put] | |
triton_poi_fused_bmm_index_put_1.run(arg454_1, buf627, arg66_1, arg358_1, buf629, arg359_1, 4096, grid=grid(4096), stream=stream0) | |
del buf627 | |
buf630 = buf607; del buf607 # reuse | |
# Source Nodes: [y_72], Original ATen: [aten.bmm] | |
triton_red_fused_bmm_2.run(buf629, arg358_1, buf630, 6656, 128, grid=grid(6656), stream=stream0) | |
del arg358_1 | |
buf634 = buf603; del buf603 # reuse | |
# Source Nodes: [mask, y_72], Original ATen: [aten._softmax, aten.add, aten.bmm, aten.index, aten.logical_not, aten.masked_fill, aten.zeros_like] | |
triton_per_fused__softmax_add_bmm_index_logical_not_masked_fill_zeros_like_3.run(buf630, arg454_1, arg67_1, buf634, 32, 208, grid=grid(32), stream=stream0) | |
buf635 = buf608; del buf608 # reuse | |
# Source Nodes: [y_72], Original ATen: [aten.bmm] | |
triton_red_fused_bmm_4.run(buf634, arg359_1, buf635, 8192, 104, grid=grid(8192), stream=stream0) | |
del arg359_1 | |
buf637 = buf625; del buf625 # reuse | |
# Source Nodes: [out_145, y_72], Original ATen: [aten.bmm, torchao.fp16act_fp6weight_linear] | |
triton_per_fused_bmm_fp16act_fp6weight_linear_5.run(buf635, buf637, 4096, 2, grid=grid(4096), stream=stream0) | |
# Source Nodes: [out_145], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf638 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf637, arg360_1, arg361_1, 13) | |
del arg360_1 | |
del arg361_1 | |
buf639 = buf638 | |
del buf638 | |
buf641 = reinterpret_tensor(buf637, (1, 1, 4096), (4096, 4096, 1), 0); del buf637 # reuse | |
# Source Nodes: [add_148, float_100, h_24, mean_49, mul_371, mul_372, mul_373, out_143, output_49, rsqrt_49], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, aten.rsqrt] | |
triton_red_fused__to_copy_add_mean_mul_rsqrt_11.run(buf613, buf623, buf639, arg49_1, buf641, 1, 4096, grid=grid(1), stream=stream0) | |
del arg49_1 | |
# Source Nodes: [out_146], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf642 = torch.ops.torchao.fp16act_fp6weight_linear.default(reinterpret_tensor(buf641, (1, 4096), (4096, 1), 0), arg362_1, arg363_1, 1) | |
del arg362_1 | |
del arg363_1 | |
buf643 = buf642 | |
del buf642 | |
# Source Nodes: [out_147], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf644 = torch.ops.torchao.fp16act_fp6weight_linear.default(reinterpret_tensor(buf641, (1, 4096), (4096, 1), 0), arg364_1, arg365_1, 1) | |
del arg364_1 | |
del arg365_1 | |
buf645 = buf644 | |
del buf644 | |
buf646 = buf643; del buf643 # reuse | |
# Source Nodes: [out_148], Original ATen: [torchao.fp16act_fp6weight_linear] | |
triton_poi_fused_fp16act_fp6weight_linear_7.run(buf646, buf645, 11008, grid=grid(11008), stream=stream0) | |
del buf645 | |
# Source Nodes: [out_148], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf647 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf646, arg366_1, arg367_1, 13) | |
del arg366_1 | |
del arg367_1 | |
del buf646 | |
buf648 = buf647 | |
del buf647 | |
buf650 = reinterpret_tensor(buf641, (1, 4096), (4096, 1), 0); del buf641 # reuse | |
# Source Nodes: [float_101, h_24, mean_50, mul_375, out_143, out_149, out_150], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, torchao.fp16act_fp6weight_linear] | |
triton_red_fused__to_copy_add_fp16act_fp6weight_linear_mean_mul_12.run(buf613, buf623, buf639, buf648, arg50_1, buf650, 1, 4096, grid=grid(1), stream=stream0) | |
del arg50_1 | |
# Source Nodes: [out_150], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf651 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf650, arg368_1, arg369_1, 1) | |
del arg368_1 | |
del arg369_1 | |
buf652 = buf651 | |
del buf651 | |
buf654 = buf629; del buf629 # reuse | |
# Source Nodes: [setitem_50, setitem_51, y_75], Original ATen: [aten.bmm, aten.index_put] | |
triton_poi_fused_bmm_index_put_1.run(arg454_1, buf652, arg66_1, arg370_1, buf654, arg371_1, 4096, grid=grid(4096), stream=stream0) | |
del buf652 | |
buf655 = buf634; del buf634 # reuse | |
# Source Nodes: [y_75], Original ATen: [aten.bmm] | |
triton_red_fused_bmm_2.run(buf654, arg370_1, buf655, 6656, 128, grid=grid(6656), stream=stream0) | |
del arg370_1 | |
buf659 = buf630; del buf630 # reuse | |
# Source Nodes: [mask, y_75], Original ATen: [aten._softmax, aten.add, aten.bmm, aten.index, aten.logical_not, aten.masked_fill, aten.zeros_like] | |
triton_per_fused__softmax_add_bmm_index_logical_not_masked_fill_zeros_like_3.run(buf655, arg454_1, arg67_1, buf659, 32, 208, grid=grid(32), stream=stream0) | |
buf660 = buf635; del buf635 # reuse | |
# Source Nodes: [y_75], Original ATen: [aten.bmm] | |
triton_red_fused_bmm_4.run(buf659, arg371_1, buf660, 8192, 104, grid=grid(8192), stream=stream0) | |
del arg371_1 | |
buf662 = buf650; del buf650 # reuse | |
# Source Nodes: [out_151, y_75], Original ATen: [aten.bmm, torchao.fp16act_fp6weight_linear] | |
triton_per_fused_bmm_fp16act_fp6weight_linear_5.run(buf660, buf662, 4096, 2, grid=grid(4096), stream=stream0) | |
# Source Nodes: [out_151], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf663 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf662, arg372_1, arg373_1, 13) | |
del arg372_1 | |
del arg373_1 | |
buf664 = buf663 | |
del buf663 | |
buf665 = buf613; del buf613 # reuse | |
buf667 = buf662; del buf662 # reuse | |
buf670 = buf615; del buf615 # reuse | |
# Source Nodes: [float_104, h_24, h_25, mean_51, mul_386, out_143, out_149, out_152, out_153], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, torchao.fp16act_fp6weight_linear] | |
triton_red_fused__to_copy_add_fp16act_fp6weight_linear_mean_mul_13.run(buf665, buf623, buf639, buf648, buf664, arg51_1, buf667, buf670, 1, 4096, grid=grid(1), stream=stream0) | |
del arg51_1 | |
del buf623 | |
del buf639 | |
del buf648 | |
del buf664 | |
# Source Nodes: [out_152], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf668 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf667, arg374_1, arg375_1, 1) | |
del arg374_1 | |
del arg375_1 | |
buf669 = buf668 | |
del buf668 | |
# Source Nodes: [out_153], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf671 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf670, arg376_1, arg377_1, 1) | |
del arg376_1 | |
del arg377_1 | |
buf672 = buf671 | |
del buf671 | |
buf673 = buf669; del buf669 # reuse | |
# Source Nodes: [out_154], Original ATen: [torchao.fp16act_fp6weight_linear] | |
triton_poi_fused_fp16act_fp6weight_linear_7.run(buf673, buf672, 11008, grid=grid(11008), stream=stream0) | |
del buf672 | |
# Source Nodes: [out_154], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf674 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf673, arg378_1, arg379_1, 13) | |
del arg378_1 | |
del arg379_1 | |
del buf673 | |
buf675 = buf674 | |
del buf674 | |
buf677 = buf670; del buf670 # reuse | |
# Source Nodes: [float_105, mean_52, mul_390, out_155, out_156], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, torchao.fp16act_fp6weight_linear] | |
triton_red_fused__to_copy_add_fp16act_fp6weight_linear_mean_mul_10.run(buf665, buf675, arg52_1, buf677, 1, 4096, grid=grid(1), stream=stream0) | |
del arg52_1 | |
# Source Nodes: [out_156], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf678 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf677, arg380_1, arg381_1, 1) | |
del arg380_1 | |
del arg381_1 | |
buf679 = buf678 | |
del buf678 | |
buf681 = buf654; del buf654 # reuse | |
# Source Nodes: [setitem_52, setitem_53, y_78], Original ATen: [aten.bmm, aten.index_put] | |
triton_poi_fused_bmm_index_put_1.run(arg454_1, buf679, arg66_1, arg382_1, buf681, arg383_1, 4096, grid=grid(4096), stream=stream0) | |
del buf679 | |
buf682 = buf659; del buf659 # reuse | |
# Source Nodes: [y_78], Original ATen: [aten.bmm] | |
triton_red_fused_bmm_2.run(buf681, arg382_1, buf682, 6656, 128, grid=grid(6656), stream=stream0) | |
del arg382_1 | |
buf686 = buf655; del buf655 # reuse | |
# Source Nodes: [mask, y_78], Original ATen: [aten._softmax, aten.add, aten.bmm, aten.index, aten.logical_not, aten.masked_fill, aten.zeros_like] | |
triton_per_fused__softmax_add_bmm_index_logical_not_masked_fill_zeros_like_3.run(buf682, arg454_1, arg67_1, buf686, 32, 208, grid=grid(32), stream=stream0) | |
buf687 = buf660; del buf660 # reuse | |
# Source Nodes: [y_78], Original ATen: [aten.bmm] | |
triton_red_fused_bmm_4.run(buf686, arg383_1, buf687, 8192, 104, grid=grid(8192), stream=stream0) | |
del arg383_1 | |
buf689 = buf677; del buf677 # reuse | |
# Source Nodes: [out_157, y_78], Original ATen: [aten.bmm, torchao.fp16act_fp6weight_linear] | |
triton_per_fused_bmm_fp16act_fp6weight_linear_5.run(buf687, buf689, 4096, 2, grid=grid(4096), stream=stream0) | |
# Source Nodes: [out_157], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf690 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf689, arg384_1, arg385_1, 13) | |
del arg384_1 | |
del arg385_1 | |
buf691 = buf690 | |
del buf690 | |
buf693 = reinterpret_tensor(buf689, (1, 1, 4096), (4096, 4096, 1), 0); del buf689 # reuse | |
# Source Nodes: [add_160, float_108, h_26, mean_53, mul_401, mul_402, mul_403, out_155, output_53, rsqrt_53], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, aten.rsqrt] | |
triton_red_fused__to_copy_add_mean_mul_rsqrt_11.run(buf665, buf675, buf691, arg53_1, buf693, 1, 4096, grid=grid(1), stream=stream0) | |
del arg53_1 | |
# Source Nodes: [out_158], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf694 = torch.ops.torchao.fp16act_fp6weight_linear.default(reinterpret_tensor(buf693, (1, 4096), (4096, 1), 0), arg386_1, arg387_1, 1) | |
del arg386_1 | |
del arg387_1 | |
buf695 = buf694 | |
del buf694 | |
# Source Nodes: [out_159], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf696 = torch.ops.torchao.fp16act_fp6weight_linear.default(reinterpret_tensor(buf693, (1, 4096), (4096, 1), 0), arg388_1, arg389_1, 1) | |
del arg388_1 | |
del arg389_1 | |
buf697 = buf696 | |
del buf696 | |
buf698 = buf695; del buf695 # reuse | |
# Source Nodes: [out_160], Original ATen: [torchao.fp16act_fp6weight_linear] | |
triton_poi_fused_fp16act_fp6weight_linear_7.run(buf698, buf697, 11008, grid=grid(11008), stream=stream0) | |
del buf697 | |
# Source Nodes: [out_160], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf699 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf698, arg390_1, arg391_1, 13) | |
del arg390_1 | |
del arg391_1 | |
del buf698 | |
buf700 = buf699 | |
del buf699 | |
buf702 = reinterpret_tensor(buf693, (1, 4096), (4096, 1), 0); del buf693 # reuse | |
# Source Nodes: [float_109, h_26, mean_54, mul_405, out_155, out_161, out_162], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, torchao.fp16act_fp6weight_linear] | |
triton_red_fused__to_copy_add_fp16act_fp6weight_linear_mean_mul_12.run(buf665, buf675, buf691, buf700, arg54_1, buf702, 1, 4096, grid=grid(1), stream=stream0) | |
del arg54_1 | |
# Source Nodes: [out_162], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf703 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf702, arg392_1, arg393_1, 1) | |
del arg392_1 | |
del arg393_1 | |
buf704 = buf703 | |
del buf703 | |
buf706 = buf681; del buf681 # reuse | |
# Source Nodes: [setitem_54, setitem_55, y_81], Original ATen: [aten.bmm, aten.index_put] | |
triton_poi_fused_bmm_index_put_1.run(arg454_1, buf704, arg66_1, arg394_1, buf706, arg395_1, 4096, grid=grid(4096), stream=stream0) | |
del buf704 | |
buf707 = buf686; del buf686 # reuse | |
# Source Nodes: [y_81], Original ATen: [aten.bmm] | |
triton_red_fused_bmm_2.run(buf706, arg394_1, buf707, 6656, 128, grid=grid(6656), stream=stream0) | |
del arg394_1 | |
buf711 = buf682; del buf682 # reuse | |
# Source Nodes: [mask, y_81], Original ATen: [aten._softmax, aten.add, aten.bmm, aten.index, aten.logical_not, aten.masked_fill, aten.zeros_like] | |
triton_per_fused__softmax_add_bmm_index_logical_not_masked_fill_zeros_like_3.run(buf707, arg454_1, arg67_1, buf711, 32, 208, grid=grid(32), stream=stream0) | |
buf712 = buf687; del buf687 # reuse | |
# Source Nodes: [y_81], Original ATen: [aten.bmm] | |
triton_red_fused_bmm_4.run(buf711, arg395_1, buf712, 8192, 104, grid=grid(8192), stream=stream0) | |
del arg395_1 | |
buf714 = buf702; del buf702 # reuse | |
# Source Nodes: [out_163, y_81], Original ATen: [aten.bmm, torchao.fp16act_fp6weight_linear] | |
triton_per_fused_bmm_fp16act_fp6weight_linear_5.run(buf712, buf714, 4096, 2, grid=grid(4096), stream=stream0) | |
# Source Nodes: [out_163], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf715 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf714, arg396_1, arg397_1, 13) | |
del arg396_1 | |
del arg397_1 | |
buf716 = buf715 | |
del buf715 | |
buf717 = buf665; del buf665 # reuse | |
buf719 = buf714; del buf714 # reuse | |
buf722 = buf667; del buf667 # reuse | |
# Source Nodes: [float_112, h_26, h_27, mean_55, mul_416, out_155, out_161, out_164, out_165], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, torchao.fp16act_fp6weight_linear] | |
triton_red_fused__to_copy_add_fp16act_fp6weight_linear_mean_mul_13.run(buf717, buf675, buf691, buf700, buf716, arg55_1, buf719, buf722, 1, 4096, grid=grid(1), stream=stream0) | |
del arg55_1 | |
del buf675 | |
del buf691 | |
del buf700 | |
del buf716 | |
# Source Nodes: [out_164], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf720 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf719, arg398_1, arg399_1, 1) | |
del arg398_1 | |
del arg399_1 | |
buf721 = buf720 | |
del buf720 | |
# Source Nodes: [out_165], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf723 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf722, arg400_1, arg401_1, 1) | |
del arg400_1 | |
del arg401_1 | |
buf724 = buf723 | |
del buf723 | |
buf725 = buf721; del buf721 # reuse | |
# Source Nodes: [out_166], Original ATen: [torchao.fp16act_fp6weight_linear] | |
triton_poi_fused_fp16act_fp6weight_linear_7.run(buf725, buf724, 11008, grid=grid(11008), stream=stream0) | |
del buf724 | |
# Source Nodes: [out_166], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf726 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf725, arg402_1, arg403_1, 13) | |
del arg402_1 | |
del arg403_1 | |
del buf725 | |
buf727 = buf726 | |
del buf726 | |
buf729 = buf722; del buf722 # reuse | |
# Source Nodes: [float_113, mean_56, mul_420, out_167, out_168], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, torchao.fp16act_fp6weight_linear] | |
triton_red_fused__to_copy_add_fp16act_fp6weight_linear_mean_mul_10.run(buf717, buf727, arg56_1, buf729, 1, 4096, grid=grid(1), stream=stream0) | |
del arg56_1 | |
# Source Nodes: [out_168], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf730 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf729, arg404_1, arg405_1, 1) | |
del arg404_1 | |
del arg405_1 | |
buf731 = buf730 | |
del buf730 | |
buf733 = buf706; del buf706 # reuse | |
# Source Nodes: [setitem_56, setitem_57, y_84], Original ATen: [aten.bmm, aten.index_put] | |
triton_poi_fused_bmm_index_put_1.run(arg454_1, buf731, arg66_1, arg406_1, buf733, arg407_1, 4096, grid=grid(4096), stream=stream0) | |
del buf731 | |
buf734 = buf711; del buf711 # reuse | |
# Source Nodes: [y_84], Original ATen: [aten.bmm] | |
triton_red_fused_bmm_2.run(buf733, arg406_1, buf734, 6656, 128, grid=grid(6656), stream=stream0) | |
del arg406_1 | |
buf738 = buf707; del buf707 # reuse | |
# Source Nodes: [mask, y_84], Original ATen: [aten._softmax, aten.add, aten.bmm, aten.index, aten.logical_not, aten.masked_fill, aten.zeros_like] | |
triton_per_fused__softmax_add_bmm_index_logical_not_masked_fill_zeros_like_3.run(buf734, arg454_1, arg67_1, buf738, 32, 208, grid=grid(32), stream=stream0) | |
buf739 = buf712; del buf712 # reuse | |
# Source Nodes: [y_84], Original ATen: [aten.bmm] | |
triton_red_fused_bmm_4.run(buf738, arg407_1, buf739, 8192, 104, grid=grid(8192), stream=stream0) | |
del arg407_1 | |
buf741 = buf729; del buf729 # reuse | |
# Source Nodes: [out_169, y_84], Original ATen: [aten.bmm, torchao.fp16act_fp6weight_linear] | |
triton_per_fused_bmm_fp16act_fp6weight_linear_5.run(buf739, buf741, 4096, 2, grid=grid(4096), stream=stream0) | |
# Source Nodes: [out_169], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf742 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf741, arg408_1, arg409_1, 13) | |
del arg408_1 | |
del arg409_1 | |
buf743 = buf742 | |
del buf742 | |
buf745 = reinterpret_tensor(buf741, (1, 1, 4096), (4096, 4096, 1), 0); del buf741 # reuse | |
# Source Nodes: [add_172, float_116, h_28, mean_57, mul_431, mul_432, mul_433, out_167, output_57, rsqrt_57], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, aten.rsqrt] | |
triton_red_fused__to_copy_add_mean_mul_rsqrt_11.run(buf717, buf727, buf743, arg57_1, buf745, 1, 4096, grid=grid(1), stream=stream0) | |
del arg57_1 | |
# Source Nodes: [out_170], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf746 = torch.ops.torchao.fp16act_fp6weight_linear.default(reinterpret_tensor(buf745, (1, 4096), (4096, 1), 0), arg410_1, arg411_1, 1) | |
del arg410_1 | |
del arg411_1 | |
buf747 = buf746 | |
del buf746 | |
# Source Nodes: [out_171], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf748 = torch.ops.torchao.fp16act_fp6weight_linear.default(reinterpret_tensor(buf745, (1, 4096), (4096, 1), 0), arg412_1, arg413_1, 1) | |
del arg412_1 | |
del arg413_1 | |
buf749 = buf748 | |
del buf748 | |
buf750 = buf747; del buf747 # reuse | |
# Source Nodes: [out_172], Original ATen: [torchao.fp16act_fp6weight_linear] | |
triton_poi_fused_fp16act_fp6weight_linear_7.run(buf750, buf749, 11008, grid=grid(11008), stream=stream0) | |
del buf749 | |
# Source Nodes: [out_172], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf751 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf750, arg414_1, arg415_1, 13) | |
del arg414_1 | |
del arg415_1 | |
del buf750 | |
buf752 = buf751 | |
del buf751 | |
buf754 = reinterpret_tensor(buf745, (1, 4096), (4096, 1), 0); del buf745 # reuse | |
# Source Nodes: [float_117, h_28, mean_58, mul_435, out_167, out_173, out_174], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, torchao.fp16act_fp6weight_linear] | |
triton_red_fused__to_copy_add_fp16act_fp6weight_linear_mean_mul_12.run(buf717, buf727, buf743, buf752, arg58_1, buf754, 1, 4096, grid=grid(1), stream=stream0) | |
del arg58_1 | |
# Source Nodes: [out_174], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf755 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf754, arg416_1, arg417_1, 1) | |
del arg416_1 | |
del arg417_1 | |
buf756 = buf755 | |
del buf755 | |
buf758 = buf733; del buf733 # reuse | |
# Source Nodes: [setitem_58, setitem_59, y_87], Original ATen: [aten.bmm, aten.index_put] | |
triton_poi_fused_bmm_index_put_1.run(arg454_1, buf756, arg66_1, arg418_1, buf758, arg419_1, 4096, grid=grid(4096), stream=stream0) | |
del buf756 | |
buf759 = buf738; del buf738 # reuse | |
# Source Nodes: [y_87], Original ATen: [aten.bmm] | |
triton_red_fused_bmm_2.run(buf758, arg418_1, buf759, 6656, 128, grid=grid(6656), stream=stream0) | |
del arg418_1 | |
buf763 = buf734; del buf734 # reuse | |
# Source Nodes: [mask, y_87], Original ATen: [aten._softmax, aten.add, aten.bmm, aten.index, aten.logical_not, aten.masked_fill, aten.zeros_like] | |
triton_per_fused__softmax_add_bmm_index_logical_not_masked_fill_zeros_like_3.run(buf759, arg454_1, arg67_1, buf763, 32, 208, grid=grid(32), stream=stream0) | |
buf764 = buf739; del buf739 # reuse | |
# Source Nodes: [y_87], Original ATen: [aten.bmm] | |
triton_red_fused_bmm_4.run(buf763, arg419_1, buf764, 8192, 104, grid=grid(8192), stream=stream0) | |
del arg419_1 | |
buf766 = buf754; del buf754 # reuse | |
# Source Nodes: [out_175, y_87], Original ATen: [aten.bmm, torchao.fp16act_fp6weight_linear] | |
triton_per_fused_bmm_fp16act_fp6weight_linear_5.run(buf764, buf766, 4096, 2, grid=grid(4096), stream=stream0) | |
# Source Nodes: [out_175], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf767 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf766, arg420_1, arg421_1, 13) | |
del arg420_1 | |
del arg421_1 | |
buf768 = buf767 | |
del buf767 | |
buf769 = buf717; del buf717 # reuse | |
buf771 = buf766; del buf766 # reuse | |
buf774 = buf719; del buf719 # reuse | |
# Source Nodes: [float_120, h_28, h_29, mean_59, mul_446, out_167, out_173, out_176, out_177], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, torchao.fp16act_fp6weight_linear] | |
triton_red_fused__to_copy_add_fp16act_fp6weight_linear_mean_mul_13.run(buf769, buf727, buf743, buf752, buf768, arg59_1, buf771, buf774, 1, 4096, grid=grid(1), stream=stream0) | |
del arg59_1 | |
del buf727 | |
del buf743 | |
del buf752 | |
del buf768 | |
# Source Nodes: [out_176], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf772 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf771, arg422_1, arg423_1, 1) | |
del arg422_1 | |
del arg423_1 | |
buf773 = buf772 | |
del buf772 | |
# Source Nodes: [out_177], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf775 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf774, arg424_1, arg425_1, 1) | |
del arg424_1 | |
del arg425_1 | |
buf776 = buf775 | |
del buf775 | |
buf777 = buf773; del buf773 # reuse | |
# Source Nodes: [out_178], Original ATen: [torchao.fp16act_fp6weight_linear] | |
triton_poi_fused_fp16act_fp6weight_linear_7.run(buf777, buf776, 11008, grid=grid(11008), stream=stream0) | |
del buf776 | |
# Source Nodes: [out_178], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf778 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf777, arg426_1, arg427_1, 13) | |
del arg426_1 | |
del arg427_1 | |
del buf777 | |
buf779 = buf778 | |
del buf778 | |
buf781 = buf774; del buf774 # reuse | |
# Source Nodes: [float_121, mean_60, mul_450, out_179, out_180], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, torchao.fp16act_fp6weight_linear] | |
triton_red_fused__to_copy_add_fp16act_fp6weight_linear_mean_mul_10.run(buf769, buf779, arg60_1, buf781, 1, 4096, grid=grid(1), stream=stream0) | |
del arg60_1 | |
# Source Nodes: [out_180], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf782 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf781, arg428_1, arg429_1, 1) | |
del arg428_1 | |
del arg429_1 | |
buf783 = buf782 | |
del buf782 | |
buf785 = buf758; del buf758 # reuse | |
# Source Nodes: [setitem_60, setitem_61, y_90], Original ATen: [aten.bmm, aten.index_put] | |
triton_poi_fused_bmm_index_put_1.run(arg454_1, buf783, arg66_1, arg430_1, buf785, arg431_1, 4096, grid=grid(4096), stream=stream0) | |
del buf783 | |
buf786 = buf763; del buf763 # reuse | |
# Source Nodes: [y_90], Original ATen: [aten.bmm] | |
triton_red_fused_bmm_2.run(buf785, arg430_1, buf786, 6656, 128, grid=grid(6656), stream=stream0) | |
del arg430_1 | |
buf790 = buf759; del buf759 # reuse | |
# Source Nodes: [mask, y_90], Original ATen: [aten._softmax, aten.add, aten.bmm, aten.index, aten.logical_not, aten.masked_fill, aten.zeros_like] | |
triton_per_fused__softmax_add_bmm_index_logical_not_masked_fill_zeros_like_3.run(buf786, arg454_1, arg67_1, buf790, 32, 208, grid=grid(32), stream=stream0) | |
buf791 = buf764; del buf764 # reuse | |
# Source Nodes: [y_90], Original ATen: [aten.bmm] | |
triton_red_fused_bmm_4.run(buf790, arg431_1, buf791, 8192, 104, grid=grid(8192), stream=stream0) | |
del arg431_1 | |
buf793 = buf781; del buf781 # reuse | |
# Source Nodes: [out_181, y_90], Original ATen: [aten.bmm, torchao.fp16act_fp6weight_linear] | |
triton_per_fused_bmm_fp16act_fp6weight_linear_5.run(buf791, buf793, 4096, 2, grid=grid(4096), stream=stream0) | |
# Source Nodes: [out_181], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf794 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf793, arg432_1, arg433_1, 13) | |
del arg432_1 | |
del arg433_1 | |
buf795 = buf794 | |
del buf794 | |
buf797 = reinterpret_tensor(buf793, (1, 1, 4096), (4096, 4096, 1), 0); del buf793 # reuse | |
# Source Nodes: [add_184, float_124, h_30, mean_61, mul_461, mul_462, mul_463, out_179, output_61, rsqrt_61], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, aten.rsqrt] | |
triton_red_fused__to_copy_add_mean_mul_rsqrt_11.run(buf769, buf779, buf795, arg61_1, buf797, 1, 4096, grid=grid(1), stream=stream0) | |
del arg61_1 | |
# Source Nodes: [out_182], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf798 = torch.ops.torchao.fp16act_fp6weight_linear.default(reinterpret_tensor(buf797, (1, 4096), (4096, 1), 0), arg434_1, arg435_1, 1) | |
del arg434_1 | |
del arg435_1 | |
buf799 = buf798 | |
del buf798 | |
# Source Nodes: [out_183], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf800 = torch.ops.torchao.fp16act_fp6weight_linear.default(reinterpret_tensor(buf797, (1, 4096), (4096, 1), 0), arg436_1, arg437_1, 1) | |
del arg436_1 | |
del arg437_1 | |
buf801 = buf800 | |
del buf800 | |
buf802 = buf799; del buf799 # reuse | |
# Source Nodes: [out_184], Original ATen: [torchao.fp16act_fp6weight_linear] | |
triton_poi_fused_fp16act_fp6weight_linear_7.run(buf802, buf801, 11008, grid=grid(11008), stream=stream0) | |
del buf801 | |
# Source Nodes: [out_184], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf803 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf802, arg438_1, arg439_1, 13) | |
del arg438_1 | |
del arg439_1 | |
del buf802 | |
buf804 = buf803 | |
del buf803 | |
buf806 = reinterpret_tensor(buf797, (1, 4096), (4096, 1), 0); del buf797 # reuse | |
# Source Nodes: [float_125, h_30, mean_62, mul_465, out_179, out_185, out_186], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, torchao.fp16act_fp6weight_linear] | |
triton_red_fused__to_copy_add_fp16act_fp6weight_linear_mean_mul_12.run(buf769, buf779, buf795, buf804, arg62_1, buf806, 1, 4096, grid=grid(1), stream=stream0) | |
del arg62_1 | |
# Source Nodes: [out_186], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf807 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf806, arg440_1, arg441_1, 1) | |
del arg440_1 | |
del arg441_1 | |
buf808 = buf807 | |
del buf807 | |
buf810 = buf785; del buf785 # reuse | |
# Source Nodes: [setitem_62, setitem_63, y_93], Original ATen: [aten.bmm, aten.index_put] | |
triton_poi_fused_bmm_index_put_1.run(arg454_1, buf808, arg66_1, arg442_1, buf810, arg443_1, 4096, grid=grid(4096), stream=stream0) | |
del arg66_1 | |
del buf808 | |
buf811 = buf790; del buf790 # reuse | |
# Source Nodes: [y_93], Original ATen: [aten.bmm] | |
triton_red_fused_bmm_2.run(buf810, arg442_1, buf811, 6656, 128, grid=grid(6656), stream=stream0) | |
del arg442_1 | |
del buf810 | |
buf815 = buf786; del buf786 # reuse | |
# Source Nodes: [mask, y_93], Original ATen: [aten._softmax, aten.add, aten.bmm, aten.index, aten.logical_not, aten.masked_fill, aten.zeros_like] | |
triton_per_fused__softmax_add_bmm_index_logical_not_masked_fill_zeros_like_3.run(buf811, arg454_1, arg67_1, buf815, 32, 208, grid=grid(32), stream=stream0) | |
del arg454_1 | |
del arg67_1 | |
del buf811 | |
buf816 = buf791; del buf791 # reuse | |
# Source Nodes: [y_93], Original ATen: [aten.bmm] | |
triton_red_fused_bmm_4.run(buf815, arg443_1, buf816, 8192, 104, grid=grid(8192), stream=stream0) | |
del arg443_1 | |
del buf815 | |
buf818 = buf806; del buf806 # reuse | |
# Source Nodes: [out_187, y_93], Original ATen: [aten.bmm, torchao.fp16act_fp6weight_linear] | |
triton_per_fused_bmm_fp16act_fp6weight_linear_5.run(buf816, buf818, 4096, 2, grid=grid(4096), stream=stream0) | |
del buf816 | |
# Source Nodes: [out_187], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf819 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf818, arg444_1, arg445_1, 13) | |
del arg444_1 | |
del arg445_1 | |
buf820 = buf819 | |
del buf819 | |
buf821 = buf769; del buf769 # reuse | |
buf823 = buf818; del buf818 # reuse | |
buf826 = buf771; del buf771 # reuse | |
# Source Nodes: [float_128, h_30, h_31, mean_63, mul_476, out_179, out_185, out_188, out_189], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, torchao.fp16act_fp6weight_linear] | |
triton_red_fused__to_copy_add_fp16act_fp6weight_linear_mean_mul_13.run(buf821, buf779, buf795, buf804, buf820, arg63_1, buf823, buf826, 1, 4096, grid=grid(1), stream=stream0) | |
del arg63_1 | |
del buf779 | |
del buf795 | |
del buf804 | |
del buf820 | |
# Source Nodes: [out_188], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf824 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf823, arg446_1, arg447_1, 1) | |
del arg446_1 | |
del arg447_1 | |
del buf823 | |
buf825 = buf824 | |
del buf824 | |
# Source Nodes: [out_189], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf827 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf826, arg448_1, arg449_1, 1) | |
del arg448_1 | |
del arg449_1 | |
buf828 = buf827 | |
del buf827 | |
buf829 = buf825; del buf825 # reuse | |
# Source Nodes: [out_190], Original ATen: [torchao.fp16act_fp6weight_linear] | |
triton_poi_fused_fp16act_fp6weight_linear_7.run(buf829, buf828, 11008, grid=grid(11008), stream=stream0) | |
del buf828 | |
# Source Nodes: [out_190], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf830 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf829, arg450_1, arg451_1, 13) | |
del arg450_1 | |
del arg451_1 | |
del buf829 | |
buf831 = buf830 | |
del buf830 | |
buf833 = buf826; del buf826 # reuse | |
# Source Nodes: [float_129, mean_64, mul_480, out_191, out_192], Original ATen: [aten._to_copy, aten.add, aten.mean, aten.mul, torchao.fp16act_fp6weight_linear] | |
triton_red_fused__to_copy_add_fp16act_fp6weight_linear_mean_mul_10.run(buf821, buf831, arg64_1, buf833, 1, 4096, grid=grid(1), stream=stream0) | |
del arg64_1 | |
del buf821 | |
del buf831 | |
# Source Nodes: [out_192], Original ATen: [torchao.fp16act_fp6weight_linear] | |
buf834 = torch.ops.torchao.fp16act_fp6weight_linear.default(buf833, arg452_1, arg453_1, 1) | |
del arg452_1 | |
del arg453_1 | |
del buf833 | |
buf835 = buf834 | |
del buf834 | |
buf836 = reinterpret_tensor(buf835, (32000, ), (1, ), 0); del buf835 # reuse | |
# Source Nodes: [logits_1], Original ATen: [aten.div] | |
triton_poi_fused_div_16.run(buf836, 32000, grid=grid(32000), stream=stream0) | |
# Source Nodes: [logits_1, topk], Original ATen: [aten.div, aten.topk] | |
buf837 = aten.topk.default(buf836, 200) | |
buf838 = buf837[0] | |
del buf837 | |
buf840 = empty_strided_cuda((1, 4), (4, 1), torch.float32) | |
# Source Nodes: [logits_2, lt, probs], Original ATen: [aten._softmax, aten.lt, aten.scalar_tensor, aten.where] | |
triton_red_fused__softmax_lt_scalar_tensor_where_17.run(buf836, buf838, buf840, 4, 8000, grid=grid(4), stream=stream0) | |
buf841 = empty_strided_cuda((1, ), (1, ), torch.float32) | |
# Source Nodes: [logits_2, lt, probs], Original ATen: [aten._softmax, aten.lt, aten.scalar_tensor, aten.where] | |
triton_per_fused__softmax_lt_scalar_tensor_where_18.run(buf840, buf841, 1, 4, grid=grid(1), stream=stream0) | |
buf842 = buf840; del buf840 # reuse | |
# Source Nodes: [logits_2, lt, probs], Original ATen: [aten._softmax, aten.lt, aten.scalar_tensor, aten.where] | |
triton_red_fused__softmax_lt_scalar_tensor_where_19.run(buf836, buf838, buf841, buf842, 4, 8000, grid=grid(4), stream=stream0) | |
buf843 = empty_strided_cuda((1, ), (1, ), torch.float32) | |
# Source Nodes: [logits_2, lt, probs], Original ATen: [aten._softmax, aten.lt, aten.scalar_tensor, aten.where] | |
triton_per_fused__softmax_lt_scalar_tensor_where_20.run(buf842, buf843, 1, 4, grid=grid(1), stream=stream0) | |
del buf842 | |
buf845 = empty_strided_cuda((1, ), (1, ), torch.int64) | |
# Source Nodes: [], Original ATen: [] | |
aten.randint.low_out(-9223372036854775808, 9223372036854775807, [1], out=buf845) | |
buf844 = buf836; del buf836 # reuse | |
buf848 = empty_strided_cuda((1, ), (1, ), torch.int32) | |
# Source Nodes: [argmax, idx_next, logits_2, lt, probs, q_128, truediv_1], Original ATen: [aten._softmax, aten._to_copy, aten.argmax, aten.div, aten.exponential, aten.lt, aten.scalar_tensor, aten.where] | |
triton_red_fused__softmax__to_copy_argmax_div_exponential_lt_scalar_tensor_where_21.run(buf844, buf838, buf841, buf843, buf845, buf848, 0, 1, 32000, grid=grid(1), stream=stream0) | |
del buf838 | |
del buf841 | |
del buf843 | |
del buf845 | |
return (buf848, buf844, ) | |
def benchmark_compiled_module(times=10, repeat=10): | |
from torch._dynamo.testing import rand_strided | |
from torch._inductor.utils import print_performance | |
arg0_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg1_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg2_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg3_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg4_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg5_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg6_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg7_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg8_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg9_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg10_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg11_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg12_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg13_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg14_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg15_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg16_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg17_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg18_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg19_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg20_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg21_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg22_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg23_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg24_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg25_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg26_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg27_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg28_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg29_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg30_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg31_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg32_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg33_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg34_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg35_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg36_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg37_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg38_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg39_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg40_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg41_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg42_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg43_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg44_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg45_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg46_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg47_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg48_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg49_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg50_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg51_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg52_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg53_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg54_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg55_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg56_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg57_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg58_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg59_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg60_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg61_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg62_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg63_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg64_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg65_1 = rand_strided((32000, 4096), (4096, 1), device='cuda:0', dtype=torch.float16) | |
arg66_1 = rand_strided((2048, 64, 2), (128, 2, 1), device='cuda:0', dtype=torch.float16) | |
arg67_1 = rand_strided((208, 208), (208, 1), device='cuda:0', dtype=torch.bool) | |
arg68_1 = rand_strided((12288, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg69_1 = rand_strided((12288, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg70_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16) | |
arg71_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16) | |
arg72_1 = rand_strided((4096, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg73_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg74_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg75_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg76_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg77_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg78_1 = rand_strided((4096, 2064), (2064, 1), device='cuda:0', dtype=torch.int32) | |
arg79_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg80_1 = rand_strided((12288, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg81_1 = rand_strided((12288, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg82_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16) | |
arg83_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16) | |
arg84_1 = rand_strided((4096, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg85_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg86_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg87_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg88_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg89_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg90_1 = rand_strided((4096, 2064), (2064, 1), device='cuda:0', dtype=torch.int32) | |
arg91_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg92_1 = rand_strided((12288, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg93_1 = rand_strided((12288, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg94_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16) | |
arg95_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16) | |
arg96_1 = rand_strided((4096, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg97_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg98_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg99_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg100_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg101_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg102_1 = rand_strided((4096, 2064), (2064, 1), device='cuda:0', dtype=torch.int32) | |
arg103_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg104_1 = rand_strided((12288, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg105_1 = rand_strided((12288, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg106_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16) | |
arg107_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16) | |
arg108_1 = rand_strided((4096, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg109_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg110_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg111_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg112_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg113_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg114_1 = rand_strided((4096, 2064), (2064, 1), device='cuda:0', dtype=torch.int32) | |
arg115_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg116_1 = rand_strided((12288, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg117_1 = rand_strided((12288, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg118_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16) | |
arg119_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16) | |
arg120_1 = rand_strided((4096, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg121_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg122_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg123_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg124_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg125_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg126_1 = rand_strided((4096, 2064), (2064, 1), device='cuda:0', dtype=torch.int32) | |
arg127_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg128_1 = rand_strided((12288, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg129_1 = rand_strided((12288, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg130_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16) | |
arg131_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16) | |
arg132_1 = rand_strided((4096, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg133_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg134_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg135_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg136_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg137_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg138_1 = rand_strided((4096, 2064), (2064, 1), device='cuda:0', dtype=torch.int32) | |
arg139_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg140_1 = rand_strided((12288, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg141_1 = rand_strided((12288, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg142_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16) | |
arg143_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16) | |
arg144_1 = rand_strided((4096, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg145_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg146_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg147_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg148_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg149_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg150_1 = rand_strided((4096, 2064), (2064, 1), device='cuda:0', dtype=torch.int32) | |
arg151_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg152_1 = rand_strided((12288, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg153_1 = rand_strided((12288, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg154_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16) | |
arg155_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16) | |
arg156_1 = rand_strided((4096, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg157_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg158_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg159_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg160_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg161_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg162_1 = rand_strided((4096, 2064), (2064, 1), device='cuda:0', dtype=torch.int32) | |
arg163_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg164_1 = rand_strided((12288, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg165_1 = rand_strided((12288, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg166_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16) | |
arg167_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16) | |
arg168_1 = rand_strided((4096, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg169_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg170_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg171_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg172_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg173_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg174_1 = rand_strided((4096, 2064), (2064, 1), device='cuda:0', dtype=torch.int32) | |
arg175_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg176_1 = rand_strided((12288, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg177_1 = rand_strided((12288, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg178_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16) | |
arg179_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16) | |
arg180_1 = rand_strided((4096, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg181_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg182_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg183_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg184_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg185_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg186_1 = rand_strided((4096, 2064), (2064, 1), device='cuda:0', dtype=torch.int32) | |
arg187_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg188_1 = rand_strided((12288, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg189_1 = rand_strided((12288, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg190_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16) | |
arg191_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16) | |
arg192_1 = rand_strided((4096, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg193_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg194_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg195_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg196_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg197_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg198_1 = rand_strided((4096, 2064), (2064, 1), device='cuda:0', dtype=torch.int32) | |
arg199_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg200_1 = rand_strided((12288, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg201_1 = rand_strided((12288, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg202_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16) | |
arg203_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16) | |
arg204_1 = rand_strided((4096, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg205_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg206_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg207_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg208_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg209_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg210_1 = rand_strided((4096, 2064), (2064, 1), device='cuda:0', dtype=torch.int32) | |
arg211_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg212_1 = rand_strided((12288, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg213_1 = rand_strided((12288, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg214_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16) | |
arg215_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16) | |
arg216_1 = rand_strided((4096, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg217_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg218_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg219_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg220_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg221_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg222_1 = rand_strided((4096, 2064), (2064, 1), device='cuda:0', dtype=torch.int32) | |
arg223_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg224_1 = rand_strided((12288, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg225_1 = rand_strided((12288, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg226_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16) | |
arg227_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16) | |
arg228_1 = rand_strided((4096, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg229_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg230_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg231_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg232_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg233_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg234_1 = rand_strided((4096, 2064), (2064, 1), device='cuda:0', dtype=torch.int32) | |
arg235_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg236_1 = rand_strided((12288, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg237_1 = rand_strided((12288, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg238_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16) | |
arg239_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16) | |
arg240_1 = rand_strided((4096, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg241_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg242_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg243_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg244_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg245_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg246_1 = rand_strided((4096, 2064), (2064, 1), device='cuda:0', dtype=torch.int32) | |
arg247_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg248_1 = rand_strided((12288, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg249_1 = rand_strided((12288, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg250_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16) | |
arg251_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16) | |
arg252_1 = rand_strided((4096, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg253_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg254_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg255_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg256_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg257_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg258_1 = rand_strided((4096, 2064), (2064, 1), device='cuda:0', dtype=torch.int32) | |
arg259_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg260_1 = rand_strided((12288, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg261_1 = rand_strided((12288, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg262_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16) | |
arg263_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16) | |
arg264_1 = rand_strided((4096, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg265_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg266_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg267_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg268_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg269_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg270_1 = rand_strided((4096, 2064), (2064, 1), device='cuda:0', dtype=torch.int32) | |
arg271_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg272_1 = rand_strided((12288, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg273_1 = rand_strided((12288, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg274_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16) | |
arg275_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16) | |
arg276_1 = rand_strided((4096, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg277_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg278_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg279_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg280_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg281_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg282_1 = rand_strided((4096, 2064), (2064, 1), device='cuda:0', dtype=torch.int32) | |
arg283_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg284_1 = rand_strided((12288, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg285_1 = rand_strided((12288, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg286_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16) | |
arg287_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16) | |
arg288_1 = rand_strided((4096, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg289_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg290_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg291_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg292_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg293_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg294_1 = rand_strided((4096, 2064), (2064, 1), device='cuda:0', dtype=torch.int32) | |
arg295_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg296_1 = rand_strided((12288, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg297_1 = rand_strided((12288, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg298_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16) | |
arg299_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16) | |
arg300_1 = rand_strided((4096, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg301_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg302_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg303_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg304_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg305_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg306_1 = rand_strided((4096, 2064), (2064, 1), device='cuda:0', dtype=torch.int32) | |
arg307_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg308_1 = rand_strided((12288, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg309_1 = rand_strided((12288, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg310_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16) | |
arg311_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16) | |
arg312_1 = rand_strided((4096, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg313_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg314_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg315_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg316_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg317_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg318_1 = rand_strided((4096, 2064), (2064, 1), device='cuda:0', dtype=torch.int32) | |
arg319_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg320_1 = rand_strided((12288, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg321_1 = rand_strided((12288, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg322_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16) | |
arg323_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16) | |
arg324_1 = rand_strided((4096, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg325_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg326_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg327_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg328_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg329_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg330_1 = rand_strided((4096, 2064), (2064, 1), device='cuda:0', dtype=torch.int32) | |
arg331_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg332_1 = rand_strided((12288, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg333_1 = rand_strided((12288, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg334_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16) | |
arg335_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16) | |
arg336_1 = rand_strided((4096, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg337_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg338_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg339_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg340_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg341_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg342_1 = rand_strided((4096, 2064), (2064, 1), device='cuda:0', dtype=torch.int32) | |
arg343_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg344_1 = rand_strided((12288, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg345_1 = rand_strided((12288, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg346_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16) | |
arg347_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16) | |
arg348_1 = rand_strided((4096, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg349_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg350_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg351_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg352_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg353_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg354_1 = rand_strided((4096, 2064), (2064, 1), device='cuda:0', dtype=torch.int32) | |
arg355_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg356_1 = rand_strided((12288, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg357_1 = rand_strided((12288, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg358_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16) | |
arg359_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16) | |
arg360_1 = rand_strided((4096, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg361_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg362_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg363_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg364_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg365_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg366_1 = rand_strided((4096, 2064), (2064, 1), device='cuda:0', dtype=torch.int32) | |
arg367_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg368_1 = rand_strided((12288, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg369_1 = rand_strided((12288, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg370_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16) | |
arg371_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16) | |
arg372_1 = rand_strided((4096, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg373_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg374_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg375_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg376_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg377_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg378_1 = rand_strided((4096, 2064), (2064, 1), device='cuda:0', dtype=torch.int32) | |
arg379_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg380_1 = rand_strided((12288, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg381_1 = rand_strided((12288, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg382_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16) | |
arg383_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16) | |
arg384_1 = rand_strided((4096, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg385_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg386_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg387_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg388_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg389_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg390_1 = rand_strided((4096, 2064), (2064, 1), device='cuda:0', dtype=torch.int32) | |
arg391_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg392_1 = rand_strided((12288, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg393_1 = rand_strided((12288, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg394_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16) | |
arg395_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16) | |
arg396_1 = rand_strided((4096, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg397_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg398_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg399_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg400_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg401_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg402_1 = rand_strided((4096, 2064), (2064, 1), device='cuda:0', dtype=torch.int32) | |
arg403_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg404_1 = rand_strided((12288, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg405_1 = rand_strided((12288, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg406_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16) | |
arg407_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16) | |
arg408_1 = rand_strided((4096, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg409_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg410_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg411_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg412_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg413_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg414_1 = rand_strided((4096, 2064), (2064, 1), device='cuda:0', dtype=torch.int32) | |
arg415_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg416_1 = rand_strided((12288, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg417_1 = rand_strided((12288, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg418_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16) | |
arg419_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16) | |
arg420_1 = rand_strided((4096, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg421_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg422_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg423_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg424_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg425_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg426_1 = rand_strided((4096, 2064), (2064, 1), device='cuda:0', dtype=torch.int32) | |
arg427_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg428_1 = rand_strided((12288, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg429_1 = rand_strided((12288, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg430_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16) | |
arg431_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16) | |
arg432_1 = rand_strided((4096, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg433_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg434_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg435_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg436_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg437_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg438_1 = rand_strided((4096, 2064), (2064, 1), device='cuda:0', dtype=torch.int32) | |
arg439_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg440_1 = rand_strided((12288, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg441_1 = rand_strided((12288, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg442_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16) | |
arg443_1 = rand_strided((1, 32, 208, 128), (851968, 26624, 128, 1), device='cuda:0', dtype=torch.float16) | |
arg444_1 = rand_strided((4096, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg445_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg446_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg447_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg448_1 = rand_strided((11008, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg449_1 = rand_strided((11008, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg450_1 = rand_strided((4096, 2064), (2064, 1), device='cuda:0', dtype=torch.int32) | |
arg451_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg452_1 = rand_strided((32000, 768), (768, 1), device='cuda:0', dtype=torch.int32) | |
arg453_1 = rand_strided((32000, ), (1, ), device='cuda:0', dtype=torch.float16) | |
arg454_1 = rand_strided((1, ), (1, ), device='cuda:0', dtype=torch.int32) | |
arg455_1 = rand_strided((1, 1), (1, 1), device='cuda:0', dtype=torch.int32) | |
fn = lambda: call([arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1, arg7_1, arg8_1, arg9_1, arg10_1, arg11_1, arg12_1, arg13_1, arg14_1, arg15_1, arg16_1, arg17_1, arg18_1, arg19_1, arg20_1, arg21_1, arg22_1, arg23_1, arg24_1, arg25_1, arg26_1, arg27_1, arg28_1, arg29_1, arg30_1, arg31_1, arg32_1, arg33_1, arg34_1, arg35_1, arg36_1, arg37_1, arg38_1, arg39_1, arg40_1, arg41_1, arg42_1, arg43_1, arg44_1, arg45_1, arg46_1, arg47_1, arg48_1, arg49_1, arg50_1, arg51_1, arg52_1, arg53_1, arg54_1, arg55_1, arg56_1, arg57_1, arg58_1, arg59_1, arg60_1, arg61_1, arg62_1, arg63_1, arg64_1, arg65_1, arg66_1, arg67_1, arg68_1, arg69_1, arg70_1, arg71_1, arg72_1, arg73_1, arg74_1, arg75_1, arg76_1, arg77_1, arg78_1, arg79_1, arg80_1, arg81_1, arg82_1, arg83_1, arg84_1, arg85_1, arg86_1, arg87_1, arg88_1, arg89_1, arg90_1, arg91_1, arg92_1, arg93_1, arg94_1, arg95_1, arg96_1, arg97_1, arg98_1, arg99_1, arg100_1, arg101_1, arg102_1, arg103_1, arg104_1, arg105_1, arg106_1, arg107_1, arg108_1, arg109_1, arg110_1, arg111_1, arg112_1, arg113_1, arg114_1, arg115_1, arg116_1, arg117_1, arg118_1, arg119_1, arg120_1, arg121_1, arg122_1, arg123_1, arg124_1, arg125_1, arg126_1, arg127_1, arg128_1, arg129_1, arg130_1, arg131_1, arg132_1, arg133_1, arg134_1, arg135_1, arg136_1, arg137_1, arg138_1, arg139_1, arg140_1, arg141_1, arg142_1, arg143_1, arg144_1, arg145_1, arg146_1, arg147_1, arg148_1, arg149_1, arg150_1, arg151_1, arg152_1, arg153_1, arg154_1, arg155_1, arg156_1, arg157_1, arg158_1, arg159_1, arg160_1, arg161_1, arg162_1, arg163_1, arg164_1, arg165_1, arg166_1, arg167_1, arg168_1, arg169_1, arg170_1, arg171_1, arg172_1, arg173_1, arg174_1, arg175_1, arg176_1, arg177_1, arg178_1, arg179_1, arg180_1, arg181_1, arg182_1, arg183_1, arg184_1, arg185_1, arg186_1, arg187_1, arg188_1, arg189_1, arg190_1, arg191_1, arg192_1, arg193_1, arg194_1, arg195_1, arg196_1, arg197_1, arg198_1, arg199_1, arg200_1, arg201_1, arg202_1, arg203_1, arg204_1, arg205_1, arg206_1, arg207_1, arg208_1, arg209_1, arg210_1, arg211_1, arg212_1, arg213_1, arg214_1, arg215_1, arg216_1, arg217_1, arg218_1, arg219_1, arg220_1, arg221_1, arg222_1, arg223_1, arg224_1, arg225_1, arg226_1, arg227_1, arg228_1, arg229_1, arg230_1, arg231_1, arg232_1, arg233_1, arg234_1, arg235_1, arg236_1, arg237_1, arg238_1, arg239_1, arg240_1, arg241_1, arg242_1, arg243_1, arg244_1, arg245_1, arg246_1, arg247_1, arg248_1, arg249_1, arg250_1, arg251_1, arg252_1, arg253_1, arg254_1, arg255_1, arg256_1, arg257_1, arg258_1, arg259_1, arg260_1, arg261_1, arg262_1, arg263_1, arg264_1, arg265_1, arg266_1, arg267_1, arg268_1, arg269_1, arg270_1, arg271_1, arg272_1, arg273_1, arg274_1, arg275_1, arg276_1, arg277_1, arg278_1, arg279_1, arg280_1, arg281_1, arg282_1, arg283_1, arg284_1, arg285_1, arg286_1, arg287_1, arg288_1, arg289_1, arg290_1, arg291_1, arg292_1, arg293_1, arg294_1, arg295_1, arg296_1, arg297_1, arg298_1, arg299_1, arg300_1, arg301_1, arg302_1, arg303_1, arg304_1, arg305_1, arg306_1, arg307_1, arg308_1, arg309_1, arg310_1, arg311_1, arg312_1, arg313_1, arg314_1, arg315_1, arg316_1, arg317_1, arg318_1, arg319_1, arg320_1, arg321_1, arg322_1, arg323_1, arg324_1, arg325_1, arg326_1, arg327_1, arg328_1, arg329_1, arg330_1, arg331_1, arg332_1, arg333_1, arg334_1, arg335_1, arg336_1, arg337_1, arg338_1, arg339_1, arg340_1, arg341_1, arg342_1, arg343_1, arg344_1, arg345_1, arg346_1, arg347_1, arg348_1, arg349_1, arg350_1, arg351_1, arg352_1, arg353_1, arg354_1, arg355_1, arg356_1, arg357_1, arg358_1, arg359_1, arg360_1, arg361_1, arg362_1, arg363_1, arg364_1, arg365_1, arg366_1, arg367_1, arg368_1, arg369_1, arg370_1, arg371_1, arg372_1, arg373_1, arg374_1, arg375_1, arg376_1, arg377_1, arg378_1, arg379_1, arg380_1, arg381_1, arg382_1, arg383_1, arg384_1, arg385_1, arg386_1, arg387_1, arg388_1, arg389_1, arg390_1, arg391_1, arg392_1, arg393_1, arg394_1, arg395_1, arg396_1, arg397_1, arg398_1, arg399_1, arg400_1, arg401_1, arg402_1, arg403_1, arg404_1, arg405_1, arg406_1, arg407_1, arg408_1, arg409_1, arg410_1, arg411_1, arg412_1, arg413_1, arg414_1, arg415_1, arg416_1, arg417_1, arg418_1, arg419_1, arg420_1, arg421_1, arg422_1, arg423_1, arg424_1, arg425_1, arg426_1, arg427_1, arg428_1, arg429_1, arg430_1, arg431_1, arg432_1, arg433_1, arg434_1, arg435_1, arg436_1, arg437_1, arg438_1, arg439_1, arg440_1, arg441_1, arg442_1, arg443_1, arg444_1, arg445_1, arg446_1, arg447_1, arg448_1, arg449_1, arg450_1, arg451_1, arg452_1, arg453_1, arg454_1, arg455_1]) | |
return print_performance(fn, times=times, repeat=repeat) | |
if __name__ == "__main__": | |
from torch._inductor.wrapper_benchmark import compiled_module_main | |
compiled_module_main('None', benchmark_compiled_module) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment