Created
April 10, 2023 20:50
-
-
Save shunting314/96a0afef9dce53d6357bf1633094f358 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import triton | |
import triton.language as tl | |
from torch._inductor.ir import ReductionHint | |
from torch._inductor.ir import TileHint | |
from torch._inductor.triton_heuristics import reduction | |
from torch._inductor.utils import instance_descriptor | |
from torch._dynamo.testing import rand_strided | |
from torch._C import _cuda_getCurrentRawStream as get_cuda_stream | |
import torch | |
from torch._inductor.triton_heuristics import grid | |
@reduction( | |
size_hints=[4096, 131072], | |
reduction_hint=ReductionHint.INNER, | |
filename=__file__, | |
meta={'signature': {0: '*fp16', 1: '*fp32', 2: '*fp32', 3: 'i32', 4: 'i32'}, 'device': 0, 'constants': {}, 'mutated_arg_names': [], 'configs': [instance_descriptor(divisible_by_16=(0, 1, 2, 3), equal_to_1=())]} | |
) | |
@triton.jit | |
def triton_red_fused__native_batch_norm_legit_functional_16(in_ptr0, in_ptr1, out_ptr0, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr): | |
xnumel = 2496 | |
rnumel = 123511 | |
xoffset = tl.program_id(0) * XBLOCK | |
xindex = xoffset + tl.arange(0, XBLOCK)[:, None] | |
xmask = xindex < xnumel | |
rbase = tl.arange(0, RBLOCK)[None, :] | |
x1 = (xindex // 192) | |
x0 = xindex % 192 | |
_tmp9 = tl.zeros([XBLOCK, RBLOCK], tl.float32) + 0 | |
x3 = xindex | |
for roffset in range(0, rnumel, RBLOCK): | |
rindex = roffset + rbase | |
rmask = rindex < rnumel | |
r2 = rindex | |
tmp0 = r2 + (123511*x1) | |
tmp1 = 1605632 | |
tmp2 = tmp0 < tmp1 | |
tmp3 = tl.load(in_ptr0 + ((12544*x0) + (2408448*(((r2 + (123511*x1)) // 12544) % 128)) + ((r2 + (123511*x1)) % 12544) + tl.zeros([XBLOCK, RBLOCK], tl.int32)), rmask & tmp2 & xmask, eviction_policy='evict_last', other=0).to(tl.float32) | |
tmp4 = tmp3.to(tl.float32) | |
tmp5 = tl.load(in_ptr1 + (x0 + tl.zeros([XBLOCK, RBLOCK], tl.int32)), rmask & tmp2 & xmask, eviction_policy='evict_last', other=0) | |
tmp6 = tmp4 - tmp5 | |
tmp7 = tmp6 * tmp6 | |
tmp8 = tl.where(tmp2, tmp7, 0) | |
_tmp9 = tl.where(rmask & xmask, _tmp9 + tmp8, _tmp9) | |
tmp9 = tl.sum(_tmp9, 1)[:, None] | |
tl.store(out_ptr0 + x3, tmp9, xmask) | |
def get_args(): | |
arg_0 = rand_strided((128, 192, 112, 112), (2408448, 12544, 112, 1), device='cuda:0', dtype=torch.float16) | |
arg_1 = rand_strided((1, 192, 1, 1), (192, 1, 192, 192), device='cuda:0', dtype=torch.float32) | |
arg_2 = rand_strided((1, 192, 1, 1, 13), (2496, 1, 2496, 2496, 192), device='cuda:0', dtype=torch.float32) | |
return arg_0, arg_1, arg_2, | |
def call(args): | |
with torch.cuda._DeviceGuard(0): | |
torch.cuda.set_device(0) | |
stream0 = get_cuda_stream(0) | |
triton_red_fused__native_batch_norm_legit_functional_16.run(*args, 2496, 123511, grid=grid(2496), stream=stream0) | |
def benchmark_all_configs(args): | |
with torch.cuda._DeviceGuard(0): | |
torch.cuda.set_device(0) | |
return triton_red_fused__native_batch_norm_legit_functional_16.benchmark_all_configs(*args, 2496, 123511, grid=grid(2496)) | |
if __name__ == '__main__': | |
from torch._inductor.utils import get_num_bytes | |
from triton.testing import do_bench | |
args = get_args() | |
ms = do_bench(lambda: call(args), rep=40, fast_flush=True)[0] | |
num_gb = get_num_bytes(*args, num_in_out_args=0) / 1e9 | |
gb_per_s = num_gb / (ms / 1e3) | |
print(f"{ms:.3f}ms {num_gb:.3f}GB {gb_per_s:.2f}GB/s") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment