Skip to content

Instantly share code, notes, and snippets.

@lukmanr
Created October 17, 2018 19:23
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save lukmanr/3d00fb6e7c81a7d28bfaa8a69a6bc11b to your computer and use it in GitHub Desktop.
Save lukmanr/3d00fb6e7c81a7d28bfaa8a69a6bc11b to your computer and use it in GitHub Desktop.
TF Model Optimization 8
def convert_graph_def_to_saved_model(export_dir, graph_filepath):
if tf.gfile.Exists(export_dir):
tf.gfile.DeleteRecursively(export_dir)
graph_def = get_graph_def_from_file(graph_filepath)
with tf.Session(graph=tf.Graph()) as session:
tf.import_graph_def(graph_def, name='')
tf.saved_model.simple_save(
session,
export_dir,
inputs={
node.name: session.graph.get_tensor_by_name(
'{}:0'.format(node.name))
for node in graph_def.node if node.op=='Placeholder'},
outputs={'class_ids': session.graph.get_tensor_by_name(
'head/predictions/class_ids:0')}
)
print('Optimized graph converted to SavedModel!')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment