Skip to content

Instantly share code, notes, and snippets.

@HoangTienDuc
Last active June 19, 2020 17:06
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save HoangTienDuc/a00a90cbed9fc67423aebdb438013ae3 to your computer and use it in GitHub Desktop.
Save HoangTienDuc/a00a90cbed9fc67423aebdb438013ae3 to your computer and use it in GitHub Desktop.
- Run server: nvidia-docker run --rm --name trtserver -p 8000:8000 -p 8001:8001 -v `pwd`:/models nvcr.io/nvidia/tritonserver:20.03.1-py3 trtserver --model-store=/models --api-version=2
# Run server: nvidia-docker run --rm --name trtserver -p 8000:8000 -p 8001:8001 -v `pwd`:/models nvcr.io/nvidia/tritonserver:20.03.1-py3 trtserver --model-store=/models --api-version=2
# Run client: nvidia-docker run -it -v `pwd`:/data --rm --net=host triton:20.03.1
# Run server: nvidia-docker run --rm --name trtserver -p 8000:8000 -p 8001:8001 -v `pwd`:/models nvcr.io/nvidia/tritonserver:20.03.1-py3 trtserver --model-store=/models --api-version=2
# Run client: nvidia-docker run -it -v `pwd`:/data --rm --net=host triton:20.03.1
#!/usr/bin/env python
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions
# are met:
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution.
# * Neither the name of NVIDIA CORPORATION nor the names of its
# contributors may be used to endorse or promote products derived
# from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
import argparse
import numpy as np
import cv2
import sys
from functools import partial
import os
import tritongrpcclient
import tritongrpcclient.model_config_pb2 as mc
import tritonhttpclient
from tritonclientutils.utils import triton_to_np_dtype
from tritonclientutils.utils import InferenceServerException
if sys.version_info >= (3, 0):
import queue
else:
import Queue as queue
class UserData:
def __init__(self):
self._completed_requests = queue.Queue()
# Callback function used for async_stream_infer()
def completion_callback(user_data, result, error):
# passing error raise and handling out
user_data._completed_requests.put((result, error))
FLAGS = None
def parse_model_grpc(model_metadata, model_config):
input_metadatas = model_metadata.inputs
input_configs = model_config.input
output_metadatas = model_metadata.outputs
input_names = [input_metadata.name for input_metadata in input_metadatas]
output_names = [output_metadata.name for output_metadata in output_metadatas]
return (model_config.max_batch_size, input_names,
output_names)
def parse_model_http(model_metadata, model_config):
input_metadatas = model_metadata['inputs']
input_config = model_config['input']
output_metadatas = model_metadata['outputs']
input_names = [input_metadata['name'] for input_metadata in input_metadatas]
output_names = [output_metadata['name'] for output_metadata in output_metadatas]
max_batch_size = 0
if 'max_batch_size' in model_config:
max_batch_size = model_config['max_batch_size']
return (max_batch_size, input_names,
output_names)
def preprocess(img_raw):
image_data = []
img = np.float32(img_raw)
im_height, im_width, _ = img.shape
scale = [img.shape[1], img.shape[0], img.shape[1], img.shape[0]]
img -= (104, 117, 123)
img = img.transpose(2, 0, 1)
return scale, img
def postprocess(results, output_names, batch_size):
"""
Post-process results to show classifications.
"""
for output_name in output_names:
output_array = results.as_numpy(output_name)
print(output_name, "output_array: ", output_array)
if len(output_array) != batch_size:
raise Exception("expected {} results, got {}".format(
batch_size, len(output_array)))
# for results in output_array:
# for result in results:
# if output_array.dtype.type == np.bytes_:
# cls = "".join(chr(x) for x in result).split(':')
# else:
# cls = result.split(':')
# print(" {} = {}".format(cls[0], cls[1]))
def requestGenerator(batched_image_data, input_names, output_names, dtype, FLAGS):
# Set the input data
inputs = []
if FLAGS.protocol.lower() == "grpc":
for input_name in input_names:
inputs.append(
tritongrpcclient.InferInput(input_name, batched_image_data.shape,
dtype))
inputs[0].set_data_from_numpy(batched_image_data)
else:
for input_name in input_names:
inputs.append(
tritonhttpclient.InferInput(input_name, batched_image_data.shape,
dtype))
inputs[0].set_data_from_numpy(batched_image_data, binary_data=False)
outputs = []
if FLAGS.protocol.lower() == "grpc":
for output_name in output_names:
outputs.append(
tritongrpcclient.InferRequestedOutput(output_name,
class_count=FLAGS.classes))
else:
for output_name in output_names:
outputs.append(
tritonhttpclient.InferRequestedOutput(output_name,
binary_data=False,
class_count=FLAGS.classes))
yield inputs, outputs, FLAGS.model_name, FLAGS.model_version
def augments():
parser = argparse.ArgumentParser()
parser.add_argument('-v',
'--verbose',
action="store_true",
required=False,
default=False,
help='Enable verbose output')
parser.add_argument('-a',
'--async',
dest="async_set",
action="store_true",
required=False,
default=False,
help='Use asynchronous inference API')
parser.add_argument('--streaming',
action="store_true",
required=False,
default=False,
help='Use streaming inference API. ' +
'The flag is only available with gRPC protocol.')
parser.add_argument('-m',
'--model-name',
type=str,
required=True,
help='Name of model')
parser.add_argument(
'-x',
'--model-version',
type=str,
required=False,
default="",
help='Version of model. Default is to use latest version.')
parser.add_argument('-b',
'--batch-size',
type=int,
required=False,
default=1,
help='Batch size. Default is 1.')
parser.add_argument('-c',
'--classes',
type=int,
required=False,
default=1,
help='Number of class results to report. Default is 1.')
parser.add_argument(
'-s',
'--scaling',
type=str,
choices=['NONE', 'INCEPTION', 'VGG'],
required=False,
default='NONE',
help='Type of scaling to apply to image pixels. Default is NONE.')
parser.add_argument('-u',
'--url',
type=str,
required=False,
default='localhost:8000',
help='Inference server URL. Default is localhost:8000.')
parser.add_argument('-i',
'--protocol',
type=str,
required=False,
default='HTTP',
help='Protocol (HTTP/gRPC) used to communicate with ' +
'the inference service. Default is HTTP.')
parser.add_argument('image_filename',
type=str,
nargs='?',
default='1.jpg',
help='Input image / Input folder.')
return parser.parse_args()
def init_model(FLAGS):
if FLAGS.streaming and FLAGS.protocol.lower() != "grpc":
raise Exception("Streaming is only allowed with gRPC protocol")
try:
if FLAGS.protocol.lower() == "grpc":
# Create gRPC client for communicating with the server
triton_client = tritongrpcclient.InferenceServerClient(
url=FLAGS.url, verbose=FLAGS.verbose)
else:
# Create HTTP client for communicating with the server
triton_client = tritonhttpclient.InferenceServerClient(
url=FLAGS.url, verbose=FLAGS.verbose)
except Exception as e:
print("client creation failed: " + str(e))
sys.exit(1)
# Make sure the model matches our requirements, and get some
# properties of the model that we need for preprocessing
try:
model_metadata = triton_client.get_model_metadata(
model_name=FLAGS.model_name, model_version=FLAGS.model_version)
except InferenceServerException as e:
print("failed to retrieve the metadata: " + str(e))
sys.exit(1)
try:
model_config = triton_client.get_model_config(
model_name=FLAGS.model_name, model_version=FLAGS.model_version)
except InferenceServerException as e:
print("failed to retrieve the config: " + str(e))
sys.exit(1)
if FLAGS.protocol.lower() == "grpc":
max_batch_size, input_name, output_name = parse_model_grpc(
model_metadata, model_config.config)
else:
max_batch_size, input_name, output_name = parse_model_http(
model_metadata, model_config)
return triton_client, max_batch_size, input_name, output_name
def inferencing(triton_client, batched_image_data, input_name, output_name, dtype, FLAGS, sent_count, responses):
try:
for inputs, outputs, model_name, model_version in requestGenerator(
batched_image_data, input_name, output_name, dtype, FLAGS):
sent_count += 1
if FLAGS.streaming:
triton_client.async_stream_infer(
FLAGS.model_name,
inputs,
request_id=str(sent_count),
model_version=FLAGS.model_version,
outputs=outputs)
elif FLAGS.async_set:
if FLAGS.protocol.lower() == "grpc":
triton_client.async_infer(
FLAGS.model_name,
inputs,
partial(completion_callback, user_data),
request_id=str(sent_count),
model_version=FLAGS.model_version,
outputs=outputs)
else:
async_requests.append(
triton_client.async_infer(
FLAGS.model_name,
inputs,
request_id=str(sent_count),
model_version=FLAGS.model_version,
outputs=outputs))
else:
responses.append(
triton_client.infer(FLAGS.model_name,
inputs,
request_id=str(sent_count),
model_version=FLAGS.model_version,
outputs=outputs))
except InferenceServerException as e:
print("inference failed: " + str(e))
if FLAGS.streaming:
triton_client.stop_stream()
sys.exit(1)
return responses
def processer(triton_client, FLAGS, image_data):
requests = []
responses = []
result_filenames = []
request_ids = []
image_idx = 0
last_request = False
user_data = UserData()
# Holds the handles to the ongoing HTTP async requests.
async_requests = []
sent_count = 0
dtype = "FP32"
if FLAGS.streaming:
triton_client.start_stream(
partial(completion_callback, user_data))
while not last_request:
input_filenames = []
repeated_image_data = []
for idx in range(FLAGS.batch_size):
input_filenames.append(idx)
repeated_image_data.append(image_data[image_idx])
image_idx = (image_idx + 1) % len(image_data)
if image_idx == 0:
last_request = True
if max_batch_size > 0:
batched_image_data = np.stack(repeated_image_data, axis=0)
else:
batched_image_data = np.array(repeated_image_data)
# Send request
responses = inferencing(triton_client, batched_image_data, input_name, output_name, dtype, FLAGS, sent_count, responses)
if FLAGS.streaming:
triton_client.stop_stream()
if FLAGS.protocol.lower() == "grpc":
if FLAGS.streaming or FLAGS.async_set:
processed_count = 0
while processed_count < sent_count:
(results, error) = user_data._completed_requests.get()
processed_count += 1
if error is not None:
print("inference failed: " + str(error))
sys.exit(1)
responses.append(results)
else:
if FLAGS.async_set:
# Collect results from the ongoing async requests
# for HTTP Async requests.
for async_request in async_requests:
responses.append(async_request.get_result())
return responses
if __name__ == '__main__':
FLAGS = augments()
triton_client, max_batch_size, input_name, output_name = init_model(FLAGS)
img_path = './1.jpg'
img_raw = cv2.imread(img_path)
scale, img = preprocess(img_raw)
# Preprocess the images into input data according to model
# requirements
image_data = []
image_data.append(img)
responses = processer(triton_client, FLAGS, image_data)
for response in responses:
if FLAGS.protocol.lower() == "grpc":
this_id = response.get_response().id
else:
this_id = response.get_response()["id"]
print("Request {}, batch size {}".format(this_id, FLAGS.batch_size))
postprocess(response, output_name, FLAGS.batch_size)
print("PASS")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment