Last active
February 28, 2021 11:42
-
-
Save codebreach/21c9b768effdd154c2da0d8a9791fd93 to your computer and use it in GitHub Desktop.
Measuring the Effects of Domain Adaptation on Transformers' Legal Knowledge
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import re | |
from argparse import ArgumentParser | |
from typing import List, Tuple | |
import torch | |
from transformers import (AutoConfig, AutoModelForMaskedLM, AutoTokenizer, | |
pipeline) | |
def predict_and_return( | |
model_name_or_path: str, fact: str, num_predictions: int | |
) -> Tuple[List[str], List[float]]: | |
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) | |
model = AutoModelForMaskedLM.from_pretrained(model_name_or_path) | |
config = AutoConfig.from_pretrained(model_name_or_path) | |
unmasker_pipeline = pipeline( | |
"fill-mask", | |
model=model, | |
tokenizer=tokenizer, | |
config=config, | |
top_k=num_predictions, | |
) | |
results = unmasker_pipeline(fact) | |
answers = [result["token_str"] for result in results] | |
confidence_scores = [result["score"] for result in results] | |
return (answers, confidence_scores) | |
if __name__ == "__main__": | |
parser = ArgumentParser() | |
parser.add_argument( | |
"--model_name_or_path", | |
default=None, | |
type=str, | |
required=True, | |
help="Path to pretrained model or model identifier from huggingface.co/models", | |
) | |
parser.add_argument( | |
"--fact", | |
default=None, | |
type=str, | |
help="The fact that is to posed to the language model. Must contain only one [MASK] token. For RoBERTa, the mask token is <mask>", | |
) | |
parser.add_argument( | |
"--num_predictions", | |
default=20, | |
type=int, | |
help="Number of top predictions to return", | |
) | |
args = parser.parse_args() | |
(predicted_words_for_fact, confidence_scores) = predict_and_return( | |
model_name_or_path=args.model_name_or_path, | |
fact=args.fact, | |
num_predictions=args.num_predictions, | |
) | |
masked_token = "<mask>" if "<mask>" in args.fact else "[MASK]" | |
for word, score in zip(predicted_words_for_fact, confidence_scores): | |
prediction = args.fact.replace(masked_token, word) | |
print(f"{prediction} - {str(score)}%") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment