-
-
Save pszemraj/b9d3b951229c0eec3973ff7efe2adb18 to your computer and use it in GitHub Desktop.
fast nomic onnx embedder
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import logging | |
from pathlib import Path | |
import datasets | |
import fire | |
import transformers | |
from datasets import load_dataset | |
from nomic_embedder_onnx import TextEmbedder | |
logging.basicConfig( | |
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" | |
) | |
def shut_up(): | |
datasets.utils.logging.set_verbosity(logging.ERROR) | |
transformers.utils.logging.set_verbosity(logging.ERROR) | |
shut_up() | |
def encode_texts(examples, embedder, col_name="embeddings"): | |
""" | |
Encode a batch of texts and return the embeddings. | |
Parameters: | |
- examples: A batch of texts. | |
- embedder: An instance of TextEmbedder. | |
Returns: | |
- A dictionary with the embeddings to add them as a new column. | |
""" | |
embeddings = embedder.encode( | |
examples[col_name], return_list=True, disable_progress=True | |
) | |
return {f"{col_name}-embedding": embeddings} | |
def process_dataset( | |
dataset_name: str, | |
config_name: str = "default", | |
split="train", | |
col_name="text", | |
batch_size=4, | |
output_dir="output", | |
push_to_hub=True, | |
): | |
""" | |
Process the dataset to encode texts and save the embeddings. | |
Parameters: | |
- dataset_name: The name of the dataset. | |
- split: The dataset split to use. | |
- batch_size: Batch size for processing. | |
- output_dir: Directory to save processed dataset. | |
- push_to_hub: Whether to push the processed dataset to HF Hub. | |
""" | |
output_dir = Path(output_dir) | |
output_dir.mkdir(parents=True, exist_ok=True) | |
# Initialize the TextEmbedder with desired parameters | |
embedder = TextEmbedder(batch_size=batch_size, use_io_binding=False) | |
print(embedder) | |
# Load the dataset | |
logging.info( | |
f"Loading dataset: {dataset_name} config: {config_name}, split: {split}" | |
) | |
dataset = load_dataset(dataset_name, config_name, split=split) | |
# Process the dataset in batches | |
logging.info("Processing dataset...") | |
dataset = dataset.map( | |
lambda examples: encode_texts(examples, embedder, col_name=col_name), | |
batched=True, | |
batch_size=batch_size, | |
) | |
if push_to_hub: | |
logging.info("Pushing processed dataset to HF Hub...") | |
dataset.push_to_hub( | |
dataset_name, | |
private=True, | |
config_name="embeddings-{col_name}-nomic_text_v1", | |
commit_message=f"add text embeddings on column {col_name}", | |
) | |
else: | |
output_path = ( | |
output_dir / f"{dataset_name.replace('/', '_')}_{split}_embeddings" | |
) | |
logging.info(f"Saving processed dataset to {output_path}") | |
dataset.save_to_disk(str(output_path)) | |
logging.info("Dataset processing completed.") | |
if __name__ == "__main__": | |
fire.Fire(process_dataset) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# pip install -U "optimum[onnxruntime]" | |
import torch | |
import torch.nn.functional as F | |
from optimum.onnxruntime import ORTModelForFeatureExtraction | |
from tqdm.auto import trange | |
from transformers import AutoTokenizer | |
class TextEmbedder: | |
def __init__( | |
self, | |
model_name="nomic-ai/nomic-embed-text-v1", | |
tokenizer_name="bert-base-uncased", | |
model_file_name="onnx/model_quantized.onnx", | |
batch_size=8, | |
use_io_binding=False, | |
): | |
self.model_name = model_name | |
self.tokenizer_name = tokenizer_name | |
self.batch_size = batch_size | |
self.use_io_binding = use_io_binding | |
self.tokenizer = AutoTokenizer.from_pretrained( | |
self.tokenizer_name or self.model_name, model_max_length=8192 | |
) | |
self.model = ORTModelForFeatureExtraction.from_pretrained( | |
model_name, | |
trust_remote_code=True, | |
file_name=model_file_name, # Use "onnx/model.onnx" for unquantized version | |
rotary_scaling_factor=2, | |
use_io_binding=use_io_binding, | |
) | |
@staticmethod | |
def mean_pooling(model_output, attention_mask): | |
token_embeddings = model_output[0] | |
input_mask_expanded = ( | |
attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() | |
) | |
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp( | |
input_mask_expanded.sum(1), min=1e-9 | |
) | |
def encode_batch(self, sentences_batch): | |
encoded_input = self.tokenizer( | |
sentences_batch, padding="longest", truncation=True, return_tensors="pt" | |
) | |
with torch.no_grad(): | |
model_output = self.model(**encoded_input) | |
embeddings = self.mean_pooling(model_output, encoded_input["attention_mask"]) | |
return F.normalize(embeddings, p=2, dim=1) | |
def encode( | |
self, | |
sentences, | |
return_list=False, | |
disable_progress=False, | |
): | |
# Ensure sentences is a list | |
if isinstance(sentences, str): | |
sentences = [sentences] # Wrap a single sentence into a list | |
all_embeddings = [] | |
for i in trange(0, len(sentences), self.batch_size, disable=disable_progress): | |
batch_sentences = sentences[i : i + self.batch_size] | |
batch_embeddings = self.encode_batch(batch_sentences) | |
all_embeddings.append(batch_embeddings) | |
all_embeddings = torch.cat(all_embeddings, dim=0) | |
return all_embeddings.cpu().tolist() if return_list else all_embeddings.cpu() | |
def __str__(self): | |
return ( | |
f"TextEmbedder(model_name={self.model_name}, batch_size={self.batch_size},\n" | |
f"\ttokenizer_name={self.tokenizer_name}, use_io_binding={self.use_io_binding})" | |
) | |
def __repr__(self): | |
return self.__str__() | |
def __call__(self, sentences, **kwargs): | |
return self.encode(sentences, **kwargs) | |
# Example usage | |
if __name__ == "__main__": | |
sentences = [ | |
"What is TSNE?", | |
"Who is Laurens van der Maaten?", | |
"Short sentence.", | |
"Another example.", | |
"Yet another query.", | |
"What is machine learning?", | |
"Tell me about OpenAI.", | |
"Explain deep learning.", | |
] | |
embedder = TextEmbedder(batch_size=3) | |
embeddings = embedder.encode(sentences) | |
print(embeddings) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment