Skip to content

Instantly share code, notes, and snippets.

@hweller1
Last active February 23, 2024 09:40
Show Gist options
  • Save hweller1/b2c743a97a4a992c263ed1de39dfef02 to your computer and use it in GitHub Desktop.
Save hweller1/b2c743a97a4a992c263ed1de39dfef02 to your computer and use it in GitHub Desktop.
Perform reciprocal rank fusion using $push to expose rank, and $unionWith and $group to join result sets of vector search and full text search
import pymongo
import time
from sentence_transformers import SentenceTransformer
from companies import names # list of company names in a separate python file
### DESCRIPTION
"""
Search against the Sphere dataset using vector search results fused with full text search results via reciprocal rank fusion.
Dataset: https://ai.meta.com/blog/introducing-sphere-meta-ais-web-scale-corpus-for-better-knowledge-intensive-nlp/
"""
### SETUP
connection_str = "<mongodb cluster conection str>"
client = pymongo.MongoClient(
connection_str
)
db = client["vector-test"]
coll = db["sphere1mm"]
### CONFIGURATION PARAMETERS
vector_penalty = 1
full_text_penalty = 10
k = 10
overrequest_factor = 10
### QUERY
queries = [f"What is {x}" for x in names]
model = SentenceTransformer(
"sentence-transformers/facebook-dpr-question_encoder-single-nq-base"
)
embeddings = model.encode(queries)
for i, query in enumerate(queries):
embedding = embeddings[i]
vector_agg_with_lookup = [
{
"$vectorSearch": {
"index": "vector",
"path": "vector",
"queryVector": embedding.tolist(),
"numCandidates": k * overrequest_factor,
"limit": k * 2
}
},
{"$group": {"_id": None, "docs": {"$push": "$$ROOT"}}},
{"$unwind": {"path": "$docs", "includeArrayIndex": "rank"}},
{
"$addFields": {
"vs_score": {"$divide": [1.0, {"$add": ["$rank", vector_penalty, 1]}]}
}
},
{"$project": {"vs_score": 1, "_id": 1, "raw": 1}},
{
"$unionWith": {
"coll": "sphere1mm",
"pipeline": [
{
"$search": {
"index": "fts_sphere",
"text": {"query": query, "path": "raw"},
}
},
{"$limit": k * 2},
{"$group": {"_id": None, "docs": {"$push": "$$ROOT"}}},
{"$unwind": {"path": "$docs", "includeArrayIndex": "rank"}},
{
"$addFields": {
"fts_score": {
"$divide": [
1.0,
{"$add": ["$rank", full_text_penalty, 1]},
]
}
}
},
{"$project": {"fts_score": 1, "_id": 1, "raw": 1}},
],
}
},
{
"$group": {
"_id": "$raw",
"vs_score": {"$max": "$vs_score"},
"fts_score": {"$max": "$fts_score"},
}
},
{
"$project": {
"_id": 1,
"raw": 1,
"vs_score": {"$ifNull": ["$vs_score", 0]},
"fts_score": {"$ifNull": ["$fts_score", 0]},
}
},
{
"$project": {
"raw": 1,
"score": {"$add": ["$fts_score", "$vs_score"]},
"_id": 1,
"vs_score": 1,
"fts_score": 1,
}
},
{"$limit": k},
{"$sort": {"score": -1}},
]
start = time.time()
x = coll.aggregate(vector_agg_with_lookup)
result = x.next()
print(f"hybrid unionWith query took {time.time() - start} seconds \n")
print(
f"Hybrid score: {result['score']}\n vector score: {result['vs_score']}\n full text score: {result['fts_score']} "
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment