Skip to content

Instantly share code, notes, and snippets.

@KeremTurgutlu
Last active July 11, 2024 20:40
Show Gist options
  • Save KeremTurgutlu/25f7c9d6d0328621bf6462871698880b to your computer and use it in GitHub Desktop.
Save KeremTurgutlu/25f7c9d6d0328621bf6462871698880b to your computer and use it in GitHub Desktop.
HQQ Tinygemm vs BitBlas Benchmark
import torch
import numpy as np
from hqq.core.quantize import HQQLinear, BaseQuantizeConfig, Quantizer, HQQBackend
from hqq.backends.torchao import HQQLinearTorchWeightOnlynt4, patch_hqq_to_aoint4
# from unpack_int4.ops import unpack_int4_packed
import torchao
import bitblas
# unpack_cuda_compiled = torch.compile(torchao.ops.unpack_int4_to_int, mode="default", fullgraph=True)
from bitblas.cache import global_operator_cache, get_database_path
from bitblas.module import BITBLAS_TARGET, BITBLAS_DATABASE_PATH
def _get_or_create_bitblas_operator(config):
if global_operator_cache.size() == 0:
global_operator_cache.load_from_database(BITBLAS_DATABASE_PATH, BITBLAS_TARGET)
bitblas_matmul = global_operator_cache.get(config)
if bitblas_matmul is None:
# should disable tuning for the first time because we may require loading bitblas operator from database.
bitblas_matmul = bitblas.Matmul(config)
bitblas_matmul.hardware_aware_finetune(topk=20)
global_operator_cache.add(config, bitblas_matmul)
global_operator_cache.save_into_database(BITBLAS_DATABASE_PATH, BITBLAS_TARGET)
print("BitBLAS Tuning done, appended operator to global_operator_cache.")
else:
print("BitBLAS Operator found in global_operator_cache.")
return bitblas_matmul
def timed(fn):
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start.record()
result = fn()
end.record()
torch.cuda.synchronize()
return result, start.elapsed_time(end) / 1000
# @torch.compile(fullgraph=False)
def hqq_quants_to_torch_quants(W_q, scales, zeros, shape, nbits=4):
# W_q = W_q.to(dtype=self.compute_dtype, device=self.device)
# scales = scales.to(dtype=self.compute_dtype, device=self.device)
# zeros = zeros.to(dtype=self.compute_dtype, device=self.device)
max_int = 2**nbits - 1
min_int = 0
dump = 2 ** (nbits - 1)
# HQQ -> torch logic
new_zeros = (scales * dump) - zeros * scales
min_val = new_zeros - scales * dump
# group_quantize_tensor_from_qparams
W_r = (W_q - zeros) * scales
W_q = (
W_r.sub(min_val)
.div(scales)
.round()
.clamp_(min_int, max_int)
.to(torch.int32)
.reshape(shape)
.contiguous()
)
# group_dequantize_tensor_from_qparams
# W_r = W_q*scales + min_val
scales = scales.contiguous().reshape(shape[0], -1)
new_zeros = new_zeros.contiguous().reshape(shape[0], -1)
return W_q, scales, new_zeros
def pack_scales_and_zeros(scales, zeros):
return (
torch.cat(
[
scales.reshape(scales.size(0), scales.size(1), 1),
zeros.reshape(zeros.size(0), zeros.size(1), 1),
],
2,
)
.transpose(0, 1)
.contiguous()
)
def reshape_packed(packed_tensor):
inner_k_tiles = packed_tensor.size(-1) * 2
return packed_tensor.permute(0, 1, 3, 2).reshape(packed_tensor.size(0),
packed_tensor.size(1) * (inner_k_tiles // 2),
packed_tensor.size(2),
1).contiguous()
def _unpack_shifting(packed_tensor):
return [(packed_tensor >> (i * 4)) & 15 for i in range(8)]
@torch.compile(fullgraph=True)
def unpack_int4_32_pack_fast(packed_tensor, shape):
reshaped_tensor = reshape_packed(packed_tensor)
unpacked_tensors = _unpack_shifting(reshaped_tensor)
# use torch.cat
cat_tensors = [torch.cat(unpacked_tensors[i::4], dim=-1).view(-1, 8) for i in range(4)]
concatenated = torch.cat(cat_tensors, dim=-1)
# # pre-allocate
# concatenated = torch.empty(shape[0]*shape[1]//32, 32, device=reshaped_tensor.device, dtype=reshaped_tensor.dtype)
# for i in range(4):
# concatenated[:,i*8:(i+1)*8] = torch.cat(unpacked_tensors[i::4], dim=-1).view(-1, 8)
group_size = shape[1] // 32
chunked_o = concatenated.view(-1, 8).unsqueeze(0).view(concatenated.size(0) // 8, 8, -1).unsqueeze(0)
res = chunked_o.view(-1, group_size, chunked_o.size(2), chunked_o.size(3)).permute(0, 2, 1, 3).reshape(shape)
return res
def unpack_scales_and_zeros(scales_and_zeros):
assert len(scales_and_zeros.shape) == 3 and scales_and_zeros.shape[2] == 2
# assert scales_and_zeros.dtype == torch.float
return torch.split(scales_and_zeros.transpose(0, 1), 1, 2)
def group_dequantize_tensor_from_qparams(
w_int32, scales, zeros, n_bit=4, groupsize=128
):
assert groupsize > 1
# needed for GPTQ single column dequantize
if groupsize > w_int32.shape[-1] and scales.shape[-1] == 1:
groupsize = w_int32.shape[-1]
assert w_int32.shape[-1] % groupsize == 0
assert w_int32.dim() == 2
w_int32_grouped = w_int32.reshape(-1, groupsize)
scales = scales.reshape(-1, 1)
zeros = zeros.reshape(-1, 1)
w_dq = (
w_int32_grouped.sub(2 ** (n_bit - 1)).mul(scales).add(zeros).reshape_as(w_int32)
)
return w_dq
@torch.compile(mode="default", fullgraph=True)
def tinygemm_unpack_dequant_matmul_naive(x, weight_int4pack, scales_and_zeros, groupsize, shape):
unpacked_W_q = unpack_int4_32_pack_fast(weight_int4pack, shape)
return x @ group_dequantize_tensor_from_qparams(unpacked_W_q, *unpack_scales_and_zeros(scales_and_zeros), groupsize=groupsize).T
# @torch.compile(fullgraph=True)
def tinygemm_unpack_dequant_matmul(x, weight_int4pack, scales_and_zeros, groupsize, shape):
inner_k_tiles = weight_int4pack.size(-1) * 2
unpacked_W_q = torchao.ops.dequantize_tensor_core_tiled_layout(weight_int4pack, scales_and_zeros, groupsize, inner_k_tiles)
return x @ unpacked_W_q.T
W_q_torch = torch.randint(0, 16, (8192, 8192), dtype=torch.int32, device="cuda")
weight_int4pack_inner_tile2 = torch.ops.aten._convert_weight_to_int4pack(W_q_torch, 2)
weight_int4pack_inner_tile4 = torch.ops.aten._convert_weight_to_int4pack(W_q_torch, 4)
weight_int4pack_inner_tile8 = torch.ops.aten._convert_weight_to_int4pack(W_q_torch, 8)
unpacked_W_q = unpack_int4_32_pack_fast(weight_int4pack_inner_tile2, W_q_torch.shape)
assert torch.equal(unpacked_W_q, W_q_torch)
unpacked_W_q = unpack_int4_32_pack_fast(weight_int4pack_inner_tile4, W_q_torch.shape)
assert torch.equal(unpacked_W_q, W_q_torch)
unpacked_W_q = unpack_int4_32_pack_fast(weight_int4pack_inner_tile8, W_q_torch.shape)
assert torch.equal(unpacked_W_q, W_q_torch)
GROUP_SIZE = 128
quant_config = BaseQuantizeConfig(nbits=4,
group_size=GROUP_SIZE,
quant_zero=False,
quant_scale=False,
offload_meta=False,
view_as_float=False,
axis=1)
in_features = 4096
out_features = 7168
W = torch.randn(out_features, in_features, dtype=torch.bfloat16, device="cuda") # output x input
m = torch.nn.Linear(*W.T.shape, bias=False)
m.weight.data.copy_(W)
hqq_linear = HQQLinear(m, quant_config, compute_dtype=torch.bfloat16, del_orig=True)
HQQLinear.set_backend(HQQBackend.PYTORCH)
# HQQ to Tinygemm conversion (4-bit).
W_q_unpacked = Quantizer.unpack[hqq_linear.meta['packing']](hqq_linear.W_q)
scale, zero, shape = hqq_linear.meta['scale'], hqq_linear.meta['zero'], hqq_linear.meta['shape']
W_q_torch, scales_torch, zeros_torch = hqq_quants_to_torch_quants(W_q_unpacked, scale, zero, shape)
scales_and_zeros = pack_scales_and_zeros(scales_torch, zeros_torch)
weight_int4pack_inner_tile8 = torch.ops.aten._convert_weight_to_int4pack(W_q_torch, 8)
INPUT_SIZES = [4,8,16,32,64,80,96,128,256,512,1024]
BITBLAS_OPT_M = [1, 16, 32, 64, 128, 256, 512]
# BITBLAS_OPT_M = [1]
# HQQ to bitblas conversion (4-bit).
quant_config = BaseQuantizeConfig(nbits=4,
group_size=GROUP_SIZE,
quant_zero=False,
quant_scale=False,
offload_meta=False,
view_as_float=False,
axis=1)
W = torch.randn(out_features, in_features, dtype=torch.half, device="cuda") # output x input
m = torch.nn.Linear(*W.T.shape, bias=False)
m.weight.data.copy_(W)
hqq_linear = HQQLinear(m, quant_config, compute_dtype=torch.half, del_orig=True)
HQQLinear.set_backend(HQQBackend.PYTORCH)
W_q_unpacked = Quantizer.unpack[hqq_linear.meta['packing']](hqq_linear.W_q)
scale, zero, shape = hqq_linear.meta['scale'], hqq_linear.meta['zero'], hqq_linear.meta['shape']
matmul_config = bitblas.MatmulConfig(
M=BITBLAS_OPT_M,
N=out_features,
K=in_features,
A_dtype="float16",
W_dtype="uint4",
accum_dtype="float16",
out_dtype="float16",
layout="nt",
with_bias=False,
group_size=GROUP_SIZE,
with_scaling=True,
with_zeros=True,
zeros_mode="original",
#fast_decoding=True,
)
matmul_eng_4bit = _get_or_create_bitblas_operator(matmul_config)
Wq_bitblas_4bit = matmul_eng_4bit.transform_weight(W_q_unpacked.reshape(shape))
meta_shape_bitblas = (hqq_linear.out_features, hqq_linear.in_features // GROUP_SIZE)
scales_bitblas_4bit = scale.view(meta_shape_bitblas)
zeros_bitblas_4bit = zero.view(meta_shape_bitblas)
# HQQ to bitblas conversion (2-bit).
quant_config = BaseQuantizeConfig(nbits=2,
group_size=GROUP_SIZE,
quant_zero=False,
quant_scale=False,
offload_meta=False,
view_as_float=False,
axis=1)
W = torch.randn(out_features, in_features, dtype=torch.half, device="cuda") # output x input
m = torch.nn.Linear(*W.T.shape, bias=False)
m.weight.data.copy_(W)
hqq_linear = HQQLinear(m, quant_config, compute_dtype=torch.half, del_orig=True)
HQQLinear.set_backend(HQQBackend.PYTORCH)
matmul_config = bitblas.MatmulConfig(
M=BITBLAS_OPT_M,
N=out_features,
K=in_features,
A_dtype="float16",
W_dtype="uint2",
accum_dtype="float16",
out_dtype="float16",
layout="nt",
with_bias=False,
group_size=GROUP_SIZE,
with_scaling=True,
with_zeros=True,
zeros_mode="original",
#fast_decoding=True,
)
matmul_eng_2bit = _get_or_create_bitblas_operator(matmul_config)
Wq_bitblas_2bit = matmul_eng_2bit.transform_weight(W_q_unpacked.reshape(shape))
meta_shape_bitblas = (hqq_linear.out_features, hqq_linear.in_features // GROUP_SIZE)
scales_bitblas_2bit = scale.view(meta_shape_bitblas)
zeros_bitblas_2bit = zero.view(meta_shape_bitblas)
for bs in INPUT_SIZES: # think it of bs x seqlen
x = torch.randn(bs, in_features, dtype=torch.bfloat16, device="cuda")
x_fp16 = torch.randn(bs, in_features, dtype=torch.half, device="cuda")
print(bs)
# tinygemm matmul time (ms)
times = []
for i in range(30):
tinygemm_out, time = timed(lambda: torch.ops.aten._weight_int4pack_mm(x,
weight_int4pack_inner_tile8,
GROUP_SIZE,
scales_and_zeros))
if i > 5:
times.append(time*1000)
print(f"tinygemm orig matmul: {np.mean(times)}")
# tinygemm unpack-dequant-matmul time (ms)
times = []
for i in range(30):
unpacked_tinygemm_out, time = timed(lambda: tinygemm_unpack_dequant_matmul(x,
weight_int4pack_inner_tile8,
scales_and_zeros,
groupsize=GROUP_SIZE,
shape=shape))
if i > 5:
times.append(time*1000)
np.mean(times)
print(f"tinygemm fused unpack-dequant-matmul: {np.mean(times)}")
# bitblas 4-bit matmul time (ms)
times = []
for i in range(30):
bitblas_out, time = timed(lambda: matmul_eng_4bit(x_fp16,
Wq_bitblas_4bit,
scale=scales_bitblas_4bit,
zeros=zeros_bitblas_4bit))
if i > 5:
times.append(time*1000)
np.mean(times)
print(f"bitblas 4-bit matmul: {np.mean(times)}")
# bitblas 2-bit matmul time (ms)
times = []
for i in range(30):
bitblas_out, time = timed(lambda: matmul_eng_2bit(x_fp16,
Wq_bitblas_2bit,
scale=scales_bitblas_2bit,
zeros=zeros_bitblas_2bit))
if i > 5:
times.append(time*1000)
np.mean(times)
print(f"bitblas 2-bit matmul: {np.mean(times)}")
_get_or_create_bitblas_operator(matmul_config)
# BitBLAS Tuning done, appended operator to global_operator_cache.
# BitBLAS Tuning done, appended operator to global_operator_cache.
# 4
# tinygemm orig matmul: 0.038186667331804834
# tinygemm fused unpack-dequant-matmul: 0.27805466825763386
# bitblas 4-bit matmul: 0.0710400016978383
# bitblas 2-bit matmul: 0.10022266364345948
# 8
# tinygemm orig matmul: 0.044628000197311245
# tinygemm fused unpack-dequant-matmul: 0.2707599997520447
# bitblas 4-bit matmul: 0.06920533441007137
# bitblas 2-bit matmul: 0.10069333016872406
# 16
# tinygemm orig matmul: 0.07483733631670475
# tinygemm fused unpack-dequant-matmul: 0.2706800041099389
# bitblas 4-bit matmul: 0.07005466800183058
# bitblas 2-bit matmul: 0.10090666357427835
# 32
# tinygemm orig matmul: 0.12526933290064335
# tinygemm fused unpack-dequant-matmul: 0.26960799594720203
# bitblas 4-bit matmul: 0.09228533475349347
# bitblas 2-bit matmul: 0.08145066661139329
# 64
# tinygemm orig matmul: 0.23359866812825203
# tinygemm fused unpack-dequant-matmul: 0.26743199800451595
# bitblas 4-bit matmul: 0.09467333555221558
# bitblas 2-bit matmul: 0.1003906639913718
# 80
# tinygemm orig matmul: 0.28740132972598076
# tinygemm fused unpack-dequant-matmul: 0.2831786771615346
# bitblas 4-bit matmul: 0.13239066861569881
# bitblas 2-bit matmul: 0.13367333138982454
# 96
# tinygemm orig matmul: 0.345684003084898
# tinygemm fused unpack-dequant-matmul: 0.28288133814930916
# bitblas 4-bit matmul: 0.13734399837752184
# bitblas 2-bit matmul: 0.1394786648452282
# 128
# tinygemm orig matmul: 0.45201199998458225
# tinygemm fused unpack-dequant-matmul: 0.2844146713614464
# bitblas 4-bit matmul: 0.15479466691613197
# bitblas 2-bit matmul: 0.1588479975859324
# 256
# tinygemm orig matmul: 0.8864426662524542
# tinygemm fused unpack-dequant-matmul: 0.3223479986190796
# bitblas 4-bit matmul: 0.2660679966211319
# bitblas 2-bit matmul: 0.24575733207166195
# 512
# tinygemm orig matmul: 1.8986239979664485
# tinygemm fused unpack-dequant-matmul: 0.4362240085999171
# bitblas 4-bit matmul: 0.3694506672521432
# bitblas 2-bit matmul: 0.3685973385969798
# 1024
# tinygemm orig matmul: 3.9093759953975677
# tinygemm fused unpack-dequant-matmul: 0.6917120044430097
# bitblas 4-bit matmul: 0.6495146602392197
# bitblas 2-bit matmul: 0.6729813342293104
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment