Skip to content

Instantly share code, notes, and snippets.

@nogawanogawa
Created July 7, 2021 13:20
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save nogawanogawa/6c247f4c1c3f7c64d0b119b4c1625968 to your computer and use it in GitHub Desktop.
Save nogawanogawa/6c247f4c1c3f7c64d0b119b4c1625968 to your computer and use it in GitHub Desktop.
xfeat_sample_2.ipynb
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "xfeat_sample_2.ipynb",
"provenance": [],
"collapsed_sections": [],
"machine_shape": "hm",
"authorship_tag": "ABX9TyNAkMWM4ow7uzyEdmWcnz5T",
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
}
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/nogawanogawa/6c247f4c1c3f7c64d0b119b4c1625968/xfeat_sample_2.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "e1CD2xBfCX7g",
"outputId": "67673932-c342-46fc-a432-d07df0f1943c"
},
"source": [
"!pip install xfeat "
],
"execution_count": 1,
"outputs": [
{
"output_type": "stream",
"text": [
"Requirement already satisfied: xfeat in /usr/local/lib/python3.7/dist-packages (0.1.1)\n",
"Requirement already satisfied: lightgbm in /usr/local/lib/python3.7/dist-packages (from xfeat) (2.2.3)\n",
"Requirement already satisfied: optuna>=1.3.0 in /usr/local/lib/python3.7/dist-packages (from xfeat) (2.8.0)\n",
"Requirement already satisfied: scikit-learn in /usr/local/lib/python3.7/dist-packages (from xfeat) (0.22.2.post1)\n",
"Requirement already satisfied: ml-metrics in /usr/local/lib/python3.7/dist-packages (from xfeat) (0.1.4)\n",
"Requirement already satisfied: pyarrow in /usr/local/lib/python3.7/dist-packages (from xfeat) (3.0.0)\n",
"Requirement already satisfied: PyYAML in /usr/local/lib/python3.7/dist-packages (from xfeat) (3.13)\n",
"Requirement already satisfied: numpy in /usr/local/lib/python3.7/dist-packages (from lightgbm->xfeat) (1.19.5)\n",
"Requirement already satisfied: scipy in /usr/local/lib/python3.7/dist-packages (from lightgbm->xfeat) (1.4.1)\n",
"Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.7/dist-packages (from optuna>=1.3.0->xfeat) (20.9)\n",
"Requirement already satisfied: cmaes>=0.8.2 in /usr/local/lib/python3.7/dist-packages (from optuna>=1.3.0->xfeat) (0.8.2)\n",
"Requirement already satisfied: cliff in /usr/local/lib/python3.7/dist-packages (from optuna>=1.3.0->xfeat) (3.8.0)\n",
"Requirement already satisfied: colorlog in /usr/local/lib/python3.7/dist-packages (from optuna>=1.3.0->xfeat) (5.0.1)\n",
"Requirement already satisfied: sqlalchemy>=1.1.0 in /usr/local/lib/python3.7/dist-packages (from optuna>=1.3.0->xfeat) (1.4.18)\n",
"Requirement already satisfied: tqdm in /usr/local/lib/python3.7/dist-packages (from optuna>=1.3.0->xfeat) (4.41.1)\n",
"Requirement already satisfied: alembic in /usr/local/lib/python3.7/dist-packages (from optuna>=1.3.0->xfeat) (1.6.5)\n",
"Requirement already satisfied: joblib>=0.11 in /usr/local/lib/python3.7/dist-packages (from scikit-learn->xfeat) (1.0.1)\n",
"Requirement already satisfied: pandas in /usr/local/lib/python3.7/dist-packages (from ml-metrics->xfeat) (1.1.5)\n",
"Requirement already satisfied: pyparsing>=2.0.2 in /usr/local/lib/python3.7/dist-packages (from packaging>=20.0->optuna>=1.3.0->xfeat) (2.4.7)\n",
"Requirement already satisfied: cmd2>=1.0.0 in /usr/local/lib/python3.7/dist-packages (from cliff->optuna>=1.3.0->xfeat) (2.1.2)\n",
"Requirement already satisfied: pbr!=2.1.0,>=2.0.0 in /usr/local/lib/python3.7/dist-packages (from cliff->optuna>=1.3.0->xfeat) (5.6.0)\n",
"Requirement already satisfied: PrettyTable>=0.7.2 in /usr/local/lib/python3.7/dist-packages (from cliff->optuna>=1.3.0->xfeat) (2.1.0)\n",
"Requirement already satisfied: stevedore>=2.0.1 in /usr/local/lib/python3.7/dist-packages (from cliff->optuna>=1.3.0->xfeat) (3.3.0)\n",
"Requirement already satisfied: greenlet!=0.4.17; python_version >= \"3\" in /usr/local/lib/python3.7/dist-packages (from sqlalchemy>=1.1.0->optuna>=1.3.0->xfeat) (1.1.0)\n",
"Requirement already satisfied: importlib-metadata; python_version < \"3.8\" in /usr/local/lib/python3.7/dist-packages (from sqlalchemy>=1.1.0->optuna>=1.3.0->xfeat) (4.5.0)\n",
"Requirement already satisfied: Mako in /usr/local/lib/python3.7/dist-packages (from alembic->optuna>=1.3.0->xfeat) (1.1.4)\n",
"Requirement already satisfied: python-editor>=0.3 in /usr/local/lib/python3.7/dist-packages (from alembic->optuna>=1.3.0->xfeat) (1.0.4)\n",
"Requirement already satisfied: python-dateutil in /usr/local/lib/python3.7/dist-packages (from alembic->optuna>=1.3.0->xfeat) (2.8.1)\n",
"Requirement already satisfied: pytz>=2017.2 in /usr/local/lib/python3.7/dist-packages (from pandas->ml-metrics->xfeat) (2018.9)\n",
"Requirement already satisfied: wcwidth>=0.1.7 in /usr/local/lib/python3.7/dist-packages (from cmd2>=1.0.0->cliff->optuna>=1.3.0->xfeat) (0.2.5)\n",
"Requirement already satisfied: attrs>=16.3.0 in /usr/local/lib/python3.7/dist-packages (from cmd2>=1.0.0->cliff->optuna>=1.3.0->xfeat) (21.2.0)\n",
"Requirement already satisfied: colorama>=0.3.7 in /usr/local/lib/python3.7/dist-packages (from cmd2>=1.0.0->cliff->optuna>=1.3.0->xfeat) (0.4.4)\n",
"Requirement already satisfied: typing-extensions; python_version < \"3.8\" in /usr/local/lib/python3.7/dist-packages (from cmd2>=1.0.0->cliff->optuna>=1.3.0->xfeat) (3.7.4.3)\n",
"Requirement already satisfied: pyperclip>=1.6 in /usr/local/lib/python3.7/dist-packages (from cmd2>=1.0.0->cliff->optuna>=1.3.0->xfeat) (1.8.2)\n",
"Requirement already satisfied: zipp>=0.5 in /usr/local/lib/python3.7/dist-packages (from importlib-metadata; python_version < \"3.8\"->sqlalchemy>=1.1.0->optuna>=1.3.0->xfeat) (3.4.1)\n",
"Requirement already satisfied: MarkupSafe>=0.9.2 in /usr/local/lib/python3.7/dist-packages (from Mako->alembic->optuna>=1.3.0->xfeat) (2.0.1)\n",
"Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.7/dist-packages (from python-dateutil->alembic->optuna>=1.3.0->xfeat) (1.15.0)\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "DUluQoZEBnNK"
},
"source": [
"import pandas as pd\n",
"import numpy as np\n",
"import lightgbm as lgb\n",
"import xfeat\n",
"from xfeat import SelectNumerical\n",
"from xfeat import ArithmeticCombinations, Pipeline\n",
"from xfeat import GBDTFeatureSelector\n",
"from sklearn.metrics import mean_squared_error\n",
"from sklearn.model_selection import train_test_split\n",
"from sklearn.datasets import fetch_california_housing"
],
"execution_count": 2,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 203
},
"id": "Wck5ivz0BxfW",
"outputId": "2c69a9be-8ead-457d-8ab7-d40975f2a211"
},
"source": [
"data = fetch_california_housing()\n",
"df = pd.DataFrame(\n",
" data=data.data,\n",
" columns=data.feature_names)\n",
"df.head()"
],
"execution_count": 3,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>MedInc</th>\n",
" <th>HouseAge</th>\n",
" <th>AveRooms</th>\n",
" <th>AveBedrms</th>\n",
" <th>Population</th>\n",
" <th>AveOccup</th>\n",
" <th>Latitude</th>\n",
" <th>Longitude</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>8.3252</td>\n",
" <td>41.0</td>\n",
" <td>6.984127</td>\n",
" <td>1.023810</td>\n",
" <td>322.0</td>\n",
" <td>2.555556</td>\n",
" <td>37.88</td>\n",
" <td>-122.23</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>8.3014</td>\n",
" <td>21.0</td>\n",
" <td>6.238137</td>\n",
" <td>0.971880</td>\n",
" <td>2401.0</td>\n",
" <td>2.109842</td>\n",
" <td>37.86</td>\n",
" <td>-122.22</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>7.2574</td>\n",
" <td>52.0</td>\n",
" <td>8.288136</td>\n",
" <td>1.073446</td>\n",
" <td>496.0</td>\n",
" <td>2.802260</td>\n",
" <td>37.85</td>\n",
" <td>-122.24</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>5.6431</td>\n",
" <td>52.0</td>\n",
" <td>5.817352</td>\n",
" <td>1.073059</td>\n",
" <td>558.0</td>\n",
" <td>2.547945</td>\n",
" <td>37.85</td>\n",
" <td>-122.25</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>3.8462</td>\n",
" <td>52.0</td>\n",
" <td>6.281853</td>\n",
" <td>1.081081</td>\n",
" <td>565.0</td>\n",
" <td>2.181467</td>\n",
" <td>37.85</td>\n",
" <td>-122.25</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" MedInc HouseAge AveRooms ... AveOccup Latitude Longitude\n",
"0 8.3252 41.0 6.984127 ... 2.555556 37.88 -122.23\n",
"1 8.3014 21.0 6.238137 ... 2.109842 37.86 -122.22\n",
"2 7.2574 52.0 8.288136 ... 2.802260 37.85 -122.24\n",
"3 5.6431 52.0 5.817352 ... 2.547945 37.85 -122.25\n",
"4 3.8462 52.0 6.281853 ... 2.181467 37.85 -122.25\n",
"\n",
"[5 rows x 8 columns]"
]
},
"metadata": {
"tags": []
},
"execution_count": 3
}
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 203
},
"id": "9Yn2oHRcC9Dq",
"outputId": "39fc8c40-fdf8-4acc-be5e-250f6798b82a"
},
"source": [
"SelectNumerical().fit_transform(df).head()"
],
"execution_count": 4,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>MedInc</th>\n",
" <th>HouseAge</th>\n",
" <th>AveRooms</th>\n",
" <th>AveBedrms</th>\n",
" <th>Population</th>\n",
" <th>AveOccup</th>\n",
" <th>Latitude</th>\n",
" <th>Longitude</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>8.3252</td>\n",
" <td>41.0</td>\n",
" <td>6.984127</td>\n",
" <td>1.023810</td>\n",
" <td>322.0</td>\n",
" <td>2.555556</td>\n",
" <td>37.88</td>\n",
" <td>-122.23</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>8.3014</td>\n",
" <td>21.0</td>\n",
" <td>6.238137</td>\n",
" <td>0.971880</td>\n",
" <td>2401.0</td>\n",
" <td>2.109842</td>\n",
" <td>37.86</td>\n",
" <td>-122.22</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>7.2574</td>\n",
" <td>52.0</td>\n",
" <td>8.288136</td>\n",
" <td>1.073446</td>\n",
" <td>496.0</td>\n",
" <td>2.802260</td>\n",
" <td>37.85</td>\n",
" <td>-122.24</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>5.6431</td>\n",
" <td>52.0</td>\n",
" <td>5.817352</td>\n",
" <td>1.073059</td>\n",
" <td>558.0</td>\n",
" <td>2.547945</td>\n",
" <td>37.85</td>\n",
" <td>-122.25</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>3.8462</td>\n",
" <td>52.0</td>\n",
" <td>6.281853</td>\n",
" <td>1.081081</td>\n",
" <td>565.0</td>\n",
" <td>2.181467</td>\n",
" <td>37.85</td>\n",
" <td>-122.25</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" MedInc HouseAge AveRooms ... AveOccup Latitude Longitude\n",
"0 8.3252 41.0 6.984127 ... 2.555556 37.88 -122.23\n",
"1 8.3014 21.0 6.238137 ... 2.109842 37.86 -122.22\n",
"2 7.2574 52.0 8.288136 ... 2.802260 37.85 -122.24\n",
"3 5.6431 52.0 5.817352 ... 2.547945 37.85 -122.25\n",
"4 3.8462 52.0 6.281853 ... 2.181467 37.85 -122.25\n",
"\n",
"[5 rows x 8 columns]"
]
},
"metadata": {
"tags": []
},
"execution_count": 4
}
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 223
},
"id": "G6-utLXKDEMj",
"outputId": "d27ff95f-ba8f-4f52-f5c6-7522370fafd0"
},
"source": [
"encoder = Pipeline([\n",
" SelectNumerical(),\n",
" ArithmeticCombinations(\n",
" drop_origin=False,\n",
" operator=\"+\",\n",
" r=3,\n",
" output_suffix=\"\",\n",
" ),\n",
"])\n",
"encoder.fit_transform(df).head()"
],
"execution_count": 5,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>MedInc</th>\n",
" <th>HouseAge</th>\n",
" <th>AveRooms</th>\n",
" <th>AveBedrms</th>\n",
" <th>Population</th>\n",
" <th>AveOccup</th>\n",
" <th>Latitude</th>\n",
" <th>Longitude</th>\n",
" <th>MedIncHouseAgeAveRooms</th>\n",
" <th>MedIncHouseAgeAveBedrms</th>\n",
" <th>MedIncHouseAgePopulation</th>\n",
" <th>MedIncHouseAgeAveOccup</th>\n",
" <th>MedIncHouseAgeLatitude</th>\n",
" <th>MedIncHouseAgeLongitude</th>\n",
" <th>MedIncAveRoomsAveBedrms</th>\n",
" <th>MedIncAveRoomsPopulation</th>\n",
" <th>MedIncAveRoomsAveOccup</th>\n",
" <th>MedIncAveRoomsLatitude</th>\n",
" <th>MedIncAveRoomsLongitude</th>\n",
" <th>MedIncAveBedrmsPopulation</th>\n",
" <th>MedIncAveBedrmsAveOccup</th>\n",
" <th>MedIncAveBedrmsLatitude</th>\n",
" <th>MedIncAveBedrmsLongitude</th>\n",
" <th>MedIncPopulationAveOccup</th>\n",
" <th>MedIncPopulationLatitude</th>\n",
" <th>MedIncPopulationLongitude</th>\n",
" <th>MedIncAveOccupLatitude</th>\n",
" <th>MedIncAveOccupLongitude</th>\n",
" <th>MedIncLatitudeLongitude</th>\n",
" <th>HouseAgeAveRoomsAveBedrms</th>\n",
" <th>HouseAgeAveRoomsPopulation</th>\n",
" <th>HouseAgeAveRoomsAveOccup</th>\n",
" <th>HouseAgeAveRoomsLatitude</th>\n",
" <th>HouseAgeAveRoomsLongitude</th>\n",
" <th>HouseAgeAveBedrmsPopulation</th>\n",
" <th>HouseAgeAveBedrmsAveOccup</th>\n",
" <th>HouseAgeAveBedrmsLatitude</th>\n",
" <th>HouseAgeAveBedrmsLongitude</th>\n",
" <th>HouseAgePopulationAveOccup</th>\n",
" <th>HouseAgePopulationLatitude</th>\n",
" <th>HouseAgePopulationLongitude</th>\n",
" <th>HouseAgeAveOccupLatitude</th>\n",
" <th>HouseAgeAveOccupLongitude</th>\n",
" <th>HouseAgeLatitudeLongitude</th>\n",
" <th>AveRoomsAveBedrmsPopulation</th>\n",
" <th>AveRoomsAveBedrmsAveOccup</th>\n",
" <th>AveRoomsAveBedrmsLatitude</th>\n",
" <th>AveRoomsAveBedrmsLongitude</th>\n",
" <th>AveRoomsPopulationAveOccup</th>\n",
" <th>AveRoomsPopulationLatitude</th>\n",
" <th>AveRoomsPopulationLongitude</th>\n",
" <th>AveRoomsAveOccupLatitude</th>\n",
" <th>AveRoomsAveOccupLongitude</th>\n",
" <th>AveRoomsLatitudeLongitude</th>\n",
" <th>AveBedrmsPopulationAveOccup</th>\n",
" <th>AveBedrmsPopulationLatitude</th>\n",
" <th>AveBedrmsPopulationLongitude</th>\n",
" <th>AveBedrmsAveOccupLatitude</th>\n",
" <th>AveBedrmsAveOccupLongitude</th>\n",
" <th>AveBedrmsLatitudeLongitude</th>\n",
" <th>PopulationAveOccupLatitude</th>\n",
" <th>PopulationAveOccupLongitude</th>\n",
" <th>PopulationLatitudeLongitude</th>\n",
" <th>AveOccupLatitudeLongitude</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>8.3252</td>\n",
" <td>41.0</td>\n",
" <td>6.984127</td>\n",
" <td>1.023810</td>\n",
" <td>322.0</td>\n",
" <td>2.555556</td>\n",
" <td>37.88</td>\n",
" <td>-122.23</td>\n",
" <td>56.309327</td>\n",
" <td>50.349010</td>\n",
" <td>371.3252</td>\n",
" <td>51.880756</td>\n",
" <td>87.2052</td>\n",
" <td>-72.9048</td>\n",
" <td>16.333137</td>\n",
" <td>337.309327</td>\n",
" <td>17.864883</td>\n",
" <td>53.189327</td>\n",
" <td>-106.920673</td>\n",
" <td>331.349010</td>\n",
" <td>11.904565</td>\n",
" <td>47.229010</td>\n",
" <td>-112.880990</td>\n",
" <td>332.880756</td>\n",
" <td>368.2052</td>\n",
" <td>208.0952</td>\n",
" <td>48.760756</td>\n",
" <td>-111.349244</td>\n",
" <td>-76.0248</td>\n",
" <td>49.007937</td>\n",
" <td>369.984127</td>\n",
" <td>50.539683</td>\n",
" <td>85.864127</td>\n",
" <td>-74.245873</td>\n",
" <td>364.023810</td>\n",
" <td>44.579365</td>\n",
" <td>79.903810</td>\n",
" <td>-80.206190</td>\n",
" <td>365.555556</td>\n",
" <td>400.88</td>\n",
" <td>240.77</td>\n",
" <td>81.435556</td>\n",
" <td>-78.674444</td>\n",
" <td>-43.35</td>\n",
" <td>330.007937</td>\n",
" <td>10.563492</td>\n",
" <td>45.887937</td>\n",
" <td>-114.222063</td>\n",
" <td>331.539683</td>\n",
" <td>366.864127</td>\n",
" <td>206.754127</td>\n",
" <td>47.419683</td>\n",
" <td>-112.690317</td>\n",
" <td>-77.365873</td>\n",
" <td>325.579365</td>\n",
" <td>360.903810</td>\n",
" <td>200.793810</td>\n",
" <td>41.459365</td>\n",
" <td>-118.650635</td>\n",
" <td>-83.326190</td>\n",
" <td>362.435556</td>\n",
" <td>202.325556</td>\n",
" <td>237.65</td>\n",
" <td>-81.794444</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>8.3014</td>\n",
" <td>21.0</td>\n",
" <td>6.238137</td>\n",
" <td>0.971880</td>\n",
" <td>2401.0</td>\n",
" <td>2.109842</td>\n",
" <td>37.86</td>\n",
" <td>-122.22</td>\n",
" <td>35.539537</td>\n",
" <td>30.273280</td>\n",
" <td>2430.3014</td>\n",
" <td>31.411242</td>\n",
" <td>67.1614</td>\n",
" <td>-92.9186</td>\n",
" <td>15.511418</td>\n",
" <td>2415.539537</td>\n",
" <td>16.649379</td>\n",
" <td>52.399537</td>\n",
" <td>-107.680463</td>\n",
" <td>2410.273280</td>\n",
" <td>11.383122</td>\n",
" <td>47.133280</td>\n",
" <td>-112.946720</td>\n",
" <td>2411.411242</td>\n",
" <td>2447.1614</td>\n",
" <td>2287.0814</td>\n",
" <td>48.271242</td>\n",
" <td>-111.808758</td>\n",
" <td>-76.0586</td>\n",
" <td>28.210018</td>\n",
" <td>2428.238137</td>\n",
" <td>29.347979</td>\n",
" <td>65.098137</td>\n",
" <td>-94.981863</td>\n",
" <td>2422.971880</td>\n",
" <td>24.081722</td>\n",
" <td>59.831880</td>\n",
" <td>-100.248120</td>\n",
" <td>2424.109842</td>\n",
" <td>2459.86</td>\n",
" <td>2299.78</td>\n",
" <td>60.969842</td>\n",
" <td>-99.110158</td>\n",
" <td>-63.36</td>\n",
" <td>2408.210018</td>\n",
" <td>9.319859</td>\n",
" <td>45.070018</td>\n",
" <td>-115.009982</td>\n",
" <td>2409.347979</td>\n",
" <td>2445.098137</td>\n",
" <td>2285.018137</td>\n",
" <td>46.207979</td>\n",
" <td>-113.872021</td>\n",
" <td>-78.121863</td>\n",
" <td>2404.081722</td>\n",
" <td>2439.831880</td>\n",
" <td>2279.751880</td>\n",
" <td>40.941722</td>\n",
" <td>-119.138278</td>\n",
" <td>-83.388120</td>\n",
" <td>2440.969842</td>\n",
" <td>2280.889842</td>\n",
" <td>2316.64</td>\n",
" <td>-82.250158</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>7.2574</td>\n",
" <td>52.0</td>\n",
" <td>8.288136</td>\n",
" <td>1.073446</td>\n",
" <td>496.0</td>\n",
" <td>2.802260</td>\n",
" <td>37.85</td>\n",
" <td>-122.24</td>\n",
" <td>67.545536</td>\n",
" <td>60.330846</td>\n",
" <td>555.2574</td>\n",
" <td>62.059660</td>\n",
" <td>97.1074</td>\n",
" <td>-62.9826</td>\n",
" <td>16.618982</td>\n",
" <td>511.545536</td>\n",
" <td>18.347795</td>\n",
" <td>53.395536</td>\n",
" <td>-106.694464</td>\n",
" <td>504.330846</td>\n",
" <td>11.133106</td>\n",
" <td>46.180846</td>\n",
" <td>-113.909154</td>\n",
" <td>506.059660</td>\n",
" <td>541.1074</td>\n",
" <td>381.0174</td>\n",
" <td>47.909660</td>\n",
" <td>-112.180340</td>\n",
" <td>-77.1326</td>\n",
" <td>61.361582</td>\n",
" <td>556.288136</td>\n",
" <td>63.090395</td>\n",
" <td>98.138136</td>\n",
" <td>-61.951864</td>\n",
" <td>549.073446</td>\n",
" <td>55.875706</td>\n",
" <td>90.923446</td>\n",
" <td>-69.166554</td>\n",
" <td>550.802260</td>\n",
" <td>585.85</td>\n",
" <td>425.76</td>\n",
" <td>92.652260</td>\n",
" <td>-67.437740</td>\n",
" <td>-32.39</td>\n",
" <td>505.361582</td>\n",
" <td>12.163842</td>\n",
" <td>47.211582</td>\n",
" <td>-112.878418</td>\n",
" <td>507.090395</td>\n",
" <td>542.138136</td>\n",
" <td>382.048136</td>\n",
" <td>48.940395</td>\n",
" <td>-111.149605</td>\n",
" <td>-76.101864</td>\n",
" <td>499.875706</td>\n",
" <td>534.923446</td>\n",
" <td>374.833446</td>\n",
" <td>41.725706</td>\n",
" <td>-118.364294</td>\n",
" <td>-83.316554</td>\n",
" <td>536.652260</td>\n",
" <td>376.562260</td>\n",
" <td>411.61</td>\n",
" <td>-81.587740</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>5.6431</td>\n",
" <td>52.0</td>\n",
" <td>5.817352</td>\n",
" <td>1.073059</td>\n",
" <td>558.0</td>\n",
" <td>2.547945</td>\n",
" <td>37.85</td>\n",
" <td>-122.25</td>\n",
" <td>63.460452</td>\n",
" <td>58.716159</td>\n",
" <td>615.6431</td>\n",
" <td>60.191045</td>\n",
" <td>95.4931</td>\n",
" <td>-64.6069</td>\n",
" <td>12.533511</td>\n",
" <td>569.460452</td>\n",
" <td>14.008397</td>\n",
" <td>49.310452</td>\n",
" <td>-110.789548</td>\n",
" <td>564.716159</td>\n",
" <td>9.264105</td>\n",
" <td>44.566159</td>\n",
" <td>-115.533841</td>\n",
" <td>566.191045</td>\n",
" <td>601.4931</td>\n",
" <td>441.3931</td>\n",
" <td>46.041045</td>\n",
" <td>-114.058955</td>\n",
" <td>-78.7569</td>\n",
" <td>58.890411</td>\n",
" <td>615.817352</td>\n",
" <td>60.365297</td>\n",
" <td>95.667352</td>\n",
" <td>-64.432648</td>\n",
" <td>611.073059</td>\n",
" <td>55.621005</td>\n",
" <td>90.923059</td>\n",
" <td>-69.176941</td>\n",
" <td>612.547945</td>\n",
" <td>647.85</td>\n",
" <td>487.75</td>\n",
" <td>92.397945</td>\n",
" <td>-67.702055</td>\n",
" <td>-32.40</td>\n",
" <td>564.890411</td>\n",
" <td>9.438356</td>\n",
" <td>44.740411</td>\n",
" <td>-115.359589</td>\n",
" <td>566.365297</td>\n",
" <td>601.667352</td>\n",
" <td>441.567352</td>\n",
" <td>46.215297</td>\n",
" <td>-113.884703</td>\n",
" <td>-78.582648</td>\n",
" <td>561.621005</td>\n",
" <td>596.923059</td>\n",
" <td>436.823059</td>\n",
" <td>41.471005</td>\n",
" <td>-118.628995</td>\n",
" <td>-83.326941</td>\n",
" <td>598.397945</td>\n",
" <td>438.297945</td>\n",
" <td>473.60</td>\n",
" <td>-81.852055</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>3.8462</td>\n",
" <td>52.0</td>\n",
" <td>6.281853</td>\n",
" <td>1.081081</td>\n",
" <td>565.0</td>\n",
" <td>2.181467</td>\n",
" <td>37.85</td>\n",
" <td>-122.25</td>\n",
" <td>62.128053</td>\n",
" <td>56.927281</td>\n",
" <td>620.8462</td>\n",
" <td>58.027667</td>\n",
" <td>93.6962</td>\n",
" <td>-66.4038</td>\n",
" <td>11.209134</td>\n",
" <td>575.128053</td>\n",
" <td>12.309520</td>\n",
" <td>47.978053</td>\n",
" <td>-112.121947</td>\n",
" <td>569.927281</td>\n",
" <td>7.108748</td>\n",
" <td>42.777281</td>\n",
" <td>-117.322719</td>\n",
" <td>571.027667</td>\n",
" <td>606.6962</td>\n",
" <td>446.5962</td>\n",
" <td>43.877667</td>\n",
" <td>-116.222333</td>\n",
" <td>-80.5538</td>\n",
" <td>59.362934</td>\n",
" <td>623.281853</td>\n",
" <td>60.463320</td>\n",
" <td>96.131853</td>\n",
" <td>-63.968147</td>\n",
" <td>618.081081</td>\n",
" <td>55.262548</td>\n",
" <td>90.931081</td>\n",
" <td>-69.168919</td>\n",
" <td>619.181467</td>\n",
" <td>654.85</td>\n",
" <td>494.75</td>\n",
" <td>92.031467</td>\n",
" <td>-68.068533</td>\n",
" <td>-32.40</td>\n",
" <td>572.362934</td>\n",
" <td>9.544402</td>\n",
" <td>45.212934</td>\n",
" <td>-114.887066</td>\n",
" <td>573.463320</td>\n",
" <td>609.131853</td>\n",
" <td>449.031853</td>\n",
" <td>46.313320</td>\n",
" <td>-113.786680</td>\n",
" <td>-78.118147</td>\n",
" <td>568.262548</td>\n",
" <td>603.931081</td>\n",
" <td>443.831081</td>\n",
" <td>41.112548</td>\n",
" <td>-118.987452</td>\n",
" <td>-83.318919</td>\n",
" <td>605.031467</td>\n",
" <td>444.931467</td>\n",
" <td>480.60</td>\n",
" <td>-82.218533</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" MedInc HouseAge ... PopulationLatitudeLongitude AveOccupLatitudeLongitude\n",
"0 8.3252 41.0 ... 237.65 -81.794444\n",
"1 8.3014 21.0 ... 2316.64 -82.250158\n",
"2 7.2574 52.0 ... 411.61 -81.587740\n",
"3 5.6431 52.0 ... 473.60 -81.852055\n",
"4 3.8462 52.0 ... 480.60 -82.218533\n",
"\n",
"[5 rows x 64 columns]"
]
},
"metadata": {
"tags": []
},
"execution_count": 5
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "vm99k0voGU_D"
},
"source": [
"def evaluate_dataframe(df, y):\n",
" X_train, X_test, y_train, y_test = train_test_split(df.values, y, test_size=0.5, random_state=1)\n",
" y_train = np.log1p(y_train)\n",
"\n",
" params = {\n",
" \"objective\": \"regression\",\n",
" \"metric\": \"rmse\",\n",
" \"learning_rate\": 0.1,\n",
" \"verbosity\": -1,\n",
" }\n",
" train_set = lgb.Dataset(X_train, label=y_train)\n",
" scores = lgb.cv(params, train_set, num_boost_round=100, stratified=False, seed=1)\n",
" rmsle_score = scores[\"rmse-mean\"][-1]\n",
" print(f\" - CV RMSEL: {rmsle_score:.6f}\")\n",
"\n",
" booster = lgb.train(params, train_set, num_boost_round=100)\n",
" y_pred = booster.predict(X_test)\n",
" test_rmsle_score = rmse(np.log1p(y_test), y_pred)\n",
" print(f\" - test RMSEL: {test_rmsle_score:.6f}\")\n",
"\n",
"def rmse(y_true, y_pred):\n",
" return np.sqrt(mean_squared_error(y_true, y_pred))\n"
],
"execution_count": 6,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "QDE4PzRAGXT4",
"outputId": "8774db65-341b-4de4-e128-aaa90dca251b"
},
"source": [
"print(\"Before adding interaction features:\")\n",
"evaluate_dataframe(df, data.target)"
],
"execution_count": 7,
"outputs": [
{
"output_type": "stream",
"text": [
"Before adding interaction features:\n",
" - CV RMSEL: 0.143823\n",
" - test RMSEL: 0.140610\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "OKS3SpOvGZfj"
},
"source": [
"def feature_engineering(df):\n",
" cols = df.columns.tolist()\n",
"\n",
" encoder = Pipeline([\n",
" ArithmeticCombinations(input_cols=cols,\n",
" drop_origin=False,\n",
" operator=\"+\",\n",
" r=2,\n",
" output_suffix=\"_plus\"),\n",
" ArithmeticCombinations(input_cols=cols,\n",
" drop_origin=False,\n",
" operator=\"*\",\n",
" r=2,\n",
" output_suffix=\"_mul\"),\n",
" ArithmeticCombinations(input_cols=cols,\n",
" drop_origin=False,\n",
" operator=\"-\",\n",
" r=2,\n",
" output_suffix=\"_minus\"),\n",
" ArithmeticCombinations(input_cols=cols,\n",
" drop_origin=False,\n",
" operator=\"+\",\n",
" r=3,\n",
" output_suffix=\"_plus\"),\n",
" ])\n",
" return encoder.fit_transform(df)"
],
"execution_count": 8,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "Nzb6IjrnJ6uQ",
"outputId": "fcc08084-c0bb-432d-c39f-8917a6111c25"
},
"source": [
"print(\"After adding interaction features:\")\n",
"df = feature_engineering(df)\n",
"evaluate_dataframe(df, data.target)\n"
],
"execution_count": 9,
"outputs": [
{
"output_type": "stream",
"text": [
"After adding interaction features:\n",
" - CV RMSEL: 0.140239\n",
" - test RMSEL: 0.137046\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "_Chquf60KDa8"
},
"source": [
"def feature_selection(df, y):\n",
" input_cols = df.columns.tolist()\n",
"\n",
" df[\"target\"] = np.log1p(y)\n",
" df_train, _ = train_test_split(df, test_size=0.5, random_state=1)\n",
"\n",
" params = {\n",
" \"objective\": \"regression\",\n",
" \"metric\": \"rmse\",\n",
" \"learning_rate\": 0.1,\n",
" \"verbosity\": -1,\n",
" }\n",
" fit_params = {\n",
" \"num_boost_round\": 100,\n",
" }\n",
"\n",
" selector = GBDTFeatureSelector(input_cols=input_cols,\n",
" target_col=\"target\",\n",
" threshold=0.95,\n",
" lgbm_params=params,\n",
" lgbm_fit_kwargs=fit_params)\n",
"\n",
" n_before_selection = len(input_cols)\n",
"\n",
" selector.fit(df_train)\n",
" selected_cols = selector.get_selected_cols()\n",
" print(f\" - {n_before_selection - len(selected_cols)} features are removed.\")\n",
"\n",
" return df[selected_cols]"
],
"execution_count": 10,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "dK_PhUVYKC6V",
"outputId": "9009463e-d3be-4e42-93b2-465cc6485440"
},
"source": [
"print(\"After applying GBDTFeatureSelector:\")\n",
"df = feature_selection(df, data.target)\n",
"evaluate_dataframe(df, data.target)"
],
"execution_count": 11,
"outputs": [
{
"output_type": "stream",
"text": [
"After applying GBDTFeatureSelector:\n",
" - 8 features are removed.\n",
" - CV RMSEL: 0.139648\n",
" - test RMSEL: 0.137029\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "zktd-cypLZZo"
},
"source": [
""
],
"execution_count": 11,
"outputs": []
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment