Skip to content

Instantly share code, notes, and snippets.

@hweller1
Last active November 15, 2023 16:27
Show Gist options
  • Save hweller1/d6dbd5036ae4366108b534a0f1662a20 to your computer and use it in GitHub Desktop.
Save hweller1/d6dbd5036ae4366108b534a0f1662a20 to your computer and use it in GitHub Desktop.
relative score fusion using unionWith and group, using the scores yielded from the searchMeta
import pymongo
import time
from sentence_transformers import SentenceTransformer
from companies import names # List of company names from another python file
### DESCRIPTION
"""
Search against the Sphere dataset using vector search results fused with full text search results via relative score fusion.
Dataset: https://ai.meta.com/blog/introducing-sphere-meta-ais-web-scale-corpus-for-better-knowledge-intensive-nlp/
"""
### SETUP
connection_str = "<mongodb cluster conection str>"
client = pymongo.MongoClient(
connection_str
)
db = client["vector-test"]
coll = db["sphere1mm"]
### CONFIGURATION PARAMETERS
vector_scalar = 0.9 # Vector search score scaling factor
vector_normalization = 40 # Rough scaling of dot product vector scores
fts_scalar = 1 - vector_scalar # FTS score scaling factor
fts_normalization = 10 # Rough scaling of full text search scores
k = 10
overrequest_factor = 10
### QUERY
queries = [f"What is {x}" for x in names]
model = SentenceTransformer(
"sentence-transformers/facebook-dpr-question_encoder-single-nq-base"
)
embeddings = model.encode(queries)
for i, query in enumerate(queries):
embedding = embeddings[i]
vector_agg_with_lookup = [
{
"$vectorSearch": {
"index": "vector",
"path": "vector",
"queryVector": embedding.tolist(),
"numCandidates": k * overrequest_factor,
"limit": k * 2
}
},
{"$addFields": {"vs_score": {"$meta": "searchScore"}}},
{
"$project": {
"vs_score": {"$multiply": ["$vs_score", vector_scalar / vector_normalization]},
"_id": 1,
"raw": 1,
}
},
{
"$unionWith": {
"coll": "sphere1mm",
"pipeline": [
{
"$search": {
"index": "fts_sphere",
"text": {"query": query, "path": "raw"},
}
},
{"$limit": k * 2},
{"$addFields": {"fts_score": {"$meta": "searchScore"}}},
{
"$project": {
"fts_score": {"$multiply": ["$fts_score", fts_scalar / fts_normalization]},
"_id": 1,
"raw": 1,
}
},
],
}
},
{
"$group": {
"_id": "$raw",
"vs_score": {"$max": "$vs_score"},
"fts_score": {"$max": "$fts_score"},
}
},
{
"$project": {
"_id": 1,
"raw": 1,
"vs_score": {"$ifNull": ["$vs_score", 0]},
"fts_score": {"$ifNull": ["$fts_score", 0]},
}
},
{
"$project": {
"raw": 1,
"score": {"$add": ["$fts_score", "$vs_score"]},
"_id": 1,
"vs_score": 1,
"fts_score": 1,
}
},
{"$limit": k},
{"$sort": {"score": -1}},
]
start = time.time()
x = coll.aggregate(vector_agg_with_lookup)
print(f"hybrid unionWith query took {time.time() - start} seconds \n")
result = x.next()
print(
f"Hybrid score: {result['score']}\n vector score: {result['vs_score']}\n full text score: {result['fts_score']} "
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment