Created
January 28, 2024 15:47
-
-
Save danzz006/a4e03cd69cc3db45f366d4bfefb51752 to your computer and use it in GitHub Desktop.
Utility script to benchmark models running with triton inference server. This code is obtained from yolov7 repository https://github.com/WongKinYiu/yolov7/tree/main
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python | |
import argparse | |
import numpy as np | |
import sys | |
import cv2 | |
import tritonclient.grpc as grpcclient | |
from tritonclient.utils import InferenceServerException | |
from processing import preprocess, postprocess | |
from render import render_box, render_filled_box, get_text_size, render_text, RAND_COLORS | |
from labels import COCOLabels | |
from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor | |
from threading import Event | |
import threading | |
import time | |
INPUT_NAMES = ["images"] | |
OUTPUT_NAMES = ["num_dets", "det_boxes", "det_scores", "det_classes"] | |
def video_mode(input): | |
inputs = [] | |
outputs = [] | |
inputs.append(grpcclient.InferInput(INPUT_NAMES[0], [1, 3, FLAGS.width, FLAGS.height], "FP32")) | |
outputs.append(grpcclient.InferRequestedOutput(OUTPUT_NAMES[0])) | |
outputs.append(grpcclient.InferRequestedOutput(OUTPUT_NAMES[1])) | |
outputs.append(grpcclient.InferRequestedOutput(OUTPUT_NAMES[2])) | |
outputs.append(grpcclient.InferRequestedOutput(OUTPUT_NAMES[3])) | |
_exit_status_ = False | |
while True: | |
print("Opening input video stream...") | |
if FLAGS.input == '0': FLAGS.input = int(FLAGS.input) | |
cap = cv2.VideoCapture(FLAGS.input) | |
if not cap.isOpened(): | |
print(f"FAILED: cannot open video {FLAGS.input}") | |
sys.exit(1) | |
print("Invoking inference...") | |
while True: | |
ret, frame = cap.read() | |
if not ret: | |
print("failed to fetch next frame") | |
event.set() | |
break | |
frame = cv2.resize(frame, (640, 640)) | |
input_image_buffer = preprocess(frame, [FLAGS.width, FLAGS.height]) | |
input_image_buffer = np.expand_dims(input_image_buffer, axis=0) | |
inputs[0].set_data_from_numpy(input_image_buffer) | |
t1 = time.time() | |
results = triton_client.infer(model_name=FLAGS.model, | |
inputs=inputs, | |
outputs=outputs, | |
model_version=FLAGS.model_version, | |
client_timeout=FLAGS.client_timeout) | |
t2 = time.time() | |
print(f"Time: {(t2-t1)*1000:.2f}ms", end="\r") | |
num_dets = results.as_numpy("num_dets") | |
det_boxes = results.as_numpy("det_boxes") | |
det_scores = results.as_numpy("det_scores") | |
det_classes = results.as_numpy("det_classes") | |
detected_objects = postprocess(num_dets, det_boxes, det_scores, det_classes, frame.shape[1], frame.shape[0], [FLAGS.width, FLAGS.height]) | |
for box in detected_objects: | |
frame = render_box(frame, box.box(), color=tuple(RAND_COLORS[box.classID % 64].tolist())) | |
size = get_text_size(frame, f"{COCOLabels(box.classID).name}: {box.confidence:.2f}", normalised_scaling=0.6) | |
frame = render_filled_box(frame, (box.x1 - 3, box.y1 - 3, box.x1 + size[0], box.y1 + size[1]), color=(220, 220, 220)) | |
frame = render_text(frame, f"{COCOLabels(box.classID).name}: {box.confidence:.2f}", (box.x1, box.y1), color=(30, 30, 30), normalised_scaling=0.5) | |
cv2.imshow(f'{threading.get_ident()}', frame) | |
if cv2.waitKey(1) == ord('q'): | |
event.set() | |
_exit_status_ = True | |
break | |
if event.is_set(): break | |
cap.release() | |
if event.is_set(): | |
if _exit_status_: | |
break | |
event.clear() | |
cap.release() | |
if __name__ == '__main__': | |
parser = argparse.ArgumentParser() | |
parser.add_argument('--mode', | |
# choices=['dummy', 'image', 'video'], | |
required=False, | |
default='video', | |
help='Run mode. \'dummy\' will send an emtpy buffer to the server to test if inference works. \'image\' will process an image. \'video\' will process a video.') | |
parser.add_argument('--input', | |
type=str, | |
required=False, | |
default=r'D:\Python\yolov7\deploy\triton-inference-server\demo.mp4', | |
help='Input file to load from in image or video mode') | |
parser.add_argument('-m', | |
'--model', | |
type=str, | |
required=False, | |
default='yolov7', | |
help='Inference model name, default yolov7') | |
parser.add_argument('-y', | |
'--model-version', | |
type=str, | |
required=False, | |
default='1', | |
help='Inference model version, default yolov7(1)') | |
parser.add_argument('-d', | |
'--processes', | |
type=int, | |
required=False, | |
default=1, | |
help='Inference model name, default yolov7') | |
parser.add_argument('--width', | |
type=int, | |
required=False, | |
default=640, | |
help='Inference model input width, default 640') | |
parser.add_argument('--height', | |
type=int, | |
required=False, | |
default=640, | |
help='Inference model input height, default 640') | |
parser.add_argument('-u', | |
'--url', | |
type=str, | |
required=False, | |
default='localhost:8001', | |
help='Inference server URL, default localhost:8001') | |
parser.add_argument('-o', | |
'--out', | |
type=str, | |
required=False, | |
default='', | |
help='Write output into file instead of displaying it') | |
parser.add_argument('-f', | |
'--fps', | |
type=float, | |
required=False, | |
default=24.0, | |
help='Video output fps, default 24.0 FPS') | |
parser.add_argument('-i', | |
'--model-info', | |
action="store_true", | |
required=False, | |
default=False, | |
help='Print model status, configuration and statistics') | |
parser.add_argument('-v', | |
'--verbose', | |
action="store_true", | |
required=False, | |
default=False, | |
help='Enable verbose client output') | |
parser.add_argument('-t', | |
'--client-timeout', | |
type=float, | |
required=False, | |
default=None, | |
help='Client timeout in seconds, default no timeout') | |
parser.add_argument('-s', | |
'--ssl', | |
action="store_true", | |
required=False, | |
default=False, | |
help='Enable SSL encrypted channel to the server') | |
parser.add_argument('-r', | |
'--root-certificates', | |
type=str, | |
required=False, | |
default=None, | |
help='File holding PEM-encoded root certificates, default none') | |
parser.add_argument('-p', | |
'--private-key', | |
type=str, | |
required=False, | |
default=None, | |
help='File holding PEM-encoded private key, default is none') | |
parser.add_argument('-x', | |
'--certificate-chain', | |
type=str, | |
required=False, | |
default=None, | |
help='File holding PEM-encoded certicate chain default is none') | |
FLAGS = parser.parse_args() | |
# Create server context | |
try: | |
triton_client = grpcclient.InferenceServerClient( | |
url=FLAGS.url, | |
verbose=FLAGS.verbose, | |
ssl=FLAGS.ssl, | |
root_certificates=FLAGS.root_certificates, | |
private_key=FLAGS.private_key, | |
certificate_chain=FLAGS.certificate_chain) | |
except Exception as e: | |
print("context creation failed: " + str(e)) | |
sys.exit() | |
# Health check | |
if not triton_client.is_server_live(): | |
print("FAILED : is_server_live") | |
sys.exit(1) | |
if not triton_client.is_server_ready(): | |
print("FAILED : is_server_ready") | |
sys.exit(1) | |
if not triton_client.is_model_ready(FLAGS.model): | |
print("FAILED : is_model_ready") | |
sys.exit(1) | |
if FLAGS.model_info: | |
# Model metadata | |
try: | |
metadata = triton_client.get_model_metadata(FLAGS.model) | |
print(metadata) | |
except InferenceServerException as ex: | |
if "Request for unknown model" not in ex.message(): | |
print("FAILED : get_model_metadata") | |
print("Got: {}".format(ex.message())) | |
sys.exit(1) | |
else: | |
print("FAILED : get_model_metadata") | |
sys.exit(1) | |
# Model configuration | |
try: | |
config = triton_client.get_model_config(FLAGS.model) | |
if not (config.config.name == FLAGS.model): | |
print("FAILED: get_model_config") | |
sys.exit(1) | |
print(config) | |
except InferenceServerException as ex: | |
print("FAILED : get_model_config") | |
print("Got: {}".format(ex.message())) | |
sys.exit(1) | |
windows = FLAGS.processes | |
with ThreadPoolExecutor(max_workers=windows) as executor: | |
event = Event() | |
executor.map(video_mode, range(windows)) | |
# video_mode('input') | |
cv2.destroyAllWindows() | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment