Skip to content

Instantly share code, notes, and snippets.

@ManZzup
Created May 31, 2023 10:25
Show Gist options
  • Save ManZzup/1109d4c1f6b8bc48b60a67983dfbd0fd to your computer and use it in GitHub Desktop.
Save ManZzup/1109d4c1f6b8bc48b60a67983dfbd0fd to your computer and use it in GitHub Desktop.
Crude implementation of Hybrid (KNN and text similarity) search for Elasticsearch in Langchain
from langchain.vectorstores import ElasticVectorSearch
from typing import Dict, List, Optional, Any, Iterable
from langchain.docstore.document import Document
import uuid
def _default_text_mapping(dim: int) -> Dict:
return {
"properties": {
"text": {"type": "text"},
"vector": {
"type": "dense_vector",
"dims": dim,
"index": True,
"similarity": "l2_norm"
}
}
}
class CustomElasticSearchVectorStore(ElasticVectorSearch):
"""Customer ES vector store that implements a hybrid search"""
def add_texts(
self,
texts: Iterable[str],
metadatas: Optional[List[dict]] = None,
refresh_indices: bool = True,
batch_size: int = 50,
**kwargs: Any,
) -> List[str]:
"""Run more texts through the embeddings and add to the vectorstore.
Args:
texts: Iterable of strings to add to the vectorstore.
metadatas: Optional list of metadatas associated with the texts.
refresh_indices: bool to refresh ElasticSearch indices
batch_size: Number of bulk inserts to do per call
Returns:
List of ids from adding the texts into the vectorstore.
"""
try:
from elasticsearch.exceptions import NotFoundError
from elasticsearch.helpers import bulk
except ImportError:
raise ValueError(
"Could not import elasticsearch python package. "
"Please install it with `pip install elasticsearch`."
)
requests = []
ids = []
embeddings = self.embedding.embed_documents(list(texts))
dim = len(embeddings[0])
mapping = _default_text_mapping(dim)
# check to see if the index already exists
try:
self.client.indices.get(index=self.index_name)
except NotFoundError:
# TODO would be nice to create index before embedding,
# just to save expensive steps for last
self.client.indices.create(index=self.index_name, mappings=mapping)
for i, text in enumerate(texts):
metadata = metadatas[i] if metadatas else {}
_id = str(uuid.uuid4())
request = {
"_op_type": "index",
"_index": self.index_name,
"vector": embeddings[i],
"text": text,
"metadata": metadata,
"_id": _id,
}
ids.append(_id)
requests.append(request)
pointer = 0
while pointer < len(requests):
# Loop through the set batch wise
bulk(self.client, requests[pointer: pointer+batch_size])
pointer += batch_size
if refresh_indices:
self.client.indices.refresh(index=self.index_name)
return ids
def similarity_search(
self, query: str, k: int = 4, filter: Optional[dict] = None, **kwargs: Any
) -> List[Document]:
"""Return docs most similar to query.
Args:
query: Text to look up documents similar to.
k: Number of Documents to return. Defaults to 4.
Returns:
List of Documents most similar to the query.
"""
embedding = self.embedding.embed_query(query)
query = {
"match": {
"text": {
"query": query,
"boost": 0.5
}
}
}
knn = {
"field": "vector",
"query_vector": embedding,
"k": k,
"num_candidates": 50,
"boost": 0.5
}
response = self.client.search(
index=self.index_name,
query=query,
knn=knn,
size=k,
source=["text", "metadata"]
)
hits = [hit for hit in response["hits"]["hits"]]
docs_and_scores = [
(
Document(
page_content=hit["_source"]["text"],
metadata=hit["_source"]["metadata"],
),
hit["_score"],
)
for hit in hits
]
documents = [d[0] for d in docs_and_scores]
return documents
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment