Skip to content

Instantly share code, notes, and snippets.

@Pozaza

Pozaza/_const.py Secret

Created October 6, 2023 16:19
Show Gist options
  • Save Pozaza/c8335bbcbbd4a73dd3bec1a9644b6865 to your computer and use it in GitHub Desktop.
Save Pozaza/c8335bbcbbd4a73dd3bec1a9644b6865 to your computer and use it in GitHub Desktop.
Make mistral 7b 0.1 gptq to work
from packaging.version import parse as parse_version
from torch import device
from ..utils.import_utils import compare_transformers_version
CPU = device("cpu")
CUDA_0 = device("cuda:0")
SUPPORTED_MODELS = [
"bloom",
"gptj",
"gpt2",
"gpt_neox",
"opt",
"moss",
"gpt_bigcode",
"codegen",
"RefinedWebModel",
"RefinedWeb",
"baichuan",
"internlm",
"qwen",
"mistral",
]
if compare_transformers_version("v4.28.0", op="ge"):
SUPPORTED_MODELS.append("llama")
EXLLAMA_DEFAULT_MAX_INPUT_LENGTH = 2048
__all__ = ["CPU", "CUDA_0", "SUPPORTED_MODELS", "EXLLAMA_DEFAULT_MAX_INPUT_LENGTH"]
from inspect import signature
from typing import Dict, Optional, Union
from ._base import BaseQuantizeConfig, BaseGPTQForCausalLM
from ._utils import check_and_get_model_type
from .bloom import BloomGPTQForCausalLM
from .codegen import CodeGenGPTQForCausalLM
from .gpt_neox import GPTNeoXGPTQForCausalLM
from .gptj import GPTJGPTQForCausalLM
from .gpt2 import GPT2GPTQForCausalLM
from .llama import LlamaGPTQForCausalLM
from .moss import MOSSGPTQForCausalLM
from .opt import OPTGPTQForCausalLM
from .rw import RWGPTQForCausalLM
from .gpt_bigcode import GPTBigCodeGPTQForCausalLM
from .baichuan import BaiChuanGPTQForCausalLM
from .internlm import InternLMGPTQForCausalLM
from .qwen import QwenGPTQForCausalLM
GPTQ_CAUSAL_LM_MODEL_MAP = {
"bloom": BloomGPTQForCausalLM,
"gpt_neox": GPTNeoXGPTQForCausalLM,
"gptj": GPTJGPTQForCausalLM,
"gpt2": GPT2GPTQForCausalLM,
"llama": LlamaGPTQForCausalLM,
"opt": OPTGPTQForCausalLM,
"moss": MOSSGPTQForCausalLM,
"gpt_bigcode": GPTBigCodeGPTQForCausalLM,
"codegen": CodeGenGPTQForCausalLM,
"RefinedWebModel": RWGPTQForCausalLM,
"RefinedWeb": RWGPTQForCausalLM,
"baichuan": BaiChuanGPTQForCausalLM,
"internlm": InternLMGPTQForCausalLM,
"qwen": QwenGPTQForCausalLM,
}
class AutoGPTQForCausalLM:
def __init__(self):
raise EnvironmentError(
"AutoGPTQModelForCausalLM is designed to be instantiated\n"
"using `AutoGPTQModelForCausalLM.from_pretrained` if want to quantize a pretrained model.\n"
"using `AutoGPTQModelForCausalLM.from_quantized` if want to inference with quantized model."
)
@classmethod
def from_pretrained(
cls,
pretrained_model_name_or_path: str,
quantize_config: BaseQuantizeConfig,
max_memory: Optional[dict] = None,
trust_remote_code: bool = False,
**model_init_kwargs
) -> BaseGPTQForCausalLM:
model_type = check_and_get_model_type(
pretrained_model_name_or_path, trust_remote_code
)
return GPTQ_CAUSAL_LM_MODEL_MAP[model_type].from_pretrained(
pretrained_model_name_or_path=pretrained_model_name_or_path,
quantize_config=quantize_config,
max_memory=max_memory,
trust_remote_code=trust_remote_code,
**model_init_kwargs
)
@classmethod
def from_quantized(
cls,
model_name_or_path: Optional[str],
device_map: Optional[Union[str, Dict[str, Union[str, int]]]] = None,
max_memory: Optional[dict] = None,
device: Optional[Union[str, int]] = None,
low_cpu_mem_usage: bool = False,
use_triton: bool = False,
inject_fused_attention: bool = True,
inject_fused_mlp: bool = True,
use_cuda_fp16: bool = True,
quantize_config: Optional[BaseQuantizeConfig] = None,
model_basename: Optional[str] = None,
use_safetensors: bool = False,
trust_remote_code: bool = False,
warmup_triton: bool = False,
trainable: bool = False,
disable_exllama: bool = False,
**kwargs
) -> BaseGPTQForCausalLM:
model_type = check_and_get_model_type(model_name_or_path, trust_remote_code)
if model_type == "mistral":
model_type = "llama"
quant_func = GPTQ_CAUSAL_LM_MODEL_MAP[model_type].from_quantized
# A static list of kwargs needed for huggingface_hub
huggingface_kwargs = [
"cache_dir",
"force_download",
"proxies",
"resume_download",
"local_files_only",
"use_auth_token",
"revision",
"subfolder",
"_raise_exceptions_for_missing_entries",
"_commit_hash"
]
# TODO: do we need this filtering of kwargs? @PanQiWei is there a reason we can't just pass all kwargs?
keywords = {
key: kwargs[key]
for key in list(signature(quant_func).parameters.keys()) + huggingface_kwargs
if key in kwargs
}
return quant_func(
model_name_or_path=model_name_or_path,
device_map=device_map,
max_memory=max_memory,
device=device,
low_cpu_mem_usage=low_cpu_mem_usage,
use_triton=use_triton,
inject_fused_attention=inject_fused_attention,
inject_fused_mlp=inject_fused_mlp,
use_cuda_fp16=use_cuda_fp16,
quantize_config=quantize_config,
model_basename=model_basename,
use_safetensors=use_safetensors,
trust_remote_code=trust_remote_code,
warmup_triton=warmup_triton,
trainable=trainable,
disable_exllama=disable_exllama,
**keywords
)
__all__ = ["AutoGPTQForCausalLM"]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment