Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save shunting314/a788db61dfe92c0e41c0b1761f39b7d2 to your computer and use it in GitHub Desktop.
Save shunting314/a788db61dfe92c0e41c0b1761f39b7d2 to your computer and use it in GitHub Desktop.
from ctypes import c_void_p, c_long
import torch
import math
import random
import os
import tempfile
from math import inf, nan
from torch._inductor.hooks import run_intermediate_hooks
from torch._inductor.utils import maybe_profile
from torch import empty_strided, as_strided, device
from torch._inductor.codecache import AsyncCompile
from torch._inductor.select_algorithm import extern_kernels
aten = torch.ops.aten
assert_size_stride = torch._C._dynamo.guards.assert_size_stride
async_compile = AsyncCompile()
# kernel path: /tmp/torchinductor_shunting/r6/cr6y7ss5jrskx4cgogjsoqzsx6y7qmngyqgxz3r4zzc4g57wwrk4.py
# Source Nodes: [add, add_1, add_2, sum_1], Original ATen: [aten.add, aten.sum]
# add => add
# add_1 => add_1
# add_2 => add_2
# sum_1 => sum_1
triton_red_fused_add_sum_0 = async_compile.triton('triton_red_fused_add_sum_0', '''
import triton
import triton.language as tl
from torch._inductor.ir import ReductionHint
from torch._inductor.ir import TileHint
from torch._inductor.triton_heuristics import AutotuneHint, reduction
from torch._inductor.utils import instance_descriptor
from torch._inductor import triton_helpers
from torch._dynamo.testing import rand_strided
from torch._C import _cuda_getCurrentRawStream as get_cuda_stream
import torch
from torch._inductor.triton_heuristics import grid
@reduction(
size_hints=[262144, 512],
reduction_hint=ReductionHint.DEFAULT,
filename=__file__,
meta={'signature': {0: '*fp32', 1: '*fp32', 2: '*fp32', 3: '*fp32', 4: '*fp32', 5: 'i32', 6: 'i32'}, 'device': 0, 'constants': {}, 'mutated_arg_names': [], 'autotune_hints': set(), 'configs': [instance_descriptor(divisible_by_16=(0, 1, 2, 3, 4, 5, 6), equal_to_1=())]}
)
@triton.jit
def triton_red_fused_add_sum_0(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr0, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr):
xnumel = 262144
rnumel = 512
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
xmask = xindex < xnumel
rbase = tl.arange(0, RBLOCK)[None, :]
x3 = xindex
x0 = xindex % 512
x1 = (xindex // 512)
_tmp7 = tl.full([XBLOCK, RBLOCK], 0, tl.float32)
for roffset in range(0, rnumel, RBLOCK):
rindex = roffset + rbase
rmask = rindex < rnumel
r2 = rindex
tmp0 = tl.load(in_ptr0 + (x3 + (262144*r2)), rmask, other=0)
tmp1 = tl.load(in_ptr1 + (x1 + (512*x0) + (262144*r2)), rmask, other=0)
tmp3 = tl.load(in_ptr2 + (x1 + (512*x0) + (262144*r2)), rmask, other=0)
tmp5 = tl.load(in_ptr3 + (x1 + (512*x0) + (262144*r2)), rmask, other=0)
tmp2 = tmp0 + tmp1
tmp4 = tmp2 + tmp3
tmp6 = tmp4 + tmp5
tmp8 = _tmp7 + tmp6
_tmp7 = tl.where(rmask, tmp8, _tmp7)
tmp7 = tl.sum(_tmp7, 1)[:, None]
tl.store(out_ptr0 + (x3), tmp7, None)
def get_args():
arg_0 = rand_strided((512, 512, 512), (262144, 512, 1), device='cuda:0', dtype=torch.float32)
arg_1 = rand_strided((512, 512, 512), (262144, 512, 1), device='cuda:0', dtype=torch.float32)
arg_2 = rand_strided((512, 512, 512), (262144, 512, 1), device='cuda:0', dtype=torch.float32)
arg_3 = rand_strided((512, 512, 512), (262144, 512, 1), device='cuda:0', dtype=torch.float32)
arg_4 = rand_strided((512, 512), (512, 1), device='cuda:0', dtype=torch.float32)
return arg_0, arg_1, arg_2, arg_3, arg_4,
def call(args):
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
stream0 = get_cuda_stream(0)
triton_red_fused_add_sum_0.run(*args, 262144, 512, grid=grid(262144), stream=stream0)
def benchmark_all_configs(args):
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
return triton_red_fused_add_sum_0.benchmark_all_configs(*args, 262144, 512, grid=grid(262144))
if __name__ == '__main__':
from torch._inductor.utils import get_num_bytes
from triton.testing import do_bench
args = get_args()
ms = do_bench(lambda: call(args), rep=40, fast_flush=True)
num_gb = get_num_bytes(*args, num_in_out_args=0) / 1e9
gb_per_s = num_gb / (ms / 1e3)
print(f"{ms:.3f}ms {num_gb:.3f}GB {gb_per_s:.2f}GB/s")
''')
import triton
import triton.language as tl
from torch._inductor.triton_heuristics import grid, start_graph, end_graph
from torch._C import _cuda_getCurrentRawStream as get_cuda_stream
async_compile.wait(globals())
del async_compile
def call(args):
arg0_1, arg1_1, arg2_1, arg3_1 = args
args.clear()
assert_size_stride(arg0_1, (512, 512, 512), (262144, 512, 1))
assert_size_stride(arg1_1, (512, 512, 512), (262144, 512, 1))
assert_size_stride(arg2_1, (512, 512, 512), (262144, 512, 1))
assert_size_stride(arg3_1, (512, 512, 512), (262144, 512, 1))
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0) # no-op to ensure context
buf0 = empty_strided((512, 512), (512, 1), device='cuda', dtype=torch.float32)
# Source Nodes: [add, add_1, add_2, sum_1], Original ATen: [aten.add, aten.sum]
stream0 = get_cuda_stream(0)
triton_red_fused_add_sum_0.run(arg0_1, arg1_1, arg2_1, arg3_1, buf0, 262144, 512, grid=grid(262144), stream=stream0)
del arg0_1
del arg1_1
del arg2_1
del arg3_1
return (buf0, )
def benchmark_compiled_module(times=10, repeat=10):
from torch._dynamo.testing import rand_strided
from torch._inductor.utils import print_performance
arg0_1 = rand_strided((512, 512, 512), (262144, 512, 1), device='cuda:0', dtype=torch.float32)
arg1_1 = rand_strided((512, 512, 512), (262144, 512, 1), device='cuda:0', dtype=torch.float32)
arg2_1 = rand_strided((512, 512, 512), (262144, 512, 1), device='cuda:0', dtype=torch.float32)
arg3_1 = rand_strided((512, 512, 512), (262144, 512, 1), device='cuda:0', dtype=torch.float32)
return print_performance(lambda: call([arg0_1, arg1_1, arg2_1, arg3_1]), times=times, repeat=repeat)
if __name__ == "__main__":
from torch._inductor.wrapper_benchmark import compiled_module_main
compiled_module_main('None', benchmark_compiled_module)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment