Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save shunting314/96a0afef9dce53d6357bf1633094f358 to your computer and use it in GitHub Desktop.
Save shunting314/96a0afef9dce53d6357bf1633094f358 to your computer and use it in GitHub Desktop.
import triton
import triton.language as tl
from torch._inductor.ir import ReductionHint
from torch._inductor.ir import TileHint
from torch._inductor.triton_heuristics import reduction
from torch._inductor.utils import instance_descriptor
from torch._dynamo.testing import rand_strided
from torch._C import _cuda_getCurrentRawStream as get_cuda_stream
import torch
from torch._inductor.triton_heuristics import grid
@reduction(
size_hints=[4096, 131072],
reduction_hint=ReductionHint.INNER,
filename=__file__,
meta={'signature': {0: '*fp16', 1: '*fp32', 2: '*fp32', 3: 'i32', 4: 'i32'}, 'device': 0, 'constants': {}, 'mutated_arg_names': [], 'configs': [instance_descriptor(divisible_by_16=(0, 1, 2, 3), equal_to_1=())]}
)
@triton.jit
def triton_red_fused__native_batch_norm_legit_functional_16(in_ptr0, in_ptr1, out_ptr0, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr):
xnumel = 2496
rnumel = 123511
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
xmask = xindex < xnumel
rbase = tl.arange(0, RBLOCK)[None, :]
x1 = (xindex // 192)
x0 = xindex % 192
_tmp9 = tl.zeros([XBLOCK, RBLOCK], tl.float32) + 0
x3 = xindex
for roffset in range(0, rnumel, RBLOCK):
rindex = roffset + rbase
rmask = rindex < rnumel
r2 = rindex
tmp0 = r2 + (123511*x1)
tmp1 = 1605632
tmp2 = tmp0 < tmp1
tmp3 = tl.load(in_ptr0 + ((12544*x0) + (2408448*(((r2 + (123511*x1)) // 12544) % 128)) + ((r2 + (123511*x1)) % 12544) + tl.zeros([XBLOCK, RBLOCK], tl.int32)), rmask & tmp2 & xmask, eviction_policy='evict_last', other=0).to(tl.float32)
tmp4 = tmp3.to(tl.float32)
tmp5 = tl.load(in_ptr1 + (x0 + tl.zeros([XBLOCK, RBLOCK], tl.int32)), rmask & tmp2 & xmask, eviction_policy='evict_last', other=0)
tmp6 = tmp4 - tmp5
tmp7 = tmp6 * tmp6
tmp8 = tl.where(tmp2, tmp7, 0)
_tmp9 = tl.where(rmask & xmask, _tmp9 + tmp8, _tmp9)
tmp9 = tl.sum(_tmp9, 1)[:, None]
tl.store(out_ptr0 + x3, tmp9, xmask)
def get_args():
arg_0 = rand_strided((128, 192, 112, 112), (2408448, 12544, 112, 1), device='cuda:0', dtype=torch.float16)
arg_1 = rand_strided((1, 192, 1, 1), (192, 1, 192, 192), device='cuda:0', dtype=torch.float32)
arg_2 = rand_strided((1, 192, 1, 1, 13), (2496, 1, 2496, 2496, 192), device='cuda:0', dtype=torch.float32)
return arg_0, arg_1, arg_2,
def call(args):
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
stream0 = get_cuda_stream(0)
triton_red_fused__native_batch_norm_legit_functional_16.run(*args, 2496, 123511, grid=grid(2496), stream=stream0)
def benchmark_all_configs(args):
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
return triton_red_fused__native_batch_norm_legit_functional_16.benchmark_all_configs(*args, 2496, 123511, grid=grid(2496))
if __name__ == '__main__':
from torch._inductor.utils import get_num_bytes
from triton.testing import do_bench
args = get_args()
ms = do_bench(lambda: call(args), rep=40, fast_flush=True)[0]
num_gb = get_num_bytes(*args, num_in_out_args=0) / 1e9
gb_per_s = num_gb / (ms / 1e3)
print(f"{ms:.3f}ms {num_gb:.3f}GB {gb_per_s:.2f}GB/s")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment