Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save HDCharles/c6413717039002c2c20b6cd669edba3e to your computer and use it in GitHub Desktop.
Save HDCharles/c6413717039002c2c20b6cd669edba3e to your computer and use it in GitHub Desktop.
triton graph for safe_int_mm
from ctypes import c_void_p, c_long
import torch
import math
import random
import os
import tempfile
from torch._inductor.hooks import run_intermediate_hooks
from torch._inductor.utils import maybe_profile
from torch import empty_strided, as_strided, device
from torch._inductor.codecache import AsyncCompile
from torch._inductor.select_algorithm import extern_kernels
aten = torch.ops.aten
assert_size_stride = torch._C._dynamo.guards.assert_size_stride
async_compile = AsyncCompile()
# kernel path: /tmp/torchinductor_cdhernandez/va/cvakdsbvsebtiaosp3kolpke7cfvzv5o6jcxjkfli4t2yuz2vum2.py
# Original ATen: aten._int_mm
# aten._int_mm => _int_mm
triton_unk_fused__int_mm_0 = async_compile.triton('triton_', '''
import triton.language as tl
import triton
from torch._inductor.triton_heuristics import template
from torch._inductor.utils import instance_descriptor
from torch._inductor import triton_helpers
@template(num_stages=2, num_warps=1, meta={'signature': {0: '*i8', 1: '*i8', 2: '*i32'}, 'device': 0, 'constants': {}, 'configs': [instance_descriptor(divisible_by_16=(0, 1, 2), equal_to_1=())]})
@triton.jit
def triton_(arg_A, arg_B, out_ptr0):
GROUP_M : tl.constexpr = 8
EVEN_K : tl.constexpr = False
ALLOW_TF32 : tl.constexpr = False
ACC_TYPE : tl.constexpr = tl.int32
BLOCK_M : tl.constexpr = 16
BLOCK_N : tl.constexpr = 16
BLOCK_K : tl.constexpr = 32
A = arg_A
B = arg_B
M = 8
N = 8
K = 17
stride_am = 17
stride_ak = 1
stride_bk = 8
stride_bn = 1
# based on triton.ops.matmul
pid = tl.program_id(0)
grid_m = (M + BLOCK_M - 1) // BLOCK_M
grid_n = (N + BLOCK_N - 1) // BLOCK_N
# re-order program ID for better L2 performance
width = GROUP_M * grid_n
group_id = pid // width
group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
pid_m = group_id * GROUP_M + (pid % group_size)
pid_n = (pid % width) // (group_size)
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)
rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)
rk = tl.arange(0, BLOCK_K)
A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)
B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
for k in range(K, 0, -BLOCK_K):
if EVEN_K:
a = tl.load(A)
b = tl.load(B)
else:
a = tl.load(A, mask=rk[None, :] < k, other=0.)
b = tl.load(B, mask=rk[:, None] < k, other=0.)
acc += tl.dot(a, b, allow_tf32=ALLOW_TF32)
A += BLOCK_K * stride_ak
B += BLOCK_K * stride_bk
# rematerialize rm and rn to save registers
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
idx_m = rm[:, None]
idx_n = rn[None, :]
mask = (idx_m < M) & (idx_n < N)
# inductor generates a suffix
xindex = idx_n + (8*idx_m)
tl.store(out_ptr0 + (tl.broadcast_to(xindex, mask.shape)), acc, mask)
''')
import triton
import triton.language as tl
from torch._inductor.triton_heuristics import grid, start_graph, end_graph
from torch._C import _cuda_getCurrentRawStream as get_cuda_stream
import torch._inductor.kernel.mm_common
meta0 = {'GROUP_M': 8, 'EVEN_K': False, 'ALLOW_TF32': False, 'ACC_TYPE': 'tl.int32', 'BLOCK_M': 16, 'BLOCK_N': 16, 'BLOCK_K': 32}
async_compile.wait(globals())
del async_compile
def call(args):
arg0_1, arg1_1 = args
args.clear()
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0) # no-op to ensure context
buf0 = empty_strided((8, 8), (8, 1), device='cuda', dtype=torch.int32)
stream0 = get_cuda_stream(0)
triton_unk_fused__int_mm_0.run(arg0_1, arg1_1, buf0, grid=torch._inductor.kernel.mm_common.mm_grid(8, 8, meta0), stream=stream0)
del arg0_1
del arg1_1
return (buf0, )
def benchmark_compiled_module(times=10, repeat=10):
from torch._dynamo.testing import rand_strided
from torch._inductor.utils import print_performance
arg0_1 = rand_strided((8, 17), (17, 1), device='cuda:0', dtype=torch.int8)
arg1_1 = rand_strided((17, 8), (8, 1), device='cuda:0', dtype=torch.int8)
return print_performance(lambda: call([arg0_1, arg1_1]), times=times, repeat=repeat)
if __name__ == "__main__":
from torch._inductor.utils import compiled_module_main
compiled_module_main('None', benchmark_compiled_module)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment