Skip to content

Instantly share code, notes, and snippets.

@lalanikarim
Created November 10, 2023 05:07
Show Gist options
  • Save lalanikarim/489a6c49f241ce8169d596c5e8f990f0 to your computer and use it in GitHub Desktop.
Save lalanikarim/489a6c49f241ce8169d596c5e8f990f0 to your computer and use it in GitHub Desktop.
import asyncio
from langchain.docstore.document import Document
from langchain.schema.embeddings import Embeddings
from langchain.schema.vectorstore import VectorStore
from langchain.embeddings import HuggingFaceEmbeddings
from surrealdb import Surreal
from typing import (
TYPE_CHECKING,
Any,
Callable,
ClassVar,
Collection,
Dict,
Iterable,
List,
Optional,
Tuple,
Type,
TypeVar,
)
class SurrealDBStore(VectorStore):
def __init__(self, dburl: str,
embeddings_function: Optional[Embeddings] = HuggingFaceEmbeddings(),
ns: str = "langchain",
db: str = "database",
collection: str = "documents",
**kwargs: Any) -> None:
self.collection = collection
self.ns = ns
self.db = db
self.sdb = Surreal(dburl)
self.embeddings_function = embeddings_function
self.score_threshold = kwargs.get("score_threshold", 0.7)
@property
def embeddings(self) -> Optional[Embeddings]:
return (
self.embedding_function
if isinstance(self.embedding_function, Embeddings)
else None
)
async def aadd_texts(
self,
texts: Iterable[str],
metadatas: Optional[List[dict]] = None,
**kwargs: Any,
) -> List[str]:
embeddings = self.embeddings_function.embed_documents(texts)
async with self.sdb:
await self.sdb.use(self.ns, self.db)
ids = []
for idx,text in enumerate(texts):
record = await self.sdb.create(
self.collection,
{
"text": text,
"embedding": embeddings[idx]
}
)
ids.append(record[0]["id"])
return ids
def add_texts(
self,
texts: Iterable[str],
metadatas: Optional[List[dict]] = None,
**kwargs: Any,
) -> List[str]:
return asyncio.run(self.aadd_texts(texts, metadatas, **kwargs))
async def asimilarity_search(
self, query: str, k: int = 4, **kwargs: Any
) -> List[Document]:
query_embedding = self.embeddings_function.embed_query(query)
async with self.sdb:
await self.sdb.use(self.ns, self.db)
results = await self.sdb.query(
f"select id, text, vector::similarity::cosine(embedding,$embedding) as similarity from {self.collection} order by similarity desc LIMIT $k",
{
"embedding": query_embedding,
"k":k
}
)
return [
Document(
page_content=result["text"],
metadata={"id": result["id"]}
) for result in results[0]["result"]
if result["similarity"] >= self.score_threshold
]
def similarity_search(
self, query: str, k: int = 4, **kwargs: Any
) -> List[Document]:
return asyncio.run(self.asimilarity_search(query, k, **kwargs))
@classmethod
async def afrom_texts(
cls,
dburl: str,
texts: List[str],
embeddings_function: Optional[Embeddings] = HuggingFaceEmbeddings(),
ns: str = "langchain",
db: str = "database",
collection: str = "documents",
**kwargs: Any,
) -> 'SurrealDBStore':
sdbs = cls(dburl,embeddings_function,ns,db,collection,**kwargs)
await sdbs.aadd_texts(texts)
return sdbs
@classmethod
def from_texts(
cls,
sdb: str,
texts: List[str],
embeddings_function: Optional[Embeddings] = HuggingFaceEmbeddings(),
ns: str = "langchain",
db: str = "database",
collection: str = "documents",
**kwargs: Any,
) -> 'SurrealDBStore':
sdbs = cls(sdb,embeddings_function,ns,db,collection,**kwargs)
sdbs.add_texts(texts)
return sdbs
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment