Skip to content

Instantly share code, notes, and snippets.

@FrancescoConti
Created April 20, 2021 07:45
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save FrancescoConti/e73a1ace0f437de72f445e63b0412b3e to your computer and use it in GitHub Desktop.
Save FrancescoConti/e73a1ace0f437de72f445e63b0412b3e to your computer and use it in GitHub Desktop.
Small test for ONNX produced by NEMO
#
# Copyright (C) 2021 University of Bologna
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Author: Francesco Conti <f.conti@unibo.it>
import onnx
import onnxruntime as rt
import numpy as np
from onnx import numpy_helper
INVESTIGATE_NODE = '490' # None (for the real output) or set to the node name e.g. '208'
MODEL_NAME = 'model_best.onnx'
DATA_NAME = 'test.npy'
data = np.load(DATA_NAME, allow_pickle=True).item()
if INVESTIGATE_NODE is not None:
model = onnx.load(MODEL_NAME)
interm = onnx.helper.ValueInfoProto()
interm.name = INVESTIGATE_NODE
model.graph.output.append(interm)
onnx.save(model, MODEL_NAME+'mod')
if INVESTIGATE_NODE is None:
sess = rt.InferenceSession(MODEL_NAME)
else:
sess = rt.InferenceSession(MODEL_NAME+'mod')
input_name = sess.get_inputs()[0].name
output_name = sess.get_outputs()[1 if INVESTIGATE_NODE is not None else 0].name
x = np.expand_dims(data['input'].astype(np.float32), axis=0)
res = sess.run([output_name], {input_name: x})
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment