Skip to content

Instantly share code, notes, and snippets.

@tsg
Created February 27, 2024 21:18
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save tsg/1088379515bfae7b293efcd78e0148ae to your computer and use it in GitHub Desktop.
Save tsg/1088379515bfae7b293efcd78e0148ae to your computer and use it in GitHub Desktop.
Hybrid Search using Xata
from xata.client import XataClient
from sentence_transformers import SentenceTransformer
import sys
import time
xata = XataClient()
# expect the query as the first argument
if len(sys.argv) != 2:
print("Usage: python hybrid_search.py <query>")
exit(1)
query = sys.argv[1]
def vector_search(query):
model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
vector = model.encode(sys.argv[1])
results = xata.data().vector_search("docs", {
"queryVector": vector.tolist() * 4,
"column": "embedding",
"size": 5
})
if not results.is_success():
raise Exception(f"Vector search failed: {results.json()}")
return results
def keyword_search(query):
results = xata.data().search_table("docs", {
"query": query,
"fuzziness": 1,
"prefix": "phrase",
"page": {
"size": 5
}
})
if not results.is_success():
raise Exception(f"Keyword search failed: {results.json()}")
return results
def rerank_with_rrf(keyword_results, vector_results, k=60):
"""Computes the reciprocal rank fusion of two search results."""
# Combine and initialize scores
unique_results = {result["id"]: result for result in keyword_results + vector_results}
scores = {result_id: 0 for result_id in unique_results.keys()}
# Helper to update scores based on RRF formula
def update_scores(results_list, scores, k):
for rank, result in enumerate(results_list, start=1):
result_id = result['id']
if result_id in scores:
scores[result_id] += 1 / (k + rank)
# Update scores for both sets of results
update_scores(keyword_results, scores, k)
update_scores(vector_results, scores, k)
# Sort results by their RRF score in descending order
sorted_result_ids = sorted(scores.keys(), key=lambda id: scores[id], reverse=True)
# Extract the sorted result objects
sorted_results = [unique_results[result_id] for result_id in sorted_result_ids]
return sorted_results
def main():
vector_results = vector_search(query)
# Note: in a real application, it would make sense to run these searches in parallel
print(f"Semantic search results:")
for result in vector_results["records"]:
print(f'{result["id"]}\t{result["sentence"]}\t{result["xata"]["score"]}')
keyword_results = keyword_search(query)
print(f"\nKeyword search results:")
for result in keyword_results["records"]:
print(f'{result["id"]}\t{result["sentence"]}\t{result["xata"]["score"]}')
results = rerank_with_rrf(keyword_results["records"], vector_results["records"])
print(f"\nReranked results:")
for result in results:
print(f'{result["id"]}\t{result["sentence"]}')
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment