Skip to content

Instantly share code, notes, and snippets.

@mobicham
Last active April 8, 2024 07:39
Show Gist options
  • Save mobicham/84ed1809c9c2f56c5c01fbcdbe22391f to your computer and use it in GitHub Desktop.
Save mobicham/84ed1809c9c2f56c5c01fbcdbe22391f to your computer and use it in GitHub Desktop.
import torch, time
import numpy as np
from tqdm import tqdm
import gc
def cleanup():
torch.cuda.empty_cache()
gc.collect()
#https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L1214-L1224
def hf_loglikelihood(logits, labels, vocab_size):
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = torch.nn.CrossEntropyLoss(ignore_index=-100)
shift_logits = shift_logits.view(-1, vocab_size)
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
loss = loss_fct(shift_logits, shift_labels)
return loss
#Adapted from https://huggingface.co/transformers/v4.2.2/perplexity.html
def eval_wikitext2(model, tokenizer, max_length=1024, stride=512, verbose=True):
model.eval()
#Llama2 tokenizer
encodings = torch.load('encodings_wiki_test_llama2.pt')
vocab_size = 32000 #https://huggingface.co/meta-llama/Llama-2-7b-hf/blob/main/config.json#L24
encodings['input_ids'] = encodings['input_ids'].to('cuda')
lls, t = [], []
for i in tqdm(range(0, encodings['input_ids'].size(1), stride), disable=not verbose):
begin_loc = max(i + stride - max_length, 0)
end_loc = min(i + stride, encodings['input_ids'].size(1))
trg_len = end_loc - i
input_ids = encodings['input_ids'][:,begin_loc:end_loc]
target_ids = input_ids.clone()
target_ids[:,:-trg_len] = -100 #ignore context
t1 = time.time()
with torch.no_grad():
#log_likelihood = model(input_ids, labels=target_ids).loss * trg_len
logits = model(input_ids)
log_likelihood = hf_loglikelihood(logits=logits, labels=target_ids, vocab_size=vocab_size) * trg_len
torch.cuda.synchronize()
t2 = time.time()
t.append((t2-t1))
lls.append(log_likelihood)
del input_ids, target_ids
ppl = np.round(float(torch.exp(torch.stack(lls).sum() / end_loc)), 4)
pred_time = np.round(np.mean(t), 3)
if(verbose):
print('perplexity', ppl)
print('time', str(pred_time) + ' sec')
del encodings
cleanup()
return {'perplexity':ppl, 'prediction_time':pred_time}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment