Skip to content

Instantly share code, notes, and snippets.

@benwtrent
Created January 31, 2023 16:09
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save benwtrent/5232a4d5aeb5c00cbae960f62b7e2cda to your computer and use it in GitHub Desktop.
Save benwtrent/5232a4d5aeb5c00cbae960f62b7e2cda to your computer and use it in GitHub Desktop.
Some helpers around loading and testing BEIR data sets with elasticsearch
import loaders
from sentence_transformers import SentenceTransformer
# Load queries, qrels, etc. and create embeddings for the queries
queries = loaders.load_jsonl(jsonl_path=Path("./data/queries.jsonl"))
embedding_model = SentenceTransformer(model_id, device="mps")
query_embeddings = embedding_model.encode([d['text'] for d in queries])
query_embeddings = query_embeddings.tolist()
query_and_embeddings = [dict(item, **{'embedding': embedding}) for (item, embedding) in zip(queries, query_embeddings)]
qrels = loaders.load_beir_qrels(qrels_file=Path("./data/qrels/test.tsv"))
# Linear quantization
def quantize_embeddings(text_and_embeddings: t.List[t.Mapping[str, t.Any]]) -> t.List[t.Mapping[str, t.Any]]:
quantized_embeddings = np.array([x['embedding'] for x in query_and_embeddings])
quantized_embeddings = (quantized_embeddings * 128)
quantized_embeddings = quantized_embeddings.clip(-128, 127).astype(int).tolist()
return [dict(item, **{'embedding': embedding}) for (item, embedding) in zip(text_and_embeddings, quantized_embeddings)]
# Create a vector index
def create_vector_index(es_client: Elasticsearch, index_name: str, dims: int, field_name: str, encoding_kind: str = "float", similarity: str = "dot_product"):
es_client.indices.create(index=index_name,mappings={"properties": {field_name: {"type": "dense_vector", "dims": dims, "index": True, "similarity": similarity, "element_type": encoding_kind}}})
# ingest embeddings, besure to use correct ID as that is important for qrels later
def ingest_embeddings(es_client: Elasticsearch, model_id: str, embeddings: t.List[t.Mapping[str, t.Any]], output_quantized: bool = False):
output_field = "embedding"
index_name = f"{model_id}_embedding_index"
if output_quantized:
index_name += "_quantized"
try:
create_vector_index(es_client=es_client, index_name=index_name, field_name=output_field, dims=384, encoding_kind=("byte" if output_quantized else "float"))
except:
pass
actions = [{"_id": d['_id'], "_index": index_name, "embedding": d['embedding']} for d in embeddings]
try:
for success, info in helpers.parallel_bulk(client=es_client,actions=actions):
if not success:
print('A document failed:', info)
raise ValueError
except BulkIndexError as e:
print(e)
print(e.errors)
def test_query_speed(es_client: Elasticsearch, model_id: str, query_and_embeddings: t.Mapping[str, t.Any], output_quantized: bool = False) -> t.Mapping[str, float]:
index_name = f"{model_id}_embedding_index"
if output_quantized:
index_name += "_quantized"
tooks = []
for q in query_and_embeddings:
result = es_client.search(
index=index_name,
size=20,
source=False,
knn={"k": 20, "num_candidates": 100, "field": "embedding", "query_vector": q['embedding']}
)
tooks.append(result['took'])
tooks = np.array(tooks)
return {'mean':tooks.mean(), 'max': tooks.max(), 'min': tooks.min(), 'median': np.median(tooks), 'std': np.std(tooks)}
def test_exact_query_speed(es_client: Elasticsearch, model_id: str, query_and_embeddings: t.Mapping[str, t.Any], output_quantized: bool = False) -> t.Mapping[str, float]:
index_name = f"{model_id}_embedding_index"
if output_quantized:
index_name += "_quantized"
tooks = []
for q in query_and_embeddings:
result = es_client.search(
index=index_name,
size=20,
source=False,
query={
"script_score": {
"query" : { "match_all": {}},
"script": {
"source": "return sigmoid(1, Math.E, -1.0 * dotProduct(params.query_vector, 'embedding'));",
"params": {"query_vector": q['embedding']}
}
}
}
)
tooks.append(result['took'])
tooks = np.array(tooks)
return {'mean':tooks.mean(), 'max': tooks.max(), 'min': tooks.min(), 'median': np.median(tooks), 'std': np.std(tooks)}
# Give knowledge of which IDs should match for a give query (qrels) determine recall
def test_recall(es_client: Elasticsearch, model_id: str, query_and_embeddings: t.Mapping[str, t.Any], qrels: t.Mapping[str, t.Any], output_quantized: bool = False, recall_size: int = 100) -> t.Mapping[str, float]:
index_name = f"{model_id}_embedding_index"
if output_quantized:
index_name += "_quantized"
recalls = []
numcans= min(1000, recall_size*5)
for q in query_and_embeddings:
ratings = [{ "_index": index_name, "_id": k, "rating": v } for (k, v) in qrels[q["_id"]].items()]
result = es_client.rank_eval(
index=index_name,
requests=[
{
"id": q["_id"],
"request": {"knn": {"k": recall_size, "num_candidates": numcans, "field": "embedding", "query_vector": q['embedding']}},
"ratings": ratings
}
],
metric={"recall": {"k": recall_size}}
)
recalls.append(result['metric_score'])
recalls = np.array(recalls)
return {'mean':recalls.mean(), 'max': recalls.max(), 'min': recalls.min(), 'median': np.median(recalls), 'std': np.std(recalls)}
# Give knowledge of which IDs should match for a give query (qrels) determine ndgc
def test_ndgc(es_client: Elasticsearch, model_id: str, query_and_embeddings: t.Mapping[str, t.Any], qrels: t.Mapping[str, t.Any], output_quantized: bool = False, ndcg_size: int = 10) -> t.Mapping[str, float]:
index_name = f"{model_id}_embedding_index"
if output_quantized:
index_name += "_quantized"
ndcgs = []
numcans= min(1000, ndcg_size*5)
for q in query_and_embeddings:
ratings = [{ "_index": index_name, "_id": k, "rating": v } for (k, v) in qrels[q["_id"]].items()]
result = es_client.rank_eval(
index=index_name,
requests=[
{
"id": q["_id"],
"request": {"knn": {"k": ndcg_size, "num_candidates": numcans, "field": "embedding", "query_vector": q['embedding']}},
"ratings": ratings
}
],
metric={"dcg": {"k": ndcg_size, "normalize": True}}
)
ndcgs.append(result['metric_score'])
ndcgs = np.array(ndcgs)
return {'mean':ndcgs.mean(), 'max': ndcgs.max(), 'min': ndcgs.min(), 'median': np.median(ndcgs), 'std': np.std(ndcgs)}
import csv
import json
import logging
from collections import defaultdict
from pathlib import Path
from typing import Dict, Iterator
import rich.progress
logger = logging.getLogger(__name__)
def load_json(path: Path) -> dict:
with open(path, "r") as f:
return json.load(f)
def load_tsv(tsv_file: Path):
logger.info(f"Loading {tsv_file} ...")
documents = []
with rich.progress.open(str(tsv_file), "r", encoding="utf8") as f_read:
for line in f_read:
_id, text = line.strip().split("\t")
documents.append({"_id": _id, "title": "", "text": text})
return documents
def load_tsv_iterator(tsv_file: Path):
logger.info(f"Loading {tsv_file} ...")
with rich.progress.open(str(tsv_file), "r", encoding="utf8") as f_read:
for line in f_read:
_id, text = line.strip().split("\t")
yield {"_id": _id, "title": "", "text": text}
def load_jsonl(jsonl_path: Path):
logger.info(f"Loading {jsonl_path} ...")
with rich.progress.open(str(jsonl_path), "r", encoding="utf8") as f_read:
return [json.loads(line) for line in f_read]
def load_jsonl_iterator(jsonl_path: Path) -> Iterator[dict]:
logger.info(f"Loading {jsonl_path} ...")
with rich.progress.open(str(jsonl_path), "r", encoding="utf8") as f_read:
for line in f_read:
yield json.loads(line)
def load_beir_qrels(qrels_file: Path) -> Dict[str, Dict[str, int]]:
logger.info(f"Loading {qrels_file} ...")
with rich.progress.open(str(qrels_file), "r", encoding="utf8") as f_read:
reader = csv.reader(f_read, delimiter="\t", quoting=csv.QUOTE_MINIMAL)
next(reader) # skip first line of the csv
qrels = defaultdict(dict)
for row in reader:
query_id, corpus_id, score = row[0], row[1], int(row[2])
qrels[str(query_id)][str(corpus_id)] = score
return qrels
def load_jsonl_qas(qas_path: Path) -> dict:
logger.info(f"Loading {qas_path} ...")
qas = defaultdict(dict)
with rich.progress.open(str(qas_path), "r", encoding="utf8") as f_read:
for line in f_read:
obj = json.loads(line)
for did in obj["answer_pids"]:
qas[str(obj["qid"])][str(did)] = int(1)
return qas
def load_needed_queries(filepath: Path, qrels: dict) -> Dict[str, str]:
logger.info(f"Loading {filepath} ...")
queries = {}
with rich.progress.open(str(filepath), "r", encoding="utf8") as f_read:
for line in f_read:
obj = json.loads(line)
if str(obj["_id"]) in qrels.keys():
queries[str(obj["_id"])] = f"{obj.get('title', '')} {obj['text']}"
return queries
def load_needed_docs(filepath: Path, qrels: dict) -> Dict[str, str]:
logger.info(f"Loading {filepath} ...")
needed_ids = {x for dpq in qrels.values() for x in dpq.keys()}
docs = {}
with rich.progress.open(str(filepath), "r", encoding="utf8") as f_read:
for line in f_read:
obj = json.loads(line)
if str(obj["_id"]) in needed_ids:
docs[str(obj["_id"])] = f"{obj.get('title', '')} {obj['text']}"
return docs
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment