Skip to content

Instantly share code, notes, and snippets.

@asimshankar
Last active November 13, 2024 20:16
Show Gist options
  • Save asimshankar/7c9f8a9b04323e93bb217109da8c7ad2 to your computer and use it in GitHub Desktop.
Save asimshankar/7c9f8a9b04323e93bb217109da8c7ad2 to your computer and use it in GitHub Desktop.
Training TensorFlow models in C

Training TensorFlow models in C

Python is the primary language in which TensorFlow models are typically developed and trained. TensorFlow does have bindings for other programming languages. These bindings have the low-level primitives that are required to build a more complete API, however, lack much of the higher-level API richness of the Python bindings, particularly for defining the model structure.

This gist demonstrates taking a model (a TensorFlow graph) created by a Python program and running the training loop in a C program.

The model

The model is a trivial one, trying to learn the function: f(x) = W\*x + b, where W and b are model parameters. The training data is constructed so that the "true" value of W is 3 and that of b is 2, i.e., f(x) = 3 * x + 2.

Files

  • model.py: Python code that constructs a model and saves the computational graph in file called graph.pb. TAll other files assume that model.py has been run once.
  • train.c: C code that loads the model, optionally loads model weights saved in a checkpoint, trains a few steps, writes the updated model weights to a checkpoint.
  • train.c.sh: Trivial script to compile and run train.c

Noteworthy

  • The Python APIs for TensorFlow include other conveniences for training (such as MonitoredSession and tf.train.Estimator), which make it easier to configure checkpointing, evaluation loops etc. The examples here aren't that sophisticated and are focused on basic model training only.
  • In this example, we use placeholders and feed dictionaries to feed input, but in a real example you probably want to use the tf.data API to cconstruct an input pipeline for providing training data to the model.
  • Not demonstrated here, but summaries for TensorBoard can also be produced by executing the summary operations.

See Also

import tensorflow as tf
# Batch of input and target output (1x1 matrices)
x = tf.placeholder(tf.float32, shape=[None, 1, 1], name='input')
y = tf.placeholder(tf.float32, shape=[None, 1, 1], name='target')
# Trivial linear model
y_ = tf.identity(tf.layers.dense(x, 1), name='output')
# Optimize loss
loss = tf.reduce_mean(tf.square(y_ - y), name='loss')
optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.01)
train_op = optimizer.minimize(loss, name='train')
init = tf.global_variables_initializer()
# tf.train.Saver.__init__ adds operations to the graph to save
# and restore variables.
saver_def = tf.train.Saver().as_saver_def()
print('Run this operation to initialize variables : ', init.name)
print('Run this operation for a train step : ', train_op.name)
print('Feed this tensor to set the checkpoint filename: ', saver_def.filename_tensor_name)
print('Run this operation to save a checkpoint : ', saver_def.save_tensor_name)
print('Run this operation to restore a checkpoint : ', saver_def.restore_op_name)
# Write the graph out to a file.
with open('graph.pb', 'w') as f:
f.write(tf.get_default_graph().as_graph_def().SerializeToString())
// Example of training the model created by model.py using the TensorFlow C API.
//
// To run use c.sh in this directory.
#include <fcntl.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <sys/stat.h>
#include <unistd.h>
#include <tensorflow/c/c_api.h>
typedef struct model_t {
TF_Graph* graph;
TF_Session* session;
TF_Status* status;
TF_Output input, target, output;
TF_Operation *init_op, *train_op, *save_op, *restore_op;
TF_Output checkpoint_file;
} model_t;
int ModelCreate(model_t* model, const char* graph_def_filename);
void ModelDestroy(model_t* model);
int ModelInit(model_t* model);
int ModelPredict(model_t* model, float* batch, int batch_size);
int ModelRunTrainStep(model_t* model);
enum SaveOrRestore { SAVE, RESTORE };
int ModelCheckpoint(model_t* model, const char* checkpoint_prefix, int type);
int Okay(TF_Status* status);
TF_Buffer* ReadFile(const char* filename);
TF_Tensor* ScalarStringTensor(const char* data, TF_Status* status);
int DirectoryExists(const char* dirname);
int main(int argc, char** argv) {
const char* graph_def_filename = "graph.pb";
const char* checkpoint_prefix = "./checkpoints/checkpoint";
int restore = DirectoryExists("checkpoints");
model_t model;
printf("Loading graph\n");
if (!ModelCreate(&model, graph_def_filename)) return 1;
if (restore) {
printf(
"Restoring weights from checkpoint (remove the checkpoints directory "
"to reset)\n");
if (!ModelCheckpoint(&model, checkpoint_prefix, RESTORE)) return 1;
} else {
printf("Initializing model weights\n");
if (!ModelInit(&model)) return 1;
}
float testdata[3] = {1.0, 2.0, 3.0};
printf("Initial predictions\n");
if (!ModelPredict(&model, &testdata[0], 3)) return 1;
printf("Training for a few steps\n");
for (int i = 0; i < 200; ++i) {
if (!ModelRunTrainStep(&model)) return 1;
}
printf("Updated predictions\n");
if (!ModelPredict(&model, &testdata[0], 3)) return 1;
printf("Saving checkpoint\n");
if (!ModelCheckpoint(&model, checkpoint_prefix, SAVE)) return 1;
ModelDestroy(&model);
}
int ModelCreate(model_t* model, const char* graph_def_filename) {
model->status = TF_NewStatus();
model->graph = TF_NewGraph();
{
// Create the session.
TF_SessionOptions* opts = TF_NewSessionOptions();
model->session = TF_NewSession(model->graph, opts, model->status);
TF_DeleteSessionOptions(opts);
if (!Okay(model->status)) return 0;
}
TF_Graph* g = model->graph;
{
// Import the graph.
TF_Buffer* graph_def = ReadFile(graph_def_filename);
if (graph_def == NULL) return 0;
printf("Read GraphDef of %zu bytes\n", graph_def->length);
TF_ImportGraphDefOptions* opts = TF_NewImportGraphDefOptions();
TF_GraphImportGraphDef(g, graph_def, opts, model->status);
TF_DeleteImportGraphDefOptions(opts);
TF_DeleteBuffer(graph_def);
if (!Okay(model->status)) return 0;
}
// Handles to the interesting operations in the graph.
model->input.oper = TF_GraphOperationByName(g, "input");
model->input.index = 0;
model->target.oper = TF_GraphOperationByName(g, "target");
model->target.index = 0;
model->output.oper = TF_GraphOperationByName(g, "output");
model->output.index = 0;
model->init_op = TF_GraphOperationByName(g, "init");
model->train_op = TF_GraphOperationByName(g, "train");
model->save_op = TF_GraphOperationByName(g, "save/control_dependency");
model->restore_op = TF_GraphOperationByName(g, "save/restore_all");
model->checkpoint_file.oper = TF_GraphOperationByName(g, "save/Const");
model->checkpoint_file.index = 0;
return 1;
}
void ModelDestroy(model_t* model) {
TF_DeleteSession(model->session, model->status);
Okay(model->status);
TF_DeleteGraph(model->graph);
TF_DeleteStatus(model->status);
}
int ModelInit(model_t* model) {
const TF_Operation* init_op[1] = {model->init_op};
TF_SessionRun(model->session, NULL,
/* No inputs */
NULL, NULL, 0,
/* No outputs */
NULL, NULL, 0,
/* Just the init operation */
init_op, 1,
/* No metadata */
NULL, model->status);
return Okay(model->status);
}
int ModelCheckpoint(model_t* model, const char* checkpoint_prefix, int type) {
TF_Tensor* t = ScalarStringTensor(checkpoint_prefix, model->status);
if (!Okay(model->status)) {
TF_DeleteTensor(t);
return 0;
}
TF_Output inputs[1] = {model->checkpoint_file};
TF_Tensor* input_values[1] = {t};
const TF_Operation* op[1] = {type == SAVE ? model->save_op
: model->restore_op};
TF_SessionRun(model->session, NULL, inputs, input_values, 1,
/* No outputs */
NULL, NULL, 0,
/* The operation */
op, 1, NULL, model->status);
TF_DeleteTensor(t);
return Okay(model->status);
}
int ModelPredict(model_t* model, float* batch, int batch_size) {
// batch consists of 1x1 matrices.
const int64_t dims[3] = {batch_size, 1, 1};
const size_t nbytes = batch_size * sizeof(float);
TF_Tensor* t = TF_AllocateTensor(TF_FLOAT, dims, 3, nbytes);
memcpy(TF_TensorData(t), batch, nbytes);
TF_Output inputs[1] = {model->input};
TF_Tensor* input_values[1] = {t};
TF_Output outputs[1] = {model->output};
TF_Tensor* output_values[1] = {NULL};
TF_SessionRun(model->session, NULL, inputs, input_values, 1, outputs,
output_values, 1,
/* No target operations to run */
NULL, 0, NULL, model->status);
TF_DeleteTensor(t);
if (!Okay(model->status)) return 0;
if (TF_TensorByteSize(output_values[0]) != nbytes) {
fprintf(stderr,
"ERROR: Expected predictions tensor to have %zu bytes, has %zu\n",
nbytes, TF_TensorByteSize(output_values[0]));
TF_DeleteTensor(output_values[0]);
return 0;
}
float* predictions = (float*)malloc(nbytes);
memcpy(predictions, TF_TensorData(output_values[0]), nbytes);
TF_DeleteTensor(output_values[0]);
printf("Predictions:\n");
for (int i = 0; i < batch_size; ++i) {
printf("\t x = %f, predicted y = %f\n", batch[i], predictions[i]);
}
free(predictions);
return 1;
}
void NextBatchForTraining(TF_Tensor** inputs_tensor,
TF_Tensor** targets_tensor) {
#define BATCH_SIZE 10
float inputs[BATCH_SIZE] = {0};
float targets[BATCH_SIZE] = {0};
for (int i = 0; i < BATCH_SIZE; ++i) {
inputs[i] = (float)rand() / (float)RAND_MAX;
targets[i] = 3.0 * inputs[i] + 2.0;
}
const int64_t dims[] = {BATCH_SIZE, 1, 1};
size_t nbytes = BATCH_SIZE * sizeof(float);
*inputs_tensor = TF_AllocateTensor(TF_FLOAT, dims, 3, nbytes);
*targets_tensor = TF_AllocateTensor(TF_FLOAT, dims, 3, nbytes);
memcpy(TF_TensorData(*inputs_tensor), inputs, nbytes);
memcpy(TF_TensorData(*targets_tensor), targets, nbytes);
#undef BATCH_SIZE
}
int ModelRunTrainStep(model_t* model) {
TF_Tensor *x, *y;
NextBatchForTraining(&x, &y);
TF_Output inputs[2] = {model->input, model->target};
TF_Tensor* input_values[2] = {x, y};
const TF_Operation* train_op[1] = {model->train_op};
TF_SessionRun(model->session, NULL, inputs, input_values, 2,
/* No outputs */
NULL, NULL, 0, train_op, 1, NULL, model->status);
TF_DeleteTensor(x);
TF_DeleteTensor(y);
return Okay(model->status);
}
int Okay(TF_Status* status) {
if (TF_GetCode(status) != TF_OK) {
fprintf(stderr, "ERROR: %s\n", TF_Message(status));
return 0;
}
return 1;
}
TF_Buffer* ReadFile(const char* filename) {
int fd = open(filename, 0);
if (fd < 0) {
perror("failed to open file: ");
return NULL;
}
struct stat stat;
if (fstat(fd, &stat) != 0) {
perror("failed to read file: ");
return NULL;
}
char* data = (char*)malloc(stat.st_size);
ssize_t nread = read(fd, data, stat.st_size);
if (nread < 0) {
perror("failed to read file: ");
free(data);
return NULL;
}
if (nread != stat.st_size) {
fprintf(stderr, "read %zd bytes, expected to read %zd\n", nread,
stat.st_size);
free(data);
return NULL;
}
TF_Buffer* ret = TF_NewBufferFromString(data, stat.st_size);
free(data);
return ret;
}
TF_Tensor* ScalarStringTensor(const char* str, TF_Status* status) {
size_t nbytes = 8 + TF_StringEncodedSize(strlen(str));
TF_Tensor* t = TF_AllocateTensor(TF_STRING, NULL, 0, nbytes);
void* data = TF_TensorData(t);
memset(data, 0, 8); // 8-byte offset of first string.
TF_StringEncode(str, strlen(str), data + 8, nbytes - 8, status);
return t;
}
int DirectoryExists(const char* dirname) {
struct stat buf;
return stat(dirname, &buf) == 0;
}
#!/bin/bash
set -e
# Compile and execute train.c
# Requires the TensorFlow C library library and headers, so download those.
if [[ ! -d "clib" ]]
then
echo "Downloading TensorFlow C library into clib"
mkdir clib
curl -L "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-cpu-linux-x86_64-1.4.0.tar.gz" | tar -C clib -xz
fi
gcc -std=c99 -I clib/include -L clib/lib train.c -ltensorflow -ltensorflow_framework
LD_LIBRARY_PATH=clib/lib ./a.out
@jj-aggarwal
Copy link

Hey, thank you for the gist. I was looking at it and find interesting. But when I tried, I got this warning. Can you answer me if it is my machine or someone also fixed it?

implicit declaration of function ‘TF_StringEncodedSize’; did you mean ‘TF_StringGetSize’? [-Wimplicit-function-declaration]gccimplicit declaration of function ‘TF_StringEncodedSize’; did you mean ‘TF_StringGetSize’? [-Wimplicit-function-declaration]gcc

@darknovismc
Copy link

Hi, I've built my project with tensorflow 2.8.0 so maybe the developers changed something in the recent version.
I recommend you to replace tensorflow folder in my project with the recent header files for current tensorflow version (or go back to 2.8.0 lib/dll)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment