-
-
Save Pozaza/c8335bbcbbd4a73dd3bec1a9644b6865 to your computer and use it in GitHub Desktop.
Make mistral 7b 0.1 gptq to work
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from packaging.version import parse as parse_version | |
from torch import device | |
from ..utils.import_utils import compare_transformers_version | |
CPU = device("cpu") | |
CUDA_0 = device("cuda:0") | |
SUPPORTED_MODELS = [ | |
"bloom", | |
"gptj", | |
"gpt2", | |
"gpt_neox", | |
"opt", | |
"moss", | |
"gpt_bigcode", | |
"codegen", | |
"RefinedWebModel", | |
"RefinedWeb", | |
"baichuan", | |
"internlm", | |
"qwen", | |
"mistral", | |
] | |
if compare_transformers_version("v4.28.0", op="ge"): | |
SUPPORTED_MODELS.append("llama") | |
EXLLAMA_DEFAULT_MAX_INPUT_LENGTH = 2048 | |
__all__ = ["CPU", "CUDA_0", "SUPPORTED_MODELS", "EXLLAMA_DEFAULT_MAX_INPUT_LENGTH"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from inspect import signature | |
from typing import Dict, Optional, Union | |
from ._base import BaseQuantizeConfig, BaseGPTQForCausalLM | |
from ._utils import check_and_get_model_type | |
from .bloom import BloomGPTQForCausalLM | |
from .codegen import CodeGenGPTQForCausalLM | |
from .gpt_neox import GPTNeoXGPTQForCausalLM | |
from .gptj import GPTJGPTQForCausalLM | |
from .gpt2 import GPT2GPTQForCausalLM | |
from .llama import LlamaGPTQForCausalLM | |
from .moss import MOSSGPTQForCausalLM | |
from .opt import OPTGPTQForCausalLM | |
from .rw import RWGPTQForCausalLM | |
from .gpt_bigcode import GPTBigCodeGPTQForCausalLM | |
from .baichuan import BaiChuanGPTQForCausalLM | |
from .internlm import InternLMGPTQForCausalLM | |
from .qwen import QwenGPTQForCausalLM | |
GPTQ_CAUSAL_LM_MODEL_MAP = { | |
"bloom": BloomGPTQForCausalLM, | |
"gpt_neox": GPTNeoXGPTQForCausalLM, | |
"gptj": GPTJGPTQForCausalLM, | |
"gpt2": GPT2GPTQForCausalLM, | |
"llama": LlamaGPTQForCausalLM, | |
"opt": OPTGPTQForCausalLM, | |
"moss": MOSSGPTQForCausalLM, | |
"gpt_bigcode": GPTBigCodeGPTQForCausalLM, | |
"codegen": CodeGenGPTQForCausalLM, | |
"RefinedWebModel": RWGPTQForCausalLM, | |
"RefinedWeb": RWGPTQForCausalLM, | |
"baichuan": BaiChuanGPTQForCausalLM, | |
"internlm": InternLMGPTQForCausalLM, | |
"qwen": QwenGPTQForCausalLM, | |
} | |
class AutoGPTQForCausalLM: | |
def __init__(self): | |
raise EnvironmentError( | |
"AutoGPTQModelForCausalLM is designed to be instantiated\n" | |
"using `AutoGPTQModelForCausalLM.from_pretrained` if want to quantize a pretrained model.\n" | |
"using `AutoGPTQModelForCausalLM.from_quantized` if want to inference with quantized model." | |
) | |
@classmethod | |
def from_pretrained( | |
cls, | |
pretrained_model_name_or_path: str, | |
quantize_config: BaseQuantizeConfig, | |
max_memory: Optional[dict] = None, | |
trust_remote_code: bool = False, | |
**model_init_kwargs | |
) -> BaseGPTQForCausalLM: | |
model_type = check_and_get_model_type( | |
pretrained_model_name_or_path, trust_remote_code | |
) | |
return GPTQ_CAUSAL_LM_MODEL_MAP[model_type].from_pretrained( | |
pretrained_model_name_or_path=pretrained_model_name_or_path, | |
quantize_config=quantize_config, | |
max_memory=max_memory, | |
trust_remote_code=trust_remote_code, | |
**model_init_kwargs | |
) | |
@classmethod | |
def from_quantized( | |
cls, | |
model_name_or_path: Optional[str], | |
device_map: Optional[Union[str, Dict[str, Union[str, int]]]] = None, | |
max_memory: Optional[dict] = None, | |
device: Optional[Union[str, int]] = None, | |
low_cpu_mem_usage: bool = False, | |
use_triton: bool = False, | |
inject_fused_attention: bool = True, | |
inject_fused_mlp: bool = True, | |
use_cuda_fp16: bool = True, | |
quantize_config: Optional[BaseQuantizeConfig] = None, | |
model_basename: Optional[str] = None, | |
use_safetensors: bool = False, | |
trust_remote_code: bool = False, | |
warmup_triton: bool = False, | |
trainable: bool = False, | |
disable_exllama: bool = False, | |
**kwargs | |
) -> BaseGPTQForCausalLM: | |
model_type = check_and_get_model_type(model_name_or_path, trust_remote_code) | |
if model_type == "mistral": | |
model_type = "llama" | |
quant_func = GPTQ_CAUSAL_LM_MODEL_MAP[model_type].from_quantized | |
# A static list of kwargs needed for huggingface_hub | |
huggingface_kwargs = [ | |
"cache_dir", | |
"force_download", | |
"proxies", | |
"resume_download", | |
"local_files_only", | |
"use_auth_token", | |
"revision", | |
"subfolder", | |
"_raise_exceptions_for_missing_entries", | |
"_commit_hash" | |
] | |
# TODO: do we need this filtering of kwargs? @PanQiWei is there a reason we can't just pass all kwargs? | |
keywords = { | |
key: kwargs[key] | |
for key in list(signature(quant_func).parameters.keys()) + huggingface_kwargs | |
if key in kwargs | |
} | |
return quant_func( | |
model_name_or_path=model_name_or_path, | |
device_map=device_map, | |
max_memory=max_memory, | |
device=device, | |
low_cpu_mem_usage=low_cpu_mem_usage, | |
use_triton=use_triton, | |
inject_fused_attention=inject_fused_attention, | |
inject_fused_mlp=inject_fused_mlp, | |
use_cuda_fp16=use_cuda_fp16, | |
quantize_config=quantize_config, | |
model_basename=model_basename, | |
use_safetensors=use_safetensors, | |
trust_remote_code=trust_remote_code, | |
warmup_triton=warmup_triton, | |
trainable=trainable, | |
disable_exllama=disable_exllama, | |
**keywords | |
) | |
__all__ = ["AutoGPTQForCausalLM"] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment