Skip to content

Instantly share code, notes, and snippets.

@nogawanogawa
Created July 7, 2021 13:47
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save nogawanogawa/6298cd61010044e6a9b5ea1e44cb6bab to your computer and use it in GitHub Desktop.
Save nogawanogawa/6298cd61010044e6a9b5ea1e44cb6bab to your computer and use it in GitHub Desktop.
xfeat_sample_3.ipynb
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "xfeat_sample_3.ipynb",
"provenance": [],
"collapsed_sections": [],
"authorship_tag": "ABX9TyOv7qlRz1M6FuVtcevVn1GE",
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
}
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/nogawanogawa/6298cd61010044e6a9b5ea1e44cb6bab/xfeat_sample_3.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "Ofldzv_uKFJ1",
"outputId": "3c3d181d-0f2e-4464-df57-1617e6770a68"
},
"source": [
"!pip install xfeat "
],
"execution_count": 1,
"outputs": [
{
"output_type": "stream",
"text": [
"Requirement already satisfied: xfeat in /usr/local/lib/python3.7/dist-packages (0.1.1)\n",
"Requirement already satisfied: lightgbm in /usr/local/lib/python3.7/dist-packages (from xfeat) (2.2.3)\n",
"Requirement already satisfied: ml-metrics in /usr/local/lib/python3.7/dist-packages (from xfeat) (0.1.4)\n",
"Requirement already satisfied: optuna>=1.3.0 in /usr/local/lib/python3.7/dist-packages (from xfeat) (2.8.0)\n",
"Requirement already satisfied: pyarrow in /usr/local/lib/python3.7/dist-packages (from xfeat) (3.0.0)\n",
"Requirement already satisfied: PyYAML in /usr/local/lib/python3.7/dist-packages (from xfeat) (3.13)\n",
"Requirement already satisfied: scikit-learn in /usr/local/lib/python3.7/dist-packages (from xfeat) (0.22.2.post1)\n",
"Requirement already satisfied: numpy in /usr/local/lib/python3.7/dist-packages (from lightgbm->xfeat) (1.19.5)\n",
"Requirement already satisfied: scipy in /usr/local/lib/python3.7/dist-packages (from lightgbm->xfeat) (1.4.1)\n",
"Requirement already satisfied: pandas in /usr/local/lib/python3.7/dist-packages (from ml-metrics->xfeat) (1.1.5)\n",
"Requirement already satisfied: tqdm in /usr/local/lib/python3.7/dist-packages (from optuna>=1.3.0->xfeat) (4.41.1)\n",
"Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.7/dist-packages (from optuna>=1.3.0->xfeat) (20.9)\n",
"Requirement already satisfied: cliff in /usr/local/lib/python3.7/dist-packages (from optuna>=1.3.0->xfeat) (3.8.0)\n",
"Requirement already satisfied: colorlog in /usr/local/lib/python3.7/dist-packages (from optuna>=1.3.0->xfeat) (5.0.1)\n",
"Requirement already satisfied: alembic in /usr/local/lib/python3.7/dist-packages (from optuna>=1.3.0->xfeat) (1.6.5)\n",
"Requirement already satisfied: sqlalchemy>=1.1.0 in /usr/local/lib/python3.7/dist-packages (from optuna>=1.3.0->xfeat) (1.4.18)\n",
"Requirement already satisfied: cmaes>=0.8.2 in /usr/local/lib/python3.7/dist-packages (from optuna>=1.3.0->xfeat) (0.8.2)\n",
"Requirement already satisfied: joblib>=0.11 in /usr/local/lib/python3.7/dist-packages (from scikit-learn->xfeat) (1.0.1)\n",
"Requirement already satisfied: python-dateutil>=2.7.3 in /usr/local/lib/python3.7/dist-packages (from pandas->ml-metrics->xfeat) (2.8.1)\n",
"Requirement already satisfied: pytz>=2017.2 in /usr/local/lib/python3.7/dist-packages (from pandas->ml-metrics->xfeat) (2018.9)\n",
"Requirement already satisfied: pyparsing>=2.0.2 in /usr/local/lib/python3.7/dist-packages (from packaging>=20.0->optuna>=1.3.0->xfeat) (2.4.7)\n",
"Requirement already satisfied: pbr!=2.1.0,>=2.0.0 in /usr/local/lib/python3.7/dist-packages (from cliff->optuna>=1.3.0->xfeat) (5.6.0)\n",
"Requirement already satisfied: stevedore>=2.0.1 in /usr/local/lib/python3.7/dist-packages (from cliff->optuna>=1.3.0->xfeat) (3.3.0)\n",
"Requirement already satisfied: cmd2>=1.0.0 in /usr/local/lib/python3.7/dist-packages (from cliff->optuna>=1.3.0->xfeat) (2.1.2)\n",
"Requirement already satisfied: PrettyTable>=0.7.2 in /usr/local/lib/python3.7/dist-packages (from cliff->optuna>=1.3.0->xfeat) (2.1.0)\n",
"Requirement already satisfied: Mako in /usr/local/lib/python3.7/dist-packages (from alembic->optuna>=1.3.0->xfeat) (1.1.4)\n",
"Requirement already satisfied: python-editor>=0.3 in /usr/local/lib/python3.7/dist-packages (from alembic->optuna>=1.3.0->xfeat) (1.0.4)\n",
"Requirement already satisfied: greenlet!=0.4.17; python_version >= \"3\" in /usr/local/lib/python3.7/dist-packages (from sqlalchemy>=1.1.0->optuna>=1.3.0->xfeat) (1.1.0)\n",
"Requirement already satisfied: importlib-metadata; python_version < \"3.8\" in /usr/local/lib/python3.7/dist-packages (from sqlalchemy>=1.1.0->optuna>=1.3.0->xfeat) (4.5.0)\n",
"Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.7/dist-packages (from python-dateutil>=2.7.3->pandas->ml-metrics->xfeat) (1.15.0)\n",
"Requirement already satisfied: wcwidth>=0.1.7 in /usr/local/lib/python3.7/dist-packages (from cmd2>=1.0.0->cliff->optuna>=1.3.0->xfeat) (0.2.5)\n",
"Requirement already satisfied: pyperclip>=1.6 in /usr/local/lib/python3.7/dist-packages (from cmd2>=1.0.0->cliff->optuna>=1.3.0->xfeat) (1.8.2)\n",
"Requirement already satisfied: attrs>=16.3.0 in /usr/local/lib/python3.7/dist-packages (from cmd2>=1.0.0->cliff->optuna>=1.3.0->xfeat) (21.2.0)\n",
"Requirement already satisfied: typing-extensions; python_version < \"3.8\" in /usr/local/lib/python3.7/dist-packages (from cmd2>=1.0.0->cliff->optuna>=1.3.0->xfeat) (3.7.4.3)\n",
"Requirement already satisfied: colorama>=0.3.7 in /usr/local/lib/python3.7/dist-packages (from cmd2>=1.0.0->cliff->optuna>=1.3.0->xfeat) (0.4.4)\n",
"Requirement already satisfied: MarkupSafe>=0.9.2 in /usr/local/lib/python3.7/dist-packages (from Mako->alembic->optuna>=1.3.0->xfeat) (2.0.1)\n",
"Requirement already satisfied: zipp>=0.5 in /usr/local/lib/python3.7/dist-packages (from importlib-metadata; python_version < \"3.8\"->sqlalchemy>=1.1.0->optuna>=1.3.0->xfeat) (3.4.1)\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "w8c4kGAuLEtv"
},
"source": [
"from functools import partial\n",
"\n",
"from sklearn.datasets import fetch_california_housing\n",
"from sklearn.metrics import mean_squared_error\n",
"from sklearn.model_selection import train_test_split\n",
"import pandas as pd\n",
"import numpy as np\n",
"import lightgbm as lgb\n",
"import optuna\n",
"\n",
"from xfeat import ArithmeticCombinations, Pipeline\n",
"from xfeat import GBDTFeatureExplorer\n",
"from xfeat import SelectNumerical"
],
"execution_count": 2,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 203
},
"id": "1BCab-CzLNAx",
"outputId": "aa6c7a70-b5d1-4d84-93aa-0b960aaa3113"
},
"source": [
"data = fetch_california_housing()\n",
"df = pd.DataFrame(\n",
" data=data.data,\n",
" columns=data.feature_names)\n",
"df.head()"
],
"execution_count": 3,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>MedInc</th>\n",
" <th>HouseAge</th>\n",
" <th>AveRooms</th>\n",
" <th>AveBedrms</th>\n",
" <th>Population</th>\n",
" <th>AveOccup</th>\n",
" <th>Latitude</th>\n",
" <th>Longitude</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>8.3252</td>\n",
" <td>41.0</td>\n",
" <td>6.984127</td>\n",
" <td>1.023810</td>\n",
" <td>322.0</td>\n",
" <td>2.555556</td>\n",
" <td>37.88</td>\n",
" <td>-122.23</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>8.3014</td>\n",
" <td>21.0</td>\n",
" <td>6.238137</td>\n",
" <td>0.971880</td>\n",
" <td>2401.0</td>\n",
" <td>2.109842</td>\n",
" <td>37.86</td>\n",
" <td>-122.22</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>7.2574</td>\n",
" <td>52.0</td>\n",
" <td>8.288136</td>\n",
" <td>1.073446</td>\n",
" <td>496.0</td>\n",
" <td>2.802260</td>\n",
" <td>37.85</td>\n",
" <td>-122.24</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>5.6431</td>\n",
" <td>52.0</td>\n",
" <td>5.817352</td>\n",
" <td>1.073059</td>\n",
" <td>558.0</td>\n",
" <td>2.547945</td>\n",
" <td>37.85</td>\n",
" <td>-122.25</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>3.8462</td>\n",
" <td>52.0</td>\n",
" <td>6.281853</td>\n",
" <td>1.081081</td>\n",
" <td>565.0</td>\n",
" <td>2.181467</td>\n",
" <td>37.85</td>\n",
" <td>-122.25</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" MedInc HouseAge AveRooms ... AveOccup Latitude Longitude\n",
"0 8.3252 41.0 6.984127 ... 2.555556 37.88 -122.23\n",
"1 8.3014 21.0 6.238137 ... 2.109842 37.86 -122.22\n",
"2 7.2574 52.0 8.288136 ... 2.802260 37.85 -122.24\n",
"3 5.6431 52.0 5.817352 ... 2.547945 37.85 -122.25\n",
"4 3.8462 52.0 6.281853 ... 2.181467 37.85 -122.25\n",
"\n",
"[5 rows x 8 columns]"
]
},
"metadata": {
"tags": []
},
"execution_count": 3
}
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 203
},
"id": "c6BKkkfFLQ-1",
"outputId": "c3ed9a10-9b34-4374-ffc4-f8b4399c9df4"
},
"source": [
"SelectNumerical().fit_transform(df).head()"
],
"execution_count": 4,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>MedInc</th>\n",
" <th>HouseAge</th>\n",
" <th>AveRooms</th>\n",
" <th>AveBedrms</th>\n",
" <th>Population</th>\n",
" <th>AveOccup</th>\n",
" <th>Latitude</th>\n",
" <th>Longitude</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>8.3252</td>\n",
" <td>41.0</td>\n",
" <td>6.984127</td>\n",
" <td>1.023810</td>\n",
" <td>322.0</td>\n",
" <td>2.555556</td>\n",
" <td>37.88</td>\n",
" <td>-122.23</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>8.3014</td>\n",
" <td>21.0</td>\n",
" <td>6.238137</td>\n",
" <td>0.971880</td>\n",
" <td>2401.0</td>\n",
" <td>2.109842</td>\n",
" <td>37.86</td>\n",
" <td>-122.22</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>7.2574</td>\n",
" <td>52.0</td>\n",
" <td>8.288136</td>\n",
" <td>1.073446</td>\n",
" <td>496.0</td>\n",
" <td>2.802260</td>\n",
" <td>37.85</td>\n",
" <td>-122.24</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>5.6431</td>\n",
" <td>52.0</td>\n",
" <td>5.817352</td>\n",
" <td>1.073059</td>\n",
" <td>558.0</td>\n",
" <td>2.547945</td>\n",
" <td>37.85</td>\n",
" <td>-122.25</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>3.8462</td>\n",
" <td>52.0</td>\n",
" <td>6.281853</td>\n",
" <td>1.081081</td>\n",
" <td>565.0</td>\n",
" <td>2.181467</td>\n",
" <td>37.85</td>\n",
" <td>-122.25</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" MedInc HouseAge AveRooms ... AveOccup Latitude Longitude\n",
"0 8.3252 41.0 6.984127 ... 2.555556 37.88 -122.23\n",
"1 8.3014 21.0 6.238137 ... 2.109842 37.86 -122.22\n",
"2 7.2574 52.0 8.288136 ... 2.802260 37.85 -122.24\n",
"3 5.6431 52.0 5.817352 ... 2.547945 37.85 -122.25\n",
"4 3.8462 52.0 6.281853 ... 2.181467 37.85 -122.25\n",
"\n",
"[5 rows x 8 columns]"
]
},
"metadata": {
"tags": []
},
"execution_count": 4
}
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 223
},
"id": "-8crzvzlLS5i",
"outputId": "c51b2871-cab1-4b94-f483-6bc86c8e9eda"
},
"source": [
"encoder = Pipeline([\n",
" SelectNumerical(),\n",
" ArithmeticCombinations(\n",
" drop_origin=False,\n",
" operator=\"+\",\n",
" r=3,\n",
" output_suffix=\"\",\n",
" ),\n",
"])\n",
"encoder.fit_transform(df).head()"
],
"execution_count": 5,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>MedInc</th>\n",
" <th>HouseAge</th>\n",
" <th>AveRooms</th>\n",
" <th>AveBedrms</th>\n",
" <th>Population</th>\n",
" <th>AveOccup</th>\n",
" <th>Latitude</th>\n",
" <th>Longitude</th>\n",
" <th>MedIncHouseAgeAveRooms</th>\n",
" <th>MedIncHouseAgeAveBedrms</th>\n",
" <th>MedIncHouseAgePopulation</th>\n",
" <th>MedIncHouseAgeAveOccup</th>\n",
" <th>MedIncHouseAgeLatitude</th>\n",
" <th>MedIncHouseAgeLongitude</th>\n",
" <th>MedIncAveRoomsAveBedrms</th>\n",
" <th>MedIncAveRoomsPopulation</th>\n",
" <th>MedIncAveRoomsAveOccup</th>\n",
" <th>MedIncAveRoomsLatitude</th>\n",
" <th>MedIncAveRoomsLongitude</th>\n",
" <th>MedIncAveBedrmsPopulation</th>\n",
" <th>MedIncAveBedrmsAveOccup</th>\n",
" <th>MedIncAveBedrmsLatitude</th>\n",
" <th>MedIncAveBedrmsLongitude</th>\n",
" <th>MedIncPopulationAveOccup</th>\n",
" <th>MedIncPopulationLatitude</th>\n",
" <th>MedIncPopulationLongitude</th>\n",
" <th>MedIncAveOccupLatitude</th>\n",
" <th>MedIncAveOccupLongitude</th>\n",
" <th>MedIncLatitudeLongitude</th>\n",
" <th>HouseAgeAveRoomsAveBedrms</th>\n",
" <th>HouseAgeAveRoomsPopulation</th>\n",
" <th>HouseAgeAveRoomsAveOccup</th>\n",
" <th>HouseAgeAveRoomsLatitude</th>\n",
" <th>HouseAgeAveRoomsLongitude</th>\n",
" <th>HouseAgeAveBedrmsPopulation</th>\n",
" <th>HouseAgeAveBedrmsAveOccup</th>\n",
" <th>HouseAgeAveBedrmsLatitude</th>\n",
" <th>HouseAgeAveBedrmsLongitude</th>\n",
" <th>HouseAgePopulationAveOccup</th>\n",
" <th>HouseAgePopulationLatitude</th>\n",
" <th>HouseAgePopulationLongitude</th>\n",
" <th>HouseAgeAveOccupLatitude</th>\n",
" <th>HouseAgeAveOccupLongitude</th>\n",
" <th>HouseAgeLatitudeLongitude</th>\n",
" <th>AveRoomsAveBedrmsPopulation</th>\n",
" <th>AveRoomsAveBedrmsAveOccup</th>\n",
" <th>AveRoomsAveBedrmsLatitude</th>\n",
" <th>AveRoomsAveBedrmsLongitude</th>\n",
" <th>AveRoomsPopulationAveOccup</th>\n",
" <th>AveRoomsPopulationLatitude</th>\n",
" <th>AveRoomsPopulationLongitude</th>\n",
" <th>AveRoomsAveOccupLatitude</th>\n",
" <th>AveRoomsAveOccupLongitude</th>\n",
" <th>AveRoomsLatitudeLongitude</th>\n",
" <th>AveBedrmsPopulationAveOccup</th>\n",
" <th>AveBedrmsPopulationLatitude</th>\n",
" <th>AveBedrmsPopulationLongitude</th>\n",
" <th>AveBedrmsAveOccupLatitude</th>\n",
" <th>AveBedrmsAveOccupLongitude</th>\n",
" <th>AveBedrmsLatitudeLongitude</th>\n",
" <th>PopulationAveOccupLatitude</th>\n",
" <th>PopulationAveOccupLongitude</th>\n",
" <th>PopulationLatitudeLongitude</th>\n",
" <th>AveOccupLatitudeLongitude</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>8.3252</td>\n",
" <td>41.0</td>\n",
" <td>6.984127</td>\n",
" <td>1.023810</td>\n",
" <td>322.0</td>\n",
" <td>2.555556</td>\n",
" <td>37.88</td>\n",
" <td>-122.23</td>\n",
" <td>56.309327</td>\n",
" <td>50.349010</td>\n",
" <td>371.3252</td>\n",
" <td>51.880756</td>\n",
" <td>87.2052</td>\n",
" <td>-72.9048</td>\n",
" <td>16.333137</td>\n",
" <td>337.309327</td>\n",
" <td>17.864883</td>\n",
" <td>53.189327</td>\n",
" <td>-106.920673</td>\n",
" <td>331.349010</td>\n",
" <td>11.904565</td>\n",
" <td>47.229010</td>\n",
" <td>-112.880990</td>\n",
" <td>332.880756</td>\n",
" <td>368.2052</td>\n",
" <td>208.0952</td>\n",
" <td>48.760756</td>\n",
" <td>-111.349244</td>\n",
" <td>-76.0248</td>\n",
" <td>49.007937</td>\n",
" <td>369.984127</td>\n",
" <td>50.539683</td>\n",
" <td>85.864127</td>\n",
" <td>-74.245873</td>\n",
" <td>364.023810</td>\n",
" <td>44.579365</td>\n",
" <td>79.903810</td>\n",
" <td>-80.206190</td>\n",
" <td>365.555556</td>\n",
" <td>400.88</td>\n",
" <td>240.77</td>\n",
" <td>81.435556</td>\n",
" <td>-78.674444</td>\n",
" <td>-43.35</td>\n",
" <td>330.007937</td>\n",
" <td>10.563492</td>\n",
" <td>45.887937</td>\n",
" <td>-114.222063</td>\n",
" <td>331.539683</td>\n",
" <td>366.864127</td>\n",
" <td>206.754127</td>\n",
" <td>47.419683</td>\n",
" <td>-112.690317</td>\n",
" <td>-77.365873</td>\n",
" <td>325.579365</td>\n",
" <td>360.903810</td>\n",
" <td>200.793810</td>\n",
" <td>41.459365</td>\n",
" <td>-118.650635</td>\n",
" <td>-83.326190</td>\n",
" <td>362.435556</td>\n",
" <td>202.325556</td>\n",
" <td>237.65</td>\n",
" <td>-81.794444</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>8.3014</td>\n",
" <td>21.0</td>\n",
" <td>6.238137</td>\n",
" <td>0.971880</td>\n",
" <td>2401.0</td>\n",
" <td>2.109842</td>\n",
" <td>37.86</td>\n",
" <td>-122.22</td>\n",
" <td>35.539537</td>\n",
" <td>30.273280</td>\n",
" <td>2430.3014</td>\n",
" <td>31.411242</td>\n",
" <td>67.1614</td>\n",
" <td>-92.9186</td>\n",
" <td>15.511418</td>\n",
" <td>2415.539537</td>\n",
" <td>16.649379</td>\n",
" <td>52.399537</td>\n",
" <td>-107.680463</td>\n",
" <td>2410.273280</td>\n",
" <td>11.383122</td>\n",
" <td>47.133280</td>\n",
" <td>-112.946720</td>\n",
" <td>2411.411242</td>\n",
" <td>2447.1614</td>\n",
" <td>2287.0814</td>\n",
" <td>48.271242</td>\n",
" <td>-111.808758</td>\n",
" <td>-76.0586</td>\n",
" <td>28.210018</td>\n",
" <td>2428.238137</td>\n",
" <td>29.347979</td>\n",
" <td>65.098137</td>\n",
" <td>-94.981863</td>\n",
" <td>2422.971880</td>\n",
" <td>24.081722</td>\n",
" <td>59.831880</td>\n",
" <td>-100.248120</td>\n",
" <td>2424.109842</td>\n",
" <td>2459.86</td>\n",
" <td>2299.78</td>\n",
" <td>60.969842</td>\n",
" <td>-99.110158</td>\n",
" <td>-63.36</td>\n",
" <td>2408.210018</td>\n",
" <td>9.319859</td>\n",
" <td>45.070018</td>\n",
" <td>-115.009982</td>\n",
" <td>2409.347979</td>\n",
" <td>2445.098137</td>\n",
" <td>2285.018137</td>\n",
" <td>46.207979</td>\n",
" <td>-113.872021</td>\n",
" <td>-78.121863</td>\n",
" <td>2404.081722</td>\n",
" <td>2439.831880</td>\n",
" <td>2279.751880</td>\n",
" <td>40.941722</td>\n",
" <td>-119.138278</td>\n",
" <td>-83.388120</td>\n",
" <td>2440.969842</td>\n",
" <td>2280.889842</td>\n",
" <td>2316.64</td>\n",
" <td>-82.250158</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>7.2574</td>\n",
" <td>52.0</td>\n",
" <td>8.288136</td>\n",
" <td>1.073446</td>\n",
" <td>496.0</td>\n",
" <td>2.802260</td>\n",
" <td>37.85</td>\n",
" <td>-122.24</td>\n",
" <td>67.545536</td>\n",
" <td>60.330846</td>\n",
" <td>555.2574</td>\n",
" <td>62.059660</td>\n",
" <td>97.1074</td>\n",
" <td>-62.9826</td>\n",
" <td>16.618982</td>\n",
" <td>511.545536</td>\n",
" <td>18.347795</td>\n",
" <td>53.395536</td>\n",
" <td>-106.694464</td>\n",
" <td>504.330846</td>\n",
" <td>11.133106</td>\n",
" <td>46.180846</td>\n",
" <td>-113.909154</td>\n",
" <td>506.059660</td>\n",
" <td>541.1074</td>\n",
" <td>381.0174</td>\n",
" <td>47.909660</td>\n",
" <td>-112.180340</td>\n",
" <td>-77.1326</td>\n",
" <td>61.361582</td>\n",
" <td>556.288136</td>\n",
" <td>63.090395</td>\n",
" <td>98.138136</td>\n",
" <td>-61.951864</td>\n",
" <td>549.073446</td>\n",
" <td>55.875706</td>\n",
" <td>90.923446</td>\n",
" <td>-69.166554</td>\n",
" <td>550.802260</td>\n",
" <td>585.85</td>\n",
" <td>425.76</td>\n",
" <td>92.652260</td>\n",
" <td>-67.437740</td>\n",
" <td>-32.39</td>\n",
" <td>505.361582</td>\n",
" <td>12.163842</td>\n",
" <td>47.211582</td>\n",
" <td>-112.878418</td>\n",
" <td>507.090395</td>\n",
" <td>542.138136</td>\n",
" <td>382.048136</td>\n",
" <td>48.940395</td>\n",
" <td>-111.149605</td>\n",
" <td>-76.101864</td>\n",
" <td>499.875706</td>\n",
" <td>534.923446</td>\n",
" <td>374.833446</td>\n",
" <td>41.725706</td>\n",
" <td>-118.364294</td>\n",
" <td>-83.316554</td>\n",
" <td>536.652260</td>\n",
" <td>376.562260</td>\n",
" <td>411.61</td>\n",
" <td>-81.587740</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>5.6431</td>\n",
" <td>52.0</td>\n",
" <td>5.817352</td>\n",
" <td>1.073059</td>\n",
" <td>558.0</td>\n",
" <td>2.547945</td>\n",
" <td>37.85</td>\n",
" <td>-122.25</td>\n",
" <td>63.460452</td>\n",
" <td>58.716159</td>\n",
" <td>615.6431</td>\n",
" <td>60.191045</td>\n",
" <td>95.4931</td>\n",
" <td>-64.6069</td>\n",
" <td>12.533511</td>\n",
" <td>569.460452</td>\n",
" <td>14.008397</td>\n",
" <td>49.310452</td>\n",
" <td>-110.789548</td>\n",
" <td>564.716159</td>\n",
" <td>9.264105</td>\n",
" <td>44.566159</td>\n",
" <td>-115.533841</td>\n",
" <td>566.191045</td>\n",
" <td>601.4931</td>\n",
" <td>441.3931</td>\n",
" <td>46.041045</td>\n",
" <td>-114.058955</td>\n",
" <td>-78.7569</td>\n",
" <td>58.890411</td>\n",
" <td>615.817352</td>\n",
" <td>60.365297</td>\n",
" <td>95.667352</td>\n",
" <td>-64.432648</td>\n",
" <td>611.073059</td>\n",
" <td>55.621005</td>\n",
" <td>90.923059</td>\n",
" <td>-69.176941</td>\n",
" <td>612.547945</td>\n",
" <td>647.85</td>\n",
" <td>487.75</td>\n",
" <td>92.397945</td>\n",
" <td>-67.702055</td>\n",
" <td>-32.40</td>\n",
" <td>564.890411</td>\n",
" <td>9.438356</td>\n",
" <td>44.740411</td>\n",
" <td>-115.359589</td>\n",
" <td>566.365297</td>\n",
" <td>601.667352</td>\n",
" <td>441.567352</td>\n",
" <td>46.215297</td>\n",
" <td>-113.884703</td>\n",
" <td>-78.582648</td>\n",
" <td>561.621005</td>\n",
" <td>596.923059</td>\n",
" <td>436.823059</td>\n",
" <td>41.471005</td>\n",
" <td>-118.628995</td>\n",
" <td>-83.326941</td>\n",
" <td>598.397945</td>\n",
" <td>438.297945</td>\n",
" <td>473.60</td>\n",
" <td>-81.852055</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>3.8462</td>\n",
" <td>52.0</td>\n",
" <td>6.281853</td>\n",
" <td>1.081081</td>\n",
" <td>565.0</td>\n",
" <td>2.181467</td>\n",
" <td>37.85</td>\n",
" <td>-122.25</td>\n",
" <td>62.128053</td>\n",
" <td>56.927281</td>\n",
" <td>620.8462</td>\n",
" <td>58.027667</td>\n",
" <td>93.6962</td>\n",
" <td>-66.4038</td>\n",
" <td>11.209134</td>\n",
" <td>575.128053</td>\n",
" <td>12.309520</td>\n",
" <td>47.978053</td>\n",
" <td>-112.121947</td>\n",
" <td>569.927281</td>\n",
" <td>7.108748</td>\n",
" <td>42.777281</td>\n",
" <td>-117.322719</td>\n",
" <td>571.027667</td>\n",
" <td>606.6962</td>\n",
" <td>446.5962</td>\n",
" <td>43.877667</td>\n",
" <td>-116.222333</td>\n",
" <td>-80.5538</td>\n",
" <td>59.362934</td>\n",
" <td>623.281853</td>\n",
" <td>60.463320</td>\n",
" <td>96.131853</td>\n",
" <td>-63.968147</td>\n",
" <td>618.081081</td>\n",
" <td>55.262548</td>\n",
" <td>90.931081</td>\n",
" <td>-69.168919</td>\n",
" <td>619.181467</td>\n",
" <td>654.85</td>\n",
" <td>494.75</td>\n",
" <td>92.031467</td>\n",
" <td>-68.068533</td>\n",
" <td>-32.40</td>\n",
" <td>572.362934</td>\n",
" <td>9.544402</td>\n",
" <td>45.212934</td>\n",
" <td>-114.887066</td>\n",
" <td>573.463320</td>\n",
" <td>609.131853</td>\n",
" <td>449.031853</td>\n",
" <td>46.313320</td>\n",
" <td>-113.786680</td>\n",
" <td>-78.118147</td>\n",
" <td>568.262548</td>\n",
" <td>603.931081</td>\n",
" <td>443.831081</td>\n",
" <td>41.112548</td>\n",
" <td>-118.987452</td>\n",
" <td>-83.318919</td>\n",
" <td>605.031467</td>\n",
" <td>444.931467</td>\n",
" <td>480.60</td>\n",
" <td>-82.218533</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" MedInc HouseAge ... PopulationLatitudeLongitude AveOccupLatitudeLongitude\n",
"0 8.3252 41.0 ... 237.65 -81.794444\n",
"1 8.3014 21.0 ... 2316.64 -82.250158\n",
"2 7.2574 52.0 ... 411.61 -81.587740\n",
"3 5.6431 52.0 ... 473.60 -81.852055\n",
"4 3.8462 52.0 ... 480.60 -82.218533\n",
"\n",
"[5 rows x 64 columns]"
]
},
"metadata": {
"tags": []
},
"execution_count": 5
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "NPPO4qieLULL"
},
"source": [
"def evaluate_dataframe(df, y):\n",
" X_train, X_test, y_train, y_test = train_test_split(df.values, y, test_size=0.5, random_state=1)\n",
" y_train = np.log1p(y_train)\n",
"\n",
" params = {\n",
" \"objective\": \"regression\",\n",
" \"metric\": \"rmse\",\n",
" \"learning_rate\": 0.1,\n",
" \"verbosity\": -1,\n",
" }\n",
" train_set = lgb.Dataset(X_train, label=y_train)\n",
" scores = lgb.cv(params, train_set, num_boost_round=100, stratified=False, seed=1)\n",
" rmsle_score = scores[\"rmse-mean\"][-1]\n",
" print(f\" - CV RMSEL: {rmsle_score:.6f}\")\n",
"\n",
" booster = lgb.train(params, train_set, num_boost_round=100)\n",
" y_pred = booster.predict(X_test)\n",
" test_rmsle_score = rmse(np.log1p(y_test), y_pred)\n",
" print(f\" - test RMSEL: {test_rmsle_score:.6f}\")\n",
"\n",
"def rmse(y_true, y_pred):\n",
" return np.sqrt(mean_squared_error(y_true, y_pred))"
],
"execution_count": 6,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "bKgOf5KBLiqR",
"outputId": "8e733691-13dc-4bd5-d74b-63f8f7610ccf"
},
"source": [
"print(\"Before adding interaction features:\")\n",
"evaluate_dataframe(df, data.target)"
],
"execution_count": 7,
"outputs": [
{
"output_type": "stream",
"text": [
"Before adding interaction features:\n",
" - CV RMSEL: 0.143823\n",
" - test RMSEL: 0.140610\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "RBuRF431LlTL"
},
"source": [
"def feature_engineering(df):\n",
" cols = df.columns.tolist()\n",
"\n",
" encoder = Pipeline([\n",
" ArithmeticCombinations(input_cols=cols,\n",
" drop_origin=False,\n",
" operator=\"+\",\n",
" r=2,\n",
" output_suffix=\"_plus\"),\n",
" ArithmeticCombinations(input_cols=cols,\n",
" drop_origin=False,\n",
" operator=\"*\",\n",
" r=2,\n",
" output_suffix=\"_mul\"),\n",
" ArithmeticCombinations(input_cols=cols,\n",
" drop_origin=False,\n",
" operator=\"-\",\n",
" r=2,\n",
" output_suffix=\"_minus\"),\n",
" ArithmeticCombinations(input_cols=cols,\n",
" drop_origin=False,\n",
" operator=\"+\",\n",
" r=3,\n",
" output_suffix=\"_plus\"),\n",
" ])\n",
" return encoder.fit_transform(df)"
],
"execution_count": 8,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "wsk5WE83LoD7",
"outputId": "fcae48b0-645f-4e9e-c81f-d14636d6221d"
},
"source": [
"print(\"After adding interaction features:\")\n",
"df = feature_engineering(df)\n",
"evaluate_dataframe(df, data.target)\n"
],
"execution_count": 9,
"outputs": [
{
"output_type": "stream",
"text": [
"After adding interaction features:\n",
" - CV RMSEL: 0.140239\n",
" - test RMSEL: 0.137046\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "y5B5IE5mLseS"
},
"source": [
"def objective(df, selector, trial):\n",
" selector.set_trial(trial)\n",
" selector.fit(df)\n",
" input_cols = selector.get_selected_cols()\n",
"\n",
" params = {\n",
" \"objective\": \"regression\",\n",
" \"metric\": \"rmse\",\n",
" \"learning_rate\": 0.1,\n",
" \"verbosity\": -1,\n",
" }\n",
"\n",
" # Evaluate with selected columns\n",
" train_set = lgb.Dataset(df[input_cols], label=df[\"target\"])\n",
" scores = lgb.cv(params, train_set, num_boost_round=100, stratified=False, seed=1)\n",
" rmsle_score = scores[\"rmse-mean\"][-1]\n",
" return rmsle_score\n",
"\n",
"def feature_selection(df, y):\n",
" input_cols = df.columns.tolist()\n",
" n_before_selection = len(input_cols)\n",
"\n",
" df[\"target\"] = np.log1p(y)\n",
" df_train, _ = train_test_split(df, test_size=0.5, random_state=1)\n",
"\n",
" params = {\n",
" \"objective\": \"regression\",\n",
" \"metric\": \"rmse\",\n",
" \"learning_rate\": 0.1,\n",
" \"verbosity\": -1,\n",
" }\n",
" fit_params = {\n",
" \"num_boost_round\": 100,\n",
" }\n",
" selector = GBDTFeatureExplorer(input_cols=input_cols,\n",
" target_col=\"target\",\n",
" fit_once=True,\n",
" threshold_range=(0.6, 1.0),\n",
" lgbm_params=params,\n",
" lgbm_fit_kwargs=fit_params)\n",
"\n",
" study = optuna.create_study(direction=\"minimize\")\n",
" study.optimize(partial(objective, df_train, selector), n_trials=20)\n",
"\n",
" selector.from_trial(study.best_trial)\n",
" selected_cols = selector.get_selected_cols()\n",
" print(f\" - {n_before_selection - len(selected_cols)} features are removed.\")\n",
"\n",
" return df[selected_cols]\n"
],
"execution_count": 10,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "zYr3OK6tLyi_",
"outputId": "41ac14e2-e141-4a2c-fc32-e6d0bb1d85b4"
},
"source": [
"print(\"After applying GBDTFeatureSelector:\")\n",
"df = feature_selection(df, data.target)\n",
"evaluate_dataframe(df, data.target)"
],
"execution_count": 11,
"outputs": [
{
"output_type": "stream",
"text": [
"\u001b[32m[I 2021-07-07 13:42:31,873]\u001b[0m A new study created in memory with name: no-name-864366cc-1513-48f1-bf2e-520bcf0dfdfd\u001b[0m\n"
],
"name": "stderr"
},
{
"output_type": "stream",
"text": [
"After applying GBDTFeatureSelector:\n"
],
"name": "stdout"
},
{
"output_type": "stream",
"text": [
"\u001b[32m[I 2021-07-07 13:42:44,986]\u001b[0m Trial 0 finished with value: 0.13979310476290702 and parameters: {'GBDTFeatureSelector.threshold': 0.6323723743056863}. Best is trial 0 with value: 0.13979310476290702.\u001b[0m\n",
"\u001b[32m[I 2021-07-07 13:42:55,737]\u001b[0m Trial 1 finished with value: 0.1400646078199391 and parameters: {'GBDTFeatureSelector.threshold': 0.7144803627038809}. Best is trial 0 with value: 0.13979310476290702.\u001b[0m\n",
"\u001b[32m[I 2021-07-07 13:43:09,875]\u001b[0m Trial 2 finished with value: 0.1398197395831723 and parameters: {'GBDTFeatureSelector.threshold': 0.9273308702240257}. Best is trial 0 with value: 0.13979310476290702.\u001b[0m\n",
"\u001b[32m[I 2021-07-07 13:43:22,438]\u001b[0m Trial 3 finished with value: 0.140407080489603 and parameters: {'GBDTFeatureSelector.threshold': 0.8311188333026339}. Best is trial 0 with value: 0.13979310476290702.\u001b[0m\n",
"\u001b[32m[I 2021-07-07 13:43:36,347]\u001b[0m Trial 4 finished with value: 0.13982575551305265 and parameters: {'GBDTFeatureSelector.threshold': 0.9166904887328864}. Best is trial 0 with value: 0.13979310476290702.\u001b[0m\n",
"\u001b[32m[I 2021-07-07 13:43:48,320]\u001b[0m Trial 5 finished with value: 0.14026843282987683 and parameters: {'GBDTFeatureSelector.threshold': 0.7966576382646842}. Best is trial 0 with value: 0.13979310476290702.\u001b[0m\n",
"\u001b[32m[I 2021-07-07 13:43:57,994]\u001b[0m Trial 6 finished with value: 0.14019537142311078 and parameters: {'GBDTFeatureSelector.threshold': 0.6483096443284599}. Best is trial 0 with value: 0.13979310476290702.\u001b[0m\n",
"\u001b[32m[I 2021-07-07 13:44:12,164]\u001b[0m Trial 7 finished with value: 0.13976970670197986 and parameters: {'GBDTFeatureSelector.threshold': 0.9362222845986482}. Best is trial 7 with value: 0.13976970670197986.\u001b[0m\n",
"\u001b[32m[I 2021-07-07 13:44:21,332]\u001b[0m Trial 8 finished with value: 0.139868083859196 and parameters: {'GBDTFeatureSelector.threshold': 0.6066806228284278}. Best is trial 7 with value: 0.13976970670197986.\u001b[0m\n",
"\u001b[32m[I 2021-07-07 13:44:31,102]\u001b[0m Trial 9 finished with value: 0.14019537142311078 and parameters: {'GBDTFeatureSelector.threshold': 0.6466075876661166}. Best is trial 7 with value: 0.13976970670197986.\u001b[0m\n",
"\u001b[32m[I 2021-07-07 13:44:46,223]\u001b[0m Trial 10 finished with value: 0.13993060595933537 and parameters: {'GBDTFeatureSelector.threshold': 0.9891604741450514}. Best is trial 7 with value: 0.13976970670197986.\u001b[0m\n",
"\u001b[32m[I 2021-07-07 13:44:58,336]\u001b[0m Trial 11 finished with value: 0.14026843282987683 and parameters: {'GBDTFeatureSelector.threshold': 0.7987916251476358}. Best is trial 7 with value: 0.13976970670197986.\u001b[0m\n",
"\u001b[32m[I 2021-07-07 13:45:13,527]\u001b[0m Trial 12 finished with value: 0.14019327941546061 and parameters: {'GBDTFeatureSelector.threshold': 0.9996359199564369}. Best is trial 7 with value: 0.13976970670197986.\u001b[0m\n",
"\u001b[32m[I 2021-07-07 13:45:26,800]\u001b[0m Trial 13 finished with value: 0.13983160378966167 and parameters: {'GBDTFeatureSelector.threshold': 0.8729919664652268}. Best is trial 7 with value: 0.13976970670197986.\u001b[0m\n",
"\u001b[32m[I 2021-07-07 13:45:37,914]\u001b[0m Trial 14 finished with value: 0.13999552399082213 and parameters: {'GBDTFeatureSelector.threshold': 0.7354805314376343}. Best is trial 7 with value: 0.13976970670197986.\u001b[0m\n",
"\u001b[32m[I 2021-07-07 13:45:48,734]\u001b[0m Trial 15 finished with value: 0.1400646078199391 and parameters: {'GBDTFeatureSelector.threshold': 0.7141932953817927}. Best is trial 7 with value: 0.13976970670197986.\u001b[0m\n",
"\u001b[32m[I 2021-07-07 13:46:03,348]\u001b[0m Trial 16 finished with value: 0.13964761866999725 and parameters: {'GBDTFeatureSelector.threshold': 0.9548019754279418}. Best is trial 16 with value: 0.13964761866999725.\u001b[0m\n",
"\u001b[32m[I 2021-07-07 13:46:17,991]\u001b[0m Trial 17 finished with value: 0.13964761866999725 and parameters: {'GBDTFeatureSelector.threshold': 0.9573316355560226}. Best is trial 16 with value: 0.13964761866999725.\u001b[0m\n",
"\u001b[32m[I 2021-07-07 13:46:32,909]\u001b[0m Trial 18 finished with value: 0.13988513329571153 and parameters: {'GBDTFeatureSelector.threshold': 0.9685334625000206}. Best is trial 16 with value: 0.13964761866999725.\u001b[0m\n",
"\u001b[32m[I 2021-07-07 13:46:46,275]\u001b[0m Trial 19 finished with value: 0.13996029128482462 and parameters: {'GBDTFeatureSelector.threshold': 0.8825238432565827}. Best is trial 16 with value: 0.13964761866999725.\u001b[0m\n"
],
"name": "stderr"
},
{
"output_type": "stream",
"text": [
" - 18 features are removed.\n",
" - CV RMSEL: 0.139960\n",
" - test RMSEL: 0.136711\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "uhMo0XgTLzDg"
},
"source": [
""
],
"execution_count": 11,
"outputs": []
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment