Skip to content

Instantly share code, notes, and snippets.

@palonso
Last active February 11, 2022 08:15
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
Star You must be signed in to star a gist
Save palonso/a09a2505947c2d7bd26f22377061b0ab to your computer and use it in GitHub Desktop.
Serialize a TensorFlow graph as a Protobuf file. Remove the nodes not into the `model` namespace.
# tensorflow: loading model
sess = tf.Session()
sess.run(tf.global_variables_initializer())
saver = tf.train.Saver()
results_folder = model + '/'
saver.restore(sess, results_folder)
gd = sess.graph.as_graph_def()
for node in gd.node:
if node.op == 'RefSwitch':
node.op = 'Switch'
for index in range(len(node.input)):
if 'moving_' in node.input[index]:
node.input[index] = node.input[index] + '/read'
elif node.op == 'AssignSub':
node.op = 'Sub'
if 'use_locking' in node.attr: del node.attr['use_locking']
elif node.op == 'AssignAdd':
node.op = 'Add'
if 'use_locking' in node.attr: del node.attr['use_locking']
elif node.op == 'Assign':
node.op = 'Identity'
if 'use_locking' in node.attr: del node.attr['use_locking']
if 'validate_shape' in node.attr: del node.attr['validate_shape']
if len(node.input) == 2:
# input0: ref: Should be from a Variable node. May be uninitialized.
# input1: value: The value to be assigned to the variable.
node.input[0] = node.input[1]
del node.input[1]
node_names =[n.name for n in gd.node if 'model' in n.name]
subgraph = tf.graph_util.extract_sub_graph(gd, node_names)
tf.reset_default_graph()
tf.import_graph_def(subgraph)
output_graph_def = tf.graph_util.convert_variables_to_constants(
sess, # The session is used to retrieve the weights
gd, # The graph_def is used to retrieve the nodes
node_names #.split(",") # The output node names are used to select the usefull nodes
)
tf.io.write_graph(output_graph_def, model, output_graph, as_text=False)
sess.close()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment