Created
December 19, 2023 08:18
-
-
Save indiejoseph/ebd5aac333bc636ab43632ba65a01d19 to your computer and use it in GitHub Desktop.
bart onnx model
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Copyright 2022, Lefebvre Dalloz Services | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
import logging | |
""" | |
This module is copy-pasted in generated Triton configuration folder to perform the tokenization step. | |
""" | |
# noinspection DuplicatedCode | |
import os | |
import json | |
from typing import Dict, List | |
import numpy as np | |
from inference.translation_pipeline import TranslationPipeline | |
from optimum.onnxruntime import ORTModelForSeq2SeqLM | |
from transformers import pipeline, BertTokenizerFast, Pipeline | |
logger = logging.getLogger(__name__) | |
try: | |
# noinspection PyUnresolvedReferences | |
import triton_python_backend_utils as pb_utils | |
except ImportError: | |
pass # triton_python_backend_utils exists only inside Triton Python backend. | |
class TritonPythonModel: | |
tokenizer: BertTokenizerFast | |
model: ORTModelForSeq2SeqLM | |
pipe: Pipeline | |
def initialize(self, args: Dict[str, str]) -> None: | |
""" | |
Initialize the tokenization process | |
:param args: arguments from Triton config file | |
""" | |
# more variables in https://github.com/triton-inference-server/python_backend/blob/main/src/python.cc | |
path: str = os.path.join(args["model_repository"], args["model_version"]) | |
self.tokenizer = BertTokenizerFast.from_pretrained(path) | |
self.model = ORTModelForSeq2SeqLM.from_pretrained(path, use_cache=False) | |
self.model_config = model_config = json.loads(args["model_config"]) | |
self.pipe = TranslationPipeline(model=self.model, tokenizer=self.tokenizer) | |
# Get OUTPUT0 configuration | |
output0_config = pb_utils.get_output_config_by_name(model_config, "OUTPUT") | |
# Convert Triton types to numpy types | |
self.output0_dtype = pb_utils.triton_string_to_numpy( | |
output0_config["data_type"] | |
) | |
def execute(self, requests) -> "List[List[pb_utils.Tensor]]": | |
""" | |
Parse and tokenize each request | |
:param requests: 1 or more requests received by Triton server. | |
:return: text as input tensors | |
""" | |
responses = [] | |
output0_dtype = self.output0_dtype | |
# for loop for batch requests (disabled in our case) | |
for request in requests: | |
# binary data typed back to string | |
inputs = [ | |
t.decode("UTF-8") | |
for t in pb_utils.get_input_tensor_by_name(request, "TEXT") | |
.as_numpy() | |
.tolist() | |
] | |
output = self.pipe(inputs) | |
logger.info(output) | |
out_0 = np.array( | |
[o["translation_text"].encode("UTF-8") for o in output] | |
).astype(np.object_) | |
out_tensor_0 = pb_utils.Tensor("OUTPUT", out_0.astype(output0_dtype)) | |
inference_response = pb_utils.InferenceResponse( | |
output_tensors=[out_tensor_0] | |
) | |
responses.append(inference_response) | |
return responses |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment