Skip to content

Instantly share code, notes, and snippets.

@antoinebrl
Last active November 21, 2022 14:40
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save antoinebrl/aac76f89395bcb548a26d42ed8d82761 to your computer and use it in GitHub Desktop.
Save antoinebrl/aac76f89395bcb548a26d42ed8d82761 to your computer and use it in GitHub Desktop.
ONNX surgery - Add input normalization
import timm
import torchvision
import torch
import urllib
from PIL import Image
from torch import nn
import onnx
import onnx.helper as oh
from onnx.onnx_pb import ValueInfoProto
import numpy as np
# Fetch model
model = timm.create_model('mobilenetv3_small_100', pretrained=True, exportable=True)
model.eval()
config = timm.data.resolve_data_config({}, model=model)
# Export ONNX file
dummy_input = torch.rand(1, 3, 224, 224).to(device)
# Onnx support hardswish in opset-14 version
torch.onnx.export(
model,
args=dummy_input,
f="model-timm.onnx",
input_names=["input"],
output_names=["output"],
opset_version=13,
verbose=False,
dynamic_axes={'input' : {0 : 'N'}, 'output' : {0 : 'N'}})
def create_constant_node(values: np.ndarray, *, name: str = ""):
tensor = oh.make_tensor(
name=name,
data_type=onnx.mapping.NP_TYPE_TO_TENSOR_TYPE[values.dtype],
dims=values.shape,
vals=values.flatten(),
)
node = oh.make_node("Constant", inputs=[], outputs=[name], value=tensor, name=name)
return node
# ONNX surgery
net = onnx.load("model-timm.onnx")
input_node = ValueInfoProto()
input_node.CopyFrom(net.graph.input[0])
input_node.type.denotation = "IMAGE"
output_node = ValueInfoProto()
output_node.CopyFrom(input_node)
output_node.name = "preproc-output"
# Substract mean
mean = np.array(model.default_cfg["mean"]).astype("float32").reshape(1, 3, 1, 1) # NCHW
mean_node = create_constant_node(mean, name="mean")
sub_node = oh.make_node("Sub", inputs=[input_node.name, mean_node.name], outputs=["centered"], name="centered")
# Divide by standard deviation
scale = np.array(model.default_cfg["std"]).astype("float32").reshape(1, 3, 1, 1) # NCHW
scale_node = create_constant_node(scale, name="scale")
div_node = oh.make_node("Div", inputs=[sub_node.name, scale_node.name], outputs=[output_node.name])
# Build the preprocessing graph
preproc_graph = oh.make_graph(
nodes=[mean_node, sub_node, scale_node, div_node], name="preproc", inputs=[input_node], outputs=[output_node]
)
preproc_net = oh.make_model(
preproc_graph, producer_name="lynxai", opset_imports=net.opset_import, ir_version=net.ir_version
)
# onnx.save(preproc_net, "preproc-model.onnx")
# Merge both networks
model_meta = {
"producer_name": net.producer_name + "+lynxai",
"producer_version": net.producer_version,
"model_version": net.model_version,
"domain": net.domain,
"doc_string": net.doc_string,
}
merged_model = onnx.compose.merge_models(
preproc_net, net, prefix1="preproc/", io_map=[(output_node.name, input_node.name)], **model_meta
)
onnx.save(merged_model, "merged.onnx")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment