Skip to content

Instantly share code, notes, and snippets.

@marskar
Last active November 18, 2019 03:12
Show Gist options
  • Save marskar/f9ede6b49ec19219ee9b0a8e9351ccd2 to your computer and use it in GitHub Desktop.
Save marskar/f9ede6b49ec19219ee9b0a8e9351ccd2 to your computer and use it in GitHub Desktop.
Two-hyperparameter, cross-validated grid search
import pandas as pd
from sklearn.model_selection import KFold, GridSearchCV
from sklearn.ensemble import RandomForestRegressor
from sklearn.datasets import load_boston
boston = load_boston()
X = boston.data
y = boston.target
# Create kf instance
kf = KFold(n_splits=5, shuffle=True, random_state=42)
# Create dt instance
rf = RandomForestRegressor()
# Create grid search instance
gscv = GridSearchCV(
rf,
{"max_depth": range(1, 20), "n_estimators": range(2, 20)},
cv=kf,
n_jobs=-1
)
gscv.fit(X, y)
# Get cross-validation data
cv_df = pd.DataFrame(gscv.cv_results_)
# Create a heatmap-style table
pivoted_df = cv_df.pivot(index="param_max_depth",
columns="param_n_estimators",
values="mean_test_score").round(3)
pivoted_df.style.background_gradient(
cmap="nipy_spectral",
axis=None
)
# Or create a heatmap with seaborn
import seaborn as sns
import matplotlib.pyplot as plt
plt.subplots(figsize=(20,15))
sns.heatmap(pivoted_df,
cmap="nipy_spectral",
annot=True,
annot_kws={"size": 16})
plt.xlabel('Trees (n)', size=18)
plt.ylabel('Levels (max)', size=18)
plt.xticks(size=14)
plt.yticks(rotation=0, size=14);
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment