Last active
June 25, 2023 13:34
-
-
Save Norod/c029595f8f85e26418ffcd58599dfb7e to your computer and use it in GitHub Desktop.
Example of using MarianMTModel to perform line-by-line text file translation on CPU or GPU
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import textwrap | |
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, MarianTokenizer, MarianMTModel | |
def chunk_examples(txt_list, width): | |
chunks = [] | |
for sentence in txt_list: | |
chunks += textwrap.wrap(sentence, width=width) | |
return chunks | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
# https://github.com/facebookresearch/flores/blob/main/flores200/README.md#languages-in-flores-200 | |
# src = "eng_Latn" # source language | |
# trg = "heb_Hebr" # target language | |
# model_name = "facebook/nllb-200-distilled-600M" | |
# tokenizer = AutoTokenizer.from_pretrained(model_name, src_lang = src) | |
# model = AutoModelForSeq2SeqLM.from_pretrained(model_name).to(device) | |
src = "en" # source language | |
trg = "he" # target language | |
model_name = f"tiedeman/opus-mt-{src}-{trg}" | |
model = MarianMTModel.from_pretrained(model_name).to(device) | |
tokenizer = MarianTokenizer.from_pretrained(model_name) | |
input_file_name = 'TinyStoriesV2-GPT4-valid.txt' | |
output_file = open('TinyStoriesV2-GPT4-valid_heb.txt', 'w') | |
current_line = 0 | |
with open(input_file_name, 'r', encoding='UTF-8') as input_file: | |
lines = input_file.readlines() | |
lines = [line.rstrip() for line in lines] | |
chunked_lines = chunk_examples(txt_list=lines, width=128) | |
lines=None | |
number_of_chunked_lines = len(chunked_lines) | |
print(f"number_of_chunked_lines = {number_of_chunked_lines}") | |
# max_length = max([len(tokenizer.encode(chunked_lines)) for chunked_lines in chunked_lines]) | |
# max_length += 1 | |
# print("Max length: {}".format(max_length)) | |
current_chunked_line = 0 | |
for line in chunked_lines: | |
current_chunked_line += 1 | |
if line == "<|endoftext|>": | |
output_file.write("<|endoftext|>" + "\n") | |
continue | |
inputs = tokenizer(line, return_tensors="pt").to(device) | |
translated_tokens = model.generate( | |
**inputs)#, forced_bos_token_id=tokenizer.lang_code_to_id[trg]) | |
decoded_lines = tokenizer.batch_decode(translated_tokens, skip_special_tokens=True) | |
for decoded_line in decoded_lines: | |
output_file.write(decoded_line + "\n") | |
if current_chunked_line % 25 == 0 or current_chunked_line == 1: | |
print(f"{current_chunked_line}") | |
output_file.flush() | |
output_file.close() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment