-
-
Save mobicham/4b08fb0bdf4c3872e5bbf68ec9803137 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch, os | |
#Settings | |
######################################################################################################## | |
#Chose a model | |
model_id = "meta-llama/Llama-2-7b-hf" | |
#model_id = "meta-llama/Llama-2-13b-hf" | |
#model_id = "meta-llama/Llama-2-70b-hf" | |
#model_id = "meta-llama/Llama-2-7b-chat-hf" | |
#model_id = "mistralai/Mistral-7B-Instruct-v0.2" | |
#model_id = "mistralai/Mixtral-8x7B-Instruct-v0.1" | |
device = 'cuda' | |
compute_dtype = torch.bfloat16 | |
use_flash_attn = True | |
cache_path = '/nas/hicham/tmp/' if(os.path.isdir('/nas/hicham/tmp/')) else '.' | |
######################################################################################################## | |
import torch, copy | |
from torch import uint8, int32, bfloat16, nn, Tensor | |
from hqq.core.quantize import * | |
from hqq.core.utils import * | |
class HQQLinearTorchWeightOnlynt4(torch.nn.Module): | |
def __init__( | |
self, | |
linear_layer: nn.Module | None, | |
quant_config: dict, | |
del_orig: bool = True, | |
compute_dtype: torch.dtype = torch.bfloat16, | |
device: str = "cuda", | |
initialize: bool = True, | |
inner_k_tiles=8, | |
padding=True, | |
): | |
super().__init__() | |
self.ready = False | |
self.in_gpu = False | |
self.bias = None | |
self.device = device | |
self.compute_dtype = compute_dtype | |
self.quant_config = copy.deepcopy(quant_config) | |
self.del_orig = del_orig | |
weight_quant_params = self.quant_config['weight_quant_params'] | |
self.groupsize = weight_quant_params['group_size'] | |
self.nbits = weight_quant_params['nbits'] | |
self.inner_k_tiles = inner_k_tiles | |
self.padding = padding | |
assert self.nbits in [1, 2, 4], "Unsupported nbits" | |
assert self.groupsize in [None, 32, 64, 128, 256], "Unsupported groupsize" | |
assert self.inner_k_tiles in [2, 4, 8], "Unsupported tile" | |
self.linear_layer = linear_layer | |
self.compute_dtype = compute_dtype | |
if initialize: | |
self.initialize() | |
###################### Initializers ###################### | |
def initialize_with_hqq_quants(self, W_q, meta, bias=None): | |
self.padding = False #Force padding off, a bit tricky to post-pad with grouping | |
self.set_shape(meta['shape']) | |
self.process_hqq_quants(W_q, meta) | |
self.bias = bias | |
self.ready = True | |
self.in_gpu = True | |
torch.cuda.empty_cache() | |
return self | |
def initialize(self): | |
if self.linear_layer is not None: | |
W = self.linear_layer.weight.data | |
self.set_shape(W.shape) | |
if(self.in_features_diff>0): | |
W = F.pad(W, pad=(0, self.in_features_diff), value=0) | |
W_q, meta = self.quantize(W, **self.quant_config) | |
self.process_hqq_quants(W_q, meta) | |
del W_q, meta | |
self.bias = ( | |
None | |
if (self.linear_layer.bias is None) | |
else self.linear_layer.bias.to(dtype=self.compute_dtype, device=self.device) | |
) | |
if self.del_orig: | |
del self.linear_layer | |
self.ready = True | |
self.in_gpu = True | |
torch.cuda.empty_cache() | |
return self | |
###################### Quantize/packing ###################### | |
def quantize(self, W: Tensor, weight_quant_params: dict, scale_quant_params=dict | None, zero_quant_params=dict | None, offload_meta=False): | |
W_q, meta = Quantizer.quantize(W, **weight_quant_params, device=self.device, compute_dtype=self.compute_dtype, bitpack=False) | |
#ToDO: meta quantization | |
return W_q, meta | |
#TODO: move these to utils | |
@torch.no_grad() | |
def reshape_meta_axis1(self, meta_tensor, new_group_size, shape): | |
meta_tensor = meta_tensor.repeat([1, shape[1]]).reshape(shape) | |
meta_tensor = torch.mean(meta_tensor.reshape([-1, new_group_size]), axis=1, keepdim=True) | |
return meta_tensor | |
def find_multiple(self, n: int, k: int) -> int: | |
if n % k == 0: return n | |
return n + k - (n % k) | |
def set_shape(self, shape): | |
self.shape = shape | |
self.in_features = shape[1] | |
self.out_features = shape[0] | |
self.origin_in_features = self.in_features | |
if(self.padding): | |
self.in_features = self.find_multiple(self.in_features, 1024) | |
self.in_features_diff = self.in_features - self.origin_in_features | |
@torch.no_grad() | |
def process_hqq_quants(self, W_q, meta): | |
scales = meta['scale'] | |
zeros = meta['zero'] | |
shape = meta['shape'] | |
if(meta["packing"] is not None): | |
W_q = Quantizer.unpack[meta['packing']](W_q) | |
if(self.groupsize is None): | |
self.groupsize = 128 | |
W_q = W_q.reshape([-1, self.groupsize]) | |
scales = self.reshape_meta_axis1(scales, self.groupsize, shape) | |
zeros = self.reshape_meta_axis1(zeros, self.groupsize, shape) | |
W_q_torch, scales_torch, zeros_torch = self.hqq_quants_to_torch_quants(W_q=W_q, scales=scales, zeros=zeros, shape=shape, nbits=self.nbits) | |
self.weight_int4pack = torch.ops.aten._convert_weight_to_int4pack(W_q_torch, self.inner_k_tiles) | |
self.scales_and_zeros = self.pack_scales_and_zeros(scales_torch, zeros_torch) | |
del W_q_torch, scales_torch, zeros_torch | |
torch.cuda.empty_cache() | |
@torch.no_grad() | |
def hqq_quants_to_torch_quants(self, W_q: Tensor, scales: Tensor, zeros: Tensor, shape, nbits=4): | |
W_q = W_q.to(dtype=self.compute_dtype, device=self.device) | |
scales = scales.to(dtype=self.compute_dtype, device=self.device) | |
zeros = zeros.to(dtype=self.compute_dtype, device=self.device) | |
max_int = 2**nbits - 1 | |
min_int = 0 | |
dump = 2 ** (nbits - 1) | |
#HQQ -> torch logic | |
new_zeros = (scales * dump) - zeros*scales | |
min_val = new_zeros - scales * dump | |
#group_quantize_tensor_from_qparams | |
W_r = (W_q - zeros)*scales | |
W_q = W_r.sub(min_val).div(scales).round().clamp_(min_int, max_int).to(torch.int32).reshape(shape).contiguous() | |
#group_dequantize_tensor_from_qparams | |
#W_r = W_q*scales + min_val | |
scales = scales.contiguous().reshape(shape[0], -1) | |
new_zeros = new_zeros.contiguous().reshape(shape[0], -1) | |
return W_q, scales, new_zeros | |
def pack_scales_and_zeros(self, scales, zeros): | |
assert scales.shape == zeros.shape | |
assert scales.dtype == torch.bfloat16 | |
assert zeros.dtype == torch.bfloat16 | |
return ( | |
torch.cat( | |
[ | |
scales.reshape(scales.size(0), scales.size(1), 1), | |
zeros.reshape(zeros.size(0), zeros.size(1), 1), | |
], | |
2, | |
) | |
.transpose(0, 1) | |
.contiguous() | |
) | |
###################### Forward/matmul ###################### | |
@torch.jit.ignore() | |
def matmul(self, x): | |
origin_x_size = x.size() | |
x = x.reshape(-1, origin_x_size[-1]) | |
c = torch.ops.aten._weight_int4pack_mm(x, self.weight_int4pack, self.groupsize, self.scales_and_zeros) | |
new_shape = origin_x_size[:-1] + (self.out_features,) | |
c = c.reshape(new_shape) | |
return c | |
#TODO without matmul | |
def dequantize(self): | |
return self.matmul(torch.eye(self.in_features, dtype=self.compute_dtype, device=self.device))[:self.origin_in_features].t() | |
#TODO: backward | |
def forward(self, x: torch.Tensor) -> torch.Tensor: | |
x = x.to(self.compute_dtype) | |
if self.in_features_diff > 0: | |
x = F.pad(x, pad=(0, self.in_features_diff)) | |
out = self.matmul(x) | |
if(self.bias is not None): | |
out += self.bias | |
return out | |
###################### Patching ###################### | |
def patch_HQQLinear_to_HQQLinearTorchWeightOnlynt4(layer, patch_params): | |
new_layer = layer | |
if(type(layer) is HQQLinear): | |
new_layer = HQQLinearTorchWeightOnlynt4(None, quant_config=layer.quant_config, compute_dtype=layer.compute_dtype, device=layer.device, del_orig=False, initialize=False, padding=False) | |
new_layer.initialize_with_hqq_quants(layer.W_q, layer.meta, layer.bias) | |
return new_layer | |
def replace_with_torchInt4(model): | |
model.base_class.patch_linearlayers(model, patch_HQQLinear_to_HQQLinearTorchWeightOnlynt4, dict([(k, None) for k in model.base_class.get_linear_tags()])) | |
cleanup() | |
#Force requantize, mainly to check if the padding with int4mm is faster | |
def patch_HQQLinear_to_HQQLinearTorchWeightOnlynt4_force_requantize(layer, patch_params): | |
new_layer = layer | |
if(type(layer) is HQQLinear): | |
#Create dummy linear layer to store dequantize weights | |
dummy_linear = torch.nn.Linear(1, 1, bias=False) | |
dummy_linear.weight.data = layer.dequantize() | |
#Disable optimizer on already dequantized weights | |
quant_config = layer.quant_config | |
quant_config['weight_quant_params']['optimize'] = False | |
new_layer = HQQLinearTorchWeightOnlynt4(dummy_linear, quant_config=quant_config, compute_dtype=layer.compute_dtype, device=layer.device, del_orig=True, initialize=True, padding=True) | |
del layer | |
cleanup() | |
return new_layer | |
def replace_with_torchInt4_force_requantize(model): | |
model.base_class.patch_linearlayers(model, patch_HQQLinear_to_HQQLinearTorchWeightOnlynt4_force_requantize, dict([(k, None) for k in model.base_class.get_linear_tags()])) | |
cleanup() | |
#HQQ Quantize | |
######################################################################################################## | |
from hqq.engine.hf import HQQModelForCausalLM, AutoTokenizer | |
from hqq.core.quantize import * | |
from hqq.core.utils import * | |
model = HQQModelForCausalLM.from_pretrained(model_id, cache_dir=cache_path, torch_dtype=compute_dtype, attn_implementation="flash_attention_2" if use_flash_attn else "sdpa") | |
tokenizer = AutoTokenizer.from_pretrained(model_id, cache_dir=cache_path) | |
quant_config = BaseQuantizeConfig(nbits=4, group_size=64, quant_scale=False, quant_zero=False) | |
quant_config['weight_quant_params']['axis'] = 1 | |
quant_config['weight_quant_params']['round_zero'] = True | |
#GPTQ ref score: 5.38 | |
#axis=0 | round_zero=True | group_size=64 = HQQLinear: 5.303 | |
#axis=1 | round_zero=True | group_size=64 = HQQLinear: 5.3363 | AO4bit: 5.3374 | |
model.quantize_model(quant_config=quant_config, compute_dtype=compute_dtype, device=device) | |
HQQLinear.set_backend(HQQBackend.PYTORCH) | |
######################################### | |
#Use fused torch int4mm | |
replace_with_torchInt4(model) | |
######################################### | |
model.config.use_cache = False | |
torch.backends.cuda.matmul.allow_tf32 = True | |
torch.backends.cudnn.allow_tf32 = True | |
model.forward = torch.compile(model.forward) | |
#warm-up | |
with torch.no_grad(): | |
out = model(torch.ones((1, 1024), dtype=torch.int32, device=device)) | |
del out | |
######################################################################################################## | |
import time | |
import numpy as np | |
t = [] | |
for _ in range(100): | |
with torch.no_grad(): | |
data = torch.randint(0, 100, (1, 1024), dtype=torch.int32, device=device) | |
t1 = time.time() | |
out = model(data) | |
torch.cuda.synchronize() | |
t2 = time.time() | |
t.append(t2-t1) | |
print(np.mean(t[-50:])) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment