Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save shunting314/caac76dfa1cc494dd032feb320a915d5 to your computer and use it in GitHub Desktop.
Save shunting314/caac76dfa1cc494dd032feb320a915d5 to your computer and use it in GitHub Desktop.
# AOT ID: ['0_backward']
from ctypes import c_void_p, c_long
import torch
import math
import random
import os
import tempfile
from math import inf, nan
from torch._inductor.hooks import run_intermediate_hooks
from torch._inductor.utils import maybe_profile
from torch._inductor.codegen.memory_planning import _align as align
from torch import device, empty_strided
from torch._inductor.async_compile import AsyncCompile
from torch._inductor.select_algorithm import extern_kernels
from torch._inductor.codegen.multi_kernel import MultiKernelCall
aten = torch.ops.aten
inductor_ops = torch.ops.inductor
_quantized = torch.ops._quantized
assert_size_stride = torch._C._dynamo.guards.assert_size_stride
empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu
empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda
reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor
alloc_from_pool = torch.ops.inductor._alloc_from_pool
async_compile = AsyncCompile()
# kernel path: /tmp/torchinductor_shunting/3i/c3iwze52kpy2zykepgw7shz2ksspmjdafig5gxepuphlkw5djj2f.py
# Source Nodes: [masked_lm_loss], Original ATen: [aten.nll_loss_backward, aten.nll_loss_forward]
# masked_lm_loss => full_default_1
triton_poi_fused_nll_loss_backward_nll_loss_forward_0 = async_compile.triton('triton_poi_fused_nll_loss_backward_nll_loss_forward_0', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
from torch._dynamo.testing import rand_strided
from torch._C import _cuda_getCurrentRawStream as get_raw_stream
import torch
from torch._inductor.runtime.triton_heuristics import grid, split_scan_grid
@triton_heuristics.pointwise(
size_hints=[268435456],
filename=__file__,
triton_meta={'signature': {0: '*fp32', 1: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1), equal_to_1=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_nll_loss_backward_nll_loss_forward_0', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 0, 'num_reduction': 0, 'backend_hash': '196B341D951BBDD96DE5B2B44B37054C2CBA8E89494C8B8A1052A863B8BC7596', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'kernel_num_gb': 1.000144896},
min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused_nll_loss_backward_nll_loss_forward_0(out_ptr0, xnumel, XBLOCK : tl.constexpr):
xnumel = 250036224
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = tl.full([XBLOCK], True, tl.int1)
x0 = xindex % 30522
x1 = (xindex // 30522)
tmp0 = 0.0
tl.store(out_ptr0 + (x0 + (30528*x1)), tmp0, None)
def get_args():
arg_0 = rand_strided((8192, 30522), (30528, 1), device='cuda:0', dtype=torch.float32)
return arg_0,
def call(args):
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
stream0 = get_raw_stream(0)
triton_poi_fused_nll_loss_backward_nll_loss_forward_0.run(*args, 250036224, grid=grid(250036224), stream=stream0)
def benchmark_all_configs(args):
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
return triton_poi_fused_nll_loss_backward_nll_loss_forward_0.benchmark_all_configs(*args, 250036224, grid=grid(250036224))
if __name__ == '__main__':
from triton.testing import do_bench
args = get_args()
ms = do_bench(lambda: call(args), rep=40, fast_flush=True)
num_gb = 1.000144896
gb_per_s = num_gb / (ms / 1e3)
print(f"{ms:.3f}ms {num_gb:.3f}GB {gb_per_s:.2f}GB/s")
''', device_str='cuda')
import triton
import triton.language as tl
from torch._inductor.runtime.triton_heuristics import grid, split_scan_grid, start_graph, end_graph
from torch._C import _cuda_getCurrentRawStream as get_raw_stream
# kernel path: /tmp/torchinductor_shunting/yt/cytwgbywfm7dtcvhkxpg6hkvllgzhv7o7bxsxaihptlzwjox3k7d.py
# Source Nodes: [masked_lm_loss], Original ATen: [aten.nll_loss_backward, aten.nll_loss_forward]
# masked_lm_loss => full_default_1
triton_poi_fused_nll_loss_backward_nll_loss_forward_1 = async_compile.triton('triton_poi_fused_nll_loss_backward_nll_loss_forward_1', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
from torch._dynamo.testing import rand_strided
from torch._C import _cuda_getCurrentRawStream as get_raw_stream
import torch
from torch._inductor.runtime.triton_heuristics import grid, split_scan_grid
@triton_heuristics.pointwise(
size_hints=[8192],
filename=__file__,
triton_meta={'signature': {0: '*i64', 1: '*fp32', 2: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2), equal_to_1=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_nll_loss_backward_nll_loss_forward_1', 'mutated_arg_names': ['out_ptr0'], 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': '196B341D951BBDD96DE5B2B44B37054C2CBA8E89494C8B8A1052A863B8BC7596', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'kernel_num_gb': 9.8304e-05},
min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused_nll_loss_backward_nll_loss_forward_1(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
xnumel = 8192
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = tl.full([XBLOCK], True, tl.int1)
x0 = xindex
tmp0 = tl.load(in_ptr0 + (x0), None)
tmp1 = tl.full([1], -100, tl.int64)
tmp2 = tmp0 != tmp1
tmp3 = tl.full([1], 0, tl.int64)
tmp4 = tl.where(tmp2, tmp0, tmp3)
tmp5 = tl.full([XBLOCK], 30522, tl.int32)
tmp6 = tmp4 + tmp5
tmp7 = tmp4 < 0
tmp8 = tl.where(tmp7, tmp6, tmp4)
tl.device_assert((0 <= tmp8) & (tmp8 < 30522), "index out of bounds: 0 <= tmp8 < 30522")
tmp10 = -1.0
tl.store(out_ptr0 + (tmp8 + (30528*x0)), tmp10, None)
def get_args():
arg_0 = rand_strided((16, 512), (512, 1), device='cuda:0', dtype=torch.int64)
arg_1 = rand_strided((8192, 30522), (30528, 1), device='cuda:0', dtype=torch.float32)
return arg_0, arg_1,
def call(args):
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
stream0 = get_raw_stream(0)
triton_poi_fused_nll_loss_backward_nll_loss_forward_1.run(*args, 8192, grid=grid(8192), stream=stream0)
def benchmark_all_configs(args):
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
return triton_poi_fused_nll_loss_backward_nll_loss_forward_1.benchmark_all_configs(*args, 8192, grid=grid(8192))
if __name__ == '__main__':
from triton.testing import do_bench
args = get_args()
ms = do_bench(lambda: call(args), rep=40, fast_flush=True)
num_gb = 9.8304e-05
gb_per_s = num_gb / (ms / 1e3)
print(f"{ms:.3f}ms {num_gb:.3f}GB {gb_per_s:.2f}GB/s")
''', device_str='cuda')
# kernel path: /tmp/torchinductor_shunting/2u/c2u2begnomazgjf7whgacopubkixewbtygsefgr73rsj6xnr6zj7.py
# Source Nodes: [masked_lm_loss], Original ATen: [aten._log_softmax_backward_data, aten.add, aten.nll_loss_backward, aten.nll_loss_forward]
# masked_lm_loss => full_default_2
triton_red_fused__log_softmax_backward_data_add_nll_loss_backward_nll_loss_forward_2 = async_compile.triton('triton_red_fused__log_softmax_backward_data_add_nll_loss_backward_nll_loss_forward_2', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
from torch._dynamo.testing import rand_strided
from torch._C import _cuda_getCurrentRawStream as get_raw_stream
import torch
from torch._inductor.runtime.triton_heuristics import grid, split_scan_grid
@triton_heuristics.reduction(
size_hints=[8192, 32768],
reduction_hint=ReductionHint.DEFAULT,
filename=__file__,
triton_meta={'signature': {0: '*fp32', 1: '*i64', 2: '*fp32', 3: '*fp32', 4: '*fp16', 5: '*fp16', 6: '*fp32', 7: '*fp32', 8: '*fp16', 9: 'i32', 10: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4, 5, 6, 7, 8, 9), equal_to_1=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused__log_softmax_backward_data_add_nll_loss_backward_nll_loss_forward_2', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 11, 'num_reduction': 1, 'backend_hash': '196B341D951BBDD96DE5B2B44B37054C2CBA8E89494C8B8A1052A863B8BC7596', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'kernel_num_gb': 1.500348424}
)
@triton.jit
def triton_red_fused__log_softmax_backward_data_add_nll_loss_backward_nll_loss_forward_2(in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, in_ptr5, in_ptr6, in_ptr7, out_ptr1, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr):
xnumel = 8192
rnumel = 30522
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
xmask = tl.full([XBLOCK, RBLOCK], True, tl.int1)
rbase = tl.arange(0, RBLOCK)[None, :]
x0 = xindex
tmp1 = tl.load(in_ptr1 + (x0), None, eviction_policy='evict_last')
tmp4 = tl.load(in_ptr2 + (0))
tmp5 = tl.broadcast_to(tmp4, [XBLOCK, RBLOCK])
tmp6 = tl.load(in_ptr3 + (0))
tmp7 = tl.broadcast_to(tmp6, [XBLOCK, RBLOCK])
_tmp13 = tl.full([XBLOCK, RBLOCK], 0, tl.float32)
for roffset in range(0, rnumel, RBLOCK):
rindex = roffset + rbase
rmask = rindex < rnumel
r1 = rindex
tmp0 = tl.load(in_ptr0 + (r1 + (30528*x0)), rmask, eviction_policy='evict_last', other=0.0)
tmp2 = tl.full([1, 1], -100, tl.int64)
tmp3 = tmp1 != tmp2
tmp8 = tmp5 / tmp7
tmp9 = 0.0
tmp10 = tl.where(tmp3, tmp8, tmp9)
tmp11 = tmp0 * tmp10
tmp12 = tl.broadcast_to(tmp11, [XBLOCK, RBLOCK])
tmp14 = _tmp13 + tmp12
_tmp13 = tl.where(rmask, tmp14, _tmp13)
tmp13 = tl.sum(_tmp13, 1)[:, None]
tmp19 = tl.load(in_ptr2 + (0))
tmp20 = tl.broadcast_to(tmp19, [XBLOCK, RBLOCK])
tmp21 = tl.load(in_ptr3 + (0))
tmp22 = tl.broadcast_to(tmp21, [XBLOCK, RBLOCK])
tmp29 = tl.load(in_ptr6 + (x0), None, eviction_policy='evict_last')
tmp31 = tl.load(in_ptr7 + (x0), None, eviction_policy='evict_last')
for roffset in range(0, rnumel, RBLOCK):
rindex = roffset + rbase
rmask = rindex < rnumel
r1 = rindex
tmp15 = tl.load(in_ptr4 + (r1 + (30522*x0)), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp16 = tl.load(in_ptr0 + (r1 + (30528*x0)), rmask, eviction_policy='evict_first', other=0.0)
tmp27 = tl.load(in_ptr5 + (r1 + (30528*x0)), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp17 = tl.full([1, 1], -100, tl.int64)
tmp18 = tmp1 != tmp17
tmp23 = tmp20 / tmp22
tmp24 = 0.0
tmp25 = tl.where(tmp18, tmp23, tmp24)
tmp26 = tmp16 * tmp25
tmp28 = tmp27.to(tl.float32)
tmp30 = tmp28 - tmp29
tmp32 = tmp30 - tmp31
tmp33 = tmp32.to(tl.float32)
tmp34 = tmp33.to(tl.float32)
tmp35 = tl_math.exp(tmp34)
tmp36 = tmp35 * tmp13
tmp37 = tmp26 - tmp36
tmp38 = tmp37.to(tl.float32)
tmp39 = tmp15 + tmp38
tl.store(out_ptr1 + (r1 + (30528*x0)), tmp39, rmask)
def get_args():
arg_0 = rand_strided((8192, 30522), (30528, 1), device='cuda:0', dtype=torch.float32)
arg_1 = rand_strided((16, 512), (512, 1), device='cuda:0', dtype=torch.int64)
arg_2 = rand_strided((), (), device='cuda:0', dtype=torch.float32)
arg_3 = rand_strided((), (), device='cuda:0', dtype=torch.float32)
arg_4 = rand_strided((16, 512, 30522), (15627264, 30522, 1), device='cuda:0', dtype=torch.float16)
arg_5 = rand_strided((16, 512, 30522), (15630336, 30528, 1), device='cuda:0', dtype=torch.float16)
arg_6 = rand_strided((8192, 1), (1, 1), device='cuda:0', dtype=torch.float32)
arg_7 = rand_strided((8192, 1), (1, 1), device='cuda:0', dtype=torch.float32)
arg_8 = rand_strided((16, 512, 30522), (15630336, 30528, 1), device='cuda:0', dtype=torch.float16)
return arg_0, arg_1, arg_2, arg_3, arg_4, arg_5, arg_6, arg_7, arg_8,
def call(args):
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
stream0 = get_raw_stream(0)
triton_red_fused__log_softmax_backward_data_add_nll_loss_backward_nll_loss_forward_2.run(*args, 8192, 30522, grid=grid(8192), stream=stream0)
def benchmark_all_configs(args):
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
return triton_red_fused__log_softmax_backward_data_add_nll_loss_backward_nll_loss_forward_2.benchmark_all_configs(*args, 8192, 30522, grid=grid(8192))
if __name__ == '__main__':
from triton.testing import do_bench
args = get_args()
ms = do_bench(lambda: call(args), rep=40, fast_flush=True)
num_gb = 1.500348424
gb_per_s = num_gb / (ms / 1e3)
print(f"{ms:.3f}ms {num_gb:.3f}GB {gb_per_s:.2f}GB/s")
''', device_str='cuda')
# kernel path: /tmp/torchinductor_shunting/ny/cny7ookgvrcyfphkgshx7jmxx4kzb2nhzoiwqlawy3qkjrc5jc3x.py
# Source Nodes: [], Original ATen: []
triton_poi_fused_3 = async_compile.triton('triton_poi_fused_3', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
from torch._dynamo.testing import rand_strided
from torch._C import _cuda_getCurrentRawStream as get_raw_stream
import torch
from torch._inductor.runtime.triton_heuristics import grid, split_scan_grid
@triton_heuristics.pointwise(
size_hints=[268435456],
filename=__file__,
triton_meta={'signature': {0: '*fp16', 1: '*fp16', 2: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2), equal_to_1=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_3', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': '196B341D951BBDD96DE5B2B44B37054C2CBA8E89494C8B8A1052A863B8BC7596', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'kernel_num_gb': 1.0002432},
min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused_3(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
xnumel = 250085376
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = tl.full([XBLOCK], True, tl.int1)
x0 = xindex % 30528
x2 = xindex
tmp0 = x0
tmp1 = tl.full([1], 30522, tl.int64)
tmp2 = tmp0 < tmp1
tmp3 = tl.load(in_ptr0 + (x2), tmp2, other=0.0).to(tl.float32)
tl.store(out_ptr0 + (x2), tmp3, None)
def get_args():
arg_0 = rand_strided((16, 512, 30522), (15630336, 30528, 1), device='cuda:0', dtype=torch.float16)
arg_1 = rand_strided((8192, 30528), (30528, 1), device='cuda:0', dtype=torch.float16)
return arg_0, arg_1,
def call(args):
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
stream0 = get_raw_stream(0)
triton_poi_fused_3.run(*args, 250085376, grid=grid(250085376), stream=stream0)
def benchmark_all_configs(args):
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
return triton_poi_fused_3.benchmark_all_configs(*args, 250085376, grid=grid(250085376))
if __name__ == '__main__':
from triton.testing import do_bench
args = get_args()
ms = do_bench(lambda: call(args), rep=40, fast_flush=True)
num_gb = 1.0002432
gb_per_s = num_gb / (ms / 1e3)
print(f"{ms:.3f}ms {num_gb:.3f}GB {gb_per_s:.2f}GB/s")
''', device_str='cuda')
# kernel path: /tmp/torchinductor_shunting/kl/ckldmag4lbnoqqb2xa3zz7bszzgzh6df4wqv2lgxnhv7xzrvm2a3.py
# Source Nodes: [], Original ATen: []
triton_poi_fused_4 = async_compile.triton('triton_poi_fused_4', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
from torch._dynamo.testing import rand_strided
from torch._C import _cuda_getCurrentRawStream as get_raw_stream
import torch
from torch._inductor.runtime.triton_heuristics import grid, split_scan_grid
@triton_heuristics.pointwise(
size_hints=[33554432],
filename=__file__,
triton_meta={'signature': {0: '*fp16', 1: '*fp16', 2: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2), equal_to_1=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_4', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': '196B341D951BBDD96DE5B2B44B37054C2CBA8E89494C8B8A1052A863B8BC7596', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'kernel_num_gb': 0.0937728},
min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused_4(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
xnumel = 23445504
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = tl.full([XBLOCK], True, tl.int1)
x1 = (xindex // 768)
x2 = xindex
tmp0 = x1
tmp1 = tl.full([1], 30522, tl.int64)
tmp2 = tmp0 < tmp1
tmp3 = tl.load(in_ptr0 + (x2), tmp2, other=0.0).to(tl.float32)
tl.store(out_ptr0 + (x2), tmp3, None)
def get_args():
arg_0 = rand_strided((30522, 768), (768, 1), device='cuda:0', dtype=torch.float16)
arg_1 = rand_strided((30528, 768), (768, 1), device='cuda:0', dtype=torch.float16)
return arg_0, arg_1,
def call(args):
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
stream0 = get_raw_stream(0)
triton_poi_fused_4.run(*args, 23445504, grid=grid(23445504), stream=stream0)
def benchmark_all_configs(args):
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
return triton_poi_fused_4.benchmark_all_configs(*args, 23445504, grid=grid(23445504))
if __name__ == '__main__':
from triton.testing import do_bench
args = get_args()
ms = do_bench(lambda: call(args), rep=40, fast_flush=True)
num_gb = 0.0937728
gb_per_s = num_gb / (ms / 1e3)
print(f"{ms:.3f}ms {num_gb:.3f}GB {gb_per_s:.2f}GB/s")
''', device_str='cuda')
# kernel path: /tmp/torchinductor_shunting/kv/ckvnavggszj55y6z3joj3ix2bejtrydnigqdoidqzr33bzb5ewym.py
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum]
triton_red_fused__to_copy_sum_5 = async_compile.triton('triton_red_fused__to_copy_sum_5', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
from torch._dynamo.testing import rand_strided
from torch._C import _cuda_getCurrentRawStream as get_raw_stream
import torch
from torch._inductor.runtime.triton_heuristics import grid, split_scan_grid
@triton_heuristics.reduction(
size_hints=[32768, 8192],
reduction_hint=ReductionHint.DEFAULT,
filename=__file__,
triton_meta={'signature': {0: '*fp16', 1: '*fp32', 2: 'i32', 3: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 3), equal_to_1=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_sum_5', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': '196B341D951BBDD96DE5B2B44B37054C2CBA8E89494C8B8A1052A863B8BC7596', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'kernel_num_gb': 0.500194536}
)
@triton.jit
def triton_red_fused__to_copy_sum_5(in_ptr0, out_ptr1, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr):
xnumel = 30522
rnumel = 8192
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
xmask = xindex < xnumel
rbase = tl.arange(0, RBLOCK)[None, :]
x0 = xindex
_tmp2 = tl.full([XBLOCK, RBLOCK], 0, tl.float32)
for roffset in range(0, rnumel, RBLOCK):
rindex = roffset + rbase
rmask = rindex < rnumel
r1 = rindex
tmp0 = tl.load(in_ptr0 + (x0 + (30528*r1)), rmask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp1 = tl.broadcast_to(tmp0, [XBLOCK, RBLOCK])
tmp3 = _tmp2 + tmp1
_tmp2 = tl.where(rmask & xmask, tmp3, _tmp2)
tmp2 = tl.sum(_tmp2, 1)[:, None]
tmp4 = tmp2.to(tl.float32)
tl.store(out_ptr1 + (x0), tmp4, xmask)
def get_args():
arg_0 = rand_strided((16, 512, 30522), (15630336, 30528, 1), device='cuda:0', dtype=torch.float16)
arg_1 = rand_strided((30522,), (1,), device='cuda:0', dtype=torch.float32)
return arg_0, arg_1,
def call(args):
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
stream0 = get_raw_stream(0)
triton_red_fused__to_copy_sum_5.run(*args, 30522, 8192, grid=grid(30522), stream=stream0)
def benchmark_all_configs(args):
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
return triton_red_fused__to_copy_sum_5.benchmark_all_configs(*args, 30522, 8192, grid=grid(30522))
if __name__ == '__main__':
from triton.testing import do_bench
args = get_args()
ms = do_bench(lambda: call(args), rep=40, fast_flush=True)
num_gb = 0.500194536
gb_per_s = num_gb / (ms / 1e3)
print(f"{ms:.3f}ms {num_gb:.3f}GB {gb_per_s:.2f}GB/s")
''', device_str='cuda')
# kernel path: /tmp/torchinductor_shunting/l6/cl6ssrtjbepd5tssmygz54usro7zta7vm56aayideyxsk3wu3wsb.py
# Source Nodes: [], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_6 = async_compile.triton('triton_poi_fused__to_copy_6', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
from torch._dynamo.testing import rand_strided
from torch._C import _cuda_getCurrentRawStream as get_raw_stream
import torch
from torch._inductor.runtime.triton_heuristics import grid, split_scan_grid
@triton_heuristics.pointwise(
size_hints=[33554432],
filename=__file__,
triton_meta={'signature': {0: '*fp16', 1: '*fp32', 2: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2), equal_to_1=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_6', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': '196B341D951BBDD96DE5B2B44B37054C2CBA8E89494C8B8A1052A863B8BC7596', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'kernel_num_gb': 0.140645376},
min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused__to_copy_6(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
xnumel = 23440896
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = xindex < xnumel
x0 = xindex
tmp0 = tl.load(in_ptr0 + (x0), xmask).to(tl.float32)
tmp1 = tmp0.to(tl.float32)
tl.store(out_ptr0 + (x0), tmp1, xmask)
def get_args():
arg_0 = rand_strided((30522, 768), (768, 1), device='cuda:0', dtype=torch.float16)
arg_1 = rand_strided((30522, 768), (768, 1), device='cuda:0', dtype=torch.float32)
return arg_0, arg_1,
def call(args):
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_6.run(*args, 23440896, grid=grid(23440896), stream=stream0)
def benchmark_all_configs(args):
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
return triton_poi_fused__to_copy_6.benchmark_all_configs(*args, 23440896, grid=grid(23440896))
if __name__ == '__main__':
from triton.testing import do_bench
args = get_args()
ms = do_bench(lambda: call(args), rep=40, fast_flush=True)
num_gb = 0.140645376
gb_per_s = num_gb / (ms / 1e3)
print(f"{ms:.3f}ms {num_gb:.3f}GB {gb_per_s:.2f}GB/s")
''', device_str='cuda')
# kernel path: /tmp/torchinductor_shunting/mc/cmcvv4a23higyc5z6e7nkc7q4ndfid6u7bhzw6yni6f5ofdbjoue.py
# Source Nodes: [hidden_states_97, hidden_states_98], Original ATen: [aten._to_copy, aten.gelu, aten.gelu_backward, aten.native_layer_norm, aten.native_layer_norm_backward, aten.view]
# hidden_states_97 => add_100, convert_element_type_498, erf_12, mul_161, mul_162, mul_163
# hidden_states_98 => mul_164, sub_38
triton_per_fused__to_copy_gelu_gelu_backward_native_layer_norm_native_layer_norm_backward_view_7 = async_compile.triton('triton_per_fused__to_copy_gelu_gelu_backward_native_layer_norm_native_layer_norm_backward_view_7', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
from torch._dynamo.testing import rand_strided
from torch._C import _cuda_getCurrentRawStream as get_raw_stream
import torch
from torch._inductor.runtime.triton_heuristics import grid, split_scan_grid
@triton_heuristics.persistent_reduction(
size_hints=[8192, 1024],
reduction_hint=ReductionHint.INNER,
filename=__file__,
triton_meta={'signature': {0: '*fp16', 1: '*fp32', 2: '*fp16', 3: '*fp32', 4: '*fp32', 5: '*fp16', 6: 'i32', 7: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4, 5, 6, 7), equal_to_1=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_gelu_gelu_backward_native_layer_norm_native_layer_norm_backward_view_7', 'mutated_arg_names': [], 'no_x_dim': True, 'num_load': 5, 'num_reduction': 2, 'backend_hash': '196B341D951BBDD96DE5B2B44B37054C2CBA8E89494C8B8A1052A863B8BC7596', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'kernel_num_gb': 0.037817344}
)
@triton.jit
def triton_per_fused__to_copy_gelu_gelu_backward_native_layer_norm_native_layer_norm_backward_view_7(in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, out_ptr3, xnumel, rnumel):
xnumel = 8192
XBLOCK: tl.constexpr = 1
rnumel = 768
RBLOCK: tl.constexpr = 1024
xoffset = tl.program_id(0) * XBLOCK
xindex = tl.full([1], xoffset, tl.int32)
xmask = tl.full([RBLOCK], True, tl.int1)
rindex = tl.arange(0, RBLOCK)[:]
roffset = 0
rmask = rindex < rnumel
r1 = rindex
x0 = xindex
tmp0 = tl.load(in_ptr0 + (r1 + (768*x0)), rmask, other=0.0).to(tl.float32)
tmp2 = tl.load(in_ptr1 + (r1), rmask, eviction_policy='evict_last', other=0.0)
tmp8 = tl.load(in_ptr2 + (r1 + (768*x0)), rmask, other=0.0).to(tl.float32)
tmp18 = tl.load(in_ptr3 + (x0), None, eviction_policy='evict_last')
tmp20 = tl.load(in_ptr4 + (x0), None, eviction_policy='evict_last')
tmp1 = tmp0.to(tl.float32)
tmp3 = tmp1 * tmp2
tmp4 = tl.broadcast_to(tmp3, [RBLOCK])
tmp6 = tl.where(rmask, tmp4, 0)
tmp7 = triton_helpers.promote_to_tensor(tl.sum(tmp6, 0))
tmp9 = tmp8.to(tl.float32)
tmp10 = 0.5
tmp11 = tmp9 * tmp10
tmp12 = 0.7071067811865476
tmp13 = tmp9 * tmp12
tmp14 = libdevice.erf(tmp13)
tmp15 = 1.0
tmp16 = tmp14 + tmp15
tmp17 = tmp11 * tmp16
tmp19 = tmp17 - tmp18
tmp21 = tmp19 * tmp20
tmp22 = tmp3 * tmp21
tmp23 = tl.broadcast_to(tmp22, [RBLOCK])
tmp25 = tl.where(rmask, tmp23, 0)
tmp26 = triton_helpers.promote_to_tensor(tl.sum(tmp25, 0))
tmp27 = 0.0013020833333333333
tmp28 = tmp20 * tmp27
tmp29 = 768.0
tmp30 = tmp3 * tmp29
tmp31 = tmp30 - tmp7
tmp32 = tmp21 * tmp26
tmp33 = tmp31 - tmp32
tmp34 = tmp28 * tmp33
tmp35 = tmp16 * tmp10
tmp36 = tmp9 * tmp9
tmp37 = -0.5
tmp38 = tmp36 * tmp37
tmp39 = tl_math.exp(tmp38)
tmp40 = 0.3989422804014327
tmp41 = tmp39 * tmp40
tmp42 = tmp9 * tmp41
tmp43 = tmp35 + tmp42
tmp44 = tmp34 * tmp43
tmp45 = tmp44.to(tl.float32)
tl.store(out_ptr3 + (r1 + (768*x0)), tmp45, rmask)
def get_args():
arg_0 = rand_strided((8192, 768), (768, 1), device='cuda:0', dtype=torch.float16)
arg_1 = rand_strided((768,), (1,), device='cuda:0', dtype=torch.float32)
arg_2 = rand_strided((8192, 768), (768, 1), device='cuda:0', dtype=torch.float16)
arg_3 = rand_strided((16, 512, 1), (512, 1, 1), device='cuda:0', dtype=torch.float32)
arg_4 = rand_strided((16, 512, 1), (512, 1, 1), device='cuda:0', dtype=torch.float32)
arg_5 = rand_strided((8192, 768), (768, 1), device='cuda:0', dtype=torch.float16)
return arg_0, arg_1, arg_2, arg_3, arg_4, arg_5,
def call(args):
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
stream0 = get_raw_stream(0)
triton_per_fused__to_copy_gelu_gelu_backward_native_layer_norm_native_layer_norm_backward_view_7.run(*args, 8192, 768, grid=grid(8192), stream=stream0)
def benchmark_all_configs(args):
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
return triton_per_fused__to_copy_gelu_gelu_backward_native_layer_norm_native_layer_norm_backward_view_7.benchmark_all_configs(*args, 8192, 768, grid=grid(8192))
if __name__ == '__main__':
from triton.testing import do_bench
args = get_args()
ms = do_bench(lambda: call(args), rep=40, fast_flush=True)
num_gb = 0.037817344
gb_per_s = num_gb / (ms / 1e3)
print(f"{ms:.3f}ms {num_gb:.3f}GB {gb_per_s:.2f}GB/s")
''', device_str='cuda')
# kernel path: /tmp/torchinductor_shunting/xn/cxn73hmq2otzylcbqlg6bpgfz5pcbvnul2dl43ea3uixl37qrhue.py
# Source Nodes: [hidden_states_97, hidden_states_98], Original ATen: [aten._to_copy, aten.gelu, aten.native_layer_norm, aten.native_layer_norm_backward]
# hidden_states_97 => add_100, convert_element_type_498, erf_12, mul_161, mul_162, mul_163
# hidden_states_98 => mul_164, sub_38
triton_red_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_8 = async_compile.triton('triton_red_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_8', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
from torch._dynamo.testing import rand_strided
from torch._C import _cuda_getCurrentRawStream as get_raw_stream
import torch
from torch._inductor.runtime.triton_heuristics import grid, split_scan_grid
@triton_heuristics.reduction(
size_hints=[65536, 128],
reduction_hint=ReductionHint.OUTER,
filename=__file__,
triton_meta={'signature': {0: '*fp16', 1: '*fp16', 2: '*fp32', 3: '*fp32', 4: '*fp32', 5: '*fp32', 6: 'i32', 7: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4, 5, 6, 7), equal_to_1=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_8', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 4, 'num_reduction': 2, 'backend_hash': '196B341D951BBDD96DE5B2B44B37054C2CBA8E89494C8B8A1052A863B8BC7596', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'kernel_num_gb': 0.025624576}
)
@triton.jit
def triton_red_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_8(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr0, out_ptr1, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr):
xnumel = 49152
rnumel = 128
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
xmask = tl.full([XBLOCK, RBLOCK], True, tl.int1)
rbase = tl.arange(0, RBLOCK)[None, :]
x0 = xindex % 768
x1 = (xindex // 768)
_tmp18 = tl.full([XBLOCK, RBLOCK], 0, tl.float32)
x3 = xindex
_tmp21 = tl.full([XBLOCK, RBLOCK], 0, tl.float32)
for roffset in range(0, rnumel, RBLOCK):
rindex = roffset + rbase
rmask = rindex < rnumel
r2 = rindex
tmp0 = tl.load(in_ptr0 + (x0 + (768*r2) + (98304*x1)), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp2 = tl.load(in_ptr1 + (x0 + (768*r2) + (98304*x1)), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp12 = tl.load(in_ptr2 + (r2 + (128*x1)), rmask, eviction_policy='evict_last', other=0.0)
tmp14 = tl.load(in_ptr3 + (r2 + (128*x1)), rmask, eviction_policy='evict_last', other=0.0)
tmp1 = tmp0.to(tl.float32)
tmp3 = tmp2.to(tl.float32)
tmp4 = 0.5
tmp5 = tmp3 * tmp4
tmp6 = 0.7071067811865476
tmp7 = tmp3 * tmp6
tmp8 = libdevice.erf(tmp7)
tmp9 = 1.0
tmp10 = tmp8 + tmp9
tmp11 = tmp5 * tmp10
tmp13 = tmp11 - tmp12
tmp15 = tmp13 * tmp14
tmp16 = tmp1 * tmp15
tmp17 = tl.broadcast_to(tmp16, [XBLOCK, RBLOCK])
tmp19 = _tmp18 + tmp17
_tmp18 = tl.where(rmask, tmp19, _tmp18)
tmp20 = tl.broadcast_to(tmp1, [XBLOCK, RBLOCK])
tmp22 = _tmp21 + tmp20
_tmp21 = tl.where(rmask, tmp22, _tmp21)
tmp18 = tl.sum(_tmp18, 1)[:, None]
tl.store(out_ptr0 + (x3), tmp18, None)
tmp21 = tl.sum(_tmp21, 1)[:, None]
tl.store(out_ptr1 + (x3), tmp21, None)
def get_args():
arg_0 = rand_strided((8192, 768), (768, 1), device='cuda:0', dtype=torch.float16)
arg_1 = rand_strided((8192, 768), (768, 1), device='cuda:0', dtype=torch.float16)
arg_2 = rand_strided((16, 512, 1), (512, 1, 1), device='cuda:0', dtype=torch.float32)
arg_3 = rand_strided((16, 512, 1), (512, 1, 1), device='cuda:0', dtype=torch.float32)
arg_4 = rand_strided((768, 64), (1, 768), device='cuda:0', dtype=torch.float32)
arg_5 = rand_strided((768, 64), (1, 768), device='cuda:0', dtype=torch.float32)
return arg_0, arg_1, arg_2, arg_3, arg_4, arg_5,
def call(args):
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
stream0 = get_raw_stream(0)
triton_red_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_8.run(*args, 49152, 128, grid=grid(49152), stream=stream0)
def benchmark_all_configs(args):
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
return triton_red_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_8.benchmark_all_configs(*args, 49152, 128, grid=grid(49152))
if __name__ == '__main__':
from triton.testing import do_bench
args = get_args()
ms = do_bench(lambda: call(args), rep=40, fast_flush=True)
num_gb = 0.025624576
gb_per_s = num_gb / (ms / 1e3)
print(f"{ms:.3f}ms {num_gb:.3f}GB {gb_per_s:.2f}GB/s")
''', device_str='cuda')
# kernel path: /tmp/torchinductor_shunting/xb/cxbiab5uy5kj4w6v3uoxvyncbfkyqo4xmh7csh6vz5yz4dut6xo2.py
# Source Nodes: [hidden_states_97, hidden_states_98], Original ATen: [aten._to_copy, aten.gelu, aten.native_layer_norm, aten.native_layer_norm_backward]
# hidden_states_97 => add_100, convert_element_type_498, erf_12, mul_161, mul_162, mul_163
# hidden_states_98 => mul_164, sub_38
triton_per_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_9 = async_compile.triton('triton_per_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_9', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
from torch._dynamo.testing import rand_strided
from torch._C import _cuda_getCurrentRawStream as get_raw_stream
import torch
from torch._inductor.runtime.triton_heuristics import grid, split_scan_grid
@triton_heuristics.persistent_reduction(
size_hints=[1024, 64],
reduction_hint=ReductionHint.OUTER,
filename=__file__,
triton_meta={'signature': {0: '*fp32', 1: '*fp32', 2: 'i32', 3: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3), equal_to_1=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_9', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': '196B341D951BBDD96DE5B2B44B37054C2CBA8E89494C8B8A1052A863B8BC7596', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'kernel_num_gb': 0.00019968}
)
@triton.jit
def triton_per_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_9(in_ptr0, out_ptr0, xnumel, rnumel, XBLOCK : tl.constexpr):
xnumel = 768
rnumel = 64
RBLOCK: tl.constexpr = 64
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
xmask = xindex < xnumel
rindex = tl.arange(0, RBLOCK)[None, :]
roffset = 0
rmask = tl.full([XBLOCK, RBLOCK], True, tl.int1)
r1 = rindex
x0 = xindex
tmp0 = tl.load(in_ptr0 + (x0 + (768*r1)), xmask, other=0.0)
tmp1 = tl.broadcast_to(tmp0, [XBLOCK, RBLOCK])
tmp3 = tl.where(xmask, tmp1, 0)
tmp4 = tl.sum(tmp3, 1)[:, None]
tl.store(out_ptr0 + (x0), tmp4, xmask)
def get_args():
arg_0 = rand_strided((768, 64), (1, 768), device='cuda:0', dtype=torch.float32)
arg_1 = rand_strided((768,), (1,), device='cuda:0', dtype=torch.float32)
return arg_0, arg_1,
def call(args):
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
stream0 = get_raw_stream(0)
triton_per_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_9.run(*args, 768, 64, grid=grid(768), stream=stream0)
def benchmark_all_configs(args):
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
return triton_per_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_9.benchmark_all_configs(*args, 768, 64, grid=grid(768))
if __name__ == '__main__':
from triton.testing import do_bench
args = get_args()
ms = do_bench(lambda: call(args), rep=40, fast_flush=True)
num_gb = 0.00019968
gb_per_s = num_gb / (ms / 1e3)
print(f"{ms:.3f}ms {num_gb:.3f}GB {gb_per_s:.2f}GB/s")
''', device_str='cuda')
# kernel path: /tmp/torchinductor_shunting/xw/cxwo2kn3atkuyrqgmindh7m4zzqaakdmurrxxe3trabcblnihnwq.py
# Source Nodes: [], Original ATen: [aten.sum]
triton_red_fused_sum_10 = async_compile.triton('triton_red_fused_sum_10', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
from torch._dynamo.testing import rand_strided
from torch._C import _cuda_getCurrentRawStream as get_raw_stream
import torch
from torch._inductor.runtime.triton_heuristics import grid, split_scan_grid
@triton_heuristics.reduction(
size_hints=[65536, 128],
reduction_hint=ReductionHint.OUTER,
filename=__file__,
triton_meta={'signature': {0: '*fp16', 1: '*fp32', 2: 'i32', 3: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3), equal_to_1=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused_sum_10', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': '196B341D951BBDD96DE5B2B44B37054C2CBA8E89494C8B8A1052A863B8BC7596', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'kernel_num_gb': 0.01277952}
)
@triton.jit
def triton_red_fused_sum_10(in_ptr0, out_ptr0, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr):
xnumel = 49152
rnumel = 128
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
xmask = tl.full([XBLOCK, RBLOCK], True, tl.int1)
rbase = tl.arange(0, RBLOCK)[None, :]
x0 = xindex % 768
x1 = (xindex // 768)
_tmp2 = tl.full([XBLOCK, RBLOCK], 0, tl.float32)
x3 = xindex
for roffset in range(0, rnumel, RBLOCK):
rindex = roffset + rbase
rmask = rindex < rnumel
r2 = rindex
tmp0 = tl.load(in_ptr0 + (x0 + (768*r2) + (98304*x1)), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp1 = tl.broadcast_to(tmp0, [XBLOCK, RBLOCK])
tmp3 = _tmp2 + tmp1
_tmp2 = tl.where(rmask, tmp3, _tmp2)
tmp2 = tl.sum(_tmp2, 1)[:, None]
tl.store(out_ptr0 + (x3), tmp2, None)
def get_args():
arg_0 = rand_strided((8192, 768), (768, 1), device='cuda:0', dtype=torch.float16)
arg_1 = rand_strided((1, 768, 64), (49152, 1, 768), device='cuda:0', dtype=torch.float32)
return arg_0, arg_1,
def call(args):
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
stream0 = get_raw_stream(0)
triton_red_fused_sum_10.run(*args, 49152, 128, grid=grid(49152), stream=stream0)
def benchmark_all_configs(args):
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
return triton_red_fused_sum_10.benchmark_all_configs(*args, 49152, 128, grid=grid(49152))
if __name__ == '__main__':
from triton.testing import do_bench
args = get_args()
ms = do_bench(lambda: call(args), rep=40, fast_flush=True)
num_gb = 0.01277952
gb_per_s = num_gb / (ms / 1e3)
print(f"{ms:.3f}ms {num_gb:.3f}GB {gb_per_s:.2f}GB/s")
''', device_str='cuda')
# kernel path: /tmp/torchinductor_shunting/bw/cbwoo26iavmb6dt37knpfmallrwdgfbg4bmorajo736d3byifyha.py
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum]
triton_per_fused__to_copy_sum_11 = async_compile.triton('triton_per_fused__to_copy_sum_11', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
from torch._dynamo.testing import rand_strided
from torch._C import _cuda_getCurrentRawStream as get_raw_stream
import torch
from torch._inductor.runtime.triton_heuristics import grid, split_scan_grid
@triton_heuristics.persistent_reduction(
size_hints=[1024, 64],
reduction_hint=ReductionHint.OUTER,
filename=__file__,
triton_meta={'signature': {0: '*fp32', 1: '*fp32', 2: 'i32', 3: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3), equal_to_1=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_sum_11', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': '196B341D951BBDD96DE5B2B44B37054C2CBA8E89494C8B8A1052A863B8BC7596', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'kernel_num_gb': 0.00019968}
)
@triton.jit
def triton_per_fused__to_copy_sum_11(in_ptr0, out_ptr1, xnumel, rnumel, XBLOCK : tl.constexpr):
xnumel = 768
rnumel = 64
RBLOCK: tl.constexpr = 64
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
xmask = xindex < xnumel
rindex = tl.arange(0, RBLOCK)[None, :]
roffset = 0
rmask = tl.full([XBLOCK, RBLOCK], True, tl.int1)
r1 = rindex
x0 = xindex
tmp0 = tl.load(in_ptr0 + (x0 + (768*r1)), xmask, other=0.0)
tmp1 = tl.broadcast_to(tmp0, [XBLOCK, RBLOCK])
tmp3 = tl.where(xmask, tmp1, 0)
tmp4 = tl.sum(tmp3, 1)[:, None]
tmp5 = tmp4.to(tl.float32)
tl.store(out_ptr1 + (x0), tmp5, xmask)
def get_args():
arg_0 = rand_strided((1, 768, 64), (49152, 1, 768), device='cuda:0', dtype=torch.float32)
arg_1 = rand_strided((768,), (1,), device='cuda:0', dtype=torch.float32)
return arg_0, arg_1,
def call(args):
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
stream0 = get_raw_stream(0)
triton_per_fused__to_copy_sum_11.run(*args, 768, 64, grid=grid(768), stream=stream0)
def benchmark_all_configs(args):
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
return triton_per_fused__to_copy_sum_11.benchmark_all_configs(*args, 768, 64, grid=grid(768))
if __name__ == '__main__':
from triton.testing import do_bench
args = get_args()
ms = do_bench(lambda: call(args), rep=40, fast_flush=True)
num_gb = 0.00019968
gb_per_s = num_gb / (ms / 1e3)
print(f"{ms:.3f}ms {num_gb:.3f}GB {gb_per_s:.2f}GB/s")
''', device_str='cuda')
# kernel path: /tmp/torchinductor_shunting/mz/cmzinemkxyqageuju6a26rj4aqmkt2l7ysoziwfzbw253z3tb63k.py
# Source Nodes: [], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_12 = async_compile.triton('triton_poi_fused__to_copy_12', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
from torch._dynamo.testing import rand_strided
from torch._C import _cuda_getCurrentRawStream as get_raw_stream
import torch
from torch._inductor.runtime.triton_heuristics import grid, split_scan_grid
@triton_heuristics.pointwise(
size_hints=[1048576],
filename=__file__,
triton_meta={'signature': {0: '*fp16', 1: '*fp32', 2: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2), equal_to_1=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_12', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': '196B341D951BBDD96DE5B2B44B37054C2CBA8E89494C8B8A1052A863B8BC7596', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'kernel_num_gb': 0.003538944},
min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused__to_copy_12(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
xnumel = 589824
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = tl.full([XBLOCK], True, tl.int1)
x0 = xindex
tmp0 = tl.load(in_ptr0 + (x0), None).to(tl.float32)
tmp1 = tmp0.to(tl.float32)
tl.store(out_ptr0 + (x0), tmp1, None)
def get_args():
arg_0 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.float16)
arg_1 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.float32)
return arg_0, arg_1,
def call(args):
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_12.run(*args, 589824, grid=grid(589824), stream=stream0)
def benchmark_all_configs(args):
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
return triton_poi_fused__to_copy_12.benchmark_all_configs(*args, 589824, grid=grid(589824))
if __name__ == '__main__':
from triton.testing import do_bench
args = get_args()
ms = do_bench(lambda: call(args), rep=40, fast_flush=True)
num_gb = 0.003538944
gb_per_s = num_gb / (ms / 1e3)
print(f"{ms:.3f}ms {num_gb:.3f}GB {gb_per_s:.2f}GB/s")
''', device_str='cuda')
# kernel path: /tmp/torchinductor_shunting/l4/cl4sftqv3rdsorngubndfo524ezgoqyc7gyppiqa5qu2yltoipgv.py
# Source Nodes: [], Original ATen: [aten._to_copy, aten.native_dropout_backward, aten.native_layer_norm_backward]
triton_per_fused__to_copy_native_dropout_backward_native_layer_norm_backward_13 = async_compile.triton('triton_per_fused__to_copy_native_dropout_backward_native_layer_norm_backward_13', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
from torch._dynamo.testing import rand_strided
from torch._C import _cuda_getCurrentRawStream as get_raw_stream
import torch
from torch._inductor.runtime.triton_heuristics import grid, split_scan_grid
@triton_heuristics.persistent_reduction(
size_hints=[8192, 1024],
reduction_hint=ReductionHint.INNER,
filename=__file__,
triton_meta={'signature': {0: '*fp16', 1: '*fp32', 2: '*fp32', 3: '*fp32', 4: '*i1', 5: '*fp32', 6: '*fp16', 7: 'i32', 8: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4, 5, 6, 7, 8), equal_to_1=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_native_dropout_backward_native_layer_norm_backward_13', 'mutated_arg_names': [], 'no_x_dim': True, 'num_load': 5, 'num_reduction': 2, 'backend_hash': '196B341D951BBDD96DE5B2B44B37054C2CBA8E89494C8B8A1052A863B8BC7596', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'kernel_num_gb': 0.081824768}
)
@triton.jit
def triton_per_fused__to_copy_native_dropout_backward_native_layer_norm_backward_13(in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, out_ptr2, out_ptr3, xnumel, rnumel):
xnumel = 8192
XBLOCK: tl.constexpr = 1
rnumel = 768
RBLOCK: tl.constexpr = 1024
xoffset = tl.program_id(0) * XBLOCK
xindex = tl.full([1], xoffset, tl.int32)
xmask = tl.full([RBLOCK], True, tl.int1)
rindex = tl.arange(0, RBLOCK)[:]
roffset = 0
rmask = rindex < rnumel
r1 = rindex
x0 = xindex
tmp0 = tl.load(in_ptr0 + (r1 + (768*x0)), rmask, other=0.0).to(tl.float32)
tmp2 = tl.load(in_ptr1 + (r1), rmask, eviction_policy='evict_last', other=0.0)
tmp8 = tl.load(in_ptr2 + (r1 + (768*x0)), rmask, other=0.0)
tmp14 = tl.load(in_ptr3 + (x0), None, eviction_policy='evict_last')
tmp22 = tl.load(in_ptr4 + (r1 + (768*x0)), rmask, other=0.0).to(tl.int1)
tmp1 = tmp0.to(tl.float32)
tmp3 = tmp1 * tmp2
tmp4 = tl.broadcast_to(tmp3, [RBLOCK])
tmp6 = tl.where(rmask, tmp4, 0)
tmp7 = triton_helpers.promote_to_tensor(tl.sum(tmp6, 0))
tmp9 = tmp3 * tmp8
tmp10 = tl.broadcast_to(tmp9, [RBLOCK])
tmp12 = tl.where(rmask, tmp10, 0)
tmp13 = triton_helpers.promote_to_tensor(tl.sum(tmp12, 0))
tmp15 = 768.0
tmp16 = tmp3 * tmp15
tmp17 = tmp16 - tmp7
tmp18 = tmp8 * tmp13
tmp19 = tmp17 - tmp18
tmp20 = tmp14 * tmp19
tmp21 = tmp20.to(tl.float32)
tmp23 = tmp22.to(tl.float32)
tmp24 = 1.1111111111111112
tmp25 = tmp23 * tmp24
tmp26 = tmp21 * tmp25
tl.store(out_ptr2 + (r1 + (768*x0)), tmp20, rmask)
tl.store(out_ptr3 + (r1 + (768*x0)), tmp26, rmask)
def get_args():
arg_0 = rand_strided((8192, 768), (768, 1), device='cuda:0', dtype=torch.float16)
arg_1 = rand_strided((768,), (1,), device='cuda:0', dtype=torch.float32)
arg_2 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.float32)
arg_3 = rand_strided((16, 512, 1), (512, 1, 1), device='cuda:0', dtype=torch.float32)
arg_4 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.bool)
arg_5 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.float32)
arg_6 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.float16)
return arg_0, arg_1, arg_2, arg_3, arg_4, arg_5, arg_6,
def call(args):
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
stream0 = get_raw_stream(0)
triton_per_fused__to_copy_native_dropout_backward_native_layer_norm_backward_13.run(*args, 8192, 768, grid=grid(8192), stream=stream0)
def benchmark_all_configs(args):
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
return triton_per_fused__to_copy_native_dropout_backward_native_layer_norm_backward_13.benchmark_all_configs(*args, 8192, 768, grid=grid(8192))
if __name__ == '__main__':
from triton.testing import do_bench
args = get_args()
ms = do_bench(lambda: call(args), rep=40, fast_flush=True)
num_gb = 0.081824768
gb_per_s = num_gb / (ms / 1e3)
print(f"{ms:.3f}ms {num_gb:.3f}GB {gb_per_s:.2f}GB/s")
''', device_str='cuda')
# kernel path: /tmp/torchinductor_shunting/pc/cpcxsiinpg3g3z4rnd4e54gqzv3auqdg34x7kolps2lll75bhrdh.py
# Source Nodes: [], Original ATen: [aten._to_copy, aten.native_layer_norm_backward]
triton_red_fused__to_copy_native_layer_norm_backward_14 = async_compile.triton('triton_red_fused__to_copy_native_layer_norm_backward_14', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
from torch._dynamo.testing import rand_strided
from torch._C import _cuda_getCurrentRawStream as get_raw_stream
import torch
from torch._inductor.runtime.triton_heuristics import grid, split_scan_grid
@triton_heuristics.reduction(
size_hints=[65536, 128],
reduction_hint=ReductionHint.OUTER,
filename=__file__,
triton_meta={'signature': {0: '*fp16', 1: '*fp32', 2: '*fp32', 3: '*fp32', 4: 'i32', 5: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4, 5), equal_to_1=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_native_layer_norm_backward_14', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 2, 'num_reduction': 2, 'backend_hash': '196B341D951BBDD96DE5B2B44B37054C2CBA8E89494C8B8A1052A863B8BC7596', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'kernel_num_gb': 0.038141952}
)
@triton.jit
def triton_red_fused__to_copy_native_layer_norm_backward_14(in_ptr0, in_ptr1, out_ptr0, out_ptr1, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr):
xnumel = 49152
rnumel = 128
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
xmask = tl.full([XBLOCK, RBLOCK], True, tl.int1)
rbase = tl.arange(0, RBLOCK)[None, :]
x0 = xindex % 768
x1 = (xindex // 768)
_tmp5 = tl.full([XBLOCK, RBLOCK], 0, tl.float32)
x3 = xindex
_tmp8 = tl.full([XBLOCK, RBLOCK], 0, tl.float32)
for roffset in range(0, rnumel, RBLOCK):
rindex = roffset + rbase
rmask = rindex < rnumel
r2 = rindex
tmp0 = tl.load(in_ptr0 + (x0 + (768*r2) + (98304*x1)), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp2 = tl.load(in_ptr1 + (x0 + (768*r2) + (98304*x1)), rmask, eviction_policy='evict_first', other=0.0)
tmp1 = tmp0.to(tl.float32)
tmp3 = tmp1 * tmp2
tmp4 = tl.broadcast_to(tmp3, [XBLOCK, RBLOCK])
tmp6 = _tmp5 + tmp4
_tmp5 = tl.where(rmask, tmp6, _tmp5)
tmp7 = tl.broadcast_to(tmp1, [XBLOCK, RBLOCK])
tmp9 = _tmp8 + tmp7
_tmp8 = tl.where(rmask, tmp9, _tmp8)
tmp5 = tl.sum(_tmp5, 1)[:, None]
tl.store(out_ptr0 + (x3), tmp5, None)
tmp8 = tl.sum(_tmp8, 1)[:, None]
tl.store(out_ptr1 + (x3), tmp8, None)
def get_args():
arg_0 = rand_strided((8192, 768), (768, 1), device='cuda:0', dtype=torch.float16)
arg_1 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.float32)
arg_2 = rand_strided((768, 64), (1, 768), device='cuda:0', dtype=torch.float32)
arg_3 = rand_strided((768, 64), (1, 768), device='cuda:0', dtype=torch.float32)
return arg_0, arg_1, arg_2, arg_3,
def call(args):
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
stream0 = get_raw_stream(0)
triton_red_fused__to_copy_native_layer_norm_backward_14.run(*args, 49152, 128, grid=grid(49152), stream=stream0)
def benchmark_all_configs(args):
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
return triton_red_fused__to_copy_native_layer_norm_backward_14.benchmark_all_configs(*args, 49152, 128, grid=grid(49152))
if __name__ == '__main__':
from triton.testing import do_bench
args = get_args()
ms = do_bench(lambda: call(args), rep=40, fast_flush=True)
num_gb = 0.038141952
gb_per_s = num_gb / (ms / 1e3)
print(f"{ms:.3f}ms {num_gb:.3f}GB {gb_per_s:.2f}GB/s")
''', device_str='cuda')
# kernel path: /tmp/torchinductor_shunting/tk/ctkue5wyr5xsiml26vprfqtgmp2eaql2enslkqgq3exemgpikz3o.py
# Source Nodes: [], Original ATen: [aten.sum]
triton_red_fused_sum_15 = async_compile.triton('triton_red_fused_sum_15', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
from torch._dynamo.testing import rand_strided
from torch._C import _cuda_getCurrentRawStream as get_raw_stream
import torch
from torch._inductor.runtime.triton_heuristics import grid, split_scan_grid
@triton_heuristics.reduction(
size_hints=[65536, 128],
reduction_hint=ReductionHint.OUTER,
filename=__file__,
triton_meta={'signature': {0: '*fp16', 1: '*fp32', 2: 'i32', 3: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3), equal_to_1=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused_sum_15', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': '196B341D951BBDD96DE5B2B44B37054C2CBA8E89494C8B8A1052A863B8BC7596', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'kernel_num_gb': 0.01277952}
)
@triton.jit
def triton_red_fused_sum_15(in_ptr0, out_ptr0, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr):
xnumel = 49152
rnumel = 128
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
xmask = tl.full([XBLOCK, RBLOCK], True, tl.int1)
rbase = tl.arange(0, RBLOCK)[None, :]
x0 = xindex % 768
x1 = (xindex // 768)
_tmp2 = tl.full([XBLOCK, RBLOCK], 0, tl.float32)
x3 = xindex
for roffset in range(0, rnumel, RBLOCK):
rindex = roffset + rbase
rmask = rindex < rnumel
r2 = rindex
tmp0 = tl.load(in_ptr0 + (x0 + (768*r2) + (98304*x1)), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp1 = tl.broadcast_to(tmp0, [XBLOCK, RBLOCK])
tmp3 = _tmp2 + tmp1
_tmp2 = tl.where(rmask, tmp3, _tmp2)
tmp2 = tl.sum(_tmp2, 1)[:, None]
tl.store(out_ptr0 + (x3), tmp2, None)
def get_args():
arg_0 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.float16)
arg_1 = rand_strided((1, 768, 64), (49152, 1, 768), device='cuda:0', dtype=torch.float32)
return arg_0, arg_1,
def call(args):
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
stream0 = get_raw_stream(0)
triton_red_fused_sum_15.run(*args, 49152, 128, grid=grid(49152), stream=stream0)
def benchmark_all_configs(args):
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
return triton_red_fused_sum_15.benchmark_all_configs(*args, 49152, 128, grid=grid(49152))
if __name__ == '__main__':
from triton.testing import do_bench
args = get_args()
ms = do_bench(lambda: call(args), rep=40, fast_flush=True)
num_gb = 0.01277952
gb_per_s = num_gb / (ms / 1e3)
print(f"{ms:.3f}ms {num_gb:.3f}GB {gb_per_s:.2f}GB/s")
''', device_str='cuda')
# kernel path: /tmp/torchinductor_shunting/us/cusyhkdtuhzsuipmuax2gi4pj3qgfxdgv62fqhekapxgkifowy6q.py
# Source Nodes: [], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_16 = async_compile.triton('triton_poi_fused__to_copy_16', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
from torch._dynamo.testing import rand_strided
from torch._C import _cuda_getCurrentRawStream as get_raw_stream
import torch
from torch._inductor.runtime.triton_heuristics import grid, split_scan_grid
@triton_heuristics.pointwise(
size_hints=[4194304],
filename=__file__,
triton_meta={'signature': {0: '*fp16', 1: '*fp32', 2: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2), equal_to_1=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_16', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': '196B341D951BBDD96DE5B2B44B37054C2CBA8E89494C8B8A1052A863B8BC7596', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'kernel_num_gb': 0.014155776},
min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused__to_copy_16(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
xnumel = 2359296
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = tl.full([XBLOCK], True, tl.int1)
x0 = xindex
tmp0 = tl.load(in_ptr0 + (x0), None).to(tl.float32)
tmp1 = tmp0.to(tl.float32)
tl.store(out_ptr0 + (x0), tmp1, None)
def get_args():
arg_0 = rand_strided((768, 3072), (3072, 1), device='cuda:0', dtype=torch.float16)
arg_1 = rand_strided((768, 3072), (3072, 1), device='cuda:0', dtype=torch.float32)
return arg_0, arg_1,
def call(args):
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_16.run(*args, 2359296, grid=grid(2359296), stream=stream0)
def benchmark_all_configs(args):
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
return triton_poi_fused__to_copy_16.benchmark_all_configs(*args, 2359296, grid=grid(2359296))
if __name__ == '__main__':
from triton.testing import do_bench
args = get_args()
ms = do_bench(lambda: call(args), rep=40, fast_flush=True)
num_gb = 0.014155776
gb_per_s = num_gb / (ms / 1e3)
print(f"{ms:.3f}ms {num_gb:.3f}GB {gb_per_s:.2f}GB/s")
''', device_str='cuda')
# kernel path: /tmp/torchinductor_shunting/im/cimmrvcpkky3hw3retgbay4kloi52mxpkldjifj3bt6hbcrvgqtv.py
# Source Nodes: [hidden_states_92], Original ATen: [aten.gelu, aten.gelu_backward]
# hidden_states_92 => add_96, convert_element_type_485, erf_11, mul_155
triton_poi_fused_gelu_gelu_backward_17 = async_compile.triton('triton_poi_fused_gelu_gelu_backward_17', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
from torch._dynamo.testing import rand_strided
from torch._C import _cuda_getCurrentRawStream as get_raw_stream
import torch
from torch._inductor.runtime.triton_heuristics import grid, split_scan_grid
@triton_heuristics.pointwise(
size_hints=[33554432],
filename=__file__,
triton_meta={'signature': {0: '*fp16', 1: '*fp16', 2: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2), equal_to_1=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_gelu_gelu_backward_17', 'mutated_arg_names': ['in_out_ptr0'], 'no_x_dim': False, 'num_load': 2, 'num_reduction': 0, 'backend_hash': '196B341D951BBDD96DE5B2B44B37054C2CBA8E89494C8B8A1052A863B8BC7596', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'kernel_num_gb': 0.150994944},
min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused_gelu_gelu_backward_17(in_out_ptr0, in_ptr0, xnumel, XBLOCK : tl.constexpr):
xnumel = 25165824
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = tl.full([XBLOCK], True, tl.int1)
x0 = xindex
tmp0 = tl.load(in_out_ptr0 + (x0), None).to(tl.float32)
tmp2 = tl.load(in_ptr0 + (x0), None).to(tl.float32)
tmp1 = tmp0.to(tl.float32)
tmp3 = tmp2.to(tl.float32)
tmp4 = 0.7071067811865476
tmp5 = tmp3 * tmp4
tmp6 = libdevice.erf(tmp5)
tmp7 = 1.0
tmp8 = tmp6 + tmp7
tmp9 = 0.5
tmp10 = tmp8 * tmp9
tmp11 = tmp3 * tmp3
tmp12 = -0.5
tmp13 = tmp11 * tmp12
tmp14 = tl_math.exp(tmp13)
tmp15 = 0.3989422804014327
tmp16 = tmp14 * tmp15
tmp17 = tmp3 * tmp16
tmp18 = tmp10 + tmp17
tmp19 = tmp1 * tmp18
tmp20 = tmp19.to(tl.float32)
tl.store(in_out_ptr0 + (x0), tmp20, None)
def get_args():
arg_0 = rand_strided((16, 512, 3072), (1572864, 3072, 1), device='cuda:0', dtype=torch.float16)
arg_1 = rand_strided((8192, 3072), (3072, 1), device='cuda:0', dtype=torch.float16)
return arg_0, arg_1,
def call(args):
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
stream0 = get_raw_stream(0)
triton_poi_fused_gelu_gelu_backward_17.run(*args, 25165824, grid=grid(25165824), stream=stream0)
def benchmark_all_configs(args):
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
return triton_poi_fused_gelu_gelu_backward_17.benchmark_all_configs(*args, 25165824, grid=grid(25165824))
if __name__ == '__main__':
from triton.testing import do_bench
args = get_args()
ms = do_bench(lambda: call(args), rep=40, fast_flush=True)
num_gb = 0.150994944
gb_per_s = num_gb / (ms / 1e3)
print(f"{ms:.3f}ms {num_gb:.3f}GB {gb_per_s:.2f}GB/s")
''', device_str='cuda')
# kernel path: /tmp/torchinductor_shunting/le/cleh7tp3rcdkx6xu5hdckexqubzkgpkgi7yarvwkzdkc76jba35y.py
# Source Nodes: [], Original ATen: [aten.sum]
triton_red_fused_sum_18 = async_compile.triton('triton_red_fused_sum_18', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
from torch._dynamo.testing import rand_strided
from torch._C import _cuda_getCurrentRawStream as get_raw_stream
import torch
from torch._inductor.runtime.triton_heuristics import grid, split_scan_grid
@triton_heuristics.reduction(
size_hints=[131072, 256],
reduction_hint=ReductionHint.OUTER,
filename=__file__,
triton_meta={'signature': {0: '*fp16', 1: '*fp32', 2: 'i32', 3: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3), equal_to_1=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused_sum_18', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': '196B341D951BBDD96DE5B2B44B37054C2CBA8E89494C8B8A1052A863B8BC7596', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'kernel_num_gb': 0.050724864}
)
@triton.jit
def triton_red_fused_sum_18(in_ptr0, out_ptr0, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr):
xnumel = 98304
rnumel = 256
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
xmask = tl.full([XBLOCK, RBLOCK], True, tl.int1)
rbase = tl.arange(0, RBLOCK)[None, :]
x0 = xindex % 3072
x1 = (xindex // 3072)
_tmp2 = tl.full([XBLOCK, RBLOCK], 0, tl.float32)
x3 = xindex
for roffset in range(0, rnumel, RBLOCK):
rindex = roffset + rbase
rmask = rindex < rnumel
r2 = rindex
tmp0 = tl.load(in_ptr0 + (x0 + (3072*r2) + (786432*x1)), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp1 = tl.broadcast_to(tmp0, [XBLOCK, RBLOCK])
tmp3 = _tmp2 + tmp1
_tmp2 = tl.where(rmask, tmp3, _tmp2)
tmp2 = tl.sum(_tmp2, 1)[:, None]
tl.store(out_ptr0 + (x3), tmp2, None)
def get_args():
arg_0 = rand_strided((16, 512, 3072), (1572864, 3072, 1), device='cuda:0', dtype=torch.float16)
arg_1 = rand_strided((1, 3072, 32), (98304, 1, 3072), device='cuda:0', dtype=torch.float32)
return arg_0, arg_1,
def call(args):
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
stream0 = get_raw_stream(0)
triton_red_fused_sum_18.run(*args, 98304, 256, grid=grid(98304), stream=stream0)
def benchmark_all_configs(args):
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
return triton_red_fused_sum_18.benchmark_all_configs(*args, 98304, 256, grid=grid(98304))
if __name__ == '__main__':
from triton.testing import do_bench
args = get_args()
ms = do_bench(lambda: call(args), rep=40, fast_flush=True)
num_gb = 0.050724864
gb_per_s = num_gb / (ms / 1e3)
print(f"{ms:.3f}ms {num_gb:.3f}GB {gb_per_s:.2f}GB/s")
''', device_str='cuda')
# kernel path: /tmp/torchinductor_shunting/uv/cuvj2aolby6dsvhktohp2xnbrouwrm3i5iwofuf5cf2gxqrzwi4p.py
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum]
triton_per_fused__to_copy_sum_19 = async_compile.triton('triton_per_fused__to_copy_sum_19', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
from torch._dynamo.testing import rand_strided
from torch._C import _cuda_getCurrentRawStream as get_raw_stream
import torch
from torch._inductor.runtime.triton_heuristics import grid, split_scan_grid
@triton_heuristics.persistent_reduction(
size_hints=[4096, 32],
reduction_hint=ReductionHint.OUTER,
filename=__file__,
triton_meta={'signature': {0: '*fp32', 1: '*fp32', 2: 'i32', 3: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3), equal_to_1=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_sum_19', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': '196B341D951BBDD96DE5B2B44B37054C2CBA8E89494C8B8A1052A863B8BC7596', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'kernel_num_gb': 0.000405504}
)
@triton.jit
def triton_per_fused__to_copy_sum_19(in_ptr0, out_ptr1, xnumel, rnumel, XBLOCK : tl.constexpr):
xnumel = 3072
rnumel = 32
RBLOCK: tl.constexpr = 32
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
xmask = xindex < xnumel
rindex = tl.arange(0, RBLOCK)[None, :]
roffset = 0
rmask = tl.full([XBLOCK, RBLOCK], True, tl.int1)
r1 = rindex
x0 = xindex
tmp0 = tl.load(in_ptr0 + (x0 + (3072*r1)), xmask, other=0.0)
tmp1 = tl.broadcast_to(tmp0, [XBLOCK, RBLOCK])
tmp3 = tl.where(xmask, tmp1, 0)
tmp4 = tl.sum(tmp3, 1)[:, None]
tmp5 = tmp4.to(tl.float32)
tl.store(out_ptr1 + (x0), tmp5, xmask)
def get_args():
arg_0 = rand_strided((1, 3072, 32), (98304, 1, 3072), device='cuda:0', dtype=torch.float32)
arg_1 = rand_strided((3072,), (1,), device='cuda:0', dtype=torch.float32)
return arg_0, arg_1,
def call(args):
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
stream0 = get_raw_stream(0)
triton_per_fused__to_copy_sum_19.run(*args, 3072, 32, grid=grid(3072), stream=stream0)
def benchmark_all_configs(args):
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
return triton_per_fused__to_copy_sum_19.benchmark_all_configs(*args, 3072, 32, grid=grid(3072))
if __name__ == '__main__':
from triton.testing import do_bench
args = get_args()
ms = do_bench(lambda: call(args), rep=40, fast_flush=True)
num_gb = 0.000405504
gb_per_s = num_gb / (ms / 1e3)
print(f"{ms:.3f}ms {num_gb:.3f}GB {gb_per_s:.2f}GB/s")
''', device_str='cuda')
# kernel path: /tmp/torchinductor_shunting/dp/cdp3sl7e4pey7rtyji754dhbkoxfiwn6srdtuu3ygal2f76wqy6g.py
# Source Nodes: [], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_20 = async_compile.triton('triton_poi_fused__to_copy_20', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
from torch._dynamo.testing import rand_strided
from torch._C import _cuda_getCurrentRawStream as get_raw_stream
import torch
from torch._inductor.runtime.triton_heuristics import grid, split_scan_grid
@triton_heuristics.pointwise(
size_hints=[4194304],
filename=__file__,
triton_meta={'signature': {0: '*fp16', 1: '*fp32', 2: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2), equal_to_1=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_20', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': '196B341D951BBDD96DE5B2B44B37054C2CBA8E89494C8B8A1052A863B8BC7596', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'kernel_num_gb': 0.014155776},
min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused__to_copy_20(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
xnumel = 2359296
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = tl.full([XBLOCK], True, tl.int1)
x0 = xindex
tmp0 = tl.load(in_ptr0 + (x0), None).to(tl.float32)
tmp1 = tmp0.to(tl.float32)
tl.store(out_ptr0 + (x0), tmp1, None)
def get_args():
arg_0 = rand_strided((3072, 768), (768, 1), device='cuda:0', dtype=torch.float16)
arg_1 = rand_strided((3072, 768), (768, 1), device='cuda:0', dtype=torch.float32)
return arg_0, arg_1,
def call(args):
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_20.run(*args, 2359296, grid=grid(2359296), stream=stream0)
def benchmark_all_configs(args):
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
return triton_poi_fused__to_copy_20.benchmark_all_configs(*args, 2359296, grid=grid(2359296))
if __name__ == '__main__':
from triton.testing import do_bench
args = get_args()
ms = do_bench(lambda: call(args), rep=40, fast_flush=True)
num_gb = 0.014155776
gb_per_s = num_gb / (ms / 1e3)
print(f"{ms:.3f}ms {num_gb:.3f}GB {gb_per_s:.2f}GB/s")
''', device_str='cuda')
# kernel path: /tmp/torchinductor_shunting/u2/cu2zu355cpgyqkbajs47g6dc25icfvkaqnjfftpy4rwsoyxwd6df.py
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_dropout_backward, aten.native_layer_norm_backward]
triton_per_fused__to_copy_add_native_dropout_backward_native_layer_norm_backward_21 = async_compile.triton('triton_per_fused__to_copy_add_native_dropout_backward_native_layer_norm_backward_21', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
from torch._dynamo.testing import rand_strided
from torch._C import _cuda_getCurrentRawStream as get_raw_stream
import torch
from torch._inductor.runtime.triton_heuristics import grid, split_scan_grid
@triton_heuristics.persistent_reduction(
size_hints=[8192, 1024],
reduction_hint=ReductionHint.INNER,
filename=__file__,
triton_meta={'signature': {0: '*fp32', 1: '*fp16', 2: '*fp32', 3: '*fp32', 4: '*fp32', 5: '*i1', 6: '*fp32', 7: '*fp16', 8: 'i32', 9: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4, 5, 6, 7, 8, 9), equal_to_1=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_add_native_dropout_backward_native_layer_norm_backward_21', 'mutated_arg_names': [], 'no_x_dim': True, 'num_load': 6, 'num_reduction': 2, 'backend_hash': '196B341D951BBDD96DE5B2B44B37054C2CBA8E89494C8B8A1052A863B8BC7596', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'kernel_num_gb': 0.106990592}
)
@triton.jit
def triton_per_fused__to_copy_add_native_dropout_backward_native_layer_norm_backward_21(in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, in_ptr5, out_ptr2, out_ptr3, xnumel, rnumel):
xnumel = 8192
XBLOCK: tl.constexpr = 1
rnumel = 768
RBLOCK: tl.constexpr = 1024
xoffset = tl.program_id(0) * XBLOCK
xindex = tl.full([1], xoffset, tl.int32)
xmask = tl.full([RBLOCK], True, tl.int1)
rindex = tl.arange(0, RBLOCK)[:]
roffset = 0
rmask = rindex < rnumel
r1 = rindex
x0 = xindex
tmp0 = tl.load(in_ptr0 + (r1 + (768*x0)), rmask, other=0.0)
tmp1 = tl.load(in_ptr1 + (r1 + (768*x0)), rmask, other=0.0).to(tl.float32)
tmp4 = tl.load(in_ptr2 + (r1), rmask, eviction_policy='evict_last', other=0.0)
tmp10 = tl.load(in_ptr3 + (r1 + (768*x0)), rmask, other=0.0)
tmp16 = tl.load(in_ptr4 + (x0), None, eviction_policy='evict_last')
tmp24 = tl.load(in_ptr5 + (r1 + (768*x0)), rmask, other=0.0).to(tl.int1)
tmp2 = tmp1.to(tl.float32)
tmp3 = tmp0 + tmp2
tmp5 = tmp3 * tmp4
tmp6 = tl.broadcast_to(tmp5, [RBLOCK])
tmp8 = tl.where(rmask, tmp6, 0)
tmp9 = triton_helpers.promote_to_tensor(tl.sum(tmp8, 0))
tmp11 = tmp5 * tmp10
tmp12 = tl.broadcast_to(tmp11, [RBLOCK])
tmp14 = tl.where(rmask, tmp12, 0)
tmp15 = triton_helpers.promote_to_tensor(tl.sum(tmp14, 0))
tmp17 = 768.0
tmp18 = tmp5 * tmp17
tmp19 = tmp18 - tmp9
tmp20 = tmp10 * tmp15
tmp21 = tmp19 - tmp20
tmp22 = tmp16 * tmp21
tmp23 = tmp22.to(tl.float32)
tmp25 = tmp24.to(tl.float32)
tmp26 = 1.1111111111111112
tmp27 = tmp25 * tmp26
tmp28 = tmp23 * tmp27
tl.store(out_ptr2 + (r1 + (768*x0)), tmp22, rmask)
tl.store(out_ptr3 + (r1 + (768*x0)), tmp28, rmask)
def get_args():
arg_0 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.float32)
arg_1 = rand_strided((8192, 768), (768, 1), device='cuda:0', dtype=torch.float16)
arg_2 = rand_strided((768,), (1,), device='cuda:0', dtype=torch.float32)
arg_3 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.float32)
arg_4 = rand_strided((16, 512, 1), (512, 1, 1), device='cuda:0', dtype=torch.float32)
arg_5 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.bool)
arg_6 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.float32)
arg_7 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.float16)
return arg_0, arg_1, arg_2, arg_3, arg_4, arg_5, arg_6, arg_7,
def call(args):
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
stream0 = get_raw_stream(0)
triton_per_fused__to_copy_add_native_dropout_backward_native_layer_norm_backward_21.run(*args, 8192, 768, grid=grid(8192), stream=stream0)
def benchmark_all_configs(args):
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
return triton_per_fused__to_copy_add_native_dropout_backward_native_layer_norm_backward_21.benchmark_all_configs(*args, 8192, 768, grid=grid(8192))
if __name__ == '__main__':
from triton.testing import do_bench
args = get_args()
ms = do_bench(lambda: call(args), rep=40, fast_flush=True)
num_gb = 0.106990592
gb_per_s = num_gb / (ms / 1e3)
print(f"{ms:.3f}ms {num_gb:.3f}GB {gb_per_s:.2f}GB/s")
''', device_str='cuda')
# kernel path: /tmp/torchinductor_shunting/kf/ckfk5xfnycnrpmf5qszpqq2qcsb7kwgpnosphd6ynan6dvn3d5gp.py
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward]
triton_red_fused__to_copy_add_native_layer_norm_backward_22 = async_compile.triton('triton_red_fused__to_copy_add_native_layer_norm_backward_22', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
from torch._dynamo.testing import rand_strided
from torch._C import _cuda_getCurrentRawStream as get_raw_stream
import torch
from torch._inductor.runtime.triton_heuristics import grid, split_scan_grid
@triton_heuristics.reduction(
size_hints=[65536, 128],
reduction_hint=ReductionHint.OUTER,
filename=__file__,
triton_meta={'signature': {0: '*fp32', 1: '*fp16', 2: '*fp32', 3: '*fp32', 4: '*fp32', 5: 'i32', 6: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4, 5, 6), equal_to_1=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_add_native_layer_norm_backward_22', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 3, 'num_reduction': 2, 'backend_hash': '196B341D951BBDD96DE5B2B44B37054C2CBA8E89494C8B8A1052A863B8BC7596', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'kernel_num_gb': 0.063307776}
)
@triton.jit
def triton_red_fused__to_copy_add_native_layer_norm_backward_22(in_ptr0, in_ptr1, in_ptr2, out_ptr0, out_ptr1, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr):
xnumel = 49152
rnumel = 128
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
xmask = tl.full([XBLOCK, RBLOCK], True, tl.int1)
rbase = tl.arange(0, RBLOCK)[None, :]
x0 = xindex % 768
x1 = (xindex // 768)
_tmp7 = tl.full([XBLOCK, RBLOCK], 0, tl.float32)
x3 = xindex
_tmp10 = tl.full([XBLOCK, RBLOCK], 0, tl.float32)
for roffset in range(0, rnumel, RBLOCK):
rindex = roffset + rbase
rmask = rindex < rnumel
r2 = rindex
tmp0 = tl.load(in_ptr0 + (x0 + (768*r2) + (98304*x1)), rmask, eviction_policy='evict_first', other=0.0)
tmp1 = tl.load(in_ptr1 + (x0 + (768*r2) + (98304*x1)), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp4 = tl.load(in_ptr2 + (x0 + (768*r2) + (98304*x1)), rmask, eviction_policy='evict_first', other=0.0)
tmp2 = tmp1.to(tl.float32)
tmp3 = tmp0 + tmp2
tmp5 = tmp3 * tmp4
tmp6 = tl.broadcast_to(tmp5, [XBLOCK, RBLOCK])
tmp8 = _tmp7 + tmp6
_tmp7 = tl.where(rmask, tmp8, _tmp7)
tmp9 = tl.broadcast_to(tmp3, [XBLOCK, RBLOCK])
tmp11 = _tmp10 + tmp9
_tmp10 = tl.where(rmask, tmp11, _tmp10)
tmp7 = tl.sum(_tmp7, 1)[:, None]
tl.store(out_ptr0 + (x3), tmp7, None)
tmp10 = tl.sum(_tmp10, 1)[:, None]
tl.store(out_ptr1 + (x3), tmp10, None)
def get_args():
arg_0 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.float32)
arg_1 = rand_strided((8192, 768), (768, 1), device='cuda:0', dtype=torch.float16)
arg_2 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.float32)
arg_3 = rand_strided((768, 64), (1, 768), device='cuda:0', dtype=torch.float32)
arg_4 = rand_strided((768, 64), (1, 768), device='cuda:0', dtype=torch.float32)
return arg_0, arg_1, arg_2, arg_3, arg_4,
def call(args):
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
stream0 = get_raw_stream(0)
triton_red_fused__to_copy_add_native_layer_norm_backward_22.run(*args, 49152, 128, grid=grid(49152), stream=stream0)
def benchmark_all_configs(args):
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
return triton_red_fused__to_copy_add_native_layer_norm_backward_22.benchmark_all_configs(*args, 49152, 128, grid=grid(49152))
if __name__ == '__main__':
from triton.testing import do_bench
args = get_args()
ms = do_bench(lambda: call(args), rep=40, fast_flush=True)
num_gb = 0.063307776
gb_per_s = num_gb / (ms / 1e3)
print(f"{ms:.3f}ms {num_gb:.3f}GB {gb_per_s:.2f}GB/s")
''', device_str='cuda')
# kernel path: /tmp/torchinductor_shunting/fy/cfyrxilgpaoobuleeggh5n6i3vfcj6pqju5o3najjzmc764wrhlw.py
# Source Nodes: [], Original ATen: [aten.clone]
triton_poi_fused_clone_23 = async_compile.triton('triton_poi_fused_clone_23', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
from torch._dynamo.testing import rand_strided
from torch._C import _cuda_getCurrentRawStream as get_raw_stream
import torch
from torch._inductor.runtime.triton_heuristics import grid, split_scan_grid
@triton_heuristics.pointwise(
size_hints=[8388608],
filename=__file__,
triton_meta={'signature': {0: '*fp16', 1: '*fp16', 2: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2), equal_to_1=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_clone_23', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': '196B341D951BBDD96DE5B2B44B37054C2CBA8E89494C8B8A1052A863B8BC7596', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'kernel_num_gb': 0.025165824},
min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused_clone_23(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
xnumel = 6291456
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = tl.full([XBLOCK], True, tl.int1)
x0 = xindex % 64
x1 = (xindex // 64) % 512
x2 = (xindex // 32768) % 12
x3 = (xindex // 393216)
x4 = xindex
tmp0 = tl.load(in_ptr0 + (x0 + (64*x2) + (768*x1) + (393216*x3)), None).to(tl.float32)
tl.store(out_ptr0 + (x4), tmp0, None)
def get_args():
arg_0 = rand_strided((8192, 768), (768, 1), device='cuda:0', dtype=torch.float16)
arg_1 = rand_strided((16, 12, 512, 64), (393216, 32768, 64, 1), device='cuda:0', dtype=torch.float16)
return arg_0, arg_1,
def call(args):
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
stream0 = get_raw_stream(0)
triton_poi_fused_clone_23.run(*args, 6291456, grid=grid(6291456), stream=stream0)
def benchmark_all_configs(args):
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
return triton_poi_fused_clone_23.benchmark_all_configs(*args, 6291456, grid=grid(6291456))
if __name__ == '__main__':
from triton.testing import do_bench
args = get_args()
ms = do_bench(lambda: call(args), rep=40, fast_flush=True)
num_gb = 0.025165824
gb_per_s = num_gb / (ms / 1e3)
print(f"{ms:.3f}ms {num_gb:.3f}GB {gb_per_s:.2f}GB/s")
''', device_str='cuda')
# kernel path: /tmp/torchinductor_shunting/lj/cljluzoqrqvi7wumengog7jpxwcpskgi6tjet3lc7j2all5nibdx.py
# Source Nodes: [], Original ATen: [aten.sum]
triton_red_fused_sum_24 = async_compile.triton('triton_red_fused_sum_24', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
from torch._dynamo.testing import rand_strided
from torch._C import _cuda_getCurrentRawStream as get_raw_stream
import torch
from torch._inductor.runtime.triton_heuristics import grid, split_scan_grid
@triton_heuristics.reduction(
size_hints=[65536, 128],
reduction_hint=ReductionHint.OUTER,
filename=__file__,
triton_meta={'signature': {0: '*fp16', 1: '*fp32', 2: 'i32', 3: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3), equal_to_1=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused_sum_24', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': '196B341D951BBDD96DE5B2B44B37054C2CBA8E89494C8B8A1052A863B8BC7596', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'kernel_num_gb': 0.01277952}
)
@triton.jit
def triton_red_fused_sum_24(in_ptr0, out_ptr0, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr):
xnumel = 49152
rnumel = 128
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
xmask = tl.full([XBLOCK, RBLOCK], True, tl.int1)
rbase = tl.arange(0, RBLOCK)[None, :]
x0 = xindex % 768
x1 = (xindex // 768)
_tmp2 = tl.full([XBLOCK, RBLOCK], 0, tl.float32)
x3 = xindex
for roffset in range(0, rnumel, RBLOCK):
rindex = roffset + rbase
rmask = rindex < rnumel
r2 = rindex
tmp0 = tl.load(in_ptr0 + (x0 + (768*r2) + (98304*x1)), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp1 = tl.broadcast_to(tmp0, [XBLOCK, RBLOCK])
tmp3 = _tmp2 + tmp1
_tmp2 = tl.where(rmask, tmp3, _tmp2)
tmp2 = tl.sum(_tmp2, 1)[:, None]
tl.store(out_ptr0 + (x3), tmp2, None)
def get_args():
arg_0 = rand_strided((16, 12, 512, 64), (393216, 64, 768, 1), device='cuda:0', dtype=torch.float16)
arg_1 = rand_strided((1, 768, 64), (49152, 1, 768), device='cuda:0', dtype=torch.float32)
return arg_0, arg_1,
def call(args):
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
stream0 = get_raw_stream(0)
triton_red_fused_sum_24.run(*args, 49152, 128, grid=grid(49152), stream=stream0)
def benchmark_all_configs(args):
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
return triton_red_fused_sum_24.benchmark_all_configs(*args, 49152, 128, grid=grid(49152))
if __name__ == '__main__':
from triton.testing import do_bench
args = get_args()
ms = do_bench(lambda: call(args), rep=40, fast_flush=True)
num_gb = 0.01277952
gb_per_s = num_gb / (ms / 1e3)
print(f"{ms:.3f}ms {num_gb:.3f}GB {gb_per_s:.2f}GB/s")
''', device_str='cuda')
# kernel path: /tmp/torchinductor_shunting/pg/cpgss3xinmgtgberwusouop6ykgkwstr6e2bi33zume6biol33tv.py
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_dropout_backward, aten.native_layer_norm_backward]
triton_per_fused__to_copy_add_native_dropout_backward_native_layer_norm_backward_25 = async_compile.triton('triton_per_fused__to_copy_add_native_dropout_backward_native_layer_norm_backward_25', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
from torch._dynamo.testing import rand_strided
from torch._C import _cuda_getCurrentRawStream as get_raw_stream
import torch
from torch._inductor.runtime.triton_heuristics import grid, split_scan_grid
@triton_heuristics.persistent_reduction(
size_hints=[8192, 1024],
reduction_hint=ReductionHint.INNER,
filename=__file__,
triton_meta={'signature': {0: '*fp32', 1: '*fp16', 2: '*fp16', 3: '*fp16', 4: '*fp32', 5: '*fp32', 6: '*fp32', 7: '*i1', 8: '*fp32', 9: '*fp16', 10: 'i32', 11: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11), equal_to_1=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_add_native_dropout_backward_native_layer_norm_backward_25', 'mutated_arg_names': [], 'no_x_dim': True, 'num_load': 8, 'num_reduction': 2, 'backend_hash': '196B341D951BBDD96DE5B2B44B37054C2CBA8E89494C8B8A1052A863B8BC7596', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'kernel_num_gb': 0.132156416}
)
@triton.jit
def triton_per_fused__to_copy_add_native_dropout_backward_native_layer_norm_backward_25(in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, in_ptr5, in_ptr6, in_ptr7, out_ptr3, out_ptr4, xnumel, rnumel):
xnumel = 8192
XBLOCK: tl.constexpr = 1
rnumel = 768
RBLOCK: tl.constexpr = 1024
xoffset = tl.program_id(0) * XBLOCK
xindex = tl.full([1], xoffset, tl.int32)
xmask = tl.full([RBLOCK], True, tl.int1)
rindex = tl.arange(0, RBLOCK)[:]
roffset = 0
rmask = rindex < rnumel
r1 = rindex
x0 = xindex
tmp0 = tl.load(in_ptr0 + (r1 + (768*x0)), rmask, other=0.0)
tmp1 = tl.load(in_ptr1 + (r1 + (768*x0)), rmask, other=0.0).to(tl.float32)
tmp4 = tl.load(in_ptr2 + (r1 + (768*x0)), rmask, other=0.0).to(tl.float32)
tmp7 = tl.load(in_ptr3 + (r1 + (768*x0)), rmask, other=0.0).to(tl.float32)
tmp10 = tl.load(in_ptr4 + (r1), rmask, eviction_policy='evict_last', other=0.0)
tmp16 = tl.load(in_ptr5 + (r1 + (768*x0)), rmask, other=0.0)
tmp22 = tl.load(in_ptr6 + (x0), None, eviction_policy='evict_last')
tmp30 = tl.load(in_ptr7 + (r1 + (768*x0)), rmask, other=0.0).to(tl.int1)
tmp2 = tmp1.to(tl.float32)
tmp3 = tmp0 + tmp2
tmp5 = tmp4.to(tl.float32)
tmp6 = tmp3 + tmp5
tmp8 = tmp7.to(tl.float32)
tmp9 = tmp6 + tmp8
tmp11 = tmp9 * tmp10
tmp12 = tl.broadcast_to(tmp11, [RBLOCK])
tmp14 = tl.where(rmask, tmp12, 0)
tmp15 = triton_helpers.promote_to_tensor(tl.sum(tmp14, 0))
tmp17 = tmp11 * tmp16
tmp18 = tl.broadcast_to(tmp17, [RBLOCK])
tmp20 = tl.where(rmask, tmp18, 0)
tmp21 = triton_helpers.promote_to_tensor(tl.sum(tmp20, 0))
tmp23 = 768.0
tmp24 = tmp11 * tmp23
tmp25 = tmp24 - tmp15
tmp26 = tmp16 * tmp21
tmp27 = tmp25 - tmp26
tmp28 = tmp22 * tmp27
tmp29 = tmp28.to(tl.float32)
tmp31 = tmp30.to(tl.float32)
tmp32 = 1.1111111111111112
tmp33 = tmp31 * tmp32
tmp34 = tmp29 * tmp33
tl.store(out_ptr3 + (r1 + (768*x0)), tmp28, rmask)
tl.store(out_ptr4 + (r1 + (768*x0)), tmp34, rmask)
def get_args():
arg_0 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.float32)
arg_1 = rand_strided((8192, 768), (768, 1), device='cuda:0', dtype=torch.float16)
arg_2 = rand_strided((8192, 768), (768, 1), device='cuda:0', dtype=torch.float16)
arg_3 = rand_strided((8192, 768), (768, 1), device='cuda:0', dtype=torch.float16)
arg_4 = rand_strided((768,), (1,), device='cuda:0', dtype=torch.float32)
arg_5 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.float32)
arg_6 = rand_strided((16, 512, 1), (512, 1, 1), device='cuda:0', dtype=torch.float32)
arg_7 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.bool)
arg_8 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.float32)
arg_9 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.float16)
return arg_0, arg_1, arg_2, arg_3, arg_4, arg_5, arg_6, arg_7, arg_8, arg_9,
def call(args):
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
stream0 = get_raw_stream(0)
triton_per_fused__to_copy_add_native_dropout_backward_native_layer_norm_backward_25.run(*args, 8192, 768, grid=grid(8192), stream=stream0)
def benchmark_all_configs(args):
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
return triton_per_fused__to_copy_add_native_dropout_backward_native_layer_norm_backward_25.benchmark_all_configs(*args, 8192, 768, grid=grid(8192))
if __name__ == '__main__':
from triton.testing import do_bench
args = get_args()
ms = do_bench(lambda: call(args), rep=40, fast_flush=True)
num_gb = 0.132156416
gb_per_s = num_gb / (ms / 1e3)
print(f"{ms:.3f}ms {num_gb:.3f}GB {gb_per_s:.2f}GB/s")
''', device_str='cuda')
# kernel path: /tmp/torchinductor_shunting/kk/ckkxhfownaqlblc6owmbx2ivrxr2v5fia4yyt5hpkdsas4vq57lh.py
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward]
triton_red_fused__to_copy_add_native_layer_norm_backward_26 = async_compile.triton('triton_red_fused__to_copy_add_native_layer_norm_backward_26', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
from torch._dynamo.testing import rand_strided
from torch._C import _cuda_getCurrentRawStream as get_raw_stream
import torch
from torch._inductor.runtime.triton_heuristics import grid, split_scan_grid
@triton_heuristics.reduction(
size_hints=[65536, 128],
reduction_hint=ReductionHint.OUTER,
filename=__file__,
triton_meta={'signature': {0: '*fp32', 1: '*fp16', 2: '*fp16', 3: '*fp16', 4: '*fp32', 5: '*fp32', 6: '*fp32', 7: 'i32', 8: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4, 5, 6, 7, 8), equal_to_1=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_add_native_layer_norm_backward_26', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 5, 'num_reduction': 2, 'backend_hash': '196B341D951BBDD96DE5B2B44B37054C2CBA8E89494C8B8A1052A863B8BC7596', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'kernel_num_gb': 0.0884736}
)
@triton.jit
def triton_red_fused__to_copy_add_native_layer_norm_backward_26(in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, out_ptr0, out_ptr1, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr):
xnumel = 49152
rnumel = 128
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
xmask = tl.full([XBLOCK, RBLOCK], True, tl.int1)
rbase = tl.arange(0, RBLOCK)[None, :]
x0 = xindex % 768
x1 = (xindex // 768)
_tmp13 = tl.full([XBLOCK, RBLOCK], 0, tl.float32)
x3 = xindex
_tmp16 = tl.full([XBLOCK, RBLOCK], 0, tl.float32)
for roffset in range(0, rnumel, RBLOCK):
rindex = roffset + rbase
rmask = rindex < rnumel
r2 = rindex
tmp0 = tl.load(in_ptr0 + (x0 + (768*r2) + (98304*x1)), rmask, eviction_policy='evict_first', other=0.0)
tmp1 = tl.load(in_ptr1 + (x0 + (768*r2) + (98304*x1)), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp4 = tl.load(in_ptr2 + (x0 + (768*r2) + (98304*x1)), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp7 = tl.load(in_ptr3 + (x0 + (768*r2) + (98304*x1)), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp10 = tl.load(in_ptr4 + (x0 + (768*r2) + (98304*x1)), rmask, eviction_policy='evict_first', other=0.0)
tmp2 = tmp1.to(tl.float32)
tmp3 = tmp0 + tmp2
tmp5 = tmp4.to(tl.float32)
tmp6 = tmp3 + tmp5
tmp8 = tmp7.to(tl.float32)
tmp9 = tmp6 + tmp8
tmp11 = tmp9 * tmp10
tmp12 = tl.broadcast_to(tmp11, [XBLOCK, RBLOCK])
tmp14 = _tmp13 + tmp12
_tmp13 = tl.where(rmask, tmp14, _tmp13)
tmp15 = tl.broadcast_to(tmp9, [XBLOCK, RBLOCK])
tmp17 = _tmp16 + tmp15
_tmp16 = tl.where(rmask, tmp17, _tmp16)
tmp13 = tl.sum(_tmp13, 1)[:, None]
tl.store(out_ptr0 + (x3), tmp13, None)
tmp16 = tl.sum(_tmp16, 1)[:, None]
tl.store(out_ptr1 + (x3), tmp16, None)
def get_args():
arg_0 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.float32)
arg_1 = rand_strided((8192, 768), (768, 1), device='cuda:0', dtype=torch.float16)
arg_2 = rand_strided((8192, 768), (768, 1), device='cuda:0', dtype=torch.float16)
arg_3 = rand_strided((8192, 768), (768, 1), device='cuda:0', dtype=torch.float16)
arg_4 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.float32)
arg_5 = rand_strided((768, 64), (1, 768), device='cuda:0', dtype=torch.float32)
arg_6 = rand_strided((768, 64), (1, 768), device='cuda:0', dtype=torch.float32)
return arg_0, arg_1, arg_2, arg_3, arg_4, arg_5, arg_6,
def call(args):
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
stream0 = get_raw_stream(0)
triton_red_fused__to_copy_add_native_layer_norm_backward_26.run(*args, 49152, 128, grid=grid(49152), stream=stream0)
def benchmark_all_configs(args):
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
return triton_red_fused__to_copy_add_native_layer_norm_backward_26.benchmark_all_configs(*args, 49152, 128, grid=grid(49152))
if __name__ == '__main__':
from triton.testing import do_bench
args = get_args()
ms = do_bench(lambda: call(args), rep=40, fast_flush=True)
num_gb = 0.0884736
gb_per_s = num_gb / (ms / 1e3)
print(f"{ms:.3f}ms {num_gb:.3f}GB {gb_per_s:.2f}GB/s")
''', device_str='cuda')
# kernel path: /tmp/torchinductor_shunting/xt/cxtwhwavzlx26uc42l5ttjvmznisiph6donsz4lua3ibmgh7atnd.py
# Source Nodes: [], Original ATen: [aten.embedding_dense_backward]
triton_poi_fused_embedding_dense_backward_27 = async_compile.triton('triton_poi_fused_embedding_dense_backward_27', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
from torch._dynamo.testing import rand_strided
from torch._C import _cuda_getCurrentRawStream as get_raw_stream
import torch
from torch._inductor.runtime.triton_heuristics import grid, split_scan_grid
@triton_heuristics.pointwise(
size_hints=[2048],
filename=__file__,
triton_meta={'signature': {0: '*fp32', 1: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1), equal_to_1=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_embedding_dense_backward_27', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 0, 'num_reduction': 0, 'backend_hash': '196B341D951BBDD96DE5B2B44B37054C2CBA8E89494C8B8A1052A863B8BC7596', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'kernel_num_gb': 6.144e-06},
min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused_embedding_dense_backward_27(out_ptr0, xnumel, XBLOCK : tl.constexpr):
xnumel = 1536
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = xindex < xnumel
x0 = xindex
tmp0 = 0.0
tl.store(out_ptr0 + (x0), tmp0, xmask)
def get_args():
arg_0 = rand_strided((2, 768), (768, 1), device='cuda:0', dtype=torch.float32)
return arg_0,
def call(args):
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
stream0 = get_raw_stream(0)
triton_poi_fused_embedding_dense_backward_27.run(*args, 1536, grid=grid(1536), stream=stream0)
def benchmark_all_configs(args):
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
return triton_poi_fused_embedding_dense_backward_27.benchmark_all_configs(*args, 1536, grid=grid(1536))
if __name__ == '__main__':
from triton.testing import do_bench
args = get_args()
ms = do_bench(lambda: call(args), rep=40, fast_flush=True)
num_gb = 6.144e-06
gb_per_s = num_gb / (ms / 1e3)
print(f"{ms:.3f}ms {num_gb:.3f}GB {gb_per_s:.2f}GB/s")
''', device_str='cuda')
# kernel path: /tmp/torchinductor_shunting/ux/cuxomnrffakbliaucrregsixhu7biugelgox2sqnovnhnqnsgirh.py
# Source Nodes: [], Original ATen: [aten.embedding_dense_backward]
triton_poi_fused_embedding_dense_backward_28 = async_compile.triton('triton_poi_fused_embedding_dense_backward_28', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
from torch._dynamo.testing import rand_strided
from torch._C import _cuda_getCurrentRawStream as get_raw_stream
import torch
from torch._inductor.runtime.triton_heuristics import grid, split_scan_grid
@triton_heuristics.pointwise(
size_hints=[33554432],
filename=__file__,
triton_meta={'signature': {0: '*fp32', 1: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1), equal_to_1=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_embedding_dense_backward_28', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 0, 'num_reduction': 0, 'backend_hash': '196B341D951BBDD96DE5B2B44B37054C2CBA8E89494C8B8A1052A863B8BC7596', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'kernel_num_gb': 0.093763584},
min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused_embedding_dense_backward_28(out_ptr0, xnumel, XBLOCK : tl.constexpr):
xnumel = 23440896
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = xindex < xnumel
x0 = xindex
tmp0 = 0.0
tl.store(out_ptr0 + (x0), tmp0, xmask)
def get_args():
arg_0 = rand_strided((30522, 768), (768, 1), device='cuda:0', dtype=torch.float32)
return arg_0,
def call(args):
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
stream0 = get_raw_stream(0)
triton_poi_fused_embedding_dense_backward_28.run(*args, 23440896, grid=grid(23440896), stream=stream0)
def benchmark_all_configs(args):
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
return triton_poi_fused_embedding_dense_backward_28.benchmark_all_configs(*args, 23440896, grid=grid(23440896))
if __name__ == '__main__':
from triton.testing import do_bench
args = get_args()
ms = do_bench(lambda: call(args), rep=40, fast_flush=True)
num_gb = 0.093763584
gb_per_s = num_gb / (ms / 1e3)
print(f"{ms:.3f}ms {num_gb:.3f}GB {gb_per_s:.2f}GB/s")
''', device_str='cuda')
# kernel path: /tmp/torchinductor_shunting/ux/cuxgjhibr4qqfqusc6n5la3fhval4hxzobfn4dzyazk5dckxavok.py
# Source Nodes: [masked_lm_loss], Original ATen: [aten._to_copy, aten.add, aten.embedding_dense_backward, aten.native_dropout_backward, aten.native_layer_norm_backward, aten.nll_loss_forward]
# masked_lm_loss => full_default_2
triton_per_fused__to_copy_add_embedding_dense_backward_native_dropout_backward_native_layer_norm_backward_nll_loss_forward_29 = async_compile.triton('triton_per_fused__to_copy_add_embedding_dense_backward_native_dropout_backward_native_layer_norm_backward_nll_loss_forward_29', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
from torch._dynamo.testing import rand_strided
from torch._C import _cuda_getCurrentRawStream as get_raw_stream
import torch
from torch._inductor.runtime.triton_heuristics import grid, split_scan_grid
@triton_heuristics.persistent_reduction(
size_hints=[8192, 1024],
reduction_hint=ReductionHint.INNER,
filename=__file__,
triton_meta={'signature': {0: '*fp32', 1: '*fp16', 2: '*fp16', 3: '*fp16', 4: '*i1', 5: '*fp32', 6: '*fp32', 7: '*fp32', 8: '*i64', 9: '*i64', 10: '*fp32', 11: '*fp32', 12: '*fp32', 13: 'i32', 14: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14), equal_to_1=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_add_embedding_dense_backward_native_dropout_backward_native_layer_norm_backward_nll_loss_forward_29', 'mutated_arg_names': ['in_out_ptr0', 'out_ptr3', 'out_ptr4'], 'no_x_dim': True, 'num_load': 10, 'num_reduction': 2, 'backend_hash': '196B341D951BBDD96DE5B2B44B37054C2CBA8E89494C8B8A1052A863B8BC7596', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'kernel_num_gb': 0.169980928}
)
@triton.jit
def triton_per_fused__to_copy_add_embedding_dense_backward_native_dropout_backward_native_layer_norm_backward_nll_loss_forward_29(in_out_ptr0, in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, in_ptr5, in_ptr6, in_ptr7, in_ptr8, out_ptr2, out_ptr3, out_ptr4, xnumel, rnumel):
xnumel = 8192
XBLOCK: tl.constexpr = 1
rnumel = 768
RBLOCK: tl.constexpr = 1024
xoffset = tl.program_id(0) * XBLOCK
xindex = tl.full([1], xoffset, tl.int32)
xmask = tl.full([RBLOCK], True, tl.int1)
rindex = tl.arange(0, RBLOCK)[:]
roffset = 0
rmask = rindex < rnumel
r1 = rindex
x0 = xindex
x2 = xindex % 512
tmp0 = tl.load(in_out_ptr0 + (r1 + (768*x0)), rmask, other=0.0)
tmp1 = tl.load(in_ptr0 + (r1 + (768*x0)), rmask, other=0.0).to(tl.float32)
tmp4 = tl.load(in_ptr1 + (r1 + (768*x0)), rmask, other=0.0).to(tl.float32)
tmp7 = tl.load(in_ptr2 + (r1 + (768*x0)), rmask, other=0.0).to(tl.float32)
tmp10 = tl.load(in_ptr3 + (r1 + (768*x0)), rmask, other=0.0).to(tl.int1)
tmp15 = tl.load(in_ptr4 + (r1), rmask, eviction_policy='evict_last', other=0.0)
tmp21 = tl.load(in_ptr5 + (r1 + (768*x0)), rmask, other=0.0)
tmp27 = tl.load(in_ptr6 + (x0), None, eviction_policy='evict_last')
tmp34 = tl.load(in_ptr7 + (x2), None, eviction_policy='evict_last')
tmp43 = tl.load(in_ptr8 + (x0), None, eviction_policy='evict_last')
tmp2 = tmp1.to(tl.float32)
tmp3 = tmp0 + tmp2
tmp5 = tmp4.to(tl.float32)
tmp6 = tmp3 + tmp5
tmp8 = tmp7.to(tl.float32)
tmp9 = tmp6 + tmp8
tmp11 = tmp10.to(tl.float32)
tmp12 = 1.1111111111111112
tmp13 = tmp11 * tmp12
tmp14 = tmp9 * tmp13
tmp16 = tmp14 * tmp15
tmp17 = tl.broadcast_to(tmp16, [RBLOCK])
tmp19 = tl.where(rmask, tmp17, 0)
tmp20 = triton_helpers.promote_to_tensor(tl.sum(tmp19, 0))
tmp22 = tmp16 * tmp21
tmp23 = tl.broadcast_to(tmp22, [RBLOCK])
tmp25 = tl.where(rmask, tmp23, 0)
tmp26 = triton_helpers.promote_to_tensor(tl.sum(tmp25, 0))
tmp28 = 768.0
tmp29 = tmp16 * tmp28
tmp30 = tmp29 - tmp20
tmp31 = tmp21 * tmp26
tmp32 = tmp30 - tmp31
tmp33 = tmp27 * tmp32
tmp35 = tl.full([RBLOCK], 2, tl.int32)
tmp36 = tmp34 + tmp35
tmp37 = tmp34 < 0
tmp38 = tl.where(tmp37, tmp36, tmp34)
tmp39 = tl.full([1], -1, tl.int64)
tmp40 = tmp34 == tmp39
tmp41 = 0.0
tmp42 = tl.where(tmp40, tmp41, tmp33)
tmp44 = tl.full([RBLOCK], 30522, tl.int32)
tmp45 = tmp43 + tmp44
tmp46 = tmp43 < 0
tmp47 = tl.where(tmp46, tmp45, tmp43)
tmp48 = tl.full([1], 0, tl.int64)
tmp49 = tmp43 == tmp48
tmp50 = tl.where(tmp49, tmp41, tmp33)
tl.store(in_out_ptr0 + (r1 + (768*x0)), tmp14, rmask)
tl.store(out_ptr2 + (r1 + (768*x0)), tmp33, rmask)
tl.atomic_add(out_ptr3 + (tl.broadcast_to(r1 + (768*tmp38), [RBLOCK])), tmp42, rmask, sem='relaxed')
tl.atomic_add(out_ptr4 + (tl.broadcast_to(r1 + (768*tmp47), [RBLOCK])), tmp50, rmask, sem='relaxed')
def get_args():
arg_0 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.float32)
arg_1 = rand_strided((8192, 768), (768, 1), device='cuda:0', dtype=torch.float16)
arg_2 = rand_strided((8192, 768), (768, 1), device='cuda:0', dtype=torch.float16)
arg_3 = rand_strided((8192, 768), (768, 1), device='cuda:0', dtype=torch.float16)
arg_4 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.bool)
arg_5 = rand_strided((768,), (1,), device='cuda:0', dtype=torch.float32)
arg_6 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.float32)
arg_7 = rand_strided((16, 512, 1), (512, 1, 1), device='cuda:0', dtype=torch.float32)
arg_8 = rand_strided((1, 512), (512, 1), device='cuda:0', dtype=torch.int64)
arg_9 = rand_strided((16, 512), (512, 1), device='cuda:0', dtype=torch.int64)
arg_10 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.float32)
arg_11 = rand_strided((2, 768), (768, 1), device='cuda:0', dtype=torch.float32)
arg_12 = rand_strided((30522, 768), (768, 1), device='cuda:0', dtype=torch.float32)
return arg_0, arg_1, arg_2, arg_3, arg_4, arg_5, arg_6, arg_7, arg_8, arg_9, arg_10, arg_11, arg_12,
def call(args):
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
stream0 = get_raw_stream(0)
triton_per_fused__to_copy_add_embedding_dense_backward_native_dropout_backward_native_layer_norm_backward_nll_loss_forward_29.run(*args, 8192, 768, grid=grid(8192), stream=stream0)
def benchmark_all_configs(args):
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
return triton_per_fused__to_copy_add_embedding_dense_backward_native_dropout_backward_native_layer_norm_backward_nll_loss_forward_29.benchmark_all_configs(*args, 8192, 768, grid=grid(8192))
if __name__ == '__main__':
from triton.testing import do_bench
args = get_args()
ms = do_bench(lambda: call(args), rep=40, fast_flush=True)
num_gb = 0.169980928
gb_per_s = num_gb / (ms / 1e3)
print(f"{ms:.3f}ms {num_gb:.3f}GB {gb_per_s:.2f}GB/s")
''', device_str='cuda')
# kernel path: /tmp/torchinductor_shunting/wc/cwcm7d2opg7dvfg3iqa43co7xxrrnilj7flh2kmutnq2zqi4vlfi.py
# Source Nodes: [], Original ATen: [aten.native_layer_norm_backward]
triton_red_fused_native_layer_norm_backward_30 = async_compile.triton('triton_red_fused_native_layer_norm_backward_30', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
from torch._dynamo.testing import rand_strided
from torch._C import _cuda_getCurrentRawStream as get_raw_stream
import torch
from torch._inductor.runtime.triton_heuristics import grid, split_scan_grid
@triton_heuristics.reduction(
size_hints=[65536, 128],
reduction_hint=ReductionHint.OUTER,
filename=__file__,
triton_meta={'signature': {0: '*fp32', 1: '*fp32', 2: '*fp32', 3: '*fp32', 4: 'i32', 5: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4, 5), equal_to_1=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused_native_layer_norm_backward_30', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 2, 'num_reduction': 2, 'backend_hash': '196B341D951BBDD96DE5B2B44B37054C2CBA8E89494C8B8A1052A863B8BC7596', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'kernel_num_gb': 0.050724864}
)
@triton.jit
def triton_red_fused_native_layer_norm_backward_30(in_ptr0, in_ptr1, out_ptr0, out_ptr1, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr):
xnumel = 49152
rnumel = 128
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
xmask = tl.full([XBLOCK, RBLOCK], True, tl.int1)
rbase = tl.arange(0, RBLOCK)[None, :]
x0 = xindex % 768
x1 = (xindex // 768)
_tmp4 = tl.full([XBLOCK, RBLOCK], 0, tl.float32)
x3 = xindex
_tmp7 = tl.full([XBLOCK, RBLOCK], 0, tl.float32)
for roffset in range(0, rnumel, RBLOCK):
rindex = roffset + rbase
rmask = rindex < rnumel
r2 = rindex
tmp0 = tl.load(in_ptr0 + (x0 + (768*r2) + (98304*x1)), rmask, eviction_policy='evict_first', other=0.0)
tmp1 = tl.load(in_ptr1 + (x0 + (768*r2) + (98304*x1)), rmask, eviction_policy='evict_first', other=0.0)
tmp2 = tmp0 * tmp1
tmp3 = tl.broadcast_to(tmp2, [XBLOCK, RBLOCK])
tmp5 = _tmp4 + tmp3
_tmp4 = tl.where(rmask, tmp5, _tmp4)
tmp6 = tl.broadcast_to(tmp0, [XBLOCK, RBLOCK])
tmp8 = _tmp7 + tmp6
_tmp7 = tl.where(rmask, tmp8, _tmp7)
tmp4 = tl.sum(_tmp4, 1)[:, None]
tl.store(out_ptr0 + (x3), tmp4, None)
tmp7 = tl.sum(_tmp7, 1)[:, None]
tl.store(out_ptr1 + (x3), tmp7, None)
def get_args():
arg_0 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.float32)
arg_1 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.float32)
arg_2 = rand_strided((768, 64), (1, 768), device='cuda:0', dtype=torch.float32)
arg_3 = rand_strided((768, 64), (1, 768), device='cuda:0', dtype=torch.float32)
return arg_0, arg_1, arg_2, arg_3,
def call(args):
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
stream0 = get_raw_stream(0)
triton_red_fused_native_layer_norm_backward_30.run(*args, 49152, 128, grid=grid(49152), stream=stream0)
def benchmark_all_configs(args):
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
return triton_red_fused_native_layer_norm_backward_30.benchmark_all_configs(*args, 49152, 128, grid=grid(49152))
if __name__ == '__main__':
from triton.testing import do_bench
args = get_args()
ms = do_bench(lambda: call(args), rep=40, fast_flush=True)
num_gb = 0.050724864
gb_per_s = num_gb / (ms / 1e3)
print(f"{ms:.3f}ms {num_gb:.3f}GB {gb_per_s:.2f}GB/s")
''', device_str='cuda')
# kernel path: /tmp/torchinductor_shunting/oo/coouyldiohu5nljnik4bjbkcy56erkcsydxpoco2pnqxwiivtdx2.py
# Source Nodes: [], Original ATen: [aten.embedding_dense_backward]
triton_poi_fused_embedding_dense_backward_31 = async_compile.triton('triton_poi_fused_embedding_dense_backward_31', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
from torch._dynamo.testing import rand_strided
from torch._C import _cuda_getCurrentRawStream as get_raw_stream
import torch
from torch._inductor.runtime.triton_heuristics import grid, split_scan_grid
@triton_heuristics.pointwise(
size_hints=[524288],
filename=__file__,
triton_meta={'signature': {0: '*fp32', 1: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1), equal_to_1=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_embedding_dense_backward_31', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 0, 'num_reduction': 0, 'backend_hash': '196B341D951BBDD96DE5B2B44B37054C2CBA8E89494C8B8A1052A863B8BC7596', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'kernel_num_gb': 0.001572864},
min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused_embedding_dense_backward_31(out_ptr0, xnumel, XBLOCK : tl.constexpr):
xnumel = 393216
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = tl.full([XBLOCK], True, tl.int1)
x0 = xindex
tmp0 = 0.0
tl.store(out_ptr0 + (x0), tmp0, None)
def get_args():
arg_0 = rand_strided((512, 768), (768, 1), device='cuda:0', dtype=torch.float32)
return arg_0,
def call(args):
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
stream0 = get_raw_stream(0)
triton_poi_fused_embedding_dense_backward_31.run(*args, 393216, grid=grid(393216), stream=stream0)
def benchmark_all_configs(args):
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
return triton_poi_fused_embedding_dense_backward_31.benchmark_all_configs(*args, 393216, grid=grid(393216))
if __name__ == '__main__':
from triton.testing import do_bench
args = get_args()
ms = do_bench(lambda: call(args), rep=40, fast_flush=True)
num_gb = 0.001572864
gb_per_s = num_gb / (ms / 1e3)
print(f"{ms:.3f}ms {num_gb:.3f}GB {gb_per_s:.2f}GB/s")
''', device_str='cuda')
# kernel path: /tmp/torchinductor_shunting/my/cmybittjc2jf2j2woo7yukrwbfrxkubzfhfznkvzq3czd6munu5v.py
# Source Nodes: [masked_lm_loss], Original ATen: [aten.embedding_dense_backward, aten.nll_loss_forward, aten.sum]
# masked_lm_loss => full_default_2
triton_per_fused_embedding_dense_backward_nll_loss_forward_sum_32 = async_compile.triton('triton_per_fused_embedding_dense_backward_nll_loss_forward_sum_32', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
from torch._dynamo.testing import rand_strided
from torch._C import _cuda_getCurrentRawStream as get_raw_stream
import torch
from torch._inductor.runtime.triton_heuristics import grid, split_scan_grid
@triton_heuristics.persistent_reduction(
size_hints=[524288, 16],
reduction_hint=ReductionHint.DEFAULT,
filename=__file__,
triton_meta={'signature': {0: '*fp32', 1: '*i64', 2: '*fp32', 3: 'i32', 4: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4), equal_to_1=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_per_fused_embedding_dense_backward_nll_loss_forward_sum_32', 'mutated_arg_names': ['out_ptr1'], 'no_x_dim': False, 'num_load': 2, 'num_reduction': 1, 'backend_hash': '196B341D951BBDD96DE5B2B44B37054C2CBA8E89494C8B8A1052A863B8BC7596', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'kernel_num_gb': 0.026742784}
)
@triton.jit
def triton_per_fused_embedding_dense_backward_nll_loss_forward_sum_32(in_ptr0, in_ptr1, out_ptr1, xnumel, rnumel, XBLOCK : tl.constexpr):
xnumel = 393216
rnumel = 16
RBLOCK: tl.constexpr = 16
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
xmask = tl.full([XBLOCK, RBLOCK], True, tl.int1)
rindex = tl.arange(0, RBLOCK)[None, :]
roffset = 0
rmask = tl.full([XBLOCK, RBLOCK], True, tl.int1)
r1 = rindex
x0 = xindex
x3 = (xindex // 768)
x2 = xindex % 768
tmp0 = tl.load(in_ptr0 + (x0 + (393216*r1)), None)
tmp4 = tl.load(in_ptr1 + (x3), None, eviction_policy='evict_last')
tmp1 = tl.broadcast_to(tmp0, [XBLOCK, RBLOCK])
tmp3 = tl.sum(tmp1, 1)[:, None]
tmp5 = tl.full([XBLOCK, 1], 512, tl.int32)
tmp6 = tmp4 + tmp5
tmp7 = tmp4 < 0
tmp8 = tl.where(tmp7, tmp6, tmp4)
tmp9 = tl.full([1, 1], -1, tl.int64)
tmp10 = tmp4 == tmp9
tmp11 = 0.0
tmp12 = tl.where(tmp10, tmp11, tmp3)
tl.atomic_add(out_ptr1 + (x2 + (768*tmp8)), tmp12, None, sem='relaxed')
def get_args():
arg_0 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.float32)
arg_1 = rand_strided((1, 512), (512, 1), device='cuda:0', dtype=torch.int64)
arg_2 = rand_strided((512, 768), (768, 1), device='cuda:0', dtype=torch.float32)
return arg_0, arg_1, arg_2,
def call(args):
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
stream0 = get_raw_stream(0)
triton_per_fused_embedding_dense_backward_nll_loss_forward_sum_32.run(*args, 393216, 16, grid=grid(393216), stream=stream0)
def benchmark_all_configs(args):
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
return triton_per_fused_embedding_dense_backward_nll_loss_forward_sum_32.benchmark_all_configs(*args, 393216, 16, grid=grid(393216))
if __name__ == '__main__':
from triton.testing import do_bench
args = get_args()
ms = do_bench(lambda: call(args), rep=40, fast_flush=True)
num_gb = 0.026742784
gb_per_s = num_gb / (ms / 1e3)
print(f"{ms:.3f}ms {num_gb:.3f}GB {gb_per_s:.2f}GB/s")
''', device_str='cuda')
async_compile.wait(globals())
del async_compile
def call(args):
primals_4, primals_14, primals_20, primals_30, primals_36, primals_46, primals_52, primals_62, primals_68, primals_78, primals_84, primals_94, primals_100, primals_110, primals_116, primals_126, primals_132, primals_142, primals_148, primals_158, primals_164, primals_174, primals_180, primals_190, primals_196, primals_200, primals_204, primals_205, primals_206, primals_207, mul_1, gt, view, permute_default_66, permute_default_67, permute_default_68, getitem_129, getitem_130, getitem_131, getitem_132, view_16, gt_2, mul_9, view_18, addmm_4, view_20, gt_3, mul_16, view_22, permute_default_60, permute_default_61, permute_default_62, getitem_122, getitem_123, getitem_124, getitem_125, view_38, gt_5, mul_22, view_40, addmm_10, view_42, gt_6, mul_29, view_44, permute_default_54, permute_default_55, permute_default_56, getitem_115, getitem_116, getitem_117, getitem_118, view_60, gt_8, mul_35, view_62, addmm_16, view_64, gt_9, mul_42, view_66, permute_default_48, permute_default_49, permute_default_50, getitem_108, getitem_109, getitem_110, getitem_111, view_82, gt_11, mul_48, view_84, addmm_22, view_86, gt_12, mul_55, view_88, permute_default_42, permute_default_43, permute_default_44, getitem_101, getitem_102, getitem_103, getitem_104, view_104, gt_14, mul_61, view_106, addmm_28, view_108, gt_15, mul_68, view_110, permute_default_36, permute_default_37, permute_default_38, getitem_94, getitem_95, getitem_96, getitem_97, view_126, gt_17, mul_74, view_128, addmm_34, view_130, gt_18, mul_81, view_132, permute_default_30, permute_default_31, permute_default_32, getitem_87, getitem_88, getitem_89, getitem_90, view_148, gt_20, mul_87, view_150, addmm_40, view_152, gt_21, mul_94, view_154, permute_default_24, permute_default_25, permute_default_26, getitem_80, getitem_81, getitem_82, getitem_83, view_170, gt_23, mul_100, view_172, addmm_46, view_174, gt_24, mul_107, view_176, permute_default_18, permute_default_19, permute_default_20, getitem_73, getitem_74, getitem_75, getitem_76, view_192, gt_26, mul_113, view_194, addmm_52, view_196, gt_27, mul_120, view_198, permute_default_12, permute_default_13, permute_default_14, getitem_66, getitem_67, getitem_68, getitem_69, view_214, gt_29, mul_126, view_216, addmm_58, view_218, gt_30, mul_133, view_220, permute_default_6, permute_default_7, permute_default_8, getitem_59, getitem_60, getitem_61, getitem_62, view_236, gt_32, mul_139, view_238, addmm_64, view_240, gt_33, mul_146, view_242, permute_default, permute_default_1, permute_default_2, getitem_52, getitem_53, getitem_54, getitem_55, view_258, gt_35, mul_152, view_260, addmm_70, view_262, gt_36, mul_159, view_264, addmm_72, getitem_51, rsqrt_25, view_266, view_267, amax_12, log, convert_element_type_510, permute_134, permute_138, div_27, permute_142, permute_146, div_28, permute_150, permute_162, permute_167, permute_171, div_30, permute_175, permute_179, div_31, permute_183, permute_195, permute_200, permute_204, div_33, permute_208, permute_212, div_34, permute_216, permute_228, permute_233, permute_237, div_36, permute_241, permute_245, div_37, permute_249, permute_261, permute_266, permute_270, div_39, permute_274, permute_278, div_40, permute_282, permute_294, permute_299, permute_303, div_42, permute_307, permute_311, div_43, permute_315, permute_327, permute_332, permute_336, div_45, permute_340, permute_344, div_46, permute_348, permute_360, permute_365, permute_369, div_48, permute_373, permute_377, div_49, permute_381, permute_393, permute_398, permute_402, div_51, permute_406, permute_410, div_52, permute_414, permute_426, permute_431, permute_435, div_54, permute_439, permute_443, div_55, permute_447, permute_459, permute_464, permute_468, div_57, permute_472, permute_476, div_58, permute_480, permute_492, permute_497, permute_501, div_60, permute_505, permute_509, div_61, permute_513, permute_525, permute_530, permute_534, div_63, tangents_1, tangents_2 = args
args.clear()
assert_size_stride(primals_4, (768, ), (1, ))
assert_size_stride(primals_14, (768, ), (1, ))
assert_size_stride(primals_20, (768, ), (1, ))
assert_size_stride(primals_30, (768, ), (1, ))
assert_size_stride(primals_36, (768, ), (1, ))
assert_size_stride(primals_46, (768, ), (1, ))
assert_size_stride(primals_52, (768, ), (1, ))
assert_size_stride(primals_62, (768, ), (1, ))
assert_size_stride(primals_68, (768, ), (1, ))
assert_size_stride(primals_78, (768, ), (1, ))
assert_size_stride(primals_84, (768, ), (1, ))
assert_size_stride(primals_94, (768, ), (1, ))
assert_size_stride(primals_100, (768, ), (1, ))
assert_size_stride(primals_110, (768, ), (1, ))
assert_size_stride(primals_116, (768, ), (1, ))
assert_size_stride(primals_126, (768, ), (1, ))
assert_size_stride(primals_132, (768, ), (1, ))
assert_size_stride(primals_142, (768, ), (1, ))
assert_size_stride(primals_148, (768, ), (1, ))
assert_size_stride(primals_158, (768, ), (1, ))
assert_size_stride(primals_164, (768, ), (1, ))
assert_size_stride(primals_174, (768, ), (1, ))
assert_size_stride(primals_180, (768, ), (1, ))
assert_size_stride(primals_190, (768, ), (1, ))
assert_size_stride(primals_196, (768, ), (1, ))
assert_size_stride(primals_200, (768, ), (1, ))
assert_size_stride(primals_204, (1, 512), (512, 1))
assert_size_stride(primals_205, (1, 512), (512, 1))
assert_size_stride(primals_206, (16, 512), (512, 1))
assert_size_stride(primals_207, (16, 512), (512, 1))
assert_size_stride(mul_1, (16, 512, 768), (393216, 768, 1))
assert_size_stride(gt, (16, 512, 768), (393216, 768, 1))
assert_size_stride(view, (8192, 768), (768, 1))
assert_size_stride(permute_default_66, (16, 12, 512, 64), (393216, 64, 768, 1))
assert_size_stride(permute_default_67, (16, 12, 512, 64), (393216, 64, 768, 1))
assert_size_stride(permute_default_68, (16, 12, 512, 64), (393216, 64, 768, 1))
assert_size_stride(getitem_129, (16, 12, 512, 64), (393216, 64, 768, 1))
assert_size_stride(getitem_130, (16, 12, 512), (6144, 512, 1))
assert_size_stride(getitem_131, (), ())
assert_size_stride(getitem_132, (), ())
assert_size_stride(view_16, (8192, 768), (768, 1))
assert_size_stride(gt_2, (16, 512, 768), (393216, 768, 1))
assert_size_stride(mul_9, (16, 512, 768), (393216, 768, 1))
assert_size_stride(view_18, (8192, 768), (768, 1))
assert_size_stride(addmm_4, (8192, 3072), (3072, 1))
assert_size_stride(view_20, (8192, 3072), (3072, 1))
assert_size_stride(gt_3, (16, 512, 768), (393216, 768, 1))
assert_size_stride(mul_16, (16, 512, 768), (393216, 768, 1))
assert_size_stride(view_22, (8192, 768), (768, 1))
assert_size_stride(permute_default_60, (16, 12, 512, 64), (393216, 64, 768, 1))
assert_size_stride(permute_default_61, (16, 12, 512, 64), (393216, 64, 768, 1))
assert_size_stride(permute_default_62, (16, 12, 512, 64), (393216, 64, 768, 1))
assert_size_stride(getitem_122, (16, 12, 512, 64), (393216, 64, 768, 1))
assert_size_stride(getitem_123, (16, 12, 512), (6144, 512, 1))
assert_size_stride(getitem_124, (), ())
assert_size_stride(getitem_125, (), ())
assert_size_stride(view_38, (8192, 768), (768, 1))
assert_size_stride(gt_5, (16, 512, 768), (393216, 768, 1))
assert_size_stride(mul_22, (16, 512, 768), (393216, 768, 1))
assert_size_stride(view_40, (8192, 768), (768, 1))
assert_size_stride(addmm_10, (8192, 3072), (3072, 1))
assert_size_stride(view_42, (8192, 3072), (3072, 1))
assert_size_stride(gt_6, (16, 512, 768), (393216, 768, 1))
assert_size_stride(mul_29, (16, 512, 768), (393216, 768, 1))
assert_size_stride(view_44, (8192, 768), (768, 1))
assert_size_stride(permute_default_54, (16, 12, 512, 64), (393216, 64, 768, 1))
assert_size_stride(permute_default_55, (16, 12, 512, 64), (393216, 64, 768, 1))
assert_size_stride(permute_default_56, (16, 12, 512, 64), (393216, 64, 768, 1))
assert_size_stride(getitem_115, (16, 12, 512, 64), (393216, 64, 768, 1))
assert_size_stride(getitem_116, (16, 12, 512), (6144, 512, 1))
assert_size_stride(getitem_117, (), ())
assert_size_stride(getitem_118, (), ())
assert_size_stride(view_60, (8192, 768), (768, 1))
assert_size_stride(gt_8, (16, 512, 768), (393216, 768, 1))
assert_size_stride(mul_35, (16, 512, 768), (393216, 768, 1))
assert_size_stride(view_62, (8192, 768), (768, 1))
assert_size_stride(addmm_16, (8192, 3072), (3072, 1))
assert_size_stride(view_64, (8192, 3072), (3072, 1))
assert_size_stride(gt_9, (16, 512, 768), (393216, 768, 1))
assert_size_stride(mul_42, (16, 512, 768), (393216, 768, 1))
assert_size_stride(view_66, (8192, 768), (768, 1))
assert_size_stride(permute_default_48, (16, 12, 512, 64), (393216, 64, 768, 1))
assert_size_stride(permute_default_49, (16, 12, 512, 64), (393216, 64, 768, 1))
assert_size_stride(permute_default_50, (16, 12, 512, 64), (393216, 64, 768, 1))
assert_size_stride(getitem_108, (16, 12, 512, 64), (393216, 64, 768, 1))
assert_size_stride(getitem_109, (16, 12, 512), (6144, 512, 1))
assert_size_stride(getitem_110, (), ())
assert_size_stride(getitem_111, (), ())
assert_size_stride(view_82, (8192, 768), (768, 1))
assert_size_stride(gt_11, (16, 512, 768), (393216, 768, 1))
assert_size_stride(mul_48, (16, 512, 768), (393216, 768, 1))
assert_size_stride(view_84, (8192, 768), (768, 1))
assert_size_stride(addmm_22, (8192, 3072), (3072, 1))
assert_size_stride(view_86, (8192, 3072), (3072, 1))
assert_size_stride(gt_12, (16, 512, 768), (393216, 768, 1))
assert_size_stride(mul_55, (16, 512, 768), (393216, 768, 1))
assert_size_stride(view_88, (8192, 768), (768, 1))
assert_size_stride(permute_default_42, (16, 12, 512, 64), (393216, 64, 768, 1))
assert_size_stride(permute_default_43, (16, 12, 512, 64), (393216, 64, 768, 1))
assert_size_stride(permute_default_44, (16, 12, 512, 64), (393216, 64, 768, 1))
assert_size_stride(getitem_101, (16, 12, 512, 64), (393216, 64, 768, 1))
assert_size_stride(getitem_102, (16, 12, 512), (6144, 512, 1))
assert_size_stride(getitem_103, (), ())
assert_size_stride(getitem_104, (), ())
assert_size_stride(view_104, (8192, 768), (768, 1))
assert_size_stride(gt_14, (16, 512, 768), (393216, 768, 1))
assert_size_stride(mul_61, (16, 512, 768), (393216, 768, 1))
assert_size_stride(view_106, (8192, 768), (768, 1))
assert_size_stride(addmm_28, (8192, 3072), (3072, 1))
assert_size_stride(view_108, (8192, 3072), (3072, 1))
assert_size_stride(gt_15, (16, 512, 768), (393216, 768, 1))
assert_size_stride(mul_68, (16, 512, 768), (393216, 768, 1))
assert_size_stride(view_110, (8192, 768), (768, 1))
assert_size_stride(permute_default_36, (16, 12, 512, 64), (393216, 64, 768, 1))
assert_size_stride(permute_default_37, (16, 12, 512, 64), (393216, 64, 768, 1))
assert_size_stride(permute_default_38, (16, 12, 512, 64), (393216, 64, 768, 1))
assert_size_stride(getitem_94, (16, 12, 512, 64), (393216, 64, 768, 1))
assert_size_stride(getitem_95, (16, 12, 512), (6144, 512, 1))
assert_size_stride(getitem_96, (), ())
assert_size_stride(getitem_97, (), ())
assert_size_stride(view_126, (8192, 768), (768, 1))
assert_size_stride(gt_17, (16, 512, 768), (393216, 768, 1))
assert_size_stride(mul_74, (16, 512, 768), (393216, 768, 1))
assert_size_stride(view_128, (8192, 768), (768, 1))
assert_size_stride(addmm_34, (8192, 3072), (3072, 1))
assert_size_stride(view_130, (8192, 3072), (3072, 1))
assert_size_stride(gt_18, (16, 512, 768), (393216, 768, 1))
assert_size_stride(mul_81, (16, 512, 768), (393216, 768, 1))
assert_size_stride(view_132, (8192, 768), (768, 1))
assert_size_stride(permute_default_30, (16, 12, 512, 64), (393216, 64, 768, 1))
assert_size_stride(permute_default_31, (16, 12, 512, 64), (393216, 64, 768, 1))
assert_size_stride(permute_default_32, (16, 12, 512, 64), (393216, 64, 768, 1))
assert_size_stride(getitem_87, (16, 12, 512, 64), (393216, 64, 768, 1))
assert_size_stride(getitem_88, (16, 12, 512), (6144, 512, 1))
assert_size_stride(getitem_89, (), ())
assert_size_stride(getitem_90, (), ())
assert_size_stride(view_148, (8192, 768), (768, 1))
assert_size_stride(gt_20, (16, 512, 768), (393216, 768, 1))
assert_size_stride(mul_87, (16, 512, 768), (393216, 768, 1))
assert_size_stride(view_150, (8192, 768), (768, 1))
assert_size_stride(addmm_40, (8192, 3072), (3072, 1))
assert_size_stride(view_152, (8192, 3072), (3072, 1))
assert_size_stride(gt_21, (16, 512, 768), (393216, 768, 1))
assert_size_stride(mul_94, (16, 512, 768), (393216, 768, 1))
assert_size_stride(view_154, (8192, 768), (768, 1))
assert_size_stride(permute_default_24, (16, 12, 512, 64), (393216, 64, 768, 1))
assert_size_stride(permute_default_25, (16, 12, 512, 64), (393216, 64, 768, 1))
assert_size_stride(permute_default_26, (16, 12, 512, 64), (393216, 64, 768, 1))
assert_size_stride(getitem_80, (16, 12, 512, 64), (393216, 64, 768, 1))
assert_size_stride(getitem_81, (16, 12, 512), (6144, 512, 1))
assert_size_stride(getitem_82, (), ())
assert_size_stride(getitem_83, (), ())
assert_size_stride(view_170, (8192, 768), (768, 1))
assert_size_stride(gt_23, (16, 512, 768), (393216, 768, 1))
assert_size_stride(mul_100, (16, 512, 768), (393216, 768, 1))
assert_size_stride(view_172, (8192, 768), (768, 1))
assert_size_stride(addmm_46, (8192, 3072), (3072, 1))
assert_size_stride(view_174, (8192, 3072), (3072, 1))
assert_size_stride(gt_24, (16, 512, 768), (393216, 768, 1))
assert_size_stride(mul_107, (16, 512, 768), (393216, 768, 1))
assert_size_stride(view_176, (8192, 768), (768, 1))
assert_size_stride(permute_default_18, (16, 12, 512, 64), (393216, 64, 768, 1))
assert_size_stride(permute_default_19, (16, 12, 512, 64), (393216, 64, 768, 1))
assert_size_stride(permute_default_20, (16, 12, 512, 64), (393216, 64, 768, 1))
assert_size_stride(getitem_73, (16, 12, 512, 64), (393216, 64, 768, 1))
assert_size_stride(getitem_74, (16, 12, 512), (6144, 512, 1))
assert_size_stride(getitem_75, (), ())
assert_size_stride(getitem_76, (), ())
assert_size_stride(view_192, (8192, 768), (768, 1))
assert_size_stride(gt_26, (16, 512, 768), (393216, 768, 1))
assert_size_stride(mul_113, (16, 512, 768), (393216, 768, 1))
assert_size_stride(view_194, (8192, 768), (768, 1))
assert_size_stride(addmm_52, (8192, 3072), (3072, 1))
assert_size_stride(view_196, (8192, 3072), (3072, 1))
assert_size_stride(gt_27, (16, 512, 768), (393216, 768, 1))
assert_size_stride(mul_120, (16, 512, 768), (393216, 768, 1))
assert_size_stride(view_198, (8192, 768), (768, 1))
assert_size_stride(permute_default_12, (16, 12, 512, 64), (393216, 64, 768, 1))
assert_size_stride(permute_default_13, (16, 12, 512, 64), (393216, 64, 768, 1))
assert_size_stride(permute_default_14, (16, 12, 512, 64), (393216, 64, 768, 1))
assert_size_stride(getitem_66, (16, 12, 512, 64), (393216, 64, 768, 1))
assert_size_stride(getitem_67, (16, 12, 512), (6144, 512, 1))
assert_size_stride(getitem_68, (), ())
assert_size_stride(getitem_69, (), ())
assert_size_stride(view_214, (8192, 768), (768, 1))
assert_size_stride(gt_29, (16, 512, 768), (393216, 768, 1))
assert_size_stride(mul_126, (16, 512, 768), (393216, 768, 1))
assert_size_stride(view_216, (8192, 768), (768, 1))
assert_size_stride(addmm_58, (8192, 3072), (3072, 1))
assert_size_stride(view_218, (8192, 3072), (3072, 1))
assert_size_stride(gt_30, (16, 512, 768), (393216, 768, 1))
assert_size_stride(mul_133, (16, 512, 768), (393216, 768, 1))
assert_size_stride(view_220, (8192, 768), (768, 1))
assert_size_stride(permute_default_6, (16, 12, 512, 64), (393216, 64, 768, 1))
assert_size_stride(permute_default_7, (16, 12, 512, 64), (393216, 64, 768, 1))
assert_size_stride(permute_default_8, (16, 12, 512, 64), (393216, 64, 768, 1))
assert_size_stride(getitem_59, (16, 12, 512, 64), (393216, 64, 768, 1))
assert_size_stride(getitem_60, (16, 12, 512), (6144, 512, 1))
assert_size_stride(getitem_61, (), ())
assert_size_stride(getitem_62, (), ())
assert_size_stride(view_236, (8192, 768), (768, 1))
assert_size_stride(gt_32, (16, 512, 768), (393216, 768, 1))
assert_size_stride(mul_139, (16, 512, 768), (393216, 768, 1))
assert_size_stride(view_238, (8192, 768), (768, 1))
assert_size_stride(addmm_64, (8192, 3072), (3072, 1))
assert_size_stride(view_240, (8192, 3072), (3072, 1))
assert_size_stride(gt_33, (16, 512, 768), (393216, 768, 1))
assert_size_stride(mul_146, (16, 512, 768), (393216, 768, 1))
assert_size_stride(view_242, (8192, 768), (768, 1))
assert_size_stride(permute_default, (16, 12, 512, 64), (393216, 64, 768, 1))
assert_size_stride(permute_default_1, (16, 12, 512, 64), (393216, 64, 768, 1))
assert_size_stride(permute_default_2, (16, 12, 512, 64), (393216, 64, 768, 1))
assert_size_stride(getitem_52, (16, 12, 512, 64), (393216, 64, 768, 1))
assert_size_stride(getitem_53, (16, 12, 512), (6144, 512, 1))
assert_size_stride(getitem_54, (), ())
assert_size_stride(getitem_55, (), ())
assert_size_stride(view_258, (8192, 768), (768, 1))
assert_size_stride(gt_35, (16, 512, 768), (393216, 768, 1))
assert_size_stride(mul_152, (16, 512, 768), (393216, 768, 1))
assert_size_stride(view_260, (8192, 768), (768, 1))
assert_size_stride(addmm_70, (8192, 3072), (3072, 1))
assert_size_stride(view_262, (8192, 3072), (3072, 1))
assert_size_stride(gt_36, (16, 512, 768), (393216, 768, 1))
assert_size_stride(mul_159, (16, 512, 768), (393216, 768, 1))
assert_size_stride(view_264, (8192, 768), (768, 1))
assert_size_stride(addmm_72, (8192, 768), (768, 1))
assert_size_stride(getitem_51, (16, 512, 1), (512, 1, 1))
assert_size_stride(rsqrt_25, (16, 512, 1), (512, 1, 1))
assert_size_stride(view_266, (8192, 768), (768, 1))
assert_size_stride(view_267, (16, 512, 30522), (15630336, 30528, 1))
assert_size_stride(amax_12, (8192, 1), (1, 1))
assert_size_stride(log, (8192, 1), (1, 1))
assert_size_stride(convert_element_type_510, (), ())
assert_size_stride(permute_134, (30522, 768), (768, 1))
assert_size_stride(permute_138, (768, 768), (768, 1))
assert_size_stride(div_27, (16, 512, 1), (512, 1, 1))
assert_size_stride(permute_142, (768, 3072), (3072, 1))
assert_size_stride(permute_146, (3072, 768), (768, 1))
assert_size_stride(div_28, (16, 512, 1), (512, 1, 1))
assert_size_stride(permute_150, (768, 768), (768, 1))
assert_size_stride(permute_162, (768, 768), (768, 1))
assert_size_stride(permute_167, (768, 768), (768, 1))
assert_size_stride(permute_171, (768, 768), (768, 1))
assert_size_stride(div_30, (16, 512, 1), (512, 1, 1))
assert_size_stride(permute_175, (768, 3072), (3072, 1))
assert_size_stride(permute_179, (3072, 768), (768, 1))
assert_size_stride(div_31, (16, 512, 1), (512, 1, 1))
assert_size_stride(permute_183, (768, 768), (768, 1))
assert_size_stride(permute_195, (768, 768), (768, 1))
assert_size_stride(permute_200, (768, 768), (768, 1))
assert_size_stride(permute_204, (768, 768), (768, 1))
assert_size_stride(div_33, (16, 512, 1), (512, 1, 1))
assert_size_stride(permute_208, (768, 3072), (3072, 1))
assert_size_stride(permute_212, (3072, 768), (768, 1))
assert_size_stride(div_34, (16, 512, 1), (512, 1, 1))
assert_size_stride(permute_216, (768, 768), (768, 1))
assert_size_stride(permute_228, (768, 768), (768, 1))
assert_size_stride(permute_233, (768, 768), (768, 1))
assert_size_stride(permute_237, (768, 768), (768, 1))
assert_size_stride(div_36, (16, 512, 1), (512, 1, 1))
assert_size_stride(permute_241, (768, 3072), (3072, 1))
assert_size_stride(permute_245, (3072, 768), (768, 1))
assert_size_stride(div_37, (16, 512, 1), (512, 1, 1))
assert_size_stride(permute_249, (768, 768), (768, 1))
assert_size_stride(permute_261, (768, 768), (768, 1))
assert_size_stride(permute_266, (768, 768), (768, 1))
assert_size_stride(permute_270, (768, 768), (768, 1))
assert_size_stride(div_39, (16, 512, 1), (512, 1, 1))
assert_size_stride(permute_274, (768, 3072), (3072, 1))
assert_size_stride(permute_278, (3072, 768), (768, 1))
assert_size_stride(div_40, (16, 512, 1), (512, 1, 1))
assert_size_stride(permute_282, (768, 768), (768, 1))
assert_size_stride(permute_294, (768, 768), (768, 1))
assert_size_stride(permute_299, (768, 768), (768, 1))
assert_size_stride(permute_303, (768, 768), (768, 1))
assert_size_stride(div_42, (16, 512, 1), (512, 1, 1))
assert_size_stride(permute_307, (768, 3072), (3072, 1))
assert_size_stride(permute_311, (3072, 768), (768, 1))
assert_size_stride(div_43, (16, 512, 1), (512, 1, 1))
assert_size_stride(permute_315, (768, 768), (768, 1))
assert_size_stride(permute_327, (768, 768), (768, 1))
assert_size_stride(permute_332, (768, 768), (768, 1))
assert_size_stride(permute_336, (768, 768), (768, 1))
assert_size_stride(div_45, (16, 512, 1), (512, 1, 1))
assert_size_stride(permute_340, (768, 3072), (3072, 1))
assert_size_stride(permute_344, (3072, 768), (768, 1))
assert_size_stride(div_46, (16, 512, 1), (512, 1, 1))
assert_size_stride(permute_348, (768, 768), (768, 1))
assert_size_stride(permute_360, (768, 768), (768, 1))
assert_size_stride(permute_365, (768, 768), (768, 1))
assert_size_stride(permute_369, (768, 768), (768, 1))
assert_size_stride(div_48, (16, 512, 1), (512, 1, 1))
assert_size_stride(permute_373, (768, 3072), (3072, 1))
assert_size_stride(permute_377, (3072, 768), (768, 1))
assert_size_stride(div_49, (16, 512, 1), (512, 1, 1))
assert_size_stride(permute_381, (768, 768), (768, 1))
assert_size_stride(permute_393, (768, 768), (768, 1))
assert_size_stride(permute_398, (768, 768), (768, 1))
assert_size_stride(permute_402, (768, 768), (768, 1))
assert_size_stride(div_51, (16, 512, 1), (512, 1, 1))
assert_size_stride(permute_406, (768, 3072), (3072, 1))
assert_size_stride(permute_410, (3072, 768), (768, 1))
assert_size_stride(div_52, (16, 512, 1), (512, 1, 1))
assert_size_stride(permute_414, (768, 768), (768, 1))
assert_size_stride(permute_426, (768, 768), (768, 1))
assert_size_stride(permute_431, (768, 768), (768, 1))
assert_size_stride(permute_435, (768, 768), (768, 1))
assert_size_stride(div_54, (16, 512, 1), (512, 1, 1))
assert_size_stride(permute_439, (768, 3072), (3072, 1))
assert_size_stride(permute_443, (3072, 768), (768, 1))
assert_size_stride(div_55, (16, 512, 1), (512, 1, 1))
assert_size_stride(permute_447, (768, 768), (768, 1))
assert_size_stride(permute_459, (768, 768), (768, 1))
assert_size_stride(permute_464, (768, 768), (768, 1))
assert_size_stride(permute_468, (768, 768), (768, 1))
assert_size_stride(div_57, (16, 512, 1), (512, 1, 1))
assert_size_stride(permute_472, (768, 3072), (3072, 1))
assert_size_stride(permute_476, (3072, 768), (768, 1))
assert_size_stride(div_58, (16, 512, 1), (512, 1, 1))
assert_size_stride(permute_480, (768, 768), (768, 1))
assert_size_stride(permute_492, (768, 768), (768, 1))
assert_size_stride(permute_497, (768, 768), (768, 1))
assert_size_stride(permute_501, (768, 768), (768, 1))
assert_size_stride(div_60, (16, 512, 1), (512, 1, 1))
assert_size_stride(permute_505, (768, 3072), (3072, 1))
assert_size_stride(permute_509, (3072, 768), (768, 1))
assert_size_stride(div_61, (16, 512, 1), (512, 1, 1))
assert_size_stride(permute_513, (768, 768), (768, 1))
assert_size_stride(permute_525, (768, 768), (768, 1))
assert_size_stride(permute_530, (768, 768), (768, 1))
assert_size_stride(permute_534, (768, 768), (768, 1))
assert_size_stride(div_63, (16, 512, 1), (512, 1, 1))
assert_size_stride(tangents_1, (), ())
assert_size_stride(tangents_2, (16, 512, 30522), (15627264, 30522, 1))
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
buf0 = empty_strided_cuda((8192, 30522), (30528, 1), torch.float32)
# Source Nodes: [masked_lm_loss], Original ATen: [aten.nll_loss_backward, aten.nll_loss_forward]
stream0 = get_raw_stream(0)
triton_poi_fused_nll_loss_backward_nll_loss_forward_0.run(buf0, 250036224, grid=grid(250036224), stream=stream0)
# Source Nodes: [masked_lm_loss], Original ATen: [aten.nll_loss_backward, aten.nll_loss_forward]
triton_poi_fused_nll_loss_backward_nll_loss_forward_1.run(primals_207, buf0, 8192, grid=grid(8192), stream=stream0)
buf3 = empty_strided_cuda((16, 512, 30522), (15630336, 30528, 1), torch.float16)
# Source Nodes: [masked_lm_loss], Original ATen: [aten._log_softmax_backward_data, aten.add, aten.nll_loss_backward, aten.nll_loss_forward]
triton_red_fused__log_softmax_backward_data_add_nll_loss_backward_nll_loss_forward_2.run(buf0, primals_207, tangents_1, convert_element_type_510, tangents_2, view_267, amax_12, log, buf3, 8192, 30522, grid=grid(8192), stream=stream0)
del amax_12
del buf0
del convert_element_type_510
del log
del primals_207
del tangents_1
del tangents_2
del view_267
buf4 = empty_strided_cuda((8192, 30528), (30528, 1), torch.float16)
# Source Nodes: [], Original ATen: []
triton_poi_fused_3.run(buf3, buf4, 250085376, grid=grid(250085376), stream=stream0)
buf5 = empty_strided_cuda((30528, 768), (768, 1), torch.float16)
# Source Nodes: [], Original ATen: []
triton_poi_fused_4.run(permute_134, buf5, 23445504, grid=grid(23445504), stream=stream0)
del permute_134
buf6 = empty_strided_cuda((8192, 768), (768, 1), torch.float16)
# Source Nodes: [], Original ATen: []
extern_kernels.mm(buf4, buf5, out=buf6)
del buf4
del buf5
buf7 = empty_strided_cuda((30522, 768), (768, 1), torch.float16)
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf3, (30522, 8192), (1, 30528), 0), view_266, out=buf7)
del view_266
buf10 = empty_strided_cuda((30522, ), (1, ), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum]
triton_red_fused__to_copy_sum_5.run(buf3, buf10, 30522, 8192, grid=grid(30522), stream=stream0)
del buf3
buf9 = empty_strided_cuda((30522, 768), (768, 1), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_6.run(buf7, buf9, 23440896, grid=grid(23440896), stream=stream0)
del buf7
buf18 = empty_strided_cuda((8192, 768), (768, 1), torch.float16)
# Source Nodes: [hidden_states_97, hidden_states_98], Original ATen: [aten._to_copy, aten.gelu, aten.gelu_backward, aten.native_layer_norm, aten.native_layer_norm_backward, aten.view]
triton_per_fused__to_copy_gelu_gelu_backward_native_layer_norm_native_layer_norm_backward_view_7.run(buf6, primals_200, addmm_72, getitem_51, rsqrt_25, buf18, 8192, 768, grid=grid(8192), stream=stream0)
del primals_200
buf13 = empty_strided_cuda((768, 64), (1, 768), torch.float32)
buf15 = empty_strided_cuda((768, 64), (1, 768), torch.float32)
# Source Nodes: [hidden_states_97, hidden_states_98], Original ATen: [aten._to_copy, aten.gelu, aten.native_layer_norm, aten.native_layer_norm_backward]
triton_red_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_8.run(buf6, addmm_72, getitem_51, rsqrt_25, buf13, buf15, 49152, 128, grid=grid(49152), stream=stream0)
del addmm_72
del getitem_51
del rsqrt_25
buf14 = empty_strided_cuda((768, ), (1, ), torch.float32)
# Source Nodes: [hidden_states_97, hidden_states_98], Original ATen: [aten._to_copy, aten.gelu, aten.native_layer_norm, aten.native_layer_norm_backward]
triton_per_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_9.run(buf13, buf14, 768, 64, grid=grid(768), stream=stream0)
buf16 = empty_strided_cuda((768, ), (1, ), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy, aten.native_layer_norm_backward]
triton_per_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_9.run(buf15, buf16, 768, 64, grid=grid(768), stream=stream0)
buf19 = buf6; del buf6 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(buf18, permute_138, out=buf19)
del permute_138
buf20 = empty_strided_cuda((768, 768), (768, 1), torch.float16)
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf18, (768, 8192), (1, 768), 0), view_264, out=buf20)
del view_264
buf21 = reinterpret_tensor(buf15, (1, 768, 64), (49152, 1, 768), 0); del buf15 # reuse
# Source Nodes: [], Original ATen: [aten.sum]
triton_red_fused_sum_10.run(buf18, buf21, 49152, 128, grid=grid(49152), stream=stream0)
buf24 = empty_strided_cuda((768, ), (1, ), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum]
triton_per_fused__to_copy_sum_11.run(buf21, buf24, 768, 64, grid=grid(768), stream=stream0)
buf23 = empty_strided_cuda((768, 768), (768, 1), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_12.run(buf20, buf23, 589824, grid=grid(589824), stream=stream0)
buf27 = empty_strided_cuda((16, 512, 768), (393216, 768, 1), torch.float32)
buf32 = reinterpret_tensor(buf18, (16, 512, 768), (393216, 768, 1), 0); del buf18 # reuse
# Source Nodes: [], Original ATen: [aten._to_copy, aten.native_dropout_backward, aten.native_layer_norm_backward]
triton_per_fused__to_copy_native_dropout_backward_native_layer_norm_backward_13.run(buf19, primals_196, mul_159, div_27, gt_36, buf27, buf32, 8192, 768, grid=grid(8192), stream=stream0)
del div_27
del gt_36
del primals_196
buf28 = reinterpret_tensor(buf21, (768, 64), (1, 768), 0); del buf21 # reuse
buf30 = buf13; del buf13 # reuse
# Source Nodes: [], Original ATen: [aten._to_copy, aten.native_layer_norm_backward]
triton_red_fused__to_copy_native_layer_norm_backward_14.run(buf19, mul_159, buf28, buf30, 49152, 128, grid=grid(49152), stream=stream0)
del mul_159
buf29 = empty_strided_cuda((768, ), (1, ), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy, aten.native_layer_norm_backward]
triton_per_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_9.run(buf28, buf29, 768, 64, grid=grid(768), stream=stream0)
buf31 = empty_strided_cuda((768, ), (1, ), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy, aten.native_layer_norm_backward]
triton_per_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_9.run(buf30, buf31, 768, 64, grid=grid(768), stream=stream0)
buf33 = empty_strided_cuda((8192, 3072), (3072, 1), torch.float16)
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf32, (8192, 768), (768, 1), 0), permute_142, out=buf33)
del permute_142
buf34 = empty_strided_cuda((768, 3072), (3072, 1), torch.float16)
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf32, (768, 8192), (1, 768), 0), view_262, out=buf34)
del view_262
buf35 = reinterpret_tensor(buf30, (1, 768, 64), (49152, 1, 768), 0); del buf30 # reuse
# Source Nodes: [], Original ATen: [aten.sum]
triton_red_fused_sum_15.run(buf32, buf35, 49152, 128, grid=grid(49152), stream=stream0)
buf38 = empty_strided_cuda((768, ), (1, ), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum]
triton_per_fused__to_copy_sum_11.run(buf35, buf38, 768, 64, grid=grid(768), stream=stream0)
buf37 = empty_strided_cuda((768, 3072), (3072, 1), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_16.run(buf34, buf37, 2359296, grid=grid(2359296), stream=stream0)
buf39 = reinterpret_tensor(buf33, (16, 512, 3072), (1572864, 3072, 1), 0); del buf33 # reuse
# Source Nodes: [hidden_states_92], Original ATen: [aten.gelu, aten.gelu_backward]
triton_poi_fused_gelu_gelu_backward_17.run(buf39, addmm_70, 25165824, grid=grid(25165824), stream=stream0)
del addmm_70
buf40 = reinterpret_tensor(buf32, (8192, 768), (768, 1), 0); del buf32 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf39, (8192, 3072), (3072, 1), 0), permute_146, out=buf40)
del permute_146
buf41 = reinterpret_tensor(buf34, (3072, 768), (768, 1), 0); del buf34 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf39, (3072, 8192), (1, 3072), 0), view_260, out=buf41)
del view_260
buf42 = empty_strided_cuda((1, 3072, 32), (98304, 1, 3072), torch.float32)
# Source Nodes: [], Original ATen: [aten.sum]
triton_red_fused_sum_18.run(buf39, buf42, 98304, 256, grid=grid(98304), stream=stream0)
buf45 = empty_strided_cuda((3072, ), (1, ), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum]
triton_per_fused__to_copy_sum_19.run(buf42, buf45, 3072, 32, grid=grid(3072), stream=stream0)
buf44 = empty_strided_cuda((3072, 768), (768, 1), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_20.run(buf41, buf44, 2359296, grid=grid(2359296), stream=stream0)
buf48 = empty_strided_cuda((16, 512, 768), (393216, 768, 1), torch.float32)
buf53 = reinterpret_tensor(buf19, (16, 512, 768), (393216, 768, 1), 0); del buf19 # reuse
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_dropout_backward, aten.native_layer_norm_backward]
triton_per_fused__to_copy_add_native_dropout_backward_native_layer_norm_backward_21.run(buf27, buf40, primals_190, mul_152, div_28, gt_35, buf48, buf53, 8192, 768, grid=grid(8192), stream=stream0)
del div_28
del gt_35
del primals_190
buf49 = reinterpret_tensor(buf35, (768, 64), (1, 768), 0); del buf35 # reuse
buf51 = buf28; del buf28 # reuse
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward]
triton_red_fused__to_copy_add_native_layer_norm_backward_22.run(buf27, buf40, mul_152, buf49, buf51, 49152, 128, grid=grid(49152), stream=stream0)
del mul_152
buf50 = empty_strided_cuda((768, ), (1, ), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward]
triton_per_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_9.run(buf49, buf50, 768, 64, grid=grid(768), stream=stream0)
buf52 = empty_strided_cuda((768, ), (1, ), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward]
triton_per_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_9.run(buf51, buf52, 768, 64, grid=grid(768), stream=stream0)
buf54 = buf40; del buf40 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf53, (8192, 768), (768, 1), 0), permute_150, out=buf54)
del permute_150
buf55 = buf20; del buf20 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf53, (768, 8192), (1, 768), 0), view_258, out=buf55)
del view_258
buf56 = reinterpret_tensor(buf51, (1, 768, 64), (49152, 1, 768), 0); del buf51 # reuse
# Source Nodes: [], Original ATen: [aten.sum]
triton_red_fused_sum_15.run(buf53, buf56, 49152, 128, grid=grid(49152), stream=stream0)
buf59 = empty_strided_cuda((768, ), (1, ), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum]
triton_per_fused__to_copy_sum_11.run(buf56, buf59, 768, 64, grid=grid(768), stream=stream0)
buf58 = empty_strided_cuda((768, 768), (768, 1), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_12.run(buf55, buf58, 589824, grid=grid(589824), stream=stream0)
buf60 = reinterpret_tensor(buf53, (16, 12, 512, 64), (393216, 32768, 64, 1), 0); del buf53 # reuse
# Source Nodes: [], Original ATen: [aten.clone]
triton_poi_fused_clone_23.run(buf54, buf60, 6291456, grid=grid(6291456), stream=stream0)
del buf54
# Source Nodes: [], Original ATen: [aten.clone]
buf61 = aten._scaled_dot_product_flash_attention_backward.default(buf60, permute_default, permute_default_1, permute_default_2, getitem_52, getitem_53, None, None, 512, 512, 0.1, False, getitem_54, getitem_55, scale=0.125)
del getitem_52
del getitem_53
del getitem_54
del getitem_55
del permute_default
del permute_default_1
del permute_default_2
buf62 = buf61[0]
buf63 = buf61[1]
buf64 = buf61[2]
del buf61
buf65 = reinterpret_tensor(buf60, (8192, 768), (768, 1), 0); del buf60 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf64, (8192, 768), (768, 1), 0), permute_162, out=buf65)
del permute_162
buf66 = buf55; del buf55 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf64, (768, 8192), (1, 768), 0), view_242, out=buf66)
buf67 = buf56; del buf56 # reuse
# Source Nodes: [], Original ATen: [aten.sum]
triton_red_fused_sum_24.run(buf64, buf67, 49152, 128, grid=grid(49152), stream=stream0)
buf70 = empty_strided_cuda((768, ), (1, ), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum]
triton_per_fused__to_copy_sum_11.run(buf67, buf70, 768, 64, grid=grid(768), stream=stream0)
buf69 = empty_strided_cuda((768, 768), (768, 1), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_12.run(buf66, buf69, 589824, grid=grid(589824), stream=stream0)
buf71 = reinterpret_tensor(buf64, (8192, 768), (768, 1), 0); del buf64 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf63, (8192, 768), (768, 1), 0), permute_167, out=buf71)
del permute_167
buf72 = buf66; del buf66 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf63, (768, 8192), (1, 768), 0), view_242, out=buf72)
buf73 = buf67; del buf67 # reuse
# Source Nodes: [], Original ATen: [aten.sum]
triton_red_fused_sum_24.run(buf63, buf73, 49152, 128, grid=grid(49152), stream=stream0)
buf76 = empty_strided_cuda((768, ), (1, ), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum]
triton_per_fused__to_copy_sum_11.run(buf73, buf76, 768, 64, grid=grid(768), stream=stream0)
buf75 = empty_strided_cuda((768, 768), (768, 1), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_12.run(buf72, buf75, 589824, grid=grid(589824), stream=stream0)
buf77 = reinterpret_tensor(buf63, (8192, 768), (768, 1), 0); del buf63 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf62, (8192, 768), (768, 1), 0), permute_171, out=buf77)
del permute_171
buf78 = buf72; del buf72 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf62, (768, 8192), (1, 768), 0), view_242, out=buf78)
del view_242
buf79 = buf73; del buf73 # reuse
# Source Nodes: [], Original ATen: [aten.sum]
triton_red_fused_sum_24.run(buf62, buf79, 49152, 128, grid=grid(49152), stream=stream0)
buf82 = empty_strided_cuda((768, ), (1, ), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum]
triton_per_fused__to_copy_sum_11.run(buf79, buf82, 768, 64, grid=grid(768), stream=stream0)
buf81 = empty_strided_cuda((768, 768), (768, 1), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_12.run(buf78, buf81, 589824, grid=grid(589824), stream=stream0)
buf86 = buf27; del buf27 # reuse
buf91 = reinterpret_tensor(buf62, (16, 512, 768), (393216, 768, 1), 0); del buf62 # reuse
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_dropout_backward, aten.native_layer_norm_backward]
triton_per_fused__to_copy_add_native_dropout_backward_native_layer_norm_backward_25.run(buf48, buf65, buf71, buf77, primals_180, mul_146, div_30, gt_33, buf86, buf91, 8192, 768, grid=grid(8192), stream=stream0)
del div_30
del gt_33
del primals_180
buf87 = reinterpret_tensor(buf79, (768, 64), (1, 768), 0); del buf79 # reuse
buf89 = buf49; del buf49 # reuse
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward]
triton_red_fused__to_copy_add_native_layer_norm_backward_26.run(buf48, buf65, buf71, buf77, mul_146, buf87, buf89, 49152, 128, grid=grid(49152), stream=stream0)
del buf65
del buf71
del mul_146
buf88 = empty_strided_cuda((768, ), (1, ), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward]
triton_per_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_9.run(buf87, buf88, 768, 64, grid=grid(768), stream=stream0)
buf90 = empty_strided_cuda((768, ), (1, ), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward]
triton_per_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_9.run(buf89, buf90, 768, 64, grid=grid(768), stream=stream0)
buf92 = reinterpret_tensor(buf39, (8192, 3072), (3072, 1), 0); del buf39 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf91, (8192, 768), (768, 1), 0), permute_175, out=buf92)
del permute_175
buf93 = reinterpret_tensor(buf41, (768, 3072), (3072, 1), 0); del buf41 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf91, (768, 8192), (1, 768), 0), view_240, out=buf93)
del view_240
buf94 = reinterpret_tensor(buf89, (1, 768, 64), (49152, 1, 768), 0); del buf89 # reuse
# Source Nodes: [], Original ATen: [aten.sum]
triton_red_fused_sum_15.run(buf91, buf94, 49152, 128, grid=grid(49152), stream=stream0)
buf97 = empty_strided_cuda((768, ), (1, ), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum]
triton_per_fused__to_copy_sum_11.run(buf94, buf97, 768, 64, grid=grid(768), stream=stream0)
buf96 = empty_strided_cuda((768, 3072), (3072, 1), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_16.run(buf93, buf96, 2359296, grid=grid(2359296), stream=stream0)
buf98 = reinterpret_tensor(buf92, (16, 512, 3072), (1572864, 3072, 1), 0); del buf92 # reuse
# Source Nodes: [hidden_states_84], Original ATen: [aten.gelu, aten.gelu_backward]
triton_poi_fused_gelu_gelu_backward_17.run(buf98, addmm_64, 25165824, grid=grid(25165824), stream=stream0)
del addmm_64
buf99 = reinterpret_tensor(buf91, (8192, 768), (768, 1), 0); del buf91 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf98, (8192, 3072), (3072, 1), 0), permute_179, out=buf99)
del permute_179
buf100 = reinterpret_tensor(buf93, (3072, 768), (768, 1), 0); del buf93 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf98, (3072, 8192), (1, 3072), 0), view_238, out=buf100)
del view_238
buf101 = buf42; del buf42 # reuse
# Source Nodes: [], Original ATen: [aten.sum]
triton_red_fused_sum_18.run(buf98, buf101, 98304, 256, grid=grid(98304), stream=stream0)
buf104 = empty_strided_cuda((3072, ), (1, ), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum]
triton_per_fused__to_copy_sum_19.run(buf101, buf104, 3072, 32, grid=grid(3072), stream=stream0)
buf103 = empty_strided_cuda((3072, 768), (768, 1), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_20.run(buf100, buf103, 2359296, grid=grid(2359296), stream=stream0)
buf107 = buf48; del buf48 # reuse
buf112 = reinterpret_tensor(buf77, (16, 512, 768), (393216, 768, 1), 0); del buf77 # reuse
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_dropout_backward, aten.native_layer_norm_backward]
triton_per_fused__to_copy_add_native_dropout_backward_native_layer_norm_backward_21.run(buf86, buf99, primals_174, mul_139, div_31, gt_32, buf107, buf112, 8192, 768, grid=grid(8192), stream=stream0)
del div_31
del gt_32
del primals_174
buf108 = reinterpret_tensor(buf94, (768, 64), (1, 768), 0); del buf94 # reuse
buf110 = buf87; del buf87 # reuse
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward]
triton_red_fused__to_copy_add_native_layer_norm_backward_22.run(buf86, buf99, mul_139, buf108, buf110, 49152, 128, grid=grid(49152), stream=stream0)
del mul_139
buf109 = empty_strided_cuda((768, ), (1, ), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward]
triton_per_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_9.run(buf108, buf109, 768, 64, grid=grid(768), stream=stream0)
buf111 = empty_strided_cuda((768, ), (1, ), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward]
triton_per_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_9.run(buf110, buf111, 768, 64, grid=grid(768), stream=stream0)
buf113 = buf99; del buf99 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf112, (8192, 768), (768, 1), 0), permute_183, out=buf113)
del permute_183
buf114 = buf78; del buf78 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf112, (768, 8192), (1, 768), 0), view_236, out=buf114)
del view_236
buf115 = reinterpret_tensor(buf110, (1, 768, 64), (49152, 1, 768), 0); del buf110 # reuse
# Source Nodes: [], Original ATen: [aten.sum]
triton_red_fused_sum_15.run(buf112, buf115, 49152, 128, grid=grid(49152), stream=stream0)
buf118 = empty_strided_cuda((768, ), (1, ), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum]
triton_per_fused__to_copy_sum_11.run(buf115, buf118, 768, 64, grid=grid(768), stream=stream0)
buf117 = empty_strided_cuda((768, 768), (768, 1), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_12.run(buf114, buf117, 589824, grid=grid(589824), stream=stream0)
buf119 = reinterpret_tensor(buf112, (16, 12, 512, 64), (393216, 32768, 64, 1), 0); del buf112 # reuse
# Source Nodes: [], Original ATen: [aten.clone]
triton_poi_fused_clone_23.run(buf113, buf119, 6291456, grid=grid(6291456), stream=stream0)
del buf113
# Source Nodes: [], Original ATen: [aten.clone]
buf120 = aten._scaled_dot_product_flash_attention_backward.default(buf119, permute_default_6, permute_default_7, permute_default_8, getitem_59, getitem_60, None, None, 512, 512, 0.1, False, getitem_61, getitem_62, scale=0.125)
del getitem_59
del getitem_60
del getitem_61
del getitem_62
del permute_default_6
del permute_default_7
del permute_default_8
buf121 = buf120[0]
buf122 = buf120[1]
buf123 = buf120[2]
del buf120
buf124 = reinterpret_tensor(buf119, (8192, 768), (768, 1), 0); del buf119 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf123, (8192, 768), (768, 1), 0), permute_195, out=buf124)
del permute_195
buf125 = buf114; del buf114 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf123, (768, 8192), (1, 768), 0), view_220, out=buf125)
buf126 = buf115; del buf115 # reuse
# Source Nodes: [], Original ATen: [aten.sum]
triton_red_fused_sum_24.run(buf123, buf126, 49152, 128, grid=grid(49152), stream=stream0)
buf129 = empty_strided_cuda((768, ), (1, ), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum]
triton_per_fused__to_copy_sum_11.run(buf126, buf129, 768, 64, grid=grid(768), stream=stream0)
buf128 = empty_strided_cuda((768, 768), (768, 1), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_12.run(buf125, buf128, 589824, grid=grid(589824), stream=stream0)
buf130 = reinterpret_tensor(buf123, (8192, 768), (768, 1), 0); del buf123 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf122, (8192, 768), (768, 1), 0), permute_200, out=buf130)
del permute_200
buf131 = buf125; del buf125 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf122, (768, 8192), (1, 768), 0), view_220, out=buf131)
buf132 = buf126; del buf126 # reuse
# Source Nodes: [], Original ATen: [aten.sum]
triton_red_fused_sum_24.run(buf122, buf132, 49152, 128, grid=grid(49152), stream=stream0)
buf135 = empty_strided_cuda((768, ), (1, ), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum]
triton_per_fused__to_copy_sum_11.run(buf132, buf135, 768, 64, grid=grid(768), stream=stream0)
buf134 = empty_strided_cuda((768, 768), (768, 1), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_12.run(buf131, buf134, 589824, grid=grid(589824), stream=stream0)
buf136 = reinterpret_tensor(buf122, (8192, 768), (768, 1), 0); del buf122 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf121, (8192, 768), (768, 1), 0), permute_204, out=buf136)
del permute_204
buf137 = buf131; del buf131 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf121, (768, 8192), (1, 768), 0), view_220, out=buf137)
del view_220
buf138 = buf132; del buf132 # reuse
# Source Nodes: [], Original ATen: [aten.sum]
triton_red_fused_sum_24.run(buf121, buf138, 49152, 128, grid=grid(49152), stream=stream0)
buf141 = empty_strided_cuda((768, ), (1, ), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum]
triton_per_fused__to_copy_sum_11.run(buf138, buf141, 768, 64, grid=grid(768), stream=stream0)
buf140 = empty_strided_cuda((768, 768), (768, 1), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_12.run(buf137, buf140, 589824, grid=grid(589824), stream=stream0)
buf145 = buf86; del buf86 # reuse
buf150 = reinterpret_tensor(buf121, (16, 512, 768), (393216, 768, 1), 0); del buf121 # reuse
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_dropout_backward, aten.native_layer_norm_backward]
triton_per_fused__to_copy_add_native_dropout_backward_native_layer_norm_backward_25.run(buf107, buf124, buf130, buf136, primals_164, mul_133, div_33, gt_30, buf145, buf150, 8192, 768, grid=grid(8192), stream=stream0)
del div_33
del gt_30
del primals_164
buf146 = reinterpret_tensor(buf138, (768, 64), (1, 768), 0); del buf138 # reuse
buf148 = buf108; del buf108 # reuse
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward]
triton_red_fused__to_copy_add_native_layer_norm_backward_26.run(buf107, buf124, buf130, buf136, mul_133, buf146, buf148, 49152, 128, grid=grid(49152), stream=stream0)
del buf124
del buf130
del mul_133
buf147 = empty_strided_cuda((768, ), (1, ), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward]
triton_per_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_9.run(buf146, buf147, 768, 64, grid=grid(768), stream=stream0)
buf149 = empty_strided_cuda((768, ), (1, ), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward]
triton_per_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_9.run(buf148, buf149, 768, 64, grid=grid(768), stream=stream0)
buf151 = reinterpret_tensor(buf98, (8192, 3072), (3072, 1), 0); del buf98 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf150, (8192, 768), (768, 1), 0), permute_208, out=buf151)
del permute_208
buf152 = reinterpret_tensor(buf100, (768, 3072), (3072, 1), 0); del buf100 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf150, (768, 8192), (1, 768), 0), view_218, out=buf152)
del view_218
buf153 = reinterpret_tensor(buf148, (1, 768, 64), (49152, 1, 768), 0); del buf148 # reuse
# Source Nodes: [], Original ATen: [aten.sum]
triton_red_fused_sum_15.run(buf150, buf153, 49152, 128, grid=grid(49152), stream=stream0)
buf156 = empty_strided_cuda((768, ), (1, ), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum]
triton_per_fused__to_copy_sum_11.run(buf153, buf156, 768, 64, grid=grid(768), stream=stream0)
buf155 = empty_strided_cuda((768, 3072), (3072, 1), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_16.run(buf152, buf155, 2359296, grid=grid(2359296), stream=stream0)
buf157 = reinterpret_tensor(buf151, (16, 512, 3072), (1572864, 3072, 1), 0); del buf151 # reuse
# Source Nodes: [hidden_states_76], Original ATen: [aten.gelu, aten.gelu_backward]
triton_poi_fused_gelu_gelu_backward_17.run(buf157, addmm_58, 25165824, grid=grid(25165824), stream=stream0)
del addmm_58
buf158 = reinterpret_tensor(buf150, (8192, 768), (768, 1), 0); del buf150 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf157, (8192, 3072), (3072, 1), 0), permute_212, out=buf158)
del permute_212
buf159 = reinterpret_tensor(buf152, (3072, 768), (768, 1), 0); del buf152 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf157, (3072, 8192), (1, 3072), 0), view_216, out=buf159)
del view_216
buf160 = buf101; del buf101 # reuse
# Source Nodes: [], Original ATen: [aten.sum]
triton_red_fused_sum_18.run(buf157, buf160, 98304, 256, grid=grid(98304), stream=stream0)
buf163 = empty_strided_cuda((3072, ), (1, ), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum]
triton_per_fused__to_copy_sum_19.run(buf160, buf163, 3072, 32, grid=grid(3072), stream=stream0)
buf162 = empty_strided_cuda((3072, 768), (768, 1), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_20.run(buf159, buf162, 2359296, grid=grid(2359296), stream=stream0)
buf166 = buf107; del buf107 # reuse
buf171 = reinterpret_tensor(buf136, (16, 512, 768), (393216, 768, 1), 0); del buf136 # reuse
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_dropout_backward, aten.native_layer_norm_backward]
triton_per_fused__to_copy_add_native_dropout_backward_native_layer_norm_backward_21.run(buf145, buf158, primals_158, mul_126, div_34, gt_29, buf166, buf171, 8192, 768, grid=grid(8192), stream=stream0)
del div_34
del gt_29
del primals_158
buf167 = reinterpret_tensor(buf153, (768, 64), (1, 768), 0); del buf153 # reuse
buf169 = buf146; del buf146 # reuse
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward]
triton_red_fused__to_copy_add_native_layer_norm_backward_22.run(buf145, buf158, mul_126, buf167, buf169, 49152, 128, grid=grid(49152), stream=stream0)
del mul_126
buf168 = empty_strided_cuda((768, ), (1, ), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward]
triton_per_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_9.run(buf167, buf168, 768, 64, grid=grid(768), stream=stream0)
buf170 = empty_strided_cuda((768, ), (1, ), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward]
triton_per_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_9.run(buf169, buf170, 768, 64, grid=grid(768), stream=stream0)
buf172 = buf158; del buf158 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf171, (8192, 768), (768, 1), 0), permute_216, out=buf172)
del permute_216
buf173 = buf137; del buf137 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf171, (768, 8192), (1, 768), 0), view_214, out=buf173)
del view_214
buf174 = reinterpret_tensor(buf169, (1, 768, 64), (49152, 1, 768), 0); del buf169 # reuse
# Source Nodes: [], Original ATen: [aten.sum]
triton_red_fused_sum_15.run(buf171, buf174, 49152, 128, grid=grid(49152), stream=stream0)
buf177 = empty_strided_cuda((768, ), (1, ), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum]
triton_per_fused__to_copy_sum_11.run(buf174, buf177, 768, 64, grid=grid(768), stream=stream0)
buf176 = empty_strided_cuda((768, 768), (768, 1), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_12.run(buf173, buf176, 589824, grid=grid(589824), stream=stream0)
buf178 = reinterpret_tensor(buf171, (16, 12, 512, 64), (393216, 32768, 64, 1), 0); del buf171 # reuse
# Source Nodes: [], Original ATen: [aten.clone]
triton_poi_fused_clone_23.run(buf172, buf178, 6291456, grid=grid(6291456), stream=stream0)
del buf172
# Source Nodes: [], Original ATen: [aten.clone]
buf179 = aten._scaled_dot_product_flash_attention_backward.default(buf178, permute_default_12, permute_default_13, permute_default_14, getitem_66, getitem_67, None, None, 512, 512, 0.1, False, getitem_68, getitem_69, scale=0.125)
del getitem_66
del getitem_67
del getitem_68
del getitem_69
del permute_default_12
del permute_default_13
del permute_default_14
buf180 = buf179[0]
buf181 = buf179[1]
buf182 = buf179[2]
del buf179
buf183 = reinterpret_tensor(buf178, (8192, 768), (768, 1), 0); del buf178 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf182, (8192, 768), (768, 1), 0), permute_228, out=buf183)
del permute_228
buf184 = buf173; del buf173 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf182, (768, 8192), (1, 768), 0), view_198, out=buf184)
buf185 = buf174; del buf174 # reuse
# Source Nodes: [], Original ATen: [aten.sum]
triton_red_fused_sum_24.run(buf182, buf185, 49152, 128, grid=grid(49152), stream=stream0)
buf188 = empty_strided_cuda((768, ), (1, ), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum]
triton_per_fused__to_copy_sum_11.run(buf185, buf188, 768, 64, grid=grid(768), stream=stream0)
buf187 = empty_strided_cuda((768, 768), (768, 1), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_12.run(buf184, buf187, 589824, grid=grid(589824), stream=stream0)
buf189 = reinterpret_tensor(buf182, (8192, 768), (768, 1), 0); del buf182 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf181, (8192, 768), (768, 1), 0), permute_233, out=buf189)
del permute_233
buf190 = buf184; del buf184 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf181, (768, 8192), (1, 768), 0), view_198, out=buf190)
buf191 = buf185; del buf185 # reuse
# Source Nodes: [], Original ATen: [aten.sum]
triton_red_fused_sum_24.run(buf181, buf191, 49152, 128, grid=grid(49152), stream=stream0)
buf194 = empty_strided_cuda((768, ), (1, ), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum]
triton_per_fused__to_copy_sum_11.run(buf191, buf194, 768, 64, grid=grid(768), stream=stream0)
buf193 = empty_strided_cuda((768, 768), (768, 1), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_12.run(buf190, buf193, 589824, grid=grid(589824), stream=stream0)
buf195 = reinterpret_tensor(buf181, (8192, 768), (768, 1), 0); del buf181 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf180, (8192, 768), (768, 1), 0), permute_237, out=buf195)
del permute_237
buf196 = buf190; del buf190 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf180, (768, 8192), (1, 768), 0), view_198, out=buf196)
del view_198
buf197 = buf191; del buf191 # reuse
# Source Nodes: [], Original ATen: [aten.sum]
triton_red_fused_sum_24.run(buf180, buf197, 49152, 128, grid=grid(49152), stream=stream0)
buf200 = empty_strided_cuda((768, ), (1, ), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum]
triton_per_fused__to_copy_sum_11.run(buf197, buf200, 768, 64, grid=grid(768), stream=stream0)
buf199 = empty_strided_cuda((768, 768), (768, 1), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_12.run(buf196, buf199, 589824, grid=grid(589824), stream=stream0)
buf204 = buf145; del buf145 # reuse
buf209 = reinterpret_tensor(buf180, (16, 512, 768), (393216, 768, 1), 0); del buf180 # reuse
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_dropout_backward, aten.native_layer_norm_backward]
triton_per_fused__to_copy_add_native_dropout_backward_native_layer_norm_backward_25.run(buf166, buf183, buf189, buf195, primals_148, mul_120, div_36, gt_27, buf204, buf209, 8192, 768, grid=grid(8192), stream=stream0)
del div_36
del gt_27
del primals_148
buf205 = reinterpret_tensor(buf197, (768, 64), (1, 768), 0); del buf197 # reuse
buf207 = buf167; del buf167 # reuse
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward]
triton_red_fused__to_copy_add_native_layer_norm_backward_26.run(buf166, buf183, buf189, buf195, mul_120, buf205, buf207, 49152, 128, grid=grid(49152), stream=stream0)
del buf183
del buf189
del mul_120
buf206 = empty_strided_cuda((768, ), (1, ), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward]
triton_per_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_9.run(buf205, buf206, 768, 64, grid=grid(768), stream=stream0)
buf208 = empty_strided_cuda((768, ), (1, ), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward]
triton_per_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_9.run(buf207, buf208, 768, 64, grid=grid(768), stream=stream0)
buf210 = reinterpret_tensor(buf157, (8192, 3072), (3072, 1), 0); del buf157 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf209, (8192, 768), (768, 1), 0), permute_241, out=buf210)
del permute_241
buf211 = reinterpret_tensor(buf159, (768, 3072), (3072, 1), 0); del buf159 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf209, (768, 8192), (1, 768), 0), view_196, out=buf211)
del view_196
buf212 = reinterpret_tensor(buf207, (1, 768, 64), (49152, 1, 768), 0); del buf207 # reuse
# Source Nodes: [], Original ATen: [aten.sum]
triton_red_fused_sum_15.run(buf209, buf212, 49152, 128, grid=grid(49152), stream=stream0)
buf215 = empty_strided_cuda((768, ), (1, ), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum]
triton_per_fused__to_copy_sum_11.run(buf212, buf215, 768, 64, grid=grid(768), stream=stream0)
buf214 = empty_strided_cuda((768, 3072), (3072, 1), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_16.run(buf211, buf214, 2359296, grid=grid(2359296), stream=stream0)
buf216 = reinterpret_tensor(buf210, (16, 512, 3072), (1572864, 3072, 1), 0); del buf210 # reuse
# Source Nodes: [hidden_states_68], Original ATen: [aten.gelu, aten.gelu_backward]
triton_poi_fused_gelu_gelu_backward_17.run(buf216, addmm_52, 25165824, grid=grid(25165824), stream=stream0)
del addmm_52
buf217 = reinterpret_tensor(buf209, (8192, 768), (768, 1), 0); del buf209 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf216, (8192, 3072), (3072, 1), 0), permute_245, out=buf217)
del permute_245
buf218 = reinterpret_tensor(buf211, (3072, 768), (768, 1), 0); del buf211 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf216, (3072, 8192), (1, 3072), 0), view_194, out=buf218)
del view_194
buf219 = buf160; del buf160 # reuse
# Source Nodes: [], Original ATen: [aten.sum]
triton_red_fused_sum_18.run(buf216, buf219, 98304, 256, grid=grid(98304), stream=stream0)
buf222 = empty_strided_cuda((3072, ), (1, ), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum]
triton_per_fused__to_copy_sum_19.run(buf219, buf222, 3072, 32, grid=grid(3072), stream=stream0)
buf221 = empty_strided_cuda((3072, 768), (768, 1), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_20.run(buf218, buf221, 2359296, grid=grid(2359296), stream=stream0)
buf225 = buf166; del buf166 # reuse
buf230 = reinterpret_tensor(buf195, (16, 512, 768), (393216, 768, 1), 0); del buf195 # reuse
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_dropout_backward, aten.native_layer_norm_backward]
triton_per_fused__to_copy_add_native_dropout_backward_native_layer_norm_backward_21.run(buf204, buf217, primals_142, mul_113, div_37, gt_26, buf225, buf230, 8192, 768, grid=grid(8192), stream=stream0)
del div_37
del gt_26
del primals_142
buf226 = reinterpret_tensor(buf212, (768, 64), (1, 768), 0); del buf212 # reuse
buf228 = buf205; del buf205 # reuse
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward]
triton_red_fused__to_copy_add_native_layer_norm_backward_22.run(buf204, buf217, mul_113, buf226, buf228, 49152, 128, grid=grid(49152), stream=stream0)
del mul_113
buf227 = empty_strided_cuda((768, ), (1, ), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward]
triton_per_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_9.run(buf226, buf227, 768, 64, grid=grid(768), stream=stream0)
buf229 = empty_strided_cuda((768, ), (1, ), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward]
triton_per_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_9.run(buf228, buf229, 768, 64, grid=grid(768), stream=stream0)
buf231 = buf217; del buf217 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf230, (8192, 768), (768, 1), 0), permute_249, out=buf231)
del permute_249
buf232 = buf196; del buf196 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf230, (768, 8192), (1, 768), 0), view_192, out=buf232)
del view_192
buf233 = reinterpret_tensor(buf228, (1, 768, 64), (49152, 1, 768), 0); del buf228 # reuse
# Source Nodes: [], Original ATen: [aten.sum]
triton_red_fused_sum_15.run(buf230, buf233, 49152, 128, grid=grid(49152), stream=stream0)
buf236 = empty_strided_cuda((768, ), (1, ), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum]
triton_per_fused__to_copy_sum_11.run(buf233, buf236, 768, 64, grid=grid(768), stream=stream0)
buf235 = empty_strided_cuda((768, 768), (768, 1), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_12.run(buf232, buf235, 589824, grid=grid(589824), stream=stream0)
buf237 = reinterpret_tensor(buf230, (16, 12, 512, 64), (393216, 32768, 64, 1), 0); del buf230 # reuse
# Source Nodes: [], Original ATen: [aten.clone]
triton_poi_fused_clone_23.run(buf231, buf237, 6291456, grid=grid(6291456), stream=stream0)
del buf231
# Source Nodes: [], Original ATen: [aten.clone]
buf238 = aten._scaled_dot_product_flash_attention_backward.default(buf237, permute_default_18, permute_default_19, permute_default_20, getitem_73, getitem_74, None, None, 512, 512, 0.1, False, getitem_75, getitem_76, scale=0.125)
del getitem_73
del getitem_74
del getitem_75
del getitem_76
del permute_default_18
del permute_default_19
del permute_default_20
buf239 = buf238[0]
buf240 = buf238[1]
buf241 = buf238[2]
del buf238
buf242 = reinterpret_tensor(buf237, (8192, 768), (768, 1), 0); del buf237 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf241, (8192, 768), (768, 1), 0), permute_261, out=buf242)
del permute_261
buf243 = buf232; del buf232 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf241, (768, 8192), (1, 768), 0), view_176, out=buf243)
buf244 = buf233; del buf233 # reuse
# Source Nodes: [], Original ATen: [aten.sum]
triton_red_fused_sum_24.run(buf241, buf244, 49152, 128, grid=grid(49152), stream=stream0)
buf247 = empty_strided_cuda((768, ), (1, ), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum]
triton_per_fused__to_copy_sum_11.run(buf244, buf247, 768, 64, grid=grid(768), stream=stream0)
buf246 = empty_strided_cuda((768, 768), (768, 1), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_12.run(buf243, buf246, 589824, grid=grid(589824), stream=stream0)
buf248 = reinterpret_tensor(buf241, (8192, 768), (768, 1), 0); del buf241 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf240, (8192, 768), (768, 1), 0), permute_266, out=buf248)
del permute_266
buf249 = buf243; del buf243 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf240, (768, 8192), (1, 768), 0), view_176, out=buf249)
buf250 = buf244; del buf244 # reuse
# Source Nodes: [], Original ATen: [aten.sum]
triton_red_fused_sum_24.run(buf240, buf250, 49152, 128, grid=grid(49152), stream=stream0)
buf253 = empty_strided_cuda((768, ), (1, ), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum]
triton_per_fused__to_copy_sum_11.run(buf250, buf253, 768, 64, grid=grid(768), stream=stream0)
buf252 = empty_strided_cuda((768, 768), (768, 1), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_12.run(buf249, buf252, 589824, grid=grid(589824), stream=stream0)
buf254 = reinterpret_tensor(buf240, (8192, 768), (768, 1), 0); del buf240 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf239, (8192, 768), (768, 1), 0), permute_270, out=buf254)
del permute_270
buf255 = buf249; del buf249 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf239, (768, 8192), (1, 768), 0), view_176, out=buf255)
del view_176
buf256 = buf250; del buf250 # reuse
# Source Nodes: [], Original ATen: [aten.sum]
triton_red_fused_sum_24.run(buf239, buf256, 49152, 128, grid=grid(49152), stream=stream0)
buf259 = empty_strided_cuda((768, ), (1, ), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum]
triton_per_fused__to_copy_sum_11.run(buf256, buf259, 768, 64, grid=grid(768), stream=stream0)
buf258 = empty_strided_cuda((768, 768), (768, 1), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_12.run(buf255, buf258, 589824, grid=grid(589824), stream=stream0)
buf263 = buf204; del buf204 # reuse
buf268 = reinterpret_tensor(buf239, (16, 512, 768), (393216, 768, 1), 0); del buf239 # reuse
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_dropout_backward, aten.native_layer_norm_backward]
triton_per_fused__to_copy_add_native_dropout_backward_native_layer_norm_backward_25.run(buf225, buf242, buf248, buf254, primals_132, mul_107, div_39, gt_24, buf263, buf268, 8192, 768, grid=grid(8192), stream=stream0)
del div_39
del gt_24
del primals_132
buf264 = reinterpret_tensor(buf256, (768, 64), (1, 768), 0); del buf256 # reuse
buf266 = buf226; del buf226 # reuse
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward]
triton_red_fused__to_copy_add_native_layer_norm_backward_26.run(buf225, buf242, buf248, buf254, mul_107, buf264, buf266, 49152, 128, grid=grid(49152), stream=stream0)
del buf242
del buf248
del mul_107
buf265 = empty_strided_cuda((768, ), (1, ), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward]
triton_per_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_9.run(buf264, buf265, 768, 64, grid=grid(768), stream=stream0)
buf267 = empty_strided_cuda((768, ), (1, ), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward]
triton_per_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_9.run(buf266, buf267, 768, 64, grid=grid(768), stream=stream0)
buf269 = reinterpret_tensor(buf216, (8192, 3072), (3072, 1), 0); del buf216 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf268, (8192, 768), (768, 1), 0), permute_274, out=buf269)
del permute_274
buf270 = reinterpret_tensor(buf218, (768, 3072), (3072, 1), 0); del buf218 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf268, (768, 8192), (1, 768), 0), view_174, out=buf270)
del view_174
buf271 = reinterpret_tensor(buf266, (1, 768, 64), (49152, 1, 768), 0); del buf266 # reuse
# Source Nodes: [], Original ATen: [aten.sum]
triton_red_fused_sum_15.run(buf268, buf271, 49152, 128, grid=grid(49152), stream=stream0)
buf274 = empty_strided_cuda((768, ), (1, ), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum]
triton_per_fused__to_copy_sum_11.run(buf271, buf274, 768, 64, grid=grid(768), stream=stream0)
buf273 = empty_strided_cuda((768, 3072), (3072, 1), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_16.run(buf270, buf273, 2359296, grid=grid(2359296), stream=stream0)
buf275 = reinterpret_tensor(buf269, (16, 512, 3072), (1572864, 3072, 1), 0); del buf269 # reuse
# Source Nodes: [hidden_states_60], Original ATen: [aten.gelu, aten.gelu_backward]
triton_poi_fused_gelu_gelu_backward_17.run(buf275, addmm_46, 25165824, grid=grid(25165824), stream=stream0)
del addmm_46
buf276 = reinterpret_tensor(buf268, (8192, 768), (768, 1), 0); del buf268 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf275, (8192, 3072), (3072, 1), 0), permute_278, out=buf276)
del permute_278
buf277 = reinterpret_tensor(buf270, (3072, 768), (768, 1), 0); del buf270 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf275, (3072, 8192), (1, 3072), 0), view_172, out=buf277)
del view_172
buf278 = buf219; del buf219 # reuse
# Source Nodes: [], Original ATen: [aten.sum]
triton_red_fused_sum_18.run(buf275, buf278, 98304, 256, grid=grid(98304), stream=stream0)
buf281 = empty_strided_cuda((3072, ), (1, ), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum]
triton_per_fused__to_copy_sum_19.run(buf278, buf281, 3072, 32, grid=grid(3072), stream=stream0)
buf280 = empty_strided_cuda((3072, 768), (768, 1), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_20.run(buf277, buf280, 2359296, grid=grid(2359296), stream=stream0)
buf284 = buf225; del buf225 # reuse
buf289 = reinterpret_tensor(buf254, (16, 512, 768), (393216, 768, 1), 0); del buf254 # reuse
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_dropout_backward, aten.native_layer_norm_backward]
triton_per_fused__to_copy_add_native_dropout_backward_native_layer_norm_backward_21.run(buf263, buf276, primals_126, mul_100, div_40, gt_23, buf284, buf289, 8192, 768, grid=grid(8192), stream=stream0)
del div_40
del gt_23
del primals_126
buf285 = reinterpret_tensor(buf271, (768, 64), (1, 768), 0); del buf271 # reuse
buf287 = buf264; del buf264 # reuse
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward]
triton_red_fused__to_copy_add_native_layer_norm_backward_22.run(buf263, buf276, mul_100, buf285, buf287, 49152, 128, grid=grid(49152), stream=stream0)
del mul_100
buf286 = empty_strided_cuda((768, ), (1, ), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward]
triton_per_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_9.run(buf285, buf286, 768, 64, grid=grid(768), stream=stream0)
buf288 = empty_strided_cuda((768, ), (1, ), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward]
triton_per_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_9.run(buf287, buf288, 768, 64, grid=grid(768), stream=stream0)
buf290 = buf276; del buf276 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf289, (8192, 768), (768, 1), 0), permute_282, out=buf290)
del permute_282
buf291 = buf255; del buf255 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf289, (768, 8192), (1, 768), 0), view_170, out=buf291)
del view_170
buf292 = reinterpret_tensor(buf287, (1, 768, 64), (49152, 1, 768), 0); del buf287 # reuse
# Source Nodes: [], Original ATen: [aten.sum]
triton_red_fused_sum_15.run(buf289, buf292, 49152, 128, grid=grid(49152), stream=stream0)
buf295 = empty_strided_cuda((768, ), (1, ), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum]
triton_per_fused__to_copy_sum_11.run(buf292, buf295, 768, 64, grid=grid(768), stream=stream0)
buf294 = empty_strided_cuda((768, 768), (768, 1), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_12.run(buf291, buf294, 589824, grid=grid(589824), stream=stream0)
buf296 = reinterpret_tensor(buf289, (16, 12, 512, 64), (393216, 32768, 64, 1), 0); del buf289 # reuse
# Source Nodes: [], Original ATen: [aten.clone]
triton_poi_fused_clone_23.run(buf290, buf296, 6291456, grid=grid(6291456), stream=stream0)
del buf290
# Source Nodes: [], Original ATen: [aten.clone]
buf297 = aten._scaled_dot_product_flash_attention_backward.default(buf296, permute_default_24, permute_default_25, permute_default_26, getitem_80, getitem_81, None, None, 512, 512, 0.1, False, getitem_82, getitem_83, scale=0.125)
del getitem_80
del getitem_81
del getitem_82
del getitem_83
del permute_default_24
del permute_default_25
del permute_default_26
buf298 = buf297[0]
buf299 = buf297[1]
buf300 = buf297[2]
del buf297
buf301 = reinterpret_tensor(buf296, (8192, 768), (768, 1), 0); del buf296 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf300, (8192, 768), (768, 1), 0), permute_294, out=buf301)
del permute_294
buf302 = buf291; del buf291 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf300, (768, 8192), (1, 768), 0), view_154, out=buf302)
buf303 = buf292; del buf292 # reuse
# Source Nodes: [], Original ATen: [aten.sum]
triton_red_fused_sum_24.run(buf300, buf303, 49152, 128, grid=grid(49152), stream=stream0)
buf306 = empty_strided_cuda((768, ), (1, ), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum]
triton_per_fused__to_copy_sum_11.run(buf303, buf306, 768, 64, grid=grid(768), stream=stream0)
buf305 = empty_strided_cuda((768, 768), (768, 1), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_12.run(buf302, buf305, 589824, grid=grid(589824), stream=stream0)
buf307 = reinterpret_tensor(buf300, (8192, 768), (768, 1), 0); del buf300 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf299, (8192, 768), (768, 1), 0), permute_299, out=buf307)
del permute_299
buf308 = buf302; del buf302 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf299, (768, 8192), (1, 768), 0), view_154, out=buf308)
buf309 = buf303; del buf303 # reuse
# Source Nodes: [], Original ATen: [aten.sum]
triton_red_fused_sum_24.run(buf299, buf309, 49152, 128, grid=grid(49152), stream=stream0)
buf312 = empty_strided_cuda((768, ), (1, ), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum]
triton_per_fused__to_copy_sum_11.run(buf309, buf312, 768, 64, grid=grid(768), stream=stream0)
buf311 = empty_strided_cuda((768, 768), (768, 1), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_12.run(buf308, buf311, 589824, grid=grid(589824), stream=stream0)
buf313 = reinterpret_tensor(buf299, (8192, 768), (768, 1), 0); del buf299 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf298, (8192, 768), (768, 1), 0), permute_303, out=buf313)
del permute_303
buf314 = buf308; del buf308 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf298, (768, 8192), (1, 768), 0), view_154, out=buf314)
del view_154
buf315 = buf309; del buf309 # reuse
# Source Nodes: [], Original ATen: [aten.sum]
triton_red_fused_sum_24.run(buf298, buf315, 49152, 128, grid=grid(49152), stream=stream0)
buf318 = empty_strided_cuda((768, ), (1, ), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum]
triton_per_fused__to_copy_sum_11.run(buf315, buf318, 768, 64, grid=grid(768), stream=stream0)
buf317 = empty_strided_cuda((768, 768), (768, 1), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_12.run(buf314, buf317, 589824, grid=grid(589824), stream=stream0)
buf322 = buf263; del buf263 # reuse
buf327 = reinterpret_tensor(buf298, (16, 512, 768), (393216, 768, 1), 0); del buf298 # reuse
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_dropout_backward, aten.native_layer_norm_backward]
triton_per_fused__to_copy_add_native_dropout_backward_native_layer_norm_backward_25.run(buf284, buf301, buf307, buf313, primals_116, mul_94, div_42, gt_21, buf322, buf327, 8192, 768, grid=grid(8192), stream=stream0)
del div_42
del gt_21
del primals_116
buf323 = reinterpret_tensor(buf315, (768, 64), (1, 768), 0); del buf315 # reuse
buf325 = buf285; del buf285 # reuse
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward]
triton_red_fused__to_copy_add_native_layer_norm_backward_26.run(buf284, buf301, buf307, buf313, mul_94, buf323, buf325, 49152, 128, grid=grid(49152), stream=stream0)
del buf301
del buf307
del mul_94
buf324 = empty_strided_cuda((768, ), (1, ), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward]
triton_per_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_9.run(buf323, buf324, 768, 64, grid=grid(768), stream=stream0)
buf326 = empty_strided_cuda((768, ), (1, ), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward]
triton_per_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_9.run(buf325, buf326, 768, 64, grid=grid(768), stream=stream0)
buf328 = reinterpret_tensor(buf275, (8192, 3072), (3072, 1), 0); del buf275 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf327, (8192, 768), (768, 1), 0), permute_307, out=buf328)
del permute_307
buf329 = reinterpret_tensor(buf277, (768, 3072), (3072, 1), 0); del buf277 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf327, (768, 8192), (1, 768), 0), view_152, out=buf329)
del view_152
buf330 = reinterpret_tensor(buf325, (1, 768, 64), (49152, 1, 768), 0); del buf325 # reuse
# Source Nodes: [], Original ATen: [aten.sum]
triton_red_fused_sum_15.run(buf327, buf330, 49152, 128, grid=grid(49152), stream=stream0)
buf333 = empty_strided_cuda((768, ), (1, ), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum]
triton_per_fused__to_copy_sum_11.run(buf330, buf333, 768, 64, grid=grid(768), stream=stream0)
buf332 = empty_strided_cuda((768, 3072), (3072, 1), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_16.run(buf329, buf332, 2359296, grid=grid(2359296), stream=stream0)
buf334 = reinterpret_tensor(buf328, (16, 512, 3072), (1572864, 3072, 1), 0); del buf328 # reuse
# Source Nodes: [hidden_states_52], Original ATen: [aten.gelu, aten.gelu_backward]
triton_poi_fused_gelu_gelu_backward_17.run(buf334, addmm_40, 25165824, grid=grid(25165824), stream=stream0)
del addmm_40
buf335 = reinterpret_tensor(buf327, (8192, 768), (768, 1), 0); del buf327 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf334, (8192, 3072), (3072, 1), 0), permute_311, out=buf335)
del permute_311
buf336 = reinterpret_tensor(buf329, (3072, 768), (768, 1), 0); del buf329 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf334, (3072, 8192), (1, 3072), 0), view_150, out=buf336)
del view_150
buf337 = buf278; del buf278 # reuse
# Source Nodes: [], Original ATen: [aten.sum]
triton_red_fused_sum_18.run(buf334, buf337, 98304, 256, grid=grid(98304), stream=stream0)
buf340 = empty_strided_cuda((3072, ), (1, ), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum]
triton_per_fused__to_copy_sum_19.run(buf337, buf340, 3072, 32, grid=grid(3072), stream=stream0)
buf339 = empty_strided_cuda((3072, 768), (768, 1), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_20.run(buf336, buf339, 2359296, grid=grid(2359296), stream=stream0)
buf343 = buf284; del buf284 # reuse
buf348 = reinterpret_tensor(buf313, (16, 512, 768), (393216, 768, 1), 0); del buf313 # reuse
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_dropout_backward, aten.native_layer_norm_backward]
triton_per_fused__to_copy_add_native_dropout_backward_native_layer_norm_backward_21.run(buf322, buf335, primals_110, mul_87, div_43, gt_20, buf343, buf348, 8192, 768, grid=grid(8192), stream=stream0)
del div_43
del gt_20
del primals_110
buf344 = reinterpret_tensor(buf330, (768, 64), (1, 768), 0); del buf330 # reuse
buf346 = buf323; del buf323 # reuse
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward]
triton_red_fused__to_copy_add_native_layer_norm_backward_22.run(buf322, buf335, mul_87, buf344, buf346, 49152, 128, grid=grid(49152), stream=stream0)
del mul_87
buf345 = empty_strided_cuda((768, ), (1, ), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward]
triton_per_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_9.run(buf344, buf345, 768, 64, grid=grid(768), stream=stream0)
buf347 = empty_strided_cuda((768, ), (1, ), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward]
triton_per_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_9.run(buf346, buf347, 768, 64, grid=grid(768), stream=stream0)
buf349 = buf335; del buf335 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf348, (8192, 768), (768, 1), 0), permute_315, out=buf349)
del permute_315
buf350 = buf314; del buf314 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf348, (768, 8192), (1, 768), 0), view_148, out=buf350)
del view_148
buf351 = reinterpret_tensor(buf346, (1, 768, 64), (49152, 1, 768), 0); del buf346 # reuse
# Source Nodes: [], Original ATen: [aten.sum]
triton_red_fused_sum_15.run(buf348, buf351, 49152, 128, grid=grid(49152), stream=stream0)
buf354 = empty_strided_cuda((768, ), (1, ), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum]
triton_per_fused__to_copy_sum_11.run(buf351, buf354, 768, 64, grid=grid(768), stream=stream0)
buf353 = empty_strided_cuda((768, 768), (768, 1), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_12.run(buf350, buf353, 589824, grid=grid(589824), stream=stream0)
buf355 = reinterpret_tensor(buf348, (16, 12, 512, 64), (393216, 32768, 64, 1), 0); del buf348 # reuse
# Source Nodes: [], Original ATen: [aten.clone]
triton_poi_fused_clone_23.run(buf349, buf355, 6291456, grid=grid(6291456), stream=stream0)
del buf349
# Source Nodes: [], Original ATen: [aten.clone]
buf356 = aten._scaled_dot_product_flash_attention_backward.default(buf355, permute_default_30, permute_default_31, permute_default_32, getitem_87, getitem_88, None, None, 512, 512, 0.1, False, getitem_89, getitem_90, scale=0.125)
del getitem_87
del getitem_88
del getitem_89
del getitem_90
del permute_default_30
del permute_default_31
del permute_default_32
buf357 = buf356[0]
buf358 = buf356[1]
buf359 = buf356[2]
del buf356
buf360 = reinterpret_tensor(buf355, (8192, 768), (768, 1), 0); del buf355 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf359, (8192, 768), (768, 1), 0), permute_327, out=buf360)
del permute_327
buf361 = buf350; del buf350 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf359, (768, 8192), (1, 768), 0), view_132, out=buf361)
buf362 = buf351; del buf351 # reuse
# Source Nodes: [], Original ATen: [aten.sum]
triton_red_fused_sum_24.run(buf359, buf362, 49152, 128, grid=grid(49152), stream=stream0)
buf365 = empty_strided_cuda((768, ), (1, ), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum]
triton_per_fused__to_copy_sum_11.run(buf362, buf365, 768, 64, grid=grid(768), stream=stream0)
buf364 = empty_strided_cuda((768, 768), (768, 1), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_12.run(buf361, buf364, 589824, grid=grid(589824), stream=stream0)
buf366 = reinterpret_tensor(buf359, (8192, 768), (768, 1), 0); del buf359 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf358, (8192, 768), (768, 1), 0), permute_332, out=buf366)
del permute_332
buf367 = buf361; del buf361 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf358, (768, 8192), (1, 768), 0), view_132, out=buf367)
buf368 = buf362; del buf362 # reuse
# Source Nodes: [], Original ATen: [aten.sum]
triton_red_fused_sum_24.run(buf358, buf368, 49152, 128, grid=grid(49152), stream=stream0)
buf371 = empty_strided_cuda((768, ), (1, ), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum]
triton_per_fused__to_copy_sum_11.run(buf368, buf371, 768, 64, grid=grid(768), stream=stream0)
buf370 = empty_strided_cuda((768, 768), (768, 1), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_12.run(buf367, buf370, 589824, grid=grid(589824), stream=stream0)
buf372 = reinterpret_tensor(buf358, (8192, 768), (768, 1), 0); del buf358 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf357, (8192, 768), (768, 1), 0), permute_336, out=buf372)
del permute_336
buf373 = buf367; del buf367 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf357, (768, 8192), (1, 768), 0), view_132, out=buf373)
del view_132
buf374 = buf368; del buf368 # reuse
# Source Nodes: [], Original ATen: [aten.sum]
triton_red_fused_sum_24.run(buf357, buf374, 49152, 128, grid=grid(49152), stream=stream0)
buf377 = empty_strided_cuda((768, ), (1, ), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum]
triton_per_fused__to_copy_sum_11.run(buf374, buf377, 768, 64, grid=grid(768), stream=stream0)
buf376 = empty_strided_cuda((768, 768), (768, 1), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_12.run(buf373, buf376, 589824, grid=grid(589824), stream=stream0)
buf381 = buf322; del buf322 # reuse
buf386 = reinterpret_tensor(buf357, (16, 512, 768), (393216, 768, 1), 0); del buf357 # reuse
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_dropout_backward, aten.native_layer_norm_backward]
triton_per_fused__to_copy_add_native_dropout_backward_native_layer_norm_backward_25.run(buf343, buf360, buf366, buf372, primals_100, mul_81, div_45, gt_18, buf381, buf386, 8192, 768, grid=grid(8192), stream=stream0)
del div_45
del gt_18
del primals_100
buf382 = reinterpret_tensor(buf374, (768, 64), (1, 768), 0); del buf374 # reuse
buf384 = buf344; del buf344 # reuse
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward]
triton_red_fused__to_copy_add_native_layer_norm_backward_26.run(buf343, buf360, buf366, buf372, mul_81, buf382, buf384, 49152, 128, grid=grid(49152), stream=stream0)
del buf360
del buf366
del mul_81
buf383 = empty_strided_cuda((768, ), (1, ), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward]
triton_per_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_9.run(buf382, buf383, 768, 64, grid=grid(768), stream=stream0)
buf385 = empty_strided_cuda((768, ), (1, ), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward]
triton_per_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_9.run(buf384, buf385, 768, 64, grid=grid(768), stream=stream0)
buf387 = reinterpret_tensor(buf334, (8192, 3072), (3072, 1), 0); del buf334 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf386, (8192, 768), (768, 1), 0), permute_340, out=buf387)
del permute_340
buf388 = reinterpret_tensor(buf336, (768, 3072), (3072, 1), 0); del buf336 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf386, (768, 8192), (1, 768), 0), view_130, out=buf388)
del view_130
buf389 = reinterpret_tensor(buf384, (1, 768, 64), (49152, 1, 768), 0); del buf384 # reuse
# Source Nodes: [], Original ATen: [aten.sum]
triton_red_fused_sum_15.run(buf386, buf389, 49152, 128, grid=grid(49152), stream=stream0)
buf392 = empty_strided_cuda((768, ), (1, ), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum]
triton_per_fused__to_copy_sum_11.run(buf389, buf392, 768, 64, grid=grid(768), stream=stream0)
buf391 = empty_strided_cuda((768, 3072), (3072, 1), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_16.run(buf388, buf391, 2359296, grid=grid(2359296), stream=stream0)
buf393 = reinterpret_tensor(buf387, (16, 512, 3072), (1572864, 3072, 1), 0); del buf387 # reuse
# Source Nodes: [hidden_states_44], Original ATen: [aten.gelu, aten.gelu_backward]
triton_poi_fused_gelu_gelu_backward_17.run(buf393, addmm_34, 25165824, grid=grid(25165824), stream=stream0)
del addmm_34
buf394 = reinterpret_tensor(buf386, (8192, 768), (768, 1), 0); del buf386 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf393, (8192, 3072), (3072, 1), 0), permute_344, out=buf394)
del permute_344
buf395 = reinterpret_tensor(buf388, (3072, 768), (768, 1), 0); del buf388 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf393, (3072, 8192), (1, 3072), 0), view_128, out=buf395)
del view_128
buf396 = buf337; del buf337 # reuse
# Source Nodes: [], Original ATen: [aten.sum]
triton_red_fused_sum_18.run(buf393, buf396, 98304, 256, grid=grid(98304), stream=stream0)
buf399 = empty_strided_cuda((3072, ), (1, ), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum]
triton_per_fused__to_copy_sum_19.run(buf396, buf399, 3072, 32, grid=grid(3072), stream=stream0)
buf398 = empty_strided_cuda((3072, 768), (768, 1), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_20.run(buf395, buf398, 2359296, grid=grid(2359296), stream=stream0)
buf402 = buf343; del buf343 # reuse
buf407 = reinterpret_tensor(buf372, (16, 512, 768), (393216, 768, 1), 0); del buf372 # reuse
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_dropout_backward, aten.native_layer_norm_backward]
triton_per_fused__to_copy_add_native_dropout_backward_native_layer_norm_backward_21.run(buf381, buf394, primals_94, mul_74, div_46, gt_17, buf402, buf407, 8192, 768, grid=grid(8192), stream=stream0)
del div_46
del gt_17
del primals_94
buf403 = reinterpret_tensor(buf389, (768, 64), (1, 768), 0); del buf389 # reuse
buf405 = buf382; del buf382 # reuse
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward]
triton_red_fused__to_copy_add_native_layer_norm_backward_22.run(buf381, buf394, mul_74, buf403, buf405, 49152, 128, grid=grid(49152), stream=stream0)
del mul_74
buf404 = empty_strided_cuda((768, ), (1, ), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward]
triton_per_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_9.run(buf403, buf404, 768, 64, grid=grid(768), stream=stream0)
buf406 = empty_strided_cuda((768, ), (1, ), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward]
triton_per_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_9.run(buf405, buf406, 768, 64, grid=grid(768), stream=stream0)
buf408 = buf394; del buf394 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf407, (8192, 768), (768, 1), 0), permute_348, out=buf408)
del permute_348
buf409 = buf373; del buf373 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf407, (768, 8192), (1, 768), 0), view_126, out=buf409)
del view_126
buf410 = reinterpret_tensor(buf405, (1, 768, 64), (49152, 1, 768), 0); del buf405 # reuse
# Source Nodes: [], Original ATen: [aten.sum]
triton_red_fused_sum_15.run(buf407, buf410, 49152, 128, grid=grid(49152), stream=stream0)
buf413 = empty_strided_cuda((768, ), (1, ), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum]
triton_per_fused__to_copy_sum_11.run(buf410, buf413, 768, 64, grid=grid(768), stream=stream0)
buf412 = empty_strided_cuda((768, 768), (768, 1), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_12.run(buf409, buf412, 589824, grid=grid(589824), stream=stream0)
buf414 = reinterpret_tensor(buf407, (16, 12, 512, 64), (393216, 32768, 64, 1), 0); del buf407 # reuse
# Source Nodes: [], Original ATen: [aten.clone]
triton_poi_fused_clone_23.run(buf408, buf414, 6291456, grid=grid(6291456), stream=stream0)
del buf408
# Source Nodes: [], Original ATen: [aten.clone]
buf415 = aten._scaled_dot_product_flash_attention_backward.default(buf414, permute_default_36, permute_default_37, permute_default_38, getitem_94, getitem_95, None, None, 512, 512, 0.1, False, getitem_96, getitem_97, scale=0.125)
del getitem_94
del getitem_95
del getitem_96
del getitem_97
del permute_default_36
del permute_default_37
del permute_default_38
buf416 = buf415[0]
buf417 = buf415[1]
buf418 = buf415[2]
del buf415
buf419 = reinterpret_tensor(buf414, (8192, 768), (768, 1), 0); del buf414 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf418, (8192, 768), (768, 1), 0), permute_360, out=buf419)
del permute_360
buf420 = buf409; del buf409 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf418, (768, 8192), (1, 768), 0), view_110, out=buf420)
buf421 = buf410; del buf410 # reuse
# Source Nodes: [], Original ATen: [aten.sum]
triton_red_fused_sum_24.run(buf418, buf421, 49152, 128, grid=grid(49152), stream=stream0)
buf424 = empty_strided_cuda((768, ), (1, ), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum]
triton_per_fused__to_copy_sum_11.run(buf421, buf424, 768, 64, grid=grid(768), stream=stream0)
buf423 = empty_strided_cuda((768, 768), (768, 1), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_12.run(buf420, buf423, 589824, grid=grid(589824), stream=stream0)
buf425 = reinterpret_tensor(buf418, (8192, 768), (768, 1), 0); del buf418 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf417, (8192, 768), (768, 1), 0), permute_365, out=buf425)
del permute_365
buf426 = buf420; del buf420 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf417, (768, 8192), (1, 768), 0), view_110, out=buf426)
buf427 = buf421; del buf421 # reuse
# Source Nodes: [], Original ATen: [aten.sum]
triton_red_fused_sum_24.run(buf417, buf427, 49152, 128, grid=grid(49152), stream=stream0)
buf430 = empty_strided_cuda((768, ), (1, ), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum]
triton_per_fused__to_copy_sum_11.run(buf427, buf430, 768, 64, grid=grid(768), stream=stream0)
buf429 = empty_strided_cuda((768, 768), (768, 1), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_12.run(buf426, buf429, 589824, grid=grid(589824), stream=stream0)
buf431 = reinterpret_tensor(buf417, (8192, 768), (768, 1), 0); del buf417 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf416, (8192, 768), (768, 1), 0), permute_369, out=buf431)
del permute_369
buf432 = buf426; del buf426 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf416, (768, 8192), (1, 768), 0), view_110, out=buf432)
del view_110
buf433 = buf427; del buf427 # reuse
# Source Nodes: [], Original ATen: [aten.sum]
triton_red_fused_sum_24.run(buf416, buf433, 49152, 128, grid=grid(49152), stream=stream0)
buf436 = empty_strided_cuda((768, ), (1, ), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum]
triton_per_fused__to_copy_sum_11.run(buf433, buf436, 768, 64, grid=grid(768), stream=stream0)
buf435 = empty_strided_cuda((768, 768), (768, 1), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_12.run(buf432, buf435, 589824, grid=grid(589824), stream=stream0)
buf440 = buf381; del buf381 # reuse
buf445 = reinterpret_tensor(buf416, (16, 512, 768), (393216, 768, 1), 0); del buf416 # reuse
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_dropout_backward, aten.native_layer_norm_backward]
triton_per_fused__to_copy_add_native_dropout_backward_native_layer_norm_backward_25.run(buf402, buf419, buf425, buf431, primals_84, mul_68, div_48, gt_15, buf440, buf445, 8192, 768, grid=grid(8192), stream=stream0)
del div_48
del gt_15
del primals_84
buf441 = reinterpret_tensor(buf433, (768, 64), (1, 768), 0); del buf433 # reuse
buf443 = buf403; del buf403 # reuse
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward]
triton_red_fused__to_copy_add_native_layer_norm_backward_26.run(buf402, buf419, buf425, buf431, mul_68, buf441, buf443, 49152, 128, grid=grid(49152), stream=stream0)
del buf419
del buf425
del mul_68
buf442 = empty_strided_cuda((768, ), (1, ), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward]
triton_per_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_9.run(buf441, buf442, 768, 64, grid=grid(768), stream=stream0)
buf444 = empty_strided_cuda((768, ), (1, ), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward]
triton_per_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_9.run(buf443, buf444, 768, 64, grid=grid(768), stream=stream0)
buf446 = reinterpret_tensor(buf393, (8192, 3072), (3072, 1), 0); del buf393 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf445, (8192, 768), (768, 1), 0), permute_373, out=buf446)
del permute_373
buf447 = reinterpret_tensor(buf395, (768, 3072), (3072, 1), 0); del buf395 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf445, (768, 8192), (1, 768), 0), view_108, out=buf447)
del view_108
buf448 = reinterpret_tensor(buf443, (1, 768, 64), (49152, 1, 768), 0); del buf443 # reuse
# Source Nodes: [], Original ATen: [aten.sum]
triton_red_fused_sum_15.run(buf445, buf448, 49152, 128, grid=grid(49152), stream=stream0)
buf451 = empty_strided_cuda((768, ), (1, ), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum]
triton_per_fused__to_copy_sum_11.run(buf448, buf451, 768, 64, grid=grid(768), stream=stream0)
buf450 = empty_strided_cuda((768, 3072), (3072, 1), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_16.run(buf447, buf450, 2359296, grid=grid(2359296), stream=stream0)
buf452 = reinterpret_tensor(buf446, (16, 512, 3072), (1572864, 3072, 1), 0); del buf446 # reuse
# Source Nodes: [hidden_states_36], Original ATen: [aten.gelu, aten.gelu_backward]
triton_poi_fused_gelu_gelu_backward_17.run(buf452, addmm_28, 25165824, grid=grid(25165824), stream=stream0)
del addmm_28
buf453 = reinterpret_tensor(buf445, (8192, 768), (768, 1), 0); del buf445 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf452, (8192, 3072), (3072, 1), 0), permute_377, out=buf453)
del permute_377
buf454 = reinterpret_tensor(buf447, (3072, 768), (768, 1), 0); del buf447 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf452, (3072, 8192), (1, 3072), 0), view_106, out=buf454)
del view_106
buf455 = buf396; del buf396 # reuse
# Source Nodes: [], Original ATen: [aten.sum]
triton_red_fused_sum_18.run(buf452, buf455, 98304, 256, grid=grid(98304), stream=stream0)
buf458 = empty_strided_cuda((3072, ), (1, ), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum]
triton_per_fused__to_copy_sum_19.run(buf455, buf458, 3072, 32, grid=grid(3072), stream=stream0)
buf457 = empty_strided_cuda((3072, 768), (768, 1), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_20.run(buf454, buf457, 2359296, grid=grid(2359296), stream=stream0)
buf461 = buf402; del buf402 # reuse
buf466 = reinterpret_tensor(buf431, (16, 512, 768), (393216, 768, 1), 0); del buf431 # reuse
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_dropout_backward, aten.native_layer_norm_backward]
triton_per_fused__to_copy_add_native_dropout_backward_native_layer_norm_backward_21.run(buf440, buf453, primals_78, mul_61, div_49, gt_14, buf461, buf466, 8192, 768, grid=grid(8192), stream=stream0)
del div_49
del gt_14
del primals_78
buf462 = reinterpret_tensor(buf448, (768, 64), (1, 768), 0); del buf448 # reuse
buf464 = buf441; del buf441 # reuse
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward]
triton_red_fused__to_copy_add_native_layer_norm_backward_22.run(buf440, buf453, mul_61, buf462, buf464, 49152, 128, grid=grid(49152), stream=stream0)
del mul_61
buf463 = empty_strided_cuda((768, ), (1, ), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward]
triton_per_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_9.run(buf462, buf463, 768, 64, grid=grid(768), stream=stream0)
buf465 = empty_strided_cuda((768, ), (1, ), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward]
triton_per_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_9.run(buf464, buf465, 768, 64, grid=grid(768), stream=stream0)
buf467 = buf453; del buf453 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf466, (8192, 768), (768, 1), 0), permute_381, out=buf467)
del permute_381
buf468 = buf432; del buf432 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf466, (768, 8192), (1, 768), 0), view_104, out=buf468)
del view_104
buf469 = reinterpret_tensor(buf464, (1, 768, 64), (49152, 1, 768), 0); del buf464 # reuse
# Source Nodes: [], Original ATen: [aten.sum]
triton_red_fused_sum_15.run(buf466, buf469, 49152, 128, grid=grid(49152), stream=stream0)
buf472 = empty_strided_cuda((768, ), (1, ), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum]
triton_per_fused__to_copy_sum_11.run(buf469, buf472, 768, 64, grid=grid(768), stream=stream0)
buf471 = empty_strided_cuda((768, 768), (768, 1), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_12.run(buf468, buf471, 589824, grid=grid(589824), stream=stream0)
buf473 = reinterpret_tensor(buf466, (16, 12, 512, 64), (393216, 32768, 64, 1), 0); del buf466 # reuse
# Source Nodes: [], Original ATen: [aten.clone]
triton_poi_fused_clone_23.run(buf467, buf473, 6291456, grid=grid(6291456), stream=stream0)
del buf467
# Source Nodes: [], Original ATen: [aten.clone]
buf474 = aten._scaled_dot_product_flash_attention_backward.default(buf473, permute_default_42, permute_default_43, permute_default_44, getitem_101, getitem_102, None, None, 512, 512, 0.1, False, getitem_103, getitem_104, scale=0.125)
del getitem_101
del getitem_102
del getitem_103
del getitem_104
del permute_default_42
del permute_default_43
del permute_default_44
buf475 = buf474[0]
buf476 = buf474[1]
buf477 = buf474[2]
del buf474
buf478 = reinterpret_tensor(buf473, (8192, 768), (768, 1), 0); del buf473 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf477, (8192, 768), (768, 1), 0), permute_393, out=buf478)
del permute_393
buf479 = buf468; del buf468 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf477, (768, 8192), (1, 768), 0), view_88, out=buf479)
buf480 = buf469; del buf469 # reuse
# Source Nodes: [], Original ATen: [aten.sum]
triton_red_fused_sum_24.run(buf477, buf480, 49152, 128, grid=grid(49152), stream=stream0)
buf483 = empty_strided_cuda((768, ), (1, ), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum]
triton_per_fused__to_copy_sum_11.run(buf480, buf483, 768, 64, grid=grid(768), stream=stream0)
buf482 = empty_strided_cuda((768, 768), (768, 1), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_12.run(buf479, buf482, 589824, grid=grid(589824), stream=stream0)
buf484 = reinterpret_tensor(buf477, (8192, 768), (768, 1), 0); del buf477 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf476, (8192, 768), (768, 1), 0), permute_398, out=buf484)
del permute_398
buf485 = buf479; del buf479 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf476, (768, 8192), (1, 768), 0), view_88, out=buf485)
buf486 = buf480; del buf480 # reuse
# Source Nodes: [], Original ATen: [aten.sum]
triton_red_fused_sum_24.run(buf476, buf486, 49152, 128, grid=grid(49152), stream=stream0)
buf489 = empty_strided_cuda((768, ), (1, ), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum]
triton_per_fused__to_copy_sum_11.run(buf486, buf489, 768, 64, grid=grid(768), stream=stream0)
buf488 = empty_strided_cuda((768, 768), (768, 1), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_12.run(buf485, buf488, 589824, grid=grid(589824), stream=stream0)
buf490 = reinterpret_tensor(buf476, (8192, 768), (768, 1), 0); del buf476 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf475, (8192, 768), (768, 1), 0), permute_402, out=buf490)
del permute_402
buf491 = buf485; del buf485 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf475, (768, 8192), (1, 768), 0), view_88, out=buf491)
del view_88
buf492 = buf486; del buf486 # reuse
# Source Nodes: [], Original ATen: [aten.sum]
triton_red_fused_sum_24.run(buf475, buf492, 49152, 128, grid=grid(49152), stream=stream0)
buf495 = empty_strided_cuda((768, ), (1, ), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum]
triton_per_fused__to_copy_sum_11.run(buf492, buf495, 768, 64, grid=grid(768), stream=stream0)
buf494 = empty_strided_cuda((768, 768), (768, 1), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_12.run(buf491, buf494, 589824, grid=grid(589824), stream=stream0)
buf499 = buf440; del buf440 # reuse
buf504 = reinterpret_tensor(buf475, (16, 512, 768), (393216, 768, 1), 0); del buf475 # reuse
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_dropout_backward, aten.native_layer_norm_backward]
triton_per_fused__to_copy_add_native_dropout_backward_native_layer_norm_backward_25.run(buf461, buf478, buf484, buf490, primals_68, mul_55, div_51, gt_12, buf499, buf504, 8192, 768, grid=grid(8192), stream=stream0)
del div_51
del gt_12
del primals_68
buf500 = reinterpret_tensor(buf492, (768, 64), (1, 768), 0); del buf492 # reuse
buf502 = buf462; del buf462 # reuse
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward]
triton_red_fused__to_copy_add_native_layer_norm_backward_26.run(buf461, buf478, buf484, buf490, mul_55, buf500, buf502, 49152, 128, grid=grid(49152), stream=stream0)
del buf478
del buf484
del mul_55
buf501 = empty_strided_cuda((768, ), (1, ), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward]
triton_per_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_9.run(buf500, buf501, 768, 64, grid=grid(768), stream=stream0)
buf503 = empty_strided_cuda((768, ), (1, ), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward]
triton_per_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_9.run(buf502, buf503, 768, 64, grid=grid(768), stream=stream0)
buf505 = reinterpret_tensor(buf452, (8192, 3072), (3072, 1), 0); del buf452 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf504, (8192, 768), (768, 1), 0), permute_406, out=buf505)
del permute_406
buf506 = reinterpret_tensor(buf454, (768, 3072), (3072, 1), 0); del buf454 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf504, (768, 8192), (1, 768), 0), view_86, out=buf506)
del view_86
buf507 = reinterpret_tensor(buf502, (1, 768, 64), (49152, 1, 768), 0); del buf502 # reuse
# Source Nodes: [], Original ATen: [aten.sum]
triton_red_fused_sum_15.run(buf504, buf507, 49152, 128, grid=grid(49152), stream=stream0)
buf510 = empty_strided_cuda((768, ), (1, ), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum]
triton_per_fused__to_copy_sum_11.run(buf507, buf510, 768, 64, grid=grid(768), stream=stream0)
buf509 = empty_strided_cuda((768, 3072), (3072, 1), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_16.run(buf506, buf509, 2359296, grid=grid(2359296), stream=stream0)
buf511 = reinterpret_tensor(buf505, (16, 512, 3072), (1572864, 3072, 1), 0); del buf505 # reuse
# Source Nodes: [hidden_states_28], Original ATen: [aten.gelu, aten.gelu_backward]
triton_poi_fused_gelu_gelu_backward_17.run(buf511, addmm_22, 25165824, grid=grid(25165824), stream=stream0)
del addmm_22
buf512 = reinterpret_tensor(buf504, (8192, 768), (768, 1), 0); del buf504 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf511, (8192, 3072), (3072, 1), 0), permute_410, out=buf512)
del permute_410
buf513 = reinterpret_tensor(buf506, (3072, 768), (768, 1), 0); del buf506 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf511, (3072, 8192), (1, 3072), 0), view_84, out=buf513)
del view_84
buf514 = buf455; del buf455 # reuse
# Source Nodes: [], Original ATen: [aten.sum]
triton_red_fused_sum_18.run(buf511, buf514, 98304, 256, grid=grid(98304), stream=stream0)
buf517 = empty_strided_cuda((3072, ), (1, ), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum]
triton_per_fused__to_copy_sum_19.run(buf514, buf517, 3072, 32, grid=grid(3072), stream=stream0)
buf516 = empty_strided_cuda((3072, 768), (768, 1), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_20.run(buf513, buf516, 2359296, grid=grid(2359296), stream=stream0)
buf520 = buf461; del buf461 # reuse
buf525 = reinterpret_tensor(buf490, (16, 512, 768), (393216, 768, 1), 0); del buf490 # reuse
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_dropout_backward, aten.native_layer_norm_backward]
triton_per_fused__to_copy_add_native_dropout_backward_native_layer_norm_backward_21.run(buf499, buf512, primals_62, mul_48, div_52, gt_11, buf520, buf525, 8192, 768, grid=grid(8192), stream=stream0)
del div_52
del gt_11
del primals_62
buf521 = reinterpret_tensor(buf507, (768, 64), (1, 768), 0); del buf507 # reuse
buf523 = buf500; del buf500 # reuse
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward]
triton_red_fused__to_copy_add_native_layer_norm_backward_22.run(buf499, buf512, mul_48, buf521, buf523, 49152, 128, grid=grid(49152), stream=stream0)
del mul_48
buf522 = empty_strided_cuda((768, ), (1, ), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward]
triton_per_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_9.run(buf521, buf522, 768, 64, grid=grid(768), stream=stream0)
buf524 = empty_strided_cuda((768, ), (1, ), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward]
triton_per_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_9.run(buf523, buf524, 768, 64, grid=grid(768), stream=stream0)
buf526 = buf512; del buf512 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf525, (8192, 768), (768, 1), 0), permute_414, out=buf526)
del permute_414
buf527 = buf491; del buf491 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf525, (768, 8192), (1, 768), 0), view_82, out=buf527)
del view_82
buf528 = reinterpret_tensor(buf523, (1, 768, 64), (49152, 1, 768), 0); del buf523 # reuse
# Source Nodes: [], Original ATen: [aten.sum]
triton_red_fused_sum_15.run(buf525, buf528, 49152, 128, grid=grid(49152), stream=stream0)
buf531 = empty_strided_cuda((768, ), (1, ), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum]
triton_per_fused__to_copy_sum_11.run(buf528, buf531, 768, 64, grid=grid(768), stream=stream0)
buf530 = empty_strided_cuda((768, 768), (768, 1), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_12.run(buf527, buf530, 589824, grid=grid(589824), stream=stream0)
buf532 = reinterpret_tensor(buf525, (16, 12, 512, 64), (393216, 32768, 64, 1), 0); del buf525 # reuse
# Source Nodes: [], Original ATen: [aten.clone]
triton_poi_fused_clone_23.run(buf526, buf532, 6291456, grid=grid(6291456), stream=stream0)
del buf526
# Source Nodes: [], Original ATen: [aten.clone]
buf533 = aten._scaled_dot_product_flash_attention_backward.default(buf532, permute_default_48, permute_default_49, permute_default_50, getitem_108, getitem_109, None, None, 512, 512, 0.1, False, getitem_110, getitem_111, scale=0.125)
del getitem_108
del getitem_109
del getitem_110
del getitem_111
del permute_default_48
del permute_default_49
del permute_default_50
buf534 = buf533[0]
buf535 = buf533[1]
buf536 = buf533[2]
del buf533
buf537 = reinterpret_tensor(buf532, (8192, 768), (768, 1), 0); del buf532 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf536, (8192, 768), (768, 1), 0), permute_426, out=buf537)
del permute_426
buf538 = buf527; del buf527 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf536, (768, 8192), (1, 768), 0), view_66, out=buf538)
buf539 = buf528; del buf528 # reuse
# Source Nodes: [], Original ATen: [aten.sum]
triton_red_fused_sum_24.run(buf536, buf539, 49152, 128, grid=grid(49152), stream=stream0)
buf542 = empty_strided_cuda((768, ), (1, ), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum]
triton_per_fused__to_copy_sum_11.run(buf539, buf542, 768, 64, grid=grid(768), stream=stream0)
buf541 = empty_strided_cuda((768, 768), (768, 1), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_12.run(buf538, buf541, 589824, grid=grid(589824), stream=stream0)
buf543 = reinterpret_tensor(buf536, (8192, 768), (768, 1), 0); del buf536 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf535, (8192, 768), (768, 1), 0), permute_431, out=buf543)
del permute_431
buf544 = buf538; del buf538 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf535, (768, 8192), (1, 768), 0), view_66, out=buf544)
buf545 = buf539; del buf539 # reuse
# Source Nodes: [], Original ATen: [aten.sum]
triton_red_fused_sum_24.run(buf535, buf545, 49152, 128, grid=grid(49152), stream=stream0)
buf548 = empty_strided_cuda((768, ), (1, ), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum]
triton_per_fused__to_copy_sum_11.run(buf545, buf548, 768, 64, grid=grid(768), stream=stream0)
buf547 = empty_strided_cuda((768, 768), (768, 1), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_12.run(buf544, buf547, 589824, grid=grid(589824), stream=stream0)
buf549 = reinterpret_tensor(buf535, (8192, 768), (768, 1), 0); del buf535 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf534, (8192, 768), (768, 1), 0), permute_435, out=buf549)
del permute_435
buf550 = buf544; del buf544 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf534, (768, 8192), (1, 768), 0), view_66, out=buf550)
del view_66
buf551 = buf545; del buf545 # reuse
# Source Nodes: [], Original ATen: [aten.sum]
triton_red_fused_sum_24.run(buf534, buf551, 49152, 128, grid=grid(49152), stream=stream0)
buf554 = empty_strided_cuda((768, ), (1, ), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum]
triton_per_fused__to_copy_sum_11.run(buf551, buf554, 768, 64, grid=grid(768), stream=stream0)
buf553 = empty_strided_cuda((768, 768), (768, 1), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_12.run(buf550, buf553, 589824, grid=grid(589824), stream=stream0)
buf558 = buf499; del buf499 # reuse
buf563 = reinterpret_tensor(buf534, (16, 512, 768), (393216, 768, 1), 0); del buf534 # reuse
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_dropout_backward, aten.native_layer_norm_backward]
triton_per_fused__to_copy_add_native_dropout_backward_native_layer_norm_backward_25.run(buf520, buf537, buf543, buf549, primals_52, mul_42, div_54, gt_9, buf558, buf563, 8192, 768, grid=grid(8192), stream=stream0)
del div_54
del gt_9
del primals_52
buf559 = reinterpret_tensor(buf551, (768, 64), (1, 768), 0); del buf551 # reuse
buf561 = buf521; del buf521 # reuse
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward]
triton_red_fused__to_copy_add_native_layer_norm_backward_26.run(buf520, buf537, buf543, buf549, mul_42, buf559, buf561, 49152, 128, grid=grid(49152), stream=stream0)
del buf537
del buf543
del mul_42
buf560 = empty_strided_cuda((768, ), (1, ), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward]
triton_per_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_9.run(buf559, buf560, 768, 64, grid=grid(768), stream=stream0)
buf562 = empty_strided_cuda((768, ), (1, ), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward]
triton_per_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_9.run(buf561, buf562, 768, 64, grid=grid(768), stream=stream0)
buf564 = reinterpret_tensor(buf511, (8192, 3072), (3072, 1), 0); del buf511 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf563, (8192, 768), (768, 1), 0), permute_439, out=buf564)
del permute_439
buf565 = reinterpret_tensor(buf513, (768, 3072), (3072, 1), 0); del buf513 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf563, (768, 8192), (1, 768), 0), view_64, out=buf565)
del view_64
buf566 = reinterpret_tensor(buf561, (1, 768, 64), (49152, 1, 768), 0); del buf561 # reuse
# Source Nodes: [], Original ATen: [aten.sum]
triton_red_fused_sum_15.run(buf563, buf566, 49152, 128, grid=grid(49152), stream=stream0)
buf569 = empty_strided_cuda((768, ), (1, ), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum]
triton_per_fused__to_copy_sum_11.run(buf566, buf569, 768, 64, grid=grid(768), stream=stream0)
buf568 = empty_strided_cuda((768, 3072), (3072, 1), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_16.run(buf565, buf568, 2359296, grid=grid(2359296), stream=stream0)
buf570 = reinterpret_tensor(buf564, (16, 512, 3072), (1572864, 3072, 1), 0); del buf564 # reuse
# Source Nodes: [hidden_states_20], Original ATen: [aten.gelu, aten.gelu_backward]
triton_poi_fused_gelu_gelu_backward_17.run(buf570, addmm_16, 25165824, grid=grid(25165824), stream=stream0)
del addmm_16
buf571 = reinterpret_tensor(buf563, (8192, 768), (768, 1), 0); del buf563 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf570, (8192, 3072), (3072, 1), 0), permute_443, out=buf571)
del permute_443
buf572 = reinterpret_tensor(buf565, (3072, 768), (768, 1), 0); del buf565 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf570, (3072, 8192), (1, 3072), 0), view_62, out=buf572)
del view_62
buf573 = buf514; del buf514 # reuse
# Source Nodes: [], Original ATen: [aten.sum]
triton_red_fused_sum_18.run(buf570, buf573, 98304, 256, grid=grid(98304), stream=stream0)
buf576 = empty_strided_cuda((3072, ), (1, ), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum]
triton_per_fused__to_copy_sum_19.run(buf573, buf576, 3072, 32, grid=grid(3072), stream=stream0)
buf575 = empty_strided_cuda((3072, 768), (768, 1), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_20.run(buf572, buf575, 2359296, grid=grid(2359296), stream=stream0)
buf579 = buf520; del buf520 # reuse
buf584 = reinterpret_tensor(buf549, (16, 512, 768), (393216, 768, 1), 0); del buf549 # reuse
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_dropout_backward, aten.native_layer_norm_backward]
triton_per_fused__to_copy_add_native_dropout_backward_native_layer_norm_backward_21.run(buf558, buf571, primals_46, mul_35, div_55, gt_8, buf579, buf584, 8192, 768, grid=grid(8192), stream=stream0)
del div_55
del gt_8
del primals_46
buf580 = reinterpret_tensor(buf566, (768, 64), (1, 768), 0); del buf566 # reuse
buf582 = buf559; del buf559 # reuse
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward]
triton_red_fused__to_copy_add_native_layer_norm_backward_22.run(buf558, buf571, mul_35, buf580, buf582, 49152, 128, grid=grid(49152), stream=stream0)
del mul_35
buf581 = empty_strided_cuda((768, ), (1, ), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward]
triton_per_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_9.run(buf580, buf581, 768, 64, grid=grid(768), stream=stream0)
buf583 = empty_strided_cuda((768, ), (1, ), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward]
triton_per_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_9.run(buf582, buf583, 768, 64, grid=grid(768), stream=stream0)
buf585 = buf571; del buf571 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf584, (8192, 768), (768, 1), 0), permute_447, out=buf585)
del permute_447
buf586 = buf550; del buf550 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf584, (768, 8192), (1, 768), 0), view_60, out=buf586)
del view_60
buf587 = reinterpret_tensor(buf582, (1, 768, 64), (49152, 1, 768), 0); del buf582 # reuse
# Source Nodes: [], Original ATen: [aten.sum]
triton_red_fused_sum_15.run(buf584, buf587, 49152, 128, grid=grid(49152), stream=stream0)
buf590 = empty_strided_cuda((768, ), (1, ), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum]
triton_per_fused__to_copy_sum_11.run(buf587, buf590, 768, 64, grid=grid(768), stream=stream0)
buf589 = empty_strided_cuda((768, 768), (768, 1), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_12.run(buf586, buf589, 589824, grid=grid(589824), stream=stream0)
buf591 = reinterpret_tensor(buf584, (16, 12, 512, 64), (393216, 32768, 64, 1), 0); del buf584 # reuse
# Source Nodes: [], Original ATen: [aten.clone]
triton_poi_fused_clone_23.run(buf585, buf591, 6291456, grid=grid(6291456), stream=stream0)
del buf585
# Source Nodes: [], Original ATen: [aten.clone]
buf592 = aten._scaled_dot_product_flash_attention_backward.default(buf591, permute_default_54, permute_default_55, permute_default_56, getitem_115, getitem_116, None, None, 512, 512, 0.1, False, getitem_117, getitem_118, scale=0.125)
del getitem_115
del getitem_116
del getitem_117
del getitem_118
del permute_default_54
del permute_default_55
del permute_default_56
buf593 = buf592[0]
buf594 = buf592[1]
buf595 = buf592[2]
del buf592
buf596 = reinterpret_tensor(buf591, (8192, 768), (768, 1), 0); del buf591 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf595, (8192, 768), (768, 1), 0), permute_459, out=buf596)
del permute_459
buf597 = buf586; del buf586 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf595, (768, 8192), (1, 768), 0), view_44, out=buf597)
buf598 = buf587; del buf587 # reuse
# Source Nodes: [], Original ATen: [aten.sum]
triton_red_fused_sum_24.run(buf595, buf598, 49152, 128, grid=grid(49152), stream=stream0)
buf601 = empty_strided_cuda((768, ), (1, ), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum]
triton_per_fused__to_copy_sum_11.run(buf598, buf601, 768, 64, grid=grid(768), stream=stream0)
buf600 = empty_strided_cuda((768, 768), (768, 1), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_12.run(buf597, buf600, 589824, grid=grid(589824), stream=stream0)
buf602 = reinterpret_tensor(buf595, (8192, 768), (768, 1), 0); del buf595 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf594, (8192, 768), (768, 1), 0), permute_464, out=buf602)
del permute_464
buf603 = buf597; del buf597 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf594, (768, 8192), (1, 768), 0), view_44, out=buf603)
buf604 = buf598; del buf598 # reuse
# Source Nodes: [], Original ATen: [aten.sum]
triton_red_fused_sum_24.run(buf594, buf604, 49152, 128, grid=grid(49152), stream=stream0)
buf607 = empty_strided_cuda((768, ), (1, ), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum]
triton_per_fused__to_copy_sum_11.run(buf604, buf607, 768, 64, grid=grid(768), stream=stream0)
buf606 = empty_strided_cuda((768, 768), (768, 1), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_12.run(buf603, buf606, 589824, grid=grid(589824), stream=stream0)
buf608 = reinterpret_tensor(buf594, (8192, 768), (768, 1), 0); del buf594 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf593, (8192, 768), (768, 1), 0), permute_468, out=buf608)
del permute_468
buf609 = buf603; del buf603 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf593, (768, 8192), (1, 768), 0), view_44, out=buf609)
del view_44
buf610 = buf604; del buf604 # reuse
# Source Nodes: [], Original ATen: [aten.sum]
triton_red_fused_sum_24.run(buf593, buf610, 49152, 128, grid=grid(49152), stream=stream0)
buf613 = empty_strided_cuda((768, ), (1, ), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum]
triton_per_fused__to_copy_sum_11.run(buf610, buf613, 768, 64, grid=grid(768), stream=stream0)
buf612 = empty_strided_cuda((768, 768), (768, 1), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_12.run(buf609, buf612, 589824, grid=grid(589824), stream=stream0)
buf617 = buf558; del buf558 # reuse
buf622 = reinterpret_tensor(buf593, (16, 512, 768), (393216, 768, 1), 0); del buf593 # reuse
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_dropout_backward, aten.native_layer_norm_backward]
triton_per_fused__to_copy_add_native_dropout_backward_native_layer_norm_backward_25.run(buf579, buf596, buf602, buf608, primals_36, mul_29, div_57, gt_6, buf617, buf622, 8192, 768, grid=grid(8192), stream=stream0)
del div_57
del gt_6
del primals_36
buf618 = reinterpret_tensor(buf610, (768, 64), (1, 768), 0); del buf610 # reuse
buf620 = buf580; del buf580 # reuse
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward]
triton_red_fused__to_copy_add_native_layer_norm_backward_26.run(buf579, buf596, buf602, buf608, mul_29, buf618, buf620, 49152, 128, grid=grid(49152), stream=stream0)
del buf596
del buf602
del mul_29
buf619 = empty_strided_cuda((768, ), (1, ), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward]
triton_per_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_9.run(buf618, buf619, 768, 64, grid=grid(768), stream=stream0)
buf621 = empty_strided_cuda((768, ), (1, ), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward]
triton_per_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_9.run(buf620, buf621, 768, 64, grid=grid(768), stream=stream0)
buf623 = reinterpret_tensor(buf570, (8192, 3072), (3072, 1), 0); del buf570 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf622, (8192, 768), (768, 1), 0), permute_472, out=buf623)
del permute_472
buf624 = reinterpret_tensor(buf572, (768, 3072), (3072, 1), 0); del buf572 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf622, (768, 8192), (1, 768), 0), view_42, out=buf624)
del view_42
buf625 = reinterpret_tensor(buf620, (1, 768, 64), (49152, 1, 768), 0); del buf620 # reuse
# Source Nodes: [], Original ATen: [aten.sum]
triton_red_fused_sum_15.run(buf622, buf625, 49152, 128, grid=grid(49152), stream=stream0)
buf628 = empty_strided_cuda((768, ), (1, ), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum]
triton_per_fused__to_copy_sum_11.run(buf625, buf628, 768, 64, grid=grid(768), stream=stream0)
buf627 = empty_strided_cuda((768, 3072), (3072, 1), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_16.run(buf624, buf627, 2359296, grid=grid(2359296), stream=stream0)
buf629 = reinterpret_tensor(buf623, (16, 512, 3072), (1572864, 3072, 1), 0); del buf623 # reuse
# Source Nodes: [hidden_states_12], Original ATen: [aten.gelu, aten.gelu_backward]
triton_poi_fused_gelu_gelu_backward_17.run(buf629, addmm_10, 25165824, grid=grid(25165824), stream=stream0)
del addmm_10
buf630 = reinterpret_tensor(buf622, (8192, 768), (768, 1), 0); del buf622 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf629, (8192, 3072), (3072, 1), 0), permute_476, out=buf630)
del permute_476
buf631 = reinterpret_tensor(buf624, (3072, 768), (768, 1), 0); del buf624 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf629, (3072, 8192), (1, 3072), 0), view_40, out=buf631)
del view_40
buf632 = buf573; del buf573 # reuse
# Source Nodes: [], Original ATen: [aten.sum]
triton_red_fused_sum_18.run(buf629, buf632, 98304, 256, grid=grid(98304), stream=stream0)
buf635 = empty_strided_cuda((3072, ), (1, ), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum]
triton_per_fused__to_copy_sum_19.run(buf632, buf635, 3072, 32, grid=grid(3072), stream=stream0)
buf634 = empty_strided_cuda((3072, 768), (768, 1), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_20.run(buf631, buf634, 2359296, grid=grid(2359296), stream=stream0)
buf638 = buf579; del buf579 # reuse
buf643 = reinterpret_tensor(buf608, (16, 512, 768), (393216, 768, 1), 0); del buf608 # reuse
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_dropout_backward, aten.native_layer_norm_backward]
triton_per_fused__to_copy_add_native_dropout_backward_native_layer_norm_backward_21.run(buf617, buf630, primals_30, mul_22, div_58, gt_5, buf638, buf643, 8192, 768, grid=grid(8192), stream=stream0)
del div_58
del gt_5
del primals_30
buf639 = reinterpret_tensor(buf625, (768, 64), (1, 768), 0); del buf625 # reuse
buf641 = buf618; del buf618 # reuse
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward]
triton_red_fused__to_copy_add_native_layer_norm_backward_22.run(buf617, buf630, mul_22, buf639, buf641, 49152, 128, grid=grid(49152), stream=stream0)
del mul_22
buf640 = empty_strided_cuda((768, ), (1, ), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward]
triton_per_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_9.run(buf639, buf640, 768, 64, grid=grid(768), stream=stream0)
buf642 = empty_strided_cuda((768, ), (1, ), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward]
triton_per_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_9.run(buf641, buf642, 768, 64, grid=grid(768), stream=stream0)
buf644 = buf630; del buf630 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf643, (8192, 768), (768, 1), 0), permute_480, out=buf644)
del permute_480
buf645 = buf609; del buf609 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf643, (768, 8192), (1, 768), 0), view_38, out=buf645)
del view_38
buf646 = reinterpret_tensor(buf641, (1, 768, 64), (49152, 1, 768), 0); del buf641 # reuse
# Source Nodes: [], Original ATen: [aten.sum]
triton_red_fused_sum_15.run(buf643, buf646, 49152, 128, grid=grid(49152), stream=stream0)
buf649 = empty_strided_cuda((768, ), (1, ), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum]
triton_per_fused__to_copy_sum_11.run(buf646, buf649, 768, 64, grid=grid(768), stream=stream0)
buf648 = empty_strided_cuda((768, 768), (768, 1), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_12.run(buf645, buf648, 589824, grid=grid(589824), stream=stream0)
buf650 = reinterpret_tensor(buf643, (16, 12, 512, 64), (393216, 32768, 64, 1), 0); del buf643 # reuse
# Source Nodes: [], Original ATen: [aten.clone]
triton_poi_fused_clone_23.run(buf644, buf650, 6291456, grid=grid(6291456), stream=stream0)
del buf644
# Source Nodes: [], Original ATen: [aten.clone]
buf651 = aten._scaled_dot_product_flash_attention_backward.default(buf650, permute_default_60, permute_default_61, permute_default_62, getitem_122, getitem_123, None, None, 512, 512, 0.1, False, getitem_124, getitem_125, scale=0.125)
del getitem_122
del getitem_123
del getitem_124
del getitem_125
del permute_default_60
del permute_default_61
del permute_default_62
buf652 = buf651[0]
buf653 = buf651[1]
buf654 = buf651[2]
del buf651
buf655 = reinterpret_tensor(buf650, (8192, 768), (768, 1), 0); del buf650 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf654, (8192, 768), (768, 1), 0), permute_492, out=buf655)
del permute_492
buf656 = buf645; del buf645 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf654, (768, 8192), (1, 768), 0), view_22, out=buf656)
buf657 = buf646; del buf646 # reuse
# Source Nodes: [], Original ATen: [aten.sum]
triton_red_fused_sum_24.run(buf654, buf657, 49152, 128, grid=grid(49152), stream=stream0)
buf660 = empty_strided_cuda((768, ), (1, ), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum]
triton_per_fused__to_copy_sum_11.run(buf657, buf660, 768, 64, grid=grid(768), stream=stream0)
buf659 = empty_strided_cuda((768, 768), (768, 1), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_12.run(buf656, buf659, 589824, grid=grid(589824), stream=stream0)
buf661 = reinterpret_tensor(buf654, (8192, 768), (768, 1), 0); del buf654 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf653, (8192, 768), (768, 1), 0), permute_497, out=buf661)
del permute_497
buf662 = buf656; del buf656 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf653, (768, 8192), (1, 768), 0), view_22, out=buf662)
buf663 = buf657; del buf657 # reuse
# Source Nodes: [], Original ATen: [aten.sum]
triton_red_fused_sum_24.run(buf653, buf663, 49152, 128, grid=grid(49152), stream=stream0)
buf666 = empty_strided_cuda((768, ), (1, ), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum]
triton_per_fused__to_copy_sum_11.run(buf663, buf666, 768, 64, grid=grid(768), stream=stream0)
buf665 = empty_strided_cuda((768, 768), (768, 1), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_12.run(buf662, buf665, 589824, grid=grid(589824), stream=stream0)
buf667 = reinterpret_tensor(buf653, (8192, 768), (768, 1), 0); del buf653 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf652, (8192, 768), (768, 1), 0), permute_501, out=buf667)
del permute_501
buf668 = buf662; del buf662 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf652, (768, 8192), (1, 768), 0), view_22, out=buf668)
del view_22
buf669 = buf663; del buf663 # reuse
# Source Nodes: [], Original ATen: [aten.sum]
triton_red_fused_sum_24.run(buf652, buf669, 49152, 128, grid=grid(49152), stream=stream0)
buf672 = empty_strided_cuda((768, ), (1, ), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum]
triton_per_fused__to_copy_sum_11.run(buf669, buf672, 768, 64, grid=grid(768), stream=stream0)
buf671 = empty_strided_cuda((768, 768), (768, 1), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_12.run(buf668, buf671, 589824, grid=grid(589824), stream=stream0)
buf676 = buf617; del buf617 # reuse
buf681 = reinterpret_tensor(buf652, (16, 512, 768), (393216, 768, 1), 0); del buf652 # reuse
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_dropout_backward, aten.native_layer_norm_backward]
triton_per_fused__to_copy_add_native_dropout_backward_native_layer_norm_backward_25.run(buf638, buf655, buf661, buf667, primals_20, mul_16, div_60, gt_3, buf676, buf681, 8192, 768, grid=grid(8192), stream=stream0)
del div_60
del gt_3
del primals_20
buf677 = reinterpret_tensor(buf669, (768, 64), (1, 768), 0); del buf669 # reuse
buf679 = buf639; del buf639 # reuse
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward]
triton_red_fused__to_copy_add_native_layer_norm_backward_26.run(buf638, buf655, buf661, buf667, mul_16, buf677, buf679, 49152, 128, grid=grid(49152), stream=stream0)
del buf655
del buf661
del mul_16
buf678 = empty_strided_cuda((768, ), (1, ), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward]
triton_per_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_9.run(buf677, buf678, 768, 64, grid=grid(768), stream=stream0)
buf680 = empty_strided_cuda((768, ), (1, ), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward]
triton_per_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_9.run(buf679, buf680, 768, 64, grid=grid(768), stream=stream0)
buf682 = reinterpret_tensor(buf629, (8192, 3072), (3072, 1), 0); del buf629 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf681, (8192, 768), (768, 1), 0), permute_505, out=buf682)
del permute_505
buf683 = reinterpret_tensor(buf631, (768, 3072), (3072, 1), 0); del buf631 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf681, (768, 8192), (1, 768), 0), view_20, out=buf683)
del view_20
buf684 = reinterpret_tensor(buf679, (1, 768, 64), (49152, 1, 768), 0); del buf679 # reuse
# Source Nodes: [], Original ATen: [aten.sum]
triton_red_fused_sum_15.run(buf681, buf684, 49152, 128, grid=grid(49152), stream=stream0)
buf687 = empty_strided_cuda((768, ), (1, ), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum]
triton_per_fused__to_copy_sum_11.run(buf684, buf687, 768, 64, grid=grid(768), stream=stream0)
buf686 = empty_strided_cuda((768, 3072), (3072, 1), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_16.run(buf683, buf686, 2359296, grid=grid(2359296), stream=stream0)
buf688 = reinterpret_tensor(buf682, (16, 512, 3072), (1572864, 3072, 1), 0); del buf682 # reuse
# Source Nodes: [hidden_states_4], Original ATen: [aten.gelu, aten.gelu_backward]
triton_poi_fused_gelu_gelu_backward_17.run(buf688, addmm_4, 25165824, grid=grid(25165824), stream=stream0)
del addmm_4
buf689 = reinterpret_tensor(buf681, (8192, 768), (768, 1), 0); del buf681 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf688, (8192, 3072), (3072, 1), 0), permute_509, out=buf689)
del permute_509
buf690 = reinterpret_tensor(buf683, (3072, 768), (768, 1), 0); del buf683 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf688, (3072, 8192), (1, 3072), 0), view_18, out=buf690)
del view_18
buf691 = buf632; del buf632 # reuse
# Source Nodes: [], Original ATen: [aten.sum]
triton_red_fused_sum_18.run(buf688, buf691, 98304, 256, grid=grid(98304), stream=stream0)
del buf688
buf694 = empty_strided_cuda((3072, ), (1, ), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum]
triton_per_fused__to_copy_sum_19.run(buf691, buf694, 3072, 32, grid=grid(3072), stream=stream0)
del buf691
buf693 = empty_strided_cuda((3072, 768), (768, 1), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_20.run(buf690, buf693, 2359296, grid=grid(2359296), stream=stream0)
del buf690
buf697 = buf638; del buf638 # reuse
buf702 = reinterpret_tensor(buf667, (16, 512, 768), (393216, 768, 1), 0); del buf667 # reuse
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_dropout_backward, aten.native_layer_norm_backward]
triton_per_fused__to_copy_add_native_dropout_backward_native_layer_norm_backward_21.run(buf676, buf689, primals_14, mul_9, div_61, gt_2, buf697, buf702, 8192, 768, grid=grid(8192), stream=stream0)
del div_61
del gt_2
del primals_14
buf698 = reinterpret_tensor(buf684, (768, 64), (1, 768), 0); del buf684 # reuse
buf700 = buf677; del buf677 # reuse
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward]
triton_red_fused__to_copy_add_native_layer_norm_backward_22.run(buf676, buf689, mul_9, buf698, buf700, 49152, 128, grid=grid(49152), stream=stream0)
del mul_9
buf699 = empty_strided_cuda((768, ), (1, ), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward]
triton_per_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_9.run(buf698, buf699, 768, 64, grid=grid(768), stream=stream0)
buf701 = empty_strided_cuda((768, ), (1, ), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm_backward]
triton_per_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_9.run(buf700, buf701, 768, 64, grid=grid(768), stream=stream0)
buf703 = buf689; del buf689 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf702, (8192, 768), (768, 1), 0), permute_513, out=buf703)
del permute_513
buf704 = buf668; del buf668 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf702, (768, 8192), (1, 768), 0), view_16, out=buf704)
del view_16
buf705 = reinterpret_tensor(buf700, (1, 768, 64), (49152, 1, 768), 0); del buf700 # reuse
# Source Nodes: [], Original ATen: [aten.sum]
triton_red_fused_sum_15.run(buf702, buf705, 49152, 128, grid=grid(49152), stream=stream0)
buf708 = empty_strided_cuda((768, ), (1, ), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum]
triton_per_fused__to_copy_sum_11.run(buf705, buf708, 768, 64, grid=grid(768), stream=stream0)
buf707 = empty_strided_cuda((768, 768), (768, 1), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_12.run(buf704, buf707, 589824, grid=grid(589824), stream=stream0)
buf709 = reinterpret_tensor(buf702, (16, 12, 512, 64), (393216, 32768, 64, 1), 0); del buf702 # reuse
# Source Nodes: [], Original ATen: [aten.clone]
triton_poi_fused_clone_23.run(buf703, buf709, 6291456, grid=grid(6291456), stream=stream0)
del buf703
# Source Nodes: [], Original ATen: [aten.clone]
buf710 = aten._scaled_dot_product_flash_attention_backward.default(buf709, permute_default_66, permute_default_67, permute_default_68, getitem_129, getitem_130, None, None, 512, 512, 0.1, False, getitem_131, getitem_132, scale=0.125)
del getitem_129
del getitem_130
del getitem_131
del getitem_132
del permute_default_66
del permute_default_67
del permute_default_68
buf711 = buf710[0]
buf712 = buf710[1]
buf713 = buf710[2]
del buf710
buf714 = reinterpret_tensor(buf709, (8192, 768), (768, 1), 0); del buf709 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf713, (8192, 768), (768, 1), 0), permute_525, out=buf714)
del permute_525
buf715 = buf704; del buf704 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf713, (768, 8192), (1, 768), 0), view, out=buf715)
buf716 = buf705; del buf705 # reuse
# Source Nodes: [], Original ATen: [aten.sum]
triton_red_fused_sum_24.run(buf713, buf716, 49152, 128, grid=grid(49152), stream=stream0)
buf719 = empty_strided_cuda((768, ), (1, ), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum]
triton_per_fused__to_copy_sum_11.run(buf716, buf719, 768, 64, grid=grid(768), stream=stream0)
buf718 = empty_strided_cuda((768, 768), (768, 1), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_12.run(buf715, buf718, 589824, grid=grid(589824), stream=stream0)
buf720 = reinterpret_tensor(buf713, (8192, 768), (768, 1), 0); del buf713 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf712, (8192, 768), (768, 1), 0), permute_530, out=buf720)
del permute_530
buf721 = buf715; del buf715 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf712, (768, 8192), (1, 768), 0), view, out=buf721)
buf722 = buf716; del buf716 # reuse
# Source Nodes: [], Original ATen: [aten.sum]
triton_red_fused_sum_24.run(buf712, buf722, 49152, 128, grid=grid(49152), stream=stream0)
buf725 = empty_strided_cuda((768, ), (1, ), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum]
triton_per_fused__to_copy_sum_11.run(buf722, buf725, 768, 64, grid=grid(768), stream=stream0)
buf724 = empty_strided_cuda((768, 768), (768, 1), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_12.run(buf721, buf724, 589824, grid=grid(589824), stream=stream0)
buf726 = reinterpret_tensor(buf712, (8192, 768), (768, 1), 0); del buf712 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf711, (8192, 768), (768, 1), 0), permute_534, out=buf726)
del permute_534
buf727 = buf721; del buf721 # reuse
# Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf711, (768, 8192), (1, 768), 0), view, out=buf727)
del view
buf728 = buf722; del buf722 # reuse
# Source Nodes: [], Original ATen: [aten.sum]
triton_red_fused_sum_24.run(buf711, buf728, 49152, 128, grid=grid(49152), stream=stream0)
del buf711
buf731 = empty_strided_cuda((768, ), (1, ), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy, aten.sum]
triton_per_fused__to_copy_sum_11.run(buf728, buf731, 768, 64, grid=grid(768), stream=stream0)
buf730 = empty_strided_cuda((768, 768), (768, 1), torch.float32)
# Source Nodes: [], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_12.run(buf727, buf730, 589824, grid=grid(589824), stream=stream0)
del buf727
buf743 = empty_strided_cuda((2, 768), (768, 1), torch.float32)
# Source Nodes: [], Original ATen: [aten.embedding_dense_backward]
triton_poi_fused_embedding_dense_backward_27.run(buf743, 1536, grid=grid(1536), stream=stream0)
buf745 = empty_strided_cuda((30522, 768), (768, 1), torch.float32)
# Source Nodes: [], Original ATen: [aten.embedding_dense_backward]
triton_poi_fused_embedding_dense_backward_28.run(buf745, 23440896, grid=grid(23440896), stream=stream0)
buf732 = buf697; del buf697 # reuse
buf735 = buf676; del buf676 # reuse
# Source Nodes: [masked_lm_loss], Original ATen: [aten._to_copy, aten.add, aten.embedding_dense_backward, aten.native_dropout_backward, aten.native_layer_norm_backward, aten.nll_loss_forward]
triton_per_fused__to_copy_add_embedding_dense_backward_native_dropout_backward_native_layer_norm_backward_nll_loss_forward_29.run(buf732, buf714, buf720, buf726, gt, primals_4, mul_1, div_63, primals_204, primals_206, buf735, buf743, buf745, 8192, 768, grid=grid(8192), stream=stream0)
del buf714
del buf720
del buf726
del div_63
del gt
del primals_204
del primals_206
del primals_4
buf736 = reinterpret_tensor(buf728, (768, 64), (1, 768), 0); del buf728 # reuse
buf738 = buf698; del buf698 # reuse
# Source Nodes: [], Original ATen: [aten.native_layer_norm_backward]
triton_red_fused_native_layer_norm_backward_30.run(buf732, mul_1, buf736, buf738, 49152, 128, grid=grid(49152), stream=stream0)
del buf732
del mul_1
buf737 = empty_strided_cuda((768, ), (1, ), torch.float32)
# Source Nodes: [], Original ATen: [aten.native_layer_norm_backward]
triton_per_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_9.run(buf736, buf737, 768, 64, grid=grid(768), stream=stream0)
del buf736
buf739 = empty_strided_cuda((768, ), (1, ), torch.float32)
# Source Nodes: [], Original ATen: [aten.native_layer_norm_backward]
triton_per_fused__to_copy_gelu_native_layer_norm_native_layer_norm_backward_9.run(buf738, buf739, 768, 64, grid=grid(768), stream=stream0)
del buf738
buf741 = empty_strided_cuda((512, 768), (768, 1), torch.float32)
# Source Nodes: [], Original ATen: [aten.embedding_dense_backward]
triton_poi_fused_embedding_dense_backward_31.run(buf741, 393216, grid=grid(393216), stream=stream0)
# Source Nodes: [masked_lm_loss], Original ATen: [aten.embedding_dense_backward, aten.nll_loss_forward, aten.sum]
triton_per_fused_embedding_dense_backward_nll_loss_forward_sum_32.run(buf735, primals_205, buf741, 393216, 16, grid=grid(393216), stream=stream0)
del buf735
del primals_205
return (buf745, buf743, buf741, buf737, buf739, buf730, buf731, buf724, buf725, buf718, buf719, buf707, buf708, buf699, buf701, buf693, buf694, buf686, buf687, buf678, buf680, buf671, buf672, buf665, buf666, buf659, buf660, buf648, buf649, buf640, buf642, buf634, buf635, buf627, buf628, buf619, buf621, buf612, buf613, buf606, buf607, buf600, buf601, buf589, buf590, buf581, buf583, buf575, buf576, buf568, buf569, buf560, buf562, buf553, buf554, buf547, buf548, buf541, buf542, buf530, buf531, buf522, buf524, buf516, buf517, buf509, buf510, buf501, buf503, buf494, buf495, buf488, buf489, buf482, buf483, buf471, buf472, buf463, buf465, buf457, buf458, buf450, buf451, buf442, buf444, buf435, buf436, buf429, buf430, buf423, buf424, buf412, buf413, buf404, buf406, buf398, buf399, buf391, buf392, buf383, buf385, buf376, buf377, buf370, buf371, buf364, buf365, buf353, buf354, buf345, buf347, buf339, buf340, buf332, buf333, buf324, buf326, buf317, buf318, buf311, buf312, buf305, buf306, buf294, buf295, buf286, buf288, buf280, buf281, buf273, buf274, buf265, buf267, buf258, buf259, buf252, buf253, buf246, buf247, buf235, buf236, buf227, buf229, buf221, buf222, buf214, buf215, buf206, buf208, buf199, buf200, buf193, buf194, buf187, buf188, buf176, buf177, buf168, buf170, buf162, buf163, buf155, buf156, buf147, buf149, buf140, buf141, buf134, buf135, buf128, buf129, buf117, buf118, buf109, buf111, buf103, buf104, buf96, buf97, buf88, buf90, buf81, buf82, buf75, buf76, buf69, buf70, buf58, buf59, buf50, buf52, buf44, buf45, buf37, buf38, buf29, buf31, buf23, buf24, buf14, buf16, buf9, buf10, None, None, None, None, )
def benchmark_compiled_module(times=10, repeat=10):
from torch._dynamo.testing import rand_strided
from torch._inductor.utils import print_performance
primals_4 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_14 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_20 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_30 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_36 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_46 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_52 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_62 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_68 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_78 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_84 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_94 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_100 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_110 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_116 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_126 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_132 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_142 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_148 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_158 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_164 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_174 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_180 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_190 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_196 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_200 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_204 = rand_strided((1, 512), (512, 1), device='cuda:0', dtype=torch.int64)
primals_205 = rand_strided((1, 512), (512, 1), device='cuda:0', dtype=torch.int64)
primals_206 = rand_strided((16, 512), (512, 1), device='cuda:0', dtype=torch.int64)
primals_207 = rand_strided((16, 512), (512, 1), device='cuda:0', dtype=torch.int64)
mul_1 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.float32)
gt = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.bool)
view = rand_strided((8192, 768), (768, 1), device='cuda:0', dtype=torch.float16)
permute_default_66 = rand_strided((16, 12, 512, 64), (393216, 64, 768, 1), device='cuda:0', dtype=torch.float16)
permute_default_67 = rand_strided((16, 12, 512, 64), (393216, 64, 768, 1), device='cuda:0', dtype=torch.float16)
permute_default_68 = rand_strided((16, 12, 512, 64), (393216, 64, 768, 1), device='cuda:0', dtype=torch.float16)
getitem_129 = rand_strided((16, 12, 512, 64), (393216, 64, 768, 1), device='cuda:0', dtype=torch.float16)
getitem_130 = rand_strided((16, 12, 512), (6144, 512, 1), device='cuda:0', dtype=torch.float32)
getitem_131 = rand_strided((), (), device='cuda:0', dtype=torch.int64)
getitem_132 = rand_strided((), (), device='cuda:0', dtype=torch.int64)
view_16 = rand_strided((8192, 768), (768, 1), device='cuda:0', dtype=torch.float16)
gt_2 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.bool)
mul_9 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.float32)
view_18 = rand_strided((8192, 768), (768, 1), device='cuda:0', dtype=torch.float16)
addmm_4 = rand_strided((8192, 3072), (3072, 1), device='cuda:0', dtype=torch.float16)
view_20 = rand_strided((8192, 3072), (3072, 1), device='cuda:0', dtype=torch.float16)
gt_3 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.bool)
mul_16 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.float32)
view_22 = rand_strided((8192, 768), (768, 1), device='cuda:0', dtype=torch.float16)
permute_default_60 = rand_strided((16, 12, 512, 64), (393216, 64, 768, 1), device='cuda:0', dtype=torch.float16)
permute_default_61 = rand_strided((16, 12, 512, 64), (393216, 64, 768, 1), device='cuda:0', dtype=torch.float16)
permute_default_62 = rand_strided((16, 12, 512, 64), (393216, 64, 768, 1), device='cuda:0', dtype=torch.float16)
getitem_122 = rand_strided((16, 12, 512, 64), (393216, 64, 768, 1), device='cuda:0', dtype=torch.float16)
getitem_123 = rand_strided((16, 12, 512), (6144, 512, 1), device='cuda:0', dtype=torch.float32)
getitem_124 = rand_strided((), (), device='cuda:0', dtype=torch.int64)
getitem_125 = rand_strided((), (), device='cuda:0', dtype=torch.int64)
view_38 = rand_strided((8192, 768), (768, 1), device='cuda:0', dtype=torch.float16)
gt_5 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.bool)
mul_22 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.float32)
view_40 = rand_strided((8192, 768), (768, 1), device='cuda:0', dtype=torch.float16)
addmm_10 = rand_strided((8192, 3072), (3072, 1), device='cuda:0', dtype=torch.float16)
view_42 = rand_strided((8192, 3072), (3072, 1), device='cuda:0', dtype=torch.float16)
gt_6 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.bool)
mul_29 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.float32)
view_44 = rand_strided((8192, 768), (768, 1), device='cuda:0', dtype=torch.float16)
permute_default_54 = rand_strided((16, 12, 512, 64), (393216, 64, 768, 1), device='cuda:0', dtype=torch.float16)
permute_default_55 = rand_strided((16, 12, 512, 64), (393216, 64, 768, 1), device='cuda:0', dtype=torch.float16)
permute_default_56 = rand_strided((16, 12, 512, 64), (393216, 64, 768, 1), device='cuda:0', dtype=torch.float16)
getitem_115 = rand_strided((16, 12, 512, 64), (393216, 64, 768, 1), device='cuda:0', dtype=torch.float16)
getitem_116 = rand_strided((16, 12, 512), (6144, 512, 1), device='cuda:0', dtype=torch.float32)
getitem_117 = rand_strided((), (), device='cuda:0', dtype=torch.int64)
getitem_118 = rand_strided((), (), device='cuda:0', dtype=torch.int64)
view_60 = rand_strided((8192, 768), (768, 1), device='cuda:0', dtype=torch.float16)
gt_8 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.bool)
mul_35 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.float32)
view_62 = rand_strided((8192, 768), (768, 1), device='cuda:0', dtype=torch.float16)
addmm_16 = rand_strided((8192, 3072), (3072, 1), device='cuda:0', dtype=torch.float16)
view_64 = rand_strided((8192, 3072), (3072, 1), device='cuda:0', dtype=torch.float16)
gt_9 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.bool)
mul_42 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.float32)
view_66 = rand_strided((8192, 768), (768, 1), device='cuda:0', dtype=torch.float16)
permute_default_48 = rand_strided((16, 12, 512, 64), (393216, 64, 768, 1), device='cuda:0', dtype=torch.float16)
permute_default_49 = rand_strided((16, 12, 512, 64), (393216, 64, 768, 1), device='cuda:0', dtype=torch.float16)
permute_default_50 = rand_strided((16, 12, 512, 64), (393216, 64, 768, 1), device='cuda:0', dtype=torch.float16)
getitem_108 = rand_strided((16, 12, 512, 64), (393216, 64, 768, 1), device='cuda:0', dtype=torch.float16)
getitem_109 = rand_strided((16, 12, 512), (6144, 512, 1), device='cuda:0', dtype=torch.float32)
getitem_110 = rand_strided((), (), device='cuda:0', dtype=torch.int64)
getitem_111 = rand_strided((), (), device='cuda:0', dtype=torch.int64)
view_82 = rand_strided((8192, 768), (768, 1), device='cuda:0', dtype=torch.float16)
gt_11 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.bool)
mul_48 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.float32)
view_84 = rand_strided((8192, 768), (768, 1), device='cuda:0', dtype=torch.float16)
addmm_22 = rand_strided((8192, 3072), (3072, 1), device='cuda:0', dtype=torch.float16)
view_86 = rand_strided((8192, 3072), (3072, 1), device='cuda:0', dtype=torch.float16)
gt_12 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.bool)
mul_55 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.float32)
view_88 = rand_strided((8192, 768), (768, 1), device='cuda:0', dtype=torch.float16)
permute_default_42 = rand_strided((16, 12, 512, 64), (393216, 64, 768, 1), device='cuda:0', dtype=torch.float16)
permute_default_43 = rand_strided((16, 12, 512, 64), (393216, 64, 768, 1), device='cuda:0', dtype=torch.float16)
permute_default_44 = rand_strided((16, 12, 512, 64), (393216, 64, 768, 1), device='cuda:0', dtype=torch.float16)
getitem_101 = rand_strided((16, 12, 512, 64), (393216, 64, 768, 1), device='cuda:0', dtype=torch.float16)
getitem_102 = rand_strided((16, 12, 512), (6144, 512, 1), device='cuda:0', dtype=torch.float32)
getitem_103 = rand_strided((), (), device='cuda:0', dtype=torch.int64)
getitem_104 = rand_strided((), (), device='cuda:0', dtype=torch.int64)
view_104 = rand_strided((8192, 768), (768, 1), device='cuda:0', dtype=torch.float16)
gt_14 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.bool)
mul_61 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.float32)
view_106 = rand_strided((8192, 768), (768, 1), device='cuda:0', dtype=torch.float16)
addmm_28 = rand_strided((8192, 3072), (3072, 1), device='cuda:0', dtype=torch.float16)
view_108 = rand_strided((8192, 3072), (3072, 1), device='cuda:0', dtype=torch.float16)
gt_15 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.bool)
mul_68 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.float32)
view_110 = rand_strided((8192, 768), (768, 1), device='cuda:0', dtype=torch.float16)
permute_default_36 = rand_strided((16, 12, 512, 64), (393216, 64, 768, 1), device='cuda:0', dtype=torch.float16)
permute_default_37 = rand_strided((16, 12, 512, 64), (393216, 64, 768, 1), device='cuda:0', dtype=torch.float16)
permute_default_38 = rand_strided((16, 12, 512, 64), (393216, 64, 768, 1), device='cuda:0', dtype=torch.float16)
getitem_94 = rand_strided((16, 12, 512, 64), (393216, 64, 768, 1), device='cuda:0', dtype=torch.float16)
getitem_95 = rand_strided((16, 12, 512), (6144, 512, 1), device='cuda:0', dtype=torch.float32)
getitem_96 = rand_strided((), (), device='cuda:0', dtype=torch.int64)
getitem_97 = rand_strided((), (), device='cuda:0', dtype=torch.int64)
view_126 = rand_strided((8192, 768), (768, 1), device='cuda:0', dtype=torch.float16)
gt_17 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.bool)
mul_74 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.float32)
view_128 = rand_strided((8192, 768), (768, 1), device='cuda:0', dtype=torch.float16)
addmm_34 = rand_strided((8192, 3072), (3072, 1), device='cuda:0', dtype=torch.float16)
view_130 = rand_strided((8192, 3072), (3072, 1), device='cuda:0', dtype=torch.float16)
gt_18 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.bool)
mul_81 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.float32)
view_132 = rand_strided((8192, 768), (768, 1), device='cuda:0', dtype=torch.float16)
permute_default_30 = rand_strided((16, 12, 512, 64), (393216, 64, 768, 1), device='cuda:0', dtype=torch.float16)
permute_default_31 = rand_strided((16, 12, 512, 64), (393216, 64, 768, 1), device='cuda:0', dtype=torch.float16)
permute_default_32 = rand_strided((16, 12, 512, 64), (393216, 64, 768, 1), device='cuda:0', dtype=torch.float16)
getitem_87 = rand_strided((16, 12, 512, 64), (393216, 64, 768, 1), device='cuda:0', dtype=torch.float16)
getitem_88 = rand_strided((16, 12, 512), (6144, 512, 1), device='cuda:0', dtype=torch.float32)
getitem_89 = rand_strided((), (), device='cuda:0', dtype=torch.int64)
getitem_90 = rand_strided((), (), device='cuda:0', dtype=torch.int64)
view_148 = rand_strided((8192, 768), (768, 1), device='cuda:0', dtype=torch.float16)
gt_20 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.bool)
mul_87 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.float32)
view_150 = rand_strided((8192, 768), (768, 1), device='cuda:0', dtype=torch.float16)
addmm_40 = rand_strided((8192, 3072), (3072, 1), device='cuda:0', dtype=torch.float16)
view_152 = rand_strided((8192, 3072), (3072, 1), device='cuda:0', dtype=torch.float16)
gt_21 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.bool)
mul_94 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.float32)
view_154 = rand_strided((8192, 768), (768, 1), device='cuda:0', dtype=torch.float16)
permute_default_24 = rand_strided((16, 12, 512, 64), (393216, 64, 768, 1), device='cuda:0', dtype=torch.float16)
permute_default_25 = rand_strided((16, 12, 512, 64), (393216, 64, 768, 1), device='cuda:0', dtype=torch.float16)
permute_default_26 = rand_strided((16, 12, 512, 64), (393216, 64, 768, 1), device='cuda:0', dtype=torch.float16)
getitem_80 = rand_strided((16, 12, 512, 64), (393216, 64, 768, 1), device='cuda:0', dtype=torch.float16)
getitem_81 = rand_strided((16, 12, 512), (6144, 512, 1), device='cuda:0', dtype=torch.float32)
getitem_82 = rand_strided((), (), device='cuda:0', dtype=torch.int64)
getitem_83 = rand_strided((), (), device='cuda:0', dtype=torch.int64)
view_170 = rand_strided((8192, 768), (768, 1), device='cuda:0', dtype=torch.float16)
gt_23 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.bool)
mul_100 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.float32)
view_172 = rand_strided((8192, 768), (768, 1), device='cuda:0', dtype=torch.float16)
addmm_46 = rand_strided((8192, 3072), (3072, 1), device='cuda:0', dtype=torch.float16)
view_174 = rand_strided((8192, 3072), (3072, 1), device='cuda:0', dtype=torch.float16)
gt_24 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.bool)
mul_107 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.float32)
view_176 = rand_strided((8192, 768), (768, 1), device='cuda:0', dtype=torch.float16)
permute_default_18 = rand_strided((16, 12, 512, 64), (393216, 64, 768, 1), device='cuda:0', dtype=torch.float16)
permute_default_19 = rand_strided((16, 12, 512, 64), (393216, 64, 768, 1), device='cuda:0', dtype=torch.float16)
permute_default_20 = rand_strided((16, 12, 512, 64), (393216, 64, 768, 1), device='cuda:0', dtype=torch.float16)
getitem_73 = rand_strided((16, 12, 512, 64), (393216, 64, 768, 1), device='cuda:0', dtype=torch.float16)
getitem_74 = rand_strided((16, 12, 512), (6144, 512, 1), device='cuda:0', dtype=torch.float32)
getitem_75 = rand_strided((), (), device='cuda:0', dtype=torch.int64)
getitem_76 = rand_strided((), (), device='cuda:0', dtype=torch.int64)
view_192 = rand_strided((8192, 768), (768, 1), device='cuda:0', dtype=torch.float16)
gt_26 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.bool)
mul_113 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.float32)
view_194 = rand_strided((8192, 768), (768, 1), device='cuda:0', dtype=torch.float16)
addmm_52 = rand_strided((8192, 3072), (3072, 1), device='cuda:0', dtype=torch.float16)
view_196 = rand_strided((8192, 3072), (3072, 1), device='cuda:0', dtype=torch.float16)
gt_27 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.bool)
mul_120 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.float32)
view_198 = rand_strided((8192, 768), (768, 1), device='cuda:0', dtype=torch.float16)
permute_default_12 = rand_strided((16, 12, 512, 64), (393216, 64, 768, 1), device='cuda:0', dtype=torch.float16)
permute_default_13 = rand_strided((16, 12, 512, 64), (393216, 64, 768, 1), device='cuda:0', dtype=torch.float16)
permute_default_14 = rand_strided((16, 12, 512, 64), (393216, 64, 768, 1), device='cuda:0', dtype=torch.float16)
getitem_66 = rand_strided((16, 12, 512, 64), (393216, 64, 768, 1), device='cuda:0', dtype=torch.float16)
getitem_67 = rand_strided((16, 12, 512), (6144, 512, 1), device='cuda:0', dtype=torch.float32)
getitem_68 = rand_strided((), (), device='cuda:0', dtype=torch.int64)
getitem_69 = rand_strided((), (), device='cuda:0', dtype=torch.int64)
view_214 = rand_strided((8192, 768), (768, 1), device='cuda:0', dtype=torch.float16)
gt_29 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.bool)
mul_126 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.float32)
view_216 = rand_strided((8192, 768), (768, 1), device='cuda:0', dtype=torch.float16)
addmm_58 = rand_strided((8192, 3072), (3072, 1), device='cuda:0', dtype=torch.float16)
view_218 = rand_strided((8192, 3072), (3072, 1), device='cuda:0', dtype=torch.float16)
gt_30 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.bool)
mul_133 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.float32)
view_220 = rand_strided((8192, 768), (768, 1), device='cuda:0', dtype=torch.float16)
permute_default_6 = rand_strided((16, 12, 512, 64), (393216, 64, 768, 1), device='cuda:0', dtype=torch.float16)
permute_default_7 = rand_strided((16, 12, 512, 64), (393216, 64, 768, 1), device='cuda:0', dtype=torch.float16)
permute_default_8 = rand_strided((16, 12, 512, 64), (393216, 64, 768, 1), device='cuda:0', dtype=torch.float16)
getitem_59 = rand_strided((16, 12, 512, 64), (393216, 64, 768, 1), device='cuda:0', dtype=torch.float16)
getitem_60 = rand_strided((16, 12, 512), (6144, 512, 1), device='cuda:0', dtype=torch.float32)
getitem_61 = rand_strided((), (), device='cuda:0', dtype=torch.int64)
getitem_62 = rand_strided((), (), device='cuda:0', dtype=torch.int64)
view_236 = rand_strided((8192, 768), (768, 1), device='cuda:0', dtype=torch.float16)
gt_32 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.bool)
mul_139 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.float32)
view_238 = rand_strided((8192, 768), (768, 1), device='cuda:0', dtype=torch.float16)
addmm_64 = rand_strided((8192, 3072), (3072, 1), device='cuda:0', dtype=torch.float16)
view_240 = rand_strided((8192, 3072), (3072, 1), device='cuda:0', dtype=torch.float16)
gt_33 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.bool)
mul_146 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.float32)
view_242 = rand_strided((8192, 768), (768, 1), device='cuda:0', dtype=torch.float16)
permute_default = rand_strided((16, 12, 512, 64), (393216, 64, 768, 1), device='cuda:0', dtype=torch.float16)
permute_default_1 = rand_strided((16, 12, 512, 64), (393216, 64, 768, 1), device='cuda:0', dtype=torch.float16)
permute_default_2 = rand_strided((16, 12, 512, 64), (393216, 64, 768, 1), device='cuda:0', dtype=torch.float16)
getitem_52 = rand_strided((16, 12, 512, 64), (393216, 64, 768, 1), device='cuda:0', dtype=torch.float16)
getitem_53 = rand_strided((16, 12, 512), (6144, 512, 1), device='cuda:0', dtype=torch.float32)
getitem_54 = rand_strided((), (), device='cuda:0', dtype=torch.int64)
getitem_55 = rand_strided((), (), device='cuda:0', dtype=torch.int64)
view_258 = rand_strided((8192, 768), (768, 1), device='cuda:0', dtype=torch.float16)
gt_35 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.bool)
mul_152 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.float32)
view_260 = rand_strided((8192, 768), (768, 1), device='cuda:0', dtype=torch.float16)
addmm_70 = rand_strided((8192, 3072), (3072, 1), device='cuda:0', dtype=torch.float16)
view_262 = rand_strided((8192, 3072), (3072, 1), device='cuda:0', dtype=torch.float16)
gt_36 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.bool)
mul_159 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.float32)
view_264 = rand_strided((8192, 768), (768, 1), device='cuda:0', dtype=torch.float16)
addmm_72 = rand_strided((8192, 768), (768, 1), device='cuda:0', dtype=torch.float16)
getitem_51 = rand_strided((16, 512, 1), (512, 1, 1), device='cuda:0', dtype=torch.float32)
rsqrt_25 = rand_strided((16, 512, 1), (512, 1, 1), device='cuda:0', dtype=torch.float32)
view_266 = rand_strided((8192, 768), (768, 1), device='cuda:0', dtype=torch.float16)
view_267 = rand_strided((16, 512, 30522), (15630336, 30528, 1), device='cuda:0', dtype=torch.float16)
amax_12 = rand_strided((8192, 1), (1, 1), device='cuda:0', dtype=torch.float32)
log = rand_strided((8192, 1), (1, 1), device='cuda:0', dtype=torch.float32)
convert_element_type_510 = rand_strided((), (), device='cuda:0', dtype=torch.float32)
permute_134 = rand_strided((30522, 768), (768, 1), device='cuda:0', dtype=torch.float16)
permute_138 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.float16)
div_27 = rand_strided((16, 512, 1), (512, 1, 1), device='cuda:0', dtype=torch.float32)
permute_142 = rand_strided((768, 3072), (3072, 1), device='cuda:0', dtype=torch.float16)
permute_146 = rand_strided((3072, 768), (768, 1), device='cuda:0', dtype=torch.float16)
div_28 = rand_strided((16, 512, 1), (512, 1, 1), device='cuda:0', dtype=torch.float32)
permute_150 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.float16)
permute_162 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.float16)
permute_167 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.float16)
permute_171 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.float16)
div_30 = rand_strided((16, 512, 1), (512, 1, 1), device='cuda:0', dtype=torch.float32)
permute_175 = rand_strided((768, 3072), (3072, 1), device='cuda:0', dtype=torch.float16)
permute_179 = rand_strided((3072, 768), (768, 1), device='cuda:0', dtype=torch.float16)
div_31 = rand_strided((16, 512, 1), (512, 1, 1), device='cuda:0', dtype=torch.float32)
permute_183 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.float16)
permute_195 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.float16)
permute_200 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.float16)
permute_204 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.float16)
div_33 = rand_strided((16, 512, 1), (512, 1, 1), device='cuda:0', dtype=torch.float32)
permute_208 = rand_strided((768, 3072), (3072, 1), device='cuda:0', dtype=torch.float16)
permute_212 = rand_strided((3072, 768), (768, 1), device='cuda:0', dtype=torch.float16)
div_34 = rand_strided((16, 512, 1), (512, 1, 1), device='cuda:0', dtype=torch.float32)
permute_216 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.float16)
permute_228 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.float16)
permute_233 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.float16)
permute_237 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.float16)
div_36 = rand_strided((16, 512, 1), (512, 1, 1), device='cuda:0', dtype=torch.float32)
permute_241 = rand_strided((768, 3072), (3072, 1), device='cuda:0', dtype=torch.float16)
permute_245 = rand_strided((3072, 768), (768, 1), device='cuda:0', dtype=torch.float16)
div_37 = rand_strided((16, 512, 1), (512, 1, 1), device='cuda:0', dtype=torch.float32)
permute_249 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.float16)
permute_261 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.float16)
permute_266 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.float16)
permute_270 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.float16)
div_39 = rand_strided((16, 512, 1), (512, 1, 1), device='cuda:0', dtype=torch.float32)
permute_274 = rand_strided((768, 3072), (3072, 1), device='cuda:0', dtype=torch.float16)
permute_278 = rand_strided((3072, 768), (768, 1), device='cuda:0', dtype=torch.float16)
div_40 = rand_strided((16, 512, 1), (512, 1, 1), device='cuda:0', dtype=torch.float32)
permute_282 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.float16)
permute_294 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.float16)
permute_299 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.float16)
permute_303 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.float16)
div_42 = rand_strided((16, 512, 1), (512, 1, 1), device='cuda:0', dtype=torch.float32)
permute_307 = rand_strided((768, 3072), (3072, 1), device='cuda:0', dtype=torch.float16)
permute_311 = rand_strided((3072, 768), (768, 1), device='cuda:0', dtype=torch.float16)
div_43 = rand_strided((16, 512, 1), (512, 1, 1), device='cuda:0', dtype=torch.float32)
permute_315 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.float16)
permute_327 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.float16)
permute_332 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.float16)
permute_336 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.float16)
div_45 = rand_strided((16, 512, 1), (512, 1, 1), device='cuda:0', dtype=torch.float32)
permute_340 = rand_strided((768, 3072), (3072, 1), device='cuda:0', dtype=torch.float16)
permute_344 = rand_strided((3072, 768), (768, 1), device='cuda:0', dtype=torch.float16)
div_46 = rand_strided((16, 512, 1), (512, 1, 1), device='cuda:0', dtype=torch.float32)
permute_348 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.float16)
permute_360 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.float16)
permute_365 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.float16)
permute_369 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.float16)
div_48 = rand_strided((16, 512, 1), (512, 1, 1), device='cuda:0', dtype=torch.float32)
permute_373 = rand_strided((768, 3072), (3072, 1), device='cuda:0', dtype=torch.float16)
permute_377 = rand_strided((3072, 768), (768, 1), device='cuda:0', dtype=torch.float16)
div_49 = rand_strided((16, 512, 1), (512, 1, 1), device='cuda:0', dtype=torch.float32)
permute_381 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.float16)
permute_393 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.float16)
permute_398 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.float16)
permute_402 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.float16)
div_51 = rand_strided((16, 512, 1), (512, 1, 1), device='cuda:0', dtype=torch.float32)
permute_406 = rand_strided((768, 3072), (3072, 1), device='cuda:0', dtype=torch.float16)
permute_410 = rand_strided((3072, 768), (768, 1), device='cuda:0', dtype=torch.float16)
div_52 = rand_strided((16, 512, 1), (512, 1, 1), device='cuda:0', dtype=torch.float32)
permute_414 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.float16)
permute_426 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.float16)
permute_431 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.float16)
permute_435 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.float16)
div_54 = rand_strided((16, 512, 1), (512, 1, 1), device='cuda:0', dtype=torch.float32)
permute_439 = rand_strided((768, 3072), (3072, 1), device='cuda:0', dtype=torch.float16)
permute_443 = rand_strided((3072, 768), (768, 1), device='cuda:0', dtype=torch.float16)
div_55 = rand_strided((16, 512, 1), (512, 1, 1), device='cuda:0', dtype=torch.float32)
permute_447 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.float16)
permute_459 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.float16)
permute_464 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.float16)
permute_468 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.float16)
div_57 = rand_strided((16, 512, 1), (512, 1, 1), device='cuda:0', dtype=torch.float32)
permute_472 = rand_strided((768, 3072), (3072, 1), device='cuda:0', dtype=torch.float16)
permute_476 = rand_strided((3072, 768), (768, 1), device='cuda:0', dtype=torch.float16)
div_58 = rand_strided((16, 512, 1), (512, 1, 1), device='cuda:0', dtype=torch.float32)
permute_480 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.float16)
permute_492 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.float16)
permute_497 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.float16)
permute_501 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.float16)
div_60 = rand_strided((16, 512, 1), (512, 1, 1), device='cuda:0', dtype=torch.float32)
permute_505 = rand_strided((768, 3072), (3072, 1), device='cuda:0', dtype=torch.float16)
permute_509 = rand_strided((3072, 768), (768, 1), device='cuda:0', dtype=torch.float16)
div_61 = rand_strided((16, 512, 1), (512, 1, 1), device='cuda:0', dtype=torch.float32)
permute_513 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.float16)
permute_525 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.float16)
permute_530 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.float16)
permute_534 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.float16)
div_63 = rand_strided((16, 512, 1), (512, 1, 1), device='cuda:0', dtype=torch.float32)
tangents_1 = rand_strided((), (), device='cuda:0', dtype=torch.float32)
tangents_2 = rand_strided((16, 512, 30522), (15627264, 30522, 1), device='cuda:0', dtype=torch.float16)
fn = lambda: call([primals_4, primals_14, primals_20, primals_30, primals_36, primals_46, primals_52, primals_62, primals_68, primals_78, primals_84, primals_94, primals_100, primals_110, primals_116, primals_126, primals_132, primals_142, primals_148, primals_158, primals_164, primals_174, primals_180, primals_190, primals_196, primals_200, primals_204, primals_205, primals_206, primals_207, mul_1, gt, view, permute_default_66, permute_default_67, permute_default_68, getitem_129, getitem_130, getitem_131, getitem_132, view_16, gt_2, mul_9, view_18, addmm_4, view_20, gt_3, mul_16, view_22, permute_default_60, permute_default_61, permute_default_62, getitem_122, getitem_123, getitem_124, getitem_125, view_38, gt_5, mul_22, view_40, addmm_10, view_42, gt_6, mul_29, view_44, permute_default_54, permute_default_55, permute_default_56, getitem_115, getitem_116, getitem_117, getitem_118, view_60, gt_8, mul_35, view_62, addmm_16, view_64, gt_9, mul_42, view_66, permute_default_48, permute_default_49, permute_default_50, getitem_108, getitem_109, getitem_110, getitem_111, view_82, gt_11, mul_48, view_84, addmm_22, view_86, gt_12, mul_55, view_88, permute_default_42, permute_default_43, permute_default_44, getitem_101, getitem_102, getitem_103, getitem_104, view_104, gt_14, mul_61, view_106, addmm_28, view_108, gt_15, mul_68, view_110, permute_default_36, permute_default_37, permute_default_38, getitem_94, getitem_95, getitem_96, getitem_97, view_126, gt_17, mul_74, view_128, addmm_34, view_130, gt_18, mul_81, view_132, permute_default_30, permute_default_31, permute_default_32, getitem_87, getitem_88, getitem_89, getitem_90, view_148, gt_20, mul_87, view_150, addmm_40, view_152, gt_21, mul_94, view_154, permute_default_24, permute_default_25, permute_default_26, getitem_80, getitem_81, getitem_82, getitem_83, view_170, gt_23, mul_100, view_172, addmm_46, view_174, gt_24, mul_107, view_176, permute_default_18, permute_default_19, permute_default_20, getitem_73, getitem_74, getitem_75, getitem_76, view_192, gt_26, mul_113, view_194, addmm_52, view_196, gt_27, mul_120, view_198, permute_default_12, permute_default_13, permute_default_14, getitem_66, getitem_67, getitem_68, getitem_69, view_214, gt_29, mul_126, view_216, addmm_58, view_218, gt_30, mul_133, view_220, permute_default_6, permute_default_7, permute_default_8, getitem_59, getitem_60, getitem_61, getitem_62, view_236, gt_32, mul_139, view_238, addmm_64, view_240, gt_33, mul_146, view_242, permute_default, permute_default_1, permute_default_2, getitem_52, getitem_53, getitem_54, getitem_55, view_258, gt_35, mul_152, view_260, addmm_70, view_262, gt_36, mul_159, view_264, addmm_72, getitem_51, rsqrt_25, view_266, view_267, amax_12, log, convert_element_type_510, permute_134, permute_138, div_27, permute_142, permute_146, div_28, permute_150, permute_162, permute_167, permute_171, div_30, permute_175, permute_179, div_31, permute_183, permute_195, permute_200, permute_204, div_33, permute_208, permute_212, div_34, permute_216, permute_228, permute_233, permute_237, div_36, permute_241, permute_245, div_37, permute_249, permute_261, permute_266, permute_270, div_39, permute_274, permute_278, div_40, permute_282, permute_294, permute_299, permute_303, div_42, permute_307, permute_311, div_43, permute_315, permute_327, permute_332, permute_336, div_45, permute_340, permute_344, div_46, permute_348, permute_360, permute_365, permute_369, div_48, permute_373, permute_377, div_49, permute_381, permute_393, permute_398, permute_402, div_51, permute_406, permute_410, div_52, permute_414, permute_426, permute_431, permute_435, div_54, permute_439, permute_443, div_55, permute_447, permute_459, permute_464, permute_468, div_57, permute_472, permute_476, div_58, permute_480, permute_492, permute_497, permute_501, div_60, permute_505, permute_509, div_61, permute_513, permute_525, permute_530, permute_534, div_63, tangents_1, tangents_2])
return print_performance(fn, times=times, repeat=repeat)
if __name__ == "__main__":
from torch._inductor.wrapper_benchmark import compiled_module_main
compiled_module_main('BertForMaskedLM', benchmark_compiled_module)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment