Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save shunting314/4ea8ac8fbd5c0de9addfc5368ae501d2 to your computer and use it in GitHub Desktop.
Save shunting314/4ea8ac8fbd5c0de9addfc5368ae501d2 to your computer and use it in GitHub Desktop.
# AOT ID: ['0_forward']
from ctypes import c_void_p, c_long
import torch
import math
import random
import os
import tempfile
from math import inf, nan
from torch._inductor.hooks import run_intermediate_hooks
from torch._inductor.utils import maybe_profile
from torch._inductor.codegen.memory_planning import _align as align
from torch import device, empty_strided
from torch._inductor.async_compile import AsyncCompile
from torch._inductor.select_algorithm import extern_kernels
from torch._inductor.codegen.multi_kernel import MultiKernelCall
aten = torch.ops.aten
inductor_ops = torch.ops.inductor
_quantized = torch.ops._quantized
assert_size_stride = torch._C._dynamo.guards.assert_size_stride
empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu
empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda
reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor
alloc_from_pool = torch.ops.inductor._alloc_from_pool
async_compile = AsyncCompile()
# kernel path: /tmp/torchinductor_shunting/tmp51wq3e92/6h/c6h2ojxcczebudxcelwtfphsrsjbqw2mvt7vssj6gbcs6rhuzl5j.py
# Source Nodes: [x], Original ATen: [aten._to_copy]
# x => convert_element_type_1
triton_poi_fused__to_copy_0 = async_compile.triton('triton_poi_fused__to_copy_0', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
from torch._dynamo.testing import rand_strided
from torch._C import _cuda_getCurrentRawStream as get_raw_stream
import torch
from torch._inductor.runtime.triton_heuristics import grid, split_scan_grid
@triton_heuristics.pointwise(
size_hints=[1048576],
filename=__file__,
triton_meta={'signature': {0: '*fp32', 1: '*fp16', 2: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2), equal_to_1=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_0', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': '196B341D951BBDD96DE5B2B44B37054C2CBA8E89494C8B8A1052A863B8BC7596', 'are_deterministic_algorithms_enabled': True, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'kernel_num_gb': 0.003538944},
min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused__to_copy_0(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
xnumel = 589824
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = tl.full([XBLOCK], True, tl.int1)
x0 = xindex
tmp0 = tl.load(in_ptr0 + (x0), None)
tmp1 = tmp0.to(tl.float32)
tl.store(out_ptr0 + (x0), tmp1, None)
def get_args():
arg_0 = rand_strided((768, 3, 16, 16), (768, 256, 16, 1), device='cuda:0', dtype=torch.float32)
arg_1 = rand_strided((768, 3, 16, 16), (768, 256, 16, 1), device='cuda:0', dtype=torch.float16)
return arg_0, arg_1,
def call(args):
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_0.run(*args, 589824, grid=grid(589824), stream=stream0)
def benchmark_all_configs(args):
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
return triton_poi_fused__to_copy_0.benchmark_all_configs(*args, 589824, grid=grid(589824))
if __name__ == '__main__':
from triton.testing import do_bench
args = get_args()
ms = do_bench(lambda: call(args), rep=40, fast_flush=True)
num_gb = 0.003538944
gb_per_s = num_gb / (ms / 1e3)
print(f"{ms:.3f}ms {num_gb:.3f}GB {gb_per_s:.2f}GB/s")
''', device_str='cuda')
import triton
import triton.language as tl
from torch._inductor.runtime.triton_heuristics import grid, split_scan_grid, start_graph, end_graph
from torch._C import _cuda_getCurrentRawStream as get_raw_stream
# kernel path: /tmp/torchinductor_shunting/tmp51wq3e92/p3/cp36una3lrbo6zpaqu4xa7zubedzetoocavmc6nppfg3ysxcg3cl.py
# Source Nodes: [x], Original ATen: [aten._to_copy]
# x => convert_element_type_2
triton_poi_fused__to_copy_1 = async_compile.triton('triton_poi_fused__to_copy_1', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
from torch._dynamo.testing import rand_strided
from torch._C import _cuda_getCurrentRawStream as get_raw_stream
import torch
from torch._inductor.runtime.triton_heuristics import grid, split_scan_grid
@triton_heuristics.pointwise(
size_hints=[2097152],
filename=__file__,
triton_meta={'signature': {0: '*fp32', 1: '*fp16', 2: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2), equal_to_1=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_1', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': '196B341D951BBDD96DE5B2B44B37054C2CBA8E89494C8B8A1052A863B8BC7596', 'are_deterministic_algorithms_enabled': True, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'kernel_num_gb': 0.007225344},
min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused__to_copy_1(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
xnumel = 1204224
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = tl.full([XBLOCK], True, tl.int1)
x0 = xindex
tmp0 = tl.load(in_ptr0 + (x0), None)
tmp1 = tmp0.to(tl.float32)
tl.store(out_ptr0 + (x0), tmp1, None)
def get_args():
arg_0 = rand_strided((8, 3, 224, 224), (150528, 50176, 224, 1), device='cuda:0', dtype=torch.float32)
arg_1 = rand_strided((8, 3, 224, 224), (150528, 50176, 224, 1), device='cuda:0', dtype=torch.float16)
return arg_0, arg_1,
def call(args):
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_1.run(*args, 1204224, grid=grid(1204224), stream=stream0)
def benchmark_all_configs(args):
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
return triton_poi_fused__to_copy_1.benchmark_all_configs(*args, 1204224, grid=grid(1204224))
if __name__ == '__main__':
from triton.testing import do_bench
args = get_args()
ms = do_bench(lambda: call(args), rep=40, fast_flush=True)
num_gb = 0.007225344
gb_per_s = num_gb / (ms / 1e3)
print(f"{ms:.3f}ms {num_gb:.3f}GB {gb_per_s:.2f}GB/s")
''', device_str='cuda')
# kernel path: /tmp/torchinductor_shunting/tmp51wq3e92/ag/cagd2lhqxez5xsijsysmarruzc3zmx2hr5vfoz6ia7xgntfwzjsp.py
# Source Nodes: [x], Original ATen: [aten._to_copy, aten.convolution]
# x => convert_element_type, convolution
triton_poi_fused__to_copy_convolution_2 = async_compile.triton('triton_poi_fused__to_copy_convolution_2', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
from torch._dynamo.testing import rand_strided
from torch._C import _cuda_getCurrentRawStream as get_raw_stream
import torch
from torch._inductor.runtime.triton_heuristics import grid, split_scan_grid
@triton_heuristics.pointwise(
size_hints=[1024],
filename=__file__,
triton_meta={'signature': {0: '*fp32', 1: '*fp16', 2: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2), equal_to_1=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_convolution_2', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': '196B341D951BBDD96DE5B2B44B37054C2CBA8E89494C8B8A1052A863B8BC7596', 'are_deterministic_algorithms_enabled': True, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'kernel_num_gb': 4.608e-06},
min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused__to_copy_convolution_2(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
xnumel = 768
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = xindex < xnumel
x0 = xindex
tmp0 = tl.load(in_ptr0 + (x0), xmask)
tmp1 = tmp0.to(tl.float32)
tl.store(out_ptr0 + (x0), tmp1, xmask)
def get_args():
arg_0 = rand_strided((768,), (1,), device='cuda:0', dtype=torch.float32)
arg_1 = rand_strided((768,), (1,), device='cuda:0', dtype=torch.float16)
return arg_0, arg_1,
def call(args):
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_convolution_2.run(*args, 768, grid=grid(768), stream=stream0)
def benchmark_all_configs(args):
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
return triton_poi_fused__to_copy_convolution_2.benchmark_all_configs(*args, 768, grid=grid(768))
if __name__ == '__main__':
from triton.testing import do_bench
args = get_args()
ms = do_bench(lambda: call(args), rep=40, fast_flush=True)
num_gb = 4.608e-06
gb_per_s = num_gb / (ms / 1e3)
print(f"{ms:.3f}ms {num_gb:.3f}GB {gb_per_s:.2f}GB/s")
''', device_str='cuda')
# kernel path: /tmp/torchinductor_shunting/tmp51wq3e92/az/cazvyaejmrd6osrlhz4gvssnihci4mb7rifvtjgf2fct4zwitpy3.py
# Source Nodes: [qkv, x_3, x_5], Original ATen: [aten._to_copy, aten.cat, aten.native_layer_norm]
# qkv => convert_element_type_5
# x_3 => cat
# x_5 => add, add_1, mul, mul_1, rsqrt, sub, var_mean
triton_red_fused__to_copy_cat_native_layer_norm_3 = async_compile.triton('triton_red_fused__to_copy_cat_native_layer_norm_3', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
from torch._dynamo.testing import rand_strided
from torch._C import _cuda_getCurrentRawStream as get_raw_stream
import torch
from torch._inductor.runtime.triton_heuristics import grid, split_scan_grid
@triton_heuristics.reduction(
size_hints=[2048, 1024],
reduction_hint=ReductionHint.DEFAULT,
filename=__file__,
triton_meta={'signature': {0: '*fp32', 1: '*fp32', 2: '*fp16', 3: '*fp16', 4: '*fp32', 5: '*fp32', 6: '*fp32', 7: '*fp32', 8: '*fp16', 9: 'i32', 10: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4, 5, 6, 7, 8, 10), equal_to_1=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_cat_native_layer_norm_3', 'mutated_arg_names': ['in_out_ptr0'], 'no_x_dim': False, 'num_load': 6, 'num_reduction': 2, 'backend_hash': '196B341D951BBDD96DE5B2B44B37054C2CBA8E89494C8B8A1052A863B8BC7596', 'are_deterministic_algorithms_enabled': True, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'kernel_num_gb': 0.00970032}
)
@triton.jit
def triton_red_fused__to_copy_cat_native_layer_norm_3(in_out_ptr0, in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, out_ptr0, out_ptr1, out_ptr2, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr):
xnumel = 1576
rnumel = 768
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
xmask = xindex < xnumel
rbase = tl.arange(0, RBLOCK)[None, :]
x0 = xindex % 197
x1 = (xindex // 197)
x3 = xindex
tmp17_mean = tl.zeros([XBLOCK, RBLOCK], tl.float32)
tmp17_m2 = tl.zeros([XBLOCK, RBLOCK], tl.float32)
tmp17_weight = tl.zeros([XBLOCK, RBLOCK], tl.float32)
for roffset in range(0, rnumel, RBLOCK):
rindex = roffset + rbase
rmask = rindex < rnumel
r2 = rindex
tmp0 = x0
tmp1 = tl.full([1, 1], 0, tl.int64)
tmp2 = tmp0 >= tmp1
tmp3 = tl.full([1, 1], 1, tl.int64)
tmp4 = tmp0 < tmp3
tmp5 = tl.load(in_ptr0 + (tl.broadcast_to(r2, [XBLOCK, RBLOCK])), rmask & tmp4 & xmask, eviction_policy='evict_last', other=0.0)
tmp6 = tmp0 >= tmp3
tmp7 = tl.full([1, 1], 197, tl.int64)
tmp8 = tmp0 < tmp7
tmp9 = tl.load(in_ptr1 + ((196*r2) + (150528*x1) + (((-1) + x0) % 196)), rmask & tmp6 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp10 = tl.load(in_ptr2 + (tl.broadcast_to(r2, [XBLOCK, RBLOCK])), rmask & tmp6 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp11 = tmp9 + tmp10
tmp12 = tmp11.to(tl.float32)
tmp13 = tl.full(tmp12.shape, 0.0, tmp12.dtype)
tmp14 = tl.where(tmp6, tmp12, tmp13)
tmp15 = tl.where(tmp4, tmp5, tmp14)
tmp16 = tl.broadcast_to(tmp15, [XBLOCK, RBLOCK])
tmp17_mean_next, tmp17_m2_next, tmp17_weight_next = triton_helpers.welford_reduce(
tmp16, tmp17_mean, tmp17_m2, tmp17_weight, roffset == 0
)
tmp17_mean = tl.where(rmask & xmask, tmp17_mean_next, tmp17_mean)
tmp17_m2 = tl.where(rmask & xmask, tmp17_m2_next, tmp17_m2)
tmp17_weight = tl.where(rmask & xmask, tmp17_weight_next, tmp17_weight)
tl.store(out_ptr0 + (r2 + (768*x3)), tmp15, rmask & xmask)
tmp17_tmp, tmp18_tmp, tmp19_tmp = triton_helpers.welford(
tmp17_mean, tmp17_m2, tmp17_weight, 1
)
tmp17 = tmp17_tmp[:, None]
tmp18 = tmp18_tmp[:, None]
tmp19 = tmp19_tmp[:, None]
tl.store(out_ptr1 + (x3), tmp17, xmask)
tmp20 = 768.0
tmp21 = tmp18 / tmp20
tmp22 = 1e-06
tmp23 = tmp21 + tmp22
tmp24 = libdevice.rsqrt(tmp23)
tl.debug_barrier()
tl.store(in_out_ptr0 + (x3), tmp24, xmask)
for roffset in range(0, rnumel, RBLOCK):
rindex = roffset + rbase
rmask = rindex < rnumel
r2 = rindex
tmp25 = tl.load(out_ptr0 + (r2 + (768*x3)), rmask & xmask, eviction_policy='evict_first', other=0.0)
tmp28 = tl.load(in_ptr3 + (r2), rmask, eviction_policy='evict_last', other=0.0)
tmp30 = tl.load(in_ptr4 + (r2), rmask, eviction_policy='evict_last', other=0.0)
tmp26 = tmp25 - tmp17
tmp27 = tmp26 * tmp24
tmp29 = tmp27 * tmp28
tmp31 = tmp29 + tmp30
tmp32 = tmp31.to(tl.float32)
tl.store(out_ptr2 + (r2 + (768*x3)), tmp32, rmask & xmask)
def get_args():
arg_0 = rand_strided((8, 197, 1), (197, 1, 1), device='cuda:0', dtype=torch.float32)
arg_1 = rand_strided((1, 1, 768), (768, 768, 1), device='cuda:0', dtype=torch.float32)
arg_2 = rand_strided((8, 768, 14, 14), (150528, 196, 14, 1), device='cuda:0', dtype=torch.float16)
arg_3 = rand_strided((768,), (1,), device='cuda:0', dtype=torch.float16)
arg_4 = rand_strided((768,), (1,), device='cuda:0', dtype=torch.float32)
arg_5 = rand_strided((768,), (1,), device='cuda:0', dtype=torch.float32)
arg_6 = rand_strided((8, 197, 768), (151296, 768, 1), device='cuda:0', dtype=torch.float32)
arg_7 = rand_strided((8, 197, 1), (197, 1, 1), device='cuda:0', dtype=torch.float32)
arg_8 = rand_strided((8, 197, 768), (151296, 768, 1), device='cuda:0', dtype=torch.float16)
return arg_0, arg_1, arg_2, arg_3, arg_4, arg_5, arg_6, arg_7, arg_8,
def call(args):
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
stream0 = get_raw_stream(0)
triton_red_fused__to_copy_cat_native_layer_norm_3.run(*args, 1576, 768, grid=grid(1576), stream=stream0)
def benchmark_all_configs(args):
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
return triton_red_fused__to_copy_cat_native_layer_norm_3.benchmark_all_configs(*args, 1576, 768, grid=grid(1576))
if __name__ == '__main__':
from triton.testing import do_bench
args = get_args()
ms = do_bench(lambda: call(args), rep=40, fast_flush=True)
num_gb = 0.00970032
gb_per_s = num_gb / (ms / 1e3)
print(f"{ms:.3f}ms {num_gb:.3f}GB {gb_per_s:.2f}GB/s")
''', device_str='cuda')
# kernel path: /tmp/torchinductor_shunting/tmp51wq3e92/3p/c3po5sv55ujxuuqlw53cq4m4aku6wua2pjih2qi3cdv4jto2d3et.py
# Source Nodes: [qkv], Original ATen: [aten._to_copy]
# qkv => convert_element_type_4
triton_poi_fused__to_copy_4 = async_compile.triton('triton_poi_fused__to_copy_4', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
from torch._dynamo.testing import rand_strided
from torch._C import _cuda_getCurrentRawStream as get_raw_stream
import torch
from torch._inductor.runtime.triton_heuristics import grid, split_scan_grid
@triton_heuristics.pointwise(
size_hints=[2097152],
filename=__file__,
triton_meta={'signature': {0: '*fp32', 1: '*fp16', 2: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2), equal_to_1=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_4', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': '196B341D951BBDD96DE5B2B44B37054C2CBA8E89494C8B8A1052A863B8BC7596', 'are_deterministic_algorithms_enabled': True, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'kernel_num_gb': 0.010616832},
min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused__to_copy_4(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
xnumel = 1769472
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = tl.full([XBLOCK], True, tl.int1)
x0 = xindex
tmp0 = tl.load(in_ptr0 + (x0), None)
tmp1 = tmp0.to(tl.float32)
tl.store(out_ptr0 + (x0), tmp1, None)
def get_args():
arg_0 = rand_strided((2304, 768), (768, 1), device='cuda:0', dtype=torch.float32)
arg_1 = rand_strided((2304, 768), (768, 1), device='cuda:0', dtype=torch.float16)
return arg_0, arg_1,
def call(args):
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_4.run(*args, 1769472, grid=grid(1769472), stream=stream0)
def benchmark_all_configs(args):
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
return triton_poi_fused__to_copy_4.benchmark_all_configs(*args, 1769472, grid=grid(1769472))
if __name__ == '__main__':
from triton.testing import do_bench
args = get_args()
ms = do_bench(lambda: call(args), rep=40, fast_flush=True)
num_gb = 0.010616832
gb_per_s = num_gb / (ms / 1e3)
print(f"{ms:.3f}ms {num_gb:.3f}GB {gb_per_s:.2f}GB/s")
''', device_str='cuda')
# kernel path: /tmp/torchinductor_shunting/tmp51wq3e92/vz/cvzui4yman7mph2bugexbe2hxymkjridm7ww5hlcxylhkqj63gpd.py
# Source Nodes: [qkv, qkv_bias], Original ATen: [aten._to_copy, aten.cat]
# qkv => convert_element_type_3
# qkv_bias => cat_1
triton_poi_fused__to_copy_cat_5 = async_compile.triton('triton_poi_fused__to_copy_cat_5', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
from torch._dynamo.testing import rand_strided
from torch._C import _cuda_getCurrentRawStream as get_raw_stream
import torch
from torch._inductor.runtime.triton_heuristics import grid, split_scan_grid
@triton_heuristics.pointwise(
size_hints=[4096],
filename=__file__,
triton_meta={'signature': {0: '*fp32', 1: '*fp32', 2: '*fp32', 3: '*fp16', 4: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4), equal_to_1=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_cat_5', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 3, 'num_reduction': 0, 'backend_hash': '196B341D951BBDD96DE5B2B44B37054C2CBA8E89494C8B8A1052A863B8BC7596', 'are_deterministic_algorithms_enabled': True, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'kernel_num_gb': 1.3824e-05},
min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused__to_copy_cat_5(in_ptr0, in_ptr1, in_ptr2, out_ptr0, xnumel, XBLOCK : tl.constexpr):
xnumel = 2304
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = xindex < xnumel
x0 = xindex
tmp0 = x0
tmp1 = tl.full([1], 0, tl.int64)
tmp2 = tmp0 >= tmp1
tmp3 = tl.full([1], 768, tl.int64)
tmp4 = tmp0 < tmp3
tmp5 = tl.load(in_ptr0 + (x0), tmp4 & xmask, eviction_policy='evict_last', other=0.0)
tmp6 = tmp0 >= tmp3
tmp7 = tl.full([1], 1536, tl.int64)
tmp8 = tmp0 < tmp7
tmp9 = tmp6 & tmp8
tmp10 = tl.load(in_ptr1 + ((-768) + x0), tmp9 & xmask, eviction_policy='evict_last', other=0.0)
tmp11 = tmp0 >= tmp7
tmp12 = tl.full([1], 2304, tl.int64)
tmp13 = tmp0 < tmp12
tmp14 = tl.load(in_ptr2 + ((-1536) + x0), tmp11 & xmask, eviction_policy='evict_last', other=0.0)
tmp15 = tl.where(tmp9, tmp10, tmp14)
tmp16 = tl.where(tmp4, tmp5, tmp15)
tmp17 = tmp16.to(tl.float32)
tl.store(out_ptr0 + (x0), tmp17, xmask)
def get_args():
arg_0 = rand_strided((768,), (1,), device='cuda:0', dtype=torch.float32)
arg_1 = rand_strided((768,), (1,), device='cuda:0', dtype=torch.float32)
arg_2 = rand_strided((768,), (1,), device='cuda:0', dtype=torch.float32)
arg_3 = rand_strided((2304,), (1,), device='cuda:0', dtype=torch.float16)
return arg_0, arg_1, arg_2, arg_3,
def call(args):
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_cat_5.run(*args, 2304, grid=grid(2304), stream=stream0)
def benchmark_all_configs(args):
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
return triton_poi_fused__to_copy_cat_5.benchmark_all_configs(*args, 2304, grid=grid(2304))
if __name__ == '__main__':
from triton.testing import do_bench
args = get_args()
ms = do_bench(lambda: call(args), rep=40, fast_flush=True)
num_gb = 1.3824e-05
gb_per_s = num_gb / (ms / 1e3)
print(f"{ms:.3f}ms {num_gb:.3f}GB {gb_per_s:.2f}GB/s")
''', device_str='cuda')
# kernel path: /tmp/torchinductor_shunting/tmp51wq3e92/jm/cjmgmq5xjedr5q57j3plezr536tvgwrie5fy5563n4zawjob2fzu.py
# Source Nodes: [x_6], Original ATen: [aten._to_copy, aten.constant_pad_nd]
# x_6 => constant_pad_nd, convert_element_type_9
triton_poi_fused__to_copy_constant_pad_nd_6 = async_compile.triton('triton_poi_fused__to_copy_constant_pad_nd_6', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
from torch._dynamo.testing import rand_strided
from torch._C import _cuda_getCurrentRawStream as get_raw_stream
import torch
from torch._inductor.runtime.triton_heuristics import grid, split_scan_grid
@triton_heuristics.pointwise(
size_hints=[524288],
filename=__file__,
triton_meta={'signature': {0: '*i64', 1: '*fp32', 2: '*fp16', 3: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3), equal_to_1=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_constant_pad_nd_6', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': '196B341D951BBDD96DE5B2B44B37054C2CBA8E89494C8B8A1052A863B8BC7596', 'are_deterministic_algorithms_enabled': True, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'kernel_num_gb': 0.001291208},
min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused__to_copy_constant_pad_nd_6(in_ptr0, in_ptr1, out_ptr0, xnumel, XBLOCK : tl.constexpr):
xnumel = 472800
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = xindex < xnumel
x0 = xindex % 200
x1 = (xindex // 200) % 197
x2 = (xindex // 39400)
x3 = xindex % 39400
tmp0 = x0
tmp1 = tl.full([1], 197, tl.int64)
tmp2 = tmp0 < tmp1
tmp3 = tl.load(in_ptr0 + (x0 + (197*x1)), tmp2 & xmask, eviction_policy='evict_last', other=0.0)
tmp4 = tl.full([XBLOCK], 732, tl.int32)
tmp5 = tmp3 + tmp4
tmp6 = tmp3 < 0
tmp7 = tl.where(tmp6, tmp5, tmp3)
tl.device_assert(((0 <= tl.broadcast_to(tmp7, [XBLOCK])) & (tl.broadcast_to(tmp7, [XBLOCK]) < 732)) | ~(tmp2 & xmask), "index out of bounds: 0 <= tl.broadcast_to(tmp7, [XBLOCK]) < 732")
tmp9 = tl.load(in_ptr1 + (x2 + (12*tmp7)), tmp2 & xmask, eviction_policy='evict_last', other=0.0)
tmp10 = tmp9.to(tl.float32)
tmp11 = tl.full(tmp10.shape, 0.0, tmp10.dtype)
tmp12 = tl.where(tmp2, tmp10, tmp11)
tl.store(out_ptr0 + (x3 + (39424*x2)), tmp12, xmask)
def get_args():
arg_0 = rand_strided((197, 197), (197, 1), device='cuda:0', dtype=torch.int64)
arg_1 = rand_strided((732, 12), (12, 1), device='cuda:0', dtype=torch.float32)
arg_2 = rand_strided((1, 12, 197, 200), (473088, 39424, 200, 1), device='cuda:0', dtype=torch.float16)
return arg_0, arg_1, arg_2,
def call(args):
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_constant_pad_nd_6.run(*args, 472800, grid=grid(472800), stream=stream0)
def benchmark_all_configs(args):
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
return triton_poi_fused__to_copy_constant_pad_nd_6.benchmark_all_configs(*args, 472800, grid=grid(472800))
if __name__ == '__main__':
from triton.testing import do_bench
args = get_args()
ms = do_bench(lambda: call(args), rep=40, fast_flush=True)
num_gb = 0.001291208
gb_per_s = num_gb / (ms / 1e3)
print(f"{ms:.3f}ms {num_gb:.3f}GB {gb_per_s:.2f}GB/s")
''', device_str='cuda')
# kernel path: /tmp/torchinductor_shunting/tmp51wq3e92/mm/cmmjnvmzgwnwovglcakbfgji3peq4t76kumq5egzestihfrbnnlu.py
# Source Nodes: [x_8], Original ATen: [aten._to_copy]
# x_8 => convert_element_type_11
triton_poi_fused__to_copy_7 = async_compile.triton('triton_poi_fused__to_copy_7', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
from torch._dynamo.testing import rand_strided
from torch._C import _cuda_getCurrentRawStream as get_raw_stream
import torch
from torch._inductor.runtime.triton_heuristics import grid, split_scan_grid
@triton_heuristics.pointwise(
size_hints=[1048576],
filename=__file__,
triton_meta={'signature': {0: '*fp32', 1: '*fp16', 2: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2), equal_to_1=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_7', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': '196B341D951BBDD96DE5B2B44B37054C2CBA8E89494C8B8A1052A863B8BC7596', 'are_deterministic_algorithms_enabled': True, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'kernel_num_gb': 0.003538944},
min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused__to_copy_7(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
xnumel = 589824
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = tl.full([XBLOCK], True, tl.int1)
x0 = xindex
tmp0 = tl.load(in_ptr0 + (x0), None)
tmp1 = tmp0.to(tl.float32)
tl.store(out_ptr0 + (x0), tmp1, None)
def get_args():
arg_0 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.float32)
arg_1 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.float16)
return arg_0, arg_1,
def call(args):
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_7.run(*args, 589824, grid=grid(589824), stream=stream0)
def benchmark_all_configs(args):
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
return triton_poi_fused__to_copy_7.benchmark_all_configs(*args, 589824, grid=grid(589824))
if __name__ == '__main__':
from triton.testing import do_bench
args = get_args()
ms = do_bench(lambda: call(args), rep=40, fast_flush=True)
num_gb = 0.003538944
gb_per_s = num_gb / (ms / 1e3)
print(f"{ms:.3f}ms {num_gb:.3f}GB {gb_per_s:.2f}GB/s")
''', device_str='cuda')
# kernel path: /tmp/torchinductor_shunting/tmp51wq3e92/ea/ceaaligcfdxmpe6npb227nmhntpxgwdrtdgp2yik66ba2xyp3gqz.py
# Source Nodes: [mul, x_10, x_11, x_12], Original ATen: [aten._to_copy, aten.add, aten.mul, aten.native_layer_norm]
# mul => mul_2
# x_10 => add_2
# x_11 => add_3, add_4, mul_3, mul_4, rsqrt_1, sub_1, var_mean_1
# x_12 => convert_element_type_17
triton_per_fused__to_copy_add_mul_native_layer_norm_8 = async_compile.triton('triton_per_fused__to_copy_add_mul_native_layer_norm_8', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
from torch._dynamo.testing import rand_strided
from torch._C import _cuda_getCurrentRawStream as get_raw_stream
import torch
from torch._inductor.runtime.triton_heuristics import grid, split_scan_grid
@triton_heuristics.persistent_reduction(
size_hints=[2048, 1024],
reduction_hint=ReductionHint.INNER,
filename=__file__,
triton_meta={'signature': {0: '*fp32', 1: '*fp32', 2: '*fp32', 3: '*fp16', 4: '*fp32', 5: '*fp32', 6: '*fp32', 7: '*fp16', 8: 'i32', 9: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4, 5, 6, 7, 9), equal_to_1=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_add_mul_native_layer_norm_8', 'mutated_arg_names': ['in_out_ptr0'], 'no_x_dim': True, 'num_load': 5, 'num_reduction': 4, 'backend_hash': '196B341D951BBDD96DE5B2B44B37054C2CBA8E89494C8B8A1052A863B8BC7596', 'are_deterministic_algorithms_enabled': True, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'kernel_num_gb': 0.009711072}
)
@triton.jit
def triton_per_fused__to_copy_add_mul_native_layer_norm_8(in_out_ptr0, in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, out_ptr0, out_ptr1, xnumel, rnumel):
xnumel = 1576
XBLOCK: tl.constexpr = 1
rnumel = 768
RBLOCK: tl.constexpr = 1024
xoffset = tl.program_id(0) * XBLOCK
xindex = tl.full([1], xoffset, tl.int32)
xmask = tl.full([RBLOCK], True, tl.int1)
rindex = tl.arange(0, RBLOCK)[:]
roffset = 0
rmask = rindex < rnumel
r1 = rindex
x0 = xindex
tmp0 = tl.load(in_ptr0 + (r1 + (768*x0)), rmask, other=0.0)
tmp1 = tl.load(in_ptr1 + (r1), rmask, eviction_policy='evict_last', other=0.0)
tmp2 = tl.load(in_ptr2 + (r1 + (768*x0)), rmask, other=0.0).to(tl.float32)
tmp29 = tl.load(in_ptr3 + (r1), rmask, eviction_policy='evict_last', other=0.0)
tmp31 = tl.load(in_ptr4 + (r1), rmask, eviction_policy='evict_last', other=0.0)
tmp3 = tmp2.to(tl.float32)
tmp4 = tmp1 * tmp3
tmp5 = tmp0 + tmp4
tmp6 = tl.broadcast_to(tmp5, [RBLOCK])
tmp8 = tl.where(rmask, tmp6, 0)
tmp9 = tl.broadcast_to(tmp6, [RBLOCK])
tmp11 = tl.where(rmask, tmp9, 0)
tmp12 = triton_helpers.promote_to_tensor(tl.sum(tmp11, 0))
tmp13 = tl.full([1], 768, tl.int32)
tmp14 = tmp13.to(tl.float32)
tmp15 = tmp12 / tmp14
tmp16 = tmp6 - tmp15
tmp17 = tmp16 * tmp16
tmp18 = tl.broadcast_to(tmp17, [RBLOCK])
tmp20 = tl.where(rmask, tmp18, 0)
tmp21 = triton_helpers.promote_to_tensor(tl.sum(tmp20, 0))
tmp22 = 768.0
tmp23 = tmp21 / tmp22
tmp24 = 1e-06
tmp25 = tmp23 + tmp24
tmp26 = libdevice.rsqrt(tmp25)
tmp27 = tmp5 - tmp15
tmp28 = tmp27 * tmp26
tmp30 = tmp28 * tmp29
tmp32 = tmp30 + tmp31
tmp33 = tmp32.to(tl.float32)
tl.debug_barrier()
tl.store(in_out_ptr0 + (x0), tmp26, None)
tl.store(out_ptr1 + (r1 + (768*x0)), tmp33, rmask)
tl.store(out_ptr0 + (x0), tmp15, None)
def get_args():
arg_0 = rand_strided((8, 197, 1), (197, 1, 1), device='cuda:0', dtype=torch.float32)
arg_1 = rand_strided((8, 197, 768), (151296, 768, 1), device='cuda:0', dtype=torch.float32)
arg_2 = rand_strided((768,), (1,), device='cuda:0', dtype=torch.float32)
arg_3 = rand_strided((1576, 768), (768, 1), device='cuda:0', dtype=torch.float16)
arg_4 = rand_strided((768,), (1,), device='cuda:0', dtype=torch.float32)
arg_5 = rand_strided((768,), (1,), device='cuda:0', dtype=torch.float32)
arg_6 = rand_strided((8, 197, 1), (197, 1, 1), device='cuda:0', dtype=torch.float32)
arg_7 = rand_strided((8, 197, 768), (151296, 768, 1), device='cuda:0', dtype=torch.float16)
return arg_0, arg_1, arg_2, arg_3, arg_4, arg_5, arg_6, arg_7,
def call(args):
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
stream0 = get_raw_stream(0)
triton_per_fused__to_copy_add_mul_native_layer_norm_8.run(*args, 1576, 768, grid=grid(1576), stream=stream0)
def benchmark_all_configs(args):
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
return triton_per_fused__to_copy_add_mul_native_layer_norm_8.benchmark_all_configs(*args, 1576, 768, grid=grid(1576))
if __name__ == '__main__':
from triton.testing import do_bench
args = get_args()
ms = do_bench(lambda: call(args), rep=40, fast_flush=True)
num_gb = 0.009711072
gb_per_s = num_gb / (ms / 1e3)
print(f"{ms:.3f}ms {num_gb:.3f}GB {gb_per_s:.2f}GB/s")
''', device_str='cuda')
# kernel path: /tmp/torchinductor_shunting/tmp51wq3e92/4x/c4x23rhimeds4ff7dm4zylzuepeejun5l25jer4np6dgcwgua7al.py
# Source Nodes: [x_12], Original ATen: [aten._to_copy]
# x_12 => convert_element_type_16
triton_poi_fused__to_copy_9 = async_compile.triton('triton_poi_fused__to_copy_9', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
from torch._dynamo.testing import rand_strided
from torch._C import _cuda_getCurrentRawStream as get_raw_stream
import torch
from torch._inductor.runtime.triton_heuristics import grid, split_scan_grid
@triton_heuristics.pointwise(
size_hints=[4194304],
filename=__file__,
triton_meta={'signature': {0: '*fp32', 1: '*fp16', 2: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2), equal_to_1=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_9', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': '196B341D951BBDD96DE5B2B44B37054C2CBA8E89494C8B8A1052A863B8BC7596', 'are_deterministic_algorithms_enabled': True, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'kernel_num_gb': 0.014155776},
min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused__to_copy_9(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
xnumel = 2359296
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = tl.full([XBLOCK], True, tl.int1)
x0 = xindex
tmp0 = tl.load(in_ptr0 + (x0), None)
tmp1 = tmp0.to(tl.float32)
tl.store(out_ptr0 + (x0), tmp1, None)
def get_args():
arg_0 = rand_strided((3072, 768), (768, 1), device='cuda:0', dtype=torch.float32)
arg_1 = rand_strided((3072, 768), (768, 1), device='cuda:0', dtype=torch.float16)
return arg_0, arg_1,
def call(args):
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_9.run(*args, 2359296, grid=grid(2359296), stream=stream0)
def benchmark_all_configs(args):
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
return triton_poi_fused__to_copy_9.benchmark_all_configs(*args, 2359296, grid=grid(2359296))
if __name__ == '__main__':
from triton.testing import do_bench
args = get_args()
ms = do_bench(lambda: call(args), rep=40, fast_flush=True)
num_gb = 0.014155776
gb_per_s = num_gb / (ms / 1e3)
print(f"{ms:.3f}ms {num_gb:.3f}GB {gb_per_s:.2f}GB/s")
''', device_str='cuda')
# kernel path: /tmp/torchinductor_shunting/tmp51wq3e92/ov/covixf22fdb6vexsalpvegfv23f3ykfz4otdxjxqmihdlny4bq7y.py
# Source Nodes: [x_12], Original ATen: [aten._to_copy]
# x_12 => convert_element_type_15
triton_poi_fused__to_copy_10 = async_compile.triton('triton_poi_fused__to_copy_10', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
from torch._dynamo.testing import rand_strided
from torch._C import _cuda_getCurrentRawStream as get_raw_stream
import torch
from torch._inductor.runtime.triton_heuristics import grid, split_scan_grid
@triton_heuristics.pointwise(
size_hints=[4096],
filename=__file__,
triton_meta={'signature': {0: '*fp32', 1: '*fp16', 2: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2), equal_to_1=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_10', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': '196B341D951BBDD96DE5B2B44B37054C2CBA8E89494C8B8A1052A863B8BC7596', 'are_deterministic_algorithms_enabled': True, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'kernel_num_gb': 1.8432e-05},
min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused__to_copy_10(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
xnumel = 3072
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = xindex < xnumel
x0 = xindex
tmp0 = tl.load(in_ptr0 + (x0), xmask)
tmp1 = tmp0.to(tl.float32)
tl.store(out_ptr0 + (x0), tmp1, xmask)
def get_args():
arg_0 = rand_strided((3072,), (1,), device='cuda:0', dtype=torch.float32)
arg_1 = rand_strided((3072,), (1,), device='cuda:0', dtype=torch.float16)
return arg_0, arg_1,
def call(args):
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_10.run(*args, 3072, grid=grid(3072), stream=stream0)
def benchmark_all_configs(args):
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
return triton_poi_fused__to_copy_10.benchmark_all_configs(*args, 3072, grid=grid(3072))
if __name__ == '__main__':
from triton.testing import do_bench
args = get_args()
ms = do_bench(lambda: call(args), rep=40, fast_flush=True)
num_gb = 1.8432e-05
gb_per_s = num_gb / (ms / 1e3)
print(f"{ms:.3f}ms {num_gb:.3f}GB {gb_per_s:.2f}GB/s")
''', device_str='cuda')
# kernel path: /tmp/torchinductor_shunting/tmp51wq3e92/nm/cnm2xzscica7x3ivq4mli6qlklaikjcpmfawojdp4i6xgchvu2i7.py
# Source Nodes: [x_13], Original ATen: [aten.gelu]
# x_13 => add_5, convert_element_type_21, convert_element_type_22, erf, mul_5, mul_6, mul_7
triton_poi_fused_gelu_11 = async_compile.triton('triton_poi_fused_gelu_11', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
from torch._dynamo.testing import rand_strided
from torch._C import _cuda_getCurrentRawStream as get_raw_stream
import torch
from torch._inductor.runtime.triton_heuristics import grid, split_scan_grid
@triton_heuristics.pointwise(
size_hints=[8388608],
filename=__file__,
triton_meta={'signature': {0: '*fp16', 1: '*fp16', 2: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2), equal_to_1=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_gelu_11', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': '196B341D951BBDD96DE5B2B44B37054C2CBA8E89494C8B8A1052A863B8BC7596', 'are_deterministic_algorithms_enabled': True, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'kernel_num_gb': 0.019365888},
min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused_gelu_11(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
xnumel = 4841472
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = tl.full([XBLOCK], True, tl.int1)
x0 = xindex
tmp0 = tl.load(in_ptr0 + (x0), None).to(tl.float32)
tmp1 = tmp0.to(tl.float32)
tmp2 = 0.5
tmp3 = tmp1 * tmp2
tmp4 = 0.7071067811865476
tmp5 = tmp1 * tmp4
tmp6 = libdevice.erf(tmp5)
tmp7 = 1.0
tmp8 = tmp6 + tmp7
tmp9 = tmp3 * tmp8
tmp10 = tmp9.to(tl.float32)
tl.store(out_ptr0 + (x0), tmp10, None)
def get_args():
arg_0 = rand_strided((1576, 3072), (3072, 1), device='cuda:0', dtype=torch.float16)
arg_1 = rand_strided((8, 197, 3072), (605184, 3072, 1), device='cuda:0', dtype=torch.float16)
return arg_0, arg_1,
def call(args):
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
stream0 = get_raw_stream(0)
triton_poi_fused_gelu_11.run(*args, 4841472, grid=grid(4841472), stream=stream0)
def benchmark_all_configs(args):
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
return triton_poi_fused_gelu_11.benchmark_all_configs(*args, 4841472, grid=grid(4841472))
if __name__ == '__main__':
from triton.testing import do_bench
args = get_args()
ms = do_bench(lambda: call(args), rep=40, fast_flush=True)
num_gb = 0.019365888
gb_per_s = num_gb / (ms / 1e3)
print(f"{ms:.3f}ms {num_gb:.3f}GB {gb_per_s:.2f}GB/s")
''', device_str='cuda')
# kernel path: /tmp/torchinductor_shunting/tmp51wq3e92/7x/c7xwoxk7qmiadkvs57zm3btkcxhzxhe2gpzlr6vxui6fiuvu7huf.py
# Source Nodes: [x_16], Original ATen: [aten._to_copy]
# x_16 => convert_element_type_24
triton_poi_fused__to_copy_12 = async_compile.triton('triton_poi_fused__to_copy_12', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
from torch._dynamo.testing import rand_strided
from torch._C import _cuda_getCurrentRawStream as get_raw_stream
import torch
from torch._inductor.runtime.triton_heuristics import grid, split_scan_grid
@triton_heuristics.pointwise(
size_hints=[4194304],
filename=__file__,
triton_meta={'signature': {0: '*fp32', 1: '*fp16', 2: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2), equal_to_1=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_12', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': '196B341D951BBDD96DE5B2B44B37054C2CBA8E89494C8B8A1052A863B8BC7596', 'are_deterministic_algorithms_enabled': True, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'kernel_num_gb': 0.014155776},
min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused__to_copy_12(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
xnumel = 2359296
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = tl.full([XBLOCK], True, tl.int1)
x0 = xindex
tmp0 = tl.load(in_ptr0 + (x0), None)
tmp1 = tmp0.to(tl.float32)
tl.store(out_ptr0 + (x0), tmp1, None)
def get_args():
arg_0 = rand_strided((768, 3072), (3072, 1), device='cuda:0', dtype=torch.float32)
arg_1 = rand_strided((768, 3072), (3072, 1), device='cuda:0', dtype=torch.float16)
return arg_0, arg_1,
def call(args):
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_12.run(*args, 2359296, grid=grid(2359296), stream=stream0)
def benchmark_all_configs(args):
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
return triton_poi_fused__to_copy_12.benchmark_all_configs(*args, 2359296, grid=grid(2359296))
if __name__ == '__main__':
from triton.testing import do_bench
args = get_args()
ms = do_bench(lambda: call(args), rep=40, fast_flush=True)
num_gb = 0.014155776
gb_per_s = num_gb / (ms / 1e3)
print(f"{ms:.3f}ms {num_gb:.3f}GB {gb_per_s:.2f}GB/s")
''', device_str='cuda')
# kernel path: /tmp/torchinductor_shunting/tmp51wq3e92/mp/cmpv2vlxq4sq5ppnyiaai2afqa3qqyefomy5hp7zkxsr2d2gwvqy.py
# Source Nodes: [mul, mul_1, qkv_2, x_10, x_18, x_19], Original ATen: [aten._to_copy, aten.add, aten.mul, aten.native_layer_norm, aten.native_layer_norm_backward]
# mul => mul_2
# mul_1 => mul_8
# qkv_2 => convert_element_type_30
# x_10 => add_2
# x_18 => add_6
# x_19 => add_7, add_8, mul_10, mul_9, rsqrt_2, sub_2, var_mean_2
triton_per_fused__to_copy_add_mul_native_layer_norm_native_layer_norm_backward_13 = async_compile.triton('triton_per_fused__to_copy_add_mul_native_layer_norm_native_layer_norm_backward_13', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
from torch._dynamo.testing import rand_strided
from torch._C import _cuda_getCurrentRawStream as get_raw_stream
import torch
from torch._inductor.runtime.triton_heuristics import grid, split_scan_grid
@triton_heuristics.persistent_reduction(
size_hints=[2048, 1024],
reduction_hint=ReductionHint.INNER,
filename=__file__,
triton_meta={'signature': {0: '*fp32', 1: '*fp32', 2: '*fp16', 3: '*fp32', 4: '*fp16', 5: '*fp32', 6: '*fp32', 7: '*fp32', 8: '*fp32', 9: '*fp16', 10: '*fp32', 11: 'i32', 12: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 12), equal_to_1=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_add_mul_native_layer_norm_native_layer_norm_backward_13', 'mutated_arg_names': [], 'no_x_dim': True, 'num_load': 7, 'num_reduction': 4, 'backend_hash': '196B341D951BBDD96DE5B2B44B37054C2CBA8E89494C8B8A1052A863B8BC7596', 'are_deterministic_algorithms_enabled': True, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'kernel_num_gb': 0.021805216}
)
@triton.jit
def triton_per_fused__to_copy_add_mul_native_layer_norm_native_layer_norm_backward_13(in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, in_ptr5, in_ptr6, out_ptr0, out_ptr3, out_ptr4, out_ptr5, xnumel, rnumel):
xnumel = 1576
XBLOCK: tl.constexpr = 1
rnumel = 768
RBLOCK: tl.constexpr = 1024
xoffset = tl.program_id(0) * XBLOCK
xindex = tl.full([1], xoffset, tl.int32)
xmask = tl.full([RBLOCK], True, tl.int1)
rindex = tl.arange(0, RBLOCK)[:]
roffset = 0
rmask = rindex < rnumel
r1 = rindex
x0 = xindex
tmp0 = tl.load(in_ptr0 + (r1 + (768*x0)), rmask, other=0.0)
tmp1 = tl.load(in_ptr1 + (r1), rmask, eviction_policy='evict_last', other=0.0)
tmp2 = tl.load(in_ptr2 + (r1 + (768*x0)), rmask, other=0.0).to(tl.float32)
tmp6 = tl.load(in_ptr3 + (r1), rmask, eviction_policy='evict_last', other=0.0)
tmp7 = tl.load(in_ptr4 + (r1 + (768*x0)), rmask, other=0.0).to(tl.float32)
tmp34 = tl.load(in_ptr5 + (r1), rmask, eviction_policy='evict_last', other=0.0)
tmp36 = tl.load(in_ptr6 + (r1), rmask, eviction_policy='evict_last', other=0.0)
tmp3 = tmp2.to(tl.float32)
tmp4 = tmp1 * tmp3
tmp5 = tmp0 + tmp4
tmp8 = tmp7.to(tl.float32)
tmp9 = tmp6 * tmp8
tmp10 = tmp5 + tmp9
tmp11 = tl.broadcast_to(tmp10, [RBLOCK])
tmp13 = tl.where(rmask, tmp11, 0)
tmp14 = tl.broadcast_to(tmp11, [RBLOCK])
tmp16 = tl.where(rmask, tmp14, 0)
tmp17 = triton_helpers.promote_to_tensor(tl.sum(tmp16, 0))
tmp18 = tl.full([1], 768, tl.int32)
tmp19 = tmp18.to(tl.float32)
tmp20 = tmp17 / tmp19
tmp21 = tmp11 - tmp20
tmp22 = tmp21 * tmp21
tmp23 = tl.broadcast_to(tmp22, [RBLOCK])
tmp25 = tl.where(rmask, tmp23, 0)
tmp26 = triton_helpers.promote_to_tensor(tl.sum(tmp25, 0))
tmp27 = tmp10 - tmp20
tmp28 = 768.0
tmp29 = tmp26 / tmp28
tmp30 = 1e-06
tmp31 = tmp29 + tmp30
tmp32 = libdevice.rsqrt(tmp31)
tmp33 = tmp27 * tmp32
tmp35 = tmp33 * tmp34
tmp37 = tmp35 + tmp36
tmp38 = tmp37.to(tl.float32)
tmp39 = 0.0013020833333333333
tmp40 = tmp32 * tmp39
tl.store(out_ptr0 + (r1 + (768*x0)), tmp10, rmask)
tl.store(out_ptr3 + (r1 + (768*x0)), tmp33, rmask)
tl.store(out_ptr4 + (r1 + (768*x0)), tmp38, rmask)
tl.store(out_ptr5 + (x0), tmp40, None)
def get_args():
arg_0 = rand_strided((8, 197, 768), (151296, 768, 1), device='cuda:0', dtype=torch.float32)
arg_1 = rand_strided((768,), (1,), device='cuda:0', dtype=torch.float32)
arg_2 = rand_strided((1576, 768), (768, 1), device='cuda:0', dtype=torch.float16)
arg_3 = rand_strided((768,), (1,), device='cuda:0', dtype=torch.float32)
arg_4 = rand_strided((1576, 768), (768, 1), device='cuda:0', dtype=torch.float16)
arg_5 = rand_strided((768,), (1,), device='cuda:0', dtype=torch.float32)
arg_6 = rand_strided((768,), (1,), device='cuda:0', dtype=torch.float32)
arg_7 = rand_strided((8, 197, 768), (151296, 768, 1), device='cuda:0', dtype=torch.float32)
arg_8 = rand_strided((8, 197, 768), (151296, 768, 1), device='cuda:0', dtype=torch.float32)
arg_9 = rand_strided((8, 197, 768), (151296, 768, 1), device='cuda:0', dtype=torch.float16)
arg_10 = rand_strided((8, 197, 1), (197, 1, 1), device='cuda:0', dtype=torch.float32)
return arg_0, arg_1, arg_2, arg_3, arg_4, arg_5, arg_6, arg_7, arg_8, arg_9, arg_10,
def call(args):
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
stream0 = get_raw_stream(0)
triton_per_fused__to_copy_add_mul_native_layer_norm_native_layer_norm_backward_13.run(*args, 1576, 768, grid=grid(1576), stream=stream0)
def benchmark_all_configs(args):
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
return triton_per_fused__to_copy_add_mul_native_layer_norm_native_layer_norm_backward_13.benchmark_all_configs(*args, 1576, 768, grid=grid(1576))
if __name__ == '__main__':
from triton.testing import do_bench
args = get_args()
ms = do_bench(lambda: call(args), rep=40, fast_flush=True)
num_gb = 0.021805216
gb_per_s = num_gb / (ms / 1e3)
print(f"{ms:.3f}ms {num_gb:.3f}GB {gb_per_s:.2f}GB/s")
''', device_str='cuda')
# kernel path: /tmp/torchinductor_shunting/tmp51wq3e92/dy/cdysvbb2rbxcu7v6o5qs4gjpjbax6ne62kohihulkf4b7ondcqcb.py
# Source Nodes: [mul_2, x_24, x_25, x_26], Original ATen: [aten._to_copy, aten.add, aten.mul, aten.native_layer_norm, aten.native_layer_norm_backward]
# mul_2 => mul_11
# x_24 => add_9
# x_25 => add_10, add_11, mul_12, mul_13, rsqrt_3, sub_3, var_mean_3
# x_26 => convert_element_type_42
triton_per_fused__to_copy_add_mul_native_layer_norm_native_layer_norm_backward_14 = async_compile.triton('triton_per_fused__to_copy_add_mul_native_layer_norm_native_layer_norm_backward_14', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
from torch._dynamo.testing import rand_strided
from torch._C import _cuda_getCurrentRawStream as get_raw_stream
import torch
from torch._inductor.runtime.triton_heuristics import grid, split_scan_grid
@triton_heuristics.persistent_reduction(
size_hints=[2048, 1024],
reduction_hint=ReductionHint.INNER,
filename=__file__,
triton_meta={'signature': {0: '*fp32', 1: '*fp32', 2: '*fp16', 3: '*fp32', 4: '*fp32', 5: '*fp32', 6: '*fp16', 7: '*fp32', 8: 'i32', 9: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4, 5, 6, 7, 9), equal_to_1=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_add_mul_native_layer_norm_native_layer_norm_backward_14', 'mutated_arg_names': [], 'no_x_dim': True, 'num_load': 5, 'num_reduction': 4, 'backend_hash': '196B341D951BBDD96DE5B2B44B37054C2CBA8E89494C8B8A1052A863B8BC7596', 'are_deterministic_algorithms_enabled': True, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'kernel_num_gb': 0.014539936}
)
@triton.jit
def triton_per_fused__to_copy_add_mul_native_layer_norm_native_layer_norm_backward_14(in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, out_ptr2, out_ptr3, out_ptr4, xnumel, rnumel):
xnumel = 1576
XBLOCK: tl.constexpr = 1
rnumel = 768
RBLOCK: tl.constexpr = 1024
xoffset = tl.program_id(0) * XBLOCK
xindex = tl.full([1], xoffset, tl.int32)
xmask = tl.full([RBLOCK], True, tl.int1)
rindex = tl.arange(0, RBLOCK)[:]
roffset = 0
rmask = rindex < rnumel
r1 = rindex
x0 = xindex
tmp0 = tl.load(in_ptr0 + (r1 + (768*x0)), rmask, other=0.0)
tmp1 = tl.load(in_ptr1 + (r1), rmask, eviction_policy='evict_last', other=0.0)
tmp2 = tl.load(in_ptr2 + (r1 + (768*x0)), rmask, other=0.0).to(tl.float32)
tmp29 = tl.load(in_ptr3 + (r1), rmask, eviction_policy='evict_last', other=0.0)
tmp31 = tl.load(in_ptr4 + (r1), rmask, eviction_policy='evict_last', other=0.0)
tmp3 = tmp2.to(tl.float32)
tmp4 = tmp1 * tmp3
tmp5 = tmp0 + tmp4
tmp6 = tl.broadcast_to(tmp5, [RBLOCK])
tmp8 = tl.where(rmask, tmp6, 0)
tmp9 = tl.broadcast_to(tmp6, [RBLOCK])
tmp11 = tl.where(rmask, tmp9, 0)
tmp12 = triton_helpers.promote_to_tensor(tl.sum(tmp11, 0))
tmp13 = tl.full([1], 768, tl.int32)
tmp14 = tmp13.to(tl.float32)
tmp15 = tmp12 / tmp14
tmp16 = tmp6 - tmp15
tmp17 = tmp16 * tmp16
tmp18 = tl.broadcast_to(tmp17, [RBLOCK])
tmp20 = tl.where(rmask, tmp18, 0)
tmp21 = triton_helpers.promote_to_tensor(tl.sum(tmp20, 0))
tmp22 = tmp5 - tmp15
tmp23 = 768.0
tmp24 = tmp21 / tmp23
tmp25 = 1e-06
tmp26 = tmp24 + tmp25
tmp27 = libdevice.rsqrt(tmp26)
tmp28 = tmp22 * tmp27
tmp30 = tmp28 * tmp29
tmp32 = tmp30 + tmp31
tmp33 = tmp32.to(tl.float32)
tmp34 = 0.0013020833333333333
tmp35 = tmp27 * tmp34
tl.store(out_ptr2 + (r1 + (768*x0)), tmp28, rmask)
tl.store(out_ptr3 + (r1 + (768*x0)), tmp33, rmask)
tl.store(out_ptr4 + (x0), tmp35, None)
def get_args():
arg_0 = rand_strided((8, 197, 768), (151296, 768, 1), device='cuda:0', dtype=torch.float32)
arg_1 = rand_strided((768,), (1,), device='cuda:0', dtype=torch.float32)
arg_2 = rand_strided((1576, 768), (768, 1), device='cuda:0', dtype=torch.float16)
arg_3 = rand_strided((768,), (1,), device='cuda:0', dtype=torch.float32)
arg_4 = rand_strided((768,), (1,), device='cuda:0', dtype=torch.float32)
arg_5 = rand_strided((8, 197, 768), (151296, 768, 1), device='cuda:0', dtype=torch.float32)
arg_6 = rand_strided((8, 197, 768), (151296, 768, 1), device='cuda:0', dtype=torch.float16)
arg_7 = rand_strided((8, 197, 1), (197, 1, 1), device='cuda:0', dtype=torch.float32)
return arg_0, arg_1, arg_2, arg_3, arg_4, arg_5, arg_6, arg_7,
def call(args):
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
stream0 = get_raw_stream(0)
triton_per_fused__to_copy_add_mul_native_layer_norm_native_layer_norm_backward_14.run(*args, 1576, 768, grid=grid(1576), stream=stream0)
def benchmark_all_configs(args):
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
return triton_per_fused__to_copy_add_mul_native_layer_norm_native_layer_norm_backward_14.benchmark_all_configs(*args, 1576, 768, grid=grid(1576))
if __name__ == '__main__':
from triton.testing import do_bench
args = get_args()
ms = do_bench(lambda: call(args), rep=40, fast_flush=True)
num_gb = 0.014539936
gb_per_s = num_gb / (ms / 1e3)
print(f"{ms:.3f}ms {num_gb:.3f}GB {gb_per_s:.2f}GB/s")
''', device_str='cuda')
# kernel path: /tmp/torchinductor_shunting/tmp51wq3e92/z6/cz6y24wvc3o3aw2kjuvbxfybimutqelk3q7u67kd5lnhvv4bwaty.py
# Source Nodes: [x_174], Original ATen: [aten.mean]
# x_174 => mean
triton_red_fused_mean_15 = async_compile.triton('triton_red_fused_mean_15', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
from torch._dynamo.testing import rand_strided
from torch._C import _cuda_getCurrentRawStream as get_raw_stream
import torch
from torch._inductor.runtime.triton_heuristics import grid, split_scan_grid
@triton_heuristics.reduction(
size_hints=[16384, 128],
reduction_hint=ReductionHint.OUTER,
filename=__file__,
triton_meta={'signature': {0: '*fp32', 1: '*fp32', 2: '*fp16', 3: '*fp32', 4: '*fp16', 5: '*fp32', 6: 'i32', 7: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4, 5, 6), equal_to_1=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused_mean_15', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 5, 'num_reduction': 1, 'backend_hash': '196B341D951BBDD96DE5B2B44B37054C2CBA8E89494C8B8A1052A863B8BC7596', 'are_deterministic_algorithms_enabled': True, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'kernel_num_gb': 0.009689088}
)
@triton.jit
def triton_red_fused_mean_15(in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, out_ptr0, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr):
xnumel = 12288
rnumel = 98
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
xmask = tl.full([XBLOCK, RBLOCK], True, tl.int1)
rbase = tl.arange(0, RBLOCK)[None, :]
x0 = xindex % 768
x1 = (xindex // 768) % 2
x2 = (xindex // 1536)
tmp1 = tl.load(in_ptr1 + (x0), None, eviction_policy='evict_last')
tmp6 = tl.load(in_ptr3 + (x0), None, eviction_policy='evict_last')
_tmp12 = tl.full([XBLOCK, RBLOCK], 0, tl.float32)
x4 = xindex
for roffset in range(0, rnumel, RBLOCK):
rindex = roffset + rbase
rmask = rindex < rnumel
r3 = rindex
tmp0 = tl.load(in_ptr0 + (768 + x0 + (768*r3) + (75264*x1) + (151296*x2)), rmask, eviction_policy='evict_first', other=0.0)
tmp2 = tl.load(in_ptr2 + (768 + x0 + (768*r3) + (75264*x1) + (151296*x2)), rmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp7 = tl.load(in_ptr4 + (768 + x0 + (768*r3) + (75264*x1) + (151296*x2)), rmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp3 = tmp2.to(tl.float32)
tmp4 = tmp1 * tmp3
tmp5 = tmp0 + tmp4
tmp8 = tmp7.to(tl.float32)
tmp9 = tmp6 * tmp8
tmp10 = tmp5 + tmp9
tmp11 = tl.broadcast_to(tmp10, [XBLOCK, RBLOCK])
tmp13 = _tmp12 + tmp11
_tmp12 = tl.where(rmask, tmp13, _tmp12)
tmp12 = tl.sum(_tmp12, 1)[:, None]
tl.store(out_ptr0 + (x4), tmp12, None)
def get_args():
arg_0 = rand_strided((8, 197, 768), (151296, 768, 1), device='cuda:0', dtype=torch.float32)
arg_1 = rand_strided((768,), (1,), device='cuda:0', dtype=torch.float32)
arg_2 = rand_strided((1576, 768), (768, 1), device='cuda:0', dtype=torch.float16)
arg_3 = rand_strided((768,), (1,), device='cuda:0', dtype=torch.float32)
arg_4 = rand_strided((1576, 768), (768, 1), device='cuda:0', dtype=torch.float16)
arg_5 = rand_strided((8, 768, 2), (1536, 1, 768), device='cuda:0', dtype=torch.float32)
return arg_0, arg_1, arg_2, arg_3, arg_4, arg_5,
def call(args):
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
stream0 = get_raw_stream(0)
triton_red_fused_mean_15.run(*args, 12288, 98, grid=grid(12288), stream=stream0)
def benchmark_all_configs(args):
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
return triton_red_fused_mean_15.benchmark_all_configs(*args, 12288, 98, grid=grid(12288))
if __name__ == '__main__':
from triton.testing import do_bench
args = get_args()
ms = do_bench(lambda: call(args), rep=40, fast_flush=True)
num_gb = 0.009689088
gb_per_s = num_gb / (ms / 1e3)
print(f"{ms:.3f}ms {num_gb:.3f}GB {gb_per_s:.2f}GB/s")
''', device_str='cuda')
# kernel path: /tmp/torchinductor_shunting/tmp51wq3e92/zj/czjvdvt56qavkjrfu2xvatcbj5l7v7m2meg7m4nxcj5f3ckzwznk.py
# Source Nodes: [x_174], Original ATen: [aten.mean]
# x_174 => mean
triton_per_fused_mean_16 = async_compile.triton('triton_per_fused_mean_16', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
from torch._dynamo.testing import rand_strided
from torch._C import _cuda_getCurrentRawStream as get_raw_stream
import torch
from torch._inductor.runtime.triton_heuristics import grid, split_scan_grid
@triton_heuristics.persistent_reduction(
size_hints=[8192, 2],
reduction_hint=ReductionHint.OUTER,
filename=__file__,
triton_meta={'signature': {0: '*fp32', 1: '*fp32', 2: 'i32', 3: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2), equal_to_1=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_per_fused_mean_16', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': '196B341D951BBDD96DE5B2B44B37054C2CBA8E89494C8B8A1052A863B8BC7596', 'are_deterministic_algorithms_enabled': True, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'kernel_num_gb': 7.3728e-05}
)
@triton.jit
def triton_per_fused_mean_16(in_ptr0, out_ptr0, xnumel, rnumel, XBLOCK : tl.constexpr):
xnumel = 6144
rnumel = 2
RBLOCK: tl.constexpr = 2
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
xmask = tl.full([XBLOCK, RBLOCK], True, tl.int1)
rindex = tl.arange(0, RBLOCK)[None, :]
roffset = 0
rmask = tl.full([XBLOCK, RBLOCK], True, tl.int1)
r2 = rindex
x0 = xindex % 768
x1 = (xindex // 768)
x3 = xindex
tmp0 = tl.load(in_ptr0 + (x0 + (768*r2) + (1536*x1)), None)
tmp1 = tl.broadcast_to(tmp0, [XBLOCK, RBLOCK])
tmp3 = tl.sum(tmp1, 1)[:, None]
tl.store(out_ptr0 + (x3), tmp3, None)
def get_args():
arg_0 = rand_strided((8, 768, 2), (1536, 1, 768), device='cuda:0', dtype=torch.float32)
arg_1 = rand_strided((8, 768), (768, 1), device='cuda:0', dtype=torch.float32)
return arg_0, arg_1,
def call(args):
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
stream0 = get_raw_stream(0)
triton_per_fused_mean_16.run(*args, 6144, 2, grid=grid(6144), stream=stream0)
def benchmark_all_configs(args):
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
return triton_per_fused_mean_16.benchmark_all_configs(*args, 6144, 2, grid=grid(6144))
if __name__ == '__main__':
from triton.testing import do_bench
args = get_args()
ms = do_bench(lambda: call(args), rep=40, fast_flush=True)
num_gb = 7.3728e-05
gb_per_s = num_gb / (ms / 1e3)
print(f"{ms:.3f}ms {num_gb:.3f}GB {gb_per_s:.2f}GB/s")
''', device_str='cuda')
# kernel path: /tmp/torchinductor_shunting/tmp51wq3e92/e6/ce6drhxrzyszsx2gi4yauegb6x2g7jpkqoxbfdf5d7b62jgrijem.py
# Source Nodes: [x_174, x_175, x_177], Original ATen: [aten._to_copy, aten.mean, aten.native_layer_norm, aten.native_layer_norm_backward]
# x_174 => mean
# x_175 => add_84, add_85, mul_108, mul_109, rsqrt_24, sub_24, var_mean_24
# x_177 => convert_element_type_305
triton_per_fused__to_copy_mean_native_layer_norm_native_layer_norm_backward_17 = async_compile.triton('triton_per_fused__to_copy_mean_native_layer_norm_native_layer_norm_backward_17', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
from torch._dynamo.testing import rand_strided
from torch._C import _cuda_getCurrentRawStream as get_raw_stream
import torch
from torch._inductor.runtime.triton_heuristics import grid, split_scan_grid
@triton_heuristics.persistent_reduction(
size_hints=[8, 1024],
reduction_hint=ReductionHint.INNER,
filename=__file__,
triton_meta={'signature': {0: '*fp32', 1: '*fp32', 2: '*fp32', 3: '*fp32', 4: '*fp16', 5: '*fp32', 6: 'i32', 7: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4, 5, 7), equal_to_1=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_mean_native_layer_norm_native_layer_norm_backward_17', 'mutated_arg_names': [], 'no_x_dim': True, 'num_load': 3, 'num_reduction': 4, 'backend_hash': '196B341D951BBDD96DE5B2B44B37054C2CBA8E89494C8B8A1052A863B8BC7596', 'are_deterministic_algorithms_enabled': True, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'kernel_num_gb': 6.7616e-05}
)
@triton.jit
def triton_per_fused__to_copy_mean_native_layer_norm_native_layer_norm_backward_17(in_ptr0, in_ptr1, in_ptr2, out_ptr2, out_ptr3, out_ptr4, xnumel, rnumel):
xnumel = 8
XBLOCK: tl.constexpr = 1
rnumel = 768
RBLOCK: tl.constexpr = 1024
xoffset = tl.program_id(0) * XBLOCK
xindex = tl.full([1], xoffset, tl.int32)
xmask = tl.full([RBLOCK], True, tl.int1)
rindex = tl.arange(0, RBLOCK)[:]
roffset = 0
rmask = rindex < rnumel
r1 = rindex
x0 = xindex
tmp0 = tl.load(in_ptr0 + (r1 + (768*x0)), rmask, other=0.0)
tmp26 = tl.load(in_ptr1 + (r1), rmask, eviction_policy='evict_last', other=0.0)
tmp28 = tl.load(in_ptr2 + (r1), rmask, eviction_policy='evict_last', other=0.0)
tmp1 = 196.0
tmp2 = tmp0 / tmp1
tmp3 = tl.broadcast_to(tmp2, [RBLOCK])
tmp5 = tl.where(rmask, tmp3, 0)
tmp6 = tl.broadcast_to(tmp3, [RBLOCK])
tmp8 = tl.where(rmask, tmp6, 0)
tmp9 = triton_helpers.promote_to_tensor(tl.sum(tmp8, 0))
tmp10 = tl.full([1], 768, tl.int32)
tmp11 = tmp10.to(tl.float32)
tmp12 = tmp9 / tmp11
tmp13 = tmp3 - tmp12
tmp14 = tmp13 * tmp13
tmp15 = tl.broadcast_to(tmp14, [RBLOCK])
tmp17 = tl.where(rmask, tmp15, 0)
tmp18 = triton_helpers.promote_to_tensor(tl.sum(tmp17, 0))
tmp19 = tmp2 - tmp12
tmp20 = 768.0
tmp21 = tmp18 / tmp20
tmp22 = 1e-06
tmp23 = tmp21 + tmp22
tmp24 = libdevice.rsqrt(tmp23)
tmp25 = tmp19 * tmp24
tmp27 = tmp25 * tmp26
tmp29 = tmp27 + tmp28
tmp30 = tmp29.to(tl.float32)
tmp31 = 0.0013020833333333333
tmp32 = tmp24 * tmp31
tl.store(out_ptr2 + (r1 + (768*x0)), tmp25, rmask)
tl.store(out_ptr3 + (r1 + (768*x0)), tmp30, rmask)
tl.store(out_ptr4 + (x0), tmp32, None)
def get_args():
arg_0 = rand_strided((8, 768), (768, 1), device='cuda:0', dtype=torch.float32)
arg_1 = rand_strided((768,), (1,), device='cuda:0', dtype=torch.float32)
arg_2 = rand_strided((768,), (1,), device='cuda:0', dtype=torch.float32)
arg_3 = rand_strided((8, 768), (768, 1), device='cuda:0', dtype=torch.float32)
arg_4 = rand_strided((8, 768), (768, 1), device='cuda:0', dtype=torch.float16)
arg_5 = rand_strided((8, 1), (1, 1), device='cuda:0', dtype=torch.float32)
return arg_0, arg_1, arg_2, arg_3, arg_4, arg_5,
def call(args):
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
stream0 = get_raw_stream(0)
triton_per_fused__to_copy_mean_native_layer_norm_native_layer_norm_backward_17.run(*args, 8, 768, grid=grid(8), stream=stream0)
def benchmark_all_configs(args):
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
return triton_per_fused__to_copy_mean_native_layer_norm_native_layer_norm_backward_17.benchmark_all_configs(*args, 8, 768, grid=grid(8))
if __name__ == '__main__':
from triton.testing import do_bench
args = get_args()
ms = do_bench(lambda: call(args), rep=40, fast_flush=True)
num_gb = 6.7616e-05
gb_per_s = num_gb / (ms / 1e3)
print(f"{ms:.3f}ms {num_gb:.3f}GB {gb_per_s:.2f}GB/s")
''', device_str='cuda')
# kernel path: /tmp/torchinductor_shunting/tmp51wq3e92/df/cdfjuwpmvjiwqelaxmxlh3nqhhnfxr4hyv2zclcoucsav3fv522y.py
# Source Nodes: [x_177], Original ATen: [aten._to_copy]
# x_177 => convert_element_type_304
triton_poi_fused__to_copy_18 = async_compile.triton('triton_poi_fused__to_copy_18', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
from torch._dynamo.testing import rand_strided
from torch._C import _cuda_getCurrentRawStream as get_raw_stream
import torch
from torch._inductor.runtime.triton_heuristics import grid, split_scan_grid
@triton_heuristics.pointwise(
size_hints=[1048576],
filename=__file__,
triton_meta={'signature': {0: '*fp32', 1: '*fp16', 2: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2), equal_to_1=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_18', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': '196B341D951BBDD96DE5B2B44B37054C2CBA8E89494C8B8A1052A863B8BC7596', 'are_deterministic_algorithms_enabled': True, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'kernel_num_gb': 0.004608},
min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused__to_copy_18(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
xnumel = 768000
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = tl.full([XBLOCK], True, tl.int1)
x0 = xindex
tmp0 = tl.load(in_ptr0 + (x0), None)
tmp1 = tmp0.to(tl.float32)
tl.store(out_ptr0 + (x0), tmp1, None)
def get_args():
arg_0 = rand_strided((1000, 768), (768, 1), device='cuda:0', dtype=torch.float32)
arg_1 = rand_strided((1000, 768), (768, 1), device='cuda:0', dtype=torch.float16)
return arg_0, arg_1,
def call(args):
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_18.run(*args, 768000, grid=grid(768000), stream=stream0)
def benchmark_all_configs(args):
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
return triton_poi_fused__to_copy_18.benchmark_all_configs(*args, 768000, grid=grid(768000))
if __name__ == '__main__':
from triton.testing import do_bench
args = get_args()
ms = do_bench(lambda: call(args), rep=40, fast_flush=True)
num_gb = 0.004608
gb_per_s = num_gb / (ms / 1e3)
print(f"{ms:.3f}ms {num_gb:.3f}GB {gb_per_s:.2f}GB/s")
''', device_str='cuda')
# kernel path: /tmp/torchinductor_shunting/tmp51wq3e92/ld/cldz75njkdkrz7nbaacrkxnckl2rxmxcngf6ibflnyc2pgyenbqt.py
# Source Nodes: [x_177], Original ATen: [aten._to_copy]
# x_177 => convert_element_type_303
triton_poi_fused__to_copy_19 = async_compile.triton('triton_poi_fused__to_copy_19', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
from torch._dynamo.testing import rand_strided
from torch._C import _cuda_getCurrentRawStream as get_raw_stream
import torch
from torch._inductor.runtime.triton_heuristics import grid, split_scan_grid
@triton_heuristics.pointwise(
size_hints=[1024],
filename=__file__,
triton_meta={'signature': {0: '*fp32', 1: '*fp16', 2: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1), equal_to_1=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_19', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': '196B341D951BBDD96DE5B2B44B37054C2CBA8E89494C8B8A1052A863B8BC7596', 'are_deterministic_algorithms_enabled': True, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'kernel_num_gb': 6e-06},
min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused__to_copy_19(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
xnumel = 1000
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = xindex < xnumel
x0 = xindex
tmp0 = tl.load(in_ptr0 + (x0), xmask)
tmp1 = tmp0.to(tl.float32)
tl.store(out_ptr0 + (x0), tmp1, xmask)
def get_args():
arg_0 = rand_strided((1000,), (1,), device='cuda:0', dtype=torch.float32)
arg_1 = rand_strided((1000,), (1,), device='cuda:0', dtype=torch.float16)
return arg_0, arg_1,
def call(args):
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_19.run(*args, 1000, grid=grid(1000), stream=stream0)
def benchmark_all_configs(args):
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
return triton_poi_fused__to_copy_19.benchmark_all_configs(*args, 1000, grid=grid(1000))
if __name__ == '__main__':
from triton.testing import do_bench
args = get_args()
ms = do_bench(lambda: call(args), rep=40, fast_flush=True)
num_gb = 6e-06
gb_per_s = num_gb / (ms / 1e3)
print(f"{ms:.3f}ms {num_gb:.3f}GB {gb_per_s:.2f}GB/s")
''', device_str='cuda')
async_compile.wait(globals())
del async_compile
def call(args):
primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, primals_12, primals_13, primals_14, primals_15, primals_16, primals_17, primals_18, primals_19, primals_20, primals_21, primals_22, primals_23, primals_24, primals_25, primals_26, primals_27, primals_28, primals_29, primals_30, primals_31, primals_32, primals_33, primals_34, primals_35, primals_36, primals_37, primals_38, primals_39, primals_40, primals_41, primals_42, primals_43, primals_44, primals_45, primals_46, primals_47, primals_48, primals_49, primals_50, primals_51, primals_52, primals_53, primals_54, primals_55, primals_56, primals_57, primals_58, primals_59, primals_60, primals_61, primals_62, primals_63, primals_64, primals_65, primals_66, primals_67, primals_68, primals_69, primals_70, primals_71, primals_72, primals_73, primals_74, primals_75, primals_76, primals_77, primals_78, primals_79, primals_80, primals_81, primals_82, primals_83, primals_84, primals_85, primals_86, primals_87, primals_88, primals_89, primals_90, primals_91, primals_92, primals_93, primals_94, primals_95, primals_96, primals_97, primals_98, primals_99, primals_100, primals_101, primals_102, primals_103, primals_104, primals_105, primals_106, primals_107, primals_108, primals_109, primals_110, primals_111, primals_112, primals_113, primals_114, primals_115, primals_116, primals_117, primals_118, primals_119, primals_120, primals_121, primals_122, primals_123, primals_124, primals_125, primals_126, primals_127, primals_128, primals_129, primals_130, primals_131, primals_132, primals_133, primals_134, primals_135, primals_136, primals_137, primals_138, primals_139, primals_140, primals_141, primals_142, primals_143, primals_144, primals_145, primals_146, primals_147, primals_148, primals_149, primals_150, primals_151, primals_152, primals_153, primals_154, primals_155, primals_156, primals_157, primals_158, primals_159, primals_160, primals_161, primals_162, primals_163, primals_164, primals_165, primals_166, primals_167, primals_168, primals_169, primals_170, primals_171, primals_172, primals_173, primals_174, primals_175, primals_176, primals_177, primals_178, primals_179, primals_180, primals_181, primals_182, primals_183, primals_184, primals_185, primals_186, primals_187, primals_188, primals_189, primals_190, primals_191, primals_192, primals_193, primals_194, primals_195, primals_196, primals_197, primals_198, primals_199, primals_200, primals_201, primals_202, primals_203, primals_204, primals_205, primals_206, primals_207, primals_208, primals_209, primals_210, primals_211, primals_212, primals_213, primals_214, primals_215, primals_216, primals_217, primals_218, primals_219, primals_220, primals_221, primals_222, primals_223, primals_224 = args
args.clear()
assert_size_stride(primals_1, (1, 1, 768), (768, 768, 1))
assert_size_stride(primals_2, (768, ), (1, ))
assert_size_stride(primals_3, (768, ), (1, ))
assert_size_stride(primals_4, (768, ), (1, ))
assert_size_stride(primals_5, (768, ), (1, ))
assert_size_stride(primals_6, (768, ), (1, ))
assert_size_stride(primals_7, (2304, 768), (768, 1))
assert_size_stride(primals_8, (732, 12), (12, 1))
assert_size_stride(primals_9, (768, ), (1, ))
assert_size_stride(primals_10, (768, ), (1, ))
assert_size_stride(primals_11, (768, ), (1, ))
assert_size_stride(primals_12, (768, ), (1, ))
assert_size_stride(primals_13, (768, ), (1, ))
assert_size_stride(primals_14, (768, ), (1, ))
assert_size_stride(primals_15, (768, ), (1, ))
assert_size_stride(primals_16, (768, ), (1, ))
assert_size_stride(primals_17, (2304, 768), (768, 1))
assert_size_stride(primals_18, (732, 12), (12, 1))
assert_size_stride(primals_19, (768, ), (1, ))
assert_size_stride(primals_20, (768, ), (1, ))
assert_size_stride(primals_21, (768, ), (1, ))
assert_size_stride(primals_22, (768, ), (1, ))
assert_size_stride(primals_23, (768, ), (1, ))
assert_size_stride(primals_24, (768, ), (1, ))
assert_size_stride(primals_25, (768, ), (1, ))
assert_size_stride(primals_26, (768, ), (1, ))
assert_size_stride(primals_27, (2304, 768), (768, 1))
assert_size_stride(primals_28, (732, 12), (12, 1))
assert_size_stride(primals_29, (768, ), (1, ))
assert_size_stride(primals_30, (768, ), (1, ))
assert_size_stride(primals_31, (768, ), (1, ))
assert_size_stride(primals_32, (768, ), (1, ))
assert_size_stride(primals_33, (768, ), (1, ))
assert_size_stride(primals_34, (768, ), (1, ))
assert_size_stride(primals_35, (768, ), (1, ))
assert_size_stride(primals_36, (768, ), (1, ))
assert_size_stride(primals_37, (2304, 768), (768, 1))
assert_size_stride(primals_38, (732, 12), (12, 1))
assert_size_stride(primals_39, (768, ), (1, ))
assert_size_stride(primals_40, (768, ), (1, ))
assert_size_stride(primals_41, (768, ), (1, ))
assert_size_stride(primals_42, (768, ), (1, ))
assert_size_stride(primals_43, (768, ), (1, ))
assert_size_stride(primals_44, (768, ), (1, ))
assert_size_stride(primals_45, (768, ), (1, ))
assert_size_stride(primals_46, (768, ), (1, ))
assert_size_stride(primals_47, (2304, 768), (768, 1))
assert_size_stride(primals_48, (732, 12), (12, 1))
assert_size_stride(primals_49, (768, ), (1, ))
assert_size_stride(primals_50, (768, ), (1, ))
assert_size_stride(primals_51, (768, ), (1, ))
assert_size_stride(primals_52, (768, ), (1, ))
assert_size_stride(primals_53, (768, ), (1, ))
assert_size_stride(primals_54, (768, ), (1, ))
assert_size_stride(primals_55, (768, ), (1, ))
assert_size_stride(primals_56, (768, ), (1, ))
assert_size_stride(primals_57, (2304, 768), (768, 1))
assert_size_stride(primals_58, (732, 12), (12, 1))
assert_size_stride(primals_59, (768, ), (1, ))
assert_size_stride(primals_60, (768, ), (1, ))
assert_size_stride(primals_61, (768, ), (1, ))
assert_size_stride(primals_62, (768, ), (1, ))
assert_size_stride(primals_63, (768, ), (1, ))
assert_size_stride(primals_64, (768, ), (1, ))
assert_size_stride(primals_65, (768, ), (1, ))
assert_size_stride(primals_66, (768, ), (1, ))
assert_size_stride(primals_67, (2304, 768), (768, 1))
assert_size_stride(primals_68, (732, 12), (12, 1))
assert_size_stride(primals_69, (768, ), (1, ))
assert_size_stride(primals_70, (768, ), (1, ))
assert_size_stride(primals_71, (768, ), (1, ))
assert_size_stride(primals_72, (768, ), (1, ))
assert_size_stride(primals_73, (768, ), (1, ))
assert_size_stride(primals_74, (768, ), (1, ))
assert_size_stride(primals_75, (768, ), (1, ))
assert_size_stride(primals_76, (768, ), (1, ))
assert_size_stride(primals_77, (2304, 768), (768, 1))
assert_size_stride(primals_78, (732, 12), (12, 1))
assert_size_stride(primals_79, (768, ), (1, ))
assert_size_stride(primals_80, (768, ), (1, ))
assert_size_stride(primals_81, (768, ), (1, ))
assert_size_stride(primals_82, (768, ), (1, ))
assert_size_stride(primals_83, (768, ), (1, ))
assert_size_stride(primals_84, (768, ), (1, ))
assert_size_stride(primals_85, (768, ), (1, ))
assert_size_stride(primals_86, (768, ), (1, ))
assert_size_stride(primals_87, (2304, 768), (768, 1))
assert_size_stride(primals_88, (732, 12), (12, 1))
assert_size_stride(primals_89, (768, ), (1, ))
assert_size_stride(primals_90, (768, ), (1, ))
assert_size_stride(primals_91, (768, ), (1, ))
assert_size_stride(primals_92, (768, ), (1, ))
assert_size_stride(primals_93, (768, ), (1, ))
assert_size_stride(primals_94, (768, ), (1, ))
assert_size_stride(primals_95, (768, ), (1, ))
assert_size_stride(primals_96, (768, ), (1, ))
assert_size_stride(primals_97, (2304, 768), (768, 1))
assert_size_stride(primals_98, (732, 12), (12, 1))
assert_size_stride(primals_99, (768, ), (1, ))
assert_size_stride(primals_100, (768, ), (1, ))
assert_size_stride(primals_101, (768, ), (1, ))
assert_size_stride(primals_102, (768, ), (1, ))
assert_size_stride(primals_103, (768, ), (1, ))
assert_size_stride(primals_104, (768, ), (1, ))
assert_size_stride(primals_105, (768, ), (1, ))
assert_size_stride(primals_106, (768, ), (1, ))
assert_size_stride(primals_107, (2304, 768), (768, 1))
assert_size_stride(primals_108, (732, 12), (12, 1))
assert_size_stride(primals_109, (768, ), (1, ))
assert_size_stride(primals_110, (768, ), (1, ))
assert_size_stride(primals_111, (768, ), (1, ))
assert_size_stride(primals_112, (768, ), (1, ))
assert_size_stride(primals_113, (768, ), (1, ))
assert_size_stride(primals_114, (768, ), (1, ))
assert_size_stride(primals_115, (768, ), (1, ))
assert_size_stride(primals_116, (768, ), (1, ))
assert_size_stride(primals_117, (2304, 768), (768, 1))
assert_size_stride(primals_118, (732, 12), (12, 1))
assert_size_stride(primals_119, (768, ), (1, ))
assert_size_stride(primals_120, (768, ), (1, ))
assert_size_stride(primals_121, (768, ), (1, ))
assert_size_stride(primals_122, (768, ), (1, ))
assert_size_stride(primals_123, (768, ), (1, ))
assert_size_stride(primals_124, (768, 3, 16, 16), (768, 256, 16, 1))
assert_size_stride(primals_125, (768, ), (1, ))
assert_size_stride(primals_126, (768, 768), (768, 1))
assert_size_stride(primals_127, (768, ), (1, ))
assert_size_stride(primals_128, (3072, 768), (768, 1))
assert_size_stride(primals_129, (3072, ), (1, ))
assert_size_stride(primals_130, (768, 3072), (3072, 1))
assert_size_stride(primals_131, (768, ), (1, ))
assert_size_stride(primals_132, (768, 768), (768, 1))
assert_size_stride(primals_133, (768, ), (1, ))
assert_size_stride(primals_134, (3072, 768), (768, 1))
assert_size_stride(primals_135, (3072, ), (1, ))
assert_size_stride(primals_136, (768, 3072), (3072, 1))
assert_size_stride(primals_137, (768, ), (1, ))
assert_size_stride(primals_138, (768, 768), (768, 1))
assert_size_stride(primals_139, (768, ), (1, ))
assert_size_stride(primals_140, (3072, 768), (768, 1))
assert_size_stride(primals_141, (3072, ), (1, ))
assert_size_stride(primals_142, (768, 3072), (3072, 1))
assert_size_stride(primals_143, (768, ), (1, ))
assert_size_stride(primals_144, (768, 768), (768, 1))
assert_size_stride(primals_145, (768, ), (1, ))
assert_size_stride(primals_146, (3072, 768), (768, 1))
assert_size_stride(primals_147, (3072, ), (1, ))
assert_size_stride(primals_148, (768, 3072), (3072, 1))
assert_size_stride(primals_149, (768, ), (1, ))
assert_size_stride(primals_150, (768, 768), (768, 1))
assert_size_stride(primals_151, (768, ), (1, ))
assert_size_stride(primals_152, (3072, 768), (768, 1))
assert_size_stride(primals_153, (3072, ), (1, ))
assert_size_stride(primals_154, (768, 3072), (3072, 1))
assert_size_stride(primals_155, (768, ), (1, ))
assert_size_stride(primals_156, (768, 768), (768, 1))
assert_size_stride(primals_157, (768, ), (1, ))
assert_size_stride(primals_158, (3072, 768), (768, 1))
assert_size_stride(primals_159, (3072, ), (1, ))
assert_size_stride(primals_160, (768, 3072), (3072, 1))
assert_size_stride(primals_161, (768, ), (1, ))
assert_size_stride(primals_162, (768, 768), (768, 1))
assert_size_stride(primals_163, (768, ), (1, ))
assert_size_stride(primals_164, (3072, 768), (768, 1))
assert_size_stride(primals_165, (3072, ), (1, ))
assert_size_stride(primals_166, (768, 3072), (3072, 1))
assert_size_stride(primals_167, (768, ), (1, ))
assert_size_stride(primals_168, (768, 768), (768, 1))
assert_size_stride(primals_169, (768, ), (1, ))
assert_size_stride(primals_170, (3072, 768), (768, 1))
assert_size_stride(primals_171, (3072, ), (1, ))
assert_size_stride(primals_172, (768, 3072), (3072, 1))
assert_size_stride(primals_173, (768, ), (1, ))
assert_size_stride(primals_174, (768, 768), (768, 1))
assert_size_stride(primals_175, (768, ), (1, ))
assert_size_stride(primals_176, (3072, 768), (768, 1))
assert_size_stride(primals_177, (3072, ), (1, ))
assert_size_stride(primals_178, (768, 3072), (3072, 1))
assert_size_stride(primals_179, (768, ), (1, ))
assert_size_stride(primals_180, (768, 768), (768, 1))
assert_size_stride(primals_181, (768, ), (1, ))
assert_size_stride(primals_182, (3072, 768), (768, 1))
assert_size_stride(primals_183, (3072, ), (1, ))
assert_size_stride(primals_184, (768, 3072), (3072, 1))
assert_size_stride(primals_185, (768, ), (1, ))
assert_size_stride(primals_186, (768, 768), (768, 1))
assert_size_stride(primals_187, (768, ), (1, ))
assert_size_stride(primals_188, (3072, 768), (768, 1))
assert_size_stride(primals_189, (3072, ), (1, ))
assert_size_stride(primals_190, (768, 3072), (3072, 1))
assert_size_stride(primals_191, (768, ), (1, ))
assert_size_stride(primals_192, (768, 768), (768, 1))
assert_size_stride(primals_193, (768, ), (1, ))
assert_size_stride(primals_194, (3072, 768), (768, 1))
assert_size_stride(primals_195, (3072, ), (1, ))
assert_size_stride(primals_196, (768, 3072), (3072, 1))
assert_size_stride(primals_197, (768, ), (1, ))
assert_size_stride(primals_198, (1000, 768), (768, 1))
assert_size_stride(primals_199, (1000, ), (1, ))
assert_size_stride(primals_200, (768, ), (1, ))
assert_size_stride(primals_201, (197, 197), (197, 1))
assert_size_stride(primals_202, (768, ), (1, ))
assert_size_stride(primals_203, (197, 197), (197, 1))
assert_size_stride(primals_204, (768, ), (1, ))
assert_size_stride(primals_205, (197, 197), (197, 1))
assert_size_stride(primals_206, (768, ), (1, ))
assert_size_stride(primals_207, (197, 197), (197, 1))
assert_size_stride(primals_208, (768, ), (1, ))
assert_size_stride(primals_209, (197, 197), (197, 1))
assert_size_stride(primals_210, (768, ), (1, ))
assert_size_stride(primals_211, (197, 197), (197, 1))
assert_size_stride(primals_212, (768, ), (1, ))
assert_size_stride(primals_213, (197, 197), (197, 1))
assert_size_stride(primals_214, (768, ), (1, ))
assert_size_stride(primals_215, (197, 197), (197, 1))
assert_size_stride(primals_216, (768, ), (1, ))
assert_size_stride(primals_217, (197, 197), (197, 1))
assert_size_stride(primals_218, (768, ), (1, ))
assert_size_stride(primals_219, (197, 197), (197, 1))
assert_size_stride(primals_220, (768, ), (1, ))
assert_size_stride(primals_221, (197, 197), (197, 1))
assert_size_stride(primals_222, (768, ), (1, ))
assert_size_stride(primals_223, (197, 197), (197, 1))
assert_size_stride(primals_224, (8, 3, 224, 224), (150528, 50176, 224, 1))
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
buf0 = empty_strided_cuda((768, 3, 16, 16), (768, 256, 16, 1), torch.float16)
# Source Nodes: [x], Original ATen: [aten._to_copy]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_0.run(primals_124, buf0, 589824, grid=grid(589824), stream=stream0)
del primals_124
buf1 = empty_strided_cuda((8, 3, 224, 224), (150528, 50176, 224, 1), torch.float16)
# Source Nodes: [x], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_1.run(primals_224, buf1, 1204224, grid=grid(1204224), stream=stream0)
del primals_224
buf2 = empty_strided_cuda((768, ), (1, ), torch.float16)
# Source Nodes: [x], Original ATen: [aten._to_copy, aten.convolution]
triton_poi_fused__to_copy_convolution_2.run(primals_125, buf2, 768, grid=grid(768), stream=stream0)
del primals_125
# Source Nodes: [x], Original ATen: [aten._to_copy, aten.convolution]
buf3 = extern_kernels.convolution(buf1, buf0, stride=(16, 16), padding=(0, 0), dilation=(1, 1), transposed=False, output_padding=(0, 0), groups=1, bias=None)
assert_size_stride(buf3, (8, 768, 14, 14), (150528, 196, 14, 1))
buf4 = empty_strided_cuda((8, 197, 768), (151296, 768, 1), torch.float32)
buf5 = empty_strided_cuda((8, 197, 1), (197, 1, 1), torch.float32)
buf6 = empty_strided_cuda((8, 197, 1), (197, 1, 1600), torch.float32)
buf8 = reinterpret_tensor(buf6, (8, 197, 1), (197, 1, 1), 0); del buf6 # reuse
buf9 = empty_strided_cuda((8, 197, 768), (151296, 768, 1), torch.float16)
# Source Nodes: [qkv, x_3, x_5], Original ATen: [aten._to_copy, aten.cat, aten.native_layer_norm]
triton_red_fused__to_copy_cat_native_layer_norm_3.run(buf8, primals_1, buf3, buf2, primals_3, primals_4, buf4, buf5, buf9, 1576, 768, grid=grid(1576), stream=stream0)
del buf3
del primals_1
del primals_4
buf10 = empty_strided_cuda((2304, 768), (768, 1), torch.float16)
# Source Nodes: [qkv], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_4.run(primals_7, buf10, 1769472, grid=grid(1769472), stream=stream0)
del primals_7
buf11 = empty_strided_cuda((2304, ), (1, ), torch.float16)
# Source Nodes: [qkv, qkv_bias], Original ATen: [aten._to_copy, aten.cat]
triton_poi_fused__to_copy_cat_5.run(primals_5, primals_200, primals_6, buf11, 2304, grid=grid(2304), stream=stream0)
del primals_200
del primals_5
del primals_6
buf12 = empty_strided_cuda((1576, 2304), (2304, 1), torch.float16)
# Source Nodes: [qkv, qkv_bias], Original ATen: [aten._to_copy, aten.addmm, aten.cat]
extern_kernels.addmm(buf11, reinterpret_tensor(buf9, (1576, 768), (768, 1), 0), reinterpret_tensor(buf10, (768, 2304), (1, 768), 0), alpha=1, beta=1, out=buf12)
buf13 = empty_strided_cuda((1, 12, 197, 200), (473088, 39424, 200, 1), torch.float16)
# Source Nodes: [x_6], Original ATen: [aten._to_copy, aten.constant_pad_nd]
triton_poi_fused__to_copy_constant_pad_nd_6.run(primals_201, primals_8, buf13, 472800, grid=grid(472800), stream=stream0)
del primals_8
# Source Nodes: [x_6], Original ATen: [aten._scaled_dot_product_efficient_attention]
buf14 = aten._scaled_dot_product_efficient_attention.default(reinterpret_tensor(buf12, (8, 12, 197, 64), (453888, 64, 2304, 1), 0), reinterpret_tensor(buf12, (8, 12, 197, 64), (453888, 64, 2304, 1), 768), reinterpret_tensor(buf12, (8, 12, 197, 64), (453888, 64, 2304, 1), 1536), reinterpret_tensor(buf13, (8, 12, 197, 197), (0, 39424, 200, 1), 0), True)
buf15 = buf14[0]
buf16 = buf14[1]
buf17 = buf14[2]
buf18 = buf14[3]
del buf14
buf19 = empty_strided_cuda((768, 768), (768, 1), torch.float16)
# Source Nodes: [x_8], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_7.run(primals_126, buf19, 589824, grid=grid(589824), stream=stream0)
del primals_126
buf20 = buf2; del buf2 # reuse
# Source Nodes: [x_8], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_convolution_2.run(primals_127, buf20, 768, grid=grid(768), stream=stream0)
del primals_127
buf21 = empty_strided_cuda((1576, 768), (768, 1), torch.float16)
# Source Nodes: [x_8], Original ATen: [aten._to_copy, aten.addmm]
extern_kernels.addmm(buf20, reinterpret_tensor(buf15, (1576, 768), (768, 1), 0), reinterpret_tensor(buf19, (768, 768), (1, 768), 0), alpha=1, beta=1, out=buf21)
buf22 = empty_strided_cuda((8, 197, 1), (197, 1, 1), torch.float32)
buf23 = empty_strided_cuda((8, 197, 1), (197, 1, 1600), torch.float32)
buf25 = reinterpret_tensor(buf23, (8, 197, 1), (197, 1, 1), 0); del buf23 # reuse
buf26 = empty_strided_cuda((8, 197, 768), (151296, 768, 1), torch.float16)
# Source Nodes: [mul, x_10, x_11, x_12], Original ATen: [aten._to_copy, aten.add, aten.mul, aten.native_layer_norm]
triton_per_fused__to_copy_add_mul_native_layer_norm_8.run(buf25, buf4, primals_2, buf21, primals_10, primals_11, buf22, buf26, 1576, 768, grid=grid(1576), stream=stream0)
del primals_11
buf27 = empty_strided_cuda((3072, 768), (768, 1), torch.float16)
# Source Nodes: [x_12], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_9.run(primals_128, buf27, 2359296, grid=grid(2359296), stream=stream0)
del primals_128
buf28 = empty_strided_cuda((3072, ), (1, ), torch.float16)
# Source Nodes: [x_12], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_10.run(primals_129, buf28, 3072, grid=grid(3072), stream=stream0)
del primals_129
buf29 = empty_strided_cuda((1576, 3072), (3072, 1), torch.float16)
# Source Nodes: [x_12], Original ATen: [aten._to_copy, aten.addmm]
extern_kernels.addmm(buf28, reinterpret_tensor(buf26, (1576, 768), (768, 1), 0), reinterpret_tensor(buf27, (768, 3072), (1, 768), 0), alpha=1, beta=1, out=buf29)
buf30 = empty_strided_cuda((8, 197, 3072), (605184, 3072, 1), torch.float16)
# Source Nodes: [x_13], Original ATen: [aten.gelu]
triton_poi_fused_gelu_11.run(buf29, buf30, 4841472, grid=grid(4841472), stream=stream0)
buf31 = empty_strided_cuda((768, 3072), (3072, 1), torch.float16)
# Source Nodes: [x_16], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_12.run(primals_130, buf31, 2359296, grid=grid(2359296), stream=stream0)
del primals_130
buf32 = buf20; del buf20 # reuse
# Source Nodes: [x_16], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_convolution_2.run(primals_131, buf32, 768, grid=grid(768), stream=stream0)
del primals_131
buf33 = empty_strided_cuda((1576, 768), (768, 1), torch.float16)
# Source Nodes: [x_16], Original ATen: [aten._to_copy, aten.addmm]
extern_kernels.addmm(buf32, reinterpret_tensor(buf30, (1576, 3072), (3072, 1), 0), reinterpret_tensor(buf31, (3072, 768), (1, 3072), 0), alpha=1, beta=1, out=buf33)
buf34 = empty_strided_cuda((8, 197, 768), (151296, 768, 1), torch.float32)
buf38 = empty_strided_cuda((8, 197, 768), (151296, 768, 1), torch.float32)
buf39 = empty_strided_cuda((8, 197, 768), (151296, 768, 1), torch.float16)
buf396 = empty_strided_cuda((8, 197, 1), (197, 1, 1), torch.float32)
# Source Nodes: [mul, mul_1, qkv_2, x_10, x_18, x_19], Original ATen: [aten._to_copy, aten.add, aten.mul, aten.native_layer_norm, aten.native_layer_norm_backward]
triton_per_fused__to_copy_add_mul_native_layer_norm_native_layer_norm_backward_13.run(buf4, primals_2, buf21, primals_9, buf33, primals_13, primals_14, buf34, buf38, buf39, buf396, 1576, 768, grid=grid(1576), stream=stream0)
del primals_14
buf40 = empty_strided_cuda((2304, 768), (768, 1), torch.float16)
# Source Nodes: [qkv_2], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_4.run(primals_17, buf40, 1769472, grid=grid(1769472), stream=stream0)
del primals_17
buf41 = buf11; del buf11 # reuse
# Source Nodes: [qkv_2, qkv_bias_1], Original ATen: [aten._to_copy, aten.cat]
triton_poi_fused__to_copy_cat_5.run(primals_15, primals_202, primals_16, buf41, 2304, grid=grid(2304), stream=stream0)
del primals_15
del primals_16
del primals_202
buf42 = empty_strided_cuda((1576, 2304), (2304, 1), torch.float16)
# Source Nodes: [qkv_2, qkv_bias_1], Original ATen: [aten._to_copy, aten.addmm, aten.cat]
extern_kernels.addmm(buf41, reinterpret_tensor(buf39, (1576, 768), (768, 1), 0), reinterpret_tensor(buf40, (768, 2304), (1, 768), 0), alpha=1, beta=1, out=buf42)
buf43 = empty_strided_cuda((1, 12, 197, 200), (473088, 39424, 200, 1), torch.float16)
# Source Nodes: [x_20], Original ATen: [aten._to_copy, aten.constant_pad_nd]
triton_poi_fused__to_copy_constant_pad_nd_6.run(primals_203, primals_18, buf43, 472800, grid=grid(472800), stream=stream0)
del primals_18
# Source Nodes: [x_20], Original ATen: [aten._scaled_dot_product_efficient_attention]
buf44 = aten._scaled_dot_product_efficient_attention.default(reinterpret_tensor(buf42, (8, 12, 197, 64), (453888, 64, 2304, 1), 0), reinterpret_tensor(buf42, (8, 12, 197, 64), (453888, 64, 2304, 1), 768), reinterpret_tensor(buf42, (8, 12, 197, 64), (453888, 64, 2304, 1), 1536), reinterpret_tensor(buf43, (8, 12, 197, 197), (0, 39424, 200, 1), 0), True)
buf45 = buf44[0]
buf46 = buf44[1]
buf47 = buf44[2]
buf48 = buf44[3]
del buf44
buf49 = empty_strided_cuda((768, 768), (768, 1), torch.float16)
# Source Nodes: [x_22], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_7.run(primals_132, buf49, 589824, grid=grid(589824), stream=stream0)
del primals_132
buf50 = buf32; del buf32 # reuse
# Source Nodes: [x_22], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_convolution_2.run(primals_133, buf50, 768, grid=grid(768), stream=stream0)
del primals_133
buf51 = empty_strided_cuda((1576, 768), (768, 1), torch.float16)
# Source Nodes: [x_22], Original ATen: [aten._to_copy, aten.addmm]
extern_kernels.addmm(buf50, reinterpret_tensor(buf45, (1576, 768), (768, 1), 0), reinterpret_tensor(buf49, (768, 768), (1, 768), 0), alpha=1, beta=1, out=buf51)
buf55 = empty_strided_cuda((8, 197, 768), (151296, 768, 1), torch.float32)
buf56 = empty_strided_cuda((8, 197, 768), (151296, 768, 1), torch.float16)
buf395 = empty_strided_cuda((8, 197, 1), (197, 1, 1), torch.float32)
# Source Nodes: [mul_2, x_24, x_25, x_26], Original ATen: [aten._to_copy, aten.add, aten.mul, aten.native_layer_norm, aten.native_layer_norm_backward]
triton_per_fused__to_copy_add_mul_native_layer_norm_native_layer_norm_backward_14.run(buf34, primals_12, buf51, primals_20, primals_21, buf55, buf56, buf395, 1576, 768, grid=grid(1576), stream=stream0)
del primals_21
buf57 = empty_strided_cuda((3072, 768), (768, 1), torch.float16)
# Source Nodes: [x_26], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_9.run(primals_134, buf57, 2359296, grid=grid(2359296), stream=stream0)
del primals_134
buf58 = buf28; del buf28 # reuse
# Source Nodes: [x_26], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_10.run(primals_135, buf58, 3072, grid=grid(3072), stream=stream0)
del primals_135
buf59 = empty_strided_cuda((1576, 3072), (3072, 1), torch.float16)
# Source Nodes: [x_26], Original ATen: [aten._to_copy, aten.addmm]
extern_kernels.addmm(buf58, reinterpret_tensor(buf56, (1576, 768), (768, 1), 0), reinterpret_tensor(buf57, (768, 3072), (1, 768), 0), alpha=1, beta=1, out=buf59)
buf60 = empty_strided_cuda((8, 197, 3072), (605184, 3072, 1), torch.float16)
# Source Nodes: [x_27], Original ATen: [aten.gelu]
triton_poi_fused_gelu_11.run(buf59, buf60, 4841472, grid=grid(4841472), stream=stream0)
buf61 = empty_strided_cuda((768, 3072), (3072, 1), torch.float16)
# Source Nodes: [x_30], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_12.run(primals_136, buf61, 2359296, grid=grid(2359296), stream=stream0)
del primals_136
buf62 = buf50; del buf50 # reuse
# Source Nodes: [x_30], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_convolution_2.run(primals_137, buf62, 768, grid=grid(768), stream=stream0)
del primals_137
buf63 = empty_strided_cuda((1576, 768), (768, 1), torch.float16)
# Source Nodes: [x_30], Original ATen: [aten._to_copy, aten.addmm]
extern_kernels.addmm(buf62, reinterpret_tensor(buf60, (1576, 3072), (3072, 1), 0), reinterpret_tensor(buf61, (3072, 768), (1, 3072), 0), alpha=1, beta=1, out=buf63)
buf64 = empty_strided_cuda((8, 197, 768), (151296, 768, 1), torch.float32)
buf68 = empty_strided_cuda((8, 197, 768), (151296, 768, 1), torch.float32)
buf69 = empty_strided_cuda((8, 197, 768), (151296, 768, 1), torch.float16)
buf394 = empty_strided_cuda((8, 197, 1), (197, 1, 1), torch.float32)
# Source Nodes: [mul_2, mul_3, qkv_4, x_24, x_32, x_33], Original ATen: [aten._to_copy, aten.add, aten.mul, aten.native_layer_norm, aten.native_layer_norm_backward]
triton_per_fused__to_copy_add_mul_native_layer_norm_native_layer_norm_backward_13.run(buf34, primals_12, buf51, primals_19, buf63, primals_23, primals_24, buf64, buf68, buf69, buf394, 1576, 768, grid=grid(1576), stream=stream0)
del primals_24
buf70 = empty_strided_cuda((2304, 768), (768, 1), torch.float16)
# Source Nodes: [qkv_4], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_4.run(primals_27, buf70, 1769472, grid=grid(1769472), stream=stream0)
del primals_27
buf71 = buf41; del buf41 # reuse
# Source Nodes: [qkv_4, qkv_bias_2], Original ATen: [aten._to_copy, aten.cat]
triton_poi_fused__to_copy_cat_5.run(primals_25, primals_204, primals_26, buf71, 2304, grid=grid(2304), stream=stream0)
del primals_204
del primals_25
del primals_26
buf72 = empty_strided_cuda((1576, 2304), (2304, 1), torch.float16)
# Source Nodes: [qkv_4, qkv_bias_2], Original ATen: [aten._to_copy, aten.addmm, aten.cat]
extern_kernels.addmm(buf71, reinterpret_tensor(buf69, (1576, 768), (768, 1), 0), reinterpret_tensor(buf70, (768, 2304), (1, 768), 0), alpha=1, beta=1, out=buf72)
buf73 = empty_strided_cuda((1, 12, 197, 200), (473088, 39424, 200, 1), torch.float16)
# Source Nodes: [x_34], Original ATen: [aten._to_copy, aten.constant_pad_nd]
triton_poi_fused__to_copy_constant_pad_nd_6.run(primals_205, primals_28, buf73, 472800, grid=grid(472800), stream=stream0)
del primals_28
# Source Nodes: [x_34], Original ATen: [aten._scaled_dot_product_efficient_attention]
buf74 = aten._scaled_dot_product_efficient_attention.default(reinterpret_tensor(buf72, (8, 12, 197, 64), (453888, 64, 2304, 1), 0), reinterpret_tensor(buf72, (8, 12, 197, 64), (453888, 64, 2304, 1), 768), reinterpret_tensor(buf72, (8, 12, 197, 64), (453888, 64, 2304, 1), 1536), reinterpret_tensor(buf73, (8, 12, 197, 197), (0, 39424, 200, 1), 0), True)
buf75 = buf74[0]
buf76 = buf74[1]
buf77 = buf74[2]
buf78 = buf74[3]
del buf74
buf79 = empty_strided_cuda((768, 768), (768, 1), torch.float16)
# Source Nodes: [x_36], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_7.run(primals_138, buf79, 589824, grid=grid(589824), stream=stream0)
del primals_138
buf80 = buf62; del buf62 # reuse
# Source Nodes: [x_36], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_convolution_2.run(primals_139, buf80, 768, grid=grid(768), stream=stream0)
del primals_139
buf81 = empty_strided_cuda((1576, 768), (768, 1), torch.float16)
# Source Nodes: [x_36], Original ATen: [aten._to_copy, aten.addmm]
extern_kernels.addmm(buf80, reinterpret_tensor(buf75, (1576, 768), (768, 1), 0), reinterpret_tensor(buf79, (768, 768), (1, 768), 0), alpha=1, beta=1, out=buf81)
buf85 = buf34; del buf34 # reuse
buf86 = empty_strided_cuda((8, 197, 768), (151296, 768, 1), torch.float16)
buf393 = empty_strided_cuda((8, 197, 1), (197, 1, 1), torch.float32)
# Source Nodes: [mul_4, x_38, x_39, x_40], Original ATen: [aten._to_copy, aten.add, aten.mul, aten.native_layer_norm, aten.native_layer_norm_backward]
triton_per_fused__to_copy_add_mul_native_layer_norm_native_layer_norm_backward_14.run(buf64, primals_22, buf81, primals_30, primals_31, buf85, buf86, buf393, 1576, 768, grid=grid(1576), stream=stream0)
del primals_31
buf87 = empty_strided_cuda((3072, 768), (768, 1), torch.float16)
# Source Nodes: [x_40], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_9.run(primals_140, buf87, 2359296, grid=grid(2359296), stream=stream0)
del primals_140
buf88 = buf58; del buf58 # reuse
# Source Nodes: [x_40], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_10.run(primals_141, buf88, 3072, grid=grid(3072), stream=stream0)
del primals_141
buf89 = empty_strided_cuda((1576, 3072), (3072, 1), torch.float16)
# Source Nodes: [x_40], Original ATen: [aten._to_copy, aten.addmm]
extern_kernels.addmm(buf88, reinterpret_tensor(buf86, (1576, 768), (768, 1), 0), reinterpret_tensor(buf87, (768, 3072), (1, 768), 0), alpha=1, beta=1, out=buf89)
buf90 = empty_strided_cuda((8, 197, 3072), (605184, 3072, 1), torch.float16)
# Source Nodes: [x_41], Original ATen: [aten.gelu]
triton_poi_fused_gelu_11.run(buf89, buf90, 4841472, grid=grid(4841472), stream=stream0)
buf91 = empty_strided_cuda((768, 3072), (3072, 1), torch.float16)
# Source Nodes: [x_44], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_12.run(primals_142, buf91, 2359296, grid=grid(2359296), stream=stream0)
del primals_142
buf92 = buf80; del buf80 # reuse
# Source Nodes: [x_44], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_convolution_2.run(primals_143, buf92, 768, grid=grid(768), stream=stream0)
del primals_143
buf93 = empty_strided_cuda((1576, 768), (768, 1), torch.float16)
# Source Nodes: [x_44], Original ATen: [aten._to_copy, aten.addmm]
extern_kernels.addmm(buf92, reinterpret_tensor(buf90, (1576, 3072), (3072, 1), 0), reinterpret_tensor(buf91, (3072, 768), (1, 3072), 0), alpha=1, beta=1, out=buf93)
buf94 = empty_strided_cuda((8, 197, 768), (151296, 768, 1), torch.float32)
buf98 = empty_strided_cuda((8, 197, 768), (151296, 768, 1), torch.float32)
buf99 = empty_strided_cuda((8, 197, 768), (151296, 768, 1), torch.float16)
buf392 = empty_strided_cuda((8, 197, 1), (197, 1, 1), torch.float32)
# Source Nodes: [mul_4, mul_5, qkv_6, x_38, x_46, x_47], Original ATen: [aten._to_copy, aten.add, aten.mul, aten.native_layer_norm, aten.native_layer_norm_backward]
triton_per_fused__to_copy_add_mul_native_layer_norm_native_layer_norm_backward_13.run(buf64, primals_22, buf81, primals_29, buf93, primals_33, primals_34, buf94, buf98, buf99, buf392, 1576, 768, grid=grid(1576), stream=stream0)
del primals_34
buf100 = empty_strided_cuda((2304, 768), (768, 1), torch.float16)
# Source Nodes: [qkv_6], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_4.run(primals_37, buf100, 1769472, grid=grid(1769472), stream=stream0)
del primals_37
buf101 = buf71; del buf71 # reuse
# Source Nodes: [qkv_6, qkv_bias_3], Original ATen: [aten._to_copy, aten.cat]
triton_poi_fused__to_copy_cat_5.run(primals_35, primals_206, primals_36, buf101, 2304, grid=grid(2304), stream=stream0)
del primals_206
del primals_35
del primals_36
buf102 = empty_strided_cuda((1576, 2304), (2304, 1), torch.float16)
# Source Nodes: [qkv_6, qkv_bias_3], Original ATen: [aten._to_copy, aten.addmm, aten.cat]
extern_kernels.addmm(buf101, reinterpret_tensor(buf99, (1576, 768), (768, 1), 0), reinterpret_tensor(buf100, (768, 2304), (1, 768), 0), alpha=1, beta=1, out=buf102)
buf103 = empty_strided_cuda((1, 12, 197, 200), (473088, 39424, 200, 1), torch.float16)
# Source Nodes: [x_48], Original ATen: [aten._to_copy, aten.constant_pad_nd]
triton_poi_fused__to_copy_constant_pad_nd_6.run(primals_207, primals_38, buf103, 472800, grid=grid(472800), stream=stream0)
del primals_38
# Source Nodes: [x_48], Original ATen: [aten._scaled_dot_product_efficient_attention]
buf104 = aten._scaled_dot_product_efficient_attention.default(reinterpret_tensor(buf102, (8, 12, 197, 64), (453888, 64, 2304, 1), 0), reinterpret_tensor(buf102, (8, 12, 197, 64), (453888, 64, 2304, 1), 768), reinterpret_tensor(buf102, (8, 12, 197, 64), (453888, 64, 2304, 1), 1536), reinterpret_tensor(buf103, (8, 12, 197, 197), (0, 39424, 200, 1), 0), True)
buf105 = buf104[0]
buf106 = buf104[1]
buf107 = buf104[2]
buf108 = buf104[3]
del buf104
buf109 = empty_strided_cuda((768, 768), (768, 1), torch.float16)
# Source Nodes: [x_50], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_7.run(primals_144, buf109, 589824, grid=grid(589824), stream=stream0)
del primals_144
buf110 = buf92; del buf92 # reuse
# Source Nodes: [x_50], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_convolution_2.run(primals_145, buf110, 768, grid=grid(768), stream=stream0)
del primals_145
buf111 = empty_strided_cuda((1576, 768), (768, 1), torch.float16)
# Source Nodes: [x_50], Original ATen: [aten._to_copy, aten.addmm]
extern_kernels.addmm(buf110, reinterpret_tensor(buf105, (1576, 768), (768, 1), 0), reinterpret_tensor(buf109, (768, 768), (1, 768), 0), alpha=1, beta=1, out=buf111)
buf115 = buf64; del buf64 # reuse
buf116 = empty_strided_cuda((8, 197, 768), (151296, 768, 1), torch.float16)
buf391 = empty_strided_cuda((8, 197, 1), (197, 1, 1), torch.float32)
# Source Nodes: [mul_6, x_52, x_53, x_54], Original ATen: [aten._to_copy, aten.add, aten.mul, aten.native_layer_norm, aten.native_layer_norm_backward]
triton_per_fused__to_copy_add_mul_native_layer_norm_native_layer_norm_backward_14.run(buf94, primals_32, buf111, primals_40, primals_41, buf115, buf116, buf391, 1576, 768, grid=grid(1576), stream=stream0)
del primals_41
buf117 = empty_strided_cuda((3072, 768), (768, 1), torch.float16)
# Source Nodes: [x_54], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_9.run(primals_146, buf117, 2359296, grid=grid(2359296), stream=stream0)
del primals_146
buf118 = buf88; del buf88 # reuse
# Source Nodes: [x_54], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_10.run(primals_147, buf118, 3072, grid=grid(3072), stream=stream0)
del primals_147
buf119 = empty_strided_cuda((1576, 3072), (3072, 1), torch.float16)
# Source Nodes: [x_54], Original ATen: [aten._to_copy, aten.addmm]
extern_kernels.addmm(buf118, reinterpret_tensor(buf116, (1576, 768), (768, 1), 0), reinterpret_tensor(buf117, (768, 3072), (1, 768), 0), alpha=1, beta=1, out=buf119)
buf120 = empty_strided_cuda((8, 197, 3072), (605184, 3072, 1), torch.float16)
# Source Nodes: [x_55], Original ATen: [aten.gelu]
triton_poi_fused_gelu_11.run(buf119, buf120, 4841472, grid=grid(4841472), stream=stream0)
buf121 = empty_strided_cuda((768, 3072), (3072, 1), torch.float16)
# Source Nodes: [x_58], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_12.run(primals_148, buf121, 2359296, grid=grid(2359296), stream=stream0)
del primals_148
buf122 = buf110; del buf110 # reuse
# Source Nodes: [x_58], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_convolution_2.run(primals_149, buf122, 768, grid=grid(768), stream=stream0)
del primals_149
buf123 = empty_strided_cuda((1576, 768), (768, 1), torch.float16)
# Source Nodes: [x_58], Original ATen: [aten._to_copy, aten.addmm]
extern_kernels.addmm(buf122, reinterpret_tensor(buf120, (1576, 3072), (3072, 1), 0), reinterpret_tensor(buf121, (3072, 768), (1, 3072), 0), alpha=1, beta=1, out=buf123)
buf124 = empty_strided_cuda((8, 197, 768), (151296, 768, 1), torch.float32)
buf128 = empty_strided_cuda((8, 197, 768), (151296, 768, 1), torch.float32)
buf129 = empty_strided_cuda((8, 197, 768), (151296, 768, 1), torch.float16)
buf390 = empty_strided_cuda((8, 197, 1), (197, 1, 1), torch.float32)
# Source Nodes: [mul_6, mul_7, qkv_8, x_52, x_60, x_61], Original ATen: [aten._to_copy, aten.add, aten.mul, aten.native_layer_norm, aten.native_layer_norm_backward]
triton_per_fused__to_copy_add_mul_native_layer_norm_native_layer_norm_backward_13.run(buf94, primals_32, buf111, primals_39, buf123, primals_43, primals_44, buf124, buf128, buf129, buf390, 1576, 768, grid=grid(1576), stream=stream0)
del primals_44
buf130 = empty_strided_cuda((2304, 768), (768, 1), torch.float16)
# Source Nodes: [qkv_8], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_4.run(primals_47, buf130, 1769472, grid=grid(1769472), stream=stream0)
del primals_47
buf131 = buf101; del buf101 # reuse
# Source Nodes: [qkv_8, qkv_bias_4], Original ATen: [aten._to_copy, aten.cat]
triton_poi_fused__to_copy_cat_5.run(primals_45, primals_208, primals_46, buf131, 2304, grid=grid(2304), stream=stream0)
del primals_208
del primals_45
del primals_46
buf132 = empty_strided_cuda((1576, 2304), (2304, 1), torch.float16)
# Source Nodes: [qkv_8, qkv_bias_4], Original ATen: [aten._to_copy, aten.addmm, aten.cat]
extern_kernels.addmm(buf131, reinterpret_tensor(buf129, (1576, 768), (768, 1), 0), reinterpret_tensor(buf130, (768, 2304), (1, 768), 0), alpha=1, beta=1, out=buf132)
buf133 = empty_strided_cuda((1, 12, 197, 200), (473088, 39424, 200, 1), torch.float16)
# Source Nodes: [x_62], Original ATen: [aten._to_copy, aten.constant_pad_nd]
triton_poi_fused__to_copy_constant_pad_nd_6.run(primals_209, primals_48, buf133, 472800, grid=grid(472800), stream=stream0)
del primals_48
# Source Nodes: [x_62], Original ATen: [aten._scaled_dot_product_efficient_attention]
buf134 = aten._scaled_dot_product_efficient_attention.default(reinterpret_tensor(buf132, (8, 12, 197, 64), (453888, 64, 2304, 1), 0), reinterpret_tensor(buf132, (8, 12, 197, 64), (453888, 64, 2304, 1), 768), reinterpret_tensor(buf132, (8, 12, 197, 64), (453888, 64, 2304, 1), 1536), reinterpret_tensor(buf133, (8, 12, 197, 197), (0, 39424, 200, 1), 0), True)
buf135 = buf134[0]
buf136 = buf134[1]
buf137 = buf134[2]
buf138 = buf134[3]
del buf134
buf139 = empty_strided_cuda((768, 768), (768, 1), torch.float16)
# Source Nodes: [x_64], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_7.run(primals_150, buf139, 589824, grid=grid(589824), stream=stream0)
del primals_150
buf140 = buf122; del buf122 # reuse
# Source Nodes: [x_64], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_convolution_2.run(primals_151, buf140, 768, grid=grid(768), stream=stream0)
del primals_151
buf141 = empty_strided_cuda((1576, 768), (768, 1), torch.float16)
# Source Nodes: [x_64], Original ATen: [aten._to_copy, aten.addmm]
extern_kernels.addmm(buf140, reinterpret_tensor(buf135, (1576, 768), (768, 1), 0), reinterpret_tensor(buf139, (768, 768), (1, 768), 0), alpha=1, beta=1, out=buf141)
buf145 = buf94; del buf94 # reuse
buf146 = empty_strided_cuda((8, 197, 768), (151296, 768, 1), torch.float16)
buf389 = empty_strided_cuda((8, 197, 1), (197, 1, 1), torch.float32)
# Source Nodes: [mul_8, x_66, x_67, x_68], Original ATen: [aten._to_copy, aten.add, aten.mul, aten.native_layer_norm, aten.native_layer_norm_backward]
triton_per_fused__to_copy_add_mul_native_layer_norm_native_layer_norm_backward_14.run(buf124, primals_42, buf141, primals_50, primals_51, buf145, buf146, buf389, 1576, 768, grid=grid(1576), stream=stream0)
del primals_51
buf147 = empty_strided_cuda((3072, 768), (768, 1), torch.float16)
# Source Nodes: [x_68], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_9.run(primals_152, buf147, 2359296, grid=grid(2359296), stream=stream0)
del primals_152
buf148 = buf118; del buf118 # reuse
# Source Nodes: [x_68], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_10.run(primals_153, buf148, 3072, grid=grid(3072), stream=stream0)
del primals_153
buf149 = empty_strided_cuda((1576, 3072), (3072, 1), torch.float16)
# Source Nodes: [x_68], Original ATen: [aten._to_copy, aten.addmm]
extern_kernels.addmm(buf148, reinterpret_tensor(buf146, (1576, 768), (768, 1), 0), reinterpret_tensor(buf147, (768, 3072), (1, 768), 0), alpha=1, beta=1, out=buf149)
buf150 = empty_strided_cuda((8, 197, 3072), (605184, 3072, 1), torch.float16)
# Source Nodes: [x_69], Original ATen: [aten.gelu]
triton_poi_fused_gelu_11.run(buf149, buf150, 4841472, grid=grid(4841472), stream=stream0)
buf151 = empty_strided_cuda((768, 3072), (3072, 1), torch.float16)
# Source Nodes: [x_72], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_12.run(primals_154, buf151, 2359296, grid=grid(2359296), stream=stream0)
del primals_154
buf152 = buf140; del buf140 # reuse
# Source Nodes: [x_72], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_convolution_2.run(primals_155, buf152, 768, grid=grid(768), stream=stream0)
del primals_155
buf153 = empty_strided_cuda((1576, 768), (768, 1), torch.float16)
# Source Nodes: [x_72], Original ATen: [aten._to_copy, aten.addmm]
extern_kernels.addmm(buf152, reinterpret_tensor(buf150, (1576, 3072), (3072, 1), 0), reinterpret_tensor(buf151, (3072, 768), (1, 3072), 0), alpha=1, beta=1, out=buf153)
buf154 = empty_strided_cuda((8, 197, 768), (151296, 768, 1), torch.float32)
buf158 = empty_strided_cuda((8, 197, 768), (151296, 768, 1), torch.float32)
buf159 = empty_strided_cuda((8, 197, 768), (151296, 768, 1), torch.float16)
buf388 = empty_strided_cuda((8, 197, 1), (197, 1, 1), torch.float32)
# Source Nodes: [mul_8, mul_9, qkv_10, x_66, x_74, x_75], Original ATen: [aten._to_copy, aten.add, aten.mul, aten.native_layer_norm, aten.native_layer_norm_backward]
triton_per_fused__to_copy_add_mul_native_layer_norm_native_layer_norm_backward_13.run(buf124, primals_42, buf141, primals_49, buf153, primals_53, primals_54, buf154, buf158, buf159, buf388, 1576, 768, grid=grid(1576), stream=stream0)
del primals_54
buf160 = empty_strided_cuda((2304, 768), (768, 1), torch.float16)
# Source Nodes: [qkv_10], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_4.run(primals_57, buf160, 1769472, grid=grid(1769472), stream=stream0)
del primals_57
buf161 = buf131; del buf131 # reuse
# Source Nodes: [qkv_10, qkv_bias_5], Original ATen: [aten._to_copy, aten.cat]
triton_poi_fused__to_copy_cat_5.run(primals_55, primals_210, primals_56, buf161, 2304, grid=grid(2304), stream=stream0)
del primals_210
del primals_55
del primals_56
buf162 = empty_strided_cuda((1576, 2304), (2304, 1), torch.float16)
# Source Nodes: [qkv_10, qkv_bias_5], Original ATen: [aten._to_copy, aten.addmm, aten.cat]
extern_kernels.addmm(buf161, reinterpret_tensor(buf159, (1576, 768), (768, 1), 0), reinterpret_tensor(buf160, (768, 2304), (1, 768), 0), alpha=1, beta=1, out=buf162)
buf163 = empty_strided_cuda((1, 12, 197, 200), (473088, 39424, 200, 1), torch.float16)
# Source Nodes: [x_76], Original ATen: [aten._to_copy, aten.constant_pad_nd]
triton_poi_fused__to_copy_constant_pad_nd_6.run(primals_211, primals_58, buf163, 472800, grid=grid(472800), stream=stream0)
del primals_58
# Source Nodes: [x_76], Original ATen: [aten._scaled_dot_product_efficient_attention]
buf164 = aten._scaled_dot_product_efficient_attention.default(reinterpret_tensor(buf162, (8, 12, 197, 64), (453888, 64, 2304, 1), 0), reinterpret_tensor(buf162, (8, 12, 197, 64), (453888, 64, 2304, 1), 768), reinterpret_tensor(buf162, (8, 12, 197, 64), (453888, 64, 2304, 1), 1536), reinterpret_tensor(buf163, (8, 12, 197, 197), (0, 39424, 200, 1), 0), True)
buf165 = buf164[0]
buf166 = buf164[1]
buf167 = buf164[2]
buf168 = buf164[3]
del buf164
buf169 = empty_strided_cuda((768, 768), (768, 1), torch.float16)
# Source Nodes: [x_78], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_7.run(primals_156, buf169, 589824, grid=grid(589824), stream=stream0)
del primals_156
buf170 = buf152; del buf152 # reuse
# Source Nodes: [x_78], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_convolution_2.run(primals_157, buf170, 768, grid=grid(768), stream=stream0)
del primals_157
buf171 = empty_strided_cuda((1576, 768), (768, 1), torch.float16)
# Source Nodes: [x_78], Original ATen: [aten._to_copy, aten.addmm]
extern_kernels.addmm(buf170, reinterpret_tensor(buf165, (1576, 768), (768, 1), 0), reinterpret_tensor(buf169, (768, 768), (1, 768), 0), alpha=1, beta=1, out=buf171)
buf175 = buf124; del buf124 # reuse
buf176 = empty_strided_cuda((8, 197, 768), (151296, 768, 1), torch.float16)
buf387 = empty_strided_cuda((8, 197, 1), (197, 1, 1), torch.float32)
# Source Nodes: [mul_10, x_80, x_81, x_82], Original ATen: [aten._to_copy, aten.add, aten.mul, aten.native_layer_norm, aten.native_layer_norm_backward]
triton_per_fused__to_copy_add_mul_native_layer_norm_native_layer_norm_backward_14.run(buf154, primals_52, buf171, primals_60, primals_61, buf175, buf176, buf387, 1576, 768, grid=grid(1576), stream=stream0)
del primals_61
buf177 = empty_strided_cuda((3072, 768), (768, 1), torch.float16)
# Source Nodes: [x_82], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_9.run(primals_158, buf177, 2359296, grid=grid(2359296), stream=stream0)
del primals_158
buf178 = buf148; del buf148 # reuse
# Source Nodes: [x_82], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_10.run(primals_159, buf178, 3072, grid=grid(3072), stream=stream0)
del primals_159
buf179 = empty_strided_cuda((1576, 3072), (3072, 1), torch.float16)
# Source Nodes: [x_82], Original ATen: [aten._to_copy, aten.addmm]
extern_kernels.addmm(buf178, reinterpret_tensor(buf176, (1576, 768), (768, 1), 0), reinterpret_tensor(buf177, (768, 3072), (1, 768), 0), alpha=1, beta=1, out=buf179)
buf180 = empty_strided_cuda((8, 197, 3072), (605184, 3072, 1), torch.float16)
# Source Nodes: [x_83], Original ATen: [aten.gelu]
triton_poi_fused_gelu_11.run(buf179, buf180, 4841472, grid=grid(4841472), stream=stream0)
buf181 = empty_strided_cuda((768, 3072), (3072, 1), torch.float16)
# Source Nodes: [x_86], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_12.run(primals_160, buf181, 2359296, grid=grid(2359296), stream=stream0)
del primals_160
buf182 = buf170; del buf170 # reuse
# Source Nodes: [x_86], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_convolution_2.run(primals_161, buf182, 768, grid=grid(768), stream=stream0)
del primals_161
buf183 = empty_strided_cuda((1576, 768), (768, 1), torch.float16)
# Source Nodes: [x_86], Original ATen: [aten._to_copy, aten.addmm]
extern_kernels.addmm(buf182, reinterpret_tensor(buf180, (1576, 3072), (3072, 1), 0), reinterpret_tensor(buf181, (3072, 768), (1, 3072), 0), alpha=1, beta=1, out=buf183)
buf184 = empty_strided_cuda((8, 197, 768), (151296, 768, 1), torch.float32)
buf188 = empty_strided_cuda((8, 197, 768), (151296, 768, 1), torch.float32)
buf189 = empty_strided_cuda((8, 197, 768), (151296, 768, 1), torch.float16)
buf386 = empty_strided_cuda((8, 197, 1), (197, 1, 1), torch.float32)
# Source Nodes: [mul_10, mul_11, qkv_12, x_80, x_88, x_89], Original ATen: [aten._to_copy, aten.add, aten.mul, aten.native_layer_norm, aten.native_layer_norm_backward]
triton_per_fused__to_copy_add_mul_native_layer_norm_native_layer_norm_backward_13.run(buf154, primals_52, buf171, primals_59, buf183, primals_63, primals_64, buf184, buf188, buf189, buf386, 1576, 768, grid=grid(1576), stream=stream0)
del primals_64
buf190 = empty_strided_cuda((2304, 768), (768, 1), torch.float16)
# Source Nodes: [qkv_12], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_4.run(primals_67, buf190, 1769472, grid=grid(1769472), stream=stream0)
del primals_67
buf191 = buf161; del buf161 # reuse
# Source Nodes: [qkv_12, qkv_bias_6], Original ATen: [aten._to_copy, aten.cat]
triton_poi_fused__to_copy_cat_5.run(primals_65, primals_212, primals_66, buf191, 2304, grid=grid(2304), stream=stream0)
del primals_212
del primals_65
del primals_66
buf192 = empty_strided_cuda((1576, 2304), (2304, 1), torch.float16)
# Source Nodes: [qkv_12, qkv_bias_6], Original ATen: [aten._to_copy, aten.addmm, aten.cat]
extern_kernels.addmm(buf191, reinterpret_tensor(buf189, (1576, 768), (768, 1), 0), reinterpret_tensor(buf190, (768, 2304), (1, 768), 0), alpha=1, beta=1, out=buf192)
buf193 = empty_strided_cuda((1, 12, 197, 200), (473088, 39424, 200, 1), torch.float16)
# Source Nodes: [x_90], Original ATen: [aten._to_copy, aten.constant_pad_nd]
triton_poi_fused__to_copy_constant_pad_nd_6.run(primals_213, primals_68, buf193, 472800, grid=grid(472800), stream=stream0)
del primals_68
# Source Nodes: [x_90], Original ATen: [aten._scaled_dot_product_efficient_attention]
buf194 = aten._scaled_dot_product_efficient_attention.default(reinterpret_tensor(buf192, (8, 12, 197, 64), (453888, 64, 2304, 1), 0), reinterpret_tensor(buf192, (8, 12, 197, 64), (453888, 64, 2304, 1), 768), reinterpret_tensor(buf192, (8, 12, 197, 64), (453888, 64, 2304, 1), 1536), reinterpret_tensor(buf193, (8, 12, 197, 197), (0, 39424, 200, 1), 0), True)
buf195 = buf194[0]
buf196 = buf194[1]
buf197 = buf194[2]
buf198 = buf194[3]
del buf194
buf199 = empty_strided_cuda((768, 768), (768, 1), torch.float16)
# Source Nodes: [x_92], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_7.run(primals_162, buf199, 589824, grid=grid(589824), stream=stream0)
del primals_162
buf200 = buf182; del buf182 # reuse
# Source Nodes: [x_92], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_convolution_2.run(primals_163, buf200, 768, grid=grid(768), stream=stream0)
del primals_163
buf201 = empty_strided_cuda((1576, 768), (768, 1), torch.float16)
# Source Nodes: [x_92], Original ATen: [aten._to_copy, aten.addmm]
extern_kernels.addmm(buf200, reinterpret_tensor(buf195, (1576, 768), (768, 1), 0), reinterpret_tensor(buf199, (768, 768), (1, 768), 0), alpha=1, beta=1, out=buf201)
buf205 = buf154; del buf154 # reuse
buf206 = empty_strided_cuda((8, 197, 768), (151296, 768, 1), torch.float16)
buf385 = empty_strided_cuda((8, 197, 1), (197, 1, 1), torch.float32)
# Source Nodes: [mul_12, x_94, x_95, x_96], Original ATen: [aten._to_copy, aten.add, aten.mul, aten.native_layer_norm, aten.native_layer_norm_backward]
triton_per_fused__to_copy_add_mul_native_layer_norm_native_layer_norm_backward_14.run(buf184, primals_62, buf201, primals_70, primals_71, buf205, buf206, buf385, 1576, 768, grid=grid(1576), stream=stream0)
del primals_71
buf207 = empty_strided_cuda((3072, 768), (768, 1), torch.float16)
# Source Nodes: [x_96], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_9.run(primals_164, buf207, 2359296, grid=grid(2359296), stream=stream0)
del primals_164
buf208 = buf178; del buf178 # reuse
# Source Nodes: [x_96], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_10.run(primals_165, buf208, 3072, grid=grid(3072), stream=stream0)
del primals_165
buf209 = empty_strided_cuda((1576, 3072), (3072, 1), torch.float16)
# Source Nodes: [x_96], Original ATen: [aten._to_copy, aten.addmm]
extern_kernels.addmm(buf208, reinterpret_tensor(buf206, (1576, 768), (768, 1), 0), reinterpret_tensor(buf207, (768, 3072), (1, 768), 0), alpha=1, beta=1, out=buf209)
buf210 = empty_strided_cuda((8, 197, 3072), (605184, 3072, 1), torch.float16)
# Source Nodes: [x_97], Original ATen: [aten.gelu]
triton_poi_fused_gelu_11.run(buf209, buf210, 4841472, grid=grid(4841472), stream=stream0)
buf211 = empty_strided_cuda((768, 3072), (3072, 1), torch.float16)
# Source Nodes: [x_100], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_12.run(primals_166, buf211, 2359296, grid=grid(2359296), stream=stream0)
del primals_166
buf212 = buf200; del buf200 # reuse
# Source Nodes: [x_100], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_convolution_2.run(primals_167, buf212, 768, grid=grid(768), stream=stream0)
del primals_167
buf213 = empty_strided_cuda((1576, 768), (768, 1), torch.float16)
# Source Nodes: [x_100], Original ATen: [aten._to_copy, aten.addmm]
extern_kernels.addmm(buf212, reinterpret_tensor(buf210, (1576, 3072), (3072, 1), 0), reinterpret_tensor(buf211, (3072, 768), (1, 3072), 0), alpha=1, beta=1, out=buf213)
buf214 = empty_strided_cuda((8, 197, 768), (151296, 768, 1), torch.float32)
buf218 = empty_strided_cuda((8, 197, 768), (151296, 768, 1), torch.float32)
buf219 = empty_strided_cuda((8, 197, 768), (151296, 768, 1), torch.float16)
buf384 = empty_strided_cuda((8, 197, 1), (197, 1, 1), torch.float32)
# Source Nodes: [mul_12, mul_13, qkv_14, x_102, x_103, x_94], Original ATen: [aten._to_copy, aten.add, aten.mul, aten.native_layer_norm, aten.native_layer_norm_backward]
triton_per_fused__to_copy_add_mul_native_layer_norm_native_layer_norm_backward_13.run(buf184, primals_62, buf201, primals_69, buf213, primals_73, primals_74, buf214, buf218, buf219, buf384, 1576, 768, grid=grid(1576), stream=stream0)
del primals_74
buf220 = empty_strided_cuda((2304, 768), (768, 1), torch.float16)
# Source Nodes: [qkv_14], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_4.run(primals_77, buf220, 1769472, grid=grid(1769472), stream=stream0)
del primals_77
buf221 = buf191; del buf191 # reuse
# Source Nodes: [qkv_14, qkv_bias_7], Original ATen: [aten._to_copy, aten.cat]
triton_poi_fused__to_copy_cat_5.run(primals_75, primals_214, primals_76, buf221, 2304, grid=grid(2304), stream=stream0)
del primals_214
del primals_75
del primals_76
buf222 = empty_strided_cuda((1576, 2304), (2304, 1), torch.float16)
# Source Nodes: [qkv_14, qkv_bias_7], Original ATen: [aten._to_copy, aten.addmm, aten.cat]
extern_kernels.addmm(buf221, reinterpret_tensor(buf219, (1576, 768), (768, 1), 0), reinterpret_tensor(buf220, (768, 2304), (1, 768), 0), alpha=1, beta=1, out=buf222)
buf223 = empty_strided_cuda((1, 12, 197, 200), (473088, 39424, 200, 1), torch.float16)
# Source Nodes: [x_104], Original ATen: [aten._to_copy, aten.constant_pad_nd]
triton_poi_fused__to_copy_constant_pad_nd_6.run(primals_215, primals_78, buf223, 472800, grid=grid(472800), stream=stream0)
del primals_78
# Source Nodes: [x_104], Original ATen: [aten._scaled_dot_product_efficient_attention]
buf224 = aten._scaled_dot_product_efficient_attention.default(reinterpret_tensor(buf222, (8, 12, 197, 64), (453888, 64, 2304, 1), 0), reinterpret_tensor(buf222, (8, 12, 197, 64), (453888, 64, 2304, 1), 768), reinterpret_tensor(buf222, (8, 12, 197, 64), (453888, 64, 2304, 1), 1536), reinterpret_tensor(buf223, (8, 12, 197, 197), (0, 39424, 200, 1), 0), True)
buf225 = buf224[0]
buf226 = buf224[1]
buf227 = buf224[2]
buf228 = buf224[3]
del buf224
buf229 = empty_strided_cuda((768, 768), (768, 1), torch.float16)
# Source Nodes: [x_106], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_7.run(primals_168, buf229, 589824, grid=grid(589824), stream=stream0)
del primals_168
buf230 = buf212; del buf212 # reuse
# Source Nodes: [x_106], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_convolution_2.run(primals_169, buf230, 768, grid=grid(768), stream=stream0)
del primals_169
buf231 = empty_strided_cuda((1576, 768), (768, 1), torch.float16)
# Source Nodes: [x_106], Original ATen: [aten._to_copy, aten.addmm]
extern_kernels.addmm(buf230, reinterpret_tensor(buf225, (1576, 768), (768, 1), 0), reinterpret_tensor(buf229, (768, 768), (1, 768), 0), alpha=1, beta=1, out=buf231)
buf235 = buf184; del buf184 # reuse
buf236 = empty_strided_cuda((8, 197, 768), (151296, 768, 1), torch.float16)
buf383 = empty_strided_cuda((8, 197, 1), (197, 1, 1), torch.float32)
# Source Nodes: [mul_14, x_108, x_109, x_110], Original ATen: [aten._to_copy, aten.add, aten.mul, aten.native_layer_norm, aten.native_layer_norm_backward]
triton_per_fused__to_copy_add_mul_native_layer_norm_native_layer_norm_backward_14.run(buf214, primals_72, buf231, primals_80, primals_81, buf235, buf236, buf383, 1576, 768, grid=grid(1576), stream=stream0)
del primals_81
buf237 = empty_strided_cuda((3072, 768), (768, 1), torch.float16)
# Source Nodes: [x_110], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_9.run(primals_170, buf237, 2359296, grid=grid(2359296), stream=stream0)
del primals_170
buf238 = buf208; del buf208 # reuse
# Source Nodes: [x_110], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_10.run(primals_171, buf238, 3072, grid=grid(3072), stream=stream0)
del primals_171
buf239 = empty_strided_cuda((1576, 3072), (3072, 1), torch.float16)
# Source Nodes: [x_110], Original ATen: [aten._to_copy, aten.addmm]
extern_kernels.addmm(buf238, reinterpret_tensor(buf236, (1576, 768), (768, 1), 0), reinterpret_tensor(buf237, (768, 3072), (1, 768), 0), alpha=1, beta=1, out=buf239)
buf240 = empty_strided_cuda((8, 197, 3072), (605184, 3072, 1), torch.float16)
# Source Nodes: [x_111], Original ATen: [aten.gelu]
triton_poi_fused_gelu_11.run(buf239, buf240, 4841472, grid=grid(4841472), stream=stream0)
buf241 = empty_strided_cuda((768, 3072), (3072, 1), torch.float16)
# Source Nodes: [x_114], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_12.run(primals_172, buf241, 2359296, grid=grid(2359296), stream=stream0)
del primals_172
buf242 = buf230; del buf230 # reuse
# Source Nodes: [x_114], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_convolution_2.run(primals_173, buf242, 768, grid=grid(768), stream=stream0)
del primals_173
buf243 = empty_strided_cuda((1576, 768), (768, 1), torch.float16)
# Source Nodes: [x_114], Original ATen: [aten._to_copy, aten.addmm]
extern_kernels.addmm(buf242, reinterpret_tensor(buf240, (1576, 3072), (3072, 1), 0), reinterpret_tensor(buf241, (3072, 768), (1, 3072), 0), alpha=1, beta=1, out=buf243)
buf244 = empty_strided_cuda((8, 197, 768), (151296, 768, 1), torch.float32)
buf248 = empty_strided_cuda((8, 197, 768), (151296, 768, 1), torch.float32)
buf249 = empty_strided_cuda((8, 197, 768), (151296, 768, 1), torch.float16)
buf382 = empty_strided_cuda((8, 197, 1), (197, 1, 1), torch.float32)
# Source Nodes: [mul_14, mul_15, qkv_16, x_108, x_116, x_117], Original ATen: [aten._to_copy, aten.add, aten.mul, aten.native_layer_norm, aten.native_layer_norm_backward]
triton_per_fused__to_copy_add_mul_native_layer_norm_native_layer_norm_backward_13.run(buf214, primals_72, buf231, primals_79, buf243, primals_83, primals_84, buf244, buf248, buf249, buf382, 1576, 768, grid=grid(1576), stream=stream0)
del primals_84
buf250 = empty_strided_cuda((2304, 768), (768, 1), torch.float16)
# Source Nodes: [qkv_16], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_4.run(primals_87, buf250, 1769472, grid=grid(1769472), stream=stream0)
del primals_87
buf251 = buf221; del buf221 # reuse
# Source Nodes: [qkv_16, qkv_bias_8], Original ATen: [aten._to_copy, aten.cat]
triton_poi_fused__to_copy_cat_5.run(primals_85, primals_216, primals_86, buf251, 2304, grid=grid(2304), stream=stream0)
del primals_216
del primals_85
del primals_86
buf252 = empty_strided_cuda((1576, 2304), (2304, 1), torch.float16)
# Source Nodes: [qkv_16, qkv_bias_8], Original ATen: [aten._to_copy, aten.addmm, aten.cat]
extern_kernels.addmm(buf251, reinterpret_tensor(buf249, (1576, 768), (768, 1), 0), reinterpret_tensor(buf250, (768, 2304), (1, 768), 0), alpha=1, beta=1, out=buf252)
buf253 = empty_strided_cuda((1, 12, 197, 200), (473088, 39424, 200, 1), torch.float16)
# Source Nodes: [x_118], Original ATen: [aten._to_copy, aten.constant_pad_nd]
triton_poi_fused__to_copy_constant_pad_nd_6.run(primals_217, primals_88, buf253, 472800, grid=grid(472800), stream=stream0)
del primals_88
# Source Nodes: [x_118], Original ATen: [aten._scaled_dot_product_efficient_attention]
buf254 = aten._scaled_dot_product_efficient_attention.default(reinterpret_tensor(buf252, (8, 12, 197, 64), (453888, 64, 2304, 1), 0), reinterpret_tensor(buf252, (8, 12, 197, 64), (453888, 64, 2304, 1), 768), reinterpret_tensor(buf252, (8, 12, 197, 64), (453888, 64, 2304, 1), 1536), reinterpret_tensor(buf253, (8, 12, 197, 197), (0, 39424, 200, 1), 0), True)
buf255 = buf254[0]
buf256 = buf254[1]
buf257 = buf254[2]
buf258 = buf254[3]
del buf254
buf259 = empty_strided_cuda((768, 768), (768, 1), torch.float16)
# Source Nodes: [x_120], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_7.run(primals_174, buf259, 589824, grid=grid(589824), stream=stream0)
del primals_174
buf260 = buf242; del buf242 # reuse
# Source Nodes: [x_120], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_convolution_2.run(primals_175, buf260, 768, grid=grid(768), stream=stream0)
del primals_175
buf261 = empty_strided_cuda((1576, 768), (768, 1), torch.float16)
# Source Nodes: [x_120], Original ATen: [aten._to_copy, aten.addmm]
extern_kernels.addmm(buf260, reinterpret_tensor(buf255, (1576, 768), (768, 1), 0), reinterpret_tensor(buf259, (768, 768), (1, 768), 0), alpha=1, beta=1, out=buf261)
buf265 = buf214; del buf214 # reuse
buf266 = empty_strided_cuda((8, 197, 768), (151296, 768, 1), torch.float16)
buf381 = empty_strided_cuda((8, 197, 1), (197, 1, 1), torch.float32)
# Source Nodes: [mul_16, x_122, x_123, x_124], Original ATen: [aten._to_copy, aten.add, aten.mul, aten.native_layer_norm, aten.native_layer_norm_backward]
triton_per_fused__to_copy_add_mul_native_layer_norm_native_layer_norm_backward_14.run(buf244, primals_82, buf261, primals_90, primals_91, buf265, buf266, buf381, 1576, 768, grid=grid(1576), stream=stream0)
del primals_91
buf267 = empty_strided_cuda((3072, 768), (768, 1), torch.float16)
# Source Nodes: [x_124], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_9.run(primals_176, buf267, 2359296, grid=grid(2359296), stream=stream0)
del primals_176
buf268 = buf238; del buf238 # reuse
# Source Nodes: [x_124], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_10.run(primals_177, buf268, 3072, grid=grid(3072), stream=stream0)
del primals_177
buf269 = empty_strided_cuda((1576, 3072), (3072, 1), torch.float16)
# Source Nodes: [x_124], Original ATen: [aten._to_copy, aten.addmm]
extern_kernels.addmm(buf268, reinterpret_tensor(buf266, (1576, 768), (768, 1), 0), reinterpret_tensor(buf267, (768, 3072), (1, 768), 0), alpha=1, beta=1, out=buf269)
buf270 = empty_strided_cuda((8, 197, 3072), (605184, 3072, 1), torch.float16)
# Source Nodes: [x_125], Original ATen: [aten.gelu]
triton_poi_fused_gelu_11.run(buf269, buf270, 4841472, grid=grid(4841472), stream=stream0)
buf271 = empty_strided_cuda((768, 3072), (3072, 1), torch.float16)
# Source Nodes: [x_128], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_12.run(primals_178, buf271, 2359296, grid=grid(2359296), stream=stream0)
del primals_178
buf272 = buf260; del buf260 # reuse
# Source Nodes: [x_128], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_convolution_2.run(primals_179, buf272, 768, grid=grid(768), stream=stream0)
del primals_179
buf273 = empty_strided_cuda((1576, 768), (768, 1), torch.float16)
# Source Nodes: [x_128], Original ATen: [aten._to_copy, aten.addmm]
extern_kernels.addmm(buf272, reinterpret_tensor(buf270, (1576, 3072), (3072, 1), 0), reinterpret_tensor(buf271, (3072, 768), (1, 3072), 0), alpha=1, beta=1, out=buf273)
buf274 = empty_strided_cuda((8, 197, 768), (151296, 768, 1), torch.float32)
buf278 = empty_strided_cuda((8, 197, 768), (151296, 768, 1), torch.float32)
buf279 = empty_strided_cuda((8, 197, 768), (151296, 768, 1), torch.float16)
buf380 = empty_strided_cuda((8, 197, 1), (197, 1, 1), torch.float32)
# Source Nodes: [mul_16, mul_17, qkv_18, x_122, x_130, x_131], Original ATen: [aten._to_copy, aten.add, aten.mul, aten.native_layer_norm, aten.native_layer_norm_backward]
triton_per_fused__to_copy_add_mul_native_layer_norm_native_layer_norm_backward_13.run(buf244, primals_82, buf261, primals_89, buf273, primals_93, primals_94, buf274, buf278, buf279, buf380, 1576, 768, grid=grid(1576), stream=stream0)
del primals_94
buf280 = empty_strided_cuda((2304, 768), (768, 1), torch.float16)
# Source Nodes: [qkv_18], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_4.run(primals_97, buf280, 1769472, grid=grid(1769472), stream=stream0)
del primals_97
buf281 = buf251; del buf251 # reuse
# Source Nodes: [qkv_18, qkv_bias_9], Original ATen: [aten._to_copy, aten.cat]
triton_poi_fused__to_copy_cat_5.run(primals_95, primals_218, primals_96, buf281, 2304, grid=grid(2304), stream=stream0)
del primals_218
del primals_95
del primals_96
buf282 = empty_strided_cuda((1576, 2304), (2304, 1), torch.float16)
# Source Nodes: [qkv_18, qkv_bias_9], Original ATen: [aten._to_copy, aten.addmm, aten.cat]
extern_kernels.addmm(buf281, reinterpret_tensor(buf279, (1576, 768), (768, 1), 0), reinterpret_tensor(buf280, (768, 2304), (1, 768), 0), alpha=1, beta=1, out=buf282)
buf283 = empty_strided_cuda((1, 12, 197, 200), (473088, 39424, 200, 1), torch.float16)
# Source Nodes: [x_132], Original ATen: [aten._to_copy, aten.constant_pad_nd]
triton_poi_fused__to_copy_constant_pad_nd_6.run(primals_219, primals_98, buf283, 472800, grid=grid(472800), stream=stream0)
del primals_98
# Source Nodes: [x_132], Original ATen: [aten._scaled_dot_product_efficient_attention]
buf284 = aten._scaled_dot_product_efficient_attention.default(reinterpret_tensor(buf282, (8, 12, 197, 64), (453888, 64, 2304, 1), 0), reinterpret_tensor(buf282, (8, 12, 197, 64), (453888, 64, 2304, 1), 768), reinterpret_tensor(buf282, (8, 12, 197, 64), (453888, 64, 2304, 1), 1536), reinterpret_tensor(buf283, (8, 12, 197, 197), (0, 39424, 200, 1), 0), True)
buf285 = buf284[0]
buf286 = buf284[1]
buf287 = buf284[2]
buf288 = buf284[3]
del buf284
buf289 = empty_strided_cuda((768, 768), (768, 1), torch.float16)
# Source Nodes: [x_134], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_7.run(primals_180, buf289, 589824, grid=grid(589824), stream=stream0)
del primals_180
buf290 = buf272; del buf272 # reuse
# Source Nodes: [x_134], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_convolution_2.run(primals_181, buf290, 768, grid=grid(768), stream=stream0)
del primals_181
buf291 = empty_strided_cuda((1576, 768), (768, 1), torch.float16)
# Source Nodes: [x_134], Original ATen: [aten._to_copy, aten.addmm]
extern_kernels.addmm(buf290, reinterpret_tensor(buf285, (1576, 768), (768, 1), 0), reinterpret_tensor(buf289, (768, 768), (1, 768), 0), alpha=1, beta=1, out=buf291)
buf295 = buf244; del buf244 # reuse
buf296 = empty_strided_cuda((8, 197, 768), (151296, 768, 1), torch.float16)
buf379 = empty_strided_cuda((8, 197, 1), (197, 1, 1), torch.float32)
# Source Nodes: [mul_18, x_136, x_137, x_138], Original ATen: [aten._to_copy, aten.add, aten.mul, aten.native_layer_norm, aten.native_layer_norm_backward]
triton_per_fused__to_copy_add_mul_native_layer_norm_native_layer_norm_backward_14.run(buf274, primals_92, buf291, primals_100, primals_101, buf295, buf296, buf379, 1576, 768, grid=grid(1576), stream=stream0)
del primals_101
buf297 = empty_strided_cuda((3072, 768), (768, 1), torch.float16)
# Source Nodes: [x_138], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_9.run(primals_182, buf297, 2359296, grid=grid(2359296), stream=stream0)
del primals_182
buf298 = buf268; del buf268 # reuse
# Source Nodes: [x_138], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_10.run(primals_183, buf298, 3072, grid=grid(3072), stream=stream0)
del primals_183
buf299 = empty_strided_cuda((1576, 3072), (3072, 1), torch.float16)
# Source Nodes: [x_138], Original ATen: [aten._to_copy, aten.addmm]
extern_kernels.addmm(buf298, reinterpret_tensor(buf296, (1576, 768), (768, 1), 0), reinterpret_tensor(buf297, (768, 3072), (1, 768), 0), alpha=1, beta=1, out=buf299)
buf300 = empty_strided_cuda((8, 197, 3072), (605184, 3072, 1), torch.float16)
# Source Nodes: [x_139], Original ATen: [aten.gelu]
triton_poi_fused_gelu_11.run(buf299, buf300, 4841472, grid=grid(4841472), stream=stream0)
buf301 = empty_strided_cuda((768, 3072), (3072, 1), torch.float16)
# Source Nodes: [x_142], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_12.run(primals_184, buf301, 2359296, grid=grid(2359296), stream=stream0)
del primals_184
buf302 = buf290; del buf290 # reuse
# Source Nodes: [x_142], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_convolution_2.run(primals_185, buf302, 768, grid=grid(768), stream=stream0)
del primals_185
buf303 = empty_strided_cuda((1576, 768), (768, 1), torch.float16)
# Source Nodes: [x_142], Original ATen: [aten._to_copy, aten.addmm]
extern_kernels.addmm(buf302, reinterpret_tensor(buf300, (1576, 3072), (3072, 1), 0), reinterpret_tensor(buf301, (3072, 768), (1, 3072), 0), alpha=1, beta=1, out=buf303)
buf304 = empty_strided_cuda((8, 197, 768), (151296, 768, 1), torch.float32)
buf308 = empty_strided_cuda((8, 197, 768), (151296, 768, 1), torch.float32)
buf309 = empty_strided_cuda((8, 197, 768), (151296, 768, 1), torch.float16)
buf378 = empty_strided_cuda((8, 197, 1), (197, 1, 1), torch.float32)
# Source Nodes: [mul_18, mul_19, qkv_20, x_136, x_144, x_145], Original ATen: [aten._to_copy, aten.add, aten.mul, aten.native_layer_norm, aten.native_layer_norm_backward]
triton_per_fused__to_copy_add_mul_native_layer_norm_native_layer_norm_backward_13.run(buf274, primals_92, buf291, primals_99, buf303, primals_103, primals_104, buf304, buf308, buf309, buf378, 1576, 768, grid=grid(1576), stream=stream0)
del primals_104
buf310 = empty_strided_cuda((2304, 768), (768, 1), torch.float16)
# Source Nodes: [qkv_20], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_4.run(primals_107, buf310, 1769472, grid=grid(1769472), stream=stream0)
del primals_107
buf311 = buf281; del buf281 # reuse
# Source Nodes: [qkv_20, qkv_bias_10], Original ATen: [aten._to_copy, aten.cat]
triton_poi_fused__to_copy_cat_5.run(primals_105, primals_220, primals_106, buf311, 2304, grid=grid(2304), stream=stream0)
del primals_105
del primals_106
del primals_220
buf312 = empty_strided_cuda((1576, 2304), (2304, 1), torch.float16)
# Source Nodes: [qkv_20, qkv_bias_10], Original ATen: [aten._to_copy, aten.addmm, aten.cat]
extern_kernels.addmm(buf311, reinterpret_tensor(buf309, (1576, 768), (768, 1), 0), reinterpret_tensor(buf310, (768, 2304), (1, 768), 0), alpha=1, beta=1, out=buf312)
buf313 = empty_strided_cuda((1, 12, 197, 200), (473088, 39424, 200, 1), torch.float16)
# Source Nodes: [x_146], Original ATen: [aten._to_copy, aten.constant_pad_nd]
triton_poi_fused__to_copy_constant_pad_nd_6.run(primals_221, primals_108, buf313, 472800, grid=grid(472800), stream=stream0)
del primals_108
# Source Nodes: [x_146], Original ATen: [aten._scaled_dot_product_efficient_attention]
buf314 = aten._scaled_dot_product_efficient_attention.default(reinterpret_tensor(buf312, (8, 12, 197, 64), (453888, 64, 2304, 1), 0), reinterpret_tensor(buf312, (8, 12, 197, 64), (453888, 64, 2304, 1), 768), reinterpret_tensor(buf312, (8, 12, 197, 64), (453888, 64, 2304, 1), 1536), reinterpret_tensor(buf313, (8, 12, 197, 197), (0, 39424, 200, 1), 0), True)
buf315 = buf314[0]
buf316 = buf314[1]
buf317 = buf314[2]
buf318 = buf314[3]
del buf314
buf319 = empty_strided_cuda((768, 768), (768, 1), torch.float16)
# Source Nodes: [x_148], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_7.run(primals_186, buf319, 589824, grid=grid(589824), stream=stream0)
del primals_186
buf320 = buf302; del buf302 # reuse
# Source Nodes: [x_148], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_convolution_2.run(primals_187, buf320, 768, grid=grid(768), stream=stream0)
del primals_187
buf321 = empty_strided_cuda((1576, 768), (768, 1), torch.float16)
# Source Nodes: [x_148], Original ATen: [aten._to_copy, aten.addmm]
extern_kernels.addmm(buf320, reinterpret_tensor(buf315, (1576, 768), (768, 1), 0), reinterpret_tensor(buf319, (768, 768), (1, 768), 0), alpha=1, beta=1, out=buf321)
buf325 = buf274; del buf274 # reuse
buf326 = empty_strided_cuda((8, 197, 768), (151296, 768, 1), torch.float16)
buf377 = empty_strided_cuda((8, 197, 1), (197, 1, 1), torch.float32)
# Source Nodes: [mul_20, x_150, x_151, x_152], Original ATen: [aten._to_copy, aten.add, aten.mul, aten.native_layer_norm, aten.native_layer_norm_backward]
triton_per_fused__to_copy_add_mul_native_layer_norm_native_layer_norm_backward_14.run(buf304, primals_102, buf321, primals_110, primals_111, buf325, buf326, buf377, 1576, 768, grid=grid(1576), stream=stream0)
del primals_111
buf327 = empty_strided_cuda((3072, 768), (768, 1), torch.float16)
# Source Nodes: [x_152], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_9.run(primals_188, buf327, 2359296, grid=grid(2359296), stream=stream0)
del primals_188
buf328 = buf298; del buf298 # reuse
# Source Nodes: [x_152], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_10.run(primals_189, buf328, 3072, grid=grid(3072), stream=stream0)
del primals_189
buf329 = empty_strided_cuda((1576, 3072), (3072, 1), torch.float16)
# Source Nodes: [x_152], Original ATen: [aten._to_copy, aten.addmm]
extern_kernels.addmm(buf328, reinterpret_tensor(buf326, (1576, 768), (768, 1), 0), reinterpret_tensor(buf327, (768, 3072), (1, 768), 0), alpha=1, beta=1, out=buf329)
buf330 = empty_strided_cuda((8, 197, 3072), (605184, 3072, 1), torch.float16)
# Source Nodes: [x_153], Original ATen: [aten.gelu]
triton_poi_fused_gelu_11.run(buf329, buf330, 4841472, grid=grid(4841472), stream=stream0)
buf331 = empty_strided_cuda((768, 3072), (3072, 1), torch.float16)
# Source Nodes: [x_156], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_12.run(primals_190, buf331, 2359296, grid=grid(2359296), stream=stream0)
del primals_190
buf332 = buf320; del buf320 # reuse
# Source Nodes: [x_156], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_convolution_2.run(primals_191, buf332, 768, grid=grid(768), stream=stream0)
del primals_191
buf333 = empty_strided_cuda((1576, 768), (768, 1), torch.float16)
# Source Nodes: [x_156], Original ATen: [aten._to_copy, aten.addmm]
extern_kernels.addmm(buf332, reinterpret_tensor(buf330, (1576, 3072), (3072, 1), 0), reinterpret_tensor(buf331, (3072, 768), (1, 3072), 0), alpha=1, beta=1, out=buf333)
buf334 = empty_strided_cuda((8, 197, 768), (151296, 768, 1), torch.float32)
buf338 = empty_strided_cuda((8, 197, 768), (151296, 768, 1), torch.float32)
buf339 = empty_strided_cuda((8, 197, 768), (151296, 768, 1), torch.float16)
buf376 = empty_strided_cuda((8, 197, 1), (197, 1, 1), torch.float32)
# Source Nodes: [mul_20, mul_21, qkv_22, x_150, x_158, x_159], Original ATen: [aten._to_copy, aten.add, aten.mul, aten.native_layer_norm, aten.native_layer_norm_backward]
triton_per_fused__to_copy_add_mul_native_layer_norm_native_layer_norm_backward_13.run(buf304, primals_102, buf321, primals_109, buf333, primals_113, primals_114, buf334, buf338, buf339, buf376, 1576, 768, grid=grid(1576), stream=stream0)
del primals_114
buf340 = empty_strided_cuda((2304, 768), (768, 1), torch.float16)
# Source Nodes: [qkv_22], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_4.run(primals_117, buf340, 1769472, grid=grid(1769472), stream=stream0)
del primals_117
buf341 = buf311; del buf311 # reuse
# Source Nodes: [qkv_22, qkv_bias_11], Original ATen: [aten._to_copy, aten.cat]
triton_poi_fused__to_copy_cat_5.run(primals_115, primals_222, primals_116, buf341, 2304, grid=grid(2304), stream=stream0)
del primals_115
del primals_116
del primals_222
buf342 = empty_strided_cuda((1576, 2304), (2304, 1), torch.float16)
# Source Nodes: [qkv_22, qkv_bias_11], Original ATen: [aten._to_copy, aten.addmm, aten.cat]
extern_kernels.addmm(buf341, reinterpret_tensor(buf339, (1576, 768), (768, 1), 0), reinterpret_tensor(buf340, (768, 2304), (1, 768), 0), alpha=1, beta=1, out=buf342)
del buf341
buf343 = empty_strided_cuda((1, 12, 197, 200), (473088, 39424, 200, 1), torch.float16)
# Source Nodes: [x_160], Original ATen: [aten._to_copy, aten.constant_pad_nd]
triton_poi_fused__to_copy_constant_pad_nd_6.run(primals_223, primals_118, buf343, 472800, grid=grid(472800), stream=stream0)
del primals_118
# Source Nodes: [x_160], Original ATen: [aten._scaled_dot_product_efficient_attention]
buf344 = aten._scaled_dot_product_efficient_attention.default(reinterpret_tensor(buf342, (8, 12, 197, 64), (453888, 64, 2304, 1), 0), reinterpret_tensor(buf342, (8, 12, 197, 64), (453888, 64, 2304, 1), 768), reinterpret_tensor(buf342, (8, 12, 197, 64), (453888, 64, 2304, 1), 1536), reinterpret_tensor(buf343, (8, 12, 197, 197), (0, 39424, 200, 1), 0), True)
buf345 = buf344[0]
buf346 = buf344[1]
buf347 = buf344[2]
buf348 = buf344[3]
del buf344
buf349 = empty_strided_cuda((768, 768), (768, 1), torch.float16)
# Source Nodes: [x_162], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_7.run(primals_192, buf349, 589824, grid=grid(589824), stream=stream0)
del primals_192
buf350 = buf332; del buf332 # reuse
# Source Nodes: [x_162], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_convolution_2.run(primals_193, buf350, 768, grid=grid(768), stream=stream0)
del primals_193
buf351 = empty_strided_cuda((1576, 768), (768, 1), torch.float16)
# Source Nodes: [x_162], Original ATen: [aten._to_copy, aten.addmm]
extern_kernels.addmm(buf350, reinterpret_tensor(buf345, (1576, 768), (768, 1), 0), reinterpret_tensor(buf349, (768, 768), (1, 768), 0), alpha=1, beta=1, out=buf351)
buf355 = buf304; del buf304 # reuse
buf356 = empty_strided_cuda((8, 197, 768), (151296, 768, 1), torch.float16)
buf375 = empty_strided_cuda((8, 197, 1), (197, 1, 1), torch.float32)
# Source Nodes: [mul_22, x_164, x_165, x_166], Original ATen: [aten._to_copy, aten.add, aten.mul, aten.native_layer_norm, aten.native_layer_norm_backward]
triton_per_fused__to_copy_add_mul_native_layer_norm_native_layer_norm_backward_14.run(buf334, primals_112, buf351, primals_120, primals_121, buf355, buf356, buf375, 1576, 768, grid=grid(1576), stream=stream0)
del primals_121
buf357 = empty_strided_cuda((3072, 768), (768, 1), torch.float16)
# Source Nodes: [x_166], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_9.run(primals_194, buf357, 2359296, grid=grid(2359296), stream=stream0)
del primals_194
buf358 = buf328; del buf328 # reuse
# Source Nodes: [x_166], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_10.run(primals_195, buf358, 3072, grid=grid(3072), stream=stream0)
del primals_195
buf359 = empty_strided_cuda((1576, 3072), (3072, 1), torch.float16)
# Source Nodes: [x_166], Original ATen: [aten._to_copy, aten.addmm]
extern_kernels.addmm(buf358, reinterpret_tensor(buf356, (1576, 768), (768, 1), 0), reinterpret_tensor(buf357, (768, 3072), (1, 768), 0), alpha=1, beta=1, out=buf359)
del buf358
buf360 = empty_strided_cuda((8, 197, 3072), (605184, 3072, 1), torch.float16)
# Source Nodes: [x_167], Original ATen: [aten.gelu]
triton_poi_fused_gelu_11.run(buf359, buf360, 4841472, grid=grid(4841472), stream=stream0)
buf361 = empty_strided_cuda((768, 3072), (3072, 1), torch.float16)
# Source Nodes: [x_170], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_12.run(primals_196, buf361, 2359296, grid=grid(2359296), stream=stream0)
del primals_196
buf362 = buf350; del buf350 # reuse
# Source Nodes: [x_170], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_convolution_2.run(primals_197, buf362, 768, grid=grid(768), stream=stream0)
del primals_197
buf363 = empty_strided_cuda((1576, 768), (768, 1), torch.float16)
# Source Nodes: [x_170], Original ATen: [aten._to_copy, aten.addmm]
extern_kernels.addmm(buf362, reinterpret_tensor(buf360, (1576, 3072), (3072, 1), 0), reinterpret_tensor(buf361, (3072, 768), (1, 3072), 0), alpha=1, beta=1, out=buf363)
del buf362
buf364 = empty_strided_cuda((8, 768, 2), (1536, 1, 768), torch.float32)
# Source Nodes: [x_174], Original ATen: [aten.mean]
triton_red_fused_mean_15.run(buf334, primals_112, buf351, primals_119, buf363, buf364, 12288, 98, grid=grid(12288), stream=stream0)
del buf334
buf365 = empty_strided_cuda((8, 768), (768, 1), torch.float32)
# Source Nodes: [x_174], Original ATen: [aten.mean]
triton_per_fused_mean_16.run(buf364, buf365, 6144, 2, grid=grid(6144), stream=stream0)
del buf364
buf369 = empty_strided_cuda((8, 768), (768, 1), torch.float32)
buf370 = empty_strided_cuda((8, 768), (768, 1), torch.float16)
buf374 = empty_strided_cuda((8, 1), (1, 1), torch.float32)
# Source Nodes: [x_174, x_175, x_177], Original ATen: [aten._to_copy, aten.mean, aten.native_layer_norm, aten.native_layer_norm_backward]
triton_per_fused__to_copy_mean_native_layer_norm_native_layer_norm_backward_17.run(buf365, primals_122, primals_123, buf369, buf370, buf374, 8, 768, grid=grid(8), stream=stream0)
del buf365
del primals_123
buf371 = empty_strided_cuda((1000, 768), (768, 1), torch.float16)
# Source Nodes: [x_177], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_18.run(primals_198, buf371, 768000, grid=grid(768000), stream=stream0)
del primals_198
buf372 = empty_strided_cuda((1000, ), (1, ), torch.float16)
# Source Nodes: [x_177], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_19.run(primals_199, buf372, 1000, grid=grid(1000), stream=stream0)
del primals_199
buf373 = empty_strided_cuda((8, 1000), (1000, 1), torch.float16)
# Source Nodes: [x_177], Original ATen: [aten._to_copy, aten.addmm]
extern_kernels.addmm(buf372, buf370, reinterpret_tensor(buf371, (768, 1000), (1, 768), 0), alpha=1, beta=1, out=buf373)
del buf372
return (buf373, primals_2, primals_3, primals_9, primals_10, primals_12, primals_13, primals_19, primals_20, primals_22, primals_23, primals_29, primals_30, primals_32, primals_33, primals_39, primals_40, primals_42, primals_43, primals_49, primals_50, primals_52, primals_53, primals_59, primals_60, primals_62, primals_63, primals_69, primals_70, primals_72, primals_73, primals_79, primals_80, primals_82, primals_83, primals_89, primals_90, primals_92, primals_93, primals_99, primals_100, primals_102, primals_103, primals_109, primals_110, primals_112, primals_113, primals_119, primals_120, primals_122, buf0, buf1, buf4, buf5, buf8, reinterpret_tensor(buf9, (1576, 768), (768, 1), 0), reinterpret_tensor(buf12, (8, 12, 197, 64), (453888, 64, 2304, 1), 0), reinterpret_tensor(buf12, (8, 12, 197, 64), (453888, 64, 2304, 1), 768), reinterpret_tensor(buf12, (8, 12, 197, 64), (453888, 64, 2304, 1), 1536), reinterpret_tensor(primals_201, (38809, ), (1, ), 0), buf13, buf15, buf16, buf17, buf18, buf21, buf22, buf25, reinterpret_tensor(buf26, (1576, 768), (768, 1), 0), buf29, reinterpret_tensor(buf30, (1576, 3072), (3072, 1), 0), buf33, buf38, reinterpret_tensor(buf39, (1576, 768), (768, 1), 0), reinterpret_tensor(buf42, (8, 12, 197, 64), (453888, 64, 2304, 1), 0), reinterpret_tensor(buf42, (8, 12, 197, 64), (453888, 64, 2304, 1), 768), reinterpret_tensor(buf42, (8, 12, 197, 64), (453888, 64, 2304, 1), 1536), reinterpret_tensor(primals_203, (38809, ), (1, ), 0), buf43, buf45, buf46, buf47, buf48, buf51, buf55, reinterpret_tensor(buf56, (1576, 768), (768, 1), 0), buf59, reinterpret_tensor(buf60, (1576, 3072), (3072, 1), 0), buf63, buf68, reinterpret_tensor(buf69, (1576, 768), (768, 1), 0), reinterpret_tensor(buf72, (8, 12, 197, 64), (453888, 64, 2304, 1), 0), reinterpret_tensor(buf72, (8, 12, 197, 64), (453888, 64, 2304, 1), 768), reinterpret_tensor(buf72, (8, 12, 197, 64), (453888, 64, 2304, 1), 1536), reinterpret_tensor(primals_205, (38809, ), (1, ), 0), buf73, buf75, buf76, buf77, buf78, buf81, buf85, reinterpret_tensor(buf86, (1576, 768), (768, 1), 0), buf89, reinterpret_tensor(buf90, (1576, 3072), (3072, 1), 0), buf93, buf98, reinterpret_tensor(buf99, (1576, 768), (768, 1), 0), reinterpret_tensor(buf102, (8, 12, 197, 64), (453888, 64, 2304, 1), 0), reinterpret_tensor(buf102, (8, 12, 197, 64), (453888, 64, 2304, 1), 768), reinterpret_tensor(buf102, (8, 12, 197, 64), (453888, 64, 2304, 1), 1536), reinterpret_tensor(primals_207, (38809, ), (1, ), 0), buf103, buf105, buf106, buf107, buf108, buf111, buf115, reinterpret_tensor(buf116, (1576, 768), (768, 1), 0), buf119, reinterpret_tensor(buf120, (1576, 3072), (3072, 1), 0), buf123, buf128, reinterpret_tensor(buf129, (1576, 768), (768, 1), 0), reinterpret_tensor(buf132, (8, 12, 197, 64), (453888, 64, 2304, 1), 0), reinterpret_tensor(buf132, (8, 12, 197, 64), (453888, 64, 2304, 1), 768), reinterpret_tensor(buf132, (8, 12, 197, 64), (453888, 64, 2304, 1), 1536), reinterpret_tensor(primals_209, (38809, ), (1, ), 0), buf133, buf135, buf136, buf137, buf138, buf141, buf145, reinterpret_tensor(buf146, (1576, 768), (768, 1), 0), buf149, reinterpret_tensor(buf150, (1576, 3072), (3072, 1), 0), buf153, buf158, reinterpret_tensor(buf159, (1576, 768), (768, 1), 0), reinterpret_tensor(buf162, (8, 12, 197, 64), (453888, 64, 2304, 1), 0), reinterpret_tensor(buf162, (8, 12, 197, 64), (453888, 64, 2304, 1), 768), reinterpret_tensor(buf162, (8, 12, 197, 64), (453888, 64, 2304, 1), 1536), reinterpret_tensor(primals_211, (38809, ), (1, ), 0), buf163, buf165, buf166, buf167, buf168, buf171, buf175, reinterpret_tensor(buf176, (1576, 768), (768, 1), 0), buf179, reinterpret_tensor(buf180, (1576, 3072), (3072, 1), 0), buf183, buf188, reinterpret_tensor(buf189, (1576, 768), (768, 1), 0), reinterpret_tensor(buf192, (8, 12, 197, 64), (453888, 64, 2304, 1), 0), reinterpret_tensor(buf192, (8, 12, 197, 64), (453888, 64, 2304, 1), 768), reinterpret_tensor(buf192, (8, 12, 197, 64), (453888, 64, 2304, 1), 1536), reinterpret_tensor(primals_213, (38809, ), (1, ), 0), buf193, buf195, buf196, buf197, buf198, buf201, buf205, reinterpret_tensor(buf206, (1576, 768), (768, 1), 0), buf209, reinterpret_tensor(buf210, (1576, 3072), (3072, 1), 0), buf213, buf218, reinterpret_tensor(buf219, (1576, 768), (768, 1), 0), reinterpret_tensor(buf222, (8, 12, 197, 64), (453888, 64, 2304, 1), 0), reinterpret_tensor(buf222, (8, 12, 197, 64), (453888, 64, 2304, 1), 768), reinterpret_tensor(buf222, (8, 12, 197, 64), (453888, 64, 2304, 1), 1536), reinterpret_tensor(primals_215, (38809, ), (1, ), 0), buf223, buf225, buf226, buf227, buf228, buf231, buf235, reinterpret_tensor(buf236, (1576, 768), (768, 1), 0), buf239, reinterpret_tensor(buf240, (1576, 3072), (3072, 1), 0), buf243, buf248, reinterpret_tensor(buf249, (1576, 768), (768, 1), 0), reinterpret_tensor(buf252, (8, 12, 197, 64), (453888, 64, 2304, 1), 0), reinterpret_tensor(buf252, (8, 12, 197, 64), (453888, 64, 2304, 1), 768), reinterpret_tensor(buf252, (8, 12, 197, 64), (453888, 64, 2304, 1), 1536), reinterpret_tensor(primals_217, (38809, ), (1, ), 0), buf253, buf255, buf256, buf257, buf258, buf261, buf265, reinterpret_tensor(buf266, (1576, 768), (768, 1), 0), buf269, reinterpret_tensor(buf270, (1576, 3072), (3072, 1), 0), buf273, buf278, reinterpret_tensor(buf279, (1576, 768), (768, 1), 0), reinterpret_tensor(buf282, (8, 12, 197, 64), (453888, 64, 2304, 1), 0), reinterpret_tensor(buf282, (8, 12, 197, 64), (453888, 64, 2304, 1), 768), reinterpret_tensor(buf282, (8, 12, 197, 64), (453888, 64, 2304, 1), 1536), reinterpret_tensor(primals_219, (38809, ), (1, ), 0), buf283, buf285, buf286, buf287, buf288, buf291, buf295, reinterpret_tensor(buf296, (1576, 768), (768, 1), 0), buf299, reinterpret_tensor(buf300, (1576, 3072), (3072, 1), 0), buf303, buf308, reinterpret_tensor(buf309, (1576, 768), (768, 1), 0), reinterpret_tensor(buf312, (8, 12, 197, 64), (453888, 64, 2304, 1), 0), reinterpret_tensor(buf312, (8, 12, 197, 64), (453888, 64, 2304, 1), 768), reinterpret_tensor(buf312, (8, 12, 197, 64), (453888, 64, 2304, 1), 1536), reinterpret_tensor(primals_221, (38809, ), (1, ), 0), buf313, buf315, buf316, buf317, buf318, buf321, buf325, reinterpret_tensor(buf326, (1576, 768), (768, 1), 0), buf329, reinterpret_tensor(buf330, (1576, 3072), (3072, 1), 0), buf333, buf338, reinterpret_tensor(buf339, (1576, 768), (768, 1), 0), reinterpret_tensor(buf342, (8, 12, 197, 64), (453888, 64, 2304, 1), 0), reinterpret_tensor(buf342, (8, 12, 197, 64), (453888, 64, 2304, 1), 768), reinterpret_tensor(buf342, (8, 12, 197, 64), (453888, 64, 2304, 1), 1536), reinterpret_tensor(primals_223, (38809, ), (1, ), 0), buf343, buf345, buf346, buf347, buf348, buf351, buf355, reinterpret_tensor(buf356, (1576, 768), (768, 1), 0), buf359, reinterpret_tensor(buf360, (1576, 3072), (3072, 1), 0), buf363, buf369, buf370, reinterpret_tensor(buf371, (1000, 768), (768, 1), 0), buf374, reinterpret_tensor(buf361, (768, 3072), (3072, 1), 0), reinterpret_tensor(buf357, (3072, 768), (768, 1), 0), buf375, reinterpret_tensor(buf349, (768, 768), (768, 1), 0), reinterpret_tensor(buf340, (2304, 768), (768, 1), 0), buf376, reinterpret_tensor(buf331, (768, 3072), (3072, 1), 0), reinterpret_tensor(buf327, (3072, 768), (768, 1), 0), buf377, reinterpret_tensor(buf319, (768, 768), (768, 1), 0), reinterpret_tensor(buf310, (2304, 768), (768, 1), 0), buf378, reinterpret_tensor(buf301, (768, 3072), (3072, 1), 0), reinterpret_tensor(buf297, (3072, 768), (768, 1), 0), buf379, reinterpret_tensor(buf289, (768, 768), (768, 1), 0), reinterpret_tensor(buf280, (2304, 768), (768, 1), 0), buf380, reinterpret_tensor(buf271, (768, 3072), (3072, 1), 0), reinterpret_tensor(buf267, (3072, 768), (768, 1), 0), buf381, reinterpret_tensor(buf259, (768, 768), (768, 1), 0), reinterpret_tensor(buf250, (2304, 768), (768, 1), 0), buf382, reinterpret_tensor(buf241, (768, 3072), (3072, 1), 0), reinterpret_tensor(buf237, (3072, 768), (768, 1), 0), buf383, reinterpret_tensor(buf229, (768, 768), (768, 1), 0), reinterpret_tensor(buf220, (2304, 768), (768, 1), 0), buf384, reinterpret_tensor(buf211, (768, 3072), (3072, 1), 0), reinterpret_tensor(buf207, (3072, 768), (768, 1), 0), buf385, reinterpret_tensor(buf199, (768, 768), (768, 1), 0), reinterpret_tensor(buf190, (2304, 768), (768, 1), 0), buf386, reinterpret_tensor(buf181, (768, 3072), (3072, 1), 0), reinterpret_tensor(buf177, (3072, 768), (768, 1), 0), buf387, reinterpret_tensor(buf169, (768, 768), (768, 1), 0), reinterpret_tensor(buf160, (2304, 768), (768, 1), 0), buf388, reinterpret_tensor(buf151, (768, 3072), (3072, 1), 0), reinterpret_tensor(buf147, (3072, 768), (768, 1), 0), buf389, reinterpret_tensor(buf139, (768, 768), (768, 1), 0), reinterpret_tensor(buf130, (2304, 768), (768, 1), 0), buf390, reinterpret_tensor(buf121, (768, 3072), (3072, 1), 0), reinterpret_tensor(buf117, (3072, 768), (768, 1), 0), buf391, reinterpret_tensor(buf109, (768, 768), (768, 1), 0), reinterpret_tensor(buf100, (2304, 768), (768, 1), 0), buf392, reinterpret_tensor(buf91, (768, 3072), (3072, 1), 0), reinterpret_tensor(buf87, (3072, 768), (768, 1), 0), buf393, reinterpret_tensor(buf79, (768, 768), (768, 1), 0), reinterpret_tensor(buf70, (2304, 768), (768, 1), 0), buf394, reinterpret_tensor(buf61, (768, 3072), (3072, 1), 0), reinterpret_tensor(buf57, (3072, 768), (768, 1), 0), buf395, reinterpret_tensor(buf49, (768, 768), (768, 1), 0), reinterpret_tensor(buf40, (2304, 768), (768, 1), 0), buf396, reinterpret_tensor(buf31, (768, 3072), (3072, 1), 0), reinterpret_tensor(buf27, (3072, 768), (768, 1), 0), reinterpret_tensor(buf19, (768, 768), (768, 1), 0), reinterpret_tensor(buf10, (2304, 768), (768, 1), 0), )
def benchmark_compiled_module(times=10, repeat=10):
from torch._dynamo.testing import rand_strided
from torch._inductor.utils import print_performance
primals_1 = rand_strided((1, 1, 768), (768, 768, 1), device='cuda:0', dtype=torch.float32)
primals_2 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_3 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_4 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_5 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_6 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_7 = rand_strided((2304, 768), (768, 1), device='cuda:0', dtype=torch.float32)
primals_8 = rand_strided((732, 12), (12, 1), device='cuda:0', dtype=torch.float32)
primals_9 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_10 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_11 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_12 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_13 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_14 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_15 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_16 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_17 = rand_strided((2304, 768), (768, 1), device='cuda:0', dtype=torch.float32)
primals_18 = rand_strided((732, 12), (12, 1), device='cuda:0', dtype=torch.float32)
primals_19 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_20 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_21 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_22 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_23 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_24 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_25 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_26 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_27 = rand_strided((2304, 768), (768, 1), device='cuda:0', dtype=torch.float32)
primals_28 = rand_strided((732, 12), (12, 1), device='cuda:0', dtype=torch.float32)
primals_29 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_30 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_31 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_32 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_33 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_34 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_35 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_36 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_37 = rand_strided((2304, 768), (768, 1), device='cuda:0', dtype=torch.float32)
primals_38 = rand_strided((732, 12), (12, 1), device='cuda:0', dtype=torch.float32)
primals_39 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_40 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_41 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_42 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_43 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_44 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_45 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_46 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_47 = rand_strided((2304, 768), (768, 1), device='cuda:0', dtype=torch.float32)
primals_48 = rand_strided((732, 12), (12, 1), device='cuda:0', dtype=torch.float32)
primals_49 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_50 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_51 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_52 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_53 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_54 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_55 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_56 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_57 = rand_strided((2304, 768), (768, 1), device='cuda:0', dtype=torch.float32)
primals_58 = rand_strided((732, 12), (12, 1), device='cuda:0', dtype=torch.float32)
primals_59 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_60 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_61 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_62 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_63 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_64 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_65 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_66 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_67 = rand_strided((2304, 768), (768, 1), device='cuda:0', dtype=torch.float32)
primals_68 = rand_strided((732, 12), (12, 1), device='cuda:0', dtype=torch.float32)
primals_69 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_70 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_71 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_72 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_73 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_74 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_75 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_76 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_77 = rand_strided((2304, 768), (768, 1), device='cuda:0', dtype=torch.float32)
primals_78 = rand_strided((732, 12), (12, 1), device='cuda:0', dtype=torch.float32)
primals_79 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_80 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_81 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_82 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_83 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_84 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_85 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_86 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_87 = rand_strided((2304, 768), (768, 1), device='cuda:0', dtype=torch.float32)
primals_88 = rand_strided((732, 12), (12, 1), device='cuda:0', dtype=torch.float32)
primals_89 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_90 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_91 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_92 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_93 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_94 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_95 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_96 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_97 = rand_strided((2304, 768), (768, 1), device='cuda:0', dtype=torch.float32)
primals_98 = rand_strided((732, 12), (12, 1), device='cuda:0', dtype=torch.float32)
primals_99 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_100 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_101 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_102 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_103 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_104 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_105 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_106 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_107 = rand_strided((2304, 768), (768, 1), device='cuda:0', dtype=torch.float32)
primals_108 = rand_strided((732, 12), (12, 1), device='cuda:0', dtype=torch.float32)
primals_109 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_110 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_111 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_112 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_113 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_114 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_115 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_116 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_117 = rand_strided((2304, 768), (768, 1), device='cuda:0', dtype=torch.float32)
primals_118 = rand_strided((732, 12), (12, 1), device='cuda:0', dtype=torch.float32)
primals_119 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_120 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_121 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_122 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_123 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_124 = rand_strided((768, 3, 16, 16), (768, 256, 16, 1), device='cuda:0', dtype=torch.float32)
primals_125 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_126 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.float32)
primals_127 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_128 = rand_strided((3072, 768), (768, 1), device='cuda:0', dtype=torch.float32)
primals_129 = rand_strided((3072, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_130 = rand_strided((768, 3072), (3072, 1), device='cuda:0', dtype=torch.float32)
primals_131 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_132 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.float32)
primals_133 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_134 = rand_strided((3072, 768), (768, 1), device='cuda:0', dtype=torch.float32)
primals_135 = rand_strided((3072, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_136 = rand_strided((768, 3072), (3072, 1), device='cuda:0', dtype=torch.float32)
primals_137 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_138 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.float32)
primals_139 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_140 = rand_strided((3072, 768), (768, 1), device='cuda:0', dtype=torch.float32)
primals_141 = rand_strided((3072, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_142 = rand_strided((768, 3072), (3072, 1), device='cuda:0', dtype=torch.float32)
primals_143 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_144 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.float32)
primals_145 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_146 = rand_strided((3072, 768), (768, 1), device='cuda:0', dtype=torch.float32)
primals_147 = rand_strided((3072, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_148 = rand_strided((768, 3072), (3072, 1), device='cuda:0', dtype=torch.float32)
primals_149 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_150 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.float32)
primals_151 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_152 = rand_strided((3072, 768), (768, 1), device='cuda:0', dtype=torch.float32)
primals_153 = rand_strided((3072, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_154 = rand_strided((768, 3072), (3072, 1), device='cuda:0', dtype=torch.float32)
primals_155 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_156 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.float32)
primals_157 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_158 = rand_strided((3072, 768), (768, 1), device='cuda:0', dtype=torch.float32)
primals_159 = rand_strided((3072, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_160 = rand_strided((768, 3072), (3072, 1), device='cuda:0', dtype=torch.float32)
primals_161 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_162 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.float32)
primals_163 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_164 = rand_strided((3072, 768), (768, 1), device='cuda:0', dtype=torch.float32)
primals_165 = rand_strided((3072, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_166 = rand_strided((768, 3072), (3072, 1), device='cuda:0', dtype=torch.float32)
primals_167 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_168 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.float32)
primals_169 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_170 = rand_strided((3072, 768), (768, 1), device='cuda:0', dtype=torch.float32)
primals_171 = rand_strided((3072, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_172 = rand_strided((768, 3072), (3072, 1), device='cuda:0', dtype=torch.float32)
primals_173 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_174 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.float32)
primals_175 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_176 = rand_strided((3072, 768), (768, 1), device='cuda:0', dtype=torch.float32)
primals_177 = rand_strided((3072, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_178 = rand_strided((768, 3072), (3072, 1), device='cuda:0', dtype=torch.float32)
primals_179 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_180 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.float32)
primals_181 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_182 = rand_strided((3072, 768), (768, 1), device='cuda:0', dtype=torch.float32)
primals_183 = rand_strided((3072, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_184 = rand_strided((768, 3072), (3072, 1), device='cuda:0', dtype=torch.float32)
primals_185 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_186 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.float32)
primals_187 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_188 = rand_strided((3072, 768), (768, 1), device='cuda:0', dtype=torch.float32)
primals_189 = rand_strided((3072, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_190 = rand_strided((768, 3072), (3072, 1), device='cuda:0', dtype=torch.float32)
primals_191 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_192 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.float32)
primals_193 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_194 = rand_strided((3072, 768), (768, 1), device='cuda:0', dtype=torch.float32)
primals_195 = rand_strided((3072, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_196 = rand_strided((768, 3072), (3072, 1), device='cuda:0', dtype=torch.float32)
primals_197 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_198 = rand_strided((1000, 768), (768, 1), device='cuda:0', dtype=torch.float32)
primals_199 = rand_strided((1000, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_200 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_201 = rand_strided((197, 197), (197, 1), device='cuda:0', dtype=torch.int64)
primals_202 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_203 = rand_strided((197, 197), (197, 1), device='cuda:0', dtype=torch.int64)
primals_204 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_205 = rand_strided((197, 197), (197, 1), device='cuda:0', dtype=torch.int64)
primals_206 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_207 = rand_strided((197, 197), (197, 1), device='cuda:0', dtype=torch.int64)
primals_208 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_209 = rand_strided((197, 197), (197, 1), device='cuda:0', dtype=torch.int64)
primals_210 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_211 = rand_strided((197, 197), (197, 1), device='cuda:0', dtype=torch.int64)
primals_212 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_213 = rand_strided((197, 197), (197, 1), device='cuda:0', dtype=torch.int64)
primals_214 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_215 = rand_strided((197, 197), (197, 1), device='cuda:0', dtype=torch.int64)
primals_216 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_217 = rand_strided((197, 197), (197, 1), device='cuda:0', dtype=torch.int64)
primals_218 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_219 = rand_strided((197, 197), (197, 1), device='cuda:0', dtype=torch.int64)
primals_220 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_221 = rand_strided((197, 197), (197, 1), device='cuda:0', dtype=torch.int64)
primals_222 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_223 = rand_strided((197, 197), (197, 1), device='cuda:0', dtype=torch.int64)
primals_224 = rand_strided((8, 3, 224, 224), (150528, 50176, 224, 1), device='cuda:0', dtype=torch.float32)
fn = lambda: call([primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, primals_12, primals_13, primals_14, primals_15, primals_16, primals_17, primals_18, primals_19, primals_20, primals_21, primals_22, primals_23, primals_24, primals_25, primals_26, primals_27, primals_28, primals_29, primals_30, primals_31, primals_32, primals_33, primals_34, primals_35, primals_36, primals_37, primals_38, primals_39, primals_40, primals_41, primals_42, primals_43, primals_44, primals_45, primals_46, primals_47, primals_48, primals_49, primals_50, primals_51, primals_52, primals_53, primals_54, primals_55, primals_56, primals_57, primals_58, primals_59, primals_60, primals_61, primals_62, primals_63, primals_64, primals_65, primals_66, primals_67, primals_68, primals_69, primals_70, primals_71, primals_72, primals_73, primals_74, primals_75, primals_76, primals_77, primals_78, primals_79, primals_80, primals_81, primals_82, primals_83, primals_84, primals_85, primals_86, primals_87, primals_88, primals_89, primals_90, primals_91, primals_92, primals_93, primals_94, primals_95, primals_96, primals_97, primals_98, primals_99, primals_100, primals_101, primals_102, primals_103, primals_104, primals_105, primals_106, primals_107, primals_108, primals_109, primals_110, primals_111, primals_112, primals_113, primals_114, primals_115, primals_116, primals_117, primals_118, primals_119, primals_120, primals_121, primals_122, primals_123, primals_124, primals_125, primals_126, primals_127, primals_128, primals_129, primals_130, primals_131, primals_132, primals_133, primals_134, primals_135, primals_136, primals_137, primals_138, primals_139, primals_140, primals_141, primals_142, primals_143, primals_144, primals_145, primals_146, primals_147, primals_148, primals_149, primals_150, primals_151, primals_152, primals_153, primals_154, primals_155, primals_156, primals_157, primals_158, primals_159, primals_160, primals_161, primals_162, primals_163, primals_164, primals_165, primals_166, primals_167, primals_168, primals_169, primals_170, primals_171, primals_172, primals_173, primals_174, primals_175, primals_176, primals_177, primals_178, primals_179, primals_180, primals_181, primals_182, primals_183, primals_184, primals_185, primals_186, primals_187, primals_188, primals_189, primals_190, primals_191, primals_192, primals_193, primals_194, primals_195, primals_196, primals_197, primals_198, primals_199, primals_200, primals_201, primals_202, primals_203, primals_204, primals_205, primals_206, primals_207, primals_208, primals_209, primals_210, primals_211, primals_212, primals_213, primals_214, primals_215, primals_216, primals_217, primals_218, primals_219, primals_220, primals_221, primals_222, primals_223, primals_224])
return print_performance(fn, times=times, repeat=repeat)
if __name__ == "__main__":
from torch._inductor.wrapper_benchmark import compiled_module_main
compiled_module_main('beit_base_patch16_224', benchmark_compiled_module)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment