Skip to content

Instantly share code, notes, and snippets.

@ftnext
Created February 11, 2023 08:14
Show Gist options
  • Save ftnext/90e1efb6d94d3b318bf4f5de6e97da31 to your computer and use it in GitHub Desktop.
Save ftnext/90e1efb6d94d3b318bf4f5de6e97da31 to your computer and use it in GitHub Desktop.
import argparse
import csv
from fairseq.models.roberta import RobertaModel
from tqdm import tqdm
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("model_dir_path", help="Directory path.")
parser.add_argument("evaluate_data_path", help="A tsv file.")
parser.add_argument("--model_name", default="checkpoint_best.pt")
parser.add_argument("--preprocess_dir_path", default="QNLI-bin")
args = parser.parse_args()
model = RobertaModel.from_pretrained(
args.model_dir_path, args.model_name, args.preprocess_dir_path
)
model.eval()
# nspecial == 4
"""
>>> model.task.label_dictionary.string([4])
'entailment'
>>> model.task.label_dictionary.string([5])
'not_entailment'
"""
label_fn = lambda label: model.task.label_dictionary.string(
[label + model.task.label_dictionary.nspecial]
)
n_correct, n_samples = 0, 0
with open(args.evaluate_data_path, encoding="utf8") as fin:
reader = csv.DictReader(fin, delimiter="\t")
for row in tqdm(reader):
tokens = model.encode(row["question"], row["sentence"])
prediction = model.predict("sentence_classification_head", tokens[:512])
predicted_label_index = prediction.argmax().item()
predicted_label = label_fn(predicted_label_index)
n_correct += int(predicted_label == row["label"])
n_samples += 1
# dev | Accuracy: 0.62350484146573
# n_correct, n_samples (3284, 5267)
print(f"| Accuracy: {n_correct/n_samples}")
# , testでも確認する
# train | Accuracy: 0.6714756641458212
# n_correct, n_samples (69256, 103140)
# test はラベルがないのでaccuracyが出せない
# データの数はwc -lで大体つかめる
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment