Created
December 9, 2022 19:32
-
-
Save AlexanderPuckhaber/715e82a753b6766d880c0c3a3be4ba44 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Copyright (c) Microsoft Corporation. All rights reserved. | |
# Licensed under the MIT License. | |
""" | |
.. _l-example-simple-usage: | |
Load and predict with ONNX Runtime and a very simple model | |
========================================================== | |
This example demonstrates how to load a model and compute | |
the output for an input vector. It also shows how to | |
retrieve the definition of its inputs and outputs. | |
""" | |
# modified from: https://onnxruntime.ai/docs/api/python/auto_examples/plot_load_and_predict.html#sphx-glr-auto-examples-plot-load-and-predict-py | |
import numpy | |
import os | |
import onnxruntime as rt | |
from onnxruntime.datasets import get_example | |
######################### | |
# Let's load a very simple model. | |
# The model is available on github `onnx...test_sigmoid <https://github.com/onnx/onnx/blob/main/onnx/backend/test/data/node/test_sigmoid>`_. | |
example1 = get_example("sigmoid.onnx") | |
options = rt.SessionOptions() | |
# enable profiling | |
options.enable_profiling = True | |
# enable perf profiling | |
options.add_session_config_entry("session.profiler.perf_config_file_name", os.path.abspath("perf_config.json")) | |
sess = rt.InferenceSession(example1, options, providers=rt.get_available_providers()) | |
######################### | |
# Let's see the input name and shape. | |
input_name = sess.get_inputs()[0].name | |
print("input name", input_name) | |
input_shape = sess.get_inputs()[0].shape | |
print("input shape", input_shape) | |
input_type = sess.get_inputs()[0].type | |
print("input type", input_type) | |
######################### | |
# Let's see the output name and shape. | |
output_name = sess.get_outputs()[0].name | |
print("output name", output_name) | |
output_shape = sess.get_outputs()[0].shape | |
print("output shape", output_shape) | |
output_type = sess.get_outputs()[0].type | |
print("output type", output_type) | |
######################### | |
# Let's compute its outputs (or predictions if it is a machine learned model). | |
import numpy.random | |
x = numpy.random.random((3, 4, 5)) | |
x = x.astype(numpy.float32) | |
res = sess.run([output_name], {input_name: x}) | |
print(res) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment