Skip to content

Instantly share code, notes, and snippets.

@indiejoseph
Created December 19, 2023 08:18
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save indiejoseph/ebd5aac333bc636ab43632ba65a01d19 to your computer and use it in GitHub Desktop.
Save indiejoseph/ebd5aac333bc636ab43632ba65a01d19 to your computer and use it in GitHub Desktop.
bart onnx model
# Copyright 2022, Lefebvre Dalloz Services
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
"""
This module is copy-pasted in generated Triton configuration folder to perform the tokenization step.
"""
# noinspection DuplicatedCode
import os
import json
from typing import Dict, List
import numpy as np
from inference.translation_pipeline import TranslationPipeline
from optimum.onnxruntime import ORTModelForSeq2SeqLM
from transformers import pipeline, BertTokenizerFast, Pipeline
logger = logging.getLogger(__name__)
try:
# noinspection PyUnresolvedReferences
import triton_python_backend_utils as pb_utils
except ImportError:
pass # triton_python_backend_utils exists only inside Triton Python backend.
class TritonPythonModel:
tokenizer: BertTokenizerFast
model: ORTModelForSeq2SeqLM
pipe: Pipeline
def initialize(self, args: Dict[str, str]) -> None:
"""
Initialize the tokenization process
:param args: arguments from Triton config file
"""
# more variables in https://github.com/triton-inference-server/python_backend/blob/main/src/python.cc
path: str = os.path.join(args["model_repository"], args["model_version"])
self.tokenizer = BertTokenizerFast.from_pretrained(path)
self.model = ORTModelForSeq2SeqLM.from_pretrained(path, use_cache=False)
self.model_config = model_config = json.loads(args["model_config"])
self.pipe = TranslationPipeline(model=self.model, tokenizer=self.tokenizer)
# Get OUTPUT0 configuration
output0_config = pb_utils.get_output_config_by_name(model_config, "OUTPUT")
# Convert Triton types to numpy types
self.output0_dtype = pb_utils.triton_string_to_numpy(
output0_config["data_type"]
)
def execute(self, requests) -> "List[List[pb_utils.Tensor]]":
"""
Parse and tokenize each request
:param requests: 1 or more requests received by Triton server.
:return: text as input tensors
"""
responses = []
output0_dtype = self.output0_dtype
# for loop for batch requests (disabled in our case)
for request in requests:
# binary data typed back to string
inputs = [
t.decode("UTF-8")
for t in pb_utils.get_input_tensor_by_name(request, "TEXT")
.as_numpy()
.tolist()
]
output = self.pipe(inputs)
logger.info(output)
out_0 = np.array(
[o["translation_text"].encode("UTF-8") for o in output]
).astype(np.object_)
out_tensor_0 = pb_utils.Tensor("OUTPUT", out_0.astype(output0_dtype))
inference_response = pb_utils.InferenceResponse(
output_tensors=[out_tensor_0]
)
responses.append(inference_response)
return responses
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment