Skip to content

Instantly share code, notes, and snippets.

@StepTurtle
Last active June 11, 2024 12:02
Show Gist options
  • Save StepTurtle/d69ac76d8bd70643fc3899045659bf35 to your computer and use it in GitHub Desktop.
Save StepTurtle/d69ac76d8bd70643fc3899045659bf35 to your computer and use it in GitHub Desktop.
RTMDet TensorRT Python Deploy
import tensorrt as trt
import pycuda.driver as cuda
import pycuda.autoinit
import numpy as np
import cv2
import ctypes
TRT_LOGGER = trt.Logger(trt.Logger.WARNING)
def preprocess_image(image_path, img_size=(640, 640)):
image = cv2.imread(image_path)
image = cv2.resize(image, img_size)
mean = np.array([103.53, 116.28, 123.675])
std = np.array([57.375, 57.12, 58.395])
image = (image - mean) / std
image = image.transpose(2, 0, 1)
return image
def load_engine(engine_path):
with open(engine_path, 'rb') as f, trt.Runtime(TRT_LOGGER) as runtime:
engine = runtime.deserialize_cuda_engine(f.read())
return engine
def allocate_buffers(engine):
inputs = []
outputs = []
bindings = []
stream = cuda.Stream()
for binding in engine:
size = trt.volume(engine.get_binding_shape(binding)) * engine.max_batch_size
dtype = trt.nptype(engine.get_binding_dtype(binding))
host_mem = cuda.pagelocked_empty(size, dtype)
device_mem = cuda.mem_alloc(host_mem.nbytes)
bindings.append(int(device_mem))
if engine.binding_is_input(binding):
inputs.append({'host': host_mem, 'device': device_mem})
else:
outputs.append({'host': host_mem, 'device': device_mem})
return inputs, outputs, bindings, stream
def do_inference(context, bindings, inputs, outputs, stream):
[cuda.memcpy_htod_async(inp['device'], inp['host'], stream) for inp in inputs]
context.execute_async_v2(bindings=bindings, stream_handle=stream.handle)
[cuda.memcpy_dtoh_async(out['host'], out['device'], stream) for out in outputs]
stream.synchronize()
return [out['host'] for out in outputs]
def load_plugin(file_path):
ctypes.CDLL(file_path)
if __name__ == "__main__":
load_plugin("libmmdeploy_tensorrt_ops.so")
engine_path = 'end2end.engine'
image_path = 'demo.jpg'
# Load Engine------------------------
engine = load_engine(engine_path)
if engine:
print("Engine loaded successfully.")
else:
print("Failed to load the engine.")
# /Load Engine-----------------------
# Pre-process------------------------
context = engine.create_execution_context()
inputs, outputs, bindings, stream = allocate_buffers(engine)
input_image = preprocess_image("demo.jpg", (640, 640))
np.copyto(inputs[0]['host'], input_image.ravel())
# /Pre-process-----------------------
# Inference--------------------------
output = do_inference(context, bindings, inputs, outputs, stream)
# /Inference-------------------------
# Post-process-----------------------
def generate_color_array(num_colors):
np.random.seed(42) # For reproducibility
colors = np.random.randint(0, 256, size=(num_colors, 3), dtype=np.uint8)
return colors
colors = generate_color_array(100)
output_image = cv2.imread(image_path)
masks = output[2].reshape(100, 640, 640)
masks = (255 * (masks - np.min(masks)) / np.ptp(masks)).astype(np.uint8)
for index, mask in enumerate(masks):
if output[0][(5 * index) + 4] < 0.3:
break
mask = cv2.resize(mask, (2880, 1860))
output_image[mask > 200] = color_array[output[1][index]]
cv2.rectangle(
output_image,
(
int(output[0][(5 * index) + 0] * (1 / 0.2222222222222222)),
int(output[0][(5 * index) + 1] * (1 / 0.3440860215)),
),
(
int(output[0][(5 * index) + 2] * (1 / 0.2222222222222222)),
int(output[0][(5 * index) + 3] * (1 / 0.3440860215)),
),
(0, 255, 0),
2,
)
cv2.putText(
output_image,
f"{output[1][index]}",
(
int(output[0][(5 * index) + 0] * (1 / 0.2222222222222222)),
int(output[0][(5 * index) + 1] * (1 / 0.3440860215)),
),
cv2.FONT_HERSHEY_SIMPLEX,
1,
(0, 255, 0),
4,
)
cv2.imshow("Output", output_image)
cv2.waitKey(5000)
# /Post-process----------------------
del engine

With this Python scripts, you can infer your images with RTMDet TensorRT models.

From this link you can get Python models and convert them to ONNX and TensorRT models

You need following files to run this script:

  1. TensorRT Model: You can create from the link
  2. mmdetection tensorrt plugin: it created with mmdeploy and you should load this plugin in script (line 61) (it locates under mmdeploy/mmdeploy/lib after you compile mmdeploy).
  3. image: any image
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment