Skip to content

Instantly share code, notes, and snippets.

@usstq
Last active September 18, 2021 04:33
Show Gist options
  • Save usstq/81ed82be3b53bd0e6767b77c8cad40b0 to your computer and use it in GitHub Desktop.
Save usstq/81ed82be3b53bd0e6767b77c8cad40b0 to your computer and use it in GitHub Desktop.
improved net_drawer.py for onnx model visualize
# SPDX-License-Identifier: Apache-2.0
# A library and utility for drawing ONNX nets. Most of this implementation has
# been borrowed from the caffe2 implementation
# https://github.com/caffe2/caffe2/blob/master/caffe2/python/net_drawer.py
#
# The script takes two required arguments:
# -input: a path to a serialized ModelProto .pb file.
# -output: a path to write a dot file representation of the graph
#
# Given this dot file representation, you can-for example-export this to svg
# with the graphviz `dot` utility, like so:
#
# $ dot -Tsvg my_output.dot -o my_output.svg
#
# Improved by tingqian.li@intel.com:
# - directly exported to svg
# - remove value node
# - simplified node label
# - use tooltip feature
# - show shape infer result
#
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import argparse
from collections import defaultdict
import json
from os import path
import numpy
import onnx.numpy_helper
from onnx import shape_inference
from onnx import ModelProto, GraphProto, NodeProto, TensorProto
import pydot # type: ignore
from typing import Text, Any, Callable, Optional, Dict
OP_STYLE = {
'shape': 'box',
'color': 'lightblue',
'style': 'filled',
'margin':"0.1,0.1",
'height': "0.3"
}
CONST_STYLE = {
'style' : 'dashed',
'shape': 'Mrecord',
'margin':"0.03,0.03",
'height': "0.1"
}
INPUT_STYLE = {
'shape': 'box',
'color': 'gray1',
'style' : 'dotted',
}
INIT_STYLE = {
'shape': 'box',
}
BLOB_STYLE = {'shape': 'octagon'}
_NodeProducer = Callable[[NodeProto, int], pydot.Node]
def _escape_label(name): # type: (Text) -> Text
# json.dumps is poor man's escaping
return json.dumps(name)
def _form_and_sanitize_docstring(s): # type: (Text) -> Text
url = 'javascript:alert('
url += _escape_label(s).replace('"', '\'').replace('<', '').replace('>', '')
url += ')'
return url
value_map = {}
elem_type_map = {}
for attr_name in dir(TensorProto):
attr = getattr(TensorProto, attr_name)
if isinstance(attr, int):
elem_type_map[attr] = attr_name
def get_shape(value_name):
str_shape = ""
if (value_name in value_map):
value = value_map[value_name]
tensor_type = value.type.tensor_type
if (tensor_type.HasField("elem_type")):
if tensor_type.elem_type in elem_type_map:
str_shape += elem_type_map[tensor_type.elem_type] + "\\n"
str_shape += "("
if (tensor_type.HasField("shape")):
# iterate through dimensions of the shape:
for d in tensor_type.shape.dim:
if (str_shape[-1] != '('):
str_shape += ","
# the dimension may have a definite (integer) value or a symbolic identifier or neither:
if (d.HasField("dim_value")):
str_shape += str(d.dim_value)
elif (d.HasField("dim_param")):
str_shape += str(d.dim_param) # unknown dimension with symbolic name
else:
str_shape += "?" # unknown dimension with no name
else:
str_shape += "?"
str_shape += ")"
else:
str_shape = "?"
return str_shape
#
# https://github.com/sassoftware/python-dlpy/blob/master/dlpy/model_conversion/onnx_graph.py
#
def _convert_onnx_attribute_proto(attr_proto):
'''
Convert ONNX AttributeProto into Python object
'''
if attr_proto.HasField('f'):
return attr_proto.f
elif attr_proto.HasField('i'):
return attr_proto.i
elif attr_proto.HasField('s'):
return str(attr_proto.s, 'utf-8')
elif attr_proto.HasField('t'):
return attr_proto.t # this is a proto!
elif attr_proto.floats:
return list(attr_proto.floats)
elif attr_proto.ints:
return list(attr_proto.ints)
elif attr_proto.strings:
str_list = list(attr_proto.strings)
str_list = list(map(lambda x: str(x, 'utf-8'), str_list))
return str_list
else:
raise ValueError("Unsupported ONNX attribute: {}".format(attr_proto))
def GetOpNodeProducer(embed_docstring=False, **kwargs): # type: (bool, **Any) -> _NodeProducer
def ReallyGetOpNode(op, op_id): # type: (NodeProto, int) -> pydot.Node
node_name = '%s_%d' % (op.op_type, op_id)
tooltip = node_name
for i, input in enumerate(op.input):
tooltip += '\n input' + str(i) + ' ' + input
for i, output in enumerate(op.output):
tooltip += '\n output' + str(i) + ' ' + output
for i, attr in enumerate(op.attribute):
tooltip += f"\n {attr.name}:{_convert_onnx_attribute_proto(attr)}"
label = op.op_type
if op.name:
label = op.name
if not (op.op_type in op.name):
label += "/" + op.op_type
op_style = kwargs
if(op.op_type == "Constant"):
op_style = CONST_STYLE
for v in op.attribute:
if v.name == 'value':
value = onnx.numpy_helper.to_array(v.t)
tooltip = str(value).replace("9223372036854775807", "max")
if (len(tooltip)):
if (len(tooltip) < 10):
label = tooltip
else:
label = tooltip[:10] + "..."
node = pydot.Node(node_name, label=label, tooltip=tooltip, **op_style)
if embed_docstring:
url = _form_and_sanitize_docstring(op.doc_string)
node.set_URL(url)
return node
return ReallyGetOpNode
def GetPydotGraph(
graph, # type: GraphProto
name=None, # type: Optional[Text]
rankdir='LR', # type: Text
node_producer=None, # type: Optional[_NodeProducer]
embed_docstring=False, # type: bool
title = ""
): # type: (...) -> pydot.Dot
if node_producer is None:
node_producer = GetOpNodeProducer(embed_docstring=embed_docstring, **OP_STYLE)
pydot_graph = pydot.Dot(name, rankdir=rankdir)
pydot_graph.set("labelloc","t")
pydot_graph.set("labelfontsize",30)
pydot_graph.set("label", title)
pydot_nodes = {} # type: Dict[Text, pydot.Node]
pydot_node_counts = defaultdict(int) # type: Dict[Text, int]
op2node = {}
initializers = {}
for t in graph.initializer:
initializers[t.name] = t
oploc_rank = [0 for x in range(len(graph.node))]
value2op={}
for op_id, op in enumerate(graph.node):
for index, name in enumerate(op.output):
value2op[name] = op_id
for op_id, op in enumerate(graph.node):
for index,name in enumerate(op.input):
if name in value2op:
op_id = value2op[name]
oploc_rank[op_id] += index
# add op node by rank
sort_index = numpy.argsort(oploc_rank)
for op_id in sort_index:
op = graph.node[op_id]
op_node = node_producer(op, op_id)
pydot_graph.add_node(op_node)
op2node[op_id] = op_node
for index, name in enumerate(op.output):
if name not in pydot_nodes:
pydot_nodes[name] = {
"name":name,
"to":index,
"from":index,
"op_node":op_node,
"shape": get_shape(name),
"consumer_cnt" : 0
}
for v in graph.input:
in_node = pydot.Node(v.name, label=f'"input:{v.name}"', tooltip="input value", **INPUT_STYLE)
pydot_graph.add_node(in_node)
pydot_nodes[v.name] = {
"name":v.name,
"to":0,
"from":0,
"op_node":in_node,
"shape": get_shape(name),
"consumer_cnt" : 0
}
for v in graph.output:
in_node = pydot.Node(v.name, label=f'"output:{v.name}"', tooltip="output value", **INPUT_STYLE)
pydot_graph.add_node(in_node)
pydot_nodes[v.name] = {
"name":v.name,
"to":0,
"from":0,
"op_node":in_node,
"shape": get_shape(name),
"consumer_cnt" : 0
}
output_names = [v.name for v in graph.output]
for op_id, op in enumerate(graph.node):
for index,name in enumerate(op.input):
if (name not in pydot_nodes):
pydot_nodes[name] = {
"name":name,
"to":0,
"from":0,
"op_node":in_node,
"consumer_cnt" : 0
}
if name in initializers:
t = initializers[name]
tensor = onnx.numpy_helper.to_array(t)
node_label = str(tensor)
tooltip = name
if (len(node_label) > 10):
tooltip = node_label
node_label = name
if (len(tooltip) > 128):
tooltip = tooltip[:128] + "..."
in_node = pydot.Node(name, label=f'"{node_label}"', tooltip=tooltip, **INIT_STYLE)
pydot_graph.add_node(in_node)
pydot_nodes[name]["shape"] = f'{tensor.shape}'
pydot_nodes[name]["op_node"] = in_node
else:
node_name = name
if (node_name == ""):
node_name = "?"
in_node = pydot.Node(node_name, label=f'"?:{name}"', tooltip="unknown", **INIT_STYLE)
pydot_graph.add_node(in_node)
pydot_nodes[name]["shape"] = get_shape(name)
pydot_nodes[name]["op_node"] = in_node
pydot_n = pydot_nodes[name]
pydot_n["to"] = index
pydot_n["consumer_cnt"] += 1
if (pydot_n["op_node"]):
pydot_graph.add_edge(pydot.Edge(pydot_n["op_node"].get_name(), op2node[op_id].get_name() + ":" + str(pydot_n["to"]),
#taillabel=str(pydot_n["from"]),
#headlabel=str(pydot_n["to"]),
label = pydot_n["shape"] + "\n" + str(pydot_n["from"]) + "->" + str(pydot_n["to"]),
tooltip = "\"" + name + "\""
))
# to output
for index, name in enumerate(op.output):
if name in output_names:
pydot_n = pydot_nodes[name]
pydot_graph.add_edge(pydot.Edge(op2node[op_id], pydot_n["op_node"],
label = get_shape(name) + "\n" + str(pydot_n["from"]) + "->" + str(pydot_n["to"])))
return pydot_graph
def main(): # type: () -> None
parser = argparse.ArgumentParser(description="ONNX net drawer")
parser.add_argument(
"--input",
type=Text, required=True,
help="The input protobuf file.",
)
parser.add_argument(
"--output",
type=Text, required=True,
help="The output protobuf file.",
)
parser.add_argument(
"--rankdir", type=Text, default='TD',
help="The rank direction of the pydot graph.",
)
parser.add_argument(
"--embed_docstring", action="store_true",
help="Embed docstring as javascript alert. Useful for SVG format.",
)
args = parser.parse_args()
model = ModelProto()
with open(args.input, 'rb') as fid:
content = fid.read()
model.ParseFromString(content)
inferred_model = shape_inference.infer_shapes(model)
for v in inferred_model.graph.value_info:
value_map[v.name] = v
for v in model.graph.input:
value_map[v.name] = v
for v in model.graph.output:
value_map[v.name] = v
opset_import = [ str(x).strip('\n') for x in model.opset_import]
title = f"ir_version:{model.ir_version} model_version:{model.model_version} opset_import:{opset_import} producer:{model.producer_name} {model.producer_version}"
pydot_graph = GetPydotGraph(
model.graph,
name=model.graph.name,
rankdir=args.rankdir,
node_producer=GetOpNodeProducer(
embed_docstring=args.embed_docstring,
**OP_STYLE
),
title = title
)
pydot_graph.write(args.output, format = args.output.split(".")[-1])
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment