-
-
Save ezyang/1fbf5bab298a0b35b47d5de6f885deff to your computer and use it in GitHub Desktop.
BERT_pytorch dynamic Triton
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from ctypes import c_void_p, c_long | |
import torch | |
import random | |
from torch import empty_strided, as_strided, device | |
from torch._inductor.codecache import AsyncCompile | |
aten = torch.ops.aten | |
assert_size_stride = torch._C._dynamo.guards.assert_size_stride | |
async_compile = AsyncCompile() | |
import triton | |
import triton.language as tl | |
from torch._inductor.triton_ops.autotune import grid | |
from torch._C import _cuda_getCurrentRawStream as get_cuda_stream | |
triton_fused_add_1_div_mul_std_sub_mean_add0 = async_compile.triton(''' | |
import triton | |
import triton.language as tl | |
from torch._inductor.ir import ReductionHint | |
from torch._inductor.ir import TileHint | |
from torch._inductor.triton_ops.autotune import reduction | |
from torch._inductor.utils import instance_descriptor | |
@reduction(size_hints=[256, 1024], | |
reduction_hint=ReductionHint.INNER, | |
filename=__file__, | |
meta={'signature': {0: '*fp32', 1: '*fp32', 2: '*fp32', 3: '*fp32', 4: '*fp32', 5: '*fp32', 6: 'i32', 7: 'i32'}, 'device': 0, 'constants': {}, 'configs': [instance_descriptor(divisible_by_16=(0, 1, 2, 3, 4, 5, 6, 7), equal_to_1=())]}) | |
@triton.jit | |
def kernel(in_out_ptr0, in_out_ptr1, in_ptr0, in_ptr1, in_ptr2, out_ptr1, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr): | |
xoffset = tl.program_id(0) * XBLOCK | |
xindex = xoffset + tl.reshape(tl.arange(0, XBLOCK), [XBLOCK, 1]) | |
xmask = xindex < xnumel | |
rbase = tl.reshape(tl.arange(0, RBLOCK), [1, RBLOCK]) | |
x0 = xindex | |
_tmp1 = tl.zeros([XBLOCK, RBLOCK], tl.float32) + 0 | |
for roffset in range(0, rnumel, RBLOCK): | |
rindex = roffset + rbase | |
rmask = rindex < rnumel | |
r1 = rindex | |
tmp0 = tl.load(in_ptr0 + (r1 + (768*x0)), xmask & rmask, eviction_policy='evict_last') | |
_tmp1 = tl.where(xmask & rmask, _tmp1 + tmp0, _tmp1) | |
tmp1 = tl.reshape(tl.sum(_tmp1, 1), [XBLOCK, 1]) | |
tmp2 = tmp1 | |
_tmp8 = tl.zeros([XBLOCK, RBLOCK], tl.float32) + 0 | |
for roffset in range(0, rnumel, RBLOCK): | |
rindex = roffset + rbase | |
rmask = rindex < rnumel | |
r1 = rindex | |
tmp3 = tl.load(in_ptr0 + (r1 + (768*x0)), xmask & rmask, eviction_policy='evict_last') | |
tmp4 = 768 | |
tmp5 = tmp2 / tmp4 | |
tmp6 = tmp3 - tmp5 | |
tmp7 = tmp6 * tmp6 | |
_tmp8 = tl.where(xmask & rmask, _tmp8 + tmp7, _tmp8) | |
tmp8 = tl.reshape(tl.sum(_tmp8, 1), [XBLOCK, 1]) | |
tmp9 = 768 | |
tmp10 = tmp1 / tmp9 | |
tmp11 = 767 | |
tmp12 = tmp8 / tmp11 | |
tmp13 = tl.sqrt(tmp12) | |
tl.store(in_out_ptr0 + (x0 + tl.zeros([XBLOCK, 1], tl.int32)), tmp10, xmask) | |
tl.store(in_out_ptr1 + (x0 + tl.zeros([XBLOCK, 1], tl.int32)), tmp13, xmask) | |
for roffset in range(0, rnumel, RBLOCK): | |
rindex = roffset + rbase | |
rmask = rindex < rnumel | |
r1 = rindex | |
tmp14 = tl.load(in_ptr1 + (r1), rmask, eviction_policy='evict_last') | |
tmp15 = tl.load(in_ptr0 + (r1 + (768*x0)), xmask & rmask, eviction_policy='evict_last') | |
tmp21 = tl.load(in_ptr2 + (r1), rmask, eviction_policy='evict_last') | |
tmp16 = tmp15 - tmp10 | |
tmp17 = tmp14 * tmp16 | |
tmp18 = 1e-06 | |
tmp19 = tmp13 + tmp18 | |
tmp20 = tmp17 / tmp19 | |
tmp22 = tmp20 + tmp21 | |
tl.store(out_ptr1 + (r1 + (768*x0) + tl.zeros([XBLOCK, RBLOCK], tl.int32)), tmp22, xmask & rmask) | |
''') | |
triton_fused_view_1_add_1_addmm_div_mul_permute_sub_mean_add_permute_11 = async_compile.triton(''' | |
import triton | |
import triton.language as tl | |
from torch._inductor.ir import ReductionHint | |
from torch._inductor.ir import TileHint | |
from torch._inductor.triton_ops.autotune import pointwise | |
from torch._inductor.utils import instance_descriptor | |
@pointwise(size_hints=[262144], filename=__file__, meta={'signature': {0: '*fp32', 1: '*fp32', 2: 'i32', 3: 'i32'}, 'device': 0, 'constants': {}, 'configs': [instance_descriptor(divisible_by_16=(0, 1, 2, 3), equal_to_1=())]}) | |
@triton.jit | |
def kernel(in_ptr0, out_ptr0, ks0, xnumel, XBLOCK : tl.constexpr): | |
xoffset = tl.program_id(0) * XBLOCK | |
xindex = xoffset + tl.reshape(tl.arange(0, XBLOCK), [XBLOCK]) | |
xmask = xindex < xnumel | |
x0 = xindex % 64 | |
x1 = (xindex // 64) % ks0 | |
x2 = (xindex // (64*ks0)) % 12 | |
x3 = (xindex // (768*ks0)) | |
x4 = xindex | |
tmp0 = tl.load(in_ptr0 + (x0 + (64*x2) + (768*x1) + (768*ks0*x3)), xmask) | |
tl.store(out_ptr0 + (x4 + tl.zeros([XBLOCK], tl.int32)), tmp0, xmask) | |
''') | |
triton_fused_add_1_div_mul_permute_6_sub_addmm_1_mean_permute_3_add_clone_12 = async_compile.triton(''' | |
import triton | |
import triton.language as tl | |
from torch._inductor.ir import ReductionHint | |
from torch._inductor.ir import TileHint | |
from torch._inductor.triton_ops.autotune import pointwise | |
from torch._inductor.utils import instance_descriptor | |
@pointwise(size_hints=[2048, 128], tile_hint=TileHint.SQUARE,filename=__file__, meta={'signature': {0: '*fp32', 1: '*fp32', 2: 'i32', 3: 'i32'}, 'device': 0, 'constants': {}, 'configs': [instance_descriptor(divisible_by_16=(0, 1, 2, 3), equal_to_1=())]}) | |
@triton.jit | |
def kernel(in_ptr0, out_ptr0, xnumel, ynumel, XBLOCK : tl.constexpr, YBLOCK : tl.constexpr): | |
xoffset = tl.program_id(0) * XBLOCK | |
xindex = xoffset + tl.reshape(tl.arange(0, XBLOCK), [XBLOCK, 1]) | |
xmask = xindex < xnumel | |
yoffset = tl.program_id(1) * YBLOCK | |
yindex = yoffset + tl.reshape(tl.arange(0, YBLOCK), [1, YBLOCK]) | |
ymask = yindex < ynumel | |
x0 = xindex % 768 | |
x1 = (xindex // 768) | |
y2 = yindex | |
x3 = xindex | |
tmp0 = tl.load(in_ptr0 + (x0 + (768*y2) + (98304*x1)), xmask & ymask) | |
tl.store(out_ptr0 + (y2 + (128*x3) + tl.zeros([XBLOCK, YBLOCK], tl.int32)), tmp0, xmask & ymask) | |
''') | |
triton_fused_add_1_div_mul_sub_mean_add_div_2_sub_1_div_1_amax3 = async_compile.triton(''' | |
import triton | |
import triton.language as tl | |
from torch._inductor.ir import ReductionHint | |
from torch._inductor.ir import TileHint | |
from torch._inductor.triton_ops.autotune import reduction | |
from torch._inductor.utils import instance_descriptor | |
@reduction(size_hints=[4096, 128], | |
reduction_hint=ReductionHint.INNER, | |
filename=__file__, | |
meta={'signature': {0: '*i1', 1: '*fp32', 2: '*fp32', 3: 'i32', 4: 'i32'}, 'device': 0, 'constants': {}, 'configs': [instance_descriptor(divisible_by_16=(0, 1, 2, 3, 4), equal_to_1=())]}) | |
@triton.jit | |
def kernel(in_ptr0, in_ptr1, out_ptr2, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr): | |
xoffset = tl.program_id(0) * XBLOCK | |
xindex = xoffset + tl.reshape(tl.arange(0, XBLOCK), [XBLOCK, 1]) | |
xmask = xindex < xnumel | |
rbase = tl.reshape(tl.arange(0, RBLOCK), [1, RBLOCK]) | |
x4 = (xindex // 1536) | |
x5 = xindex | |
_tmp8 = tl.zeros([XBLOCK, RBLOCK], tl.float32) + float("-inf") | |
for roffset in range(0, rnumel, RBLOCK): | |
rindex = roffset + rbase | |
rmask = rindex < rnumel | |
r2 = rindex | |
tmp0 = tl.load(in_ptr0 + (r2 + (128*x4)), xmask & rmask, eviction_policy='evict_last') | |
tmp4 = tl.load(in_ptr1 + (r2 + (128*x5)), xmask & rmask, eviction_policy='evict_last') | |
tmp1 = 0 | |
tmp2 = tmp0 == tmp1 | |
tmp3 = -1000000000.0 | |
tmp5 = 8.0 | |
tmp6 = tmp4 / tmp5 | |
tmp7 = tl.where(tmp2, tmp3, tmp6) | |
_tmp8 = tl.where(xmask & rmask & (_tmp8 < tmp7), tmp7, _tmp8) | |
tmp8 = tl.reshape(tl.max(_tmp8, 1), [XBLOCK, 1]) | |
_tmp19 = tl.zeros([XBLOCK, RBLOCK], tl.float32) + 0 | |
for roffset in range(0, rnumel, RBLOCK): | |
rindex = roffset + rbase | |
rmask = rindex < rnumel | |
r2 = rindex | |
tmp9 = tl.load(in_ptr0 + (r2 + (128*x4)), xmask & rmask, eviction_policy='evict_last') | |
tmp13 = tl.load(in_ptr1 + (r2 + (128*x5)), xmask & rmask, eviction_policy='evict_last') | |
tmp10 = 0 | |
tmp11 = tmp9 == tmp10 | |
tmp12 = -1000000000.0 | |
tmp14 = 8.0 | |
tmp15 = tmp13 / tmp14 | |
tmp16 = tl.where(tmp11, tmp12, tmp15) | |
tmp17 = tmp16 - tmp8 | |
tmp18 = tl.exp(tmp17) | |
_tmp19 = tl.where(xmask & rmask, _tmp19 + tmp18, _tmp19) | |
tmp19 = tl.reshape(tl.sum(_tmp19, 1), [XBLOCK, 1]) | |
for roffset in range(0, rnumel, RBLOCK): | |
rindex = roffset + rbase | |
rmask = rindex < rnumel | |
r2 = rindex | |
tmp20 = tl.load(in_ptr0 + (r2 + (128*x4)), xmask & rmask, eviction_policy='evict_last') | |
tmp24 = tl.load(in_ptr1 + (r2 + (128*x5)), xmask & rmask, eviction_policy='evict_last') | |
tmp21 = 0 | |
tmp22 = tmp20 == tmp21 | |
tmp23 = -1000000000.0 | |
tmp25 = 8.0 | |
tmp26 = tmp24 / tmp25 | |
tmp27 = tl.where(tmp22, tmp23, tmp26) | |
tmp28 = tmp27 - tmp8 | |
tmp29 = tl.exp(tmp28) | |
tmp30 = tmp29 / tmp19 | |
tl.store(out_ptr2 + (r2 + (128*x5) + tl.zeros([XBLOCK, RBLOCK], tl.int32)), tmp30, xmask & rmask) | |
''') | |
triton_fused_add_1_div_mul_clone_2_view_7_permute_5_sub_permute_4_mean_add4 = async_compile.triton(''' | |
import triton | |
import triton.language as tl | |
from torch._inductor.ir import ReductionHint | |
from torch._inductor.ir import TileHint | |
from torch._inductor.triton_ops.autotune import pointwise | |
from torch._inductor.utils import instance_descriptor | |
@pointwise(size_hints=[262144], filename=__file__, meta={'signature': {0: '*fp32', 1: '*fp32', 2: 'i32'}, 'device': 0, 'constants': {}, 'configs': [instance_descriptor(divisible_by_16=(0, 1, 2), equal_to_1=())]}) | |
@triton.jit | |
def kernel(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr): | |
xoffset = tl.program_id(0) * XBLOCK | |
xindex = xoffset + tl.reshape(tl.arange(0, XBLOCK), [XBLOCK]) | |
xmask = xindex < xnumel | |
x0 = xindex % 64 | |
x4 = (xindex // 64) % 128 | |
x5 = (xindex // 8192) % 12 | |
x6 = (xindex // 98304) | |
x7 = xindex | |
tmp0 = tl.load(in_ptr0 + (x0 + (64*x5) + (768*x4) + (98304*x6)), xmask) | |
tl.store(out_ptr0 + (x7 + tl.zeros([XBLOCK], tl.int32)), tmp0, xmask) | |
''') | |
triton_fused_clone_3_add_1_div_where_mul_clone_2_permute_5_bmm_1_permute_6_sum_15 = async_compile.triton(''' | |
import triton | |
import triton.language as tl | |
from torch._inductor.ir import ReductionHint | |
from torch._inductor.ir import TileHint | |
from torch._inductor.triton_ops.autotune import pointwise | |
from torch._inductor.utils import instance_descriptor | |
@pointwise(size_hints=[262144], filename=__file__, meta={'signature': {0: '*fp32', 1: '*fp32', 2: 'i32'}, 'device': 0, 'constants': {}, 'configs': [instance_descriptor(divisible_by_16=(0, 1, 2), equal_to_1=())]}) | |
@triton.jit | |
def kernel(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr): | |
xoffset = tl.program_id(0) * XBLOCK | |
xindex = xoffset + tl.reshape(tl.arange(0, XBLOCK), [XBLOCK]) | |
xmask = xindex < xnumel | |
x0 = xindex % 64 | |
x1 = (xindex // 64) % 12 | |
x4 = (xindex // 768) % 128 | |
x5 = (xindex // 98304) | |
x6 = xindex | |
tmp0 = tl.load(in_ptr0 + (x0 + (64*x4) + (8192*x1) + (98304*x5)), xmask) | |
tl.store(out_ptr0 + (x6 + tl.zeros([XBLOCK], tl.int32)), tmp0, xmask) | |
''') | |
triton_fused_clone_3_add_1_div_view_13_mul_clone_2_bmm_1_permute_8_sub_mean6 = async_compile.triton(''' | |
import triton | |
import triton.language as tl | |
from torch._inductor.ir import ReductionHint | |
from torch._inductor.ir import TileHint | |
from torch._inductor.triton_ops.autotune import reduction | |
from torch._inductor.utils import instance_descriptor | |
@reduction(size_hints=[256, 1024], | |
reduction_hint=ReductionHint.INNER, | |
filename=__file__, | |
meta={'signature': {0: '*fp32', 1: '*fp32', 2: '*fp32', 3: '*fp32', 4: '*fp32', 5: '*fp32', 6: '*fp32', 7: 'i32', 8: 'i32'}, 'device': 0, 'constants': {}, 'configs': [instance_descriptor(divisible_by_16=(0, 1, 2, 3, 4, 5, 6, 7, 8), equal_to_1=())]}) | |
@triton.jit | |
def kernel(in_out_ptr0, in_out_ptr1, in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr1, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr): | |
xoffset = tl.program_id(0) * XBLOCK | |
xindex = xoffset + tl.reshape(tl.arange(0, XBLOCK), [XBLOCK, 1]) | |
xmask = xindex < xnumel | |
rbase = tl.reshape(tl.arange(0, RBLOCK), [1, RBLOCK]) | |
x0 = xindex | |
_tmp3 = tl.zeros([XBLOCK, RBLOCK], tl.float32) + 0 | |
for roffset in range(0, rnumel, RBLOCK): | |
rindex = roffset + rbase | |
rmask = rindex < rnumel | |
r1 = rindex | |
tmp0 = tl.load(in_ptr0 + (r1 + (768*x0)), xmask & rmask, eviction_policy='evict_last') | |
tmp1 = tl.load(in_ptr1 + (r1 + (768*x0)), xmask & rmask, eviction_policy='evict_last') | |
tmp2 = tmp0 + tmp1 | |
_tmp3 = tl.where(xmask & rmask, _tmp3 + tmp2, _tmp3) | |
tmp3 = tl.reshape(tl.sum(_tmp3, 1), [XBLOCK, 1]) | |
tmp4 = tmp3 | |
_tmp12 = tl.zeros([XBLOCK, RBLOCK], tl.float32) + 0 | |
for roffset in range(0, rnumel, RBLOCK): | |
rindex = roffset + rbase | |
rmask = rindex < rnumel | |
r1 = rindex | |
tmp5 = tl.load(in_ptr0 + (r1 + (768*x0)), xmask & rmask, eviction_policy='evict_last') | |
tmp6 = tl.load(in_ptr1 + (r1 + (768*x0)), xmask & rmask, eviction_policy='evict_last') | |
tmp7 = tmp5 + tmp6 | |
tmp8 = 768 | |
tmp9 = tmp4 / tmp8 | |
tmp10 = tmp7 - tmp9 | |
tmp11 = tmp10 * tmp10 | |
_tmp12 = tl.where(xmask & rmask, _tmp12 + tmp11, _tmp12) | |
tmp12 = tl.reshape(tl.sum(_tmp12, 1), [XBLOCK, 1]) | |
tmp13 = 768 | |
tmp14 = tmp3 / tmp13 | |
tmp15 = 767 | |
tmp16 = tmp12 / tmp15 | |
tmp17 = tl.sqrt(tmp16) | |
tl.store(in_out_ptr0 + (x0 + tl.zeros([XBLOCK, 1], tl.int32)), tmp14, xmask) | |
tl.store(in_out_ptr1 + (x0 + tl.zeros([XBLOCK, 1], tl.int32)), tmp17, xmask) | |
for roffset in range(0, rnumel, RBLOCK): | |
rindex = roffset + rbase | |
rmask = rindex < rnumel | |
r1 = rindex | |
tmp18 = tl.load(in_ptr2 + (r1), rmask, eviction_policy='evict_last') | |
tmp19 = tl.load(in_ptr0 + (r1 + (768*x0)), xmask & rmask, eviction_policy='evict_last') | |
tmp20 = tl.load(in_ptr1 + (r1 + (768*x0)), xmask & rmask, eviction_policy='evict_last') | |
tmp27 = tl.load(in_ptr3 + (r1), rmask, eviction_policy='evict_last') | |
tmp21 = tmp19 + tmp20 | |
tmp22 = tmp21 - tmp14 | |
tmp23 = tmp18 * tmp22 | |
tmp24 = 1e-06 | |
tmp25 = tmp17 + tmp24 | |
tmp26 = tmp23 / tmp25 | |
tmp28 = tmp26 + tmp27 | |
tl.store(out_ptr1 + (r1 + (768*x0) + tl.zeros([XBLOCK, RBLOCK], tl.int32)), tmp28, xmask & rmask) | |
''') | |
triton_fused_clone_3_add_1_div_abs_1_view_13_mul_clone_2_bmm_1_mul_631_permute_87 = async_compile.triton(''' | |
import triton | |
import triton.language as tl | |
from torch._inductor.ir import ReductionHint | |
from torch._inductor.ir import TileHint | |
from torch._inductor.triton_ops.autotune import pointwise | |
from torch._inductor.utils import instance_descriptor | |
@pointwise(size_hints=[1048576], filename=__file__, meta={'signature': {0: '*fp32', 1: '*fp32', 2: '*fp32', 3: 'i32'}, 'device': 0, 'constants': {}, 'configs': [instance_descriptor(divisible_by_16=(0, 1, 2, 3), equal_to_1=())]}) | |
@triton.jit | |
def kernel(in_ptr0, out_ptr0, out_ptr1, xnumel, XBLOCK : tl.constexpr): | |
xoffset = tl.program_id(0) * XBLOCK | |
xindex = xoffset + tl.reshape(tl.arange(0, XBLOCK), [XBLOCK]) | |
xmask = xindex < xnumel | |
x0 = xindex | |
tmp0 = tl.load(in_ptr0 + (x0), xmask) | |
tmp1 = 0.5 | |
tmp2 = tmp0 * tmp1 | |
tmp3 = 0.7071067811865476 | |
tmp4 = tmp0 * tmp3 | |
tmp5 = tl.where(tmp4 < 0, -1, 1) | |
tmp6 = tl.where(tmp4 == 0, 0, tmp5) | |
tmp7 = 1.0 | |
tmp8 = tl.abs(tmp4) | |
tmp9 = 0.3275911 | |
tmp10 = tmp8 * tmp9 | |
tmp11 = tmp10 + tmp7 | |
tmp12 = 1 / tmp11 | |
tmp13 = tmp12 * tmp7 | |
tmp14 = 1.061405429 | |
tmp15 = tmp13 * tmp14 | |
tmp16 = -1.453152027 | |
tmp17 = tmp15 + tmp16 | |
tmp18 = tmp17 * tmp13 | |
tmp19 = 1.421413741 | |
tmp20 = tmp18 + tmp19 | |
tmp21 = tmp20 * tmp13 | |
tmp22 = -0.284496736 | |
tmp23 = tmp21 + tmp22 | |
tmp24 = tmp23 * tmp13 | |
tmp25 = 0.254829592 | |
tmp26 = tmp24 + tmp25 | |
tmp27 = tmp26 * tmp13 | |
tmp28 = -tmp8 | |
tmp29 = tmp28 * tmp8 | |
tmp30 = tl.exp(tmp29) | |
tmp31 = tmp27 * tmp30 | |
tmp32 = tmp7 - tmp31 | |
tmp33 = tmp6 * tmp32 | |
tmp34 = 1 | |
tmp35 = tmp33 + tmp34 | |
tmp36 = tmp2 * tmp35 | |
tmp37 = tmp35 * tmp1 | |
tmp38 = tmp0 * tmp0 | |
tmp39 = -0.5 | |
tmp40 = tmp38 * tmp39 | |
tmp41 = tl.exp(tmp40) | |
tmp42 = 0.3989422804014327 | |
tmp43 = tmp41 * tmp42 | |
tmp44 = tmp0 * tmp43 | |
tmp45 = tmp37 + tmp44 | |
tl.store(out_ptr0 + (x0 + tl.zeros([XBLOCK], tl.int32)), tmp36, xmask) | |
tl.store(out_ptr1 + (x0 + tl.zeros([XBLOCK], tl.int32)), tmp45, xmask) | |
''') | |
triton_fused_add_11_clone_3_add_1_div_abs_1_view_13_mul_add_12_clone_2_sub_48 = async_compile.triton(''' | |
import triton | |
import triton.language as tl | |
from torch._inductor.ir import ReductionHint | |
from torch._inductor.ir import TileHint | |
from torch._inductor.triton_ops.autotune import reduction | |
from torch._inductor.utils import instance_descriptor | |
@reduction(size_hints=[256, 1024], | |
reduction_hint=ReductionHint.INNER, | |
filename=__file__, | |
meta={'signature': {0: '*fp32', 1: '*fp32', 2: '*fp32', 3: '*fp32', 4: '*fp32', 5: '*fp32', 6: '*fp32', 7: '*fp32', 8: 'i32', 9: 'i32'}, 'device': 0, 'constants': {}, 'configs': [instance_descriptor(divisible_by_16=(0, 1, 2, 3, 4, 5, 6, 7, 8, 9), equal_to_1=())]}) | |
@triton.jit | |
def kernel(in_out_ptr0, in_out_ptr1, in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, out_ptr1, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr): | |
xoffset = tl.program_id(0) * XBLOCK | |
xindex = xoffset + tl.reshape(tl.arange(0, XBLOCK), [XBLOCK, 1]) | |
xmask = xindex < xnumel | |
rbase = tl.reshape(tl.arange(0, RBLOCK), [1, RBLOCK]) | |
x0 = xindex | |
_tmp5 = tl.zeros([XBLOCK, RBLOCK], tl.float32) + 0 | |
for roffset in range(0, rnumel, RBLOCK): | |
rindex = roffset + rbase | |
rmask = rindex < rnumel | |
r1 = rindex | |
tmp0 = tl.load(in_ptr0 + (r1 + (768*x0)), xmask & rmask, eviction_policy='evict_last') | |
tmp1 = tl.load(in_ptr1 + (r1 + (768*x0)), xmask & rmask, eviction_policy='evict_last') | |
tmp3 = tl.load(in_ptr2 + (r1 + (768*x0)), xmask & rmask, eviction_policy='evict_last') | |
tmp2 = tmp0 + tmp1 | |
tmp4 = tmp2 + tmp3 | |
_tmp5 = tl.where(xmask & rmask, _tmp5 + tmp4, _tmp5) | |
tmp5 = tl.reshape(tl.sum(_tmp5, 1), [XBLOCK, 1]) | |
tmp6 = tmp5 | |
_tmp16 = tl.zeros([XBLOCK, RBLOCK], tl.float32) + 0 | |
for roffset in range(0, rnumel, RBLOCK): | |
rindex = roffset + rbase | |
rmask = rindex < rnumel | |
r1 = rindex | |
tmp7 = tl.load(in_ptr0 + (r1 + (768*x0)), xmask & rmask, eviction_policy='evict_last') | |
tmp8 = tl.load(in_ptr1 + (r1 + (768*x0)), xmask & rmask, eviction_policy='evict_last') | |
tmp10 = tl.load(in_ptr2 + (r1 + (768*x0)), xmask & rmask, eviction_policy='evict_last') | |
tmp9 = tmp7 + tmp8 | |
tmp11 = tmp9 + tmp10 | |
tmp12 = 768 | |
tmp13 = tmp6 / tmp12 | |
tmp14 = tmp11 - tmp13 | |
tmp15 = tmp14 * tmp14 | |
_tmp16 = tl.where(xmask & rmask, _tmp16 + tmp15, _tmp16) | |
tmp16 = tl.reshape(tl.sum(_tmp16, 1), [XBLOCK, 1]) | |
tmp17 = 768 | |
tmp18 = tmp5 / tmp17 | |
tmp19 = 767 | |
tmp20 = tmp16 / tmp19 | |
tmp21 = tl.sqrt(tmp20) | |
tl.store(in_out_ptr0 + (x0 + tl.zeros([XBLOCK, 1], tl.int32)), tmp18, xmask) | |
tl.store(in_out_ptr1 + (x0 + tl.zeros([XBLOCK, 1], tl.int32)), tmp21, xmask) | |
for roffset in range(0, rnumel, RBLOCK): | |
rindex = roffset + rbase | |
rmask = rindex < rnumel | |
r1 = rindex | |
tmp22 = tl.load(in_ptr3 + (r1), rmask, eviction_policy='evict_last') | |
tmp23 = tl.load(in_ptr0 + (r1 + (768*x0)), xmask & rmask, eviction_policy='evict_last') | |
tmp24 = tl.load(in_ptr1 + (r1 + (768*x0)), xmask & rmask, eviction_policy='evict_last') | |
tmp26 = tl.load(in_ptr2 + (r1 + (768*x0)), xmask & rmask, eviction_policy='evict_last') | |
tmp33 = tl.load(in_ptr4 + (r1), rmask, eviction_policy='evict_last') | |
tmp25 = tmp23 + tmp24 | |
tmp27 = tmp25 + tmp26 | |
tmp28 = tmp27 - tmp18 | |
tmp29 = tmp22 * tmp28 | |
tmp30 = 1e-06 | |
tmp31 = tmp21 + tmp30 | |
tmp32 = tmp29 / tmp31 | |
tmp34 = tmp32 + tmp33 | |
tl.store(out_ptr1 + (r1 + (768*x0) + tl.zeros([XBLOCK, RBLOCK], tl.int32)), tmp34, xmask & rmask) | |
''') | |
triton_fused_add_11_add_1_permute_19_div_view_23_mul_add_12_sub_4_permute_14_view_199 = async_compile.triton(''' | |
import triton | |
import triton.language as tl | |
from torch._inductor.ir import ReductionHint | |
from torch._inductor.ir import TileHint | |
from torch._inductor.triton_ops.autotune import reduction | |
from torch._inductor.utils import instance_descriptor | |
@reduction(size_hints=[256, 1024], | |
reduction_hint=ReductionHint.INNER, | |
filename=__file__, | |
meta={'signature': {0: '*fp32', 1: '*fp32', 2: '*fp32', 3: '*fp32', 4: '*fp32', 5: '*fp32', 6: '*fp32', 7: '*fp32', 8: '*fp32', 9: 'i32', 10: 'i32'}, 'device': 0, 'constants': {}, 'configs': [instance_descriptor(divisible_by_16=(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10), equal_to_1=())]}) | |
@triton.jit | |
def kernel(in_out_ptr0, in_out_ptr1, in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, in_ptr5, out_ptr1, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr): | |
xoffset = tl.program_id(0) * XBLOCK | |
xindex = xoffset + tl.reshape(tl.arange(0, XBLOCK), [XBLOCK, 1]) | |
xmask = xindex < xnumel | |
rbase = tl.reshape(tl.arange(0, RBLOCK), [1, RBLOCK]) | |
x0 = xindex | |
_tmp7 = tl.zeros([XBLOCK, RBLOCK], tl.float32) + 0 | |
for roffset in range(0, rnumel, RBLOCK): | |
rindex = roffset + rbase | |
rmask = rindex < rnumel | |
r1 = rindex | |
tmp0 = tl.load(in_ptr0 + (r1 + (768*x0)), xmask & rmask, eviction_policy='evict_last') | |
tmp1 = tl.load(in_ptr1 + (r1 + (768*x0)), xmask & rmask, eviction_policy='evict_last') | |
tmp3 = tl.load(in_ptr2 + (r1 + (768*x0)), xmask & rmask, eviction_policy='evict_last') | |
tmp5 = tl.load(in_ptr3 + (r1 + (768*x0)), xmask & rmask, eviction_policy='evict_last') | |
tmp2 = tmp0 + tmp1 | |
tmp4 = tmp2 + tmp3 | |
tmp6 = tmp4 + tmp5 | |
_tmp7 = tl.where(xmask & rmask, _tmp7 + tmp6, _tmp7) | |
tmp7 = tl.reshape(tl.sum(_tmp7, 1), [XBLOCK, 1]) | |
tmp8 = tmp7 | |
_tmp20 = tl.zeros([XBLOCK, RBLOCK], tl.float32) + 0 | |
for roffset in range(0, rnumel, RBLOCK): | |
rindex = roffset + rbase | |
rmask = rindex < rnumel | |
r1 = rindex | |
tmp9 = tl.load(in_ptr0 + (r1 + (768*x0)), xmask & rmask, eviction_policy='evict_last') | |
tmp10 = tl.load(in_ptr1 + (r1 + (768*x0)), xmask & rmask, eviction_policy='evict_last') | |
tmp12 = tl.load(in_ptr2 + (r1 + (768*x0)), xmask & rmask, eviction_policy='evict_last') | |
tmp14 = tl.load(in_ptr3 + (r1 + (768*x0)), xmask & rmask, eviction_policy='evict_last') | |
tmp11 = tmp9 + tmp10 | |
tmp13 = tmp11 + tmp12 | |
tmp15 = tmp13 + tmp14 | |
tmp16 = 768 | |
tmp17 = tmp8 / tmp16 | |
tmp18 = tmp15 - tmp17 | |
tmp19 = tmp18 * tmp18 | |
_tmp20 = tl.where(xmask & rmask, _tmp20 + tmp19, _tmp20) | |
tmp20 = tl.reshape(tl.sum(_tmp20, 1), [XBLOCK, 1]) | |
tmp21 = 768 | |
tmp22 = tmp7 / tmp21 | |
tmp23 = 767 | |
tmp24 = tmp20 / tmp23 | |
tmp25 = tl.sqrt(tmp24) | |
tl.store(in_out_ptr0 + (x0 + tl.zeros([XBLOCK, 1], tl.int32)), tmp22, xmask) | |
tl.store(in_out_ptr1 + (x0 + tl.zeros([XBLOCK, 1], tl.int32)), tmp25, xmask) | |
for roffset in range(0, rnumel, RBLOCK): | |
rindex = roffset + rbase | |
rmask = rindex < rnumel | |
r1 = rindex | |
tmp26 = tl.load(in_ptr4 + (r1), rmask, eviction_policy='evict_last') | |
tmp27 = tl.load(in_ptr0 + (r1 + (768*x0)), xmask & rmask, eviction_policy='evict_last') | |
tmp28 = tl.load(in_ptr1 + (r1 + (768*x0)), xmask & rmask, eviction_policy='evict_last') | |
tmp30 = tl.load(in_ptr2 + (r1 + (768*x0)), xmask & rmask, eviction_policy='evict_last') | |
tmp32 = tl.load(in_ptr3 + (r1 + (768*x0)), xmask & rmask, eviction_policy='evict_last') | |
tmp39 = tl.load(in_ptr5 + (r1), rmask, eviction_policy='evict_last') | |
tmp29 = tmp27 + tmp28 | |
tmp31 = tmp29 + tmp30 | |
tmp33 = tmp31 + tmp32 | |
tmp34 = tmp33 - tmp22 | |
tmp35 = tmp26 * tmp34 | |
tmp36 = 1e-06 | |
tmp37 = tmp25 + tmp36 | |
tmp38 = tmp35 / tmp37 | |
tmp40 = tmp38 + tmp39 | |
tl.store(out_ptr1 + (r1 + (768*x0) + tl.zeros([XBLOCK, RBLOCK], tl.int32)), tmp40, xmask & rmask) | |
''') | |
triton_fused_add_11_add_1_permute_19_div_view_23_mul_add_12_view_35_sub_4_permute_1410 = async_compile.triton(''' | |
import triton | |
import triton.language as tl | |
from torch._inductor.ir import ReductionHint | |
from torch._inductor.ir import TileHint | |
from torch._inductor.triton_ops.autotune import reduction | |
from torch._inductor.utils import instance_descriptor | |
@reduction(size_hints=[256, 1024], | |
reduction_hint=ReductionHint.INNER, | |
filename=__file__, | |
meta={'signature': {0: '*fp32', 1: '*fp32', 2: '*fp32', 3: '*fp32', 4: '*fp32', 5: '*fp32', 6: '*fp32', 7: '*fp32', 8: '*fp32', 9: '*fp32', 10: '*fp32', 11: 'i32', 12: 'i32'}, 'device': 0, 'constants': {}, 'configs': [instance_descriptor(divisible_by_16=(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12), equal_to_1=())]}) | |
@triton.jit | |
def kernel(in_out_ptr0, in_out_ptr1, in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, in_ptr5, in_ptr6, out_ptr0, out_ptr2, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr): | |
xoffset = tl.program_id(0) * XBLOCK | |
xindex = xoffset + tl.reshape(tl.arange(0, XBLOCK), [XBLOCK, 1]) | |
xmask = xindex < xnumel | |
rbase = tl.reshape(tl.arange(0, RBLOCK), [1, RBLOCK]) | |
x0 = xindex | |
for roffset in range(0, rnumel, RBLOCK): | |
rindex = roffset + rbase | |
rmask = rindex < rnumel | |
r1 = rindex | |
tmp0 = tl.load(in_ptr0 + (r1 + (768*x0)), xmask & rmask, eviction_policy='evict_last') | |
tmp1 = tl.load(in_ptr1 + (r1 + (768*x0)), xmask & rmask, eviction_policy='evict_last') | |
tmp3 = tl.load(in_ptr2 + (r1 + (768*x0)), xmask & rmask, eviction_policy='evict_last') | |
tmp5 = tl.load(in_ptr3 + (r1 + (768*x0)), xmask & rmask, eviction_policy='evict_last') | |
tmp7 = tl.load(in_ptr4 + (r1 + (768*x0)), xmask & rmask, eviction_policy='evict_last') | |
tmp2 = tmp0 + tmp1 | |
tmp4 = tmp2 + tmp3 | |
tmp6 = tmp4 + tmp5 | |
tmp8 = tmp6 + tmp7 | |
tl.store(out_ptr0 + (r1 + (768*x0) + tl.zeros([XBLOCK, RBLOCK], tl.int32)), tmp8, xmask & rmask) | |
_tmp10 = tl.zeros([XBLOCK, RBLOCK], tl.float32) + 0 | |
for roffset in range(0, rnumel, RBLOCK): | |
rindex = roffset + rbase | |
rmask = rindex < rnumel | |
r1 = rindex | |
tmp9 = tl.load(out_ptr0 + (r1 + (768*x0)), xmask & rmask, eviction_policy='evict_last') | |
_tmp10 = tl.where(xmask & rmask, _tmp10 + tmp9, _tmp10) | |
tmp10 = tl.reshape(tl.sum(_tmp10, 1), [XBLOCK, 1]) | |
tmp11 = tmp10 | |
_tmp17 = tl.zeros([XBLOCK, RBLOCK], tl.float32) + 0 | |
for roffset in range(0, rnumel, RBLOCK): | |
rindex = roffset + rbase | |
rmask = rindex < rnumel | |
r1 = rindex | |
tmp12 = tl.load(out_ptr0 + (r1 + (768*x0)), xmask & rmask, eviction_policy='evict_last') | |
tmp13 = 768 | |
tmp14 = tmp11 / tmp13 | |
tmp15 = tmp12 - tmp14 | |
tmp16 = tmp15 * tmp15 | |
_tmp17 = tl.where(xmask & rmask, _tmp17 + tmp16, _tmp17) | |
tmp17 = tl.reshape(tl.sum(_tmp17, 1), [XBLOCK, 1]) | |
tmp18 = 768 | |
tmp19 = tmp10 / tmp18 | |
tmp20 = 767 | |
tmp21 = tmp17 / tmp20 | |
tmp22 = tl.sqrt(tmp21) | |
tl.store(in_out_ptr0 + (x0 + tl.zeros([XBLOCK, 1], tl.int32)), tmp19, xmask) | |
tl.store(in_out_ptr1 + (x0 + tl.zeros([XBLOCK, 1], tl.int32)), tmp22, xmask) | |
for roffset in range(0, rnumel, RBLOCK): | |
rindex = roffset + rbase | |
rmask = rindex < rnumel | |
r1 = rindex | |
tmp23 = tl.load(in_ptr5 + (r1), rmask, eviction_policy='evict_last') | |
tmp24 = tl.load(out_ptr0 + (r1 + (768*x0)), xmask & rmask, eviction_policy='evict_last') | |
tmp30 = tl.load(in_ptr6 + (r1), rmask, eviction_policy='evict_last') | |
tmp25 = tmp24 - tmp19 | |
tmp26 = tmp23 * tmp25 | |
tmp27 = 1e-06 | |
tmp28 = tmp22 + tmp27 | |
tmp29 = tmp26 / tmp28 | |
tmp31 = tmp29 + tmp30 | |
tl.store(out_ptr2 + (r1 + (768*x0) + tl.zeros([XBLOCK, RBLOCK], tl.int32)), tmp31, xmask & rmask) | |
''') | |
triton_fused_permute_19_view_35_view_34_reciprocal_1_view_33_addmm_9_mul_32_permute_20_std_3_add_1411 = async_compile.triton(''' | |
import triton | |
import triton.language as tl | |
from torch._inductor.ir import ReductionHint | |
from torch._inductor.ir import TileHint | |
from torch._inductor.triton_ops.autotune import reduction | |
from torch._inductor.utils import instance_descriptor | |
@reduction(size_hints=[256, 1024], | |
reduction_hint=ReductionHint.INNER, | |
filename=__file__, | |
meta={'signature': {0: '*fp32', 1: '*fp32', 2: '*fp32', 3: '*fp32', 4: '*fp32', 5: '*fp32', 6: '*fp32', 7: '*fp32', 8: '*fp32', 9: '*fp32', 10: 'i32', 11: 'i32'}, 'device': 0, 'constants': {}, 'configs': [instance_descriptor(divisible_by_16=(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11), equal_to_1=())]}) | |
@triton.jit | |
def kernel(in_out_ptr0, in_out_ptr1, in_out_ptr2, in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, in_ptr5, out_ptr1, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr): | |
xoffset = tl.program_id(0) * XBLOCK | |
xindex = xoffset + tl.reshape(tl.arange(0, XBLOCK), [XBLOCK, 1]) | |
xmask = xindex < xnumel | |
rbase = tl.reshape(tl.arange(0, RBLOCK), [1, RBLOCK]) | |
x0 = xindex | |
for roffset in range(0, rnumel, RBLOCK): | |
rindex = roffset + rbase | |
rmask = rindex < rnumel | |
r1 = rindex | |
tmp0 = tl.load(in_out_ptr0 + (r1 + (768*x0)), xmask & rmask, eviction_policy='evict_last') | |
tmp1 = tl.load(in_ptr0 + (r1 + (768*x0)), xmask & rmask, eviction_policy='evict_last') | |
tmp3 = tl.load(in_ptr1 + (r1 + (768*x0)), xmask & rmask, eviction_policy='evict_last') | |
tmp5 = tl.load(in_ptr2 + (r1 + (768*x0)), xmask & rmask, eviction_policy='evict_last') | |
tmp7 = tl.load(in_ptr3 + (r1 + (768*x0)), xmask & rmask, eviction_policy='evict_last') | |
tmp2 = tmp0 + tmp1 | |
tmp4 = tmp2 + tmp3 | |
tmp6 = tmp4 + tmp5 | |
tmp8 = tmp6 + tmp7 | |
tl.store(in_out_ptr0 + (r1 + (768*x0) + tl.zeros([XBLOCK, RBLOCK], tl.int32)), tmp8, xmask & rmask) | |
_tmp10 = tl.zeros([XBLOCK, RBLOCK], tl.float32) + 0 | |
for roffset in range(0, rnumel, RBLOCK): | |
rindex = roffset + rbase | |
rmask = rindex < rnumel | |
r1 = rindex | |
tmp9 = tl.load(in_out_ptr0 + (r1 + (768*x0)), xmask & rmask, eviction_policy='evict_last') | |
_tmp10 = tl.where(xmask & rmask, _tmp10 + tmp9, _tmp10) | |
tmp10 = tl.reshape(tl.sum(_tmp10, 1), [XBLOCK, 1]) | |
tmp11 = tmp10 | |
_tmp17 = tl.zeros([XBLOCK, RBLOCK], tl.float32) + 0 | |
for roffset in range(0, rnumel, RBLOCK): | |
rindex = roffset + rbase | |
rmask = rindex < rnumel | |
r1 = rindex | |
tmp12 = tl.load(in_out_ptr0 + (r1 + (768*x0)), xmask & rmask, eviction_policy='evict_last') | |
tmp13 = 768 | |
tmp14 = tmp11 / tmp13 | |
tmp15 = tmp12 - tmp14 | |
tmp16 = tmp15 * tmp15 | |
_tmp17 = tl.where(xmask & rmask, _tmp17 + tmp16, _tmp17) | |
tmp17 = tl.reshape(tl.sum(_tmp17, 1), [XBLOCK, 1]) | |
tmp18 = 768 | |
tmp19 = tmp10 / tmp18 | |
tmp20 = 767 | |
tmp21 = tmp17 / tmp20 | |
tmp22 = tl.sqrt(tmp21) | |
tl.store(in_out_ptr1 + (x0 + tl.zeros([XBLOCK, 1], tl.int32)), tmp19, xmask) | |
tl.store(in_out_ptr2 + (x0 + tl.zeros([XBLOCK, 1], tl.int32)), tmp22, xmask) | |
for roffset in range(0, rnumel, RBLOCK): | |
rindex = roffset + rbase | |
rmask = rindex < rnumel | |
r1 = rindex | |
tmp23 = tl.load(in_ptr4 + (r1), rmask, eviction_policy='evict_last') | |
tmp24 = tl.load(in_out_ptr0 + (r1 + (768*x0)), xmask & rmask, eviction_policy='evict_last') | |
tmp30 = tl.load(in_ptr5 + (r1), rmask, eviction_policy='evict_last') | |
tmp25 = tmp24 - tmp19 | |
tmp26 = tmp23 * tmp25 | |
tmp27 = 1e-06 | |
tmp28 = tmp22 + tmp27 | |
tmp29 = tmp26 / tmp28 | |
tmp31 = tmp29 + tmp30 | |
tl.store(out_ptr1 + (r1 + (768*x0) + tl.zeros([XBLOCK, RBLOCK], tl.int32)), tmp31, xmask & rmask) | |
''') | |
triton_fused_clone_19_view_88_bmm_9_div_18__unsafe_view_14_abs_5_lift_fresh_copy_8_clone_17_bmm_8_permute_5212 = async_compile.triton(''' | |
import triton | |
import triton.language as tl | |
from torch._inductor.ir import ReductionHint | |
from torch._inductor.ir import TileHint | |
from torch._inductor.triton_ops.autotune import pointwise | |
from torch._inductor.utils import instance_descriptor | |
@pointwise(size_hints=[262144], filename=__file__, meta={'signature': {0: '*fp32', 1: '*fp32', 2: '*fp32', 3: '*fp32', 4: '*fp32', 5: 'i32'}, 'device': 0, 'constants': {}, 'configs': [instance_descriptor(divisible_by_16=(0, 1, 2, 3, 4, 5), equal_to_1=())]}) | |
@triton.jit | |
def kernel(in_out_ptr0, in_ptr0, in_ptr1, in_ptr2, in_ptr3, xnumel, XBLOCK : tl.constexpr): | |
xoffset = tl.program_id(0) * XBLOCK | |
xindex = xoffset + tl.reshape(tl.arange(0, XBLOCK), [XBLOCK]) | |
xmask = xindex < xnumel | |
x0 = xindex | |
tmp0 = tl.load(in_out_ptr0 + (x0), xmask) | |
tmp1 = tl.load(in_ptr0 + (x0), xmask) | |
tmp3 = tl.load(in_ptr1 + (x0), xmask) | |
tmp5 = tl.load(in_ptr2 + (x0), xmask) | |
tmp7 = tl.load(in_ptr3 + (x0), xmask) | |
tmp2 = tmp0 + tmp1 | |
tmp4 = tmp2 + tmp3 | |
tmp6 = tmp4 + tmp5 | |
tmp8 = tmp6 + tmp7 | |
tl.store(in_out_ptr0 + (x0 + tl.zeros([XBLOCK], tl.int32)), tmp8, xmask) | |
''') | |
async_compile.wait(globals()) | |
del async_compile | |
def call(args): | |
primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, primals_12, primals_13, primals_14, primals_15, primals_16, primals_17, primals_18, primals_19, primals_20, primals_21, primals_22, primals_23, primals_24, primals_25, primals_26, primals_27, primals_28, primals_29, primals_30, primals_31, primals_32, primals_33, primals_34, primals_35, primals_36, primals_37, primals_38, primals_39, primals_40, primals_41, primals_42, primals_43, primals_44, primals_45, primals_46, primals_47, primals_48, primals_49, primals_50, primals_51, primals_52, primals_53, primals_54, primals_55, primals_56, primals_57, primals_58, primals_59, primals_60, primals_61, primals_62, primals_63, primals_64, primals_65, primals_66, primals_67, primals_68, primals_69, primals_70, primals_71, primals_72, primals_73, primals_74, primals_75, primals_76, primals_77, primals_78, primals_79, primals_80, primals_81, primals_82, primals_83, primals_84, primals_85, primals_86, primals_87, primals_88, primals_89, primals_90, primals_91, primals_92, primals_93, primals_94, primals_95, primals_96, primals_97, primals_98, primals_99, primals_100, primals_101, primals_102, primals_103, primals_104, primals_105, primals_106, primals_107, primals_108, primals_109, primals_110, primals_111, primals_112, primals_113, primals_114, primals_115, primals_116, primals_117, primals_118, primals_119, primals_120, primals_121, primals_122, primals_123, primals_124, primals_125, primals_126, primals_127, primals_128, primals_129, primals_130, primals_131, primals_132, primals_133, primals_134, primals_135, primals_136, primals_137, primals_138, primals_139, primals_140, primals_141, primals_142, primals_143, primals_144, primals_145, primals_146, primals_147, primals_148, primals_149, primals_150, primals_151, primals_152, primals_153, primals_154, primals_155, primals_156, primals_157, primals_158, primals_159, primals_160, primals_161, primals_162, primals_163, primals_164, primals_165, primals_166, primals_167, primals_168, primals_169, primals_170, primals_171, primals_172, primals_173, primals_174, primals_175, primals_176, primals_177, primals_178, primals_179, primals_180, primals_181, primals_182, primals_183, primals_184, primals_185, primals_186, primals_187, primals_188, primals_189, primals_190, primals_191, primals_192, primals_193, primals_194 = args | |
args.clear() | |
primals_193_size = primals_193.size() | |
s0 = primals_193_size[0] | |
s1 = primals_193_size[1] | |
buf0 = empty_strided((s0, 128, 1), (128, 1, 128*s0), device='cuda', dtype=torch.float32) | |
buf3 = empty_strided((s0, 128, 1), (128, 1, 128*s0), device='cuda', dtype=torch.float32) | |
buf1 = buf0; del buf0 # reuse | |
buf4 = buf3; del buf3 # reuse | |
buf5 = empty_strided((s0, 128, 768), (98304, 768, 1), device='cuda', dtype=torch.float32) | |
triton_fused_add_1_div_mul_std_sub_mean_add0_xnumel = s0*s1 | |
stream0 = get_cuda_stream(0) | |
triton_fused_add_1_div_mul_std_sub_mean_add0.run(buf1, buf4, primals_193, primals_1, primals_2, buf5, triton_fused_add_1_div_mul_std_sub_mean_add0_xnumel, 768, grid=grid(triton_fused_add_1_div_mul_std_sub_mean_add0_xnumel), stream=stream0) | |
buf6 = empty_strided((128*s0, 768), (768, 1), device='cuda', dtype=torch.float32) | |
aten.addmm.out(primals_50, as_strided(buf5, (s0*s1, 768), (768, 1)), as_strided(primals_49, (768, 768), (1, 768)), beta=1, alpha=1, out=buf6) | |
del primals_50 | |
buf7 = empty_strided((128*s0, 768), (768, 1), device='cuda', dtype=torch.float32) | |
aten.addmm.out(primals_52, as_strided(buf5, (s0*s1, 768), (768, 1)), as_strided(primals_51, (768, 768), (1, 768)), beta=1, alpha=1, out=buf7) | |
del primals_52 | |
buf8 = empty_strided((128*s0, 768), (768, 1), device='cuda', dtype=torch.float32) | |
aten.addmm.out(primals_54, as_strided(buf5, (s0*s1, 768), (768, 1)), as_strided(primals_53, (768, 768), (1, 768)), beta=1, alpha=1, out=buf8) | |
del primals_54 | |
buf9 = as_strided(buf5, (s0, 12, 128, 64), (98304, 8192, 64, 1)); del buf5 # reuse | |
triton_fused_view_1_add_1_addmm_div_mul_permute_sub_mean_add_permute_11_xnumel = 768*s0*s1 | |
triton_fused_view_1_add_1_addmm_div_mul_permute_sub_mean_add_permute_11.run(buf6, buf9, s1, triton_fused_view_1_add_1_addmm_div_mul_permute_sub_mean_add_permute_11_xnumel, grid=grid(triton_fused_view_1_add_1_addmm_div_mul_permute_sub_mean_add_permute_11_xnumel), stream=stream0) | |
buf10 = as_strided(buf6, (s0, 12, 64, 128), (98304, 8192, 128, 1)); del buf6 # reuse | |
triton_fused_add_1_div_mul_permute_6_sub_addmm_1_mean_permute_3_add_clone_12_xnumel = 768*s0 | |
triton_fused_add_1_div_mul_permute_6_sub_addmm_1_mean_permute_3_add_clone_12.run(buf7, buf10, triton_fused_add_1_div_mul_permute_6_sub_addmm_1_mean_permute_3_add_clone_12_xnumel, 128, grid=grid(triton_fused_add_1_div_mul_permute_6_sub_addmm_1_mean_permute_3_add_clone_12_xnumel, 128), stream=stream0) | |
buf11 = empty_strided((12*s0, 128, 128), (16384, 128, 1), device='cuda', dtype=torch.float32) | |
aten.bmm.out(as_strided(buf9, (12*s0, 128, 64), (8192, 64, 1)), as_strided(buf10, (12*s0, 64, 128), (8192, 128, 1)), out=buf11) | |
buf14 = empty_strided((s0, 12, 128, 128), (196608, 16384, 128, 1), device='cuda', dtype=torch.float32) | |
triton_fused_add_1_div_mul_sub_mean_add_div_2_sub_1_div_1_amax3_xnumel = 1536*s0 | |
triton_fused_add_1_div_mul_sub_mean_add_div_2_sub_1_div_1_amax3.run(primals_194, buf11, buf14, triton_fused_add_1_div_mul_sub_mean_add_div_2_sub_1_div_1_amax3_xnumel, 128, grid=grid(triton_fused_add_1_div_mul_sub_mean_add_div_2_sub_1_div_1_amax3_xnumel), stream=stream0) | |
buf15 = as_strided(buf7, (s0, 12, 128, 64), (98304, 8192, 64, 1)); del buf7 # reuse | |
triton_fused_add_1_div_mul_clone_2_view_7_permute_5_sub_permute_4_mean_add4_xnumel = 98304*s0 | |
triton_fused_add_1_div_mul_clone_2_view_7_permute_5_sub_permute_4_mean_add4.run(buf8, buf15, triton_fused_add_1_div_mul_clone_2_view_7_permute_5_sub_permute_4_mean_add4_xnumel, grid=grid(triton_fused_add_1_div_mul_clone_2_view_7_permute_5_sub_permute_4_mean_add4_xnumel), stream=stream0) | |
buf16 = as_strided(buf8, (12*s0, 128, 64), (8192, 64, 1)); del buf8 # reuse | |
aten.bmm.out(as_strided(buf14, (12*s0, 128, 128), (16384, 128, 1)), as_strided(buf15, (12*s0, 128, 64), (8192, 64, 1)), out=buf16) | |
buf17 = empty_strided((s0, 128, 12, 64), (98304, 768, 64, 1), device='cuda', dtype=torch.float32) | |
triton_fused_clone_3_add_1_div_where_mul_clone_2_permute_5_bmm_1_permute_6_sum_15_xnumel = 98304*s0 | |
triton_fused_clone_3_add_1_div_where_mul_clone_2_permute_5_bmm_1_permute_6_sum_15.run(buf16, buf17, triton_fused_clone_3_add_1_div_where_mul_clone_2_permute_5_bmm_1_permute_6_sum_15_xnumel, grid=grid(triton_fused_clone_3_add_1_div_where_mul_clone_2_permute_5_bmm_1_permute_6_sum_15_xnumel), stream=stream0) | |
buf18 = empty_strided((128*s0, 768), (768, 1), device='cuda', dtype=torch.float32) | |
aten.addmm.out(primals_56, as_strided(buf17, (128*s0, 768), (768, 1)), as_strided(primals_55, (768, 768), (1, 768)), beta=1, alpha=1, out=buf18) | |
del primals_56 | |
buf19 = empty_strided((s0, 128, 1), (128, 1, 128*s0), device='cuda', dtype=torch.float32) | |
buf22 = empty_strided((s0, 128, 1), (128, 1, 128*s0), device='cuda', dtype=torch.float32) | |
buf20 = buf19; del buf19 # reuse | |
buf23 = buf22; del buf22 # reuse | |
buf24 = as_strided(buf17, (s0, 128, 768), (98304, 768, 1)); del buf17 # reuse | |
triton_fused_clone_3_add_1_div_view_13_mul_clone_2_bmm_1_permute_8_sub_mean6_xnumel = 128*s0 | |
triton_fused_clone_3_add_1_div_view_13_mul_clone_2_bmm_1_permute_8_sub_mean6.run(buf20, buf23, primals_193, buf18, primals_3, primals_4, buf24, triton_fused_clone_3_add_1_div_view_13_mul_clone_2_bmm_1_permute_8_sub_mean6_xnumel, 768, grid=grid(triton_fused_clone_3_add_1_div_view_13_mul_clone_2_bmm_1_permute_8_sub_mean6_xnumel), stream=stream0) | |
buf25 = empty_strided((128*s0, 3072), (3072, 1), device='cuda', dtype=torch.float32) | |
aten.addmm.out(primals_58, as_strided(buf24, (128*s0, 768), (768, 1)), as_strided(primals_57, (768, 3072), (1, 768)), beta=1, alpha=1, out=buf25) | |
del primals_58 | |
buf26 = empty_strided((s0, 128, 3072), (393216, 3072, 1), device='cuda', dtype=torch.float32) | |
buf353 = empty_strided((s0, 128, 3072), (393216, 3072, 1), device='cuda', dtype=torch.float32) | |
triton_fused_clone_3_add_1_div_abs_1_view_13_mul_clone_2_bmm_1_mul_631_permute_87_xnumel = 393216*s0 | |
triton_fused_clone_3_add_1_div_abs_1_view_13_mul_clone_2_bmm_1_mul_631_permute_87.run(buf25, buf26, buf353, triton_fused_clone_3_add_1_div_abs_1_view_13_mul_clone_2_bmm_1_mul_631_permute_87_xnumel, grid=grid(triton_fused_clone_3_add_1_div_abs_1_view_13_mul_clone_2_bmm_1_mul_631_permute_87_xnumel), stream=stream0) | |
buf27 = as_strided(buf24, (128*s0, 768), (768, 1)); del buf24 # reuse | |
aten.addmm.out(primals_60, as_strided(buf26, (128*s0, 3072), (3072, 1)), as_strided(primals_59, (3072, 768), (1, 3072)), beta=1, alpha=1, out=buf27) | |
del primals_60 | |
buf28 = empty_strided((s0, 128, 1), (128, 1, 128*s0), device='cuda', dtype=torch.float32) | |
buf31 = empty_strided((s0, 128, 1), (128, 1, 128*s0), device='cuda', dtype=torch.float32) | |
buf29 = buf28; del buf28 # reuse | |
buf32 = buf31; del buf31 # reuse | |
buf33 = empty_strided((s0, 128, 768), (98304, 768, 1), device='cuda', dtype=torch.float32) | |
triton_fused_add_11_clone_3_add_1_div_abs_1_view_13_mul_add_12_clone_2_sub_48_xnumel = 128*s0 | |
triton_fused_add_11_clone_3_add_1_div_abs_1_view_13_mul_add_12_clone_2_sub_48.run(buf29, buf32, primals_193, buf18, buf27, primals_5, primals_6, buf33, triton_fused_add_11_clone_3_add_1_div_abs_1_view_13_mul_add_12_clone_2_sub_48_xnumel, 768, grid=grid(triton_fused_add_11_clone_3_add_1_div_abs_1_view_13_mul_add_12_clone_2_sub_48_xnumel), stream=stream0) | |
buf34 = empty_strided((128*s0, 768), (768, 1), device='cuda', dtype=torch.float32) | |
aten.addmm.out(primals_62, as_strided(buf33, (128*s0, 768), (768, 1)), as_strided(primals_61, (768, 768), (1, 768)), beta=1, alpha=1, out=buf34) | |
del primals_62 | |
buf35 = empty_strided((128*s0, 768), (768, 1), device='cuda', dtype=torch.float32) | |
aten.addmm.out(primals_64, as_strided(buf33, (128*s0, 768), (768, 1)), as_strided(primals_63, (768, 768), (1, 768)), beta=1, alpha=1, out=buf35) | |
del primals_64 | |
buf36 = empty_strided((128*s0, 768), (768, 1), device='cuda', dtype=torch.float32) | |
aten.addmm.out(primals_66, as_strided(buf33, (128*s0, 768), (768, 1)), as_strided(primals_65, (768, 768), (1, 768)), beta=1, alpha=1, out=buf36) | |
del primals_66 | |
buf37 = as_strided(buf33, (s0, 12, 128, 64), (98304, 8192, 64, 1)); del buf33 # reuse | |
triton_fused_add_1_div_mul_clone_2_view_7_permute_5_sub_permute_4_mean_add4_xnumel = 98304*s0 | |
triton_fused_add_1_div_mul_clone_2_view_7_permute_5_sub_permute_4_mean_add4.run(buf34, buf37, triton_fused_add_1_div_mul_clone_2_view_7_permute_5_sub_permute_4_mean_add4_xnumel, grid=grid(triton_fused_add_1_div_mul_clone_2_view_7_permute_5_sub_permute_4_mean_add4_xnumel), stream=stream0) | |
buf38 = as_strided(buf34, (s0, 12, 64, 128), (98304, 8192, 128, 1)); del buf34 # reuse | |
triton_fused_add_1_div_mul_permute_6_sub_addmm_1_mean_permute_3_add_clone_12_xnumel = 768*s0 | |
triton_fused_add_1_div_mul_permute_6_sub_addmm_1_mean_permute_3_add_clone_12.run(buf35, buf38, triton_fused_add_1_div_mul_permute_6_sub_addmm_1_mean_permute_3_add_clone_12_xnumel, 128, grid=grid(triton_fused_add_1_div_mul_permute_6_sub_addmm_1_mean_permute_3_add_clone_12_xnumel, 128), stream=stream0) | |
buf39 = buf11; del buf11 # reuse | |
aten.bmm.out(as_strided(buf37, (12*s0, 128, 64), (8192, 64, 1)), as_strided(buf38, (12*s0, 64, 128), (8192, 128, 1)), out=buf39) | |
buf42 = empty_strided((s0, 12, 128, 128), (196608, 16384, 128, 1), device='cuda', dtype=torch.float32) | |
triton_fused_add_1_div_mul_sub_mean_add_div_2_sub_1_div_1_amax3_xnumel = 1536*s0 | |
triton_fused_add_1_div_mul_sub_mean_add_div_2_sub_1_div_1_amax3.run(primals_194, buf39, buf42, triton_fused_add_1_div_mul_sub_mean_add_div_2_sub_1_div_1_amax3_xnumel, 128, grid=grid(triton_fused_add_1_div_mul_sub_mean_add_div_2_sub_1_div_1_amax3_xnumel), stream=stream0) | |
buf43 = as_strided(buf35, (s0, 12, 128, 64), (98304, 8192, 64, 1)); del buf35 # reuse | |
triton_fused_add_1_div_mul_clone_2_view_7_permute_5_sub_permute_4_mean_add4_xnumel = 98304*s0 | |
triton_fused_add_1_div_mul_clone_2_view_7_permute_5_sub_permute_4_mean_add4.run(buf36, buf43, triton_fused_add_1_div_mul_clone_2_view_7_permute_5_sub_permute_4_mean_add4_xnumel, grid=grid(triton_fused_add_1_div_mul_clone_2_view_7_permute_5_sub_permute_4_mean_add4_xnumel), stream=stream0) | |
buf44 = as_strided(buf36, (12*s0, 128, 64), (8192, 64, 1)); del buf36 # reuse | |
aten.bmm.out(as_strided(buf42, (12*s0, 128, 128), (16384, 128, 1)), as_strided(buf43, (12*s0, 128, 64), (8192, 64, 1)), out=buf44) | |
buf45 = empty_strided((s0, 128, 12, 64), (98304, 768, 64, 1), device='cuda', dtype=torch.float32) | |
triton_fused_clone_3_add_1_div_where_mul_clone_2_permute_5_bmm_1_permute_6_sum_15_xnumel = 98304*s0 | |
triton_fused_clone_3_add_1_div_where_mul_clone_2_permute_5_bmm_1_permute_6_sum_15.run(buf44, buf45, triton_fused_clone_3_add_1_div_where_mul_clone_2_permute_5_bmm_1_permute_6_sum_15_xnumel, grid=grid(triton_fused_clone_3_add_1_div_where_mul_clone_2_permute_5_bmm_1_permute_6_sum_15_xnumel), stream=stream0) | |
buf46 = empty_strided((128*s0, 768), (768, 1), device='cuda', dtype=torch.float32) | |
aten.addmm.out(primals_68, as_strided(buf45, (128*s0, 768), (768, 1)), as_strided(primals_67, (768, 768), (1, 768)), beta=1, alpha=1, out=buf46) | |
del primals_68 | |
buf47 = empty_strided((s0, 128, 1), (128, 1, 128*s0), device='cuda', dtype=torch.float32) | |
buf50 = empty_strided((s0, 128, 1), (128, 1, 128*s0), device='cuda', dtype=torch.float32) | |
buf48 = buf47; del buf47 # reuse | |
buf51 = buf50; del buf50 # reuse | |
buf52 = as_strided(buf45, (s0, 128, 768), (98304, 768, 1)); del buf45 # reuse | |
triton_fused_add_11_add_1_permute_19_div_view_23_mul_add_12_sub_4_permute_14_view_199_xnumel = 128*s0 | |
triton_fused_add_11_add_1_permute_19_div_view_23_mul_add_12_sub_4_permute_14_view_199.run(buf48, buf51, primals_193, buf18, buf27, buf46, primals_7, primals_8, buf52, triton_fused_add_11_add_1_permute_19_div_view_23_mul_add_12_sub_4_permute_14_view_199_xnumel, 768, grid=grid(triton_fused_add_11_add_1_permute_19_div_view_23_mul_add_12_sub_4_permute_14_view_199_xnumel), stream=stream0) | |
buf53 = buf25; del buf25 # reuse | |
aten.addmm.out(primals_70, as_strided(buf52, (128*s0, 768), (768, 1)), as_strided(primals_69, (768, 3072), (1, 768)), beta=1, alpha=1, out=buf53) | |
del primals_70 | |
buf54 = empty_strided((s0, 128, 3072), (393216, 3072, 1), device='cuda', dtype=torch.float32) | |
buf352 = empty_strided((s0, 128, 3072), (393216, 3072, 1), device='cuda', dtype=torch.float32) | |
triton_fused_clone_3_add_1_div_abs_1_view_13_mul_clone_2_bmm_1_mul_631_permute_87_xnumel = 393216*s0 | |
triton_fused_clone_3_add_1_div_abs_1_view_13_mul_clone_2_bmm_1_mul_631_permute_87.run(buf53, buf54, buf352, triton_fused_clone_3_add_1_div_abs_1_view_13_mul_clone_2_bmm_1_mul_631_permute_87_xnumel, grid=grid(triton_fused_clone_3_add_1_div_abs_1_view_13_mul_clone_2_bmm_1_mul_631_permute_87_xnumel), stream=stream0) | |
buf55 = as_strided(buf52, (128*s0, 768), (768, 1)); del buf52 # reuse | |
aten.addmm.out(primals_72, as_strided(buf54, (128*s0, 3072), (3072, 1)), as_strided(primals_71, (3072, 768), (1, 3072)), beta=1, alpha=1, out=buf55) | |
del primals_72 | |
buf56 = empty_strided((s0, 128, 768), (98304, 768, 1), device='cuda', dtype=torch.float32) | |
buf57 = empty_strided((s0, 128, 1), (128, 1, 128*s0), device='cuda', dtype=torch.float32) | |
buf60 = empty_strided((s0, 128, 1), (128, 1, 128*s0), device='cuda', dtype=torch.float32) | |
buf58 = buf57; del buf57 # reuse | |
buf61 = buf60; del buf60 # reuse | |
buf62 = empty_strided((s0, 128, 768), (98304, 768, 1), device='cuda', dtype=torch.float32) | |
triton_fused_add_11_add_1_permute_19_div_view_23_mul_add_12_view_35_sub_4_permute_1410_xnumel = 128*s0 | |
triton_fused_add_11_add_1_permute_19_div_view_23_mul_add_12_view_35_sub_4_permute_1410.run(buf58, buf61, primals_193, buf18, buf27, buf46, buf55, primals_9, primals_10, buf56, buf62, triton_fused_add_11_add_1_permute_19_div_view_23_mul_add_12_view_35_sub_4_permute_1410_xnumel, 768, grid=grid(triton_fused_add_11_add_1_permute_19_div_view_23_mul_add_12_view_35_sub_4_permute_1410_xnumel), stream=stream0) | |
buf63 = empty_strided((128*s0, 768), (768, 1), device='cuda', dtype=torch.float32) | |
aten.addmm.out(primals_74, as_strided(buf62, (128*s0, 768), (768, 1)), as_strided(primals_73, (768, 768), (1, 768)), beta=1, alpha=1, out=buf63) | |
del primals_74 | |
buf64 = empty_strided((128*s0, 768), (768, 1), device='cuda', dtype=torch.float32) | |
aten.addmm.out(primals_76, as_strided(buf62, (128*s0, 768), (768, 1)), as_strided(primals_75, (768, 768), (1, 768)), beta=1, alpha=1, out=buf64) | |
del primals_76 | |
buf65 = empty_strided((128*s0, 768), (768, 1), device='cuda', dtype=torch.float32) | |
aten.addmm.out(primals_78, as_strided(buf62, (128*s0, 768), (768, 1)), as_strided(primals_77, (768, 768), (1, 768)), beta=1, alpha=1, out=buf65) | |
del primals_78 | |
buf66 = as_strided(buf62, (s0, 12, 128, 64), (98304, 8192, 64, 1)); del buf62 # reuse | |
triton_fused_add_1_div_mul_clone_2_view_7_permute_5_sub_permute_4_mean_add4_xnumel = 98304*s0 | |
triton_fused_add_1_div_mul_clone_2_view_7_permute_5_sub_permute_4_mean_add4.run(buf63, buf66, triton_fused_add_1_div_mul_clone_2_view_7_permute_5_sub_permute_4_mean_add4_xnumel, grid=grid(triton_fused_add_1_div_mul_clone_2_view_7_permute_5_sub_permute_4_mean_add4_xnumel), stream=stream0) | |
buf67 = as_strided(buf63, (s0, 12, 64, 128), (98304, 8192, 128, 1)); del buf63 # reuse | |
triton_fused_add_1_div_mul_permute_6_sub_addmm_1_mean_permute_3_add_clone_12_xnumel = 768*s0 | |
triton_fused_add_1_div_mul_permute_6_sub_addmm_1_mean_permute_3_add_clone_12.run(buf64, buf67, triton_fused_add_1_div_mul_permute_6_sub_addmm_1_mean_permute_3_add_clone_12_xnumel, 128, grid=grid(triton_fused_add_1_div_mul_permute_6_sub_addmm_1_mean_permute_3_add_clone_12_xnumel, 128), stream=stream0) | |
buf68 = buf39; del buf39 # reuse | |
aten.bmm.out(as_strided(buf66, (12*s0, 128, 64), (8192, 64, 1)), as_strided(buf67, (12*s0, 64, 128), (8192, 128, 1)), out=buf68) | |
buf71 = empty_strided((s0, 12, 128, 128), (196608, 16384, 128, 1), device='cuda', dtype=torch.float32) | |
triton_fused_add_1_div_mul_sub_mean_add_div_2_sub_1_div_1_amax3_xnumel = 1536*s0 | |
triton_fused_add_1_div_mul_sub_mean_add_div_2_sub_1_div_1_amax3.run(primals_194, buf68, buf71, triton_fused_add_1_div_mul_sub_mean_add_div_2_sub_1_div_1_amax3_xnumel, 128, grid=grid(triton_fused_add_1_div_mul_sub_mean_add_div_2_sub_1_div_1_amax3_xnumel), stream=stream0) | |
buf72 = as_strided(buf64, (s0, 12, 128, 64), (98304, 8192, 64, 1)); del buf64 # reuse | |
triton_fused_add_1_div_mul_clone_2_view_7_permute_5_sub_permute_4_mean_add4_xnumel = 98304*s0 | |
triton_fused_add_1_div_mul_clone_2_view_7_permute_5_sub_permute_4_mean_add4.run(buf65, buf72, triton_fused_add_1_div_mul_clone_2_view_7_permute_5_sub_permute_4_mean_add4_xnumel, grid=grid(triton_fused_add_1_div_mul_clone_2_view_7_permute_5_sub_permute_4_mean_add4_xnumel), stream=stream0) | |
buf73 = as_strided(buf65, (12*s0, 128, 64), (8192, 64, 1)); del buf65 # reuse | |
aten.bmm.out(as_strided(buf71, (12*s0, 128, 128), (16384, 128, 1)), as_strided(buf72, (12*s0, 128, 64), (8192, 64, 1)), out=buf73) | |
buf74 = empty_strided((s0, 128, 12, 64), (98304, 768, 64, 1), device='cuda', dtype=torch.float32) | |
triton_fused_clone_3_add_1_div_where_mul_clone_2_permute_5_bmm_1_permute_6_sum_15_xnumel = 98304*s0 | |
triton_fused_clone_3_add_1_div_where_mul_clone_2_permute_5_bmm_1_permute_6_sum_15.run(buf73, buf74, triton_fused_clone_3_add_1_div_where_mul_clone_2_permute_5_bmm_1_permute_6_sum_15_xnumel, grid=grid(triton_fused_clone_3_add_1_div_where_mul_clone_2_permute_5_bmm_1_permute_6_sum_15_xnumel), stream=stream0) | |
buf75 = empty_strided((128*s0, 768), (768, 1), device='cuda', dtype=torch.float32) | |
aten.addmm.out(primals_80, as_strided(buf74, (128*s0, 768), (768, 1)), as_strided(primals_79, (768, 768), (1, 768)), beta=1, alpha=1, out=buf75) | |
del primals_80 | |
buf76 = empty_strided((s0, 128, 1), (128, 1, 128*s0), device='cuda', dtype=torch.float32) | |
buf79 = empty_strided((s0, 128, 1), (128, 1, 128*s0), device='cuda', dtype=torch.float32) | |
buf77 = buf76; del buf76 # reuse | |
buf80 = buf79; del buf79 # reuse | |
buf81 = as_strided(buf74, (s0, 128, 768), (98304, 768, 1)); del buf74 # reuse | |
triton_fused_clone_3_add_1_div_view_13_mul_clone_2_bmm_1_permute_8_sub_mean6_xnumel = 128*s0 | |
triton_fused_clone_3_add_1_div_view_13_mul_clone_2_bmm_1_permute_8_sub_mean6.run(buf77, buf80, buf56, buf75, primals_11, primals_12, buf81, triton_fused_clone_3_add_1_div_view_13_mul_clone_2_bmm_1_permute_8_sub_mean6_xnumel, 768, grid=grid(triton_fused_clone_3_add_1_div_view_13_mul_clone_2_bmm_1_permute_8_sub_mean6_xnumel), stream=stream0) | |
buf82 = buf53; del buf53 # reuse | |
aten.addmm.out(primals_82, as_strided(buf81, (128*s0, 768), (768, 1)), as_strided(primals_81, (768, 3072), (1, 768)), beta=1, alpha=1, out=buf82) | |
del primals_82 | |
buf83 = empty_strided((s0, 128, 3072), (393216, 3072, 1), device='cuda', dtype=torch.float32) | |
buf351 = empty_strided((s0, 128, 3072), (393216, 3072, 1), device='cuda', dtype=torch.float32) | |
triton_fused_clone_3_add_1_div_abs_1_view_13_mul_clone_2_bmm_1_mul_631_permute_87_xnumel = 393216*s0 | |
triton_fused_clone_3_add_1_div_abs_1_view_13_mul_clone_2_bmm_1_mul_631_permute_87.run(buf82, buf83, buf351, triton_fused_clone_3_add_1_div_abs_1_view_13_mul_clone_2_bmm_1_mul_631_permute_87_xnumel, grid=grid(triton_fused_clone_3_add_1_div_abs_1_view_13_mul_clone_2_bmm_1_mul_631_permute_87_xnumel), stream=stream0) | |
buf84 = as_strided(buf81, (128*s0, 768), (768, 1)); del buf81 # reuse | |
aten.addmm.out(primals_84, as_strided(buf83, (128*s0, 3072), (3072, 1)), as_strided(primals_83, (3072, 768), (1, 3072)), beta=1, alpha=1, out=buf84) | |
del primals_84 | |
buf85 = empty_strided((s0, 128, 1), (128, 1, 128*s0), device='cuda', dtype=torch.float32) | |
buf88 = empty_strided((s0, 128, 1), (128, 1, 128*s0), device='cuda', dtype=torch.float32) | |
buf86 = buf85; del buf85 # reuse | |
buf89 = buf88; del buf88 # reuse | |
buf90 = empty_strided((s0, 128, 768), (98304, 768, 1), device='cuda', dtype=torch.float32) | |
triton_fused_add_11_clone_3_add_1_div_abs_1_view_13_mul_add_12_clone_2_sub_48_xnumel = 128*s0 | |
triton_fused_add_11_clone_3_add_1_div_abs_1_view_13_mul_add_12_clone_2_sub_48.run(buf86, buf89, buf56, buf75, buf84, primals_13, primals_14, buf90, triton_fused_add_11_clone_3_add_1_div_abs_1_view_13_mul_add_12_clone_2_sub_48_xnumel, 768, grid=grid(triton_fused_add_11_clone_3_add_1_div_abs_1_view_13_mul_add_12_clone_2_sub_48_xnumel), stream=stream0) | |
buf91 = empty_strided((128*s0, 768), (768, 1), device='cuda', dtype=torch.float32) | |
aten.addmm.out(primals_86, as_strided(buf90, (128*s0, 768), (768, 1)), as_strided(primals_85, (768, 768), (1, 768)), beta=1, alpha=1, out=buf91) | |
del primals_86 | |
buf92 = empty_strided((128*s0, 768), (768, 1), device='cuda', dtype=torch.float32) | |
aten.addmm.out(primals_88, as_strided(buf90, (128*s0, 768), (768, 1)), as_strided(primals_87, (768, 768), (1, 768)), beta=1, alpha=1, out=buf92) | |
del primals_88 | |
buf93 = empty_strided((128*s0, 768), (768, 1), device='cuda', dtype=torch.float32) | |
aten.addmm.out(primals_90, as_strided(buf90, (128*s0, 768), (768, 1)), as_strided(primals_89, (768, 768), (1, 768)), beta=1, alpha=1, out=buf93) | |
del primals_90 | |
buf94 = as_strided(buf90, (s0, 12, 128, 64), (98304, 8192, 64, 1)); del buf90 # reuse | |
triton_fused_add_1_div_mul_clone_2_view_7_permute_5_sub_permute_4_mean_add4_xnumel = 98304*s0 | |
triton_fused_add_1_div_mul_clone_2_view_7_permute_5_sub_permute_4_mean_add4.run(buf91, buf94, triton_fused_add_1_div_mul_clone_2_view_7_permute_5_sub_permute_4_mean_add4_xnumel, grid=grid(triton_fused_add_1_div_mul_clone_2_view_7_permute_5_sub_permute_4_mean_add4_xnumel), stream=stream0) | |
buf95 = as_strided(buf91, (s0, 12, 64, 128), (98304, 8192, 128, 1)); del buf91 # reuse | |
triton_fused_add_1_div_mul_permute_6_sub_addmm_1_mean_permute_3_add_clone_12_xnumel = 768*s0 | |
triton_fused_add_1_div_mul_permute_6_sub_addmm_1_mean_permute_3_add_clone_12.run(buf92, buf95, triton_fused_add_1_div_mul_permute_6_sub_addmm_1_mean_permute_3_add_clone_12_xnumel, 128, grid=grid(triton_fused_add_1_div_mul_permute_6_sub_addmm_1_mean_permute_3_add_clone_12_xnumel, 128), stream=stream0) | |
buf96 = buf68; del buf68 # reuse | |
aten.bmm.out(as_strided(buf94, (12*s0, 128, 64), (8192, 64, 1)), as_strided(buf95, (12*s0, 64, 128), (8192, 128, 1)), out=buf96) | |
buf99 = empty_strided((s0, 12, 128, 128), (196608, 16384, 128, 1), device='cuda', dtype=torch.float32) | |
triton_fused_add_1_div_mul_sub_mean_add_div_2_sub_1_div_1_amax3_xnumel = 1536*s0 | |
triton_fused_add_1_div_mul_sub_mean_add_div_2_sub_1_div_1_amax3.run(primals_194, buf96, buf99, triton_fused_add_1_div_mul_sub_mean_add_div_2_sub_1_div_1_amax3_xnumel, 128, grid=grid(triton_fused_add_1_div_mul_sub_mean_add_div_2_sub_1_div_1_amax3_xnumel), stream=stream0) | |
buf100 = as_strided(buf92, (s0, 12, 128, 64), (98304, 8192, 64, 1)); del buf92 # reuse | |
triton_fused_add_1_div_mul_clone_2_view_7_permute_5_sub_permute_4_mean_add4_xnumel = 98304*s0 | |
triton_fused_add_1_div_mul_clone_2_view_7_permute_5_sub_permute_4_mean_add4.run(buf93, buf100, triton_fused_add_1_div_mul_clone_2_view_7_permute_5_sub_permute_4_mean_add4_xnumel, grid=grid(triton_fused_add_1_div_mul_clone_2_view_7_permute_5_sub_permute_4_mean_add4_xnumel), stream=stream0) | |
buf101 = as_strided(buf93, (12*s0, 128, 64), (8192, 64, 1)); del buf93 # reuse | |
aten.bmm.out(as_strided(buf99, (12*s0, 128, 128), (16384, 128, 1)), as_strided(buf100, (12*s0, 128, 64), (8192, 64, 1)), out=buf101) | |
buf102 = empty_strided((s0, 128, 12, 64), (98304, 768, 64, 1), device='cuda', dtype=torch.float32) | |
triton_fused_clone_3_add_1_div_where_mul_clone_2_permute_5_bmm_1_permute_6_sum_15_xnumel = 98304*s0 | |
triton_fused_clone_3_add_1_div_where_mul_clone_2_permute_5_bmm_1_permute_6_sum_15.run(buf101, buf102, triton_fused_clone_3_add_1_div_where_mul_clone_2_permute_5_bmm_1_permute_6_sum_15_xnumel, grid=grid(triton_fused_clone_3_add_1_div_where_mul_clone_2_permute_5_bmm_1_permute_6_sum_15_xnumel), stream=stream0) | |
buf103 = empty_strided((128*s0, 768), (768, 1), device='cuda', dtype=torch.float32) | |
aten.addmm.out(primals_92, as_strided(buf102, (128*s0, 768), (768, 1)), as_strided(primals_91, (768, 768), (1, 768)), beta=1, alpha=1, out=buf103) | |
del primals_92 | |
buf104 = empty_strided((s0, 128, 1), (128, 1, 128*s0), device='cuda', dtype=torch.float32) | |
buf107 = empty_strided((s0, 128, 1), (128, 1, 128*s0), device='cuda', dtype=torch.float32) | |
buf105 = buf104; del buf104 # reuse | |
buf108 = buf107; del buf107 # reuse | |
buf109 = as_strided(buf102, (s0, 128, 768), (98304, 768, 1)); del buf102 # reuse | |
triton_fused_add_11_add_1_permute_19_div_view_23_mul_add_12_sub_4_permute_14_view_199_xnumel = 128*s0 | |
triton_fused_add_11_add_1_permute_19_div_view_23_mul_add_12_sub_4_permute_14_view_199.run(buf105, buf108, buf56, buf75, buf84, buf103, primals_15, primals_16, buf109, triton_fused_add_11_add_1_permute_19_div_view_23_mul_add_12_sub_4_permute_14_view_199_xnumel, 768, grid=grid(triton_fused_add_11_add_1_permute_19_div_view_23_mul_add_12_sub_4_permute_14_view_199_xnumel), stream=stream0) | |
buf110 = buf82; del buf82 # reuse | |
aten.addmm.out(primals_94, as_strided(buf109, (128*s0, 768), (768, 1)), as_strided(primals_93, (768, 3072), (1, 768)), beta=1, alpha=1, out=buf110) | |
del primals_94 | |
buf111 = empty_strided((s0, 128, 3072), (393216, 3072, 1), device='cuda', dtype=torch.float32) | |
buf350 = empty_strided((s0, 128, 3072), (393216, 3072, 1), device='cuda', dtype=torch.float32) | |
triton_fused_clone_3_add_1_div_abs_1_view_13_mul_clone_2_bmm_1_mul_631_permute_87_xnumel = 393216*s0 | |
triton_fused_clone_3_add_1_div_abs_1_view_13_mul_clone_2_bmm_1_mul_631_permute_87.run(buf110, buf111, buf350, triton_fused_clone_3_add_1_div_abs_1_view_13_mul_clone_2_bmm_1_mul_631_permute_87_xnumel, grid=grid(triton_fused_clone_3_add_1_div_abs_1_view_13_mul_clone_2_bmm_1_mul_631_permute_87_xnumel), stream=stream0) | |
buf112 = as_strided(buf109, (128*s0, 768), (768, 1)); del buf109 # reuse | |
aten.addmm.out(primals_96, as_strided(buf111, (128*s0, 3072), (3072, 1)), as_strided(primals_95, (3072, 768), (1, 3072)), beta=1, alpha=1, out=buf112) | |
del primals_96 | |
buf113 = buf56; del buf56 # reuse | |
buf114 = empty_strided((s0, 128, 1), (128, 1, 128*s0), device='cuda', dtype=torch.float32) | |
buf117 = empty_strided((s0, 128, 1), (128, 1, 128*s0), device='cuda', dtype=torch.float32) | |
buf115 = buf114; del buf114 # reuse | |
buf118 = buf117; del buf117 # reuse | |
buf119 = empty_strided((s0, 128, 768), (98304, 768, 1), device='cuda', dtype=torch.float32) | |
triton_fused_permute_19_view_35_view_34_reciprocal_1_view_33_addmm_9_mul_32_permute_20_std_3_add_1411_xnumel = 128*s0 | |
triton_fused_permute_19_view_35_view_34_reciprocal_1_view_33_addmm_9_mul_32_permute_20_std_3_add_1411.run(buf113, buf115, buf118, buf75, buf84, buf103, buf112, primals_17, primals_18, buf119, triton_fused_permute_19_view_35_view_34_reciprocal_1_view_33_addmm_9_mul_32_permute_20_std_3_add_1411_xnumel, 768, grid=grid(triton_fused_permute_19_view_35_view_34_reciprocal_1_view_33_addmm_9_mul_32_permute_20_std_3_add_1411_xnumel), stream=stream0) | |
buf120 = empty_strided((128*s0, 768), (768, 1), device='cuda', dtype=torch.float32) | |
aten.addmm.out(primals_98, as_strided(buf119, (128*s0, 768), (768, 1)), as_strided(primals_97, (768, 768), (1, 768)), beta=1, alpha=1, out=buf120) | |
del primals_98 | |
buf121 = empty_strided((128*s0, 768), (768, 1), device='cuda', dtype=torch.float32) | |
aten.addmm.out(primals_100, as_strided(buf119, (128*s0, 768), (768, 1)), as_strided(primals_99, (768, 768), (1, 768)), beta=1, alpha=1, out=buf121) | |
del primals_100 | |
buf122 = empty_strided((128*s0, 768), (768, 1), device='cuda', dtype=torch.float32) | |
aten.addmm.out(primals_102, as_strided(buf119, (128*s0, 768), (768, 1)), as_strided(primals_101, (768, 768), (1, 768)), beta=1, alpha=1, out=buf122) | |
del primals_102 | |
buf123 = as_strided(buf119, (s0, 12, 128, 64), (98304, 8192, 64, 1)); del buf119 # reuse | |
triton_fused_add_1_div_mul_clone_2_view_7_permute_5_sub_permute_4_mean_add4_xnumel = 98304*s0 | |
triton_fused_add_1_div_mul_clone_2_view_7_permute_5_sub_permute_4_mean_add4.run(buf120, buf123, triton_fused_add_1_div_mul_clone_2_view_7_permute_5_sub_permute_4_mean_add4_xnumel, grid=grid(triton_fused_add_1_div_mul_clone_2_view_7_permute_5_sub_permute_4_mean_add4_xnumel), stream=stream0) | |
buf124 = as_strided(buf120, (s0, 12, 64, 128), (98304, 8192, 128, 1)); del buf120 # reuse | |
triton_fused_add_1_div_mul_permute_6_sub_addmm_1_mean_permute_3_add_clone_12_xnumel = 768*s0 | |
triton_fused_add_1_div_mul_permute_6_sub_addmm_1_mean_permute_3_add_clone_12.run(buf121, buf124, triton_fused_add_1_div_mul_permute_6_sub_addmm_1_mean_permute_3_add_clone_12_xnumel, 128, grid=grid(triton_fused_add_1_div_mul_permute_6_sub_addmm_1_mean_permute_3_add_clone_12_xnumel, 128), stream=stream0) | |
buf125 = buf96; del buf96 # reuse | |
aten.bmm.out(as_strided(buf123, (12*s0, 128, 64), (8192, 64, 1)), as_strided(buf124, (12*s0, 64, 128), (8192, 128, 1)), out=buf125) | |
buf128 = empty_strided((s0, 12, 128, 128), (196608, 16384, 128, 1), device='cuda', dtype=torch.float32) | |
triton_fused_add_1_div_mul_sub_mean_add_div_2_sub_1_div_1_amax3_xnumel = 1536*s0 | |
triton_fused_add_1_div_mul_sub_mean_add_div_2_sub_1_div_1_amax3.run(primals_194, buf125, buf128, triton_fused_add_1_div_mul_sub_mean_add_div_2_sub_1_div_1_amax3_xnumel, 128, grid=grid(triton_fused_add_1_div_mul_sub_mean_add_div_2_sub_1_div_1_amax3_xnumel), stream=stream0) | |
buf129 = as_strided(buf121, (s0, 12, 128, 64), (98304, 8192, 64, 1)); del buf121 # reuse | |
triton_fused_add_1_div_mul_clone_2_view_7_permute_5_sub_permute_4_mean_add4_xnumel = 98304*s0 | |
triton_fused_add_1_div_mul_clone_2_view_7_permute_5_sub_permute_4_mean_add4.run(buf122, buf129, triton_fused_add_1_div_mul_clone_2_view_7_permute_5_sub_permute_4_mean_add4_xnumel, grid=grid(triton_fused_add_1_div_mul_clone_2_view_7_permute_5_sub_permute_4_mean_add4_xnumel), stream=stream0) | |
buf130 = as_strided(buf122, (12*s0, 128, 64), (8192, 64, 1)); del buf122 # reuse | |
aten.bmm.out(as_strided(buf128, (12*s0, 128, 128), (16384, 128, 1)), as_strided(buf129, (12*s0, 128, 64), (8192, 64, 1)), out=buf130) | |
buf131 = empty_strided((s0, 128, 12, 64), (98304, 768, 64, 1), device='cuda', dtype=torch.float32) | |
triton_fused_clone_3_add_1_div_where_mul_clone_2_permute_5_bmm_1_permute_6_sum_15_xnumel = 98304*s0 | |
triton_fused_clone_3_add_1_div_where_mul_clone_2_permute_5_bmm_1_permute_6_sum_15.run(buf130, buf131, triton_fused_clone_3_add_1_div_where_mul_clone_2_permute_5_bmm_1_permute_6_sum_15_xnumel, grid=grid(triton_fused_clone_3_add_1_div_where_mul_clone_2_permute_5_bmm_1_permute_6_sum_15_xnumel), stream=stream0) | |
buf132 = empty_strided((128*s0, 768), (768, 1), device='cuda', dtype=torch.float32) | |
aten.addmm.out(primals_104, as_strided(buf131, (128*s0, 768), (768, 1)), as_strided(primals_103, (768, 768), (1, 768)), beta=1, alpha=1, out=buf132) | |
del primals_104 | |
buf133 = empty_strided((s0, 128, 1), (128, 1, 128*s0), device='cuda', dtype=torch.float32) | |
buf136 = empty_strided((s0, 128, 1), (128, 1, 128*s0), device='cuda', dtype=torch.float32) | |
buf134 = buf133; del buf133 # reuse | |
buf137 = buf136; del buf136 # reuse | |
buf138 = as_strided(buf131, (s0, 128, 768), (98304, 768, 1)); del buf131 # reuse | |
triton_fused_clone_3_add_1_div_view_13_mul_clone_2_bmm_1_permute_8_sub_mean6_xnumel = 128*s0 | |
triton_fused_clone_3_add_1_div_view_13_mul_clone_2_bmm_1_permute_8_sub_mean6.run(buf134, buf137, buf113, buf132, primals_19, primals_20, buf138, triton_fused_clone_3_add_1_div_view_13_mul_clone_2_bmm_1_permute_8_sub_mean6_xnumel, 768, grid=grid(triton_fused_clone_3_add_1_div_view_13_mul_clone_2_bmm_1_permute_8_sub_mean6_xnumel), stream=stream0) | |
buf139 = buf110; del buf110 # reuse | |
aten.addmm.out(primals_106, as_strided(buf138, (128*s0, 768), (768, 1)), as_strided(primals_105, (768, 3072), (1, 768)), beta=1, alpha=1, out=buf139) | |
del primals_106 | |
buf140 = empty_strided((s0, 128, 3072), (393216, 3072, 1), device='cuda', dtype=torch.float32) | |
buf349 = empty_strided((s0, 128, 3072), (393216, 3072, 1), device='cuda', dtype=torch.float32) | |
triton_fused_clone_3_add_1_div_abs_1_view_13_mul_clone_2_bmm_1_mul_631_permute_87_xnumel = 393216*s0 | |
triton_fused_clone_3_add_1_div_abs_1_view_13_mul_clone_2_bmm_1_mul_631_permute_87.run(buf139, buf140, buf349, triton_fused_clone_3_add_1_div_abs_1_view_13_mul_clone_2_bmm_1_mul_631_permute_87_xnumel, grid=grid(triton_fused_clone_3_add_1_div_abs_1_view_13_mul_clone_2_bmm_1_mul_631_permute_87_xnumel), stream=stream0) | |
buf141 = as_strided(buf138, (128*s0, 768), (768, 1)); del buf138 # reuse | |
aten.addmm.out(primals_108, as_strided(buf140, (128*s0, 3072), (3072, 1)), as_strided(primals_107, (3072, 768), (1, 3072)), beta=1, alpha=1, out=buf141) | |
del primals_108 | |
buf142 = empty_strided((s0, 128, 1), (128, 1, 128*s0), device='cuda', dtype=torch.float32) | |
buf145 = empty_strided((s0, 128, 1), (128, 1, 128*s0), device='cuda', dtype=torch.float32) | |
buf143 = buf142; del buf142 # reuse | |
buf146 = buf145; del buf145 # reuse | |
buf147 = empty_strided((s0, 128, 768), (98304, 768, 1), device='cuda', dtype=torch.float32) | |
triton_fused_add_11_clone_3_add_1_div_abs_1_view_13_mul_add_12_clone_2_sub_48_xnumel = 128*s0 | |
triton_fused_add_11_clone_3_add_1_div_abs_1_view_13_mul_add_12_clone_2_sub_48.run(buf143, buf146, buf113, buf132, buf141, primals_21, primals_22, buf147, triton_fused_add_11_clone_3_add_1_div_abs_1_view_13_mul_add_12_clone_2_sub_48_xnumel, 768, grid=grid(triton_fused_add_11_clone_3_add_1_div_abs_1_view_13_mul_add_12_clone_2_sub_48_xnumel), stream=stream0) | |
buf148 = empty_strided((128*s0, 768), (768, 1), device='cuda', dtype=torch.float32) | |
aten.addmm.out(primals_110, as_strided(buf147, (128*s0, 768), (768, 1)), as_strided(primals_109, (768, 768), (1, 768)), beta=1, alpha=1, out=buf148) | |
del primals_110 | |
buf149 = empty_strided((128*s0, 768), (768, 1), device='cuda', dtype=torch.float32) | |
aten.addmm.out(primals_112, as_strided(buf147, (128*s0, 768), (768, 1)), as_strided(primals_111, (768, 768), (1, 768)), beta=1, alpha=1, out=buf149) | |
del primals_112 | |
buf150 = empty_strided((128*s0, 768), (768, 1), device='cuda', dtype=torch.float32) | |
aten.addmm.out(primals_114, as_strided(buf147, (128*s0, 768), (768, 1)), as_strided(primals_113, (768, 768), (1, 768)), beta=1, alpha=1, out=buf150) | |
del primals_114 | |
buf151 = as_strided(buf147, (s0, 12, 128, 64), (98304, 8192, 64, 1)); del buf147 # reuse | |
triton_fused_add_1_div_mul_clone_2_view_7_permute_5_sub_permute_4_mean_add4_xnumel = 98304*s0 | |
triton_fused_add_1_div_mul_clone_2_view_7_permute_5_sub_permute_4_mean_add4.run(buf148, buf151, triton_fused_add_1_div_mul_clone_2_view_7_permute_5_sub_permute_4_mean_add4_xnumel, grid=grid(triton_fused_add_1_div_mul_clone_2_view_7_permute_5_sub_permute_4_mean_add4_xnumel), stream=stream0) | |
buf152 = as_strided(buf148, (s0, 12, 64, 128), (98304, 8192, 128, 1)); del buf148 # reuse | |
triton_fused_add_1_div_mul_permute_6_sub_addmm_1_mean_permute_3_add_clone_12_xnumel = 768*s0 | |
triton_fused_add_1_div_mul_permute_6_sub_addmm_1_mean_permute_3_add_clone_12.run(buf149, buf152, triton_fused_add_1_div_mul_permute_6_sub_addmm_1_mean_permute_3_add_clone_12_xnumel, 128, grid=grid(triton_fused_add_1_div_mul_permute_6_sub_addmm_1_mean_permute_3_add_clone_12_xnumel, 128), stream=stream0) | |
buf153 = buf125; del buf125 # reuse | |
aten.bmm.out(as_strided(buf151, (12*s0, 128, 64), (8192, 64, 1)), as_strided(buf152, (12*s0, 64, 128), (8192, 128, 1)), out=buf153) | |
buf156 = empty_strided((s0, 12, 128, 128), (196608, 16384, 128, 1), device='cuda', dtype=torch.float32) | |
triton_fused_add_1_div_mul_sub_mean_add_div_2_sub_1_div_1_amax3_xnumel = 1536*s0 | |
triton_fused_add_1_div_mul_sub_mean_add_div_2_sub_1_div_1_amax3.run(primals_194, buf153, buf156, triton_fused_add_1_div_mul_sub_mean_add_div_2_sub_1_div_1_amax3_xnumel, 128, grid=grid(triton_fused_add_1_div_mul_sub_mean_add_div_2_sub_1_div_1_amax3_xnumel), stream=stream0) | |
buf157 = as_strided(buf149, (s0, 12, 128, 64), (98304, 8192, 64, 1)); del buf149 # reuse | |
triton_fused_add_1_div_mul_clone_2_view_7_permute_5_sub_permute_4_mean_add4_xnumel = 98304*s0 | |
triton_fused_add_1_div_mul_clone_2_view_7_permute_5_sub_permute_4_mean_add4.run(buf150, buf157, triton_fused_add_1_div_mul_clone_2_view_7_permute_5_sub_permute_4_mean_add4_xnumel, grid=grid(triton_fused_add_1_div_mul_clone_2_view_7_permute_5_sub_permute_4_mean_add4_xnumel), stream=stream0) | |
buf158 = as_strided(buf150, (12*s0, 128, 64), (8192, 64, 1)); del buf150 # reuse | |
aten.bmm.out(as_strided(buf156, (12*s0, 128, 128), (16384, 128, 1)), as_strided(buf157, (12*s0, 128, 64), (8192, 64, 1)), out=buf158) | |
buf159 = empty_strided((s0, 128, 12, 64), (98304, 768, 64, 1), device='cuda', dtype=torch.float32) | |
triton_fused_clone_3_add_1_div_where_mul_clone_2_permute_5_bmm_1_permute_6_sum_15_xnumel = 98304*s0 | |
triton_fused_clone_3_add_1_div_where_mul_clone_2_permute_5_bmm_1_permute_6_sum_15.run(buf158, buf159, triton_fused_clone_3_add_1_div_where_mul_clone_2_permute_5_bmm_1_permute_6_sum_15_xnumel, grid=grid(triton_fused_clone_3_add_1_div_where_mul_clone_2_permute_5_bmm_1_permute_6_sum_15_xnumel), stream=stream0) | |
buf160 = empty_strided((128*s0, 768), (768, 1), device='cuda', dtype=torch.float32) | |
aten.addmm.out(primals_116, as_strided(buf159, (128*s0, 768), (768, 1)), as_strided(primals_115, (768, 768), (1, 768)), beta=1, alpha=1, out=buf160) | |
del primals_116 | |
buf161 = empty_strided((s0, 128, 1), (128, 1, 128*s0), device='cuda', dtype=torch.float32) | |
buf164 = empty_strided((s0, 128, 1), (128, 1, 128*s0), device='cuda', dtype=torch.float32) | |
buf162 = buf161; del buf161 # reuse | |
buf165 = buf164; del buf164 # reuse | |
buf166 = as_strided(buf159, (s0, 128, 768), (98304, 768, 1)); del buf159 # reuse | |
triton_fused_add_11_add_1_permute_19_div_view_23_mul_add_12_sub_4_permute_14_view_199_xnumel = 128*s0 | |
triton_fused_add_11_add_1_permute_19_div_view_23_mul_add_12_sub_4_permute_14_view_199.run(buf162, buf165, buf113, buf132, buf141, buf160, primals_23, primals_24, buf166, triton_fused_add_11_add_1_permute_19_div_view_23_mul_add_12_sub_4_permute_14_view_199_xnumel, 768, grid=grid(triton_fused_add_11_add_1_permute_19_div_view_23_mul_add_12_sub_4_permute_14_view_199_xnumel), stream=stream0) | |
buf167 = buf139; del buf139 # reuse | |
aten.addmm.out(primals_118, as_strided(buf166, (128*s0, 768), (768, 1)), as_strided(primals_117, (768, 3072), (1, 768)), beta=1, alpha=1, out=buf167) | |
del primals_118 | |
buf168 = empty_strided((s0, 128, 3072), (393216, 3072, 1), device='cuda', dtype=torch.float32) | |
buf348 = empty_strided((s0, 128, 3072), (393216, 3072, 1), device='cuda', dtype=torch.float32) | |
triton_fused_clone_3_add_1_div_abs_1_view_13_mul_clone_2_bmm_1_mul_631_permute_87_xnumel = 393216*s0 | |
triton_fused_clone_3_add_1_div_abs_1_view_13_mul_clone_2_bmm_1_mul_631_permute_87.run(buf167, buf168, buf348, triton_fused_clone_3_add_1_div_abs_1_view_13_mul_clone_2_bmm_1_mul_631_permute_87_xnumel, grid=grid(triton_fused_clone_3_add_1_div_abs_1_view_13_mul_clone_2_bmm_1_mul_631_permute_87_xnumel), stream=stream0) | |
buf169 = as_strided(buf166, (128*s0, 768), (768, 1)); del buf166 # reuse | |
aten.addmm.out(primals_120, as_strided(buf168, (128*s0, 3072), (3072, 1)), as_strided(primals_119, (3072, 768), (1, 3072)), beta=1, alpha=1, out=buf169) | |
del primals_120 | |
buf170 = buf113; del buf113 # reuse | |
buf171 = empty_strided((s0, 128, 1), (128, 1, 128*s0), device='cuda', dtype=torch.float32) | |
buf174 = empty_strided((s0, 128, 1), (128, 1, 128*s0), device='cuda', dtype=torch.float32) | |
buf172 = buf171; del buf171 # reuse | |
buf175 = buf174; del buf174 # reuse | |
buf176 = empty_strided((s0, 128, 768), (98304, 768, 1), device='cuda', dtype=torch.float32) | |
triton_fused_permute_19_view_35_view_34_reciprocal_1_view_33_addmm_9_mul_32_permute_20_std_3_add_1411_xnumel = 128*s0 | |
triton_fused_permute_19_view_35_view_34_reciprocal_1_view_33_addmm_9_mul_32_permute_20_std_3_add_1411.run(buf170, buf172, buf175, buf132, buf141, buf160, buf169, primals_25, primals_26, buf176, triton_fused_permute_19_view_35_view_34_reciprocal_1_view_33_addmm_9_mul_32_permute_20_std_3_add_1411_xnumel, 768, grid=grid(triton_fused_permute_19_view_35_view_34_reciprocal_1_view_33_addmm_9_mul_32_permute_20_std_3_add_1411_xnumel), stream=stream0) | |
buf177 = empty_strided((128*s0, 768), (768, 1), device='cuda', dtype=torch.float32) | |
aten.addmm.out(primals_122, as_strided(buf176, (128*s0, 768), (768, 1)), as_strided(primals_121, (768, 768), (1, 768)), beta=1, alpha=1, out=buf177) | |
del primals_122 | |
buf178 = empty_strided((128*s0, 768), (768, 1), device='cuda', dtype=torch.float32) | |
aten.addmm.out(primals_124, as_strided(buf176, (128*s0, 768), (768, 1)), as_strided(primals_123, (768, 768), (1, 768)), beta=1, alpha=1, out=buf178) | |
del primals_124 | |
buf179 = empty_strided((128*s0, 768), (768, 1), device='cuda', dtype=torch.float32) | |
aten.addmm.out(primals_126, as_strided(buf176, (128*s0, 768), (768, 1)), as_strided(primals_125, (768, 768), (1, 768)), beta=1, alpha=1, out=buf179) | |
del primals_126 | |
buf180 = as_strided(buf176, (s0, 12, 128, 64), (98304, 8192, 64, 1)); del buf176 # reuse | |
triton_fused_add_1_div_mul_clone_2_view_7_permute_5_sub_permute_4_mean_add4_xnumel = 98304*s0 | |
triton_fused_add_1_div_mul_clone_2_view_7_permute_5_sub_permute_4_mean_add4.run(buf177, buf180, triton_fused_add_1_div_mul_clone_2_view_7_permute_5_sub_permute_4_mean_add4_xnumel, grid=grid(triton_fused_add_1_div_mul_clone_2_view_7_permute_5_sub_permute_4_mean_add4_xnumel), stream=stream0) | |
buf181 = as_strided(buf177, (s0, 12, 64, 128), (98304, 8192, 128, 1)); del buf177 # reuse | |
triton_fused_add_1_div_mul_permute_6_sub_addmm_1_mean_permute_3_add_clone_12_xnumel = 768*s0 | |
triton_fused_add_1_div_mul_permute_6_sub_addmm_1_mean_permute_3_add_clone_12.run(buf178, buf181, triton_fused_add_1_div_mul_permute_6_sub_addmm_1_mean_permute_3_add_clone_12_xnumel, 128, grid=grid(triton_fused_add_1_div_mul_permute_6_sub_addmm_1_mean_permute_3_add_clone_12_xnumel, 128), stream=stream0) | |
buf182 = buf153; del buf153 # reuse | |
aten.bmm.out(as_strided(buf180, (12*s0, 128, 64), (8192, 64, 1)), as_strided(buf181, (12*s0, 64, 128), (8192, 128, 1)), out=buf182) | |
buf185 = empty_strided((s0, 12, 128, 128), (196608, 16384, 128, 1), device='cuda', dtype=torch.float32) | |
triton_fused_add_1_div_mul_sub_mean_add_div_2_sub_1_div_1_amax3_xnumel = 1536*s0 | |
triton_fused_add_1_div_mul_sub_mean_add_div_2_sub_1_div_1_amax3.run(primals_194, buf182, buf185, triton_fused_add_1_div_mul_sub_mean_add_div_2_sub_1_div_1_amax3_xnumel, 128, grid=grid(triton_fused_add_1_div_mul_sub_mean_add_div_2_sub_1_div_1_amax3_xnumel), stream=stream0) | |
buf186 = as_strided(buf178, (s0, 12, 128, 64), (98304, 8192, 64, 1)); del buf178 # reuse | |
triton_fused_add_1_div_mul_clone_2_view_7_permute_5_sub_permute_4_mean_add4_xnumel = 98304*s0 | |
triton_fused_add_1_div_mul_clone_2_view_7_permute_5_sub_permute_4_mean_add4.run(buf179, buf186, triton_fused_add_1_div_mul_clone_2_view_7_permute_5_sub_permute_4_mean_add4_xnumel, grid=grid(triton_fused_add_1_div_mul_clone_2_view_7_permute_5_sub_permute_4_mean_add4_xnumel), stream=stream0) | |
buf187 = as_strided(buf179, (12*s0, 128, 64), (8192, 64, 1)); del buf179 # reuse | |
aten.bmm.out(as_strided(buf185, (12*s0, 128, 128), (16384, 128, 1)), as_strided(buf186, (12*s0, 128, 64), (8192, 64, 1)), out=buf187) | |
buf188 = empty_strided((s0, 128, 12, 64), (98304, 768, 64, 1), device='cuda', dtype=torch.float32) | |
triton_fused_clone_3_add_1_div_where_mul_clone_2_permute_5_bmm_1_permute_6_sum_15_xnumel = 98304*s0 | |
triton_fused_clone_3_add_1_div_where_mul_clone_2_permute_5_bmm_1_permute_6_sum_15.run(buf187, buf188, triton_fused_clone_3_add_1_div_where_mul_clone_2_permute_5_bmm_1_permute_6_sum_15_xnumel, grid=grid(triton_fused_clone_3_add_1_div_where_mul_clone_2_permute_5_bmm_1_permute_6_sum_15_xnumel), stream=stream0) | |
buf189 = empty_strided((128*s0, 768), (768, 1), device='cuda', dtype=torch.float32) | |
aten.addmm.out(primals_128, as_strided(buf188, (128*s0, 768), (768, 1)), as_strided(primals_127, (768, 768), (1, 768)), beta=1, alpha=1, out=buf189) | |
del primals_128 | |
buf190 = empty_strided((s0, 128, 1), (128, 1, 128*s0), device='cuda', dtype=torch.float32) | |
buf193 = empty_strided((s0, 128, 1), (128, 1, 128*s0), device='cuda', dtype=torch.float32) | |
buf191 = buf190; del buf190 # reuse | |
buf194 = buf193; del buf193 # reuse | |
buf195 = as_strided(buf188, (s0, 128, 768), (98304, 768, 1)); del buf188 # reuse | |
triton_fused_clone_3_add_1_div_view_13_mul_clone_2_bmm_1_permute_8_sub_mean6_xnumel = 128*s0 | |
triton_fused_clone_3_add_1_div_view_13_mul_clone_2_bmm_1_permute_8_sub_mean6.run(buf191, buf194, buf170, buf189, primals_27, primals_28, buf195, triton_fused_clone_3_add_1_div_view_13_mul_clone_2_bmm_1_permute_8_sub_mean6_xnumel, 768, grid=grid(triton_fused_clone_3_add_1_div_view_13_mul_clone_2_bmm_1_permute_8_sub_mean6_xnumel), stream=stream0) | |
buf196 = buf167; del buf167 # reuse | |
aten.addmm.out(primals_130, as_strided(buf195, (128*s0, 768), (768, 1)), as_strided(primals_129, (768, 3072), (1, 768)), beta=1, alpha=1, out=buf196) | |
del primals_130 | |
buf197 = empty_strided((s0, 128, 3072), (393216, 3072, 1), device='cuda', dtype=torch.float32) | |
buf347 = empty_strided((s0, 128, 3072), (393216, 3072, 1), device='cuda', dtype=torch.float32) | |
triton_fused_clone_3_add_1_div_abs_1_view_13_mul_clone_2_bmm_1_mul_631_permute_87_xnumel = 393216*s0 | |
triton_fused_clone_3_add_1_div_abs_1_view_13_mul_clone_2_bmm_1_mul_631_permute_87.run(buf196, buf197, buf347, triton_fused_clone_3_add_1_div_abs_1_view_13_mul_clone_2_bmm_1_mul_631_permute_87_xnumel, grid=grid(triton_fused_clone_3_add_1_div_abs_1_view_13_mul_clone_2_bmm_1_mul_631_permute_87_xnumel), stream=stream0) | |
buf198 = as_strided(buf195, (128*s0, 768), (768, 1)); del buf195 # reuse | |
aten.addmm.out(primals_132, as_strided(buf197, (128*s0, 3072), (3072, 1)), as_strided(primals_131, (3072, 768), (1, 3072)), beta=1, alpha=1, out=buf198) | |
del primals_132 | |
buf199 = empty_strided((s0, 128, 1), (128, 1, 128*s0), device='cuda', dtype=torch.float32) | |
buf202 = empty_strided((s0, 128, 1), (128, 1, 128*s0), device='cuda', dtype=torch.float32) | |
buf200 = buf199; del buf199 # reuse | |
buf203 = buf202; del buf202 # reuse | |
buf204 = empty_strided((s0, 128, 768), (98304, 768, 1), device='cuda', dtype=torch.float32) | |
triton_fused_add_11_clone_3_add_1_div_abs_1_view_13_mul_add_12_clone_2_sub_48_xnumel = 128*s0 | |
triton_fused_add_11_clone_3_add_1_div_abs_1_view_13_mul_add_12_clone_2_sub_48.run(buf200, buf203, buf170, buf189, buf198, primals_29, primals_30, buf204, triton_fused_add_11_clone_3_add_1_div_abs_1_view_13_mul_add_12_clone_2_sub_48_xnumel, 768, grid=grid(triton_fused_add_11_clone_3_add_1_div_abs_1_view_13_mul_add_12_clone_2_sub_48_xnumel), stream=stream0) | |
buf205 = empty_strided((128*s0, 768), (768, 1), device='cuda', dtype=torch.float32) | |
aten.addmm.out(primals_134, as_strided(buf204, (128*s0, 768), (768, 1)), as_strided(primals_133, (768, 768), (1, 768)), beta=1, alpha=1, out=buf205) | |
del primals_134 | |
buf206 = empty_strided((128*s0, 768), (768, 1), device='cuda', dtype=torch.float32) | |
aten.addmm.out(primals_136, as_strided(buf204, (128*s0, 768), (768, 1)), as_strided(primals_135, (768, 768), (1, 768)), beta=1, alpha=1, out=buf206) | |
del primals_136 | |
buf207 = empty_strided((128*s0, 768), (768, 1), device='cuda', dtype=torch.float32) | |
aten.addmm.out(primals_138, as_strided(buf204, (128*s0, 768), (768, 1)), as_strided(primals_137, (768, 768), (1, 768)), beta=1, alpha=1, out=buf207) | |
del primals_138 | |
buf208 = as_strided(buf204, (s0, 12, 128, 64), (98304, 8192, 64, 1)); del buf204 # reuse | |
triton_fused_add_1_div_mul_clone_2_view_7_permute_5_sub_permute_4_mean_add4_xnumel = 98304*s0 | |
triton_fused_add_1_div_mul_clone_2_view_7_permute_5_sub_permute_4_mean_add4.run(buf205, buf208, triton_fused_add_1_div_mul_clone_2_view_7_permute_5_sub_permute_4_mean_add4_xnumel, grid=grid(triton_fused_add_1_div_mul_clone_2_view_7_permute_5_sub_permute_4_mean_add4_xnumel), stream=stream0) | |
buf209 = as_strided(buf205, (s0, 12, 64, 128), (98304, 8192, 128, 1)); del buf205 # reuse | |
triton_fused_add_1_div_mul_permute_6_sub_addmm_1_mean_permute_3_add_clone_12_xnumel = 768*s0 | |
triton_fused_add_1_div_mul_permute_6_sub_addmm_1_mean_permute_3_add_clone_12.run(buf206, buf209, triton_fused_add_1_div_mul_permute_6_sub_addmm_1_mean_permute_3_add_clone_12_xnumel, 128, grid=grid(triton_fused_add_1_div_mul_permute_6_sub_addmm_1_mean_permute_3_add_clone_12_xnumel, 128), stream=stream0) | |
buf210 = buf182; del buf182 # reuse | |
aten.bmm.out(as_strided(buf208, (12*s0, 128, 64), (8192, 64, 1)), as_strided(buf209, (12*s0, 64, 128), (8192, 128, 1)), out=buf210) | |
buf213 = empty_strided((s0, 12, 128, 128), (196608, 16384, 128, 1), device='cuda', dtype=torch.float32) | |
triton_fused_add_1_div_mul_sub_mean_add_div_2_sub_1_div_1_amax3_xnumel = 1536*s0 | |
triton_fused_add_1_div_mul_sub_mean_add_div_2_sub_1_div_1_amax3.run(primals_194, buf210, buf213, triton_fused_add_1_div_mul_sub_mean_add_div_2_sub_1_div_1_amax3_xnumel, 128, grid=grid(triton_fused_add_1_div_mul_sub_mean_add_div_2_sub_1_div_1_amax3_xnumel), stream=stream0) | |
buf214 = as_strided(buf206, (s0, 12, 128, 64), (98304, 8192, 64, 1)); del buf206 # reuse | |
triton_fused_add_1_div_mul_clone_2_view_7_permute_5_sub_permute_4_mean_add4_xnumel = 98304*s0 | |
triton_fused_add_1_div_mul_clone_2_view_7_permute_5_sub_permute_4_mean_add4.run(buf207, buf214, triton_fused_add_1_div_mul_clone_2_view_7_permute_5_sub_permute_4_mean_add4_xnumel, grid=grid(triton_fused_add_1_div_mul_clone_2_view_7_permute_5_sub_permute_4_mean_add4_xnumel), stream=stream0) | |
buf215 = as_strided(buf207, (12*s0, 128, 64), (8192, 64, 1)); del buf207 # reuse | |
aten.bmm.out(as_strided(buf213, (12*s0, 128, 128), (16384, 128, 1)), as_strided(buf214, (12*s0, 128, 64), (8192, 64, 1)), out=buf215) | |
buf216 = empty_strided((s0, 128, 12, 64), (98304, 768, 64, 1), device='cuda', dtype=torch.float32) | |
triton_fused_clone_3_add_1_div_where_mul_clone_2_permute_5_bmm_1_permute_6_sum_15_xnumel = 98304*s0 | |
triton_fused_clone_3_add_1_div_where_mul_clone_2_permute_5_bmm_1_permute_6_sum_15.run(buf215, buf216, triton_fused_clone_3_add_1_div_where_mul_clone_2_permute_5_bmm_1_permute_6_sum_15_xnumel, grid=grid(triton_fused_clone_3_add_1_div_where_mul_clone_2_permute_5_bmm_1_permute_6_sum_15_xnumel), stream=stream0) | |
buf217 = empty_strided((128*s0, 768), (768, 1), device='cuda', dtype=torch.float32) | |
aten.addmm.out(primals_140, as_strided(buf216, (128*s0, 768), (768, 1)), as_strided(primals_139, (768, 768), (1, 768)), beta=1, alpha=1, out=buf217) | |
del primals_140 | |
buf218 = empty_strided((s0, 128, 1), (128, 1, 128*s0), device='cuda', dtype=torch.float32) | |
buf221 = empty_strided((s0, 128, 1), (128, 1, 128*s0), device='cuda', dtype=torch.float32) | |
buf219 = buf218; del buf218 # reuse | |
buf222 = buf221; del buf221 # reuse | |
buf223 = as_strided(buf216, (s0, 128, 768), (98304, 768, 1)); del buf216 # reuse | |
triton_fused_add_11_add_1_permute_19_div_view_23_mul_add_12_sub_4_permute_14_view_199_xnumel = 128*s0 | |
triton_fused_add_11_add_1_permute_19_div_view_23_mul_add_12_sub_4_permute_14_view_199.run(buf219, buf222, buf170, buf189, buf198, buf217, primals_31, primals_32, buf223, triton_fused_add_11_add_1_permute_19_div_view_23_mul_add_12_sub_4_permute_14_view_199_xnumel, 768, grid=grid(triton_fused_add_11_add_1_permute_19_div_view_23_mul_add_12_sub_4_permute_14_view_199_xnumel), stream=stream0) | |
buf224 = buf196; del buf196 # reuse | |
aten.addmm.out(primals_142, as_strided(buf223, (128*s0, 768), (768, 1)), as_strided(primals_141, (768, 3072), (1, 768)), beta=1, alpha=1, out=buf224) | |
del primals_142 | |
buf225 = empty_strided((s0, 128, 3072), (393216, 3072, 1), device='cuda', dtype=torch.float32) | |
buf346 = empty_strided((s0, 128, 3072), (393216, 3072, 1), device='cuda', dtype=torch.float32) | |
triton_fused_clone_3_add_1_div_abs_1_view_13_mul_clone_2_bmm_1_mul_631_permute_87_xnumel = 393216*s0 | |
triton_fused_clone_3_add_1_div_abs_1_view_13_mul_clone_2_bmm_1_mul_631_permute_87.run(buf224, buf225, buf346, triton_fused_clone_3_add_1_div_abs_1_view_13_mul_clone_2_bmm_1_mul_631_permute_87_xnumel, grid=grid(triton_fused_clone_3_add_1_div_abs_1_view_13_mul_clone_2_bmm_1_mul_631_permute_87_xnumel), stream=stream0) | |
buf226 = as_strided(buf223, (128*s0, 768), (768, 1)); del buf223 # reuse | |
aten.addmm.out(primals_144, as_strided(buf225, (128*s0, 3072), (3072, 1)), as_strided(primals_143, (3072, 768), (1, 3072)), beta=1, alpha=1, out=buf226) | |
del primals_144 | |
buf227 = buf170; del buf170 # reuse | |
buf228 = empty_strided((s0, 128, 1), (128, 1, 128*s0), device='cuda', dtype=torch.float32) | |
buf231 = empty_strided((s0, 128, 1), (128, 1, 128*s0), device='cuda', dtype=torch.float32) | |
buf229 = buf228; del buf228 # reuse | |
buf232 = buf231; del buf231 # reuse | |
buf233 = empty_strided((s0, 128, 768), (98304, 768, 1), device='cuda', dtype=torch.float32) | |
triton_fused_permute_19_view_35_view_34_reciprocal_1_view_33_addmm_9_mul_32_permute_20_std_3_add_1411_xnumel = 128*s0 | |
triton_fused_permute_19_view_35_view_34_reciprocal_1_view_33_addmm_9_mul_32_permute_20_std_3_add_1411.run(buf227, buf229, buf232, buf189, buf198, buf217, buf226, primals_33, primals_34, buf233, triton_fused_permute_19_view_35_view_34_reciprocal_1_view_33_addmm_9_mul_32_permute_20_std_3_add_1411_xnumel, 768, grid=grid(triton_fused_permute_19_view_35_view_34_reciprocal_1_view_33_addmm_9_mul_32_permute_20_std_3_add_1411_xnumel), stream=stream0) | |
buf234 = empty_strided((128*s0, 768), (768, 1), device='cuda', dtype=torch.float32) | |
aten.addmm.out(primals_146, as_strided(buf233, (128*s0, 768), (768, 1)), as_strided(primals_145, (768, 768), (1, 768)), beta=1, alpha=1, out=buf234) | |
del primals_146 | |
buf235 = empty_strided((128*s0, 768), (768, 1), device='cuda', dtype=torch.float32) | |
aten.addmm.out(primals_148, as_strided(buf233, (128*s0, 768), (768, 1)), as_strided(primals_147, (768, 768), (1, 768)), beta=1, alpha=1, out=buf235) | |
del primals_148 | |
buf236 = empty_strided((128*s0, 768), (768, 1), device='cuda', dtype=torch.float32) | |
aten.addmm.out(primals_150, as_strided(buf233, (128*s0, 768), (768, 1)), as_strided(primals_149, (768, 768), (1, 768)), beta=1, alpha=1, out=buf236) | |
del primals_150 | |
buf237 = as_strided(buf233, (s0, 12, 128, 64), (98304, 8192, 64, 1)); del buf233 # reuse | |
triton_fused_add_1_div_mul_clone_2_view_7_permute_5_sub_permute_4_mean_add4_xnumel = 98304*s0 | |
triton_fused_add_1_div_mul_clone_2_view_7_permute_5_sub_permute_4_mean_add4.run(buf234, buf237, triton_fused_add_1_div_mul_clone_2_view_7_permute_5_sub_permute_4_mean_add4_xnumel, grid=grid(triton_fused_add_1_div_mul_clone_2_view_7_permute_5_sub_permute_4_mean_add4_xnumel), stream=stream0) | |
buf238 = as_strided(buf234, (s0, 12, 64, 128), (98304, 8192, 128, 1)); del buf234 # reuse | |
triton_fused_add_1_div_mul_permute_6_sub_addmm_1_mean_permute_3_add_clone_12_xnumel = 768*s0 | |
triton_fused_add_1_div_mul_permute_6_sub_addmm_1_mean_permute_3_add_clone_12.run(buf235, buf238, triton_fused_add_1_div_mul_permute_6_sub_addmm_1_mean_permute_3_add_clone_12_xnumel, 128, grid=grid(triton_fused_add_1_div_mul_permute_6_sub_addmm_1_mean_permute_3_add_clone_12_xnumel, 128), stream=stream0) | |
buf239 = buf210; del buf210 # reuse | |
aten.bmm.out(as_strided(buf237, (12*s0, 128, 64), (8192, 64, 1)), as_strided(buf238, (12*s0, 64, 128), (8192, 128, 1)), out=buf239) | |
buf242 = empty_strided((s0, 12, 128, 128), (196608, 16384, 128, 1), device='cuda', dtype=torch.float32) | |
triton_fused_add_1_div_mul_sub_mean_add_div_2_sub_1_div_1_amax3_xnumel = 1536*s0 | |
triton_fused_add_1_div_mul_sub_mean_add_div_2_sub_1_div_1_amax3.run(primals_194, buf239, buf242, triton_fused_add_1_div_mul_sub_mean_add_div_2_sub_1_div_1_amax3_xnumel, 128, grid=grid(triton_fused_add_1_div_mul_sub_mean_add_div_2_sub_1_div_1_amax3_xnumel), stream=stream0) | |
buf243 = as_strided(buf235, (s0, 12, 128, 64), (98304, 8192, 64, 1)); del buf235 # reuse | |
triton_fused_add_1_div_mul_clone_2_view_7_permute_5_sub_permute_4_mean_add4_xnumel = 98304*s0 | |
triton_fused_add_1_div_mul_clone_2_view_7_permute_5_sub_permute_4_mean_add4.run(buf236, buf243, triton_fused_add_1_div_mul_clone_2_view_7_permute_5_sub_permute_4_mean_add4_xnumel, grid=grid(triton_fused_add_1_div_mul_clone_2_view_7_permute_5_sub_permute_4_mean_add4_xnumel), stream=stream0) | |
buf244 = as_strided(buf236, (12*s0, 128, 64), (8192, 64, 1)); del buf236 # reuse | |
aten.bmm.out(as_strided(buf242, (12*s0, 128, 128), (16384, 128, 1)), as_strided(buf243, (12*s0, 128, 64), (8192, 64, 1)), out=buf244) | |
buf245 = empty_strided((s0, 128, 12, 64), (98304, 768, 64, 1), device='cuda', dtype=torch.float32) | |
triton_fused_clone_3_add_1_div_where_mul_clone_2_permute_5_bmm_1_permute_6_sum_15_xnumel = 98304*s0 | |
triton_fused_clone_3_add_1_div_where_mul_clone_2_permute_5_bmm_1_permute_6_sum_15.run(buf244, buf245, triton_fused_clone_3_add_1_div_where_mul_clone_2_permute_5_bmm_1_permute_6_sum_15_xnumel, grid=grid(triton_fused_clone_3_add_1_div_where_mul_clone_2_permute_5_bmm_1_permute_6_sum_15_xnumel), stream=stream0) | |
buf246 = empty_strided((128*s0, 768), (768, 1), device='cuda', dtype=torch.float32) | |
aten.addmm.out(primals_152, as_strided(buf245, (128*s0, 768), (768, 1)), as_strided(primals_151, (768, 768), (1, 768)), beta=1, alpha=1, out=buf246) | |
del primals_152 | |
buf247 = empty_strided((s0, 128, 1), (128, 1, 128*s0), device='cuda', dtype=torch.float32) | |
buf250 = empty_strided((s0, 128, 1), (128, 1, 128*s0), device='cuda', dtype=torch.float32) | |
buf248 = buf247; del buf247 # reuse | |
buf251 = buf250; del buf250 # reuse | |
buf252 = as_strided(buf245, (s0, 128, 768), (98304, 768, 1)); del buf245 # reuse | |
triton_fused_clone_3_add_1_div_view_13_mul_clone_2_bmm_1_permute_8_sub_mean6_xnumel = 128*s0 | |
triton_fused_clone_3_add_1_div_view_13_mul_clone_2_bmm_1_permute_8_sub_mean6.run(buf248, buf251, buf227, buf246, primals_35, primals_36, buf252, triton_fused_clone_3_add_1_div_view_13_mul_clone_2_bmm_1_permute_8_sub_mean6_xnumel, 768, grid=grid(triton_fused_clone_3_add_1_div_view_13_mul_clone_2_bmm_1_permute_8_sub_mean6_xnumel), stream=stream0) | |
buf253 = buf224; del buf224 # reuse | |
aten.addmm.out(primals_154, as_strided(buf252, (128*s0, 768), (768, 1)), as_strided(primals_153, (768, 3072), (1, 768)), beta=1, alpha=1, out=buf253) | |
del primals_154 | |
buf254 = empty_strided((s0, 128, 3072), (393216, 3072, 1), device='cuda', dtype=torch.float32) | |
buf345 = empty_strided((s0, 128, 3072), (393216, 3072, 1), device='cuda', dtype=torch.float32) | |
triton_fused_clone_3_add_1_div_abs_1_view_13_mul_clone_2_bmm_1_mul_631_permute_87_xnumel = 393216*s0 | |
triton_fused_clone_3_add_1_div_abs_1_view_13_mul_clone_2_bmm_1_mul_631_permute_87.run(buf253, buf254, buf345, triton_fused_clone_3_add_1_div_abs_1_view_13_mul_clone_2_bmm_1_mul_631_permute_87_xnumel, grid=grid(triton_fused_clone_3_add_1_div_abs_1_view_13_mul_clone_2_bmm_1_mul_631_permute_87_xnumel), stream=stream0) | |
buf255 = as_strided(buf252, (128*s0, 768), (768, 1)); del buf252 # reuse | |
aten.addmm.out(primals_156, as_strided(buf254, (128*s0, 3072), (3072, 1)), as_strided(primals_155, (3072, 768), (1, 3072)), beta=1, alpha=1, out=buf255) | |
del primals_156 | |
buf256 = empty_strided((s0, 128, 1), (128, 1, 128*s0), device='cuda', dtype=torch.float32) | |
buf259 = empty_strided((s0, 128, 1), (128, 1, 128*s0), device='cuda', dtype=torch.float32) | |
buf257 = buf256; del buf256 # reuse | |
buf260 = buf259; del buf259 # reuse | |
buf261 = empty_strided((s0, 128, 768), (98304, 768, 1), device='cuda', dtype=torch.float32) | |
triton_fused_add_11_clone_3_add_1_div_abs_1_view_13_mul_add_12_clone_2_sub_48_xnumel = 128*s0 | |
triton_fused_add_11_clone_3_add_1_div_abs_1_view_13_mul_add_12_clone_2_sub_48.run(buf257, buf260, buf227, buf246, buf255, primals_37, primals_38, buf261, triton_fused_add_11_clone_3_add_1_div_abs_1_view_13_mul_add_12_clone_2_sub_48_xnumel, 768, grid=grid(triton_fused_add_11_clone_3_add_1_div_abs_1_view_13_mul_add_12_clone_2_sub_48_xnumel), stream=stream0) | |
buf262 = empty_strided((128*s0, 768), (768, 1), device='cuda', dtype=torch.float32) | |
aten.addmm.out(primals_158, as_strided(buf261, (128*s0, 768), (768, 1)), as_strided(primals_157, (768, 768), (1, 768)), beta=1, alpha=1, out=buf262) | |
del primals_158 | |
buf263 = empty_strided((128*s0, 768), (768, 1), device='cuda', dtype=torch.float32) | |
aten.addmm.out(primals_160, as_strided(buf261, (128*s0, 768), (768, 1)), as_strided(primals_159, (768, 768), (1, 768)), beta=1, alpha=1, out=buf263) | |
del primals_160 | |
buf264 = empty_strided((128*s0, 768), (768, 1), device='cuda', dtype=torch.float32) | |
aten.addmm.out(primals_162, as_strided(buf261, (128*s0, 768), (768, 1)), as_strided(primals_161, (768, 768), (1, 768)), beta=1, alpha=1, out=buf264) | |
del primals_162 | |
buf265 = as_strided(buf261, (s0, 12, 128, 64), (98304, 8192, 64, 1)); del buf261 # reuse | |
triton_fused_add_1_div_mul_clone_2_view_7_permute_5_sub_permute_4_mean_add4_xnumel = 98304*s0 | |
triton_fused_add_1_div_mul_clone_2_view_7_permute_5_sub_permute_4_mean_add4.run(buf262, buf265, triton_fused_add_1_div_mul_clone_2_view_7_permute_5_sub_permute_4_mean_add4_xnumel, grid=grid(triton_fused_add_1_div_mul_clone_2_view_7_permute_5_sub_permute_4_mean_add4_xnumel), stream=stream0) | |
buf266 = as_strided(buf262, (s0, 12, 64, 128), (98304, 8192, 128, 1)); del buf262 # reuse | |
triton_fused_add_1_div_mul_permute_6_sub_addmm_1_mean_permute_3_add_clone_12_xnumel = 768*s0 | |
triton_fused_add_1_div_mul_permute_6_sub_addmm_1_mean_permute_3_add_clone_12.run(buf263, buf266, triton_fused_add_1_div_mul_permute_6_sub_addmm_1_mean_permute_3_add_clone_12_xnumel, 128, grid=grid(triton_fused_add_1_div_mul_permute_6_sub_addmm_1_mean_permute_3_add_clone_12_xnumel, 128), stream=stream0) | |
buf267 = buf239; del buf239 # reuse | |
aten.bmm.out(as_strided(buf265, (12*s0, 128, 64), (8192, 64, 1)), as_strided(buf266, (12*s0, 64, 128), (8192, 128, 1)), out=buf267) | |
buf270 = empty_strided((s0, 12, 128, 128), (196608, 16384, 128, 1), device='cuda', dtype=torch.float32) | |
triton_fused_add_1_div_mul_sub_mean_add_div_2_sub_1_div_1_amax3_xnumel = 1536*s0 | |
triton_fused_add_1_div_mul_sub_mean_add_div_2_sub_1_div_1_amax3.run(primals_194, buf267, buf270, triton_fused_add_1_div_mul_sub_mean_add_div_2_sub_1_div_1_amax3_xnumel, 128, grid=grid(triton_fused_add_1_div_mul_sub_mean_add_div_2_sub_1_div_1_amax3_xnumel), stream=stream0) | |
buf271 = as_strided(buf263, (s0, 12, 128, 64), (98304, 8192, 64, 1)); del buf263 # reuse | |
triton_fused_add_1_div_mul_clone_2_view_7_permute_5_sub_permute_4_mean_add4_xnumel = 98304*s0 | |
triton_fused_add_1_div_mul_clone_2_view_7_permute_5_sub_permute_4_mean_add4.run(buf264, buf271, triton_fused_add_1_div_mul_clone_2_view_7_permute_5_sub_permute_4_mean_add4_xnumel, grid=grid(triton_fused_add_1_div_mul_clone_2_view_7_permute_5_sub_permute_4_mean_add4_xnumel), stream=stream0) | |
buf272 = as_strided(buf264, (12*s0, 128, 64), (8192, 64, 1)); del buf264 # reuse | |
aten.bmm.out(as_strided(buf270, (12*s0, 128, 128), (16384, 128, 1)), as_strided(buf271, (12*s0, 128, 64), (8192, 64, 1)), out=buf272) | |
buf273 = empty_strided((s0, 128, 12, 64), (98304, 768, 64, 1), device='cuda', dtype=torch.float32) | |
triton_fused_clone_3_add_1_div_where_mul_clone_2_permute_5_bmm_1_permute_6_sum_15_xnumel = 98304*s0 | |
triton_fused_clone_3_add_1_div_where_mul_clone_2_permute_5_bmm_1_permute_6_sum_15.run(buf272, buf273, triton_fused_clone_3_add_1_div_where_mul_clone_2_permute_5_bmm_1_permute_6_sum_15_xnumel, grid=grid(triton_fused_clone_3_add_1_div_where_mul_clone_2_permute_5_bmm_1_permute_6_sum_15_xnumel), stream=stream0) | |
buf274 = empty_strided((128*s0, 768), (768, 1), device='cuda', dtype=torch.float32) | |
aten.addmm.out(primals_164, as_strided(buf273, (128*s0, 768), (768, 1)), as_strided(primals_163, (768, 768), (1, 768)), beta=1, alpha=1, out=buf274) | |
del primals_164 | |
buf275 = empty_strided((s0, 128, 1), (128, 1, 128*s0), device='cuda', dtype=torch.float32) | |
buf278 = empty_strided((s0, 128, 1), (128, 1, 128*s0), device='cuda', dtype=torch.float32) | |
buf276 = buf275; del buf275 # reuse | |
buf279 = buf278; del buf278 # reuse | |
buf280 = as_strided(buf273, (s0, 128, 768), (98304, 768, 1)); del buf273 # reuse | |
triton_fused_add_11_add_1_permute_19_div_view_23_mul_add_12_sub_4_permute_14_view_199_xnumel = 128*s0 | |
triton_fused_add_11_add_1_permute_19_div_view_23_mul_add_12_sub_4_permute_14_view_199.run(buf276, buf279, buf227, buf246, buf255, buf274, primals_39, primals_40, buf280, triton_fused_add_11_add_1_permute_19_div_view_23_mul_add_12_sub_4_permute_14_view_199_xnumel, 768, grid=grid(triton_fused_add_11_add_1_permute_19_div_view_23_mul_add_12_sub_4_permute_14_view_199_xnumel), stream=stream0) | |
buf281 = buf253; del buf253 # reuse | |
aten.addmm.out(primals_166, as_strided(buf280, (128*s0, 768), (768, 1)), as_strided(primals_165, (768, 3072), (1, 768)), beta=1, alpha=1, out=buf281) | |
del primals_166 | |
buf282 = empty_strided((s0, 128, 3072), (393216, 3072, 1), device='cuda', dtype=torch.float32) | |
buf344 = empty_strided((s0, 128, 3072), (393216, 3072, 1), device='cuda', dtype=torch.float32) | |
triton_fused_clone_3_add_1_div_abs_1_view_13_mul_clone_2_bmm_1_mul_631_permute_87_xnumel = 393216*s0 | |
triton_fused_clone_3_add_1_div_abs_1_view_13_mul_clone_2_bmm_1_mul_631_permute_87.run(buf281, buf282, buf344, triton_fused_clone_3_add_1_div_abs_1_view_13_mul_clone_2_bmm_1_mul_631_permute_87_xnumel, grid=grid(triton_fused_clone_3_add_1_div_abs_1_view_13_mul_clone_2_bmm_1_mul_631_permute_87_xnumel), stream=stream0) | |
buf283 = as_strided(buf280, (128*s0, 768), (768, 1)); del buf280 # reuse | |
aten.addmm.out(primals_168, as_strided(buf282, (128*s0, 3072), (3072, 1)), as_strided(primals_167, (3072, 768), (1, 3072)), beta=1, alpha=1, out=buf283) | |
del primals_168 | |
buf284 = buf227; del buf227 # reuse | |
buf285 = empty_strided((s0, 128, 1), (128, 1, 128*s0), device='cuda', dtype=torch.float32) | |
buf288 = empty_strided((s0, 128, 1), (128, 1, 128*s0), device='cuda', dtype=torch.float32) | |
buf286 = buf285; del buf285 # reuse | |
buf289 = buf288; del buf288 # reuse | |
buf290 = empty_strided((s0, 128, 768), (98304, 768, 1), device='cuda', dtype=torch.float32) | |
triton_fused_permute_19_view_35_view_34_reciprocal_1_view_33_addmm_9_mul_32_permute_20_std_3_add_1411_xnumel = 128*s0 | |
triton_fused_permute_19_view_35_view_34_reciprocal_1_view_33_addmm_9_mul_32_permute_20_std_3_add_1411.run(buf284, buf286, buf289, buf246, buf255, buf274, buf283, primals_41, primals_42, buf290, triton_fused_permute_19_view_35_view_34_reciprocal_1_view_33_addmm_9_mul_32_permute_20_std_3_add_1411_xnumel, 768, grid=grid(triton_fused_permute_19_view_35_view_34_reciprocal_1_view_33_addmm_9_mul_32_permute_20_std_3_add_1411_xnumel), stream=stream0) | |
buf291 = empty_strided((128*s0, 768), (768, 1), device='cuda', dtype=torch.float32) | |
aten.addmm.out(primals_170, as_strided(buf290, (128*s0, 768), (768, 1)), as_strided(primals_169, (768, 768), (1, 768)), beta=1, alpha=1, out=buf291) | |
del primals_170 | |
buf292 = empty_strided((128*s0, 768), (768, 1), device='cuda', dtype=torch.float32) | |
aten.addmm.out(primals_172, as_strided(buf290, (128*s0, 768), (768, 1)), as_strided(primals_171, (768, 768), (1, 768)), beta=1, alpha=1, out=buf292) | |
del primals_172 | |
buf293 = empty_strided((128*s0, 768), (768, 1), device='cuda', dtype=torch.float32) | |
aten.addmm.out(primals_174, as_strided(buf290, (128*s0, 768), (768, 1)), as_strided(primals_173, (768, 768), (1, 768)), beta=1, alpha=1, out=buf293) | |
del primals_174 | |
buf294 = as_strided(buf290, (s0, 12, 128, 64), (98304, 8192, 64, 1)); del buf290 # reuse | |
triton_fused_add_1_div_mul_clone_2_view_7_permute_5_sub_permute_4_mean_add4_xnumel = 98304*s0 | |
triton_fused_add_1_div_mul_clone_2_view_7_permute_5_sub_permute_4_mean_add4.run(buf291, buf294, triton_fused_add_1_div_mul_clone_2_view_7_permute_5_sub_permute_4_mean_add4_xnumel, grid=grid(triton_fused_add_1_div_mul_clone_2_view_7_permute_5_sub_permute_4_mean_add4_xnumel), stream=stream0) | |
buf295 = as_strided(buf291, (s0, 12, 64, 128), (98304, 8192, 128, 1)); del buf291 # reuse | |
triton_fused_add_1_div_mul_permute_6_sub_addmm_1_mean_permute_3_add_clone_12_xnumel = 768*s0 | |
triton_fused_add_1_div_mul_permute_6_sub_addmm_1_mean_permute_3_add_clone_12.run(buf292, buf295, triton_fused_add_1_div_mul_permute_6_sub_addmm_1_mean_permute_3_add_clone_12_xnumel, 128, grid=grid(triton_fused_add_1_div_mul_permute_6_sub_addmm_1_mean_permute_3_add_clone_12_xnumel, 128), stream=stream0) | |
buf296 = buf267; del buf267 # reuse | |
aten.bmm.out(as_strided(buf294, (12*s0, 128, 64), (8192, 64, 1)), as_strided(buf295, (12*s0, 64, 128), (8192, 128, 1)), out=buf296) | |
buf299 = empty_strided((s0, 12, 128, 128), (196608, 16384, 128, 1), device='cuda', dtype=torch.float32) | |
triton_fused_add_1_div_mul_sub_mean_add_div_2_sub_1_div_1_amax3_xnumel = 1536*s0 | |
triton_fused_add_1_div_mul_sub_mean_add_div_2_sub_1_div_1_amax3.run(primals_194, buf296, buf299, triton_fused_add_1_div_mul_sub_mean_add_div_2_sub_1_div_1_amax3_xnumel, 128, grid=grid(triton_fused_add_1_div_mul_sub_mean_add_div_2_sub_1_div_1_amax3_xnumel), stream=stream0) | |
buf300 = as_strided(buf292, (s0, 12, 128, 64), (98304, 8192, 64, 1)); del buf292 # reuse | |
triton_fused_add_1_div_mul_clone_2_view_7_permute_5_sub_permute_4_mean_add4_xnumel = 98304*s0 | |
triton_fused_add_1_div_mul_clone_2_view_7_permute_5_sub_permute_4_mean_add4.run(buf293, buf300, triton_fused_add_1_div_mul_clone_2_view_7_permute_5_sub_permute_4_mean_add4_xnumel, grid=grid(triton_fused_add_1_div_mul_clone_2_view_7_permute_5_sub_permute_4_mean_add4_xnumel), stream=stream0) | |
buf301 = as_strided(buf293, (12*s0, 128, 64), (8192, 64, 1)); del buf293 # reuse | |
aten.bmm.out(as_strided(buf299, (12*s0, 128, 128), (16384, 128, 1)), as_strided(buf300, (12*s0, 128, 64), (8192, 64, 1)), out=buf301) | |
buf302 = empty_strided((s0, 128, 12, 64), (98304, 768, 64, 1), device='cuda', dtype=torch.float32) | |
triton_fused_clone_3_add_1_div_where_mul_clone_2_permute_5_bmm_1_permute_6_sum_15_xnumel = 98304*s0 | |
triton_fused_clone_3_add_1_div_where_mul_clone_2_permute_5_bmm_1_permute_6_sum_15.run(buf301, buf302, triton_fused_clone_3_add_1_div_where_mul_clone_2_permute_5_bmm_1_permute_6_sum_15_xnumel, grid=grid(triton_fused_clone_3_add_1_div_where_mul_clone_2_permute_5_bmm_1_permute_6_sum_15_xnumel), stream=stream0) | |
buf303 = empty_strided((128*s0, 768), (768, 1), device='cuda', dtype=torch.float32) | |
aten.addmm.out(primals_176, as_strided(buf302, (128*s0, 768), (768, 1)), as_strided(primals_175, (768, 768), (1, 768)), beta=1, alpha=1, out=buf303) | |
del primals_176 | |
buf304 = empty_strided((s0, 128, 1), (128, 1, 128*s0), device='cuda', dtype=torch.float32) | |
buf307 = empty_strided((s0, 128, 1), (128, 1, 128*s0), device='cuda', dtype=torch.float32) | |
buf305 = buf304; del buf304 # reuse | |
buf308 = buf307; del buf307 # reuse | |
buf309 = as_strided(buf302, (s0, 128, 768), (98304, 768, 1)); del buf302 # reuse | |
triton_fused_clone_3_add_1_div_view_13_mul_clone_2_bmm_1_permute_8_sub_mean6_xnumel = 128*s0 | |
triton_fused_clone_3_add_1_div_view_13_mul_clone_2_bmm_1_permute_8_sub_mean6.run(buf305, buf308, buf284, buf303, primals_43, primals_44, buf309, triton_fused_clone_3_add_1_div_view_13_mul_clone_2_bmm_1_permute_8_sub_mean6_xnumel, 768, grid=grid(triton_fused_clone_3_add_1_div_view_13_mul_clone_2_bmm_1_permute_8_sub_mean6_xnumel), stream=stream0) | |
buf310 = buf281; del buf281 # reuse | |
aten.addmm.out(primals_178, as_strided(buf309, (128*s0, 768), (768, 1)), as_strided(primals_177, (768, 3072), (1, 768)), beta=1, alpha=1, out=buf310) | |
del primals_178 | |
buf311 = empty_strided((s0, 128, 3072), (393216, 3072, 1), device='cuda', dtype=torch.float32) | |
buf343 = empty_strided((s0, 128, 3072), (393216, 3072, 1), device='cuda', dtype=torch.float32) | |
triton_fused_clone_3_add_1_div_abs_1_view_13_mul_clone_2_bmm_1_mul_631_permute_87_xnumel = 393216*s0 | |
triton_fused_clone_3_add_1_div_abs_1_view_13_mul_clone_2_bmm_1_mul_631_permute_87.run(buf310, buf311, buf343, triton_fused_clone_3_add_1_div_abs_1_view_13_mul_clone_2_bmm_1_mul_631_permute_87_xnumel, grid=grid(triton_fused_clone_3_add_1_div_abs_1_view_13_mul_clone_2_bmm_1_mul_631_permute_87_xnumel), stream=stream0) | |
buf312 = as_strided(buf309, (128*s0, 768), (768, 1)); del buf309 # reuse | |
aten.addmm.out(primals_180, as_strided(buf311, (128*s0, 3072), (3072, 1)), as_strided(primals_179, (3072, 768), (1, 3072)), beta=1, alpha=1, out=buf312) | |
del primals_180 | |
buf313 = empty_strided((s0, 128, 1), (128, 1, 128*s0), device='cuda', dtype=torch.float32) | |
buf316 = empty_strided((s0, 128, 1), (128, 1, 128*s0), device='cuda', dtype=torch.float32) | |
buf314 = buf313; del buf313 # reuse | |
buf317 = buf316; del buf316 # reuse | |
buf318 = empty_strided((s0, 128, 768), (98304, 768, 1), device='cuda', dtype=torch.float32) | |
triton_fused_add_11_clone_3_add_1_div_abs_1_view_13_mul_add_12_clone_2_sub_48_xnumel = 128*s0 | |
triton_fused_add_11_clone_3_add_1_div_abs_1_view_13_mul_add_12_clone_2_sub_48.run(buf314, buf317, buf284, buf303, buf312, primals_45, primals_46, buf318, triton_fused_add_11_clone_3_add_1_div_abs_1_view_13_mul_add_12_clone_2_sub_48_xnumel, 768, grid=grid(triton_fused_add_11_clone_3_add_1_div_abs_1_view_13_mul_add_12_clone_2_sub_48_xnumel), stream=stream0) | |
buf319 = empty_strided((128*s0, 768), (768, 1), device='cuda', dtype=torch.float32) | |
aten.addmm.out(primals_182, as_strided(buf318, (128*s0, 768), (768, 1)), as_strided(primals_181, (768, 768), (1, 768)), beta=1, alpha=1, out=buf319) | |
del primals_182 | |
buf320 = empty_strided((128*s0, 768), (768, 1), device='cuda', dtype=torch.float32) | |
aten.addmm.out(primals_184, as_strided(buf318, (128*s0, 768), (768, 1)), as_strided(primals_183, (768, 768), (1, 768)), beta=1, alpha=1, out=buf320) | |
del primals_184 | |
buf321 = empty_strided((128*s0, 768), (768, 1), device='cuda', dtype=torch.float32) | |
aten.addmm.out(primals_186, as_strided(buf318, (128*s0, 768), (768, 1)), as_strided(primals_185, (768, 768), (1, 768)), beta=1, alpha=1, out=buf321) | |
del primals_186 | |
buf322 = as_strided(buf318, (s0, 12, 128, 64), (98304, 8192, 64, 1)); del buf318 # reuse | |
triton_fused_add_1_div_mul_clone_2_view_7_permute_5_sub_permute_4_mean_add4_xnumel = 98304*s0 | |
triton_fused_add_1_div_mul_clone_2_view_7_permute_5_sub_permute_4_mean_add4.run(buf319, buf322, triton_fused_add_1_div_mul_clone_2_view_7_permute_5_sub_permute_4_mean_add4_xnumel, grid=grid(triton_fused_add_1_div_mul_clone_2_view_7_permute_5_sub_permute_4_mean_add4_xnumel), stream=stream0) | |
buf323 = as_strided(buf319, (s0, 12, 64, 128), (98304, 8192, 128, 1)); del buf319 # reuse | |
triton_fused_add_1_div_mul_permute_6_sub_addmm_1_mean_permute_3_add_clone_12_xnumel = 768*s0 | |
triton_fused_add_1_div_mul_permute_6_sub_addmm_1_mean_permute_3_add_clone_12.run(buf320, buf323, triton_fused_add_1_div_mul_permute_6_sub_addmm_1_mean_permute_3_add_clone_12_xnumel, 128, grid=grid(triton_fused_add_1_div_mul_permute_6_sub_addmm_1_mean_permute_3_add_clone_12_xnumel, 128), stream=stream0) | |
buf324 = buf296; del buf296 # reuse | |
aten.bmm.out(as_strided(buf322, (12*s0, 128, 64), (8192, 64, 1)), as_strided(buf323, (12*s0, 64, 128), (8192, 128, 1)), out=buf324) | |
buf327 = empty_strided((s0, 12, 128, 128), (196608, 16384, 128, 1), device='cuda', dtype=torch.float32) | |
triton_fused_add_1_div_mul_sub_mean_add_div_2_sub_1_div_1_amax3_xnumel = 1536*s0 | |
triton_fused_add_1_div_mul_sub_mean_add_div_2_sub_1_div_1_amax3.run(primals_194, buf324, buf327, triton_fused_add_1_div_mul_sub_mean_add_div_2_sub_1_div_1_amax3_xnumel, 128, grid=grid(triton_fused_add_1_div_mul_sub_mean_add_div_2_sub_1_div_1_amax3_xnumel), stream=stream0) | |
del buf324 | |
buf328 = as_strided(buf320, (s0, 12, 128, 64), (98304, 8192, 64, 1)); del buf320 # reuse | |
triton_fused_add_1_div_mul_clone_2_view_7_permute_5_sub_permute_4_mean_add4_xnumel = 98304*s0 | |
triton_fused_add_1_div_mul_clone_2_view_7_permute_5_sub_permute_4_mean_add4.run(buf321, buf328, triton_fused_add_1_div_mul_clone_2_view_7_permute_5_sub_permute_4_mean_add4_xnumel, grid=grid(triton_fused_add_1_div_mul_clone_2_view_7_permute_5_sub_permute_4_mean_add4_xnumel), stream=stream0) | |
buf329 = as_strided(buf321, (12*s0, 128, 64), (8192, 64, 1)); del buf321 # reuse | |
aten.bmm.out(as_strided(buf327, (12*s0, 128, 128), (16384, 128, 1)), as_strided(buf328, (12*s0, 128, 64), (8192, 64, 1)), out=buf329) | |
buf330 = empty_strided((s0, 128, 12, 64), (98304, 768, 64, 1), device='cuda', dtype=torch.float32) | |
triton_fused_clone_3_add_1_div_where_mul_clone_2_permute_5_bmm_1_permute_6_sum_15_xnumel = 98304*s0 | |
triton_fused_clone_3_add_1_div_where_mul_clone_2_permute_5_bmm_1_permute_6_sum_15.run(buf329, buf330, triton_fused_clone_3_add_1_div_where_mul_clone_2_permute_5_bmm_1_permute_6_sum_15_xnumel, grid=grid(triton_fused_clone_3_add_1_div_where_mul_clone_2_permute_5_bmm_1_permute_6_sum_15_xnumel), stream=stream0) | |
buf331 = empty_strided((128*s0, 768), (768, 1), device='cuda', dtype=torch.float32) | |
aten.addmm.out(primals_188, as_strided(buf330, (128*s0, 768), (768, 1)), as_strided(primals_187, (768, 768), (1, 768)), beta=1, alpha=1, out=buf331) | |
del primals_188 | |
buf332 = empty_strided((s0, 128, 1), (128, 1, 128*s0), device='cuda', dtype=torch.float32) | |
buf335 = empty_strided((s0, 128, 1), (128, 1, 128*s0), device='cuda', dtype=torch.float32) | |
buf333 = buf332; del buf332 # reuse | |
buf336 = buf335; del buf335 # reuse | |
buf337 = as_strided(buf330, (s0, 128, 768), (98304, 768, 1)); del buf330 # reuse | |
triton_fused_add_11_add_1_permute_19_div_view_23_mul_add_12_sub_4_permute_14_view_199_xnumel = 128*s0 | |
triton_fused_add_11_add_1_permute_19_div_view_23_mul_add_12_sub_4_permute_14_view_199.run(buf333, buf336, buf284, buf303, buf312, buf331, primals_47, primals_48, buf337, triton_fused_add_11_add_1_permute_19_div_view_23_mul_add_12_sub_4_permute_14_view_199_xnumel, 768, grid=grid(triton_fused_add_11_add_1_permute_19_div_view_23_mul_add_12_sub_4_permute_14_view_199_xnumel), stream=stream0) | |
buf338 = buf310; del buf310 # reuse | |
aten.addmm.out(primals_190, as_strided(buf337, (128*s0, 768), (768, 1)), as_strided(primals_189, (768, 3072), (1, 768)), beta=1, alpha=1, out=buf338) | |
del primals_190 | |
buf339 = empty_strided((s0, 128, 3072), (393216, 3072, 1), device='cuda', dtype=torch.float32) | |
buf342 = empty_strided((s0, 128, 3072), (393216, 3072, 1), device='cuda', dtype=torch.float32) | |
triton_fused_clone_3_add_1_div_abs_1_view_13_mul_clone_2_bmm_1_mul_631_permute_87_xnumel = 393216*s0 | |
triton_fused_clone_3_add_1_div_abs_1_view_13_mul_clone_2_bmm_1_mul_631_permute_87.run(buf338, buf339, buf342, triton_fused_clone_3_add_1_div_abs_1_view_13_mul_clone_2_bmm_1_mul_631_permute_87_xnumel, grid=grid(triton_fused_clone_3_add_1_div_abs_1_view_13_mul_clone_2_bmm_1_mul_631_permute_87_xnumel), stream=stream0) | |
del buf338 | |
buf340 = as_strided(buf337, (128*s0, 768), (768, 1)); del buf337 # reuse | |
aten.addmm.out(primals_192, as_strided(buf339, (128*s0, 3072), (3072, 1)), as_strided(primals_191, (3072, 768), (1, 3072)), beta=1, alpha=1, out=buf340) | |
del primals_192 | |
buf341 = buf284; del buf284 # reuse | |
triton_fused_clone_19_view_88_bmm_9_div_18__unsafe_view_14_abs_5_lift_fresh_copy_8_clone_17_bmm_8_permute_5212_xnumel = 98304*s0 | |
triton_fused_clone_19_view_88_bmm_9_div_18__unsafe_view_14_abs_5_lift_fresh_copy_8_clone_17_bmm_8_permute_5212.run(buf341, buf303, buf312, buf331, buf340, triton_fused_clone_19_view_88_bmm_9_div_18__unsafe_view_14_abs_5_lift_fresh_copy_8_clone_17_bmm_8_permute_5212_xnumel, grid=grid(triton_fused_clone_19_view_88_bmm_9_div_18__unsafe_view_14_abs_5_lift_fresh_copy_8_clone_17_bmm_8_permute_5212_xnumel), stream=stream0) | |
return (buf341, primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, primals_12, primals_13, primals_14, primals_15, primals_16, primals_17, primals_18, primals_19, primals_20, primals_21, primals_22, primals_23, primals_24, primals_25, primals_26, primals_27, primals_28, primals_29, primals_30, primals_31, primals_32, primals_33, primals_34, primals_35, primals_36, primals_37, primals_38, primals_39, primals_40, primals_41, primals_42, primals_43, primals_44, primals_45, primals_46, primals_47, primals_48, primals_193, primals_194, buf1, buf4, buf9, buf10, buf14, buf15, buf16, buf18, buf20, buf23, as_strided(buf26, (128*s0, 3072), (3072, 1)), buf27, buf29, buf32, buf37, buf38, buf42, buf43, buf44, buf46, buf48, buf51, as_strided(buf54, (128*s0, 3072), (3072, 1)), buf55, buf58, buf61, buf66, buf67, buf71, buf72, buf73, buf75, buf77, buf80, as_strided(buf83, (128*s0, 3072), (3072, 1)), buf84, buf86, buf89, buf94, buf95, buf99, buf100, buf101, buf103, buf105, buf108, as_strided(buf111, (128*s0, 3072), (3072, 1)), buf112, buf115, buf118, buf123, buf124, buf128, buf129, buf130, buf132, buf134, buf137, as_strided(buf140, (128*s0, 3072), (3072, 1)), buf141, buf143, buf146, buf151, buf152, buf156, buf157, buf158, buf160, buf162, buf165, as_strided(buf168, (128*s0, 3072), (3072, 1)), buf169, buf172, buf175, buf180, buf181, buf185, buf186, buf187, buf189, buf191, buf194, as_strided(buf197, (128*s0, 3072), (3072, 1)), buf198, buf200, buf203, buf208, buf209, buf213, buf214, buf215, buf217, buf219, buf222, as_strided(buf225, (128*s0, 3072), (3072, 1)), buf226, buf229, buf232, buf237, buf238, buf242, buf243, buf244, buf246, buf248, buf251, as_strided(buf254, (128*s0, 3072), (3072, 1)), buf255, buf257, buf260, buf265, buf266, buf270, buf271, buf272, buf274, buf276, buf279, as_strided(buf282, (128*s0, 3072), (3072, 1)), buf283, buf286, buf289, buf294, buf295, buf299, buf300, buf301, buf303, buf305, buf308, as_strided(buf311, (128*s0, 3072), (3072, 1)), buf312, buf314, buf317, buf322, buf323, buf327, buf328, buf329, buf331, buf333, buf336, as_strided(buf339, (128*s0, 3072), (3072, 1)), as_strided(primals_191, (768, 3072), (3072, 1)), buf342, as_strided(primals_189, (3072, 768), (768, 1)), as_strided(primals_187, (768, 768), (768, 1)), as_strided(primals_185, (768, 768), (768, 1)), as_strided(primals_183, (768, 768), (768, 1)), as_strided(primals_181, (768, 768), (768, 1)), as_strided(primals_179, (768, 3072), (3072, 1)), buf343, as_strided(primals_177, (3072, 768), (768, 1)), as_strided(primals_175, (768, 768), (768, 1)), as_strided(primals_173, (768, 768), (768, 1)), as_strided(primals_171, (768, 768), (768, 1)), as_strided(primals_169, (768, 768), (768, 1)), as_strided(primals_167, (768, 3072), (3072, 1)), buf344, as_strided(primals_165, (3072, 768), (768, 1)), as_strided(primals_163, (768, 768), (768, 1)), as_strided(primals_161, (768, 768), (768, 1)), as_strided(primals_159, (768, 768), (768, 1)), as_strided(primals_157, (768, 768), (768, 1)), as_strided(primals_155, (768, 3072), (3072, 1)), buf345, as_strided(primals_153, (3072, 768), (768, 1)), as_strided(primals_151, (768, 768), (768, 1)), as_strided(primals_149, (768, 768), (768, 1)), as_strided(primals_147, (768, 768), (768, 1)), as_strided(primals_145, (768, 768), (768, 1)), as_strided(primals_143, (768, 3072), (3072, 1)), buf346, as_strided(primals_141, (3072, 768), (768, 1)), as_strided(primals_139, (768, 768), (768, 1)), as_strided(primals_137, (768, 768), (768, 1)), as_strided(primals_135, (768, 768), (768, 1)), as_strided(primals_133, (768, 768), (768, 1)), as_strided(primals_131, (768, 3072), (3072, 1)), buf347, as_strided(primals_129, (3072, 768), (768, 1)), as_strided(primals_127, (768, 768), (768, 1)), as_strided(primals_125, (768, 768), (768, 1)), as_strided(primals_123, (768, 768), (768, 1)), as_strided(primals_121, (768, 768), (768, 1)), as_strided(primals_119, (768, 3072), (3072, 1)), buf348, as_strided(primals_117, (3072, 768), (768, 1)), as_strided(primals_115, (768, 768), (768, 1)), as_strided(primals_113, (768, 768), (768, 1)), as_strided(primals_111, (768, 768), (768, 1)), as_strided(primals_109, (768, 768), (768, 1)), as_strided(primals_107, (768, 3072), (3072, 1)), buf349, as_strided(primals_105, (3072, 768), (768, 1)), as_strided(primals_103, (768, 768), (768, 1)), as_strided(primals_101, (768, 768), (768, 1)), as_strided(primals_99, (768, 768), (768, 1)), as_strided(primals_97, (768, 768), (768, 1)), as_strided(primals_95, (768, 3072), (3072, 1)), buf350, as_strided(primals_93, (3072, 768), (768, 1)), as_strided(primals_91, (768, 768), (768, 1)), as_strided(primals_89, (768, 768), (768, 1)), as_strided(primals_87, (768, 768), (768, 1)), as_strided(primals_85, (768, 768), (768, 1)), as_strided(primals_83, (768, 3072), (3072, 1)), buf351, as_strided(primals_81, (3072, 768), (768, 1)), as_strided(primals_79, (768, 768), (768, 1)), as_strided(primals_77, (768, 768), (768, 1)), as_strided(primals_75, (768, 768), (768, 1)), as_strided(primals_73, (768, 768), (768, 1)), as_strided(primals_71, (768, 3072), (3072, 1)), buf352, as_strided(primals_69, (3072, 768), (768, 1)), as_strided(primals_67, (768, 768), (768, 1)), as_strided(primals_65, (768, 768), (768, 1)), as_strided(primals_63, (768, 768), (768, 1)), as_strided(primals_61, (768, 768), (768, 1)), as_strided(primals_59, (768, 3072), (3072, 1)), buf353, as_strided(primals_57, (3072, 768), (768, 1)), as_strided(primals_55, (768, 768), (768, 1)), as_strided(primals_53, (768, 768), (768, 1)), as_strided(primals_51, (768, 768), (768, 1)), as_strided(primals_49, (768, 768), (768, 1)), s0, 128, 128*s0, 128, 12*s0, 128, s0, 12, 128, 128, 12*s0, 128, s0, 128, 128*s0, 768, s0, 128, 128*s0, s0, 128, s0, 128, 128*s0, 128, 12*s0, 128, s0, 12, 128, 128, 12*s0, 128, s0, 128, 128*s0, 768, s0, 128, 128*s0, s0, 128, s0, 128, 128*s0, 128, 12*s0, 128, s0, 12, 128, 128, 12*s0, 128, s0, 128, 128*s0, 768, s0, 128, 128*s0, s0, 128, s0, 128, 128*s0, 128, 12*s0, 128, s0, 12, 128, 128, 12*s0, 128, s0, 128, 128*s0, 768, s0, 128, 128*s0, s0, 128, s0, 128, 128*s0, 128, 12*s0, 128, s0, 12, 128, 128, 12*s0, 128, s0, 128, 128*s0, 768, s0, 128, 128*s0, s0, 128, s0, 128, 128*s0, 128, 12*s0, 128, s0, 12, 128, 128, 12*s0, 128, s0, 128, 128*s0, 768, s0, 128, 128*s0, s0, 128, s0, 128, 128*s0, 128, 12*s0, 128, s0, 12, 128, 128, 12*s0, 128, s0, 128, 128*s0, 768, s0, 128, 128*s0, s0, 128, s0, 128, 128*s0, 128, 12*s0, 128, s0, 12, 128, 128, 12*s0, 128, s0, 128, 128*s0, 768, s0, 128, 128*s0, s0, 128, s0, 128, 128*s0, 128, 12*s0, 128, s0, 12, 128, 128, 12*s0, 128, s0, 128, 128*s0, 768, s0, 128, 128*s0, s0, 128, s0, 128, 128*s0, 128, 12*s0, 128, s0, 12, 128, 128, 12*s0, 128, s0, 128, 128*s0, 768, s0, 128, 128*s0, s0, 128, s0, 128, 128*s0, 128, 12*s0, 128, s0, 12, 128, 128, 12*s0, 128, s0, 128, 128*s0, 768, s0, 128, 128*s0, s0, 128, s0, 128, 128*s0, 128, 12*s0, 128, s0, 12, 128, 128, 12*s0, 128, s0, 128, 128*s0, 768, s0, 128, 128*s0, s0, 128, 128*s0, 128*s0, 12*s0, 128, 128, s0, 128, 128*s0, s0, 128, 128*s0, s0, 128, 128*s0, 128*s0, 12*s0, 128, 128, s0, 128, 128*s0, s0, 128, 128*s0, s0, 128, 128*s0, 128*s0, 12*s0, 128, 128, s0, 128, 128*s0, s0, 128, 128*s0, s0, 128, 128*s0, 128*s0, 12*s0, 128, 128, s0, 128, 128*s0, s0, 128, 128*s0, s0, 128, 128*s0, 128*s0, 12*s0, 128, 128, s0, 128, 128*s0, s0, 128, 128*s0, s0, 128, 128*s0, 128*s0, 12*s0, 128, 128, s0, 128, 128*s0, s0, 128, 128*s0, s0, 128, 128*s0, 128*s0, 12*s0, 128, 128, s0, 128, 128*s0, s0, 128, 128*s0, s0, 128, 128*s0, 128*s0, 12*s0, 128, 128, s0, 128, 128*s0, s0, 128, 128*s0, s0, 128, 128*s0, 128*s0, 12*s0, 128, 128, s0, 128, 128*s0, s0, 128, 128*s0, s0, 128, 128*s0, 128*s0, 12*s0, 128, 128, s0, 128, 128*s0, s0, 128, 128*s0, s0, 128, 128*s0, 128*s0, 12*s0, 128, 128, s0, 128, 128*s0, s0, 128, 128*s0, s0, 128, 128*s0, 128*s0, 12*s0, 128, 128, s0, 128, 128*s0, s0, 128, 128*s0, s0, 128, 128*s0, ) | |
if __name__ == "__main__": | |
from torch._dynamo.testing import rand_strided | |
from torch._inductor.utils import print_performance | |
primals_1 = rand_strided((768, ), (1, ), device='cuda', dtype=torch.float32) | |
primals_2 = rand_strided((768, ), (1, ), device='cuda', dtype=torch.float32) | |
primals_3 = rand_strided((768, ), (1, ), device='cuda', dtype=torch.float32) | |
primals_4 = rand_strided((768, ), (1, ), device='cuda', dtype=torch.float32) | |
primals_5 = rand_strided((768, ), (1, ), device='cuda', dtype=torch.float32) | |
primals_6 = rand_strided((768, ), (1, ), device='cuda', dtype=torch.float32) | |
primals_7 = rand_strided((768, ), (1, ), device='cuda', dtype=torch.float32) | |
primals_8 = rand_strided((768, ), (1, ), device='cuda', dtype=torch.float32) | |
primals_9 = rand_strided((768, ), (1, ), device='cuda', dtype=torch.float32) | |
primals_10 = rand_strided((768, ), (1, ), device='cuda', dtype=torch.float32) | |
primals_11 = rand_strided((768, ), (1, ), device='cuda', dtype=torch.float32) | |
primals_12 = rand_strided((768, ), (1, ), device='cuda', dtype=torch.float32) | |
primals_13 = rand_strided((768, ), (1, ), device='cuda', dtype=torch.float32) | |
primals_14 = rand_strided((768, ), (1, ), device='cuda', dtype=torch.float32) | |
primals_15 = rand_strided((768, ), (1, ), device='cuda', dtype=torch.float32) | |
primals_16 = rand_strided((768, ), (1, ), device='cuda', dtype=torch.float32) | |
primals_17 = rand_strided((768, ), (1, ), device='cuda', dtype=torch.float32) | |
primals_18 = rand_strided((768, ), (1, ), device='cuda', dtype=torch.float32) | |
primals_19 = rand_strided((768, ), (1, ), device='cuda', dtype=torch.float32) | |
primals_20 = rand_strided((768, ), (1, ), device='cuda', dtype=torch.float32) | |
primals_21 = rand_strided((768, ), (1, ), device='cuda', dtype=torch.float32) | |
primals_22 = rand_strided((768, ), (1, ), device='cuda', dtype=torch.float32) | |
primals_23 = rand_strided((768, ), (1, ), device='cuda', dtype=torch.float32) | |
primals_24 = rand_strided((768, ), (1, ), device='cuda', dtype=torch.float32) | |
primals_25 = rand_strided((768, ), (1, ), device='cuda', dtype=torch.float32) | |
primals_26 = rand_strided((768, ), (1, ), device='cuda', dtype=torch.float32) | |
primals_27 = rand_strided((768, ), (1, ), device='cuda', dtype=torch.float32) | |
primals_28 = rand_strided((768, ), (1, ), device='cuda', dtype=torch.float32) | |
primals_29 = rand_strided((768, ), (1, ), device='cuda', dtype=torch.float32) | |
primals_30 = rand_strided((768, ), (1, ), device='cuda', dtype=torch.float32) | |
primals_31 = rand_strided((768, ), (1, ), device='cuda', dtype=torch.float32) | |
primals_32 = rand_strided((768, ), (1, ), device='cuda', dtype=torch.float32) | |
primals_33 = rand_strided((768, ), (1, ), device='cuda', dtype=torch.float32) | |
primals_34 = rand_strided((768, ), (1, ), device='cuda', dtype=torch.float32) | |
primals_35 = rand_strided((768, ), (1, ), device='cuda', dtype=torch.float32) | |
primals_36 = rand_strided((768, ), (1, ), device='cuda', dtype=torch.float32) | |
primals_37 = rand_strided((768, ), (1, ), device='cuda', dtype=torch.float32) | |
primals_38 = rand_strided((768, ), (1, ), device='cuda', dtype=torch.float32) | |
primals_39 = rand_strided((768, ), (1, ), device='cuda', dtype=torch.float32) | |
primals_40 = rand_strided((768, ), (1, ), device='cuda', dtype=torch.float32) | |
primals_41 = rand_strided((768, ), (1, ), device='cuda', dtype=torch.float32) | |
primals_42 = rand_strided((768, ), (1, ), device='cuda', dtype=torch.float32) | |
primals_43 = rand_strided((768, ), (1, ), device='cuda', dtype=torch.float32) | |
primals_44 = rand_strided((768, ), (1, ), device='cuda', dtype=torch.float32) | |
primals_45 = rand_strided((768, ), (1, ), device='cuda', dtype=torch.float32) | |
primals_46 = rand_strided((768, ), (1, ), device='cuda', dtype=torch.float32) | |
primals_47 = rand_strided((768, ), (1, ), device='cuda', dtype=torch.float32) | |
primals_48 = rand_strided((768, ), (1, ), device='cuda', dtype=torch.float32) | |
primals_49 = rand_strided((768, 768), (768, 1), device='cuda', dtype=torch.float32) | |
primals_50 = rand_strided((768, ), (1, ), device='cuda', dtype=torch.float32) | |
primals_51 = rand_strided((768, 768), (768, 1), device='cuda', dtype=torch.float32) | |
primals_52 = rand_strided((768, ), (1, ), device='cuda', dtype=torch.float32) | |
primals_53 = rand_strided((768, 768), (768, 1), device='cuda', dtype=torch.float32) | |
primals_54 = rand_strided((768, ), (1, ), device='cuda', dtype=torch.float32) | |
primals_55 = rand_strided((768, 768), (768, 1), device='cuda', dtype=torch.float32) | |
primals_56 = rand_strided((768, ), (1, ), device='cuda', dtype=torch.float32) | |
primals_57 = rand_strided((3072, 768), (768, 1), device='cuda', dtype=torch.float32) | |
primals_58 = rand_strided((3072, ), (1, ), device='cuda', dtype=torch.float32) | |
primals_59 = rand_strided((768, 3072), (3072, 1), device='cuda', dtype=torch.float32) | |
primals_60 = rand_strided((768, ), (1, ), device='cuda', dtype=torch.float32) | |
primals_61 = rand_strided((768, 768), (768, 1), device='cuda', dtype=torch.float32) | |
primals_62 = rand_strided((768, ), (1, ), device='cuda', dtype=torch.float32) | |
primals_63 = rand_strided((768, 768), (768, 1), device='cuda', dtype=torch.float32) | |
primals_64 = rand_strided((768, ), (1, ), device='cuda', dtype=torch.float32) | |
primals_65 = rand_strided((768, 768), (768, 1), device='cuda', dtype=torch.float32) | |
primals_66 = rand_strided((768, ), (1, ), device='cuda', dtype=torch.float32) | |
primals_67 = rand_strided((768, 768), (768, 1), device='cuda', dtype=torch.float32) | |
primals_68 = rand_strided((768, ), (1, ), device='cuda', dtype=torch.float32) | |
primals_69 = rand_strided((3072, 768), (768, 1), device='cuda', dtype=torch.float32) | |
primals_70 = rand_strided((3072, ), (1, ), device='cuda', dtype=torch.float32) | |
primals_71 = rand_strided((768, 3072), (3072, 1), device='cuda', dtype=torch.float32) | |
primals_72 = rand_strided((768, ), (1, ), device='cuda', dtype=torch.float32) | |
primals_73 = rand_strided((768, 768), (768, 1), device='cuda', dtype=torch.float32) | |
primals_74 = rand_strided((768, ), (1, ), device='cuda', dtype=torch.float32) | |
primals_75 = rand_strided((768, 768), (768, 1), device='cuda', dtype=torch.float32) | |
primals_76 = rand_strided((768, ), (1, ), device='cuda', dtype=torch.float32) | |
primals_77 = rand_strided((768, 768), (768, 1), device='cuda', dtype=torch.float32) | |
primals_78 = rand_strided((768, ), (1, ), device='cuda', dtype=torch.float32) | |
primals_79 = rand_strided((768, 768), (768, 1), device='cuda', dtype=torch.float32) | |
primals_80 = rand_strided((768, ), (1, ), device='cuda', dtype=torch.float32) | |
primals_81 = rand_strided((3072, 768), (768, 1), device='cuda', dtype=torch.float32) | |
primals_82 = rand_strided((3072, ), (1, ), device='cuda', dtype=torch.float32) | |
primals_83 = rand_strided((768, 3072), (3072, 1), device='cuda', dtype=torch.float32) | |
primals_84 = rand_strided((768, ), (1, ), device='cuda', dtype=torch.float32) | |
primals_85 = rand_strided((768, 768), (768, 1), device='cuda', dtype=torch.float32) | |
primals_86 = rand_strided((768, ), (1, ), device='cuda', dtype=torch.float32) | |
primals_87 = rand_strided((768, 768), (768, 1), device='cuda', dtype=torch.float32) | |
primals_88 = rand_strided((768, ), (1, ), device='cuda', dtype=torch.float32) | |
primals_89 = rand_strided((768, 768), (768, 1), device='cuda', dtype=torch.float32) | |
primals_90 = rand_strided((768, ), (1, ), device='cuda', dtype=torch.float32) | |
primals_91 = rand_strided((768, 768), (768, 1), device='cuda', dtype=torch.float32) | |
primals_92 = rand_strided((768, ), (1, ), device='cuda', dtype=torch.float32) | |
primals_93 = rand_strided((3072, 768), (768, 1), device='cuda', dtype=torch.float32) | |
primals_94 = rand_strided((3072, ), (1, ), device='cuda', dtype=torch.float32) | |
primals_95 = rand_strided((768, 3072), (3072, 1), device='cuda', dtype=torch.float32) | |
primals_96 = rand_strided((768, ), (1, ), device='cuda', dtype=torch.float32) | |
primals_97 = rand_strided((768, 768), (768, 1), device='cuda', dtype=torch.float32) | |
primals_98 = rand_strided((768, ), (1, ), device='cuda', dtype=torch.float32) | |
primals_99 = rand_strided((768, 768), (768, 1), device='cuda', dtype=torch.float32) | |
primals_100 = rand_strided((768, ), (1, ), device='cuda', dtype=torch.float32) | |
primals_101 = rand_strided((768, 768), (768, 1), device='cuda', dtype=torch.float32) | |
primals_102 = rand_strided((768, ), (1, ), device='cuda', dtype=torch.float32) | |
primals_103 = rand_strided((768, 768), (768, 1), device='cuda', dtype=torch.float32) | |
primals_104 = rand_strided((768, ), (1, ), device='cuda', dtype=torch.float32) | |
primals_105 = rand_strided((3072, 768), (768, 1), device='cuda', dtype=torch.float32) | |
primals_106 = rand_strided((3072, ), (1, ), device='cuda', dtype=torch.float32) | |
primals_107 = rand_strided((768, 3072), (3072, 1), device='cuda', dtype=torch.float32) | |
primals_108 = rand_strided((768, ), (1, ), device='cuda', dtype=torch.float32) | |
primals_109 = rand_strided((768, 768), (768, 1), device='cuda', dtype=torch.float32) | |
primals_110 = rand_strided((768, ), (1, ), device='cuda', dtype=torch.float32) | |
primals_111 = rand_strided((768, 768), (768, 1), device='cuda', dtype=torch.float32) | |
primals_112 = rand_strided((768, ), (1, ), device='cuda', dtype=torch.float32) | |
primals_113 = rand_strided((768, 768), (768, 1), device='cuda', dtype=torch.float32) | |
primals_114 = rand_strided((768, ), (1, ), device='cuda', dtype=torch.float32) | |
primals_115 = rand_strided((768, 768), (768, 1), device='cuda', dtype=torch.float32) | |
primals_116 = rand_strided((768, ), (1, ), device='cuda', dtype=torch.float32) | |
primals_117 = rand_strided((3072, 768), (768, 1), device='cuda', dtype=torch.float32) | |
primals_118 = rand_strided((3072, ), (1, ), device='cuda', dtype=torch.float32) | |
primals_119 = rand_strided((768, 3072), (3072, 1), device='cuda', dtype=torch.float32) | |
primals_120 = rand_strided((768, ), (1, ), device='cuda', dtype=torch.float32) | |
primals_121 = rand_strided((768, 768), (768, 1), device='cuda', dtype=torch.float32) | |
primals_122 = rand_strided((768, ), (1, ), device='cuda', dtype=torch.float32) | |
primals_123 = rand_strided((768, 768), (768, 1), device='cuda', dtype=torch.float32) | |
primals_124 = rand_strided((768, ), (1, ), device='cuda', dtype=torch.float32) | |
primals_125 = rand_strided((768, 768), (768, 1), device='cuda', dtype=torch.float32) | |
primals_126 = rand_strided((768, ), (1, ), device='cuda', dtype=torch.float32) | |
primals_127 = rand_strided((768, 768), (768, 1), device='cuda', dtype=torch.float32) | |
primals_128 = rand_strided((768, ), (1, ), device='cuda', dtype=torch.float32) | |
primals_129 = rand_strided((3072, 768), (768, 1), device='cuda', dtype=torch.float32) | |
primals_130 = rand_strided((3072, ), (1, ), device='cuda', dtype=torch.float32) | |
primals_131 = rand_strided((768, 3072), (3072, 1), device='cuda', dtype=torch.float32) | |
primals_132 = rand_strided((768, ), (1, ), device='cuda', dtype=torch.float32) | |
primals_133 = rand_strided((768, 768), (768, 1), device='cuda', dtype=torch.float32) | |
primals_134 = rand_strided((768, ), (1, ), device='cuda', dtype=torch.float32) | |
primals_135 = rand_strided((768, 768), (768, 1), device='cuda', dtype=torch.float32) | |
primals_136 = rand_strided((768, ), (1, ), device='cuda', dtype=torch.float32) | |
primals_137 = rand_strided((768, 768), (768, 1), device='cuda', dtype=torch.float32) | |
primals_138 = rand_strided((768, ), (1, ), device='cuda', dtype=torch.float32) | |
primals_139 = rand_strided((768, 768), (768, 1), device='cuda', dtype=torch.float32) | |
primals_140 = rand_strided((768, ), (1, ), device='cuda', dtype=torch.float32) | |
primals_141 = rand_strided((3072, 768), (768, 1), device='cuda', dtype=torch.float32) | |
primals_142 = rand_strided((3072, ), (1, ), device='cuda', dtype=torch.float32) | |
primals_143 = rand_strided((768, 3072), (3072, 1), device='cuda', dtype=torch.float32) | |
primals_144 = rand_strided((768, ), (1, ), device='cuda', dtype=torch.float32) | |
primals_145 = rand_strided((768, 768), (768, 1), device='cuda', dtype=torch.float32) | |
primals_146 = rand_strided((768, ), (1, ), device='cuda', dtype=torch.float32) | |
primals_147 = rand_strided((768, 768), (768, 1), device='cuda', dtype=torch.float32) | |
primals_148 = rand_strided((768, ), (1, ), device='cuda', dtype=torch.float32) | |
primals_149 = rand_strided((768, 768), (768, 1), device='cuda', dtype=torch.float32) | |
primals_150 = rand_strided((768, ), (1, ), device='cuda', dtype=torch.float32) | |
primals_151 = rand_strided((768, 768), (768, 1), device='cuda', dtype=torch.float32) | |
primals_152 = rand_strided((768, ), (1, ), device='cuda', dtype=torch.float32) | |
primals_153 = rand_strided((3072, 768), (768, 1), device='cuda', dtype=torch.float32) | |
primals_154 = rand_strided((3072, ), (1, ), device='cuda', dtype=torch.float32) | |
primals_155 = rand_strided((768, 3072), (3072, 1), device='cuda', dtype=torch.float32) | |
primals_156 = rand_strided((768, ), (1, ), device='cuda', dtype=torch.float32) | |
primals_157 = rand_strided((768, 768), (768, 1), device='cuda', dtype=torch.float32) | |
primals_158 = rand_strided((768, ), (1, ), device='cuda', dtype=torch.float32) | |
primals_159 = rand_strided((768, 768), (768, 1), device='cuda', dtype=torch.float32) | |
primals_160 = rand_strided((768, ), (1, ), device='cuda', dtype=torch.float32) | |
primals_161 = rand_strided((768, 768), (768, 1), device='cuda', dtype=torch.float32) | |
primals_162 = rand_strided((768, ), (1, ), device='cuda', dtype=torch.float32) | |
primals_163 = rand_strided((768, 768), (768, 1), device='cuda', dtype=torch.float32) | |
primals_164 = rand_strided((768, ), (1, ), device='cuda', dtype=torch.float32) | |
primals_165 = rand_strided((3072, 768), (768, 1), device='cuda', dtype=torch.float32) | |
primals_166 = rand_strided((3072, ), (1, ), device='cuda', dtype=torch.float32) | |
primals_167 = rand_strided((768, 3072), (3072, 1), device='cuda', dtype=torch.float32) | |
primals_168 = rand_strided((768, ), (1, ), device='cuda', dtype=torch.float32) | |
primals_169 = rand_strided((768, 768), (768, 1), device='cuda', dtype=torch.float32) | |
primals_170 = rand_strided((768, ), (1, ), device='cuda', dtype=torch.float32) | |
primals_171 = rand_strided((768, 768), (768, 1), device='cuda', dtype=torch.float32) | |
primals_172 = rand_strided((768, ), (1, ), device='cuda', dtype=torch.float32) | |
primals_173 = rand_strided((768, 768), (768, 1), device='cuda', dtype=torch.float32) | |
primals_174 = rand_strided((768, ), (1, ), device='cuda', dtype=torch.float32) | |
primals_175 = rand_strided((768, 768), (768, 1), device='cuda', dtype=torch.float32) | |
primals_176 = rand_strided((768, ), (1, ), device='cuda', dtype=torch.float32) | |
primals_177 = rand_strided((3072, 768), (768, 1), device='cuda', dtype=torch.float32) | |
primals_178 = rand_strided((3072, ), (1, ), device='cuda', dtype=torch.float32) | |
primals_179 = rand_strided((768, 3072), (3072, 1), device='cuda', dtype=torch.float32) | |
primals_180 = rand_strided((768, ), (1, ), device='cuda', dtype=torch.float32) | |
primals_181 = rand_strided((768, 768), (768, 1), device='cuda', dtype=torch.float32) | |
primals_182 = rand_strided((768, ), (1, ), device='cuda', dtype=torch.float32) | |
primals_183 = rand_strided((768, 768), (768, 1), device='cuda', dtype=torch.float32) | |
primals_184 = rand_strided((768, ), (1, ), device='cuda', dtype=torch.float32) | |
primals_185 = rand_strided((768, 768), (768, 1), device='cuda', dtype=torch.float32) | |
primals_186 = rand_strided((768, ), (1, ), device='cuda', dtype=torch.float32) | |
primals_187 = rand_strided((768, 768), (768, 1), device='cuda', dtype=torch.float32) | |
primals_188 = rand_strided((768, ), (1, ), device='cuda', dtype=torch.float32) | |
primals_189 = rand_strided((3072, 768), (768, 1), device='cuda', dtype=torch.float32) | |
primals_190 = rand_strided((3072, ), (1, ), device='cuda', dtype=torch.float32) | |
primals_191 = rand_strided((768, 3072), (3072, 1), device='cuda', dtype=torch.float32) | |
primals_192 = rand_strided((768, ), (1, ), device='cuda', dtype=torch.float32) | |
primals_193 = rand_strided((2, 128, 768), (98304, 768, 1), device='cuda', dtype=torch.float32) | |
primals_194 = rand_strided((2, 1, 128, 128), (128, 0, 0, 1), device='cuda', dtype=torch.bool) | |
print_performance(lambda: call([primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, primals_12, primals_13, primals_14, primals_15, primals_16, primals_17, primals_18, primals_19, primals_20, primals_21, primals_22, primals_23, primals_24, primals_25, primals_26, primals_27, primals_28, primals_29, primals_30, primals_31, primals_32, primals_33, primals_34, primals_35, primals_36, primals_37, primals_38, primals_39, primals_40, primals_41, primals_42, primals_43, primals_44, primals_45, primals_46, primals_47, primals_48, primals_49, primals_50, primals_51, primals_52, primals_53, primals_54, primals_55, primals_56, primals_57, primals_58, primals_59, primals_60, primals_61, primals_62, primals_63, primals_64, primals_65, primals_66, primals_67, primals_68, primals_69, primals_70, primals_71, primals_72, primals_73, primals_74, primals_75, primals_76, primals_77, primals_78, primals_79, primals_80, primals_81, primals_82, primals_83, primals_84, primals_85, primals_86, primals_87, primals_88, primals_89, primals_90, primals_91, primals_92, primals_93, primals_94, primals_95, primals_96, primals_97, primals_98, primals_99, primals_100, primals_101, primals_102, primals_103, primals_104, primals_105, primals_106, primals_107, primals_108, primals_109, primals_110, primals_111, primals_112, primals_113, primals_114, primals_115, primals_116, primals_117, primals_118, primals_119, primals_120, primals_121, primals_122, primals_123, primals_124, primals_125, primals_126, primals_127, primals_128, primals_129, primals_130, primals_131, primals_132, primals_133, primals_134, primals_135, primals_136, primals_137, primals_138, primals_139, primals_140, primals_141, primals_142, primals_143, primals_144, primals_145, primals_146, primals_147, primals_148, primals_149, primals_150, primals_151, primals_152, primals_153, primals_154, primals_155, primals_156, primals_157, primals_158, primals_159, primals_160, primals_161, primals_162, primals_163, primals_164, primals_165, primals_166, primals_167, primals_168, primals_169, primals_170, primals_171, primals_172, primals_173, primals_174, primals_175, primals_176, primals_177, primals_178, primals_179, primals_180, primals_181, primals_182, primals_183, primals_184, primals_185, primals_186, primals_187, primals_188, primals_189, primals_190, primals_191, primals_192, primals_193, primals_194])) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment