Skip to content

Instantly share code, notes, and snippets.

@YashasSamaga
Last active August 15, 2023 02:05
Show Gist options
  • Star 24 You must be signed in to star a gist
  • Fork 3 You must be signed in to fork a gist
  • Save YashasSamaga/a84cf2826ab2dc755005321fe17cd15d to your computer and use it in GitHub Desktop.
Save YashasSamaga/a84cf2826ab2dc755005321fe17cd15d to your computer and use it in GitHub Desktop.
GSoC 2019 | OpenCV | Adding a CUDA backend to the DNN module

DISCLAIMER

This gist documents the Google Summer of Code project. It is not updated and hence does not indicate current status of the CUDA backend.

For updated details, please see this gist.

Allow the OpenCV's DNN module to work with GPUs

Student: Yashas Samaga B L

Mentor: Davis King

Project Link: https://summerofcode.withgoogle.com/projects/#6021087400296448

Relevant PRs:

Introduction

The OpenCV’s DNN module has a blazing fast inference capability on CPUs. It supports performing inference on GPUs using OpenCL but lacks a CUDA backend. NVIDIA’s GPUs support OpenCL, but their capabilities are limited by OpenCL.

This project adds a new CUDA backend that can perform lightning fast inference on NVIDIA GPUs.

How to use?

Build

The CUDA backend requires CUDA Toolkit and cuDNN (min: 7.5.0) to be installed on the system. The CMake scripts will automatically detect the dependencies when the following options are set:

  • WITH_CUDA
  • WITH_CUDNN

The CUDA backend is enabled by setting the following option:

  • OPENCV_DNN_CUDA

After building, run [build dir]/bin/opencv_test_dnn and [build dir]/bin/opencv_perf_dnn.

Usage

The project adds the following new backends and targets to the existing list.

Backend Target
DNN_BACKEND_CUDA DNN_TARGET_CUDA
DNN_BACKEND_CUDA DNN_TARGET_CUDA_FP16

Support Matrix

The CUDA backend uses OpenCV's CPU backend as a fallback for unsupported layers and partially supported layers with unsupported configurations.

Blip Meaning
✔️ fully supported
🔵 partially supported
unsupported
Layer Status
Activations ✔️
Batch Normalization ✔️
Blank Layer ✔️
Concat Layer ✔️
Const Layer ✔️
Convolution 2d ✔️
Convolution 3d ✔️
Crop and resize
Crop Layer ✔️
Detection Output Layer
Deconvolution 2d 🔵 (most configurations supported)
Deconvolution 3d 🔵 (most configurations supported)
Elementwise Layers ✔️
Eltwise Layer ✔️
Flatten Layer ✔️
Fully Connected Layer ✔️
Input Layer
Interp Layer ✔️
Local Response Normalization ✔️
Max Unpooling 2d ✔️
Max Unpooling 3d ✔️
MVN Layer
Normalize Layer 🔵 (L1 and L2 supported)
Padding Layer ✔️
Permute Layer ✔️
Pooling 2d 🔵 (max and average supported)
Pooling 3d 🔵 (max and average supported)
Prior Box Layer ✔️
Proposal Layer
Region Layer ✔️
Reorg Layer ✔️
Reshape Layer ✔️
Resize Layer ✔️
Scale Layer ✔️
Shift Layer ✔️
Shuffle Channel Layer ✔️
Slice Layer ✔️
Softmax Layer ✔️
Split Layer ✔️
LSTM Layer

OCV CPU vs IE CPU vs CUDA

CPU: i7 7700HQ

GPU: NVIDIA GTX 1050 Mobile

CPU BLAS Library: MKL 2019.0.4

CUDA Version: 10.1

cuDNN: 7.6.2

Warmup Runs: 3 (forward pass is performed three times before benchmarks)

Benchmark Runs: 10 (the average of ten forward passes is reported)

Test Code: https://gist.github.com/YashasSamaga/71157cf0c3768c497e5e70fb95435596

Batch Size = 1

Model CUDA FP32 Inference Engine CPU OpenCV CPU
GoogLeNet 7.2447ms 10.4981ms 17.9176ms
DenseNet121 12.6324ms 19.1823ms 48.0628ms
EAST Text Detection 18.8281ms 49.0508ms 88.9429ms
ENet 11.5014ms Exception 62.5854ms
FastNeuralStyle StaryNight 27.498ms 178.309ms 160.359ms
Inception 5h 7.8546ms 22.2789ms 20.3255ms
Inception v2 FasterRCNN 112.736ms Exception 374.26ms
MobileNet SSD 58.4751ms 9.2896ms 27.3061ms
OpenCV Face Detector 6.9831ms 8.3981ms 17.6683ms
OpenPose Pose MPI 160.561ms 509.446ms 838.161ms
Resnet 50 11.3603ms 28.1529ms 50.2752ms
SqueezeNet 2.4084ms 3.2918ms 5.476ms
VGG16 SSD 70.4117ms 249.725ms 360.207ms
Yolo v3 57.9822ms 214.629ms 296.806ms
Yolo v2 51.5784ms 193.453ms 260.19ms

Batch Size = 10

Model CUDA FP32 Inference Engine CPU OpenCV CPU
GoogLeNet 35.7556ms 108.946ms 225.928ms
DenseNet121 74.9241ms 295.105ms 650.924ms
EAST Text Detection 149.58ms 536.946ms 1273.93ms
FastNeuralStyle StaryNight 283.173ms 1966.5ms 2175.3ms
Inception 5h 36.6225ms 180.429ms 233.276ms
MobileNet SSD 277.753ms 111.872ms 316.063ms
OpenCV Face Detector 52.4366ms 95.7866ms 202.657ms
OpenPose Pose MPI 628.617ms 5650.05ms 10683.5ms
Resnet 50 74.283ms 230.817ms 541.308ms
SqueezeNet 15.8144ms 35.4915ms 69.4122ms
VGG16 SSD 594.286ms 2796.23ms 4661.51ms
Yolo v3 488.704ms 2419.8ms 4209.74ms
Yolo v2 491.414ms 2185.47ms 3788.34ms

OpenCV CUDA vs OpenCV CPU

CPU: 2x Intel Xeon E5-2640 v4

GPU: 1x NVIDIA GTX 1080 Ti (11 GB)

CPU BLAS Library: OpenBLAS 0.2.20

CUDA Version: 10.0

cuDNN: 7.6.2

Warmup Runs: 3 (forward pass is performed three times before benchmarks)

Benchmark Runs: 10 (the average of ten forward passes is reported)

Test Code: https://gist.github.com/YashasSamaga/71157cf0c3768c497e5e70fb95435596

Backend Comparision

Batch Size = 1

Model CUDA FP32 OpenCV CPU
GoogLeNet 4.8824ms 14.2981ms
DenseNet121 6.4555ms 57.8244ms
EAST Text Detection 5.901ms 67.4301ms
ENet 4.5979ms 30.2767ms
FastNeuralStyle StaryNight 5.3193ms 51.3313ms
Inception 5h 4.9487ms 16.0048ms
Inception v2 FasterRCNN 82.0298ms 179.245ms
MobileNet SSD 70.9177ms 23.9348ms
OpenCV Face Detector 4.9288ms 15.4205ms
OpenPose Pose MPI 30.5954ms 246.747ms
Resnet 50 4.5968ms 45.1153ms
SqueezeNet 1.0888ms 3.6492ms
VGG16 SSD 23.5926ms 194.976ms
Yolo v3 18.0002ms 141.861ms
Yolo v2 12.1279ms 111.642ms

Batch Size = 10

Model CUDA FP32 OpenCV CPU
GoogLeNet 10.149ms 75.9591ms
DenseNet121 20.269ms 312.426ms
EAST Text Detection 32.1556ms 402.16ms
FastNeuralStyle StaryNight 49.1025ms 461.095ms
Inception 5h 9.9721ms 67.9308ms
MobileNet SSD 96.2898ms 110.783ms
OpenCV Face Detector 22.7501ms 77.8742ms
OpenPose Pose MPI 118.858ms 2321.89ms
Resnet 50 18.4139ms 229.599ms
SqueezeNet 4.4893ms 22.3049ms
VGG16 SSD 194.181ms 1319.67ms
Yolo v3 122.603ms 1044.11ms
Yolo v2 104.072ms 819.177ms

Batch Size = 128

Model CUDA FP32 OpenCV CPU
GoogLeNet 90.3755ms 775.769ms
DenseNet121 199.516ms 3536.38ms
EAST Text Detection 376.458ms 7685.72ms
FastNeuralStyle StaryNight 801.778ms 6607.15ms
Inception 5h 93.4188ms 771.575ms
MobileNet SSD 1028.93ms 1110.37ms
OpenCV Face Detector 276.992ms 977.997ms
OpenPose Pose MPI 1279.26ms 32159.3ms
Resnet 50 200.789ms 1719.92ms
SqueezeNet 55.6244ms 255.397ms
VGG16 SSD 2969.05ms 17201ms
Yolo v3 1564.78ms 13699.2ms
Yolo v2 1362.84ms 11254.9ms

Images processed per second (CUDA FP32)

Model batch size = 1 batch size = 10 batch size = 128
GoogLeNet 204 985 1416
DenseNet121 154 493 641
EAST Text Detection 169 311 340
ENet 217 Not Applicable Not Applicable
FastNeuralStyle StaryNight 188 204 160
Inception 5h 202 1002 1370
Inception v2 FasterRCNN 12 Not Aplicable Not Applicable
MobileNet SSD 14 104 124
OpenCV Face Detector 202 440 462
OpenPose Pose MPI 33 84 100
Resnet 50 217 540 637
SqueezeNet 918 2228 2301
VGG16 SSD 42 52 43
Yolo v3 55 82 81
Yolo v2 82 96 93

OpenCV CUDA vs TensorFlow

GPU: NVIDIA GTX 1080 Ti (11 GB)

Batch of 1

Model OpenCV CUDA TensorFlow
ResNet-50 4.5968ms 7.1163ms
EAST Text Detection 5.901ms 8.6890ms

Batch of 10

Model OpenCV CUDA TensorFlow
ResNet-50 18.4139ms 22.3665ms
EAST Text Detection 32.1556ms 39.4857ms

Batch of 128

Model OpenCV CUDA TensorFlow
ResNet-50 200.789ms 216.3923ms
EAST Text Detection 376.458ms 421.8292ms
@sunilsomarajan
Copy link

sunilsomarajan commented May 4, 2020

@YashasSamaga, some results for a single image inference (Mean of 10 runs) on Intel

Intel with 1080 GTX Ti : Mobilenet SSD V2

python opencv_mobile_ssd.py (4.3.0)
[INFO] setting preferable backend and target to CUDA...
1 0.99846137 person
64 0.46295404 potted plant
64 0.35649517 potted plant
72 0.735632 tv
73 0.60551673 laptop
84 0.8533891 book
84 0.71501225 book
84 0.55131793 book
Time per inference: 65.505838 ms
FPS: 15.265814842071778

python opencv_mobile_ssd.py (4.3.0+ opencv/opencv#16900)
[INFO] setting preferable backend and target to CUDA...
1 0.99846137 person
64 0.4629534 potted plant
64 0.35649383 potted plant
72 0.7356325 tv
73 0.6055173 laptop
84 0.85338897 book
84 0.7150124 book
84 0.55131817 book
Time per inference: 5.713272 ms
FPS: 175.0310476063297

@JulienMaille
Copy link

I have a question since I started working with inference on Cuda devices: is there a reason why the cudnn dll is so big?
This is a real pain to redistribute a 200+Mb, more over when you know that only a fraction of your users has a compatible GPU.
Is there a way to make it lighter? In comparison OpenVino redist is much smaller.

@YashasSamaga
Copy link
Author

@JulienMaile cuDNN is being broken into pieces in cuDNN 8.0. OpenCV cannot really do anything to make it lighter other than adding replacements for the services cuDNN provides (which would make cuDNN optional).

@JulienMaille
Copy link

Thanks for your reply. Where can I find information about cuDNN 8.0?

@YashasSamaga
Copy link
Author

YashasSamaga commented May 4, 2020

@JulieanMaile Looks like they have removed it from their release notes page. They had posted release notes for early access cuDNN 8.0.x.x weeks ago. I can find traces in google (search for "cuDNN 8.0 early access"). When I had gone through the documentation, they seemed to have split cuDNN into 6 libraries: 3 for inference and 3 for training. Each category had something like basic ops, cnn and advanced (mostly RNN stuff).

@JulienMaille
Copy link

Hi @YashasSamaga that's also what I found. Any news since the recent GTC show?

@ynioba
Copy link

ynioba commented May 23, 2020

Hi @YashasSamaga I have a question since I have two GPU last week. How to make another GPU work ? Thanks!

@YashasSamaga
Copy link
Author

@ynioba

There are many ways to make use of multiple GPUs. Here is one which I think is the safest and the least complex solution. It makes use of the fact that the CUDA runtime library maintains a separate CUDA context for each CPU thread.

Suppose you have N devices.

Create N threads.
Assign a CUDA device to each thread by calling cudaSetDevice or cv::cuda::setDevice in that thread. Each thread is now associated with a device.
You can create any number of cv::dnn::Net objects in any of those threads and the network will use the device associated with that thread for memory and computation.

From opencv/opencv#14827

@ynioba
Copy link

ynioba commented May 24, 2020

Thanks for your reply, the code works well,you forever happy is my greatest wish,thanks again for your reply.

@BackT0TheFuture
Copy link

BackT0TheFuture commented May 26, 2020

@YashasSamaga
thanks for your great efforts.
what's the efficient way to use cv dnn under multi-thread condation? e.g. web request
thanks!

@YashasSamaga
Copy link
Author

@goodtogood I didn't understand your question. Most of the computation is carried out on GPU. You can have multiple instances of cv::dnn::Net for the same GPU. This allows you to extract more from each GPU by reducing GPU idle time (big improvements in some cases like in opencv/opencv#17365 (comment)).

@BackT0TheFuture
Copy link

@YashasSamaga
sorry for my unclear question.
actually I want to set up one inference service via web app.
it is known that the initialization of model often costs too much time.
first time, I tried init only one dnn::Net object,
It was shared among threads, but it crashed.
seemed that it's not thread safety.
then I made a pool of dnn::Net ,
It includes some dnn::Net obj initialized in advanced.
a free obj will be taken from the pool for inference.
Is this way correct ? is there a better way to handle this case?
thanks a lot!

@YashasSamaga
Copy link
Author

YashasSamaga commented May 26, 2020

@goodtogood That sounds reasonable. You can also try to do inferences in batches if you don't have tight latency requirements. You can have cv::dnn::Net objects initialized for single image inference, batch of two and maybe even four. The throughput increases dramatically as you increase the batch size.

Here are some stats for YOLOv4 on RTX 2080 Ti. The batched inference gives an almost 2x increase in FPS.

Input Size Darknet FP16 OCV FP32 FPS OCV FP16 FPS OCV FP32 batch = 4 OCV FP16 batch = 4
320 x 320 105.8 129.2 171.2 198 384
416 x 416 85.6 99.9 146 139.6 260.5
512 x 512 71.8 90.3 125.6 112.8 190.5
608 x 608 56.7 56 103.2 68.5 133

@BackT0TheFuture
Copy link

BackT0TheFuture commented May 26, 2020

@YashasSamaga
thank you so much for details.
as for FP16, I'v tested using dnn of OCV,
code copied from your gist link
I didn't get a significant difference (GPU RTX 2070S).

YOLOv4 608x608 batch=1
OCV FP32 22fps
OCV FP16 26fps

a little weird! Is it normal ?

@YashasSamaga
Copy link
Author

@goodtogood Can you share the exact code you used? 22 or 26FPS seems too less for RTX 2070S.

@YashasSamaga
Copy link
Author

YashasSamaga commented May 26, 2020

@goodtogood Set nms_threshold=0 in all [yolo] blocks in yolov4.cfg. NMS is carried out on CPU and is very inefficient when done during the inference. It's best to disable NMS and perform it after the inference finishes. You will gain significant additional FPS. You can find example code here: https://gist.github.com/YashasSamaga/e2b19a6807a13046e399f4bc3cca3a49

@BackT0TheFuture
Copy link

BackT0TheFuture commented May 26, 2020

@YashasSamaga

Driver: 441.22 CUDA: 10.2 CUDNN: 7.6.5
Opencv commit 713577
Windows8.1 64bit VS2019 OCV4.3
It's almost copied from your code, just modified the benchmark num and backend.
thanks!

YOLO v4
[CUDA FP32]
init >> 1329.51ms
inference >> min = 45.596ms, max = 49.184ms, mean = 46.7278ms, stddev = 0.57918ms
[CUDA FP16]
init >> 865.449ms
inference >> min = 37.418ms, max = 43.093ms, mean = 39.4826ms, stddev = 1.24976ms

#include <iostream>
#include <algorithm>
#include <vector>
#include <chrono>
#include <numeric>

#include <opencv2/dnn.hpp>
#include <opencv2/highgui.hpp>
#include "benchmark.hpp"

#define USE_RANDOM_IMAGES

constexpr auto default_batch_size = 1;

struct mask_type {
    int backend;
    int target;
};

struct config_type {
    std::string name;
    int backend;
    int target;
};

std::vector<config_type> backends = {
    //{"OCV CPU", cv::dnn::DNN_BACKEND_OPENCV, cv::dnn::DNN_TARGET_CPU},
    //{"OCV OpenCL", cv::dnn::DNN_BACKEND_OPENCV, cv::dnn::DNN_TARGET_OPENCL},
    //{"OCV OpenCL FP16", cv::dnn::DNN_BACKEND_OPENCV, cv::dnn::DNN_TARGET_OPENCL_FP16},
    //{"IE CPU", cv::dnn::DNN_BACKEND_INFERENCE_ENGINE, cv::dnn::DNN_TARGET_CPU},

    {"CUDA FP32", cv::dnn::DNN_BACKEND_CUDA, cv::dnn::DNN_TARGET_CUDA},
    {"CUDA FP16", cv::dnn::DNN_BACKEND_CUDA, cv::dnn::DNN_TARGET_CUDA_FP16}
};

std::vector<cv::Mat> image_samples;

template <class T>
auto to_milliseconds(const T& duration) {
    return std::chrono::duration_cast<std::chrono::milliseconds>(duration);
}

template <class T>
auto to_microseconds(const T& duration) {
    return std::chrono::duration_cast<std::chrono::microseconds>(duration);
}

struct perf_result_t
{
    using duration = std::chrono::microseconds;

    duration init_time;
    std::vector<duration> runtimes;
};

template <std::size_t BENCHMARK_RUNS, std::size_t WARMUP_RUNS>
auto run_network(
    const std::string& model, const std::string& config,
    const cv::Mat& blob,
    const std::vector<std::string>& output_names_,
    int backend, int target)
{
    auto net = cv::dnn::readNet(model, config);
    net.setPreferableBackend(backend);
    net.setPreferableTarget(target);

    auto output_names = output_names_;
    if (output_names.empty())
        output_names = net.getUnconnectedOutLayersNames();

    std::vector<cv::Mat> output_mats;
    auto init_time = benchmark([&] {
        net.setInput(blob);
        net.forward(output_mats, output_names);
        });

    for (int i = 0; i < WARMUP_RUNS; i++)
    {
        net.setInput(blob);
        net.forward(output_mats, output_names);
    }

    perf_result_t result;
    result.init_time = init_time;
    result.runtimes.reserve(BENCHMARK_RUNS);

    for (int i = 0; i < BENCHMARK_RUNS; i++)
    {
        net.setInput(blob);
        auto inference_time = benchmark([&] {
            net.forward(output_mats, output_names);
            });

        result.runtimes.push_back(inference_time);
    }

    return result;
}

void bench_network(
    const std::string& model, const std::string& config,
    cv::Size input_size,
    const std::vector<std::string>& output_names = {},
    int count = default_batch_size,
    std::vector<mask_type> mask = {})
{
#ifndef USE_RANDOM_IMAGES
    assert(count <= image_samples.size());
#endif

    std::vector<cv::Mat> images;
    for (int i = 0; i < count; i++)
    {
#ifdef USE_RANDOM_IMAGES
        cv::Mat image(input_size, CV_32FC3);
        cv::randu(image, cv::Scalar(0, 0, 0), cv::Scalar(255, 255, 255));
        images.push_back(image);
#else
        images.push_back(image_samples[i]);
#endif
    }

    cv::Mat blob = cv::dnn::blobFromImages(images, 1.0f, input_size, 0.0f);

    for (auto c : backends) {
        auto backend = c.backend;
        auto target = c.target;

        bool skip = [backend, target, mask] {
            for (auto m : mask) {
                if (m.backend == backend && m.target == target)
                    return true;
                if (m.backend == backend && m.target == -1)
                    return true;
                if (m.backend == -1 && m.target == target)
                    return true;
            }

            return false;
        } ();

        if (skip)
            continue;

        try {
            constexpr int WARMUP_RUNS = 3;
            constexpr int BENCHMARK_RUNS = 400;
            auto result = run_network<BENCHMARK_RUNS, WARMUP_RUNS>(model, config, blob, output_names, backend, target);

            float init_time = to_microseconds(result.init_time).count() / 1000.0;

            std::vector<float> runtimes;
            for (auto r : result.runtimes)
                runtimes.push_back(to_microseconds(r).count() / 1000.0);

            auto sum = std::accumulate(std::begin(runtimes), std::end(runtimes), 0.0f);
            auto squared_sum = std::inner_product(std::begin(runtimes), std::end(runtimes), std::begin(runtimes), 0.0f);

            auto min = *std::min_element(std::begin(runtimes), std::end(runtimes));
            auto max = *std::max_element(std::begin(runtimes), std::end(runtimes));
            auto mean = sum / runtimes.size();
            auto stddev = std::sqrt(squared_sum / runtimes.size() - mean * mean);

            std::cout << '[' << c.name << "]" << '\n'
                << "\tinit >> " << init_time << "ms" << '\n'
                << "\tinference >> " << "min = " << min << "ms, max = " << max << "ms, mean = " << mean << "ms, stddev = " << stddev << "ms" << std::endl;
        }
        catch (const std::exception& ex) {
            std::cout << ex.what() << std::endl;
            return;
        }
    }

    std::cout << std::endl;
}


void bench_yolo_v4()
{
    std::cout << "YOLO v4\n";
    bench_network("./yolov4.cfg", "./yolov4.weights", cv::Size(608, 608));
    std::cout << std::endl;
}


int main()
{    
    bench_yolo_v4();
    return 0;
}

@BackT0TheFuture
Copy link

BackT0TheFuture commented May 26, 2020

@goodtogood Set nms_threshold=0 in all [yolo] blocks in yolov4.cfg. NMS is carried out on CPU and is very inefficient when done during the inference. It's best to disable NMS and perform it after the inference finishes. You will gain significant additional FPS. You can find example code here: https://gist.github.com/YashasSamaga/e2b19a6807a13046e399f4bc3cca3a49

YOLO v4
[CUDA FP32]
        init >> 1245.76ms
        inference >> min = 29.934ms, max = 31.181ms, mean = 30.3622ms, stddev = 0.207436ms
[CUDA FP16]
        init >> 876.087ms
        inference >> min = 22.916ms, max = 28.212ms, mean = 24.5076ms, stddev = 1.09143ms

after setting nms_threshold=0 of three [yolo] layers,
the performance improved a lot.
thank you so much for your patient explanation.

@JulienMaille
Copy link

JulienMaille commented Jun 11, 2020

@YashasSamaga I missed the release of cudnn 8.0.0 RC a week ago.
Seems like there are still huge inference dlls. No sure which ones we need. Have you tried compiling OpenCV with this release?
image

@matt-sharp
Copy link

@goodtogood That sounds reasonable. You can also try to do inferences in batches if you don't have tight latency requirements. You can have cv::dnn::Net objects initialized for single image inference, batch of two and maybe even four. The throughput increases dramatically as you increase the batch size.

How do we control the batch size? Is this what we feed into blobfromimages?

@YashasSamaga
Copy link
Author

How do we control the batch size? Is this what we feed into blobfromimages?

@matt-sharp Yes. Note that changing input shapes will cause reinitialization (which is time-consuming). Fix a batch size and use it throughout (and in case you have just three images but initialized for a batch size of four, pad a dummy zero image and make a batch of four to avoid reinitialization).

You can also initialize multiple networks for different batch sizes if your GPU memory permits. You can have one net object for single image inference, another for a batch of four, and another for eight. You can use all networks simultaneously (if you want) and use a smaller batch size when you do not have enough jobs to populate the bigger batches.

@matt-sharp
Copy link

@YashasSamaga are there any guidelines for the optimal batch size? Which backend DNN_TARGET_CUDA_FP16 or DNN_TARGET_CUDA_FP32?
I'm using 1 x Tesla v100, CUDA 11.2, CuDNN 7.6.5, YOLOv4, image size 608 x 608.

Is there any benefit to initializing multiple networks for the same model and running in parallel?
Also, is it possible to run batch inference with the high level Detection Model API since we don't feed a blob into this?

@YashasSamaga
Copy link
Author

YashasSamaga commented Apr 9, 2021

are there any guidelines for the optimal batch size?

@matt-sharp The performance varies across devices. The latency increases with batch size along with throughput. The batch size you use is largely dependent on your latency requirements. It's a tradeoff between latency and throughput. I'd recommend trying out different batch sizes and choose a batch size that provides substantial improvement compared to the next smaller batch size. The throughput generally always increases with batch size but with diminishing returns.

Which backend DNN_TARGET_CUDA_FP16 or DNN_TARGET_CUDA_FP32?

FP16 works great for YOLOv4 with practically no loss in detection performance. Your GPU supports FP16 target and I recommend it.

Is there any benefit to initializing multiple networks for the same model and running in parallel?

Yes, the GPU is used during inference but stays idle during pre/postprocessing on CPU. If you use multiple threads with one network initialized per thread, the GPU inference workload from one thread will keep the GPU busy while another thread is busy doing the pre/postprocessing. You can also pipeline your entire process into stages: preprocessing, DNN inference, postprocess. With a pipeline, you can keep all stages of the pipeline busy by processing different frames in each stage. For example, you will be preprocessing the next frame while the GPU is computing the forward pass for the current frame and simultaneously you would be postprocessing the previous frame. This increases throughput.

If you're initializing networks with large batch sizes, say 32, then most of the computation goes waste if you had to perform inference on just one image. But reinitializing the network to work with a single image would cost much more time. The idea of initializing multiple networks with different batch sizes is to let you improve latency when you cannot fill an entire batch. If you had just 3 images, you can use the network initialized to process four images and later if you have to process 8 images, you can use a network initialized to work on batches of eight.

Also, is it possible to run batch inference with the high level Detection Model API since we don't feed a blob into this?

Batching is not supported in the high level model API. You can track the feature request here: opencv/opencv#17838

@matt-sharp
Copy link

@YashasSamaga thanks for your reply.

Please can I confirm that the Detection Model API returns bounding box co-ordinates in the form absolute(top, left, width, height) whereas net.forward returns absolute(centre x, centre y, width, height)?

@fengyuentau
Copy link

The benchmarks I posted are for MobileNetSSD_deploy.prototxt/MobileNetSSD_deploy.caffemodel which you can find here.

MobileNet is slow with the CUDA backend because of depthwise convolutions. The CUDA backend fully relies on cuDNN for convolutions and cuDNN is very bad at depthwise convolutions.

Has this bad performance of depthwise conv fixed for cuDNN now? I just tested a DNN model with depthwise conv on a Jetson Nano B01 with CUDA 10.2 & cuDNN 8.2.1. Results turned out to be slower than the SoC of Jetson Nano (arm-based, 4-core, 1.5GHz), and much slower than the SoC on Raspberry Pi 4B (4-core, 1.5 GHz).

Also tested a model w/o depthwise conv, which seems to be meeting expectations (~10X times faster than the Jetson Nano SoC).

@Algabri
Copy link

Algabri commented Mar 10, 2022

Do you have an academic paper for these results?
I would like to cite.

@YashasSamaga
Copy link
Author

@Algabri No, there is no paper that is specific to this project.

@Algabri
Copy link

Algabri commented Mar 11, 2022

@YashasSamaga , Thanks for your reply.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment