Skip to content

Instantly share code, notes, and snippets.

@HDCharles
Last active March 1, 2024 17:00
Show Gist options
  • Save HDCharles/a7fc12b31702cf963d8453e0da157296 to your computer and use it in GitHub Desktop.
Save HDCharles/a7fc12b31702cf963d8453e0da157296 to your computer and use it in GitHub Desktop.
script for comparing performance of several linear triton kernels across several shapes
import torch
import torch.nn.functional as F
import triton
import triton.language as tl
from triton import Config
from torch._inductor import config
from torch import _dynamo
aten = torch.ops.aten
def get_configs_io_bound():
configs = []
for num_stages in [2, 3, 4, 5, 6]:
for block_m in [16, 32]:
for block_k in [32, 64]:
for block_n in [32, 64, 128, 256]:
num_warps = 2 if block_n <= 64 else 4
configs.append(
Config({'BLOCK_SIZE_M': block_m, 'BLOCK_SIZE_N': block_n, 'BLOCK_SIZE_K': block_k, 'GROUP_SIZE_M': 8},
num_stages=num_stages, num_warps=num_warps))
return configs
config_list = [
Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8),
Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8),
Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2),
# good for int8
Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8),
Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8),
Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2),
Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2),
Config({'BLOCK_SIZE_M': 16, 'BLOCK_SIZE_N': 16, 'BLOCK_SIZE_K': 256, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2),
]+get_configs_io_bound() # taken from inductor
@triton.autotune(
configs = [
Config({'BLOCK_SIZE_M': 16, 'BLOCK_SIZE_N': 16, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=1),
Config({'BLOCK_SIZE_M': 16, 'BLOCK_SIZE_N': 16, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=1),
]+config_list,
key=['M', 'K', 'N'],
)
@triton.jit
def int8_weight_only_linear_kernel(
# Pointers to matrices
x_ptr, w_ptr, b_ptr, s_ptr, y_ptr,
# Matrix dimensions
M, N, K,
# The stride variables represent how much to increase the ptr by when moving by 1
# element in a particular dimension. E.g. `stride_am` is how much to increase `x_ptr`
# by to get the element one row down (A has M rows).
stride_xm, stride_xk,
stride_wk, stride_wn,
stride_b,
stride_ym, stride_yn,
# Meta-parameters
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
GROUP_SIZE_M: tl.constexpr,
):
"""Kernel for computing the matmul C = A x B.
A has shape (M, K), B has shape (K, N) and C has shape (M, N)
"""
# -----------------------------------------------------------
# Map program ids `pid` to the block of Y it should compute.
# This is done in a grouped ordering to promote L2 data reuse.
# See above `L2 Cache Optimizations` section for details.
pid = tl.program_id(axis=0)
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
num_pid_in_group = GROUP_SIZE_M * num_pid_n
group_id = pid // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + (pid % group_size_m)
pid_n = (pid % num_pid_in_group) // group_size_m
# ----------------------------------------------------------
# Create pointers for the first blocks of X and W.
# We will advance this pointer as we move in the K direction
# and accumulate
# `x_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers
# `w_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers
# See above `Pointer Arithmetics` section for details
offs_xm = tl.max_contiguous((pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M,BLOCK_SIZE_M)
offs_wn = tl.max_contiguous((pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N,BLOCK_SIZE_N)
offs_k = tl.arange(0, BLOCK_SIZE_K)
x_ptrs = x_ptr + (offs_xm[:, None] * stride_xm + offs_k[None, :] * stride_xk)
w_ptrs = w_ptr + (offs_k[:, None] * stride_wk + offs_wn[None, :] * stride_wn)
b_ptrs = b_ptr + (offs_wn * stride_b)
step_w = BLOCK_SIZE_K * stride_wk
step_x = BLOCK_SIZE_K * stride_xk
# -----------------------------------------------------------
# Iterate to compute a block of the Y matrix.
# We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block
# of fp32 values for higher accuracy.
# `accumulator` will be converted back to fp16 after the loop.
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
# Load the next block of A and B, generate a mask by checking the K dimension.
# If it is out of bounds, set it to 0.
x = tl.load(x_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)
w = tl.load(w_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
# We accumulate along the K dimension.
accumulator += tl.dot(x, w.to(tl.bfloat16))
# Advance the ptrs to the next K block.
x_ptrs += step_x
w_ptrs += step_w
s = tl.load(s_ptr)
b = tl.load(b_ptrs)
y = (accumulator.to(tl.bfloat16) * s + b)
# y = accumulator
# -----------------------------------------------------------
# Write back the block of the output matrix Y with masks.
offs_ym = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_yn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
y_ptrs = y_ptr + stride_ym * offs_ym[:, None] + stride_yn * offs_yn[None, :]
y_mask = (offs_ym[:, None] < M) & (offs_yn[None, :] < N)
tl.store(y_ptrs, y, mask=y_mask)
def int8_weight_only_linear(x, w, b, s):
# Check constraints.
assert x.shape[1] == w.shape[0], "Incompatible dimensions"
# assert x.is_contiguous(), "Matrix x must be contiguous"
# assert w.is_contiguous(), "Matrix w must be contiguous"
M, K = x.shape
K, N = w.shape
assert b.shape[0] == N
# Allocates output.
y = torch.empty((M, N), device=x.device, dtype=x.dtype)
# 1D launch kernel where each block gets its own program.
grid = lambda META: (
triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']),
)
int8_weight_only_linear_kernel[grid](
x, w, b, s, y,
M, N, K,
x.stride(0), x.stride(1),
w.stride(0), w.stride(1),
b.stride(0),
y.stride(0), y.stride(1),
)
return y
@triton.autotune(
configs = [
Config({'N_BLOCK': 1, 'K_BLOCK': 2**14}),
Config({'N_BLOCK': 1, 'K_BLOCK': 2**13}),
Config({'N_BLOCK': 1, 'K_BLOCK': 2**12}),
Config({'N_BLOCK': 1, 'K_BLOCK': 2**11}),
Config({'N_BLOCK': 1, 'K_BLOCK': 2**10}),
Config({'N_BLOCK': 1, 'K_BLOCK': 2**8}),
Config({'N_BLOCK': 1, 'K_BLOCK': 16}),
],
key=['M', 'K', 'N'],
)
@triton.jit # r is k, x is n
def int8_weight_only_linear_bs_1_kernel(x_ptr, w_ptr, b_ptr, s_ptr, y_ptr, M, N, K, N_BLOCK : tl.constexpr, K_BLOCK : tl.constexpr):
noffset = tl.program_id(0) * N_BLOCK # N_BLOCK is always 1
nindex = noffset + tl.arange(0, N_BLOCK)[None, :] # nindex is a single index
nmask = nindex < N
kbase = tl.arange(0, K_BLOCK)[:, None] #this is a chunk of k values
n0 = nindex
acc = tl.full([K_BLOCK, N_BLOCK], 0, tl.float32) # should probably be bfloat16
for koffset in range(0, K, K_BLOCK):
kindex = koffset + kbase
kmask = kindex < K
x = tl.load(x_ptr + (kindex), None, eviction_policy='evict_last').to(tl.float32)
w = tl.load(w_ptr + (kindex + (K*n0)), nmask, eviction_policy='evict_first', other=0.0)
x_fp32 = x.to(tl.float32) # not needed?
w_fp32 = w.to(tl.float32) # would be faster to multiply in bf16?
xw_fp32 = x_fp32 * w_fp32
xw = tl.broadcast_to(xw_fp32, [K_BLOCK, N_BLOCK])
hold = acc + xw
acc = tl.where(nmask, hold, acc) # if every thread has their own column, probably not needed
xw_sum = tl.sum(acc, 0)[:, None]
scale = tl.load(s_ptr + (n0), nmask, eviction_policy='evict_last').to(tl.float32) # feels like these should be evicted first, only x is reused
bias = tl.load(b_ptr + (n0), nmask, eviction_policy='evict_last').to(tl.float32)
xw_sum_fp32 = xw_sum.to(tl.float32)
y = xw_sum_fp32 * scale
y_final = y + bias
tl.store(y_ptr + (n0), y_final, nmask)
def int8_weight_only_linear_bs_1(x, w, b, s):
# Check constraints.
assert x.shape[0]==1
assert x.shape[1] == w.shape[0], "Incompatible dimensions"
M, K = x.shape
K, N = w.shape
assert b.shape[0] == N
# Allocates output.
y = torch.empty((M, N), device=x.device, dtype=x.dtype)
# 1D launch kernel where each block gets its own program.
grid = lambda META: (
triton.cdiv(N, META['N_BLOCK']),
)
int8_weight_only_linear_bs_1_kernel[grid](
x, w, b, s, y,
M, N, K,
)
return y
@triton.autotune(
configs = [
Config({'BLOCK_SIZE_M': 16, 'BLOCK_SIZE_N': 16, 'BLOCK_SIZE_K': 256, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2),
Config({'BLOCK_SIZE_M': 16, 'BLOCK_SIZE_N': 16, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2),
]+config_list,
key=['M', 'K', 'N'],
)
@triton.jit
def uint4x2_weight_only_linear_kernel(
# Pointers to matrices
x_ptr, w_ptr, b_ptr, s_ptr, y_ptr,
# Matrix dimensions
M, N, K, # x is Mx(K*2) and w is KxN
# The stride variables represent how much to increase the ptr by when moving by 1
# element in a particular dimension. E.g. `stride_am` is how much to increase `x_ptr`
# by to get the element one row down (A has M rows).
stride_xm, stride_xk,
stride_wk, stride_wn,
stride_b,
stride_ym, stride_yn,
# Meta-parameters
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
GROUP_SIZE_M: tl.constexpr,
):
"""Kernel for computing the matmul C = A x B.
A has shape (M, K), B has shape (K, N) and C has shape (M, N)
"""
# -----------------------------------------------------------
# Map program ids `pid` to the block of Y it should compute.
# This is done in a grouped ordering to promote L2 data reuse.
# See above `L2 Cache Optimizations` section for details.
pid = tl.program_id(axis=0)
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
num_pid_in_group = GROUP_SIZE_M * num_pid_n
group_id = pid // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + (pid % group_size_m)
pid_n = (pid % num_pid_in_group) // group_size_m
# ----------------------------------------------------------
# Create pointers for the first blocks of X and W.
# We will advance this pointer as we move in the K direction
# and accumulate
# `x_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers
# `w_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers
# See above `Pointer Arithmetics` section for details
offs_xm = tl.max_contiguous((pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M,BLOCK_SIZE_M)
offs_wn = tl.max_contiguous((pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N,BLOCK_SIZE_N)
offs_k = tl.arange(0, BLOCK_SIZE_K)
x_ptrs = x_ptr + (offs_xm[:, None] * stride_xm + offs_k[None, :] * stride_xk)
w_ptrs = w_ptr + (offs_k[:, None]//2 * stride_wk + offs_wn[None, :] * stride_wn)
w_shifts = (offs_k % 2) * 4
b_ptrs = b_ptr + (offs_wn * stride_b)
step_w = BLOCK_SIZE_K//2 * stride_wk
step_x = BLOCK_SIZE_K * stride_xk
# -----------------------------------------------------------
# Iterate to compute a block of the Y matrix.
# We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block
# of fp32 values for higher accuracy.
# `accumulator` will be converted back to fp16 after the loop.
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
# Load the next block of A and B, generate a mask by checking the K dimension.
# If it is out of bounds, set it to 0.
x = tl.load(x_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)
w = tl.load(w_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
w = ((w >> w_shifts[:, None]) & 0xF) - 8
# We accumulate along the K dimension.
accumulator += tl.dot(x, w.to(tl.bfloat16))
# Advance the ptrs to the next K block.
x_ptrs += step_x
w_ptrs += step_w
s = tl.load(s_ptr)
b = tl.load(b_ptrs)
y = (accumulator.to(tl.bfloat16) * s)+b
# -----------------------------------------------------------
# Write back the block of the output matrix Y with masks.
offs_ym = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_yn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
y_ptrs = y_ptr + stride_ym * offs_ym[:, None] + stride_yn * offs_yn[None, :]
y_mask = (offs_ym[:, None] < M) & (offs_yn[None, :] < N)
tl.store(y_ptrs, y, mask=y_mask)
def uint4x2_weight_only_linear(x, w, b, s):
# Check constraints.
assert x.shape[1] == w.shape[0]*2, "Incompatible dimensions"
# assert x.is_contiguous(), "Matrix x must be contiguous"
# assert w.is_contiguous(), "Matrix w must be contiguous"
M, K = x.shape
_, N = w.shape
assert b.shape[0] == N
# Allocates output.
y = torch.empty((M, N), device=x.device, dtype=x.dtype)
# 1D launch kernel where each block gets its own program.
grid = lambda META: (
triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']),
)
uint4x2_weight_only_linear_kernel[grid](
x, w, b, s, y,
M, N, K,
x.stride(0), x.stride(1),
w.stride(0), w.stride(1),
b.stride(0),
y.stride(0), y.stride(1),
)
return y
#run options
run_tinygemm = False
run_uint4x2 = False
run_int8_kernels = True
run_all_contiguous_options = False
run_bs_1 = False
print(f"settings: run_tinygem={run_tinygemm}, run_uint4x2={run_uint4x2}, run_int8_kernels={run_int8_kernels}, run_all_contiguous_options={run_all_contiguous_options}, run_bs_1={run_bs_1}")
quantiles = [0.5, 0.2, 0.8]
result = {}
for K in [2**6, 2**8, 2**10, 2**12]:
N = K
for M in [K,1] if run_bs_1 else [K]:
result[(M,K,N)]={}
result[(M,K,N)]["cublas linear"]={}
result[(M,K,N)]["int8 linear"]={}
result[(M,K,N)]["int8 wo lin bs1"]={}
result[(M,K,N)]["uint4x2 linear"]={}
result[(M,K,N)]["int4 tinygemm"]={}
for x_noncontiguous in [0,1] if run_all_contiguous_options else [0]:
for w_noncontiguous in [0,1] if run_all_contiguous_options else [1]:
print(f">>shape={(M,K,N)}, x_noncontiguous={x_noncontiguous}, w_noncontiguous={w_noncontiguous}")
x = torch.randn(M,K).to('cuda').to(torch.bfloat16)
w_bf16 = torch.randn(K, N, dtype=torch.bfloat16).cuda()
bias = torch.randn(N, dtype=torch.bfloat16).cuda()
if x_noncontiguous:
x = x.t().contiguous().t()
if w_noncontiguous:
w_bf16 = w_bf16.t().contiguous().t()
print("running cublas linear (ideally want something faster than this)")
try:
torch.nn.functional.linear(x, w_bf16, bias)
torch.cuda.synchronize()
result[M,K,N]["cublas linear"][(x_noncontiguous, w_noncontiguous)] = triton.testing.do_bench(lambda: torch.nn.functional.linear(x, w_bf16, bias), quantiles=quantiles)[0]
except:
print("err")
pass
torch.cuda.synchronize()
if run_tinygemm:
print("running int4 tinygemm (fast for bs=1, slow for bs>1)")
try:
import torchao
# find feasible inner_k_tiles and groupsize for tinygemm kernel
inner_k_tiles=min(K//32,8) # [2, 4, 6 or 8] for inner_k_tiles 2
groupsize=min(128, K//inner_k_tiles) # [32, 64, 128, 256] and needs K%(groupsize*inner_k_tiles)==0
w_int4, scales_and_zeros = torchao.quantization.groupwise_affine_quantize_tensor(w_bf16, 4, groupsize)
w_int4pack = aten._convert_weight_to_int4pack(w_int4.contiguous(), inner_k_tiles)
del w_int4
result[M,K,N]["int4 tinygemm"][(x_noncontiguous, w_noncontiguous)] = triton.testing.do_bench(lambda: aten._weight_int4pack_mm(x.contiguous(), w_int4pack, groupsize, scales_and_zeros), quantiles=quantiles)[0]
del scales_and_zeros, w_int4pack
except:
print("err")
pass
del w_bf16
w_int8 = torch.randint(-128, 127, (K, N), dtype=torch.int8).cuda()
if w_noncontiguous:
w_int8 = w_int8.t().contiguous().t()
scale = torch.randn(N, dtype=torch.bfloat16).cuda()
torch._dynamo.reset()
if run_int8_kernels:
print("running int8 wo lin bs1 (fast but limited)")
try:
int8_weight_only_linear_bs_1(x, w_int8, bias, scale)
torch.cuda.synchronize()
result[M,K,N]["int8 wo lin bs1"][(x_noncontiguous, w_noncontiguous)] = triton.testing.do_bench(lambda: int8_weight_only_linear_bs_1(x, w_int8, bias, scale), quantiles=quantiles)[0]
except:
print("err" if M==1 else "not supported for bs>1")
pass
torch.cuda.synchronize()
if run_int8_kernels:
print("running int8 linear (slow but general, ideally this could be optimized to run faster than cublas for bs>1)")
try:
int8_weight_only_linear(x, w_int8, bias, scale)
torch.cuda.synchronize()
result[M,K,N]["int8 linear"][(x_noncontiguous, w_noncontiguous)] = triton.testing.do_bench(lambda: int8_weight_only_linear(x, w_int8, bias, scale), quantiles=quantiles)[0]
except:
print("err")
pass
torch.cuda.synchronize()
del w_int8
w_uint4x2 = torch.randint(0, 255, (K//2, N), dtype=torch.uint8).cuda()
if w_noncontiguous:
w_uint4x2 = w_uint4x2.t().contiguous().t()
if run_uint4x2:
print("running uint4x2 linear (slow, ideally this could be optimized to run faster than cublas for bs>1)")
try:
assert w_noncontiguous==1
uint4x2_weight_only_linear(x, w_uint4x2, bias, scale)
torch.cuda.synchronize()
result[M,K,N]["uint4x2 linear"][(x_noncontiguous, w_noncontiguous)] = triton.testing.do_bench(lambda: uint4x2_weight_only_linear(x, w_uint4x2, bias, scale), quantiles=quantiles)[0]
except:
print("err" if w_noncontiguous==1 else "not supported when w_noncontiguous==0")
pass
torch.cuda.synchronize()
del w_uint4x2, scale, bias
caches = {"int8 linear": int8_weight_only_linear_kernel.cache, "uint4x2 linear": uint4x2_weight_only_linear_kernel.cache, "int8 wo lin bs1": int8_weight_only_linear_bs_1_kernel.cache}
for M,K,N in result.keys():
print(f"shape=({M},{K})x({K},{N})")
print("| X . W | X . Wt | Xt . W | Xt .Wt | model | config")
for name in result[(M,K,N)].keys():
r = result[(M,K,N)][name]
if len(r)==0:
continue
used_config = None
if name in caches:
cache = caches[name]
for key,config in cache.items():
if key[0]==M and key[1]==K and key[2]==N:
used_config=config
break
print(f"| {(r[(0,0)] if (0,0) in r else 0):2.4f} | {(r[(0,1)] if (0,1) in r else 0):2.4f} | {(r[(1,0)] if (1,0) in r else 0):2.4f} | {(r[(1,1)] if (1,1) in r else 0):2.4f} | {name:<15} | {used_config}")
# install: pip install -U --index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/Triton-Nightly/pypi/simple/ triton-nightly
# instlal: pip install torchao
# using torch compiled from 2.1.0a0+git9c2122d or nightly
# using cuda 12.0 on A100 GPU
"""
shape=(64,64)x(64,64)
| X . W | X . Wt | Xt . W | Xt .Wt | model | config
| 0.0102 | 0.0096 | 0.0108 | 0.0108 | cublas linear | None
| 0.1055 | 0.0998 | 0.1144 | 0.0963 | int8 linear | BLOCK_SIZE_M: 256, BLOCK_SIZE_N: 128, BLOCK_SIZE_K: 32, GROUP_SIZE_M: 8, num_warps: 8, num_ctas: 1, num_stages: 3, enable_warp_specialization: False, enable_persistent: False
shape=(256,256)x(256,256)
| X . W | X . Wt | Xt . W | Xt .Wt | model | config
| 0.0112 | 0.0126 | 0.0136 | 0.0136 | cublas linear | None
| 0.0992 | 0.0964 | 0.1205 | 0.1067 | int8 linear | BLOCK_SIZE_M: 32, BLOCK_SIZE_N: 128, BLOCK_SIZE_K: 32, GROUP_SIZE_M: 8, num_warps: 4, num_ctas: 1, num_stages: 6, enable_warp_specialization: False, enable_persistent: False
shape=(1024,1024)x(1024,1024)
| X . W | X . Wt | Xt . W | Xt .Wt | model | config
| 0.0303 | 0.0303 | 0.0297 | 0.0292 | cublas linear | None
| 0.1036 | 0.1069 | 0.1110 | 0.1050 | int8 linear | BLOCK_SIZE_M: 128, BLOCK_SIZE_N: 32, BLOCK_SIZE_K: 32, GROUP_SIZE_M: 8, num_warps: 4, num_ctas: 1, num_stages: 4, enable_warp_specialization: False, enable_persistent: False
shape=(4096,4096)x(4096,4096)
| X . W | X . Wt | Xt . W | Xt .Wt | model | config
| 0.6448 | 0.6414 | 0.6423 | 0.6402 | cublas linear | None
| 1.2970 | 1.1251 | 1.8512 | 1.4334 | int8 linear | BLOCK_SIZE_M: 128, BLOCK_SIZE_N: 128, BLOCK_SIZE_K: 32, GROUP_SIZE_M: 8, num_warps: 4, num_ctas: 1, num_stages: 4, enable_warp_specialization: False, enable_persistent: False
"""
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment