Created
November 14, 2018 10:04
-
-
Save GrimMaple/f64693887452a9c659d3e48a0161098a to your computer and use it in GitHub Desktop.
Batch MNIST Sample
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
//! \file sampleMNIST.cpp | |
//! \brief This file contains the implementation of the MNIST sample. | |
//! | |
//! It builds a TensorRT engine by importing a trained MNIST Caffe model. It uses the engine to run | |
//! inference on an input image of a digit. | |
//! It can be run with the following command line: | |
//! Command: ./sample_mnist [-h or --help] [-d=/path/to/data/dir or --datadir=/path/to/data/dir] | |
#include "common.h" | |
#include "argsParser.h" | |
#include "buffers.h" | |
#include "NvCaffeParser.h" | |
#include "NvInfer.h" | |
#include <cuda_runtime_api.h> | |
#include <algorithm> | |
#include <cassert> | |
#include <cmath> | |
#include <fstream> | |
#include <iostream> | |
#include <sstream> | |
static Logger gLogger; | |
//! | |
//! \brief The MNISTSampleParams structure groups the additional parameters required by | |
//! the MNIST sample. | |
//! | |
struct MNISTSampleParams : public samplesCommon::CaffeSampleParams | |
{ | |
std::string meanFileName; | |
}; | |
//! | |
//! \brief The SampleMNIST class implements the MNIST sample | |
//! | |
//! \details It creates the network using a trained Caffe MNIST classification model | |
//! | |
class SampleMNIST | |
{ | |
template <typename T> | |
using SampleUniquePtr = std::unique_ptr<T, samplesCommon::InferDeleter>; | |
public: | |
SampleMNIST(const MNISTSampleParams& params) | |
: mParams(params) | |
{ | |
} | |
//! | |
//! \brief Function builds the network engine | |
//! | |
bool build(); | |
//! | |
//! \brief This function runs the TensorRT inference engine for this sample | |
//! | |
bool infer(); | |
//! | |
//! \brief This function can be used to clean up any state created in the sample class | |
//! | |
bool teardown(); | |
private: | |
//! | |
//! \brief This function uses a Caffe parser to create the MNIST Network and marks the | |
//! output layers | |
//! | |
void constructNetwork(SampleUniquePtr<nvinfer1::IBuilder>& builder, SampleUniquePtr<nvinfer1::INetworkDefinition>& network, SampleUniquePtr<nvcaffeparser1::ICaffeParser>& parser); | |
//! | |
//! \brief Reads the input and mean data, preprocesses, and stores the result in a managed buffer | |
//! | |
bool processInput(const samplesCommon::BufferManager& buffers, const std::string& inputTensorName, int inputFileIdx) const; | |
//! | |
//! \brief Verifies that the output is correct and prints it | |
//! | |
bool verifyOutput(const samplesCommon::BufferManager& buffers, const std::string& outputTensorName, int groundTruthDigit) const; | |
std::shared_ptr<nvinfer1::ICudaEngine> mEngine = nullptr; //!< The TensorRT engine used to run the network | |
MNISTSampleParams mParams; //!< The parameters for the sample. | |
nvinfer1::Dims mInputDims; //!< The dimensions of the input to the network. | |
SampleUniquePtr<nvcaffeparser1::IBinaryProtoBlob> mMeanBlob; //! the mean blob, which we need to keep around until build is done | |
}; | |
//! | |
//! \brief This function creates the network, configures the builder and creates the network engine | |
//! | |
//! \details This function creates the MNIST network by parsing the caffe model and builds | |
//! the engine that will be used to run MNIST (mEngine) | |
//! | |
//! \return Returns true if the engine was created successfully and false otherwise | |
//! | |
bool SampleMNIST::build() | |
{ | |
auto builder = SampleUniquePtr<nvinfer1::IBuilder>(nvinfer1::createInferBuilder(gLogger)); | |
if (!builder) return false; | |
auto network = SampleUniquePtr<nvinfer1::INetworkDefinition>(builder->createNetwork()); | |
if (!network) return false; | |
auto parser = SampleUniquePtr<nvcaffeparser1::ICaffeParser>(nvcaffeparser1::createCaffeParser()); | |
if (!parser) return false; | |
constructNetwork(builder, network, parser); | |
builder->setMaxBatchSize(/*mParams.batchSize*/10); | |
builder->setMaxWorkspaceSize(16_MB); | |
builder->allowGPUFallback(true); | |
if (mParams.dlaID > 0) samplesCommon::enableDLA(builder.get(), mParams.dlaID); | |
mEngine = std::move(std::shared_ptr<nvinfer1::ICudaEngine>(builder->buildCudaEngine(*network), samplesCommon::InferDeleter())); | |
if (!mEngine) return false; | |
assert(network->getNbInputs() == 1); | |
mInputDims = network->getInput(0)->getDimensions(); | |
assert(mInputDims.nbDims == 3); | |
return true; | |
} | |
//! | |
//! \brief Reads the input and mean data, preprocesses, and stores the result in a managed buffer | |
//! | |
bool SampleMNIST::processInput(const samplesCommon::BufferManager& buffers, const std::string& inputTensorName, int inputFileIdx) const | |
{ | |
const int inputH = mInputDims.d[1]; | |
const int inputW = mInputDims.d[2]; | |
// Read a random digit file | |
srand(unsigned(time(nullptr))); | |
uint8_t fileData[inputH * inputW * 10]; | |
readPGMFile(locateFile(std::to_string(inputFileIdx) + ".pgm", mParams.dataDirs), fileData, inputH, inputW); | |
// Print ASCII representation of digit | |
std::cout << "\nInput:\n" | |
<< std::endl; | |
for (int i = 0; i < inputH * inputW; i++) | |
std::cout << (" .:-=+*#%@"[fileData[i] / 26]) << (((i + 1) % inputW) ? "" : "\n"); | |
float* hostInputBuffer = static_cast<float*>(buffers.getHostBuffer(inputTensorName)); | |
for (int i = 0; i < inputH * inputW * 10; i++) | |
hostInputBuffer[i] = float(fileData[i%10]); | |
return true; | |
} | |
//! | |
//! \brief Verifies that the output is correct and prints it | |
//! | |
bool SampleMNIST::verifyOutput(const samplesCommon::BufferManager& buffers, const std::string& outputTensorName, int groundTruthDigit) const | |
{ | |
const float* prob = static_cast<const float*>(buffers.getHostBuffer(outputTensorName)); | |
// Print histogram of the output distribution | |
std::cout << "\nOutput:\n\n"; | |
for(int j=0; j<10; j++) | |
{ | |
float val{0.0f}; | |
int idx{0}; | |
for (unsigned int i = 0; i < 10; i++) | |
{ | |
val = std::max(val, prob[i+j*10]); | |
if (val == prob[i]) idx = i; | |
std::cout << i << ": " << std::string(int(std::floor(prob[i+j*10] * 10 + 0.5f)), '*') << "\n"; | |
} | |
std::cout << std::endl<< std::endl<< std::endl; | |
} | |
std::cout << std::endl; | |
return true; | |
} | |
//! | |
//! \brief This function uses a caffe parser to create the MNIST Network and marks the | |
//! output layers | |
//! | |
//! \param network Pointer to the network that will be populated with the MNIST network | |
//! | |
//! \param builder Pointer to the engine builder | |
//! | |
void SampleMNIST::constructNetwork(SampleUniquePtr<nvinfer1::IBuilder>& builder, SampleUniquePtr<nvinfer1::INetworkDefinition>& network, SampleUniquePtr<nvcaffeparser1::ICaffeParser>& parser) | |
{ | |
const nvcaffeparser1::IBlobNameToTensor* blobNameToTensor = parser->parse( | |
locateFile(mParams.prototxtFileName, mParams.dataDirs).c_str(), | |
locateFile(mParams.weightsFileName, mParams.dataDirs).c_str(), | |
*network, | |
nvinfer1::DataType::kFLOAT); | |
for (auto& s : mParams.outputTensorNames) | |
network->markOutput(*blobNameToTensor->find(s.c_str())); | |
// add mean subtraction to the beginning of the network | |
Dims inputDims = network->getInput(0)->getDimensions(); | |
mMeanBlob = SampleUniquePtr<nvcaffeparser1::IBinaryProtoBlob>(parser->parseBinaryProto(locateFile(mParams.meanFileName, mParams.dataDirs).c_str())); | |
Weights meanWeights{DataType::kFLOAT, mMeanBlob->getData(), inputDims.d[1]*inputDims.d[2]}; | |
auto mean = network->addConstant(Dims3(1, inputDims.d[1], inputDims.d[2]), meanWeights); | |
auto meanSub = network->addElementWise(*network->getInput(0), *mean->getOutput(0), ElementWiseOperation::kSUB); | |
network->getLayer(0)->setInput(0, *meanSub->getOutput(0)); | |
} | |
//! | |
//! \brief This function runs the TensorRT inference engine for this sample | |
//! | |
//! \details This function is the main execution function of the sample. It allocates | |
//! the buffer, sets inputs, executes the engine, and verifies the output. | |
//! | |
bool SampleMNIST::infer() | |
{ | |
// Create RAII buffer manager object | |
samplesCommon::BufferManager buffers(mEngine, 10/*mParams.batchSize*/); | |
auto context = SampleUniquePtr<nvinfer1::IExecutionContext>(mEngine->createExecutionContext()); | |
if (!context) return false; | |
// Pick a random digit to try to infer | |
srand(time(NULL)); | |
const int digit = rand() % 10; | |
// Read the input data into the managed buffers | |
// There should be just 1 input tensor | |
assert(mParams.inputTensorNames.size() == 1); | |
if (!processInput(buffers, mParams.inputTensorNames[0], digit)) return false; | |
// Create CUDA stream for the execution of this inference. | |
cudaStream_t stream; | |
CHECK(cudaStreamCreate(&stream)); | |
// Asynchronously copy data from host input buffers to device input buffers | |
buffers.copyInputToDeviceAsync(stream); | |
// Asynchronously enqueue the inference work | |
if (!context->enqueue(/*mParams.batchSize*/ 10, buffers.getDeviceBindings().data(), stream, nullptr)) return false; | |
// Asynchronously copy data from device output buffers to host output buffers | |
buffers.copyOutputToHostAsync(stream); | |
// Wait for the work in the stream to complete | |
cudaStreamSynchronize(stream); | |
// Release stream | |
cudaStreamDestroy(stream); | |
// Check and print the output of the inference | |
// There should be just one output tensor | |
assert(mParams.outputTensorNames.size() == 1); | |
bool outputCorrect = verifyOutput(buffers, mParams.outputTensorNames[0], digit); | |
return outputCorrect; | |
} | |
//! | |
//! \brief This function can be used to clean up any state created in the sample class | |
//! | |
bool SampleMNIST::teardown() | |
{ | |
//! Clean up the libprotobuf files as the parsing is complete | |
//! \note It is not safe to use any other part of the protocol buffers library after | |
//! ShutdownProtobufLibrary() has been called. | |
nvcaffeparser1::shutdownProtobufLibrary(); | |
return true; | |
} | |
//! | |
//! \brief This function initializes members of the params struct using the command line args | |
//! | |
MNISTSampleParams initializeSampleParams(const samplesCommon::Args& args) | |
{ | |
MNISTSampleParams params; | |
if (args.dataDirs.size() != 0) //!< Use the data directory provided by the user | |
params.dataDirs = args.dataDirs; | |
else //!< Use default directories if user hasn't provided directory paths | |
{ | |
params.dataDirs.push_back("data/mnist/"); | |
params.dataDirs.push_back("data/samples/mnist/"); | |
} | |
params.prototxtFileName = "mnist.prototxt"; | |
params.weightsFileName = "mnist.caffemodel"; | |
params.meanFileName = "mnist_mean.binaryproto"; | |
params.inputTensorNames.push_back("data"); | |
params.batchSize = 1; | |
params.outputTensorNames.push_back("prob"); | |
params.dlaID = args.useDLA; | |
return params; | |
} | |
//! | |
//! \brief This function prints the help information for running this sample | |
//! | |
void printHelpInfo() | |
{ | |
std::cout << "Usage: ./sample_mnist [-h or --help] [-d or --datadir=<path to data directory>]\n"; | |
std::cout << "--help Display help information\n"; | |
std::cout << "--datadir Specify path to a data directory, overriding the default. This option can be used multiple times to add multiple directories. If no data directories are given, the default is to use (data/samples/mnist/, data/mnist/)" << std::endl; | |
} | |
int main(int argc, char** argv) | |
{ | |
samplesCommon::Args args; | |
if (!samplesCommon::parseArgs(args, argc, argv)) | |
{ | |
if (args.help) | |
{ | |
printHelpInfo(); | |
return EXIT_SUCCESS; | |
} | |
return EXIT_FAILURE; | |
} | |
MNISTSampleParams params = initializeSampleParams(args); | |
SampleMNIST sample(params); | |
std::cout << "Building and running a GPU inference engine for MNIST" << std::endl; | |
if (!sample.build()) | |
return EXIT_FAILURE; | |
if (!sample.infer()) | |
return EXIT_FAILURE; | |
if (!sample.teardown()) | |
return EXIT_FAILURE; | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment