Last active
October 8, 2020 16:48
-
-
Save eopXD/075ff5d4da38b0000d902acf755cdfbc to your computer and use it in GitHub Desktop.
Generate ONNX model via ONNX, input initialized
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# By eopXD (eopxd.com) | |
# Since it took me some time to figure it out, I guess this file can help out. | |
import onnx | |
from onnx import helper | |
from onnx import AttributeProto, TensorProto, GraphProto | |
from onnx import numpy_helper | |
xNP = np.array([[[[0., 1., 2.], # (1, 1, 3, 3) | |
[3., 4., 5.], | |
[6., 7., 8.]]]]).astype(np.float32) | |
wNP = np.array([[[[1., 1., 1.], # (1, 2, 3, 3) | |
[1., 1., 1.], | |
[1., 1., 1.]], | |
[[1., 1., 1.], | |
[1., 1., 1.], | |
[1., 1., 1.]]]]).astype(np.float32) | |
bNP = np.array([0, 0]) | |
yNP = np.array([[[[0., 1., 3., 3., 2.], # (1, 2, 5, 5) | |
[3., 8., 15., 12., 7.], | |
[9., 21., 36., 27., 15.], | |
[9., 20., 33., 24., 13.], | |
[6., 13., 21., 15., 8.]], | |
[[0., 1., 3., 3., 2.], | |
[3., 8., 15., 12., 7.], | |
[9., 21., 36., 27., 15.], | |
[9., 20., 33., 24., 13.], | |
[6., 13., 21., 15., 8.]]]]).astype(np.float32) | |
ConvTranspose = onnx.helper.make_node("ConvTranspose", ["X", "W", "B"], ["Y"]) | |
xInit = numpy_helper.from_array(xNP, "X") | |
wInit = numpy_helper.from_array(wNP, "W") | |
bInit = numpy_helper.from_array(bNP, "B") | |
xTensor = helper.make_tensor_value_info('X', TensorProto.FLOAT, [1, 1, 3, 3]) | |
wTensor = helper.make_tensor_value_info('W', TensorProto.FLOAT, [1, 2, 3, 3]) | |
bTensor = helper.make_tensor_value_info('B', TensorProto.FLOAT, [2]) | |
yTensor = helper.make_tensor_value_info('Y', TensorProto.FLOAT, [1, 2, 5, 5]) | |
graph_def = helper.make_graph( | |
[ConvTranspose], | |
'test-model', | |
[xTensor, wTensor, bTensor], | |
[yTensor], | |
[xInit, wInit, bInit] | |
) | |
model_def = helper.make_model(graph_def, producer_name='onnx-example') | |
print('The model is:\n{}'.format(model_def)) | |
onnx.checker.check_model(model_def) | |
print('The model is checked!') | |
onnx.save(model_def, "SingleConvTranspose.onnx") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment