Skip to content

Instantly share code, notes, and snippets.

@mobicham
Created April 12, 2024 10:11
Show Gist options
  • Save mobicham/7fb59e825fed0831fccf44752cb21214 to your computer and use it in GitHub Desktop.
Save mobicham/7fb59e825fed0831fccf44752cb21214 to your computer and use it in GitHub Desktop.
# pip install transformers, accelerate, einops
# pip install git+https://github.com/mobiusml/hqq.git
# pip install git+https://github.com/aredden/torch-cublas-hgemm.git
######################################################################################################
import torch, os
#Settings
########################################################################################################
#Chose a model
model_id = "meta-llama/Llama-2-7b-hf"
device = 'cuda'
compute_dtype = torch.float16
use_flash_attn = False#True
cache_path = '/nas/hicham/tmp/' if(os.path.isdir('/nas/hicham/tmp/')) else '.'
########################################################################################################
#HQQ Quantize
########################################################################################################
from transformers import AutoModelForCausalLM
from hqq.engine.hf import HQQModelForCausalLM, AutoTokenizer
from hqq.core.quantize import *
from hqq.core.utils import *
model = HQQModelForCausalLM.from_pretrained(model_id, cache_dir=cache_path, torch_dtype=compute_dtype, attn_implementation="flash_attention_2" if use_flash_attn else "sdpa")
tokenizer = AutoTokenizer.from_pretrained(model_id, cache_dir=cache_path)
quant_config = BaseQuantizeConfig(nbits=4, group_size=64, quant_scale=False, quant_zero=False)
model.quantize_model(quant_config=quant_config, compute_dtype=compute_dtype, device=device)
HQQLinear.set_backend(HQQBackend.PYTORCH)
########################################################################################################
from cublas_ops import *
@torch.jit.ignore()
def zippy_matmul(x, W):
#out = cublas_half_matmul_simple(x.contiguous().view(-1, x.shape[-1]), W).reshape([x.shape[0], x.shape[1], W.shape[0]])
out = cublas_half_matmul_batched_simple(x.contiguous(), W)
return out
def patch_fct(layer, patch_params):
if(type(layer) is HQQLinear):
def new_forward(self, x):
unpack_fct = Quantizer.unpack[self.meta['packing']]
W_r = ((unpack_fct(self.W_q, self.compute_dtype) - self.meta['zero'])*self.meta['scale']).view(self.meta['shape'])
#Matmul
#out = torch.matmul(x, W_r.T)
out = zippy_matmul(x, W_r)
#There's no bias in llama2-7B
if(self.bias is not None):
out += self.bias
return out
layer.forward = lambda x: new_forward(layer, x)
return layer
model.base_class.patch_linearlayers(model, patch_fct, dict([(k, None) for k in model.base_class.get_linear_tags()]))
cleanup()
########################################################################################################
#Benchmark
model.config.use_cache = False
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
model.forward = torch.compile(model.forward)
batch_size, context_size = 1, 1024
#warm-up
for _ in range(10):
with torch.no_grad():
out = model(torch.ones((batch_size, context_size), dtype=torch.int32, device=device)).logits
del out ; cleanup()
import time
import numpy as np
from tqdm import tqdm
t = []
for _ in tqdm(range(100)):
with torch.no_grad():
data = torch.randint(0, 100, (batch_size, context_size), dtype=torch.int32, device=device)
t1 = time.time()
out = model(data).logits
torch.cuda.synchronize()
t2 = time.time()
t.append(t2-t1)
del out; cleanup()
print(np.mean(t))
#Llama2-7B | batch_size=1 / context-size:1024 / forward pass in secs / with torch.compile / spda attnention
# GPU: 4090
# --------------------------------------------------
# FP16: No quant - torch.matmul: 0.08854429721832276
# HQQ quantization:
# Int4: dequantize()->torch.matmul(): 0.108595871925354
# Int4: dequantize()->Zippy simple matmul(): 0.14381505489349367
# Int4: dequantize()->Zippy batched matmul(): 0.13050755262374877
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment