Skip to content

Instantly share code, notes, and snippets.

@thepulkitagarwal
Last active October 25, 2017 17:02
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save thepulkitagarwal/a454ada5fcac514bdc454bdf63b42603 to your computer and use it in GitHub Desktop.
Save thepulkitagarwal/a454ada5fcac514bdc454bdf63b42603 to your computer and use it in GitHub Desktop.
def print_graph_nodes(filename):
import tensorflow as tf
g = tf.GraphDef()
g.ParseFromString(open(filename, 'rb').read())
print()
print(filename)
print("=======================INPUT=========================")
print([n for n in g.node if n.name.find('input') != -1])
print("=======================OUTPUT========================")
print([n for n in g.node if n.name.find('output') != -1])
print("===================KERAS_LEARNING=====================")
print([n for n in g.node if n.name.find('keras_learning_phase') != -1])
print("======================================================")
print()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment