Last active
October 31, 2017 23:06
-
-
Save eggie5/f0838a1289ca851aa5a72593575b7f06 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
### We will try to seralize and desearlaize a graph that is using the new `get_single_element` function of the Dataset API | |
### You will see that it does not desearlize gracefully. | |
#### Part 1: Build arbitrary graph using Dataset API and new get_single_element function | |
import numpy as np | |
import tensorflow as tf | |
from tensorflow.contrib.data import Dataset, Iterator | |
BATCH_SIZE=100 | |
example = tf.placeholder(tf.float32, shape=[None, 5], name="input_example") # 1 example has 5 features | |
w = tf.Variable(tf.random_normal([1, 5], stddev=0.35),name="weights") | |
b = tf.Variable(tf.zeros([5]), name="biases") | |
def preprocessing_fn(_in): | |
return _in+0.0 | |
dataset = (tf.data.Dataset.from_tensor_slices(example) | |
.map(preprocessing_fn, num_parallel_calls=BATCH_SIZE) | |
.batch(BATCH_SIZE)) | |
x = tf.contrib.data.get_single_element(dataset) #new per fb180d5 ~20 days ago | |
z = tf.sigmoid(tf.multiply(w,x)+b) | |
init_op = tf.global_variables_initializer() | |
### Part 2: Export the Graph using SavedModel API | |
with tf.Session() as sess: | |
sess.run(init_op) | |
print sess.run(z, feed_dict={example:np.random.random((2,5))}) | |
#export | |
from datetime import datetime | |
from tensorflow.python.saved_model import builder | |
import os | |
version = int(datetime.now().strftime("%s")) | |
export_path = os.path.join(tf.compat.as_bytes("_versions"), tf.compat.as_bytes(str(version))) | |
print 'Exporting trained model to', export_path | |
builder = builder.SavedModelBuilder(export_path) | |
builder.add_meta_graph_and_variables(sess, ["ID_TAG"],clear_devices=True) | |
builder.save(as_text=True) | |
print('Done exporting!') | |
### Part 3 Restore the graph using SavedModel API | |
tf.reset_default_graph() | |
dir(tf.contrib) | |
with tf.Session() as sess: | |
tf.saved_model.loader.load(sess, ["ID_TAG"], '_versions/%s'%version) | |
graph = tf.get_default_graph() | |
input_t = graph.get_tensor_by_name("input_example:0") | |
output = graph.get_tensor_by_name("Sigmoid:0") | |
print sess.run(output, feed_dict={input_t:np.random.random((10,5))}) | |
#THROWS THis ERROR | |
# NotFoundError (see above for traceback): Function tf_map_func_7eefab31 is not defined. | |
# [[Node: ParallelMapDataset = ParallelMapDataset[Targuments=[], f=tf_map_func_7eefab31[], | |
# output_shapes=[[5]], output_types=[DT_FLOAT], | |
# _device="/job:localhost/replica:0/task:0/device:CPU:0"](TensorSliceDataset, num_parallel_calls)]] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment