Last active
July 27, 2019 00:13
-
-
Save ryojiysd/1a04e655b1fc9895a179ba11af39975a to your computer and use it in GitHub Desktop.
TensorFlow sample for Dataset and SavedModel (Python and C++)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <memory> | |
#include <vector> | |
#include <string> | |
#include <unordered_set> | |
#include <iostream> | |
#include "tensorflow/core/public/session.h" | |
#include "tensorflow/core/platform/env.h" | |
#include "tensorflow/cc/saved_model/loader.h" | |
#include "tensorflow/cc/saved_model/tag_constants.h" | |
using tensorflow::Tensor; | |
using tensorflow::TensorShape; | |
using tensorflow::Status; | |
using tensorflow::Session; | |
using tensorflow::SavedModelBundle; | |
using tensorflow::SessionOptions; | |
using tensorflow::RunOptions; | |
int main(void) | |
{ | |
const std::string export_dir = "./model"; | |
SavedModelBundle bundle; | |
SessionOptions session_options; | |
RunOptions run_options; | |
// Load model from SavedModel | |
Status status = tensorflow::LoadSavedModel(session_options, run_options, export_dir, {tensorflow::kSavedModelTagServe}, &bundle); | |
if (!status.ok()) { | |
std::cout << "Failed to load saved model" << std::endl; | |
std::cout << status.ToString() << std::endl; | |
return -1; | |
} | |
// Create input data | |
Tensor batch_size(tensorflow::DT_INT64, tensorflow::TensorShape()); | |
auto dst = batch_size.flat<long long>().data(); | |
long long bsize = 3L; | |
memcpy(dst, &bsize, sizeof(bsize)); | |
Tensor input(tensorflow::DT_FLOAT, TensorShape({3, 1})); | |
auto input_dst = input.flat<float>().data(); | |
float arr[3] = {2.0, 3.0, 4.0}; | |
memcpy(input_dst, arr, sizeof(arr)); | |
// Initialize the iterator | |
status = bundle.session->Run( | |
{{"input", input}, {"target", input}, {"batch_size", batch_size}}, | |
{}, | |
{"dataset_init"}, | |
nullptr); | |
if (!status.ok()) { | |
std::cout << "Failed to run sesssion (dataset_init)" << std::endl; | |
std::cout << status.ToString() << std::endl; | |
return -1; | |
} | |
// Prediction | |
std::vector<Tensor> outputs; | |
status = bundle.session->Run({}, {"output:0"}, {}, &outputs); | |
if (!status.ok()) { | |
std::cout << "Failed to run sesssion (output:0)" << std::endl; | |
std::cout << status.ToString() << std::endl; | |
return -1; | |
} | |
Tensor a = outputs.at(0); | |
const int out_dim = 3; | |
for (int i = 0; i < out_dim; i++) { | |
std::cout << a.flat<float>()(i) << std::endl; | |
} | |
return 0; | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import tensorflow as tf | |
from tensorflow.python.saved_model import tag_constants | |
## Actual data | |
test_data = np.array([[2.0], [3.0], [4.0]], dtype='float32') | |
test_label = test_data | |
graph = tf.Graph() | |
with graph.as_default(): | |
# Load model from SavedModel | |
with tf.Session(graph=graph) as sess: | |
tf.saved_model.loader.load( | |
sess, | |
[tag_constants.SERVING], | |
'./model' | |
) | |
# Fetch an operation to initialize iterator | |
dataset_init_op = graph.get_operation_by_name('dataset_init') | |
# Initialize the iterator using the test data | |
sess.run(dataset_init_op, feed_dict={ | |
'input:0': test_data, | |
'target:0': test_label, | |
'batch_size:0': test_data.shape[0] | |
} | |
) | |
# Prediction | |
results = sess.run('output:0') | |
for i in range(len(results)): | |
print("{} {}".format(test_data[i], results[i])) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import tensorflow as tf | |
EPOCHS = 10 | |
BATCH_SIZE = 10 | |
## Actual data | |
train = np.array(np.random.sample((100, 1)) * 10, dtype='float32') | |
train_label = 2 * train | |
test_data = np.array([[1.0], [2.0], [3.0]], dtype='float32') | |
test_label = test_data | |
# Place holders | |
batch_size = tf.placeholder(tf.int64, name='batch_size') | |
x = tf.placeholder(tf.float32, shape=[None, 1], name='input') | |
y = tf.placeholder(tf.float32, shape=[None, 1], name='target') | |
# Create dataset and its iterator | |
dataset = tf.data.Dataset.from_tensor_slices((x, y)).batch(batch_size).repeat() | |
iter = tf.data.Iterator.from_structure(dataset.output_types, dataset.output_shapes) | |
dataset_init_op = iter.make_initializer(dataset, name='dataset_init') | |
x_, y_ = iter.get_next() | |
# make a simple model | |
prediction = tf.identity(tf.layers.dense(x_, 1), name='output') | |
# Optimize loss | |
loss = tf.reduce_mean(tf.square(prediction - y_), name='loss') | |
optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.01) | |
train_op = optimizer.minimize(loss, name='train') | |
with tf.Session() as sess: | |
sess.run(tf.global_variables_initializer()) | |
# Initialize the iterator using training data | |
sess.run(dataset_init_op, feed_dict={ x: train, y: train_label, batch_size: BATCH_SIZE }) | |
# Train | |
for i in range(EPOCHS): | |
_, loss_value = sess.run([train_op, loss]) | |
print("Iter: {}, Loss: {:.4f}".format(i, loss_value)) | |
# Initialize the iterator using test data | |
sess.run(dataset_init_op, feed_dict={ x: test_data, y: test_label, batch_size: test_data.shape[0] }) | |
# Prediction | |
results = sess.run(prediction) | |
for i in range(len(results)): | |
print("{} {}".format(test_data[i], results[i])) | |
# Save model | |
inputs_dict = { | |
"x": x, | |
"y": y, | |
"batch_size": batch_size | |
} | |
outputs_dict = { | |
"output": prediction | |
} | |
tf.saved_model.simple_save(sess, './model', inputs_dict, outputs_dict) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment