Skip to content

Instantly share code, notes, and snippets.

@Norod
Last active June 25, 2023 13:34
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save Norod/c029595f8f85e26418ffcd58599dfb7e to your computer and use it in GitHub Desktop.
Save Norod/c029595f8f85e26418ffcd58599dfb7e to your computer and use it in GitHub Desktop.
Example of using MarianMTModel to perform line-by-line text file translation on CPU or GPU
import torch
import textwrap
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, MarianTokenizer, MarianMTModel
def chunk_examples(txt_list, width):
chunks = []
for sentence in txt_list:
chunks += textwrap.wrap(sentence, width=width)
return chunks
device = "cuda" if torch.cuda.is_available() else "cpu"
# https://github.com/facebookresearch/flores/blob/main/flores200/README.md#languages-in-flores-200
# src = "eng_Latn" # source language
# trg = "heb_Hebr" # target language
# model_name = "facebook/nllb-200-distilled-600M"
# tokenizer = AutoTokenizer.from_pretrained(model_name, src_lang = src)
# model = AutoModelForSeq2SeqLM.from_pretrained(model_name).to(device)
src = "en" # source language
trg = "he" # target language
model_name = f"tiedeman/opus-mt-{src}-{trg}"
model = MarianMTModel.from_pretrained(model_name).to(device)
tokenizer = MarianTokenizer.from_pretrained(model_name)
input_file_name = 'TinyStoriesV2-GPT4-valid.txt'
output_file = open('TinyStoriesV2-GPT4-valid_heb.txt', 'w')
current_line = 0
with open(input_file_name, 'r', encoding='UTF-8') as input_file:
lines = input_file.readlines()
lines = [line.rstrip() for line in lines]
chunked_lines = chunk_examples(txt_list=lines, width=128)
lines=None
number_of_chunked_lines = len(chunked_lines)
print(f"number_of_chunked_lines = {number_of_chunked_lines}")
# max_length = max([len(tokenizer.encode(chunked_lines)) for chunked_lines in chunked_lines])
# max_length += 1
# print("Max length: {}".format(max_length))
current_chunked_line = 0
for line in chunked_lines:
current_chunked_line += 1
if line == "<|endoftext|>":
output_file.write("<|endoftext|>" + "\n")
continue
inputs = tokenizer(line, return_tensors="pt").to(device)
translated_tokens = model.generate(
**inputs)#, forced_bos_token_id=tokenizer.lang_code_to_id[trg])
decoded_lines = tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)
for decoded_line in decoded_lines:
output_file.write(decoded_line + "\n")
if current_chunked_line % 25 == 0 or current_chunked_line == 1:
print(f"{current_chunked_line}")
output_file.flush()
output_file.close()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment