Created
June 25, 2024 00:00
-
-
Save shunting314/caac76dfa1cc494dd032feb320a915d5 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# AOT ID: ['0_backward'] | |
from ctypes import c_void_p, c_long | |
import torch | |
import math | |
import random | |
import os | |
import tempfile | |
from math import inf, nan | |
from torch._inductor.hooks import run_intermediate_hooks | |
from torch._inductor.utils import maybe_profile | |
from torch._inductor.codegen.memory_planning import _align as align | |
from torch import device, empty_strided | |
from torch._inductor.async_compile import AsyncCompile | |
from torch._inductor.select_algorithm import extern_kernels | |
from torch._inductor.codegen.multi_kernel import MultiKernelCall | |
aten = torch.ops.aten | |
inductor_ops = torch.ops.inductor | |
_quantized = torch.ops._quantized | |
assert_size_stride = torch._C._dynamo.guards.assert_size_stride | |
empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu | |
empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda | |
reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor | |
alloc_from_pool = torch.ops.inductor._alloc_from_pool | |
async_compile = AsyncCompile() | |
# kernel path: /tmp/torchinductor_shunting/3i/c3iwze52kpy2zykepgw7shz2ksspmjdafig5gxepuphlkw5djj2f.py | |
# Source Nodes: [masked_lm_loss], Original ATen: [aten.nll_loss_backward, aten.nll_loss_forward] | |
# masked_lm_loss => full_default_1 | |
triton_poi_fused_nll_loss_backward_nll_loss_forward_0 = async_compile.triton('triton_poi_fused_nll_loss_backward_nll_loss_forward_0', ''' | |
import triton | |
import triton.language as tl | |
from triton.compiler.compiler import AttrsDescriptor | |
from torch._inductor.runtime import triton_helpers, triton_heuristics | |
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties | |
from torch._dynamo.testing import rand_strided | |
from torch._C import _cuda_getCurrentRawStream as get_raw_stream | |
import torch | |
from torch._inductor.runtime.triton_heuristics import grid, split_scan_grid | |
@triton_heuristics.pointwise( | |
size_hints=[268435456], | |
filename=__file__, | |
triton_meta={'signature': {0: '*fp32', 1: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1), equal_to_1=())]}, | |
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_nll_loss_backward_nll_loss_forward_0', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 0, 'num_reduction': 0, 'backend_hash': '196B341D951BBDD96DE5B2B44B37054C2CBA8E89494C8B8A1052A863B8BC7596', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'kernel_num_gb': 1.000144896}, | |
min_elem_per_thread=0 | |
) | |
@triton.jit | |
def triton_poi_fused_nll_loss_backward_nll_loss_forward_0(out_ptr0, xnumel, XBLOCK : tl.constexpr): | |
xnumel = 250036224 | |
xoffset = tl.program_id(0) * XBLOCK | |
xindex = xoffset + tl.arange(0, XBLOCK)[:] | |
xmask = tl.full([XBLOCK], True, tl.int1) | |
x0 = xindex % 30522 | |
x1 = (xindex // 30522) | |
tmp0 = 0.0 | |
tl.store(out_ptr0 + (x0 + (30528*x1)), tmp0, None) | |
def get_args(): | |
arg_0 = rand_strided((8192, 30522), (30528, 1), device='cuda:0', dtype=torch.float32) | |
return arg_0, | |
def call(args): | |
with torch.cuda._DeviceGuard(0): | |
torch.cuda.set_device(0) | |
stream0 = get_raw_stream(0) | |
triton_poi_fused_nll_loss_backward_nll_loss_forward_0.run(*args, 250036224, grid=grid(250036224), stream=stream0) | |
def benchmark_all_configs(args): | |
with torch.cuda._DeviceGuard(0): | |
torch.cuda.set_device(0) | |
return triton_poi_fused_nll_loss_backward_nll_loss_forward_0.benchmark_all_configs(*args, 250036224, grid=grid(250036224)) | |
if __name__ == '__main__': | |
from triton.testing import do_bench | |
args = get_args() | |
ms = do_bench(lambda: call(args), rep=40, fast_flush=True) | |
num_gb = 1.000144896 | |
gb_per_s = num_gb / (ms / 1e3) | |
print(f"{ms:.3f}ms {num_gb:.3f}GB {gb_per_s:.2f}GB/s") | |
''', device_str='cuda') | |
import triton | |
import triton.language as tl | |
from torch._inductor.runtime.triton_heuristics import grid, split_scan_grid, start_graph, end_graph | |
from torch._C import _cuda_getCurrentRawStream as get_raw_stream | |
# kernel path: /tmp/torchinductor_shunting/yt/cytwgbywfm7dtcvhkxpg6hkvllgzhv7o7bxsxaihptlzwjox3k7d.py | |
# Source Nodes: [masked_lm_loss], Original ATen: [aten.nll_loss_backward, aten.nll_loss_forward] | |
# masked_lm_loss => full_default_1 | |
triton_poi_fused_nll_loss_backward_nll_loss_forward_1 = async_compile.triton('triton_poi_fused_nll_loss_backward_nll_loss_forward_1', ''' | |
import triton | |
import triton.language as tl | |
from triton.compiler.compiler import AttrsDescriptor | |
from torch._inductor.runtime import triton_helpers, triton_heuristics | |
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties | |
from torch._dynamo.testing import rand_strided | |
from torch._C import _cuda_getCurrentRawStream as get_raw_stream | |
import torch | |
from torch._inductor.runtime.triton_heuristics import grid, split_scan_grid | |
@triton_heuristics.pointwise( | |
size_hints=[8192], | |
filename=__file__, | |
triton_meta={'signature': {0: '*i64', 1: '*fp32', 2: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2), equal_to_1=())]}, | |
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_nll_loss_backward_nll_loss_forward_1', 'mutated_arg_names': ['out_ptr0'], 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': '196B341D951BBDD96DE5B2B44B37054C2CBA8E89494C8B8A1052A863B8BC7596', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'kernel_num_gb': 9.8304e-05}, | |
min_elem_per_thread=0 | |
) | |
@triton.jit | |
def triton_poi_fused_nll_loss_backward_nll_loss_forward_1(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr): | |
xnumel = 8192 | |
xoffset = tl.program_id(0) * XBLOCK | |
xindex = xoffset + tl.arange(0, XBLOCK)[:] | |
xmask = tl.full([XBLOCK], True, tl.int1) | |
x0 = xindex | |
tmp0 = tl.load(in_ptr0 + (x0), None) | |
tmp1 = tl.full([1], -100, tl.int64) | |
tmp2 = tmp0 != tmp1 | |
tmp3 = tl.full([1], 0, tl.int64) | |
tmp4 = tl.where(tmp2, tmp0, tmp3) | |
tmp5 = tl.full([XBLOCK], 30522, tl.int32) | |
tmp6 = tmp4 + tmp5 | |
tmp7 = tmp4 < 0 | |
tmp8 = tl.where(tmp7, tmp6, tmp4) | |
tl.device_assert((0 <= tmp8) & (tmp8 < 30522), "index out of bounds: 0 <= tmp8 < 30522") | |
tmp10 = -1.0 | |
tl.store(out_ptr0 + (tmp8 + (30528*x0)), tmp10, None) | |
def get_args(): | |
arg_0 = rand_strided((16, 512), (512, 1), device='cuda:0', dtype=torch.int64) | |
arg_1 = rand_strided((8192, 30522), (30528, 1), device='cuda:0', dtype=torch.float32) | |
return arg_0, arg_1, | |
def call(args): | |
with torch.cuda._DeviceGuard(0): | |
torch.cuda.set_device(0) | |
stream0 = get_raw_stream(0) | |
triton_poi_fused_nll_loss_backward_nll_loss_forward_1.run(*args, 8192, grid=grid(8192), stream=stream0) | |
def benchmark_all_configs(args): | |
with torch.cuda._DeviceGuard(0): | |
torch.cuda.set_device(0) | |
return triton_poi_fused_nll_loss_backward_nll_loss_forward_1.benchmark_all_configs(*args, 8192, grid=grid(8192)) | |
if __name__ == '__main__': | |
from triton.testing import do_bench | |
args = get_args() | |
ms = do_bench(lambda: call(args), rep=40, fast_flush=True) | |
num_gb = 9.8304e-05 | |
gb_per_s = num_gb / (ms / 1e3) | |
print(f"{ms:.3f}ms {num_gb:.3f}GB {gb_per_s:.2f}GB/s") | |
''', device_str='cuda') | |
# kernel path: /tmp/torchinductor_shunting/2u/c2u2begnomazgjf7whgacopubkixewbtygsefgr73rsj6xnr6zj7.py | |
# Source Nodes: [masked_lm_loss], Original ATen: [aten._log_softmax_backward_data, aten.add, aten.nll_loss_backward, aten.nll_loss_forward] | |
# masked_lm_loss => full_default_2 | |
triton_red_fused__log_softmax_backward_data_add_nll_loss_backward_nll_loss_forward_2 = async_compile.triton('triton_red_fused__log_softmax_backward_data_add_nll_loss_backward_nll_loss_forward_2', ''' | |
import triton | |
import triton.language as tl | |
from triton.compiler.compiler import AttrsDescriptor | |
from torch._inductor.runtime import triton_helpers, triton_heuristics | |
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties | |
from torch._dynamo.testing import rand_strided | |
from torch._C import _cuda_getCurrentRawStream as get_raw_stream | |
import torch | |
from torch._inductor.runtime.triton_heuristics import grid, split_scan_grid | |
@triton_heuristics.reduction( | |
size_hints=[8192, 32768], | |
reduction_hint=ReductionHint.DEFAULT, | |
filename=__file__, | |
triton_meta={'signature': {0: '*fp32', 1: '*i64', 2: '*fp32', 3: '*fp32', 4: '*fp16', 5: '*fp16', 6: '*fp32', 7: '*fp32', 8: '*fp16', 9: 'i32', 10: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4, 5, 6, 7, 8, 9), equal_to_1=())]}, | |
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused__log_softmax_backward_data_add_nll_loss_backward_nll_loss_forward_2', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 11, 'num_reduction': 1, 'backend_hash': '196B341D951BBDD96DE5B2B44B37054C2CBA8E89494C8B8A1052A863B8BC7596', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'kernel_num_gb': 1.500348424} | |
) | |
@triton.jit | |
def triton_red_fused__log_softmax_backward_data_add_nll_loss_backward_nll_loss_forward_2(in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, in_ptr5, in_ptr6, in_ptr7, out_ptr1, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr): | |
xnumel = 8192 | |
rnumel = 30522 | |
xoffset = tl.program_id(0) * XBLOCK | |
xindex = xoffset + tl.arange(0, XBLOCK)[:, None] | |
xmask = tl.full([XBLOCK, RBLOCK], True, tl.int1) | |
rbase = tl.arange(0, RBLOCK)[None, :] | |
x0 = xindex | |
tmp1 = tl.load(in_ptr1 + (x0), None, eviction_policy='evict_last') | |
tmp4 = tl.load(in_ptr2 + (0)) | |
tmp5 = tl.broadcast_to(tmp4, [XBLOCK, RBLOCK]) | |
tmp6 = tl.load(in_ptr3 + (0)) | |
tmp7 = tl.broadcast_to(tmp6, [XBLOCK, RBLOCK]) | |
_tmp13 = tl.full([XBLOCK, RBLOCK], 0, tl.float32) | |
for roffset in range(0, rnumel, RBLOCK): | |
rindex = roffset + rbase | |
rmask = rindex < rnumel | |
r1 = rindex | |
tmp0 = tl.load(in_ptr0 + (r1 + (30528*x0)), rmask, eviction_policy='evict_last', other=0.0) | |
tmp2 = tl.full([1, 1], -100, tl.int64) | |
tmp3 = tmp1 != tmp2 | |
tmp8 = tmp5 / tmp7 | |
tmp9 = 0.0 | |
tmp10 = tl.where(tmp3, tmp8, tmp9) | |
tmp11 = tmp0 * tmp10 | |
tmp12 = tl.broadcast_to(tmp11, [XBLOCK, RBLOCK]) | |
tmp14 = _tmp13 + tmp12 | |
_tmp13 = tl.where(rmask, tmp14, _tmp13) | |
tmp13 = tl.sum(_tmp13, 1)[:, None] | |
tmp19 = tl.load(in_ptr2 + (0)) | |
tmp20 = tl.broadcast_to(tmp19, [XBLOCK, RBLOCK]) | |
tmp21 = tl.load(in_ptr3 + (0)) | |
tmp22 = tl.broadcast_to(tmp21, [XBLOCK, RBLOCK]) | |
tmp29 = tl.load(in_ptr6 + (x0), None, eviction_policy='evict_last') | |
tmp31 = tl.load(in_ptr7 + (x0), None, eviction_policy='evict_last') | |
for roffset in range(0, rnumel, RBLOCK): | |
rindex = roffset + rbase | |
rmask = rindex < rnumel | |
r1 = rindex | |
tmp15 = tl.load(in_ptr4 + (r1 + (30522*x0)), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32) | |
tmp16 = tl.load(in_ptr0 + (r1 + (30528*x0)), rmask, eviction_policy='evict_first', other=0.0) | |
tmp27 = tl.load(in_ptr5 + (r1 + (30528*x0)), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32) | |
tmp17 = tl.full([1, 1], -100, tl.int64) | |
tmp18 = tmp1 != tmp17 | |
tmp23 = tmp20 / tmp22 | |
tmp24 = 0.0 | |
tmp25 = tl.where(tmp18, tmp23, tmp24) | |
tmp26 = tmp16 * tmp25 | |
tmp28 = tmp27.to(tl.float32) | |
tmp30 = tmp28 - tmp29 | |
tmp32 = tmp30 - tmp31 | |
tmp33 = tmp32.to(tl.float32) | |
tmp34 = tmp33.to(tl.float32) | |
tmp35 = tl_math.exp(tmp34) | |
tmp36 = tmp35 * tmp13 | |
tmp37 = tmp26 - tmp36 | |
tmp38 = tmp37.to(tl.float32) | |
tmp39 = tmp15 + tmp38 | |
tl.store(out_ptr1 + (r1 + (30528*x0)), tmp39, rmask) | |
def get_args(): | |
arg_0 = rand_strided((8192, 30522), (30528, 1), device='cuda:0', dtype=torch.float32) | |
arg_1 = rand_strided((16, 512), (512, 1), device='cuda:0', dtype=torch.int64) | |
arg_2 = rand_strided((), (), device='cuda:0', dtype=torch.float32) | |
arg_3 = rand_strided((), (), device='cuda:0', dtype=torch.float32) | |
arg_4 = rand_strided((16, 512, 30522), (15627264, 30522, 1), device='cuda:0', dtype=torch.float16) | |
arg_5 = rand_strided((16, 512, 30522), (15630336, 30528, 1), device='cuda:0', dtype=torch.float16) | |
arg_6 = rand_strided((8192, 1), (1, 1), device='cuda:0', dtype=torch.float32) | |
arg_7 = rand_strided((8192, 1), (1, 1), device='cuda:0', dtype=torch.float32) | |
arg_8 = rand_strided((16, 512, 30522), (15630336, 30528, 1), device='cuda:0', dtype=torch.float16) | |
return arg_0, arg_1, arg_2, arg_3, arg_4, arg_5, arg_6, arg_7, arg_8, | |
def call(args): | |
with torch.cuda._DeviceGuard(0): | |
torch.cuda.set_device(0) | |
stream0 = get_raw_stream(0) | |
triton_red_fused__log_softmax_backward_data_add_nll_loss_backward_nll_loss_forward_2.run(*args, 8192, 30522, grid=grid(8192), stream=stream0) | |
def benchmark_all_configs(args): | |
with torch.cuda._DeviceGuard(0): | |
torch.cuda.set_device(0) | |
return triton_red_fused__log_softmax_backward_data_add_nll_loss_backward_nll_loss_forward_2.benchmark_all_configs(*args, 8192, 30522, grid=grid(8192)) | |
if __name__ == '__main__': | |
from triton.testing import do_bench | |
args = get_args() | |
ms = do_bench(lambda: call(args), rep=40, fast_flush=True) | |
num_gb = 1.500348424 | |
gb_per_s = num_gb / (ms / 1e3) | |
print(f"{ms:.3f}ms {num_gb:.3f}GB {gb_per_s:.2f}GB/s") | |
''', device_str='cuda') | |
# kernel path: /tmp/torchinductor_shunting/ny/cny7ookgvrcyfphkgshx7jmxx4kzb2nhzoiwqlawy3qkjrc5jc3x.py | |
# Source Nodes: [], Original ATen: [] | |
triton_poi_fused_3 = async_compile.triton('triton_poi_fused_3', ''' | |
import triton | |
import triton.language as tl | |
from triton.compiler.compiler import AttrsDescriptor | |
from torch._inductor.runtime import triton_helpers, triton_heuristics | |
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties | |
from torch._dynamo.testing import rand_strided | |
from torch._C import _cuda_getCurrentRawStream as get_raw_stream | |
import torch | |
from torch._inductor.runtime.triton_heuristics import grid, split_scan_grid | |
@triton_heuristics.pointwise( | |
size_hints=[268435456], | |
filename=__file__, | |
triton_meta={'signature': {0: '*fp16', 1: '*fp16', 2: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2), equal_to_1=())]}, | |
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_3', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': '196B341D951BBDD96DE5B2B44B37054C2CBA8E89494C8B8A1052A863B8BC7596', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'kernel_num_gb': 1.0002432}, | |
min_elem_per_thread=0 | |
) | |
@triton.jit | |
def triton_poi_fused_3(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr): | |
xnumel = 250085376 | |
xoffset = tl.program_id(0) * XBLOCK | |
xindex = xoffset + tl.arange(0, XBLOCK)[:] | |
xmask = tl.full([XBLOCK], True, tl.int1) | |
x0 = xindex % 30528 | |
x2 = xindex | |
tmp0 = x0 | |
tmp1 = tl.full([1], 30522, tl.int64) | |
tmp2 = tmp0 < tmp1 | |
tmp3 = tl.load(in_ptr0 + (x2), tmp2, other=0.0).to(tl.float32) | |
tl.store(out_ptr0 + (x2), tmp3, None) | |
def get_args(): | |
arg_0 = rand_strided((16, 512, 30522), (15630336, 30528, 1), device='cuda:0', dtype=torch.float16) | |
arg_1 = rand_strided((8192, 30528), (30528, 1), device='cuda:0', dtype=torch.float16) | |
return arg_0, arg_1, | |
def call(args): | |
with torch.cuda._DeviceGuard(0): | |
torch.cuda.set_device(0) | |
stream0 = get_raw_stream(0) | |
triton_poi_fused_3.run(*args, 250085376, grid=grid(250085376), stream=stream0) | |
def benchmark_all_configs(args): | |
with torch.cuda._DeviceGuard(0): | |
torch.cuda.set_device(0) | |
return triton_poi_fused_3.benchmark_all_configs(*args, 250085376, grid=grid(250085376)) | |
if __name__ == '__main__': | |
from triton.testing import do_bench | |
args = get_args() | |
ms = do_bench(lambda: call(args), rep=40, fast_flush=True) | |
num_gb = 1.0002432 | |
gb_per_s = num_gb / (ms / 1e3) | |
print(f"{ms:.3f}ms {num_gb:.3f}GB {gb_per_s:.2f}GB/s") | |
''', device_str='cuda') | |
# kernel path: /tmp/torchinductor_shunting/kl/ckldmag4lbnoqqb2xa3zz7bszzgzh6df4wqv2lgxnhv7xzrvm2a3.py | |
# Source Nodes: [], Original ATen: [] | |
triton_poi_fused_4 = async_compile.triton('triton_poi_fused_4', ''' | |
import triton | |
import triton.language as tl | |
from triton.compiler.compiler import AttrsDescriptor | |
from torch._inductor.runtime import triton_helpers, triton_heuristics | |
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties | |
from torch._dynamo.testing import rand_strided | |
from torch._C import _cuda_getCurrentRawStream as get_raw_stream | |
import torch | |
from torch._inductor.runtime.triton_heuristics import grid, split_scan_grid | |
@triton_heuristics.pointwise( | |
size_hints=[33554432], | |
filename=__file__, | |
triton_meta={'signature': {0: '*fp16', 1: '*fp16', 2: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2), equal_to_1=())]}, | |
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_4', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': '196B341D951BBDD96DE5B2B44B37054C2CBA8E89494C8B8A1052A863B8BC7596', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'kernel_num_gb': 0.0937728}, | |
min_elem_per_thread=0 | |
) | |
@triton.jit | |
def triton_poi_fused_4(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr): | |
xnumel = 23445504 | |
xoffset = tl.program_id(0) * XBLOCK | |
xindex = xoffset + tl.arange(0, XBLOCK)[:] | |
xmask = tl.full([XBLOCK], True, tl.int1) | |
x1 = (xindex // 768) | |
x2 = xindex | |
tmp0 = x1 | |
tmp1 = tl.full([1], 30522, tl.int64) | |
tmp2 = tmp0 < tmp1 | |
tmp3 = tl.load(in_ptr0 + (x2), tmp2, other=0.0).to(tl.float32) | |
tl.store(out_ptr0 + (x2), tmp3, None) | |
def get_args(): | |
arg_0 = rand_strided((30522, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
arg_1 = rand_strided((30528, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
return arg_0, arg_1, | |
def call(args): | |
with torch.cuda._DeviceGuard(0): | |
torch.cuda.set_device(0) | |
stream0 = get_raw_stream(0) | |
triton_poi_fused_4.run(*args, 23445504, grid=grid(23445504), stream=stream0) | |
def benchmark_all_configs(args): | |
with torch.cuda._DeviceGuard(0): | |
torch.cuda.set_device(0) | |
return triton_poi_fused_4.benchmark_all_configs(*args, 23445504, grid=grid(23445504)) | |
if __name__ == '__main__': | |
from triton.testing import do_bench | |
args = get_args() | |
ms = do_bench(lambda: call(args), rep=40, fast_flush=True) | |
num_gb = 0.0937728 | |
gb_per_s = num_gb / (ms / 1e3) | |
print(f"{ms:.3f}ms {num_gb:.3f}GB {gb_per_s:.2f}GB/s") | |
''', device_str='cuda') | |
# kernel path: /tmp/torchinductor_shunting/kv/ckvnavggszj55y6z3joj3ix2bejtrydnigqdoidqzr33bzb5ewym.py | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum] | |
triton_red_fused__to_copy_sum_5 = async_compile.triton('triton_red_fused__to_copy_sum_5', ''' | |
import triton | |
import triton.language as tl | |
from triton.compiler.compiler import AttrsDescriptor | |
from torch._inductor.runtime import triton_helpers, triton_heuristics | |
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties | |
from torch._dynamo.testing import rand_strided | |
from torch._C import _cuda_getCurrentRawStream as get_raw_stream | |
import torch | |
from torch._inductor.runtime.triton_heuristics import grid, split_scan_grid | |
@triton_heuristics.reduction( | |
size_hints=[32768, 8192], | |
reduction_hint=ReductionHint.DEFAULT, | |
filename=__file__, | |
triton_meta={'signature': {0: '*fp16', 1: '*fp32', 2: 'i32', 3: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 3), equal_to_1=())]}, | |
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_sum_5', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': '196B341D951BBDD96DE5B2B44B37054C2CBA8E89494C8B8A1052A863B8BC7596', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'kernel_num_gb': 0.500194536} | |
) | |
@triton.jit | |
def triton_red_fused__to_copy_sum_5(in_ptr0, out_ptr1, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr): | |
xnumel = 30522 | |
rnumel = 8192 | |
xoffset = tl.program_id(0) * XBLOCK | |
xindex = xoffset + tl.arange(0, XBLOCK)[:, None] | |
xmask = xindex < xnumel | |
rbase = tl.arange(0, RBLOCK)[None, :] | |
x0 = xindex | |
_tmp2 = tl.full([XBLOCK, RBLOCK], 0, tl.float32) | |
for roffset in range(0, rnumel, RBLOCK): | |
rindex = roffset + rbase | |
rmask = rindex < rnumel | |
r1 = rindex | |
tmp0 = tl.load(in_ptr0 + (x0 + (30528*r1)), rmask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32) | |
tmp1 = tl.broadcast_to(tmp0, [XBLOCK, RBLOCK]) | |
tmp3 = _tmp2 + tmp1 | |
_tmp2 = tl.where(rmask & xmask, tmp3, _tmp2) | |
tmp2 = tl.sum(_tmp2, 1)[:, None] | |
tmp4 = tmp2.to(tl.float32) | |
tl.store(out_ptr1 + (x0), tmp4, xmask) | |
def get_args(): | |
arg_0 = rand_strided((16, 512, 30522), (15630336, 30528, 1), device='cuda:0', dtype=torch.float16) | |
arg_1 = rand_strided((30522,), (1,), device='cuda:0', dtype=torch.float32) | |
return arg_0, arg_1, | |
def call(args): | |
with torch.cuda._DeviceGuard(0): | |
torch.cuda.set_device(0) | |
stream0 = get_raw_stream(0) | |
triton_red_fused__to_copy_sum_5.run(*args, 30522, 8192, grid=grid(30522), stream=stream0) | |
def benchmark_all_configs(args): | |
with torch.cuda._DeviceGuard(0): | |
torch.cuda.set_device(0) | |
return triton_red_fused__to_copy_sum_5.benchmark_all_configs(*args, 30522, 8192, grid=grid(30522)) | |
if __name__ == '__main__': | |
from triton.testing import do_bench | |
args = get_args() | |
ms = do_bench(lambda: call(args), rep=40, fast_flush=True) | |
num_gb = 0.500194536 | |
gb_per_s = num_gb / (ms / 1e3) | |
print(f"{ms:.3f}ms {num_gb:.3f}GB {gb_per_s:.2f}GB/s") | |
''', device_str='cuda') | |
# kernel path: /tmp/torchinductor_shunting/l6/cl6ssrtjbepd5tssmygz54usro7zta7vm56aayideyxsk3wu3wsb.py | |
# Source Nodes: [], Original ATen: [aten._to_copy] | |
triton_poi_fused__to_copy_6 = async_compile.triton('triton_poi_fused__to_copy_6', ''' | |
import triton | |
import triton.language as tl | |
from triton.compiler.compiler import AttrsDescriptor | |
from torch._inductor.runtime import triton_helpers, triton_heuristics | |
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties | |
from torch._dynamo.testing import rand_strided | |
from torch._C import _cuda_getCurrentRawStream as get_raw_stream | |
import torch | |
from torch._inductor.runtime.triton_heuristics import grid, split_scan_grid | |
@triton_heuristics.pointwise( | |
size_hints=[33554432], | |
filename=__file__, | |
triton_meta={'signature': {0: '*fp16', 1: '*fp32', 2: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2), equal_to_1=())]}, | |
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_6', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': '196B341D951BBDD96DE5B2B44B37054C2CBA8E89494C8B8A1052A863B8BC7596', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'kernel_num_gb': 0.140645376}, | |
min_elem_per_thread=0 | |
) | |
@triton.jit | |
def triton_poi_fused__to_copy_6(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr): | |
xnumel = 23440896 | |
xoffset = tl.program_id(0) * XBLOCK | |
xindex = xoffset + tl.arange(0, XBLOCK)[:] | |
xmask = xindex < xnumel | |
x0 = xindex | |
tmp0 = tl.load(in_ptr0 + (x0), xmask).to(tl.float32) | |
tmp1 = tmp0.to(tl.float32) | |
tl.store(out_ptr0 + (x0), tmp1, xmask) | |
def get_args(): | |
arg_0 = rand_strided((30522, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
arg_1 = rand_strided((30522, 768), (768, 1), device='cuda:0', dtype=torch.float32) | |
return arg_0, arg_1, | |
def call(args): | |
with torch.cuda._DeviceGuard(0): | |
torch.cuda.set_device(0) | |
stream0 = get_raw_stream(0) | |
triton_poi_fused__to_copy_6.run(*args, 23440896, grid=grid(23440896), stream=stream0) | |
def benchmark_all_configs(args): | |
with torch.cuda._DeviceGuard(0): | |
torch.cuda.set_device(0) | |
return triton_poi_fused__to_copy_6.benchmark_all_configs(*args, 23440896, grid=grid(23440896)) | |
if __name__ == '__main__': | |
from triton.testing import do_bench | |
args = get_args() | |
ms = do_bench(lambda: call(args), rep=40, fast_flush=True) | |
num_gb = 0.140645376 | |
gb_per_s = num_gb / (ms / 1e3) | |
print(f"{ms:.3f}ms {num_gb:.3f}GB {gb_per_s:.2f}GB/s") | |
''', device_str='cuda') | |
# kernel path: /tmp/torchinductor_shunting/mc/cmcvv4a23higyc5z6e7nkc7q4ndfid6u7bhzw6yni6f5ofdbjoue.py | |
# Source Nodes: [hidden_states_97, hidden_states_98], Original ATen: [aten._to_copy, aten.gelu, aten.gelu_backward, aten.native_layer_norm, aten.native_layer_norm_backward, aten.view] | |
# hidden_states_97 => add_100, convert_element_type_498, erf_12, mul_161, mul_162, mul_163 | |
# hidden_states_98 => mul_164, sub_38 | |
triton_per_fused__to_copy_gelu_gelu_backward_native_layer_norm_native_layer_norm_backward_view_7 = async_compile.triton('triton_per_fused__to_copy_gelu_gelu_backward_native_layer_norm_native_layer_norm_backward_view_7', ''' | |
import triton | |
import triton.language as tl | |
from triton.compiler.compiler import AttrsDescriptor | |
from torch._inductor.runtime import triton_helpers, triton_heuristics | |
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties | |
from torch._dynamo.testing import rand_strided | |
from torch._C import _cuda_getCurrentRawStream as get_raw_stream | |
import torch | |
from torch._inductor.runtime.triton_heuristics import grid, split_scan_grid | |
@triton_heuristics.persistent_reduction( | |
size_hints=[8192, 1024], | |
reduction_hint=ReductionHint.INNER, | |
filename=__file__, | |
triton_meta={'signature': {0: '*fp16', 1: '*fp32', 2: '*fp16', 3: '*fp32', 4: '*fp32', 5: '*fp16', 6: 'i32', 7: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4, 5, 6, 7), equal_to_1=())]}, | |
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_gelu_gelu_backward_native_layer_norm_native_layer_norm_backward_view_7', 'mutated_arg_names': [], 'no_x_dim': True, 'num_load': 5, 'num_reduction': 2, 'backend_hash': '196B341D951BBDD96DE5B2B44B37054C2CBA8E89494C8B8A1052A863B8BC7596', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'kernel_num_gb': 0.037817344} | |
) | |
@triton.jit | |
def triton_per_fused__to_copy_gelu_gelu_backward_native_layer_norm_native_layer_norm_backward_view_7(in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, out_ptr3, xnumel, rnumel): | |
xnumel = 8192 | |
XBLOCK: tl.constexpr = 1 | |
rnumel = 768 | |
RBLOCK: tl.constexpr = 1024 | |
xoffset = tl.program_id(0) * XBLOCK | |
xindex = tl.full([1], xoffset, tl.int32) | |
xmask = tl.full([RBLOCK], True, tl.int1) | |
rindex = tl.arange(0, RBLOCK)[:] | |
roffset = 0 | |
rmask = rindex < rnumel | |
r1 = rindex | |
x0 = xindex | |
tmp0 = tl.load(in_ptr0 + (r1 + (768*x0)), rmask, other=0.0).to(tl.float32) | |
tmp2 = tl.load(in_ptr1 + (r1), rmask, eviction_policy='evict_last', other=0.0) | |
tmp8 = tl.load(in_ptr2 + (r1 + (768*x0)), rmask, other=0.0).to(tl.float32) | |
tmp18 = tl.load(in_ptr3 + (x0), None, eviction_policy='evict_last') | |
tmp20 = tl.load(in_ptr4 + (x0), None, eviction_policy='evict_last') | |
tmp1 = tmp0.to(tl.float32) | |
tmp3 = tmp1 * tmp2 | |
tmp4 = tl.broadcast_to(tmp3, [RBLOCK]) | |
tmp6 = tl.where(rmask, tmp4, 0) | |
tmp7 = triton_helpers.promote_to_tensor(tl.sum(tmp6, 0)) | |
tmp9 = tmp8.to(tl.float32) | |
tmp10 = 0.5 | |
tmp11 = tmp9 * tmp10 | |
tmp12 = 0.7071067811865476 | |
tmp13 = tmp9 * tmp12 | |
tmp14 = libdevice.erf(tmp13) | |
tmp15 = 1.0 | |
tmp16 = tmp14 + tmp15 | |
tmp17 = tmp11 * tmp16 | |
tmp19 = tmp17 - tmp18 | |
tmp21 = tmp19 * tmp20 | |
tmp22 = tmp3 * tmp21 | |
tmp23 = tl.broadcast_to(tmp22, [RBLOCK]) | |
tmp25 = tl.where(rmask, tmp23, 0) | |
tmp26 = triton_helpers.promote_to_tensor(tl.sum(tmp25, 0)) | |
tmp27 = 0.0013020833333333333 | |
tmp28 = tmp20 * tmp27 | |
tmp29 = 768.0 | |
tmp30 = tmp3 * tmp29 | |
tmp31 = tmp30 - tmp7 | |
tmp32 = tmp21 * tmp26 | |
tmp33 = tmp31 - tmp32 | |
tmp34 = tmp28 * tmp33 | |
tmp35 = tmp16 * tmp10 | |
tmp36 = tmp9 * tmp9 | |
tmp37 = -0.5 | |
tmp38 = tmp36 * tmp37 | |
tmp39 = tl_math.exp(tmp38) | |
tmp40 = 0.3989422804014327 | |
tmp41 = tmp39 * tmp40 | |
tmp42 = tmp9 * tmp41 | |
tmp43 = tmp35 + tmp42 | |
tmp44 = tmp34 * tmp43 | |
tmp45 = tmp44.to(tl.float32) | |
tl.store(out_ptr3 + (r1 + (768*x0)), tmp45, rmask) | |
def get_args(): | |
arg_0 = rand_strided((8192, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
arg_1 = rand_strided((768,), (1,), device='cuda:0', dtype=torch.float32) | |
arg_2 = rand_strided((8192, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
arg_3 = rand_strided((16, 512, 1), (512, 1, 1), device='cuda:0', dtype=torch.float32) | |
arg_4 = rand_strided((16, 512, 1), (512, 1, 1), device='cuda:0', dtype=torch.float32) | |
arg_5 = rand_strided((8192, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
return arg_0, arg_1, arg_2, arg_3, arg_4, arg_5, | |
def call(args): | |
with torch.cuda._DeviceGuard(0): | |
torch.cuda.set_device(0) | |
stream0 = get_raw_stream(0) | |
triton_per_fused__to_copy_gelu_gelu_backward_native_layer_norm_native_layer_norm_backward_view_7.run(*args, 8192, 768, grid=grid(8192), stream=stream0) | |
def benchmark_all_configs(args): | |
with torch.cuda._DeviceGuard(0): | |
torch.cuda.set_device(0) | |
return triton_per_fused__to_copy_gelu_gelu_backward_native_layer_norm_native_layer_norm_backward_view_7.benchmark_all_configs(*args, 8192, 768, grid=grid(8192)) | |
if __name__ == '__main__': | |
from triton.testing import do_bench | |
args = get_args() | |
ms = do_bench(lambda: call(args), rep=40, fast_flush=True) | |
num_gb = 0.037817344 | |
gb_per_s = num_gb / (ms / 1e3) | |
print(f"{ms:.3f}ms {num_gb:.3f}GB {gb_per_s:.2f}GB/s") | |
''', device_str='cuda') | |
# kernel path: /tmp/torchinductor_shunting/xn/cxn73hmq2otzylcbqlg6bpgfz5pcbvnul2dl43ea3uixl37qrhue.py | |
# Source Nodes: [hidden_states_97, hidden_states_98], Original ATen: [aten._to_copy, aten.gelu, aten.native_layer_norm, aten.native_layer_norm_backward] | |
# hidden_states_97 => add_100, convert_element_type_498, erf_12, mul_161, mul_162, mul_163 | |
# hidden_states_98 => mul_164, sub_38 | |
triton_red_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_8 = async_compile.triton('triton_red_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_8', ''' | |
import triton | |
import triton.language as tl | |
from triton.compiler.compiler import AttrsDescriptor | |
from torch._inductor.runtime import triton_helpers, triton_heuristics | |
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties | |
from torch._dynamo.testing import rand_strided | |
from torch._C import _cuda_getCurrentRawStream as get_raw_stream | |
import torch | |
from torch._inductor.runtime.triton_heuristics import grid, split_scan_grid | |
@triton_heuristics.reduction( | |
size_hints=[65536, 128], | |
reduction_hint=ReductionHint.OUTER, | |
filename=__file__, | |
triton_meta={'signature': {0: '*fp16', 1: '*fp16', 2: '*fp32', 3: '*fp32', 4: '*fp32', 5: '*fp32', 6: 'i32', 7: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4, 5, 6, 7), equal_to_1=())]}, | |
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_8', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 4, 'num_reduction': 2, 'backend_hash': '196B341D951BBDD96DE5B2B44B37054C2CBA8E89494C8B8A1052A863B8BC7596', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'kernel_num_gb': 0.025624576} | |
) | |
@triton.jit | |
def triton_red_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_8(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr0, out_ptr1, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr): | |
xnumel = 49152 | |
rnumel = 128 | |
xoffset = tl.program_id(0) * XBLOCK | |
xindex = xoffset + tl.arange(0, XBLOCK)[:, None] | |
xmask = tl.full([XBLOCK, RBLOCK], True, tl.int1) | |
rbase = tl.arange(0, RBLOCK)[None, :] | |
x0 = xindex % 768 | |
x1 = (xindex // 768) | |
_tmp18 = tl.full([XBLOCK, RBLOCK], 0, tl.float32) | |
x3 = xindex | |
_tmp21 = tl.full([XBLOCK, RBLOCK], 0, tl.float32) | |
for roffset in range(0, rnumel, RBLOCK): | |
rindex = roffset + rbase | |
rmask = rindex < rnumel | |
r2 = rindex | |
tmp0 = tl.load(in_ptr0 + (x0 + (768*r2) + (98304*x1)), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32) | |
tmp2 = tl.load(in_ptr1 + (x0 + (768*r2) + (98304*x1)), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32) | |
tmp12 = tl.load(in_ptr2 + (r2 + (128*x1)), rmask, eviction_policy='evict_last', other=0.0) | |
tmp14 = tl.load(in_ptr3 + (r2 + (128*x1)), rmask, eviction_policy='evict_last', other=0.0) | |
tmp1 = tmp0.to(tl.float32) | |
tmp3 = tmp2.to(tl.float32) | |
tmp4 = 0.5 | |
tmp5 = tmp3 * tmp4 | |
tmp6 = 0.7071067811865476 | |
tmp7 = tmp3 * tmp6 | |
tmp8 = libdevice.erf(tmp7) | |
tmp9 = 1.0 | |
tmp10 = tmp8 + tmp9 | |
tmp11 = tmp5 * tmp10 | |
tmp13 = tmp11 - tmp12 | |
tmp15 = tmp13 * tmp14 | |
tmp16 = tmp1 * tmp15 | |
tmp17 = tl.broadcast_to(tmp16, [XBLOCK, RBLOCK]) | |
tmp19 = _tmp18 + tmp17 | |
_tmp18 = tl.where(rmask, tmp19, _tmp18) | |
tmp20 = tl.broadcast_to(tmp1, [XBLOCK, RBLOCK]) | |
tmp22 = _tmp21 + tmp20 | |
_tmp21 = tl.where(rmask, tmp22, _tmp21) | |
tmp18 = tl.sum(_tmp18, 1)[:, None] | |
tl.store(out_ptr0 + (x3), tmp18, None) | |
tmp21 = tl.sum(_tmp21, 1)[:, None] | |
tl.store(out_ptr1 + (x3), tmp21, None) | |
def get_args(): | |
arg_0 = rand_strided((8192, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
arg_1 = rand_strided((8192, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
arg_2 = rand_strided((16, 512, 1), (512, 1, 1), device='cuda:0', dtype=torch.float32) | |
arg_3 = rand_strided((16, 512, 1), (512, 1, 1), device='cuda:0', dtype=torch.float32) | |
arg_4 = rand_strided((768, 64), (1, 768), device='cuda:0', dtype=torch.float32) | |
arg_5 = rand_strided((768, 64), (1, 768), device='cuda:0', dtype=torch.float32) | |
return arg_0, arg_1, arg_2, arg_3, arg_4, arg_5, | |
def call(args): | |
with torch.cuda._DeviceGuard(0): | |
torch.cuda.set_device(0) | |
stream0 = get_raw_stream(0) | |
triton_red_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_8.run(*args, 49152, 128, grid=grid(49152), stream=stream0) | |
def benchmark_all_configs(args): | |
with torch.cuda._DeviceGuard(0): | |
torch.cuda.set_device(0) | |
return triton_red_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_8.benchmark_all_configs(*args, 49152, 128, grid=grid(49152)) | |
if __name__ == '__main__': | |
from triton.testing import do_bench | |
args = get_args() | |
ms = do_bench(lambda: call(args), rep=40, fast_flush=True) | |
num_gb = 0.025624576 | |
gb_per_s = num_gb / (ms / 1e3) | |
print(f"{ms:.3f}ms {num_gb:.3f}GB {gb_per_s:.2f}GB/s") | |
''', device_str='cuda') | |
# kernel path: /tmp/torchinductor_shunting/xb/cxbiab5uy5kj4w6v3uoxvyncbfkyqo4xmh7csh6vz5yz4dut6xo2.py | |
# Source Nodes: [hidden_states_97, hidden_states_98], Original ATen: [aten._to_copy, aten.gelu, aten.native_layer_norm, aten.native_layer_norm_backward] | |
# hidden_states_97 => add_100, convert_element_type_498, erf_12, mul_161, mul_162, mul_163 | |
# hidden_states_98 => mul_164, sub_38 | |
triton_per_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_9 = async_compile.triton('triton_per_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_9', ''' | |
import triton | |
import triton.language as tl | |
from triton.compiler.compiler import AttrsDescriptor | |
from torch._inductor.runtime import triton_helpers, triton_heuristics | |
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties | |
from torch._dynamo.testing import rand_strided | |
from torch._C import _cuda_getCurrentRawStream as get_raw_stream | |
import torch | |
from torch._inductor.runtime.triton_heuristics import grid, split_scan_grid | |
@triton_heuristics.persistent_reduction( | |
size_hints=[1024, 64], | |
reduction_hint=ReductionHint.OUTER, | |
filename=__file__, | |
triton_meta={'signature': {0: '*fp32', 1: '*fp32', 2: 'i32', 3: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3), equal_to_1=())]}, | |
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_9', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': '196B341D951BBDD96DE5B2B44B37054C2CBA8E89494C8B8A1052A863B8BC7596', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'kernel_num_gb': 0.00019968} | |
) | |
@triton.jit | |
def triton_per_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_9(in_ptr0, out_ptr0, xnumel, rnumel, XBLOCK : tl.constexpr): | |
xnumel = 768 | |
rnumel = 64 | |
RBLOCK: tl.constexpr = 64 | |
xoffset = tl.program_id(0) * XBLOCK | |
xindex = xoffset + tl.arange(0, XBLOCK)[:, None] | |
xmask = xindex < xnumel | |
rindex = tl.arange(0, RBLOCK)[None, :] | |
roffset = 0 | |
rmask = tl.full([XBLOCK, RBLOCK], True, tl.int1) | |
r1 = rindex | |
x0 = xindex | |
tmp0 = tl.load(in_ptr0 + (x0 + (768*r1)), xmask, other=0.0) | |
tmp1 = tl.broadcast_to(tmp0, [XBLOCK, RBLOCK]) | |
tmp3 = tl.where(xmask, tmp1, 0) | |
tmp4 = tl.sum(tmp3, 1)[:, None] | |
tl.store(out_ptr0 + (x0), tmp4, xmask) | |
def get_args(): | |
arg_0 = rand_strided((768, 64), (1, 768), device='cuda:0', dtype=torch.float32) | |
arg_1 = rand_strided((768,), (1,), device='cuda:0', dtype=torch.float32) | |
return arg_0, arg_1, | |
def call(args): | |
with torch.cuda._DeviceGuard(0): | |
torch.cuda.set_device(0) | |
stream0 = get_raw_stream(0) | |
triton_per_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_9.run(*args, 768, 64, grid=grid(768), stream=stream0) | |
def benchmark_all_configs(args): | |
with torch.cuda._DeviceGuard(0): | |
torch.cuda.set_device(0) | |
return triton_per_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_9.benchmark_all_configs(*args, 768, 64, grid=grid(768)) | |
if __name__ == '__main__': | |
from triton.testing import do_bench | |
args = get_args() | |
ms = do_bench(lambda: call(args), rep=40, fast_flush=True) | |
num_gb = 0.00019968 | |
gb_per_s = num_gb / (ms / 1e3) | |
print(f"{ms:.3f}ms {num_gb:.3f}GB {gb_per_s:.2f}GB/s") | |
''', device_str='cuda') | |
# kernel path: /tmp/torchinductor_shunting/xw/cxwo2kn3atkuyrqgmindh7m4zzqaakdmurrxxe3trabcblnihnwq.py | |
# Source Nodes: [], Original ATen: [aten.sum] | |
triton_red_fused_sum_10 = async_compile.triton('triton_red_fused_sum_10', ''' | |
import triton | |
import triton.language as tl | |
from triton.compiler.compiler import AttrsDescriptor | |
from torch._inductor.runtime import triton_helpers, triton_heuristics | |
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties | |
from torch._dynamo.testing import rand_strided | |
from torch._C import _cuda_getCurrentRawStream as get_raw_stream | |
import torch | |
from torch._inductor.runtime.triton_heuristics import grid, split_scan_grid | |
@triton_heuristics.reduction( | |
size_hints=[65536, 128], | |
reduction_hint=ReductionHint.OUTER, | |
filename=__file__, | |
triton_meta={'signature': {0: '*fp16', 1: '*fp32', 2: 'i32', 3: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3), equal_to_1=())]}, | |
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused_sum_10', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': '196B341D951BBDD96DE5B2B44B37054C2CBA8E89494C8B8A1052A863B8BC7596', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'kernel_num_gb': 0.01277952} | |
) | |
@triton.jit | |
def triton_red_fused_sum_10(in_ptr0, out_ptr0, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr): | |
xnumel = 49152 | |
rnumel = 128 | |
xoffset = tl.program_id(0) * XBLOCK | |
xindex = xoffset + tl.arange(0, XBLOCK)[:, None] | |
xmask = tl.full([XBLOCK, RBLOCK], True, tl.int1) | |
rbase = tl.arange(0, RBLOCK)[None, :] | |
x0 = xindex % 768 | |
x1 = (xindex // 768) | |
_tmp2 = tl.full([XBLOCK, RBLOCK], 0, tl.float32) | |
x3 = xindex | |
for roffset in range(0, rnumel, RBLOCK): | |
rindex = roffset + rbase | |
rmask = rindex < rnumel | |
r2 = rindex | |
tmp0 = tl.load(in_ptr0 + (x0 + (768*r2) + (98304*x1)), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32) | |
tmp1 = tl.broadcast_to(tmp0, [XBLOCK, RBLOCK]) | |
tmp3 = _tmp2 + tmp1 | |
_tmp2 = tl.where(rmask, tmp3, _tmp2) | |
tmp2 = tl.sum(_tmp2, 1)[:, None] | |
tl.store(out_ptr0 + (x3), tmp2, None) | |
def get_args(): | |
arg_0 = rand_strided((8192, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
arg_1 = rand_strided((1, 768, 64), (49152, 1, 768), device='cuda:0', dtype=torch.float32) | |
return arg_0, arg_1, | |
def call(args): | |
with torch.cuda._DeviceGuard(0): | |
torch.cuda.set_device(0) | |
stream0 = get_raw_stream(0) | |
triton_red_fused_sum_10.run(*args, 49152, 128, grid=grid(49152), stream=stream0) | |
def benchmark_all_configs(args): | |
with torch.cuda._DeviceGuard(0): | |
torch.cuda.set_device(0) | |
return triton_red_fused_sum_10.benchmark_all_configs(*args, 49152, 128, grid=grid(49152)) | |
if __name__ == '__main__': | |
from triton.testing import do_bench | |
args = get_args() | |
ms = do_bench(lambda: call(args), rep=40, fast_flush=True) | |
num_gb = 0.01277952 | |
gb_per_s = num_gb / (ms / 1e3) | |
print(f"{ms:.3f}ms {num_gb:.3f}GB {gb_per_s:.2f}GB/s") | |
''', device_str='cuda') | |
# kernel path: /tmp/torchinductor_shunting/bw/cbwoo26iavmb6dt37knpfmallrwdgfbg4bmorajo736d3byifyha.py | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum] | |
triton_per_fused__to_copy_sum_11 = async_compile.triton('triton_per_fused__to_copy_sum_11', ''' | |
import triton | |
import triton.language as tl | |
from triton.compiler.compiler import AttrsDescriptor | |
from torch._inductor.runtime import triton_helpers, triton_heuristics | |
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties | |
from torch._dynamo.testing import rand_strided | |
from torch._C import _cuda_getCurrentRawStream as get_raw_stream | |
import torch | |
from torch._inductor.runtime.triton_heuristics import grid, split_scan_grid | |
@triton_heuristics.persistent_reduction( | |
size_hints=[1024, 64], | |
reduction_hint=ReductionHint.OUTER, | |
filename=__file__, | |
triton_meta={'signature': {0: '*fp32', 1: '*fp32', 2: 'i32', 3: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3), equal_to_1=())]}, | |
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_sum_11', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': '196B341D951BBDD96DE5B2B44B37054C2CBA8E89494C8B8A1052A863B8BC7596', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'kernel_num_gb': 0.00019968} | |
) | |
@triton.jit | |
def triton_per_fused__to_copy_sum_11(in_ptr0, out_ptr1, xnumel, rnumel, XBLOCK : tl.constexpr): | |
xnumel = 768 | |
rnumel = 64 | |
RBLOCK: tl.constexpr = 64 | |
xoffset = tl.program_id(0) * XBLOCK | |
xindex = xoffset + tl.arange(0, XBLOCK)[:, None] | |
xmask = xindex < xnumel | |
rindex = tl.arange(0, RBLOCK)[None, :] | |
roffset = 0 | |
rmask = tl.full([XBLOCK, RBLOCK], True, tl.int1) | |
r1 = rindex | |
x0 = xindex | |
tmp0 = tl.load(in_ptr0 + (x0 + (768*r1)), xmask, other=0.0) | |
tmp1 = tl.broadcast_to(tmp0, [XBLOCK, RBLOCK]) | |
tmp3 = tl.where(xmask, tmp1, 0) | |
tmp4 = tl.sum(tmp3, 1)[:, None] | |
tmp5 = tmp4.to(tl.float32) | |
tl.store(out_ptr1 + (x0), tmp5, xmask) | |
def get_args(): | |
arg_0 = rand_strided((1, 768, 64), (49152, 1, 768), device='cuda:0', dtype=torch.float32) | |
arg_1 = rand_strided((768,), (1,), device='cuda:0', dtype=torch.float32) | |
return arg_0, arg_1, | |
def call(args): | |
with torch.cuda._DeviceGuard(0): | |
torch.cuda.set_device(0) | |
stream0 = get_raw_stream(0) | |
triton_per_fused__to_copy_sum_11.run(*args, 768, 64, grid=grid(768), stream=stream0) | |
def benchmark_all_configs(args): | |
with torch.cuda._DeviceGuard(0): | |
torch.cuda.set_device(0) | |
return triton_per_fused__to_copy_sum_11.benchmark_all_configs(*args, 768, 64, grid=grid(768)) | |
if __name__ == '__main__': | |
from triton.testing import do_bench | |
args = get_args() | |
ms = do_bench(lambda: call(args), rep=40, fast_flush=True) | |
num_gb = 0.00019968 | |
gb_per_s = num_gb / (ms / 1e3) | |
print(f"{ms:.3f}ms {num_gb:.3f}GB {gb_per_s:.2f}GB/s") | |
''', device_str='cuda') | |
# kernel path: /tmp/torchinductor_shunting/mz/cmzinemkxyqageuju6a26rj4aqmkt2l7ysoziwfzbw253z3tb63k.py | |
# Source Nodes: [], Original ATen: [aten._to_copy] | |
triton_poi_fused__to_copy_12 = async_compile.triton('triton_poi_fused__to_copy_12', ''' | |
import triton | |
import triton.language as tl | |
from triton.compiler.compiler import AttrsDescriptor | |
from torch._inductor.runtime import triton_helpers, triton_heuristics | |
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties | |
from torch._dynamo.testing import rand_strided | |
from torch._C import _cuda_getCurrentRawStream as get_raw_stream | |
import torch | |
from torch._inductor.runtime.triton_heuristics import grid, split_scan_grid | |
@triton_heuristics.pointwise( | |
size_hints=[1048576], | |
filename=__file__, | |
triton_meta={'signature': {0: '*fp16', 1: '*fp32', 2: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2), equal_to_1=())]}, | |
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_12', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': '196B341D951BBDD96DE5B2B44B37054C2CBA8E89494C8B8A1052A863B8BC7596', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'kernel_num_gb': 0.003538944}, | |
min_elem_per_thread=0 | |
) | |
@triton.jit | |
def triton_poi_fused__to_copy_12(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr): | |
xnumel = 589824 | |
xoffset = tl.program_id(0) * XBLOCK | |
xindex = xoffset + tl.arange(0, XBLOCK)[:] | |
xmask = tl.full([XBLOCK], True, tl.int1) | |
x0 = xindex | |
tmp0 = tl.load(in_ptr0 + (x0), None).to(tl.float32) | |
tmp1 = tmp0.to(tl.float32) | |
tl.store(out_ptr0 + (x0), tmp1, None) | |
def get_args(): | |
arg_0 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
arg_1 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.float32) | |
return arg_0, arg_1, | |
def call(args): | |
with torch.cuda._DeviceGuard(0): | |
torch.cuda.set_device(0) | |
stream0 = get_raw_stream(0) | |
triton_poi_fused__to_copy_12.run(*args, 589824, grid=grid(589824), stream=stream0) | |
def benchmark_all_configs(args): | |
with torch.cuda._DeviceGuard(0): | |
torch.cuda.set_device(0) | |
return triton_poi_fused__to_copy_12.benchmark_all_configs(*args, 589824, grid=grid(589824)) | |
if __name__ == '__main__': | |
from triton.testing import do_bench | |
args = get_args() | |
ms = do_bench(lambda: call(args), rep=40, fast_flush=True) | |
num_gb = 0.003538944 | |
gb_per_s = num_gb / (ms / 1e3) | |
print(f"{ms:.3f}ms {num_gb:.3f}GB {gb_per_s:.2f}GB/s") | |
''', device_str='cuda') | |
# kernel path: /tmp/torchinductor_shunting/l4/cl4sftqv3rdsorngubndfo524ezgoqyc7gyppiqa5qu2yltoipgv.py | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.native_dropout_backward, aten.native_layer_norm_backward] | |
triton_per_fused__to_copy_native_dropout_backward_native_layer_norm_backward_13 = async_compile.triton('triton_per_fused__to_copy_native_dropout_backward_native_layer_norm_backward_13', ''' | |
import triton | |
import triton.language as tl | |
from triton.compiler.compiler import AttrsDescriptor | |
from torch._inductor.runtime import triton_helpers, triton_heuristics | |
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties | |
from torch._dynamo.testing import rand_strided | |
from torch._C import _cuda_getCurrentRawStream as get_raw_stream | |
import torch | |
from torch._inductor.runtime.triton_heuristics import grid, split_scan_grid | |
@triton_heuristics.persistent_reduction( | |
size_hints=[8192, 1024], | |
reduction_hint=ReductionHint.INNER, | |
filename=__file__, | |
triton_meta={'signature': {0: '*fp16', 1: '*fp32', 2: '*fp32', 3: '*fp32', 4: '*i1', 5: '*fp32', 6: '*fp16', 7: 'i32', 8: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4, 5, 6, 7, 8), equal_to_1=())]}, | |
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_native_dropout_backward_native_layer_norm_backward_13', 'mutated_arg_names': [], 'no_x_dim': True, 'num_load': 5, 'num_reduction': 2, 'backend_hash': '196B341D951BBDD96DE5B2B44B37054C2CBA8E89494C8B8A1052A863B8BC7596', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'kernel_num_gb': 0.081824768} | |
) | |
@triton.jit | |
def triton_per_fused__to_copy_native_dropout_backward_native_layer_norm_backward_13(in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, out_ptr2, out_ptr3, xnumel, rnumel): | |
xnumel = 8192 | |
XBLOCK: tl.constexpr = 1 | |
rnumel = 768 | |
RBLOCK: tl.constexpr = 1024 | |
xoffset = tl.program_id(0) * XBLOCK | |
xindex = tl.full([1], xoffset, tl.int32) | |
xmask = tl.full([RBLOCK], True, tl.int1) | |
rindex = tl.arange(0, RBLOCK)[:] | |
roffset = 0 | |
rmask = rindex < rnumel | |
r1 = rindex | |
x0 = xindex | |
tmp0 = tl.load(in_ptr0 + (r1 + (768*x0)), rmask, other=0.0).to(tl.float32) | |
tmp2 = tl.load(in_ptr1 + (r1), rmask, eviction_policy='evict_last', other=0.0) | |
tmp8 = tl.load(in_ptr2 + (r1 + (768*x0)), rmask, other=0.0) | |
tmp14 = tl.load(in_ptr3 + (x0), None, eviction_policy='evict_last') | |
tmp22 = tl.load(in_ptr4 + (r1 + (768*x0)), rmask, other=0.0).to(tl.int1) | |
tmp1 = tmp0.to(tl.float32) | |
tmp3 = tmp1 * tmp2 | |
tmp4 = tl.broadcast_to(tmp3, [RBLOCK]) | |
tmp6 = tl.where(rmask, tmp4, 0) | |
tmp7 = triton_helpers.promote_to_tensor(tl.sum(tmp6, 0)) | |
tmp9 = tmp3 * tmp8 | |
tmp10 = tl.broadcast_to(tmp9, [RBLOCK]) | |
tmp12 = tl.where(rmask, tmp10, 0) | |
tmp13 = triton_helpers.promote_to_tensor(tl.sum(tmp12, 0)) | |
tmp15 = 768.0 | |
tmp16 = tmp3 * tmp15 | |
tmp17 = tmp16 - tmp7 | |
tmp18 = tmp8 * tmp13 | |
tmp19 = tmp17 - tmp18 | |
tmp20 = tmp14 * tmp19 | |
tmp21 = tmp20.to(tl.float32) | |
tmp23 = tmp22.to(tl.float32) | |
tmp24 = 1.1111111111111112 | |
tmp25 = tmp23 * tmp24 | |
tmp26 = tmp21 * tmp25 | |
tl.store(out_ptr2 + (r1 + (768*x0)), tmp20, rmask) | |
tl.store(out_ptr3 + (r1 + (768*x0)), tmp26, rmask) | |
def get_args(): | |
arg_0 = rand_strided((8192, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
arg_1 = rand_strided((768,), (1,), device='cuda:0', dtype=torch.float32) | |
arg_2 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.float32) | |
arg_3 = rand_strided((16, 512, 1), (512, 1, 1), device='cuda:0', dtype=torch.float32) | |
arg_4 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.bool) | |
arg_5 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.float32) | |
arg_6 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.float16) | |
return arg_0, arg_1, arg_2, arg_3, arg_4, arg_5, arg_6, | |
def call(args): | |
with torch.cuda._DeviceGuard(0): | |
torch.cuda.set_device(0) | |
stream0 = get_raw_stream(0) | |
triton_per_fused__to_copy_native_dropout_backward_native_layer_norm_backward_13.run(*args, 8192, 768, grid=grid(8192), stream=stream0) | |
def benchmark_all_configs(args): | |
with torch.cuda._DeviceGuard(0): | |
torch.cuda.set_device(0) | |
return triton_per_fused__to_copy_native_dropout_backward_native_layer_norm_backward_13.benchmark_all_configs(*args, 8192, 768, grid=grid(8192)) | |
if __name__ == '__main__': | |
from triton.testing import do_bench | |
args = get_args() | |
ms = do_bench(lambda: call(args), rep=40, fast_flush=True) | |
num_gb = 0.081824768 | |
gb_per_s = num_gb / (ms / 1e3) | |
print(f"{ms:.3f}ms {num_gb:.3f}GB {gb_per_s:.2f}GB/s") | |
''', device_str='cuda') | |
# kernel path: /tmp/torchinductor_shunting/pc/cpcxsiinpg3g3z4rnd4e54gqzv3auqdg34x7kolps2lll75bhrdh.py | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.native_layer_norm_backward] | |
triton_red_fused__to_copy_native_layer_norm_backward_14 = async_compile.triton('triton_red_fused__to_copy_native_layer_norm_backward_14', ''' | |
import triton | |
import triton.language as tl | |
from triton.compiler.compiler import AttrsDescriptor | |
from torch._inductor.runtime import triton_helpers, triton_heuristics | |
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties | |
from torch._dynamo.testing import rand_strided | |
from torch._C import _cuda_getCurrentRawStream as get_raw_stream | |
import torch | |
from torch._inductor.runtime.triton_heuristics import grid, split_scan_grid | |
@triton_heuristics.reduction( | |
size_hints=[65536, 128], | |
reduction_hint=ReductionHint.OUTER, | |
filename=__file__, | |
triton_meta={'signature': {0: '*fp16', 1: '*fp32', 2: '*fp32', 3: '*fp32', 4: 'i32', 5: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4, 5), equal_to_1=())]}, | |
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_native_layer_norm_backward_14', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 2, 'num_reduction': 2, 'backend_hash': '196B341D951BBDD96DE5B2B44B37054C2CBA8E89494C8B8A1052A863B8BC7596', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'kernel_num_gb': 0.038141952} | |
) | |
@triton.jit | |
def triton_red_fused__to_copy_native_layer_norm_backward_14(in_ptr0, in_ptr1, out_ptr0, out_ptr1, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr): | |
xnumel = 49152 | |
rnumel = 128 | |
xoffset = tl.program_id(0) * XBLOCK | |
xindex = xoffset + tl.arange(0, XBLOCK)[:, None] | |
xmask = tl.full([XBLOCK, RBLOCK], True, tl.int1) | |
rbase = tl.arange(0, RBLOCK)[None, :] | |
x0 = xindex % 768 | |
x1 = (xindex // 768) | |
_tmp5 = tl.full([XBLOCK, RBLOCK], 0, tl.float32) | |
x3 = xindex | |
_tmp8 = tl.full([XBLOCK, RBLOCK], 0, tl.float32) | |
for roffset in range(0, rnumel, RBLOCK): | |
rindex = roffset + rbase | |
rmask = rindex < rnumel | |
r2 = rindex | |
tmp0 = tl.load(in_ptr0 + (x0 + (768*r2) + (98304*x1)), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32) | |
tmp2 = tl.load(in_ptr1 + (x0 + (768*r2) + (98304*x1)), rmask, eviction_policy='evict_first', other=0.0) | |
tmp1 = tmp0.to(tl.float32) | |
tmp3 = tmp1 * tmp2 | |
tmp4 = tl.broadcast_to(tmp3, [XBLOCK, RBLOCK]) | |
tmp6 = _tmp5 + tmp4 | |
_tmp5 = tl.where(rmask, tmp6, _tmp5) | |
tmp7 = tl.broadcast_to(tmp1, [XBLOCK, RBLOCK]) | |
tmp9 = _tmp8 + tmp7 | |
_tmp8 = tl.where(rmask, tmp9, _tmp8) | |
tmp5 = tl.sum(_tmp5, 1)[:, None] | |
tl.store(out_ptr0 + (x3), tmp5, None) | |
tmp8 = tl.sum(_tmp8, 1)[:, None] | |
tl.store(out_ptr1 + (x3), tmp8, None) | |
def get_args(): | |
arg_0 = rand_strided((8192, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
arg_1 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.float32) | |
arg_2 = rand_strided((768, 64), (1, 768), device='cuda:0', dtype=torch.float32) | |
arg_3 = rand_strided((768, 64), (1, 768), device='cuda:0', dtype=torch.float32) | |
return arg_0, arg_1, arg_2, arg_3, | |
def call(args): | |
with torch.cuda._DeviceGuard(0): | |
torch.cuda.set_device(0) | |
stream0 = get_raw_stream(0) | |
triton_red_fused__to_copy_native_layer_norm_backward_14.run(*args, 49152, 128, grid=grid(49152), stream=stream0) | |
def benchmark_all_configs(args): | |
with torch.cuda._DeviceGuard(0): | |
torch.cuda.set_device(0) | |
return triton_red_fused__to_copy_native_layer_norm_backward_14.benchmark_all_configs(*args, 49152, 128, grid=grid(49152)) | |
if __name__ == '__main__': | |
from triton.testing import do_bench | |
args = get_args() | |
ms = do_bench(lambda: call(args), rep=40, fast_flush=True) | |
num_gb = 0.038141952 | |
gb_per_s = num_gb / (ms / 1e3) | |
print(f"{ms:.3f}ms {num_gb:.3f}GB {gb_per_s:.2f}GB/s") | |
''', device_str='cuda') | |
# kernel path: /tmp/torchinductor_shunting/tk/ctkue5wyr5xsiml26vprfqtgmp2eaql2enslkqgq3exemgpikz3o.py | |
# Source Nodes: [], Original ATen: [aten.sum] | |
triton_red_fused_sum_15 = async_compile.triton('triton_red_fused_sum_15', ''' | |
import triton | |
import triton.language as tl | |
from triton.compiler.compiler import AttrsDescriptor | |
from torch._inductor.runtime import triton_helpers, triton_heuristics | |
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties | |
from torch._dynamo.testing import rand_strided | |
from torch._C import _cuda_getCurrentRawStream as get_raw_stream | |
import torch | |
from torch._inductor.runtime.triton_heuristics import grid, split_scan_grid | |
@triton_heuristics.reduction( | |
size_hints=[65536, 128], | |
reduction_hint=ReductionHint.OUTER, | |
filename=__file__, | |
triton_meta={'signature': {0: '*fp16', 1: '*fp32', 2: 'i32', 3: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3), equal_to_1=())]}, | |
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused_sum_15', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': '196B341D951BBDD96DE5B2B44B37054C2CBA8E89494C8B8A1052A863B8BC7596', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'kernel_num_gb': 0.01277952} | |
) | |
@triton.jit | |
def triton_red_fused_sum_15(in_ptr0, out_ptr0, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr): | |
xnumel = 49152 | |
rnumel = 128 | |
xoffset = tl.program_id(0) * XBLOCK | |
xindex = xoffset + tl.arange(0, XBLOCK)[:, None] | |
xmask = tl.full([XBLOCK, RBLOCK], True, tl.int1) | |
rbase = tl.arange(0, RBLOCK)[None, :] | |
x0 = xindex % 768 | |
x1 = (xindex // 768) | |
_tmp2 = tl.full([XBLOCK, RBLOCK], 0, tl.float32) | |
x3 = xindex | |
for roffset in range(0, rnumel, RBLOCK): | |
rindex = roffset + rbase | |
rmask = rindex < rnumel | |
r2 = rindex | |
tmp0 = tl.load(in_ptr0 + (x0 + (768*r2) + (98304*x1)), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32) | |
tmp1 = tl.broadcast_to(tmp0, [XBLOCK, RBLOCK]) | |
tmp3 = _tmp2 + tmp1 | |
_tmp2 = tl.where(rmask, tmp3, _tmp2) | |
tmp2 = tl.sum(_tmp2, 1)[:, None] | |
tl.store(out_ptr0 + (x3), tmp2, None) | |
def get_args(): | |
arg_0 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.float16) | |
arg_1 = rand_strided((1, 768, 64), (49152, 1, 768), device='cuda:0', dtype=torch.float32) | |
return arg_0, arg_1, | |
def call(args): | |
with torch.cuda._DeviceGuard(0): | |
torch.cuda.set_device(0) | |
stream0 = get_raw_stream(0) | |
triton_red_fused_sum_15.run(*args, 49152, 128, grid=grid(49152), stream=stream0) | |
def benchmark_all_configs(args): | |
with torch.cuda._DeviceGuard(0): | |
torch.cuda.set_device(0) | |
return triton_red_fused_sum_15.benchmark_all_configs(*args, 49152, 128, grid=grid(49152)) | |
if __name__ == '__main__': | |
from triton.testing import do_bench | |
args = get_args() | |
ms = do_bench(lambda: call(args), rep=40, fast_flush=True) | |
num_gb = 0.01277952 | |
gb_per_s = num_gb / (ms / 1e3) | |
print(f"{ms:.3f}ms {num_gb:.3f}GB {gb_per_s:.2f}GB/s") | |
''', device_str='cuda') | |
# kernel path: /tmp/torchinductor_shunting/us/cusyhkdtuhzsuipmuax2gi4pj3qgfxdgv62fqhekapxgkifowy6q.py | |
# Source Nodes: [], Original ATen: [aten._to_copy] | |
triton_poi_fused__to_copy_16 = async_compile.triton('triton_poi_fused__to_copy_16', ''' | |
import triton | |
import triton.language as tl | |
from triton.compiler.compiler import AttrsDescriptor | |
from torch._inductor.runtime import triton_helpers, triton_heuristics | |
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties | |
from torch._dynamo.testing import rand_strided | |
from torch._C import _cuda_getCurrentRawStream as get_raw_stream | |
import torch | |
from torch._inductor.runtime.triton_heuristics import grid, split_scan_grid | |
@triton_heuristics.pointwise( | |
size_hints=[4194304], | |
filename=__file__, | |
triton_meta={'signature': {0: '*fp16', 1: '*fp32', 2: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2), equal_to_1=())]}, | |
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_16', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': '196B341D951BBDD96DE5B2B44B37054C2CBA8E89494C8B8A1052A863B8BC7596', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'kernel_num_gb': 0.014155776}, | |
min_elem_per_thread=0 | |
) | |
@triton.jit | |
def triton_poi_fused__to_copy_16(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr): | |
xnumel = 2359296 | |
xoffset = tl.program_id(0) * XBLOCK | |
xindex = xoffset + tl.arange(0, XBLOCK)[:] | |
xmask = tl.full([XBLOCK], True, tl.int1) | |
x0 = xindex | |
tmp0 = tl.load(in_ptr0 + (x0), None).to(tl.float32) | |
tmp1 = tmp0.to(tl.float32) | |
tl.store(out_ptr0 + (x0), tmp1, None) | |
def get_args(): | |
arg_0 = rand_strided((768, 3072), (3072, 1), device='cuda:0', dtype=torch.float16) | |
arg_1 = rand_strided((768, 3072), (3072, 1), device='cuda:0', dtype=torch.float32) | |
return arg_0, arg_1, | |
def call(args): | |
with torch.cuda._DeviceGuard(0): | |
torch.cuda.set_device(0) | |
stream0 = get_raw_stream(0) | |
triton_poi_fused__to_copy_16.run(*args, 2359296, grid=grid(2359296), stream=stream0) | |
def benchmark_all_configs(args): | |
with torch.cuda._DeviceGuard(0): | |
torch.cuda.set_device(0) | |
return triton_poi_fused__to_copy_16.benchmark_all_configs(*args, 2359296, grid=grid(2359296)) | |
if __name__ == '__main__': | |
from triton.testing import do_bench | |
args = get_args() | |
ms = do_bench(lambda: call(args), rep=40, fast_flush=True) | |
num_gb = 0.014155776 | |
gb_per_s = num_gb / (ms / 1e3) | |
print(f"{ms:.3f}ms {num_gb:.3f}GB {gb_per_s:.2f}GB/s") | |
''', device_str='cuda') | |
# kernel path: /tmp/torchinductor_shunting/im/cimmrvcpkky3hw3retgbay4kloi52mxpkldjifj3bt6hbcrvgqtv.py | |
# Source Nodes: [hidden_states_92], Original ATen: [aten.gelu, aten.gelu_backward] | |
# hidden_states_92 => add_96, convert_element_type_485, erf_11, mul_155 | |
triton_poi_fused_gelu_gelu_backward_17 = async_compile.triton('triton_poi_fused_gelu_gelu_backward_17', ''' | |
import triton | |
import triton.language as tl | |
from triton.compiler.compiler import AttrsDescriptor | |
from torch._inductor.runtime import triton_helpers, triton_heuristics | |
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties | |
from torch._dynamo.testing import rand_strided | |
from torch._C import _cuda_getCurrentRawStream as get_raw_stream | |
import torch | |
from torch._inductor.runtime.triton_heuristics import grid, split_scan_grid | |
@triton_heuristics.pointwise( | |
size_hints=[33554432], | |
filename=__file__, | |
triton_meta={'signature': {0: '*fp16', 1: '*fp16', 2: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2), equal_to_1=())]}, | |
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_gelu_gelu_backward_17', 'mutated_arg_names': ['in_out_ptr0'], 'no_x_dim': False, 'num_load': 2, 'num_reduction': 0, 'backend_hash': '196B341D951BBDD96DE5B2B44B37054C2CBA8E89494C8B8A1052A863B8BC7596', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'kernel_num_gb': 0.150994944}, | |
min_elem_per_thread=0 | |
) | |
@triton.jit | |
def triton_poi_fused_gelu_gelu_backward_17(in_out_ptr0, in_ptr0, xnumel, XBLOCK : tl.constexpr): | |
xnumel = 25165824 | |
xoffset = tl.program_id(0) * XBLOCK | |
xindex = xoffset + tl.arange(0, XBLOCK)[:] | |
xmask = tl.full([XBLOCK], True, tl.int1) | |
x0 = xindex | |
tmp0 = tl.load(in_out_ptr0 + (x0), None).to(tl.float32) | |
tmp2 = tl.load(in_ptr0 + (x0), None).to(tl.float32) | |
tmp1 = tmp0.to(tl.float32) | |
tmp3 = tmp2.to(tl.float32) | |
tmp4 = 0.7071067811865476 | |
tmp5 = tmp3 * tmp4 | |
tmp6 = libdevice.erf(tmp5) | |
tmp7 = 1.0 | |
tmp8 = tmp6 + tmp7 | |
tmp9 = 0.5 | |
tmp10 = tmp8 * tmp9 | |
tmp11 = tmp3 * tmp3 | |
tmp12 = -0.5 | |
tmp13 = tmp11 * tmp12 | |
tmp14 = tl_math.exp(tmp13) | |
tmp15 = 0.3989422804014327 | |
tmp16 = tmp14 * tmp15 | |
tmp17 = tmp3 * tmp16 | |
tmp18 = tmp10 + tmp17 | |
tmp19 = tmp1 * tmp18 | |
tmp20 = tmp19.to(tl.float32) | |
tl.store(in_out_ptr0 + (x0), tmp20, None) | |
def get_args(): | |
arg_0 = rand_strided((16, 512, 3072), (1572864, 3072, 1), device='cuda:0', dtype=torch.float16) | |
arg_1 = rand_strided((8192, 3072), (3072, 1), device='cuda:0', dtype=torch.float16) | |
return arg_0, arg_1, | |
def call(args): | |
with torch.cuda._DeviceGuard(0): | |
torch.cuda.set_device(0) | |
stream0 = get_raw_stream(0) | |
triton_poi_fused_gelu_gelu_backward_17.run(*args, 25165824, grid=grid(25165824), stream=stream0) | |
def benchmark_all_configs(args): | |
with torch.cuda._DeviceGuard(0): | |
torch.cuda.set_device(0) | |
return triton_poi_fused_gelu_gelu_backward_17.benchmark_all_configs(*args, 25165824, grid=grid(25165824)) | |
if __name__ == '__main__': | |
from triton.testing import do_bench | |
args = get_args() | |
ms = do_bench(lambda: call(args), rep=40, fast_flush=True) | |
num_gb = 0.150994944 | |
gb_per_s = num_gb / (ms / 1e3) | |
print(f"{ms:.3f}ms {num_gb:.3f}GB {gb_per_s:.2f}GB/s") | |
''', device_str='cuda') | |
# kernel path: /tmp/torchinductor_shunting/le/cleh7tp3rcdkx6xu5hdckexqubzkgpkgi7yarvwkzdkc76jba35y.py | |
# Source Nodes: [], Original ATen: [aten.sum] | |
triton_red_fused_sum_18 = async_compile.triton('triton_red_fused_sum_18', ''' | |
import triton | |
import triton.language as tl | |
from triton.compiler.compiler import AttrsDescriptor | |
from torch._inductor.runtime import triton_helpers, triton_heuristics | |
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties | |
from torch._dynamo.testing import rand_strided | |
from torch._C import _cuda_getCurrentRawStream as get_raw_stream | |
import torch | |
from torch._inductor.runtime.triton_heuristics import grid, split_scan_grid | |
@triton_heuristics.reduction( | |
size_hints=[131072, 256], | |
reduction_hint=ReductionHint.OUTER, | |
filename=__file__, | |
triton_meta={'signature': {0: '*fp16', 1: '*fp32', 2: 'i32', 3: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3), equal_to_1=())]}, | |
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused_sum_18', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': '196B341D951BBDD96DE5B2B44B37054C2CBA8E89494C8B8A1052A863B8BC7596', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'kernel_num_gb': 0.050724864} | |
) | |
@triton.jit | |
def triton_red_fused_sum_18(in_ptr0, out_ptr0, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr): | |
xnumel = 98304 | |
rnumel = 256 | |
xoffset = tl.program_id(0) * XBLOCK | |
xindex = xoffset + tl.arange(0, XBLOCK)[:, None] | |
xmask = tl.full([XBLOCK, RBLOCK], True, tl.int1) | |
rbase = tl.arange(0, RBLOCK)[None, :] | |
x0 = xindex % 3072 | |
x1 = (xindex // 3072) | |
_tmp2 = tl.full([XBLOCK, RBLOCK], 0, tl.float32) | |
x3 = xindex | |
for roffset in range(0, rnumel, RBLOCK): | |
rindex = roffset + rbase | |
rmask = rindex < rnumel | |
r2 = rindex | |
tmp0 = tl.load(in_ptr0 + (x0 + (3072*r2) + (786432*x1)), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32) | |
tmp1 = tl.broadcast_to(tmp0, [XBLOCK, RBLOCK]) | |
tmp3 = _tmp2 + tmp1 | |
_tmp2 = tl.where(rmask, tmp3, _tmp2) | |
tmp2 = tl.sum(_tmp2, 1)[:, None] | |
tl.store(out_ptr0 + (x3), tmp2, None) | |
def get_args(): | |
arg_0 = rand_strided((16, 512, 3072), (1572864, 3072, 1), device='cuda:0', dtype=torch.float16) | |
arg_1 = rand_strided((1, 3072, 32), (98304, 1, 3072), device='cuda:0', dtype=torch.float32) | |
return arg_0, arg_1, | |
def call(args): | |
with torch.cuda._DeviceGuard(0): | |
torch.cuda.set_device(0) | |
stream0 = get_raw_stream(0) | |
triton_red_fused_sum_18.run(*args, 98304, 256, grid=grid(98304), stream=stream0) | |
def benchmark_all_configs(args): | |
with torch.cuda._DeviceGuard(0): | |
torch.cuda.set_device(0) | |
return triton_red_fused_sum_18.benchmark_all_configs(*args, 98304, 256, grid=grid(98304)) | |
if __name__ == '__main__': | |
from triton.testing import do_bench | |
args = get_args() | |
ms = do_bench(lambda: call(args), rep=40, fast_flush=True) | |
num_gb = 0.050724864 | |
gb_per_s = num_gb / (ms / 1e3) | |
print(f"{ms:.3f}ms {num_gb:.3f}GB {gb_per_s:.2f}GB/s") | |
''', device_str='cuda') | |
# kernel path: /tmp/torchinductor_shunting/uv/cuvj2aolby6dsvhktohp2xnbrouwrm3i5iwofuf5cf2gxqrzwi4p.py | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum] | |
triton_per_fused__to_copy_sum_19 = async_compile.triton('triton_per_fused__to_copy_sum_19', ''' | |
import triton | |
import triton.language as tl | |
from triton.compiler.compiler import AttrsDescriptor | |
from torch._inductor.runtime import triton_helpers, triton_heuristics | |
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties | |
from torch._dynamo.testing import rand_strided | |
from torch._C import _cuda_getCurrentRawStream as get_raw_stream | |
import torch | |
from torch._inductor.runtime.triton_heuristics import grid, split_scan_grid | |
@triton_heuristics.persistent_reduction( | |
size_hints=[4096, 32], | |
reduction_hint=ReductionHint.OUTER, | |
filename=__file__, | |
triton_meta={'signature': {0: '*fp32', 1: '*fp32', 2: 'i32', 3: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3), equal_to_1=())]}, | |
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_sum_19', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': '196B341D951BBDD96DE5B2B44B37054C2CBA8E89494C8B8A1052A863B8BC7596', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'kernel_num_gb': 0.000405504} | |
) | |
@triton.jit | |
def triton_per_fused__to_copy_sum_19(in_ptr0, out_ptr1, xnumel, rnumel, XBLOCK : tl.constexpr): | |
xnumel = 3072 | |
rnumel = 32 | |
RBLOCK: tl.constexpr = 32 | |
xoffset = tl.program_id(0) * XBLOCK | |
xindex = xoffset + tl.arange(0, XBLOCK)[:, None] | |
xmask = xindex < xnumel | |
rindex = tl.arange(0, RBLOCK)[None, :] | |
roffset = 0 | |
rmask = tl.full([XBLOCK, RBLOCK], True, tl.int1) | |
r1 = rindex | |
x0 = xindex | |
tmp0 = tl.load(in_ptr0 + (x0 + (3072*r1)), xmask, other=0.0) | |
tmp1 = tl.broadcast_to(tmp0, [XBLOCK, RBLOCK]) | |
tmp3 = tl.where(xmask, tmp1, 0) | |
tmp4 = tl.sum(tmp3, 1)[:, None] | |
tmp5 = tmp4.to(tl.float32) | |
tl.store(out_ptr1 + (x0), tmp5, xmask) | |
def get_args(): | |
arg_0 = rand_strided((1, 3072, 32), (98304, 1, 3072), device='cuda:0', dtype=torch.float32) | |
arg_1 = rand_strided((3072,), (1,), device='cuda:0', dtype=torch.float32) | |
return arg_0, arg_1, | |
def call(args): | |
with torch.cuda._DeviceGuard(0): | |
torch.cuda.set_device(0) | |
stream0 = get_raw_stream(0) | |
triton_per_fused__to_copy_sum_19.run(*args, 3072, 32, grid=grid(3072), stream=stream0) | |
def benchmark_all_configs(args): | |
with torch.cuda._DeviceGuard(0): | |
torch.cuda.set_device(0) | |
return triton_per_fused__to_copy_sum_19.benchmark_all_configs(*args, 3072, 32, grid=grid(3072)) | |
if __name__ == '__main__': | |
from triton.testing import do_bench | |
args = get_args() | |
ms = do_bench(lambda: call(args), rep=40, fast_flush=True) | |
num_gb = 0.000405504 | |
gb_per_s = num_gb / (ms / 1e3) | |
print(f"{ms:.3f}ms {num_gb:.3f}GB {gb_per_s:.2f}GB/s") | |
''', device_str='cuda') | |
# kernel path: /tmp/torchinductor_shunting/dp/cdp3sl7e4pey7rtyji754dhbkoxfiwn6srdtuu3ygal2f76wqy6g.py | |
# Source Nodes: [], Original ATen: [aten._to_copy] | |
triton_poi_fused__to_copy_20 = async_compile.triton('triton_poi_fused__to_copy_20', ''' | |
import triton | |
import triton.language as tl | |
from triton.compiler.compiler import AttrsDescriptor | |
from torch._inductor.runtime import triton_helpers, triton_heuristics | |
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties | |
from torch._dynamo.testing import rand_strided | |
from torch._C import _cuda_getCurrentRawStream as get_raw_stream | |
import torch | |
from torch._inductor.runtime.triton_heuristics import grid, split_scan_grid | |
@triton_heuristics.pointwise( | |
size_hints=[4194304], | |
filename=__file__, | |
triton_meta={'signature': {0: '*fp16', 1: '*fp32', 2: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2), equal_to_1=())]}, | |
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_20', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': '196B341D951BBDD96DE5B2B44B37054C2CBA8E89494C8B8A1052A863B8BC7596', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'kernel_num_gb': 0.014155776}, | |
min_elem_per_thread=0 | |
) | |
@triton.jit | |
def triton_poi_fused__to_copy_20(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr): | |
xnumel = 2359296 | |
xoffset = tl.program_id(0) * XBLOCK | |
xindex = xoffset + tl.arange(0, XBLOCK)[:] | |
xmask = tl.full([XBLOCK], True, tl.int1) | |
x0 = xindex | |
tmp0 = tl.load(in_ptr0 + (x0), None).to(tl.float32) | |
tmp1 = tmp0.to(tl.float32) | |
tl.store(out_ptr0 + (x0), tmp1, None) | |
def get_args(): | |
arg_0 = rand_strided((3072, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
arg_1 = rand_strided((3072, 768), (768, 1), device='cuda:0', dtype=torch.float32) | |
return arg_0, arg_1, | |
def call(args): | |
with torch.cuda._DeviceGuard(0): | |
torch.cuda.set_device(0) | |
stream0 = get_raw_stream(0) | |
triton_poi_fused__to_copy_20.run(*args, 2359296, grid=grid(2359296), stream=stream0) | |
def benchmark_all_configs(args): | |
with torch.cuda._DeviceGuard(0): | |
torch.cuda.set_device(0) | |
return triton_poi_fused__to_copy_20.benchmark_all_configs(*args, 2359296, grid=grid(2359296)) | |
if __name__ == '__main__': | |
from triton.testing import do_bench | |
args = get_args() | |
ms = do_bench(lambda: call(args), rep=40, fast_flush=True) | |
num_gb = 0.014155776 | |
gb_per_s = num_gb / (ms / 1e3) | |
print(f"{ms:.3f}ms {num_gb:.3f}GB {gb_per_s:.2f}GB/s") | |
''', device_str='cuda') | |
# kernel path: /tmp/torchinductor_shunting/u2/cu2zu355cpgyqkbajs47g6dc25icfvkaqnjfftpy4rwsoyxwd6df.py | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_dropout_backward, aten.native_layer_norm_backward] | |
triton_per_fused__to_copy_add_native_dropout_backward_native_layer_norm_backward_21 = async_compile.triton('triton_per_fused__to_copy_add_native_dropout_backward_native_layer_norm_backward_21', ''' | |
import triton | |
import triton.language as tl | |
from triton.compiler.compiler import AttrsDescriptor | |
from torch._inductor.runtime import triton_helpers, triton_heuristics | |
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties | |
from torch._dynamo.testing import rand_strided | |
from torch._C import _cuda_getCurrentRawStream as get_raw_stream | |
import torch | |
from torch._inductor.runtime.triton_heuristics import grid, split_scan_grid | |
@triton_heuristics.persistent_reduction( | |
size_hints=[8192, 1024], | |
reduction_hint=ReductionHint.INNER, | |
filename=__file__, | |
triton_meta={'signature': {0: '*fp32', 1: '*fp16', 2: '*fp32', 3: '*fp32', 4: '*fp32', 5: '*i1', 6: '*fp32', 7: '*fp16', 8: 'i32', 9: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4, 5, 6, 7, 8, 9), equal_to_1=())]}, | |
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_add_native_dropout_backward_native_layer_norm_backward_21', 'mutated_arg_names': [], 'no_x_dim': True, 'num_load': 6, 'num_reduction': 2, 'backend_hash': '196B341D951BBDD96DE5B2B44B37054C2CBA8E89494C8B8A1052A863B8BC7596', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'kernel_num_gb': 0.106990592} | |
) | |
@triton.jit | |
def triton_per_fused__to_copy_add_native_dropout_backward_native_layer_norm_backward_21(in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, in_ptr5, out_ptr2, out_ptr3, xnumel, rnumel): | |
xnumel = 8192 | |
XBLOCK: tl.constexpr = 1 | |
rnumel = 768 | |
RBLOCK: tl.constexpr = 1024 | |
xoffset = tl.program_id(0) * XBLOCK | |
xindex = tl.full([1], xoffset, tl.int32) | |
xmask = tl.full([RBLOCK], True, tl.int1) | |
rindex = tl.arange(0, RBLOCK)[:] | |
roffset = 0 | |
rmask = rindex < rnumel | |
r1 = rindex | |
x0 = xindex | |
tmp0 = tl.load(in_ptr0 + (r1 + (768*x0)), rmask, other=0.0) | |
tmp1 = tl.load(in_ptr1 + (r1 + (768*x0)), rmask, other=0.0).to(tl.float32) | |
tmp4 = tl.load(in_ptr2 + (r1), rmask, eviction_policy='evict_last', other=0.0) | |
tmp10 = tl.load(in_ptr3 + (r1 + (768*x0)), rmask, other=0.0) | |
tmp16 = tl.load(in_ptr4 + (x0), None, eviction_policy='evict_last') | |
tmp24 = tl.load(in_ptr5 + (r1 + (768*x0)), rmask, other=0.0).to(tl.int1) | |
tmp2 = tmp1.to(tl.float32) | |
tmp3 = tmp0 + tmp2 | |
tmp5 = tmp3 * tmp4 | |
tmp6 = tl.broadcast_to(tmp5, [RBLOCK]) | |
tmp8 = tl.where(rmask, tmp6, 0) | |
tmp9 = triton_helpers.promote_to_tensor(tl.sum(tmp8, 0)) | |
tmp11 = tmp5 * tmp10 | |
tmp12 = tl.broadcast_to(tmp11, [RBLOCK]) | |
tmp14 = tl.where(rmask, tmp12, 0) | |
tmp15 = triton_helpers.promote_to_tensor(tl.sum(tmp14, 0)) | |
tmp17 = 768.0 | |
tmp18 = tmp5 * tmp17 | |
tmp19 = tmp18 - tmp9 | |
tmp20 = tmp10 * tmp15 | |
tmp21 = tmp19 - tmp20 | |
tmp22 = tmp16 * tmp21 | |
tmp23 = tmp22.to(tl.float32) | |
tmp25 = tmp24.to(tl.float32) | |
tmp26 = 1.1111111111111112 | |
tmp27 = tmp25 * tmp26 | |
tmp28 = tmp23 * tmp27 | |
tl.store(out_ptr2 + (r1 + (768*x0)), tmp22, rmask) | |
tl.store(out_ptr3 + (r1 + (768*x0)), tmp28, rmask) | |
def get_args(): | |
arg_0 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.float32) | |
arg_1 = rand_strided((8192, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
arg_2 = rand_strided((768,), (1,), device='cuda:0', dtype=torch.float32) | |
arg_3 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.float32) | |
arg_4 = rand_strided((16, 512, 1), (512, 1, 1), device='cuda:0', dtype=torch.float32) | |
arg_5 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.bool) | |
arg_6 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.float32) | |
arg_7 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.float16) | |
return arg_0, arg_1, arg_2, arg_3, arg_4, arg_5, arg_6, arg_7, | |
def call(args): | |
with torch.cuda._DeviceGuard(0): | |
torch.cuda.set_device(0) | |
stream0 = get_raw_stream(0) | |
triton_per_fused__to_copy_add_native_dropout_backward_native_layer_norm_backward_21.run(*args, 8192, 768, grid=grid(8192), stream=stream0) | |
def benchmark_all_configs(args): | |
with torch.cuda._DeviceGuard(0): | |
torch.cuda.set_device(0) | |
return triton_per_fused__to_copy_add_native_dropout_backward_native_layer_norm_backward_21.benchmark_all_configs(*args, 8192, 768, grid=grid(8192)) | |
if __name__ == '__main__': | |
from triton.testing import do_bench | |
args = get_args() | |
ms = do_bench(lambda: call(args), rep=40, fast_flush=True) | |
num_gb = 0.106990592 | |
gb_per_s = num_gb / (ms / 1e3) | |
print(f"{ms:.3f}ms {num_gb:.3f}GB {gb_per_s:.2f}GB/s") | |
''', device_str='cuda') | |
# kernel path: /tmp/torchinductor_shunting/kf/ckfk5xfnycnrpmf5qszpqq2qcsb7kwgpnosphd6ynan6dvn3d5gp.py | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward] | |
triton_red_fused__to_copy_add_native_layer_norm_backward_22 = async_compile.triton('triton_red_fused__to_copy_add_native_layer_norm_backward_22', ''' | |
import triton | |
import triton.language as tl | |
from triton.compiler.compiler import AttrsDescriptor | |
from torch._inductor.runtime import triton_helpers, triton_heuristics | |
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties | |
from torch._dynamo.testing import rand_strided | |
from torch._C import _cuda_getCurrentRawStream as get_raw_stream | |
import torch | |
from torch._inductor.runtime.triton_heuristics import grid, split_scan_grid | |
@triton_heuristics.reduction( | |
size_hints=[65536, 128], | |
reduction_hint=ReductionHint.OUTER, | |
filename=__file__, | |
triton_meta={'signature': {0: '*fp32', 1: '*fp16', 2: '*fp32', 3: '*fp32', 4: '*fp32', 5: 'i32', 6: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4, 5, 6), equal_to_1=())]}, | |
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_add_native_layer_norm_backward_22', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 3, 'num_reduction': 2, 'backend_hash': '196B341D951BBDD96DE5B2B44B37054C2CBA8E89494C8B8A1052A863B8BC7596', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'kernel_num_gb': 0.063307776} | |
) | |
@triton.jit | |
def triton_red_fused__to_copy_add_native_layer_norm_backward_22(in_ptr0, in_ptr1, in_ptr2, out_ptr0, out_ptr1, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr): | |
xnumel = 49152 | |
rnumel = 128 | |
xoffset = tl.program_id(0) * XBLOCK | |
xindex = xoffset + tl.arange(0, XBLOCK)[:, None] | |
xmask = tl.full([XBLOCK, RBLOCK], True, tl.int1) | |
rbase = tl.arange(0, RBLOCK)[None, :] | |
x0 = xindex % 768 | |
x1 = (xindex // 768) | |
_tmp7 = tl.full([XBLOCK, RBLOCK], 0, tl.float32) | |
x3 = xindex | |
_tmp10 = tl.full([XBLOCK, RBLOCK], 0, tl.float32) | |
for roffset in range(0, rnumel, RBLOCK): | |
rindex = roffset + rbase | |
rmask = rindex < rnumel | |
r2 = rindex | |
tmp0 = tl.load(in_ptr0 + (x0 + (768*r2) + (98304*x1)), rmask, eviction_policy='evict_first', other=0.0) | |
tmp1 = tl.load(in_ptr1 + (x0 + (768*r2) + (98304*x1)), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32) | |
tmp4 = tl.load(in_ptr2 + (x0 + (768*r2) + (98304*x1)), rmask, eviction_policy='evict_first', other=0.0) | |
tmp2 = tmp1.to(tl.float32) | |
tmp3 = tmp0 + tmp2 | |
tmp5 = tmp3 * tmp4 | |
tmp6 = tl.broadcast_to(tmp5, [XBLOCK, RBLOCK]) | |
tmp8 = _tmp7 + tmp6 | |
_tmp7 = tl.where(rmask, tmp8, _tmp7) | |
tmp9 = tl.broadcast_to(tmp3, [XBLOCK, RBLOCK]) | |
tmp11 = _tmp10 + tmp9 | |
_tmp10 = tl.where(rmask, tmp11, _tmp10) | |
tmp7 = tl.sum(_tmp7, 1)[:, None] | |
tl.store(out_ptr0 + (x3), tmp7, None) | |
tmp10 = tl.sum(_tmp10, 1)[:, None] | |
tl.store(out_ptr1 + (x3), tmp10, None) | |
def get_args(): | |
arg_0 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.float32) | |
arg_1 = rand_strided((8192, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
arg_2 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.float32) | |
arg_3 = rand_strided((768, 64), (1, 768), device='cuda:0', dtype=torch.float32) | |
arg_4 = rand_strided((768, 64), (1, 768), device='cuda:0', dtype=torch.float32) | |
return arg_0, arg_1, arg_2, arg_3, arg_4, | |
def call(args): | |
with torch.cuda._DeviceGuard(0): | |
torch.cuda.set_device(0) | |
stream0 = get_raw_stream(0) | |
triton_red_fused__to_copy_add_native_layer_norm_backward_22.run(*args, 49152, 128, grid=grid(49152), stream=stream0) | |
def benchmark_all_configs(args): | |
with torch.cuda._DeviceGuard(0): | |
torch.cuda.set_device(0) | |
return triton_red_fused__to_copy_add_native_layer_norm_backward_22.benchmark_all_configs(*args, 49152, 128, grid=grid(49152)) | |
if __name__ == '__main__': | |
from triton.testing import do_bench | |
args = get_args() | |
ms = do_bench(lambda: call(args), rep=40, fast_flush=True) | |
num_gb = 0.063307776 | |
gb_per_s = num_gb / (ms / 1e3) | |
print(f"{ms:.3f}ms {num_gb:.3f}GB {gb_per_s:.2f}GB/s") | |
''', device_str='cuda') | |
# kernel path: /tmp/torchinductor_shunting/fy/cfyrxilgpaoobuleeggh5n6i3vfcj6pqju5o3najjzmc764wrhlw.py | |
# Source Nodes: [], Original ATen: [aten.clone] | |
triton_poi_fused_clone_23 = async_compile.triton('triton_poi_fused_clone_23', ''' | |
import triton | |
import triton.language as tl | |
from triton.compiler.compiler import AttrsDescriptor | |
from torch._inductor.runtime import triton_helpers, triton_heuristics | |
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties | |
from torch._dynamo.testing import rand_strided | |
from torch._C import _cuda_getCurrentRawStream as get_raw_stream | |
import torch | |
from torch._inductor.runtime.triton_heuristics import grid, split_scan_grid | |
@triton_heuristics.pointwise( | |
size_hints=[8388608], | |
filename=__file__, | |
triton_meta={'signature': {0: '*fp16', 1: '*fp16', 2: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2), equal_to_1=())]}, | |
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_clone_23', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': '196B341D951BBDD96DE5B2B44B37054C2CBA8E89494C8B8A1052A863B8BC7596', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'kernel_num_gb': 0.025165824}, | |
min_elem_per_thread=0 | |
) | |
@triton.jit | |
def triton_poi_fused_clone_23(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr): | |
xnumel = 6291456 | |
xoffset = tl.program_id(0) * XBLOCK | |
xindex = xoffset + tl.arange(0, XBLOCK)[:] | |
xmask = tl.full([XBLOCK], True, tl.int1) | |
x0 = xindex % 64 | |
x1 = (xindex // 64) % 512 | |
x2 = (xindex // 32768) % 12 | |
x3 = (xindex // 393216) | |
x4 = xindex | |
tmp0 = tl.load(in_ptr0 + (x0 + (64*x2) + (768*x1) + (393216*x3)), None).to(tl.float32) | |
tl.store(out_ptr0 + (x4), tmp0, None) | |
def get_args(): | |
arg_0 = rand_strided((8192, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
arg_1 = rand_strided((16, 12, 512, 64), (393216, 32768, 64, 1), device='cuda:0', dtype=torch.float16) | |
return arg_0, arg_1, | |
def call(args): | |
with torch.cuda._DeviceGuard(0): | |
torch.cuda.set_device(0) | |
stream0 = get_raw_stream(0) | |
triton_poi_fused_clone_23.run(*args, 6291456, grid=grid(6291456), stream=stream0) | |
def benchmark_all_configs(args): | |
with torch.cuda._DeviceGuard(0): | |
torch.cuda.set_device(0) | |
return triton_poi_fused_clone_23.benchmark_all_configs(*args, 6291456, grid=grid(6291456)) | |
if __name__ == '__main__': | |
from triton.testing import do_bench | |
args = get_args() | |
ms = do_bench(lambda: call(args), rep=40, fast_flush=True) | |
num_gb = 0.025165824 | |
gb_per_s = num_gb / (ms / 1e3) | |
print(f"{ms:.3f}ms {num_gb:.3f}GB {gb_per_s:.2f}GB/s") | |
''', device_str='cuda') | |
# kernel path: /tmp/torchinductor_shunting/lj/cljluzoqrqvi7wumengog7jpxwcpskgi6tjet3lc7j2all5nibdx.py | |
# Source Nodes: [], Original ATen: [aten.sum] | |
triton_red_fused_sum_24 = async_compile.triton('triton_red_fused_sum_24', ''' | |
import triton | |
import triton.language as tl | |
from triton.compiler.compiler import AttrsDescriptor | |
from torch._inductor.runtime import triton_helpers, triton_heuristics | |
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties | |
from torch._dynamo.testing import rand_strided | |
from torch._C import _cuda_getCurrentRawStream as get_raw_stream | |
import torch | |
from torch._inductor.runtime.triton_heuristics import grid, split_scan_grid | |
@triton_heuristics.reduction( | |
size_hints=[65536, 128], | |
reduction_hint=ReductionHint.OUTER, | |
filename=__file__, | |
triton_meta={'signature': {0: '*fp16', 1: '*fp32', 2: 'i32', 3: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3), equal_to_1=())]}, | |
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused_sum_24', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': '196B341D951BBDD96DE5B2B44B37054C2CBA8E89494C8B8A1052A863B8BC7596', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'kernel_num_gb': 0.01277952} | |
) | |
@triton.jit | |
def triton_red_fused_sum_24(in_ptr0, out_ptr0, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr): | |
xnumel = 49152 | |
rnumel = 128 | |
xoffset = tl.program_id(0) * XBLOCK | |
xindex = xoffset + tl.arange(0, XBLOCK)[:, None] | |
xmask = tl.full([XBLOCK, RBLOCK], True, tl.int1) | |
rbase = tl.arange(0, RBLOCK)[None, :] | |
x0 = xindex % 768 | |
x1 = (xindex // 768) | |
_tmp2 = tl.full([XBLOCK, RBLOCK], 0, tl.float32) | |
x3 = xindex | |
for roffset in range(0, rnumel, RBLOCK): | |
rindex = roffset + rbase | |
rmask = rindex < rnumel | |
r2 = rindex | |
tmp0 = tl.load(in_ptr0 + (x0 + (768*r2) + (98304*x1)), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32) | |
tmp1 = tl.broadcast_to(tmp0, [XBLOCK, RBLOCK]) | |
tmp3 = _tmp2 + tmp1 | |
_tmp2 = tl.where(rmask, tmp3, _tmp2) | |
tmp2 = tl.sum(_tmp2, 1)[:, None] | |
tl.store(out_ptr0 + (x3), tmp2, None) | |
def get_args(): | |
arg_0 = rand_strided((16, 12, 512, 64), (393216, 64, 768, 1), device='cuda:0', dtype=torch.float16) | |
arg_1 = rand_strided((1, 768, 64), (49152, 1, 768), device='cuda:0', dtype=torch.float32) | |
return arg_0, arg_1, | |
def call(args): | |
with torch.cuda._DeviceGuard(0): | |
torch.cuda.set_device(0) | |
stream0 = get_raw_stream(0) | |
triton_red_fused_sum_24.run(*args, 49152, 128, grid=grid(49152), stream=stream0) | |
def benchmark_all_configs(args): | |
with torch.cuda._DeviceGuard(0): | |
torch.cuda.set_device(0) | |
return triton_red_fused_sum_24.benchmark_all_configs(*args, 49152, 128, grid=grid(49152)) | |
if __name__ == '__main__': | |
from triton.testing import do_bench | |
args = get_args() | |
ms = do_bench(lambda: call(args), rep=40, fast_flush=True) | |
num_gb = 0.01277952 | |
gb_per_s = num_gb / (ms / 1e3) | |
print(f"{ms:.3f}ms {num_gb:.3f}GB {gb_per_s:.2f}GB/s") | |
''', device_str='cuda') | |
# kernel path: /tmp/torchinductor_shunting/pg/cpgss3xinmgtgberwusouop6ykgkwstr6e2bi33zume6biol33tv.py | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_dropout_backward, aten.native_layer_norm_backward] | |
triton_per_fused__to_copy_add_native_dropout_backward_native_layer_norm_backward_25 = async_compile.triton('triton_per_fused__to_copy_add_native_dropout_backward_native_layer_norm_backward_25', ''' | |
import triton | |
import triton.language as tl | |
from triton.compiler.compiler import AttrsDescriptor | |
from torch._inductor.runtime import triton_helpers, triton_heuristics | |
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties | |
from torch._dynamo.testing import rand_strided | |
from torch._C import _cuda_getCurrentRawStream as get_raw_stream | |
import torch | |
from torch._inductor.runtime.triton_heuristics import grid, split_scan_grid | |
@triton_heuristics.persistent_reduction( | |
size_hints=[8192, 1024], | |
reduction_hint=ReductionHint.INNER, | |
filename=__file__, | |
triton_meta={'signature': {0: '*fp32', 1: '*fp16', 2: '*fp16', 3: '*fp16', 4: '*fp32', 5: '*fp32', 6: '*fp32', 7: '*i1', 8: '*fp32', 9: '*fp16', 10: 'i32', 11: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11), equal_to_1=())]}, | |
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_add_native_dropout_backward_native_layer_norm_backward_25', 'mutated_arg_names': [], 'no_x_dim': True, 'num_load': 8, 'num_reduction': 2, 'backend_hash': '196B341D951BBDD96DE5B2B44B37054C2CBA8E89494C8B8A1052A863B8BC7596', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'kernel_num_gb': 0.132156416} | |
) | |
@triton.jit | |
def triton_per_fused__to_copy_add_native_dropout_backward_native_layer_norm_backward_25(in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, in_ptr5, in_ptr6, in_ptr7, out_ptr3, out_ptr4, xnumel, rnumel): | |
xnumel = 8192 | |
XBLOCK: tl.constexpr = 1 | |
rnumel = 768 | |
RBLOCK: tl.constexpr = 1024 | |
xoffset = tl.program_id(0) * XBLOCK | |
xindex = tl.full([1], xoffset, tl.int32) | |
xmask = tl.full([RBLOCK], True, tl.int1) | |
rindex = tl.arange(0, RBLOCK)[:] | |
roffset = 0 | |
rmask = rindex < rnumel | |
r1 = rindex | |
x0 = xindex | |
tmp0 = tl.load(in_ptr0 + (r1 + (768*x0)), rmask, other=0.0) | |
tmp1 = tl.load(in_ptr1 + (r1 + (768*x0)), rmask, other=0.0).to(tl.float32) | |
tmp4 = tl.load(in_ptr2 + (r1 + (768*x0)), rmask, other=0.0).to(tl.float32) | |
tmp7 = tl.load(in_ptr3 + (r1 + (768*x0)), rmask, other=0.0).to(tl.float32) | |
tmp10 = tl.load(in_ptr4 + (r1), rmask, eviction_policy='evict_last', other=0.0) | |
tmp16 = tl.load(in_ptr5 + (r1 + (768*x0)), rmask, other=0.0) | |
tmp22 = tl.load(in_ptr6 + (x0), None, eviction_policy='evict_last') | |
tmp30 = tl.load(in_ptr7 + (r1 + (768*x0)), rmask, other=0.0).to(tl.int1) | |
tmp2 = tmp1.to(tl.float32) | |
tmp3 = tmp0 + tmp2 | |
tmp5 = tmp4.to(tl.float32) | |
tmp6 = tmp3 + tmp5 | |
tmp8 = tmp7.to(tl.float32) | |
tmp9 = tmp6 + tmp8 | |
tmp11 = tmp9 * tmp10 | |
tmp12 = tl.broadcast_to(tmp11, [RBLOCK]) | |
tmp14 = tl.where(rmask, tmp12, 0) | |
tmp15 = triton_helpers.promote_to_tensor(tl.sum(tmp14, 0)) | |
tmp17 = tmp11 * tmp16 | |
tmp18 = tl.broadcast_to(tmp17, [RBLOCK]) | |
tmp20 = tl.where(rmask, tmp18, 0) | |
tmp21 = triton_helpers.promote_to_tensor(tl.sum(tmp20, 0)) | |
tmp23 = 768.0 | |
tmp24 = tmp11 * tmp23 | |
tmp25 = tmp24 - tmp15 | |
tmp26 = tmp16 * tmp21 | |
tmp27 = tmp25 - tmp26 | |
tmp28 = tmp22 * tmp27 | |
tmp29 = tmp28.to(tl.float32) | |
tmp31 = tmp30.to(tl.float32) | |
tmp32 = 1.1111111111111112 | |
tmp33 = tmp31 * tmp32 | |
tmp34 = tmp29 * tmp33 | |
tl.store(out_ptr3 + (r1 + (768*x0)), tmp28, rmask) | |
tl.store(out_ptr4 + (r1 + (768*x0)), tmp34, rmask) | |
def get_args(): | |
arg_0 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.float32) | |
arg_1 = rand_strided((8192, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
arg_2 = rand_strided((8192, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
arg_3 = rand_strided((8192, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
arg_4 = rand_strided((768,), (1,), device='cuda:0', dtype=torch.float32) | |
arg_5 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.float32) | |
arg_6 = rand_strided((16, 512, 1), (512, 1, 1), device='cuda:0', dtype=torch.float32) | |
arg_7 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.bool) | |
arg_8 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.float32) | |
arg_9 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.float16) | |
return arg_0, arg_1, arg_2, arg_3, arg_4, arg_5, arg_6, arg_7, arg_8, arg_9, | |
def call(args): | |
with torch.cuda._DeviceGuard(0): | |
torch.cuda.set_device(0) | |
stream0 = get_raw_stream(0) | |
triton_per_fused__to_copy_add_native_dropout_backward_native_layer_norm_backward_25.run(*args, 8192, 768, grid=grid(8192), stream=stream0) | |
def benchmark_all_configs(args): | |
with torch.cuda._DeviceGuard(0): | |
torch.cuda.set_device(0) | |
return triton_per_fused__to_copy_add_native_dropout_backward_native_layer_norm_backward_25.benchmark_all_configs(*args, 8192, 768, grid=grid(8192)) | |
if __name__ == '__main__': | |
from triton.testing import do_bench | |
args = get_args() | |
ms = do_bench(lambda: call(args), rep=40, fast_flush=True) | |
num_gb = 0.132156416 | |
gb_per_s = num_gb / (ms / 1e3) | |
print(f"{ms:.3f}ms {num_gb:.3f}GB {gb_per_s:.2f}GB/s") | |
''', device_str='cuda') | |
# kernel path: /tmp/torchinductor_shunting/kk/ckkxhfownaqlblc6owmbx2ivrxr2v5fia4yyt5hpkdsas4vq57lh.py | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward] | |
triton_red_fused__to_copy_add_native_layer_norm_backward_26 = async_compile.triton('triton_red_fused__to_copy_add_native_layer_norm_backward_26', ''' | |
import triton | |
import triton.language as tl | |
from triton.compiler.compiler import AttrsDescriptor | |
from torch._inductor.runtime import triton_helpers, triton_heuristics | |
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties | |
from torch._dynamo.testing import rand_strided | |
from torch._C import _cuda_getCurrentRawStream as get_raw_stream | |
import torch | |
from torch._inductor.runtime.triton_heuristics import grid, split_scan_grid | |
@triton_heuristics.reduction( | |
size_hints=[65536, 128], | |
reduction_hint=ReductionHint.OUTER, | |
filename=__file__, | |
triton_meta={'signature': {0: '*fp32', 1: '*fp16', 2: '*fp16', 3: '*fp16', 4: '*fp32', 5: '*fp32', 6: '*fp32', 7: 'i32', 8: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4, 5, 6, 7, 8), equal_to_1=())]}, | |
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_add_native_layer_norm_backward_26', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 5, 'num_reduction': 2, 'backend_hash': '196B341D951BBDD96DE5B2B44B37054C2CBA8E89494C8B8A1052A863B8BC7596', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'kernel_num_gb': 0.0884736} | |
) | |
@triton.jit | |
def triton_red_fused__to_copy_add_native_layer_norm_backward_26(in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, out_ptr0, out_ptr1, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr): | |
xnumel = 49152 | |
rnumel = 128 | |
xoffset = tl.program_id(0) * XBLOCK | |
xindex = xoffset + tl.arange(0, XBLOCK)[:, None] | |
xmask = tl.full([XBLOCK, RBLOCK], True, tl.int1) | |
rbase = tl.arange(0, RBLOCK)[None, :] | |
x0 = xindex % 768 | |
x1 = (xindex // 768) | |
_tmp13 = tl.full([XBLOCK, RBLOCK], 0, tl.float32) | |
x3 = xindex | |
_tmp16 = tl.full([XBLOCK, RBLOCK], 0, tl.float32) | |
for roffset in range(0, rnumel, RBLOCK): | |
rindex = roffset + rbase | |
rmask = rindex < rnumel | |
r2 = rindex | |
tmp0 = tl.load(in_ptr0 + (x0 + (768*r2) + (98304*x1)), rmask, eviction_policy='evict_first', other=0.0) | |
tmp1 = tl.load(in_ptr1 + (x0 + (768*r2) + (98304*x1)), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32) | |
tmp4 = tl.load(in_ptr2 + (x0 + (768*r2) + (98304*x1)), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32) | |
tmp7 = tl.load(in_ptr3 + (x0 + (768*r2) + (98304*x1)), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32) | |
tmp10 = tl.load(in_ptr4 + (x0 + (768*r2) + (98304*x1)), rmask, eviction_policy='evict_first', other=0.0) | |
tmp2 = tmp1.to(tl.float32) | |
tmp3 = tmp0 + tmp2 | |
tmp5 = tmp4.to(tl.float32) | |
tmp6 = tmp3 + tmp5 | |
tmp8 = tmp7.to(tl.float32) | |
tmp9 = tmp6 + tmp8 | |
tmp11 = tmp9 * tmp10 | |
tmp12 = tl.broadcast_to(tmp11, [XBLOCK, RBLOCK]) | |
tmp14 = _tmp13 + tmp12 | |
_tmp13 = tl.where(rmask, tmp14, _tmp13) | |
tmp15 = tl.broadcast_to(tmp9, [XBLOCK, RBLOCK]) | |
tmp17 = _tmp16 + tmp15 | |
_tmp16 = tl.where(rmask, tmp17, _tmp16) | |
tmp13 = tl.sum(_tmp13, 1)[:, None] | |
tl.store(out_ptr0 + (x3), tmp13, None) | |
tmp16 = tl.sum(_tmp16, 1)[:, None] | |
tl.store(out_ptr1 + (x3), tmp16, None) | |
def get_args(): | |
arg_0 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.float32) | |
arg_1 = rand_strided((8192, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
arg_2 = rand_strided((8192, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
arg_3 = rand_strided((8192, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
arg_4 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.float32) | |
arg_5 = rand_strided((768, 64), (1, 768), device='cuda:0', dtype=torch.float32) | |
arg_6 = rand_strided((768, 64), (1, 768), device='cuda:0', dtype=torch.float32) | |
return arg_0, arg_1, arg_2, arg_3, arg_4, arg_5, arg_6, | |
def call(args): | |
with torch.cuda._DeviceGuard(0): | |
torch.cuda.set_device(0) | |
stream0 = get_raw_stream(0) | |
triton_red_fused__to_copy_add_native_layer_norm_backward_26.run(*args, 49152, 128, grid=grid(49152), stream=stream0) | |
def benchmark_all_configs(args): | |
with torch.cuda._DeviceGuard(0): | |
torch.cuda.set_device(0) | |
return triton_red_fused__to_copy_add_native_layer_norm_backward_26.benchmark_all_configs(*args, 49152, 128, grid=grid(49152)) | |
if __name__ == '__main__': | |
from triton.testing import do_bench | |
args = get_args() | |
ms = do_bench(lambda: call(args), rep=40, fast_flush=True) | |
num_gb = 0.0884736 | |
gb_per_s = num_gb / (ms / 1e3) | |
print(f"{ms:.3f}ms {num_gb:.3f}GB {gb_per_s:.2f}GB/s") | |
''', device_str='cuda') | |
# kernel path: /tmp/torchinductor_shunting/xt/cxtwhwavzlx26uc42l5ttjvmznisiph6donsz4lua3ibmgh7atnd.py | |
# Source Nodes: [], Original ATen: [aten.embedding_dense_backward] | |
triton_poi_fused_embedding_dense_backward_27 = async_compile.triton('triton_poi_fused_embedding_dense_backward_27', ''' | |
import triton | |
import triton.language as tl | |
from triton.compiler.compiler import AttrsDescriptor | |
from torch._inductor.runtime import triton_helpers, triton_heuristics | |
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties | |
from torch._dynamo.testing import rand_strided | |
from torch._C import _cuda_getCurrentRawStream as get_raw_stream | |
import torch | |
from torch._inductor.runtime.triton_heuristics import grid, split_scan_grid | |
@triton_heuristics.pointwise( | |
size_hints=[2048], | |
filename=__file__, | |
triton_meta={'signature': {0: '*fp32', 1: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1), equal_to_1=())]}, | |
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_embedding_dense_backward_27', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 0, 'num_reduction': 0, 'backend_hash': '196B341D951BBDD96DE5B2B44B37054C2CBA8E89494C8B8A1052A863B8BC7596', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'kernel_num_gb': 6.144e-06}, | |
min_elem_per_thread=0 | |
) | |
@triton.jit | |
def triton_poi_fused_embedding_dense_backward_27(out_ptr0, xnumel, XBLOCK : tl.constexpr): | |
xnumel = 1536 | |
xoffset = tl.program_id(0) * XBLOCK | |
xindex = xoffset + tl.arange(0, XBLOCK)[:] | |
xmask = xindex < xnumel | |
x0 = xindex | |
tmp0 = 0.0 | |
tl.store(out_ptr0 + (x0), tmp0, xmask) | |
def get_args(): | |
arg_0 = rand_strided((2, 768), (768, 1), device='cuda:0', dtype=torch.float32) | |
return arg_0, | |
def call(args): | |
with torch.cuda._DeviceGuard(0): | |
torch.cuda.set_device(0) | |
stream0 = get_raw_stream(0) | |
triton_poi_fused_embedding_dense_backward_27.run(*args, 1536, grid=grid(1536), stream=stream0) | |
def benchmark_all_configs(args): | |
with torch.cuda._DeviceGuard(0): | |
torch.cuda.set_device(0) | |
return triton_poi_fused_embedding_dense_backward_27.benchmark_all_configs(*args, 1536, grid=grid(1536)) | |
if __name__ == '__main__': | |
from triton.testing import do_bench | |
args = get_args() | |
ms = do_bench(lambda: call(args), rep=40, fast_flush=True) | |
num_gb = 6.144e-06 | |
gb_per_s = num_gb / (ms / 1e3) | |
print(f"{ms:.3f}ms {num_gb:.3f}GB {gb_per_s:.2f}GB/s") | |
''', device_str='cuda') | |
# kernel path: /tmp/torchinductor_shunting/ux/cuxomnrffakbliaucrregsixhu7biugelgox2sqnovnhnqnsgirh.py | |
# Source Nodes: [], Original ATen: [aten.embedding_dense_backward] | |
triton_poi_fused_embedding_dense_backward_28 = async_compile.triton('triton_poi_fused_embedding_dense_backward_28', ''' | |
import triton | |
import triton.language as tl | |
from triton.compiler.compiler import AttrsDescriptor | |
from torch._inductor.runtime import triton_helpers, triton_heuristics | |
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties | |
from torch._dynamo.testing import rand_strided | |
from torch._C import _cuda_getCurrentRawStream as get_raw_stream | |
import torch | |
from torch._inductor.runtime.triton_heuristics import grid, split_scan_grid | |
@triton_heuristics.pointwise( | |
size_hints=[33554432], | |
filename=__file__, | |
triton_meta={'signature': {0: '*fp32', 1: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1), equal_to_1=())]}, | |
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_embedding_dense_backward_28', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 0, 'num_reduction': 0, 'backend_hash': '196B341D951BBDD96DE5B2B44B37054C2CBA8E89494C8B8A1052A863B8BC7596', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'kernel_num_gb': 0.093763584}, | |
min_elem_per_thread=0 | |
) | |
@triton.jit | |
def triton_poi_fused_embedding_dense_backward_28(out_ptr0, xnumel, XBLOCK : tl.constexpr): | |
xnumel = 23440896 | |
xoffset = tl.program_id(0) * XBLOCK | |
xindex = xoffset + tl.arange(0, XBLOCK)[:] | |
xmask = xindex < xnumel | |
x0 = xindex | |
tmp0 = 0.0 | |
tl.store(out_ptr0 + (x0), tmp0, xmask) | |
def get_args(): | |
arg_0 = rand_strided((30522, 768), (768, 1), device='cuda:0', dtype=torch.float32) | |
return arg_0, | |
def call(args): | |
with torch.cuda._DeviceGuard(0): | |
torch.cuda.set_device(0) | |
stream0 = get_raw_stream(0) | |
triton_poi_fused_embedding_dense_backward_28.run(*args, 23440896, grid=grid(23440896), stream=stream0) | |
def benchmark_all_configs(args): | |
with torch.cuda._DeviceGuard(0): | |
torch.cuda.set_device(0) | |
return triton_poi_fused_embedding_dense_backward_28.benchmark_all_configs(*args, 23440896, grid=grid(23440896)) | |
if __name__ == '__main__': | |
from triton.testing import do_bench | |
args = get_args() | |
ms = do_bench(lambda: call(args), rep=40, fast_flush=True) | |
num_gb = 0.093763584 | |
gb_per_s = num_gb / (ms / 1e3) | |
print(f"{ms:.3f}ms {num_gb:.3f}GB {gb_per_s:.2f}GB/s") | |
''', device_str='cuda') | |
# kernel path: /tmp/torchinductor_shunting/ux/cuxgjhibr4qqfqusc6n5la3fhval4hxzobfn4dzyazk5dckxavok.py | |
# Source Nodes: [masked_lm_loss], Original ATen: [aten._to_copy, aten.add, aten.embedding_dense_backward, aten.native_dropout_backward, aten.native_layer_norm_backward, aten.nll_loss_forward] | |
# masked_lm_loss => full_default_2 | |
triton_per_fused__to_copy_add_embedding_dense_backward_native_dropout_backward_native_layer_norm_backward_nll_loss_forward_29 = async_compile.triton('triton_per_fused__to_copy_add_embedding_dense_backward_native_dropout_backward_native_layer_norm_backward_nll_loss_forward_29', ''' | |
import triton | |
import triton.language as tl | |
from triton.compiler.compiler import AttrsDescriptor | |
from torch._inductor.runtime import triton_helpers, triton_heuristics | |
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties | |
from torch._dynamo.testing import rand_strided | |
from torch._C import _cuda_getCurrentRawStream as get_raw_stream | |
import torch | |
from torch._inductor.runtime.triton_heuristics import grid, split_scan_grid | |
@triton_heuristics.persistent_reduction( | |
size_hints=[8192, 1024], | |
reduction_hint=ReductionHint.INNER, | |
filename=__file__, | |
triton_meta={'signature': {0: '*fp32', 1: '*fp16', 2: '*fp16', 3: '*fp16', 4: '*i1', 5: '*fp32', 6: '*fp32', 7: '*fp32', 8: '*i64', 9: '*i64', 10: '*fp32', 11: '*fp32', 12: '*fp32', 13: 'i32', 14: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14), equal_to_1=())]}, | |
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_add_embedding_dense_backward_native_dropout_backward_native_layer_norm_backward_nll_loss_forward_29', 'mutated_arg_names': ['in_out_ptr0', 'out_ptr3', 'out_ptr4'], 'no_x_dim': True, 'num_load': 10, 'num_reduction': 2, 'backend_hash': '196B341D951BBDD96DE5B2B44B37054C2CBA8E89494C8B8A1052A863B8BC7596', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'kernel_num_gb': 0.169980928} | |
) | |
@triton.jit | |
def triton_per_fused__to_copy_add_embedding_dense_backward_native_dropout_backward_native_layer_norm_backward_nll_loss_forward_29(in_out_ptr0, in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, in_ptr5, in_ptr6, in_ptr7, in_ptr8, out_ptr2, out_ptr3, out_ptr4, xnumel, rnumel): | |
xnumel = 8192 | |
XBLOCK: tl.constexpr = 1 | |
rnumel = 768 | |
RBLOCK: tl.constexpr = 1024 | |
xoffset = tl.program_id(0) * XBLOCK | |
xindex = tl.full([1], xoffset, tl.int32) | |
xmask = tl.full([RBLOCK], True, tl.int1) | |
rindex = tl.arange(0, RBLOCK)[:] | |
roffset = 0 | |
rmask = rindex < rnumel | |
r1 = rindex | |
x0 = xindex | |
x2 = xindex % 512 | |
tmp0 = tl.load(in_out_ptr0 + (r1 + (768*x0)), rmask, other=0.0) | |
tmp1 = tl.load(in_ptr0 + (r1 + (768*x0)), rmask, other=0.0).to(tl.float32) | |
tmp4 = tl.load(in_ptr1 + (r1 + (768*x0)), rmask, other=0.0).to(tl.float32) | |
tmp7 = tl.load(in_ptr2 + (r1 + (768*x0)), rmask, other=0.0).to(tl.float32) | |
tmp10 = tl.load(in_ptr3 + (r1 + (768*x0)), rmask, other=0.0).to(tl.int1) | |
tmp15 = tl.load(in_ptr4 + (r1), rmask, eviction_policy='evict_last', other=0.0) | |
tmp21 = tl.load(in_ptr5 + (r1 + (768*x0)), rmask, other=0.0) | |
tmp27 = tl.load(in_ptr6 + (x0), None, eviction_policy='evict_last') | |
tmp34 = tl.load(in_ptr7 + (x2), None, eviction_policy='evict_last') | |
tmp43 = tl.load(in_ptr8 + (x0), None, eviction_policy='evict_last') | |
tmp2 = tmp1.to(tl.float32) | |
tmp3 = tmp0 + tmp2 | |
tmp5 = tmp4.to(tl.float32) | |
tmp6 = tmp3 + tmp5 | |
tmp8 = tmp7.to(tl.float32) | |
tmp9 = tmp6 + tmp8 | |
tmp11 = tmp10.to(tl.float32) | |
tmp12 = 1.1111111111111112 | |
tmp13 = tmp11 * tmp12 | |
tmp14 = tmp9 * tmp13 | |
tmp16 = tmp14 * tmp15 | |
tmp17 = tl.broadcast_to(tmp16, [RBLOCK]) | |
tmp19 = tl.where(rmask, tmp17, 0) | |
tmp20 = triton_helpers.promote_to_tensor(tl.sum(tmp19, 0)) | |
tmp22 = tmp16 * tmp21 | |
tmp23 = tl.broadcast_to(tmp22, [RBLOCK]) | |
tmp25 = tl.where(rmask, tmp23, 0) | |
tmp26 = triton_helpers.promote_to_tensor(tl.sum(tmp25, 0)) | |
tmp28 = 768.0 | |
tmp29 = tmp16 * tmp28 | |
tmp30 = tmp29 - tmp20 | |
tmp31 = tmp21 * tmp26 | |
tmp32 = tmp30 - tmp31 | |
tmp33 = tmp27 * tmp32 | |
tmp35 = tl.full([RBLOCK], 2, tl.int32) | |
tmp36 = tmp34 + tmp35 | |
tmp37 = tmp34 < 0 | |
tmp38 = tl.where(tmp37, tmp36, tmp34) | |
tmp39 = tl.full([1], -1, tl.int64) | |
tmp40 = tmp34 == tmp39 | |
tmp41 = 0.0 | |
tmp42 = tl.where(tmp40, tmp41, tmp33) | |
tmp44 = tl.full([RBLOCK], 30522, tl.int32) | |
tmp45 = tmp43 + tmp44 | |
tmp46 = tmp43 < 0 | |
tmp47 = tl.where(tmp46, tmp45, tmp43) | |
tmp48 = tl.full([1], 0, tl.int64) | |
tmp49 = tmp43 == tmp48 | |
tmp50 = tl.where(tmp49, tmp41, tmp33) | |
tl.store(in_out_ptr0 + (r1 + (768*x0)), tmp14, rmask) | |
tl.store(out_ptr2 + (r1 + (768*x0)), tmp33, rmask) | |
tl.atomic_add(out_ptr3 + (tl.broadcast_to(r1 + (768*tmp38), [RBLOCK])), tmp42, rmask, sem='relaxed') | |
tl.atomic_add(out_ptr4 + (tl.broadcast_to(r1 + (768*tmp47), [RBLOCK])), tmp50, rmask, sem='relaxed') | |
def get_args(): | |
arg_0 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.float32) | |
arg_1 = rand_strided((8192, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
arg_2 = rand_strided((8192, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
arg_3 = rand_strided((8192, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
arg_4 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.bool) | |
arg_5 = rand_strided((768,), (1,), device='cuda:0', dtype=torch.float32) | |
arg_6 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.float32) | |
arg_7 = rand_strided((16, 512, 1), (512, 1, 1), device='cuda:0', dtype=torch.float32) | |
arg_8 = rand_strided((1, 512), (512, 1), device='cuda:0', dtype=torch.int64) | |
arg_9 = rand_strided((16, 512), (512, 1), device='cuda:0', dtype=torch.int64) | |
arg_10 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.float32) | |
arg_11 = rand_strided((2, 768), (768, 1), device='cuda:0', dtype=torch.float32) | |
arg_12 = rand_strided((30522, 768), (768, 1), device='cuda:0', dtype=torch.float32) | |
return arg_0, arg_1, arg_2, arg_3, arg_4, arg_5, arg_6, arg_7, arg_8, arg_9, arg_10, arg_11, arg_12, | |
def call(args): | |
with torch.cuda._DeviceGuard(0): | |
torch.cuda.set_device(0) | |
stream0 = get_raw_stream(0) | |
triton_per_fused__to_copy_add_embedding_dense_backward_native_dropout_backward_native_layer_norm_backward_nll_loss_forward_29.run(*args, 8192, 768, grid=grid(8192), stream=stream0) | |
def benchmark_all_configs(args): | |
with torch.cuda._DeviceGuard(0): | |
torch.cuda.set_device(0) | |
return triton_per_fused__to_copy_add_embedding_dense_backward_native_dropout_backward_native_layer_norm_backward_nll_loss_forward_29.benchmark_all_configs(*args, 8192, 768, grid=grid(8192)) | |
if __name__ == '__main__': | |
from triton.testing import do_bench | |
args = get_args() | |
ms = do_bench(lambda: call(args), rep=40, fast_flush=True) | |
num_gb = 0.169980928 | |
gb_per_s = num_gb / (ms / 1e3) | |
print(f"{ms:.3f}ms {num_gb:.3f}GB {gb_per_s:.2f}GB/s") | |
''', device_str='cuda') | |
# kernel path: /tmp/torchinductor_shunting/wc/cwcm7d2opg7dvfg3iqa43co7xxrrnilj7flh2kmutnq2zqi4vlfi.py | |
# Source Nodes: [], Original ATen: [aten.native_layer_norm_backward] | |
triton_red_fused_native_layer_norm_backward_30 = async_compile.triton('triton_red_fused_native_layer_norm_backward_30', ''' | |
import triton | |
import triton.language as tl | |
from triton.compiler.compiler import AttrsDescriptor | |
from torch._inductor.runtime import triton_helpers, triton_heuristics | |
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties | |
from torch._dynamo.testing import rand_strided | |
from torch._C import _cuda_getCurrentRawStream as get_raw_stream | |
import torch | |
from torch._inductor.runtime.triton_heuristics import grid, split_scan_grid | |
@triton_heuristics.reduction( | |
size_hints=[65536, 128], | |
reduction_hint=ReductionHint.OUTER, | |
filename=__file__, | |
triton_meta={'signature': {0: '*fp32', 1: '*fp32', 2: '*fp32', 3: '*fp32', 4: 'i32', 5: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4, 5), equal_to_1=())]}, | |
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused_native_layer_norm_backward_30', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 2, 'num_reduction': 2, 'backend_hash': '196B341D951BBDD96DE5B2B44B37054C2CBA8E89494C8B8A1052A863B8BC7596', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'kernel_num_gb': 0.050724864} | |
) | |
@triton.jit | |
def triton_red_fused_native_layer_norm_backward_30(in_ptr0, in_ptr1, out_ptr0, out_ptr1, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr): | |
xnumel = 49152 | |
rnumel = 128 | |
xoffset = tl.program_id(0) * XBLOCK | |
xindex = xoffset + tl.arange(0, XBLOCK)[:, None] | |
xmask = tl.full([XBLOCK, RBLOCK], True, tl.int1) | |
rbase = tl.arange(0, RBLOCK)[None, :] | |
x0 = xindex % 768 | |
x1 = (xindex // 768) | |
_tmp4 = tl.full([XBLOCK, RBLOCK], 0, tl.float32) | |
x3 = xindex | |
_tmp7 = tl.full([XBLOCK, RBLOCK], 0, tl.float32) | |
for roffset in range(0, rnumel, RBLOCK): | |
rindex = roffset + rbase | |
rmask = rindex < rnumel | |
r2 = rindex | |
tmp0 = tl.load(in_ptr0 + (x0 + (768*r2) + (98304*x1)), rmask, eviction_policy='evict_first', other=0.0) | |
tmp1 = tl.load(in_ptr1 + (x0 + (768*r2) + (98304*x1)), rmask, eviction_policy='evict_first', other=0.0) | |
tmp2 = tmp0 * tmp1 | |
tmp3 = tl.broadcast_to(tmp2, [XBLOCK, RBLOCK]) | |
tmp5 = _tmp4 + tmp3 | |
_tmp4 = tl.where(rmask, tmp5, _tmp4) | |
tmp6 = tl.broadcast_to(tmp0, [XBLOCK, RBLOCK]) | |
tmp8 = _tmp7 + tmp6 | |
_tmp7 = tl.where(rmask, tmp8, _tmp7) | |
tmp4 = tl.sum(_tmp4, 1)[:, None] | |
tl.store(out_ptr0 + (x3), tmp4, None) | |
tmp7 = tl.sum(_tmp7, 1)[:, None] | |
tl.store(out_ptr1 + (x3), tmp7, None) | |
def get_args(): | |
arg_0 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.float32) | |
arg_1 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.float32) | |
arg_2 = rand_strided((768, 64), (1, 768), device='cuda:0', dtype=torch.float32) | |
arg_3 = rand_strided((768, 64), (1, 768), device='cuda:0', dtype=torch.float32) | |
return arg_0, arg_1, arg_2, arg_3, | |
def call(args): | |
with torch.cuda._DeviceGuard(0): | |
torch.cuda.set_device(0) | |
stream0 = get_raw_stream(0) | |
triton_red_fused_native_layer_norm_backward_30.run(*args, 49152, 128, grid=grid(49152), stream=stream0) | |
def benchmark_all_configs(args): | |
with torch.cuda._DeviceGuard(0): | |
torch.cuda.set_device(0) | |
return triton_red_fused_native_layer_norm_backward_30.benchmark_all_configs(*args, 49152, 128, grid=grid(49152)) | |
if __name__ == '__main__': | |
from triton.testing import do_bench | |
args = get_args() | |
ms = do_bench(lambda: call(args), rep=40, fast_flush=True) | |
num_gb = 0.050724864 | |
gb_per_s = num_gb / (ms / 1e3) | |
print(f"{ms:.3f}ms {num_gb:.3f}GB {gb_per_s:.2f}GB/s") | |
''', device_str='cuda') | |
# kernel path: /tmp/torchinductor_shunting/oo/coouyldiohu5nljnik4bjbkcy56erkcsydxpoco2pnqxwiivtdx2.py | |
# Source Nodes: [], Original ATen: [aten.embedding_dense_backward] | |
triton_poi_fused_embedding_dense_backward_31 = async_compile.triton('triton_poi_fused_embedding_dense_backward_31', ''' | |
import triton | |
import triton.language as tl | |
from triton.compiler.compiler import AttrsDescriptor | |
from torch._inductor.runtime import triton_helpers, triton_heuristics | |
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties | |
from torch._dynamo.testing import rand_strided | |
from torch._C import _cuda_getCurrentRawStream as get_raw_stream | |
import torch | |
from torch._inductor.runtime.triton_heuristics import grid, split_scan_grid | |
@triton_heuristics.pointwise( | |
size_hints=[524288], | |
filename=__file__, | |
triton_meta={'signature': {0: '*fp32', 1: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1), equal_to_1=())]}, | |
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_embedding_dense_backward_31', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 0, 'num_reduction': 0, 'backend_hash': '196B341D951BBDD96DE5B2B44B37054C2CBA8E89494C8B8A1052A863B8BC7596', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'kernel_num_gb': 0.001572864}, | |
min_elem_per_thread=0 | |
) | |
@triton.jit | |
def triton_poi_fused_embedding_dense_backward_31(out_ptr0, xnumel, XBLOCK : tl.constexpr): | |
xnumel = 393216 | |
xoffset = tl.program_id(0) * XBLOCK | |
xindex = xoffset + tl.arange(0, XBLOCK)[:] | |
xmask = tl.full([XBLOCK], True, tl.int1) | |
x0 = xindex | |
tmp0 = 0.0 | |
tl.store(out_ptr0 + (x0), tmp0, None) | |
def get_args(): | |
arg_0 = rand_strided((512, 768), (768, 1), device='cuda:0', dtype=torch.float32) | |
return arg_0, | |
def call(args): | |
with torch.cuda._DeviceGuard(0): | |
torch.cuda.set_device(0) | |
stream0 = get_raw_stream(0) | |
triton_poi_fused_embedding_dense_backward_31.run(*args, 393216, grid=grid(393216), stream=stream0) | |
def benchmark_all_configs(args): | |
with torch.cuda._DeviceGuard(0): | |
torch.cuda.set_device(0) | |
return triton_poi_fused_embedding_dense_backward_31.benchmark_all_configs(*args, 393216, grid=grid(393216)) | |
if __name__ == '__main__': | |
from triton.testing import do_bench | |
args = get_args() | |
ms = do_bench(lambda: call(args), rep=40, fast_flush=True) | |
num_gb = 0.001572864 | |
gb_per_s = num_gb / (ms / 1e3) | |
print(f"{ms:.3f}ms {num_gb:.3f}GB {gb_per_s:.2f}GB/s") | |
''', device_str='cuda') | |
# kernel path: /tmp/torchinductor_shunting/my/cmybittjc2jf2j2woo7yukrwbfrxkubzfhfznkvzq3czd6munu5v.py | |
# Source Nodes: [masked_lm_loss], Original ATen: [aten.embedding_dense_backward, aten.nll_loss_forward, aten.sum] | |
# masked_lm_loss => full_default_2 | |
triton_per_fused_embedding_dense_backward_nll_loss_forward_sum_32 = async_compile.triton('triton_per_fused_embedding_dense_backward_nll_loss_forward_sum_32', ''' | |
import triton | |
import triton.language as tl | |
from triton.compiler.compiler import AttrsDescriptor | |
from torch._inductor.runtime import triton_helpers, triton_heuristics | |
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties | |
from torch._dynamo.testing import rand_strided | |
from torch._C import _cuda_getCurrentRawStream as get_raw_stream | |
import torch | |
from torch._inductor.runtime.triton_heuristics import grid, split_scan_grid | |
@triton_heuristics.persistent_reduction( | |
size_hints=[524288, 16], | |
reduction_hint=ReductionHint.DEFAULT, | |
filename=__file__, | |
triton_meta={'signature': {0: '*fp32', 1: '*i64', 2: '*fp32', 3: 'i32', 4: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4), equal_to_1=())]}, | |
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_per_fused_embedding_dense_backward_nll_loss_forward_sum_32', 'mutated_arg_names': ['out_ptr1'], 'no_x_dim': False, 'num_load': 2, 'num_reduction': 1, 'backend_hash': '196B341D951BBDD96DE5B2B44B37054C2CBA8E89494C8B8A1052A863B8BC7596', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'kernel_num_gb': 0.026742784} | |
) | |
@triton.jit | |
def triton_per_fused_embedding_dense_backward_nll_loss_forward_sum_32(in_ptr0, in_ptr1, out_ptr1, xnumel, rnumel, XBLOCK : tl.constexpr): | |
xnumel = 393216 | |
rnumel = 16 | |
RBLOCK: tl.constexpr = 16 | |
xoffset = tl.program_id(0) * XBLOCK | |
xindex = xoffset + tl.arange(0, XBLOCK)[:, None] | |
xmask = tl.full([XBLOCK, RBLOCK], True, tl.int1) | |
rindex = tl.arange(0, RBLOCK)[None, :] | |
roffset = 0 | |
rmask = tl.full([XBLOCK, RBLOCK], True, tl.int1) | |
r1 = rindex | |
x0 = xindex | |
x3 = (xindex // 768) | |
x2 = xindex % 768 | |
tmp0 = tl.load(in_ptr0 + (x0 + (393216*r1)), None) | |
tmp4 = tl.load(in_ptr1 + (x3), None, eviction_policy='evict_last') | |
tmp1 = tl.broadcast_to(tmp0, [XBLOCK, RBLOCK]) | |
tmp3 = tl.sum(tmp1, 1)[:, None] | |
tmp5 = tl.full([XBLOCK, 1], 512, tl.int32) | |
tmp6 = tmp4 + tmp5 | |
tmp7 = tmp4 < 0 | |
tmp8 = tl.where(tmp7, tmp6, tmp4) | |
tmp9 = tl.full([1, 1], -1, tl.int64) | |
tmp10 = tmp4 == tmp9 | |
tmp11 = 0.0 | |
tmp12 = tl.where(tmp10, tmp11, tmp3) | |
tl.atomic_add(out_ptr1 + (x2 + (768*tmp8)), tmp12, None, sem='relaxed') | |
def get_args(): | |
arg_0 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.float32) | |
arg_1 = rand_strided((1, 512), (512, 1), device='cuda:0', dtype=torch.int64) | |
arg_2 = rand_strided((512, 768), (768, 1), device='cuda:0', dtype=torch.float32) | |
return arg_0, arg_1, arg_2, | |
def call(args): | |
with torch.cuda._DeviceGuard(0): | |
torch.cuda.set_device(0) | |
stream0 = get_raw_stream(0) | |
triton_per_fused_embedding_dense_backward_nll_loss_forward_sum_32.run(*args, 393216, 16, grid=grid(393216), stream=stream0) | |
def benchmark_all_configs(args): | |
with torch.cuda._DeviceGuard(0): | |
torch.cuda.set_device(0) | |
return triton_per_fused_embedding_dense_backward_nll_loss_forward_sum_32.benchmark_all_configs(*args, 393216, 16, grid=grid(393216)) | |
if __name__ == '__main__': | |
from triton.testing import do_bench | |
args = get_args() | |
ms = do_bench(lambda: call(args), rep=40, fast_flush=True) | |
num_gb = 0.026742784 | |
gb_per_s = num_gb / (ms / 1e3) | |
print(f"{ms:.3f}ms {num_gb:.3f}GB {gb_per_s:.2f}GB/s") | |
''', device_str='cuda') | |
async_compile.wait(globals()) | |
del async_compile | |
def call(args): | |
primals_4, primals_14, primals_20, primals_30, primals_36, primals_46, primals_52, primals_62, primals_68, primals_78, primals_84, primals_94, primals_100, primals_110, primals_116, primals_126, primals_132, primals_142, primals_148, primals_158, primals_164, primals_174, primals_180, primals_190, primals_196, primals_200, primals_204, primals_205, primals_206, primals_207, mul_1, gt, view, permute_default_66, permute_default_67, permute_default_68, getitem_129, getitem_130, getitem_131, getitem_132, view_16, gt_2, mul_9, view_18, addmm_4, view_20, gt_3, mul_16, view_22, permute_default_60, permute_default_61, permute_default_62, getitem_122, getitem_123, getitem_124, getitem_125, view_38, gt_5, mul_22, view_40, addmm_10, view_42, gt_6, mul_29, view_44, permute_default_54, permute_default_55, permute_default_56, getitem_115, getitem_116, getitem_117, getitem_118, view_60, gt_8, mul_35, view_62, addmm_16, view_64, gt_9, mul_42, view_66, permute_default_48, permute_default_49, permute_default_50, getitem_108, getitem_109, getitem_110, getitem_111, view_82, gt_11, mul_48, view_84, addmm_22, view_86, gt_12, mul_55, view_88, permute_default_42, permute_default_43, permute_default_44, getitem_101, getitem_102, getitem_103, getitem_104, view_104, gt_14, mul_61, view_106, addmm_28, view_108, gt_15, mul_68, view_110, permute_default_36, permute_default_37, permute_default_38, getitem_94, getitem_95, getitem_96, getitem_97, view_126, gt_17, mul_74, view_128, addmm_34, view_130, gt_18, mul_81, view_132, permute_default_30, permute_default_31, permute_default_32, getitem_87, getitem_88, getitem_89, getitem_90, view_148, gt_20, mul_87, view_150, addmm_40, view_152, gt_21, mul_94, view_154, permute_default_24, permute_default_25, permute_default_26, getitem_80, getitem_81, getitem_82, getitem_83, view_170, gt_23, mul_100, view_172, addmm_46, view_174, gt_24, mul_107, view_176, permute_default_18, permute_default_19, permute_default_20, getitem_73, getitem_74, getitem_75, getitem_76, view_192, gt_26, mul_113, view_194, addmm_52, view_196, gt_27, mul_120, view_198, permute_default_12, permute_default_13, permute_default_14, getitem_66, getitem_67, getitem_68, getitem_69, view_214, gt_29, mul_126, view_216, addmm_58, view_218, gt_30, mul_133, view_220, permute_default_6, permute_default_7, permute_default_8, getitem_59, getitem_60, getitem_61, getitem_62, view_236, gt_32, mul_139, view_238, addmm_64, view_240, gt_33, mul_146, view_242, permute_default, permute_default_1, permute_default_2, getitem_52, getitem_53, getitem_54, getitem_55, view_258, gt_35, mul_152, view_260, addmm_70, view_262, gt_36, mul_159, view_264, addmm_72, getitem_51, rsqrt_25, view_266, view_267, amax_12, log, convert_element_type_510, permute_134, permute_138, div_27, permute_142, permute_146, div_28, permute_150, permute_162, permute_167, permute_171, div_30, permute_175, permute_179, div_31, permute_183, permute_195, permute_200, permute_204, div_33, permute_208, permute_212, div_34, permute_216, permute_228, permute_233, permute_237, div_36, permute_241, permute_245, div_37, permute_249, permute_261, permute_266, permute_270, div_39, permute_274, permute_278, div_40, permute_282, permute_294, permute_299, permute_303, div_42, permute_307, permute_311, div_43, permute_315, permute_327, permute_332, permute_336, div_45, permute_340, permute_344, div_46, permute_348, permute_360, permute_365, permute_369, div_48, permute_373, permute_377, div_49, permute_381, permute_393, permute_398, permute_402, div_51, permute_406, permute_410, div_52, permute_414, permute_426, permute_431, permute_435, div_54, permute_439, permute_443, div_55, permute_447, permute_459, permute_464, permute_468, div_57, permute_472, permute_476, div_58, permute_480, permute_492, permute_497, permute_501, div_60, permute_505, permute_509, div_61, permute_513, permute_525, permute_530, permute_534, div_63, tangents_1, tangents_2 = args | |
args.clear() | |
assert_size_stride(primals_4, (768, ), (1, )) | |
assert_size_stride(primals_14, (768, ), (1, )) | |
assert_size_stride(primals_20, (768, ), (1, )) | |
assert_size_stride(primals_30, (768, ), (1, )) | |
assert_size_stride(primals_36, (768, ), (1, )) | |
assert_size_stride(primals_46, (768, ), (1, )) | |
assert_size_stride(primals_52, (768, ), (1, )) | |
assert_size_stride(primals_62, (768, ), (1, )) | |
assert_size_stride(primals_68, (768, ), (1, )) | |
assert_size_stride(primals_78, (768, ), (1, )) | |
assert_size_stride(primals_84, (768, ), (1, )) | |
assert_size_stride(primals_94, (768, ), (1, )) | |
assert_size_stride(primals_100, (768, ), (1, )) | |
assert_size_stride(primals_110, (768, ), (1, )) | |
assert_size_stride(primals_116, (768, ), (1, )) | |
assert_size_stride(primals_126, (768, ), (1, )) | |
assert_size_stride(primals_132, (768, ), (1, )) | |
assert_size_stride(primals_142, (768, ), (1, )) | |
assert_size_stride(primals_148, (768, ), (1, )) | |
assert_size_stride(primals_158, (768, ), (1, )) | |
assert_size_stride(primals_164, (768, ), (1, )) | |
assert_size_stride(primals_174, (768, ), (1, )) | |
assert_size_stride(primals_180, (768, ), (1, )) | |
assert_size_stride(primals_190, (768, ), (1, )) | |
assert_size_stride(primals_196, (768, ), (1, )) | |
assert_size_stride(primals_200, (768, ), (1, )) | |
assert_size_stride(primals_204, (1, 512), (512, 1)) | |
assert_size_stride(primals_205, (1, 512), (512, 1)) | |
assert_size_stride(primals_206, (16, 512), (512, 1)) | |
assert_size_stride(primals_207, (16, 512), (512, 1)) | |
assert_size_stride(mul_1, (16, 512, 768), (393216, 768, 1)) | |
assert_size_stride(gt, (16, 512, 768), (393216, 768, 1)) | |
assert_size_stride(view, (8192, 768), (768, 1)) | |
assert_size_stride(permute_default_66, (16, 12, 512, 64), (393216, 64, 768, 1)) | |
assert_size_stride(permute_default_67, (16, 12, 512, 64), (393216, 64, 768, 1)) | |
assert_size_stride(permute_default_68, (16, 12, 512, 64), (393216, 64, 768, 1)) | |
assert_size_stride(getitem_129, (16, 12, 512, 64), (393216, 64, 768, 1)) | |
assert_size_stride(getitem_130, (16, 12, 512), (6144, 512, 1)) | |
assert_size_stride(getitem_131, (), ()) | |
assert_size_stride(getitem_132, (), ()) | |
assert_size_stride(view_16, (8192, 768), (768, 1)) | |
assert_size_stride(gt_2, (16, 512, 768), (393216, 768, 1)) | |
assert_size_stride(mul_9, (16, 512, 768), (393216, 768, 1)) | |
assert_size_stride(view_18, (8192, 768), (768, 1)) | |
assert_size_stride(addmm_4, (8192, 3072), (3072, 1)) | |
assert_size_stride(view_20, (8192, 3072), (3072, 1)) | |
assert_size_stride(gt_3, (16, 512, 768), (393216, 768, 1)) | |
assert_size_stride(mul_16, (16, 512, 768), (393216, 768, 1)) | |
assert_size_stride(view_22, (8192, 768), (768, 1)) | |
assert_size_stride(permute_default_60, (16, 12, 512, 64), (393216, 64, 768, 1)) | |
assert_size_stride(permute_default_61, (16, 12, 512, 64), (393216, 64, 768, 1)) | |
assert_size_stride(permute_default_62, (16, 12, 512, 64), (393216, 64, 768, 1)) | |
assert_size_stride(getitem_122, (16, 12, 512, 64), (393216, 64, 768, 1)) | |
assert_size_stride(getitem_123, (16, 12, 512), (6144, 512, 1)) | |
assert_size_stride(getitem_124, (), ()) | |
assert_size_stride(getitem_125, (), ()) | |
assert_size_stride(view_38, (8192, 768), (768, 1)) | |
assert_size_stride(gt_5, (16, 512, 768), (393216, 768, 1)) | |
assert_size_stride(mul_22, (16, 512, 768), (393216, 768, 1)) | |
assert_size_stride(view_40, (8192, 768), (768, 1)) | |
assert_size_stride(addmm_10, (8192, 3072), (3072, 1)) | |
assert_size_stride(view_42, (8192, 3072), (3072, 1)) | |
assert_size_stride(gt_6, (16, 512, 768), (393216, 768, 1)) | |
assert_size_stride(mul_29, (16, 512, 768), (393216, 768, 1)) | |
assert_size_stride(view_44, (8192, 768), (768, 1)) | |
assert_size_stride(permute_default_54, (16, 12, 512, 64), (393216, 64, 768, 1)) | |
assert_size_stride(permute_default_55, (16, 12, 512, 64), (393216, 64, 768, 1)) | |
assert_size_stride(permute_default_56, (16, 12, 512, 64), (393216, 64, 768, 1)) | |
assert_size_stride(getitem_115, (16, 12, 512, 64), (393216, 64, 768, 1)) | |
assert_size_stride(getitem_116, (16, 12, 512), (6144, 512, 1)) | |
assert_size_stride(getitem_117, (), ()) | |
assert_size_stride(getitem_118, (), ()) | |
assert_size_stride(view_60, (8192, 768), (768, 1)) | |
assert_size_stride(gt_8, (16, 512, 768), (393216, 768, 1)) | |
assert_size_stride(mul_35, (16, 512, 768), (393216, 768, 1)) | |
assert_size_stride(view_62, (8192, 768), (768, 1)) | |
assert_size_stride(addmm_16, (8192, 3072), (3072, 1)) | |
assert_size_stride(view_64, (8192, 3072), (3072, 1)) | |
assert_size_stride(gt_9, (16, 512, 768), (393216, 768, 1)) | |
assert_size_stride(mul_42, (16, 512, 768), (393216, 768, 1)) | |
assert_size_stride(view_66, (8192, 768), (768, 1)) | |
assert_size_stride(permute_default_48, (16, 12, 512, 64), (393216, 64, 768, 1)) | |
assert_size_stride(permute_default_49, (16, 12, 512, 64), (393216, 64, 768, 1)) | |
assert_size_stride(permute_default_50, (16, 12, 512, 64), (393216, 64, 768, 1)) | |
assert_size_stride(getitem_108, (16, 12, 512, 64), (393216, 64, 768, 1)) | |
assert_size_stride(getitem_109, (16, 12, 512), (6144, 512, 1)) | |
assert_size_stride(getitem_110, (), ()) | |
assert_size_stride(getitem_111, (), ()) | |
assert_size_stride(view_82, (8192, 768), (768, 1)) | |
assert_size_stride(gt_11, (16, 512, 768), (393216, 768, 1)) | |
assert_size_stride(mul_48, (16, 512, 768), (393216, 768, 1)) | |
assert_size_stride(view_84, (8192, 768), (768, 1)) | |
assert_size_stride(addmm_22, (8192, 3072), (3072, 1)) | |
assert_size_stride(view_86, (8192, 3072), (3072, 1)) | |
assert_size_stride(gt_12, (16, 512, 768), (393216, 768, 1)) | |
assert_size_stride(mul_55, (16, 512, 768), (393216, 768, 1)) | |
assert_size_stride(view_88, (8192, 768), (768, 1)) | |
assert_size_stride(permute_default_42, (16, 12, 512, 64), (393216, 64, 768, 1)) | |
assert_size_stride(permute_default_43, (16, 12, 512, 64), (393216, 64, 768, 1)) | |
assert_size_stride(permute_default_44, (16, 12, 512, 64), (393216, 64, 768, 1)) | |
assert_size_stride(getitem_101, (16, 12, 512, 64), (393216, 64, 768, 1)) | |
assert_size_stride(getitem_102, (16, 12, 512), (6144, 512, 1)) | |
assert_size_stride(getitem_103, (), ()) | |
assert_size_stride(getitem_104, (), ()) | |
assert_size_stride(view_104, (8192, 768), (768, 1)) | |
assert_size_stride(gt_14, (16, 512, 768), (393216, 768, 1)) | |
assert_size_stride(mul_61, (16, 512, 768), (393216, 768, 1)) | |
assert_size_stride(view_106, (8192, 768), (768, 1)) | |
assert_size_stride(addmm_28, (8192, 3072), (3072, 1)) | |
assert_size_stride(view_108, (8192, 3072), (3072, 1)) | |
assert_size_stride(gt_15, (16, 512, 768), (393216, 768, 1)) | |
assert_size_stride(mul_68, (16, 512, 768), (393216, 768, 1)) | |
assert_size_stride(view_110, (8192, 768), (768, 1)) | |
assert_size_stride(permute_default_36, (16, 12, 512, 64), (393216, 64, 768, 1)) | |
assert_size_stride(permute_default_37, (16, 12, 512, 64), (393216, 64, 768, 1)) | |
assert_size_stride(permute_default_38, (16, 12, 512, 64), (393216, 64, 768, 1)) | |
assert_size_stride(getitem_94, (16, 12, 512, 64), (393216, 64, 768, 1)) | |
assert_size_stride(getitem_95, (16, 12, 512), (6144, 512, 1)) | |
assert_size_stride(getitem_96, (), ()) | |
assert_size_stride(getitem_97, (), ()) | |
assert_size_stride(view_126, (8192, 768), (768, 1)) | |
assert_size_stride(gt_17, (16, 512, 768), (393216, 768, 1)) | |
assert_size_stride(mul_74, (16, 512, 768), (393216, 768, 1)) | |
assert_size_stride(view_128, (8192, 768), (768, 1)) | |
assert_size_stride(addmm_34, (8192, 3072), (3072, 1)) | |
assert_size_stride(view_130, (8192, 3072), (3072, 1)) | |
assert_size_stride(gt_18, (16, 512, 768), (393216, 768, 1)) | |
assert_size_stride(mul_81, (16, 512, 768), (393216, 768, 1)) | |
assert_size_stride(view_132, (8192, 768), (768, 1)) | |
assert_size_stride(permute_default_30, (16, 12, 512, 64), (393216, 64, 768, 1)) | |
assert_size_stride(permute_default_31, (16, 12, 512, 64), (393216, 64, 768, 1)) | |
assert_size_stride(permute_default_32, (16, 12, 512, 64), (393216, 64, 768, 1)) | |
assert_size_stride(getitem_87, (16, 12, 512, 64), (393216, 64, 768, 1)) | |
assert_size_stride(getitem_88, (16, 12, 512), (6144, 512, 1)) | |
assert_size_stride(getitem_89, (), ()) | |
assert_size_stride(getitem_90, (), ()) | |
assert_size_stride(view_148, (8192, 768), (768, 1)) | |
assert_size_stride(gt_20, (16, 512, 768), (393216, 768, 1)) | |
assert_size_stride(mul_87, (16, 512, 768), (393216, 768, 1)) | |
assert_size_stride(view_150, (8192, 768), (768, 1)) | |
assert_size_stride(addmm_40, (8192, 3072), (3072, 1)) | |
assert_size_stride(view_152, (8192, 3072), (3072, 1)) | |
assert_size_stride(gt_21, (16, 512, 768), (393216, 768, 1)) | |
assert_size_stride(mul_94, (16, 512, 768), (393216, 768, 1)) | |
assert_size_stride(view_154, (8192, 768), (768, 1)) | |
assert_size_stride(permute_default_24, (16, 12, 512, 64), (393216, 64, 768, 1)) | |
assert_size_stride(permute_default_25, (16, 12, 512, 64), (393216, 64, 768, 1)) | |
assert_size_stride(permute_default_26, (16, 12, 512, 64), (393216, 64, 768, 1)) | |
assert_size_stride(getitem_80, (16, 12, 512, 64), (393216, 64, 768, 1)) | |
assert_size_stride(getitem_81, (16, 12, 512), (6144, 512, 1)) | |
assert_size_stride(getitem_82, (), ()) | |
assert_size_stride(getitem_83, (), ()) | |
assert_size_stride(view_170, (8192, 768), (768, 1)) | |
assert_size_stride(gt_23, (16, 512, 768), (393216, 768, 1)) | |
assert_size_stride(mul_100, (16, 512, 768), (393216, 768, 1)) | |
assert_size_stride(view_172, (8192, 768), (768, 1)) | |
assert_size_stride(addmm_46, (8192, 3072), (3072, 1)) | |
assert_size_stride(view_174, (8192, 3072), (3072, 1)) | |
assert_size_stride(gt_24, (16, 512, 768), (393216, 768, 1)) | |
assert_size_stride(mul_107, (16, 512, 768), (393216, 768, 1)) | |
assert_size_stride(view_176, (8192, 768), (768, 1)) | |
assert_size_stride(permute_default_18, (16, 12, 512, 64), (393216, 64, 768, 1)) | |
assert_size_stride(permute_default_19, (16, 12, 512, 64), (393216, 64, 768, 1)) | |
assert_size_stride(permute_default_20, (16, 12, 512, 64), (393216, 64, 768, 1)) | |
assert_size_stride(getitem_73, (16, 12, 512, 64), (393216, 64, 768, 1)) | |
assert_size_stride(getitem_74, (16, 12, 512), (6144, 512, 1)) | |
assert_size_stride(getitem_75, (), ()) | |
assert_size_stride(getitem_76, (), ()) | |
assert_size_stride(view_192, (8192, 768), (768, 1)) | |
assert_size_stride(gt_26, (16, 512, 768), (393216, 768, 1)) | |
assert_size_stride(mul_113, (16, 512, 768), (393216, 768, 1)) | |
assert_size_stride(view_194, (8192, 768), (768, 1)) | |
assert_size_stride(addmm_52, (8192, 3072), (3072, 1)) | |
assert_size_stride(view_196, (8192, 3072), (3072, 1)) | |
assert_size_stride(gt_27, (16, 512, 768), (393216, 768, 1)) | |
assert_size_stride(mul_120, (16, 512, 768), (393216, 768, 1)) | |
assert_size_stride(view_198, (8192, 768), (768, 1)) | |
assert_size_stride(permute_default_12, (16, 12, 512, 64), (393216, 64, 768, 1)) | |
assert_size_stride(permute_default_13, (16, 12, 512, 64), (393216, 64, 768, 1)) | |
assert_size_stride(permute_default_14, (16, 12, 512, 64), (393216, 64, 768, 1)) | |
assert_size_stride(getitem_66, (16, 12, 512, 64), (393216, 64, 768, 1)) | |
assert_size_stride(getitem_67, (16, 12, 512), (6144, 512, 1)) | |
assert_size_stride(getitem_68, (), ()) | |
assert_size_stride(getitem_69, (), ()) | |
assert_size_stride(view_214, (8192, 768), (768, 1)) | |
assert_size_stride(gt_29, (16, 512, 768), (393216, 768, 1)) | |
assert_size_stride(mul_126, (16, 512, 768), (393216, 768, 1)) | |
assert_size_stride(view_216, (8192, 768), (768, 1)) | |
assert_size_stride(addmm_58, (8192, 3072), (3072, 1)) | |
assert_size_stride(view_218, (8192, 3072), (3072, 1)) | |
assert_size_stride(gt_30, (16, 512, 768), (393216, 768, 1)) | |
assert_size_stride(mul_133, (16, 512, 768), (393216, 768, 1)) | |
assert_size_stride(view_220, (8192, 768), (768, 1)) | |
assert_size_stride(permute_default_6, (16, 12, 512, 64), (393216, 64, 768, 1)) | |
assert_size_stride(permute_default_7, (16, 12, 512, 64), (393216, 64, 768, 1)) | |
assert_size_stride(permute_default_8, (16, 12, 512, 64), (393216, 64, 768, 1)) | |
assert_size_stride(getitem_59, (16, 12, 512, 64), (393216, 64, 768, 1)) | |
assert_size_stride(getitem_60, (16, 12, 512), (6144, 512, 1)) | |
assert_size_stride(getitem_61, (), ()) | |
assert_size_stride(getitem_62, (), ()) | |
assert_size_stride(view_236, (8192, 768), (768, 1)) | |
assert_size_stride(gt_32, (16, 512, 768), (393216, 768, 1)) | |
assert_size_stride(mul_139, (16, 512, 768), (393216, 768, 1)) | |
assert_size_stride(view_238, (8192, 768), (768, 1)) | |
assert_size_stride(addmm_64, (8192, 3072), (3072, 1)) | |
assert_size_stride(view_240, (8192, 3072), (3072, 1)) | |
assert_size_stride(gt_33, (16, 512, 768), (393216, 768, 1)) | |
assert_size_stride(mul_146, (16, 512, 768), (393216, 768, 1)) | |
assert_size_stride(view_242, (8192, 768), (768, 1)) | |
assert_size_stride(permute_default, (16, 12, 512, 64), (393216, 64, 768, 1)) | |
assert_size_stride(permute_default_1, (16, 12, 512, 64), (393216, 64, 768, 1)) | |
assert_size_stride(permute_default_2, (16, 12, 512, 64), (393216, 64, 768, 1)) | |
assert_size_stride(getitem_52, (16, 12, 512, 64), (393216, 64, 768, 1)) | |
assert_size_stride(getitem_53, (16, 12, 512), (6144, 512, 1)) | |
assert_size_stride(getitem_54, (), ()) | |
assert_size_stride(getitem_55, (), ()) | |
assert_size_stride(view_258, (8192, 768), (768, 1)) | |
assert_size_stride(gt_35, (16, 512, 768), (393216, 768, 1)) | |
assert_size_stride(mul_152, (16, 512, 768), (393216, 768, 1)) | |
assert_size_stride(view_260, (8192, 768), (768, 1)) | |
assert_size_stride(addmm_70, (8192, 3072), (3072, 1)) | |
assert_size_stride(view_262, (8192, 3072), (3072, 1)) | |
assert_size_stride(gt_36, (16, 512, 768), (393216, 768, 1)) | |
assert_size_stride(mul_159, (16, 512, 768), (393216, 768, 1)) | |
assert_size_stride(view_264, (8192, 768), (768, 1)) | |
assert_size_stride(addmm_72, (8192, 768), (768, 1)) | |
assert_size_stride(getitem_51, (16, 512, 1), (512, 1, 1)) | |
assert_size_stride(rsqrt_25, (16, 512, 1), (512, 1, 1)) | |
assert_size_stride(view_266, (8192, 768), (768, 1)) | |
assert_size_stride(view_267, (16, 512, 30522), (15630336, 30528, 1)) | |
assert_size_stride(amax_12, (8192, 1), (1, 1)) | |
assert_size_stride(log, (8192, 1), (1, 1)) | |
assert_size_stride(convert_element_type_510, (), ()) | |
assert_size_stride(permute_134, (30522, 768), (768, 1)) | |
assert_size_stride(permute_138, (768, 768), (768, 1)) | |
assert_size_stride(div_27, (16, 512, 1), (512, 1, 1)) | |
assert_size_stride(permute_142, (768, 3072), (3072, 1)) | |
assert_size_stride(permute_146, (3072, 768), (768, 1)) | |
assert_size_stride(div_28, (16, 512, 1), (512, 1, 1)) | |
assert_size_stride(permute_150, (768, 768), (768, 1)) | |
assert_size_stride(permute_162, (768, 768), (768, 1)) | |
assert_size_stride(permute_167, (768, 768), (768, 1)) | |
assert_size_stride(permute_171, (768, 768), (768, 1)) | |
assert_size_stride(div_30, (16, 512, 1), (512, 1, 1)) | |
assert_size_stride(permute_175, (768, 3072), (3072, 1)) | |
assert_size_stride(permute_179, (3072, 768), (768, 1)) | |
assert_size_stride(div_31, (16, 512, 1), (512, 1, 1)) | |
assert_size_stride(permute_183, (768, 768), (768, 1)) | |
assert_size_stride(permute_195, (768, 768), (768, 1)) | |
assert_size_stride(permute_200, (768, 768), (768, 1)) | |
assert_size_stride(permute_204, (768, 768), (768, 1)) | |
assert_size_stride(div_33, (16, 512, 1), (512, 1, 1)) | |
assert_size_stride(permute_208, (768, 3072), (3072, 1)) | |
assert_size_stride(permute_212, (3072, 768), (768, 1)) | |
assert_size_stride(div_34, (16, 512, 1), (512, 1, 1)) | |
assert_size_stride(permute_216, (768, 768), (768, 1)) | |
assert_size_stride(permute_228, (768, 768), (768, 1)) | |
assert_size_stride(permute_233, (768, 768), (768, 1)) | |
assert_size_stride(permute_237, (768, 768), (768, 1)) | |
assert_size_stride(div_36, (16, 512, 1), (512, 1, 1)) | |
assert_size_stride(permute_241, (768, 3072), (3072, 1)) | |
assert_size_stride(permute_245, (3072, 768), (768, 1)) | |
assert_size_stride(div_37, (16, 512, 1), (512, 1, 1)) | |
assert_size_stride(permute_249, (768, 768), (768, 1)) | |
assert_size_stride(permute_261, (768, 768), (768, 1)) | |
assert_size_stride(permute_266, (768, 768), (768, 1)) | |
assert_size_stride(permute_270, (768, 768), (768, 1)) | |
assert_size_stride(div_39, (16, 512, 1), (512, 1, 1)) | |
assert_size_stride(permute_274, (768, 3072), (3072, 1)) | |
assert_size_stride(permute_278, (3072, 768), (768, 1)) | |
assert_size_stride(div_40, (16, 512, 1), (512, 1, 1)) | |
assert_size_stride(permute_282, (768, 768), (768, 1)) | |
assert_size_stride(permute_294, (768, 768), (768, 1)) | |
assert_size_stride(permute_299, (768, 768), (768, 1)) | |
assert_size_stride(permute_303, (768, 768), (768, 1)) | |
assert_size_stride(div_42, (16, 512, 1), (512, 1, 1)) | |
assert_size_stride(permute_307, (768, 3072), (3072, 1)) | |
assert_size_stride(permute_311, (3072, 768), (768, 1)) | |
assert_size_stride(div_43, (16, 512, 1), (512, 1, 1)) | |
assert_size_stride(permute_315, (768, 768), (768, 1)) | |
assert_size_stride(permute_327, (768, 768), (768, 1)) | |
assert_size_stride(permute_332, (768, 768), (768, 1)) | |
assert_size_stride(permute_336, (768, 768), (768, 1)) | |
assert_size_stride(div_45, (16, 512, 1), (512, 1, 1)) | |
assert_size_stride(permute_340, (768, 3072), (3072, 1)) | |
assert_size_stride(permute_344, (3072, 768), (768, 1)) | |
assert_size_stride(div_46, (16, 512, 1), (512, 1, 1)) | |
assert_size_stride(permute_348, (768, 768), (768, 1)) | |
assert_size_stride(permute_360, (768, 768), (768, 1)) | |
assert_size_stride(permute_365, (768, 768), (768, 1)) | |
assert_size_stride(permute_369, (768, 768), (768, 1)) | |
assert_size_stride(div_48, (16, 512, 1), (512, 1, 1)) | |
assert_size_stride(permute_373, (768, 3072), (3072, 1)) | |
assert_size_stride(permute_377, (3072, 768), (768, 1)) | |
assert_size_stride(div_49, (16, 512, 1), (512, 1, 1)) | |
assert_size_stride(permute_381, (768, 768), (768, 1)) | |
assert_size_stride(permute_393, (768, 768), (768, 1)) | |
assert_size_stride(permute_398, (768, 768), (768, 1)) | |
assert_size_stride(permute_402, (768, 768), (768, 1)) | |
assert_size_stride(div_51, (16, 512, 1), (512, 1, 1)) | |
assert_size_stride(permute_406, (768, 3072), (3072, 1)) | |
assert_size_stride(permute_410, (3072, 768), (768, 1)) | |
assert_size_stride(div_52, (16, 512, 1), (512, 1, 1)) | |
assert_size_stride(permute_414, (768, 768), (768, 1)) | |
assert_size_stride(permute_426, (768, 768), (768, 1)) | |
assert_size_stride(permute_431, (768, 768), (768, 1)) | |
assert_size_stride(permute_435, (768, 768), (768, 1)) | |
assert_size_stride(div_54, (16, 512, 1), (512, 1, 1)) | |
assert_size_stride(permute_439, (768, 3072), (3072, 1)) | |
assert_size_stride(permute_443, (3072, 768), (768, 1)) | |
assert_size_stride(div_55, (16, 512, 1), (512, 1, 1)) | |
assert_size_stride(permute_447, (768, 768), (768, 1)) | |
assert_size_stride(permute_459, (768, 768), (768, 1)) | |
assert_size_stride(permute_464, (768, 768), (768, 1)) | |
assert_size_stride(permute_468, (768, 768), (768, 1)) | |
assert_size_stride(div_57, (16, 512, 1), (512, 1, 1)) | |
assert_size_stride(permute_472, (768, 3072), (3072, 1)) | |
assert_size_stride(permute_476, (3072, 768), (768, 1)) | |
assert_size_stride(div_58, (16, 512, 1), (512, 1, 1)) | |
assert_size_stride(permute_480, (768, 768), (768, 1)) | |
assert_size_stride(permute_492, (768, 768), (768, 1)) | |
assert_size_stride(permute_497, (768, 768), (768, 1)) | |
assert_size_stride(permute_501, (768, 768), (768, 1)) | |
assert_size_stride(div_60, (16, 512, 1), (512, 1, 1)) | |
assert_size_stride(permute_505, (768, 3072), (3072, 1)) | |
assert_size_stride(permute_509, (3072, 768), (768, 1)) | |
assert_size_stride(div_61, (16, 512, 1), (512, 1, 1)) | |
assert_size_stride(permute_513, (768, 768), (768, 1)) | |
assert_size_stride(permute_525, (768, 768), (768, 1)) | |
assert_size_stride(permute_530, (768, 768), (768, 1)) | |
assert_size_stride(permute_534, (768, 768), (768, 1)) | |
assert_size_stride(div_63, (16, 512, 1), (512, 1, 1)) | |
assert_size_stride(tangents_1, (), ()) | |
assert_size_stride(tangents_2, (16, 512, 30522), (15627264, 30522, 1)) | |
with torch.cuda._DeviceGuard(0): | |
torch.cuda.set_device(0) | |
buf0 = empty_strided_cuda((8192, 30522), (30528, 1), torch.float32) | |
# Source Nodes: [masked_lm_loss], Original ATen: [aten.nll_loss_backward, aten.nll_loss_forward] | |
stream0 = get_raw_stream(0) | |
triton_poi_fused_nll_loss_backward_nll_loss_forward_0.run(buf0, 250036224, grid=grid(250036224), stream=stream0) | |
# Source Nodes: [masked_lm_loss], Original ATen: [aten.nll_loss_backward, aten.nll_loss_forward] | |
triton_poi_fused_nll_loss_backward_nll_loss_forward_1.run(primals_207, buf0, 8192, grid=grid(8192), stream=stream0) | |
buf3 = empty_strided_cuda((16, 512, 30522), (15630336, 30528, 1), torch.float16) | |
# Source Nodes: [masked_lm_loss], Original ATen: [aten._log_softmax_backward_data, aten.add, aten.nll_loss_backward, aten.nll_loss_forward] | |
triton_red_fused__log_softmax_backward_data_add_nll_loss_backward_nll_loss_forward_2.run(buf0, primals_207, tangents_1, convert_element_type_510, tangents_2, view_267, amax_12, log, buf3, 8192, 30522, grid=grid(8192), stream=stream0) | |
del amax_12 | |
del buf0 | |
del convert_element_type_510 | |
del log | |
del primals_207 | |
del tangents_1 | |
del tangents_2 | |
del view_267 | |
buf4 = empty_strided_cuda((8192, 30528), (30528, 1), torch.float16) | |
# Source Nodes: [], Original ATen: [] | |
triton_poi_fused_3.run(buf3, buf4, 250085376, grid=grid(250085376), stream=stream0) | |
buf5 = empty_strided_cuda((30528, 768), (768, 1), torch.float16) | |
# Source Nodes: [], Original ATen: [] | |
triton_poi_fused_4.run(permute_134, buf5, 23445504, grid=grid(23445504), stream=stream0) | |
del permute_134 | |
buf6 = empty_strided_cuda((8192, 768), (768, 1), torch.float16) | |
# Source Nodes: [], Original ATen: [] | |
extern_kernels.mm(buf4, buf5, out=buf6) | |
del buf4 | |
del buf5 | |
buf7 = empty_strided_cuda((30522, 768), (768, 1), torch.float16) | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf3, (30522, 8192), (1, 30528), 0), view_266, out=buf7) | |
del view_266 | |
buf10 = empty_strided_cuda((30522, ), (1, ), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum] | |
triton_red_fused__to_copy_sum_5.run(buf3, buf10, 30522, 8192, grid=grid(30522), stream=stream0) | |
del buf3 | |
buf9 = empty_strided_cuda((30522, 768), (768, 1), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy] | |
triton_poi_fused__to_copy_6.run(buf7, buf9, 23440896, grid=grid(23440896), stream=stream0) | |
del buf7 | |
buf18 = empty_strided_cuda((8192, 768), (768, 1), torch.float16) | |
# Source Nodes: [hidden_states_97, hidden_states_98], Original ATen: [aten._to_copy, aten.gelu, aten.gelu_backward, aten.native_layer_norm, aten.native_layer_norm_backward, aten.view] | |
triton_per_fused__to_copy_gelu_gelu_backward_native_layer_norm_native_layer_norm_backward_view_7.run(buf6, primals_200, addmm_72, getitem_51, rsqrt_25, buf18, 8192, 768, grid=grid(8192), stream=stream0) | |
del primals_200 | |
buf13 = empty_strided_cuda((768, 64), (1, 768), torch.float32) | |
buf15 = empty_strided_cuda((768, 64), (1, 768), torch.float32) | |
# Source Nodes: [hidden_states_97, hidden_states_98], Original ATen: [aten._to_copy, aten.gelu, aten.native_layer_norm, aten.native_layer_norm_backward] | |
triton_red_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_8.run(buf6, addmm_72, getitem_51, rsqrt_25, buf13, buf15, 49152, 128, grid=grid(49152), stream=stream0) | |
del addmm_72 | |
del getitem_51 | |
del rsqrt_25 | |
buf14 = empty_strided_cuda((768, ), (1, ), torch.float32) | |
# Source Nodes: [hidden_states_97, hidden_states_98], Original ATen: [aten._to_copy, aten.gelu, aten.native_layer_norm, aten.native_layer_norm_backward] | |
triton_per_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_9.run(buf13, buf14, 768, 64, grid=grid(768), stream=stream0) | |
buf16 = empty_strided_cuda((768, ), (1, ), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.native_layer_norm_backward] | |
triton_per_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_9.run(buf15, buf16, 768, 64, grid=grid(768), stream=stream0) | |
buf19 = buf6; del buf6 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(buf18, permute_138, out=buf19) | |
del permute_138 | |
buf20 = empty_strided_cuda((768, 768), (768, 1), torch.float16) | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf18, (768, 8192), (1, 768), 0), view_264, out=buf20) | |
del view_264 | |
buf21 = reinterpret_tensor(buf15, (1, 768, 64), (49152, 1, 768), 0); del buf15 # reuse | |
# Source Nodes: [], Original ATen: [aten.sum] | |
triton_red_fused_sum_10.run(buf18, buf21, 49152, 128, grid=grid(49152), stream=stream0) | |
buf24 = empty_strided_cuda((768, ), (1, ), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum] | |
triton_per_fused__to_copy_sum_11.run(buf21, buf24, 768, 64, grid=grid(768), stream=stream0) | |
buf23 = empty_strided_cuda((768, 768), (768, 1), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy] | |
triton_poi_fused__to_copy_12.run(buf20, buf23, 589824, grid=grid(589824), stream=stream0) | |
buf27 = empty_strided_cuda((16, 512, 768), (393216, 768, 1), torch.float32) | |
buf32 = reinterpret_tensor(buf18, (16, 512, 768), (393216, 768, 1), 0); del buf18 # reuse | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.native_dropout_backward, aten.native_layer_norm_backward] | |
triton_per_fused__to_copy_native_dropout_backward_native_layer_norm_backward_13.run(buf19, primals_196, mul_159, div_27, gt_36, buf27, buf32, 8192, 768, grid=grid(8192), stream=stream0) | |
del div_27 | |
del gt_36 | |
del primals_196 | |
buf28 = reinterpret_tensor(buf21, (768, 64), (1, 768), 0); del buf21 # reuse | |
buf30 = buf13; del buf13 # reuse | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.native_layer_norm_backward] | |
triton_red_fused__to_copy_native_layer_norm_backward_14.run(buf19, mul_159, buf28, buf30, 49152, 128, grid=grid(49152), stream=stream0) | |
del mul_159 | |
buf29 = empty_strided_cuda((768, ), (1, ), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.native_layer_norm_backward] | |
triton_per_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_9.run(buf28, buf29, 768, 64, grid=grid(768), stream=stream0) | |
buf31 = empty_strided_cuda((768, ), (1, ), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.native_layer_norm_backward] | |
triton_per_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_9.run(buf30, buf31, 768, 64, grid=grid(768), stream=stream0) | |
buf33 = empty_strided_cuda((8192, 3072), (3072, 1), torch.float16) | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf32, (8192, 768), (768, 1), 0), permute_142, out=buf33) | |
del permute_142 | |
buf34 = empty_strided_cuda((768, 3072), (3072, 1), torch.float16) | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf32, (768, 8192), (1, 768), 0), view_262, out=buf34) | |
del view_262 | |
buf35 = reinterpret_tensor(buf30, (1, 768, 64), (49152, 1, 768), 0); del buf30 # reuse | |
# Source Nodes: [], Original ATen: [aten.sum] | |
triton_red_fused_sum_15.run(buf32, buf35, 49152, 128, grid=grid(49152), stream=stream0) | |
buf38 = empty_strided_cuda((768, ), (1, ), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum] | |
triton_per_fused__to_copy_sum_11.run(buf35, buf38, 768, 64, grid=grid(768), stream=stream0) | |
buf37 = empty_strided_cuda((768, 3072), (3072, 1), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy] | |
triton_poi_fused__to_copy_16.run(buf34, buf37, 2359296, grid=grid(2359296), stream=stream0) | |
buf39 = reinterpret_tensor(buf33, (16, 512, 3072), (1572864, 3072, 1), 0); del buf33 # reuse | |
# Source Nodes: [hidden_states_92], Original ATen: [aten.gelu, aten.gelu_backward] | |
triton_poi_fused_gelu_gelu_backward_17.run(buf39, addmm_70, 25165824, grid=grid(25165824), stream=stream0) | |
del addmm_70 | |
buf40 = reinterpret_tensor(buf32, (8192, 768), (768, 1), 0); del buf32 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf39, (8192, 3072), (3072, 1), 0), permute_146, out=buf40) | |
del permute_146 | |
buf41 = reinterpret_tensor(buf34, (3072, 768), (768, 1), 0); del buf34 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf39, (3072, 8192), (1, 3072), 0), view_260, out=buf41) | |
del view_260 | |
buf42 = empty_strided_cuda((1, 3072, 32), (98304, 1, 3072), torch.float32) | |
# Source Nodes: [], Original ATen: [aten.sum] | |
triton_red_fused_sum_18.run(buf39, buf42, 98304, 256, grid=grid(98304), stream=stream0) | |
buf45 = empty_strided_cuda((3072, ), (1, ), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum] | |
triton_per_fused__to_copy_sum_19.run(buf42, buf45, 3072, 32, grid=grid(3072), stream=stream0) | |
buf44 = empty_strided_cuda((3072, 768), (768, 1), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy] | |
triton_poi_fused__to_copy_20.run(buf41, buf44, 2359296, grid=grid(2359296), stream=stream0) | |
buf48 = empty_strided_cuda((16, 512, 768), (393216, 768, 1), torch.float32) | |
buf53 = reinterpret_tensor(buf19, (16, 512, 768), (393216, 768, 1), 0); del buf19 # reuse | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_dropout_backward, aten.native_layer_norm_backward] | |
triton_per_fused__to_copy_add_native_dropout_backward_native_layer_norm_backward_21.run(buf27, buf40, primals_190, mul_152, div_28, gt_35, buf48, buf53, 8192, 768, grid=grid(8192), stream=stream0) | |
del div_28 | |
del gt_35 | |
del primals_190 | |
buf49 = reinterpret_tensor(buf35, (768, 64), (1, 768), 0); del buf35 # reuse | |
buf51 = buf28; del buf28 # reuse | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward] | |
triton_red_fused__to_copy_add_native_layer_norm_backward_22.run(buf27, buf40, mul_152, buf49, buf51, 49152, 128, grid=grid(49152), stream=stream0) | |
del mul_152 | |
buf50 = empty_strided_cuda((768, ), (1, ), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward] | |
triton_per_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_9.run(buf49, buf50, 768, 64, grid=grid(768), stream=stream0) | |
buf52 = empty_strided_cuda((768, ), (1, ), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward] | |
triton_per_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_9.run(buf51, buf52, 768, 64, grid=grid(768), stream=stream0) | |
buf54 = buf40; del buf40 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf53, (8192, 768), (768, 1), 0), permute_150, out=buf54) | |
del permute_150 | |
buf55 = buf20; del buf20 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf53, (768, 8192), (1, 768), 0), view_258, out=buf55) | |
del view_258 | |
buf56 = reinterpret_tensor(buf51, (1, 768, 64), (49152, 1, 768), 0); del buf51 # reuse | |
# Source Nodes: [], Original ATen: [aten.sum] | |
triton_red_fused_sum_15.run(buf53, buf56, 49152, 128, grid=grid(49152), stream=stream0) | |
buf59 = empty_strided_cuda((768, ), (1, ), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum] | |
triton_per_fused__to_copy_sum_11.run(buf56, buf59, 768, 64, grid=grid(768), stream=stream0) | |
buf58 = empty_strided_cuda((768, 768), (768, 1), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy] | |
triton_poi_fused__to_copy_12.run(buf55, buf58, 589824, grid=grid(589824), stream=stream0) | |
buf60 = reinterpret_tensor(buf53, (16, 12, 512, 64), (393216, 32768, 64, 1), 0); del buf53 # reuse | |
# Source Nodes: [], Original ATen: [aten.clone] | |
triton_poi_fused_clone_23.run(buf54, buf60, 6291456, grid=grid(6291456), stream=stream0) | |
del buf54 | |
# Source Nodes: [], Original ATen: [aten.clone] | |
buf61 = aten._scaled_dot_product_flash_attention_backward.default(buf60, permute_default, permute_default_1, permute_default_2, getitem_52, getitem_53, None, None, 512, 512, 0.1, False, getitem_54, getitem_55, scale=0.125) | |
del getitem_52 | |
del getitem_53 | |
del getitem_54 | |
del getitem_55 | |
del permute_default | |
del permute_default_1 | |
del permute_default_2 | |
buf62 = buf61[0] | |
buf63 = buf61[1] | |
buf64 = buf61[2] | |
del buf61 | |
buf65 = reinterpret_tensor(buf60, (8192, 768), (768, 1), 0); del buf60 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf64, (8192, 768), (768, 1), 0), permute_162, out=buf65) | |
del permute_162 | |
buf66 = buf55; del buf55 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf64, (768, 8192), (1, 768), 0), view_242, out=buf66) | |
buf67 = buf56; del buf56 # reuse | |
# Source Nodes: [], Original ATen: [aten.sum] | |
triton_red_fused_sum_24.run(buf64, buf67, 49152, 128, grid=grid(49152), stream=stream0) | |
buf70 = empty_strided_cuda((768, ), (1, ), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum] | |
triton_per_fused__to_copy_sum_11.run(buf67, buf70, 768, 64, grid=grid(768), stream=stream0) | |
buf69 = empty_strided_cuda((768, 768), (768, 1), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy] | |
triton_poi_fused__to_copy_12.run(buf66, buf69, 589824, grid=grid(589824), stream=stream0) | |
buf71 = reinterpret_tensor(buf64, (8192, 768), (768, 1), 0); del buf64 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf63, (8192, 768), (768, 1), 0), permute_167, out=buf71) | |
del permute_167 | |
buf72 = buf66; del buf66 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf63, (768, 8192), (1, 768), 0), view_242, out=buf72) | |
buf73 = buf67; del buf67 # reuse | |
# Source Nodes: [], Original ATen: [aten.sum] | |
triton_red_fused_sum_24.run(buf63, buf73, 49152, 128, grid=grid(49152), stream=stream0) | |
buf76 = empty_strided_cuda((768, ), (1, ), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum] | |
triton_per_fused__to_copy_sum_11.run(buf73, buf76, 768, 64, grid=grid(768), stream=stream0) | |
buf75 = empty_strided_cuda((768, 768), (768, 1), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy] | |
triton_poi_fused__to_copy_12.run(buf72, buf75, 589824, grid=grid(589824), stream=stream0) | |
buf77 = reinterpret_tensor(buf63, (8192, 768), (768, 1), 0); del buf63 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf62, (8192, 768), (768, 1), 0), permute_171, out=buf77) | |
del permute_171 | |
buf78 = buf72; del buf72 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf62, (768, 8192), (1, 768), 0), view_242, out=buf78) | |
del view_242 | |
buf79 = buf73; del buf73 # reuse | |
# Source Nodes: [], Original ATen: [aten.sum] | |
triton_red_fused_sum_24.run(buf62, buf79, 49152, 128, grid=grid(49152), stream=stream0) | |
buf82 = empty_strided_cuda((768, ), (1, ), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum] | |
triton_per_fused__to_copy_sum_11.run(buf79, buf82, 768, 64, grid=grid(768), stream=stream0) | |
buf81 = empty_strided_cuda((768, 768), (768, 1), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy] | |
triton_poi_fused__to_copy_12.run(buf78, buf81, 589824, grid=grid(589824), stream=stream0) | |
buf86 = buf27; del buf27 # reuse | |
buf91 = reinterpret_tensor(buf62, (16, 512, 768), (393216, 768, 1), 0); del buf62 # reuse | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_dropout_backward, aten.native_layer_norm_backward] | |
triton_per_fused__to_copy_add_native_dropout_backward_native_layer_norm_backward_25.run(buf48, buf65, buf71, buf77, primals_180, mul_146, div_30, gt_33, buf86, buf91, 8192, 768, grid=grid(8192), stream=stream0) | |
del div_30 | |
del gt_33 | |
del primals_180 | |
buf87 = reinterpret_tensor(buf79, (768, 64), (1, 768), 0); del buf79 # reuse | |
buf89 = buf49; del buf49 # reuse | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward] | |
triton_red_fused__to_copy_add_native_layer_norm_backward_26.run(buf48, buf65, buf71, buf77, mul_146, buf87, buf89, 49152, 128, grid=grid(49152), stream=stream0) | |
del buf65 | |
del buf71 | |
del mul_146 | |
buf88 = empty_strided_cuda((768, ), (1, ), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward] | |
triton_per_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_9.run(buf87, buf88, 768, 64, grid=grid(768), stream=stream0) | |
buf90 = empty_strided_cuda((768, ), (1, ), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward] | |
triton_per_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_9.run(buf89, buf90, 768, 64, grid=grid(768), stream=stream0) | |
buf92 = reinterpret_tensor(buf39, (8192, 3072), (3072, 1), 0); del buf39 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf91, (8192, 768), (768, 1), 0), permute_175, out=buf92) | |
del permute_175 | |
buf93 = reinterpret_tensor(buf41, (768, 3072), (3072, 1), 0); del buf41 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf91, (768, 8192), (1, 768), 0), view_240, out=buf93) | |
del view_240 | |
buf94 = reinterpret_tensor(buf89, (1, 768, 64), (49152, 1, 768), 0); del buf89 # reuse | |
# Source Nodes: [], Original ATen: [aten.sum] | |
triton_red_fused_sum_15.run(buf91, buf94, 49152, 128, grid=grid(49152), stream=stream0) | |
buf97 = empty_strided_cuda((768, ), (1, ), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum] | |
triton_per_fused__to_copy_sum_11.run(buf94, buf97, 768, 64, grid=grid(768), stream=stream0) | |
buf96 = empty_strided_cuda((768, 3072), (3072, 1), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy] | |
triton_poi_fused__to_copy_16.run(buf93, buf96, 2359296, grid=grid(2359296), stream=stream0) | |
buf98 = reinterpret_tensor(buf92, (16, 512, 3072), (1572864, 3072, 1), 0); del buf92 # reuse | |
# Source Nodes: [hidden_states_84], Original ATen: [aten.gelu, aten.gelu_backward] | |
triton_poi_fused_gelu_gelu_backward_17.run(buf98, addmm_64, 25165824, grid=grid(25165824), stream=stream0) | |
del addmm_64 | |
buf99 = reinterpret_tensor(buf91, (8192, 768), (768, 1), 0); del buf91 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf98, (8192, 3072), (3072, 1), 0), permute_179, out=buf99) | |
del permute_179 | |
buf100 = reinterpret_tensor(buf93, (3072, 768), (768, 1), 0); del buf93 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf98, (3072, 8192), (1, 3072), 0), view_238, out=buf100) | |
del view_238 | |
buf101 = buf42; del buf42 # reuse | |
# Source Nodes: [], Original ATen: [aten.sum] | |
triton_red_fused_sum_18.run(buf98, buf101, 98304, 256, grid=grid(98304), stream=stream0) | |
buf104 = empty_strided_cuda((3072, ), (1, ), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum] | |
triton_per_fused__to_copy_sum_19.run(buf101, buf104, 3072, 32, grid=grid(3072), stream=stream0) | |
buf103 = empty_strided_cuda((3072, 768), (768, 1), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy] | |
triton_poi_fused__to_copy_20.run(buf100, buf103, 2359296, grid=grid(2359296), stream=stream0) | |
buf107 = buf48; del buf48 # reuse | |
buf112 = reinterpret_tensor(buf77, (16, 512, 768), (393216, 768, 1), 0); del buf77 # reuse | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_dropout_backward, aten.native_layer_norm_backward] | |
triton_per_fused__to_copy_add_native_dropout_backward_native_layer_norm_backward_21.run(buf86, buf99, primals_174, mul_139, div_31, gt_32, buf107, buf112, 8192, 768, grid=grid(8192), stream=stream0) | |
del div_31 | |
del gt_32 | |
del primals_174 | |
buf108 = reinterpret_tensor(buf94, (768, 64), (1, 768), 0); del buf94 # reuse | |
buf110 = buf87; del buf87 # reuse | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward] | |
triton_red_fused__to_copy_add_native_layer_norm_backward_22.run(buf86, buf99, mul_139, buf108, buf110, 49152, 128, grid=grid(49152), stream=stream0) | |
del mul_139 | |
buf109 = empty_strided_cuda((768, ), (1, ), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward] | |
triton_per_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_9.run(buf108, buf109, 768, 64, grid=grid(768), stream=stream0) | |
buf111 = empty_strided_cuda((768, ), (1, ), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward] | |
triton_per_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_9.run(buf110, buf111, 768, 64, grid=grid(768), stream=stream0) | |
buf113 = buf99; del buf99 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf112, (8192, 768), (768, 1), 0), permute_183, out=buf113) | |
del permute_183 | |
buf114 = buf78; del buf78 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf112, (768, 8192), (1, 768), 0), view_236, out=buf114) | |
del view_236 | |
buf115 = reinterpret_tensor(buf110, (1, 768, 64), (49152, 1, 768), 0); del buf110 # reuse | |
# Source Nodes: [], Original ATen: [aten.sum] | |
triton_red_fused_sum_15.run(buf112, buf115, 49152, 128, grid=grid(49152), stream=stream0) | |
buf118 = empty_strided_cuda((768, ), (1, ), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum] | |
triton_per_fused__to_copy_sum_11.run(buf115, buf118, 768, 64, grid=grid(768), stream=stream0) | |
buf117 = empty_strided_cuda((768, 768), (768, 1), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy] | |
triton_poi_fused__to_copy_12.run(buf114, buf117, 589824, grid=grid(589824), stream=stream0) | |
buf119 = reinterpret_tensor(buf112, (16, 12, 512, 64), (393216, 32768, 64, 1), 0); del buf112 # reuse | |
# Source Nodes: [], Original ATen: [aten.clone] | |
triton_poi_fused_clone_23.run(buf113, buf119, 6291456, grid=grid(6291456), stream=stream0) | |
del buf113 | |
# Source Nodes: [], Original ATen: [aten.clone] | |
buf120 = aten._scaled_dot_product_flash_attention_backward.default(buf119, permute_default_6, permute_default_7, permute_default_8, getitem_59, getitem_60, None, None, 512, 512, 0.1, False, getitem_61, getitem_62, scale=0.125) | |
del getitem_59 | |
del getitem_60 | |
del getitem_61 | |
del getitem_62 | |
del permute_default_6 | |
del permute_default_7 | |
del permute_default_8 | |
buf121 = buf120[0] | |
buf122 = buf120[1] | |
buf123 = buf120[2] | |
del buf120 | |
buf124 = reinterpret_tensor(buf119, (8192, 768), (768, 1), 0); del buf119 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf123, (8192, 768), (768, 1), 0), permute_195, out=buf124) | |
del permute_195 | |
buf125 = buf114; del buf114 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf123, (768, 8192), (1, 768), 0), view_220, out=buf125) | |
buf126 = buf115; del buf115 # reuse | |
# Source Nodes: [], Original ATen: [aten.sum] | |
triton_red_fused_sum_24.run(buf123, buf126, 49152, 128, grid=grid(49152), stream=stream0) | |
buf129 = empty_strided_cuda((768, ), (1, ), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum] | |
triton_per_fused__to_copy_sum_11.run(buf126, buf129, 768, 64, grid=grid(768), stream=stream0) | |
buf128 = empty_strided_cuda((768, 768), (768, 1), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy] | |
triton_poi_fused__to_copy_12.run(buf125, buf128, 589824, grid=grid(589824), stream=stream0) | |
buf130 = reinterpret_tensor(buf123, (8192, 768), (768, 1), 0); del buf123 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf122, (8192, 768), (768, 1), 0), permute_200, out=buf130) | |
del permute_200 | |
buf131 = buf125; del buf125 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf122, (768, 8192), (1, 768), 0), view_220, out=buf131) | |
buf132 = buf126; del buf126 # reuse | |
# Source Nodes: [], Original ATen: [aten.sum] | |
triton_red_fused_sum_24.run(buf122, buf132, 49152, 128, grid=grid(49152), stream=stream0) | |
buf135 = empty_strided_cuda((768, ), (1, ), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum] | |
triton_per_fused__to_copy_sum_11.run(buf132, buf135, 768, 64, grid=grid(768), stream=stream0) | |
buf134 = empty_strided_cuda((768, 768), (768, 1), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy] | |
triton_poi_fused__to_copy_12.run(buf131, buf134, 589824, grid=grid(589824), stream=stream0) | |
buf136 = reinterpret_tensor(buf122, (8192, 768), (768, 1), 0); del buf122 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf121, (8192, 768), (768, 1), 0), permute_204, out=buf136) | |
del permute_204 | |
buf137 = buf131; del buf131 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf121, (768, 8192), (1, 768), 0), view_220, out=buf137) | |
del view_220 | |
buf138 = buf132; del buf132 # reuse | |
# Source Nodes: [], Original ATen: [aten.sum] | |
triton_red_fused_sum_24.run(buf121, buf138, 49152, 128, grid=grid(49152), stream=stream0) | |
buf141 = empty_strided_cuda((768, ), (1, ), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum] | |
triton_per_fused__to_copy_sum_11.run(buf138, buf141, 768, 64, grid=grid(768), stream=stream0) | |
buf140 = empty_strided_cuda((768, 768), (768, 1), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy] | |
triton_poi_fused__to_copy_12.run(buf137, buf140, 589824, grid=grid(589824), stream=stream0) | |
buf145 = buf86; del buf86 # reuse | |
buf150 = reinterpret_tensor(buf121, (16, 512, 768), (393216, 768, 1), 0); del buf121 # reuse | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_dropout_backward, aten.native_layer_norm_backward] | |
triton_per_fused__to_copy_add_native_dropout_backward_native_layer_norm_backward_25.run(buf107, buf124, buf130, buf136, primals_164, mul_133, div_33, gt_30, buf145, buf150, 8192, 768, grid=grid(8192), stream=stream0) | |
del div_33 | |
del gt_30 | |
del primals_164 | |
buf146 = reinterpret_tensor(buf138, (768, 64), (1, 768), 0); del buf138 # reuse | |
buf148 = buf108; del buf108 # reuse | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward] | |
triton_red_fused__to_copy_add_native_layer_norm_backward_26.run(buf107, buf124, buf130, buf136, mul_133, buf146, buf148, 49152, 128, grid=grid(49152), stream=stream0) | |
del buf124 | |
del buf130 | |
del mul_133 | |
buf147 = empty_strided_cuda((768, ), (1, ), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward] | |
triton_per_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_9.run(buf146, buf147, 768, 64, grid=grid(768), stream=stream0) | |
buf149 = empty_strided_cuda((768, ), (1, ), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward] | |
triton_per_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_9.run(buf148, buf149, 768, 64, grid=grid(768), stream=stream0) | |
buf151 = reinterpret_tensor(buf98, (8192, 3072), (3072, 1), 0); del buf98 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf150, (8192, 768), (768, 1), 0), permute_208, out=buf151) | |
del permute_208 | |
buf152 = reinterpret_tensor(buf100, (768, 3072), (3072, 1), 0); del buf100 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf150, (768, 8192), (1, 768), 0), view_218, out=buf152) | |
del view_218 | |
buf153 = reinterpret_tensor(buf148, (1, 768, 64), (49152, 1, 768), 0); del buf148 # reuse | |
# Source Nodes: [], Original ATen: [aten.sum] | |
triton_red_fused_sum_15.run(buf150, buf153, 49152, 128, grid=grid(49152), stream=stream0) | |
buf156 = empty_strided_cuda((768, ), (1, ), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum] | |
triton_per_fused__to_copy_sum_11.run(buf153, buf156, 768, 64, grid=grid(768), stream=stream0) | |
buf155 = empty_strided_cuda((768, 3072), (3072, 1), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy] | |
triton_poi_fused__to_copy_16.run(buf152, buf155, 2359296, grid=grid(2359296), stream=stream0) | |
buf157 = reinterpret_tensor(buf151, (16, 512, 3072), (1572864, 3072, 1), 0); del buf151 # reuse | |
# Source Nodes: [hidden_states_76], Original ATen: [aten.gelu, aten.gelu_backward] | |
triton_poi_fused_gelu_gelu_backward_17.run(buf157, addmm_58, 25165824, grid=grid(25165824), stream=stream0) | |
del addmm_58 | |
buf158 = reinterpret_tensor(buf150, (8192, 768), (768, 1), 0); del buf150 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf157, (8192, 3072), (3072, 1), 0), permute_212, out=buf158) | |
del permute_212 | |
buf159 = reinterpret_tensor(buf152, (3072, 768), (768, 1), 0); del buf152 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf157, (3072, 8192), (1, 3072), 0), view_216, out=buf159) | |
del view_216 | |
buf160 = buf101; del buf101 # reuse | |
# Source Nodes: [], Original ATen: [aten.sum] | |
triton_red_fused_sum_18.run(buf157, buf160, 98304, 256, grid=grid(98304), stream=stream0) | |
buf163 = empty_strided_cuda((3072, ), (1, ), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum] | |
triton_per_fused__to_copy_sum_19.run(buf160, buf163, 3072, 32, grid=grid(3072), stream=stream0) | |
buf162 = empty_strided_cuda((3072, 768), (768, 1), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy] | |
triton_poi_fused__to_copy_20.run(buf159, buf162, 2359296, grid=grid(2359296), stream=stream0) | |
buf166 = buf107; del buf107 # reuse | |
buf171 = reinterpret_tensor(buf136, (16, 512, 768), (393216, 768, 1), 0); del buf136 # reuse | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_dropout_backward, aten.native_layer_norm_backward] | |
triton_per_fused__to_copy_add_native_dropout_backward_native_layer_norm_backward_21.run(buf145, buf158, primals_158, mul_126, div_34, gt_29, buf166, buf171, 8192, 768, grid=grid(8192), stream=stream0) | |
del div_34 | |
del gt_29 | |
del primals_158 | |
buf167 = reinterpret_tensor(buf153, (768, 64), (1, 768), 0); del buf153 # reuse | |
buf169 = buf146; del buf146 # reuse | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward] | |
triton_red_fused__to_copy_add_native_layer_norm_backward_22.run(buf145, buf158, mul_126, buf167, buf169, 49152, 128, grid=grid(49152), stream=stream0) | |
del mul_126 | |
buf168 = empty_strided_cuda((768, ), (1, ), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward] | |
triton_per_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_9.run(buf167, buf168, 768, 64, grid=grid(768), stream=stream0) | |
buf170 = empty_strided_cuda((768, ), (1, ), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward] | |
triton_per_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_9.run(buf169, buf170, 768, 64, grid=grid(768), stream=stream0) | |
buf172 = buf158; del buf158 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf171, (8192, 768), (768, 1), 0), permute_216, out=buf172) | |
del permute_216 | |
buf173 = buf137; del buf137 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf171, (768, 8192), (1, 768), 0), view_214, out=buf173) | |
del view_214 | |
buf174 = reinterpret_tensor(buf169, (1, 768, 64), (49152, 1, 768), 0); del buf169 # reuse | |
# Source Nodes: [], Original ATen: [aten.sum] | |
triton_red_fused_sum_15.run(buf171, buf174, 49152, 128, grid=grid(49152), stream=stream0) | |
buf177 = empty_strided_cuda((768, ), (1, ), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum] | |
triton_per_fused__to_copy_sum_11.run(buf174, buf177, 768, 64, grid=grid(768), stream=stream0) | |
buf176 = empty_strided_cuda((768, 768), (768, 1), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy] | |
triton_poi_fused__to_copy_12.run(buf173, buf176, 589824, grid=grid(589824), stream=stream0) | |
buf178 = reinterpret_tensor(buf171, (16, 12, 512, 64), (393216, 32768, 64, 1), 0); del buf171 # reuse | |
# Source Nodes: [], Original ATen: [aten.clone] | |
triton_poi_fused_clone_23.run(buf172, buf178, 6291456, grid=grid(6291456), stream=stream0) | |
del buf172 | |
# Source Nodes: [], Original ATen: [aten.clone] | |
buf179 = aten._scaled_dot_product_flash_attention_backward.default(buf178, permute_default_12, permute_default_13, permute_default_14, getitem_66, getitem_67, None, None, 512, 512, 0.1, False, getitem_68, getitem_69, scale=0.125) | |
del getitem_66 | |
del getitem_67 | |
del getitem_68 | |
del getitem_69 | |
del permute_default_12 | |
del permute_default_13 | |
del permute_default_14 | |
buf180 = buf179[0] | |
buf181 = buf179[1] | |
buf182 = buf179[2] | |
del buf179 | |
buf183 = reinterpret_tensor(buf178, (8192, 768), (768, 1), 0); del buf178 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf182, (8192, 768), (768, 1), 0), permute_228, out=buf183) | |
del permute_228 | |
buf184 = buf173; del buf173 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf182, (768, 8192), (1, 768), 0), view_198, out=buf184) | |
buf185 = buf174; del buf174 # reuse | |
# Source Nodes: [], Original ATen: [aten.sum] | |
triton_red_fused_sum_24.run(buf182, buf185, 49152, 128, grid=grid(49152), stream=stream0) | |
buf188 = empty_strided_cuda((768, ), (1, ), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum] | |
triton_per_fused__to_copy_sum_11.run(buf185, buf188, 768, 64, grid=grid(768), stream=stream0) | |
buf187 = empty_strided_cuda((768, 768), (768, 1), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy] | |
triton_poi_fused__to_copy_12.run(buf184, buf187, 589824, grid=grid(589824), stream=stream0) | |
buf189 = reinterpret_tensor(buf182, (8192, 768), (768, 1), 0); del buf182 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf181, (8192, 768), (768, 1), 0), permute_233, out=buf189) | |
del permute_233 | |
buf190 = buf184; del buf184 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf181, (768, 8192), (1, 768), 0), view_198, out=buf190) | |
buf191 = buf185; del buf185 # reuse | |
# Source Nodes: [], Original ATen: [aten.sum] | |
triton_red_fused_sum_24.run(buf181, buf191, 49152, 128, grid=grid(49152), stream=stream0) | |
buf194 = empty_strided_cuda((768, ), (1, ), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum] | |
triton_per_fused__to_copy_sum_11.run(buf191, buf194, 768, 64, grid=grid(768), stream=stream0) | |
buf193 = empty_strided_cuda((768, 768), (768, 1), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy] | |
triton_poi_fused__to_copy_12.run(buf190, buf193, 589824, grid=grid(589824), stream=stream0) | |
buf195 = reinterpret_tensor(buf181, (8192, 768), (768, 1), 0); del buf181 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf180, (8192, 768), (768, 1), 0), permute_237, out=buf195) | |
del permute_237 | |
buf196 = buf190; del buf190 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf180, (768, 8192), (1, 768), 0), view_198, out=buf196) | |
del view_198 | |
buf197 = buf191; del buf191 # reuse | |
# Source Nodes: [], Original ATen: [aten.sum] | |
triton_red_fused_sum_24.run(buf180, buf197, 49152, 128, grid=grid(49152), stream=stream0) | |
buf200 = empty_strided_cuda((768, ), (1, ), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum] | |
triton_per_fused__to_copy_sum_11.run(buf197, buf200, 768, 64, grid=grid(768), stream=stream0) | |
buf199 = empty_strided_cuda((768, 768), (768, 1), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy] | |
triton_poi_fused__to_copy_12.run(buf196, buf199, 589824, grid=grid(589824), stream=stream0) | |
buf204 = buf145; del buf145 # reuse | |
buf209 = reinterpret_tensor(buf180, (16, 512, 768), (393216, 768, 1), 0); del buf180 # reuse | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_dropout_backward, aten.native_layer_norm_backward] | |
triton_per_fused__to_copy_add_native_dropout_backward_native_layer_norm_backward_25.run(buf166, buf183, buf189, buf195, primals_148, mul_120, div_36, gt_27, buf204, buf209, 8192, 768, grid=grid(8192), stream=stream0) | |
del div_36 | |
del gt_27 | |
del primals_148 | |
buf205 = reinterpret_tensor(buf197, (768, 64), (1, 768), 0); del buf197 # reuse | |
buf207 = buf167; del buf167 # reuse | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward] | |
triton_red_fused__to_copy_add_native_layer_norm_backward_26.run(buf166, buf183, buf189, buf195, mul_120, buf205, buf207, 49152, 128, grid=grid(49152), stream=stream0) | |
del buf183 | |
del buf189 | |
del mul_120 | |
buf206 = empty_strided_cuda((768, ), (1, ), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward] | |
triton_per_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_9.run(buf205, buf206, 768, 64, grid=grid(768), stream=stream0) | |
buf208 = empty_strided_cuda((768, ), (1, ), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward] | |
triton_per_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_9.run(buf207, buf208, 768, 64, grid=grid(768), stream=stream0) | |
buf210 = reinterpret_tensor(buf157, (8192, 3072), (3072, 1), 0); del buf157 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf209, (8192, 768), (768, 1), 0), permute_241, out=buf210) | |
del permute_241 | |
buf211 = reinterpret_tensor(buf159, (768, 3072), (3072, 1), 0); del buf159 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf209, (768, 8192), (1, 768), 0), view_196, out=buf211) | |
del view_196 | |
buf212 = reinterpret_tensor(buf207, (1, 768, 64), (49152, 1, 768), 0); del buf207 # reuse | |
# Source Nodes: [], Original ATen: [aten.sum] | |
triton_red_fused_sum_15.run(buf209, buf212, 49152, 128, grid=grid(49152), stream=stream0) | |
buf215 = empty_strided_cuda((768, ), (1, ), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum] | |
triton_per_fused__to_copy_sum_11.run(buf212, buf215, 768, 64, grid=grid(768), stream=stream0) | |
buf214 = empty_strided_cuda((768, 3072), (3072, 1), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy] | |
triton_poi_fused__to_copy_16.run(buf211, buf214, 2359296, grid=grid(2359296), stream=stream0) | |
buf216 = reinterpret_tensor(buf210, (16, 512, 3072), (1572864, 3072, 1), 0); del buf210 # reuse | |
# Source Nodes: [hidden_states_68], Original ATen: [aten.gelu, aten.gelu_backward] | |
triton_poi_fused_gelu_gelu_backward_17.run(buf216, addmm_52, 25165824, grid=grid(25165824), stream=stream0) | |
del addmm_52 | |
buf217 = reinterpret_tensor(buf209, (8192, 768), (768, 1), 0); del buf209 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf216, (8192, 3072), (3072, 1), 0), permute_245, out=buf217) | |
del permute_245 | |
buf218 = reinterpret_tensor(buf211, (3072, 768), (768, 1), 0); del buf211 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf216, (3072, 8192), (1, 3072), 0), view_194, out=buf218) | |
del view_194 | |
buf219 = buf160; del buf160 # reuse | |
# Source Nodes: [], Original ATen: [aten.sum] | |
triton_red_fused_sum_18.run(buf216, buf219, 98304, 256, grid=grid(98304), stream=stream0) | |
buf222 = empty_strided_cuda((3072, ), (1, ), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum] | |
triton_per_fused__to_copy_sum_19.run(buf219, buf222, 3072, 32, grid=grid(3072), stream=stream0) | |
buf221 = empty_strided_cuda((3072, 768), (768, 1), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy] | |
triton_poi_fused__to_copy_20.run(buf218, buf221, 2359296, grid=grid(2359296), stream=stream0) | |
buf225 = buf166; del buf166 # reuse | |
buf230 = reinterpret_tensor(buf195, (16, 512, 768), (393216, 768, 1), 0); del buf195 # reuse | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_dropout_backward, aten.native_layer_norm_backward] | |
triton_per_fused__to_copy_add_native_dropout_backward_native_layer_norm_backward_21.run(buf204, buf217, primals_142, mul_113, div_37, gt_26, buf225, buf230, 8192, 768, grid=grid(8192), stream=stream0) | |
del div_37 | |
del gt_26 | |
del primals_142 | |
buf226 = reinterpret_tensor(buf212, (768, 64), (1, 768), 0); del buf212 # reuse | |
buf228 = buf205; del buf205 # reuse | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward] | |
triton_red_fused__to_copy_add_native_layer_norm_backward_22.run(buf204, buf217, mul_113, buf226, buf228, 49152, 128, grid=grid(49152), stream=stream0) | |
del mul_113 | |
buf227 = empty_strided_cuda((768, ), (1, ), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward] | |
triton_per_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_9.run(buf226, buf227, 768, 64, grid=grid(768), stream=stream0) | |
buf229 = empty_strided_cuda((768, ), (1, ), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward] | |
triton_per_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_9.run(buf228, buf229, 768, 64, grid=grid(768), stream=stream0) | |
buf231 = buf217; del buf217 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf230, (8192, 768), (768, 1), 0), permute_249, out=buf231) | |
del permute_249 | |
buf232 = buf196; del buf196 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf230, (768, 8192), (1, 768), 0), view_192, out=buf232) | |
del view_192 | |
buf233 = reinterpret_tensor(buf228, (1, 768, 64), (49152, 1, 768), 0); del buf228 # reuse | |
# Source Nodes: [], Original ATen: [aten.sum] | |
triton_red_fused_sum_15.run(buf230, buf233, 49152, 128, grid=grid(49152), stream=stream0) | |
buf236 = empty_strided_cuda((768, ), (1, ), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum] | |
triton_per_fused__to_copy_sum_11.run(buf233, buf236, 768, 64, grid=grid(768), stream=stream0) | |
buf235 = empty_strided_cuda((768, 768), (768, 1), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy] | |
triton_poi_fused__to_copy_12.run(buf232, buf235, 589824, grid=grid(589824), stream=stream0) | |
buf237 = reinterpret_tensor(buf230, (16, 12, 512, 64), (393216, 32768, 64, 1), 0); del buf230 # reuse | |
# Source Nodes: [], Original ATen: [aten.clone] | |
triton_poi_fused_clone_23.run(buf231, buf237, 6291456, grid=grid(6291456), stream=stream0) | |
del buf231 | |
# Source Nodes: [], Original ATen: [aten.clone] | |
buf238 = aten._scaled_dot_product_flash_attention_backward.default(buf237, permute_default_18, permute_default_19, permute_default_20, getitem_73, getitem_74, None, None, 512, 512, 0.1, False, getitem_75, getitem_76, scale=0.125) | |
del getitem_73 | |
del getitem_74 | |
del getitem_75 | |
del getitem_76 | |
del permute_default_18 | |
del permute_default_19 | |
del permute_default_20 | |
buf239 = buf238[0] | |
buf240 = buf238[1] | |
buf241 = buf238[2] | |
del buf238 | |
buf242 = reinterpret_tensor(buf237, (8192, 768), (768, 1), 0); del buf237 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf241, (8192, 768), (768, 1), 0), permute_261, out=buf242) | |
del permute_261 | |
buf243 = buf232; del buf232 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf241, (768, 8192), (1, 768), 0), view_176, out=buf243) | |
buf244 = buf233; del buf233 # reuse | |
# Source Nodes: [], Original ATen: [aten.sum] | |
triton_red_fused_sum_24.run(buf241, buf244, 49152, 128, grid=grid(49152), stream=stream0) | |
buf247 = empty_strided_cuda((768, ), (1, ), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum] | |
triton_per_fused__to_copy_sum_11.run(buf244, buf247, 768, 64, grid=grid(768), stream=stream0) | |
buf246 = empty_strided_cuda((768, 768), (768, 1), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy] | |
triton_poi_fused__to_copy_12.run(buf243, buf246, 589824, grid=grid(589824), stream=stream0) | |
buf248 = reinterpret_tensor(buf241, (8192, 768), (768, 1), 0); del buf241 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf240, (8192, 768), (768, 1), 0), permute_266, out=buf248) | |
del permute_266 | |
buf249 = buf243; del buf243 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf240, (768, 8192), (1, 768), 0), view_176, out=buf249) | |
buf250 = buf244; del buf244 # reuse | |
# Source Nodes: [], Original ATen: [aten.sum] | |
triton_red_fused_sum_24.run(buf240, buf250, 49152, 128, grid=grid(49152), stream=stream0) | |
buf253 = empty_strided_cuda((768, ), (1, ), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum] | |
triton_per_fused__to_copy_sum_11.run(buf250, buf253, 768, 64, grid=grid(768), stream=stream0) | |
buf252 = empty_strided_cuda((768, 768), (768, 1), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy] | |
triton_poi_fused__to_copy_12.run(buf249, buf252, 589824, grid=grid(589824), stream=stream0) | |
buf254 = reinterpret_tensor(buf240, (8192, 768), (768, 1), 0); del buf240 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf239, (8192, 768), (768, 1), 0), permute_270, out=buf254) | |
del permute_270 | |
buf255 = buf249; del buf249 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf239, (768, 8192), (1, 768), 0), view_176, out=buf255) | |
del view_176 | |
buf256 = buf250; del buf250 # reuse | |
# Source Nodes: [], Original ATen: [aten.sum] | |
triton_red_fused_sum_24.run(buf239, buf256, 49152, 128, grid=grid(49152), stream=stream0) | |
buf259 = empty_strided_cuda((768, ), (1, ), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum] | |
triton_per_fused__to_copy_sum_11.run(buf256, buf259, 768, 64, grid=grid(768), stream=stream0) | |
buf258 = empty_strided_cuda((768, 768), (768, 1), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy] | |
triton_poi_fused__to_copy_12.run(buf255, buf258, 589824, grid=grid(589824), stream=stream0) | |
buf263 = buf204; del buf204 # reuse | |
buf268 = reinterpret_tensor(buf239, (16, 512, 768), (393216, 768, 1), 0); del buf239 # reuse | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_dropout_backward, aten.native_layer_norm_backward] | |
triton_per_fused__to_copy_add_native_dropout_backward_native_layer_norm_backward_25.run(buf225, buf242, buf248, buf254, primals_132, mul_107, div_39, gt_24, buf263, buf268, 8192, 768, grid=grid(8192), stream=stream0) | |
del div_39 | |
del gt_24 | |
del primals_132 | |
buf264 = reinterpret_tensor(buf256, (768, 64), (1, 768), 0); del buf256 # reuse | |
buf266 = buf226; del buf226 # reuse | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward] | |
triton_red_fused__to_copy_add_native_layer_norm_backward_26.run(buf225, buf242, buf248, buf254, mul_107, buf264, buf266, 49152, 128, grid=grid(49152), stream=stream0) | |
del buf242 | |
del buf248 | |
del mul_107 | |
buf265 = empty_strided_cuda((768, ), (1, ), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward] | |
triton_per_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_9.run(buf264, buf265, 768, 64, grid=grid(768), stream=stream0) | |
buf267 = empty_strided_cuda((768, ), (1, ), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward] | |
triton_per_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_9.run(buf266, buf267, 768, 64, grid=grid(768), stream=stream0) | |
buf269 = reinterpret_tensor(buf216, (8192, 3072), (3072, 1), 0); del buf216 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf268, (8192, 768), (768, 1), 0), permute_274, out=buf269) | |
del permute_274 | |
buf270 = reinterpret_tensor(buf218, (768, 3072), (3072, 1), 0); del buf218 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf268, (768, 8192), (1, 768), 0), view_174, out=buf270) | |
del view_174 | |
buf271 = reinterpret_tensor(buf266, (1, 768, 64), (49152, 1, 768), 0); del buf266 # reuse | |
# Source Nodes: [], Original ATen: [aten.sum] | |
triton_red_fused_sum_15.run(buf268, buf271, 49152, 128, grid=grid(49152), stream=stream0) | |
buf274 = empty_strided_cuda((768, ), (1, ), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum] | |
triton_per_fused__to_copy_sum_11.run(buf271, buf274, 768, 64, grid=grid(768), stream=stream0) | |
buf273 = empty_strided_cuda((768, 3072), (3072, 1), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy] | |
triton_poi_fused__to_copy_16.run(buf270, buf273, 2359296, grid=grid(2359296), stream=stream0) | |
buf275 = reinterpret_tensor(buf269, (16, 512, 3072), (1572864, 3072, 1), 0); del buf269 # reuse | |
# Source Nodes: [hidden_states_60], Original ATen: [aten.gelu, aten.gelu_backward] | |
triton_poi_fused_gelu_gelu_backward_17.run(buf275, addmm_46, 25165824, grid=grid(25165824), stream=stream0) | |
del addmm_46 | |
buf276 = reinterpret_tensor(buf268, (8192, 768), (768, 1), 0); del buf268 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf275, (8192, 3072), (3072, 1), 0), permute_278, out=buf276) | |
del permute_278 | |
buf277 = reinterpret_tensor(buf270, (3072, 768), (768, 1), 0); del buf270 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf275, (3072, 8192), (1, 3072), 0), view_172, out=buf277) | |
del view_172 | |
buf278 = buf219; del buf219 # reuse | |
# Source Nodes: [], Original ATen: [aten.sum] | |
triton_red_fused_sum_18.run(buf275, buf278, 98304, 256, grid=grid(98304), stream=stream0) | |
buf281 = empty_strided_cuda((3072, ), (1, ), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum] | |
triton_per_fused__to_copy_sum_19.run(buf278, buf281, 3072, 32, grid=grid(3072), stream=stream0) | |
buf280 = empty_strided_cuda((3072, 768), (768, 1), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy] | |
triton_poi_fused__to_copy_20.run(buf277, buf280, 2359296, grid=grid(2359296), stream=stream0) | |
buf284 = buf225; del buf225 # reuse | |
buf289 = reinterpret_tensor(buf254, (16, 512, 768), (393216, 768, 1), 0); del buf254 # reuse | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_dropout_backward, aten.native_layer_norm_backward] | |
triton_per_fused__to_copy_add_native_dropout_backward_native_layer_norm_backward_21.run(buf263, buf276, primals_126, mul_100, div_40, gt_23, buf284, buf289, 8192, 768, grid=grid(8192), stream=stream0) | |
del div_40 | |
del gt_23 | |
del primals_126 | |
buf285 = reinterpret_tensor(buf271, (768, 64), (1, 768), 0); del buf271 # reuse | |
buf287 = buf264; del buf264 # reuse | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward] | |
triton_red_fused__to_copy_add_native_layer_norm_backward_22.run(buf263, buf276, mul_100, buf285, buf287, 49152, 128, grid=grid(49152), stream=stream0) | |
del mul_100 | |
buf286 = empty_strided_cuda((768, ), (1, ), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward] | |
triton_per_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_9.run(buf285, buf286, 768, 64, grid=grid(768), stream=stream0) | |
buf288 = empty_strided_cuda((768, ), (1, ), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward] | |
triton_per_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_9.run(buf287, buf288, 768, 64, grid=grid(768), stream=stream0) | |
buf290 = buf276; del buf276 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf289, (8192, 768), (768, 1), 0), permute_282, out=buf290) | |
del permute_282 | |
buf291 = buf255; del buf255 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf289, (768, 8192), (1, 768), 0), view_170, out=buf291) | |
del view_170 | |
buf292 = reinterpret_tensor(buf287, (1, 768, 64), (49152, 1, 768), 0); del buf287 # reuse | |
# Source Nodes: [], Original ATen: [aten.sum] | |
triton_red_fused_sum_15.run(buf289, buf292, 49152, 128, grid=grid(49152), stream=stream0) | |
buf295 = empty_strided_cuda((768, ), (1, ), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum] | |
triton_per_fused__to_copy_sum_11.run(buf292, buf295, 768, 64, grid=grid(768), stream=stream0) | |
buf294 = empty_strided_cuda((768, 768), (768, 1), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy] | |
triton_poi_fused__to_copy_12.run(buf291, buf294, 589824, grid=grid(589824), stream=stream0) | |
buf296 = reinterpret_tensor(buf289, (16, 12, 512, 64), (393216, 32768, 64, 1), 0); del buf289 # reuse | |
# Source Nodes: [], Original ATen: [aten.clone] | |
triton_poi_fused_clone_23.run(buf290, buf296, 6291456, grid=grid(6291456), stream=stream0) | |
del buf290 | |
# Source Nodes: [], Original ATen: [aten.clone] | |
buf297 = aten._scaled_dot_product_flash_attention_backward.default(buf296, permute_default_24, permute_default_25, permute_default_26, getitem_80, getitem_81, None, None, 512, 512, 0.1, False, getitem_82, getitem_83, scale=0.125) | |
del getitem_80 | |
del getitem_81 | |
del getitem_82 | |
del getitem_83 | |
del permute_default_24 | |
del permute_default_25 | |
del permute_default_26 | |
buf298 = buf297[0] | |
buf299 = buf297[1] | |
buf300 = buf297[2] | |
del buf297 | |
buf301 = reinterpret_tensor(buf296, (8192, 768), (768, 1), 0); del buf296 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf300, (8192, 768), (768, 1), 0), permute_294, out=buf301) | |
del permute_294 | |
buf302 = buf291; del buf291 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf300, (768, 8192), (1, 768), 0), view_154, out=buf302) | |
buf303 = buf292; del buf292 # reuse | |
# Source Nodes: [], Original ATen: [aten.sum] | |
triton_red_fused_sum_24.run(buf300, buf303, 49152, 128, grid=grid(49152), stream=stream0) | |
buf306 = empty_strided_cuda((768, ), (1, ), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum] | |
triton_per_fused__to_copy_sum_11.run(buf303, buf306, 768, 64, grid=grid(768), stream=stream0) | |
buf305 = empty_strided_cuda((768, 768), (768, 1), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy] | |
triton_poi_fused__to_copy_12.run(buf302, buf305, 589824, grid=grid(589824), stream=stream0) | |
buf307 = reinterpret_tensor(buf300, (8192, 768), (768, 1), 0); del buf300 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf299, (8192, 768), (768, 1), 0), permute_299, out=buf307) | |
del permute_299 | |
buf308 = buf302; del buf302 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf299, (768, 8192), (1, 768), 0), view_154, out=buf308) | |
buf309 = buf303; del buf303 # reuse | |
# Source Nodes: [], Original ATen: [aten.sum] | |
triton_red_fused_sum_24.run(buf299, buf309, 49152, 128, grid=grid(49152), stream=stream0) | |
buf312 = empty_strided_cuda((768, ), (1, ), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum] | |
triton_per_fused__to_copy_sum_11.run(buf309, buf312, 768, 64, grid=grid(768), stream=stream0) | |
buf311 = empty_strided_cuda((768, 768), (768, 1), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy] | |
triton_poi_fused__to_copy_12.run(buf308, buf311, 589824, grid=grid(589824), stream=stream0) | |
buf313 = reinterpret_tensor(buf299, (8192, 768), (768, 1), 0); del buf299 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf298, (8192, 768), (768, 1), 0), permute_303, out=buf313) | |
del permute_303 | |
buf314 = buf308; del buf308 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf298, (768, 8192), (1, 768), 0), view_154, out=buf314) | |
del view_154 | |
buf315 = buf309; del buf309 # reuse | |
# Source Nodes: [], Original ATen: [aten.sum] | |
triton_red_fused_sum_24.run(buf298, buf315, 49152, 128, grid=grid(49152), stream=stream0) | |
buf318 = empty_strided_cuda((768, ), (1, ), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum] | |
triton_per_fused__to_copy_sum_11.run(buf315, buf318, 768, 64, grid=grid(768), stream=stream0) | |
buf317 = empty_strided_cuda((768, 768), (768, 1), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy] | |
triton_poi_fused__to_copy_12.run(buf314, buf317, 589824, grid=grid(589824), stream=stream0) | |
buf322 = buf263; del buf263 # reuse | |
buf327 = reinterpret_tensor(buf298, (16, 512, 768), (393216, 768, 1), 0); del buf298 # reuse | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_dropout_backward, aten.native_layer_norm_backward] | |
triton_per_fused__to_copy_add_native_dropout_backward_native_layer_norm_backward_25.run(buf284, buf301, buf307, buf313, primals_116, mul_94, div_42, gt_21, buf322, buf327, 8192, 768, grid=grid(8192), stream=stream0) | |
del div_42 | |
del gt_21 | |
del primals_116 | |
buf323 = reinterpret_tensor(buf315, (768, 64), (1, 768), 0); del buf315 # reuse | |
buf325 = buf285; del buf285 # reuse | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward] | |
triton_red_fused__to_copy_add_native_layer_norm_backward_26.run(buf284, buf301, buf307, buf313, mul_94, buf323, buf325, 49152, 128, grid=grid(49152), stream=stream0) | |
del buf301 | |
del buf307 | |
del mul_94 | |
buf324 = empty_strided_cuda((768, ), (1, ), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward] | |
triton_per_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_9.run(buf323, buf324, 768, 64, grid=grid(768), stream=stream0) | |
buf326 = empty_strided_cuda((768, ), (1, ), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward] | |
triton_per_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_9.run(buf325, buf326, 768, 64, grid=grid(768), stream=stream0) | |
buf328 = reinterpret_tensor(buf275, (8192, 3072), (3072, 1), 0); del buf275 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf327, (8192, 768), (768, 1), 0), permute_307, out=buf328) | |
del permute_307 | |
buf329 = reinterpret_tensor(buf277, (768, 3072), (3072, 1), 0); del buf277 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf327, (768, 8192), (1, 768), 0), view_152, out=buf329) | |
del view_152 | |
buf330 = reinterpret_tensor(buf325, (1, 768, 64), (49152, 1, 768), 0); del buf325 # reuse | |
# Source Nodes: [], Original ATen: [aten.sum] | |
triton_red_fused_sum_15.run(buf327, buf330, 49152, 128, grid=grid(49152), stream=stream0) | |
buf333 = empty_strided_cuda((768, ), (1, ), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum] | |
triton_per_fused__to_copy_sum_11.run(buf330, buf333, 768, 64, grid=grid(768), stream=stream0) | |
buf332 = empty_strided_cuda((768, 3072), (3072, 1), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy] | |
triton_poi_fused__to_copy_16.run(buf329, buf332, 2359296, grid=grid(2359296), stream=stream0) | |
buf334 = reinterpret_tensor(buf328, (16, 512, 3072), (1572864, 3072, 1), 0); del buf328 # reuse | |
# Source Nodes: [hidden_states_52], Original ATen: [aten.gelu, aten.gelu_backward] | |
triton_poi_fused_gelu_gelu_backward_17.run(buf334, addmm_40, 25165824, grid=grid(25165824), stream=stream0) | |
del addmm_40 | |
buf335 = reinterpret_tensor(buf327, (8192, 768), (768, 1), 0); del buf327 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf334, (8192, 3072), (3072, 1), 0), permute_311, out=buf335) | |
del permute_311 | |
buf336 = reinterpret_tensor(buf329, (3072, 768), (768, 1), 0); del buf329 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf334, (3072, 8192), (1, 3072), 0), view_150, out=buf336) | |
del view_150 | |
buf337 = buf278; del buf278 # reuse | |
# Source Nodes: [], Original ATen: [aten.sum] | |
triton_red_fused_sum_18.run(buf334, buf337, 98304, 256, grid=grid(98304), stream=stream0) | |
buf340 = empty_strided_cuda((3072, ), (1, ), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum] | |
triton_per_fused__to_copy_sum_19.run(buf337, buf340, 3072, 32, grid=grid(3072), stream=stream0) | |
buf339 = empty_strided_cuda((3072, 768), (768, 1), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy] | |
triton_poi_fused__to_copy_20.run(buf336, buf339, 2359296, grid=grid(2359296), stream=stream0) | |
buf343 = buf284; del buf284 # reuse | |
buf348 = reinterpret_tensor(buf313, (16, 512, 768), (393216, 768, 1), 0); del buf313 # reuse | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_dropout_backward, aten.native_layer_norm_backward] | |
triton_per_fused__to_copy_add_native_dropout_backward_native_layer_norm_backward_21.run(buf322, buf335, primals_110, mul_87, div_43, gt_20, buf343, buf348, 8192, 768, grid=grid(8192), stream=stream0) | |
del div_43 | |
del gt_20 | |
del primals_110 | |
buf344 = reinterpret_tensor(buf330, (768, 64), (1, 768), 0); del buf330 # reuse | |
buf346 = buf323; del buf323 # reuse | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward] | |
triton_red_fused__to_copy_add_native_layer_norm_backward_22.run(buf322, buf335, mul_87, buf344, buf346, 49152, 128, grid=grid(49152), stream=stream0) | |
del mul_87 | |
buf345 = empty_strided_cuda((768, ), (1, ), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward] | |
triton_per_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_9.run(buf344, buf345, 768, 64, grid=grid(768), stream=stream0) | |
buf347 = empty_strided_cuda((768, ), (1, ), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward] | |
triton_per_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_9.run(buf346, buf347, 768, 64, grid=grid(768), stream=stream0) | |
buf349 = buf335; del buf335 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf348, (8192, 768), (768, 1), 0), permute_315, out=buf349) | |
del permute_315 | |
buf350 = buf314; del buf314 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf348, (768, 8192), (1, 768), 0), view_148, out=buf350) | |
del view_148 | |
buf351 = reinterpret_tensor(buf346, (1, 768, 64), (49152, 1, 768), 0); del buf346 # reuse | |
# Source Nodes: [], Original ATen: [aten.sum] | |
triton_red_fused_sum_15.run(buf348, buf351, 49152, 128, grid=grid(49152), stream=stream0) | |
buf354 = empty_strided_cuda((768, ), (1, ), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum] | |
triton_per_fused__to_copy_sum_11.run(buf351, buf354, 768, 64, grid=grid(768), stream=stream0) | |
buf353 = empty_strided_cuda((768, 768), (768, 1), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy] | |
triton_poi_fused__to_copy_12.run(buf350, buf353, 589824, grid=grid(589824), stream=stream0) | |
buf355 = reinterpret_tensor(buf348, (16, 12, 512, 64), (393216, 32768, 64, 1), 0); del buf348 # reuse | |
# Source Nodes: [], Original ATen: [aten.clone] | |
triton_poi_fused_clone_23.run(buf349, buf355, 6291456, grid=grid(6291456), stream=stream0) | |
del buf349 | |
# Source Nodes: [], Original ATen: [aten.clone] | |
buf356 = aten._scaled_dot_product_flash_attention_backward.default(buf355, permute_default_30, permute_default_31, permute_default_32, getitem_87, getitem_88, None, None, 512, 512, 0.1, False, getitem_89, getitem_90, scale=0.125) | |
del getitem_87 | |
del getitem_88 | |
del getitem_89 | |
del getitem_90 | |
del permute_default_30 | |
del permute_default_31 | |
del permute_default_32 | |
buf357 = buf356[0] | |
buf358 = buf356[1] | |
buf359 = buf356[2] | |
del buf356 | |
buf360 = reinterpret_tensor(buf355, (8192, 768), (768, 1), 0); del buf355 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf359, (8192, 768), (768, 1), 0), permute_327, out=buf360) | |
del permute_327 | |
buf361 = buf350; del buf350 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf359, (768, 8192), (1, 768), 0), view_132, out=buf361) | |
buf362 = buf351; del buf351 # reuse | |
# Source Nodes: [], Original ATen: [aten.sum] | |
triton_red_fused_sum_24.run(buf359, buf362, 49152, 128, grid=grid(49152), stream=stream0) | |
buf365 = empty_strided_cuda((768, ), (1, ), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum] | |
triton_per_fused__to_copy_sum_11.run(buf362, buf365, 768, 64, grid=grid(768), stream=stream0) | |
buf364 = empty_strided_cuda((768, 768), (768, 1), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy] | |
triton_poi_fused__to_copy_12.run(buf361, buf364, 589824, grid=grid(589824), stream=stream0) | |
buf366 = reinterpret_tensor(buf359, (8192, 768), (768, 1), 0); del buf359 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf358, (8192, 768), (768, 1), 0), permute_332, out=buf366) | |
del permute_332 | |
buf367 = buf361; del buf361 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf358, (768, 8192), (1, 768), 0), view_132, out=buf367) | |
buf368 = buf362; del buf362 # reuse | |
# Source Nodes: [], Original ATen: [aten.sum] | |
triton_red_fused_sum_24.run(buf358, buf368, 49152, 128, grid=grid(49152), stream=stream0) | |
buf371 = empty_strided_cuda((768, ), (1, ), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum] | |
triton_per_fused__to_copy_sum_11.run(buf368, buf371, 768, 64, grid=grid(768), stream=stream0) | |
buf370 = empty_strided_cuda((768, 768), (768, 1), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy] | |
triton_poi_fused__to_copy_12.run(buf367, buf370, 589824, grid=grid(589824), stream=stream0) | |
buf372 = reinterpret_tensor(buf358, (8192, 768), (768, 1), 0); del buf358 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf357, (8192, 768), (768, 1), 0), permute_336, out=buf372) | |
del permute_336 | |
buf373 = buf367; del buf367 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf357, (768, 8192), (1, 768), 0), view_132, out=buf373) | |
del view_132 | |
buf374 = buf368; del buf368 # reuse | |
# Source Nodes: [], Original ATen: [aten.sum] | |
triton_red_fused_sum_24.run(buf357, buf374, 49152, 128, grid=grid(49152), stream=stream0) | |
buf377 = empty_strided_cuda((768, ), (1, ), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum] | |
triton_per_fused__to_copy_sum_11.run(buf374, buf377, 768, 64, grid=grid(768), stream=stream0) | |
buf376 = empty_strided_cuda((768, 768), (768, 1), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy] | |
triton_poi_fused__to_copy_12.run(buf373, buf376, 589824, grid=grid(589824), stream=stream0) | |
buf381 = buf322; del buf322 # reuse | |
buf386 = reinterpret_tensor(buf357, (16, 512, 768), (393216, 768, 1), 0); del buf357 # reuse | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_dropout_backward, aten.native_layer_norm_backward] | |
triton_per_fused__to_copy_add_native_dropout_backward_native_layer_norm_backward_25.run(buf343, buf360, buf366, buf372, primals_100, mul_81, div_45, gt_18, buf381, buf386, 8192, 768, grid=grid(8192), stream=stream0) | |
del div_45 | |
del gt_18 | |
del primals_100 | |
buf382 = reinterpret_tensor(buf374, (768, 64), (1, 768), 0); del buf374 # reuse | |
buf384 = buf344; del buf344 # reuse | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward] | |
triton_red_fused__to_copy_add_native_layer_norm_backward_26.run(buf343, buf360, buf366, buf372, mul_81, buf382, buf384, 49152, 128, grid=grid(49152), stream=stream0) | |
del buf360 | |
del buf366 | |
del mul_81 | |
buf383 = empty_strided_cuda((768, ), (1, ), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward] | |
triton_per_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_9.run(buf382, buf383, 768, 64, grid=grid(768), stream=stream0) | |
buf385 = empty_strided_cuda((768, ), (1, ), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward] | |
triton_per_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_9.run(buf384, buf385, 768, 64, grid=grid(768), stream=stream0) | |
buf387 = reinterpret_tensor(buf334, (8192, 3072), (3072, 1), 0); del buf334 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf386, (8192, 768), (768, 1), 0), permute_340, out=buf387) | |
del permute_340 | |
buf388 = reinterpret_tensor(buf336, (768, 3072), (3072, 1), 0); del buf336 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf386, (768, 8192), (1, 768), 0), view_130, out=buf388) | |
del view_130 | |
buf389 = reinterpret_tensor(buf384, (1, 768, 64), (49152, 1, 768), 0); del buf384 # reuse | |
# Source Nodes: [], Original ATen: [aten.sum] | |
triton_red_fused_sum_15.run(buf386, buf389, 49152, 128, grid=grid(49152), stream=stream0) | |
buf392 = empty_strided_cuda((768, ), (1, ), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum] | |
triton_per_fused__to_copy_sum_11.run(buf389, buf392, 768, 64, grid=grid(768), stream=stream0) | |
buf391 = empty_strided_cuda((768, 3072), (3072, 1), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy] | |
triton_poi_fused__to_copy_16.run(buf388, buf391, 2359296, grid=grid(2359296), stream=stream0) | |
buf393 = reinterpret_tensor(buf387, (16, 512, 3072), (1572864, 3072, 1), 0); del buf387 # reuse | |
# Source Nodes: [hidden_states_44], Original ATen: [aten.gelu, aten.gelu_backward] | |
triton_poi_fused_gelu_gelu_backward_17.run(buf393, addmm_34, 25165824, grid=grid(25165824), stream=stream0) | |
del addmm_34 | |
buf394 = reinterpret_tensor(buf386, (8192, 768), (768, 1), 0); del buf386 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf393, (8192, 3072), (3072, 1), 0), permute_344, out=buf394) | |
del permute_344 | |
buf395 = reinterpret_tensor(buf388, (3072, 768), (768, 1), 0); del buf388 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf393, (3072, 8192), (1, 3072), 0), view_128, out=buf395) | |
del view_128 | |
buf396 = buf337; del buf337 # reuse | |
# Source Nodes: [], Original ATen: [aten.sum] | |
triton_red_fused_sum_18.run(buf393, buf396, 98304, 256, grid=grid(98304), stream=stream0) | |
buf399 = empty_strided_cuda((3072, ), (1, ), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum] | |
triton_per_fused__to_copy_sum_19.run(buf396, buf399, 3072, 32, grid=grid(3072), stream=stream0) | |
buf398 = empty_strided_cuda((3072, 768), (768, 1), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy] | |
triton_poi_fused__to_copy_20.run(buf395, buf398, 2359296, grid=grid(2359296), stream=stream0) | |
buf402 = buf343; del buf343 # reuse | |
buf407 = reinterpret_tensor(buf372, (16, 512, 768), (393216, 768, 1), 0); del buf372 # reuse | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_dropout_backward, aten.native_layer_norm_backward] | |
triton_per_fused__to_copy_add_native_dropout_backward_native_layer_norm_backward_21.run(buf381, buf394, primals_94, mul_74, div_46, gt_17, buf402, buf407, 8192, 768, grid=grid(8192), stream=stream0) | |
del div_46 | |
del gt_17 | |
del primals_94 | |
buf403 = reinterpret_tensor(buf389, (768, 64), (1, 768), 0); del buf389 # reuse | |
buf405 = buf382; del buf382 # reuse | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward] | |
triton_red_fused__to_copy_add_native_layer_norm_backward_22.run(buf381, buf394, mul_74, buf403, buf405, 49152, 128, grid=grid(49152), stream=stream0) | |
del mul_74 | |
buf404 = empty_strided_cuda((768, ), (1, ), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward] | |
triton_per_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_9.run(buf403, buf404, 768, 64, grid=grid(768), stream=stream0) | |
buf406 = empty_strided_cuda((768, ), (1, ), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward] | |
triton_per_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_9.run(buf405, buf406, 768, 64, grid=grid(768), stream=stream0) | |
buf408 = buf394; del buf394 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf407, (8192, 768), (768, 1), 0), permute_348, out=buf408) | |
del permute_348 | |
buf409 = buf373; del buf373 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf407, (768, 8192), (1, 768), 0), view_126, out=buf409) | |
del view_126 | |
buf410 = reinterpret_tensor(buf405, (1, 768, 64), (49152, 1, 768), 0); del buf405 # reuse | |
# Source Nodes: [], Original ATen: [aten.sum] | |
triton_red_fused_sum_15.run(buf407, buf410, 49152, 128, grid=grid(49152), stream=stream0) | |
buf413 = empty_strided_cuda((768, ), (1, ), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum] | |
triton_per_fused__to_copy_sum_11.run(buf410, buf413, 768, 64, grid=grid(768), stream=stream0) | |
buf412 = empty_strided_cuda((768, 768), (768, 1), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy] | |
triton_poi_fused__to_copy_12.run(buf409, buf412, 589824, grid=grid(589824), stream=stream0) | |
buf414 = reinterpret_tensor(buf407, (16, 12, 512, 64), (393216, 32768, 64, 1), 0); del buf407 # reuse | |
# Source Nodes: [], Original ATen: [aten.clone] | |
triton_poi_fused_clone_23.run(buf408, buf414, 6291456, grid=grid(6291456), stream=stream0) | |
del buf408 | |
# Source Nodes: [], Original ATen: [aten.clone] | |
buf415 = aten._scaled_dot_product_flash_attention_backward.default(buf414, permute_default_36, permute_default_37, permute_default_38, getitem_94, getitem_95, None, None, 512, 512, 0.1, False, getitem_96, getitem_97, scale=0.125) | |
del getitem_94 | |
del getitem_95 | |
del getitem_96 | |
del getitem_97 | |
del permute_default_36 | |
del permute_default_37 | |
del permute_default_38 | |
buf416 = buf415[0] | |
buf417 = buf415[1] | |
buf418 = buf415[2] | |
del buf415 | |
buf419 = reinterpret_tensor(buf414, (8192, 768), (768, 1), 0); del buf414 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf418, (8192, 768), (768, 1), 0), permute_360, out=buf419) | |
del permute_360 | |
buf420 = buf409; del buf409 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf418, (768, 8192), (1, 768), 0), view_110, out=buf420) | |
buf421 = buf410; del buf410 # reuse | |
# Source Nodes: [], Original ATen: [aten.sum] | |
triton_red_fused_sum_24.run(buf418, buf421, 49152, 128, grid=grid(49152), stream=stream0) | |
buf424 = empty_strided_cuda((768, ), (1, ), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum] | |
triton_per_fused__to_copy_sum_11.run(buf421, buf424, 768, 64, grid=grid(768), stream=stream0) | |
buf423 = empty_strided_cuda((768, 768), (768, 1), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy] | |
triton_poi_fused__to_copy_12.run(buf420, buf423, 589824, grid=grid(589824), stream=stream0) | |
buf425 = reinterpret_tensor(buf418, (8192, 768), (768, 1), 0); del buf418 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf417, (8192, 768), (768, 1), 0), permute_365, out=buf425) | |
del permute_365 | |
buf426 = buf420; del buf420 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf417, (768, 8192), (1, 768), 0), view_110, out=buf426) | |
buf427 = buf421; del buf421 # reuse | |
# Source Nodes: [], Original ATen: [aten.sum] | |
triton_red_fused_sum_24.run(buf417, buf427, 49152, 128, grid=grid(49152), stream=stream0) | |
buf430 = empty_strided_cuda((768, ), (1, ), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum] | |
triton_per_fused__to_copy_sum_11.run(buf427, buf430, 768, 64, grid=grid(768), stream=stream0) | |
buf429 = empty_strided_cuda((768, 768), (768, 1), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy] | |
triton_poi_fused__to_copy_12.run(buf426, buf429, 589824, grid=grid(589824), stream=stream0) | |
buf431 = reinterpret_tensor(buf417, (8192, 768), (768, 1), 0); del buf417 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf416, (8192, 768), (768, 1), 0), permute_369, out=buf431) | |
del permute_369 | |
buf432 = buf426; del buf426 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf416, (768, 8192), (1, 768), 0), view_110, out=buf432) | |
del view_110 | |
buf433 = buf427; del buf427 # reuse | |
# Source Nodes: [], Original ATen: [aten.sum] | |
triton_red_fused_sum_24.run(buf416, buf433, 49152, 128, grid=grid(49152), stream=stream0) | |
buf436 = empty_strided_cuda((768, ), (1, ), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum] | |
triton_per_fused__to_copy_sum_11.run(buf433, buf436, 768, 64, grid=grid(768), stream=stream0) | |
buf435 = empty_strided_cuda((768, 768), (768, 1), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy] | |
triton_poi_fused__to_copy_12.run(buf432, buf435, 589824, grid=grid(589824), stream=stream0) | |
buf440 = buf381; del buf381 # reuse | |
buf445 = reinterpret_tensor(buf416, (16, 512, 768), (393216, 768, 1), 0); del buf416 # reuse | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_dropout_backward, aten.native_layer_norm_backward] | |
triton_per_fused__to_copy_add_native_dropout_backward_native_layer_norm_backward_25.run(buf402, buf419, buf425, buf431, primals_84, mul_68, div_48, gt_15, buf440, buf445, 8192, 768, grid=grid(8192), stream=stream0) | |
del div_48 | |
del gt_15 | |
del primals_84 | |
buf441 = reinterpret_tensor(buf433, (768, 64), (1, 768), 0); del buf433 # reuse | |
buf443 = buf403; del buf403 # reuse | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward] | |
triton_red_fused__to_copy_add_native_layer_norm_backward_26.run(buf402, buf419, buf425, buf431, mul_68, buf441, buf443, 49152, 128, grid=grid(49152), stream=stream0) | |
del buf419 | |
del buf425 | |
del mul_68 | |
buf442 = empty_strided_cuda((768, ), (1, ), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward] | |
triton_per_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_9.run(buf441, buf442, 768, 64, grid=grid(768), stream=stream0) | |
buf444 = empty_strided_cuda((768, ), (1, ), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward] | |
triton_per_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_9.run(buf443, buf444, 768, 64, grid=grid(768), stream=stream0) | |
buf446 = reinterpret_tensor(buf393, (8192, 3072), (3072, 1), 0); del buf393 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf445, (8192, 768), (768, 1), 0), permute_373, out=buf446) | |
del permute_373 | |
buf447 = reinterpret_tensor(buf395, (768, 3072), (3072, 1), 0); del buf395 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf445, (768, 8192), (1, 768), 0), view_108, out=buf447) | |
del view_108 | |
buf448 = reinterpret_tensor(buf443, (1, 768, 64), (49152, 1, 768), 0); del buf443 # reuse | |
# Source Nodes: [], Original ATen: [aten.sum] | |
triton_red_fused_sum_15.run(buf445, buf448, 49152, 128, grid=grid(49152), stream=stream0) | |
buf451 = empty_strided_cuda((768, ), (1, ), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum] | |
triton_per_fused__to_copy_sum_11.run(buf448, buf451, 768, 64, grid=grid(768), stream=stream0) | |
buf450 = empty_strided_cuda((768, 3072), (3072, 1), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy] | |
triton_poi_fused__to_copy_16.run(buf447, buf450, 2359296, grid=grid(2359296), stream=stream0) | |
buf452 = reinterpret_tensor(buf446, (16, 512, 3072), (1572864, 3072, 1), 0); del buf446 # reuse | |
# Source Nodes: [hidden_states_36], Original ATen: [aten.gelu, aten.gelu_backward] | |
triton_poi_fused_gelu_gelu_backward_17.run(buf452, addmm_28, 25165824, grid=grid(25165824), stream=stream0) | |
del addmm_28 | |
buf453 = reinterpret_tensor(buf445, (8192, 768), (768, 1), 0); del buf445 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf452, (8192, 3072), (3072, 1), 0), permute_377, out=buf453) | |
del permute_377 | |
buf454 = reinterpret_tensor(buf447, (3072, 768), (768, 1), 0); del buf447 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf452, (3072, 8192), (1, 3072), 0), view_106, out=buf454) | |
del view_106 | |
buf455 = buf396; del buf396 # reuse | |
# Source Nodes: [], Original ATen: [aten.sum] | |
triton_red_fused_sum_18.run(buf452, buf455, 98304, 256, grid=grid(98304), stream=stream0) | |
buf458 = empty_strided_cuda((3072, ), (1, ), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum] | |
triton_per_fused__to_copy_sum_19.run(buf455, buf458, 3072, 32, grid=grid(3072), stream=stream0) | |
buf457 = empty_strided_cuda((3072, 768), (768, 1), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy] | |
triton_poi_fused__to_copy_20.run(buf454, buf457, 2359296, grid=grid(2359296), stream=stream0) | |
buf461 = buf402; del buf402 # reuse | |
buf466 = reinterpret_tensor(buf431, (16, 512, 768), (393216, 768, 1), 0); del buf431 # reuse | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_dropout_backward, aten.native_layer_norm_backward] | |
triton_per_fused__to_copy_add_native_dropout_backward_native_layer_norm_backward_21.run(buf440, buf453, primals_78, mul_61, div_49, gt_14, buf461, buf466, 8192, 768, grid=grid(8192), stream=stream0) | |
del div_49 | |
del gt_14 | |
del primals_78 | |
buf462 = reinterpret_tensor(buf448, (768, 64), (1, 768), 0); del buf448 # reuse | |
buf464 = buf441; del buf441 # reuse | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward] | |
triton_red_fused__to_copy_add_native_layer_norm_backward_22.run(buf440, buf453, mul_61, buf462, buf464, 49152, 128, grid=grid(49152), stream=stream0) | |
del mul_61 | |
buf463 = empty_strided_cuda((768, ), (1, ), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward] | |
triton_per_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_9.run(buf462, buf463, 768, 64, grid=grid(768), stream=stream0) | |
buf465 = empty_strided_cuda((768, ), (1, ), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward] | |
triton_per_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_9.run(buf464, buf465, 768, 64, grid=grid(768), stream=stream0) | |
buf467 = buf453; del buf453 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf466, (8192, 768), (768, 1), 0), permute_381, out=buf467) | |
del permute_381 | |
buf468 = buf432; del buf432 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf466, (768, 8192), (1, 768), 0), view_104, out=buf468) | |
del view_104 | |
buf469 = reinterpret_tensor(buf464, (1, 768, 64), (49152, 1, 768), 0); del buf464 # reuse | |
# Source Nodes: [], Original ATen: [aten.sum] | |
triton_red_fused_sum_15.run(buf466, buf469, 49152, 128, grid=grid(49152), stream=stream0) | |
buf472 = empty_strided_cuda((768, ), (1, ), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum] | |
triton_per_fused__to_copy_sum_11.run(buf469, buf472, 768, 64, grid=grid(768), stream=stream0) | |
buf471 = empty_strided_cuda((768, 768), (768, 1), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy] | |
triton_poi_fused__to_copy_12.run(buf468, buf471, 589824, grid=grid(589824), stream=stream0) | |
buf473 = reinterpret_tensor(buf466, (16, 12, 512, 64), (393216, 32768, 64, 1), 0); del buf466 # reuse | |
# Source Nodes: [], Original ATen: [aten.clone] | |
triton_poi_fused_clone_23.run(buf467, buf473, 6291456, grid=grid(6291456), stream=stream0) | |
del buf467 | |
# Source Nodes: [], Original ATen: [aten.clone] | |
buf474 = aten._scaled_dot_product_flash_attention_backward.default(buf473, permute_default_42, permute_default_43, permute_default_44, getitem_101, getitem_102, None, None, 512, 512, 0.1, False, getitem_103, getitem_104, scale=0.125) | |
del getitem_101 | |
del getitem_102 | |
del getitem_103 | |
del getitem_104 | |
del permute_default_42 | |
del permute_default_43 | |
del permute_default_44 | |
buf475 = buf474[0] | |
buf476 = buf474[1] | |
buf477 = buf474[2] | |
del buf474 | |
buf478 = reinterpret_tensor(buf473, (8192, 768), (768, 1), 0); del buf473 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf477, (8192, 768), (768, 1), 0), permute_393, out=buf478) | |
del permute_393 | |
buf479 = buf468; del buf468 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf477, (768, 8192), (1, 768), 0), view_88, out=buf479) | |
buf480 = buf469; del buf469 # reuse | |
# Source Nodes: [], Original ATen: [aten.sum] | |
triton_red_fused_sum_24.run(buf477, buf480, 49152, 128, grid=grid(49152), stream=stream0) | |
buf483 = empty_strided_cuda((768, ), (1, ), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum] | |
triton_per_fused__to_copy_sum_11.run(buf480, buf483, 768, 64, grid=grid(768), stream=stream0) | |
buf482 = empty_strided_cuda((768, 768), (768, 1), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy] | |
triton_poi_fused__to_copy_12.run(buf479, buf482, 589824, grid=grid(589824), stream=stream0) | |
buf484 = reinterpret_tensor(buf477, (8192, 768), (768, 1), 0); del buf477 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf476, (8192, 768), (768, 1), 0), permute_398, out=buf484) | |
del permute_398 | |
buf485 = buf479; del buf479 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf476, (768, 8192), (1, 768), 0), view_88, out=buf485) | |
buf486 = buf480; del buf480 # reuse | |
# Source Nodes: [], Original ATen: [aten.sum] | |
triton_red_fused_sum_24.run(buf476, buf486, 49152, 128, grid=grid(49152), stream=stream0) | |
buf489 = empty_strided_cuda((768, ), (1, ), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum] | |
triton_per_fused__to_copy_sum_11.run(buf486, buf489, 768, 64, grid=grid(768), stream=stream0) | |
buf488 = empty_strided_cuda((768, 768), (768, 1), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy] | |
triton_poi_fused__to_copy_12.run(buf485, buf488, 589824, grid=grid(589824), stream=stream0) | |
buf490 = reinterpret_tensor(buf476, (8192, 768), (768, 1), 0); del buf476 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf475, (8192, 768), (768, 1), 0), permute_402, out=buf490) | |
del permute_402 | |
buf491 = buf485; del buf485 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf475, (768, 8192), (1, 768), 0), view_88, out=buf491) | |
del view_88 | |
buf492 = buf486; del buf486 # reuse | |
# Source Nodes: [], Original ATen: [aten.sum] | |
triton_red_fused_sum_24.run(buf475, buf492, 49152, 128, grid=grid(49152), stream=stream0) | |
buf495 = empty_strided_cuda((768, ), (1, ), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum] | |
triton_per_fused__to_copy_sum_11.run(buf492, buf495, 768, 64, grid=grid(768), stream=stream0) | |
buf494 = empty_strided_cuda((768, 768), (768, 1), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy] | |
triton_poi_fused__to_copy_12.run(buf491, buf494, 589824, grid=grid(589824), stream=stream0) | |
buf499 = buf440; del buf440 # reuse | |
buf504 = reinterpret_tensor(buf475, (16, 512, 768), (393216, 768, 1), 0); del buf475 # reuse | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_dropout_backward, aten.native_layer_norm_backward] | |
triton_per_fused__to_copy_add_native_dropout_backward_native_layer_norm_backward_25.run(buf461, buf478, buf484, buf490, primals_68, mul_55, div_51, gt_12, buf499, buf504, 8192, 768, grid=grid(8192), stream=stream0) | |
del div_51 | |
del gt_12 | |
del primals_68 | |
buf500 = reinterpret_tensor(buf492, (768, 64), (1, 768), 0); del buf492 # reuse | |
buf502 = buf462; del buf462 # reuse | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward] | |
triton_red_fused__to_copy_add_native_layer_norm_backward_26.run(buf461, buf478, buf484, buf490, mul_55, buf500, buf502, 49152, 128, grid=grid(49152), stream=stream0) | |
del buf478 | |
del buf484 | |
del mul_55 | |
buf501 = empty_strided_cuda((768, ), (1, ), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward] | |
triton_per_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_9.run(buf500, buf501, 768, 64, grid=grid(768), stream=stream0) | |
buf503 = empty_strided_cuda((768, ), (1, ), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward] | |
triton_per_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_9.run(buf502, buf503, 768, 64, grid=grid(768), stream=stream0) | |
buf505 = reinterpret_tensor(buf452, (8192, 3072), (3072, 1), 0); del buf452 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf504, (8192, 768), (768, 1), 0), permute_406, out=buf505) | |
del permute_406 | |
buf506 = reinterpret_tensor(buf454, (768, 3072), (3072, 1), 0); del buf454 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf504, (768, 8192), (1, 768), 0), view_86, out=buf506) | |
del view_86 | |
buf507 = reinterpret_tensor(buf502, (1, 768, 64), (49152, 1, 768), 0); del buf502 # reuse | |
# Source Nodes: [], Original ATen: [aten.sum] | |
triton_red_fused_sum_15.run(buf504, buf507, 49152, 128, grid=grid(49152), stream=stream0) | |
buf510 = empty_strided_cuda((768, ), (1, ), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum] | |
triton_per_fused__to_copy_sum_11.run(buf507, buf510, 768, 64, grid=grid(768), stream=stream0) | |
buf509 = empty_strided_cuda((768, 3072), (3072, 1), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy] | |
triton_poi_fused__to_copy_16.run(buf506, buf509, 2359296, grid=grid(2359296), stream=stream0) | |
buf511 = reinterpret_tensor(buf505, (16, 512, 3072), (1572864, 3072, 1), 0); del buf505 # reuse | |
# Source Nodes: [hidden_states_28], Original ATen: [aten.gelu, aten.gelu_backward] | |
triton_poi_fused_gelu_gelu_backward_17.run(buf511, addmm_22, 25165824, grid=grid(25165824), stream=stream0) | |
del addmm_22 | |
buf512 = reinterpret_tensor(buf504, (8192, 768), (768, 1), 0); del buf504 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf511, (8192, 3072), (3072, 1), 0), permute_410, out=buf512) | |
del permute_410 | |
buf513 = reinterpret_tensor(buf506, (3072, 768), (768, 1), 0); del buf506 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf511, (3072, 8192), (1, 3072), 0), view_84, out=buf513) | |
del view_84 | |
buf514 = buf455; del buf455 # reuse | |
# Source Nodes: [], Original ATen: [aten.sum] | |
triton_red_fused_sum_18.run(buf511, buf514, 98304, 256, grid=grid(98304), stream=stream0) | |
buf517 = empty_strided_cuda((3072, ), (1, ), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum] | |
triton_per_fused__to_copy_sum_19.run(buf514, buf517, 3072, 32, grid=grid(3072), stream=stream0) | |
buf516 = empty_strided_cuda((3072, 768), (768, 1), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy] | |
triton_poi_fused__to_copy_20.run(buf513, buf516, 2359296, grid=grid(2359296), stream=stream0) | |
buf520 = buf461; del buf461 # reuse | |
buf525 = reinterpret_tensor(buf490, (16, 512, 768), (393216, 768, 1), 0); del buf490 # reuse | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_dropout_backward, aten.native_layer_norm_backward] | |
triton_per_fused__to_copy_add_native_dropout_backward_native_layer_norm_backward_21.run(buf499, buf512, primals_62, mul_48, div_52, gt_11, buf520, buf525, 8192, 768, grid=grid(8192), stream=stream0) | |
del div_52 | |
del gt_11 | |
del primals_62 | |
buf521 = reinterpret_tensor(buf507, (768, 64), (1, 768), 0); del buf507 # reuse | |
buf523 = buf500; del buf500 # reuse | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward] | |
triton_red_fused__to_copy_add_native_layer_norm_backward_22.run(buf499, buf512, mul_48, buf521, buf523, 49152, 128, grid=grid(49152), stream=stream0) | |
del mul_48 | |
buf522 = empty_strided_cuda((768, ), (1, ), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward] | |
triton_per_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_9.run(buf521, buf522, 768, 64, grid=grid(768), stream=stream0) | |
buf524 = empty_strided_cuda((768, ), (1, ), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward] | |
triton_per_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_9.run(buf523, buf524, 768, 64, grid=grid(768), stream=stream0) | |
buf526 = buf512; del buf512 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf525, (8192, 768), (768, 1), 0), permute_414, out=buf526) | |
del permute_414 | |
buf527 = buf491; del buf491 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf525, (768, 8192), (1, 768), 0), view_82, out=buf527) | |
del view_82 | |
buf528 = reinterpret_tensor(buf523, (1, 768, 64), (49152, 1, 768), 0); del buf523 # reuse | |
# Source Nodes: [], Original ATen: [aten.sum] | |
triton_red_fused_sum_15.run(buf525, buf528, 49152, 128, grid=grid(49152), stream=stream0) | |
buf531 = empty_strided_cuda((768, ), (1, ), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum] | |
triton_per_fused__to_copy_sum_11.run(buf528, buf531, 768, 64, grid=grid(768), stream=stream0) | |
buf530 = empty_strided_cuda((768, 768), (768, 1), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy] | |
triton_poi_fused__to_copy_12.run(buf527, buf530, 589824, grid=grid(589824), stream=stream0) | |
buf532 = reinterpret_tensor(buf525, (16, 12, 512, 64), (393216, 32768, 64, 1), 0); del buf525 # reuse | |
# Source Nodes: [], Original ATen: [aten.clone] | |
triton_poi_fused_clone_23.run(buf526, buf532, 6291456, grid=grid(6291456), stream=stream0) | |
del buf526 | |
# Source Nodes: [], Original ATen: [aten.clone] | |
buf533 = aten._scaled_dot_product_flash_attention_backward.default(buf532, permute_default_48, permute_default_49, permute_default_50, getitem_108, getitem_109, None, None, 512, 512, 0.1, False, getitem_110, getitem_111, scale=0.125) | |
del getitem_108 | |
del getitem_109 | |
del getitem_110 | |
del getitem_111 | |
del permute_default_48 | |
del permute_default_49 | |
del permute_default_50 | |
buf534 = buf533[0] | |
buf535 = buf533[1] | |
buf536 = buf533[2] | |
del buf533 | |
buf537 = reinterpret_tensor(buf532, (8192, 768), (768, 1), 0); del buf532 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf536, (8192, 768), (768, 1), 0), permute_426, out=buf537) | |
del permute_426 | |
buf538 = buf527; del buf527 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf536, (768, 8192), (1, 768), 0), view_66, out=buf538) | |
buf539 = buf528; del buf528 # reuse | |
# Source Nodes: [], Original ATen: [aten.sum] | |
triton_red_fused_sum_24.run(buf536, buf539, 49152, 128, grid=grid(49152), stream=stream0) | |
buf542 = empty_strided_cuda((768, ), (1, ), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum] | |
triton_per_fused__to_copy_sum_11.run(buf539, buf542, 768, 64, grid=grid(768), stream=stream0) | |
buf541 = empty_strided_cuda((768, 768), (768, 1), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy] | |
triton_poi_fused__to_copy_12.run(buf538, buf541, 589824, grid=grid(589824), stream=stream0) | |
buf543 = reinterpret_tensor(buf536, (8192, 768), (768, 1), 0); del buf536 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf535, (8192, 768), (768, 1), 0), permute_431, out=buf543) | |
del permute_431 | |
buf544 = buf538; del buf538 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf535, (768, 8192), (1, 768), 0), view_66, out=buf544) | |
buf545 = buf539; del buf539 # reuse | |
# Source Nodes: [], Original ATen: [aten.sum] | |
triton_red_fused_sum_24.run(buf535, buf545, 49152, 128, grid=grid(49152), stream=stream0) | |
buf548 = empty_strided_cuda((768, ), (1, ), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum] | |
triton_per_fused__to_copy_sum_11.run(buf545, buf548, 768, 64, grid=grid(768), stream=stream0) | |
buf547 = empty_strided_cuda((768, 768), (768, 1), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy] | |
triton_poi_fused__to_copy_12.run(buf544, buf547, 589824, grid=grid(589824), stream=stream0) | |
buf549 = reinterpret_tensor(buf535, (8192, 768), (768, 1), 0); del buf535 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf534, (8192, 768), (768, 1), 0), permute_435, out=buf549) | |
del permute_435 | |
buf550 = buf544; del buf544 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf534, (768, 8192), (1, 768), 0), view_66, out=buf550) | |
del view_66 | |
buf551 = buf545; del buf545 # reuse | |
# Source Nodes: [], Original ATen: [aten.sum] | |
triton_red_fused_sum_24.run(buf534, buf551, 49152, 128, grid=grid(49152), stream=stream0) | |
buf554 = empty_strided_cuda((768, ), (1, ), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum] | |
triton_per_fused__to_copy_sum_11.run(buf551, buf554, 768, 64, grid=grid(768), stream=stream0) | |
buf553 = empty_strided_cuda((768, 768), (768, 1), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy] | |
triton_poi_fused__to_copy_12.run(buf550, buf553, 589824, grid=grid(589824), stream=stream0) | |
buf558 = buf499; del buf499 # reuse | |
buf563 = reinterpret_tensor(buf534, (16, 512, 768), (393216, 768, 1), 0); del buf534 # reuse | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_dropout_backward, aten.native_layer_norm_backward] | |
triton_per_fused__to_copy_add_native_dropout_backward_native_layer_norm_backward_25.run(buf520, buf537, buf543, buf549, primals_52, mul_42, div_54, gt_9, buf558, buf563, 8192, 768, grid=grid(8192), stream=stream0) | |
del div_54 | |
del gt_9 | |
del primals_52 | |
buf559 = reinterpret_tensor(buf551, (768, 64), (1, 768), 0); del buf551 # reuse | |
buf561 = buf521; del buf521 # reuse | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward] | |
triton_red_fused__to_copy_add_native_layer_norm_backward_26.run(buf520, buf537, buf543, buf549, mul_42, buf559, buf561, 49152, 128, grid=grid(49152), stream=stream0) | |
del buf537 | |
del buf543 | |
del mul_42 | |
buf560 = empty_strided_cuda((768, ), (1, ), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward] | |
triton_per_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_9.run(buf559, buf560, 768, 64, grid=grid(768), stream=stream0) | |
buf562 = empty_strided_cuda((768, ), (1, ), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward] | |
triton_per_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_9.run(buf561, buf562, 768, 64, grid=grid(768), stream=stream0) | |
buf564 = reinterpret_tensor(buf511, (8192, 3072), (3072, 1), 0); del buf511 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf563, (8192, 768), (768, 1), 0), permute_439, out=buf564) | |
del permute_439 | |
buf565 = reinterpret_tensor(buf513, (768, 3072), (3072, 1), 0); del buf513 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf563, (768, 8192), (1, 768), 0), view_64, out=buf565) | |
del view_64 | |
buf566 = reinterpret_tensor(buf561, (1, 768, 64), (49152, 1, 768), 0); del buf561 # reuse | |
# Source Nodes: [], Original ATen: [aten.sum] | |
triton_red_fused_sum_15.run(buf563, buf566, 49152, 128, grid=grid(49152), stream=stream0) | |
buf569 = empty_strided_cuda((768, ), (1, ), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum] | |
triton_per_fused__to_copy_sum_11.run(buf566, buf569, 768, 64, grid=grid(768), stream=stream0) | |
buf568 = empty_strided_cuda((768, 3072), (3072, 1), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy] | |
triton_poi_fused__to_copy_16.run(buf565, buf568, 2359296, grid=grid(2359296), stream=stream0) | |
buf570 = reinterpret_tensor(buf564, (16, 512, 3072), (1572864, 3072, 1), 0); del buf564 # reuse | |
# Source Nodes: [hidden_states_20], Original ATen: [aten.gelu, aten.gelu_backward] | |
triton_poi_fused_gelu_gelu_backward_17.run(buf570, addmm_16, 25165824, grid=grid(25165824), stream=stream0) | |
del addmm_16 | |
buf571 = reinterpret_tensor(buf563, (8192, 768), (768, 1), 0); del buf563 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf570, (8192, 3072), (3072, 1), 0), permute_443, out=buf571) | |
del permute_443 | |
buf572 = reinterpret_tensor(buf565, (3072, 768), (768, 1), 0); del buf565 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf570, (3072, 8192), (1, 3072), 0), view_62, out=buf572) | |
del view_62 | |
buf573 = buf514; del buf514 # reuse | |
# Source Nodes: [], Original ATen: [aten.sum] | |
triton_red_fused_sum_18.run(buf570, buf573, 98304, 256, grid=grid(98304), stream=stream0) | |
buf576 = empty_strided_cuda((3072, ), (1, ), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum] | |
triton_per_fused__to_copy_sum_19.run(buf573, buf576, 3072, 32, grid=grid(3072), stream=stream0) | |
buf575 = empty_strided_cuda((3072, 768), (768, 1), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy] | |
triton_poi_fused__to_copy_20.run(buf572, buf575, 2359296, grid=grid(2359296), stream=stream0) | |
buf579 = buf520; del buf520 # reuse | |
buf584 = reinterpret_tensor(buf549, (16, 512, 768), (393216, 768, 1), 0); del buf549 # reuse | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_dropout_backward, aten.native_layer_norm_backward] | |
triton_per_fused__to_copy_add_native_dropout_backward_native_layer_norm_backward_21.run(buf558, buf571, primals_46, mul_35, div_55, gt_8, buf579, buf584, 8192, 768, grid=grid(8192), stream=stream0) | |
del div_55 | |
del gt_8 | |
del primals_46 | |
buf580 = reinterpret_tensor(buf566, (768, 64), (1, 768), 0); del buf566 # reuse | |
buf582 = buf559; del buf559 # reuse | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward] | |
triton_red_fused__to_copy_add_native_layer_norm_backward_22.run(buf558, buf571, mul_35, buf580, buf582, 49152, 128, grid=grid(49152), stream=stream0) | |
del mul_35 | |
buf581 = empty_strided_cuda((768, ), (1, ), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward] | |
triton_per_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_9.run(buf580, buf581, 768, 64, grid=grid(768), stream=stream0) | |
buf583 = empty_strided_cuda((768, ), (1, ), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward] | |
triton_per_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_9.run(buf582, buf583, 768, 64, grid=grid(768), stream=stream0) | |
buf585 = buf571; del buf571 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf584, (8192, 768), (768, 1), 0), permute_447, out=buf585) | |
del permute_447 | |
buf586 = buf550; del buf550 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf584, (768, 8192), (1, 768), 0), view_60, out=buf586) | |
del view_60 | |
buf587 = reinterpret_tensor(buf582, (1, 768, 64), (49152, 1, 768), 0); del buf582 # reuse | |
# Source Nodes: [], Original ATen: [aten.sum] | |
triton_red_fused_sum_15.run(buf584, buf587, 49152, 128, grid=grid(49152), stream=stream0) | |
buf590 = empty_strided_cuda((768, ), (1, ), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum] | |
triton_per_fused__to_copy_sum_11.run(buf587, buf590, 768, 64, grid=grid(768), stream=stream0) | |
buf589 = empty_strided_cuda((768, 768), (768, 1), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy] | |
triton_poi_fused__to_copy_12.run(buf586, buf589, 589824, grid=grid(589824), stream=stream0) | |
buf591 = reinterpret_tensor(buf584, (16, 12, 512, 64), (393216, 32768, 64, 1), 0); del buf584 # reuse | |
# Source Nodes: [], Original ATen: [aten.clone] | |
triton_poi_fused_clone_23.run(buf585, buf591, 6291456, grid=grid(6291456), stream=stream0) | |
del buf585 | |
# Source Nodes: [], Original ATen: [aten.clone] | |
buf592 = aten._scaled_dot_product_flash_attention_backward.default(buf591, permute_default_54, permute_default_55, permute_default_56, getitem_115, getitem_116, None, None, 512, 512, 0.1, False, getitem_117, getitem_118, scale=0.125) | |
del getitem_115 | |
del getitem_116 | |
del getitem_117 | |
del getitem_118 | |
del permute_default_54 | |
del permute_default_55 | |
del permute_default_56 | |
buf593 = buf592[0] | |
buf594 = buf592[1] | |
buf595 = buf592[2] | |
del buf592 | |
buf596 = reinterpret_tensor(buf591, (8192, 768), (768, 1), 0); del buf591 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf595, (8192, 768), (768, 1), 0), permute_459, out=buf596) | |
del permute_459 | |
buf597 = buf586; del buf586 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf595, (768, 8192), (1, 768), 0), view_44, out=buf597) | |
buf598 = buf587; del buf587 # reuse | |
# Source Nodes: [], Original ATen: [aten.sum] | |
triton_red_fused_sum_24.run(buf595, buf598, 49152, 128, grid=grid(49152), stream=stream0) | |
buf601 = empty_strided_cuda((768, ), (1, ), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum] | |
triton_per_fused__to_copy_sum_11.run(buf598, buf601, 768, 64, grid=grid(768), stream=stream0) | |
buf600 = empty_strided_cuda((768, 768), (768, 1), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy] | |
triton_poi_fused__to_copy_12.run(buf597, buf600, 589824, grid=grid(589824), stream=stream0) | |
buf602 = reinterpret_tensor(buf595, (8192, 768), (768, 1), 0); del buf595 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf594, (8192, 768), (768, 1), 0), permute_464, out=buf602) | |
del permute_464 | |
buf603 = buf597; del buf597 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf594, (768, 8192), (1, 768), 0), view_44, out=buf603) | |
buf604 = buf598; del buf598 # reuse | |
# Source Nodes: [], Original ATen: [aten.sum] | |
triton_red_fused_sum_24.run(buf594, buf604, 49152, 128, grid=grid(49152), stream=stream0) | |
buf607 = empty_strided_cuda((768, ), (1, ), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum] | |
triton_per_fused__to_copy_sum_11.run(buf604, buf607, 768, 64, grid=grid(768), stream=stream0) | |
buf606 = empty_strided_cuda((768, 768), (768, 1), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy] | |
triton_poi_fused__to_copy_12.run(buf603, buf606, 589824, grid=grid(589824), stream=stream0) | |
buf608 = reinterpret_tensor(buf594, (8192, 768), (768, 1), 0); del buf594 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf593, (8192, 768), (768, 1), 0), permute_468, out=buf608) | |
del permute_468 | |
buf609 = buf603; del buf603 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf593, (768, 8192), (1, 768), 0), view_44, out=buf609) | |
del view_44 | |
buf610 = buf604; del buf604 # reuse | |
# Source Nodes: [], Original ATen: [aten.sum] | |
triton_red_fused_sum_24.run(buf593, buf610, 49152, 128, grid=grid(49152), stream=stream0) | |
buf613 = empty_strided_cuda((768, ), (1, ), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum] | |
triton_per_fused__to_copy_sum_11.run(buf610, buf613, 768, 64, grid=grid(768), stream=stream0) | |
buf612 = empty_strided_cuda((768, 768), (768, 1), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy] | |
triton_poi_fused__to_copy_12.run(buf609, buf612, 589824, grid=grid(589824), stream=stream0) | |
buf617 = buf558; del buf558 # reuse | |
buf622 = reinterpret_tensor(buf593, (16, 512, 768), (393216, 768, 1), 0); del buf593 # reuse | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_dropout_backward, aten.native_layer_norm_backward] | |
triton_per_fused__to_copy_add_native_dropout_backward_native_layer_norm_backward_25.run(buf579, buf596, buf602, buf608, primals_36, mul_29, div_57, gt_6, buf617, buf622, 8192, 768, grid=grid(8192), stream=stream0) | |
del div_57 | |
del gt_6 | |
del primals_36 | |
buf618 = reinterpret_tensor(buf610, (768, 64), (1, 768), 0); del buf610 # reuse | |
buf620 = buf580; del buf580 # reuse | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward] | |
triton_red_fused__to_copy_add_native_layer_norm_backward_26.run(buf579, buf596, buf602, buf608, mul_29, buf618, buf620, 49152, 128, grid=grid(49152), stream=stream0) | |
del buf596 | |
del buf602 | |
del mul_29 | |
buf619 = empty_strided_cuda((768, ), (1, ), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward] | |
triton_per_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_9.run(buf618, buf619, 768, 64, grid=grid(768), stream=stream0) | |
buf621 = empty_strided_cuda((768, ), (1, ), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward] | |
triton_per_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_9.run(buf620, buf621, 768, 64, grid=grid(768), stream=stream0) | |
buf623 = reinterpret_tensor(buf570, (8192, 3072), (3072, 1), 0); del buf570 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf622, (8192, 768), (768, 1), 0), permute_472, out=buf623) | |
del permute_472 | |
buf624 = reinterpret_tensor(buf572, (768, 3072), (3072, 1), 0); del buf572 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf622, (768, 8192), (1, 768), 0), view_42, out=buf624) | |
del view_42 | |
buf625 = reinterpret_tensor(buf620, (1, 768, 64), (49152, 1, 768), 0); del buf620 # reuse | |
# Source Nodes: [], Original ATen: [aten.sum] | |
triton_red_fused_sum_15.run(buf622, buf625, 49152, 128, grid=grid(49152), stream=stream0) | |
buf628 = empty_strided_cuda((768, ), (1, ), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum] | |
triton_per_fused__to_copy_sum_11.run(buf625, buf628, 768, 64, grid=grid(768), stream=stream0) | |
buf627 = empty_strided_cuda((768, 3072), (3072, 1), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy] | |
triton_poi_fused__to_copy_16.run(buf624, buf627, 2359296, grid=grid(2359296), stream=stream0) | |
buf629 = reinterpret_tensor(buf623, (16, 512, 3072), (1572864, 3072, 1), 0); del buf623 # reuse | |
# Source Nodes: [hidden_states_12], Original ATen: [aten.gelu, aten.gelu_backward] | |
triton_poi_fused_gelu_gelu_backward_17.run(buf629, addmm_10, 25165824, grid=grid(25165824), stream=stream0) | |
del addmm_10 | |
buf630 = reinterpret_tensor(buf622, (8192, 768), (768, 1), 0); del buf622 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf629, (8192, 3072), (3072, 1), 0), permute_476, out=buf630) | |
del permute_476 | |
buf631 = reinterpret_tensor(buf624, (3072, 768), (768, 1), 0); del buf624 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf629, (3072, 8192), (1, 3072), 0), view_40, out=buf631) | |
del view_40 | |
buf632 = buf573; del buf573 # reuse | |
# Source Nodes: [], Original ATen: [aten.sum] | |
triton_red_fused_sum_18.run(buf629, buf632, 98304, 256, grid=grid(98304), stream=stream0) | |
buf635 = empty_strided_cuda((3072, ), (1, ), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum] | |
triton_per_fused__to_copy_sum_19.run(buf632, buf635, 3072, 32, grid=grid(3072), stream=stream0) | |
buf634 = empty_strided_cuda((3072, 768), (768, 1), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy] | |
triton_poi_fused__to_copy_20.run(buf631, buf634, 2359296, grid=grid(2359296), stream=stream0) | |
buf638 = buf579; del buf579 # reuse | |
buf643 = reinterpret_tensor(buf608, (16, 512, 768), (393216, 768, 1), 0); del buf608 # reuse | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_dropout_backward, aten.native_layer_norm_backward] | |
triton_per_fused__to_copy_add_native_dropout_backward_native_layer_norm_backward_21.run(buf617, buf630, primals_30, mul_22, div_58, gt_5, buf638, buf643, 8192, 768, grid=grid(8192), stream=stream0) | |
del div_58 | |
del gt_5 | |
del primals_30 | |
buf639 = reinterpret_tensor(buf625, (768, 64), (1, 768), 0); del buf625 # reuse | |
buf641 = buf618; del buf618 # reuse | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward] | |
triton_red_fused__to_copy_add_native_layer_norm_backward_22.run(buf617, buf630, mul_22, buf639, buf641, 49152, 128, grid=grid(49152), stream=stream0) | |
del mul_22 | |
buf640 = empty_strided_cuda((768, ), (1, ), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward] | |
triton_per_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_9.run(buf639, buf640, 768, 64, grid=grid(768), stream=stream0) | |
buf642 = empty_strided_cuda((768, ), (1, ), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward] | |
triton_per_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_9.run(buf641, buf642, 768, 64, grid=grid(768), stream=stream0) | |
buf644 = buf630; del buf630 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf643, (8192, 768), (768, 1), 0), permute_480, out=buf644) | |
del permute_480 | |
buf645 = buf609; del buf609 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf643, (768, 8192), (1, 768), 0), view_38, out=buf645) | |
del view_38 | |
buf646 = reinterpret_tensor(buf641, (1, 768, 64), (49152, 1, 768), 0); del buf641 # reuse | |
# Source Nodes: [], Original ATen: [aten.sum] | |
triton_red_fused_sum_15.run(buf643, buf646, 49152, 128, grid=grid(49152), stream=stream0) | |
buf649 = empty_strided_cuda((768, ), (1, ), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum] | |
triton_per_fused__to_copy_sum_11.run(buf646, buf649, 768, 64, grid=grid(768), stream=stream0) | |
buf648 = empty_strided_cuda((768, 768), (768, 1), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy] | |
triton_poi_fused__to_copy_12.run(buf645, buf648, 589824, grid=grid(589824), stream=stream0) | |
buf650 = reinterpret_tensor(buf643, (16, 12, 512, 64), (393216, 32768, 64, 1), 0); del buf643 # reuse | |
# Source Nodes: [], Original ATen: [aten.clone] | |
triton_poi_fused_clone_23.run(buf644, buf650, 6291456, grid=grid(6291456), stream=stream0) | |
del buf644 | |
# Source Nodes: [], Original ATen: [aten.clone] | |
buf651 = aten._scaled_dot_product_flash_attention_backward.default(buf650, permute_default_60, permute_default_61, permute_default_62, getitem_122, getitem_123, None, None, 512, 512, 0.1, False, getitem_124, getitem_125, scale=0.125) | |
del getitem_122 | |
del getitem_123 | |
del getitem_124 | |
del getitem_125 | |
del permute_default_60 | |
del permute_default_61 | |
del permute_default_62 | |
buf652 = buf651[0] | |
buf653 = buf651[1] | |
buf654 = buf651[2] | |
del buf651 | |
buf655 = reinterpret_tensor(buf650, (8192, 768), (768, 1), 0); del buf650 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf654, (8192, 768), (768, 1), 0), permute_492, out=buf655) | |
del permute_492 | |
buf656 = buf645; del buf645 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf654, (768, 8192), (1, 768), 0), view_22, out=buf656) | |
buf657 = buf646; del buf646 # reuse | |
# Source Nodes: [], Original ATen: [aten.sum] | |
triton_red_fused_sum_24.run(buf654, buf657, 49152, 128, grid=grid(49152), stream=stream0) | |
buf660 = empty_strided_cuda((768, ), (1, ), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum] | |
triton_per_fused__to_copy_sum_11.run(buf657, buf660, 768, 64, grid=grid(768), stream=stream0) | |
buf659 = empty_strided_cuda((768, 768), (768, 1), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy] | |
triton_poi_fused__to_copy_12.run(buf656, buf659, 589824, grid=grid(589824), stream=stream0) | |
buf661 = reinterpret_tensor(buf654, (8192, 768), (768, 1), 0); del buf654 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf653, (8192, 768), (768, 1), 0), permute_497, out=buf661) | |
del permute_497 | |
buf662 = buf656; del buf656 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf653, (768, 8192), (1, 768), 0), view_22, out=buf662) | |
buf663 = buf657; del buf657 # reuse | |
# Source Nodes: [], Original ATen: [aten.sum] | |
triton_red_fused_sum_24.run(buf653, buf663, 49152, 128, grid=grid(49152), stream=stream0) | |
buf666 = empty_strided_cuda((768, ), (1, ), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum] | |
triton_per_fused__to_copy_sum_11.run(buf663, buf666, 768, 64, grid=grid(768), stream=stream0) | |
buf665 = empty_strided_cuda((768, 768), (768, 1), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy] | |
triton_poi_fused__to_copy_12.run(buf662, buf665, 589824, grid=grid(589824), stream=stream0) | |
buf667 = reinterpret_tensor(buf653, (8192, 768), (768, 1), 0); del buf653 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf652, (8192, 768), (768, 1), 0), permute_501, out=buf667) | |
del permute_501 | |
buf668 = buf662; del buf662 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf652, (768, 8192), (1, 768), 0), view_22, out=buf668) | |
del view_22 | |
buf669 = buf663; del buf663 # reuse | |
# Source Nodes: [], Original ATen: [aten.sum] | |
triton_red_fused_sum_24.run(buf652, buf669, 49152, 128, grid=grid(49152), stream=stream0) | |
buf672 = empty_strided_cuda((768, ), (1, ), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum] | |
triton_per_fused__to_copy_sum_11.run(buf669, buf672, 768, 64, grid=grid(768), stream=stream0) | |
buf671 = empty_strided_cuda((768, 768), (768, 1), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy] | |
triton_poi_fused__to_copy_12.run(buf668, buf671, 589824, grid=grid(589824), stream=stream0) | |
buf676 = buf617; del buf617 # reuse | |
buf681 = reinterpret_tensor(buf652, (16, 512, 768), (393216, 768, 1), 0); del buf652 # reuse | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_dropout_backward, aten.native_layer_norm_backward] | |
triton_per_fused__to_copy_add_native_dropout_backward_native_layer_norm_backward_25.run(buf638, buf655, buf661, buf667, primals_20, mul_16, div_60, gt_3, buf676, buf681, 8192, 768, grid=grid(8192), stream=stream0) | |
del div_60 | |
del gt_3 | |
del primals_20 | |
buf677 = reinterpret_tensor(buf669, (768, 64), (1, 768), 0); del buf669 # reuse | |
buf679 = buf639; del buf639 # reuse | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward] | |
triton_red_fused__to_copy_add_native_layer_norm_backward_26.run(buf638, buf655, buf661, buf667, mul_16, buf677, buf679, 49152, 128, grid=grid(49152), stream=stream0) | |
del buf655 | |
del buf661 | |
del mul_16 | |
buf678 = empty_strided_cuda((768, ), (1, ), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward] | |
triton_per_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_9.run(buf677, buf678, 768, 64, grid=grid(768), stream=stream0) | |
buf680 = empty_strided_cuda((768, ), (1, ), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward] | |
triton_per_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_9.run(buf679, buf680, 768, 64, grid=grid(768), stream=stream0) | |
buf682 = reinterpret_tensor(buf629, (8192, 3072), (3072, 1), 0); del buf629 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf681, (8192, 768), (768, 1), 0), permute_505, out=buf682) | |
del permute_505 | |
buf683 = reinterpret_tensor(buf631, (768, 3072), (3072, 1), 0); del buf631 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf681, (768, 8192), (1, 768), 0), view_20, out=buf683) | |
del view_20 | |
buf684 = reinterpret_tensor(buf679, (1, 768, 64), (49152, 1, 768), 0); del buf679 # reuse | |
# Source Nodes: [], Original ATen: [aten.sum] | |
triton_red_fused_sum_15.run(buf681, buf684, 49152, 128, grid=grid(49152), stream=stream0) | |
buf687 = empty_strided_cuda((768, ), (1, ), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum] | |
triton_per_fused__to_copy_sum_11.run(buf684, buf687, 768, 64, grid=grid(768), stream=stream0) | |
buf686 = empty_strided_cuda((768, 3072), (3072, 1), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy] | |
triton_poi_fused__to_copy_16.run(buf683, buf686, 2359296, grid=grid(2359296), stream=stream0) | |
buf688 = reinterpret_tensor(buf682, (16, 512, 3072), (1572864, 3072, 1), 0); del buf682 # reuse | |
# Source Nodes: [hidden_states_4], Original ATen: [aten.gelu, aten.gelu_backward] | |
triton_poi_fused_gelu_gelu_backward_17.run(buf688, addmm_4, 25165824, grid=grid(25165824), stream=stream0) | |
del addmm_4 | |
buf689 = reinterpret_tensor(buf681, (8192, 768), (768, 1), 0); del buf681 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf688, (8192, 3072), (3072, 1), 0), permute_509, out=buf689) | |
del permute_509 | |
buf690 = reinterpret_tensor(buf683, (3072, 768), (768, 1), 0); del buf683 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf688, (3072, 8192), (1, 3072), 0), view_18, out=buf690) | |
del view_18 | |
buf691 = buf632; del buf632 # reuse | |
# Source Nodes: [], Original ATen: [aten.sum] | |
triton_red_fused_sum_18.run(buf688, buf691, 98304, 256, grid=grid(98304), stream=stream0) | |
del buf688 | |
buf694 = empty_strided_cuda((3072, ), (1, ), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum] | |
triton_per_fused__to_copy_sum_19.run(buf691, buf694, 3072, 32, grid=grid(3072), stream=stream0) | |
del buf691 | |
buf693 = empty_strided_cuda((3072, 768), (768, 1), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy] | |
triton_poi_fused__to_copy_20.run(buf690, buf693, 2359296, grid=grid(2359296), stream=stream0) | |
del buf690 | |
buf697 = buf638; del buf638 # reuse | |
buf702 = reinterpret_tensor(buf667, (16, 512, 768), (393216, 768, 1), 0); del buf667 # reuse | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_dropout_backward, aten.native_layer_norm_backward] | |
triton_per_fused__to_copy_add_native_dropout_backward_native_layer_norm_backward_21.run(buf676, buf689, primals_14, mul_9, div_61, gt_2, buf697, buf702, 8192, 768, grid=grid(8192), stream=stream0) | |
del div_61 | |
del gt_2 | |
del primals_14 | |
buf698 = reinterpret_tensor(buf684, (768, 64), (1, 768), 0); del buf684 # reuse | |
buf700 = buf677; del buf677 # reuse | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward] | |
triton_red_fused__to_copy_add_native_layer_norm_backward_22.run(buf676, buf689, mul_9, buf698, buf700, 49152, 128, grid=grid(49152), stream=stream0) | |
del mul_9 | |
buf699 = empty_strided_cuda((768, ), (1, ), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward] | |
triton_per_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_9.run(buf698, buf699, 768, 64, grid=grid(768), stream=stream0) | |
buf701 = empty_strided_cuda((768, ), (1, ), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward] | |
triton_per_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_9.run(buf700, buf701, 768, 64, grid=grid(768), stream=stream0) | |
buf703 = buf689; del buf689 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf702, (8192, 768), (768, 1), 0), permute_513, out=buf703) | |
del permute_513 | |
buf704 = buf668; del buf668 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf702, (768, 8192), (1, 768), 0), view_16, out=buf704) | |
del view_16 | |
buf705 = reinterpret_tensor(buf700, (1, 768, 64), (49152, 1, 768), 0); del buf700 # reuse | |
# Source Nodes: [], Original ATen: [aten.sum] | |
triton_red_fused_sum_15.run(buf702, buf705, 49152, 128, grid=grid(49152), stream=stream0) | |
buf708 = empty_strided_cuda((768, ), (1, ), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum] | |
triton_per_fused__to_copy_sum_11.run(buf705, buf708, 768, 64, grid=grid(768), stream=stream0) | |
buf707 = empty_strided_cuda((768, 768), (768, 1), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy] | |
triton_poi_fused__to_copy_12.run(buf704, buf707, 589824, grid=grid(589824), stream=stream0) | |
buf709 = reinterpret_tensor(buf702, (16, 12, 512, 64), (393216, 32768, 64, 1), 0); del buf702 # reuse | |
# Source Nodes: [], Original ATen: [aten.clone] | |
triton_poi_fused_clone_23.run(buf703, buf709, 6291456, grid=grid(6291456), stream=stream0) | |
del buf703 | |
# Source Nodes: [], Original ATen: [aten.clone] | |
buf710 = aten._scaled_dot_product_flash_attention_backward.default(buf709, permute_default_66, permute_default_67, permute_default_68, getitem_129, getitem_130, None, None, 512, 512, 0.1, False, getitem_131, getitem_132, scale=0.125) | |
del getitem_129 | |
del getitem_130 | |
del getitem_131 | |
del getitem_132 | |
del permute_default_66 | |
del permute_default_67 | |
del permute_default_68 | |
buf711 = buf710[0] | |
buf712 = buf710[1] | |
buf713 = buf710[2] | |
del buf710 | |
buf714 = reinterpret_tensor(buf709, (8192, 768), (768, 1), 0); del buf709 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf713, (8192, 768), (768, 1), 0), permute_525, out=buf714) | |
del permute_525 | |
buf715 = buf704; del buf704 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf713, (768, 8192), (1, 768), 0), view, out=buf715) | |
buf716 = buf705; del buf705 # reuse | |
# Source Nodes: [], Original ATen: [aten.sum] | |
triton_red_fused_sum_24.run(buf713, buf716, 49152, 128, grid=grid(49152), stream=stream0) | |
buf719 = empty_strided_cuda((768, ), (1, ), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum] | |
triton_per_fused__to_copy_sum_11.run(buf716, buf719, 768, 64, grid=grid(768), stream=stream0) | |
buf718 = empty_strided_cuda((768, 768), (768, 1), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy] | |
triton_poi_fused__to_copy_12.run(buf715, buf718, 589824, grid=grid(589824), stream=stream0) | |
buf720 = reinterpret_tensor(buf713, (8192, 768), (768, 1), 0); del buf713 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf712, (8192, 768), (768, 1), 0), permute_530, out=buf720) | |
del permute_530 | |
buf721 = buf715; del buf715 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf712, (768, 8192), (1, 768), 0), view, out=buf721) | |
buf722 = buf716; del buf716 # reuse | |
# Source Nodes: [], Original ATen: [aten.sum] | |
triton_red_fused_sum_24.run(buf712, buf722, 49152, 128, grid=grid(49152), stream=stream0) | |
buf725 = empty_strided_cuda((768, ), (1, ), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum] | |
triton_per_fused__to_copy_sum_11.run(buf722, buf725, 768, 64, grid=grid(768), stream=stream0) | |
buf724 = empty_strided_cuda((768, 768), (768, 1), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy] | |
triton_poi_fused__to_copy_12.run(buf721, buf724, 589824, grid=grid(589824), stream=stream0) | |
buf726 = reinterpret_tensor(buf712, (8192, 768), (768, 1), 0); del buf712 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf711, (8192, 768), (768, 1), 0), permute_534, out=buf726) | |
del permute_534 | |
buf727 = buf721; del buf721 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf711, (768, 8192), (1, 768), 0), view, out=buf727) | |
del view | |
buf728 = buf722; del buf722 # reuse | |
# Source Nodes: [], Original ATen: [aten.sum] | |
triton_red_fused_sum_24.run(buf711, buf728, 49152, 128, grid=grid(49152), stream=stream0) | |
del buf711 | |
buf731 = empty_strided_cuda((768, ), (1, ), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum] | |
triton_per_fused__to_copy_sum_11.run(buf728, buf731, 768, 64, grid=grid(768), stream=stream0) | |
buf730 = empty_strided_cuda((768, 768), (768, 1), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy] | |
triton_poi_fused__to_copy_12.run(buf727, buf730, 589824, grid=grid(589824), stream=stream0) | |
del buf727 | |
buf743 = empty_strided_cuda((2, 768), (768, 1), torch.float32) | |
# Source Nodes: [], Original ATen: [aten.embedding_dense_backward] | |
triton_poi_fused_embedding_dense_backward_27.run(buf743, 1536, grid=grid(1536), stream=stream0) | |
buf745 = empty_strided_cuda((30522, 768), (768, 1), torch.float32) | |
# Source Nodes: [], Original ATen: [aten.embedding_dense_backward] | |
triton_poi_fused_embedding_dense_backward_28.run(buf745, 23440896, grid=grid(23440896), stream=stream0) | |
buf732 = buf697; del buf697 # reuse | |
buf735 = buf676; del buf676 # reuse | |
# Source Nodes: [masked_lm_loss], Original ATen: [aten._to_copy, aten.add, aten.embedding_dense_backward, aten.native_dropout_backward, aten.native_layer_norm_backward, aten.nll_loss_forward] | |
triton_per_fused__to_copy_add_embedding_dense_backward_native_dropout_backward_native_layer_norm_backward_nll_loss_forward_29.run(buf732, buf714, buf720, buf726, gt, primals_4, mul_1, div_63, primals_204, primals_206, buf735, buf743, buf745, 8192, 768, grid=grid(8192), stream=stream0) | |
del buf714 | |
del buf720 | |
del buf726 | |
del div_63 | |
del gt | |
del primals_204 | |
del primals_206 | |
del primals_4 | |
buf736 = reinterpret_tensor(buf728, (768, 64), (1, 768), 0); del buf728 # reuse | |
buf738 = buf698; del buf698 # reuse | |
# Source Nodes: [], Original ATen: [aten.native_layer_norm_backward] | |
triton_red_fused_native_layer_norm_backward_30.run(buf732, mul_1, buf736, buf738, 49152, 128, grid=grid(49152), stream=stream0) | |
del buf732 | |
del mul_1 | |
buf737 = empty_strided_cuda((768, ), (1, ), torch.float32) | |
# Source Nodes: [], Original ATen: [aten.native_layer_norm_backward] | |
triton_per_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_9.run(buf736, buf737, 768, 64, grid=grid(768), stream=stream0) | |
del buf736 | |
buf739 = empty_strided_cuda((768, ), (1, ), torch.float32) | |
# Source Nodes: [], Original ATen: [aten.native_layer_norm_backward] | |
triton_per_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_9.run(buf738, buf739, 768, 64, grid=grid(768), stream=stream0) | |
del buf738 | |
buf741 = empty_strided_cuda((512, 768), (768, 1), torch.float32) | |
# Source Nodes: [], Original ATen: [aten.embedding_dense_backward] | |
triton_poi_fused_embedding_dense_backward_31.run(buf741, 393216, grid=grid(393216), stream=stream0) | |
# Source Nodes: [masked_lm_loss], Original ATen: [aten.embedding_dense_backward, aten.nll_loss_forward, aten.sum] | |
triton_per_fused_embedding_dense_backward_nll_loss_forward_sum_32.run(buf735, primals_205, buf741, 393216, 16, grid=grid(393216), stream=stream0) | |
del buf735 | |
del primals_205 | |
return (buf745, buf743, buf741, buf737, buf739, buf730, buf731, buf724, buf725, buf718, buf719, buf707, buf708, buf699, buf701, buf693, buf694, buf686, buf687, buf678, buf680, buf671, buf672, buf665, buf666, buf659, buf660, buf648, buf649, buf640, buf642, buf634, buf635, buf627, buf628, buf619, buf621, buf612, buf613, buf606, buf607, buf600, buf601, buf589, buf590, buf581, buf583, buf575, buf576, buf568, buf569, buf560, buf562, buf553, buf554, buf547, buf548, buf541, buf542, buf530, buf531, buf522, buf524, buf516, buf517, buf509, buf510, buf501, buf503, buf494, buf495, buf488, buf489, buf482, buf483, buf471, buf472, buf463, buf465, buf457, buf458, buf450, buf451, buf442, buf444, buf435, buf436, buf429, buf430, buf423, buf424, buf412, buf413, buf404, buf406, buf398, buf399, buf391, buf392, buf383, buf385, buf376, buf377, buf370, buf371, buf364, buf365, buf353, buf354, buf345, buf347, buf339, buf340, buf332, buf333, buf324, buf326, buf317, buf318, buf311, buf312, buf305, buf306, buf294, buf295, buf286, buf288, buf280, buf281, buf273, buf274, buf265, buf267, buf258, buf259, buf252, buf253, buf246, buf247, buf235, buf236, buf227, buf229, buf221, buf222, buf214, buf215, buf206, buf208, buf199, buf200, buf193, buf194, buf187, buf188, buf176, buf177, buf168, buf170, buf162, buf163, buf155, buf156, buf147, buf149, buf140, buf141, buf134, buf135, buf128, buf129, buf117, buf118, buf109, buf111, buf103, buf104, buf96, buf97, buf88, buf90, buf81, buf82, buf75, buf76, buf69, buf70, buf58, buf59, buf50, buf52, buf44, buf45, buf37, buf38, buf29, buf31, buf23, buf24, buf14, buf16, buf9, buf10, None, None, None, None, ) | |
def benchmark_compiled_module(times=10, repeat=10): | |
from torch._dynamo.testing import rand_strided | |
from torch._inductor.utils import print_performance | |
primals_4 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.float32) | |
primals_14 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.float32) | |
primals_20 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.float32) | |
primals_30 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.float32) | |
primals_36 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.float32) | |
primals_46 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.float32) | |
primals_52 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.float32) | |
primals_62 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.float32) | |
primals_68 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.float32) | |
primals_78 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.float32) | |
primals_84 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.float32) | |
primals_94 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.float32) | |
primals_100 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.float32) | |
primals_110 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.float32) | |
primals_116 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.float32) | |
primals_126 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.float32) | |
primals_132 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.float32) | |
primals_142 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.float32) | |
primals_148 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.float32) | |
primals_158 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.float32) | |
primals_164 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.float32) | |
primals_174 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.float32) | |
primals_180 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.float32) | |
primals_190 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.float32) | |
primals_196 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.float32) | |
primals_200 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.float32) | |
primals_204 = rand_strided((1, 512), (512, 1), device='cuda:0', dtype=torch.int64) | |
primals_205 = rand_strided((1, 512), (512, 1), device='cuda:0', dtype=torch.int64) | |
primals_206 = rand_strided((16, 512), (512, 1), device='cuda:0', dtype=torch.int64) | |
primals_207 = rand_strided((16, 512), (512, 1), device='cuda:0', dtype=torch.int64) | |
mul_1 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.float32) | |
gt = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.bool) | |
view = rand_strided((8192, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
permute_default_66 = rand_strided((16, 12, 512, 64), (393216, 64, 768, 1), device='cuda:0', dtype=torch.float16) | |
permute_default_67 = rand_strided((16, 12, 512, 64), (393216, 64, 768, 1), device='cuda:0', dtype=torch.float16) | |
permute_default_68 = rand_strided((16, 12, 512, 64), (393216, 64, 768, 1), device='cuda:0', dtype=torch.float16) | |
getitem_129 = rand_strided((16, 12, 512, 64), (393216, 64, 768, 1), device='cuda:0', dtype=torch.float16) | |
getitem_130 = rand_strided((16, 12, 512), (6144, 512, 1), device='cuda:0', dtype=torch.float32) | |
getitem_131 = rand_strided((), (), device='cuda:0', dtype=torch.int64) | |
getitem_132 = rand_strided((), (), device='cuda:0', dtype=torch.int64) | |
view_16 = rand_strided((8192, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
gt_2 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.bool) | |
mul_9 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.float32) | |
view_18 = rand_strided((8192, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
addmm_4 = rand_strided((8192, 3072), (3072, 1), device='cuda:0', dtype=torch.float16) | |
view_20 = rand_strided((8192, 3072), (3072, 1), device='cuda:0', dtype=torch.float16) | |
gt_3 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.bool) | |
mul_16 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.float32) | |
view_22 = rand_strided((8192, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
permute_default_60 = rand_strided((16, 12, 512, 64), (393216, 64, 768, 1), device='cuda:0', dtype=torch.float16) | |
permute_default_61 = rand_strided((16, 12, 512, 64), (393216, 64, 768, 1), device='cuda:0', dtype=torch.float16) | |
permute_default_62 = rand_strided((16, 12, 512, 64), (393216, 64, 768, 1), device='cuda:0', dtype=torch.float16) | |
getitem_122 = rand_strided((16, 12, 512, 64), (393216, 64, 768, 1), device='cuda:0', dtype=torch.float16) | |
getitem_123 = rand_strided((16, 12, 512), (6144, 512, 1), device='cuda:0', dtype=torch.float32) | |
getitem_124 = rand_strided((), (), device='cuda:0', dtype=torch.int64) | |
getitem_125 = rand_strided((), (), device='cuda:0', dtype=torch.int64) | |
view_38 = rand_strided((8192, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
gt_5 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.bool) | |
mul_22 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.float32) | |
view_40 = rand_strided((8192, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
addmm_10 = rand_strided((8192, 3072), (3072, 1), device='cuda:0', dtype=torch.float16) | |
view_42 = rand_strided((8192, 3072), (3072, 1), device='cuda:0', dtype=torch.float16) | |
gt_6 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.bool) | |
mul_29 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.float32) | |
view_44 = rand_strided((8192, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
permute_default_54 = rand_strided((16, 12, 512, 64), (393216, 64, 768, 1), device='cuda:0', dtype=torch.float16) | |
permute_default_55 = rand_strided((16, 12, 512, 64), (393216, 64, 768, 1), device='cuda:0', dtype=torch.float16) | |
permute_default_56 = rand_strided((16, 12, 512, 64), (393216, 64, 768, 1), device='cuda:0', dtype=torch.float16) | |
getitem_115 = rand_strided((16, 12, 512, 64), (393216, 64, 768, 1), device='cuda:0', dtype=torch.float16) | |
getitem_116 = rand_strided((16, 12, 512), (6144, 512, 1), device='cuda:0', dtype=torch.float32) | |
getitem_117 = rand_strided((), (), device='cuda:0', dtype=torch.int64) | |
getitem_118 = rand_strided((), (), device='cuda:0', dtype=torch.int64) | |
view_60 = rand_strided((8192, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
gt_8 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.bool) | |
mul_35 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.float32) | |
view_62 = rand_strided((8192, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
addmm_16 = rand_strided((8192, 3072), (3072, 1), device='cuda:0', dtype=torch.float16) | |
view_64 = rand_strided((8192, 3072), (3072, 1), device='cuda:0', dtype=torch.float16) | |
gt_9 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.bool) | |
mul_42 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.float32) | |
view_66 = rand_strided((8192, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
permute_default_48 = rand_strided((16, 12, 512, 64), (393216, 64, 768, 1), device='cuda:0', dtype=torch.float16) | |
permute_default_49 = rand_strided((16, 12, 512, 64), (393216, 64, 768, 1), device='cuda:0', dtype=torch.float16) | |
permute_default_50 = rand_strided((16, 12, 512, 64), (393216, 64, 768, 1), device='cuda:0', dtype=torch.float16) | |
getitem_108 = rand_strided((16, 12, 512, 64), (393216, 64, 768, 1), device='cuda:0', dtype=torch.float16) | |
getitem_109 = rand_strided((16, 12, 512), (6144, 512, 1), device='cuda:0', dtype=torch.float32) | |
getitem_110 = rand_strided((), (), device='cuda:0', dtype=torch.int64) | |
getitem_111 = rand_strided((), (), device='cuda:0', dtype=torch.int64) | |
view_82 = rand_strided((8192, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
gt_11 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.bool) | |
mul_48 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.float32) | |
view_84 = rand_strided((8192, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
addmm_22 = rand_strided((8192, 3072), (3072, 1), device='cuda:0', dtype=torch.float16) | |
view_86 = rand_strided((8192, 3072), (3072, 1), device='cuda:0', dtype=torch.float16) | |
gt_12 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.bool) | |
mul_55 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.float32) | |
view_88 = rand_strided((8192, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
permute_default_42 = rand_strided((16, 12, 512, 64), (393216, 64, 768, 1), device='cuda:0', dtype=torch.float16) | |
permute_default_43 = rand_strided((16, 12, 512, 64), (393216, 64, 768, 1), device='cuda:0', dtype=torch.float16) | |
permute_default_44 = rand_strided((16, 12, 512, 64), (393216, 64, 768, 1), device='cuda:0', dtype=torch.float16) | |
getitem_101 = rand_strided((16, 12, 512, 64), (393216, 64, 768, 1), device='cuda:0', dtype=torch.float16) | |
getitem_102 = rand_strided((16, 12, 512), (6144, 512, 1), device='cuda:0', dtype=torch.float32) | |
getitem_103 = rand_strided((), (), device='cuda:0', dtype=torch.int64) | |
getitem_104 = rand_strided((), (), device='cuda:0', dtype=torch.int64) | |
view_104 = rand_strided((8192, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
gt_14 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.bool) | |
mul_61 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.float32) | |
view_106 = rand_strided((8192, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
addmm_28 = rand_strided((8192, 3072), (3072, 1), device='cuda:0', dtype=torch.float16) | |
view_108 = rand_strided((8192, 3072), (3072, 1), device='cuda:0', dtype=torch.float16) | |
gt_15 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.bool) | |
mul_68 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.float32) | |
view_110 = rand_strided((8192, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
permute_default_36 = rand_strided((16, 12, 512, 64), (393216, 64, 768, 1), device='cuda:0', dtype=torch.float16) | |
permute_default_37 = rand_strided((16, 12, 512, 64), (393216, 64, 768, 1), device='cuda:0', dtype=torch.float16) | |
permute_default_38 = rand_strided((16, 12, 512, 64), (393216, 64, 768, 1), device='cuda:0', dtype=torch.float16) | |
getitem_94 = rand_strided((16, 12, 512, 64), (393216, 64, 768, 1), device='cuda:0', dtype=torch.float16) | |
getitem_95 = rand_strided((16, 12, 512), (6144, 512, 1), device='cuda:0', dtype=torch.float32) | |
getitem_96 = rand_strided((), (), device='cuda:0', dtype=torch.int64) | |
getitem_97 = rand_strided((), (), device='cuda:0', dtype=torch.int64) | |
view_126 = rand_strided((8192, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
gt_17 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.bool) | |
mul_74 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.float32) | |
view_128 = rand_strided((8192, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
addmm_34 = rand_strided((8192, 3072), (3072, 1), device='cuda:0', dtype=torch.float16) | |
view_130 = rand_strided((8192, 3072), (3072, 1), device='cuda:0', dtype=torch.float16) | |
gt_18 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.bool) | |
mul_81 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.float32) | |
view_132 = rand_strided((8192, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
permute_default_30 = rand_strided((16, 12, 512, 64), (393216, 64, 768, 1), device='cuda:0', dtype=torch.float16) | |
permute_default_31 = rand_strided((16, 12, 512, 64), (393216, 64, 768, 1), device='cuda:0', dtype=torch.float16) | |
permute_default_32 = rand_strided((16, 12, 512, 64), (393216, 64, 768, 1), device='cuda:0', dtype=torch.float16) | |
getitem_87 = rand_strided((16, 12, 512, 64), (393216, 64, 768, 1), device='cuda:0', dtype=torch.float16) | |
getitem_88 = rand_strided((16, 12, 512), (6144, 512, 1), device='cuda:0', dtype=torch.float32) | |
getitem_89 = rand_strided((), (), device='cuda:0', dtype=torch.int64) | |
getitem_90 = rand_strided((), (), device='cuda:0', dtype=torch.int64) | |
view_148 = rand_strided((8192, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
gt_20 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.bool) | |
mul_87 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.float32) | |
view_150 = rand_strided((8192, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
addmm_40 = rand_strided((8192, 3072), (3072, 1), device='cuda:0', dtype=torch.float16) | |
view_152 = rand_strided((8192, 3072), (3072, 1), device='cuda:0', dtype=torch.float16) | |
gt_21 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.bool) | |
mul_94 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.float32) | |
view_154 = rand_strided((8192, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
permute_default_24 = rand_strided((16, 12, 512, 64), (393216, 64, 768, 1), device='cuda:0', dtype=torch.float16) | |
permute_default_25 = rand_strided((16, 12, 512, 64), (393216, 64, 768, 1), device='cuda:0', dtype=torch.float16) | |
permute_default_26 = rand_strided((16, 12, 512, 64), (393216, 64, 768, 1), device='cuda:0', dtype=torch.float16) | |
getitem_80 = rand_strided((16, 12, 512, 64), (393216, 64, 768, 1), device='cuda:0', dtype=torch.float16) | |
getitem_81 = rand_strided((16, 12, 512), (6144, 512, 1), device='cuda:0', dtype=torch.float32) | |
getitem_82 = rand_strided((), (), device='cuda:0', dtype=torch.int64) | |
getitem_83 = rand_strided((), (), device='cuda:0', dtype=torch.int64) | |
view_170 = rand_strided((8192, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
gt_23 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.bool) | |
mul_100 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.float32) | |
view_172 = rand_strided((8192, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
addmm_46 = rand_strided((8192, 3072), (3072, 1), device='cuda:0', dtype=torch.float16) | |
view_174 = rand_strided((8192, 3072), (3072, 1), device='cuda:0', dtype=torch.float16) | |
gt_24 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.bool) | |
mul_107 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.float32) | |
view_176 = rand_strided((8192, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
permute_default_18 = rand_strided((16, 12, 512, 64), (393216, 64, 768, 1), device='cuda:0', dtype=torch.float16) | |
permute_default_19 = rand_strided((16, 12, 512, 64), (393216, 64, 768, 1), device='cuda:0', dtype=torch.float16) | |
permute_default_20 = rand_strided((16, 12, 512, 64), (393216, 64, 768, 1), device='cuda:0', dtype=torch.float16) | |
getitem_73 = rand_strided((16, 12, 512, 64), (393216, 64, 768, 1), device='cuda:0', dtype=torch.float16) | |
getitem_74 = rand_strided((16, 12, 512), (6144, 512, 1), device='cuda:0', dtype=torch.float32) | |
getitem_75 = rand_strided((), (), device='cuda:0', dtype=torch.int64) | |
getitem_76 = rand_strided((), (), device='cuda:0', dtype=torch.int64) | |
view_192 = rand_strided((8192, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
gt_26 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.bool) | |
mul_113 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.float32) | |
view_194 = rand_strided((8192, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
addmm_52 = rand_strided((8192, 3072), (3072, 1), device='cuda:0', dtype=torch.float16) | |
view_196 = rand_strided((8192, 3072), (3072, 1), device='cuda:0', dtype=torch.float16) | |
gt_27 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.bool) | |
mul_120 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.float32) | |
view_198 = rand_strided((8192, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
permute_default_12 = rand_strided((16, 12, 512, 64), (393216, 64, 768, 1), device='cuda:0', dtype=torch.float16) | |
permute_default_13 = rand_strided((16, 12, 512, 64), (393216, 64, 768, 1), device='cuda:0', dtype=torch.float16) | |
permute_default_14 = rand_strided((16, 12, 512, 64), (393216, 64, 768, 1), device='cuda:0', dtype=torch.float16) | |
getitem_66 = rand_strided((16, 12, 512, 64), (393216, 64, 768, 1), device='cuda:0', dtype=torch.float16) | |
getitem_67 = rand_strided((16, 12, 512), (6144, 512, 1), device='cuda:0', dtype=torch.float32) | |
getitem_68 = rand_strided((), (), device='cuda:0', dtype=torch.int64) | |
getitem_69 = rand_strided((), (), device='cuda:0', dtype=torch.int64) | |
view_214 = rand_strided((8192, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
gt_29 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.bool) | |
mul_126 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.float32) | |
view_216 = rand_strided((8192, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
addmm_58 = rand_strided((8192, 3072), (3072, 1), device='cuda:0', dtype=torch.float16) | |
view_218 = rand_strided((8192, 3072), (3072, 1), device='cuda:0', dtype=torch.float16) | |
gt_30 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.bool) | |
mul_133 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.float32) | |
view_220 = rand_strided((8192, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
permute_default_6 = rand_strided((16, 12, 512, 64), (393216, 64, 768, 1), device='cuda:0', dtype=torch.float16) | |
permute_default_7 = rand_strided((16, 12, 512, 64), (393216, 64, 768, 1), device='cuda:0', dtype=torch.float16) | |
permute_default_8 = rand_strided((16, 12, 512, 64), (393216, 64, 768, 1), device='cuda:0', dtype=torch.float16) | |
getitem_59 = rand_strided((16, 12, 512, 64), (393216, 64, 768, 1), device='cuda:0', dtype=torch.float16) | |
getitem_60 = rand_strided((16, 12, 512), (6144, 512, 1), device='cuda:0', dtype=torch.float32) | |
getitem_61 = rand_strided((), (), device='cuda:0', dtype=torch.int64) | |
getitem_62 = rand_strided((), (), device='cuda:0', dtype=torch.int64) | |
view_236 = rand_strided((8192, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
gt_32 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.bool) | |
mul_139 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.float32) | |
view_238 = rand_strided((8192, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
addmm_64 = rand_strided((8192, 3072), (3072, 1), device='cuda:0', dtype=torch.float16) | |
view_240 = rand_strided((8192, 3072), (3072, 1), device='cuda:0', dtype=torch.float16) | |
gt_33 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.bool) | |
mul_146 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.float32) | |
view_242 = rand_strided((8192, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
permute_default = rand_strided((16, 12, 512, 64), (393216, 64, 768, 1), device='cuda:0', dtype=torch.float16) | |
permute_default_1 = rand_strided((16, 12, 512, 64), (393216, 64, 768, 1), device='cuda:0', dtype=torch.float16) | |
permute_default_2 = rand_strided((16, 12, 512, 64), (393216, 64, 768, 1), device='cuda:0', dtype=torch.float16) | |
getitem_52 = rand_strided((16, 12, 512, 64), (393216, 64, 768, 1), device='cuda:0', dtype=torch.float16) | |
getitem_53 = rand_strided((16, 12, 512), (6144, 512, 1), device='cuda:0', dtype=torch.float32) | |
getitem_54 = rand_strided((), (), device='cuda:0', dtype=torch.int64) | |
getitem_55 = rand_strided((), (), device='cuda:0', dtype=torch.int64) | |
view_258 = rand_strided((8192, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
gt_35 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.bool) | |
mul_152 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.float32) | |
view_260 = rand_strided((8192, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
addmm_70 = rand_strided((8192, 3072), (3072, 1), device='cuda:0', dtype=torch.float16) | |
view_262 = rand_strided((8192, 3072), (3072, 1), device='cuda:0', dtype=torch.float16) | |
gt_36 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.bool) | |
mul_159 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.float32) | |
view_264 = rand_strided((8192, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
addmm_72 = rand_strided((8192, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
getitem_51 = rand_strided((16, 512, 1), (512, 1, 1), device='cuda:0', dtype=torch.float32) | |
rsqrt_25 = rand_strided((16, 512, 1), (512, 1, 1), device='cuda:0', dtype=torch.float32) | |
view_266 = rand_strided((8192, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
view_267 = rand_strided((16, 512, 30522), (15630336, 30528, 1), device='cuda:0', dtype=torch.float16) | |
amax_12 = rand_strided((8192, 1), (1, 1), device='cuda:0', dtype=torch.float32) | |
log = rand_strided((8192, 1), (1, 1), device='cuda:0', dtype=torch.float32) | |
convert_element_type_510 = rand_strided((), (), device='cuda:0', dtype=torch.float32) | |
permute_134 = rand_strided((30522, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
permute_138 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
div_27 = rand_strided((16, 512, 1), (512, 1, 1), device='cuda:0', dtype=torch.float32) | |
permute_142 = rand_strided((768, 3072), (3072, 1), device='cuda:0', dtype=torch.float16) | |
permute_146 = rand_strided((3072, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
div_28 = rand_strided((16, 512, 1), (512, 1, 1), device='cuda:0', dtype=torch.float32) | |
permute_150 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
permute_162 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
permute_167 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
permute_171 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
div_30 = rand_strided((16, 512, 1), (512, 1, 1), device='cuda:0', dtype=torch.float32) | |
permute_175 = rand_strided((768, 3072), (3072, 1), device='cuda:0', dtype=torch.float16) | |
permute_179 = rand_strided((3072, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
div_31 = rand_strided((16, 512, 1), (512, 1, 1), device='cuda:0', dtype=torch.float32) | |
permute_183 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
permute_195 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
permute_200 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
permute_204 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
div_33 = rand_strided((16, 512, 1), (512, 1, 1), device='cuda:0', dtype=torch.float32) | |
permute_208 = rand_strided((768, 3072), (3072, 1), device='cuda:0', dtype=torch.float16) | |
permute_212 = rand_strided((3072, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
div_34 = rand_strided((16, 512, 1), (512, 1, 1), device='cuda:0', dtype=torch.float32) | |
permute_216 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
permute_228 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
permute_233 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
permute_237 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
div_36 = rand_strided((16, 512, 1), (512, 1, 1), device='cuda:0', dtype=torch.float32) | |
permute_241 = rand_strided((768, 3072), (3072, 1), device='cuda:0', dtype=torch.float16) | |
permute_245 = rand_strided((3072, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
div_37 = rand_strided((16, 512, 1), (512, 1, 1), device='cuda:0', dtype=torch.float32) | |
permute_249 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
permute_261 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
permute_266 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
permute_270 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
div_39 = rand_strided((16, 512, 1), (512, 1, 1), device='cuda:0', dtype=torch.float32) | |
permute_274 = rand_strided((768, 3072), (3072, 1), device='cuda:0', dtype=torch.float16) | |
permute_278 = rand_strided((3072, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
div_40 = rand_strided((16, 512, 1), (512, 1, 1), device='cuda:0', dtype=torch.float32) | |
permute_282 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
permute_294 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
permute_299 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
permute_303 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
div_42 = rand_strided((16, 512, 1), (512, 1, 1), device='cuda:0', dtype=torch.float32) | |
permute_307 = rand_strided((768, 3072), (3072, 1), device='cuda:0', dtype=torch.float16) | |
permute_311 = rand_strided((3072, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
div_43 = rand_strided((16, 512, 1), (512, 1, 1), device='cuda:0', dtype=torch.float32) | |
permute_315 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
permute_327 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
permute_332 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
permute_336 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
div_45 = rand_strided((16, 512, 1), (512, 1, 1), device='cuda:0', dtype=torch.float32) | |
permute_340 = rand_strided((768, 3072), (3072, 1), device='cuda:0', dtype=torch.float16) | |
permute_344 = rand_strided((3072, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
div_46 = rand_strided((16, 512, 1), (512, 1, 1), device='cuda:0', dtype=torch.float32) | |
permute_348 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
permute_360 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
permute_365 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
permute_369 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
div_48 = rand_strided((16, 512, 1), (512, 1, 1), device='cuda:0', dtype=torch.float32) | |
permute_373 = rand_strided((768, 3072), (3072, 1), device='cuda:0', dtype=torch.float16) | |
permute_377 = rand_strided((3072, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
div_49 = rand_strided((16, 512, 1), (512, 1, 1), device='cuda:0', dtype=torch.float32) | |
permute_381 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
permute_393 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
permute_398 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
permute_402 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
div_51 = rand_strided((16, 512, 1), (512, 1, 1), device='cuda:0', dtype=torch.float32) | |
permute_406 = rand_strided((768, 3072), (3072, 1), device='cuda:0', dtype=torch.float16) | |
permute_410 = rand_strided((3072, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
div_52 = rand_strided((16, 512, 1), (512, 1, 1), device='cuda:0', dtype=torch.float32) | |
permute_414 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
permute_426 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
permute_431 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
permute_435 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
div_54 = rand_strided((16, 512, 1), (512, 1, 1), device='cuda:0', dtype=torch.float32) | |
permute_439 = rand_strided((768, 3072), (3072, 1), device='cuda:0', dtype=torch.float16) | |
permute_443 = rand_strided((3072, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
div_55 = rand_strided((16, 512, 1), (512, 1, 1), device='cuda:0', dtype=torch.float32) | |
permute_447 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
permute_459 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
permute_464 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
permute_468 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
div_57 = rand_strided((16, 512, 1), (512, 1, 1), device='cuda:0', dtype=torch.float32) | |
permute_472 = rand_strided((768, 3072), (3072, 1), device='cuda:0', dtype=torch.float16) | |
permute_476 = rand_strided((3072, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
div_58 = rand_strided((16, 512, 1), (512, 1, 1), device='cuda:0', dtype=torch.float32) | |
permute_480 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
permute_492 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
permute_497 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
permute_501 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
div_60 = rand_strided((16, 512, 1), (512, 1, 1), device='cuda:0', dtype=torch.float32) | |
permute_505 = rand_strided((768, 3072), (3072, 1), device='cuda:0', dtype=torch.float16) | |
permute_509 = rand_strided((3072, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
div_61 = rand_strided((16, 512, 1), (512, 1, 1), device='cuda:0', dtype=torch.float32) | |
permute_513 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
permute_525 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
permute_530 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
permute_534 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
div_63 = rand_strided((16, 512, 1), (512, 1, 1), device='cuda:0', dtype=torch.float32) | |
tangents_1 = rand_strided((), (), device='cuda:0', dtype=torch.float32) | |
tangents_2 = rand_strided((16, 512, 30522), (15627264, 30522, 1), device='cuda:0', dtype=torch.float16) | |
fn = lambda: call([primals_4, primals_14, primals_20, primals_30, primals_36, primals_46, primals_52, primals_62, primals_68, primals_78, primals_84, primals_94, primals_100, primals_110, primals_116, primals_126, primals_132, primals_142, primals_148, primals_158, primals_164, primals_174, primals_180, primals_190, primals_196, primals_200, primals_204, primals_205, primals_206, primals_207, mul_1, gt, view, permute_default_66, permute_default_67, permute_default_68, getitem_129, getitem_130, getitem_131, getitem_132, view_16, gt_2, mul_9, view_18, addmm_4, view_20, gt_3, mul_16, view_22, permute_default_60, permute_default_61, permute_default_62, getitem_122, getitem_123, getitem_124, getitem_125, view_38, gt_5, mul_22, view_40, addmm_10, view_42, gt_6, mul_29, view_44, permute_default_54, permute_default_55, permute_default_56, getitem_115, getitem_116, getitem_117, getitem_118, view_60, gt_8, mul_35, view_62, addmm_16, view_64, gt_9, mul_42, view_66, permute_default_48, permute_default_49, permute_default_50, getitem_108, getitem_109, getitem_110, getitem_111, view_82, gt_11, mul_48, view_84, addmm_22, view_86, gt_12, mul_55, view_88, permute_default_42, permute_default_43, permute_default_44, getitem_101, getitem_102, getitem_103, getitem_104, view_104, gt_14, mul_61, view_106, addmm_28, view_108, gt_15, mul_68, view_110, permute_default_36, permute_default_37, permute_default_38, getitem_94, getitem_95, getitem_96, getitem_97, view_126, gt_17, mul_74, view_128, addmm_34, view_130, gt_18, mul_81, view_132, permute_default_30, permute_default_31, permute_default_32, getitem_87, getitem_88, getitem_89, getitem_90, view_148, gt_20, mul_87, view_150, addmm_40, view_152, gt_21, mul_94, view_154, permute_default_24, permute_default_25, permute_default_26, getitem_80, getitem_81, getitem_82, getitem_83, view_170, gt_23, mul_100, view_172, addmm_46, view_174, gt_24, mul_107, view_176, permute_default_18, permute_default_19, permute_default_20, getitem_73, getitem_74, getitem_75, getitem_76, view_192, gt_26, mul_113, view_194, addmm_52, view_196, gt_27, mul_120, view_198, permute_default_12, permute_default_13, permute_default_14, getitem_66, getitem_67, getitem_68, getitem_69, view_214, gt_29, mul_126, view_216, addmm_58, view_218, gt_30, mul_133, view_220, permute_default_6, permute_default_7, permute_default_8, getitem_59, getitem_60, getitem_61, getitem_62, view_236, gt_32, mul_139, view_238, addmm_64, view_240, gt_33, mul_146, view_242, permute_default, permute_default_1, permute_default_2, getitem_52, getitem_53, getitem_54, getitem_55, view_258, gt_35, mul_152, view_260, addmm_70, view_262, gt_36, mul_159, view_264, addmm_72, getitem_51, rsqrt_25, view_266, view_267, amax_12, log, convert_element_type_510, permute_134, permute_138, div_27, permute_142, permute_146, div_28, permute_150, permute_162, permute_167, permute_171, div_30, permute_175, permute_179, div_31, permute_183, permute_195, permute_200, permute_204, div_33, permute_208, permute_212, div_34, permute_216, permute_228, permute_233, permute_237, div_36, permute_241, permute_245, div_37, permute_249, permute_261, permute_266, permute_270, div_39, permute_274, permute_278, div_40, permute_282, permute_294, permute_299, permute_303, div_42, permute_307, permute_311, div_43, permute_315, permute_327, permute_332, permute_336, div_45, permute_340, permute_344, div_46, permute_348, permute_360, permute_365, permute_369, div_48, permute_373, permute_377, div_49, permute_381, permute_393, permute_398, permute_402, div_51, permute_406, permute_410, div_52, permute_414, permute_426, permute_431, permute_435, div_54, permute_439, permute_443, div_55, permute_447, permute_459, permute_464, permute_468, div_57, permute_472, permute_476, div_58, permute_480, permute_492, permute_497, permute_501, div_60, permute_505, permute_509, div_61, permute_513, permute_525, permute_530, permute_534, div_63, tangents_1, tangents_2]) | |
return print_performance(fn, times=times, repeat=repeat) | |
if __name__ == "__main__": | |
from torch._inductor.wrapper_benchmark import compiled_module_main | |
compiled_module_main('BertForMaskedLM', benchmark_compiled_module) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment