Skip to content

Instantly share code, notes, and snippets.

@alexlimh
Created February 10, 2023 20:39
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save alexlimh/61cf2a16e352dcad74ac93e48283ba98 to your computer and use it in GitHub Desktop.
Save alexlimh/61cf2a16e352dcad74ac93e48283ba98 to your computer and use it in GitHub Desktop.
import os
import time
import argparse
import numpy as np
import pickle
import collections
import jsonlines
import torch
import glob
import scipy
from functools import partial
from multiprocessing import Pool
from tqdm import tqdm
from pyserini.search.lucene import LuceneImpactSearcher
from pyserini.pyclass import autoclass, JFloat, JArrayList, JHashMap
try:
from dpr_scale.retriever_ext import scatter as c_scatter
except ImportError:
raise ImportError(
'Cannot import scatter module.'
' Make sure you have compiled the retriever extension.'
)
def load_file(path, i):
data = torch.load(path)
return (data, i)
def maxsim(entry):
q_embed, d_embeds, d_lens, qid, scores, docids = entry
if len(d_embeds) == 0:
return qid, scores, docids
d_embeds = scipy.sparse.vstack(d_embeds).transpose() # (LD x 1000) x D
max_scores = (q_embed@d_embeds).todense() # LQ x (LD x 1000)
scores = []
start = 0
for d_len in d_lens:
scores.append(max_scores[:, start:start+d_len].max(1).sum())
start += d_len
scores, docids = list(zip(*sorted(list(zip(scores, docids)), key=lambda x: -x[0])))
return qid, scores, docids
class LuceneMultiTermSearcher(LuceneImpactSearcher):
def __init__(self, query_path, query_embedding_path, sparse_query_path, sparse_corpus_path, corpus_path, weight_threshold, threads, id2idx_path=None, *args, **kwargs):
super().__init__(*args, **kwargs)
self.weight_threshold = weight_threshold
self.threads = threads
self.queries = {}
print("Loading queries...")
if query_path.split(".")[-1] == "tsv":
with open(query_path) as f:
lines = f.readlines()
for line in lines:
qid, query = line.strip().split("\t")
self.queries[qid] = query
elif query_path.split(".")[-1] == "jsonl":
with jsonlines.open(query_path) as f:
for line in f:
self.queries[line["question_id"]] = line["question"]
print("Loading query embeddings...")
self.query_embeddings = self.query_preprocess(query_embedding_path)
print("Loading corpus...")
with open(corpus_path) as f:
lines = f.readlines()
self.corpus_len = len(lines) - 1 # exclude header
self.sparse_q_vecs = None
self.sparse_vecs = None
if sparse_query_path is not None and os.path.exists(sparse_query_path):
self.pool = Pool(threads)
print("Loading sparse query vectors for fast reranking...")
sparse_query_range_path = os.path.join(sparse_query_path, "sparse_range.pkl")
with open(sparse_query_range_path, "rb") as f:
sparse_q_ranges = pickle.load(f)
sparse_query_vec_path = os.path.join(sparse_query_path, "sparse_vec.npz")
sparse_q_vecs = scipy.sparse.load_npz(sparse_query_vec_path)
sparse_q_vecs_scatter = []
for start, end in sparse_q_ranges:
sparse_q_vecs_scatter.append(sparse_q_vecs[start:end])
if id2idx_path is None:
self.sparse_q_vecs = {k:v for k, v in zip(list(self.query_embeddings.keys()), sparse_q_vecs_scatter)}
else:
with open(id2idx_path, "rb") as f:
id2idx = pickle.load(f)
self.sparse_q_vecs = {k:sparse_q_vecs_scatter[id2idx[k]] for k in self.query_embeddings.keys()}
print("Loading sparse corpus vectors for fast reranking...")
sparse_corpus_range_path = os.path.join(sparse_corpus_path, "sparse_range.pkl")
with open(sparse_corpus_range_path, "rb") as f:
self.sparse_ranges = pickle.load(f)
sparse_corpus_vec_path = os.path.join(sparse_corpus_path, "sparse_vec.npz")
sparse_vecs = scipy.sparse.load_npz(sparse_corpus_vec_path)
self.sparse_vecs = []
for start, end in tqdm(self.sparse_ranges):
self.sparse_vecs.append(sparse_vecs[start:end])
def query_preprocess(self, embedding_path):
upper_embeddings = collections.defaultdict(dict)
with jsonlines.open(embedding_path) as f:
for line in f:
if len(line["vector"]) > 0:
topic_pos_id = line["id"]
splits = topic_pos_id.split("_")
pos = splits[-1]
topic_id = "_".join(splits[:-1])
if topic_id in self.queries:
for term, weight in line["vector"].items():
if weight > self.weight_threshold:
upper_embeddings[topic_id][term] = upper_embeddings[topic_id].get(term, 0) + weight
return upper_embeddings
def batch_search(self, topk, threads, batch_size):
query_lst = JArrayList()
qid_lst = JArrayList()
qids = []
ranking = {}
iterator = list(self.query_embeddings.items())
count = 0
print("Searching...")
latency = 0
for i, entry in tqdm(list(enumerate(iterator))):
qid, vector = entry
jquery = JHashMap()
for token, weight in vector.items():
if weight > self.weight_threshold and token in self.idf and self.idf[token] > self.min_idf:
jquery.put(token, JFloat(weight))
query_lst.add(jquery)
qid_lst.add(qid)
qids.append(qid)
count += 1
if count == batch_size or i == len(iterator) - 1:
tic = time.perf_counter()
raw_results = self.object.batch_search(query_lst, qid_lst, topk, threads)
results = {r.getKey(): r.getValue() for r in raw_results.entrySet().toArray()}
all_scores = []
all_docids = []
for qid in qids:
hits = results[qid]
docids = []
scores = []
for hit in hits:
docids.append(int(hit.docid))
scores.append(hit.score)
all_scores.append(scores)
all_docids.append(docids)
if self.sparse_vecs is not None:
qids, all_scores, all_docids = self.fast_rerank(qids, all_scores, all_docids)
for qid, scores, docids in zip(qids, all_scores, all_docids):
ranking[qid] = (scores, docids)
toc = time.perf_counter()
latency += toc - tic
query_lst = JArrayList()
qid_lst = JArrayList()
qids = []
count = 0
if self.sparse_vecs is not None:
self.pool.close()
print(f"Average search latency {latency/len(iterator)*1000:.2f}ms/query")
return ranking
def fast_rerank(self, qids, all_scores, all_docids):
all_q_embeds = []
all_d_embeds = []
all_d_lens = []
for qid, scores, docids in zip(qids, all_scores, all_docids):
all_q_embeds.append(self.sparse_q_vecs[qid])
d_embeds = []
d_lens = []
for docid in docids:
start, end = self.sparse_ranges[int(docid)]
d_embeds.append(self.sparse_vecs[int(docid)])
d_lens.append(end-start)
all_d_embeds.append(d_embeds)
all_d_lens.append(d_lens)
entries = list(zip(all_q_embeds, all_d_embeds, all_d_lens, qids, all_scores, all_docids))
results = self.pool.map(maxsim, entries)
qids, all_scores, all_docids = list(zip(*results))
return qids, all_scores, all_docids
def main(args):
searcher = LuceneMultiTermSearcher(args.query_path,
args.query_embedding_path,
args.sparse_query_vec_path,
args.sparse_corpus_vec_path,
args.corpus_path,
args.weight_threshold,
id2idx_path=args.id2idx_path,
threads=args.threads,
index_dir=args.index,
query_encoder=None,
min_idf=args.min_idf)
ranking = searcher.batch_search(args.topk, args.threads, args.batch_size)
i2d = []
if args.idx2id_path is not None and os.path.exists(args.idx2id_path):
with open(args.idx2id_path) as f:
lines = f.readlines()
for line in lines:
i2d.append(line.strip())
trec_reults = []
for topic_id, (top_scores, top_indices) in ranking.items():
for rank, (score, doc_id) in enumerate(list(zip(top_scores, top_indices))[:args.out_topk]):
if len(i2d) == 0:
trec_reults.append(f"{topic_id} Q0 {doc_id} {rank+1} {score:.6f} Anserini\n")
else:
trec_reults.append(f"{topic_id} Q0 {i2d[doc_id]} {rank+1} {score:.6f} Anserini\n")
print(f"Writing output to {args.output_path}")
os.makedirs(args.output_path, exist_ok=True)
with open(os.path.join(args.output_path, f"retrieval.trec"), "w") as g:
g.writelines(trec_reults)
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Keeps only queries in a qrels file.')
parser.add_argument('--index', required=True, help='MS MARCO tsv qrels file.')
parser.add_argument('--query_path', required=True, help='Queries file.')
parser.add_argument('--id2idx_path', default=None, help='Queries file.')
parser.add_argument('--query_embedding_path', required=True, help='MS MARCO tsv qrels file.')
parser.add_argument('--sparse_query_vec_path', default=None, help='MS MARCO tsv qrels file.')
parser.add_argument('--sparse_corpus_vec_path', default=None, help='MS MARCO tsv qrels file.')
parser.add_argument('--idx2id_path', default=None, help='MS MARCO tsv qrels file.')
parser.add_argument('--corpus_path', required=True, help='Queries file.')
parser.add_argument('--output_path', required=True, help='Output queries file.')
parser.add_argument('--topk', type=int, default=1000, help='Output queries file.')
parser.add_argument('--out_topk', type=int, default=1000, help='Output queries file.')
parser.add_argument('--threads', type=int, default=1, help='Output queries file.')
parser.add_argument('--batch_size', type=int, default=128, help='Output queries file.')
parser.add_argument('--weight_threshold', type=float, default=0.0, help='Output queries file.')
parser.add_argument('--min_idf', type=float, default=0.0, help='Output queries file.')
args = parser.parse_args()
main(args)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment