Skip to content

Instantly share code, notes, and snippets.

@jamescalam
Last active December 23, 2021 01:16
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save jamescalam/133a83c32642ea27a9f648bcc9297003 to your computer and use it in GitHub Desktop.
Save jamescalam/133a83c32642ea27a9f648bcc9297003 to your computer and use it in GitHub Desktop.
import argparse
import datasets
from sentence_transformers import (
InputExample,
SentenceTransformer
)
from sentence_transformers.evaluation import EmbeddingSimilarityEvaluator
from sentence_transformers.cross_encoder.evaluation import CECorrelationEvaluator
from sentence_transformers.cross_encoder import CrossEncoder
parser = argparse.ArgumentParser("Eval")
parser.add_argument('model', help='Path to a model to be evaluated')
args = parser.parse_args()
dev = datasets.load_dataset('glue', 'stsb', split='validation')
dev_set = []
for row in dev:
dev_set.append(
InputExample(
texts=[row['sentence1'], row['sentence2']],
label=float(row['label'])
)
)
if 'cross-encoder' in args.model:
evaluator = CECorrelationEvaluator.from_input_examples(
dev_set
)
model = CrossEncoder(args.model)
else:
evaluator = EmbeddingSimilarityEvaluator.from_input_examples(
dev_set, write_csv=False
)
model = SentenceTransformer(args.model)
print(f'SCORE: {round(evaluator(model), 3)}')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment