Last active
January 18, 2023 00:23
-
-
Save ymoslem/60e1d1dc44fe006f67e130b6ad703c4b to your computer and use it in GitHub Desktop.
Example of using CTranslate2 as a translation inference engine
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# First convert your OpenNMT-py or OpenNMT-tf model to a CTranslate2 model. | |
# pip3 install ctranslate2 | |
# • OpenNMT-py: | |
# ct2-opennmt-py-converter --model_path model.pt --output_dir enja_ctranslate2 --quantization int8 | |
# • OpenNMT-tf: | |
# ct2-opennmt-tf-converter --model_path model --output_dir enja_ctranslate2 --src_vocab source.vocab --tgt_vocab target.vocab --model_type TransformerBase --quantization int8 | |
import ctranslate2 | |
import sentencepiece as spm | |
# Set file paths | |
source_file_path = "test.en" | |
target_file_path = "test.ja" | |
sp_source_model_path = "spm_model.en" | |
sp_target_model_path = "spm_model.ja" | |
ct_model_path = "enja_ctranslate2/" | |
# Load the source SentecePiece model | |
sp = spm.SentencePieceProcessor() | |
sp.load(sp_source_model_path) | |
# Open the source file | |
with open(source_file_path, "r") as source: | |
lines = source.readlines() | |
source_sents = [line.strip() for line in lines] | |
# Subword the source sentences | |
source_sents_subworded = sp.encode_as_pieces(source_sents) | |
# Translate the source sentences | |
translator = ctranslate2.Translator(ct_model_path, device="cpu") # or "cuda" for GPU | |
translations = translator.translate_batch(source_sents_subworded, batch_type="tokens", max_batch_size=4096) | |
translations = [translation.hypotheses[0] for translation in translations] | |
# Load the target SentecePiece model | |
sp.load(sp_target_model_path) | |
# Desubword the target sentences | |
translations_desubword = sp.decode(translations) | |
# Save the translations to the a file | |
with open(target_file_path, "w+", encoding="utf-8") as target: | |
for line in translations_desubword: | |
target.write(line.strip() + "\n") | |
print("Done") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment