-
-
Save guillaumekln/fb125fc3eb108d1a304b7432486e712f to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""Run SentenceTransformer.encode with CTranslate2. | |
WARNING! This script is a proof of concept using LaBSE and can easily break for other models: | |
* the selected model does not have a registered converter in CTranslate2 | |
* the model requires outputs that are not returned by CTranslate2 | |
* the sentence-transformers library changed the loading logic or code structure | |
(this example was tested with sentence-transformers==2.2.2 and transformers==4.29.2) | |
""" | |
import os | |
import ctranslate2 | |
import numpy as np | |
import sentence_transformers | |
import torch | |
def main(): | |
sentences = ["This is an example sentence", "Each sentence is converted"] | |
model = CT2SentenceTransformer("sentence-transformers/LaBSE") | |
embeddings = model.encode(sentences) | |
print(embeddings) | |
class CT2SentenceTransformer(sentence_transformers.SentenceTransformer): | |
"""Extension of sentence_transformers.SentenceTransformer using a CTranslate2 model.""" | |
def __init__(self, *args, compute_type="default", **kwargs): | |
super().__init__(*args, **kwargs) | |
self[0] = CT2Transformer(self[0], compute_type=compute_type) | |
class CT2Transformer(torch.nn.Module): | |
"""Wrapper around a sentence_transformers.models.Transformer which routes the forward | |
call to a CTranslate2 encoder model. | |
""" | |
def __init__(self, transformer, compute_type="default"): | |
super().__init__() | |
self.transformer = transformer | |
self.compute_type = compute_type | |
self.encoder = None | |
# Convert to the CTranslate2 model format, if not already done. | |
model_dir = transformer.auto_model.config.name_or_path | |
self.ct2_model_dir = os.path.join(model_dir, "ct2") | |
if not os.path.exists(self.ct2_model_dir): | |
converter = ctranslate2.converters.TransformersConverter(model_dir) | |
converter.convert(self.ct2_model_dir) | |
def children(self): | |
# Do not consider the "transformer" attribute as a child module so that it will stay on the CPU. | |
return [] | |
def forward(self, features): | |
device = features["input_ids"].device | |
if self.encoder is None: | |
# The encoder is lazy-loaded to correctly resolve the target device. | |
self.encoder = ctranslate2.Encoder( | |
self.ct2_model_dir, | |
device=device.type, | |
device_index=device.index or 0, | |
intra_threads=torch.get_num_threads(), | |
compute_type=self.compute_type, | |
) | |
input_ids = features["input_ids"].to(torch.int32) | |
length = features["attention_mask"].sum(1, dtype=torch.int32) | |
if device.type == "cpu": | |
# PyTorch CPU tensors do not implement the Array interface so a roundtrip to Numpy | |
# is required for both the input and output. | |
input_ids = input_ids.numpy() | |
length = length.numpy() | |
input_ids = ctranslate2.StorageView.from_array(input_ids) | |
length = ctranslate2.StorageView.from_array(length) | |
outputs = self.encoder.forward_batch(input_ids, length) | |
last_hidden_state = outputs.last_hidden_state | |
if device.type == "cpu": | |
last_hidden_state = np.array(last_hidden_state) | |
features["token_embeddings"] = torch.as_tensor( | |
last_hidden_state, device=device | |
).to(torch.float32) | |
return features | |
def tokenize(self, *args, **kwargs): | |
return self.transformer.tokenize(*args, **kwargs) | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment