Created
June 13, 2023 18:11
-
-
Save shunting314/a78997f54b5751f2887f4576956036ce to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import triton | |
import triton.language as tl | |
from torch._inductor.ir import ReductionHint | |
from torch._inductor.ir import TileHint | |
from torch._inductor.triton_heuristics import persistent_reduction | |
from torch._inductor.utils import instance_descriptor | |
from torch._inductor import triton_helpers | |
from torch._dynamo.testing import rand_strided | |
from torch._C import _cuda_getCurrentRawStream as get_cuda_stream | |
import torch | |
from torch._inductor.triton_heuristics import grid | |
@persistent_reduction( | |
size_hints=[8192, 1024], | |
reduction_hint=ReductionHint.INNER, | |
filename=__file__, | |
meta={'signature': {0: '*i64', 1: '*i64', 2: '*fp32', 3: '*i64', 4: '*fp32', 5: '*fp32', 6: '*i64', 7: '*fp32', 8: '*fp32', 9: '*i1', 10: '*fp32', 11: '*fp16', 12: '*fp32', 13: 'i32', 14: 'i32', 15: 'i32'}, 'device': 0, 'constants': {}, 'mutated_arg_names': ['in_out_ptr0'], 'configs': [instance_descriptor(divisible_by_16=(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 15), equal_to_1=())]} | |
) | |
@triton.jit | |
def triton_per_fused__to_copy_add_embedding_mul_native_dropout_native_layer_norm_native_layer_norm_backward_ne_view_1(in_out_ptr0, in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, in_ptr5, in_ptr6, in_ptr7, out_ptr3, out_ptr4, out_ptr5, out_ptr6, load_seed_offset, xnumel, rnumel, XBLOCK : tl.constexpr): | |
xnumel = 8192 | |
rnumel = 768 | |
RBLOCK: tl.constexpr = 1024 | |
xoffset = tl.program_id(0) * XBLOCK | |
xindex = xoffset | |
xmask = xindex < xnumel | |
rindex = tl.arange(0, RBLOCK)[:] | |
rmask = rindex < rnumel | |
x0 = xindex | |
r3 = rindex | |
x1 = xindex % 512 | |
tmp0 = tl.load(in_out_ptr0 + (x0), None) | |
tmp4 = tl.load(in_ptr0 + (x0), None) | |
tmp11 = tl.load(in_ptr0 + (x0), None, eviction_policy='evict_last') | |
tmp14 = tl.load(in_ptr2 + (x1), None, eviction_policy='evict_last') | |
tmp16 = tl.load(in_ptr3 + (r3), rmask, eviction_policy='evict_last', other=0) | |
tmp42 = tl.load(in_ptr6 + (r3), rmask, eviction_policy='evict_last', other=0) | |
tmp44 = tl.load(in_ptr7 + (r3), rmask, eviction_policy='evict_last', other=0) | |
tmp1 = tmp0.to(tl.int32) | |
tmp2 = 0 | |
tmp3 = tmp1 + tmp2 | |
tmp5 = 1 | |
tmp6 = tmp4 != tmp5 | |
tmp7 = tmp6.to(tl.int32) | |
tmp8 = tmp3 * tmp7 | |
tmp9 = tmp8.to(tl.int64) | |
tmp10 = tmp9 + tmp5 | |
tmp12 = triton_helpers.promote_to_tensor(tmp11) | |
tl.device_assert((0 <= tmp12) & (tmp12 < 32005), "index out of bounds: 0 <= tmp12 < 32005") | |
tmp13 = tl.load(in_ptr1 + (r3 + (768*tmp11)), rmask, other=0) | |
tmp15 = triton_helpers.promote_to_tensor(tmp14) | |
tl.device_assert((0 <= tmp15) & (tmp15 < 1), "index out of bounds: 0 <= tmp15 < 1") | |
tmp17 = tmp13 + tmp16 | |
tmp18 = triton_helpers.promote_to_tensor(tmp10) | |
tl.device_assert((0 <= tmp18) & (tmp18 < 514), "index out of bounds: 0 <= tmp18 < 514") | |
tmp19 = tl.load(in_ptr4 + (r3 + (768*tmp10)), rmask, other=0) | |
tmp20 = tmp17 + tmp19 | |
tmp22 = tl.where(rmask, tmp20, 0) | |
tmp23 = tl.sum(tmp22, 0) | |
tmp24 = 768.0 | |
tmp25 = tmp23 / tmp24 | |
tmp26 = tmp20 - tmp25 | |
tmp27 = tmp26 * tmp26 | |
tmp29 = tl.where(rmask, tmp27, 0) | |
tmp30 = tl.sum(tmp29, 0) | |
tmp31 = tl.load(in_ptr5 + load_seed_offset) | |
tmp32 = r3 + (768*x0) | |
tmp33 = tl.rand(tmp31, (tmp32).to(tl.uint32)) | |
tmp34 = 0.1 | |
tmp35 = tmp33 > tmp34 | |
tmp36 = tmp30 / tmp24 | |
tmp37 = 1e-05 | |
tmp38 = tmp36 + tmp37 | |
tmp39 = tl.math.rsqrt(tmp38) | |
tmp40 = tmp26 * tmp39 | |
tmp41 = tmp35.to(tl.float32) | |
tmp43 = tmp40 * tmp42 | |
tmp45 = tmp43 + tmp44 | |
tmp46 = tmp41 * tmp45 | |
tmp47 = 1.1111111111111112 | |
tmp48 = tmp46 * tmp47 | |
tmp49 = tmp48.to(tl.float32) | |
tmp50 = tmp39 / tmp24 | |
tl.store(in_out_ptr0 + (x0), tmp10, None) | |
tl.store(out_ptr3 + (r3 + (768*x0)), tmp35, rmask) | |
tl.store(out_ptr4 + (r3 + (768*x0)), tmp40, rmask) | |
tl.store(out_ptr5 + (r3 + (768*x0)), tmp49, rmask) | |
tl.store(out_ptr6 + (x0), tmp50, None) | |
def get_args(): | |
arg_0 = rand_strided((16, 512), (512, 1), device='cuda:0', dtype=torch.int64) | |
arg_1 = rand_strided((16, 512), (512, 1), device='cuda:0', dtype=torch.int64) | |
arg_2 = rand_strided((32005, 768), (768, 1), device='cuda:0', dtype=torch.float32) | |
arg_3 = rand_strided((1, 514), (514, 1), device='cuda:0', dtype=torch.int64) | |
arg_4 = rand_strided((1, 768), (768, 1), device='cuda:0', dtype=torch.float32) | |
arg_5 = rand_strided((514, 768), (768, 1), device='cuda:0', dtype=torch.float32) | |
arg_6 = rand_strided((37,), (1,), device='cuda:0', dtype=torch.int64) | |
arg_7 = rand_strided((768,), (1,), device='cuda:0', dtype=torch.float32) | |
arg_8 = rand_strided((768,), (1,), device='cuda:0', dtype=torch.float32) | |
arg_9 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.bool) | |
arg_10 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.float32) | |
arg_11 = rand_strided((8192, 768), (768, 1), device='cuda:0', dtype=torch.float16) | |
arg_12 = rand_strided((16, 512, 1), (512, 1, 1), device='cuda:0', dtype=torch.float32) | |
arg_13 = 0 | |
return arg_0, arg_1, arg_2, arg_3, arg_4, arg_5, arg_6, arg_7, arg_8, arg_9, arg_10, arg_11, arg_12, arg_13, | |
def call(args): | |
with torch.cuda._DeviceGuard(0): | |
torch.cuda.set_device(0) | |
stream0 = get_cuda_stream(0) | |
triton_per_fused__to_copy_add_embedding_mul_native_dropout_native_layer_norm_native_layer_norm_backward_ne_view_1.run(*args, 8192, 768, grid=grid(8192), stream=stream0) | |
def benchmark_all_configs(args): | |
with torch.cuda._DeviceGuard(0): | |
torch.cuda.set_device(0) | |
return triton_per_fused__to_copy_add_embedding_mul_native_dropout_native_layer_norm_native_layer_norm_backward_ne_view_1.benchmark_all_configs(*args, 8192, 768, grid=grid(8192)) | |
if __name__ == '__main__': | |
from torch._inductor.utils import get_num_bytes | |
from triton.testing import do_bench | |
args = get_args() | |
ms = do_bench(lambda: call(args), rep=40, fast_flush=True) | |
num_gb = get_num_bytes(*args, num_in_out_args=2) / 1e9 | |
gb_per_s = num_gb / (ms / 1e3) | |
print(f"{ms:.3f}ms {num_gb:.3f}GB {gb_per_s:.2f}GB/s") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment