Skip to content

Instantly share code, notes, and snippets.

@penk
Created October 12, 2016 23:40
Show Gist options
  • Save penk/400b27d83467f74f7cad1ee8928325d1 to your computer and use it in GitHub Desktop.
Save penk/400b27d83467f74f7cad1ee8928325d1 to your computer and use it in GitHub Desktop.
#include <fstream>
#include <jpeglib.h>
#include <setjmp.h>
#include <QDebug>
#include <QtQml/qqml.h>
#include <QtQml/QQmlExtensionPlugin>
#include <QDir>
#include <tensorflow/core/framework/graph.pb.h>
#include <tensorflow/core/framework/tensor.h>
#include <tensorflow/core/graph/default_device.h>
#include <tensorflow/core/graph/graph_def_builder.h>
#include <tensorflow/core/lib/core/errors.h>
#include <tensorflow/core/lib/core/stringpiece.h>
#include <tensorflow/core/lib/core/threadpool.h>
#include <tensorflow/core/lib/io/path.h>
#include <tensorflow/core/lib/strings/stringprintf.h>
#include <tensorflow/core/platform/init_main.h>
#include <tensorflow/core/platform/logging.h>
#include <tensorflow/core/platform/types.h>
#include <tensorflow/core/public/session.h>
#include <tensorflow/core/util/command_line_flags.h>
using namespace tensorflow;
using tensorflow::Flag;
using tensorflow::Tensor;
using tensorflow::Status;
using tensorflow::string;
using tensorflow::int32;
class TensorFlowModel : public QObject
{
Q_OBJECT
public:
TensorFlowModel(QObject *parent=0) : QObject(parent)
{
// load and initialize the model
qDebug() << "\n\n\n==============================================\n\n\n";
Status load_graph_status = LoadGraph("/data/local/tmp/tensorflow_inception_graph.pb", &session);
qDebug() << "Load graph status:" << QString::fromStdString(load_graph_status.ToString());
}
~TensorFlowModel()
{
}
Q_INVOKABLE QString run(QString image_path) {
std::vector<Tensor> resized_tensors;
Status read_tensor_status =
//ReadTensorFromImageFile(image_path.toStdString(), 299, 299, 128, 128, &resized_tensors);
ReadTensorFromImageFile(image_path.toStdString(), 224, 224, 117, 1, &resized_tensors);
qDebug() << "Read status:" << QString::fromStdString(read_tensor_status.ToString());
const Tensor& resized_tensor = resized_tensors[0];
//string input_layer = "Mul";
string input_layer = "input";
//string output_layer = "softmax";
string output_layer = "output";
// Actually run the image through the model.
std::vector<Tensor> outputs;
qDebug() << "Before session->Run";
Status run_status = session->Run({{input_layer, resized_tensor}}, {output_layer}, {}, &outputs);
qDebug() << "Run status:" << QString::fromStdString(run_status.ToString());
string labels = "/data/local/tmp/imagenet_comp_graph_label_strings.txt";
result_labels = "";
Status print_status = PrintTopLabels(outputs, labels);
qDebug() << "Label status:" << QString::fromStdString(print_status.ToString());
return result_labels;
}
Status LoadGraph(string graph_file_name, std::unique_ptr<tensorflow::Session>* session)
{
tensorflow::GraphDef graph_def;
Status load_graph_status =
ReadBinaryProto(tensorflow::Env::Default(), graph_file_name, &graph_def);
if (!load_graph_status.ok()) {
return tensorflow::errors::NotFound("Failed to load compute graph at '",
graph_file_name, "'");
}
tensorflow::SessionOptions options;
tensorflow::ConfigProto& config = options.config;
config.set_intra_op_parallelism_threads(4);
session->reset(tensorflow::NewSession(tensorflow::SessionOptions(options)));
Status session_create_status = (*session)->Create(graph_def);
if (!session_create_status.ok()) {
return session_create_status;
}
return Status::OK();
}
Status ReadTensorFromImageFile(string file_name, const int wanted_height,
const int wanted_width, const float input_mean,
const float input_std, std::vector<Tensor>* out_tensors)
{
std::vector<tensorflow::uint8> image_data;
int image_width;
int image_height;
int image_channels;
TF_RETURN_IF_ERROR(LoadJpegFile(file_name, &image_data, &image_width,
&image_height, &image_channels));
LOG(INFO) << "Loaded JPEG: " << image_width << "x" << image_height
<< "x" << image_channels;
const int wanted_channels = 3;
if (image_channels < wanted_channels) {
return tensorflow::errors::FailedPrecondition("Image needs to have at least ",
wanted_channels, " but only has ",
image_channels);
}
tensorflow::Tensor image_tensor(
tensorflow::DT_FLOAT, tensorflow::TensorShape(
{1, wanted_height, wanted_width, wanted_channels}));
auto image_tensor_mapped = image_tensor.tensor<float, 4>();
tensorflow::uint8* in = image_data.data();
float *out = image_tensor_mapped.data();
const size_t image_rowlen = image_width * image_channels;
const float width_scale = static_cast<float>(image_width) / wanted_width;
const float height_scale = static_cast<float>(image_height) / wanted_height;
for (int y = 0; y < wanted_height; ++y) {
const float in_y = y * height_scale;
const int top_y_index = static_cast<int>(floorf(in_y));
const int bottom_y_index =
std::min(static_cast<int>(ceilf(in_y)), (image_height - 1));
const float y_lerp = in_y - top_y_index;
tensorflow::uint8* in_top_row = in + (top_y_index * image_rowlen);
tensorflow::uint8* in_bottom_row = in + (bottom_y_index * image_rowlen);
float *out_row = out + (y * wanted_width * wanted_channels);
for (int x = 0; x < wanted_width; ++x) {
const float in_x = x * width_scale;
const int left_x_index = static_cast<int>(floorf(in_x));
const int right_x_index =
std::min(static_cast<int>(ceilf(in_x)), (image_width - 1));
tensorflow::uint8* in_top_left_pixel =
in_top_row + (left_x_index * wanted_channels);
tensorflow::uint8* in_top_right_pixel =
in_top_row + (right_x_index * wanted_channels);
tensorflow::uint8* in_bottom_left_pixel =
in_bottom_row + (left_x_index * wanted_channels);
tensorflow::uint8* in_bottom_right_pixel =
in_bottom_row + (right_x_index * wanted_channels);
const float x_lerp = in_x - left_x_index;
float *out_pixel = out_row + (x * wanted_channels);
for (int c = 0; c < wanted_channels; ++c) {
const float top_left((in_top_left_pixel[c] - input_mean) / input_std);
const float top_right((in_top_right_pixel[c] - input_mean) / input_std);
const float bottom_left((in_bottom_left_pixel[c] - input_mean) / input_std);
const float bottom_right((in_bottom_right_pixel[c] - input_mean) / input_std);
const float top = top_left + (top_right - top_left) * x_lerp;
const float bottom =
bottom_left + (bottom_right - bottom_left) * x_lerp;
out_pixel[c] = top + (bottom - top) * y_lerp;
}
}
}
out_tensors->push_back(image_tensor);
return Status::OK();
}
Status LoadJpegFile(string file_name, std::vector<tensorflow::uint8>* data,
int* width, int* height, int* channels) {
struct jpeg_decompress_struct cinfo;
FILE * infile;
JSAMPARRAY buffer;
int row_stride;
if ((infile = fopen(file_name.c_str(), "rb")) == NULL) {
LOG(ERROR) << "Can't open " << file_name;
return tensorflow::errors::NotFound("JPEG file ", file_name,
" not found");
}
struct jpeg_error_mgr jerr;
jmp_buf jpeg_jmpbuf;
cinfo.err = jpeg_std_error(&jerr);
cinfo.client_data = &jpeg_jmpbuf;
//jerr.error_exit = CatchError;
if (setjmp(jpeg_jmpbuf)) {
return tensorflow::errors::Unknown("JPEG decoding failed");
}
jpeg_create_decompress(&cinfo);
jpeg_stdio_src(&cinfo, infile);
jpeg_read_header(&cinfo, TRUE);
jpeg_start_decompress(&cinfo);
*width = cinfo.output_width;
*height = cinfo.output_height;
*channels = cinfo.output_components;
data->resize((*height) * (*width) * (*channels));
row_stride = cinfo.output_width * cinfo.output_components;
buffer = (*cinfo.mem->alloc_sarray)
((j_common_ptr) &cinfo, JPOOL_IMAGE, row_stride, 1);
while (cinfo.output_scanline < cinfo.output_height) {
tensorflow::uint8* row_address = &((*data)[cinfo.output_scanline * row_stride]);
jpeg_read_scanlines(&cinfo, buffer, 1);
memcpy(row_address, buffer[0], row_stride);
}
jpeg_finish_decompress(&cinfo);
jpeg_destroy_decompress(&cinfo);
fclose(infile);
return Status::OK();
}
Status PrintTopLabels(const std::vector<Tensor>& outputs, string labels_file_name)
{
std::vector<string> labels;
size_t label_count;
Status read_labels_status = ReadLabelsFile(labels_file_name, &labels, &label_count);
if (!read_labels_status.ok()) {
LOG(ERROR) << read_labels_status;
return read_labels_status;
}
const int how_many_labels = std::min(5, static_cast<int>(label_count));
Tensor indices;
Tensor scores;
TF_RETURN_IF_ERROR(GetTopLabels(outputs, how_many_labels, &indices, &scores));
tensorflow::TTypes<float>::Flat scores_flat = scores.flat<float>();
tensorflow::TTypes<int32>::Flat indices_flat = indices.flat<int32>();
for (int pos = 0; pos < how_many_labels; ++pos) {
const int label_index = indices_flat(pos);
const float score = scores_flat(pos);
result_labels.append(
QString::fromStdString(labels[label_index]) + " (" +
QString::number(label_index) + "): " + QString::number(score) + "\n"
);
}
return Status::OK();
}
Status GetTopLabels(const std::vector<Tensor>& outputs, int how_many_labels, Tensor* out_indices, Tensor* out_scores) {
const Tensor& unsorted_scores_tensor = outputs[0];
auto unsorted_scores_flat = unsorted_scores_tensor.flat<float>();
std::vector<std::pair<int, float>> scores;
for (int i = 0; i < unsorted_scores_flat.size(); ++i) {
scores.push_back(std::pair<int, float>({i, unsorted_scores_flat(i)}));
}
std::sort(scores.begin(), scores.end(),
[](const std::pair<int, float> &left,
const std::pair<int, float> &right) {
return left.second > right.second;
});
scores.resize(how_many_labels);
Tensor sorted_indices(tensorflow::DT_INT32, {scores.size()});
Tensor sorted_scores(tensorflow::DT_FLOAT, {scores.size()});
for (int i = 0; i < scores.size(); ++i) {
sorted_indices.flat<int>()(i) = scores[i].first;
sorted_scores.flat<float>()(i) = scores[i].second;
}
*out_indices = sorted_indices;
*out_scores = sorted_scores;
return Status::OK();
}
Status ReadLabelsFile(string file_name, std::vector<string>* result, size_t* found_label_count)
{
std::ifstream file(file_name);
if (!file) {
return tensorflow::errors::NotFound("Labels file ", file_name, " not found.");
}
result->clear();
string line;
while (std::getline(file, line)) {
result->push_back(line);
}
*found_label_count = result->size();
const int padding = 16;
while (result->size() % padding) {
result->emplace_back();
}
return Status::OK();
}
public:
std::unique_ptr<tensorflow::Session> session;
QString result_labels;
};
class TensorFlowPlugin : public QQmlExtensionPlugin
{
Q_OBJECT
Q_PLUGIN_METADATA(IID "io.dt42.TensorFlow" FILE "tensorflow.json")
public:
void registerTypes(const char *uri)
{
qmlRegisterType<TensorFlowModel>(uri, 1, 0, "TensorFlow");
}
};
#include "plugin.moc"
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment