Skip to content

Instantly share code, notes, and snippets.

@jameshfisher
Created March 8, 2021 15:19
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save jameshfisher/f99ad86fc23d2ae7c856ee2f2ec89cd8 to your computer and use it in GitHub Desktop.
Save jameshfisher/f99ad86fc23d2ae7c856ee2f2ec89cd8 to your computer and use it in GitHub Desktop.
Plot a TensorFlow graph with graphviz/dot
import tensorflow as tf
try:
# pydot-ng is a fork of pydot that is better maintained.
import pydot_ng as pydot
except ImportError:
# pydotplus is an improved version of pydot
try:
import pydotplus as pydot
except ImportError:
# Fall back on pydot if necessary.
try:
import pydot
except ImportError:
pydot = None
def add_edge(dot, src, dst, **kwargs):
if not dot.get_edge(src, dst):
dot.add_edge(pydot.Edge(src, dst, **kwargs))
def format_shape(shape):
return str(shape).replace(str(None), 'None').replace('<', '').replace('>', '')
subgraph_attrs = [
'_true_graph', '_false_graph', # StatelessIf
'_cond_graph', '_body_graph', # StatelessWhile
# TODO what other attrs refer to subgraphs?
]
def add_graph_to_dot(graph, dot):
graph_input_labels = '|'.join([f"<in{pos}>"+tensor.dtype.name+' '+format_shape(tensor.shape) for (pos, tensor) in enumerate(graph.inputs)])
graphinput = pydot.Node(f"graphinput_{str(id(graph))}", label=f'Graph inputs: |{graph_input_labels}')
dot.add_node(graphinput)
graph_output_labels = '|'.join([f"<out{pos}>"+tensor.dtype.name+' '+format_shape(tensor.shape) for (pos, tensor) in enumerate(graph.outputs)])
graphoutput = pydot.Node(f"graphoutput_{str(id(graph))}", label=f'Graph outputs: |{graph_output_labels}')
dot.add_node(graphoutput)
for (f_name, f) in graph._functions.items():
# Note: pydot prepends "cluster_" to the id, which is how you draw a border (awful)
cluster = pydot.Cluster(str(id(f.graph)), label=f_name)
dot.add_subgraph(cluster)
add_graph_to_dot(f.graph, cluster)
ops = graph.get_operations()
# Add nodes first
for op in ops:
if op.type == 'Placeholder':
# For our purposes, a Placeholder _does_ have an input.
# It comes from the graph inputs.
# We instead use the placeholder's outputs to describe its input.
input_labels = '|'.join([f"<in{pos}>"+tensor.dtype.name+' '+format_shape(tensor.shape) for (pos, tensor) in enumerate(op.outputs)])
else:
input_labels = '|'.join([f"<in{pos}>"+tensor.dtype.name+' '+format_shape(tensor.shape) for (pos, tensor) in enumerate(op.inputs)])
output_labels = '|'.join([f"<out{pos}>"+tensor.dtype.name+' '+format_shape(tensor.shape) for (pos, tensor) in enumerate(op.outputs)])
label = f"{op.name}: {op.type}\n|{{inputs:|outputs:}}|{{{{{input_labels}}}|{{{output_labels}}}}}"
op_node = pydot.Node(str(id(op)), label=label)
dot.add_node(op_node)
# Now add edges
for op in ops:
try:
for pos, input_tensor in enumerate(op.inputs):
# Don't show the tensors; just draw arrows between operations
add_edge(
dot,
f"{str(id(input_tensor.op))}:out{input_tensor.value_index}",
f"{str(id(op))}:in{pos}",
)
except:
# Get an exception for _OperationWithOutputs - a tensorflow bug?
print(f"Could not get inputs for {op}")
for subgraph_attr in subgraph_attrs:
if hasattr(op, subgraph_attr):
subgraph = getattr(op, subgraph_attr)
add_edge(
dot,
f"graphoutput_{str(id(subgraph))}",
str(id(op)),
ltail=f"cluster_{str(id(subgraph))}",
label=subgraph_attr,
)
for pos, input_tensor in enumerate(graph.inputs):
# Note: always to input 0, because it's always to a Placeholder with one input
add_edge(
dot,
f"graphinput_{str(id(graph))}:in{pos}",
f"{str(id(input_tensor.op))}:in0"
)
for pos, output_tensor in enumerate(graph.outputs):
add_edge(
dot,
f"{str(id(output_tensor.op))}:out{output_tensor.value_index}",
f"graphoutput_{str(id(graph))}:out{pos}"
)
def graph_to_dot(graph):
dot = pydot.Dot()
dot.set('rankdir', 'TB')
dot.set('concentrate', 'true')
dot.set('dpi', 96)
dot.set_node_defaults(shape='record')
dot.set('compound', 'true') # https://stackoverflow.com/a/2012106/229792
dot.set('newrank', 'true')
add_graph_to_dot(graph, dot)
return dot
def plot_graph(graph):
dot = graph_to_dot(graph)
print(dot)
dot.write('./graph.png', format='png')
### EXAMPLE
def py_func(x):
if tf.random.uniform(()) < 0.5:
x = x*x
x = tf.cast(x, 'float32')
return 2*x + 5
tf_func = tf.function(py_func)
tf_concrete_func = tf_func.get_concrete_function(tf.constant(3))
my_graph = tf_concrete_func.graph
plot_graph(my_graph)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment