Skip to content

Instantly share code, notes, and snippets.

@tomo-makes
Last active December 14, 2019 03:40
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save tomo-makes/d388f40632d553778e59ab5552b5cdcb to your computer and use it in GitHub Desktop.
Save tomo-makes/d388f40632d553778e59ab5552b5cdcb to your computer and use it in GitHub Desktop.
@kirikei記事にそって、AutoGluon使ってみた
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "@kirikei記事にそって、AutoGluon使ってみた",
"provenance": [],
"collapsed_sections": [],
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"accelerator": "GPU"
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/tomo-makes/d388f40632d553778e59ab5552b5cdcb/-kirikei-autogluon.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "BwPFzJD8ZrxX",
"colab_type": "text"
},
"source": [
"# AutoGluonを試す\n",
"\n",
"[データとfitだけで始めるAutoML - AutoGluon使ってみた - - Qiita](https://qiita.com/kirikei/items/f879eb2cfbaf3d37ee0f)\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "TSGrRaR8fRCe",
"colab_type": "text"
},
"source": [
"### 前提\n",
"\n",
"- GPUあり、Python3ランタイムを選ぶ\n",
" - ただし、 `mxnet-cu100` をインストールに選ばなければGPUなしでも動くと思われる"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "pwkN1h0SaAxa",
"colab_type": "text"
},
"source": [
"## インストール"
]
},
{
"cell_type": "code",
"metadata": {
"id": "tq7qZ2_kZoWT",
"colab_type": "code",
"outputId": "638d058a-b38b-4eb8-9d73-7dc565554aca",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 1000
}
},
"source": [
"!pip install --upgrade mxnet-cu100"
],
"execution_count": 0,
"outputs": [
{
"output_type": "stream",
"text": [
"Collecting mxnet-cu100\n",
"\u001b[?25l Downloading https://files.pythonhosted.org/packages/3d/84/d098e0607ee6207448b6af65315f5d45946b49e4f48160eade6cdd64ce4e/mxnet_cu100-1.5.1.post0-py2.py3-none-manylinux1_x86_64.whl (540.1MB)\n",
"\u001b[K |████████████████████████████████| 540.1MB 29kB/s \n",
"\u001b[?25hCollecting graphviz<0.9.0,>=0.8.1\n",
" Downloading https://files.pythonhosted.org/packages/53/39/4ab213673844e0c004bed8a0781a0721a3f6bb23eb8854ee75c236428892/graphviz-0.8.4-py2.py3-none-any.whl\n",
"Requirement already satisfied, skipping upgrade: numpy<2.0.0,>1.16.0 in /usr/local/lib/python3.6/dist-packages (from mxnet-cu100) (1.17.4)\n",
"Requirement already satisfied, skipping upgrade: requests<3,>=2.20.0 in /usr/local/lib/python3.6/dist-packages (from mxnet-cu100) (2.21.0)\n",
"Requirement already satisfied, skipping upgrade: urllib3<1.25,>=1.21.1 in /usr/local/lib/python3.6/dist-packages (from requests<3,>=2.20.0->mxnet-cu100) (1.24.3)\n",
"Requirement already satisfied, skipping upgrade: certifi>=2017.4.17 in /usr/local/lib/python3.6/dist-packages (from requests<3,>=2.20.0->mxnet-cu100) (2019.11.28)\n",
"Requirement already satisfied, skipping upgrade: chardet<3.1.0,>=3.0.2 in /usr/local/lib/python3.6/dist-packages (from requests<3,>=2.20.0->mxnet-cu100) (3.0.4)\n",
"Requirement already satisfied, skipping upgrade: idna<2.9,>=2.5 in /usr/local/lib/python3.6/dist-packages (from requests<3,>=2.20.0->mxnet-cu100) (2.8)\n",
"Installing collected packages: graphviz, mxnet-cu100\n",
" Found existing installation: graphviz 0.10.1\n",
" Uninstalling graphviz-0.10.1:\n",
" Successfully uninstalled graphviz-0.10.1\n",
"Successfully installed graphviz-0.8.4 mxnet-cu100-1.5.1.post0\n",
"Collecting autogluon\n",
"\u001b[?25l Downloading https://files.pythonhosted.org/packages/4f/11/94e9c1cb329c852c7611e5e53061969b2eb49afa38a270dc8356a0aca02f/autogluon-0.0.3-py3-none-any.whl (316kB)\n",
"\u001b[K |████████████████████████████████| 317kB 37.8MB/s \n",
"\u001b[?25hRequirement already satisfied: numpy>=1.16.0 in /usr/local/lib/python3.6/dist-packages (from autogluon) (1.17.4)\n",
"Requirement already satisfied: matplotlib in /usr/local/lib/python3.6/dist-packages (from autogluon) (3.1.2)\n",
"Collecting tqdm>=4.38.0\n",
"\u001b[?25l Downloading https://files.pythonhosted.org/packages/7f/32/5144caf0478b1f26bd9d97f510a47336cf4ac0f96c6bc3b5af20d4173920/tqdm-4.40.2-py2.py3-none-any.whl (55kB)\n",
"\u001b[K |████████████████████████████████| 61kB 9.7MB/s \n",
"\u001b[?25hCollecting ConfigSpace<=0.4.10\n",
"\u001b[?25l Downloading https://files.pythonhosted.org/packages/42/de/4e8e4f26332fc65404f52baa112defbf822b6738b60bfa6b2993f5c60933/ConfigSpace-0.4.10.tar.gz (882kB)\n",
"\u001b[K |████████████████████████████████| 890kB 25.8MB/s \n",
"\u001b[?25hCollecting scikit-learn==0.21.2\n",
"\u001b[?25l Downloading https://files.pythonhosted.org/packages/85/04/49633f490f726da6e454fddc8e938bbb5bfed2001681118d3814c219b723/scikit_learn-0.21.2-cp36-cp36m-manylinux1_x86_64.whl (6.7MB)\n",
"\u001b[K |████████████████████████████████| 6.7MB 56.2MB/s \n",
"\u001b[?25hCollecting distributed>=2.6.0\n",
"\u001b[?25l Downloading https://files.pythonhosted.org/packages/c1/45/1c741b907b7aa60a7a951e3cb107f1c3cda40fc368c2ff21a0bb58029aee/distributed-2.9.0-py3-none-any.whl (569kB)\n",
"\u001b[K |████████████████████████████████| 573kB 53.8MB/s \n",
"\u001b[?25hCollecting gluoncv>=0.5.0\n",
"\u001b[?25l Downloading https://files.pythonhosted.org/packages/66/1e/e9096a9742764c9427bb7e408bc867b8761d9a2a830514f1997d1e51abaf/gluoncv-0.5.0-py2.py3-none-any.whl (511kB)\n",
"\u001b[K |████████████████████████████████| 512kB 64.1MB/s \n",
"\u001b[?25hCollecting lightgbm==2.3.0\n",
"\u001b[?25l Downloading https://files.pythonhosted.org/packages/05/ec/756f13b25258e0aa6ec82d98504e01523814f95fc70718407419b8520e1d/lightgbm-2.3.0-py2.py3-none-manylinux1_x86_64.whl (1.3MB)\n",
"\u001b[K |████████████████████████████████| 1.3MB 24.5MB/s \n",
"\u001b[?25hCollecting scikit-optimize\n",
"\u001b[?25l Downloading https://files.pythonhosted.org/packages/f4/44/60f82c97d1caa98752c7da2c1681cab5c7a390a0fdd3a55fac672b321cac/scikit_optimize-0.5.2-py2.py3-none-any.whl (74kB)\n",
"\u001b[K |████████████████████████████████| 81kB 12.3MB/s \n",
"\u001b[?25hCollecting catboost\n",
"\u001b[?25l Downloading https://files.pythonhosted.org/packages/3d/f6/733fe7cca5d0d882e1a708ad59da2510416cc2e4fa54e17c7a5082f67811/catboost-0.20.1-cp36-none-manylinux1_x86_64.whl (63.6MB)\n",
"\u001b[K |████████████████████████████████| 63.6MB 49.3MB/s \n",
"\u001b[?25hRequirement already satisfied: scipy>=1.3.3 in /usr/local/lib/python3.6/dist-packages (from autogluon) (1.3.3)\n",
"Requirement already satisfied: tornado in /usr/local/lib/python3.6/dist-packages (from autogluon) (4.5.3)\n",
"Requirement already satisfied: graphviz in /usr/local/lib/python3.6/dist-packages (from autogluon) (0.8.4)\n",
"Collecting pandas==0.24.2\n",
"\u001b[?25l Downloading https://files.pythonhosted.org/packages/19/74/e50234bc82c553fecdbd566d8650801e3fe2d6d8c8d940638e3d8a7c5522/pandas-0.24.2-cp36-cp36m-manylinux1_x86_64.whl (10.1MB)\n",
"\u001b[K |████████████████████████████████| 10.1MB 22.4MB/s \n",
"\u001b[?25hRequirement already satisfied: requests in /usr/local/lib/python3.6/dist-packages (from autogluon) (2.21.0)\n",
"Requirement already satisfied: psutil in /usr/local/lib/python3.6/dist-packages (from autogluon) (5.4.8)\n",
"Collecting gluonnlp==0.8.1\n",
"\u001b[?25l Downloading https://files.pythonhosted.org/packages/6e/2d/40c2ad37d74e5e9030064d73542b7df0f7df7ba98d47932874033cf03d79/gluonnlp-0.8.1.tar.gz (236kB)\n",
"\u001b[K |████████████████████████████████| 245kB 53.6MB/s \n",
"\u001b[?25hRequirement already satisfied: cython in /usr/local/lib/python3.6/dist-packages (from autogluon) (0.29.14)\n",
"Collecting boto3==1.9.187\n",
"\u001b[?25l Downloading https://files.pythonhosted.org/packages/e7/f6/29b7254c051801ac182710bda0032c6d4bd20ea914c5f10d6eb3fc47e669/boto3-1.9.187-py2.py3-none-any.whl (128kB)\n",
"\u001b[K |████████████████████████████████| 133kB 57.9MB/s \n",
"\u001b[?25hCollecting paramiko>=2.5.0\n",
"\u001b[?25l Downloading https://files.pythonhosted.org/packages/06/1e/1e08baaaf6c3d3df1459fd85f0e7d2d6aa916f33958f151ee1ecc9800971/paramiko-2.7.1-py2.py3-none-any.whl (206kB)\n",
"\u001b[K |████████████████████████████████| 215kB 60.7MB/s \n",
"\u001b[?25hRequirement already satisfied: kiwisolver>=1.0.1 in /usr/local/lib/python3.6/dist-packages (from matplotlib->autogluon) (1.1.0)\n",
"Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.6/dist-packages (from matplotlib->autogluon) (0.10.0)\n",
"Requirement already satisfied: python-dateutil>=2.1 in /usr/local/lib/python3.6/dist-packages (from matplotlib->autogluon) (2.6.1)\n",
"Requirement already satisfied: pyparsing!=2.0.4,!=2.1.2,!=2.1.6,>=2.0.1 in /usr/local/lib/python3.6/dist-packages (from matplotlib->autogluon) (2.4.5)\n",
"Requirement already satisfied: typing in /usr/local/lib/python3.6/dist-packages (from ConfigSpace<=0.4.10->autogluon) (3.6.6)\n",
"Requirement already satisfied: joblib>=0.11 in /usr/local/lib/python3.6/dist-packages (from scikit-learn==0.21.2->autogluon) (0.14.1)\n",
"Requirement already satisfied: cloudpickle>=0.2.2 in /usr/local/lib/python3.6/dist-packages (from distributed>=2.6.0->autogluon) (1.2.2)\n",
"Requirement already satisfied: sortedcontainers!=2.0.0,!=2.0.1 in /usr/local/lib/python3.6/dist-packages (from distributed>=2.6.0->autogluon) (2.1.0)\n",
"Collecting dask>=2.7.0\n",
"\u001b[?25l Downloading https://files.pythonhosted.org/packages/bf/b3/9175539d5a43b0bb1fe6d9729613a9639dd78a0e13c3fa7fc1ba702e56fa/dask-2.9.0-py3-none-any.whl (770kB)\n",
"\u001b[K |████████████████████████████████| 778kB 54.7MB/s \n",
"\u001b[?25hRequirement already satisfied: msgpack in /usr/local/lib/python3.6/dist-packages (from distributed>=2.6.0->autogluon) (0.5.6)\n",
"Requirement already satisfied: zict>=0.1.3 in /usr/local/lib/python3.6/dist-packages (from distributed>=2.6.0->autogluon) (1.0.0)\n",
"Requirement already satisfied: pyyaml in /usr/local/lib/python3.6/dist-packages (from distributed>=2.6.0->autogluon) (3.13)\n",
"Requirement already satisfied: toolz>=0.7.4 in /usr/local/lib/python3.6/dist-packages (from distributed>=2.6.0->autogluon) (0.10.0)\n",
"Requirement already satisfied: tblib in /usr/local/lib/python3.6/dist-packages (from distributed>=2.6.0->autogluon) (1.6.0)\n",
"Requirement already satisfied: click>=6.6 in /usr/local/lib/python3.6/dist-packages (from distributed>=2.6.0->autogluon) (7.0)\n",
"Requirement already satisfied: Pillow in /usr/local/lib/python3.6/dist-packages (from gluoncv>=0.5.0->autogluon) (4.3.0)\n",
"Requirement already satisfied: plotly in /usr/local/lib/python3.6/dist-packages (from catboost->autogluon) (4.1.1)\n",
"Requirement already satisfied: six in /usr/local/lib/python3.6/dist-packages (from catboost->autogluon) (1.12.0)\n",
"Requirement already satisfied: pytz>=2011k in /usr/local/lib/python3.6/dist-packages (from pandas==0.24.2->autogluon) (2018.9)\n",
"Requirement already satisfied: chardet<3.1.0,>=3.0.2 in /usr/local/lib/python3.6/dist-packages (from requests->autogluon) (3.0.4)\n",
"Requirement already satisfied: idna<2.9,>=2.5 in /usr/local/lib/python3.6/dist-packages (from requests->autogluon) (2.8)\n",
"Requirement already satisfied: urllib3<1.25,>=1.21.1 in /usr/local/lib/python3.6/dist-packages (from requests->autogluon) (1.24.3)\n",
"Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.6/dist-packages (from requests->autogluon) (2019.11.28)\n",
"Requirement already satisfied: s3transfer<0.3.0,>=0.2.0 in /usr/local/lib/python3.6/dist-packages (from boto3==1.9.187->autogluon) (0.2.1)\n",
"Requirement already satisfied: jmespath<1.0.0,>=0.7.1 in /usr/local/lib/python3.6/dist-packages (from boto3==1.9.187->autogluon) (0.9.4)\n",
"Collecting botocore<1.13.0,>=1.12.187\n",
"\u001b[?25l Downloading https://files.pythonhosted.org/packages/8e/7b/88f10115b4748f86be6b7b1d8761ba5023fccf6e6cbe762e368f63eddcf9/botocore-1.12.253-py2.py3-none-any.whl (5.7MB)\n",
"\u001b[K |████████████████████████████████| 5.8MB 57.5MB/s \n",
"\u001b[?25hCollecting bcrypt>=3.1.3\n",
"\u001b[?25l Downloading https://files.pythonhosted.org/packages/8b/1d/82826443777dd4a624e38a08957b975e75df859b381ae302cfd7a30783ed/bcrypt-3.1.7-cp34-abi3-manylinux1_x86_64.whl (56kB)\n",
"\u001b[K |████████████████████████████████| 61kB 9.8MB/s \n",
"\u001b[?25hCollecting pynacl>=1.0.1\n",
"\u001b[?25l Downloading https://files.pythonhosted.org/packages/27/15/2cd0a203f318c2240b42cd9dd13c931ddd61067809fee3479f44f086103e/PyNaCl-1.3.0-cp34-abi3-manylinux1_x86_64.whl (759kB)\n",
"\u001b[K |████████████████████████████████| 768kB 67.3MB/s \n",
"\u001b[?25hCollecting cryptography>=2.5\n",
"\u001b[?25l Downloading https://files.pythonhosted.org/packages/ca/9a/7cece52c46546e214e10811b36b2da52ce1ea7fa203203a629b8dfadad53/cryptography-2.8-cp34-abi3-manylinux2010_x86_64.whl (2.3MB)\n",
"\u001b[K |████████████████████████████████| 2.3MB 64.4MB/s \n",
"\u001b[?25hRequirement already satisfied: setuptools in /usr/local/lib/python3.6/dist-packages (from kiwisolver>=1.0.1->matplotlib->autogluon) (42.0.2)\n",
"Requirement already satisfied: heapdict in /usr/local/lib/python3.6/dist-packages (from zict>=0.1.3->distributed>=2.6.0->autogluon) (1.0.1)\n",
"Requirement already satisfied: olefile in /usr/local/lib/python3.6/dist-packages (from Pillow->gluoncv>=0.5.0->autogluon) (0.46)\n",
"Requirement already satisfied: retrying>=1.3.3 in /usr/local/lib/python3.6/dist-packages (from plotly->catboost->autogluon) (1.3.3)\n",
"Requirement already satisfied: docutils<0.16,>=0.10 in /usr/local/lib/python3.6/dist-packages (from botocore<1.13.0,>=1.12.187->boto3==1.9.187->autogluon) (0.15.2)\n",
"Requirement already satisfied: cffi>=1.1 in /usr/local/lib/python3.6/dist-packages (from bcrypt>=3.1.3->paramiko>=2.5.0->autogluon) (1.13.2)\n",
"Requirement already satisfied: pycparser in /usr/local/lib/python3.6/dist-packages (from cffi>=1.1->bcrypt>=3.1.3->paramiko>=2.5.0->autogluon) (2.19)\n",
"Building wheels for collected packages: ConfigSpace, gluonnlp\n",
" Building wheel for ConfigSpace (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
" Created wheel for ConfigSpace: filename=ConfigSpace-0.4.10-cp36-cp36m-linux_x86_64.whl size=2712287 sha256=815a05501e2523793a5913ab06d99fe34ae3e3aa4579b2b8cda2eb51821d2322\n",
" Stored in directory: /root/.cache/pip/wheels/75/83/cb/28dd42bac69c8867d485138030daa83841c7f84afe68b2fdf7\n",
" Building wheel for gluonnlp (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
" Created wheel for gluonnlp: filename=gluonnlp-0.8.1-cp36-none-any.whl size=293520 sha256=5f08fda536ea3c6253f7c86b3499b2dcace59f20e4da64122b3730d099f0b444\n",
" Stored in directory: /root/.cache/pip/wheels/3e/e7/3e/9cdf8ad7fce112fde2f4a52604045e5dd80f84d645bedb70c7\n",
"Successfully built ConfigSpace gluonnlp\n",
"\u001b[31mERROR: google-colab 1.0.0 has requirement pandas~=0.25.0; python_version >= \"3.0\", but you'll have pandas 0.24.2 which is incompatible.\u001b[0m\n",
"\u001b[31mERROR: distributed 2.9.0 has requirement tornado>=5, but you'll have tornado 4.5.3 which is incompatible.\u001b[0m\n",
"Installing collected packages: tqdm, ConfigSpace, scikit-learn, dask, distributed, gluoncv, lightgbm, scikit-optimize, pandas, catboost, gluonnlp, botocore, boto3, bcrypt, pynacl, cryptography, paramiko, autogluon\n",
" Found existing installation: tqdm 4.28.1\n",
" Uninstalling tqdm-4.28.1:\n",
" Successfully uninstalled tqdm-4.28.1\n",
" Found existing installation: scikit-learn 0.21.3\n",
" Uninstalling scikit-learn-0.21.3:\n",
" Successfully uninstalled scikit-learn-0.21.3\n",
" Found existing installation: dask 1.1.5\n",
" Uninstalling dask-1.1.5:\n",
" Successfully uninstalled dask-1.1.5\n",
" Found existing installation: distributed 1.25.3\n",
" Uninstalling distributed-1.25.3:\n",
" Successfully uninstalled distributed-1.25.3\n",
" Found existing installation: lightgbm 2.2.3\n",
" Uninstalling lightgbm-2.2.3:\n",
" Successfully uninstalled lightgbm-2.2.3\n",
" Found existing installation: pandas 0.25.3\n",
" Uninstalling pandas-0.25.3:\n",
" Successfully uninstalled pandas-0.25.3\n",
" Found existing installation: botocore 1.13.36\n",
" Uninstalling botocore-1.13.36:\n",
" Successfully uninstalled botocore-1.13.36\n",
" Found existing installation: boto3 1.10.36\n",
" Uninstalling boto3-1.10.36:\n",
" Successfully uninstalled boto3-1.10.36\n",
"Successfully installed ConfigSpace-0.4.10 autogluon-0.0.3 bcrypt-3.1.7 boto3-1.9.187 botocore-1.12.253 catboost-0.20.1 cryptography-2.8 dask-2.9.0 distributed-2.9.0 gluoncv-0.5.0 gluonnlp-0.8.1 lightgbm-2.3.0 pandas-0.24.2 paramiko-2.7.1 pynacl-1.3.0 scikit-learn-0.21.2 scikit-optimize-0.5.2 tqdm-4.40.2\n"
],
"name": "stdout"
},
{
"output_type": "display_data",
"data": {
"application/vnd.colab-display-data+json": {
"pip_warning": {
"packages": [
"pandas",
"tqdm"
]
}
}
},
"metadata": {
"tags": []
}
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "AxEJ6a9vdNPE",
"colab_type": "code",
"outputId": "c3a01f6d-f655-4d22-d11c-e3e1fe310ada",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 1000
}
},
"source": [
"!pip install autogluon"
],
"execution_count": 0,
"outputs": [
{
"output_type": "stream",
"text": [
"Requirement already satisfied: autogluon in /usr/local/lib/python3.6/dist-packages (0.0.3)\n",
"Requirement already satisfied: ConfigSpace<=0.4.10 in /usr/local/lib/python3.6/dist-packages (from autogluon) (0.4.10)\n",
"Requirement already satisfied: graphviz in /usr/local/lib/python3.6/dist-packages (from autogluon) (0.8.4)\n",
"Requirement already satisfied: psutil in /usr/local/lib/python3.6/dist-packages (from autogluon) (5.4.8)\n",
"Requirement already satisfied: numpy>=1.16.0 in /usr/local/lib/python3.6/dist-packages (from autogluon) (1.17.4)\n",
"Requirement already satisfied: scikit-learn==0.21.2 in /usr/local/lib/python3.6/dist-packages (from autogluon) (0.21.2)\n",
"Requirement already satisfied: tornado in /usr/local/lib/python3.6/dist-packages (from autogluon) (4.5.3)\n",
"Requirement already satisfied: tqdm>=4.38.0 in /usr/local/lib/python3.6/dist-packages (from autogluon) (4.40.2)\n",
"Requirement already satisfied: gluoncv>=0.5.0 in /usr/local/lib/python3.6/dist-packages (from autogluon) (0.5.0)\n",
"Requirement already satisfied: gluonnlp==0.8.1 in /usr/local/lib/python3.6/dist-packages (from autogluon) (0.8.1)\n",
"Requirement already satisfied: scikit-optimize in /usr/local/lib/python3.6/dist-packages (from autogluon) (0.5.2)\n",
"Requirement already satisfied: lightgbm==2.3.0 in /usr/local/lib/python3.6/dist-packages (from autogluon) (2.3.0)\n",
"Requirement already satisfied: distributed>=2.6.0 in /usr/local/lib/python3.6/dist-packages (from autogluon) (2.9.0)\n",
"Requirement already satisfied: catboost in /usr/local/lib/python3.6/dist-packages (from autogluon) (0.20.1)\n",
"Requirement already satisfied: boto3==1.9.187 in /usr/local/lib/python3.6/dist-packages (from autogluon) (1.9.187)\n",
"Requirement already satisfied: scipy>=1.3.3 in /usr/local/lib/python3.6/dist-packages (from autogluon) (1.3.3)\n",
"Requirement already satisfied: requests in /usr/local/lib/python3.6/dist-packages (from autogluon) (2.21.0)\n",
"Requirement already satisfied: matplotlib in /usr/local/lib/python3.6/dist-packages (from autogluon) (3.1.2)\n",
"Requirement already satisfied: cython in /usr/local/lib/python3.6/dist-packages (from autogluon) (0.29.14)\n",
"Requirement already satisfied: paramiko>=2.5.0 in /usr/local/lib/python3.6/dist-packages (from autogluon) (2.7.1)\n",
"Requirement already satisfied: pandas==0.24.2 in /usr/local/lib/python3.6/dist-packages (from autogluon) (0.24.2)\n",
"Requirement already satisfied: typing in /usr/local/lib/python3.6/dist-packages (from ConfigSpace<=0.4.10->autogluon) (3.6.6)\n",
"Requirement already satisfied: pyparsing in /usr/local/lib/python3.6/dist-packages (from ConfigSpace<=0.4.10->autogluon) (2.4.5)\n",
"Requirement already satisfied: joblib>=0.11 in /usr/local/lib/python3.6/dist-packages (from scikit-learn==0.21.2->autogluon) (0.14.1)\n",
"Requirement already satisfied: Pillow in /usr/local/lib/python3.6/dist-packages (from gluoncv>=0.5.0->autogluon) (4.3.0)\n",
"Requirement already satisfied: cloudpickle>=0.2.2 in /usr/local/lib/python3.6/dist-packages (from distributed>=2.6.0->autogluon) (1.2.2)\n",
"Requirement already satisfied: dask>=2.7.0 in /usr/local/lib/python3.6/dist-packages (from distributed>=2.6.0->autogluon) (2.9.0)\n",
"Requirement already satisfied: toolz>=0.7.4 in /usr/local/lib/python3.6/dist-packages (from distributed>=2.6.0->autogluon) (0.10.0)\n",
"Requirement already satisfied: msgpack in /usr/local/lib/python3.6/dist-packages (from distributed>=2.6.0->autogluon) (0.5.6)\n",
"Requirement already satisfied: sortedcontainers!=2.0.0,!=2.0.1 in /usr/local/lib/python3.6/dist-packages (from distributed>=2.6.0->autogluon) (2.1.0)\n",
"Requirement already satisfied: tblib in /usr/local/lib/python3.6/dist-packages (from distributed>=2.6.0->autogluon) (1.6.0)\n",
"Requirement already satisfied: pyyaml in /usr/local/lib/python3.6/dist-packages (from distributed>=2.6.0->autogluon) (3.13)\n",
"Requirement already satisfied: zict>=0.1.3 in /usr/local/lib/python3.6/dist-packages (from distributed>=2.6.0->autogluon) (1.0.0)\n",
"Requirement already satisfied: click>=6.6 in /usr/local/lib/python3.6/dist-packages (from distributed>=2.6.0->autogluon) (7.0)\n",
"Requirement already satisfied: plotly in /usr/local/lib/python3.6/dist-packages (from catboost->autogluon) (4.1.1)\n",
"Requirement already satisfied: six in /usr/local/lib/python3.6/dist-packages (from catboost->autogluon) (1.12.0)\n",
"Requirement already satisfied: jmespath<1.0.0,>=0.7.1 in /usr/local/lib/python3.6/dist-packages (from boto3==1.9.187->autogluon) (0.9.4)\n",
"Requirement already satisfied: s3transfer<0.3.0,>=0.2.0 in /usr/local/lib/python3.6/dist-packages (from boto3==1.9.187->autogluon) (0.2.1)\n",
"Requirement already satisfied: botocore<1.13.0,>=1.12.187 in /usr/local/lib/python3.6/dist-packages (from boto3==1.9.187->autogluon) (1.12.253)\n",
"Requirement already satisfied: urllib3<1.25,>=1.21.1 in /usr/local/lib/python3.6/dist-packages (from requests->autogluon) (1.24.3)\n",
"Requirement already satisfied: idna<2.9,>=2.5 in /usr/local/lib/python3.6/dist-packages (from requests->autogluon) (2.8)\n",
"Requirement already satisfied: chardet<3.1.0,>=3.0.2 in /usr/local/lib/python3.6/dist-packages (from requests->autogluon) (3.0.4)\n",
"Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.6/dist-packages (from requests->autogluon) (2019.11.28)\n",
"Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.6/dist-packages (from matplotlib->autogluon) (0.10.0)\n",
"Requirement already satisfied: kiwisolver>=1.0.1 in /usr/local/lib/python3.6/dist-packages (from matplotlib->autogluon) (1.1.0)\n",
"Requirement already satisfied: python-dateutil>=2.1 in /usr/local/lib/python3.6/dist-packages (from matplotlib->autogluon) (2.6.1)\n",
"Requirement already satisfied: cryptography>=2.5 in /usr/local/lib/python3.6/dist-packages (from paramiko>=2.5.0->autogluon) (2.8)\n",
"Requirement already satisfied: pynacl>=1.0.1 in /usr/local/lib/python3.6/dist-packages (from paramiko>=2.5.0->autogluon) (1.3.0)\n",
"Requirement already satisfied: bcrypt>=3.1.3 in /usr/local/lib/python3.6/dist-packages (from paramiko>=2.5.0->autogluon) (3.1.7)\n",
"Requirement already satisfied: pytz>=2011k in /usr/local/lib/python3.6/dist-packages (from pandas==0.24.2->autogluon) (2018.9)\n",
"Requirement already satisfied: olefile in /usr/local/lib/python3.6/dist-packages (from Pillow->gluoncv>=0.5.0->autogluon) (0.46)\n",
"Requirement already satisfied: heapdict in /usr/local/lib/python3.6/dist-packages (from zict>=0.1.3->distributed>=2.6.0->autogluon) (1.0.1)\n",
"Requirement already satisfied: retrying>=1.3.3 in /usr/local/lib/python3.6/dist-packages (from plotly->catboost->autogluon) (1.3.3)\n",
"Requirement already satisfied: docutils<0.16,>=0.10 in /usr/local/lib/python3.6/dist-packages (from botocore<1.13.0,>=1.12.187->boto3==1.9.187->autogluon) (0.15.2)\n",
"Requirement already satisfied: setuptools in /usr/local/lib/python3.6/dist-packages (from kiwisolver>=1.0.1->matplotlib->autogluon) (42.0.2)\n",
"Requirement already satisfied: cffi!=1.11.3,>=1.8 in /usr/local/lib/python3.6/dist-packages (from cryptography>=2.5->paramiko>=2.5.0->autogluon) (1.13.2)\n",
"Requirement already satisfied: pycparser in /usr/local/lib/python3.6/dist-packages (from cffi!=1.11.3,>=1.8->cryptography>=2.5->paramiko>=2.5.0->autogluon) (2.19)\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "uyUG_yY6aCQ-",
"colab_type": "text"
},
"source": [
"## テーブルデータで試す\n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "1aTzzwWRdg5p",
"colab_type": "text"
},
"source": [
"国税調査のデータadultデータセットを用いて分類問題と回帰問題を解く。"
]
},
{
"cell_type": "code",
"metadata": {
"id": "HonOkoU0Z5iI",
"colab_type": "code",
"outputId": "706ab950-fa18-4b5a-dcf7-ecc5e1057955",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 54
}
},
"source": [
"import autogluon as ag\n",
"from autogluon import TabularPrediction as task\n",
"\n",
"train_data = task.Dataset(file_path='https://autogluon.s3.amazonaws.com/datasets/Inc/train.csv')\n",
"train_data = train_data.head(2000) # subsample 2000 data points for faster demo"
],
"execution_count": 0,
"outputs": [
{
"output_type": "stream",
"text": [
"Loaded data from: https://autogluon.s3.amazonaws.com/datasets/Inc/train.csv | Columns = 15 / 15 | Rows = 39073 -> 39073\n"
],
"name": "stderr"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "pI4yRw38Z56v",
"colab_type": "code",
"outputId": "013946ff-c350-4d4a-a93d-92b69de4556d",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 377
}
},
"source": [
"train_data.head()"
],
"execution_count": 0,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>age</th>\n",
" <th>workclass</th>\n",
" <th>fnlwgt</th>\n",
" <th>education</th>\n",
" <th>education-num</th>\n",
" <th>marital-status</th>\n",
" <th>occupation</th>\n",
" <th>relationship</th>\n",
" <th>race</th>\n",
" <th>sex</th>\n",
" <th>capital-gain</th>\n",
" <th>capital-loss</th>\n",
" <th>hours-per-week</th>\n",
" <th>native-country</th>\n",
" <th>class</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>25</td>\n",
" <td>Private</td>\n",
" <td>178478</td>\n",
" <td>Bachelors</td>\n",
" <td>13</td>\n",
" <td>Never-married</td>\n",
" <td>Tech-support</td>\n",
" <td>Own-child</td>\n",
" <td>White</td>\n",
" <td>Female</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>40</td>\n",
" <td>United-States</td>\n",
" <td>&lt;=50K</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>23</td>\n",
" <td>State-gov</td>\n",
" <td>61743</td>\n",
" <td>5th-6th</td>\n",
" <td>3</td>\n",
" <td>Never-married</td>\n",
" <td>Transport-moving</td>\n",
" <td>Not-in-family</td>\n",
" <td>White</td>\n",
" <td>Male</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>35</td>\n",
" <td>United-States</td>\n",
" <td>&lt;=50K</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>46</td>\n",
" <td>Private</td>\n",
" <td>376789</td>\n",
" <td>HS-grad</td>\n",
" <td>9</td>\n",
" <td>Never-married</td>\n",
" <td>Other-service</td>\n",
" <td>Not-in-family</td>\n",
" <td>White</td>\n",
" <td>Male</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>15</td>\n",
" <td>United-States</td>\n",
" <td>&lt;=50K</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>55</td>\n",
" <td>?</td>\n",
" <td>200235</td>\n",
" <td>HS-grad</td>\n",
" <td>9</td>\n",
" <td>Married-civ-spouse</td>\n",
" <td>?</td>\n",
" <td>Husband</td>\n",
" <td>White</td>\n",
" <td>Male</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>50</td>\n",
" <td>United-States</td>\n",
" <td>&gt;50K</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>36</td>\n",
" <td>Private</td>\n",
" <td>224541</td>\n",
" <td>7th-8th</td>\n",
" <td>4</td>\n",
" <td>Married-civ-spouse</td>\n",
" <td>Handlers-cleaners</td>\n",
" <td>Husband</td>\n",
" <td>White</td>\n",
" <td>Male</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>40</td>\n",
" <td>El-Salvador</td>\n",
" <td>&lt;=50K</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" age workclass fnlwgt ... hours-per-week native-country class\n",
"0 25 Private 178478 ... 40 United-States <=50K\n",
"1 23 State-gov 61743 ... 35 United-States <=50K\n",
"2 46 Private 376789 ... 15 United-States <=50K\n",
"3 55 ? 200235 ... 50 United-States >50K\n",
"4 36 Private 224541 ... 40 El-Salvador <=50K\n",
"\n",
"[5 rows x 15 columns]"
]
},
"metadata": {
"tags": []
},
"execution_count": 5
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "TseD-HCGaIdB",
"colab_type": "text"
},
"source": [
"### 分類\n",
"\n",
"classカラムの分類,すなわち収入が50k以上か未満かの二値分類を行う。"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "f8I_vW_aaue6",
"colab_type": "text"
},
"source": [
"クラスを確認する。"
]
},
{
"cell_type": "code",
"metadata": {
"id": "xuW0zRSdZ5_9",
"colab_type": "code",
"outputId": "b67794f5-98e5-4394-9f08-d591abcb0aa2",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 119
}
},
"source": [
"label_column = 'class'\n",
"print(\"Summary of class variable: \\n\", train_data[label_column].describe())"
],
"execution_count": 0,
"outputs": [
{
"output_type": "stream",
"text": [
"Summary of class variable: \n",
" count 2000\n",
"unique 2\n",
"top <=50K\n",
"freq 1551\n",
"Name: class, dtype: object\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "kUdFkEEbav0b",
"colab_type": "text"
},
"source": [
"とりあえずAutoGluonを実行する。"
]
},
{
"cell_type": "code",
"metadata": {
"id": "fhT0RKuoacps",
"colab_type": "code",
"outputId": "2fa7082b-c172-4017-ff5a-c7b60f727c4b",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 785
}
},
"source": [
"result_dir = 'agModels-predictClass' # specifies folder where to store trained models\n",
"predictor = task.fit(train_data=train_data, label=label_column, output_directory=result_dir)"
],
"execution_count": 0,
"outputs": [
{
"output_type": "stream",
"text": [
"Beginning AutoGluon training ...\n",
"Preprocessing data ...\n",
"Here are the first 10 unique label values in your data: [' <=50K' ' >50K']\n",
"AutoGluon infers your prediction problem is: binary (because only two unique label-values observed)\n",
"If this is wrong, please specify `problem_type` argument in fit() instead (You may specify problem_type as one of: ['binary', 'multiclass', 'regression'])\n",
"\n",
"Selected class <--> label mapping: class 1 = >50K, class 0 = <=50K\n",
"\tData preprocessing and feature engineering runtime = 0.08s ...\n",
"AutoGluon will gauge predictive performance using evaluation metric: accuracy\n",
"To change this, specify the eval_metric argument of fit()\n",
"Fitting model: RandomForestClassifierGini ...\n",
"\t0.72s\t = Training runtime\n",
"\t0.8725\t = Validation accuracy score\n",
"Fitting model: RandomForestClassifierEntr ...\n",
"\t0.8s\t = Training runtime\n",
"\t0.875\t = Validation accuracy score\n",
"Fitting model: ExtraTreesClassifierGini ...\n",
"\t0.6s\t = Training runtime\n",
"\t0.8475\t = Validation accuracy score\n",
"Fitting model: ExtraTreesClassifierEntr ...\n",
"\t0.6s\t = Training runtime\n",
"\t0.8525\t = Validation accuracy score\n",
"Fitting model: KNeighborsClassifierUnif ...\n",
"\t0.01s\t = Training runtime\n",
"\t0.78\t = Validation accuracy score\n",
"Fitting model: KNeighborsClassifierDist ...\n",
"\t0.01s\t = Training runtime\n",
"\t0.7325\t = Validation accuracy score\n",
"Fitting model: LightGBMClassifier ...\n",
"\t0.39s\t = Training runtime\n",
"\t0.8775\t = Validation accuracy score\n",
"Fitting model: CatboostClassifier ...\n",
"\t2.38s\t = Training runtime\n",
"\t0.8725\t = Validation accuracy score\n",
"Fitting model: NeuralNetClassifier ...\n",
"\t15.16s\t = Training runtime\n",
"\t0.8675\t = Validation accuracy score\n",
"Fitting model: LightGBMClassifierCustom ...\n",
"\t0.72s\t = Training runtime\n",
"\t0.855\t = Validation accuracy score\n",
"Fitting model: weighted_ensemble_l1 ...\n",
"\t0.45s\t = Training runtime\n",
"\t0.88\t = Validation accuracy score\n",
"AutoGluon training complete, total runtime = 24.09s ...\n"
],
"name": "stderr"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "LPucLx6ha429",
"colab_type": "text"
},
"source": [
"テストデータをダウンロードする。"
]
},
{
"cell_type": "code",
"metadata": {
"id": "rHfXqg_EagA7",
"colab_type": "code",
"outputId": "00eed8c8-1eaa-4eb0-cbc3-70c3bf3b8271",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 54
}
},
"source": [
"test_data = task.Dataset(file_path='https://autogluon.s3.amazonaws.com/datasets/Inc/test.csv')\n",
"y_test = test_data[label_column] # values to predict\n",
"test_data_nolab = test_data.drop(labels=[label_column],axis=1) "
],
"execution_count": 0,
"outputs": [
{
"output_type": "stream",
"text": [
"Loaded data from: https://autogluon.s3.amazonaws.com/datasets/Inc/test.csv | Columns = 15 / 15 | Rows = 9769 -> 9769\n"
],
"name": "stderr"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "5tTNcSYRa7sg",
"colab_type": "text"
},
"source": [
"モデルを評価する。"
]
},
{
"cell_type": "code",
"metadata": {
"id": "23IKFo1gagUS",
"colab_type": "code",
"outputId": "07474158-ede6-4350-ba6a-05aa63869c4f",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 663
}
},
"source": [
"# unnecessary, just demonstrates how to load previously-trained predictor from file\n",
"predictor = task.load(result_dir) \n",
"\n",
"y_pred = predictor.predict(test_data_nolab)\n",
"print(\"Predictions: \", y_pred)\n",
"perf = predictor.evaluate_predictions(y_true=y_test, y_pred=y_pred, auxiliary_metrics=True)"
],
"execution_count": 0,
"outputs": [
{
"output_type": "stream",
"text": [
"Evaluation: accuracy on test data: 0.856280\n",
"Evaluations on test data:\n",
"{\n",
" \"accuracy\": 0.8562800696079435,\n",
" \"accuracy_score\": 0.8562800696079435,\n",
" \"balanced_accuracy_score\": 0.7625358844305661,\n",
" \"matthews_corrcoef\": 0.5769073882148189,\n",
" \"f1_score\": 0.8562800696079435\n",
"}\n"
],
"name": "stderr"
},
{
"output_type": "stream",
"text": [
"Predictions: [' <=50K' ' <=50K' ' >50K' ... ' <=50K' ' <=50K' ' <=50K']\n"
],
"name": "stdout"
},
{
"output_type": "stream",
"text": [
"Detailed (per-class) classification report:\n",
"{\n",
" \" <=50K\": {\n",
" \"precision\": 0.8791222570532915,\n",
" \"recall\": 0.9409475238223057,\n",
" \"f1-score\": 0.9089848308051341,\n",
" \"support\": 7451\n",
" },\n",
" \" >50K\": {\n",
" \"precision\": 0.7547380156075808,\n",
" \"recall\": 0.5841242450388265,\n",
" \"f1-score\": 0.6585603112840467,\n",
" \"support\": 2318\n",
" },\n",
" \"accuracy\": 0.8562800696079435,\n",
" \"macro avg\": {\n",
" \"precision\": 0.8169301363304362,\n",
" \"recall\": 0.7625358844305661,\n",
" \"f1-score\": 0.7837725710445904,\n",
" \"support\": 9769\n",
" },\n",
" \"weighted avg\": {\n",
" \"precision\": 0.8496082155269166,\n",
" \"recall\": 0.8562800696079435,\n",
" \"f1-score\": 0.8495638014009084,\n",
" \"support\": 9769\n",
" }\n",
"}\n"
],
"name": "stderr"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "lD0LNVMJaqRI",
"colab_type": "text"
},
"source": [
"探索の様子を確認する。"
]
},
{
"cell_type": "code",
"metadata": {
"id": "tAjW4yaaaliD",
"colab_type": "code",
"outputId": "0f4a0147-1c09-456b-f293-5ec47cce6fe1",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 258
}
},
"source": [
"results = predictor.fit_summary()\n"
],
"execution_count": 0,
"outputs": [
{
"output_type": "stream",
"text": [
"*** Summary of fit() ***\n",
"Number of models trained: 11\n",
"Types of models trained: \n",
"{'KNNModel', 'TabularNeuralNetModel', 'CatboostModel', 'LGBModel', 'WeightedEnsembleModel', 'RFModel'}\n",
"Validation performance of individual models: {'RandomForestClassifierGini': 0.8725, 'RandomForestClassifierEntr': 0.875, 'ExtraTreesClassifierGini': 0.8475, 'ExtraTreesClassifierEntr': 0.8525, 'KNeighborsClassifierUnif': 0.78, 'KNeighborsClassifierDist': 0.7325, 'LightGBMClassifier': 0.8775, 'CatboostClassifier': 0.8725, 'NeuralNetClassifier': 0.8675, 'LightGBMClassifierCustom': 0.855, 'weighted_ensemble_l1': 0.88}\n",
"Best model (based on validation performance): weighted_ensemble_l1\n",
"Hyperparameter-tuning used: False\n",
"Bagging used: False \n",
"Stack-ensembling used: False \n",
"User-specified hyperparameters:\n",
"{'NN': {'num_epochs': 500}, 'GBM': {'num_boost_round': 10000}, 'CAT': {'iterations': 10000}, 'RF': {'n_estimators': 300}, 'XT': {'n_estimators': 300}, 'KNN': {}, 'custom': ['GBM']}\n",
"Plot summary of models saved to file: SummaryOfModels.html\n",
"*** End of fit() summary ***\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "SYQIMvyDbEBS",
"colab_type": "text"
},
"source": [
"モデルがタスクと、カラムをどう認識したか確認する。"
]
},
{
"cell_type": "code",
"metadata": {
"id": "dcVTZ4qjbDM4",
"colab_type": "code",
"outputId": "9f480390-c6b8-40c8-8a26-df74c3d64aaf",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 71
}
},
"source": [
"print(\"AutoGluon infers problem type is: \", predictor.problem_type)\n",
"print(\"AutoGluon categorized the features as: \", predictor.feature_types)"
],
"execution_count": 0,
"outputs": [
{
"output_type": "stream",
"text": [
"AutoGluon infers problem type is: binary\n",
"AutoGluon categorized the features as: {'nlp': [], 'vectorizers': [], 'int': ['age', 'fnlwgt', 'education-num', 'capital-gain', 'capital-loss', 'hours-per-week'], 'object': ['workclass', 'education', 'marital-status', 'occupation', 'relationship', 'race', 'sex', 'native-country']}\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "5C2lmOHUaMt7",
"colab_type": "text"
},
"source": [
"### 回帰\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "vChVHi2bbM3k",
"colab_type": "text"
},
"source": [
"年齢をtarget varとして回帰する。"
]
},
{
"cell_type": "code",
"metadata": {
"id": "An2PitqfaNBr",
"colab_type": "code",
"outputId": "8c0f3880-e5a3-4ef5-d356-b8a4d72b0d32",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 187
}
},
"source": [
"age_column = 'age'\n",
"print(\"Summary of age variable: \\n\", train_data[age_column].describe())"
],
"execution_count": 0,
"outputs": [
{
"output_type": "stream",
"text": [
"Summary of age variable: \n",
" count 2000.000000\n",
"mean 38.375000\n",
"std 13.781195\n",
"min 17.000000\n",
"25% 27.000000\n",
"50% 37.000000\n",
"75% 47.000000\n",
"max 90.000000\n",
"Name: age, dtype: float64\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "_wkdvWk6eIoh",
"colab_type": "text"
},
"source": [
"`evaluate()` で評価する。\n"
]
},
{
"cell_type": "code",
"metadata": {
"id": "xxT2HQOkeFAr",
"colab_type": "code",
"outputId": "844fea81-f834-406a-de24-a47d661ee43f",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 632
}
},
"source": [
"predictor_age = task.fit(train_data=train_data, \n",
" output_directory=\"agModels-predictAge\", \n",
" label=age_column,\n",
" problem_type='regression')\n",
"performance = predictor_age.evaluate(test_data)"
],
"execution_count": 0,
"outputs": [
{
"output_type": "stream",
"text": [
"Beginning AutoGluon training ...\n",
"Preprocessing data ...\n",
"\tData preprocessing and feature engineering runtime = 0.08s ...\n",
"AutoGluon will gauge predictive performance using evaluation metric: root_mean_squared_error\n",
"To change this, specify the eval_metric argument of fit()\n",
"Fitting model: RandomForestRegressorMSE ...\n",
"NumExpr defaulting to 2 threads.\n",
"\t1.51s\t = Training runtime\n",
"\t-11.08\t = Validation root_mean_squared_error score\n",
"Fitting model: ExtraTreesRegressorMSE ...\n",
"\t1.1s\t = Training runtime\n",
"\t-11.2898\t = Validation root_mean_squared_error score\n",
"Fitting model: KNeighborsRegressorUnif ...\n",
"\t0.01s\t = Training runtime\n",
"\t-14.5251\t = Validation root_mean_squared_error score\n",
"Fitting model: KNeighborsRegressorDist ...\n",
"\t0.01s\t = Training runtime\n",
"\t-15.2104\t = Validation root_mean_squared_error score\n",
"Fitting model: LightGBMRegressor ...\n",
"\t0.34s\t = Training runtime\n",
"\t-10.311\t = Validation root_mean_squared_error score\n",
"Fitting model: CatboostRegressor ...\n",
"\t2.16s\t = Training runtime\n",
"\t-10.4003\t = Validation root_mean_squared_error score\n",
"Fitting model: NeuralNetRegressor ...\n",
"\t13.07s\t = Training runtime\n",
"\t-10.6561\t = Validation root_mean_squared_error score\n",
"Fitting model: LightGBMRegressorCustom ...\n",
"\t0.61s\t = Training runtime\n",
"\t-10.5824\t = Validation root_mean_squared_error score\n",
"Fitting model: weighted_ensemble_l1 ...\n",
"\t0.43s\t = Training runtime\n",
"\t-10.193\t = Validation root_mean_squared_error score\n",
"AutoGluon training complete, total runtime = 21.12s ...\n"
],
"name": "stderr"
},
{
"output_type": "stream",
"text": [
"Predictive performance on given dataset: root_mean_squared_error = 10.27827195078937\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "XNGO3Ds-ePk5",
"colab_type": "text"
},
"source": [
"結果を確認する。"
]
},
{
"cell_type": "code",
"metadata": {
"id": "pu3yKie4ePsk",
"colab_type": "code",
"outputId": "ede76091-30ce-4302-f27e-0cd50dece0a4",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 258
}
},
"source": [
"results = predictor_age.fit_summary()"
],
"execution_count": 0,
"outputs": [
{
"output_type": "stream",
"text": [
"*** Summary of fit() ***\n",
"Number of models trained: 9\n",
"Types of models trained: \n",
"{'KNNModel', 'TabularNeuralNetModel', 'CatboostModel', 'LGBModel', 'WeightedEnsembleModel', 'RFModel'}\n",
"Validation performance of individual models: {'RandomForestRegressorMSE': -11.079997378910637, 'ExtraTreesRegressorMSE': -11.289767107134379, 'KNeighborsRegressorUnif': -14.525081755363722, 'KNeighborsRegressorDist': -15.210358581022946, 'LightGBMRegressor': -10.31100588616801, 'CatboostRegressor': -10.400331629210777, 'NeuralNetRegressor': -10.656088855066407, 'LightGBMRegressorCustom': -10.582398535859294, 'weighted_ensemble_l1': -10.193049558344967}\n",
"Best model (based on validation performance): weighted_ensemble_l1\n",
"Hyperparameter-tuning used: False\n",
"Bagging used: False \n",
"Stack-ensembling used: False \n",
"User-specified hyperparameters:\n",
"{'NN': {'num_epochs': 500}, 'GBM': {'num_boost_round': 10000}, 'CAT': {'iterations': 10000}, 'RF': {'n_estimators': 300}, 'XT': {'n_estimators': 300}, 'KNN': {}, 'custom': ['GBM']}\n",
"Plot summary of models saved to file: SummaryOfModels.html\n",
"*** End of fit() summary ***\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "dKWhS31qeM-W",
"colab_type": "code",
"colab": {}
},
"source": [
""
],
"execution_count": 0,
"outputs": []
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment