Last active
December 14, 2019 03:40
-
-
Save tomo-makes/d388f40632d553778e59ab5552b5cdcb to your computer and use it in GitHub Desktop.
@kirikei記事にそって、AutoGluon使ってみた
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"nbformat": 4, | |
"nbformat_minor": 0, | |
"metadata": { | |
"colab": { | |
"name": "@kirikei記事にそって、AutoGluon使ってみた", | |
"provenance": [], | |
"collapsed_sections": [], | |
"include_colab_link": true | |
}, | |
"kernelspec": { | |
"name": "python3", | |
"display_name": "Python 3" | |
}, | |
"accelerator": "GPU" | |
}, | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "view-in-github", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"<a href=\"https://colab.research.google.com/gist/tomo-makes/d388f40632d553778e59ab5552b5cdcb/-kirikei-autogluon.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "BwPFzJD8ZrxX", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"# AutoGluonを試す\n", | |
"\n", | |
"[データとfitだけで始めるAutoML - AutoGluon使ってみた - - Qiita](https://qiita.com/kirikei/items/f879eb2cfbaf3d37ee0f)\n" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "TSGrRaR8fRCe", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"### 前提\n", | |
"\n", | |
"- GPUあり、Python3ランタイムを選ぶ\n", | |
" - ただし、 `mxnet-cu100` をインストールに選ばなければGPUなしでも動くと思われる" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "pwkN1h0SaAxa", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"## インストール" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "tq7qZ2_kZoWT", | |
"colab_type": "code", | |
"outputId": "638d058a-b38b-4eb8-9d73-7dc565554aca", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 1000 | |
} | |
}, | |
"source": [ | |
"!pip install --upgrade mxnet-cu100" | |
], | |
"execution_count": 0, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"Collecting mxnet-cu100\n", | |
"\u001b[?25l Downloading https://files.pythonhosted.org/packages/3d/84/d098e0607ee6207448b6af65315f5d45946b49e4f48160eade6cdd64ce4e/mxnet_cu100-1.5.1.post0-py2.py3-none-manylinux1_x86_64.whl (540.1MB)\n", | |
"\u001b[K |████████████████████████████████| 540.1MB 29kB/s \n", | |
"\u001b[?25hCollecting graphviz<0.9.0,>=0.8.1\n", | |
" Downloading https://files.pythonhosted.org/packages/53/39/4ab213673844e0c004bed8a0781a0721a3f6bb23eb8854ee75c236428892/graphviz-0.8.4-py2.py3-none-any.whl\n", | |
"Requirement already satisfied, skipping upgrade: numpy<2.0.0,>1.16.0 in /usr/local/lib/python3.6/dist-packages (from mxnet-cu100) (1.17.4)\n", | |
"Requirement already satisfied, skipping upgrade: requests<3,>=2.20.0 in /usr/local/lib/python3.6/dist-packages (from mxnet-cu100) (2.21.0)\n", | |
"Requirement already satisfied, skipping upgrade: urllib3<1.25,>=1.21.1 in /usr/local/lib/python3.6/dist-packages (from requests<3,>=2.20.0->mxnet-cu100) (1.24.3)\n", | |
"Requirement already satisfied, skipping upgrade: certifi>=2017.4.17 in /usr/local/lib/python3.6/dist-packages (from requests<3,>=2.20.0->mxnet-cu100) (2019.11.28)\n", | |
"Requirement already satisfied, skipping upgrade: chardet<3.1.0,>=3.0.2 in /usr/local/lib/python3.6/dist-packages (from requests<3,>=2.20.0->mxnet-cu100) (3.0.4)\n", | |
"Requirement already satisfied, skipping upgrade: idna<2.9,>=2.5 in /usr/local/lib/python3.6/dist-packages (from requests<3,>=2.20.0->mxnet-cu100) (2.8)\n", | |
"Installing collected packages: graphviz, mxnet-cu100\n", | |
" Found existing installation: graphviz 0.10.1\n", | |
" Uninstalling graphviz-0.10.1:\n", | |
" Successfully uninstalled graphviz-0.10.1\n", | |
"Successfully installed graphviz-0.8.4 mxnet-cu100-1.5.1.post0\n", | |
"Collecting autogluon\n", | |
"\u001b[?25l Downloading https://files.pythonhosted.org/packages/4f/11/94e9c1cb329c852c7611e5e53061969b2eb49afa38a270dc8356a0aca02f/autogluon-0.0.3-py3-none-any.whl (316kB)\n", | |
"\u001b[K |████████████████████████████████| 317kB 37.8MB/s \n", | |
"\u001b[?25hRequirement already satisfied: numpy>=1.16.0 in /usr/local/lib/python3.6/dist-packages (from autogluon) (1.17.4)\n", | |
"Requirement already satisfied: matplotlib in /usr/local/lib/python3.6/dist-packages (from autogluon) (3.1.2)\n", | |
"Collecting tqdm>=4.38.0\n", | |
"\u001b[?25l Downloading https://files.pythonhosted.org/packages/7f/32/5144caf0478b1f26bd9d97f510a47336cf4ac0f96c6bc3b5af20d4173920/tqdm-4.40.2-py2.py3-none-any.whl (55kB)\n", | |
"\u001b[K |████████████████████████████████| 61kB 9.7MB/s \n", | |
"\u001b[?25hCollecting ConfigSpace<=0.4.10\n", | |
"\u001b[?25l Downloading https://files.pythonhosted.org/packages/42/de/4e8e4f26332fc65404f52baa112defbf822b6738b60bfa6b2993f5c60933/ConfigSpace-0.4.10.tar.gz (882kB)\n", | |
"\u001b[K |████████████████████████████████| 890kB 25.8MB/s \n", | |
"\u001b[?25hCollecting scikit-learn==0.21.2\n", | |
"\u001b[?25l Downloading https://files.pythonhosted.org/packages/85/04/49633f490f726da6e454fddc8e938bbb5bfed2001681118d3814c219b723/scikit_learn-0.21.2-cp36-cp36m-manylinux1_x86_64.whl (6.7MB)\n", | |
"\u001b[K |████████████████████████████████| 6.7MB 56.2MB/s \n", | |
"\u001b[?25hCollecting distributed>=2.6.0\n", | |
"\u001b[?25l Downloading https://files.pythonhosted.org/packages/c1/45/1c741b907b7aa60a7a951e3cb107f1c3cda40fc368c2ff21a0bb58029aee/distributed-2.9.0-py3-none-any.whl (569kB)\n", | |
"\u001b[K |████████████████████████████████| 573kB 53.8MB/s \n", | |
"\u001b[?25hCollecting gluoncv>=0.5.0\n", | |
"\u001b[?25l Downloading https://files.pythonhosted.org/packages/66/1e/e9096a9742764c9427bb7e408bc867b8761d9a2a830514f1997d1e51abaf/gluoncv-0.5.0-py2.py3-none-any.whl (511kB)\n", | |
"\u001b[K |████████████████████████████████| 512kB 64.1MB/s \n", | |
"\u001b[?25hCollecting lightgbm==2.3.0\n", | |
"\u001b[?25l Downloading https://files.pythonhosted.org/packages/05/ec/756f13b25258e0aa6ec82d98504e01523814f95fc70718407419b8520e1d/lightgbm-2.3.0-py2.py3-none-manylinux1_x86_64.whl (1.3MB)\n", | |
"\u001b[K |████████████████████████████████| 1.3MB 24.5MB/s \n", | |
"\u001b[?25hCollecting scikit-optimize\n", | |
"\u001b[?25l Downloading https://files.pythonhosted.org/packages/f4/44/60f82c97d1caa98752c7da2c1681cab5c7a390a0fdd3a55fac672b321cac/scikit_optimize-0.5.2-py2.py3-none-any.whl (74kB)\n", | |
"\u001b[K |████████████████████████████████| 81kB 12.3MB/s \n", | |
"\u001b[?25hCollecting catboost\n", | |
"\u001b[?25l Downloading https://files.pythonhosted.org/packages/3d/f6/733fe7cca5d0d882e1a708ad59da2510416cc2e4fa54e17c7a5082f67811/catboost-0.20.1-cp36-none-manylinux1_x86_64.whl (63.6MB)\n", | |
"\u001b[K |████████████████████████████████| 63.6MB 49.3MB/s \n", | |
"\u001b[?25hRequirement already satisfied: scipy>=1.3.3 in /usr/local/lib/python3.6/dist-packages (from autogluon) (1.3.3)\n", | |
"Requirement already satisfied: tornado in /usr/local/lib/python3.6/dist-packages (from autogluon) (4.5.3)\n", | |
"Requirement already satisfied: graphviz in /usr/local/lib/python3.6/dist-packages (from autogluon) (0.8.4)\n", | |
"Collecting pandas==0.24.2\n", | |
"\u001b[?25l Downloading https://files.pythonhosted.org/packages/19/74/e50234bc82c553fecdbd566d8650801e3fe2d6d8c8d940638e3d8a7c5522/pandas-0.24.2-cp36-cp36m-manylinux1_x86_64.whl (10.1MB)\n", | |
"\u001b[K |████████████████████████████████| 10.1MB 22.4MB/s \n", | |
"\u001b[?25hRequirement already satisfied: requests in /usr/local/lib/python3.6/dist-packages (from autogluon) (2.21.0)\n", | |
"Requirement already satisfied: psutil in /usr/local/lib/python3.6/dist-packages (from autogluon) (5.4.8)\n", | |
"Collecting gluonnlp==0.8.1\n", | |
"\u001b[?25l Downloading https://files.pythonhosted.org/packages/6e/2d/40c2ad37d74e5e9030064d73542b7df0f7df7ba98d47932874033cf03d79/gluonnlp-0.8.1.tar.gz (236kB)\n", | |
"\u001b[K |████████████████████████████████| 245kB 53.6MB/s \n", | |
"\u001b[?25hRequirement already satisfied: cython in /usr/local/lib/python3.6/dist-packages (from autogluon) (0.29.14)\n", | |
"Collecting boto3==1.9.187\n", | |
"\u001b[?25l Downloading https://files.pythonhosted.org/packages/e7/f6/29b7254c051801ac182710bda0032c6d4bd20ea914c5f10d6eb3fc47e669/boto3-1.9.187-py2.py3-none-any.whl (128kB)\n", | |
"\u001b[K |████████████████████████████████| 133kB 57.9MB/s \n", | |
"\u001b[?25hCollecting paramiko>=2.5.0\n", | |
"\u001b[?25l Downloading https://files.pythonhosted.org/packages/06/1e/1e08baaaf6c3d3df1459fd85f0e7d2d6aa916f33958f151ee1ecc9800971/paramiko-2.7.1-py2.py3-none-any.whl (206kB)\n", | |
"\u001b[K |████████████████████████████████| 215kB 60.7MB/s \n", | |
"\u001b[?25hRequirement already satisfied: kiwisolver>=1.0.1 in /usr/local/lib/python3.6/dist-packages (from matplotlib->autogluon) (1.1.0)\n", | |
"Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.6/dist-packages (from matplotlib->autogluon) (0.10.0)\n", | |
"Requirement already satisfied: python-dateutil>=2.1 in /usr/local/lib/python3.6/dist-packages (from matplotlib->autogluon) (2.6.1)\n", | |
"Requirement already satisfied: pyparsing!=2.0.4,!=2.1.2,!=2.1.6,>=2.0.1 in /usr/local/lib/python3.6/dist-packages (from matplotlib->autogluon) (2.4.5)\n", | |
"Requirement already satisfied: typing in /usr/local/lib/python3.6/dist-packages (from ConfigSpace<=0.4.10->autogluon) (3.6.6)\n", | |
"Requirement already satisfied: joblib>=0.11 in /usr/local/lib/python3.6/dist-packages (from scikit-learn==0.21.2->autogluon) (0.14.1)\n", | |
"Requirement already satisfied: cloudpickle>=0.2.2 in /usr/local/lib/python3.6/dist-packages (from distributed>=2.6.0->autogluon) (1.2.2)\n", | |
"Requirement already satisfied: sortedcontainers!=2.0.0,!=2.0.1 in /usr/local/lib/python3.6/dist-packages (from distributed>=2.6.0->autogluon) (2.1.0)\n", | |
"Collecting dask>=2.7.0\n", | |
"\u001b[?25l Downloading https://files.pythonhosted.org/packages/bf/b3/9175539d5a43b0bb1fe6d9729613a9639dd78a0e13c3fa7fc1ba702e56fa/dask-2.9.0-py3-none-any.whl (770kB)\n", | |
"\u001b[K |████████████████████████████████| 778kB 54.7MB/s \n", | |
"\u001b[?25hRequirement already satisfied: msgpack in /usr/local/lib/python3.6/dist-packages (from distributed>=2.6.0->autogluon) (0.5.6)\n", | |
"Requirement already satisfied: zict>=0.1.3 in /usr/local/lib/python3.6/dist-packages (from distributed>=2.6.0->autogluon) (1.0.0)\n", | |
"Requirement already satisfied: pyyaml in /usr/local/lib/python3.6/dist-packages (from distributed>=2.6.0->autogluon) (3.13)\n", | |
"Requirement already satisfied: toolz>=0.7.4 in /usr/local/lib/python3.6/dist-packages (from distributed>=2.6.0->autogluon) (0.10.0)\n", | |
"Requirement already satisfied: tblib in /usr/local/lib/python3.6/dist-packages (from distributed>=2.6.0->autogluon) (1.6.0)\n", | |
"Requirement already satisfied: click>=6.6 in /usr/local/lib/python3.6/dist-packages (from distributed>=2.6.0->autogluon) (7.0)\n", | |
"Requirement already satisfied: Pillow in /usr/local/lib/python3.6/dist-packages (from gluoncv>=0.5.0->autogluon) (4.3.0)\n", | |
"Requirement already satisfied: plotly in /usr/local/lib/python3.6/dist-packages (from catboost->autogluon) (4.1.1)\n", | |
"Requirement already satisfied: six in /usr/local/lib/python3.6/dist-packages (from catboost->autogluon) (1.12.0)\n", | |
"Requirement already satisfied: pytz>=2011k in /usr/local/lib/python3.6/dist-packages (from pandas==0.24.2->autogluon) (2018.9)\n", | |
"Requirement already satisfied: chardet<3.1.0,>=3.0.2 in /usr/local/lib/python3.6/dist-packages (from requests->autogluon) (3.0.4)\n", | |
"Requirement already satisfied: idna<2.9,>=2.5 in /usr/local/lib/python3.6/dist-packages (from requests->autogluon) (2.8)\n", | |
"Requirement already satisfied: urllib3<1.25,>=1.21.1 in /usr/local/lib/python3.6/dist-packages (from requests->autogluon) (1.24.3)\n", | |
"Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.6/dist-packages (from requests->autogluon) (2019.11.28)\n", | |
"Requirement already satisfied: s3transfer<0.3.0,>=0.2.0 in /usr/local/lib/python3.6/dist-packages (from boto3==1.9.187->autogluon) (0.2.1)\n", | |
"Requirement already satisfied: jmespath<1.0.0,>=0.7.1 in /usr/local/lib/python3.6/dist-packages (from boto3==1.9.187->autogluon) (0.9.4)\n", | |
"Collecting botocore<1.13.0,>=1.12.187\n", | |
"\u001b[?25l Downloading https://files.pythonhosted.org/packages/8e/7b/88f10115b4748f86be6b7b1d8761ba5023fccf6e6cbe762e368f63eddcf9/botocore-1.12.253-py2.py3-none-any.whl (5.7MB)\n", | |
"\u001b[K |████████████████████████████████| 5.8MB 57.5MB/s \n", | |
"\u001b[?25hCollecting bcrypt>=3.1.3\n", | |
"\u001b[?25l Downloading https://files.pythonhosted.org/packages/8b/1d/82826443777dd4a624e38a08957b975e75df859b381ae302cfd7a30783ed/bcrypt-3.1.7-cp34-abi3-manylinux1_x86_64.whl (56kB)\n", | |
"\u001b[K |████████████████████████████████| 61kB 9.8MB/s \n", | |
"\u001b[?25hCollecting pynacl>=1.0.1\n", | |
"\u001b[?25l Downloading https://files.pythonhosted.org/packages/27/15/2cd0a203f318c2240b42cd9dd13c931ddd61067809fee3479f44f086103e/PyNaCl-1.3.0-cp34-abi3-manylinux1_x86_64.whl (759kB)\n", | |
"\u001b[K |████████████████████████████████| 768kB 67.3MB/s \n", | |
"\u001b[?25hCollecting cryptography>=2.5\n", | |
"\u001b[?25l Downloading https://files.pythonhosted.org/packages/ca/9a/7cece52c46546e214e10811b36b2da52ce1ea7fa203203a629b8dfadad53/cryptography-2.8-cp34-abi3-manylinux2010_x86_64.whl (2.3MB)\n", | |
"\u001b[K |████████████████████████████████| 2.3MB 64.4MB/s \n", | |
"\u001b[?25hRequirement already satisfied: setuptools in /usr/local/lib/python3.6/dist-packages (from kiwisolver>=1.0.1->matplotlib->autogluon) (42.0.2)\n", | |
"Requirement already satisfied: heapdict in /usr/local/lib/python3.6/dist-packages (from zict>=0.1.3->distributed>=2.6.0->autogluon) (1.0.1)\n", | |
"Requirement already satisfied: olefile in /usr/local/lib/python3.6/dist-packages (from Pillow->gluoncv>=0.5.0->autogluon) (0.46)\n", | |
"Requirement already satisfied: retrying>=1.3.3 in /usr/local/lib/python3.6/dist-packages (from plotly->catboost->autogluon) (1.3.3)\n", | |
"Requirement already satisfied: docutils<0.16,>=0.10 in /usr/local/lib/python3.6/dist-packages (from botocore<1.13.0,>=1.12.187->boto3==1.9.187->autogluon) (0.15.2)\n", | |
"Requirement already satisfied: cffi>=1.1 in /usr/local/lib/python3.6/dist-packages (from bcrypt>=3.1.3->paramiko>=2.5.0->autogluon) (1.13.2)\n", | |
"Requirement already satisfied: pycparser in /usr/local/lib/python3.6/dist-packages (from cffi>=1.1->bcrypt>=3.1.3->paramiko>=2.5.0->autogluon) (2.19)\n", | |
"Building wheels for collected packages: ConfigSpace, gluonnlp\n", | |
" Building wheel for ConfigSpace (setup.py) ... \u001b[?25l\u001b[?25hdone\n", | |
" Created wheel for ConfigSpace: filename=ConfigSpace-0.4.10-cp36-cp36m-linux_x86_64.whl size=2712287 sha256=815a05501e2523793a5913ab06d99fe34ae3e3aa4579b2b8cda2eb51821d2322\n", | |
" Stored in directory: /root/.cache/pip/wheels/75/83/cb/28dd42bac69c8867d485138030daa83841c7f84afe68b2fdf7\n", | |
" Building wheel for gluonnlp (setup.py) ... \u001b[?25l\u001b[?25hdone\n", | |
" Created wheel for gluonnlp: filename=gluonnlp-0.8.1-cp36-none-any.whl size=293520 sha256=5f08fda536ea3c6253f7c86b3499b2dcace59f20e4da64122b3730d099f0b444\n", | |
" Stored in directory: /root/.cache/pip/wheels/3e/e7/3e/9cdf8ad7fce112fde2f4a52604045e5dd80f84d645bedb70c7\n", | |
"Successfully built ConfigSpace gluonnlp\n", | |
"\u001b[31mERROR: google-colab 1.0.0 has requirement pandas~=0.25.0; python_version >= \"3.0\", but you'll have pandas 0.24.2 which is incompatible.\u001b[0m\n", | |
"\u001b[31mERROR: distributed 2.9.0 has requirement tornado>=5, but you'll have tornado 4.5.3 which is incompatible.\u001b[0m\n", | |
"Installing collected packages: tqdm, ConfigSpace, scikit-learn, dask, distributed, gluoncv, lightgbm, scikit-optimize, pandas, catboost, gluonnlp, botocore, boto3, bcrypt, pynacl, cryptography, paramiko, autogluon\n", | |
" Found existing installation: tqdm 4.28.1\n", | |
" Uninstalling tqdm-4.28.1:\n", | |
" Successfully uninstalled tqdm-4.28.1\n", | |
" Found existing installation: scikit-learn 0.21.3\n", | |
" Uninstalling scikit-learn-0.21.3:\n", | |
" Successfully uninstalled scikit-learn-0.21.3\n", | |
" Found existing installation: dask 1.1.5\n", | |
" Uninstalling dask-1.1.5:\n", | |
" Successfully uninstalled dask-1.1.5\n", | |
" Found existing installation: distributed 1.25.3\n", | |
" Uninstalling distributed-1.25.3:\n", | |
" Successfully uninstalled distributed-1.25.3\n", | |
" Found existing installation: lightgbm 2.2.3\n", | |
" Uninstalling lightgbm-2.2.3:\n", | |
" Successfully uninstalled lightgbm-2.2.3\n", | |
" Found existing installation: pandas 0.25.3\n", | |
" Uninstalling pandas-0.25.3:\n", | |
" Successfully uninstalled pandas-0.25.3\n", | |
" Found existing installation: botocore 1.13.36\n", | |
" Uninstalling botocore-1.13.36:\n", | |
" Successfully uninstalled botocore-1.13.36\n", | |
" Found existing installation: boto3 1.10.36\n", | |
" Uninstalling boto3-1.10.36:\n", | |
" Successfully uninstalled boto3-1.10.36\n", | |
"Successfully installed ConfigSpace-0.4.10 autogluon-0.0.3 bcrypt-3.1.7 boto3-1.9.187 botocore-1.12.253 catboost-0.20.1 cryptography-2.8 dask-2.9.0 distributed-2.9.0 gluoncv-0.5.0 gluonnlp-0.8.1 lightgbm-2.3.0 pandas-0.24.2 paramiko-2.7.1 pynacl-1.3.0 scikit-learn-0.21.2 scikit-optimize-0.5.2 tqdm-4.40.2\n" | |
], | |
"name": "stdout" | |
}, | |
{ | |
"output_type": "display_data", | |
"data": { | |
"application/vnd.colab-display-data+json": { | |
"pip_warning": { | |
"packages": [ | |
"pandas", | |
"tqdm" | |
] | |
} | |
} | |
}, | |
"metadata": { | |
"tags": [] | |
} | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "AxEJ6a9vdNPE", | |
"colab_type": "code", | |
"outputId": "c3a01f6d-f655-4d22-d11c-e3e1fe310ada", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 1000 | |
} | |
}, | |
"source": [ | |
"!pip install autogluon" | |
], | |
"execution_count": 0, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"Requirement already satisfied: autogluon in /usr/local/lib/python3.6/dist-packages (0.0.3)\n", | |
"Requirement already satisfied: ConfigSpace<=0.4.10 in /usr/local/lib/python3.6/dist-packages (from autogluon) (0.4.10)\n", | |
"Requirement already satisfied: graphviz in /usr/local/lib/python3.6/dist-packages (from autogluon) (0.8.4)\n", | |
"Requirement already satisfied: psutil in /usr/local/lib/python3.6/dist-packages (from autogluon) (5.4.8)\n", | |
"Requirement already satisfied: numpy>=1.16.0 in /usr/local/lib/python3.6/dist-packages (from autogluon) (1.17.4)\n", | |
"Requirement already satisfied: scikit-learn==0.21.2 in /usr/local/lib/python3.6/dist-packages (from autogluon) (0.21.2)\n", | |
"Requirement already satisfied: tornado in /usr/local/lib/python3.6/dist-packages (from autogluon) (4.5.3)\n", | |
"Requirement already satisfied: tqdm>=4.38.0 in /usr/local/lib/python3.6/dist-packages (from autogluon) (4.40.2)\n", | |
"Requirement already satisfied: gluoncv>=0.5.0 in /usr/local/lib/python3.6/dist-packages (from autogluon) (0.5.0)\n", | |
"Requirement already satisfied: gluonnlp==0.8.1 in /usr/local/lib/python3.6/dist-packages (from autogluon) (0.8.1)\n", | |
"Requirement already satisfied: scikit-optimize in /usr/local/lib/python3.6/dist-packages (from autogluon) (0.5.2)\n", | |
"Requirement already satisfied: lightgbm==2.3.0 in /usr/local/lib/python3.6/dist-packages (from autogluon) (2.3.0)\n", | |
"Requirement already satisfied: distributed>=2.6.0 in /usr/local/lib/python3.6/dist-packages (from autogluon) (2.9.0)\n", | |
"Requirement already satisfied: catboost in /usr/local/lib/python3.6/dist-packages (from autogluon) (0.20.1)\n", | |
"Requirement already satisfied: boto3==1.9.187 in /usr/local/lib/python3.6/dist-packages (from autogluon) (1.9.187)\n", | |
"Requirement already satisfied: scipy>=1.3.3 in /usr/local/lib/python3.6/dist-packages (from autogluon) (1.3.3)\n", | |
"Requirement already satisfied: requests in /usr/local/lib/python3.6/dist-packages (from autogluon) (2.21.0)\n", | |
"Requirement already satisfied: matplotlib in /usr/local/lib/python3.6/dist-packages (from autogluon) (3.1.2)\n", | |
"Requirement already satisfied: cython in /usr/local/lib/python3.6/dist-packages (from autogluon) (0.29.14)\n", | |
"Requirement already satisfied: paramiko>=2.5.0 in /usr/local/lib/python3.6/dist-packages (from autogluon) (2.7.1)\n", | |
"Requirement already satisfied: pandas==0.24.2 in /usr/local/lib/python3.6/dist-packages (from autogluon) (0.24.2)\n", | |
"Requirement already satisfied: typing in /usr/local/lib/python3.6/dist-packages (from ConfigSpace<=0.4.10->autogluon) (3.6.6)\n", | |
"Requirement already satisfied: pyparsing in /usr/local/lib/python3.6/dist-packages (from ConfigSpace<=0.4.10->autogluon) (2.4.5)\n", | |
"Requirement already satisfied: joblib>=0.11 in /usr/local/lib/python3.6/dist-packages (from scikit-learn==0.21.2->autogluon) (0.14.1)\n", | |
"Requirement already satisfied: Pillow in /usr/local/lib/python3.6/dist-packages (from gluoncv>=0.5.0->autogluon) (4.3.0)\n", | |
"Requirement already satisfied: cloudpickle>=0.2.2 in /usr/local/lib/python3.6/dist-packages (from distributed>=2.6.0->autogluon) (1.2.2)\n", | |
"Requirement already satisfied: dask>=2.7.0 in /usr/local/lib/python3.6/dist-packages (from distributed>=2.6.0->autogluon) (2.9.0)\n", | |
"Requirement already satisfied: toolz>=0.7.4 in /usr/local/lib/python3.6/dist-packages (from distributed>=2.6.0->autogluon) (0.10.0)\n", | |
"Requirement already satisfied: msgpack in /usr/local/lib/python3.6/dist-packages (from distributed>=2.6.0->autogluon) (0.5.6)\n", | |
"Requirement already satisfied: sortedcontainers!=2.0.0,!=2.0.1 in /usr/local/lib/python3.6/dist-packages (from distributed>=2.6.0->autogluon) (2.1.0)\n", | |
"Requirement already satisfied: tblib in /usr/local/lib/python3.6/dist-packages (from distributed>=2.6.0->autogluon) (1.6.0)\n", | |
"Requirement already satisfied: pyyaml in /usr/local/lib/python3.6/dist-packages (from distributed>=2.6.0->autogluon) (3.13)\n", | |
"Requirement already satisfied: zict>=0.1.3 in /usr/local/lib/python3.6/dist-packages (from distributed>=2.6.0->autogluon) (1.0.0)\n", | |
"Requirement already satisfied: click>=6.6 in /usr/local/lib/python3.6/dist-packages (from distributed>=2.6.0->autogluon) (7.0)\n", | |
"Requirement already satisfied: plotly in /usr/local/lib/python3.6/dist-packages (from catboost->autogluon) (4.1.1)\n", | |
"Requirement already satisfied: six in /usr/local/lib/python3.6/dist-packages (from catboost->autogluon) (1.12.0)\n", | |
"Requirement already satisfied: jmespath<1.0.0,>=0.7.1 in /usr/local/lib/python3.6/dist-packages (from boto3==1.9.187->autogluon) (0.9.4)\n", | |
"Requirement already satisfied: s3transfer<0.3.0,>=0.2.0 in /usr/local/lib/python3.6/dist-packages (from boto3==1.9.187->autogluon) (0.2.1)\n", | |
"Requirement already satisfied: botocore<1.13.0,>=1.12.187 in /usr/local/lib/python3.6/dist-packages (from boto3==1.9.187->autogluon) (1.12.253)\n", | |
"Requirement already satisfied: urllib3<1.25,>=1.21.1 in /usr/local/lib/python3.6/dist-packages (from requests->autogluon) (1.24.3)\n", | |
"Requirement already satisfied: idna<2.9,>=2.5 in /usr/local/lib/python3.6/dist-packages (from requests->autogluon) (2.8)\n", | |
"Requirement already satisfied: chardet<3.1.0,>=3.0.2 in /usr/local/lib/python3.6/dist-packages (from requests->autogluon) (3.0.4)\n", | |
"Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.6/dist-packages (from requests->autogluon) (2019.11.28)\n", | |
"Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.6/dist-packages (from matplotlib->autogluon) (0.10.0)\n", | |
"Requirement already satisfied: kiwisolver>=1.0.1 in /usr/local/lib/python3.6/dist-packages (from matplotlib->autogluon) (1.1.0)\n", | |
"Requirement already satisfied: python-dateutil>=2.1 in /usr/local/lib/python3.6/dist-packages (from matplotlib->autogluon) (2.6.1)\n", | |
"Requirement already satisfied: cryptography>=2.5 in /usr/local/lib/python3.6/dist-packages (from paramiko>=2.5.0->autogluon) (2.8)\n", | |
"Requirement already satisfied: pynacl>=1.0.1 in /usr/local/lib/python3.6/dist-packages (from paramiko>=2.5.0->autogluon) (1.3.0)\n", | |
"Requirement already satisfied: bcrypt>=3.1.3 in /usr/local/lib/python3.6/dist-packages (from paramiko>=2.5.0->autogluon) (3.1.7)\n", | |
"Requirement already satisfied: pytz>=2011k in /usr/local/lib/python3.6/dist-packages (from pandas==0.24.2->autogluon) (2018.9)\n", | |
"Requirement already satisfied: olefile in /usr/local/lib/python3.6/dist-packages (from Pillow->gluoncv>=0.5.0->autogluon) (0.46)\n", | |
"Requirement already satisfied: heapdict in /usr/local/lib/python3.6/dist-packages (from zict>=0.1.3->distributed>=2.6.0->autogluon) (1.0.1)\n", | |
"Requirement already satisfied: retrying>=1.3.3 in /usr/local/lib/python3.6/dist-packages (from plotly->catboost->autogluon) (1.3.3)\n", | |
"Requirement already satisfied: docutils<0.16,>=0.10 in /usr/local/lib/python3.6/dist-packages (from botocore<1.13.0,>=1.12.187->boto3==1.9.187->autogluon) (0.15.2)\n", | |
"Requirement already satisfied: setuptools in /usr/local/lib/python3.6/dist-packages (from kiwisolver>=1.0.1->matplotlib->autogluon) (42.0.2)\n", | |
"Requirement already satisfied: cffi!=1.11.3,>=1.8 in /usr/local/lib/python3.6/dist-packages (from cryptography>=2.5->paramiko>=2.5.0->autogluon) (1.13.2)\n", | |
"Requirement already satisfied: pycparser in /usr/local/lib/python3.6/dist-packages (from cffi!=1.11.3,>=1.8->cryptography>=2.5->paramiko>=2.5.0->autogluon) (2.19)\n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "uyUG_yY6aCQ-", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"## テーブルデータで試す\n", | |
"\n" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "1aTzzwWRdg5p", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"国税調査のデータadultデータセットを用いて分類問題と回帰問題を解く。" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "HonOkoU0Z5iI", | |
"colab_type": "code", | |
"outputId": "706ab950-fa18-4b5a-dcf7-ecc5e1057955", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 54 | |
} | |
}, | |
"source": [ | |
"import autogluon as ag\n", | |
"from autogluon import TabularPrediction as task\n", | |
"\n", | |
"train_data = task.Dataset(file_path='https://autogluon.s3.amazonaws.com/datasets/Inc/train.csv')\n", | |
"train_data = train_data.head(2000) # subsample 2000 data points for faster demo" | |
], | |
"execution_count": 0, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"Loaded data from: https://autogluon.s3.amazonaws.com/datasets/Inc/train.csv | Columns = 15 / 15 | Rows = 39073 -> 39073\n" | |
], | |
"name": "stderr" | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "pI4yRw38Z56v", | |
"colab_type": "code", | |
"outputId": "013946ff-c350-4d4a-a93d-92b69de4556d", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 377 | |
} | |
}, | |
"source": [ | |
"train_data.head()" | |
], | |
"execution_count": 0, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/html": [ | |
"<div>\n", | |
"<style scoped>\n", | |
" .dataframe tbody tr th:only-of-type {\n", | |
" vertical-align: middle;\n", | |
" }\n", | |
"\n", | |
" .dataframe tbody tr th {\n", | |
" vertical-align: top;\n", | |
" }\n", | |
"\n", | |
" .dataframe thead th {\n", | |
" text-align: right;\n", | |
" }\n", | |
"</style>\n", | |
"<table border=\"1\" class=\"dataframe\">\n", | |
" <thead>\n", | |
" <tr style=\"text-align: right;\">\n", | |
" <th></th>\n", | |
" <th>age</th>\n", | |
" <th>workclass</th>\n", | |
" <th>fnlwgt</th>\n", | |
" <th>education</th>\n", | |
" <th>education-num</th>\n", | |
" <th>marital-status</th>\n", | |
" <th>occupation</th>\n", | |
" <th>relationship</th>\n", | |
" <th>race</th>\n", | |
" <th>sex</th>\n", | |
" <th>capital-gain</th>\n", | |
" <th>capital-loss</th>\n", | |
" <th>hours-per-week</th>\n", | |
" <th>native-country</th>\n", | |
" <th>class</th>\n", | |
" </tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" <tr>\n", | |
" <th>0</th>\n", | |
" <td>25</td>\n", | |
" <td>Private</td>\n", | |
" <td>178478</td>\n", | |
" <td>Bachelors</td>\n", | |
" <td>13</td>\n", | |
" <td>Never-married</td>\n", | |
" <td>Tech-support</td>\n", | |
" <td>Own-child</td>\n", | |
" <td>White</td>\n", | |
" <td>Female</td>\n", | |
" <td>0</td>\n", | |
" <td>0</td>\n", | |
" <td>40</td>\n", | |
" <td>United-States</td>\n", | |
" <td><=50K</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>1</th>\n", | |
" <td>23</td>\n", | |
" <td>State-gov</td>\n", | |
" <td>61743</td>\n", | |
" <td>5th-6th</td>\n", | |
" <td>3</td>\n", | |
" <td>Never-married</td>\n", | |
" <td>Transport-moving</td>\n", | |
" <td>Not-in-family</td>\n", | |
" <td>White</td>\n", | |
" <td>Male</td>\n", | |
" <td>0</td>\n", | |
" <td>0</td>\n", | |
" <td>35</td>\n", | |
" <td>United-States</td>\n", | |
" <td><=50K</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>2</th>\n", | |
" <td>46</td>\n", | |
" <td>Private</td>\n", | |
" <td>376789</td>\n", | |
" <td>HS-grad</td>\n", | |
" <td>9</td>\n", | |
" <td>Never-married</td>\n", | |
" <td>Other-service</td>\n", | |
" <td>Not-in-family</td>\n", | |
" <td>White</td>\n", | |
" <td>Male</td>\n", | |
" <td>0</td>\n", | |
" <td>0</td>\n", | |
" <td>15</td>\n", | |
" <td>United-States</td>\n", | |
" <td><=50K</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>3</th>\n", | |
" <td>55</td>\n", | |
" <td>?</td>\n", | |
" <td>200235</td>\n", | |
" <td>HS-grad</td>\n", | |
" <td>9</td>\n", | |
" <td>Married-civ-spouse</td>\n", | |
" <td>?</td>\n", | |
" <td>Husband</td>\n", | |
" <td>White</td>\n", | |
" <td>Male</td>\n", | |
" <td>0</td>\n", | |
" <td>0</td>\n", | |
" <td>50</td>\n", | |
" <td>United-States</td>\n", | |
" <td>>50K</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>4</th>\n", | |
" <td>36</td>\n", | |
" <td>Private</td>\n", | |
" <td>224541</td>\n", | |
" <td>7th-8th</td>\n", | |
" <td>4</td>\n", | |
" <td>Married-civ-spouse</td>\n", | |
" <td>Handlers-cleaners</td>\n", | |
" <td>Husband</td>\n", | |
" <td>White</td>\n", | |
" <td>Male</td>\n", | |
" <td>0</td>\n", | |
" <td>0</td>\n", | |
" <td>40</td>\n", | |
" <td>El-Salvador</td>\n", | |
" <td><=50K</td>\n", | |
" </tr>\n", | |
" </tbody>\n", | |
"</table>\n", | |
"</div>" | |
], | |
"text/plain": [ | |
" age workclass fnlwgt ... hours-per-week native-country class\n", | |
"0 25 Private 178478 ... 40 United-States <=50K\n", | |
"1 23 State-gov 61743 ... 35 United-States <=50K\n", | |
"2 46 Private 376789 ... 15 United-States <=50K\n", | |
"3 55 ? 200235 ... 50 United-States >50K\n", | |
"4 36 Private 224541 ... 40 El-Salvador <=50K\n", | |
"\n", | |
"[5 rows x 15 columns]" | |
] | |
}, | |
"metadata": { | |
"tags": [] | |
}, | |
"execution_count": 5 | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "TseD-HCGaIdB", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"### 分類\n", | |
"\n", | |
"classカラムの分類,すなわち収入が50k以上か未満かの二値分類を行う。" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "f8I_vW_aaue6", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"クラスを確認する。" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "xuW0zRSdZ5_9", | |
"colab_type": "code", | |
"outputId": "b67794f5-98e5-4394-9f08-d591abcb0aa2", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 119 | |
} | |
}, | |
"source": [ | |
"label_column = 'class'\n", | |
"print(\"Summary of class variable: \\n\", train_data[label_column].describe())" | |
], | |
"execution_count": 0, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"Summary of class variable: \n", | |
" count 2000\n", | |
"unique 2\n", | |
"top <=50K\n", | |
"freq 1551\n", | |
"Name: class, dtype: object\n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "kUdFkEEbav0b", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"とりあえずAutoGluonを実行する。" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "fhT0RKuoacps", | |
"colab_type": "code", | |
"outputId": "2fa7082b-c172-4017-ff5a-c7b60f727c4b", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 785 | |
} | |
}, | |
"source": [ | |
"result_dir = 'agModels-predictClass' # specifies folder where to store trained models\n", | |
"predictor = task.fit(train_data=train_data, label=label_column, output_directory=result_dir)" | |
], | |
"execution_count": 0, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"Beginning AutoGluon training ...\n", | |
"Preprocessing data ...\n", | |
"Here are the first 10 unique label values in your data: [' <=50K' ' >50K']\n", | |
"AutoGluon infers your prediction problem is: binary (because only two unique label-values observed)\n", | |
"If this is wrong, please specify `problem_type` argument in fit() instead (You may specify problem_type as one of: ['binary', 'multiclass', 'regression'])\n", | |
"\n", | |
"Selected class <--> label mapping: class 1 = >50K, class 0 = <=50K\n", | |
"\tData preprocessing and feature engineering runtime = 0.08s ...\n", | |
"AutoGluon will gauge predictive performance using evaluation metric: accuracy\n", | |
"To change this, specify the eval_metric argument of fit()\n", | |
"Fitting model: RandomForestClassifierGini ...\n", | |
"\t0.72s\t = Training runtime\n", | |
"\t0.8725\t = Validation accuracy score\n", | |
"Fitting model: RandomForestClassifierEntr ...\n", | |
"\t0.8s\t = Training runtime\n", | |
"\t0.875\t = Validation accuracy score\n", | |
"Fitting model: ExtraTreesClassifierGini ...\n", | |
"\t0.6s\t = Training runtime\n", | |
"\t0.8475\t = Validation accuracy score\n", | |
"Fitting model: ExtraTreesClassifierEntr ...\n", | |
"\t0.6s\t = Training runtime\n", | |
"\t0.8525\t = Validation accuracy score\n", | |
"Fitting model: KNeighborsClassifierUnif ...\n", | |
"\t0.01s\t = Training runtime\n", | |
"\t0.78\t = Validation accuracy score\n", | |
"Fitting model: KNeighborsClassifierDist ...\n", | |
"\t0.01s\t = Training runtime\n", | |
"\t0.7325\t = Validation accuracy score\n", | |
"Fitting model: LightGBMClassifier ...\n", | |
"\t0.39s\t = Training runtime\n", | |
"\t0.8775\t = Validation accuracy score\n", | |
"Fitting model: CatboostClassifier ...\n", | |
"\t2.38s\t = Training runtime\n", | |
"\t0.8725\t = Validation accuracy score\n", | |
"Fitting model: NeuralNetClassifier ...\n", | |
"\t15.16s\t = Training runtime\n", | |
"\t0.8675\t = Validation accuracy score\n", | |
"Fitting model: LightGBMClassifierCustom ...\n", | |
"\t0.72s\t = Training runtime\n", | |
"\t0.855\t = Validation accuracy score\n", | |
"Fitting model: weighted_ensemble_l1 ...\n", | |
"\t0.45s\t = Training runtime\n", | |
"\t0.88\t = Validation accuracy score\n", | |
"AutoGluon training complete, total runtime = 24.09s ...\n" | |
], | |
"name": "stderr" | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "LPucLx6ha429", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"テストデータをダウンロードする。" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "rHfXqg_EagA7", | |
"colab_type": "code", | |
"outputId": "00eed8c8-1eaa-4eb0-cbc3-70c3bf3b8271", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 54 | |
} | |
}, | |
"source": [ | |
"test_data = task.Dataset(file_path='https://autogluon.s3.amazonaws.com/datasets/Inc/test.csv')\n", | |
"y_test = test_data[label_column] # values to predict\n", | |
"test_data_nolab = test_data.drop(labels=[label_column],axis=1) " | |
], | |
"execution_count": 0, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"Loaded data from: https://autogluon.s3.amazonaws.com/datasets/Inc/test.csv | Columns = 15 / 15 | Rows = 9769 -> 9769\n" | |
], | |
"name": "stderr" | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "5tTNcSYRa7sg", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"モデルを評価する。" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "23IKFo1gagUS", | |
"colab_type": "code", | |
"outputId": "07474158-ede6-4350-ba6a-05aa63869c4f", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 663 | |
} | |
}, | |
"source": [ | |
"# unnecessary, just demonstrates how to load previously-trained predictor from file\n", | |
"predictor = task.load(result_dir) \n", | |
"\n", | |
"y_pred = predictor.predict(test_data_nolab)\n", | |
"print(\"Predictions: \", y_pred)\n", | |
"perf = predictor.evaluate_predictions(y_true=y_test, y_pred=y_pred, auxiliary_metrics=True)" | |
], | |
"execution_count": 0, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"Evaluation: accuracy on test data: 0.856280\n", | |
"Evaluations on test data:\n", | |
"{\n", | |
" \"accuracy\": 0.8562800696079435,\n", | |
" \"accuracy_score\": 0.8562800696079435,\n", | |
" \"balanced_accuracy_score\": 0.7625358844305661,\n", | |
" \"matthews_corrcoef\": 0.5769073882148189,\n", | |
" \"f1_score\": 0.8562800696079435\n", | |
"}\n" | |
], | |
"name": "stderr" | |
}, | |
{ | |
"output_type": "stream", | |
"text": [ | |
"Predictions: [' <=50K' ' <=50K' ' >50K' ... ' <=50K' ' <=50K' ' <=50K']\n" | |
], | |
"name": "stdout" | |
}, | |
{ | |
"output_type": "stream", | |
"text": [ | |
"Detailed (per-class) classification report:\n", | |
"{\n", | |
" \" <=50K\": {\n", | |
" \"precision\": 0.8791222570532915,\n", | |
" \"recall\": 0.9409475238223057,\n", | |
" \"f1-score\": 0.9089848308051341,\n", | |
" \"support\": 7451\n", | |
" },\n", | |
" \" >50K\": {\n", | |
" \"precision\": 0.7547380156075808,\n", | |
" \"recall\": 0.5841242450388265,\n", | |
" \"f1-score\": 0.6585603112840467,\n", | |
" \"support\": 2318\n", | |
" },\n", | |
" \"accuracy\": 0.8562800696079435,\n", | |
" \"macro avg\": {\n", | |
" \"precision\": 0.8169301363304362,\n", | |
" \"recall\": 0.7625358844305661,\n", | |
" \"f1-score\": 0.7837725710445904,\n", | |
" \"support\": 9769\n", | |
" },\n", | |
" \"weighted avg\": {\n", | |
" \"precision\": 0.8496082155269166,\n", | |
" \"recall\": 0.8562800696079435,\n", | |
" \"f1-score\": 0.8495638014009084,\n", | |
" \"support\": 9769\n", | |
" }\n", | |
"}\n" | |
], | |
"name": "stderr" | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "lD0LNVMJaqRI", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"探索の様子を確認する。" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "tAjW4yaaaliD", | |
"colab_type": "code", | |
"outputId": "0f4a0147-1c09-456b-f293-5ec47cce6fe1", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 258 | |
} | |
}, | |
"source": [ | |
"results = predictor.fit_summary()\n" | |
], | |
"execution_count": 0, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"*** Summary of fit() ***\n", | |
"Number of models trained: 11\n", | |
"Types of models trained: \n", | |
"{'KNNModel', 'TabularNeuralNetModel', 'CatboostModel', 'LGBModel', 'WeightedEnsembleModel', 'RFModel'}\n", | |
"Validation performance of individual models: {'RandomForestClassifierGini': 0.8725, 'RandomForestClassifierEntr': 0.875, 'ExtraTreesClassifierGini': 0.8475, 'ExtraTreesClassifierEntr': 0.8525, 'KNeighborsClassifierUnif': 0.78, 'KNeighborsClassifierDist': 0.7325, 'LightGBMClassifier': 0.8775, 'CatboostClassifier': 0.8725, 'NeuralNetClassifier': 0.8675, 'LightGBMClassifierCustom': 0.855, 'weighted_ensemble_l1': 0.88}\n", | |
"Best model (based on validation performance): weighted_ensemble_l1\n", | |
"Hyperparameter-tuning used: False\n", | |
"Bagging used: False \n", | |
"Stack-ensembling used: False \n", | |
"User-specified hyperparameters:\n", | |
"{'NN': {'num_epochs': 500}, 'GBM': {'num_boost_round': 10000}, 'CAT': {'iterations': 10000}, 'RF': {'n_estimators': 300}, 'XT': {'n_estimators': 300}, 'KNN': {}, 'custom': ['GBM']}\n", | |
"Plot summary of models saved to file: SummaryOfModels.html\n", | |
"*** End of fit() summary ***\n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "SYQIMvyDbEBS", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"モデルがタスクと、カラムをどう認識したか確認する。" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "dcVTZ4qjbDM4", | |
"colab_type": "code", | |
"outputId": "9f480390-c6b8-40c8-8a26-df74c3d64aaf", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 71 | |
} | |
}, | |
"source": [ | |
"print(\"AutoGluon infers problem type is: \", predictor.problem_type)\n", | |
"print(\"AutoGluon categorized the features as: \", predictor.feature_types)" | |
], | |
"execution_count": 0, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"AutoGluon infers problem type is: binary\n", | |
"AutoGluon categorized the features as: {'nlp': [], 'vectorizers': [], 'int': ['age', 'fnlwgt', 'education-num', 'capital-gain', 'capital-loss', 'hours-per-week'], 'object': ['workclass', 'education', 'marital-status', 'occupation', 'relationship', 'race', 'sex', 'native-country']}\n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "5C2lmOHUaMt7", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"### 回帰\n" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "vChVHi2bbM3k", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"年齢をtarget varとして回帰する。" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "An2PitqfaNBr", | |
"colab_type": "code", | |
"outputId": "8c0f3880-e5a3-4ef5-d356-b8a4d72b0d32", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 187 | |
} | |
}, | |
"source": [ | |
"age_column = 'age'\n", | |
"print(\"Summary of age variable: \\n\", train_data[age_column].describe())" | |
], | |
"execution_count": 0, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"Summary of age variable: \n", | |
" count 2000.000000\n", | |
"mean 38.375000\n", | |
"std 13.781195\n", | |
"min 17.000000\n", | |
"25% 27.000000\n", | |
"50% 37.000000\n", | |
"75% 47.000000\n", | |
"max 90.000000\n", | |
"Name: age, dtype: float64\n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "_wkdvWk6eIoh", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"`evaluate()` で評価する。\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "xxT2HQOkeFAr", | |
"colab_type": "code", | |
"outputId": "844fea81-f834-406a-de24-a47d661ee43f", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 632 | |
} | |
}, | |
"source": [ | |
"predictor_age = task.fit(train_data=train_data, \n", | |
" output_directory=\"agModels-predictAge\", \n", | |
" label=age_column,\n", | |
" problem_type='regression')\n", | |
"performance = predictor_age.evaluate(test_data)" | |
], | |
"execution_count": 0, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"Beginning AutoGluon training ...\n", | |
"Preprocessing data ...\n", | |
"\tData preprocessing and feature engineering runtime = 0.08s ...\n", | |
"AutoGluon will gauge predictive performance using evaluation metric: root_mean_squared_error\n", | |
"To change this, specify the eval_metric argument of fit()\n", | |
"Fitting model: RandomForestRegressorMSE ...\n", | |
"NumExpr defaulting to 2 threads.\n", | |
"\t1.51s\t = Training runtime\n", | |
"\t-11.08\t = Validation root_mean_squared_error score\n", | |
"Fitting model: ExtraTreesRegressorMSE ...\n", | |
"\t1.1s\t = Training runtime\n", | |
"\t-11.2898\t = Validation root_mean_squared_error score\n", | |
"Fitting model: KNeighborsRegressorUnif ...\n", | |
"\t0.01s\t = Training runtime\n", | |
"\t-14.5251\t = Validation root_mean_squared_error score\n", | |
"Fitting model: KNeighborsRegressorDist ...\n", | |
"\t0.01s\t = Training runtime\n", | |
"\t-15.2104\t = Validation root_mean_squared_error score\n", | |
"Fitting model: LightGBMRegressor ...\n", | |
"\t0.34s\t = Training runtime\n", | |
"\t-10.311\t = Validation root_mean_squared_error score\n", | |
"Fitting model: CatboostRegressor ...\n", | |
"\t2.16s\t = Training runtime\n", | |
"\t-10.4003\t = Validation root_mean_squared_error score\n", | |
"Fitting model: NeuralNetRegressor ...\n", | |
"\t13.07s\t = Training runtime\n", | |
"\t-10.6561\t = Validation root_mean_squared_error score\n", | |
"Fitting model: LightGBMRegressorCustom ...\n", | |
"\t0.61s\t = Training runtime\n", | |
"\t-10.5824\t = Validation root_mean_squared_error score\n", | |
"Fitting model: weighted_ensemble_l1 ...\n", | |
"\t0.43s\t = Training runtime\n", | |
"\t-10.193\t = Validation root_mean_squared_error score\n", | |
"AutoGluon training complete, total runtime = 21.12s ...\n" | |
], | |
"name": "stderr" | |
}, | |
{ | |
"output_type": "stream", | |
"text": [ | |
"Predictive performance on given dataset: root_mean_squared_error = 10.27827195078937\n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "XNGO3Ds-ePk5", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"結果を確認する。" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "pu3yKie4ePsk", | |
"colab_type": "code", | |
"outputId": "ede76091-30ce-4302-f27e-0cd50dece0a4", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 258 | |
} | |
}, | |
"source": [ | |
"results = predictor_age.fit_summary()" | |
], | |
"execution_count": 0, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"*** Summary of fit() ***\n", | |
"Number of models trained: 9\n", | |
"Types of models trained: \n", | |
"{'KNNModel', 'TabularNeuralNetModel', 'CatboostModel', 'LGBModel', 'WeightedEnsembleModel', 'RFModel'}\n", | |
"Validation performance of individual models: {'RandomForestRegressorMSE': -11.079997378910637, 'ExtraTreesRegressorMSE': -11.289767107134379, 'KNeighborsRegressorUnif': -14.525081755363722, 'KNeighborsRegressorDist': -15.210358581022946, 'LightGBMRegressor': -10.31100588616801, 'CatboostRegressor': -10.400331629210777, 'NeuralNetRegressor': -10.656088855066407, 'LightGBMRegressorCustom': -10.582398535859294, 'weighted_ensemble_l1': -10.193049558344967}\n", | |
"Best model (based on validation performance): weighted_ensemble_l1\n", | |
"Hyperparameter-tuning used: False\n", | |
"Bagging used: False \n", | |
"Stack-ensembling used: False \n", | |
"User-specified hyperparameters:\n", | |
"{'NN': {'num_epochs': 500}, 'GBM': {'num_boost_round': 10000}, 'CAT': {'iterations': 10000}, 'RF': {'n_estimators': 300}, 'XT': {'n_estimators': 300}, 'KNN': {}, 'custom': ['GBM']}\n", | |
"Plot summary of models saved to file: SummaryOfModels.html\n", | |
"*** End of fit() summary ***\n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "dKWhS31qeM-W", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
} | |
] | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment