Created
June 2, 2021 13:21
-
-
Save ikuokuo/29c5b7eaf6601b75302162ea28865fb9 to your computer and use it in GitHub Desktop.
Cut sub-model from an ONNX model, and update its input/output names or shapes
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python | |
# -*- coding: utf-8 -*- | |
# pylint: disable=missing-docstring | |
import argparse | |
import os | |
import sys | |
import timeit | |
import typing | |
import onnx | |
class ONNXCutArgs(typing.NamedTuple): | |
input_path: str | |
output_path: typing.Optional[str] | |
input_names: typing.List[str] | |
output_names: typing.List[str] | |
input_names_new: typing.Optional[typing.List[str]] | |
output_names_new: typing.Optional[typing.List[str]] | |
input_shape: typing.Optional[typing.Tuple[int, int, int, int]] | |
specify_shapes: typing.Optional[typing.Dict[str, typing.Tuple[int, int, int, int]]] | |
do_cut: bool | |
do_rename: bool | |
do_reshape: bool | |
do_shape_infer: bool | |
do_simplify: bool | |
class ONNXTiming(object): | |
def beg(self, action): | |
self._t_beg = timeit.default_timer() | |
self._action = action | |
print(f"\nonnx {action} ...") | |
def end(self): | |
t_end = timeit.default_timer() | |
print(f"onnx {self._action} done, cost {t_end-self._t_beg:.3f} s") | |
def onnx_cut(args: ONNXCutArgs) -> None: | |
timing = ONNXTiming() | |
model_path = args.input_path | |
if args.do_cut: | |
timing.beg("cut") | |
onnx.utils.extract_model( | |
args.input_path, args.output_path, args.input_names, args.output_names) | |
timing.end() | |
model_path = args.output_path | |
timing.beg("load") | |
print(f" path={model_path}") | |
model = onnx.load(model_path) | |
timing.end() | |
if args.do_rename: | |
timing.beg("rename") | |
names = args.input_names | |
name_new = [f"input_{i}" for i, _ in enumerate(args.input_names)] | |
if args.input_names_new is None: | |
name_new = [f"input_{i}" for i, _ in enumerate(args.input_names)] | |
else: | |
name_new = args.input_names_new | |
names.extend(args.output_names) | |
if args.output_names_new is None: | |
name_new.extend([f"output_{i}" for i, _ in enumerate(args.output_names)]) | |
else: | |
name_new.extend(args.output_names_new) | |
_onnx_rename(model, names, name_new) | |
_onnx_rename_print(model, names, name_new, | |
file=os.path.splitext(model_path)[0] + "_rename.csv") | |
timing.end() | |
simp_input_shape = None | |
if args.do_reshape: | |
if args.input_shape and len(args.input_shape) == 4: | |
timing.beg("reshape input") | |
_onnx_input_reshape(model, args.input_shape) | |
simp_input_shape = args.input_shape | |
timing.end() | |
if args.specify_shapes: | |
timing.beg("reshape specified") | |
_onnx_specify_shapes(model, args.specify_shapes) | |
timing.end() | |
if args.do_shape_infer: | |
timing.beg("shape_infer") | |
model_infer = onnx.shape_inference.infer_shapes(model) | |
model = model_infer | |
timing.end() | |
if args.do_simplify: | |
timing.beg("simplify") | |
import onnxsim | |
if simp_input_shape: | |
model_simp, check = onnxsim.simplify(model, perform_optimization=False, | |
input_shapes={None: simp_input_shape}) | |
else: | |
model_simp, check = onnxsim.simplify(model, perform_optimization=False) | |
assert check, "Simplified ONNX model could not be validated" | |
model = model_simp | |
timing.end() | |
output_path = os.path.splitext(model_path)[0] + "_final.onnx" | |
timing.beg("save") | |
print(f" path={output_path}") | |
onnx.save(model, output_path) | |
timing.end() | |
def _onnx_rename(model, names, names_new): | |
for node in model.graph.node: | |
for i, n in enumerate(node.input): | |
if n in names: | |
node.input[i] = names_new[names.index(n)] | |
for i, n in enumerate(node.output): | |
if n in names: | |
node.output[i] = names_new[names.index(n)] | |
for node in model.graph.input: | |
if node.name in names: | |
node.name = names_new[names.index(node.name)] | |
# print(model.graph.input) | |
for node in model.graph.output: | |
if node.name in names: | |
node.name = names_new[names.index(node.name)] | |
# print(model.graph.output) | |
def _onnx_rename_print(model, names, names_new, file="onnx_rename.csv"): | |
if file: | |
import csv | |
with open(file, "w", newline="") as f: | |
writer = csv.writer(f) | |
writer.writerow(("old", "new", "node")) | |
print(" old, new, node") | |
for n, nn in zip(names, names_new): | |
node = _onnx_node(model, nn) | |
writer.writerow((n, nn, '-' if node is None else node.name)) | |
print(f"{n:>10s}, {nn:>10s}, {'-' if node is None else node.name:>10s}") | |
else: | |
print(" old, new, node") | |
for n, nn in zip(names, names_new): | |
node = _onnx_node(model, nn) | |
print(f"{n:>10s}, {nn:>10s}, {'-' if node is None else node.name:>10s}") | |
def _onnx_node(model, name): | |
for node in model.graph.node: | |
if node.name == name: | |
return node | |
for input in node.input: | |
if input == name: | |
return node | |
for output in node.output: | |
if output == name: | |
return node | |
return None | |
def _onnx_input_reshape(model, input_shape): | |
# set input_dims | |
input_dims = {} | |
for input in model.graph.input: | |
input_dims[input.name] = list(input_shape) | |
# keep output_dims | |
output_dims = {} | |
for output in model.graph.output: | |
dim = [] | |
tensor_type = output.type.tensor_type | |
if tensor_type.HasField("shape"): | |
for d in tensor_type.shape.dim: | |
if d.HasField("dim_value"): | |
dim.append(d.dim_value) | |
elif d.HasField("dim_param"): | |
dim.append(d.dim_param) | |
else: | |
# sys.exit("error: unknown dimension") | |
continue | |
output_dims[output.name] = dim | |
# update | |
from onnx.tools import update_model_dims | |
update_model_dims.update_inputs_outputs_dims( | |
model, input_dims, output_dims) | |
print(model.graph.input) | |
# print(model.graph.output) | |
def _onnx_specify_shapes(model, specify_shapes): | |
graph = model.graph | |
input_map = _onnx_graph_name_map(graph.input) | |
output_map = _onnx_graph_name_map(graph.output) | |
value_info_map = _onnx_graph_name_map(graph.value_info) | |
all_map = {**input_map, **output_map, **value_info_map} | |
def _update(info_new, *graph_value_infos): | |
for infos in graph_value_infos: | |
for info in infos: | |
if info.name == info_new.name: | |
infos.remove(info) | |
infos.append(info_new) | |
return True | |
return False | |
from onnx import helper | |
for name, shape in specify_shapes.items(): | |
if name in all_map: | |
info = all_map[name] | |
info_new = helper.make_tensor_value_info( | |
info.name, info.type.tensor_type.elem_type, list(shape)) | |
if _update(info_new, graph.input, graph.output, graph.value_info): | |
print(f" {name}: {shape}") | |
else: | |
print(f"warn: specify shape {name} update failed") | |
else: | |
print(f"warn: specify shape {name} not found") | |
def _onnx_graph_name_map(graph_prop_list): | |
m = {} | |
for n in graph_prop_list: | |
m[n.name] = n | |
return m | |
def _parse_args(): | |
parser = argparse.ArgumentParser() | |
def ints_type(string, num=4, sep=","): | |
if not string: | |
return None | |
ints = string.split(sep) | |
ints_len = len(ints) | |
if ints_len != num: | |
sys.exit(f"error: ints_type size must be {num}") | |
return tuple([int(x) for x in ints]) | |
def shapes_type(string): | |
if not string: | |
return None | |
shapes = dict() | |
for s in string.split(";"): | |
ss = s.split(":") | |
if 2 != len(ss): | |
sys.exit(f"error: shapes_type must be like \"name:1,1,512,512;...\"") | |
name = ss[0] | |
ints_s = ss[1] | |
ints = ints_s.split(",") | |
if 4 != len(ints): | |
sys.exit(f"error: shapes_type must be like \"name:1,1,512,512\"") | |
shapes[name] = tuple([int(x) for x in ints]) | |
return shapes | |
parser.add_argument("-i", "--input", required=True, | |
help="the model input path: %(default)s") | |
parser.add_argument("-o", "--output", default=None, | |
help="the model output path for sub: %(default)s") | |
parser.add_argument("-in", "--input-names", nargs="+", required=True, | |
help="the model input names for sub: %(default)s") | |
parser.add_argument("-on", "--output-names", nargs="+", required=True, | |
help="the model output names for sub: %(default)s") | |
parser.add_argument("-inn", "--input-names-new", nargs="+", | |
help="the model input names for sub: %(default)s") | |
parser.add_argument("-onn", "--output-names-new", nargs="+", | |
help="the model output names for sub: %(default)s") | |
parser.add_argument("-is", "--input-shape", | |
type=ints_type, metavar="1,3,512,512", default=None, | |
help="the model input shape NCHW for sub: %(default)s") | |
parser.add_argument("-ss", "--specify-shapes", | |
type=shapes_type, metavar="\"name:1,3,512,512;...\"", default=None, | |
help="specify shapes by input/output name of all nodes: %(default)s") | |
parser.add_argument("-nc", "--no-cut", action="store_true", | |
help="not cut sub model by input/output names: %(default)s") | |
parser.add_argument("-nrn", "--no-rename", action="store_true", | |
help="not rename input/output names to new ones (default: input/output_0 ...): %(default)s") | |
parser.add_argument("-nrs", "--no-reshape", action="store_true", | |
help="not reshape input shape, specify shapes: %(default)s") | |
parser.add_argument("-nis", "--no-infer-shape", action="store_true", | |
help="not inference shapes of model: %(default)s") | |
parser.add_argument("-nsi", "--no-simplify", action="store_true", | |
help="not simplify sub model: %(default)s") | |
args = parser.parse_args() | |
if not os.path.exists(args.input): | |
sys.exit(f"error: model \"{args.input}\" not exists") | |
if not args.output: | |
args.output = os.path.splitext(args.input)[0] + "_cut.onnx" | |
if args.input_names_new is not None and \ | |
len(args.input_names_new) != len(args.input_names): | |
sys.exit("error: the size of new input names must same as original ones") | |
if args.output_names_new is not None and \ | |
len(args.output_names_new) != len(args.output_names): | |
sys.exit("error: the size of new output names must same as original ones") | |
print("Args") | |
print(f" input: {args.input}") | |
print(f" output: {args.output}") | |
print(f" input_names: {args.input_names}") | |
print(f" output_names: {args.output_names}") | |
print(f" input_names_new: {args.input_names_new}") | |
print(f" output_names_new: {args.output_names_new}") | |
print(f" input_shape: {args.input_shape}") | |
print(f" specify_shapes: {args.specify_shapes}") | |
print(f" no_cut: {args.no_cut}") | |
print(f" no_rename: {args.no_rename}") | |
print(f" no_reshape: {args.no_reshape}") | |
print(f" no_infer_shape: {args.no_infer_shape}") | |
print(f" no_simplify: {args.no_simplify}") | |
return args | |
def _main(): | |
args = _parse_args() | |
onnx_cut(ONNXCutArgs( | |
input_path=args.input, | |
output_path=args.output, | |
input_names=args.input_names, | |
output_names=args.output_names, | |
input_names_new=args.input_names_new, | |
output_names_new=args.output_names_new, | |
input_shape=args.input_shape, | |
specify_shapes=args.specify_shapes, | |
do_cut=not args.no_cut, | |
do_rename=not args.no_rename, | |
do_reshape=not args.no_reshape, | |
do_shape_infer=not args.no_infer_shape, | |
do_simplify=not args.no_simplify | |
)) | |
if __name__ == "__main__": | |
_main() | |
# Is it possible to change input/output layer names of onnx model? | |
# https://github.com/onnx/onnx/issues/2052 | |
# WinMLDashboard | |
# https://github.com/microsoft/Windows-Machine-Learning/tree/master/Tools/WinMLDashboard |
Hey, the resulting onnx file is getting seg fault in onnxsim module. Any ideas why?
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
ONNX env