Skip to content

Instantly share code, notes, and snippets.

@jimfleming
Last active January 2, 2022 09:44
Show Gist options
  • Star 12 You must be signed in to star a gist
  • Fork 4 You must be signed in to fork a gist
  • Save jimfleming/4202e529042c401b17b7 to your computer and use it in GitHub Desktop.
Save jimfleming/4202e529042c401b17b7 to your computer and use it in GitHub Desktop.
Load a saved graph definition from Protobuf.
#include "tensorflow/core/public/session.h"
#include "tensorflow/core/platform/env.h"
using namespace tensorflow;
int main(int argc, char* argv[]) {
// Initialize a tensorflow session
Session* session;
Status status = NewSession(SessionOptions(), &session);
if (!status.ok()) {
std::cout << status.ToString() << "\n";
return 1;
}
// Read in the protobuf graph we exported
// (The path seems to be relative to the cwd. Keep this in mind
// when using `bazel run` since the cwd isn't where you call
// `bazel run` but from inside a temp folder.)
GraphDef graph_def;
status = ReadBinaryProto(Env::Default(), "models/graph.pb", &graph_def);
if (!status.ok()) {
std::cout << status.ToString() << "\n";
return 1;
}
// Add the graph to the session
status = session->Create(graph_def);
if (!status.ok()) {
std::cout << status.ToString() << "\n";
return 1;
}
// Setup inputs and outputs:
// Our graph doesn't require any inputs, since it specifies default values,
// but we'll change an input to demonstrate.
Tensor a(DT_FLOAT, TensorShape());
a.scalar<float>()() = 3.0;
Tensor b(DT_FLOAT, TensorShape());
b.scalar<float>()() = 2.0;
std::vector<std::pair<string, tensorflow::Tensor>> inputs = {
{ "a", a },
{ "b", b },
};
// The session will initialize the outputs
std::vector<tensorflow::Tensor> outputs;
// Run the session, evaluating our "c" operation from the graph
status = session->Run(inputs, {"c"}, {}, &outputs);
if (!status.ok()) {
std::cout << status.ToString() << "\n";
return 1;
}
// Grab the first output (we only evaluated one graph node: "c")
// and convert the node to a scalar representation.
auto output_c = outputs[0].scalar<float>();
// (There are similar methods for vectors and matrices here:
// https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/public/tensor.h)
// Print the results
std::cout << outputs[0].DebugString() << "\n"; // Tensor<type: float shape: [] values: 30>
std::cout << output_c() << "\n"; // 30
// Free any resources used by the session
session->Close();
return 0;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment