Skip to content

Instantly share code, notes, and snippets.

@mzbac
Created September 6, 2023 10:26
Show Gist options
  • Save mzbac/e6f1a328ab2606cec63f879def6d1458 to your computer and use it in GitHub Desktop.
Save mzbac/e6f1a328ab2606cec63f879def6d1458 to your computer and use it in GitHub Desktop.
import os
from transformers import AutoTokenizer, TextGenerationPipeline
from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig
import numpy as np
import torch
import torch.nn as nn
pretrained_model_dir = './merged_models'
quantized_model_dir = './models/CodeLlama-34b-guanaco-gptq'
def get_wikitext2(nsamples, seed, seqlen, model):
from datasets import load_dataset
traindata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='train')
testdata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='test')
from transformers import AutoTokenizer
try:
tokenizer = AutoTokenizer.from_pretrained(model, use_fast=False)
except:
tokenizer = AutoTokenizer.from_pretrained(model, use_fast=True)
trainenc = tokenizer("\n\n".join(traindata['text']), return_tensors='pt').to('cpu')
testenc = tokenizer("\n\n".join(testdata['text']), return_tensors='pt').to('cpu')
import random
random.seed(seed)
np.random.seed(0)
torch.random.manual_seed(0)
traindataset = []
for _ in range(nsamples):
i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1)
j = i + seqlen
inp = trainenc.input_ids[:, i:j]
attention_mask = torch.ones_like(inp)
traindataset.append({'input_ids':inp,'attention_mask': attention_mask})
return traindataset, testenc
def main():
tokenizer = AutoTokenizer.from_pretrained(pretrained_model_dir, use_fast=True)
traindataset,testenc = get_wikitext2(128, 0, 2048, pretrained_model_dir)
quantize_config = BaseQuantizeConfig(
bits=4,
group_size=128,
desc_act=False,
)
model = AutoGPTQForCausalLM.from_pretrained(pretrained_model_dir, quantize_config, device_map={"":"cpu"})
model.quantize(traindataset, use_triton=False)
model.save_quantized(quantized_model_dir, use_safetensors=True)
tokenizer.save_pretrained(quantized_model_dir)
if __name__ == "__main__":
import logging
logging.basicConfig(
format="%(asctime)s %(levelname)s [%(name)s] %(message)s", level=logging.INFO, datefmt="%Y-%m-%d %H:%M:%S"
)
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment