-
-
Save waleedkadous/aea1d312d68c9431949442cc562d5f2c to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from langchain.embeddings.base import Embeddings | |
from typing import List | |
from sentence_transformers import SentenceTransformer | |
class LocalHuggingFaceEmbeddings(Embeddings): | |
def __init__(self, model_id): | |
# Should use the GPU by default | |
self.model = SentenceTransformer(model_id) | |
def embed_documents(self, texts: List[str]) -> List[List[float]]: | |
"""Embed a list of documents using a locally running | |
Hugging Face Sentence Transformer model | |
Args: | |
texts: The list of texts to embed. | |
Returns: | |
List of embeddings, one for each text. | |
""" | |
embeddings =self.model.encode(texts) | |
return embeddings | |
def embed_query(self, text: str) -> List[float]: | |
"""Embed a query using a locally running HF | |
Sentence trnsformer. | |
Args: | |
text: The text to embed. | |
Returns: | |
Embeddings for the text. | |
""" | |
embedding = self.model.encode(text) | |
return list(map(float, embedding)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment