Skip to content

Instantly share code, notes, and snippets.

@saitodev
Created October 2, 2016 10:51
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save saitodev/3fcfc1b3b5ef05ece1ee47c639280687 to your computer and use it in GitHub Desktop.
Save saitodev/3fcfc1b3b5ef05ece1ee47c639280687 to your computer and use it in GitHub Desktop.
// -*- coding: utf-8 -*-
#include <iostream>
#include <fstream>
#include <memory>
#include <vector>
#include <cassert>
#include <cstdint>
#include <boost/iostreams/filtering_stream.hpp>
#include <boost/iostreams/filter/gzip.hpp>
#include "tensorflow/core/platform/init_main.h"
#include "tensorflow/core/public/session.h"
#include "tensorflow/core/util/command_line_flags.h"
tensorflow::Status
LoadGraph(tensorflow::string graph_filename,
std::unique_ptr<tensorflow::Session>* session)
{
tensorflow::GraphDef graph_def;
tensorflow::Status load_graph_status =
ReadBinaryProto(tensorflow::Env::Default(), graph_filename, &graph_def);
if (!load_graph_status.ok()) {
return tensorflow::errors::NotFound("Failed to load compute graph at '",
graph_filename, "'");
}
session->reset(tensorflow::NewSession(tensorflow::SessionOptions()));
tensorflow::Status session_create_status = (*session)->Create(graph_def);
if (!session_create_status.ok()) {
return session_create_status;
}
return tensorflow::Status::OK();
}
tensorflow::Tensor
LoadMnistImages(tensorflow::string filename)
{
const int N_header = 16;
const int N_data = 10000;
const int N_width = 28;
const int N_height = 28;
const int N_vec = N_width * N_height;
auto tensor = tensorflow::Tensor(tensorflow::DT_FLOAT, tensorflow::TensorShape({N_data, N_vec}));
auto mat = tensor.tensor<float, 2>();
mat.setZero();
std::ifstream fin(filename, std::ios_base::in | std::ios_base::binary);
assert(!fin.fail());
boost::iostreams::filtering_istream s;
s.push(boost::iostreams::gzip_decompressor());
s.push(fin);
char c;
for (int i=0; i<N_header; ++i) {
s.get(c);
}
for (int n=0; n<N_data; ++n) {
for (int h=0; h<N_height; ++h) {
for (int w=0; w<N_width; ++w) {
s.get(c);
mat(n, h*N_width + w) = static_cast<float>(static_cast<uint8_t>(c)) / 255.0;
}
}
}
return tensor;
}
tensorflow::Tensor
LoadMnistLabels(tensorflow::string filename)
{
const int N_header = 8;
const int N_data = 10000;
const int N_vec = 10;
auto tensor = tensorflow::Tensor(tensorflow::DT_FLOAT, tensorflow::TensorShape({N_data, N_vec}));
auto mat = tensor.tensor<float, 2>();
mat.setZero();
std::ifstream fin(filename, std::ios_base::in | std::ios_base::binary);
assert(!fin.fail());
boost::iostreams::filtering_istream s;
s.push(boost::iostreams::gzip_decompressor());
s.push(fin);
char c;
for (int i=0; i<N_header; ++i) {
s.get(c);
}
for (int n=0; n<N_data; ++n) {
s.get(c);
assert((c >= 0) && (c < 10));
mat(n, c) = 1.0;
}
return tensor;
}
int main(int argc, char* argv[])
{
tensorflow::string graph_filename = "trained_graph.pb";
tensorflow::string image_filename = "MNIST_data/t10k-images-idx3-ubyte.gz";
tensorflow::string label_filename = "MNIST_data/t10k-labels-idx1-ubyte.gz";
const bool parse_result = tensorflow::ParseFlags(
&argc, argv,
{tensorflow::Flag("graph", &graph_filename),
tensorflow::Flag("image", &image_filename),
tensorflow::Flag("label", &label_filename)});
if (!parse_result) {
LOG(ERROR) << "Error parsing command-line flags.";
return -1;
}
tensorflow::port::InitMain(argv[0], &argc, &argv);
if (argc > 1) {
LOG(ERROR) << "Unknown argument " << argv[1];
return -1;
}
std::unique_ptr<tensorflow::Session> session;
auto load_graph_status = LoadGraph(graph_filename, &session);
if (!load_graph_status.ok()) {
LOG(ERROR) << load_graph_status.error_message();
return -1;
}
auto x = LoadMnistImages(image_filename);
auto y_ = LoadMnistLabels(label_filename);
std::vector<tensorflow::Tensor> outputs;
auto session_run_status = session->Run({{"x:0", x}, {"y_:0", y_}},
{"accuracy:0"},
{},
&outputs);
if (!session_run_status.ok()) {
LOG(ERROR) << session_run_status.error_message();
return -1;
}
float accuracy = outputs[0].scalar<float>()(0);
std::cout << "accuracy = " << accuracy << std::endl;
return 0;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment