Skip to content

Instantly share code, notes, and snippets.

@StepTurtle
Last active June 11, 2024 15:41
Show Gist options
  • Save StepTurtle/37aae5afb52fc4d9186296dc5b2fb7c3 to your computer and use it in GitHub Desktop.
Save StepTurtle/37aae5afb52fc4d9186296dc5b2fb7c3 to your computer and use it in GitHub Desktop.
RTMDet TensorRT Python Deploy
cmake_minimum_required(VERSION 3.10)
project(RTMDet)
set_property(TARGET ${PROJECT_NAME} PROPERTY CXX_STANDARD 11)
# Find OpenCV
find_package(OpenCV REQUIRED)
# Set the CUDA paths. Update this path to your CUDA installation.
set(CUDA_TOOLKIT_ROOT_DIR /usr/local/cuda-11.8)
find_package(CUDA REQUIRED)
# Set the TensorRT paths. Update this path to your TensorRT installation.
set(TENSORRT_ROOT /root/workspace/TensorRT)
find_path(TENSORRT_INCLUDE_DIR NvInfer.h PATHS ${TENSORRT_ROOT}/include)
find_library(NVINFER_LIB nvinfer PATHS ${TENSORRT_ROOT}/lib)
find_library(NVONNXPARSER_LIB nvonnxparser PATHS ${TENSORRT_ROOT}/lib)
find_library(NVPLUGIN_LIB nvinfer_plugin PATHS ${TENSORRT_ROOT}/lib)
# Include directories
include_directories(${CUDA_TOOLKIT_ROOT_DIR}/include)
include_directories(${TENSORRT_INCLUDE_DIR})
include_directories(${OpenCV_INCLUDE_DIRS})
# Source files
set(SOURCE_FILES rtmdet.cpp)
add_executable(${PROJECT_NAME} ${SOURCE_FILES})
target_link_libraries(${PROJECT_NAME} ${CUDA_LIBRARIES} ${NVINFER_LIB} ${NVONNXPARSER_LIB} ${NVPLUGIN_LIB} ${OpenCV_LIBS} dl)
#include <dlfcn.h>
#include <NvInfer.h>
#include <fstream>
#include <vector>
#include <iostream>
#include <opencv2/opencv.hpp>
std::vector <std::string> coco_labels = {"person", "bicycle", "car", "motorcycle", "airplane", "bus", "train", "truck",
"boat", "traffic light", "fire hydrant", "stop sign", "parking meter", "bench",
"bird", "cat", "dog", "horse", "sheep", "cow", "elephant", "bear", "zebra",
"giraffe", "backpack", "umbrella", "handbag", "tie", "suitcase", "frisbee",
"skis", "snowboard", "sports ball", "kite", "baseball bat", "baseball glove",
"skateboard", "surfboard", "tennis racket", "bottle", "wine glass", "cup",
"fork", "knife", "spoon", "bowl", "banana", "apple", "sandwich", "orange",
"broccoli", "carrot", "hot dog", "pizza", "donut", "cake", "chair", "couch",
"potted plant", "bed", "dining table", "toilet", "tv", "laptop", "mouse",
"remote", "keyboard", "cell phone", "microwave", "oven", "toaster", "sink",
"refrigerator", "book", "clock", "vase", "scissors", "teddy bear",
"hair drier", "toothbrush",
};
class Logger : public nvinfer1::ILogger {
void log(Severity severity, const char *msg)
noexcept override{
if (severity != Severity::kINFO) {
std::cout << msg << std::endl;
}
}
} gLogger;
void *loadLibrary(const char *libPath) {
void *handle = dlopen(libPath, RTLD_LAZY);
if (!handle) {
std::cerr << "Cannot load library: " << dlerror() << '\n';
return nullptr;
}
return handle;
}
std::vector<char> readEngineFile(const std::string &engineFile) {
std::ifstream file(engineFile, std::ios::binary);
if (!file) {
std::cerr << "Error opening engine file: " << engineFile << '\n';
return {};
}
file.seekg(0, std::ios::end);
size_t fileSize = file.tellg();
file.seekg(0, std::ios::beg);
std::vector<char> buffer(fileSize);
file.read(buffer.data(), fileSize);
return buffer;
}
std::vector <cv::Vec3b> generate_color_array(int num_colors) {
std::vector <cv::Vec3b> colors(num_colors);
for (int i = 0; i < num_colors; ++i) {
cv::Vec3b color;
color[0] = rand() % 256;
color[1] = rand() % 256;
color[2] = rand() % 256;
colors[i] = color;
}
return colors;
}
int main() {
std::string engine_path = "end2end.engine";
std::string image_path = "demo.jpg";
// List of paths to your plugin shared object files
const char *pluginLibs[] = {
"/root/workspace/rtmdet/custom_plugin/libmmdeploy_tensorrt_ops.so",
};
// Load each plugin library
for (const char *libPath: pluginLibs) {
if (!loadLibrary(libPath)) {
return -1; // Exit if any library fails to load
}
}
// Read the serialized engine
std::vector<char> engineData = readEngineFile(engine_path);
if (engineData.empty()) {
std::cerr << "Failed to read engine file\n";
return -1;
}
// Create a runtime and deserialize the engine
nvinfer1::IRuntime *runtime = nvinfer1::createInferRuntime(gLogger);
if (!runtime) {
std::cerr << "Failed to create runtime\n";
return -1;
}
nvinfer1::ICudaEngine *engine = runtime->deserializeCudaEngine(engineData.data(), engineData.size(), nullptr);
if (!engine) {
std::cerr << "Failed to deserialize CUDA engine\n";
return -1;
}
nvinfer1::IExecutionContext *context = engine->createExecutionContext();
if (!context) {
std::cerr << "Failed to create execution context\n";
return -1;
}
const int inputIndex = engine->getBindingIndex("input");
const int outputDetsIndex = engine->getBindingIndex("dets");
const int outputLabelsIndex = engine->getBindingIndex("labels");
const int outputMasksIndex = engine->getBindingIndex("masks");
// Allocate memory for input and output tensors
const int batchSize = 1;
const int inputSize = 640;
const int inputVolume = inputSize * inputSize * 3;
const int outputDetsSize = 500;
const int outputLabelsSize = 100;
const int outputMasksSize = 40960000;
void *inputBuffer, *outputDetsBuffer, *outputLabelsBuffer, *outputMasksBuffer;
cudaMalloc(&inputBuffer, batchSize * inputVolume * sizeof(float));
cudaMalloc(&outputDetsBuffer, batchSize * outputDetsSize * sizeof(float));
cudaMalloc(&outputLabelsBuffer, batchSize * outputLabelsSize * sizeof(int));
cudaMalloc(&outputMasksBuffer, batchSize * outputMasksSize * sizeof(float));
context->setTensorAddress("input", inputBuffer);
context->setTensorAddress("dets", outputDetsBuffer);
context->setTensorAddress("labels", outputLabelsBuffer);
context->setTensorAddress("masks", outputMasksBuffer);
cv::Mat img = cv::imread(image_path);
cv::resize(img, img, cv::Size(inputSize, inputSize));
img.convertTo(img, CV_32F);
img -= cv::Scalar(103.53, 116.28, 123.675);
img /= cv::Scalar(57.375, 57.12, 58.395);
// Copy the image to the input buffer
float *hostData = new float[img.rows * img.cols * img.channels()];
for (int c = 0; c < img.channels(); ++c) {
for (int i = 0; i < img.rows; ++i) {
for (int j = 0; j < img.cols; ++j) {
hostData[c * img.rows * img.cols + i * img.cols + j] = img.at<cv::Vec3f>(i, j)[c];
}
}
}
cudaMemcpy(inputBuffer, hostData, batchSize * inputVolume * sizeof(float), cudaMemcpyHostToDevice);
// Inference
context->executeV2(&inputBuffer);
// Copy the labels to the output buffer
std::vector<int> outputLabels(batchSize * outputLabelsSize);
cudaMemcpy(outputLabels.data(), outputLabelsBuffer, batchSize * outputLabelsSize * sizeof(int),
cudaMemcpyDeviceToHost);
// Copy the detections to the output buffer
std::vector<float> outputDets(batchSize * outputDetsSize);
cudaMemcpy(outputDets.data(), outputDetsBuffer, batchSize * outputDetsSize * sizeof(float),
cudaMemcpyDeviceToHost);
// Copy the masks to the output buffer
std::vector<float> outputMasksData(batchSize * outputMasksSize);
cudaMemcpy(outputMasksData.data(), outputMasksBuffer, batchSize * outputMasksSize * sizeof(float),
cudaMemcpyDeviceToHost);
std::vector <cv::Mat> outputMasks(100);
for (int i = 0; i < 100; ++i) {
cv::Mat mask(640, 640, CV_32F, outputMasksData.data() + (i * 640 * 640));
outputMasks[i] = mask;
}
// Draw the detections
cv::Mat output_image = cv::imread(image_path);
for (int index = 0; index < 100; ++index) {
if (outputDets[(5 * index) + 4] < 0.3) {
continue;
}
cv::Mat mask = outputMasks[index];
double minVal, maxVal;
cv::minMaxLoc(mask, &minVal, &maxVal); // find minimum and maximum intensities
mask.convertTo(mask, CV_8U, 255.0 / (maxVal - minVal), -minVal * 255.0 / (maxVal - minVal));
cv::resize(mask, mask, cv::Size(output_image.cols, output_image.rows));
// Index the mask and if the value of pixel is greater than 100 make pixel blue
for (int i = 0; i < mask.rows; ++i) {
for (int j = 0; j < mask.cols; ++j) {
if (mask.at<uchar>(i, j) > 200) {
output_image.at<cv::Vec3b>(i, j) = colors[outputLabels[index]];
}
}
}
// Draw rectangle around the object
cv::rectangle(output_image,
cv::Point(static_cast<int>(outputDets[(5 * index) + 0] * (1 / 0.2222222222222222)),
static_cast<int>(outputDets[(5 * index) + 1] * (1 / 0.3440860215))),
cv::Point(static_cast<int>(outputDets[(5 * index) + 2] * (1 / 0.2222222222222222)),
static_cast<int>(outputDets[(5 * index) + 3] * (1 / 0.3440860215))),
colors[outputLabels[index]], 2);
// Write the class name
cv::putText(output_image, coco_labels[index],
cv::Point(static_cast<int>(outputDets[(5 * index) + 0] * (1 / 0.2222222222222222)),
static_cast<int>(outputDets[(5 * index) + 1] * (1 / 0.3440860215))),
cv::FONT_HERSHEY_SIMPLEX, 1, colors[outputLabels[index]], 2);
}
cv::imshow("Output", output_image);
cv::waitKey(1);
delete[] hostData;
}

With this Cpp scripts, you can infer your images with RTMDet TensorRT models.

From this link you can get PyTorch models and convert them to ONNX and TensorRT formats.

You need following files to run this script:

  1. TensorRT Model: You can create from the link
  2. mmdetection tensorrt plugin: it created with mmdeploy and you shuld load this plugin in script (line 71) (it locates under mmdeploy/mmdeploy/lib after you compile mmdeploy).
  3. image: any image

Change CUDA_TOOLKIT_ROOT_DIR and TENSORRT_ROOT from CMakeLists.txt file.

mkdir build
cd build
cmake ..
make
./RTMDet
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment