Skip to content

Instantly share code, notes, and snippets.

@s3nh
Last active January 24, 2024 18:08
Show Gist options
  • Save s3nh/cfbbf43f5e9e3cfe8c3e4e2f0d550b80 to your computer and use it in GitHub Desktop.
Save s3nh/cfbbf43f5e9e3cfe8c3e4e2f0d550b80 to your computer and use it in GitHub Desktop.
custom_embedding_fn_chroma.py
import torch
from chromadb import EmbeddingFunction
from typing import List, Dict, Union
from typing import Any, TypeVar
INSTRUCTIONS = {
"qa": {
"query": "Represent this query for retrieving relevant documents: ",
"key": "Represent this document for retrieval: ",
},
"icl": {
"query": "Convert this example into vector to look for useful examples: ",
"key": "Convert this example into vector for retrieval: ",
},
"chat": {
"query": "Embed this dialogue to find useful historical dialogues: ",
"key": "Embed this historical dialogue for retrieval: ",
},
"lrlm": {
"query": "Embed this text chunk for finding useful historical chunks: ",
"key": "Embed this historical text chunk for retrieval: ",
},
"tool": {
"query": "Transform this user request for fetching helpful tool descriptions: ",
"key": "Transform this tool description for retrieval: "
},
"convsearch": {
"query": "Encode this query and context for searching relevant passages: ",
"key": "Encode this passage for retrieval: ",
},
}
class CustomCFG:
model_name: str = '../assets/llm-embedder'
local_files_only: bool = True
max_length: int = 512
padding: bool = True
truncation: bool = True
return_tensors: str = 'pt'
chunk_size: int = 16
pad_token: str = "PAD "
model_half: bool = False
instruction: str = 'convsearch'
class CustomEmbeddingFunction(EmbeddingFunction):
def __init__(self):
self.device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
self.tokenizer = self.load_tokenizer()
self.retriever = self.load_retriever()
self.starter: str= 'Represent this sentence for searching relevant passages :'
def __call__(self, texts: Union[str, List]):
chunk_text = self.batch_chunk(texts)
embeddings = self.batch_processing(chunk_text)
return embeddings.tolist()
def tokenize(self, texts):
batch_dict = self.tokenizer(texts, #max_length = RetrieverCFG.max_length,
padding = CustomCFG.padding,
truncation = CustomCFG.truncation,
return_tensors = CustomCFG.return_tensors).to(self.device)
return batch_dict
def average_pool(self, last_hidden_states: Tensor, attention_mask: Tensor) -> Tensor:
last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0)
return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]
def load_retriever(self):
retriever = AutoModel.from_pretrained(
pretrained_model_name_or_path = CustomCFG.model_name,
local_files_only = True,
trust_remote_code=True
)
if CustomCFG.model_half:
torch.set_default_dtype(torch.half)
return retriever.to(self.device).half().eval()
else:
return retriever.to(self.device).eval()
def process(self):
if 'llm-embedder' in CustomCFG.model_name:
instruction = INSTRUCTIONS[CustomCFG.instruction]
self.input_text = [instruction["key"] + query for query in self.input_text]
batch_dict = self.tokenizer(self.input_text, #max_length = RetrieverCFG.max_length,
padding = CustomCFG.padding,
truncation = CustomCFG.truncation,
return_tensors = CustomCFG.return_tensors).to(self.device)
with torch.no_grad():
outputs = self.retriever(**batch_dict)
if 'xge' in CustomCFG.model_name:
embeddings = outputs[0][:, 0]
embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
if 'llm-embedder' in CustomCFG.model_name:
embeddings = outputs.last_hidden_state[:, 0]
embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
else:
embeddings = self.average_pool(
last_hidden_states = outputs.last_hidden_state.to(self.device),
attention_mask = batch_dict['attention_mask'].to(self.device)
)
return embeddings
def batch_chunk(self):
...
def process_once(self):
...
def batch_processing(self):
...
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment