Skip to content

Instantly share code, notes, and snippets.

@szalpal
Created June 21, 2022 23:45
Show Gist options
  • Save szalpal/63d427249faab0f1b9087059ae394d58 to your computer and use it in GitHub Desktop.
Save szalpal/63d427249faab0f1b9087059ae394d58 to your computer and use it in GitHub Desktop.
Tritonserver client code for a bug reproduction
#!/usr/bin/env python
# The MIT License (MIT)
#
# Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
#
# Permission is hereby granted, free of charge, to any person obtaining a copy of
# this software and associated documentation files (the "Software"), to deal in
# the Software without restriction, including without limitation the rights to
# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
# the Software, and to permit persons to whom the Software is furnished to do so,
# subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
import argparse
import os
import re
import sys
import numpy as np
import tritonclient.grpc
from numpy.random import randint
np.random.seed(100019)
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('-v', '--verbose', action="store_true", required=False, default=False,
help='Enable verbose output')
parser.add_argument('-u', '--url', type=str, required=False, default='localhost:8001',
help='Inference server URL. Default is localhost:8001.')
parser.add_argument('--batch_size', type=int, required=False, default=1,
help='Batch size')
parser.add_argument('--n_iter', type=int, required=False, default=-1,
help='Number of iterations , with `batch_size` size')
parser.add_argument('--model_name', type=str, required=False, default="dali_multi_input",
help='Model name')
parser.add_argument('--img_dir', type=str, required=False, default=None,
help='Directory, with images that will be broken down into batches and '
'inferred. The directory must contain images only')
return parser.parse_args()
def array_from_list(arrays):
"""
Convert list of ndarrays to single ndarray with ndims+=1
"""
lengths = list(map(lambda x, arr=arrays: arr[x].shape[0], [x for x in range(len(arrays))]))
max_len = max(lengths)
arrays = list(map(lambda arr, ml=max_len: np.pad(arr, ((0, ml - arr.shape[0]))), arrays))
for arr in arrays:
assert arr.shape == arrays[0].shape, "Arrays must have the same shape"
return np.stack(arrays)
def load_image(img_path: str):
"""
Loads image as an encoded array of bytes.
This is a typical approach you want to use in DALI backend
"""
with open(img_path, "rb") as f:
img = f.read()
return np.array(list(img)).astype(np.uint8)
def load_images(dir_path: str, name_pattern='.', max_images=-1):
"""
Loads all files in given dir_path. Treats them as images. Optionally apply regex pattern to
file names and use only the files, that suffice the pattern
"""
assert max_images > 0 or max_images == -1
images = []
# Traverses directory for files (not dirs) and returns full paths to them
path_generator = (os.path.join(dir_path, f) for f in os.listdir(dir_path) if
os.path.isfile(os.path.join(dir_path, f)) and
re.search(name_pattern, f) is not None)
img_paths = [dir_path] if os.path.isfile(dir_path) else list(path_generator)
if 0 < max_images < len(img_paths):
img_paths = img_paths[:max_images]
for img in img_paths:
images.append(load_image(img))
return images
def batcher(dataset, max_batch_size, n_iterations=-1):
"""
Generator, that splits dataset into batches with given batch size
"""
iter_idx = 0
data_idx = 0
while data_idx < len(dataset):
if 0 < n_iterations <= iter_idx:
raise StopIteration
batch_size = min(randint(0, max_batch_size) + 1, len(dataset) - data_idx)
iter_idx += 1
yield dataset[data_idx: data_idx + batch_size]
data_idx += batch_size
def main():
FLAGS = parse_args()
try:
triton_client = tritonclient.grpc.InferenceServerClient(url=FLAGS.url,
verbose=FLAGS.verbose)
except Exception as e:
print("channel creation failed: " + str(e))
sys.exit(1)
model_name = FLAGS.model_name
model_version = -1
image_data = load_images(FLAGS.img_dir, max_images=FLAGS.batch_size * FLAGS.n_iter)
input_data = array_from_list(image_data)
# Infer
outputs = []
input_names = ["IMAGE"]
scalar_names = ["CROP_X", "CROP_Y", "CROP_WIDTH", "CROP_HEIGHT"]
output_names = ["PREPROCESSED_IMAGE"]
input_shape = list(input_data.shape)
for oname in output_names:
outputs.append(tritonclient.grpc.InferRequestedOutput(oname))
for batch in batcher(input_data, FLAGS.batch_size):
print("Input mean before backend processing:", np.mean(batch))
batch_size = np.shape(batch)[0]
print("Batch size: ", batch_size)
# Initialize the data
input_shape[0] = batch_size
# scalars = randint(0, 1024, size=(batch_size, 1), dtype=np.int32)
inputs = [tritonclient.grpc.InferInput(iname, input_shape, "UINT8") for iname in
input_names]
scalar_inputs = [tritonclient.grpc.InferInput(sname, [batch_size, 1], "FP32") for sname in
scalar_names]
for inp in inputs:
inp.set_data_from_numpy(np.copy(batch))
for scal in scalar_inputs[:2]:
scal.set_data_from_numpy(randint(0, 1, size=(batch_size, 1)).astype('float32'))
for scal in scalar_inputs[2:]:
scal.set_data_from_numpy(randint(100, 200, size=(batch_size, 1)).astype('float32'))
# Test with outputs
results = triton_client.infer(model_name=model_name, inputs=[*inputs, *scalar_inputs],
outputs=outputs)
# Get the output arrays from the results
for oname in output_names:
print("\nOutput: ", oname)
output_data = results.as_numpy(oname)
print("Output mean after backend processing:", np.mean(output_data))
print("Output shape: ", np.shape(output_data))
# expected = np.multiply(batch, 1 if oname is "DALI_unchanged" else scalars,
# dtype=np.int32)
# if not np.allclose(output_data, expected):
# print("Pre/post average does not match")
# sys.exit(1)
# else:
# print("pass")
statistics = triton_client.get_inference_statistics(model_name=model_name)
if len(statistics.model_stats) != 1:
print("FAILED: Inference Statistics")
sys.exit(1)
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment