Skip to content

Instantly share code, notes, and snippets.

@LukeAI
Created June 16, 2023 22:27
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save LukeAI/336a1fd9ea802d454d883342517a681f to your computer and use it in GitHub Desktop.
Save LukeAI/336a1fd9ea802d454d883342517a681f to your computer and use it in GitHub Desktop.
#!/usr/bin/env python3
import os
import time
import logging
import argparse
import tensorrt as trt
import pycuda.driver as cuda
import pycuda.autoinit
import numpy as np
from cuda import cudart
import ctypes
from typing import Optional, List, Tuple
from skimage import io
import cv2
from skimage.transform import resize
parser = argparse.ArgumentParser()
parser.add_argument('--onnx', default="yolo_nas_s_nms.onnx")
parser.add_argument('--image', default="forks.png")
parser.add_argument('--fp16', action='store_true')
args = parser.parse_args()
ctx = pycuda.autoinit.context
trt.init_libnvinfer_plugins(None, "")
def cuda_call(call):
err, res = call[0], call[1:]
if len(res) == 1:
res = res[0]
return res
class HostDeviceMem:
"""Pair of host and device memory, where the host memory is wrapped in a numpy array"""
def __init__(self, size: int, dtype: np.dtype, shape: Tuple[int, int, int]):
nbytes = size * dtype.itemsize
host_mem = cuda_call(cudart.cudaMallocHost(nbytes))
pointer_type = ctypes.POINTER(np.ctypeslib.as_ctypes_type(dtype))
self.shape = shape
self._host = np.ctypeslib.as_array(ctypes.cast(host_mem, pointer_type), (size,))
self._device = cuda_call(cudart.cudaMalloc(nbytes))
self._nbytes = nbytes
@property
def host(self) -> np.ndarray:
return self._host
@host.setter
def host(self, arr: np.ndarray):
if arr.size > self.host.size:
raise ValueError(
f"Tried to fit an array of size {arr.size} into host memory of size {self.host.size}"
)
np.copyto(self.host[:arr.size], arr.flat, casting='safe')
@property
def device(self) -> int:
return self._device
@property
def nbytes(self) -> int:
return self._nbytes
def __str__(self):
return f"Host:\n{self.host}\nDevice:\n{self.device}\nSize:\n{self.nbytes}\n"
def __repr__(self):
return self.__str__()
def free(self):
cuda_call(cudart.cudaFree(self.device))
cuda_call(cudart.cudaFreeHost(self.host.ctypes.data))
class YoloNAS:
trt_logger = trt.Logger()
def __init__(self, onnx_path: str, fp16 = False):
trt_path = os.path.splitext(onnx_path)[0] + ".trt"
if (not os.path.exists(trt_path)):
print("building engine from ", onnx_path)
engine = self.build_engine_from_onnx(args.onnx)
with open(trt_path, "wb") as fw:
fw.write(engine.serialize())
print("saved engine to ", trt_path)
else:
print("reading ", trt_path)
with open(trt_path, "rb") as f, trt.Runtime(self.trt_logger) as runtime:
engine = runtime.deserialize_cuda_engine(f.read())
self.context = engine.create_execution_context()
self.inputs, self.outputs, self.bindings, self.stream = self.allocate_buffers(engine)
# Allocates all buffers required for an engine, i.e. host/device inputs/outputs.
# If engine uses dynamic shapes, specify a profile to find the maximum input & output size.
def allocate_buffers(self, engine: trt.ICudaEngine, profile_idx: Optional[int] = None):
inputs = []
outputs = []
bindings = []
stream = cuda_call(cudart.cudaStreamCreate())
tensor_names = [engine.get_tensor_name(i) for i in range(engine.num_io_tensors)]
for binding in tensor_names:
# get_tensor_profile_shape returns (min_shape, optimal_shape, max_shape)
# Pick out the max shape to allocate enough memory for the binding.
shape = engine.get_tensor_shape(binding) if profile_idx is None else engine.get_tensor_profile_shape(binding, profile_idx)[-1]
shape_valid = np.all([s >= 0 for s in shape])
if not shape_valid and profile_idx is None:
raise ValueError(f"Binding {binding} has dynamic shape, " +\
"but no profile was specified.")
size = trt.volume(shape)
if engine.has_implicit_batch_dimension:
size *= engine.max_batch_size
dtype = np.dtype(trt.nptype(engine.get_tensor_dtype(binding)))
# Allocate host and device buffers
bindingMemory = HostDeviceMem(size, dtype, shape)
# Append the device buffer to device bindings.
bindings.append(int(bindingMemory.device))
# Append to the appropriate list.
if engine.get_tensor_mode(binding) == trt.TensorIOMode.INPUT:
inputs.append(bindingMemory)
else:
outputs.append(bindingMemory)
return inputs, outputs, bindings, stream
def build_engine_from_onnx(self, onnx_file_path, fp16=False):
EXPLICIT_BATCH = 1 << (int)(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
with trt.Builder(self.trt_logger) as builder, \
builder.create_network(EXPLICIT_BATCH) as network, \
builder.create_builder_config() as config, \
trt.OnnxParser(network, self.trt_logger) as parser, \
trt.Runtime(self.trt_logger) as runtime:
config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 2 << 30) # 2Gb
if fp16:
print("setting fp16 flag")
config.set_flag(trt.BuilderFlag.FP16)
# Parse model file
assert os.path.exists(onnx_file_path), f'cannot find {onnx_file_path}'
with open(onnx_file_path, 'rb') as fr:
if not parser.parse(fr.read()):
print('ERROR: Failed to parse the ONNX file.')
for error in range(parser.num_errors):
print (parser.get_error(error))
assert False
plan = builder.build_serialized_network(network, config)
engine = runtime.deserialize_cuda_engine(plan)
return engine
def infer(self, image):
self.inputs[0].host = image
# Transfer input data to the GPU.
host2device = cudart.cudaMemcpyKind.cudaMemcpyHostToDevice
[cuda_call(cudart.cudaMemcpyAsync(inp.device, inp.host, inp.nbytes, host2device, self.stream)) for inp in self.inputs]
# Run inference.
self.context.execute_async_v2(bindings=self.bindings, stream_handle=self.stream)
# Transfer predictions back from the GPU.
device2host = cudart.cudaMemcpyKind.cudaMemcpyDeviceToHost
[cuda_call(cudart.cudaMemcpyAsync(out.host, out.device, out.nbytes, device2host, self.stream)) for out in self.outputs]
# Synchronize the stream
cuda_call(cudart.cudaStreamSynchronize(self.stream))
# Return outputs
return [out.host for out in self.outputs]
def load_image(img_path, size):
img_raw = io.imread(img_path)
img_raw = np.rollaxis(img_raw, 2, 0)
img_resize = resize(img_raw / 255, size, anti_aliasing=True)
img_resize = img_resize.astype(np.float32)
return img_resize, img_raw
def main():
yolo = YoloNAS(args.onnx)
input_shape = yolo.inputs[0].shape[1:]
img_resize, img_raw = load_image(args.image, input_shape)
# warmup
for i in range(10):
yolo.infer(img_resize)
start_time = time.time()
trt_outputs = yolo.infer(img_resize)
end_time = time.time()
print("Total Running time = {:.3f} seconds".format(end_time - start_time))
num_dets = trt_outputs[0][0]
det_boxes = []
for box in range(num_dets):
x0, y0, x1, y1 = map(round, trt_outputs[1][box*4:(box+1)*4])
det_boxes.append(((int(x0), int(y0)), (int(x1), int(y1))))
det_scores = trt_outputs[2][:num_dets]
det_classes = trt_outputs[3][:num_dets]
print(det_scores)
print(det_classes)
# show image
img = cv2.imread(args.image)
img = cv2.resize(img, (input_shape[1], input_shape[2]))
for box in det_boxes:
cv2.rectangle(img, box[0], box[1], (0,255,0), 3)
cv2.imshow("bboxes", img)
cv2.waitKey(0)
cv2.destroyAllWindows()
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment