Skip to content

Instantly share code, notes, and snippets.

@mobicham
Last active April 11, 2024 12:06
Show Gist options
  • Save mobicham/4b08fb0bdf4c3872e5bbf68ec9803137 to your computer and use it in GitHub Desktop.
Save mobicham/4b08fb0bdf4c3872e5bbf68ec9803137 to your computer and use it in GitHub Desktop.
import torch, os
#Settings
########################################################################################################
#Chose a model
model_id = "meta-llama/Llama-2-7b-hf"
#model_id = "meta-llama/Llama-2-13b-hf"
#model_id = "meta-llama/Llama-2-70b-hf"
#model_id = "meta-llama/Llama-2-7b-chat-hf"
#model_id = "mistralai/Mistral-7B-Instruct-v0.2"
#model_id = "mistralai/Mixtral-8x7B-Instruct-v0.1"
device = 'cuda'
compute_dtype = torch.bfloat16
use_flash_attn = True
cache_path = '/nas/hicham/tmp/' if(os.path.isdir('/nas/hicham/tmp/')) else '.'
########################################################################################################
import torch, copy
from torch import uint8, int32, bfloat16, nn, Tensor
from hqq.core.quantize import *
from hqq.core.utils import *
class HQQLinearTorchWeightOnlynt4(torch.nn.Module):
def __init__(
self,
linear_layer: nn.Module | None,
quant_config: dict,
del_orig: bool = True,
compute_dtype: torch.dtype = torch.bfloat16,
device: str = "cuda",
initialize: bool = True,
inner_k_tiles=8,
padding=True,
):
super().__init__()
self.ready = False
self.in_gpu = False
self.bias = None
self.device = device
self.compute_dtype = compute_dtype
self.quant_config = copy.deepcopy(quant_config)
self.del_orig = del_orig
weight_quant_params = self.quant_config['weight_quant_params']
self.groupsize = weight_quant_params['group_size']
self.nbits = weight_quant_params['nbits']
self.inner_k_tiles = inner_k_tiles
self.padding = padding
assert self.nbits in [1, 2, 4], "Unsupported nbits"
assert self.groupsize in [None, 32, 64, 128, 256], "Unsupported groupsize"
assert self.inner_k_tiles in [2, 4, 8], "Unsupported tile"
self.linear_layer = linear_layer
self.compute_dtype = compute_dtype
if initialize:
self.initialize()
###################### Initializers ######################
def initialize_with_hqq_quants(self, W_q, meta, bias=None):
self.padding = False #Force padding off, a bit tricky to post-pad with grouping
self.set_shape(meta['shape'])
self.process_hqq_quants(W_q, meta)
self.bias = bias
self.ready = True
self.in_gpu = True
torch.cuda.empty_cache()
return self
def initialize(self):
if self.linear_layer is not None:
W = self.linear_layer.weight.data
self.set_shape(W.shape)
if(self.in_features_diff>0):
W = F.pad(W, pad=(0, self.in_features_diff), value=0)
W_q, meta = self.quantize(W, **self.quant_config)
self.process_hqq_quants(W_q, meta)
del W_q, meta
self.bias = (
None
if (self.linear_layer.bias is None)
else self.linear_layer.bias.to(dtype=self.compute_dtype, device=self.device)
)
if self.del_orig:
del self.linear_layer
self.ready = True
self.in_gpu = True
torch.cuda.empty_cache()
return self
###################### Quantize/packing ######################
def quantize(self, W: Tensor, weight_quant_params: dict, scale_quant_params=dict | None, zero_quant_params=dict | None, offload_meta=False):
W_q, meta = Quantizer.quantize(W, **weight_quant_params, device=self.device, compute_dtype=self.compute_dtype, bitpack=False)
#ToDO: meta quantization
return W_q, meta
#TODO: move these to utils
@torch.no_grad()
def reshape_meta_axis1(self, meta_tensor, new_group_size, shape):
meta_tensor = meta_tensor.repeat([1, shape[1]]).reshape(shape)
meta_tensor = torch.mean(meta_tensor.reshape([-1, new_group_size]), axis=1, keepdim=True)
return meta_tensor
def find_multiple(self, n: int, k: int) -> int:
if n % k == 0: return n
return n + k - (n % k)
def set_shape(self, shape):
self.shape = shape
self.in_features = shape[1]
self.out_features = shape[0]
self.origin_in_features = self.in_features
if(self.padding):
self.in_features = self.find_multiple(self.in_features, 1024)
self.in_features_diff = self.in_features - self.origin_in_features
@torch.no_grad()
def process_hqq_quants(self, W_q, meta):
scales = meta['scale']
zeros = meta['zero']
shape = meta['shape']
if(meta["packing"] is not None):
W_q = Quantizer.unpack[meta['packing']](W_q)
if(self.groupsize is None):
self.groupsize = 128
W_q = W_q.reshape([-1, self.groupsize])
scales = self.reshape_meta_axis1(scales, self.groupsize, shape)
zeros = self.reshape_meta_axis1(zeros, self.groupsize, shape)
W_q_torch, scales_torch, zeros_torch = self.hqq_quants_to_torch_quants(W_q=W_q, scales=scales, zeros=zeros, shape=shape, nbits=self.nbits)
self.weight_int4pack = torch.ops.aten._convert_weight_to_int4pack(W_q_torch, self.inner_k_tiles)
self.scales_and_zeros = self.pack_scales_and_zeros(scales_torch, zeros_torch)
del W_q_torch, scales_torch, zeros_torch
torch.cuda.empty_cache()
@torch.no_grad()
def hqq_quants_to_torch_quants(self, W_q: Tensor, scales: Tensor, zeros: Tensor, shape, nbits=4):
W_q = W_q.to(dtype=self.compute_dtype, device=self.device)
scales = scales.to(dtype=self.compute_dtype, device=self.device)
zeros = zeros.to(dtype=self.compute_dtype, device=self.device)
max_int = 2**nbits - 1
min_int = 0
dump = 2 ** (nbits - 1)
#HQQ -> torch logic
new_zeros = (scales * dump) - zeros*scales
min_val = new_zeros - scales * dump
#group_quantize_tensor_from_qparams
W_r = (W_q - zeros)*scales
W_q = W_r.sub(min_val).div(scales).round().clamp_(min_int, max_int).to(torch.int32).reshape(shape).contiguous()
#group_dequantize_tensor_from_qparams
#W_r = W_q*scales + min_val
scales = scales.contiguous().reshape(shape[0], -1)
new_zeros = new_zeros.contiguous().reshape(shape[0], -1)
return W_q, scales, new_zeros
def pack_scales_and_zeros(self, scales, zeros):
assert scales.shape == zeros.shape
assert scales.dtype == torch.bfloat16
assert zeros.dtype == torch.bfloat16
return (
torch.cat(
[
scales.reshape(scales.size(0), scales.size(1), 1),
zeros.reshape(zeros.size(0), zeros.size(1), 1),
],
2,
)
.transpose(0, 1)
.contiguous()
)
###################### Forward/matmul ######################
@torch.jit.ignore()
def matmul(self, x):
origin_x_size = x.size()
x = x.reshape(-1, origin_x_size[-1])
c = torch.ops.aten._weight_int4pack_mm(x, self.weight_int4pack, self.groupsize, self.scales_and_zeros)
new_shape = origin_x_size[:-1] + (self.out_features,)
c = c.reshape(new_shape)
return c
#TODO without matmul
def dequantize(self):
return self.matmul(torch.eye(self.in_features, dtype=self.compute_dtype, device=self.device))[:self.origin_in_features].t()
#TODO: backward
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = x.to(self.compute_dtype)
if self.in_features_diff > 0:
x = F.pad(x, pad=(0, self.in_features_diff))
out = self.matmul(x)
if(self.bias is not None):
out += self.bias
return out
###################### Patching ######################
def patch_HQQLinear_to_HQQLinearTorchWeightOnlynt4(layer, patch_params):
new_layer = layer
if(type(layer) is HQQLinear):
new_layer = HQQLinearTorchWeightOnlynt4(None, quant_config=layer.quant_config, compute_dtype=layer.compute_dtype, device=layer.device, del_orig=False, initialize=False, padding=False)
new_layer.initialize_with_hqq_quants(layer.W_q, layer.meta, layer.bias)
return new_layer
def replace_with_torchInt4(model):
model.base_class.patch_linearlayers(model, patch_HQQLinear_to_HQQLinearTorchWeightOnlynt4, dict([(k, None) for k in model.base_class.get_linear_tags()]))
cleanup()
#Force requantize, mainly to check if the padding with int4mm is faster
def patch_HQQLinear_to_HQQLinearTorchWeightOnlynt4_force_requantize(layer, patch_params):
new_layer = layer
if(type(layer) is HQQLinear):
#Create dummy linear layer to store dequantize weights
dummy_linear = torch.nn.Linear(1, 1, bias=False)
dummy_linear.weight.data = layer.dequantize()
#Disable optimizer on already dequantized weights
quant_config = layer.quant_config
quant_config['weight_quant_params']['optimize'] = False
new_layer = HQQLinearTorchWeightOnlynt4(dummy_linear, quant_config=quant_config, compute_dtype=layer.compute_dtype, device=layer.device, del_orig=True, initialize=True, padding=True)
del layer
cleanup()
return new_layer
def replace_with_torchInt4_force_requantize(model):
model.base_class.patch_linearlayers(model, patch_HQQLinear_to_HQQLinearTorchWeightOnlynt4_force_requantize, dict([(k, None) for k in model.base_class.get_linear_tags()]))
cleanup()
#HQQ Quantize
########################################################################################################
from hqq.engine.hf import HQQModelForCausalLM, AutoTokenizer
from hqq.core.quantize import *
from hqq.core.utils import *
model = HQQModelForCausalLM.from_pretrained(model_id, cache_dir=cache_path, torch_dtype=compute_dtype, attn_implementation="flash_attention_2" if use_flash_attn else "sdpa")
tokenizer = AutoTokenizer.from_pretrained(model_id, cache_dir=cache_path)
quant_config = BaseQuantizeConfig(nbits=4, group_size=64, quant_scale=False, quant_zero=False)
quant_config['weight_quant_params']['axis'] = 1
quant_config['weight_quant_params']['round_zero'] = True
#GPTQ ref score: 5.38
#axis=0 | round_zero=True | group_size=64 = HQQLinear: 5.303
#axis=1 | round_zero=True | group_size=64 = HQQLinear: 5.3363 | AO4bit: 5.3374
model.quantize_model(quant_config=quant_config, compute_dtype=compute_dtype, device=device)
HQQLinear.set_backend(HQQBackend.PYTORCH)
#########################################
#Use fused torch int4mm
replace_with_torchInt4(model)
#########################################
model.config.use_cache = False
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
model.forward = torch.compile(model.forward)
#warm-up
with torch.no_grad():
out = model(torch.ones((1, 1024), dtype=torch.int32, device=device))
del out
########################################################################################################
import time
import numpy as np
t = []
for _ in range(100):
with torch.no_grad():
data = torch.randint(0, 100, (1, 1024), dtype=torch.int32, device=device)
t1 = time.time()
out = model(data)
torch.cuda.synchronize()
t2 = time.time()
t.append(t2-t1)
print(np.mean(t[-50:]))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment