Skip to content

Instantly share code, notes, and snippets.

@thomasjpfan
Last active April 12, 2023 17:58
Show Gist options
  • Save thomasjpfan/5c4aff6d17f807801ddef1a73fe6c53b to your computer and use it in GitHub Desktop.
Save thomasjpfan/5c4aff6d17f807801ddef1a73fe6c53b to your computer and use it in GitHub Desktop.
Benchmark for early stopping in hist gradient boosting
"""Benchmark for early stopping with predefined metric strings.
```python
python bench_hist_early_stopping.py --problem classification
python bench_hist_early_stopping.py --problem regression
```
"""
import argparse
from time import perf_counter
from sklearn.ensemble import HistGradientBoostingRegressor
from sklearn.ensemble import HistGradientBoostingClassifier
from sklearn.datasets import make_regression
from sklearn.datasets import make_classification
parser = argparse.ArgumentParser()
parser.add_argument(
"--problem",
type=str,
default="classification",
choices=["classification", "regression"],
)
n_samples, n_features = int(2e6), 20
args = parser.parse_args()
def get_estimator_and_data():
if args.problem == "classification":
X, y = make_classification(
n_samples,
n_features=n_features,
n_classes=3,
n_informative=4,
random_state=0,
)
return X, y, HistGradientBoostingClassifier, "accuracy"
elif args.problem == "regression":
X, y = make_regression(n_samples, n_features=n_features, random_state=0)
return X, y, HistGradientBoostingRegressor, "neg_mean_absolute_error"
X, y, Estimator, scorer = get_estimator_and_data()
est = Estimator(
early_stopping=True, random_state=0, validation_fraction=0.3, scoring=scorer
)
start = perf_counter()
est.fit(X, y)
end = perf_counter()
assert hasattr(est, "validation_score_")
print("Runtime:", end - start)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment