Skip to content

Instantly share code, notes, and snippets.

Last active April 13, 2024 15:50
Show Gist options
  • Save asimshankar/7c9f8a9b04323e93bb217109da8c7ad2 to your computer and use it in GitHub Desktop.
Save asimshankar/7c9f8a9b04323e93bb217109da8c7ad2 to your computer and use it in GitHub Desktop.
Training TensorFlow models in C

Training TensorFlow models in C

Python is the primary language in which TensorFlow models are typically developed and trained. TensorFlow does have bindings for other programming languages. These bindings have the low-level primitives that are required to build a more complete API, however, lack much of the higher-level API richness of the Python bindings, particularly for defining the model structure.

This gist demonstrates taking a model (a TensorFlow graph) created by a Python program and running the training loop in a C program.

The model

The model is a trivial one, trying to learn the function: f(x) = W\*x + b, where W and b are model parameters. The training data is constructed so that the "true" value of W is 3 and that of b is 2, i.e., f(x) = 3 * x + 2.


  • Python code that constructs a model and saves the computational graph in file called graph.pb. TAll other files assume that has been run once.
  • train.c: C code that loads the model, optionally loads model weights saved in a checkpoint, trains a few steps, writes the updated model weights to a checkpoint.
  • Trivial script to compile and run train.c


  • The Python APIs for TensorFlow include other conveniences for training (such as MonitoredSession and tf.train.Estimator), which make it easier to configure checkpointing, evaluation loops etc. The examples here aren't that sophisticated and are focused on basic model training only.
  • In this example, we use placeholders and feed dictionaries to feed input, but in a real example you probably want to use the API to cconstruct an input pipeline for providing training data to the model.
  • Not demonstrated here, but summaries for TensorBoard can also be produced by executing the summary operations.

See Also

import tensorflow as tf
# Batch of input and target output (1x1 matrices)
x = tf.placeholder(tf.float32, shape=[None, 1, 1], name='input')
y = tf.placeholder(tf.float32, shape=[None, 1, 1], name='target')
# Trivial linear model
y_ = tf.identity(tf.layers.dense(x, 1), name='output')
# Optimize loss
loss = tf.reduce_mean(tf.square(y_ - y), name='loss')
optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.01)
train_op = optimizer.minimize(loss, name='train')
init = tf.global_variables_initializer()
# tf.train.Saver.__init__ adds operations to the graph to save
# and restore variables.
saver_def = tf.train.Saver().as_saver_def()
print('Run this operation to initialize variables : ',
print('Run this operation for a train step : ',
print('Feed this tensor to set the checkpoint filename: ', saver_def.filename_tensor_name)
print('Run this operation to save a checkpoint : ', saver_def.save_tensor_name)
print('Run this operation to restore a checkpoint : ', saver_def.restore_op_name)
# Write the graph out to a file.
with open('graph.pb', 'w') as f:
// Example of training the model created by using the TensorFlow C API.
// To run use in this directory.
#include <fcntl.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <sys/stat.h>
#include <unistd.h>
#include <tensorflow/c/c_api.h>
typedef struct model_t {
TF_Graph* graph;
TF_Session* session;
TF_Status* status;
TF_Output input, target, output;
TF_Operation *init_op, *train_op, *save_op, *restore_op;
TF_Output checkpoint_file;
} model_t;
int ModelCreate(model_t* model, const char* graph_def_filename);
void ModelDestroy(model_t* model);
int ModelInit(model_t* model);
int ModelPredict(model_t* model, float* batch, int batch_size);
int ModelRunTrainStep(model_t* model);
enum SaveOrRestore { SAVE, RESTORE };
int ModelCheckpoint(model_t* model, const char* checkpoint_prefix, int type);
int Okay(TF_Status* status);
TF_Buffer* ReadFile(const char* filename);
TF_Tensor* ScalarStringTensor(const char* data, TF_Status* status);
int DirectoryExists(const char* dirname);
int main(int argc, char** argv) {
const char* graph_def_filename = "graph.pb";
const char* checkpoint_prefix = "./checkpoints/checkpoint";
int restore = DirectoryExists("checkpoints");
model_t model;
printf("Loading graph\n");
if (!ModelCreate(&model, graph_def_filename)) return 1;
if (restore) {
"Restoring weights from checkpoint (remove the checkpoints directory "
"to reset)\n");
if (!ModelCheckpoint(&model, checkpoint_prefix, RESTORE)) return 1;
} else {
printf("Initializing model weights\n");
if (!ModelInit(&model)) return 1;
float testdata[3] = {1.0, 2.0, 3.0};
printf("Initial predictions\n");
if (!ModelPredict(&model, &testdata[0], 3)) return 1;
printf("Training for a few steps\n");
for (int i = 0; i < 200; ++i) {
if (!ModelRunTrainStep(&model)) return 1;
printf("Updated predictions\n");
if (!ModelPredict(&model, &testdata[0], 3)) return 1;
printf("Saving checkpoint\n");
if (!ModelCheckpoint(&model, checkpoint_prefix, SAVE)) return 1;
int ModelCreate(model_t* model, const char* graph_def_filename) {
model->status = TF_NewStatus();
model->graph = TF_NewGraph();
// Create the session.
TF_SessionOptions* opts = TF_NewSessionOptions();
model->session = TF_NewSession(model->graph, opts, model->status);
if (!Okay(model->status)) return 0;
TF_Graph* g = model->graph;
// Import the graph.
TF_Buffer* graph_def = ReadFile(graph_def_filename);
if (graph_def == NULL) return 0;
printf("Read GraphDef of %zu bytes\n", graph_def->length);
TF_ImportGraphDefOptions* opts = TF_NewImportGraphDefOptions();
TF_GraphImportGraphDef(g, graph_def, opts, model->status);
if (!Okay(model->status)) return 0;
// Handles to the interesting operations in the graph.
model->input.oper = TF_GraphOperationByName(g, "input");
model->input.index = 0;
model->target.oper = TF_GraphOperationByName(g, "target");
model->target.index = 0;
model->output.oper = TF_GraphOperationByName(g, "output");
model->output.index = 0;
model->init_op = TF_GraphOperationByName(g, "init");
model->train_op = TF_GraphOperationByName(g, "train");
model->save_op = TF_GraphOperationByName(g, "save/control_dependency");
model->restore_op = TF_GraphOperationByName(g, "save/restore_all");
model->checkpoint_file.oper = TF_GraphOperationByName(g, "save/Const");
model->checkpoint_file.index = 0;
return 1;
void ModelDestroy(model_t* model) {
TF_DeleteSession(model->session, model->status);
int ModelInit(model_t* model) {
const TF_Operation* init_op[1] = {model->init_op};
TF_SessionRun(model->session, NULL,
/* No inputs */
/* No outputs */
/* Just the init operation */
init_op, 1,
/* No metadata */
NULL, model->status);
return Okay(model->status);
int ModelCheckpoint(model_t* model, const char* checkpoint_prefix, int type) {
TF_Tensor* t = ScalarStringTensor(checkpoint_prefix, model->status);
if (!Okay(model->status)) {
return 0;
TF_Output inputs[1] = {model->checkpoint_file};
TF_Tensor* input_values[1] = {t};
const TF_Operation* op[1] = {type == SAVE ? model->save_op
: model->restore_op};
TF_SessionRun(model->session, NULL, inputs, input_values, 1,
/* No outputs */
/* The operation */
op, 1, NULL, model->status);
return Okay(model->status);
int ModelPredict(model_t* model, float* batch, int batch_size) {
// batch consists of 1x1 matrices.
const int64_t dims[3] = {batch_size, 1, 1};
const size_t nbytes = batch_size * sizeof(float);
TF_Tensor* t = TF_AllocateTensor(TF_FLOAT, dims, 3, nbytes);
memcpy(TF_TensorData(t), batch, nbytes);
TF_Output inputs[1] = {model->input};
TF_Tensor* input_values[1] = {t};
TF_Output outputs[1] = {model->output};
TF_Tensor* output_values[1] = {NULL};
TF_SessionRun(model->session, NULL, inputs, input_values, 1, outputs,
output_values, 1,
/* No target operations to run */
NULL, 0, NULL, model->status);
if (!Okay(model->status)) return 0;
if (TF_TensorByteSize(output_values[0]) != nbytes) {
"ERROR: Expected predictions tensor to have %zu bytes, has %zu\n",
nbytes, TF_TensorByteSize(output_values[0]));
return 0;
float* predictions = (float*)malloc(nbytes);
memcpy(predictions, TF_TensorData(output_values[0]), nbytes);
for (int i = 0; i < batch_size; ++i) {
printf("\t x = %f, predicted y = %f\n", batch[i], predictions[i]);
return 1;
void NextBatchForTraining(TF_Tensor** inputs_tensor,
TF_Tensor** targets_tensor) {
#define BATCH_SIZE 10
float inputs[BATCH_SIZE] = {0};
float targets[BATCH_SIZE] = {0};
for (int i = 0; i < BATCH_SIZE; ++i) {
inputs[i] = (float)rand() / (float)RAND_MAX;
targets[i] = 3.0 * inputs[i] + 2.0;
const int64_t dims[] = {BATCH_SIZE, 1, 1};
size_t nbytes = BATCH_SIZE * sizeof(float);
*inputs_tensor = TF_AllocateTensor(TF_FLOAT, dims, 3, nbytes);
*targets_tensor = TF_AllocateTensor(TF_FLOAT, dims, 3, nbytes);
memcpy(TF_TensorData(*inputs_tensor), inputs, nbytes);
memcpy(TF_TensorData(*targets_tensor), targets, nbytes);
int ModelRunTrainStep(model_t* model) {
TF_Tensor *x, *y;
NextBatchForTraining(&x, &y);
TF_Output inputs[2] = {model->input, model->target};
TF_Tensor* input_values[2] = {x, y};
const TF_Operation* train_op[1] = {model->train_op};
TF_SessionRun(model->session, NULL, inputs, input_values, 2,
/* No outputs */
NULL, NULL, 0, train_op, 1, NULL, model->status);
return Okay(model->status);
int Okay(TF_Status* status) {
if (TF_GetCode(status) != TF_OK) {
fprintf(stderr, "ERROR: %s\n", TF_Message(status));
return 0;
return 1;
TF_Buffer* ReadFile(const char* filename) {
int fd = open(filename, 0);
if (fd < 0) {
perror("failed to open file: ");
return NULL;
struct stat stat;
if (fstat(fd, &stat) != 0) {
perror("failed to read file: ");
return NULL;
char* data = (char*)malloc(stat.st_size);
ssize_t nread = read(fd, data, stat.st_size);
if (nread < 0) {
perror("failed to read file: ");
return NULL;
if (nread != stat.st_size) {
fprintf(stderr, "read %zd bytes, expected to read %zd\n", nread,
return NULL;
TF_Buffer* ret = TF_NewBufferFromString(data, stat.st_size);
return ret;
TF_Tensor* ScalarStringTensor(const char* str, TF_Status* status) {
size_t nbytes = 8 + TF_StringEncodedSize(strlen(str));
TF_Tensor* t = TF_AllocateTensor(TF_STRING, NULL, 0, nbytes);
void* data = TF_TensorData(t);
memset(data, 0, 8); // 8-byte offset of first string.
TF_StringEncode(str, strlen(str), data + 8, nbytes - 8, status);
return t;
int DirectoryExists(const char* dirname) {
struct stat buf;
return stat(dirname, &buf) == 0;
set -e
# Compile and execute train.c
# Requires the TensorFlow C library library and headers, so download those.
if [[ ! -d "clib" ]]
echo "Downloading TensorFlow C library into clib"
mkdir clib
curl -L "" | tar -C clib -xz
gcc -std=c99 -I clib/include -L clib/lib train.c -ltensorflow -ltensorflow_framework
LD_LIBRARY_PATH=clib/lib ./a.out
Copy link

nikita-astronaut commented Apr 19, 2018

Hey, @asimshankar!

I found this very-very useful, literally the first tensorflow+C example I could reproduce. There is a problem, though. For me, the code you've posted runs only with GradientDescent. Any other optimizer causes checkpoint loading error while running ./ Perhaps, you could have helped me?

@nikita_astronaut on Telegram, on Matrix

Where I can find you?

Copy link

Hi, @asimshankar. Thank you so much! I found this example extremely helpful in my path of learning. However, in line 28 shouldn't it be open('graph.pb', 'wb')? Since graph.pb is a binary format file. Thanks again for writing this example.

Copy link

iHaikal commented Apr 21, 2019

Hi, @asimshankar. Thank you for this gist. I have several questions:

  1. How would this scale up to n-dimensional data overall (in python and c code)? I tried to train model to predict AND logic function and changed batch dimensions from 1x1 to 1x2 and I passed batch parameter in ModelPredict function a sample from AND logic, {0.0, 1.0} for example, the t results into NULL.
  2. What are you trying to achieve in NextBatchForTraining function? (I know its preparing next batch) the for loop is just baffling me.
  3. How would this scale up to n-layers in feed forward neural network? Or even other types of neural networks? (like is the C code more or less stay the same or a huge change would be required ?)

Copy link

chaimash commented Sep 7, 2019


Copy link

chaimash commented Sep 7, 2019

thank you!

Copy link

Is there a way to restore the trained model in python again?

Copy link

Eths33 commented Dec 9, 2019

Testing on windows
With the python file I was getting this error:

TypeError                                 Traceback (most recent call last)
<ipython-input-18-2fe5a7458a00> in <module>
     34 # Write the graph out to a file.
     35 with open('graph.pb', 'w') as f:
---> 36   f.write(tf.get_default_graph().as_graph_def().SerializeToString())

TypeError: write() argument must be str, not bytes

Opening as "wb" fixed the error.
Or saving using the following command worked.
tf.train.write_graph(tf.get_default_graph(), "./", 'graph.pb', as_text=False)

Then to read in the file I used:

TF_Buffer* ReadFile(const char* filename) {
	FILE *fd;// = fopen(filename, "r");
	fopen_s(&fd, filename, "rb");
	if (fd == NULL) {
		perror("failed to open file: ");
		return NULL;

	fseek(fd, 0, SEEK_END);
	long fsize = ftell(fd);
	fseek(fd, 0, SEEK_SET);  /* same as rewind(f); */

	char *data = (char*)malloc(fsize + 1);
	fread(data, 1, fsize, fd);

	data[fsize] = 0;

	TF_Buffer* ret = TF_NewBufferFromString(data, fsize);
	return ret;

Copy link

AbduElrahmanRezk commented Jan 16, 2020

@asimshankar can you provide any cnn example on tensorflow c ?

Copy link

Hi, thanks for sharing this! It was extremely helpful!

Do you have any idea how to pass the 'train' operator using TF v2(keras) instead of TF v1? I invested more time than I wished in it and was unsuccessful : (

I managed to make inferences in the C API using a pre-trained Keras model but I cannot manage to train in the C API.


Copy link

Hi, I made a simple C++ wrapper class which works with Tensorflow 2.0+.
It supports batch training on MNIST image dataset. I encourage everyone to visit my repository on Github !

Copy link

Hey, thank you for the gist. I was looking at it and find interesting. But when I tried, I got this warning. Can you answer me if it is my machine or someone also fixed it?

implicit declaration of function ‘TF_StringEncodedSize’; did you mean ‘TF_StringGetSize’? [-Wimplicit-function-declaration]gccimplicit declaration of function ‘TF_StringEncodedSize’; did you mean ‘TF_StringGetSize’? [-Wimplicit-function-declaration]gcc

Copy link

Hi, I've built my project with tensorflow 2.8.0 so maybe the developers changed something in the recent version.
I recommend you to replace tensorflow folder in my project with the recent header files for current tensorflow version (or go back to 2.8.0 lib/dll)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment