Skip to content

Instantly share code, notes, and snippets.

@Dixhom
Created October 10, 2022 02:29
Show Gist options
  • Save Dixhom/bc5c6de519e48718571f4905bbac59e4 to your computer and use it in GitHub Desktop.
Save Dixhom/bc5c6de519e48718571f4905bbac59e4 to your computer and use it in GitHub Desktop.
tree-model-multicollinearity

What is this?

This is a verification to confirm whether a tree based machine learning model causes multicollinearity. This is purely out of my curiosity. I just wondered if multicollinearity takes place in a tree based model and aggravates its performance. If it does, it means one needs to remove correlated features before building models.

Method

I artificially generated three datasets, obtained scores with lightgbm and k-fold cross validation and compared the finals scores.

  1. dataset 1: 10 informative features (useful features for prediction)
  2. dataset 2: 10 informative features + 10 features with random values
  3. dataset 3: 10 informative features + 10 redundant features (random linear combination of the informative features)

The experiment above was repeated using two kinds of class imbalance proportion: (0, 1)=(0.5, 0.5) and (0, 1)=(0.99, 0.01).

Result

The boxplots of the three datasets were almost the same for both of the class imbalance proportion. In the conditions used in this experiment, a tree based machine learning model doesn't cause multicollinearity.

Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [],
"source": [
"import pandas as pd\n",
"import numpy as np\n",
"from sklearn.model_selection import StratifiedKFold\n",
"from sklearn.datasets import make_classification\n",
"from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score\n",
"import lightgbm as lgb\n",
"import matplotlib.pyplot as plt "
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"# 10 informative features \n",
"def generate_informative_df(class_weights=[0.5]):\n",
" X, y = make_classification(n_samples=10000, n_features=10, n_informative=10, n_redundant=0, n_classes=2, weights=class_weights)\n",
" df = pd.DataFrame(X, columns=[f'info{i}' for i in range(10)])\n",
" df['y'] = y\n",
" return df\n",
"\n",
"# 10 informative features + 10 random features\n",
"def generate_informative_random_df(class_weights=[0.5]):\n",
" X, y = make_classification(n_samples=10000, n_features=20, n_informative=10, n_redundant=0, n_classes=2, weights=class_weights)\n",
" df = pd.DataFrame(X, columns=[f'info{i}' for i in range(10)] + [f'rand{i}' for i in range(10)])\n",
" df['y'] = y\n",
" return df\n",
"\n",
"# 10 informative features + 10 redundant features (multicollinearity)\n",
"def generate_informative_redundant_df(class_weights=[0.5]):\n",
" X, y = make_classification(n_samples=10000, n_features=20, n_informative=10, n_redundant=10, n_classes=2, weights=class_weights)\n",
" df = pd.DataFrame(X, columns=[f'info{i}' for i in range(10)] + [f'red{i}' for i in range(10)])\n",
" df['y'] = y\n",
" return df"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>info0</th>\n",
" <th>info1</th>\n",
" <th>info2</th>\n",
" <th>info3</th>\n",
" <th>info4</th>\n",
" <th>info5</th>\n",
" <th>info6</th>\n",
" <th>info7</th>\n",
" <th>info8</th>\n",
" <th>info9</th>\n",
" <th>y</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>2.773325</td>\n",
" <td>0.889765</td>\n",
" <td>4.426880</td>\n",
" <td>-3.113184</td>\n",
" <td>4.747506</td>\n",
" <td>-1.178276</td>\n",
" <td>-2.529489</td>\n",
" <td>-2.395454</td>\n",
" <td>-0.883241</td>\n",
" <td>3.458592</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>1.911636</td>\n",
" <td>0.521006</td>\n",
" <td>-0.886217</td>\n",
" <td>2.298344</td>\n",
" <td>-1.322928</td>\n",
" <td>-0.709265</td>\n",
" <td>0.344617</td>\n",
" <td>-2.144182</td>\n",
" <td>-0.286680</td>\n",
" <td>-1.365262</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>-0.271112</td>\n",
" <td>0.240984</td>\n",
" <td>-0.993360</td>\n",
" <td>-0.098098</td>\n",
" <td>-0.714097</td>\n",
" <td>0.126976</td>\n",
" <td>1.390515</td>\n",
" <td>-2.398267</td>\n",
" <td>4.365493</td>\n",
" <td>0.799304</td>\n",
" <td>1</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" info0 info1 info2 info3 info4 info5 info6 \\\n",
"0 2.773325 0.889765 4.426880 -3.113184 4.747506 -1.178276 -2.529489 \n",
"1 1.911636 0.521006 -0.886217 2.298344 -1.322928 -0.709265 0.344617 \n",
"2 -0.271112 0.240984 -0.993360 -0.098098 -0.714097 0.126976 1.390515 \n",
"\n",
" info7 info8 info9 y \n",
"0 -2.395454 -0.883241 3.458592 1 \n",
"1 -2.144182 -0.286680 -1.365262 0 \n",
"2 -2.398267 4.365493 0.799304 1 "
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"generate_informative_df().head(3)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>info0</th>\n",
" <th>info1</th>\n",
" <th>info2</th>\n",
" <th>info3</th>\n",
" <th>info4</th>\n",
" <th>info5</th>\n",
" <th>info6</th>\n",
" <th>info7</th>\n",
" <th>info8</th>\n",
" <th>info9</th>\n",
" <th>...</th>\n",
" <th>rand1</th>\n",
" <th>rand2</th>\n",
" <th>rand3</th>\n",
" <th>rand4</th>\n",
" <th>rand5</th>\n",
" <th>rand6</th>\n",
" <th>rand7</th>\n",
" <th>rand8</th>\n",
" <th>rand9</th>\n",
" <th>y</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>1.936563</td>\n",
" <td>1.870453</td>\n",
" <td>0.465916</td>\n",
" <td>-1.261388</td>\n",
" <td>2.133361</td>\n",
" <td>1.803505</td>\n",
" <td>2.580378</td>\n",
" <td>-1.595381</td>\n",
" <td>0.679063</td>\n",
" <td>-0.645615</td>\n",
" <td>...</td>\n",
" <td>-2.378653</td>\n",
" <td>4.323241</td>\n",
" <td>-2.079817</td>\n",
" <td>0.081981</td>\n",
" <td>0.040277</td>\n",
" <td>1.587782</td>\n",
" <td>0.094462</td>\n",
" <td>-0.418136</td>\n",
" <td>-1.090805</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>0.184026</td>\n",
" <td>-2.168789</td>\n",
" <td>0.411869</td>\n",
" <td>-0.942830</td>\n",
" <td>1.797832</td>\n",
" <td>0.030025</td>\n",
" <td>-2.441538</td>\n",
" <td>0.781152</td>\n",
" <td>0.594451</td>\n",
" <td>1.105893</td>\n",
" <td>...</td>\n",
" <td>0.443111</td>\n",
" <td>1.711014</td>\n",
" <td>0.709275</td>\n",
" <td>-4.459221</td>\n",
" <td>-0.652913</td>\n",
" <td>3.639128</td>\n",
" <td>1.089707</td>\n",
" <td>1.391340</td>\n",
" <td>-0.512643</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>2.996100</td>\n",
" <td>-4.153566</td>\n",
" <td>1.262521</td>\n",
" <td>-1.465341</td>\n",
" <td>-2.303118</td>\n",
" <td>-0.769839</td>\n",
" <td>1.964011</td>\n",
" <td>3.439831</td>\n",
" <td>-0.543817</td>\n",
" <td>0.434642</td>\n",
" <td>...</td>\n",
" <td>0.139599</td>\n",
" <td>-2.905087</td>\n",
" <td>0.447791</td>\n",
" <td>1.162886</td>\n",
" <td>7.065497</td>\n",
" <td>4.141545</td>\n",
" <td>-0.199840</td>\n",
" <td>-0.675612</td>\n",
" <td>-1.475501</td>\n",
" <td>1</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>3 rows × 21 columns</p>\n",
"</div>"
],
"text/plain": [
" info0 info1 info2 info3 info4 info5 info6 \\\n",
"0 1.936563 1.870453 0.465916 -1.261388 2.133361 1.803505 2.580378 \n",
"1 0.184026 -2.168789 0.411869 -0.942830 1.797832 0.030025 -2.441538 \n",
"2 2.996100 -4.153566 1.262521 -1.465341 -2.303118 -0.769839 1.964011 \n",
"\n",
" info7 info8 info9 ... rand1 rand2 rand3 rand4 \\\n",
"0 -1.595381 0.679063 -0.645615 ... -2.378653 4.323241 -2.079817 0.081981 \n",
"1 0.781152 0.594451 1.105893 ... 0.443111 1.711014 0.709275 -4.459221 \n",
"2 3.439831 -0.543817 0.434642 ... 0.139599 -2.905087 0.447791 1.162886 \n",
"\n",
" rand5 rand6 rand7 rand8 rand9 y \n",
"0 0.040277 1.587782 0.094462 -0.418136 -1.090805 0 \n",
"1 -0.652913 3.639128 1.089707 1.391340 -0.512643 0 \n",
"2 7.065497 4.141545 -0.199840 -0.675612 -1.475501 1 \n",
"\n",
"[3 rows x 21 columns]"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"generate_informative_random_df().head(3)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>info0</th>\n",
" <th>info1</th>\n",
" <th>info2</th>\n",
" <th>info3</th>\n",
" <th>info4</th>\n",
" <th>info5</th>\n",
" <th>info6</th>\n",
" <th>info7</th>\n",
" <th>info8</th>\n",
" <th>info9</th>\n",
" <th>...</th>\n",
" <th>red1</th>\n",
" <th>red2</th>\n",
" <th>red3</th>\n",
" <th>red4</th>\n",
" <th>red5</th>\n",
" <th>red6</th>\n",
" <th>red7</th>\n",
" <th>red8</th>\n",
" <th>red9</th>\n",
" <th>y</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>-0.932563</td>\n",
" <td>3.456379</td>\n",
" <td>-0.507142</td>\n",
" <td>1.473530</td>\n",
" <td>2.651997</td>\n",
" <td>7.762541</td>\n",
" <td>2.374229</td>\n",
" <td>-2.677372</td>\n",
" <td>-0.607400</td>\n",
" <td>-1.481389</td>\n",
" <td>...</td>\n",
" <td>4.553652</td>\n",
" <td>2.596553</td>\n",
" <td>1.600653</td>\n",
" <td>-0.682390</td>\n",
" <td>-5.318137</td>\n",
" <td>-0.888237</td>\n",
" <td>-4.065887</td>\n",
" <td>-0.531629</td>\n",
" <td>-4.780433</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>1.368720</td>\n",
" <td>-3.069189</td>\n",
" <td>4.671624</td>\n",
" <td>-1.197378</td>\n",
" <td>-2.293645</td>\n",
" <td>-6.662906</td>\n",
" <td>-2.335615</td>\n",
" <td>-3.380183</td>\n",
" <td>0.645226</td>\n",
" <td>-0.509708</td>\n",
" <td>...</td>\n",
" <td>-0.467492</td>\n",
" <td>0.501271</td>\n",
" <td>1.678550</td>\n",
" <td>-2.214877</td>\n",
" <td>-2.199863</td>\n",
" <td>1.706975</td>\n",
" <td>-1.803508</td>\n",
" <td>-3.336058</td>\n",
" <td>1.702677</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>0.336312</td>\n",
" <td>-3.582316</td>\n",
" <td>5.304764</td>\n",
" <td>-0.455649</td>\n",
" <td>-0.389859</td>\n",
" <td>-3.792496</td>\n",
" <td>-2.612149</td>\n",
" <td>-0.239425</td>\n",
" <td>4.188700</td>\n",
" <td>0.295560</td>\n",
" <td>...</td>\n",
" <td>2.018401</td>\n",
" <td>-0.419778</td>\n",
" <td>-1.394204</td>\n",
" <td>-0.752825</td>\n",
" <td>-0.294768</td>\n",
" <td>3.402908</td>\n",
" <td>3.007048</td>\n",
" <td>-1.435240</td>\n",
" <td>1.028247</td>\n",
" <td>0</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>3 rows × 21 columns</p>\n",
"</div>"
],
"text/plain": [
" info0 info1 info2 info3 info4 info5 info6 \\\n",
"0 -0.932563 3.456379 -0.507142 1.473530 2.651997 7.762541 2.374229 \n",
"1 1.368720 -3.069189 4.671624 -1.197378 -2.293645 -6.662906 -2.335615 \n",
"2 0.336312 -3.582316 5.304764 -0.455649 -0.389859 -3.792496 -2.612149 \n",
"\n",
" info7 info8 info9 ... red1 red2 red3 red4 \\\n",
"0 -2.677372 -0.607400 -1.481389 ... 4.553652 2.596553 1.600653 -0.682390 \n",
"1 -3.380183 0.645226 -0.509708 ... -0.467492 0.501271 1.678550 -2.214877 \n",
"2 -0.239425 4.188700 0.295560 ... 2.018401 -0.419778 -1.394204 -0.752825 \n",
"\n",
" red5 red6 red7 red8 red9 y \n",
"0 -5.318137 -0.888237 -4.065887 -0.531629 -4.780433 0 \n",
"1 -2.199863 1.706975 -1.803508 -3.336058 1.702677 0 \n",
"2 -0.294768 3.402908 3.007048 -1.435240 1.028247 0 \n",
"\n",
"[3 rows x 21 columns]"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"generate_informative_redundant_df().head(3)"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
"def get_lgb_oof(df):\n",
" \"\"\"get oof using lightgbm and k-fold cross validation\"\"\"\n",
" X = df.drop('y', axis=1).values\n",
" y = df.y.values\n",
"\n",
" # stratified kfold split\n",
" kf = StratifiedKFold(n_splits=5, shuffle=True)\n",
" oof = np.zeros(len(y))\n",
"\n",
" # cv iterate through splits\n",
" for train_index, eval_index in kf.split(X, y):\n",
" X_train, X_eval = X[train_index], X[eval_index]\n",
" y_train, y_eval = y[train_index], y[eval_index]\n",
" \n",
" # prepare datasets\n",
" lgb_train = lgb.Dataset(X_train, y_train)\n",
" lgb_eval = lgb.Dataset(X_eval, y_eval, reference=lgb_train)\n",
"\n",
" # LightGBM hyperparameters\n",
" lgbm_params = {\n",
" 'objective': 'binary',\n",
" 'metric': 'binary_logloss',\n",
" 'verbose': -1,\n",
" }\n",
"\n",
" model = lgb.train(lgbm_params, lgb_train,\n",
" # validation data for the model\n",
" valid_sets=lgb_eval,\n",
" # train up to 10000 rounds\n",
" num_boost_round=10000,\n",
" # if the score doesn't increase for 10 rounds, stop training\n",
" early_stopping_rounds=10)\n",
"\n",
" # predict holdout with the trained model\n",
" y_pred_proba = model.predict(X_eval, num_iteration=model.best_iteration)\n",
" oof[eval_index] = (y_pred_proba > 0.5).astype(int)\n",
"\n",
" return oof"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [],
"source": [
"def get_major_scores(y_true, y_pred):\n",
" \"\"\"get scores like accuracy at a time\"\"\"\n",
" accuracy = accuracy_score(y_true, y_pred)\n",
" precision = precision_score(y_true, y_pred)\n",
" recall = recall_score(y_true, y_pred)\n",
" f1 = f1_score(y_true, y_pred)\n",
" auc = roc_auc_score(y_true, y_pred)\n",
" return pd.Series(dict(accuracy=accuracy, precision=precision, recall=recall, f1=f1, auc=auc))"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [],
"source": [
"def get_scores_from_func(func, class_weights):\n",
" \"\"\"generate a dataset, get oof and scores\"\"\"\n",
" df = func(class_weights)\n",
" oof = get_lgb_oof(df)\n",
" scores = get_major_scores(df.y, oof)\n",
" return scores\n",
"\n",
"def repeat_tests(n_tests, func, class_weights):\n",
" \"\"\"repeat tests (get scores from a dataset) n times\"\"\"\n",
" score_list = [get_scores_from_func(func, class_weights) for _ in range(n_tests)]\n",
" return pd.concat(score_list, axis=1).T"
]
},
{
"cell_type": "code",
"execution_count": 31,
"metadata": {},
"outputs": [],
"source": [
"%%capture cap\n",
"# generate datasets and perform tests\n",
"n_tests = 100\n",
"info = repeat_tests(n_tests, generate_informative_df, class_weights=[0.5])\n",
"info_rand = repeat_tests(n_tests, generate_informative_random_df, class_weights=[0.5])\n",
"info_red = repeat_tests(n_tests, generate_informative_redundant_df, class_weights=[0.5])"
]
},
{
"cell_type": "code",
"execution_count": 32,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[\"[1]\\tvalid_0's binary_logloss: 0.640252\", \"Training until validation scores don't improve for 10 rounds\", \"[2]\\tvalid_0's binary_logloss: 0.595524\"]\n",
"['Early stopping, best iteration is:', \"[202]\\tvalid_0's binary_logloss: 0.15658\", '']\n"
]
}
],
"source": [
"# stdout results\n",
"capsplit = cap.stdout.split('\\n')\n",
"print(capsplit[:3])\n",
"print(capsplit[-3:])"
]
},
{
"cell_type": "code",
"execution_count": 33,
"metadata": {},
"outputs": [],
"source": [
"# concat data\n",
"info['group'] = 'info'\n",
"info_rand['group'] = 'info + rand'\n",
"info_red['group'] = 'info + redundant'\n",
"stack = pd.concat([info, info_rand, info_red])"
]
},
{
"cell_type": "code",
"execution_count": 34,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>accuracy</th>\n",
" <th>precision</th>\n",
" <th>recall</th>\n",
" <th>f1</th>\n",
" <th>auc</th>\n",
" <th>group</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>0.5911</td>\n",
" <td>0.954092</td>\n",
" <td>0.191238</td>\n",
" <td>0.318614</td>\n",
" <td>0.591020</td>\n",
" <td>info</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>0.5906</td>\n",
" <td>0.959100</td>\n",
" <td>0.187901</td>\n",
" <td>0.314238</td>\n",
" <td>0.589957</td>\n",
" <td>info</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>0.5895</td>\n",
" <td>0.935357</td>\n",
" <td>0.191229</td>\n",
" <td>0.317539</td>\n",
" <td>0.589023</td>\n",
" <td>info</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>0.5932</td>\n",
" <td>0.967904</td>\n",
" <td>0.192961</td>\n",
" <td>0.321774</td>\n",
" <td>0.593280</td>\n",
" <td>info</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>0.5935</td>\n",
" <td>0.969031</td>\n",
" <td>0.193845</td>\n",
" <td>0.323064</td>\n",
" <td>0.593820</td>\n",
" <td>info</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>95</th>\n",
" <td>0.5948</td>\n",
" <td>0.978873</td>\n",
" <td>0.194444</td>\n",
" <td>0.324441</td>\n",
" <td>0.595121</td>\n",
" <td>info + redundant</td>\n",
" </tr>\n",
" <tr>\n",
" <th>96</th>\n",
" <td>0.5937</td>\n",
" <td>0.971972</td>\n",
" <td>0.193967</td>\n",
" <td>0.323397</td>\n",
" <td>0.594180</td>\n",
" <td>info + redundant</td>\n",
" </tr>\n",
" <tr>\n",
" <th>97</th>\n",
" <td>0.5906</td>\n",
" <td>0.956389</td>\n",
" <td>0.188827</td>\n",
" <td>0.315385</td>\n",
" <td>0.590118</td>\n",
" <td>info + redundant</td>\n",
" </tr>\n",
" <tr>\n",
" <th>98</th>\n",
" <td>0.5917</td>\n",
" <td>0.956436</td>\n",
" <td>0.193007</td>\n",
" <td>0.321197</td>\n",
" <td>0.592099</td>\n",
" <td>info + redundant</td>\n",
" </tr>\n",
" <tr>\n",
" <th>99</th>\n",
" <td>0.5880</td>\n",
" <td>0.939638</td>\n",
" <td>0.187024</td>\n",
" <td>0.311957</td>\n",
" <td>0.587519</td>\n",
" <td>info + redundant</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>300 rows × 6 columns</p>\n",
"</div>"
],
"text/plain": [
" accuracy precision recall f1 auc group\n",
"0 0.5911 0.954092 0.191238 0.318614 0.591020 info\n",
"1 0.5906 0.959100 0.187901 0.314238 0.589957 info\n",
"2 0.5895 0.935357 0.191229 0.317539 0.589023 info\n",
"3 0.5932 0.967904 0.192961 0.321774 0.593280 info\n",
"4 0.5935 0.969031 0.193845 0.323064 0.593820 info\n",
".. ... ... ... ... ... ...\n",
"95 0.5948 0.978873 0.194444 0.324441 0.595121 info + redundant\n",
"96 0.5937 0.971972 0.193967 0.323397 0.594180 info + redundant\n",
"97 0.5906 0.956389 0.188827 0.315385 0.590118 info + redundant\n",
"98 0.5917 0.956436 0.193007 0.321197 0.592099 info + redundant\n",
"99 0.5880 0.939638 0.187024 0.311957 0.587519 info + redundant\n",
"\n",
"[300 rows x 6 columns]"
]
},
"execution_count": 34,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"stack"
]
},
{
"cell_type": "code",
"execution_count": 35,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "",
"text/plain": [
"<Figure size 1440x1080 with 6 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"# make a boxplot for each score\n",
"fig, axes = plt.subplots(ncols=3, nrows=2, figsize=(20,15))\n",
"score_names = stack.drop('group', axis=1).columns\n",
"for i, score_name in enumerate(score_names):\n",
" stack.boxplot(column=score_name, by='group', ax=axes[i // 3, i % 3])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"What about an imbalanced dataset?"
]
},
{
"cell_type": "code",
"execution_count": 36,
"metadata": {},
"outputs": [],
"source": [
"%%capture cap\n",
"n_tests = 100\n",
"info = repeat_tests(n_tests, generate_informative_df, class_weights=[0.99])\n",
"info_rand = repeat_tests(n_tests, generate_informative_random_df, class_weights=[0.99])\n",
"info_red = repeat_tests(n_tests, generate_informative_redundant_df, class_weights=[0.99])"
]
},
{
"cell_type": "code",
"execution_count": 37,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[\"[1]\\tvalid_0's binary_logloss: 0.065233\", \"Training until validation scores don't improve for 10 rounds\", \"[2]\\tvalid_0's binary_logloss: 0.0623185\"]\n",
"['Early stopping, best iteration is:', \"[32]\\tvalid_0's binary_logloss: 0.0501468\", '']\n"
]
}
],
"source": [
"capsplit = cap.stdout.split('\\n')\n",
"print(capsplit[:3])\n",
"print(capsplit[-3:])"
]
},
{
"cell_type": "code",
"execution_count": 38,
"metadata": {},
"outputs": [],
"source": [
"info['group'] = 'info'\n",
"info_rand['group'] = 'info + rand'\n",
"info_red['group'] = 'info + redundant'\n",
"stack = pd.concat([info, info_rand, info_red])"
]
},
{
"cell_type": "code",
"execution_count": 39,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "",
"text/plain": [
"<Figure size 1440x1080 with 6 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"fig, axes = plt.subplots(ncols=3, nrows=2, figsize=(20,15))\n",
"score_names = stack.drop('group', axis=1).columns\n",
"for i, score_name in enumerate(score_names):\n",
" stack.boxplot(column=score_name, by='group', ax=axes[i // 3, i % 3])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Almost the same!"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3.6.13 64-bit ('usr')",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.13"
},
"orig_nbformat": 4,
"vscode": {
"interpreter": {
"hash": "91500301695e1fbb9150d4f634352c3c795d8e09e120b79156f6e946456f2571"
}
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment