Skip to content

Instantly share code, notes, and snippets.

@HDCharles
Created July 27, 2023 18:55
Show Gist options
  • Save HDCharles/d1bd21b07748ce58ee2d1f5b2d487710 to your computer and use it in GitHub Desktop.
Save HDCharles/d1bd21b07748ce58ee2d1f5b2d487710 to your computer and use it in GitHub Desktop.
more configs benchmarks
import torch
import torch.nn.functional as F
import triton
import triton.language as tl
from triton.ops.matmul import matmul as triton_matmul
from triton.ops.matmul import _kernel
from triton import Config
def get_configs_io_bound():
configs = []
for num_stages in [2, 3, 4, 5, 6]:
for block_m in [16, 32]:
for block_k in [32, 64]:
for block_n in [32, 64, 128, 256]:
num_warps = 2 if block_n <= 64 else 4
configs.append(
Config({'BLOCK_SIZE_M': block_m, 'BLOCK_SIZE_N': block_n, 'BLOCK_SIZE_K': block_k, 'GROUP_SIZE_M': 8},
num_stages=num_stages, num_warps=num_warps))
return configs
config_list = [
Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8),
Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8),
Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2),
# good for int8
Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8),
Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8),
Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2),
]+get_configs_io_bound()
# `triton.jit`'ed functions can be auto-tuned by using the `triton.autotune` decorator, which consumes:
# - A list of `triton.Config` objects that define different configurations of
# meta-parameters (e.g., `BLOCK_SIZE_M`) and compilation options (e.g., `num_warps`) to try
# - An auto-tuning *key* whose change in values will trigger evaluation of all the
# provided configs
@triton.autotune(
configs = config_list,
# configs=[
# triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8),
# triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
# triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
# triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
# triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
# triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
# triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2),
# triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2),
# ],
key=['M', 'N', 'K'],
)
@triton.jit
def matmul_kernel(
# Pointers to matrices
a_ptr, b_ptr, bias_ptr, scale_ptr, c_ptr,
# Matrix dimensions
M, N, K,
# The stride variables represent how much to increase the ptr by when moving by 1
# element in a particular dimension. E.g. `stride_am` is how much to increase `a_ptr`
# by to get the element one row down (A has M rows).
stride_am, stride_ak,
stride_bk, stride_bn,
stride_bias,
stride_cm, stride_cn,
# Meta-parameters
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
GROUP_SIZE_M: tl.constexpr,
):
"""Kernel for computing the matmul C = A x B.
A has shape (M, K), B has shape (K, N) and C has shape (M, N)
"""
# -----------------------------------------------------------
# Map program ids `pid` to the block of C it should compute.
# This is done in a grouped ordering to promote L2 data reuse.
# See above `L2 Cache Optimizations` section for details.
pid = tl.program_id(axis=0)
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
num_pid_in_group = GROUP_SIZE_M * num_pid_n
group_id = pid // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + (pid % group_size_m)
pid_n = (pid % num_pid_in_group) // group_size_m
# ----------------------------------------------------------
# Create pointers for the first blocks of A and B.
# We will advance this pointer as we move in the K direction
# and accumulate
# `a_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers
# `b_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers
# See above `Pointer Arithmetics` section for details
offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
offs_k = tl.arange(0, BLOCK_SIZE_K)
a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
bias_ptrs = bias_ptr + (offs_bn * stride_bias)
# -----------------------------------------------------------
# Iterate to compute a block of the C matrix.
# We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block
# of fp32 values for higher accuracy.
# `accumulator` will be converted back to fp16 after the loop.
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
# Load the next block of A and B, generate a mask by checking the K dimension.
# If it is out of bounds, set it to 0.
a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
# We accumulate along the K dimension.
accumulator += tl.dot(a, b)
# Advance the ptrs to the next K block.
a_ptrs += BLOCK_SIZE_K * stride_ak
b_ptrs += BLOCK_SIZE_K * stride_bk
scale = tl.load(scale_ptr)
bias = tl.load(bias_ptrs)
c = accumulator.to(tl.bfloat16) * scale + bias
# -----------------------------------------------------------
# Write back the block of the output matrix C with masks.
offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
tl.store(c_ptrs, c, mask=c_mask)
def matmul(a, b, bias, scale):
# Check constraints.
assert a.shape[1] == b.shape[0], "Incompatible dimensions"
# assert a.is_contiguous(), "Matrix A must be contiguous"
# assert b.is_contiguous(), "Matrix B must be contiguous"
M, K = a.shape
K, N = b.shape
# Allocates output.
c = torch.empty((M, N), device=a.device, dtype=a.dtype)
# 1D launch kernel where each block gets its own program.
grid = lambda META: (
triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']),
)
matmul_kernel[grid](
a, b, bias, scale, c,
M, N, K,
a.stride(0), a.stride(1),
b.stride(0), b.stride(1),
bias.stride(0),
c.stride(0), c.stride(1),
)
return c
# `triton.jit`'ed functions can be auto-tuned by using the `triton.autotune` decorator, which consumes:
# - A list of `triton.Config` objects that define different configurations of
# meta-parameters (e.g., `BLOCK_SIZE_M`) and compilation options (e.g., `num_warps`) to try
# - An auto-tuning *key* whose change in values will trigger evaluation of all the
# provided configs
@triton.autotune(
configs=config_list,
# configs = [
# triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8),
# triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
# triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
# triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
# triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
# triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
# triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2),
# triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2),
# triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
# triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8),
# triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=2, num_warps=4),
# ]
key=['M', 'N', 'K'],
)
@triton.jit
def int8_weight_only_linear_kernel(
# Pointers to matrices
x_ptr, w_ptr, b_ptr, s_ptr, y_ptr,
# Matrix dimensions
M, N, K,
# The stride variables represent how much to increase the ptr by when moving by 1
# element in a particular dimension. E.g. `stride_am` is how much to increase `x_ptr`
# by to get the element one row down (A has M rows).
stride_xm, stride_xk,
stride_wk, stride_wn,
stride_b,
stride_ym, stride_yn,
# Meta-parameters
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
GROUP_SIZE_M: tl.constexpr,
):
"""Kernel for computing the matmul C = A x B.
A has shape (M, K), B has shape (K, N) and C has shape (M, N)
"""
# -----------------------------------------------------------
# Map program ids `pid` to the block of Y it should compute.
# This is done in a grouped ordering to promote L2 data reuse.
# See above `L2 Cache Optimizations` section for details.
pid = tl.program_id(axis=0)
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
num_pid_in_group = GROUP_SIZE_M * num_pid_n
group_id = pid // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + (pid % group_size_m)
pid_n = (pid % num_pid_in_group) // group_size_m
# ----------------------------------------------------------
# Create pointers for the first blocks of X and W.
# We will advance this pointer as we move in the K direction
# and accumulate
# `x_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers
# `w_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers
# See above `Pointer Arithmetics` section for details
offs_xm = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
offs_wn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
offs_k = tl.arange(0, BLOCK_SIZE_K)
x_ptrs = x_ptr + (offs_xm[:, None] * stride_xm + offs_k[None, :] * stride_xk)
w_ptrs = w_ptr + (offs_k[:, None] * stride_wk + offs_wn[None, :] * stride_wn)
b_ptrs = b_ptr + (offs_wn * stride_b)
step_w = BLOCK_SIZE_K * stride_wk
step_x = BLOCK_SIZE_K * stride_xk
# -----------------------------------------------------------
# Iterate to compute a block of the Y matrix.
# We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block
# of fp32 values for higher accuracy.
# `accumulator` will be converted back to fp16 after the loop.
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
# Load the next block of A and B, generate a mask by checking the K dimension.
# If it is out of bounds, set it to 0.
x = tl.load(x_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)
w = tl.load(w_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
# We accumulate along the K dimension.
accumulator += tl.dot(x, w.to(tl.bfloat16))
# Advance the ptrs to the next K block.
x_ptrs += step_x
w_ptrs += step_w
s = tl.load(s_ptr)
b = tl.load(b_ptrs)
y = (accumulator.to(tl.bfloat16) * s + b)
# y = accumulator
# -----------------------------------------------------------
# Write back the block of the output matrix Y with masks.
offs_ym = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_yn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
y_ptrs = y_ptr + stride_ym * offs_ym[:, None] + stride_yn * offs_yn[None, :]
y_mask = (offs_ym[:, None] < M) & (offs_yn[None, :] < N)
tl.store(y_ptrs, y, mask=y_mask)
def int8_weight_only_linear(x, w, b, s):
# Check constraints.
assert x.shape[1] == w.shape[0], "Incompatible dimensions"
# assert x.is_contiguous(), "Matrix x must be contiguous"
# assert w.is_contiguous(), "Matrix w must be contiguous"
M, K = x.shape
K, N = w.shape
assert b.shape[0] == N
# Allocates output.
y = torch.empty((M, N), device=x.device, dtype=x.dtype)
# 1D launch kernel where each block gets its own program.
grid = lambda META: (
triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']),
)
int8_weight_only_linear_kernel[grid](
x, w, b, s, y,
M, N, K,
x.stride(0), x.stride(1),
w.stride(0), w.stride(1),
b.stride(0),
y.stride(0), y.stride(1),
)
return y
# `triton.jit`'ed functions can be auto-tuned by using the `triton.autotune` decorator, which consumes:
# - A list of `triton.Config` objects that define different configurations of
# meta-parameters (e.g., `BLOCK_SIZE_M`) and compilation options (e.g., `num_warps`) to try
# - An auto-tuning *key* whose change in values will trigger evaluation of all the
# provided configs
@triton.autotune(
configs=config_list,
# configs = [
# triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8),
# triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
# triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
# triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
# triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
# triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
# triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2),
# triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2),
# triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
# triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8),
# triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=2, num_warps=4),
# ],
key=['M', 'N', 'K'],
)
@triton.jit
def uint4x2_weight_only_linear_kernel(
# Pointers to matrices
x_ptr, w_ptr, b_ptr, s_ptr, y_ptr,
# Matrix dimensions
M, N, K, # x is Mx(K*2) and w is KxN
# The stride variables represent how much to increase the ptr by when moving by 1
# element in a particular dimension. E.g. `stride_am` is how much to increase `x_ptr`
# by to get the element one row down (A has M rows).
stride_xm, stride_xk,
stride_wk, stride_wn,
stride_b,
stride_ym, stride_yn,
# Meta-parameters
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
GROUP_SIZE_M: tl.constexpr,
):
"""Kernel for computing the matmul C = A x B.
A has shape (M, K), B has shape (K, N) and C has shape (M, N)
"""
# -----------------------------------------------------------
# Map program ids `pid` to the block of Y it should compute.
# This is done in a grouped ordering to promote L2 data reuse.
# See above `L2 Cache Optimizations` section for details.
pid = tl.program_id(axis=0)
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
num_pid_in_group = GROUP_SIZE_M * num_pid_n
group_id = pid // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + (pid % group_size_m)
pid_n = (pid % num_pid_in_group) // group_size_m
# ----------------------------------------------------------
# Create pointers for the first blocks of X and W.
# We will advance this pointer as we move in the K direction
# and accumulate
# `x_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers
# `w_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers
# See above `Pointer Arithmetics` section for details
offs_xm = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
offs_wn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
offs_k = tl.arange(0, BLOCK_SIZE_K)
x_ptrs = x_ptr + (offs_xm[:, None] * stride_xm + offs_k[None, :] * stride_xk)
w_ptrs = w_ptr + (offs_k[:, None]//2 * stride_wk + offs_wn[None, :] * stride_wn)
w_shifts = (offs_k % 2) * 4
b_ptrs = b_ptr + (offs_wn * stride_b)
step_w = BLOCK_SIZE_K//2 * stride_wk
step_x = BLOCK_SIZE_K * stride_xk
# -----------------------------------------------------------
# Iterate to compute a block of the Y matrix.
# We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block
# of fp32 values for higher accuracy.
# `accumulator` will be converted back to fp16 after the loop.
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
# Load the next block of A and B, generate a mask by checking the K dimension.
# If it is out of bounds, set it to 0.
x = tl.load(x_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)
w = tl.load(w_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
w = ((w >> w_shifts[:, None]) & 0xF) - 8
# We accumulate along the K dimension.
accumulator += tl.dot(x, w.to(tl.bfloat16))
# Advance the ptrs to the next K block.
x_ptrs += step_x
w_ptrs += step_w
s = tl.load(s_ptr)
b = tl.load(b_ptrs)
y = (accumulator.to(tl.bfloat16) * s)+b
# -----------------------------------------------------------
# Write back the block of the output matrix Y with masks.
offs_ym = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_yn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
y_ptrs = y_ptr + stride_ym * offs_ym[:, None] + stride_yn * offs_yn[None, :]
y_mask = (offs_ym[:, None] < M) & (offs_yn[None, :] < N)
tl.store(y_ptrs, y, mask=y_mask)
def uint4x2_weight_only_linear(x, w, b, s):
# Check constraints.
assert x.shape[1] == w.shape[0]*2, "Incompatible dimensions"
# assert x.is_contiguous(), "Matrix x must be contiguous"
# assert w.is_contiguous(), "Matrix w must be contiguous"
M, K = x.shape
_, N = w.shape
assert b.shape[0] == N
# Allocates output.
y = torch.empty((M, N), device=x.device, dtype=x.dtype)
# 1D launch kernel where each block gets its own program.
grid = lambda META: (
triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']),
)
uint4x2_weight_only_linear_kernel[grid](
x, w, b, s, y,
M, N, K,
x.stride(0), x.stride(1),
w.stride(0), w.stride(1),
b.stride(0),
y.stride(0), y.stride(1),
)
return y
quantiles = [0.5, 0.2, 0.8] # idk what this is for but the tutorial had it
result = {}
for D in [2**8, 2**10, 2**12, 2**14]:
result[D]={}
result[D]["cublas linear"]={}
result[D]["triton matmul"]={}
result[D]["int8 linear"]={}
result[D]["uint4x2 linear"]={}
for t_x in [0,1]:
for t_w in [1,0]:
x = torch.randn(D,D).to('cuda').to(torch.bfloat16)
w_bf16 = torch.randn(D,D, dtype=torch.bfloat16).cuda()
bias = torch.randn(D, dtype=torch.bfloat16).cuda()
if t_x:
x = x.t()
if t_w:
w_bf16 = w_bf16.t()
torch.nn.functional.linear(x, w_bf16, bias)
torch.cuda.synchronize()
result[D]["cublas linear"][(t_x, t_w)] = triton.testing.do_bench(lambda: torch.nn.functional.linear(x, w_bf16, bias), quantiles=quantiles)[0]
torch.cuda.synchronize()
del w_bf16
w_int8 = torch.randint(-128, 127, (D, D), dtype=torch.int8).cuda()
if t_w:
w_int8 = w_int8.t()
scale = torch.randn(D, dtype=torch.bfloat16).cuda()
triton_matmul(x, w_int8)
torch.cuda.synchronize()
# result[D]["bfloat16 linear"][(t_x, t_w)] = triton.testing.do_bench(lambda: matmul(x, w_bf16, bias, scale), quantiles=quantiles)[0]
result[D]["triton matmul"][(t_x, t_w)] = triton.testing.do_bench(lambda: triton_matmul(x, w_int8), quantiles=quantiles)[0]
torch.cuda.synchronize()
int8_weight_only_linear(x, w_int8, bias, scale)
torch.cuda.synchronize()
result[D]["int8 linear"][(t_x, t_w)] = triton.testing.do_bench(lambda: int8_weight_only_linear(x, w_int8, bias, scale), quantiles=quantiles)[0]
torch.cuda.synchronize()
del w_int8
w_uint4x2 = torch.randint(0, 255, (D//2, D), dtype=torch.uint8).cuda()
if t_w:
w_uint4x2 = torch.randint(0, 255, (D, D//2), dtype=torch.uint8).cuda().t()
uint4x2_weight_only_linear(x, w_uint4x2, bias, scale)
torch.cuda.synchronize()
result[D]["uint4x2 linear"][(t_x, t_w)] = triton.testing.do_bench(lambda: uint4x2_weight_only_linear(x, w_uint4x2, bias, scale), quantiles=quantiles)[0]
torch.cuda.synchronize()
del w_uint4x2
caches = {"triton matmul": _kernel.cache, "int8 linear": int8_weight_only_linear_kernel.cache, "uint4x2 linear": uint4x2_weight_only_linear_kernel.cache}
print("X . W | X . W.t() | X.t() . W | X.t() . W.t() | model | (M, N, K) | config |")
for d in result.keys():
for name in result[d].keys():
r = result[d][name]
print(f"| {r[(0,0)]:.4f} | {r[(0,1)]:.4f} | {r[(1,0)]:.4f} | {r[(1,1)]:.4f} | {name} | {(d, d, d)} | {None if name == 'cublas linear' else caches[name][(d, d, d)]} |")
# using triton version triton-nightly 2.1.0.dev20230726014945
# install: pip install -U --index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/Triton-Nightly/pypi/simple/ triton-nightly
# using torch compiled from 2.1.0a0+git9c2122d
# using cuda 12.1, cudnn 8.9.2 on A100 GPU
# ---------- OUTPUT --------------
# | X . W | X . W.t() | X.t() . W | X.t() . W.t() | model | (M, N, K) | config |
# | 0.0092 | 0.0102 | 0.0113 | 0.0102 | cublas linear | (256, 256, 256) | None |
# | 0.0123 | 0.0102 | 0.0123 | 0.0133 | triton matmul | (256, 256, 256) | BLOCK_M: 16, BLOCK_N: 128, BLOCK_K: 64, SPLIT_K: 1, num_warps: 4, num_stages: 4 |
# | 0.0123 | 0.0092 | 0.0123 | 0.0123 | int8 linear | (256, 256, 256) | BLOCK_SIZE_M: 16, BLOCK_SIZE_N: 32, BLOCK_SIZE_K: 64, GROUP_SIZE_M: 8, num_warps: 2, num_stages: 4 |
# | 0.0113 | 0.0123 | 0.0123 | 0.0133 | uint4x2 linear | (256, 256, 256) | BLOCK_SIZE_M: 16, BLOCK_SIZE_N: 32, BLOCK_SIZE_K: 64, GROUP_SIZE_M: 8, num_warps: 2, num_stages: 5 |
# | 0.0246 | 0.0236 | 0.0236 | 0.0236 | cublas linear | (1024, 1024, 1024) | None |
# | 0.0573 | 0.0389 | 0.0563 | 0.0573 | triton matmul | (1024, 1024, 1024) | BLOCK_M: 128, BLOCK_N: 128, BLOCK_K: 32, SPLIT_K: 1, num_warps: 4, num_stages: 4 |
# | 0.0389 | 0.0348 | 0.0399 | 0.0379 | int8 linear | (1024, 1024, 1024) | BLOCK_SIZE_M: 32, BLOCK_SIZE_N: 128, BLOCK_SIZE_K: 64, GROUP_SIZE_M: 8, num_warps: 4, num_stages: 3 |
# | 0.0440 | 0.0389 | 0.0451 | 0.0410 | uint4x2 linear | (1024, 1024, 1024) | BLOCK_SIZE_M: 128, BLOCK_SIZE_N: 32, BLOCK_SIZE_K: 32, GROUP_SIZE_M: 8, num_warps: 4, num_stages: 4 |
# | 0.5929 | 0.5827 | 0.5898 | 0.5888 | cublas linear | (4096, 4096, 4096) | None |
# | 1.0793 | 0.9318 | 1.0547 | 1.0291 | triton matmul | (4096, 4096, 4096) | BLOCK_M: 128, BLOCK_N: 128, BLOCK_K: 32, SPLIT_K: 1, num_warps: 4, num_stages: 4 |
# | 0.9871 | 0.9467 | 1.0527 | 0.9871 | int8 linear | (4096, 4096, 4096) | BLOCK_SIZE_M: 256, BLOCK_SIZE_N: 64, BLOCK_SIZE_K: 32, GROUP_SIZE_M: 8, num_warps: 4, num_stages: 4 |
# | 1.0127 | 1.0086 | 1.0793 | 1.0281 | uint4x2 linear | (4096, 4096, 4096) | BLOCK_SIZE_M: 256, BLOCK_SIZE_N: 64, BLOCK_SIZE_K: 32, GROUP_SIZE_M: 8, num_warps: 4, num_stages: 4 |
# | 36.4964 | 36.0643 | 36.0300 | 35.5374 | cublas linear | (16384, 16384, 16384) | None |
# | 63.9580 | 57.1423 | 63.4399 | 62.9699 | triton matmul | (16384, 16384, 16384) | BLOCK_M: 128, BLOCK_N: 128, BLOCK_K: 32, SPLIT_K: 1, num_warps: 4, num_stages: 4 |
# | 65.0670 | 61.5434 | 66.1012 | 62.3503 | int8 linear | (16384, 16384, 16384) | BLOCK_SIZE_M: 128, BLOCK_SIZE_N: 128, BLOCK_SIZE_K: 32, GROUP_SIZE_M: 8, num_warps: 4, num_stages: 4 |
# | 61.3018 | 63.1890 | 63.6826 | 62.2889 | uint4x2 linear | (16384, 16384, 16384) | BLOCK_SIZE_M: 256, BLOCK_SIZE_N: 64, BLOCK_SIZE_K: 32, GROUP_SIZE_M: 8, num_warps: 4, num_stages: 4 |
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment