Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save shunting314/a78997f54b5751f2887f4576956036ce to your computer and use it in GitHub Desktop.
Save shunting314/a78997f54b5751f2887f4576956036ce to your computer and use it in GitHub Desktop.
import triton
import triton.language as tl
from torch._inductor.ir import ReductionHint
from torch._inductor.ir import TileHint
from torch._inductor.triton_heuristics import persistent_reduction
from torch._inductor.utils import instance_descriptor
from torch._inductor import triton_helpers
from torch._dynamo.testing import rand_strided
from torch._C import _cuda_getCurrentRawStream as get_cuda_stream
import torch
from torch._inductor.triton_heuristics import grid
@persistent_reduction(
size_hints=[8192, 1024],
reduction_hint=ReductionHint.INNER,
filename=__file__,
meta={'signature': {0: '*i64', 1: '*i64', 2: '*fp32', 3: '*i64', 4: '*fp32', 5: '*fp32', 6: '*i64', 7: '*fp32', 8: '*fp32', 9: '*i1', 10: '*fp32', 11: '*fp16', 12: '*fp32', 13: 'i32', 14: 'i32', 15: 'i32'}, 'device': 0, 'constants': {}, 'mutated_arg_names': ['in_out_ptr0'], 'configs': [instance_descriptor(divisible_by_16=(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 15), equal_to_1=())]}
)
@triton.jit
def triton_per_fused__to_copy_add_embedding_mul_native_dropout_native_layer_norm_native_layer_norm_backward_ne_view_1(in_out_ptr0, in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, in_ptr5, in_ptr6, in_ptr7, out_ptr3, out_ptr4, out_ptr5, out_ptr6, load_seed_offset, xnumel, rnumel, XBLOCK : tl.constexpr):
xnumel = 8192
rnumel = 768
RBLOCK: tl.constexpr = 1024
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset
xmask = xindex < xnumel
rindex = tl.arange(0, RBLOCK)[:]
rmask = rindex < rnumel
x0 = xindex
r3 = rindex
x1 = xindex % 512
tmp0 = tl.load(in_out_ptr0 + (x0), None)
tmp4 = tl.load(in_ptr0 + (x0), None)
tmp11 = tl.load(in_ptr0 + (x0), None, eviction_policy='evict_last')
tmp14 = tl.load(in_ptr2 + (x1), None, eviction_policy='evict_last')
tmp16 = tl.load(in_ptr3 + (r3), rmask, eviction_policy='evict_last', other=0)
tmp42 = tl.load(in_ptr6 + (r3), rmask, eviction_policy='evict_last', other=0)
tmp44 = tl.load(in_ptr7 + (r3), rmask, eviction_policy='evict_last', other=0)
tmp1 = tmp0.to(tl.int32)
tmp2 = 0
tmp3 = tmp1 + tmp2
tmp5 = 1
tmp6 = tmp4 != tmp5
tmp7 = tmp6.to(tl.int32)
tmp8 = tmp3 * tmp7
tmp9 = tmp8.to(tl.int64)
tmp10 = tmp9 + tmp5
tmp12 = triton_helpers.promote_to_tensor(tmp11)
tl.device_assert((0 <= tmp12) & (tmp12 < 32005), "index out of bounds: 0 <= tmp12 < 32005")
tmp13 = tl.load(in_ptr1 + (r3 + (768*tmp11)), rmask, other=0)
tmp15 = triton_helpers.promote_to_tensor(tmp14)
tl.device_assert((0 <= tmp15) & (tmp15 < 1), "index out of bounds: 0 <= tmp15 < 1")
tmp17 = tmp13 + tmp16
tmp18 = triton_helpers.promote_to_tensor(tmp10)
tl.device_assert((0 <= tmp18) & (tmp18 < 514), "index out of bounds: 0 <= tmp18 < 514")
tmp19 = tl.load(in_ptr4 + (r3 + (768*tmp10)), rmask, other=0)
tmp20 = tmp17 + tmp19
tmp22 = tl.where(rmask, tmp20, 0)
tmp23 = tl.sum(tmp22, 0)
tmp24 = 768.0
tmp25 = tmp23 / tmp24
tmp26 = tmp20 - tmp25
tmp27 = tmp26 * tmp26
tmp29 = tl.where(rmask, tmp27, 0)
tmp30 = tl.sum(tmp29, 0)
tmp31 = tl.load(in_ptr5 + load_seed_offset)
tmp32 = r3 + (768*x0)
tmp33 = tl.rand(tmp31, (tmp32).to(tl.uint32))
tmp34 = 0.1
tmp35 = tmp33 > tmp34
tmp36 = tmp30 / tmp24
tmp37 = 1e-05
tmp38 = tmp36 + tmp37
tmp39 = tl.math.rsqrt(tmp38)
tmp40 = tmp26 * tmp39
tmp41 = tmp35.to(tl.float32)
tmp43 = tmp40 * tmp42
tmp45 = tmp43 + tmp44
tmp46 = tmp41 * tmp45
tmp47 = 1.1111111111111112
tmp48 = tmp46 * tmp47
tmp49 = tmp48.to(tl.float32)
tmp50 = tmp39 / tmp24
tl.store(in_out_ptr0 + (x0), tmp10, None)
tl.store(out_ptr3 + (r3 + (768*x0)), tmp35, rmask)
tl.store(out_ptr4 + (r3 + (768*x0)), tmp40, rmask)
tl.store(out_ptr5 + (r3 + (768*x0)), tmp49, rmask)
tl.store(out_ptr6 + (x0), tmp50, None)
def get_args():
arg_0 = rand_strided((16, 512), (512, 1), device='cuda:0', dtype=torch.int64)
arg_1 = rand_strided((16, 512), (512, 1), device='cuda:0', dtype=torch.int64)
arg_2 = rand_strided((32005, 768), (768, 1), device='cuda:0', dtype=torch.float32)
arg_3 = rand_strided((1, 514), (514, 1), device='cuda:0', dtype=torch.int64)
arg_4 = rand_strided((1, 768), (768, 1), device='cuda:0', dtype=torch.float32)
arg_5 = rand_strided((514, 768), (768, 1), device='cuda:0', dtype=torch.float32)
arg_6 = rand_strided((37,), (1,), device='cuda:0', dtype=torch.int64)
arg_7 = rand_strided((768,), (1,), device='cuda:0', dtype=torch.float32)
arg_8 = rand_strided((768,), (1,), device='cuda:0', dtype=torch.float32)
arg_9 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.bool)
arg_10 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda:0', dtype=torch.float32)
arg_11 = rand_strided((8192, 768), (768, 1), device='cuda:0', dtype=torch.float16)
arg_12 = rand_strided((16, 512, 1), (512, 1, 1), device='cuda:0', dtype=torch.float32)
arg_13 = 0
return arg_0, arg_1, arg_2, arg_3, arg_4, arg_5, arg_6, arg_7, arg_8, arg_9, arg_10, arg_11, arg_12, arg_13,
def call(args):
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
stream0 = get_cuda_stream(0)
triton_per_fused__to_copy_add_embedding_mul_native_dropout_native_layer_norm_native_layer_norm_backward_ne_view_1.run(*args, 8192, 768, grid=grid(8192), stream=stream0)
def benchmark_all_configs(args):
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
return triton_per_fused__to_copy_add_embedding_mul_native_dropout_native_layer_norm_native_layer_norm_backward_ne_view_1.benchmark_all_configs(*args, 8192, 768, grid=grid(8192))
if __name__ == '__main__':
from torch._inductor.utils import get_num_bytes
from triton.testing import do_bench
args = get_args()
ms = do_bench(lambda: call(args), rep=40, fast_flush=True)
num_gb = get_num_bytes(*args, num_in_out_args=2) / 1e9
gb_per_s = num_gb / (ms / 1e3)
print(f"{ms:.3f}ms {num_gb:.3f}GB {gb_per_s:.2f}GB/s")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment