Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save shunting314/4da3b5c2ee7f9470ac2c70cd788bddf8 to your computer and use it in GitHub Desktop.
Save shunting314/4da3b5c2ee7f9470ac2c70cd788bddf8 to your computer and use it in GitHub Desktop.
# AOT ID: ['0_backward']
from ctypes import c_void_p, c_long
import torch
import math
import random
import os
import tempfile
from math import inf, nan
from torch._inductor.hooks import run_intermediate_hooks
from torch._inductor.utils import maybe_profile
from torch._inductor.codegen.memory_planning import _align as align
from torch import device, empty_strided
from torch._inductor.async_compile import AsyncCompile
from torch._inductor.select_algorithm import extern_kernels
from torch._inductor.codegen.multi_kernel import MultiKernelCall
aten = torch.ops.aten
inductor_ops = torch.ops.inductor
_quantized = torch.ops._quantized
assert_size_stride = torch._C._dynamo.guards.assert_size_stride
empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu
empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda
reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor
alloc_from_pool = torch.ops.inductor._alloc_from_pool
async_compile = AsyncCompile()
# kernel path: /tmp/torchinductor_shunting/mt/cmtlbu4aoetpujfnvf6x4euykpu33uxpy6shtforep4hry72bijk.py
# Source Nodes: [masked_lm_loss], Original ATen: [aten._log_softmax_backward_data, aten.add, aten.nll_loss_backward, aten.nll_loss_forward]
# masked_lm_loss => full_default_1, full_default_2
triton_red_fused__log_softmax_backward_data_add_nll_loss_backward_nll_loss_forward_0 = async_compile.triton('triton_red_fused__log_softmax_backward_data_add_nll_loss_backward_nll_loss_forward_0', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
from torch._dynamo.testing import rand_strided
from torch._C import _cuda_getCurrentRawStream as get_raw_stream
import torch
from torch._inductor.runtime.triton_heuristics import grid, split_scan_grid
@triton_heuristics.reduction(
size_hints=[8192, 32768],
reduction_hint=ReductionHint.DEFAULT,
filename=__file__,
triton_meta={'signature': {0: '*i64', 1: '*fp32', 2: '*fp32', 3: '*fp16', 4: '*fp16', 5: '*fp32', 6: '*fp32', 7: '*fp16', 8: 'i32', 9: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4, 5, 6, 7, 8), equal_to_1=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused__log_softmax_backward_data_add_nll_loss_backward_nll_loss_forward_0', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 9, 'num_reduction': 1, 'backend_hash': '196B341D951BBDD96DE5B2B44B37054C2CBA8E89494C8B8A1052A863B8BC7596', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'kernel_num_gb': 1.500348424}
)
@triton.jit
def triton_red_fused__log_softmax_backward_data_add_nll_loss_backward_nll_loss_forward_0(in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, in_ptr5, in_ptr6, out_ptr1, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr):
xnumel = 8192
rnumel = 30522
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
xmask = tl.full([XBLOCK, RBLOCK], True, tl.int1)
rbase = tl.arange(0, RBLOCK)[None, :]
x0 = xindex
tmp0 = tl.load(in_ptr0 + (x0), None, eviction_policy='evict_last')
tmp10 = tl.load(in_ptr1 + (0))
tmp11 = tl.broadcast_to(tmp10, [XBLOCK, RBLOCK])
tmp12 = tl.load(in_ptr2 + (0))
tmp13 = tl.broadcast_to(tmp12, [XBLOCK, RBLOCK])
_tmp18 = tl.full([XBLOCK, RBLOCK], 0, tl.float32)
for roffset in range(0, rnumel, RBLOCK):
rindex = roffset + rbase
rmask = rindex < rnumel
r1 = rindex
tmp1 = tl.full([1, 1], -100, tl.int64)
tmp2 = tmp0 != tmp1
tmp3 = tl.full([1, 1], 0, tl.int64)
tmp4 = tl.where(tmp2, tmp0, tmp3)
tmp5 = r1
tmp6 = tmp4 == tmp5
tmp7 = -1.0
tmp8 = 0.0
tmp9 = tl.where(tmp6, tmp7, tmp8)
tmp14 = tmp11 / tmp13
tmp15 = tl.where(tmp2, tmp14, tmp8)
tmp16 = tmp9 * tmp15
tmp17 = tl.broadcast_to(tmp16, [XBLOCK, RBLOCK])
tmp19 = _tmp18 + tmp17
_tmp18 = tl.where(rmask, tmp19, _tmp18)
tmp18 = tl.sum(_tmp18, 1)[:, None]
tmp30 = tl.load(in_ptr1 + (0))
tmp31 = tl.broadcast_to(tmp30, [XBLOCK, RBLOCK])
tmp32 = tl.load(in_ptr2 + (0))
tmp33 = tl.broadcast_to(tmp32, [XBLOCK, RBLOCK])
tmp39 = tl.load(in_ptr5 + (x0), None, eviction_policy='evict_last')
tmp41 = tl.load(in_ptr6 + (x0), None, eviction_policy='evict_last')
for roffset in range(0, rnumel, RBLOCK):
rindex = roffset + rbase
rmask = rindex < rnumel
r1 = rindex
tmp20 = tl.load(in_ptr3 + (r1 + (30522*x0)), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp37 = tl.load(in_ptr4 + (r1 + (30528*x0)), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp21 = tl.full([1, 1], -100, tl.int64)
tmp22 = tmp0 != tmp21
tmp23 = tl.full([1, 1], 0, tl.int64)
tmp24 = tl.where(tmp22, tmp0, tmp23)
tmp25 = r1
tmp26 = tmp24 == tmp25
tmp27 = -1.0
tmp28 = 0.0
tmp29 = tl.where(tmp26, tmp27, tmp28)
tmp34 = tmp31 / tmp33
tmp35 = tl.where(tmp22, tmp34, tmp28)
tmp36 = tmp29 * tmp35
tmp38 = tmp37.to(tl.float32)
tmp40 = tmp38 - tmp39
tmp42 = tmp40 - tmp41
tmp43 = tmp42.to(tl.float32)
tmp44 = tmp43.to(tl.float32)
tmp45 = tl_math.exp(tmp44)
tmp46 = tmp45 * tmp18
tmp47 = tmp36 - tmp46
tmp48 = tmp47.to(tl.float32)
tmp49 = tmp20 + tmp48
tl.store(out_ptr1 + (r1 + (30528*x0)), tmp49, rmask)
def get_args():
arg_0 = rand_strided((16, 512), (512, 1), device='cuda:0', dtype=torch.int64)
arg_1 = rand_strided((), (), device='cuda:0', dtype=torch.float32)
arg_2 = rand_strided((), (), device='cuda:0', dtype=torch.float32)
arg_3 = rand_strided((16, 512, 30522), (15627264, 30522, 1), device='cuda:0', dtype=torch.float16)
arg_4 = rand_strided((16, 512, 30522), (15630336, 30528, 1), device='cuda:0', dtype=torch.float16)
arg_5 = rand_strided((8192, 1), (1, 1), device='cuda:0', dtype=torch.float32)
arg_6 = rand_strided((8192, 1), (1, 1), device='cuda:0', dtype=torch.float32)
arg_7 = rand_strided((16, 512, 30522), (15630336, 30528, 1), device='cuda:0', dtype=torch.float16)
return arg_0, arg_1, arg_2, arg_3, arg_4, arg_5, arg_6, arg_7,
def call(args):
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
stream0 = get_raw_stream(0)
triton_red_fused__log_softmax_backward_data_add_nll_loss_backward_nll_loss_forward_0.run(*args, 8192, 30522, grid=grid(8192), stream=stream0)
def benchmark_all_configs(args):
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
return triton_red_fused__log_softmax_backward_data_add_nll_loss_backward_nll_loss_forward_0.benchmark_all_configs(*args, 8192, 30522, grid=grid(8192))
if __name__ == '__main__':
from triton.testing import do_bench
args = get_args()
ms = do_bench(lambda: call(args), rep=40, fast_flush=True)
num_gb = 1.500348424
gb_per_s = num_gb / (ms / 1e3)
print(f"{ms:.3f}ms {num_gb:.3f}GB {gb_per_s:.2f}GB/s")
''', device_str='cuda')
import triton
import triton.language as tl
from torch._inductor.runtime.triton_heuristics import grid, split_scan_grid, start_graph, end_graph
from torch._C import _cuda_getCurrentRawStream as get_raw_stream
# kernel path: /tmp/torchinductor_shunting/mr/cmrmqdcrcihizdvgq5xqmrayhaooh5p5y3ha4icjpyutweyhrrbl.py
# Source Nodes: [], Original ATen: []
triton_poi_fused_1 = async_compile.triton('triton_poi_fused_1', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
from torch._dynamo.testing import rand_strided
from torch._C import _cuda_getCurrentRawStream as get_raw_stream
import torch
from torch._inductor.runtime.triton_heuristics import grid, split_scan_grid
@triton_heuristics.pointwise(
size_hints=[268435456],
filename=__file__,
triton_meta={'signature': {0: '*fp16', 1: '*fp16', 2: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2), equal_to_1=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_1', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': '196B341D951BBDD96DE5B2B44B37054C2CBA8E89494C8B8A1052A863B8BC7596', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'kernel_num_gb': 1.0002432},
min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused_1(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
xnumel = 250085376
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = tl.full([XBLOCK], True, tl.int1)
x0 = xindex % 30528
x2 = xindex
tmp0 = x0
tmp1 = tl.full([1], 30522, tl.int64)
tmp2 = tmp0 < tmp1
tmp3 = tl.load(in_ptr0 + (x2), tmp2, other=0.0).to(tl.float32)
tl.store(out_ptr0 + (x2), tmp3, None)
def get_args():
arg_0 = rand_strided((16, 512, 30522), (15630336, 30528, 1), device='cuda:0', dtype=torch.float16)
arg_1 = rand_strided((8192, 30528), (30528, 1), device='cuda:0', dtype=torch.float16)
return arg_0, arg_1,
def call(args):
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
stream0 = get_raw_stream(0)
triton_poi_fused_1.run(*args, 250085376, grid=grid(250085376), stream=stream0)
def benchmark_all_configs(args):
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
return triton_poi_fused_1.benchmark_all_configs(*args, 250085376, grid=grid(250085376))
if __name__ == '__main__':
from triton.testing import do_bench
args = get_args()
ms = do_bench(lambda: call(args), rep=40, fast_flush=True)
num_gb = 1.0002432
gb_per_s = num_gb / (ms / 1e3)
print(f"{ms:.3f}ms {num_gb:.3f}GB {gb_per_s:.2f}GB/s")
''', device_str='cuda')
# kernel path: /tmp/torchinductor_shunting/kr/ckrtxiooml7ugwbki5lfdkf26ifptclbwzfuczjsgv3337j6cnlt.py
# Source Nodes: [], Original ATen: []
triton_poi_fused_2 = async_compile.triton('triton_poi_fused_2', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
from torch._dynamo.testing import rand_strided
from torch._C import _cuda_getCurrentRawStream as get_raw_stream
import torch
from torch._inductor.runtime.triton_heuristics import grid, split_scan_grid
@triton_heuristics.pointwise(
size_hints=[33554432],
filename=__file__,
triton_meta={'signature': {0: '*fp16', 1: '*fp16', 2: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2), equal_to_1=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_2', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': '196B341D951BBDD96DE5B2B44B37054C2CBA8E89494C8B8A1052A863B8BC7596', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'kernel_num_gb': 0.0937728},
min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused_2(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
xnumel = 23445504
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = tl.full([XBLOCK], True, tl.int1)
x1 = (xindex // 768)
x2 = xindex
tmp0 = x1
tmp1 = tl.full([1], 30522, tl.int64)
tmp2 = tmp0 < tmp1
tmp3 = tl.load(in_ptr0 + (x2), tmp2, other=0.0).to(tl.float32)
tl.store(out_ptr0 + (x2), tmp3, None)
def get_args():
arg_0 = rand_strided((30522, 768), (768, 1), device='cuda:0', dtype=torch.float16)
arg_1 = rand_strided((30528, 768), (768, 1), device='cuda:0', dtype=torch.float16)
return arg_0, arg_1,
def call(args):
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
stream0 = get_raw_stream(0)
triton_poi_fused_2.run(*args, 23445504, grid=grid(23445504), stream=stream0)
def benchmark_all_configs(args):
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
return triton_poi_fused_2.benchmark_all_configs(*args, 23445504, grid=grid(23445504))
if __name__ == '__main__':
from triton.testing import do_bench
args = get_args()
ms = do_bench(lambda: call(args), rep=40, fast_flush=True)
num_gb = 0.0937728
gb_per_s = num_gb / (ms / 1e3)
print(f"{ms:.3f}ms {num_gb:.3f}GB {gb_per_s:.2f}GB/s")
''', device_str='cuda')
# kernel path: /tmp/torchinductor_shunting/p4/cp4m2j65mbzodehwkp3jqfofigkvv3uayonfyitzbmq3ro7b62pt.py
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum]
triton_red_fused__to_copy_sum_3 = async_compile.triton('triton_red_fused__to_copy_sum_3', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
from torch._dynamo.testing import rand_strided
from torch._C import _cuda_getCurrentRawStream as get_raw_stream
import torch
from torch._inductor.runtime.triton_heuristics import grid, split_scan_grid
@triton_heuristics.reduction(
size_hints=[32768, 8192],
reduction_hint=ReductionHint.DEFAULT,
filename=__file__,
triton_meta={'signature': {0: '*fp16', 1: '*fp32', 2: 'i32', 3: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 3), equal_to_1=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_sum_3', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': '196B341D951BBDD96DE5B2B44B37054C2CBA8E89494C8B8A1052A863B8BC7596', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'kernel_num_gb': 0.500194536}
)
@triton.jit
def triton_red_fused__to_copy_sum_3(in_ptr0, out_ptr1, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr):
xnumel = 30522
rnumel = 8192
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
xmask = xindex < xnumel
rbase = tl.arange(0, RBLOCK)[None, :]
x0 = xindex
_tmp2 = tl.full([XBLOCK, RBLOCK], 0, tl.float32)
for roffset in range(0, rnumel, RBLOCK):
rindex = roffset + rbase
rmask = rindex < rnumel
r1 = rindex
tmp0 = tl.load(in_ptr0 + (x0 + (30528*r1)), rmask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp1 = tl.broadcast_to(tmp0, [XBLOCK, RBLOCK])
tmp3 = _tmp2 + tmp1
_tmp2 = tl.where(rmask & xmask, tmp3, _tmp2)
tmp2 = tl.sum(_tmp2, 1)[:, None]
tmp4 = tmp2.to(tl.float32)
tl.store(out_ptr1 + (x0), tmp4, xmask)
def get_args():
arg_0 = rand_strided((16, 512, 30522), (15630336, 30528, 1), device='cuda:0', dtype=torch.float16)
arg_1 = rand_strided((30522,), (1,), device='cuda:0', dtype=torch.float32)
return arg_0, arg_1,
def call(args):
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
stream0 = get_raw_stream(0)
triton_red_fused__to_copy_sum_3.run(*args, 30522, 8192, grid=grid(30522), stream=stream0)
def benchmark_all_configs(args):
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
return triton_red_fused__to_copy_sum_3.benchmark_all_configs(*args, 30522, 8192, grid=grid(30522))
if __name__ == '__main__':
from triton.testing import do_bench
args = get_args()
ms = do_bench(lambda: call(args), rep=40, fast_flush=True)
num_gb = 0.500194536
gb_per_s = num_gb / (ms / 1e3)
print(f"{ms:.3f}ms {num_gb:.3f}GB {gb_per_s:.2f}GB/s")
''', device_str='cuda')
# kernel path: /tmp/torchinductor_shunting/nx/cnxl5yf7bimfnoatnitspyqvw7ppsv2sxfrbw2q6pvaf653wnqsg.py
# Source Nodes: [], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_4 = async_compile.triton('triton_poi_fused__to_copy_4', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
from torch._dynamo.testing import rand_strided
from torch._C import _cuda_getCurrentRawStream as get_raw_stream
import torch
from torch._inductor.runtime.triton_heuristics import grid, split_scan_grid
@triton_heuristics.pointwise(
size_hints=[33554432],
filename=__file__,
triton_meta={'signature': {0: '*fp16', 1: '*fp32', 2: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2), equal_to_1=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_4', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': '196B341D951BBDD96DE5B2B44B37054C2CBA8E89494C8B8A1052A863B8BC7596', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'kernel_num_gb': 0.140645376},
min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused__to_copy_4(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
xnumel = 23440896
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = xindex < xnumel
x0 = xindex
tmp0 = tl.load(in_ptr0 + (x0), xmask).to(tl.float32)
tmp1 = tmp0.to(tl.float32)
tl.store(out_ptr0 + (x0), tmp1, xmask)
def get_args():
arg_0 = rand_strided((30522, 768), (768, 1), device='cuda:0', dtype=torch.float16)
arg_1 = rand_strided((30522, 768), (768, 1), device='cuda:0', dtype=torch.float32)
return arg_0, arg_1,
def call(args):
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_4.run(*args, 23440896, grid=grid(23440896), stream=stream0)
def benchmark_all_configs(args):
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
return triton_poi_fused__to_copy_4.benchmark_all_configs(*args, 23440896, grid=grid(23440896))
if __name__ == '__main__':
from triton.testing import do_bench
args = get_args()
ms = do_bench(lambda: call(args), rep=40, fast_flush=True)
num_gb = 0.140645376
gb_per_s = num_gb / (ms / 1e3)
print(f"{ms:.3f}ms {num_gb:.3f}GB {gb_per_s:.2f}GB/s")
''', device_str='cuda')
# kernel path: /tmp/torchinductor_shunting/yy/cyyvpnhm2jzov5qsqyvwzq5jr53s4xnzjf5k276kkys4x3j5o3m6.py
# Source Nodes: [hidden_states_97, hidden_states_98], Original ATen: [aten._to_copy, aten.gelu, aten.gelu_backward, aten.native_layer_norm, aten.native_layer_norm_backward, aten.view]
# hidden_states_97 => add_100, convert_element_type_498, erf_12, mul_161, mul_162, mul_163
# hidden_states_98 => mul_164, sub_38
triton_per_fused__to_copy_gelu_gelu_backward_native_layer_norm_native_layer_norm_backward_view_5 = async_compile.triton('triton_per_fused__to_copy_gelu_gelu_backward_native_layer_norm_native_layer_norm_backward_view_5', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
from torch._dynamo.testing import rand_strided
from torch._C import _cuda_getCurrentRawStream as get_raw_stream
import torch
from torch._inductor.runtime.triton_heuristics import grid, split_scan_grid
@triton_heuristics.persistent_reduction(
size_hints=[8192, 1024],
reduction_hint=ReductionHint.INNER,
filename=__file__,
triton_meta={'signature': {0: '*fp16', 1: '*fp32', 2: '*fp16', 3: '*fp32', 4: '*fp32', 5: '*fp16', 6: 'i32', 7: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4, 5, 6, 7), equal_to_1=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_gelu_gelu_backward_native_layer_norm_native_layer_norm_backward_view_5', 'mutated_arg_names': [], 'no_x_dim': True, 'num_load': 5, 'num_reduction': 2, 'backend_hash': '196B341D951BBDD96DE5B2B44B37054C2CBA8E89494C8B8A1052A863B8BC7596', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'kernel_num_gb': 0.037817344}
)
@triton.jit
def triton_per_fused__to_copy_gelu_gelu_backward_native_layer_norm_native_layer_norm_backward_view_5(in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, out_ptr3, xnumel, rnumel):
xnumel = 8192
XBLOCK: tl.constexpr = 1
rnumel = 768
RBLOCK: tl.constexpr = 1024
xoffset = tl.program_id(0) * XBLOCK
xindex = tl.full([1], xoffset, tl.int32)
xmask = tl.full([RBLOCK], True, tl.int1)
rindex = tl.arange(0, RBLOCK)[:]
roffset = 0
rmask = rindex < rnumel
r1 = rindex
x0 = xindex
tmp0 = tl.load(in_ptr0 + (r1 + (768*x0)), rmask, other=0.0).to(tl.float32)
tmp2 = tl.load(in_ptr1 + (r1), rmask, eviction_policy='evict_last', other=0.0)
tmp8 = tl.load(in_ptr2 + (r1 + (768*x0)), rmask, other=0.0).to(tl.float32)
tmp18 = tl.load(in_ptr3 + (x0), None, eviction_policy='evict_last')
tmp20 = tl.load(in_ptr4 + (x0), None, eviction_policy='evict_last')
tmp1 = tmp0.to(tl.float32)
tmp3 = tmp1 * tmp2
tmp4 = tl.broadcast_to(tmp3, [RBLOCK])
tmp6 = tl.where(rmask, tmp4, 0)
tmp7 = triton_helpers.promote_to_tensor(tl.sum(tmp6, 0))
tmp9 = tmp8.to(tl.float32)
tmp10 = 0.5
tmp11 = tmp9 * tmp10
tmp12 = 0.7071067811865476
tmp13 = tmp9 * tmp12
tmp14 = libdevice.erf(tmp13)
tmp15 = 1.0
tmp16 = tmp14 + tmp15
tmp17 = tmp11 * tmp16
tmp19 = tmp17 - tmp18
tmp21 = tmp19 * tmp20
tmp22 = tmp3 * tmp21
tmp23 = tl.broadcast_to(tmp22, [RBLOCK])
tmp25 = tl.where(rmask, tmp23, 0)
tmp26 = triton_helpers.promote_to_tensor(tl.sum(tmp25, 0))
tmp27 = 0.0013020833333333333
tmp28 = tmp20 * tmp27
tmp29 = 768.0
tmp30 = tmp3 * tmp29
tmp31 = tmp30 - tmp7
tmp32 = tmp21 * tmp26
tmp33 = tmp31 - tmp32
tmp34 = tmp28 * tmp33
tmp35 = tmp16 * tmp10
tmp36 = tmp9 * tmp9
tmp37 = -0.5
tmp38 = tmp36 * tmp37
tmp39 = tl_math.exp(tmp38)
tmp40 = 0.3989422804014327
tmp41 = tmp39 * tmp40
tmp42 = tmp9 * tmp41
tmp43 = tmp35 + tmp42
tmp44 = tmp34 * tmp43
tmp45 = tmp44.to(tl.float32)
tl.store(out_ptr3 + (r1 + (768*x0)), tmp45, rmask)
def get_args():
arg_0 = rand_strided((8192, 768), (768, 1), device='cuda:0', dtype=torch.float16)
arg_1 = rand_strided((768,), (1,), device='cuda:0', dtype=torch.float32)
arg_2 = rand_strided((8192, 768), (768, 1), device='cuda:0', dtype=torch.float16)
arg_3 = rand_strided((16, 512, 1), (512, 1, 1), device='cuda:0', dtype=torch.float32)
arg_4 = rand_strided((16, 512, 1), (512, 1, 1), device='cuda:0', dtype=torch.float32)
arg_5 = rand_strided((8192, 768), (768, 1), device='cuda:0', dtype=torch.float16)
return arg_0, arg_1, arg_2, arg_3, arg_4, arg_5,
def call(args):
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
stream0 = get_raw_stream(0)
triton_per_fused__to_copy_gelu_gelu_backward_native_layer_norm_native_layer_norm_backward_view_5.run(*args, 8192, 768, grid=grid(8192), stream=stream0)
def benchmark_all_configs(args):
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
return triton_per_fused__to_copy_gelu_gelu_backward_native_layer_norm_native_layer_norm_backward_view_5.benchmark_all_configs(*args, 8192, 768, grid=grid(8192))
if __name__ == '__main__':
from triton.testing import do_bench
args = get_args()
ms = do_bench(lambda: call(args), rep=40, fast_flush=True)
num_gb = 0.037817344
gb_per_s = num_gb / (ms / 1e3)
print(f"{ms:.3f}ms {num_gb:.3f}GB {gb_per_s:.2f}GB/s")
''', device_str='cuda')
# kernel path: /tmp/torchinductor_shunting/p2/cp26z4vf45nz2fihk4sr3hatolkq5voe5ay6tuzetss46ifomndv.py
# Source Nodes: [hidden_states_97, hidden_states_98], Original ATen: [aten._to_copy, aten.gelu, aten.native_layer_norm, aten.native_layer_norm_backward]
# hidden_states_97 => add_100, convert_element_type_498, erf_12, mul_161, mul_162, mul_163
# hidden_states_98 => mul_164, sub_38
triton_red_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_6 = async_compile.triton('triton_red_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_6', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
from torch._dynamo.testing import rand_strided
from torch._C import _cuda_getCurrentRawStream as get_raw_stream
import torch
from torch._inductor.runtime.triton_heuristics import grid, split_scan_grid
@triton_heuristics.reduction(
size_hints=[65536, 128],
reduction_hint=ReductionHint.OUTER,
filename=__file__,
triton_meta={'signature': {0: '*fp16', 1: '*fp16', 2: '*fp32', 3: '*fp32', 4: '*fp32', 5: '*fp32', 6: 'i32', 7: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4, 5, 6, 7), equal_to_1=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_6', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 4, 'num_reduction': 2, 'backend_hash': '196B341D951BBDD96DE5B2B44B37054C2CBA8E89494C8B8A1052A863B8BC7596', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'kernel_num_gb': 0.025624576}
)
@triton.jit
def triton_red_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_6(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr0, out_ptr1, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr):
xnumel = 49152
rnumel = 128
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
xmask = tl.full([XBLOCK, RBLOCK], True, tl.int1)
rbase = tl.arange(0, RBLOCK)[None, :]
x0 = xindex % 768
x1 = (xindex // 768)
_tmp18 = tl.full([XBLOCK, RBLOCK], 0, tl.float32)
x3 = xindex
_tmp21 = tl.full([XBLOCK, RBLOCK], 0, tl.float32)
for roffset in range(0, rnumel, RBLOCK):
rindex = roffset + rbase
rmask = rindex < rnumel
r2 = rindex
tmp0 = tl.load(in_ptr0 + (x0 + (768*r2) + (98304*x1)), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp2 = tl.load(in_ptr1 + (x0 + (768*r2) + (98304*x1)), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp12 = tl.load(in_ptr2 + (r2 + (128*x1)), rmask, eviction_policy='evict_last', other=0.0)
tmp14 = tl.load(in_ptr3 + (r2 + (128*x1)), rmask, eviction_policy='evict_last', other=0.0)
tmp1 = tmp0.to(tl.float32)
tmp3 = tmp2.to(tl.float32)
tmp4 = 0.5
tmp5 = tmp3 * tmp4
tmp6 = 0.7071067811865476
tmp7 = tmp3 * tmp6
tmp8 = libdevice.erf(tmp7)
tmp9 = 1.0
tmp10 = tmp8 + tmp9
tmp11 = tmp5 * tmp10
tmp13 = tmp11 - tmp12
tmp15 = tmp13 * tmp14
tmp16 = tmp1 * tmp15
tmp17 = tl.broadcast_to(tmp16, [XBLOCK, RBLOCK])
tmp19 = _tmp18 + tmp17
_tmp18 = tl.where(rmask, tmp19, _tmp18)
tmp20 = tl.broadcast_to(tmp1, [XBLOCK, RBLOCK])
tmp22 = _tmp21 + tmp20
_tmp21 = tl.where(rmask, tmp22, _tmp21)
tmp18 = tl.sum(_tmp18, 1)[:, None]
tl.store(out_ptr0 + (x3), tmp18, None)
tmp21 = tl.sum(_tmp21, 1)[:, None]
tl.store(out_ptr1 + (x3), tmp21, None)
def get_args():
arg_0 = rand_strided((8192, 768), (768, 1), device='cuda:0', dtype=torch.float16)
arg_1 = rand_strided((8192, 768), (768, 1), device='cuda:0', dtype=torch.float16)
arg_2 = rand_strided((16, 512, 1), (512, 1, 1), device='cuda:0', dtype=torch.float32)
arg_3 = rand_strided((16, 512, 1), (512, 1, 1), device='cuda:0', dtype=torch.float32)
arg_4 = rand_strided((768, 64), (1, 768), device='cuda:0', dtype=torch.float32)
arg_5 = rand_strided((768, 64), (1, 768), device='cuda:0', dtype=torch.float32)
return arg_0, arg_1, arg_2, arg_3, arg_4, arg_5,
def call(args):
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
stream0 = get_raw_stream(0)
triton_red_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_6.run(*args, 49152, 128, grid=grid(49152), stream=stream0)
def benchmark_all_configs(args):
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
return triton_red_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_6.benchmark_all_configs(*args, 49152, 128, grid=grid(49152))
if __name__ == '__main__':
from triton.testing import do_bench
args = get_args()
ms = do_bench(lambda: call(args), rep=40, fast_flush=True)
num_gb = 0.025624576
gb_per_s = num_gb / (ms / 1e3)
print(f"{ms:.3f}ms {num_gb:.3f}GB {gb_per_s:.2f}GB/s")
''', device_str='cuda')
# kernel path: /tmp/torchinductor_shunting/be/cbevumbrgmzdfgdapkc3uyo5g3sg7zr3sizocbvwld6n22bimq7o.py
# Source Nodes: [hidden_states_97, hidden_states_98], Original ATen: [aten._to_copy, aten.gelu, aten.native_layer_norm, aten.native_layer_norm_backward]
# hidden_states_97 => add_100, convert_element_type_498, erf_12, mul_161, mul_162, mul_163
# hidden_states_98 => mul_164, sub_38
triton_per_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_7 = async_compile.triton('triton_per_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_7', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
from torch._dynamo.testing import rand_strided
from torch._C import _cuda_getCurrentRawStream as get_raw_stream
import torch
from torch._inductor.runtime.triton_heuristics import grid, split_scan_grid
@triton_heuristics.persistent_reduction(
size_hints=[1024, 64],
reduction_hint=ReductionHint.OUTER,
filename=__file__,
triton_meta={'signature': {0: '*fp32', 1: '*fp32', 2: 'i32', 3: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3), equal_to_1=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_7', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': '196B341D951BBDD96DE5B2B44B37054C2CBA8E89494C8B8A1052A863B8BC7596', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'kernel_num_gb': 0.00019968}
)
@triton.jit
def triton_per_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_7(in_ptr0, out_ptr0, xnumel, rnumel, XBLOCK : tl.constexpr):
xnumel = 768
rnumel = 64
RBLOCK: tl.constexpr = 64
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
xmask = xindex < xnumel
rindex = tl.arange(0, RBLOCK)[None, :]
roffset = 0
rmask = tl.full([XBLOCK, RBLOCK], True, tl.int1)
r1 = rindex
x0 = xindex
tmp0 = tl.load(in_ptr0 + (x0 + (768*r1)), xmask, other=0.0)
tmp1 = tl.broadcast_to(tmp0, [XBLOCK, RBLOCK])
tmp3 = tl.where(xmask, tmp1, 0)
tmp4 = tl.sum(tmp3, 1)[:, None]
tl.store(out_ptr0 + (x0), tmp4, xmask)
def get_args():
arg_0 = rand_strided((768, 64), (1, 768), device='cuda:0', dtype=torch.float32)
arg_1 = rand_strided((768,), (1,), device='cuda:0', dtype=torch.float32)
return arg_0, arg_1,
def call(args):
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
stream0 = get_raw_stream(0)
triton_per_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_7.run(*args, 768, 64, grid=grid(768), stream=stream0)
def benchmark_all_configs(args):
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
return triton_per_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_7.benchmark_all_configs(*args, 768, 64, grid=grid(768))
if __name__ == '__main__':
from triton.testing import do_bench
args = get_args()
ms = do_bench(lambda: call(args), rep=40, fast_flush=True)
num_gb = 0.00019968
gb_per_s = num_gb / (ms / 1e3)
print(f"{ms:.3f}ms {num_gb:.3f}GB {gb_per_s:.2f}GB/s")
''', device_str='cuda')
# kernel path: /tmp/torchinductor_shunting/dg/cdgwv5eeos3ltgqpvtb7olrpc6yd4n6mytrgvxk656fu3eavxcgq.py
# Source Nodes: [], Original ATen: [aten.sum]
triton_red_fused_sum_8 = async_compile.triton('triton_red_fused_sum_8', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
from torch._dynamo.testing import rand_strided
from torch._C import _cuda_getCurrentRawStream as get_raw_stream
import torch
from torch._inductor.runtime.triton_heuristics import grid, split_scan_grid
@triton_heuristics.reduction(
size_hints=[65536, 128],
reduction_hint=ReductionHint.OUTER,
filename=__file__,
triton_meta={'signature': {0: '*fp16', 1: '*fp32', 2: 'i32', 3: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3), equal_to_1=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused_sum_8', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': '196B341D951BBDD96DE5B2B44B37054C2CBA8E89494C8B8A1052A863B8BC7596', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'kernel_num_gb': 0.01277952}
)
@triton.jit
def triton_red_fused_sum_8(in_ptr0, out_ptr0, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr):
xnumel = 49152
rnumel = 128
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
xmask = tl.full([XBLOCK, RBLOCK], True, tl.int1)
rbase = tl.arange(0, RBLOCK)[None, :]
x0 = xindex % 768
x1 = (xindex // 768)
_tmp2 = tl.full([XBLOCK, RBLOCK], 0, tl.float32)
x3 = xindex
for roffset in range(0, rnumel, RBLOCK):
rindex = roffset + rbase
rmask = rindex < rnumel
r2 = rindex
tmp0 = tl.load(in_ptr0 + (x0 + (768*r2) + (98304*x1)), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp1 = tl.broadcast_to(tmp0, [XBLOCK, RBLOCK])
tmp3 = _tmp2 + tmp1
_tmp2 = tl.where(rmask, tmp3, _tmp2)
tmp2 = tl.sum(_tmp2, 1)[:, None]
tl.store(out_ptr0 + (x3), tmp2, None)
def get_args():
arg_0 = rand_strided((8192, 768), (768, 1), device='cuda:0', dtype=torch.float16)
arg_1 = rand_strided((1, 768, 64), (49152, 1, 768), device='cuda:0', dtype=torch.float32)
return arg_0, arg_1,
def call(args):
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
stream0 = get_raw_stream(0)
triton_red_fused_sum_8.run(*args, 49152, 128, grid=grid(49152), stream=stream0)
def benchmark_all_configs(args):
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
return triton_red_fused_sum_8.benchmark_all_configs(*args, 49152, 128, grid=grid(49152))
if __name__ == '__main__':
from triton.testing import do_bench
args = get_args()
ms = do_bench(lambda: call(args), rep=40, fast_flush=True)
num_gb = 0.01277952
gb_per_s = num_gb / (ms / 1e3)
print(f"{ms:.3f}ms {num_gb:.3f}GB {gb_per_s:.2f}GB/s")
''', device_str='cuda')
# kernel path: /tmp/torchinductor_shunting/de/cdefudbzyfjtf3vfo5cjqbkh4nayk6d2w5vztk6sxrz3diyaoomj.py
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum]
triton_per_fused__to_copy_sum_9 = async_compile.triton('triton_per_fused__to_copy_sum_9', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
from torch._dynamo.testing import rand_strided
from torch._C import _cuda_getCurrentRawStream as get_raw_stream
import torch
from torch._inductor.runtime.triton_heuristics import grid, split_scan_grid
@triton_heuristics.persistent_reduction(
size_hints=[1024, 64],
reduction_hint=ReductionHint.OUTER,
filename=__file__,
triton_meta={'signature': {0: '*fp32', 1: '*fp32', 2: 'i32', 3: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3), equal_to_1=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_sum_9', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': '196B341D951BBDD96DE5B2B44B37054C2CBA8E89494C8B8A1052A863B8BC7596', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'kernel_num_gb': 0.00019968}
)
@triton.jit
def triton_per_fused__to_copy_sum_9(in_ptr0, out_ptr1, xnumel, rnumel, XBLOCK : tl.constexpr):
xnumel = 768
rnumel = 64
RBLOCK: tl.constexpr = 64
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
xmask = xindex < xnumel
rindex = tl.arange(0, RBLOCK)[None, :]
roffset = 0
rmask = tl.full([XBLOCK, RBLOCK], True, tl.int1)
r1 = rindex
x0 = xindex
tmp0 = tl.load(in_ptr0 + (x0 + (768*r1)), xmask, other=0.0)
tmp1 = tl.broadcast_to(tmp0, [XBLOCK, RBLOCK])
tmp3 = tl.where(xmask, tmp1, 0)
tmp4 = tl.sum(tmp3, 1)[:, None]
tmp5 = tmp4.to(tl.float32)
tl.store(out_ptr1 + (x0), tmp5, xmask)
def get_args():
arg_0 = rand_strided((1, 768, 64), (49152, 1, 768), device='cuda:0', dtype=torch.float32)
arg_1 = rand_strided((768,), (1,), device='cuda:0', dtype=torch.float32)
return arg_0, arg_1,
def call(args):
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
stream0 = get_raw_stream(0)
triton_per_fused__to_copy_sum_9.run(*args, 768, 64, grid=grid(768), stream=stream0)
def benchmark_all_configs(args):
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
return triton_per_fused__to_copy_sum_9.benchmark_all_configs(*args, 768, 64, grid=grid(768))
if __name__ == '__main__':
from triton.testing import do_bench
args = get_args()
ms = do_bench(lambda: call(args), rep=40, fast_flush=True)
num_gb = 0.00019968
gb_per_s = num_gb / (ms / 1e3)
print(f"{ms:.3f}ms {num_gb:.3f}GB {gb_per_s:.2f}GB/s")
''', device_str='cuda')
# kernel path: /tmp/torchinductor_shunting/aa/caaivtwgxhkxzzagzke745nxumkqey4pkyuq3ww7l7r3nilsqszw.py
# Source Nodes: [], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_10 = async_compile.triton('triton_poi_fused__to_copy_10', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
from torch._dynamo.testing import rand_strided
from torch._C import _cuda_getCurrentRawStream as get_raw_stream
import torch
from torch._inductor.runtime.triton_heuristics import grid, split_scan_grid
@triton_heuristics.pointwise(
size_hints=[1048576],
filename=__file__,
triton_meta={'signature': {0: '*fp16', 1: '*fp32', 2: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2), equal_to_1=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_10', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': '196B341D951BBDD96DE5B2B44B37054C2CBA8E89494C8B8A1052A863B8BC7596', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'kernel_num_gb': 0.003538944},
min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused__to_copy_10(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
xnumel = 589824
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = tl.full([XBLOCK], True, tl.int1)
x0 = xindex
tmp0 = tl.load(in_ptr0 + (x0), None).to(tl.float32)
tmp1 = tmp0.to(tl.float32)
tl.store(out_ptr0 + (x0), tmp1, None)
def get_args():
arg_0 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.float16)
arg_1 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.float32)
return arg_0, arg_1,
def call(args):
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_10.run(*args, 589824, grid=grid(589824), stream=stream0)
def benchmark_all_configs(args):
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
return triton_poi_fused__to_copy_10.benchmark_all_configs(*args, 589824, grid=grid(589824))
if __name__ == '__main__':
from triton.testing import do_bench
args = get_args()
ms = do_bench(lambda: call(args), rep=40, fast_flush=True)
num_gb = 0.003538944
gb_per_s = num_gb / (ms / 1e3)
print(f"{ms:.3f}ms {num_gb:.3f}GB {gb_per_s:.2f}GB/s")
''', device_str='cuda')
# kernel path: /tmp/torchinductor_shunting/3x/c3xzzagzsfle5jswbt57zqxejbypuvykxyswwi5k2leqkor7uv76.py
# Source Nodes: [], Original ATen: [aten._to_copy, aten.native_dropout_backward, aten.native_layer_norm_backward]
triton_per_fused__to_copy_native_dropout_backward_native_layer_norm_backward_11 = async_compile.triton('triton_per_fused__to_copy_native_dropout_backward_native_layer_norm_backward_11', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
from torch._dynamo.testing import rand_strided
from torch._C import _cuda_getCurrentRawStream as get_raw_stream
import torch
from torch._inductor.runtime.triton_heuristics import grid, split_scan_grid
@triton_heuristics.persistent_reduction(
size_hints=[8192, 1024],
reduction_hint=ReductionHint.INNER,
filename=__file__,
triton_meta={'signature': {0: '*fp16', 1: '*fp32', 2: '*fp32', 3: '*fp32', 4: '*i1', 5: '*fp32', 6: '*fp16', 7: 'i32', 8: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4, 5, 6, 7, 8), equal_to_1=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_native_dropout_backward_native_layer_norm_backward_11', 'mutated_arg_names': [], 'no_x_dim': True, 'num_load': 5, 'num_reduction': 2, 'backend_hash': '196B341D951BBDD96DE5B2B44B37054C2CBA8E89494C8B8A1052A863B8BC7596', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'kernel_num_gb': 0.081824768}
)
@triton.jit
def triton_per_fused__to_copy_native_dropout_backward_native_layer_norm_backward_11(in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, out_ptr2, out_ptr3, xnumel, rnumel):
xnumel = 8192
XBLOCK: tl.constexpr = 1
rnumel = 768
RBLOCK: tl.constexpr = 1024
xoffset = tl.program_id(0) * XBLOCK
xindex = tl.full([1], xoffset, tl.int32)
xmask = tl.full([RBLOCK], True, tl.int1)
rindex = tl.arange(0, RBLOCK)[:]
roffset = 0
rmask = rindex < rnumel
r1 = rindex
x0 = xindex
tmp0 = tl.load(in_ptr0 + (r1 + (768*x0)), rmask, other=0.0).to(tl.float32)
tmp2 = tl.load(in_ptr1 + (r1), rmask, eviction_policy='evict_last', other=0.0)
tmp8 = tl.load(in_ptr2 + (r1 + (768*x0)), rmask, other=0.0)
tmp14 = tl.load(in_ptr3 + (x0), None, eviction_policy='evict_last')
tmp22 = tl.load(in_ptr4 + (r1 + (768*x0)), rmask, other=0.0).to(tl.int1)
tmp1 = tmp0.to(tl.float32)
tmp3 = tmp1 * tmp2
tmp4 = tl.broadcast_to(tmp3, [RBLOCK])
tmp6 = tl.where(rmask, tmp4, 0)
tmp7 = triton_helpers.promote_to_tensor(tl.sum(tmp6, 0))
tmp9 = tmp3 * tmp8
tmp10 = tl.broadcast_to(tmp9, [RBLOCK])
tmp12 = tl.where(rmask, tmp10, 0)
tmp13 = triton_helpers.promote_to_tensor(tl.sum(tmp12, 0))
tmp15 = 768.0
tmp16 = tmp3 * tmp15
tmp17 = tmp16 - tmp7
tmp18 = tmp8 * tmp13
tmp19 = tmp17 - tmp18
tmp20 = tmp14 * tmp19
tmp21 = tmp20.to(tl.float32)
tmp23 = tmp22.to(tl.float32)
tmp24 = 1.1111111111111112
tmp25 = tmp23 * tmp24
tmp26 = tmp21 * tmp25
tl.store(out_ptr2 + (r1 + (768*x0)), tmp20, rmask)
tl.store(out_ptr3 + (r1 + (768*x0)), tmp26, rmask)
def get_args():
arg_0 = rand_strided((8192, 768), (768, 1), device='cuda:0', dtype=torch.float16)
arg_1 = rand_strided((768,), (1,), device='cuda:0', dtype=torch.float32)
arg_2 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.float32)
arg_3 = rand_strided((16, 512, 1), (512, 1, 1), device='cuda:0', dtype=torch.float32)
arg_4 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.bool)
arg_5 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.float32)
arg_6 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.float16)
return arg_0, arg_1, arg_2, arg_3, arg_4, arg_5, arg_6,
def call(args):
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
stream0 = get_raw_stream(0)
triton_per_fused__to_copy_native_dropout_backward_native_layer_norm_backward_11.run(*args, 8192, 768, grid=grid(8192), stream=stream0)
def benchmark_all_configs(args):
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
return triton_per_fused__to_copy_native_dropout_backward_native_layer_norm_backward_11.benchmark_all_configs(*args, 8192, 768, grid=grid(8192))
if __name__ == '__main__':
from triton.testing import do_bench
args = get_args()
ms = do_bench(lambda: call(args), rep=40, fast_flush=True)
num_gb = 0.081824768
gb_per_s = num_gb / (ms / 1e3)
print(f"{ms:.3f}ms {num_gb:.3f}GB {gb_per_s:.2f}GB/s")
''', device_str='cuda')
# kernel path: /tmp/torchinductor_shunting/ix/cixvree2ydqv56di7ljjja76jndkhgeq7umgiaw3evahe2aocyji.py
# Source Nodes: [], Original ATen: [aten._to_copy, aten.native_layer_norm_backward]
triton_red_fused__to_copy_native_layer_norm_backward_12 = async_compile.triton('triton_red_fused__to_copy_native_layer_norm_backward_12', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
from torch._dynamo.testing import rand_strided
from torch._C import _cuda_getCurrentRawStream as get_raw_stream
import torch
from torch._inductor.runtime.triton_heuristics import grid, split_scan_grid
@triton_heuristics.reduction(
size_hints=[65536, 128],
reduction_hint=ReductionHint.OUTER,
filename=__file__,
triton_meta={'signature': {0: '*fp16', 1: '*fp32', 2: '*fp32', 3: '*fp32', 4: 'i32', 5: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4, 5), equal_to_1=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_native_layer_norm_backward_12', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 2, 'num_reduction': 2, 'backend_hash': '196B341D951BBDD96DE5B2B44B37054C2CBA8E89494C8B8A1052A863B8BC7596', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'kernel_num_gb': 0.038141952}
)
@triton.jit
def triton_red_fused__to_copy_native_layer_norm_backward_12(in_ptr0, in_ptr1, out_ptr0, out_ptr1, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr):
xnumel = 49152
rnumel = 128
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
xmask = tl.full([XBLOCK, RBLOCK], True, tl.int1)
rbase = tl.arange(0, RBLOCK)[None, :]
x0 = xindex % 768
x1 = (xindex // 768)
_tmp5 = tl.full([XBLOCK, RBLOCK], 0, tl.float32)
x3 = xindex
_tmp8 = tl.full([XBLOCK, RBLOCK], 0, tl.float32)
for roffset in range(0, rnumel, RBLOCK):
rindex = roffset + rbase
rmask = rindex < rnumel
r2 = rindex
tmp0 = tl.load(in_ptr0 + (x0 + (768*r2) + (98304*x1)), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp2 = tl.load(in_ptr1 + (x0 + (768*r2) + (98304*x1)), rmask, eviction_policy='evict_first', other=0.0)
tmp1 = tmp0.to(tl.float32)
tmp3 = tmp1 * tmp2
tmp4 = tl.broadcast_to(tmp3, [XBLOCK, RBLOCK])
tmp6 = _tmp5 + tmp4
_tmp5 = tl.where(rmask, tmp6, _tmp5)
tmp7 = tl.broadcast_to(tmp1, [XBLOCK, RBLOCK])
tmp9 = _tmp8 + tmp7
_tmp8 = tl.where(rmask, tmp9, _tmp8)
tmp5 = tl.sum(_tmp5, 1)[:, None]
tl.store(out_ptr0 + (x3), tmp5, None)
tmp8 = tl.sum(_tmp8, 1)[:, None]
tl.store(out_ptr1 + (x3), tmp8, None)
def get_args():
arg_0 = rand_strided((8192, 768), (768, 1), device='cuda:0', dtype=torch.float16)
arg_1 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.float32)
arg_2 = rand_strided((768, 64), (1, 768), device='cuda:0', dtype=torch.float32)
arg_3 = rand_strided((768, 64), (1, 768), device='cuda:0', dtype=torch.float32)
return arg_0, arg_1, arg_2, arg_3,
def call(args):
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
stream0 = get_raw_stream(0)
triton_red_fused__to_copy_native_layer_norm_backward_12.run(*args, 49152, 128, grid=grid(49152), stream=stream0)
def benchmark_all_configs(args):
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
return triton_red_fused__to_copy_native_layer_norm_backward_12.benchmark_all_configs(*args, 49152, 128, grid=grid(49152))
if __name__ == '__main__':
from triton.testing import do_bench
args = get_args()
ms = do_bench(lambda: call(args), rep=40, fast_flush=True)
num_gb = 0.038141952
gb_per_s = num_gb / (ms / 1e3)
print(f"{ms:.3f}ms {num_gb:.3f}GB {gb_per_s:.2f}GB/s")
''', device_str='cuda')
# kernel path: /tmp/torchinductor_shunting/h5/ch5onztw533j5pafqjnjfivj5tmonqevomegommgb7ujoga557vl.py
# Source Nodes: [], Original ATen: [aten.sum]
triton_red_fused_sum_13 = async_compile.triton('triton_red_fused_sum_13', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
from torch._dynamo.testing import rand_strided
from torch._C import _cuda_getCurrentRawStream as get_raw_stream
import torch
from torch._inductor.runtime.triton_heuristics import grid, split_scan_grid
@triton_heuristics.reduction(
size_hints=[65536, 128],
reduction_hint=ReductionHint.OUTER,
filename=__file__,
triton_meta={'signature': {0: '*fp16', 1: '*fp32', 2: 'i32', 3: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3), equal_to_1=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused_sum_13', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': '196B341D951BBDD96DE5B2B44B37054C2CBA8E89494C8B8A1052A863B8BC7596', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'kernel_num_gb': 0.01277952}
)
@triton.jit
def triton_red_fused_sum_13(in_ptr0, out_ptr0, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr):
xnumel = 49152
rnumel = 128
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
xmask = tl.full([XBLOCK, RBLOCK], True, tl.int1)
rbase = tl.arange(0, RBLOCK)[None, :]
x0 = xindex % 768
x1 = (xindex // 768)
_tmp2 = tl.full([XBLOCK, RBLOCK], 0, tl.float32)
x3 = xindex
for roffset in range(0, rnumel, RBLOCK):
rindex = roffset + rbase
rmask = rindex < rnumel
r2 = rindex
tmp0 = tl.load(in_ptr0 + (x0 + (768*r2) + (98304*x1)), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp1 = tl.broadcast_to(tmp0, [XBLOCK, RBLOCK])
tmp3 = _tmp2 + tmp1
_tmp2 = tl.where(rmask, tmp3, _tmp2)
tmp2 = tl.sum(_tmp2, 1)[:, None]
tl.store(out_ptr0 + (x3), tmp2, None)
def get_args():
arg_0 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.float16)
arg_1 = rand_strided((1, 768, 64), (49152, 1, 768), device='cuda:0', dtype=torch.float32)
return arg_0, arg_1,
def call(args):
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
stream0 = get_raw_stream(0)
triton_red_fused_sum_13.run(*args, 49152, 128, grid=grid(49152), stream=stream0)
def benchmark_all_configs(args):
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
return triton_red_fused_sum_13.benchmark_all_configs(*args, 49152, 128, grid=grid(49152))
if __name__ == '__main__':
from triton.testing import do_bench
args = get_args()
ms = do_bench(lambda: call(args), rep=40, fast_flush=True)
num_gb = 0.01277952
gb_per_s = num_gb / (ms / 1e3)
print(f"{ms:.3f}ms {num_gb:.3f}GB {gb_per_s:.2f}GB/s")
''', device_str='cuda')
# kernel path: /tmp/torchinductor_shunting/xa/cxaftrsj77n67s4u7dsn2e4cavvigeasorghvno527bfpnxdwa6r.py
# Source Nodes: [], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_14 = async_compile.triton('triton_poi_fused__to_copy_14', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
from torch._dynamo.testing import rand_strided
from torch._C import _cuda_getCurrentRawStream as get_raw_stream
import torch
from torch._inductor.runtime.triton_heuristics import grid, split_scan_grid
@triton_heuristics.pointwise(
size_hints=[4194304],
filename=__file__,
triton_meta={'signature': {0: '*fp16', 1: '*fp32', 2: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2), equal_to_1=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_14', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': '196B341D951BBDD96DE5B2B44B37054C2CBA8E89494C8B8A1052A863B8BC7596', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'kernel_num_gb': 0.014155776},
min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused__to_copy_14(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
xnumel = 2359296
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = tl.full([XBLOCK], True, tl.int1)
x0 = xindex
tmp0 = tl.load(in_ptr0 + (x0), None).to(tl.float32)
tmp1 = tmp0.to(tl.float32)
tl.store(out_ptr0 + (x0), tmp1, None)
def get_args():
arg_0 = rand_strided((768, 3072), (3072, 1), device='cuda:0', dtype=torch.float16)
arg_1 = rand_strided((768, 3072), (3072, 1), device='cuda:0', dtype=torch.float32)
return arg_0, arg_1,
def call(args):
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_14.run(*args, 2359296, grid=grid(2359296), stream=stream0)
def benchmark_all_configs(args):
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
return triton_poi_fused__to_copy_14.benchmark_all_configs(*args, 2359296, grid=grid(2359296))
if __name__ == '__main__':
from triton.testing import do_bench
args = get_args()
ms = do_bench(lambda: call(args), rep=40, fast_flush=True)
num_gb = 0.014155776
gb_per_s = num_gb / (ms / 1e3)
print(f"{ms:.3f}ms {num_gb:.3f}GB {gb_per_s:.2f}GB/s")
''', device_str='cuda')
# kernel path: /tmp/torchinductor_shunting/hf/chfhbekl56hr4wmn7e4jzntlpgkbay7rkk3uwcw34wubtw4mtp7c.py
# Source Nodes: [hidden_states_92], Original ATen: [aten.gelu, aten.gelu_backward]
# hidden_states_92 => add_96, convert_element_type_485, erf_11, mul_155
triton_poi_fused_gelu_gelu_backward_15 = async_compile.triton('triton_poi_fused_gelu_gelu_backward_15', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
from torch._dynamo.testing import rand_strided
from torch._C import _cuda_getCurrentRawStream as get_raw_stream
import torch
from torch._inductor.runtime.triton_heuristics import grid, split_scan_grid
@triton_heuristics.pointwise(
size_hints=[33554432],
filename=__file__,
triton_meta={'signature': {0: '*fp16', 1: '*fp16', 2: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2), equal_to_1=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_gelu_gelu_backward_15', 'mutated_arg_names': ['in_out_ptr0'], 'no_x_dim': False, 'num_load': 2, 'num_reduction': 0, 'backend_hash': '196B341D951BBDD96DE5B2B44B37054C2CBA8E89494C8B8A1052A863B8BC7596', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'kernel_num_gb': 0.150994944},
min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused_gelu_gelu_backward_15(in_out_ptr0, in_ptr0, xnumel, XBLOCK : tl.constexpr):
xnumel = 25165824
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = tl.full([XBLOCK], True, tl.int1)
x0 = xindex
tmp0 = tl.load(in_out_ptr0 + (x0), None).to(tl.float32)
tmp2 = tl.load(in_ptr0 + (x0), None).to(tl.float32)
tmp1 = tmp0.to(tl.float32)
tmp3 = tmp2.to(tl.float32)
tmp4 = 0.7071067811865476
tmp5 = tmp3 * tmp4
tmp6 = libdevice.erf(tmp5)
tmp7 = 1.0
tmp8 = tmp6 + tmp7
tmp9 = 0.5
tmp10 = tmp8 * tmp9
tmp11 = tmp3 * tmp3
tmp12 = -0.5
tmp13 = tmp11 * tmp12
tmp14 = tl_math.exp(tmp13)
tmp15 = 0.3989422804014327
tmp16 = tmp14 * tmp15
tmp17 = tmp3 * tmp16
tmp18 = tmp10 + tmp17
tmp19 = tmp1 * tmp18
tmp20 = tmp19.to(tl.float32)
tl.store(in_out_ptr0 + (x0), tmp20, None)
def get_args():
arg_0 = rand_strided((16, 512, 3072), (1572864, 3072, 1), device='cuda:0', dtype=torch.float16)
arg_1 = rand_strided((8192, 3072), (3072, 1), device='cuda:0', dtype=torch.float16)
return arg_0, arg_1,
def call(args):
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
stream0 = get_raw_stream(0)
triton_poi_fused_gelu_gelu_backward_15.run(*args, 25165824, grid=grid(25165824), stream=stream0)
def benchmark_all_configs(args):
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
return triton_poi_fused_gelu_gelu_backward_15.benchmark_all_configs(*args, 25165824, grid=grid(25165824))
if __name__ == '__main__':
from triton.testing import do_bench
args = get_args()
ms = do_bench(lambda: call(args), rep=40, fast_flush=True)
num_gb = 0.150994944
gb_per_s = num_gb / (ms / 1e3)
print(f"{ms:.3f}ms {num_gb:.3f}GB {gb_per_s:.2f}GB/s")
''', device_str='cuda')
# kernel path: /tmp/torchinductor_shunting/fh/cfhrsmzvructcw4ytv6jmh4xhobkaqx2h23x7f66d374p3qloz4y.py
# Source Nodes: [], Original ATen: [aten.sum]
triton_red_fused_sum_16 = async_compile.triton('triton_red_fused_sum_16', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
from torch._dynamo.testing import rand_strided
from torch._C import _cuda_getCurrentRawStream as get_raw_stream
import torch
from torch._inductor.runtime.triton_heuristics import grid, split_scan_grid
@triton_heuristics.reduction(
size_hints=[131072, 256],
reduction_hint=ReductionHint.OUTER,
filename=__file__,
triton_meta={'signature': {0: '*fp16', 1: '*fp32', 2: 'i32', 3: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3), equal_to_1=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused_sum_16', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': '196B341D951BBDD96DE5B2B44B37054C2CBA8E89494C8B8A1052A863B8BC7596', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'kernel_num_gb': 0.050724864}
)
@triton.jit
def triton_red_fused_sum_16(in_ptr0, out_ptr0, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr):
xnumel = 98304
rnumel = 256
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
xmask = tl.full([XBLOCK, RBLOCK], True, tl.int1)
rbase = tl.arange(0, RBLOCK)[None, :]
x0 = xindex % 3072
x1 = (xindex // 3072)
_tmp2 = tl.full([XBLOCK, RBLOCK], 0, tl.float32)
x3 = xindex
for roffset in range(0, rnumel, RBLOCK):
rindex = roffset + rbase
rmask = rindex < rnumel
r2 = rindex
tmp0 = tl.load(in_ptr0 + (x0 + (3072*r2) + (786432*x1)), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp1 = tl.broadcast_to(tmp0, [XBLOCK, RBLOCK])
tmp3 = _tmp2 + tmp1
_tmp2 = tl.where(rmask, tmp3, _tmp2)
tmp2 = tl.sum(_tmp2, 1)[:, None]
tl.store(out_ptr0 + (x3), tmp2, None)
def get_args():
arg_0 = rand_strided((16, 512, 3072), (1572864, 3072, 1), device='cuda:0', dtype=torch.float16)
arg_1 = rand_strided((1, 3072, 32), (98304, 1, 3072), device='cuda:0', dtype=torch.float32)
return arg_0, arg_1,
def call(args):
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
stream0 = get_raw_stream(0)
triton_red_fused_sum_16.run(*args, 98304, 256, grid=grid(98304), stream=stream0)
def benchmark_all_configs(args):
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
return triton_red_fused_sum_16.benchmark_all_configs(*args, 98304, 256, grid=grid(98304))
if __name__ == '__main__':
from triton.testing import do_bench
args = get_args()
ms = do_bench(lambda: call(args), rep=40, fast_flush=True)
num_gb = 0.050724864
gb_per_s = num_gb / (ms / 1e3)
print(f"{ms:.3f}ms {num_gb:.3f}GB {gb_per_s:.2f}GB/s")
''', device_str='cuda')
# kernel path: /tmp/torchinductor_shunting/i2/ci2zrjeg52uyf2byr2fphdezj5wiruk6qxrv4gylsb5x5sdq7p37.py
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum]
triton_per_fused__to_copy_sum_17 = async_compile.triton('triton_per_fused__to_copy_sum_17', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
from torch._dynamo.testing import rand_strided
from torch._C import _cuda_getCurrentRawStream as get_raw_stream
import torch
from torch._inductor.runtime.triton_heuristics import grid, split_scan_grid
@triton_heuristics.persistent_reduction(
size_hints=[4096, 32],
reduction_hint=ReductionHint.OUTER,
filename=__file__,
triton_meta={'signature': {0: '*fp32', 1: '*fp32', 2: 'i32', 3: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3), equal_to_1=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_sum_17', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': '196B341D951BBDD96DE5B2B44B37054C2CBA8E89494C8B8A1052A863B8BC7596', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'kernel_num_gb': 0.000405504}
)
@triton.jit
def triton_per_fused__to_copy_sum_17(in_ptr0, out_ptr1, xnumel, rnumel, XBLOCK : tl.constexpr):
xnumel = 3072
rnumel = 32
RBLOCK: tl.constexpr = 32
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
xmask = xindex < xnumel
rindex = tl.arange(0, RBLOCK)[None, :]
roffset = 0
rmask = tl.full([XBLOCK, RBLOCK], True, tl.int1)
r1 = rindex
x0 = xindex
tmp0 = tl.load(in_ptr0 + (x0 + (3072*r1)), xmask, other=0.0)
tmp1 = tl.broadcast_to(tmp0, [XBLOCK, RBLOCK])
tmp3 = tl.where(xmask, tmp1, 0)
tmp4 = tl.sum(tmp3, 1)[:, None]
tmp5 = tmp4.to(tl.float32)
tl.store(out_ptr1 + (x0), tmp5, xmask)
def get_args():
arg_0 = rand_strided((1, 3072, 32), (98304, 1, 3072), device='cuda:0', dtype=torch.float32)
arg_1 = rand_strided((3072,), (1,), device='cuda:0', dtype=torch.float32)
return arg_0, arg_1,
def call(args):
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
stream0 = get_raw_stream(0)
triton_per_fused__to_copy_sum_17.run(*args, 3072, 32, grid=grid(3072), stream=stream0)
def benchmark_all_configs(args):
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
return triton_per_fused__to_copy_sum_17.benchmark_all_configs(*args, 3072, 32, grid=grid(3072))
if __name__ == '__main__':
from triton.testing import do_bench
args = get_args()
ms = do_bench(lambda: call(args), rep=40, fast_flush=True)
num_gb = 0.000405504
gb_per_s = num_gb / (ms / 1e3)
print(f"{ms:.3f}ms {num_gb:.3f}GB {gb_per_s:.2f}GB/s")
''', device_str='cuda')
# kernel path: /tmp/torchinductor_shunting/wo/cworhwsjrdvmasqjh2lg6aetbo4ztpk5ats5rsegareen6mwnibu.py
# Source Nodes: [], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_18 = async_compile.triton('triton_poi_fused__to_copy_18', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
from torch._dynamo.testing import rand_strided
from torch._C import _cuda_getCurrentRawStream as get_raw_stream
import torch
from torch._inductor.runtime.triton_heuristics import grid, split_scan_grid
@triton_heuristics.pointwise(
size_hints=[4194304],
filename=__file__,
triton_meta={'signature': {0: '*fp16', 1: '*fp32', 2: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2), equal_to_1=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_18', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': '196B341D951BBDD96DE5B2B44B37054C2CBA8E89494C8B8A1052A863B8BC7596', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'kernel_num_gb': 0.014155776},
min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused__to_copy_18(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
xnumel = 2359296
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = tl.full([XBLOCK], True, tl.int1)
x0 = xindex
tmp0 = tl.load(in_ptr0 + (x0), None).to(tl.float32)
tmp1 = tmp0.to(tl.float32)
tl.store(out_ptr0 + (x0), tmp1, None)
def get_args():
arg_0 = rand_strided((3072, 768), (768, 1), device='cuda:0', dtype=torch.float16)
arg_1 = rand_strided((3072, 768), (768, 1), device='cuda:0', dtype=torch.float32)
return arg_0, arg_1,
def call(args):
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_18.run(*args, 2359296, grid=grid(2359296), stream=stream0)
def benchmark_all_configs(args):
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
return triton_poi_fused__to_copy_18.benchmark_all_configs(*args, 2359296, grid=grid(2359296))
if __name__ == '__main__':
from triton.testing import do_bench
args = get_args()
ms = do_bench(lambda: call(args), rep=40, fast_flush=True)
num_gb = 0.014155776
gb_per_s = num_gb / (ms / 1e3)
print(f"{ms:.3f}ms {num_gb:.3f}GB {gb_per_s:.2f}GB/s")
''', device_str='cuda')
# kernel path: /tmp/torchinductor_shunting/xl/cxluz6c4say445ydf43oeaw5l6cwrfehyxz6txytoztcrndtykgu.py
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_dropout_backward, aten.native_layer_norm_backward]
triton_per_fused__to_copy_add_native_dropout_backward_native_layer_norm_backward_19 = async_compile.triton('triton_per_fused__to_copy_add_native_dropout_backward_native_layer_norm_backward_19', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
from torch._dynamo.testing import rand_strided
from torch._C import _cuda_getCurrentRawStream as get_raw_stream
import torch
from torch._inductor.runtime.triton_heuristics import grid, split_scan_grid
@triton_heuristics.persistent_reduction(
size_hints=[8192, 1024],
reduction_hint=ReductionHint.INNER,
filename=__file__,
triton_meta={'signature': {0: '*fp32', 1: '*fp16', 2: '*fp32', 3: '*fp32', 4: '*fp32', 5: '*i1', 6: '*fp32', 7: '*fp16', 8: 'i32', 9: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4, 5, 6, 7, 8, 9), equal_to_1=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_add_native_dropout_backward_native_layer_norm_backward_19', 'mutated_arg_names': [], 'no_x_dim': True, 'num_load': 6, 'num_reduction': 2, 'backend_hash': '196B341D951BBDD96DE5B2B44B37054C2CBA8E89494C8B8A1052A863B8BC7596', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'kernel_num_gb': 0.106990592}
)
@triton.jit
def triton_per_fused__to_copy_add_native_dropout_backward_native_layer_norm_backward_19(in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, in_ptr5, out_ptr2, out_ptr3, xnumel, rnumel):
xnumel = 8192
XBLOCK: tl.constexpr = 1
rnumel = 768
RBLOCK: tl.constexpr = 1024
xoffset = tl.program_id(0) * XBLOCK
xindex = tl.full([1], xoffset, tl.int32)
xmask = tl.full([RBLOCK], True, tl.int1)
rindex = tl.arange(0, RBLOCK)[:]
roffset = 0
rmask = rindex < rnumel
r1 = rindex
x0 = xindex
tmp0 = tl.load(in_ptr0 + (r1 + (768*x0)), rmask, other=0.0)
tmp1 = tl.load(in_ptr1 + (r1 + (768*x0)), rmask, other=0.0).to(tl.float32)
tmp4 = tl.load(in_ptr2 + (r1), rmask, eviction_policy='evict_last', other=0.0)
tmp10 = tl.load(in_ptr3 + (r1 + (768*x0)), rmask, other=0.0)
tmp16 = tl.load(in_ptr4 + (x0), None, eviction_policy='evict_last')
tmp24 = tl.load(in_ptr5 + (r1 + (768*x0)), rmask, other=0.0).to(tl.int1)
tmp2 = tmp1.to(tl.float32)
tmp3 = tmp0 + tmp2
tmp5 = tmp3 * tmp4
tmp6 = tl.broadcast_to(tmp5, [RBLOCK])
tmp8 = tl.where(rmask, tmp6, 0)
tmp9 = triton_helpers.promote_to_tensor(tl.sum(tmp8, 0))
tmp11 = tmp5 * tmp10
tmp12 = tl.broadcast_to(tmp11, [RBLOCK])
tmp14 = tl.where(rmask, tmp12, 0)
tmp15 = triton_helpers.promote_to_tensor(tl.sum(tmp14, 0))
tmp17 = 768.0
tmp18 = tmp5 * tmp17
tmp19 = tmp18 - tmp9
tmp20 = tmp10 * tmp15
tmp21 = tmp19 - tmp20
tmp22 = tmp16 * tmp21
tmp23 = tmp22.to(tl.float32)
tmp25 = tmp24.to(tl.float32)
tmp26 = 1.1111111111111112
tmp27 = tmp25 * tmp26
tmp28 = tmp23 * tmp27
tl.store(out_ptr2 + (r1 + (768*x0)), tmp22, rmask)
tl.store(out_ptr3 + (r1 + (768*x0)), tmp28, rmask)
def get_args():
arg_0 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.float32)
arg_1 = rand_strided((8192, 768), (768, 1), device='cuda:0', dtype=torch.float16)
arg_2 = rand_strided((768,), (1,), device='cuda:0', dtype=torch.float32)
arg_3 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.float32)
arg_4 = rand_strided((16, 512, 1), (512, 1, 1), device='cuda:0', dtype=torch.float32)
arg_5 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.bool)
arg_6 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.float32)
arg_7 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.float16)
return arg_0, arg_1, arg_2, arg_3, arg_4, arg_5, arg_6, arg_7,
def call(args):
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
stream0 = get_raw_stream(0)
triton_per_fused__to_copy_add_native_dropout_backward_native_layer_norm_backward_19.run(*args, 8192, 768, grid=grid(8192), stream=stream0)
def benchmark_all_configs(args):
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
return triton_per_fused__to_copy_add_native_dropout_backward_native_layer_norm_backward_19.benchmark_all_configs(*args, 8192, 768, grid=grid(8192))
if __name__ == '__main__':
from triton.testing import do_bench
args = get_args()
ms = do_bench(lambda: call(args), rep=40, fast_flush=True)
num_gb = 0.106990592
gb_per_s = num_gb / (ms / 1e3)
print(f"{ms:.3f}ms {num_gb:.3f}GB {gb_per_s:.2f}GB/s")
''', device_str='cuda')
# kernel path: /tmp/torchinductor_shunting/7d/c7dgbzvqana3egdgcxvl64as7g5ng2lnohfx5nuiwhgjgnqfzgzf.py
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward]
triton_red_fused__to_copy_add_native_layer_norm_backward_20 = async_compile.triton('triton_red_fused__to_copy_add_native_layer_norm_backward_20', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
from torch._dynamo.testing import rand_strided
from torch._C import _cuda_getCurrentRawStream as get_raw_stream
import torch
from torch._inductor.runtime.triton_heuristics import grid, split_scan_grid
@triton_heuristics.reduction(
size_hints=[65536, 128],
reduction_hint=ReductionHint.OUTER,
filename=__file__,
triton_meta={'signature': {0: '*fp32', 1: '*fp16', 2: '*fp32', 3: '*fp32', 4: '*fp32', 5: 'i32', 6: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4, 5, 6), equal_to_1=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_add_native_layer_norm_backward_20', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 3, 'num_reduction': 2, 'backend_hash': '196B341D951BBDD96DE5B2B44B37054C2CBA8E89494C8B8A1052A863B8BC7596', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'kernel_num_gb': 0.063307776}
)
@triton.jit
def triton_red_fused__to_copy_add_native_layer_norm_backward_20(in_ptr0, in_ptr1, in_ptr2, out_ptr0, out_ptr1, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr):
xnumel = 49152
rnumel = 128
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
xmask = tl.full([XBLOCK, RBLOCK], True, tl.int1)
rbase = tl.arange(0, RBLOCK)[None, :]
x0 = xindex % 768
x1 = (xindex // 768)
_tmp7 = tl.full([XBLOCK, RBLOCK], 0, tl.float32)
x3 = xindex
_tmp10 = tl.full([XBLOCK, RBLOCK], 0, tl.float32)
for roffset in range(0, rnumel, RBLOCK):
rindex = roffset + rbase
rmask = rindex < rnumel
r2 = rindex
tmp0 = tl.load(in_ptr0 + (x0 + (768*r2) + (98304*x1)), rmask, eviction_policy='evict_first', other=0.0)
tmp1 = tl.load(in_ptr1 + (x0 + (768*r2) + (98304*x1)), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp4 = tl.load(in_ptr2 + (x0 + (768*r2) + (98304*x1)), rmask, eviction_policy='evict_first', other=0.0)
tmp2 = tmp1.to(tl.float32)
tmp3 = tmp0 + tmp2
tmp5 = tmp3 * tmp4
tmp6 = tl.broadcast_to(tmp5, [XBLOCK, RBLOCK])
tmp8 = _tmp7 + tmp6
_tmp7 = tl.where(rmask, tmp8, _tmp7)
tmp9 = tl.broadcast_to(tmp3, [XBLOCK, RBLOCK])
tmp11 = _tmp10 + tmp9
_tmp10 = tl.where(rmask, tmp11, _tmp10)
tmp7 = tl.sum(_tmp7, 1)[:, None]
tl.store(out_ptr0 + (x3), tmp7, None)
tmp10 = tl.sum(_tmp10, 1)[:, None]
tl.store(out_ptr1 + (x3), tmp10, None)
def get_args():
arg_0 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.float32)
arg_1 = rand_strided((8192, 768), (768, 1), device='cuda:0', dtype=torch.float16)
arg_2 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.float32)
arg_3 = rand_strided((768, 64), (1, 768), device='cuda:0', dtype=torch.float32)
arg_4 = rand_strided((768, 64), (1, 768), device='cuda:0', dtype=torch.float32)
return arg_0, arg_1, arg_2, arg_3, arg_4,
def call(args):
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
stream0 = get_raw_stream(0)
triton_red_fused__to_copy_add_native_layer_norm_backward_20.run(*args, 49152, 128, grid=grid(49152), stream=stream0)
def benchmark_all_configs(args):
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
return triton_red_fused__to_copy_add_native_layer_norm_backward_20.benchmark_all_configs(*args, 49152, 128, grid=grid(49152))
if __name__ == '__main__':
from triton.testing import do_bench
args = get_args()
ms = do_bench(lambda: call(args), rep=40, fast_flush=True)
num_gb = 0.063307776
gb_per_s = num_gb / (ms / 1e3)
print(f"{ms:.3f}ms {num_gb:.3f}GB {gb_per_s:.2f}GB/s")
''', device_str='cuda')
# kernel path: /tmp/torchinductor_shunting/wl/cwl6ydl7tr2jbzx6zhrjkb6rmp4sbhadfjg44rsjimnvb7lidvb6.py
# Source Nodes: [], Original ATen: [aten.clone]
triton_poi_fused_clone_21 = async_compile.triton('triton_poi_fused_clone_21', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
from torch._dynamo.testing import rand_strided
from torch._C import _cuda_getCurrentRawStream as get_raw_stream
import torch
from torch._inductor.runtime.triton_heuristics import grid, split_scan_grid
@triton_heuristics.pointwise(
size_hints=[8388608],
filename=__file__,
triton_meta={'signature': {0: '*fp16', 1: '*fp16', 2: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2), equal_to_1=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_clone_21', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': '196B341D951BBDD96DE5B2B44B37054C2CBA8E89494C8B8A1052A863B8BC7596', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'kernel_num_gb': 0.025165824},
min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused_clone_21(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
xnumel = 6291456
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = tl.full([XBLOCK], True, tl.int1)
x0 = xindex % 64
x1 = (xindex // 64) % 512
x2 = (xindex // 32768) % 12
x3 = (xindex // 393216)
x4 = xindex
tmp0 = tl.load(in_ptr0 + (x0 + (64*x2) + (768*x1) + (393216*x3)), None).to(tl.float32)
tl.store(out_ptr0 + (x4), tmp0, None)
def get_args():
arg_0 = rand_strided((8192, 768), (768, 1), device='cuda:0', dtype=torch.float16)
arg_1 = rand_strided((16, 12, 512, 64), (393216, 32768, 64, 1), device='cuda:0', dtype=torch.float16)
return arg_0, arg_1,
def call(args):
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
stream0 = get_raw_stream(0)
triton_poi_fused_clone_21.run(*args, 6291456, grid=grid(6291456), stream=stream0)
def benchmark_all_configs(args):
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
return triton_poi_fused_clone_21.benchmark_all_configs(*args, 6291456, grid=grid(6291456))
if __name__ == '__main__':
from triton.testing import do_bench
args = get_args()
ms = do_bench(lambda: call(args), rep=40, fast_flush=True)
num_gb = 0.025165824
gb_per_s = num_gb / (ms / 1e3)
print(f"{ms:.3f}ms {num_gb:.3f}GB {gb_per_s:.2f}GB/s")
''', device_str='cuda')
# kernel path: /tmp/torchinductor_shunting/2w/c2w56xbyktkjz3c5fbogsijpit33ejsgka5ndga7xju74glcq3wd.py
# Source Nodes: [], Original ATen: [aten.sum]
triton_red_fused_sum_22 = async_compile.triton('triton_red_fused_sum_22', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
from torch._dynamo.testing import rand_strided
from torch._C import _cuda_getCurrentRawStream as get_raw_stream
import torch
from torch._inductor.runtime.triton_heuristics import grid, split_scan_grid
@triton_heuristics.reduction(
size_hints=[65536, 128],
reduction_hint=ReductionHint.OUTER,
filename=__file__,
triton_meta={'signature': {0: '*fp16', 1: '*fp32', 2: 'i32', 3: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3), equal_to_1=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused_sum_22', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': '196B341D951BBDD96DE5B2B44B37054C2CBA8E89494C8B8A1052A863B8BC7596', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'kernel_num_gb': 0.01277952}
)
@triton.jit
def triton_red_fused_sum_22(in_ptr0, out_ptr0, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr):
xnumel = 49152
rnumel = 128
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
xmask = tl.full([XBLOCK, RBLOCK], True, tl.int1)
rbase = tl.arange(0, RBLOCK)[None, :]
x0 = xindex % 768
x1 = (xindex // 768)
_tmp2 = tl.full([XBLOCK, RBLOCK], 0, tl.float32)
x3 = xindex
for roffset in range(0, rnumel, RBLOCK):
rindex = roffset + rbase
rmask = rindex < rnumel
r2 = rindex
tmp0 = tl.load(in_ptr0 + (x0 + (768*r2) + (98304*x1)), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp1 = tl.broadcast_to(tmp0, [XBLOCK, RBLOCK])
tmp3 = _tmp2 + tmp1
_tmp2 = tl.where(rmask, tmp3, _tmp2)
tmp2 = tl.sum(_tmp2, 1)[:, None]
tl.store(out_ptr0 + (x3), tmp2, None)
def get_args():
arg_0 = rand_strided((16, 12, 512, 64), (393216, 64, 768, 1), device='cuda:0', dtype=torch.float16)
arg_1 = rand_strided((1, 768, 64), (49152, 1, 768), device='cuda:0', dtype=torch.float32)
return arg_0, arg_1,
def call(args):
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
stream0 = get_raw_stream(0)
triton_red_fused_sum_22.run(*args, 49152, 128, grid=grid(49152), stream=stream0)
def benchmark_all_configs(args):
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
return triton_red_fused_sum_22.benchmark_all_configs(*args, 49152, 128, grid=grid(49152))
if __name__ == '__main__':
from triton.testing import do_bench
args = get_args()
ms = do_bench(lambda: call(args), rep=40, fast_flush=True)
num_gb = 0.01277952
gb_per_s = num_gb / (ms / 1e3)
print(f"{ms:.3f}ms {num_gb:.3f}GB {gb_per_s:.2f}GB/s")
''', device_str='cuda')
# kernel path: /tmp/torchinductor_shunting/jl/cjldro7q6f5pxz6zuzrjak63sjnv3jqrpdferyvy23uyblyxgymd.py
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_dropout_backward, aten.native_layer_norm_backward]
triton_per_fused__to_copy_add_native_dropout_backward_native_layer_norm_backward_23 = async_compile.triton('triton_per_fused__to_copy_add_native_dropout_backward_native_layer_norm_backward_23', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
from torch._dynamo.testing import rand_strided
from torch._C import _cuda_getCurrentRawStream as get_raw_stream
import torch
from torch._inductor.runtime.triton_heuristics import grid, split_scan_grid
@triton_heuristics.persistent_reduction(
size_hints=[8192, 1024],
reduction_hint=ReductionHint.INNER,
filename=__file__,
triton_meta={'signature': {0: '*fp32', 1: '*fp16', 2: '*fp16', 3: '*fp16', 4: '*fp32', 5: '*fp32', 6: '*fp32', 7: '*i1', 8: '*fp32', 9: '*fp16', 10: 'i32', 11: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11), equal_to_1=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_add_native_dropout_backward_native_layer_norm_backward_23', 'mutated_arg_names': [], 'no_x_dim': True, 'num_load': 8, 'num_reduction': 2, 'backend_hash': '196B341D951BBDD96DE5B2B44B37054C2CBA8E89494C8B8A1052A863B8BC7596', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'kernel_num_gb': 0.132156416}
)
@triton.jit
def triton_per_fused__to_copy_add_native_dropout_backward_native_layer_norm_backward_23(in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, in_ptr5, in_ptr6, in_ptr7, out_ptr3, out_ptr4, xnumel, rnumel):
xnumel = 8192
XBLOCK: tl.constexpr = 1
rnumel = 768
RBLOCK: tl.constexpr = 1024
xoffset = tl.program_id(0) * XBLOCK
xindex = tl.full([1], xoffset, tl.int32)
xmask = tl.full([RBLOCK], True, tl.int1)
rindex = tl.arange(0, RBLOCK)[:]
roffset = 0
rmask = rindex < rnumel
r1 = rindex
x0 = xindex
tmp0 = tl.load(in_ptr0 + (r1 + (768*x0)), rmask, other=0.0)
tmp1 = tl.load(in_ptr1 + (r1 + (768*x0)), rmask, other=0.0).to(tl.float32)
tmp4 = tl.load(in_ptr2 + (r1 + (768*x0)), rmask, other=0.0).to(tl.float32)
tmp7 = tl.load(in_ptr3 + (r1 + (768*x0)), rmask, other=0.0).to(tl.float32)
tmp10 = tl.load(in_ptr4 + (r1), rmask, eviction_policy='evict_last', other=0.0)
tmp16 = tl.load(in_ptr5 + (r1 + (768*x0)), rmask, other=0.0)
tmp22 = tl.load(in_ptr6 + (x0), None, eviction_policy='evict_last')
tmp30 = tl.load(in_ptr7 + (r1 + (768*x0)), rmask, other=0.0).to(tl.int1)
tmp2 = tmp1.to(tl.float32)
tmp3 = tmp0 + tmp2
tmp5 = tmp4.to(tl.float32)
tmp6 = tmp3 + tmp5
tmp8 = tmp7.to(tl.float32)
tmp9 = tmp6 + tmp8
tmp11 = tmp9 * tmp10
tmp12 = tl.broadcast_to(tmp11, [RBLOCK])
tmp14 = tl.where(rmask, tmp12, 0)
tmp15 = triton_helpers.promote_to_tensor(tl.sum(tmp14, 0))
tmp17 = tmp11 * tmp16
tmp18 = tl.broadcast_to(tmp17, [RBLOCK])
tmp20 = tl.where(rmask, tmp18, 0)
tmp21 = triton_helpers.promote_to_tensor(tl.sum(tmp20, 0))
tmp23 = 768.0
tmp24 = tmp11 * tmp23
tmp25 = tmp24 - tmp15
tmp26 = tmp16 * tmp21
tmp27 = tmp25 - tmp26
tmp28 = tmp22 * tmp27
tmp29 = tmp28.to(tl.float32)
tmp31 = tmp30.to(tl.float32)
tmp32 = 1.1111111111111112
tmp33 = tmp31 * tmp32
tmp34 = tmp29 * tmp33
tl.store(out_ptr3 + (r1 + (768*x0)), tmp28, rmask)
tl.store(out_ptr4 + (r1 + (768*x0)), tmp34, rmask)
def get_args():
arg_0 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.float32)
arg_1 = rand_strided((8192, 768), (768, 1), device='cuda:0', dtype=torch.float16)
arg_2 = rand_strided((8192, 768), (768, 1), device='cuda:0', dtype=torch.float16)
arg_3 = rand_strided((8192, 768), (768, 1), device='cuda:0', dtype=torch.float16)
arg_4 = rand_strided((768,), (1,), device='cuda:0', dtype=torch.float32)
arg_5 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.float32)
arg_6 = rand_strided((16, 512, 1), (512, 1, 1), device='cuda:0', dtype=torch.float32)
arg_7 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.bool)
arg_8 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.float32)
arg_9 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.float16)
return arg_0, arg_1, arg_2, arg_3, arg_4, arg_5, arg_6, arg_7, arg_8, arg_9,
def call(args):
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
stream0 = get_raw_stream(0)
triton_per_fused__to_copy_add_native_dropout_backward_native_layer_norm_backward_23.run(*args, 8192, 768, grid=grid(8192), stream=stream0)
def benchmark_all_configs(args):
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
return triton_per_fused__to_copy_add_native_dropout_backward_native_layer_norm_backward_23.benchmark_all_configs(*args, 8192, 768, grid=grid(8192))
if __name__ == '__main__':
from triton.testing import do_bench
args = get_args()
ms = do_bench(lambda: call(args), rep=40, fast_flush=True)
num_gb = 0.132156416
gb_per_s = num_gb / (ms / 1e3)
print(f"{ms:.3f}ms {num_gb:.3f}GB {gb_per_s:.2f}GB/s")
''', device_str='cuda')
# kernel path: /tmp/torchinductor_shunting/nh/cnhqvaojz2ly5fypxyjca4dmreema6cy5ukpvnferpa7nqrba3em.py
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward]
triton_red_fused__to_copy_add_native_layer_norm_backward_24 = async_compile.triton('triton_red_fused__to_copy_add_native_layer_norm_backward_24', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
from torch._dynamo.testing import rand_strided
from torch._C import _cuda_getCurrentRawStream as get_raw_stream
import torch
from torch._inductor.runtime.triton_heuristics import grid, split_scan_grid
@triton_heuristics.reduction(
size_hints=[65536, 128],
reduction_hint=ReductionHint.OUTER,
filename=__file__,
triton_meta={'signature': {0: '*fp32', 1: '*fp16', 2: '*fp16', 3: '*fp16', 4: '*fp32', 5: '*fp32', 6: '*fp32', 7: 'i32', 8: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4, 5, 6, 7, 8), equal_to_1=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_add_native_layer_norm_backward_24', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 5, 'num_reduction': 2, 'backend_hash': '196B341D951BBDD96DE5B2B44B37054C2CBA8E89494C8B8A1052A863B8BC7596', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'kernel_num_gb': 0.0884736}
)
@triton.jit
def triton_red_fused__to_copy_add_native_layer_norm_backward_24(in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, out_ptr0, out_ptr1, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr):
xnumel = 49152
rnumel = 128
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
xmask = tl.full([XBLOCK, RBLOCK], True, tl.int1)
rbase = tl.arange(0, RBLOCK)[None, :]
x0 = xindex % 768
x1 = (xindex // 768)
_tmp13 = tl.full([XBLOCK, RBLOCK], 0, tl.float32)
x3 = xindex
_tmp16 = tl.full([XBLOCK, RBLOCK], 0, tl.float32)
for roffset in range(0, rnumel, RBLOCK):
rindex = roffset + rbase
rmask = rindex < rnumel
r2 = rindex
tmp0 = tl.load(in_ptr0 + (x0 + (768*r2) + (98304*x1)), rmask, eviction_policy='evict_first', other=0.0)
tmp1 = tl.load(in_ptr1 + (x0 + (768*r2) + (98304*x1)), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp4 = tl.load(in_ptr2 + (x0 + (768*r2) + (98304*x1)), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp7 = tl.load(in_ptr3 + (x0 + (768*r2) + (98304*x1)), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp10 = tl.load(in_ptr4 + (x0 + (768*r2) + (98304*x1)), rmask, eviction_policy='evict_first', other=0.0)
tmp2 = tmp1.to(tl.float32)
tmp3 = tmp0 + tmp2
tmp5 = tmp4.to(tl.float32)
tmp6 = tmp3 + tmp5
tmp8 = tmp7.to(tl.float32)
tmp9 = tmp6 + tmp8
tmp11 = tmp9 * tmp10
tmp12 = tl.broadcast_to(tmp11, [XBLOCK, RBLOCK])
tmp14 = _tmp13 + tmp12
_tmp13 = tl.where(rmask, tmp14, _tmp13)
tmp15 = tl.broadcast_to(tmp9, [XBLOCK, RBLOCK])
tmp17 = _tmp16 + tmp15
_tmp16 = tl.where(rmask, tmp17, _tmp16)
tmp13 = tl.sum(_tmp13, 1)[:, None]
tl.store(out_ptr0 + (x3), tmp13, None)
tmp16 = tl.sum(_tmp16, 1)[:, None]
tl.store(out_ptr1 + (x3), tmp16, None)
def get_args():
arg_0 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.float32)
arg_1 = rand_strided((8192, 768), (768, 1), device='cuda:0', dtype=torch.float16)
arg_2 = rand_strided((8192, 768), (768, 1), device='cuda:0', dtype=torch.float16)
arg_3 = rand_strided((8192, 768), (768, 1), device='cuda:0', dtype=torch.float16)
arg_4 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.float32)
arg_5 = rand_strided((768, 64), (1, 768), device='cuda:0', dtype=torch.float32)
arg_6 = rand_strided((768, 64), (1, 768), device='cuda:0', dtype=torch.float32)
return arg_0, arg_1, arg_2, arg_3, arg_4, arg_5, arg_6,
def call(args):
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
stream0 = get_raw_stream(0)
triton_red_fused__to_copy_add_native_layer_norm_backward_24.run(*args, 49152, 128, grid=grid(49152), stream=stream0)
def benchmark_all_configs(args):
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
return triton_red_fused__to_copy_add_native_layer_norm_backward_24.benchmark_all_configs(*args, 49152, 128, grid=grid(49152))
if __name__ == '__main__':
from triton.testing import do_bench
args = get_args()
ms = do_bench(lambda: call(args), rep=40, fast_flush=True)
num_gb = 0.0884736
gb_per_s = num_gb / (ms / 1e3)
print(f"{ms:.3f}ms {num_gb:.3f}GB {gb_per_s:.2f}GB/s")
''', device_str='cuda')
# kernel path: /tmp/torchinductor_shunting/te/cte77lf37am2mn24wajel7rv6zip6voecouwoqp2tx4rqtc4kzcb.py
# Source Nodes: [], Original ATen: [aten.embedding_dense_backward]
triton_poi_fused_embedding_dense_backward_25 = async_compile.triton('triton_poi_fused_embedding_dense_backward_25', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
from torch._dynamo.testing import rand_strided
from torch._C import _cuda_getCurrentRawStream as get_raw_stream
import torch
from torch._inductor.runtime.triton_heuristics import grid, split_scan_grid
@triton_heuristics.pointwise(
size_hints=[2048],
filename=__file__,
triton_meta={'signature': {0: '*fp32', 1: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1), equal_to_1=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_embedding_dense_backward_25', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 0, 'num_reduction': 0, 'backend_hash': '196B341D951BBDD96DE5B2B44B37054C2CBA8E89494C8B8A1052A863B8BC7596', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'kernel_num_gb': 6.144e-06},
min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused_embedding_dense_backward_25(out_ptr0, xnumel, XBLOCK : tl.constexpr):
xnumel = 1536
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = xindex < xnumel
x0 = xindex
tmp0 = 0.0
tl.store(out_ptr0 + (x0), tmp0, xmask)
def get_args():
arg_0 = rand_strided((2, 768), (768, 1), device='cuda:0', dtype=torch.float32)
return arg_0,
def call(args):
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
stream0 = get_raw_stream(0)
triton_poi_fused_embedding_dense_backward_25.run(*args, 1536, grid=grid(1536), stream=stream0)
def benchmark_all_configs(args):
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
return triton_poi_fused_embedding_dense_backward_25.benchmark_all_configs(*args, 1536, grid=grid(1536))
if __name__ == '__main__':
from triton.testing import do_bench
args = get_args()
ms = do_bench(lambda: call(args), rep=40, fast_flush=True)
num_gb = 6.144e-06
gb_per_s = num_gb / (ms / 1e3)
print(f"{ms:.3f}ms {num_gb:.3f}GB {gb_per_s:.2f}GB/s")
''', device_str='cuda')
# kernel path: /tmp/torchinductor_shunting/6f/c6fngau4jk3lnabptysnet4fu7qocsijshhqwcxzkwwgptradatm.py
# Source Nodes: [], Original ATen: [aten.embedding_dense_backward]
triton_poi_fused_embedding_dense_backward_26 = async_compile.triton('triton_poi_fused_embedding_dense_backward_26', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
from torch._dynamo.testing import rand_strided
from torch._C import _cuda_getCurrentRawStream as get_raw_stream
import torch
from torch._inductor.runtime.triton_heuristics import grid, split_scan_grid
@triton_heuristics.pointwise(
size_hints=[33554432],
filename=__file__,
triton_meta={'signature': {0: '*fp32', 1: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1), equal_to_1=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_embedding_dense_backward_26', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 0, 'num_reduction': 0, 'backend_hash': '196B341D951BBDD96DE5B2B44B37054C2CBA8E89494C8B8A1052A863B8BC7596', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'kernel_num_gb': 0.093763584},
min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused_embedding_dense_backward_26(out_ptr0, xnumel, XBLOCK : tl.constexpr):
xnumel = 23440896
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = xindex < xnumel
x0 = xindex
tmp0 = 0.0
tl.store(out_ptr0 + (x0), tmp0, xmask)
def get_args():
arg_0 = rand_strided((30522, 768), (768, 1), device='cuda:0', dtype=torch.float32)
return arg_0,
def call(args):
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
stream0 = get_raw_stream(0)
triton_poi_fused_embedding_dense_backward_26.run(*args, 23440896, grid=grid(23440896), stream=stream0)
def benchmark_all_configs(args):
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
return triton_poi_fused_embedding_dense_backward_26.benchmark_all_configs(*args, 23440896, grid=grid(23440896))
if __name__ == '__main__':
from triton.testing import do_bench
args = get_args()
ms = do_bench(lambda: call(args), rep=40, fast_flush=True)
num_gb = 0.093763584
gb_per_s = num_gb / (ms / 1e3)
print(f"{ms:.3f}ms {num_gb:.3f}GB {gb_per_s:.2f}GB/s")
''', device_str='cuda')
# kernel path: /tmp/torchinductor_shunting/ha/chan35efrgnml6b6lkx5heboj7kad62m5wsfwbpqx4tq72w4wxpt.py
# Source Nodes: [masked_lm_loss], Original ATen: [aten._to_copy, aten.add, aten.embedding_dense_backward, aten.native_dropout_backward, aten.native_layer_norm_backward, aten.nll_loss_forward]
# masked_lm_loss => full_default_2
triton_per_fused__to_copy_add_embedding_dense_backward_native_dropout_backward_native_layer_norm_backward_nll_loss_forward_27 = async_compile.triton('triton_per_fused__to_copy_add_embedding_dense_backward_native_dropout_backward_native_layer_norm_backward_nll_loss_forward_27', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
from torch._dynamo.testing import rand_strided
from torch._C import _cuda_getCurrentRawStream as get_raw_stream
import torch
from torch._inductor.runtime.triton_heuristics import grid, split_scan_grid
@triton_heuristics.persistent_reduction(
size_hints=[8192, 1024],
reduction_hint=ReductionHint.INNER,
filename=__file__,
triton_meta={'signature': {0: '*fp32', 1: '*fp16', 2: '*fp16', 3: '*fp16', 4: '*i1', 5: '*fp32', 6: '*fp32', 7: '*fp32', 8: '*i64', 9: '*i64', 10: '*fp32', 11: '*fp32', 12: '*fp32', 13: 'i32', 14: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14), equal_to_1=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_add_embedding_dense_backward_native_dropout_backward_native_layer_norm_backward_nll_loss_forward_27', 'mutated_arg_names': ['in_out_ptr0', 'out_ptr3', 'out_ptr4'], 'no_x_dim': True, 'num_load': 10, 'num_reduction': 2, 'backend_hash': '196B341D951BBDD96DE5B2B44B37054C2CBA8E89494C8B8A1052A863B8BC7596', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'kernel_num_gb': 0.169980928}
)
@triton.jit
def triton_per_fused__to_copy_add_embedding_dense_backward_native_dropout_backward_native_layer_norm_backward_nll_loss_forward_27(in_out_ptr0, in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, in_ptr5, in_ptr6, in_ptr7, in_ptr8, out_ptr2, out_ptr3, out_ptr4, xnumel, rnumel):
xnumel = 8192
XBLOCK: tl.constexpr = 1
rnumel = 768
RBLOCK: tl.constexpr = 1024
xoffset = tl.program_id(0) * XBLOCK
xindex = tl.full([1], xoffset, tl.int32)
xmask = tl.full([RBLOCK], True, tl.int1)
rindex = tl.arange(0, RBLOCK)[:]
roffset = 0
rmask = rindex < rnumel
r1 = rindex
x0 = xindex
x2 = xindex % 512
tmp0 = tl.load(in_out_ptr0 + (r1 + (768*x0)), rmask, other=0.0)
tmp1 = tl.load(in_ptr0 + (r1 + (768*x0)), rmask, other=0.0).to(tl.float32)
tmp4 = tl.load(in_ptr1 + (r1 + (768*x0)), rmask, other=0.0).to(tl.float32)
tmp7 = tl.load(in_ptr2 + (r1 + (768*x0)), rmask, other=0.0).to(tl.float32)
tmp10 = tl.load(in_ptr3 + (r1 + (768*x0)), rmask, other=0.0).to(tl.int1)
tmp15 = tl.load(in_ptr4 + (r1), rmask, eviction_policy='evict_last', other=0.0)
tmp21 = tl.load(in_ptr5 + (r1 + (768*x0)), rmask, other=0.0)
tmp27 = tl.load(in_ptr6 + (x0), None, eviction_policy='evict_last')
tmp34 = tl.load(in_ptr7 + (x2), None, eviction_policy='evict_last')
tmp43 = tl.load(in_ptr8 + (x0), None, eviction_policy='evict_last')
tmp2 = tmp1.to(tl.float32)
tmp3 = tmp0 + tmp2
tmp5 = tmp4.to(tl.float32)
tmp6 = tmp3 + tmp5
tmp8 = tmp7.to(tl.float32)
tmp9 = tmp6 + tmp8
tmp11 = tmp10.to(tl.float32)
tmp12 = 1.1111111111111112
tmp13 = tmp11 * tmp12
tmp14 = tmp9 * tmp13
tmp16 = tmp14 * tmp15
tmp17 = tl.broadcast_to(tmp16, [RBLOCK])
tmp19 = tl.where(rmask, tmp17, 0)
tmp20 = triton_helpers.promote_to_tensor(tl.sum(tmp19, 0))
tmp22 = tmp16 * tmp21
tmp23 = tl.broadcast_to(tmp22, [RBLOCK])
tmp25 = tl.where(rmask, tmp23, 0)
tmp26 = triton_helpers.promote_to_tensor(tl.sum(tmp25, 0))
tmp28 = 768.0
tmp29 = tmp16 * tmp28
tmp30 = tmp29 - tmp20
tmp31 = tmp21 * tmp26
tmp32 = tmp30 - tmp31
tmp33 = tmp27 * tmp32
tmp35 = tl.full([RBLOCK], 2, tl.int32)
tmp36 = tmp34 + tmp35
tmp37 = tmp34 < 0
tmp38 = tl.where(tmp37, tmp36, tmp34)
tmp39 = tl.full([1], -1, tl.int64)
tmp40 = tmp34 == tmp39
tmp41 = 0.0
tmp42 = tl.where(tmp40, tmp41, tmp33)
tmp44 = tl.full([RBLOCK], 30522, tl.int32)
tmp45 = tmp43 + tmp44
tmp46 = tmp43 < 0
tmp47 = tl.where(tmp46, tmp45, tmp43)
tmp48 = tl.full([1], 0, tl.int64)
tmp49 = tmp43 == tmp48
tmp50 = tl.where(tmp49, tmp41, tmp33)
tl.store(in_out_ptr0 + (r1 + (768*x0)), tmp14, rmask)
tl.store(out_ptr2 + (r1 + (768*x0)), tmp33, rmask)
tl.atomic_add(out_ptr3 + (tl.broadcast_to(r1 + (768*tmp38), [RBLOCK])), tmp42, rmask, sem='relaxed')
tl.atomic_add(out_ptr4 + (tl.broadcast_to(r1 + (768*tmp47), [RBLOCK])), tmp50, rmask, sem='relaxed')
def get_args():
arg_0 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.float32)
arg_1 = rand_strided((8192, 768), (768, 1), device='cuda:0', dtype=torch.float16)
arg_2 = rand_strided((8192, 768), (768, 1), device='cuda:0', dtype=torch.float16)
arg_3 = rand_strided((8192, 768), (768, 1), device='cuda:0', dtype=torch.float16)
arg_4 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.bool)
arg_5 = rand_strided((768,), (1,), device='cuda:0', dtype=torch.float32)
arg_6 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.float32)
arg_7 = rand_strided((16, 512, 1), (512, 1, 1), device='cuda:0', dtype=torch.float32)
arg_8 = rand_strided((1, 512), (512, 1), device='cuda:0', dtype=torch.int64)
arg_9 = rand_strided((16, 512), (512, 1), device='cuda:0', dtype=torch.int64)
arg_10 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.float32)
arg_11 = rand_strided((2, 768), (768, 1), device='cuda:0', dtype=torch.float32)
arg_12 = rand_strided((30522, 768), (768, 1), device='cuda:0', dtype=torch.float32)
return arg_0, arg_1, arg_2, arg_3, arg_4, arg_5, arg_6, arg_7, arg_8, arg_9, arg_10, arg_11, arg_12,
def call(args):
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
stream0 = get_raw_stream(0)
triton_per_fused__to_copy_add_embedding_dense_backward_native_dropout_backward_native_layer_norm_backward_nll_loss_forward_27.run(*args, 8192, 768, grid=grid(8192), stream=stream0)
def benchmark_all_configs(args):
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
return triton_per_fused__to_copy_add_embedding_dense_backward_native_dropout_backward_native_layer_norm_backward_nll_loss_forward_27.benchmark_all_configs(*args, 8192, 768, grid=grid(8192))
if __name__ == '__main__':
from triton.testing import do_bench
args = get_args()
ms = do_bench(lambda: call(args), rep=40, fast_flush=True)
num_gb = 0.169980928
gb_per_s = num_gb / (ms / 1e3)
print(f"{ms:.3f}ms {num_gb:.3f}GB {gb_per_s:.2f}GB/s")
''', device_str='cuda')
# kernel path: /tmp/torchinductor_shunting/c7/cc74oeyzfdenaonorpxyc6q6zpqctt3hj57l7j5h46ugv67z2gei.py
# Source Nodes: [], Original ATen: [aten.native_layer_norm_backward]
triton_red_fused_native_layer_norm_backward_28 = async_compile.triton('triton_red_fused_native_layer_norm_backward_28', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
from torch._dynamo.testing import rand_strided
from torch._C import _cuda_getCurrentRawStream as get_raw_stream
import torch
from torch._inductor.runtime.triton_heuristics import grid, split_scan_grid
@triton_heuristics.reduction(
size_hints=[65536, 128],
reduction_hint=ReductionHint.OUTER,
filename=__file__,
triton_meta={'signature': {0: '*fp32', 1: '*fp32', 2: '*fp32', 3: '*fp32', 4: 'i32', 5: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4, 5), equal_to_1=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused_native_layer_norm_backward_28', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 2, 'num_reduction': 2, 'backend_hash': '196B341D951BBDD96DE5B2B44B37054C2CBA8E89494C8B8A1052A863B8BC7596', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'kernel_num_gb': 0.050724864}
)
@triton.jit
def triton_red_fused_native_layer_norm_backward_28(in_ptr0, in_ptr1, out_ptr0, out_ptr1, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr):
xnumel = 49152
rnumel = 128
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
xmask = tl.full([XBLOCK, RBLOCK], True, tl.int1)
rbase = tl.arange(0, RBLOCK)[None, :]
x0 = xindex % 768
x1 = (xindex // 768)
_tmp4 = tl.full([XBLOCK, RBLOCK], 0, tl.float32)
x3 = xindex
_tmp7 = tl.full([XBLOCK, RBLOCK], 0, tl.float32)
for roffset in range(0, rnumel, RBLOCK):
rindex = roffset + rbase
rmask = rindex < rnumel
r2 = rindex
tmp0 = tl.load(in_ptr0 + (x0 + (768*r2) + (98304*x1)), rmask, eviction_policy='evict_first', other=0.0)
tmp1 = tl.load(in_ptr1 + (x0 + (768*r2) + (98304*x1)), rmask, eviction_policy='evict_first', other=0.0)
tmp2 = tmp0 * tmp1
tmp3 = tl.broadcast_to(tmp2, [XBLOCK, RBLOCK])
tmp5 = _tmp4 + tmp3
_tmp4 = tl.where(rmask, tmp5, _tmp4)
tmp6 = tl.broadcast_to(tmp0, [XBLOCK, RBLOCK])
tmp8 = _tmp7 + tmp6
_tmp7 = tl.where(rmask, tmp8, _tmp7)
tmp4 = tl.sum(_tmp4, 1)[:, None]
tl.store(out_ptr0 + (x3), tmp4, None)
tmp7 = tl.sum(_tmp7, 1)[:, None]
tl.store(out_ptr1 + (x3), tmp7, None)
def get_args():
arg_0 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.float32)
arg_1 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.float32)
arg_2 = rand_strided((768, 64), (1, 768), device='cuda:0', dtype=torch.float32)
arg_3 = rand_strided((768, 64), (1, 768), device='cuda:0', dtype=torch.float32)
return arg_0, arg_1, arg_2, arg_3,
def call(args):
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
stream0 = get_raw_stream(0)
triton_red_fused_native_layer_norm_backward_28.run(*args, 49152, 128, grid=grid(49152), stream=stream0)
def benchmark_all_configs(args):
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
return triton_red_fused_native_layer_norm_backward_28.benchmark_all_configs(*args, 49152, 128, grid=grid(49152))
if __name__ == '__main__':
from triton.testing import do_bench
args = get_args()
ms = do_bench(lambda: call(args), rep=40, fast_flush=True)
num_gb = 0.050724864
gb_per_s = num_gb / (ms / 1e3)
print(f"{ms:.3f}ms {num_gb:.3f}GB {gb_per_s:.2f}GB/s")
''', device_str='cuda')
# kernel path: /tmp/torchinductor_shunting/55/c55h64q7s3knjikeayrydmub7yk3opg73oy3todkps33rlu5tz62.py
# Source Nodes: [], Original ATen: [aten.embedding_dense_backward]
triton_poi_fused_embedding_dense_backward_29 = async_compile.triton('triton_poi_fused_embedding_dense_backward_29', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
from torch._dynamo.testing import rand_strided
from torch._C import _cuda_getCurrentRawStream as get_raw_stream
import torch
from torch._inductor.runtime.triton_heuristics import grid, split_scan_grid
@triton_heuristics.pointwise(
size_hints=[524288],
filename=__file__,
triton_meta={'signature': {0: '*fp32', 1: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1), equal_to_1=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_embedding_dense_backward_29', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 0, 'num_reduction': 0, 'backend_hash': '196B341D951BBDD96DE5B2B44B37054C2CBA8E89494C8B8A1052A863B8BC7596', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'kernel_num_gb': 0.001572864},
min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused_embedding_dense_backward_29(out_ptr0, xnumel, XBLOCK : tl.constexpr):
xnumel = 393216
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = tl.full([XBLOCK], True, tl.int1)
x0 = xindex
tmp0 = 0.0
tl.store(out_ptr0 + (x0), tmp0, None)
def get_args():
arg_0 = rand_strided((512, 768), (768, 1), device='cuda:0', dtype=torch.float32)
return arg_0,
def call(args):
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
stream0 = get_raw_stream(0)
triton_poi_fused_embedding_dense_backward_29.run(*args, 393216, grid=grid(393216), stream=stream0)
def benchmark_all_configs(args):
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
return triton_poi_fused_embedding_dense_backward_29.benchmark_all_configs(*args, 393216, grid=grid(393216))
if __name__ == '__main__':
from triton.testing import do_bench
args = get_args()
ms = do_bench(lambda: call(args), rep=40, fast_flush=True)
num_gb = 0.001572864
gb_per_s = num_gb / (ms / 1e3)
print(f"{ms:.3f}ms {num_gb:.3f}GB {gb_per_s:.2f}GB/s")
''', device_str='cuda')
# kernel path: /tmp/torchinductor_shunting/wv/cwv53zxut7fjvmy4vtbp25bc3pksunq7xtkpoeutsmskctsgwxni.py
# Source Nodes: [masked_lm_loss], Original ATen: [aten.embedding_dense_backward, aten.nll_loss_forward, aten.sum]
# masked_lm_loss => full_default_2
triton_per_fused_embedding_dense_backward_nll_loss_forward_sum_30 = async_compile.triton('triton_per_fused_embedding_dense_backward_nll_loss_forward_sum_30', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
from torch._dynamo.testing import rand_strided
from torch._C import _cuda_getCurrentRawStream as get_raw_stream
import torch
from torch._inductor.runtime.triton_heuristics import grid, split_scan_grid
@triton_heuristics.persistent_reduction(
size_hints=[524288, 16],
reduction_hint=ReductionHint.DEFAULT,
filename=__file__,
triton_meta={'signature': {0: '*fp32', 1: '*i64', 2: '*fp32', 3: 'i32', 4: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4), equal_to_1=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_per_fused_embedding_dense_backward_nll_loss_forward_sum_30', 'mutated_arg_names': ['out_ptr1'], 'no_x_dim': False, 'num_load': 2, 'num_reduction': 1, 'backend_hash': '196B341D951BBDD96DE5B2B44B37054C2CBA8E89494C8B8A1052A863B8BC7596', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'kernel_num_gb': 0.026742784}
)
@triton.jit
def triton_per_fused_embedding_dense_backward_nll_loss_forward_sum_30(in_ptr0, in_ptr1, out_ptr1, xnumel, rnumel, XBLOCK : tl.constexpr):
xnumel = 393216
rnumel = 16
RBLOCK: tl.constexpr = 16
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
xmask = tl.full([XBLOCK, RBLOCK], True, tl.int1)
rindex = tl.arange(0, RBLOCK)[None, :]
roffset = 0
rmask = tl.full([XBLOCK, RBLOCK], True, tl.int1)
r1 = rindex
x0 = xindex
x3 = (xindex // 768)
x2 = xindex % 768
tmp0 = tl.load(in_ptr0 + (x0 + (393216*r1)), None)
tmp4 = tl.load(in_ptr1 + (x3), None, eviction_policy='evict_last')
tmp1 = tl.broadcast_to(tmp0, [XBLOCK, RBLOCK])
tmp3 = tl.sum(tmp1, 1)[:, None]
tmp5 = tl.full([XBLOCK, 1], 512, tl.int32)
tmp6 = tmp4 + tmp5
tmp7 = tmp4 < 0
tmp8 = tl.where(tmp7, tmp6, tmp4)
tmp9 = tl.full([1, 1], -1, tl.int64)
tmp10 = tmp4 == tmp9
tmp11 = 0.0
tmp12 = tl.where(tmp10, tmp11, tmp3)
tl.atomic_add(out_ptr1 + (x2 + (768*tmp8)), tmp12, None, sem='relaxed')
def get_args():
arg_0 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.float32)
arg_1 = rand_strided((1, 512), (512, 1), device='cuda:0', dtype=torch.int64)
arg_2 = rand_strided((512, 768), (768, 1), device='cuda:0', dtype=torch.float32)
return arg_0, arg_1, arg_2,
def call(args):
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
stream0 = get_raw_stream(0)
triton_per_fused_embedding_dense_backward_nll_loss_forward_sum_30.run(*args, 393216, 16, grid=grid(393216), stream=stream0)
def benchmark_all_configs(args):
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
return triton_per_fused_embedding_dense_backward_nll_loss_forward_sum_30.benchmark_all_configs(*args, 393216, 16, grid=grid(393216))
if __name__ == '__main__':
from triton.testing import do_bench
args = get_args()
ms = do_bench(lambda: call(args), rep=40, fast_flush=True)
num_gb = 0.026742784
gb_per_s = num_gb / (ms / 1e3)
print(f"{ms:.3f}ms {num_gb:.3f}GB {gb_per_s:.2f}GB/s")
''', device_str='cuda')
async_compile.wait(globals())
del async_compile
def call(args):
primals_4, primals_14, primals_20, primals_30, primals_36, primals_46, primals_52, primals_62, primals_68, primals_78, primals_84, primals_94, primals_100, primals_110, primals_116, primals_126, primals_132, primals_142, primals_148, primals_158, primals_164, primals_174, primals_180, primals_190, primals_196, primals_200, primals_204, primals_205, primals_206, primals_207, mul_1, gt, view, permute_default_66, permute_default_67, permute_default_68, getitem_129, getitem_130, getitem_131, getitem_132, view_16, gt_2, mul_9, view_18, addmm_4, view_20, gt_3, mul_16, view_22, permute_default_60, permute_default_61, permute_default_62, getitem_122, getitem_123, getitem_124, getitem_125, view_38, gt_5, mul_22, view_40, addmm_10, view_42, gt_6, mul_29, view_44, permute_default_54, permute_default_55, permute_default_56, getitem_115, getitem_116, getitem_117, getitem_118, view_60, gt_8, mul_35, view_62, addmm_16, view_64, gt_9, mul_42, view_66, permute_default_48, permute_default_49, permute_default_50, getitem_108, getitem_109, getitem_110, getitem_111, view_82, gt_11, mul_48, view_84, addmm_22, view_86, gt_12, mul_55, view_88, permute_default_42, permute_default_43, permute_default_44, getitem_101, getitem_102, getitem_103, getitem_104, view_104, gt_14, mul_61, view_106, addmm_28, view_108, gt_15, mul_68, view_110, permute_default_36, permute_default_37, permute_default_38, getitem_94, getitem_95, getitem_96, getitem_97, view_126, gt_17, mul_74, view_128, addmm_34, view_130, gt_18, mul_81, view_132, permute_default_30, permute_default_31, permute_default_32, getitem_87, getitem_88, getitem_89, getitem_90, view_148, gt_20, mul_87, view_150, addmm_40, view_152, gt_21, mul_94, view_154, permute_default_24, permute_default_25, permute_default_26, getitem_80, getitem_81, getitem_82, getitem_83, view_170, gt_23, mul_100, view_172, addmm_46, view_174, gt_24, mul_107, view_176, permute_default_18, permute_default_19, permute_default_20, getitem_73, getitem_74, getitem_75, getitem_76, view_192, gt_26, mul_113, view_194, addmm_52, view_196, gt_27, mul_120, view_198, permute_default_12, permute_default_13, permute_default_14, getitem_66, getitem_67, getitem_68, getitem_69, view_214, gt_29, mul_126, view_216, addmm_58, view_218, gt_30, mul_133, view_220, permute_default_6, permute_default_7, permute_default_8, getitem_59, getitem_60, getitem_61, getitem_62, view_236, gt_32, mul_139, view_238, addmm_64, view_240, gt_33, mul_146, view_242, permute_default, permute_default_1, permute_default_2, getitem_52, getitem_53, getitem_54, getitem_55, view_258, gt_35, mul_152, view_260, addmm_70, view_262, gt_36, mul_159, view_264, addmm_72, getitem_51, rsqrt_25, view_266, view_267, amax_12, log, convert_element_type_510, permute_134, permute_138, div_27, permute_142, permute_146, div_28, permute_150, permute_162, permute_167, permute_171, div_30, permute_175, permute_179, div_31, permute_183, permute_195, permute_200, permute_204, div_33, permute_208, permute_212, div_34, permute_216, permute_228, permute_233, permute_237, div_36, permute_241, permute_245, div_37, permute_249, permute_261, permute_266, permute_270, div_39, permute_274, permute_278, div_40, permute_282, permute_294, permute_299, permute_303, div_42, permute_307, permute_311, div_43, permute_315, permute_327, permute_332, permute_336, div_45, permute_340, permute_344, div_46, permute_348, permute_360, permute_365, permute_369, div_48, permute_373, permute_377, div_49, permute_381, permute_393, permute_398, permute_402, div_51, permute_406, permute_410, div_52, permute_414, permute_426, permute_431, permute_435, div_54, permute_439, permute_443, div_55, permute_447, permute_459, permute_464, permute_468, div_57, permute_472, permute_476, div_58, permute_480, permute_492, permute_497, permute_501, div_60, permute_505, permute_509, div_61, permute_513, permute_525, permute_530, permute_534, div_63, tangents_1, tangents_2 = args
args.clear()
assert_size_stride(primals_4, (768, ), (1, ))
assert_size_stride(primals_14, (768, ), (1, ))
assert_size_stride(primals_20, (768, ), (1, ))
assert_size_stride(primals_30, (768, ), (1, ))
assert_size_stride(primals_36, (768, ), (1, ))
assert_size_stride(primals_46, (768, ), (1, ))
assert_size_stride(primals_52, (768, ), (1, ))
assert_size_stride(primals_62, (768, ), (1, ))
assert_size_stride(primals_68, (768, ), (1, ))
assert_size_stride(primals_78, (768, ), (1, ))
assert_size_stride(primals_84, (768, ), (1, ))
assert_size_stride(primals_94, (768, ), (1, ))
assert_size_stride(primals_100, (768, ), (1, ))
assert_size_stride(primals_110, (768, ), (1, ))
assert_size_stride(primals_116, (768, ), (1, ))
assert_size_stride(primals_126, (768, ), (1, ))
assert_size_stride(primals_132, (768, ), (1, ))
assert_size_stride(primals_142, (768, ), (1, ))
assert_size_stride(primals_148, (768, ), (1, ))
assert_size_stride(primals_158, (768, ), (1, ))
assert_size_stride(primals_164, (768, ), (1, ))
assert_size_stride(primals_174, (768, ), (1, ))
assert_size_stride(primals_180, (768, ), (1, ))
assert_size_stride(primals_190, (768, ), (1, ))
assert_size_stride(primals_196, (768, ), (1, ))
assert_size_stride(primals_200, (768, ), (1, ))
assert_size_stride(primals_204, (1, 512), (512, 1))
assert_size_stride(primals_205, (1, 512), (512, 1))
assert_size_stride(primals_206, (16, 512), (512, 1))
assert_size_stride(primals_207, (16, 512), (512, 1))
assert_size_stride(mul_1, (16, 512, 768), (393216, 768, 1))
assert_size_stride(gt, (16, 512, 768), (393216, 768, 1))
assert_size_stride(view, (8192, 768), (768, 1))
assert_size_stride(permute_default_66, (16, 12, 512, 64), (393216, 64, 768, 1))
assert_size_stride(permute_default_67, (16, 12, 512, 64), (393216, 64, 768, 1))
assert_size_stride(permute_default_68, (16, 12, 512, 64), (393216, 64, 768, 1))
assert_size_stride(getitem_129, (16, 12, 512, 64), (393216, 64, 768, 1))
assert_size_stride(getitem_130, (16, 12, 512), (6144, 512, 1))
assert_size_stride(getitem_131, (), ())
assert_size_stride(getitem_132, (), ())
assert_size_stride(view_16, (8192, 768), (768, 1))
assert_size_stride(gt_2, (16, 512, 768), (393216, 768, 1))
assert_size_stride(mul_9, (16, 512, 768), (393216, 768, 1))
assert_size_stride(view_18, (8192, 768), (768, 1))
assert_size_stride(addmm_4, (8192, 3072), (3072, 1))
assert_size_stride(view_20, (8192, 3072), (3072, 1))
assert_size_stride(gt_3, (16, 512, 768), (393216, 768, 1))
assert_size_stride(mul_16, (16, 512, 768), (393216, 768, 1))
assert_size_stride(view_22, (8192, 768), (768, 1))
assert_size_stride(permute_default_60, (16, 12, 512, 64), (393216, 64, 768, 1))
assert_size_stride(permute_default_61, (16, 12, 512, 64), (393216, 64, 768, 1))
assert_size_stride(permute_default_62, (16, 12, 512, 64), (393216, 64, 768, 1))
assert_size_stride(getitem_122, (16, 12, 512, 64), (393216, 64, 768, 1))
assert_size_stride(getitem_123, (16, 12, 512), (6144, 512, 1))
assert_size_stride(getitem_124, (), ())
assert_size_stride(getitem_125, (), ())
assert_size_stride(view_38, (8192, 768), (768, 1))
assert_size_stride(gt_5, (16, 512, 768), (393216, 768, 1))
assert_size_stride(mul_22, (16, 512, 768), (393216, 768, 1))
assert_size_stride(view_40, (8192, 768), (768, 1))
assert_size_stride(addmm_10, (8192, 3072), (3072, 1))
assert_size_stride(view_42, (8192, 3072), (3072, 1))
assert_size_stride(gt_6, (16, 512, 768), (393216, 768, 1))
assert_size_stride(mul_29, (16, 512, 768), (393216, 768, 1))
assert_size_stride(view_44, (8192, 768), (768, 1))
assert_size_stride(permute_default_54, (16, 12, 512, 64), (393216, 64, 768, 1))
assert_size_stride(permute_default_55, (16, 12, 512, 64), (393216, 64, 768, 1))
assert_size_stride(permute_default_56, (16, 12, 512, 64), (393216, 64, 768, 1))
assert_size_stride(getitem_115, (16, 12, 512, 64), (393216, 64, 768, 1))
assert_size_stride(getitem_116, (16, 12, 512), (6144, 512, 1))
assert_size_stride(getitem_117, (), ())
assert_size_stride(getitem_118, (), ())
assert_size_stride(view_60, (8192, 768), (768, 1))
assert_size_stride(gt_8, (16, 512, 768), (393216, 768, 1))
assert_size_stride(mul_35, (16, 512, 768), (393216, 768, 1))
assert_size_stride(view_62, (8192, 768), (768, 1))
assert_size_stride(addmm_16, (8192, 3072), (3072, 1))
assert_size_stride(view_64, (8192, 3072), (3072, 1))
assert_size_stride(gt_9, (16, 512, 768), (393216, 768, 1))
assert_size_stride(mul_42, (16, 512, 768), (393216, 768, 1))
assert_size_stride(view_66, (8192, 768), (768, 1))
assert_size_stride(permute_default_48, (16, 12, 512, 64), (393216, 64, 768, 1))
assert_size_stride(permute_default_49, (16, 12, 512, 64), (393216, 64, 768, 1))
assert_size_stride(permute_default_50, (16, 12, 512, 64), (393216, 64, 768, 1))
assert_size_stride(getitem_108, (16, 12, 512, 64), (393216, 64, 768, 1))
assert_size_stride(getitem_109, (16, 12, 512), (6144, 512, 1))
assert_size_stride(getitem_110, (), ())
assert_size_stride(getitem_111, (), ())
assert_size_stride(view_82, (8192, 768), (768, 1))
assert_size_stride(gt_11, (16, 512, 768), (393216, 768, 1))
assert_size_stride(mul_48, (16, 512, 768), (393216, 768, 1))
assert_size_stride(view_84, (8192, 768), (768, 1))
assert_size_stride(addmm_22, (8192, 3072), (3072, 1))
assert_size_stride(view_86, (8192, 3072), (3072, 1))
assert_size_stride(gt_12, (16, 512, 768), (393216, 768, 1))
assert_size_stride(mul_55, (16, 512, 768), (393216, 768, 1))
assert_size_stride(view_88, (8192, 768), (768, 1))
assert_size_stride(permute_default_42, (16, 12, 512, 64), (393216, 64, 768, 1))
assert_size_stride(permute_default_43, (16, 12, 512, 64), (393216, 64, 768, 1))
assert_size_stride(permute_default_44, (16, 12, 512, 64), (393216, 64, 768, 1))
assert_size_stride(getitem_101, (16, 12, 512, 64), (393216, 64, 768, 1))
assert_size_stride(getitem_102, (16, 12, 512), (6144, 512, 1))
assert_size_stride(getitem_103, (), ())
assert_size_stride(getitem_104, (), ())
assert_size_stride(view_104, (8192, 768), (768, 1))
assert_size_stride(gt_14, (16, 512, 768), (393216, 768, 1))
assert_size_stride(mul_61, (16, 512, 768), (393216, 768, 1))
assert_size_stride(view_106, (8192, 768), (768, 1))
assert_size_stride(addmm_28, (8192, 3072), (3072, 1))
assert_size_stride(view_108, (8192, 3072), (3072, 1))
assert_size_stride(gt_15, (16, 512, 768), (393216, 768, 1))
assert_size_stride(mul_68, (16, 512, 768), (393216, 768, 1))
assert_size_stride(view_110, (8192, 768), (768, 1))
assert_size_stride(permute_default_36, (16, 12, 512, 64), (393216, 64, 768, 1))
assert_size_stride(permute_default_37, (16, 12, 512, 64), (393216, 64, 768, 1))
assert_size_stride(permute_default_38, (16, 12, 512, 64), (393216, 64, 768, 1))
assert_size_stride(getitem_94, (16, 12, 512, 64), (393216, 64, 768, 1))
assert_size_stride(getitem_95, (16, 12, 512), (6144, 512, 1))
assert_size_stride(getitem_96, (), ())
assert_size_stride(getitem_97, (), ())
assert_size_stride(view_126, (8192, 768), (768, 1))
assert_size_stride(gt_17, (16, 512, 768), (393216, 768, 1))
assert_size_stride(mul_74, (16, 512, 768), (393216, 768, 1))
assert_size_stride(view_128, (8192, 768), (768, 1))
assert_size_stride(addmm_34, (8192, 3072), (3072, 1))
assert_size_stride(view_130, (8192, 3072), (3072, 1))
assert_size_stride(gt_18, (16, 512, 768), (393216, 768, 1))
assert_size_stride(mul_81, (16, 512, 768), (393216, 768, 1))
assert_size_stride(view_132, (8192, 768), (768, 1))
assert_size_stride(permute_default_30, (16, 12, 512, 64), (393216, 64, 768, 1))
assert_size_stride(permute_default_31, (16, 12, 512, 64), (393216, 64, 768, 1))
assert_size_stride(permute_default_32, (16, 12, 512, 64), (393216, 64, 768, 1))
assert_size_stride(getitem_87, (16, 12, 512, 64), (393216, 64, 768, 1))
assert_size_stride(getitem_88, (16, 12, 512), (6144, 512, 1))
assert_size_stride(getitem_89, (), ())
assert_size_stride(getitem_90, (), ())
assert_size_stride(view_148, (8192, 768), (768, 1))
assert_size_stride(gt_20, (16, 512, 768), (393216, 768, 1))
assert_size_stride(mul_87, (16, 512, 768), (393216, 768, 1))
assert_size_stride(view_150, (8192, 768), (768, 1))
assert_size_stride(addmm_40, (8192, 3072), (3072, 1))
assert_size_stride(view_152, (8192, 3072), (3072, 1))
assert_size_stride(gt_21, (16, 512, 768), (393216, 768, 1))
assert_size_stride(mul_94, (16, 512, 768), (393216, 768, 1))
assert_size_stride(view_154, (8192, 768), (768, 1))
assert_size_stride(permute_default_24, (16, 12, 512, 64), (393216, 64, 768, 1))
assert_size_stride(permute_default_25, (16, 12, 512, 64), (393216, 64, 768, 1))
assert_size_stride(permute_default_26, (16, 12, 512, 64), (393216, 64, 768, 1))
assert_size_stride(getitem_80, (16, 12, 512, 64), (393216, 64, 768, 1))
assert_size_stride(getitem_81, (16, 12, 512), (6144, 512, 1))
assert_size_stride(getitem_82, (), ())
assert_size_stride(getitem_83, (), ())
assert_size_stride(view_170, (8192, 768), (768, 1))
assert_size_stride(gt_23, (16, 512, 768), (393216, 768, 1))
assert_size_stride(mul_100, (16, 512, 768), (393216, 768, 1))
assert_size_stride(view_172, (8192, 768), (768, 1))
assert_size_stride(addmm_46, (8192, 3072), (3072, 1))
assert_size_stride(view_174, (8192, 3072), (3072, 1))
assert_size_stride(gt_24, (16, 512, 768), (393216, 768, 1))
assert_size_stride(mul_107, (16, 512, 768), (393216, 768, 1))
assert_size_stride(view_176, (8192, 768), (768, 1))
assert_size_stride(permute_default_18, (16, 12, 512, 64), (393216, 64, 768, 1))
assert_size_stride(permute_default_19, (16, 12, 512, 64), (393216, 64, 768, 1))
assert_size_stride(permute_default_20, (16, 12, 512, 64), (393216, 64, 768, 1))
assert_size_stride(getitem_73, (16, 12, 512, 64), (393216, 64, 768, 1))
assert_size_stride(getitem_74, (16, 12, 512), (6144, 512, 1))
assert_size_stride(getitem_75, (), ())
assert_size_stride(getitem_76, (), ())
assert_size_stride(view_192, (8192, 768), (768, 1))
assert_size_stride(gt_26, (16, 512, 768), (393216, 768, 1))
assert_size_stride(mul_113, (16, 512, 768), (393216, 768, 1))
assert_size_stride(view_194, (8192, 768), (768, 1))
assert_size_stride(addmm_52, (8192, 3072), (3072, 1))
assert_size_stride(view_196, (8192, 3072), (3072, 1))
assert_size_stride(gt_27, (16, 512, 768), (393216, 768, 1))
assert_size_stride(mul_120, (16, 512, 768), (393216, 768, 1))
assert_size_stride(view_198, (8192, 768), (768, 1))
assert_size_stride(permute_default_12, (16, 12, 512, 64), (393216, 64, 768, 1))
assert_size_stride(permute_default_13, (16, 12, 512, 64), (393216, 64, 768, 1))
assert_size_stride(permute_default_14, (16, 12, 512, 64), (393216, 64, 768, 1))
assert_size_stride(getitem_66, (16, 12, 512, 64), (393216, 64, 768, 1))
assert_size_stride(getitem_67, (16, 12, 512), (6144, 512, 1))
assert_size_stride(getitem_68, (), ())
assert_size_stride(getitem_69, (), ())
assert_size_stride(view_214, (8192, 768), (768, 1))
assert_size_stride(gt_29, (16, 512, 768), (393216, 768, 1))
assert_size_stride(mul_126, (16, 512, 768), (393216, 768, 1))
assert_size_stride(view_216, (8192, 768), (768, 1))
assert_size_stride(addmm_58, (8192, 3072), (3072, 1))
assert_size_stride(view_218, (8192, 3072), (3072, 1))
assert_size_stride(gt_30, (16, 512, 768), (393216, 768, 1))
assert_size_stride(mul_133, (16, 512, 768), (393216, 768, 1))
assert_size_stride(view_220, (8192, 768), (768, 1))
assert_size_stride(permute_default_6, (16, 12, 512, 64), (393216, 64, 768, 1))
assert_size_stride(permute_default_7, (16, 12, 512, 64), (393216, 64, 768, 1))
assert_size_stride(permute_default_8, (16, 12, 512, 64), (393216, 64, 768, 1))
assert_size_stride(getitem_59, (16, 12, 512, 64), (393216, 64, 768, 1))
assert_size_stride(getitem_60, (16, 12, 512), (6144, 512, 1))
assert_size_stride(getitem_61, (), ())
assert_size_stride(getitem_62, (), ())
assert_size_stride(view_236, (8192, 768), (768, 1))
assert_size_stride(gt_32, (16, 512, 768), (393216, 768, 1))
assert_size_stride(mul_139, (16, 512, 768), (393216, 768, 1))
assert_size_stride(view_238, (8192, 768), (768, 1))
assert_size_stride(addmm_64, (8192, 3072), (3072, 1))
assert_size_stride(view_240, (8192, 3072), (3072, 1))
assert_size_stride(gt_33, (16, 512, 768), (393216, 768, 1))
assert_size_stride(mul_146, (16, 512, 768), (393216, 768, 1))
assert_size_stride(view_242, (8192, 768), (768, 1))
assert_size_stride(permute_default, (16, 12, 512, 64), (393216, 64, 768, 1))
assert_size_stride(permute_default_1, (16, 12, 512, 64), (393216, 64, 768, 1))
assert_size_stride(permute_default_2, (16, 12, 512, 64), (393216, 64, 768, 1))
assert_size_stride(getitem_52, (16, 12, 512, 64), (393216, 64, 768, 1))
assert_size_stride(getitem_53, (16, 12, 512), (6144, 512, 1))
assert_size_stride(getitem_54, (), ())
assert_size_stride(getitem_55, (), ())
assert_size_stride(view_258, (8192, 768), (768, 1))
assert_size_stride(gt_35, (16, 512, 768), (393216, 768, 1))
assert_size_stride(mul_152, (16, 512, 768), (393216, 768, 1))
assert_size_stride(view_260, (8192, 768), (768, 1))
assert_size_stride(addmm_70, (8192, 3072), (3072, 1))
assert_size_stride(view_262, (8192, 3072), (3072, 1))
assert_size_stride(gt_36, (16, 512, 768), (393216, 768, 1))
assert_size_stride(mul_159, (16, 512, 768), (393216, 768, 1))
assert_size_stride(view_264, (8192, 768), (768, 1))
assert_size_stride(addmm_72, (8192, 768), (768, 1))
assert_size_stride(getitem_51, (16, 512, 1), (512, 1, 1))
assert_size_stride(rsqrt_25, (16, 512, 1), (512, 1, 1))
assert_size_stride(view_266, (8192, 768), (768, 1))
assert_size_stride(view_267, (16, 512, 30522), (15630336, 30528, 1))
assert_size_stride(amax_12, (8192, 1), (1, 1))
assert_size_stride(log, (8192, 1), (1, 1))
assert_size_stride(convert_element_type_510, (), ())
assert_size_stride(permute_134, (30522, 768), (768, 1))
assert_size_stride(permute_138, (768, 768), (768, 1))
assert_size_stride(div_27, (16, 512, 1), (512, 1, 1))
assert_size_stride(permute_142, (768, 3072), (3072, 1))
assert_size_stride(permute_146, (3072, 768), (768, 1))
assert_size_stride(div_28, (16, 512, 1), (512, 1, 1))
assert_size_stride(permute_150, (768, 768), (768, 1))
assert_size_stride(permute_162, (768, 768), (768, 1))
assert_size_stride(permute_167, (768, 768), (768, 1))
assert_size_stride(permute_171, (768, 768), (768, 1))
assert_size_stride(div_30, (16, 512, 1), (512, 1, 1))
assert_size_stride(permute_175, (768, 3072), (3072, 1))
assert_size_stride(permute_179, (3072, 768), (768, 1))
assert_size_stride(div_31, (16, 512, 1), (512, 1, 1))
assert_size_stride(permute_183, (768, 768), (768, 1))
assert_size_stride(permute_195, (768, 768), (768, 1))
assert_size_stride(permute_200, (768, 768), (768, 1))
assert_size_stride(permute_204, (768, 768), (768, 1))
assert_size_stride(div_33, (16, 512, 1), (512, 1, 1))
assert_size_stride(permute_208, (768, 3072), (3072, 1))
assert_size_stride(permute_212, (3072, 768), (768, 1))
assert_size_stride(div_34, (16, 512, 1), (512, 1, 1))
assert_size_stride(permute_216, (768, 768), (768, 1))
assert_size_stride(permute_228, (768, 768), (768, 1))
assert_size_stride(permute_233, (768, 768), (768, 1))
assert_size_stride(permute_237, (768, 768), (768, 1))
assert_size_stride(div_36, (16, 512, 1), (512, 1, 1))
assert_size_stride(permute_241, (768, 3072), (3072, 1))
assert_size_stride(permute_245, (3072, 768), (768, 1))
assert_size_stride(div_37, (16, 512, 1), (512, 1, 1))
assert_size_stride(permute_249, (768, 768), (768, 1))
assert_size_stride(permute_261, (768, 768), (768, 1))
assert_size_stride(permute_266, (768, 768), (768, 1))
assert_size_stride(permute_270, (768, 768), (768, 1))
assert_size_stride(div_39, (16, 512, 1), (512, 1, 1))
assert_size_stride(permute_274, (768, 3072), (3072, 1))
assert_size_stride(permute_278, (3072, 768), (768, 1))
assert_size_stride(div_40, (16, 512, 1), (512, 1, 1))
assert_size_stride(permute_282, (768, 768), (768, 1))
assert_size_stride(permute_294, (768, 768), (768, 1))
assert_size_stride(permute_299, (768, 768), (768, 1))
assert_size_stride(permute_303, (768, 768), (768, 1))
assert_size_stride(div_42, (16, 512, 1), (512, 1, 1))
assert_size_stride(permute_307, (768, 3072), (3072, 1))
assert_size_stride(permute_311, (3072, 768), (768, 1))
assert_size_stride(div_43, (16, 512, 1), (512, 1, 1))
assert_size_stride(permute_315, (768, 768), (768, 1))
assert_size_stride(permute_327, (768, 768), (768, 1))
assert_size_stride(permute_332, (768, 768), (768, 1))
assert_size_stride(permute_336, (768, 768), (768, 1))
assert_size_stride(div_45, (16, 512, 1), (512, 1, 1))
assert_size_stride(permute_340, (768, 3072), (3072, 1))
assert_size_stride(permute_344, (3072, 768), (768, 1))
assert_size_stride(div_46, (16, 512, 1), (512, 1, 1))
assert_size_stride(permute_348, (768, 768), (768, 1))
assert_size_stride(permute_360, (768, 768), (768, 1))
assert_size_stride(permute_365, (768, 768), (768, 1))
assert_size_stride(permute_369, (768, 768), (768, 1))
assert_size_stride(div_48, (16, 512, 1), (512, 1, 1))
assert_size_stride(permute_373, (768, 3072), (3072, 1))
assert_size_stride(permute_377, (3072, 768), (768, 1))
assert_size_stride(div_49, (16, 512, 1), (512, 1, 1))
assert_size_stride(permute_381, (768, 768), (768, 1))
assert_size_stride(permute_393, (768, 768), (768, 1))
assert_size_stride(permute_398, (768, 768), (768, 1))
assert_size_stride(permute_402, (768, 768), (768, 1))
assert_size_stride(div_51, (16, 512, 1), (512, 1, 1))
assert_size_stride(permute_406, (768, 3072), (3072, 1))
assert_size_stride(permute_410, (3072, 768), (768, 1))
assert_size_stride(div_52, (16, 512, 1), (512, 1, 1))
assert_size_stride(permute_414, (768, 768), (768, 1))
assert_size_stride(permute_426, (768, 768), (768, 1))
assert_size_stride(permute_431, (768, 768), (768, 1))
assert_size_stride(permute_435, (768, 768), (768, 1))
assert_size_stride(div_54, (16, 512, 1), (512, 1, 1))
assert_size_stride(permute_439, (768, 3072), (3072, 1))
assert_size_stride(permute_443, (3072, 768), (768, 1))
assert_size_stride(div_55, (16, 512, 1), (512, 1, 1))
assert_size_stride(permute_447, (768, 768), (768, 1))
assert_size_stride(permute_459, (768, 768), (768, 1))
assert_size_stride(permute_464, (768, 768), (768, 1))
assert_size_stride(permute_468, (768, 768), (768, 1))
assert_size_stride(div_57, (16, 512, 1), (512, 1, 1))
assert_size_stride(permute_472, (768, 3072), (3072, 1))
assert_size_stride(permute_476, (3072, 768), (768, 1))
assert_size_stride(div_58, (16, 512, 1), (512, 1, 1))
assert_size_stride(permute_480, (768, 768), (768, 1))
assert_size_stride(permute_492, (768, 768), (768, 1))
assert_size_stride(permute_497, (768, 768), (768, 1))
assert_size_stride(permute_501, (768, 768), (768, 1))
assert_size_stride(div_60, (16, 512, 1), (512, 1, 1))
assert_size_stride(permute_505, (768, 3072), (3072, 1))
assert_size_stride(permute_509, (3072, 768), (768, 1))
assert_size_stride(div_61, (16, 512, 1), (512, 1, 1))
assert_size_stride(permute_513, (768, 768), (768, 1))
assert_size_stride(permute_525, (768, 768), (768, 1))
assert_size_stride(permute_530, (768, 768), (768, 1))
assert_size_stride(permute_534, (768, 768), (768, 1))
assert_size_stride(div_63, (16, 512, 1), (512, 1, 1))
assert_size_stride(tangents_1, (), ())
assert_size_stride(tangents_2, (16, 512, 30522), (15627264, 30522, 1))
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
buf1 = empty_strided_cuda((16, 512, 30522), (15630336, 30528, 1), torch.float16)
# Source Nodes: [masked_lm_loss], Original ATen: [aten._log_softmax_backward_data, aten.add, aten.nll_loss_backward, aten.nll_loss_forward]
stream0 = get_raw_stream(0)
triton_red_fused__log_softmax_backward_data_add_nll_loss_backward_nll_loss_forward_0.run(primals_207, tangents_1, convert_element_type_510, tangents_2, view_267, amax_12, log, buf1, 8192, 30522, grid=grid(8192), stream=stream0)
del amax_12
del convert_element_type_510
del log
del primals_207
del tangents_1
del tangents_2
del view_267
buf2 = empty_strided_cuda((8192, 30528), (30528, 1), torch.float16)
# Source Nodes: [], Original ATen: []
triton_poi_fused_1.run(buf1, buf2, 250085376, grid=grid(250085376), stream=stream0)
buf3 = empty_strided_cuda((30528, 768), (768, 1), torch.float16)
# Source Nodes: [], Original ATen: []
triton_poi_fused_2.run(permute_134, buf3, 23445504, grid=grid(23445504), stream=stream0)
del permute_134
buf4 = empty_strided_cuda((8192, 768), (768, 1), torch.float16)
# Source Nodes: [], Original ATen: []
extern_kernels.mm(buf2, buf3, out=buf4)
del buf2
del buf3
buf5 = empty_strided_cuda((30522, 768), (768, 1), torch.float16)
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf1, (30522, 8192), (1, 30528), 0), view_266, out=buf5)
del view_266
buf8 = empty_strided_cuda((30522, ), (1, ), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum]
triton_red_fused__to_copy_sum_3.run(buf1, buf8, 30522, 8192, grid=grid(30522), stream=stream0)
del buf1
buf7 = empty_strided_cuda((30522, 768), (768, 1), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_4.run(buf5, buf7, 23440896, grid=grid(23440896), stream=stream0)
del buf5
buf16 = empty_strided_cuda((8192, 768), (768, 1), torch.float16)
# Source Nodes: [hidden_states_97, hidden_states_98], Original ATen: [aten._to_copy, aten.gelu, aten.gelu_backward, aten.native_layer_norm, aten.native_layer_norm_backward, aten.view]
triton_per_fused__to_copy_gelu_gelu_backward_native_layer_norm_native_layer_norm_backward_view_5.run(buf4, primals_200, addmm_72, getitem_51, rsqrt_25, buf16, 8192, 768, grid=grid(8192), stream=stream0)
del primals_200
buf11 = empty_strided_cuda((768, 64), (1, 768), torch.float32)
buf13 = empty_strided_cuda((768, 64), (1, 768), torch.float32)
# Source Nodes: [hidden_states_97, hidden_states_98], Original ATen: [aten._to_copy, aten.gelu, aten.native_layer_norm, aten.native_layer_norm_backward]
triton_red_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_6.run(buf4, addmm_72, getitem_51, rsqrt_25, buf11, buf13, 49152, 128, grid=grid(49152), stream=stream0)
del addmm_72
del getitem_51
del rsqrt_25
buf12 = empty_strided_cuda((768, ), (1, ), torch.float32)
# Source Nodes: [hidden_states_97, hidden_states_98], Original ATen: [aten._to_copy, aten.gelu, aten.native_layer_norm, aten.native_layer_norm_backward]
triton_per_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_7.run(buf11, buf12, 768, 64, grid=grid(768), stream=stream0)
buf14 = empty_strided_cuda((768, ), (1, ), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy, aten.native_layer_norm_backward]
triton_per_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_7.run(buf13, buf14, 768, 64, grid=grid(768), stream=stream0)
buf17 = buf4; del buf4 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(buf16, permute_138, out=buf17)
del permute_138
buf18 = empty_strided_cuda((768, 768), (768, 1), torch.float16)
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf16, (768, 8192), (1, 768), 0), view_264, out=buf18)
del view_264
buf19 = reinterpret_tensor(buf13, (1, 768, 64), (49152, 1, 768), 0); del buf13 # reuse
# Source Nodes: [], Original ATen: [aten.sum]
triton_red_fused_sum_8.run(buf16, buf19, 49152, 128, grid=grid(49152), stream=stream0)
buf22 = empty_strided_cuda((768, ), (1, ), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum]
triton_per_fused__to_copy_sum_9.run(buf19, buf22, 768, 64, grid=grid(768), stream=stream0)
buf21 = empty_strided_cuda((768, 768), (768, 1), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_10.run(buf18, buf21, 589824, grid=grid(589824), stream=stream0)
buf25 = empty_strided_cuda((16, 512, 768), (393216, 768, 1), torch.float32)
buf30 = reinterpret_tensor(buf16, (16, 512, 768), (393216, 768, 1), 0); del buf16 # reuse
# Source Nodes: [], Original ATen: [aten._to_copy, aten.native_dropout_backward, aten.native_layer_norm_backward]
triton_per_fused__to_copy_native_dropout_backward_native_layer_norm_backward_11.run(buf17, primals_196, mul_159, div_27, gt_36, buf25, buf30, 8192, 768, grid=grid(8192), stream=stream0)
del div_27
del gt_36
del primals_196
buf26 = reinterpret_tensor(buf19, (768, 64), (1, 768), 0); del buf19 # reuse
buf28 = buf11; del buf11 # reuse
# Source Nodes: [], Original ATen: [aten._to_copy, aten.native_layer_norm_backward]
triton_red_fused__to_copy_native_layer_norm_backward_12.run(buf17, mul_159, buf26, buf28, 49152, 128, grid=grid(49152), stream=stream0)
del mul_159
buf27 = empty_strided_cuda((768, ), (1, ), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy, aten.native_layer_norm_backward]
triton_per_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_7.run(buf26, buf27, 768, 64, grid=grid(768), stream=stream0)
buf29 = empty_strided_cuda((768, ), (1, ), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy, aten.native_layer_norm_backward]
triton_per_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_7.run(buf28, buf29, 768, 64, grid=grid(768), stream=stream0)
buf31 = empty_strided_cuda((8192, 3072), (3072, 1), torch.float16)
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf30, (8192, 768), (768, 1), 0), permute_142, out=buf31)
del permute_142
buf32 = empty_strided_cuda((768, 3072), (3072, 1), torch.float16)
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf30, (768, 8192), (1, 768), 0), view_262, out=buf32)
del view_262
buf33 = reinterpret_tensor(buf28, (1, 768, 64), (49152, 1, 768), 0); del buf28 # reuse
# Source Nodes: [], Original ATen: [aten.sum]
triton_red_fused_sum_13.run(buf30, buf33, 49152, 128, grid=grid(49152), stream=stream0)
buf36 = empty_strided_cuda((768, ), (1, ), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum]
triton_per_fused__to_copy_sum_9.run(buf33, buf36, 768, 64, grid=grid(768), stream=stream0)
buf35 = empty_strided_cuda((768, 3072), (3072, 1), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_14.run(buf32, buf35, 2359296, grid=grid(2359296), stream=stream0)
buf37 = reinterpret_tensor(buf31, (16, 512, 3072), (1572864, 3072, 1), 0); del buf31 # reuse
# Source Nodes: [hidden_states_92], Original ATen: [aten.gelu, aten.gelu_backward]
triton_poi_fused_gelu_gelu_backward_15.run(buf37, addmm_70, 25165824, grid=grid(25165824), stream=stream0)
del addmm_70
buf38 = reinterpret_tensor(buf30, (8192, 768), (768, 1), 0); del buf30 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf37, (8192, 3072), (3072, 1), 0), permute_146, out=buf38)
del permute_146
buf39 = reinterpret_tensor(buf32, (3072, 768), (768, 1), 0); del buf32 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf37, (3072, 8192), (1, 3072), 0), view_260, out=buf39)
del view_260
buf40 = empty_strided_cuda((1, 3072, 32), (98304, 1, 3072), torch.float32)
# Source Nodes: [], Original ATen: [aten.sum]
triton_red_fused_sum_16.run(buf37, buf40, 98304, 256, grid=grid(98304), stream=stream0)
buf43 = empty_strided_cuda((3072, ), (1, ), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum]
triton_per_fused__to_copy_sum_17.run(buf40, buf43, 3072, 32, grid=grid(3072), stream=stream0)
buf42 = empty_strided_cuda((3072, 768), (768, 1), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_18.run(buf39, buf42, 2359296, grid=grid(2359296), stream=stream0)
buf46 = empty_strided_cuda((16, 512, 768), (393216, 768, 1), torch.float32)
buf51 = reinterpret_tensor(buf17, (16, 512, 768), (393216, 768, 1), 0); del buf17 # reuse
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_dropout_backward, aten.native_layer_norm_backward]
triton_per_fused__to_copy_add_native_dropout_backward_native_layer_norm_backward_19.run(buf25, buf38, primals_190, mul_152, div_28, gt_35, buf46, buf51, 8192, 768, grid=grid(8192), stream=stream0)
del div_28
del gt_35
del primals_190
buf47 = reinterpret_tensor(buf33, (768, 64), (1, 768), 0); del buf33 # reuse
buf49 = buf26; del buf26 # reuse
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward]
triton_red_fused__to_copy_add_native_layer_norm_backward_20.run(buf25, buf38, mul_152, buf47, buf49, 49152, 128, grid=grid(49152), stream=stream0)
del mul_152
buf48 = empty_strided_cuda((768, ), (1, ), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward]
triton_per_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_7.run(buf47, buf48, 768, 64, grid=grid(768), stream=stream0)
buf50 = empty_strided_cuda((768, ), (1, ), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward]
triton_per_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_7.run(buf49, buf50, 768, 64, grid=grid(768), stream=stream0)
buf52 = buf38; del buf38 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf51, (8192, 768), (768, 1), 0), permute_150, out=buf52)
del permute_150
buf53 = buf18; del buf18 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf51, (768, 8192), (1, 768), 0), view_258, out=buf53)
del view_258
buf54 = reinterpret_tensor(buf49, (1, 768, 64), (49152, 1, 768), 0); del buf49 # reuse
# Source Nodes: [], Original ATen: [aten.sum]
triton_red_fused_sum_13.run(buf51, buf54, 49152, 128, grid=grid(49152), stream=stream0)
buf57 = empty_strided_cuda((768, ), (1, ), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum]
triton_per_fused__to_copy_sum_9.run(buf54, buf57, 768, 64, grid=grid(768), stream=stream0)
buf56 = empty_strided_cuda((768, 768), (768, 1), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_10.run(buf53, buf56, 589824, grid=grid(589824), stream=stream0)
buf58 = reinterpret_tensor(buf51, (16, 12, 512, 64), (393216, 32768, 64, 1), 0); del buf51 # reuse
# Source Nodes: [], Original ATen: [aten.clone]
triton_poi_fused_clone_21.run(buf52, buf58, 6291456, grid=grid(6291456), stream=stream0)
del buf52
# Source Nodes: [], Original ATen: [aten.clone]
buf59 = aten._scaled_dot_product_flash_attention_backward.default(buf58, permute_default, permute_default_1, permute_default_2, getitem_52, getitem_53, None, None, 512, 512, 0.1, False, getitem_54, getitem_55, scale=0.125)
del getitem_52
del getitem_53
del getitem_54
del getitem_55
del permute_default
del permute_default_1
del permute_default_2
buf60 = buf59[0]
buf61 = buf59[1]
buf62 = buf59[2]
del buf59
buf63 = reinterpret_tensor(buf58, (8192, 768), (768, 1), 0); del buf58 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf62, (8192, 768), (768, 1), 0), permute_162, out=buf63)
del permute_162
buf64 = buf53; del buf53 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf62, (768, 8192), (1, 768), 0), view_242, out=buf64)
buf65 = buf54; del buf54 # reuse
# Source Nodes: [], Original ATen: [aten.sum]
triton_red_fused_sum_22.run(buf62, buf65, 49152, 128, grid=grid(49152), stream=stream0)
buf68 = empty_strided_cuda((768, ), (1, ), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum]
triton_per_fused__to_copy_sum_9.run(buf65, buf68, 768, 64, grid=grid(768), stream=stream0)
buf67 = empty_strided_cuda((768, 768), (768, 1), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_10.run(buf64, buf67, 589824, grid=grid(589824), stream=stream0)
buf69 = reinterpret_tensor(buf62, (8192, 768), (768, 1), 0); del buf62 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf61, (8192, 768), (768, 1), 0), permute_167, out=buf69)
del permute_167
buf70 = buf64; del buf64 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf61, (768, 8192), (1, 768), 0), view_242, out=buf70)
buf71 = buf65; del buf65 # reuse
# Source Nodes: [], Original ATen: [aten.sum]
triton_red_fused_sum_22.run(buf61, buf71, 49152, 128, grid=grid(49152), stream=stream0)
buf74 = empty_strided_cuda((768, ), (1, ), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum]
triton_per_fused__to_copy_sum_9.run(buf71, buf74, 768, 64, grid=grid(768), stream=stream0)
buf73 = empty_strided_cuda((768, 768), (768, 1), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_10.run(buf70, buf73, 589824, grid=grid(589824), stream=stream0)
buf75 = reinterpret_tensor(buf61, (8192, 768), (768, 1), 0); del buf61 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf60, (8192, 768), (768, 1), 0), permute_171, out=buf75)
del permute_171
buf76 = buf70; del buf70 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf60, (768, 8192), (1, 768), 0), view_242, out=buf76)
del view_242
buf77 = buf71; del buf71 # reuse
# Source Nodes: [], Original ATen: [aten.sum]
triton_red_fused_sum_22.run(buf60, buf77, 49152, 128, grid=grid(49152), stream=stream0)
buf80 = empty_strided_cuda((768, ), (1, ), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum]
triton_per_fused__to_copy_sum_9.run(buf77, buf80, 768, 64, grid=grid(768), stream=stream0)
buf79 = empty_strided_cuda((768, 768), (768, 1), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_10.run(buf76, buf79, 589824, grid=grid(589824), stream=stream0)
buf84 = buf25; del buf25 # reuse
buf89 = reinterpret_tensor(buf60, (16, 512, 768), (393216, 768, 1), 0); del buf60 # reuse
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_dropout_backward, aten.native_layer_norm_backward]
triton_per_fused__to_copy_add_native_dropout_backward_native_layer_norm_backward_23.run(buf46, buf63, buf69, buf75, primals_180, mul_146, div_30, gt_33, buf84, buf89, 8192, 768, grid=grid(8192), stream=stream0)
del div_30
del gt_33
del primals_180
buf85 = reinterpret_tensor(buf77, (768, 64), (1, 768), 0); del buf77 # reuse
buf87 = buf47; del buf47 # reuse
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward]
triton_red_fused__to_copy_add_native_layer_norm_backward_24.run(buf46, buf63, buf69, buf75, mul_146, buf85, buf87, 49152, 128, grid=grid(49152), stream=stream0)
del buf63
del buf69
del mul_146
buf86 = empty_strided_cuda((768, ), (1, ), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward]
triton_per_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_7.run(buf85, buf86, 768, 64, grid=grid(768), stream=stream0)
buf88 = empty_strided_cuda((768, ), (1, ), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward]
triton_per_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_7.run(buf87, buf88, 768, 64, grid=grid(768), stream=stream0)
buf90 = reinterpret_tensor(buf37, (8192, 3072), (3072, 1), 0); del buf37 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf89, (8192, 768), (768, 1), 0), permute_175, out=buf90)
del permute_175
buf91 = reinterpret_tensor(buf39, (768, 3072), (3072, 1), 0); del buf39 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf89, (768, 8192), (1, 768), 0), view_240, out=buf91)
del view_240
buf92 = reinterpret_tensor(buf87, (1, 768, 64), (49152, 1, 768), 0); del buf87 # reuse
# Source Nodes: [], Original ATen: [aten.sum]
triton_red_fused_sum_13.run(buf89, buf92, 49152, 128, grid=grid(49152), stream=stream0)
buf95 = empty_strided_cuda((768, ), (1, ), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum]
triton_per_fused__to_copy_sum_9.run(buf92, buf95, 768, 64, grid=grid(768), stream=stream0)
buf94 = empty_strided_cuda((768, 3072), (3072, 1), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_14.run(buf91, buf94, 2359296, grid=grid(2359296), stream=stream0)
buf96 = reinterpret_tensor(buf90, (16, 512, 3072), (1572864, 3072, 1), 0); del buf90 # reuse
# Source Nodes: [hidden_states_84], Original ATen: [aten.gelu, aten.gelu_backward]
triton_poi_fused_gelu_gelu_backward_15.run(buf96, addmm_64, 25165824, grid=grid(25165824), stream=stream0)
del addmm_64
buf97 = reinterpret_tensor(buf89, (8192, 768), (768, 1), 0); del buf89 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf96, (8192, 3072), (3072, 1), 0), permute_179, out=buf97)
del permute_179
buf98 = reinterpret_tensor(buf91, (3072, 768), (768, 1), 0); del buf91 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf96, (3072, 8192), (1, 3072), 0), view_238, out=buf98)
del view_238
buf99 = buf40; del buf40 # reuse
# Source Nodes: [], Original ATen: [aten.sum]
triton_red_fused_sum_16.run(buf96, buf99, 98304, 256, grid=grid(98304), stream=stream0)
buf102 = empty_strided_cuda((3072, ), (1, ), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum]
triton_per_fused__to_copy_sum_17.run(buf99, buf102, 3072, 32, grid=grid(3072), stream=stream0)
buf101 = empty_strided_cuda((3072, 768), (768, 1), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_18.run(buf98, buf101, 2359296, grid=grid(2359296), stream=stream0)
buf105 = buf46; del buf46 # reuse
buf110 = reinterpret_tensor(buf75, (16, 512, 768), (393216, 768, 1), 0); del buf75 # reuse
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_dropout_backward, aten.native_layer_norm_backward]
triton_per_fused__to_copy_add_native_dropout_backward_native_layer_norm_backward_19.run(buf84, buf97, primals_174, mul_139, div_31, gt_32, buf105, buf110, 8192, 768, grid=grid(8192), stream=stream0)
del div_31
del gt_32
del primals_174
buf106 = reinterpret_tensor(buf92, (768, 64), (1, 768), 0); del buf92 # reuse
buf108 = buf85; del buf85 # reuse
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward]
triton_red_fused__to_copy_add_native_layer_norm_backward_20.run(buf84, buf97, mul_139, buf106, buf108, 49152, 128, grid=grid(49152), stream=stream0)
del mul_139
buf107 = empty_strided_cuda((768, ), (1, ), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward]
triton_per_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_7.run(buf106, buf107, 768, 64, grid=grid(768), stream=stream0)
buf109 = empty_strided_cuda((768, ), (1, ), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward]
triton_per_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_7.run(buf108, buf109, 768, 64, grid=grid(768), stream=stream0)
buf111 = buf97; del buf97 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf110, (8192, 768), (768, 1), 0), permute_183, out=buf111)
del permute_183
buf112 = buf76; del buf76 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf110, (768, 8192), (1, 768), 0), view_236, out=buf112)
del view_236
buf113 = reinterpret_tensor(buf108, (1, 768, 64), (49152, 1, 768), 0); del buf108 # reuse
# Source Nodes: [], Original ATen: [aten.sum]
triton_red_fused_sum_13.run(buf110, buf113, 49152, 128, grid=grid(49152), stream=stream0)
buf116 = empty_strided_cuda((768, ), (1, ), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum]
triton_per_fused__to_copy_sum_9.run(buf113, buf116, 768, 64, grid=grid(768), stream=stream0)
buf115 = empty_strided_cuda((768, 768), (768, 1), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_10.run(buf112, buf115, 589824, grid=grid(589824), stream=stream0)
buf117 = reinterpret_tensor(buf110, (16, 12, 512, 64), (393216, 32768, 64, 1), 0); del buf110 # reuse
# Source Nodes: [], Original ATen: [aten.clone]
triton_poi_fused_clone_21.run(buf111, buf117, 6291456, grid=grid(6291456), stream=stream0)
del buf111
# Source Nodes: [], Original ATen: [aten.clone]
buf118 = aten._scaled_dot_product_flash_attention_backward.default(buf117, permute_default_6, permute_default_7, permute_default_8, getitem_59, getitem_60, None, None, 512, 512, 0.1, False, getitem_61, getitem_62, scale=0.125)
del getitem_59
del getitem_60
del getitem_61
del getitem_62
del permute_default_6
del permute_default_7
del permute_default_8
buf119 = buf118[0]
buf120 = buf118[1]
buf121 = buf118[2]
del buf118
buf122 = reinterpret_tensor(buf117, (8192, 768), (768, 1), 0); del buf117 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf121, (8192, 768), (768, 1), 0), permute_195, out=buf122)
del permute_195
buf123 = buf112; del buf112 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf121, (768, 8192), (1, 768), 0), view_220, out=buf123)
buf124 = buf113; del buf113 # reuse
# Source Nodes: [], Original ATen: [aten.sum]
triton_red_fused_sum_22.run(buf121, buf124, 49152, 128, grid=grid(49152), stream=stream0)
buf127 = empty_strided_cuda((768, ), (1, ), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum]
triton_per_fused__to_copy_sum_9.run(buf124, buf127, 768, 64, grid=grid(768), stream=stream0)
buf126 = empty_strided_cuda((768, 768), (768, 1), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_10.run(buf123, buf126, 589824, grid=grid(589824), stream=stream0)
buf128 = reinterpret_tensor(buf121, (8192, 768), (768, 1), 0); del buf121 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf120, (8192, 768), (768, 1), 0), permute_200, out=buf128)
del permute_200
buf129 = buf123; del buf123 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf120, (768, 8192), (1, 768), 0), view_220, out=buf129)
buf130 = buf124; del buf124 # reuse
# Source Nodes: [], Original ATen: [aten.sum]
triton_red_fused_sum_22.run(buf120, buf130, 49152, 128, grid=grid(49152), stream=stream0)
buf133 = empty_strided_cuda((768, ), (1, ), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum]
triton_per_fused__to_copy_sum_9.run(buf130, buf133, 768, 64, grid=grid(768), stream=stream0)
buf132 = empty_strided_cuda((768, 768), (768, 1), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_10.run(buf129, buf132, 589824, grid=grid(589824), stream=stream0)
buf134 = reinterpret_tensor(buf120, (8192, 768), (768, 1), 0); del buf120 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf119, (8192, 768), (768, 1), 0), permute_204, out=buf134)
del permute_204
buf135 = buf129; del buf129 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf119, (768, 8192), (1, 768), 0), view_220, out=buf135)
del view_220
buf136 = buf130; del buf130 # reuse
# Source Nodes: [], Original ATen: [aten.sum]
triton_red_fused_sum_22.run(buf119, buf136, 49152, 128, grid=grid(49152), stream=stream0)
buf139 = empty_strided_cuda((768, ), (1, ), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum]
triton_per_fused__to_copy_sum_9.run(buf136, buf139, 768, 64, grid=grid(768), stream=stream0)
buf138 = empty_strided_cuda((768, 768), (768, 1), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_10.run(buf135, buf138, 589824, grid=grid(589824), stream=stream0)
buf143 = buf84; del buf84 # reuse
buf148 = reinterpret_tensor(buf119, (16, 512, 768), (393216, 768, 1), 0); del buf119 # reuse
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_dropout_backward, aten.native_layer_norm_backward]
triton_per_fused__to_copy_add_native_dropout_backward_native_layer_norm_backward_23.run(buf105, buf122, buf128, buf134, primals_164, mul_133, div_33, gt_30, buf143, buf148, 8192, 768, grid=grid(8192), stream=stream0)
del div_33
del gt_30
del primals_164
buf144 = reinterpret_tensor(buf136, (768, 64), (1, 768), 0); del buf136 # reuse
buf146 = buf106; del buf106 # reuse
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward]
triton_red_fused__to_copy_add_native_layer_norm_backward_24.run(buf105, buf122, buf128, buf134, mul_133, buf144, buf146, 49152, 128, grid=grid(49152), stream=stream0)
del buf122
del buf128
del mul_133
buf145 = empty_strided_cuda((768, ), (1, ), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward]
triton_per_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_7.run(buf144, buf145, 768, 64, grid=grid(768), stream=stream0)
buf147 = empty_strided_cuda((768, ), (1, ), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward]
triton_per_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_7.run(buf146, buf147, 768, 64, grid=grid(768), stream=stream0)
buf149 = reinterpret_tensor(buf96, (8192, 3072), (3072, 1), 0); del buf96 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf148, (8192, 768), (768, 1), 0), permute_208, out=buf149)
del permute_208
buf150 = reinterpret_tensor(buf98, (768, 3072), (3072, 1), 0); del buf98 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf148, (768, 8192), (1, 768), 0), view_218, out=buf150)
del view_218
buf151 = reinterpret_tensor(buf146, (1, 768, 64), (49152, 1, 768), 0); del buf146 # reuse
# Source Nodes: [], Original ATen: [aten.sum]
triton_red_fused_sum_13.run(buf148, buf151, 49152, 128, grid=grid(49152), stream=stream0)
buf154 = empty_strided_cuda((768, ), (1, ), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum]
triton_per_fused__to_copy_sum_9.run(buf151, buf154, 768, 64, grid=grid(768), stream=stream0)
buf153 = empty_strided_cuda((768, 3072), (3072, 1), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_14.run(buf150, buf153, 2359296, grid=grid(2359296), stream=stream0)
buf155 = reinterpret_tensor(buf149, (16, 512, 3072), (1572864, 3072, 1), 0); del buf149 # reuse
# Source Nodes: [hidden_states_76], Original ATen: [aten.gelu, aten.gelu_backward]
triton_poi_fused_gelu_gelu_backward_15.run(buf155, addmm_58, 25165824, grid=grid(25165824), stream=stream0)
del addmm_58
buf156 = reinterpret_tensor(buf148, (8192, 768), (768, 1), 0); del buf148 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf155, (8192, 3072), (3072, 1), 0), permute_212, out=buf156)
del permute_212
buf157 = reinterpret_tensor(buf150, (3072, 768), (768, 1), 0); del buf150 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf155, (3072, 8192), (1, 3072), 0), view_216, out=buf157)
del view_216
buf158 = buf99; del buf99 # reuse
# Source Nodes: [], Original ATen: [aten.sum]
triton_red_fused_sum_16.run(buf155, buf158, 98304, 256, grid=grid(98304), stream=stream0)
buf161 = empty_strided_cuda((3072, ), (1, ), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum]
triton_per_fused__to_copy_sum_17.run(buf158, buf161, 3072, 32, grid=grid(3072), stream=stream0)
buf160 = empty_strided_cuda((3072, 768), (768, 1), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_18.run(buf157, buf160, 2359296, grid=grid(2359296), stream=stream0)
buf164 = buf105; del buf105 # reuse
buf169 = reinterpret_tensor(buf134, (16, 512, 768), (393216, 768, 1), 0); del buf134 # reuse
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_dropout_backward, aten.native_layer_norm_backward]
triton_per_fused__to_copy_add_native_dropout_backward_native_layer_norm_backward_19.run(buf143, buf156, primals_158, mul_126, div_34, gt_29, buf164, buf169, 8192, 768, grid=grid(8192), stream=stream0)
del div_34
del gt_29
del primals_158
buf165 = reinterpret_tensor(buf151, (768, 64), (1, 768), 0); del buf151 # reuse
buf167 = buf144; del buf144 # reuse
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward]
triton_red_fused__to_copy_add_native_layer_norm_backward_20.run(buf143, buf156, mul_126, buf165, buf167, 49152, 128, grid=grid(49152), stream=stream0)
del mul_126
buf166 = empty_strided_cuda((768, ), (1, ), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward]
triton_per_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_7.run(buf165, buf166, 768, 64, grid=grid(768), stream=stream0)
buf168 = empty_strided_cuda((768, ), (1, ), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward]
triton_per_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_7.run(buf167, buf168, 768, 64, grid=grid(768), stream=stream0)
buf170 = buf156; del buf156 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf169, (8192, 768), (768, 1), 0), permute_216, out=buf170)
del permute_216
buf171 = buf135; del buf135 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf169, (768, 8192), (1, 768), 0), view_214, out=buf171)
del view_214
buf172 = reinterpret_tensor(buf167, (1, 768, 64), (49152, 1, 768), 0); del buf167 # reuse
# Source Nodes: [], Original ATen: [aten.sum]
triton_red_fused_sum_13.run(buf169, buf172, 49152, 128, grid=grid(49152), stream=stream0)
buf175 = empty_strided_cuda((768, ), (1, ), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum]
triton_per_fused__to_copy_sum_9.run(buf172, buf175, 768, 64, grid=grid(768), stream=stream0)
buf174 = empty_strided_cuda((768, 768), (768, 1), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_10.run(buf171, buf174, 589824, grid=grid(589824), stream=stream0)
buf176 = reinterpret_tensor(buf169, (16, 12, 512, 64), (393216, 32768, 64, 1), 0); del buf169 # reuse
# Source Nodes: [], Original ATen: [aten.clone]
triton_poi_fused_clone_21.run(buf170, buf176, 6291456, grid=grid(6291456), stream=stream0)
del buf170
# Source Nodes: [], Original ATen: [aten.clone]
buf177 = aten._scaled_dot_product_flash_attention_backward.default(buf176, permute_default_12, permute_default_13, permute_default_14, getitem_66, getitem_67, None, None, 512, 512, 0.1, False, getitem_68, getitem_69, scale=0.125)
del getitem_66
del getitem_67
del getitem_68
del getitem_69
del permute_default_12
del permute_default_13
del permute_default_14
buf178 = buf177[0]
buf179 = buf177[1]
buf180 = buf177[2]
del buf177
buf181 = reinterpret_tensor(buf176, (8192, 768), (768, 1), 0); del buf176 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf180, (8192, 768), (768, 1), 0), permute_228, out=buf181)
del permute_228
buf182 = buf171; del buf171 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf180, (768, 8192), (1, 768), 0), view_198, out=buf182)
buf183 = buf172; del buf172 # reuse
# Source Nodes: [], Original ATen: [aten.sum]
triton_red_fused_sum_22.run(buf180, buf183, 49152, 128, grid=grid(49152), stream=stream0)
buf186 = empty_strided_cuda((768, ), (1, ), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum]
triton_per_fused__to_copy_sum_9.run(buf183, buf186, 768, 64, grid=grid(768), stream=stream0)
buf185 = empty_strided_cuda((768, 768), (768, 1), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_10.run(buf182, buf185, 589824, grid=grid(589824), stream=stream0)
buf187 = reinterpret_tensor(buf180, (8192, 768), (768, 1), 0); del buf180 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf179, (8192, 768), (768, 1), 0), permute_233, out=buf187)
del permute_233
buf188 = buf182; del buf182 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf179, (768, 8192), (1, 768), 0), view_198, out=buf188)
buf189 = buf183; del buf183 # reuse
# Source Nodes: [], Original ATen: [aten.sum]
triton_red_fused_sum_22.run(buf179, buf189, 49152, 128, grid=grid(49152), stream=stream0)
buf192 = empty_strided_cuda((768, ), (1, ), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum]
triton_per_fused__to_copy_sum_9.run(buf189, buf192, 768, 64, grid=grid(768), stream=stream0)
buf191 = empty_strided_cuda((768, 768), (768, 1), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_10.run(buf188, buf191, 589824, grid=grid(589824), stream=stream0)
buf193 = reinterpret_tensor(buf179, (8192, 768), (768, 1), 0); del buf179 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf178, (8192, 768), (768, 1), 0), permute_237, out=buf193)
del permute_237
buf194 = buf188; del buf188 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf178, (768, 8192), (1, 768), 0), view_198, out=buf194)
del view_198
buf195 = buf189; del buf189 # reuse
# Source Nodes: [], Original ATen: [aten.sum]
triton_red_fused_sum_22.run(buf178, buf195, 49152, 128, grid=grid(49152), stream=stream0)
buf198 = empty_strided_cuda((768, ), (1, ), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum]
triton_per_fused__to_copy_sum_9.run(buf195, buf198, 768, 64, grid=grid(768), stream=stream0)
buf197 = empty_strided_cuda((768, 768), (768, 1), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_10.run(buf194, buf197, 589824, grid=grid(589824), stream=stream0)
buf202 = buf143; del buf143 # reuse
buf207 = reinterpret_tensor(buf178, (16, 512, 768), (393216, 768, 1), 0); del buf178 # reuse
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_dropout_backward, aten.native_layer_norm_backward]
triton_per_fused__to_copy_add_native_dropout_backward_native_layer_norm_backward_23.run(buf164, buf181, buf187, buf193, primals_148, mul_120, div_36, gt_27, buf202, buf207, 8192, 768, grid=grid(8192), stream=stream0)
del div_36
del gt_27
del primals_148
buf203 = reinterpret_tensor(buf195, (768, 64), (1, 768), 0); del buf195 # reuse
buf205 = buf165; del buf165 # reuse
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward]
triton_red_fused__to_copy_add_native_layer_norm_backward_24.run(buf164, buf181, buf187, buf193, mul_120, buf203, buf205, 49152, 128, grid=grid(49152), stream=stream0)
del buf181
del buf187
del mul_120
buf204 = empty_strided_cuda((768, ), (1, ), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward]
triton_per_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_7.run(buf203, buf204, 768, 64, grid=grid(768), stream=stream0)
buf206 = empty_strided_cuda((768, ), (1, ), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward]
triton_per_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_7.run(buf205, buf206, 768, 64, grid=grid(768), stream=stream0)
buf208 = reinterpret_tensor(buf155, (8192, 3072), (3072, 1), 0); del buf155 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf207, (8192, 768), (768, 1), 0), permute_241, out=buf208)
del permute_241
buf209 = reinterpret_tensor(buf157, (768, 3072), (3072, 1), 0); del buf157 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf207, (768, 8192), (1, 768), 0), view_196, out=buf209)
del view_196
buf210 = reinterpret_tensor(buf205, (1, 768, 64), (49152, 1, 768), 0); del buf205 # reuse
# Source Nodes: [], Original ATen: [aten.sum]
triton_red_fused_sum_13.run(buf207, buf210, 49152, 128, grid=grid(49152), stream=stream0)
buf213 = empty_strided_cuda((768, ), (1, ), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum]
triton_per_fused__to_copy_sum_9.run(buf210, buf213, 768, 64, grid=grid(768), stream=stream0)
buf212 = empty_strided_cuda((768, 3072), (3072, 1), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_14.run(buf209, buf212, 2359296, grid=grid(2359296), stream=stream0)
buf214 = reinterpret_tensor(buf208, (16, 512, 3072), (1572864, 3072, 1), 0); del buf208 # reuse
# Source Nodes: [hidden_states_68], Original ATen: [aten.gelu, aten.gelu_backward]
triton_poi_fused_gelu_gelu_backward_15.run(buf214, addmm_52, 25165824, grid=grid(25165824), stream=stream0)
del addmm_52
buf215 = reinterpret_tensor(buf207, (8192, 768), (768, 1), 0); del buf207 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf214, (8192, 3072), (3072, 1), 0), permute_245, out=buf215)
del permute_245
buf216 = reinterpret_tensor(buf209, (3072, 768), (768, 1), 0); del buf209 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf214, (3072, 8192), (1, 3072), 0), view_194, out=buf216)
del view_194
buf217 = buf158; del buf158 # reuse
# Source Nodes: [], Original ATen: [aten.sum]
triton_red_fused_sum_16.run(buf214, buf217, 98304, 256, grid=grid(98304), stream=stream0)
buf220 = empty_strided_cuda((3072, ), (1, ), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum]
triton_per_fused__to_copy_sum_17.run(buf217, buf220, 3072, 32, grid=grid(3072), stream=stream0)
buf219 = empty_strided_cuda((3072, 768), (768, 1), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_18.run(buf216, buf219, 2359296, grid=grid(2359296), stream=stream0)
buf223 = buf164; del buf164 # reuse
buf228 = reinterpret_tensor(buf193, (16, 512, 768), (393216, 768, 1), 0); del buf193 # reuse
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_dropout_backward, aten.native_layer_norm_backward]
triton_per_fused__to_copy_add_native_dropout_backward_native_layer_norm_backward_19.run(buf202, buf215, primals_142, mul_113, div_37, gt_26, buf223, buf228, 8192, 768, grid=grid(8192), stream=stream0)
del div_37
del gt_26
del primals_142
buf224 = reinterpret_tensor(buf210, (768, 64), (1, 768), 0); del buf210 # reuse
buf226 = buf203; del buf203 # reuse
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward]
triton_red_fused__to_copy_add_native_layer_norm_backward_20.run(buf202, buf215, mul_113, buf224, buf226, 49152, 128, grid=grid(49152), stream=stream0)
del mul_113
buf225 = empty_strided_cuda((768, ), (1, ), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward]
triton_per_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_7.run(buf224, buf225, 768, 64, grid=grid(768), stream=stream0)
buf227 = empty_strided_cuda((768, ), (1, ), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward]
triton_per_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_7.run(buf226, buf227, 768, 64, grid=grid(768), stream=stream0)
buf229 = buf215; del buf215 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf228, (8192, 768), (768, 1), 0), permute_249, out=buf229)
del permute_249
buf230 = buf194; del buf194 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf228, (768, 8192), (1, 768), 0), view_192, out=buf230)
del view_192
buf231 = reinterpret_tensor(buf226, (1, 768, 64), (49152, 1, 768), 0); del buf226 # reuse
# Source Nodes: [], Original ATen: [aten.sum]
triton_red_fused_sum_13.run(buf228, buf231, 49152, 128, grid=grid(49152), stream=stream0)
buf234 = empty_strided_cuda((768, ), (1, ), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum]
triton_per_fused__to_copy_sum_9.run(buf231, buf234, 768, 64, grid=grid(768), stream=stream0)
buf233 = empty_strided_cuda((768, 768), (768, 1), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_10.run(buf230, buf233, 589824, grid=grid(589824), stream=stream0)
buf235 = reinterpret_tensor(buf228, (16, 12, 512, 64), (393216, 32768, 64, 1), 0); del buf228 # reuse
# Source Nodes: [], Original ATen: [aten.clone]
triton_poi_fused_clone_21.run(buf229, buf235, 6291456, grid=grid(6291456), stream=stream0)
del buf229
# Source Nodes: [], Original ATen: [aten.clone]
buf236 = aten._scaled_dot_product_flash_attention_backward.default(buf235, permute_default_18, permute_default_19, permute_default_20, getitem_73, getitem_74, None, None, 512, 512, 0.1, False, getitem_75, getitem_76, scale=0.125)
del getitem_73
del getitem_74
del getitem_75
del getitem_76
del permute_default_18
del permute_default_19
del permute_default_20
buf237 = buf236[0]
buf238 = buf236[1]
buf239 = buf236[2]
del buf236
buf240 = reinterpret_tensor(buf235, (8192, 768), (768, 1), 0); del buf235 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf239, (8192, 768), (768, 1), 0), permute_261, out=buf240)
del permute_261
buf241 = buf230; del buf230 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf239, (768, 8192), (1, 768), 0), view_176, out=buf241)
buf242 = buf231; del buf231 # reuse
# Source Nodes: [], Original ATen: [aten.sum]
triton_red_fused_sum_22.run(buf239, buf242, 49152, 128, grid=grid(49152), stream=stream0)
buf245 = empty_strided_cuda((768, ), (1, ), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum]
triton_per_fused__to_copy_sum_9.run(buf242, buf245, 768, 64, grid=grid(768), stream=stream0)
buf244 = empty_strided_cuda((768, 768), (768, 1), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_10.run(buf241, buf244, 589824, grid=grid(589824), stream=stream0)
buf246 = reinterpret_tensor(buf239, (8192, 768), (768, 1), 0); del buf239 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf238, (8192, 768), (768, 1), 0), permute_266, out=buf246)
del permute_266
buf247 = buf241; del buf241 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf238, (768, 8192), (1, 768), 0), view_176, out=buf247)
buf248 = buf242; del buf242 # reuse
# Source Nodes: [], Original ATen: [aten.sum]
triton_red_fused_sum_22.run(buf238, buf248, 49152, 128, grid=grid(49152), stream=stream0)
buf251 = empty_strided_cuda((768, ), (1, ), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum]
triton_per_fused__to_copy_sum_9.run(buf248, buf251, 768, 64, grid=grid(768), stream=stream0)
buf250 = empty_strided_cuda((768, 768), (768, 1), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_10.run(buf247, buf250, 589824, grid=grid(589824), stream=stream0)
buf252 = reinterpret_tensor(buf238, (8192, 768), (768, 1), 0); del buf238 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf237, (8192, 768), (768, 1), 0), permute_270, out=buf252)
del permute_270
buf253 = buf247; del buf247 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf237, (768, 8192), (1, 768), 0), view_176, out=buf253)
del view_176
buf254 = buf248; del buf248 # reuse
# Source Nodes: [], Original ATen: [aten.sum]
triton_red_fused_sum_22.run(buf237, buf254, 49152, 128, grid=grid(49152), stream=stream0)
buf257 = empty_strided_cuda((768, ), (1, ), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum]
triton_per_fused__to_copy_sum_9.run(buf254, buf257, 768, 64, grid=grid(768), stream=stream0)
buf256 = empty_strided_cuda((768, 768), (768, 1), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_10.run(buf253, buf256, 589824, grid=grid(589824), stream=stream0)
buf261 = buf202; del buf202 # reuse
buf266 = reinterpret_tensor(buf237, (16, 512, 768), (393216, 768, 1), 0); del buf237 # reuse
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_dropout_backward, aten.native_layer_norm_backward]
triton_per_fused__to_copy_add_native_dropout_backward_native_layer_norm_backward_23.run(buf223, buf240, buf246, buf252, primals_132, mul_107, div_39, gt_24, buf261, buf266, 8192, 768, grid=grid(8192), stream=stream0)
del div_39
del gt_24
del primals_132
buf262 = reinterpret_tensor(buf254, (768, 64), (1, 768), 0); del buf254 # reuse
buf264 = buf224; del buf224 # reuse
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward]
triton_red_fused__to_copy_add_native_layer_norm_backward_24.run(buf223, buf240, buf246, buf252, mul_107, buf262, buf264, 49152, 128, grid=grid(49152), stream=stream0)
del buf240
del buf246
del mul_107
buf263 = empty_strided_cuda((768, ), (1, ), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward]
triton_per_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_7.run(buf262, buf263, 768, 64, grid=grid(768), stream=stream0)
buf265 = empty_strided_cuda((768, ), (1, ), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward]
triton_per_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_7.run(buf264, buf265, 768, 64, grid=grid(768), stream=stream0)
buf267 = reinterpret_tensor(buf214, (8192, 3072), (3072, 1), 0); del buf214 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf266, (8192, 768), (768, 1), 0), permute_274, out=buf267)
del permute_274
buf268 = reinterpret_tensor(buf216, (768, 3072), (3072, 1), 0); del buf216 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf266, (768, 8192), (1, 768), 0), view_174, out=buf268)
del view_174
buf269 = reinterpret_tensor(buf264, (1, 768, 64), (49152, 1, 768), 0); del buf264 # reuse
# Source Nodes: [], Original ATen: [aten.sum]
triton_red_fused_sum_13.run(buf266, buf269, 49152, 128, grid=grid(49152), stream=stream0)
buf272 = empty_strided_cuda((768, ), (1, ), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum]
triton_per_fused__to_copy_sum_9.run(buf269, buf272, 768, 64, grid=grid(768), stream=stream0)
buf271 = empty_strided_cuda((768, 3072), (3072, 1), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_14.run(buf268, buf271, 2359296, grid=grid(2359296), stream=stream0)
buf273 = reinterpret_tensor(buf267, (16, 512, 3072), (1572864, 3072, 1), 0); del buf267 # reuse
# Source Nodes: [hidden_states_60], Original ATen: [aten.gelu, aten.gelu_backward]
triton_poi_fused_gelu_gelu_backward_15.run(buf273, addmm_46, 25165824, grid=grid(25165824), stream=stream0)
del addmm_46
buf274 = reinterpret_tensor(buf266, (8192, 768), (768, 1), 0); del buf266 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf273, (8192, 3072), (3072, 1), 0), permute_278, out=buf274)
del permute_278
buf275 = reinterpret_tensor(buf268, (3072, 768), (768, 1), 0); del buf268 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf273, (3072, 8192), (1, 3072), 0), view_172, out=buf275)
del view_172
buf276 = buf217; del buf217 # reuse
# Source Nodes: [], Original ATen: [aten.sum]
triton_red_fused_sum_16.run(buf273, buf276, 98304, 256, grid=grid(98304), stream=stream0)
buf279 = empty_strided_cuda((3072, ), (1, ), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum]
triton_per_fused__to_copy_sum_17.run(buf276, buf279, 3072, 32, grid=grid(3072), stream=stream0)
buf278 = empty_strided_cuda((3072, 768), (768, 1), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_18.run(buf275, buf278, 2359296, grid=grid(2359296), stream=stream0)
buf282 = buf223; del buf223 # reuse
buf287 = reinterpret_tensor(buf252, (16, 512, 768), (393216, 768, 1), 0); del buf252 # reuse
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_dropout_backward, aten.native_layer_norm_backward]
triton_per_fused__to_copy_add_native_dropout_backward_native_layer_norm_backward_19.run(buf261, buf274, primals_126, mul_100, div_40, gt_23, buf282, buf287, 8192, 768, grid=grid(8192), stream=stream0)
del div_40
del gt_23
del primals_126
buf283 = reinterpret_tensor(buf269, (768, 64), (1, 768), 0); del buf269 # reuse
buf285 = buf262; del buf262 # reuse
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward]
triton_red_fused__to_copy_add_native_layer_norm_backward_20.run(buf261, buf274, mul_100, buf283, buf285, 49152, 128, grid=grid(49152), stream=stream0)
del mul_100
buf284 = empty_strided_cuda((768, ), (1, ), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward]
triton_per_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_7.run(buf283, buf284, 768, 64, grid=grid(768), stream=stream0)
buf286 = empty_strided_cuda((768, ), (1, ), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward]
triton_per_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_7.run(buf285, buf286, 768, 64, grid=grid(768), stream=stream0)
buf288 = buf274; del buf274 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf287, (8192, 768), (768, 1), 0), permute_282, out=buf288)
del permute_282
buf289 = buf253; del buf253 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf287, (768, 8192), (1, 768), 0), view_170, out=buf289)
del view_170
buf290 = reinterpret_tensor(buf285, (1, 768, 64), (49152, 1, 768), 0); del buf285 # reuse
# Source Nodes: [], Original ATen: [aten.sum]
triton_red_fused_sum_13.run(buf287, buf290, 49152, 128, grid=grid(49152), stream=stream0)
buf293 = empty_strided_cuda((768, ), (1, ), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum]
triton_per_fused__to_copy_sum_9.run(buf290, buf293, 768, 64, grid=grid(768), stream=stream0)
buf292 = empty_strided_cuda((768, 768), (768, 1), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_10.run(buf289, buf292, 589824, grid=grid(589824), stream=stream0)
buf294 = reinterpret_tensor(buf287, (16, 12, 512, 64), (393216, 32768, 64, 1), 0); del buf287 # reuse
# Source Nodes: [], Original ATen: [aten.clone]
triton_poi_fused_clone_21.run(buf288, buf294, 6291456, grid=grid(6291456), stream=stream0)
del buf288
# Source Nodes: [], Original ATen: [aten.clone]
buf295 = aten._scaled_dot_product_flash_attention_backward.default(buf294, permute_default_24, permute_default_25, permute_default_26, getitem_80, getitem_81, None, None, 512, 512, 0.1, False, getitem_82, getitem_83, scale=0.125)
del getitem_80
del getitem_81
del getitem_82
del getitem_83
del permute_default_24
del permute_default_25
del permute_default_26
buf296 = buf295[0]
buf297 = buf295[1]
buf298 = buf295[2]
del buf295
buf299 = reinterpret_tensor(buf294, (8192, 768), (768, 1), 0); del buf294 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf298, (8192, 768), (768, 1), 0), permute_294, out=buf299)
del permute_294
buf300 = buf289; del buf289 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf298, (768, 8192), (1, 768), 0), view_154, out=buf300)
buf301 = buf290; del buf290 # reuse
# Source Nodes: [], Original ATen: [aten.sum]
triton_red_fused_sum_22.run(buf298, buf301, 49152, 128, grid=grid(49152), stream=stream0)
buf304 = empty_strided_cuda((768, ), (1, ), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum]
triton_per_fused__to_copy_sum_9.run(buf301, buf304, 768, 64, grid=grid(768), stream=stream0)
buf303 = empty_strided_cuda((768, 768), (768, 1), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_10.run(buf300, buf303, 589824, grid=grid(589824), stream=stream0)
buf305 = reinterpret_tensor(buf298, (8192, 768), (768, 1), 0); del buf298 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf297, (8192, 768), (768, 1), 0), permute_299, out=buf305)
del permute_299
buf306 = buf300; del buf300 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf297, (768, 8192), (1, 768), 0), view_154, out=buf306)
buf307 = buf301; del buf301 # reuse
# Source Nodes: [], Original ATen: [aten.sum]
triton_red_fused_sum_22.run(buf297, buf307, 49152, 128, grid=grid(49152), stream=stream0)
buf310 = empty_strided_cuda((768, ), (1, ), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum]
triton_per_fused__to_copy_sum_9.run(buf307, buf310, 768, 64, grid=grid(768), stream=stream0)
buf309 = empty_strided_cuda((768, 768), (768, 1), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_10.run(buf306, buf309, 589824, grid=grid(589824), stream=stream0)
buf311 = reinterpret_tensor(buf297, (8192, 768), (768, 1), 0); del buf297 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf296, (8192, 768), (768, 1), 0), permute_303, out=buf311)
del permute_303
buf312 = buf306; del buf306 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf296, (768, 8192), (1, 768), 0), view_154, out=buf312)
del view_154
buf313 = buf307; del buf307 # reuse
# Source Nodes: [], Original ATen: [aten.sum]
triton_red_fused_sum_22.run(buf296, buf313, 49152, 128, grid=grid(49152), stream=stream0)
buf316 = empty_strided_cuda((768, ), (1, ), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum]
triton_per_fused__to_copy_sum_9.run(buf313, buf316, 768, 64, grid=grid(768), stream=stream0)
buf315 = empty_strided_cuda((768, 768), (768, 1), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_10.run(buf312, buf315, 589824, grid=grid(589824), stream=stream0)
buf320 = buf261; del buf261 # reuse
buf325 = reinterpret_tensor(buf296, (16, 512, 768), (393216, 768, 1), 0); del buf296 # reuse
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_dropout_backward, aten.native_layer_norm_backward]
triton_per_fused__to_copy_add_native_dropout_backward_native_layer_norm_backward_23.run(buf282, buf299, buf305, buf311, primals_116, mul_94, div_42, gt_21, buf320, buf325, 8192, 768, grid=grid(8192), stream=stream0)
del div_42
del gt_21
del primals_116
buf321 = reinterpret_tensor(buf313, (768, 64), (1, 768), 0); del buf313 # reuse
buf323 = buf283; del buf283 # reuse
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward]
triton_red_fused__to_copy_add_native_layer_norm_backward_24.run(buf282, buf299, buf305, buf311, mul_94, buf321, buf323, 49152, 128, grid=grid(49152), stream=stream0)
del buf299
del buf305
del mul_94
buf322 = empty_strided_cuda((768, ), (1, ), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward]
triton_per_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_7.run(buf321, buf322, 768, 64, grid=grid(768), stream=stream0)
buf324 = empty_strided_cuda((768, ), (1, ), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward]
triton_per_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_7.run(buf323, buf324, 768, 64, grid=grid(768), stream=stream0)
buf326 = reinterpret_tensor(buf273, (8192, 3072), (3072, 1), 0); del buf273 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf325, (8192, 768), (768, 1), 0), permute_307, out=buf326)
del permute_307
buf327 = reinterpret_tensor(buf275, (768, 3072), (3072, 1), 0); del buf275 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf325, (768, 8192), (1, 768), 0), view_152, out=buf327)
del view_152
buf328 = reinterpret_tensor(buf323, (1, 768, 64), (49152, 1, 768), 0); del buf323 # reuse
# Source Nodes: [], Original ATen: [aten.sum]
triton_red_fused_sum_13.run(buf325, buf328, 49152, 128, grid=grid(49152), stream=stream0)
buf331 = empty_strided_cuda((768, ), (1, ), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum]
triton_per_fused__to_copy_sum_9.run(buf328, buf331, 768, 64, grid=grid(768), stream=stream0)
buf330 = empty_strided_cuda((768, 3072), (3072, 1), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_14.run(buf327, buf330, 2359296, grid=grid(2359296), stream=stream0)
buf332 = reinterpret_tensor(buf326, (16, 512, 3072), (1572864, 3072, 1), 0); del buf326 # reuse
# Source Nodes: [hidden_states_52], Original ATen: [aten.gelu, aten.gelu_backward]
triton_poi_fused_gelu_gelu_backward_15.run(buf332, addmm_40, 25165824, grid=grid(25165824), stream=stream0)
del addmm_40
buf333 = reinterpret_tensor(buf325, (8192, 768), (768, 1), 0); del buf325 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf332, (8192, 3072), (3072, 1), 0), permute_311, out=buf333)
del permute_311
buf334 = reinterpret_tensor(buf327, (3072, 768), (768, 1), 0); del buf327 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf332, (3072, 8192), (1, 3072), 0), view_150, out=buf334)
del view_150
buf335 = buf276; del buf276 # reuse
# Source Nodes: [], Original ATen: [aten.sum]
triton_red_fused_sum_16.run(buf332, buf335, 98304, 256, grid=grid(98304), stream=stream0)
buf338 = empty_strided_cuda((3072, ), (1, ), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum]
triton_per_fused__to_copy_sum_17.run(buf335, buf338, 3072, 32, grid=grid(3072), stream=stream0)
buf337 = empty_strided_cuda((3072, 768), (768, 1), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_18.run(buf334, buf337, 2359296, grid=grid(2359296), stream=stream0)
buf341 = buf282; del buf282 # reuse
buf346 = reinterpret_tensor(buf311, (16, 512, 768), (393216, 768, 1), 0); del buf311 # reuse
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_dropout_backward, aten.native_layer_norm_backward]
triton_per_fused__to_copy_add_native_dropout_backward_native_layer_norm_backward_19.run(buf320, buf333, primals_110, mul_87, div_43, gt_20, buf341, buf346, 8192, 768, grid=grid(8192), stream=stream0)
del div_43
del gt_20
del primals_110
buf342 = reinterpret_tensor(buf328, (768, 64), (1, 768), 0); del buf328 # reuse
buf344 = buf321; del buf321 # reuse
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward]
triton_red_fused__to_copy_add_native_layer_norm_backward_20.run(buf320, buf333, mul_87, buf342, buf344, 49152, 128, grid=grid(49152), stream=stream0)
del mul_87
buf343 = empty_strided_cuda((768, ), (1, ), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward]
triton_per_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_7.run(buf342, buf343, 768, 64, grid=grid(768), stream=stream0)
buf345 = empty_strided_cuda((768, ), (1, ), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward]
triton_per_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_7.run(buf344, buf345, 768, 64, grid=grid(768), stream=stream0)
buf347 = buf333; del buf333 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf346, (8192, 768), (768, 1), 0), permute_315, out=buf347)
del permute_315
buf348 = buf312; del buf312 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf346, (768, 8192), (1, 768), 0), view_148, out=buf348)
del view_148
buf349 = reinterpret_tensor(buf344, (1, 768, 64), (49152, 1, 768), 0); del buf344 # reuse
# Source Nodes: [], Original ATen: [aten.sum]
triton_red_fused_sum_13.run(buf346, buf349, 49152, 128, grid=grid(49152), stream=stream0)
buf352 = empty_strided_cuda((768, ), (1, ), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum]
triton_per_fused__to_copy_sum_9.run(buf349, buf352, 768, 64, grid=grid(768), stream=stream0)
buf351 = empty_strided_cuda((768, 768), (768, 1), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_10.run(buf348, buf351, 589824, grid=grid(589824), stream=stream0)
buf353 = reinterpret_tensor(buf346, (16, 12, 512, 64), (393216, 32768, 64, 1), 0); del buf346 # reuse
# Source Nodes: [], Original ATen: [aten.clone]
triton_poi_fused_clone_21.run(buf347, buf353, 6291456, grid=grid(6291456), stream=stream0)
del buf347
# Source Nodes: [], Original ATen: [aten.clone]
buf354 = aten._scaled_dot_product_flash_attention_backward.default(buf353, permute_default_30, permute_default_31, permute_default_32, getitem_87, getitem_88, None, None, 512, 512, 0.1, False, getitem_89, getitem_90, scale=0.125)
del getitem_87
del getitem_88
del getitem_89
del getitem_90
del permute_default_30
del permute_default_31
del permute_default_32
buf355 = buf354[0]
buf356 = buf354[1]
buf357 = buf354[2]
del buf354
buf358 = reinterpret_tensor(buf353, (8192, 768), (768, 1), 0); del buf353 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf357, (8192, 768), (768, 1), 0), permute_327, out=buf358)
del permute_327
buf359 = buf348; del buf348 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf357, (768, 8192), (1, 768), 0), view_132, out=buf359)
buf360 = buf349; del buf349 # reuse
# Source Nodes: [], Original ATen: [aten.sum]
triton_red_fused_sum_22.run(buf357, buf360, 49152, 128, grid=grid(49152), stream=stream0)
buf363 = empty_strided_cuda((768, ), (1, ), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum]
triton_per_fused__to_copy_sum_9.run(buf360, buf363, 768, 64, grid=grid(768), stream=stream0)
buf362 = empty_strided_cuda((768, 768), (768, 1), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_10.run(buf359, buf362, 589824, grid=grid(589824), stream=stream0)
buf364 = reinterpret_tensor(buf357, (8192, 768), (768, 1), 0); del buf357 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf356, (8192, 768), (768, 1), 0), permute_332, out=buf364)
del permute_332
buf365 = buf359; del buf359 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf356, (768, 8192), (1, 768), 0), view_132, out=buf365)
buf366 = buf360; del buf360 # reuse
# Source Nodes: [], Original ATen: [aten.sum]
triton_red_fused_sum_22.run(buf356, buf366, 49152, 128, grid=grid(49152), stream=stream0)
buf369 = empty_strided_cuda((768, ), (1, ), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum]
triton_per_fused__to_copy_sum_9.run(buf366, buf369, 768, 64, grid=grid(768), stream=stream0)
buf368 = empty_strided_cuda((768, 768), (768, 1), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_10.run(buf365, buf368, 589824, grid=grid(589824), stream=stream0)
buf370 = reinterpret_tensor(buf356, (8192, 768), (768, 1), 0); del buf356 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf355, (8192, 768), (768, 1), 0), permute_336, out=buf370)
del permute_336
buf371 = buf365; del buf365 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf355, (768, 8192), (1, 768), 0), view_132, out=buf371)
del view_132
buf372 = buf366; del buf366 # reuse
# Source Nodes: [], Original ATen: [aten.sum]
triton_red_fused_sum_22.run(buf355, buf372, 49152, 128, grid=grid(49152), stream=stream0)
buf375 = empty_strided_cuda((768, ), (1, ), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum]
triton_per_fused__to_copy_sum_9.run(buf372, buf375, 768, 64, grid=grid(768), stream=stream0)
buf374 = empty_strided_cuda((768, 768), (768, 1), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_10.run(buf371, buf374, 589824, grid=grid(589824), stream=stream0)
buf379 = buf320; del buf320 # reuse
buf384 = reinterpret_tensor(buf355, (16, 512, 768), (393216, 768, 1), 0); del buf355 # reuse
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_dropout_backward, aten.native_layer_norm_backward]
triton_per_fused__to_copy_add_native_dropout_backward_native_layer_norm_backward_23.run(buf341, buf358, buf364, buf370, primals_100, mul_81, div_45, gt_18, buf379, buf384, 8192, 768, grid=grid(8192), stream=stream0)
del div_45
del gt_18
del primals_100
buf380 = reinterpret_tensor(buf372, (768, 64), (1, 768), 0); del buf372 # reuse
buf382 = buf342; del buf342 # reuse
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward]
triton_red_fused__to_copy_add_native_layer_norm_backward_24.run(buf341, buf358, buf364, buf370, mul_81, buf380, buf382, 49152, 128, grid=grid(49152), stream=stream0)
del buf358
del buf364
del mul_81
buf381 = empty_strided_cuda((768, ), (1, ), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward]
triton_per_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_7.run(buf380, buf381, 768, 64, grid=grid(768), stream=stream0)
buf383 = empty_strided_cuda((768, ), (1, ), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward]
triton_per_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_7.run(buf382, buf383, 768, 64, grid=grid(768), stream=stream0)
buf385 = reinterpret_tensor(buf332, (8192, 3072), (3072, 1), 0); del buf332 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf384, (8192, 768), (768, 1), 0), permute_340, out=buf385)
del permute_340
buf386 = reinterpret_tensor(buf334, (768, 3072), (3072, 1), 0); del buf334 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf384, (768, 8192), (1, 768), 0), view_130, out=buf386)
del view_130
buf387 = reinterpret_tensor(buf382, (1, 768, 64), (49152, 1, 768), 0); del buf382 # reuse
# Source Nodes: [], Original ATen: [aten.sum]
triton_red_fused_sum_13.run(buf384, buf387, 49152, 128, grid=grid(49152), stream=stream0)
buf390 = empty_strided_cuda((768, ), (1, ), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum]
triton_per_fused__to_copy_sum_9.run(buf387, buf390, 768, 64, grid=grid(768), stream=stream0)
buf389 = empty_strided_cuda((768, 3072), (3072, 1), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_14.run(buf386, buf389, 2359296, grid=grid(2359296), stream=stream0)
buf391 = reinterpret_tensor(buf385, (16, 512, 3072), (1572864, 3072, 1), 0); del buf385 # reuse
# Source Nodes: [hidden_states_44], Original ATen: [aten.gelu, aten.gelu_backward]
triton_poi_fused_gelu_gelu_backward_15.run(buf391, addmm_34, 25165824, grid=grid(25165824), stream=stream0)
del addmm_34
buf392 = reinterpret_tensor(buf384, (8192, 768), (768, 1), 0); del buf384 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf391, (8192, 3072), (3072, 1), 0), permute_344, out=buf392)
del permute_344
buf393 = reinterpret_tensor(buf386, (3072, 768), (768, 1), 0); del buf386 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf391, (3072, 8192), (1, 3072), 0), view_128, out=buf393)
del view_128
buf394 = buf335; del buf335 # reuse
# Source Nodes: [], Original ATen: [aten.sum]
triton_red_fused_sum_16.run(buf391, buf394, 98304, 256, grid=grid(98304), stream=stream0)
buf397 = empty_strided_cuda((3072, ), (1, ), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum]
triton_per_fused__to_copy_sum_17.run(buf394, buf397, 3072, 32, grid=grid(3072), stream=stream0)
buf396 = empty_strided_cuda((3072, 768), (768, 1), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_18.run(buf393, buf396, 2359296, grid=grid(2359296), stream=stream0)
buf400 = buf341; del buf341 # reuse
buf405 = reinterpret_tensor(buf370, (16, 512, 768), (393216, 768, 1), 0); del buf370 # reuse
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_dropout_backward, aten.native_layer_norm_backward]
triton_per_fused__to_copy_add_native_dropout_backward_native_layer_norm_backward_19.run(buf379, buf392, primals_94, mul_74, div_46, gt_17, buf400, buf405, 8192, 768, grid=grid(8192), stream=stream0)
del div_46
del gt_17
del primals_94
buf401 = reinterpret_tensor(buf387, (768, 64), (1, 768), 0); del buf387 # reuse
buf403 = buf380; del buf380 # reuse
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward]
triton_red_fused__to_copy_add_native_layer_norm_backward_20.run(buf379, buf392, mul_74, buf401, buf403, 49152, 128, grid=grid(49152), stream=stream0)
del mul_74
buf402 = empty_strided_cuda((768, ), (1, ), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward]
triton_per_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_7.run(buf401, buf402, 768, 64, grid=grid(768), stream=stream0)
buf404 = empty_strided_cuda((768, ), (1, ), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward]
triton_per_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_7.run(buf403, buf404, 768, 64, grid=grid(768), stream=stream0)
buf406 = buf392; del buf392 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf405, (8192, 768), (768, 1), 0), permute_348, out=buf406)
del permute_348
buf407 = buf371; del buf371 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf405, (768, 8192), (1, 768), 0), view_126, out=buf407)
del view_126
buf408 = reinterpret_tensor(buf403, (1, 768, 64), (49152, 1, 768), 0); del buf403 # reuse
# Source Nodes: [], Original ATen: [aten.sum]
triton_red_fused_sum_13.run(buf405, buf408, 49152, 128, grid=grid(49152), stream=stream0)
buf411 = empty_strided_cuda((768, ), (1, ), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum]
triton_per_fused__to_copy_sum_9.run(buf408, buf411, 768, 64, grid=grid(768), stream=stream0)
buf410 = empty_strided_cuda((768, 768), (768, 1), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_10.run(buf407, buf410, 589824, grid=grid(589824), stream=stream0)
buf412 = reinterpret_tensor(buf405, (16, 12, 512, 64), (393216, 32768, 64, 1), 0); del buf405 # reuse
# Source Nodes: [], Original ATen: [aten.clone]
triton_poi_fused_clone_21.run(buf406, buf412, 6291456, grid=grid(6291456), stream=stream0)
del buf406
# Source Nodes: [], Original ATen: [aten.clone]
buf413 = aten._scaled_dot_product_flash_attention_backward.default(buf412, permute_default_36, permute_default_37, permute_default_38, getitem_94, getitem_95, None, None, 512, 512, 0.1, False, getitem_96, getitem_97, scale=0.125)
del getitem_94
del getitem_95
del getitem_96
del getitem_97
del permute_default_36
del permute_default_37
del permute_default_38
buf414 = buf413[0]
buf415 = buf413[1]
buf416 = buf413[2]
del buf413
buf417 = reinterpret_tensor(buf412, (8192, 768), (768, 1), 0); del buf412 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf416, (8192, 768), (768, 1), 0), permute_360, out=buf417)
del permute_360
buf418 = buf407; del buf407 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf416, (768, 8192), (1, 768), 0), view_110, out=buf418)
buf419 = buf408; del buf408 # reuse
# Source Nodes: [], Original ATen: [aten.sum]
triton_red_fused_sum_22.run(buf416, buf419, 49152, 128, grid=grid(49152), stream=stream0)
buf422 = empty_strided_cuda((768, ), (1, ), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum]
triton_per_fused__to_copy_sum_9.run(buf419, buf422, 768, 64, grid=grid(768), stream=stream0)
buf421 = empty_strided_cuda((768, 768), (768, 1), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_10.run(buf418, buf421, 589824, grid=grid(589824), stream=stream0)
buf423 = reinterpret_tensor(buf416, (8192, 768), (768, 1), 0); del buf416 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf415, (8192, 768), (768, 1), 0), permute_365, out=buf423)
del permute_365
buf424 = buf418; del buf418 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf415, (768, 8192), (1, 768), 0), view_110, out=buf424)
buf425 = buf419; del buf419 # reuse
# Source Nodes: [], Original ATen: [aten.sum]
triton_red_fused_sum_22.run(buf415, buf425, 49152, 128, grid=grid(49152), stream=stream0)
buf428 = empty_strided_cuda((768, ), (1, ), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum]
triton_per_fused__to_copy_sum_9.run(buf425, buf428, 768, 64, grid=grid(768), stream=stream0)
buf427 = empty_strided_cuda((768, 768), (768, 1), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_10.run(buf424, buf427, 589824, grid=grid(589824), stream=stream0)
buf429 = reinterpret_tensor(buf415, (8192, 768), (768, 1), 0); del buf415 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf414, (8192, 768), (768, 1), 0), permute_369, out=buf429)
del permute_369
buf430 = buf424; del buf424 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf414, (768, 8192), (1, 768), 0), view_110, out=buf430)
del view_110
buf431 = buf425; del buf425 # reuse
# Source Nodes: [], Original ATen: [aten.sum]
triton_red_fused_sum_22.run(buf414, buf431, 49152, 128, grid=grid(49152), stream=stream0)
buf434 = empty_strided_cuda((768, ), (1, ), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum]
triton_per_fused__to_copy_sum_9.run(buf431, buf434, 768, 64, grid=grid(768), stream=stream0)
buf433 = empty_strided_cuda((768, 768), (768, 1), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_10.run(buf430, buf433, 589824, grid=grid(589824), stream=stream0)
buf438 = buf379; del buf379 # reuse
buf443 = reinterpret_tensor(buf414, (16, 512, 768), (393216, 768, 1), 0); del buf414 # reuse
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_dropout_backward, aten.native_layer_norm_backward]
triton_per_fused__to_copy_add_native_dropout_backward_native_layer_norm_backward_23.run(buf400, buf417, buf423, buf429, primals_84, mul_68, div_48, gt_15, buf438, buf443, 8192, 768, grid=grid(8192), stream=stream0)
del div_48
del gt_15
del primals_84
buf439 = reinterpret_tensor(buf431, (768, 64), (1, 768), 0); del buf431 # reuse
buf441 = buf401; del buf401 # reuse
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward]
triton_red_fused__to_copy_add_native_layer_norm_backward_24.run(buf400, buf417, buf423, buf429, mul_68, buf439, buf441, 49152, 128, grid=grid(49152), stream=stream0)
del buf417
del buf423
del mul_68
buf440 = empty_strided_cuda((768, ), (1, ), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward]
triton_per_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_7.run(buf439, buf440, 768, 64, grid=grid(768), stream=stream0)
buf442 = empty_strided_cuda((768, ), (1, ), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward]
triton_per_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_7.run(buf441, buf442, 768, 64, grid=grid(768), stream=stream0)
buf444 = reinterpret_tensor(buf391, (8192, 3072), (3072, 1), 0); del buf391 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf443, (8192, 768), (768, 1), 0), permute_373, out=buf444)
del permute_373
buf445 = reinterpret_tensor(buf393, (768, 3072), (3072, 1), 0); del buf393 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf443, (768, 8192), (1, 768), 0), view_108, out=buf445)
del view_108
buf446 = reinterpret_tensor(buf441, (1, 768, 64), (49152, 1, 768), 0); del buf441 # reuse
# Source Nodes: [], Original ATen: [aten.sum]
triton_red_fused_sum_13.run(buf443, buf446, 49152, 128, grid=grid(49152), stream=stream0)
buf449 = empty_strided_cuda((768, ), (1, ), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum]
triton_per_fused__to_copy_sum_9.run(buf446, buf449, 768, 64, grid=grid(768), stream=stream0)
buf448 = empty_strided_cuda((768, 3072), (3072, 1), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_14.run(buf445, buf448, 2359296, grid=grid(2359296), stream=stream0)
buf450 = reinterpret_tensor(buf444, (16, 512, 3072), (1572864, 3072, 1), 0); del buf444 # reuse
# Source Nodes: [hidden_states_36], Original ATen: [aten.gelu, aten.gelu_backward]
triton_poi_fused_gelu_gelu_backward_15.run(buf450, addmm_28, 25165824, grid=grid(25165824), stream=stream0)
del addmm_28
buf451 = reinterpret_tensor(buf443, (8192, 768), (768, 1), 0); del buf443 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf450, (8192, 3072), (3072, 1), 0), permute_377, out=buf451)
del permute_377
buf452 = reinterpret_tensor(buf445, (3072, 768), (768, 1), 0); del buf445 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf450, (3072, 8192), (1, 3072), 0), view_106, out=buf452)
del view_106
buf453 = buf394; del buf394 # reuse
# Source Nodes: [], Original ATen: [aten.sum]
triton_red_fused_sum_16.run(buf450, buf453, 98304, 256, grid=grid(98304), stream=stream0)
buf456 = empty_strided_cuda((3072, ), (1, ), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum]
triton_per_fused__to_copy_sum_17.run(buf453, buf456, 3072, 32, grid=grid(3072), stream=stream0)
buf455 = empty_strided_cuda((3072, 768), (768, 1), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_18.run(buf452, buf455, 2359296, grid=grid(2359296), stream=stream0)
buf459 = buf400; del buf400 # reuse
buf464 = reinterpret_tensor(buf429, (16, 512, 768), (393216, 768, 1), 0); del buf429 # reuse
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_dropout_backward, aten.native_layer_norm_backward]
triton_per_fused__to_copy_add_native_dropout_backward_native_layer_norm_backward_19.run(buf438, buf451, primals_78, mul_61, div_49, gt_14, buf459, buf464, 8192, 768, grid=grid(8192), stream=stream0)
del div_49
del gt_14
del primals_78
buf460 = reinterpret_tensor(buf446, (768, 64), (1, 768), 0); del buf446 # reuse
buf462 = buf439; del buf439 # reuse
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward]
triton_red_fused__to_copy_add_native_layer_norm_backward_20.run(buf438, buf451, mul_61, buf460, buf462, 49152, 128, grid=grid(49152), stream=stream0)
del mul_61
buf461 = empty_strided_cuda((768, ), (1, ), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward]
triton_per_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_7.run(buf460, buf461, 768, 64, grid=grid(768), stream=stream0)
buf463 = empty_strided_cuda((768, ), (1, ), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward]
triton_per_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_7.run(buf462, buf463, 768, 64, grid=grid(768), stream=stream0)
buf465 = buf451; del buf451 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf464, (8192, 768), (768, 1), 0), permute_381, out=buf465)
del permute_381
buf466 = buf430; del buf430 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf464, (768, 8192), (1, 768), 0), view_104, out=buf466)
del view_104
buf467 = reinterpret_tensor(buf462, (1, 768, 64), (49152, 1, 768), 0); del buf462 # reuse
# Source Nodes: [], Original ATen: [aten.sum]
triton_red_fused_sum_13.run(buf464, buf467, 49152, 128, grid=grid(49152), stream=stream0)
buf470 = empty_strided_cuda((768, ), (1, ), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum]
triton_per_fused__to_copy_sum_9.run(buf467, buf470, 768, 64, grid=grid(768), stream=stream0)
buf469 = empty_strided_cuda((768, 768), (768, 1), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_10.run(buf466, buf469, 589824, grid=grid(589824), stream=stream0)
buf471 = reinterpret_tensor(buf464, (16, 12, 512, 64), (393216, 32768, 64, 1), 0); del buf464 # reuse
# Source Nodes: [], Original ATen: [aten.clone]
triton_poi_fused_clone_21.run(buf465, buf471, 6291456, grid=grid(6291456), stream=stream0)
del buf465
# Source Nodes: [], Original ATen: [aten.clone]
buf472 = aten._scaled_dot_product_flash_attention_backward.default(buf471, permute_default_42, permute_default_43, permute_default_44, getitem_101, getitem_102, None, None, 512, 512, 0.1, False, getitem_103, getitem_104, scale=0.125)
del getitem_101
del getitem_102
del getitem_103
del getitem_104
del permute_default_42
del permute_default_43
del permute_default_44
buf473 = buf472[0]
buf474 = buf472[1]
buf475 = buf472[2]
del buf472
buf476 = reinterpret_tensor(buf471, (8192, 768), (768, 1), 0); del buf471 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf475, (8192, 768), (768, 1), 0), permute_393, out=buf476)
del permute_393
buf477 = buf466; del buf466 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf475, (768, 8192), (1, 768), 0), view_88, out=buf477)
buf478 = buf467; del buf467 # reuse
# Source Nodes: [], Original ATen: [aten.sum]
triton_red_fused_sum_22.run(buf475, buf478, 49152, 128, grid=grid(49152), stream=stream0)
buf481 = empty_strided_cuda((768, ), (1, ), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum]
triton_per_fused__to_copy_sum_9.run(buf478, buf481, 768, 64, grid=grid(768), stream=stream0)
buf480 = empty_strided_cuda((768, 768), (768, 1), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_10.run(buf477, buf480, 589824, grid=grid(589824), stream=stream0)
buf482 = reinterpret_tensor(buf475, (8192, 768), (768, 1), 0); del buf475 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf474, (8192, 768), (768, 1), 0), permute_398, out=buf482)
del permute_398
buf483 = buf477; del buf477 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf474, (768, 8192), (1, 768), 0), view_88, out=buf483)
buf484 = buf478; del buf478 # reuse
# Source Nodes: [], Original ATen: [aten.sum]
triton_red_fused_sum_22.run(buf474, buf484, 49152, 128, grid=grid(49152), stream=stream0)
buf487 = empty_strided_cuda((768, ), (1, ), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum]
triton_per_fused__to_copy_sum_9.run(buf484, buf487, 768, 64, grid=grid(768), stream=stream0)
buf486 = empty_strided_cuda((768, 768), (768, 1), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_10.run(buf483, buf486, 589824, grid=grid(589824), stream=stream0)
buf488 = reinterpret_tensor(buf474, (8192, 768), (768, 1), 0); del buf474 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf473, (8192, 768), (768, 1), 0), permute_402, out=buf488)
del permute_402
buf489 = buf483; del buf483 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf473, (768, 8192), (1, 768), 0), view_88, out=buf489)
del view_88
buf490 = buf484; del buf484 # reuse
# Source Nodes: [], Original ATen: [aten.sum]
triton_red_fused_sum_22.run(buf473, buf490, 49152, 128, grid=grid(49152), stream=stream0)
buf493 = empty_strided_cuda((768, ), (1, ), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum]
triton_per_fused__to_copy_sum_9.run(buf490, buf493, 768, 64, grid=grid(768), stream=stream0)
buf492 = empty_strided_cuda((768, 768), (768, 1), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_10.run(buf489, buf492, 589824, grid=grid(589824), stream=stream0)
buf497 = buf438; del buf438 # reuse
buf502 = reinterpret_tensor(buf473, (16, 512, 768), (393216, 768, 1), 0); del buf473 # reuse
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_dropout_backward, aten.native_layer_norm_backward]
triton_per_fused__to_copy_add_native_dropout_backward_native_layer_norm_backward_23.run(buf459, buf476, buf482, buf488, primals_68, mul_55, div_51, gt_12, buf497, buf502, 8192, 768, grid=grid(8192), stream=stream0)
del div_51
del gt_12
del primals_68
buf498 = reinterpret_tensor(buf490, (768, 64), (1, 768), 0); del buf490 # reuse
buf500 = buf460; del buf460 # reuse
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward]
triton_red_fused__to_copy_add_native_layer_norm_backward_24.run(buf459, buf476, buf482, buf488, mul_55, buf498, buf500, 49152, 128, grid=grid(49152), stream=stream0)
del buf476
del buf482
del mul_55
buf499 = empty_strided_cuda((768, ), (1, ), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward]
triton_per_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_7.run(buf498, buf499, 768, 64, grid=grid(768), stream=stream0)
buf501 = empty_strided_cuda((768, ), (1, ), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward]
triton_per_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_7.run(buf500, buf501, 768, 64, grid=grid(768), stream=stream0)
buf503 = reinterpret_tensor(buf450, (8192, 3072), (3072, 1), 0); del buf450 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf502, (8192, 768), (768, 1), 0), permute_406, out=buf503)
del permute_406
buf504 = reinterpret_tensor(buf452, (768, 3072), (3072, 1), 0); del buf452 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf502, (768, 8192), (1, 768), 0), view_86, out=buf504)
del view_86
buf505 = reinterpret_tensor(buf500, (1, 768, 64), (49152, 1, 768), 0); del buf500 # reuse
# Source Nodes: [], Original ATen: [aten.sum]
triton_red_fused_sum_13.run(buf502, buf505, 49152, 128, grid=grid(49152), stream=stream0)
buf508 = empty_strided_cuda((768, ), (1, ), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum]
triton_per_fused__to_copy_sum_9.run(buf505, buf508, 768, 64, grid=grid(768), stream=stream0)
buf507 = empty_strided_cuda((768, 3072), (3072, 1), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_14.run(buf504, buf507, 2359296, grid=grid(2359296), stream=stream0)
buf509 = reinterpret_tensor(buf503, (16, 512, 3072), (1572864, 3072, 1), 0); del buf503 # reuse
# Source Nodes: [hidden_states_28], Original ATen: [aten.gelu, aten.gelu_backward]
triton_poi_fused_gelu_gelu_backward_15.run(buf509, addmm_22, 25165824, grid=grid(25165824), stream=stream0)
del addmm_22
buf510 = reinterpret_tensor(buf502, (8192, 768), (768, 1), 0); del buf502 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf509, (8192, 3072), (3072, 1), 0), permute_410, out=buf510)
del permute_410
buf511 = reinterpret_tensor(buf504, (3072, 768), (768, 1), 0); del buf504 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf509, (3072, 8192), (1, 3072), 0), view_84, out=buf511)
del view_84
buf512 = buf453; del buf453 # reuse
# Source Nodes: [], Original ATen: [aten.sum]
triton_red_fused_sum_16.run(buf509, buf512, 98304, 256, grid=grid(98304), stream=stream0)
buf515 = empty_strided_cuda((3072, ), (1, ), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum]
triton_per_fused__to_copy_sum_17.run(buf512, buf515, 3072, 32, grid=grid(3072), stream=stream0)
buf514 = empty_strided_cuda((3072, 768), (768, 1), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_18.run(buf511, buf514, 2359296, grid=grid(2359296), stream=stream0)
buf518 = buf459; del buf459 # reuse
buf523 = reinterpret_tensor(buf488, (16, 512, 768), (393216, 768, 1), 0); del buf488 # reuse
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_dropout_backward, aten.native_layer_norm_backward]
triton_per_fused__to_copy_add_native_dropout_backward_native_layer_norm_backward_19.run(buf497, buf510, primals_62, mul_48, div_52, gt_11, buf518, buf523, 8192, 768, grid=grid(8192), stream=stream0)
del div_52
del gt_11
del primals_62
buf519 = reinterpret_tensor(buf505, (768, 64), (1, 768), 0); del buf505 # reuse
buf521 = buf498; del buf498 # reuse
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward]
triton_red_fused__to_copy_add_native_layer_norm_backward_20.run(buf497, buf510, mul_48, buf519, buf521, 49152, 128, grid=grid(49152), stream=stream0)
del mul_48
buf520 = empty_strided_cuda((768, ), (1, ), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward]
triton_per_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_7.run(buf519, buf520, 768, 64, grid=grid(768), stream=stream0)
buf522 = empty_strided_cuda((768, ), (1, ), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward]
triton_per_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_7.run(buf521, buf522, 768, 64, grid=grid(768), stream=stream0)
buf524 = buf510; del buf510 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf523, (8192, 768), (768, 1), 0), permute_414, out=buf524)
del permute_414
buf525 = buf489; del buf489 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf523, (768, 8192), (1, 768), 0), view_82, out=buf525)
del view_82
buf526 = reinterpret_tensor(buf521, (1, 768, 64), (49152, 1, 768), 0); del buf521 # reuse
# Source Nodes: [], Original ATen: [aten.sum]
triton_red_fused_sum_13.run(buf523, buf526, 49152, 128, grid=grid(49152), stream=stream0)
buf529 = empty_strided_cuda((768, ), (1, ), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum]
triton_per_fused__to_copy_sum_9.run(buf526, buf529, 768, 64, grid=grid(768), stream=stream0)
buf528 = empty_strided_cuda((768, 768), (768, 1), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_10.run(buf525, buf528, 589824, grid=grid(589824), stream=stream0)
buf530 = reinterpret_tensor(buf523, (16, 12, 512, 64), (393216, 32768, 64, 1), 0); del buf523 # reuse
# Source Nodes: [], Original ATen: [aten.clone]
triton_poi_fused_clone_21.run(buf524, buf530, 6291456, grid=grid(6291456), stream=stream0)
del buf524
# Source Nodes: [], Original ATen: [aten.clone]
buf531 = aten._scaled_dot_product_flash_attention_backward.default(buf530, permute_default_48, permute_default_49, permute_default_50, getitem_108, getitem_109, None, None, 512, 512, 0.1, False, getitem_110, getitem_111, scale=0.125)
del getitem_108
del getitem_109
del getitem_110
del getitem_111
del permute_default_48
del permute_default_49
del permute_default_50
buf532 = buf531[0]
buf533 = buf531[1]
buf534 = buf531[2]
del buf531
buf535 = reinterpret_tensor(buf530, (8192, 768), (768, 1), 0); del buf530 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf534, (8192, 768), (768, 1), 0), permute_426, out=buf535)
del permute_426
buf536 = buf525; del buf525 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf534, (768, 8192), (1, 768), 0), view_66, out=buf536)
buf537 = buf526; del buf526 # reuse
# Source Nodes: [], Original ATen: [aten.sum]
triton_red_fused_sum_22.run(buf534, buf537, 49152, 128, grid=grid(49152), stream=stream0)
buf540 = empty_strided_cuda((768, ), (1, ), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum]
triton_per_fused__to_copy_sum_9.run(buf537, buf540, 768, 64, grid=grid(768), stream=stream0)
buf539 = empty_strided_cuda((768, 768), (768, 1), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_10.run(buf536, buf539, 589824, grid=grid(589824), stream=stream0)
buf541 = reinterpret_tensor(buf534, (8192, 768), (768, 1), 0); del buf534 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf533, (8192, 768), (768, 1), 0), permute_431, out=buf541)
del permute_431
buf542 = buf536; del buf536 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf533, (768, 8192), (1, 768), 0), view_66, out=buf542)
buf543 = buf537; del buf537 # reuse
# Source Nodes: [], Original ATen: [aten.sum]
triton_red_fused_sum_22.run(buf533, buf543, 49152, 128, grid=grid(49152), stream=stream0)
buf546 = empty_strided_cuda((768, ), (1, ), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum]
triton_per_fused__to_copy_sum_9.run(buf543, buf546, 768, 64, grid=grid(768), stream=stream0)
buf545 = empty_strided_cuda((768, 768), (768, 1), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_10.run(buf542, buf545, 589824, grid=grid(589824), stream=stream0)
buf547 = reinterpret_tensor(buf533, (8192, 768), (768, 1), 0); del buf533 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf532, (8192, 768), (768, 1), 0), permute_435, out=buf547)
del permute_435
buf548 = buf542; del buf542 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf532, (768, 8192), (1, 768), 0), view_66, out=buf548)
del view_66
buf549 = buf543; del buf543 # reuse
# Source Nodes: [], Original ATen: [aten.sum]
triton_red_fused_sum_22.run(buf532, buf549, 49152, 128, grid=grid(49152), stream=stream0)
buf552 = empty_strided_cuda((768, ), (1, ), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum]
triton_per_fused__to_copy_sum_9.run(buf549, buf552, 768, 64, grid=grid(768), stream=stream0)
buf551 = empty_strided_cuda((768, 768), (768, 1), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_10.run(buf548, buf551, 589824, grid=grid(589824), stream=stream0)
buf556 = buf497; del buf497 # reuse
buf561 = reinterpret_tensor(buf532, (16, 512, 768), (393216, 768, 1), 0); del buf532 # reuse
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_dropout_backward, aten.native_layer_norm_backward]
triton_per_fused__to_copy_add_native_dropout_backward_native_layer_norm_backward_23.run(buf518, buf535, buf541, buf547, primals_52, mul_42, div_54, gt_9, buf556, buf561, 8192, 768, grid=grid(8192), stream=stream0)
del div_54
del gt_9
del primals_52
buf557 = reinterpret_tensor(buf549, (768, 64), (1, 768), 0); del buf549 # reuse
buf559 = buf519; del buf519 # reuse
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward]
triton_red_fused__to_copy_add_native_layer_norm_backward_24.run(buf518, buf535, buf541, buf547, mul_42, buf557, buf559, 49152, 128, grid=grid(49152), stream=stream0)
del buf535
del buf541
del mul_42
buf558 = empty_strided_cuda((768, ), (1, ), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward]
triton_per_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_7.run(buf557, buf558, 768, 64, grid=grid(768), stream=stream0)
buf560 = empty_strided_cuda((768, ), (1, ), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward]
triton_per_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_7.run(buf559, buf560, 768, 64, grid=grid(768), stream=stream0)
buf562 = reinterpret_tensor(buf509, (8192, 3072), (3072, 1), 0); del buf509 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf561, (8192, 768), (768, 1), 0), permute_439, out=buf562)
del permute_439
buf563 = reinterpret_tensor(buf511, (768, 3072), (3072, 1), 0); del buf511 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf561, (768, 8192), (1, 768), 0), view_64, out=buf563)
del view_64
buf564 = reinterpret_tensor(buf559, (1, 768, 64), (49152, 1, 768), 0); del buf559 # reuse
# Source Nodes: [], Original ATen: [aten.sum]
triton_red_fused_sum_13.run(buf561, buf564, 49152, 128, grid=grid(49152), stream=stream0)
buf567 = empty_strided_cuda((768, ), (1, ), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum]
triton_per_fused__to_copy_sum_9.run(buf564, buf567, 768, 64, grid=grid(768), stream=stream0)
buf566 = empty_strided_cuda((768, 3072), (3072, 1), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_14.run(buf563, buf566, 2359296, grid=grid(2359296), stream=stream0)
buf568 = reinterpret_tensor(buf562, (16, 512, 3072), (1572864, 3072, 1), 0); del buf562 # reuse
# Source Nodes: [hidden_states_20], Original ATen: [aten.gelu, aten.gelu_backward]
triton_poi_fused_gelu_gelu_backward_15.run(buf568, addmm_16, 25165824, grid=grid(25165824), stream=stream0)
del addmm_16
buf569 = reinterpret_tensor(buf561, (8192, 768), (768, 1), 0); del buf561 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf568, (8192, 3072), (3072, 1), 0), permute_443, out=buf569)
del permute_443
buf570 = reinterpret_tensor(buf563, (3072, 768), (768, 1), 0); del buf563 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf568, (3072, 8192), (1, 3072), 0), view_62, out=buf570)
del view_62
buf571 = buf512; del buf512 # reuse
# Source Nodes: [], Original ATen: [aten.sum]
triton_red_fused_sum_16.run(buf568, buf571, 98304, 256, grid=grid(98304), stream=stream0)
buf574 = empty_strided_cuda((3072, ), (1, ), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum]
triton_per_fused__to_copy_sum_17.run(buf571, buf574, 3072, 32, grid=grid(3072), stream=stream0)
buf573 = empty_strided_cuda((3072, 768), (768, 1), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_18.run(buf570, buf573, 2359296, grid=grid(2359296), stream=stream0)
buf577 = buf518; del buf518 # reuse
buf582 = reinterpret_tensor(buf547, (16, 512, 768), (393216, 768, 1), 0); del buf547 # reuse
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_dropout_backward, aten.native_layer_norm_backward]
triton_per_fused__to_copy_add_native_dropout_backward_native_layer_norm_backward_19.run(buf556, buf569, primals_46, mul_35, div_55, gt_8, buf577, buf582, 8192, 768, grid=grid(8192), stream=stream0)
del div_55
del gt_8
del primals_46
buf578 = reinterpret_tensor(buf564, (768, 64), (1, 768), 0); del buf564 # reuse
buf580 = buf557; del buf557 # reuse
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward]
triton_red_fused__to_copy_add_native_layer_norm_backward_20.run(buf556, buf569, mul_35, buf578, buf580, 49152, 128, grid=grid(49152), stream=stream0)
del mul_35
buf579 = empty_strided_cuda((768, ), (1, ), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward]
triton_per_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_7.run(buf578, buf579, 768, 64, grid=grid(768), stream=stream0)
buf581 = empty_strided_cuda((768, ), (1, ), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward]
triton_per_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_7.run(buf580, buf581, 768, 64, grid=grid(768), stream=stream0)
buf583 = buf569; del buf569 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf582, (8192, 768), (768, 1), 0), permute_447, out=buf583)
del permute_447
buf584 = buf548; del buf548 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf582, (768, 8192), (1, 768), 0), view_60, out=buf584)
del view_60
buf585 = reinterpret_tensor(buf580, (1, 768, 64), (49152, 1, 768), 0); del buf580 # reuse
# Source Nodes: [], Original ATen: [aten.sum]
triton_red_fused_sum_13.run(buf582, buf585, 49152, 128, grid=grid(49152), stream=stream0)
buf588 = empty_strided_cuda((768, ), (1, ), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum]
triton_per_fused__to_copy_sum_9.run(buf585, buf588, 768, 64, grid=grid(768), stream=stream0)
buf587 = empty_strided_cuda((768, 768), (768, 1), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_10.run(buf584, buf587, 589824, grid=grid(589824), stream=stream0)
buf589 = reinterpret_tensor(buf582, (16, 12, 512, 64), (393216, 32768, 64, 1), 0); del buf582 # reuse
# Source Nodes: [], Original ATen: [aten.clone]
triton_poi_fused_clone_21.run(buf583, buf589, 6291456, grid=grid(6291456), stream=stream0)
del buf583
# Source Nodes: [], Original ATen: [aten.clone]
buf590 = aten._scaled_dot_product_flash_attention_backward.default(buf589, permute_default_54, permute_default_55, permute_default_56, getitem_115, getitem_116, None, None, 512, 512, 0.1, False, getitem_117, getitem_118, scale=0.125)
del getitem_115
del getitem_116
del getitem_117
del getitem_118
del permute_default_54
del permute_default_55
del permute_default_56
buf591 = buf590[0]
buf592 = buf590[1]
buf593 = buf590[2]
del buf590
buf594 = reinterpret_tensor(buf589, (8192, 768), (768, 1), 0); del buf589 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf593, (8192, 768), (768, 1), 0), permute_459, out=buf594)
del permute_459
buf595 = buf584; del buf584 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf593, (768, 8192), (1, 768), 0), view_44, out=buf595)
buf596 = buf585; del buf585 # reuse
# Source Nodes: [], Original ATen: [aten.sum]
triton_red_fused_sum_22.run(buf593, buf596, 49152, 128, grid=grid(49152), stream=stream0)
buf599 = empty_strided_cuda((768, ), (1, ), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum]
triton_per_fused__to_copy_sum_9.run(buf596, buf599, 768, 64, grid=grid(768), stream=stream0)
buf598 = empty_strided_cuda((768, 768), (768, 1), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_10.run(buf595, buf598, 589824, grid=grid(589824), stream=stream0)
buf600 = reinterpret_tensor(buf593, (8192, 768), (768, 1), 0); del buf593 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf592, (8192, 768), (768, 1), 0), permute_464, out=buf600)
del permute_464
buf601 = buf595; del buf595 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf592, (768, 8192), (1, 768), 0), view_44, out=buf601)
buf602 = buf596; del buf596 # reuse
# Source Nodes: [], Original ATen: [aten.sum]
triton_red_fused_sum_22.run(buf592, buf602, 49152, 128, grid=grid(49152), stream=stream0)
buf605 = empty_strided_cuda((768, ), (1, ), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum]
triton_per_fused__to_copy_sum_9.run(buf602, buf605, 768, 64, grid=grid(768), stream=stream0)
buf604 = empty_strided_cuda((768, 768), (768, 1), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_10.run(buf601, buf604, 589824, grid=grid(589824), stream=stream0)
buf606 = reinterpret_tensor(buf592, (8192, 768), (768, 1), 0); del buf592 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf591, (8192, 768), (768, 1), 0), permute_468, out=buf606)
del permute_468
buf607 = buf601; del buf601 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf591, (768, 8192), (1, 768), 0), view_44, out=buf607)
del view_44
buf608 = buf602; del buf602 # reuse
# Source Nodes: [], Original ATen: [aten.sum]
triton_red_fused_sum_22.run(buf591, buf608, 49152, 128, grid=grid(49152), stream=stream0)
buf611 = empty_strided_cuda((768, ), (1, ), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum]
triton_per_fused__to_copy_sum_9.run(buf608, buf611, 768, 64, grid=grid(768), stream=stream0)
buf610 = empty_strided_cuda((768, 768), (768, 1), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_10.run(buf607, buf610, 589824, grid=grid(589824), stream=stream0)
buf615 = buf556; del buf556 # reuse
buf620 = reinterpret_tensor(buf591, (16, 512, 768), (393216, 768, 1), 0); del buf591 # reuse
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_dropout_backward, aten.native_layer_norm_backward]
triton_per_fused__to_copy_add_native_dropout_backward_native_layer_norm_backward_23.run(buf577, buf594, buf600, buf606, primals_36, mul_29, div_57, gt_6, buf615, buf620, 8192, 768, grid=grid(8192), stream=stream0)
del div_57
del gt_6
del primals_36
buf616 = reinterpret_tensor(buf608, (768, 64), (1, 768), 0); del buf608 # reuse
buf618 = buf578; del buf578 # reuse
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward]
triton_red_fused__to_copy_add_native_layer_norm_backward_24.run(buf577, buf594, buf600, buf606, mul_29, buf616, buf618, 49152, 128, grid=grid(49152), stream=stream0)
del buf594
del buf600
del mul_29
buf617 = empty_strided_cuda((768, ), (1, ), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward]
triton_per_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_7.run(buf616, buf617, 768, 64, grid=grid(768), stream=stream0)
buf619 = empty_strided_cuda((768, ), (1, ), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward]
triton_per_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_7.run(buf618, buf619, 768, 64, grid=grid(768), stream=stream0)
buf621 = reinterpret_tensor(buf568, (8192, 3072), (3072, 1), 0); del buf568 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf620, (8192, 768), (768, 1), 0), permute_472, out=buf621)
del permute_472
buf622 = reinterpret_tensor(buf570, (768, 3072), (3072, 1), 0); del buf570 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf620, (768, 8192), (1, 768), 0), view_42, out=buf622)
del view_42
buf623 = reinterpret_tensor(buf618, (1, 768, 64), (49152, 1, 768), 0); del buf618 # reuse
# Source Nodes: [], Original ATen: [aten.sum]
triton_red_fused_sum_13.run(buf620, buf623, 49152, 128, grid=grid(49152), stream=stream0)
buf626 = empty_strided_cuda((768, ), (1, ), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum]
triton_per_fused__to_copy_sum_9.run(buf623, buf626, 768, 64, grid=grid(768), stream=stream0)
buf625 = empty_strided_cuda((768, 3072), (3072, 1), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_14.run(buf622, buf625, 2359296, grid=grid(2359296), stream=stream0)
buf627 = reinterpret_tensor(buf621, (16, 512, 3072), (1572864, 3072, 1), 0); del buf621 # reuse
# Source Nodes: [hidden_states_12], Original ATen: [aten.gelu, aten.gelu_backward]
triton_poi_fused_gelu_gelu_backward_15.run(buf627, addmm_10, 25165824, grid=grid(25165824), stream=stream0)
del addmm_10
buf628 = reinterpret_tensor(buf620, (8192, 768), (768, 1), 0); del buf620 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf627, (8192, 3072), (3072, 1), 0), permute_476, out=buf628)
del permute_476
buf629 = reinterpret_tensor(buf622, (3072, 768), (768, 1), 0); del buf622 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf627, (3072, 8192), (1, 3072), 0), view_40, out=buf629)
del view_40
buf630 = buf571; del buf571 # reuse
# Source Nodes: [], Original ATen: [aten.sum]
triton_red_fused_sum_16.run(buf627, buf630, 98304, 256, grid=grid(98304), stream=stream0)
buf633 = empty_strided_cuda((3072, ), (1, ), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum]
triton_per_fused__to_copy_sum_17.run(buf630, buf633, 3072, 32, grid=grid(3072), stream=stream0)
buf632 = empty_strided_cuda((3072, 768), (768, 1), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_18.run(buf629, buf632, 2359296, grid=grid(2359296), stream=stream0)
buf636 = buf577; del buf577 # reuse
buf641 = reinterpret_tensor(buf606, (16, 512, 768), (393216, 768, 1), 0); del buf606 # reuse
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_dropout_backward, aten.native_layer_norm_backward]
triton_per_fused__to_copy_add_native_dropout_backward_native_layer_norm_backward_19.run(buf615, buf628, primals_30, mul_22, div_58, gt_5, buf636, buf641, 8192, 768, grid=grid(8192), stream=stream0)
del div_58
del gt_5
del primals_30
buf637 = reinterpret_tensor(buf623, (768, 64), (1, 768), 0); del buf623 # reuse
buf639 = buf616; del buf616 # reuse
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward]
triton_red_fused__to_copy_add_native_layer_norm_backward_20.run(buf615, buf628, mul_22, buf637, buf639, 49152, 128, grid=grid(49152), stream=stream0)
del mul_22
buf638 = empty_strided_cuda((768, ), (1, ), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward]
triton_per_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_7.run(buf637, buf638, 768, 64, grid=grid(768), stream=stream0)
buf640 = empty_strided_cuda((768, ), (1, ), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward]
triton_per_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_7.run(buf639, buf640, 768, 64, grid=grid(768), stream=stream0)
buf642 = buf628; del buf628 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf641, (8192, 768), (768, 1), 0), permute_480, out=buf642)
del permute_480
buf643 = buf607; del buf607 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf641, (768, 8192), (1, 768), 0), view_38, out=buf643)
del view_38
buf644 = reinterpret_tensor(buf639, (1, 768, 64), (49152, 1, 768), 0); del buf639 # reuse
# Source Nodes: [], Original ATen: [aten.sum]
triton_red_fused_sum_13.run(buf641, buf644, 49152, 128, grid=grid(49152), stream=stream0)
buf647 = empty_strided_cuda((768, ), (1, ), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum]
triton_per_fused__to_copy_sum_9.run(buf644, buf647, 768, 64, grid=grid(768), stream=stream0)
buf646 = empty_strided_cuda((768, 768), (768, 1), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_10.run(buf643, buf646, 589824, grid=grid(589824), stream=stream0)
buf648 = reinterpret_tensor(buf641, (16, 12, 512, 64), (393216, 32768, 64, 1), 0); del buf641 # reuse
# Source Nodes: [], Original ATen: [aten.clone]
triton_poi_fused_clone_21.run(buf642, buf648, 6291456, grid=grid(6291456), stream=stream0)
del buf642
# Source Nodes: [], Original ATen: [aten.clone]
buf649 = aten._scaled_dot_product_flash_attention_backward.default(buf648, permute_default_60, permute_default_61, permute_default_62, getitem_122, getitem_123, None, None, 512, 512, 0.1, False, getitem_124, getitem_125, scale=0.125)
del getitem_122
del getitem_123
del getitem_124
del getitem_125
del permute_default_60
del permute_default_61
del permute_default_62
buf650 = buf649[0]
buf651 = buf649[1]
buf652 = buf649[2]
del buf649
buf653 = reinterpret_tensor(buf648, (8192, 768), (768, 1), 0); del buf648 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf652, (8192, 768), (768, 1), 0), permute_492, out=buf653)
del permute_492
buf654 = buf643; del buf643 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf652, (768, 8192), (1, 768), 0), view_22, out=buf654)
buf655 = buf644; del buf644 # reuse
# Source Nodes: [], Original ATen: [aten.sum]
triton_red_fused_sum_22.run(buf652, buf655, 49152, 128, grid=grid(49152), stream=stream0)
buf658 = empty_strided_cuda((768, ), (1, ), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum]
triton_per_fused__to_copy_sum_9.run(buf655, buf658, 768, 64, grid=grid(768), stream=stream0)
buf657 = empty_strided_cuda((768, 768), (768, 1), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_10.run(buf654, buf657, 589824, grid=grid(589824), stream=stream0)
buf659 = reinterpret_tensor(buf652, (8192, 768), (768, 1), 0); del buf652 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf651, (8192, 768), (768, 1), 0), permute_497, out=buf659)
del permute_497
buf660 = buf654; del buf654 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf651, (768, 8192), (1, 768), 0), view_22, out=buf660)
buf661 = buf655; del buf655 # reuse
# Source Nodes: [], Original ATen: [aten.sum]
triton_red_fused_sum_22.run(buf651, buf661, 49152, 128, grid=grid(49152), stream=stream0)
buf664 = empty_strided_cuda((768, ), (1, ), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum]
triton_per_fused__to_copy_sum_9.run(buf661, buf664, 768, 64, grid=grid(768), stream=stream0)
buf663 = empty_strided_cuda((768, 768), (768, 1), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_10.run(buf660, buf663, 589824, grid=grid(589824), stream=stream0)
buf665 = reinterpret_tensor(buf651, (8192, 768), (768, 1), 0); del buf651 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf650, (8192, 768), (768, 1), 0), permute_501, out=buf665)
del permute_501
buf666 = buf660; del buf660 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf650, (768, 8192), (1, 768), 0), view_22, out=buf666)
del view_22
buf667 = buf661; del buf661 # reuse
# Source Nodes: [], Original ATen: [aten.sum]
triton_red_fused_sum_22.run(buf650, buf667, 49152, 128, grid=grid(49152), stream=stream0)
buf670 = empty_strided_cuda((768, ), (1, ), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum]
triton_per_fused__to_copy_sum_9.run(buf667, buf670, 768, 64, grid=grid(768), stream=stream0)
buf669 = empty_strided_cuda((768, 768), (768, 1), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_10.run(buf666, buf669, 589824, grid=grid(589824), stream=stream0)
buf674 = buf615; del buf615 # reuse
buf679 = reinterpret_tensor(buf650, (16, 512, 768), (393216, 768, 1), 0); del buf650 # reuse
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_dropout_backward, aten.native_layer_norm_backward]
triton_per_fused__to_copy_add_native_dropout_backward_native_layer_norm_backward_23.run(buf636, buf653, buf659, buf665, primals_20, mul_16, div_60, gt_3, buf674, buf679, 8192, 768, grid=grid(8192), stream=stream0)
del div_60
del gt_3
del primals_20
buf675 = reinterpret_tensor(buf667, (768, 64), (1, 768), 0); del buf667 # reuse
buf677 = buf637; del buf637 # reuse
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward]
triton_red_fused__to_copy_add_native_layer_norm_backward_24.run(buf636, buf653, buf659, buf665, mul_16, buf675, buf677, 49152, 128, grid=grid(49152), stream=stream0)
del buf653
del buf659
del mul_16
buf676 = empty_strided_cuda((768, ), (1, ), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward]
triton_per_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_7.run(buf675, buf676, 768, 64, grid=grid(768), stream=stream0)
buf678 = empty_strided_cuda((768, ), (1, ), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward]
triton_per_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_7.run(buf677, buf678, 768, 64, grid=grid(768), stream=stream0)
buf680 = reinterpret_tensor(buf627, (8192, 3072), (3072, 1), 0); del buf627 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf679, (8192, 768), (768, 1), 0), permute_505, out=buf680)
del permute_505
buf681 = reinterpret_tensor(buf629, (768, 3072), (3072, 1), 0); del buf629 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf679, (768, 8192), (1, 768), 0), view_20, out=buf681)
del view_20
buf682 = reinterpret_tensor(buf677, (1, 768, 64), (49152, 1, 768), 0); del buf677 # reuse
# Source Nodes: [], Original ATen: [aten.sum]
triton_red_fused_sum_13.run(buf679, buf682, 49152, 128, grid=grid(49152), stream=stream0)
buf685 = empty_strided_cuda((768, ), (1, ), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum]
triton_per_fused__to_copy_sum_9.run(buf682, buf685, 768, 64, grid=grid(768), stream=stream0)
buf684 = empty_strided_cuda((768, 3072), (3072, 1), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_14.run(buf681, buf684, 2359296, grid=grid(2359296), stream=stream0)
buf686 = reinterpret_tensor(buf680, (16, 512, 3072), (1572864, 3072, 1), 0); del buf680 # reuse
# Source Nodes: [hidden_states_4], Original ATen: [aten.gelu, aten.gelu_backward]
triton_poi_fused_gelu_gelu_backward_15.run(buf686, addmm_4, 25165824, grid=grid(25165824), stream=stream0)
del addmm_4
buf687 = reinterpret_tensor(buf679, (8192, 768), (768, 1), 0); del buf679 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf686, (8192, 3072), (3072, 1), 0), permute_509, out=buf687)
del permute_509
buf688 = reinterpret_tensor(buf681, (3072, 768), (768, 1), 0); del buf681 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf686, (3072, 8192), (1, 3072), 0), view_18, out=buf688)
del view_18
buf689 = buf630; del buf630 # reuse
# Source Nodes: [], Original ATen: [aten.sum]
triton_red_fused_sum_16.run(buf686, buf689, 98304, 256, grid=grid(98304), stream=stream0)
del buf686
buf692 = empty_strided_cuda((3072, ), (1, ), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum]
triton_per_fused__to_copy_sum_17.run(buf689, buf692, 3072, 32, grid=grid(3072), stream=stream0)
del buf689
buf691 = empty_strided_cuda((3072, 768), (768, 1), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_18.run(buf688, buf691, 2359296, grid=grid(2359296), stream=stream0)
del buf688
buf695 = buf636; del buf636 # reuse
buf700 = reinterpret_tensor(buf665, (16, 512, 768), (393216, 768, 1), 0); del buf665 # reuse
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_dropout_backward, aten.native_layer_norm_backward]
triton_per_fused__to_copy_add_native_dropout_backward_native_layer_norm_backward_19.run(buf674, buf687, primals_14, mul_9, div_61, gt_2, buf695, buf700, 8192, 768, grid=grid(8192), stream=stream0)
del div_61
del gt_2
del primals_14
buf696 = reinterpret_tensor(buf682, (768, 64), (1, 768), 0); del buf682 # reuse
buf698 = buf675; del buf675 # reuse
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward]
triton_red_fused__to_copy_add_native_layer_norm_backward_20.run(buf674, buf687, mul_9, buf696, buf698, 49152, 128, grid=grid(49152), stream=stream0)
del mul_9
buf697 = empty_strided_cuda((768, ), (1, ), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward]
triton_per_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_7.run(buf696, buf697, 768, 64, grid=grid(768), stream=stream0)
buf699 = empty_strided_cuda((768, ), (1, ), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward]
triton_per_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_7.run(buf698, buf699, 768, 64, grid=grid(768), stream=stream0)
buf701 = buf687; del buf687 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf700, (8192, 768), (768, 1), 0), permute_513, out=buf701)
del permute_513
buf702 = buf666; del buf666 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf700, (768, 8192), (1, 768), 0), view_16, out=buf702)
del view_16
buf703 = reinterpret_tensor(buf698, (1, 768, 64), (49152, 1, 768), 0); del buf698 # reuse
# Source Nodes: [], Original ATen: [aten.sum]
triton_red_fused_sum_13.run(buf700, buf703, 49152, 128, grid=grid(49152), stream=stream0)
buf706 = empty_strided_cuda((768, ), (1, ), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum]
triton_per_fused__to_copy_sum_9.run(buf703, buf706, 768, 64, grid=grid(768), stream=stream0)
buf705 = empty_strided_cuda((768, 768), (768, 1), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_10.run(buf702, buf705, 589824, grid=grid(589824), stream=stream0)
buf707 = reinterpret_tensor(buf700, (16, 12, 512, 64), (393216, 32768, 64, 1), 0); del buf700 # reuse
# Source Nodes: [], Original ATen: [aten.clone]
triton_poi_fused_clone_21.run(buf701, buf707, 6291456, grid=grid(6291456), stream=stream0)
del buf701
# Source Nodes: [], Original ATen: [aten.clone]
buf708 = aten._scaled_dot_product_flash_attention_backward.default(buf707, permute_default_66, permute_default_67, permute_default_68, getitem_129, getitem_130, None, None, 512, 512, 0.1, False, getitem_131, getitem_132, scale=0.125)
del getitem_129
del getitem_130
del getitem_131
del getitem_132
del permute_default_66
del permute_default_67
del permute_default_68
buf709 = buf708[0]
buf710 = buf708[1]
buf711 = buf708[2]
del buf708
buf712 = reinterpret_tensor(buf707, (8192, 768), (768, 1), 0); del buf707 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf711, (8192, 768), (768, 1), 0), permute_525, out=buf712)
del permute_525
buf713 = buf702; del buf702 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf711, (768, 8192), (1, 768), 0), view, out=buf713)
buf714 = buf703; del buf703 # reuse
# Source Nodes: [], Original ATen: [aten.sum]
triton_red_fused_sum_22.run(buf711, buf714, 49152, 128, grid=grid(49152), stream=stream0)
buf717 = empty_strided_cuda((768, ), (1, ), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum]
triton_per_fused__to_copy_sum_9.run(buf714, buf717, 768, 64, grid=grid(768), stream=stream0)
buf716 = empty_strided_cuda((768, 768), (768, 1), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_10.run(buf713, buf716, 589824, grid=grid(589824), stream=stream0)
buf718 = reinterpret_tensor(buf711, (8192, 768), (768, 1), 0); del buf711 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf710, (8192, 768), (768, 1), 0), permute_530, out=buf718)
del permute_530
buf719 = buf713; del buf713 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf710, (768, 8192), (1, 768), 0), view, out=buf719)
buf720 = buf714; del buf714 # reuse
# Source Nodes: [], Original ATen: [aten.sum]
triton_red_fused_sum_22.run(buf710, buf720, 49152, 128, grid=grid(49152), stream=stream0)
buf723 = empty_strided_cuda((768, ), (1, ), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum]
triton_per_fused__to_copy_sum_9.run(buf720, buf723, 768, 64, grid=grid(768), stream=stream0)
buf722 = empty_strided_cuda((768, 768), (768, 1), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_10.run(buf719, buf722, 589824, grid=grid(589824), stream=stream0)
buf724 = reinterpret_tensor(buf710, (8192, 768), (768, 1), 0); del buf710 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf709, (8192, 768), (768, 1), 0), permute_534, out=buf724)
del permute_534
buf725 = buf719; del buf719 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf709, (768, 8192), (1, 768), 0), view, out=buf725)
del view
buf726 = buf720; del buf720 # reuse
# Source Nodes: [], Original ATen: [aten.sum]
triton_red_fused_sum_22.run(buf709, buf726, 49152, 128, grid=grid(49152), stream=stream0)
del buf709
buf729 = empty_strided_cuda((768, ), (1, ), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum]
triton_per_fused__to_copy_sum_9.run(buf726, buf729, 768, 64, grid=grid(768), stream=stream0)
buf728 = empty_strided_cuda((768, 768), (768, 1), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_10.run(buf725, buf728, 589824, grid=grid(589824), stream=stream0)
del buf725
buf741 = empty_strided_cuda((2, 768), (768, 1), torch.float32)
# Source Nodes: [], Original ATen: [aten.embedding_dense_backward]
triton_poi_fused_embedding_dense_backward_25.run(buf741, 1536, grid=grid(1536), stream=stream0)
buf743 = empty_strided_cuda((30522, 768), (768, 1), torch.float32)
# Source Nodes: [], Original ATen: [aten.embedding_dense_backward]
triton_poi_fused_embedding_dense_backward_26.run(buf743, 23440896, grid=grid(23440896), stream=stream0)
buf730 = buf695; del buf695 # reuse
buf733 = buf674; del buf674 # reuse
# Source Nodes: [masked_lm_loss], Original ATen: [aten._to_copy, aten.add, aten.embedding_dense_backward, aten.native_dropout_backward, aten.native_layer_norm_backward, aten.nll_loss_forward]
triton_per_fused__to_copy_add_embedding_dense_backward_native_dropout_backward_native_layer_norm_backward_nll_loss_forward_27.run(buf730, buf712, buf718, buf724, gt, primals_4, mul_1, div_63, primals_204, primals_206, buf733, buf741, buf743, 8192, 768, grid=grid(8192), stream=stream0)
del buf712
del buf718
del buf724
del div_63
del gt
del primals_204
del primals_206
del primals_4
buf734 = reinterpret_tensor(buf726, (768, 64), (1, 768), 0); del buf726 # reuse
buf736 = buf696; del buf696 # reuse
# Source Nodes: [], Original ATen: [aten.native_layer_norm_backward]
triton_red_fused_native_layer_norm_backward_28.run(buf730, mul_1, buf734, buf736, 49152, 128, grid=grid(49152), stream=stream0)
del buf730
del mul_1
buf735 = empty_strided_cuda((768, ), (1, ), torch.float32)
# Source Nodes: [], Original ATen: [aten.native_layer_norm_backward]
triton_per_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_7.run(buf734, buf735, 768, 64, grid=grid(768), stream=stream0)
del buf734
buf737 = empty_strided_cuda((768, ), (1, ), torch.float32)
# Source Nodes: [], Original ATen: [aten.native_layer_norm_backward]
triton_per_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_7.run(buf736, buf737, 768, 64, grid=grid(768), stream=stream0)
del buf736
buf739 = empty_strided_cuda((512, 768), (768, 1), torch.float32)
# Source Nodes: [], Original ATen: [aten.embedding_dense_backward]
triton_poi_fused_embedding_dense_backward_29.run(buf739, 393216, grid=grid(393216), stream=stream0)
# Source Nodes: [masked_lm_loss], Original ATen: [aten.embedding_dense_backward, aten.nll_loss_forward, aten.sum]
triton_per_fused_embedding_dense_backward_nll_loss_forward_sum_30.run(buf733, primals_205, buf739, 393216, 16, grid=grid(393216), stream=stream0)
del buf733
del primals_205
return (buf743, buf741, buf739, buf735, buf737, buf728, buf729, buf722, buf723, buf716, buf717, buf705, buf706, buf697, buf699, buf691, buf692, buf684, buf685, buf676, buf678, buf669, buf670, buf663, buf664, buf657, buf658, buf646, buf647, buf638, buf640, buf632, buf633, buf625, buf626, buf617, buf619, buf610, buf611, buf604, buf605, buf598, buf599, buf587, buf588, buf579, buf581, buf573, buf574, buf566, buf567, buf558, buf560, buf551, buf552, buf545, buf546, buf539, buf540, buf528, buf529, buf520, buf522, buf514, buf515, buf507, buf508, buf499, buf501, buf492, buf493, buf486, buf487, buf480, buf481, buf469, buf470, buf461, buf463, buf455, buf456, buf448, buf449, buf440, buf442, buf433, buf434, buf427, buf428, buf421, buf422, buf410, buf411, buf402, buf404, buf396, buf397, buf389, buf390, buf381, buf383, buf374, buf375, buf368, buf369, buf362, buf363, buf351, buf352, buf343, buf345, buf337, buf338, buf330, buf331, buf322, buf324, buf315, buf316, buf309, buf310, buf303, buf304, buf292, buf293, buf284, buf286, buf278, buf279, buf271, buf272, buf263, buf265, buf256, buf257, buf250, buf251, buf244, buf245, buf233, buf234, buf225, buf227, buf219, buf220, buf212, buf213, buf204, buf206, buf197, buf198, buf191, buf192, buf185, buf186, buf174, buf175, buf166, buf168, buf160, buf161, buf153, buf154, buf145, buf147, buf138, buf139, buf132, buf133, buf126, buf127, buf115, buf116, buf107, buf109, buf101, buf102, buf94, buf95, buf86, buf88, buf79, buf80, buf73, buf74, buf67, buf68, buf56, buf57, buf48, buf50, buf42, buf43, buf35, buf36, buf27, buf29, buf21, buf22, buf12, buf14, buf7, buf8, None, None, None, None, )
def benchmark_compiled_module(times=10, repeat=10):
from torch._dynamo.testing import rand_strided
from torch._inductor.utils import print_performance
primals_4 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_14 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_20 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_30 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_36 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_46 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_52 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_62 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_68 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_78 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_84 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_94 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_100 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_110 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_116 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_126 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_132 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_142 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_148 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_158 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_164 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_174 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_180 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_190 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_196 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_200 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_204 = rand_strided((1, 512), (512, 1), device='cuda:0', dtype=torch.int64)
primals_205 = rand_strided((1, 512), (512, 1), device='cuda:0', dtype=torch.int64)
primals_206 = rand_strided((16, 512), (512, 1), device='cuda:0', dtype=torch.int64)
primals_207 = rand_strided((16, 512), (512, 1), device='cuda:0', dtype=torch.int64)
mul_1 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.float32)
gt = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.bool)
view = rand_strided((8192, 768), (768, 1), device='cuda:0', dtype=torch.float16)
permute_default_66 = rand_strided((16, 12, 512, 64), (393216, 64, 768, 1), device='cuda:0', dtype=torch.float16)
permute_default_67 = rand_strided((16, 12, 512, 64), (393216, 64, 768, 1), device='cuda:0', dtype=torch.float16)
permute_default_68 = rand_strided((16, 12, 512, 64), (393216, 64, 768, 1), device='cuda:0', dtype=torch.float16)
getitem_129 = rand_strided((16, 12, 512, 64), (393216, 64, 768, 1), device='cuda:0', dtype=torch.float16)
getitem_130 = rand_strided((16, 12, 512), (6144, 512, 1), device='cuda:0', dtype=torch.float32)
getitem_131 = rand_strided((), (), device='cuda:0', dtype=torch.int64)
getitem_132 = rand_strided((), (), device='cuda:0', dtype=torch.int64)
view_16 = rand_strided((8192, 768), (768, 1), device='cuda:0', dtype=torch.float16)
gt_2 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.bool)
mul_9 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.float32)
view_18 = rand_strided((8192, 768), (768, 1), device='cuda:0', dtype=torch.float16)
addmm_4 = rand_strided((8192, 3072), (3072, 1), device='cuda:0', dtype=torch.float16)
view_20 = rand_strided((8192, 3072), (3072, 1), device='cuda:0', dtype=torch.float16)
gt_3 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.bool)
mul_16 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.float32)
view_22 = rand_strided((8192, 768), (768, 1), device='cuda:0', dtype=torch.float16)
permute_default_60 = rand_strided((16, 12, 512, 64), (393216, 64, 768, 1), device='cuda:0', dtype=torch.float16)
permute_default_61 = rand_strided((16, 12, 512, 64), (393216, 64, 768, 1), device='cuda:0', dtype=torch.float16)
permute_default_62 = rand_strided((16, 12, 512, 64), (393216, 64, 768, 1), device='cuda:0', dtype=torch.float16)
getitem_122 = rand_strided((16, 12, 512, 64), (393216, 64, 768, 1), device='cuda:0', dtype=torch.float16)
getitem_123 = rand_strided((16, 12, 512), (6144, 512, 1), device='cuda:0', dtype=torch.float32)
getitem_124 = rand_strided((), (), device='cuda:0', dtype=torch.int64)
getitem_125 = rand_strided((), (), device='cuda:0', dtype=torch.int64)
view_38 = rand_strided((8192, 768), (768, 1), device='cuda:0', dtype=torch.float16)
gt_5 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.bool)
mul_22 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.float32)
view_40 = rand_strided((8192, 768), (768, 1), device='cuda:0', dtype=torch.float16)
addmm_10 = rand_strided((8192, 3072), (3072, 1), device='cuda:0', dtype=torch.float16)
view_42 = rand_strided((8192, 3072), (3072, 1), device='cuda:0', dtype=torch.float16)
gt_6 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.bool)
mul_29 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.float32)
view_44 = rand_strided((8192, 768), (768, 1), device='cuda:0', dtype=torch.float16)
permute_default_54 = rand_strided((16, 12, 512, 64), (393216, 64, 768, 1), device='cuda:0', dtype=torch.float16)
permute_default_55 = rand_strided((16, 12, 512, 64), (393216, 64, 768, 1), device='cuda:0', dtype=torch.float16)
permute_default_56 = rand_strided((16, 12, 512, 64), (393216, 64, 768, 1), device='cuda:0', dtype=torch.float16)
getitem_115 = rand_strided((16, 12, 512, 64), (393216, 64, 768, 1), device='cuda:0', dtype=torch.float16)
getitem_116 = rand_strided((16, 12, 512), (6144, 512, 1), device='cuda:0', dtype=torch.float32)
getitem_117 = rand_strided((), (), device='cuda:0', dtype=torch.int64)
getitem_118 = rand_strided((), (), device='cuda:0', dtype=torch.int64)
view_60 = rand_strided((8192, 768), (768, 1), device='cuda:0', dtype=torch.float16)
gt_8 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.bool)
mul_35 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.float32)
view_62 = rand_strided((8192, 768), (768, 1), device='cuda:0', dtype=torch.float16)
addmm_16 = rand_strided((8192, 3072), (3072, 1), device='cuda:0', dtype=torch.float16)
view_64 = rand_strided((8192, 3072), (3072, 1), device='cuda:0', dtype=torch.float16)
gt_9 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.bool)
mul_42 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.float32)
view_66 = rand_strided((8192, 768), (768, 1), device='cuda:0', dtype=torch.float16)
permute_default_48 = rand_strided((16, 12, 512, 64), (393216, 64, 768, 1), device='cuda:0', dtype=torch.float16)
permute_default_49 = rand_strided((16, 12, 512, 64), (393216, 64, 768, 1), device='cuda:0', dtype=torch.float16)
permute_default_50 = rand_strided((16, 12, 512, 64), (393216, 64, 768, 1), device='cuda:0', dtype=torch.float16)
getitem_108 = rand_strided((16, 12, 512, 64), (393216, 64, 768, 1), device='cuda:0', dtype=torch.float16)
getitem_109 = rand_strided((16, 12, 512), (6144, 512, 1), device='cuda:0', dtype=torch.float32)
getitem_110 = rand_strided((), (), device='cuda:0', dtype=torch.int64)
getitem_111 = rand_strided((), (), device='cuda:0', dtype=torch.int64)
view_82 = rand_strided((8192, 768), (768, 1), device='cuda:0', dtype=torch.float16)
gt_11 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.bool)
mul_48 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.float32)
view_84 = rand_strided((8192, 768), (768, 1), device='cuda:0', dtype=torch.float16)
addmm_22 = rand_strided((8192, 3072), (3072, 1), device='cuda:0', dtype=torch.float16)
view_86 = rand_strided((8192, 3072), (3072, 1), device='cuda:0', dtype=torch.float16)
gt_12 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.bool)
mul_55 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.float32)
view_88 = rand_strided((8192, 768), (768, 1), device='cuda:0', dtype=torch.float16)
permute_default_42 = rand_strided((16, 12, 512, 64), (393216, 64, 768, 1), device='cuda:0', dtype=torch.float16)
permute_default_43 = rand_strided((16, 12, 512, 64), (393216, 64, 768, 1), device='cuda:0', dtype=torch.float16)
permute_default_44 = rand_strided((16, 12, 512, 64), (393216, 64, 768, 1), device='cuda:0', dtype=torch.float16)
getitem_101 = rand_strided((16, 12, 512, 64), (393216, 64, 768, 1), device='cuda:0', dtype=torch.float16)
getitem_102 = rand_strided((16, 12, 512), (6144, 512, 1), device='cuda:0', dtype=torch.float32)
getitem_103 = rand_strided((), (), device='cuda:0', dtype=torch.int64)
getitem_104 = rand_strided((), (), device='cuda:0', dtype=torch.int64)
view_104 = rand_strided((8192, 768), (768, 1), device='cuda:0', dtype=torch.float16)
gt_14 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.bool)
mul_61 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.float32)
view_106 = rand_strided((8192, 768), (768, 1), device='cuda:0', dtype=torch.float16)
addmm_28 = rand_strided((8192, 3072), (3072, 1), device='cuda:0', dtype=torch.float16)
view_108 = rand_strided((8192, 3072), (3072, 1), device='cuda:0', dtype=torch.float16)
gt_15 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.bool)
mul_68 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.float32)
view_110 = rand_strided((8192, 768), (768, 1), device='cuda:0', dtype=torch.float16)
permute_default_36 = rand_strided((16, 12, 512, 64), (393216, 64, 768, 1), device='cuda:0', dtype=torch.float16)
permute_default_37 = rand_strided((16, 12, 512, 64), (393216, 64, 768, 1), device='cuda:0', dtype=torch.float16)
permute_default_38 = rand_strided((16, 12, 512, 64), (393216, 64, 768, 1), device='cuda:0', dtype=torch.float16)
getitem_94 = rand_strided((16, 12, 512, 64), (393216, 64, 768, 1), device='cuda:0', dtype=torch.float16)
getitem_95 = rand_strided((16, 12, 512), (6144, 512, 1), device='cuda:0', dtype=torch.float32)
getitem_96 = rand_strided((), (), device='cuda:0', dtype=torch.int64)
getitem_97 = rand_strided((), (), device='cuda:0', dtype=torch.int64)
view_126 = rand_strided((8192, 768), (768, 1), device='cuda:0', dtype=torch.float16)
gt_17 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.bool)
mul_74 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.float32)
view_128 = rand_strided((8192, 768), (768, 1), device='cuda:0', dtype=torch.float16)
addmm_34 = rand_strided((8192, 3072), (3072, 1), device='cuda:0', dtype=torch.float16)
view_130 = rand_strided((8192, 3072), (3072, 1), device='cuda:0', dtype=torch.float16)
gt_18 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.bool)
mul_81 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.float32)
view_132 = rand_strided((8192, 768), (768, 1), device='cuda:0', dtype=torch.float16)
permute_default_30 = rand_strided((16, 12, 512, 64), (393216, 64, 768, 1), device='cuda:0', dtype=torch.float16)
permute_default_31 = rand_strided((16, 12, 512, 64), (393216, 64, 768, 1), device='cuda:0', dtype=torch.float16)
permute_default_32 = rand_strided((16, 12, 512, 64), (393216, 64, 768, 1), device='cuda:0', dtype=torch.float16)
getitem_87 = rand_strided((16, 12, 512, 64), (393216, 64, 768, 1), device='cuda:0', dtype=torch.float16)
getitem_88 = rand_strided((16, 12, 512), (6144, 512, 1), device='cuda:0', dtype=torch.float32)
getitem_89 = rand_strided((), (), device='cuda:0', dtype=torch.int64)
getitem_90 = rand_strided((), (), device='cuda:0', dtype=torch.int64)
view_148 = rand_strided((8192, 768), (768, 1), device='cuda:0', dtype=torch.float16)
gt_20 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.bool)
mul_87 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.float32)
view_150 = rand_strided((8192, 768), (768, 1), device='cuda:0', dtype=torch.float16)
addmm_40 = rand_strided((8192, 3072), (3072, 1), device='cuda:0', dtype=torch.float16)
view_152 = rand_strided((8192, 3072), (3072, 1), device='cuda:0', dtype=torch.float16)
gt_21 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.bool)
mul_94 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.float32)
view_154 = rand_strided((8192, 768), (768, 1), device='cuda:0', dtype=torch.float16)
permute_default_24 = rand_strided((16, 12, 512, 64), (393216, 64, 768, 1), device='cuda:0', dtype=torch.float16)
permute_default_25 = rand_strided((16, 12, 512, 64), (393216, 64, 768, 1), device='cuda:0', dtype=torch.float16)
permute_default_26 = rand_strided((16, 12, 512, 64), (393216, 64, 768, 1), device='cuda:0', dtype=torch.float16)
getitem_80 = rand_strided((16, 12, 512, 64), (393216, 64, 768, 1), device='cuda:0', dtype=torch.float16)
getitem_81 = rand_strided((16, 12, 512), (6144, 512, 1), device='cuda:0', dtype=torch.float32)
getitem_82 = rand_strided((), (), device='cuda:0', dtype=torch.int64)
getitem_83 = rand_strided((), (), device='cuda:0', dtype=torch.int64)
view_170 = rand_strided((8192, 768), (768, 1), device='cuda:0', dtype=torch.float16)
gt_23 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.bool)
mul_100 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.float32)
view_172 = rand_strided((8192, 768), (768, 1), device='cuda:0', dtype=torch.float16)
addmm_46 = rand_strided((8192, 3072), (3072, 1), device='cuda:0', dtype=torch.float16)
view_174 = rand_strided((8192, 3072), (3072, 1), device='cuda:0', dtype=torch.float16)
gt_24 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.bool)
mul_107 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.float32)
view_176 = rand_strided((8192, 768), (768, 1), device='cuda:0', dtype=torch.float16)
permute_default_18 = rand_strided((16, 12, 512, 64), (393216, 64, 768, 1), device='cuda:0', dtype=torch.float16)
permute_default_19 = rand_strided((16, 12, 512, 64), (393216, 64, 768, 1), device='cuda:0', dtype=torch.float16)
permute_default_20 = rand_strided((16, 12, 512, 64), (393216, 64, 768, 1), device='cuda:0', dtype=torch.float16)
getitem_73 = rand_strided((16, 12, 512, 64), (393216, 64, 768, 1), device='cuda:0', dtype=torch.float16)
getitem_74 = rand_strided((16, 12, 512), (6144, 512, 1), device='cuda:0', dtype=torch.float32)
getitem_75 = rand_strided((), (), device='cuda:0', dtype=torch.int64)
getitem_76 = rand_strided((), (), device='cuda:0', dtype=torch.int64)
view_192 = rand_strided((8192, 768), (768, 1), device='cuda:0', dtype=torch.float16)
gt_26 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.bool)
mul_113 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.float32)
view_194 = rand_strided((8192, 768), (768, 1), device='cuda:0', dtype=torch.float16)
addmm_52 = rand_strided((8192, 3072), (3072, 1), device='cuda:0', dtype=torch.float16)
view_196 = rand_strided((8192, 3072), (3072, 1), device='cuda:0', dtype=torch.float16)
gt_27 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.bool)
mul_120 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.float32)
view_198 = rand_strided((8192, 768), (768, 1), device='cuda:0', dtype=torch.float16)
permute_default_12 = rand_strided((16, 12, 512, 64), (393216, 64, 768, 1), device='cuda:0', dtype=torch.float16)
permute_default_13 = rand_strided((16, 12, 512, 64), (393216, 64, 768, 1), device='cuda:0', dtype=torch.float16)
permute_default_14 = rand_strided((16, 12, 512, 64), (393216, 64, 768, 1), device='cuda:0', dtype=torch.float16)
getitem_66 = rand_strided((16, 12, 512, 64), (393216, 64, 768, 1), device='cuda:0', dtype=torch.float16)
getitem_67 = rand_strided((16, 12, 512), (6144, 512, 1), device='cuda:0', dtype=torch.float32)
getitem_68 = rand_strided((), (), device='cuda:0', dtype=torch.int64)
getitem_69 = rand_strided((), (), device='cuda:0', dtype=torch.int64)
view_214 = rand_strided((8192, 768), (768, 1), device='cuda:0', dtype=torch.float16)
gt_29 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.bool)
mul_126 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.float32)
view_216 = rand_strided((8192, 768), (768, 1), device='cuda:0', dtype=torch.float16)
addmm_58 = rand_strided((8192, 3072), (3072, 1), device='cuda:0', dtype=torch.float16)
view_218 = rand_strided((8192, 3072), (3072, 1), device='cuda:0', dtype=torch.float16)
gt_30 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.bool)
mul_133 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.float32)
view_220 = rand_strided((8192, 768), (768, 1), device='cuda:0', dtype=torch.float16)
permute_default_6 = rand_strided((16, 12, 512, 64), (393216, 64, 768, 1), device='cuda:0', dtype=torch.float16)
permute_default_7 = rand_strided((16, 12, 512, 64), (393216, 64, 768, 1), device='cuda:0', dtype=torch.float16)
permute_default_8 = rand_strided((16, 12, 512, 64), (393216, 64, 768, 1), device='cuda:0', dtype=torch.float16)
getitem_59 = rand_strided((16, 12, 512, 64), (393216, 64, 768, 1), device='cuda:0', dtype=torch.float16)
getitem_60 = rand_strided((16, 12, 512), (6144, 512, 1), device='cuda:0', dtype=torch.float32)
getitem_61 = rand_strided((), (), device='cuda:0', dtype=torch.int64)
getitem_62 = rand_strided((), (), device='cuda:0', dtype=torch.int64)
view_236 = rand_strided((8192, 768), (768, 1), device='cuda:0', dtype=torch.float16)
gt_32 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.bool)
mul_139 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.float32)
view_238 = rand_strided((8192, 768), (768, 1), device='cuda:0', dtype=torch.float16)
addmm_64 = rand_strided((8192, 3072), (3072, 1), device='cuda:0', dtype=torch.float16)
view_240 = rand_strided((8192, 3072), (3072, 1), device='cuda:0', dtype=torch.float16)
gt_33 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.bool)
mul_146 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.float32)
view_242 = rand_strided((8192, 768), (768, 1), device='cuda:0', dtype=torch.float16)
permute_default = rand_strided((16, 12, 512, 64), (393216, 64, 768, 1), device='cuda:0', dtype=torch.float16)
permute_default_1 = rand_strided((16, 12, 512, 64), (393216, 64, 768, 1), device='cuda:0', dtype=torch.float16)
permute_default_2 = rand_strided((16, 12, 512, 64), (393216, 64, 768, 1), device='cuda:0', dtype=torch.float16)
getitem_52 = rand_strided((16, 12, 512, 64), (393216, 64, 768, 1), device='cuda:0', dtype=torch.float16)
getitem_53 = rand_strided((16, 12, 512), (6144, 512, 1), device='cuda:0', dtype=torch.float32)
getitem_54 = rand_strided((), (), device='cuda:0', dtype=torch.int64)
getitem_55 = rand_strided((), (), device='cuda:0', dtype=torch.int64)
view_258 = rand_strided((8192, 768), (768, 1), device='cuda:0', dtype=torch.float16)
gt_35 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.bool)
mul_152 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.float32)
view_260 = rand_strided((8192, 768), (768, 1), device='cuda:0', dtype=torch.float16)
addmm_70 = rand_strided((8192, 3072), (3072, 1), device='cuda:0', dtype=torch.float16)
view_262 = rand_strided((8192, 3072), (3072, 1), device='cuda:0', dtype=torch.float16)
gt_36 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.bool)
mul_159 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.float32)
view_264 = rand_strided((8192, 768), (768, 1), device='cuda:0', dtype=torch.float16)
addmm_72 = rand_strided((8192, 768), (768, 1), device='cuda:0', dtype=torch.float16)
getitem_51 = rand_strided((16, 512, 1), (512, 1, 1), device='cuda:0', dtype=torch.float32)
rsqrt_25 = rand_strided((16, 512, 1), (512, 1, 1), device='cuda:0', dtype=torch.float32)
view_266 = rand_strided((8192, 768), (768, 1), device='cuda:0', dtype=torch.float16)
view_267 = rand_strided((16, 512, 30522), (15630336, 30528, 1), device='cuda:0', dtype=torch.float16)
amax_12 = rand_strided((8192, 1), (1, 1), device='cuda:0', dtype=torch.float32)
log = rand_strided((8192, 1), (1, 1), device='cuda:0', dtype=torch.float32)
convert_element_type_510 = rand_strided((), (), device='cuda:0', dtype=torch.float32)
permute_134 = rand_strided((30522, 768), (768, 1), device='cuda:0', dtype=torch.float16)
permute_138 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.float16)
div_27 = rand_strided((16, 512, 1), (512, 1, 1), device='cuda:0', dtype=torch.float32)
permute_142 = rand_strided((768, 3072), (3072, 1), device='cuda:0', dtype=torch.float16)
permute_146 = rand_strided((3072, 768), (768, 1), device='cuda:0', dtype=torch.float16)
div_28 = rand_strided((16, 512, 1), (512, 1, 1), device='cuda:0', dtype=torch.float32)
permute_150 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.float16)
permute_162 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.float16)
permute_167 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.float16)
permute_171 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.float16)
div_30 = rand_strided((16, 512, 1), (512, 1, 1), device='cuda:0', dtype=torch.float32)
permute_175 = rand_strided((768, 3072), (3072, 1), device='cuda:0', dtype=torch.float16)
permute_179 = rand_strided((3072, 768), (768, 1), device='cuda:0', dtype=torch.float16)
div_31 = rand_strided((16, 512, 1), (512, 1, 1), device='cuda:0', dtype=torch.float32)
permute_183 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.float16)
permute_195 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.float16)
permute_200 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.float16)
permute_204 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.float16)
div_33 = rand_strided((16, 512, 1), (512, 1, 1), device='cuda:0', dtype=torch.float32)
permute_208 = rand_strided((768, 3072), (3072, 1), device='cuda:0', dtype=torch.float16)
permute_212 = rand_strided((3072, 768), (768, 1), device='cuda:0', dtype=torch.float16)
div_34 = rand_strided((16, 512, 1), (512, 1, 1), device='cuda:0', dtype=torch.float32)
permute_216 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.float16)
permute_228 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.float16)
permute_233 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.float16)
permute_237 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.float16)
div_36 = rand_strided((16, 512, 1), (512, 1, 1), device='cuda:0', dtype=torch.float32)
permute_241 = rand_strided((768, 3072), (3072, 1), device='cuda:0', dtype=torch.float16)
permute_245 = rand_strided((3072, 768), (768, 1), device='cuda:0', dtype=torch.float16)
div_37 = rand_strided((16, 512, 1), (512, 1, 1), device='cuda:0', dtype=torch.float32)
permute_249 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.float16)
permute_261 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.float16)
permute_266 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.float16)
permute_270 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.float16)
div_39 = rand_strided((16, 512, 1), (512, 1, 1), device='cuda:0', dtype=torch.float32)
permute_274 = rand_strided((768, 3072), (3072, 1), device='cuda:0', dtype=torch.float16)
permute_278 = rand_strided((3072, 768), (768, 1), device='cuda:0', dtype=torch.float16)
div_40 = rand_strided((16, 512, 1), (512, 1, 1), device='cuda:0', dtype=torch.float32)
permute_282 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.float16)
permute_294 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.float16)
permute_299 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.float16)
permute_303 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.float16)
div_42 = rand_strided((16, 512, 1), (512, 1, 1), device='cuda:0', dtype=torch.float32)
permute_307 = rand_strided((768, 3072), (3072, 1), device='cuda:0', dtype=torch.float16)
permute_311 = rand_strided((3072, 768), (768, 1), device='cuda:0', dtype=torch.float16)
div_43 = rand_strided((16, 512, 1), (512, 1, 1), device='cuda:0', dtype=torch.float32)
permute_315 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.float16)
permute_327 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.float16)
permute_332 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.float16)
permute_336 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.float16)
div_45 = rand_strided((16, 512, 1), (512, 1, 1), device='cuda:0', dtype=torch.float32)
permute_340 = rand_strided((768, 3072), (3072, 1), device='cuda:0', dtype=torch.float16)
permute_344 = rand_strided((3072, 768), (768, 1), device='cuda:0', dtype=torch.float16)
div_46 = rand_strided((16, 512, 1), (512, 1, 1), device='cuda:0', dtype=torch.float32)
permute_348 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.float16)
permute_360 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.float16)
permute_365 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.float16)
permute_369 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.float16)
div_48 = rand_strided((16, 512, 1), (512, 1, 1), device='cuda:0', dtype=torch.float32)
permute_373 = rand_strided((768, 3072), (3072, 1), device='cuda:0', dtype=torch.float16)
permute_377 = rand_strided((3072, 768), (768, 1), device='cuda:0', dtype=torch.float16)
div_49 = rand_strided((16, 512, 1), (512, 1, 1), device='cuda:0', dtype=torch.float32)
permute_381 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.float16)
permute_393 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.float16)
permute_398 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.float16)
permute_402 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.float16)
div_51 = rand_strided((16, 512, 1), (512, 1, 1), device='cuda:0', dtype=torch.float32)
permute_406 = rand_strided((768, 3072), (3072, 1), device='cuda:0', dtype=torch.float16)
permute_410 = rand_strided((3072, 768), (768, 1), device='cuda:0', dtype=torch.float16)
div_52 = rand_strided((16, 512, 1), (512, 1, 1), device='cuda:0', dtype=torch.float32)
permute_414 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.float16)
permute_426 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.float16)
permute_431 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.float16)
permute_435 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.float16)
div_54 = rand_strided((16, 512, 1), (512, 1, 1), device='cuda:0', dtype=torch.float32)
permute_439 = rand_strided((768, 3072), (3072, 1), device='cuda:0', dtype=torch.float16)
permute_443 = rand_strided((3072, 768), (768, 1), device='cuda:0', dtype=torch.float16)
div_55 = rand_strided((16, 512, 1), (512, 1, 1), device='cuda:0', dtype=torch.float32)
permute_447 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.float16)
permute_459 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.float16)
permute_464 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.float16)
permute_468 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.float16)
div_57 = rand_strided((16, 512, 1), (512, 1, 1), device='cuda:0', dtype=torch.float32)
permute_472 = rand_strided((768, 3072), (3072, 1), device='cuda:0', dtype=torch.float16)
permute_476 = rand_strided((3072, 768), (768, 1), device='cuda:0', dtype=torch.float16)
div_58 = rand_strided((16, 512, 1), (512, 1, 1), device='cuda:0', dtype=torch.float32)
permute_480 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.float16)
permute_492 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.float16)
permute_497 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.float16)
permute_501 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.float16)
div_60 = rand_strided((16, 512, 1), (512, 1, 1), device='cuda:0', dtype=torch.float32)
permute_505 = rand_strided((768, 3072), (3072, 1), device='cuda:0', dtype=torch.float16)
permute_509 = rand_strided((3072, 768), (768, 1), device='cuda:0', dtype=torch.float16)
div_61 = rand_strided((16, 512, 1), (512, 1, 1), device='cuda:0', dtype=torch.float32)
permute_513 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.float16)
permute_525 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.float16)
permute_530 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.float16)
permute_534 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.float16)
div_63 = rand_strided((16, 512, 1), (512, 1, 1), device='cuda:0', dtype=torch.float32)
tangents_1 = rand_strided((), (), device='cuda:0', dtype=torch.float32)
tangents_2 = rand_strided((16, 512, 30522), (15627264, 30522, 1), device='cuda:0', dtype=torch.float16)
fn = lambda: call([primals_4, primals_14, primals_20, primals_30, primals_36, primals_46, primals_52, primals_62, primals_68, primals_78, primals_84, primals_94, primals_100, primals_110, primals_116, primals_126, primals_132, primals_142, primals_148, primals_158, primals_164, primals_174, primals_180, primals_190, primals_196, primals_200, primals_204, primals_205, primals_206, primals_207, mul_1, gt, view, permute_default_66, permute_default_67, permute_default_68, getitem_129, getitem_130, getitem_131, getitem_132, view_16, gt_2, mul_9, view_18, addmm_4, view_20, gt_3, mul_16, view_22, permute_default_60, permute_default_61, permute_default_62, getitem_122, getitem_123, getitem_124, getitem_125, view_38, gt_5, mul_22, view_40, addmm_10, view_42, gt_6, mul_29, view_44, permute_default_54, permute_default_55, permute_default_56, getitem_115, getitem_116, getitem_117, getitem_118, view_60, gt_8, mul_35, view_62, addmm_16, view_64, gt_9, mul_42, view_66, permute_default_48, permute_default_49, permute_default_50, getitem_108, getitem_109, getitem_110, getitem_111, view_82, gt_11, mul_48, view_84, addmm_22, view_86, gt_12, mul_55, view_88, permute_default_42, permute_default_43, permute_default_44, getitem_101, getitem_102, getitem_103, getitem_104, view_104, gt_14, mul_61, view_106, addmm_28, view_108, gt_15, mul_68, view_110, permute_default_36, permute_default_37, permute_default_38, getitem_94, getitem_95, getitem_96, getitem_97, view_126, gt_17, mul_74, view_128, addmm_34, view_130, gt_18, mul_81, view_132, permute_default_30, permute_default_31, permute_default_32, getitem_87, getitem_88, getitem_89, getitem_90, view_148, gt_20, mul_87, view_150, addmm_40, view_152, gt_21, mul_94, view_154, permute_default_24, permute_default_25, permute_default_26, getitem_80, getitem_81, getitem_82, getitem_83, view_170, gt_23, mul_100, view_172, addmm_46, view_174, gt_24, mul_107, view_176, permute_default_18, permute_default_19, permute_default_20, getitem_73, getitem_74, getitem_75, getitem_76, view_192, gt_26, mul_113, view_194, addmm_52, view_196, gt_27, mul_120, view_198, permute_default_12, permute_default_13, permute_default_14, getitem_66, getitem_67, getitem_68, getitem_69, view_214, gt_29, mul_126, view_216, addmm_58, view_218, gt_30, mul_133, view_220, permute_default_6, permute_default_7, permute_default_8, getitem_59, getitem_60, getitem_61, getitem_62, view_236, gt_32, mul_139, view_238, addmm_64, view_240, gt_33, mul_146, view_242, permute_default, permute_default_1, permute_default_2, getitem_52, getitem_53, getitem_54, getitem_55, view_258, gt_35, mul_152, view_260, addmm_70, view_262, gt_36, mul_159, view_264, addmm_72, getitem_51, rsqrt_25, view_266, view_267, amax_12, log, convert_element_type_510, permute_134, permute_138, div_27, permute_142, permute_146, div_28, permute_150, permute_162, permute_167, permute_171, div_30, permute_175, permute_179, div_31, permute_183, permute_195, permute_200, permute_204, div_33, permute_208, permute_212, div_34, permute_216, permute_228, permute_233, permute_237, div_36, permute_241, permute_245, div_37, permute_249, permute_261, permute_266, permute_270, div_39, permute_274, permute_278, div_40, permute_282, permute_294, permute_299, permute_303, div_42, permute_307, permute_311, div_43, permute_315, permute_327, permute_332, permute_336, div_45, permute_340, permute_344, div_46, permute_348, permute_360, permute_365, permute_369, div_48, permute_373, permute_377, div_49, permute_381, permute_393, permute_398, permute_402, div_51, permute_406, permute_410, div_52, permute_414, permute_426, permute_431, permute_435, div_54, permute_439, permute_443, div_55, permute_447, permute_459, permute_464, permute_468, div_57, permute_472, permute_476, div_58, permute_480, permute_492, permute_497, permute_501, div_60, permute_505, permute_509, div_61, permute_513, permute_525, permute_530, permute_534, div_63, tangents_1, tangents_2])
return print_performance(fn, times=times, repeat=repeat)
if __name__ == "__main__":
from torch._inductor.wrapper_benchmark import compiled_module_main
compiled_module_main('BertForMaskedLM', benchmark_compiled_module)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment