-
-
Save bjascob/7cece1eb654301925f79eb567af097f8 to your computer and use it in GitHub Desktop.
import math | |
import logging | |
import statistics | |
from tqdm import tqdm | |
import torch | |
from torch.utils.data import DataLoader, Dataset | |
import datasets | |
from datasets import load_dataset | |
import transformers | |
from transformers import AutoTokenizer | |
from transformers import DataCollatorForLanguageModeling | |
class BartPerplexityTester: | |
def __init__(self, model, num_test_chars=None, device='cuda'): | |
self.device = torch.device(device) | |
self.model = model.to(self.device) | |
self.tokenizer = AutoTokenizer.from_pretrained('facebook/bart-base') # bart-large is the same | |
self.text = self.load_text(num_test_chars) | |
# Load the text. num_test_chars: 10K chars = 2246 tokens (use None for all) | |
def load_text(self, num_test_chars): | |
logging.getLogger('datasets').setLevel(logging.ERROR) # Reusing dataset wikitext,... | |
articles = [a['text'] for a in load_dataset("wikitext", "wikitext-2-raw-v1", split="test")] | |
text = ' '.join(articles)[:num_test_chars] | |
return text | |
def run_test(self, seq_len=None, num_test_chars=None, batch_size=8, mlm_prob=0.15): | |
# Tokenize. verbose=False elminates message 'token sequences too long for model' | |
tok_ids = self.tokenizer(self.text, add_special_tokens=False, verbose=False).input_ids | |
# Split into tokenized sequences all of the same length and discard any short samples at the end | |
if seq_len is None: | |
seq_len = self.tokenizer.model_max_length | |
samples = [c for c in chunk(tok_ids, seq_len) if len(c)==seq_len] | |
print('Loaded {:,} samples of length {:,} tokens'.format(len(samples), len(samples[0]))) | |
# Add bos and eos tokens and create the decoder_input_ids | |
# mask_token_id = 50264 | |
bos = self.tokenizer.bos_token_id # = 0 | |
eos = self.tokenizer.eos_token_id # = 2 | |
dst = self.model.config.decoder_start_token_id # = 2 (same as eos token id) | |
input_ids = [[bos] + sample + [eos] for sample in samples] | |
decoder_ids = [[dst] + iids[:-1] for iids in input_ids] # shift_tokens_right | |
# Put this all into a dataset and create the loader | |
# The collator will take care of randomly masking the input_id tokens and creating the | |
# 'labels' keys with -100 for any non-masked token | |
dataset = EvalDataset(input_ids, decoder_ids) | |
collator = DataCollatorForLanguageModeling(tokenizer=self.tokenizer, mlm_probability=mlm_prob) | |
dataloader = DataLoader(dataset, collate_fn=collator, batch_size=batch_size) | |
# Run evaluation | |
print('Testing') | |
self.model.eval() | |
losses = [] | |
for step, batch in enumerate(tqdm(dataloader, ncols=100, disable=False)): | |
with torch.no_grad(): | |
torch.set_printoptions(threshold=10000, linewidth=150) | |
decoder_ids = batch['decoder_input_ids'].to(self.device) | |
input_ids = batch['input_ids'].to(self.device) | |
labels = batch['labels'].to(self.device) | |
outputs = self.model(input_ids=input_ids, labels=labels, decoder_input_ids=decoder_ids) | |
losses.append(outputs.loss.item()) | |
try: | |
perplexity = math.exp(statistics.mean(losses)) | |
except OverflowError: | |
perplexity = float('inf') | |
return perplexity | |
# iterator to split a list into n segments | |
def chunk(lst, n): | |
for i in range(0, len(lst), n): | |
yield lst[i:i + n] | |
# Container for model data | |
class EvalDataset(Dataset): | |
def __init__(self, input_ids, decoder_input_ids): | |
assert len(input_ids) == len(decoder_input_ids) | |
self.input_ids = input_ids | |
self.decoder_input_ids = decoder_input_ids | |
def __getitem__(self, index): | |
return {'input_ids': self.input_ids[index], | |
'decoder_input_ids': self.decoder_input_ids[index]} | |
def __len__(self): | |
return len(self.input_ids) |
#!/usr/bin/python3 | |
from transformers import AutoModelForMaskedLM, set_seed | |
from bart_token_level_perplexity import BartPerplexityTester | |
# Use seq_len=256 as a standard for testing. | |
if __name__ == '__main__': | |
device = 'cuda:0' | |
model_name ='facebook/bart-base' | |
# Masking is a random process so results will vary unless this is set | |
# set_seed(0) | |
print('Loading model %s' % model_name) | |
model = AutoModelForMaskedLM.from_pretrained(model_name) | |
print('Loading tester with corpus and tokenizer') | |
tester = BartPerplexityTester(model, device=device) | |
# Note that sequence length is in tokens | |
# Don't set seq_len > 800 or perplexity scores will jump | |
print('Testing') | |
ppl = tester.run_test(seq_len=256, batch_size=8) | |
print() | |
print('Model perplexity is %.2f' % ppl) |
Hi @bjascob! That's a cool Tester class and I think you could even make it more performant.
If you're using the AutoTokenizer
class the argument return_overflowing_tokens
will automatically break the sequence into chunks of max_length
sizes so there's no need for the loop between lines 32 and 35.
I also took a look at the documentation for BartForConditionalGeneration
class and found out that under the decoder_input_ids
docstring the following snippet is provided:
"For translation and summarization training, decoder_input_ids should be provided. If no decoder_input_ids is provided, the model will create this tensor by shifting the input_ids to the right for denoising pre-training following the paper."
I'm fairly new to the field but I'm sure you're measuring the ability of the model to understand the context of your corpus by harnessing its capabilities of un-masking the <mask>
tokens correctly. If this is the case, there's no need to provide the decoder_input_ids
since the class DataCollatorForLanguageModeling
will handle the creation of the labels
by replicating the input_ids
and masking the actual input_ids
. So you don't also need to handle the inputing of the BOS_token
and EOS_token
.
I'd change the run_test()
method (lines 29 to 51) to the following:
def run_test(self, seq_len=None, num_test_chars=None, batch_size=8, mlm_prob=0.15):
# Tokenize. verbose=False elminates message 'token sequences too long for model'
tok_ids = self.tokenizer(self.text,
add_special_tokens=True, # Automatically add [BOS] and [EOS] tokens
return_overflowing_tokens=True, # Breaks the seq into chunks of `max_length`
truncation=True,
padding="max_length",
verbose=False).input_ids
# Checks if last seq was padded
# Note not sure if excluding the last seq if it's shorter than max_length will provide
# more accurate results as this should be a comparative test
if any(tok_ids[-1] == 0): # True = there's [PAD] token
tok_ids = tok_ids[:-1]
# Put this all into a dataset and create the loader
# The collator will take care of randomly masking the input_id tokens and creating the
# 'labels' keys with -100 for any non-masked token
dataset = EvalDataset(input_ids)
collator = DataCollatorForLanguageModeling(tokenizer=self.tokenizer, mlm_probability=mlm_prob)
dataloader = DataLoader(dataset, collate_fn=collator, batch_size=batch_size)
# Run evaluation
print('Testing')
self.model.eval()
losses = []
for step, batch in enumerate(tqdm(dataloader, ncols=100, disable=False)):
with torch.no_grad():
torch.set_printoptions(threshold=10000, linewidth=150)
input_ids = batch['input_ids'].to(self.device)
labels = batch['labels'].to(self.device)
outputs = self.model(input_ids=input_ids, labels=labels)
losses.append(outputs.loss.item())
try:
perplexity = math.exp(statistics.mean(losses))
except OverflowError:
perplexity = float('inf')
return perplexity
and readapt the EvalDataset
class (lines 78 to 90) just to exclude the decoder_input_ids
:
# Container for model data
class EvalDataset(Dataset):
def __init__(self, input_ids):
self.input_ids = input_ids
def __getitem__(self, index):
return {'input_ids': self.input_ids[index]}
def __len__(self):
return len(self.input_ids)
One last input to your reply:
I'm not certain why this is. There could be a bug or conceptual issue with the test code but it may just be that when the model what pretrained, 256 token sequences were the average size of the input and so those perform the best.
Imo this is due to the fact that this specific dataset is devided into rows where each row is has a paragraph of a wikipedia article. I ran a quick EDA and most usefull rows have somewhat 800 900ish characters which after going through the tokenizer would perhaps generate the complete sequences on 256 tokens and longer sequences would come with other non-related articles and would confuse the model. But again, I didn't go deep into the analysis.
It's may be important to note that the perplexity varies a bit depending on the sequence length you use for testing. Best scores are obtained around 256 and scores get worse quickly as seq_len goes above 800.
I'm not certain why this is. There could be a bug or conceptual issue with the test code but it may just be that when the model what pretrained, 256 token sequences were the average size of the input and so those perform the best.