Last active
February 23, 2024 09:40
-
-
Save hweller1/b2c743a97a4a992c263ed1de39dfef02 to your computer and use it in GitHub Desktop.
Perform reciprocal rank fusion using $push to expose rank, and $unionWith and $group to join result sets of vector search and full text search
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import pymongo | |
import time | |
from sentence_transformers import SentenceTransformer | |
from companies import names # list of company names in a separate python file | |
### DESCRIPTION | |
""" | |
Search against the Sphere dataset using vector search results fused with full text search results via reciprocal rank fusion. | |
Dataset: https://ai.meta.com/blog/introducing-sphere-meta-ais-web-scale-corpus-for-better-knowledge-intensive-nlp/ | |
""" | |
### SETUP | |
connection_str = "<mongodb cluster conection str>" | |
client = pymongo.MongoClient( | |
connection_str | |
) | |
db = client["vector-test"] | |
coll = db["sphere1mm"] | |
### CONFIGURATION PARAMETERS | |
vector_penalty = 1 | |
full_text_penalty = 10 | |
k = 10 | |
overrequest_factor = 10 | |
### QUERY | |
queries = [f"What is {x}" for x in names] | |
model = SentenceTransformer( | |
"sentence-transformers/facebook-dpr-question_encoder-single-nq-base" | |
) | |
embeddings = model.encode(queries) | |
for i, query in enumerate(queries): | |
embedding = embeddings[i] | |
vector_agg_with_lookup = [ | |
{ | |
"$vectorSearch": { | |
"index": "vector", | |
"path": "vector", | |
"queryVector": embedding.tolist(), | |
"numCandidates": k * overrequest_factor, | |
"limit": k * 2 | |
} | |
}, | |
{"$group": {"_id": None, "docs": {"$push": "$$ROOT"}}}, | |
{"$unwind": {"path": "$docs", "includeArrayIndex": "rank"}}, | |
{ | |
"$addFields": { | |
"vs_score": {"$divide": [1.0, {"$add": ["$rank", vector_penalty, 1]}]} | |
} | |
}, | |
{"$project": {"vs_score": 1, "_id": 1, "raw": 1}}, | |
{ | |
"$unionWith": { | |
"coll": "sphere1mm", | |
"pipeline": [ | |
{ | |
"$search": { | |
"index": "fts_sphere", | |
"text": {"query": query, "path": "raw"}, | |
} | |
}, | |
{"$limit": k * 2}, | |
{"$group": {"_id": None, "docs": {"$push": "$$ROOT"}}}, | |
{"$unwind": {"path": "$docs", "includeArrayIndex": "rank"}}, | |
{ | |
"$addFields": { | |
"fts_score": { | |
"$divide": [ | |
1.0, | |
{"$add": ["$rank", full_text_penalty, 1]}, | |
] | |
} | |
} | |
}, | |
{"$project": {"fts_score": 1, "_id": 1, "raw": 1}}, | |
], | |
} | |
}, | |
{ | |
"$group": { | |
"_id": "$raw", | |
"vs_score": {"$max": "$vs_score"}, | |
"fts_score": {"$max": "$fts_score"}, | |
} | |
}, | |
{ | |
"$project": { | |
"_id": 1, | |
"raw": 1, | |
"vs_score": {"$ifNull": ["$vs_score", 0]}, | |
"fts_score": {"$ifNull": ["$fts_score", 0]}, | |
} | |
}, | |
{ | |
"$project": { | |
"raw": 1, | |
"score": {"$add": ["$fts_score", "$vs_score"]}, | |
"_id": 1, | |
"vs_score": 1, | |
"fts_score": 1, | |
} | |
}, | |
{"$limit": k}, | |
{"$sort": {"score": -1}}, | |
] | |
start = time.time() | |
x = coll.aggregate(vector_agg_with_lookup) | |
result = x.next() | |
print(f"hybrid unionWith query took {time.time() - start} seconds \n") | |
print( | |
f"Hybrid score: {result['score']}\n vector score: {result['vs_score']}\n full text score: {result['fts_score']} " | |
) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment