Skip to content

Instantly share code, notes, and snippets.

@arnaldog12
Last active July 13, 2020 12:31
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save arnaldog12/5c6494d20a4a5d7b23a01975b66811b1 to your computer and use it in GitHub Desktop.
Save arnaldog12/5c6494d20a4a5d7b23a01975b66811b1 to your computer and use it in GitHub Desktop.
TensorFlow in C++
#pragma once
#ifndef TENSORFLOW_GRAPH_H
#define TENSORFLOW_GRAPH_H
#include "TensorflowUtils.h"
#include "TensorflowPlaceholder.h"
using namespace tensorflow;
using deeplearning::TensorflowUtils;
using deeplearning::TensorflowPlaceholder;
namespace deeplearning
{
class TensorflowGraph
{
private:
Session *session;
public:
TensorflowGraph(std::string metaFile, std::string checkpointFolder, SessionOptions options = SessionOptions())
{
MetaGraphDef graphDef = this->loadGraphFromMetaFile(metaFile);
this->session = this->createSession(graphDef.graph_def(), options);
loadCheckpoint(graphDef, checkpointFolder);
}
TensorflowGraph(std::string protobufFile, SessionOptions options = SessionOptions())
{
GraphDef graphDef = this->loadGraphFromProtobufFile(protobufFile);
this->session = this->createSession(graphDef, options);
}
TensorflowGraph(std::ostringstream& protobufFile, SessionOptions options = SessionOptions())
{
std::string decoded = protobufFile.str();
GraphDef graphDef = this->loadGraphFromString(decoded);
this->session = this->createSession(graphDef, options);
}
~TensorflowGraph()
{
tensorflow::Status status = this->session->Close();
delete this->session;
}
std::vector<std::vector<cv::Mat>> run(TensorflowPlaceholder::tensorDict feedDict, std::vector<std::string> outputTensorNames, std::vector<std::string> targetNodeNames = {})
{
std::vector<Tensor> outputsTensor;
TF_CHECK_OK(session->Run(feedDict, outputTensorNames, targetNodeNames, &outputsTensor));
return TensorflowUtils::tensor2mat(outputsTensor);
}
private:
MetaGraphDef loadGraphFromMetaFile(std::string metaFile)
{
MetaGraphDef graphDef;
TF_CHECK_OK(ReadBinaryProto(Env::Default(), metaFile, &graphDef));
return graphDef;
}
GraphDef loadGraphFromProtobufFile(std::string protobufFile)
{
GraphDef graphDef;
TF_CHECK_OK(ReadBinaryProto(Env::Default(), protobufFile, &graphDef));
return graphDef;
}
GraphDef loadGraphFromString(std::string protobufFile)
{
GraphDef graphDef;
if (!graphDef.ParseFromString(protobufFile)) throw "Nao foi possible carregar o modelo do Tensorflow!";
return graphDef;
}
Session* createSession(GraphDef graphDef, SessionOptions options)
{
Session *session;
TF_CHECK_OK(NewSession(options, &session));
TF_CHECK_OK(session->Create(graphDef));
return session;
}
void loadCheckpoint(MetaGraphDef& graphDef, std::string checkpointFolder)
{
Tensor checkpointPathTensor(DT_STRING, TensorShape());
checkpointPathTensor.scalar<std::string>()() = checkpointFolder;
TF_CHECK_OK(
session->Run(
{ { graphDef.saver_def().filename_tensor_name(), checkpointPathTensor } },
{},
{ graphDef.saver_def().restore_op_name() },
nullptr)
);
}
};
}
#endif
#pragma once
#ifndef TENSORFLOW_PLACEHOLDER_H
#define TENSORFLOW_PLACEHOLDER_H
using namespace tensorflow;
namespace deeplearning
{
class TensorflowPlaceholder
{
public:
typedef std::pair<std::string, Tensor> placeholderType;
typedef std::vector<placeholderType> tensorDict;
static placeholderType tensor(string key, Tensor t)
{
return { key, t };
}
static placeholderType boolean(string key, bool value)
{
Tensor placeholder(DT_BOOL, TensorShape());
placeholder.scalar<bool>()() = value;
return { key, placeholder };
}
};
}
#endif
#pragma once
#ifndef TENSORFLOW_UTILS_H
#define TENSORFLOW_UTILS_H
#include "opencv2/core/core.hpp"
#include "tensorflow/core/public/session.h"
#include "tensorflow/core/protobuf/meta_graph.pb.h"
using namespace tensorflow;
typedef enum
{
TENSOR_2D = 1,
TENSOR_4D = 3,
}TENSOR_SHAPE;
namespace deeplearning
{
class TensorflowUtils
{
public:
template <class T>
static Tensor mat2tensor(cv::Mat image, tensorflow::DataType type = tensorflow::DT_FLOAT, TENSOR_SHAPE shape = TENSOR_4D, int nImages = 1)
{
T *imageData = (T *)image.data;
TensorShape imageShape;
switch (shape)
{
case TENSOR_2D: imageShape = TensorShape{ nImages, image.rows * image.cols * image.channels() / nImages }; break;
default: imageShape = TensorShape{ nImages, image.rows / nImages, image.cols, image.channels() }; break;
}
Tensor imageTensor = Tensor(type, imageShape);
std::copy_n((char *)imageData, imageShape.num_elements() * sizeof(T), const_cast<char *>(imageTensor.tensor_data().data()));
return imageTensor;
}
template <class T>
static Tensor mat2tensor(std::vector<cv::Mat> images, tensorflow::DataType type = tensorflow::DT_FLOAT, TENSOR_SHAPE shape = TENSOR_4D)
{
cv::Mat imagesConcat;
cv::vconcat(images, imagesConcat);
return mat2tensor<T>(imagesConcat, type, shape, images.size());
}
static std::vector<cv::Mat> tensor2mat(Tensor tensor)
{
TensorShape shape = tensor.shape();
int nDims = shape.dims();
int nImages = shape.dim_size(0);
int width = nDims > 2 ? shape.dim_size(2) : (nDims > 1 ? shape.dim_size(1) : shape.dim_size(0));
int height = nDims > 2 ? shape.dim_size(1) : 1;
int channels = (nDims == 4) ? shape.dim_size(3) : 1;
std::vector<cv::Mat> result;
for (int i = 0; i < nImages; i++)
{
Tensor slice = tensor.Slice(i, i + 1);
assert(slice.IsAligned() == true);
float *outputData = slice.flat<float>().data();
cv::Mat imgOut(cv::Size(width, height), CV_32FC(channels));
std::copy_n((char*)outputData, slice.shape().num_elements() * sizeof(float), (char*)imgOut.data);
result.push_back(imgOut);
}
return result;
}
static std::vector<std::vector<cv::Mat>> tensor2mat(std::vector<Tensor> tensors)
{
std::vector<std::vector<cv::Mat>> results;
for (Tensor t : tensors)
results.push_back(tensor2mat(t));
return results;
}
};
}
#endif
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment