Skip to content

Instantly share code, notes, and snippets.

@HDCharles
Created July 25, 2023 18:37
Show Gist options
  • Save HDCharles/d8a1ff7d52fcafcb7a0d880596b2c0c1 to your computer and use it in GitHub Desktop.
Save HDCharles/d8a1ff7d52fcafcb7a0d880596b2c0c1 to your computer and use it in GitHub Desktop.
codegen
from ctypes import c_void_p, c_long
import torch
import math
import random
import os
import tempfile
from math import inf, nan
from torch._inductor.hooks import run_intermediate_hooks
from torch._inductor.utils import maybe_profile
from torch import empty_strided, as_strided, device
from torch._inductor.codecache import AsyncCompile
from torch._inductor.select_algorithm import extern_kernels
aten = torch.ops.aten
assert_size_stride = torch._C._dynamo.guards.assert_size_stride
async_compile = AsyncCompile()
# kernel path: /tmp/torchinductor_cdhernandez/me/cme6svhcsihecrrpkvmkdyfs65x3wwf3kkkf372kqjwnoi3ltmyn.py
# Original ATen: aten.add, aten.mm, aten.mul
# aten.add => add
# aten.mm => tuned_mixed_dtype_mm
# aten.mul => mul
triton_unk_fused_add_mm_mul_0 = async_compile.triton('triton_unk_fused_add_mm_mul_0', '''
import triton.language as tl
import triton
from torch._inductor.triton_heuristics import template
from torch._inductor.utils import instance_descriptor
from torch._inductor import triton_helpers
@template(num_stages=4, num_warps=8, meta={'signature': {0: '*bf16', 1: '*i8', 2: '*bf16', 3: '*bf16', 4: '*bf16'}, 'device': 0, 'constants': {}, 'configs': [instance_descriptor(divisible_by_16=(0, 1, 2, 3, 4), equal_to_1=())]})
@triton.jit
def triton_unk_fused_add_mm_mul_0(arg_A, arg_B, in_ptr2, in_ptr3, out_ptr1):
GROUP_M : tl.constexpr = 8
EVEN_K : tl.constexpr = True
ALLOW_TF32 : tl.constexpr = False
ACC_TYPE : tl.constexpr = tl.float32
B_PROLOGUE_CAST_TYPE : tl.constexpr = tl.bfloat16
BLOCK_M : tl.constexpr = 128
BLOCK_N : tl.constexpr = 64
BLOCK_K : tl.constexpr = 32
A = arg_A
B = arg_B
M = 4096
N = 1280
K = 1280
stride_am = 1280
stride_ak = 1
stride_bk = 1280
stride_bn = 1
# based on triton.ops.matmul
pid = tl.program_id(0)
grid_m = (M + BLOCK_M - 1) // BLOCK_M
grid_n = (N + BLOCK_N - 1) // BLOCK_N
# re-order program ID for better L2 performance
width = GROUP_M * grid_n
group_id = pid // width
group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
pid_m = group_id * GROUP_M + (pid % group_size)
pid_n = (pid % width) // (group_size)
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)
rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)
rk = tl.arange(0, BLOCK_K)
A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)
B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
for k in range(K, 0, -BLOCK_K):
if EVEN_K:
a = tl.load(A)
b = tl.load(B)
else:
a = tl.load(A, mask=rk[None, :] < k, other=0.)
b = tl.load(B, mask=rk[:, None] < k, other=0.)
if B_PROLOGUE_CAST_TYPE is not None:
b = b.to(B_PROLOGUE_CAST_TYPE)
acc += tl.dot(a, b, allow_tf32=ALLOW_TF32)
A += BLOCK_K * stride_ak
B += BLOCK_K * stride_bk
# rematerialize rm and rn to save registers
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
idx_m = rm[:, None]
idx_n = rn[None, :]
mask = (idx_m < M) & (idx_n < N)
# inductor generates a suffix
xindex = idx_n + (1280*idx_m)
tmp0 = tl.load(in_ptr2 + (tl.broadcast_to(idx_n, mask.shape)), mask, eviction_policy='evict_last').to(tl.float32)
tmp2 = tl.load(in_ptr3 + (tl.broadcast_to(idx_n, mask.shape)), mask, eviction_policy='evict_last').to(tl.float32)
tmp1 = acc * tmp0
tmp3 = tmp1 + tmp2
tl.store(out_ptr1 + (tl.broadcast_to(idx_n + (1280*idx_m), mask.shape)), tmp3, mask)
''')
import triton
import triton.language as tl
from torch._inductor.triton_heuristics import grid, start_graph, end_graph
from torch._C import _cuda_getCurrentRawStream as get_cuda_stream
import torch._inductor.kernel.mm_common
meta0 = {'GROUP_M': 8, 'EVEN_K': True, 'ALLOW_TF32': False, 'ACC_TYPE': 'tl.float32', 'B_PROLOGUE_CAST_TYPE': 'tl.bfloat16', 'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 32}
async_compile.wait(globals())
del async_compile
def call(args):
arg0_1, arg1_1, arg2_1, arg3_1 = args
args.clear()
assert_size_stride(arg0_1, (1280, 1280), (1280, 1))
assert_size_stride(arg1_1, (1280, ), (1, ))
assert_size_stride(arg2_1, (1280, ), (1, ))
assert_size_stride(arg3_1, (4096, 1280), (1280, 1))
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0) # no-op to ensure context
buf1 = empty_strided((4096, 1280), (1280, 1), device='cuda', dtype=torch.bfloat16)
stream0 = get_cuda_stream(0)
triton_unk_fused_add_mm_mul_0.run(arg3_1, arg0_1, arg1_1, arg2_1, buf1, grid=torch._inductor.kernel.mm_common.mm_grid(4096, 1280, meta0), stream=stream0)
del arg0_1
del arg1_1
del arg2_1
del arg3_1
return (buf1, )
def benchmark_compiled_module(times=10, repeat=10):
from torch._dynamo.testing import rand_strided
from torch._inductor.utils import print_performance
arg0_1 = rand_strided((1280, 1280), (1280, 1), device='cuda:0', dtype=torch.int8)
arg1_1 = rand_strided((1280, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
arg2_1 = rand_strided((1280, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
arg3_1 = rand_strided((4096, 1280), (1280, 1), device='cuda:0', dtype=torch.bfloat16)
return print_performance(lambda: call([arg0_1, arg1_1, arg2_1, arg3_1]), times=times, repeat=repeat)
if __name__ == "__main__":
from torch._inductor.utils import compiled_module_main
compiled_module_main('None', benchmark_compiled_module)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment