|
#!/usr/bin/env python |
|
|
|
""" |
|
This example shows how to load the owl-generated comptation Graph definition directly into Tensorflow and execute. |
|
If you want to see how Tensorflow itself does it, un-comment Section 1 and comment Section 2, and then run again. |
|
""" |
|
|
|
from __future__ import absolute_import |
|
from __future__ import division |
|
from __future__ import print_function |
|
|
|
import math |
|
import os |
|
import sys |
|
import time |
|
|
|
import numpy as np |
|
import tensorflow as tf |
|
from tensorflow.contrib.learn.python.learn.datasets.mnist import read_data_sets |
|
|
|
from google.protobuf import text_format |
|
from tensorflow.python.framework import graph_io |
|
|
|
NUM_CLASSES = 10 |
|
IMAGE_SIZE = 28 |
|
IMAGE_PIXELS = IMAGE_SIZE * IMAGE_SIZE |
|
|
|
input_data_dir = os.path.join(os.getenv('TEST_TMPDIR', '/tmp'), |
|
'tensorflow/mnist/input_data') |
|
data_sets = read_data_sets(input_data_dir, False) |
|
batch_size = 100 |
|
filename = 'test_cgraph_cnn' |
|
checkpoint_file = os.path.join(os.getenv("PWD"), filename + '.ckpt') |
|
meta_file = filename + '.pb' |
|
|
|
|
|
def get_one_hot(targets, nb_classes): |
|
res = np.eye(nb_classes)[np.array(targets).reshape(-1)] |
|
return res.astype('uint8') |
|
|
|
|
|
def fill_feed_dict(data_set, images_pl, labels_pl): |
|
images_feed, labels_feed = data_set.next_batch(batch_size, False) |
|
images_feed = np.reshape(images_feed, (-1, 28, 28, 1)) |
|
labels_feed = get_one_hot(labels_feed, NUM_CLASSES) |
|
feed_dict = { |
|
images_pl: images_feed, |
|
labels_pl: labels_feed, |
|
} |
|
return feed_dict |
|
|
|
|
|
## Section 1. Build Graph and save |
|
|
|
""" |
|
|
|
def cnn_model(images): |
|
# images already preprocessed, so a dummy lambda layer here |
|
lamb = images / (1.) |
|
conv1 = tf.layers.conv2d(inputs=lamb, filters=32, kernel_size=[5, 5], |
|
padding="same", activation=tf.nn.relu) |
|
pool1 = tf.layers.max_pooling2d(inputs=conv1, pool_size=[2, 2], strides=2) |
|
flatten = tf.layers.flatten(inputs=pool1) |
|
dense = tf.layers.dense(inputs=flatten, units=1024, activation=tf.nn.relu) |
|
logits = tf.layers.dense(inputs=dense, units=10) |
|
return logits |
|
|
|
|
|
def loss(logits, labels): |
|
labels = tf.to_int64(labels) |
|
return tf.losses.softmax_cross_entropy(onehot_labels=labels, logits=logits) |
|
|
|
|
|
def placeholder_inputs(batch_size): |
|
images_placeholder = tf.placeholder( |
|
tf.float32, shape=(batch_size, IMAGE_SIZE, IMAGE_SIZE, 1), |
|
name='images_placeholder') |
|
labels_placeholder = tf.placeholder(tf.int32, |
|
shape=(batch_size, NUM_CLASSES), name='labels_placeholder') |
|
return images_placeholder, labels_placeholder |
|
|
|
|
|
with tf.Graph().as_default(): |
|
sess = tf.Session() |
|
images_placeholder, labels_placeholder = placeholder_inputs(batch_size) |
|
logits = cnn_model(images_placeholder) |
|
loss = loss(logits, labels_placeholder) |
|
|
|
tf.add_to_collection("result", loss) |
|
|
|
saver = tf.train.Saver() |
|
init = tf.global_variables_initializer() |
|
sess.run(init) |
|
saver.save(sess, checkpoint_file) |
|
|
|
""" |
|
|
|
|
|
### Section 2. Load Graph from pbtxt file |
|
|
|
|
|
with open(filename + '.pbtxt', 'r') as f: |
|
metagraph_def = tf.MetaGraphDef() |
|
file_content = f.read() |
|
text_format.Merge(file_content,metagraph_def) |
|
graph_io.write_graph(metagraph_def, |
|
os.path.dirname(filename), |
|
os.path.basename(filename) + '.pb', |
|
as_text=False) |
|
|
|
|
|
### Section 3. Execute inference on NN |
|
|
|
with tf.Graph().as_default(): |
|
sess = tf.Session() |
|
saver = tf.train.import_meta_graph(meta_file) |
|
graph = tf.get_default_graph() |
|
|
|
images_placeholder = graph.get_tensor_by_name('xt:0') |
|
labels_placeholder = graph.get_tensor_by_name('yt:0') |
|
loss = tf.get_collection("result")[0] |
|
|
|
init = tf.global_variables_initializer() |
|
sess.run(init) |
|
|
|
feed_dict = fill_feed_dict(data_sets.test, images_placeholder, labels_placeholder) |
|
print(sess.run(loss, feed_dict=feed_dict)) |
|
|