Created
September 22, 2023 00:26
-
-
Save shunting314/59aeafd297ed8ff03aa12030a2dd41ae to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import triton | |
import triton.language as tl | |
from torch._inductor.ir import ReductionHint | |
from torch._inductor.ir import TileHint | |
from torch._inductor.triton_heuristics import AutotuneHint, reduction | |
from torch._inductor.utils import instance_descriptor | |
from torch._inductor import triton_helpers | |
from torch._dynamo.testing import rand_strided | |
from torch._C import _cuda_getCurrentRawStream as get_cuda_stream | |
import torch | |
from torch._inductor.triton_heuristics import grid | |
@reduction( | |
size_hints=[512, 4096], | |
reduction_hint=ReductionHint.INNER, | |
filename=__file__, | |
meta={'signature': {0: '*fp32', 1: '*fp16', 2: '*fp16', 3: '*fp32', 4: '*fp32', 5: '*fp16', 6: '*fp32', 7: '*fp32', 8: '*fp16', 9: '*fp16', 10: '*fp16', 11: '*fp32', 12: '*fp32', 13: '*fp16', 14: '*fp32', 15: '*fp32', 16: '*fp16', 17: '*fp16', 18: '*fp16', 19: '*fp32', 20: '*fp32', 21: '*fp16', 22: '*fp32', 23: '*fp32', 24: '*fp16', 25: '*fp16', 26: '*fp16', 27: '*fp32', 28: '*fp32', 29: '*fp16', 30: '*fp32', 31: '*fp32', 32: '*fp16', 33: '*fp16', 34: '*fp16', 35: '*fp32', 36: '*fp32', 37: '*fp16', 38: '*fp32', 39: '*fp32', 40: '*fp16', 41: '*fp16', 42: '*fp16', 43: '*fp32', 44: '*fp32', 45: '*fp16', 46: '*fp32', 47: '*fp32', 48: '*fp16', 49: '*fp16', 50: '*fp16', 51: '*fp32', 52: '*fp32', 53: '*fp16', 54: '*fp32', 55: '*fp32', 56: '*fp16', 57: '*fp16', 58: '*fp16', 59: '*fp32', 60: '*fp32', 61: '*fp16', 62: '*fp32', 63: '*fp32', 64: '*fp16', 65: '*fp16', 66: '*fp16', 67: '*fp32', 68: '*fp32', 69: '*fp16', 70: '*fp32', 71: '*fp32', 72: '*fp16', 73: '*fp16', 74: '*fp16', 75: '*fp32', 76: '*fp32', 77: '*fp16', 78: '*fp32', 79: '*fp32', 80: '*fp16', 81: '*fp16', 82: '*fp16', 83: '*fp32', 84: '*fp32', 85: '*fp16', 86: '*fp32', 87: '*fp32', 88: '*fp16', 89: '*fp16', 90: '*fp16', 91: '*fp32', 92: '*fp32', 93: '*fp16', 94: '*fp32', 95: '*fp32', 96: '*fp16', 97: '*fp32', 98: '*fp32', 99: '*fp16', 100: '*fp32', 101: '*i1', 102: '*fp32', 103: '*fp32', 104: '*fp32', 105: '*fp32', 106: '*fp32', 107: '*fp32', 108: '*fp32', 109: '*fp32', 110: '*fp32', 111: '*fp32', 112: '*fp32', 113: '*fp32', 114: '*fp32', 115: '*fp32', 116: '*fp32', 117: '*fp32', 118: '*fp32', 119: '*fp32', 120: '*fp32', 121: '*fp32', 122: '*fp32', 123: '*fp32', 124: '*fp32', 125: '*fp32', 126: '*fp32', 127: '*fp32', 128: '*fp32', 129: '*fp32', 130: '*fp32', 131: '*fp32', 132: '*fp32', 133: '*fp32', 134: '*fp32', 135: '*fp32', 136: '*fp32', 137: '*fp32', 138: '*fp32', 139: '*fp16', 140: 'i32', 141: 'i32'}, 'device': 0, 'device_type': 'cuda', 'constants': {}, 'mutated_arg_names': [], 'autotune_hints': set(), 'kernel_name': 'Placeholder.DESCRIPTIVE_NAME', 'configs': [instance_descriptor(divisible_by_16=(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141), equal_to_1=(), ids_of_folded_args=(), divisible_by_8=(140, 141))]} | |
) | |
@triton.jit | |
def triton_(in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, in_ptr5, in_ptr6, in_ptr7, in_ptr8, in_ptr9, in_ptr10, in_ptr11, in_ptr12, in_ptr13, in_ptr14, in_ptr15, in_ptr16, in_ptr17, in_ptr18, in_ptr19, in_ptr20, in_ptr21, in_ptr22, in_ptr23, in_ptr24, in_ptr25, in_ptr26, in_ptr27, in_ptr28, in_ptr29, in_ptr30, in_ptr31, in_ptr32, in_ptr33, in_ptr34, in_ptr35, in_ptr36, in_ptr37, in_ptr38, in_ptr39, in_ptr40, in_ptr41, in_ptr42, in_ptr43, in_ptr44, in_ptr45, in_ptr46, in_ptr47, in_ptr48, in_ptr49, in_ptr50, in_ptr51, in_ptr52, in_ptr53, in_ptr54, in_ptr55, in_ptr56, in_ptr57, in_ptr58, in_ptr59, in_ptr60, in_ptr61, in_ptr62, in_ptr63, in_ptr64, in_ptr65, in_ptr66, in_ptr67, in_ptr68, in_ptr69, in_ptr70, in_ptr71, in_ptr72, in_ptr73, in_ptr74, in_ptr75, in_ptr76, in_ptr77, in_ptr78, in_ptr79, in_ptr80, in_ptr81, in_ptr82, in_ptr83, in_ptr84, in_ptr85, in_ptr86, in_ptr87, in_ptr88, in_ptr89, in_ptr90, in_ptr91, in_ptr92, in_ptr93, in_ptr94, in_ptr95, in_ptr96, in_ptr97, in_ptr98, in_ptr99, in_ptr100, in_ptr101, out_ptr0, out_ptr1, out_ptr2, out_ptr3, out_ptr4, out_ptr5, out_ptr6, out_ptr7, out_ptr8, out_ptr9, out_ptr10, out_ptr11, out_ptr12, out_ptr13, out_ptr14, out_ptr15, out_ptr16, out_ptr17, out_ptr18, out_ptr19, out_ptr20, out_ptr21, out_ptr22, out_ptr23, out_ptr24, out_ptr25, out_ptr26, out_ptr27, out_ptr28, out_ptr29, out_ptr30, out_ptr31, out_ptr32, out_ptr33, out_ptr34, out_ptr35, out_ptr38, out_ptr39, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr): | |
xnumel = 512 | |
rnumel = 2560 | |
xoffset = tl.program_id(0) * XBLOCK | |
xindex = xoffset + tl.arange(0, XBLOCK)[:, None] | |
xmask = xindex < xnumel | |
rbase = tl.arange(0, RBLOCK)[None, :] | |
x0 = xindex | |
tmp7 = tl.load(in_ptr3 + (x0), xmask, eviction_policy='evict_last') | |
tmp9 = tl.load(in_ptr4 + (x0), xmask, eviction_policy='evict_last') | |
tmp14 = tl.load(in_ptr6 + (x0), xmask, eviction_policy='evict_last') | |
tmp16 = tl.load(in_ptr7 + (x0), xmask, eviction_policy='evict_last') | |
tmp27 = tl.load(in_ptr11 + (x0), xmask, eviction_policy='evict_last') | |
tmp29 = tl.load(in_ptr12 + (x0), xmask, eviction_policy='evict_last') | |
tmp34 = tl.load(in_ptr14 + (x0), xmask, eviction_policy='evict_last') | |
tmp36 = tl.load(in_ptr15 + (x0), xmask, eviction_policy='evict_last') | |
tmp47 = tl.load(in_ptr19 + (x0), xmask, eviction_policy='evict_last') | |
tmp49 = tl.load(in_ptr20 + (x0), xmask, eviction_policy='evict_last') | |
tmp54 = tl.load(in_ptr22 + (x0), xmask, eviction_policy='evict_last') | |
tmp56 = tl.load(in_ptr23 + (x0), xmask, eviction_policy='evict_last') | |
tmp67 = tl.load(in_ptr27 + (x0), xmask, eviction_policy='evict_last') | |
tmp69 = tl.load(in_ptr28 + (x0), xmask, eviction_policy='evict_last') | |
tmp74 = tl.load(in_ptr30 + (x0), xmask, eviction_policy='evict_last') | |
tmp76 = tl.load(in_ptr31 + (x0), xmask, eviction_policy='evict_last') | |
tmp87 = tl.load(in_ptr35 + (x0), xmask, eviction_policy='evict_last') | |
tmp89 = tl.load(in_ptr36 + (x0), xmask, eviction_policy='evict_last') | |
tmp94 = tl.load(in_ptr38 + (x0), xmask, eviction_policy='evict_last') | |
tmp96 = tl.load(in_ptr39 + (x0), xmask, eviction_policy='evict_last') | |
tmp107 = tl.load(in_ptr43 + (x0), xmask, eviction_policy='evict_last') | |
tmp109 = tl.load(in_ptr44 + (x0), xmask, eviction_policy='evict_last') | |
tmp114 = tl.load(in_ptr46 + (x0), xmask, eviction_policy='evict_last') | |
tmp116 = tl.load(in_ptr47 + (x0), xmask, eviction_policy='evict_last') | |
tmp127 = tl.load(in_ptr51 + (x0), xmask, eviction_policy='evict_last') | |
tmp129 = tl.load(in_ptr52 + (x0), xmask, eviction_policy='evict_last') | |
tmp134 = tl.load(in_ptr54 + (x0), xmask, eviction_policy='evict_last') | |
tmp136 = tl.load(in_ptr55 + (x0), xmask, eviction_policy='evict_last') | |
tmp147 = tl.load(in_ptr59 + (x0), xmask, eviction_policy='evict_last') | |
tmp149 = tl.load(in_ptr60 + (x0), xmask, eviction_policy='evict_last') | |
tmp154 = tl.load(in_ptr62 + (x0), xmask, eviction_policy='evict_last') | |
tmp156 = tl.load(in_ptr63 + (x0), xmask, eviction_policy='evict_last') | |
tmp167 = tl.load(in_ptr67 + (x0), xmask, eviction_policy='evict_last') | |
tmp169 = tl.load(in_ptr68 + (x0), xmask, eviction_policy='evict_last') | |
tmp174 = tl.load(in_ptr70 + (x0), xmask, eviction_policy='evict_last') | |
tmp176 = tl.load(in_ptr71 + (x0), xmask, eviction_policy='evict_last') | |
tmp187 = tl.load(in_ptr75 + (x0), xmask, eviction_policy='evict_last') | |
tmp189 = tl.load(in_ptr76 + (x0), xmask, eviction_policy='evict_last') | |
tmp194 = tl.load(in_ptr78 + (x0), xmask, eviction_policy='evict_last') | |
tmp196 = tl.load(in_ptr79 + (x0), xmask, eviction_policy='evict_last') | |
tmp207 = tl.load(in_ptr83 + (x0), xmask, eviction_policy='evict_last') | |
tmp209 = tl.load(in_ptr84 + (x0), xmask, eviction_policy='evict_last') | |
tmp214 = tl.load(in_ptr86 + (x0), xmask, eviction_policy='evict_last') | |
tmp216 = tl.load(in_ptr87 + (x0), xmask, eviction_policy='evict_last') | |
tmp227 = tl.load(in_ptr91 + (x0), xmask, eviction_policy='evict_last') | |
tmp229 = tl.load(in_ptr92 + (x0), xmask, eviction_policy='evict_last') | |
tmp234 = tl.load(in_ptr94 + (x0), xmask, eviction_policy='evict_last') | |
tmp236 = tl.load(in_ptr95 + (x0), xmask, eviction_policy='evict_last') | |
tmp241 = tl.load(in_ptr97 + (x0), xmask, eviction_policy='evict_last') | |
tmp243 = tl.load(in_ptr98 + (x0), xmask, eviction_policy='evict_last') | |
_tmp250 = tl.full([XBLOCK, RBLOCK], 0, tl.float32) | |
for roffset in range(0, rnumel, RBLOCK): | |
rindex = roffset + rbase | |
rmask = rindex < rnumel | |
r1 = rindex | |
tmp0 = tl.load(in_ptr0 + (r1 + (2560*x0)), rmask & xmask, eviction_policy='evict_last', other=0) | |
tmp1 = tl.load(in_ptr1 + (r1 + (2560*x0)), rmask & xmask, eviction_policy='evict_last', other=0).to(tl.float32) | |
tmp4 = tl.load(in_ptr2 + (r1 + (2560*x0)), rmask & xmask, eviction_policy='evict_last', other=0).to(tl.float32) | |
tmp11 = tl.load(in_ptr5 + (r1 + (2560*x0)), rmask & xmask, eviction_policy='evict_last', other=0).to(tl.float32) | |
tmp18 = tl.load(in_ptr8 + (r1 + (2560*x0)), rmask & xmask, eviction_policy='evict_last', other=0).to(tl.float32) | |
tmp21 = tl.load(in_ptr9 + (r1 + (2560*x0)), rmask & xmask, eviction_policy='evict_last', other=0).to(tl.float32) | |
tmp24 = tl.load(in_ptr10 + (r1 + (2560*x0)), rmask & xmask, eviction_policy='evict_last', other=0).to(tl.float32) | |
tmp31 = tl.load(in_ptr13 + (r1 + (2560*x0)), rmask & xmask, eviction_policy='evict_last', other=0).to(tl.float32) | |
tmp38 = tl.load(in_ptr16 + (r1 + (2560*x0)), rmask & xmask, eviction_policy='evict_last', other=0).to(tl.float32) | |
tmp41 = tl.load(in_ptr17 + (r1 + (2560*x0)), rmask & xmask, eviction_policy='evict_last', other=0).to(tl.float32) | |
tmp44 = tl.load(in_ptr18 + (r1 + (2560*x0)), rmask & xmask, eviction_policy='evict_last', other=0).to(tl.float32) | |
tmp51 = tl.load(in_ptr21 + (r1 + (2560*x0)), rmask & xmask, eviction_policy='evict_last', other=0).to(tl.float32) | |
tmp58 = tl.load(in_ptr24 + (r1 + (2560*x0)), rmask & xmask, eviction_policy='evict_last', other=0).to(tl.float32) | |
tmp61 = tl.load(in_ptr25 + (r1 + (2560*x0)), rmask & xmask, eviction_policy='evict_last', other=0).to(tl.float32) | |
tmp64 = tl.load(in_ptr26 + (r1 + (2560*x0)), rmask & xmask, eviction_policy='evict_last', other=0).to(tl.float32) | |
tmp71 = tl.load(in_ptr29 + (r1 + (2560*x0)), rmask & xmask, eviction_policy='evict_last', other=0).to(tl.float32) | |
tmp78 = tl.load(in_ptr32 + (r1 + (2560*x0)), rmask & xmask, eviction_policy='evict_last', other=0).to(tl.float32) | |
tmp81 = tl.load(in_ptr33 + (r1 + (2560*x0)), rmask & xmask, eviction_policy='evict_last', other=0).to(tl.float32) | |
tmp84 = tl.load(in_ptr34 + (r1 + (2560*x0)), rmask & xmask, eviction_policy='evict_last', other=0).to(tl.float32) | |
tmp91 = tl.load(in_ptr37 + (r1 + (2560*x0)), rmask & xmask, eviction_policy='evict_last', other=0).to(tl.float32) | |
tmp98 = tl.load(in_ptr40 + (r1 + (2560*x0)), rmask & xmask, eviction_policy='evict_last', other=0).to(tl.float32) | |
tmp101 = tl.load(in_ptr41 + (r1 + (2560*x0)), rmask & xmask, eviction_policy='evict_last', other=0).to(tl.float32) | |
tmp104 = tl.load(in_ptr42 + (r1 + (2560*x0)), rmask & xmask, eviction_policy='evict_last', other=0).to(tl.float32) | |
tmp111 = tl.load(in_ptr45 + (r1 + (2560*x0)), rmask & xmask, eviction_policy='evict_last', other=0).to(tl.float32) | |
tmp118 = tl.load(in_ptr48 + (r1 + (2560*x0)), rmask & xmask, eviction_policy='evict_last', other=0).to(tl.float32) | |
tmp121 = tl.load(in_ptr49 + (r1 + (2560*x0)), rmask & xmask, eviction_policy='evict_last', other=0).to(tl.float32) | |
tmp124 = tl.load(in_ptr50 + (r1 + (2560*x0)), rmask & xmask, eviction_policy='evict_last', other=0).to(tl.float32) | |
tmp131 = tl.load(in_ptr53 + (r1 + (2560*x0)), rmask & xmask, eviction_policy='evict_last', other=0).to(tl.float32) | |
tmp138 = tl.load(in_ptr56 + (r1 + (2560*x0)), rmask & xmask, eviction_policy='evict_last', other=0).to(tl.float32) | |
tmp141 = tl.load(in_ptr57 + (r1 + (2560*x0)), rmask & xmask, eviction_policy='evict_last', other=0).to(tl.float32) | |
tmp144 = tl.load(in_ptr58 + (r1 + (2560*x0)), rmask & xmask, eviction_policy='evict_last', other=0).to(tl.float32) | |
tmp151 = tl.load(in_ptr61 + (r1 + (2560*x0)), rmask & xmask, eviction_policy='evict_last', other=0).to(tl.float32) | |
tmp158 = tl.load(in_ptr64 + (r1 + (2560*x0)), rmask & xmask, eviction_policy='evict_last', other=0).to(tl.float32) | |
tmp161 = tl.load(in_ptr65 + (r1 + (2560*x0)), rmask & xmask, eviction_policy='evict_last', other=0).to(tl.float32) | |
tmp164 = tl.load(in_ptr66 + (r1 + (2560*x0)), rmask & xmask, eviction_policy='evict_last', other=0).to(tl.float32) | |
tmp171 = tl.load(in_ptr69 + (r1 + (2560*x0)), rmask & xmask, eviction_policy='evict_last', other=0).to(tl.float32) | |
tmp178 = tl.load(in_ptr72 + (r1 + (2560*x0)), rmask & xmask, eviction_policy='evict_last', other=0).to(tl.float32) | |
tmp181 = tl.load(in_ptr73 + (r1 + (2560*x0)), rmask & xmask, eviction_policy='evict_last', other=0).to(tl.float32) | |
tmp184 = tl.load(in_ptr74 + (r1 + (2560*x0)), rmask & xmask, eviction_policy='evict_last', other=0).to(tl.float32) | |
tmp191 = tl.load(in_ptr77 + (r1 + (2560*x0)), rmask & xmask, eviction_policy='evict_last', other=0).to(tl.float32) | |
tmp198 = tl.load(in_ptr80 + (r1 + (2560*x0)), rmask & xmask, eviction_policy='evict_last', other=0).to(tl.float32) | |
tmp201 = tl.load(in_ptr81 + (r1 + (2560*x0)), rmask & xmask, eviction_policy='evict_last', other=0).to(tl.float32) | |
tmp204 = tl.load(in_ptr82 + (r1 + (2560*x0)), rmask & xmask, eviction_policy='evict_last', other=0).to(tl.float32) | |
tmp211 = tl.load(in_ptr85 + (r1 + (2560*x0)), rmask & xmask, eviction_policy='evict_last', other=0).to(tl.float32) | |
tmp218 = tl.load(in_ptr88 + (r1 + (2560*x0)), rmask & xmask, eviction_policy='evict_last', other=0).to(tl.float32) | |
tmp221 = tl.load(in_ptr89 + (r1 + (2560*x0)), rmask & xmask, eviction_policy='evict_last', other=0).to(tl.float32) | |
tmp224 = tl.load(in_ptr90 + (r1 + (2560*x0)), rmask & xmask, eviction_policy='evict_last', other=0).to(tl.float32) | |
tmp231 = tl.load(in_ptr93 + (r1 + (2560*x0)), rmask & xmask, eviction_policy='evict_last', other=0).to(tl.float32) | |
tmp238 = tl.load(in_ptr96 + (r1 + (2560*x0)), rmask & xmask, eviction_policy='evict_last', other=0).to(tl.float32) | |
tmp245 = tl.load(in_ptr99 + (r1 + (2560*x0)), rmask & xmask, eviction_policy='evict_last', other=0).to(tl.float32) | |
tmp247 = tl.load(in_ptr100 + (r1), rmask, eviction_policy='evict_last', other=0) | |
tmp2 = tmp1.to(tl.float32) | |
tmp3 = tmp0 + tmp2 | |
tmp5 = tmp4.to(tl.float32) | |
tmp6 = tmp3 + tmp5 | |
tmp8 = tmp6 - tmp7 | |
tmp10 = tmp8 * tmp9 | |
tmp12 = tmp11.to(tl.float32) | |
tmp13 = tmp6 + tmp12 | |
tmp15 = tmp13 - tmp14 | |
tmp17 = tmp15 * tmp16 | |
tmp19 = tmp18.to(tl.float32) | |
tmp20 = tmp13 + tmp19 | |
tmp22 = tmp21.to(tl.float32) | |
tmp23 = tmp20 + tmp22 | |
tmp25 = tmp24.to(tl.float32) | |
tmp26 = tmp23 + tmp25 | |
tmp28 = tmp26 - tmp27 | |
tmp30 = tmp28 * tmp29 | |
tmp32 = tmp31.to(tl.float32) | |
tmp33 = tmp26 + tmp32 | |
tmp35 = tmp33 - tmp34 | |
tmp37 = tmp35 * tmp36 | |
tmp39 = tmp38.to(tl.float32) | |
tmp40 = tmp33 + tmp39 | |
tmp42 = tmp41.to(tl.float32) | |
tmp43 = tmp40 + tmp42 | |
tmp45 = tmp44.to(tl.float32) | |
tmp46 = tmp43 + tmp45 | |
tmp48 = tmp46 - tmp47 | |
tmp50 = tmp48 * tmp49 | |
tmp52 = tmp51.to(tl.float32) | |
tmp53 = tmp46 + tmp52 | |
tmp55 = tmp53 - tmp54 | |
tmp57 = tmp55 * tmp56 | |
tmp59 = tmp58.to(tl.float32) | |
tmp60 = tmp53 + tmp59 | |
tmp62 = tmp61.to(tl.float32) | |
tmp63 = tmp60 + tmp62 | |
tmp65 = tmp64.to(tl.float32) | |
tmp66 = tmp63 + tmp65 | |
tmp68 = tmp66 - tmp67 | |
tmp70 = tmp68 * tmp69 | |
tmp72 = tmp71.to(tl.float32) | |
tmp73 = tmp66 + tmp72 | |
tmp75 = tmp73 - tmp74 | |
tmp77 = tmp75 * tmp76 | |
tmp79 = tmp78.to(tl.float32) | |
tmp80 = tmp73 + tmp79 | |
tmp82 = tmp81.to(tl.float32) | |
tmp83 = tmp80 + tmp82 | |
tmp85 = tmp84.to(tl.float32) | |
tmp86 = tmp83 + tmp85 | |
tmp88 = tmp86 - tmp87 | |
tmp90 = tmp88 * tmp89 | |
tmp92 = tmp91.to(tl.float32) | |
tmp93 = tmp86 + tmp92 | |
tmp95 = tmp93 - tmp94 | |
tmp97 = tmp95 * tmp96 | |
tmp99 = tmp98.to(tl.float32) | |
tmp100 = tmp93 + tmp99 | |
tmp102 = tmp101.to(tl.float32) | |
tmp103 = tmp100 + tmp102 | |
tmp105 = tmp104.to(tl.float32) | |
tmp106 = tmp103 + tmp105 | |
tmp108 = tmp106 - tmp107 | |
tmp110 = tmp108 * tmp109 | |
tmp112 = tmp111.to(tl.float32) | |
tmp113 = tmp106 + tmp112 | |
tmp115 = tmp113 - tmp114 | |
tmp117 = tmp115 * tmp116 | |
tmp119 = tmp118.to(tl.float32) | |
tmp120 = tmp113 + tmp119 | |
tmp122 = tmp121.to(tl.float32) | |
tmp123 = tmp120 + tmp122 | |
tmp125 = tmp124.to(tl.float32) | |
tmp126 = tmp123 + tmp125 | |
tmp128 = tmp126 - tmp127 | |
tmp130 = tmp128 * tmp129 | |
tmp132 = tmp131.to(tl.float32) | |
tmp133 = tmp126 + tmp132 | |
tmp135 = tmp133 - tmp134 | |
tmp137 = tmp135 * tmp136 | |
tmp139 = tmp138.to(tl.float32) | |
tmp140 = tmp133 + tmp139 | |
tmp142 = tmp141.to(tl.float32) | |
tmp143 = tmp140 + tmp142 | |
tmp145 = tmp144.to(tl.float32) | |
tmp146 = tmp143 + tmp145 | |
tmp148 = tmp146 - tmp147 | |
tmp150 = tmp148 * tmp149 | |
tmp152 = tmp151.to(tl.float32) | |
tmp153 = tmp146 + tmp152 | |
tmp155 = tmp153 - tmp154 | |
tmp157 = tmp155 * tmp156 | |
tmp159 = tmp158.to(tl.float32) | |
tmp160 = tmp153 + tmp159 | |
tmp162 = tmp161.to(tl.float32) | |
tmp163 = tmp160 + tmp162 | |
tmp165 = tmp164.to(tl.float32) | |
tmp166 = tmp163 + tmp165 | |
tmp168 = tmp166 - tmp167 | |
tmp170 = tmp168 * tmp169 | |
tmp172 = tmp171.to(tl.float32) | |
tmp173 = tmp166 + tmp172 | |
tmp175 = tmp173 - tmp174 | |
tmp177 = tmp175 * tmp176 | |
tmp179 = tmp178.to(tl.float32) | |
tmp180 = tmp173 + tmp179 | |
tmp182 = tmp181.to(tl.float32) | |
tmp183 = tmp180 + tmp182 | |
tmp185 = tmp184.to(tl.float32) | |
tmp186 = tmp183 + tmp185 | |
tmp188 = tmp186 - tmp187 | |
tmp190 = tmp188 * tmp189 | |
tmp192 = tmp191.to(tl.float32) | |
tmp193 = tmp186 + tmp192 | |
tmp195 = tmp193 - tmp194 | |
tmp197 = tmp195 * tmp196 | |
tmp199 = tmp198.to(tl.float32) | |
tmp200 = tmp193 + tmp199 | |
tmp202 = tmp201.to(tl.float32) | |
tmp203 = tmp200 + tmp202 | |
tmp205 = tmp204.to(tl.float32) | |
tmp206 = tmp203 + tmp205 | |
tmp208 = tmp206 - tmp207 | |
tmp210 = tmp208 * tmp209 | |
tmp212 = tmp211.to(tl.float32) | |
tmp213 = tmp206 + tmp212 | |
tmp215 = tmp213 - tmp214 | |
tmp217 = tmp215 * tmp216 | |
tmp219 = tmp218.to(tl.float32) | |
tmp220 = tmp213 + tmp219 | |
tmp222 = tmp221.to(tl.float32) | |
tmp223 = tmp220 + tmp222 | |
tmp225 = tmp224.to(tl.float32) | |
tmp226 = tmp223 + tmp225 | |
tmp228 = tmp226 - tmp227 | |
tmp230 = tmp228 * tmp229 | |
tmp232 = tmp231.to(tl.float32) | |
tmp233 = tmp226 + tmp232 | |
tmp235 = tmp233 - tmp234 | |
tmp237 = tmp235 * tmp236 | |
tmp239 = tmp238.to(tl.float32) | |
tmp240 = tmp233 + tmp239 | |
tmp242 = tmp240 - tmp241 | |
tmp244 = tmp242 * tmp243 | |
tmp246 = tmp245.to(tl.float32) | |
tmp248 = tmp246 * tmp247 | |
tmp249 = tl.broadcast_to(tmp248, [XBLOCK, RBLOCK]) | |
tmp251 = _tmp250 + tmp249 | |
_tmp250 = tl.where(rmask & xmask, tmp251, _tmp250) | |
tl.store(out_ptr0 + (r1 + (2560*x0)), tmp10, rmask & xmask) | |
tl.store(out_ptr1 + (r1 + (2560*x0)), tmp17, rmask & xmask) | |
tl.store(out_ptr2 + (r1 + (2560*x0)), tmp20, rmask & xmask) | |
tl.store(out_ptr3 + (r1 + (2560*x0)), tmp30, rmask & xmask) | |
tl.store(out_ptr4 + (r1 + (2560*x0)), tmp37, rmask & xmask) | |
tl.store(out_ptr5 + (r1 + (2560*x0)), tmp40, rmask & xmask) | |
tl.store(out_ptr6 + (r1 + (2560*x0)), tmp50, rmask & xmask) | |
tl.store(out_ptr7 + (r1 + (2560*x0)), tmp57, rmask & xmask) | |
tl.store(out_ptr8 + (r1 + (2560*x0)), tmp60, rmask & xmask) | |
tl.store(out_ptr9 + (r1 + (2560*x0)), tmp70, rmask & xmask) | |
tl.store(out_ptr10 + (r1 + (2560*x0)), tmp77, rmask & xmask) | |
tl.store(out_ptr11 + (r1 + (2560*x0)), tmp80, rmask & xmask) | |
tl.store(out_ptr12 + (r1 + (2560*x0)), tmp90, rmask & xmask) | |
tl.store(out_ptr13 + (r1 + (2560*x0)), tmp97, rmask & xmask) | |
tl.store(out_ptr14 + (r1 + (2560*x0)), tmp100, rmask & xmask) | |
tl.store(out_ptr15 + (r1 + (2560*x0)), tmp110, rmask & xmask) | |
tl.store(out_ptr16 + (r1 + (2560*x0)), tmp117, rmask & xmask) | |
tl.store(out_ptr17 + (r1 + (2560*x0)), tmp120, rmask & xmask) | |
tl.store(out_ptr18 + (r1 + (2560*x0)), tmp130, rmask & xmask) | |
tl.store(out_ptr19 + (r1 + (2560*x0)), tmp137, rmask & xmask) | |
tl.store(out_ptr20 + (r1 + (2560*x0)), tmp140, rmask & xmask) | |
tl.store(out_ptr21 + (r1 + (2560*x0)), tmp150, rmask & xmask) | |
tl.store(out_ptr22 + (r1 + (2560*x0)), tmp157, rmask & xmask) | |
tl.store(out_ptr23 + (r1 + (2560*x0)), tmp160, rmask & xmask) | |
tl.store(out_ptr24 + (r1 + (2560*x0)), tmp170, rmask & xmask) | |
tl.store(out_ptr25 + (r1 + (2560*x0)), tmp177, rmask & xmask) | |
tl.store(out_ptr26 + (r1 + (2560*x0)), tmp180, rmask & xmask) | |
tl.store(out_ptr27 + (r1 + (2560*x0)), tmp190, rmask & xmask) | |
tl.store(out_ptr28 + (r1 + (2560*x0)), tmp197, rmask & xmask) | |
tl.store(out_ptr29 + (r1 + (2560*x0)), tmp200, rmask & xmask) | |
tl.store(out_ptr30 + (r1 + (2560*x0)), tmp210, rmask & xmask) | |
tl.store(out_ptr31 + (r1 + (2560*x0)), tmp217, rmask & xmask) | |
tl.store(out_ptr32 + (r1 + (2560*x0)), tmp220, rmask & xmask) | |
tl.store(out_ptr33 + (r1 + (2560*x0)), tmp230, rmask & xmask) | |
tl.store(out_ptr34 + (r1 + (2560*x0)), tmp237, rmask & xmask) | |
tl.store(out_ptr35 + (r1 + (2560*x0)), tmp244, rmask & xmask) | |
tmp250 = tl.sum(_tmp250, 1)[:, None] | |
_tmp259 = tl.full([XBLOCK, RBLOCK], 0, tl.float32) | |
for roffset in range(0, rnumel, RBLOCK): | |
rindex = roffset + rbase | |
rmask = rindex < rnumel | |
r1 = rindex | |
tmp252 = tl.load(in_ptr99 + (r1 + (2560*x0)), rmask & xmask, eviction_policy='evict_last', other=0).to(tl.float32) | |
tmp254 = tl.load(in_ptr100 + (r1), rmask, eviction_policy='evict_last', other=0) | |
tmp256 = tl.load(out_ptr35 + (r1 + (2560*x0)), rmask & xmask, eviction_policy='evict_last', other=0) | |
tmp253 = tmp252.to(tl.float32) | |
tmp255 = tmp253 * tmp254 | |
tmp257 = tmp255 * tmp256 | |
tmp258 = tl.broadcast_to(tmp257, [XBLOCK, RBLOCK]) | |
tmp260 = _tmp259 + tmp258 | |
_tmp259 = tl.where(rmask & xmask, tmp260, _tmp259) | |
tmp259 = tl.sum(_tmp259, 1)[:, None] | |
for roffset in range(0, rnumel, RBLOCK): | |
rindex = roffset + rbase | |
rmask = rindex < rnumel | |
r1 = rindex | |
tmp263 = tl.load(in_ptr99 + (r1 + (2560*x0)), rmask & xmask, eviction_policy='evict_last', other=0).to(tl.float32) | |
tmp265 = tl.load(in_ptr100 + (r1), rmask, eviction_policy='evict_last', other=0) | |
tmp269 = tl.load(out_ptr35 + (r1 + (2560*x0)), rmask & xmask, eviction_policy='evict_last', other=0) | |
tmp274 = tl.load(in_ptr101 + (r1 + (2560*x0)), rmask & xmask, eviction_policy='evict_last') | |
tmp261 = 2560.0 | |
tmp262 = tmp243 / tmp261 | |
tmp264 = tmp263.to(tl.float32) | |
tmp266 = tmp264 * tmp265 | |
tmp267 = tmp266 * tmp261 | |
tmp268 = tmp267 - tmp250 | |
tmp270 = tmp269 * tmp259 | |
tmp271 = tmp268 - tmp270 | |
tmp272 = tmp262 * tmp271 | |
tmp273 = tmp272.to(tl.float32) | |
tmp275 = tmp274.to(tl.float32) | |
tmp276 = 1.1111111111111112 | |
tmp277 = tmp275 * tmp276 | |
tmp278 = tmp273 * tmp277 | |
tl.store(out_ptr38 + (r1 + (2560*x0)), tmp272, rmask & xmask) | |
tl.store(out_ptr39 + (r1 + (2560*x0)), tmp278, rmask & xmask) | |
def get_args(): | |
arg_0 = rand_strided((4, 128, 2560), (327680, 2560, 1), device='cuda:0', dtype=torch.float32) | |
arg_1 = rand_strided((4, 128, 2560), (327680, 2560, 1), device='cuda:0', dtype=torch.float16) | |
arg_2 = rand_strided((4, 128, 2560), (327680, 2560, 1), device='cuda:0', dtype=torch.float16) | |
arg_3 = rand_strided((4, 128, 1), (128, 1, 1), device='cuda:0', dtype=torch.float32) | |
arg_4 = rand_strided((4, 128, 1), (128, 1, 1), device='cuda:0', dtype=torch.float32) | |
arg_5 = rand_strided((4, 128, 2560), (327680, 2560, 1), device='cuda:0', dtype=torch.float16) | |
arg_6 = rand_strided((4, 128, 1), (128, 1, 1), device='cuda:0', dtype=torch.float32) | |
arg_7 = rand_strided((4, 128, 1), (128, 1, 1), device='cuda:0', dtype=torch.float32) | |
arg_8 = rand_strided((4, 128, 2560), (327680, 2560, 1), device='cuda:0', dtype=torch.float16) | |
arg_9 = rand_strided((4, 128, 2560), (327680, 2560, 1), device='cuda:0', dtype=torch.float16) | |
arg_10 = rand_strided((4, 128, 2560), (327680, 2560, 1), device='cuda:0', dtype=torch.float16) | |
arg_11 = rand_strided((4, 128, 1), (128, 1, 1), device='cuda:0', dtype=torch.float32) | |
arg_12 = rand_strided((4, 128, 1), (128, 1, 1), device='cuda:0', dtype=torch.float32) | |
arg_13 = rand_strided((4, 128, 2560), (327680, 2560, 1), device='cuda:0', dtype=torch.float16) | |
arg_14 = rand_strided((4, 128, 1), (128, 1, 1), device='cuda:0', dtype=torch.float32) | |
arg_15 = rand_strided((4, 128, 1), (128, 1, 1), device='cuda:0', dtype=torch.float32) | |
arg_16 = rand_strided((4, 128, 2560), (327680, 2560, 1), device='cuda:0', dtype=torch.float16) | |
arg_17 = rand_strided((4, 128, 2560), (327680, 2560, 1), device='cuda:0', dtype=torch.float16) | |
arg_18 = rand_strided((4, 128, 2560), (327680, 2560, 1), device='cuda:0', dtype=torch.float16) | |
arg_19 = rand_strided((4, 128, 1), (128, 1, 1), device='cuda:0', dtype=torch.float32) | |
arg_20 = rand_strided((4, 128, 1), (128, 1, 1), device='cuda:0', dtype=torch.float32) | |
arg_21 = rand_strided((4, 128, 2560), (327680, 2560, 1), device='cuda:0', dtype=torch.float16) | |
arg_22 = rand_strided((4, 128, 1), (128, 1, 1), device='cuda:0', dtype=torch.float32) | |
arg_23 = rand_strided((4, 128, 1), (128, 1, 1), device='cuda:0', dtype=torch.float32) | |
arg_24 = rand_strided((4, 128, 2560), (327680, 2560, 1), device='cuda:0', dtype=torch.float16) | |
arg_25 = rand_strided((4, 128, 2560), (327680, 2560, 1), device='cuda:0', dtype=torch.float16) | |
arg_26 = rand_strided((4, 128, 2560), (327680, 2560, 1), device='cuda:0', dtype=torch.float16) | |
arg_27 = rand_strided((4, 128, 1), (128, 1, 1), device='cuda:0', dtype=torch.float32) | |
arg_28 = rand_strided((4, 128, 1), (128, 1, 1), device='cuda:0', dtype=torch.float32) | |
arg_29 = rand_strided((4, 128, 2560), (327680, 2560, 1), device='cuda:0', dtype=torch.float16) | |
arg_30 = rand_strided((4, 128, 1), (128, 1, 1), device='cuda:0', dtype=torch.float32) | |
arg_31 = rand_strided((4, 128, 1), (128, 1, 1), device='cuda:0', dtype=torch.float32) | |
arg_32 = rand_strided((4, 128, 2560), (327680, 2560, 1), device='cuda:0', dtype=torch.float16) | |
arg_33 = rand_strided((4, 128, 2560), (327680, 2560, 1), device='cuda:0', dtype=torch.float16) | |
arg_34 = rand_strided((4, 128, 2560), (327680, 2560, 1), device='cuda:0', dtype=torch.float16) | |
arg_35 = rand_strided((4, 128, 1), (128, 1, 1), device='cuda:0', dtype=torch.float32) | |
arg_36 = rand_strided((4, 128, 1), (128, 1, 1), device='cuda:0', dtype=torch.float32) | |
arg_37 = rand_strided((4, 128, 2560), (327680, 2560, 1), device='cuda:0', dtype=torch.float16) | |
arg_38 = rand_strided((4, 128, 1), (128, 1, 1), device='cuda:0', dtype=torch.float32) | |
arg_39 = rand_strided((4, 128, 1), (128, 1, 1), device='cuda:0', dtype=torch.float32) | |
arg_40 = rand_strided((4, 128, 2560), (327680, 2560, 1), device='cuda:0', dtype=torch.float16) | |
arg_41 = rand_strided((4, 128, 2560), (327680, 2560, 1), device='cuda:0', dtype=torch.float16) | |
arg_42 = rand_strided((4, 128, 2560), (327680, 2560, 1), device='cuda:0', dtype=torch.float16) | |
arg_43 = rand_strided((4, 128, 1), (128, 1, 1), device='cuda:0', dtype=torch.float32) | |
arg_44 = rand_strided((4, 128, 1), (128, 1, 1), device='cuda:0', dtype=torch.float32) | |
arg_45 = rand_strided((4, 128, 2560), (327680, 2560, 1), device='cuda:0', dtype=torch.float16) | |
arg_46 = rand_strided((4, 128, 1), (128, 1, 1), device='cuda:0', dtype=torch.float32) | |
arg_47 = rand_strided((4, 128, 1), (128, 1, 1), device='cuda:0', dtype=torch.float32) | |
arg_48 = rand_strided((4, 128, 2560), (327680, 2560, 1), device='cuda:0', dtype=torch.float16) | |
arg_49 = rand_strided((4, 128, 2560), (327680, 2560, 1), device='cuda:0', dtype=torch.float16) | |
arg_50 = rand_strided((4, 128, 2560), (327680, 2560, 1), device='cuda:0', dtype=torch.float16) | |
arg_51 = rand_strided((4, 128, 1), (128, 1, 1), device='cuda:0', dtype=torch.float32) | |
arg_52 = rand_strided((4, 128, 1), (128, 1, 1), device='cuda:0', dtype=torch.float32) | |
arg_53 = rand_strided((4, 128, 2560), (327680, 2560, 1), device='cuda:0', dtype=torch.float16) | |
arg_54 = rand_strided((4, 128, 1), (128, 1, 1), device='cuda:0', dtype=torch.float32) | |
arg_55 = rand_strided((4, 128, 1), (128, 1, 1), device='cuda:0', dtype=torch.float32) | |
arg_56 = rand_strided((4, 128, 2560), (327680, 2560, 1), device='cuda:0', dtype=torch.float16) | |
arg_57 = rand_strided((4, 128, 2560), (327680, 2560, 1), device='cuda:0', dtype=torch.float16) | |
arg_58 = rand_strided((4, 128, 2560), (327680, 2560, 1), device='cuda:0', dtype=torch.float16) | |
arg_59 = rand_strided((4, 128, 1), (128, 1, 1), device='cuda:0', dtype=torch.float32) | |
arg_60 = rand_strided((4, 128, 1), (128, 1, 1), device='cuda:0', dtype=torch.float32) | |
arg_61 = rand_strided((4, 128, 2560), (327680, 2560, 1), device='cuda:0', dtype=torch.float16) | |
arg_62 = rand_strided((4, 128, 1), (128, 1, 1), device='cuda:0', dtype=torch.float32) | |
arg_63 = rand_strided((4, 128, 1), (128, 1, 1), device='cuda:0', dtype=torch.float32) | |
arg_64 = rand_strided((4, 128, 2560), (327680, 2560, 1), device='cuda:0', dtype=torch.float16) | |
arg_65 = rand_strided((4, 128, 2560), (327680, 2560, 1), device='cuda:0', dtype=torch.float16) | |
arg_66 = rand_strided((4, 128, 2560), (327680, 2560, 1), device='cuda:0', dtype=torch.float16) | |
arg_67 = rand_strided((4, 128, 1), (128, 1, 1), device='cuda:0', dtype=torch.float32) | |
arg_68 = rand_strided((4, 128, 1), (128, 1, 1), device='cuda:0', dtype=torch.float32) | |
arg_69 = rand_strided((4, 128, 2560), (327680, 2560, 1), device='cuda:0', dtype=torch.float16) | |
arg_70 = rand_strided((4, 128, 1), (128, 1, 1), device='cuda:0', dtype=torch.float32) | |
arg_71 = rand_strided((4, 128, 1), (128, 1, 1), device='cuda:0', dtype=torch.float32) | |
arg_72 = rand_strided((4, 128, 2560), (327680, 2560, 1), device='cuda:0', dtype=torch.float16) | |
arg_73 = rand_strided((4, 128, 2560), (327680, 2560, 1), device='cuda:0', dtype=torch.float16) | |
arg_74 = rand_strided((4, 128, 2560), (327680, 2560, 1), device='cuda:0', dtype=torch.float16) | |
arg_75 = rand_strided((4, 128, 1), (128, 1, 1), device='cuda:0', dtype=torch.float32) | |
arg_76 = rand_strided((4, 128, 1), (128, 1, 1), device='cuda:0', dtype=torch.float32) | |
arg_77 = rand_strided((4, 128, 2560), (327680, 2560, 1), device='cuda:0', dtype=torch.float16) | |
arg_78 = rand_strided((4, 128, 1), (128, 1, 1), device='cuda:0', dtype=torch.float32) | |
arg_79 = rand_strided((4, 128, 1), (128, 1, 1), device='cuda:0', dtype=torch.float32) | |
arg_80 = rand_strided((4, 128, 2560), (327680, 2560, 1), device='cuda:0', dtype=torch.float16) | |
arg_81 = rand_strided((4, 128, 2560), (327680, 2560, 1), device='cuda:0', dtype=torch.float16) | |
arg_82 = rand_strided((4, 128, 2560), (327680, 2560, 1), device='cuda:0', dtype=torch.float16) | |
arg_83 = rand_strided((4, 128, 1), (128, 1, 1), device='cuda:0', dtype=torch.float32) | |
arg_84 = rand_strided((4, 128, 1), (128, 1, 1), device='cuda:0', dtype=torch.float32) | |
arg_85 = rand_strided((4, 128, 2560), (327680, 2560, 1), device='cuda:0', dtype=torch.float16) | |
arg_86 = rand_strided((4, 128, 1), (128, 1, 1), device='cuda:0', dtype=torch.float32) | |
arg_87 = rand_strided((4, 128, 1), (128, 1, 1), device='cuda:0', dtype=torch.float32) | |
arg_88 = rand_strided((4, 128, 2560), (327680, 2560, 1), device='cuda:0', dtype=torch.float16) | |
arg_89 = rand_strided((4, 128, 2560), (327680, 2560, 1), device='cuda:0', dtype=torch.float16) | |
arg_90 = rand_strided((4, 128, 2560), (327680, 2560, 1), device='cuda:0', dtype=torch.float16) | |
arg_91 = rand_strided((4, 128, 1), (128, 1, 1), device='cuda:0', dtype=torch.float32) | |
arg_92 = rand_strided((4, 128, 1), (128, 1, 1), device='cuda:0', dtype=torch.float32) | |
arg_93 = rand_strided((4, 128, 2560), (327680, 2560, 1), device='cuda:0', dtype=torch.float16) | |
arg_94 = rand_strided((4, 128, 1), (128, 1, 1), device='cuda:0', dtype=torch.float32) | |
arg_95 = rand_strided((4, 128, 1), (128, 1, 1), device='cuda:0', dtype=torch.float32) | |
arg_96 = rand_strided((4, 128, 2560), (327680, 2560, 1), device='cuda:0', dtype=torch.float16) | |
arg_97 = rand_strided((4, 128, 1), (128, 1, 1), device='cuda:0', dtype=torch.float32) | |
arg_98 = rand_strided((4, 128, 1), (128, 1, 1), device='cuda:0', dtype=torch.float32) | |
arg_99 = rand_strided((512, 2560), (2560, 1), device='cuda:0', dtype=torch.float16) | |
arg_100 = rand_strided((2560,), (1,), device='cuda:0', dtype=torch.float32) | |
arg_101 = rand_strided((4, 128, 2560), (327680, 2560, 1), device='cuda:0', dtype=torch.bool) | |
arg_102 = rand_strided((4, 128, 2560), (327680, 2560, 1), device='cuda:0', dtype=torch.float32) | |
arg_103 = rand_strided((4, 128, 2560), (327680, 2560, 1), device='cuda:0', dtype=torch.float32) | |
arg_104 = rand_strided((4, 128, 2560), (327680, 2560, 1), device='cuda:0', dtype=torch.float32) | |
arg_105 = rand_strided((4, 128, 2560), (327680, 2560, 1), device='cuda:0', dtype=torch.float32) | |
arg_106 = rand_strided((4, 128, 2560), (327680, 2560, 1), device='cuda:0', dtype=torch.float32) | |
arg_107 = rand_strided((4, 128, 2560), (327680, 2560, 1), device='cuda:0', dtype=torch.float32) | |
arg_108 = rand_strided((4, 128, 2560), (327680, 2560, 1), device='cuda:0', dtype=torch.float32) | |
arg_109 = rand_strided((4, 128, 2560), (327680, 2560, 1), device='cuda:0', dtype=torch.float32) | |
arg_110 = rand_strided((4, 128, 2560), (327680, 2560, 1), device='cuda:0', dtype=torch.float32) | |
arg_111 = rand_strided((4, 128, 2560), (327680, 2560, 1), device='cuda:0', dtype=torch.float32) | |
arg_112 = rand_strided((4, 128, 2560), (327680, 2560, 1), device='cuda:0', dtype=torch.float32) | |
arg_113 = rand_strided((4, 128, 2560), (327680, 2560, 1), device='cuda:0', dtype=torch.float32) | |
arg_114 = rand_strided((4, 128, 2560), (327680, 2560, 1), device='cuda:0', dtype=torch.float32) | |
arg_115 = rand_strided((4, 128, 2560), (327680, 2560, 1), device='cuda:0', dtype=torch.float32) | |
arg_116 = rand_strided((4, 128, 2560), (327680, 2560, 1), device='cuda:0', dtype=torch.float32) | |
arg_117 = rand_strided((4, 128, 2560), (327680, 2560, 1), device='cuda:0', dtype=torch.float32) | |
arg_118 = rand_strided((4, 128, 2560), (327680, 2560, 1), device='cuda:0', dtype=torch.float32) | |
arg_119 = rand_strided((4, 128, 2560), (327680, 2560, 1), device='cuda:0', dtype=torch.float32) | |
arg_120 = rand_strided((4, 128, 2560), (327680, 2560, 1), device='cuda:0', dtype=torch.float32) | |
arg_121 = rand_strided((4, 128, 2560), (327680, 2560, 1), device='cuda:0', dtype=torch.float32) | |
arg_122 = rand_strided((4, 128, 2560), (327680, 2560, 1), device='cuda:0', dtype=torch.float32) | |
arg_123 = rand_strided((4, 128, 2560), (327680, 2560, 1), device='cuda:0', dtype=torch.float32) | |
arg_124 = rand_strided((4, 128, 2560), (327680, 2560, 1), device='cuda:0', dtype=torch.float32) | |
arg_125 = rand_strided((4, 128, 2560), (327680, 2560, 1), device='cuda:0', dtype=torch.float32) | |
arg_126 = rand_strided((4, 128, 2560), (327680, 2560, 1), device='cuda:0', dtype=torch.float32) | |
arg_127 = rand_strided((4, 128, 2560), (327680, 2560, 1), device='cuda:0', dtype=torch.float32) | |
arg_128 = rand_strided((4, 128, 2560), (327680, 2560, 1), device='cuda:0', dtype=torch.float32) | |
arg_129 = rand_strided((4, 128, 2560), (327680, 2560, 1), device='cuda:0', dtype=torch.float32) | |
arg_130 = rand_strided((4, 128, 2560), (327680, 2560, 1), device='cuda:0', dtype=torch.float32) | |
arg_131 = rand_strided((4, 128, 2560), (327680, 2560, 1), device='cuda:0', dtype=torch.float32) | |
arg_132 = rand_strided((4, 128, 2560), (327680, 2560, 1), device='cuda:0', dtype=torch.float32) | |
arg_133 = rand_strided((4, 128, 2560), (327680, 2560, 1), device='cuda:0', dtype=torch.float32) | |
arg_134 = rand_strided((4, 128, 2560), (327680, 2560, 1), device='cuda:0', dtype=torch.float32) | |
arg_135 = rand_strided((4, 128, 2560), (327680, 2560, 1), device='cuda:0', dtype=torch.float32) | |
arg_136 = rand_strided((4, 128, 2560), (327680, 2560, 1), device='cuda:0', dtype=torch.float32) | |
arg_137 = rand_strided((4, 128, 2560), (327680, 2560, 1), device='cuda:0', dtype=torch.float32) | |
arg_138 = rand_strided((4, 128, 2560), (327680, 2560, 1), device='cuda:0', dtype=torch.float32) | |
arg_139 = rand_strided((4, 128, 2560), (327680, 2560, 1), device='cuda:0', dtype=torch.float16) | |
return arg_0, arg_1, arg_2, arg_3, arg_4, arg_5, arg_6, arg_7, arg_8, arg_9, arg_10, arg_11, arg_12, arg_13, arg_14, arg_15, arg_16, arg_17, arg_18, arg_19, arg_20, arg_21, arg_22, arg_23, arg_24, arg_25, arg_26, arg_27, arg_28, arg_29, arg_30, arg_31, arg_32, arg_33, arg_34, arg_35, arg_36, arg_37, arg_38, arg_39, arg_40, arg_41, arg_42, arg_43, arg_44, arg_45, arg_46, arg_47, arg_48, arg_49, arg_50, arg_51, arg_52, arg_53, arg_54, arg_55, arg_56, arg_57, arg_58, arg_59, arg_60, arg_61, arg_62, arg_63, arg_64, arg_65, arg_66, arg_67, arg_68, arg_69, arg_70, arg_71, arg_72, arg_73, arg_74, arg_75, arg_76, arg_77, arg_78, arg_79, arg_80, arg_81, arg_82, arg_83, arg_84, arg_85, arg_86, arg_87, arg_88, arg_89, arg_90, arg_91, arg_92, arg_93, arg_94, arg_95, arg_96, arg_97, arg_98, arg_99, arg_100, arg_101, arg_102, arg_103, arg_104, arg_105, arg_106, arg_107, arg_108, arg_109, arg_110, arg_111, arg_112, arg_113, arg_114, arg_115, arg_116, arg_117, arg_118, arg_119, arg_120, arg_121, arg_122, arg_123, arg_124, arg_125, arg_126, arg_127, arg_128, arg_129, arg_130, arg_131, arg_132, arg_133, arg_134, arg_135, arg_136, arg_137, arg_138, arg_139, | |
def call(args): | |
with torch.cuda._DeviceGuard(0): | |
torch.cuda.set_device(0) | |
stream0 = get_cuda_stream(0) | |
triton_.run(*args, 512, 2560, grid=grid(512), stream=stream0) | |
def benchmark_all_configs(args): | |
with torch.cuda._DeviceGuard(0): | |
torch.cuda.set_device(0) | |
return triton_.benchmark_all_configs(*args, 512, 2560, grid=grid(512)) | |
if __name__ == '__main__': | |
from torch._inductor.utils import get_num_bytes | |
from triton.testing import do_bench | |
args = get_args() | |
ms = do_bench(lambda: call(args), rep=40, fast_flush=True) | |
num_gb = get_num_bytes(*args, num_in_out_args=0) / 1e9 | |
gb_per_s = num_gb / (ms / 1e3) | |
print(f"{ms:.3f}ms {num_gb:.3f}GB {gb_per_s:.2f}GB/s") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment