Skip to content

Instantly share code, notes, and snippets.

@thomasjpfan
Last active April 13, 2020 13:38
Show Gist options
  • Save thomasjpfan/c56867094d6db9ed79dcf32b34679399 to your computer and use it in GitHub Desktop.
Save thomasjpfan/c56867094d6db9ed79dcf32b34679399 to your computer and use it in GitHub Desktop.
Benchmarking native categorical support
from sklearn.model_selection import cross_validate
import matplotlib.pyplot as plt
import numpy as np
from sklearn.datasets import fetch_openml
from sklearn.experimental import enable_hist_gradient_boosting # noqa
from sklearn.ensemble import HistGradientBoostingClassifier
from sklearn.pipeline import make_pipeline
from sklearn.compose import make_column_transformer
from sklearn.compose import make_column_selector
from sklearn.preprocessing import OneHotEncoder
from sklearn.impute import SimpleImputer
X, y = fetch_openml(data_id=179, as_frame=True, return_X_y=True)
# does not support categories in encoding y yet
y = y.cat.codes
n_features = X.shape[1]
n_categorical_features = (X.dtypes == 'category').sum()
n_numerical_features = (X.dtypes == 'float').sum()
print(f"Number of features: {X.shape[1]}")
print(f"Number of categorical featuers: {n_categorical_features}")
print(f"Number of numerical featuers: {n_numerical_features}")
cat_pipe = make_pipeline(
SimpleImputer(strategy='constant', fill_value='missing'),
OneHotEncoder(sparse=False, handle_unknown='ignore'))
preprocessor = make_column_transformer(
(cat_pipe, make_column_selector(dtype_include='category')),
remainder='passthrough')
hist_one_hot = make_pipeline(preprocessor,
HistGradientBoostingClassifier(random_state=0))
hist_native = HistGradientBoostingClassifier(categorical='pandas',
random_state=0)
one_hot_result = cross_validate(hist_one_hot, X, y)
native_result = cross_validate(hist_native, X, y)
fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(12, 8))
plot_info = [('fit_time', 'Fit times (s)', ax1),
('score_time', 'Score times (s)', ax2),
('test_score', 'Test Scores (accuracy)', ax3)]
x, width = np.arange(2), 0.9
for key, title, ax in plot_info:
items = [native_result[key], one_hot_result[key]]
ax.bar(x, [np.mean(item) for item in items],
width,
yerr=[np.std(item) for item in items],
color=['b', 'r'])
ax.set(xlabel='Split number',
title=title,
xticks=[0, 1],
xticklabels=['Native', "One Hot"])
fig.suptitle("Adult dataset")
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment