Created
July 30, 2021 11:56
-
-
Save hugoabonizio/e958c8bcc196bdf602d389b0bd6b6834 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import tqdm | |
import gzip | |
import torch | |
import tarfile | |
import logging | |
from datetime import datetime | |
from torch.utils.data import DataLoader | |
from sentence_transformers import InputExample, LoggingHandler, util | |
from sentence_transformers.cross_encoder import CrossEncoder | |
from sentence_transformers.cross_encoder.evaluation import CERerankingEvaluator | |
model_name_or_path = "cross-encoder/ms-marco-MiniLM-L-6-v2" | |
model = CrossEncoder(model_name_or_path, num_labels=1, max_length=512) | |
data_folder = 'msmarco-data' | |
os.makedirs(data_folder, exist_ok=True) | |
corpus = {} | |
collection_filepath = os.path.join(data_folder, 'collection.tsv') | |
if not os.path.exists(collection_filepath): | |
tar_filepath = os.path.join(data_folder, 'collection.tar.gz') | |
if not os.path.exists(tar_filepath): | |
logging.info("Download collection.tar.gz") | |
util.http_get('https://msmarco.blob.core.windows.net/msmarcoranking/collection.tar.gz', tar_filepath) | |
with tarfile.open(tar_filepath, "r:gz") as tar: | |
tar.extractall(path=data_folder) | |
with open(collection_filepath, 'r', encoding='utf8') as finput: | |
for line in finput: | |
pid, passage = line.strip().split("\t") | |
corpus[pid] = passage | |
queries = {} | |
queries_filepath = os.path.join(data_folder, 'queries.train.tsv') | |
if not os.path.exists(queries_filepath): | |
tar_filepath = os.path.join(data_folder, 'queries.tar.gz') | |
if not os.path.exists(tar_filepath): | |
logging.info("Download queries.tar.gz") | |
util.http_get('https://msmarco.blob.core.windows.net/msmarcoranking/queries.tar.gz', tar_filepath) | |
with tarfile.open(tar_filepath, "r:gz") as tar: | |
tar.extractall(path=data_folder) | |
with open(queries_filepath, 'r', encoding='utf8') as finput: | |
for line in finput: | |
qid, query = line.strip().split("\t") | |
queries[qid] = query | |
dev_samples = {} | |
# We use 200 random queries from the train set for evaluation during training | |
# Each query has at least one relevant and up to 200 irrelevant (negative) passages | |
num_dev_queries = 200 | |
num_max_dev_negatives = 200 | |
# msmarco-qidpidtriples.rnd-shuf.train-eval.tsv.gz and msmarco-qidpidtriples.rnd-shuf.train.tsv.gz is a randomly | |
# shuffled version of qidpidtriples.train.full.2.tsv.gz from the MS Marco website | |
# We extracted in the train-eval split 500 random queries that can be used for evaluation during training | |
train_eval_filepath = os.path.join(data_folder, 'msmarco-qidpidtriples.rnd-shuf.train-eval.tsv.gz') | |
if not os.path.exists(train_eval_filepath): | |
logging.info("Download "+os.path.basename(train_eval_filepath)) | |
util.http_get('https://sbert.net/datasets/msmarco-qidpidtriples.rnd-shuf.train-eval.tsv.gz', train_eval_filepath) | |
with gzip.open(train_eval_filepath, 'rt') as finput: | |
for line in finput: | |
qid, pos_id, neg_id = line.strip().split() | |
if qid not in dev_samples and len(dev_samples) < num_dev_queries: | |
dev_samples[qid] = {'query': queries[qid], 'positive': set(), 'negative': set()} | |
if qid in dev_samples: | |
dev_samples[qid]['positive'].add(corpus[pos_id]) | |
if len(dev_samples[qid]['negative']) < num_max_dev_negatives: | |
dev_samples[qid]['negative'].add(corpus[neg_id]) | |
evaluator = CERerankingEvaluator(dev_samples, name='train-eval') | |
print(evaluator(model)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment