Skip to content

Instantly share code, notes, and snippets.

@eggie5
Last active October 31, 2017 23:06
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save eggie5/f0838a1289ca851aa5a72593575b7f06 to your computer and use it in GitHub Desktop.
Save eggie5/f0838a1289ca851aa5a72593575b7f06 to your computer and use it in GitHub Desktop.
### We will try to seralize and desearlaize a graph that is using the new `get_single_element` function of the Dataset API
### You will see that it does not desearlize gracefully.
#### Part 1: Build arbitrary graph using Dataset API and new get_single_element function
import numpy as np
import tensorflow as tf
from tensorflow.contrib.data import Dataset, Iterator
BATCH_SIZE=100
example = tf.placeholder(tf.float32, shape=[None, 5], name="input_example") # 1 example has 5 features
w = tf.Variable(tf.random_normal([1, 5], stddev=0.35),name="weights")
b = tf.Variable(tf.zeros([5]), name="biases")
def preprocessing_fn(_in):
return _in+0.0
dataset = (tf.data.Dataset.from_tensor_slices(example)
.map(preprocessing_fn, num_parallel_calls=BATCH_SIZE)
.batch(BATCH_SIZE))
x = tf.contrib.data.get_single_element(dataset) #new per fb180d5 ~20 days ago
z = tf.sigmoid(tf.multiply(w,x)+b)
init_op = tf.global_variables_initializer()
### Part 2: Export the Graph using SavedModel API
with tf.Session() as sess:
sess.run(init_op)
print sess.run(z, feed_dict={example:np.random.random((2,5))})
#export
from datetime import datetime
from tensorflow.python.saved_model import builder
import os
version = int(datetime.now().strftime("%s"))
export_path = os.path.join(tf.compat.as_bytes("_versions"), tf.compat.as_bytes(str(version)))
print 'Exporting trained model to', export_path
builder = builder.SavedModelBuilder(export_path)
builder.add_meta_graph_and_variables(sess, ["ID_TAG"],clear_devices=True)
builder.save(as_text=True)
print('Done exporting!')
### Part 3 Restore the graph using SavedModel API
tf.reset_default_graph()
dir(tf.contrib)
with tf.Session() as sess:
tf.saved_model.loader.load(sess, ["ID_TAG"], '_versions/%s'%version)
graph = tf.get_default_graph()
input_t = graph.get_tensor_by_name("input_example:0")
output = graph.get_tensor_by_name("Sigmoid:0")
print sess.run(output, feed_dict={input_t:np.random.random((10,5))})
#THROWS THis ERROR
# NotFoundError (see above for traceback): Function tf_map_func_7eefab31 is not defined.
# [[Node: ParallelMapDataset = ParallelMapDataset[Targuments=[], f=tf_map_func_7eefab31[],
# output_shapes=[[5]], output_types=[DT_FLOAT],
# _device="/job:localhost/replica:0/task:0/device:CPU:0"](TensorSliceDataset, num_parallel_calls)]]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment