Skip to content

Instantly share code, notes, and snippets.

@guillaumekln
Last active July 22, 2024 13:36
Show Gist options
  • Save guillaumekln/fb125fc3eb108d1a304b7432486e712f to your computer and use it in GitHub Desktop.
Save guillaumekln/fb125fc3eb108d1a304b7432486e712f to your computer and use it in GitHub Desktop.
"""Run SentenceTransformer.encode with CTranslate2.
WARNING! This script is a proof of concept using LaBSE and can easily break for other models:
* the selected model does not have a registered converter in CTranslate2
* the model requires outputs that are not returned by CTranslate2
* the sentence-transformers library changed the loading logic or code structure
(this example was tested with sentence-transformers==2.2.2 and transformers==4.29.2)
"""
import os
import ctranslate2
import numpy as np
import sentence_transformers
import torch
def main():
sentences = ["This is an example sentence", "Each sentence is converted"]
model = CT2SentenceTransformer("sentence-transformers/LaBSE")
embeddings = model.encode(sentences)
print(embeddings)
class CT2SentenceTransformer(sentence_transformers.SentenceTransformer):
"""Extension of sentence_transformers.SentenceTransformer using a CTranslate2 model."""
def __init__(self, *args, compute_type="default", **kwargs):
super().__init__(*args, **kwargs)
self[0] = CT2Transformer(self[0], compute_type=compute_type)
class CT2Transformer(torch.nn.Module):
"""Wrapper around a sentence_transformers.models.Transformer which routes the forward
call to a CTranslate2 encoder model.
"""
def __init__(self, transformer, compute_type="default"):
super().__init__()
self.transformer = transformer
self.compute_type = compute_type
self.encoder = None
# Convert to the CTranslate2 model format, if not already done.
model_dir = transformer.auto_model.config.name_or_path
self.ct2_model_dir = os.path.join(model_dir, "ct2")
if not os.path.exists(self.ct2_model_dir):
converter = ctranslate2.converters.TransformersConverter(model_dir)
converter.convert(self.ct2_model_dir)
def children(self):
# Do not consider the "transformer" attribute as a child module so that it will stay on the CPU.
return []
def forward(self, features):
device = features["input_ids"].device
if self.encoder is None:
# The encoder is lazy-loaded to correctly resolve the target device.
self.encoder = ctranslate2.Encoder(
self.ct2_model_dir,
device=device.type,
device_index=device.index or 0,
intra_threads=torch.get_num_threads(),
compute_type=self.compute_type,
)
input_ids = features["input_ids"].to(torch.int32)
length = features["attention_mask"].sum(1, dtype=torch.int32)
if device.type == "cpu":
# PyTorch CPU tensors do not implement the Array interface so a roundtrip to Numpy
# is required for both the input and output.
input_ids = input_ids.numpy()
length = length.numpy()
input_ids = ctranslate2.StorageView.from_array(input_ids)
length = ctranslate2.StorageView.from_array(length)
outputs = self.encoder.forward_batch(input_ids, length)
last_hidden_state = outputs.last_hidden_state
if device.type == "cpu":
last_hidden_state = np.array(last_hidden_state)
features["token_embeddings"] = torch.as_tensor(
last_hidden_state, device=device
).to(torch.float32)
return features
def tokenize(self, *args, **kwargs):
return self.transformer.tokenize(*args, **kwargs)
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment