Last active
March 1, 2024 17:00
-
-
Save HDCharles/a7fc12b31702cf963d8453e0da157296 to your computer and use it in GitHub Desktop.
script for comparing performance of several linear triton kernels across several shapes
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.nn.functional as F | |
import triton | |
import triton.language as tl | |
from triton import Config | |
from torch._inductor import config | |
from torch import _dynamo | |
aten = torch.ops.aten | |
def get_configs_io_bound(): | |
configs = [] | |
for num_stages in [2, 3, 4, 5, 6]: | |
for block_m in [16, 32]: | |
for block_k in [32, 64]: | |
for block_n in [32, 64, 128, 256]: | |
num_warps = 2 if block_n <= 64 else 4 | |
configs.append( | |
Config({'BLOCK_SIZE_M': block_m, 'BLOCK_SIZE_N': block_n, 'BLOCK_SIZE_K': block_k, 'GROUP_SIZE_M': 8}, | |
num_stages=num_stages, num_warps=num_warps)) | |
return configs | |
config_list = [ | |
Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8), | |
Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8), | |
Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), | |
Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), | |
Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), | |
Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), | |
Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), | |
Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), | |
Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2), | |
# good for int8 | |
Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8), | |
Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8), | |
Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), | |
Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), | |
Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), | |
Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), | |
Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), | |
Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), | |
Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2), | |
Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2), | |
Config({'BLOCK_SIZE_M': 16, 'BLOCK_SIZE_N': 16, 'BLOCK_SIZE_K': 256, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2), | |
]+get_configs_io_bound() # taken from inductor | |
@triton.autotune( | |
configs = [ | |
Config({'BLOCK_SIZE_M': 16, 'BLOCK_SIZE_N': 16, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=1), | |
Config({'BLOCK_SIZE_M': 16, 'BLOCK_SIZE_N': 16, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=1), | |
]+config_list, | |
key=['M', 'K', 'N'], | |
) | |
@triton.jit | |
def int8_weight_only_linear_kernel( | |
# Pointers to matrices | |
x_ptr, w_ptr, b_ptr, s_ptr, y_ptr, | |
# Matrix dimensions | |
M, N, K, | |
# The stride variables represent how much to increase the ptr by when moving by 1 | |
# element in a particular dimension. E.g. `stride_am` is how much to increase `x_ptr` | |
# by to get the element one row down (A has M rows). | |
stride_xm, stride_xk, | |
stride_wk, stride_wn, | |
stride_b, | |
stride_ym, stride_yn, | |
# Meta-parameters | |
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, | |
GROUP_SIZE_M: tl.constexpr, | |
): | |
"""Kernel for computing the matmul C = A x B. | |
A has shape (M, K), B has shape (K, N) and C has shape (M, N) | |
""" | |
# ----------------------------------------------------------- | |
# Map program ids `pid` to the block of Y it should compute. | |
# This is done in a grouped ordering to promote L2 data reuse. | |
# See above `L2 Cache Optimizations` section for details. | |
pid = tl.program_id(axis=0) | |
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) | |
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) | |
num_pid_in_group = GROUP_SIZE_M * num_pid_n | |
group_id = pid // num_pid_in_group | |
first_pid_m = group_id * GROUP_SIZE_M | |
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) | |
pid_m = first_pid_m + (pid % group_size_m) | |
pid_n = (pid % num_pid_in_group) // group_size_m | |
# ---------------------------------------------------------- | |
# Create pointers for the first blocks of X and W. | |
# We will advance this pointer as we move in the K direction | |
# and accumulate | |
# `x_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers | |
# `w_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers | |
# See above `Pointer Arithmetics` section for details | |
offs_xm = tl.max_contiguous((pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M,BLOCK_SIZE_M) | |
offs_wn = tl.max_contiguous((pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N,BLOCK_SIZE_N) | |
offs_k = tl.arange(0, BLOCK_SIZE_K) | |
x_ptrs = x_ptr + (offs_xm[:, None] * stride_xm + offs_k[None, :] * stride_xk) | |
w_ptrs = w_ptr + (offs_k[:, None] * stride_wk + offs_wn[None, :] * stride_wn) | |
b_ptrs = b_ptr + (offs_wn * stride_b) | |
step_w = BLOCK_SIZE_K * stride_wk | |
step_x = BLOCK_SIZE_K * stride_xk | |
# ----------------------------------------------------------- | |
# Iterate to compute a block of the Y matrix. | |
# We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block | |
# of fp32 values for higher accuracy. | |
# `accumulator` will be converted back to fp16 after the loop. | |
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) | |
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): | |
# Load the next block of A and B, generate a mask by checking the K dimension. | |
# If it is out of bounds, set it to 0. | |
x = tl.load(x_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) | |
w = tl.load(w_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) | |
# We accumulate along the K dimension. | |
accumulator += tl.dot(x, w.to(tl.bfloat16)) | |
# Advance the ptrs to the next K block. | |
x_ptrs += step_x | |
w_ptrs += step_w | |
s = tl.load(s_ptr) | |
b = tl.load(b_ptrs) | |
y = (accumulator.to(tl.bfloat16) * s + b) | |
# y = accumulator | |
# ----------------------------------------------------------- | |
# Write back the block of the output matrix Y with masks. | |
offs_ym = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) | |
offs_yn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) | |
y_ptrs = y_ptr + stride_ym * offs_ym[:, None] + stride_yn * offs_yn[None, :] | |
y_mask = (offs_ym[:, None] < M) & (offs_yn[None, :] < N) | |
tl.store(y_ptrs, y, mask=y_mask) | |
def int8_weight_only_linear(x, w, b, s): | |
# Check constraints. | |
assert x.shape[1] == w.shape[0], "Incompatible dimensions" | |
# assert x.is_contiguous(), "Matrix x must be contiguous" | |
# assert w.is_contiguous(), "Matrix w must be contiguous" | |
M, K = x.shape | |
K, N = w.shape | |
assert b.shape[0] == N | |
# Allocates output. | |
y = torch.empty((M, N), device=x.device, dtype=x.dtype) | |
# 1D launch kernel where each block gets its own program. | |
grid = lambda META: ( | |
triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), | |
) | |
int8_weight_only_linear_kernel[grid]( | |
x, w, b, s, y, | |
M, N, K, | |
x.stride(0), x.stride(1), | |
w.stride(0), w.stride(1), | |
b.stride(0), | |
y.stride(0), y.stride(1), | |
) | |
return y | |
@triton.autotune( | |
configs = [ | |
Config({'N_BLOCK': 1, 'K_BLOCK': 2**14}), | |
Config({'N_BLOCK': 1, 'K_BLOCK': 2**13}), | |
Config({'N_BLOCK': 1, 'K_BLOCK': 2**12}), | |
Config({'N_BLOCK': 1, 'K_BLOCK': 2**11}), | |
Config({'N_BLOCK': 1, 'K_BLOCK': 2**10}), | |
Config({'N_BLOCK': 1, 'K_BLOCK': 2**8}), | |
Config({'N_BLOCK': 1, 'K_BLOCK': 16}), | |
], | |
key=['M', 'K', 'N'], | |
) | |
@triton.jit # r is k, x is n | |
def int8_weight_only_linear_bs_1_kernel(x_ptr, w_ptr, b_ptr, s_ptr, y_ptr, M, N, K, N_BLOCK : tl.constexpr, K_BLOCK : tl.constexpr): | |
noffset = tl.program_id(0) * N_BLOCK # N_BLOCK is always 1 | |
nindex = noffset + tl.arange(0, N_BLOCK)[None, :] # nindex is a single index | |
nmask = nindex < N | |
kbase = tl.arange(0, K_BLOCK)[:, None] #this is a chunk of k values | |
n0 = nindex | |
acc = tl.full([K_BLOCK, N_BLOCK], 0, tl.float32) # should probably be bfloat16 | |
for koffset in range(0, K, K_BLOCK): | |
kindex = koffset + kbase | |
kmask = kindex < K | |
x = tl.load(x_ptr + (kindex), None, eviction_policy='evict_last').to(tl.float32) | |
w = tl.load(w_ptr + (kindex + (K*n0)), nmask, eviction_policy='evict_first', other=0.0) | |
x_fp32 = x.to(tl.float32) # not needed? | |
w_fp32 = w.to(tl.float32) # would be faster to multiply in bf16? | |
xw_fp32 = x_fp32 * w_fp32 | |
xw = tl.broadcast_to(xw_fp32, [K_BLOCK, N_BLOCK]) | |
hold = acc + xw | |
acc = tl.where(nmask, hold, acc) # if every thread has their own column, probably not needed | |
xw_sum = tl.sum(acc, 0)[:, None] | |
scale = tl.load(s_ptr + (n0), nmask, eviction_policy='evict_last').to(tl.float32) # feels like these should be evicted first, only x is reused | |
bias = tl.load(b_ptr + (n0), nmask, eviction_policy='evict_last').to(tl.float32) | |
xw_sum_fp32 = xw_sum.to(tl.float32) | |
y = xw_sum_fp32 * scale | |
y_final = y + bias | |
tl.store(y_ptr + (n0), y_final, nmask) | |
def int8_weight_only_linear_bs_1(x, w, b, s): | |
# Check constraints. | |
assert x.shape[0]==1 | |
assert x.shape[1] == w.shape[0], "Incompatible dimensions" | |
M, K = x.shape | |
K, N = w.shape | |
assert b.shape[0] == N | |
# Allocates output. | |
y = torch.empty((M, N), device=x.device, dtype=x.dtype) | |
# 1D launch kernel where each block gets its own program. | |
grid = lambda META: ( | |
triton.cdiv(N, META['N_BLOCK']), | |
) | |
int8_weight_only_linear_bs_1_kernel[grid]( | |
x, w, b, s, y, | |
M, N, K, | |
) | |
return y | |
@triton.autotune( | |
configs = [ | |
Config({'BLOCK_SIZE_M': 16, 'BLOCK_SIZE_N': 16, 'BLOCK_SIZE_K': 256, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2), | |
Config({'BLOCK_SIZE_M': 16, 'BLOCK_SIZE_N': 16, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2), | |
]+config_list, | |
key=['M', 'K', 'N'], | |
) | |
@triton.jit | |
def uint4x2_weight_only_linear_kernel( | |
# Pointers to matrices | |
x_ptr, w_ptr, b_ptr, s_ptr, y_ptr, | |
# Matrix dimensions | |
M, N, K, # x is Mx(K*2) and w is KxN | |
# The stride variables represent how much to increase the ptr by when moving by 1 | |
# element in a particular dimension. E.g. `stride_am` is how much to increase `x_ptr` | |
# by to get the element one row down (A has M rows). | |
stride_xm, stride_xk, | |
stride_wk, stride_wn, | |
stride_b, | |
stride_ym, stride_yn, | |
# Meta-parameters | |
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, | |
GROUP_SIZE_M: tl.constexpr, | |
): | |
"""Kernel for computing the matmul C = A x B. | |
A has shape (M, K), B has shape (K, N) and C has shape (M, N) | |
""" | |
# ----------------------------------------------------------- | |
# Map program ids `pid` to the block of Y it should compute. | |
# This is done in a grouped ordering to promote L2 data reuse. | |
# See above `L2 Cache Optimizations` section for details. | |
pid = tl.program_id(axis=0) | |
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) | |
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) | |
num_pid_in_group = GROUP_SIZE_M * num_pid_n | |
group_id = pid // num_pid_in_group | |
first_pid_m = group_id * GROUP_SIZE_M | |
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) | |
pid_m = first_pid_m + (pid % group_size_m) | |
pid_n = (pid % num_pid_in_group) // group_size_m | |
# ---------------------------------------------------------- | |
# Create pointers for the first blocks of X and W. | |
# We will advance this pointer as we move in the K direction | |
# and accumulate | |
# `x_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers | |
# `w_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers | |
# See above `Pointer Arithmetics` section for details | |
offs_xm = tl.max_contiguous((pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M,BLOCK_SIZE_M) | |
offs_wn = tl.max_contiguous((pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N,BLOCK_SIZE_N) | |
offs_k = tl.arange(0, BLOCK_SIZE_K) | |
x_ptrs = x_ptr + (offs_xm[:, None] * stride_xm + offs_k[None, :] * stride_xk) | |
w_ptrs = w_ptr + (offs_k[:, None]//2 * stride_wk + offs_wn[None, :] * stride_wn) | |
w_shifts = (offs_k % 2) * 4 | |
b_ptrs = b_ptr + (offs_wn * stride_b) | |
step_w = BLOCK_SIZE_K//2 * stride_wk | |
step_x = BLOCK_SIZE_K * stride_xk | |
# ----------------------------------------------------------- | |
# Iterate to compute a block of the Y matrix. | |
# We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block | |
# of fp32 values for higher accuracy. | |
# `accumulator` will be converted back to fp16 after the loop. | |
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) | |
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): | |
# Load the next block of A and B, generate a mask by checking the K dimension. | |
# If it is out of bounds, set it to 0. | |
x = tl.load(x_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) | |
w = tl.load(w_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) | |
w = ((w >> w_shifts[:, None]) & 0xF) - 8 | |
# We accumulate along the K dimension. | |
accumulator += tl.dot(x, w.to(tl.bfloat16)) | |
# Advance the ptrs to the next K block. | |
x_ptrs += step_x | |
w_ptrs += step_w | |
s = tl.load(s_ptr) | |
b = tl.load(b_ptrs) | |
y = (accumulator.to(tl.bfloat16) * s)+b | |
# ----------------------------------------------------------- | |
# Write back the block of the output matrix Y with masks. | |
offs_ym = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) | |
offs_yn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) | |
y_ptrs = y_ptr + stride_ym * offs_ym[:, None] + stride_yn * offs_yn[None, :] | |
y_mask = (offs_ym[:, None] < M) & (offs_yn[None, :] < N) | |
tl.store(y_ptrs, y, mask=y_mask) | |
def uint4x2_weight_only_linear(x, w, b, s): | |
# Check constraints. | |
assert x.shape[1] == w.shape[0]*2, "Incompatible dimensions" | |
# assert x.is_contiguous(), "Matrix x must be contiguous" | |
# assert w.is_contiguous(), "Matrix w must be contiguous" | |
M, K = x.shape | |
_, N = w.shape | |
assert b.shape[0] == N | |
# Allocates output. | |
y = torch.empty((M, N), device=x.device, dtype=x.dtype) | |
# 1D launch kernel where each block gets its own program. | |
grid = lambda META: ( | |
triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), | |
) | |
uint4x2_weight_only_linear_kernel[grid]( | |
x, w, b, s, y, | |
M, N, K, | |
x.stride(0), x.stride(1), | |
w.stride(0), w.stride(1), | |
b.stride(0), | |
y.stride(0), y.stride(1), | |
) | |
return y | |
#run options | |
run_tinygemm = False | |
run_uint4x2 = False | |
run_int8_kernels = True | |
run_all_contiguous_options = False | |
run_bs_1 = False | |
print(f"settings: run_tinygem={run_tinygemm}, run_uint4x2={run_uint4x2}, run_int8_kernels={run_int8_kernels}, run_all_contiguous_options={run_all_contiguous_options}, run_bs_1={run_bs_1}") | |
quantiles = [0.5, 0.2, 0.8] | |
result = {} | |
for K in [2**6, 2**8, 2**10, 2**12]: | |
N = K | |
for M in [K,1] if run_bs_1 else [K]: | |
result[(M,K,N)]={} | |
result[(M,K,N)]["cublas linear"]={} | |
result[(M,K,N)]["int8 linear"]={} | |
result[(M,K,N)]["int8 wo lin bs1"]={} | |
result[(M,K,N)]["uint4x2 linear"]={} | |
result[(M,K,N)]["int4 tinygemm"]={} | |
for x_noncontiguous in [0,1] if run_all_contiguous_options else [0]: | |
for w_noncontiguous in [0,1] if run_all_contiguous_options else [1]: | |
print(f">>shape={(M,K,N)}, x_noncontiguous={x_noncontiguous}, w_noncontiguous={w_noncontiguous}") | |
x = torch.randn(M,K).to('cuda').to(torch.bfloat16) | |
w_bf16 = torch.randn(K, N, dtype=torch.bfloat16).cuda() | |
bias = torch.randn(N, dtype=torch.bfloat16).cuda() | |
if x_noncontiguous: | |
x = x.t().contiguous().t() | |
if w_noncontiguous: | |
w_bf16 = w_bf16.t().contiguous().t() | |
print("running cublas linear (ideally want something faster than this)") | |
try: | |
torch.nn.functional.linear(x, w_bf16, bias) | |
torch.cuda.synchronize() | |
result[M,K,N]["cublas linear"][(x_noncontiguous, w_noncontiguous)] = triton.testing.do_bench(lambda: torch.nn.functional.linear(x, w_bf16, bias), quantiles=quantiles)[0] | |
except: | |
print("err") | |
pass | |
torch.cuda.synchronize() | |
if run_tinygemm: | |
print("running int4 tinygemm (fast for bs=1, slow for bs>1)") | |
try: | |
import torchao | |
# find feasible inner_k_tiles and groupsize for tinygemm kernel | |
inner_k_tiles=min(K//32,8) # [2, 4, 6 or 8] for inner_k_tiles 2 | |
groupsize=min(128, K//inner_k_tiles) # [32, 64, 128, 256] and needs K%(groupsize*inner_k_tiles)==0 | |
w_int4, scales_and_zeros = torchao.quantization.groupwise_affine_quantize_tensor(w_bf16, 4, groupsize) | |
w_int4pack = aten._convert_weight_to_int4pack(w_int4.contiguous(), inner_k_tiles) | |
del w_int4 | |
result[M,K,N]["int4 tinygemm"][(x_noncontiguous, w_noncontiguous)] = triton.testing.do_bench(lambda: aten._weight_int4pack_mm(x.contiguous(), w_int4pack, groupsize, scales_and_zeros), quantiles=quantiles)[0] | |
del scales_and_zeros, w_int4pack | |
except: | |
print("err") | |
pass | |
del w_bf16 | |
w_int8 = torch.randint(-128, 127, (K, N), dtype=torch.int8).cuda() | |
if w_noncontiguous: | |
w_int8 = w_int8.t().contiguous().t() | |
scale = torch.randn(N, dtype=torch.bfloat16).cuda() | |
torch._dynamo.reset() | |
if run_int8_kernels: | |
print("running int8 wo lin bs1 (fast but limited)") | |
try: | |
int8_weight_only_linear_bs_1(x, w_int8, bias, scale) | |
torch.cuda.synchronize() | |
result[M,K,N]["int8 wo lin bs1"][(x_noncontiguous, w_noncontiguous)] = triton.testing.do_bench(lambda: int8_weight_only_linear_bs_1(x, w_int8, bias, scale), quantiles=quantiles)[0] | |
except: | |
print("err" if M==1 else "not supported for bs>1") | |
pass | |
torch.cuda.synchronize() | |
if run_int8_kernels: | |
print("running int8 linear (slow but general, ideally this could be optimized to run faster than cublas for bs>1)") | |
try: | |
int8_weight_only_linear(x, w_int8, bias, scale) | |
torch.cuda.synchronize() | |
result[M,K,N]["int8 linear"][(x_noncontiguous, w_noncontiguous)] = triton.testing.do_bench(lambda: int8_weight_only_linear(x, w_int8, bias, scale), quantiles=quantiles)[0] | |
except: | |
print("err") | |
pass | |
torch.cuda.synchronize() | |
del w_int8 | |
w_uint4x2 = torch.randint(0, 255, (K//2, N), dtype=torch.uint8).cuda() | |
if w_noncontiguous: | |
w_uint4x2 = w_uint4x2.t().contiguous().t() | |
if run_uint4x2: | |
print("running uint4x2 linear (slow, ideally this could be optimized to run faster than cublas for bs>1)") | |
try: | |
assert w_noncontiguous==1 | |
uint4x2_weight_only_linear(x, w_uint4x2, bias, scale) | |
torch.cuda.synchronize() | |
result[M,K,N]["uint4x2 linear"][(x_noncontiguous, w_noncontiguous)] = triton.testing.do_bench(lambda: uint4x2_weight_only_linear(x, w_uint4x2, bias, scale), quantiles=quantiles)[0] | |
except: | |
print("err" if w_noncontiguous==1 else "not supported when w_noncontiguous==0") | |
pass | |
torch.cuda.synchronize() | |
del w_uint4x2, scale, bias | |
caches = {"int8 linear": int8_weight_only_linear_kernel.cache, "uint4x2 linear": uint4x2_weight_only_linear_kernel.cache, "int8 wo lin bs1": int8_weight_only_linear_bs_1_kernel.cache} | |
for M,K,N in result.keys(): | |
print(f"shape=({M},{K})x({K},{N})") | |
print("| X . W | X . Wt | Xt . W | Xt .Wt | model | config") | |
for name in result[(M,K,N)].keys(): | |
r = result[(M,K,N)][name] | |
if len(r)==0: | |
continue | |
used_config = None | |
if name in caches: | |
cache = caches[name] | |
for key,config in cache.items(): | |
if key[0]==M and key[1]==K and key[2]==N: | |
used_config=config | |
break | |
print(f"| {(r[(0,0)] if (0,0) in r else 0):2.4f} | {(r[(0,1)] if (0,1) in r else 0):2.4f} | {(r[(1,0)] if (1,0) in r else 0):2.4f} | {(r[(1,1)] if (1,1) in r else 0):2.4f} | {name:<15} | {used_config}") | |
# install: pip install -U --index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/Triton-Nightly/pypi/simple/ triton-nightly | |
# instlal: pip install torchao | |
# using torch compiled from 2.1.0a0+git9c2122d or nightly | |
# using cuda 12.0 on A100 GPU | |
""" | |
shape=(64,64)x(64,64) | |
| X . W | X . Wt | Xt . W | Xt .Wt | model | config | |
| 0.0102 | 0.0096 | 0.0108 | 0.0108 | cublas linear | None | |
| 0.1055 | 0.0998 | 0.1144 | 0.0963 | int8 linear | BLOCK_SIZE_M: 256, BLOCK_SIZE_N: 128, BLOCK_SIZE_K: 32, GROUP_SIZE_M: 8, num_warps: 8, num_ctas: 1, num_stages: 3, enable_warp_specialization: False, enable_persistent: False | |
shape=(256,256)x(256,256) | |
| X . W | X . Wt | Xt . W | Xt .Wt | model | config | |
| 0.0112 | 0.0126 | 0.0136 | 0.0136 | cublas linear | None | |
| 0.0992 | 0.0964 | 0.1205 | 0.1067 | int8 linear | BLOCK_SIZE_M: 32, BLOCK_SIZE_N: 128, BLOCK_SIZE_K: 32, GROUP_SIZE_M: 8, num_warps: 4, num_ctas: 1, num_stages: 6, enable_warp_specialization: False, enable_persistent: False | |
shape=(1024,1024)x(1024,1024) | |
| X . W | X . Wt | Xt . W | Xt .Wt | model | config | |
| 0.0303 | 0.0303 | 0.0297 | 0.0292 | cublas linear | None | |
| 0.1036 | 0.1069 | 0.1110 | 0.1050 | int8 linear | BLOCK_SIZE_M: 128, BLOCK_SIZE_N: 32, BLOCK_SIZE_K: 32, GROUP_SIZE_M: 8, num_warps: 4, num_ctas: 1, num_stages: 4, enable_warp_specialization: False, enable_persistent: False | |
shape=(4096,4096)x(4096,4096) | |
| X . W | X . Wt | Xt . W | Xt .Wt | model | config | |
| 0.6448 | 0.6414 | 0.6423 | 0.6402 | cublas linear | None | |
| 1.2970 | 1.1251 | 1.8512 | 1.4334 | int8 linear | BLOCK_SIZE_M: 128, BLOCK_SIZE_N: 128, BLOCK_SIZE_K: 32, GROUP_SIZE_M: 8, num_warps: 4, num_ctas: 1, num_stages: 4, enable_warp_specialization: False, enable_persistent: False | |
""" |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment