Skip to content

Instantly share code, notes, and snippets.

@danzz006
Created January 28, 2024 15:47
Show Gist options
  • Save danzz006/a4e03cd69cc3db45f366d4bfefb51752 to your computer and use it in GitHub Desktop.
Save danzz006/a4e03cd69cc3db45f366d4bfefb51752 to your computer and use it in GitHub Desktop.
Utility script to benchmark models running with triton inference server. This code is obtained from yolov7 repository https://github.com/WongKinYiu/yolov7/tree/main
#!/usr/bin/env python
import argparse
import numpy as np
import sys
import cv2
import tritonclient.grpc as grpcclient
from tritonclient.utils import InferenceServerException
from processing import preprocess, postprocess
from render import render_box, render_filled_box, get_text_size, render_text, RAND_COLORS
from labels import COCOLabels
from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor
from threading import Event
import threading
import time
INPUT_NAMES = ["images"]
OUTPUT_NAMES = ["num_dets", "det_boxes", "det_scores", "det_classes"]
def video_mode(input):
inputs = []
outputs = []
inputs.append(grpcclient.InferInput(INPUT_NAMES[0], [1, 3, FLAGS.width, FLAGS.height], "FP32"))
outputs.append(grpcclient.InferRequestedOutput(OUTPUT_NAMES[0]))
outputs.append(grpcclient.InferRequestedOutput(OUTPUT_NAMES[1]))
outputs.append(grpcclient.InferRequestedOutput(OUTPUT_NAMES[2]))
outputs.append(grpcclient.InferRequestedOutput(OUTPUT_NAMES[3]))
_exit_status_ = False
while True:
print("Opening input video stream...")
if FLAGS.input == '0': FLAGS.input = int(FLAGS.input)
cap = cv2.VideoCapture(FLAGS.input)
if not cap.isOpened():
print(f"FAILED: cannot open video {FLAGS.input}")
sys.exit(1)
print("Invoking inference...")
while True:
ret, frame = cap.read()
if not ret:
print("failed to fetch next frame")
event.set()
break
frame = cv2.resize(frame, (640, 640))
input_image_buffer = preprocess(frame, [FLAGS.width, FLAGS.height])
input_image_buffer = np.expand_dims(input_image_buffer, axis=0)
inputs[0].set_data_from_numpy(input_image_buffer)
t1 = time.time()
results = triton_client.infer(model_name=FLAGS.model,
inputs=inputs,
outputs=outputs,
model_version=FLAGS.model_version,
client_timeout=FLAGS.client_timeout)
t2 = time.time()
print(f"Time: {(t2-t1)*1000:.2f}ms", end="\r")
num_dets = results.as_numpy("num_dets")
det_boxes = results.as_numpy("det_boxes")
det_scores = results.as_numpy("det_scores")
det_classes = results.as_numpy("det_classes")
detected_objects = postprocess(num_dets, det_boxes, det_scores, det_classes, frame.shape[1], frame.shape[0], [FLAGS.width, FLAGS.height])
for box in detected_objects:
frame = render_box(frame, box.box(), color=tuple(RAND_COLORS[box.classID % 64].tolist()))
size = get_text_size(frame, f"{COCOLabels(box.classID).name}: {box.confidence:.2f}", normalised_scaling=0.6)
frame = render_filled_box(frame, (box.x1 - 3, box.y1 - 3, box.x1 + size[0], box.y1 + size[1]), color=(220, 220, 220))
frame = render_text(frame, f"{COCOLabels(box.classID).name}: {box.confidence:.2f}", (box.x1, box.y1), color=(30, 30, 30), normalised_scaling=0.5)
cv2.imshow(f'{threading.get_ident()}', frame)
if cv2.waitKey(1) == ord('q'):
event.set()
_exit_status_ = True
break
if event.is_set(): break
cap.release()
if event.is_set():
if _exit_status_:
break
event.clear()
cap.release()
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--mode',
# choices=['dummy', 'image', 'video'],
required=False,
default='video',
help='Run mode. \'dummy\' will send an emtpy buffer to the server to test if inference works. \'image\' will process an image. \'video\' will process a video.')
parser.add_argument('--input',
type=str,
required=False,
default=r'D:\Python\yolov7\deploy\triton-inference-server\demo.mp4',
help='Input file to load from in image or video mode')
parser.add_argument('-m',
'--model',
type=str,
required=False,
default='yolov7',
help='Inference model name, default yolov7')
parser.add_argument('-y',
'--model-version',
type=str,
required=False,
default='1',
help='Inference model version, default yolov7(1)')
parser.add_argument('-d',
'--processes',
type=int,
required=False,
default=1,
help='Inference model name, default yolov7')
parser.add_argument('--width',
type=int,
required=False,
default=640,
help='Inference model input width, default 640')
parser.add_argument('--height',
type=int,
required=False,
default=640,
help='Inference model input height, default 640')
parser.add_argument('-u',
'--url',
type=str,
required=False,
default='localhost:8001',
help='Inference server URL, default localhost:8001')
parser.add_argument('-o',
'--out',
type=str,
required=False,
default='',
help='Write output into file instead of displaying it')
parser.add_argument('-f',
'--fps',
type=float,
required=False,
default=24.0,
help='Video output fps, default 24.0 FPS')
parser.add_argument('-i',
'--model-info',
action="store_true",
required=False,
default=False,
help='Print model status, configuration and statistics')
parser.add_argument('-v',
'--verbose',
action="store_true",
required=False,
default=False,
help='Enable verbose client output')
parser.add_argument('-t',
'--client-timeout',
type=float,
required=False,
default=None,
help='Client timeout in seconds, default no timeout')
parser.add_argument('-s',
'--ssl',
action="store_true",
required=False,
default=False,
help='Enable SSL encrypted channel to the server')
parser.add_argument('-r',
'--root-certificates',
type=str,
required=False,
default=None,
help='File holding PEM-encoded root certificates, default none')
parser.add_argument('-p',
'--private-key',
type=str,
required=False,
default=None,
help='File holding PEM-encoded private key, default is none')
parser.add_argument('-x',
'--certificate-chain',
type=str,
required=False,
default=None,
help='File holding PEM-encoded certicate chain default is none')
FLAGS = parser.parse_args()
# Create server context
try:
triton_client = grpcclient.InferenceServerClient(
url=FLAGS.url,
verbose=FLAGS.verbose,
ssl=FLAGS.ssl,
root_certificates=FLAGS.root_certificates,
private_key=FLAGS.private_key,
certificate_chain=FLAGS.certificate_chain)
except Exception as e:
print("context creation failed: " + str(e))
sys.exit()
# Health check
if not triton_client.is_server_live():
print("FAILED : is_server_live")
sys.exit(1)
if not triton_client.is_server_ready():
print("FAILED : is_server_ready")
sys.exit(1)
if not triton_client.is_model_ready(FLAGS.model):
print("FAILED : is_model_ready")
sys.exit(1)
if FLAGS.model_info:
# Model metadata
try:
metadata = triton_client.get_model_metadata(FLAGS.model)
print(metadata)
except InferenceServerException as ex:
if "Request for unknown model" not in ex.message():
print("FAILED : get_model_metadata")
print("Got: {}".format(ex.message()))
sys.exit(1)
else:
print("FAILED : get_model_metadata")
sys.exit(1)
# Model configuration
try:
config = triton_client.get_model_config(FLAGS.model)
if not (config.config.name == FLAGS.model):
print("FAILED: get_model_config")
sys.exit(1)
print(config)
except InferenceServerException as ex:
print("FAILED : get_model_config")
print("Got: {}".format(ex.message()))
sys.exit(1)
windows = FLAGS.processes
with ThreadPoolExecutor(max_workers=windows) as executor:
event = Event()
executor.map(video_mode, range(windows))
# video_mode('input')
cv2.destroyAllWindows()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment