This gist documents the Google Summer of Code project. It is not updated and hence does not indicate current status of the CUDA backend.
For updated details, please see this gist.
This gist documents the Google Summer of Code project. It is not updated and hence does not indicate current status of the CUDA backend.
For updated details, please see this gist.
Student: Yashas Samaga B L
Mentor: Davis King
Project Link: https://summerofcode.withgoogle.com/projects/#6021087400296448
Relevant PRs:
The OpenCV’s DNN module has a blazing fast inference capability on CPUs. It supports performing inference on GPUs using OpenCL but lacks a CUDA backend. NVIDIA’s GPUs support OpenCL, but their capabilities are limited by OpenCL.
This project adds a new CUDA backend that can perform lightning fast inference on NVIDIA GPUs.
The CUDA backend requires CUDA Toolkit and cuDNN (min: 7.5.0) to be installed on the system. The CMake scripts will automatically detect the dependencies when the following options are set:
WITH_CUDA
WITH_CUDNN
The CUDA backend is enabled by setting the following option:
OPENCV_DNN_CUDA
After building, run [build dir]/bin/opencv_test_dnn
and [build dir]/bin/opencv_perf_dnn
.
The project adds the following new backends and targets to the existing list.
Backend | Target |
---|---|
DNN_BACKEND_CUDA |
DNN_TARGET_CUDA |
DNN_BACKEND_CUDA |
DNN_TARGET_CUDA_FP16 |
The CUDA backend uses OpenCV's CPU backend as a fallback for unsupported layers and partially supported layers with unsupported configurations.
Blip | Meaning |
---|---|
✔️ | fully supported |
🔵 | partially supported |
❌ | unsupported |
Layer | Status |
---|---|
Activations | ✔️ |
Batch Normalization | ✔️ |
Blank Layer | ✔️ |
Concat Layer | ✔️ |
Const Layer | ✔️ |
Convolution 2d | ✔️ |
Convolution 3d | ✔️ |
Crop and resize | ❌ |
Crop Layer | ✔️ |
Detection Output Layer | ❌ |
Deconvolution 2d | 🔵 (most configurations supported) |
Deconvolution 3d | 🔵 (most configurations supported) |
Elementwise Layers | ✔️ |
Eltwise Layer | ✔️ |
Flatten Layer | ✔️ |
Fully Connected Layer | ✔️ |
Input Layer | ❌ |
Interp Layer | ✔️ |
Local Response Normalization | ✔️ |
Max Unpooling 2d | ✔️ |
Max Unpooling 3d | ✔️ |
MVN Layer | ❌ |
Normalize Layer | 🔵 (L1 and L2 supported) |
Padding Layer | ✔️ |
Permute Layer | ✔️ |
Pooling 2d | 🔵 (max and average supported) |
Pooling 3d | 🔵 (max and average supported) |
Prior Box Layer | ✔️ |
Proposal Layer | ❌ |
Region Layer | ✔️ |
Reorg Layer | ✔️ |
Reshape Layer | ✔️ |
Resize Layer | ✔️ |
Scale Layer | ✔️ |
Shift Layer | ✔️ |
Shuffle Channel Layer | ✔️ |
Slice Layer | ✔️ |
Softmax Layer | ✔️ |
Split Layer | ✔️ |
LSTM Layer | ❌ |
CPU: i7 7700HQ
GPU: NVIDIA GTX 1050 Mobile
CPU BLAS Library: MKL 2019.0.4
CUDA Version: 10.1
cuDNN: 7.6.2
Warmup Runs: 3 (forward pass is performed three times before benchmarks)
Benchmark Runs: 10 (the average of ten forward passes is reported)
Test Code: https://gist.github.com/YashasSamaga/71157cf0c3768c497e5e70fb95435596
Model | CUDA FP32 | Inference Engine CPU | OpenCV CPU |
---|---|---|---|
GoogLeNet | 7.2447ms | 10.4981ms | 17.9176ms |
DenseNet121 | 12.6324ms | 19.1823ms | 48.0628ms |
EAST Text Detection | 18.8281ms | 49.0508ms | 88.9429ms |
ENet | 11.5014ms | Exception | 62.5854ms |
FastNeuralStyle StaryNight | 27.498ms | 178.309ms | 160.359ms |
Inception 5h | 7.8546ms | 22.2789ms | 20.3255ms |
Inception v2 FasterRCNN | 112.736ms | Exception | 374.26ms |
MobileNet SSD | 58.4751ms | 9.2896ms | 27.3061ms |
OpenCV Face Detector | 6.9831ms | 8.3981ms | 17.6683ms |
OpenPose Pose MPI | 160.561ms | 509.446ms | 838.161ms |
Resnet 50 | 11.3603ms | 28.1529ms | 50.2752ms |
SqueezeNet | 2.4084ms | 3.2918ms | 5.476ms |
VGG16 SSD | 70.4117ms | 249.725ms | 360.207ms |
Yolo v3 | 57.9822ms | 214.629ms | 296.806ms |
Yolo v2 | 51.5784ms | 193.453ms | 260.19ms |
Model | CUDA FP32 | Inference Engine CPU | OpenCV CPU |
---|---|---|---|
GoogLeNet | 35.7556ms | 108.946ms | 225.928ms |
DenseNet121 | 74.9241ms | 295.105ms | 650.924ms |
EAST Text Detection | 149.58ms | 536.946ms | 1273.93ms |
FastNeuralStyle StaryNight | 283.173ms | 1966.5ms | 2175.3ms |
Inception 5h | 36.6225ms | 180.429ms | 233.276ms |
MobileNet SSD | 277.753ms | 111.872ms | 316.063ms |
OpenCV Face Detector | 52.4366ms | 95.7866ms | 202.657ms |
OpenPose Pose MPI | 628.617ms | 5650.05ms | 10683.5ms |
Resnet 50 | 74.283ms | 230.817ms | 541.308ms |
SqueezeNet | 15.8144ms | 35.4915ms | 69.4122ms |
VGG16 SSD | 594.286ms | 2796.23ms | 4661.51ms |
Yolo v3 | 488.704ms | 2419.8ms | 4209.74ms |
Yolo v2 | 491.414ms | 2185.47ms | 3788.34ms |
CPU: 2x Intel Xeon E5-2640 v4
GPU: 1x NVIDIA GTX 1080 Ti (11 GB)
CPU BLAS Library: OpenBLAS 0.2.20
CUDA Version: 10.0
cuDNN: 7.6.2
Warmup Runs: 3 (forward pass is performed three times before benchmarks)
Benchmark Runs: 10 (the average of ten forward passes is reported)
Test Code: https://gist.github.com/YashasSamaga/71157cf0c3768c497e5e70fb95435596
Model | CUDA FP32 | OpenCV CPU |
---|---|---|
GoogLeNet | 4.8824ms | 14.2981ms |
DenseNet121 | 6.4555ms | 57.8244ms |
EAST Text Detection | 5.901ms | 67.4301ms |
ENet | 4.5979ms | 30.2767ms |
FastNeuralStyle StaryNight | 5.3193ms | 51.3313ms |
Inception 5h | 4.9487ms | 16.0048ms |
Inception v2 FasterRCNN | 82.0298ms | 179.245ms |
MobileNet SSD | 70.9177ms | 23.9348ms |
OpenCV Face Detector | 4.9288ms | 15.4205ms |
OpenPose Pose MPI | 30.5954ms | 246.747ms |
Resnet 50 | 4.5968ms | 45.1153ms |
SqueezeNet | 1.0888ms | 3.6492ms |
VGG16 SSD | 23.5926ms | 194.976ms |
Yolo v3 | 18.0002ms | 141.861ms |
Yolo v2 | 12.1279ms | 111.642ms |
Model | CUDA FP32 | OpenCV CPU |
---|---|---|
GoogLeNet | 10.149ms | 75.9591ms |
DenseNet121 | 20.269ms | 312.426ms |
EAST Text Detection | 32.1556ms | 402.16ms |
FastNeuralStyle StaryNight | 49.1025ms | 461.095ms |
Inception 5h | 9.9721ms | 67.9308ms |
MobileNet SSD | 96.2898ms | 110.783ms |
OpenCV Face Detector | 22.7501ms | 77.8742ms |
OpenPose Pose MPI | 118.858ms | 2321.89ms |
Resnet 50 | 18.4139ms | 229.599ms |
SqueezeNet | 4.4893ms | 22.3049ms |
VGG16 SSD | 194.181ms | 1319.67ms |
Yolo v3 | 122.603ms | 1044.11ms |
Yolo v2 | 104.072ms | 819.177ms |
Model | CUDA FP32 | OpenCV CPU |
---|---|---|
GoogLeNet | 90.3755ms | 775.769ms |
DenseNet121 | 199.516ms | 3536.38ms |
EAST Text Detection | 376.458ms | 7685.72ms |
FastNeuralStyle StaryNight | 801.778ms | 6607.15ms |
Inception 5h | 93.4188ms | 771.575ms |
MobileNet SSD | 1028.93ms | 1110.37ms |
OpenCV Face Detector | 276.992ms | 977.997ms |
OpenPose Pose MPI | 1279.26ms | 32159.3ms |
Resnet 50 | 200.789ms | 1719.92ms |
SqueezeNet | 55.6244ms | 255.397ms |
VGG16 SSD | 2969.05ms | 17201ms |
Yolo v3 | 1564.78ms | 13699.2ms |
Yolo v2 | 1362.84ms | 11254.9ms |
Model | batch size = 1 | batch size = 10 | batch size = 128 |
---|---|---|---|
GoogLeNet | 204 | 985 | 1416 |
DenseNet121 | 154 | 493 | 641 |
EAST Text Detection | 169 | 311 | 340 |
ENet | 217 | Not Applicable | Not Applicable |
FastNeuralStyle StaryNight | 188 | 204 | 160 |
Inception 5h | 202 | 1002 | 1370 |
Inception v2 FasterRCNN | 12 | Not Aplicable | Not Applicable |
MobileNet SSD | 14 | 104 | 124 |
OpenCV Face Detector | 202 | 440 | 462 |
OpenPose Pose MPI | 33 | 84 | 100 |
Resnet 50 | 217 | 540 | 637 |
SqueezeNet | 918 | 2228 | 2301 |
VGG16 SSD | 42 | 52 | 43 |
Yolo v3 | 55 | 82 | 81 |
Yolo v2 | 82 | 96 | 93 |
GPU: NVIDIA GTX 1080 Ti (11 GB)
Model | OpenCV CUDA | TensorFlow |
---|---|---|
ResNet-50 | 4.5968ms | 7.1163ms |
EAST Text Detection | 5.901ms | 8.6890ms |
Model | OpenCV CUDA | TensorFlow |
---|---|---|
ResNet-50 | 18.4139ms | 22.3665ms |
EAST Text Detection | 32.1556ms | 39.4857ms |
Model | OpenCV CUDA | TensorFlow |
---|---|---|
ResNet-50 | 200.789ms | 216.3923ms |
EAST Text Detection | 376.458ms | 421.8292ms |
@YashasSamaga, some results for a single image inference (Mean of 10 runs) on Intel
Intel with 1080 GTX Ti : Mobilenet SSD V2
python opencv_mobile_ssd.py (4.3.0)
[INFO] setting preferable backend and target to CUDA...
1 0.99846137 person
64 0.46295404 potted plant
64 0.35649517 potted plant
72 0.735632 tv
73 0.60551673 laptop
84 0.8533891 book
84 0.71501225 book
84 0.55131793 book
Time per inference: 65.505838 ms
FPS: 15.265814842071778
python opencv_mobile_ssd.py (4.3.0+ opencv/opencv#16900)
[INFO] setting preferable backend and target to CUDA...
1 0.99846137 person
64 0.4629534 potted plant
64 0.35649383 potted plant
72 0.7356325 tv
73 0.6055173 laptop
84 0.85338897 book
84 0.7150124 book
84 0.55131817 book
Time per inference: 5.713272 ms
FPS: 175.0310476063297
I have a question since I started working with inference on Cuda devices: is there a reason why the cudnn dll is so big?
This is a real pain to redistribute a 200+Mb, more over when you know that only a fraction of your users has a compatible GPU.
Is there a way to make it lighter? In comparison OpenVino redist is much smaller.
@JulienMaile cuDNN is being broken into pieces in cuDNN 8.0. OpenCV cannot really do anything to make it lighter other than adding replacements for the services cuDNN provides (which would make cuDNN optional).
Thanks for your reply. Where can I find information about cuDNN 8.0?
@JulieanMaile Looks like they have removed it from their release notes page. They had posted release notes for early access cuDNN 8.0.x.x weeks ago. I can find traces in google (search for "cuDNN 8.0 early access"). When I had gone through the documentation, they seemed to have split cuDNN into 6 libraries: 3 for inference and 3 for training. Each category had something like basic ops, cnn and advanced (mostly RNN stuff).
Hi @YashasSamaga that's also what I found. Any news since the recent GTC show?
Hi @YashasSamaga I have a question since I have two GPU last week. How to make another GPU work ? Thanks!
There are many ways to make use of multiple GPUs. Here is one which I think is the safest and the least complex solution. It makes use of the fact that the CUDA runtime library maintains a separate CUDA context for each CPU thread.
Suppose you have N devices.
Create N threads.
Assign a CUDA device to each thread by calling cudaSetDevice or cv::cuda::setDevice in that thread. Each thread is now associated with a device.
You can create any number of cv::dnn::Net objects in any of those threads and the network will use the device associated with that thread for memory and computation.
From opencv/opencv#14827
Thanks for your reply, the code works well,you forever happy is my greatest wish,thanks again for your reply.
@YashasSamaga
thanks for your great efforts.
what's the efficient way to use cv dnn under multi-thread condation? e.g. web request
thanks!
@goodtogood I didn't understand your question. Most of the computation is carried out on GPU. You can have multiple instances of cv::dnn::Net
for the same GPU. This allows you to extract more from each GPU by reducing GPU idle time (big improvements in some cases like in opencv/opencv#17365 (comment)).
@YashasSamaga
sorry for my unclear question.
actually I want to set up one inference service via web app.
it is known that the initialization of model often costs too much time.
first time, I tried init only one dnn::Net object,
It was shared among threads, but it crashed.
seemed that it's not thread safety.
then I made a pool of dnn::Net ,
It includes some dnn::Net obj initialized in advanced.
a free obj will be taken from the pool for inference.
Is this way correct ? is there a better way to handle this case?
thanks a lot!
@goodtogood That sounds reasonable. You can also try to do inferences in batches if you don't have tight latency requirements. You can have cv::dnn::Net
objects initialized for single image inference, batch of two and maybe even four. The throughput increases dramatically as you increase the batch size.
Here are some stats for YOLOv4 on RTX 2080 Ti. The batched inference gives an almost 2x increase in FPS.
Input Size | Darknet FP16 | OCV FP32 FPS | OCV FP16 FPS | OCV FP32 batch = 4 | OCV FP16 batch = 4 |
---|---|---|---|---|---|
320 x 320 | 105.8 | 129.2 | 171.2 | 198 | 384 |
416 x 416 | 85.6 | 99.9 | 146 | 139.6 | 260.5 |
512 x 512 | 71.8 | 90.3 | 125.6 | 112.8 | 190.5 |
608 x 608 | 56.7 | 56 | 103.2 | 68.5 | 133 |
@YashasSamaga
thank you so much for details.
as for FP16, I'v tested using dnn of OCV,
code copied from your gist link
I didn't get a significant difference (GPU RTX 2070S).
YOLOv4 608x608 batch=1
OCV FP32 22fps
OCV FP16 26fps
a little weird! Is it normal ?
@goodtogood Can you share the exact code you used? 22 or 26FPS seems too less for RTX 2070S.
@goodtogood Set nms_threshold=0
in all [yolo]
blocks in yolov4.cfg. NMS is carried out on CPU and is very inefficient when done during the inference. It's best to disable NMS and perform it after the inference finishes. You will gain significant additional FPS. You can find example code here: https://gist.github.com/YashasSamaga/e2b19a6807a13046e399f4bc3cca3a49
Driver: 441.22 CUDA: 10.2 CUDNN: 7.6.5
Opencv commit 713577
Windows8.1 64bit VS2019 OCV4.3
It's almost copied from your code, just modified the benchmark num and backend.
thanks!
YOLO v4
[CUDA FP32]
init >> 1329.51ms
inference >> min = 45.596ms, max = 49.184ms, mean = 46.7278ms, stddev = 0.57918ms
[CUDA FP16]
init >> 865.449ms
inference >> min = 37.418ms, max = 43.093ms, mean = 39.4826ms, stddev = 1.24976ms
#include <iostream>
#include <algorithm>
#include <vector>
#include <chrono>
#include <numeric>
#include <opencv2/dnn.hpp>
#include <opencv2/highgui.hpp>
#include "benchmark.hpp"
#define USE_RANDOM_IMAGES
constexpr auto default_batch_size = 1;
struct mask_type {
int backend;
int target;
};
struct config_type {
std::string name;
int backend;
int target;
};
std::vector<config_type> backends = {
//{"OCV CPU", cv::dnn::DNN_BACKEND_OPENCV, cv::dnn::DNN_TARGET_CPU},
//{"OCV OpenCL", cv::dnn::DNN_BACKEND_OPENCV, cv::dnn::DNN_TARGET_OPENCL},
//{"OCV OpenCL FP16", cv::dnn::DNN_BACKEND_OPENCV, cv::dnn::DNN_TARGET_OPENCL_FP16},
//{"IE CPU", cv::dnn::DNN_BACKEND_INFERENCE_ENGINE, cv::dnn::DNN_TARGET_CPU},
{"CUDA FP32", cv::dnn::DNN_BACKEND_CUDA, cv::dnn::DNN_TARGET_CUDA},
{"CUDA FP16", cv::dnn::DNN_BACKEND_CUDA, cv::dnn::DNN_TARGET_CUDA_FP16}
};
std::vector<cv::Mat> image_samples;
template <class T>
auto to_milliseconds(const T& duration) {
return std::chrono::duration_cast<std::chrono::milliseconds>(duration);
}
template <class T>
auto to_microseconds(const T& duration) {
return std::chrono::duration_cast<std::chrono::microseconds>(duration);
}
struct perf_result_t
{
using duration = std::chrono::microseconds;
duration init_time;
std::vector<duration> runtimes;
};
template <std::size_t BENCHMARK_RUNS, std::size_t WARMUP_RUNS>
auto run_network(
const std::string& model, const std::string& config,
const cv::Mat& blob,
const std::vector<std::string>& output_names_,
int backend, int target)
{
auto net = cv::dnn::readNet(model, config);
net.setPreferableBackend(backend);
net.setPreferableTarget(target);
auto output_names = output_names_;
if (output_names.empty())
output_names = net.getUnconnectedOutLayersNames();
std::vector<cv::Mat> output_mats;
auto init_time = benchmark([&] {
net.setInput(blob);
net.forward(output_mats, output_names);
});
for (int i = 0; i < WARMUP_RUNS; i++)
{
net.setInput(blob);
net.forward(output_mats, output_names);
}
perf_result_t result;
result.init_time = init_time;
result.runtimes.reserve(BENCHMARK_RUNS);
for (int i = 0; i < BENCHMARK_RUNS; i++)
{
net.setInput(blob);
auto inference_time = benchmark([&] {
net.forward(output_mats, output_names);
});
result.runtimes.push_back(inference_time);
}
return result;
}
void bench_network(
const std::string& model, const std::string& config,
cv::Size input_size,
const std::vector<std::string>& output_names = {},
int count = default_batch_size,
std::vector<mask_type> mask = {})
{
#ifndef USE_RANDOM_IMAGES
assert(count <= image_samples.size());
#endif
std::vector<cv::Mat> images;
for (int i = 0; i < count; i++)
{
#ifdef USE_RANDOM_IMAGES
cv::Mat image(input_size, CV_32FC3);
cv::randu(image, cv::Scalar(0, 0, 0), cv::Scalar(255, 255, 255));
images.push_back(image);
#else
images.push_back(image_samples[i]);
#endif
}
cv::Mat blob = cv::dnn::blobFromImages(images, 1.0f, input_size, 0.0f);
for (auto c : backends) {
auto backend = c.backend;
auto target = c.target;
bool skip = [backend, target, mask] {
for (auto m : mask) {
if (m.backend == backend && m.target == target)
return true;
if (m.backend == backend && m.target == -1)
return true;
if (m.backend == -1 && m.target == target)
return true;
}
return false;
} ();
if (skip)
continue;
try {
constexpr int WARMUP_RUNS = 3;
constexpr int BENCHMARK_RUNS = 400;
auto result = run_network<BENCHMARK_RUNS, WARMUP_RUNS>(model, config, blob, output_names, backend, target);
float init_time = to_microseconds(result.init_time).count() / 1000.0;
std::vector<float> runtimes;
for (auto r : result.runtimes)
runtimes.push_back(to_microseconds(r).count() / 1000.0);
auto sum = std::accumulate(std::begin(runtimes), std::end(runtimes), 0.0f);
auto squared_sum = std::inner_product(std::begin(runtimes), std::end(runtimes), std::begin(runtimes), 0.0f);
auto min = *std::min_element(std::begin(runtimes), std::end(runtimes));
auto max = *std::max_element(std::begin(runtimes), std::end(runtimes));
auto mean = sum / runtimes.size();
auto stddev = std::sqrt(squared_sum / runtimes.size() - mean * mean);
std::cout << '[' << c.name << "]" << '\n'
<< "\tinit >> " << init_time << "ms" << '\n'
<< "\tinference >> " << "min = " << min << "ms, max = " << max << "ms, mean = " << mean << "ms, stddev = " << stddev << "ms" << std::endl;
}
catch (const std::exception& ex) {
std::cout << ex.what() << std::endl;
return;
}
}
std::cout << std::endl;
}
void bench_yolo_v4()
{
std::cout << "YOLO v4\n";
bench_network("./yolov4.cfg", "./yolov4.weights", cv::Size(608, 608));
std::cout << std::endl;
}
int main()
{
bench_yolo_v4();
return 0;
}
@goodtogood Set
nms_threshold=0
in all[yolo]
blocks in yolov4.cfg. NMS is carried out on CPU and is very inefficient when done during the inference. It's best to disable NMS and perform it after the inference finishes. You will gain significant additional FPS. You can find example code here: https://gist.github.com/YashasSamaga/e2b19a6807a13046e399f4bc3cca3a49
YOLO v4
[CUDA FP32]
init >> 1245.76ms
inference >> min = 29.934ms, max = 31.181ms, mean = 30.3622ms, stddev = 0.207436ms
[CUDA FP16]
init >> 876.087ms
inference >> min = 22.916ms, max = 28.212ms, mean = 24.5076ms, stddev = 1.09143ms
after setting nms_threshold=0
of three [yolo] layers,
the performance improved a lot.
thank you so much for your patient explanation.
@YashasSamaga I missed the release of cudnn 8.0.0 RC a week ago.
Seems like there are still huge inference dlls. No sure which ones we need. Have you tried compiling OpenCV with this release?
@goodtogood That sounds reasonable. You can also try to do inferences in batches if you don't have tight latency requirements. You can have
cv::dnn::Net
objects initialized for single image inference, batch of two and maybe even four. The throughput increases dramatically as you increase the batch size.
How do we control the batch size? Is this what we feed into blobfromimages?
How do we control the batch size? Is this what we feed into blobfromimages?
@matt-sharp Yes. Note that changing input shapes will cause reinitialization (which is time-consuming). Fix a batch size and use it throughout (and in case you have just three images but initialized for a batch size of four, pad a dummy zero image and make a batch of four to avoid reinitialization).
You can also initialize multiple networks for different batch sizes if your GPU memory permits. You can have one net object for single image inference, another for a batch of four, and another for eight. You can use all networks simultaneously (if you want) and use a smaller batch size when you do not have enough jobs to populate the bigger batches.
@YashasSamaga are there any guidelines for the optimal batch size? Which backend DNN_TARGET_CUDA_FP16 or DNN_TARGET_CUDA_FP32?
I'm using 1 x Tesla v100, CUDA 11.2, CuDNN 7.6.5, YOLOv4, image size 608 x 608.
Is there any benefit to initializing multiple networks for the same model and running in parallel?
Also, is it possible to run batch inference with the high level Detection Model API since we don't feed a blob into this?
are there any guidelines for the optimal batch size?
@matt-sharp The performance varies across devices. The latency increases with batch size along with throughput. The batch size you use is largely dependent on your latency requirements. It's a tradeoff between latency and throughput. I'd recommend trying out different batch sizes and choose a batch size that provides substantial improvement compared to the next smaller batch size. The throughput generally always increases with batch size but with diminishing returns.
Which backend DNN_TARGET_CUDA_FP16 or DNN_TARGET_CUDA_FP32?
FP16 works great for YOLOv4 with practically no loss in detection performance. Your GPU supports FP16 target and I recommend it.
Is there any benefit to initializing multiple networks for the same model and running in parallel?
Yes, the GPU is used during inference but stays idle during pre/postprocessing on CPU. If you use multiple threads with one network initialized per thread, the GPU inference workload from one thread will keep the GPU busy while another thread is busy doing the pre/postprocessing. You can also pipeline your entire process into stages: preprocessing, DNN inference, postprocess. With a pipeline, you can keep all stages of the pipeline busy by processing different frames in each stage. For example, you will be preprocessing the next frame while the GPU is computing the forward pass for the current frame and simultaneously you would be postprocessing the previous frame. This increases throughput.
If you're initializing networks with large batch sizes, say 32, then most of the computation goes waste if you had to perform inference on just one image. But reinitializing the network to work with a single image would cost much more time. The idea of initializing multiple networks with different batch sizes is to let you improve latency when you cannot fill an entire batch. If you had just 3 images, you can use the network initialized to process four images and later if you have to process 8 images, you can use a network initialized to work on batches of eight.
Also, is it possible to run batch inference with the high level Detection Model API since we don't feed a blob into this?
Batching is not supported in the high level model API. You can track the feature request here: opencv/opencv#17838
@YashasSamaga thanks for your reply.
Please can I confirm that the Detection Model API returns bounding box co-ordinates in the form absolute(top, left, width, height) whereas net.forward returns absolute(centre x, centre y, width, height)?
The benchmarks I posted are for
MobileNetSSD_deploy.prototxt
/MobileNetSSD_deploy.caffemodel
which you can find here.MobileNet is slow with the CUDA backend because of depthwise convolutions. The CUDA backend fully relies on cuDNN for convolutions and cuDNN is very bad at depthwise convolutions.
Has this bad performance of depthwise conv fixed for cuDNN now? I just tested a DNN model with depthwise conv on a Jetson Nano B01 with CUDA 10.2 & cuDNN 8.2.1. Results turned out to be slower than the SoC of Jetson Nano (arm-based, 4-core, 1.5GHz), and much slower than the SoC on Raspberry Pi 4B (4-core, 1.5 GHz).
Also tested a model w/o depthwise conv, which seems to be meeting expectations (~10X times faster than the Jetson Nano SoC).
Do you have an academic paper for these results?
I would like to cite.
@Algabri No, there is no paper that is specific to this project.
@YashasSamaga , Thanks for your reply.
@YashasSamaga, I pulled in your change on top of 4.3.0. The difference is significant for mobilenet_ssd v2 on Intel with 1080 GTX Ti . I have to still test on Jetson platforms.