Created
May 15, 2024 15:03
-
-
Save schuhiti/9b632ec3a707c0cff107eb7a48f757c0 to your computer and use it in GitHub Desktop.
ローカル環境でAI翻訳 (NLLB200 Model)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
pip install transformers | |
# CPUだけならいらないのかも? | |
pip install torch torchvision torchaudio | |
pip install gradio | |
pip install nltk | |
# NLTK Dataのダウンロード | |
python -m nltk.downloader punkt | |
# 実際にはpythonインタープリターからダウンロードした | |
# $ python3 | |
# >>> import nltk | |
# >>> nltk.download('punkt') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import gradio as gr | |
import nltk | |
import translate_with_nllb200 as translator | |
# 英文を文単位に分割する | |
def split_sentences(text): | |
return nltk.sent_tokenize(text) | |
def translation(txt): | |
splitted = split_sentences(txt) | |
responce = [] | |
splitter = "\n" | |
for sentence in splitted: | |
responce.append(translator.translate(sentence)) | |
yield splitter.join(responce) # 一文ごと結果を返す | |
# Web UI (Gradio) | |
demo = gr.Interface( | |
fn=translation, | |
inputs=[gr.Textbox(label="Text", lines=3)], | |
outputs=[gr.Textbox(label="Translation", lines=6)], | |
allow_flagging="never" | |
) | |
demo.launch() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer | |
# CUDA等を使う場合 | |
#import torch | |
# DirectMLを使う場合 | |
#import torch_directml | |
tokenizer = AutoTokenizer.from_pretrained( | |
# 入力は英語とする | |
"facebook/nllb-200-3.3B", src_lang="eng_Latn" | |
) | |
model = AutoModelForSeq2SeqLM.from_pretrained("facebook/nllb-200-3.3B") | |
# CUDA用 | |
#if torch.cuda.is_available(): | |
# device = torch.device('cuda') | |
#elif torch_directml.is_available(): | |
# device = torch_directml.device(torch_directml.default_device()) | |
#else: | |
# device = torch.device('cpu') | |
#model.to(device) | |
def translate(article): | |
inputs = tokenizer(article, return_tensors="pt") | |
# 日本語を返す、max_lengthは適宜調整する | |
translated_tokens = model.generate( | |
**inputs, forced_bos_token_id=tokenizer.lang_code_to_id["jpn_Jpan"], max_length=80 | |
) | |
return tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)[0] | |
# 参考 | |
# https://economylife.net/nllb200-windows-local/ |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment