Skip to content

Instantly share code, notes, and snippets.

@tonyreina
Last active September 4, 2019 19:40
Show Gist options
  • Save tonyreina/80763eecdc660e5b358308e9932fa03c to your computer and use it in GitHub Desktop.
Save tonyreina/80763eecdc660e5b358308e9932fa03c to your computer and use it in GitHub Desktop.
Load TensorFlow protobuf
import tensorflow as tf
import argparse
parser = argparse.ArgumentParser(
description="Loads TensorFlow protobuf and converts it to saved model",
add_help=True, formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument("--filename", required=True,
help="the name and path of the HDF5 dataset")
parser.add_argument("--input_layer_name", default="import/shuffled_queue:0",
help="the name of the input layer - Use Netron or TensorBoard to view the graph")
parser.add_argument("--output_layer_name", default="import/Rank_1/packed:0",
help="the name of the input layer - Use Netron or TensorBoard to view the graph")
args = parser.parse_args()
def printTensors(graph):
"""
Print all of the operations in a TensorFlow graph
"""
for op in graph.get_operations():
print(op.name)
def loadProtobuf(filename):
"""
Loads a binary TensorFlow protobuf file
"""
with tf.gfile.GFile(filename, "rb") as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
tf.graph_util.remove_training_nodes(graph_def, protected_nodes=None)
return graph_def
graph_def = loadProtobuf(args.filename)
with tf.Graph().as_default() as graph:
tf.import_graph_def(graph_def)
#printTensors(graph)
x = graph.get_tensor_by_name(args.input_layer_name)
y = graph.get_tensor_by_name(args.output_layer_name)
sess = tf.Session(graph=graph)
print("Loaded graph {}".format(args.filename))
tf.saved_model.simple_save(sess,
"saved_model_directory",
inputs={args.input_layer_name: x},
outputs={args.output_layer_name: y})
print("Saved model to 'saved_model_directory'")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment