Skip to content

Instantly share code, notes, and snippets.

Embed
What would you like to do?
Listing operations in frozen .pb TensorFlow graphs in GraphDef format (see comments for SavedModel)
import argparse
import os
import sys
from typing import Iterable
import tensorflow as tf
parser = argparse.ArgumentParser()
parser.add_argument('file', type=str, help='The file name of the frozen graph.')
args = parser.parse_args()
if not os.path.exists(args.file):
parser.exit(1, 'The specified file does not exist: {}'.format(args.file))
graph_def = None
graph = None
# Assuming a `.pb` file in `GraphDef` format.
# See comments on https://gist.github.com/sunsided/88d24bf44068fe0fe5b88f09a1bee92a/
# for inspecting SavedModel graphs instead.
print('Loading graph definition ...', file=sys.stderr)
try:
with tf.gfile.GFile(args.file, "rb") as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
except BaseException as e:
parser.exit(2, 'Error loading the graph definition: {}'.format(str(e)))
print('Importing graph ...', file=sys.stderr)
try:
assert graph_def is not None
with tf.Graph().as_default() as graph: # type: tf.Graph
tf.import_graph_def(
graph_def,
input_map=None,
return_elements=None,
name='',
op_dict=None,
producer_op_list=None
)
except BaseException as e:
parser.exit(2, 'Error importing the graph: {}'.format(str(e)))
print()
print('Operations:')
assert graph is not None
ops = graph.get_operations() # type: Iterable[tf.Operation]
for op in ops:
print('- {0:20s} "{1}" ({2} outputs)'.format(op.type, op.name, len(op.outputs)))
print()
print('Sources (operations without inputs):')
for op in ops:
if len(op.inputs) > 0:
continue
print('- {0}'.format(op.name))
print()
print('Operation inputs:')
for op in ops:
if len(op.inputs) == 0:
continue
print('- {0:20}'.format(op.name))
print(' {0}'.format(', '.join(i.name for i in op.inputs)))
print()
print('Tensors:')
for op in ops:
for out in op.outputs:
print('- {0:20} {1:10} "{2}"'.format(str(out.shape), out.dtype.name, out.name))
@fuzzyBatman

This comment has been minimized.

Copy link

@fuzzyBatman fuzzyBatman commented May 21, 2019

Thank you mate!

@peter197321

This comment has been minimized.

Copy link

@peter197321 peter197321 commented Sep 9, 2020

What version of tf&python is needed?

(tf1.x-cpu) PS C:> python .\dump_operations.py .\saved_model\saved_model.pb
PS C:> python .\dump_operations.py .\saved_model\saved_model.pb
C:\Anaconda3\envs\tf1.x-cpu\lib\site-packages\tensorflow\python\framework\dtypes.py:516: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.
_np_qint8 = np.dtype([("qint8", np.int8, 1)])
C:\Anaconda3\envs\tf1.x-cpu\lib\site-packages\tensorflow\python\framework\dtypes.py:517: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.
_np_quint8 = np.dtype([("quint8", np.uint8, 1)])
C:\Anaconda3\envs\tf1.x-cpu\lib\site-packages\tensorflow\python\framework\dtypes.py:518: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.
_np_qint16 = np.dtype([("qint16", np.int16, 1)])
C:\Users\nxa18908\Anaconda3\envs\tf1.x-cpu\lib\site-packages\tensorflow\python\framework\dtypes.py:519: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.
_np_quint16 = np.dtype([("quint16", np.uint16, 1)])
C:\Anaconda3\envs\tf1.x-cpu\lib\site-packages\tensorflow\python\framework\dtypes.py:520: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.
_np_qint32 = np.dtype([("qint32", np.int32, 1)])
C:\Anaconda3\envs\tf1.x-cpu\lib\site-packages\tensorflow\python\framework\dtypes.py:525: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.
np_resource = np.dtype([("resource", np.ubyte, 1)])
C:\Anaconda3\envs\tf1.x-cpu\lib\site-packages\tensorboard\compat\tensorflow_stub\dtypes.py:541: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.
ning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.
_np_quint8 = np.dtype([("quint8", np.uint8, 1)])
C:\Anaconda3\envs\tf1.x-cpu\lib\site-packages\tensorboard\compat\tensorflow_stub\dtypes.py:543: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.
_np_qint16 = np.dtype([("qint16", np.int16, 1)])
C:\Anaconda3\envs\tf1.x-cpu\lib\site-packages\tensorboard\compat\tensorflow_stub\dtypes.py:544: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.
_np_quint16 = np.dtype([("quint16", np.uint16, 1)])
C:\Anaconda3\envs\tf1.x-cpu\lib\site-packages\tensorboard\compat\tensorflow_stub\dtypes.py:545: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.
_np_qint32 = np.dtype([("qint32", np.int32, 1)])
C:\Anaconda3\envs\tf1.x-cpu\lib\site-packages\tensorboard\compat\tensorflow_stub\dtypes.py:550: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.
np_resource = np.dtype([("resource", np.ubyte, 1)])
Loading graph definition ...
WARNING:tensorflow:From .\dump_operations.py:20: The name tf.gfile.GFile is deprecated. Please use tf.io.gfile.GFile instead.

WARNING:tensorflow:From .\dump_operations.py:21: The name tf.GraphDef is deprecated. Please use tf.compat.v1.GraphDef instead.

Error loading the graph definition: Wrong wire type in tag.
(tf1.x-cpu) PS C:> conda activate tf2.x-cpu
(tf2.x-cpu) PS C:> python .\dump_operations.py .\saved_model\saved_model.pb
Loading graph definition ...
Error loading the graph definition: module 'tensorflow' has no attribute 'gfile'
(tf2.x-cpu) PS C:>

@sunsided

This comment has been minimized.

Copy link
Owner Author

@sunsided sunsided commented Sep 9, 2020

It's been a while, but any Python 3 and some TensorFlow >= 1.8 and < 2 should do. Didn't try with TF 2, but its upgrade converter might help.

@peter197321

This comment has been minimized.

Copy link

@peter197321 peter197321 commented Sep 9, 2020

@sunsided

This comment has been minimized.

Copy link
Owner Author

@sunsided sunsided commented Sep 9, 2020

I just toyed around with it a bit and here is my best guess at what's happening. 🙂

First of all - to get rid of the warnings, try replacing import tensorflow as tf with import tensorflow.compat.v1 as tf.

I'm assuming the culprit is this: The code above is loading a frozen model containing a GraphDef (defined in graph.proto) - however, these are not compatible with loading graphs stored as a SavedModel (defined in saved_model.proto).

Specifically this block assumes the GraphDef format, tries to decode it as such (this is where it blows up) and then imports the result into a new graph:

import tensorflow.compat.v1 as tf

model_file = "path/to/model/file.pb"

with tf.gfile.GFile(model_file, "rb") as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())

assert graph_def is not None
with tf.Graph().as_default() as graph:  # type: tf.Graph
    tf.import_graph_def(
        graph_def,
        input_map=None,
        return_elements=None,
        name='',
        op_dict=None,
        producer_op_list=None
    )

If you have a SavedModel however, there's a much quicker way to achieve the same result using a saved_model.loader. Assuming the tf.tag_constants.SERVING tag:

import tensorflow.compat.v1 as tf

model_path = "path/to/model"  # <-- sneaky one, expects a `saved_model.pb` file in there

graph = tf.Graph()
with tf.Session(graph=graph) as sess:
    tf.saved_model.loader.load(sess, [tf.tag_constants.SERVING], model_path)

This method is deprecated according to the tf.compat.v1.saved_model.load documentation, but upgrading shouldn't be too hard.

Once you have the graph populated with the mode, the rest of the code works as before.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment