-
-
Save mobicham/9aa8dc0e64ea1cb7d4e44fef55e6a4b4 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# pip install git+https://github.com/mobiusml/hqq.git; | |
# num_threads=12; OMP_NUM_THREADS=$num_threads CUDA_VISIBLE_DEVICES=0 ipython3 | |
########################################################################################################################################################## | |
import torch, os | |
os.environ["TOKENIZERS_PARALLELISM"] = "1" | |
torch.backends.cuda.matmul.allow_tf32 = True | |
torch.backends.cudnn.allow_tf32 = True | |
cache_path = '.' | |
model_id = "meta-llama/Llama-2-7b-chat-hf" | |
compute_dtype = torch.float16 #int4 kernel only works with bfloat16 | |
device = 'cuda:0' | |
########################################################################################################################################################## | |
from hqq.engine.hf import HQQModelForCausalLM, AutoTokenizer | |
from hqq.core.quantize import * | |
tokenizer = AutoTokenizer.from_pretrained(model_id, cache_dir=cache_path) | |
model = HQQModelForCausalLM.from_pretrained(model_id, cache_dir=cache_path, torch_dtype=compute_dtype, attn_implementation="sdpa") | |
quant_config = BaseQuantizeConfig(nbits=4, group_size=64, quant_scale=False, quant_zero=False, axis=0) | |
model.quantize_model(quant_config=quant_config, compute_dtype=compute_dtype, device=device) | |
#Set default backends | |
if(quant_config['weight_quant_params']['axis']==0): | |
HQQLinear.set_backend(HQQBackend.ATEN) | |
else: | |
HQQLinear.set_backend(HQQBackend.PYTORCH) | |
########################################################################################################################################################## | |
def patch_linearlayers(model, fct, patch_param=None): | |
model.base_class.patch_linearlayers(model, fct, dict([(k, patch_param) for k in model.base_class.get_linear_tags()])) | |
#pip install git+https://github.com/aredden/torch-cublas-hgemm.git | |
from cublas_ops import * | |
@torch.jit.ignore() | |
def zippy_gemv(W, x): | |
out = hgemv_simt(mat=W, vec=x.view(-1)).view(x.shape[:2] + (W.shape[0],)) | |
return out | |
def patch_hqq_matmul(layer, patch_param): | |
def forward(self, x): | |
W_r = self.dequantize() | |
#Generation: one-token at a time | |
if(x.shape[0]==x.shape[1]==1): | |
#out = torch.matmul(W_r, x.view(-1)).view(x.shape[:2] + (W_r.shape[0],)) #torch.matmul v2 | |
out = zippy_gemv(W_r, x) | |
#Pre-fill | |
else: | |
out = torch.matmul(x, W_r.T) | |
if(self.bias is not None): | |
out += self.bias | |
return out | |
if(type(layer) is HQQLinear): | |
layer.forward = lambda x: forward(layer, x) | |
return layer | |
patch_linearlayers(model, patch_hqq_matmul) | |
########################################################################################################################################################## | |
from hqq.utils.generation_hf import HFGenerator | |
#Generate | |
gen = HFGenerator(model, tokenizer, do_sample=False, compile_args=None) #skips compilation: slower, but works properly | |
#gen = HFGenerator(model, tokenizer, do_sample=False) #compiled: much faster, but there's a bug with HF's StaticCache | |
out = gen.generate("Write an essay about large language models.", max_new_tokens=100, print_tokens=False) | |
#Uncompiled: | |
#dequantize() -> torch.matmul: 10.96 tokens/sec | |
#dequantize() -> zippy gemv: 11.17 tokens/sec | |
#Compiled: | |
#dequantize() -> torch.matmul: 64.4 tokens/sec | |
#dequantize() -> zippy gemv: crash :( |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment