Skip to content

Instantly share code, notes, and snippets.

@qinjian623
Last active November 8, 2023 08:21
Show Gist options
  • Star 38 You must be signed in to star a gist
  • Fork 11 You must be signed in to fork a gist
  • Save qinjian623/6aa777037534c1c1dccbb66f832e93b8 to your computer and use it in GitHub Desktop.
Save qinjian623/6aa777037534c1c1dccbb66f832e93b8 to your computer and use it in GitHub Desktop.
ONNX file to Pytorch model
import onnx
import struct
import torch
import torch.nn as nn
import torchvision as tv
import warnings
# enum DataType {
# UNDEFINED = 0;
# // Basic types.
# FLOAT = 1; // float
# UINT8 = 2; // uint8_t
# INT8 = 3; // int8_t
# UINT16 = 4; // uint16_t
# INT16 = 5; // int16_t
# INT32 = 6; // int32_t
# INT64 = 7; // int64_t
# STRING = 8; // string
# BOOL = 9; // bool
#
# // IEEE754 half-precision floating-point format (16 bits wide).
# // This format has 1 sign bit, 5 exponent bits, and 10 mantissa bits.
# FLOAT16 = 10;
#
# DOUBLE = 11;
# UINT32 = 12;
# UINT64 = 13;
# COMPLEX64 = 14; // complex with float32 real and imaginary components
# COMPLEX128 = 15; // complex with float64 real and imaginary components
#
# // Non-IEEE floating-point format based on IEEE754 single-precision
# // floating-point number truncated to 16 bits.
# // This format has 1 sign bit, 8 exponent bits, and 7 mantissa bits.
# BFLOAT16 = 16;
#
# // Future extensions go here.
# }
# TODO more types maybe?
data_type_tab = {
1: ['f', 4],
2: ['B', 1],
3: ['b', 1],
4: ['H', 2],
5: ['h', 2],
6: ['i', 4],
7: ['q', 8],
10: ['e', 2],
11: ['d', 8],
12: ['I', 4],
13: ['Q', 8]
}
def empty(x):
return x
# TODO pytorch only accepts 2-value list for padding.
def _slim422(l4):
assert len(l4) == 4
p0, p1 = l4[::2]
if l4[0] == 0: # TODO bad code
p0 = l4[2] // 2
if l4[2] == 1:
p0 = 1
if l4[1] == 0: # TODO bad code
p1 = l4[3] // 2
if l4[3] == 1:
p1 = 1
return p0, p1
def _check_attr(attrs, map):
for attr in attrs:
if attr.name not in map:
warnings.warn("Missing {} in parser's attr_map.".format(attr.name))
def unpack_weights(initializer):
ret = {}
for i in initializer:
name = i.name
dtype = i.data_type
shape = list(i.dims)
if dtype not in data_type_tab:
warnings("This data type {} is not supported yet.".format(dtype))
fmt, size = data_type_tab[dtype]
if len(i.raw_data) == 0:
if dtype == 1:
data_list = i.float_data
elif dtype == 7:
data_list = i.int64_data
else:
warnings.warn("No-raw-data type {} not supported yet.".format(dtype))
else:
data_list = struct.unpack('<' + fmt * (len(i.raw_data) // size), i.raw_data)
t = torch.tensor(data_list)
if len(shape) != 0:
t = t.view(*shape)
ret[name] = t
return ret
def rebuild_lrn(node, weights):
# size, alpha = 1e-4, beta = 0.75, k = 1.
rebuild_lrn.lrn_attr_map = {
'size': 'size',
'alpha': 'alpha',
'beta': 'beta',
'bias': 'k'
}
kwargs = {}
for att in node.attribute:
kwargs[rebuild_lrn.lrn_attr_map[att.name]] = att.f if att.name != 'size' else att.i
return nn.LocalResponseNorm(**kwargs), node.input, node.output
def rebuild_conv(node, weights):
rebuild_conv.conv_attr_map = {
"pads": "padding",
"strides": "stride",
"kernel_shape": "kernel_size",
"group": "groups",
"dilations": "dilation"
}
assert len(node.output) == 1
with_bias = False
if len(node.input) == 3:
with_bias = True
bias_name = node.input[2]
bias = weights[bias_name]
weight_name = node.input[1]
weight = weights[weight_name]
in_channels = weight.shape[1]
out_channels = weight.shape[0]
kwargs = {}
for att in node.attribute:
kwargs[rebuild_conv.conv_attr_map[att.name]] = list(att.ints) if att.name != 'group' else att.i
if 'padding' in kwargs:
kwargs["padding"] = _slim422(kwargs["padding"])
groups = 1 if 'groups' not in kwargs else kwargs['groups']
in_channels *= groups
conv = nn.Conv2d(in_channels, out_channels, **kwargs, bias=with_bias)
conv.weight.data = weight
if with_bias:
conv.bias.data = bias
return conv, node.input[:1], node.output
def rebuild_dropout(node, weights):
ratio = node.attribute[0].f
return nn.Dropout2d(p=ratio), node.input, node.output
def rebuild_batchnormalization(node, weights):
rebuild_batchnormalization.bn_attr_map = {
"epsilon": "eps",
"momentum": "momentum"
}
assert len(node.input) == 5
assert len(node.output) == 1
weight = weights[node.input[1]]
bias = weights[node.input[2]]
running_mean = weights[node.input[3]]
running_var = weights[node.input[4]]
dim = weight.shape[0]
kwargs = {}
_check_attr(node.attribute, rebuild_batchnormalization.bn_attr_map)
for att in node.attribute:
if att.name in rebuild_batchnormalization.bn_attr_map:
kwargs[rebuild_batchnormalization.bn_attr_map[att.name]] = att.f
bn = nn.BatchNorm2d(num_features=dim)
bn.weight.data = weight
bn.bias.data = bias
bn.running_mean.data = running_mean
bn.running_var.data = running_var
return bn, node.input[:1], node.output
def rebuild_relu(node, weights):
return nn.ReLU(), node.input, node.output
def rebuild_maxpool(node, weights):
rebuild_maxpool.mp_attr_map = {
"pads": "padding",
"strides": "stride",
"kernel_shape": "kernel_size",
}
kwargs = {}
for att in node.attribute:
kwargs[rebuild_maxpool.mp_attr_map[att.name]] = list(att.ints)
if 'padding' in kwargs:
kwargs["padding"] = _slim422(kwargs["padding"])
mp = nn.MaxPool2d(**kwargs)
return mp, node.input, node.output
def rebuild_add(node, weights):
def add(a, b):
return a + b
return add, node.input, node.output
def rebuild_globalaveragepool(node, weights):
avg_pool = nn.AdaptiveAvgPool2d((1, 1))
return avg_pool, node.input, node.output
def rebuild_transpose(node, weights):
perm = node.attribute[0].ints
def transpose(x):
x = x.permute(*perm)
return x
return transpose, node.input, node.output
def rebuild_flatten(node, weights):
if len(node.attribute) == 0:
d = 1
else:
d = node.attribute[0].i
def flatten(x):
o_shape = []
for i in range(d):
o_shape.append(x.shape[i])
o_shape.append(-1)
return x.view(*o_shape)
return flatten, node.input, node.output
def rebuild_gemm(node, weights):
weight = weights[node.input[1]]
bias = weights[node.input[2]]
in_feats = weight.shape[1]
out_feats = weight.shape[0]
linear = nn.Linear(in_features=in_feats, out_features=out_feats)
linear.weight.data = weight
linear.bias.data = bias
return linear, node.input[:1], node.output
def rebuild_concat(node, weights):
dim = node.attribute[0].i
def concat(*inputs):
# for i in inputs:
# print(i.shape)
ret = torch.cat(inputs, dim)
# print(ret.shape)
# exit()
return ret
return concat, node.input, node.output
def rebuild_pad(node, weights):
mode = node.attribute[0].s
pads = list(node.attribute[1].ints)
value = node.attribute[2].f
assert mode == b'constant' # TODO constant only
assert sum(pads[:4]) == 0 # TODO pad2d only
pad = nn.ConstantPad2d(pads[4:], value)
return pad, node.input, node.output
def rebuild_constant(node, weights):
raw_data = node.attribute[0].t.raw_data
data_type = node.attribute[0].t.data_type
fmt, size = data_type_tab[data_type]
data = struct.unpack('<' + fmt * (len(raw_data) // size), raw_data)
if len(data) == 1:
data = data[0]
def constant():
return torch.tensor(data)
return constant, [], node.output
def rebuild_sum(node, weights):
def sum(*inputs):
ret = inputs[0]
for i in inputs[1:]:
ret += i
return ret
return sum, node.input, node.output
def rebuild_shape(node, weights):
def shape(x):
return torch.tensor(list(x.shape))
return shape, node.input, node.output
def rebuild_gather(node, weights):
axis = node.attribute[0].i
def gather(x, idx):
return torch.gather(x, axis, idx)
return gather, node.input, node.output
def _nd_unsqueeze(x, dims):
dims = sorted(dims)
for d in dims:
x = torch.unsqueeze(x, dim=d)
return x
def rebuild_unsqueeze(node, weights):
axes = node.attribute[0].ints
def unsqueeze(x):
return _nd_unsqueeze(x, axes)
return unsqueeze, node.input, node.output
def rebuild_mul(node, weights):
def mul(a, b):
return a * b
return mul, node.input, node.output
def rebuild_softmax(node, weights):
def f_softmax(x):
return x.softmax(dim=1, dtype=torch.double).float()
return f_softmax, node.input, node.output
def rebuild_reshape(node, weights):
def reshape(x, s):
data_shape = x.shape
onnx_shape = s.tolist()
pt_shape = []
for idx, d in enumerate(onnx_shape):
if d == 0:
pt_shape.append(data_shape[idx])
else:
pt_shape.append(d)
return torch.reshape(x, pt_shape)
return reshape, node.input, node.output
def rebuild_averagepool(node, weights):
rebuild_averagepool.avg_attr_map = {
"pads": "padding",
"strides": "stride",
"kernel_shape": "kernel_size",
}
kwargs = {}
for att in node.attribute:
kwargs[rebuild_averagepool.avg_attr_map[att.name]] = list(att.ints)
if 'padding' in kwargs:
kwargs["padding"] = _slim422(kwargs["padding"])
ap = nn.AvgPool2d(**kwargs)
return ap, node.input, node.output
def rebuild_op(node, weights):
op_type = node.op_type
return globals()['rebuild_'+op_type.lower()](node, weights)
def construct_pytorch_nodes(graph, weights):
ret = []
for single_node in graph.node:
ret.append(rebuild_op(single_node, weights))
return ret
def resolve_deps(name, deps, inter_tensors):
if name in inter_tensors:
return
else:
op, deps_names = deps[name]
args = []
for deps_name in deps_names:
resolve_deps(deps_name, deps, inter_tensors)
args.append(inter_tensors[deps_name])
result = op(*args)
inter_tensors[name] = result
class DependencyModule(nn.Module):
def __init__(self, onnx_model, input_name=None):
super(DependencyModule, self).__init__()
self.deps = {}
self.inter_tensors = dict()
self.weights = unpack_weights(onnx_model.graph.initializer)
nodes = construct_pytorch_nodes(onnx_model.graph, self.weights)
for idx, (node, inputs, outputs) in enumerate(nodes):
if isinstance(node, nn.Module):
self.add_module(str(idx), node)
for output_name in outputs:
self.deps[output_name] = (node, inputs)
self.input_name = onnx_model.graph.input[0].name # TODO only you
self.output_name = onnx_model.graph.output[0].name # TODO only you
if input_name is not None:
self.input_name = input_name
def forward(self, input):
self.inter_tensors = self.weights.copy()
self.inter_tensors[self.input_name] = input
resolve_deps(self.output_name, self.deps, self.inter_tensors)
return self.inter_tensors[self.output_name]
def test_net(original_model, onnx_file):
import time
original_model.eval()
onnx_model = onnx.load(onnx_file)
reconstruct_model = DependencyModule(onnx_model)
reconstruct_model.eval()
input = torch.randn(3, 3, 224, 224)
s = time.time()
r1 = original_model(input)
print("Original:", time.time() - s)
s = time.time()
r = reconstruct_model(input)
print("DependencyModule:", time.time() - s)
print("Max error for", onnx_file, ":", (r - r1).abs().max().item())
def main():
test_net(tv.models.resnet18(True), "res18.onnx")
test_net(tv.models.resnet50(True), "res50.onnx")
test_net(tv.models.densenet121(True), "dense121.onnx")
if __name__ == '__main__':
main()
import mxnet.contrib.onnx as onnx_mxnet
import mxnet as mx
import numpy as np
import torch
import onnx
import onnx2pytorch as oi
from collections import namedtuple
def construct_mxnext_model(onnx_file, test_input):
sym, arg, aux = onnx_mxnet.import_model(onnx_file)
data_names = [graph_input for graph_input in sym.list_inputs()
if graph_input not in arg and graph_input not in aux]
print("Input Blob Names:", data_names)
mod = mx.mod.Module(symbol=sym, data_names=data_names, context=mx.cpu(), label_names=None)
print(sym)
# exit(0)
mod.bind(for_training=False, data_shapes=[(data_names[0], test_input.shape)], label_shapes=None)
mod.set_params(arg_params=arg, aux_params=aux, allow_missing=True, allow_extra=True)
Batch = namedtuple('Batch', ['data'])
# forward on the provided data batch
mod.forward(Batch([mx.nd.array(test_input)]))
output = mod.get_outputs()[0]
mo = output.asnumpy()
return mo
def construct_pytorch_model(onnx_file, test_input):
onnx_model = onnx.load(onnx_file)
if onnx_file == "densenet121.onnx":
reconstruct_model = oi.DependencyModule(onnx_model, input_name="data_0")
else:
reconstruct_model = oi.DependencyModule(onnx_model)
reconstruct_model.eval()
i = torch.from_numpy(test_input).float()
o = reconstruct_model(i).detach().numpy()
return o
def test_onnx_model(onnx_file):
print("=" * 80)
print(onnx_file, ":")
test_input = np.random.randn(1, 3, 224, 224) / 10
o = construct_pytorch_model(onnx_file, test_input)
mo = construct_mxnext_model(onnx_file, test_input)
abs_error = np.absolute(mo - o)
print(abs_error.max(), abs_error.mean(), abs_error.min())
print(mo[0][:5])
print(o[0][:5])
def main():
ok_onnx_model_files = [
"googlenet.onnx", # OK special padding setting case not supported by PyTorch MaxPool. with Softmax()
"resnet18v2.onnx", # OK
"resnet34v2.onnx", # OK
"squeezenet1.1.onnx", # OK
"mobilenetv2-1.0.onnx", # OK
"alex_net.onnx", # OK but max error is not small enough. with Softmax()
"densenet121.onnx", # OK but input_name is 'data_0', not '0' in onnx.graph.input
"vgg16.onnx", # OK
# "inception_v2.onnx", # TODO wrong output, with Softmax()
# "inception_v1.onnx", # TODO Gemm weight shape in runtime
# "shuffle_net.onnx", # TODO wrong output, maybe by transpose or Softmax()
]
for model_file in ok_onnx_model_files:
test_onnx_model(model_file)
if __name__ == '__main__':
main()
@akawashiro
Copy link

Hi, thank you for publishing excellent source code. What is the license of this code? Can I change and use this for commercial purposes?

@qinjian623
Copy link
Author

qinjian623 commented Aug 30, 2021

Hi, thank you for publishing excellent source code. What is the license of this code? Can I change and use this for commercial purposes?

Hi, @akawashiro .

Sorry for the late reply.

I always prefer MIT license. And it's compatible with commercial usage.

@akawashiro
Copy link

Thank you.

@MonTer998
Copy link

hi i am sorry but i am new to machine learning but have to make a coinventor from onnx to pytorch so your code is a big help, if its not too much to ask can you elaborate in an example how dose your code works

@qinjian623
Copy link
Author

hi i am sorry but i am new to machine learning but have to make a coinventor from onnx to pytorch so your code is a big help, if its not too much to ask can you elaborate in an example how dose your code works

@MonTer998
Code starts from here showed a simple usage of this script:
https://gist.github.com/qinjian623/6aa777037534c1c1dccbb66f832e93b8#file-onnx2pytorch_validate-py-L29

@MonTer998
Copy link

@qinjian623 thanks a million

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment