Skip to content

Instantly share code, notes, and snippets.

@nbroad1881
Last active September 22, 2022 23:35
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save nbroad1881/27906b1636c34748f3e10527e7bc3902 to your computer and use it in GitHub Desktop.
Save nbroad1881/27906b1636c34748f3e10527e7bc3902 to your computer and use it in GitHub Desktop.
Quickly test how a Masked LM will do on texts.
import argparse
from itertools import chain
import evaluate
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForMaskedLM, Trainer, TrainingArguments, DataCollatorForLanguageModeling
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model_name", type=str)
parser.add_argument("--text_file", type=str)
parser.add_argument("--text_column", type=str)
parser.add_argument("--max_seq_length", type=int)
parser.add_argument("--batch_size", type=int)
parser.add_argument("--max_batches", type=int)
args = parser.parse_args()
dataset = load_dataset(args.text_file.split(".")[-1], data_files=args.text_file, split="train")
print(dataset)
tokenizer = AutoTokenizer.from_pretrained(args.model_name)
def tokenize_function(examples):
return tokenizer(examples[args.text_column], return_special_tokens_mask=True)
def group_texts(examples):
# Concatenate all texts.
concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()}
total_length = len(concatenated_examples[list(examples.keys())[0]])
# We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
# customize this part to your needs.
if total_length >= args.max_seq_length:
total_length = (total_length // args.max_seq_length) * args.max_seq_length
# Split by chunks of max_len.
result = {
k: [t[i : i + args.max_seq_length] for i in range(0, total_length, args.max_seq_length)]
for k, t in concatenated_examples.items()
}
return result
dataset = dataset.map(tokenize_function, batched=True, num_proc=2, remove_columns=dataset.column_names)
dataset = dataset.map(group_texts, batched=True, num_proc=2, remove_columns=dataset.column_names)
dataset = dataset.select(range(min(len(dataset), args.max_batches*args.batch_size)))
def preprocess_logits_for_metrics(logits, labels):
if isinstance(logits, tuple):
# Depending on the model and config, logits may contain extra tensors,
# like past_key_values, but logits always come first
logits = logits[0]
return logits.argmax(dim=-1)
metric = evaluate.load("accuracy")
def compute_metrics(eval_preds):
preds, labels = eval_preds
# preds have the same shape as the labels, after the argmax(-1) has been calculated
# by preprocess_logits_for_metrics
labels = labels.reshape(-1)
preds = preds.reshape(-1)
mask = labels != -100
labels = labels[mask]
preds = preds[mask]
return metric.compute(predictions=preds, references=labels)
model = AutoModelForMaskedLM.from_pretrained(args.model_name)
args = TrainingArguments(".", per_device_eval_batch_size=args.batch_size, log_level="error")
trainer = Trainer(
model,
args=args,
data_collator=DataCollatorForLanguageModeling(tokenizer, mlm_probability=0.15),
compute_metrics=compute_metrics,
preprocess_logits_for_metrics=preprocess_logits_for_metrics,
)
results = trainer.predict(dataset)
print(results.metrics)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment