Skip to content

Instantly share code, notes, and snippets.

@noah-kim-theori
Last active January 4, 2024 18:07
Show Gist options
  • Save noah-kim-theori/4fc23189d22d2e5ea49ef127cdb5cdfd to your computer and use it in GitHub Desktop.
Save noah-kim-theori/4fc23189d22d2e5ea49ef127cdb5cdfd to your computer and use it in GitHub Desktop.
## pip install autoawq==0.1.7
## pip install transformers==4.36.2
import functools
import gc
import os
from collections import defaultdict
from typing import Dict, List, Optional, Tuple, Union
import torch
import torch.nn as nn
from tqdm import tqdm
from awq.models.base import BaseAWQForCausalLM
from awq.models._config import AwqConfig
from awq.modules.fused.attn import QuantAttentionFused
from awq.modules.fused.mlp import QuantFusedMLP
from awq.modules.fused.norm import FasterTransformerRMSNorm
from awq.modules.linear import WQLinear_GEMM, WQLinear_GEMV
from awq.quantize.quantizer import AwqQuantizer
from awq.quantize.scale import apply_clip, apply_scale
from awq.utils import fused_utils
from awq.utils.calib_data import get_calib_dataset
from awq.utils.module import (
append_str_prefix,
get_named_linears,
get_op_name,
set_op_by_name,
)
from awq.utils.utils import clear_memory
from transformers.models.mixtral.modeling_mixtral import (
MixtralDecoderLayer as OldmixtralDecoderLayer,
MixtralForCausalLM as OldmixtralForCausalLM,
MoeModelOutputWithPast,
)
def _exclude_layers_to_not_quantize(linear_layers, modules_to_not_convert):
filtered_layers = {}
for name, linear_layer in linear_layers.items():
if not any(key in name for key in modules_to_not_convert):
filtered_layers[name] = linear_layer
return filtered_layers
class MixtralAwqConfig(AwqConfig):
def __init__(
self,
*args,
modules_to_not_convert: Optional[List] = None,
**kwargs,
):
super().__init__(*args, **kwargs)
self.modules_to_not_convert = modules_to_not_convert
def to_dict(self):
return {
**super().to_dict(),
"modules_to_not_convert": self.modules_to_not_convert,
}
def to_transformers_dict(self):
return {
**super().to_transformers_dict(),
"modules_to_not_convert": self.modules_to_not_convert,
}
class MixtralAwqQuantizer(AwqQuantizer):
def __init__(self, *args, modules_to_not_convert: List = [], **kwargs):
super().__init__(*args)
self.modules_to_not_convert = modules_to_not_convert
def init_quant(self, n_samples=128, seqlen=512):
modules = self.awq_model.get_model_layers(self.model)
samples = get_calib_dataset(
data=self.calib_data,
tokenizer=self.tokenizer,
n_samples=n_samples,
block_size=seqlen,
split=self.split,
text_column=self.text_column,
)
samples = torch.cat(samples, dim=0)
inps = []
layer_kwargs = {}
modules[0] = modules[0].cuda()
self.awq_model.move_embed(self.model, "cuda")
# get input and kwargs to layer 0
# with_kwargs is only supported in PyTorch 2.0
# use this Catcher hack for now
class Catcher(nn.Module):
def __init__(self, module):
super().__init__()
self.module = module
#### <MODIFIED>
def forward(self, *args, **kwargs):
# assume first input to forward is hidden states
if len(args) > 0:
hidden_states = args[0]
del args
else:
first_key = list(kwargs.keys())[0]
hidden_states = kwargs.pop(first_key)
inps.append(hidden_states)
#### </MODIFIED>
layer_kwargs.update(kwargs)
raise ValueError # early exit to break later inference
# patch layer 0 to catch input and kwargs
modules[0] = Catcher(modules[0])
try:
self.model(samples.to(next(self.model.parameters()).device))
except ValueError: # work with early exit
pass
#### <MODIFIED>
# Update the layer kwargs with `prepare_inputs_for_generation` method
# that takes care of everything to avoid unexpected errors.
layer_kwargs = self.model.prepare_inputs_for_generation(samples, **layer_kwargs)
# Pop the input_ids as they are not needed at all.
layer_kwargs.pop("input_ids")
#### </MODIFIED>
del samples
modules[0] = modules[0].module # restore
inps = inps[0]
modules[0] = modules[0].cpu()
self.awq_model.move_embed(self.model, "cpu")
clear_memory()
if layer_kwargs.get("attention_mask") is not None:
layer_kwargs["attention_mask"] = layer_kwargs["attention_mask"].to("cuda")
return modules, layer_kwargs, inps
def quantize(self):
for i in tqdm(range(len(self.modules)), desc="AWQ"):
# Move module and inputs to correct device
common_device = next(self.modules[i].parameters()).device
if common_device is None or str(common_device) == "cpu":
self.modules[i] = self.modules[i].cuda()
common_device = next(self.modules[i].parameters()).device
self.inps = self.inps.to(common_device)
# [STEP 1]: Get layer, extract linear modules, extract input features
named_linears = get_named_linears(self.modules[i])
#### <MODIFIED>
# Filter out the linear layers we don't want to exclude
named_linears = _exclude_layers_to_not_quantize(
named_linears, self.modules_to_not_convert
)
#### </MODIFIED>
input_feat = self._get_input_feat(self.modules[i], named_linears)
clear_memory()
# [STEP 2]: Compute and apply scale list
module_config: List[Dict] = self.awq_model.get_layers_for_scaling(
self.modules[i], input_feat, self.module_kwargs
)
scales_list = [
self._search_best_scale(self.modules[i], **layer) for layer in module_config
]
apply_scale(self.modules[i], scales_list, input_feat_dict=input_feat)
scales_list = append_str_prefix(
scales_list, get_op_name(self.model, self.modules[i]) + "."
)
# [STEP 3]: Compute and apply clipping list
clip_list = self._search_best_clip(self.modules[i], named_linears, input_feat)
apply_clip(self.modules[i], clip_list)
clip_list = append_str_prefix(clip_list, get_op_name(self.model, self.modules[i]) + ".")
# [STEP 4]: Quantize weights
self._apply_quant(self.modules[i], named_linears)
clear_memory()
def _get_input_feat(self, layer, named_linears):
# firstly, get input features of all linear layers
def cache_input_hook(m, x, y, name, feat_dict):
x = x[0]
x = x.detach().cpu()
feat_dict[name].append(x)
input_feat = defaultdict(list)
handles = []
#### <MODIFIED>
named_linears = {**named_linears, "block_sparse_moe": layer.block_sparse_moe}
#### </MODIFIED>
for name in named_linears:
handles.append(
named_linears[name].register_forward_hook(
functools.partial(cache_input_hook, name=name, feat_dict=input_feat)
)
)
self.inps = self.inps.to(next(layer.parameters()).device) # in case multi-gpu
# get output as next layer's input
self.inps = layer(self.inps, **self.module_kwargs)[0]
for h in handles:
h.remove()
# now solve for scaling and clipping
input_feat = {k: torch.cat(v, dim=0) for k, v in input_feat.items()}
return input_feat
class MixtralAWQForCausalLM(BaseAWQForCausalLM):
layer_type = "MixtralDecoderLayer"
max_new_tokens_key = "max_position_embeddings"
@torch.no_grad()
def quantize(
self,
tokenizer=None,
quant_config={},
calib_data: Union[str, List[str]] = "pileval",
split="train",
text_column="text",
):
self.quant_config = MixtralAwqConfig.from_dict(quant_config)
#### <MODIFIED>
quantizer = MixtralAwqQuantizer(
self,
self.model,
tokenizer,
self.quant_config.w_bit,
self.quant_config.q_group_size,
self.quant_config.version,
calib_data,
split,
text_column,
modules_to_not_convert=self.quant_config.modules_to_not_convert,
)
#### </MODIFIED>
quantizer.quantize()
self.is_quantized = True
def _load_quantized_modules(self, model, quant_config, version):
# Real quantization of weights
assert quant_config.zero_point, "We only support zero_point quantization now."
# Get blocks of model
layers = self.get_model_layers(model)
for i in tqdm(range(len(layers)), desc="Replacing layers..."):
layer = layers[i]
# Get every linear layer in a block
named_linears = get_named_linears(layer)
#### <MODIFIED>
# Filter out the linear layers we don't want to exclude
named_linears = _exclude_layers_to_not_quantize(
named_linears, quant_config.modules_to_not_convert
)
#### </MODIFIED>
# Replace activation functions
self._scale_activations(self, layer)
# Replace nn.Linear with WQLinear
for name, module in named_linears.items():
if version == "GEMM":
q_linear_module = WQLinear_GEMM
elif version == "GEMV":
q_linear_module = WQLinear_GEMV
q_linear = q_linear_module.from_linear(
module, quant_config.w_bit, quant_config.q_group_size, True
)
q_linear.to(next(layer.parameters()).device)
set_op_by_name(layer, name, q_linear)
torch.cuda.empty_cache()
gc.collect()
def _load_config(
self,
model_path,
model_filename,
safetensors=True,
version="GEMM",
trust_remote_code=True,
max_new_tokens=4096,
**config_kwargs,
):
# [STEP 1] Download model if path is not a directory
if not os.path.isdir(model_path):
ignore_patterns = ["*msgpack*", "*h5*", "optimizer.pt"]
if safetensors:
ignore_patterns.extend(["*.pt*", "*.bin*"])
else:
ignore_patterns.append("*.safetensors*")
from huggingface_hub import snapshot_download
model_path = snapshot_download(model_path, ignore_patterns=ignore_patterns)
if model_filename != "":
model_weights_path = model_path + f"/{model_filename}"
else:
model_weights_path = model_path
#### <MODIFIED>
# [STEP 2] Load config and set sequence length
# TODO: Create BaseAWQConfig class
quant_config = MixtralAwqConfig.from_pretrained(model_path)
#### </MODIFIED>
from transformers import AutoConfig
# Load model config and set max generation length
if max_new_tokens is None and hasattr(self, "max_new_tokens_key"):
config = AutoConfig.from_pretrained(
model_path, trust_remote_code=trust_remote_code, **config_kwargs
)
config.max_new_tokens = getattr(config, self.max_new_tokens_key)
else:
max_new_tokens = 2048 if max_new_tokens is None else max_new_tokens
config = AutoConfig.from_pretrained(
model_path, trust_remote_code=trust_remote_code, **config_kwargs
)
config.max_new_tokens = max_new_tokens
return model_weights_path, config, quant_config
@staticmethod
def fuse_layers(model: OldmixtralForCausalLM):
fuser = MixtralFuser(model)
fuser.fuse_transformer()
@staticmethod
def get_model_layers(model: OldmixtralForCausalLM):
return model.model.layers
@staticmethod
def get_act_for_scaling(module):
return dict(is_scalable=False)
@staticmethod
def move_embed(model: OldmixtralForCausalLM, device: str):
model.model.embed_tokens = model.model.embed_tokens.to(device)
@staticmethod
def get_layers_for_scaling(module: OldmixtralDecoderLayer, input_feat, module_kwargs):
layers = []
# attention input
layers.append(
dict(
prev_op=module.input_layernorm,
layers=[module.self_attn.q_proj, module.self_attn.k_proj, module.self_attn.v_proj],
inp=input_feat["self_attn.q_proj"],
module2inspect=module.self_attn,
kwargs=module_kwargs,
)
)
# attention out
if module.self_attn.v_proj.weight.shape == module.self_attn.o_proj.weight.shape:
layers.append(
dict(
prev_op=module.self_attn.v_proj,
layers=[module.self_attn.o_proj],
inp=input_feat["self_attn.o_proj"],
)
)
layers.append(
dict(
prev_op=module.post_attention_layernorm,
layers=[
w for expert in module.block_sparse_moe.experts for w in [expert.w1, expert.w3]
],
inp=input_feat[f"block_sparse_moe"],
module2inspect=module.block_sparse_moe,
)
)
# expert
for i, expert in enumerate(module.block_sparse_moe.experts):
layers.append(
dict(
prev_op=expert.w3,
layers=[expert.w2],
inp=input_feat[f"block_sparse_moe.experts.{i}.w2"],
)
)
return layers
class MixtralMLP(nn.Module):
def __init__(self, gate_proj, down_proj, up_proj):
super().__init__()
self.fused_mlp = QuantFusedMLP(gate_proj=gate_proj, down_proj=down_proj, up_proj=up_proj)
def forward(self, hidden_states, routing_weights):
return routing_weights * self.fused_mlp(hidden_states)
class MixtralBlock(nn.Module):
def __init__(
self,
hidden_size,
n_heads,
n_kv_heads,
qkv_layer,
o_proj,
moe,
norm_1,
norm_2,
dev,
max_seq_len,
):
super().__init__()
self.n_heads = n_heads
self.n_kv_heads = n_kv_heads
self.hidden_size = hidden_size
self.norm_1 = norm_1.to(dev)
self.attn = QuantAttentionFused(
self.hidden_size,
self.n_heads,
self.n_kv_heads,
qkv_layer,
o_proj,
dev=dev,
max_seq_len=max_seq_len,
use_alibi=False,
).to(dev)
self.norm_2 = norm_2.to(dev)
self.moe = moe
self.device = dev
def forward(
self, hidden_states, past_key_value, attn_bias=None, attention_mask=None, is_causal=None
):
norm_out = self.norm_1(hidden_states)
attn_output, _, past_key_value = self.attn.forward(
hidden_states=norm_out, past_key_value=past_key_value, attention_mask=attention_mask
)
h = hidden_states.to(attn_output.device) + attn_output
moe_output, _ = self.moe.forward(self.norm_2(h))
out = h + moe_output
return out, None, past_key_value
class MixtralModel(nn.Module):
"""
LlamaLikeModel is intended to be reused across models that have
an architecture that closely resembles Llama, e.g. Mistral and Aquila.
"""
def __init__(self, vocab_size, blocks, embedding, norm):
super().__init__()
self.vocab_size = vocab_size
self.embedding = embedding
self.blocks: List[MixtralBlock] = nn.ModuleList(blocks)
self.norm = norm
self.last_forward_num_tokens = 0
@torch.inference_mode()
def forward(
self,
input_ids: torch.Tensor,
attn_bias=None,
attention_mask=None,
is_causal=None,
*args,
**kwargs,
):
input_ids, self.last_forward_num_tokens = fused_utils.prepare_input_ids(
input_ids, self.last_forward_num_tokens
)
_bsz, seqlen = input_ids.shape
fused_utils.prepare_cache(self.blocks, seqlen)
h = self.embedding(input_ids)
mask = fused_utils.prepare_attention_mask(
seqlen=seqlen,
start_pos=self.blocks[0].attn.start_pos,
device=input_ids.device,
type_as=h,
)
for layer in self.blocks:
h, mask = fused_utils.prepare_correct_devices(
layer,
h,
mask,
)
h, _, past_key_value = layer(h, None, attention_mask=mask, is_causal=is_causal)
h = self.norm(h)
return MoeModelOutputWithPast(
last_hidden_state=h,
past_key_values=past_key_value,
hidden_states=(),
attentions=(),
router_logits=(),
)
class MixtralFuser:
def __init__(self, model: OldmixtralForCausalLM):
self.model = model
self.mixtral_blocks: List[Tuple[str, OldmixtralDecoderLayer]] = [
(name, module)
for name, module in self.model.named_modules()
if "MixtralDecoderLayer".lower() in module.__class__.__name__.lower()
]
def fuse_transformer(self):
blocks = []
module: OldmixtralDecoderLayer
for module in tqdm(self.model.model.layers, desc="Fusing layers..."):
device = next(iter(module.state_dict().values())).device
qkv = fused_utils.fuse_qkv(
module, module.self_attn.q_proj, module.self_attn.k_proj, module.self_attn.v_proj
)
# Adapt to mixture of experts
for i in range(len(module.block_sparse_moe.experts)):
mlp = MixtralMLP(
gate_proj=module.block_sparse_moe.experts[i].w1,
down_proj=module.block_sparse_moe.experts[i].w2,
up_proj=module.block_sparse_moe.experts[i].w3,
)
module.block_sparse_moe.experts[i] = mlp
norm_1 = FasterTransformerRMSNorm(
module.input_layernorm.weight, module.input_layernorm.variance_epsilon
)
norm_2 = FasterTransformerRMSNorm(
module.post_attention_layernorm.weight,
module.post_attention_layernorm.variance_epsilon,
)
blocks.append(
MixtralBlock(
hidden_size=self.model.config.hidden_size,
n_heads=self.model.config.num_attention_heads,
n_kv_heads=self.model.config.num_key_value_heads,
qkv_layer=qkv,
o_proj=module.self_attn.o_proj,
moe=module.block_sparse_moe,
norm_1=norm_1,
norm_2=norm_2,
dev=device,
max_seq_len=self.model.config.max_new_tokens,
)
)
self.model.model = MixtralModel(
self.model.config.vocab_size,
blocks,
self.model.model.embed_tokens,
self.model.model.norm,
)
def quantize():
from transformers import AutoTokenizer
MODEL = "./huggingface/Mixtral-8x7B-Instruct-v0.1"
CONFIG = {
"zero_point": True,
"q_group_size": 128,
"w_bit": 4,
"version": "GEMM",
"modules_to_not_convert": ["gate"],
}
model = MixtralAWQForCausalLM.from_pretrained(
MODEL, "mixtral", low_cpu_mem_usage=True, safetensors=True, device_map="cpu"
)
tokenizer = AutoTokenizer.from_pretrained(MODEL)
# Quantize
model.quantize(tokenizer, quant_config=CONFIG)
# Save quantized model
output = MODEL.replace("huggingface", "awq")
model.save_quantized(output)
tokenizer.save_pretrained(output)
if __name__ == "__main__":
import torch
from transformers import AutoTokenizer, TextStreamer
MODEL = "./shared/awq/Mixtral-8x7B-Instruct-v0.1/"
model = MixtralAWQForCausalLM.from_quantized(
MODEL,
"mixtral",
fuse_layers=True,
)
tokenizer = AutoTokenizer.from_pretrained(MODEL)
inputs = tokenizer("Hi, my name is", return_tensors="pt")
model.generate(
**{k: v.cuda() for k, v in inputs.items()},
streamer=TextStreamer(tokenizer),
max_new_tokens=20,
)
## pip install autoawq==0.1.7
import functools
import gc
import os
from collections import defaultdict
from typing import Dict, List, Optional, Tuple, Union
import torch
import torch.nn as nn
from tqdm import tqdm
from awq.models.base import BaseAWQForCausalLM
from awq.models._config import AwqConfig
from awq.modules.fused.attn import QuantAttentionFused
from awq.modules.fused.mlp import QuantFusedMLP
from awq.modules.fused.norm import FasterTransformerRMSNorm
from awq.modules.linear import WQLinear_GEMM, WQLinear_GEMV
from awq.quantize.quantizer import AwqQuantizer
from awq.quantize.scale import apply_clip, apply_scale
from awq.utils import fused_utils
from awq.utils.calib_data import get_calib_dataset
from awq.utils.module import (
append_str_prefix,
get_named_linears,
get_op_name,
set_op_by_name,
)
from awq.utils.utils import clear_memory
from transformers.models.mixtral.modeling_mixtral import (
MixtralDecoderLayer as OldmixtralDecoderLayer,
MixtralForCausalLM as OldmixtralForCausalLM,
MoeModelOutputWithPast,
)
def _exclude_layers_to_not_quantize(linear_layers, modules_to_not_convert):
filtered_layers = {}
for name, linear_layer in linear_layers.items():
if not any(key in name for key in modules_to_not_convert):
filtered_layers[name] = linear_layer
return filtered_layers
class MixtralAwqConfig(AwqConfig):
def __init__(
self,
*args,
modules_to_not_convert: Optional[List] = None,
**kwargs,
):
super().__init__(*args, **kwargs)
self.modules_to_not_convert = modules_to_not_convert
def to_dict(self):
return {
**super().to_dict(),
"modules_to_not_convert": self.modules_to_not_convert,
}
def to_transformers_dict(self):
return {
**super().to_transformers_dict(),
"modules_to_not_convert": self.modules_to_not_convert,
}
class MixtralAwqQuantizer(AwqQuantizer):
def __init__(self, *args, modules_to_not_convert: List = [], **kwargs):
super().__init__(*args)
self.modules_to_not_convert = modules_to_not_convert
def init_quant(self, n_samples=128, seqlen=512):
modules = self.awq_model.get_model_layers(self.model)
samples = get_calib_dataset(
data=self.calib_data,
tokenizer=self.tokenizer,
n_samples=n_samples,
block_size=seqlen,
split=self.split,
text_column=self.text_column,
)
samples = torch.cat(samples, dim=0)
inps = []
layer_kwargs = {}
modules[0] = modules[0].cuda()
self.awq_model.move_embed(self.model, "cuda")
# get input and kwargs to layer 0
# with_kwargs is only supported in PyTorch 2.0
# use this Catcher hack for now
class Catcher(nn.Module):
def __init__(self, module):
super().__init__()
self.module = module
#### <MODIFIED>
def forward(self, *args, **kwargs):
# assume first input to forward is hidden states
if len(args) > 0:
hidden_states = args[0]
del args
else:
first_key = list(kwargs.keys())[0]
hidden_states = kwargs.pop(first_key)
inps.append(hidden_states)
#### </MODIFIED>
layer_kwargs.update(kwargs)
raise ValueError # early exit to break later inference
# patch layer 0 to catch input and kwargs
modules[0] = Catcher(modules[0])
try:
self.model(samples.to(next(self.model.parameters()).device))
except ValueError: # work with early exit
pass
#### <MODIFIED>
# Update the layer kwargs with `prepare_inputs_for_generation` method
# that takes care of everything to avoid unexpected errors.
layer_kwargs = self.model.prepare_inputs_for_generation(samples, **layer_kwargs)
# Pop the input_ids as they are not needed at all.
layer_kwargs.pop("input_ids")
#### </MODIFIED>
del samples
modules[0] = modules[0].module # restore
inps = inps[0]
modules[0] = modules[0].cpu()
self.awq_model.move_embed(self.model, "cpu")
clear_memory()
if layer_kwargs.get("attention_mask") is not None:
layer_kwargs["attention_mask"] = layer_kwargs["attention_mask"].to("cuda")
return modules, layer_kwargs, inps
def quantize(self):
for i in tqdm(range(len(self.modules)), desc="AWQ"):
# [STEP 1]: Get layer, extract linear modules, extract input features
named_linears = get_named_linears(self.modules[i])
#### <MODIFIED>
# Filter out the linear layers we don't want to exclude
named_linears = _exclude_layers_to_not_quantize(
named_linears, self.modules_to_not_convert
)
#### </MODIFIED>
self._apply_quant(self.modules[i], named_linears)
clear_memory()
def _get_input_feat(self, layer, named_linears):
# firstly, get input features of all linear layers
def cache_input_hook(m, x, y, name, feat_dict):
x = x[0]
x = x.detach().cpu()
feat_dict[name].append(x)
input_feat = defaultdict(list)
handles = []
#### <MODIFIED>
named_linears = {**named_linears, "block_sparse_moe": layer.block_sparse_moe}
#### </MODIFIED>
for name in named_linears:
handles.append(
named_linears[name].register_forward_hook(
functools.partial(cache_input_hook, name=name, feat_dict=input_feat)
)
)
self.inps = self.inps.to(next(layer.parameters()).device) # in case multi-gpu
# get output as next layer's input
self.inps = layer(self.inps, **self.module_kwargs)[0]
for h in handles:
h.remove()
# now solve for scaling and clipping
input_feat = {k: torch.cat(v, dim=0) for k, v in input_feat.items()}
return input_feat
class MixtralAWQForCausalLM(BaseAWQForCausalLM):
layer_type = "MixtralDecoderLayer"
max_new_tokens_key = "max_position_embeddings"
@torch.no_grad()
def quantize(
self,
tokenizer=None,
quant_config={},
calib_data: Union[str, List[str]] = "pileval",
split="train",
text_column="text",
):
self.quant_config = MixtralAwqConfig.from_dict(quant_config)
#### <MODIFIED>
quantizer = MixtralAwqQuantizer(
self,
self.model,
tokenizer,
self.quant_config.w_bit,
self.quant_config.q_group_size,
self.quant_config.version,
calib_data,
split,
text_column,
modules_to_not_convert=self.quant_config.modules_to_not_convert,
)
#### </MODIFIED>
quantizer.quantize()
self.is_quantized = True
def _load_quantized_modules(self, model, quant_config, version):
# Real quantization of weights
assert quant_config.zero_point, "We only support zero_point quantization now."
# Get blocks of model
layers = self.get_model_layers(model)
for i in tqdm(range(len(layers)), desc="Replacing layers..."):
layer = layers[i]
# Get every linear layer in a block
named_linears = get_named_linears(layer)
#### <MODIFIED>
# Filter out the linear layers we don't want to exclude
named_linears = _exclude_layers_to_not_quantize(
named_linears, quant_config.modules_to_not_convert
)
#### </MODIFIED>
# Replace activation functions
self._scale_activations(self, layer)
# Replace nn.Linear with WQLinear
for name, module in named_linears.items():
if version == "GEMM":
q_linear_module = WQLinear_GEMM
elif version == "GEMV":
q_linear_module = WQLinear_GEMV
q_linear = q_linear_module.from_linear(
module, quant_config.w_bit, quant_config.q_group_size, True
)
q_linear.to(next(layer.parameters()).device)
set_op_by_name(layer, name, q_linear)
torch.cuda.empty_cache()
gc.collect()
def _load_config(
self,
model_path,
model_filename,
safetensors=True,
version="GEMM",
trust_remote_code=True,
max_new_tokens=4096,
**config_kwargs,
):
# [STEP 1] Download model if path is not a directory
if not os.path.isdir(model_path):
ignore_patterns = ["*msgpack*", "*h5*", "optimizer.pt"]
if safetensors:
ignore_patterns.extend(["*.pt*", "*.bin*"])
else:
ignore_patterns.append("*.safetensors*")
from huggingface_hub import snapshot_download
model_path = snapshot_download(model_path, ignore_patterns=ignore_patterns)
if model_filename != "":
model_weights_path = model_path + f"/{model_filename}"
else:
model_weights_path = model_path
#### <MODIFIED>
# [STEP 2] Load config and set sequence length
# TODO: Create BaseAWQConfig class
quant_config = MixtralAwqConfig.from_pretrained(model_path)
#### </MODIFIED>
from transformers import AutoConfig
# Load model config and set max generation length
if max_new_tokens is None and hasattr(self, "max_new_tokens_key"):
config = AutoConfig.from_pretrained(
model_path, trust_remote_code=trust_remote_code, **config_kwargs
)
config.max_new_tokens = getattr(config, self.max_new_tokens_key)
else:
max_new_tokens = 2048 if max_new_tokens is None else max_new_tokens
config = AutoConfig.from_pretrained(
model_path, trust_remote_code=trust_remote_code, **config_kwargs
)
config.max_new_tokens = max_new_tokens
return model_weights_path, config, quant_config
@staticmethod
def fuse_layers(model: OldmixtralForCausalLM):
fuser = MixtralFuser(model)
fuser.fuse_transformer()
@staticmethod
def get_model_layers(model: OldmixtralForCausalLM):
return model.model.layers
@staticmethod
def get_act_for_scaling(module):
return dict(is_scalable=False)
@staticmethod
def move_embed(model: OldmixtralForCausalLM, device: str):
model.model.embed_tokens = model.model.embed_tokens.to(device)
@staticmethod
def get_layers_for_scaling(module: OldmixtralDecoderLayer, input_feat, module_kwargs):
layers = []
return layers
class MixtralMLP(nn.Module):
def __init__(self, gate_proj, down_proj, up_proj):
super().__init__()
self.fused_mlp = QuantFusedMLP(gate_proj=gate_proj, down_proj=down_proj, up_proj=up_proj)
def forward(self, hidden_states, routing_weights):
return routing_weights * self.fused_mlp(hidden_states)
class MixtralBlock(nn.Module):
def __init__(
self,
hidden_size,
n_heads,
n_kv_heads,
qkv_layer,
o_proj,
moe,
norm_1,
norm_2,
dev,
max_seq_len,
):
super().__init__()
self.n_heads = n_heads
self.n_kv_heads = n_kv_heads
self.hidden_size = hidden_size
self.norm_1 = norm_1.to(dev)
self.attn = QuantAttentionFused(
self.hidden_size,
self.n_heads,
self.n_kv_heads,
qkv_layer,
o_proj,
dev=dev,
max_seq_len=max_seq_len,
use_alibi=False,
).to(dev)
self.norm_2 = norm_2.to(dev)
self.moe = moe
self.device = dev
def forward(
self, hidden_states, past_key_value, attn_bias=None, attention_mask=None, is_causal=None
):
norm_out = self.norm_1(hidden_states)
attn_output, _, past_key_value = self.attn.forward(
hidden_states=norm_out, past_key_value=past_key_value, attention_mask=attention_mask
)
h = hidden_states.to(attn_output.device) + attn_output
moe_output, _ = self.moe.forward(self.norm_2(h))
out = h + moe_output
return out, None, past_key_value
class MixtralModel(nn.Module):
"""
LlamaLikeModel is intended to be reused across models that have
an architecture that closely resembles Llama, e.g. Mistral and Aquila.
"""
def __init__(self, vocab_size, blocks, embedding, norm):
super().__init__()
self.vocab_size = vocab_size
self.embedding = embedding
self.blocks: List[MixtralBlock] = nn.ModuleList(blocks)
self.norm = norm
self.last_forward_num_tokens = 0
@torch.inference_mode()
def forward(
self,
input_ids: torch.Tensor,
attn_bias=None,
attention_mask=None,
is_causal=None,
*args,
**kwargs,
):
input_ids, self.last_forward_num_tokens = fused_utils.prepare_input_ids(
input_ids, self.last_forward_num_tokens
)
_bsz, seqlen = input_ids.shape
fused_utils.prepare_cache(self.blocks, seqlen)
h = self.embedding(input_ids)
mask = fused_utils.prepare_attention_mask(
seqlen=seqlen,
start_pos=self.blocks[0].attn.start_pos,
device=input_ids.device,
type_as=h,
)
for layer in self.blocks:
h, mask = fused_utils.prepare_correct_devices(
layer,
h,
mask,
)
h, _, past_key_value = layer(h, None, attention_mask=mask, is_causal=is_causal)
h = self.norm(h)
return MoeModelOutputWithPast(
last_hidden_state=h,
past_key_values=past_key_value,
hidden_states=(),
attentions=(),
router_logits=(),
)
class MixtralFuser:
def __init__(self, model: OldmixtralForCausalLM):
self.model = model
self.mixtral_blocks: List[Tuple[str, OldmixtralDecoderLayer]] = [
(name, module)
for name, module in self.model.named_modules()
if "MixtralDecoderLayer".lower() in module.__class__.__name__.lower()
]
def fuse_transformer(self):
blocks = []
module: OldmixtralDecoderLayer
for module in tqdm(self.model.model.layers, desc="Fusing layers..."):
device = next(iter(module.state_dict().values())).device
qkv = fused_utils.fuse_qkv(
module, module.self_attn.q_proj, module.self_attn.k_proj, module.self_attn.v_proj
)
# Adapt to mixture of experts
for i in range(len(module.block_sparse_moe.experts)):
mlp = MixtralMLP(
gate_proj=module.block_sparse_moe.experts[i].w1,
down_proj=module.block_sparse_moe.experts[i].w2,
up_proj=module.block_sparse_moe.experts[i].w3,
)
module.block_sparse_moe.experts[i] = mlp
norm_1 = FasterTransformerRMSNorm(
module.input_layernorm.weight, module.input_layernorm.variance_epsilon
)
norm_2 = FasterTransformerRMSNorm(
module.post_attention_layernorm.weight,
module.post_attention_layernorm.variance_epsilon,
)
blocks.append(
MixtralBlock(
hidden_size=self.model.config.hidden_size,
n_heads=self.model.config.num_attention_heads,
n_kv_heads=self.model.config.num_key_value_heads,
qkv_layer=qkv,
o_proj=module.self_attn.o_proj,
moe=module.block_sparse_moe,
norm_1=norm_1,
norm_2=norm_2,
dev=device,
max_seq_len=self.model.config.max_new_tokens,
)
)
self.model.model = MixtralModel(
self.model.config.vocab_size,
blocks,
self.model.model.embed_tokens,
self.model.model.norm,
)
def quantize():
from transformers import AutoTokenizer
MODEL = "./shared/huggingface/Mixtral-8x7B-Instruct-v0.1"
CONFIG = {
"zero_point": True,
"q_group_size": 128,
"w_bit": 4,
"version": "GEMM",
"modules_to_not_convert": ["gate"],
}
model = MixtralAWQForCausalLM.from_pretrained(
MODEL, "mixtral", low_cpu_mem_usage=True, safetensors=True, device_map="cpu"
)
tokenizer = AutoTokenizer.from_pretrained(MODEL)
# Quantize
model.quantize(tokenizer, quant_config=CONFIG)
# Save quantized model
output = MODEL.replace("huggingface", "awq")
model.save_quantized(output)
tokenizer.save_pretrained(output)
if __name__ == "__main__":
import torch
from transformers import AutoTokenizer, TextStreamer
MODEL = "./shared/awq/Mixtral-8x7B-Instruct-v0.1/"
model = MixtralAWQForCausalLM.from_quantized(
MODEL,
"mixtral",
fuse_layers=True,
)
tokenizer = AutoTokenizer.from_pretrained(MODEL)
inputs = tokenizer("안녕하세요, 저는 ", return_tensors="pt")
model.generate(
**{k: v.cuda() for k, v in inputs.items()},
streamer=TextStreamer(tokenizer),
max_new_tokens=20,
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment