Last active
September 22, 2022 23:35
-
-
Save nbroad1881/27906b1636c34748f3e10527e7bc3902 to your computer and use it in GitHub Desktop.
Quickly test how a Masked LM will do on texts.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import argparse | |
from itertools import chain | |
import evaluate | |
from datasets import load_dataset | |
from transformers import AutoTokenizer, AutoModelForMaskedLM, Trainer, TrainingArguments, DataCollatorForLanguageModeling | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--model_name", type=str) | |
parser.add_argument("--text_file", type=str) | |
parser.add_argument("--text_column", type=str) | |
parser.add_argument("--max_seq_length", type=int) | |
parser.add_argument("--batch_size", type=int) | |
parser.add_argument("--max_batches", type=int) | |
args = parser.parse_args() | |
dataset = load_dataset(args.text_file.split(".")[-1], data_files=args.text_file, split="train") | |
print(dataset) | |
tokenizer = AutoTokenizer.from_pretrained(args.model_name) | |
def tokenize_function(examples): | |
return tokenizer(examples[args.text_column], return_special_tokens_mask=True) | |
def group_texts(examples): | |
# Concatenate all texts. | |
concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()} | |
total_length = len(concatenated_examples[list(examples.keys())[0]]) | |
# We drop the small remainder, we could add padding if the model supported it instead of this drop, you can | |
# customize this part to your needs. | |
if total_length >= args.max_seq_length: | |
total_length = (total_length // args.max_seq_length) * args.max_seq_length | |
# Split by chunks of max_len. | |
result = { | |
k: [t[i : i + args.max_seq_length] for i in range(0, total_length, args.max_seq_length)] | |
for k, t in concatenated_examples.items() | |
} | |
return result | |
dataset = dataset.map(tokenize_function, batched=True, num_proc=2, remove_columns=dataset.column_names) | |
dataset = dataset.map(group_texts, batched=True, num_proc=2, remove_columns=dataset.column_names) | |
dataset = dataset.select(range(min(len(dataset), args.max_batches*args.batch_size))) | |
def preprocess_logits_for_metrics(logits, labels): | |
if isinstance(logits, tuple): | |
# Depending on the model and config, logits may contain extra tensors, | |
# like past_key_values, but logits always come first | |
logits = logits[0] | |
return logits.argmax(dim=-1) | |
metric = evaluate.load("accuracy") | |
def compute_metrics(eval_preds): | |
preds, labels = eval_preds | |
# preds have the same shape as the labels, after the argmax(-1) has been calculated | |
# by preprocess_logits_for_metrics | |
labels = labels.reshape(-1) | |
preds = preds.reshape(-1) | |
mask = labels != -100 | |
labels = labels[mask] | |
preds = preds[mask] | |
return metric.compute(predictions=preds, references=labels) | |
model = AutoModelForMaskedLM.from_pretrained(args.model_name) | |
args = TrainingArguments(".", per_device_eval_batch_size=args.batch_size, log_level="error") | |
trainer = Trainer( | |
model, | |
args=args, | |
data_collator=DataCollatorForLanguageModeling(tokenizer, mlm_probability=0.15), | |
compute_metrics=compute_metrics, | |
preprocess_logits_for_metrics=preprocess_logits_for_metrics, | |
) | |
results = trainer.predict(dataset) | |
print(results.metrics) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment