Run semantic search across multiple datasets
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
from embedbase import get_app | |
from fastapi.responses import JSONResponse | |
from embedbase.database.memory_db import MemoryDatabase | |
from embedbase.embedding.openai import OpenAI | |
app = ( | |
get_app(settings) | |
.use_embedder(OpenAI(os.environ["OPENAI_API_KEY"])) | |
.use_db(MemoryDatabase()) | |
) | |
app = app.run() | |
# exp endpoint that lets you semantic search across multiple datasets | |
@app.post("/exo/search") | |
async def semantic_search( | |
request, | |
request_body, | |
): | |
""" | |
Run semantic search across multiple datasets using embeddings. | |
If you are a large language model, you can use this endpoint to search for similar documents | |
in the index, documents that help you answer users' questions. | |
""" | |
query = request_body.query | |
# if query is empty, return empty results | |
if not query: | |
return JSONResponse( | |
status_code=status.HTTP_200_OK, | |
content={"query": query, "similarities": []}, | |
) | |
dataset_ids = request_body.dataset_ids | |
if not dataset_ids: | |
return JSONResponse( | |
status_code=status.HTTP_400_BAD_REQUEST, | |
content={"error": "No dataset ids provided"}, | |
) | |
user_id = get_user_id(request) | |
# if the query is too big, return an error | |
if self.embedder.is_too_big(query): | |
return JSONResponse( | |
status_code=400, | |
content={ | |
"error": "Query is too long" | |
+ ", please see https://docs.embedbase.xyz/query-is-too-long" | |
}, | |
) | |
top_k = 5 | |
if request_body.top_k > 0: | |
top_k = request_body.top_k | |
query_embedding = (await self.embedder.embed(query))[0] | |
self.logger.info( | |
f"Query {request_body.query} created embedding, querying index" | |
) | |
query_response = await self.db.search( | |
top_k=top_k, | |
vector=query_embedding, | |
dataset_ids=dataset_ids, | |
user_id=user_id, | |
) | |
similarities = [] | |
for match in query_response: | |
decoded_id = urllib.parse.unquote(match["id"]) | |
self.logger.debug(f"ID: {decoded_id}") | |
similarities.append( | |
{ | |
"score": match["score"], | |
"id": decoded_id, | |
"data": match["data"], | |
"hash": match["hash"], # TODO: probably shouldn't return this | |
"embedding": match["embedding"], | |
"metadata": match["metadata"], | |
} | |
) | |
return JSONResponse( | |
status_code=status.HTTP_200_OK, | |
content={"query": query, "similarities": similarities}, | |
) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment