Skip to content

Instantly share code, notes, and snippets.

@victornoel victornoel/
Created Feb 14, 2020

What would you like to do?
import os
import warnings
from typing import List, Tuple
import onnxruntime as rt
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.compose import ColumnTransformer
from sklearn.model_selection import RandomizedSearchCV
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import MinMaxScaler, OneHotEncoder
from scipy.stats import randint as sp_randint, uniform as sp_uniform
from xgboost import XGBRegressor
from onnxmltools.convert.xgboost.operator_converters.XGBoost import convert_xgboost
import onnxmltools.convert.common.data_types as onnxtypes
from skl2onnx import update_registered_converter, convert_sklearn
from skl2onnx.common.shape_calculator import calculate_linear_regressor_output_shapes
warnings.simplefilter(action='ignore', category=FutureWarning)
os.environ["PYTHONWARNINGS"] = 'ignore'
def test():
data = pd.read_csv((''
#data['female'] = data['sex'] == 'female'
#data.drop('sex', axis=1, inplace=True)
#data = data[['age', 'fare', 'female', 'embarked', 'pclass', 'survived']]
data = data[['age', 'fare', 'sex', 'embarked', 'pclass', 'survived']]
for col in data:
dtype = data[col].dtype
if dtype in ['float64', 'float32']:
data[col].fillna(0., inplace=True)
if dtype in ['int64']:
data[col].fillna(0, inplace=True)
elif dtype in ['O']:
data[col].fillna('N/A', inplace=True)
full_df = data.drop('survived', axis=1)
full_labels = data['survived']
train_df, test_df, train_labels, test_labels = train_test_split(
full_df, full_labels, test_size=.2
col_transformer = _column_tranformer_fitted_from_df(full_df)
search = RandomizedSearchCV(
XGBRegressor(verbose=0, objective='reg:squarederror'),
"colsample_bytree": sp_uniform(),
"gamma": sp_uniform(.1, 1),
'learning_rate': sp_uniform(.1, .6),
'max_depth': sp_randint(10, 30),
'min_child_weight': sp_uniform(0, 3),
'n_estimators': range(10, 75),
cv=5, n_iter=10, n_jobs=-1
), train_labels)
regressor = XGBRegressor(verbose=0, objective='reg:squarederror', **search.best_params_)
model = Pipeline(
steps=[('preprocessor', col_transformer),
('regressor', regressor)]
XGBRegressor, 'XGBRegressor',
onnx = convert_sklearn(
session = rt.InferenceSession(onnx.SerializeToString())
pred_skl = model.predict(test_df)
pred_onx = _predict(session, test_df)
diff = np.sort(np.abs(np.squeeze(pred_skl) - np.squeeze(pred_onx)))[-5:]
print('min(Y)-max(Y):', min(test_labels), max(test_labels))
def _column_tranformer_fitted_from_df(data: pd.DataFrame) -> ColumnTransformer:
def transformer_for_column(column: pd.Series):
if column.dtype in ['float64', 'float32', 'int64']:
return MinMaxScaler()
if column.dtype in ['bool']:
return 'passthrough'
if column.dtype in ['O']:
return OneHotEncoder(drop='first')
raise ValueError(f'Unexpected column dtype for {}:{column.dtype}')
return ColumnTransformer(
[(col, transformer_for_column(data[col]), [col]) for col in data.columns],
def _convert_dataframe_schema(data: pd.DataFrame) -> List[Tuple[str, onnxtypes.DataType]]:
def type_for_column(column: pd.Series):
if column.dtype in ['float64', 'float32']:
# onnx does not really support float64 (DoubleTensorType does not work with TreeEnsembleRegressor)
return onnxtypes.FloatTensorType([None, 1])
if column.dtype in ['int64']:
return onnxtypes.Int64TensorType([None, 1])
if column.dtype in ['bool']:
return onnxtypes.BooleanTensorType([None, 1])
if column.dtype in ['O']:
return onnxtypes.StringTensorType([None, 1])
raise ValueError(f'Unexpected column dtype for {}:{column.dtype}')
return [(col, type_for_column(data[col])) for col in data.columns]
def _predict(session: rt.InferenceSession, data: pd.DataFrame) -> pd.Series:
def _correctly_typed_column(column: pd.Series) -> pd.Series:
if column.dtype in ['float64']:
return column.astype(np.float32)
return column
def _correctly_shaped_values(values):
return values.reshape((values.shape[0], 1))
inputs = {
c: _correctly_shaped_values(_correctly_typed_column(data[c]).values)
for c in data.columns
return pd.Series(, inputs)[0].reshape(-1),
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
You can’t perform that action at this time.