Skip to content

Instantly share code, notes, and snippets.

@mobicham
Last active June 10, 2024 13:10
Show Gist options
  • Save mobicham/3ef2ef33d7f234f84f80249c41b6fae0 to your computer and use it in GitHub Desktop.
Save mobicham/3ef2ef33d7f234f84f80249c41b6fae0 to your computer and use it in GitHub Desktop.
#pip uninstall torch -y; pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu121; #Pytorch nightly
#git clone https://github.com/microsoft/BitBLAS.git && cd BitBLAS && pip install -e . && cd ..; #BitBLAS
#apt-get install libncurses5 -y
#cd BitBLAS/3rdparty/tvm/ && make -j10 && cd ../../..
#num_threads=24; OMP_NUM_THREADS=$num_threads CUDA_VISIBLE_DEVICES=0 ipython3
######################################################################################################################################
#General
######################################################################
import bitblas # import it frst otherwise it complains about some C++ libs
import torch
import time, gc
import numpy as np
device = 'cuda:0'
compute_dtype = torch.float16
def cleanup():
torch.cuda.empty_cache()
gc.collect()
def eval_time(fct, tag, params, verbose=False, iters=2000, ref_time=None):
t = []
for _ in range(iters):
t1 = time.time()
fct(**params)
torch.cuda.synchronize()
t2 = time.time()
t.append(t2-t1)
out = np.round(np.mean(t[-iters//2:]), 8)
if(verbose):
ref_str = ""
if(ref_time is not None):
x_faster = np.round(ref_time/out, 3)
ref_str = ' | ' + str(x_faster) + "x faster" + ' ✅' if(x_faster>1) else '❌'
print(tag, ':', out, 'sec/iter', ref_str)
return out
#Torch
######################################################################
class TorchLinear(torch.nn.Module):
def __init__(self, weight, bias=None):
super().__init__()
self.weight = weight
self.bias = None
def forward(self, x):
out = torch.matmul(x, self.weight.T)
if(self.bias):
out += self.bias
return out
#BitBLAS
######################################################################
from bitblas import Matmul
class BitBlassLinear(torch.nn.Module):
def __init__(self, weight, nbits=4, group_size=64, batch_size=1, bias=None, device=device, compute_dtype=compute_dtype):
super().__init__()
#In/Out tensors params
self.compute_dtype = compute_dtype
self.device = device
self.dtype_str = str(self.compute_dtype ).split('.')[-1]
#Shapes
self.batch_size = batch_size
self.shape = weight.shape
self.in_features, self.out_features = self.shape[::-1]
#Bias
self.bias = bias
if(self.bias is not None):
if(type(self.bias) is torch.Tensor):
self.bias = self.bias.to(dtype=self.compute_dtype, device=self.device)
if(type(self.bias) is torch.nn.Parameter):
self.bias.data = self.bias.data.to(dtype=self.compute_dtype, device=self.device)
#Quant params
self.group_size = self.in_features if(group_size==-1) else group_size
self.nbits = nbits
storage_nbit = 8 # assume int8 storage
n_float_per_elem = storage_nbit // self.nbits
matmul_config = bitblas.MatmulConfig(
M=self.batch_size,
N=self.out_features,
K=self.in_features,
fast_decoding=True,
A_dtype=self.dtype_str,
W_dtype=f"uint{self.nbits}",
accum_dtype=self.dtype_str,
out_dtype=self.dtype_str,
layout="nt",
with_bias=self.bias is not None,
group_size=self.group_size,
with_scaling=True,
with_zeros=True,
zeros_mode="quantized",
)
self.matmul = Matmul(matmul_config)
self.matmul.hardware_aware_finetune(topk=20)
#Fake data: todo use asym_quant of weight
self.qweight = torch.randint(0, 2**self.nbits - 1, size=(self.out_features, self.in_features // n_float_per_elem), dtype=torch.uint8, device=self.device)
self.scales = torch.randn((self.out_features, self.in_features // self.group_size), dtype=self.compute_dtype, device=self.device)/10.
self.zeros = torch.randn((self.out_features, self.in_features // self.group_size), dtype=self.compute_dtype, device=self.device)/10.
def forward(self, x):
out = torch.empty([x.shape[0], self.out_features], dtype=x.dtype, device=x.device)
self.matmul.forward(A=x, W=self.qweight, scale=self.scales, zeros=self.zeros, bias=self.bias, output=out)
return out
#Eval GEMVs (Decoding only)
######################################################################
shapes = []
for batch_size in [1, 4, 8, 16]:
for N in [4096, 8192, 11008, 14336, 28672]:
for M in [4096, 8192, 11008, 14336, 28672]:
shapes.append({'batch_size':batch_size, 'W_shape':[N, M]})
group_size = 64
verbose = True
results = {}
for shape in shapes:
batch_size, weight_shape = shape['batch_size'], shape['W_shape']
weight = torch.randn(weight_shape, dtype=compute_dtype, device=device)/10.
x = torch.randn([batch_size, weight_shape[1]], dtype=compute_dtype, device=device)
in_features, out_features = weight_shape[::-1]
print('---------------------------------------------------------------------------------------')
print('batch_size', batch_size, ' | ', 'shape', weight_shape)
linear, tag = TorchLinear(weight=weight), 'torch_fp16'
results[tag] = eval_time(linear.forward, tag=tag, verbose=verbose, params={"x": x})
ref_time = results['torch_fp16']
linear, tag = BitBlassLinear(weight=weight, nbits=4, group_size=group_size, batch_size=batch_size), 'bitblas_4bit_fp16'
results[tag] = eval_time(linear.forward, tag=tag, verbose=verbose, params={"x": x}, ref_time=ref_time)
linear, tag = BitBlassLinear(weight=weight, nbits=2, group_size=group_size, batch_size=batch_size), 'bitblas_2bit_fp16'
results[tag] = eval_time(linear.forward, tag=tag, verbose=verbose, params={"x": x}, ref_time=ref_time)
#cleanup
del linear.__dict__
del linear
cleanup()
#########################################################################################################
#RTX 4090
#########################################################################################################
batch_size 1 | shape [4096, 4096]
torch_fp16 : 3.42e-05 sec/iter
bitblas_4bit_fp16 : 3.628e-05 sec/iter ❌
bitblas_2bit_fp16 : 3.621e-05 sec/iter ❌
---------------------------------------------------------------------------------------
batch_size 1 | shape [4096, 8192]
torch_fp16 : 6.181e-05 sec/iter
bitblas_4bit_fp16 : 3.734e-05 sec/iter | 1.655x faster ✅
bitblas_2bit_fp16 : 3.622e-05 sec/iter | 1.707x faster ✅
---------------------------------------------------------------------------------------
batch_size 1 | shape [4096, 11008]
torch_fp16 : 9.346e-05 sec/iter
bitblas_4bit_fp16 : 3.933e-05 sec/iter | 2.376x faster ✅
bitblas_2bit_fp16 : 3.929e-05 sec/iter | 2.379x faster ✅
---------------------------------------------------------------------------------------
batch_size 1 | shape [4096, 14336]
torch_fp16 : 0.00011729 sec/iter
bitblas_4bit_fp16 : 4.165e-05 sec/iter | 2.816x faster ✅
bitblas_2bit_fp16 : 4.103e-05 sec/iter | 2.859x faster ✅
---------------------------------------------------------------------------------------
batch_size 1 | shape [4096, 28672]
torch_fp16 : 0.0002366 sec/iter
bitblas_4bit_fp16 : 5.093e-05 sec/iter | 4.646x faster ✅
bitblas_2bit_fp16 : 5.221e-05 sec/iter | 4.532x faster ✅
---------------------------------------------------------------------------------------
batch_size 1 | shape [8192, 4096]
torch_fp16 : 6.28e-05 sec/iter
bitblas_4bit_fp16 : 3.62e-05 sec/iter | 1.735x faster ✅
bitblas_2bit_fp16 : 3.653e-05 sec/iter | 1.719x faster ✅
---------------------------------------------------------------------------------------
batch_size 1 | shape [8192, 8192]
torch_fp16 : 0.00013772 sec/iter
bitblas_4bit_fp16 : 4.204e-05 sec/iter | 3.276x faster ✅
bitblas_2bit_fp16 : 4.222e-05 sec/iter | 3.262x faster ✅
---------------------------------------------------------------------------------------
batch_size 1 | shape [8192, 11008]
torch_fp16 : 0.00018499 sec/iter
bitblas_4bit_fp16 : 4.564e-05 sec/iter | 4.053x faster ✅
bitblas_2bit_fp16 : 4.627e-05 sec/iter | 3.998x faster ✅
---------------------------------------------------------------------------------------
batch_size 1 | shape [8192, 14336]
torch_fp16 : 0.00023862 sec/iter
bitblas_4bit_fp16 : 5.121e-05 sec/iter | 4.66x faster ✅
bitblas_2bit_fp16 : 5.095e-05 sec/iter | 4.683x faster ✅
---------------------------------------------------------------------------------------
batch_size 1 | shape [8192, 28672]
torch_fp16 : 0.00048317 sec/iter
bitblas_4bit_fp16 : 0.00016411 sec/iter | 2.944x faster ✅
bitblas_2bit_fp16 : 7.261e-05 sec/iter | 6.654x faster ✅
---------------------------------------------------------------------------------------
batch_size 1 | shape [11008, 4096]
torch_fp16 : 9.342e-05 sec/iter
bitblas_4bit_fp16 : 3.777e-05 sec/iter | 2.473x faster ✅
bitblas_2bit_fp16 : 3.764e-05 sec/iter | 2.482x faster ✅
---------------------------------------------------------------------------------------
batch_size 1 | shape [11008, 8192]
torch_fp16 : 0.00018587 sec/iter
bitblas_4bit_fp16 : 4.554e-05 sec/iter | 4.081x faster ✅
bitblas_2bit_fp16 : 4.538e-05 sec/iter | 4.096x faster ✅
---------------------------------------------------------------------------------------
batch_size 1 | shape [11008, 11008]
torch_fp16 : 0.00025034 sec/iter
bitblas_4bit_fp16 : 5.208e-05 sec/iter | 4.807x faster ✅
bitblas_2bit_fp16 : 5.092e-05 sec/iter | 4.916x faster ✅
---------------------------------------------------------------------------------------
batch_size 1 | shape [11008, 14336]
torch_fp16 : 0.00032531 sec/iter
bitblas_4bit_fp16 : 0.00011192 sec/iter | 2.907x faster ✅
bitblas_2bit_fp16 : 5.678e-05 sec/iter | 5.729x faster ✅
---------------------------------------------------------------------------------------
batch_size 1 | shape [11008, 28672]
torch_fp16 : 0.00065201 sec/iter
bitblas_4bit_fp16 : 0.00021042 sec/iter | 3.099x faster ✅
bitblas_2bit_fp16 : 0.00013528 sec/iter | 4.82x faster ✅
---------------------------------------------------------------------------------------
batch_size 1 | shape [14336, 4096]
torch_fp16 : 0.00012206 sec/iter
bitblas_4bit_fp16 : 4.023e-05 sec/iter | 3.034x faster ✅
bitblas_2bit_fp16 : 4.011e-05 sec/iter | 3.043x faster ✅
---------------------------------------------------------------------------------------
batch_size 1 | shape [14336, 8192]
torch_fp16 : 0.00024254 sec/iter
bitblas_4bit_fp16 : 5.035e-05 sec/iter | 4.817x faster ✅
bitblas_2bit_fp16 : 4.91e-05 sec/iter | 4.94x faster ✅
---------------------------------------------------------------------------------------
batch_size 1 | shape [14336, 11008]
torch_fp16 : 0.00032609 sec/iter
bitblas_4bit_fp16 : 0.00011718 sec/iter | 2.783x faster ✅
bitblas_2bit_fp16 : 5.843e-05 sec/iter | 5.581x faster ✅
---------------------------------------------------------------------------------------
batch_size 1 | shape [14336, 14336]
torch_fp16 : 0.00042259 sec/iter
bitblas_4bit_fp16 : 0.00014625 sec/iter | 2.89x faster ✅
bitblas_2bit_fp16 : 6.32e-05 sec/iter | 6.687x faster ✅
---------------------------------------------------------------------------------------
batch_size 1 | shape [14336, 28672]
torch_fp16 : 0.0008506 sec/iter
bitblas_4bit_fp16 : 0.00026286 sec/iter | 3.236x faster ✅
bitblas_2bit_fp16 : 0.00016728 sec/iter | 5.085x faster ✅
---------------------------------------------------------------------------------------
batch_size 1 | shape [28672, 4096]
torch_fp16 : 0.00024394 sec/iter
bitblas_4bit_fp16 : 5.038e-05 sec/iter | 4.842x faster ✅
bitblas_2bit_fp16 : 4.941e-05 sec/iter | 4.937x faster ✅
---------------------------------------------------------------------------------------
batch_size 1 | shape [28672, 8192]
torch_fp16 : 0.00048683 sec/iter
bitblas_4bit_fp16 : 0.0001625 sec/iter | 2.996x faster ✅
bitblas_2bit_fp16 : 6.984e-05 sec/iter | 6.971x faster ✅
---------------------------------------------------------------------------------------
batch_size 1 | shape [28672, 11008]
torch_fp16 : 0.00065427 sec/iter
bitblas_4bit_fp16 : 0.00021553 sec/iter | 3.036x faster ✅
bitblas_2bit_fp16 : 0.00016969 sec/iter | 3.856x faster ✅
---------------------------------------------------------------------------------------
batch_size 1 | shape [28672, 14336]
torch_fp16 : 0.00085253 sec/iter
bitblas_4bit_fp16 : 0.00026216 sec/iter | 3.252x faster ✅
bitblas_2bit_fp16 : 0.00016371 sec/iter | 5.208x faster ✅
---------------------------------------------------------------------------------------
batch_size 1 | shape [28672, 28672]
torch_fp16 : 0.00170864 sec/iter
bitblas_4bit_fp16 : 0.00049289 sec/iter | 3.467x faster ✅
bitblas_2bit_fp16 : 0.00029746 sec/iter | 5.744x faster ✅
---------------------------------------------------------------------------------------
batch_size 4 | shape [4096, 4096]
torch_fp16 : 3.086e-05 sec/iter
bitblas_4bit_fp16 : 4.395e-05 sec/iter ❌
bitblas_2bit_fp16 : 4.406e-05 sec/iter ❌
---------------------------------------------------------------------------------------
batch_size 4 | shape [4096, 8192]
torch_fp16 : 4.451e-05 sec/iter
bitblas_4bit_fp16 : 5.544e-05 sec/iter ❌
bitblas_2bit_fp16 : 5.894e-05 sec/iter ❌
---------------------------------------------------------------------------------------
batch_size 4 | shape [4096, 11008]
torch_fp16 : 0.00011632 sec/iter
bitblas_4bit_fp16 : 6.947e-05 sec/iter | 1.674x faster ✅
bitblas_2bit_fp16 : 6.59e-05 sec/iter | 1.765x faster ✅
---------------------------------------------------------------------------------------
batch_size 4 | shape [4096, 14336]
torch_fp16 : 0.00014391 sec/iter
bitblas_4bit_fp16 : 7.542e-05 sec/iter | 1.908x faster ✅
bitblas_2bit_fp16 : 8.007e-05 sec/iter | 1.797x faster ✅
--------------------------------------------------------------------------------------
batch_size 4 | shape [4096, 28672]
torch_fp16 : 0.00027199 sec/iter
bitblas_4bit_fp16 : 0.00011418 sec/iter | 2.382x faster ✅
bitblas_2bit_fp16 : 0.00011311 sec/iter | 2.405x faster ✅
---------------------------------------------------------------------------------------
batch_size 4 | shape [8192, 4096]
torch_fp16 : 3.694e-05 sec/iter
bitblas_4bit_fp16 : 5.705e-05 sec/iter ❌
bitblas_2bit_fp16 : 5.708e-05 sec/iter ❌
---------------------------------------------------------------------------------------
batch_size 4 | shape [8192, 8192]
torch_fp16 : 0.0001692 sec/iter
bitblas_4bit_fp16 : 7.537e-05 sec/iter | 2.245x faster ✅
bitblas_2bit_fp16 : 8.153e-05 sec/iter | 2.075x faster ✅
---------------------------------------------------------------------------------------
batch_size 4 | shape [8192, 11008]
torch_fp16 : 0.00021019 sec/iter
bitblas_4bit_fp16 : 9.764e-05 sec/iter | 2.153x faster ✅
bitblas_2bit_fp16 : 9.42e-05 sec/iter | 2.231x faster ✅
---------------------------------------------------------------------------------------
batch_size 4 | shape [8192, 14336]
torch_fp16 : 0.0002667 sec/iter
bitblas_4bit_fp16 : 0.00011614 sec/iter | 2.296x faster ✅
bitblas_2bit_fp16 : 0.00010946 sec/iter | 2.437x faster ✅
---------------------------------------------------------------------------------------
batch_size 4 | shape [8192, 28672]
torch_fp16 : 0.00051254 sec/iter
bitblas_4bit_fp16 : 0.00048294 sec/iter | 1.061x faster ✅
bitblas_2bit_fp16 : 0.00020411 sec/iter | 2.511x faster ✅
---------------------------------------------------------------------------------------
batch_size 4 | shape [11008, 4096]
torch_fp16 : 0.0001183 sec/iter
bitblas_4bit_fp16 : 6.981e-05 sec/iter | 1.695x faster ✅
bitblas_2bit_fp16 : 6.971e-05 sec/iter | 1.697x faster ✅
---------------------------------------------------------------------------------------
batch_size 4 | shape [11008, 8192]
torch_fp16 : 0.00022187 sec/iter
bitblas_4bit_fp16 : 8.889e-05 sec/iter | 2.496x faster ✅
bitblas_2bit_fp16 : 8.697e-05 sec/iter | 2.551x faster ✅
---------------------------------------------------------------------------------------
batch_size 4 | shape [11008, 4096]
torch_fp16 : 0.0001183 sec/iter
bitblas_4bit_fp16 : 6.981e-05 sec/iter | 1.695x faster ✅
bitblas_2bit_fp16 : 6.971e-05 sec/iter | 1.697x faster ✅
---------------------------------------------------------------------------------------
batch_size 4 | shape [11008, 8192]
torch_fp16 : 0.00022187 sec/iter
bitblas_4bit_fp16 : 8.889e-05 sec/iter | 2.496x faster ✅
bitblas_2bit_fp16 : 8.697e-05 sec/iter | 2.551x faster ✅
---------------------------------------------------------------------------------------
batch_size 4 | shape [11008, 11008]
torch_fp16 : 0.0002819 sec/iter
bitblas_4bit_fp16 : 0.00014008 sec/iter | 2.012x faster ✅
bitblas_2bit_fp16 : 0.00012171 sec/iter | 2.316x faster ✅
---------------------------------------------------------------------------------------
batch_size 4 | shape [8192, 14336]
torch_fp16 : 0.0002667 sec/iter
bitblas_4bit_fp16 : 0.00011614 sec/iter | 2.296x faster ✅
bitblas_2bit_fp16 : 0.00010946 sec/iter | 2.437x faster ✅
---------------------------------------------------------------------------------------
batch_size 4 | shape [8192, 28672]
torch_fp16 : 0.00051254 sec/iter
bitblas_4bit_fp16 : 0.00048294 sec/iter | 1.061x faster ✅
bitblas_2bit_fp16 : 0.00020411 sec/iter | 2.511x faster ✅
---------------------------------------------------------------------------------------
batch_size 4 | shape [11008, 4096]
torch_fp16 : 0.0001183 sec/iter
bitblas_4bit_fp16 : 6.981e-05 sec/iter | 1.695x faster ✅
bitblas_2bit_fp16 : 6.971e-05 sec/iter | 1.697x faster ✅
---------------------------------------------------------------------------------------
batch_size 4 | shape [11008, 8192]
torch_fp16 : 0.00022187 sec/iter
bitblas_4bit_fp16 : 8.889e-05 sec/iter | 2.496x faster ✅
bitblas_2bit_fp16 : 8.697e-05 sec/iter | 2.551x faster ✅
---------------------------------------------------------------------------------------
batch_size 4 | shape [11008, 11008]
torch_fp16 : 0.0002819 sec/iter
bitblas_4bit_fp16 : 0.00014008 sec/iter | 2.012x faster ✅
bitblas_2bit_fp16 : 0.00012171 sec/iter | 2.316x faster ✅
---------------------------------------------------------------------------------------
batch_size 4 | shape [11008, 14336]
torch_fp16 : 0.00036405 sec/iter
bitblas_4bit_fp16 : 0.00022627 sec/iter | 1.609x faster ✅
bitblas_2bit_fp16 : 0.00014614 sec/iter | 2.491x faster ✅
---------------------------------------------------------------------------------------
batch_size 4 | shape [11008, 28672]
torch_fp16 : 0.00069423 sec/iter
bitblas_4bit_fp16 : 0.00074643 sec/iter ❌
bitblas_2bit_fp16 : 0.00038147 sec/iter | 1.82x faster ✅
---------------------------------------------------------------------------------------
batch_size 4 | shape [14336, 4096]
torch_fp16 : 0.00014509 sec/iter
bitblas_4bit_fp16 : 7.343e-05 sec/iter | 1.976x faster ✅
bitblas_2bit_fp16 : 7.383e-05 sec/iter | 1.965x faster ✅
---------------------------------------------------------------------------------------
batch_size 4 | shape [14336, 8192]
torch_fp16 : 0.00027772 sec/iter
bitblas_4bit_fp16 : 0.0001098 sec/iter | 2.529x faster ✅
bitblas_2bit_fp16 : 0.00010597 sec/iter | 2.621x faster ✅
---------------------------------------------------------------------------------------
batch_size 4 | shape [14336, 11008]
torch_fp16 : 0.00035274 sec/iter
bitblas_4bit_fp16 : 0.0002119 sec/iter | 1.665x faster ✅
bitblas_2bit_fp16 : 0.0001382 sec/iter | 2.552x faster ✅
---------------------------------------------------------------------------------------
batch_size 4 | shape [14336, 14336]
torch_fp16 : 0.00045208 sec/iter
bitblas_4bit_fp16 : 0.00048183 sec/iter ❌
bitblas_2bit_fp16 : 0.00017588 sec/iter | 2.57x faster ✅
---------------------------------------------------------------------------------------
batch_size 4 | shape [14336, 28672]
torch_fp16 : 0.00089674 sec/iter
bitblas_4bit_fp16 : 0.00095913 sec/iter ❌
bitblas_2bit_fp16 : 0.00063574 sec/iter | 1.411x faster ✅
----------------------------------------------------------------------------------------
batch_size 4 | shape [28672, 4096]
torch_fp16 : 0.00026862 sec/iter
bitblas_4bit_fp16 : 0.00010704 sec/iter | 2.51x faster ✅
bitblas_2bit_fp16 : 0.000115 sec/iter | 2.336x faster ✅
---------------------------------------------------------------------------------------
@Iron-Bound
Copy link

🚀

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment