Skip to content

Instantly share code, notes, and snippets.

@ogrisel
Last active August 10, 2020 14:53
Show Gist options
  • Save ogrisel/7b3c9e22c83330ed67abb5e62000fa6a to your computer and use it in GitHub Desktop.
Save ogrisel/7b3c9e22c83330ed67abb5e62000fa6a to your computer and use it in GitHub Desktop.
from sklearn.model_selection import cross_validate
import matplotlib.pyplot as plt
import numpy as np
from sklearn.datasets import fetch_openml
from sklearn.experimental import enable_hist_gradient_boosting # noqa
from sklearn.ensemble import HistGradientBoostingClassifier
from sklearn.pipeline import make_pipeline
from sklearn.compose import make_column_transformer
from sklearn.compose import make_column_selector
from sklearn.preprocessing import OneHotEncoder, OrdinalEncoder
from sklearn.impute import SimpleImputer
X, y = fetch_openml(data_id=179, as_frame=True, return_X_y=True)
# does not support categories in encoding y yet
y = y.cat.codes
n_features = X.shape[1]
n_categorical_features = (X.dtypes == 'category').sum()
n_numerical_features = (X.dtypes == 'float').sum()
print(f"Number of features: {X.shape[1]}")
print(f"Number of categorical features: {n_categorical_features}")
print(f"Number of numerical features: {n_numerical_features}")
ohe_pipe = make_pipeline(
SimpleImputer(strategy='constant', fill_value='missing'),
OneHotEncoder(sparse=False, handle_unknown='ignore'))
ohe_preprocessor = make_column_transformer(
(ohe_pipe, make_column_selector(dtype_include='category')),
remainder='passthrough')
cat_columns = make_column_selector(dtype_include='category')(X)
categories = [
X[column].unique().tolist() + ["missing"]
for column in cat_columns
]
oe_pipe = make_pipeline(
SimpleImputer(strategy='constant', fill_value='missing'),
OrdinalEncoder(categories=categories))
oe_preprocessor = make_column_transformer(
(oe_pipe, cat_columns),
remainder='passthrough')
hist_one_hot = make_pipeline(ohe_preprocessor,
HistGradientBoostingClassifier(random_state=0))
hist_oe_hot = make_pipeline(oe_preprocessor,
HistGradientBoostingClassifier(random_state=0))
hist_native = HistGradientBoostingClassifier(categorical_features="pandas",
random_state=0)
one_hot_result = cross_validate(hist_one_hot, X, y)
oe_hot_result = cross_validate(hist_oe_hot, X, y)
native_result = cross_validate(hist_native, X, y)
fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(12, 8))
plot_info = [('fit_time', 'Fit times (s)', ax1),
('score_time', 'Score times (s)', ax2),
('test_score', 'Test Scores (accuracy)', ax3)]
x, width = np.arange(3), 0.9
for key, title, ax in plot_info:
items = [native_result[key], oe_hot_result[key], one_hot_result[key]]
labels = ['Native', "Ordinal", "One Hot"]
for item, label in zip(items, labels):
print(f"{label}, {key}: {np.mean(item):.3f} +/- {np.std(item):.3f}")
ax.bar(x, [np.mean(item) for item in items], width,
yerr=[np.std(item) for item in items])
ax.set(xlabel='Split number',
title=title,
xticks=[0, 1, 2],
xticklabels=labels)
fig.suptitle("Adult dataset")
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment