Skip to content

Instantly share code, notes, and snippets.

@jdh8
Created July 8, 2019 10:10
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save jdh8/b3bac3389d2abc82e85f30afc4bf131d to your computer and use it in GitHub Desktop.
Save jdh8/b3bac3389d2abc82e85f30afc4bf131d to your computer and use it in GitHub Desktop.
ONNX operator inference by immediately constructed single-layer model
import numpy
import onnx
import onnxruntime
def infer(operator, *inputs, outputs=1, **attributes):
tags = {
numpy.void: onnx.TensorProto.UNDEFINED,
numpy.float32: onnx.TensorProto.FLOAT,
numpy.uint8: onnx.TensorProto.UINT8,
numpy.int8: onnx.TensorProto.INT8,
numpy.uint16: onnx.TensorProto.UINT16,
numpy.int16: onnx.TensorProto.INT16,
numpy.int32: onnx.TensorProto.INT32,
numpy.int64: onnx.TensorProto.INT64,
numpy.bytes_: onnx.TensorProto.STRING,
numpy.str_: onnx.TensorProto.STRING,
numpy.bool_: onnx.TensorProto.BOOL,
numpy.float16: onnx.TensorProto.FLOAT16,
numpy.float64: onnx.TensorProto.DOUBLE,
numpy.uint32: onnx.TensorProto.UINT32,
numpy.uint64: onnx.TensorProto.UINT64,
numpy.complex64: onnx.TensorProto.COMPLEX64,
numpy.complex128: onnx.TensorProto.COMPLEX128,
}
arity = range(len(inputs))
itensors = [onnx.helper.make_tensor_value_info(f"%i{k}", tags[inputs[k].dtype.type], inputs[k].shape) for k in arity]
otensors = [onnx.helper.make_empty_tensor_value_info(f"%o{k}") for k in range(outputs)]
node = onnx.helper.make_node(operator, [t.name for t in itensors], [t.name for t in otensors], **attributes)
graph = onnx.helper.make_graph([node], operator, itensors, otensors)
model = onnx.helper.make_model(graph)
sess = onnxruntime.InferenceSession(model.SerializeToString())
return sess.run([f"%o{k}" for k in range(outputs)], { f"%i{k}": inputs[k] for k in arity })
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment