Skip to content

Instantly share code, notes, and snippets.

@HDCharles
Created August 14, 2023 16:21
Show Gist options
  • Save HDCharles/b2d8c916cfc4629d3f81f09de734e577 to your computer and use it in GitHub Desktop.
Save HDCharles/b2d8c916cfc4629d3f81f09de734e577 to your computer and use it in GitHub Desktop.
microbenchmarks for mixed dtype kernels
import torch
import torch.nn.functional as F
import triton
import triton.language as tl
from triton.ops.matmul import matmul as triton_matmul
from triton.ops.matmul import _kernel
from triton import Config
import nvtx
import time
def get_configs_io_bound():
configs = []
for num_stages in [3, 4, 5, 6]:
for block_m in [16, 32]:
for block_k in [32, 64]:
for block_n in [32, 64, 128]:
num_warps = 2 if block_n <= 64 else 4
configs.append(
Config({'BLOCK_SIZE_M': block_m, 'BLOCK_SIZE_N': block_n, 'BLOCK_SIZE_K': block_k, 'GROUP_SIZE_M': 8},
num_stages=num_stages, num_warps=num_warps))
return configs
config_list = [
Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), #
Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), #
Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2),
Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
] +get_configs_io_bound()
@triton.autotune(
configs = config_list,
key=['M', 'N', 'K'],
)
@triton.jit
def matmul_kernel(
# Pointers to matrices
a_ptr, b_ptr, bias_ptr, scale_ptr, c_ptr,
# Matrix dimensions
M, N, K,
# The stride variables represent how much to increase the ptr by when moving by 1
# element in a particular dimension. E.g. `stride_am` is how much to increase `a_ptr`
# by to get the element one row down (A has M rows).
stride_am, stride_ak,
stride_bk, stride_bn,
stride_bias,
stride_cm, stride_cn,
# Meta-parameters
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
GROUP_SIZE_M: tl.constexpr,
):
"""Kernel for computing the matmul C = A x B.
A has shape (M, K), B has shape (K, N) and C has shape (M, N)
"""
# -----------------------------------------------------------
# Map program ids `pid` to the block of C it should compute.
# This is done in a grouped ordering to promote L2 data reuse.
# See above `L2 Cache Optimizations` section for details.
pid = tl.program_id(axis=0)
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
num_pid_in_group = GROUP_SIZE_M * num_pid_n
group_id = pid // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + (pid % group_size_m)
pid_n = (pid % num_pid_in_group) // group_size_m
# ----------------------------------------------------------
# Create pointers for the first blocks of A and B.
# We will advance this pointer as we move in the K direction
# and accumulate
# `a_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers
# `b_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers
# See above `Pointer Arithmetics` section for details
offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
offs_k = tl.arange(0, BLOCK_SIZE_K)
a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
bias_ptrs = bias_ptr + (offs_bn * stride_bias)
# -----------------------------------------------------------
# Iterate to compute a block of the C matrix.
# We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block
# of fp32 values for higher accuracy.
# `accumulator` will be converted back to fp16 after the loop.
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
# Load the next block of A and B, generate a mask by checking the K dimension.
# If it is out of bounds, set it to 0.
a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
# We accumulate along the K dimension.
accumulator += tl.dot(a, b.to(tl.bfloat16))
# Advance the ptrs to the next K block.
a_ptrs += BLOCK_SIZE_K * stride_ak
b_ptrs += BLOCK_SIZE_K * stride_bk
scale = tl.load(scale_ptr)
bias = tl.load(bias_ptrs)
c = accumulator.to(tl.bfloat16) * scale + bias
# -----------------------------------------------------------
# Write back the block of the output matrix C with masks.
offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
tl.store(c_ptrs, c, mask=c_mask)
def matmul(a, b, bias, scale):
# Check constraints.
assert a.shape[1] == b.shape[0], "Incompatible dimensions"
# assert a.is_contiguous(), "Matrix A must be contiguous"
# assert b.is_contiguous(), "Matrix B must be contiguous"
M, K = a.shape
K, N = b.shape
# Allocates output.
c = torch.empty((M, N), device=a.device, dtype=a.dtype)
# 1D launch kernel where each block gets its own program.
grid = lambda META: (
triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']),
)
matmul_kernel[grid](
a, b, bias, scale, c,
M, N, K,
a.stride(0), a.stride(1),
b.stride(0), b.stride(1),
bias.stride(0),
c.stride(0), c.stride(1),
)
return c
@triton.autotune(
configs=config_list,
key=['M', 'N', 'K'],
)
@triton.jit
def int8_weight_only_linear_kernel(
# Pointers to matrices
x_ptr, w_ptr, b_ptr, s_ptr, y_ptr,
# Matrix dimensions
M, N, K,
# The stride variables represent how much to increase the ptr by when moving by 1
# element in a particular dimension. E.g. `stride_am` is how much to increase `x_ptr`
# by to get the element one row down (A has M rows).
stride_xm, stride_xk,
stride_wk, stride_wn,
stride_b,
stride_ym, stride_yn,
# Meta-parameters
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
GROUP_SIZE_M: tl.constexpr,
):
"""Kernel for computing the matmul C = A x B.
A has shape (M, K), B has shape (K, N) and C has shape (M, N)
"""
# -----------------------------------------------------------
# Map program ids `pid` to the block of Y it should compute.
# This is done in a grouped ordering to promote L2 data reuse.
# See above `L2 Cache Optimizations` section for details.
pid = tl.program_id(axis=0)
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
num_pid_in_group = GROUP_SIZE_M * num_pid_n
group_id = pid // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + (pid % group_size_m)
pid_n = (pid % num_pid_in_group) // group_size_m
# ----------------------------------------------------------
# Create pointers for the first blocks of X and W.
# We will advance this pointer as we move in the K direction
# and accumulate
# `x_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers
# `w_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers
# See above `Pointer Arithmetics` section for details
offs_xm = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
offs_wn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
offs_k = tl.arange(0, BLOCK_SIZE_K)
x_ptrs = x_ptr + (offs_xm[:, None] * stride_xm + offs_k[None, :] * stride_xk)
w_ptrs = w_ptr + (offs_k[:, None] * stride_wk + offs_wn[None, :] * stride_wn)
b_ptrs = b_ptr + (offs_wn * stride_b)
step_w = BLOCK_SIZE_K * stride_wk
step_x = BLOCK_SIZE_K * stride_xk
# -----------------------------------------------------------
# Iterate to compute a block of the Y matrix.
# We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block
# of fp32 values for higher accuracy.
# `accumulator` will be converted back to fp16 after the loop.
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
# Load the next block of A and B, generate a mask by checking the K dimension.
# If it is out of bounds, set it to 0.
x = tl.load(x_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)
w = tl.load(w_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
# We accumulate along the K dimension.
accumulator += tl.dot(x, w.to(tl.bfloat16))
# Advance the ptrs to the next K block.
x_ptrs += step_x
w_ptrs += step_w
s = tl.load(s_ptr)
b = tl.load(b_ptrs)
y = (accumulator.to(tl.bfloat16) * s + b)
# y = accumulator
# -----------------------------------------------------------
# Write back the block of the output matrix Y with masks.
offs_ym = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_yn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
y_ptrs = y_ptr + stride_ym * offs_ym[:, None] + stride_yn * offs_yn[None, :]
y_mask = (offs_ym[:, None] < M) & (offs_yn[None, :] < N)
tl.store(y_ptrs, y, mask=y_mask)
def int8_weight_only_linear(x, w, b, s):
# Check constraints.
assert x.shape[1] == w.shape[0], "Incompatible dimensions"
# assert x.is_contiguous(), "Matrix x must be contiguous"
# assert w.is_contiguous(), "Matrix w must be contiguous"
M, K = x.shape
K, N = w.shape
assert b.shape[0] == N
# Allocates output.
y = torch.empty((M, N), device=x.device, dtype=x.dtype)
# 1D launch kernel where each block gets its own program.
grid = lambda META: (
triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']),
)
int8_weight_only_linear_kernel[grid](
x, w, b, s, y,
M, N, K,
x.stride(0), x.stride(1),
w.stride(0), w.stride(1),
b.stride(0),
y.stride(0), y.stride(1),
)
return y
@triton.jit
def unpack(x,s,a,m):
return ((x>>s)& a)+m
@triton.autotune(
configs=config_list,
key=['M', 'N', 'K'],
)
@triton.jit
def int4x2_weight_only_linear_kernel(
# Pointers to matrices
x_ptr, w_ptr, b_ptr, s_ptr, y_ptr,
# Matrix dimensions
M, N, K, # x is Mx(K*2) and w is KxN
# The stride variables represent how much to increase the ptr by when moving by 1
# element in a particular dimension. E.g. `stride_am` is how much to increase `x_ptr`
# by to get the element one row down (A has M rows).
stride_xm, stride_xk,
stride_wk, stride_wn,
stride_b,
stride_ym, stride_yn,
# Meta-parameters
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
GROUP_SIZE_M: tl.constexpr,
):
"""Kernel for computing the matmul C = A x B.
A has shape (M, K), B has shape (K, N) and C has shape (M, N)
"""
# -----------------------------------------------------------
# Map program ids `pid` to the block of Y it should compute.
# This is done in a grouped ordering to promote L2 data reuse.
# See above `L2 Cache Optimizations` section for details.
pid = tl.program_id(axis=0)
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
num_pid_in_group = GROUP_SIZE_M * num_pid_n
group_id = pid // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + (pid % group_size_m)
pid_n = (pid % num_pid_in_group) // group_size_m
# ----------------------------------------------------------
# Create pointers for the first blocks of X and W.
# We will advance this pointer as we move in the K direction
# and accumulate
# `x_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers
# `w_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers
# See above `Pointer Arithmetics` section for details
offs_xm = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
offs_wn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
offs_k = tl.arange(0, BLOCK_SIZE_K)
x_ptrs = x_ptr + (offs_xm[:, None] * stride_xm + offs_k[None, :] * stride_xk)
w_ptrs = w_ptr + (offs_k[:, None]//2 * stride_wk + offs_wn[None, :] * stride_wn)
w_shifts = ((offs_k % 2) * 4).to(tl.int8)[:, None]
w_and = (tl.arange(0,1)+0xF).to(tl.int8)[:, None]
w_subs = ((1+offs_k % 2) * 8).to(tl.int8)[:, None]
b_ptrs = b_ptr + (offs_wn * stride_b)
step_w = BLOCK_SIZE_K//2 * stride_wk
step_x = BLOCK_SIZE_K * stride_xk
# -----------------------------------------------------------
# Iterate to compute a block of the Y matrix.
# We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block
# of fp32 values for higher accuracy.
# `accumulator` will be converted back to fp16 after the loop.
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
# Load the next block of A and B, generate a mask by checking the K dimension.
# If it is out of bounds, set it to 0.
x = tl.load(x_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)
w = tl.load(w_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
w = unpack(w, w_shifts, w_and, w_subs)
# We accumulate along the K dimension.
accumulator += tl.dot(x, w.to(tl.bfloat16))
# Advance the ptrs to the next K block.
x_ptrs += step_x
w_ptrs += step_w
s = tl.load(s_ptr)
b = tl.load(b_ptrs)
y = (accumulator.to(tl.bfloat16) * s)+b
# -----------------------------------------------------------
# Write back the block of the output matrix Y with masks.
offs_ym = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_yn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
y_ptrs = y_ptr + stride_ym * offs_ym[:, None] + stride_yn * offs_yn[None, :]
y_mask = (offs_ym[:, None] < M) & (offs_yn[None, :] < N)
tl.store(y_ptrs, y, mask=y_mask)
def int4x2_weight_only_linear(x, w, b, s):
# Check constraints.
assert x.shape[1] == w.shape[0]*2, "Incompatible dimensions"
# assert x.is_contiguous(), "Matrix x must be contiguous"
# assert w.is_contiguous(), "Matrix w must be contiguous"
M, K = x.shape
N = w.shape[1]
assert b.shape[0] == N
# Allocates output.
y = torch.empty((M, N), device=x.device, dtype=x.dtype)
# 1D launch kernel where each block gets its own program.
grid = lambda META: (
triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']),
)
int4x2_weight_only_linear_kernel[grid](
x, w, b, s, y,
M, N, K,
x.stride(0), x.stride(1),
w.stride(0), w.stride(1),
b.stride(0),
y.stride(0), y.stride(1),
)
return y
#### testing accuracy
M,K,N = 4,4,4
x = torch.randn(M,K, dtype=torch.bfloat16).to('cuda')
w_int4 = torch.randint(-8,7, (K, N), dtype=torch.int8).cuda()
w_int4x2 = w_int4[::2]+8+w_int4[1::2]<<4
# w_int8 = torch.randint(-128, 127, (K//2, N), dtype=torch.int8).cuda()
scale = torch.randn(N, dtype=torch.bfloat16).to('cuda')
bias = torch.randn(N, dtype=torch.bfloat16).to('cuda')
base = torch.mm(x, w_int4.to(torch.bfloat16))
ref = torch.mm(x, torch.cat(((w_int4x2 & 0xF)-8, w_int4x2>>4),1).reshape(-1, w_int4x2.shape[1]).to(x.dtype))
test = int4x2_weight_only_linear(x, w_int4x2, bias, scale)
print(base-ref)
print(base-test)
print(ref-test)
assert False
#### testing per
quantiles = [0.5, 0.2, 0.8] # idk what this is for but the tutorial had it
result = {}
w_shapes = [
[4096, 4096],
[4096, 12288],
[4096, 22016],
[11008, 4096],
]
x_shapes = [32,64,192,192*32,192*256]
shapes = [ [x]+w for x in x_shapes for w in w_shapes]
for shape in shapes:
M, K, N = shape
print((M, K, N))
result[(M, K, N)]={}
result[(M, K, N)]["cublas bf16 linear"]={}
result[(M, K, N)]["fp16 to bf16 matmul"]={}
result[(M, K, N)]["int8 to bf16 linear"]={}
result[(M, K, N)]["int4x2 to bf16 linear"]={}
for t_x in [0]:
for t_w in [1]:
x = torch.randn(M,K).to('cuda').to(torch.bfloat16)
w_bf16 = torch.randn(N, K, dtype=torch.bfloat16).cuda()
bias = torch.randn(N, dtype=torch.bfloat16).cuda()
if t_x:
x = x.t().contiguous().t()
if t_w:
w_bf16 = w_bf16.t().contiguous().t()
torch.nn.functional.linear(x, w_bf16, bias)
torch.cuda.synchronize()
result[(M,K,N)]["cublas bf16 linear"][(t_x, t_w)] = triton.testing.do_bench(lambda: torch.nn.functional.linear(x, w_bf16, bias), quantiles=quantiles)[0]
print(result[(M,K,N)]["cublas bf16 linear"][(t_x, t_w)])
del w_bf16
w_fp16 = torch.randn(K,N, dtype=torch.float16).cuda()
scale = torch.randn(N, dtype=torch.bfloat16).cuda()
matmul(x, w_fp16, bias, scale)
torch.cuda.synchronize()
result[(M,K,N)]["fp16 to bf16 matmul"][(t_x, t_w)] = triton.testing.do_bench(lambda: matmul(x, w_fp16, bias, scale), quantiles=quantiles)[0]
print(result[(M,K,N)]["fp16 to bf16 matmul"][(t_x, t_w)])
del w_fp16
w_int8 = torch.randint(-128, 127, (K, N), dtype=torch.int8).cuda()
if t_w:
w_int8 = w_int8.t().contiguous().t()
int8_weight_only_linear(x, w_int8, bias, scale)
torch.cuda.synchronize()
result[(M, K, N)]["int8 to bf16 linear"][(t_x, t_w)] = triton.testing.do_bench(lambda: int8_weight_only_linear(x, w_int8, bias, scale), quantiles=quantiles)[0]
print(result[(M, K, N)]["int8 to bf16 linear"][(t_x, t_w)])
del w_int8
w_int4x2 = torch.randint(-128, 127, (K//2, N), dtype=torch.int8).cuda()
if t_w:
w_int4x2 = w_int4x2.t().contiguous().t()
int4x2_weight_only_linear(x, w_int4x2, bias, scale)
torch.cuda.synchronize()
result[(M, K, N)]["int4x2 to bf16 linear"][(t_x, t_w)] = triton.testing.do_bench(lambda: int4x2_weight_only_linear(x, w_int4x2, bias, scale), quantiles=quantiles)[0]
print(result[(M, K, N)]["int4x2 to bf16 linear"][(t_x, t_w)])
del w_int4x2, x, bias, scale
caches = {"fp16 to bf16 matmul": matmul_kernel.cache, "int8 to bf16 linear": int8_weight_only_linear_kernel.cache, "int4x2 to bf16 linear": int4x2_weight_only_linear_kernel.cache}
print(matmul_kernel.cache)
print("| time | tops/s | model | (M, K, N) | config |")
for shape in result.keys():
cache_key = (shape[0], shape[2], shape[1])
for name in result[shape].keys():
r = result[shape][name]
s = "|"
for key in r:
s += f"{r[key]:.4f} |"
s += f"{2*shape[0]*shape[2]*shape[1]/r[key]:.2e} | {name} | {shape} | {None if name == 'cublas bf16 linear' else caches[name][cache_key]} |"
print(s)
print(" ")
# using triton version triton-nightly 2.1.0.dev20230726014945
# install: pip install -U --index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/Triton-Nightly/pypi/simple/ triton-nightly
# using torch compiled from 2.1.0a0+git9c2122d
# using cuda 12.1, cudnn 8.9.2 on A100 GPU
# ---------- OUTPUT --------------
# | time | tops/s | model | (M, K, N) | config |
# |0.0891 |3.62e+10 | cublas bf16 linear | (32, 4096, 12288) | None |
# |0.1219 |2.64e+10 | fp16 to bf16 matmul | (32, 4096, 12288) | BLOCK_SIZE_M: 32, BLOCK_SIZE_N: 32, BLOCK_SIZE_K: 64, GROUP_SIZE_M: 8, num_warps: 2, num_stages: 3 |
# |0.0717 |4.49e+10 | int8 to bf16 linear | (32, 4096, 12288) | BLOCK_SIZE_M: 32, BLOCK_SIZE_N: 128, BLOCK_SIZE_K: 64, GROUP_SIZE_M: 8, num_warps: 4, num_stages: 6 |
# |0.1556 |2.07e+10 | int4x2 to bf16 linear | (32, 4096, 12288) | BLOCK_SIZE_M: 32, BLOCK_SIZE_N: 32, BLOCK_SIZE_K: 64, GROUP_SIZE_M: 8, num_warps: 2, num_stages: 6 |
# |0.0481 |2.23e+10 | cublas bf16 linear | (32, 4096, 4096) | None |
# |0.0942 |1.14e+10 | fp16 to bf16 matmul | (32, 4096, 4096) | BLOCK_SIZE_M: 16, BLOCK_SIZE_N: 32, BLOCK_SIZE_K: 64, GROUP_SIZE_M: 8, num_warps: 2, num_stages: 4 |
# |0.0512 |2.10e+10 | int8 to bf16 linear | (32, 4096, 4096) | BLOCK_SIZE_M: 16, BLOCK_SIZE_N: 32, BLOCK_SIZE_K: 64, GROUP_SIZE_M: 8, num_warps: 2, num_stages: 4 |
# |0.1116 |9.62e+09 | int4x2 to bf16 linear | (32, 4096, 4096) | BLOCK_SIZE_M: 32, BLOCK_SIZE_N: 32, BLOCK_SIZE_K: 64, GROUP_SIZE_M: 8, num_warps: 2, num_stages: 4 |
# |0.1434 |4.03e+10 | cublas bf16 linear | (32, 4096, 22016) | None |
# |0.1761 |3.28e+10 | fp16 to bf16 matmul | (32, 4096, 22016) | BLOCK_SIZE_M: 32, BLOCK_SIZE_N: 64, BLOCK_SIZE_K: 64, GROUP_SIZE_M: 8, num_warps: 2, num_stages: 3 |
# |0.1116 |5.17e+10 | int8 to bf16 linear | (32, 4096, 22016) | BLOCK_SIZE_M: 32, BLOCK_SIZE_N: 128, BLOCK_SIZE_K: 64, GROUP_SIZE_M: 8, num_warps: 4, num_stages: 4 |
# |0.2345 |2.46e+10 | int4x2 to bf16 linear | (32, 4096, 22016) | BLOCK_SIZE_M: 32, BLOCK_SIZE_N: 32, BLOCK_SIZE_K: 32, GROUP_SIZE_M: 8, num_warps: 2, num_stages: 3 |
# |0.0901 |3.20e+10 | cublas bf16 linear | (32, 11008, 4096) | None |
# |0.2222 |1.30e+10 | fp16 to bf16 matmul | (32, 11008, 4096) | BLOCK_SIZE_M: 16, BLOCK_SIZE_N: 32, BLOCK_SIZE_K: 64, GROUP_SIZE_M: 8, num_warps: 2, num_stages: 4 |
# |0.1178 |2.45e+10 | int8 to bf16 linear | (32, 11008, 4096) | BLOCK_SIZE_M: 16, BLOCK_SIZE_N: 32, BLOCK_SIZE_K: 64, GROUP_SIZE_M: 8, num_warps: 2, num_stages: 4 |
# |0.2970 |9.72e+09 | int4x2 to bf16 linear | (32, 11008, 4096) | BLOCK_SIZE_M: 32, BLOCK_SIZE_N: 32, BLOCK_SIZE_K: 64, GROUP_SIZE_M: 8, num_warps: 2, num_stages: 4 |
# |0.0973 |6.62e+10 | cublas bf16 linear | (64, 4096, 12288) | None |
# |0.1556 |4.14e+10 | fp16 to bf16 matmul | (64, 4096, 12288) | BLOCK_SIZE_M: 128, BLOCK_SIZE_N: 64, BLOCK_SIZE_K: 64, GROUP_SIZE_M: 8, num_warps: 4, num_stages: 4 |
# |0.0983 |6.55e+10 | int8 to bf16 linear | (64, 4096, 12288) | BLOCK_SIZE_M: 32, BLOCK_SIZE_N: 128, BLOCK_SIZE_K: 64, GROUP_SIZE_M: 8, num_warps: 4, num_stages: 6 |
# |0.1741 |3.70e+10 | int4x2 to bf16 linear | (64, 4096, 12288) | BLOCK_SIZE_M: 64, BLOCK_SIZE_N: 32, BLOCK_SIZE_K: 32, GROUP_SIZE_M: 8, num_warps: 2, num_stages: 5 |
# |0.0481 |4.46e+10 | cublas bf16 linear | (64, 4096, 4096) | None |
# |0.1044 |2.06e+10 | fp16 to bf16 matmul | (64, 4096, 4096) | BLOCK_SIZE_M: 32, BLOCK_SIZE_N: 32, BLOCK_SIZE_K: 64, GROUP_SIZE_M: 8, num_warps: 2, num_stages: 3 |
# |0.0635 |3.38e+10 | int8 to bf16 linear | (64, 4096, 4096) | BLOCK_SIZE_M: 32, BLOCK_SIZE_N: 128, BLOCK_SIZE_K: 64, GROUP_SIZE_M: 8, num_warps: 4, num_stages: 4 |
# |0.1229 |1.75e+10 | int4x2 to bf16 linear | (64, 4096, 4096) | BLOCK_SIZE_M: 128, BLOCK_SIZE_N: 32, BLOCK_SIZE_K: 64, GROUP_SIZE_M: 8, num_warps: 4, num_stages: 4 |
# |0.1434 |8.05e+10 | cublas bf16 linear | (64, 4096, 22016) | None |
# |0.2017 |5.72e+10 | fp16 to bf16 matmul | (64, 4096, 22016) | BLOCK_SIZE_M: 32, BLOCK_SIZE_N: 64, BLOCK_SIZE_K: 64, GROUP_SIZE_M: 8, num_warps: 2, num_stages: 3 |
# |0.1669 |6.92e+10 | int8 to bf16 linear | (64, 4096, 22016) | BLOCK_SIZE_M: 64, BLOCK_SIZE_N: 128, BLOCK_SIZE_K: 64, GROUP_SIZE_M: 8, num_warps: 4, num_stages: 4 |
# |0.2642 |4.37e+10 | int4x2 to bf16 linear | (64, 4096, 22016) | BLOCK_SIZE_M: 64, BLOCK_SIZE_N: 32, BLOCK_SIZE_K: 32, GROUP_SIZE_M: 8, num_warps: 2, num_stages: 5 |
# |0.0922 |6.26e+10 | cublas bf16 linear | (64, 11008, 4096) | None |
# |0.2478 |2.33e+10 | fp16 to bf16 matmul | (64, 11008, 4096) | BLOCK_SIZE_M: 32, BLOCK_SIZE_N: 32, BLOCK_SIZE_K: 64, GROUP_SIZE_M: 8, num_warps: 2, num_stages: 4 |
# |0.1516 |3.81e+10 | int8 to bf16 linear | (64, 11008, 4096) | BLOCK_SIZE_M: 32, BLOCK_SIZE_N: 128, BLOCK_SIZE_K: 64, GROUP_SIZE_M: 8, num_warps: 4, num_stages: 4 |
# |0.3082 |1.87e+10 | int4x2 to bf16 linear | (64, 11008, 4096) | BLOCK_SIZE_M: 128, BLOCK_SIZE_N: 32, BLOCK_SIZE_K: 64, GROUP_SIZE_M: 8, num_warps: 4, num_stages: 4 |
# |0.1382 |1.40e+11 | cublas bf16 linear | (192, 4096, 12288) | None |
# |0.2161 |8.95e+10 | fp16 to bf16 matmul | (192, 4096, 12288) | BLOCK_SIZE_M: 64, BLOCK_SIZE_N: 128, BLOCK_SIZE_K: 64, GROUP_SIZE_M: 8, num_warps: 4, num_stages: 4 |
# |0.2222 |8.70e+10 | int8 to bf16 linear | (192, 4096, 12288) | BLOCK_SIZE_M: 256, BLOCK_SIZE_N: 64, BLOCK_SIZE_K: 32, GROUP_SIZE_M: 8, num_warps: 4, num_stages: 4 |
# |0.2621 |7.37e+10 | int4x2 to bf16 linear | (192, 4096, 12288) | BLOCK_SIZE_M: 256, BLOCK_SIZE_N: 64, BLOCK_SIZE_K: 32, GROUP_SIZE_M: 8, num_warps: 4, num_stages: 4 |
# |0.0573 |1.12e+11 | cublas bf16 linear | (192, 4096, 4096) | None |
# |0.1413 |4.56e+10 | fp16 to bf16 matmul | (192, 4096, 4096) | BLOCK_SIZE_M: 128, BLOCK_SIZE_N: 64, BLOCK_SIZE_K: 64, GROUP_SIZE_M: 8, num_warps: 4, num_stages: 4 |
# |0.0963 |6.69e+10 | int8 to bf16 linear | (192, 4096, 4096) | BLOCK_SIZE_M: 32, BLOCK_SIZE_N: 128, BLOCK_SIZE_K: 64, GROUP_SIZE_M: 8, num_warps: 4, num_stages: 4 |
# |0.1556 |4.14e+10 | int4x2 to bf16 linear | (192, 4096, 4096) | BLOCK_SIZE_M: 64, BLOCK_SIZE_N: 32, BLOCK_SIZE_K: 32, GROUP_SIZE_M: 8, num_warps: 2, num_stages: 5 |
# |0.2621 |1.32e+11 | cublas bf16 linear | (192, 4096, 22016) | None |
# |0.4065 |8.52e+10 | fp16 to bf16 matmul | (192, 4096, 22016) | BLOCK_SIZE_M: 64, BLOCK_SIZE_N: 128, BLOCK_SIZE_K: 64, GROUP_SIZE_M: 8, num_warps: 4, num_stages: 4 |
# |0.3963 |8.74e+10 | int8 to bf16 linear | (192, 4096, 22016) | BLOCK_SIZE_M: 64, BLOCK_SIZE_N: 128, BLOCK_SIZE_K: 64, GROUP_SIZE_M: 8, num_warps: 4, num_stages: 4 |
# |0.5038 |6.87e+10 | int4x2 to bf16 linear | (192, 4096, 22016) | BLOCK_SIZE_M: 256, BLOCK_SIZE_N: 64, BLOCK_SIZE_K: 32, GROUP_SIZE_M: 8, num_warps: 4, num_stages: 4 |
# |0.1741 |9.95e+10 | cublas bf16 linear | (192, 11008, 4096) | None |
# |0.3523 |4.92e+10 | fp16 to bf16 matmul | (192, 11008, 4096) | BLOCK_SIZE_M: 128, BLOCK_SIZE_N: 64, BLOCK_SIZE_K: 64, GROUP_SIZE_M: 8, num_warps: 4, num_stages: 4 |
# |0.2365 |7.32e+10 | int8 to bf16 linear | (192, 11008, 4096) | BLOCK_SIZE_M: 32, BLOCK_SIZE_N: 128, BLOCK_SIZE_K: 64, GROUP_SIZE_M: 8, num_warps: 4, num_stages: 6 |
# |0.4024 |4.30e+10 | int4x2 to bf16 linear | (192, 11008, 4096) | BLOCK_SIZE_M: 64, BLOCK_SIZE_N: 32, BLOCK_SIZE_K: 32, GROUP_SIZE_M: 8, num_warps: 2, num_stages: 5 |
# |2.7740 |2.23e+11 | cublas bf16 linear | (6144, 4096, 12288) | None |
# |4.7012 |1.32e+11 | fp16 to bf16 matmul | (6144, 4096, 12288) | BLOCK_SIZE_M: 128, BLOCK_SIZE_N: 64, BLOCK_SIZE_K: 32, GROUP_SIZE_M: 8, num_warps: 4, num_stages: 4 |
# |4.4503 |1.39e+11 | int8 to bf16 linear | (6144, 4096, 12288) | BLOCK_SIZE_M: 256, BLOCK_SIZE_N: 64, BLOCK_SIZE_K: 32, GROUP_SIZE_M: 8, num_warps: 4, num_stages: 4 |
# |5.2439 |1.18e+11 | int4x2 to bf16 linear | (6144, 4096, 12288) | BLOCK_SIZE_M: 256, BLOCK_SIZE_N: 64, BLOCK_SIZE_K: 32, GROUP_SIZE_M: 8, num_warps: 4, num_stages: 4 |
# |1.0148 |2.03e+11 | cublas bf16 linear | (6144, 4096, 4096) | None |
# |1.6502 |1.25e+11 | fp16 to bf16 matmul | (6144, 4096, 4096) | BLOCK_SIZE_M: 128, BLOCK_SIZE_N: 64, BLOCK_SIZE_K: 32, GROUP_SIZE_M: 8, num_warps: 4, num_stages: 4 |
# |1.5913 |1.30e+11 | int8 to bf16 linear | (6144, 4096, 4096) | BLOCK_SIZE_M: 256, BLOCK_SIZE_N: 64, BLOCK_SIZE_K: 32, GROUP_SIZE_M: 8, num_warps: 4, num_stages: 4 |
# |1.8995 |1.09e+11 | int4x2 to bf16 linear | (6144, 4096, 4096) | BLOCK_SIZE_M: 256, BLOCK_SIZE_N: 64, BLOCK_SIZE_K: 32, GROUP_SIZE_M: 8, num_warps: 4, num_stages: 4 |
# |4.9080 |2.26e+11 | cublas bf16 linear | (6144, 4096, 22016) | None |
# |8.3671 |1.32e+11 | fp16 to bf16 matmul | (6144, 4096, 22016) | BLOCK_SIZE_M: 128, BLOCK_SIZE_N: 64, BLOCK_SIZE_K: 32, GROUP_SIZE_M: 8, num_warps: 4, num_stages: 4 |
# |7.9165 |1.40e+11 | int8 to bf16 linear | (6144, 4096, 22016) | BLOCK_SIZE_M: 256, BLOCK_SIZE_N: 64, BLOCK_SIZE_K: 32, GROUP_SIZE_M: 8, num_warps: 4, num_stages: 4 |
# |9.3317 |1.19e+11 | int4x2 to bf16 linear | (6144, 4096, 22016) | BLOCK_SIZE_M: 256, BLOCK_SIZE_N: 64, BLOCK_SIZE_K: 32, GROUP_SIZE_M: 8, num_warps: 4, num_stages: 4 |
# |2.5856 |2.14e+11 | cublas bf16 linear | (6144, 11008, 4096) | None |
# |4.3668 |1.27e+11 | fp16 to bf16 matmul | (6144, 11008, 4096) | BLOCK_SIZE_M: 128, BLOCK_SIZE_N: 64, BLOCK_SIZE_K: 32, GROUP_SIZE_M: 8, num_warps: 4, num_stages: 4 |
# |4.1595 |1.33e+11 | int8 to bf16 linear | (6144, 11008, 4096) | BLOCK_SIZE_M: 256, BLOCK_SIZE_N: 64, BLOCK_SIZE_K: 32, GROUP_SIZE_M: 8, num_warps: 4, num_stages: 4 |
# |5.1236 |1.08e+11 | int4x2 to bf16 linear | (6144, 11008, 4096) | BLOCK_SIZE_M: 256, BLOCK_SIZE_N: 64, BLOCK_SIZE_K: 32, GROUP_SIZE_M: 8, num_warps: 4, num_stages: 4 |
# |21.5255 |2.30e+11 | cublas bf16 linear | (49152, 4096, 12288) | None |
# |36.8973 |1.34e+11 | fp16 to bf16 matmul | (49152, 4096, 12288) | BLOCK_SIZE_M: 128, BLOCK_SIZE_N: 64, BLOCK_SIZE_K: 32, GROUP_SIZE_M: 8, num_warps: 4, num_stages: 4 |
# |34.9880 |1.41e+11 | int8 to bf16 linear | (49152, 4096, 12288) | BLOCK_SIZE_M: 256, BLOCK_SIZE_N: 64, BLOCK_SIZE_K: 32, GROUP_SIZE_M: 8, num_warps: 4, num_stages: 4 |
# |41.0977 |1.20e+11 | int4x2 to bf16 linear | (49152, 4096, 12288) | BLOCK_SIZE_M: 256, BLOCK_SIZE_N: 64, BLOCK_SIZE_K: 32, GROUP_SIZE_M: 8, num_warps: 4, num_stages: 4 |
# |7.2008 |2.29e+11 | cublas bf16 linear | (49152, 4096, 4096) | None |
# |12.3955 |1.33e+11 | fp16 to bf16 matmul | (49152, 4096, 4096) | BLOCK_SIZE_M: 128, BLOCK_SIZE_N: 64, BLOCK_SIZE_K: 32, GROUP_SIZE_M: 8, num_warps: 4, num_stages: 4 |
# |11.7064 |1.41e+11 | int8 to bf16 linear | (49152, 4096, 4096) | BLOCK_SIZE_M: 256, BLOCK_SIZE_N: 64, BLOCK_SIZE_K: 32, GROUP_SIZE_M: 8, num_warps: 4, num_stages: 4 |
# |13.7636 |1.20e+11 | int4x2 to bf16 linear | (49152, 4096, 4096) | BLOCK_SIZE_M: 256, BLOCK_SIZE_N: 64, BLOCK_SIZE_K: 32, GROUP_SIZE_M: 8, num_warps: 4, num_stages: 4 |
# |38.5004 |2.30e+11 | cublas bf16 linear | (49152, 4096, 22016) | None |
# |66.2395 |1.34e+11 | fp16 to bf16 matmul | (49152, 4096, 22016) | BLOCK_SIZE_M: 128, BLOCK_SIZE_N: 64, BLOCK_SIZE_K: 32, GROUP_SIZE_M: 8, num_warps: 4, num_stages: 4 |
# |62.6012 |1.42e+11 | int8 to bf16 linear | (49152, 4096, 22016) | BLOCK_SIZE_M: 256, BLOCK_SIZE_N: 64, BLOCK_SIZE_K: 32, GROUP_SIZE_M: 8, num_warps: 4, num_stages: 4 |
# |73.5396 |1.21e+11 | int4x2 to bf16 linear | (49152, 4096, 22016) | BLOCK_SIZE_M: 256, BLOCK_SIZE_N: 64, BLOCK_SIZE_K: 32, GROUP_SIZE_M: 8, num_warps: 4, num_stages: 4 |
# |18.4074 |2.41e+11 | cublas bf16 linear | (49152, 11008, 4096) | None |
# |34.4018 |1.29e+11 | fp16 to bf16 matmul | (49152, 11008, 4096) | BLOCK_SIZE_M: 128, BLOCK_SIZE_N: 64, BLOCK_SIZE_K: 32, GROUP_SIZE_M: 8, num_warps: 4, num_stages: 4 |
# |30.6647 |1.45e+11 | int8 to bf16 linear | (49152, 11008, 4096) | BLOCK_SIZE_M: 256, BLOCK_SIZE_N: 64, BLOCK_SIZE_K: 32, GROUP_SIZE_M: 8, num_warps: 4, num_stages: 4 |
# |38.0856 |1.16e+11 | int4x2 to bf16 linear | (49152, 11008, 4096) | BLOCK_SIZE_M: 256, BLOCK_SIZE_N: 64, BLOCK_SIZE_K: 32, GROUP_SIZE_M: 8, num_warps: 4, num_stages: 4 |
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment