-
-
Save amankharwal/d36f936b74d19ab6f1a19c947c1676e5 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from sklearn.ensemble import RandomForestClassifier | |
from sklearn.metrics import make_scorer, accuracy_score | |
from sklearn.model_selection import GridSearchCV | |
# Choose the type of classifier. | |
clf = RandomForestClassifier() | |
# Choose some parameter combinations to try | |
parameters = {'n_estimators': [4, 6, 9], | |
'max_features': ['log2', 'sqrt','auto'], | |
'criterion': ['entropy', 'gini'], | |
'max_depth': [2, 3, 5, 10], | |
'min_samples_split': [2, 3, 5], | |
'min_samples_leaf': [1,5,8] | |
} | |
# Type of scoring used to compare parameter combinations | |
acc_scorer = make_scorer(accuracy_score) | |
# Run the grid search | |
grid_obj = GridSearchCV(clf, parameters, scoring=acc_scorer) | |
grid_obj = grid_obj.fit(X_train, y_train) | |
# Set the clf to the best combination of parameters | |
clf = grid_obj.best_estimator_ | |
# Fit the best algorithm to the data. | |
clf.fit(X_train, y_train) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment