Skip to content

Instantly share code, notes, and snippets.

@malteos
Created June 26, 2024 06:13
Show Gist options
  • Save malteos/178a1b77ac362cd7857a054e2d9c07cb to your computer and use it in GitHub Desktop.
Save malteos/178a1b77ac362cd7857a054e2d9c07cb to your computer and use it in GitHub Desktop.
Run BM25 baseline on MTEB retrieval tasks
"""Evaluate BM25 on MTEB tasks
Usage:
python bm25.py -t <task name> --output_folder=./data/results
Notes:
- https://github.com/xhluca/bm25s (promissing implememntation)
- https://github.com/beir-cellar/beir/blob/main/examples/retrieval/evaluation/lexical/evaluate_bm25.py
- https://colab.research.google.com/drive/1HfutiEhHMJLXiWGT8pcipxT5L2TpYEdt?usp=sharing#scrollTo=nqotyXuIBPt6
Requirements:
pip install "bm25s[full]" PyStemmer beir
"""
import argparse
import json
import logging
import os
from pathlib import Path
from time import time
from typing import List, Optional, Union
import bm25s
import Stemmer
from beir.retrieval.evaluation import EvaluateRetrieval
from mteb.abstasks import AbsTaskRetrieval
from mteb.evaluation import MTEB
from mteb.evaluation.evaluators.RetrievalEvaluator import DenseRetrievalExactSearch
logging.basicConfig(
format="%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO
)
logger = logging.getLogger(__name__)
class BM25Search(DenseRetrievalExactSearch):
"""Override dense retrieval with BM25 search"""
def __init__(
self,
previous_results: str = None,
stopwords: str = "en",
stemmer_language: Optional[str] = "english",
**kwargs,
):
super().__init__(
model=None,
batch_size=1,
corpus_chunk_size=1,
previous_results=previous_results,
**kwargs,
)
self.stopwords = stopwords
# optional: create a stemmer
self.stemmer = Stemmer.Stemmer(stemmer_language) if stemmer_language else None
def search(
self,
corpus: dict[str, dict[str, str]],
queries: dict[str, Union[str, List[str]]],
top_k: int,
score_function: str,
return_sorted: bool = False,
**kwargs,
) -> dict[str, dict[str, float]]:
logger.info("Encoding Corpus...")
corpus_ids = list(corpus.keys())
corpus_with_ids = [{"doc_id": cid, **corpus[cid]} for cid in corpus_ids]
corpus_texts = [
"\n".join([doc["title"], doc["text"]]) for doc in corpus_with_ids
] # concatenate all document values (title, text, ...)
encoded_corpus = self.encode(corpus_texts)
logger.info(
f"Indexing Corpus... {len(encoded_corpus.ids):,} documents, {len(encoded_corpus.vocab):,} vocab"
)
# Create the BM25 model and index the corpus
retriever = bm25s.BM25()
retriever.index(encoded_corpus)
logger.info("Encoding Queries...")
query_ids = list(queries.keys())
self.results = {qid: {} for qid in query_ids}
queries_texts = [queries[qid] for qid in queries]
query_token_strs = self.encode(queries_texts, return_ids=False)
logger.info(f"Retrieving Results... {len(queries):,} queries")
queries_results, queries_scores = retriever.retrieve(
query_token_strs, corpus=corpus_with_ids, k=top_k
)
# Iterate over queries
for qi, qid in enumerate(query_ids):
doc_id_to_score = {}
query_results = queries_results[qi]
scores = queries_scores[qi]
doc_id_to_score = {}
# Iterate over results
for ri in range(len(query_results)):
doc = query_results[ri]
score = scores[ri]
doc_id = doc["doc_id"]
doc_id_to_score[doc_id] = float(score)
self.results[qid] = doc_id_to_score
return self.results
def encode(self, texts: List[str], **kwargs):
"""Encode input text as term vectors"""
return bm25s.tokenize(
texts, stopwords=self.stopwords, stemmer=self.stemmer, **kwargs
)
class BM25MTEB(MTEB):
"""Override eval methods from parent class"""
def select_tasks(self, **kwargs):
"""Select the tasks to be evaluated."""
super().select_tasks(**kwargs)
# Get only retrieval tasks
self.tasks = [t for t in self.tasks if isinstance(t, AbsTaskRetrieval)]
def _run_eval(self, task, model, split, output_folder, **kwargs):
if model is not None:
raise ValueError("BM25 does not need a model")
if not isinstance(task, AbsTaskRetrieval):
raise ValueError(
"Only retrieval tasks that inherit `AbsTaskRetrieval` from can be evaluated!"
)
tick = time()
results = self.evaluate_task(task, split, output_folder=output_folder, **kwargs)
tock = time()
return results, tick, tock
def _evaluate_subset(
self,
corpus,
queries,
relevant_docs,
hf_subset: str,
main_score: str,
k_values=[1, 3, 5, 10, 20, 100, 1000],
**kwargs,
):
start_time = time()
# Retrieve and evaluate with BM25 search
model = BM25Search()
retriever = EvaluateRetrieval(retriever=model)
results = retriever.retrieve(corpus, queries)
end_time = time()
logger.info(
"Time taken to retrieve: {:.2f} seconds".format(end_time - start_time)
)
if kwargs.get("save_predictions", False):
output_folder = Path(kwargs.get("output_folder", "results"))
if not os.path.isdir(output_folder):
os.makedirs(output_folder)
top_k = kwargs.get("top_k", None)
if top_k is not None:
for qid in list(results.keys()):
doc_ids = set(
sorted(
results[qid], key=lambda x: results[qid][x], reverse=True
)[:top_k]
)
results[qid] = {
k: v for k, v in results[qid].items() if k in doc_ids
}
qrels_save_path = (
output_folder
/ f"{self.metadata_dict['name']}_{hf_subset}_predictions.json"
)
with open(qrels_save_path, "w") as f:
json.dump(results, f)
ndcg, _map, recall, precision = retriever.evaluate(
relevant_docs,
results,
k_values,
ignore_identical_ids=kwargs.get("ignore_identical_ids", True),
)
mrr = retriever.evaluate_custom(relevant_docs, results, k_values, "mrr")
scores = {
**{f"ndcg_at_{k.split('@')[1]}": v for (k, v) in ndcg.items()},
**{f"map_at_{k.split('@')[1]}": v for (k, v) in _map.items()},
**{f"recall_at_{k.split('@')[1]}": v for (k, v) in recall.items()},
**{f"precision_at_{k.split('@')[1]}": v for (k, v) in precision.items()},
**{f"mrr_at_{k.split('@')[1]}": v for (k, v) in mrr.items()},
}
# task._add_main_score(scores)
scores["main_score"] = scores[main_score]
return scores
def evaluate_task(self, task, split="test", **kwargs):
"""Evaluate a specific task"""
scores = {}
hf_subsets = (
[l for l in task.hf_subsets]
if (task.is_multilingual or task.is_crosslingual)
else ["default"]
)
for hf_subset in hf_subsets:
logger.info(f"Subset: {hf_subset}")
if hf_subset == "default":
corpus, queries, relevant_docs = (
task.corpus[split],
task.queries[split],
task.relevant_docs[split],
)
else:
corpus, queries, relevant_docs = (
task.corpus[hf_subset][split],
task.queries[hf_subset][split],
task.relevant_docs[hf_subset][split],
)
scores[hf_subset] = self._evaluate_subset(
corpus,
queries,
relevant_docs,
hf_subset,
task.metadata.main_score,
**kwargs,
)
return scores
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--task_types",
nargs="+",
type=str,
default=None,
help="List of task types (Clustering, Retrieval..) to be evaluated. If None, all tasks will be evaluated",
)
parser.add_argument(
"--task_categories",
nargs="+",
type=str,
default=None,
help="List of task categories (s2s, p2p..) to be evaluated. If None, all tasks will be evaluated",
)
parser.add_argument(
"-t",
"--tasks",
nargs="+",
type=str,
default=None,
help="List of tasks to be evaluated. If specified, the other arguments are ignored.",
)
parser.add_argument(
"-l",
"--task-langs",
nargs="*",
type=str,
default=None,
help="List of languages to be evaluated. if not set, all languages will be evaluated.",
)
parser.add_argument(
"--seed", type=int, default=42, help="Random seed for computation"
)
parser.add_argument(
"--output_folder",
type=str,
default=None,
help="Output directory for results. Will default to results/{model_name} if not set.",
)
parser.add_argument(
"-v", "--verbosity", type=int, default=2, help="Verbosity level"
)
parser.add_argument(
"--co2_tracker",
type=bool,
default=False,
help="Enable CO₂ tracker, disabled by default",
)
## evaluation params
parser.add_argument(
"--eval_splits",
nargs="+",
type=str,
default=None,
help="Evaluation splits to use (train, dev, test..). If None, all splits will be used",
)
## display tasks
parser.add_argument(
"--available_tasks",
action="store_true",
default=False,
help="Display the available tasks",
)
# TODO: check what prams are useful to add
args = parser.parse_args()
# set logging based on verbosity level
if args.verbosity == 0:
logging.getLogger("mteb").setLevel(logging.CRITICAL)
elif args.verbosity == 1:
logging.getLogger("mteb").setLevel(logging.WARNING)
elif args.verbosity == 2:
logging.getLogger("mteb").setLevel(logging.INFO)
elif args.verbosity == 3:
logging.getLogger("mteb").setLevel(logging.DEBUG)
logger.info("Running with parameters: %s", args)
if args.available_tasks:
BM25MTEB.mteb_tasks()
return
eval = BM25MTEB(
task_categories=args.task_categories,
task_types=args.task_types,
task_langs=args.task_langs,
tasks=args.tasks,
)
eval.run(
model=None,
verbosity=args.verbosity,
output_folder=args.output_folder,
eval_splits=args.eval_splits,
co2_tracker=args.co2_tracker,
)
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment