Skip to content

Instantly share code, notes, and snippets.

Embed
What would you like to do?
Cut sub-model from an ONNX model, and update its input/output names or shapes
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# pylint: disable=missing-docstring
import argparse
import os
import sys
import timeit
import typing
import onnx
class ONNXCutArgs(typing.NamedTuple):
input_path: str
output_path: typing.Optional[str]
input_names: typing.List[str]
output_names: typing.List[str]
input_names_new: typing.Optional[typing.List[str]]
output_names_new: typing.Optional[typing.List[str]]
input_shape: typing.Optional[typing.Tuple[int, int, int, int]]
specify_shapes: typing.Optional[typing.Dict[str, typing.Tuple[int, int, int, int]]]
do_cut: bool
do_rename: bool
do_reshape: bool
do_shape_infer: bool
do_simplify: bool
class ONNXTiming(object):
def beg(self, action):
self._t_beg = timeit.default_timer()
self._action = action
print(f"\nonnx {action} ...")
def end(self):
t_end = timeit.default_timer()
print(f"onnx {self._action} done, cost {t_end-self._t_beg:.3f} s")
def onnx_cut(args: ONNXCutArgs) -> None:
timing = ONNXTiming()
model_path = args.input_path
if args.do_cut:
timing.beg("cut")
onnx.utils.extract_model(
args.input_path, args.output_path, args.input_names, args.output_names)
timing.end()
model_path = args.output_path
timing.beg("load")
print(f" path={model_path}")
model = onnx.load(model_path)
timing.end()
if args.do_rename:
timing.beg("rename")
names = args.input_names
name_new = [f"input_{i}" for i, _ in enumerate(args.input_names)]
if args.input_names_new is None:
name_new = [f"input_{i}" for i, _ in enumerate(args.input_names)]
else:
name_new = args.input_names_new
names.extend(args.output_names)
if args.output_names_new is None:
name_new.extend([f"output_{i}" for i, _ in enumerate(args.output_names)])
else:
name_new.extend(args.output_names_new)
_onnx_rename(model, names, name_new)
_onnx_rename_print(model, names, name_new,
file=os.path.splitext(model_path)[0] + "_rename.csv")
timing.end()
simp_input_shape = None
if args.do_reshape:
if args.input_shape and len(args.input_shape) == 4:
timing.beg("reshape input")
_onnx_input_reshape(model, args.input_shape)
simp_input_shape = args.input_shape
timing.end()
if args.specify_shapes:
timing.beg("reshape specified")
_onnx_specify_shapes(model, args.specify_shapes)
timing.end()
if args.do_shape_infer:
timing.beg("shape_infer")
model_infer = onnx.shape_inference.infer_shapes(model)
model = model_infer
timing.end()
if args.do_simplify:
timing.beg("simplify")
import onnxsim
if simp_input_shape:
model_simp, check = onnxsim.simplify(model, perform_optimization=False,
input_shapes={None: simp_input_shape})
else:
model_simp, check = onnxsim.simplify(model, perform_optimization=False)
assert check, "Simplified ONNX model could not be validated"
model = model_simp
timing.end()
output_path = os.path.splitext(model_path)[0] + "_final.onnx"
timing.beg("save")
print(f" path={output_path}")
onnx.save(model, output_path)
timing.end()
def _onnx_rename(model, names, names_new):
for node in model.graph.node:
for i, n in enumerate(node.input):
if n in names:
node.input[i] = names_new[names.index(n)]
for i, n in enumerate(node.output):
if n in names:
node.output[i] = names_new[names.index(n)]
for node in model.graph.input:
if node.name in names:
node.name = names_new[names.index(node.name)]
# print(model.graph.input)
for node in model.graph.output:
if node.name in names:
node.name = names_new[names.index(node.name)]
# print(model.graph.output)
def _onnx_rename_print(model, names, names_new, file="onnx_rename.csv"):
if file:
import csv
with open(file, "w", newline="") as f:
writer = csv.writer(f)
writer.writerow(("old", "new", "node"))
print(" old, new, node")
for n, nn in zip(names, names_new):
node = _onnx_node(model, nn)
writer.writerow((n, nn, '-' if node is None else node.name))
print(f"{n:>10s}, {nn:>10s}, {'-' if node is None else node.name:>10s}")
else:
print(" old, new, node")
for n, nn in zip(names, names_new):
node = _onnx_node(model, nn)
print(f"{n:>10s}, {nn:>10s}, {'-' if node is None else node.name:>10s}")
def _onnx_node(model, name):
for node in model.graph.node:
if node.name == name:
return node
for input in node.input:
if input == name:
return node
for output in node.output:
if output == name:
return node
return None
def _onnx_input_reshape(model, input_shape):
# set input_dims
input_dims = {}
for input in model.graph.input:
input_dims[input.name] = list(input_shape)
# keep output_dims
output_dims = {}
for output in model.graph.output:
dim = []
tensor_type = output.type.tensor_type
if tensor_type.HasField("shape"):
for d in tensor_type.shape.dim:
if d.HasField("dim_value"):
dim.append(d.dim_value)
elif d.HasField("dim_param"):
dim.append(d.dim_param)
else:
# sys.exit("error: unknown dimension")
continue
output_dims[output.name] = dim
# update
from onnx.tools import update_model_dims
update_model_dims.update_inputs_outputs_dims(
model, input_dims, output_dims)
print(model.graph.input)
# print(model.graph.output)
def _onnx_specify_shapes(model, specify_shapes):
graph = model.graph
input_map = _onnx_graph_name_map(graph.input)
output_map = _onnx_graph_name_map(graph.output)
value_info_map = _onnx_graph_name_map(graph.value_info)
all_map = {**input_map, **output_map, **value_info_map}
def _update(info_new, *graph_value_infos):
for infos in graph_value_infos:
for info in infos:
if info.name == info_new.name:
infos.remove(info)
infos.append(info_new)
return True
return False
from onnx import helper
for name, shape in specify_shapes.items():
if name in all_map:
info = all_map[name]
info_new = helper.make_tensor_value_info(
info.name, info.type.tensor_type.elem_type, list(shape))
if _update(info_new, graph.input, graph.output, graph.value_info):
print(f" {name}: {shape}")
else:
print(f"warn: specify shape {name} update failed")
else:
print(f"warn: specify shape {name} not found")
def _onnx_graph_name_map(graph_prop_list):
m = {}
for n in graph_prop_list:
m[n.name] = n
return m
def _parse_args():
parser = argparse.ArgumentParser()
def ints_type(string, num=4, sep=","):
if not string:
return None
ints = string.split(sep)
ints_len = len(ints)
if ints_len != num:
sys.exit(f"error: ints_type size must be {num}")
return tuple([int(x) for x in ints])
def shapes_type(string):
if not string:
return None
shapes = dict()
for s in string.split(";"):
ss = s.split(":")
if 2 != len(ss):
sys.exit(f"error: shapes_type must be like \"name:1,1,512,512;...\"")
name = ss[0]
ints_s = ss[1]
ints = ints_s.split(",")
if 4 != len(ints):
sys.exit(f"error: shapes_type must be like \"name:1,1,512,512\"")
shapes[name] = tuple([int(x) for x in ints])
return shapes
parser.add_argument("-i", "--input", required=True,
help="the model input path: %(default)s")
parser.add_argument("-o", "--output", default=None,
help="the model output path for sub: %(default)s")
parser.add_argument("-in", "--input-names", nargs="+", required=True,
help="the model input names for sub: %(default)s")
parser.add_argument("-on", "--output-names", nargs="+", required=True,
help="the model output names for sub: %(default)s")
parser.add_argument("-inn", "--input-names-new", nargs="+",
help="the model input names for sub: %(default)s")
parser.add_argument("-onn", "--output-names-new", nargs="+",
help="the model output names for sub: %(default)s")
parser.add_argument("-is", "--input-shape",
type=ints_type, metavar="1,3,512,512", default=None,
help="the model input shape NCHW for sub: %(default)s")
parser.add_argument("-ss", "--specify-shapes",
type=shapes_type, metavar="\"name:1,3,512,512;...\"", default=None,
help="specify shapes by input/output name of all nodes: %(default)s")
parser.add_argument("-nc", "--no-cut", action="store_true",
help="not cut sub model by input/output names: %(default)s")
parser.add_argument("-nrn", "--no-rename", action="store_true",
help="not rename input/output names to new ones (default: input/output_0 ...): %(default)s")
parser.add_argument("-nrs", "--no-reshape", action="store_true",
help="not reshape input shape, specify shapes: %(default)s")
parser.add_argument("-nis", "--no-infer-shape", action="store_true",
help="not inference shapes of model: %(default)s")
parser.add_argument("-nsi", "--no-simplify", action="store_true",
help="not simplify sub model: %(default)s")
args = parser.parse_args()
if not os.path.exists(args.input):
sys.exit(f"error: model \"{args.input}\" not exists")
if not args.output:
args.output = os.path.splitext(args.input)[0] + "_cut.onnx"
if args.input_names_new is not None and \
len(args.input_names_new) != len(args.input_names):
sys.exit("error: the size of new input names must same as original ones")
if args.output_names_new is not None and \
len(args.output_names_new) != len(args.output_names):
sys.exit("error: the size of new output names must same as original ones")
print("Args")
print(f" input: {args.input}")
print(f" output: {args.output}")
print(f" input_names: {args.input_names}")
print(f" output_names: {args.output_names}")
print(f" input_names_new: {args.input_names_new}")
print(f" output_names_new: {args.output_names_new}")
print(f" input_shape: {args.input_shape}")
print(f" specify_shapes: {args.specify_shapes}")
print(f" no_cut: {args.no_cut}")
print(f" no_rename: {args.no_rename}")
print(f" no_reshape: {args.no_reshape}")
print(f" no_infer_shape: {args.no_infer_shape}")
print(f" no_simplify: {args.no_simplify}")
return args
def _main():
args = _parse_args()
onnx_cut(ONNXCutArgs(
input_path=args.input,
output_path=args.output,
input_names=args.input_names,
output_names=args.output_names,
input_names_new=args.input_names_new,
output_names_new=args.output_names_new,
input_shape=args.input_shape,
specify_shapes=args.specify_shapes,
do_cut=not args.no_cut,
do_rename=not args.no_rename,
do_reshape=not args.no_reshape,
do_shape_infer=not args.no_infer_shape,
do_simplify=not args.no_simplify
))
if __name__ == "__main__":
_main()
# Is it possible to change input/output layer names of onnx model?
# https://github.com/onnx/onnx/issues/2052
# WinMLDashboard
# https://github.com/microsoft/Windows-Machine-Learning/tree/master/Tools/WinMLDashboard
@ikuokuo

This comment has been minimized.

Copy link
Owner Author

@ikuokuo ikuokuo commented Jun 2, 2021

ONNX env

conda create -n onnx python=3.8 -y
conda activate onnx

# ONNX
#  https://github.com/onnx/onnx
conda install -c conda-forge onnx -y

# ONNX Simplifier
#  https://github.com/daquexian/onnx-simplifier
pip install onnx-simplifier
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment