Last active
January 4, 2024 18:07
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
## pip install autoawq==0.1.7 | |
## pip install transformers==4.36.2 | |
import functools | |
import gc | |
import os | |
from collections import defaultdict | |
from typing import Dict, List, Optional, Tuple, Union | |
import torch | |
import torch.nn as nn | |
from tqdm import tqdm | |
from awq.models.base import BaseAWQForCausalLM | |
from awq.models._config import AwqConfig | |
from awq.modules.fused.attn import QuantAttentionFused | |
from awq.modules.fused.mlp import QuantFusedMLP | |
from awq.modules.fused.norm import FasterTransformerRMSNorm | |
from awq.modules.linear import WQLinear_GEMM, WQLinear_GEMV | |
from awq.quantize.quantizer import AwqQuantizer | |
from awq.quantize.scale import apply_clip, apply_scale | |
from awq.utils import fused_utils | |
from awq.utils.calib_data import get_calib_dataset | |
from awq.utils.module import ( | |
append_str_prefix, | |
get_named_linears, | |
get_op_name, | |
set_op_by_name, | |
) | |
from awq.utils.utils import clear_memory | |
from transformers.models.mixtral.modeling_mixtral import ( | |
MixtralDecoderLayer as OldmixtralDecoderLayer, | |
MixtralForCausalLM as OldmixtralForCausalLM, | |
MoeModelOutputWithPast, | |
) | |
def _exclude_layers_to_not_quantize(linear_layers, modules_to_not_convert): | |
filtered_layers = {} | |
for name, linear_layer in linear_layers.items(): | |
if not any(key in name for key in modules_to_not_convert): | |
filtered_layers[name] = linear_layer | |
return filtered_layers | |
class MixtralAwqConfig(AwqConfig): | |
def __init__( | |
self, | |
*args, | |
modules_to_not_convert: Optional[List] = None, | |
**kwargs, | |
): | |
super().__init__(*args, **kwargs) | |
self.modules_to_not_convert = modules_to_not_convert | |
def to_dict(self): | |
return { | |
**super().to_dict(), | |
"modules_to_not_convert": self.modules_to_not_convert, | |
} | |
def to_transformers_dict(self): | |
return { | |
**super().to_transformers_dict(), | |
"modules_to_not_convert": self.modules_to_not_convert, | |
} | |
class MixtralAwqQuantizer(AwqQuantizer): | |
def __init__(self, *args, modules_to_not_convert: List = [], **kwargs): | |
super().__init__(*args) | |
self.modules_to_not_convert = modules_to_not_convert | |
def init_quant(self, n_samples=128, seqlen=512): | |
modules = self.awq_model.get_model_layers(self.model) | |
samples = get_calib_dataset( | |
data=self.calib_data, | |
tokenizer=self.tokenizer, | |
n_samples=n_samples, | |
block_size=seqlen, | |
split=self.split, | |
text_column=self.text_column, | |
) | |
samples = torch.cat(samples, dim=0) | |
inps = [] | |
layer_kwargs = {} | |
modules[0] = modules[0].cuda() | |
self.awq_model.move_embed(self.model, "cuda") | |
# get input and kwargs to layer 0 | |
# with_kwargs is only supported in PyTorch 2.0 | |
# use this Catcher hack for now | |
class Catcher(nn.Module): | |
def __init__(self, module): | |
super().__init__() | |
self.module = module | |
#### <MODIFIED> | |
def forward(self, *args, **kwargs): | |
# assume first input to forward is hidden states | |
if len(args) > 0: | |
hidden_states = args[0] | |
del args | |
else: | |
first_key = list(kwargs.keys())[0] | |
hidden_states = kwargs.pop(first_key) | |
inps.append(hidden_states) | |
#### </MODIFIED> | |
layer_kwargs.update(kwargs) | |
raise ValueError # early exit to break later inference | |
# patch layer 0 to catch input and kwargs | |
modules[0] = Catcher(modules[0]) | |
try: | |
self.model(samples.to(next(self.model.parameters()).device)) | |
except ValueError: # work with early exit | |
pass | |
#### <MODIFIED> | |
# Update the layer kwargs with `prepare_inputs_for_generation` method | |
# that takes care of everything to avoid unexpected errors. | |
layer_kwargs = self.model.prepare_inputs_for_generation(samples, **layer_kwargs) | |
# Pop the input_ids as they are not needed at all. | |
layer_kwargs.pop("input_ids") | |
#### </MODIFIED> | |
del samples | |
modules[0] = modules[0].module # restore | |
inps = inps[0] | |
modules[0] = modules[0].cpu() | |
self.awq_model.move_embed(self.model, "cpu") | |
clear_memory() | |
if layer_kwargs.get("attention_mask") is not None: | |
layer_kwargs["attention_mask"] = layer_kwargs["attention_mask"].to("cuda") | |
return modules, layer_kwargs, inps | |
def quantize(self): | |
for i in tqdm(range(len(self.modules)), desc="AWQ"): | |
# Move module and inputs to correct device | |
common_device = next(self.modules[i].parameters()).device | |
if common_device is None or str(common_device) == "cpu": | |
self.modules[i] = self.modules[i].cuda() | |
common_device = next(self.modules[i].parameters()).device | |
self.inps = self.inps.to(common_device) | |
# [STEP 1]: Get layer, extract linear modules, extract input features | |
named_linears = get_named_linears(self.modules[i]) | |
#### <MODIFIED> | |
# Filter out the linear layers we don't want to exclude | |
named_linears = _exclude_layers_to_not_quantize( | |
named_linears, self.modules_to_not_convert | |
) | |
#### </MODIFIED> | |
input_feat = self._get_input_feat(self.modules[i], named_linears) | |
clear_memory() | |
# [STEP 2]: Compute and apply scale list | |
module_config: List[Dict] = self.awq_model.get_layers_for_scaling( | |
self.modules[i], input_feat, self.module_kwargs | |
) | |
scales_list = [ | |
self._search_best_scale(self.modules[i], **layer) for layer in module_config | |
] | |
apply_scale(self.modules[i], scales_list, input_feat_dict=input_feat) | |
scales_list = append_str_prefix( | |
scales_list, get_op_name(self.model, self.modules[i]) + "." | |
) | |
# [STEP 3]: Compute and apply clipping list | |
clip_list = self._search_best_clip(self.modules[i], named_linears, input_feat) | |
apply_clip(self.modules[i], clip_list) | |
clip_list = append_str_prefix(clip_list, get_op_name(self.model, self.modules[i]) + ".") | |
# [STEP 4]: Quantize weights | |
self._apply_quant(self.modules[i], named_linears) | |
clear_memory() | |
def _get_input_feat(self, layer, named_linears): | |
# firstly, get input features of all linear layers | |
def cache_input_hook(m, x, y, name, feat_dict): | |
x = x[0] | |
x = x.detach().cpu() | |
feat_dict[name].append(x) | |
input_feat = defaultdict(list) | |
handles = [] | |
#### <MODIFIED> | |
named_linears = {**named_linears, "block_sparse_moe": layer.block_sparse_moe} | |
#### </MODIFIED> | |
for name in named_linears: | |
handles.append( | |
named_linears[name].register_forward_hook( | |
functools.partial(cache_input_hook, name=name, feat_dict=input_feat) | |
) | |
) | |
self.inps = self.inps.to(next(layer.parameters()).device) # in case multi-gpu | |
# get output as next layer's input | |
self.inps = layer(self.inps, **self.module_kwargs)[0] | |
for h in handles: | |
h.remove() | |
# now solve for scaling and clipping | |
input_feat = {k: torch.cat(v, dim=0) for k, v in input_feat.items()} | |
return input_feat | |
class MixtralAWQForCausalLM(BaseAWQForCausalLM): | |
layer_type = "MixtralDecoderLayer" | |
max_new_tokens_key = "max_position_embeddings" | |
@torch.no_grad() | |
def quantize( | |
self, | |
tokenizer=None, | |
quant_config={}, | |
calib_data: Union[str, List[str]] = "pileval", | |
split="train", | |
text_column="text", | |
): | |
self.quant_config = MixtralAwqConfig.from_dict(quant_config) | |
#### <MODIFIED> | |
quantizer = MixtralAwqQuantizer( | |
self, | |
self.model, | |
tokenizer, | |
self.quant_config.w_bit, | |
self.quant_config.q_group_size, | |
self.quant_config.version, | |
calib_data, | |
split, | |
text_column, | |
modules_to_not_convert=self.quant_config.modules_to_not_convert, | |
) | |
#### </MODIFIED> | |
quantizer.quantize() | |
self.is_quantized = True | |
def _load_quantized_modules(self, model, quant_config, version): | |
# Real quantization of weights | |
assert quant_config.zero_point, "We only support zero_point quantization now." | |
# Get blocks of model | |
layers = self.get_model_layers(model) | |
for i in tqdm(range(len(layers)), desc="Replacing layers..."): | |
layer = layers[i] | |
# Get every linear layer in a block | |
named_linears = get_named_linears(layer) | |
#### <MODIFIED> | |
# Filter out the linear layers we don't want to exclude | |
named_linears = _exclude_layers_to_not_quantize( | |
named_linears, quant_config.modules_to_not_convert | |
) | |
#### </MODIFIED> | |
# Replace activation functions | |
self._scale_activations(self, layer) | |
# Replace nn.Linear with WQLinear | |
for name, module in named_linears.items(): | |
if version == "GEMM": | |
q_linear_module = WQLinear_GEMM | |
elif version == "GEMV": | |
q_linear_module = WQLinear_GEMV | |
q_linear = q_linear_module.from_linear( | |
module, quant_config.w_bit, quant_config.q_group_size, True | |
) | |
q_linear.to(next(layer.parameters()).device) | |
set_op_by_name(layer, name, q_linear) | |
torch.cuda.empty_cache() | |
gc.collect() | |
def _load_config( | |
self, | |
model_path, | |
model_filename, | |
safetensors=True, | |
version="GEMM", | |
trust_remote_code=True, | |
max_new_tokens=4096, | |
**config_kwargs, | |
): | |
# [STEP 1] Download model if path is not a directory | |
if not os.path.isdir(model_path): | |
ignore_patterns = ["*msgpack*", "*h5*", "optimizer.pt"] | |
if safetensors: | |
ignore_patterns.extend(["*.pt*", "*.bin*"]) | |
else: | |
ignore_patterns.append("*.safetensors*") | |
from huggingface_hub import snapshot_download | |
model_path = snapshot_download(model_path, ignore_patterns=ignore_patterns) | |
if model_filename != "": | |
model_weights_path = model_path + f"/{model_filename}" | |
else: | |
model_weights_path = model_path | |
#### <MODIFIED> | |
# [STEP 2] Load config and set sequence length | |
# TODO: Create BaseAWQConfig class | |
quant_config = MixtralAwqConfig.from_pretrained(model_path) | |
#### </MODIFIED> | |
from transformers import AutoConfig | |
# Load model config and set max generation length | |
if max_new_tokens is None and hasattr(self, "max_new_tokens_key"): | |
config = AutoConfig.from_pretrained( | |
model_path, trust_remote_code=trust_remote_code, **config_kwargs | |
) | |
config.max_new_tokens = getattr(config, self.max_new_tokens_key) | |
else: | |
max_new_tokens = 2048 if max_new_tokens is None else max_new_tokens | |
config = AutoConfig.from_pretrained( | |
model_path, trust_remote_code=trust_remote_code, **config_kwargs | |
) | |
config.max_new_tokens = max_new_tokens | |
return model_weights_path, config, quant_config | |
@staticmethod | |
def fuse_layers(model: OldmixtralForCausalLM): | |
fuser = MixtralFuser(model) | |
fuser.fuse_transformer() | |
@staticmethod | |
def get_model_layers(model: OldmixtralForCausalLM): | |
return model.model.layers | |
@staticmethod | |
def get_act_for_scaling(module): | |
return dict(is_scalable=False) | |
@staticmethod | |
def move_embed(model: OldmixtralForCausalLM, device: str): | |
model.model.embed_tokens = model.model.embed_tokens.to(device) | |
@staticmethod | |
def get_layers_for_scaling(module: OldmixtralDecoderLayer, input_feat, module_kwargs): | |
layers = [] | |
# attention input | |
layers.append( | |
dict( | |
prev_op=module.input_layernorm, | |
layers=[module.self_attn.q_proj, module.self_attn.k_proj, module.self_attn.v_proj], | |
inp=input_feat["self_attn.q_proj"], | |
module2inspect=module.self_attn, | |
kwargs=module_kwargs, | |
) | |
) | |
# attention out | |
if module.self_attn.v_proj.weight.shape == module.self_attn.o_proj.weight.shape: | |
layers.append( | |
dict( | |
prev_op=module.self_attn.v_proj, | |
layers=[module.self_attn.o_proj], | |
inp=input_feat["self_attn.o_proj"], | |
) | |
) | |
layers.append( | |
dict( | |
prev_op=module.post_attention_layernorm, | |
layers=[ | |
w for expert in module.block_sparse_moe.experts for w in [expert.w1, expert.w3] | |
], | |
inp=input_feat[f"block_sparse_moe"], | |
module2inspect=module.block_sparse_moe, | |
) | |
) | |
# expert | |
for i, expert in enumerate(module.block_sparse_moe.experts): | |
layers.append( | |
dict( | |
prev_op=expert.w3, | |
layers=[expert.w2], | |
inp=input_feat[f"block_sparse_moe.experts.{i}.w2"], | |
) | |
) | |
return layers | |
class MixtralMLP(nn.Module): | |
def __init__(self, gate_proj, down_proj, up_proj): | |
super().__init__() | |
self.fused_mlp = QuantFusedMLP(gate_proj=gate_proj, down_proj=down_proj, up_proj=up_proj) | |
def forward(self, hidden_states, routing_weights): | |
return routing_weights * self.fused_mlp(hidden_states) | |
class MixtralBlock(nn.Module): | |
def __init__( | |
self, | |
hidden_size, | |
n_heads, | |
n_kv_heads, | |
qkv_layer, | |
o_proj, | |
moe, | |
norm_1, | |
norm_2, | |
dev, | |
max_seq_len, | |
): | |
super().__init__() | |
self.n_heads = n_heads | |
self.n_kv_heads = n_kv_heads | |
self.hidden_size = hidden_size | |
self.norm_1 = norm_1.to(dev) | |
self.attn = QuantAttentionFused( | |
self.hidden_size, | |
self.n_heads, | |
self.n_kv_heads, | |
qkv_layer, | |
o_proj, | |
dev=dev, | |
max_seq_len=max_seq_len, | |
use_alibi=False, | |
).to(dev) | |
self.norm_2 = norm_2.to(dev) | |
self.moe = moe | |
self.device = dev | |
def forward( | |
self, hidden_states, past_key_value, attn_bias=None, attention_mask=None, is_causal=None | |
): | |
norm_out = self.norm_1(hidden_states) | |
attn_output, _, past_key_value = self.attn.forward( | |
hidden_states=norm_out, past_key_value=past_key_value, attention_mask=attention_mask | |
) | |
h = hidden_states.to(attn_output.device) + attn_output | |
moe_output, _ = self.moe.forward(self.norm_2(h)) | |
out = h + moe_output | |
return out, None, past_key_value | |
class MixtralModel(nn.Module): | |
""" | |
LlamaLikeModel is intended to be reused across models that have | |
an architecture that closely resembles Llama, e.g. Mistral and Aquila. | |
""" | |
def __init__(self, vocab_size, blocks, embedding, norm): | |
super().__init__() | |
self.vocab_size = vocab_size | |
self.embedding = embedding | |
self.blocks: List[MixtralBlock] = nn.ModuleList(blocks) | |
self.norm = norm | |
self.last_forward_num_tokens = 0 | |
@torch.inference_mode() | |
def forward( | |
self, | |
input_ids: torch.Tensor, | |
attn_bias=None, | |
attention_mask=None, | |
is_causal=None, | |
*args, | |
**kwargs, | |
): | |
input_ids, self.last_forward_num_tokens = fused_utils.prepare_input_ids( | |
input_ids, self.last_forward_num_tokens | |
) | |
_bsz, seqlen = input_ids.shape | |
fused_utils.prepare_cache(self.blocks, seqlen) | |
h = self.embedding(input_ids) | |
mask = fused_utils.prepare_attention_mask( | |
seqlen=seqlen, | |
start_pos=self.blocks[0].attn.start_pos, | |
device=input_ids.device, | |
type_as=h, | |
) | |
for layer in self.blocks: | |
h, mask = fused_utils.prepare_correct_devices( | |
layer, | |
h, | |
mask, | |
) | |
h, _, past_key_value = layer(h, None, attention_mask=mask, is_causal=is_causal) | |
h = self.norm(h) | |
return MoeModelOutputWithPast( | |
last_hidden_state=h, | |
past_key_values=past_key_value, | |
hidden_states=(), | |
attentions=(), | |
router_logits=(), | |
) | |
class MixtralFuser: | |
def __init__(self, model: OldmixtralForCausalLM): | |
self.model = model | |
self.mixtral_blocks: List[Tuple[str, OldmixtralDecoderLayer]] = [ | |
(name, module) | |
for name, module in self.model.named_modules() | |
if "MixtralDecoderLayer".lower() in module.__class__.__name__.lower() | |
] | |
def fuse_transformer(self): | |
blocks = [] | |
module: OldmixtralDecoderLayer | |
for module in tqdm(self.model.model.layers, desc="Fusing layers..."): | |
device = next(iter(module.state_dict().values())).device | |
qkv = fused_utils.fuse_qkv( | |
module, module.self_attn.q_proj, module.self_attn.k_proj, module.self_attn.v_proj | |
) | |
# Adapt to mixture of experts | |
for i in range(len(module.block_sparse_moe.experts)): | |
mlp = MixtralMLP( | |
gate_proj=module.block_sparse_moe.experts[i].w1, | |
down_proj=module.block_sparse_moe.experts[i].w2, | |
up_proj=module.block_sparse_moe.experts[i].w3, | |
) | |
module.block_sparse_moe.experts[i] = mlp | |
norm_1 = FasterTransformerRMSNorm( | |
module.input_layernorm.weight, module.input_layernorm.variance_epsilon | |
) | |
norm_2 = FasterTransformerRMSNorm( | |
module.post_attention_layernorm.weight, | |
module.post_attention_layernorm.variance_epsilon, | |
) | |
blocks.append( | |
MixtralBlock( | |
hidden_size=self.model.config.hidden_size, | |
n_heads=self.model.config.num_attention_heads, | |
n_kv_heads=self.model.config.num_key_value_heads, | |
qkv_layer=qkv, | |
o_proj=module.self_attn.o_proj, | |
moe=module.block_sparse_moe, | |
norm_1=norm_1, | |
norm_2=norm_2, | |
dev=device, | |
max_seq_len=self.model.config.max_new_tokens, | |
) | |
) | |
self.model.model = MixtralModel( | |
self.model.config.vocab_size, | |
blocks, | |
self.model.model.embed_tokens, | |
self.model.model.norm, | |
) | |
def quantize(): | |
from transformers import AutoTokenizer | |
MODEL = "./huggingface/Mixtral-8x7B-Instruct-v0.1" | |
CONFIG = { | |
"zero_point": True, | |
"q_group_size": 128, | |
"w_bit": 4, | |
"version": "GEMM", | |
"modules_to_not_convert": ["gate"], | |
} | |
model = MixtralAWQForCausalLM.from_pretrained( | |
MODEL, "mixtral", low_cpu_mem_usage=True, safetensors=True, device_map="cpu" | |
) | |
tokenizer = AutoTokenizer.from_pretrained(MODEL) | |
# Quantize | |
model.quantize(tokenizer, quant_config=CONFIG) | |
# Save quantized model | |
output = MODEL.replace("huggingface", "awq") | |
model.save_quantized(output) | |
tokenizer.save_pretrained(output) | |
if __name__ == "__main__": | |
import torch | |
from transformers import AutoTokenizer, TextStreamer | |
MODEL = "./shared/awq/Mixtral-8x7B-Instruct-v0.1/" | |
model = MixtralAWQForCausalLM.from_quantized( | |
MODEL, | |
"mixtral", | |
fuse_layers=True, | |
) | |
tokenizer = AutoTokenizer.from_pretrained(MODEL) | |
inputs = tokenizer("Hi, my name is", return_tensors="pt") | |
model.generate( | |
**{k: v.cuda() for k, v in inputs.items()}, | |
streamer=TextStreamer(tokenizer), | |
max_new_tokens=20, | |
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
## pip install autoawq==0.1.7 | |
import functools | |
import gc | |
import os | |
from collections import defaultdict | |
from typing import Dict, List, Optional, Tuple, Union | |
import torch | |
import torch.nn as nn | |
from tqdm import tqdm | |
from awq.models.base import BaseAWQForCausalLM | |
from awq.models._config import AwqConfig | |
from awq.modules.fused.attn import QuantAttentionFused | |
from awq.modules.fused.mlp import QuantFusedMLP | |
from awq.modules.fused.norm import FasterTransformerRMSNorm | |
from awq.modules.linear import WQLinear_GEMM, WQLinear_GEMV | |
from awq.quantize.quantizer import AwqQuantizer | |
from awq.quantize.scale import apply_clip, apply_scale | |
from awq.utils import fused_utils | |
from awq.utils.calib_data import get_calib_dataset | |
from awq.utils.module import ( | |
append_str_prefix, | |
get_named_linears, | |
get_op_name, | |
set_op_by_name, | |
) | |
from awq.utils.utils import clear_memory | |
from transformers.models.mixtral.modeling_mixtral import ( | |
MixtralDecoderLayer as OldmixtralDecoderLayer, | |
MixtralForCausalLM as OldmixtralForCausalLM, | |
MoeModelOutputWithPast, | |
) | |
def _exclude_layers_to_not_quantize(linear_layers, modules_to_not_convert): | |
filtered_layers = {} | |
for name, linear_layer in linear_layers.items(): | |
if not any(key in name for key in modules_to_not_convert): | |
filtered_layers[name] = linear_layer | |
return filtered_layers | |
class MixtralAwqConfig(AwqConfig): | |
def __init__( | |
self, | |
*args, | |
modules_to_not_convert: Optional[List] = None, | |
**kwargs, | |
): | |
super().__init__(*args, **kwargs) | |
self.modules_to_not_convert = modules_to_not_convert | |
def to_dict(self): | |
return { | |
**super().to_dict(), | |
"modules_to_not_convert": self.modules_to_not_convert, | |
} | |
def to_transformers_dict(self): | |
return { | |
**super().to_transformers_dict(), | |
"modules_to_not_convert": self.modules_to_not_convert, | |
} | |
class MixtralAwqQuantizer(AwqQuantizer): | |
def __init__(self, *args, modules_to_not_convert: List = [], **kwargs): | |
super().__init__(*args) | |
self.modules_to_not_convert = modules_to_not_convert | |
def init_quant(self, n_samples=128, seqlen=512): | |
modules = self.awq_model.get_model_layers(self.model) | |
samples = get_calib_dataset( | |
data=self.calib_data, | |
tokenizer=self.tokenizer, | |
n_samples=n_samples, | |
block_size=seqlen, | |
split=self.split, | |
text_column=self.text_column, | |
) | |
samples = torch.cat(samples, dim=0) | |
inps = [] | |
layer_kwargs = {} | |
modules[0] = modules[0].cuda() | |
self.awq_model.move_embed(self.model, "cuda") | |
# get input and kwargs to layer 0 | |
# with_kwargs is only supported in PyTorch 2.0 | |
# use this Catcher hack for now | |
class Catcher(nn.Module): | |
def __init__(self, module): | |
super().__init__() | |
self.module = module | |
#### <MODIFIED> | |
def forward(self, *args, **kwargs): | |
# assume first input to forward is hidden states | |
if len(args) > 0: | |
hidden_states = args[0] | |
del args | |
else: | |
first_key = list(kwargs.keys())[0] | |
hidden_states = kwargs.pop(first_key) | |
inps.append(hidden_states) | |
#### </MODIFIED> | |
layer_kwargs.update(kwargs) | |
raise ValueError # early exit to break later inference | |
# patch layer 0 to catch input and kwargs | |
modules[0] = Catcher(modules[0]) | |
try: | |
self.model(samples.to(next(self.model.parameters()).device)) | |
except ValueError: # work with early exit | |
pass | |
#### <MODIFIED> | |
# Update the layer kwargs with `prepare_inputs_for_generation` method | |
# that takes care of everything to avoid unexpected errors. | |
layer_kwargs = self.model.prepare_inputs_for_generation(samples, **layer_kwargs) | |
# Pop the input_ids as they are not needed at all. | |
layer_kwargs.pop("input_ids") | |
#### </MODIFIED> | |
del samples | |
modules[0] = modules[0].module # restore | |
inps = inps[0] | |
modules[0] = modules[0].cpu() | |
self.awq_model.move_embed(self.model, "cpu") | |
clear_memory() | |
if layer_kwargs.get("attention_mask") is not None: | |
layer_kwargs["attention_mask"] = layer_kwargs["attention_mask"].to("cuda") | |
return modules, layer_kwargs, inps | |
def quantize(self): | |
for i in tqdm(range(len(self.modules)), desc="AWQ"): | |
# [STEP 1]: Get layer, extract linear modules, extract input features | |
named_linears = get_named_linears(self.modules[i]) | |
#### <MODIFIED> | |
# Filter out the linear layers we don't want to exclude | |
named_linears = _exclude_layers_to_not_quantize( | |
named_linears, self.modules_to_not_convert | |
) | |
#### </MODIFIED> | |
self._apply_quant(self.modules[i], named_linears) | |
clear_memory() | |
def _get_input_feat(self, layer, named_linears): | |
# firstly, get input features of all linear layers | |
def cache_input_hook(m, x, y, name, feat_dict): | |
x = x[0] | |
x = x.detach().cpu() | |
feat_dict[name].append(x) | |
input_feat = defaultdict(list) | |
handles = [] | |
#### <MODIFIED> | |
named_linears = {**named_linears, "block_sparse_moe": layer.block_sparse_moe} | |
#### </MODIFIED> | |
for name in named_linears: | |
handles.append( | |
named_linears[name].register_forward_hook( | |
functools.partial(cache_input_hook, name=name, feat_dict=input_feat) | |
) | |
) | |
self.inps = self.inps.to(next(layer.parameters()).device) # in case multi-gpu | |
# get output as next layer's input | |
self.inps = layer(self.inps, **self.module_kwargs)[0] | |
for h in handles: | |
h.remove() | |
# now solve for scaling and clipping | |
input_feat = {k: torch.cat(v, dim=0) for k, v in input_feat.items()} | |
return input_feat | |
class MixtralAWQForCausalLM(BaseAWQForCausalLM): | |
layer_type = "MixtralDecoderLayer" | |
max_new_tokens_key = "max_position_embeddings" | |
@torch.no_grad() | |
def quantize( | |
self, | |
tokenizer=None, | |
quant_config={}, | |
calib_data: Union[str, List[str]] = "pileval", | |
split="train", | |
text_column="text", | |
): | |
self.quant_config = MixtralAwqConfig.from_dict(quant_config) | |
#### <MODIFIED> | |
quantizer = MixtralAwqQuantizer( | |
self, | |
self.model, | |
tokenizer, | |
self.quant_config.w_bit, | |
self.quant_config.q_group_size, | |
self.quant_config.version, | |
calib_data, | |
split, | |
text_column, | |
modules_to_not_convert=self.quant_config.modules_to_not_convert, | |
) | |
#### </MODIFIED> | |
quantizer.quantize() | |
self.is_quantized = True | |
def _load_quantized_modules(self, model, quant_config, version): | |
# Real quantization of weights | |
assert quant_config.zero_point, "We only support zero_point quantization now." | |
# Get blocks of model | |
layers = self.get_model_layers(model) | |
for i in tqdm(range(len(layers)), desc="Replacing layers..."): | |
layer = layers[i] | |
# Get every linear layer in a block | |
named_linears = get_named_linears(layer) | |
#### <MODIFIED> | |
# Filter out the linear layers we don't want to exclude | |
named_linears = _exclude_layers_to_not_quantize( | |
named_linears, quant_config.modules_to_not_convert | |
) | |
#### </MODIFIED> | |
# Replace activation functions | |
self._scale_activations(self, layer) | |
# Replace nn.Linear with WQLinear | |
for name, module in named_linears.items(): | |
if version == "GEMM": | |
q_linear_module = WQLinear_GEMM | |
elif version == "GEMV": | |
q_linear_module = WQLinear_GEMV | |
q_linear = q_linear_module.from_linear( | |
module, quant_config.w_bit, quant_config.q_group_size, True | |
) | |
q_linear.to(next(layer.parameters()).device) | |
set_op_by_name(layer, name, q_linear) | |
torch.cuda.empty_cache() | |
gc.collect() | |
def _load_config( | |
self, | |
model_path, | |
model_filename, | |
safetensors=True, | |
version="GEMM", | |
trust_remote_code=True, | |
max_new_tokens=4096, | |
**config_kwargs, | |
): | |
# [STEP 1] Download model if path is not a directory | |
if not os.path.isdir(model_path): | |
ignore_patterns = ["*msgpack*", "*h5*", "optimizer.pt"] | |
if safetensors: | |
ignore_patterns.extend(["*.pt*", "*.bin*"]) | |
else: | |
ignore_patterns.append("*.safetensors*") | |
from huggingface_hub import snapshot_download | |
model_path = snapshot_download(model_path, ignore_patterns=ignore_patterns) | |
if model_filename != "": | |
model_weights_path = model_path + f"/{model_filename}" | |
else: | |
model_weights_path = model_path | |
#### <MODIFIED> | |
# [STEP 2] Load config and set sequence length | |
# TODO: Create BaseAWQConfig class | |
quant_config = MixtralAwqConfig.from_pretrained(model_path) | |
#### </MODIFIED> | |
from transformers import AutoConfig | |
# Load model config and set max generation length | |
if max_new_tokens is None and hasattr(self, "max_new_tokens_key"): | |
config = AutoConfig.from_pretrained( | |
model_path, trust_remote_code=trust_remote_code, **config_kwargs | |
) | |
config.max_new_tokens = getattr(config, self.max_new_tokens_key) | |
else: | |
max_new_tokens = 2048 if max_new_tokens is None else max_new_tokens | |
config = AutoConfig.from_pretrained( | |
model_path, trust_remote_code=trust_remote_code, **config_kwargs | |
) | |
config.max_new_tokens = max_new_tokens | |
return model_weights_path, config, quant_config | |
@staticmethod | |
def fuse_layers(model: OldmixtralForCausalLM): | |
fuser = MixtralFuser(model) | |
fuser.fuse_transformer() | |
@staticmethod | |
def get_model_layers(model: OldmixtralForCausalLM): | |
return model.model.layers | |
@staticmethod | |
def get_act_for_scaling(module): | |
return dict(is_scalable=False) | |
@staticmethod | |
def move_embed(model: OldmixtralForCausalLM, device: str): | |
model.model.embed_tokens = model.model.embed_tokens.to(device) | |
@staticmethod | |
def get_layers_for_scaling(module: OldmixtralDecoderLayer, input_feat, module_kwargs): | |
layers = [] | |
return layers | |
class MixtralMLP(nn.Module): | |
def __init__(self, gate_proj, down_proj, up_proj): | |
super().__init__() | |
self.fused_mlp = QuantFusedMLP(gate_proj=gate_proj, down_proj=down_proj, up_proj=up_proj) | |
def forward(self, hidden_states, routing_weights): | |
return routing_weights * self.fused_mlp(hidden_states) | |
class MixtralBlock(nn.Module): | |
def __init__( | |
self, | |
hidden_size, | |
n_heads, | |
n_kv_heads, | |
qkv_layer, | |
o_proj, | |
moe, | |
norm_1, | |
norm_2, | |
dev, | |
max_seq_len, | |
): | |
super().__init__() | |
self.n_heads = n_heads | |
self.n_kv_heads = n_kv_heads | |
self.hidden_size = hidden_size | |
self.norm_1 = norm_1.to(dev) | |
self.attn = QuantAttentionFused( | |
self.hidden_size, | |
self.n_heads, | |
self.n_kv_heads, | |
qkv_layer, | |
o_proj, | |
dev=dev, | |
max_seq_len=max_seq_len, | |
use_alibi=False, | |
).to(dev) | |
self.norm_2 = norm_2.to(dev) | |
self.moe = moe | |
self.device = dev | |
def forward( | |
self, hidden_states, past_key_value, attn_bias=None, attention_mask=None, is_causal=None | |
): | |
norm_out = self.norm_1(hidden_states) | |
attn_output, _, past_key_value = self.attn.forward( | |
hidden_states=norm_out, past_key_value=past_key_value, attention_mask=attention_mask | |
) | |
h = hidden_states.to(attn_output.device) + attn_output | |
moe_output, _ = self.moe.forward(self.norm_2(h)) | |
out = h + moe_output | |
return out, None, past_key_value | |
class MixtralModel(nn.Module): | |
""" | |
LlamaLikeModel is intended to be reused across models that have | |
an architecture that closely resembles Llama, e.g. Mistral and Aquila. | |
""" | |
def __init__(self, vocab_size, blocks, embedding, norm): | |
super().__init__() | |
self.vocab_size = vocab_size | |
self.embedding = embedding | |
self.blocks: List[MixtralBlock] = nn.ModuleList(blocks) | |
self.norm = norm | |
self.last_forward_num_tokens = 0 | |
@torch.inference_mode() | |
def forward( | |
self, | |
input_ids: torch.Tensor, | |
attn_bias=None, | |
attention_mask=None, | |
is_causal=None, | |
*args, | |
**kwargs, | |
): | |
input_ids, self.last_forward_num_tokens = fused_utils.prepare_input_ids( | |
input_ids, self.last_forward_num_tokens | |
) | |
_bsz, seqlen = input_ids.shape | |
fused_utils.prepare_cache(self.blocks, seqlen) | |
h = self.embedding(input_ids) | |
mask = fused_utils.prepare_attention_mask( | |
seqlen=seqlen, | |
start_pos=self.blocks[0].attn.start_pos, | |
device=input_ids.device, | |
type_as=h, | |
) | |
for layer in self.blocks: | |
h, mask = fused_utils.prepare_correct_devices( | |
layer, | |
h, | |
mask, | |
) | |
h, _, past_key_value = layer(h, None, attention_mask=mask, is_causal=is_causal) | |
h = self.norm(h) | |
return MoeModelOutputWithPast( | |
last_hidden_state=h, | |
past_key_values=past_key_value, | |
hidden_states=(), | |
attentions=(), | |
router_logits=(), | |
) | |
class MixtralFuser: | |
def __init__(self, model: OldmixtralForCausalLM): | |
self.model = model | |
self.mixtral_blocks: List[Tuple[str, OldmixtralDecoderLayer]] = [ | |
(name, module) | |
for name, module in self.model.named_modules() | |
if "MixtralDecoderLayer".lower() in module.__class__.__name__.lower() | |
] | |
def fuse_transformer(self): | |
blocks = [] | |
module: OldmixtralDecoderLayer | |
for module in tqdm(self.model.model.layers, desc="Fusing layers..."): | |
device = next(iter(module.state_dict().values())).device | |
qkv = fused_utils.fuse_qkv( | |
module, module.self_attn.q_proj, module.self_attn.k_proj, module.self_attn.v_proj | |
) | |
# Adapt to mixture of experts | |
for i in range(len(module.block_sparse_moe.experts)): | |
mlp = MixtralMLP( | |
gate_proj=module.block_sparse_moe.experts[i].w1, | |
down_proj=module.block_sparse_moe.experts[i].w2, | |
up_proj=module.block_sparse_moe.experts[i].w3, | |
) | |
module.block_sparse_moe.experts[i] = mlp | |
norm_1 = FasterTransformerRMSNorm( | |
module.input_layernorm.weight, module.input_layernorm.variance_epsilon | |
) | |
norm_2 = FasterTransformerRMSNorm( | |
module.post_attention_layernorm.weight, | |
module.post_attention_layernorm.variance_epsilon, | |
) | |
blocks.append( | |
MixtralBlock( | |
hidden_size=self.model.config.hidden_size, | |
n_heads=self.model.config.num_attention_heads, | |
n_kv_heads=self.model.config.num_key_value_heads, | |
qkv_layer=qkv, | |
o_proj=module.self_attn.o_proj, | |
moe=module.block_sparse_moe, | |
norm_1=norm_1, | |
norm_2=norm_2, | |
dev=device, | |
max_seq_len=self.model.config.max_new_tokens, | |
) | |
) | |
self.model.model = MixtralModel( | |
self.model.config.vocab_size, | |
blocks, | |
self.model.model.embed_tokens, | |
self.model.model.norm, | |
) | |
def quantize(): | |
from transformers import AutoTokenizer | |
MODEL = "./shared/huggingface/Mixtral-8x7B-Instruct-v0.1" | |
CONFIG = { | |
"zero_point": True, | |
"q_group_size": 128, | |
"w_bit": 4, | |
"version": "GEMM", | |
"modules_to_not_convert": ["gate"], | |
} | |
model = MixtralAWQForCausalLM.from_pretrained( | |
MODEL, "mixtral", low_cpu_mem_usage=True, safetensors=True, device_map="cpu" | |
) | |
tokenizer = AutoTokenizer.from_pretrained(MODEL) | |
# Quantize | |
model.quantize(tokenizer, quant_config=CONFIG) | |
# Save quantized model | |
output = MODEL.replace("huggingface", "awq") | |
model.save_quantized(output) | |
tokenizer.save_pretrained(output) | |
if __name__ == "__main__": | |
import torch | |
from transformers import AutoTokenizer, TextStreamer | |
MODEL = "./shared/awq/Mixtral-8x7B-Instruct-v0.1/" | |
model = MixtralAWQForCausalLM.from_quantized( | |
MODEL, | |
"mixtral", | |
fuse_layers=True, | |
) | |
tokenizer = AutoTokenizer.from_pretrained(MODEL) | |
inputs = tokenizer("안녕하세요, 저는 ", return_tensors="pt") | |
model.generate( | |
**{k: v.cuda() for k, v in inputs.items()}, | |
streamer=TextStreamer(tokenizer), | |
max_new_tokens=20, | |
) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment