Last active
August 22, 2023 12:28
-
-
Save bjascob/7cece1eb654301925f79eb567af097f8 to your computer and use it in GitHub Desktop.
Bart token level perplexity.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import math | |
import logging | |
import statistics | |
from tqdm import tqdm | |
import torch | |
from torch.utils.data import DataLoader, Dataset | |
import datasets | |
from datasets import load_dataset | |
import transformers | |
from transformers import AutoTokenizer | |
from transformers import DataCollatorForLanguageModeling | |
class BartPerplexityTester: | |
def __init__(self, model, num_test_chars=None, device='cuda'): | |
self.device = torch.device(device) | |
self.model = model.to(self.device) | |
self.tokenizer = AutoTokenizer.from_pretrained('facebook/bart-base') # bart-large is the same | |
self.text = self.load_text(num_test_chars) | |
# Load the text. num_test_chars: 10K chars = 2246 tokens (use None for all) | |
def load_text(self, num_test_chars): | |
logging.getLogger('datasets').setLevel(logging.ERROR) # Reusing dataset wikitext,... | |
articles = [a['text'] for a in load_dataset("wikitext", "wikitext-2-raw-v1", split="test")] | |
text = ' '.join(articles)[:num_test_chars] | |
return text | |
def run_test(self, seq_len=None, num_test_chars=None, batch_size=8, mlm_prob=0.15): | |
# Tokenize. verbose=False elminates message 'token sequences too long for model' | |
tok_ids = self.tokenizer(self.text, add_special_tokens=False, verbose=False).input_ids | |
# Split into tokenized sequences all of the same length and discard any short samples at the end | |
if seq_len is None: | |
seq_len = self.tokenizer.model_max_length | |
samples = [c for c in chunk(tok_ids, seq_len) if len(c)==seq_len] | |
print('Loaded {:,} samples of length {:,} tokens'.format(len(samples), len(samples[0]))) | |
# Add bos and eos tokens and create the decoder_input_ids | |
# mask_token_id = 50264 | |
bos = self.tokenizer.bos_token_id # = 0 | |
eos = self.tokenizer.eos_token_id # = 2 | |
dst = self.model.config.decoder_start_token_id # = 2 (same as eos token id) | |
input_ids = [[bos] + sample + [eos] for sample in samples] | |
decoder_ids = [[dst] + iids[:-1] for iids in input_ids] # shift_tokens_right | |
# Put this all into a dataset and create the loader | |
# The collator will take care of randomly masking the input_id tokens and creating the | |
# 'labels' keys with -100 for any non-masked token | |
dataset = EvalDataset(input_ids, decoder_ids) | |
collator = DataCollatorForLanguageModeling(tokenizer=self.tokenizer, mlm_probability=mlm_prob) | |
dataloader = DataLoader(dataset, collate_fn=collator, batch_size=batch_size) | |
# Run evaluation | |
print('Testing') | |
self.model.eval() | |
losses = [] | |
for step, batch in enumerate(tqdm(dataloader, ncols=100, disable=False)): | |
with torch.no_grad(): | |
torch.set_printoptions(threshold=10000, linewidth=150) | |
decoder_ids = batch['decoder_input_ids'].to(self.device) | |
input_ids = batch['input_ids'].to(self.device) | |
labels = batch['labels'].to(self.device) | |
outputs = self.model(input_ids=input_ids, labels=labels, decoder_input_ids=decoder_ids) | |
losses.append(outputs.loss.item()) | |
try: | |
perplexity = math.exp(statistics.mean(losses)) | |
except OverflowError: | |
perplexity = float('inf') | |
return perplexity | |
# iterator to split a list into n segments | |
def chunk(lst, n): | |
for i in range(0, len(lst), n): | |
yield lst[i:i + n] | |
# Container for model data | |
class EvalDataset(Dataset): | |
def __init__(self, input_ids, decoder_input_ids): | |
assert len(input_ids) == len(decoder_input_ids) | |
self.input_ids = input_ids | |
self.decoder_input_ids = decoder_input_ids | |
def __getitem__(self, index): | |
return {'input_ids': self.input_ids[index], | |
'decoder_input_ids': self.decoder_input_ids[index]} | |
def __len__(self): | |
return len(self.input_ids) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/python3 | |
from transformers import AutoModelForMaskedLM, set_seed | |
from bart_token_level_perplexity import BartPerplexityTester | |
# Use seq_len=256 as a standard for testing. | |
if __name__ == '__main__': | |
device = 'cuda:0' | |
model_name ='facebook/bart-base' | |
# Masking is a random process so results will vary unless this is set | |
# set_seed(0) | |
print('Loading model %s' % model_name) | |
model = AutoModelForMaskedLM.from_pretrained(model_name) | |
print('Loading tester with corpus and tokenizer') | |
tester = BartPerplexityTester(model, device=device) | |
# Note that sequence length is in tokens | |
# Don't set seq_len > 800 or perplexity scores will jump | |
print('Testing') | |
ppl = tester.run_test(seq_len=256, batch_size=8) | |
print() | |
print('Model perplexity is %.2f' % ppl) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Hi @bjascob! That's a cool Tester class and I think you could even make it more performant.
If you're using the
AutoTokenizer
class the argumentreturn_overflowing_tokens
will automatically break the sequence into chunks ofmax_length
sizes so there's no need for the loop between lines 32 and 35.I also took a look at the documentation for
BartForConditionalGeneration
class and found out that under thedecoder_input_ids
docstring the following snippet is provided:"For translation and summarization training, decoder_input_ids should be provided. If no decoder_input_ids is provided, the model will create this tensor by shifting the input_ids to the right for denoising pre-training following the paper."
I'm fairly new to the field but I'm sure you're measuring the ability of the model to understand the context of your corpus by harnessing its capabilities of un-masking the
<mask>
tokens correctly. If this is the case, there's no need to provide thedecoder_input_ids
since the classDataCollatorForLanguageModeling
will handle the creation of thelabels
by replicating theinput_ids
and masking the actualinput_ids
. So you don't also need to handle the inputing of theBOS_token
andEOS_token
.I'd change the
run_test()
method (lines 29 to 51) to the following:and readapt the
EvalDataset
class (lines 78 to 90) just to exclude thedecoder_input_ids
:One last input to your reply:
Imo this is due to the fact that this specific dataset is devided into rows where each row is has a paragraph of a wikipedia article. I ran a quick EDA and most usefull rows have somewhat 800 900ish characters which after going through the tokenizer would perhaps generate the complete sequences on 256 tokens and longer sequences would come with other non-related articles and would confuse the model. But again, I didn't go deep into the analysis.