Skip to content

Instantly share code, notes, and snippets.

@pszemraj
Last active February 3, 2024 07:39
Show Gist options
  • Save pszemraj/b9d3b951229c0eec3973ff7efe2adb18 to your computer and use it in GitHub Desktop.
Save pszemraj/b9d3b951229c0eec3973ff7efe2adb18 to your computer and use it in GitHub Desktop.
fast nomic onnx embedder
import logging
from pathlib import Path
import datasets
import fire
import transformers
from datasets import load_dataset
from nomic_embedder_onnx import TextEmbedder
logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
)
def shut_up():
datasets.utils.logging.set_verbosity(logging.ERROR)
transformers.utils.logging.set_verbosity(logging.ERROR)
shut_up()
def encode_texts(examples, embedder, col_name="embeddings"):
"""
Encode a batch of texts and return the embeddings.
Parameters:
- examples: A batch of texts.
- embedder: An instance of TextEmbedder.
Returns:
- A dictionary with the embeddings to add them as a new column.
"""
embeddings = embedder.encode(
examples[col_name], return_list=True, disable_progress=True
)
return {f"{col_name}-embedding": embeddings}
def process_dataset(
dataset_name: str,
config_name: str = "default",
split="train",
col_name="text",
batch_size=4,
output_dir="output",
push_to_hub=True,
):
"""
Process the dataset to encode texts and save the embeddings.
Parameters:
- dataset_name: The name of the dataset.
- split: The dataset split to use.
- batch_size: Batch size for processing.
- output_dir: Directory to save processed dataset.
- push_to_hub: Whether to push the processed dataset to HF Hub.
"""
output_dir = Path(output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
# Initialize the TextEmbedder with desired parameters
embedder = TextEmbedder(batch_size=batch_size, use_io_binding=False)
print(embedder)
# Load the dataset
logging.info(
f"Loading dataset: {dataset_name} config: {config_name}, split: {split}"
)
dataset = load_dataset(dataset_name, config_name, split=split)
# Process the dataset in batches
logging.info("Processing dataset...")
dataset = dataset.map(
lambda examples: encode_texts(examples, embedder, col_name=col_name),
batched=True,
batch_size=batch_size,
)
if push_to_hub:
logging.info("Pushing processed dataset to HF Hub...")
dataset.push_to_hub(
dataset_name,
private=True,
config_name="embeddings-{col_name}-nomic_text_v1",
commit_message=f"add text embeddings on column {col_name}",
)
else:
output_path = (
output_dir / f"{dataset_name.replace('/', '_')}_{split}_embeddings"
)
logging.info(f"Saving processed dataset to {output_path}")
dataset.save_to_disk(str(output_path))
logging.info("Dataset processing completed.")
if __name__ == "__main__":
fire.Fire(process_dataset)
# pip install -U "optimum[onnxruntime]"
import torch
import torch.nn.functional as F
from optimum.onnxruntime import ORTModelForFeatureExtraction
from tqdm.auto import trange
from transformers import AutoTokenizer
class TextEmbedder:
def __init__(
self,
model_name="nomic-ai/nomic-embed-text-v1",
tokenizer_name="bert-base-uncased",
model_file_name="onnx/model_quantized.onnx",
batch_size=8,
use_io_binding=False,
):
self.model_name = model_name
self.tokenizer_name = tokenizer_name
self.batch_size = batch_size
self.use_io_binding = use_io_binding
self.tokenizer = AutoTokenizer.from_pretrained(
self.tokenizer_name or self.model_name, model_max_length=8192
)
self.model = ORTModelForFeatureExtraction.from_pretrained(
model_name,
trust_remote_code=True,
file_name=model_file_name, # Use "onnx/model.onnx" for unquantized version
rotary_scaling_factor=2,
use_io_binding=use_io_binding,
)
@staticmethod
def mean_pooling(model_output, attention_mask):
token_embeddings = model_output[0]
input_mask_expanded = (
attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
)
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(
input_mask_expanded.sum(1), min=1e-9
)
def encode_batch(self, sentences_batch):
encoded_input = self.tokenizer(
sentences_batch, padding="longest", truncation=True, return_tensors="pt"
)
with torch.no_grad():
model_output = self.model(**encoded_input)
embeddings = self.mean_pooling(model_output, encoded_input["attention_mask"])
return F.normalize(embeddings, p=2, dim=1)
def encode(
self,
sentences,
return_list=False,
disable_progress=False,
):
# Ensure sentences is a list
if isinstance(sentences, str):
sentences = [sentences] # Wrap a single sentence into a list
all_embeddings = []
for i in trange(0, len(sentences), self.batch_size, disable=disable_progress):
batch_sentences = sentences[i : i + self.batch_size]
batch_embeddings = self.encode_batch(batch_sentences)
all_embeddings.append(batch_embeddings)
all_embeddings = torch.cat(all_embeddings, dim=0)
return all_embeddings.cpu().tolist() if return_list else all_embeddings.cpu()
def __str__(self):
return (
f"TextEmbedder(model_name={self.model_name}, batch_size={self.batch_size},\n"
f"\ttokenizer_name={self.tokenizer_name}, use_io_binding={self.use_io_binding})"
)
def __repr__(self):
return self.__str__()
def __call__(self, sentences, **kwargs):
return self.encode(sentences, **kwargs)
# Example usage
if __name__ == "__main__":
sentences = [
"What is TSNE?",
"Who is Laurens van der Maaten?",
"Short sentence.",
"Another example.",
"Yet another query.",
"What is machine learning?",
"Tell me about OpenAI.",
"Explain deep learning.",
]
embedder = TextEmbedder(batch_size=3)
embeddings = embedder.encode(sentences)
print(embeddings)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment