Skip to content

Instantly share code, notes, and snippets.

@r7vme
Created June 23, 2019 21:32
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save r7vme/95afe5f21b2ffbd2ae7cb6d59f37bfce to your computer and use it in GitHub Desktop.
Save r7vme/95afe5f21b2ffbd2ae7cb6d59f37bfce to your computer and use it in GitHub Desktop.
Add new TensorFlow opeation to existing (freezed) graph
import tensorflow as tf
def get_graph_def_from_file(graph_filepath):
with tf.Graph().as_default():
with tf.gfile.GFile(graph_filepath, 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
return graph_def
# inputs
graph_file="optimized_tRT.pb"
graph_file_out="optimized_tRT_modified.pb"
outputs=["test_model/model/logits/linear/BiasAdd:0"]
g=tf.Graph()
with tf.Session(graph=g) as sess:
y, = tf.import_graph_def(
get_graph_def_from_file(graph_file),
name="",
return_elements=outputs
)
# Add desired operation
# Examples:
#y = tf.argmax(y, axis=3, name="argmax_1", output_type=tf.dtypes.int32)
y=tf.nn.softmax(y, axis=3, name="softmax_1")
# Save graph
tf.train.write_graph(g, logdir=".", as_text=False, name=graph_file_out)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment