Last active
January 5, 2022 00:59
-
-
Save ymoslem/d85b55d2182cfd2ab5d08bed6c63c713 to your computer and use it in GitHub Desktop.
Use nBART pre-trained multilingual model for translation
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!pip install transformers sentencepiece torch -U -q | |
# Replace "test_source.txt" with your source file. | |
# Change src_lang, tgt_lang, and lang_code_to_id to the source and target languages you need. | |
from transformers import MBartForConditionalGeneration, MBart50TokenizerFast | |
import torch | |
from tqdm import tqdm | |
# Function to split source lines into chunks to avoid out-of-memory errors | |
def chunks(lst, n): | |
"""Yield successive n-sized chunks from lst.""" | |
for i in range(0, len(lst), n): | |
yield lst[i : i + n] | |
# Specifiy the model and tokenizer | |
model = MBartForConditionalGeneration.from_pretrained("facebook/mbart-large-50-many-to-one-mmt") | |
tokenizer = MBart50TokenizerFast.from_pretrained("facebook/mbart-large-50-many-to-one-mmt", src_lang="fr_XX", tgt_lang="en_XX") | |
# Open the source file | |
with open("test_source.txt", "r") as source_file: | |
lines = source_file.readlines() | |
source = [line.strip() for line in lines] | |
print(source[0], "\n") | |
# Use GPU if available | |
device = "cuda:0" if torch.cuda.is_available() else "cpu" | |
# Move the model to the GPU. Use half() to use FP16 if possible | |
model = model.to(device).half() | |
# Translate the source sentences | |
translations = [] | |
chunk_size = 32 | |
tqdm_total = round(len(lines)/chunk_size) | |
for source_chunck in tqdm(chunks(source, chunk_size), total=tqdm_total): | |
model_inputs = tokenizer(source_chunck, padding=True, return_tensors="pt").to(device) # optinal: max_length= | |
generated_tokens = model.generate(**model_inputs, forced_bos_token_id=tokenizer.lang_code_to_id["en_XX"]) | |
english = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True) | |
translations.extend(english) | |
# Save translations to a file. | |
# Althernatively, integrate this into the previous step, instead of saving to a list first. | |
with open("test_target.txt", "w+") as target_file: | |
for line in translations: | |
target_file.write(line.strip() + "\n") | |
print("Done") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment