Skip to content

Instantly share code, notes, and snippets.

@menandro
Created May 28, 2023 09:54
Show Gist options
  • Save menandro/e96c90a67c6d63bc402ef5300f1b7bf0 to your computer and use it in GitHub Desktop.
Save menandro/e96c90a67c6d63bc402ef5300f1b7bf0 to your computer and use it in GitHub Desktop.
# Follow installation from here: https://www.reddit.com/r/LocalLLaMA/comments/11o6o3f/how_to_install_llama_8bit_and_4bit/
import torch
import torch.nn as nn
import sys
sys.path.append('repositories/GPTQ-for-LLaMa')
import time
from quant import make_quant
is_triton = False
import inspect
from transformers import AutoConfig, AutoModelForCausalLM, LlamaTokenizer
import transformers
from torch.nn import functional as F
def find_layers(module, layers=[nn.Conv2d, nn.Linear], name=''):
if type(module) in layers:
return {name: module}
res = {}
for name1, child in module.named_children():
res.update(find_layers(
child, layers=layers, name=name + '.' + name1 if name != '' else name1
))
return res
def load_quant(folder, checkpoint, wbits, groupsize=-1, faster_kernel=False, exclude_layers=None, kernel_switch_threshold=128, eval=True):
exclude_layers = exclude_layers or ['lm_head']
def noop(*args, **kwargs):
pass
config = AutoConfig.from_pretrained(folder, trust_remote_code=True)
torch.nn.init.kaiming_uniform_ = noop
torch.nn.init.uniform_ = noop
torch.nn.init.normal_ = noop
torch.set_default_dtype(torch.half)
transformers.modeling_utils._init_weights = False
torch.set_default_dtype(torch.half)
model = AutoModelForCausalLM.from_config(config, trust_remote_code=True)
torch.set_default_dtype(torch.float)
if eval:
model = model.eval()
layers = find_layers(model)
for name in exclude_layers:
if name in layers:
del layers[name]
gptq_args = inspect.getfullargspec(make_quant).args
make_quant_kwargs = {
'module': model,
'names': layers,
'bits': wbits,
}
print(make_quant_kwargs['bits'])
if 'groupsize' in gptq_args:
make_quant_kwargs['groupsize'] = groupsize
if 'faster' in gptq_args:
make_quant_kwargs['faster'] = faster_kernel
if 'kernel_switch_threshold' in gptq_args:
make_quant_kwargs['kernel_switch_threshold'] = kernel_switch_threshold
make_quant(**make_quant_kwargs)
del layers
model.load_state_dict(torch.load(folder + '/' + checkpoint), strict=False)
model.seqlen = 2048
return model
# The function that loads the model in modules/models.py
def load_quantized(folder, checkpoint):
threshold = 128
model = load_quant(folder, checkpoint, 4, 128, kernel_switch_threshold=threshold)
model = model.to(torch.device('cuda:0'))
return model
model_dir = 'models/TheBloke_vicuna-13B-1.1-GPTQ-4bit-128g'
checkpoint = 'vicuna-13B-1.1-GPTQ-4bit-128g.compat.no-act-order.pt'
model = load_quantized(model_dir, checkpoint)
tokenizer = LlamaTokenizer.from_pretrained(model_dir, clean_up_tokenization_spaces=True)
try:
tokenizer.eos_token_id = 2
tokenizer.bos_token_id = 1
tokenizer.pad_token_id = 0
except:
pass
print("Model loaded")
model.eval()
prompt = "USER: Why is the sky blue? Answer in one sentence. ASSISTANT:"
tokens = tokenizer(prompt, add_special_tokens=True, return_tensors="pt").to(torch.device('cuda:0'))
top_k = 20
max_tokens = 200
do_sample = True
idx = tokens['input_ids']
n_tokens = 0
token_counter = 0
while True:
logits = model(idx)["logits"][:,-1,:]
# print(logits.shape)
if top_k is not None:
v, _ = torch.topk(logits, top_k)
# print(v.shape)
logits[logits < v[:, [-1]]] = -float('Inf')
probs = F.softmax(logits, dim=-1)
# either sample from the distribution or take the most likely element
if do_sample:
idx_next = torch.multinomial(probs, num_samples=1)
else:
_, idx_next = torch.topk(probs, k=1, dim=-1)
# print(idx_next.shape)
# append sampled index to the running sequence and continue
idx = torch.cat((idx, idx_next), dim=1)
# print(idx_next==2)
if idx_next == 2:
break
n_tokens += 1
if n_tokens > max_tokens:
break
token_counter += 1
if token_counter >= 5:
result = tokenizer.batch_decode(idx, skip_special_tokens=True)
print(result, end='\r')
token_counter = 0
result = tokenizer.batch_decode(idx, skip_special_tokens=True)
print(result)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment