Created
June 25, 2024 00:00
-
-
Save shunting314/4da3b5c2ee7f9470ac2c70cd788bddf8 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# AOT ID: ['0_backward'] | |
from ctypes import c_void_p, c_long | |
import torch | |
import math | |
import random | |
import os | |
import tempfile | |
from math import inf, nan | |
from torch._inductor.hooks import run_intermediate_hooks | |
from torch._inductor.utils import maybe_profile | |
from torch._inductor.codegen.memory_planning import _align as align | |
from torch import device, empty_strided | |
from torch._inductor.async_compile import AsyncCompile | |
from torch._inductor.select_algorithm import extern_kernels | |
from torch._inductor.codegen.multi_kernel import MultiKernelCall | |
aten = torch.ops.aten | |
inductor_ops = torch.ops.inductor | |
_quantized = torch.ops._quantized | |
assert_size_stride = torch._C._dynamo.guards.assert_size_stride | |
empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu | |
empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda | |
reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor | |
alloc_from_pool = torch.ops.inductor._alloc_from_pool | |
async_compile = AsyncCompile() | |
# kernel path: /tmp/torchinductor_shunting/mt/cmtlbu4aoetpujfnvf6x4euykpu33uxpy6shtforep4hry72bijk.py | |
# Source Nodes: [masked_lm_loss], Original ATen: [aten._log_softmax_backward_data, aten.add, aten.nll_loss_backward, aten.nll_loss_forward] | |
# masked_lm_loss => full_default_1, full_default_2 | |
triton_red_fused__log_softmax_backward_data_add_nll_loss_backward_nll_loss_forward_0 = async_compile.triton('triton_red_fused__log_softmax_backward_data_add_nll_loss_backward_nll_loss_forward_0', ''' | |
import triton | |
import triton.language as tl | |
from triton.compiler.compiler import AttrsDescriptor | |
from torch._inductor.runtime import triton_helpers, triton_heuristics | |
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties | |
from torch._dynamo.testing import rand_strided | |
from torch._C import _cuda_getCurrentRawStream as get_raw_stream | |
import torch | |
from torch._inductor.runtime.triton_heuristics import grid, split_scan_grid | |
@triton_heuristics.reduction( | |
size_hints=[8192, 32768], | |
reduction_hint=ReductionHint.DEFAULT, | |
filename=__file__, | |
triton_meta={'signature': {0: '*i64', 1: '*fp32', 2: '*fp32', 3: '*fp16', 4: '*fp16', 5: '*fp32', 6: '*fp32', 7: '*fp16', 8: 'i32', 9: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4, 5, 6, 7, 8), equal_to_1=())]}, | |
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused__log_softmax_backward_data_add_nll_loss_backward_nll_loss_forward_0', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 9, 'num_reduction': 1, 'backend_hash': '196B341D951BBDD96DE5B2B44B37054C2CBA8E89494C8B8A1052A863B8BC7596', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'kernel_num_gb': 1.500348424} | |
) | |
@triton.jit | |
def triton_red_fused__log_softmax_backward_data_add_nll_loss_backward_nll_loss_forward_0(in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, in_ptr5, in_ptr6, out_ptr1, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr): | |
xnumel = 8192 | |
rnumel = 30522 | |
xoffset = tl.program_id(0) * XBLOCK | |
xindex = xoffset + tl.arange(0, XBLOCK)[:, None] | |
xmask = tl.full([XBLOCK, RBLOCK], True, tl.int1) | |
rbase = tl.arange(0, RBLOCK)[None, :] | |
x0 = xindex | |
tmp0 = tl.load(in_ptr0 + (x0), None, eviction_policy='evict_last') | |
tmp10 = tl.load(in_ptr1 + (0)) | |
tmp11 = tl.broadcast_to(tmp10, [XBLOCK, RBLOCK]) | |
tmp12 = tl.load(in_ptr2 + (0)) | |
tmp13 = tl.broadcast_to(tmp12, [XBLOCK, RBLOCK]) | |
_tmp18 = tl.full([XBLOCK, RBLOCK], 0, tl.float32) | |
for roffset in range(0, rnumel, RBLOCK): | |
rindex = roffset + rbase | |
rmask = rindex < rnumel | |
r1 = rindex | |
tmp1 = tl.full([1, 1], -100, tl.int64) | |
tmp2 = tmp0 != tmp1 | |
tmp3 = tl.full([1, 1], 0, tl.int64) | |
tmp4 = tl.where(tmp2, tmp0, tmp3) | |
tmp5 = r1 | |
tmp6 = tmp4 == tmp5 | |
tmp7 = -1.0 | |
tmp8 = 0.0 | |
tmp9 = tl.where(tmp6, tmp7, tmp8) | |
tmp14 = tmp11 / tmp13 | |
tmp15 = tl.where(tmp2, tmp14, tmp8) | |
tmp16 = tmp9 * tmp15 | |
tmp17 = tl.broadcast_to(tmp16, [XBLOCK, RBLOCK]) | |
tmp19 = _tmp18 + tmp17 | |
_tmp18 = tl.where(rmask, tmp19, _tmp18) | |
tmp18 = tl.sum(_tmp18, 1)[:, None] | |
tmp30 = tl.load(in_ptr1 + (0)) | |
tmp31 = tl.broadcast_to(tmp30, [XBLOCK, RBLOCK]) | |
tmp32 = tl.load(in_ptr2 + (0)) | |
tmp33 = tl.broadcast_to(tmp32, [XBLOCK, RBLOCK]) | |
tmp39 = tl.load(in_ptr5 + (x0), None, eviction_policy='evict_last') | |
tmp41 = tl.load(in_ptr6 + (x0), None, eviction_policy='evict_last') | |
for roffset in range(0, rnumel, RBLOCK): | |
rindex = roffset + rbase | |
rmask = rindex < rnumel | |
r1 = rindex | |
tmp20 = tl.load(in_ptr3 + (r1 + (30522*x0)), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32) | |
tmp37 = tl.load(in_ptr4 + (r1 + (30528*x0)), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32) | |
tmp21 = tl.full([1, 1], -100, tl.int64) | |
tmp22 = tmp0 != tmp21 | |
tmp23 = tl.full([1, 1], 0, tl.int64) | |
tmp24 = tl.where(tmp22, tmp0, tmp23) | |
tmp25 = r1 | |
tmp26 = tmp24 == tmp25 | |
tmp27 = -1.0 | |
tmp28 = 0.0 | |
tmp29 = tl.where(tmp26, tmp27, tmp28) | |
tmp34 = tmp31 / tmp33 | |
tmp35 = tl.where(tmp22, tmp34, tmp28) | |
tmp36 = tmp29 * tmp35 | |
tmp38 = tmp37.to(tl.float32) | |
tmp40 = tmp38 - tmp39 | |
tmp42 = tmp40 - tmp41 | |
tmp43 = tmp42.to(tl.float32) | |
tmp44 = tmp43.to(tl.float32) | |
tmp45 = tl_math.exp(tmp44) | |
tmp46 = tmp45 * tmp18 | |
tmp47 = tmp36 - tmp46 | |
tmp48 = tmp47.to(tl.float32) | |
tmp49 = tmp20 + tmp48 | |
tl.store(out_ptr1 + (r1 + (30528*x0)), tmp49, rmask) | |
def get_args(): | |
arg_0 = rand_strided((16, 512), (512, 1), device='cuda:0', dtype=torch.int64) | |
arg_1 = rand_strided((), (), device='cuda:0', dtype=torch.float32) | |
arg_2 = rand_strided((), (), device='cuda:0', dtype=torch.float32) | |
arg_3 = rand_strided((16, 512, 30522), (15627264, 30522, 1), device='cuda:0', dtype=torch.float16) | |
arg_4 = rand_strided((16, 512, 30522), (15630336, 30528, 1), device='cuda:0', dtype=torch.float16) | |
arg_5 = rand_strided((8192, 1), (1, 1), device='cuda:0', dtype=torch.float32) | |
arg_6 = rand_strided((8192, 1), (1, 1), device='cuda:0', dtype=torch.float32) | |
arg_7 = rand_strided((16, 512, 30522), (15630336, 30528, 1), device='cuda:0', dtype=torch.float16) | |
return arg_0, arg_1, arg_2, arg_3, arg_4, arg_5, arg_6, arg_7, | |
def call(args): | |
with torch.cuda._DeviceGuard(0): | |
torch.cuda.set_device(0) | |
stream0 = get_raw_stream(0) | |
triton_red_fused__log_softmax_backward_data_add_nll_loss_backward_nll_loss_forward_0.run(*args, 8192, 30522, grid=grid(8192), stream=stream0) | |
def benchmark_all_configs(args): | |
with torch.cuda._DeviceGuard(0): | |
torch.cuda.set_device(0) | |
return triton_red_fused__log_softmax_backward_data_add_nll_loss_backward_nll_loss_forward_0.benchmark_all_configs(*args, 8192, 30522, grid=grid(8192)) | |
if __name__ == '__main__': | |
from triton.testing import do_bench | |
args = get_args() | |
ms = do_bench(lambda: call(args), rep=40, fast_flush=True) | |
num_gb = 1.500348424 | |
gb_per_s = num_gb / (ms / 1e3) | |
print(f"{ms:.3f}ms {num_gb:.3f}GB {gb_per_s:.2f}GB/s") | |
''', device_str='cuda') | |
import triton | |
import triton.language as tl | |
from torch._inductor.runtime.triton_heuristics import grid, split_scan_grid, start_graph, end_graph | |
from torch._C import _cuda_getCurrentRawStream as get_raw_stream | |
# kernel path: /tmp/torchinductor_shunting/mr/cmrmqdcrcihizdvgq5xqmrayhaooh5p5y3ha4icjpyutweyhrrbl.py | |
# Source Nodes: [], Original ATen: [] | |
triton_poi_fused_1 = async_compile.triton('triton_poi_fused_1', ''' | |
import triton | |
import triton.language as tl | |
from triton.compiler.compiler import AttrsDescriptor | |
from torch._inductor.runtime import triton_helpers, triton_heuristics | |
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties | |
from torch._dynamo.testing import rand_strided | |
from torch._C import _cuda_getCurrentRawStream as get_raw_stream | |
import torch | |
from torch._inductor.runtime.triton_heuristics import grid, split_scan_grid | |
@triton_heuristics.pointwise( | |
size_hints=[268435456], | |
filename=__file__, | |
triton_meta={'signature': {0: '*fp16', 1: '*fp16', 2: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2), equal_to_1=())]}, | |
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_1', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': '196B341D951BBDD96DE5B2B44B37054C2CBA8E89494C8B8A1052A863B8BC7596', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'kernel_num_gb': 1.0002432}, | |
min_elem_per_thread=0 | |
) | |
@triton.jit | |
def triton_poi_fused_1(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr): | |
xnumel = 250085376 | |
xoffset = tl.program_id(0) * XBLOCK | |
xindex = xoffset + tl.arange(0, XBLOCK)[:] | |
xmask = tl.full([XBLOCK], True, tl.int1) | |
x0 = xindex % 30528 | |
x2 = xindex | |
tmp0 = x0 | |
tmp1 = tl.full([1], 30522, tl.int64) | |
tmp2 = tmp0 < tmp1 | |
tmp3 = tl.load(in_ptr0 + (x2), tmp2, other=0.0).to(tl.float32) | |
tl.store(out_ptr0 + (x2), tmp3, None) | |
def get_args(): | |
arg_0 = rand_strided((16, 512, 30522), (15630336, 30528, 1), device='cuda:0', dtype=torch.float16) | |
arg_1 = rand_strided((8192, 30528), (30528, 1), device='cuda:0', dtype=torch.float16) | |
return arg_0, arg_1, | |
def call(args): | |
with torch.cuda._DeviceGuard(0): | |
torch.cuda.set_device(0) | |
stream0 = get_raw_stream(0) | |
triton_poi_fused_1.run(*args, 250085376, grid=grid(250085376), stream=stream0) | |
def benchmark_all_configs(args): | |
with torch.cuda._DeviceGuard(0): | |
torch.cuda.set_device(0) | |
return triton_poi_fused_1.benchmark_all_configs(*args, 250085376, grid=grid(250085376)) | |
if __name__ == '__main__': | |
from triton.testing import do_bench | |
args = get_args() | |
ms = do_bench(lambda: call(args), rep=40, fast_flush=True) | |
num_gb = 1.0002432 | |
gb_per_s = num_gb / (ms / 1e3) | |
print(f"{ms:.3f}ms {num_gb:.3f}GB {gb_per_s:.2f}GB/s") | |
''', device_str='cuda') | |
# kernel path: /tmp/torchinductor_shunting/kr/ckrtxiooml7ugwbki5lfdkf26ifptclbwzfuczjsgv3337j6cnlt.py | |
# Source Nodes: [], Original ATen: [] | |
triton_poi_fused_2 = async_compile.triton('triton_poi_fused_2', ''' | |
import triton | |
import triton.language as tl | |
from triton.compiler.compiler import AttrsDescriptor | |
from torch._inductor.runtime import triton_helpers, triton_heuristics | |
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties | |
from torch._dynamo.testing import rand_strided | |
from torch._C import _cuda_getCurrentRawStream as get_raw_stream | |
import torch | |
from torch._inductor.runtime.triton_heuristics import grid, split_scan_grid | |
@triton_heuristics.pointwise( | |
size_hints=[33554432], | |
filename=__file__, | |
triton_meta={'signature': {0: '*fp16', 1: '*fp16', 2: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2), equal_to_1=())]}, | |
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_2', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': '196B341D951BBDD96DE5B2B44B37054C2CBA8E89494C8B8A1052A863B8BC7596', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'kernel_num_gb': 0.0937728}, | |
min_elem_per_thread=0 | |
) | |
@triton.jit | |
def triton_poi_fused_2(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr): | |
xnumel = 23445504 | |
xoffset = tl.program_id(0) * XBLOCK | |
xindex = xoffset + tl.arange(0, XBLOCK)[:] | |
xmask = tl.full([XBLOCK], True, tl.int1) | |
x1 = (xindex // 768) | |
x2 = xindex | |
tmp0 = x1 | |
tmp1 = tl.full([1], 30522, tl.int64) | |
tmp2 = tmp0 < tmp1 | |
tmp3 = tl.load(in_ptr0 + (x2), tmp2, other=0.0).to(tl.float32) | |
tl.store(out_ptr0 + (x2), tmp3, None) | |
def get_args(): | |
arg_0 = rand_strided((30522, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
arg_1 = rand_strided((30528, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
return arg_0, arg_1, | |
def call(args): | |
with torch.cuda._DeviceGuard(0): | |
torch.cuda.set_device(0) | |
stream0 = get_raw_stream(0) | |
triton_poi_fused_2.run(*args, 23445504, grid=grid(23445504), stream=stream0) | |
def benchmark_all_configs(args): | |
with torch.cuda._DeviceGuard(0): | |
torch.cuda.set_device(0) | |
return triton_poi_fused_2.benchmark_all_configs(*args, 23445504, grid=grid(23445504)) | |
if __name__ == '__main__': | |
from triton.testing import do_bench | |
args = get_args() | |
ms = do_bench(lambda: call(args), rep=40, fast_flush=True) | |
num_gb = 0.0937728 | |
gb_per_s = num_gb / (ms / 1e3) | |
print(f"{ms:.3f}ms {num_gb:.3f}GB {gb_per_s:.2f}GB/s") | |
''', device_str='cuda') | |
# kernel path: /tmp/torchinductor_shunting/p4/cp4m2j65mbzodehwkp3jqfofigkvv3uayonfyitzbmq3ro7b62pt.py | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum] | |
triton_red_fused__to_copy_sum_3 = async_compile.triton('triton_red_fused__to_copy_sum_3', ''' | |
import triton | |
import triton.language as tl | |
from triton.compiler.compiler import AttrsDescriptor | |
from torch._inductor.runtime import triton_helpers, triton_heuristics | |
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties | |
from torch._dynamo.testing import rand_strided | |
from torch._C import _cuda_getCurrentRawStream as get_raw_stream | |
import torch | |
from torch._inductor.runtime.triton_heuristics import grid, split_scan_grid | |
@triton_heuristics.reduction( | |
size_hints=[32768, 8192], | |
reduction_hint=ReductionHint.DEFAULT, | |
filename=__file__, | |
triton_meta={'signature': {0: '*fp16', 1: '*fp32', 2: 'i32', 3: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 3), equal_to_1=())]}, | |
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_sum_3', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': '196B341D951BBDD96DE5B2B44B37054C2CBA8E89494C8B8A1052A863B8BC7596', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'kernel_num_gb': 0.500194536} | |
) | |
@triton.jit | |
def triton_red_fused__to_copy_sum_3(in_ptr0, out_ptr1, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr): | |
xnumel = 30522 | |
rnumel = 8192 | |
xoffset = tl.program_id(0) * XBLOCK | |
xindex = xoffset + tl.arange(0, XBLOCK)[:, None] | |
xmask = xindex < xnumel | |
rbase = tl.arange(0, RBLOCK)[None, :] | |
x0 = xindex | |
_tmp2 = tl.full([XBLOCK, RBLOCK], 0, tl.float32) | |
for roffset in range(0, rnumel, RBLOCK): | |
rindex = roffset + rbase | |
rmask = rindex < rnumel | |
r1 = rindex | |
tmp0 = tl.load(in_ptr0 + (x0 + (30528*r1)), rmask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32) | |
tmp1 = tl.broadcast_to(tmp0, [XBLOCK, RBLOCK]) | |
tmp3 = _tmp2 + tmp1 | |
_tmp2 = tl.where(rmask & xmask, tmp3, _tmp2) | |
tmp2 = tl.sum(_tmp2, 1)[:, None] | |
tmp4 = tmp2.to(tl.float32) | |
tl.store(out_ptr1 + (x0), tmp4, xmask) | |
def get_args(): | |
arg_0 = rand_strided((16, 512, 30522), (15630336, 30528, 1), device='cuda:0', dtype=torch.float16) | |
arg_1 = rand_strided((30522,), (1,), device='cuda:0', dtype=torch.float32) | |
return arg_0, arg_1, | |
def call(args): | |
with torch.cuda._DeviceGuard(0): | |
torch.cuda.set_device(0) | |
stream0 = get_raw_stream(0) | |
triton_red_fused__to_copy_sum_3.run(*args, 30522, 8192, grid=grid(30522), stream=stream0) | |
def benchmark_all_configs(args): | |
with torch.cuda._DeviceGuard(0): | |
torch.cuda.set_device(0) | |
return triton_red_fused__to_copy_sum_3.benchmark_all_configs(*args, 30522, 8192, grid=grid(30522)) | |
if __name__ == '__main__': | |
from triton.testing import do_bench | |
args = get_args() | |
ms = do_bench(lambda: call(args), rep=40, fast_flush=True) | |
num_gb = 0.500194536 | |
gb_per_s = num_gb / (ms / 1e3) | |
print(f"{ms:.3f}ms {num_gb:.3f}GB {gb_per_s:.2f}GB/s") | |
''', device_str='cuda') | |
# kernel path: /tmp/torchinductor_shunting/nx/cnxl5yf7bimfnoatnitspyqvw7ppsv2sxfrbw2q6pvaf653wnqsg.py | |
# Source Nodes: [], Original ATen: [aten._to_copy] | |
triton_poi_fused__to_copy_4 = async_compile.triton('triton_poi_fused__to_copy_4', ''' | |
import triton | |
import triton.language as tl | |
from triton.compiler.compiler import AttrsDescriptor | |
from torch._inductor.runtime import triton_helpers, triton_heuristics | |
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties | |
from torch._dynamo.testing import rand_strided | |
from torch._C import _cuda_getCurrentRawStream as get_raw_stream | |
import torch | |
from torch._inductor.runtime.triton_heuristics import grid, split_scan_grid | |
@triton_heuristics.pointwise( | |
size_hints=[33554432], | |
filename=__file__, | |
triton_meta={'signature': {0: '*fp16', 1: '*fp32', 2: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2), equal_to_1=())]}, | |
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_4', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': '196B341D951BBDD96DE5B2B44B37054C2CBA8E89494C8B8A1052A863B8BC7596', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'kernel_num_gb': 0.140645376}, | |
min_elem_per_thread=0 | |
) | |
@triton.jit | |
def triton_poi_fused__to_copy_4(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr): | |
xnumel = 23440896 | |
xoffset = tl.program_id(0) * XBLOCK | |
xindex = xoffset + tl.arange(0, XBLOCK)[:] | |
xmask = xindex < xnumel | |
x0 = xindex | |
tmp0 = tl.load(in_ptr0 + (x0), xmask).to(tl.float32) | |
tmp1 = tmp0.to(tl.float32) | |
tl.store(out_ptr0 + (x0), tmp1, xmask) | |
def get_args(): | |
arg_0 = rand_strided((30522, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
arg_1 = rand_strided((30522, 768), (768, 1), device='cuda:0', dtype=torch.float32) | |
return arg_0, arg_1, | |
def call(args): | |
with torch.cuda._DeviceGuard(0): | |
torch.cuda.set_device(0) | |
stream0 = get_raw_stream(0) | |
triton_poi_fused__to_copy_4.run(*args, 23440896, grid=grid(23440896), stream=stream0) | |
def benchmark_all_configs(args): | |
with torch.cuda._DeviceGuard(0): | |
torch.cuda.set_device(0) | |
return triton_poi_fused__to_copy_4.benchmark_all_configs(*args, 23440896, grid=grid(23440896)) | |
if __name__ == '__main__': | |
from triton.testing import do_bench | |
args = get_args() | |
ms = do_bench(lambda: call(args), rep=40, fast_flush=True) | |
num_gb = 0.140645376 | |
gb_per_s = num_gb / (ms / 1e3) | |
print(f"{ms:.3f}ms {num_gb:.3f}GB {gb_per_s:.2f}GB/s") | |
''', device_str='cuda') | |
# kernel path: /tmp/torchinductor_shunting/yy/cyyvpnhm2jzov5qsqyvwzq5jr53s4xnzjf5k276kkys4x3j5o3m6.py | |
# Source Nodes: [hidden_states_97, hidden_states_98], Original ATen: [aten._to_copy, aten.gelu, aten.gelu_backward, aten.native_layer_norm, aten.native_layer_norm_backward, aten.view] | |
# hidden_states_97 => add_100, convert_element_type_498, erf_12, mul_161, mul_162, mul_163 | |
# hidden_states_98 => mul_164, sub_38 | |
triton_per_fused__to_copy_gelu_gelu_backward_native_layer_norm_native_layer_norm_backward_view_5 = async_compile.triton('triton_per_fused__to_copy_gelu_gelu_backward_native_layer_norm_native_layer_norm_backward_view_5', ''' | |
import triton | |
import triton.language as tl | |
from triton.compiler.compiler import AttrsDescriptor | |
from torch._inductor.runtime import triton_helpers, triton_heuristics | |
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties | |
from torch._dynamo.testing import rand_strided | |
from torch._C import _cuda_getCurrentRawStream as get_raw_stream | |
import torch | |
from torch._inductor.runtime.triton_heuristics import grid, split_scan_grid | |
@triton_heuristics.persistent_reduction( | |
size_hints=[8192, 1024], | |
reduction_hint=ReductionHint.INNER, | |
filename=__file__, | |
triton_meta={'signature': {0: '*fp16', 1: '*fp32', 2: '*fp16', 3: '*fp32', 4: '*fp32', 5: '*fp16', 6: 'i32', 7: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4, 5, 6, 7), equal_to_1=())]}, | |
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_gelu_gelu_backward_native_layer_norm_native_layer_norm_backward_view_5', 'mutated_arg_names': [], 'no_x_dim': True, 'num_load': 5, 'num_reduction': 2, 'backend_hash': '196B341D951BBDD96DE5B2B44B37054C2CBA8E89494C8B8A1052A863B8BC7596', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'kernel_num_gb': 0.037817344} | |
) | |
@triton.jit | |
def triton_per_fused__to_copy_gelu_gelu_backward_native_layer_norm_native_layer_norm_backward_view_5(in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, out_ptr3, xnumel, rnumel): | |
xnumel = 8192 | |
XBLOCK: tl.constexpr = 1 | |
rnumel = 768 | |
RBLOCK: tl.constexpr = 1024 | |
xoffset = tl.program_id(0) * XBLOCK | |
xindex = tl.full([1], xoffset, tl.int32) | |
xmask = tl.full([RBLOCK], True, tl.int1) | |
rindex = tl.arange(0, RBLOCK)[:] | |
roffset = 0 | |
rmask = rindex < rnumel | |
r1 = rindex | |
x0 = xindex | |
tmp0 = tl.load(in_ptr0 + (r1 + (768*x0)), rmask, other=0.0).to(tl.float32) | |
tmp2 = tl.load(in_ptr1 + (r1), rmask, eviction_policy='evict_last', other=0.0) | |
tmp8 = tl.load(in_ptr2 + (r1 + (768*x0)), rmask, other=0.0).to(tl.float32) | |
tmp18 = tl.load(in_ptr3 + (x0), None, eviction_policy='evict_last') | |
tmp20 = tl.load(in_ptr4 + (x0), None, eviction_policy='evict_last') | |
tmp1 = tmp0.to(tl.float32) | |
tmp3 = tmp1 * tmp2 | |
tmp4 = tl.broadcast_to(tmp3, [RBLOCK]) | |
tmp6 = tl.where(rmask, tmp4, 0) | |
tmp7 = triton_helpers.promote_to_tensor(tl.sum(tmp6, 0)) | |
tmp9 = tmp8.to(tl.float32) | |
tmp10 = 0.5 | |
tmp11 = tmp9 * tmp10 | |
tmp12 = 0.7071067811865476 | |
tmp13 = tmp9 * tmp12 | |
tmp14 = libdevice.erf(tmp13) | |
tmp15 = 1.0 | |
tmp16 = tmp14 + tmp15 | |
tmp17 = tmp11 * tmp16 | |
tmp19 = tmp17 - tmp18 | |
tmp21 = tmp19 * tmp20 | |
tmp22 = tmp3 * tmp21 | |
tmp23 = tl.broadcast_to(tmp22, [RBLOCK]) | |
tmp25 = tl.where(rmask, tmp23, 0) | |
tmp26 = triton_helpers.promote_to_tensor(tl.sum(tmp25, 0)) | |
tmp27 = 0.0013020833333333333 | |
tmp28 = tmp20 * tmp27 | |
tmp29 = 768.0 | |
tmp30 = tmp3 * tmp29 | |
tmp31 = tmp30 - tmp7 | |
tmp32 = tmp21 * tmp26 | |
tmp33 = tmp31 - tmp32 | |
tmp34 = tmp28 * tmp33 | |
tmp35 = tmp16 * tmp10 | |
tmp36 = tmp9 * tmp9 | |
tmp37 = -0.5 | |
tmp38 = tmp36 * tmp37 | |
tmp39 = tl_math.exp(tmp38) | |
tmp40 = 0.3989422804014327 | |
tmp41 = tmp39 * tmp40 | |
tmp42 = tmp9 * tmp41 | |
tmp43 = tmp35 + tmp42 | |
tmp44 = tmp34 * tmp43 | |
tmp45 = tmp44.to(tl.float32) | |
tl.store(out_ptr3 + (r1 + (768*x0)), tmp45, rmask) | |
def get_args(): | |
arg_0 = rand_strided((8192, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
arg_1 = rand_strided((768,), (1,), device='cuda:0', dtype=torch.float32) | |
arg_2 = rand_strided((8192, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
arg_3 = rand_strided((16, 512, 1), (512, 1, 1), device='cuda:0', dtype=torch.float32) | |
arg_4 = rand_strided((16, 512, 1), (512, 1, 1), device='cuda:0', dtype=torch.float32) | |
arg_5 = rand_strided((8192, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
return arg_0, arg_1, arg_2, arg_3, arg_4, arg_5, | |
def call(args): | |
with torch.cuda._DeviceGuard(0): | |
torch.cuda.set_device(0) | |
stream0 = get_raw_stream(0) | |
triton_per_fused__to_copy_gelu_gelu_backward_native_layer_norm_native_layer_norm_backward_view_5.run(*args, 8192, 768, grid=grid(8192), stream=stream0) | |
def benchmark_all_configs(args): | |
with torch.cuda._DeviceGuard(0): | |
torch.cuda.set_device(0) | |
return triton_per_fused__to_copy_gelu_gelu_backward_native_layer_norm_native_layer_norm_backward_view_5.benchmark_all_configs(*args, 8192, 768, grid=grid(8192)) | |
if __name__ == '__main__': | |
from triton.testing import do_bench | |
args = get_args() | |
ms = do_bench(lambda: call(args), rep=40, fast_flush=True) | |
num_gb = 0.037817344 | |
gb_per_s = num_gb / (ms / 1e3) | |
print(f"{ms:.3f}ms {num_gb:.3f}GB {gb_per_s:.2f}GB/s") | |
''', device_str='cuda') | |
# kernel path: /tmp/torchinductor_shunting/p2/cp26z4vf45nz2fihk4sr3hatolkq5voe5ay6tuzetss46ifomndv.py | |
# Source Nodes: [hidden_states_97, hidden_states_98], Original ATen: [aten._to_copy, aten.gelu, aten.native_layer_norm, aten.native_layer_norm_backward] | |
# hidden_states_97 => add_100, convert_element_type_498, erf_12, mul_161, mul_162, mul_163 | |
# hidden_states_98 => mul_164, sub_38 | |
triton_red_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_6 = async_compile.triton('triton_red_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_6', ''' | |
import triton | |
import triton.language as tl | |
from triton.compiler.compiler import AttrsDescriptor | |
from torch._inductor.runtime import triton_helpers, triton_heuristics | |
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties | |
from torch._dynamo.testing import rand_strided | |
from torch._C import _cuda_getCurrentRawStream as get_raw_stream | |
import torch | |
from torch._inductor.runtime.triton_heuristics import grid, split_scan_grid | |
@triton_heuristics.reduction( | |
size_hints=[65536, 128], | |
reduction_hint=ReductionHint.OUTER, | |
filename=__file__, | |
triton_meta={'signature': {0: '*fp16', 1: '*fp16', 2: '*fp32', 3: '*fp32', 4: '*fp32', 5: '*fp32', 6: 'i32', 7: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4, 5, 6, 7), equal_to_1=())]}, | |
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_6', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 4, 'num_reduction': 2, 'backend_hash': '196B341D951BBDD96DE5B2B44B37054C2CBA8E89494C8B8A1052A863B8BC7596', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'kernel_num_gb': 0.025624576} | |
) | |
@triton.jit | |
def triton_red_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_6(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr0, out_ptr1, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr): | |
xnumel = 49152 | |
rnumel = 128 | |
xoffset = tl.program_id(0) * XBLOCK | |
xindex = xoffset + tl.arange(0, XBLOCK)[:, None] | |
xmask = tl.full([XBLOCK, RBLOCK], True, tl.int1) | |
rbase = tl.arange(0, RBLOCK)[None, :] | |
x0 = xindex % 768 | |
x1 = (xindex // 768) | |
_tmp18 = tl.full([XBLOCK, RBLOCK], 0, tl.float32) | |
x3 = xindex | |
_tmp21 = tl.full([XBLOCK, RBLOCK], 0, tl.float32) | |
for roffset in range(0, rnumel, RBLOCK): | |
rindex = roffset + rbase | |
rmask = rindex < rnumel | |
r2 = rindex | |
tmp0 = tl.load(in_ptr0 + (x0 + (768*r2) + (98304*x1)), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32) | |
tmp2 = tl.load(in_ptr1 + (x0 + (768*r2) + (98304*x1)), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32) | |
tmp12 = tl.load(in_ptr2 + (r2 + (128*x1)), rmask, eviction_policy='evict_last', other=0.0) | |
tmp14 = tl.load(in_ptr3 + (r2 + (128*x1)), rmask, eviction_policy='evict_last', other=0.0) | |
tmp1 = tmp0.to(tl.float32) | |
tmp3 = tmp2.to(tl.float32) | |
tmp4 = 0.5 | |
tmp5 = tmp3 * tmp4 | |
tmp6 = 0.7071067811865476 | |
tmp7 = tmp3 * tmp6 | |
tmp8 = libdevice.erf(tmp7) | |
tmp9 = 1.0 | |
tmp10 = tmp8 + tmp9 | |
tmp11 = tmp5 * tmp10 | |
tmp13 = tmp11 - tmp12 | |
tmp15 = tmp13 * tmp14 | |
tmp16 = tmp1 * tmp15 | |
tmp17 = tl.broadcast_to(tmp16, [XBLOCK, RBLOCK]) | |
tmp19 = _tmp18 + tmp17 | |
_tmp18 = tl.where(rmask, tmp19, _tmp18) | |
tmp20 = tl.broadcast_to(tmp1, [XBLOCK, RBLOCK]) | |
tmp22 = _tmp21 + tmp20 | |
_tmp21 = tl.where(rmask, tmp22, _tmp21) | |
tmp18 = tl.sum(_tmp18, 1)[:, None] | |
tl.store(out_ptr0 + (x3), tmp18, None) | |
tmp21 = tl.sum(_tmp21, 1)[:, None] | |
tl.store(out_ptr1 + (x3), tmp21, None) | |
def get_args(): | |
arg_0 = rand_strided((8192, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
arg_1 = rand_strided((8192, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
arg_2 = rand_strided((16, 512, 1), (512, 1, 1), device='cuda:0', dtype=torch.float32) | |
arg_3 = rand_strided((16, 512, 1), (512, 1, 1), device='cuda:0', dtype=torch.float32) | |
arg_4 = rand_strided((768, 64), (1, 768), device='cuda:0', dtype=torch.float32) | |
arg_5 = rand_strided((768, 64), (1, 768), device='cuda:0', dtype=torch.float32) | |
return arg_0, arg_1, arg_2, arg_3, arg_4, arg_5, | |
def call(args): | |
with torch.cuda._DeviceGuard(0): | |
torch.cuda.set_device(0) | |
stream0 = get_raw_stream(0) | |
triton_red_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_6.run(*args, 49152, 128, grid=grid(49152), stream=stream0) | |
def benchmark_all_configs(args): | |
with torch.cuda._DeviceGuard(0): | |
torch.cuda.set_device(0) | |
return triton_red_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_6.benchmark_all_configs(*args, 49152, 128, grid=grid(49152)) | |
if __name__ == '__main__': | |
from triton.testing import do_bench | |
args = get_args() | |
ms = do_bench(lambda: call(args), rep=40, fast_flush=True) | |
num_gb = 0.025624576 | |
gb_per_s = num_gb / (ms / 1e3) | |
print(f"{ms:.3f}ms {num_gb:.3f}GB {gb_per_s:.2f}GB/s") | |
''', device_str='cuda') | |
# kernel path: /tmp/torchinductor_shunting/be/cbevumbrgmzdfgdapkc3uyo5g3sg7zr3sizocbvwld6n22bimq7o.py | |
# Source Nodes: [hidden_states_97, hidden_states_98], Original ATen: [aten._to_copy, aten.gelu, aten.native_layer_norm, aten.native_layer_norm_backward] | |
# hidden_states_97 => add_100, convert_element_type_498, erf_12, mul_161, mul_162, mul_163 | |
# hidden_states_98 => mul_164, sub_38 | |
triton_per_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_7 = async_compile.triton('triton_per_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_7', ''' | |
import triton | |
import triton.language as tl | |
from triton.compiler.compiler import AttrsDescriptor | |
from torch._inductor.runtime import triton_helpers, triton_heuristics | |
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties | |
from torch._dynamo.testing import rand_strided | |
from torch._C import _cuda_getCurrentRawStream as get_raw_stream | |
import torch | |
from torch._inductor.runtime.triton_heuristics import grid, split_scan_grid | |
@triton_heuristics.persistent_reduction( | |
size_hints=[1024, 64], | |
reduction_hint=ReductionHint.OUTER, | |
filename=__file__, | |
triton_meta={'signature': {0: '*fp32', 1: '*fp32', 2: 'i32', 3: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3), equal_to_1=())]}, | |
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_7', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': '196B341D951BBDD96DE5B2B44B37054C2CBA8E89494C8B8A1052A863B8BC7596', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'kernel_num_gb': 0.00019968} | |
) | |
@triton.jit | |
def triton_per_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_7(in_ptr0, out_ptr0, xnumel, rnumel, XBLOCK : tl.constexpr): | |
xnumel = 768 | |
rnumel = 64 | |
RBLOCK: tl.constexpr = 64 | |
xoffset = tl.program_id(0) * XBLOCK | |
xindex = xoffset + tl.arange(0, XBLOCK)[:, None] | |
xmask = xindex < xnumel | |
rindex = tl.arange(0, RBLOCK)[None, :] | |
roffset = 0 | |
rmask = tl.full([XBLOCK, RBLOCK], True, tl.int1) | |
r1 = rindex | |
x0 = xindex | |
tmp0 = tl.load(in_ptr0 + (x0 + (768*r1)), xmask, other=0.0) | |
tmp1 = tl.broadcast_to(tmp0, [XBLOCK, RBLOCK]) | |
tmp3 = tl.where(xmask, tmp1, 0) | |
tmp4 = tl.sum(tmp3, 1)[:, None] | |
tl.store(out_ptr0 + (x0), tmp4, xmask) | |
def get_args(): | |
arg_0 = rand_strided((768, 64), (1, 768), device='cuda:0', dtype=torch.float32) | |
arg_1 = rand_strided((768,), (1,), device='cuda:0', dtype=torch.float32) | |
return arg_0, arg_1, | |
def call(args): | |
with torch.cuda._DeviceGuard(0): | |
torch.cuda.set_device(0) | |
stream0 = get_raw_stream(0) | |
triton_per_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_7.run(*args, 768, 64, grid=grid(768), stream=stream0) | |
def benchmark_all_configs(args): | |
with torch.cuda._DeviceGuard(0): | |
torch.cuda.set_device(0) | |
return triton_per_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_7.benchmark_all_configs(*args, 768, 64, grid=grid(768)) | |
if __name__ == '__main__': | |
from triton.testing import do_bench | |
args = get_args() | |
ms = do_bench(lambda: call(args), rep=40, fast_flush=True) | |
num_gb = 0.00019968 | |
gb_per_s = num_gb / (ms / 1e3) | |
print(f"{ms:.3f}ms {num_gb:.3f}GB {gb_per_s:.2f}GB/s") | |
''', device_str='cuda') | |
# kernel path: /tmp/torchinductor_shunting/dg/cdgwv5eeos3ltgqpvtb7olrpc6yd4n6mytrgvxk656fu3eavxcgq.py | |
# Source Nodes: [], Original ATen: [aten.sum] | |
triton_red_fused_sum_8 = async_compile.triton('triton_red_fused_sum_8', ''' | |
import triton | |
import triton.language as tl | |
from triton.compiler.compiler import AttrsDescriptor | |
from torch._inductor.runtime import triton_helpers, triton_heuristics | |
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties | |
from torch._dynamo.testing import rand_strided | |
from torch._C import _cuda_getCurrentRawStream as get_raw_stream | |
import torch | |
from torch._inductor.runtime.triton_heuristics import grid, split_scan_grid | |
@triton_heuristics.reduction( | |
size_hints=[65536, 128], | |
reduction_hint=ReductionHint.OUTER, | |
filename=__file__, | |
triton_meta={'signature': {0: '*fp16', 1: '*fp32', 2: 'i32', 3: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3), equal_to_1=())]}, | |
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused_sum_8', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': '196B341D951BBDD96DE5B2B44B37054C2CBA8E89494C8B8A1052A863B8BC7596', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'kernel_num_gb': 0.01277952} | |
) | |
@triton.jit | |
def triton_red_fused_sum_8(in_ptr0, out_ptr0, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr): | |
xnumel = 49152 | |
rnumel = 128 | |
xoffset = tl.program_id(0) * XBLOCK | |
xindex = xoffset + tl.arange(0, XBLOCK)[:, None] | |
xmask = tl.full([XBLOCK, RBLOCK], True, tl.int1) | |
rbase = tl.arange(0, RBLOCK)[None, :] | |
x0 = xindex % 768 | |
x1 = (xindex // 768) | |
_tmp2 = tl.full([XBLOCK, RBLOCK], 0, tl.float32) | |
x3 = xindex | |
for roffset in range(0, rnumel, RBLOCK): | |
rindex = roffset + rbase | |
rmask = rindex < rnumel | |
r2 = rindex | |
tmp0 = tl.load(in_ptr0 + (x0 + (768*r2) + (98304*x1)), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32) | |
tmp1 = tl.broadcast_to(tmp0, [XBLOCK, RBLOCK]) | |
tmp3 = _tmp2 + tmp1 | |
_tmp2 = tl.where(rmask, tmp3, _tmp2) | |
tmp2 = tl.sum(_tmp2, 1)[:, None] | |
tl.store(out_ptr0 + (x3), tmp2, None) | |
def get_args(): | |
arg_0 = rand_strided((8192, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
arg_1 = rand_strided((1, 768, 64), (49152, 1, 768), device='cuda:0', dtype=torch.float32) | |
return arg_0, arg_1, | |
def call(args): | |
with torch.cuda._DeviceGuard(0): | |
torch.cuda.set_device(0) | |
stream0 = get_raw_stream(0) | |
triton_red_fused_sum_8.run(*args, 49152, 128, grid=grid(49152), stream=stream0) | |
def benchmark_all_configs(args): | |
with torch.cuda._DeviceGuard(0): | |
torch.cuda.set_device(0) | |
return triton_red_fused_sum_8.benchmark_all_configs(*args, 49152, 128, grid=grid(49152)) | |
if __name__ == '__main__': | |
from triton.testing import do_bench | |
args = get_args() | |
ms = do_bench(lambda: call(args), rep=40, fast_flush=True) | |
num_gb = 0.01277952 | |
gb_per_s = num_gb / (ms / 1e3) | |
print(f"{ms:.3f}ms {num_gb:.3f}GB {gb_per_s:.2f}GB/s") | |
''', device_str='cuda') | |
# kernel path: /tmp/torchinductor_shunting/de/cdefudbzyfjtf3vfo5cjqbkh4nayk6d2w5vztk6sxrz3diyaoomj.py | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum] | |
triton_per_fused__to_copy_sum_9 = async_compile.triton('triton_per_fused__to_copy_sum_9', ''' | |
import triton | |
import triton.language as tl | |
from triton.compiler.compiler import AttrsDescriptor | |
from torch._inductor.runtime import triton_helpers, triton_heuristics | |
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties | |
from torch._dynamo.testing import rand_strided | |
from torch._C import _cuda_getCurrentRawStream as get_raw_stream | |
import torch | |
from torch._inductor.runtime.triton_heuristics import grid, split_scan_grid | |
@triton_heuristics.persistent_reduction( | |
size_hints=[1024, 64], | |
reduction_hint=ReductionHint.OUTER, | |
filename=__file__, | |
triton_meta={'signature': {0: '*fp32', 1: '*fp32', 2: 'i32', 3: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3), equal_to_1=())]}, | |
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_sum_9', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': '196B341D951BBDD96DE5B2B44B37054C2CBA8E89494C8B8A1052A863B8BC7596', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'kernel_num_gb': 0.00019968} | |
) | |
@triton.jit | |
def triton_per_fused__to_copy_sum_9(in_ptr0, out_ptr1, xnumel, rnumel, XBLOCK : tl.constexpr): | |
xnumel = 768 | |
rnumel = 64 | |
RBLOCK: tl.constexpr = 64 | |
xoffset = tl.program_id(0) * XBLOCK | |
xindex = xoffset + tl.arange(0, XBLOCK)[:, None] | |
xmask = xindex < xnumel | |
rindex = tl.arange(0, RBLOCK)[None, :] | |
roffset = 0 | |
rmask = tl.full([XBLOCK, RBLOCK], True, tl.int1) | |
r1 = rindex | |
x0 = xindex | |
tmp0 = tl.load(in_ptr0 + (x0 + (768*r1)), xmask, other=0.0) | |
tmp1 = tl.broadcast_to(tmp0, [XBLOCK, RBLOCK]) | |
tmp3 = tl.where(xmask, tmp1, 0) | |
tmp4 = tl.sum(tmp3, 1)[:, None] | |
tmp5 = tmp4.to(tl.float32) | |
tl.store(out_ptr1 + (x0), tmp5, xmask) | |
def get_args(): | |
arg_0 = rand_strided((1, 768, 64), (49152, 1, 768), device='cuda:0', dtype=torch.float32) | |
arg_1 = rand_strided((768,), (1,), device='cuda:0', dtype=torch.float32) | |
return arg_0, arg_1, | |
def call(args): | |
with torch.cuda._DeviceGuard(0): | |
torch.cuda.set_device(0) | |
stream0 = get_raw_stream(0) | |
triton_per_fused__to_copy_sum_9.run(*args, 768, 64, grid=grid(768), stream=stream0) | |
def benchmark_all_configs(args): | |
with torch.cuda._DeviceGuard(0): | |
torch.cuda.set_device(0) | |
return triton_per_fused__to_copy_sum_9.benchmark_all_configs(*args, 768, 64, grid=grid(768)) | |
if __name__ == '__main__': | |
from triton.testing import do_bench | |
args = get_args() | |
ms = do_bench(lambda: call(args), rep=40, fast_flush=True) | |
num_gb = 0.00019968 | |
gb_per_s = num_gb / (ms / 1e3) | |
print(f"{ms:.3f}ms {num_gb:.3f}GB {gb_per_s:.2f}GB/s") | |
''', device_str='cuda') | |
# kernel path: /tmp/torchinductor_shunting/aa/caaivtwgxhkxzzagzke745nxumkqey4pkyuq3ww7l7r3nilsqszw.py | |
# Source Nodes: [], Original ATen: [aten._to_copy] | |
triton_poi_fused__to_copy_10 = async_compile.triton('triton_poi_fused__to_copy_10', ''' | |
import triton | |
import triton.language as tl | |
from triton.compiler.compiler import AttrsDescriptor | |
from torch._inductor.runtime import triton_helpers, triton_heuristics | |
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties | |
from torch._dynamo.testing import rand_strided | |
from torch._C import _cuda_getCurrentRawStream as get_raw_stream | |
import torch | |
from torch._inductor.runtime.triton_heuristics import grid, split_scan_grid | |
@triton_heuristics.pointwise( | |
size_hints=[1048576], | |
filename=__file__, | |
triton_meta={'signature': {0: '*fp16', 1: '*fp32', 2: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2), equal_to_1=())]}, | |
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_10', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': '196B341D951BBDD96DE5B2B44B37054C2CBA8E89494C8B8A1052A863B8BC7596', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'kernel_num_gb': 0.003538944}, | |
min_elem_per_thread=0 | |
) | |
@triton.jit | |
def triton_poi_fused__to_copy_10(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr): | |
xnumel = 589824 | |
xoffset = tl.program_id(0) * XBLOCK | |
xindex = xoffset + tl.arange(0, XBLOCK)[:] | |
xmask = tl.full([XBLOCK], True, tl.int1) | |
x0 = xindex | |
tmp0 = tl.load(in_ptr0 + (x0), None).to(tl.float32) | |
tmp1 = tmp0.to(tl.float32) | |
tl.store(out_ptr0 + (x0), tmp1, None) | |
def get_args(): | |
arg_0 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
arg_1 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.float32) | |
return arg_0, arg_1, | |
def call(args): | |
with torch.cuda._DeviceGuard(0): | |
torch.cuda.set_device(0) | |
stream0 = get_raw_stream(0) | |
triton_poi_fused__to_copy_10.run(*args, 589824, grid=grid(589824), stream=stream0) | |
def benchmark_all_configs(args): | |
with torch.cuda._DeviceGuard(0): | |
torch.cuda.set_device(0) | |
return triton_poi_fused__to_copy_10.benchmark_all_configs(*args, 589824, grid=grid(589824)) | |
if __name__ == '__main__': | |
from triton.testing import do_bench | |
args = get_args() | |
ms = do_bench(lambda: call(args), rep=40, fast_flush=True) | |
num_gb = 0.003538944 | |
gb_per_s = num_gb / (ms / 1e3) | |
print(f"{ms:.3f}ms {num_gb:.3f}GB {gb_per_s:.2f}GB/s") | |
''', device_str='cuda') | |
# kernel path: /tmp/torchinductor_shunting/3x/c3xzzagzsfle5jswbt57zqxejbypuvykxyswwi5k2leqkor7uv76.py | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.native_dropout_backward, aten.native_layer_norm_backward] | |
triton_per_fused__to_copy_native_dropout_backward_native_layer_norm_backward_11 = async_compile.triton('triton_per_fused__to_copy_native_dropout_backward_native_layer_norm_backward_11', ''' | |
import triton | |
import triton.language as tl | |
from triton.compiler.compiler import AttrsDescriptor | |
from torch._inductor.runtime import triton_helpers, triton_heuristics | |
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties | |
from torch._dynamo.testing import rand_strided | |
from torch._C import _cuda_getCurrentRawStream as get_raw_stream | |
import torch | |
from torch._inductor.runtime.triton_heuristics import grid, split_scan_grid | |
@triton_heuristics.persistent_reduction( | |
size_hints=[8192, 1024], | |
reduction_hint=ReductionHint.INNER, | |
filename=__file__, | |
triton_meta={'signature': {0: '*fp16', 1: '*fp32', 2: '*fp32', 3: '*fp32', 4: '*i1', 5: '*fp32', 6: '*fp16', 7: 'i32', 8: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4, 5, 6, 7, 8), equal_to_1=())]}, | |
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_native_dropout_backward_native_layer_norm_backward_11', 'mutated_arg_names': [], 'no_x_dim': True, 'num_load': 5, 'num_reduction': 2, 'backend_hash': '196B341D951BBDD96DE5B2B44B37054C2CBA8E89494C8B8A1052A863B8BC7596', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'kernel_num_gb': 0.081824768} | |
) | |
@triton.jit | |
def triton_per_fused__to_copy_native_dropout_backward_native_layer_norm_backward_11(in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, out_ptr2, out_ptr3, xnumel, rnumel): | |
xnumel = 8192 | |
XBLOCK: tl.constexpr = 1 | |
rnumel = 768 | |
RBLOCK: tl.constexpr = 1024 | |
xoffset = tl.program_id(0) * XBLOCK | |
xindex = tl.full([1], xoffset, tl.int32) | |
xmask = tl.full([RBLOCK], True, tl.int1) | |
rindex = tl.arange(0, RBLOCK)[:] | |
roffset = 0 | |
rmask = rindex < rnumel | |
r1 = rindex | |
x0 = xindex | |
tmp0 = tl.load(in_ptr0 + (r1 + (768*x0)), rmask, other=0.0).to(tl.float32) | |
tmp2 = tl.load(in_ptr1 + (r1), rmask, eviction_policy='evict_last', other=0.0) | |
tmp8 = tl.load(in_ptr2 + (r1 + (768*x0)), rmask, other=0.0) | |
tmp14 = tl.load(in_ptr3 + (x0), None, eviction_policy='evict_last') | |
tmp22 = tl.load(in_ptr4 + (r1 + (768*x0)), rmask, other=0.0).to(tl.int1) | |
tmp1 = tmp0.to(tl.float32) | |
tmp3 = tmp1 * tmp2 | |
tmp4 = tl.broadcast_to(tmp3, [RBLOCK]) | |
tmp6 = tl.where(rmask, tmp4, 0) | |
tmp7 = triton_helpers.promote_to_tensor(tl.sum(tmp6, 0)) | |
tmp9 = tmp3 * tmp8 | |
tmp10 = tl.broadcast_to(tmp9, [RBLOCK]) | |
tmp12 = tl.where(rmask, tmp10, 0) | |
tmp13 = triton_helpers.promote_to_tensor(tl.sum(tmp12, 0)) | |
tmp15 = 768.0 | |
tmp16 = tmp3 * tmp15 | |
tmp17 = tmp16 - tmp7 | |
tmp18 = tmp8 * tmp13 | |
tmp19 = tmp17 - tmp18 | |
tmp20 = tmp14 * tmp19 | |
tmp21 = tmp20.to(tl.float32) | |
tmp23 = tmp22.to(tl.float32) | |
tmp24 = 1.1111111111111112 | |
tmp25 = tmp23 * tmp24 | |
tmp26 = tmp21 * tmp25 | |
tl.store(out_ptr2 + (r1 + (768*x0)), tmp20, rmask) | |
tl.store(out_ptr3 + (r1 + (768*x0)), tmp26, rmask) | |
def get_args(): | |
arg_0 = rand_strided((8192, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
arg_1 = rand_strided((768,), (1,), device='cuda:0', dtype=torch.float32) | |
arg_2 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.float32) | |
arg_3 = rand_strided((16, 512, 1), (512, 1, 1), device='cuda:0', dtype=torch.float32) | |
arg_4 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.bool) | |
arg_5 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.float32) | |
arg_6 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.float16) | |
return arg_0, arg_1, arg_2, arg_3, arg_4, arg_5, arg_6, | |
def call(args): | |
with torch.cuda._DeviceGuard(0): | |
torch.cuda.set_device(0) | |
stream0 = get_raw_stream(0) | |
triton_per_fused__to_copy_native_dropout_backward_native_layer_norm_backward_11.run(*args, 8192, 768, grid=grid(8192), stream=stream0) | |
def benchmark_all_configs(args): | |
with torch.cuda._DeviceGuard(0): | |
torch.cuda.set_device(0) | |
return triton_per_fused__to_copy_native_dropout_backward_native_layer_norm_backward_11.benchmark_all_configs(*args, 8192, 768, grid=grid(8192)) | |
if __name__ == '__main__': | |
from triton.testing import do_bench | |
args = get_args() | |
ms = do_bench(lambda: call(args), rep=40, fast_flush=True) | |
num_gb = 0.081824768 | |
gb_per_s = num_gb / (ms / 1e3) | |
print(f"{ms:.3f}ms {num_gb:.3f}GB {gb_per_s:.2f}GB/s") | |
''', device_str='cuda') | |
# kernel path: /tmp/torchinductor_shunting/ix/cixvree2ydqv56di7ljjja76jndkhgeq7umgiaw3evahe2aocyji.py | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.native_layer_norm_backward] | |
triton_red_fused__to_copy_native_layer_norm_backward_12 = async_compile.triton('triton_red_fused__to_copy_native_layer_norm_backward_12', ''' | |
import triton | |
import triton.language as tl | |
from triton.compiler.compiler import AttrsDescriptor | |
from torch._inductor.runtime import triton_helpers, triton_heuristics | |
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties | |
from torch._dynamo.testing import rand_strided | |
from torch._C import _cuda_getCurrentRawStream as get_raw_stream | |
import torch | |
from torch._inductor.runtime.triton_heuristics import grid, split_scan_grid | |
@triton_heuristics.reduction( | |
size_hints=[65536, 128], | |
reduction_hint=ReductionHint.OUTER, | |
filename=__file__, | |
triton_meta={'signature': {0: '*fp16', 1: '*fp32', 2: '*fp32', 3: '*fp32', 4: 'i32', 5: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4, 5), equal_to_1=())]}, | |
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_native_layer_norm_backward_12', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 2, 'num_reduction': 2, 'backend_hash': '196B341D951BBDD96DE5B2B44B37054C2CBA8E89494C8B8A1052A863B8BC7596', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'kernel_num_gb': 0.038141952} | |
) | |
@triton.jit | |
def triton_red_fused__to_copy_native_layer_norm_backward_12(in_ptr0, in_ptr1, out_ptr0, out_ptr1, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr): | |
xnumel = 49152 | |
rnumel = 128 | |
xoffset = tl.program_id(0) * XBLOCK | |
xindex = xoffset + tl.arange(0, XBLOCK)[:, None] | |
xmask = tl.full([XBLOCK, RBLOCK], True, tl.int1) | |
rbase = tl.arange(0, RBLOCK)[None, :] | |
x0 = xindex % 768 | |
x1 = (xindex // 768) | |
_tmp5 = tl.full([XBLOCK, RBLOCK], 0, tl.float32) | |
x3 = xindex | |
_tmp8 = tl.full([XBLOCK, RBLOCK], 0, tl.float32) | |
for roffset in range(0, rnumel, RBLOCK): | |
rindex = roffset + rbase | |
rmask = rindex < rnumel | |
r2 = rindex | |
tmp0 = tl.load(in_ptr0 + (x0 + (768*r2) + (98304*x1)), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32) | |
tmp2 = tl.load(in_ptr1 + (x0 + (768*r2) + (98304*x1)), rmask, eviction_policy='evict_first', other=0.0) | |
tmp1 = tmp0.to(tl.float32) | |
tmp3 = tmp1 * tmp2 | |
tmp4 = tl.broadcast_to(tmp3, [XBLOCK, RBLOCK]) | |
tmp6 = _tmp5 + tmp4 | |
_tmp5 = tl.where(rmask, tmp6, _tmp5) | |
tmp7 = tl.broadcast_to(tmp1, [XBLOCK, RBLOCK]) | |
tmp9 = _tmp8 + tmp7 | |
_tmp8 = tl.where(rmask, tmp9, _tmp8) | |
tmp5 = tl.sum(_tmp5, 1)[:, None] | |
tl.store(out_ptr0 + (x3), tmp5, None) | |
tmp8 = tl.sum(_tmp8, 1)[:, None] | |
tl.store(out_ptr1 + (x3), tmp8, None) | |
def get_args(): | |
arg_0 = rand_strided((8192, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
arg_1 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.float32) | |
arg_2 = rand_strided((768, 64), (1, 768), device='cuda:0', dtype=torch.float32) | |
arg_3 = rand_strided((768, 64), (1, 768), device='cuda:0', dtype=torch.float32) | |
return arg_0, arg_1, arg_2, arg_3, | |
def call(args): | |
with torch.cuda._DeviceGuard(0): | |
torch.cuda.set_device(0) | |
stream0 = get_raw_stream(0) | |
triton_red_fused__to_copy_native_layer_norm_backward_12.run(*args, 49152, 128, grid=grid(49152), stream=stream0) | |
def benchmark_all_configs(args): | |
with torch.cuda._DeviceGuard(0): | |
torch.cuda.set_device(0) | |
return triton_red_fused__to_copy_native_layer_norm_backward_12.benchmark_all_configs(*args, 49152, 128, grid=grid(49152)) | |
if __name__ == '__main__': | |
from triton.testing import do_bench | |
args = get_args() | |
ms = do_bench(lambda: call(args), rep=40, fast_flush=True) | |
num_gb = 0.038141952 | |
gb_per_s = num_gb / (ms / 1e3) | |
print(f"{ms:.3f}ms {num_gb:.3f}GB {gb_per_s:.2f}GB/s") | |
''', device_str='cuda') | |
# kernel path: /tmp/torchinductor_shunting/h5/ch5onztw533j5pafqjnjfivj5tmonqevomegommgb7ujoga557vl.py | |
# Source Nodes: [], Original ATen: [aten.sum] | |
triton_red_fused_sum_13 = async_compile.triton('triton_red_fused_sum_13', ''' | |
import triton | |
import triton.language as tl | |
from triton.compiler.compiler import AttrsDescriptor | |
from torch._inductor.runtime import triton_helpers, triton_heuristics | |
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties | |
from torch._dynamo.testing import rand_strided | |
from torch._C import _cuda_getCurrentRawStream as get_raw_stream | |
import torch | |
from torch._inductor.runtime.triton_heuristics import grid, split_scan_grid | |
@triton_heuristics.reduction( | |
size_hints=[65536, 128], | |
reduction_hint=ReductionHint.OUTER, | |
filename=__file__, | |
triton_meta={'signature': {0: '*fp16', 1: '*fp32', 2: 'i32', 3: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3), equal_to_1=())]}, | |
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused_sum_13', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': '196B341D951BBDD96DE5B2B44B37054C2CBA8E89494C8B8A1052A863B8BC7596', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'kernel_num_gb': 0.01277952} | |
) | |
@triton.jit | |
def triton_red_fused_sum_13(in_ptr0, out_ptr0, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr): | |
xnumel = 49152 | |
rnumel = 128 | |
xoffset = tl.program_id(0) * XBLOCK | |
xindex = xoffset + tl.arange(0, XBLOCK)[:, None] | |
xmask = tl.full([XBLOCK, RBLOCK], True, tl.int1) | |
rbase = tl.arange(0, RBLOCK)[None, :] | |
x0 = xindex % 768 | |
x1 = (xindex // 768) | |
_tmp2 = tl.full([XBLOCK, RBLOCK], 0, tl.float32) | |
x3 = xindex | |
for roffset in range(0, rnumel, RBLOCK): | |
rindex = roffset + rbase | |
rmask = rindex < rnumel | |
r2 = rindex | |
tmp0 = tl.load(in_ptr0 + (x0 + (768*r2) + (98304*x1)), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32) | |
tmp1 = tl.broadcast_to(tmp0, [XBLOCK, RBLOCK]) | |
tmp3 = _tmp2 + tmp1 | |
_tmp2 = tl.where(rmask, tmp3, _tmp2) | |
tmp2 = tl.sum(_tmp2, 1)[:, None] | |
tl.store(out_ptr0 + (x3), tmp2, None) | |
def get_args(): | |
arg_0 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.float16) | |
arg_1 = rand_strided((1, 768, 64), (49152, 1, 768), device='cuda:0', dtype=torch.float32) | |
return arg_0, arg_1, | |
def call(args): | |
with torch.cuda._DeviceGuard(0): | |
torch.cuda.set_device(0) | |
stream0 = get_raw_stream(0) | |
triton_red_fused_sum_13.run(*args, 49152, 128, grid=grid(49152), stream=stream0) | |
def benchmark_all_configs(args): | |
with torch.cuda._DeviceGuard(0): | |
torch.cuda.set_device(0) | |
return triton_red_fused_sum_13.benchmark_all_configs(*args, 49152, 128, grid=grid(49152)) | |
if __name__ == '__main__': | |
from triton.testing import do_bench | |
args = get_args() | |
ms = do_bench(lambda: call(args), rep=40, fast_flush=True) | |
num_gb = 0.01277952 | |
gb_per_s = num_gb / (ms / 1e3) | |
print(f"{ms:.3f}ms {num_gb:.3f}GB {gb_per_s:.2f}GB/s") | |
''', device_str='cuda') | |
# kernel path: /tmp/torchinductor_shunting/xa/cxaftrsj77n67s4u7dsn2e4cavvigeasorghvno527bfpnxdwa6r.py | |
# Source Nodes: [], Original ATen: [aten._to_copy] | |
triton_poi_fused__to_copy_14 = async_compile.triton('triton_poi_fused__to_copy_14', ''' | |
import triton | |
import triton.language as tl | |
from triton.compiler.compiler import AttrsDescriptor | |
from torch._inductor.runtime import triton_helpers, triton_heuristics | |
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties | |
from torch._dynamo.testing import rand_strided | |
from torch._C import _cuda_getCurrentRawStream as get_raw_stream | |
import torch | |
from torch._inductor.runtime.triton_heuristics import grid, split_scan_grid | |
@triton_heuristics.pointwise( | |
size_hints=[4194304], | |
filename=__file__, | |
triton_meta={'signature': {0: '*fp16', 1: '*fp32', 2: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2), equal_to_1=())]}, | |
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_14', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': '196B341D951BBDD96DE5B2B44B37054C2CBA8E89494C8B8A1052A863B8BC7596', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'kernel_num_gb': 0.014155776}, | |
min_elem_per_thread=0 | |
) | |
@triton.jit | |
def triton_poi_fused__to_copy_14(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr): | |
xnumel = 2359296 | |
xoffset = tl.program_id(0) * XBLOCK | |
xindex = xoffset + tl.arange(0, XBLOCK)[:] | |
xmask = tl.full([XBLOCK], True, tl.int1) | |
x0 = xindex | |
tmp0 = tl.load(in_ptr0 + (x0), None).to(tl.float32) | |
tmp1 = tmp0.to(tl.float32) | |
tl.store(out_ptr0 + (x0), tmp1, None) | |
def get_args(): | |
arg_0 = rand_strided((768, 3072), (3072, 1), device='cuda:0', dtype=torch.float16) | |
arg_1 = rand_strided((768, 3072), (3072, 1), device='cuda:0', dtype=torch.float32) | |
return arg_0, arg_1, | |
def call(args): | |
with torch.cuda._DeviceGuard(0): | |
torch.cuda.set_device(0) | |
stream0 = get_raw_stream(0) | |
triton_poi_fused__to_copy_14.run(*args, 2359296, grid=grid(2359296), stream=stream0) | |
def benchmark_all_configs(args): | |
with torch.cuda._DeviceGuard(0): | |
torch.cuda.set_device(0) | |
return triton_poi_fused__to_copy_14.benchmark_all_configs(*args, 2359296, grid=grid(2359296)) | |
if __name__ == '__main__': | |
from triton.testing import do_bench | |
args = get_args() | |
ms = do_bench(lambda: call(args), rep=40, fast_flush=True) | |
num_gb = 0.014155776 | |
gb_per_s = num_gb / (ms / 1e3) | |
print(f"{ms:.3f}ms {num_gb:.3f}GB {gb_per_s:.2f}GB/s") | |
''', device_str='cuda') | |
# kernel path: /tmp/torchinductor_shunting/hf/chfhbekl56hr4wmn7e4jzntlpgkbay7rkk3uwcw34wubtw4mtp7c.py | |
# Source Nodes: [hidden_states_92], Original ATen: [aten.gelu, aten.gelu_backward] | |
# hidden_states_92 => add_96, convert_element_type_485, erf_11, mul_155 | |
triton_poi_fused_gelu_gelu_backward_15 = async_compile.triton('triton_poi_fused_gelu_gelu_backward_15', ''' | |
import triton | |
import triton.language as tl | |
from triton.compiler.compiler import AttrsDescriptor | |
from torch._inductor.runtime import triton_helpers, triton_heuristics | |
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties | |
from torch._dynamo.testing import rand_strided | |
from torch._C import _cuda_getCurrentRawStream as get_raw_stream | |
import torch | |
from torch._inductor.runtime.triton_heuristics import grid, split_scan_grid | |
@triton_heuristics.pointwise( | |
size_hints=[33554432], | |
filename=__file__, | |
triton_meta={'signature': {0: '*fp16', 1: '*fp16', 2: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2), equal_to_1=())]}, | |
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_gelu_gelu_backward_15', 'mutated_arg_names': ['in_out_ptr0'], 'no_x_dim': False, 'num_load': 2, 'num_reduction': 0, 'backend_hash': '196B341D951BBDD96DE5B2B44B37054C2CBA8E89494C8B8A1052A863B8BC7596', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'kernel_num_gb': 0.150994944}, | |
min_elem_per_thread=0 | |
) | |
@triton.jit | |
def triton_poi_fused_gelu_gelu_backward_15(in_out_ptr0, in_ptr0, xnumel, XBLOCK : tl.constexpr): | |
xnumel = 25165824 | |
xoffset = tl.program_id(0) * XBLOCK | |
xindex = xoffset + tl.arange(0, XBLOCK)[:] | |
xmask = tl.full([XBLOCK], True, tl.int1) | |
x0 = xindex | |
tmp0 = tl.load(in_out_ptr0 + (x0), None).to(tl.float32) | |
tmp2 = tl.load(in_ptr0 + (x0), None).to(tl.float32) | |
tmp1 = tmp0.to(tl.float32) | |
tmp3 = tmp2.to(tl.float32) | |
tmp4 = 0.7071067811865476 | |
tmp5 = tmp3 * tmp4 | |
tmp6 = libdevice.erf(tmp5) | |
tmp7 = 1.0 | |
tmp8 = tmp6 + tmp7 | |
tmp9 = 0.5 | |
tmp10 = tmp8 * tmp9 | |
tmp11 = tmp3 * tmp3 | |
tmp12 = -0.5 | |
tmp13 = tmp11 * tmp12 | |
tmp14 = tl_math.exp(tmp13) | |
tmp15 = 0.3989422804014327 | |
tmp16 = tmp14 * tmp15 | |
tmp17 = tmp3 * tmp16 | |
tmp18 = tmp10 + tmp17 | |
tmp19 = tmp1 * tmp18 | |
tmp20 = tmp19.to(tl.float32) | |
tl.store(in_out_ptr0 + (x0), tmp20, None) | |
def get_args(): | |
arg_0 = rand_strided((16, 512, 3072), (1572864, 3072, 1), device='cuda:0', dtype=torch.float16) | |
arg_1 = rand_strided((8192, 3072), (3072, 1), device='cuda:0', dtype=torch.float16) | |
return arg_0, arg_1, | |
def call(args): | |
with torch.cuda._DeviceGuard(0): | |
torch.cuda.set_device(0) | |
stream0 = get_raw_stream(0) | |
triton_poi_fused_gelu_gelu_backward_15.run(*args, 25165824, grid=grid(25165824), stream=stream0) | |
def benchmark_all_configs(args): | |
with torch.cuda._DeviceGuard(0): | |
torch.cuda.set_device(0) | |
return triton_poi_fused_gelu_gelu_backward_15.benchmark_all_configs(*args, 25165824, grid=grid(25165824)) | |
if __name__ == '__main__': | |
from triton.testing import do_bench | |
args = get_args() | |
ms = do_bench(lambda: call(args), rep=40, fast_flush=True) | |
num_gb = 0.150994944 | |
gb_per_s = num_gb / (ms / 1e3) | |
print(f"{ms:.3f}ms {num_gb:.3f}GB {gb_per_s:.2f}GB/s") | |
''', device_str='cuda') | |
# kernel path: /tmp/torchinductor_shunting/fh/cfhrsmzvructcw4ytv6jmh4xhobkaqx2h23x7f66d374p3qloz4y.py | |
# Source Nodes: [], Original ATen: [aten.sum] | |
triton_red_fused_sum_16 = async_compile.triton('triton_red_fused_sum_16', ''' | |
import triton | |
import triton.language as tl | |
from triton.compiler.compiler import AttrsDescriptor | |
from torch._inductor.runtime import triton_helpers, triton_heuristics | |
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties | |
from torch._dynamo.testing import rand_strided | |
from torch._C import _cuda_getCurrentRawStream as get_raw_stream | |
import torch | |
from torch._inductor.runtime.triton_heuristics import grid, split_scan_grid | |
@triton_heuristics.reduction( | |
size_hints=[131072, 256], | |
reduction_hint=ReductionHint.OUTER, | |
filename=__file__, | |
triton_meta={'signature': {0: '*fp16', 1: '*fp32', 2: 'i32', 3: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3), equal_to_1=())]}, | |
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused_sum_16', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': '196B341D951BBDD96DE5B2B44B37054C2CBA8E89494C8B8A1052A863B8BC7596', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'kernel_num_gb': 0.050724864} | |
) | |
@triton.jit | |
def triton_red_fused_sum_16(in_ptr0, out_ptr0, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr): | |
xnumel = 98304 | |
rnumel = 256 | |
xoffset = tl.program_id(0) * XBLOCK | |
xindex = xoffset + tl.arange(0, XBLOCK)[:, None] | |
xmask = tl.full([XBLOCK, RBLOCK], True, tl.int1) | |
rbase = tl.arange(0, RBLOCK)[None, :] | |
x0 = xindex % 3072 | |
x1 = (xindex // 3072) | |
_tmp2 = tl.full([XBLOCK, RBLOCK], 0, tl.float32) | |
x3 = xindex | |
for roffset in range(0, rnumel, RBLOCK): | |
rindex = roffset + rbase | |
rmask = rindex < rnumel | |
r2 = rindex | |
tmp0 = tl.load(in_ptr0 + (x0 + (3072*r2) + (786432*x1)), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32) | |
tmp1 = tl.broadcast_to(tmp0, [XBLOCK, RBLOCK]) | |
tmp3 = _tmp2 + tmp1 | |
_tmp2 = tl.where(rmask, tmp3, _tmp2) | |
tmp2 = tl.sum(_tmp2, 1)[:, None] | |
tl.store(out_ptr0 + (x3), tmp2, None) | |
def get_args(): | |
arg_0 = rand_strided((16, 512, 3072), (1572864, 3072, 1), device='cuda:0', dtype=torch.float16) | |
arg_1 = rand_strided((1, 3072, 32), (98304, 1, 3072), device='cuda:0', dtype=torch.float32) | |
return arg_0, arg_1, | |
def call(args): | |
with torch.cuda._DeviceGuard(0): | |
torch.cuda.set_device(0) | |
stream0 = get_raw_stream(0) | |
triton_red_fused_sum_16.run(*args, 98304, 256, grid=grid(98304), stream=stream0) | |
def benchmark_all_configs(args): | |
with torch.cuda._DeviceGuard(0): | |
torch.cuda.set_device(0) | |
return triton_red_fused_sum_16.benchmark_all_configs(*args, 98304, 256, grid=grid(98304)) | |
if __name__ == '__main__': | |
from triton.testing import do_bench | |
args = get_args() | |
ms = do_bench(lambda: call(args), rep=40, fast_flush=True) | |
num_gb = 0.050724864 | |
gb_per_s = num_gb / (ms / 1e3) | |
print(f"{ms:.3f}ms {num_gb:.3f}GB {gb_per_s:.2f}GB/s") | |
''', device_str='cuda') | |
# kernel path: /tmp/torchinductor_shunting/i2/ci2zrjeg52uyf2byr2fphdezj5wiruk6qxrv4gylsb5x5sdq7p37.py | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum] | |
triton_per_fused__to_copy_sum_17 = async_compile.triton('triton_per_fused__to_copy_sum_17', ''' | |
import triton | |
import triton.language as tl | |
from triton.compiler.compiler import AttrsDescriptor | |
from torch._inductor.runtime import triton_helpers, triton_heuristics | |
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties | |
from torch._dynamo.testing import rand_strided | |
from torch._C import _cuda_getCurrentRawStream as get_raw_stream | |
import torch | |
from torch._inductor.runtime.triton_heuristics import grid, split_scan_grid | |
@triton_heuristics.persistent_reduction( | |
size_hints=[4096, 32], | |
reduction_hint=ReductionHint.OUTER, | |
filename=__file__, | |
triton_meta={'signature': {0: '*fp32', 1: '*fp32', 2: 'i32', 3: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3), equal_to_1=())]}, | |
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_sum_17', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': '196B341D951BBDD96DE5B2B44B37054C2CBA8E89494C8B8A1052A863B8BC7596', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'kernel_num_gb': 0.000405504} | |
) | |
@triton.jit | |
def triton_per_fused__to_copy_sum_17(in_ptr0, out_ptr1, xnumel, rnumel, XBLOCK : tl.constexpr): | |
xnumel = 3072 | |
rnumel = 32 | |
RBLOCK: tl.constexpr = 32 | |
xoffset = tl.program_id(0) * XBLOCK | |
xindex = xoffset + tl.arange(0, XBLOCK)[:, None] | |
xmask = xindex < xnumel | |
rindex = tl.arange(0, RBLOCK)[None, :] | |
roffset = 0 | |
rmask = tl.full([XBLOCK, RBLOCK], True, tl.int1) | |
r1 = rindex | |
x0 = xindex | |
tmp0 = tl.load(in_ptr0 + (x0 + (3072*r1)), xmask, other=0.0) | |
tmp1 = tl.broadcast_to(tmp0, [XBLOCK, RBLOCK]) | |
tmp3 = tl.where(xmask, tmp1, 0) | |
tmp4 = tl.sum(tmp3, 1)[:, None] | |
tmp5 = tmp4.to(tl.float32) | |
tl.store(out_ptr1 + (x0), tmp5, xmask) | |
def get_args(): | |
arg_0 = rand_strided((1, 3072, 32), (98304, 1, 3072), device='cuda:0', dtype=torch.float32) | |
arg_1 = rand_strided((3072,), (1,), device='cuda:0', dtype=torch.float32) | |
return arg_0, arg_1, | |
def call(args): | |
with torch.cuda._DeviceGuard(0): | |
torch.cuda.set_device(0) | |
stream0 = get_raw_stream(0) | |
triton_per_fused__to_copy_sum_17.run(*args, 3072, 32, grid=grid(3072), stream=stream0) | |
def benchmark_all_configs(args): | |
with torch.cuda._DeviceGuard(0): | |
torch.cuda.set_device(0) | |
return triton_per_fused__to_copy_sum_17.benchmark_all_configs(*args, 3072, 32, grid=grid(3072)) | |
if __name__ == '__main__': | |
from triton.testing import do_bench | |
args = get_args() | |
ms = do_bench(lambda: call(args), rep=40, fast_flush=True) | |
num_gb = 0.000405504 | |
gb_per_s = num_gb / (ms / 1e3) | |
print(f"{ms:.3f}ms {num_gb:.3f}GB {gb_per_s:.2f}GB/s") | |
''', device_str='cuda') | |
# kernel path: /tmp/torchinductor_shunting/wo/cworhwsjrdvmasqjh2lg6aetbo4ztpk5ats5rsegareen6mwnibu.py | |
# Source Nodes: [], Original ATen: [aten._to_copy] | |
triton_poi_fused__to_copy_18 = async_compile.triton('triton_poi_fused__to_copy_18', ''' | |
import triton | |
import triton.language as tl | |
from triton.compiler.compiler import AttrsDescriptor | |
from torch._inductor.runtime import triton_helpers, triton_heuristics | |
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties | |
from torch._dynamo.testing import rand_strided | |
from torch._C import _cuda_getCurrentRawStream as get_raw_stream | |
import torch | |
from torch._inductor.runtime.triton_heuristics import grid, split_scan_grid | |
@triton_heuristics.pointwise( | |
size_hints=[4194304], | |
filename=__file__, | |
triton_meta={'signature': {0: '*fp16', 1: '*fp32', 2: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2), equal_to_1=())]}, | |
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_18', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': '196B341D951BBDD96DE5B2B44B37054C2CBA8E89494C8B8A1052A863B8BC7596', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'kernel_num_gb': 0.014155776}, | |
min_elem_per_thread=0 | |
) | |
@triton.jit | |
def triton_poi_fused__to_copy_18(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr): | |
xnumel = 2359296 | |
xoffset = tl.program_id(0) * XBLOCK | |
xindex = xoffset + tl.arange(0, XBLOCK)[:] | |
xmask = tl.full([XBLOCK], True, tl.int1) | |
x0 = xindex | |
tmp0 = tl.load(in_ptr0 + (x0), None).to(tl.float32) | |
tmp1 = tmp0.to(tl.float32) | |
tl.store(out_ptr0 + (x0), tmp1, None) | |
def get_args(): | |
arg_0 = rand_strided((3072, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
arg_1 = rand_strided((3072, 768), (768, 1), device='cuda:0', dtype=torch.float32) | |
return arg_0, arg_1, | |
def call(args): | |
with torch.cuda._DeviceGuard(0): | |
torch.cuda.set_device(0) | |
stream0 = get_raw_stream(0) | |
triton_poi_fused__to_copy_18.run(*args, 2359296, grid=grid(2359296), stream=stream0) | |
def benchmark_all_configs(args): | |
with torch.cuda._DeviceGuard(0): | |
torch.cuda.set_device(0) | |
return triton_poi_fused__to_copy_18.benchmark_all_configs(*args, 2359296, grid=grid(2359296)) | |
if __name__ == '__main__': | |
from triton.testing import do_bench | |
args = get_args() | |
ms = do_bench(lambda: call(args), rep=40, fast_flush=True) | |
num_gb = 0.014155776 | |
gb_per_s = num_gb / (ms / 1e3) | |
print(f"{ms:.3f}ms {num_gb:.3f}GB {gb_per_s:.2f}GB/s") | |
''', device_str='cuda') | |
# kernel path: /tmp/torchinductor_shunting/xl/cxluz6c4say445ydf43oeaw5l6cwrfehyxz6txytoztcrndtykgu.py | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_dropout_backward, aten.native_layer_norm_backward] | |
triton_per_fused__to_copy_add_native_dropout_backward_native_layer_norm_backward_19 = async_compile.triton('triton_per_fused__to_copy_add_native_dropout_backward_native_layer_norm_backward_19', ''' | |
import triton | |
import triton.language as tl | |
from triton.compiler.compiler import AttrsDescriptor | |
from torch._inductor.runtime import triton_helpers, triton_heuristics | |
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties | |
from torch._dynamo.testing import rand_strided | |
from torch._C import _cuda_getCurrentRawStream as get_raw_stream | |
import torch | |
from torch._inductor.runtime.triton_heuristics import grid, split_scan_grid | |
@triton_heuristics.persistent_reduction( | |
size_hints=[8192, 1024], | |
reduction_hint=ReductionHint.INNER, | |
filename=__file__, | |
triton_meta={'signature': {0: '*fp32', 1: '*fp16', 2: '*fp32', 3: '*fp32', 4: '*fp32', 5: '*i1', 6: '*fp32', 7: '*fp16', 8: 'i32', 9: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4, 5, 6, 7, 8, 9), equal_to_1=())]}, | |
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_add_native_dropout_backward_native_layer_norm_backward_19', 'mutated_arg_names': [], 'no_x_dim': True, 'num_load': 6, 'num_reduction': 2, 'backend_hash': '196B341D951BBDD96DE5B2B44B37054C2CBA8E89494C8B8A1052A863B8BC7596', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'kernel_num_gb': 0.106990592} | |
) | |
@triton.jit | |
def triton_per_fused__to_copy_add_native_dropout_backward_native_layer_norm_backward_19(in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, in_ptr5, out_ptr2, out_ptr3, xnumel, rnumel): | |
xnumel = 8192 | |
XBLOCK: tl.constexpr = 1 | |
rnumel = 768 | |
RBLOCK: tl.constexpr = 1024 | |
xoffset = tl.program_id(0) * XBLOCK | |
xindex = tl.full([1], xoffset, tl.int32) | |
xmask = tl.full([RBLOCK], True, tl.int1) | |
rindex = tl.arange(0, RBLOCK)[:] | |
roffset = 0 | |
rmask = rindex < rnumel | |
r1 = rindex | |
x0 = xindex | |
tmp0 = tl.load(in_ptr0 + (r1 + (768*x0)), rmask, other=0.0) | |
tmp1 = tl.load(in_ptr1 + (r1 + (768*x0)), rmask, other=0.0).to(tl.float32) | |
tmp4 = tl.load(in_ptr2 + (r1), rmask, eviction_policy='evict_last', other=0.0) | |
tmp10 = tl.load(in_ptr3 + (r1 + (768*x0)), rmask, other=0.0) | |
tmp16 = tl.load(in_ptr4 + (x0), None, eviction_policy='evict_last') | |
tmp24 = tl.load(in_ptr5 + (r1 + (768*x0)), rmask, other=0.0).to(tl.int1) | |
tmp2 = tmp1.to(tl.float32) | |
tmp3 = tmp0 + tmp2 | |
tmp5 = tmp3 * tmp4 | |
tmp6 = tl.broadcast_to(tmp5, [RBLOCK]) | |
tmp8 = tl.where(rmask, tmp6, 0) | |
tmp9 = triton_helpers.promote_to_tensor(tl.sum(tmp8, 0)) | |
tmp11 = tmp5 * tmp10 | |
tmp12 = tl.broadcast_to(tmp11, [RBLOCK]) | |
tmp14 = tl.where(rmask, tmp12, 0) | |
tmp15 = triton_helpers.promote_to_tensor(tl.sum(tmp14, 0)) | |
tmp17 = 768.0 | |
tmp18 = tmp5 * tmp17 | |
tmp19 = tmp18 - tmp9 | |
tmp20 = tmp10 * tmp15 | |
tmp21 = tmp19 - tmp20 | |
tmp22 = tmp16 * tmp21 | |
tmp23 = tmp22.to(tl.float32) | |
tmp25 = tmp24.to(tl.float32) | |
tmp26 = 1.1111111111111112 | |
tmp27 = tmp25 * tmp26 | |
tmp28 = tmp23 * tmp27 | |
tl.store(out_ptr2 + (r1 + (768*x0)), tmp22, rmask) | |
tl.store(out_ptr3 + (r1 + (768*x0)), tmp28, rmask) | |
def get_args(): | |
arg_0 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.float32) | |
arg_1 = rand_strided((8192, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
arg_2 = rand_strided((768,), (1,), device='cuda:0', dtype=torch.float32) | |
arg_3 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.float32) | |
arg_4 = rand_strided((16, 512, 1), (512, 1, 1), device='cuda:0', dtype=torch.float32) | |
arg_5 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.bool) | |
arg_6 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.float32) | |
arg_7 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.float16) | |
return arg_0, arg_1, arg_2, arg_3, arg_4, arg_5, arg_6, arg_7, | |
def call(args): | |
with torch.cuda._DeviceGuard(0): | |
torch.cuda.set_device(0) | |
stream0 = get_raw_stream(0) | |
triton_per_fused__to_copy_add_native_dropout_backward_native_layer_norm_backward_19.run(*args, 8192, 768, grid=grid(8192), stream=stream0) | |
def benchmark_all_configs(args): | |
with torch.cuda._DeviceGuard(0): | |
torch.cuda.set_device(0) | |
return triton_per_fused__to_copy_add_native_dropout_backward_native_layer_norm_backward_19.benchmark_all_configs(*args, 8192, 768, grid=grid(8192)) | |
if __name__ == '__main__': | |
from triton.testing import do_bench | |
args = get_args() | |
ms = do_bench(lambda: call(args), rep=40, fast_flush=True) | |
num_gb = 0.106990592 | |
gb_per_s = num_gb / (ms / 1e3) | |
print(f"{ms:.3f}ms {num_gb:.3f}GB {gb_per_s:.2f}GB/s") | |
''', device_str='cuda') | |
# kernel path: /tmp/torchinductor_shunting/7d/c7dgbzvqana3egdgcxvl64as7g5ng2lnohfx5nuiwhgjgnqfzgzf.py | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward] | |
triton_red_fused__to_copy_add_native_layer_norm_backward_20 = async_compile.triton('triton_red_fused__to_copy_add_native_layer_norm_backward_20', ''' | |
import triton | |
import triton.language as tl | |
from triton.compiler.compiler import AttrsDescriptor | |
from torch._inductor.runtime import triton_helpers, triton_heuristics | |
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties | |
from torch._dynamo.testing import rand_strided | |
from torch._C import _cuda_getCurrentRawStream as get_raw_stream | |
import torch | |
from torch._inductor.runtime.triton_heuristics import grid, split_scan_grid | |
@triton_heuristics.reduction( | |
size_hints=[65536, 128], | |
reduction_hint=ReductionHint.OUTER, | |
filename=__file__, | |
triton_meta={'signature': {0: '*fp32', 1: '*fp16', 2: '*fp32', 3: '*fp32', 4: '*fp32', 5: 'i32', 6: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4, 5, 6), equal_to_1=())]}, | |
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_add_native_layer_norm_backward_20', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 3, 'num_reduction': 2, 'backend_hash': '196B341D951BBDD96DE5B2B44B37054C2CBA8E89494C8B8A1052A863B8BC7596', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'kernel_num_gb': 0.063307776} | |
) | |
@triton.jit | |
def triton_red_fused__to_copy_add_native_layer_norm_backward_20(in_ptr0, in_ptr1, in_ptr2, out_ptr0, out_ptr1, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr): | |
xnumel = 49152 | |
rnumel = 128 | |
xoffset = tl.program_id(0) * XBLOCK | |
xindex = xoffset + tl.arange(0, XBLOCK)[:, None] | |
xmask = tl.full([XBLOCK, RBLOCK], True, tl.int1) | |
rbase = tl.arange(0, RBLOCK)[None, :] | |
x0 = xindex % 768 | |
x1 = (xindex // 768) | |
_tmp7 = tl.full([XBLOCK, RBLOCK], 0, tl.float32) | |
x3 = xindex | |
_tmp10 = tl.full([XBLOCK, RBLOCK], 0, tl.float32) | |
for roffset in range(0, rnumel, RBLOCK): | |
rindex = roffset + rbase | |
rmask = rindex < rnumel | |
r2 = rindex | |
tmp0 = tl.load(in_ptr0 + (x0 + (768*r2) + (98304*x1)), rmask, eviction_policy='evict_first', other=0.0) | |
tmp1 = tl.load(in_ptr1 + (x0 + (768*r2) + (98304*x1)), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32) | |
tmp4 = tl.load(in_ptr2 + (x0 + (768*r2) + (98304*x1)), rmask, eviction_policy='evict_first', other=0.0) | |
tmp2 = tmp1.to(tl.float32) | |
tmp3 = tmp0 + tmp2 | |
tmp5 = tmp3 * tmp4 | |
tmp6 = tl.broadcast_to(tmp5, [XBLOCK, RBLOCK]) | |
tmp8 = _tmp7 + tmp6 | |
_tmp7 = tl.where(rmask, tmp8, _tmp7) | |
tmp9 = tl.broadcast_to(tmp3, [XBLOCK, RBLOCK]) | |
tmp11 = _tmp10 + tmp9 | |
_tmp10 = tl.where(rmask, tmp11, _tmp10) | |
tmp7 = tl.sum(_tmp7, 1)[:, None] | |
tl.store(out_ptr0 + (x3), tmp7, None) | |
tmp10 = tl.sum(_tmp10, 1)[:, None] | |
tl.store(out_ptr1 + (x3), tmp10, None) | |
def get_args(): | |
arg_0 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.float32) | |
arg_1 = rand_strided((8192, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
arg_2 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.float32) | |
arg_3 = rand_strided((768, 64), (1, 768), device='cuda:0', dtype=torch.float32) | |
arg_4 = rand_strided((768, 64), (1, 768), device='cuda:0', dtype=torch.float32) | |
return arg_0, arg_1, arg_2, arg_3, arg_4, | |
def call(args): | |
with torch.cuda._DeviceGuard(0): | |
torch.cuda.set_device(0) | |
stream0 = get_raw_stream(0) | |
triton_red_fused__to_copy_add_native_layer_norm_backward_20.run(*args, 49152, 128, grid=grid(49152), stream=stream0) | |
def benchmark_all_configs(args): | |
with torch.cuda._DeviceGuard(0): | |
torch.cuda.set_device(0) | |
return triton_red_fused__to_copy_add_native_layer_norm_backward_20.benchmark_all_configs(*args, 49152, 128, grid=grid(49152)) | |
if __name__ == '__main__': | |
from triton.testing import do_bench | |
args = get_args() | |
ms = do_bench(lambda: call(args), rep=40, fast_flush=True) | |
num_gb = 0.063307776 | |
gb_per_s = num_gb / (ms / 1e3) | |
print(f"{ms:.3f}ms {num_gb:.3f}GB {gb_per_s:.2f}GB/s") | |
''', device_str='cuda') | |
# kernel path: /tmp/torchinductor_shunting/wl/cwl6ydl7tr2jbzx6zhrjkb6rmp4sbhadfjg44rsjimnvb7lidvb6.py | |
# Source Nodes: [], Original ATen: [aten.clone] | |
triton_poi_fused_clone_21 = async_compile.triton('triton_poi_fused_clone_21', ''' | |
import triton | |
import triton.language as tl | |
from triton.compiler.compiler import AttrsDescriptor | |
from torch._inductor.runtime import triton_helpers, triton_heuristics | |
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties | |
from torch._dynamo.testing import rand_strided | |
from torch._C import _cuda_getCurrentRawStream as get_raw_stream | |
import torch | |
from torch._inductor.runtime.triton_heuristics import grid, split_scan_grid | |
@triton_heuristics.pointwise( | |
size_hints=[8388608], | |
filename=__file__, | |
triton_meta={'signature': {0: '*fp16', 1: '*fp16', 2: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2), equal_to_1=())]}, | |
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_clone_21', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': '196B341D951BBDD96DE5B2B44B37054C2CBA8E89494C8B8A1052A863B8BC7596', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'kernel_num_gb': 0.025165824}, | |
min_elem_per_thread=0 | |
) | |
@triton.jit | |
def triton_poi_fused_clone_21(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr): | |
xnumel = 6291456 | |
xoffset = tl.program_id(0) * XBLOCK | |
xindex = xoffset + tl.arange(0, XBLOCK)[:] | |
xmask = tl.full([XBLOCK], True, tl.int1) | |
x0 = xindex % 64 | |
x1 = (xindex // 64) % 512 | |
x2 = (xindex // 32768) % 12 | |
x3 = (xindex // 393216) | |
x4 = xindex | |
tmp0 = tl.load(in_ptr0 + (x0 + (64*x2) + (768*x1) + (393216*x3)), None).to(tl.float32) | |
tl.store(out_ptr0 + (x4), tmp0, None) | |
def get_args(): | |
arg_0 = rand_strided((8192, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
arg_1 = rand_strided((16, 12, 512, 64), (393216, 32768, 64, 1), device='cuda:0', dtype=torch.float16) | |
return arg_0, arg_1, | |
def call(args): | |
with torch.cuda._DeviceGuard(0): | |
torch.cuda.set_device(0) | |
stream0 = get_raw_stream(0) | |
triton_poi_fused_clone_21.run(*args, 6291456, grid=grid(6291456), stream=stream0) | |
def benchmark_all_configs(args): | |
with torch.cuda._DeviceGuard(0): | |
torch.cuda.set_device(0) | |
return triton_poi_fused_clone_21.benchmark_all_configs(*args, 6291456, grid=grid(6291456)) | |
if __name__ == '__main__': | |
from triton.testing import do_bench | |
args = get_args() | |
ms = do_bench(lambda: call(args), rep=40, fast_flush=True) | |
num_gb = 0.025165824 | |
gb_per_s = num_gb / (ms / 1e3) | |
print(f"{ms:.3f}ms {num_gb:.3f}GB {gb_per_s:.2f}GB/s") | |
''', device_str='cuda') | |
# kernel path: /tmp/torchinductor_shunting/2w/c2w56xbyktkjz3c5fbogsijpit33ejsgka5ndga7xju74glcq3wd.py | |
# Source Nodes: [], Original ATen: [aten.sum] | |
triton_red_fused_sum_22 = async_compile.triton('triton_red_fused_sum_22', ''' | |
import triton | |
import triton.language as tl | |
from triton.compiler.compiler import AttrsDescriptor | |
from torch._inductor.runtime import triton_helpers, triton_heuristics | |
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties | |
from torch._dynamo.testing import rand_strided | |
from torch._C import _cuda_getCurrentRawStream as get_raw_stream | |
import torch | |
from torch._inductor.runtime.triton_heuristics import grid, split_scan_grid | |
@triton_heuristics.reduction( | |
size_hints=[65536, 128], | |
reduction_hint=ReductionHint.OUTER, | |
filename=__file__, | |
triton_meta={'signature': {0: '*fp16', 1: '*fp32', 2: 'i32', 3: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3), equal_to_1=())]}, | |
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused_sum_22', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': '196B341D951BBDD96DE5B2B44B37054C2CBA8E89494C8B8A1052A863B8BC7596', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'kernel_num_gb': 0.01277952} | |
) | |
@triton.jit | |
def triton_red_fused_sum_22(in_ptr0, out_ptr0, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr): | |
xnumel = 49152 | |
rnumel = 128 | |
xoffset = tl.program_id(0) * XBLOCK | |
xindex = xoffset + tl.arange(0, XBLOCK)[:, None] | |
xmask = tl.full([XBLOCK, RBLOCK], True, tl.int1) | |
rbase = tl.arange(0, RBLOCK)[None, :] | |
x0 = xindex % 768 | |
x1 = (xindex // 768) | |
_tmp2 = tl.full([XBLOCK, RBLOCK], 0, tl.float32) | |
x3 = xindex | |
for roffset in range(0, rnumel, RBLOCK): | |
rindex = roffset + rbase | |
rmask = rindex < rnumel | |
r2 = rindex | |
tmp0 = tl.load(in_ptr0 + (x0 + (768*r2) + (98304*x1)), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32) | |
tmp1 = tl.broadcast_to(tmp0, [XBLOCK, RBLOCK]) | |
tmp3 = _tmp2 + tmp1 | |
_tmp2 = tl.where(rmask, tmp3, _tmp2) | |
tmp2 = tl.sum(_tmp2, 1)[:, None] | |
tl.store(out_ptr0 + (x3), tmp2, None) | |
def get_args(): | |
arg_0 = rand_strided((16, 12, 512, 64), (393216, 64, 768, 1), device='cuda:0', dtype=torch.float16) | |
arg_1 = rand_strided((1, 768, 64), (49152, 1, 768), device='cuda:0', dtype=torch.float32) | |
return arg_0, arg_1, | |
def call(args): | |
with torch.cuda._DeviceGuard(0): | |
torch.cuda.set_device(0) | |
stream0 = get_raw_stream(0) | |
triton_red_fused_sum_22.run(*args, 49152, 128, grid=grid(49152), stream=stream0) | |
def benchmark_all_configs(args): | |
with torch.cuda._DeviceGuard(0): | |
torch.cuda.set_device(0) | |
return triton_red_fused_sum_22.benchmark_all_configs(*args, 49152, 128, grid=grid(49152)) | |
if __name__ == '__main__': | |
from triton.testing import do_bench | |
args = get_args() | |
ms = do_bench(lambda: call(args), rep=40, fast_flush=True) | |
num_gb = 0.01277952 | |
gb_per_s = num_gb / (ms / 1e3) | |
print(f"{ms:.3f}ms {num_gb:.3f}GB {gb_per_s:.2f}GB/s") | |
''', device_str='cuda') | |
# kernel path: /tmp/torchinductor_shunting/jl/cjldro7q6f5pxz6zuzrjak63sjnv3jqrpdferyvy23uyblyxgymd.py | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_dropout_backward, aten.native_layer_norm_backward] | |
triton_per_fused__to_copy_add_native_dropout_backward_native_layer_norm_backward_23 = async_compile.triton('triton_per_fused__to_copy_add_native_dropout_backward_native_layer_norm_backward_23', ''' | |
import triton | |
import triton.language as tl | |
from triton.compiler.compiler import AttrsDescriptor | |
from torch._inductor.runtime import triton_helpers, triton_heuristics | |
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties | |
from torch._dynamo.testing import rand_strided | |
from torch._C import _cuda_getCurrentRawStream as get_raw_stream | |
import torch | |
from torch._inductor.runtime.triton_heuristics import grid, split_scan_grid | |
@triton_heuristics.persistent_reduction( | |
size_hints=[8192, 1024], | |
reduction_hint=ReductionHint.INNER, | |
filename=__file__, | |
triton_meta={'signature': {0: '*fp32', 1: '*fp16', 2: '*fp16', 3: '*fp16', 4: '*fp32', 5: '*fp32', 6: '*fp32', 7: '*i1', 8: '*fp32', 9: '*fp16', 10: 'i32', 11: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11), equal_to_1=())]}, | |
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_add_native_dropout_backward_native_layer_norm_backward_23', 'mutated_arg_names': [], 'no_x_dim': True, 'num_load': 8, 'num_reduction': 2, 'backend_hash': '196B341D951BBDD96DE5B2B44B37054C2CBA8E89494C8B8A1052A863B8BC7596', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'kernel_num_gb': 0.132156416} | |
) | |
@triton.jit | |
def triton_per_fused__to_copy_add_native_dropout_backward_native_layer_norm_backward_23(in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, in_ptr5, in_ptr6, in_ptr7, out_ptr3, out_ptr4, xnumel, rnumel): | |
xnumel = 8192 | |
XBLOCK: tl.constexpr = 1 | |
rnumel = 768 | |
RBLOCK: tl.constexpr = 1024 | |
xoffset = tl.program_id(0) * XBLOCK | |
xindex = tl.full([1], xoffset, tl.int32) | |
xmask = tl.full([RBLOCK], True, tl.int1) | |
rindex = tl.arange(0, RBLOCK)[:] | |
roffset = 0 | |
rmask = rindex < rnumel | |
r1 = rindex | |
x0 = xindex | |
tmp0 = tl.load(in_ptr0 + (r1 + (768*x0)), rmask, other=0.0) | |
tmp1 = tl.load(in_ptr1 + (r1 + (768*x0)), rmask, other=0.0).to(tl.float32) | |
tmp4 = tl.load(in_ptr2 + (r1 + (768*x0)), rmask, other=0.0).to(tl.float32) | |
tmp7 = tl.load(in_ptr3 + (r1 + (768*x0)), rmask, other=0.0).to(tl.float32) | |
tmp10 = tl.load(in_ptr4 + (r1), rmask, eviction_policy='evict_last', other=0.0) | |
tmp16 = tl.load(in_ptr5 + (r1 + (768*x0)), rmask, other=0.0) | |
tmp22 = tl.load(in_ptr6 + (x0), None, eviction_policy='evict_last') | |
tmp30 = tl.load(in_ptr7 + (r1 + (768*x0)), rmask, other=0.0).to(tl.int1) | |
tmp2 = tmp1.to(tl.float32) | |
tmp3 = tmp0 + tmp2 | |
tmp5 = tmp4.to(tl.float32) | |
tmp6 = tmp3 + tmp5 | |
tmp8 = tmp7.to(tl.float32) | |
tmp9 = tmp6 + tmp8 | |
tmp11 = tmp9 * tmp10 | |
tmp12 = tl.broadcast_to(tmp11, [RBLOCK]) | |
tmp14 = tl.where(rmask, tmp12, 0) | |
tmp15 = triton_helpers.promote_to_tensor(tl.sum(tmp14, 0)) | |
tmp17 = tmp11 * tmp16 | |
tmp18 = tl.broadcast_to(tmp17, [RBLOCK]) | |
tmp20 = tl.where(rmask, tmp18, 0) | |
tmp21 = triton_helpers.promote_to_tensor(tl.sum(tmp20, 0)) | |
tmp23 = 768.0 | |
tmp24 = tmp11 * tmp23 | |
tmp25 = tmp24 - tmp15 | |
tmp26 = tmp16 * tmp21 | |
tmp27 = tmp25 - tmp26 | |
tmp28 = tmp22 * tmp27 | |
tmp29 = tmp28.to(tl.float32) | |
tmp31 = tmp30.to(tl.float32) | |
tmp32 = 1.1111111111111112 | |
tmp33 = tmp31 * tmp32 | |
tmp34 = tmp29 * tmp33 | |
tl.store(out_ptr3 + (r1 + (768*x0)), tmp28, rmask) | |
tl.store(out_ptr4 + (r1 + (768*x0)), tmp34, rmask) | |
def get_args(): | |
arg_0 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.float32) | |
arg_1 = rand_strided((8192, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
arg_2 = rand_strided((8192, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
arg_3 = rand_strided((8192, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
arg_4 = rand_strided((768,), (1,), device='cuda:0', dtype=torch.float32) | |
arg_5 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.float32) | |
arg_6 = rand_strided((16, 512, 1), (512, 1, 1), device='cuda:0', dtype=torch.float32) | |
arg_7 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.bool) | |
arg_8 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.float32) | |
arg_9 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.float16) | |
return arg_0, arg_1, arg_2, arg_3, arg_4, arg_5, arg_6, arg_7, arg_8, arg_9, | |
def call(args): | |
with torch.cuda._DeviceGuard(0): | |
torch.cuda.set_device(0) | |
stream0 = get_raw_stream(0) | |
triton_per_fused__to_copy_add_native_dropout_backward_native_layer_norm_backward_23.run(*args, 8192, 768, grid=grid(8192), stream=stream0) | |
def benchmark_all_configs(args): | |
with torch.cuda._DeviceGuard(0): | |
torch.cuda.set_device(0) | |
return triton_per_fused__to_copy_add_native_dropout_backward_native_layer_norm_backward_23.benchmark_all_configs(*args, 8192, 768, grid=grid(8192)) | |
if __name__ == '__main__': | |
from triton.testing import do_bench | |
args = get_args() | |
ms = do_bench(lambda: call(args), rep=40, fast_flush=True) | |
num_gb = 0.132156416 | |
gb_per_s = num_gb / (ms / 1e3) | |
print(f"{ms:.3f}ms {num_gb:.3f}GB {gb_per_s:.2f}GB/s") | |
''', device_str='cuda') | |
# kernel path: /tmp/torchinductor_shunting/nh/cnhqvaojz2ly5fypxyjca4dmreema6cy5ukpvnferpa7nqrba3em.py | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward] | |
triton_red_fused__to_copy_add_native_layer_norm_backward_24 = async_compile.triton('triton_red_fused__to_copy_add_native_layer_norm_backward_24', ''' | |
import triton | |
import triton.language as tl | |
from triton.compiler.compiler import AttrsDescriptor | |
from torch._inductor.runtime import triton_helpers, triton_heuristics | |
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties | |
from torch._dynamo.testing import rand_strided | |
from torch._C import _cuda_getCurrentRawStream as get_raw_stream | |
import torch | |
from torch._inductor.runtime.triton_heuristics import grid, split_scan_grid | |
@triton_heuristics.reduction( | |
size_hints=[65536, 128], | |
reduction_hint=ReductionHint.OUTER, | |
filename=__file__, | |
triton_meta={'signature': {0: '*fp32', 1: '*fp16', 2: '*fp16', 3: '*fp16', 4: '*fp32', 5: '*fp32', 6: '*fp32', 7: 'i32', 8: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4, 5, 6, 7, 8), equal_to_1=())]}, | |
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_add_native_layer_norm_backward_24', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 5, 'num_reduction': 2, 'backend_hash': '196B341D951BBDD96DE5B2B44B37054C2CBA8E89494C8B8A1052A863B8BC7596', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'kernel_num_gb': 0.0884736} | |
) | |
@triton.jit | |
def triton_red_fused__to_copy_add_native_layer_norm_backward_24(in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, out_ptr0, out_ptr1, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr): | |
xnumel = 49152 | |
rnumel = 128 | |
xoffset = tl.program_id(0) * XBLOCK | |
xindex = xoffset + tl.arange(0, XBLOCK)[:, None] | |
xmask = tl.full([XBLOCK, RBLOCK], True, tl.int1) | |
rbase = tl.arange(0, RBLOCK)[None, :] | |
x0 = xindex % 768 | |
x1 = (xindex // 768) | |
_tmp13 = tl.full([XBLOCK, RBLOCK], 0, tl.float32) | |
x3 = xindex | |
_tmp16 = tl.full([XBLOCK, RBLOCK], 0, tl.float32) | |
for roffset in range(0, rnumel, RBLOCK): | |
rindex = roffset + rbase | |
rmask = rindex < rnumel | |
r2 = rindex | |
tmp0 = tl.load(in_ptr0 + (x0 + (768*r2) + (98304*x1)), rmask, eviction_policy='evict_first', other=0.0) | |
tmp1 = tl.load(in_ptr1 + (x0 + (768*r2) + (98304*x1)), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32) | |
tmp4 = tl.load(in_ptr2 + (x0 + (768*r2) + (98304*x1)), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32) | |
tmp7 = tl.load(in_ptr3 + (x0 + (768*r2) + (98304*x1)), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32) | |
tmp10 = tl.load(in_ptr4 + (x0 + (768*r2) + (98304*x1)), rmask, eviction_policy='evict_first', other=0.0) | |
tmp2 = tmp1.to(tl.float32) | |
tmp3 = tmp0 + tmp2 | |
tmp5 = tmp4.to(tl.float32) | |
tmp6 = tmp3 + tmp5 | |
tmp8 = tmp7.to(tl.float32) | |
tmp9 = tmp6 + tmp8 | |
tmp11 = tmp9 * tmp10 | |
tmp12 = tl.broadcast_to(tmp11, [XBLOCK, RBLOCK]) | |
tmp14 = _tmp13 + tmp12 | |
_tmp13 = tl.where(rmask, tmp14, _tmp13) | |
tmp15 = tl.broadcast_to(tmp9, [XBLOCK, RBLOCK]) | |
tmp17 = _tmp16 + tmp15 | |
_tmp16 = tl.where(rmask, tmp17, _tmp16) | |
tmp13 = tl.sum(_tmp13, 1)[:, None] | |
tl.store(out_ptr0 + (x3), tmp13, None) | |
tmp16 = tl.sum(_tmp16, 1)[:, None] | |
tl.store(out_ptr1 + (x3), tmp16, None) | |
def get_args(): | |
arg_0 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.float32) | |
arg_1 = rand_strided((8192, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
arg_2 = rand_strided((8192, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
arg_3 = rand_strided((8192, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
arg_4 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.float32) | |
arg_5 = rand_strided((768, 64), (1, 768), device='cuda:0', dtype=torch.float32) | |
arg_6 = rand_strided((768, 64), (1, 768), device='cuda:0', dtype=torch.float32) | |
return arg_0, arg_1, arg_2, arg_3, arg_4, arg_5, arg_6, | |
def call(args): | |
with torch.cuda._DeviceGuard(0): | |
torch.cuda.set_device(0) | |
stream0 = get_raw_stream(0) | |
triton_red_fused__to_copy_add_native_layer_norm_backward_24.run(*args, 49152, 128, grid=grid(49152), stream=stream0) | |
def benchmark_all_configs(args): | |
with torch.cuda._DeviceGuard(0): | |
torch.cuda.set_device(0) | |
return triton_red_fused__to_copy_add_native_layer_norm_backward_24.benchmark_all_configs(*args, 49152, 128, grid=grid(49152)) | |
if __name__ == '__main__': | |
from triton.testing import do_bench | |
args = get_args() | |
ms = do_bench(lambda: call(args), rep=40, fast_flush=True) | |
num_gb = 0.0884736 | |
gb_per_s = num_gb / (ms / 1e3) | |
print(f"{ms:.3f}ms {num_gb:.3f}GB {gb_per_s:.2f}GB/s") | |
''', device_str='cuda') | |
# kernel path: /tmp/torchinductor_shunting/te/cte77lf37am2mn24wajel7rv6zip6voecouwoqp2tx4rqtc4kzcb.py | |
# Source Nodes: [], Original ATen: [aten.embedding_dense_backward] | |
triton_poi_fused_embedding_dense_backward_25 = async_compile.triton('triton_poi_fused_embedding_dense_backward_25', ''' | |
import triton | |
import triton.language as tl | |
from triton.compiler.compiler import AttrsDescriptor | |
from torch._inductor.runtime import triton_helpers, triton_heuristics | |
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties | |
from torch._dynamo.testing import rand_strided | |
from torch._C import _cuda_getCurrentRawStream as get_raw_stream | |
import torch | |
from torch._inductor.runtime.triton_heuristics import grid, split_scan_grid | |
@triton_heuristics.pointwise( | |
size_hints=[2048], | |
filename=__file__, | |
triton_meta={'signature': {0: '*fp32', 1: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1), equal_to_1=())]}, | |
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_embedding_dense_backward_25', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 0, 'num_reduction': 0, 'backend_hash': '196B341D951BBDD96DE5B2B44B37054C2CBA8E89494C8B8A1052A863B8BC7596', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'kernel_num_gb': 6.144e-06}, | |
min_elem_per_thread=0 | |
) | |
@triton.jit | |
def triton_poi_fused_embedding_dense_backward_25(out_ptr0, xnumel, XBLOCK : tl.constexpr): | |
xnumel = 1536 | |
xoffset = tl.program_id(0) * XBLOCK | |
xindex = xoffset + tl.arange(0, XBLOCK)[:] | |
xmask = xindex < xnumel | |
x0 = xindex | |
tmp0 = 0.0 | |
tl.store(out_ptr0 + (x0), tmp0, xmask) | |
def get_args(): | |
arg_0 = rand_strided((2, 768), (768, 1), device='cuda:0', dtype=torch.float32) | |
return arg_0, | |
def call(args): | |
with torch.cuda._DeviceGuard(0): | |
torch.cuda.set_device(0) | |
stream0 = get_raw_stream(0) | |
triton_poi_fused_embedding_dense_backward_25.run(*args, 1536, grid=grid(1536), stream=stream0) | |
def benchmark_all_configs(args): | |
with torch.cuda._DeviceGuard(0): | |
torch.cuda.set_device(0) | |
return triton_poi_fused_embedding_dense_backward_25.benchmark_all_configs(*args, 1536, grid=grid(1536)) | |
if __name__ == '__main__': | |
from triton.testing import do_bench | |
args = get_args() | |
ms = do_bench(lambda: call(args), rep=40, fast_flush=True) | |
num_gb = 6.144e-06 | |
gb_per_s = num_gb / (ms / 1e3) | |
print(f"{ms:.3f}ms {num_gb:.3f}GB {gb_per_s:.2f}GB/s") | |
''', device_str='cuda') | |
# kernel path: /tmp/torchinductor_shunting/6f/c6fngau4jk3lnabptysnet4fu7qocsijshhqwcxzkwwgptradatm.py | |
# Source Nodes: [], Original ATen: [aten.embedding_dense_backward] | |
triton_poi_fused_embedding_dense_backward_26 = async_compile.triton('triton_poi_fused_embedding_dense_backward_26', ''' | |
import triton | |
import triton.language as tl | |
from triton.compiler.compiler import AttrsDescriptor | |
from torch._inductor.runtime import triton_helpers, triton_heuristics | |
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties | |
from torch._dynamo.testing import rand_strided | |
from torch._C import _cuda_getCurrentRawStream as get_raw_stream | |
import torch | |
from torch._inductor.runtime.triton_heuristics import grid, split_scan_grid | |
@triton_heuristics.pointwise( | |
size_hints=[33554432], | |
filename=__file__, | |
triton_meta={'signature': {0: '*fp32', 1: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1), equal_to_1=())]}, | |
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_embedding_dense_backward_26', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 0, 'num_reduction': 0, 'backend_hash': '196B341D951BBDD96DE5B2B44B37054C2CBA8E89494C8B8A1052A863B8BC7596', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'kernel_num_gb': 0.093763584}, | |
min_elem_per_thread=0 | |
) | |
@triton.jit | |
def triton_poi_fused_embedding_dense_backward_26(out_ptr0, xnumel, XBLOCK : tl.constexpr): | |
xnumel = 23440896 | |
xoffset = tl.program_id(0) * XBLOCK | |
xindex = xoffset + tl.arange(0, XBLOCK)[:] | |
xmask = xindex < xnumel | |
x0 = xindex | |
tmp0 = 0.0 | |
tl.store(out_ptr0 + (x0), tmp0, xmask) | |
def get_args(): | |
arg_0 = rand_strided((30522, 768), (768, 1), device='cuda:0', dtype=torch.float32) | |
return arg_0, | |
def call(args): | |
with torch.cuda._DeviceGuard(0): | |
torch.cuda.set_device(0) | |
stream0 = get_raw_stream(0) | |
triton_poi_fused_embedding_dense_backward_26.run(*args, 23440896, grid=grid(23440896), stream=stream0) | |
def benchmark_all_configs(args): | |
with torch.cuda._DeviceGuard(0): | |
torch.cuda.set_device(0) | |
return triton_poi_fused_embedding_dense_backward_26.benchmark_all_configs(*args, 23440896, grid=grid(23440896)) | |
if __name__ == '__main__': | |
from triton.testing import do_bench | |
args = get_args() | |
ms = do_bench(lambda: call(args), rep=40, fast_flush=True) | |
num_gb = 0.093763584 | |
gb_per_s = num_gb / (ms / 1e3) | |
print(f"{ms:.3f}ms {num_gb:.3f}GB {gb_per_s:.2f}GB/s") | |
''', device_str='cuda') | |
# kernel path: /tmp/torchinductor_shunting/ha/chan35efrgnml6b6lkx5heboj7kad62m5wsfwbpqx4tq72w4wxpt.py | |
# Source Nodes: [masked_lm_loss], Original ATen: [aten._to_copy, aten.add, aten.embedding_dense_backward, aten.native_dropout_backward, aten.native_layer_norm_backward, aten.nll_loss_forward] | |
# masked_lm_loss => full_default_2 | |
triton_per_fused__to_copy_add_embedding_dense_backward_native_dropout_backward_native_layer_norm_backward_nll_loss_forward_27 = async_compile.triton('triton_per_fused__to_copy_add_embedding_dense_backward_native_dropout_backward_native_layer_norm_backward_nll_loss_forward_27', ''' | |
import triton | |
import triton.language as tl | |
from triton.compiler.compiler import AttrsDescriptor | |
from torch._inductor.runtime import triton_helpers, triton_heuristics | |
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties | |
from torch._dynamo.testing import rand_strided | |
from torch._C import _cuda_getCurrentRawStream as get_raw_stream | |
import torch | |
from torch._inductor.runtime.triton_heuristics import grid, split_scan_grid | |
@triton_heuristics.persistent_reduction( | |
size_hints=[8192, 1024], | |
reduction_hint=ReductionHint.INNER, | |
filename=__file__, | |
triton_meta={'signature': {0: '*fp32', 1: '*fp16', 2: '*fp16', 3: '*fp16', 4: '*i1', 5: '*fp32', 6: '*fp32', 7: '*fp32', 8: '*i64', 9: '*i64', 10: '*fp32', 11: '*fp32', 12: '*fp32', 13: 'i32', 14: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14), equal_to_1=())]}, | |
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_add_embedding_dense_backward_native_dropout_backward_native_layer_norm_backward_nll_loss_forward_27', 'mutated_arg_names': ['in_out_ptr0', 'out_ptr3', 'out_ptr4'], 'no_x_dim': True, 'num_load': 10, 'num_reduction': 2, 'backend_hash': '196B341D951BBDD96DE5B2B44B37054C2CBA8E89494C8B8A1052A863B8BC7596', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'kernel_num_gb': 0.169980928} | |
) | |
@triton.jit | |
def triton_per_fused__to_copy_add_embedding_dense_backward_native_dropout_backward_native_layer_norm_backward_nll_loss_forward_27(in_out_ptr0, in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, in_ptr5, in_ptr6, in_ptr7, in_ptr8, out_ptr2, out_ptr3, out_ptr4, xnumel, rnumel): | |
xnumel = 8192 | |
XBLOCK: tl.constexpr = 1 | |
rnumel = 768 | |
RBLOCK: tl.constexpr = 1024 | |
xoffset = tl.program_id(0) * XBLOCK | |
xindex = tl.full([1], xoffset, tl.int32) | |
xmask = tl.full([RBLOCK], True, tl.int1) | |
rindex = tl.arange(0, RBLOCK)[:] | |
roffset = 0 | |
rmask = rindex < rnumel | |
r1 = rindex | |
x0 = xindex | |
x2 = xindex % 512 | |
tmp0 = tl.load(in_out_ptr0 + (r1 + (768*x0)), rmask, other=0.0) | |
tmp1 = tl.load(in_ptr0 + (r1 + (768*x0)), rmask, other=0.0).to(tl.float32) | |
tmp4 = tl.load(in_ptr1 + (r1 + (768*x0)), rmask, other=0.0).to(tl.float32) | |
tmp7 = tl.load(in_ptr2 + (r1 + (768*x0)), rmask, other=0.0).to(tl.float32) | |
tmp10 = tl.load(in_ptr3 + (r1 + (768*x0)), rmask, other=0.0).to(tl.int1) | |
tmp15 = tl.load(in_ptr4 + (r1), rmask, eviction_policy='evict_last', other=0.0) | |
tmp21 = tl.load(in_ptr5 + (r1 + (768*x0)), rmask, other=0.0) | |
tmp27 = tl.load(in_ptr6 + (x0), None, eviction_policy='evict_last') | |
tmp34 = tl.load(in_ptr7 + (x2), None, eviction_policy='evict_last') | |
tmp43 = tl.load(in_ptr8 + (x0), None, eviction_policy='evict_last') | |
tmp2 = tmp1.to(tl.float32) | |
tmp3 = tmp0 + tmp2 | |
tmp5 = tmp4.to(tl.float32) | |
tmp6 = tmp3 + tmp5 | |
tmp8 = tmp7.to(tl.float32) | |
tmp9 = tmp6 + tmp8 | |
tmp11 = tmp10.to(tl.float32) | |
tmp12 = 1.1111111111111112 | |
tmp13 = tmp11 * tmp12 | |
tmp14 = tmp9 * tmp13 | |
tmp16 = tmp14 * tmp15 | |
tmp17 = tl.broadcast_to(tmp16, [RBLOCK]) | |
tmp19 = tl.where(rmask, tmp17, 0) | |
tmp20 = triton_helpers.promote_to_tensor(tl.sum(tmp19, 0)) | |
tmp22 = tmp16 * tmp21 | |
tmp23 = tl.broadcast_to(tmp22, [RBLOCK]) | |
tmp25 = tl.where(rmask, tmp23, 0) | |
tmp26 = triton_helpers.promote_to_tensor(tl.sum(tmp25, 0)) | |
tmp28 = 768.0 | |
tmp29 = tmp16 * tmp28 | |
tmp30 = tmp29 - tmp20 | |
tmp31 = tmp21 * tmp26 | |
tmp32 = tmp30 - tmp31 | |
tmp33 = tmp27 * tmp32 | |
tmp35 = tl.full([RBLOCK], 2, tl.int32) | |
tmp36 = tmp34 + tmp35 | |
tmp37 = tmp34 < 0 | |
tmp38 = tl.where(tmp37, tmp36, tmp34) | |
tmp39 = tl.full([1], -1, tl.int64) | |
tmp40 = tmp34 == tmp39 | |
tmp41 = 0.0 | |
tmp42 = tl.where(tmp40, tmp41, tmp33) | |
tmp44 = tl.full([RBLOCK], 30522, tl.int32) | |
tmp45 = tmp43 + tmp44 | |
tmp46 = tmp43 < 0 | |
tmp47 = tl.where(tmp46, tmp45, tmp43) | |
tmp48 = tl.full([1], 0, tl.int64) | |
tmp49 = tmp43 == tmp48 | |
tmp50 = tl.where(tmp49, tmp41, tmp33) | |
tl.store(in_out_ptr0 + (r1 + (768*x0)), tmp14, rmask) | |
tl.store(out_ptr2 + (r1 + (768*x0)), tmp33, rmask) | |
tl.atomic_add(out_ptr3 + (tl.broadcast_to(r1 + (768*tmp38), [RBLOCK])), tmp42, rmask, sem='relaxed') | |
tl.atomic_add(out_ptr4 + (tl.broadcast_to(r1 + (768*tmp47), [RBLOCK])), tmp50, rmask, sem='relaxed') | |
def get_args(): | |
arg_0 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.float32) | |
arg_1 = rand_strided((8192, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
arg_2 = rand_strided((8192, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
arg_3 = rand_strided((8192, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
arg_4 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.bool) | |
arg_5 = rand_strided((768,), (1,), device='cuda:0', dtype=torch.float32) | |
arg_6 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.float32) | |
arg_7 = rand_strided((16, 512, 1), (512, 1, 1), device='cuda:0', dtype=torch.float32) | |
arg_8 = rand_strided((1, 512), (512, 1), device='cuda:0', dtype=torch.int64) | |
arg_9 = rand_strided((16, 512), (512, 1), device='cuda:0', dtype=torch.int64) | |
arg_10 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.float32) | |
arg_11 = rand_strided((2, 768), (768, 1), device='cuda:0', dtype=torch.float32) | |
arg_12 = rand_strided((30522, 768), (768, 1), device='cuda:0', dtype=torch.float32) | |
return arg_0, arg_1, arg_2, arg_3, arg_4, arg_5, arg_6, arg_7, arg_8, arg_9, arg_10, arg_11, arg_12, | |
def call(args): | |
with torch.cuda._DeviceGuard(0): | |
torch.cuda.set_device(0) | |
stream0 = get_raw_stream(0) | |
triton_per_fused__to_copy_add_embedding_dense_backward_native_dropout_backward_native_layer_norm_backward_nll_loss_forward_27.run(*args, 8192, 768, grid=grid(8192), stream=stream0) | |
def benchmark_all_configs(args): | |
with torch.cuda._DeviceGuard(0): | |
torch.cuda.set_device(0) | |
return triton_per_fused__to_copy_add_embedding_dense_backward_native_dropout_backward_native_layer_norm_backward_nll_loss_forward_27.benchmark_all_configs(*args, 8192, 768, grid=grid(8192)) | |
if __name__ == '__main__': | |
from triton.testing import do_bench | |
args = get_args() | |
ms = do_bench(lambda: call(args), rep=40, fast_flush=True) | |
num_gb = 0.169980928 | |
gb_per_s = num_gb / (ms / 1e3) | |
print(f"{ms:.3f}ms {num_gb:.3f}GB {gb_per_s:.2f}GB/s") | |
''', device_str='cuda') | |
# kernel path: /tmp/torchinductor_shunting/c7/cc74oeyzfdenaonorpxyc6q6zpqctt3hj57l7j5h46ugv67z2gei.py | |
# Source Nodes: [], Original ATen: [aten.native_layer_norm_backward] | |
triton_red_fused_native_layer_norm_backward_28 = async_compile.triton('triton_red_fused_native_layer_norm_backward_28', ''' | |
import triton | |
import triton.language as tl | |
from triton.compiler.compiler import AttrsDescriptor | |
from torch._inductor.runtime import triton_helpers, triton_heuristics | |
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties | |
from torch._dynamo.testing import rand_strided | |
from torch._C import _cuda_getCurrentRawStream as get_raw_stream | |
import torch | |
from torch._inductor.runtime.triton_heuristics import grid, split_scan_grid | |
@triton_heuristics.reduction( | |
size_hints=[65536, 128], | |
reduction_hint=ReductionHint.OUTER, | |
filename=__file__, | |
triton_meta={'signature': {0: '*fp32', 1: '*fp32', 2: '*fp32', 3: '*fp32', 4: 'i32', 5: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4, 5), equal_to_1=())]}, | |
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused_native_layer_norm_backward_28', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 2, 'num_reduction': 2, 'backend_hash': '196B341D951BBDD96DE5B2B44B37054C2CBA8E89494C8B8A1052A863B8BC7596', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'kernel_num_gb': 0.050724864} | |
) | |
@triton.jit | |
def triton_red_fused_native_layer_norm_backward_28(in_ptr0, in_ptr1, out_ptr0, out_ptr1, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr): | |
xnumel = 49152 | |
rnumel = 128 | |
xoffset = tl.program_id(0) * XBLOCK | |
xindex = xoffset + tl.arange(0, XBLOCK)[:, None] | |
xmask = tl.full([XBLOCK, RBLOCK], True, tl.int1) | |
rbase = tl.arange(0, RBLOCK)[None, :] | |
x0 = xindex % 768 | |
x1 = (xindex // 768) | |
_tmp4 = tl.full([XBLOCK, RBLOCK], 0, tl.float32) | |
x3 = xindex | |
_tmp7 = tl.full([XBLOCK, RBLOCK], 0, tl.float32) | |
for roffset in range(0, rnumel, RBLOCK): | |
rindex = roffset + rbase | |
rmask = rindex < rnumel | |
r2 = rindex | |
tmp0 = tl.load(in_ptr0 + (x0 + (768*r2) + (98304*x1)), rmask, eviction_policy='evict_first', other=0.0) | |
tmp1 = tl.load(in_ptr1 + (x0 + (768*r2) + (98304*x1)), rmask, eviction_policy='evict_first', other=0.0) | |
tmp2 = tmp0 * tmp1 | |
tmp3 = tl.broadcast_to(tmp2, [XBLOCK, RBLOCK]) | |
tmp5 = _tmp4 + tmp3 | |
_tmp4 = tl.where(rmask, tmp5, _tmp4) | |
tmp6 = tl.broadcast_to(tmp0, [XBLOCK, RBLOCK]) | |
tmp8 = _tmp7 + tmp6 | |
_tmp7 = tl.where(rmask, tmp8, _tmp7) | |
tmp4 = tl.sum(_tmp4, 1)[:, None] | |
tl.store(out_ptr0 + (x3), tmp4, None) | |
tmp7 = tl.sum(_tmp7, 1)[:, None] | |
tl.store(out_ptr1 + (x3), tmp7, None) | |
def get_args(): | |
arg_0 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.float32) | |
arg_1 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.float32) | |
arg_2 = rand_strided((768, 64), (1, 768), device='cuda:0', dtype=torch.float32) | |
arg_3 = rand_strided((768, 64), (1, 768), device='cuda:0', dtype=torch.float32) | |
return arg_0, arg_1, arg_2, arg_3, | |
def call(args): | |
with torch.cuda._DeviceGuard(0): | |
torch.cuda.set_device(0) | |
stream0 = get_raw_stream(0) | |
triton_red_fused_native_layer_norm_backward_28.run(*args, 49152, 128, grid=grid(49152), stream=stream0) | |
def benchmark_all_configs(args): | |
with torch.cuda._DeviceGuard(0): | |
torch.cuda.set_device(0) | |
return triton_red_fused_native_layer_norm_backward_28.benchmark_all_configs(*args, 49152, 128, grid=grid(49152)) | |
if __name__ == '__main__': | |
from triton.testing import do_bench | |
args = get_args() | |
ms = do_bench(lambda: call(args), rep=40, fast_flush=True) | |
num_gb = 0.050724864 | |
gb_per_s = num_gb / (ms / 1e3) | |
print(f"{ms:.3f}ms {num_gb:.3f}GB {gb_per_s:.2f}GB/s") | |
''', device_str='cuda') | |
# kernel path: /tmp/torchinductor_shunting/55/c55h64q7s3knjikeayrydmub7yk3opg73oy3todkps33rlu5tz62.py | |
# Source Nodes: [], Original ATen: [aten.embedding_dense_backward] | |
triton_poi_fused_embedding_dense_backward_29 = async_compile.triton('triton_poi_fused_embedding_dense_backward_29', ''' | |
import triton | |
import triton.language as tl | |
from triton.compiler.compiler import AttrsDescriptor | |
from torch._inductor.runtime import triton_helpers, triton_heuristics | |
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties | |
from torch._dynamo.testing import rand_strided | |
from torch._C import _cuda_getCurrentRawStream as get_raw_stream | |
import torch | |
from torch._inductor.runtime.triton_heuristics import grid, split_scan_grid | |
@triton_heuristics.pointwise( | |
size_hints=[524288], | |
filename=__file__, | |
triton_meta={'signature': {0: '*fp32', 1: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1), equal_to_1=())]}, | |
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_embedding_dense_backward_29', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 0, 'num_reduction': 0, 'backend_hash': '196B341D951BBDD96DE5B2B44B37054C2CBA8E89494C8B8A1052A863B8BC7596', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'kernel_num_gb': 0.001572864}, | |
min_elem_per_thread=0 | |
) | |
@triton.jit | |
def triton_poi_fused_embedding_dense_backward_29(out_ptr0, xnumel, XBLOCK : tl.constexpr): | |
xnumel = 393216 | |
xoffset = tl.program_id(0) * XBLOCK | |
xindex = xoffset + tl.arange(0, XBLOCK)[:] | |
xmask = tl.full([XBLOCK], True, tl.int1) | |
x0 = xindex | |
tmp0 = 0.0 | |
tl.store(out_ptr0 + (x0), tmp0, None) | |
def get_args(): | |
arg_0 = rand_strided((512, 768), (768, 1), device='cuda:0', dtype=torch.float32) | |
return arg_0, | |
def call(args): | |
with torch.cuda._DeviceGuard(0): | |
torch.cuda.set_device(0) | |
stream0 = get_raw_stream(0) | |
triton_poi_fused_embedding_dense_backward_29.run(*args, 393216, grid=grid(393216), stream=stream0) | |
def benchmark_all_configs(args): | |
with torch.cuda._DeviceGuard(0): | |
torch.cuda.set_device(0) | |
return triton_poi_fused_embedding_dense_backward_29.benchmark_all_configs(*args, 393216, grid=grid(393216)) | |
if __name__ == '__main__': | |
from triton.testing import do_bench | |
args = get_args() | |
ms = do_bench(lambda: call(args), rep=40, fast_flush=True) | |
num_gb = 0.001572864 | |
gb_per_s = num_gb / (ms / 1e3) | |
print(f"{ms:.3f}ms {num_gb:.3f}GB {gb_per_s:.2f}GB/s") | |
''', device_str='cuda') | |
# kernel path: /tmp/torchinductor_shunting/wv/cwv53zxut7fjvmy4vtbp25bc3pksunq7xtkpoeutsmskctsgwxni.py | |
# Source Nodes: [masked_lm_loss], Original ATen: [aten.embedding_dense_backward, aten.nll_loss_forward, aten.sum] | |
# masked_lm_loss => full_default_2 | |
triton_per_fused_embedding_dense_backward_nll_loss_forward_sum_30 = async_compile.triton('triton_per_fused_embedding_dense_backward_nll_loss_forward_sum_30', ''' | |
import triton | |
import triton.language as tl | |
from triton.compiler.compiler import AttrsDescriptor | |
from torch._inductor.runtime import triton_helpers, triton_heuristics | |
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties | |
from torch._dynamo.testing import rand_strided | |
from torch._C import _cuda_getCurrentRawStream as get_raw_stream | |
import torch | |
from torch._inductor.runtime.triton_heuristics import grid, split_scan_grid | |
@triton_heuristics.persistent_reduction( | |
size_hints=[524288, 16], | |
reduction_hint=ReductionHint.DEFAULT, | |
filename=__file__, | |
triton_meta={'signature': {0: '*fp32', 1: '*i64', 2: '*fp32', 3: 'i32', 4: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4), equal_to_1=())]}, | |
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_per_fused_embedding_dense_backward_nll_loss_forward_sum_30', 'mutated_arg_names': ['out_ptr1'], 'no_x_dim': False, 'num_load': 2, 'num_reduction': 1, 'backend_hash': '196B341D951BBDD96DE5B2B44B37054C2CBA8E89494C8B8A1052A863B8BC7596', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'kernel_num_gb': 0.026742784} | |
) | |
@triton.jit | |
def triton_per_fused_embedding_dense_backward_nll_loss_forward_sum_30(in_ptr0, in_ptr1, out_ptr1, xnumel, rnumel, XBLOCK : tl.constexpr): | |
xnumel = 393216 | |
rnumel = 16 | |
RBLOCK: tl.constexpr = 16 | |
xoffset = tl.program_id(0) * XBLOCK | |
xindex = xoffset + tl.arange(0, XBLOCK)[:, None] | |
xmask = tl.full([XBLOCK, RBLOCK], True, tl.int1) | |
rindex = tl.arange(0, RBLOCK)[None, :] | |
roffset = 0 | |
rmask = tl.full([XBLOCK, RBLOCK], True, tl.int1) | |
r1 = rindex | |
x0 = xindex | |
x3 = (xindex // 768) | |
x2 = xindex % 768 | |
tmp0 = tl.load(in_ptr0 + (x0 + (393216*r1)), None) | |
tmp4 = tl.load(in_ptr1 + (x3), None, eviction_policy='evict_last') | |
tmp1 = tl.broadcast_to(tmp0, [XBLOCK, RBLOCK]) | |
tmp3 = tl.sum(tmp1, 1)[:, None] | |
tmp5 = tl.full([XBLOCK, 1], 512, tl.int32) | |
tmp6 = tmp4 + tmp5 | |
tmp7 = tmp4 < 0 | |
tmp8 = tl.where(tmp7, tmp6, tmp4) | |
tmp9 = tl.full([1, 1], -1, tl.int64) | |
tmp10 = tmp4 == tmp9 | |
tmp11 = 0.0 | |
tmp12 = tl.where(tmp10, tmp11, tmp3) | |
tl.atomic_add(out_ptr1 + (x2 + (768*tmp8)), tmp12, None, sem='relaxed') | |
def get_args(): | |
arg_0 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.float32) | |
arg_1 = rand_strided((1, 512), (512, 1), device='cuda:0', dtype=torch.int64) | |
arg_2 = rand_strided((512, 768), (768, 1), device='cuda:0', dtype=torch.float32) | |
return arg_0, arg_1, arg_2, | |
def call(args): | |
with torch.cuda._DeviceGuard(0): | |
torch.cuda.set_device(0) | |
stream0 = get_raw_stream(0) | |
triton_per_fused_embedding_dense_backward_nll_loss_forward_sum_30.run(*args, 393216, 16, grid=grid(393216), stream=stream0) | |
def benchmark_all_configs(args): | |
with torch.cuda._DeviceGuard(0): | |
torch.cuda.set_device(0) | |
return triton_per_fused_embedding_dense_backward_nll_loss_forward_sum_30.benchmark_all_configs(*args, 393216, 16, grid=grid(393216)) | |
if __name__ == '__main__': | |
from triton.testing import do_bench | |
args = get_args() | |
ms = do_bench(lambda: call(args), rep=40, fast_flush=True) | |
num_gb = 0.026742784 | |
gb_per_s = num_gb / (ms / 1e3) | |
print(f"{ms:.3f}ms {num_gb:.3f}GB {gb_per_s:.2f}GB/s") | |
''', device_str='cuda') | |
async_compile.wait(globals()) | |
del async_compile | |
def call(args): | |
primals_4, primals_14, primals_20, primals_30, primals_36, primals_46, primals_52, primals_62, primals_68, primals_78, primals_84, primals_94, primals_100, primals_110, primals_116, primals_126, primals_132, primals_142, primals_148, primals_158, primals_164, primals_174, primals_180, primals_190, primals_196, primals_200, primals_204, primals_205, primals_206, primals_207, mul_1, gt, view, permute_default_66, permute_default_67, permute_default_68, getitem_129, getitem_130, getitem_131, getitem_132, view_16, gt_2, mul_9, view_18, addmm_4, view_20, gt_3, mul_16, view_22, permute_default_60, permute_default_61, permute_default_62, getitem_122, getitem_123, getitem_124, getitem_125, view_38, gt_5, mul_22, view_40, addmm_10, view_42, gt_6, mul_29, view_44, permute_default_54, permute_default_55, permute_default_56, getitem_115, getitem_116, getitem_117, getitem_118, view_60, gt_8, mul_35, view_62, addmm_16, view_64, gt_9, mul_42, view_66, permute_default_48, permute_default_49, permute_default_50, getitem_108, getitem_109, getitem_110, getitem_111, view_82, gt_11, mul_48, view_84, addmm_22, view_86, gt_12, mul_55, view_88, permute_default_42, permute_default_43, permute_default_44, getitem_101, getitem_102, getitem_103, getitem_104, view_104, gt_14, mul_61, view_106, addmm_28, view_108, gt_15, mul_68, view_110, permute_default_36, permute_default_37, permute_default_38, getitem_94, getitem_95, getitem_96, getitem_97, view_126, gt_17, mul_74, view_128, addmm_34, view_130, gt_18, mul_81, view_132, permute_default_30, permute_default_31, permute_default_32, getitem_87, getitem_88, getitem_89, getitem_90, view_148, gt_20, mul_87, view_150, addmm_40, view_152, gt_21, mul_94, view_154, permute_default_24, permute_default_25, permute_default_26, getitem_80, getitem_81, getitem_82, getitem_83, view_170, gt_23, mul_100, view_172, addmm_46, view_174, gt_24, mul_107, view_176, permute_default_18, permute_default_19, permute_default_20, getitem_73, getitem_74, getitem_75, getitem_76, view_192, gt_26, mul_113, view_194, addmm_52, view_196, gt_27, mul_120, view_198, permute_default_12, permute_default_13, permute_default_14, getitem_66, getitem_67, getitem_68, getitem_69, view_214, gt_29, mul_126, view_216, addmm_58, view_218, gt_30, mul_133, view_220, permute_default_6, permute_default_7, permute_default_8, getitem_59, getitem_60, getitem_61, getitem_62, view_236, gt_32, mul_139, view_238, addmm_64, view_240, gt_33, mul_146, view_242, permute_default, permute_default_1, permute_default_2, getitem_52, getitem_53, getitem_54, getitem_55, view_258, gt_35, mul_152, view_260, addmm_70, view_262, gt_36, mul_159, view_264, addmm_72, getitem_51, rsqrt_25, view_266, view_267, amax_12, log, convert_element_type_510, permute_134, permute_138, div_27, permute_142, permute_146, div_28, permute_150, permute_162, permute_167, permute_171, div_30, permute_175, permute_179, div_31, permute_183, permute_195, permute_200, permute_204, div_33, permute_208, permute_212, div_34, permute_216, permute_228, permute_233, permute_237, div_36, permute_241, permute_245, div_37, permute_249, permute_261, permute_266, permute_270, div_39, permute_274, permute_278, div_40, permute_282, permute_294, permute_299, permute_303, div_42, permute_307, permute_311, div_43, permute_315, permute_327, permute_332, permute_336, div_45, permute_340, permute_344, div_46, permute_348, permute_360, permute_365, permute_369, div_48, permute_373, permute_377, div_49, permute_381, permute_393, permute_398, permute_402, div_51, permute_406, permute_410, div_52, permute_414, permute_426, permute_431, permute_435, div_54, permute_439, permute_443, div_55, permute_447, permute_459, permute_464, permute_468, div_57, permute_472, permute_476, div_58, permute_480, permute_492, permute_497, permute_501, div_60, permute_505, permute_509, div_61, permute_513, permute_525, permute_530, permute_534, div_63, tangents_1, tangents_2 = args | |
args.clear() | |
assert_size_stride(primals_4, (768, ), (1, )) | |
assert_size_stride(primals_14, (768, ), (1, )) | |
assert_size_stride(primals_20, (768, ), (1, )) | |
assert_size_stride(primals_30, (768, ), (1, )) | |
assert_size_stride(primals_36, (768, ), (1, )) | |
assert_size_stride(primals_46, (768, ), (1, )) | |
assert_size_stride(primals_52, (768, ), (1, )) | |
assert_size_stride(primals_62, (768, ), (1, )) | |
assert_size_stride(primals_68, (768, ), (1, )) | |
assert_size_stride(primals_78, (768, ), (1, )) | |
assert_size_stride(primals_84, (768, ), (1, )) | |
assert_size_stride(primals_94, (768, ), (1, )) | |
assert_size_stride(primals_100, (768, ), (1, )) | |
assert_size_stride(primals_110, (768, ), (1, )) | |
assert_size_stride(primals_116, (768, ), (1, )) | |
assert_size_stride(primals_126, (768, ), (1, )) | |
assert_size_stride(primals_132, (768, ), (1, )) | |
assert_size_stride(primals_142, (768, ), (1, )) | |
assert_size_stride(primals_148, (768, ), (1, )) | |
assert_size_stride(primals_158, (768, ), (1, )) | |
assert_size_stride(primals_164, (768, ), (1, )) | |
assert_size_stride(primals_174, (768, ), (1, )) | |
assert_size_stride(primals_180, (768, ), (1, )) | |
assert_size_stride(primals_190, (768, ), (1, )) | |
assert_size_stride(primals_196, (768, ), (1, )) | |
assert_size_stride(primals_200, (768, ), (1, )) | |
assert_size_stride(primals_204, (1, 512), (512, 1)) | |
assert_size_stride(primals_205, (1, 512), (512, 1)) | |
assert_size_stride(primals_206, (16, 512), (512, 1)) | |
assert_size_stride(primals_207, (16, 512), (512, 1)) | |
assert_size_stride(mul_1, (16, 512, 768), (393216, 768, 1)) | |
assert_size_stride(gt, (16, 512, 768), (393216, 768, 1)) | |
assert_size_stride(view, (8192, 768), (768, 1)) | |
assert_size_stride(permute_default_66, (16, 12, 512, 64), (393216, 64, 768, 1)) | |
assert_size_stride(permute_default_67, (16, 12, 512, 64), (393216, 64, 768, 1)) | |
assert_size_stride(permute_default_68, (16, 12, 512, 64), (393216, 64, 768, 1)) | |
assert_size_stride(getitem_129, (16, 12, 512, 64), (393216, 64, 768, 1)) | |
assert_size_stride(getitem_130, (16, 12, 512), (6144, 512, 1)) | |
assert_size_stride(getitem_131, (), ()) | |
assert_size_stride(getitem_132, (), ()) | |
assert_size_stride(view_16, (8192, 768), (768, 1)) | |
assert_size_stride(gt_2, (16, 512, 768), (393216, 768, 1)) | |
assert_size_stride(mul_9, (16, 512, 768), (393216, 768, 1)) | |
assert_size_stride(view_18, (8192, 768), (768, 1)) | |
assert_size_stride(addmm_4, (8192, 3072), (3072, 1)) | |
assert_size_stride(view_20, (8192, 3072), (3072, 1)) | |
assert_size_stride(gt_3, (16, 512, 768), (393216, 768, 1)) | |
assert_size_stride(mul_16, (16, 512, 768), (393216, 768, 1)) | |
assert_size_stride(view_22, (8192, 768), (768, 1)) | |
assert_size_stride(permute_default_60, (16, 12, 512, 64), (393216, 64, 768, 1)) | |
assert_size_stride(permute_default_61, (16, 12, 512, 64), (393216, 64, 768, 1)) | |
assert_size_stride(permute_default_62, (16, 12, 512, 64), (393216, 64, 768, 1)) | |
assert_size_stride(getitem_122, (16, 12, 512, 64), (393216, 64, 768, 1)) | |
assert_size_stride(getitem_123, (16, 12, 512), (6144, 512, 1)) | |
assert_size_stride(getitem_124, (), ()) | |
assert_size_stride(getitem_125, (), ()) | |
assert_size_stride(view_38, (8192, 768), (768, 1)) | |
assert_size_stride(gt_5, (16, 512, 768), (393216, 768, 1)) | |
assert_size_stride(mul_22, (16, 512, 768), (393216, 768, 1)) | |
assert_size_stride(view_40, (8192, 768), (768, 1)) | |
assert_size_stride(addmm_10, (8192, 3072), (3072, 1)) | |
assert_size_stride(view_42, (8192, 3072), (3072, 1)) | |
assert_size_stride(gt_6, (16, 512, 768), (393216, 768, 1)) | |
assert_size_stride(mul_29, (16, 512, 768), (393216, 768, 1)) | |
assert_size_stride(view_44, (8192, 768), (768, 1)) | |
assert_size_stride(permute_default_54, (16, 12, 512, 64), (393216, 64, 768, 1)) | |
assert_size_stride(permute_default_55, (16, 12, 512, 64), (393216, 64, 768, 1)) | |
assert_size_stride(permute_default_56, (16, 12, 512, 64), (393216, 64, 768, 1)) | |
assert_size_stride(getitem_115, (16, 12, 512, 64), (393216, 64, 768, 1)) | |
assert_size_stride(getitem_116, (16, 12, 512), (6144, 512, 1)) | |
assert_size_stride(getitem_117, (), ()) | |
assert_size_stride(getitem_118, (), ()) | |
assert_size_stride(view_60, (8192, 768), (768, 1)) | |
assert_size_stride(gt_8, (16, 512, 768), (393216, 768, 1)) | |
assert_size_stride(mul_35, (16, 512, 768), (393216, 768, 1)) | |
assert_size_stride(view_62, (8192, 768), (768, 1)) | |
assert_size_stride(addmm_16, (8192, 3072), (3072, 1)) | |
assert_size_stride(view_64, (8192, 3072), (3072, 1)) | |
assert_size_stride(gt_9, (16, 512, 768), (393216, 768, 1)) | |
assert_size_stride(mul_42, (16, 512, 768), (393216, 768, 1)) | |
assert_size_stride(view_66, (8192, 768), (768, 1)) | |
assert_size_stride(permute_default_48, (16, 12, 512, 64), (393216, 64, 768, 1)) | |
assert_size_stride(permute_default_49, (16, 12, 512, 64), (393216, 64, 768, 1)) | |
assert_size_stride(permute_default_50, (16, 12, 512, 64), (393216, 64, 768, 1)) | |
assert_size_stride(getitem_108, (16, 12, 512, 64), (393216, 64, 768, 1)) | |
assert_size_stride(getitem_109, (16, 12, 512), (6144, 512, 1)) | |
assert_size_stride(getitem_110, (), ()) | |
assert_size_stride(getitem_111, (), ()) | |
assert_size_stride(view_82, (8192, 768), (768, 1)) | |
assert_size_stride(gt_11, (16, 512, 768), (393216, 768, 1)) | |
assert_size_stride(mul_48, (16, 512, 768), (393216, 768, 1)) | |
assert_size_stride(view_84, (8192, 768), (768, 1)) | |
assert_size_stride(addmm_22, (8192, 3072), (3072, 1)) | |
assert_size_stride(view_86, (8192, 3072), (3072, 1)) | |
assert_size_stride(gt_12, (16, 512, 768), (393216, 768, 1)) | |
assert_size_stride(mul_55, (16, 512, 768), (393216, 768, 1)) | |
assert_size_stride(view_88, (8192, 768), (768, 1)) | |
assert_size_stride(permute_default_42, (16, 12, 512, 64), (393216, 64, 768, 1)) | |
assert_size_stride(permute_default_43, (16, 12, 512, 64), (393216, 64, 768, 1)) | |
assert_size_stride(permute_default_44, (16, 12, 512, 64), (393216, 64, 768, 1)) | |
assert_size_stride(getitem_101, (16, 12, 512, 64), (393216, 64, 768, 1)) | |
assert_size_stride(getitem_102, (16, 12, 512), (6144, 512, 1)) | |
assert_size_stride(getitem_103, (), ()) | |
assert_size_stride(getitem_104, (), ()) | |
assert_size_stride(view_104, (8192, 768), (768, 1)) | |
assert_size_stride(gt_14, (16, 512, 768), (393216, 768, 1)) | |
assert_size_stride(mul_61, (16, 512, 768), (393216, 768, 1)) | |
assert_size_stride(view_106, (8192, 768), (768, 1)) | |
assert_size_stride(addmm_28, (8192, 3072), (3072, 1)) | |
assert_size_stride(view_108, (8192, 3072), (3072, 1)) | |
assert_size_stride(gt_15, (16, 512, 768), (393216, 768, 1)) | |
assert_size_stride(mul_68, (16, 512, 768), (393216, 768, 1)) | |
assert_size_stride(view_110, (8192, 768), (768, 1)) | |
assert_size_stride(permute_default_36, (16, 12, 512, 64), (393216, 64, 768, 1)) | |
assert_size_stride(permute_default_37, (16, 12, 512, 64), (393216, 64, 768, 1)) | |
assert_size_stride(permute_default_38, (16, 12, 512, 64), (393216, 64, 768, 1)) | |
assert_size_stride(getitem_94, (16, 12, 512, 64), (393216, 64, 768, 1)) | |
assert_size_stride(getitem_95, (16, 12, 512), (6144, 512, 1)) | |
assert_size_stride(getitem_96, (), ()) | |
assert_size_stride(getitem_97, (), ()) | |
assert_size_stride(view_126, (8192, 768), (768, 1)) | |
assert_size_stride(gt_17, (16, 512, 768), (393216, 768, 1)) | |
assert_size_stride(mul_74, (16, 512, 768), (393216, 768, 1)) | |
assert_size_stride(view_128, (8192, 768), (768, 1)) | |
assert_size_stride(addmm_34, (8192, 3072), (3072, 1)) | |
assert_size_stride(view_130, (8192, 3072), (3072, 1)) | |
assert_size_stride(gt_18, (16, 512, 768), (393216, 768, 1)) | |
assert_size_stride(mul_81, (16, 512, 768), (393216, 768, 1)) | |
assert_size_stride(view_132, (8192, 768), (768, 1)) | |
assert_size_stride(permute_default_30, (16, 12, 512, 64), (393216, 64, 768, 1)) | |
assert_size_stride(permute_default_31, (16, 12, 512, 64), (393216, 64, 768, 1)) | |
assert_size_stride(permute_default_32, (16, 12, 512, 64), (393216, 64, 768, 1)) | |
assert_size_stride(getitem_87, (16, 12, 512, 64), (393216, 64, 768, 1)) | |
assert_size_stride(getitem_88, (16, 12, 512), (6144, 512, 1)) | |
assert_size_stride(getitem_89, (), ()) | |
assert_size_stride(getitem_90, (), ()) | |
assert_size_stride(view_148, (8192, 768), (768, 1)) | |
assert_size_stride(gt_20, (16, 512, 768), (393216, 768, 1)) | |
assert_size_stride(mul_87, (16, 512, 768), (393216, 768, 1)) | |
assert_size_stride(view_150, (8192, 768), (768, 1)) | |
assert_size_stride(addmm_40, (8192, 3072), (3072, 1)) | |
assert_size_stride(view_152, (8192, 3072), (3072, 1)) | |
assert_size_stride(gt_21, (16, 512, 768), (393216, 768, 1)) | |
assert_size_stride(mul_94, (16, 512, 768), (393216, 768, 1)) | |
assert_size_stride(view_154, (8192, 768), (768, 1)) | |
assert_size_stride(permute_default_24, (16, 12, 512, 64), (393216, 64, 768, 1)) | |
assert_size_stride(permute_default_25, (16, 12, 512, 64), (393216, 64, 768, 1)) | |
assert_size_stride(permute_default_26, (16, 12, 512, 64), (393216, 64, 768, 1)) | |
assert_size_stride(getitem_80, (16, 12, 512, 64), (393216, 64, 768, 1)) | |
assert_size_stride(getitem_81, (16, 12, 512), (6144, 512, 1)) | |
assert_size_stride(getitem_82, (), ()) | |
assert_size_stride(getitem_83, (), ()) | |
assert_size_stride(view_170, (8192, 768), (768, 1)) | |
assert_size_stride(gt_23, (16, 512, 768), (393216, 768, 1)) | |
assert_size_stride(mul_100, (16, 512, 768), (393216, 768, 1)) | |
assert_size_stride(view_172, (8192, 768), (768, 1)) | |
assert_size_stride(addmm_46, (8192, 3072), (3072, 1)) | |
assert_size_stride(view_174, (8192, 3072), (3072, 1)) | |
assert_size_stride(gt_24, (16, 512, 768), (393216, 768, 1)) | |
assert_size_stride(mul_107, (16, 512, 768), (393216, 768, 1)) | |
assert_size_stride(view_176, (8192, 768), (768, 1)) | |
assert_size_stride(permute_default_18, (16, 12, 512, 64), (393216, 64, 768, 1)) | |
assert_size_stride(permute_default_19, (16, 12, 512, 64), (393216, 64, 768, 1)) | |
assert_size_stride(permute_default_20, (16, 12, 512, 64), (393216, 64, 768, 1)) | |
assert_size_stride(getitem_73, (16, 12, 512, 64), (393216, 64, 768, 1)) | |
assert_size_stride(getitem_74, (16, 12, 512), (6144, 512, 1)) | |
assert_size_stride(getitem_75, (), ()) | |
assert_size_stride(getitem_76, (), ()) | |
assert_size_stride(view_192, (8192, 768), (768, 1)) | |
assert_size_stride(gt_26, (16, 512, 768), (393216, 768, 1)) | |
assert_size_stride(mul_113, (16, 512, 768), (393216, 768, 1)) | |
assert_size_stride(view_194, (8192, 768), (768, 1)) | |
assert_size_stride(addmm_52, (8192, 3072), (3072, 1)) | |
assert_size_stride(view_196, (8192, 3072), (3072, 1)) | |
assert_size_stride(gt_27, (16, 512, 768), (393216, 768, 1)) | |
assert_size_stride(mul_120, (16, 512, 768), (393216, 768, 1)) | |
assert_size_stride(view_198, (8192, 768), (768, 1)) | |
assert_size_stride(permute_default_12, (16, 12, 512, 64), (393216, 64, 768, 1)) | |
assert_size_stride(permute_default_13, (16, 12, 512, 64), (393216, 64, 768, 1)) | |
assert_size_stride(permute_default_14, (16, 12, 512, 64), (393216, 64, 768, 1)) | |
assert_size_stride(getitem_66, (16, 12, 512, 64), (393216, 64, 768, 1)) | |
assert_size_stride(getitem_67, (16, 12, 512), (6144, 512, 1)) | |
assert_size_stride(getitem_68, (), ()) | |
assert_size_stride(getitem_69, (), ()) | |
assert_size_stride(view_214, (8192, 768), (768, 1)) | |
assert_size_stride(gt_29, (16, 512, 768), (393216, 768, 1)) | |
assert_size_stride(mul_126, (16, 512, 768), (393216, 768, 1)) | |
assert_size_stride(view_216, (8192, 768), (768, 1)) | |
assert_size_stride(addmm_58, (8192, 3072), (3072, 1)) | |
assert_size_stride(view_218, (8192, 3072), (3072, 1)) | |
assert_size_stride(gt_30, (16, 512, 768), (393216, 768, 1)) | |
assert_size_stride(mul_133, (16, 512, 768), (393216, 768, 1)) | |
assert_size_stride(view_220, (8192, 768), (768, 1)) | |
assert_size_stride(permute_default_6, (16, 12, 512, 64), (393216, 64, 768, 1)) | |
assert_size_stride(permute_default_7, (16, 12, 512, 64), (393216, 64, 768, 1)) | |
assert_size_stride(permute_default_8, (16, 12, 512, 64), (393216, 64, 768, 1)) | |
assert_size_stride(getitem_59, (16, 12, 512, 64), (393216, 64, 768, 1)) | |
assert_size_stride(getitem_60, (16, 12, 512), (6144, 512, 1)) | |
assert_size_stride(getitem_61, (), ()) | |
assert_size_stride(getitem_62, (), ()) | |
assert_size_stride(view_236, (8192, 768), (768, 1)) | |
assert_size_stride(gt_32, (16, 512, 768), (393216, 768, 1)) | |
assert_size_stride(mul_139, (16, 512, 768), (393216, 768, 1)) | |
assert_size_stride(view_238, (8192, 768), (768, 1)) | |
assert_size_stride(addmm_64, (8192, 3072), (3072, 1)) | |
assert_size_stride(view_240, (8192, 3072), (3072, 1)) | |
assert_size_stride(gt_33, (16, 512, 768), (393216, 768, 1)) | |
assert_size_stride(mul_146, (16, 512, 768), (393216, 768, 1)) | |
assert_size_stride(view_242, (8192, 768), (768, 1)) | |
assert_size_stride(permute_default, (16, 12, 512, 64), (393216, 64, 768, 1)) | |
assert_size_stride(permute_default_1, (16, 12, 512, 64), (393216, 64, 768, 1)) | |
assert_size_stride(permute_default_2, (16, 12, 512, 64), (393216, 64, 768, 1)) | |
assert_size_stride(getitem_52, (16, 12, 512, 64), (393216, 64, 768, 1)) | |
assert_size_stride(getitem_53, (16, 12, 512), (6144, 512, 1)) | |
assert_size_stride(getitem_54, (), ()) | |
assert_size_stride(getitem_55, (), ()) | |
assert_size_stride(view_258, (8192, 768), (768, 1)) | |
assert_size_stride(gt_35, (16, 512, 768), (393216, 768, 1)) | |
assert_size_stride(mul_152, (16, 512, 768), (393216, 768, 1)) | |
assert_size_stride(view_260, (8192, 768), (768, 1)) | |
assert_size_stride(addmm_70, (8192, 3072), (3072, 1)) | |
assert_size_stride(view_262, (8192, 3072), (3072, 1)) | |
assert_size_stride(gt_36, (16, 512, 768), (393216, 768, 1)) | |
assert_size_stride(mul_159, (16, 512, 768), (393216, 768, 1)) | |
assert_size_stride(view_264, (8192, 768), (768, 1)) | |
assert_size_stride(addmm_72, (8192, 768), (768, 1)) | |
assert_size_stride(getitem_51, (16, 512, 1), (512, 1, 1)) | |
assert_size_stride(rsqrt_25, (16, 512, 1), (512, 1, 1)) | |
assert_size_stride(view_266, (8192, 768), (768, 1)) | |
assert_size_stride(view_267, (16, 512, 30522), (15630336, 30528, 1)) | |
assert_size_stride(amax_12, (8192, 1), (1, 1)) | |
assert_size_stride(log, (8192, 1), (1, 1)) | |
assert_size_stride(convert_element_type_510, (), ()) | |
assert_size_stride(permute_134, (30522, 768), (768, 1)) | |
assert_size_stride(permute_138, (768, 768), (768, 1)) | |
assert_size_stride(div_27, (16, 512, 1), (512, 1, 1)) | |
assert_size_stride(permute_142, (768, 3072), (3072, 1)) | |
assert_size_stride(permute_146, (3072, 768), (768, 1)) | |
assert_size_stride(div_28, (16, 512, 1), (512, 1, 1)) | |
assert_size_stride(permute_150, (768, 768), (768, 1)) | |
assert_size_stride(permute_162, (768, 768), (768, 1)) | |
assert_size_stride(permute_167, (768, 768), (768, 1)) | |
assert_size_stride(permute_171, (768, 768), (768, 1)) | |
assert_size_stride(div_30, (16, 512, 1), (512, 1, 1)) | |
assert_size_stride(permute_175, (768, 3072), (3072, 1)) | |
assert_size_stride(permute_179, (3072, 768), (768, 1)) | |
assert_size_stride(div_31, (16, 512, 1), (512, 1, 1)) | |
assert_size_stride(permute_183, (768, 768), (768, 1)) | |
assert_size_stride(permute_195, (768, 768), (768, 1)) | |
assert_size_stride(permute_200, (768, 768), (768, 1)) | |
assert_size_stride(permute_204, (768, 768), (768, 1)) | |
assert_size_stride(div_33, (16, 512, 1), (512, 1, 1)) | |
assert_size_stride(permute_208, (768, 3072), (3072, 1)) | |
assert_size_stride(permute_212, (3072, 768), (768, 1)) | |
assert_size_stride(div_34, (16, 512, 1), (512, 1, 1)) | |
assert_size_stride(permute_216, (768, 768), (768, 1)) | |
assert_size_stride(permute_228, (768, 768), (768, 1)) | |
assert_size_stride(permute_233, (768, 768), (768, 1)) | |
assert_size_stride(permute_237, (768, 768), (768, 1)) | |
assert_size_stride(div_36, (16, 512, 1), (512, 1, 1)) | |
assert_size_stride(permute_241, (768, 3072), (3072, 1)) | |
assert_size_stride(permute_245, (3072, 768), (768, 1)) | |
assert_size_stride(div_37, (16, 512, 1), (512, 1, 1)) | |
assert_size_stride(permute_249, (768, 768), (768, 1)) | |
assert_size_stride(permute_261, (768, 768), (768, 1)) | |
assert_size_stride(permute_266, (768, 768), (768, 1)) | |
assert_size_stride(permute_270, (768, 768), (768, 1)) | |
assert_size_stride(div_39, (16, 512, 1), (512, 1, 1)) | |
assert_size_stride(permute_274, (768, 3072), (3072, 1)) | |
assert_size_stride(permute_278, (3072, 768), (768, 1)) | |
assert_size_stride(div_40, (16, 512, 1), (512, 1, 1)) | |
assert_size_stride(permute_282, (768, 768), (768, 1)) | |
assert_size_stride(permute_294, (768, 768), (768, 1)) | |
assert_size_stride(permute_299, (768, 768), (768, 1)) | |
assert_size_stride(permute_303, (768, 768), (768, 1)) | |
assert_size_stride(div_42, (16, 512, 1), (512, 1, 1)) | |
assert_size_stride(permute_307, (768, 3072), (3072, 1)) | |
assert_size_stride(permute_311, (3072, 768), (768, 1)) | |
assert_size_stride(div_43, (16, 512, 1), (512, 1, 1)) | |
assert_size_stride(permute_315, (768, 768), (768, 1)) | |
assert_size_stride(permute_327, (768, 768), (768, 1)) | |
assert_size_stride(permute_332, (768, 768), (768, 1)) | |
assert_size_stride(permute_336, (768, 768), (768, 1)) | |
assert_size_stride(div_45, (16, 512, 1), (512, 1, 1)) | |
assert_size_stride(permute_340, (768, 3072), (3072, 1)) | |
assert_size_stride(permute_344, (3072, 768), (768, 1)) | |
assert_size_stride(div_46, (16, 512, 1), (512, 1, 1)) | |
assert_size_stride(permute_348, (768, 768), (768, 1)) | |
assert_size_stride(permute_360, (768, 768), (768, 1)) | |
assert_size_stride(permute_365, (768, 768), (768, 1)) | |
assert_size_stride(permute_369, (768, 768), (768, 1)) | |
assert_size_stride(div_48, (16, 512, 1), (512, 1, 1)) | |
assert_size_stride(permute_373, (768, 3072), (3072, 1)) | |
assert_size_stride(permute_377, (3072, 768), (768, 1)) | |
assert_size_stride(div_49, (16, 512, 1), (512, 1, 1)) | |
assert_size_stride(permute_381, (768, 768), (768, 1)) | |
assert_size_stride(permute_393, (768, 768), (768, 1)) | |
assert_size_stride(permute_398, (768, 768), (768, 1)) | |
assert_size_stride(permute_402, (768, 768), (768, 1)) | |
assert_size_stride(div_51, (16, 512, 1), (512, 1, 1)) | |
assert_size_stride(permute_406, (768, 3072), (3072, 1)) | |
assert_size_stride(permute_410, (3072, 768), (768, 1)) | |
assert_size_stride(div_52, (16, 512, 1), (512, 1, 1)) | |
assert_size_stride(permute_414, (768, 768), (768, 1)) | |
assert_size_stride(permute_426, (768, 768), (768, 1)) | |
assert_size_stride(permute_431, (768, 768), (768, 1)) | |
assert_size_stride(permute_435, (768, 768), (768, 1)) | |
assert_size_stride(div_54, (16, 512, 1), (512, 1, 1)) | |
assert_size_stride(permute_439, (768, 3072), (3072, 1)) | |
assert_size_stride(permute_443, (3072, 768), (768, 1)) | |
assert_size_stride(div_55, (16, 512, 1), (512, 1, 1)) | |
assert_size_stride(permute_447, (768, 768), (768, 1)) | |
assert_size_stride(permute_459, (768, 768), (768, 1)) | |
assert_size_stride(permute_464, (768, 768), (768, 1)) | |
assert_size_stride(permute_468, (768, 768), (768, 1)) | |
assert_size_stride(div_57, (16, 512, 1), (512, 1, 1)) | |
assert_size_stride(permute_472, (768, 3072), (3072, 1)) | |
assert_size_stride(permute_476, (3072, 768), (768, 1)) | |
assert_size_stride(div_58, (16, 512, 1), (512, 1, 1)) | |
assert_size_stride(permute_480, (768, 768), (768, 1)) | |
assert_size_stride(permute_492, (768, 768), (768, 1)) | |
assert_size_stride(permute_497, (768, 768), (768, 1)) | |
assert_size_stride(permute_501, (768, 768), (768, 1)) | |
assert_size_stride(div_60, (16, 512, 1), (512, 1, 1)) | |
assert_size_stride(permute_505, (768, 3072), (3072, 1)) | |
assert_size_stride(permute_509, (3072, 768), (768, 1)) | |
assert_size_stride(div_61, (16, 512, 1), (512, 1, 1)) | |
assert_size_stride(permute_513, (768, 768), (768, 1)) | |
assert_size_stride(permute_525, (768, 768), (768, 1)) | |
assert_size_stride(permute_530, (768, 768), (768, 1)) | |
assert_size_stride(permute_534, (768, 768), (768, 1)) | |
assert_size_stride(div_63, (16, 512, 1), (512, 1, 1)) | |
assert_size_stride(tangents_1, (), ()) | |
assert_size_stride(tangents_2, (16, 512, 30522), (15627264, 30522, 1)) | |
with torch.cuda._DeviceGuard(0): | |
torch.cuda.set_device(0) | |
buf1 = empty_strided_cuda((16, 512, 30522), (15630336, 30528, 1), torch.float16) | |
# Source Nodes: [masked_lm_loss], Original ATen: [aten._log_softmax_backward_data, aten.add, aten.nll_loss_backward, aten.nll_loss_forward] | |
stream0 = get_raw_stream(0) | |
triton_red_fused__log_softmax_backward_data_add_nll_loss_backward_nll_loss_forward_0.run(primals_207, tangents_1, convert_element_type_510, tangents_2, view_267, amax_12, log, buf1, 8192, 30522, grid=grid(8192), stream=stream0) | |
del amax_12 | |
del convert_element_type_510 | |
del log | |
del primals_207 | |
del tangents_1 | |
del tangents_2 | |
del view_267 | |
buf2 = empty_strided_cuda((8192, 30528), (30528, 1), torch.float16) | |
# Source Nodes: [], Original ATen: [] | |
triton_poi_fused_1.run(buf1, buf2, 250085376, grid=grid(250085376), stream=stream0) | |
buf3 = empty_strided_cuda((30528, 768), (768, 1), torch.float16) | |
# Source Nodes: [], Original ATen: [] | |
triton_poi_fused_2.run(permute_134, buf3, 23445504, grid=grid(23445504), stream=stream0) | |
del permute_134 | |
buf4 = empty_strided_cuda((8192, 768), (768, 1), torch.float16) | |
# Source Nodes: [], Original ATen: [] | |
extern_kernels.mm(buf2, buf3, out=buf4) | |
del buf2 | |
del buf3 | |
buf5 = empty_strided_cuda((30522, 768), (768, 1), torch.float16) | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf1, (30522, 8192), (1, 30528), 0), view_266, out=buf5) | |
del view_266 | |
buf8 = empty_strided_cuda((30522, ), (1, ), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum] | |
triton_red_fused__to_copy_sum_3.run(buf1, buf8, 30522, 8192, grid=grid(30522), stream=stream0) | |
del buf1 | |
buf7 = empty_strided_cuda((30522, 768), (768, 1), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy] | |
triton_poi_fused__to_copy_4.run(buf5, buf7, 23440896, grid=grid(23440896), stream=stream0) | |
del buf5 | |
buf16 = empty_strided_cuda((8192, 768), (768, 1), torch.float16) | |
# Source Nodes: [hidden_states_97, hidden_states_98], Original ATen: [aten._to_copy, aten.gelu, aten.gelu_backward, aten.native_layer_norm, aten.native_layer_norm_backward, aten.view] | |
triton_per_fused__to_copy_gelu_gelu_backward_native_layer_norm_native_layer_norm_backward_view_5.run(buf4, primals_200, addmm_72, getitem_51, rsqrt_25, buf16, 8192, 768, grid=grid(8192), stream=stream0) | |
del primals_200 | |
buf11 = empty_strided_cuda((768, 64), (1, 768), torch.float32) | |
buf13 = empty_strided_cuda((768, 64), (1, 768), torch.float32) | |
# Source Nodes: [hidden_states_97, hidden_states_98], Original ATen: [aten._to_copy, aten.gelu, aten.native_layer_norm, aten.native_layer_norm_backward] | |
triton_red_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_6.run(buf4, addmm_72, getitem_51, rsqrt_25, buf11, buf13, 49152, 128, grid=grid(49152), stream=stream0) | |
del addmm_72 | |
del getitem_51 | |
del rsqrt_25 | |
buf12 = empty_strided_cuda((768, ), (1, ), torch.float32) | |
# Source Nodes: [hidden_states_97, hidden_states_98], Original ATen: [aten._to_copy, aten.gelu, aten.native_layer_norm, aten.native_layer_norm_backward] | |
triton_per_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_7.run(buf11, buf12, 768, 64, grid=grid(768), stream=stream0) | |
buf14 = empty_strided_cuda((768, ), (1, ), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.native_layer_norm_backward] | |
triton_per_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_7.run(buf13, buf14, 768, 64, grid=grid(768), stream=stream0) | |
buf17 = buf4; del buf4 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(buf16, permute_138, out=buf17) | |
del permute_138 | |
buf18 = empty_strided_cuda((768, 768), (768, 1), torch.float16) | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf16, (768, 8192), (1, 768), 0), view_264, out=buf18) | |
del view_264 | |
buf19 = reinterpret_tensor(buf13, (1, 768, 64), (49152, 1, 768), 0); del buf13 # reuse | |
# Source Nodes: [], Original ATen: [aten.sum] | |
triton_red_fused_sum_8.run(buf16, buf19, 49152, 128, grid=grid(49152), stream=stream0) | |
buf22 = empty_strided_cuda((768, ), (1, ), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum] | |
triton_per_fused__to_copy_sum_9.run(buf19, buf22, 768, 64, grid=grid(768), stream=stream0) | |
buf21 = empty_strided_cuda((768, 768), (768, 1), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy] | |
triton_poi_fused__to_copy_10.run(buf18, buf21, 589824, grid=grid(589824), stream=stream0) | |
buf25 = empty_strided_cuda((16, 512, 768), (393216, 768, 1), torch.float32) | |
buf30 = reinterpret_tensor(buf16, (16, 512, 768), (393216, 768, 1), 0); del buf16 # reuse | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.native_dropout_backward, aten.native_layer_norm_backward] | |
triton_per_fused__to_copy_native_dropout_backward_native_layer_norm_backward_11.run(buf17, primals_196, mul_159, div_27, gt_36, buf25, buf30, 8192, 768, grid=grid(8192), stream=stream0) | |
del div_27 | |
del gt_36 | |
del primals_196 | |
buf26 = reinterpret_tensor(buf19, (768, 64), (1, 768), 0); del buf19 # reuse | |
buf28 = buf11; del buf11 # reuse | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.native_layer_norm_backward] | |
triton_red_fused__to_copy_native_layer_norm_backward_12.run(buf17, mul_159, buf26, buf28, 49152, 128, grid=grid(49152), stream=stream0) | |
del mul_159 | |
buf27 = empty_strided_cuda((768, ), (1, ), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.native_layer_norm_backward] | |
triton_per_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_7.run(buf26, buf27, 768, 64, grid=grid(768), stream=stream0) | |
buf29 = empty_strided_cuda((768, ), (1, ), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.native_layer_norm_backward] | |
triton_per_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_7.run(buf28, buf29, 768, 64, grid=grid(768), stream=stream0) | |
buf31 = empty_strided_cuda((8192, 3072), (3072, 1), torch.float16) | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf30, (8192, 768), (768, 1), 0), permute_142, out=buf31) | |
del permute_142 | |
buf32 = empty_strided_cuda((768, 3072), (3072, 1), torch.float16) | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf30, (768, 8192), (1, 768), 0), view_262, out=buf32) | |
del view_262 | |
buf33 = reinterpret_tensor(buf28, (1, 768, 64), (49152, 1, 768), 0); del buf28 # reuse | |
# Source Nodes: [], Original ATen: [aten.sum] | |
triton_red_fused_sum_13.run(buf30, buf33, 49152, 128, grid=grid(49152), stream=stream0) | |
buf36 = empty_strided_cuda((768, ), (1, ), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum] | |
triton_per_fused__to_copy_sum_9.run(buf33, buf36, 768, 64, grid=grid(768), stream=stream0) | |
buf35 = empty_strided_cuda((768, 3072), (3072, 1), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy] | |
triton_poi_fused__to_copy_14.run(buf32, buf35, 2359296, grid=grid(2359296), stream=stream0) | |
buf37 = reinterpret_tensor(buf31, (16, 512, 3072), (1572864, 3072, 1), 0); del buf31 # reuse | |
# Source Nodes: [hidden_states_92], Original ATen: [aten.gelu, aten.gelu_backward] | |
triton_poi_fused_gelu_gelu_backward_15.run(buf37, addmm_70, 25165824, grid=grid(25165824), stream=stream0) | |
del addmm_70 | |
buf38 = reinterpret_tensor(buf30, (8192, 768), (768, 1), 0); del buf30 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf37, (8192, 3072), (3072, 1), 0), permute_146, out=buf38) | |
del permute_146 | |
buf39 = reinterpret_tensor(buf32, (3072, 768), (768, 1), 0); del buf32 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf37, (3072, 8192), (1, 3072), 0), view_260, out=buf39) | |
del view_260 | |
buf40 = empty_strided_cuda((1, 3072, 32), (98304, 1, 3072), torch.float32) | |
# Source Nodes: [], Original ATen: [aten.sum] | |
triton_red_fused_sum_16.run(buf37, buf40, 98304, 256, grid=grid(98304), stream=stream0) | |
buf43 = empty_strided_cuda((3072, ), (1, ), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum] | |
triton_per_fused__to_copy_sum_17.run(buf40, buf43, 3072, 32, grid=grid(3072), stream=stream0) | |
buf42 = empty_strided_cuda((3072, 768), (768, 1), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy] | |
triton_poi_fused__to_copy_18.run(buf39, buf42, 2359296, grid=grid(2359296), stream=stream0) | |
buf46 = empty_strided_cuda((16, 512, 768), (393216, 768, 1), torch.float32) | |
buf51 = reinterpret_tensor(buf17, (16, 512, 768), (393216, 768, 1), 0); del buf17 # reuse | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_dropout_backward, aten.native_layer_norm_backward] | |
triton_per_fused__to_copy_add_native_dropout_backward_native_layer_norm_backward_19.run(buf25, buf38, primals_190, mul_152, div_28, gt_35, buf46, buf51, 8192, 768, grid=grid(8192), stream=stream0) | |
del div_28 | |
del gt_35 | |
del primals_190 | |
buf47 = reinterpret_tensor(buf33, (768, 64), (1, 768), 0); del buf33 # reuse | |
buf49 = buf26; del buf26 # reuse | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward] | |
triton_red_fused__to_copy_add_native_layer_norm_backward_20.run(buf25, buf38, mul_152, buf47, buf49, 49152, 128, grid=grid(49152), stream=stream0) | |
del mul_152 | |
buf48 = empty_strided_cuda((768, ), (1, ), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward] | |
triton_per_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_7.run(buf47, buf48, 768, 64, grid=grid(768), stream=stream0) | |
buf50 = empty_strided_cuda((768, ), (1, ), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward] | |
triton_per_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_7.run(buf49, buf50, 768, 64, grid=grid(768), stream=stream0) | |
buf52 = buf38; del buf38 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf51, (8192, 768), (768, 1), 0), permute_150, out=buf52) | |
del permute_150 | |
buf53 = buf18; del buf18 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf51, (768, 8192), (1, 768), 0), view_258, out=buf53) | |
del view_258 | |
buf54 = reinterpret_tensor(buf49, (1, 768, 64), (49152, 1, 768), 0); del buf49 # reuse | |
# Source Nodes: [], Original ATen: [aten.sum] | |
triton_red_fused_sum_13.run(buf51, buf54, 49152, 128, grid=grid(49152), stream=stream0) | |
buf57 = empty_strided_cuda((768, ), (1, ), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum] | |
triton_per_fused__to_copy_sum_9.run(buf54, buf57, 768, 64, grid=grid(768), stream=stream0) | |
buf56 = empty_strided_cuda((768, 768), (768, 1), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy] | |
triton_poi_fused__to_copy_10.run(buf53, buf56, 589824, grid=grid(589824), stream=stream0) | |
buf58 = reinterpret_tensor(buf51, (16, 12, 512, 64), (393216, 32768, 64, 1), 0); del buf51 # reuse | |
# Source Nodes: [], Original ATen: [aten.clone] | |
triton_poi_fused_clone_21.run(buf52, buf58, 6291456, grid=grid(6291456), stream=stream0) | |
del buf52 | |
# Source Nodes: [], Original ATen: [aten.clone] | |
buf59 = aten._scaled_dot_product_flash_attention_backward.default(buf58, permute_default, permute_default_1, permute_default_2, getitem_52, getitem_53, None, None, 512, 512, 0.1, False, getitem_54, getitem_55, scale=0.125) | |
del getitem_52 | |
del getitem_53 | |
del getitem_54 | |
del getitem_55 | |
del permute_default | |
del permute_default_1 | |
del permute_default_2 | |
buf60 = buf59[0] | |
buf61 = buf59[1] | |
buf62 = buf59[2] | |
del buf59 | |
buf63 = reinterpret_tensor(buf58, (8192, 768), (768, 1), 0); del buf58 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf62, (8192, 768), (768, 1), 0), permute_162, out=buf63) | |
del permute_162 | |
buf64 = buf53; del buf53 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf62, (768, 8192), (1, 768), 0), view_242, out=buf64) | |
buf65 = buf54; del buf54 # reuse | |
# Source Nodes: [], Original ATen: [aten.sum] | |
triton_red_fused_sum_22.run(buf62, buf65, 49152, 128, grid=grid(49152), stream=stream0) | |
buf68 = empty_strided_cuda((768, ), (1, ), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum] | |
triton_per_fused__to_copy_sum_9.run(buf65, buf68, 768, 64, grid=grid(768), stream=stream0) | |
buf67 = empty_strided_cuda((768, 768), (768, 1), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy] | |
triton_poi_fused__to_copy_10.run(buf64, buf67, 589824, grid=grid(589824), stream=stream0) | |
buf69 = reinterpret_tensor(buf62, (8192, 768), (768, 1), 0); del buf62 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf61, (8192, 768), (768, 1), 0), permute_167, out=buf69) | |
del permute_167 | |
buf70 = buf64; del buf64 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf61, (768, 8192), (1, 768), 0), view_242, out=buf70) | |
buf71 = buf65; del buf65 # reuse | |
# Source Nodes: [], Original ATen: [aten.sum] | |
triton_red_fused_sum_22.run(buf61, buf71, 49152, 128, grid=grid(49152), stream=stream0) | |
buf74 = empty_strided_cuda((768, ), (1, ), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum] | |
triton_per_fused__to_copy_sum_9.run(buf71, buf74, 768, 64, grid=grid(768), stream=stream0) | |
buf73 = empty_strided_cuda((768, 768), (768, 1), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy] | |
triton_poi_fused__to_copy_10.run(buf70, buf73, 589824, grid=grid(589824), stream=stream0) | |
buf75 = reinterpret_tensor(buf61, (8192, 768), (768, 1), 0); del buf61 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf60, (8192, 768), (768, 1), 0), permute_171, out=buf75) | |
del permute_171 | |
buf76 = buf70; del buf70 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf60, (768, 8192), (1, 768), 0), view_242, out=buf76) | |
del view_242 | |
buf77 = buf71; del buf71 # reuse | |
# Source Nodes: [], Original ATen: [aten.sum] | |
triton_red_fused_sum_22.run(buf60, buf77, 49152, 128, grid=grid(49152), stream=stream0) | |
buf80 = empty_strided_cuda((768, ), (1, ), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum] | |
triton_per_fused__to_copy_sum_9.run(buf77, buf80, 768, 64, grid=grid(768), stream=stream0) | |
buf79 = empty_strided_cuda((768, 768), (768, 1), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy] | |
triton_poi_fused__to_copy_10.run(buf76, buf79, 589824, grid=grid(589824), stream=stream0) | |
buf84 = buf25; del buf25 # reuse | |
buf89 = reinterpret_tensor(buf60, (16, 512, 768), (393216, 768, 1), 0); del buf60 # reuse | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_dropout_backward, aten.native_layer_norm_backward] | |
triton_per_fused__to_copy_add_native_dropout_backward_native_layer_norm_backward_23.run(buf46, buf63, buf69, buf75, primals_180, mul_146, div_30, gt_33, buf84, buf89, 8192, 768, grid=grid(8192), stream=stream0) | |
del div_30 | |
del gt_33 | |
del primals_180 | |
buf85 = reinterpret_tensor(buf77, (768, 64), (1, 768), 0); del buf77 # reuse | |
buf87 = buf47; del buf47 # reuse | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward] | |
triton_red_fused__to_copy_add_native_layer_norm_backward_24.run(buf46, buf63, buf69, buf75, mul_146, buf85, buf87, 49152, 128, grid=grid(49152), stream=stream0) | |
del buf63 | |
del buf69 | |
del mul_146 | |
buf86 = empty_strided_cuda((768, ), (1, ), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward] | |
triton_per_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_7.run(buf85, buf86, 768, 64, grid=grid(768), stream=stream0) | |
buf88 = empty_strided_cuda((768, ), (1, ), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward] | |
triton_per_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_7.run(buf87, buf88, 768, 64, grid=grid(768), stream=stream0) | |
buf90 = reinterpret_tensor(buf37, (8192, 3072), (3072, 1), 0); del buf37 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf89, (8192, 768), (768, 1), 0), permute_175, out=buf90) | |
del permute_175 | |
buf91 = reinterpret_tensor(buf39, (768, 3072), (3072, 1), 0); del buf39 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf89, (768, 8192), (1, 768), 0), view_240, out=buf91) | |
del view_240 | |
buf92 = reinterpret_tensor(buf87, (1, 768, 64), (49152, 1, 768), 0); del buf87 # reuse | |
# Source Nodes: [], Original ATen: [aten.sum] | |
triton_red_fused_sum_13.run(buf89, buf92, 49152, 128, grid=grid(49152), stream=stream0) | |
buf95 = empty_strided_cuda((768, ), (1, ), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum] | |
triton_per_fused__to_copy_sum_9.run(buf92, buf95, 768, 64, grid=grid(768), stream=stream0) | |
buf94 = empty_strided_cuda((768, 3072), (3072, 1), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy] | |
triton_poi_fused__to_copy_14.run(buf91, buf94, 2359296, grid=grid(2359296), stream=stream0) | |
buf96 = reinterpret_tensor(buf90, (16, 512, 3072), (1572864, 3072, 1), 0); del buf90 # reuse | |
# Source Nodes: [hidden_states_84], Original ATen: [aten.gelu, aten.gelu_backward] | |
triton_poi_fused_gelu_gelu_backward_15.run(buf96, addmm_64, 25165824, grid=grid(25165824), stream=stream0) | |
del addmm_64 | |
buf97 = reinterpret_tensor(buf89, (8192, 768), (768, 1), 0); del buf89 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf96, (8192, 3072), (3072, 1), 0), permute_179, out=buf97) | |
del permute_179 | |
buf98 = reinterpret_tensor(buf91, (3072, 768), (768, 1), 0); del buf91 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf96, (3072, 8192), (1, 3072), 0), view_238, out=buf98) | |
del view_238 | |
buf99 = buf40; del buf40 # reuse | |
# Source Nodes: [], Original ATen: [aten.sum] | |
triton_red_fused_sum_16.run(buf96, buf99, 98304, 256, grid=grid(98304), stream=stream0) | |
buf102 = empty_strided_cuda((3072, ), (1, ), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum] | |
triton_per_fused__to_copy_sum_17.run(buf99, buf102, 3072, 32, grid=grid(3072), stream=stream0) | |
buf101 = empty_strided_cuda((3072, 768), (768, 1), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy] | |
triton_poi_fused__to_copy_18.run(buf98, buf101, 2359296, grid=grid(2359296), stream=stream0) | |
buf105 = buf46; del buf46 # reuse | |
buf110 = reinterpret_tensor(buf75, (16, 512, 768), (393216, 768, 1), 0); del buf75 # reuse | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_dropout_backward, aten.native_layer_norm_backward] | |
triton_per_fused__to_copy_add_native_dropout_backward_native_layer_norm_backward_19.run(buf84, buf97, primals_174, mul_139, div_31, gt_32, buf105, buf110, 8192, 768, grid=grid(8192), stream=stream0) | |
del div_31 | |
del gt_32 | |
del primals_174 | |
buf106 = reinterpret_tensor(buf92, (768, 64), (1, 768), 0); del buf92 # reuse | |
buf108 = buf85; del buf85 # reuse | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward] | |
triton_red_fused__to_copy_add_native_layer_norm_backward_20.run(buf84, buf97, mul_139, buf106, buf108, 49152, 128, grid=grid(49152), stream=stream0) | |
del mul_139 | |
buf107 = empty_strided_cuda((768, ), (1, ), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward] | |
triton_per_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_7.run(buf106, buf107, 768, 64, grid=grid(768), stream=stream0) | |
buf109 = empty_strided_cuda((768, ), (1, ), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward] | |
triton_per_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_7.run(buf108, buf109, 768, 64, grid=grid(768), stream=stream0) | |
buf111 = buf97; del buf97 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf110, (8192, 768), (768, 1), 0), permute_183, out=buf111) | |
del permute_183 | |
buf112 = buf76; del buf76 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf110, (768, 8192), (1, 768), 0), view_236, out=buf112) | |
del view_236 | |
buf113 = reinterpret_tensor(buf108, (1, 768, 64), (49152, 1, 768), 0); del buf108 # reuse | |
# Source Nodes: [], Original ATen: [aten.sum] | |
triton_red_fused_sum_13.run(buf110, buf113, 49152, 128, grid=grid(49152), stream=stream0) | |
buf116 = empty_strided_cuda((768, ), (1, ), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum] | |
triton_per_fused__to_copy_sum_9.run(buf113, buf116, 768, 64, grid=grid(768), stream=stream0) | |
buf115 = empty_strided_cuda((768, 768), (768, 1), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy] | |
triton_poi_fused__to_copy_10.run(buf112, buf115, 589824, grid=grid(589824), stream=stream0) | |
buf117 = reinterpret_tensor(buf110, (16, 12, 512, 64), (393216, 32768, 64, 1), 0); del buf110 # reuse | |
# Source Nodes: [], Original ATen: [aten.clone] | |
triton_poi_fused_clone_21.run(buf111, buf117, 6291456, grid=grid(6291456), stream=stream0) | |
del buf111 | |
# Source Nodes: [], Original ATen: [aten.clone] | |
buf118 = aten._scaled_dot_product_flash_attention_backward.default(buf117, permute_default_6, permute_default_7, permute_default_8, getitem_59, getitem_60, None, None, 512, 512, 0.1, False, getitem_61, getitem_62, scale=0.125) | |
del getitem_59 | |
del getitem_60 | |
del getitem_61 | |
del getitem_62 | |
del permute_default_6 | |
del permute_default_7 | |
del permute_default_8 | |
buf119 = buf118[0] | |
buf120 = buf118[1] | |
buf121 = buf118[2] | |
del buf118 | |
buf122 = reinterpret_tensor(buf117, (8192, 768), (768, 1), 0); del buf117 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf121, (8192, 768), (768, 1), 0), permute_195, out=buf122) | |
del permute_195 | |
buf123 = buf112; del buf112 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf121, (768, 8192), (1, 768), 0), view_220, out=buf123) | |
buf124 = buf113; del buf113 # reuse | |
# Source Nodes: [], Original ATen: [aten.sum] | |
triton_red_fused_sum_22.run(buf121, buf124, 49152, 128, grid=grid(49152), stream=stream0) | |
buf127 = empty_strided_cuda((768, ), (1, ), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum] | |
triton_per_fused__to_copy_sum_9.run(buf124, buf127, 768, 64, grid=grid(768), stream=stream0) | |
buf126 = empty_strided_cuda((768, 768), (768, 1), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy] | |
triton_poi_fused__to_copy_10.run(buf123, buf126, 589824, grid=grid(589824), stream=stream0) | |
buf128 = reinterpret_tensor(buf121, (8192, 768), (768, 1), 0); del buf121 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf120, (8192, 768), (768, 1), 0), permute_200, out=buf128) | |
del permute_200 | |
buf129 = buf123; del buf123 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf120, (768, 8192), (1, 768), 0), view_220, out=buf129) | |
buf130 = buf124; del buf124 # reuse | |
# Source Nodes: [], Original ATen: [aten.sum] | |
triton_red_fused_sum_22.run(buf120, buf130, 49152, 128, grid=grid(49152), stream=stream0) | |
buf133 = empty_strided_cuda((768, ), (1, ), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum] | |
triton_per_fused__to_copy_sum_9.run(buf130, buf133, 768, 64, grid=grid(768), stream=stream0) | |
buf132 = empty_strided_cuda((768, 768), (768, 1), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy] | |
triton_poi_fused__to_copy_10.run(buf129, buf132, 589824, grid=grid(589824), stream=stream0) | |
buf134 = reinterpret_tensor(buf120, (8192, 768), (768, 1), 0); del buf120 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf119, (8192, 768), (768, 1), 0), permute_204, out=buf134) | |
del permute_204 | |
buf135 = buf129; del buf129 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf119, (768, 8192), (1, 768), 0), view_220, out=buf135) | |
del view_220 | |
buf136 = buf130; del buf130 # reuse | |
# Source Nodes: [], Original ATen: [aten.sum] | |
triton_red_fused_sum_22.run(buf119, buf136, 49152, 128, grid=grid(49152), stream=stream0) | |
buf139 = empty_strided_cuda((768, ), (1, ), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum] | |
triton_per_fused__to_copy_sum_9.run(buf136, buf139, 768, 64, grid=grid(768), stream=stream0) | |
buf138 = empty_strided_cuda((768, 768), (768, 1), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy] | |
triton_poi_fused__to_copy_10.run(buf135, buf138, 589824, grid=grid(589824), stream=stream0) | |
buf143 = buf84; del buf84 # reuse | |
buf148 = reinterpret_tensor(buf119, (16, 512, 768), (393216, 768, 1), 0); del buf119 # reuse | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_dropout_backward, aten.native_layer_norm_backward] | |
triton_per_fused__to_copy_add_native_dropout_backward_native_layer_norm_backward_23.run(buf105, buf122, buf128, buf134, primals_164, mul_133, div_33, gt_30, buf143, buf148, 8192, 768, grid=grid(8192), stream=stream0) | |
del div_33 | |
del gt_30 | |
del primals_164 | |
buf144 = reinterpret_tensor(buf136, (768, 64), (1, 768), 0); del buf136 # reuse | |
buf146 = buf106; del buf106 # reuse | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward] | |
triton_red_fused__to_copy_add_native_layer_norm_backward_24.run(buf105, buf122, buf128, buf134, mul_133, buf144, buf146, 49152, 128, grid=grid(49152), stream=stream0) | |
del buf122 | |
del buf128 | |
del mul_133 | |
buf145 = empty_strided_cuda((768, ), (1, ), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward] | |
triton_per_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_7.run(buf144, buf145, 768, 64, grid=grid(768), stream=stream0) | |
buf147 = empty_strided_cuda((768, ), (1, ), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward] | |
triton_per_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_7.run(buf146, buf147, 768, 64, grid=grid(768), stream=stream0) | |
buf149 = reinterpret_tensor(buf96, (8192, 3072), (3072, 1), 0); del buf96 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf148, (8192, 768), (768, 1), 0), permute_208, out=buf149) | |
del permute_208 | |
buf150 = reinterpret_tensor(buf98, (768, 3072), (3072, 1), 0); del buf98 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf148, (768, 8192), (1, 768), 0), view_218, out=buf150) | |
del view_218 | |
buf151 = reinterpret_tensor(buf146, (1, 768, 64), (49152, 1, 768), 0); del buf146 # reuse | |
# Source Nodes: [], Original ATen: [aten.sum] | |
triton_red_fused_sum_13.run(buf148, buf151, 49152, 128, grid=grid(49152), stream=stream0) | |
buf154 = empty_strided_cuda((768, ), (1, ), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum] | |
triton_per_fused__to_copy_sum_9.run(buf151, buf154, 768, 64, grid=grid(768), stream=stream0) | |
buf153 = empty_strided_cuda((768, 3072), (3072, 1), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy] | |
triton_poi_fused__to_copy_14.run(buf150, buf153, 2359296, grid=grid(2359296), stream=stream0) | |
buf155 = reinterpret_tensor(buf149, (16, 512, 3072), (1572864, 3072, 1), 0); del buf149 # reuse | |
# Source Nodes: [hidden_states_76], Original ATen: [aten.gelu, aten.gelu_backward] | |
triton_poi_fused_gelu_gelu_backward_15.run(buf155, addmm_58, 25165824, grid=grid(25165824), stream=stream0) | |
del addmm_58 | |
buf156 = reinterpret_tensor(buf148, (8192, 768), (768, 1), 0); del buf148 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf155, (8192, 3072), (3072, 1), 0), permute_212, out=buf156) | |
del permute_212 | |
buf157 = reinterpret_tensor(buf150, (3072, 768), (768, 1), 0); del buf150 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf155, (3072, 8192), (1, 3072), 0), view_216, out=buf157) | |
del view_216 | |
buf158 = buf99; del buf99 # reuse | |
# Source Nodes: [], Original ATen: [aten.sum] | |
triton_red_fused_sum_16.run(buf155, buf158, 98304, 256, grid=grid(98304), stream=stream0) | |
buf161 = empty_strided_cuda((3072, ), (1, ), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum] | |
triton_per_fused__to_copy_sum_17.run(buf158, buf161, 3072, 32, grid=grid(3072), stream=stream0) | |
buf160 = empty_strided_cuda((3072, 768), (768, 1), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy] | |
triton_poi_fused__to_copy_18.run(buf157, buf160, 2359296, grid=grid(2359296), stream=stream0) | |
buf164 = buf105; del buf105 # reuse | |
buf169 = reinterpret_tensor(buf134, (16, 512, 768), (393216, 768, 1), 0); del buf134 # reuse | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_dropout_backward, aten.native_layer_norm_backward] | |
triton_per_fused__to_copy_add_native_dropout_backward_native_layer_norm_backward_19.run(buf143, buf156, primals_158, mul_126, div_34, gt_29, buf164, buf169, 8192, 768, grid=grid(8192), stream=stream0) | |
del div_34 | |
del gt_29 | |
del primals_158 | |
buf165 = reinterpret_tensor(buf151, (768, 64), (1, 768), 0); del buf151 # reuse | |
buf167 = buf144; del buf144 # reuse | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward] | |
triton_red_fused__to_copy_add_native_layer_norm_backward_20.run(buf143, buf156, mul_126, buf165, buf167, 49152, 128, grid=grid(49152), stream=stream0) | |
del mul_126 | |
buf166 = empty_strided_cuda((768, ), (1, ), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward] | |
triton_per_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_7.run(buf165, buf166, 768, 64, grid=grid(768), stream=stream0) | |
buf168 = empty_strided_cuda((768, ), (1, ), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward] | |
triton_per_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_7.run(buf167, buf168, 768, 64, grid=grid(768), stream=stream0) | |
buf170 = buf156; del buf156 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf169, (8192, 768), (768, 1), 0), permute_216, out=buf170) | |
del permute_216 | |
buf171 = buf135; del buf135 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf169, (768, 8192), (1, 768), 0), view_214, out=buf171) | |
del view_214 | |
buf172 = reinterpret_tensor(buf167, (1, 768, 64), (49152, 1, 768), 0); del buf167 # reuse | |
# Source Nodes: [], Original ATen: [aten.sum] | |
triton_red_fused_sum_13.run(buf169, buf172, 49152, 128, grid=grid(49152), stream=stream0) | |
buf175 = empty_strided_cuda((768, ), (1, ), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum] | |
triton_per_fused__to_copy_sum_9.run(buf172, buf175, 768, 64, grid=grid(768), stream=stream0) | |
buf174 = empty_strided_cuda((768, 768), (768, 1), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy] | |
triton_poi_fused__to_copy_10.run(buf171, buf174, 589824, grid=grid(589824), stream=stream0) | |
buf176 = reinterpret_tensor(buf169, (16, 12, 512, 64), (393216, 32768, 64, 1), 0); del buf169 # reuse | |
# Source Nodes: [], Original ATen: [aten.clone] | |
triton_poi_fused_clone_21.run(buf170, buf176, 6291456, grid=grid(6291456), stream=stream0) | |
del buf170 | |
# Source Nodes: [], Original ATen: [aten.clone] | |
buf177 = aten._scaled_dot_product_flash_attention_backward.default(buf176, permute_default_12, permute_default_13, permute_default_14, getitem_66, getitem_67, None, None, 512, 512, 0.1, False, getitem_68, getitem_69, scale=0.125) | |
del getitem_66 | |
del getitem_67 | |
del getitem_68 | |
del getitem_69 | |
del permute_default_12 | |
del permute_default_13 | |
del permute_default_14 | |
buf178 = buf177[0] | |
buf179 = buf177[1] | |
buf180 = buf177[2] | |
del buf177 | |
buf181 = reinterpret_tensor(buf176, (8192, 768), (768, 1), 0); del buf176 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf180, (8192, 768), (768, 1), 0), permute_228, out=buf181) | |
del permute_228 | |
buf182 = buf171; del buf171 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf180, (768, 8192), (1, 768), 0), view_198, out=buf182) | |
buf183 = buf172; del buf172 # reuse | |
# Source Nodes: [], Original ATen: [aten.sum] | |
triton_red_fused_sum_22.run(buf180, buf183, 49152, 128, grid=grid(49152), stream=stream0) | |
buf186 = empty_strided_cuda((768, ), (1, ), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum] | |
triton_per_fused__to_copy_sum_9.run(buf183, buf186, 768, 64, grid=grid(768), stream=stream0) | |
buf185 = empty_strided_cuda((768, 768), (768, 1), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy] | |
triton_poi_fused__to_copy_10.run(buf182, buf185, 589824, grid=grid(589824), stream=stream0) | |
buf187 = reinterpret_tensor(buf180, (8192, 768), (768, 1), 0); del buf180 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf179, (8192, 768), (768, 1), 0), permute_233, out=buf187) | |
del permute_233 | |
buf188 = buf182; del buf182 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf179, (768, 8192), (1, 768), 0), view_198, out=buf188) | |
buf189 = buf183; del buf183 # reuse | |
# Source Nodes: [], Original ATen: [aten.sum] | |
triton_red_fused_sum_22.run(buf179, buf189, 49152, 128, grid=grid(49152), stream=stream0) | |
buf192 = empty_strided_cuda((768, ), (1, ), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum] | |
triton_per_fused__to_copy_sum_9.run(buf189, buf192, 768, 64, grid=grid(768), stream=stream0) | |
buf191 = empty_strided_cuda((768, 768), (768, 1), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy] | |
triton_poi_fused__to_copy_10.run(buf188, buf191, 589824, grid=grid(589824), stream=stream0) | |
buf193 = reinterpret_tensor(buf179, (8192, 768), (768, 1), 0); del buf179 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf178, (8192, 768), (768, 1), 0), permute_237, out=buf193) | |
del permute_237 | |
buf194 = buf188; del buf188 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf178, (768, 8192), (1, 768), 0), view_198, out=buf194) | |
del view_198 | |
buf195 = buf189; del buf189 # reuse | |
# Source Nodes: [], Original ATen: [aten.sum] | |
triton_red_fused_sum_22.run(buf178, buf195, 49152, 128, grid=grid(49152), stream=stream0) | |
buf198 = empty_strided_cuda((768, ), (1, ), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum] | |
triton_per_fused__to_copy_sum_9.run(buf195, buf198, 768, 64, grid=grid(768), stream=stream0) | |
buf197 = empty_strided_cuda((768, 768), (768, 1), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy] | |
triton_poi_fused__to_copy_10.run(buf194, buf197, 589824, grid=grid(589824), stream=stream0) | |
buf202 = buf143; del buf143 # reuse | |
buf207 = reinterpret_tensor(buf178, (16, 512, 768), (393216, 768, 1), 0); del buf178 # reuse | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_dropout_backward, aten.native_layer_norm_backward] | |
triton_per_fused__to_copy_add_native_dropout_backward_native_layer_norm_backward_23.run(buf164, buf181, buf187, buf193, primals_148, mul_120, div_36, gt_27, buf202, buf207, 8192, 768, grid=grid(8192), stream=stream0) | |
del div_36 | |
del gt_27 | |
del primals_148 | |
buf203 = reinterpret_tensor(buf195, (768, 64), (1, 768), 0); del buf195 # reuse | |
buf205 = buf165; del buf165 # reuse | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward] | |
triton_red_fused__to_copy_add_native_layer_norm_backward_24.run(buf164, buf181, buf187, buf193, mul_120, buf203, buf205, 49152, 128, grid=grid(49152), stream=stream0) | |
del buf181 | |
del buf187 | |
del mul_120 | |
buf204 = empty_strided_cuda((768, ), (1, ), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward] | |
triton_per_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_7.run(buf203, buf204, 768, 64, grid=grid(768), stream=stream0) | |
buf206 = empty_strided_cuda((768, ), (1, ), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward] | |
triton_per_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_7.run(buf205, buf206, 768, 64, grid=grid(768), stream=stream0) | |
buf208 = reinterpret_tensor(buf155, (8192, 3072), (3072, 1), 0); del buf155 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf207, (8192, 768), (768, 1), 0), permute_241, out=buf208) | |
del permute_241 | |
buf209 = reinterpret_tensor(buf157, (768, 3072), (3072, 1), 0); del buf157 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf207, (768, 8192), (1, 768), 0), view_196, out=buf209) | |
del view_196 | |
buf210 = reinterpret_tensor(buf205, (1, 768, 64), (49152, 1, 768), 0); del buf205 # reuse | |
# Source Nodes: [], Original ATen: [aten.sum] | |
triton_red_fused_sum_13.run(buf207, buf210, 49152, 128, grid=grid(49152), stream=stream0) | |
buf213 = empty_strided_cuda((768, ), (1, ), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum] | |
triton_per_fused__to_copy_sum_9.run(buf210, buf213, 768, 64, grid=grid(768), stream=stream0) | |
buf212 = empty_strided_cuda((768, 3072), (3072, 1), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy] | |
triton_poi_fused__to_copy_14.run(buf209, buf212, 2359296, grid=grid(2359296), stream=stream0) | |
buf214 = reinterpret_tensor(buf208, (16, 512, 3072), (1572864, 3072, 1), 0); del buf208 # reuse | |
# Source Nodes: [hidden_states_68], Original ATen: [aten.gelu, aten.gelu_backward] | |
triton_poi_fused_gelu_gelu_backward_15.run(buf214, addmm_52, 25165824, grid=grid(25165824), stream=stream0) | |
del addmm_52 | |
buf215 = reinterpret_tensor(buf207, (8192, 768), (768, 1), 0); del buf207 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf214, (8192, 3072), (3072, 1), 0), permute_245, out=buf215) | |
del permute_245 | |
buf216 = reinterpret_tensor(buf209, (3072, 768), (768, 1), 0); del buf209 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf214, (3072, 8192), (1, 3072), 0), view_194, out=buf216) | |
del view_194 | |
buf217 = buf158; del buf158 # reuse | |
# Source Nodes: [], Original ATen: [aten.sum] | |
triton_red_fused_sum_16.run(buf214, buf217, 98304, 256, grid=grid(98304), stream=stream0) | |
buf220 = empty_strided_cuda((3072, ), (1, ), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum] | |
triton_per_fused__to_copy_sum_17.run(buf217, buf220, 3072, 32, grid=grid(3072), stream=stream0) | |
buf219 = empty_strided_cuda((3072, 768), (768, 1), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy] | |
triton_poi_fused__to_copy_18.run(buf216, buf219, 2359296, grid=grid(2359296), stream=stream0) | |
buf223 = buf164; del buf164 # reuse | |
buf228 = reinterpret_tensor(buf193, (16, 512, 768), (393216, 768, 1), 0); del buf193 # reuse | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_dropout_backward, aten.native_layer_norm_backward] | |
triton_per_fused__to_copy_add_native_dropout_backward_native_layer_norm_backward_19.run(buf202, buf215, primals_142, mul_113, div_37, gt_26, buf223, buf228, 8192, 768, grid=grid(8192), stream=stream0) | |
del div_37 | |
del gt_26 | |
del primals_142 | |
buf224 = reinterpret_tensor(buf210, (768, 64), (1, 768), 0); del buf210 # reuse | |
buf226 = buf203; del buf203 # reuse | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward] | |
triton_red_fused__to_copy_add_native_layer_norm_backward_20.run(buf202, buf215, mul_113, buf224, buf226, 49152, 128, grid=grid(49152), stream=stream0) | |
del mul_113 | |
buf225 = empty_strided_cuda((768, ), (1, ), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward] | |
triton_per_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_7.run(buf224, buf225, 768, 64, grid=grid(768), stream=stream0) | |
buf227 = empty_strided_cuda((768, ), (1, ), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward] | |
triton_per_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_7.run(buf226, buf227, 768, 64, grid=grid(768), stream=stream0) | |
buf229 = buf215; del buf215 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf228, (8192, 768), (768, 1), 0), permute_249, out=buf229) | |
del permute_249 | |
buf230 = buf194; del buf194 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf228, (768, 8192), (1, 768), 0), view_192, out=buf230) | |
del view_192 | |
buf231 = reinterpret_tensor(buf226, (1, 768, 64), (49152, 1, 768), 0); del buf226 # reuse | |
# Source Nodes: [], Original ATen: [aten.sum] | |
triton_red_fused_sum_13.run(buf228, buf231, 49152, 128, grid=grid(49152), stream=stream0) | |
buf234 = empty_strided_cuda((768, ), (1, ), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum] | |
triton_per_fused__to_copy_sum_9.run(buf231, buf234, 768, 64, grid=grid(768), stream=stream0) | |
buf233 = empty_strided_cuda((768, 768), (768, 1), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy] | |
triton_poi_fused__to_copy_10.run(buf230, buf233, 589824, grid=grid(589824), stream=stream0) | |
buf235 = reinterpret_tensor(buf228, (16, 12, 512, 64), (393216, 32768, 64, 1), 0); del buf228 # reuse | |
# Source Nodes: [], Original ATen: [aten.clone] | |
triton_poi_fused_clone_21.run(buf229, buf235, 6291456, grid=grid(6291456), stream=stream0) | |
del buf229 | |
# Source Nodes: [], Original ATen: [aten.clone] | |
buf236 = aten._scaled_dot_product_flash_attention_backward.default(buf235, permute_default_18, permute_default_19, permute_default_20, getitem_73, getitem_74, None, None, 512, 512, 0.1, False, getitem_75, getitem_76, scale=0.125) | |
del getitem_73 | |
del getitem_74 | |
del getitem_75 | |
del getitem_76 | |
del permute_default_18 | |
del permute_default_19 | |
del permute_default_20 | |
buf237 = buf236[0] | |
buf238 = buf236[1] | |
buf239 = buf236[2] | |
del buf236 | |
buf240 = reinterpret_tensor(buf235, (8192, 768), (768, 1), 0); del buf235 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf239, (8192, 768), (768, 1), 0), permute_261, out=buf240) | |
del permute_261 | |
buf241 = buf230; del buf230 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf239, (768, 8192), (1, 768), 0), view_176, out=buf241) | |
buf242 = buf231; del buf231 # reuse | |
# Source Nodes: [], Original ATen: [aten.sum] | |
triton_red_fused_sum_22.run(buf239, buf242, 49152, 128, grid=grid(49152), stream=stream0) | |
buf245 = empty_strided_cuda((768, ), (1, ), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum] | |
triton_per_fused__to_copy_sum_9.run(buf242, buf245, 768, 64, grid=grid(768), stream=stream0) | |
buf244 = empty_strided_cuda((768, 768), (768, 1), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy] | |
triton_poi_fused__to_copy_10.run(buf241, buf244, 589824, grid=grid(589824), stream=stream0) | |
buf246 = reinterpret_tensor(buf239, (8192, 768), (768, 1), 0); del buf239 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf238, (8192, 768), (768, 1), 0), permute_266, out=buf246) | |
del permute_266 | |
buf247 = buf241; del buf241 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf238, (768, 8192), (1, 768), 0), view_176, out=buf247) | |
buf248 = buf242; del buf242 # reuse | |
# Source Nodes: [], Original ATen: [aten.sum] | |
triton_red_fused_sum_22.run(buf238, buf248, 49152, 128, grid=grid(49152), stream=stream0) | |
buf251 = empty_strided_cuda((768, ), (1, ), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum] | |
triton_per_fused__to_copy_sum_9.run(buf248, buf251, 768, 64, grid=grid(768), stream=stream0) | |
buf250 = empty_strided_cuda((768, 768), (768, 1), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy] | |
triton_poi_fused__to_copy_10.run(buf247, buf250, 589824, grid=grid(589824), stream=stream0) | |
buf252 = reinterpret_tensor(buf238, (8192, 768), (768, 1), 0); del buf238 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf237, (8192, 768), (768, 1), 0), permute_270, out=buf252) | |
del permute_270 | |
buf253 = buf247; del buf247 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf237, (768, 8192), (1, 768), 0), view_176, out=buf253) | |
del view_176 | |
buf254 = buf248; del buf248 # reuse | |
# Source Nodes: [], Original ATen: [aten.sum] | |
triton_red_fused_sum_22.run(buf237, buf254, 49152, 128, grid=grid(49152), stream=stream0) | |
buf257 = empty_strided_cuda((768, ), (1, ), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum] | |
triton_per_fused__to_copy_sum_9.run(buf254, buf257, 768, 64, grid=grid(768), stream=stream0) | |
buf256 = empty_strided_cuda((768, 768), (768, 1), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy] | |
triton_poi_fused__to_copy_10.run(buf253, buf256, 589824, grid=grid(589824), stream=stream0) | |
buf261 = buf202; del buf202 # reuse | |
buf266 = reinterpret_tensor(buf237, (16, 512, 768), (393216, 768, 1), 0); del buf237 # reuse | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_dropout_backward, aten.native_layer_norm_backward] | |
triton_per_fused__to_copy_add_native_dropout_backward_native_layer_norm_backward_23.run(buf223, buf240, buf246, buf252, primals_132, mul_107, div_39, gt_24, buf261, buf266, 8192, 768, grid=grid(8192), stream=stream0) | |
del div_39 | |
del gt_24 | |
del primals_132 | |
buf262 = reinterpret_tensor(buf254, (768, 64), (1, 768), 0); del buf254 # reuse | |
buf264 = buf224; del buf224 # reuse | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward] | |
triton_red_fused__to_copy_add_native_layer_norm_backward_24.run(buf223, buf240, buf246, buf252, mul_107, buf262, buf264, 49152, 128, grid=grid(49152), stream=stream0) | |
del buf240 | |
del buf246 | |
del mul_107 | |
buf263 = empty_strided_cuda((768, ), (1, ), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward] | |
triton_per_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_7.run(buf262, buf263, 768, 64, grid=grid(768), stream=stream0) | |
buf265 = empty_strided_cuda((768, ), (1, ), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward] | |
triton_per_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_7.run(buf264, buf265, 768, 64, grid=grid(768), stream=stream0) | |
buf267 = reinterpret_tensor(buf214, (8192, 3072), (3072, 1), 0); del buf214 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf266, (8192, 768), (768, 1), 0), permute_274, out=buf267) | |
del permute_274 | |
buf268 = reinterpret_tensor(buf216, (768, 3072), (3072, 1), 0); del buf216 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf266, (768, 8192), (1, 768), 0), view_174, out=buf268) | |
del view_174 | |
buf269 = reinterpret_tensor(buf264, (1, 768, 64), (49152, 1, 768), 0); del buf264 # reuse | |
# Source Nodes: [], Original ATen: [aten.sum] | |
triton_red_fused_sum_13.run(buf266, buf269, 49152, 128, grid=grid(49152), stream=stream0) | |
buf272 = empty_strided_cuda((768, ), (1, ), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum] | |
triton_per_fused__to_copy_sum_9.run(buf269, buf272, 768, 64, grid=grid(768), stream=stream0) | |
buf271 = empty_strided_cuda((768, 3072), (3072, 1), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy] | |
triton_poi_fused__to_copy_14.run(buf268, buf271, 2359296, grid=grid(2359296), stream=stream0) | |
buf273 = reinterpret_tensor(buf267, (16, 512, 3072), (1572864, 3072, 1), 0); del buf267 # reuse | |
# Source Nodes: [hidden_states_60], Original ATen: [aten.gelu, aten.gelu_backward] | |
triton_poi_fused_gelu_gelu_backward_15.run(buf273, addmm_46, 25165824, grid=grid(25165824), stream=stream0) | |
del addmm_46 | |
buf274 = reinterpret_tensor(buf266, (8192, 768), (768, 1), 0); del buf266 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf273, (8192, 3072), (3072, 1), 0), permute_278, out=buf274) | |
del permute_278 | |
buf275 = reinterpret_tensor(buf268, (3072, 768), (768, 1), 0); del buf268 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf273, (3072, 8192), (1, 3072), 0), view_172, out=buf275) | |
del view_172 | |
buf276 = buf217; del buf217 # reuse | |
# Source Nodes: [], Original ATen: [aten.sum] | |
triton_red_fused_sum_16.run(buf273, buf276, 98304, 256, grid=grid(98304), stream=stream0) | |
buf279 = empty_strided_cuda((3072, ), (1, ), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum] | |
triton_per_fused__to_copy_sum_17.run(buf276, buf279, 3072, 32, grid=grid(3072), stream=stream0) | |
buf278 = empty_strided_cuda((3072, 768), (768, 1), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy] | |
triton_poi_fused__to_copy_18.run(buf275, buf278, 2359296, grid=grid(2359296), stream=stream0) | |
buf282 = buf223; del buf223 # reuse | |
buf287 = reinterpret_tensor(buf252, (16, 512, 768), (393216, 768, 1), 0); del buf252 # reuse | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_dropout_backward, aten.native_layer_norm_backward] | |
triton_per_fused__to_copy_add_native_dropout_backward_native_layer_norm_backward_19.run(buf261, buf274, primals_126, mul_100, div_40, gt_23, buf282, buf287, 8192, 768, grid=grid(8192), stream=stream0) | |
del div_40 | |
del gt_23 | |
del primals_126 | |
buf283 = reinterpret_tensor(buf269, (768, 64), (1, 768), 0); del buf269 # reuse | |
buf285 = buf262; del buf262 # reuse | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward] | |
triton_red_fused__to_copy_add_native_layer_norm_backward_20.run(buf261, buf274, mul_100, buf283, buf285, 49152, 128, grid=grid(49152), stream=stream0) | |
del mul_100 | |
buf284 = empty_strided_cuda((768, ), (1, ), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward] | |
triton_per_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_7.run(buf283, buf284, 768, 64, grid=grid(768), stream=stream0) | |
buf286 = empty_strided_cuda((768, ), (1, ), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward] | |
triton_per_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_7.run(buf285, buf286, 768, 64, grid=grid(768), stream=stream0) | |
buf288 = buf274; del buf274 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf287, (8192, 768), (768, 1), 0), permute_282, out=buf288) | |
del permute_282 | |
buf289 = buf253; del buf253 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf287, (768, 8192), (1, 768), 0), view_170, out=buf289) | |
del view_170 | |
buf290 = reinterpret_tensor(buf285, (1, 768, 64), (49152, 1, 768), 0); del buf285 # reuse | |
# Source Nodes: [], Original ATen: [aten.sum] | |
triton_red_fused_sum_13.run(buf287, buf290, 49152, 128, grid=grid(49152), stream=stream0) | |
buf293 = empty_strided_cuda((768, ), (1, ), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum] | |
triton_per_fused__to_copy_sum_9.run(buf290, buf293, 768, 64, grid=grid(768), stream=stream0) | |
buf292 = empty_strided_cuda((768, 768), (768, 1), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy] | |
triton_poi_fused__to_copy_10.run(buf289, buf292, 589824, grid=grid(589824), stream=stream0) | |
buf294 = reinterpret_tensor(buf287, (16, 12, 512, 64), (393216, 32768, 64, 1), 0); del buf287 # reuse | |
# Source Nodes: [], Original ATen: [aten.clone] | |
triton_poi_fused_clone_21.run(buf288, buf294, 6291456, grid=grid(6291456), stream=stream0) | |
del buf288 | |
# Source Nodes: [], Original ATen: [aten.clone] | |
buf295 = aten._scaled_dot_product_flash_attention_backward.default(buf294, permute_default_24, permute_default_25, permute_default_26, getitem_80, getitem_81, None, None, 512, 512, 0.1, False, getitem_82, getitem_83, scale=0.125) | |
del getitem_80 | |
del getitem_81 | |
del getitem_82 | |
del getitem_83 | |
del permute_default_24 | |
del permute_default_25 | |
del permute_default_26 | |
buf296 = buf295[0] | |
buf297 = buf295[1] | |
buf298 = buf295[2] | |
del buf295 | |
buf299 = reinterpret_tensor(buf294, (8192, 768), (768, 1), 0); del buf294 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf298, (8192, 768), (768, 1), 0), permute_294, out=buf299) | |
del permute_294 | |
buf300 = buf289; del buf289 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf298, (768, 8192), (1, 768), 0), view_154, out=buf300) | |
buf301 = buf290; del buf290 # reuse | |
# Source Nodes: [], Original ATen: [aten.sum] | |
triton_red_fused_sum_22.run(buf298, buf301, 49152, 128, grid=grid(49152), stream=stream0) | |
buf304 = empty_strided_cuda((768, ), (1, ), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum] | |
triton_per_fused__to_copy_sum_9.run(buf301, buf304, 768, 64, grid=grid(768), stream=stream0) | |
buf303 = empty_strided_cuda((768, 768), (768, 1), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy] | |
triton_poi_fused__to_copy_10.run(buf300, buf303, 589824, grid=grid(589824), stream=stream0) | |
buf305 = reinterpret_tensor(buf298, (8192, 768), (768, 1), 0); del buf298 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf297, (8192, 768), (768, 1), 0), permute_299, out=buf305) | |
del permute_299 | |
buf306 = buf300; del buf300 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf297, (768, 8192), (1, 768), 0), view_154, out=buf306) | |
buf307 = buf301; del buf301 # reuse | |
# Source Nodes: [], Original ATen: [aten.sum] | |
triton_red_fused_sum_22.run(buf297, buf307, 49152, 128, grid=grid(49152), stream=stream0) | |
buf310 = empty_strided_cuda((768, ), (1, ), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum] | |
triton_per_fused__to_copy_sum_9.run(buf307, buf310, 768, 64, grid=grid(768), stream=stream0) | |
buf309 = empty_strided_cuda((768, 768), (768, 1), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy] | |
triton_poi_fused__to_copy_10.run(buf306, buf309, 589824, grid=grid(589824), stream=stream0) | |
buf311 = reinterpret_tensor(buf297, (8192, 768), (768, 1), 0); del buf297 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf296, (8192, 768), (768, 1), 0), permute_303, out=buf311) | |
del permute_303 | |
buf312 = buf306; del buf306 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf296, (768, 8192), (1, 768), 0), view_154, out=buf312) | |
del view_154 | |
buf313 = buf307; del buf307 # reuse | |
# Source Nodes: [], Original ATen: [aten.sum] | |
triton_red_fused_sum_22.run(buf296, buf313, 49152, 128, grid=grid(49152), stream=stream0) | |
buf316 = empty_strided_cuda((768, ), (1, ), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum] | |
triton_per_fused__to_copy_sum_9.run(buf313, buf316, 768, 64, grid=grid(768), stream=stream0) | |
buf315 = empty_strided_cuda((768, 768), (768, 1), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy] | |
triton_poi_fused__to_copy_10.run(buf312, buf315, 589824, grid=grid(589824), stream=stream0) | |
buf320 = buf261; del buf261 # reuse | |
buf325 = reinterpret_tensor(buf296, (16, 512, 768), (393216, 768, 1), 0); del buf296 # reuse | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_dropout_backward, aten.native_layer_norm_backward] | |
triton_per_fused__to_copy_add_native_dropout_backward_native_layer_norm_backward_23.run(buf282, buf299, buf305, buf311, primals_116, mul_94, div_42, gt_21, buf320, buf325, 8192, 768, grid=grid(8192), stream=stream0) | |
del div_42 | |
del gt_21 | |
del primals_116 | |
buf321 = reinterpret_tensor(buf313, (768, 64), (1, 768), 0); del buf313 # reuse | |
buf323 = buf283; del buf283 # reuse | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward] | |
triton_red_fused__to_copy_add_native_layer_norm_backward_24.run(buf282, buf299, buf305, buf311, mul_94, buf321, buf323, 49152, 128, grid=grid(49152), stream=stream0) | |
del buf299 | |
del buf305 | |
del mul_94 | |
buf322 = empty_strided_cuda((768, ), (1, ), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward] | |
triton_per_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_7.run(buf321, buf322, 768, 64, grid=grid(768), stream=stream0) | |
buf324 = empty_strided_cuda((768, ), (1, ), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward] | |
triton_per_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_7.run(buf323, buf324, 768, 64, grid=grid(768), stream=stream0) | |
buf326 = reinterpret_tensor(buf273, (8192, 3072), (3072, 1), 0); del buf273 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf325, (8192, 768), (768, 1), 0), permute_307, out=buf326) | |
del permute_307 | |
buf327 = reinterpret_tensor(buf275, (768, 3072), (3072, 1), 0); del buf275 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf325, (768, 8192), (1, 768), 0), view_152, out=buf327) | |
del view_152 | |
buf328 = reinterpret_tensor(buf323, (1, 768, 64), (49152, 1, 768), 0); del buf323 # reuse | |
# Source Nodes: [], Original ATen: [aten.sum] | |
triton_red_fused_sum_13.run(buf325, buf328, 49152, 128, grid=grid(49152), stream=stream0) | |
buf331 = empty_strided_cuda((768, ), (1, ), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum] | |
triton_per_fused__to_copy_sum_9.run(buf328, buf331, 768, 64, grid=grid(768), stream=stream0) | |
buf330 = empty_strided_cuda((768, 3072), (3072, 1), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy] | |
triton_poi_fused__to_copy_14.run(buf327, buf330, 2359296, grid=grid(2359296), stream=stream0) | |
buf332 = reinterpret_tensor(buf326, (16, 512, 3072), (1572864, 3072, 1), 0); del buf326 # reuse | |
# Source Nodes: [hidden_states_52], Original ATen: [aten.gelu, aten.gelu_backward] | |
triton_poi_fused_gelu_gelu_backward_15.run(buf332, addmm_40, 25165824, grid=grid(25165824), stream=stream0) | |
del addmm_40 | |
buf333 = reinterpret_tensor(buf325, (8192, 768), (768, 1), 0); del buf325 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf332, (8192, 3072), (3072, 1), 0), permute_311, out=buf333) | |
del permute_311 | |
buf334 = reinterpret_tensor(buf327, (3072, 768), (768, 1), 0); del buf327 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf332, (3072, 8192), (1, 3072), 0), view_150, out=buf334) | |
del view_150 | |
buf335 = buf276; del buf276 # reuse | |
# Source Nodes: [], Original ATen: [aten.sum] | |
triton_red_fused_sum_16.run(buf332, buf335, 98304, 256, grid=grid(98304), stream=stream0) | |
buf338 = empty_strided_cuda((3072, ), (1, ), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum] | |
triton_per_fused__to_copy_sum_17.run(buf335, buf338, 3072, 32, grid=grid(3072), stream=stream0) | |
buf337 = empty_strided_cuda((3072, 768), (768, 1), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy] | |
triton_poi_fused__to_copy_18.run(buf334, buf337, 2359296, grid=grid(2359296), stream=stream0) | |
buf341 = buf282; del buf282 # reuse | |
buf346 = reinterpret_tensor(buf311, (16, 512, 768), (393216, 768, 1), 0); del buf311 # reuse | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_dropout_backward, aten.native_layer_norm_backward] | |
triton_per_fused__to_copy_add_native_dropout_backward_native_layer_norm_backward_19.run(buf320, buf333, primals_110, mul_87, div_43, gt_20, buf341, buf346, 8192, 768, grid=grid(8192), stream=stream0) | |
del div_43 | |
del gt_20 | |
del primals_110 | |
buf342 = reinterpret_tensor(buf328, (768, 64), (1, 768), 0); del buf328 # reuse | |
buf344 = buf321; del buf321 # reuse | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward] | |
triton_red_fused__to_copy_add_native_layer_norm_backward_20.run(buf320, buf333, mul_87, buf342, buf344, 49152, 128, grid=grid(49152), stream=stream0) | |
del mul_87 | |
buf343 = empty_strided_cuda((768, ), (1, ), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward] | |
triton_per_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_7.run(buf342, buf343, 768, 64, grid=grid(768), stream=stream0) | |
buf345 = empty_strided_cuda((768, ), (1, ), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward] | |
triton_per_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_7.run(buf344, buf345, 768, 64, grid=grid(768), stream=stream0) | |
buf347 = buf333; del buf333 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf346, (8192, 768), (768, 1), 0), permute_315, out=buf347) | |
del permute_315 | |
buf348 = buf312; del buf312 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf346, (768, 8192), (1, 768), 0), view_148, out=buf348) | |
del view_148 | |
buf349 = reinterpret_tensor(buf344, (1, 768, 64), (49152, 1, 768), 0); del buf344 # reuse | |
# Source Nodes: [], Original ATen: [aten.sum] | |
triton_red_fused_sum_13.run(buf346, buf349, 49152, 128, grid=grid(49152), stream=stream0) | |
buf352 = empty_strided_cuda((768, ), (1, ), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum] | |
triton_per_fused__to_copy_sum_9.run(buf349, buf352, 768, 64, grid=grid(768), stream=stream0) | |
buf351 = empty_strided_cuda((768, 768), (768, 1), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy] | |
triton_poi_fused__to_copy_10.run(buf348, buf351, 589824, grid=grid(589824), stream=stream0) | |
buf353 = reinterpret_tensor(buf346, (16, 12, 512, 64), (393216, 32768, 64, 1), 0); del buf346 # reuse | |
# Source Nodes: [], Original ATen: [aten.clone] | |
triton_poi_fused_clone_21.run(buf347, buf353, 6291456, grid=grid(6291456), stream=stream0) | |
del buf347 | |
# Source Nodes: [], Original ATen: [aten.clone] | |
buf354 = aten._scaled_dot_product_flash_attention_backward.default(buf353, permute_default_30, permute_default_31, permute_default_32, getitem_87, getitem_88, None, None, 512, 512, 0.1, False, getitem_89, getitem_90, scale=0.125) | |
del getitem_87 | |
del getitem_88 | |
del getitem_89 | |
del getitem_90 | |
del permute_default_30 | |
del permute_default_31 | |
del permute_default_32 | |
buf355 = buf354[0] | |
buf356 = buf354[1] | |
buf357 = buf354[2] | |
del buf354 | |
buf358 = reinterpret_tensor(buf353, (8192, 768), (768, 1), 0); del buf353 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf357, (8192, 768), (768, 1), 0), permute_327, out=buf358) | |
del permute_327 | |
buf359 = buf348; del buf348 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf357, (768, 8192), (1, 768), 0), view_132, out=buf359) | |
buf360 = buf349; del buf349 # reuse | |
# Source Nodes: [], Original ATen: [aten.sum] | |
triton_red_fused_sum_22.run(buf357, buf360, 49152, 128, grid=grid(49152), stream=stream0) | |
buf363 = empty_strided_cuda((768, ), (1, ), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum] | |
triton_per_fused__to_copy_sum_9.run(buf360, buf363, 768, 64, grid=grid(768), stream=stream0) | |
buf362 = empty_strided_cuda((768, 768), (768, 1), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy] | |
triton_poi_fused__to_copy_10.run(buf359, buf362, 589824, grid=grid(589824), stream=stream0) | |
buf364 = reinterpret_tensor(buf357, (8192, 768), (768, 1), 0); del buf357 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf356, (8192, 768), (768, 1), 0), permute_332, out=buf364) | |
del permute_332 | |
buf365 = buf359; del buf359 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf356, (768, 8192), (1, 768), 0), view_132, out=buf365) | |
buf366 = buf360; del buf360 # reuse | |
# Source Nodes: [], Original ATen: [aten.sum] | |
triton_red_fused_sum_22.run(buf356, buf366, 49152, 128, grid=grid(49152), stream=stream0) | |
buf369 = empty_strided_cuda((768, ), (1, ), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum] | |
triton_per_fused__to_copy_sum_9.run(buf366, buf369, 768, 64, grid=grid(768), stream=stream0) | |
buf368 = empty_strided_cuda((768, 768), (768, 1), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy] | |
triton_poi_fused__to_copy_10.run(buf365, buf368, 589824, grid=grid(589824), stream=stream0) | |
buf370 = reinterpret_tensor(buf356, (8192, 768), (768, 1), 0); del buf356 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf355, (8192, 768), (768, 1), 0), permute_336, out=buf370) | |
del permute_336 | |
buf371 = buf365; del buf365 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf355, (768, 8192), (1, 768), 0), view_132, out=buf371) | |
del view_132 | |
buf372 = buf366; del buf366 # reuse | |
# Source Nodes: [], Original ATen: [aten.sum] | |
triton_red_fused_sum_22.run(buf355, buf372, 49152, 128, grid=grid(49152), stream=stream0) | |
buf375 = empty_strided_cuda((768, ), (1, ), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum] | |
triton_per_fused__to_copy_sum_9.run(buf372, buf375, 768, 64, grid=grid(768), stream=stream0) | |
buf374 = empty_strided_cuda((768, 768), (768, 1), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy] | |
triton_poi_fused__to_copy_10.run(buf371, buf374, 589824, grid=grid(589824), stream=stream0) | |
buf379 = buf320; del buf320 # reuse | |
buf384 = reinterpret_tensor(buf355, (16, 512, 768), (393216, 768, 1), 0); del buf355 # reuse | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_dropout_backward, aten.native_layer_norm_backward] | |
triton_per_fused__to_copy_add_native_dropout_backward_native_layer_norm_backward_23.run(buf341, buf358, buf364, buf370, primals_100, mul_81, div_45, gt_18, buf379, buf384, 8192, 768, grid=grid(8192), stream=stream0) | |
del div_45 | |
del gt_18 | |
del primals_100 | |
buf380 = reinterpret_tensor(buf372, (768, 64), (1, 768), 0); del buf372 # reuse | |
buf382 = buf342; del buf342 # reuse | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward] | |
triton_red_fused__to_copy_add_native_layer_norm_backward_24.run(buf341, buf358, buf364, buf370, mul_81, buf380, buf382, 49152, 128, grid=grid(49152), stream=stream0) | |
del buf358 | |
del buf364 | |
del mul_81 | |
buf381 = empty_strided_cuda((768, ), (1, ), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward] | |
triton_per_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_7.run(buf380, buf381, 768, 64, grid=grid(768), stream=stream0) | |
buf383 = empty_strided_cuda((768, ), (1, ), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward] | |
triton_per_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_7.run(buf382, buf383, 768, 64, grid=grid(768), stream=stream0) | |
buf385 = reinterpret_tensor(buf332, (8192, 3072), (3072, 1), 0); del buf332 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf384, (8192, 768), (768, 1), 0), permute_340, out=buf385) | |
del permute_340 | |
buf386 = reinterpret_tensor(buf334, (768, 3072), (3072, 1), 0); del buf334 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf384, (768, 8192), (1, 768), 0), view_130, out=buf386) | |
del view_130 | |
buf387 = reinterpret_tensor(buf382, (1, 768, 64), (49152, 1, 768), 0); del buf382 # reuse | |
# Source Nodes: [], Original ATen: [aten.sum] | |
triton_red_fused_sum_13.run(buf384, buf387, 49152, 128, grid=grid(49152), stream=stream0) | |
buf390 = empty_strided_cuda((768, ), (1, ), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum] | |
triton_per_fused__to_copy_sum_9.run(buf387, buf390, 768, 64, grid=grid(768), stream=stream0) | |
buf389 = empty_strided_cuda((768, 3072), (3072, 1), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy] | |
triton_poi_fused__to_copy_14.run(buf386, buf389, 2359296, grid=grid(2359296), stream=stream0) | |
buf391 = reinterpret_tensor(buf385, (16, 512, 3072), (1572864, 3072, 1), 0); del buf385 # reuse | |
# Source Nodes: [hidden_states_44], Original ATen: [aten.gelu, aten.gelu_backward] | |
triton_poi_fused_gelu_gelu_backward_15.run(buf391, addmm_34, 25165824, grid=grid(25165824), stream=stream0) | |
del addmm_34 | |
buf392 = reinterpret_tensor(buf384, (8192, 768), (768, 1), 0); del buf384 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf391, (8192, 3072), (3072, 1), 0), permute_344, out=buf392) | |
del permute_344 | |
buf393 = reinterpret_tensor(buf386, (3072, 768), (768, 1), 0); del buf386 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf391, (3072, 8192), (1, 3072), 0), view_128, out=buf393) | |
del view_128 | |
buf394 = buf335; del buf335 # reuse | |
# Source Nodes: [], Original ATen: [aten.sum] | |
triton_red_fused_sum_16.run(buf391, buf394, 98304, 256, grid=grid(98304), stream=stream0) | |
buf397 = empty_strided_cuda((3072, ), (1, ), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum] | |
triton_per_fused__to_copy_sum_17.run(buf394, buf397, 3072, 32, grid=grid(3072), stream=stream0) | |
buf396 = empty_strided_cuda((3072, 768), (768, 1), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy] | |
triton_poi_fused__to_copy_18.run(buf393, buf396, 2359296, grid=grid(2359296), stream=stream0) | |
buf400 = buf341; del buf341 # reuse | |
buf405 = reinterpret_tensor(buf370, (16, 512, 768), (393216, 768, 1), 0); del buf370 # reuse | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_dropout_backward, aten.native_layer_norm_backward] | |
triton_per_fused__to_copy_add_native_dropout_backward_native_layer_norm_backward_19.run(buf379, buf392, primals_94, mul_74, div_46, gt_17, buf400, buf405, 8192, 768, grid=grid(8192), stream=stream0) | |
del div_46 | |
del gt_17 | |
del primals_94 | |
buf401 = reinterpret_tensor(buf387, (768, 64), (1, 768), 0); del buf387 # reuse | |
buf403 = buf380; del buf380 # reuse | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward] | |
triton_red_fused__to_copy_add_native_layer_norm_backward_20.run(buf379, buf392, mul_74, buf401, buf403, 49152, 128, grid=grid(49152), stream=stream0) | |
del mul_74 | |
buf402 = empty_strided_cuda((768, ), (1, ), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward] | |
triton_per_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_7.run(buf401, buf402, 768, 64, grid=grid(768), stream=stream0) | |
buf404 = empty_strided_cuda((768, ), (1, ), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward] | |
triton_per_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_7.run(buf403, buf404, 768, 64, grid=grid(768), stream=stream0) | |
buf406 = buf392; del buf392 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf405, (8192, 768), (768, 1), 0), permute_348, out=buf406) | |
del permute_348 | |
buf407 = buf371; del buf371 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf405, (768, 8192), (1, 768), 0), view_126, out=buf407) | |
del view_126 | |
buf408 = reinterpret_tensor(buf403, (1, 768, 64), (49152, 1, 768), 0); del buf403 # reuse | |
# Source Nodes: [], Original ATen: [aten.sum] | |
triton_red_fused_sum_13.run(buf405, buf408, 49152, 128, grid=grid(49152), stream=stream0) | |
buf411 = empty_strided_cuda((768, ), (1, ), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum] | |
triton_per_fused__to_copy_sum_9.run(buf408, buf411, 768, 64, grid=grid(768), stream=stream0) | |
buf410 = empty_strided_cuda((768, 768), (768, 1), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy] | |
triton_poi_fused__to_copy_10.run(buf407, buf410, 589824, grid=grid(589824), stream=stream0) | |
buf412 = reinterpret_tensor(buf405, (16, 12, 512, 64), (393216, 32768, 64, 1), 0); del buf405 # reuse | |
# Source Nodes: [], Original ATen: [aten.clone] | |
triton_poi_fused_clone_21.run(buf406, buf412, 6291456, grid=grid(6291456), stream=stream0) | |
del buf406 | |
# Source Nodes: [], Original ATen: [aten.clone] | |
buf413 = aten._scaled_dot_product_flash_attention_backward.default(buf412, permute_default_36, permute_default_37, permute_default_38, getitem_94, getitem_95, None, None, 512, 512, 0.1, False, getitem_96, getitem_97, scale=0.125) | |
del getitem_94 | |
del getitem_95 | |
del getitem_96 | |
del getitem_97 | |
del permute_default_36 | |
del permute_default_37 | |
del permute_default_38 | |
buf414 = buf413[0] | |
buf415 = buf413[1] | |
buf416 = buf413[2] | |
del buf413 | |
buf417 = reinterpret_tensor(buf412, (8192, 768), (768, 1), 0); del buf412 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf416, (8192, 768), (768, 1), 0), permute_360, out=buf417) | |
del permute_360 | |
buf418 = buf407; del buf407 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf416, (768, 8192), (1, 768), 0), view_110, out=buf418) | |
buf419 = buf408; del buf408 # reuse | |
# Source Nodes: [], Original ATen: [aten.sum] | |
triton_red_fused_sum_22.run(buf416, buf419, 49152, 128, grid=grid(49152), stream=stream0) | |
buf422 = empty_strided_cuda((768, ), (1, ), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum] | |
triton_per_fused__to_copy_sum_9.run(buf419, buf422, 768, 64, grid=grid(768), stream=stream0) | |
buf421 = empty_strided_cuda((768, 768), (768, 1), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy] | |
triton_poi_fused__to_copy_10.run(buf418, buf421, 589824, grid=grid(589824), stream=stream0) | |
buf423 = reinterpret_tensor(buf416, (8192, 768), (768, 1), 0); del buf416 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf415, (8192, 768), (768, 1), 0), permute_365, out=buf423) | |
del permute_365 | |
buf424 = buf418; del buf418 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf415, (768, 8192), (1, 768), 0), view_110, out=buf424) | |
buf425 = buf419; del buf419 # reuse | |
# Source Nodes: [], Original ATen: [aten.sum] | |
triton_red_fused_sum_22.run(buf415, buf425, 49152, 128, grid=grid(49152), stream=stream0) | |
buf428 = empty_strided_cuda((768, ), (1, ), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum] | |
triton_per_fused__to_copy_sum_9.run(buf425, buf428, 768, 64, grid=grid(768), stream=stream0) | |
buf427 = empty_strided_cuda((768, 768), (768, 1), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy] | |
triton_poi_fused__to_copy_10.run(buf424, buf427, 589824, grid=grid(589824), stream=stream0) | |
buf429 = reinterpret_tensor(buf415, (8192, 768), (768, 1), 0); del buf415 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf414, (8192, 768), (768, 1), 0), permute_369, out=buf429) | |
del permute_369 | |
buf430 = buf424; del buf424 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf414, (768, 8192), (1, 768), 0), view_110, out=buf430) | |
del view_110 | |
buf431 = buf425; del buf425 # reuse | |
# Source Nodes: [], Original ATen: [aten.sum] | |
triton_red_fused_sum_22.run(buf414, buf431, 49152, 128, grid=grid(49152), stream=stream0) | |
buf434 = empty_strided_cuda((768, ), (1, ), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum] | |
triton_per_fused__to_copy_sum_9.run(buf431, buf434, 768, 64, grid=grid(768), stream=stream0) | |
buf433 = empty_strided_cuda((768, 768), (768, 1), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy] | |
triton_poi_fused__to_copy_10.run(buf430, buf433, 589824, grid=grid(589824), stream=stream0) | |
buf438 = buf379; del buf379 # reuse | |
buf443 = reinterpret_tensor(buf414, (16, 512, 768), (393216, 768, 1), 0); del buf414 # reuse | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_dropout_backward, aten.native_layer_norm_backward] | |
triton_per_fused__to_copy_add_native_dropout_backward_native_layer_norm_backward_23.run(buf400, buf417, buf423, buf429, primals_84, mul_68, div_48, gt_15, buf438, buf443, 8192, 768, grid=grid(8192), stream=stream0) | |
del div_48 | |
del gt_15 | |
del primals_84 | |
buf439 = reinterpret_tensor(buf431, (768, 64), (1, 768), 0); del buf431 # reuse | |
buf441 = buf401; del buf401 # reuse | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward] | |
triton_red_fused__to_copy_add_native_layer_norm_backward_24.run(buf400, buf417, buf423, buf429, mul_68, buf439, buf441, 49152, 128, grid=grid(49152), stream=stream0) | |
del buf417 | |
del buf423 | |
del mul_68 | |
buf440 = empty_strided_cuda((768, ), (1, ), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward] | |
triton_per_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_7.run(buf439, buf440, 768, 64, grid=grid(768), stream=stream0) | |
buf442 = empty_strided_cuda((768, ), (1, ), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward] | |
triton_per_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_7.run(buf441, buf442, 768, 64, grid=grid(768), stream=stream0) | |
buf444 = reinterpret_tensor(buf391, (8192, 3072), (3072, 1), 0); del buf391 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf443, (8192, 768), (768, 1), 0), permute_373, out=buf444) | |
del permute_373 | |
buf445 = reinterpret_tensor(buf393, (768, 3072), (3072, 1), 0); del buf393 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf443, (768, 8192), (1, 768), 0), view_108, out=buf445) | |
del view_108 | |
buf446 = reinterpret_tensor(buf441, (1, 768, 64), (49152, 1, 768), 0); del buf441 # reuse | |
# Source Nodes: [], Original ATen: [aten.sum] | |
triton_red_fused_sum_13.run(buf443, buf446, 49152, 128, grid=grid(49152), stream=stream0) | |
buf449 = empty_strided_cuda((768, ), (1, ), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum] | |
triton_per_fused__to_copy_sum_9.run(buf446, buf449, 768, 64, grid=grid(768), stream=stream0) | |
buf448 = empty_strided_cuda((768, 3072), (3072, 1), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy] | |
triton_poi_fused__to_copy_14.run(buf445, buf448, 2359296, grid=grid(2359296), stream=stream0) | |
buf450 = reinterpret_tensor(buf444, (16, 512, 3072), (1572864, 3072, 1), 0); del buf444 # reuse | |
# Source Nodes: [hidden_states_36], Original ATen: [aten.gelu, aten.gelu_backward] | |
triton_poi_fused_gelu_gelu_backward_15.run(buf450, addmm_28, 25165824, grid=grid(25165824), stream=stream0) | |
del addmm_28 | |
buf451 = reinterpret_tensor(buf443, (8192, 768), (768, 1), 0); del buf443 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf450, (8192, 3072), (3072, 1), 0), permute_377, out=buf451) | |
del permute_377 | |
buf452 = reinterpret_tensor(buf445, (3072, 768), (768, 1), 0); del buf445 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf450, (3072, 8192), (1, 3072), 0), view_106, out=buf452) | |
del view_106 | |
buf453 = buf394; del buf394 # reuse | |
# Source Nodes: [], Original ATen: [aten.sum] | |
triton_red_fused_sum_16.run(buf450, buf453, 98304, 256, grid=grid(98304), stream=stream0) | |
buf456 = empty_strided_cuda((3072, ), (1, ), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum] | |
triton_per_fused__to_copy_sum_17.run(buf453, buf456, 3072, 32, grid=grid(3072), stream=stream0) | |
buf455 = empty_strided_cuda((3072, 768), (768, 1), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy] | |
triton_poi_fused__to_copy_18.run(buf452, buf455, 2359296, grid=grid(2359296), stream=stream0) | |
buf459 = buf400; del buf400 # reuse | |
buf464 = reinterpret_tensor(buf429, (16, 512, 768), (393216, 768, 1), 0); del buf429 # reuse | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_dropout_backward, aten.native_layer_norm_backward] | |
triton_per_fused__to_copy_add_native_dropout_backward_native_layer_norm_backward_19.run(buf438, buf451, primals_78, mul_61, div_49, gt_14, buf459, buf464, 8192, 768, grid=grid(8192), stream=stream0) | |
del div_49 | |
del gt_14 | |
del primals_78 | |
buf460 = reinterpret_tensor(buf446, (768, 64), (1, 768), 0); del buf446 # reuse | |
buf462 = buf439; del buf439 # reuse | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward] | |
triton_red_fused__to_copy_add_native_layer_norm_backward_20.run(buf438, buf451, mul_61, buf460, buf462, 49152, 128, grid=grid(49152), stream=stream0) | |
del mul_61 | |
buf461 = empty_strided_cuda((768, ), (1, ), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward] | |
triton_per_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_7.run(buf460, buf461, 768, 64, grid=grid(768), stream=stream0) | |
buf463 = empty_strided_cuda((768, ), (1, ), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward] | |
triton_per_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_7.run(buf462, buf463, 768, 64, grid=grid(768), stream=stream0) | |
buf465 = buf451; del buf451 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf464, (8192, 768), (768, 1), 0), permute_381, out=buf465) | |
del permute_381 | |
buf466 = buf430; del buf430 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf464, (768, 8192), (1, 768), 0), view_104, out=buf466) | |
del view_104 | |
buf467 = reinterpret_tensor(buf462, (1, 768, 64), (49152, 1, 768), 0); del buf462 # reuse | |
# Source Nodes: [], Original ATen: [aten.sum] | |
triton_red_fused_sum_13.run(buf464, buf467, 49152, 128, grid=grid(49152), stream=stream0) | |
buf470 = empty_strided_cuda((768, ), (1, ), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum] | |
triton_per_fused__to_copy_sum_9.run(buf467, buf470, 768, 64, grid=grid(768), stream=stream0) | |
buf469 = empty_strided_cuda((768, 768), (768, 1), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy] | |
triton_poi_fused__to_copy_10.run(buf466, buf469, 589824, grid=grid(589824), stream=stream0) | |
buf471 = reinterpret_tensor(buf464, (16, 12, 512, 64), (393216, 32768, 64, 1), 0); del buf464 # reuse | |
# Source Nodes: [], Original ATen: [aten.clone] | |
triton_poi_fused_clone_21.run(buf465, buf471, 6291456, grid=grid(6291456), stream=stream0) | |
del buf465 | |
# Source Nodes: [], Original ATen: [aten.clone] | |
buf472 = aten._scaled_dot_product_flash_attention_backward.default(buf471, permute_default_42, permute_default_43, permute_default_44, getitem_101, getitem_102, None, None, 512, 512, 0.1, False, getitem_103, getitem_104, scale=0.125) | |
del getitem_101 | |
del getitem_102 | |
del getitem_103 | |
del getitem_104 | |
del permute_default_42 | |
del permute_default_43 | |
del permute_default_44 | |
buf473 = buf472[0] | |
buf474 = buf472[1] | |
buf475 = buf472[2] | |
del buf472 | |
buf476 = reinterpret_tensor(buf471, (8192, 768), (768, 1), 0); del buf471 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf475, (8192, 768), (768, 1), 0), permute_393, out=buf476) | |
del permute_393 | |
buf477 = buf466; del buf466 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf475, (768, 8192), (1, 768), 0), view_88, out=buf477) | |
buf478 = buf467; del buf467 # reuse | |
# Source Nodes: [], Original ATen: [aten.sum] | |
triton_red_fused_sum_22.run(buf475, buf478, 49152, 128, grid=grid(49152), stream=stream0) | |
buf481 = empty_strided_cuda((768, ), (1, ), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum] | |
triton_per_fused__to_copy_sum_9.run(buf478, buf481, 768, 64, grid=grid(768), stream=stream0) | |
buf480 = empty_strided_cuda((768, 768), (768, 1), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy] | |
triton_poi_fused__to_copy_10.run(buf477, buf480, 589824, grid=grid(589824), stream=stream0) | |
buf482 = reinterpret_tensor(buf475, (8192, 768), (768, 1), 0); del buf475 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf474, (8192, 768), (768, 1), 0), permute_398, out=buf482) | |
del permute_398 | |
buf483 = buf477; del buf477 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf474, (768, 8192), (1, 768), 0), view_88, out=buf483) | |
buf484 = buf478; del buf478 # reuse | |
# Source Nodes: [], Original ATen: [aten.sum] | |
triton_red_fused_sum_22.run(buf474, buf484, 49152, 128, grid=grid(49152), stream=stream0) | |
buf487 = empty_strided_cuda((768, ), (1, ), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum] | |
triton_per_fused__to_copy_sum_9.run(buf484, buf487, 768, 64, grid=grid(768), stream=stream0) | |
buf486 = empty_strided_cuda((768, 768), (768, 1), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy] | |
triton_poi_fused__to_copy_10.run(buf483, buf486, 589824, grid=grid(589824), stream=stream0) | |
buf488 = reinterpret_tensor(buf474, (8192, 768), (768, 1), 0); del buf474 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf473, (8192, 768), (768, 1), 0), permute_402, out=buf488) | |
del permute_402 | |
buf489 = buf483; del buf483 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf473, (768, 8192), (1, 768), 0), view_88, out=buf489) | |
del view_88 | |
buf490 = buf484; del buf484 # reuse | |
# Source Nodes: [], Original ATen: [aten.sum] | |
triton_red_fused_sum_22.run(buf473, buf490, 49152, 128, grid=grid(49152), stream=stream0) | |
buf493 = empty_strided_cuda((768, ), (1, ), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum] | |
triton_per_fused__to_copy_sum_9.run(buf490, buf493, 768, 64, grid=grid(768), stream=stream0) | |
buf492 = empty_strided_cuda((768, 768), (768, 1), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy] | |
triton_poi_fused__to_copy_10.run(buf489, buf492, 589824, grid=grid(589824), stream=stream0) | |
buf497 = buf438; del buf438 # reuse | |
buf502 = reinterpret_tensor(buf473, (16, 512, 768), (393216, 768, 1), 0); del buf473 # reuse | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_dropout_backward, aten.native_layer_norm_backward] | |
triton_per_fused__to_copy_add_native_dropout_backward_native_layer_norm_backward_23.run(buf459, buf476, buf482, buf488, primals_68, mul_55, div_51, gt_12, buf497, buf502, 8192, 768, grid=grid(8192), stream=stream0) | |
del div_51 | |
del gt_12 | |
del primals_68 | |
buf498 = reinterpret_tensor(buf490, (768, 64), (1, 768), 0); del buf490 # reuse | |
buf500 = buf460; del buf460 # reuse | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward] | |
triton_red_fused__to_copy_add_native_layer_norm_backward_24.run(buf459, buf476, buf482, buf488, mul_55, buf498, buf500, 49152, 128, grid=grid(49152), stream=stream0) | |
del buf476 | |
del buf482 | |
del mul_55 | |
buf499 = empty_strided_cuda((768, ), (1, ), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward] | |
triton_per_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_7.run(buf498, buf499, 768, 64, grid=grid(768), stream=stream0) | |
buf501 = empty_strided_cuda((768, ), (1, ), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward] | |
triton_per_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_7.run(buf500, buf501, 768, 64, grid=grid(768), stream=stream0) | |
buf503 = reinterpret_tensor(buf450, (8192, 3072), (3072, 1), 0); del buf450 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf502, (8192, 768), (768, 1), 0), permute_406, out=buf503) | |
del permute_406 | |
buf504 = reinterpret_tensor(buf452, (768, 3072), (3072, 1), 0); del buf452 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf502, (768, 8192), (1, 768), 0), view_86, out=buf504) | |
del view_86 | |
buf505 = reinterpret_tensor(buf500, (1, 768, 64), (49152, 1, 768), 0); del buf500 # reuse | |
# Source Nodes: [], Original ATen: [aten.sum] | |
triton_red_fused_sum_13.run(buf502, buf505, 49152, 128, grid=grid(49152), stream=stream0) | |
buf508 = empty_strided_cuda((768, ), (1, ), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum] | |
triton_per_fused__to_copy_sum_9.run(buf505, buf508, 768, 64, grid=grid(768), stream=stream0) | |
buf507 = empty_strided_cuda((768, 3072), (3072, 1), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy] | |
triton_poi_fused__to_copy_14.run(buf504, buf507, 2359296, grid=grid(2359296), stream=stream0) | |
buf509 = reinterpret_tensor(buf503, (16, 512, 3072), (1572864, 3072, 1), 0); del buf503 # reuse | |
# Source Nodes: [hidden_states_28], Original ATen: [aten.gelu, aten.gelu_backward] | |
triton_poi_fused_gelu_gelu_backward_15.run(buf509, addmm_22, 25165824, grid=grid(25165824), stream=stream0) | |
del addmm_22 | |
buf510 = reinterpret_tensor(buf502, (8192, 768), (768, 1), 0); del buf502 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf509, (8192, 3072), (3072, 1), 0), permute_410, out=buf510) | |
del permute_410 | |
buf511 = reinterpret_tensor(buf504, (3072, 768), (768, 1), 0); del buf504 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf509, (3072, 8192), (1, 3072), 0), view_84, out=buf511) | |
del view_84 | |
buf512 = buf453; del buf453 # reuse | |
# Source Nodes: [], Original ATen: [aten.sum] | |
triton_red_fused_sum_16.run(buf509, buf512, 98304, 256, grid=grid(98304), stream=stream0) | |
buf515 = empty_strided_cuda((3072, ), (1, ), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum] | |
triton_per_fused__to_copy_sum_17.run(buf512, buf515, 3072, 32, grid=grid(3072), stream=stream0) | |
buf514 = empty_strided_cuda((3072, 768), (768, 1), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy] | |
triton_poi_fused__to_copy_18.run(buf511, buf514, 2359296, grid=grid(2359296), stream=stream0) | |
buf518 = buf459; del buf459 # reuse | |
buf523 = reinterpret_tensor(buf488, (16, 512, 768), (393216, 768, 1), 0); del buf488 # reuse | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_dropout_backward, aten.native_layer_norm_backward] | |
triton_per_fused__to_copy_add_native_dropout_backward_native_layer_norm_backward_19.run(buf497, buf510, primals_62, mul_48, div_52, gt_11, buf518, buf523, 8192, 768, grid=grid(8192), stream=stream0) | |
del div_52 | |
del gt_11 | |
del primals_62 | |
buf519 = reinterpret_tensor(buf505, (768, 64), (1, 768), 0); del buf505 # reuse | |
buf521 = buf498; del buf498 # reuse | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward] | |
triton_red_fused__to_copy_add_native_layer_norm_backward_20.run(buf497, buf510, mul_48, buf519, buf521, 49152, 128, grid=grid(49152), stream=stream0) | |
del mul_48 | |
buf520 = empty_strided_cuda((768, ), (1, ), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward] | |
triton_per_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_7.run(buf519, buf520, 768, 64, grid=grid(768), stream=stream0) | |
buf522 = empty_strided_cuda((768, ), (1, ), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward] | |
triton_per_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_7.run(buf521, buf522, 768, 64, grid=grid(768), stream=stream0) | |
buf524 = buf510; del buf510 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf523, (8192, 768), (768, 1), 0), permute_414, out=buf524) | |
del permute_414 | |
buf525 = buf489; del buf489 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf523, (768, 8192), (1, 768), 0), view_82, out=buf525) | |
del view_82 | |
buf526 = reinterpret_tensor(buf521, (1, 768, 64), (49152, 1, 768), 0); del buf521 # reuse | |
# Source Nodes: [], Original ATen: [aten.sum] | |
triton_red_fused_sum_13.run(buf523, buf526, 49152, 128, grid=grid(49152), stream=stream0) | |
buf529 = empty_strided_cuda((768, ), (1, ), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum] | |
triton_per_fused__to_copy_sum_9.run(buf526, buf529, 768, 64, grid=grid(768), stream=stream0) | |
buf528 = empty_strided_cuda((768, 768), (768, 1), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy] | |
triton_poi_fused__to_copy_10.run(buf525, buf528, 589824, grid=grid(589824), stream=stream0) | |
buf530 = reinterpret_tensor(buf523, (16, 12, 512, 64), (393216, 32768, 64, 1), 0); del buf523 # reuse | |
# Source Nodes: [], Original ATen: [aten.clone] | |
triton_poi_fused_clone_21.run(buf524, buf530, 6291456, grid=grid(6291456), stream=stream0) | |
del buf524 | |
# Source Nodes: [], Original ATen: [aten.clone] | |
buf531 = aten._scaled_dot_product_flash_attention_backward.default(buf530, permute_default_48, permute_default_49, permute_default_50, getitem_108, getitem_109, None, None, 512, 512, 0.1, False, getitem_110, getitem_111, scale=0.125) | |
del getitem_108 | |
del getitem_109 | |
del getitem_110 | |
del getitem_111 | |
del permute_default_48 | |
del permute_default_49 | |
del permute_default_50 | |
buf532 = buf531[0] | |
buf533 = buf531[1] | |
buf534 = buf531[2] | |
del buf531 | |
buf535 = reinterpret_tensor(buf530, (8192, 768), (768, 1), 0); del buf530 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf534, (8192, 768), (768, 1), 0), permute_426, out=buf535) | |
del permute_426 | |
buf536 = buf525; del buf525 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf534, (768, 8192), (1, 768), 0), view_66, out=buf536) | |
buf537 = buf526; del buf526 # reuse | |
# Source Nodes: [], Original ATen: [aten.sum] | |
triton_red_fused_sum_22.run(buf534, buf537, 49152, 128, grid=grid(49152), stream=stream0) | |
buf540 = empty_strided_cuda((768, ), (1, ), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum] | |
triton_per_fused__to_copy_sum_9.run(buf537, buf540, 768, 64, grid=grid(768), stream=stream0) | |
buf539 = empty_strided_cuda((768, 768), (768, 1), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy] | |
triton_poi_fused__to_copy_10.run(buf536, buf539, 589824, grid=grid(589824), stream=stream0) | |
buf541 = reinterpret_tensor(buf534, (8192, 768), (768, 1), 0); del buf534 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf533, (8192, 768), (768, 1), 0), permute_431, out=buf541) | |
del permute_431 | |
buf542 = buf536; del buf536 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf533, (768, 8192), (1, 768), 0), view_66, out=buf542) | |
buf543 = buf537; del buf537 # reuse | |
# Source Nodes: [], Original ATen: [aten.sum] | |
triton_red_fused_sum_22.run(buf533, buf543, 49152, 128, grid=grid(49152), stream=stream0) | |
buf546 = empty_strided_cuda((768, ), (1, ), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum] | |
triton_per_fused__to_copy_sum_9.run(buf543, buf546, 768, 64, grid=grid(768), stream=stream0) | |
buf545 = empty_strided_cuda((768, 768), (768, 1), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy] | |
triton_poi_fused__to_copy_10.run(buf542, buf545, 589824, grid=grid(589824), stream=stream0) | |
buf547 = reinterpret_tensor(buf533, (8192, 768), (768, 1), 0); del buf533 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf532, (8192, 768), (768, 1), 0), permute_435, out=buf547) | |
del permute_435 | |
buf548 = buf542; del buf542 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf532, (768, 8192), (1, 768), 0), view_66, out=buf548) | |
del view_66 | |
buf549 = buf543; del buf543 # reuse | |
# Source Nodes: [], Original ATen: [aten.sum] | |
triton_red_fused_sum_22.run(buf532, buf549, 49152, 128, grid=grid(49152), stream=stream0) | |
buf552 = empty_strided_cuda((768, ), (1, ), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum] | |
triton_per_fused__to_copy_sum_9.run(buf549, buf552, 768, 64, grid=grid(768), stream=stream0) | |
buf551 = empty_strided_cuda((768, 768), (768, 1), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy] | |
triton_poi_fused__to_copy_10.run(buf548, buf551, 589824, grid=grid(589824), stream=stream0) | |
buf556 = buf497; del buf497 # reuse | |
buf561 = reinterpret_tensor(buf532, (16, 512, 768), (393216, 768, 1), 0); del buf532 # reuse | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_dropout_backward, aten.native_layer_norm_backward] | |
triton_per_fused__to_copy_add_native_dropout_backward_native_layer_norm_backward_23.run(buf518, buf535, buf541, buf547, primals_52, mul_42, div_54, gt_9, buf556, buf561, 8192, 768, grid=grid(8192), stream=stream0) | |
del div_54 | |
del gt_9 | |
del primals_52 | |
buf557 = reinterpret_tensor(buf549, (768, 64), (1, 768), 0); del buf549 # reuse | |
buf559 = buf519; del buf519 # reuse | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward] | |
triton_red_fused__to_copy_add_native_layer_norm_backward_24.run(buf518, buf535, buf541, buf547, mul_42, buf557, buf559, 49152, 128, grid=grid(49152), stream=stream0) | |
del buf535 | |
del buf541 | |
del mul_42 | |
buf558 = empty_strided_cuda((768, ), (1, ), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward] | |
triton_per_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_7.run(buf557, buf558, 768, 64, grid=grid(768), stream=stream0) | |
buf560 = empty_strided_cuda((768, ), (1, ), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward] | |
triton_per_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_7.run(buf559, buf560, 768, 64, grid=grid(768), stream=stream0) | |
buf562 = reinterpret_tensor(buf509, (8192, 3072), (3072, 1), 0); del buf509 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf561, (8192, 768), (768, 1), 0), permute_439, out=buf562) | |
del permute_439 | |
buf563 = reinterpret_tensor(buf511, (768, 3072), (3072, 1), 0); del buf511 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf561, (768, 8192), (1, 768), 0), view_64, out=buf563) | |
del view_64 | |
buf564 = reinterpret_tensor(buf559, (1, 768, 64), (49152, 1, 768), 0); del buf559 # reuse | |
# Source Nodes: [], Original ATen: [aten.sum] | |
triton_red_fused_sum_13.run(buf561, buf564, 49152, 128, grid=grid(49152), stream=stream0) | |
buf567 = empty_strided_cuda((768, ), (1, ), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum] | |
triton_per_fused__to_copy_sum_9.run(buf564, buf567, 768, 64, grid=grid(768), stream=stream0) | |
buf566 = empty_strided_cuda((768, 3072), (3072, 1), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy] | |
triton_poi_fused__to_copy_14.run(buf563, buf566, 2359296, grid=grid(2359296), stream=stream0) | |
buf568 = reinterpret_tensor(buf562, (16, 512, 3072), (1572864, 3072, 1), 0); del buf562 # reuse | |
# Source Nodes: [hidden_states_20], Original ATen: [aten.gelu, aten.gelu_backward] | |
triton_poi_fused_gelu_gelu_backward_15.run(buf568, addmm_16, 25165824, grid=grid(25165824), stream=stream0) | |
del addmm_16 | |
buf569 = reinterpret_tensor(buf561, (8192, 768), (768, 1), 0); del buf561 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf568, (8192, 3072), (3072, 1), 0), permute_443, out=buf569) | |
del permute_443 | |
buf570 = reinterpret_tensor(buf563, (3072, 768), (768, 1), 0); del buf563 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf568, (3072, 8192), (1, 3072), 0), view_62, out=buf570) | |
del view_62 | |
buf571 = buf512; del buf512 # reuse | |
# Source Nodes: [], Original ATen: [aten.sum] | |
triton_red_fused_sum_16.run(buf568, buf571, 98304, 256, grid=grid(98304), stream=stream0) | |
buf574 = empty_strided_cuda((3072, ), (1, ), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum] | |
triton_per_fused__to_copy_sum_17.run(buf571, buf574, 3072, 32, grid=grid(3072), stream=stream0) | |
buf573 = empty_strided_cuda((3072, 768), (768, 1), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy] | |
triton_poi_fused__to_copy_18.run(buf570, buf573, 2359296, grid=grid(2359296), stream=stream0) | |
buf577 = buf518; del buf518 # reuse | |
buf582 = reinterpret_tensor(buf547, (16, 512, 768), (393216, 768, 1), 0); del buf547 # reuse | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_dropout_backward, aten.native_layer_norm_backward] | |
triton_per_fused__to_copy_add_native_dropout_backward_native_layer_norm_backward_19.run(buf556, buf569, primals_46, mul_35, div_55, gt_8, buf577, buf582, 8192, 768, grid=grid(8192), stream=stream0) | |
del div_55 | |
del gt_8 | |
del primals_46 | |
buf578 = reinterpret_tensor(buf564, (768, 64), (1, 768), 0); del buf564 # reuse | |
buf580 = buf557; del buf557 # reuse | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward] | |
triton_red_fused__to_copy_add_native_layer_norm_backward_20.run(buf556, buf569, mul_35, buf578, buf580, 49152, 128, grid=grid(49152), stream=stream0) | |
del mul_35 | |
buf579 = empty_strided_cuda((768, ), (1, ), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward] | |
triton_per_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_7.run(buf578, buf579, 768, 64, grid=grid(768), stream=stream0) | |
buf581 = empty_strided_cuda((768, ), (1, ), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward] | |
triton_per_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_7.run(buf580, buf581, 768, 64, grid=grid(768), stream=stream0) | |
buf583 = buf569; del buf569 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf582, (8192, 768), (768, 1), 0), permute_447, out=buf583) | |
del permute_447 | |
buf584 = buf548; del buf548 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf582, (768, 8192), (1, 768), 0), view_60, out=buf584) | |
del view_60 | |
buf585 = reinterpret_tensor(buf580, (1, 768, 64), (49152, 1, 768), 0); del buf580 # reuse | |
# Source Nodes: [], Original ATen: [aten.sum] | |
triton_red_fused_sum_13.run(buf582, buf585, 49152, 128, grid=grid(49152), stream=stream0) | |
buf588 = empty_strided_cuda((768, ), (1, ), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum] | |
triton_per_fused__to_copy_sum_9.run(buf585, buf588, 768, 64, grid=grid(768), stream=stream0) | |
buf587 = empty_strided_cuda((768, 768), (768, 1), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy] | |
triton_poi_fused__to_copy_10.run(buf584, buf587, 589824, grid=grid(589824), stream=stream0) | |
buf589 = reinterpret_tensor(buf582, (16, 12, 512, 64), (393216, 32768, 64, 1), 0); del buf582 # reuse | |
# Source Nodes: [], Original ATen: [aten.clone] | |
triton_poi_fused_clone_21.run(buf583, buf589, 6291456, grid=grid(6291456), stream=stream0) | |
del buf583 | |
# Source Nodes: [], Original ATen: [aten.clone] | |
buf590 = aten._scaled_dot_product_flash_attention_backward.default(buf589, permute_default_54, permute_default_55, permute_default_56, getitem_115, getitem_116, None, None, 512, 512, 0.1, False, getitem_117, getitem_118, scale=0.125) | |
del getitem_115 | |
del getitem_116 | |
del getitem_117 | |
del getitem_118 | |
del permute_default_54 | |
del permute_default_55 | |
del permute_default_56 | |
buf591 = buf590[0] | |
buf592 = buf590[1] | |
buf593 = buf590[2] | |
del buf590 | |
buf594 = reinterpret_tensor(buf589, (8192, 768), (768, 1), 0); del buf589 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf593, (8192, 768), (768, 1), 0), permute_459, out=buf594) | |
del permute_459 | |
buf595 = buf584; del buf584 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf593, (768, 8192), (1, 768), 0), view_44, out=buf595) | |
buf596 = buf585; del buf585 # reuse | |
# Source Nodes: [], Original ATen: [aten.sum] | |
triton_red_fused_sum_22.run(buf593, buf596, 49152, 128, grid=grid(49152), stream=stream0) | |
buf599 = empty_strided_cuda((768, ), (1, ), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum] | |
triton_per_fused__to_copy_sum_9.run(buf596, buf599, 768, 64, grid=grid(768), stream=stream0) | |
buf598 = empty_strided_cuda((768, 768), (768, 1), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy] | |
triton_poi_fused__to_copy_10.run(buf595, buf598, 589824, grid=grid(589824), stream=stream0) | |
buf600 = reinterpret_tensor(buf593, (8192, 768), (768, 1), 0); del buf593 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf592, (8192, 768), (768, 1), 0), permute_464, out=buf600) | |
del permute_464 | |
buf601 = buf595; del buf595 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf592, (768, 8192), (1, 768), 0), view_44, out=buf601) | |
buf602 = buf596; del buf596 # reuse | |
# Source Nodes: [], Original ATen: [aten.sum] | |
triton_red_fused_sum_22.run(buf592, buf602, 49152, 128, grid=grid(49152), stream=stream0) | |
buf605 = empty_strided_cuda((768, ), (1, ), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum] | |
triton_per_fused__to_copy_sum_9.run(buf602, buf605, 768, 64, grid=grid(768), stream=stream0) | |
buf604 = empty_strided_cuda((768, 768), (768, 1), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy] | |
triton_poi_fused__to_copy_10.run(buf601, buf604, 589824, grid=grid(589824), stream=stream0) | |
buf606 = reinterpret_tensor(buf592, (8192, 768), (768, 1), 0); del buf592 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf591, (8192, 768), (768, 1), 0), permute_468, out=buf606) | |
del permute_468 | |
buf607 = buf601; del buf601 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf591, (768, 8192), (1, 768), 0), view_44, out=buf607) | |
del view_44 | |
buf608 = buf602; del buf602 # reuse | |
# Source Nodes: [], Original ATen: [aten.sum] | |
triton_red_fused_sum_22.run(buf591, buf608, 49152, 128, grid=grid(49152), stream=stream0) | |
buf611 = empty_strided_cuda((768, ), (1, ), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum] | |
triton_per_fused__to_copy_sum_9.run(buf608, buf611, 768, 64, grid=grid(768), stream=stream0) | |
buf610 = empty_strided_cuda((768, 768), (768, 1), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy] | |
triton_poi_fused__to_copy_10.run(buf607, buf610, 589824, grid=grid(589824), stream=stream0) | |
buf615 = buf556; del buf556 # reuse | |
buf620 = reinterpret_tensor(buf591, (16, 512, 768), (393216, 768, 1), 0); del buf591 # reuse | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_dropout_backward, aten.native_layer_norm_backward] | |
triton_per_fused__to_copy_add_native_dropout_backward_native_layer_norm_backward_23.run(buf577, buf594, buf600, buf606, primals_36, mul_29, div_57, gt_6, buf615, buf620, 8192, 768, grid=grid(8192), stream=stream0) | |
del div_57 | |
del gt_6 | |
del primals_36 | |
buf616 = reinterpret_tensor(buf608, (768, 64), (1, 768), 0); del buf608 # reuse | |
buf618 = buf578; del buf578 # reuse | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward] | |
triton_red_fused__to_copy_add_native_layer_norm_backward_24.run(buf577, buf594, buf600, buf606, mul_29, buf616, buf618, 49152, 128, grid=grid(49152), stream=stream0) | |
del buf594 | |
del buf600 | |
del mul_29 | |
buf617 = empty_strided_cuda((768, ), (1, ), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward] | |
triton_per_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_7.run(buf616, buf617, 768, 64, grid=grid(768), stream=stream0) | |
buf619 = empty_strided_cuda((768, ), (1, ), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward] | |
triton_per_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_7.run(buf618, buf619, 768, 64, grid=grid(768), stream=stream0) | |
buf621 = reinterpret_tensor(buf568, (8192, 3072), (3072, 1), 0); del buf568 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf620, (8192, 768), (768, 1), 0), permute_472, out=buf621) | |
del permute_472 | |
buf622 = reinterpret_tensor(buf570, (768, 3072), (3072, 1), 0); del buf570 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf620, (768, 8192), (1, 768), 0), view_42, out=buf622) | |
del view_42 | |
buf623 = reinterpret_tensor(buf618, (1, 768, 64), (49152, 1, 768), 0); del buf618 # reuse | |
# Source Nodes: [], Original ATen: [aten.sum] | |
triton_red_fused_sum_13.run(buf620, buf623, 49152, 128, grid=grid(49152), stream=stream0) | |
buf626 = empty_strided_cuda((768, ), (1, ), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum] | |
triton_per_fused__to_copy_sum_9.run(buf623, buf626, 768, 64, grid=grid(768), stream=stream0) | |
buf625 = empty_strided_cuda((768, 3072), (3072, 1), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy] | |
triton_poi_fused__to_copy_14.run(buf622, buf625, 2359296, grid=grid(2359296), stream=stream0) | |
buf627 = reinterpret_tensor(buf621, (16, 512, 3072), (1572864, 3072, 1), 0); del buf621 # reuse | |
# Source Nodes: [hidden_states_12], Original ATen: [aten.gelu, aten.gelu_backward] | |
triton_poi_fused_gelu_gelu_backward_15.run(buf627, addmm_10, 25165824, grid=grid(25165824), stream=stream0) | |
del addmm_10 | |
buf628 = reinterpret_tensor(buf620, (8192, 768), (768, 1), 0); del buf620 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf627, (8192, 3072), (3072, 1), 0), permute_476, out=buf628) | |
del permute_476 | |
buf629 = reinterpret_tensor(buf622, (3072, 768), (768, 1), 0); del buf622 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf627, (3072, 8192), (1, 3072), 0), view_40, out=buf629) | |
del view_40 | |
buf630 = buf571; del buf571 # reuse | |
# Source Nodes: [], Original ATen: [aten.sum] | |
triton_red_fused_sum_16.run(buf627, buf630, 98304, 256, grid=grid(98304), stream=stream0) | |
buf633 = empty_strided_cuda((3072, ), (1, ), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum] | |
triton_per_fused__to_copy_sum_17.run(buf630, buf633, 3072, 32, grid=grid(3072), stream=stream0) | |
buf632 = empty_strided_cuda((3072, 768), (768, 1), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy] | |
triton_poi_fused__to_copy_18.run(buf629, buf632, 2359296, grid=grid(2359296), stream=stream0) | |
buf636 = buf577; del buf577 # reuse | |
buf641 = reinterpret_tensor(buf606, (16, 512, 768), (393216, 768, 1), 0); del buf606 # reuse | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_dropout_backward, aten.native_layer_norm_backward] | |
triton_per_fused__to_copy_add_native_dropout_backward_native_layer_norm_backward_19.run(buf615, buf628, primals_30, mul_22, div_58, gt_5, buf636, buf641, 8192, 768, grid=grid(8192), stream=stream0) | |
del div_58 | |
del gt_5 | |
del primals_30 | |
buf637 = reinterpret_tensor(buf623, (768, 64), (1, 768), 0); del buf623 # reuse | |
buf639 = buf616; del buf616 # reuse | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward] | |
triton_red_fused__to_copy_add_native_layer_norm_backward_20.run(buf615, buf628, mul_22, buf637, buf639, 49152, 128, grid=grid(49152), stream=stream0) | |
del mul_22 | |
buf638 = empty_strided_cuda((768, ), (1, ), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward] | |
triton_per_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_7.run(buf637, buf638, 768, 64, grid=grid(768), stream=stream0) | |
buf640 = empty_strided_cuda((768, ), (1, ), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward] | |
triton_per_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_7.run(buf639, buf640, 768, 64, grid=grid(768), stream=stream0) | |
buf642 = buf628; del buf628 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf641, (8192, 768), (768, 1), 0), permute_480, out=buf642) | |
del permute_480 | |
buf643 = buf607; del buf607 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf641, (768, 8192), (1, 768), 0), view_38, out=buf643) | |
del view_38 | |
buf644 = reinterpret_tensor(buf639, (1, 768, 64), (49152, 1, 768), 0); del buf639 # reuse | |
# Source Nodes: [], Original ATen: [aten.sum] | |
triton_red_fused_sum_13.run(buf641, buf644, 49152, 128, grid=grid(49152), stream=stream0) | |
buf647 = empty_strided_cuda((768, ), (1, ), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum] | |
triton_per_fused__to_copy_sum_9.run(buf644, buf647, 768, 64, grid=grid(768), stream=stream0) | |
buf646 = empty_strided_cuda((768, 768), (768, 1), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy] | |
triton_poi_fused__to_copy_10.run(buf643, buf646, 589824, grid=grid(589824), stream=stream0) | |
buf648 = reinterpret_tensor(buf641, (16, 12, 512, 64), (393216, 32768, 64, 1), 0); del buf641 # reuse | |
# Source Nodes: [], Original ATen: [aten.clone] | |
triton_poi_fused_clone_21.run(buf642, buf648, 6291456, grid=grid(6291456), stream=stream0) | |
del buf642 | |
# Source Nodes: [], Original ATen: [aten.clone] | |
buf649 = aten._scaled_dot_product_flash_attention_backward.default(buf648, permute_default_60, permute_default_61, permute_default_62, getitem_122, getitem_123, None, None, 512, 512, 0.1, False, getitem_124, getitem_125, scale=0.125) | |
del getitem_122 | |
del getitem_123 | |
del getitem_124 | |
del getitem_125 | |
del permute_default_60 | |
del permute_default_61 | |
del permute_default_62 | |
buf650 = buf649[0] | |
buf651 = buf649[1] | |
buf652 = buf649[2] | |
del buf649 | |
buf653 = reinterpret_tensor(buf648, (8192, 768), (768, 1), 0); del buf648 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf652, (8192, 768), (768, 1), 0), permute_492, out=buf653) | |
del permute_492 | |
buf654 = buf643; del buf643 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf652, (768, 8192), (1, 768), 0), view_22, out=buf654) | |
buf655 = buf644; del buf644 # reuse | |
# Source Nodes: [], Original ATen: [aten.sum] | |
triton_red_fused_sum_22.run(buf652, buf655, 49152, 128, grid=grid(49152), stream=stream0) | |
buf658 = empty_strided_cuda((768, ), (1, ), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum] | |
triton_per_fused__to_copy_sum_9.run(buf655, buf658, 768, 64, grid=grid(768), stream=stream0) | |
buf657 = empty_strided_cuda((768, 768), (768, 1), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy] | |
triton_poi_fused__to_copy_10.run(buf654, buf657, 589824, grid=grid(589824), stream=stream0) | |
buf659 = reinterpret_tensor(buf652, (8192, 768), (768, 1), 0); del buf652 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf651, (8192, 768), (768, 1), 0), permute_497, out=buf659) | |
del permute_497 | |
buf660 = buf654; del buf654 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf651, (768, 8192), (1, 768), 0), view_22, out=buf660) | |
buf661 = buf655; del buf655 # reuse | |
# Source Nodes: [], Original ATen: [aten.sum] | |
triton_red_fused_sum_22.run(buf651, buf661, 49152, 128, grid=grid(49152), stream=stream0) | |
buf664 = empty_strided_cuda((768, ), (1, ), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum] | |
triton_per_fused__to_copy_sum_9.run(buf661, buf664, 768, 64, grid=grid(768), stream=stream0) | |
buf663 = empty_strided_cuda((768, 768), (768, 1), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy] | |
triton_poi_fused__to_copy_10.run(buf660, buf663, 589824, grid=grid(589824), stream=stream0) | |
buf665 = reinterpret_tensor(buf651, (8192, 768), (768, 1), 0); del buf651 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf650, (8192, 768), (768, 1), 0), permute_501, out=buf665) | |
del permute_501 | |
buf666 = buf660; del buf660 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf650, (768, 8192), (1, 768), 0), view_22, out=buf666) | |
del view_22 | |
buf667 = buf661; del buf661 # reuse | |
# Source Nodes: [], Original ATen: [aten.sum] | |
triton_red_fused_sum_22.run(buf650, buf667, 49152, 128, grid=grid(49152), stream=stream0) | |
buf670 = empty_strided_cuda((768, ), (1, ), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum] | |
triton_per_fused__to_copy_sum_9.run(buf667, buf670, 768, 64, grid=grid(768), stream=stream0) | |
buf669 = empty_strided_cuda((768, 768), (768, 1), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy] | |
triton_poi_fused__to_copy_10.run(buf666, buf669, 589824, grid=grid(589824), stream=stream0) | |
buf674 = buf615; del buf615 # reuse | |
buf679 = reinterpret_tensor(buf650, (16, 512, 768), (393216, 768, 1), 0); del buf650 # reuse | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_dropout_backward, aten.native_layer_norm_backward] | |
triton_per_fused__to_copy_add_native_dropout_backward_native_layer_norm_backward_23.run(buf636, buf653, buf659, buf665, primals_20, mul_16, div_60, gt_3, buf674, buf679, 8192, 768, grid=grid(8192), stream=stream0) | |
del div_60 | |
del gt_3 | |
del primals_20 | |
buf675 = reinterpret_tensor(buf667, (768, 64), (1, 768), 0); del buf667 # reuse | |
buf677 = buf637; del buf637 # reuse | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward] | |
triton_red_fused__to_copy_add_native_layer_norm_backward_24.run(buf636, buf653, buf659, buf665, mul_16, buf675, buf677, 49152, 128, grid=grid(49152), stream=stream0) | |
del buf653 | |
del buf659 | |
del mul_16 | |
buf676 = empty_strided_cuda((768, ), (1, ), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward] | |
triton_per_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_7.run(buf675, buf676, 768, 64, grid=grid(768), stream=stream0) | |
buf678 = empty_strided_cuda((768, ), (1, ), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward] | |
triton_per_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_7.run(buf677, buf678, 768, 64, grid=grid(768), stream=stream0) | |
buf680 = reinterpret_tensor(buf627, (8192, 3072), (3072, 1), 0); del buf627 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf679, (8192, 768), (768, 1), 0), permute_505, out=buf680) | |
del permute_505 | |
buf681 = reinterpret_tensor(buf629, (768, 3072), (3072, 1), 0); del buf629 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf679, (768, 8192), (1, 768), 0), view_20, out=buf681) | |
del view_20 | |
buf682 = reinterpret_tensor(buf677, (1, 768, 64), (49152, 1, 768), 0); del buf677 # reuse | |
# Source Nodes: [], Original ATen: [aten.sum] | |
triton_red_fused_sum_13.run(buf679, buf682, 49152, 128, grid=grid(49152), stream=stream0) | |
buf685 = empty_strided_cuda((768, ), (1, ), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum] | |
triton_per_fused__to_copy_sum_9.run(buf682, buf685, 768, 64, grid=grid(768), stream=stream0) | |
buf684 = empty_strided_cuda((768, 3072), (3072, 1), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy] | |
triton_poi_fused__to_copy_14.run(buf681, buf684, 2359296, grid=grid(2359296), stream=stream0) | |
buf686 = reinterpret_tensor(buf680, (16, 512, 3072), (1572864, 3072, 1), 0); del buf680 # reuse | |
# Source Nodes: [hidden_states_4], Original ATen: [aten.gelu, aten.gelu_backward] | |
triton_poi_fused_gelu_gelu_backward_15.run(buf686, addmm_4, 25165824, grid=grid(25165824), stream=stream0) | |
del addmm_4 | |
buf687 = reinterpret_tensor(buf679, (8192, 768), (768, 1), 0); del buf679 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf686, (8192, 3072), (3072, 1), 0), permute_509, out=buf687) | |
del permute_509 | |
buf688 = reinterpret_tensor(buf681, (3072, 768), (768, 1), 0); del buf681 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf686, (3072, 8192), (1, 3072), 0), view_18, out=buf688) | |
del view_18 | |
buf689 = buf630; del buf630 # reuse | |
# Source Nodes: [], Original ATen: [aten.sum] | |
triton_red_fused_sum_16.run(buf686, buf689, 98304, 256, grid=grid(98304), stream=stream0) | |
del buf686 | |
buf692 = empty_strided_cuda((3072, ), (1, ), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum] | |
triton_per_fused__to_copy_sum_17.run(buf689, buf692, 3072, 32, grid=grid(3072), stream=stream0) | |
del buf689 | |
buf691 = empty_strided_cuda((3072, 768), (768, 1), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy] | |
triton_poi_fused__to_copy_18.run(buf688, buf691, 2359296, grid=grid(2359296), stream=stream0) | |
del buf688 | |
buf695 = buf636; del buf636 # reuse | |
buf700 = reinterpret_tensor(buf665, (16, 512, 768), (393216, 768, 1), 0); del buf665 # reuse | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_dropout_backward, aten.native_layer_norm_backward] | |
triton_per_fused__to_copy_add_native_dropout_backward_native_layer_norm_backward_19.run(buf674, buf687, primals_14, mul_9, div_61, gt_2, buf695, buf700, 8192, 768, grid=grid(8192), stream=stream0) | |
del div_61 | |
del gt_2 | |
del primals_14 | |
buf696 = reinterpret_tensor(buf682, (768, 64), (1, 768), 0); del buf682 # reuse | |
buf698 = buf675; del buf675 # reuse | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward] | |
triton_red_fused__to_copy_add_native_layer_norm_backward_20.run(buf674, buf687, mul_9, buf696, buf698, 49152, 128, grid=grid(49152), stream=stream0) | |
del mul_9 | |
buf697 = empty_strided_cuda((768, ), (1, ), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward] | |
triton_per_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_7.run(buf696, buf697, 768, 64, grid=grid(768), stream=stream0) | |
buf699 = empty_strided_cuda((768, ), (1, ), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward] | |
triton_per_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_7.run(buf698, buf699, 768, 64, grid=grid(768), stream=stream0) | |
buf701 = buf687; del buf687 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf700, (8192, 768), (768, 1), 0), permute_513, out=buf701) | |
del permute_513 | |
buf702 = buf666; del buf666 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf700, (768, 8192), (1, 768), 0), view_16, out=buf702) | |
del view_16 | |
buf703 = reinterpret_tensor(buf698, (1, 768, 64), (49152, 1, 768), 0); del buf698 # reuse | |
# Source Nodes: [], Original ATen: [aten.sum] | |
triton_red_fused_sum_13.run(buf700, buf703, 49152, 128, grid=grid(49152), stream=stream0) | |
buf706 = empty_strided_cuda((768, ), (1, ), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum] | |
triton_per_fused__to_copy_sum_9.run(buf703, buf706, 768, 64, grid=grid(768), stream=stream0) | |
buf705 = empty_strided_cuda((768, 768), (768, 1), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy] | |
triton_poi_fused__to_copy_10.run(buf702, buf705, 589824, grid=grid(589824), stream=stream0) | |
buf707 = reinterpret_tensor(buf700, (16, 12, 512, 64), (393216, 32768, 64, 1), 0); del buf700 # reuse | |
# Source Nodes: [], Original ATen: [aten.clone] | |
triton_poi_fused_clone_21.run(buf701, buf707, 6291456, grid=grid(6291456), stream=stream0) | |
del buf701 | |
# Source Nodes: [], Original ATen: [aten.clone] | |
buf708 = aten._scaled_dot_product_flash_attention_backward.default(buf707, permute_default_66, permute_default_67, permute_default_68, getitem_129, getitem_130, None, None, 512, 512, 0.1, False, getitem_131, getitem_132, scale=0.125) | |
del getitem_129 | |
del getitem_130 | |
del getitem_131 | |
del getitem_132 | |
del permute_default_66 | |
del permute_default_67 | |
del permute_default_68 | |
buf709 = buf708[0] | |
buf710 = buf708[1] | |
buf711 = buf708[2] | |
del buf708 | |
buf712 = reinterpret_tensor(buf707, (8192, 768), (768, 1), 0); del buf707 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf711, (8192, 768), (768, 1), 0), permute_525, out=buf712) | |
del permute_525 | |
buf713 = buf702; del buf702 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf711, (768, 8192), (1, 768), 0), view, out=buf713) | |
buf714 = buf703; del buf703 # reuse | |
# Source Nodes: [], Original ATen: [aten.sum] | |
triton_red_fused_sum_22.run(buf711, buf714, 49152, 128, grid=grid(49152), stream=stream0) | |
buf717 = empty_strided_cuda((768, ), (1, ), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum] | |
triton_per_fused__to_copy_sum_9.run(buf714, buf717, 768, 64, grid=grid(768), stream=stream0) | |
buf716 = empty_strided_cuda((768, 768), (768, 1), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy] | |
triton_poi_fused__to_copy_10.run(buf713, buf716, 589824, grid=grid(589824), stream=stream0) | |
buf718 = reinterpret_tensor(buf711, (8192, 768), (768, 1), 0); del buf711 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf710, (8192, 768), (768, 1), 0), permute_530, out=buf718) | |
del permute_530 | |
buf719 = buf713; del buf713 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf710, (768, 8192), (1, 768), 0), view, out=buf719) | |
buf720 = buf714; del buf714 # reuse | |
# Source Nodes: [], Original ATen: [aten.sum] | |
triton_red_fused_sum_22.run(buf710, buf720, 49152, 128, grid=grid(49152), stream=stream0) | |
buf723 = empty_strided_cuda((768, ), (1, ), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum] | |
triton_per_fused__to_copy_sum_9.run(buf720, buf723, 768, 64, grid=grid(768), stream=stream0) | |
buf722 = empty_strided_cuda((768, 768), (768, 1), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy] | |
triton_poi_fused__to_copy_10.run(buf719, buf722, 589824, grid=grid(589824), stream=stream0) | |
buf724 = reinterpret_tensor(buf710, (8192, 768), (768, 1), 0); del buf710 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf709, (8192, 768), (768, 1), 0), permute_534, out=buf724) | |
del permute_534 | |
buf725 = buf719; del buf719 # reuse | |
# Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf709, (768, 8192), (1, 768), 0), view, out=buf725) | |
del view | |
buf726 = buf720; del buf720 # reuse | |
# Source Nodes: [], Original ATen: [aten.sum] | |
triton_red_fused_sum_22.run(buf709, buf726, 49152, 128, grid=grid(49152), stream=stream0) | |
del buf709 | |
buf729 = empty_strided_cuda((768, ), (1, ), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum] | |
triton_per_fused__to_copy_sum_9.run(buf726, buf729, 768, 64, grid=grid(768), stream=stream0) | |
buf728 = empty_strided_cuda((768, 768), (768, 1), torch.float32) | |
# Source Nodes: [], Original ATen: [aten._to_copy] | |
triton_poi_fused__to_copy_10.run(buf725, buf728, 589824, grid=grid(589824), stream=stream0) | |
del buf725 | |
buf741 = empty_strided_cuda((2, 768), (768, 1), torch.float32) | |
# Source Nodes: [], Original ATen: [aten.embedding_dense_backward] | |
triton_poi_fused_embedding_dense_backward_25.run(buf741, 1536, grid=grid(1536), stream=stream0) | |
buf743 = empty_strided_cuda((30522, 768), (768, 1), torch.float32) | |
# Source Nodes: [], Original ATen: [aten.embedding_dense_backward] | |
triton_poi_fused_embedding_dense_backward_26.run(buf743, 23440896, grid=grid(23440896), stream=stream0) | |
buf730 = buf695; del buf695 # reuse | |
buf733 = buf674; del buf674 # reuse | |
# Source Nodes: [masked_lm_loss], Original ATen: [aten._to_copy, aten.add, aten.embedding_dense_backward, aten.native_dropout_backward, aten.native_layer_norm_backward, aten.nll_loss_forward] | |
triton_per_fused__to_copy_add_embedding_dense_backward_native_dropout_backward_native_layer_norm_backward_nll_loss_forward_27.run(buf730, buf712, buf718, buf724, gt, primals_4, mul_1, div_63, primals_204, primals_206, buf733, buf741, buf743, 8192, 768, grid=grid(8192), stream=stream0) | |
del buf712 | |
del buf718 | |
del buf724 | |
del div_63 | |
del gt | |
del primals_204 | |
del primals_206 | |
del primals_4 | |
buf734 = reinterpret_tensor(buf726, (768, 64), (1, 768), 0); del buf726 # reuse | |
buf736 = buf696; del buf696 # reuse | |
# Source Nodes: [], Original ATen: [aten.native_layer_norm_backward] | |
triton_red_fused_native_layer_norm_backward_28.run(buf730, mul_1, buf734, buf736, 49152, 128, grid=grid(49152), stream=stream0) | |
del buf730 | |
del mul_1 | |
buf735 = empty_strided_cuda((768, ), (1, ), torch.float32) | |
# Source Nodes: [], Original ATen: [aten.native_layer_norm_backward] | |
triton_per_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_7.run(buf734, buf735, 768, 64, grid=grid(768), stream=stream0) | |
del buf734 | |
buf737 = empty_strided_cuda((768, ), (1, ), torch.float32) | |
# Source Nodes: [], Original ATen: [aten.native_layer_norm_backward] | |
triton_per_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_7.run(buf736, buf737, 768, 64, grid=grid(768), stream=stream0) | |
del buf736 | |
buf739 = empty_strided_cuda((512, 768), (768, 1), torch.float32) | |
# Source Nodes: [], Original ATen: [aten.embedding_dense_backward] | |
triton_poi_fused_embedding_dense_backward_29.run(buf739, 393216, grid=grid(393216), stream=stream0) | |
# Source Nodes: [masked_lm_loss], Original ATen: [aten.embedding_dense_backward, aten.nll_loss_forward, aten.sum] | |
triton_per_fused_embedding_dense_backward_nll_loss_forward_sum_30.run(buf733, primals_205, buf739, 393216, 16, grid=grid(393216), stream=stream0) | |
del buf733 | |
del primals_205 | |
return (buf743, buf741, buf739, buf735, buf737, buf728, buf729, buf722, buf723, buf716, buf717, buf705, buf706, buf697, buf699, buf691, buf692, buf684, buf685, buf676, buf678, buf669, buf670, buf663, buf664, buf657, buf658, buf646, buf647, buf638, buf640, buf632, buf633, buf625, buf626, buf617, buf619, buf610, buf611, buf604, buf605, buf598, buf599, buf587, buf588, buf579, buf581, buf573, buf574, buf566, buf567, buf558, buf560, buf551, buf552, buf545, buf546, buf539, buf540, buf528, buf529, buf520, buf522, buf514, buf515, buf507, buf508, buf499, buf501, buf492, buf493, buf486, buf487, buf480, buf481, buf469, buf470, buf461, buf463, buf455, buf456, buf448, buf449, buf440, buf442, buf433, buf434, buf427, buf428, buf421, buf422, buf410, buf411, buf402, buf404, buf396, buf397, buf389, buf390, buf381, buf383, buf374, buf375, buf368, buf369, buf362, buf363, buf351, buf352, buf343, buf345, buf337, buf338, buf330, buf331, buf322, buf324, buf315, buf316, buf309, buf310, buf303, buf304, buf292, buf293, buf284, buf286, buf278, buf279, buf271, buf272, buf263, buf265, buf256, buf257, buf250, buf251, buf244, buf245, buf233, buf234, buf225, buf227, buf219, buf220, buf212, buf213, buf204, buf206, buf197, buf198, buf191, buf192, buf185, buf186, buf174, buf175, buf166, buf168, buf160, buf161, buf153, buf154, buf145, buf147, buf138, buf139, buf132, buf133, buf126, buf127, buf115, buf116, buf107, buf109, buf101, buf102, buf94, buf95, buf86, buf88, buf79, buf80, buf73, buf74, buf67, buf68, buf56, buf57, buf48, buf50, buf42, buf43, buf35, buf36, buf27, buf29, buf21, buf22, buf12, buf14, buf7, buf8, None, None, None, None, ) | |
def benchmark_compiled_module(times=10, repeat=10): | |
from torch._dynamo.testing import rand_strided | |
from torch._inductor.utils import print_performance | |
primals_4 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.float32) | |
primals_14 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.float32) | |
primals_20 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.float32) | |
primals_30 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.float32) | |
primals_36 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.float32) | |
primals_46 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.float32) | |
primals_52 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.float32) | |
primals_62 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.float32) | |
primals_68 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.float32) | |
primals_78 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.float32) | |
primals_84 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.float32) | |
primals_94 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.float32) | |
primals_100 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.float32) | |
primals_110 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.float32) | |
primals_116 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.float32) | |
primals_126 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.float32) | |
primals_132 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.float32) | |
primals_142 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.float32) | |
primals_148 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.float32) | |
primals_158 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.float32) | |
primals_164 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.float32) | |
primals_174 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.float32) | |
primals_180 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.float32) | |
primals_190 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.float32) | |
primals_196 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.float32) | |
primals_200 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.float32) | |
primals_204 = rand_strided((1, 512), (512, 1), device='cuda:0', dtype=torch.int64) | |
primals_205 = rand_strided((1, 512), (512, 1), device='cuda:0', dtype=torch.int64) | |
primals_206 = rand_strided((16, 512), (512, 1), device='cuda:0', dtype=torch.int64) | |
primals_207 = rand_strided((16, 512), (512, 1), device='cuda:0', dtype=torch.int64) | |
mul_1 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.float32) | |
gt = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.bool) | |
view = rand_strided((8192, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
permute_default_66 = rand_strided((16, 12, 512, 64), (393216, 64, 768, 1), device='cuda:0', dtype=torch.float16) | |
permute_default_67 = rand_strided((16, 12, 512, 64), (393216, 64, 768, 1), device='cuda:0', dtype=torch.float16) | |
permute_default_68 = rand_strided((16, 12, 512, 64), (393216, 64, 768, 1), device='cuda:0', dtype=torch.float16) | |
getitem_129 = rand_strided((16, 12, 512, 64), (393216, 64, 768, 1), device='cuda:0', dtype=torch.float16) | |
getitem_130 = rand_strided((16, 12, 512), (6144, 512, 1), device='cuda:0', dtype=torch.float32) | |
getitem_131 = rand_strided((), (), device='cuda:0', dtype=torch.int64) | |
getitem_132 = rand_strided((), (), device='cuda:0', dtype=torch.int64) | |
view_16 = rand_strided((8192, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
gt_2 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.bool) | |
mul_9 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.float32) | |
view_18 = rand_strided((8192, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
addmm_4 = rand_strided((8192, 3072), (3072, 1), device='cuda:0', dtype=torch.float16) | |
view_20 = rand_strided((8192, 3072), (3072, 1), device='cuda:0', dtype=torch.float16) | |
gt_3 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.bool) | |
mul_16 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.float32) | |
view_22 = rand_strided((8192, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
permute_default_60 = rand_strided((16, 12, 512, 64), (393216, 64, 768, 1), device='cuda:0', dtype=torch.float16) | |
permute_default_61 = rand_strided((16, 12, 512, 64), (393216, 64, 768, 1), device='cuda:0', dtype=torch.float16) | |
permute_default_62 = rand_strided((16, 12, 512, 64), (393216, 64, 768, 1), device='cuda:0', dtype=torch.float16) | |
getitem_122 = rand_strided((16, 12, 512, 64), (393216, 64, 768, 1), device='cuda:0', dtype=torch.float16) | |
getitem_123 = rand_strided((16, 12, 512), (6144, 512, 1), device='cuda:0', dtype=torch.float32) | |
getitem_124 = rand_strided((), (), device='cuda:0', dtype=torch.int64) | |
getitem_125 = rand_strided((), (), device='cuda:0', dtype=torch.int64) | |
view_38 = rand_strided((8192, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
gt_5 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.bool) | |
mul_22 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.float32) | |
view_40 = rand_strided((8192, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
addmm_10 = rand_strided((8192, 3072), (3072, 1), device='cuda:0', dtype=torch.float16) | |
view_42 = rand_strided((8192, 3072), (3072, 1), device='cuda:0', dtype=torch.float16) | |
gt_6 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.bool) | |
mul_29 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.float32) | |
view_44 = rand_strided((8192, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
permute_default_54 = rand_strided((16, 12, 512, 64), (393216, 64, 768, 1), device='cuda:0', dtype=torch.float16) | |
permute_default_55 = rand_strided((16, 12, 512, 64), (393216, 64, 768, 1), device='cuda:0', dtype=torch.float16) | |
permute_default_56 = rand_strided((16, 12, 512, 64), (393216, 64, 768, 1), device='cuda:0', dtype=torch.float16) | |
getitem_115 = rand_strided((16, 12, 512, 64), (393216, 64, 768, 1), device='cuda:0', dtype=torch.float16) | |
getitem_116 = rand_strided((16, 12, 512), (6144, 512, 1), device='cuda:0', dtype=torch.float32) | |
getitem_117 = rand_strided((), (), device='cuda:0', dtype=torch.int64) | |
getitem_118 = rand_strided((), (), device='cuda:0', dtype=torch.int64) | |
view_60 = rand_strided((8192, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
gt_8 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.bool) | |
mul_35 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.float32) | |
view_62 = rand_strided((8192, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
addmm_16 = rand_strided((8192, 3072), (3072, 1), device='cuda:0', dtype=torch.float16) | |
view_64 = rand_strided((8192, 3072), (3072, 1), device='cuda:0', dtype=torch.float16) | |
gt_9 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.bool) | |
mul_42 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.float32) | |
view_66 = rand_strided((8192, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
permute_default_48 = rand_strided((16, 12, 512, 64), (393216, 64, 768, 1), device='cuda:0', dtype=torch.float16) | |
permute_default_49 = rand_strided((16, 12, 512, 64), (393216, 64, 768, 1), device='cuda:0', dtype=torch.float16) | |
permute_default_50 = rand_strided((16, 12, 512, 64), (393216, 64, 768, 1), device='cuda:0', dtype=torch.float16) | |
getitem_108 = rand_strided((16, 12, 512, 64), (393216, 64, 768, 1), device='cuda:0', dtype=torch.float16) | |
getitem_109 = rand_strided((16, 12, 512), (6144, 512, 1), device='cuda:0', dtype=torch.float32) | |
getitem_110 = rand_strided((), (), device='cuda:0', dtype=torch.int64) | |
getitem_111 = rand_strided((), (), device='cuda:0', dtype=torch.int64) | |
view_82 = rand_strided((8192, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
gt_11 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.bool) | |
mul_48 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.float32) | |
view_84 = rand_strided((8192, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
addmm_22 = rand_strided((8192, 3072), (3072, 1), device='cuda:0', dtype=torch.float16) | |
view_86 = rand_strided((8192, 3072), (3072, 1), device='cuda:0', dtype=torch.float16) | |
gt_12 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.bool) | |
mul_55 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.float32) | |
view_88 = rand_strided((8192, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
permute_default_42 = rand_strided((16, 12, 512, 64), (393216, 64, 768, 1), device='cuda:0', dtype=torch.float16) | |
permute_default_43 = rand_strided((16, 12, 512, 64), (393216, 64, 768, 1), device='cuda:0', dtype=torch.float16) | |
permute_default_44 = rand_strided((16, 12, 512, 64), (393216, 64, 768, 1), device='cuda:0', dtype=torch.float16) | |
getitem_101 = rand_strided((16, 12, 512, 64), (393216, 64, 768, 1), device='cuda:0', dtype=torch.float16) | |
getitem_102 = rand_strided((16, 12, 512), (6144, 512, 1), device='cuda:0', dtype=torch.float32) | |
getitem_103 = rand_strided((), (), device='cuda:0', dtype=torch.int64) | |
getitem_104 = rand_strided((), (), device='cuda:0', dtype=torch.int64) | |
view_104 = rand_strided((8192, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
gt_14 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.bool) | |
mul_61 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.float32) | |
view_106 = rand_strided((8192, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
addmm_28 = rand_strided((8192, 3072), (3072, 1), device='cuda:0', dtype=torch.float16) | |
view_108 = rand_strided((8192, 3072), (3072, 1), device='cuda:0', dtype=torch.float16) | |
gt_15 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.bool) | |
mul_68 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.float32) | |
view_110 = rand_strided((8192, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
permute_default_36 = rand_strided((16, 12, 512, 64), (393216, 64, 768, 1), device='cuda:0', dtype=torch.float16) | |
permute_default_37 = rand_strided((16, 12, 512, 64), (393216, 64, 768, 1), device='cuda:0', dtype=torch.float16) | |
permute_default_38 = rand_strided((16, 12, 512, 64), (393216, 64, 768, 1), device='cuda:0', dtype=torch.float16) | |
getitem_94 = rand_strided((16, 12, 512, 64), (393216, 64, 768, 1), device='cuda:0', dtype=torch.float16) | |
getitem_95 = rand_strided((16, 12, 512), (6144, 512, 1), device='cuda:0', dtype=torch.float32) | |
getitem_96 = rand_strided((), (), device='cuda:0', dtype=torch.int64) | |
getitem_97 = rand_strided((), (), device='cuda:0', dtype=torch.int64) | |
view_126 = rand_strided((8192, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
gt_17 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.bool) | |
mul_74 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.float32) | |
view_128 = rand_strided((8192, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
addmm_34 = rand_strided((8192, 3072), (3072, 1), device='cuda:0', dtype=torch.float16) | |
view_130 = rand_strided((8192, 3072), (3072, 1), device='cuda:0', dtype=torch.float16) | |
gt_18 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.bool) | |
mul_81 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.float32) | |
view_132 = rand_strided((8192, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
permute_default_30 = rand_strided((16, 12, 512, 64), (393216, 64, 768, 1), device='cuda:0', dtype=torch.float16) | |
permute_default_31 = rand_strided((16, 12, 512, 64), (393216, 64, 768, 1), device='cuda:0', dtype=torch.float16) | |
permute_default_32 = rand_strided((16, 12, 512, 64), (393216, 64, 768, 1), device='cuda:0', dtype=torch.float16) | |
getitem_87 = rand_strided((16, 12, 512, 64), (393216, 64, 768, 1), device='cuda:0', dtype=torch.float16) | |
getitem_88 = rand_strided((16, 12, 512), (6144, 512, 1), device='cuda:0', dtype=torch.float32) | |
getitem_89 = rand_strided((), (), device='cuda:0', dtype=torch.int64) | |
getitem_90 = rand_strided((), (), device='cuda:0', dtype=torch.int64) | |
view_148 = rand_strided((8192, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
gt_20 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.bool) | |
mul_87 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.float32) | |
view_150 = rand_strided((8192, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
addmm_40 = rand_strided((8192, 3072), (3072, 1), device='cuda:0', dtype=torch.float16) | |
view_152 = rand_strided((8192, 3072), (3072, 1), device='cuda:0', dtype=torch.float16) | |
gt_21 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.bool) | |
mul_94 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.float32) | |
view_154 = rand_strided((8192, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
permute_default_24 = rand_strided((16, 12, 512, 64), (393216, 64, 768, 1), device='cuda:0', dtype=torch.float16) | |
permute_default_25 = rand_strided((16, 12, 512, 64), (393216, 64, 768, 1), device='cuda:0', dtype=torch.float16) | |
permute_default_26 = rand_strided((16, 12, 512, 64), (393216, 64, 768, 1), device='cuda:0', dtype=torch.float16) | |
getitem_80 = rand_strided((16, 12, 512, 64), (393216, 64, 768, 1), device='cuda:0', dtype=torch.float16) | |
getitem_81 = rand_strided((16, 12, 512), (6144, 512, 1), device='cuda:0', dtype=torch.float32) | |
getitem_82 = rand_strided((), (), device='cuda:0', dtype=torch.int64) | |
getitem_83 = rand_strided((), (), device='cuda:0', dtype=torch.int64) | |
view_170 = rand_strided((8192, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
gt_23 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.bool) | |
mul_100 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.float32) | |
view_172 = rand_strided((8192, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
addmm_46 = rand_strided((8192, 3072), (3072, 1), device='cuda:0', dtype=torch.float16) | |
view_174 = rand_strided((8192, 3072), (3072, 1), device='cuda:0', dtype=torch.float16) | |
gt_24 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.bool) | |
mul_107 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.float32) | |
view_176 = rand_strided((8192, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
permute_default_18 = rand_strided((16, 12, 512, 64), (393216, 64, 768, 1), device='cuda:0', dtype=torch.float16) | |
permute_default_19 = rand_strided((16, 12, 512, 64), (393216, 64, 768, 1), device='cuda:0', dtype=torch.float16) | |
permute_default_20 = rand_strided((16, 12, 512, 64), (393216, 64, 768, 1), device='cuda:0', dtype=torch.float16) | |
getitem_73 = rand_strided((16, 12, 512, 64), (393216, 64, 768, 1), device='cuda:0', dtype=torch.float16) | |
getitem_74 = rand_strided((16, 12, 512), (6144, 512, 1), device='cuda:0', dtype=torch.float32) | |
getitem_75 = rand_strided((), (), device='cuda:0', dtype=torch.int64) | |
getitem_76 = rand_strided((), (), device='cuda:0', dtype=torch.int64) | |
view_192 = rand_strided((8192, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
gt_26 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.bool) | |
mul_113 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.float32) | |
view_194 = rand_strided((8192, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
addmm_52 = rand_strided((8192, 3072), (3072, 1), device='cuda:0', dtype=torch.float16) | |
view_196 = rand_strided((8192, 3072), (3072, 1), device='cuda:0', dtype=torch.float16) | |
gt_27 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.bool) | |
mul_120 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.float32) | |
view_198 = rand_strided((8192, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
permute_default_12 = rand_strided((16, 12, 512, 64), (393216, 64, 768, 1), device='cuda:0', dtype=torch.float16) | |
permute_default_13 = rand_strided((16, 12, 512, 64), (393216, 64, 768, 1), device='cuda:0', dtype=torch.float16) | |
permute_default_14 = rand_strided((16, 12, 512, 64), (393216, 64, 768, 1), device='cuda:0', dtype=torch.float16) | |
getitem_66 = rand_strided((16, 12, 512, 64), (393216, 64, 768, 1), device='cuda:0', dtype=torch.float16) | |
getitem_67 = rand_strided((16, 12, 512), (6144, 512, 1), device='cuda:0', dtype=torch.float32) | |
getitem_68 = rand_strided((), (), device='cuda:0', dtype=torch.int64) | |
getitem_69 = rand_strided((), (), device='cuda:0', dtype=torch.int64) | |
view_214 = rand_strided((8192, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
gt_29 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.bool) | |
mul_126 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.float32) | |
view_216 = rand_strided((8192, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
addmm_58 = rand_strided((8192, 3072), (3072, 1), device='cuda:0', dtype=torch.float16) | |
view_218 = rand_strided((8192, 3072), (3072, 1), device='cuda:0', dtype=torch.float16) | |
gt_30 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.bool) | |
mul_133 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.float32) | |
view_220 = rand_strided((8192, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
permute_default_6 = rand_strided((16, 12, 512, 64), (393216, 64, 768, 1), device='cuda:0', dtype=torch.float16) | |
permute_default_7 = rand_strided((16, 12, 512, 64), (393216, 64, 768, 1), device='cuda:0', dtype=torch.float16) | |
permute_default_8 = rand_strided((16, 12, 512, 64), (393216, 64, 768, 1), device='cuda:0', dtype=torch.float16) | |
getitem_59 = rand_strided((16, 12, 512, 64), (393216, 64, 768, 1), device='cuda:0', dtype=torch.float16) | |
getitem_60 = rand_strided((16, 12, 512), (6144, 512, 1), device='cuda:0', dtype=torch.float32) | |
getitem_61 = rand_strided((), (), device='cuda:0', dtype=torch.int64) | |
getitem_62 = rand_strided((), (), device='cuda:0', dtype=torch.int64) | |
view_236 = rand_strided((8192, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
gt_32 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.bool) | |
mul_139 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.float32) | |
view_238 = rand_strided((8192, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
addmm_64 = rand_strided((8192, 3072), (3072, 1), device='cuda:0', dtype=torch.float16) | |
view_240 = rand_strided((8192, 3072), (3072, 1), device='cuda:0', dtype=torch.float16) | |
gt_33 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.bool) | |
mul_146 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.float32) | |
view_242 = rand_strided((8192, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
permute_default = rand_strided((16, 12, 512, 64), (393216, 64, 768, 1), device='cuda:0', dtype=torch.float16) | |
permute_default_1 = rand_strided((16, 12, 512, 64), (393216, 64, 768, 1), device='cuda:0', dtype=torch.float16) | |
permute_default_2 = rand_strided((16, 12, 512, 64), (393216, 64, 768, 1), device='cuda:0', dtype=torch.float16) | |
getitem_52 = rand_strided((16, 12, 512, 64), (393216, 64, 768, 1), device='cuda:0', dtype=torch.float16) | |
getitem_53 = rand_strided((16, 12, 512), (6144, 512, 1), device='cuda:0', dtype=torch.float32) | |
getitem_54 = rand_strided((), (), device='cuda:0', dtype=torch.int64) | |
getitem_55 = rand_strided((), (), device='cuda:0', dtype=torch.int64) | |
view_258 = rand_strided((8192, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
gt_35 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.bool) | |
mul_152 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.float32) | |
view_260 = rand_strided((8192, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
addmm_70 = rand_strided((8192, 3072), (3072, 1), device='cuda:0', dtype=torch.float16) | |
view_262 = rand_strided((8192, 3072), (3072, 1), device='cuda:0', dtype=torch.float16) | |
gt_36 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.bool) | |
mul_159 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.float32) | |
view_264 = rand_strided((8192, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
addmm_72 = rand_strided((8192, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
getitem_51 = rand_strided((16, 512, 1), (512, 1, 1), device='cuda:0', dtype=torch.float32) | |
rsqrt_25 = rand_strided((16, 512, 1), (512, 1, 1), device='cuda:0', dtype=torch.float32) | |
view_266 = rand_strided((8192, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
view_267 = rand_strided((16, 512, 30522), (15630336, 30528, 1), device='cuda:0', dtype=torch.float16) | |
amax_12 = rand_strided((8192, 1), (1, 1), device='cuda:0', dtype=torch.float32) | |
log = rand_strided((8192, 1), (1, 1), device='cuda:0', dtype=torch.float32) | |
convert_element_type_510 = rand_strided((), (), device='cuda:0', dtype=torch.float32) | |
permute_134 = rand_strided((30522, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
permute_138 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
div_27 = rand_strided((16, 512, 1), (512, 1, 1), device='cuda:0', dtype=torch.float32) | |
permute_142 = rand_strided((768, 3072), (3072, 1), device='cuda:0', dtype=torch.float16) | |
permute_146 = rand_strided((3072, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
div_28 = rand_strided((16, 512, 1), (512, 1, 1), device='cuda:0', dtype=torch.float32) | |
permute_150 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
permute_162 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
permute_167 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
permute_171 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
div_30 = rand_strided((16, 512, 1), (512, 1, 1), device='cuda:0', dtype=torch.float32) | |
permute_175 = rand_strided((768, 3072), (3072, 1), device='cuda:0', dtype=torch.float16) | |
permute_179 = rand_strided((3072, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
div_31 = rand_strided((16, 512, 1), (512, 1, 1), device='cuda:0', dtype=torch.float32) | |
permute_183 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
permute_195 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
permute_200 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
permute_204 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
div_33 = rand_strided((16, 512, 1), (512, 1, 1), device='cuda:0', dtype=torch.float32) | |
permute_208 = rand_strided((768, 3072), (3072, 1), device='cuda:0', dtype=torch.float16) | |
permute_212 = rand_strided((3072, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
div_34 = rand_strided((16, 512, 1), (512, 1, 1), device='cuda:0', dtype=torch.float32) | |
permute_216 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
permute_228 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
permute_233 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
permute_237 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
div_36 = rand_strided((16, 512, 1), (512, 1, 1), device='cuda:0', dtype=torch.float32) | |
permute_241 = rand_strided((768, 3072), (3072, 1), device='cuda:0', dtype=torch.float16) | |
permute_245 = rand_strided((3072, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
div_37 = rand_strided((16, 512, 1), (512, 1, 1), device='cuda:0', dtype=torch.float32) | |
permute_249 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
permute_261 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
permute_266 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
permute_270 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
div_39 = rand_strided((16, 512, 1), (512, 1, 1), device='cuda:0', dtype=torch.float32) | |
permute_274 = rand_strided((768, 3072), (3072, 1), device='cuda:0', dtype=torch.float16) | |
permute_278 = rand_strided((3072, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
div_40 = rand_strided((16, 512, 1), (512, 1, 1), device='cuda:0', dtype=torch.float32) | |
permute_282 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
permute_294 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
permute_299 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
permute_303 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
div_42 = rand_strided((16, 512, 1), (512, 1, 1), device='cuda:0', dtype=torch.float32) | |
permute_307 = rand_strided((768, 3072), (3072, 1), device='cuda:0', dtype=torch.float16) | |
permute_311 = rand_strided((3072, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
div_43 = rand_strided((16, 512, 1), (512, 1, 1), device='cuda:0', dtype=torch.float32) | |
permute_315 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
permute_327 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
permute_332 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
permute_336 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
div_45 = rand_strided((16, 512, 1), (512, 1, 1), device='cuda:0', dtype=torch.float32) | |
permute_340 = rand_strided((768, 3072), (3072, 1), device='cuda:0', dtype=torch.float16) | |
permute_344 = rand_strided((3072, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
div_46 = rand_strided((16, 512, 1), (512, 1, 1), device='cuda:0', dtype=torch.float32) | |
permute_348 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
permute_360 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
permute_365 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
permute_369 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
div_48 = rand_strided((16, 512, 1), (512, 1, 1), device='cuda:0', dtype=torch.float32) | |
permute_373 = rand_strided((768, 3072), (3072, 1), device='cuda:0', dtype=torch.float16) | |
permute_377 = rand_strided((3072, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
div_49 = rand_strided((16, 512, 1), (512, 1, 1), device='cuda:0', dtype=torch.float32) | |
permute_381 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
permute_393 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
permute_398 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
permute_402 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
div_51 = rand_strided((16, 512, 1), (512, 1, 1), device='cuda:0', dtype=torch.float32) | |
permute_406 = rand_strided((768, 3072), (3072, 1), device='cuda:0', dtype=torch.float16) | |
permute_410 = rand_strided((3072, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
div_52 = rand_strided((16, 512, 1), (512, 1, 1), device='cuda:0', dtype=torch.float32) | |
permute_414 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
permute_426 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
permute_431 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
permute_435 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
div_54 = rand_strided((16, 512, 1), (512, 1, 1), device='cuda:0', dtype=torch.float32) | |
permute_439 = rand_strided((768, 3072), (3072, 1), device='cuda:0', dtype=torch.float16) | |
permute_443 = rand_strided((3072, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
div_55 = rand_strided((16, 512, 1), (512, 1, 1), device='cuda:0', dtype=torch.float32) | |
permute_447 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
permute_459 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
permute_464 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
permute_468 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
div_57 = rand_strided((16, 512, 1), (512, 1, 1), device='cuda:0', dtype=torch.float32) | |
permute_472 = rand_strided((768, 3072), (3072, 1), device='cuda:0', dtype=torch.float16) | |
permute_476 = rand_strided((3072, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
div_58 = rand_strided((16, 512, 1), (512, 1, 1), device='cuda:0', dtype=torch.float32) | |
permute_480 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
permute_492 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
permute_497 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
permute_501 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
div_60 = rand_strided((16, 512, 1), (512, 1, 1), device='cuda:0', dtype=torch.float32) | |
permute_505 = rand_strided((768, 3072), (3072, 1), device='cuda:0', dtype=torch.float16) | |
permute_509 = rand_strided((3072, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
div_61 = rand_strided((16, 512, 1), (512, 1, 1), device='cuda:0', dtype=torch.float32) | |
permute_513 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
permute_525 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
permute_530 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
permute_534 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
div_63 = rand_strided((16, 512, 1), (512, 1, 1), device='cuda:0', dtype=torch.float32) | |
tangents_1 = rand_strided((), (), device='cuda:0', dtype=torch.float32) | |
tangents_2 = rand_strided((16, 512, 30522), (15627264, 30522, 1), device='cuda:0', dtype=torch.float16) | |
fn = lambda: call([primals_4, primals_14, primals_20, primals_30, primals_36, primals_46, primals_52, primals_62, primals_68, primals_78, primals_84, primals_94, primals_100, primals_110, primals_116, primals_126, primals_132, primals_142, primals_148, primals_158, primals_164, primals_174, primals_180, primals_190, primals_196, primals_200, primals_204, primals_205, primals_206, primals_207, mul_1, gt, view, permute_default_66, permute_default_67, permute_default_68, getitem_129, getitem_130, getitem_131, getitem_132, view_16, gt_2, mul_9, view_18, addmm_4, view_20, gt_3, mul_16, view_22, permute_default_60, permute_default_61, permute_default_62, getitem_122, getitem_123, getitem_124, getitem_125, view_38, gt_5, mul_22, view_40, addmm_10, view_42, gt_6, mul_29, view_44, permute_default_54, permute_default_55, permute_default_56, getitem_115, getitem_116, getitem_117, getitem_118, view_60, gt_8, mul_35, view_62, addmm_16, view_64, gt_9, mul_42, view_66, permute_default_48, permute_default_49, permute_default_50, getitem_108, getitem_109, getitem_110, getitem_111, view_82, gt_11, mul_48, view_84, addmm_22, view_86, gt_12, mul_55, view_88, permute_default_42, permute_default_43, permute_default_44, getitem_101, getitem_102, getitem_103, getitem_104, view_104, gt_14, mul_61, view_106, addmm_28, view_108, gt_15, mul_68, view_110, permute_default_36, permute_default_37, permute_default_38, getitem_94, getitem_95, getitem_96, getitem_97, view_126, gt_17, mul_74, view_128, addmm_34, view_130, gt_18, mul_81, view_132, permute_default_30, permute_default_31, permute_default_32, getitem_87, getitem_88, getitem_89, getitem_90, view_148, gt_20, mul_87, view_150, addmm_40, view_152, gt_21, mul_94, view_154, permute_default_24, permute_default_25, permute_default_26, getitem_80, getitem_81, getitem_82, getitem_83, view_170, gt_23, mul_100, view_172, addmm_46, view_174, gt_24, mul_107, view_176, permute_default_18, permute_default_19, permute_default_20, getitem_73, getitem_74, getitem_75, getitem_76, view_192, gt_26, mul_113, view_194, addmm_52, view_196, gt_27, mul_120, view_198, permute_default_12, permute_default_13, permute_default_14, getitem_66, getitem_67, getitem_68, getitem_69, view_214, gt_29, mul_126, view_216, addmm_58, view_218, gt_30, mul_133, view_220, permute_default_6, permute_default_7, permute_default_8, getitem_59, getitem_60, getitem_61, getitem_62, view_236, gt_32, mul_139, view_238, addmm_64, view_240, gt_33, mul_146, view_242, permute_default, permute_default_1, permute_default_2, getitem_52, getitem_53, getitem_54, getitem_55, view_258, gt_35, mul_152, view_260, addmm_70, view_262, gt_36, mul_159, view_264, addmm_72, getitem_51, rsqrt_25, view_266, view_267, amax_12, log, convert_element_type_510, permute_134, permute_138, div_27, permute_142, permute_146, div_28, permute_150, permute_162, permute_167, permute_171, div_30, permute_175, permute_179, div_31, permute_183, permute_195, permute_200, permute_204, div_33, permute_208, permute_212, div_34, permute_216, permute_228, permute_233, permute_237, div_36, permute_241, permute_245, div_37, permute_249, permute_261, permute_266, permute_270, div_39, permute_274, permute_278, div_40, permute_282, permute_294, permute_299, permute_303, div_42, permute_307, permute_311, div_43, permute_315, permute_327, permute_332, permute_336, div_45, permute_340, permute_344, div_46, permute_348, permute_360, permute_365, permute_369, div_48, permute_373, permute_377, div_49, permute_381, permute_393, permute_398, permute_402, div_51, permute_406, permute_410, div_52, permute_414, permute_426, permute_431, permute_435, div_54, permute_439, permute_443, div_55, permute_447, permute_459, permute_464, permute_468, div_57, permute_472, permute_476, div_58, permute_480, permute_492, permute_497, permute_501, div_60, permute_505, permute_509, div_61, permute_513, permute_525, permute_530, permute_534, div_63, tangents_1, tangents_2]) | |
return print_performance(fn, times=times, repeat=repeat) | |
if __name__ == "__main__": | |
from torch._inductor.wrapper_benchmark import compiled_module_main | |
compiled_module_main('BertForMaskedLM', benchmark_compiled_module) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment