Skip to content

Instantly share code, notes, and snippets.

@mobicham
Created January 12, 2024 14:45
Show Gist options
  • Save mobicham/8b3147742beb3b302064453a15ced428 to your computer and use it in GitHub Desktop.
Save mobicham/8b3147742beb3b302064453a15ced428 to your computer and use it in GitHub Desktop.
awq_hqq_test.py
#Settings
######################################################################################
hf_auth = None #HuggingFace token
cache_path = '' #cache directory to store data
#Chose a model
model_id = "meta-llama/Llama-2-7b-hf"
#model_id = "meta-llama/Llama-2-13b-hf"
#model_id = "meta-llama/Llama-2-70b-hf"
#HQQ Quantize
######################################################################################
from hqq.engine.hf import HQQModelForCausalLM, AutoTokenizer
model = HQQModelForCausalLM.from_pretrained(model_id, use_auth_token=hf_auth, cache_dir=cache_path)
tokenizer = AutoTokenizer.from_pretrained(model_id, use_auth_token=hf_auth, cache_dir=cache_path)
from hqq.core.quantize import *
quant_config = BaseQuantizeConfig(nbits=4, group_size=128, quant_scale=False, quant_zero=False)
################################################################################################
# #HQQ Orig
# model.quantize_model(quant_config=quant_config)
# HQQLinear.set_backend(HQQBackend.PYTORCH_COMPILE)
################################################################################################
#Awq Patching
from awq.modules.linear import WQLinear_GEMM, WQLinear_GEMV
from hqq.core.quantize import *
def patch_linear(linear_layer, quant_config):
if(quant_config):
w_quant_params = quant_config['weight_quant_params']
w_quant_params.update({'axis':1})
hqq_layer = HQQLinear(linear_layer, quant_config=quant_config, del_orig=False)
w_shape = hqq_layer.meta['shape']
W_ref = linear_layer.weight.data.clone()
##########################################################
version = 'GEMV' #'GEMV'
device = hqq_layer.device
w_shape = hqq_layer.meta['shape']
nbits = w_quant_params['nbits']
group_size = w_quant_params['group_size']
max_int = 2 ** nbits - 1
min_int = 0
scales = hqq_layer.meta['scale'].clone()
zeros = torch.round(hqq_layer.meta['zero'].clone())
#w = linear_layer.weight.data
w = W_ref.half().clone().to('cuda').reshape(-1, group_size)
linear_layer.weight.data = ((torch.clamp(torch.round(w / scales) + zeros, min_int, max_int) - zeros) * scales).reshape(w_shape) #Original AWQ
#Need to re-write AWQ to me it work with this logic
scales = scales.reshape([w_shape[0], -1])
zeros = zeros.reshape([w_shape[0], -1])
if version == 'GEMM':
scales = scales.t().contiguous()
zeros = zeros.t().contiguous()
q_linear_module = WQLinear_GEMM
elif version == 'GEMV':
q_linear_module = WQLinear_GEMV
awq_layer = q_linear_module.from_linear(linear_layer, w_bit=nbits, group_size=group_size, init_only=False, scales=scales, zeros=zeros)
new_layer = awq_layer
##########################################################
# def pseudo_quantize_tensor(w, w_bit, group_size, get_scale_zp=False):
# org_w_shape = w.shape
# if group_size > 0:
# assert org_w_shape[-1] % group_size == 0
# w = w.reshape(-1, group_size)
# assert w.dim() == 2
# # zero point quantization
# max_val = w.amax(dim=1, keepdim=True)
# min_val = w.amin(dim=1, keepdim=True)
# max_int = 2 ** w_bit - 1
# min_int = 0
# scales = (max_val - min_val).clamp(min=1e-5) / max_int
# zeros = (-torch.round(min_val / scales)).clamp_(min_int, max_int)
# assert torch.isnan(scales).sum() == 0
# assert torch.isnan(w).sum() == 0
# w = (torch.clamp(torch.round(w / scales) + zeros, min_int, max_int) - zeros) * scales
# assert torch.isnan(w).sum() == 0
# w = w.reshape(org_w_shape)
# if get_scale_zp:
# return w, scales.view(w.shape[0], -1), zeros.view(w.shape[0], -1)
# else:
# return w
# linear_layer.weight.data = W_ref.half().clone().to('cuda')
# linear_layer.weight.data, scales, zeros = pseudo_quantize_tensor(linear_layer.weight.data, w_bit=w_quant_params['nbits'], group_size=w_quant_params['group_size'], get_scale_zp=True)
# awq_layer = WQLinear_GEMV.from_linear(linear_layer, w_bit=w_quant_params['nbits'], group_size=w_quant_params['group_size'], init_only=False, scales=scales, zeros=zeros)
# awq_layer = awq_layer.to('cuda')
###########################################################################
# x = 0.01*torch.ones((1, 256, 4096), dtype=torch.float16, device='cuda')
# with torch.no_grad():
# y_ref = linear_layer(x)
# y_hqq = hqq_layer(x)
# y_awq = awq_layer(x)
# print(y_ref)
# print(y_awq)
#################
del w, W_ref, linear_layer, hqq_layer
cleanup()
else:
new_layer = linear_layer.half().cuda()
return new_layer
linear_tags = model.base_class.get_linear_tags()
patch_params = dict([(tag, quant_config) for tag in linear_tags])
model.base_class.patch_model(model, lambda l: l.half().cuda(), patch_linear, patch_params, verbose=True)
# from awq.models.llama import LlamaFuser
# model.max_seq_len = 1024
# fuser = LlamaFuser(model)
# fuser.fuse_transformer()
##########################################################################
model.eval()
#Warmup
for i in range(10):
with torch.no_grad():
out = model(torch.ones((1, 1024), dtype=torch.int32, device='cuda'))
del out
cleanup()
#Eval time
import time
t = []
for i in tqdm(range(100)):
t1 = time.time()
with torch.no_grad():
out = model(torch.ones((1, 1024), dtype=torch.int32, device='cuda'))
torch.cuda.synchronize()
t2 = time.time()
t.append(t2-t1)
print(np.mean(t)) #0.254
del out
cleanup()
from eval_model import eval_wikitext2
eval_wikitext2(model, tokenizer, verbose=True)
# {'perplexity': 5.3757, 'prediction_time': 0.499: Original AWQ
#GEMV: 0.499 / GEMM: 0.572
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment