Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save hugoabonizio/e958c8bcc196bdf602d389b0bd6b6834 to your computer and use it in GitHub Desktop.
Save hugoabonizio/e958c8bcc196bdf602d389b0bd6b6834 to your computer and use it in GitHub Desktop.
import os
import tqdm
import gzip
import torch
import tarfile
import logging
from datetime import datetime
from torch.utils.data import DataLoader
from sentence_transformers import InputExample, LoggingHandler, util
from sentence_transformers.cross_encoder import CrossEncoder
from sentence_transformers.cross_encoder.evaluation import CERerankingEvaluator
model_name_or_path = "cross-encoder/ms-marco-MiniLM-L-6-v2"
model = CrossEncoder(model_name_or_path, num_labels=1, max_length=512)
data_folder = 'msmarco-data'
os.makedirs(data_folder, exist_ok=True)
corpus = {}
collection_filepath = os.path.join(data_folder, 'collection.tsv')
if not os.path.exists(collection_filepath):
tar_filepath = os.path.join(data_folder, 'collection.tar.gz')
if not os.path.exists(tar_filepath):
logging.info("Download collection.tar.gz")
util.http_get('https://msmarco.blob.core.windows.net/msmarcoranking/collection.tar.gz', tar_filepath)
with tarfile.open(tar_filepath, "r:gz") as tar:
tar.extractall(path=data_folder)
with open(collection_filepath, 'r', encoding='utf8') as finput:
for line in finput:
pid, passage = line.strip().split("\t")
corpus[pid] = passage
queries = {}
queries_filepath = os.path.join(data_folder, 'queries.train.tsv')
if not os.path.exists(queries_filepath):
tar_filepath = os.path.join(data_folder, 'queries.tar.gz')
if not os.path.exists(tar_filepath):
logging.info("Download queries.tar.gz")
util.http_get('https://msmarco.blob.core.windows.net/msmarcoranking/queries.tar.gz', tar_filepath)
with tarfile.open(tar_filepath, "r:gz") as tar:
tar.extractall(path=data_folder)
with open(queries_filepath, 'r', encoding='utf8') as finput:
for line in finput:
qid, query = line.strip().split("\t")
queries[qid] = query
dev_samples = {}
# We use 200 random queries from the train set for evaluation during training
# Each query has at least one relevant and up to 200 irrelevant (negative) passages
num_dev_queries = 200
num_max_dev_negatives = 200
# msmarco-qidpidtriples.rnd-shuf.train-eval.tsv.gz and msmarco-qidpidtriples.rnd-shuf.train.tsv.gz is a randomly
# shuffled version of qidpidtriples.train.full.2.tsv.gz from the MS Marco website
# We extracted in the train-eval split 500 random queries that can be used for evaluation during training
train_eval_filepath = os.path.join(data_folder, 'msmarco-qidpidtriples.rnd-shuf.train-eval.tsv.gz')
if not os.path.exists(train_eval_filepath):
logging.info("Download "+os.path.basename(train_eval_filepath))
util.http_get('https://sbert.net/datasets/msmarco-qidpidtriples.rnd-shuf.train-eval.tsv.gz', train_eval_filepath)
with gzip.open(train_eval_filepath, 'rt') as finput:
for line in finput:
qid, pos_id, neg_id = line.strip().split()
if qid not in dev_samples and len(dev_samples) < num_dev_queries:
dev_samples[qid] = {'query': queries[qid], 'positive': set(), 'negative': set()}
if qid in dev_samples:
dev_samples[qid]['positive'].add(corpus[pos_id])
if len(dev_samples[qid]['negative']) < num_max_dev_negatives:
dev_samples[qid]['negative'].add(corpus[neg_id])
evaluator = CERerankingEvaluator(dev_samples, name='train-eval')
print(evaluator(model))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment