Skip to content

Instantly share code, notes, and snippets.

@callmemaze
Created September 9, 2022 13:51
Show Gist options
  • Save callmemaze/fd39a691d46b74020fcf93fe66a2aa72 to your computer and use it in GitHub Desktop.
Save callmemaze/fd39a691d46b74020fcf93fe66a2aa72 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import joblib\n",
"import pandas as pd\n",
"from sklearn.ensemble import GradientBoostingClassifier\n",
"from sklearn.model_selection import GridSearchCV\n",
"import warnings\n",
"warnings.filterwarnings('ignore', category=FutureWarning)\n",
"warnings.filterwarnings('ignore', category=DeprecationWarning)\n",
"\n",
"tr_features = pd.read_csv('../../../train_features.csv')\n",
"tr_labels = pd.read_csv('../../../train_labels.csv', header=None)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Hyperparameter tuning\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"def print_results(results):\n",
" print('BEST PARAMS: {}\\n'.format(results.best_params_))\n",
"\n",
" means = results.cv_results_['mean_test_score']\n",
" stds = results.cv_results_['std_test_score']\n",
" for mean, std, params in zip(means, stds, results.cv_results_['params']):\n",
" print('{} (+/-{}) for {}'.format(round(mean, 3), round(std * 2, 3), params))"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"BEST PARAMS: {'learning_rate': 0.01, 'max_depth': 3, 'n_estimators': 500}\n",
"\n",
"0.624 (+/-0.005) for {'learning_rate': 0.01, 'max_depth': 1, 'n_estimators': 5}\n",
"0.796 (+/-0.116) for {'learning_rate': 0.01, 'max_depth': 1, 'n_estimators': 50}\n",
"0.796 (+/-0.116) for {'learning_rate': 0.01, 'max_depth': 1, 'n_estimators': 250}\n",
"0.811 (+/-0.118) for {'learning_rate': 0.01, 'max_depth': 1, 'n_estimators': 500}\n",
"0.624 (+/-0.005) for {'learning_rate': 0.01, 'max_depth': 3, 'n_estimators': 5}\n",
"0.811 (+/-0.071) for {'learning_rate': 0.01, 'max_depth': 3, 'n_estimators': 50}\n",
"0.83 (+/-0.076) for {'learning_rate': 0.01, 'max_depth': 3, 'n_estimators': 250}\n",
"0.841 (+/-0.079) for {'learning_rate': 0.01, 'max_depth': 3, 'n_estimators': 500}\n",
"0.624 (+/-0.005) for {'learning_rate': 0.01, 'max_depth': 5, 'n_estimators': 5}\n",
"0.818 (+/-0.051) for {'learning_rate': 0.01, 'max_depth': 5, 'n_estimators': 50}\n",
"0.82 (+/-0.039) for {'learning_rate': 0.01, 'max_depth': 5, 'n_estimators': 250}\n",
"0.83 (+/-0.044) for {'learning_rate': 0.01, 'max_depth': 5, 'n_estimators': 500}\n",
"0.624 (+/-0.005) for {'learning_rate': 0.01, 'max_depth': 7, 'n_estimators': 5}\n",
"0.818 (+/-0.054) for {'learning_rate': 0.01, 'max_depth': 7, 'n_estimators': 50}\n",
"0.822 (+/-0.041) for {'learning_rate': 0.01, 'max_depth': 7, 'n_estimators': 250}\n",
"0.801 (+/-0.023) for {'learning_rate': 0.01, 'max_depth': 7, 'n_estimators': 500}\n",
"0.624 (+/-0.005) for {'learning_rate': 0.01, 'max_depth': 9, 'n_estimators': 5}\n",
"0.801 (+/-0.055) for {'learning_rate': 0.01, 'max_depth': 9, 'n_estimators': 50}\n",
"0.801 (+/-0.024) for {'learning_rate': 0.01, 'max_depth': 9, 'n_estimators': 250}\n",
"0.783 (+/-0.026) for {'learning_rate': 0.01, 'max_depth': 9, 'n_estimators': 500}\n",
"0.796 (+/-0.116) for {'learning_rate': 0.1, 'max_depth': 1, 'n_estimators': 5}\n",
"0.815 (+/-0.12) for {'learning_rate': 0.1, 'max_depth': 1, 'n_estimators': 50}\n",
"0.818 (+/-0.112) for {'learning_rate': 0.1, 'max_depth': 1, 'n_estimators': 250}\n",
"0.828 (+/-0.093) for {'learning_rate': 0.1, 'max_depth': 1, 'n_estimators': 500}\n",
"0.813 (+/-0.073) for {'learning_rate': 0.1, 'max_depth': 3, 'n_estimators': 5}\n",
"0.835 (+/-0.082) for {'learning_rate': 0.1, 'max_depth': 3, 'n_estimators': 50}\n",
"0.831 (+/-0.038) for {'learning_rate': 0.1, 'max_depth': 3, 'n_estimators': 250}\n",
"0.811 (+/-0.03) for {'learning_rate': 0.1, 'max_depth': 3, 'n_estimators': 500}\n",
"0.815 (+/-0.053) for {'learning_rate': 0.1, 'max_depth': 5, 'n_estimators': 5}\n",
"0.826 (+/-0.018) for {'learning_rate': 0.1, 'max_depth': 5, 'n_estimators': 50}\n",
"0.803 (+/-0.048) for {'learning_rate': 0.1, 'max_depth': 5, 'n_estimators': 250}\n",
"0.807 (+/-0.053) for {'learning_rate': 0.1, 'max_depth': 5, 'n_estimators': 500}\n",
"0.822 (+/-0.056) for {'learning_rate': 0.1, 'max_depth': 7, 'n_estimators': 5}\n",
"0.8 (+/-0.015) for {'learning_rate': 0.1, 'max_depth': 7, 'n_estimators': 50}\n",
"0.794 (+/-0.044) for {'learning_rate': 0.1, 'max_depth': 7, 'n_estimators': 250}\n",
"0.801 (+/-0.066) for {'learning_rate': 0.1, 'max_depth': 7, 'n_estimators': 500}\n",
"0.8 (+/-0.042) for {'learning_rate': 0.1, 'max_depth': 9, 'n_estimators': 5}\n",
"0.788 (+/-0.041) for {'learning_rate': 0.1, 'max_depth': 9, 'n_estimators': 50}\n",
"0.79 (+/-0.024) for {'learning_rate': 0.1, 'max_depth': 9, 'n_estimators': 250}\n",
"0.79 (+/-0.049) for {'learning_rate': 0.1, 'max_depth': 9, 'n_estimators': 500}\n",
"0.818 (+/-0.1) for {'learning_rate': 1, 'max_depth': 1, 'n_estimators': 5}\n",
"0.83 (+/-0.078) for {'learning_rate': 1, 'max_depth': 1, 'n_estimators': 50}\n",
"0.828 (+/-0.069) for {'learning_rate': 1, 'max_depth': 1, 'n_estimators': 250}\n",
"0.818 (+/-0.082) for {'learning_rate': 1, 'max_depth': 1, 'n_estimators': 500}\n",
"0.82 (+/-0.063) for {'learning_rate': 1, 'max_depth': 3, 'n_estimators': 5}\n",
"0.794 (+/-0.038) for {'learning_rate': 1, 'max_depth': 3, 'n_estimators': 50}\n",
"0.796 (+/-0.039) for {'learning_rate': 1, 'max_depth': 3, 'n_estimators': 250}\n",
"0.801 (+/-0.048) for {'learning_rate': 1, 'max_depth': 3, 'n_estimators': 500}\n",
"0.805 (+/-0.042) for {'learning_rate': 1, 'max_depth': 5, 'n_estimators': 5}\n",
"0.811 (+/-0.078) for {'learning_rate': 1, 'max_depth': 5, 'n_estimators': 50}\n",
"0.809 (+/-0.074) for {'learning_rate': 1, 'max_depth': 5, 'n_estimators': 250}\n",
"0.803 (+/-0.081) for {'learning_rate': 1, 'max_depth': 5, 'n_estimators': 500}\n",
"0.783 (+/-0.013) for {'learning_rate': 1, 'max_depth': 7, 'n_estimators': 5}\n",
"0.787 (+/-0.051) for {'learning_rate': 1, 'max_depth': 7, 'n_estimators': 50}\n",
"0.796 (+/-0.032) for {'learning_rate': 1, 'max_depth': 7, 'n_estimators': 250}\n",
"0.79 (+/-0.052) for {'learning_rate': 1, 'max_depth': 7, 'n_estimators': 500}\n",
"0.785 (+/-0.034) for {'learning_rate': 1, 'max_depth': 9, 'n_estimators': 5}\n",
"0.77 (+/-0.033) for {'learning_rate': 1, 'max_depth': 9, 'n_estimators': 50}\n",
"0.8 (+/-0.053) for {'learning_rate': 1, 'max_depth': 9, 'n_estimators': 250}\n",
"0.801 (+/-0.043) for {'learning_rate': 1, 'max_depth': 9, 'n_estimators': 500}\n",
"0.204 (+/-0.116) for {'learning_rate': 10, 'max_depth': 1, 'n_estimators': 5}\n",
"0.204 (+/-0.116) for {'learning_rate': 10, 'max_depth': 1, 'n_estimators': 50}\n",
"0.204 (+/-0.116) for {'learning_rate': 10, 'max_depth': 1, 'n_estimators': 250}\n",
"0.204 (+/-0.116) for {'learning_rate': 10, 'max_depth': 1, 'n_estimators': 500}\n",
"0.311 (+/-0.192) for {'learning_rate': 10, 'max_depth': 3, 'n_estimators': 5}\n",
"0.311 (+/-0.192) for {'learning_rate': 10, 'max_depth': 3, 'n_estimators': 50}\n",
"0.311 (+/-0.192) for {'learning_rate': 10, 'max_depth': 3, 'n_estimators': 250}\n",
"0.311 (+/-0.192) for {'learning_rate': 10, 'max_depth': 3, 'n_estimators': 500}\n",
"0.552 (+/-0.363) for {'learning_rate': 10, 'max_depth': 5, 'n_estimators': 5}\n",
"0.444 (+/-0.316) for {'learning_rate': 10, 'max_depth': 5, 'n_estimators': 50}\n",
"0.397 (+/-0.204) for {'learning_rate': 10, 'max_depth': 5, 'n_estimators': 250}\n",
"0.397 (+/-0.197) for {'learning_rate': 10, 'max_depth': 5, 'n_estimators': 500}\n",
"0.599 (+/-0.171) for {'learning_rate': 10, 'max_depth': 7, 'n_estimators': 5}\n",
"0.588 (+/-0.188) for {'learning_rate': 10, 'max_depth': 7, 'n_estimators': 50}\n",
"0.616 (+/-0.133) for {'learning_rate': 10, 'max_depth': 7, 'n_estimators': 250}\n",
"0.618 (+/-0.159) for {'learning_rate': 10, 'max_depth': 7, 'n_estimators': 500}\n",
"0.7 (+/-0.124) for {'learning_rate': 10, 'max_depth': 9, 'n_estimators': 5}\n",
"0.717 (+/-0.128) for {'learning_rate': 10, 'max_depth': 9, 'n_estimators': 50}\n",
"0.7 (+/-0.124) for {'learning_rate': 10, 'max_depth': 9, 'n_estimators': 250}\n",
"0.704 (+/-0.11) for {'learning_rate': 10, 'max_depth': 9, 'n_estimators': 500}\n",
"0.376 (+/-0.005) for {'learning_rate': 100, 'max_depth': 1, 'n_estimators': 5}\n",
"0.376 (+/-0.005) for {'learning_rate': 100, 'max_depth': 1, 'n_estimators': 50}\n",
"0.376 (+/-0.005) for {'learning_rate': 100, 'max_depth': 1, 'n_estimators': 250}\n",
"0.376 (+/-0.005) for {'learning_rate': 100, 'max_depth': 1, 'n_estimators': 500}\n",
"0.29 (+/-0.104) for {'learning_rate': 100, 'max_depth': 3, 'n_estimators': 5}\n",
"0.29 (+/-0.104) for {'learning_rate': 100, 'max_depth': 3, 'n_estimators': 50}\n",
"0.29 (+/-0.104) for {'learning_rate': 100, 'max_depth': 3, 'n_estimators': 250}\n",
"0.29 (+/-0.104) for {'learning_rate': 100, 'max_depth': 3, 'n_estimators': 500}\n",
"0.373 (+/-0.181) for {'learning_rate': 100, 'max_depth': 5, 'n_estimators': 5}\n",
"0.375 (+/-0.174) for {'learning_rate': 100, 'max_depth': 5, 'n_estimators': 50}\n",
"0.369 (+/-0.176) for {'learning_rate': 100, 'max_depth': 5, 'n_estimators': 250}\n",
"0.375 (+/-0.173) for {'learning_rate': 100, 'max_depth': 5, 'n_estimators': 500}\n",
"0.551 (+/-0.126) for {'learning_rate': 100, 'max_depth': 7, 'n_estimators': 5}\n",
"0.547 (+/-0.13) for {'learning_rate': 100, 'max_depth': 7, 'n_estimators': 50}\n",
"0.584 (+/-0.117) for {'learning_rate': 100, 'max_depth': 7, 'n_estimators': 250}\n",
"0.562 (+/-0.135) for {'learning_rate': 100, 'max_depth': 7, 'n_estimators': 500}\n",
"0.635 (+/-0.063) for {'learning_rate': 100, 'max_depth': 9, 'n_estimators': 5}\n",
"0.674 (+/-0.079) for {'learning_rate': 100, 'max_depth': 9, 'n_estimators': 50}\n",
"0.652 (+/-0.061) for {'learning_rate': 100, 'max_depth': 9, 'n_estimators': 250}\n",
"0.663 (+/-0.108) for {'learning_rate': 100, 'max_depth': 9, 'n_estimators': 500}\n"
]
}
],
"source": [
"gb = GradientBoostingClassifier()\n",
"parameters = {\n",
" 'n_estimators': [5, 50, 250, 500],\n",
" 'max_depth': [1, 3, 5, 7, 9],\n",
" 'learning_rate': [0.01, 0.1, 1, 10, 100]\n",
"}\n",
"\n",
"cv = GridSearchCV(gb, parameters, cv=5)\n",
"cv.fit(tr_features, tr_labels.values.ravel())\n",
"\n",
"print_results(cv)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Write out pickled model"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"['../../../GB_model.pkl']"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"joblib.dump(cv.best_estimator_, '../../../GB_model.pkl')"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.1"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment