Last active
March 22, 2022 13:16
-
-
Save oliver-batey/6f3ea8c77792a18490702476374a9cad to your computer and use it in GitHub Desktop.
Example of a gridsearch in scikit-learn
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from sklearn.model_selection import GridSearchCV | |
# params is a dictionary, the keys are the hyperparameter and the vaules are a list of values | |
# to search over. | |
params = [ | |
{ | |
"transform__txt__max_features": [None, 100, 10], | |
"transform__num__selector__attribute_names": [ | |
["n_words"], | |
["mean_word_length"], | |
["n_words", "mean_word_length"], | |
], | |
} | |
] | |
# GridSearchCV by default stratifies our cross-validation | |
# and retrains model on the best set of hyperparameters | |
model = GridSearchCV(pipeline, params, scoring="balanced_accuracy", cv=5) | |
model.fit(X_train, y_train) | |
# Display all of the results of the grid search | |
print(model.cv_results_) | |
# Display the mean scores for each combination of hyperparameters | |
print(model.cv_results_[model.cv_results_["mean_test_score"]]) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment