Skip to content

Instantly share code, notes, and snippets.

Last active July 23, 2020 00:38
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save paleomoon/73e4c6c52f5d1889051446fcacd50739 to your computer and use it in GitHub Desktop.
Save paleomoon/73e4c6c52f5d1889051446fcacd50739 to your computer and use it in GitHub Desktop.
Using ONNX in TensorRT
* Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* See the License for the specific language governing permissions and
* limitations under the License.
//! sampleOnnxMNIST.cpp
//! This file contains the implementation of the ONNX MNIST sample. It creates the network using
//! the MNIST onnx model.
//! It can be run with the following command line:
//! Command: ./sample_onnx_mnist [-h or --help] [-d=/path/to/data/dir or --datadir=/path/to/data/dir]
//! [--useDLACore=<int>]
#include "argsParser.h"
#include "buffers.h"
#include "common.h"
#include "logger.h"
#include "parserOnnxConfig.h"
#include "NvInfer.h"
#include <cuda_runtime_api.h>
#include <cstdlib>
#include <fstream>
#include <iostream>
#include <sstream>
const std::string gSampleName = "TensorRT.sample_onnx_mnist";
//! \brief The SampleOnnxMNIST class implements the ONNX MNIST sample
//! \details It creates the network using an ONNX model
class SampleOnnxMNIST
template <typename T>
using SampleUniquePtr = std::unique_ptr<T, samplesCommon::InferDeleter>;
SampleOnnxMNIST(const samplesCommon::OnnxSampleParams& params)
: mParams(params)
, mEngine(nullptr)
//! \brief Function builds the network engine
bool build();
//! \brief Runs the TensorRT inference engine for this sample
bool infer();
samplesCommon::OnnxSampleParams mParams; //!< The parameters for the sample.
nvinfer1::Dims mInputDims; //!< The dimensions of the input to the network.
nvinfer1::Dims mOutputDims; //!< The dimensions of the output to the network.
int mNumber{ 0 }; //!< The number to classify
std::shared_ptr<nvinfer1::ICudaEngine> mEngine; //!< The TensorRT engine used to run the network
//! \brief Parses an ONNX model for MNIST and creates a TensorRT network
bool constructNetwork(SampleUniquePtr<nvinfer1::IBuilder>& builder,
SampleUniquePtr<nvinfer1::INetworkDefinition>& network, SampleUniquePtr<nvinfer1::IBuilderConfig>& config,
SampleUniquePtr<nvonnxparser::IParser>& parser);
//! \brief Reads the input and stores the result in a managed buffer
bool processInput(const samplesCommon::BufferManager& buffers);
//! \brief Classifies digits and verify result
bool verifyOutput(const samplesCommon::BufferManager& buffers);
//! \brief Creates the network, configures the builder and creates the network engine
//! \details This function creates the Onnx MNIST network by parsing the Onnx model and builds
//! the engine that will be used to run MNIST (mEngine)
//! \return Returns true if the engine was created successfully and false otherwise
bool SampleOnnxMNIST::build()
auto builder = SampleUniquePtr<nvinfer1::IBuilder>(nvinfer1::createInferBuilder(gLogger.getTRTLogger()));
if (!builder)
return false;
const auto explicitBatch = 1U << static_cast<uint32_t>(NetworkDefinitionCreationFlag::kEXPLICIT_BATCH);
//const auto explicitPrecision = 1U << static_cast<uint32_t>(NetworkDefinitionCreationFlag::kEXPLICIT_PRECISION);
auto network = SampleUniquePtr<nvinfer1::INetworkDefinition>(builder->createNetworkV2(explicitBatch));
if (!network)
return false;
auto config = SampleUniquePtr<nvinfer1::IBuilderConfig>(builder->createBuilderConfig());
if (!config)
return false;
auto parser = SampleUniquePtr<nvonnxparser::IParser>(nvonnxparser::createParser(*network, gLogger.getTRTLogger()));
if (!parser)
return false;
auto constructed = constructNetwork(builder, network, config, parser);
if (!constructed)
return false;
mEngine = std::shared_ptr<nvinfer1::ICudaEngine>(
builder->buildEngineWithConfig(*network, *config), samplesCommon::InferDeleter());
if (!mEngine)
return false;
assert(network->getNbInputs() == 1);
mInputDims = network->getInput(0)->getDimensions();
assert(mInputDims.nbDims == 4);
assert(network->getNbOutputs() == 1);
mOutputDims = network->getOutput(0)->getDimensions();
assert(mOutputDims.nbDims == 2);
return true;
//! \brief Uses a ONNX parser to create the Onnx MNIST Network and marks the
//! output layers
//! \param network Pointer to the network that will be populated with the Onnx MNIST network
//! \param builder Pointer to the engine builder
bool SampleOnnxMNIST::constructNetwork(SampleUniquePtr<nvinfer1::IBuilder>& builder,
SampleUniquePtr<nvinfer1::INetworkDefinition>& network, SampleUniquePtr<nvinfer1::IBuilderConfig>& config,
SampleUniquePtr<nvonnxparser::IParser>& parser)
auto parsed = parser->parseFromFile(
locateFile(mParams.onnxFileName, mParams.dataDirs).c_str(), static_cast<int>(gLogger.getReportableSeverity()));
if (!parsed)
return false;
if (mParams.fp16)
if (mParams.int8)
samplesCommon::setAllTensorScales(network.get(), 127.0f, 127.0f);
samplesCommon::enableDLA(builder.get(), config.get(), mParams.dlaCore);
return true;
//! \brief Runs the TensorRT inference engine for this sample
//! \details This function is the main execution function of the sample. It allocates the buffer,
//! sets inputs and executes the engine.
bool SampleOnnxMNIST::infer()
// Create RAII buffer manager object
samplesCommon::BufferManager buffers(mEngine, mParams.batchSize);
auto context = SampleUniquePtr<nvinfer1::IExecutionContext>(mEngine->createExecutionContext());
if (!context)
return false;
// Read the input data into the managed buffers
assert(mParams.inputTensorNames.size() == 1);
if (!processInput(buffers))
return false;
// Memcpy from host input buffers to device input buffers
bool status = context->executeV2(buffers.getDeviceBindings().data());
if (!status)
return false;
// Memcpy from device output buffers to host output buffers
// Verify results
if (!verifyOutput(buffers))
return false;
return true;
//! \brief Reads the input and stores the result in a managed buffer
bool SampleOnnxMNIST::processInput(const samplesCommon::BufferManager& buffers)
const int inputH = mInputDims.d[2];
const int inputW = mInputDims.d[3];
// Read a random digit file
std::vector<uint8_t> fileData(inputH * inputW);
mNumber = rand() % 10;
readPGMFile(locateFile(std::to_string(mNumber) + ".pgm", mParams.dataDirs),, inputH, inputW);
// Print an ascii representation
gLogInfo << "Input:" << std::endl;
for (int i = 0; i < inputH * inputW; i++)
gLogInfo << (" .:-=+*#%@"[fileData[i] / 26]) << (((i + 1) % inputW) ? "" : "\n");
gLogInfo << std::endl;
float* hostDataBuffer = static_cast<float*>(buffers.getHostBuffer(mParams.inputTensorNames[0]));
for (int i = 0; i < inputH * inputW; i++)
hostDataBuffer[i] = 1.0 - float(fileData[i] / 255.0);
return true;
//! \brief Classifies digits and verify result
//! \return whether the classification output matches expectations
bool SampleOnnxMNIST::verifyOutput(const samplesCommon::BufferManager& buffers)
const int outputSize = mOutputDims.d[1];
float* output = static_cast<float*>(buffers.getHostBuffer(mParams.outputTensorNames[0]));
float val{ 0.0f };
int idx{ 0 };
// Calculate Softmax
float sum{ 0.0f };
for (int i = 0; i < outputSize; i++)
output[i] = exp(output[i]);
sum += output[i];
gLogInfo << "Output:" << std::endl;
for (int i = 0; i < outputSize; i++)
output[i] /= sum;
val = std::max(val, output[i]);
if (val == output[i])
idx = i;
gLogInfo << " Prob " << i << " " << std::fixed << std::setw(5) << std::setprecision(4) << output[i] << " "
<< "Class " << i << ": " << std::string(int(std::floor(output[i] * 10 + 0.5f)), '*') << std::endl;
gLogInfo << std::endl;
return idx == mNumber && val > 0.9f;
//! \brief Initializes members of the params struct using the command line args
samplesCommon::OnnxSampleParams initializeSampleParams(const samplesCommon::Args& args)
samplesCommon::OnnxSampleParams params;
if (args.dataDirs.empty()) //!< Use default directories if user hasn't provided directory paths
else //!< Use the data directory provided by the user
params.dataDirs = args.dataDirs;
params.onnxFileName = "model.onnx";
params.batchSize = 1;
params.dlaCore = args.useDLACore;
params.int8 = args.runInInt8;
params.fp16 = args.runInFp16;
return params;
//! \brief Prints the help information for running this sample
void printHelpInfo()
<< "Usage: ./sample_onnx_mnist [-h or --help] [-d or --datadir=<path to data directory>] [--useDLACore=<int>]"
<< std::endl;
std::cout << "--help Display help information" << std::endl;
std::cout << "--datadir Specify path to a data directory, overriding the default. This option can be used "
"multiple times to add multiple directories. If no data directories are given, the default is to use "
"(data/samples/mnist/, data/mnist/)"
<< std::endl;
std::cout << "--useDLACore=N Specify a DLA engine for layers that support DLA. Value can range from 0 to n-1, "
"where n is the number of DLA engines on the platform."
<< std::endl;
std::cout << "--int8 Run in Int8 mode." << std::endl;
std::cout << "--fp16 Run in FP16 mode." << std::endl;
int main(int argc, char** argv)
samplesCommon::Args args;
bool argsOK = samplesCommon::parseArgs(args, argc, argv);
if (!argsOK)
gLogError << "Invalid arguments" << std::endl;
if (
auto sampleTest = gLogger.defineTest(gSampleName, argc, argv);
SampleOnnxMNIST sample(initializeSampleParams(args));
gLogInfo << "Building and running a GPU inference engine for Onnx MNIST" << std::endl;
if (!
return gLogger.reportFail(sampleTest);
if (!sample.infer())
return gLogger.reportFail(sampleTest);
return gLogger.reportPass(sampleTest);
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment