Skip to content

Instantly share code, notes, and snippets.

@thirdwing
Created February 28, 2019 22:01
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save thirdwing/326623f5e672ed07e857c263ba87a716 to your computer and use it in GitHub Desktop.
Save thirdwing/326623f5e672ed07e857c263ba87a716 to your computer and use it in GitHub Desktop.
Example for C api of tensorflow

Valgrind might report some memory still reachable.

This is known. See tensorflow/tensorflow#17739 for detail.

Resources like thread-pool and memory allocator are shared between sessions. Shutting down the session would not shut down shared resources because those resources may be reused for future sessions. I guess you want an API to shut down memory allocator/thread-pool, which doesn't exist AFAIK

==28825== LEAK SUMMARY:
==28825== definitely lost: 0 bytes in 0 blocks
==28825== indirectly lost: 0 bytes in 0 blocks
==28825== possibly lost: 69,312 bytes in 248 blocks
==28825== still reachable: 7,865,580 bytes in 134,721 blocks
==28825== of which reachable via heuristic:
==28825== stdstring : 1,756,058 bytes in 44,875 blocks
==28825== suppressed: 0 bytes in 0 blocks
#
# this will save a simple mlp into folder 'exported'
#
import tensorflow as tf
import numpy as np
from sklearn import datasets
from sklearn.model_selection import train_test_split
RANDOM_SEED = 42
tf.set_random_seed(RANDOM_SEED)
def init_weights(shape):
""" Weight initialization """
weights = tf.random_normal(shape, stddev=0.1)
return tf.Variable(weights)
def forwardprop(X, w_1, w_2):
"""
Forward-propagation.
IMPORTANT: yhat is not softmax since TensorFlow's softmax_cross_entropy_with_logits() does that internally.
"""
h = tf.nn.sigmoid(tf.matmul(X, w_1)) # The \sigma function
yhat = tf.matmul(h, w_2) # The \varphi function
return yhat
def get_iris_data():
""" Read the iris data set and split them into training and test sets """
iris = datasets.load_iris()
data = iris["data"]
target = iris["target"]
# Prepend the column of 1s for bias
N, M = data.shape
all_X = np.ones((N, M + 1))
all_X[:, 1:] = data
# Convert into one-hot vectors
num_labels = len(np.unique(target))
all_Y = np.eye(num_labels)[target] # One liner trick!
return train_test_split(all_X, all_Y, test_size=0.33, random_state=RANDOM_SEED)
train_X, test_X, train_y, test_y = get_iris_data()
# Layer's sizes
x_size = train_X.shape[1] # Number of input nodes: 4 features and 1 bias
h_size = 256 # Number of hidden nodes
y_size = train_y.shape[1] # Number of outcomes (3 iris flowers)
# Symbols
X = tf.placeholder("float", shape=[None, x_size], name="input_x")
y = tf.placeholder("float", shape=[None, y_size])
# Weight initializations
w_1 = init_weights((x_size, h_size))
w_2 = init_weights((h_size, y_size))
# Forward propagation
yhat = forwardprop(X, w_1, w_2)
output = tf.identity(yhat, name="yhat") # this line is not necessary. I just try to rename it.
predict = tf.argmax(yhat, axis=1, name="output_y")
# Backward propagation
cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y, logits=yhat))
updates = tf.train.GradientDescentOptimizer(0.01).minimize(cost)
# Run SGD
sess = tf.Session()
init = tf.global_variables_initializer()
sess.run(init)
for epoch in range(100):
# Train with each example
for i in range(len(train_X)):
sess.run(updates, feed_dict={X: train_X[i: i + 1], y: train_y[i: i + 1]})
train_accuracy = np.mean(np.argmax(train_y, axis=1) ==
sess.run(predict, feed_dict={X: train_X}))
test_accuracy = np.mean(np.argmax(test_y, axis=1) ==
sess.run(predict, feed_dict={X: test_X}))
print("Epoch = %d, train accuracy = %.2f%%, test accuracy = %.2f%%"
% (epoch + 1, 100. * train_accuracy, 100. * test_accuracy))
saver = tf.train.Saver()
saver.save(sess, './exported/model')
tf.train.write_graph(sess.graph, '.', "./exported/graph.pb", as_text=False)
tf.train.write_graph(sess.graph, '.', "./exported/graph.pb_txt", as_text=True)
sess.close()
#include <dlfcn.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <iostream>
#include "c_api.h"
void free_buffer(void* data, size_t length) { free(data); }
void deallocator(void* ptr, size_t len, void* arg) { free(ptr); }
int main() {
FILE* f = fopen("./exported/graph.pb", "rb");
fseek(f, 0, SEEK_END);
long fsize = ftell(f);
fseek(f, 0, SEEK_SET);
void* data = malloc(fsize);
fread(data, fsize, 1, f);
fclose(f);
void* tf_handle = dlopen("libtensorflow.so", RTLD_NOW);
if (!tf_handle) {
std::cerr << "Error: " << dlerror() << std::endl;
return EXIT_FAILURE;
}
// TF_NewBuffer
TF_Buffer* (*TF_NewBuffer)(void);
TF_NewBuffer = (TF_Buffer * (*)(void)) dlsym(tf_handle, "TF_NewBuffer");
// TF_DeleteBuffer
void (*TF_DeleteBuffer)(TF_Buffer*);
TF_DeleteBuffer = (void (*)(TF_Buffer*))dlsym(tf_handle, "TF_DeleteBuffer");
// TF_NewGraph
TF_Graph* (*TF_NewGraph)(void);
TF_NewGraph = (TF_Graph * (*)(void)) dlsym(tf_handle, "TF_NewGraph");
// TF_DeleteGraph
void (*TF_DeleteGraph)(TF_Graph*);
TF_DeleteGraph = (void (*)(TF_Graph*))dlsym(tf_handle, "TF_DeleteGraph");
// TF_NewStatus
TF_Status* (*TF_NewStatus)(void);
TF_NewStatus = (TF_Status * (*)(void)) dlsym(tf_handle, "TF_NewStatus");
// TF_DeleteStatus
void (*TF_DeleteStatus)(TF_Status*);
TF_DeleteStatus = (void (*)(TF_Status*))dlsym(tf_handle, "TF_DeleteStatus");
// TF_NewImportGraphDefOptions
TF_ImportGraphDefOptions* (*TF_NewImportGraphDefOptions)(void);
TF_NewImportGraphDefOptions = (TF_ImportGraphDefOptions * (*)(void))
dlsym(tf_handle, "TF_NewImportGraphDefOptions");
// TF_DeleteImportGraphDefOptions
void (*TF_DeleteImportGraphDefOptions)(TF_ImportGraphDefOptions*);
TF_DeleteImportGraphDefOptions = (void (*)(TF_ImportGraphDefOptions*))dlsym(
tf_handle, "TF_DeleteImportGraphDefOptions");
// TF_GetCode
TF_Code (*TF_GetCode)(const TF_Status*);
TF_GetCode = (TF_Code(*)(const TF_Status*))dlsym(tf_handle, "TF_GetCode");
// TF_Message
const char* (*TF_Message)(const TF_Status* s);
TF_Message =
(const char* (*)(const TF_Status*))dlsym(tf_handle, "TF_Message");
// TF_GraphImportGraphDef
void (*TF_GraphImportGraphDef)(TF_Graph * graph, const TF_Buffer* graph_def,
const TF_ImportGraphDefOptions* options,
TF_Status* status);
TF_GraphImportGraphDef =
(void (*)(TF_Graph*, const TF_Buffer*, const TF_ImportGraphDefOptions*,
TF_Status*))dlsym(tf_handle, "TF_GraphImportGraphDef");
// TF_NewSessionOptions
TF_SessionOptions* (*TF_NewSessionOptions)(void);
TF_NewSessionOptions =
(TF_SessionOptions * (*)(void)) dlsym(tf_handle, "TF_NewSessionOptions");
// TF_DeleteSessionOptions
void (*TF_DeleteSessionOptions)(TF_SessionOptions*);
TF_DeleteSessionOptions =
(void (*)(TF_SessionOptions*))dlsym(tf_handle, "TF_DeleteSessionOptions");
// TF_NewSession
TF_Session* (*TF_NewSession)(TF_Graph*, const TF_SessionOptions*, TF_Status*);
TF_NewSession =
(TF_Session * (*)(TF_Graph*, const TF_SessionOptions*, TF_Status*))
dlsym(tf_handle, "TF_NewSession");
// TF_CloseSession
void (*TF_CloseSession)(TF_Session*, TF_Status*);
TF_CloseSession =
(void (*)(TF_Session*, TF_Status*))dlsym(tf_handle, "TF_CloseSession");
// TF_DeleteSession
void (*TF_DeleteSession)(TF_Session*, TF_Status*);
TF_DeleteSession =
(void (*)(TF_Session*, TF_Status*))dlsym(tf_handle, "TF_DeleteSession");
// TF_GraphOperationByName
TF_Operation* (*TF_GraphOperationByName)(TF_Graph*, const char*);
TF_GraphOperationByName = (TF_Operation * (*)(TF_Graph*, const char*))
dlsym(tf_handle, "TF_GraphOperationByName");
// TF_OperationOpType
const char* (*TF_OperationOpType)(TF_Operation*);
TF_OperationOpType =
(const char* (*)(TF_Operation*))dlsym(tf_handle, "TF_OperationOpType");
// TF_SessionRun
void (*TF_SessionRun)(TF_Session*, const TF_Buffer*, const TF_Output*,
TF_Tensor* const*, int, const TF_Output*, TF_Tensor**,
int, const TF_Operation* const*, int, TF_Buffer*,
TF_Status*);
TF_SessionRun = (void (*)(
TF_Session*, const TF_Buffer*, const TF_Output*, TF_Tensor* const*, int,
const TF_Output*, TF_Tensor**, int, const TF_Operation* const*, int,
TF_Buffer*, TF_Status*))dlsym(tf_handle, "TF_SessionRun");
// TF_StringEncodedSize
size_t (*TF_StringEncodedSize)(size_t);
TF_StringEncodedSize =
(size_t(*)(size_t))dlsym(tf_handle, "TF_StringEncodedSize");
// TF_StringEncode
size_t (*TF_StringEncode)(const char*, size_t, char*, size_t, TF_Status*);
TF_StringEncode = (size_t(*)(const char*, size_t, char*, size_t,
TF_Status*))dlsym(tf_handle, "TF_StringEncode");
// TF_NewTensor
TF_Tensor* (*TF_NewTensor)(TF_DataType, const int64_t*, int, void*, size_t,
void (*deallocator)(void*, size_t, void*), void*);
TF_NewTensor =
(TF_Tensor * (*)(TF_DataType, const int64_t*, int, void*, size_t,
void (*deallocator)(void*, size_t, void*), void*))
dlsym(tf_handle, "TF_NewTensor");
// TF_TensorData
void* (*TF_TensorData)(const TF_Tensor*);
TF_TensorData =
(void* (*)(const TF_Tensor*))dlsym(tf_handle, "TF_TensorData");
// TF_NumDims
int (*TF_NumDims)(const TF_Tensor*);
TF_NumDims = (int (*)(const TF_Tensor*))dlsym(tf_handle, "TF_NumDims");
// TF_Dim
int64_t (*TF_Dim)(const TF_Tensor*, int);
TF_Dim = (int64_t(*)(const TF_Tensor*, int))dlsym(tf_handle, "TF_Dim");
// TF_DeleteTensor
void (*TF_DeleteTensor)(TF_Tensor*);
TF_DeleteTensor = (void (*)(TF_Tensor*))dlsym(tf_handle, "TF_DeleteTensor");
// TF_OperationNumOutputs
int (*TF_OperationNumOutputs)(TF_Operation*);
TF_OperationNumOutputs =
(int (*)(TF_Operation*))dlsym(tf_handle, "TF_OperationNumOutputs");
TF_Buffer* graph_def = TF_NewBuffer();
graph_def->data = data;
graph_def->length = fsize;
graph_def->data_deallocator = free_buffer;
TF_Graph* graph = TF_NewGraph();
TF_Status* status = TF_NewStatus();
TF_ImportGraphDefOptions* opts = TF_NewImportGraphDefOptions();
TF_GraphImportGraphDef(graph, graph_def, opts, status);
TF_DeleteImportGraphDefOptions(opts);
TF_DeleteBuffer(graph_def);
if (TF_GetCode(status) != TF_OK) {
fprintf(stderr, "ERROR: Unable to import graph %s\n", TF_Message(status));
return 1;
}
fprintf(stdout, "Successfully imported graph\n");
// create session
TF_SessionOptions* opt = TF_NewSessionOptions();
TF_Session* sess = TF_NewSession(graph, opt, status);
TF_DeleteSessionOptions(opt);
if (TF_GetCode(status) != TF_OK) {
fprintf(stderr, "ERROR: Unable to create session %s\n", TF_Message(status));
return 1;
}
fprintf(stdout, "Successfully created session\n");
// run init operation
const TF_Operation* init_op = TF_GraphOperationByName(graph, "init");
const TF_Operation* const* targets_ptr = &init_op;
TF_SessionRun(sess,
/* RunOptions */ NULL,
/* Input tensors */ NULL, NULL, 0,
/* Output tensors */ NULL, NULL, 0,
/* Target operations */ targets_ptr, 1,
/* RunMetadata */ NULL,
/* Output status */ status);
if (TF_GetCode(status) != TF_OK) {
fprintf(stderr, "ERROR: Unable to run init_op: %s\n", TF_Message(status));
return 1;
}
TF_Operation* checkpoint_op = TF_GraphOperationByName(graph, "save/Const");
const TF_Operation* const restore_op =
TF_GraphOperationByName(graph, "save/restore_all");
const char* checkpoint_path_str = "./exported/model";
size_t checkpoint_path_str_len = strlen(checkpoint_path_str);
size_t encoded_size = TF_StringEncodedSize(checkpoint_path_str_len);
// The format for TF_STRING tensors is:
// start_offset: array[uint64]
// data: byte[...]
size_t total_size = sizeof(int64_t) + encoded_size;
char* input_encoded = (char*)malloc(total_size);
memset(input_encoded, 0, total_size);
TF_StringEncode(checkpoint_path_str, checkpoint_path_str_len,
input_encoded + sizeof(int64_t), encoded_size, status);
if (TF_GetCode(status) != TF_OK) {
fprintf(stderr, "ERROR: something wrong with encoding: %s",
TF_Message(status));
return 1;
}
TF_Tensor* path_tensor = TF_NewTensor(TF_STRING, NULL, 0, input_encoded,
total_size, &deallocator, 0);
TF_Output* run_path = (TF_Output*)malloc(1 * sizeof(TF_Output));
run_path[0].oper = checkpoint_op;
run_path[0].index = 0;
TF_Tensor** run_path_tensors = (TF_Tensor**)malloc(1 * sizeof(TF_Tensor*));
run_path_tensors[0] = path_tensor;
TF_SessionRun(sess,
/* RunOptions */ NULL,
/* Input tensors */ run_path, run_path_tensors, 1,
/* Output tensors */ NULL, NULL, 0,
/* Target operations */ &restore_op, 1,
/* RunMetadata */ NULL,
/* Output status */ status);
if (TF_GetCode(status) != TF_OK) {
fprintf(stderr, "ERROR: Unable to run restore_op: %s\n",
TF_Message(status));
return 1;
}
TF_DeleteTensor(path_tensor);
free(run_path);
free(run_path_tensors);
// gerenate input
TF_Operation* input_op = TF_GraphOperationByName(graph, "input_x");
printf("input_op has %i inputs\n", TF_OperationNumOutputs(input_op));
float* raw_input_data = (float*)malloc(10 * sizeof(float));
raw_input_data[0] = 1.0f;
raw_input_data[1] = 6.1f;
raw_input_data[2] = 2.8f;
raw_input_data[3] = 4.7f;
raw_input_data[4] = 1.2f;
raw_input_data[5] = 1.0f;
raw_input_data[6] = 5.7f;
raw_input_data[7] = 3.8f;
raw_input_data[8] = 1.7f;
raw_input_data[9] = 0.3f;
int64_t* raw_input_dims = (int64_t*)malloc(2 * sizeof(int64_t));
raw_input_dims[0] = 2;
raw_input_dims[1] = 5;
// prepare inputs
TF_Tensor* input_tensor =
TF_NewTensor(TF_FLOAT, raw_input_dims, 2, raw_input_data,
10 * sizeof(float), &deallocator, NULL);
TF_Output* run_inputs = (TF_Output*)malloc(1 * sizeof(TF_Output));
run_inputs[0].oper = input_op;
run_inputs[0].index = 0;
TF_Tensor** run_inputs_tensors = (TF_Tensor**)malloc(1 * sizeof(TF_Tensor*));
run_inputs_tensors[0] = input_tensor;
// prepare outputs
TF_Operation* output_op = TF_GraphOperationByName(graph, "yhat");
printf("output_op has %i outputs\n", TF_OperationNumOutputs(output_op));
printf("TF_OperationOpType %s\n", TF_OperationOpType(output_op));
TF_Output* run_outputs = (TF_Output*)malloc(1 * sizeof(TF_Output));
run_outputs[0].oper = output_op;
run_outputs[0].index = 0;
TF_Tensor** run_output_tensors = (TF_Tensor**)malloc(1 * sizeof(TF_Tensor*));
// run network
TF_SessionRun(sess,
/* RunOptions */ NULL,
/* Input tensors */ run_inputs, run_inputs_tensors, 1,
/* Output tensors */ run_outputs, run_output_tensors, 1,
/* Target operations */ NULL, 0,
/* RunMetadata */ NULL,
/* Output status */ status);
if (TF_GetCode(status) != TF_OK) {
fprintf(stderr, "ERROR: Unable to run output_op: %s\n", TF_Message(status));
return 1;
}
printf("output-tensor has %i dims\n", TF_NumDims(run_output_tensors[0]));
std::cout << "dim[0]: " << TF_Dim(run_output_tensors[0], 0) << std::endl;
std::cout << "dim[1]: " << TF_Dim(run_output_tensors[0], 1) << std::endl;
void* output_data = TF_TensorData(run_output_tensors[0]);
printf("output %f\n", ((float*)output_data)[0]);
printf("output %f\n", ((float*)output_data)[1]);
printf("output %f\n", ((float*)output_data)[2]);
printf("output %f\n", ((float*)output_data)[3]);
printf("output %f\n", ((float*)output_data)[4]);
printf("output %f\n", ((float*)output_data)[5]);
// clean up the sess and graph
TF_CloseSession(sess, status);
TF_DeleteSession(sess, status);
TF_DeleteStatus(status);
TF_DeleteGraph(graph);
// cleanup the tensor for input/output
TF_DeleteTensor(input_tensor);
free(raw_input_dims);
free(run_inputs);
free(run_inputs_tensors);
TF_DeleteTensor(run_output_tensors[0]);
free(run_outputs);
free(run_output_tensors);
return 0;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment