Skip to content

Instantly share code, notes, and snippets.

@usstq
Last active May 4, 2023 05:51
Show Gist options
  • Save usstq/610071c2009f7fbcb7806f28c4a82370 to your computer and use it in GitHub Desktop.
Save usstq/610071c2009f7fbcb7806f28c4a82370 to your computer and use it in GitHub Desktop.
a CLI tool for inspecting onnx model
#!/usr/bin/python3
import onnx
import onnx.numpy_helper
import sys
def get_value_info(m, name):
if (name.startswith('%')):
name = name[1:]
for a in m.graph.value_info:
if a.name == name:
return onnx.helper.printable_value_info(a)
return name
def read_attr_value(m, fuzzy_name):
if (fuzzy_name.startswith('%')):
fuzzy_name = fuzzy_name[1:]
for a in m.graph.node:
if fuzzy_name in a.output:
name = a.name
for o_idx in range(len(a.output)):
if (a.output[o_idx] == fuzzy_name):
value = None
for v in a.attribute:
if v.name == 'value':
value = onnx.numpy_helper.to_array(v.t)
return f"{value}"
return "?"
def read_const(m, fuzzy_name):
if (fuzzy_name.startswith('%')):
fuzzy_name = fuzzy_name[1:]
for a in m.graph.value_info:
if a.name == fuzzy_name:
#print(dir(a.type))
print(f"value_info: {onnx.helper.printable_value_info(a)}")
#print(f"value_info: {a}")
for a in m.graph.node:
if a.name == fuzzy_name:
print(f"{a}")
return
if fuzzy_name in a.output:
name = a.name
for o_idx in range(len(a.output)):
if (a.output[o_idx] == fuzzy_name):
value = None
for v in a.attribute:
if v.name == 'value':
value = onnx.numpy_helper.to_array(v.t)
print(f"value: {name}:{o_idx}({fuzzy_name}) {value}")
return
for t in m.graph.initializer:
if t.name == fuzzy_name:
tensor = onnx.numpy_helper.to_array(t)
print(f"initializer: {fuzzy_name}={tensor} \n{tensor.dtype} {tensor.shape}")
if (len(sys.argv) == 1):
print("inspect.py onnx_model [node_name | output_name | initialzer_name] ...")
print(" You can check printable graph or any constant/initializer")
sys.exit(0)
# show printable graph
m = onnx.load(sys.argv[1])
if (len(sys.argv) == 2):
m = onnx.shape_inference.infer_shapes(m)
str_graph = onnx.helper.printable_graph(m.graph)
for line in str_graph.splitlines():
parts = line.split(" = ")
value = ""
if len(parts) >= 2:
value_name = parts[0].strip()
value_start = parts[0].find(value_name)
parts[0] = parts[0][0:value_start] + get_value_info(m, value_name).strip()
if parts[1].startswith("Constant"):
value = f"value={read_attr_value(m, value_name)}"
print(" = ".join(parts) + " " + value)
sys.exit(0)
# inspect value
for n in sys.argv[2:]:
read_const(m, n)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment