Last active
July 23, 2020 00:38
-
-
Save paleomoon/73e4c6c52f5d1889051446fcacd50739 to your computer and use it in GitHub Desktop.
Using ONNX in TensorRT
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
/* | |
* Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. | |
* | |
* Licensed under the Apache License, Version 2.0 (the "License"); | |
* you may not use this file except in compliance with the License. | |
* You may obtain a copy of the License at | |
* | |
* http://www.apache.org/licenses/LICENSE-2.0 | |
* | |
* Unless required by applicable law or agreed to in writing, software | |
* distributed under the License is distributed on an "AS IS" BASIS, | |
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
* See the License for the specific language governing permissions and | |
* limitations under the License. | |
*/ | |
//! | |
//! sampleOnnxMNIST.cpp | |
//! This file contains the implementation of the ONNX MNIST sample. It creates the network using | |
//! the MNIST onnx model. | |
//! It can be run with the following command line: | |
//! Command: ./sample_onnx_mnist [-h or --help] [-d=/path/to/data/dir or --datadir=/path/to/data/dir] | |
//! [--useDLACore=<int>] | |
//! | |
#include "argsParser.h" | |
#include "buffers.h" | |
#include "common.h" | |
#include "logger.h" | |
#include "parserOnnxConfig.h" | |
#include "NvInfer.h" | |
#include <cuda_runtime_api.h> | |
#include <cstdlib> | |
#include <fstream> | |
#include <iostream> | |
#include <sstream> | |
const std::string gSampleName = "TensorRT.sample_onnx_mnist"; | |
//! \brief The SampleOnnxMNIST class implements the ONNX MNIST sample | |
//! | |
//! \details It creates the network using an ONNX model | |
//! | |
class SampleOnnxMNIST | |
{ | |
template <typename T> | |
using SampleUniquePtr = std::unique_ptr<T, samplesCommon::InferDeleter>; | |
public: | |
SampleOnnxMNIST(const samplesCommon::OnnxSampleParams& params) | |
: mParams(params) | |
, mEngine(nullptr) | |
{ | |
} | |
//! | |
//! \brief Function builds the network engine | |
//! | |
bool build(); | |
//! | |
//! \brief Runs the TensorRT inference engine for this sample | |
//! | |
bool infer(); | |
private: | |
samplesCommon::OnnxSampleParams mParams; //!< The parameters for the sample. | |
nvinfer1::Dims mInputDims; //!< The dimensions of the input to the network. | |
nvinfer1::Dims mOutputDims; //!< The dimensions of the output to the network. | |
int mNumber{ 0 }; //!< The number to classify | |
std::shared_ptr<nvinfer1::ICudaEngine> mEngine; //!< The TensorRT engine used to run the network | |
//! | |
//! \brief Parses an ONNX model for MNIST and creates a TensorRT network | |
//! | |
bool constructNetwork(SampleUniquePtr<nvinfer1::IBuilder>& builder, | |
SampleUniquePtr<nvinfer1::INetworkDefinition>& network, SampleUniquePtr<nvinfer1::IBuilderConfig>& config, | |
SampleUniquePtr<nvonnxparser::IParser>& parser); | |
//! | |
//! \brief Reads the input and stores the result in a managed buffer | |
//! | |
bool processInput(const samplesCommon::BufferManager& buffers); | |
//! | |
//! \brief Classifies digits and verify result | |
//! | |
bool verifyOutput(const samplesCommon::BufferManager& buffers); | |
}; | |
//! | |
//! \brief Creates the network, configures the builder and creates the network engine | |
//! | |
//! \details This function creates the Onnx MNIST network by parsing the Onnx model and builds | |
//! the engine that will be used to run MNIST (mEngine) | |
//! | |
//! \return Returns true if the engine was created successfully and false otherwise | |
//! | |
bool SampleOnnxMNIST::build() | |
{ | |
auto builder = SampleUniquePtr<nvinfer1::IBuilder>(nvinfer1::createInferBuilder(gLogger.getTRTLogger())); | |
if (!builder) | |
{ | |
return false; | |
} | |
const auto explicitBatch = 1U << static_cast<uint32_t>(NetworkDefinitionCreationFlag::kEXPLICIT_BATCH); | |
//const auto explicitPrecision = 1U << static_cast<uint32_t>(NetworkDefinitionCreationFlag::kEXPLICIT_PRECISION); | |
auto network = SampleUniquePtr<nvinfer1::INetworkDefinition>(builder->createNetworkV2(explicitBatch)); | |
if (!network) | |
{ | |
return false; | |
} | |
auto config = SampleUniquePtr<nvinfer1::IBuilderConfig>(builder->createBuilderConfig()); | |
if (!config) | |
{ | |
return false; | |
} | |
auto parser = SampleUniquePtr<nvonnxparser::IParser>(nvonnxparser::createParser(*network, gLogger.getTRTLogger())); | |
if (!parser) | |
{ | |
return false; | |
} | |
auto constructed = constructNetwork(builder, network, config, parser); | |
if (!constructed) | |
{ | |
return false; | |
} | |
mEngine = std::shared_ptr<nvinfer1::ICudaEngine>( | |
builder->buildEngineWithConfig(*network, *config), samplesCommon::InferDeleter()); | |
if (!mEngine) | |
{ | |
return false; | |
} | |
assert(network->getNbInputs() == 1); | |
mInputDims = network->getInput(0)->getDimensions(); | |
assert(mInputDims.nbDims == 4); | |
assert(network->getNbOutputs() == 1); | |
mOutputDims = network->getOutput(0)->getDimensions(); | |
assert(mOutputDims.nbDims == 2); | |
return true; | |
} | |
//! | |
//! \brief Uses a ONNX parser to create the Onnx MNIST Network and marks the | |
//! output layers | |
//! | |
//! \param network Pointer to the network that will be populated with the Onnx MNIST network | |
//! | |
//! \param builder Pointer to the engine builder | |
//! | |
bool SampleOnnxMNIST::constructNetwork(SampleUniquePtr<nvinfer1::IBuilder>& builder, | |
SampleUniquePtr<nvinfer1::INetworkDefinition>& network, SampleUniquePtr<nvinfer1::IBuilderConfig>& config, | |
SampleUniquePtr<nvonnxparser::IParser>& parser) | |
{ | |
auto parsed = parser->parseFromFile( | |
locateFile(mParams.onnxFileName, mParams.dataDirs).c_str(), static_cast<int>(gLogger.getReportableSeverity())); | |
if (!parsed) | |
{ | |
return false; | |
} | |
builder->setMaxBatchSize(mParams.batchSize); | |
config->setMaxWorkspaceSize(16_MiB); | |
if (mParams.fp16) | |
{ | |
config->setFlag(BuilderFlag::kFP16); | |
} | |
if (mParams.int8) | |
{ | |
config->setFlag(BuilderFlag::kINT8); | |
samplesCommon::setAllTensorScales(network.get(), 127.0f, 127.0f); | |
} | |
samplesCommon::enableDLA(builder.get(), config.get(), mParams.dlaCore); | |
return true; | |
} | |
//! | |
//! \brief Runs the TensorRT inference engine for this sample | |
//! | |
//! \details This function is the main execution function of the sample. It allocates the buffer, | |
//! sets inputs and executes the engine. | |
//! | |
bool SampleOnnxMNIST::infer() | |
{ | |
// Create RAII buffer manager object | |
samplesCommon::BufferManager buffers(mEngine, mParams.batchSize); | |
auto context = SampleUniquePtr<nvinfer1::IExecutionContext>(mEngine->createExecutionContext()); | |
if (!context) | |
{ | |
return false; | |
} | |
// Read the input data into the managed buffers | |
assert(mParams.inputTensorNames.size() == 1); | |
if (!processInput(buffers)) | |
{ | |
return false; | |
} | |
// Memcpy from host input buffers to device input buffers | |
buffers.copyInputToDevice(); | |
bool status = context->executeV2(buffers.getDeviceBindings().data()); | |
if (!status) | |
{ | |
return false; | |
} | |
// Memcpy from device output buffers to host output buffers | |
buffers.copyOutputToHost(); | |
// Verify results | |
if (!verifyOutput(buffers)) | |
{ | |
return false; | |
} | |
return true; | |
} | |
//! | |
//! \brief Reads the input and stores the result in a managed buffer | |
//! | |
bool SampleOnnxMNIST::processInput(const samplesCommon::BufferManager& buffers) | |
{ | |
const int inputH = mInputDims.d[2]; | |
const int inputW = mInputDims.d[3]; | |
// Read a random digit file | |
srand(unsigned(time(nullptr))); | |
std::vector<uint8_t> fileData(inputH * inputW); | |
mNumber = rand() % 10; | |
readPGMFile(locateFile(std::to_string(mNumber) + ".pgm", mParams.dataDirs), fileData.data(), inputH, inputW); | |
// Print an ascii representation | |
gLogInfo << "Input:" << std::endl; | |
for (int i = 0; i < inputH * inputW; i++) | |
{ | |
gLogInfo << (" .:-=+*#%@"[fileData[i] / 26]) << (((i + 1) % inputW) ? "" : "\n"); | |
} | |
gLogInfo << std::endl; | |
float* hostDataBuffer = static_cast<float*>(buffers.getHostBuffer(mParams.inputTensorNames[0])); | |
for (int i = 0; i < inputH * inputW; i++) | |
{ | |
hostDataBuffer[i] = 1.0 - float(fileData[i] / 255.0); | |
} | |
return true; | |
} | |
//! | |
//! \brief Classifies digits and verify result | |
//! | |
//! \return whether the classification output matches expectations | |
//! | |
bool SampleOnnxMNIST::verifyOutput(const samplesCommon::BufferManager& buffers) | |
{ | |
const int outputSize = mOutputDims.d[1]; | |
float* output = static_cast<float*>(buffers.getHostBuffer(mParams.outputTensorNames[0])); | |
float val{ 0.0f }; | |
int idx{ 0 }; | |
// Calculate Softmax | |
float sum{ 0.0f }; | |
for (int i = 0; i < outputSize; i++) | |
{ | |
output[i] = exp(output[i]); | |
sum += output[i]; | |
} | |
gLogInfo << "Output:" << std::endl; | |
for (int i = 0; i < outputSize; i++) | |
{ | |
output[i] /= sum; | |
val = std::max(val, output[i]); | |
if (val == output[i]) | |
{ | |
idx = i; | |
} | |
gLogInfo << " Prob " << i << " " << std::fixed << std::setw(5) << std::setprecision(4) << output[i] << " " | |
<< "Class " << i << ": " << std::string(int(std::floor(output[i] * 10 + 0.5f)), '*') << std::endl; | |
} | |
gLogInfo << std::endl; | |
return idx == mNumber && val > 0.9f; | |
} | |
//! | |
//! \brief Initializes members of the params struct using the command line args | |
//! | |
samplesCommon::OnnxSampleParams initializeSampleParams(const samplesCommon::Args& args) | |
{ | |
samplesCommon::OnnxSampleParams params; | |
if (args.dataDirs.empty()) //!< Use default directories if user hasn't provided directory paths | |
{ | |
params.dataDirs.push_back("data/mnist/"); | |
params.dataDirs.push_back("data/samples/mnist/"); | |
} | |
else //!< Use the data directory provided by the user | |
{ | |
params.dataDirs = args.dataDirs; | |
} | |
params.onnxFileName = "model.onnx"; | |
params.inputTensorNames.push_back("skeleton"); | |
params.batchSize = 1; | |
params.outputTensorNames.push_back("action"); | |
params.dlaCore = args.useDLACore; | |
params.int8 = args.runInInt8; | |
params.fp16 = args.runInFp16; | |
return params; | |
} | |
//! | |
//! \brief Prints the help information for running this sample | |
//! | |
void printHelpInfo() | |
{ | |
std::cout | |
<< "Usage: ./sample_onnx_mnist [-h or --help] [-d or --datadir=<path to data directory>] [--useDLACore=<int>]" | |
<< std::endl; | |
std::cout << "--help Display help information" << std::endl; | |
std::cout << "--datadir Specify path to a data directory, overriding the default. This option can be used " | |
"multiple times to add multiple directories. If no data directories are given, the default is to use " | |
"(data/samples/mnist/, data/mnist/)" | |
<< std::endl; | |
std::cout << "--useDLACore=N Specify a DLA engine for layers that support DLA. Value can range from 0 to n-1, " | |
"where n is the number of DLA engines on the platform." | |
<< std::endl; | |
std::cout << "--int8 Run in Int8 mode." << std::endl; | |
std::cout << "--fp16 Run in FP16 mode." << std::endl; | |
} | |
int main(int argc, char** argv) | |
{ | |
samplesCommon::Args args; | |
bool argsOK = samplesCommon::parseArgs(args, argc, argv); | |
if (!argsOK) | |
{ | |
gLogError << "Invalid arguments" << std::endl; | |
printHelpInfo(); | |
return EXIT_FAILURE; | |
} | |
if (args.help) | |
{ | |
printHelpInfo(); | |
return EXIT_SUCCESS; | |
} | |
auto sampleTest = gLogger.defineTest(gSampleName, argc, argv); | |
gLogger.reportTestStart(sampleTest); | |
SampleOnnxMNIST sample(initializeSampleParams(args)); | |
gLogInfo << "Building and running a GPU inference engine for Onnx MNIST" << std::endl; | |
if (!sample.build()) | |
{ | |
return gLogger.reportFail(sampleTest); | |
} | |
if (!sample.infer()) | |
{ | |
return gLogger.reportFail(sampleTest); | |
} | |
return gLogger.reportPass(sampleTest); | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment