Skip to content

Instantly share code, notes, and snippets.

@shunting314
Created September 22, 2023 00:26
Show Gist options
  • Save shunting314/59aeafd297ed8ff03aa12030a2dd41ae to your computer and use it in GitHub Desktop.
Save shunting314/59aeafd297ed8ff03aa12030a2dd41ae to your computer and use it in GitHub Desktop.
import triton
import triton.language as tl
from torch._inductor.ir import ReductionHint
from torch._inductor.ir import TileHint
from torch._inductor.triton_heuristics import AutotuneHint, reduction
from torch._inductor.utils import instance_descriptor
from torch._inductor import triton_helpers
from torch._dynamo.testing import rand_strided
from torch._C import _cuda_getCurrentRawStream as get_cuda_stream
import torch
from torch._inductor.triton_heuristics import grid
@reduction(
size_hints=[512, 4096],
reduction_hint=ReductionHint.INNER,
filename=__file__,
meta={'signature': {0: '*fp32', 1: '*fp16', 2: '*fp16', 3: '*fp32', 4: '*fp32', 5: '*fp16', 6: '*fp32', 7: '*fp32', 8: '*fp16', 9: '*fp16', 10: '*fp16', 11: '*fp32', 12: '*fp32', 13: '*fp16', 14: '*fp32', 15: '*fp32', 16: '*fp16', 17: '*fp16', 18: '*fp16', 19: '*fp32', 20: '*fp32', 21: '*fp16', 22: '*fp32', 23: '*fp32', 24: '*fp16', 25: '*fp16', 26: '*fp16', 27: '*fp32', 28: '*fp32', 29: '*fp16', 30: '*fp32', 31: '*fp32', 32: '*fp16', 33: '*fp16', 34: '*fp16', 35: '*fp32', 36: '*fp32', 37: '*fp16', 38: '*fp32', 39: '*fp32', 40: '*fp16', 41: '*fp16', 42: '*fp16', 43: '*fp32', 44: '*fp32', 45: '*fp16', 46: '*fp32', 47: '*fp32', 48: '*fp16', 49: '*fp16', 50: '*fp16', 51: '*fp32', 52: '*fp32', 53: '*fp16', 54: '*fp32', 55: '*fp32', 56: '*fp16', 57: '*fp16', 58: '*fp16', 59: '*fp32', 60: '*fp32', 61: '*fp16', 62: '*fp32', 63: '*fp32', 64: '*fp16', 65: '*fp16', 66: '*fp16', 67: '*fp32', 68: '*fp32', 69: '*fp16', 70: '*fp32', 71: '*fp32', 72: '*fp16', 73: '*fp16', 74: '*fp16', 75: '*fp32', 76: '*fp32', 77: '*fp16', 78: '*fp32', 79: '*fp32', 80: '*fp16', 81: '*fp16', 82: '*fp16', 83: '*fp32', 84: '*fp32', 85: '*fp16', 86: '*fp32', 87: '*fp32', 88: '*fp16', 89: '*fp16', 90: '*fp16', 91: '*fp32', 92: '*fp32', 93: '*fp16', 94: '*fp32', 95: '*fp32', 96: '*fp16', 97: '*fp32', 98: '*fp32', 99: '*fp16', 100: '*fp32', 101: '*i1', 102: '*fp32', 103: '*fp32', 104: '*fp32', 105: '*fp32', 106: '*fp32', 107: '*fp32', 108: '*fp32', 109: '*fp32', 110: '*fp32', 111: '*fp32', 112: '*fp32', 113: '*fp32', 114: '*fp32', 115: '*fp32', 116: '*fp32', 117: '*fp32', 118: '*fp32', 119: '*fp32', 120: '*fp32', 121: '*fp32', 122: '*fp32', 123: '*fp32', 124: '*fp32', 125: '*fp32', 126: '*fp32', 127: '*fp32', 128: '*fp32', 129: '*fp32', 130: '*fp32', 131: '*fp32', 132: '*fp32', 133: '*fp32', 134: '*fp32', 135: '*fp32', 136: '*fp32', 137: '*fp32', 138: '*fp32', 139: '*fp16', 140: 'i32', 141: 'i32'}, 'device': 0, 'device_type': 'cuda', 'constants': {}, 'mutated_arg_names': [], 'autotune_hints': set(), 'kernel_name': 'Placeholder.DESCRIPTIVE_NAME', 'configs': [instance_descriptor(divisible_by_16=(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141), equal_to_1=(), ids_of_folded_args=(), divisible_by_8=(140, 141))]}
)
@triton.jit
def triton_(in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, in_ptr5, in_ptr6, in_ptr7, in_ptr8, in_ptr9, in_ptr10, in_ptr11, in_ptr12, in_ptr13, in_ptr14, in_ptr15, in_ptr16, in_ptr17, in_ptr18, in_ptr19, in_ptr20, in_ptr21, in_ptr22, in_ptr23, in_ptr24, in_ptr25, in_ptr26, in_ptr27, in_ptr28, in_ptr29, in_ptr30, in_ptr31, in_ptr32, in_ptr33, in_ptr34, in_ptr35, in_ptr36, in_ptr37, in_ptr38, in_ptr39, in_ptr40, in_ptr41, in_ptr42, in_ptr43, in_ptr44, in_ptr45, in_ptr46, in_ptr47, in_ptr48, in_ptr49, in_ptr50, in_ptr51, in_ptr52, in_ptr53, in_ptr54, in_ptr55, in_ptr56, in_ptr57, in_ptr58, in_ptr59, in_ptr60, in_ptr61, in_ptr62, in_ptr63, in_ptr64, in_ptr65, in_ptr66, in_ptr67, in_ptr68, in_ptr69, in_ptr70, in_ptr71, in_ptr72, in_ptr73, in_ptr74, in_ptr75, in_ptr76, in_ptr77, in_ptr78, in_ptr79, in_ptr80, in_ptr81, in_ptr82, in_ptr83, in_ptr84, in_ptr85, in_ptr86, in_ptr87, in_ptr88, in_ptr89, in_ptr90, in_ptr91, in_ptr92, in_ptr93, in_ptr94, in_ptr95, in_ptr96, in_ptr97, in_ptr98, in_ptr99, in_ptr100, in_ptr101, out_ptr0, out_ptr1, out_ptr2, out_ptr3, out_ptr4, out_ptr5, out_ptr6, out_ptr7, out_ptr8, out_ptr9, out_ptr10, out_ptr11, out_ptr12, out_ptr13, out_ptr14, out_ptr15, out_ptr16, out_ptr17, out_ptr18, out_ptr19, out_ptr20, out_ptr21, out_ptr22, out_ptr23, out_ptr24, out_ptr25, out_ptr26, out_ptr27, out_ptr28, out_ptr29, out_ptr30, out_ptr31, out_ptr32, out_ptr33, out_ptr34, out_ptr35, out_ptr38, out_ptr39, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr):
xnumel = 512
rnumel = 2560
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
xmask = xindex < xnumel
rbase = tl.arange(0, RBLOCK)[None, :]
x0 = xindex
tmp7 = tl.load(in_ptr3 + (x0), xmask, eviction_policy='evict_last')
tmp9 = tl.load(in_ptr4 + (x0), xmask, eviction_policy='evict_last')
tmp14 = tl.load(in_ptr6 + (x0), xmask, eviction_policy='evict_last')
tmp16 = tl.load(in_ptr7 + (x0), xmask, eviction_policy='evict_last')
tmp27 = tl.load(in_ptr11 + (x0), xmask, eviction_policy='evict_last')
tmp29 = tl.load(in_ptr12 + (x0), xmask, eviction_policy='evict_last')
tmp34 = tl.load(in_ptr14 + (x0), xmask, eviction_policy='evict_last')
tmp36 = tl.load(in_ptr15 + (x0), xmask, eviction_policy='evict_last')
tmp47 = tl.load(in_ptr19 + (x0), xmask, eviction_policy='evict_last')
tmp49 = tl.load(in_ptr20 + (x0), xmask, eviction_policy='evict_last')
tmp54 = tl.load(in_ptr22 + (x0), xmask, eviction_policy='evict_last')
tmp56 = tl.load(in_ptr23 + (x0), xmask, eviction_policy='evict_last')
tmp67 = tl.load(in_ptr27 + (x0), xmask, eviction_policy='evict_last')
tmp69 = tl.load(in_ptr28 + (x0), xmask, eviction_policy='evict_last')
tmp74 = tl.load(in_ptr30 + (x0), xmask, eviction_policy='evict_last')
tmp76 = tl.load(in_ptr31 + (x0), xmask, eviction_policy='evict_last')
tmp87 = tl.load(in_ptr35 + (x0), xmask, eviction_policy='evict_last')
tmp89 = tl.load(in_ptr36 + (x0), xmask, eviction_policy='evict_last')
tmp94 = tl.load(in_ptr38 + (x0), xmask, eviction_policy='evict_last')
tmp96 = tl.load(in_ptr39 + (x0), xmask, eviction_policy='evict_last')
tmp107 = tl.load(in_ptr43 + (x0), xmask, eviction_policy='evict_last')
tmp109 = tl.load(in_ptr44 + (x0), xmask, eviction_policy='evict_last')
tmp114 = tl.load(in_ptr46 + (x0), xmask, eviction_policy='evict_last')
tmp116 = tl.load(in_ptr47 + (x0), xmask, eviction_policy='evict_last')
tmp127 = tl.load(in_ptr51 + (x0), xmask, eviction_policy='evict_last')
tmp129 = tl.load(in_ptr52 + (x0), xmask, eviction_policy='evict_last')
tmp134 = tl.load(in_ptr54 + (x0), xmask, eviction_policy='evict_last')
tmp136 = tl.load(in_ptr55 + (x0), xmask, eviction_policy='evict_last')
tmp147 = tl.load(in_ptr59 + (x0), xmask, eviction_policy='evict_last')
tmp149 = tl.load(in_ptr60 + (x0), xmask, eviction_policy='evict_last')
tmp154 = tl.load(in_ptr62 + (x0), xmask, eviction_policy='evict_last')
tmp156 = tl.load(in_ptr63 + (x0), xmask, eviction_policy='evict_last')
tmp167 = tl.load(in_ptr67 + (x0), xmask, eviction_policy='evict_last')
tmp169 = tl.load(in_ptr68 + (x0), xmask, eviction_policy='evict_last')
tmp174 = tl.load(in_ptr70 + (x0), xmask, eviction_policy='evict_last')
tmp176 = tl.load(in_ptr71 + (x0), xmask, eviction_policy='evict_last')
tmp187 = tl.load(in_ptr75 + (x0), xmask, eviction_policy='evict_last')
tmp189 = tl.load(in_ptr76 + (x0), xmask, eviction_policy='evict_last')
tmp194 = tl.load(in_ptr78 + (x0), xmask, eviction_policy='evict_last')
tmp196 = tl.load(in_ptr79 + (x0), xmask, eviction_policy='evict_last')
tmp207 = tl.load(in_ptr83 + (x0), xmask, eviction_policy='evict_last')
tmp209 = tl.load(in_ptr84 + (x0), xmask, eviction_policy='evict_last')
tmp214 = tl.load(in_ptr86 + (x0), xmask, eviction_policy='evict_last')
tmp216 = tl.load(in_ptr87 + (x0), xmask, eviction_policy='evict_last')
tmp227 = tl.load(in_ptr91 + (x0), xmask, eviction_policy='evict_last')
tmp229 = tl.load(in_ptr92 + (x0), xmask, eviction_policy='evict_last')
tmp234 = tl.load(in_ptr94 + (x0), xmask, eviction_policy='evict_last')
tmp236 = tl.load(in_ptr95 + (x0), xmask, eviction_policy='evict_last')
tmp241 = tl.load(in_ptr97 + (x0), xmask, eviction_policy='evict_last')
tmp243 = tl.load(in_ptr98 + (x0), xmask, eviction_policy='evict_last')
_tmp250 = tl.full([XBLOCK, RBLOCK], 0, tl.float32)
for roffset in range(0, rnumel, RBLOCK):
rindex = roffset + rbase
rmask = rindex < rnumel
r1 = rindex
tmp0 = tl.load(in_ptr0 + (r1 + (2560*x0)), rmask & xmask, eviction_policy='evict_last', other=0)
tmp1 = tl.load(in_ptr1 + (r1 + (2560*x0)), rmask & xmask, eviction_policy='evict_last', other=0).to(tl.float32)
tmp4 = tl.load(in_ptr2 + (r1 + (2560*x0)), rmask & xmask, eviction_policy='evict_last', other=0).to(tl.float32)
tmp11 = tl.load(in_ptr5 + (r1 + (2560*x0)), rmask & xmask, eviction_policy='evict_last', other=0).to(tl.float32)
tmp18 = tl.load(in_ptr8 + (r1 + (2560*x0)), rmask & xmask, eviction_policy='evict_last', other=0).to(tl.float32)
tmp21 = tl.load(in_ptr9 + (r1 + (2560*x0)), rmask & xmask, eviction_policy='evict_last', other=0).to(tl.float32)
tmp24 = tl.load(in_ptr10 + (r1 + (2560*x0)), rmask & xmask, eviction_policy='evict_last', other=0).to(tl.float32)
tmp31 = tl.load(in_ptr13 + (r1 + (2560*x0)), rmask & xmask, eviction_policy='evict_last', other=0).to(tl.float32)
tmp38 = tl.load(in_ptr16 + (r1 + (2560*x0)), rmask & xmask, eviction_policy='evict_last', other=0).to(tl.float32)
tmp41 = tl.load(in_ptr17 + (r1 + (2560*x0)), rmask & xmask, eviction_policy='evict_last', other=0).to(tl.float32)
tmp44 = tl.load(in_ptr18 + (r1 + (2560*x0)), rmask & xmask, eviction_policy='evict_last', other=0).to(tl.float32)
tmp51 = tl.load(in_ptr21 + (r1 + (2560*x0)), rmask & xmask, eviction_policy='evict_last', other=0).to(tl.float32)
tmp58 = tl.load(in_ptr24 + (r1 + (2560*x0)), rmask & xmask, eviction_policy='evict_last', other=0).to(tl.float32)
tmp61 = tl.load(in_ptr25 + (r1 + (2560*x0)), rmask & xmask, eviction_policy='evict_last', other=0).to(tl.float32)
tmp64 = tl.load(in_ptr26 + (r1 + (2560*x0)), rmask & xmask, eviction_policy='evict_last', other=0).to(tl.float32)
tmp71 = tl.load(in_ptr29 + (r1 + (2560*x0)), rmask & xmask, eviction_policy='evict_last', other=0).to(tl.float32)
tmp78 = tl.load(in_ptr32 + (r1 + (2560*x0)), rmask & xmask, eviction_policy='evict_last', other=0).to(tl.float32)
tmp81 = tl.load(in_ptr33 + (r1 + (2560*x0)), rmask & xmask, eviction_policy='evict_last', other=0).to(tl.float32)
tmp84 = tl.load(in_ptr34 + (r1 + (2560*x0)), rmask & xmask, eviction_policy='evict_last', other=0).to(tl.float32)
tmp91 = tl.load(in_ptr37 + (r1 + (2560*x0)), rmask & xmask, eviction_policy='evict_last', other=0).to(tl.float32)
tmp98 = tl.load(in_ptr40 + (r1 + (2560*x0)), rmask & xmask, eviction_policy='evict_last', other=0).to(tl.float32)
tmp101 = tl.load(in_ptr41 + (r1 + (2560*x0)), rmask & xmask, eviction_policy='evict_last', other=0).to(tl.float32)
tmp104 = tl.load(in_ptr42 + (r1 + (2560*x0)), rmask & xmask, eviction_policy='evict_last', other=0).to(tl.float32)
tmp111 = tl.load(in_ptr45 + (r1 + (2560*x0)), rmask & xmask, eviction_policy='evict_last', other=0).to(tl.float32)
tmp118 = tl.load(in_ptr48 + (r1 + (2560*x0)), rmask & xmask, eviction_policy='evict_last', other=0).to(tl.float32)
tmp121 = tl.load(in_ptr49 + (r1 + (2560*x0)), rmask & xmask, eviction_policy='evict_last', other=0).to(tl.float32)
tmp124 = tl.load(in_ptr50 + (r1 + (2560*x0)), rmask & xmask, eviction_policy='evict_last', other=0).to(tl.float32)
tmp131 = tl.load(in_ptr53 + (r1 + (2560*x0)), rmask & xmask, eviction_policy='evict_last', other=0).to(tl.float32)
tmp138 = tl.load(in_ptr56 + (r1 + (2560*x0)), rmask & xmask, eviction_policy='evict_last', other=0).to(tl.float32)
tmp141 = tl.load(in_ptr57 + (r1 + (2560*x0)), rmask & xmask, eviction_policy='evict_last', other=0).to(tl.float32)
tmp144 = tl.load(in_ptr58 + (r1 + (2560*x0)), rmask & xmask, eviction_policy='evict_last', other=0).to(tl.float32)
tmp151 = tl.load(in_ptr61 + (r1 + (2560*x0)), rmask & xmask, eviction_policy='evict_last', other=0).to(tl.float32)
tmp158 = tl.load(in_ptr64 + (r1 + (2560*x0)), rmask & xmask, eviction_policy='evict_last', other=0).to(tl.float32)
tmp161 = tl.load(in_ptr65 + (r1 + (2560*x0)), rmask & xmask, eviction_policy='evict_last', other=0).to(tl.float32)
tmp164 = tl.load(in_ptr66 + (r1 + (2560*x0)), rmask & xmask, eviction_policy='evict_last', other=0).to(tl.float32)
tmp171 = tl.load(in_ptr69 + (r1 + (2560*x0)), rmask & xmask, eviction_policy='evict_last', other=0).to(tl.float32)
tmp178 = tl.load(in_ptr72 + (r1 + (2560*x0)), rmask & xmask, eviction_policy='evict_last', other=0).to(tl.float32)
tmp181 = tl.load(in_ptr73 + (r1 + (2560*x0)), rmask & xmask, eviction_policy='evict_last', other=0).to(tl.float32)
tmp184 = tl.load(in_ptr74 + (r1 + (2560*x0)), rmask & xmask, eviction_policy='evict_last', other=0).to(tl.float32)
tmp191 = tl.load(in_ptr77 + (r1 + (2560*x0)), rmask & xmask, eviction_policy='evict_last', other=0).to(tl.float32)
tmp198 = tl.load(in_ptr80 + (r1 + (2560*x0)), rmask & xmask, eviction_policy='evict_last', other=0).to(tl.float32)
tmp201 = tl.load(in_ptr81 + (r1 + (2560*x0)), rmask & xmask, eviction_policy='evict_last', other=0).to(tl.float32)
tmp204 = tl.load(in_ptr82 + (r1 + (2560*x0)), rmask & xmask, eviction_policy='evict_last', other=0).to(tl.float32)
tmp211 = tl.load(in_ptr85 + (r1 + (2560*x0)), rmask & xmask, eviction_policy='evict_last', other=0).to(tl.float32)
tmp218 = tl.load(in_ptr88 + (r1 + (2560*x0)), rmask & xmask, eviction_policy='evict_last', other=0).to(tl.float32)
tmp221 = tl.load(in_ptr89 + (r1 + (2560*x0)), rmask & xmask, eviction_policy='evict_last', other=0).to(tl.float32)
tmp224 = tl.load(in_ptr90 + (r1 + (2560*x0)), rmask & xmask, eviction_policy='evict_last', other=0).to(tl.float32)
tmp231 = tl.load(in_ptr93 + (r1 + (2560*x0)), rmask & xmask, eviction_policy='evict_last', other=0).to(tl.float32)
tmp238 = tl.load(in_ptr96 + (r1 + (2560*x0)), rmask & xmask, eviction_policy='evict_last', other=0).to(tl.float32)
tmp245 = tl.load(in_ptr99 + (r1 + (2560*x0)), rmask & xmask, eviction_policy='evict_last', other=0).to(tl.float32)
tmp247 = tl.load(in_ptr100 + (r1), rmask, eviction_policy='evict_last', other=0)
tmp2 = tmp1.to(tl.float32)
tmp3 = tmp0 + tmp2
tmp5 = tmp4.to(tl.float32)
tmp6 = tmp3 + tmp5
tmp8 = tmp6 - tmp7
tmp10 = tmp8 * tmp9
tmp12 = tmp11.to(tl.float32)
tmp13 = tmp6 + tmp12
tmp15 = tmp13 - tmp14
tmp17 = tmp15 * tmp16
tmp19 = tmp18.to(tl.float32)
tmp20 = tmp13 + tmp19
tmp22 = tmp21.to(tl.float32)
tmp23 = tmp20 + tmp22
tmp25 = tmp24.to(tl.float32)
tmp26 = tmp23 + tmp25
tmp28 = tmp26 - tmp27
tmp30 = tmp28 * tmp29
tmp32 = tmp31.to(tl.float32)
tmp33 = tmp26 + tmp32
tmp35 = tmp33 - tmp34
tmp37 = tmp35 * tmp36
tmp39 = tmp38.to(tl.float32)
tmp40 = tmp33 + tmp39
tmp42 = tmp41.to(tl.float32)
tmp43 = tmp40 + tmp42
tmp45 = tmp44.to(tl.float32)
tmp46 = tmp43 + tmp45
tmp48 = tmp46 - tmp47
tmp50 = tmp48 * tmp49
tmp52 = tmp51.to(tl.float32)
tmp53 = tmp46 + tmp52
tmp55 = tmp53 - tmp54
tmp57 = tmp55 * tmp56
tmp59 = tmp58.to(tl.float32)
tmp60 = tmp53 + tmp59
tmp62 = tmp61.to(tl.float32)
tmp63 = tmp60 + tmp62
tmp65 = tmp64.to(tl.float32)
tmp66 = tmp63 + tmp65
tmp68 = tmp66 - tmp67
tmp70 = tmp68 * tmp69
tmp72 = tmp71.to(tl.float32)
tmp73 = tmp66 + tmp72
tmp75 = tmp73 - tmp74
tmp77 = tmp75 * tmp76
tmp79 = tmp78.to(tl.float32)
tmp80 = tmp73 + tmp79
tmp82 = tmp81.to(tl.float32)
tmp83 = tmp80 + tmp82
tmp85 = tmp84.to(tl.float32)
tmp86 = tmp83 + tmp85
tmp88 = tmp86 - tmp87
tmp90 = tmp88 * tmp89
tmp92 = tmp91.to(tl.float32)
tmp93 = tmp86 + tmp92
tmp95 = tmp93 - tmp94
tmp97 = tmp95 * tmp96
tmp99 = tmp98.to(tl.float32)
tmp100 = tmp93 + tmp99
tmp102 = tmp101.to(tl.float32)
tmp103 = tmp100 + tmp102
tmp105 = tmp104.to(tl.float32)
tmp106 = tmp103 + tmp105
tmp108 = tmp106 - tmp107
tmp110 = tmp108 * tmp109
tmp112 = tmp111.to(tl.float32)
tmp113 = tmp106 + tmp112
tmp115 = tmp113 - tmp114
tmp117 = tmp115 * tmp116
tmp119 = tmp118.to(tl.float32)
tmp120 = tmp113 + tmp119
tmp122 = tmp121.to(tl.float32)
tmp123 = tmp120 + tmp122
tmp125 = tmp124.to(tl.float32)
tmp126 = tmp123 + tmp125
tmp128 = tmp126 - tmp127
tmp130 = tmp128 * tmp129
tmp132 = tmp131.to(tl.float32)
tmp133 = tmp126 + tmp132
tmp135 = tmp133 - tmp134
tmp137 = tmp135 * tmp136
tmp139 = tmp138.to(tl.float32)
tmp140 = tmp133 + tmp139
tmp142 = tmp141.to(tl.float32)
tmp143 = tmp140 + tmp142
tmp145 = tmp144.to(tl.float32)
tmp146 = tmp143 + tmp145
tmp148 = tmp146 - tmp147
tmp150 = tmp148 * tmp149
tmp152 = tmp151.to(tl.float32)
tmp153 = tmp146 + tmp152
tmp155 = tmp153 - tmp154
tmp157 = tmp155 * tmp156
tmp159 = tmp158.to(tl.float32)
tmp160 = tmp153 + tmp159
tmp162 = tmp161.to(tl.float32)
tmp163 = tmp160 + tmp162
tmp165 = tmp164.to(tl.float32)
tmp166 = tmp163 + tmp165
tmp168 = tmp166 - tmp167
tmp170 = tmp168 * tmp169
tmp172 = tmp171.to(tl.float32)
tmp173 = tmp166 + tmp172
tmp175 = tmp173 - tmp174
tmp177 = tmp175 * tmp176
tmp179 = tmp178.to(tl.float32)
tmp180 = tmp173 + tmp179
tmp182 = tmp181.to(tl.float32)
tmp183 = tmp180 + tmp182
tmp185 = tmp184.to(tl.float32)
tmp186 = tmp183 + tmp185
tmp188 = tmp186 - tmp187
tmp190 = tmp188 * tmp189
tmp192 = tmp191.to(tl.float32)
tmp193 = tmp186 + tmp192
tmp195 = tmp193 - tmp194
tmp197 = tmp195 * tmp196
tmp199 = tmp198.to(tl.float32)
tmp200 = tmp193 + tmp199
tmp202 = tmp201.to(tl.float32)
tmp203 = tmp200 + tmp202
tmp205 = tmp204.to(tl.float32)
tmp206 = tmp203 + tmp205
tmp208 = tmp206 - tmp207
tmp210 = tmp208 * tmp209
tmp212 = tmp211.to(tl.float32)
tmp213 = tmp206 + tmp212
tmp215 = tmp213 - tmp214
tmp217 = tmp215 * tmp216
tmp219 = tmp218.to(tl.float32)
tmp220 = tmp213 + tmp219
tmp222 = tmp221.to(tl.float32)
tmp223 = tmp220 + tmp222
tmp225 = tmp224.to(tl.float32)
tmp226 = tmp223 + tmp225
tmp228 = tmp226 - tmp227
tmp230 = tmp228 * tmp229
tmp232 = tmp231.to(tl.float32)
tmp233 = tmp226 + tmp232
tmp235 = tmp233 - tmp234
tmp237 = tmp235 * tmp236
tmp239 = tmp238.to(tl.float32)
tmp240 = tmp233 + tmp239
tmp242 = tmp240 - tmp241
tmp244 = tmp242 * tmp243
tmp246 = tmp245.to(tl.float32)
tmp248 = tmp246 * tmp247
tmp249 = tl.broadcast_to(tmp248, [XBLOCK, RBLOCK])
tmp251 = _tmp250 + tmp249
_tmp250 = tl.where(rmask & xmask, tmp251, _tmp250)
tl.store(out_ptr0 + (r1 + (2560*x0)), tmp10, rmask & xmask)
tl.store(out_ptr1 + (r1 + (2560*x0)), tmp17, rmask & xmask)
tl.store(out_ptr2 + (r1 + (2560*x0)), tmp20, rmask & xmask)
tl.store(out_ptr3 + (r1 + (2560*x0)), tmp30, rmask & xmask)
tl.store(out_ptr4 + (r1 + (2560*x0)), tmp37, rmask & xmask)
tl.store(out_ptr5 + (r1 + (2560*x0)), tmp40, rmask & xmask)
tl.store(out_ptr6 + (r1 + (2560*x0)), tmp50, rmask & xmask)
tl.store(out_ptr7 + (r1 + (2560*x0)), tmp57, rmask & xmask)
tl.store(out_ptr8 + (r1 + (2560*x0)), tmp60, rmask & xmask)
tl.store(out_ptr9 + (r1 + (2560*x0)), tmp70, rmask & xmask)
tl.store(out_ptr10 + (r1 + (2560*x0)), tmp77, rmask & xmask)
tl.store(out_ptr11 + (r1 + (2560*x0)), tmp80, rmask & xmask)
tl.store(out_ptr12 + (r1 + (2560*x0)), tmp90, rmask & xmask)
tl.store(out_ptr13 + (r1 + (2560*x0)), tmp97, rmask & xmask)
tl.store(out_ptr14 + (r1 + (2560*x0)), tmp100, rmask & xmask)
tl.store(out_ptr15 + (r1 + (2560*x0)), tmp110, rmask & xmask)
tl.store(out_ptr16 + (r1 + (2560*x0)), tmp117, rmask & xmask)
tl.store(out_ptr17 + (r1 + (2560*x0)), tmp120, rmask & xmask)
tl.store(out_ptr18 + (r1 + (2560*x0)), tmp130, rmask & xmask)
tl.store(out_ptr19 + (r1 + (2560*x0)), tmp137, rmask & xmask)
tl.store(out_ptr20 + (r1 + (2560*x0)), tmp140, rmask & xmask)
tl.store(out_ptr21 + (r1 + (2560*x0)), tmp150, rmask & xmask)
tl.store(out_ptr22 + (r1 + (2560*x0)), tmp157, rmask & xmask)
tl.store(out_ptr23 + (r1 + (2560*x0)), tmp160, rmask & xmask)
tl.store(out_ptr24 + (r1 + (2560*x0)), tmp170, rmask & xmask)
tl.store(out_ptr25 + (r1 + (2560*x0)), tmp177, rmask & xmask)
tl.store(out_ptr26 + (r1 + (2560*x0)), tmp180, rmask & xmask)
tl.store(out_ptr27 + (r1 + (2560*x0)), tmp190, rmask & xmask)
tl.store(out_ptr28 + (r1 + (2560*x0)), tmp197, rmask & xmask)
tl.store(out_ptr29 + (r1 + (2560*x0)), tmp200, rmask & xmask)
tl.store(out_ptr30 + (r1 + (2560*x0)), tmp210, rmask & xmask)
tl.store(out_ptr31 + (r1 + (2560*x0)), tmp217, rmask & xmask)
tl.store(out_ptr32 + (r1 + (2560*x0)), tmp220, rmask & xmask)
tl.store(out_ptr33 + (r1 + (2560*x0)), tmp230, rmask & xmask)
tl.store(out_ptr34 + (r1 + (2560*x0)), tmp237, rmask & xmask)
tl.store(out_ptr35 + (r1 + (2560*x0)), tmp244, rmask & xmask)
tmp250 = tl.sum(_tmp250, 1)[:, None]
_tmp259 = tl.full([XBLOCK, RBLOCK], 0, tl.float32)
for roffset in range(0, rnumel, RBLOCK):
rindex = roffset + rbase
rmask = rindex < rnumel
r1 = rindex
tmp252 = tl.load(in_ptr99 + (r1 + (2560*x0)), rmask & xmask, eviction_policy='evict_last', other=0).to(tl.float32)
tmp254 = tl.load(in_ptr100 + (r1), rmask, eviction_policy='evict_last', other=0)
tmp256 = tl.load(out_ptr35 + (r1 + (2560*x0)), rmask & xmask, eviction_policy='evict_last', other=0)
tmp253 = tmp252.to(tl.float32)
tmp255 = tmp253 * tmp254
tmp257 = tmp255 * tmp256
tmp258 = tl.broadcast_to(tmp257, [XBLOCK, RBLOCK])
tmp260 = _tmp259 + tmp258
_tmp259 = tl.where(rmask & xmask, tmp260, _tmp259)
tmp259 = tl.sum(_tmp259, 1)[:, None]
for roffset in range(0, rnumel, RBLOCK):
rindex = roffset + rbase
rmask = rindex < rnumel
r1 = rindex
tmp263 = tl.load(in_ptr99 + (r1 + (2560*x0)), rmask & xmask, eviction_policy='evict_last', other=0).to(tl.float32)
tmp265 = tl.load(in_ptr100 + (r1), rmask, eviction_policy='evict_last', other=0)
tmp269 = tl.load(out_ptr35 + (r1 + (2560*x0)), rmask & xmask, eviction_policy='evict_last', other=0)
tmp274 = tl.load(in_ptr101 + (r1 + (2560*x0)), rmask & xmask, eviction_policy='evict_last')
tmp261 = 2560.0
tmp262 = tmp243 / tmp261
tmp264 = tmp263.to(tl.float32)
tmp266 = tmp264 * tmp265
tmp267 = tmp266 * tmp261
tmp268 = tmp267 - tmp250
tmp270 = tmp269 * tmp259
tmp271 = tmp268 - tmp270
tmp272 = tmp262 * tmp271
tmp273 = tmp272.to(tl.float32)
tmp275 = tmp274.to(tl.float32)
tmp276 = 1.1111111111111112
tmp277 = tmp275 * tmp276
tmp278 = tmp273 * tmp277
tl.store(out_ptr38 + (r1 + (2560*x0)), tmp272, rmask & xmask)
tl.store(out_ptr39 + (r1 + (2560*x0)), tmp278, rmask & xmask)
def get_args():
arg_0 = rand_strided((4, 128, 2560), (327680, 2560, 1), device='cuda:0', dtype=torch.float32)
arg_1 = rand_strided((4, 128, 2560), (327680, 2560, 1), device='cuda:0', dtype=torch.float16)
arg_2 = rand_strided((4, 128, 2560), (327680, 2560, 1), device='cuda:0', dtype=torch.float16)
arg_3 = rand_strided((4, 128, 1), (128, 1, 1), device='cuda:0', dtype=torch.float32)
arg_4 = rand_strided((4, 128, 1), (128, 1, 1), device='cuda:0', dtype=torch.float32)
arg_5 = rand_strided((4, 128, 2560), (327680, 2560, 1), device='cuda:0', dtype=torch.float16)
arg_6 = rand_strided((4, 128, 1), (128, 1, 1), device='cuda:0', dtype=torch.float32)
arg_7 = rand_strided((4, 128, 1), (128, 1, 1), device='cuda:0', dtype=torch.float32)
arg_8 = rand_strided((4, 128, 2560), (327680, 2560, 1), device='cuda:0', dtype=torch.float16)
arg_9 = rand_strided((4, 128, 2560), (327680, 2560, 1), device='cuda:0', dtype=torch.float16)
arg_10 = rand_strided((4, 128, 2560), (327680, 2560, 1), device='cuda:0', dtype=torch.float16)
arg_11 = rand_strided((4, 128, 1), (128, 1, 1), device='cuda:0', dtype=torch.float32)
arg_12 = rand_strided((4, 128, 1), (128, 1, 1), device='cuda:0', dtype=torch.float32)
arg_13 = rand_strided((4, 128, 2560), (327680, 2560, 1), device='cuda:0', dtype=torch.float16)
arg_14 = rand_strided((4, 128, 1), (128, 1, 1), device='cuda:0', dtype=torch.float32)
arg_15 = rand_strided((4, 128, 1), (128, 1, 1), device='cuda:0', dtype=torch.float32)
arg_16 = rand_strided((4, 128, 2560), (327680, 2560, 1), device='cuda:0', dtype=torch.float16)
arg_17 = rand_strided((4, 128, 2560), (327680, 2560, 1), device='cuda:0', dtype=torch.float16)
arg_18 = rand_strided((4, 128, 2560), (327680, 2560, 1), device='cuda:0', dtype=torch.float16)
arg_19 = rand_strided((4, 128, 1), (128, 1, 1), device='cuda:0', dtype=torch.float32)
arg_20 = rand_strided((4, 128, 1), (128, 1, 1), device='cuda:0', dtype=torch.float32)
arg_21 = rand_strided((4, 128, 2560), (327680, 2560, 1), device='cuda:0', dtype=torch.float16)
arg_22 = rand_strided((4, 128, 1), (128, 1, 1), device='cuda:0', dtype=torch.float32)
arg_23 = rand_strided((4, 128, 1), (128, 1, 1), device='cuda:0', dtype=torch.float32)
arg_24 = rand_strided((4, 128, 2560), (327680, 2560, 1), device='cuda:0', dtype=torch.float16)
arg_25 = rand_strided((4, 128, 2560), (327680, 2560, 1), device='cuda:0', dtype=torch.float16)
arg_26 = rand_strided((4, 128, 2560), (327680, 2560, 1), device='cuda:0', dtype=torch.float16)
arg_27 = rand_strided((4, 128, 1), (128, 1, 1), device='cuda:0', dtype=torch.float32)
arg_28 = rand_strided((4, 128, 1), (128, 1, 1), device='cuda:0', dtype=torch.float32)
arg_29 = rand_strided((4, 128, 2560), (327680, 2560, 1), device='cuda:0', dtype=torch.float16)
arg_30 = rand_strided((4, 128, 1), (128, 1, 1), device='cuda:0', dtype=torch.float32)
arg_31 = rand_strided((4, 128, 1), (128, 1, 1), device='cuda:0', dtype=torch.float32)
arg_32 = rand_strided((4, 128, 2560), (327680, 2560, 1), device='cuda:0', dtype=torch.float16)
arg_33 = rand_strided((4, 128, 2560), (327680, 2560, 1), device='cuda:0', dtype=torch.float16)
arg_34 = rand_strided((4, 128, 2560), (327680, 2560, 1), device='cuda:0', dtype=torch.float16)
arg_35 = rand_strided((4, 128, 1), (128, 1, 1), device='cuda:0', dtype=torch.float32)
arg_36 = rand_strided((4, 128, 1), (128, 1, 1), device='cuda:0', dtype=torch.float32)
arg_37 = rand_strided((4, 128, 2560), (327680, 2560, 1), device='cuda:0', dtype=torch.float16)
arg_38 = rand_strided((4, 128, 1), (128, 1, 1), device='cuda:0', dtype=torch.float32)
arg_39 = rand_strided((4, 128, 1), (128, 1, 1), device='cuda:0', dtype=torch.float32)
arg_40 = rand_strided((4, 128, 2560), (327680, 2560, 1), device='cuda:0', dtype=torch.float16)
arg_41 = rand_strided((4, 128, 2560), (327680, 2560, 1), device='cuda:0', dtype=torch.float16)
arg_42 = rand_strided((4, 128, 2560), (327680, 2560, 1), device='cuda:0', dtype=torch.float16)
arg_43 = rand_strided((4, 128, 1), (128, 1, 1), device='cuda:0', dtype=torch.float32)
arg_44 = rand_strided((4, 128, 1), (128, 1, 1), device='cuda:0', dtype=torch.float32)
arg_45 = rand_strided((4, 128, 2560), (327680, 2560, 1), device='cuda:0', dtype=torch.float16)
arg_46 = rand_strided((4, 128, 1), (128, 1, 1), device='cuda:0', dtype=torch.float32)
arg_47 = rand_strided((4, 128, 1), (128, 1, 1), device='cuda:0', dtype=torch.float32)
arg_48 = rand_strided((4, 128, 2560), (327680, 2560, 1), device='cuda:0', dtype=torch.float16)
arg_49 = rand_strided((4, 128, 2560), (327680, 2560, 1), device='cuda:0', dtype=torch.float16)
arg_50 = rand_strided((4, 128, 2560), (327680, 2560, 1), device='cuda:0', dtype=torch.float16)
arg_51 = rand_strided((4, 128, 1), (128, 1, 1), device='cuda:0', dtype=torch.float32)
arg_52 = rand_strided((4, 128, 1), (128, 1, 1), device='cuda:0', dtype=torch.float32)
arg_53 = rand_strided((4, 128, 2560), (327680, 2560, 1), device='cuda:0', dtype=torch.float16)
arg_54 = rand_strided((4, 128, 1), (128, 1, 1), device='cuda:0', dtype=torch.float32)
arg_55 = rand_strided((4, 128, 1), (128, 1, 1), device='cuda:0', dtype=torch.float32)
arg_56 = rand_strided((4, 128, 2560), (327680, 2560, 1), device='cuda:0', dtype=torch.float16)
arg_57 = rand_strided((4, 128, 2560), (327680, 2560, 1), device='cuda:0', dtype=torch.float16)
arg_58 = rand_strided((4, 128, 2560), (327680, 2560, 1), device='cuda:0', dtype=torch.float16)
arg_59 = rand_strided((4, 128, 1), (128, 1, 1), device='cuda:0', dtype=torch.float32)
arg_60 = rand_strided((4, 128, 1), (128, 1, 1), device='cuda:0', dtype=torch.float32)
arg_61 = rand_strided((4, 128, 2560), (327680, 2560, 1), device='cuda:0', dtype=torch.float16)
arg_62 = rand_strided((4, 128, 1), (128, 1, 1), device='cuda:0', dtype=torch.float32)
arg_63 = rand_strided((4, 128, 1), (128, 1, 1), device='cuda:0', dtype=torch.float32)
arg_64 = rand_strided((4, 128, 2560), (327680, 2560, 1), device='cuda:0', dtype=torch.float16)
arg_65 = rand_strided((4, 128, 2560), (327680, 2560, 1), device='cuda:0', dtype=torch.float16)
arg_66 = rand_strided((4, 128, 2560), (327680, 2560, 1), device='cuda:0', dtype=torch.float16)
arg_67 = rand_strided((4, 128, 1), (128, 1, 1), device='cuda:0', dtype=torch.float32)
arg_68 = rand_strided((4, 128, 1), (128, 1, 1), device='cuda:0', dtype=torch.float32)
arg_69 = rand_strided((4, 128, 2560), (327680, 2560, 1), device='cuda:0', dtype=torch.float16)
arg_70 = rand_strided((4, 128, 1), (128, 1, 1), device='cuda:0', dtype=torch.float32)
arg_71 = rand_strided((4, 128, 1), (128, 1, 1), device='cuda:0', dtype=torch.float32)
arg_72 = rand_strided((4, 128, 2560), (327680, 2560, 1), device='cuda:0', dtype=torch.float16)
arg_73 = rand_strided((4, 128, 2560), (327680, 2560, 1), device='cuda:0', dtype=torch.float16)
arg_74 = rand_strided((4, 128, 2560), (327680, 2560, 1), device='cuda:0', dtype=torch.float16)
arg_75 = rand_strided((4, 128, 1), (128, 1, 1), device='cuda:0', dtype=torch.float32)
arg_76 = rand_strided((4, 128, 1), (128, 1, 1), device='cuda:0', dtype=torch.float32)
arg_77 = rand_strided((4, 128, 2560), (327680, 2560, 1), device='cuda:0', dtype=torch.float16)
arg_78 = rand_strided((4, 128, 1), (128, 1, 1), device='cuda:0', dtype=torch.float32)
arg_79 = rand_strided((4, 128, 1), (128, 1, 1), device='cuda:0', dtype=torch.float32)
arg_80 = rand_strided((4, 128, 2560), (327680, 2560, 1), device='cuda:0', dtype=torch.float16)
arg_81 = rand_strided((4, 128, 2560), (327680, 2560, 1), device='cuda:0', dtype=torch.float16)
arg_82 = rand_strided((4, 128, 2560), (327680, 2560, 1), device='cuda:0', dtype=torch.float16)
arg_83 = rand_strided((4, 128, 1), (128, 1, 1), device='cuda:0', dtype=torch.float32)
arg_84 = rand_strided((4, 128, 1), (128, 1, 1), device='cuda:0', dtype=torch.float32)
arg_85 = rand_strided((4, 128, 2560), (327680, 2560, 1), device='cuda:0', dtype=torch.float16)
arg_86 = rand_strided((4, 128, 1), (128, 1, 1), device='cuda:0', dtype=torch.float32)
arg_87 = rand_strided((4, 128, 1), (128, 1, 1), device='cuda:0', dtype=torch.float32)
arg_88 = rand_strided((4, 128, 2560), (327680, 2560, 1), device='cuda:0', dtype=torch.float16)
arg_89 = rand_strided((4, 128, 2560), (327680, 2560, 1), device='cuda:0', dtype=torch.float16)
arg_90 = rand_strided((4, 128, 2560), (327680, 2560, 1), device='cuda:0', dtype=torch.float16)
arg_91 = rand_strided((4, 128, 1), (128, 1, 1), device='cuda:0', dtype=torch.float32)
arg_92 = rand_strided((4, 128, 1), (128, 1, 1), device='cuda:0', dtype=torch.float32)
arg_93 = rand_strided((4, 128, 2560), (327680, 2560, 1), device='cuda:0', dtype=torch.float16)
arg_94 = rand_strided((4, 128, 1), (128, 1, 1), device='cuda:0', dtype=torch.float32)
arg_95 = rand_strided((4, 128, 1), (128, 1, 1), device='cuda:0', dtype=torch.float32)
arg_96 = rand_strided((4, 128, 2560), (327680, 2560, 1), device='cuda:0', dtype=torch.float16)
arg_97 = rand_strided((4, 128, 1), (128, 1, 1), device='cuda:0', dtype=torch.float32)
arg_98 = rand_strided((4, 128, 1), (128, 1, 1), device='cuda:0', dtype=torch.float32)
arg_99 = rand_strided((512, 2560), (2560, 1), device='cuda:0', dtype=torch.float16)
arg_100 = rand_strided((2560,), (1,), device='cuda:0', dtype=torch.float32)
arg_101 = rand_strided((4, 128, 2560), (327680, 2560, 1), device='cuda:0', dtype=torch.bool)
arg_102 = rand_strided((4, 128, 2560), (327680, 2560, 1), device='cuda:0', dtype=torch.float32)
arg_103 = rand_strided((4, 128, 2560), (327680, 2560, 1), device='cuda:0', dtype=torch.float32)
arg_104 = rand_strided((4, 128, 2560), (327680, 2560, 1), device='cuda:0', dtype=torch.float32)
arg_105 = rand_strided((4, 128, 2560), (327680, 2560, 1), device='cuda:0', dtype=torch.float32)
arg_106 = rand_strided((4, 128, 2560), (327680, 2560, 1), device='cuda:0', dtype=torch.float32)
arg_107 = rand_strided((4, 128, 2560), (327680, 2560, 1), device='cuda:0', dtype=torch.float32)
arg_108 = rand_strided((4, 128, 2560), (327680, 2560, 1), device='cuda:0', dtype=torch.float32)
arg_109 = rand_strided((4, 128, 2560), (327680, 2560, 1), device='cuda:0', dtype=torch.float32)
arg_110 = rand_strided((4, 128, 2560), (327680, 2560, 1), device='cuda:0', dtype=torch.float32)
arg_111 = rand_strided((4, 128, 2560), (327680, 2560, 1), device='cuda:0', dtype=torch.float32)
arg_112 = rand_strided((4, 128, 2560), (327680, 2560, 1), device='cuda:0', dtype=torch.float32)
arg_113 = rand_strided((4, 128, 2560), (327680, 2560, 1), device='cuda:0', dtype=torch.float32)
arg_114 = rand_strided((4, 128, 2560), (327680, 2560, 1), device='cuda:0', dtype=torch.float32)
arg_115 = rand_strided((4, 128, 2560), (327680, 2560, 1), device='cuda:0', dtype=torch.float32)
arg_116 = rand_strided((4, 128, 2560), (327680, 2560, 1), device='cuda:0', dtype=torch.float32)
arg_117 = rand_strided((4, 128, 2560), (327680, 2560, 1), device='cuda:0', dtype=torch.float32)
arg_118 = rand_strided((4, 128, 2560), (327680, 2560, 1), device='cuda:0', dtype=torch.float32)
arg_119 = rand_strided((4, 128, 2560), (327680, 2560, 1), device='cuda:0', dtype=torch.float32)
arg_120 = rand_strided((4, 128, 2560), (327680, 2560, 1), device='cuda:0', dtype=torch.float32)
arg_121 = rand_strided((4, 128, 2560), (327680, 2560, 1), device='cuda:0', dtype=torch.float32)
arg_122 = rand_strided((4, 128, 2560), (327680, 2560, 1), device='cuda:0', dtype=torch.float32)
arg_123 = rand_strided((4, 128, 2560), (327680, 2560, 1), device='cuda:0', dtype=torch.float32)
arg_124 = rand_strided((4, 128, 2560), (327680, 2560, 1), device='cuda:0', dtype=torch.float32)
arg_125 = rand_strided((4, 128, 2560), (327680, 2560, 1), device='cuda:0', dtype=torch.float32)
arg_126 = rand_strided((4, 128, 2560), (327680, 2560, 1), device='cuda:0', dtype=torch.float32)
arg_127 = rand_strided((4, 128, 2560), (327680, 2560, 1), device='cuda:0', dtype=torch.float32)
arg_128 = rand_strided((4, 128, 2560), (327680, 2560, 1), device='cuda:0', dtype=torch.float32)
arg_129 = rand_strided((4, 128, 2560), (327680, 2560, 1), device='cuda:0', dtype=torch.float32)
arg_130 = rand_strided((4, 128, 2560), (327680, 2560, 1), device='cuda:0', dtype=torch.float32)
arg_131 = rand_strided((4, 128, 2560), (327680, 2560, 1), device='cuda:0', dtype=torch.float32)
arg_132 = rand_strided((4, 128, 2560), (327680, 2560, 1), device='cuda:0', dtype=torch.float32)
arg_133 = rand_strided((4, 128, 2560), (327680, 2560, 1), device='cuda:0', dtype=torch.float32)
arg_134 = rand_strided((4, 128, 2560), (327680, 2560, 1), device='cuda:0', dtype=torch.float32)
arg_135 = rand_strided((4, 128, 2560), (327680, 2560, 1), device='cuda:0', dtype=torch.float32)
arg_136 = rand_strided((4, 128, 2560), (327680, 2560, 1), device='cuda:0', dtype=torch.float32)
arg_137 = rand_strided((4, 128, 2560), (327680, 2560, 1), device='cuda:0', dtype=torch.float32)
arg_138 = rand_strided((4, 128, 2560), (327680, 2560, 1), device='cuda:0', dtype=torch.float32)
arg_139 = rand_strided((4, 128, 2560), (327680, 2560, 1), device='cuda:0', dtype=torch.float16)
return arg_0, arg_1, arg_2, arg_3, arg_4, arg_5, arg_6, arg_7, arg_8, arg_9, arg_10, arg_11, arg_12, arg_13, arg_14, arg_15, arg_16, arg_17, arg_18, arg_19, arg_20, arg_21, arg_22, arg_23, arg_24, arg_25, arg_26, arg_27, arg_28, arg_29, arg_30, arg_31, arg_32, arg_33, arg_34, arg_35, arg_36, arg_37, arg_38, arg_39, arg_40, arg_41, arg_42, arg_43, arg_44, arg_45, arg_46, arg_47, arg_48, arg_49, arg_50, arg_51, arg_52, arg_53, arg_54, arg_55, arg_56, arg_57, arg_58, arg_59, arg_60, arg_61, arg_62, arg_63, arg_64, arg_65, arg_66, arg_67, arg_68, arg_69, arg_70, arg_71, arg_72, arg_73, arg_74, arg_75, arg_76, arg_77, arg_78, arg_79, arg_80, arg_81, arg_82, arg_83, arg_84, arg_85, arg_86, arg_87, arg_88, arg_89, arg_90, arg_91, arg_92, arg_93, arg_94, arg_95, arg_96, arg_97, arg_98, arg_99, arg_100, arg_101, arg_102, arg_103, arg_104, arg_105, arg_106, arg_107, arg_108, arg_109, arg_110, arg_111, arg_112, arg_113, arg_114, arg_115, arg_116, arg_117, arg_118, arg_119, arg_120, arg_121, arg_122, arg_123, arg_124, arg_125, arg_126, arg_127, arg_128, arg_129, arg_130, arg_131, arg_132, arg_133, arg_134, arg_135, arg_136, arg_137, arg_138, arg_139,
def call(args):
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
stream0 = get_cuda_stream(0)
triton_.run(*args, 512, 2560, grid=grid(512), stream=stream0)
def benchmark_all_configs(args):
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
return triton_.benchmark_all_configs(*args, 512, 2560, grid=grid(512))
if __name__ == '__main__':
from torch._inductor.utils import get_num_bytes
from triton.testing import do_bench
args = get_args()
ms = do_bench(lambda: call(args), rep=40, fast_flush=True)
num_gb = get_num_bytes(*args, num_in_out_args=0) / 1e9
gb_per_s = num_gb / (ms / 1e3)
print(f"{ms:.3f}ms {num_gb:.3f}GB {gb_per_s:.2f}GB/s")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment