Skip to content

Instantly share code, notes, and snippets.

@calebmadrigal
Last active April 2, 2019 20:29
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save calebmadrigal/3d505c9e883e3d1e53bf56f84a9106b3 to your computer and use it in GitHub Desktop.
Save calebmadrigal/3d505c9e883e3d1e53bf56f84a9106b3 to your computer and use it in GitHub Desktop.
sklearn to onnx
Prediction for [1, 1, 1, 1]: [[0.16937022 0.83062978]]
Prediction for [1, 2, 3, 4]: [[0.25071094 0.74928906]]
The maximum opset needed by this model is only 6.
The maximum opset needed by this model is only 1.
INPUT NAME: float_input
OUTPUTS: ['output_label', 'output_probability']
Prediction with ONNX for [1. 1. 1. 1.]: [[{0: 0.16937017440795898, 1: 0.830629825592041}]]
Prediction with ONNX for [1. 2. 3. 4.]: [[{0: 0.2507109045982361, 1: 0.7492890954017639}]]
from sklearn.ensemble import RandomForestClassifier
from sklearn.datasets import make_classification
X, y = make_classification(n_samples=1000, n_features=4,
n_informative=2, n_redundant=0,
random_state=0, shuffle=False)
clf = RandomForestClassifier(max_depth=2, random_state=0)
clf.fit(X, y)
tests = [ [1, 1, 1, 1], [1, 2, 3, 4] ]
for t in tests:
result = clf.predict_proba([t])
print('Prediction for {}: {}'.format(t, result))
# Onnx conversion
from skl2onnx import convert_sklearn
from skl2onnx.common.data_types import FloatTensorType
initial_type = [('float_input', FloatTensorType([1, 4]))]
onx = convert_sklearn(clf, initial_types=initial_type)
with open("iris_clf.onnx", "wb") as f:
f.write(onx.SerializeToString())
# Run via Onnx Runtime
import numpy as np
import onnxruntime as rt
sess = rt.InferenceSession("iris_clf.onnx")
input_name = sess.get_inputs()[0].name
print('INPUT NAME: {}'.format(input_name))
outputs = sess.get_outputs()
output_names = [o.name for o in sess.get_outputs()]
print('OUTPUTS: {}'.format(output_names))
label_name = output_names[1] # 'output_probability'
for t in tests:
X = np.asarray(t).astype(np.float32)
pred_onnx = sess.run([label_name], {input_name: X})
print('Prediction with ONNX for {}: {}'.format(X, pred_onnx))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment