Skip to content

Instantly share code, notes, and snippets.

@asimshankar
Last active January 28, 2024 17:24
Show Gist options
  • Star 50 You must be signed in to star a gist
  • Fork 12 You must be signed in to fork a gist
  • Save asimshankar/5c96acd1280507940bad9083370fe8dc to your computer and use it in GitHub Desktop.
Save asimshankar/5c96acd1280507940bad9083370fe8dc to your computer and use it in GitHub Desktop.
Training TensorFlow models in C++

Training TensorFlow models in C++

Python is the primary language in which TensorFlow models are typically developed and trained. TensorFlow does have bindings for other programming languages. These bindings have the low-level primitives that are required to build a more complete API, however, lack much of the higher-level API richness of the Python bindings, particularly for defining the model structure.

This file demonstrates taking a model (a TensorFlow graph) created by a Python program and running the training loop in C++.

The model

The model is a trivial one, trying to learn the function: f(x) = W\*x + b, where W and b are model parameters. The training data is constructed so that the "true" value of W is 3 and that of b is 2, i.e., f(x) = 3 * x + 2.

Files

  • model.py: Python code that constructs a model and saves the computational graph in file called graph.pb. TAll other files assume that model.py has been run once.
  • train.cc: C++ code that loads the model, optionally loads model weights saved in a checkpoint, trains a few steps, writes the updated model weights to a checkpoint.

Noteworthy

  • The Python APIs for TensorFlow include other conveniences for training (such as MonitoredSession and tf.train.Estimator), which make it easier to configure checkpointing, evaluation loops etc. The examples here aren't that sophisticated and are focused on basic model training only.
  • In this example, we use placeholders and feed dictionaries to feed input, but in a real example you probably want to use the tf.data API to cconstruct an input pipeline for providing training data to the model.
  • Not demonstrated here, but summaries for TensorBoard can also be produced by executing the summary operations.

See Also

import tensorflow as tf
# Batch of input and target output (1x1 matrices)
x = tf.placeholder(tf.float32, shape=[None, 1, 1], name='input')
y = tf.placeholder(tf.float32, shape=[None, 1, 1], name='target')
# Trivial linear model
y_ = tf.identity(tf.layers.dense(x, 1), name='output')
# Optimize loss
loss = tf.reduce_mean(tf.square(y_ - y), name='loss')
optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.01)
train_op = optimizer.minimize(loss, name='train')
init = tf.global_variables_initializer()
# tf.train.Saver.__init__ adds operations to the graph to save
# and restore variables.
saver_def = tf.train.Saver().as_saver_def()
print('Run this operation to initialize variables : ', init.name)
print('Run this operation for a train step : ', train_op.name)
print('Feed this tensor to set the checkpoint filename: ', saver_def.filename_tensor_name)
print('Run this operation to save a checkpoint : ', saver_def.save_tensor_name)
print('Run this operation to restore a checkpoint : ', saver_def.restore_op_name)
# Write the graph out to a file.
with open('graph.pb', 'w') as f:
f.write(tf.get_default_graph().as_graph_def().SerializeToString())
// Example of training the model created by main.py in a C++ program.
//
// See also
// https://github.com/tensorflow/tensorflow/blob/r1.4/tensorflow/examples/label_image/main.cc
#include <iostream>
#include <vector>
#include <cstdlib>
#include <string>
#include <sys/stat.h>
#include "third_party/tensorflow/core/framework/graph.proto.h"
#include "third_party/tensorflow/core/framework/tensor.h"
#include "third_party/tensorflow/core/lib/io/path.h"
#include "third_party/tensorflow/core/platform/env.h"
#include "third_party/tensorflow/core/platform/init_main.h"
#include "third_party/tensorflow/core/platform/logging.h"
#include "third_party/tensorflow/core/platform/types.h"
#include "third_party/tensorflow/core/public/session.h"
class Model {
public:
Model(const string& graph_def_filename) {
tensorflow::GraphDef graph_def;
TF_CHECK_OK(tensorflow::ReadBinaryProto(tensorflow::Env::Default(),
graph_def_filename, &graph_def));
session_.reset(tensorflow::NewSession(tensorflow::SessionOptions()));
TF_CHECK_OK(session_->Create(graph_def));
}
void Init() { TF_CHECK_OK(session_->Run({}, {}, {"init"}, nullptr)); }
void Restore(const string& checkpoint_prefix) {
SaveOrRestore(checkpoint_prefix, "save/restore_all");
}
void Predict(const std::vector<float>& batch) {
std::vector<tensorflow::Tensor> out_tensors;
TF_CHECK_OK(session_->Run({{"input", MakeTensor(batch)}}, {"output"}, {},
&out_tensors));
for (int i = 0; i < batch.size(); ++i) {
std::cout << "\t x = " << batch[i]
<< ", predicted y = " << out_tensors[0].flat<float>()(i)
<< "\n";
}
}
void RunTrainStep(const std::vector<float>& input_batch,
const std::vector<float>& target_batch) {
TF_CHECK_OK(session_->Run({{"input", MakeTensor(input_batch)},
{"target", MakeTensor(target_batch)}},
{}, {"train"}, nullptr));
}
void Checkpoint(const string& checkpoint_prefix) {
SaveOrRestore(checkpoint_prefix, "save/control_dependency");
}
private:
tensorflow::Tensor MakeTensor(const std::vector<float>& batch) {
tensorflow::Tensor t(tensorflow::DT_FLOAT,
tensorflow::TensorShape({(int)batch.size(), 1, 1}));
for (int i = 0; i < batch.size(); ++i) {
t.flat<float>()(i) = batch[i];
}
return t;
}
void SaveOrRestore(const string& checkpoint_prefix, const string& op_name) {
tensorflow::Tensor t(tensorflow::DT_STRING, tensorflow::TensorShape());
t.scalar<string>()() = checkpoint_prefix;
TF_CHECK_OK(session_->Run({{"save/Const", t}}, {}, {op_name}, nullptr));
}
std::unique_ptr<tensorflow::Session> session_;
};
bool DirectoryExists(const string& dir) {
struct stat buf;
return stat(dir.c_str(), &buf) == 0;
}
int main(int argc, char* argv[]) {
const string graph_def_filename =
"/usr/local/google/home/ashankar/tmp/gist/graph.pb";
const string checkpoint_dir = "/usr/local/google/home/ashankar/tmp/gist/checkpoints";
const string checkpoint_prefix = checkpoint_dir + "/checkpoint";
bool restore = DirectoryExists(checkpoint_dir);
// Setup global state for TensorFlow.
tensorflow::port::InitMain(argv[0], &argc, &argv);
std::cout << "Loading graph\n";
Model model(graph_def_filename);
if (!restore) {
std::cout << "Initializing model weights\n";
model.Init();
} else {
std::cout << "Restoring model weights from checkpoint\n";
model.Restore(checkpoint_prefix);
}
const std::vector<float> testdata({1.0, 2.0, 3.0});
std::cout << "Initial predictions\n";
model.Predict(testdata);
std::cout << "Training for a few steps\n";
for (int i = 0; i < 200; ++i) {
std::vector<float> train_inputs, train_targets;
for (int j = 0; j < 10; j++) {
train_inputs.push_back(static_cast<float>(std::rand()) / static_cast<float>(RAND_MAX));
train_targets.push_back(3 * train_inputs.back() + 2);
}
model.RunTrainStep(train_inputs, train_targets);
}
std::cout << "Updated predictions\n";
model.Predict(testdata);
std::cout << "Saving checkpoint\n";
model.Checkpoint(checkpoint_prefix);
return 0;
}
@MoonXu0722
Copy link

TypeError: write() argument must be str, not bytes

@lqzmforer
Copy link

TypeError: write() argument must be str, not bytes

same

@ashwathkris
Copy link

TypeError: write() argument must be str, not bytes

try this :
with open('graph.pb', 'wb') as f :

@prantoran
Copy link

What if I trained a model in python and want to load it into a c++ application using Bazel? Btw, this gist was really helpful.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment