Skip to content

Instantly share code, notes, and snippets.

@ezyang
Created November 20, 2022 19:39
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save ezyang/1fbf5bab298a0b35b47d5de6f885deff to your computer and use it in GitHub Desktop.
Save ezyang/1fbf5bab298a0b35b47d5de6f885deff to your computer and use it in GitHub Desktop.
BERT_pytorch dynamic Triton
from ctypes import c_void_p, c_long
import torch
import random
from torch import empty_strided, as_strided, device
from torch._inductor.codecache import AsyncCompile
aten = torch.ops.aten
assert_size_stride = torch._C._dynamo.guards.assert_size_stride
async_compile = AsyncCompile()
import triton
import triton.language as tl
from torch._inductor.triton_ops.autotune import grid
from torch._C import _cuda_getCurrentRawStream as get_cuda_stream
triton_fused_add_1_div_mul_std_sub_mean_add0 = async_compile.triton('''
import triton
import triton.language as tl
from torch._inductor.ir import ReductionHint
from torch._inductor.ir import TileHint
from torch._inductor.triton_ops.autotune import reduction
from torch._inductor.utils import instance_descriptor
@reduction(size_hints=[256, 1024],
reduction_hint=ReductionHint.INNER,
filename=__file__,
meta={'signature': {0: '*fp32', 1: '*fp32', 2: '*fp32', 3: '*fp32', 4: '*fp32', 5: '*fp32', 6: 'i32', 7: 'i32'}, 'device': 0, 'constants': {}, 'configs': [instance_descriptor(divisible_by_16=(0, 1, 2, 3, 4, 5, 6, 7), equal_to_1=())]})
@triton.jit
def kernel(in_out_ptr0, in_out_ptr1, in_ptr0, in_ptr1, in_ptr2, out_ptr1, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr):
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.reshape(tl.arange(0, XBLOCK), [XBLOCK, 1])
xmask = xindex < xnumel
rbase = tl.reshape(tl.arange(0, RBLOCK), [1, RBLOCK])
x0 = xindex
_tmp1 = tl.zeros([XBLOCK, RBLOCK], tl.float32) + 0
for roffset in range(0, rnumel, RBLOCK):
rindex = roffset + rbase
rmask = rindex < rnumel
r1 = rindex
tmp0 = tl.load(in_ptr0 + (r1 + (768*x0)), xmask & rmask, eviction_policy='evict_last')
_tmp1 = tl.where(xmask & rmask, _tmp1 + tmp0, _tmp1)
tmp1 = tl.reshape(tl.sum(_tmp1, 1), [XBLOCK, 1])
tmp2 = tmp1
_tmp8 = tl.zeros([XBLOCK, RBLOCK], tl.float32) + 0
for roffset in range(0, rnumel, RBLOCK):
rindex = roffset + rbase
rmask = rindex < rnumel
r1 = rindex
tmp3 = tl.load(in_ptr0 + (r1 + (768*x0)), xmask & rmask, eviction_policy='evict_last')
tmp4 = 768
tmp5 = tmp2 / tmp4
tmp6 = tmp3 - tmp5
tmp7 = tmp6 * tmp6
_tmp8 = tl.where(xmask & rmask, _tmp8 + tmp7, _tmp8)
tmp8 = tl.reshape(tl.sum(_tmp8, 1), [XBLOCK, 1])
tmp9 = 768
tmp10 = tmp1 / tmp9
tmp11 = 767
tmp12 = tmp8 / tmp11
tmp13 = tl.sqrt(tmp12)
tl.store(in_out_ptr0 + (x0 + tl.zeros([XBLOCK, 1], tl.int32)), tmp10, xmask)
tl.store(in_out_ptr1 + (x0 + tl.zeros([XBLOCK, 1], tl.int32)), tmp13, xmask)
for roffset in range(0, rnumel, RBLOCK):
rindex = roffset + rbase
rmask = rindex < rnumel
r1 = rindex
tmp14 = tl.load(in_ptr1 + (r1), rmask, eviction_policy='evict_last')
tmp15 = tl.load(in_ptr0 + (r1 + (768*x0)), xmask & rmask, eviction_policy='evict_last')
tmp21 = tl.load(in_ptr2 + (r1), rmask, eviction_policy='evict_last')
tmp16 = tmp15 - tmp10
tmp17 = tmp14 * tmp16
tmp18 = 1e-06
tmp19 = tmp13 + tmp18
tmp20 = tmp17 / tmp19
tmp22 = tmp20 + tmp21
tl.store(out_ptr1 + (r1 + (768*x0) + tl.zeros([XBLOCK, RBLOCK], tl.int32)), tmp22, xmask & rmask)
''')
triton_fused_view_1_add_1_addmm_div_mul_permute_sub_mean_add_permute_11 = async_compile.triton('''
import triton
import triton.language as tl
from torch._inductor.ir import ReductionHint
from torch._inductor.ir import TileHint
from torch._inductor.triton_ops.autotune import pointwise
from torch._inductor.utils import instance_descriptor
@pointwise(size_hints=[262144], filename=__file__, meta={'signature': {0: '*fp32', 1: '*fp32', 2: 'i32', 3: 'i32'}, 'device': 0, 'constants': {}, 'configs': [instance_descriptor(divisible_by_16=(0, 1, 2, 3), equal_to_1=())]})
@triton.jit
def kernel(in_ptr0, out_ptr0, ks0, xnumel, XBLOCK : tl.constexpr):
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.reshape(tl.arange(0, XBLOCK), [XBLOCK])
xmask = xindex < xnumel
x0 = xindex % 64
x1 = (xindex // 64) % ks0
x2 = (xindex // (64*ks0)) % 12
x3 = (xindex // (768*ks0))
x4 = xindex
tmp0 = tl.load(in_ptr0 + (x0 + (64*x2) + (768*x1) + (768*ks0*x3)), xmask)
tl.store(out_ptr0 + (x4 + tl.zeros([XBLOCK], tl.int32)), tmp0, xmask)
''')
triton_fused_add_1_div_mul_permute_6_sub_addmm_1_mean_permute_3_add_clone_12 = async_compile.triton('''
import triton
import triton.language as tl
from torch._inductor.ir import ReductionHint
from torch._inductor.ir import TileHint
from torch._inductor.triton_ops.autotune import pointwise
from torch._inductor.utils import instance_descriptor
@pointwise(size_hints=[2048, 128], tile_hint=TileHint.SQUARE,filename=__file__, meta={'signature': {0: '*fp32', 1: '*fp32', 2: 'i32', 3: 'i32'}, 'device': 0, 'constants': {}, 'configs': [instance_descriptor(divisible_by_16=(0, 1, 2, 3), equal_to_1=())]})
@triton.jit
def kernel(in_ptr0, out_ptr0, xnumel, ynumel, XBLOCK : tl.constexpr, YBLOCK : tl.constexpr):
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.reshape(tl.arange(0, XBLOCK), [XBLOCK, 1])
xmask = xindex < xnumel
yoffset = tl.program_id(1) * YBLOCK
yindex = yoffset + tl.reshape(tl.arange(0, YBLOCK), [1, YBLOCK])
ymask = yindex < ynumel
x0 = xindex % 768
x1 = (xindex // 768)
y2 = yindex
x3 = xindex
tmp0 = tl.load(in_ptr0 + (x0 + (768*y2) + (98304*x1)), xmask & ymask)
tl.store(out_ptr0 + (y2 + (128*x3) + tl.zeros([XBLOCK, YBLOCK], tl.int32)), tmp0, xmask & ymask)
''')
triton_fused_add_1_div_mul_sub_mean_add_div_2_sub_1_div_1_amax3 = async_compile.triton('''
import triton
import triton.language as tl
from torch._inductor.ir import ReductionHint
from torch._inductor.ir import TileHint
from torch._inductor.triton_ops.autotune import reduction
from torch._inductor.utils import instance_descriptor
@reduction(size_hints=[4096, 128],
reduction_hint=ReductionHint.INNER,
filename=__file__,
meta={'signature': {0: '*i1', 1: '*fp32', 2: '*fp32', 3: 'i32', 4: 'i32'}, 'device': 0, 'constants': {}, 'configs': [instance_descriptor(divisible_by_16=(0, 1, 2, 3, 4), equal_to_1=())]})
@triton.jit
def kernel(in_ptr0, in_ptr1, out_ptr2, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr):
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.reshape(tl.arange(0, XBLOCK), [XBLOCK, 1])
xmask = xindex < xnumel
rbase = tl.reshape(tl.arange(0, RBLOCK), [1, RBLOCK])
x4 = (xindex // 1536)
x5 = xindex
_tmp8 = tl.zeros([XBLOCK, RBLOCK], tl.float32) + float("-inf")
for roffset in range(0, rnumel, RBLOCK):
rindex = roffset + rbase
rmask = rindex < rnumel
r2 = rindex
tmp0 = tl.load(in_ptr0 + (r2 + (128*x4)), xmask & rmask, eviction_policy='evict_last')
tmp4 = tl.load(in_ptr1 + (r2 + (128*x5)), xmask & rmask, eviction_policy='evict_last')
tmp1 = 0
tmp2 = tmp0 == tmp1
tmp3 = -1000000000.0
tmp5 = 8.0
tmp6 = tmp4 / tmp5
tmp7 = tl.where(tmp2, tmp3, tmp6)
_tmp8 = tl.where(xmask & rmask & (_tmp8 < tmp7), tmp7, _tmp8)
tmp8 = tl.reshape(tl.max(_tmp8, 1), [XBLOCK, 1])
_tmp19 = tl.zeros([XBLOCK, RBLOCK], tl.float32) + 0
for roffset in range(0, rnumel, RBLOCK):
rindex = roffset + rbase
rmask = rindex < rnumel
r2 = rindex
tmp9 = tl.load(in_ptr0 + (r2 + (128*x4)), xmask & rmask, eviction_policy='evict_last')
tmp13 = tl.load(in_ptr1 + (r2 + (128*x5)), xmask & rmask, eviction_policy='evict_last')
tmp10 = 0
tmp11 = tmp9 == tmp10
tmp12 = -1000000000.0
tmp14 = 8.0
tmp15 = tmp13 / tmp14
tmp16 = tl.where(tmp11, tmp12, tmp15)
tmp17 = tmp16 - tmp8
tmp18 = tl.exp(tmp17)
_tmp19 = tl.where(xmask & rmask, _tmp19 + tmp18, _tmp19)
tmp19 = tl.reshape(tl.sum(_tmp19, 1), [XBLOCK, 1])
for roffset in range(0, rnumel, RBLOCK):
rindex = roffset + rbase
rmask = rindex < rnumel
r2 = rindex
tmp20 = tl.load(in_ptr0 + (r2 + (128*x4)), xmask & rmask, eviction_policy='evict_last')
tmp24 = tl.load(in_ptr1 + (r2 + (128*x5)), xmask & rmask, eviction_policy='evict_last')
tmp21 = 0
tmp22 = tmp20 == tmp21
tmp23 = -1000000000.0
tmp25 = 8.0
tmp26 = tmp24 / tmp25
tmp27 = tl.where(tmp22, tmp23, tmp26)
tmp28 = tmp27 - tmp8
tmp29 = tl.exp(tmp28)
tmp30 = tmp29 / tmp19
tl.store(out_ptr2 + (r2 + (128*x5) + tl.zeros([XBLOCK, RBLOCK], tl.int32)), tmp30, xmask & rmask)
''')
triton_fused_add_1_div_mul_clone_2_view_7_permute_5_sub_permute_4_mean_add4 = async_compile.triton('''
import triton
import triton.language as tl
from torch._inductor.ir import ReductionHint
from torch._inductor.ir import TileHint
from torch._inductor.triton_ops.autotune import pointwise
from torch._inductor.utils import instance_descriptor
@pointwise(size_hints=[262144], filename=__file__, meta={'signature': {0: '*fp32', 1: '*fp32', 2: 'i32'}, 'device': 0, 'constants': {}, 'configs': [instance_descriptor(divisible_by_16=(0, 1, 2), equal_to_1=())]})
@triton.jit
def kernel(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.reshape(tl.arange(0, XBLOCK), [XBLOCK])
xmask = xindex < xnumel
x0 = xindex % 64
x4 = (xindex // 64) % 128
x5 = (xindex // 8192) % 12
x6 = (xindex // 98304)
x7 = xindex
tmp0 = tl.load(in_ptr0 + (x0 + (64*x5) + (768*x4) + (98304*x6)), xmask)
tl.store(out_ptr0 + (x7 + tl.zeros([XBLOCK], tl.int32)), tmp0, xmask)
''')
triton_fused_clone_3_add_1_div_where_mul_clone_2_permute_5_bmm_1_permute_6_sum_15 = async_compile.triton('''
import triton
import triton.language as tl
from torch._inductor.ir import ReductionHint
from torch._inductor.ir import TileHint
from torch._inductor.triton_ops.autotune import pointwise
from torch._inductor.utils import instance_descriptor
@pointwise(size_hints=[262144], filename=__file__, meta={'signature': {0: '*fp32', 1: '*fp32', 2: 'i32'}, 'device': 0, 'constants': {}, 'configs': [instance_descriptor(divisible_by_16=(0, 1, 2), equal_to_1=())]})
@triton.jit
def kernel(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.reshape(tl.arange(0, XBLOCK), [XBLOCK])
xmask = xindex < xnumel
x0 = xindex % 64
x1 = (xindex // 64) % 12
x4 = (xindex // 768) % 128
x5 = (xindex // 98304)
x6 = xindex
tmp0 = tl.load(in_ptr0 + (x0 + (64*x4) + (8192*x1) + (98304*x5)), xmask)
tl.store(out_ptr0 + (x6 + tl.zeros([XBLOCK], tl.int32)), tmp0, xmask)
''')
triton_fused_clone_3_add_1_div_view_13_mul_clone_2_bmm_1_permute_8_sub_mean6 = async_compile.triton('''
import triton
import triton.language as tl
from torch._inductor.ir import ReductionHint
from torch._inductor.ir import TileHint
from torch._inductor.triton_ops.autotune import reduction
from torch._inductor.utils import instance_descriptor
@reduction(size_hints=[256, 1024],
reduction_hint=ReductionHint.INNER,
filename=__file__,
meta={'signature': {0: '*fp32', 1: '*fp32', 2: '*fp32', 3: '*fp32', 4: '*fp32', 5: '*fp32', 6: '*fp32', 7: 'i32', 8: 'i32'}, 'device': 0, 'constants': {}, 'configs': [instance_descriptor(divisible_by_16=(0, 1, 2, 3, 4, 5, 6, 7, 8), equal_to_1=())]})
@triton.jit
def kernel(in_out_ptr0, in_out_ptr1, in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr1, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr):
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.reshape(tl.arange(0, XBLOCK), [XBLOCK, 1])
xmask = xindex < xnumel
rbase = tl.reshape(tl.arange(0, RBLOCK), [1, RBLOCK])
x0 = xindex
_tmp3 = tl.zeros([XBLOCK, RBLOCK], tl.float32) + 0
for roffset in range(0, rnumel, RBLOCK):
rindex = roffset + rbase
rmask = rindex < rnumel
r1 = rindex
tmp0 = tl.load(in_ptr0 + (r1 + (768*x0)), xmask & rmask, eviction_policy='evict_last')
tmp1 = tl.load(in_ptr1 + (r1 + (768*x0)), xmask & rmask, eviction_policy='evict_last')
tmp2 = tmp0 + tmp1
_tmp3 = tl.where(xmask & rmask, _tmp3 + tmp2, _tmp3)
tmp3 = tl.reshape(tl.sum(_tmp3, 1), [XBLOCK, 1])
tmp4 = tmp3
_tmp12 = tl.zeros([XBLOCK, RBLOCK], tl.float32) + 0
for roffset in range(0, rnumel, RBLOCK):
rindex = roffset + rbase
rmask = rindex < rnumel
r1 = rindex
tmp5 = tl.load(in_ptr0 + (r1 + (768*x0)), xmask & rmask, eviction_policy='evict_last')
tmp6 = tl.load(in_ptr1 + (r1 + (768*x0)), xmask & rmask, eviction_policy='evict_last')
tmp7 = tmp5 + tmp6
tmp8 = 768
tmp9 = tmp4 / tmp8
tmp10 = tmp7 - tmp9
tmp11 = tmp10 * tmp10
_tmp12 = tl.where(xmask & rmask, _tmp12 + tmp11, _tmp12)
tmp12 = tl.reshape(tl.sum(_tmp12, 1), [XBLOCK, 1])
tmp13 = 768
tmp14 = tmp3 / tmp13
tmp15 = 767
tmp16 = tmp12 / tmp15
tmp17 = tl.sqrt(tmp16)
tl.store(in_out_ptr0 + (x0 + tl.zeros([XBLOCK, 1], tl.int32)), tmp14, xmask)
tl.store(in_out_ptr1 + (x0 + tl.zeros([XBLOCK, 1], tl.int32)), tmp17, xmask)
for roffset in range(0, rnumel, RBLOCK):
rindex = roffset + rbase
rmask = rindex < rnumel
r1 = rindex
tmp18 = tl.load(in_ptr2 + (r1), rmask, eviction_policy='evict_last')
tmp19 = tl.load(in_ptr0 + (r1 + (768*x0)), xmask & rmask, eviction_policy='evict_last')
tmp20 = tl.load(in_ptr1 + (r1 + (768*x0)), xmask & rmask, eviction_policy='evict_last')
tmp27 = tl.load(in_ptr3 + (r1), rmask, eviction_policy='evict_last')
tmp21 = tmp19 + tmp20
tmp22 = tmp21 - tmp14
tmp23 = tmp18 * tmp22
tmp24 = 1e-06
tmp25 = tmp17 + tmp24
tmp26 = tmp23 / tmp25
tmp28 = tmp26 + tmp27
tl.store(out_ptr1 + (r1 + (768*x0) + tl.zeros([XBLOCK, RBLOCK], tl.int32)), tmp28, xmask & rmask)
''')
triton_fused_clone_3_add_1_div_abs_1_view_13_mul_clone_2_bmm_1_mul_631_permute_87 = async_compile.triton('''
import triton
import triton.language as tl
from torch._inductor.ir import ReductionHint
from torch._inductor.ir import TileHint
from torch._inductor.triton_ops.autotune import pointwise
from torch._inductor.utils import instance_descriptor
@pointwise(size_hints=[1048576], filename=__file__, meta={'signature': {0: '*fp32', 1: '*fp32', 2: '*fp32', 3: 'i32'}, 'device': 0, 'constants': {}, 'configs': [instance_descriptor(divisible_by_16=(0, 1, 2, 3), equal_to_1=())]})
@triton.jit
def kernel(in_ptr0, out_ptr0, out_ptr1, xnumel, XBLOCK : tl.constexpr):
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.reshape(tl.arange(0, XBLOCK), [XBLOCK])
xmask = xindex < xnumel
x0 = xindex
tmp0 = tl.load(in_ptr0 + (x0), xmask)
tmp1 = 0.5
tmp2 = tmp0 * tmp1
tmp3 = 0.7071067811865476
tmp4 = tmp0 * tmp3
tmp5 = tl.where(tmp4 < 0, -1, 1)
tmp6 = tl.where(tmp4 == 0, 0, tmp5)
tmp7 = 1.0
tmp8 = tl.abs(tmp4)
tmp9 = 0.3275911
tmp10 = tmp8 * tmp9
tmp11 = tmp10 + tmp7
tmp12 = 1 / tmp11
tmp13 = tmp12 * tmp7
tmp14 = 1.061405429
tmp15 = tmp13 * tmp14
tmp16 = -1.453152027
tmp17 = tmp15 + tmp16
tmp18 = tmp17 * tmp13
tmp19 = 1.421413741
tmp20 = tmp18 + tmp19
tmp21 = tmp20 * tmp13
tmp22 = -0.284496736
tmp23 = tmp21 + tmp22
tmp24 = tmp23 * tmp13
tmp25 = 0.254829592
tmp26 = tmp24 + tmp25
tmp27 = tmp26 * tmp13
tmp28 = -tmp8
tmp29 = tmp28 * tmp8
tmp30 = tl.exp(tmp29)
tmp31 = tmp27 * tmp30
tmp32 = tmp7 - tmp31
tmp33 = tmp6 * tmp32
tmp34 = 1
tmp35 = tmp33 + tmp34
tmp36 = tmp2 * tmp35
tmp37 = tmp35 * tmp1
tmp38 = tmp0 * tmp0
tmp39 = -0.5
tmp40 = tmp38 * tmp39
tmp41 = tl.exp(tmp40)
tmp42 = 0.3989422804014327
tmp43 = tmp41 * tmp42
tmp44 = tmp0 * tmp43
tmp45 = tmp37 + tmp44
tl.store(out_ptr0 + (x0 + tl.zeros([XBLOCK], tl.int32)), tmp36, xmask)
tl.store(out_ptr1 + (x0 + tl.zeros([XBLOCK], tl.int32)), tmp45, xmask)
''')
triton_fused_add_11_clone_3_add_1_div_abs_1_view_13_mul_add_12_clone_2_sub_48 = async_compile.triton('''
import triton
import triton.language as tl
from torch._inductor.ir import ReductionHint
from torch._inductor.ir import TileHint
from torch._inductor.triton_ops.autotune import reduction
from torch._inductor.utils import instance_descriptor
@reduction(size_hints=[256, 1024],
reduction_hint=ReductionHint.INNER,
filename=__file__,
meta={'signature': {0: '*fp32', 1: '*fp32', 2: '*fp32', 3: '*fp32', 4: '*fp32', 5: '*fp32', 6: '*fp32', 7: '*fp32', 8: 'i32', 9: 'i32'}, 'device': 0, 'constants': {}, 'configs': [instance_descriptor(divisible_by_16=(0, 1, 2, 3, 4, 5, 6, 7, 8, 9), equal_to_1=())]})
@triton.jit
def kernel(in_out_ptr0, in_out_ptr1, in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, out_ptr1, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr):
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.reshape(tl.arange(0, XBLOCK), [XBLOCK, 1])
xmask = xindex < xnumel
rbase = tl.reshape(tl.arange(0, RBLOCK), [1, RBLOCK])
x0 = xindex
_tmp5 = tl.zeros([XBLOCK, RBLOCK], tl.float32) + 0
for roffset in range(0, rnumel, RBLOCK):
rindex = roffset + rbase
rmask = rindex < rnumel
r1 = rindex
tmp0 = tl.load(in_ptr0 + (r1 + (768*x0)), xmask & rmask, eviction_policy='evict_last')
tmp1 = tl.load(in_ptr1 + (r1 + (768*x0)), xmask & rmask, eviction_policy='evict_last')
tmp3 = tl.load(in_ptr2 + (r1 + (768*x0)), xmask & rmask, eviction_policy='evict_last')
tmp2 = tmp0 + tmp1
tmp4 = tmp2 + tmp3
_tmp5 = tl.where(xmask & rmask, _tmp5 + tmp4, _tmp5)
tmp5 = tl.reshape(tl.sum(_tmp5, 1), [XBLOCK, 1])
tmp6 = tmp5
_tmp16 = tl.zeros([XBLOCK, RBLOCK], tl.float32) + 0
for roffset in range(0, rnumel, RBLOCK):
rindex = roffset + rbase
rmask = rindex < rnumel
r1 = rindex
tmp7 = tl.load(in_ptr0 + (r1 + (768*x0)), xmask & rmask, eviction_policy='evict_last')
tmp8 = tl.load(in_ptr1 + (r1 + (768*x0)), xmask & rmask, eviction_policy='evict_last')
tmp10 = tl.load(in_ptr2 + (r1 + (768*x0)), xmask & rmask, eviction_policy='evict_last')
tmp9 = tmp7 + tmp8
tmp11 = tmp9 + tmp10
tmp12 = 768
tmp13 = tmp6 / tmp12
tmp14 = tmp11 - tmp13
tmp15 = tmp14 * tmp14
_tmp16 = tl.where(xmask & rmask, _tmp16 + tmp15, _tmp16)
tmp16 = tl.reshape(tl.sum(_tmp16, 1), [XBLOCK, 1])
tmp17 = 768
tmp18 = tmp5 / tmp17
tmp19 = 767
tmp20 = tmp16 / tmp19
tmp21 = tl.sqrt(tmp20)
tl.store(in_out_ptr0 + (x0 + tl.zeros([XBLOCK, 1], tl.int32)), tmp18, xmask)
tl.store(in_out_ptr1 + (x0 + tl.zeros([XBLOCK, 1], tl.int32)), tmp21, xmask)
for roffset in range(0, rnumel, RBLOCK):
rindex = roffset + rbase
rmask = rindex < rnumel
r1 = rindex
tmp22 = tl.load(in_ptr3 + (r1), rmask, eviction_policy='evict_last')
tmp23 = tl.load(in_ptr0 + (r1 + (768*x0)), xmask & rmask, eviction_policy='evict_last')
tmp24 = tl.load(in_ptr1 + (r1 + (768*x0)), xmask & rmask, eviction_policy='evict_last')
tmp26 = tl.load(in_ptr2 + (r1 + (768*x0)), xmask & rmask, eviction_policy='evict_last')
tmp33 = tl.load(in_ptr4 + (r1), rmask, eviction_policy='evict_last')
tmp25 = tmp23 + tmp24
tmp27 = tmp25 + tmp26
tmp28 = tmp27 - tmp18
tmp29 = tmp22 * tmp28
tmp30 = 1e-06
tmp31 = tmp21 + tmp30
tmp32 = tmp29 / tmp31
tmp34 = tmp32 + tmp33
tl.store(out_ptr1 + (r1 + (768*x0) + tl.zeros([XBLOCK, RBLOCK], tl.int32)), tmp34, xmask & rmask)
''')
triton_fused_add_11_add_1_permute_19_div_view_23_mul_add_12_sub_4_permute_14_view_199 = async_compile.triton('''
import triton
import triton.language as tl
from torch._inductor.ir import ReductionHint
from torch._inductor.ir import TileHint
from torch._inductor.triton_ops.autotune import reduction
from torch._inductor.utils import instance_descriptor
@reduction(size_hints=[256, 1024],
reduction_hint=ReductionHint.INNER,
filename=__file__,
meta={'signature': {0: '*fp32', 1: '*fp32', 2: '*fp32', 3: '*fp32', 4: '*fp32', 5: '*fp32', 6: '*fp32', 7: '*fp32', 8: '*fp32', 9: 'i32', 10: 'i32'}, 'device': 0, 'constants': {}, 'configs': [instance_descriptor(divisible_by_16=(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10), equal_to_1=())]})
@triton.jit
def kernel(in_out_ptr0, in_out_ptr1, in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, in_ptr5, out_ptr1, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr):
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.reshape(tl.arange(0, XBLOCK), [XBLOCK, 1])
xmask = xindex < xnumel
rbase = tl.reshape(tl.arange(0, RBLOCK), [1, RBLOCK])
x0 = xindex
_tmp7 = tl.zeros([XBLOCK, RBLOCK], tl.float32) + 0
for roffset in range(0, rnumel, RBLOCK):
rindex = roffset + rbase
rmask = rindex < rnumel
r1 = rindex
tmp0 = tl.load(in_ptr0 + (r1 + (768*x0)), xmask & rmask, eviction_policy='evict_last')
tmp1 = tl.load(in_ptr1 + (r1 + (768*x0)), xmask & rmask, eviction_policy='evict_last')
tmp3 = tl.load(in_ptr2 + (r1 + (768*x0)), xmask & rmask, eviction_policy='evict_last')
tmp5 = tl.load(in_ptr3 + (r1 + (768*x0)), xmask & rmask, eviction_policy='evict_last')
tmp2 = tmp0 + tmp1
tmp4 = tmp2 + tmp3
tmp6 = tmp4 + tmp5
_tmp7 = tl.where(xmask & rmask, _tmp7 + tmp6, _tmp7)
tmp7 = tl.reshape(tl.sum(_tmp7, 1), [XBLOCK, 1])
tmp8 = tmp7
_tmp20 = tl.zeros([XBLOCK, RBLOCK], tl.float32) + 0
for roffset in range(0, rnumel, RBLOCK):
rindex = roffset + rbase
rmask = rindex < rnumel
r1 = rindex
tmp9 = tl.load(in_ptr0 + (r1 + (768*x0)), xmask & rmask, eviction_policy='evict_last')
tmp10 = tl.load(in_ptr1 + (r1 + (768*x0)), xmask & rmask, eviction_policy='evict_last')
tmp12 = tl.load(in_ptr2 + (r1 + (768*x0)), xmask & rmask, eviction_policy='evict_last')
tmp14 = tl.load(in_ptr3 + (r1 + (768*x0)), xmask & rmask, eviction_policy='evict_last')
tmp11 = tmp9 + tmp10
tmp13 = tmp11 + tmp12
tmp15 = tmp13 + tmp14
tmp16 = 768
tmp17 = tmp8 / tmp16
tmp18 = tmp15 - tmp17
tmp19 = tmp18 * tmp18
_tmp20 = tl.where(xmask & rmask, _tmp20 + tmp19, _tmp20)
tmp20 = tl.reshape(tl.sum(_tmp20, 1), [XBLOCK, 1])
tmp21 = 768
tmp22 = tmp7 / tmp21
tmp23 = 767
tmp24 = tmp20 / tmp23
tmp25 = tl.sqrt(tmp24)
tl.store(in_out_ptr0 + (x0 + tl.zeros([XBLOCK, 1], tl.int32)), tmp22, xmask)
tl.store(in_out_ptr1 + (x0 + tl.zeros([XBLOCK, 1], tl.int32)), tmp25, xmask)
for roffset in range(0, rnumel, RBLOCK):
rindex = roffset + rbase
rmask = rindex < rnumel
r1 = rindex
tmp26 = tl.load(in_ptr4 + (r1), rmask, eviction_policy='evict_last')
tmp27 = tl.load(in_ptr0 + (r1 + (768*x0)), xmask & rmask, eviction_policy='evict_last')
tmp28 = tl.load(in_ptr1 + (r1 + (768*x0)), xmask & rmask, eviction_policy='evict_last')
tmp30 = tl.load(in_ptr2 + (r1 + (768*x0)), xmask & rmask, eviction_policy='evict_last')
tmp32 = tl.load(in_ptr3 + (r1 + (768*x0)), xmask & rmask, eviction_policy='evict_last')
tmp39 = tl.load(in_ptr5 + (r1), rmask, eviction_policy='evict_last')
tmp29 = tmp27 + tmp28
tmp31 = tmp29 + tmp30
tmp33 = tmp31 + tmp32
tmp34 = tmp33 - tmp22
tmp35 = tmp26 * tmp34
tmp36 = 1e-06
tmp37 = tmp25 + tmp36
tmp38 = tmp35 / tmp37
tmp40 = tmp38 + tmp39
tl.store(out_ptr1 + (r1 + (768*x0) + tl.zeros([XBLOCK, RBLOCK], tl.int32)), tmp40, xmask & rmask)
''')
triton_fused_add_11_add_1_permute_19_div_view_23_mul_add_12_view_35_sub_4_permute_1410 = async_compile.triton('''
import triton
import triton.language as tl
from torch._inductor.ir import ReductionHint
from torch._inductor.ir import TileHint
from torch._inductor.triton_ops.autotune import reduction
from torch._inductor.utils import instance_descriptor
@reduction(size_hints=[256, 1024],
reduction_hint=ReductionHint.INNER,
filename=__file__,
meta={'signature': {0: '*fp32', 1: '*fp32', 2: '*fp32', 3: '*fp32', 4: '*fp32', 5: '*fp32', 6: '*fp32', 7: '*fp32', 8: '*fp32', 9: '*fp32', 10: '*fp32', 11: 'i32', 12: 'i32'}, 'device': 0, 'constants': {}, 'configs': [instance_descriptor(divisible_by_16=(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12), equal_to_1=())]})
@triton.jit
def kernel(in_out_ptr0, in_out_ptr1, in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, in_ptr5, in_ptr6, out_ptr0, out_ptr2, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr):
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.reshape(tl.arange(0, XBLOCK), [XBLOCK, 1])
xmask = xindex < xnumel
rbase = tl.reshape(tl.arange(0, RBLOCK), [1, RBLOCK])
x0 = xindex
for roffset in range(0, rnumel, RBLOCK):
rindex = roffset + rbase
rmask = rindex < rnumel
r1 = rindex
tmp0 = tl.load(in_ptr0 + (r1 + (768*x0)), xmask & rmask, eviction_policy='evict_last')
tmp1 = tl.load(in_ptr1 + (r1 + (768*x0)), xmask & rmask, eviction_policy='evict_last')
tmp3 = tl.load(in_ptr2 + (r1 + (768*x0)), xmask & rmask, eviction_policy='evict_last')
tmp5 = tl.load(in_ptr3 + (r1 + (768*x0)), xmask & rmask, eviction_policy='evict_last')
tmp7 = tl.load(in_ptr4 + (r1 + (768*x0)), xmask & rmask, eviction_policy='evict_last')
tmp2 = tmp0 + tmp1
tmp4 = tmp2 + tmp3
tmp6 = tmp4 + tmp5
tmp8 = tmp6 + tmp7
tl.store(out_ptr0 + (r1 + (768*x0) + tl.zeros([XBLOCK, RBLOCK], tl.int32)), tmp8, xmask & rmask)
_tmp10 = tl.zeros([XBLOCK, RBLOCK], tl.float32) + 0
for roffset in range(0, rnumel, RBLOCK):
rindex = roffset + rbase
rmask = rindex < rnumel
r1 = rindex
tmp9 = tl.load(out_ptr0 + (r1 + (768*x0)), xmask & rmask, eviction_policy='evict_last')
_tmp10 = tl.where(xmask & rmask, _tmp10 + tmp9, _tmp10)
tmp10 = tl.reshape(tl.sum(_tmp10, 1), [XBLOCK, 1])
tmp11 = tmp10
_tmp17 = tl.zeros([XBLOCK, RBLOCK], tl.float32) + 0
for roffset in range(0, rnumel, RBLOCK):
rindex = roffset + rbase
rmask = rindex < rnumel
r1 = rindex
tmp12 = tl.load(out_ptr0 + (r1 + (768*x0)), xmask & rmask, eviction_policy='evict_last')
tmp13 = 768
tmp14 = tmp11 / tmp13
tmp15 = tmp12 - tmp14
tmp16 = tmp15 * tmp15
_tmp17 = tl.where(xmask & rmask, _tmp17 + tmp16, _tmp17)
tmp17 = tl.reshape(tl.sum(_tmp17, 1), [XBLOCK, 1])
tmp18 = 768
tmp19 = tmp10 / tmp18
tmp20 = 767
tmp21 = tmp17 / tmp20
tmp22 = tl.sqrt(tmp21)
tl.store(in_out_ptr0 + (x0 + tl.zeros([XBLOCK, 1], tl.int32)), tmp19, xmask)
tl.store(in_out_ptr1 + (x0 + tl.zeros([XBLOCK, 1], tl.int32)), tmp22, xmask)
for roffset in range(0, rnumel, RBLOCK):
rindex = roffset + rbase
rmask = rindex < rnumel
r1 = rindex
tmp23 = tl.load(in_ptr5 + (r1), rmask, eviction_policy='evict_last')
tmp24 = tl.load(out_ptr0 + (r1 + (768*x0)), xmask & rmask, eviction_policy='evict_last')
tmp30 = tl.load(in_ptr6 + (r1), rmask, eviction_policy='evict_last')
tmp25 = tmp24 - tmp19
tmp26 = tmp23 * tmp25
tmp27 = 1e-06
tmp28 = tmp22 + tmp27
tmp29 = tmp26 / tmp28
tmp31 = tmp29 + tmp30
tl.store(out_ptr2 + (r1 + (768*x0) + tl.zeros([XBLOCK, RBLOCK], tl.int32)), tmp31, xmask & rmask)
''')
triton_fused_permute_19_view_35_view_34_reciprocal_1_view_33_addmm_9_mul_32_permute_20_std_3_add_1411 = async_compile.triton('''
import triton
import triton.language as tl
from torch._inductor.ir import ReductionHint
from torch._inductor.ir import TileHint
from torch._inductor.triton_ops.autotune import reduction
from torch._inductor.utils import instance_descriptor
@reduction(size_hints=[256, 1024],
reduction_hint=ReductionHint.INNER,
filename=__file__,
meta={'signature': {0: '*fp32', 1: '*fp32', 2: '*fp32', 3: '*fp32', 4: '*fp32', 5: '*fp32', 6: '*fp32', 7: '*fp32', 8: '*fp32', 9: '*fp32', 10: 'i32', 11: 'i32'}, 'device': 0, 'constants': {}, 'configs': [instance_descriptor(divisible_by_16=(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11), equal_to_1=())]})
@triton.jit
def kernel(in_out_ptr0, in_out_ptr1, in_out_ptr2, in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, in_ptr5, out_ptr1, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr):
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.reshape(tl.arange(0, XBLOCK), [XBLOCK, 1])
xmask = xindex < xnumel
rbase = tl.reshape(tl.arange(0, RBLOCK), [1, RBLOCK])
x0 = xindex
for roffset in range(0, rnumel, RBLOCK):
rindex = roffset + rbase
rmask = rindex < rnumel
r1 = rindex
tmp0 = tl.load(in_out_ptr0 + (r1 + (768*x0)), xmask & rmask, eviction_policy='evict_last')
tmp1 = tl.load(in_ptr0 + (r1 + (768*x0)), xmask & rmask, eviction_policy='evict_last')
tmp3 = tl.load(in_ptr1 + (r1 + (768*x0)), xmask & rmask, eviction_policy='evict_last')
tmp5 = tl.load(in_ptr2 + (r1 + (768*x0)), xmask & rmask, eviction_policy='evict_last')
tmp7 = tl.load(in_ptr3 + (r1 + (768*x0)), xmask & rmask, eviction_policy='evict_last')
tmp2 = tmp0 + tmp1
tmp4 = tmp2 + tmp3
tmp6 = tmp4 + tmp5
tmp8 = tmp6 + tmp7
tl.store(in_out_ptr0 + (r1 + (768*x0) + tl.zeros([XBLOCK, RBLOCK], tl.int32)), tmp8, xmask & rmask)
_tmp10 = tl.zeros([XBLOCK, RBLOCK], tl.float32) + 0
for roffset in range(0, rnumel, RBLOCK):
rindex = roffset + rbase
rmask = rindex < rnumel
r1 = rindex
tmp9 = tl.load(in_out_ptr0 + (r1 + (768*x0)), xmask & rmask, eviction_policy='evict_last')
_tmp10 = tl.where(xmask & rmask, _tmp10 + tmp9, _tmp10)
tmp10 = tl.reshape(tl.sum(_tmp10, 1), [XBLOCK, 1])
tmp11 = tmp10
_tmp17 = tl.zeros([XBLOCK, RBLOCK], tl.float32) + 0
for roffset in range(0, rnumel, RBLOCK):
rindex = roffset + rbase
rmask = rindex < rnumel
r1 = rindex
tmp12 = tl.load(in_out_ptr0 + (r1 + (768*x0)), xmask & rmask, eviction_policy='evict_last')
tmp13 = 768
tmp14 = tmp11 / tmp13
tmp15 = tmp12 - tmp14
tmp16 = tmp15 * tmp15
_tmp17 = tl.where(xmask & rmask, _tmp17 + tmp16, _tmp17)
tmp17 = tl.reshape(tl.sum(_tmp17, 1), [XBLOCK, 1])
tmp18 = 768
tmp19 = tmp10 / tmp18
tmp20 = 767
tmp21 = tmp17 / tmp20
tmp22 = tl.sqrt(tmp21)
tl.store(in_out_ptr1 + (x0 + tl.zeros([XBLOCK, 1], tl.int32)), tmp19, xmask)
tl.store(in_out_ptr2 + (x0 + tl.zeros([XBLOCK, 1], tl.int32)), tmp22, xmask)
for roffset in range(0, rnumel, RBLOCK):
rindex = roffset + rbase
rmask = rindex < rnumel
r1 = rindex
tmp23 = tl.load(in_ptr4 + (r1), rmask, eviction_policy='evict_last')
tmp24 = tl.load(in_out_ptr0 + (r1 + (768*x0)), xmask & rmask, eviction_policy='evict_last')
tmp30 = tl.load(in_ptr5 + (r1), rmask, eviction_policy='evict_last')
tmp25 = tmp24 - tmp19
tmp26 = tmp23 * tmp25
tmp27 = 1e-06
tmp28 = tmp22 + tmp27
tmp29 = tmp26 / tmp28
tmp31 = tmp29 + tmp30
tl.store(out_ptr1 + (r1 + (768*x0) + tl.zeros([XBLOCK, RBLOCK], tl.int32)), tmp31, xmask & rmask)
''')
triton_fused_clone_19_view_88_bmm_9_div_18__unsafe_view_14_abs_5_lift_fresh_copy_8_clone_17_bmm_8_permute_5212 = async_compile.triton('''
import triton
import triton.language as tl
from torch._inductor.ir import ReductionHint
from torch._inductor.ir import TileHint
from torch._inductor.triton_ops.autotune import pointwise
from torch._inductor.utils import instance_descriptor
@pointwise(size_hints=[262144], filename=__file__, meta={'signature': {0: '*fp32', 1: '*fp32', 2: '*fp32', 3: '*fp32', 4: '*fp32', 5: 'i32'}, 'device': 0, 'constants': {}, 'configs': [instance_descriptor(divisible_by_16=(0, 1, 2, 3, 4, 5), equal_to_1=())]})
@triton.jit
def kernel(in_out_ptr0, in_ptr0, in_ptr1, in_ptr2, in_ptr3, xnumel, XBLOCK : tl.constexpr):
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.reshape(tl.arange(0, XBLOCK), [XBLOCK])
xmask = xindex < xnumel
x0 = xindex
tmp0 = tl.load(in_out_ptr0 + (x0), xmask)
tmp1 = tl.load(in_ptr0 + (x0), xmask)
tmp3 = tl.load(in_ptr1 + (x0), xmask)
tmp5 = tl.load(in_ptr2 + (x0), xmask)
tmp7 = tl.load(in_ptr3 + (x0), xmask)
tmp2 = tmp0 + tmp1
tmp4 = tmp2 + tmp3
tmp6 = tmp4 + tmp5
tmp8 = tmp6 + tmp7
tl.store(in_out_ptr0 + (x0 + tl.zeros([XBLOCK], tl.int32)), tmp8, xmask)
''')
async_compile.wait(globals())
del async_compile
def call(args):
primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, primals_12, primals_13, primals_14, primals_15, primals_16, primals_17, primals_18, primals_19, primals_20, primals_21, primals_22, primals_23, primals_24, primals_25, primals_26, primals_27, primals_28, primals_29, primals_30, primals_31, primals_32, primals_33, primals_34, primals_35, primals_36, primals_37, primals_38, primals_39, primals_40, primals_41, primals_42, primals_43, primals_44, primals_45, primals_46, primals_47, primals_48, primals_49, primals_50, primals_51, primals_52, primals_53, primals_54, primals_55, primals_56, primals_57, primals_58, primals_59, primals_60, primals_61, primals_62, primals_63, primals_64, primals_65, primals_66, primals_67, primals_68, primals_69, primals_70, primals_71, primals_72, primals_73, primals_74, primals_75, primals_76, primals_77, primals_78, primals_79, primals_80, primals_81, primals_82, primals_83, primals_84, primals_85, primals_86, primals_87, primals_88, primals_89, primals_90, primals_91, primals_92, primals_93, primals_94, primals_95, primals_96, primals_97, primals_98, primals_99, primals_100, primals_101, primals_102, primals_103, primals_104, primals_105, primals_106, primals_107, primals_108, primals_109, primals_110, primals_111, primals_112, primals_113, primals_114, primals_115, primals_116, primals_117, primals_118, primals_119, primals_120, primals_121, primals_122, primals_123, primals_124, primals_125, primals_126, primals_127, primals_128, primals_129, primals_130, primals_131, primals_132, primals_133, primals_134, primals_135, primals_136, primals_137, primals_138, primals_139, primals_140, primals_141, primals_142, primals_143, primals_144, primals_145, primals_146, primals_147, primals_148, primals_149, primals_150, primals_151, primals_152, primals_153, primals_154, primals_155, primals_156, primals_157, primals_158, primals_159, primals_160, primals_161, primals_162, primals_163, primals_164, primals_165, primals_166, primals_167, primals_168, primals_169, primals_170, primals_171, primals_172, primals_173, primals_174, primals_175, primals_176, primals_177, primals_178, primals_179, primals_180, primals_181, primals_182, primals_183, primals_184, primals_185, primals_186, primals_187, primals_188, primals_189, primals_190, primals_191, primals_192, primals_193, primals_194 = args
args.clear()
primals_193_size = primals_193.size()
s0 = primals_193_size[0]
s1 = primals_193_size[1]
buf0 = empty_strided((s0, 128, 1), (128, 1, 128*s0), device='cuda', dtype=torch.float32)
buf3 = empty_strided((s0, 128, 1), (128, 1, 128*s0), device='cuda', dtype=torch.float32)
buf1 = buf0; del buf0 # reuse
buf4 = buf3; del buf3 # reuse
buf5 = empty_strided((s0, 128, 768), (98304, 768, 1), device='cuda', dtype=torch.float32)
triton_fused_add_1_div_mul_std_sub_mean_add0_xnumel = s0*s1
stream0 = get_cuda_stream(0)
triton_fused_add_1_div_mul_std_sub_mean_add0.run(buf1, buf4, primals_193, primals_1, primals_2, buf5, triton_fused_add_1_div_mul_std_sub_mean_add0_xnumel, 768, grid=grid(triton_fused_add_1_div_mul_std_sub_mean_add0_xnumel), stream=stream0)
buf6 = empty_strided((128*s0, 768), (768, 1), device='cuda', dtype=torch.float32)
aten.addmm.out(primals_50, as_strided(buf5, (s0*s1, 768), (768, 1)), as_strided(primals_49, (768, 768), (1, 768)), beta=1, alpha=1, out=buf6)
del primals_50
buf7 = empty_strided((128*s0, 768), (768, 1), device='cuda', dtype=torch.float32)
aten.addmm.out(primals_52, as_strided(buf5, (s0*s1, 768), (768, 1)), as_strided(primals_51, (768, 768), (1, 768)), beta=1, alpha=1, out=buf7)
del primals_52
buf8 = empty_strided((128*s0, 768), (768, 1), device='cuda', dtype=torch.float32)
aten.addmm.out(primals_54, as_strided(buf5, (s0*s1, 768), (768, 1)), as_strided(primals_53, (768, 768), (1, 768)), beta=1, alpha=1, out=buf8)
del primals_54
buf9 = as_strided(buf5, (s0, 12, 128, 64), (98304, 8192, 64, 1)); del buf5 # reuse
triton_fused_view_1_add_1_addmm_div_mul_permute_sub_mean_add_permute_11_xnumel = 768*s0*s1
triton_fused_view_1_add_1_addmm_div_mul_permute_sub_mean_add_permute_11.run(buf6, buf9, s1, triton_fused_view_1_add_1_addmm_div_mul_permute_sub_mean_add_permute_11_xnumel, grid=grid(triton_fused_view_1_add_1_addmm_div_mul_permute_sub_mean_add_permute_11_xnumel), stream=stream0)
buf10 = as_strided(buf6, (s0, 12, 64, 128), (98304, 8192, 128, 1)); del buf6 # reuse
triton_fused_add_1_div_mul_permute_6_sub_addmm_1_mean_permute_3_add_clone_12_xnumel = 768*s0
triton_fused_add_1_div_mul_permute_6_sub_addmm_1_mean_permute_3_add_clone_12.run(buf7, buf10, triton_fused_add_1_div_mul_permute_6_sub_addmm_1_mean_permute_3_add_clone_12_xnumel, 128, grid=grid(triton_fused_add_1_div_mul_permute_6_sub_addmm_1_mean_permute_3_add_clone_12_xnumel, 128), stream=stream0)
buf11 = empty_strided((12*s0, 128, 128), (16384, 128, 1), device='cuda', dtype=torch.float32)
aten.bmm.out(as_strided(buf9, (12*s0, 128, 64), (8192, 64, 1)), as_strided(buf10, (12*s0, 64, 128), (8192, 128, 1)), out=buf11)
buf14 = empty_strided((s0, 12, 128, 128), (196608, 16384, 128, 1), device='cuda', dtype=torch.float32)
triton_fused_add_1_div_mul_sub_mean_add_div_2_sub_1_div_1_amax3_xnumel = 1536*s0
triton_fused_add_1_div_mul_sub_mean_add_div_2_sub_1_div_1_amax3.run(primals_194, buf11, buf14, triton_fused_add_1_div_mul_sub_mean_add_div_2_sub_1_div_1_amax3_xnumel, 128, grid=grid(triton_fused_add_1_div_mul_sub_mean_add_div_2_sub_1_div_1_amax3_xnumel), stream=stream0)
buf15 = as_strided(buf7, (s0, 12, 128, 64), (98304, 8192, 64, 1)); del buf7 # reuse
triton_fused_add_1_div_mul_clone_2_view_7_permute_5_sub_permute_4_mean_add4_xnumel = 98304*s0
triton_fused_add_1_div_mul_clone_2_view_7_permute_5_sub_permute_4_mean_add4.run(buf8, buf15, triton_fused_add_1_div_mul_clone_2_view_7_permute_5_sub_permute_4_mean_add4_xnumel, grid=grid(triton_fused_add_1_div_mul_clone_2_view_7_permute_5_sub_permute_4_mean_add4_xnumel), stream=stream0)
buf16 = as_strided(buf8, (12*s0, 128, 64), (8192, 64, 1)); del buf8 # reuse
aten.bmm.out(as_strided(buf14, (12*s0, 128, 128), (16384, 128, 1)), as_strided(buf15, (12*s0, 128, 64), (8192, 64, 1)), out=buf16)
buf17 = empty_strided((s0, 128, 12, 64), (98304, 768, 64, 1), device='cuda', dtype=torch.float32)
triton_fused_clone_3_add_1_div_where_mul_clone_2_permute_5_bmm_1_permute_6_sum_15_xnumel = 98304*s0
triton_fused_clone_3_add_1_div_where_mul_clone_2_permute_5_bmm_1_permute_6_sum_15.run(buf16, buf17, triton_fused_clone_3_add_1_div_where_mul_clone_2_permute_5_bmm_1_permute_6_sum_15_xnumel, grid=grid(triton_fused_clone_3_add_1_div_where_mul_clone_2_permute_5_bmm_1_permute_6_sum_15_xnumel), stream=stream0)
buf18 = empty_strided((128*s0, 768), (768, 1), device='cuda', dtype=torch.float32)
aten.addmm.out(primals_56, as_strided(buf17, (128*s0, 768), (768, 1)), as_strided(primals_55, (768, 768), (1, 768)), beta=1, alpha=1, out=buf18)
del primals_56
buf19 = empty_strided((s0, 128, 1), (128, 1, 128*s0), device='cuda', dtype=torch.float32)
buf22 = empty_strided((s0, 128, 1), (128, 1, 128*s0), device='cuda', dtype=torch.float32)
buf20 = buf19; del buf19 # reuse
buf23 = buf22; del buf22 # reuse
buf24 = as_strided(buf17, (s0, 128, 768), (98304, 768, 1)); del buf17 # reuse
triton_fused_clone_3_add_1_div_view_13_mul_clone_2_bmm_1_permute_8_sub_mean6_xnumel = 128*s0
triton_fused_clone_3_add_1_div_view_13_mul_clone_2_bmm_1_permute_8_sub_mean6.run(buf20, buf23, primals_193, buf18, primals_3, primals_4, buf24, triton_fused_clone_3_add_1_div_view_13_mul_clone_2_bmm_1_permute_8_sub_mean6_xnumel, 768, grid=grid(triton_fused_clone_3_add_1_div_view_13_mul_clone_2_bmm_1_permute_8_sub_mean6_xnumel), stream=stream0)
buf25 = empty_strided((128*s0, 3072), (3072, 1), device='cuda', dtype=torch.float32)
aten.addmm.out(primals_58, as_strided(buf24, (128*s0, 768), (768, 1)), as_strided(primals_57, (768, 3072), (1, 768)), beta=1, alpha=1, out=buf25)
del primals_58
buf26 = empty_strided((s0, 128, 3072), (393216, 3072, 1), device='cuda', dtype=torch.float32)
buf353 = empty_strided((s0, 128, 3072), (393216, 3072, 1), device='cuda', dtype=torch.float32)
triton_fused_clone_3_add_1_div_abs_1_view_13_mul_clone_2_bmm_1_mul_631_permute_87_xnumel = 393216*s0
triton_fused_clone_3_add_1_div_abs_1_view_13_mul_clone_2_bmm_1_mul_631_permute_87.run(buf25, buf26, buf353, triton_fused_clone_3_add_1_div_abs_1_view_13_mul_clone_2_bmm_1_mul_631_permute_87_xnumel, grid=grid(triton_fused_clone_3_add_1_div_abs_1_view_13_mul_clone_2_bmm_1_mul_631_permute_87_xnumel), stream=stream0)
buf27 = as_strided(buf24, (128*s0, 768), (768, 1)); del buf24 # reuse
aten.addmm.out(primals_60, as_strided(buf26, (128*s0, 3072), (3072, 1)), as_strided(primals_59, (3072, 768), (1, 3072)), beta=1, alpha=1, out=buf27)
del primals_60
buf28 = empty_strided((s0, 128, 1), (128, 1, 128*s0), device='cuda', dtype=torch.float32)
buf31 = empty_strided((s0, 128, 1), (128, 1, 128*s0), device='cuda', dtype=torch.float32)
buf29 = buf28; del buf28 # reuse
buf32 = buf31; del buf31 # reuse
buf33 = empty_strided((s0, 128, 768), (98304, 768, 1), device='cuda', dtype=torch.float32)
triton_fused_add_11_clone_3_add_1_div_abs_1_view_13_mul_add_12_clone_2_sub_48_xnumel = 128*s0
triton_fused_add_11_clone_3_add_1_div_abs_1_view_13_mul_add_12_clone_2_sub_48.run(buf29, buf32, primals_193, buf18, buf27, primals_5, primals_6, buf33, triton_fused_add_11_clone_3_add_1_div_abs_1_view_13_mul_add_12_clone_2_sub_48_xnumel, 768, grid=grid(triton_fused_add_11_clone_3_add_1_div_abs_1_view_13_mul_add_12_clone_2_sub_48_xnumel), stream=stream0)
buf34 = empty_strided((128*s0, 768), (768, 1), device='cuda', dtype=torch.float32)
aten.addmm.out(primals_62, as_strided(buf33, (128*s0, 768), (768, 1)), as_strided(primals_61, (768, 768), (1, 768)), beta=1, alpha=1, out=buf34)
del primals_62
buf35 = empty_strided((128*s0, 768), (768, 1), device='cuda', dtype=torch.float32)
aten.addmm.out(primals_64, as_strided(buf33, (128*s0, 768), (768, 1)), as_strided(primals_63, (768, 768), (1, 768)), beta=1, alpha=1, out=buf35)
del primals_64
buf36 = empty_strided((128*s0, 768), (768, 1), device='cuda', dtype=torch.float32)
aten.addmm.out(primals_66, as_strided(buf33, (128*s0, 768), (768, 1)), as_strided(primals_65, (768, 768), (1, 768)), beta=1, alpha=1, out=buf36)
del primals_66
buf37 = as_strided(buf33, (s0, 12, 128, 64), (98304, 8192, 64, 1)); del buf33 # reuse
triton_fused_add_1_div_mul_clone_2_view_7_permute_5_sub_permute_4_mean_add4_xnumel = 98304*s0
triton_fused_add_1_div_mul_clone_2_view_7_permute_5_sub_permute_4_mean_add4.run(buf34, buf37, triton_fused_add_1_div_mul_clone_2_view_7_permute_5_sub_permute_4_mean_add4_xnumel, grid=grid(triton_fused_add_1_div_mul_clone_2_view_7_permute_5_sub_permute_4_mean_add4_xnumel), stream=stream0)
buf38 = as_strided(buf34, (s0, 12, 64, 128), (98304, 8192, 128, 1)); del buf34 # reuse
triton_fused_add_1_div_mul_permute_6_sub_addmm_1_mean_permute_3_add_clone_12_xnumel = 768*s0
triton_fused_add_1_div_mul_permute_6_sub_addmm_1_mean_permute_3_add_clone_12.run(buf35, buf38, triton_fused_add_1_div_mul_permute_6_sub_addmm_1_mean_permute_3_add_clone_12_xnumel, 128, grid=grid(triton_fused_add_1_div_mul_permute_6_sub_addmm_1_mean_permute_3_add_clone_12_xnumel, 128), stream=stream0)
buf39 = buf11; del buf11 # reuse
aten.bmm.out(as_strided(buf37, (12*s0, 128, 64), (8192, 64, 1)), as_strided(buf38, (12*s0, 64, 128), (8192, 128, 1)), out=buf39)
buf42 = empty_strided((s0, 12, 128, 128), (196608, 16384, 128, 1), device='cuda', dtype=torch.float32)
triton_fused_add_1_div_mul_sub_mean_add_div_2_sub_1_div_1_amax3_xnumel = 1536*s0
triton_fused_add_1_div_mul_sub_mean_add_div_2_sub_1_div_1_amax3.run(primals_194, buf39, buf42, triton_fused_add_1_div_mul_sub_mean_add_div_2_sub_1_div_1_amax3_xnumel, 128, grid=grid(triton_fused_add_1_div_mul_sub_mean_add_div_2_sub_1_div_1_amax3_xnumel), stream=stream0)
buf43 = as_strided(buf35, (s0, 12, 128, 64), (98304, 8192, 64, 1)); del buf35 # reuse
triton_fused_add_1_div_mul_clone_2_view_7_permute_5_sub_permute_4_mean_add4_xnumel = 98304*s0
triton_fused_add_1_div_mul_clone_2_view_7_permute_5_sub_permute_4_mean_add4.run(buf36, buf43, triton_fused_add_1_div_mul_clone_2_view_7_permute_5_sub_permute_4_mean_add4_xnumel, grid=grid(triton_fused_add_1_div_mul_clone_2_view_7_permute_5_sub_permute_4_mean_add4_xnumel), stream=stream0)
buf44 = as_strided(buf36, (12*s0, 128, 64), (8192, 64, 1)); del buf36 # reuse
aten.bmm.out(as_strided(buf42, (12*s0, 128, 128), (16384, 128, 1)), as_strided(buf43, (12*s0, 128, 64), (8192, 64, 1)), out=buf44)
buf45 = empty_strided((s0, 128, 12, 64), (98304, 768, 64, 1), device='cuda', dtype=torch.float32)
triton_fused_clone_3_add_1_div_where_mul_clone_2_permute_5_bmm_1_permute_6_sum_15_xnumel = 98304*s0
triton_fused_clone_3_add_1_div_where_mul_clone_2_permute_5_bmm_1_permute_6_sum_15.run(buf44, buf45, triton_fused_clone_3_add_1_div_where_mul_clone_2_permute_5_bmm_1_permute_6_sum_15_xnumel, grid=grid(triton_fused_clone_3_add_1_div_where_mul_clone_2_permute_5_bmm_1_permute_6_sum_15_xnumel), stream=stream0)
buf46 = empty_strided((128*s0, 768), (768, 1), device='cuda', dtype=torch.float32)
aten.addmm.out(primals_68, as_strided(buf45, (128*s0, 768), (768, 1)), as_strided(primals_67, (768, 768), (1, 768)), beta=1, alpha=1, out=buf46)
del primals_68
buf47 = empty_strided((s0, 128, 1), (128, 1, 128*s0), device='cuda', dtype=torch.float32)
buf50 = empty_strided((s0, 128, 1), (128, 1, 128*s0), device='cuda', dtype=torch.float32)
buf48 = buf47; del buf47 # reuse
buf51 = buf50; del buf50 # reuse
buf52 = as_strided(buf45, (s0, 128, 768), (98304, 768, 1)); del buf45 # reuse
triton_fused_add_11_add_1_permute_19_div_view_23_mul_add_12_sub_4_permute_14_view_199_xnumel = 128*s0
triton_fused_add_11_add_1_permute_19_div_view_23_mul_add_12_sub_4_permute_14_view_199.run(buf48, buf51, primals_193, buf18, buf27, buf46, primals_7, primals_8, buf52, triton_fused_add_11_add_1_permute_19_div_view_23_mul_add_12_sub_4_permute_14_view_199_xnumel, 768, grid=grid(triton_fused_add_11_add_1_permute_19_div_view_23_mul_add_12_sub_4_permute_14_view_199_xnumel), stream=stream0)
buf53 = buf25; del buf25 # reuse
aten.addmm.out(primals_70, as_strided(buf52, (128*s0, 768), (768, 1)), as_strided(primals_69, (768, 3072), (1, 768)), beta=1, alpha=1, out=buf53)
del primals_70
buf54 = empty_strided((s0, 128, 3072), (393216, 3072, 1), device='cuda', dtype=torch.float32)
buf352 = empty_strided((s0, 128, 3072), (393216, 3072, 1), device='cuda', dtype=torch.float32)
triton_fused_clone_3_add_1_div_abs_1_view_13_mul_clone_2_bmm_1_mul_631_permute_87_xnumel = 393216*s0
triton_fused_clone_3_add_1_div_abs_1_view_13_mul_clone_2_bmm_1_mul_631_permute_87.run(buf53, buf54, buf352, triton_fused_clone_3_add_1_div_abs_1_view_13_mul_clone_2_bmm_1_mul_631_permute_87_xnumel, grid=grid(triton_fused_clone_3_add_1_div_abs_1_view_13_mul_clone_2_bmm_1_mul_631_permute_87_xnumel), stream=stream0)
buf55 = as_strided(buf52, (128*s0, 768), (768, 1)); del buf52 # reuse
aten.addmm.out(primals_72, as_strided(buf54, (128*s0, 3072), (3072, 1)), as_strided(primals_71, (3072, 768), (1, 3072)), beta=1, alpha=1, out=buf55)
del primals_72
buf56 = empty_strided((s0, 128, 768), (98304, 768, 1), device='cuda', dtype=torch.float32)
buf57 = empty_strided((s0, 128, 1), (128, 1, 128*s0), device='cuda', dtype=torch.float32)
buf60 = empty_strided((s0, 128, 1), (128, 1, 128*s0), device='cuda', dtype=torch.float32)
buf58 = buf57; del buf57 # reuse
buf61 = buf60; del buf60 # reuse
buf62 = empty_strided((s0, 128, 768), (98304, 768, 1), device='cuda', dtype=torch.float32)
triton_fused_add_11_add_1_permute_19_div_view_23_mul_add_12_view_35_sub_4_permute_1410_xnumel = 128*s0
triton_fused_add_11_add_1_permute_19_div_view_23_mul_add_12_view_35_sub_4_permute_1410.run(buf58, buf61, primals_193, buf18, buf27, buf46, buf55, primals_9, primals_10, buf56, buf62, triton_fused_add_11_add_1_permute_19_div_view_23_mul_add_12_view_35_sub_4_permute_1410_xnumel, 768, grid=grid(triton_fused_add_11_add_1_permute_19_div_view_23_mul_add_12_view_35_sub_4_permute_1410_xnumel), stream=stream0)
buf63 = empty_strided((128*s0, 768), (768, 1), device='cuda', dtype=torch.float32)
aten.addmm.out(primals_74, as_strided(buf62, (128*s0, 768), (768, 1)), as_strided(primals_73, (768, 768), (1, 768)), beta=1, alpha=1, out=buf63)
del primals_74
buf64 = empty_strided((128*s0, 768), (768, 1), device='cuda', dtype=torch.float32)
aten.addmm.out(primals_76, as_strided(buf62, (128*s0, 768), (768, 1)), as_strided(primals_75, (768, 768), (1, 768)), beta=1, alpha=1, out=buf64)
del primals_76
buf65 = empty_strided((128*s0, 768), (768, 1), device='cuda', dtype=torch.float32)
aten.addmm.out(primals_78, as_strided(buf62, (128*s0, 768), (768, 1)), as_strided(primals_77, (768, 768), (1, 768)), beta=1, alpha=1, out=buf65)
del primals_78
buf66 = as_strided(buf62, (s0, 12, 128, 64), (98304, 8192, 64, 1)); del buf62 # reuse
triton_fused_add_1_div_mul_clone_2_view_7_permute_5_sub_permute_4_mean_add4_xnumel = 98304*s0
triton_fused_add_1_div_mul_clone_2_view_7_permute_5_sub_permute_4_mean_add4.run(buf63, buf66, triton_fused_add_1_div_mul_clone_2_view_7_permute_5_sub_permute_4_mean_add4_xnumel, grid=grid(triton_fused_add_1_div_mul_clone_2_view_7_permute_5_sub_permute_4_mean_add4_xnumel), stream=stream0)
buf67 = as_strided(buf63, (s0, 12, 64, 128), (98304, 8192, 128, 1)); del buf63 # reuse
triton_fused_add_1_div_mul_permute_6_sub_addmm_1_mean_permute_3_add_clone_12_xnumel = 768*s0
triton_fused_add_1_div_mul_permute_6_sub_addmm_1_mean_permute_3_add_clone_12.run(buf64, buf67, triton_fused_add_1_div_mul_permute_6_sub_addmm_1_mean_permute_3_add_clone_12_xnumel, 128, grid=grid(triton_fused_add_1_div_mul_permute_6_sub_addmm_1_mean_permute_3_add_clone_12_xnumel, 128), stream=stream0)
buf68 = buf39; del buf39 # reuse
aten.bmm.out(as_strided(buf66, (12*s0, 128, 64), (8192, 64, 1)), as_strided(buf67, (12*s0, 64, 128), (8192, 128, 1)), out=buf68)
buf71 = empty_strided((s0, 12, 128, 128), (196608, 16384, 128, 1), device='cuda', dtype=torch.float32)
triton_fused_add_1_div_mul_sub_mean_add_div_2_sub_1_div_1_amax3_xnumel = 1536*s0
triton_fused_add_1_div_mul_sub_mean_add_div_2_sub_1_div_1_amax3.run(primals_194, buf68, buf71, triton_fused_add_1_div_mul_sub_mean_add_div_2_sub_1_div_1_amax3_xnumel, 128, grid=grid(triton_fused_add_1_div_mul_sub_mean_add_div_2_sub_1_div_1_amax3_xnumel), stream=stream0)
buf72 = as_strided(buf64, (s0, 12, 128, 64), (98304, 8192, 64, 1)); del buf64 # reuse
triton_fused_add_1_div_mul_clone_2_view_7_permute_5_sub_permute_4_mean_add4_xnumel = 98304*s0
triton_fused_add_1_div_mul_clone_2_view_7_permute_5_sub_permute_4_mean_add4.run(buf65, buf72, triton_fused_add_1_div_mul_clone_2_view_7_permute_5_sub_permute_4_mean_add4_xnumel, grid=grid(triton_fused_add_1_div_mul_clone_2_view_7_permute_5_sub_permute_4_mean_add4_xnumel), stream=stream0)
buf73 = as_strided(buf65, (12*s0, 128, 64), (8192, 64, 1)); del buf65 # reuse
aten.bmm.out(as_strided(buf71, (12*s0, 128, 128), (16384, 128, 1)), as_strided(buf72, (12*s0, 128, 64), (8192, 64, 1)), out=buf73)
buf74 = empty_strided((s0, 128, 12, 64), (98304, 768, 64, 1), device='cuda', dtype=torch.float32)
triton_fused_clone_3_add_1_div_where_mul_clone_2_permute_5_bmm_1_permute_6_sum_15_xnumel = 98304*s0
triton_fused_clone_3_add_1_div_where_mul_clone_2_permute_5_bmm_1_permute_6_sum_15.run(buf73, buf74, triton_fused_clone_3_add_1_div_where_mul_clone_2_permute_5_bmm_1_permute_6_sum_15_xnumel, grid=grid(triton_fused_clone_3_add_1_div_where_mul_clone_2_permute_5_bmm_1_permute_6_sum_15_xnumel), stream=stream0)
buf75 = empty_strided((128*s0, 768), (768, 1), device='cuda', dtype=torch.float32)
aten.addmm.out(primals_80, as_strided(buf74, (128*s0, 768), (768, 1)), as_strided(primals_79, (768, 768), (1, 768)), beta=1, alpha=1, out=buf75)
del primals_80
buf76 = empty_strided((s0, 128, 1), (128, 1, 128*s0), device='cuda', dtype=torch.float32)
buf79 = empty_strided((s0, 128, 1), (128, 1, 128*s0), device='cuda', dtype=torch.float32)
buf77 = buf76; del buf76 # reuse
buf80 = buf79; del buf79 # reuse
buf81 = as_strided(buf74, (s0, 128, 768), (98304, 768, 1)); del buf74 # reuse
triton_fused_clone_3_add_1_div_view_13_mul_clone_2_bmm_1_permute_8_sub_mean6_xnumel = 128*s0
triton_fused_clone_3_add_1_div_view_13_mul_clone_2_bmm_1_permute_8_sub_mean6.run(buf77, buf80, buf56, buf75, primals_11, primals_12, buf81, triton_fused_clone_3_add_1_div_view_13_mul_clone_2_bmm_1_permute_8_sub_mean6_xnumel, 768, grid=grid(triton_fused_clone_3_add_1_div_view_13_mul_clone_2_bmm_1_permute_8_sub_mean6_xnumel), stream=stream0)
buf82 = buf53; del buf53 # reuse
aten.addmm.out(primals_82, as_strided(buf81, (128*s0, 768), (768, 1)), as_strided(primals_81, (768, 3072), (1, 768)), beta=1, alpha=1, out=buf82)
del primals_82
buf83 = empty_strided((s0, 128, 3072), (393216, 3072, 1), device='cuda', dtype=torch.float32)
buf351 = empty_strided((s0, 128, 3072), (393216, 3072, 1), device='cuda', dtype=torch.float32)
triton_fused_clone_3_add_1_div_abs_1_view_13_mul_clone_2_bmm_1_mul_631_permute_87_xnumel = 393216*s0
triton_fused_clone_3_add_1_div_abs_1_view_13_mul_clone_2_bmm_1_mul_631_permute_87.run(buf82, buf83, buf351, triton_fused_clone_3_add_1_div_abs_1_view_13_mul_clone_2_bmm_1_mul_631_permute_87_xnumel, grid=grid(triton_fused_clone_3_add_1_div_abs_1_view_13_mul_clone_2_bmm_1_mul_631_permute_87_xnumel), stream=stream0)
buf84 = as_strided(buf81, (128*s0, 768), (768, 1)); del buf81 # reuse
aten.addmm.out(primals_84, as_strided(buf83, (128*s0, 3072), (3072, 1)), as_strided(primals_83, (3072, 768), (1, 3072)), beta=1, alpha=1, out=buf84)
del primals_84
buf85 = empty_strided((s0, 128, 1), (128, 1, 128*s0), device='cuda', dtype=torch.float32)
buf88 = empty_strided((s0, 128, 1), (128, 1, 128*s0), device='cuda', dtype=torch.float32)
buf86 = buf85; del buf85 # reuse
buf89 = buf88; del buf88 # reuse
buf90 = empty_strided((s0, 128, 768), (98304, 768, 1), device='cuda', dtype=torch.float32)
triton_fused_add_11_clone_3_add_1_div_abs_1_view_13_mul_add_12_clone_2_sub_48_xnumel = 128*s0
triton_fused_add_11_clone_3_add_1_div_abs_1_view_13_mul_add_12_clone_2_sub_48.run(buf86, buf89, buf56, buf75, buf84, primals_13, primals_14, buf90, triton_fused_add_11_clone_3_add_1_div_abs_1_view_13_mul_add_12_clone_2_sub_48_xnumel, 768, grid=grid(triton_fused_add_11_clone_3_add_1_div_abs_1_view_13_mul_add_12_clone_2_sub_48_xnumel), stream=stream0)
buf91 = empty_strided((128*s0, 768), (768, 1), device='cuda', dtype=torch.float32)
aten.addmm.out(primals_86, as_strided(buf90, (128*s0, 768), (768, 1)), as_strided(primals_85, (768, 768), (1, 768)), beta=1, alpha=1, out=buf91)
del primals_86
buf92 = empty_strided((128*s0, 768), (768, 1), device='cuda', dtype=torch.float32)
aten.addmm.out(primals_88, as_strided(buf90, (128*s0, 768), (768, 1)), as_strided(primals_87, (768, 768), (1, 768)), beta=1, alpha=1, out=buf92)
del primals_88
buf93 = empty_strided((128*s0, 768), (768, 1), device='cuda', dtype=torch.float32)
aten.addmm.out(primals_90, as_strided(buf90, (128*s0, 768), (768, 1)), as_strided(primals_89, (768, 768), (1, 768)), beta=1, alpha=1, out=buf93)
del primals_90
buf94 = as_strided(buf90, (s0, 12, 128, 64), (98304, 8192, 64, 1)); del buf90 # reuse
triton_fused_add_1_div_mul_clone_2_view_7_permute_5_sub_permute_4_mean_add4_xnumel = 98304*s0
triton_fused_add_1_div_mul_clone_2_view_7_permute_5_sub_permute_4_mean_add4.run(buf91, buf94, triton_fused_add_1_div_mul_clone_2_view_7_permute_5_sub_permute_4_mean_add4_xnumel, grid=grid(triton_fused_add_1_div_mul_clone_2_view_7_permute_5_sub_permute_4_mean_add4_xnumel), stream=stream0)
buf95 = as_strided(buf91, (s0, 12, 64, 128), (98304, 8192, 128, 1)); del buf91 # reuse
triton_fused_add_1_div_mul_permute_6_sub_addmm_1_mean_permute_3_add_clone_12_xnumel = 768*s0
triton_fused_add_1_div_mul_permute_6_sub_addmm_1_mean_permute_3_add_clone_12.run(buf92, buf95, triton_fused_add_1_div_mul_permute_6_sub_addmm_1_mean_permute_3_add_clone_12_xnumel, 128, grid=grid(triton_fused_add_1_div_mul_permute_6_sub_addmm_1_mean_permute_3_add_clone_12_xnumel, 128), stream=stream0)
buf96 = buf68; del buf68 # reuse
aten.bmm.out(as_strided(buf94, (12*s0, 128, 64), (8192, 64, 1)), as_strided(buf95, (12*s0, 64, 128), (8192, 128, 1)), out=buf96)
buf99 = empty_strided((s0, 12, 128, 128), (196608, 16384, 128, 1), device='cuda', dtype=torch.float32)
triton_fused_add_1_div_mul_sub_mean_add_div_2_sub_1_div_1_amax3_xnumel = 1536*s0
triton_fused_add_1_div_mul_sub_mean_add_div_2_sub_1_div_1_amax3.run(primals_194, buf96, buf99, triton_fused_add_1_div_mul_sub_mean_add_div_2_sub_1_div_1_amax3_xnumel, 128, grid=grid(triton_fused_add_1_div_mul_sub_mean_add_div_2_sub_1_div_1_amax3_xnumel), stream=stream0)
buf100 = as_strided(buf92, (s0, 12, 128, 64), (98304, 8192, 64, 1)); del buf92 # reuse
triton_fused_add_1_div_mul_clone_2_view_7_permute_5_sub_permute_4_mean_add4_xnumel = 98304*s0
triton_fused_add_1_div_mul_clone_2_view_7_permute_5_sub_permute_4_mean_add4.run(buf93, buf100, triton_fused_add_1_div_mul_clone_2_view_7_permute_5_sub_permute_4_mean_add4_xnumel, grid=grid(triton_fused_add_1_div_mul_clone_2_view_7_permute_5_sub_permute_4_mean_add4_xnumel), stream=stream0)
buf101 = as_strided(buf93, (12*s0, 128, 64), (8192, 64, 1)); del buf93 # reuse
aten.bmm.out(as_strided(buf99, (12*s0, 128, 128), (16384, 128, 1)), as_strided(buf100, (12*s0, 128, 64), (8192, 64, 1)), out=buf101)
buf102 = empty_strided((s0, 128, 12, 64), (98304, 768, 64, 1), device='cuda', dtype=torch.float32)
triton_fused_clone_3_add_1_div_where_mul_clone_2_permute_5_bmm_1_permute_6_sum_15_xnumel = 98304*s0
triton_fused_clone_3_add_1_div_where_mul_clone_2_permute_5_bmm_1_permute_6_sum_15.run(buf101, buf102, triton_fused_clone_3_add_1_div_where_mul_clone_2_permute_5_bmm_1_permute_6_sum_15_xnumel, grid=grid(triton_fused_clone_3_add_1_div_where_mul_clone_2_permute_5_bmm_1_permute_6_sum_15_xnumel), stream=stream0)
buf103 = empty_strided((128*s0, 768), (768, 1), device='cuda', dtype=torch.float32)
aten.addmm.out(primals_92, as_strided(buf102, (128*s0, 768), (768, 1)), as_strided(primals_91, (768, 768), (1, 768)), beta=1, alpha=1, out=buf103)
del primals_92
buf104 = empty_strided((s0, 128, 1), (128, 1, 128*s0), device='cuda', dtype=torch.float32)
buf107 = empty_strided((s0, 128, 1), (128, 1, 128*s0), device='cuda', dtype=torch.float32)
buf105 = buf104; del buf104 # reuse
buf108 = buf107; del buf107 # reuse
buf109 = as_strided(buf102, (s0, 128, 768), (98304, 768, 1)); del buf102 # reuse
triton_fused_add_11_add_1_permute_19_div_view_23_mul_add_12_sub_4_permute_14_view_199_xnumel = 128*s0
triton_fused_add_11_add_1_permute_19_div_view_23_mul_add_12_sub_4_permute_14_view_199.run(buf105, buf108, buf56, buf75, buf84, buf103, primals_15, primals_16, buf109, triton_fused_add_11_add_1_permute_19_div_view_23_mul_add_12_sub_4_permute_14_view_199_xnumel, 768, grid=grid(triton_fused_add_11_add_1_permute_19_div_view_23_mul_add_12_sub_4_permute_14_view_199_xnumel), stream=stream0)
buf110 = buf82; del buf82 # reuse
aten.addmm.out(primals_94, as_strided(buf109, (128*s0, 768), (768, 1)), as_strided(primals_93, (768, 3072), (1, 768)), beta=1, alpha=1, out=buf110)
del primals_94
buf111 = empty_strided((s0, 128, 3072), (393216, 3072, 1), device='cuda', dtype=torch.float32)
buf350 = empty_strided((s0, 128, 3072), (393216, 3072, 1), device='cuda', dtype=torch.float32)
triton_fused_clone_3_add_1_div_abs_1_view_13_mul_clone_2_bmm_1_mul_631_permute_87_xnumel = 393216*s0
triton_fused_clone_3_add_1_div_abs_1_view_13_mul_clone_2_bmm_1_mul_631_permute_87.run(buf110, buf111, buf350, triton_fused_clone_3_add_1_div_abs_1_view_13_mul_clone_2_bmm_1_mul_631_permute_87_xnumel, grid=grid(triton_fused_clone_3_add_1_div_abs_1_view_13_mul_clone_2_bmm_1_mul_631_permute_87_xnumel), stream=stream0)
buf112 = as_strided(buf109, (128*s0, 768), (768, 1)); del buf109 # reuse
aten.addmm.out(primals_96, as_strided(buf111, (128*s0, 3072), (3072, 1)), as_strided(primals_95, (3072, 768), (1, 3072)), beta=1, alpha=1, out=buf112)
del primals_96
buf113 = buf56; del buf56 # reuse
buf114 = empty_strided((s0, 128, 1), (128, 1, 128*s0), device='cuda', dtype=torch.float32)
buf117 = empty_strided((s0, 128, 1), (128, 1, 128*s0), device='cuda', dtype=torch.float32)
buf115 = buf114; del buf114 # reuse
buf118 = buf117; del buf117 # reuse
buf119 = empty_strided((s0, 128, 768), (98304, 768, 1), device='cuda', dtype=torch.float32)
triton_fused_permute_19_view_35_view_34_reciprocal_1_view_33_addmm_9_mul_32_permute_20_std_3_add_1411_xnumel = 128*s0
triton_fused_permute_19_view_35_view_34_reciprocal_1_view_33_addmm_9_mul_32_permute_20_std_3_add_1411.run(buf113, buf115, buf118, buf75, buf84, buf103, buf112, primals_17, primals_18, buf119, triton_fused_permute_19_view_35_view_34_reciprocal_1_view_33_addmm_9_mul_32_permute_20_std_3_add_1411_xnumel, 768, grid=grid(triton_fused_permute_19_view_35_view_34_reciprocal_1_view_33_addmm_9_mul_32_permute_20_std_3_add_1411_xnumel), stream=stream0)
buf120 = empty_strided((128*s0, 768), (768, 1), device='cuda', dtype=torch.float32)
aten.addmm.out(primals_98, as_strided(buf119, (128*s0, 768), (768, 1)), as_strided(primals_97, (768, 768), (1, 768)), beta=1, alpha=1, out=buf120)
del primals_98
buf121 = empty_strided((128*s0, 768), (768, 1), device='cuda', dtype=torch.float32)
aten.addmm.out(primals_100, as_strided(buf119, (128*s0, 768), (768, 1)), as_strided(primals_99, (768, 768), (1, 768)), beta=1, alpha=1, out=buf121)
del primals_100
buf122 = empty_strided((128*s0, 768), (768, 1), device='cuda', dtype=torch.float32)
aten.addmm.out(primals_102, as_strided(buf119, (128*s0, 768), (768, 1)), as_strided(primals_101, (768, 768), (1, 768)), beta=1, alpha=1, out=buf122)
del primals_102
buf123 = as_strided(buf119, (s0, 12, 128, 64), (98304, 8192, 64, 1)); del buf119 # reuse
triton_fused_add_1_div_mul_clone_2_view_7_permute_5_sub_permute_4_mean_add4_xnumel = 98304*s0
triton_fused_add_1_div_mul_clone_2_view_7_permute_5_sub_permute_4_mean_add4.run(buf120, buf123, triton_fused_add_1_div_mul_clone_2_view_7_permute_5_sub_permute_4_mean_add4_xnumel, grid=grid(triton_fused_add_1_div_mul_clone_2_view_7_permute_5_sub_permute_4_mean_add4_xnumel), stream=stream0)
buf124 = as_strided(buf120, (s0, 12, 64, 128), (98304, 8192, 128, 1)); del buf120 # reuse
triton_fused_add_1_div_mul_permute_6_sub_addmm_1_mean_permute_3_add_clone_12_xnumel = 768*s0
triton_fused_add_1_div_mul_permute_6_sub_addmm_1_mean_permute_3_add_clone_12.run(buf121, buf124, triton_fused_add_1_div_mul_permute_6_sub_addmm_1_mean_permute_3_add_clone_12_xnumel, 128, grid=grid(triton_fused_add_1_div_mul_permute_6_sub_addmm_1_mean_permute_3_add_clone_12_xnumel, 128), stream=stream0)
buf125 = buf96; del buf96 # reuse
aten.bmm.out(as_strided(buf123, (12*s0, 128, 64), (8192, 64, 1)), as_strided(buf124, (12*s0, 64, 128), (8192, 128, 1)), out=buf125)
buf128 = empty_strided((s0, 12, 128, 128), (196608, 16384, 128, 1), device='cuda', dtype=torch.float32)
triton_fused_add_1_div_mul_sub_mean_add_div_2_sub_1_div_1_amax3_xnumel = 1536*s0
triton_fused_add_1_div_mul_sub_mean_add_div_2_sub_1_div_1_amax3.run(primals_194, buf125, buf128, triton_fused_add_1_div_mul_sub_mean_add_div_2_sub_1_div_1_amax3_xnumel, 128, grid=grid(triton_fused_add_1_div_mul_sub_mean_add_div_2_sub_1_div_1_amax3_xnumel), stream=stream0)
buf129 = as_strided(buf121, (s0, 12, 128, 64), (98304, 8192, 64, 1)); del buf121 # reuse
triton_fused_add_1_div_mul_clone_2_view_7_permute_5_sub_permute_4_mean_add4_xnumel = 98304*s0
triton_fused_add_1_div_mul_clone_2_view_7_permute_5_sub_permute_4_mean_add4.run(buf122, buf129, triton_fused_add_1_div_mul_clone_2_view_7_permute_5_sub_permute_4_mean_add4_xnumel, grid=grid(triton_fused_add_1_div_mul_clone_2_view_7_permute_5_sub_permute_4_mean_add4_xnumel), stream=stream0)
buf130 = as_strided(buf122, (12*s0, 128, 64), (8192, 64, 1)); del buf122 # reuse
aten.bmm.out(as_strided(buf128, (12*s0, 128, 128), (16384, 128, 1)), as_strided(buf129, (12*s0, 128, 64), (8192, 64, 1)), out=buf130)
buf131 = empty_strided((s0, 128, 12, 64), (98304, 768, 64, 1), device='cuda', dtype=torch.float32)
triton_fused_clone_3_add_1_div_where_mul_clone_2_permute_5_bmm_1_permute_6_sum_15_xnumel = 98304*s0
triton_fused_clone_3_add_1_div_where_mul_clone_2_permute_5_bmm_1_permute_6_sum_15.run(buf130, buf131, triton_fused_clone_3_add_1_div_where_mul_clone_2_permute_5_bmm_1_permute_6_sum_15_xnumel, grid=grid(triton_fused_clone_3_add_1_div_where_mul_clone_2_permute_5_bmm_1_permute_6_sum_15_xnumel), stream=stream0)
buf132 = empty_strided((128*s0, 768), (768, 1), device='cuda', dtype=torch.float32)
aten.addmm.out(primals_104, as_strided(buf131, (128*s0, 768), (768, 1)), as_strided(primals_103, (768, 768), (1, 768)), beta=1, alpha=1, out=buf132)
del primals_104
buf133 = empty_strided((s0, 128, 1), (128, 1, 128*s0), device='cuda', dtype=torch.float32)
buf136 = empty_strided((s0, 128, 1), (128, 1, 128*s0), device='cuda', dtype=torch.float32)
buf134 = buf133; del buf133 # reuse
buf137 = buf136; del buf136 # reuse
buf138 = as_strided(buf131, (s0, 128, 768), (98304, 768, 1)); del buf131 # reuse
triton_fused_clone_3_add_1_div_view_13_mul_clone_2_bmm_1_permute_8_sub_mean6_xnumel = 128*s0
triton_fused_clone_3_add_1_div_view_13_mul_clone_2_bmm_1_permute_8_sub_mean6.run(buf134, buf137, buf113, buf132, primals_19, primals_20, buf138, triton_fused_clone_3_add_1_div_view_13_mul_clone_2_bmm_1_permute_8_sub_mean6_xnumel, 768, grid=grid(triton_fused_clone_3_add_1_div_view_13_mul_clone_2_bmm_1_permute_8_sub_mean6_xnumel), stream=stream0)
buf139 = buf110; del buf110 # reuse
aten.addmm.out(primals_106, as_strided(buf138, (128*s0, 768), (768, 1)), as_strided(primals_105, (768, 3072), (1, 768)), beta=1, alpha=1, out=buf139)
del primals_106
buf140 = empty_strided((s0, 128, 3072), (393216, 3072, 1), device='cuda', dtype=torch.float32)
buf349 = empty_strided((s0, 128, 3072), (393216, 3072, 1), device='cuda', dtype=torch.float32)
triton_fused_clone_3_add_1_div_abs_1_view_13_mul_clone_2_bmm_1_mul_631_permute_87_xnumel = 393216*s0
triton_fused_clone_3_add_1_div_abs_1_view_13_mul_clone_2_bmm_1_mul_631_permute_87.run(buf139, buf140, buf349, triton_fused_clone_3_add_1_div_abs_1_view_13_mul_clone_2_bmm_1_mul_631_permute_87_xnumel, grid=grid(triton_fused_clone_3_add_1_div_abs_1_view_13_mul_clone_2_bmm_1_mul_631_permute_87_xnumel), stream=stream0)
buf141 = as_strided(buf138, (128*s0, 768), (768, 1)); del buf138 # reuse
aten.addmm.out(primals_108, as_strided(buf140, (128*s0, 3072), (3072, 1)), as_strided(primals_107, (3072, 768), (1, 3072)), beta=1, alpha=1, out=buf141)
del primals_108
buf142 = empty_strided((s0, 128, 1), (128, 1, 128*s0), device='cuda', dtype=torch.float32)
buf145 = empty_strided((s0, 128, 1), (128, 1, 128*s0), device='cuda', dtype=torch.float32)
buf143 = buf142; del buf142 # reuse
buf146 = buf145; del buf145 # reuse
buf147 = empty_strided((s0, 128, 768), (98304, 768, 1), device='cuda', dtype=torch.float32)
triton_fused_add_11_clone_3_add_1_div_abs_1_view_13_mul_add_12_clone_2_sub_48_xnumel = 128*s0
triton_fused_add_11_clone_3_add_1_div_abs_1_view_13_mul_add_12_clone_2_sub_48.run(buf143, buf146, buf113, buf132, buf141, primals_21, primals_22, buf147, triton_fused_add_11_clone_3_add_1_div_abs_1_view_13_mul_add_12_clone_2_sub_48_xnumel, 768, grid=grid(triton_fused_add_11_clone_3_add_1_div_abs_1_view_13_mul_add_12_clone_2_sub_48_xnumel), stream=stream0)
buf148 = empty_strided((128*s0, 768), (768, 1), device='cuda', dtype=torch.float32)
aten.addmm.out(primals_110, as_strided(buf147, (128*s0, 768), (768, 1)), as_strided(primals_109, (768, 768), (1, 768)), beta=1, alpha=1, out=buf148)
del primals_110
buf149 = empty_strided((128*s0, 768), (768, 1), device='cuda', dtype=torch.float32)
aten.addmm.out(primals_112, as_strided(buf147, (128*s0, 768), (768, 1)), as_strided(primals_111, (768, 768), (1, 768)), beta=1, alpha=1, out=buf149)
del primals_112
buf150 = empty_strided((128*s0, 768), (768, 1), device='cuda', dtype=torch.float32)
aten.addmm.out(primals_114, as_strided(buf147, (128*s0, 768), (768, 1)), as_strided(primals_113, (768, 768), (1, 768)), beta=1, alpha=1, out=buf150)
del primals_114
buf151 = as_strided(buf147, (s0, 12, 128, 64), (98304, 8192, 64, 1)); del buf147 # reuse
triton_fused_add_1_div_mul_clone_2_view_7_permute_5_sub_permute_4_mean_add4_xnumel = 98304*s0
triton_fused_add_1_div_mul_clone_2_view_7_permute_5_sub_permute_4_mean_add4.run(buf148, buf151, triton_fused_add_1_div_mul_clone_2_view_7_permute_5_sub_permute_4_mean_add4_xnumel, grid=grid(triton_fused_add_1_div_mul_clone_2_view_7_permute_5_sub_permute_4_mean_add4_xnumel), stream=stream0)
buf152 = as_strided(buf148, (s0, 12, 64, 128), (98304, 8192, 128, 1)); del buf148 # reuse
triton_fused_add_1_div_mul_permute_6_sub_addmm_1_mean_permute_3_add_clone_12_xnumel = 768*s0
triton_fused_add_1_div_mul_permute_6_sub_addmm_1_mean_permute_3_add_clone_12.run(buf149, buf152, triton_fused_add_1_div_mul_permute_6_sub_addmm_1_mean_permute_3_add_clone_12_xnumel, 128, grid=grid(triton_fused_add_1_div_mul_permute_6_sub_addmm_1_mean_permute_3_add_clone_12_xnumel, 128), stream=stream0)
buf153 = buf125; del buf125 # reuse
aten.bmm.out(as_strided(buf151, (12*s0, 128, 64), (8192, 64, 1)), as_strided(buf152, (12*s0, 64, 128), (8192, 128, 1)), out=buf153)
buf156 = empty_strided((s0, 12, 128, 128), (196608, 16384, 128, 1), device='cuda', dtype=torch.float32)
triton_fused_add_1_div_mul_sub_mean_add_div_2_sub_1_div_1_amax3_xnumel = 1536*s0
triton_fused_add_1_div_mul_sub_mean_add_div_2_sub_1_div_1_amax3.run(primals_194, buf153, buf156, triton_fused_add_1_div_mul_sub_mean_add_div_2_sub_1_div_1_amax3_xnumel, 128, grid=grid(triton_fused_add_1_div_mul_sub_mean_add_div_2_sub_1_div_1_amax3_xnumel), stream=stream0)
buf157 = as_strided(buf149, (s0, 12, 128, 64), (98304, 8192, 64, 1)); del buf149 # reuse
triton_fused_add_1_div_mul_clone_2_view_7_permute_5_sub_permute_4_mean_add4_xnumel = 98304*s0
triton_fused_add_1_div_mul_clone_2_view_7_permute_5_sub_permute_4_mean_add4.run(buf150, buf157, triton_fused_add_1_div_mul_clone_2_view_7_permute_5_sub_permute_4_mean_add4_xnumel, grid=grid(triton_fused_add_1_div_mul_clone_2_view_7_permute_5_sub_permute_4_mean_add4_xnumel), stream=stream0)
buf158 = as_strided(buf150, (12*s0, 128, 64), (8192, 64, 1)); del buf150 # reuse
aten.bmm.out(as_strided(buf156, (12*s0, 128, 128), (16384, 128, 1)), as_strided(buf157, (12*s0, 128, 64), (8192, 64, 1)), out=buf158)
buf159 = empty_strided((s0, 128, 12, 64), (98304, 768, 64, 1), device='cuda', dtype=torch.float32)
triton_fused_clone_3_add_1_div_where_mul_clone_2_permute_5_bmm_1_permute_6_sum_15_xnumel = 98304*s0
triton_fused_clone_3_add_1_div_where_mul_clone_2_permute_5_bmm_1_permute_6_sum_15.run(buf158, buf159, triton_fused_clone_3_add_1_div_where_mul_clone_2_permute_5_bmm_1_permute_6_sum_15_xnumel, grid=grid(triton_fused_clone_3_add_1_div_where_mul_clone_2_permute_5_bmm_1_permute_6_sum_15_xnumel), stream=stream0)
buf160 = empty_strided((128*s0, 768), (768, 1), device='cuda', dtype=torch.float32)
aten.addmm.out(primals_116, as_strided(buf159, (128*s0, 768), (768, 1)), as_strided(primals_115, (768, 768), (1, 768)), beta=1, alpha=1, out=buf160)
del primals_116
buf161 = empty_strided((s0, 128, 1), (128, 1, 128*s0), device='cuda', dtype=torch.float32)
buf164 = empty_strided((s0, 128, 1), (128, 1, 128*s0), device='cuda', dtype=torch.float32)
buf162 = buf161; del buf161 # reuse
buf165 = buf164; del buf164 # reuse
buf166 = as_strided(buf159, (s0, 128, 768), (98304, 768, 1)); del buf159 # reuse
triton_fused_add_11_add_1_permute_19_div_view_23_mul_add_12_sub_4_permute_14_view_199_xnumel = 128*s0
triton_fused_add_11_add_1_permute_19_div_view_23_mul_add_12_sub_4_permute_14_view_199.run(buf162, buf165, buf113, buf132, buf141, buf160, primals_23, primals_24, buf166, triton_fused_add_11_add_1_permute_19_div_view_23_mul_add_12_sub_4_permute_14_view_199_xnumel, 768, grid=grid(triton_fused_add_11_add_1_permute_19_div_view_23_mul_add_12_sub_4_permute_14_view_199_xnumel), stream=stream0)
buf167 = buf139; del buf139 # reuse
aten.addmm.out(primals_118, as_strided(buf166, (128*s0, 768), (768, 1)), as_strided(primals_117, (768, 3072), (1, 768)), beta=1, alpha=1, out=buf167)
del primals_118
buf168 = empty_strided((s0, 128, 3072), (393216, 3072, 1), device='cuda', dtype=torch.float32)
buf348 = empty_strided((s0, 128, 3072), (393216, 3072, 1), device='cuda', dtype=torch.float32)
triton_fused_clone_3_add_1_div_abs_1_view_13_mul_clone_2_bmm_1_mul_631_permute_87_xnumel = 393216*s0
triton_fused_clone_3_add_1_div_abs_1_view_13_mul_clone_2_bmm_1_mul_631_permute_87.run(buf167, buf168, buf348, triton_fused_clone_3_add_1_div_abs_1_view_13_mul_clone_2_bmm_1_mul_631_permute_87_xnumel, grid=grid(triton_fused_clone_3_add_1_div_abs_1_view_13_mul_clone_2_bmm_1_mul_631_permute_87_xnumel), stream=stream0)
buf169 = as_strided(buf166, (128*s0, 768), (768, 1)); del buf166 # reuse
aten.addmm.out(primals_120, as_strided(buf168, (128*s0, 3072), (3072, 1)), as_strided(primals_119, (3072, 768), (1, 3072)), beta=1, alpha=1, out=buf169)
del primals_120
buf170 = buf113; del buf113 # reuse
buf171 = empty_strided((s0, 128, 1), (128, 1, 128*s0), device='cuda', dtype=torch.float32)
buf174 = empty_strided((s0, 128, 1), (128, 1, 128*s0), device='cuda', dtype=torch.float32)
buf172 = buf171; del buf171 # reuse
buf175 = buf174; del buf174 # reuse
buf176 = empty_strided((s0, 128, 768), (98304, 768, 1), device='cuda', dtype=torch.float32)
triton_fused_permute_19_view_35_view_34_reciprocal_1_view_33_addmm_9_mul_32_permute_20_std_3_add_1411_xnumel = 128*s0
triton_fused_permute_19_view_35_view_34_reciprocal_1_view_33_addmm_9_mul_32_permute_20_std_3_add_1411.run(buf170, buf172, buf175, buf132, buf141, buf160, buf169, primals_25, primals_26, buf176, triton_fused_permute_19_view_35_view_34_reciprocal_1_view_33_addmm_9_mul_32_permute_20_std_3_add_1411_xnumel, 768, grid=grid(triton_fused_permute_19_view_35_view_34_reciprocal_1_view_33_addmm_9_mul_32_permute_20_std_3_add_1411_xnumel), stream=stream0)
buf177 = empty_strided((128*s0, 768), (768, 1), device='cuda', dtype=torch.float32)
aten.addmm.out(primals_122, as_strided(buf176, (128*s0, 768), (768, 1)), as_strided(primals_121, (768, 768), (1, 768)), beta=1, alpha=1, out=buf177)
del primals_122
buf178 = empty_strided((128*s0, 768), (768, 1), device='cuda', dtype=torch.float32)
aten.addmm.out(primals_124, as_strided(buf176, (128*s0, 768), (768, 1)), as_strided(primals_123, (768, 768), (1, 768)), beta=1, alpha=1, out=buf178)
del primals_124
buf179 = empty_strided((128*s0, 768), (768, 1), device='cuda', dtype=torch.float32)
aten.addmm.out(primals_126, as_strided(buf176, (128*s0, 768), (768, 1)), as_strided(primals_125, (768, 768), (1, 768)), beta=1, alpha=1, out=buf179)
del primals_126
buf180 = as_strided(buf176, (s0, 12, 128, 64), (98304, 8192, 64, 1)); del buf176 # reuse
triton_fused_add_1_div_mul_clone_2_view_7_permute_5_sub_permute_4_mean_add4_xnumel = 98304*s0
triton_fused_add_1_div_mul_clone_2_view_7_permute_5_sub_permute_4_mean_add4.run(buf177, buf180, triton_fused_add_1_div_mul_clone_2_view_7_permute_5_sub_permute_4_mean_add4_xnumel, grid=grid(triton_fused_add_1_div_mul_clone_2_view_7_permute_5_sub_permute_4_mean_add4_xnumel), stream=stream0)
buf181 = as_strided(buf177, (s0, 12, 64, 128), (98304, 8192, 128, 1)); del buf177 # reuse
triton_fused_add_1_div_mul_permute_6_sub_addmm_1_mean_permute_3_add_clone_12_xnumel = 768*s0
triton_fused_add_1_div_mul_permute_6_sub_addmm_1_mean_permute_3_add_clone_12.run(buf178, buf181, triton_fused_add_1_div_mul_permute_6_sub_addmm_1_mean_permute_3_add_clone_12_xnumel, 128, grid=grid(triton_fused_add_1_div_mul_permute_6_sub_addmm_1_mean_permute_3_add_clone_12_xnumel, 128), stream=stream0)
buf182 = buf153; del buf153 # reuse
aten.bmm.out(as_strided(buf180, (12*s0, 128, 64), (8192, 64, 1)), as_strided(buf181, (12*s0, 64, 128), (8192, 128, 1)), out=buf182)
buf185 = empty_strided((s0, 12, 128, 128), (196608, 16384, 128, 1), device='cuda', dtype=torch.float32)
triton_fused_add_1_div_mul_sub_mean_add_div_2_sub_1_div_1_amax3_xnumel = 1536*s0
triton_fused_add_1_div_mul_sub_mean_add_div_2_sub_1_div_1_amax3.run(primals_194, buf182, buf185, triton_fused_add_1_div_mul_sub_mean_add_div_2_sub_1_div_1_amax3_xnumel, 128, grid=grid(triton_fused_add_1_div_mul_sub_mean_add_div_2_sub_1_div_1_amax3_xnumel), stream=stream0)
buf186 = as_strided(buf178, (s0, 12, 128, 64), (98304, 8192, 64, 1)); del buf178 # reuse
triton_fused_add_1_div_mul_clone_2_view_7_permute_5_sub_permute_4_mean_add4_xnumel = 98304*s0
triton_fused_add_1_div_mul_clone_2_view_7_permute_5_sub_permute_4_mean_add4.run(buf179, buf186, triton_fused_add_1_div_mul_clone_2_view_7_permute_5_sub_permute_4_mean_add4_xnumel, grid=grid(triton_fused_add_1_div_mul_clone_2_view_7_permute_5_sub_permute_4_mean_add4_xnumel), stream=stream0)
buf187 = as_strided(buf179, (12*s0, 128, 64), (8192, 64, 1)); del buf179 # reuse
aten.bmm.out(as_strided(buf185, (12*s0, 128, 128), (16384, 128, 1)), as_strided(buf186, (12*s0, 128, 64), (8192, 64, 1)), out=buf187)
buf188 = empty_strided((s0, 128, 12, 64), (98304, 768, 64, 1), device='cuda', dtype=torch.float32)
triton_fused_clone_3_add_1_div_where_mul_clone_2_permute_5_bmm_1_permute_6_sum_15_xnumel = 98304*s0
triton_fused_clone_3_add_1_div_where_mul_clone_2_permute_5_bmm_1_permute_6_sum_15.run(buf187, buf188, triton_fused_clone_3_add_1_div_where_mul_clone_2_permute_5_bmm_1_permute_6_sum_15_xnumel, grid=grid(triton_fused_clone_3_add_1_div_where_mul_clone_2_permute_5_bmm_1_permute_6_sum_15_xnumel), stream=stream0)
buf189 = empty_strided((128*s0, 768), (768, 1), device='cuda', dtype=torch.float32)
aten.addmm.out(primals_128, as_strided(buf188, (128*s0, 768), (768, 1)), as_strided(primals_127, (768, 768), (1, 768)), beta=1, alpha=1, out=buf189)
del primals_128
buf190 = empty_strided((s0, 128, 1), (128, 1, 128*s0), device='cuda', dtype=torch.float32)
buf193 = empty_strided((s0, 128, 1), (128, 1, 128*s0), device='cuda', dtype=torch.float32)
buf191 = buf190; del buf190 # reuse
buf194 = buf193; del buf193 # reuse
buf195 = as_strided(buf188, (s0, 128, 768), (98304, 768, 1)); del buf188 # reuse
triton_fused_clone_3_add_1_div_view_13_mul_clone_2_bmm_1_permute_8_sub_mean6_xnumel = 128*s0
triton_fused_clone_3_add_1_div_view_13_mul_clone_2_bmm_1_permute_8_sub_mean6.run(buf191, buf194, buf170, buf189, primals_27, primals_28, buf195, triton_fused_clone_3_add_1_div_view_13_mul_clone_2_bmm_1_permute_8_sub_mean6_xnumel, 768, grid=grid(triton_fused_clone_3_add_1_div_view_13_mul_clone_2_bmm_1_permute_8_sub_mean6_xnumel), stream=stream0)
buf196 = buf167; del buf167 # reuse
aten.addmm.out(primals_130, as_strided(buf195, (128*s0, 768), (768, 1)), as_strided(primals_129, (768, 3072), (1, 768)), beta=1, alpha=1, out=buf196)
del primals_130
buf197 = empty_strided((s0, 128, 3072), (393216, 3072, 1), device='cuda', dtype=torch.float32)
buf347 = empty_strided((s0, 128, 3072), (393216, 3072, 1), device='cuda', dtype=torch.float32)
triton_fused_clone_3_add_1_div_abs_1_view_13_mul_clone_2_bmm_1_mul_631_permute_87_xnumel = 393216*s0
triton_fused_clone_3_add_1_div_abs_1_view_13_mul_clone_2_bmm_1_mul_631_permute_87.run(buf196, buf197, buf347, triton_fused_clone_3_add_1_div_abs_1_view_13_mul_clone_2_bmm_1_mul_631_permute_87_xnumel, grid=grid(triton_fused_clone_3_add_1_div_abs_1_view_13_mul_clone_2_bmm_1_mul_631_permute_87_xnumel), stream=stream0)
buf198 = as_strided(buf195, (128*s0, 768), (768, 1)); del buf195 # reuse
aten.addmm.out(primals_132, as_strided(buf197, (128*s0, 3072), (3072, 1)), as_strided(primals_131, (3072, 768), (1, 3072)), beta=1, alpha=1, out=buf198)
del primals_132
buf199 = empty_strided((s0, 128, 1), (128, 1, 128*s0), device='cuda', dtype=torch.float32)
buf202 = empty_strided((s0, 128, 1), (128, 1, 128*s0), device='cuda', dtype=torch.float32)
buf200 = buf199; del buf199 # reuse
buf203 = buf202; del buf202 # reuse
buf204 = empty_strided((s0, 128, 768), (98304, 768, 1), device='cuda', dtype=torch.float32)
triton_fused_add_11_clone_3_add_1_div_abs_1_view_13_mul_add_12_clone_2_sub_48_xnumel = 128*s0
triton_fused_add_11_clone_3_add_1_div_abs_1_view_13_mul_add_12_clone_2_sub_48.run(buf200, buf203, buf170, buf189, buf198, primals_29, primals_30, buf204, triton_fused_add_11_clone_3_add_1_div_abs_1_view_13_mul_add_12_clone_2_sub_48_xnumel, 768, grid=grid(triton_fused_add_11_clone_3_add_1_div_abs_1_view_13_mul_add_12_clone_2_sub_48_xnumel), stream=stream0)
buf205 = empty_strided((128*s0, 768), (768, 1), device='cuda', dtype=torch.float32)
aten.addmm.out(primals_134, as_strided(buf204, (128*s0, 768), (768, 1)), as_strided(primals_133, (768, 768), (1, 768)), beta=1, alpha=1, out=buf205)
del primals_134
buf206 = empty_strided((128*s0, 768), (768, 1), device='cuda', dtype=torch.float32)
aten.addmm.out(primals_136, as_strided(buf204, (128*s0, 768), (768, 1)), as_strided(primals_135, (768, 768), (1, 768)), beta=1, alpha=1, out=buf206)
del primals_136
buf207 = empty_strided((128*s0, 768), (768, 1), device='cuda', dtype=torch.float32)
aten.addmm.out(primals_138, as_strided(buf204, (128*s0, 768), (768, 1)), as_strided(primals_137, (768, 768), (1, 768)), beta=1, alpha=1, out=buf207)
del primals_138
buf208 = as_strided(buf204, (s0, 12, 128, 64), (98304, 8192, 64, 1)); del buf204 # reuse
triton_fused_add_1_div_mul_clone_2_view_7_permute_5_sub_permute_4_mean_add4_xnumel = 98304*s0
triton_fused_add_1_div_mul_clone_2_view_7_permute_5_sub_permute_4_mean_add4.run(buf205, buf208, triton_fused_add_1_div_mul_clone_2_view_7_permute_5_sub_permute_4_mean_add4_xnumel, grid=grid(triton_fused_add_1_div_mul_clone_2_view_7_permute_5_sub_permute_4_mean_add4_xnumel), stream=stream0)
buf209 = as_strided(buf205, (s0, 12, 64, 128), (98304, 8192, 128, 1)); del buf205 # reuse
triton_fused_add_1_div_mul_permute_6_sub_addmm_1_mean_permute_3_add_clone_12_xnumel = 768*s0
triton_fused_add_1_div_mul_permute_6_sub_addmm_1_mean_permute_3_add_clone_12.run(buf206, buf209, triton_fused_add_1_div_mul_permute_6_sub_addmm_1_mean_permute_3_add_clone_12_xnumel, 128, grid=grid(triton_fused_add_1_div_mul_permute_6_sub_addmm_1_mean_permute_3_add_clone_12_xnumel, 128), stream=stream0)
buf210 = buf182; del buf182 # reuse
aten.bmm.out(as_strided(buf208, (12*s0, 128, 64), (8192, 64, 1)), as_strided(buf209, (12*s0, 64, 128), (8192, 128, 1)), out=buf210)
buf213 = empty_strided((s0, 12, 128, 128), (196608, 16384, 128, 1), device='cuda', dtype=torch.float32)
triton_fused_add_1_div_mul_sub_mean_add_div_2_sub_1_div_1_amax3_xnumel = 1536*s0
triton_fused_add_1_div_mul_sub_mean_add_div_2_sub_1_div_1_amax3.run(primals_194, buf210, buf213, triton_fused_add_1_div_mul_sub_mean_add_div_2_sub_1_div_1_amax3_xnumel, 128, grid=grid(triton_fused_add_1_div_mul_sub_mean_add_div_2_sub_1_div_1_amax3_xnumel), stream=stream0)
buf214 = as_strided(buf206, (s0, 12, 128, 64), (98304, 8192, 64, 1)); del buf206 # reuse
triton_fused_add_1_div_mul_clone_2_view_7_permute_5_sub_permute_4_mean_add4_xnumel = 98304*s0
triton_fused_add_1_div_mul_clone_2_view_7_permute_5_sub_permute_4_mean_add4.run(buf207, buf214, triton_fused_add_1_div_mul_clone_2_view_7_permute_5_sub_permute_4_mean_add4_xnumel, grid=grid(triton_fused_add_1_div_mul_clone_2_view_7_permute_5_sub_permute_4_mean_add4_xnumel), stream=stream0)
buf215 = as_strided(buf207, (12*s0, 128, 64), (8192, 64, 1)); del buf207 # reuse
aten.bmm.out(as_strided(buf213, (12*s0, 128, 128), (16384, 128, 1)), as_strided(buf214, (12*s0, 128, 64), (8192, 64, 1)), out=buf215)
buf216 = empty_strided((s0, 128, 12, 64), (98304, 768, 64, 1), device='cuda', dtype=torch.float32)
triton_fused_clone_3_add_1_div_where_mul_clone_2_permute_5_bmm_1_permute_6_sum_15_xnumel = 98304*s0
triton_fused_clone_3_add_1_div_where_mul_clone_2_permute_5_bmm_1_permute_6_sum_15.run(buf215, buf216, triton_fused_clone_3_add_1_div_where_mul_clone_2_permute_5_bmm_1_permute_6_sum_15_xnumel, grid=grid(triton_fused_clone_3_add_1_div_where_mul_clone_2_permute_5_bmm_1_permute_6_sum_15_xnumel), stream=stream0)
buf217 = empty_strided((128*s0, 768), (768, 1), device='cuda', dtype=torch.float32)
aten.addmm.out(primals_140, as_strided(buf216, (128*s0, 768), (768, 1)), as_strided(primals_139, (768, 768), (1, 768)), beta=1, alpha=1, out=buf217)
del primals_140
buf218 = empty_strided((s0, 128, 1), (128, 1, 128*s0), device='cuda', dtype=torch.float32)
buf221 = empty_strided((s0, 128, 1), (128, 1, 128*s0), device='cuda', dtype=torch.float32)
buf219 = buf218; del buf218 # reuse
buf222 = buf221; del buf221 # reuse
buf223 = as_strided(buf216, (s0, 128, 768), (98304, 768, 1)); del buf216 # reuse
triton_fused_add_11_add_1_permute_19_div_view_23_mul_add_12_sub_4_permute_14_view_199_xnumel = 128*s0
triton_fused_add_11_add_1_permute_19_div_view_23_mul_add_12_sub_4_permute_14_view_199.run(buf219, buf222, buf170, buf189, buf198, buf217, primals_31, primals_32, buf223, triton_fused_add_11_add_1_permute_19_div_view_23_mul_add_12_sub_4_permute_14_view_199_xnumel, 768, grid=grid(triton_fused_add_11_add_1_permute_19_div_view_23_mul_add_12_sub_4_permute_14_view_199_xnumel), stream=stream0)
buf224 = buf196; del buf196 # reuse
aten.addmm.out(primals_142, as_strided(buf223, (128*s0, 768), (768, 1)), as_strided(primals_141, (768, 3072), (1, 768)), beta=1, alpha=1, out=buf224)
del primals_142
buf225 = empty_strided((s0, 128, 3072), (393216, 3072, 1), device='cuda', dtype=torch.float32)
buf346 = empty_strided((s0, 128, 3072), (393216, 3072, 1), device='cuda', dtype=torch.float32)
triton_fused_clone_3_add_1_div_abs_1_view_13_mul_clone_2_bmm_1_mul_631_permute_87_xnumel = 393216*s0
triton_fused_clone_3_add_1_div_abs_1_view_13_mul_clone_2_bmm_1_mul_631_permute_87.run(buf224, buf225, buf346, triton_fused_clone_3_add_1_div_abs_1_view_13_mul_clone_2_bmm_1_mul_631_permute_87_xnumel, grid=grid(triton_fused_clone_3_add_1_div_abs_1_view_13_mul_clone_2_bmm_1_mul_631_permute_87_xnumel), stream=stream0)
buf226 = as_strided(buf223, (128*s0, 768), (768, 1)); del buf223 # reuse
aten.addmm.out(primals_144, as_strided(buf225, (128*s0, 3072), (3072, 1)), as_strided(primals_143, (3072, 768), (1, 3072)), beta=1, alpha=1, out=buf226)
del primals_144
buf227 = buf170; del buf170 # reuse
buf228 = empty_strided((s0, 128, 1), (128, 1, 128*s0), device='cuda', dtype=torch.float32)
buf231 = empty_strided((s0, 128, 1), (128, 1, 128*s0), device='cuda', dtype=torch.float32)
buf229 = buf228; del buf228 # reuse
buf232 = buf231; del buf231 # reuse
buf233 = empty_strided((s0, 128, 768), (98304, 768, 1), device='cuda', dtype=torch.float32)
triton_fused_permute_19_view_35_view_34_reciprocal_1_view_33_addmm_9_mul_32_permute_20_std_3_add_1411_xnumel = 128*s0
triton_fused_permute_19_view_35_view_34_reciprocal_1_view_33_addmm_9_mul_32_permute_20_std_3_add_1411.run(buf227, buf229, buf232, buf189, buf198, buf217, buf226, primals_33, primals_34, buf233, triton_fused_permute_19_view_35_view_34_reciprocal_1_view_33_addmm_9_mul_32_permute_20_std_3_add_1411_xnumel, 768, grid=grid(triton_fused_permute_19_view_35_view_34_reciprocal_1_view_33_addmm_9_mul_32_permute_20_std_3_add_1411_xnumel), stream=stream0)
buf234 = empty_strided((128*s0, 768), (768, 1), device='cuda', dtype=torch.float32)
aten.addmm.out(primals_146, as_strided(buf233, (128*s0, 768), (768, 1)), as_strided(primals_145, (768, 768), (1, 768)), beta=1, alpha=1, out=buf234)
del primals_146
buf235 = empty_strided((128*s0, 768), (768, 1), device='cuda', dtype=torch.float32)
aten.addmm.out(primals_148, as_strided(buf233, (128*s0, 768), (768, 1)), as_strided(primals_147, (768, 768), (1, 768)), beta=1, alpha=1, out=buf235)
del primals_148
buf236 = empty_strided((128*s0, 768), (768, 1), device='cuda', dtype=torch.float32)
aten.addmm.out(primals_150, as_strided(buf233, (128*s0, 768), (768, 1)), as_strided(primals_149, (768, 768), (1, 768)), beta=1, alpha=1, out=buf236)
del primals_150
buf237 = as_strided(buf233, (s0, 12, 128, 64), (98304, 8192, 64, 1)); del buf233 # reuse
triton_fused_add_1_div_mul_clone_2_view_7_permute_5_sub_permute_4_mean_add4_xnumel = 98304*s0
triton_fused_add_1_div_mul_clone_2_view_7_permute_5_sub_permute_4_mean_add4.run(buf234, buf237, triton_fused_add_1_div_mul_clone_2_view_7_permute_5_sub_permute_4_mean_add4_xnumel, grid=grid(triton_fused_add_1_div_mul_clone_2_view_7_permute_5_sub_permute_4_mean_add4_xnumel), stream=stream0)
buf238 = as_strided(buf234, (s0, 12, 64, 128), (98304, 8192, 128, 1)); del buf234 # reuse
triton_fused_add_1_div_mul_permute_6_sub_addmm_1_mean_permute_3_add_clone_12_xnumel = 768*s0
triton_fused_add_1_div_mul_permute_6_sub_addmm_1_mean_permute_3_add_clone_12.run(buf235, buf238, triton_fused_add_1_div_mul_permute_6_sub_addmm_1_mean_permute_3_add_clone_12_xnumel, 128, grid=grid(triton_fused_add_1_div_mul_permute_6_sub_addmm_1_mean_permute_3_add_clone_12_xnumel, 128), stream=stream0)
buf239 = buf210; del buf210 # reuse
aten.bmm.out(as_strided(buf237, (12*s0, 128, 64), (8192, 64, 1)), as_strided(buf238, (12*s0, 64, 128), (8192, 128, 1)), out=buf239)
buf242 = empty_strided((s0, 12, 128, 128), (196608, 16384, 128, 1), device='cuda', dtype=torch.float32)
triton_fused_add_1_div_mul_sub_mean_add_div_2_sub_1_div_1_amax3_xnumel = 1536*s0
triton_fused_add_1_div_mul_sub_mean_add_div_2_sub_1_div_1_amax3.run(primals_194, buf239, buf242, triton_fused_add_1_div_mul_sub_mean_add_div_2_sub_1_div_1_amax3_xnumel, 128, grid=grid(triton_fused_add_1_div_mul_sub_mean_add_div_2_sub_1_div_1_amax3_xnumel), stream=stream0)
buf243 = as_strided(buf235, (s0, 12, 128, 64), (98304, 8192, 64, 1)); del buf235 # reuse
triton_fused_add_1_div_mul_clone_2_view_7_permute_5_sub_permute_4_mean_add4_xnumel = 98304*s0
triton_fused_add_1_div_mul_clone_2_view_7_permute_5_sub_permute_4_mean_add4.run(buf236, buf243, triton_fused_add_1_div_mul_clone_2_view_7_permute_5_sub_permute_4_mean_add4_xnumel, grid=grid(triton_fused_add_1_div_mul_clone_2_view_7_permute_5_sub_permute_4_mean_add4_xnumel), stream=stream0)
buf244 = as_strided(buf236, (12*s0, 128, 64), (8192, 64, 1)); del buf236 # reuse
aten.bmm.out(as_strided(buf242, (12*s0, 128, 128), (16384, 128, 1)), as_strided(buf243, (12*s0, 128, 64), (8192, 64, 1)), out=buf244)
buf245 = empty_strided((s0, 128, 12, 64), (98304, 768, 64, 1), device='cuda', dtype=torch.float32)
triton_fused_clone_3_add_1_div_where_mul_clone_2_permute_5_bmm_1_permute_6_sum_15_xnumel = 98304*s0
triton_fused_clone_3_add_1_div_where_mul_clone_2_permute_5_bmm_1_permute_6_sum_15.run(buf244, buf245, triton_fused_clone_3_add_1_div_where_mul_clone_2_permute_5_bmm_1_permute_6_sum_15_xnumel, grid=grid(triton_fused_clone_3_add_1_div_where_mul_clone_2_permute_5_bmm_1_permute_6_sum_15_xnumel), stream=stream0)
buf246 = empty_strided((128*s0, 768), (768, 1), device='cuda', dtype=torch.float32)
aten.addmm.out(primals_152, as_strided(buf245, (128*s0, 768), (768, 1)), as_strided(primals_151, (768, 768), (1, 768)), beta=1, alpha=1, out=buf246)
del primals_152
buf247 = empty_strided((s0, 128, 1), (128, 1, 128*s0), device='cuda', dtype=torch.float32)
buf250 = empty_strided((s0, 128, 1), (128, 1, 128*s0), device='cuda', dtype=torch.float32)
buf248 = buf247; del buf247 # reuse
buf251 = buf250; del buf250 # reuse
buf252 = as_strided(buf245, (s0, 128, 768), (98304, 768, 1)); del buf245 # reuse
triton_fused_clone_3_add_1_div_view_13_mul_clone_2_bmm_1_permute_8_sub_mean6_xnumel = 128*s0
triton_fused_clone_3_add_1_div_view_13_mul_clone_2_bmm_1_permute_8_sub_mean6.run(buf248, buf251, buf227, buf246, primals_35, primals_36, buf252, triton_fused_clone_3_add_1_div_view_13_mul_clone_2_bmm_1_permute_8_sub_mean6_xnumel, 768, grid=grid(triton_fused_clone_3_add_1_div_view_13_mul_clone_2_bmm_1_permute_8_sub_mean6_xnumel), stream=stream0)
buf253 = buf224; del buf224 # reuse
aten.addmm.out(primals_154, as_strided(buf252, (128*s0, 768), (768, 1)), as_strided(primals_153, (768, 3072), (1, 768)), beta=1, alpha=1, out=buf253)
del primals_154
buf254 = empty_strided((s0, 128, 3072), (393216, 3072, 1), device='cuda', dtype=torch.float32)
buf345 = empty_strided((s0, 128, 3072), (393216, 3072, 1), device='cuda', dtype=torch.float32)
triton_fused_clone_3_add_1_div_abs_1_view_13_mul_clone_2_bmm_1_mul_631_permute_87_xnumel = 393216*s0
triton_fused_clone_3_add_1_div_abs_1_view_13_mul_clone_2_bmm_1_mul_631_permute_87.run(buf253, buf254, buf345, triton_fused_clone_3_add_1_div_abs_1_view_13_mul_clone_2_bmm_1_mul_631_permute_87_xnumel, grid=grid(triton_fused_clone_3_add_1_div_abs_1_view_13_mul_clone_2_bmm_1_mul_631_permute_87_xnumel), stream=stream0)
buf255 = as_strided(buf252, (128*s0, 768), (768, 1)); del buf252 # reuse
aten.addmm.out(primals_156, as_strided(buf254, (128*s0, 3072), (3072, 1)), as_strided(primals_155, (3072, 768), (1, 3072)), beta=1, alpha=1, out=buf255)
del primals_156
buf256 = empty_strided((s0, 128, 1), (128, 1, 128*s0), device='cuda', dtype=torch.float32)
buf259 = empty_strided((s0, 128, 1), (128, 1, 128*s0), device='cuda', dtype=torch.float32)
buf257 = buf256; del buf256 # reuse
buf260 = buf259; del buf259 # reuse
buf261 = empty_strided((s0, 128, 768), (98304, 768, 1), device='cuda', dtype=torch.float32)
triton_fused_add_11_clone_3_add_1_div_abs_1_view_13_mul_add_12_clone_2_sub_48_xnumel = 128*s0
triton_fused_add_11_clone_3_add_1_div_abs_1_view_13_mul_add_12_clone_2_sub_48.run(buf257, buf260, buf227, buf246, buf255, primals_37, primals_38, buf261, triton_fused_add_11_clone_3_add_1_div_abs_1_view_13_mul_add_12_clone_2_sub_48_xnumel, 768, grid=grid(triton_fused_add_11_clone_3_add_1_div_abs_1_view_13_mul_add_12_clone_2_sub_48_xnumel), stream=stream0)
buf262 = empty_strided((128*s0, 768), (768, 1), device='cuda', dtype=torch.float32)
aten.addmm.out(primals_158, as_strided(buf261, (128*s0, 768), (768, 1)), as_strided(primals_157, (768, 768), (1, 768)), beta=1, alpha=1, out=buf262)
del primals_158
buf263 = empty_strided((128*s0, 768), (768, 1), device='cuda', dtype=torch.float32)
aten.addmm.out(primals_160, as_strided(buf261, (128*s0, 768), (768, 1)), as_strided(primals_159, (768, 768), (1, 768)), beta=1, alpha=1, out=buf263)
del primals_160
buf264 = empty_strided((128*s0, 768), (768, 1), device='cuda', dtype=torch.float32)
aten.addmm.out(primals_162, as_strided(buf261, (128*s0, 768), (768, 1)), as_strided(primals_161, (768, 768), (1, 768)), beta=1, alpha=1, out=buf264)
del primals_162
buf265 = as_strided(buf261, (s0, 12, 128, 64), (98304, 8192, 64, 1)); del buf261 # reuse
triton_fused_add_1_div_mul_clone_2_view_7_permute_5_sub_permute_4_mean_add4_xnumel = 98304*s0
triton_fused_add_1_div_mul_clone_2_view_7_permute_5_sub_permute_4_mean_add4.run(buf262, buf265, triton_fused_add_1_div_mul_clone_2_view_7_permute_5_sub_permute_4_mean_add4_xnumel, grid=grid(triton_fused_add_1_div_mul_clone_2_view_7_permute_5_sub_permute_4_mean_add4_xnumel), stream=stream0)
buf266 = as_strided(buf262, (s0, 12, 64, 128), (98304, 8192, 128, 1)); del buf262 # reuse
triton_fused_add_1_div_mul_permute_6_sub_addmm_1_mean_permute_3_add_clone_12_xnumel = 768*s0
triton_fused_add_1_div_mul_permute_6_sub_addmm_1_mean_permute_3_add_clone_12.run(buf263, buf266, triton_fused_add_1_div_mul_permute_6_sub_addmm_1_mean_permute_3_add_clone_12_xnumel, 128, grid=grid(triton_fused_add_1_div_mul_permute_6_sub_addmm_1_mean_permute_3_add_clone_12_xnumel, 128), stream=stream0)
buf267 = buf239; del buf239 # reuse
aten.bmm.out(as_strided(buf265, (12*s0, 128, 64), (8192, 64, 1)), as_strided(buf266, (12*s0, 64, 128), (8192, 128, 1)), out=buf267)
buf270 = empty_strided((s0, 12, 128, 128), (196608, 16384, 128, 1), device='cuda', dtype=torch.float32)
triton_fused_add_1_div_mul_sub_mean_add_div_2_sub_1_div_1_amax3_xnumel = 1536*s0
triton_fused_add_1_div_mul_sub_mean_add_div_2_sub_1_div_1_amax3.run(primals_194, buf267, buf270, triton_fused_add_1_div_mul_sub_mean_add_div_2_sub_1_div_1_amax3_xnumel, 128, grid=grid(triton_fused_add_1_div_mul_sub_mean_add_div_2_sub_1_div_1_amax3_xnumel), stream=stream0)
buf271 = as_strided(buf263, (s0, 12, 128, 64), (98304, 8192, 64, 1)); del buf263 # reuse
triton_fused_add_1_div_mul_clone_2_view_7_permute_5_sub_permute_4_mean_add4_xnumel = 98304*s0
triton_fused_add_1_div_mul_clone_2_view_7_permute_5_sub_permute_4_mean_add4.run(buf264, buf271, triton_fused_add_1_div_mul_clone_2_view_7_permute_5_sub_permute_4_mean_add4_xnumel, grid=grid(triton_fused_add_1_div_mul_clone_2_view_7_permute_5_sub_permute_4_mean_add4_xnumel), stream=stream0)
buf272 = as_strided(buf264, (12*s0, 128, 64), (8192, 64, 1)); del buf264 # reuse
aten.bmm.out(as_strided(buf270, (12*s0, 128, 128), (16384, 128, 1)), as_strided(buf271, (12*s0, 128, 64), (8192, 64, 1)), out=buf272)
buf273 = empty_strided((s0, 128, 12, 64), (98304, 768, 64, 1), device='cuda', dtype=torch.float32)
triton_fused_clone_3_add_1_div_where_mul_clone_2_permute_5_bmm_1_permute_6_sum_15_xnumel = 98304*s0
triton_fused_clone_3_add_1_div_where_mul_clone_2_permute_5_bmm_1_permute_6_sum_15.run(buf272, buf273, triton_fused_clone_3_add_1_div_where_mul_clone_2_permute_5_bmm_1_permute_6_sum_15_xnumel, grid=grid(triton_fused_clone_3_add_1_div_where_mul_clone_2_permute_5_bmm_1_permute_6_sum_15_xnumel), stream=stream0)
buf274 = empty_strided((128*s0, 768), (768, 1), device='cuda', dtype=torch.float32)
aten.addmm.out(primals_164, as_strided(buf273, (128*s0, 768), (768, 1)), as_strided(primals_163, (768, 768), (1, 768)), beta=1, alpha=1, out=buf274)
del primals_164
buf275 = empty_strided((s0, 128, 1), (128, 1, 128*s0), device='cuda', dtype=torch.float32)
buf278 = empty_strided((s0, 128, 1), (128, 1, 128*s0), device='cuda', dtype=torch.float32)
buf276 = buf275; del buf275 # reuse
buf279 = buf278; del buf278 # reuse
buf280 = as_strided(buf273, (s0, 128, 768), (98304, 768, 1)); del buf273 # reuse
triton_fused_add_11_add_1_permute_19_div_view_23_mul_add_12_sub_4_permute_14_view_199_xnumel = 128*s0
triton_fused_add_11_add_1_permute_19_div_view_23_mul_add_12_sub_4_permute_14_view_199.run(buf276, buf279, buf227, buf246, buf255, buf274, primals_39, primals_40, buf280, triton_fused_add_11_add_1_permute_19_div_view_23_mul_add_12_sub_4_permute_14_view_199_xnumel, 768, grid=grid(triton_fused_add_11_add_1_permute_19_div_view_23_mul_add_12_sub_4_permute_14_view_199_xnumel), stream=stream0)
buf281 = buf253; del buf253 # reuse
aten.addmm.out(primals_166, as_strided(buf280, (128*s0, 768), (768, 1)), as_strided(primals_165, (768, 3072), (1, 768)), beta=1, alpha=1, out=buf281)
del primals_166
buf282 = empty_strided((s0, 128, 3072), (393216, 3072, 1), device='cuda', dtype=torch.float32)
buf344 = empty_strided((s0, 128, 3072), (393216, 3072, 1), device='cuda', dtype=torch.float32)
triton_fused_clone_3_add_1_div_abs_1_view_13_mul_clone_2_bmm_1_mul_631_permute_87_xnumel = 393216*s0
triton_fused_clone_3_add_1_div_abs_1_view_13_mul_clone_2_bmm_1_mul_631_permute_87.run(buf281, buf282, buf344, triton_fused_clone_3_add_1_div_abs_1_view_13_mul_clone_2_bmm_1_mul_631_permute_87_xnumel, grid=grid(triton_fused_clone_3_add_1_div_abs_1_view_13_mul_clone_2_bmm_1_mul_631_permute_87_xnumel), stream=stream0)
buf283 = as_strided(buf280, (128*s0, 768), (768, 1)); del buf280 # reuse
aten.addmm.out(primals_168, as_strided(buf282, (128*s0, 3072), (3072, 1)), as_strided(primals_167, (3072, 768), (1, 3072)), beta=1, alpha=1, out=buf283)
del primals_168
buf284 = buf227; del buf227 # reuse
buf285 = empty_strided((s0, 128, 1), (128, 1, 128*s0), device='cuda', dtype=torch.float32)
buf288 = empty_strided((s0, 128, 1), (128, 1, 128*s0), device='cuda', dtype=torch.float32)
buf286 = buf285; del buf285 # reuse
buf289 = buf288; del buf288 # reuse
buf290 = empty_strided((s0, 128, 768), (98304, 768, 1), device='cuda', dtype=torch.float32)
triton_fused_permute_19_view_35_view_34_reciprocal_1_view_33_addmm_9_mul_32_permute_20_std_3_add_1411_xnumel = 128*s0
triton_fused_permute_19_view_35_view_34_reciprocal_1_view_33_addmm_9_mul_32_permute_20_std_3_add_1411.run(buf284, buf286, buf289, buf246, buf255, buf274, buf283, primals_41, primals_42, buf290, triton_fused_permute_19_view_35_view_34_reciprocal_1_view_33_addmm_9_mul_32_permute_20_std_3_add_1411_xnumel, 768, grid=grid(triton_fused_permute_19_view_35_view_34_reciprocal_1_view_33_addmm_9_mul_32_permute_20_std_3_add_1411_xnumel), stream=stream0)
buf291 = empty_strided((128*s0, 768), (768, 1), device='cuda', dtype=torch.float32)
aten.addmm.out(primals_170, as_strided(buf290, (128*s0, 768), (768, 1)), as_strided(primals_169, (768, 768), (1, 768)), beta=1, alpha=1, out=buf291)
del primals_170
buf292 = empty_strided((128*s0, 768), (768, 1), device='cuda', dtype=torch.float32)
aten.addmm.out(primals_172, as_strided(buf290, (128*s0, 768), (768, 1)), as_strided(primals_171, (768, 768), (1, 768)), beta=1, alpha=1, out=buf292)
del primals_172
buf293 = empty_strided((128*s0, 768), (768, 1), device='cuda', dtype=torch.float32)
aten.addmm.out(primals_174, as_strided(buf290, (128*s0, 768), (768, 1)), as_strided(primals_173, (768, 768), (1, 768)), beta=1, alpha=1, out=buf293)
del primals_174
buf294 = as_strided(buf290, (s0, 12, 128, 64), (98304, 8192, 64, 1)); del buf290 # reuse
triton_fused_add_1_div_mul_clone_2_view_7_permute_5_sub_permute_4_mean_add4_xnumel = 98304*s0
triton_fused_add_1_div_mul_clone_2_view_7_permute_5_sub_permute_4_mean_add4.run(buf291, buf294, triton_fused_add_1_div_mul_clone_2_view_7_permute_5_sub_permute_4_mean_add4_xnumel, grid=grid(triton_fused_add_1_div_mul_clone_2_view_7_permute_5_sub_permute_4_mean_add4_xnumel), stream=stream0)
buf295 = as_strided(buf291, (s0, 12, 64, 128), (98304, 8192, 128, 1)); del buf291 # reuse
triton_fused_add_1_div_mul_permute_6_sub_addmm_1_mean_permute_3_add_clone_12_xnumel = 768*s0
triton_fused_add_1_div_mul_permute_6_sub_addmm_1_mean_permute_3_add_clone_12.run(buf292, buf295, triton_fused_add_1_div_mul_permute_6_sub_addmm_1_mean_permute_3_add_clone_12_xnumel, 128, grid=grid(triton_fused_add_1_div_mul_permute_6_sub_addmm_1_mean_permute_3_add_clone_12_xnumel, 128), stream=stream0)
buf296 = buf267; del buf267 # reuse
aten.bmm.out(as_strided(buf294, (12*s0, 128, 64), (8192, 64, 1)), as_strided(buf295, (12*s0, 64, 128), (8192, 128, 1)), out=buf296)
buf299 = empty_strided((s0, 12, 128, 128), (196608, 16384, 128, 1), device='cuda', dtype=torch.float32)
triton_fused_add_1_div_mul_sub_mean_add_div_2_sub_1_div_1_amax3_xnumel = 1536*s0
triton_fused_add_1_div_mul_sub_mean_add_div_2_sub_1_div_1_amax3.run(primals_194, buf296, buf299, triton_fused_add_1_div_mul_sub_mean_add_div_2_sub_1_div_1_amax3_xnumel, 128, grid=grid(triton_fused_add_1_div_mul_sub_mean_add_div_2_sub_1_div_1_amax3_xnumel), stream=stream0)
buf300 = as_strided(buf292, (s0, 12, 128, 64), (98304, 8192, 64, 1)); del buf292 # reuse
triton_fused_add_1_div_mul_clone_2_view_7_permute_5_sub_permute_4_mean_add4_xnumel = 98304*s0
triton_fused_add_1_div_mul_clone_2_view_7_permute_5_sub_permute_4_mean_add4.run(buf293, buf300, triton_fused_add_1_div_mul_clone_2_view_7_permute_5_sub_permute_4_mean_add4_xnumel, grid=grid(triton_fused_add_1_div_mul_clone_2_view_7_permute_5_sub_permute_4_mean_add4_xnumel), stream=stream0)
buf301 = as_strided(buf293, (12*s0, 128, 64), (8192, 64, 1)); del buf293 # reuse
aten.bmm.out(as_strided(buf299, (12*s0, 128, 128), (16384, 128, 1)), as_strided(buf300, (12*s0, 128, 64), (8192, 64, 1)), out=buf301)
buf302 = empty_strided((s0, 128, 12, 64), (98304, 768, 64, 1), device='cuda', dtype=torch.float32)
triton_fused_clone_3_add_1_div_where_mul_clone_2_permute_5_bmm_1_permute_6_sum_15_xnumel = 98304*s0
triton_fused_clone_3_add_1_div_where_mul_clone_2_permute_5_bmm_1_permute_6_sum_15.run(buf301, buf302, triton_fused_clone_3_add_1_div_where_mul_clone_2_permute_5_bmm_1_permute_6_sum_15_xnumel, grid=grid(triton_fused_clone_3_add_1_div_where_mul_clone_2_permute_5_bmm_1_permute_6_sum_15_xnumel), stream=stream0)
buf303 = empty_strided((128*s0, 768), (768, 1), device='cuda', dtype=torch.float32)
aten.addmm.out(primals_176, as_strided(buf302, (128*s0, 768), (768, 1)), as_strided(primals_175, (768, 768), (1, 768)), beta=1, alpha=1, out=buf303)
del primals_176
buf304 = empty_strided((s0, 128, 1), (128, 1, 128*s0), device='cuda', dtype=torch.float32)
buf307 = empty_strided((s0, 128, 1), (128, 1, 128*s0), device='cuda', dtype=torch.float32)
buf305 = buf304; del buf304 # reuse
buf308 = buf307; del buf307 # reuse
buf309 = as_strided(buf302, (s0, 128, 768), (98304, 768, 1)); del buf302 # reuse
triton_fused_clone_3_add_1_div_view_13_mul_clone_2_bmm_1_permute_8_sub_mean6_xnumel = 128*s0
triton_fused_clone_3_add_1_div_view_13_mul_clone_2_bmm_1_permute_8_sub_mean6.run(buf305, buf308, buf284, buf303, primals_43, primals_44, buf309, triton_fused_clone_3_add_1_div_view_13_mul_clone_2_bmm_1_permute_8_sub_mean6_xnumel, 768, grid=grid(triton_fused_clone_3_add_1_div_view_13_mul_clone_2_bmm_1_permute_8_sub_mean6_xnumel), stream=stream0)
buf310 = buf281; del buf281 # reuse
aten.addmm.out(primals_178, as_strided(buf309, (128*s0, 768), (768, 1)), as_strided(primals_177, (768, 3072), (1, 768)), beta=1, alpha=1, out=buf310)
del primals_178
buf311 = empty_strided((s0, 128, 3072), (393216, 3072, 1), device='cuda', dtype=torch.float32)
buf343 = empty_strided((s0, 128, 3072), (393216, 3072, 1), device='cuda', dtype=torch.float32)
triton_fused_clone_3_add_1_div_abs_1_view_13_mul_clone_2_bmm_1_mul_631_permute_87_xnumel = 393216*s0
triton_fused_clone_3_add_1_div_abs_1_view_13_mul_clone_2_bmm_1_mul_631_permute_87.run(buf310, buf311, buf343, triton_fused_clone_3_add_1_div_abs_1_view_13_mul_clone_2_bmm_1_mul_631_permute_87_xnumel, grid=grid(triton_fused_clone_3_add_1_div_abs_1_view_13_mul_clone_2_bmm_1_mul_631_permute_87_xnumel), stream=stream0)
buf312 = as_strided(buf309, (128*s0, 768), (768, 1)); del buf309 # reuse
aten.addmm.out(primals_180, as_strided(buf311, (128*s0, 3072), (3072, 1)), as_strided(primals_179, (3072, 768), (1, 3072)), beta=1, alpha=1, out=buf312)
del primals_180
buf313 = empty_strided((s0, 128, 1), (128, 1, 128*s0), device='cuda', dtype=torch.float32)
buf316 = empty_strided((s0, 128, 1), (128, 1, 128*s0), device='cuda', dtype=torch.float32)
buf314 = buf313; del buf313 # reuse
buf317 = buf316; del buf316 # reuse
buf318 = empty_strided((s0, 128, 768), (98304, 768, 1), device='cuda', dtype=torch.float32)
triton_fused_add_11_clone_3_add_1_div_abs_1_view_13_mul_add_12_clone_2_sub_48_xnumel = 128*s0
triton_fused_add_11_clone_3_add_1_div_abs_1_view_13_mul_add_12_clone_2_sub_48.run(buf314, buf317, buf284, buf303, buf312, primals_45, primals_46, buf318, triton_fused_add_11_clone_3_add_1_div_abs_1_view_13_mul_add_12_clone_2_sub_48_xnumel, 768, grid=grid(triton_fused_add_11_clone_3_add_1_div_abs_1_view_13_mul_add_12_clone_2_sub_48_xnumel), stream=stream0)
buf319 = empty_strided((128*s0, 768), (768, 1), device='cuda', dtype=torch.float32)
aten.addmm.out(primals_182, as_strided(buf318, (128*s0, 768), (768, 1)), as_strided(primals_181, (768, 768), (1, 768)), beta=1, alpha=1, out=buf319)
del primals_182
buf320 = empty_strided((128*s0, 768), (768, 1), device='cuda', dtype=torch.float32)
aten.addmm.out(primals_184, as_strided(buf318, (128*s0, 768), (768, 1)), as_strided(primals_183, (768, 768), (1, 768)), beta=1, alpha=1, out=buf320)
del primals_184
buf321 = empty_strided((128*s0, 768), (768, 1), device='cuda', dtype=torch.float32)
aten.addmm.out(primals_186, as_strided(buf318, (128*s0, 768), (768, 1)), as_strided(primals_185, (768, 768), (1, 768)), beta=1, alpha=1, out=buf321)
del primals_186
buf322 = as_strided(buf318, (s0, 12, 128, 64), (98304, 8192, 64, 1)); del buf318 # reuse
triton_fused_add_1_div_mul_clone_2_view_7_permute_5_sub_permute_4_mean_add4_xnumel = 98304*s0
triton_fused_add_1_div_mul_clone_2_view_7_permute_5_sub_permute_4_mean_add4.run(buf319, buf322, triton_fused_add_1_div_mul_clone_2_view_7_permute_5_sub_permute_4_mean_add4_xnumel, grid=grid(triton_fused_add_1_div_mul_clone_2_view_7_permute_5_sub_permute_4_mean_add4_xnumel), stream=stream0)
buf323 = as_strided(buf319, (s0, 12, 64, 128), (98304, 8192, 128, 1)); del buf319 # reuse
triton_fused_add_1_div_mul_permute_6_sub_addmm_1_mean_permute_3_add_clone_12_xnumel = 768*s0
triton_fused_add_1_div_mul_permute_6_sub_addmm_1_mean_permute_3_add_clone_12.run(buf320, buf323, triton_fused_add_1_div_mul_permute_6_sub_addmm_1_mean_permute_3_add_clone_12_xnumel, 128, grid=grid(triton_fused_add_1_div_mul_permute_6_sub_addmm_1_mean_permute_3_add_clone_12_xnumel, 128), stream=stream0)
buf324 = buf296; del buf296 # reuse
aten.bmm.out(as_strided(buf322, (12*s0, 128, 64), (8192, 64, 1)), as_strided(buf323, (12*s0, 64, 128), (8192, 128, 1)), out=buf324)
buf327 = empty_strided((s0, 12, 128, 128), (196608, 16384, 128, 1), device='cuda', dtype=torch.float32)
triton_fused_add_1_div_mul_sub_mean_add_div_2_sub_1_div_1_amax3_xnumel = 1536*s0
triton_fused_add_1_div_mul_sub_mean_add_div_2_sub_1_div_1_amax3.run(primals_194, buf324, buf327, triton_fused_add_1_div_mul_sub_mean_add_div_2_sub_1_div_1_amax3_xnumel, 128, grid=grid(triton_fused_add_1_div_mul_sub_mean_add_div_2_sub_1_div_1_amax3_xnumel), stream=stream0)
del buf324
buf328 = as_strided(buf320, (s0, 12, 128, 64), (98304, 8192, 64, 1)); del buf320 # reuse
triton_fused_add_1_div_mul_clone_2_view_7_permute_5_sub_permute_4_mean_add4_xnumel = 98304*s0
triton_fused_add_1_div_mul_clone_2_view_7_permute_5_sub_permute_4_mean_add4.run(buf321, buf328, triton_fused_add_1_div_mul_clone_2_view_7_permute_5_sub_permute_4_mean_add4_xnumel, grid=grid(triton_fused_add_1_div_mul_clone_2_view_7_permute_5_sub_permute_4_mean_add4_xnumel), stream=stream0)
buf329 = as_strided(buf321, (12*s0, 128, 64), (8192, 64, 1)); del buf321 # reuse
aten.bmm.out(as_strided(buf327, (12*s0, 128, 128), (16384, 128, 1)), as_strided(buf328, (12*s0, 128, 64), (8192, 64, 1)), out=buf329)
buf330 = empty_strided((s0, 128, 12, 64), (98304, 768, 64, 1), device='cuda', dtype=torch.float32)
triton_fused_clone_3_add_1_div_where_mul_clone_2_permute_5_bmm_1_permute_6_sum_15_xnumel = 98304*s0
triton_fused_clone_3_add_1_div_where_mul_clone_2_permute_5_bmm_1_permute_6_sum_15.run(buf329, buf330, triton_fused_clone_3_add_1_div_where_mul_clone_2_permute_5_bmm_1_permute_6_sum_15_xnumel, grid=grid(triton_fused_clone_3_add_1_div_where_mul_clone_2_permute_5_bmm_1_permute_6_sum_15_xnumel), stream=stream0)
buf331 = empty_strided((128*s0, 768), (768, 1), device='cuda', dtype=torch.float32)
aten.addmm.out(primals_188, as_strided(buf330, (128*s0, 768), (768, 1)), as_strided(primals_187, (768, 768), (1, 768)), beta=1, alpha=1, out=buf331)
del primals_188
buf332 = empty_strided((s0, 128, 1), (128, 1, 128*s0), device='cuda', dtype=torch.float32)
buf335 = empty_strided((s0, 128, 1), (128, 1, 128*s0), device='cuda', dtype=torch.float32)
buf333 = buf332; del buf332 # reuse
buf336 = buf335; del buf335 # reuse
buf337 = as_strided(buf330, (s0, 128, 768), (98304, 768, 1)); del buf330 # reuse
triton_fused_add_11_add_1_permute_19_div_view_23_mul_add_12_sub_4_permute_14_view_199_xnumel = 128*s0
triton_fused_add_11_add_1_permute_19_div_view_23_mul_add_12_sub_4_permute_14_view_199.run(buf333, buf336, buf284, buf303, buf312, buf331, primals_47, primals_48, buf337, triton_fused_add_11_add_1_permute_19_div_view_23_mul_add_12_sub_4_permute_14_view_199_xnumel, 768, grid=grid(triton_fused_add_11_add_1_permute_19_div_view_23_mul_add_12_sub_4_permute_14_view_199_xnumel), stream=stream0)
buf338 = buf310; del buf310 # reuse
aten.addmm.out(primals_190, as_strided(buf337, (128*s0, 768), (768, 1)), as_strided(primals_189, (768, 3072), (1, 768)), beta=1, alpha=1, out=buf338)
del primals_190
buf339 = empty_strided((s0, 128, 3072), (393216, 3072, 1), device='cuda', dtype=torch.float32)
buf342 = empty_strided((s0, 128, 3072), (393216, 3072, 1), device='cuda', dtype=torch.float32)
triton_fused_clone_3_add_1_div_abs_1_view_13_mul_clone_2_bmm_1_mul_631_permute_87_xnumel = 393216*s0
triton_fused_clone_3_add_1_div_abs_1_view_13_mul_clone_2_bmm_1_mul_631_permute_87.run(buf338, buf339, buf342, triton_fused_clone_3_add_1_div_abs_1_view_13_mul_clone_2_bmm_1_mul_631_permute_87_xnumel, grid=grid(triton_fused_clone_3_add_1_div_abs_1_view_13_mul_clone_2_bmm_1_mul_631_permute_87_xnumel), stream=stream0)
del buf338
buf340 = as_strided(buf337, (128*s0, 768), (768, 1)); del buf337 # reuse
aten.addmm.out(primals_192, as_strided(buf339, (128*s0, 3072), (3072, 1)), as_strided(primals_191, (3072, 768), (1, 3072)), beta=1, alpha=1, out=buf340)
del primals_192
buf341 = buf284; del buf284 # reuse
triton_fused_clone_19_view_88_bmm_9_div_18__unsafe_view_14_abs_5_lift_fresh_copy_8_clone_17_bmm_8_permute_5212_xnumel = 98304*s0
triton_fused_clone_19_view_88_bmm_9_div_18__unsafe_view_14_abs_5_lift_fresh_copy_8_clone_17_bmm_8_permute_5212.run(buf341, buf303, buf312, buf331, buf340, triton_fused_clone_19_view_88_bmm_9_div_18__unsafe_view_14_abs_5_lift_fresh_copy_8_clone_17_bmm_8_permute_5212_xnumel, grid=grid(triton_fused_clone_19_view_88_bmm_9_div_18__unsafe_view_14_abs_5_lift_fresh_copy_8_clone_17_bmm_8_permute_5212_xnumel), stream=stream0)
return (buf341, primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, primals_12, primals_13, primals_14, primals_15, primals_16, primals_17, primals_18, primals_19, primals_20, primals_21, primals_22, primals_23, primals_24, primals_25, primals_26, primals_27, primals_28, primals_29, primals_30, primals_31, primals_32, primals_33, primals_34, primals_35, primals_36, primals_37, primals_38, primals_39, primals_40, primals_41, primals_42, primals_43, primals_44, primals_45, primals_46, primals_47, primals_48, primals_193, primals_194, buf1, buf4, buf9, buf10, buf14, buf15, buf16, buf18, buf20, buf23, as_strided(buf26, (128*s0, 3072), (3072, 1)), buf27, buf29, buf32, buf37, buf38, buf42, buf43, buf44, buf46, buf48, buf51, as_strided(buf54, (128*s0, 3072), (3072, 1)), buf55, buf58, buf61, buf66, buf67, buf71, buf72, buf73, buf75, buf77, buf80, as_strided(buf83, (128*s0, 3072), (3072, 1)), buf84, buf86, buf89, buf94, buf95, buf99, buf100, buf101, buf103, buf105, buf108, as_strided(buf111, (128*s0, 3072), (3072, 1)), buf112, buf115, buf118, buf123, buf124, buf128, buf129, buf130, buf132, buf134, buf137, as_strided(buf140, (128*s0, 3072), (3072, 1)), buf141, buf143, buf146, buf151, buf152, buf156, buf157, buf158, buf160, buf162, buf165, as_strided(buf168, (128*s0, 3072), (3072, 1)), buf169, buf172, buf175, buf180, buf181, buf185, buf186, buf187, buf189, buf191, buf194, as_strided(buf197, (128*s0, 3072), (3072, 1)), buf198, buf200, buf203, buf208, buf209, buf213, buf214, buf215, buf217, buf219, buf222, as_strided(buf225, (128*s0, 3072), (3072, 1)), buf226, buf229, buf232, buf237, buf238, buf242, buf243, buf244, buf246, buf248, buf251, as_strided(buf254, (128*s0, 3072), (3072, 1)), buf255, buf257, buf260, buf265, buf266, buf270, buf271, buf272, buf274, buf276, buf279, as_strided(buf282, (128*s0, 3072), (3072, 1)), buf283, buf286, buf289, buf294, buf295, buf299, buf300, buf301, buf303, buf305, buf308, as_strided(buf311, (128*s0, 3072), (3072, 1)), buf312, buf314, buf317, buf322, buf323, buf327, buf328, buf329, buf331, buf333, buf336, as_strided(buf339, (128*s0, 3072), (3072, 1)), as_strided(primals_191, (768, 3072), (3072, 1)), buf342, as_strided(primals_189, (3072, 768), (768, 1)), as_strided(primals_187, (768, 768), (768, 1)), as_strided(primals_185, (768, 768), (768, 1)), as_strided(primals_183, (768, 768), (768, 1)), as_strided(primals_181, (768, 768), (768, 1)), as_strided(primals_179, (768, 3072), (3072, 1)), buf343, as_strided(primals_177, (3072, 768), (768, 1)), as_strided(primals_175, (768, 768), (768, 1)), as_strided(primals_173, (768, 768), (768, 1)), as_strided(primals_171, (768, 768), (768, 1)), as_strided(primals_169, (768, 768), (768, 1)), as_strided(primals_167, (768, 3072), (3072, 1)), buf344, as_strided(primals_165, (3072, 768), (768, 1)), as_strided(primals_163, (768, 768), (768, 1)), as_strided(primals_161, (768, 768), (768, 1)), as_strided(primals_159, (768, 768), (768, 1)), as_strided(primals_157, (768, 768), (768, 1)), as_strided(primals_155, (768, 3072), (3072, 1)), buf345, as_strided(primals_153, (3072, 768), (768, 1)), as_strided(primals_151, (768, 768), (768, 1)), as_strided(primals_149, (768, 768), (768, 1)), as_strided(primals_147, (768, 768), (768, 1)), as_strided(primals_145, (768, 768), (768, 1)), as_strided(primals_143, (768, 3072), (3072, 1)), buf346, as_strided(primals_141, (3072, 768), (768, 1)), as_strided(primals_139, (768, 768), (768, 1)), as_strided(primals_137, (768, 768), (768, 1)), as_strided(primals_135, (768, 768), (768, 1)), as_strided(primals_133, (768, 768), (768, 1)), as_strided(primals_131, (768, 3072), (3072, 1)), buf347, as_strided(primals_129, (3072, 768), (768, 1)), as_strided(primals_127, (768, 768), (768, 1)), as_strided(primals_125, (768, 768), (768, 1)), as_strided(primals_123, (768, 768), (768, 1)), as_strided(primals_121, (768, 768), (768, 1)), as_strided(primals_119, (768, 3072), (3072, 1)), buf348, as_strided(primals_117, (3072, 768), (768, 1)), as_strided(primals_115, (768, 768), (768, 1)), as_strided(primals_113, (768, 768), (768, 1)), as_strided(primals_111, (768, 768), (768, 1)), as_strided(primals_109, (768, 768), (768, 1)), as_strided(primals_107, (768, 3072), (3072, 1)), buf349, as_strided(primals_105, (3072, 768), (768, 1)), as_strided(primals_103, (768, 768), (768, 1)), as_strided(primals_101, (768, 768), (768, 1)), as_strided(primals_99, (768, 768), (768, 1)), as_strided(primals_97, (768, 768), (768, 1)), as_strided(primals_95, (768, 3072), (3072, 1)), buf350, as_strided(primals_93, (3072, 768), (768, 1)), as_strided(primals_91, (768, 768), (768, 1)), as_strided(primals_89, (768, 768), (768, 1)), as_strided(primals_87, (768, 768), (768, 1)), as_strided(primals_85, (768, 768), (768, 1)), as_strided(primals_83, (768, 3072), (3072, 1)), buf351, as_strided(primals_81, (3072, 768), (768, 1)), as_strided(primals_79, (768, 768), (768, 1)), as_strided(primals_77, (768, 768), (768, 1)), as_strided(primals_75, (768, 768), (768, 1)), as_strided(primals_73, (768, 768), (768, 1)), as_strided(primals_71, (768, 3072), (3072, 1)), buf352, as_strided(primals_69, (3072, 768), (768, 1)), as_strided(primals_67, (768, 768), (768, 1)), as_strided(primals_65, (768, 768), (768, 1)), as_strided(primals_63, (768, 768), (768, 1)), as_strided(primals_61, (768, 768), (768, 1)), as_strided(primals_59, (768, 3072), (3072, 1)), buf353, as_strided(primals_57, (3072, 768), (768, 1)), as_strided(primals_55, (768, 768), (768, 1)), as_strided(primals_53, (768, 768), (768, 1)), as_strided(primals_51, (768, 768), (768, 1)), as_strided(primals_49, (768, 768), (768, 1)), s0, 128, 128*s0, 128, 12*s0, 128, s0, 12, 128, 128, 12*s0, 128, s0, 128, 128*s0, 768, s0, 128, 128*s0, s0, 128, s0, 128, 128*s0, 128, 12*s0, 128, s0, 12, 128, 128, 12*s0, 128, s0, 128, 128*s0, 768, s0, 128, 128*s0, s0, 128, s0, 128, 128*s0, 128, 12*s0, 128, s0, 12, 128, 128, 12*s0, 128, s0, 128, 128*s0, 768, s0, 128, 128*s0, s0, 128, s0, 128, 128*s0, 128, 12*s0, 128, s0, 12, 128, 128, 12*s0, 128, s0, 128, 128*s0, 768, s0, 128, 128*s0, s0, 128, s0, 128, 128*s0, 128, 12*s0, 128, s0, 12, 128, 128, 12*s0, 128, s0, 128, 128*s0, 768, s0, 128, 128*s0, s0, 128, s0, 128, 128*s0, 128, 12*s0, 128, s0, 12, 128, 128, 12*s0, 128, s0, 128, 128*s0, 768, s0, 128, 128*s0, s0, 128, s0, 128, 128*s0, 128, 12*s0, 128, s0, 12, 128, 128, 12*s0, 128, s0, 128, 128*s0, 768, s0, 128, 128*s0, s0, 128, s0, 128, 128*s0, 128, 12*s0, 128, s0, 12, 128, 128, 12*s0, 128, s0, 128, 128*s0, 768, s0, 128, 128*s0, s0, 128, s0, 128, 128*s0, 128, 12*s0, 128, s0, 12, 128, 128, 12*s0, 128, s0, 128, 128*s0, 768, s0, 128, 128*s0, s0, 128, s0, 128, 128*s0, 128, 12*s0, 128, s0, 12, 128, 128, 12*s0, 128, s0, 128, 128*s0, 768, s0, 128, 128*s0, s0, 128, s0, 128, 128*s0, 128, 12*s0, 128, s0, 12, 128, 128, 12*s0, 128, s0, 128, 128*s0, 768, s0, 128, 128*s0, s0, 128, s0, 128, 128*s0, 128, 12*s0, 128, s0, 12, 128, 128, 12*s0, 128, s0, 128, 128*s0, 768, s0, 128, 128*s0, s0, 128, 128*s0, 128*s0, 12*s0, 128, 128, s0, 128, 128*s0, s0, 128, 128*s0, s0, 128, 128*s0, 128*s0, 12*s0, 128, 128, s0, 128, 128*s0, s0, 128, 128*s0, s0, 128, 128*s0, 128*s0, 12*s0, 128, 128, s0, 128, 128*s0, s0, 128, 128*s0, s0, 128, 128*s0, 128*s0, 12*s0, 128, 128, s0, 128, 128*s0, s0, 128, 128*s0, s0, 128, 128*s0, 128*s0, 12*s0, 128, 128, s0, 128, 128*s0, s0, 128, 128*s0, s0, 128, 128*s0, 128*s0, 12*s0, 128, 128, s0, 128, 128*s0, s0, 128, 128*s0, s0, 128, 128*s0, 128*s0, 12*s0, 128, 128, s0, 128, 128*s0, s0, 128, 128*s0, s0, 128, 128*s0, 128*s0, 12*s0, 128, 128, s0, 128, 128*s0, s0, 128, 128*s0, s0, 128, 128*s0, 128*s0, 12*s0, 128, 128, s0, 128, 128*s0, s0, 128, 128*s0, s0, 128, 128*s0, 128*s0, 12*s0, 128, 128, s0, 128, 128*s0, s0, 128, 128*s0, s0, 128, 128*s0, 128*s0, 12*s0, 128, 128, s0, 128, 128*s0, s0, 128, 128*s0, s0, 128, 128*s0, 128*s0, 12*s0, 128, 128, s0, 128, 128*s0, s0, 128, 128*s0, s0, 128, 128*s0, )
if __name__ == "__main__":
from torch._dynamo.testing import rand_strided
from torch._inductor.utils import print_performance
primals_1 = rand_strided((768, ), (1, ), device='cuda', dtype=torch.float32)
primals_2 = rand_strided((768, ), (1, ), device='cuda', dtype=torch.float32)
primals_3 = rand_strided((768, ), (1, ), device='cuda', dtype=torch.float32)
primals_4 = rand_strided((768, ), (1, ), device='cuda', dtype=torch.float32)
primals_5 = rand_strided((768, ), (1, ), device='cuda', dtype=torch.float32)
primals_6 = rand_strided((768, ), (1, ), device='cuda', dtype=torch.float32)
primals_7 = rand_strided((768, ), (1, ), device='cuda', dtype=torch.float32)
primals_8 = rand_strided((768, ), (1, ), device='cuda', dtype=torch.float32)
primals_9 = rand_strided((768, ), (1, ), device='cuda', dtype=torch.float32)
primals_10 = rand_strided((768, ), (1, ), device='cuda', dtype=torch.float32)
primals_11 = rand_strided((768, ), (1, ), device='cuda', dtype=torch.float32)
primals_12 = rand_strided((768, ), (1, ), device='cuda', dtype=torch.float32)
primals_13 = rand_strided((768, ), (1, ), device='cuda', dtype=torch.float32)
primals_14 = rand_strided((768, ), (1, ), device='cuda', dtype=torch.float32)
primals_15 = rand_strided((768, ), (1, ), device='cuda', dtype=torch.float32)
primals_16 = rand_strided((768, ), (1, ), device='cuda', dtype=torch.float32)
primals_17 = rand_strided((768, ), (1, ), device='cuda', dtype=torch.float32)
primals_18 = rand_strided((768, ), (1, ), device='cuda', dtype=torch.float32)
primals_19 = rand_strided((768, ), (1, ), device='cuda', dtype=torch.float32)
primals_20 = rand_strided((768, ), (1, ), device='cuda', dtype=torch.float32)
primals_21 = rand_strided((768, ), (1, ), device='cuda', dtype=torch.float32)
primals_22 = rand_strided((768, ), (1, ), device='cuda', dtype=torch.float32)
primals_23 = rand_strided((768, ), (1, ), device='cuda', dtype=torch.float32)
primals_24 = rand_strided((768, ), (1, ), device='cuda', dtype=torch.float32)
primals_25 = rand_strided((768, ), (1, ), device='cuda', dtype=torch.float32)
primals_26 = rand_strided((768, ), (1, ), device='cuda', dtype=torch.float32)
primals_27 = rand_strided((768, ), (1, ), device='cuda', dtype=torch.float32)
primals_28 = rand_strided((768, ), (1, ), device='cuda', dtype=torch.float32)
primals_29 = rand_strided((768, ), (1, ), device='cuda', dtype=torch.float32)
primals_30 = rand_strided((768, ), (1, ), device='cuda', dtype=torch.float32)
primals_31 = rand_strided((768, ), (1, ), device='cuda', dtype=torch.float32)
primals_32 = rand_strided((768, ), (1, ), device='cuda', dtype=torch.float32)
primals_33 = rand_strided((768, ), (1, ), device='cuda', dtype=torch.float32)
primals_34 = rand_strided((768, ), (1, ), device='cuda', dtype=torch.float32)
primals_35 = rand_strided((768, ), (1, ), device='cuda', dtype=torch.float32)
primals_36 = rand_strided((768, ), (1, ), device='cuda', dtype=torch.float32)
primals_37 = rand_strided((768, ), (1, ), device='cuda', dtype=torch.float32)
primals_38 = rand_strided((768, ), (1, ), device='cuda', dtype=torch.float32)
primals_39 = rand_strided((768, ), (1, ), device='cuda', dtype=torch.float32)
primals_40 = rand_strided((768, ), (1, ), device='cuda', dtype=torch.float32)
primals_41 = rand_strided((768, ), (1, ), device='cuda', dtype=torch.float32)
primals_42 = rand_strided((768, ), (1, ), device='cuda', dtype=torch.float32)
primals_43 = rand_strided((768, ), (1, ), device='cuda', dtype=torch.float32)
primals_44 = rand_strided((768, ), (1, ), device='cuda', dtype=torch.float32)
primals_45 = rand_strided((768, ), (1, ), device='cuda', dtype=torch.float32)
primals_46 = rand_strided((768, ), (1, ), device='cuda', dtype=torch.float32)
primals_47 = rand_strided((768, ), (1, ), device='cuda', dtype=torch.float32)
primals_48 = rand_strided((768, ), (1, ), device='cuda', dtype=torch.float32)
primals_49 = rand_strided((768, 768), (768, 1), device='cuda', dtype=torch.float32)
primals_50 = rand_strided((768, ), (1, ), device='cuda', dtype=torch.float32)
primals_51 = rand_strided((768, 768), (768, 1), device='cuda', dtype=torch.float32)
primals_52 = rand_strided((768, ), (1, ), device='cuda', dtype=torch.float32)
primals_53 = rand_strided((768, 768), (768, 1), device='cuda', dtype=torch.float32)
primals_54 = rand_strided((768, ), (1, ), device='cuda', dtype=torch.float32)
primals_55 = rand_strided((768, 768), (768, 1), device='cuda', dtype=torch.float32)
primals_56 = rand_strided((768, ), (1, ), device='cuda', dtype=torch.float32)
primals_57 = rand_strided((3072, 768), (768, 1), device='cuda', dtype=torch.float32)
primals_58 = rand_strided((3072, ), (1, ), device='cuda', dtype=torch.float32)
primals_59 = rand_strided((768, 3072), (3072, 1), device='cuda', dtype=torch.float32)
primals_60 = rand_strided((768, ), (1, ), device='cuda', dtype=torch.float32)
primals_61 = rand_strided((768, 768), (768, 1), device='cuda', dtype=torch.float32)
primals_62 = rand_strided((768, ), (1, ), device='cuda', dtype=torch.float32)
primals_63 = rand_strided((768, 768), (768, 1), device='cuda', dtype=torch.float32)
primals_64 = rand_strided((768, ), (1, ), device='cuda', dtype=torch.float32)
primals_65 = rand_strided((768, 768), (768, 1), device='cuda', dtype=torch.float32)
primals_66 = rand_strided((768, ), (1, ), device='cuda', dtype=torch.float32)
primals_67 = rand_strided((768, 768), (768, 1), device='cuda', dtype=torch.float32)
primals_68 = rand_strided((768, ), (1, ), device='cuda', dtype=torch.float32)
primals_69 = rand_strided((3072, 768), (768, 1), device='cuda', dtype=torch.float32)
primals_70 = rand_strided((3072, ), (1, ), device='cuda', dtype=torch.float32)
primals_71 = rand_strided((768, 3072), (3072, 1), device='cuda', dtype=torch.float32)
primals_72 = rand_strided((768, ), (1, ), device='cuda', dtype=torch.float32)
primals_73 = rand_strided((768, 768), (768, 1), device='cuda', dtype=torch.float32)
primals_74 = rand_strided((768, ), (1, ), device='cuda', dtype=torch.float32)
primals_75 = rand_strided((768, 768), (768, 1), device='cuda', dtype=torch.float32)
primals_76 = rand_strided((768, ), (1, ), device='cuda', dtype=torch.float32)
primals_77 = rand_strided((768, 768), (768, 1), device='cuda', dtype=torch.float32)
primals_78 = rand_strided((768, ), (1, ), device='cuda', dtype=torch.float32)
primals_79 = rand_strided((768, 768), (768, 1), device='cuda', dtype=torch.float32)
primals_80 = rand_strided((768, ), (1, ), device='cuda', dtype=torch.float32)
primals_81 = rand_strided((3072, 768), (768, 1), device='cuda', dtype=torch.float32)
primals_82 = rand_strided((3072, ), (1, ), device='cuda', dtype=torch.float32)
primals_83 = rand_strided((768, 3072), (3072, 1), device='cuda', dtype=torch.float32)
primals_84 = rand_strided((768, ), (1, ), device='cuda', dtype=torch.float32)
primals_85 = rand_strided((768, 768), (768, 1), device='cuda', dtype=torch.float32)
primals_86 = rand_strided((768, ), (1, ), device='cuda', dtype=torch.float32)
primals_87 = rand_strided((768, 768), (768, 1), device='cuda', dtype=torch.float32)
primals_88 = rand_strided((768, ), (1, ), device='cuda', dtype=torch.float32)
primals_89 = rand_strided((768, 768), (768, 1), device='cuda', dtype=torch.float32)
primals_90 = rand_strided((768, ), (1, ), device='cuda', dtype=torch.float32)
primals_91 = rand_strided((768, 768), (768, 1), device='cuda', dtype=torch.float32)
primals_92 = rand_strided((768, ), (1, ), device='cuda', dtype=torch.float32)
primals_93 = rand_strided((3072, 768), (768, 1), device='cuda', dtype=torch.float32)
primals_94 = rand_strided((3072, ), (1, ), device='cuda', dtype=torch.float32)
primals_95 = rand_strided((768, 3072), (3072, 1), device='cuda', dtype=torch.float32)
primals_96 = rand_strided((768, ), (1, ), device='cuda', dtype=torch.float32)
primals_97 = rand_strided((768, 768), (768, 1), device='cuda', dtype=torch.float32)
primals_98 = rand_strided((768, ), (1, ), device='cuda', dtype=torch.float32)
primals_99 = rand_strided((768, 768), (768, 1), device='cuda', dtype=torch.float32)
primals_100 = rand_strided((768, ), (1, ), device='cuda', dtype=torch.float32)
primals_101 = rand_strided((768, 768), (768, 1), device='cuda', dtype=torch.float32)
primals_102 = rand_strided((768, ), (1, ), device='cuda', dtype=torch.float32)
primals_103 = rand_strided((768, 768), (768, 1), device='cuda', dtype=torch.float32)
primals_104 = rand_strided((768, ), (1, ), device='cuda', dtype=torch.float32)
primals_105 = rand_strided((3072, 768), (768, 1), device='cuda', dtype=torch.float32)
primals_106 = rand_strided((3072, ), (1, ), device='cuda', dtype=torch.float32)
primals_107 = rand_strided((768, 3072), (3072, 1), device='cuda', dtype=torch.float32)
primals_108 = rand_strided((768, ), (1, ), device='cuda', dtype=torch.float32)
primals_109 = rand_strided((768, 768), (768, 1), device='cuda', dtype=torch.float32)
primals_110 = rand_strided((768, ), (1, ), device='cuda', dtype=torch.float32)
primals_111 = rand_strided((768, 768), (768, 1), device='cuda', dtype=torch.float32)
primals_112 = rand_strided((768, ), (1, ), device='cuda', dtype=torch.float32)
primals_113 = rand_strided((768, 768), (768, 1), device='cuda', dtype=torch.float32)
primals_114 = rand_strided((768, ), (1, ), device='cuda', dtype=torch.float32)
primals_115 = rand_strided((768, 768), (768, 1), device='cuda', dtype=torch.float32)
primals_116 = rand_strided((768, ), (1, ), device='cuda', dtype=torch.float32)
primals_117 = rand_strided((3072, 768), (768, 1), device='cuda', dtype=torch.float32)
primals_118 = rand_strided((3072, ), (1, ), device='cuda', dtype=torch.float32)
primals_119 = rand_strided((768, 3072), (3072, 1), device='cuda', dtype=torch.float32)
primals_120 = rand_strided((768, ), (1, ), device='cuda', dtype=torch.float32)
primals_121 = rand_strided((768, 768), (768, 1), device='cuda', dtype=torch.float32)
primals_122 = rand_strided((768, ), (1, ), device='cuda', dtype=torch.float32)
primals_123 = rand_strided((768, 768), (768, 1), device='cuda', dtype=torch.float32)
primals_124 = rand_strided((768, ), (1, ), device='cuda', dtype=torch.float32)
primals_125 = rand_strided((768, 768), (768, 1), device='cuda', dtype=torch.float32)
primals_126 = rand_strided((768, ), (1, ), device='cuda', dtype=torch.float32)
primals_127 = rand_strided((768, 768), (768, 1), device='cuda', dtype=torch.float32)
primals_128 = rand_strided((768, ), (1, ), device='cuda', dtype=torch.float32)
primals_129 = rand_strided((3072, 768), (768, 1), device='cuda', dtype=torch.float32)
primals_130 = rand_strided((3072, ), (1, ), device='cuda', dtype=torch.float32)
primals_131 = rand_strided((768, 3072), (3072, 1), device='cuda', dtype=torch.float32)
primals_132 = rand_strided((768, ), (1, ), device='cuda', dtype=torch.float32)
primals_133 = rand_strided((768, 768), (768, 1), device='cuda', dtype=torch.float32)
primals_134 = rand_strided((768, ), (1, ), device='cuda', dtype=torch.float32)
primals_135 = rand_strided((768, 768), (768, 1), device='cuda', dtype=torch.float32)
primals_136 = rand_strided((768, ), (1, ), device='cuda', dtype=torch.float32)
primals_137 = rand_strided((768, 768), (768, 1), device='cuda', dtype=torch.float32)
primals_138 = rand_strided((768, ), (1, ), device='cuda', dtype=torch.float32)
primals_139 = rand_strided((768, 768), (768, 1), device='cuda', dtype=torch.float32)
primals_140 = rand_strided((768, ), (1, ), device='cuda', dtype=torch.float32)
primals_141 = rand_strided((3072, 768), (768, 1), device='cuda', dtype=torch.float32)
primals_142 = rand_strided((3072, ), (1, ), device='cuda', dtype=torch.float32)
primals_143 = rand_strided((768, 3072), (3072, 1), device='cuda', dtype=torch.float32)
primals_144 = rand_strided((768, ), (1, ), device='cuda', dtype=torch.float32)
primals_145 = rand_strided((768, 768), (768, 1), device='cuda', dtype=torch.float32)
primals_146 = rand_strided((768, ), (1, ), device='cuda', dtype=torch.float32)
primals_147 = rand_strided((768, 768), (768, 1), device='cuda', dtype=torch.float32)
primals_148 = rand_strided((768, ), (1, ), device='cuda', dtype=torch.float32)
primals_149 = rand_strided((768, 768), (768, 1), device='cuda', dtype=torch.float32)
primals_150 = rand_strided((768, ), (1, ), device='cuda', dtype=torch.float32)
primals_151 = rand_strided((768, 768), (768, 1), device='cuda', dtype=torch.float32)
primals_152 = rand_strided((768, ), (1, ), device='cuda', dtype=torch.float32)
primals_153 = rand_strided((3072, 768), (768, 1), device='cuda', dtype=torch.float32)
primals_154 = rand_strided((3072, ), (1, ), device='cuda', dtype=torch.float32)
primals_155 = rand_strided((768, 3072), (3072, 1), device='cuda', dtype=torch.float32)
primals_156 = rand_strided((768, ), (1, ), device='cuda', dtype=torch.float32)
primals_157 = rand_strided((768, 768), (768, 1), device='cuda', dtype=torch.float32)
primals_158 = rand_strided((768, ), (1, ), device='cuda', dtype=torch.float32)
primals_159 = rand_strided((768, 768), (768, 1), device='cuda', dtype=torch.float32)
primals_160 = rand_strided((768, ), (1, ), device='cuda', dtype=torch.float32)
primals_161 = rand_strided((768, 768), (768, 1), device='cuda', dtype=torch.float32)
primals_162 = rand_strided((768, ), (1, ), device='cuda', dtype=torch.float32)
primals_163 = rand_strided((768, 768), (768, 1), device='cuda', dtype=torch.float32)
primals_164 = rand_strided((768, ), (1, ), device='cuda', dtype=torch.float32)
primals_165 = rand_strided((3072, 768), (768, 1), device='cuda', dtype=torch.float32)
primals_166 = rand_strided((3072, ), (1, ), device='cuda', dtype=torch.float32)
primals_167 = rand_strided((768, 3072), (3072, 1), device='cuda', dtype=torch.float32)
primals_168 = rand_strided((768, ), (1, ), device='cuda', dtype=torch.float32)
primals_169 = rand_strided((768, 768), (768, 1), device='cuda', dtype=torch.float32)
primals_170 = rand_strided((768, ), (1, ), device='cuda', dtype=torch.float32)
primals_171 = rand_strided((768, 768), (768, 1), device='cuda', dtype=torch.float32)
primals_172 = rand_strided((768, ), (1, ), device='cuda', dtype=torch.float32)
primals_173 = rand_strided((768, 768), (768, 1), device='cuda', dtype=torch.float32)
primals_174 = rand_strided((768, ), (1, ), device='cuda', dtype=torch.float32)
primals_175 = rand_strided((768, 768), (768, 1), device='cuda', dtype=torch.float32)
primals_176 = rand_strided((768, ), (1, ), device='cuda', dtype=torch.float32)
primals_177 = rand_strided((3072, 768), (768, 1), device='cuda', dtype=torch.float32)
primals_178 = rand_strided((3072, ), (1, ), device='cuda', dtype=torch.float32)
primals_179 = rand_strided((768, 3072), (3072, 1), device='cuda', dtype=torch.float32)
primals_180 = rand_strided((768, ), (1, ), device='cuda', dtype=torch.float32)
primals_181 = rand_strided((768, 768), (768, 1), device='cuda', dtype=torch.float32)
primals_182 = rand_strided((768, ), (1, ), device='cuda', dtype=torch.float32)
primals_183 = rand_strided((768, 768), (768, 1), device='cuda', dtype=torch.float32)
primals_184 = rand_strided((768, ), (1, ), device='cuda', dtype=torch.float32)
primals_185 = rand_strided((768, 768), (768, 1), device='cuda', dtype=torch.float32)
primals_186 = rand_strided((768, ), (1, ), device='cuda', dtype=torch.float32)
primals_187 = rand_strided((768, 768), (768, 1), device='cuda', dtype=torch.float32)
primals_188 = rand_strided((768, ), (1, ), device='cuda', dtype=torch.float32)
primals_189 = rand_strided((3072, 768), (768, 1), device='cuda', dtype=torch.float32)
primals_190 = rand_strided((3072, ), (1, ), device='cuda', dtype=torch.float32)
primals_191 = rand_strided((768, 3072), (3072, 1), device='cuda', dtype=torch.float32)
primals_192 = rand_strided((768, ), (1, ), device='cuda', dtype=torch.float32)
primals_193 = rand_strided((2, 128, 768), (98304, 768, 1), device='cuda', dtype=torch.float32)
primals_194 = rand_strided((2, 1, 128, 128), (128, 0, 0, 1), device='cuda', dtype=torch.bool)
print_performance(lambda: call([primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, primals_12, primals_13, primals_14, primals_15, primals_16, primals_17, primals_18, primals_19, primals_20, primals_21, primals_22, primals_23, primals_24, primals_25, primals_26, primals_27, primals_28, primals_29, primals_30, primals_31, primals_32, primals_33, primals_34, primals_35, primals_36, primals_37, primals_38, primals_39, primals_40, primals_41, primals_42, primals_43, primals_44, primals_45, primals_46, primals_47, primals_48, primals_49, primals_50, primals_51, primals_52, primals_53, primals_54, primals_55, primals_56, primals_57, primals_58, primals_59, primals_60, primals_61, primals_62, primals_63, primals_64, primals_65, primals_66, primals_67, primals_68, primals_69, primals_70, primals_71, primals_72, primals_73, primals_74, primals_75, primals_76, primals_77, primals_78, primals_79, primals_80, primals_81, primals_82, primals_83, primals_84, primals_85, primals_86, primals_87, primals_88, primals_89, primals_90, primals_91, primals_92, primals_93, primals_94, primals_95, primals_96, primals_97, primals_98, primals_99, primals_100, primals_101, primals_102, primals_103, primals_104, primals_105, primals_106, primals_107, primals_108, primals_109, primals_110, primals_111, primals_112, primals_113, primals_114, primals_115, primals_116, primals_117, primals_118, primals_119, primals_120, primals_121, primals_122, primals_123, primals_124, primals_125, primals_126, primals_127, primals_128, primals_129, primals_130, primals_131, primals_132, primals_133, primals_134, primals_135, primals_136, primals_137, primals_138, primals_139, primals_140, primals_141, primals_142, primals_143, primals_144, primals_145, primals_146, primals_147, primals_148, primals_149, primals_150, primals_151, primals_152, primals_153, primals_154, primals_155, primals_156, primals_157, primals_158, primals_159, primals_160, primals_161, primals_162, primals_163, primals_164, primals_165, primals_166, primals_167, primals_168, primals_169, primals_170, primals_171, primals_172, primals_173, primals_174, primals_175, primals_176, primals_177, primals_178, primals_179, primals_180, primals_181, primals_182, primals_183, primals_184, primals_185, primals_186, primals_187, primals_188, primals_189, primals_190, primals_191, primals_192, primals_193, primals_194]))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment