-
-
Save mobicham/8b3147742beb3b302064453a15ced428 to your computer and use it in GitHub Desktop.
awq_hqq_test.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#Settings | |
###################################################################################### | |
hf_auth = None #HuggingFace token | |
cache_path = '' #cache directory to store data | |
#Chose a model | |
model_id = "meta-llama/Llama-2-7b-hf" | |
#model_id = "meta-llama/Llama-2-13b-hf" | |
#model_id = "meta-llama/Llama-2-70b-hf" | |
#HQQ Quantize | |
###################################################################################### | |
from hqq.engine.hf import HQQModelForCausalLM, AutoTokenizer | |
model = HQQModelForCausalLM.from_pretrained(model_id, use_auth_token=hf_auth, cache_dir=cache_path) | |
tokenizer = AutoTokenizer.from_pretrained(model_id, use_auth_token=hf_auth, cache_dir=cache_path) | |
from hqq.core.quantize import * | |
quant_config = BaseQuantizeConfig(nbits=4, group_size=128, quant_scale=False, quant_zero=False) | |
################################################################################################ | |
# #HQQ Orig | |
# model.quantize_model(quant_config=quant_config) | |
# HQQLinear.set_backend(HQQBackend.PYTORCH_COMPILE) | |
################################################################################################ | |
#Awq Patching | |
from awq.modules.linear import WQLinear_GEMM, WQLinear_GEMV | |
from hqq.core.quantize import * | |
def patch_linear(linear_layer, quant_config): | |
if(quant_config): | |
w_quant_params = quant_config['weight_quant_params'] | |
w_quant_params.update({'axis':1}) | |
hqq_layer = HQQLinear(linear_layer, quant_config=quant_config, del_orig=False) | |
w_shape = hqq_layer.meta['shape'] | |
W_ref = linear_layer.weight.data.clone() | |
########################################################## | |
version = 'GEMV' #'GEMV' | |
device = hqq_layer.device | |
w_shape = hqq_layer.meta['shape'] | |
nbits = w_quant_params['nbits'] | |
group_size = w_quant_params['group_size'] | |
max_int = 2 ** nbits - 1 | |
min_int = 0 | |
scales = hqq_layer.meta['scale'].clone() | |
zeros = torch.round(hqq_layer.meta['zero'].clone()) | |
#w = linear_layer.weight.data | |
w = W_ref.half().clone().to('cuda').reshape(-1, group_size) | |
linear_layer.weight.data = ((torch.clamp(torch.round(w / scales) + zeros, min_int, max_int) - zeros) * scales).reshape(w_shape) #Original AWQ | |
#Need to re-write AWQ to me it work with this logic | |
scales = scales.reshape([w_shape[0], -1]) | |
zeros = zeros.reshape([w_shape[0], -1]) | |
if version == 'GEMM': | |
scales = scales.t().contiguous() | |
zeros = zeros.t().contiguous() | |
q_linear_module = WQLinear_GEMM | |
elif version == 'GEMV': | |
q_linear_module = WQLinear_GEMV | |
awq_layer = q_linear_module.from_linear(linear_layer, w_bit=nbits, group_size=group_size, init_only=False, scales=scales, zeros=zeros) | |
new_layer = awq_layer | |
########################################################## | |
# def pseudo_quantize_tensor(w, w_bit, group_size, get_scale_zp=False): | |
# org_w_shape = w.shape | |
# if group_size > 0: | |
# assert org_w_shape[-1] % group_size == 0 | |
# w = w.reshape(-1, group_size) | |
# assert w.dim() == 2 | |
# # zero point quantization | |
# max_val = w.amax(dim=1, keepdim=True) | |
# min_val = w.amin(dim=1, keepdim=True) | |
# max_int = 2 ** w_bit - 1 | |
# min_int = 0 | |
# scales = (max_val - min_val).clamp(min=1e-5) / max_int | |
# zeros = (-torch.round(min_val / scales)).clamp_(min_int, max_int) | |
# assert torch.isnan(scales).sum() == 0 | |
# assert torch.isnan(w).sum() == 0 | |
# w = (torch.clamp(torch.round(w / scales) + zeros, min_int, max_int) - zeros) * scales | |
# assert torch.isnan(w).sum() == 0 | |
# w = w.reshape(org_w_shape) | |
# if get_scale_zp: | |
# return w, scales.view(w.shape[0], -1), zeros.view(w.shape[0], -1) | |
# else: | |
# return w | |
# linear_layer.weight.data = W_ref.half().clone().to('cuda') | |
# linear_layer.weight.data, scales, zeros = pseudo_quantize_tensor(linear_layer.weight.data, w_bit=w_quant_params['nbits'], group_size=w_quant_params['group_size'], get_scale_zp=True) | |
# awq_layer = WQLinear_GEMV.from_linear(linear_layer, w_bit=w_quant_params['nbits'], group_size=w_quant_params['group_size'], init_only=False, scales=scales, zeros=zeros) | |
# awq_layer = awq_layer.to('cuda') | |
########################################################################### | |
# x = 0.01*torch.ones((1, 256, 4096), dtype=torch.float16, device='cuda') | |
# with torch.no_grad(): | |
# y_ref = linear_layer(x) | |
# y_hqq = hqq_layer(x) | |
# y_awq = awq_layer(x) | |
# print(y_ref) | |
# print(y_awq) | |
################# | |
del w, W_ref, linear_layer, hqq_layer | |
cleanup() | |
else: | |
new_layer = linear_layer.half().cuda() | |
return new_layer | |
linear_tags = model.base_class.get_linear_tags() | |
patch_params = dict([(tag, quant_config) for tag in linear_tags]) | |
model.base_class.patch_model(model, lambda l: l.half().cuda(), patch_linear, patch_params, verbose=True) | |
# from awq.models.llama import LlamaFuser | |
# model.max_seq_len = 1024 | |
# fuser = LlamaFuser(model) | |
# fuser.fuse_transformer() | |
########################################################################## | |
model.eval() | |
#Warmup | |
for i in range(10): | |
with torch.no_grad(): | |
out = model(torch.ones((1, 1024), dtype=torch.int32, device='cuda')) | |
del out | |
cleanup() | |
#Eval time | |
import time | |
t = [] | |
for i in tqdm(range(100)): | |
t1 = time.time() | |
with torch.no_grad(): | |
out = model(torch.ones((1, 1024), dtype=torch.int32, device='cuda')) | |
torch.cuda.synchronize() | |
t2 = time.time() | |
t.append(t2-t1) | |
print(np.mean(t)) #0.254 | |
del out | |
cleanup() | |
from eval_model import eval_wikitext2 | |
eval_wikitext2(model, tokenizer, verbose=True) | |
# {'perplexity': 5.3757, 'prediction_time': 0.499: Original AWQ | |
#GEMV: 0.499 / GEMM: 0.572 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment