Created
November 10, 2023 05:07
-
-
Save lalanikarim/489a6c49f241ce8169d596c5e8f990f0 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import asyncio | |
from langchain.docstore.document import Document | |
from langchain.schema.embeddings import Embeddings | |
from langchain.schema.vectorstore import VectorStore | |
from langchain.embeddings import HuggingFaceEmbeddings | |
from surrealdb import Surreal | |
from typing import ( | |
TYPE_CHECKING, | |
Any, | |
Callable, | |
ClassVar, | |
Collection, | |
Dict, | |
Iterable, | |
List, | |
Optional, | |
Tuple, | |
Type, | |
TypeVar, | |
) | |
class SurrealDBStore(VectorStore): | |
def __init__(self, dburl: str, | |
embeddings_function: Optional[Embeddings] = HuggingFaceEmbeddings(), | |
ns: str = "langchain", | |
db: str = "database", | |
collection: str = "documents", | |
**kwargs: Any) -> None: | |
self.collection = collection | |
self.ns = ns | |
self.db = db | |
self.sdb = Surreal(dburl) | |
self.embeddings_function = embeddings_function | |
self.score_threshold = kwargs.get("score_threshold", 0.7) | |
@property | |
def embeddings(self) -> Optional[Embeddings]: | |
return ( | |
self.embedding_function | |
if isinstance(self.embedding_function, Embeddings) | |
else None | |
) | |
async def aadd_texts( | |
self, | |
texts: Iterable[str], | |
metadatas: Optional[List[dict]] = None, | |
**kwargs: Any, | |
) -> List[str]: | |
embeddings = self.embeddings_function.embed_documents(texts) | |
async with self.sdb: | |
await self.sdb.use(self.ns, self.db) | |
ids = [] | |
for idx,text in enumerate(texts): | |
record = await self.sdb.create( | |
self.collection, | |
{ | |
"text": text, | |
"embedding": embeddings[idx] | |
} | |
) | |
ids.append(record[0]["id"]) | |
return ids | |
def add_texts( | |
self, | |
texts: Iterable[str], | |
metadatas: Optional[List[dict]] = None, | |
**kwargs: Any, | |
) -> List[str]: | |
return asyncio.run(self.aadd_texts(texts, metadatas, **kwargs)) | |
async def asimilarity_search( | |
self, query: str, k: int = 4, **kwargs: Any | |
) -> List[Document]: | |
query_embedding = self.embeddings_function.embed_query(query) | |
async with self.sdb: | |
await self.sdb.use(self.ns, self.db) | |
results = await self.sdb.query( | |
f"select id, text, vector::similarity::cosine(embedding,$embedding) as similarity from {self.collection} order by similarity desc LIMIT $k", | |
{ | |
"embedding": query_embedding, | |
"k":k | |
} | |
) | |
return [ | |
Document( | |
page_content=result["text"], | |
metadata={"id": result["id"]} | |
) for result in results[0]["result"] | |
if result["similarity"] >= self.score_threshold | |
] | |
def similarity_search( | |
self, query: str, k: int = 4, **kwargs: Any | |
) -> List[Document]: | |
return asyncio.run(self.asimilarity_search(query, k, **kwargs)) | |
@classmethod | |
async def afrom_texts( | |
cls, | |
dburl: str, | |
texts: List[str], | |
embeddings_function: Optional[Embeddings] = HuggingFaceEmbeddings(), | |
ns: str = "langchain", | |
db: str = "database", | |
collection: str = "documents", | |
**kwargs: Any, | |
) -> 'SurrealDBStore': | |
sdbs = cls(dburl,embeddings_function,ns,db,collection,**kwargs) | |
await sdbs.aadd_texts(texts) | |
return sdbs | |
@classmethod | |
def from_texts( | |
cls, | |
sdb: str, | |
texts: List[str], | |
embeddings_function: Optional[Embeddings] = HuggingFaceEmbeddings(), | |
ns: str = "langchain", | |
db: str = "database", | |
collection: str = "documents", | |
**kwargs: Any, | |
) -> 'SurrealDBStore': | |
sdbs = cls(sdb,embeddings_function,ns,db,collection,**kwargs) | |
sdbs.add_texts(texts) | |
return sdbs |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment