Skip to content

Instantly share code, notes, and snippets.

@alexlimh
Created March 30, 2023 17:54
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save alexlimh/82a1683cd86a40521004252fdbc9b739 to your computer and use it in GitHub Desktop.
Save alexlimh/82a1683cd86a40521004252fdbc9b739 to your computer and use it in GitHub Desktop.
# (c) Facebook, Inc. and its affiliates. Confidential and proprietary.
# @manual=//faiss/python:pyfaiss
import faiss
import hydra
import glob
import json
import os
import pickle
import torch
import numpy as np
from dpr_scale.conf.config import MainConfig
from dpr_scale.datamodule.dpr import CSVDataset, QueryCSVDataset, QueryTRECDataset
from omegaconf import open_dict
from pytorch_lightning.trainer import Trainer
from typing import Dict, List
import time
from tqdm import tqdm, trange
def merge_results(
passages: Dict,
questions: List,
top_doc_ids: List,
scores_list: List,
):
# join passages text with the result ids, their questions
merged_data = []
assert len(top_doc_ids) == len(questions) == len(scores_list)
for i, question, doc_ids, scores in zip(range(len(questions)), questions, top_doc_ids, scores_list):
ctxs = [
{
"id": passages[id]["id"],
"title": passages[id]["title"],
"text": passages[id]["text"],
"score": float(score),
}
for id, score in zip(doc_ids, scores)
]
merged_data.append(
{
"question": question["question"],
"answers": question["answers"] if "answers" in question else [],
"ctxs": ctxs,
"id": question.get("id", i),
}
)
return merged_data
def build_index(paths):
index = None
vectors = []
for fname in paths:
with open(fname, 'rb') as f:
vector = pickle.load(f)
if not index:
index = faiss.IndexFlatIP(vector.size()[1])
print(f"Adding {vector.size()} vectors from {fname}")
index.add(vector.numpy())
vectors.append(vector)
vectors = torch.cat(vectors, 0)
return index, vectors
def sort(scores, topk):
top_ids = np.argpartition(scores, -topk, axis=1)[:, -topk:] # linear time partition but shuffled
top_scores = np.take_along_axis(scores, top_ids, axis=1)
top_subset_ids = np.argsort(-1.*top_scores, axis=1) # sort the top-k list
top_scores = np.take_along_axis(top_scores, top_subset_ids, axis=1)
top_ids = np.take_along_axis(top_ids, top_subset_ids, axis=1)
return top_scores, top_ids
@hydra.main(config_path="conf", config_name="config")
def main(cfg: MainConfig):
# Temp patch for datamodule refactoring
cfg.task.datamodule = None
cfg.task._target_ = (
"dpr_scale.task.dpr_eval_task.GenerateQueryEmbeddingsTask" # hack
)
# trainer.fit does some setup, so we need to call it even though no training is done
with open_dict(cfg):
cfg.trainer.limit_train_batches = 0
if "plugins" in cfg.trainer:
cfg.trainer.pop(
"plugins"
) # remove ddp_sharded, because it breaks during loading
print(cfg)
task = hydra.utils.instantiate(cfg.task, _recursive_=False)
transform = hydra.utils.instantiate(cfg.task.transform)
datamodule = hydra.utils.instantiate(cfg.datamodule, transform=transform)
trainer = Trainer(**cfg.trainer)
trainer.fit(task, datamodule=datamodule)
trainer.test(task, datamodule=datamodule)
# index all passages
input_paths = sorted(glob.glob(os.path.join(cfg.task.ctx_embeddings_dir, "reps_*")))
index, ctx_vectors = build_index(input_paths)
# reload question embeddings
print("Loading question vectors.")
with open(
task.query_emb_output_path, "rb"
) as f:
q_repr = pickle.load(f)
if cfg.use_gpu:
q_repr, ctx_vectors = q_repr.cuda(), ctx_vectors.cuda()
else:
q_repr, ctx_vectors = q_repr.numpy().astype(np.float32), ctx_vectors.numpy().astype(np.float32)
print("Retrieving results...")
retrieval_time = 0
sort_time = 0
all_indexes = []
all_scores = []
# scores, indexes = index.search(q_repr.numpy(), 100)
for batch_start in trange(0, len(q_repr), cfg.batch_size):
batch_q_repr = q_repr[batch_start: batch_start + cfg.batch_size]
tic = time.perf_counter()
if cfg.use_gpu:
scores = torch.matmul(batch_q_repr, ctx_vectors.T)
else:
scores = np.matmul(batch_q_repr, ctx_vectors.T)
toc = time.perf_counter()
retrieval_time += toc - tic
tic = time.perf_counter()
if cfg.use_gpu:
scores, indexes = scores.topk(dim=1, k=cfg.topk)
else:
scores, indexes = sort(scores, cfg.topk)
toc = time.perf_counter()
sort_time += toc - tic
scores, indexes = scores.tolist(), indexes.tolist()
all_scores.extend(scores)
all_indexes.extend(indexes)
print(f"Retrieval time:{retrieval_time:.2f}s")
print(f"Sorting time:{sort_time:.2f}s")
# load questions file
print(f"Loading questions file {cfg.datamodule.test_path}")
if "msmarco" in cfg.datamodule.test_path:
questions = QueryTRECDataset(cfg.datamodule.test_path)
else:
questions = QueryCSVDataset(cfg.datamodule.test_path)
# load all passages:
print(f"Loading passages from {cfg.task.passages}")
ctxs = CSVDataset(cfg.task.passages)
# write output file
print("Merging results...")
if cfg.datamodule.trec_format:
trec_data = []
for i, (question, doc_ids, scores) in enumerate(zip(questions, all_indexes, all_scores)):
topic_id = question["id"]
for rank, (doc_id, score) in enumerate(zip(doc_ids, scores)):
trec_data.append(f"{topic_id} Q0 {doc_id} {rank+1} {score:.6f} dpr-scale\n")
print(f"Writing output to {cfg.task.output_path}")
os.makedirs(cfg.task.output_path, exist_ok=True)
with open(os.path.join(cfg.task.output_path, f"retrieval.trec"), "w") as g:
g.writelines(trec_data)
else:
results = merge_results(ctxs, questions, all_indexes, all_scores)
print(f"Writing output to {cfg.task.output_path}")
os.makedirs(os.path.dirname(cfg.task.output_path), exist_ok=True)
with open(cfg.task.output_path, "w") as g:
g.write(json.dumps(results, indent=4))
g.write("\n")
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment