Last active
January 24, 2024 18:08
-
-
Save s3nh/cfbbf43f5e9e3cfe8c3e4e2f0d550b80 to your computer and use it in GitHub Desktop.
custom_embedding_fn_chroma.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from chromadb import EmbeddingFunction | |
from typing import List, Dict, Union | |
from typing import Any, TypeVar | |
INSTRUCTIONS = { | |
"qa": { | |
"query": "Represent this query for retrieving relevant documents: ", | |
"key": "Represent this document for retrieval: ", | |
}, | |
"icl": { | |
"query": "Convert this example into vector to look for useful examples: ", | |
"key": "Convert this example into vector for retrieval: ", | |
}, | |
"chat": { | |
"query": "Embed this dialogue to find useful historical dialogues: ", | |
"key": "Embed this historical dialogue for retrieval: ", | |
}, | |
"lrlm": { | |
"query": "Embed this text chunk for finding useful historical chunks: ", | |
"key": "Embed this historical text chunk for retrieval: ", | |
}, | |
"tool": { | |
"query": "Transform this user request for fetching helpful tool descriptions: ", | |
"key": "Transform this tool description for retrieval: " | |
}, | |
"convsearch": { | |
"query": "Encode this query and context for searching relevant passages: ", | |
"key": "Encode this passage for retrieval: ", | |
}, | |
} | |
class CustomCFG: | |
model_name: str = '../assets/llm-embedder' | |
local_files_only: bool = True | |
max_length: int = 512 | |
padding: bool = True | |
truncation: bool = True | |
return_tensors: str = 'pt' | |
chunk_size: int = 16 | |
pad_token: str = "PAD " | |
model_half: bool = False | |
instruction: str = 'convsearch' | |
class CustomEmbeddingFunction(EmbeddingFunction): | |
def __init__(self): | |
self.device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') | |
self.tokenizer = self.load_tokenizer() | |
self.retriever = self.load_retriever() | |
self.starter: str= 'Represent this sentence for searching relevant passages :' | |
def __call__(self, texts: Union[str, List]): | |
chunk_text = self.batch_chunk(texts) | |
embeddings = self.batch_processing(chunk_text) | |
return embeddings.tolist() | |
def tokenize(self, texts): | |
batch_dict = self.tokenizer(texts, #max_length = RetrieverCFG.max_length, | |
padding = CustomCFG.padding, | |
truncation = CustomCFG.truncation, | |
return_tensors = CustomCFG.return_tensors).to(self.device) | |
return batch_dict | |
def average_pool(self, last_hidden_states: Tensor, attention_mask: Tensor) -> Tensor: | |
last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0) | |
return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None] | |
def load_retriever(self): | |
retriever = AutoModel.from_pretrained( | |
pretrained_model_name_or_path = CustomCFG.model_name, | |
local_files_only = True, | |
trust_remote_code=True | |
) | |
if CustomCFG.model_half: | |
torch.set_default_dtype(torch.half) | |
return retriever.to(self.device).half().eval() | |
else: | |
return retriever.to(self.device).eval() | |
def process(self): | |
if 'llm-embedder' in CustomCFG.model_name: | |
instruction = INSTRUCTIONS[CustomCFG.instruction] | |
self.input_text = [instruction["key"] + query for query in self.input_text] | |
batch_dict = self.tokenizer(self.input_text, #max_length = RetrieverCFG.max_length, | |
padding = CustomCFG.padding, | |
truncation = CustomCFG.truncation, | |
return_tensors = CustomCFG.return_tensors).to(self.device) | |
with torch.no_grad(): | |
outputs = self.retriever(**batch_dict) | |
if 'xge' in CustomCFG.model_name: | |
embeddings = outputs[0][:, 0] | |
embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1) | |
if 'llm-embedder' in CustomCFG.model_name: | |
embeddings = outputs.last_hidden_state[:, 0] | |
embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1) | |
else: | |
embeddings = self.average_pool( | |
last_hidden_states = outputs.last_hidden_state.to(self.device), | |
attention_mask = batch_dict['attention_mask'].to(self.device) | |
) | |
return embeddings | |
def batch_chunk(self): | |
... | |
def process_once(self): | |
... | |
def batch_processing(self): | |
... |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment