Skip to content

Instantly share code, notes, and snippets.

@serihiro
Created September 22, 2019 09:59
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save serihiro/efd7411419627aea8f1a57033a678bf8 to your computer and use it in GitHub Desktop.
Save serihiro/efd7411419627aea8f1a57033a678bf8 to your computer and use it in GitHub Desktop.
simple onnx runtime (support only MNUIST / GEMM + RELU )
import argparse
import onnx
import numpy as np
from onnx import numpy_helper
from PIL import Image
def relu(x: np.ndarray) -> np.ndarray:
return np.maximum(x, 0)
def gemm(x: np.ndarray, w: np.ndarray, b: np.ndarray) -> np.ndarray:
return w.dot(x) + b
def infer(model_path: str, input_key: str, img: np.ndarray) -> np.ndarray:
graph = onnx.load(model_path).graph
variables = {}
variables[input_key] = img
for data in graph.initializer:
if data.name not in variables:
variables[data.name] = numpy_helper.to_array(data)
for node in graph.node:
if node.op_type == 'Gemm':
output = gemm(
variables[node.input[0]], variables[node.input[1]], variables[node.input[2]])
elif node.op_type == 'Relu':
output = relu(variables[node.input[0]])
else:
raise AttributeError(
f'A not supported op_type is used: {node.op_type}')
variables[node.output[0]] = output
return output
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--model', '-m', type=str, required=True)
parser.add_argument('--image', '-i', type=str, required=True)
parser.add_argument('--input_key', type=str, default='Input_0')
args = parser.parse_args()
img = np.array(Image.open(args.image), dtype='float32').reshape(784)
expected = infer(model_path=args.model, input_key=args.input_key,
img=img)
print(expected)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment