Skip to content

Instantly share code, notes, and snippets.

What would you like to do?
Training TensorFlow models in C++

Training TensorFlow models in C++

Python is the primary language in which TensorFlow models are typically developed and trained. TensorFlow does have bindings for other programming languages. These bindings have the low-level primitives that are required to build a more complete API, however, lack much of the higher-level API richness of the Python bindings, particularly for defining the model structure.

This file demonstrates taking a model (a TensorFlow graph) created by a Python program and running the training loop in C++.

The model

The model is a trivial one, trying to learn the function: f(x) = W\*x + b, where W and b are model parameters. The training data is constructed so that the "true" value of W is 3 and that of b is 2, i.e., f(x) = 3 * x + 2.


  • Python code that constructs a model and saves the computational graph in file called graph.pb. TAll other files assume that has been run once.
  • C++ code that loads the model, optionally loads model weights saved in a checkpoint, trains a few steps, writes the updated model weights to a checkpoint.


  • The Python APIs for TensorFlow include other conveniences for training (such as MonitoredSession and tf.train.Estimator), which make it easier to configure checkpointing, evaluation loops etc. The examples here aren't that sophisticated and are focused on basic model training only.
  • In this example, we use placeholders and feed dictionaries to feed input, but in a real example you probably want to use the API to cconstruct an input pipeline for providing training data to the model.
  • Not demonstrated here, but summaries for TensorBoard can also be produced by executing the summary operations.

See Also

import tensorflow as tf
# Batch of input and target output (1x1 matrices)
x = tf.placeholder(tf.float32, shape=[None, 1, 1], name='input')
y = tf.placeholder(tf.float32, shape=[None, 1, 1], name='target')
# Trivial linear model
y_ = tf.identity(tf.layers.dense(x, 1), name='output')
# Optimize loss
loss = tf.reduce_mean(tf.square(y_ - y), name='loss')
optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.01)
train_op = optimizer.minimize(loss, name='train')
init = tf.global_variables_initializer()
# tf.train.Saver.__init__ adds operations to the graph to save
# and restore variables.
saver_def = tf.train.Saver().as_saver_def()
print('Run this operation to initialize variables : ',
print('Run this operation for a train step : ',
print('Feed this tensor to set the checkpoint filename: ', saver_def.filename_tensor_name)
print('Run this operation to save a checkpoint : ', saver_def.save_tensor_name)
print('Run this operation to restore a checkpoint : ', saver_def.restore_op_name)
# Write the graph out to a file.
with open('graph.pb', 'w') as f:
// Example of training the model created by in a C++ program.
// See also
#include <iostream>
#include <vector>
#include <cstdlib>
#include <string>
#include <sys/stat.h>
#include "third_party/tensorflow/core/framework/graph.proto.h"
#include "third_party/tensorflow/core/framework/tensor.h"
#include "third_party/tensorflow/core/lib/io/path.h"
#include "third_party/tensorflow/core/platform/env.h"
#include "third_party/tensorflow/core/platform/init_main.h"
#include "third_party/tensorflow/core/platform/logging.h"
#include "third_party/tensorflow/core/platform/types.h"
#include "third_party/tensorflow/core/public/session.h"
class Model {
Model(const string& graph_def_filename) {
tensorflow::GraphDef graph_def;
graph_def_filename, &graph_def));
void Init() { TF_CHECK_OK(session_->Run({}, {}, {"init"}, nullptr)); }
void Restore(const string& checkpoint_prefix) {
SaveOrRestore(checkpoint_prefix, "save/restore_all");
void Predict(const std::vector<float>& batch) {
std::vector<tensorflow::Tensor> out_tensors;
TF_CHECK_OK(session_->Run({{"input", MakeTensor(batch)}}, {"output"}, {},
for (int i = 0; i < batch.size(); ++i) {
std::cout << "\t x = " << batch[i]
<< ", predicted y = " << out_tensors[0].flat<float>()(i)
<< "\n";
void RunTrainStep(const std::vector<float>& input_batch,
const std::vector<float>& target_batch) {
TF_CHECK_OK(session_->Run({{"input", MakeTensor(input_batch)},
{"target", MakeTensor(target_batch)}},
{}, {"train"}, nullptr));
void Checkpoint(const string& checkpoint_prefix) {
SaveOrRestore(checkpoint_prefix, "save/control_dependency");
tensorflow::Tensor MakeTensor(const std::vector<float>& batch) {
tensorflow::Tensor t(tensorflow::DT_FLOAT,
tensorflow::TensorShape({(int)batch.size(), 1, 1}));
for (int i = 0; i < batch.size(); ++i) {
t.flat<float>()(i) = batch[i];
return t;
void SaveOrRestore(const string& checkpoint_prefix, const string& op_name) {
tensorflow::Tensor t(tensorflow::DT_STRING, tensorflow::TensorShape());
t.scalar<string>()() = checkpoint_prefix;
TF_CHECK_OK(session_->Run({{"save/Const", t}}, {}, {op_name}, nullptr));
std::unique_ptr<tensorflow::Session> session_;
bool DirectoryExists(const string& dir) {
struct stat buf;
return stat(dir.c_str(), &buf) == 0;
int main(int argc, char* argv[]) {
const string graph_def_filename =
const string checkpoint_dir = "/usr/local/google/home/ashankar/tmp/gist/checkpoints";
const string checkpoint_prefix = checkpoint_dir + "/checkpoint";
bool restore = DirectoryExists(checkpoint_dir);
// Setup global state for TensorFlow.
tensorflow::port::InitMain(argv[0], &argc, &argv);
std::cout << "Loading graph\n";
Model model(graph_def_filename);
if (!restore) {
std::cout << "Initializing model weights\n";
} else {
std::cout << "Restoring model weights from checkpoint\n";
const std::vector<float> testdata({1.0, 2.0, 3.0});
std::cout << "Initial predictions\n";
std::cout << "Training for a few steps\n";
for (int i = 0; i < 200; ++i) {
std::vector<float> train_inputs, train_targets;
for (int j = 0; j < 10; j++) {
train_inputs.push_back(static_cast<float>(std::rand()) / static_cast<float>(RAND_MAX));
train_targets.push_back(3 * train_inputs.back() + 2);
model.RunTrainStep(train_inputs, train_targets);
std::cout << "Updated predictions\n";
std::cout << "Saving checkpoint\n";
return 0;

This comment has been minimized.

Copy link

@MoonXu0722 MoonXu0722 commented Jul 15, 2019

TypeError: write() argument must be str, not bytes


This comment has been minimized.

Copy link

@lqzmforer lqzmforer commented Jul 26, 2019

TypeError: write() argument must be str, not bytes



This comment has been minimized.

Copy link

@Ashwa2001 Ashwa2001 commented Jun 26, 2020

TypeError: write() argument must be str, not bytes

try this :
with open('graph.pb', 'wb') as f :

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
You can’t perform that action at this time.