Skip to content

Instantly share code, notes, and snippets.

@sappho192
Last active January 7, 2024 01:48
Show Gist options
  • Save sappho192/98ce84364ccb04cd77c8a1ef144869f8 to your computer and use it in GitHub Desktop.
Save sappho192/98ce84364ccb04cd77c8a1ef144869f8 to your computer and use it in GitHub Desktop.
Transformers EncoderDecoder language model on Optimum OnnxRuntime
# pip install transformers, optimum, onnx, onnxruntime, fugashi, unidic-lite
from transformers import BertJapaneseTokenizer,PreTrainedTokenizerFast
from optimum.onnxruntime import ORTModelForSeq2SeqLM
encoder_model_name = "cl-tohoku/bert-base-japanese-v2"
decoder_model_name = "skt/kogpt2-base-v2"
# using local tokenizer
# encoder_model_name = "./src_tokenizer"
# decoder_model_name = "./trg_tokenizer"
src_tokenizer = BertJapaneseTokenizer.from_pretrained(encoder_model_name)
trg_tokenizer = PreTrainedTokenizerFast.from_pretrained(decoder_model_name)
# `from_transformers=True` downloads the PyTorch weights and converts them to ONNX format
model = ORTModelForSeq2SeqLM.from_pretrained("./onnx")
text = "ギルガメッシュ討伐戦"
text2 = "ギルガメッシュ討伐戦に行ってきます。一緒に行きましょうか?"
def translate(text_src):
embeddings = src_tokenizer(text_src, return_attention_mask=False, return_token_type_ids=False, return_tensors='pt')
embeddings = {k: v for k, v in embeddings.items()}
output = model.generate(**embeddings)[0, 1:-1]
text_trg = trg_tokenizer.decode(output.cpu())
return text_trg
print(translate(text))
print(translate(text2))
from cx_Freeze import setup, Executable
import sys
sys.setrecursionlimit(5000)
buildOptions = {
"packages":
["transformers","fugashi","unidic_lite","optimum"]
,
"excludes":[
"matplotlib"
]
}
exe = [Executable("infer_onnx.py")] # 2
# 3
setup(name= 'FFJaKo',
version = '0.1',
author = "sappho192",
description = "FFXIV Ja-Ko Translator",
options = dict(build_exe = buildOptions),
executables = exe)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment