Last active
April 24, 2024 13:28
-
-
Save mobicham/cb07c1eff443ad0918c49ab7bb03e269 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch, os | |
cache_path = '.' | |
compute_dtype = torch.float16 | |
device = 'cuda:0' | |
################################################################################################### | |
from transformers import AutoModelForCausalLM, AutoTokenizer, HqqConfig | |
model_id = "meta-llama/Llama-2-7b-hf" | |
#Basic | |
#Linear layers will use the same quantization config | |
quant_config = HqqConfig(nbits=8, group_size=64, quant_zero=False, quant_scale=False, axis=0) #axis=0 is used by default | |
#Each type of linear layer (referred to as linear tag) will use different quantization parameters | |
# q4_config = {'nbits':4, 'group_size':64, 'quant_zero':False, 'quant_scale':False} | |
# q3_config = {'nbits':3, 'group_size':32, 'quant_zero':False, 'quant_scale':False} | |
# quant_config = HqqConfig(dynamic_config={ | |
# 'self_attn.q_proj':q4_config, | |
# 'self_attn.k_proj':q4_config, | |
# 'self_attn.v_proj':q4_config, | |
# 'self_attn.o_proj':q4_config, | |
# 'mlp.gate_proj':q3_config, | |
# 'mlp.up_proj' :q3_config, | |
# 'mlp.down_proj':q3_config, | |
# }) | |
##################################################################################################### | |
model = AutoModelForCausalLM.from_pretrained(model_id, | |
cache_dir=cache_path, | |
torch_dtype=compute_dtype, | |
device_map="auto", #device | |
low_cpu_mem_usage=True, | |
quantization_config=quant_config | |
) | |
#Set backend | |
from hqq.core.quantize import * | |
from hqq.core.utils import cleanup | |
HQQLinear.set_backend(HQQBackend.ATEN) | |
#Forward | |
with torch.no_grad(): | |
out = model(torch.zeros([1, 1024], device=device, dtype=torch.int32)).logits | |
print(out) | |
del out | |
cleanup() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment