Skip to content

Instantly share code, notes, and snippets.

Embed
What would you like to do?
Cut sub-model from an ONNX model, and update its input/output names or shapes
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# pylint: disable=missing-docstring
import argparse
import os
import sys
import timeit
import typing
import onnx
class ONNXCutArgs(typing.NamedTuple):
input_path: str
output_path: typing.Optional[str]
input_names: typing.List[str]
output_names: typing.List[str]
input_names_new: typing.Optional[typing.List[str]]
output_names_new: typing.Optional[typing.List[str]]
input_shape: typing.Optional[typing.Tuple[int, int, int, int]]
specify_shapes: typing.Optional[typing.Dict[str, typing.Tuple[int, int, int, int]]]
do_cut: bool
do_rename: bool
do_reshape: bool
do_shape_infer: bool
do_simplify: bool
class ONNXTiming(object):
def beg(self, action):
self._t_beg = timeit.default_timer()
self._action = action
print(f"\nonnx {action} ...")
def end(self):
t_end = timeit.default_timer()
print(f"onnx {self._action} done, cost {t_end-self._t_beg:.3f} s")
def onnx_cut(args: ONNXCutArgs) -> None:
timing = ONNXTiming()
model_path = args.input_path
if args.do_cut:
timing.beg("cut")
onnx.utils.extract_model(
args.input_path, args.output_path, args.input_names, args.output_names)
timing.end()
model_path = args.output_path
timing.beg("load")
print(f" path={model_path}")
model = onnx.load(model_path)
timing.end()
if args.do_rename:
timing.beg("rename")
names = args.input_names
name_new = [f"input_{i}" for i, _ in enumerate(args.input_names)]
if args.input_names_new is None:
name_new = [f"input_{i}" for i, _ in enumerate(args.input_names)]
else:
name_new = args.input_names_new
names.extend(args.output_names)
if args.output_names_new is None:
name_new.extend([f"output_{i}" for i, _ in enumerate(args.output_names)])
else:
name_new.extend(args.output_names_new)
_onnx_rename(model, names, name_new)
_onnx_rename_print(model, names, name_new,
file=os.path.splitext(model_path)[0] + "_rename.csv")
timing.end()
simp_input_shape = None
if args.do_reshape:
if args.input_shape and len(args.input_shape) == 4:
timing.beg("reshape input")
_onnx_input_reshape(model, args.input_shape)
simp_input_shape = args.input_shape
timing.end()
if args.specify_shapes:
timing.beg("reshape specified")
_onnx_specify_shapes(model, args.specify_shapes)
timing.end()
if args.do_shape_infer:
timing.beg("shape_infer")
model_infer = onnx.shape_inference.infer_shapes(model)
model = model_infer
timing.end()
if args.do_simplify:
timing.beg("simplify")
import onnxsim
if simp_input_shape:
model_simp, check = onnxsim.simplify(model, perform_optimization=False,
input_shapes={None: simp_input_shape})
else:
model_simp, check = onnxsim.simplify(model, perform_optimization=False)
assert check, "Simplified ONNX model could not be validated"
model = model_simp
timing.end()
output_path = os.path.splitext(model_path)[0] + "_final.onnx"
timing.beg("save")
print(f" path={output_path}")
onnx.save(model, output_path)
timing.end()
def _onnx_rename(model, names, names_new):
for node in model.graph.node:
for i, n in enumerate(node.input):
if n in names:
node.input[i] = names_new[names.index(n)]
for i, n in enumerate(node.output):
if n in names:
node.output[i] = names_new[names.index(n)]
for node in model.graph.input:
if node.name in names:
node.name = names_new[names.index(node.name)]
# print(model.graph.input)
for node in model.graph.output:
if node.name in names:
node.name = names_new[names.index(node.name)]
# print(model.graph.output)
def _onnx_rename_print(model, names, names_new, file="onnx_rename.csv"):
if file:
import csv
with open(file, "w", newline="") as f:
writer = csv.writer(f)
writer.writerow(("old", "new", "node"))
print(" old, new, node")
for n, nn in zip(names, names_new):
node = _onnx_node(model, nn)
writer.writerow((n, nn, '-' if node is None else node.name))
print(f"{n:>10s}, {nn:>10s}, {'-' if node is None else node.name:>10s}")
else:
print(" old, new, node")
for n, nn in zip(names, names_new):
node = _onnx_node(model, nn)
print(f"{n:>10s}, {nn:>10s}, {'-' if node is None else node.name:>10s}")
def _onnx_node(model, name):
for node in model.graph.node:
if node.name == name:
return node
for input in node.input:
if input == name:
return node
for output in node.output:
if output == name:
return node
return None
def _onnx_input_reshape(model, input_shape):
# set input_dims
input_dims = {}
for input in model.graph.input:
input_dims[input.name] = list(input_shape)
# keep output_dims
output_dims = {}
for output in model.graph.output:
dim = []
tensor_type = output.type.tensor_type
if tensor_type.HasField("shape"):
for d in tensor_type.shape.dim:
if d.HasField("dim_value"):
dim.append(d.dim_value)
elif d.HasField("dim_param"):
dim.append(d.dim_param)
else:
# sys.exit("error: unknown dimension")
continue
output_dims[output.name] = dim
# update
from onnx.tools import update_model_dims
update_model_dims.update_inputs_outputs_dims(
model, input_dims, output_dims)
print(model.graph.input)
# print(model.graph.output)
def _onnx_specify_shapes(model, specify_shapes):
graph = model.graph
input_map = _onnx_graph_name_map(graph.input)
output_map = _onnx_graph_name_map(graph.output)
value_info_map = _onnx_graph_name_map(graph.value_info)
all_map = {**input_map, **output_map, **value_info_map}
def _update(info_new, *graph_value_infos):
for infos in graph_value_infos:
for info in infos:
if info.name == info_new.name:
infos.remove(info)
infos.append(info_new)
return True
return False
from onnx import helper
for name, shape in specify_shapes.items():
if name in all_map:
info = all_map[name]
info_new = helper.make_tensor_value_info(
info.name, info.type.tensor_type.elem_type, list(shape))
if _update(info_new, graph.input, graph.output, graph.value_info):
print(f" {name}: {shape}")
else:
print(f"warn: specify shape {name} update failed")
else:
print(f"warn: specify shape {name} not found")
def _onnx_graph_name_map(graph_prop_list):
m = {}
for n in graph_prop_list:
m[n.name] = n
return m
def _parse_args():
parser = argparse.ArgumentParser()
def ints_type(string, num=4, sep=","):
if not string:
return None
ints = string.split(sep)
ints_len = len(ints)
if ints_len != num:
sys.exit(f"error: ints_type size must be {num}")
return tuple([int(x) for x in ints])
def shapes_type(string):
if not string:
return None
shapes = dict()
for s in string.split(";"):
ss = s.split(":")
if 2 != len(ss):
sys.exit(f"error: shapes_type must be like \"name:1,1,512,512;...\"")
name = ss[0]
ints_s = ss[1]
ints = ints_s.split(",")
if 4 != len(ints):
sys.exit(f"error: shapes_type must be like \"name:1,1,512,512\"")
shapes[name] = tuple([int(x) for x in ints])
return shapes
parser.add_argument("-i", "--input", required=True,
help="the model input path: %(default)s")
parser.add_argument("-o", "--output", default=None,
help="the model output path for sub: %(default)s")
parser.add_argument("-in", "--input-names", nargs="+", required=True,
help="the model input names for sub: %(default)s")
parser.add_argument("-on", "--output-names", nargs="+", required=True,
help="the model output names for sub: %(default)s")
parser.add_argument("-inn", "--input-names-new", nargs="+",
help="the model input names for sub: %(default)s")
parser.add_argument("-onn", "--output-names-new", nargs="+",
help="the model output names for sub: %(default)s")
parser.add_argument("-is", "--input-shape",
type=ints_type, metavar="1,3,512,512", default=None,
help="the model input shape NCHW for sub: %(default)s")
parser.add_argument("-ss", "--specify-shapes",
type=shapes_type, metavar="\"name:1,3,512,512;...\"", default=None,
help="specify shapes by input/output name of all nodes: %(default)s")
parser.add_argument("-nc", "--no-cut", action="store_true",
help="not cut sub model by input/output names: %(default)s")
parser.add_argument("-nrn", "--no-rename", action="store_true",
help="not rename input/output names to new ones (default: input/output_0 ...): %(default)s")
parser.add_argument("-nrs", "--no-reshape", action="store_true",
help="not reshape input shape, specify shapes: %(default)s")
parser.add_argument("-nis", "--no-infer-shape", action="store_true",
help="not inference shapes of model: %(default)s")
parser.add_argument("-nsi", "--no-simplify", action="store_true",
help="not simplify sub model: %(default)s")
args = parser.parse_args()
if not os.path.exists(args.input):
sys.exit(f"error: model \"{args.input}\" not exists")
if not args.output:
args.output = os.path.splitext(args.input)[0] + "_cut.onnx"
if args.input_names_new is not None and \
len(args.input_names_new) != len(args.input_names):
sys.exit("error: the size of new input names must same as original ones")
if args.output_names_new is not None and \
len(args.output_names_new) != len(args.output_names):
sys.exit("error: the size of new output names must same as original ones")
print("Args")
print(f" input: {args.input}")
print(f" output: {args.output}")
print(f" input_names: {args.input_names}")
print(f" output_names: {args.output_names}")
print(f" input_names_new: {args.input_names_new}")
print(f" output_names_new: {args.output_names_new}")
print(f" input_shape: {args.input_shape}")
print(f" specify_shapes: {args.specify_shapes}")
print(f" no_cut: {args.no_cut}")
print(f" no_rename: {args.no_rename}")
print(f" no_reshape: {args.no_reshape}")
print(f" no_infer_shape: {args.no_infer_shape}")
print(f" no_simplify: {args.no_simplify}")
return args
def _main():
args = _parse_args()
onnx_cut(ONNXCutArgs(
input_path=args.input,
output_path=args.output,
input_names=args.input_names,
output_names=args.output_names,
input_names_new=args.input_names_new,
output_names_new=args.output_names_new,
input_shape=args.input_shape,
specify_shapes=args.specify_shapes,
do_cut=not args.no_cut,
do_rename=not args.no_rename,
do_reshape=not args.no_reshape,
do_shape_infer=not args.no_infer_shape,
do_simplify=not args.no_simplify
))
if __name__ == "__main__":
_main()
# Is it possible to change input/output layer names of onnx model?
# https://github.com/onnx/onnx/issues/2052
# WinMLDashboard
# https://github.com/microsoft/Windows-Machine-Learning/tree/master/Tools/WinMLDashboard
@ikuokuo
Copy link
Author

ikuokuo commented Jun 2, 2021

ONNX env

conda create -n onnx python=3.8 -y
conda activate onnx

# ONNX
#  https://github.com/onnx/onnx
conda install -c conda-forge onnx -y

# ONNX Simplifier
#  https://github.com/daquexian/onnx-simplifier
pip install onnx-simplifier

@mayankverk
Copy link

mayankverk commented Jun 30, 2022

Hey, the resulting onnx file is getting seg fault in onnxsim module. Any ideas why?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment