Skip to content

Instantly share code, notes, and snippets.

@shunting314
Created August 8, 2023 22:02
Show Gist options
  • Save shunting314/441e8839d24e1878c313e539b1ebd551 to your computer and use it in GitHub Desktop.
Save shunting314/441e8839d24e1878c313e539b1ebd551 to your computer and use it in GitHub Desktop.
import triton
import triton.language as tl
from torch._inductor.ir import ReductionHint
from torch._inductor.ir import TileHint
from torch._inductor.triton_heuristics import AutotuneHint, pointwise
from torch._inductor.utils import instance_descriptor
from torch._inductor import triton_helpers
from torch._dynamo.testing import rand_strided
from torch._C import _cuda_getCurrentRawStream as get_cuda_stream
import torch
from torch._inductor.triton_heuristics import grid
@pointwise(size_hints=[16384, 16384], tile_hint=TileHint.DEFAULT,filename=__file__, meta={'signature': {0: '*fp32', 1: '*fp32', 2: '*fp32', 3: 'i32', 4: 'i32'}, 'device': 0, 'device_type': 'cuda', 'constants': {}, 'mutated_arg_names': [], 'autotune_hints': set(), 'configs': [instance_descriptor(divisible_by_16=(0, 1, 2, 3, 4), equal_to_1=())]})
@triton.jit
def triton_poi_fused_add_0(in_ptr0, in_ptr1, out_ptr0, xnumel, ynumel, XBLOCK : tl.constexpr, YBLOCK : tl.constexpr):
xnumel = 16384
ynumel = 16384
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
xmask = xindex < xnumel
yoffset = tl.program_id(1) * YBLOCK
yindex = yoffset + tl.arange(0, YBLOCK)[None, :]
ymask = yindex < ynumel
x0 = xindex
y1 = yindex
# tmp0 = tl.load(in_ptr0 + (y1 + (16384*x0)), xmask & ymask)
# tmp1 = tl.load(in_ptr1 + (x0 + (16384*y1)), xmask & ymask)
tmp0 = tl.load(in_ptr1 + (y1 + (16384*x0)), xmask & ymask)
tmp1 = tl.load(in_ptr0 + (x0 + (16384*y1)), xmask & ymask)
tmp2 = tmp0 + tmp1
# tl.store(out_ptr0 + (y1 + (16384*x0)), tmp2, xmask & ymask)
tl.store(out_ptr0 + (x0 + (16384*y1)), tmp2, xmask & ymask)
def get_args():
arg_0 = rand_strided((16384, 16384), (16384, 1), device='cuda:0', dtype=torch.float32)
arg_1 = rand_strided((16384, 16384), (16384, 1), device='cuda:0', dtype=torch.float32)
arg_2 = rand_strided((16384, 16384), (16384, 1), device='cuda:0', dtype=torch.float32)
return arg_0, arg_1, arg_2,
def call(args):
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
stream0 = get_cuda_stream(0)
triton_poi_fused_add_0.run(*args, 16384, 16384, grid=grid(16384, 16384), stream=stream0)
def benchmark_all_configs(args):
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
return triton_poi_fused_add_0.benchmark_all_configs(*args, 16384, 16384, grid=grid(16384, 16384))
if __name__ == '__main__':
from torch._inductor.utils import get_num_bytes
from triton.testing import do_bench
args = get_args()
ms = do_bench(lambda: call(args), rep=40, fast_flush=True)
num_gb = get_num_bytes(*args, num_in_out_args=0) / 1e9
gb_per_s = num_gb / (ms / 1e3)
print(f"{ms:.3f}ms {num_gb:.3f}GB {gb_per_s:.2f}GB/s")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment