-
-
Save bertmaher/93302c4f40728d8481873850e84cf47a to your computer and use it in GitHub Desktop.
Triton matmul crashes
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import triton.language as tl | |
import triton | |
torch.set_float32_matmul_precision("high") | |
@triton.jit | |
def triton_mm(arg_A, arg_B, out_ptr0): | |
GROUP_M : tl.constexpr = 8 | |
EVEN_K : tl.constexpr = True | |
ALLOW_TF32 : tl.constexpr = True | |
ACC_TYPE : tl.constexpr = tl.float32 | |
BLOCK_M : tl.constexpr = 64 | |
BLOCK_N : tl.constexpr = 64 | |
BLOCK_K : tl.constexpr = 32 | |
A = arg_A | |
B = arg_B | |
M = 1024 | |
N = 1953 | |
K = 2048 | |
stride_am = 1 | |
stride_ak = 1024 | |
stride_bk = 7920 | |
stride_bn = 1 | |
# based on triton.ops.matmul | |
pid = tl.program_id(0) | |
grid_m = (M + BLOCK_M - 1) // BLOCK_M | |
grid_n = (N + BLOCK_N - 1) // BLOCK_N | |
# re-order program ID for better L2 performance | |
width = GROUP_M * grid_n | |
group_id = pid // width | |
group_size = min(grid_m - group_id * GROUP_M, GROUP_M) | |
pid_m = group_id * GROUP_M + (pid % group_size) | |
pid_n = (pid % width) // (group_size) | |
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) | |
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) | |
ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M) | |
rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N) | |
rk = tl.arange(0, BLOCK_K) | |
A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak) | |
B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn) | |
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) | |
for k in range(K, 0, -BLOCK_K): | |
if EVEN_K: | |
a = tl.load(A) | |
b = tl.load(B) | |
else: | |
a = tl.load(A, mask=rk[None, :] < k, other=0.) | |
b = tl.load(B, mask=rk[:, None] < k, other=0.) | |
acc += tl.dot(a, b, allow_tf32=ALLOW_TF32) | |
A += BLOCK_K * stride_ak | |
B += BLOCK_K * stride_bk | |
# rematerialize rm and rn to save registers | |
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) | |
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) | |
idx_m = rm[:, None] | |
idx_n = rn[None, :] | |
mask = (idx_m < M) & (idx_n < N) | |
# inductor generates a suffix | |
xindex = idx_n + (1953*idx_m) | |
tl.store(out_ptr0 + (xindex + tl.zeros(mask.shape, tl.int32)), acc, mask) | |
BLOCK_M = 64 | |
BLOCK_N = 64 | |
a = torch.empty_strided((1024, 2048), (1, 1024), device="cuda") | |
b = torch.empty_strided((2048, 1953), (7920, 1), device="cuda") | |
out = torch.empty(1024, 1953, device="cuda") | |
triton_mm[triton.cdiv(1024, BLOCK_M) * triton.cdiv(1953, BLOCK_N), 1, 1]( | |
a, b, out, | |
num_stages=2, | |
num_warps=4, | |
) | |
torch.cuda.synchronize() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import triton.language as tl | |
import triton | |
torch.set_float32_matmul_precision("high") | |
@triton.jit | |
def triton_mm_plus_mm(arg_A, arg_B, arg_C, arg_D, out_ptr0): | |
GROUP_M : tl.constexpr = 8 | |
EVEN_K : tl.constexpr = True | |
ALLOW_TF32 : tl.constexpr = True | |
ACC_TYPE : tl.constexpr = tl.float32 | |
BLOCK_M : tl.constexpr = 64 | |
BLOCK_N : tl.constexpr = 64 | |
BLOCK_K : tl.constexpr = 32 | |
A = arg_A | |
B = arg_B | |
C = arg_C | |
D = arg_D | |
M = 2048 | |
N = 1536 | |
K1 = 64 | |
# K2 = 64 | |
stride_am = 64 | |
stride_ak = 1 | |
stride_bk = 1536 | |
stride_bn = 1 | |
stride_cm = 64 | |
stride_ck = 1 | |
stride_dk = 1536 | |
stride_dn = 1 | |
# based on triton.ops.matmul | |
pid = tl.program_id(0) | |
grid_m = (M + BLOCK_M - 1) // BLOCK_M | |
grid_n = (N + BLOCK_N - 1) // BLOCK_N | |
# re-order program ID for better L2 performance | |
width = GROUP_M * grid_n | |
group_id = pid // width | |
group_size = min(grid_m - group_id * GROUP_M, GROUP_M) | |
pid_m = group_id * GROUP_M + (pid % group_size) | |
pid_n = (pid % width) // (group_size) | |
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) | |
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) | |
ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M) | |
rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N) | |
rk = tl.arange(0, BLOCK_K) | |
A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak) | |
B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn) | |
C = C + (ram[:, None] * stride_cm + rk[None, :] * stride_ck) | |
D = D + (rk[:, None] * stride_dk + rbn[None, :] * stride_dn) | |
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) | |
for k1 in range(K1, 0, -BLOCK_K): | |
# First matmul with A @ B | |
if EVEN_K: | |
a = tl.load(A) | |
b = tl.load(B) | |
else: | |
a = tl.load(A, mask=rk[None, :] < k1, other=0.) | |
b = tl.load(B, mask=rk[:, None] < k1, other=0.) | |
acc += tl.dot(a, b, allow_tf32=ALLOW_TF32) | |
A += BLOCK_K * stride_ak | |
B += BLOCK_K * stride_bk | |
# Splitting this into two loops causes an internal triton LLVM error | |
# https://github.com/openai/triton/issues/967 | |
# for k2 in range(K2, 0, -BLOCK_K): | |
k2 = k1 | |
# Second matmul with C @ D | |
if EVEN_K: | |
c = tl.load(C) | |
d = tl.load(D) | |
else: | |
c = tl.load(C, mask=rk[None, :] < k2, other=0.) | |
d = tl.load(D, mask=rk[:, None] < k2, other=0.) | |
acc += tl.dot(c, d, allow_tf32=ALLOW_TF32) | |
C += BLOCK_K * stride_ck | |
D += BLOCK_K * stride_dk | |
# rematerialize rm and rn to save registers | |
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) | |
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) | |
idx_m = rm[:, None] | |
idx_n = rn[None, :] | |
mask = (idx_m < M) & (idx_n < N) | |
# inductor generates a suffix | |
xindex = idx_n + (1536*idx_m) | |
tl.store(out_ptr0 + (xindex + tl.zeros(mask.shape, tl.int32)), acc, mask) | |
m, n, k = 2048, 1536, 256 | |
mm_grid = (triton.cdiv(m, k) * triton.cdiv(n, k), 1, 1) | |
triton_mm_plus_mm[mm_grid]( | |
torch.randn(m, k).cuda(), | |
torch.randn(k, n).cuda(), | |
torch.randn(m, k).cuda(), | |
torch.randn(k, n).cuda(), | |
torch.empty(m, n).cuda(), | |
num_stages=2, | |
num_warps=4, | |
) | |
torch.cuda.synchronize() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment