Last active
November 15, 2023 16:27
-
-
Save hweller1/d6dbd5036ae4366108b534a0f1662a20 to your computer and use it in GitHub Desktop.
relative score fusion using unionWith and group, using the scores yielded from the searchMeta
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import pymongo | |
import time | |
from sentence_transformers import SentenceTransformer | |
from companies import names # List of company names from another python file | |
### DESCRIPTION | |
""" | |
Search against the Sphere dataset using vector search results fused with full text search results via relative score fusion. | |
Dataset: https://ai.meta.com/blog/introducing-sphere-meta-ais-web-scale-corpus-for-better-knowledge-intensive-nlp/ | |
""" | |
### SETUP | |
connection_str = "<mongodb cluster conection str>" | |
client = pymongo.MongoClient( | |
connection_str | |
) | |
db = client["vector-test"] | |
coll = db["sphere1mm"] | |
### CONFIGURATION PARAMETERS | |
vector_scalar = 0.9 # Vector search score scaling factor | |
vector_normalization = 40 # Rough scaling of dot product vector scores | |
fts_scalar = 1 - vector_scalar # FTS score scaling factor | |
fts_normalization = 10 # Rough scaling of full text search scores | |
k = 10 | |
overrequest_factor = 10 | |
### QUERY | |
queries = [f"What is {x}" for x in names] | |
model = SentenceTransformer( | |
"sentence-transformers/facebook-dpr-question_encoder-single-nq-base" | |
) | |
embeddings = model.encode(queries) | |
for i, query in enumerate(queries): | |
embedding = embeddings[i] | |
vector_agg_with_lookup = [ | |
{ | |
"$vectorSearch": { | |
"index": "vector", | |
"path": "vector", | |
"queryVector": embedding.tolist(), | |
"numCandidates": k * overrequest_factor, | |
"limit": k * 2 | |
} | |
}, | |
{"$addFields": {"vs_score": {"$meta": "searchScore"}}}, | |
{ | |
"$project": { | |
"vs_score": {"$multiply": ["$vs_score", vector_scalar / vector_normalization]}, | |
"_id": 1, | |
"raw": 1, | |
} | |
}, | |
{ | |
"$unionWith": { | |
"coll": "sphere1mm", | |
"pipeline": [ | |
{ | |
"$search": { | |
"index": "fts_sphere", | |
"text": {"query": query, "path": "raw"}, | |
} | |
}, | |
{"$limit": k * 2}, | |
{"$addFields": {"fts_score": {"$meta": "searchScore"}}}, | |
{ | |
"$project": { | |
"fts_score": {"$multiply": ["$fts_score", fts_scalar / fts_normalization]}, | |
"_id": 1, | |
"raw": 1, | |
} | |
}, | |
], | |
} | |
}, | |
{ | |
"$group": { | |
"_id": "$raw", | |
"vs_score": {"$max": "$vs_score"}, | |
"fts_score": {"$max": "$fts_score"}, | |
} | |
}, | |
{ | |
"$project": { | |
"_id": 1, | |
"raw": 1, | |
"vs_score": {"$ifNull": ["$vs_score", 0]}, | |
"fts_score": {"$ifNull": ["$fts_score", 0]}, | |
} | |
}, | |
{ | |
"$project": { | |
"raw": 1, | |
"score": {"$add": ["$fts_score", "$vs_score"]}, | |
"_id": 1, | |
"vs_score": 1, | |
"fts_score": 1, | |
} | |
}, | |
{"$limit": k}, | |
{"$sort": {"score": -1}}, | |
] | |
start = time.time() | |
x = coll.aggregate(vector_agg_with_lookup) | |
print(f"hybrid unionWith query took {time.time() - start} seconds \n") | |
result = x.next() | |
print( | |
f"Hybrid score: {result['score']}\n vector score: {result['vs_score']}\n full text score: {result['fts_score']} " | |
) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment