Skip to content

Instantly share code, notes, and snippets.

Embed
What would you like to do?
GridSearchCV.ipynb
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "GridSearchCV.ipynb",
"provenance": [],
"collapsed_sections": [],
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"accelerator": "GPU"
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/kiwamizamurai/ea0cc420e5ed48dbc54d5108af502f5e/gridsearchcv.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "code",
"metadata": {
"id": "i-5BAxA1ALXT",
"colab_type": "code",
"outputId": "f1887139-7081-4a47-b1e3-8f74f920843c",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 472
}
},
"source": [
"!pip install lightgbm\n",
"!pip install hyperopt"
],
"execution_count": 9,
"outputs": [
{
"output_type": "stream",
"text": [
"Requirement already satisfied: lightgbm in /usr/local/lib/python3.6/dist-packages (2.2.3)\n",
"Requirement already satisfied: scikit-learn in /usr/local/lib/python3.6/dist-packages (from lightgbm) (0.21.3)\n",
"Requirement already satisfied: scipy in /usr/local/lib/python3.6/dist-packages (from lightgbm) (1.3.2)\n",
"Requirement already satisfied: numpy in /usr/local/lib/python3.6/dist-packages (from lightgbm) (1.17.4)\n",
"Requirement already satisfied: joblib>=0.11 in /usr/local/lib/python3.6/dist-packages (from scikit-learn->lightgbm) (0.14.0)\n",
"Requirement already satisfied: hyperopt in /usr/local/lib/python3.6/dist-packages (0.1.2)\n",
"Requirement already satisfied: future in /usr/local/lib/python3.6/dist-packages (from hyperopt) (0.16.0)\n",
"Requirement already satisfied: six in /usr/local/lib/python3.6/dist-packages (from hyperopt) (1.12.0)\n",
"Requirement already satisfied: numpy in /usr/local/lib/python3.6/dist-packages (from hyperopt) (1.17.4)\n",
"Requirement already satisfied: tqdm in /usr/local/lib/python3.6/dist-packages (from hyperopt) (4.28.1)\n",
"Requirement already satisfied: networkx in /usr/local/lib/python3.6/dist-packages (from hyperopt) (2.4)\n",
"Requirement already satisfied: pymongo in /usr/local/lib/python3.6/dist-packages (from hyperopt) (3.9.0)\n",
"Requirement already satisfied: scipy in /usr/local/lib/python3.6/dist-packages (from hyperopt) (1.3.2)\n",
"Requirement already satisfied: decorator>=4.3.0 in /usr/local/lib/python3.6/dist-packages (from networkx->hyperopt) (4.4.1)\n",
"Requirement already satisfied: lightgbm in /usr/local/lib/python3.6/dist-packages (2.2.3)\n",
"Requirement already satisfied: scipy in /usr/local/lib/python3.6/dist-packages (from lightgbm) (1.3.2)\n",
"Requirement already satisfied: numpy in /usr/local/lib/python3.6/dist-packages (from lightgbm) (1.17.4)\n",
"Requirement already satisfied: scikit-learn in /usr/local/lib/python3.6/dist-packages (from lightgbm) (0.21.3)\n",
"Requirement already satisfied: joblib>=0.11 in /usr/local/lib/python3.6/dist-packages (from scikit-learn->lightgbm) (0.14.0)\n",
"Requirement already satisfied: hyperopt in /usr/local/lib/python3.6/dist-packages (0.1.2)\n",
"Requirement already satisfied: tqdm in /usr/local/lib/python3.6/dist-packages (from hyperopt) (4.28.1)\n",
"Requirement already satisfied: six in /usr/local/lib/python3.6/dist-packages (from hyperopt) (1.12.0)\n",
"Requirement already satisfied: networkx in /usr/local/lib/python3.6/dist-packages (from hyperopt) (2.4)\n",
"Requirement already satisfied: scipy in /usr/local/lib/python3.6/dist-packages (from hyperopt) (1.3.2)\n",
"Requirement already satisfied: pymongo in /usr/local/lib/python3.6/dist-packages (from hyperopt) (3.9.0)\n",
"Requirement already satisfied: numpy in /usr/local/lib/python3.6/dist-packages (from hyperopt) (1.17.4)\n",
"Requirement already satisfied: future in /usr/local/lib/python3.6/dist-packages (from hyperopt) (0.16.0)\n",
"Requirement already satisfied: decorator>=4.3.0 in /usr/local/lib/python3.6/dist-packages (from networkx->hyperopt) (4.4.1)\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "AJgtWiOLBHSi",
"colab_type": "code",
"colab": {}
},
"source": [
"import numpy as np\n",
"from scipy.stats import randint as sp_randint\n",
"from scipy.stats import uniform as sp_uniform\n",
" \n",
"from sklearn.datasets import load_boston\n",
"from sklearn.model_selection import (cross_val_score, train_test_split, GridSearchCV, RandomizedSearchCV)\n",
"from sklearn.metrics import r2_score\n",
" \n",
"from lightgbm.sklearn import LGBMRegressor"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "4utCgEUnBQBZ",
"colab_type": "code",
"outputId": "51eed18a-ed80-4900-9442-43eedd8be6c4",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 82
}
},
"source": [
"boston = load_boston()\n",
"X, y = boston.data, boston.target\n",
" \n",
"X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=1)\n",
" \n",
"print(X_train.shape, ' train samples shape')\n",
"print(X_test.shape, ' test samples shape')"
],
"execution_count": 11,
"outputs": [
{
"output_type": "stream",
"text": [
"(404, 13) train samples shape\n",
"(102, 13) test samples shape\n",
"(404, 13) train samples shape\n",
"(102, 13) test samples shape\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "CkfYbe6ABR4L",
"colab_type": "code",
"colab": {}
},
"source": [
"hyper_space = {'n_estimators': [1000, 1500, 2000],\n",
" 'max_depth': [4, 5, -1],\n",
" 'num_leaves': [15, 31, 63],\n",
" 'subsample': [0.6, 0.7, 1.0],\n",
" 'colsample_bytree': [0.6, 0.7, 1.0]}"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "0np7sRK7B3bt",
"colab_type": "text"
},
"source": [
"まずハイパーパラメータの総数が\n",
"$$ 3 \\times 3 \\times 3 \\times 3 \\times 3 = 243 $$\n",
"なので243回fitします"
]
},
{
"cell_type": "code",
"metadata": {
"id": "7cQccdwNBVNp",
"colab_type": "code",
"colab": {}
},
"source": [
"est = LGBMRegressor(boosting='gbdt', n_jobs=-1, random_state=1)"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "YCk-0aOcCT-e",
"colab_type": "text"
},
"source": [
"さらに、CV=2なので \n",
"\n",
"$$ 243 \\times 2 = 486 $$\n",
"\n",
"回、合計fitします\n"
]
},
{
"cell_type": "code",
"metadata": {
"id": "oBNtTwfHCVBi",
"colab_type": "code",
"outputId": "37f871f8-fdb5-4cfa-afff-d2e829f9265a",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 180
}
},
"source": [
"gs = GridSearchCV(est, hyper_space, scoring='r2', cv=2, verbose=1)\n",
"gs_results = gs.fit(X_train, y_train)\n",
"print(\"BEST PARAMETERS: \" + str(gs_results.best_params_))\n",
"print(\"BEST CV SCORE: \" + str(gs_results.best_score_))"
],
"execution_count": 14,
"outputs": [
{
"output_type": "stream",
"text": [
"Fitting 2 folds for each of 243 candidates, totalling 486 fits\n"
],
"name": "stdout"
},
{
"output_type": "stream",
"text": [
"[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.\n",
"[Parallel(n_jobs=1)]: Done 486 out of 486 | elapsed: 2.3min finished\n"
],
"name": "stderr"
},
{
"output_type": "stream",
"text": [
"BEST PARAMETERS: {'colsample_bytree': 0.6, 'max_depth': -1, 'n_estimators': 1000, 'num_leaves': 15, 'subsample': 0.6}\n",
"BEST CV SCORE: 0.8094512224346279\n",
"Fitting 2 folds for each of 243 candidates, totalling 486 fits\n"
],
"name": "stdout"
},
{
"output_type": "stream",
"text": [
"[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.\n",
"[Parallel(n_jobs=1)]: Done 486 out of 486 | elapsed: 2.3min finished\n"
],
"name": "stderr"
},
{
"output_type": "stream",
"text": [
"BEST PARAMETERS: {'colsample_bytree': 0.6, 'max_depth': -1, 'n_estimators': 1000, 'num_leaves': 15, 'subsample': 0.6}\n",
"BEST CV SCORE: 0.8094512224346279\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "tDZbpCuMCf07",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 50
},
"outputId": "0bc05df7-335b-452a-d7c8-d6bccc74f8c5"
},
"source": [
"# Predict (after fitting GridSearchCV is an estimator with best parameters)\n",
"y_pred = gs.predict(X_test)\n",
" \n",
"# Score\n",
"score = r2_score(y_test, y_pred)\n",
"print(\"R2 SCORE ON TEST DATA: {}\".format(score))"
],
"execution_count": 15,
"outputs": [
{
"output_type": "stream",
"text": [
"R2 SCORE ON TEST DATA: 0.8763833391368616\n",
"R2 SCORE ON TEST DATA: 0.8763833391368616\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "kTfPalA7C0ko",
"colab_type": "code",
"colab": {}
},
"source": [
"hyper_space = {'n_estimators': sp_randint(1000, 2500),\n",
" 'max_depth': [4, 5, 8, -1],\n",
" 'num_leaves': [15, 31, 63, 127],\n",
" 'subsample': sp_uniform(0.6, 0.4),\n",
" 'colsample_bytree': sp_uniform(0.6, 0.4)}"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "BI8ds3DeC0oi",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 115
},
"outputId": "8f9953d3-fd19-4d20-e328-1960be9a56db"
},
"source": [
"rs = RandomizedSearchCV(est, hyper_space, n_iter=60, scoring='r2', cv=4, verbose=1, random_state=2018)\n",
"rs_results = rs.fit(X_train, y_train)\n",
"print(\"BEST PARAMETERS: \" + str(rs_results.best_params_))\n",
"print(\"BEST CV SCORE: \" + str(rs_results.best_score_))\n",
" \n",
"# Predict (after fitting RandomizedSearchCV is an estimator with best parameters)\n",
"y_pred = rs.predict(X_test)\n",
" \n",
"# Score\n",
"score = r2_score(y_test, y_pred)\n",
"print(\"R2 SCORE ON TEST DATA: {}\".format(score))"
],
"execution_count": 16,
"outputs": [
{
"output_type": "stream",
"text": [
"Fitting 4 folds for each of 60 candidates, totalling 240 fits\n"
],
"name": "stdout"
},
{
"output_type": "stream",
"text": [
"[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.\n",
"[Parallel(n_jobs=1)]: Done 240 out of 240 | elapsed: 1.6min finished\n"
],
"name": "stderr"
},
{
"output_type": "stream",
"text": [
"BEST PARAMETERS: {'subsample': 0.6, 'num_leaves': 31, 'n_estimators': 1000, 'max_depth': 4, 'colsample_bytree': 1.0}\n",
"BEST CV SCORE: 0.8573797585710996\n",
"R2 SCORE ON TEST DATA: 0.8737735714358015\n"
],
"name": "stdout"
}
]
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
You can’t perform that action at this time.