-
-
Save ftnext/90e1efb6d94d3b318bf4f5de6e97da31 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import argparse | |
import csv | |
from fairseq.models.roberta import RobertaModel | |
from tqdm import tqdm | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser() | |
parser.add_argument("model_dir_path", help="Directory path.") | |
parser.add_argument("evaluate_data_path", help="A tsv file.") | |
parser.add_argument("--model_name", default="checkpoint_best.pt") | |
parser.add_argument("--preprocess_dir_path", default="QNLI-bin") | |
args = parser.parse_args() | |
model = RobertaModel.from_pretrained( | |
args.model_dir_path, args.model_name, args.preprocess_dir_path | |
) | |
model.eval() | |
# nspecial == 4 | |
""" | |
>>> model.task.label_dictionary.string([4]) | |
'entailment' | |
>>> model.task.label_dictionary.string([5]) | |
'not_entailment' | |
""" | |
label_fn = lambda label: model.task.label_dictionary.string( | |
[label + model.task.label_dictionary.nspecial] | |
) | |
n_correct, n_samples = 0, 0 | |
with open(args.evaluate_data_path, encoding="utf8") as fin: | |
reader = csv.DictReader(fin, delimiter="\t") | |
for row in tqdm(reader): | |
tokens = model.encode(row["question"], row["sentence"]) | |
prediction = model.predict("sentence_classification_head", tokens[:512]) | |
predicted_label_index = prediction.argmax().item() | |
predicted_label = label_fn(predicted_label_index) | |
n_correct += int(predicted_label == row["label"]) | |
n_samples += 1 | |
# dev | Accuracy: 0.62350484146573 | |
# n_correct, n_samples (3284, 5267) | |
print(f"| Accuracy: {n_correct/n_samples}") | |
# , testでも確認する | |
# train | Accuracy: 0.6714756641458212 | |
# n_correct, n_samples (69256, 103140) | |
# test はラベルがないのでaccuracyが出せない | |
# データの数はwc -lで大体つかめる |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment