Skip to content

Instantly share code, notes, and snippets.

@mobicham
Created April 17, 2024 08:52
Show Gist options
  • Save mobicham/9aa8dc0e64ea1cb7d4e44fef55e6a4b4 to your computer and use it in GitHub Desktop.
Save mobicham/9aa8dc0e64ea1cb7d4e44fef55e6a4b4 to your computer and use it in GitHub Desktop.
# pip install git+https://github.com/mobiusml/hqq.git;
# num_threads=12; OMP_NUM_THREADS=$num_threads CUDA_VISIBLE_DEVICES=0 ipython3
##########################################################################################################################################################
import torch, os
os.environ["TOKENIZERS_PARALLELISM"] = "1"
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
cache_path = '.'
model_id = "meta-llama/Llama-2-7b-chat-hf"
compute_dtype = torch.float16 #int4 kernel only works with bfloat16
device = 'cuda:0'
##########################################################################################################################################################
from hqq.engine.hf import HQQModelForCausalLM, AutoTokenizer
from hqq.core.quantize import *
tokenizer = AutoTokenizer.from_pretrained(model_id, cache_dir=cache_path)
model = HQQModelForCausalLM.from_pretrained(model_id, cache_dir=cache_path, torch_dtype=compute_dtype, attn_implementation="sdpa")
quant_config = BaseQuantizeConfig(nbits=4, group_size=64, quant_scale=False, quant_zero=False, axis=0)
model.quantize_model(quant_config=quant_config, compute_dtype=compute_dtype, device=device)
#Set default backends
if(quant_config['weight_quant_params']['axis']==0):
HQQLinear.set_backend(HQQBackend.ATEN)
else:
HQQLinear.set_backend(HQQBackend.PYTORCH)
##########################################################################################################################################################
def patch_linearlayers(model, fct, patch_param=None):
model.base_class.patch_linearlayers(model, fct, dict([(k, patch_param) for k in model.base_class.get_linear_tags()]))
#pip install git+https://github.com/aredden/torch-cublas-hgemm.git
from cublas_ops import *
@torch.jit.ignore()
def zippy_gemv(W, x):
out = hgemv_simt(mat=W, vec=x.view(-1)).view(x.shape[:2] + (W.shape[0],))
return out
def patch_hqq_matmul(layer, patch_param):
def forward(self, x):
W_r = self.dequantize()
#Generation: one-token at a time
if(x.shape[0]==x.shape[1]==1):
#out = torch.matmul(W_r, x.view(-1)).view(x.shape[:2] + (W_r.shape[0],)) #torch.matmul v2
out = zippy_gemv(W_r, x)
#Pre-fill
else:
out = torch.matmul(x, W_r.T)
if(self.bias is not None):
out += self.bias
return out
if(type(layer) is HQQLinear):
layer.forward = lambda x: forward(layer, x)
return layer
patch_linearlayers(model, patch_hqq_matmul)
##########################################################################################################################################################
from hqq.utils.generation_hf import HFGenerator
#Generate
gen = HFGenerator(model, tokenizer, do_sample=False, compile_args=None) #skips compilation: slower, but works properly
#gen = HFGenerator(model, tokenizer, do_sample=False) #compiled: much faster, but there's a bug with HF's StaticCache
out = gen.generate("Write an essay about large language models.", max_new_tokens=100, print_tokens=False)
#Uncompiled:
#dequantize() -> torch.matmul: 10.96 tokens/sec
#dequantize() -> zippy gemv: 11.17 tokens/sec
#Compiled:
#dequantize() -> torch.matmul: 64.4 tokens/sec
#dequantize() -> zippy gemv: crash :(
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment