Skip to content

Instantly share code, notes, and snippets.

@shunting314
Created August 23, 2023 05:45
Show Gist options
  • Save shunting314/3a3b8ce1ccee7b51b8ee0d9a2d24dd3d to your computer and use it in GitHub Desktop.
Save shunting314/3a3b8ce1ccee7b51b8ee0d9a2d24dd3d to your computer and use it in GitHub Desktop.
from ctypes import c_void_p, c_long
import torch
import math
import random
import os
import tempfile
from math import inf, nan
from torch._inductor.hooks import run_intermediate_hooks
from torch._inductor.utils import maybe_profile
from torch import empty_strided, as_strided, device
from torch._inductor.codecache import AsyncCompile
from torch._inductor.select_algorithm import extern_kernels
from torch._inductor import config
config.compile_threads = 1
aten = torch.ops.aten
assert_size_stride = torch._C._dynamo.guards.assert_size_stride
async_compile = AsyncCompile()
# kernel path: /tmp/torchinductor_shunting/wy/cwyjslsgxrszsby4hxharbx4bkta6nvcti3qrytf3mi6bxh3e3jp.py
# Source Nodes: [mm], Original ATen: [aten.mm]
# mm => mixed_mm
triton_unk_fused_mm_0 = async_compile.triton('triton_', '''
import triton.language as tl
import triton
from torch._inductor.triton_heuristics import template
# from torch._inductor.utils import instance_descriptor
from torch._inductor import triton_helpers
from collections import namedtuple
instance_descriptor = namedtuple("instance_descriptor", ["divisible_by_16", "equal_to_1", "ids_of_folded_args", "divisible_by_8"], defaults=[set(), set(), set(), set()])
@template(num_stages=2, num_warps=1, meta={'signature': {0: '*fp32', 1: '*bf16', 2: '*fp32'}, 'device': 0, 'device_type': 'cuda', 'constants': {}, 'configs': [instance_descriptor(divisible_by_16=(0, 1, 2), equal_to_1=())]})
@triton.jit
def triton_(arg_A, arg_B, out_ptr0):
GROUP_M : tl.constexpr = 8
EVEN_K : tl.constexpr = False
ALLOW_TF32 : tl.constexpr = False
ACC_TYPE : tl.constexpr = tl.float32
B_PROLOGUE_CAST_TYPE : tl.constexpr = tl.float32
BLOCK_M : tl.constexpr = 16
BLOCK_N : tl.constexpr = 16
BLOCK_K : tl.constexpr = 16
A = arg_A
B = arg_B
M = 8
N = 8
K = 8
if M * N == 0:
# early exit due to zero-size input(s)
return
stride_am = 8
stride_ak = 1
stride_bk = 8
stride_bn = 1
# based on triton.ops.matmul
pid = tl.program_id(0)
grid_m = (M + BLOCK_M - 1) // BLOCK_M
grid_n = (N + BLOCK_N - 1) // BLOCK_N
# re-order program ID for better L2 performance
width = GROUP_M * grid_n
group_id = pid // width
group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
pid_m = group_id * GROUP_M + (pid % group_size)
pid_n = (pid % width) // (group_size)
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)
rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)
rk = tl.arange(0, BLOCK_K)
A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)
B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
for k in range(K, 0, -BLOCK_K):
if EVEN_K:
a = tl.load(A)
b = tl.load(B)
else:
a = tl.load(A, mask=rk[None, :] < k, other=0.)
b = tl.load(B, mask=rk[:, None] < k, other=0.)
# if B_PROLOGUE_CAST_TYPE is not None:
# b = b.to(B_PROLOGUE_CAST_TYPE)
b = b.to(tl.float32)
acc += tl.dot(a, b, allow_tf32=ALLOW_TF32)
A += BLOCK_K * stride_ak
B += BLOCK_K * stride_bk
# rematerialize rm and rn to save registers
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
idx_m = rm[:, None]
idx_n = rn[None, :]
mask = (idx_m < M) & (idx_n < N)
# inductor generates a suffix
xindex = idx_n + (8*idx_m)
tl.store(out_ptr0 + (tl.broadcast_to(xindex, mask.shape)), acc, mask)
''')
import triton
import triton.language as tl
from torch._inductor.triton_heuristics import grid, start_graph, end_graph
from torch._C import _cuda_getCurrentRawStream as get_cuda_stream
import torch._inductor.kernel.mm_common
meta0 = {'GROUP_M': 8, 'EVEN_K': False, 'ALLOW_TF32': False, 'ACC_TYPE': 'tl.float32', 'B_PROLOGUE_CAST_TYPE': 'tl.float32', 'BLOCK_M': 16, 'BLOCK_N': 16, 'BLOCK_K': 16}
async_compile.wait(globals())
del async_compile
def call(args):
arg0_1, arg1_1 = args
args.clear()
assert_size_stride(arg0_1, (8, 8), (8, 1))
assert_size_stride(arg1_1, (8, 8), (8, 1))
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0) # no-op to ensure context
buf0 = empty_strided((8, 8), (8, 1), device='cuda', dtype=torch.float32)
# Source Nodes: [mm], Original ATen: [aten.mm]
stream0 = get_cuda_stream(0)
triton_unk_fused_mm_0.run(arg0_1, arg1_1, buf0, grid=torch._inductor.kernel.mm_common.mm_grid(8, 8, meta0), stream=stream0)
del arg0_1
del arg1_1
return (buf0, )
def benchmark_compiled_module(times=10, repeat=10):
from torch._dynamo.testing import rand_strided
from torch._inductor.utils import print_performance
arg0_1 = rand_strided((8, 8), (8, 1), device='cuda:0', dtype=torch.float32)
arg1_1 = rand_strided((8, 8), (8, 1), device='cuda:0', dtype=torch.bfloat16)
return print_performance(lambda: call([arg0_1, arg1_1]), times=times, repeat=repeat)
if __name__ == "__main__":
from torch._inductor.wrapper_benchmark import compiled_module_main
compiled_module_main('None', benchmark_compiled_module)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment