Created
September 9, 2022 13:51
-
-
Save callmemaze/fd39a691d46b74020fcf93fe66a2aa72 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import joblib\n", | |
"import pandas as pd\n", | |
"from sklearn.ensemble import GradientBoostingClassifier\n", | |
"from sklearn.model_selection import GridSearchCV\n", | |
"import warnings\n", | |
"warnings.filterwarnings('ignore', category=FutureWarning)\n", | |
"warnings.filterwarnings('ignore', category=DeprecationWarning)\n", | |
"\n", | |
"tr_features = pd.read_csv('../../../train_features.csv')\n", | |
"tr_labels = pd.read_csv('../../../train_labels.csv', header=None)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"### Hyperparameter tuning\n", | |
"\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def print_results(results):\n", | |
" print('BEST PARAMS: {}\\n'.format(results.best_params_))\n", | |
"\n", | |
" means = results.cv_results_['mean_test_score']\n", | |
" stds = results.cv_results_['std_test_score']\n", | |
" for mean, std, params in zip(means, stds, results.cv_results_['params']):\n", | |
" print('{} (+/-{}) for {}'.format(round(mean, 3), round(std * 2, 3), params))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"BEST PARAMS: {'learning_rate': 0.01, 'max_depth': 3, 'n_estimators': 500}\n", | |
"\n", | |
"0.624 (+/-0.005) for {'learning_rate': 0.01, 'max_depth': 1, 'n_estimators': 5}\n", | |
"0.796 (+/-0.116) for {'learning_rate': 0.01, 'max_depth': 1, 'n_estimators': 50}\n", | |
"0.796 (+/-0.116) for {'learning_rate': 0.01, 'max_depth': 1, 'n_estimators': 250}\n", | |
"0.811 (+/-0.118) for {'learning_rate': 0.01, 'max_depth': 1, 'n_estimators': 500}\n", | |
"0.624 (+/-0.005) for {'learning_rate': 0.01, 'max_depth': 3, 'n_estimators': 5}\n", | |
"0.811 (+/-0.071) for {'learning_rate': 0.01, 'max_depth': 3, 'n_estimators': 50}\n", | |
"0.83 (+/-0.076) for {'learning_rate': 0.01, 'max_depth': 3, 'n_estimators': 250}\n", | |
"0.841 (+/-0.079) for {'learning_rate': 0.01, 'max_depth': 3, 'n_estimators': 500}\n", | |
"0.624 (+/-0.005) for {'learning_rate': 0.01, 'max_depth': 5, 'n_estimators': 5}\n", | |
"0.818 (+/-0.051) for {'learning_rate': 0.01, 'max_depth': 5, 'n_estimators': 50}\n", | |
"0.82 (+/-0.039) for {'learning_rate': 0.01, 'max_depth': 5, 'n_estimators': 250}\n", | |
"0.83 (+/-0.044) for {'learning_rate': 0.01, 'max_depth': 5, 'n_estimators': 500}\n", | |
"0.624 (+/-0.005) for {'learning_rate': 0.01, 'max_depth': 7, 'n_estimators': 5}\n", | |
"0.818 (+/-0.054) for {'learning_rate': 0.01, 'max_depth': 7, 'n_estimators': 50}\n", | |
"0.822 (+/-0.041) for {'learning_rate': 0.01, 'max_depth': 7, 'n_estimators': 250}\n", | |
"0.801 (+/-0.023) for {'learning_rate': 0.01, 'max_depth': 7, 'n_estimators': 500}\n", | |
"0.624 (+/-0.005) for {'learning_rate': 0.01, 'max_depth': 9, 'n_estimators': 5}\n", | |
"0.801 (+/-0.055) for {'learning_rate': 0.01, 'max_depth': 9, 'n_estimators': 50}\n", | |
"0.801 (+/-0.024) for {'learning_rate': 0.01, 'max_depth': 9, 'n_estimators': 250}\n", | |
"0.783 (+/-0.026) for {'learning_rate': 0.01, 'max_depth': 9, 'n_estimators': 500}\n", | |
"0.796 (+/-0.116) for {'learning_rate': 0.1, 'max_depth': 1, 'n_estimators': 5}\n", | |
"0.815 (+/-0.12) for {'learning_rate': 0.1, 'max_depth': 1, 'n_estimators': 50}\n", | |
"0.818 (+/-0.112) for {'learning_rate': 0.1, 'max_depth': 1, 'n_estimators': 250}\n", | |
"0.828 (+/-0.093) for {'learning_rate': 0.1, 'max_depth': 1, 'n_estimators': 500}\n", | |
"0.813 (+/-0.073) for {'learning_rate': 0.1, 'max_depth': 3, 'n_estimators': 5}\n", | |
"0.835 (+/-0.082) for {'learning_rate': 0.1, 'max_depth': 3, 'n_estimators': 50}\n", | |
"0.831 (+/-0.038) for {'learning_rate': 0.1, 'max_depth': 3, 'n_estimators': 250}\n", | |
"0.811 (+/-0.03) for {'learning_rate': 0.1, 'max_depth': 3, 'n_estimators': 500}\n", | |
"0.815 (+/-0.053) for {'learning_rate': 0.1, 'max_depth': 5, 'n_estimators': 5}\n", | |
"0.826 (+/-0.018) for {'learning_rate': 0.1, 'max_depth': 5, 'n_estimators': 50}\n", | |
"0.803 (+/-0.048) for {'learning_rate': 0.1, 'max_depth': 5, 'n_estimators': 250}\n", | |
"0.807 (+/-0.053) for {'learning_rate': 0.1, 'max_depth': 5, 'n_estimators': 500}\n", | |
"0.822 (+/-0.056) for {'learning_rate': 0.1, 'max_depth': 7, 'n_estimators': 5}\n", | |
"0.8 (+/-0.015) for {'learning_rate': 0.1, 'max_depth': 7, 'n_estimators': 50}\n", | |
"0.794 (+/-0.044) for {'learning_rate': 0.1, 'max_depth': 7, 'n_estimators': 250}\n", | |
"0.801 (+/-0.066) for {'learning_rate': 0.1, 'max_depth': 7, 'n_estimators': 500}\n", | |
"0.8 (+/-0.042) for {'learning_rate': 0.1, 'max_depth': 9, 'n_estimators': 5}\n", | |
"0.788 (+/-0.041) for {'learning_rate': 0.1, 'max_depth': 9, 'n_estimators': 50}\n", | |
"0.79 (+/-0.024) for {'learning_rate': 0.1, 'max_depth': 9, 'n_estimators': 250}\n", | |
"0.79 (+/-0.049) for {'learning_rate': 0.1, 'max_depth': 9, 'n_estimators': 500}\n", | |
"0.818 (+/-0.1) for {'learning_rate': 1, 'max_depth': 1, 'n_estimators': 5}\n", | |
"0.83 (+/-0.078) for {'learning_rate': 1, 'max_depth': 1, 'n_estimators': 50}\n", | |
"0.828 (+/-0.069) for {'learning_rate': 1, 'max_depth': 1, 'n_estimators': 250}\n", | |
"0.818 (+/-0.082) for {'learning_rate': 1, 'max_depth': 1, 'n_estimators': 500}\n", | |
"0.82 (+/-0.063) for {'learning_rate': 1, 'max_depth': 3, 'n_estimators': 5}\n", | |
"0.794 (+/-0.038) for {'learning_rate': 1, 'max_depth': 3, 'n_estimators': 50}\n", | |
"0.796 (+/-0.039) for {'learning_rate': 1, 'max_depth': 3, 'n_estimators': 250}\n", | |
"0.801 (+/-0.048) for {'learning_rate': 1, 'max_depth': 3, 'n_estimators': 500}\n", | |
"0.805 (+/-0.042) for {'learning_rate': 1, 'max_depth': 5, 'n_estimators': 5}\n", | |
"0.811 (+/-0.078) for {'learning_rate': 1, 'max_depth': 5, 'n_estimators': 50}\n", | |
"0.809 (+/-0.074) for {'learning_rate': 1, 'max_depth': 5, 'n_estimators': 250}\n", | |
"0.803 (+/-0.081) for {'learning_rate': 1, 'max_depth': 5, 'n_estimators': 500}\n", | |
"0.783 (+/-0.013) for {'learning_rate': 1, 'max_depth': 7, 'n_estimators': 5}\n", | |
"0.787 (+/-0.051) for {'learning_rate': 1, 'max_depth': 7, 'n_estimators': 50}\n", | |
"0.796 (+/-0.032) for {'learning_rate': 1, 'max_depth': 7, 'n_estimators': 250}\n", | |
"0.79 (+/-0.052) for {'learning_rate': 1, 'max_depth': 7, 'n_estimators': 500}\n", | |
"0.785 (+/-0.034) for {'learning_rate': 1, 'max_depth': 9, 'n_estimators': 5}\n", | |
"0.77 (+/-0.033) for {'learning_rate': 1, 'max_depth': 9, 'n_estimators': 50}\n", | |
"0.8 (+/-0.053) for {'learning_rate': 1, 'max_depth': 9, 'n_estimators': 250}\n", | |
"0.801 (+/-0.043) for {'learning_rate': 1, 'max_depth': 9, 'n_estimators': 500}\n", | |
"0.204 (+/-0.116) for {'learning_rate': 10, 'max_depth': 1, 'n_estimators': 5}\n", | |
"0.204 (+/-0.116) for {'learning_rate': 10, 'max_depth': 1, 'n_estimators': 50}\n", | |
"0.204 (+/-0.116) for {'learning_rate': 10, 'max_depth': 1, 'n_estimators': 250}\n", | |
"0.204 (+/-0.116) for {'learning_rate': 10, 'max_depth': 1, 'n_estimators': 500}\n", | |
"0.311 (+/-0.192) for {'learning_rate': 10, 'max_depth': 3, 'n_estimators': 5}\n", | |
"0.311 (+/-0.192) for {'learning_rate': 10, 'max_depth': 3, 'n_estimators': 50}\n", | |
"0.311 (+/-0.192) for {'learning_rate': 10, 'max_depth': 3, 'n_estimators': 250}\n", | |
"0.311 (+/-0.192) for {'learning_rate': 10, 'max_depth': 3, 'n_estimators': 500}\n", | |
"0.552 (+/-0.363) for {'learning_rate': 10, 'max_depth': 5, 'n_estimators': 5}\n", | |
"0.444 (+/-0.316) for {'learning_rate': 10, 'max_depth': 5, 'n_estimators': 50}\n", | |
"0.397 (+/-0.204) for {'learning_rate': 10, 'max_depth': 5, 'n_estimators': 250}\n", | |
"0.397 (+/-0.197) for {'learning_rate': 10, 'max_depth': 5, 'n_estimators': 500}\n", | |
"0.599 (+/-0.171) for {'learning_rate': 10, 'max_depth': 7, 'n_estimators': 5}\n", | |
"0.588 (+/-0.188) for {'learning_rate': 10, 'max_depth': 7, 'n_estimators': 50}\n", | |
"0.616 (+/-0.133) for {'learning_rate': 10, 'max_depth': 7, 'n_estimators': 250}\n", | |
"0.618 (+/-0.159) for {'learning_rate': 10, 'max_depth': 7, 'n_estimators': 500}\n", | |
"0.7 (+/-0.124) for {'learning_rate': 10, 'max_depth': 9, 'n_estimators': 5}\n", | |
"0.717 (+/-0.128) for {'learning_rate': 10, 'max_depth': 9, 'n_estimators': 50}\n", | |
"0.7 (+/-0.124) for {'learning_rate': 10, 'max_depth': 9, 'n_estimators': 250}\n", | |
"0.704 (+/-0.11) for {'learning_rate': 10, 'max_depth': 9, 'n_estimators': 500}\n", | |
"0.376 (+/-0.005) for {'learning_rate': 100, 'max_depth': 1, 'n_estimators': 5}\n", | |
"0.376 (+/-0.005) for {'learning_rate': 100, 'max_depth': 1, 'n_estimators': 50}\n", | |
"0.376 (+/-0.005) for {'learning_rate': 100, 'max_depth': 1, 'n_estimators': 250}\n", | |
"0.376 (+/-0.005) for {'learning_rate': 100, 'max_depth': 1, 'n_estimators': 500}\n", | |
"0.29 (+/-0.104) for {'learning_rate': 100, 'max_depth': 3, 'n_estimators': 5}\n", | |
"0.29 (+/-0.104) for {'learning_rate': 100, 'max_depth': 3, 'n_estimators': 50}\n", | |
"0.29 (+/-0.104) for {'learning_rate': 100, 'max_depth': 3, 'n_estimators': 250}\n", | |
"0.29 (+/-0.104) for {'learning_rate': 100, 'max_depth': 3, 'n_estimators': 500}\n", | |
"0.373 (+/-0.181) for {'learning_rate': 100, 'max_depth': 5, 'n_estimators': 5}\n", | |
"0.375 (+/-0.174) for {'learning_rate': 100, 'max_depth': 5, 'n_estimators': 50}\n", | |
"0.369 (+/-0.176) for {'learning_rate': 100, 'max_depth': 5, 'n_estimators': 250}\n", | |
"0.375 (+/-0.173) for {'learning_rate': 100, 'max_depth': 5, 'n_estimators': 500}\n", | |
"0.551 (+/-0.126) for {'learning_rate': 100, 'max_depth': 7, 'n_estimators': 5}\n", | |
"0.547 (+/-0.13) for {'learning_rate': 100, 'max_depth': 7, 'n_estimators': 50}\n", | |
"0.584 (+/-0.117) for {'learning_rate': 100, 'max_depth': 7, 'n_estimators': 250}\n", | |
"0.562 (+/-0.135) for {'learning_rate': 100, 'max_depth': 7, 'n_estimators': 500}\n", | |
"0.635 (+/-0.063) for {'learning_rate': 100, 'max_depth': 9, 'n_estimators': 5}\n", | |
"0.674 (+/-0.079) for {'learning_rate': 100, 'max_depth': 9, 'n_estimators': 50}\n", | |
"0.652 (+/-0.061) for {'learning_rate': 100, 'max_depth': 9, 'n_estimators': 250}\n", | |
"0.663 (+/-0.108) for {'learning_rate': 100, 'max_depth': 9, 'n_estimators': 500}\n" | |
] | |
} | |
], | |
"source": [ | |
"gb = GradientBoostingClassifier()\n", | |
"parameters = {\n", | |
" 'n_estimators': [5, 50, 250, 500],\n", | |
" 'max_depth': [1, 3, 5, 7, 9],\n", | |
" 'learning_rate': [0.01, 0.1, 1, 10, 100]\n", | |
"}\n", | |
"\n", | |
"cv = GridSearchCV(gb, parameters, cv=5)\n", | |
"cv.fit(tr_features, tr_labels.values.ravel())\n", | |
"\n", | |
"print_results(cv)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"### Write out pickled model" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"['../../../GB_model.pkl']" | |
] | |
}, | |
"execution_count": 4, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"joblib.dump(cv.best_estimator_, '../../../GB_model.pkl')" | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.7.1" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment