Skip to content

Instantly share code, notes, and snippets.

@ymoslem
Last active January 5, 2022 00:59
Show Gist options
  • Save ymoslem/d85b55d2182cfd2ab5d08bed6c63c713 to your computer and use it in GitHub Desktop.
Save ymoslem/d85b55d2182cfd2ab5d08bed6c63c713 to your computer and use it in GitHub Desktop.
Use nBART pre-trained multilingual model for translation
#!pip install transformers sentencepiece torch -U -q
# Replace "test_source.txt" with your source file.
# Change src_lang, tgt_lang, and lang_code_to_id to the source and target languages you need.
from transformers import MBartForConditionalGeneration, MBart50TokenizerFast
import torch
from tqdm import tqdm
# Function to split source lines into chunks to avoid out-of-memory errors
def chunks(lst, n):
"""Yield successive n-sized chunks from lst."""
for i in range(0, len(lst), n):
yield lst[i : i + n]
# Specifiy the model and tokenizer
model = MBartForConditionalGeneration.from_pretrained("facebook/mbart-large-50-many-to-one-mmt")
tokenizer = MBart50TokenizerFast.from_pretrained("facebook/mbart-large-50-many-to-one-mmt", src_lang="fr_XX", tgt_lang="en_XX")
# Open the source file
with open("test_source.txt", "r") as source_file:
lines = source_file.readlines()
source = [line.strip() for line in lines]
print(source[0], "\n")
# Use GPU if available
device = "cuda:0" if torch.cuda.is_available() else "cpu"
# Move the model to the GPU. Use half() to use FP16 if possible
model = model.to(device).half()
# Translate the source sentences
translations = []
chunk_size = 32
tqdm_total = round(len(lines)/chunk_size)
for source_chunck in tqdm(chunks(source, chunk_size), total=tqdm_total):
model_inputs = tokenizer(source_chunck, padding=True, return_tensors="pt").to(device) # optinal: max_length=
generated_tokens = model.generate(**model_inputs, forced_bos_token_id=tokenizer.lang_code_to_id["en_XX"])
english = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
translations.extend(english)
# Save translations to a file.
# Althernatively, integrate this into the previous step, instead of saving to a list first.
with open("test_target.txt", "w+") as target_file:
for line in translations:
target_file.write(line.strip() + "\n")
print("Done")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment