Skip to content

Instantly share code, notes, and snippets.

@disktnk
Created February 20, 2019 03:47
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save disktnk/f9a0b0ec12daf8f39d09fc86ade2a08e to your computer and use it in GitHub Desktop.
Save disktnk/f9a0b0ec12daf8f39d09fc86ade2a08e to your computer and use it in GitHub Desktop.
`pytest`
def test_bn_onnxruntime():
import numpy as np
import onnx
# this code is from ONNX BatchNomalization example
# https://github.com/onnx/onnx/blob/master/onnx/backend/test/case/node/batchnorm.py
def _batchnorm_test_mode(x, s, bias, mean, var, epsilon=1e-5): # type: ignore
dims_x = len(x.shape)
dim_ones = (1,) * (dims_x - 2)
s = s.reshape(-1, *dim_ones)
bias = bias.reshape(-1, *dim_ones)
mean = mean.reshape(-1, *dim_ones)
var = var.reshape(-1, *dim_ones)
return s * (x - mean) / np.sqrt(var + epsilon) + bias
# input size: (1, 2, 1, 3)
x = np.array([[[[-1, 0, 1]], [[2, 3, 4]]]]).astype(np.float32)
s = np.array([1.0, 1.5]).astype(np.float32)
bias = np.array([0, 1]).astype(np.float32)
mean = np.array([0, 3]).astype(np.float32)
var = np.array([1, 1.5]).astype(np.float32)
y = _batchnorm_test_mode(x, s, bias, mean, var).astype(np.float32)
node = onnx.helper.make_node(
'BatchNormalization',
inputs=['x', 's', 'bias', 'mean', 'var'],
outputs=['y'],
)
inputs = [x, s, bias, mean, var]
outputs = [y]
# this code is from ONNX expect function
# https://github.com/onnx/onnx/blob/master/onnx/backend/test/case/node/__init__.py
def _extract_value_info(arr, name): # type: (np.ndarray, Text) -> onnx.ValueInfoProto
return onnx.helper.make_tensor_value_info(
name=name,
elem_type=onnx.mapping.NP_TYPE_TO_TENSOR_TYPE[arr.dtype],
shape=arr.shape)
present_inputs = [x for x in node.input if (x != '')]
present_outputs = [x for x in node.output if (x != '')]
inputs_vi = [_extract_value_info(arr, arr_name)
for arr, arr_name in zip(inputs, present_inputs)]
outputs_vi = [_extract_value_info(arr, arr_name)
for arr, arr_name in zip(outputs, present_outputs)]
graph = onnx.helper.make_graph(
nodes=[node],
name='test_batchnorm_example',
inputs=inputs_vi,
outputs=outputs_vi)
opset_version = 8 # 9 is failed by "GENERAL ERROR", 7 and 8 are succeeded
model = onnx.helper.make_model(
graph,
producer_name='backend-test',
opset_imports=[onnx.helper.make_opsetid('', opset_version)]
)
import onnxruntime as rt
sess = rt.InferenceSession(model.SerializeToString())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment