Skip to content

Instantly share code, notes, and snippets.

@bertmaher
Created March 8, 2023 05:00
Show Gist options
  • Save bertmaher/93302c4f40728d8481873850e84cf47a to your computer and use it in GitHub Desktop.
Save bertmaher/93302c4f40728d8481873850e84cf47a to your computer and use it in GitHub Desktop.
Triton matmul crashes
import torch
import triton.language as tl
import triton
torch.set_float32_matmul_precision("high")
@triton.jit
def triton_mm(arg_A, arg_B, out_ptr0):
GROUP_M : tl.constexpr = 8
EVEN_K : tl.constexpr = True
ALLOW_TF32 : tl.constexpr = True
ACC_TYPE : tl.constexpr = tl.float32
BLOCK_M : tl.constexpr = 64
BLOCK_N : tl.constexpr = 64
BLOCK_K : tl.constexpr = 32
A = arg_A
B = arg_B
M = 1024
N = 1953
K = 2048
stride_am = 1
stride_ak = 1024
stride_bk = 7920
stride_bn = 1
# based on triton.ops.matmul
pid = tl.program_id(0)
grid_m = (M + BLOCK_M - 1) // BLOCK_M
grid_n = (N + BLOCK_N - 1) // BLOCK_N
# re-order program ID for better L2 performance
width = GROUP_M * grid_n
group_id = pid // width
group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
pid_m = group_id * GROUP_M + (pid % group_size)
pid_n = (pid % width) // (group_size)
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)
rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)
rk = tl.arange(0, BLOCK_K)
A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)
B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
for k in range(K, 0, -BLOCK_K):
if EVEN_K:
a = tl.load(A)
b = tl.load(B)
else:
a = tl.load(A, mask=rk[None, :] < k, other=0.)
b = tl.load(B, mask=rk[:, None] < k, other=0.)
acc += tl.dot(a, b, allow_tf32=ALLOW_TF32)
A += BLOCK_K * stride_ak
B += BLOCK_K * stride_bk
# rematerialize rm and rn to save registers
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
idx_m = rm[:, None]
idx_n = rn[None, :]
mask = (idx_m < M) & (idx_n < N)
# inductor generates a suffix
xindex = idx_n + (1953*idx_m)
tl.store(out_ptr0 + (xindex + tl.zeros(mask.shape, tl.int32)), acc, mask)
BLOCK_M = 64
BLOCK_N = 64
a = torch.empty_strided((1024, 2048), (1, 1024), device="cuda")
b = torch.empty_strided((2048, 1953), (7920, 1), device="cuda")
out = torch.empty(1024, 1953, device="cuda")
triton_mm[triton.cdiv(1024, BLOCK_M) * triton.cdiv(1953, BLOCK_N), 1, 1](
a, b, out,
num_stages=2,
num_warps=4,
)
torch.cuda.synchronize()
import torch
import triton.language as tl
import triton
torch.set_float32_matmul_precision("high")
@triton.jit
def triton_mm_plus_mm(arg_A, arg_B, arg_C, arg_D, out_ptr0):
GROUP_M : tl.constexpr = 8
EVEN_K : tl.constexpr = True
ALLOW_TF32 : tl.constexpr = True
ACC_TYPE : tl.constexpr = tl.float32
BLOCK_M : tl.constexpr = 64
BLOCK_N : tl.constexpr = 64
BLOCK_K : tl.constexpr = 32
A = arg_A
B = arg_B
C = arg_C
D = arg_D
M = 2048
N = 1536
K1 = 64
# K2 = 64
stride_am = 64
stride_ak = 1
stride_bk = 1536
stride_bn = 1
stride_cm = 64
stride_ck = 1
stride_dk = 1536
stride_dn = 1
# based on triton.ops.matmul
pid = tl.program_id(0)
grid_m = (M + BLOCK_M - 1) // BLOCK_M
grid_n = (N + BLOCK_N - 1) // BLOCK_N
# re-order program ID for better L2 performance
width = GROUP_M * grid_n
group_id = pid // width
group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
pid_m = group_id * GROUP_M + (pid % group_size)
pid_n = (pid % width) // (group_size)
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)
rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)
rk = tl.arange(0, BLOCK_K)
A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)
B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)
C = C + (ram[:, None] * stride_cm + rk[None, :] * stride_ck)
D = D + (rk[:, None] * stride_dk + rbn[None, :] * stride_dn)
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
for k1 in range(K1, 0, -BLOCK_K):
# First matmul with A @ B
if EVEN_K:
a = tl.load(A)
b = tl.load(B)
else:
a = tl.load(A, mask=rk[None, :] < k1, other=0.)
b = tl.load(B, mask=rk[:, None] < k1, other=0.)
acc += tl.dot(a, b, allow_tf32=ALLOW_TF32)
A += BLOCK_K * stride_ak
B += BLOCK_K * stride_bk
# Splitting this into two loops causes an internal triton LLVM error
# https://github.com/openai/triton/issues/967
# for k2 in range(K2, 0, -BLOCK_K):
k2 = k1
# Second matmul with C @ D
if EVEN_K:
c = tl.load(C)
d = tl.load(D)
else:
c = tl.load(C, mask=rk[None, :] < k2, other=0.)
d = tl.load(D, mask=rk[:, None] < k2, other=0.)
acc += tl.dot(c, d, allow_tf32=ALLOW_TF32)
C += BLOCK_K * stride_ck
D += BLOCK_K * stride_dk
# rematerialize rm and rn to save registers
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
idx_m = rm[:, None]
idx_n = rn[None, :]
mask = (idx_m < M) & (idx_n < N)
# inductor generates a suffix
xindex = idx_n + (1536*idx_m)
tl.store(out_ptr0 + (xindex + tl.zeros(mask.shape, tl.int32)), acc, mask)
m, n, k = 2048, 1536, 256
mm_grid = (triton.cdiv(m, k) * triton.cdiv(n, k), 1, 1)
triton_mm_plus_mm[mm_grid](
torch.randn(m, k).cuda(),
torch.randn(k, n).cuda(),
torch.randn(m, k).cuda(),
torch.randn(k, n).cuda(),
torch.empty(m, n).cuda(),
num_stages=2,
num_warps=4,
)
torch.cuda.synchronize()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment