-
-
Save szalpal/63d427249faab0f1b9087059ae394d58 to your computer and use it in GitHub Desktop.
Tritonserver client code for a bug reproduction
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python | |
# The MIT License (MIT) | |
# | |
# Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES | |
# | |
# Permission is hereby granted, free of charge, to any person obtaining a copy of | |
# this software and associated documentation files (the "Software"), to deal in | |
# the Software without restriction, including without limitation the rights to | |
# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of | |
# the Software, and to permit persons to whom the Software is furnished to do so, | |
# subject to the following conditions: | |
# | |
# The above copyright notice and this permission notice shall be included in all | |
# copies or substantial portions of the Software. | |
# | |
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | |
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS | |
# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR | |
# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER | |
# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN | |
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. | |
import argparse | |
import os | |
import re | |
import sys | |
import numpy as np | |
import tritonclient.grpc | |
from numpy.random import randint | |
np.random.seed(100019) | |
def parse_args(): | |
parser = argparse.ArgumentParser() | |
parser.add_argument('-v', '--verbose', action="store_true", required=False, default=False, | |
help='Enable verbose output') | |
parser.add_argument('-u', '--url', type=str, required=False, default='localhost:8001', | |
help='Inference server URL. Default is localhost:8001.') | |
parser.add_argument('--batch_size', type=int, required=False, default=1, | |
help='Batch size') | |
parser.add_argument('--n_iter', type=int, required=False, default=-1, | |
help='Number of iterations , with `batch_size` size') | |
parser.add_argument('--model_name', type=str, required=False, default="dali_multi_input", | |
help='Model name') | |
parser.add_argument('--img_dir', type=str, required=False, default=None, | |
help='Directory, with images that will be broken down into batches and ' | |
'inferred. The directory must contain images only') | |
return parser.parse_args() | |
def array_from_list(arrays): | |
""" | |
Convert list of ndarrays to single ndarray with ndims+=1 | |
""" | |
lengths = list(map(lambda x, arr=arrays: arr[x].shape[0], [x for x in range(len(arrays))])) | |
max_len = max(lengths) | |
arrays = list(map(lambda arr, ml=max_len: np.pad(arr, ((0, ml - arr.shape[0]))), arrays)) | |
for arr in arrays: | |
assert arr.shape == arrays[0].shape, "Arrays must have the same shape" | |
return np.stack(arrays) | |
def load_image(img_path: str): | |
""" | |
Loads image as an encoded array of bytes. | |
This is a typical approach you want to use in DALI backend | |
""" | |
with open(img_path, "rb") as f: | |
img = f.read() | |
return np.array(list(img)).astype(np.uint8) | |
def load_images(dir_path: str, name_pattern='.', max_images=-1): | |
""" | |
Loads all files in given dir_path. Treats them as images. Optionally apply regex pattern to | |
file names and use only the files, that suffice the pattern | |
""" | |
assert max_images > 0 or max_images == -1 | |
images = [] | |
# Traverses directory for files (not dirs) and returns full paths to them | |
path_generator = (os.path.join(dir_path, f) for f in os.listdir(dir_path) if | |
os.path.isfile(os.path.join(dir_path, f)) and | |
re.search(name_pattern, f) is not None) | |
img_paths = [dir_path] if os.path.isfile(dir_path) else list(path_generator) | |
if 0 < max_images < len(img_paths): | |
img_paths = img_paths[:max_images] | |
for img in img_paths: | |
images.append(load_image(img)) | |
return images | |
def batcher(dataset, max_batch_size, n_iterations=-1): | |
""" | |
Generator, that splits dataset into batches with given batch size | |
""" | |
iter_idx = 0 | |
data_idx = 0 | |
while data_idx < len(dataset): | |
if 0 < n_iterations <= iter_idx: | |
raise StopIteration | |
batch_size = min(randint(0, max_batch_size) + 1, len(dataset) - data_idx) | |
iter_idx += 1 | |
yield dataset[data_idx: data_idx + batch_size] | |
data_idx += batch_size | |
def main(): | |
FLAGS = parse_args() | |
try: | |
triton_client = tritonclient.grpc.InferenceServerClient(url=FLAGS.url, | |
verbose=FLAGS.verbose) | |
except Exception as e: | |
print("channel creation failed: " + str(e)) | |
sys.exit(1) | |
model_name = FLAGS.model_name | |
model_version = -1 | |
image_data = load_images(FLAGS.img_dir, max_images=FLAGS.batch_size * FLAGS.n_iter) | |
input_data = array_from_list(image_data) | |
# Infer | |
outputs = [] | |
input_names = ["IMAGE"] | |
scalar_names = ["CROP_X", "CROP_Y", "CROP_WIDTH", "CROP_HEIGHT"] | |
output_names = ["PREPROCESSED_IMAGE"] | |
input_shape = list(input_data.shape) | |
for oname in output_names: | |
outputs.append(tritonclient.grpc.InferRequestedOutput(oname)) | |
for batch in batcher(input_data, FLAGS.batch_size): | |
print("Input mean before backend processing:", np.mean(batch)) | |
batch_size = np.shape(batch)[0] | |
print("Batch size: ", batch_size) | |
# Initialize the data | |
input_shape[0] = batch_size | |
# scalars = randint(0, 1024, size=(batch_size, 1), dtype=np.int32) | |
inputs = [tritonclient.grpc.InferInput(iname, input_shape, "UINT8") for iname in | |
input_names] | |
scalar_inputs = [tritonclient.grpc.InferInput(sname, [batch_size, 1], "FP32") for sname in | |
scalar_names] | |
for inp in inputs: | |
inp.set_data_from_numpy(np.copy(batch)) | |
for scal in scalar_inputs[:2]: | |
scal.set_data_from_numpy(randint(0, 1, size=(batch_size, 1)).astype('float32')) | |
for scal in scalar_inputs[2:]: | |
scal.set_data_from_numpy(randint(100, 200, size=(batch_size, 1)).astype('float32')) | |
# Test with outputs | |
results = triton_client.infer(model_name=model_name, inputs=[*inputs, *scalar_inputs], | |
outputs=outputs) | |
# Get the output arrays from the results | |
for oname in output_names: | |
print("\nOutput: ", oname) | |
output_data = results.as_numpy(oname) | |
print("Output mean after backend processing:", np.mean(output_data)) | |
print("Output shape: ", np.shape(output_data)) | |
# expected = np.multiply(batch, 1 if oname is "DALI_unchanged" else scalars, | |
# dtype=np.int32) | |
# if not np.allclose(output_data, expected): | |
# print("Pre/post average does not match") | |
# sys.exit(1) | |
# else: | |
# print("pass") | |
statistics = triton_client.get_inference_statistics(model_name=model_name) | |
if len(statistics.model_stats) != 1: | |
print("FAILED: Inference Statistics") | |
sys.exit(1) | |
if __name__ == '__main__': | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment