Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save pb111/d382f3cc54ed5bd9a3218e73a2115dfb to your computer and use it in GitHub Desktop.
Save pb111/d382f3cc54ed5bd9a3218e73a2115dfb to your computer and use it in GitHub Desktop.
Data Preprocessing Project - Imbalanced Classes Problem
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Data Preprocessing Project - Imbalanced Classes Problem\n",
"\n",
"\n",
"Imbalanced classes is one of the major problems in machine learning. In this data preprocessing project, I discuss the imbalanced classes problem. I present Python implementation to deal with this problem."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Table of Contents\n",
"\n",
"\n",
"I have divided this project into various sections which are listed below:-\n",
"\n",
"\n",
"\n",
"\n",
"1.\tIntroduction to imbalanced classes problem\n",
"\n",
"2.\tProblems with imbalanced learning\n",
"\n",
"3.\tExample of imbalanced classes\n",
"\n",
"4.\tApproaches to handle imbalanced classes\n",
"\n",
"5.\tPython implementation to illustrate class imbalance problem\n",
"\n",
"6.\tPrecision - Recall Curve\n",
"\n",
"7. Random over-sampling the minority class\n",
"\n",
"8.\tRandom under-sampling the majority class\n",
"\n",
"9.\tApply tree-based algorithms\n",
"\n",
"10.\tRandom under-sampling and over-sampling with imbalanced-learn\n",
"\n",
"11.\tUnder-sampling : Tomek links\n",
"\n",
"12.\tUnder-sampling : Cluster Centroids\n",
"\n",
"13.\tOver-sampling : SMOTE\n",
"\n",
"14.\tConclusion \n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 1. Introduction to imbalanced classes problem\n",
"\n",
"\n",
"Any real world dataset may come along with several problems. The problem of **imbalanced class** is one of them. The problem of imbalanced classes arises when one set of classes dominate over another set of classes. The former is called majority class while the latter is called minority class. It causes the machine learning model to be more biased towards majority class. It causes poor classification of minority classes. Hence, this problem throw the question of “accuracy” out of question. This is a very common problem in machine learning where we have datasets with a disproportionate ratio of observations in each class.\n",
"\n",
"\n",
"**Imbalanced classes problem** is one of the major problems in the field of data science and machine learning. It is very important that we should properly deal with this problem and develop our machine learning model accordingly. If this not done, then we may end up with higher accuracy. But this higher accuracy is meaningless because it comes from a meaningless metric which is not suitable for the dataset in question. Hence, this higher accuracy no longer reliably measures model performance. \n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 2. Problems with imbalanced learning\n",
"\n",
"\n",
"The problem of imbalanced classes is very common and it is bound to happen. For example, in the above example the number of patients who do not have the rare disease is much larger than the number of patients who have the rare disease. So, the model does not correctly classify the patients who have the rare disease. This is where the problem arises.\n",
"\n",
"\n",
"The problem of learning from imbalanced data have new and modern approaches. This learning from imbalanced data is referred to as **imbalanced learning**. \n",
"\n",
"\n",
"Significant problems may arise with imbalanced learning. These are as follows:-\n",
"\n",
"\n",
"1.\tThe class distribution is skewed when the dataset has underrepresented data.\n",
"\n",
"2.\tThe high level of accuracy is simply misleading. In the previous example, it is high because most patients do not \n",
" have the disease not because of the good model. \n",
" \n",
"3.\tThere may be inherent complex characteristics in the dataset. Imbalanced learning from such dataset requires new \n",
" approaches, principles, tools and techniques. But, it cannot guarantee an efficient solution to the business problem.\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 3. Example of imbalanced classes\n",
"\n",
"\n",
"The problem of imbalanced classes may appear in many areas including the following:-\n",
"\n",
"\n",
"1.\tDisease detection\n",
"\n",
"2.\tFraud detection\n",
"\n",
"3.\tSpam filtering\n",
"\n",
"4.\tEarthquake prediction\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 4. Approaches to handle imbalanced classes\n",
"\n",
"\n",
"In this section, I will list various approaches to deal with the imbalanced class problem. These approaches may fall under two categories – dataset level approach and algorithmic ensemble techniques approach. The various methods to deal with imbalanced class problem are listed below. I will describe these techniques in more detail in the following sections.\n",
"\n",
"\n",
"1.\tRandom Undersampling methods\n",
"\n",
"2.\tRandom Oversampling methods\n",
"\n",
"3. Tree-based algorithms\n",
"\n",
"4. Resampling with imbalanced-learn\n",
"\n",
"5. Under-sampling : Tomek links\n",
"\n",
"6. Under-sampling : Cluster Centroids\n",
"\n",
"7. Over-sampling : SMOTE\n",
"\n",
"\n",
"I have discussed these methods in detail in the readme document."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 5. Python implementation to illustrate class imbalance problem\n",
"\n",
"\n",
"\n",
"Now, I will perform Python implementation to illustrate class imbalance problem."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Import Python libraries\n",
"\n",
"I will start off by importing the required Python libraries."
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"# import Python libraries\n",
"\n",
"import numpy as np\n",
"import pandas as pd\n",
"import matplotlib.pyplot as plt\n",
"%matplotlib inline"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"import warnings\n",
"warnings.filterwarnings('ignore')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Import dataset\n",
"\n",
"\n",
"Now, I will import the dataset with the usual Python `read_csv()` function."
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"data = 'C:/datasets/creditcard.csv'\n",
"\n",
"df = pd.read_csv(data)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Dataset description\n",
"\n",
"\n",
"I have used the **Credit Card Fraud Detecttion** dataset for this project. I have downloaded this project from the Kaggle website. This dataset can be found at the following url-\n",
"\n",
"\n",
"https://www.kaggle.com/mlg-ulb/creditcardfraud\n",
"\n",
"\n",
"This dataset contains transactions made by european credit card holders in September 2013. It represents transactions that occurred in two days. We have 492 fraudulent transactions out of total 284,807 transactions. This dataset is highly unbalanced, the positive class (frauds) account for only 0.172% of all transactions.\n",
"\n",
"\n",
"Feature 'Time' contains the seconds elapsed between each transaction and the first transaction in the dataset. The feature 'Amount' is the transaction Amount, this feature can be used for example-dependant cost-senstive learning. Feature 'Class' is the response variable and it takes value 1 in case of fraud and 0 otherwise. So, our target variable is `Class` variable.\n",
"\n",
"\n",
"\n",
"Given the class imbalance ratio, it is recommended to measure the accuracy using the `Area Under the Precision-Recall Curve (AUPRC)`. Confusion matrix accuracy is not meaningful for unbalanced classification.\n",
" "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Exploratory data analysis\n",
"\n",
"\n",
"Now, I will conduct exploratory data analysis to gain an insight into the dataset."
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(284807, 31)"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# check shape of dataset\n",
"\n",
"df.shape"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We can see that there are 284,807 instances and 31 columns in the dataset."
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Time</th>\n",
" <th>V1</th>\n",
" <th>V2</th>\n",
" <th>V3</th>\n",
" <th>V4</th>\n",
" <th>V5</th>\n",
" <th>V6</th>\n",
" <th>V7</th>\n",
" <th>V8</th>\n",
" <th>V9</th>\n",
" <th>...</th>\n",
" <th>V21</th>\n",
" <th>V22</th>\n",
" <th>V23</th>\n",
" <th>V24</th>\n",
" <th>V25</th>\n",
" <th>V26</th>\n",
" <th>V27</th>\n",
" <th>V28</th>\n",
" <th>Amount</th>\n",
" <th>Class</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>0.0</td>\n",
" <td>-1.359807</td>\n",
" <td>-0.072781</td>\n",
" <td>2.536347</td>\n",
" <td>1.378155</td>\n",
" <td>-0.338321</td>\n",
" <td>0.462388</td>\n",
" <td>0.239599</td>\n",
" <td>0.098698</td>\n",
" <td>0.363787</td>\n",
" <td>...</td>\n",
" <td>-0.018307</td>\n",
" <td>0.277838</td>\n",
" <td>-0.110474</td>\n",
" <td>0.066928</td>\n",
" <td>0.128539</td>\n",
" <td>-0.189115</td>\n",
" <td>0.133558</td>\n",
" <td>-0.021053</td>\n",
" <td>149.62</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>0.0</td>\n",
" <td>1.191857</td>\n",
" <td>0.266151</td>\n",
" <td>0.166480</td>\n",
" <td>0.448154</td>\n",
" <td>0.060018</td>\n",
" <td>-0.082361</td>\n",
" <td>-0.078803</td>\n",
" <td>0.085102</td>\n",
" <td>-0.255425</td>\n",
" <td>...</td>\n",
" <td>-0.225775</td>\n",
" <td>-0.638672</td>\n",
" <td>0.101288</td>\n",
" <td>-0.339846</td>\n",
" <td>0.167170</td>\n",
" <td>0.125895</td>\n",
" <td>-0.008983</td>\n",
" <td>0.014724</td>\n",
" <td>2.69</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>1.0</td>\n",
" <td>-1.358354</td>\n",
" <td>-1.340163</td>\n",
" <td>1.773209</td>\n",
" <td>0.379780</td>\n",
" <td>-0.503198</td>\n",
" <td>1.800499</td>\n",
" <td>0.791461</td>\n",
" <td>0.247676</td>\n",
" <td>-1.514654</td>\n",
" <td>...</td>\n",
" <td>0.247998</td>\n",
" <td>0.771679</td>\n",
" <td>0.909412</td>\n",
" <td>-0.689281</td>\n",
" <td>-0.327642</td>\n",
" <td>-0.139097</td>\n",
" <td>-0.055353</td>\n",
" <td>-0.059752</td>\n",
" <td>378.66</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>1.0</td>\n",
" <td>-0.966272</td>\n",
" <td>-0.185226</td>\n",
" <td>1.792993</td>\n",
" <td>-0.863291</td>\n",
" <td>-0.010309</td>\n",
" <td>1.247203</td>\n",
" <td>0.237609</td>\n",
" <td>0.377436</td>\n",
" <td>-1.387024</td>\n",
" <td>...</td>\n",
" <td>-0.108300</td>\n",
" <td>0.005274</td>\n",
" <td>-0.190321</td>\n",
" <td>-1.175575</td>\n",
" <td>0.647376</td>\n",
" <td>-0.221929</td>\n",
" <td>0.062723</td>\n",
" <td>0.061458</td>\n",
" <td>123.50</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>2.0</td>\n",
" <td>-1.158233</td>\n",
" <td>0.877737</td>\n",
" <td>1.548718</td>\n",
" <td>0.403034</td>\n",
" <td>-0.407193</td>\n",
" <td>0.095921</td>\n",
" <td>0.592941</td>\n",
" <td>-0.270533</td>\n",
" <td>0.817739</td>\n",
" <td>...</td>\n",
" <td>-0.009431</td>\n",
" <td>0.798278</td>\n",
" <td>-0.137458</td>\n",
" <td>0.141267</td>\n",
" <td>-0.206010</td>\n",
" <td>0.502292</td>\n",
" <td>0.219422</td>\n",
" <td>0.215153</td>\n",
" <td>69.99</td>\n",
" <td>0</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>5 rows × 31 columns</p>\n",
"</div>"
],
"text/plain": [
" Time V1 V2 V3 V4 V5 V6 V7 \\\n",
"0 0.0 -1.359807 -0.072781 2.536347 1.378155 -0.338321 0.462388 0.239599 \n",
"1 0.0 1.191857 0.266151 0.166480 0.448154 0.060018 -0.082361 -0.078803 \n",
"2 1.0 -1.358354 -1.340163 1.773209 0.379780 -0.503198 1.800499 0.791461 \n",
"3 1.0 -0.966272 -0.185226 1.792993 -0.863291 -0.010309 1.247203 0.237609 \n",
"4 2.0 -1.158233 0.877737 1.548718 0.403034 -0.407193 0.095921 0.592941 \n",
"\n",
" V8 V9 ... V21 V22 V23 V24 \\\n",
"0 0.098698 0.363787 ... -0.018307 0.277838 -0.110474 0.066928 \n",
"1 0.085102 -0.255425 ... -0.225775 -0.638672 0.101288 -0.339846 \n",
"2 0.247676 -1.514654 ... 0.247998 0.771679 0.909412 -0.689281 \n",
"3 0.377436 -1.387024 ... -0.108300 0.005274 -0.190321 -1.175575 \n",
"4 -0.270533 0.817739 ... -0.009431 0.798278 -0.137458 0.141267 \n",
"\n",
" V25 V26 V27 V28 Amount Class \n",
"0 0.128539 -0.189115 0.133558 -0.021053 149.62 0 \n",
"1 0.167170 0.125895 -0.008983 0.014724 2.69 0 \n",
"2 -0.327642 -0.139097 -0.055353 -0.059752 378.66 0 \n",
"3 0.647376 -0.221929 0.062723 0.061458 123.50 0 \n",
"4 -0.206010 0.502292 0.219422 0.215153 69.99 0 \n",
"\n",
"[5 rows x 31 columns]"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# preview of the dataset\n",
"\n",
"df.head()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The `df.head()` function gives the preview of the dataset. We can see that there is a `Class` column in the dataset which is our target variable.\n",
"\n",
"\n",
"I will check the distribution of the `Class` column with the `value_counts()` method as follows:-"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"0 284315\n",
"1 492\n",
"Name: Class, dtype: int64"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# check the distribution of Class column\n",
"\n",
"df['Class'].value_counts()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"So, we have 492 fraudulent transactions out of total 284,807 transactions in the dataset. The `Class` column takes value `1 for \n",
"fraudulent transactions` and `0 for non-fraudulent transactions`."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now, I will find the percentage of labels 0 and 1 within the `Class` column."
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"0 0.998273\n",
"1 0.001727\n",
"Name: Class, dtype: float64"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# percentage of labels within the Class column\n",
"\n",
"df['Class'].value_counts()/np.float(len(df))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We can see that the `Class` column is highly imbalanced. It contains 99.82% labels as `0` and 0.17% labels as `1`. \n",
"\n",
"Now, I will plot the bar plot to confirm this."
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<matplotlib.axes._subplots.AxesSubplot at 0xb5820294a8>"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAXcAAAD4CAYAAAAXUaZHAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvIxREBQAAC15JREFUeJzt3V+Infldx/H3ZxOjF60tmFFq/jSBpmgUYWWIhV640orJCslNkQSKWpbmxijSIkaUVeONthcFIf4JWqsFN8Ze6FAjuahbBDU1s7QuJiE6xGqGiDttlwUpmsZ+vZixHk5Ocp5JTnKSb94vGDjP8/xyzpcwefPMc84zSVUhSerlmXkPIEmaPeMuSQ0Zd0lqyLhLUkPGXZIaMu6S1JBxl6SGjLskNWTcJamhrfN64e3bt9eePXvm9fKS9ER65ZVXvlRVC9PWzS3ue/bsYXl5eV4vL0lPpCT/OmSdl2UkqSHjLkkNGXdJasi4S1JDxl2SGpoa9yQfT/Jakn+8y/Ek+c0kK0leTfL9sx9TkrQZQ87cPwEcvMfxQ8C+ja/jwG8/+FiSpAcxNe5V9dfAV+6x5AjwR7XuIvDWJG+b1YCSpM2bxU1MO4AbI9urG/v+fXxhkuOsn92ze/fuGbz0w7fn5F/Me4RWvvjrPzrvEaSnwizeUM2EfRP/1+2qOlNVi1W1uLAw9e5ZSdJ9mkXcV4FdI9s7gZszeF5J0n2aRdyXgB/f+NTMu4A3quqOSzKSpEdn6jX3JC8BzwHbk6wCvwx8E0BV/Q5wHngeWAG+CnzgYQ0rSRpmatyr6tiU4wX81MwmkiQ9MO9QlaSGjLskNWTcJakh4y5JDRl3SWrIuEtSQ8Zdkhoy7pLUkHGXpIaMuyQ1ZNwlqSHjLkkNGXdJasi4S1JDxl2SGjLuktSQcZekhoy7JDVk3CWpIeMuSQ0Zd0lqyLhLUkPGXZIaMu6S1JBxl6SGjLskNWTcJakh4y5JDRl3SWrIuEtSQ8Zdkhoy7pLU0KC4JzmY5FqSlSQnJxzfneTlJJ9P8mqS52c/qiRpqKlxT7IFOA0cAvYDx5LsH1v2S8C5qnoWOAr81qwHlSQNN+TM/QCwUlXXq+oWcBY4MramgG/dePwW4ObsRpQkbdaQuO8Aboxsr27sG/UrwPuTrALngZ+e9ERJjidZTrK8trZ2H+NKkoYYEvdM2Fdj28eAT1TVTuB54JNJ7njuqjpTVYtVtbiwsLD5aSVJgwyJ+yqwa2R7J3dednkBOAdQVX8HfAuwfRYDSpI2b0jcLwH7kuxNso31N0yXxtb8G/AegCTfzXrcve4iSXMyNe5VdRs4AVwArrL+qZjLSU4lObyx7MPAB5P8A/AS8JNVNX7pRpL0iGwdsqiqzrP+RunovhdHHl8B3j3b0SRJ98s7VCWpIeMuSQ0Zd0lqyLhLUkPGXZIaMu6S1JBxl6SGjLskNWTcJakh4y5JDRl3SWrIuEtSQ8Zdkhoy7pLUkHGXpIaMuyQ1ZNwlqSHjLkkNGXdJasi4S1JDxl2SGjLuktSQcZekhoy7JDVk3CWpIeMuSQ0Zd0lqyLhLUkPGXZIaMu6S1JBxl6SGjLskNWTcJamhQXFPcjDJtSQrSU7eZc2PJbmS5HKSP57tmJKkzdg6bUGSLcBp4IeBVeBSkqWqujKyZh/wC8C7q+r1JN/+sAaWJE035Mz9ALBSVder6hZwFjgytuaDwOmqeh2gql6b7ZiSpM0YEvcdwI2R7dWNfaPeCbwzyd8kuZjk4KQnSnI8yXKS5bW1tfubWJI01ZC4Z8K+GtveCuwDngOOAb+X5K13/KGqM1W1WFWLCwsLm51VkjTQkLivArtGtncCNyes+fOq+lpV/QtwjfXYS5LmYEjcLwH7kuxNsg04CiyNrfkz4IcAkmxn/TLN9VkOKkkabmrcq+o2cAK4AFwFzlXV5SSnkhzeWHYB+HKSK8DLwM9V1Zcf1tCSpHub+lFIgKo6D5wf2/fiyOMCPrTxJUmaM+9QlaSGjLskNWTcJakh4y5JDRl3SWrIuEtSQ8Zdkhoy7pLUkHGXpIaMuyQ1ZNwlqSHjLkkNGXdJasi4S1JDxl2SGjLuktSQcZekhoy7JDVk3CWpIeMuSQ0Zd0lqyLhLUkPGXZIaMu6S1JBxl6SGjLskNWTcJakh4y5JDRl3SWrIuEtSQ8Zdkhoy7pLU0KC4JzmY5FqSlSQn77HufUkqyeLsRpQkbdbUuCfZApwGDgH7gWNJ9k9Y92bgZ4DPzXpISdLmDDlzPwCsVNX1qroFnAWOTFj3a8BHgP+a4XySpPswJO47gBsj26sb+74hybPArqr69L2eKMnxJMtJltfW1jY9rCRpmCFxz4R99Y2DyTPAx4APT3uiqjpTVYtVtbiwsDB8SknSpgyJ+yqwa2R7J3BzZPvNwPcCn03yReBdwJJvqkrS/AyJ+yVgX5K9SbYBR4Gl/ztYVW9U1faq2lNVe4CLwOGqWn4oE0uSppoa96q6DZwALgBXgXNVdTnJqSSHH/aAkqTN2zpkUVWdB86P7XvxLmufe/CxJEkPwjtUJakh4y5JDRl3SWrIuEtSQ8Zdkhoy7pLUkHGXpIaMuyQ1ZNwlqSHjLkkNGXdJasi4S1JDxl2SGjLuktSQcZekhoy7JDVk3CWpIeMuSQ0Zd0lqyLhLUkPGXZIaMu6S1JBxl6SGjLskNWTcJakh4y5JDRl3SWrIuEtSQ8Zdkhoy7pLUkHGXpIaMuyQ1ZNwlqaFBcU9yMMm1JCtJTk44/qEkV5K8muQzSd4++1ElSUNNjXuSLcBp4BCwHziWZP/Yss8Di1X1fcCngI/MelBJ0nBDztwPACtVdb2qbgFngSOjC6rq5ar66sbmRWDnbMeUJG3GkLjvAG6MbK9u7LubF4C/nHQgyfEky0mW19bWhk8pSdqUIXHPhH01cWHyfmAR+Oik41V1pqoWq2pxYWFh+JSSpE3ZOmDNKrBrZHsncHN8UZL3Ar8I/GBV/fdsxpMk3Y8hZ+6XgH1J9ibZBhwFlkYXJHkW+F3gcFW9NvsxJUmbMTXuVXUbOAFcAK4C56rqcpJTSQ5vLPso8CbgT5N8IcnSXZ5OkvQIDLksQ1WdB86P7Xtx5PF7ZzyXJOkBeIeqJDVk3CWpIeMuSQ0Zd0lqyLhLUkPGXZIaMu6S1JBxl6SGjLskNWTcJakh4y5JDRl3SWrIuEtSQ8Zdkhoy7pLUkHGXpIaMuyQ1ZNwlqSHjLkkNGXdJasi4S1JDxl2SGjLuktSQcZekhoy7JDVk3CWpIeMuSQ0Zd0lqyLhLUkPGXZIaMu6S1JBxl6SGBsU9ycEk15KsJDk54fg3J/mTjeOfS7Jn1oNKkoabGvckW4DTwCFgP3Asyf6xZS8Ar1fVO4CPAb8x60ElScMNOXM/AKxU1fWqugWcBY6MrTkC/OHG408B70mS2Y0pSdqMrQPW7ABujGyvAj9wtzVVdTvJG8C3AV8aXZTkOHB8Y/M/k1y7n6E10XbG/r4fR/FnuqfRE/G9+QR5+5BFQ+I+6Qy87mMNVXUGODPgNbVJSZaranHec0jj/N6cjyGXZVaBXSPbO4Gbd1uTZCvwFuArsxhQkrR5Q+J+CdiXZG+SbcBRYGlszRLwExuP3wf8VVXdceYuSXo0pl6W2biGfgK4AGwBPl5Vl5OcAparagn4feCTSVZYP2M/+jCH1kRe7tLjyu/NOYgn2JLUj3eoSlJDxl2SGjLuktTQkM+56zGT5LtYvyt4B+v3E9wElqrq6lwHk/TY8Mz9CZPk51n/FRAB/p71j6oGeGnSL3WT9HTy0zJPmCT/BHxPVX1tbP824HJV7ZvPZNK9JflAVf3BvOd4Wnjm/uT5OvCdE/a/beOY9Lj61XkP8DTxmvuT52eBzyT5Z/7/F7rtBt4BnJjbVBKQ5NW7HQK+41HO8rTzsswTKMkzrP8q5h2s/6NZBS5V1f/MdTA99ZL8B/AjwOvjh4C/rapJP3XqIfDM/QlUVV8HLs57DmmCTwNvqqovjB9I8tlHP87TyzN3SWrIN1QlqSHjLkkNGXdJasi4S1JD/wvQ9V7if0uClQAAAABJRU5ErkJggg==\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"# view the distribution of percentages within the Class column\n",
"\n",
"\n",
"(df['Class'].value_counts()/np.float(len(df))).plot.bar()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The above bar plot confirms our finding that the `Class` variable is highly imbalanced. "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Misleading accuracy for imbalanced classes\n",
"\n",
"\n",
"Now, I will demonstrate that accuracy is misleading for imbalanced classes. Most of the machine learning algorithms are designed to maximize the overall accuracy by default. But this maximum accuracy is misleading. We can confirm this with the following analysis.\n",
"\n",
"\n",
"I will fit a very simple Logistic Regression model using the default settings. I will train the classifier on the imbalanced dataset."
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"# declare feature vector and target variable\n",
"\n",
"X = df.drop(['Class'], axis=1)\n",
"y = df['Class']"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"# import Logistic Regression classifier\n",
"from sklearn.linear_model import LogisticRegression\n",
"\n",
"\n",
"# instantiate the Logistic Regression classifier\n",
"logreg = LogisticRegression()\n",
"\n",
"\n",
"# fit the classifier to the imbalanced data\n",
"clf = logreg.fit(X, y)\n",
"\n",
"\n",
"# predict on the training data\n",
"y_pred = clf.predict(X)\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now, I have trained the model. I will check its accuracy."
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Accuracy : 99.90%\n"
]
}
],
"source": [
"# import the accuracy metric\n",
"from sklearn.metrics import accuracy_score\n",
"\n",
"\n",
"# print the accuracy\n",
"accuracy = accuracy_score(y_pred, y)\n",
"\n",
"print(\"Accuracy : %.2f%%\" % (accuracy * 100.0))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Accuracy paradox\n",
"\n",
"\n",
"Thus, our Logistic Regression model for credit card fraud detection has an accuracy of 99.90%. It means that for each 100 transactions it classified, 99.90% were classified as genuine.\n",
"\n",
"\n",
"It does not mean that our model performance is excellent. I have previously shown that our dataset have 99.90% genuine transactions and 0.1% fraudulent transactions. Our Logistic Regression classifier predicted all transactions as genuine. \n",
"Then we have a accuracy of 99.90% because it correctly classified 99.90% transactions as genuine.\n",
"\n",
"\n",
"Thus, this algorithm is 99.90% accurate. But it was horrible at classifying fraudulent transactions. So, we should have other ways to measure the model performance. One such measure is confusion matrix described below."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Confusion matrix\n",
"\n",
"\n",
"A confusion matrix is a tool for summarizing the performance of a classification algorithm. A confusion matrix will give us a clear picture of classification model performance and the types of errors produced by the model. It gives us a summary of correct and incorrect predictions broken down by each category. The summary is represented in a tabular form.\n",
"\n",
"\n",
"Four types of outcomes are possible while evaluating a classification model performance. These four outcomes are described below:-\n",
"\n",
"\n",
"**True Positives (TP)** – True Positives occur when we predict an observation belongs to a certain class and the observation actually belongs to that class.\n",
"\n",
"\n",
"**True Negatives (TN)** – True Negatives occur when we predict an observation does not belong to a certain class and the observation actually does not belong to that class.\n",
"\n",
"\n",
"**False Positives (FP)** – False Positives occur when we predict an observation belongs to a certain class but the observation actually does not belong to that class. This type of error is called **Type I error.**\n",
"\n",
"\n",
"\n",
"**False Negatives (FN)** – False Negatives occur when we predict an observation does not belong to a certain class but the observation actually belongs to that class. This is a very serious error and it is called **Type II error.**\n",
"\n",
"\n",
"\n",
"These four outcomes are summarized in a confusion matrix given below.\n"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Confusion matrix:\n",
" [[284240 75]\n",
" [ 203 289]]\n"
]
}
],
"source": [
"# import the metric\n",
"from sklearn.metrics import confusion_matrix\n",
"\n",
"\n",
"# print the confusion matrix\n",
"cnf_matrix = confusion_matrix(y, y_pred)\n",
"\n",
"\n",
"print('Confusion matrix:\\n', cnf_matrix)\n",
" "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Interpretation of confusion matrix\n",
"\n",
"\n",
"Now, I will interpret the confusion matrix.\n",
"\n",
"\n",
"- Out of the total 284315 transactions which were predicted genuine, the classifier predicted correctly 284240 of them. It means that the classifer predicted 284240 transactions as genuine and they were actually genuine. Also, it predicted 75 transactions as genuine but it were fraudulent. So, we have `284240 True Positives(TP)` and `75 False Positives(FP)`.\n",
"\n",
"\n",
"- Out of the total 492 transactions which were not predicted as genuine, the classifier predicted correctly 289 of them. It means that the classifer did not predict 289 transactions as genuine and they were actually not genuine. SO, they were fraudulent. Also, it did not predict 203 transactions as genuine but they were genuine. So, we have `289 True Negatives(TN)` and `203 False Negatives(FN)`.\n",
"\n",
"\n",
"\n",
"- So, out of all the 284807 transactions, the classifier correctly predicted 284529 of them. Thus, we will get the accuracy of\n",
"`(284240+289)/(284240+289+75+203) = 99.90%.`\n",
"\n",
"\n",
"\n",
"- But this is not the true picture. The confusion matrix allows us to obtain a true picture of the performance of the algorithm. The algorithm tries to predict the fraudulent transactions out of the total transactions. It correctly predicted 289 transactions as fraudulent out of all the 284807 transactions. In this case the accuracy becomes `(289/284807)=0.10%.`\n",
"\n",
"\n",
"\n",
"- Moreover, we have `203+289=492` transactions as fraudulent. The algorithm is correctly classifying 289 of them as fraudulent while it fails to predict 203 transactions which were fraudulent. In this case the accuracy becomes `(289/492)=58.74%.`\n",
"\n",
"\n",
"So, we can conclude that the accuracy of 99.90% is misleading because we have imbalanced classes. We need more subtle way to evaluate the performance of the model.\n",
"\n",
"\n",
"There is another metric called `Classification Report` which helps to evaluate model performance."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Classification report\n",
"\n",
"\n",
"\n",
"**Classification report** is another way to evaluate the classification model performance. It displays the **precision**, **recall**, **f1** and **support** scores for the model. I have described these terms in later sections.\n",
"\n",
"\n",
"\n",
"We can plot a classification report as follows:-"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Classification Report:\n",
"\n",
" precision recall f1-score support\n",
"\n",
" 0 1.00 1.00 1.00 284315\n",
" 1 0.79 0.59 0.68 492\n",
"\n",
" micro avg 1.00 1.00 1.00 284807\n",
" macro avg 0.90 0.79 0.84 284807\n",
"weighted avg 1.00 1.00 1.00 284807\n",
"\n"
]
}
],
"source": [
"# import the metric\n",
"from sklearn.metrics import classification_report\n",
"\n",
"\n",
"# print classification report\n",
"print(\"Classification Report:\\n\\n\", classification_report(y, y_pred))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Precision\n",
"\n",
"\n",
"Precision can be defined as the percentage of correctly predicted positive outcomes out of all the predicted positive outcomes.\n",
"It can be given as the ratio of true positives (TP) to the sum of true and false positives (TP + FP). \n",
"\n",
"\n",
"Mathematically, **precision** can be defined as the ratio of `TP to (TP + FP).`\n",
"\n",
"\n",
"So, precision is more concerned with the positive class than the negative class.\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Recall\n",
"\n",
"\n",
"Recall can be defined as the percentage of correctly predicted positive outcomes out of all the actual positive outcomes.\n",
"It can be given as the ratio of true positives (TP) to the sum of true positives and false negatives (TP + FN). **Recall** is also called **Sensitivity**.\n",
"\n",
"\n",
"Mathematically, **recall** can be given as the ratio of `TP to (TP + FN).`\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### f1-score\n",
"\n",
"\n",
"**f1-score** is the weighted harmonic mean of precision and recall. The best possible **f1-score** would be 1.0 and the worst \n",
"would be 0.0. **f1-score** is the harmonic mean of precision and recall. So, **f1-score** is always lower than accuracy measures as they embed precision and recall into their computation. The weighted average of `f1-score` should be used to \n",
"compare classifier models, not global accuracy.\n",
"\n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Support\n",
"\n",
"\n",
"**Support** is the actual number of occurrences of the class in our dataset. It classifies `284315 transactions as genuine` and `492 transactions as fraudulent`."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### ROC Curve\n",
"\n",
"\n",
"Another tool to measure the classification model performance visually is **ROC Curve**. ROC Curve stands for **Receiver Operating Characteristic Curve**. \n",
"\n",
"\n",
"The **ROC Curve** plots the **True Positive Rate (TPR)** against the **False Positive Rate (FPR)** at various threshold levels.\n",
"\n",
"\n",
"**True Positive Rate (TPR)** is also called **Recall**. It is defined as the ratio of `TP to (TP + FN).`\n",
"\n",
"\n",
"\n",
"\n",
"**False Positive Rate (FPR)** is defined as the ratio of `FP to (FP + TN).`\n",
"\n",
"\n",
"\n",
"\n",
"The **Receiver Operating Characteristic Area Under Curve (ROC AUC)** is the area under the ROC curve. The higher it is, the better the model is. \n",
"\n",
"\n",
"In the ROC Curve, we will focus on the TPR (True Positive Rate) and FPR (False Positive Rate) of a single point. This will give us the general performance of the ROC curve which consists of the TPR and FPR at various probability thresholds."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Precision - Recall Curve\n",
"\n",
"\n",
"\n",
"Another tool to measure the classification model performance is **Precision-Recall Curve**. It is a useful metric which is used to evaluate a classifier model performance when classes are very imbalanced such as in this case. This **Precision-Recall Curve** shows the trade off between precision and recall.\n",
"\n",
"\n",
"\n",
"In a **Precision-Recall Curve**, we plot **Precision** against **Recall**.\n",
"\n",
"\n",
"**Precision** is defined as the ratio of `TP to (TP + FP).`\n",
"\n",
"\n",
"\n",
"\n",
"**Recall** is defined as the ratio of `TP to (TP + FN).`\n",
"\n",
"\n",
"\n",
"\n",
"The **Precision Recall Area Under Curve (PR AUC)** is the area under the PR curve. The higher it is, the better the model is."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Difference between ROC AUC and PR AUC\n",
"\n",
"\n",
"- Precision-Recall does not account for True Negatives (TN) unlike ROC AUC (TN is not a component of either Precision or Recall). \n",
"\n",
"\n",
"- In the cases of class imbalance problem, we have many more negatives than positives. The Precision-Recall curve much better illustrates the difference between algorithms in the class imbalance problem cases where there are lot more negative examples than the positive examples. In these cases of class imbalances, we should use Precision-Recall Curve (PR AUC), otherwise we should use ROC AUC.\n",
"\n",
"\n",
"So, we can conclude that we should use PR AUC for cases where the class imbalance problem occurs. Otherwise, we should use ROC AUC.\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 6. Precision - Recall Curve \n",
"\n",
"\n",
"In the previous section, we conclude that we should use `Precision-Recall Area Under Curve` for cases where the class imbalance problem exists. Otherwise, we should use `ROC-AUC (Receiver Operating Characteristic Area Under Curve)`.\n",
"\n",
"\n",
"Now, I will compute the `average precision score`. "
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Average precision-recall score : 0.47\n"
]
}
],
"source": [
"# compute and print average precision score\n",
"\n",
"from sklearn.metrics import average_precision_score\n",
"\n",
"average_precision = average_precision_score(y_pred, y)\n",
"\n",
"print('Average precision-recall score : {0:0.2f}'.format(average_precision))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"`Precision-Recall Curve` gives us the correct accuracy in this imbalanced dataset case. We can see that we have a very poor accuracy for the model.\n",
"\n",
"\n",
"Now, I will plot the `precision-recall curve`."
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<matplotlib.legend.Legend at 0xb5847599e8>"
]
},
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"from sklearn.metrics import precision_recall_curve \n",
"\n",
"precision, recall, thresholds = precision_recall_curve(y_pred, y)\n",
"\n",
"# create plot\n",
"plt.plot(precision, recall, label='Precision-recall curve')\n",
"plt.xlabel('Precision')\n",
"plt.ylabel('Recall')\n",
"plt.title('Precision-recall curve')\n",
"plt.legend(loc=\"lower left\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 7. Random over-sampling the minority class\n",
"\n",
"\n",
"\n",
"**Over-sampling** is the process of randomly duplicating observations from the minority class in order to achieve a balanced dataset. So, it replicates the observations from minority class to balance the data. It is also known as **upsampling**. It may result in overfitting due to duplication of data points. \n",
"\n",
"\n",
"The most common way of over-sampling is to resample with replacement. I will proceed as follows:-\n",
"\n",
"\n",
"First, I will import the resampling module from Scikit-Learn."
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [],
"source": [
"# import resample module \n",
"\n",
"from sklearn.utils import resample"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now, I will create a new dataframe with an oversampled minority class as follows:-\n",
"\n",
"\n",
"1. At first, I will separate observations from Class variable into different DataFrames.\n",
"\n",
"\n",
"2. Now, I will resample the minority class with replacement. I will set the number of samples of minority class to match \n",
" that of the majority class.\n",
"\n",
"\n",
"3. Finally, I will combine the oversampled minority class DataFrame with the original majority class DataFrame."
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [],
"source": [
"# separate the minority and majority classes\n",
"df_majority = df[df['Class']==0]\n",
"df_minority = df[df['Class']==1]"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [],
"source": [
"# oversample minority class\n",
"\n",
"df_minority_oversampled = resample(df_minority, replace=True, n_samples=284315, random_state=0) "
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [],
"source": [
"# combine majority class with oversampled minority class\n",
"\n",
"df_oversampled = pd.concat([df_majority, df_minority_oversampled])"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"1 284315\n",
"0 284315\n",
"Name: Class, dtype: int64"
]
},
"execution_count": 20,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# display new class value counts\n",
"\n",
"df_oversampled['Class'].value_counts()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now, we can see that we have a balanced dataset. The ratio of the two class labels is now 1:1.\n",
"\n",
"Now, I will plot the bar plot of the above two classes."
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<matplotlib.axes._subplots.AxesSubplot at 0xb58476ee48>"
]
},
"execution_count": 21,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAXcAAAD4CAYAAAAXUaZHAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvIxREBQAAC1BJREFUeJzt3VGInflZx/HvbxPihS29MIPUJNMJNiBRi8Ux61UVXTGhkAiukIDQlcpQaKhaL5qihDbe6Ar2KheNuFKEmq69GutIwGovRLZmVpdKNqQdwmqGgKZ2WRGxadzHi0zbw+lJ5j0zZzKbJ98PBM7/ff+c82wYvrx555yzqSokSb08tdsDSJJmz7hLUkPGXZIaMu6S1JBxl6SGjLskNWTcJakh4y5JDRl3SWpo72698P79+2thYWG3Xl6SHksvv/zy16tqbrN9uxb3hYUFVldXd+vlJemxlORfh+zztowkNWTcJakh4y5JDRl3SWrIuEtSQ4PinuR4khtJ1pKcm3D+uSR3kryy8efXZz+qJGmoTd8KmWQPcBH4BWAduJpkuapeHdv6uao6uwMzSpKmNOTK/RiwVlU3q+oucBk4tbNjSZK2Y8iHmA4At0bW68DTE/b9cpL3AV8Ffquqbo1vSLIELAHMz89PP+0uWDj3V7s9Qiuv/f77d3uEPj7xjt2eoJdPvLHbE8zUkCv3TDg2/n/V/ktgoareA/wN8JlJT1RVl6pqsaoW5+Y2/fSsJGmLhsR9HTg0sj4I3B7dUFX/WVXf3Fj+MfCTsxlPkrQVQ+J+FTiS5HCSfcBpYHl0Q5J3jixPAtdnN6IkaVqb3nOvqntJzgJXgD3AC1V1LckFYLWqloGPJDkJ3AO+ATy3gzNLkjYx6Fshq2oFWBk7dn7k8ceBj892NEnSVvkJVUlqyLhLUkPGXZIaMu6S1JBxl6SGjLskNWTcJakh4y5JDRl3SWrIuEtSQ8Zdkhoy7pLUkHGXpIaMuyQ1ZNwlqSHjLkkNGXdJasi4S1JDxl2SGjLuktSQcZekhoy7JDVk3CWpIeMuSQ0Zd0lqyLhLUkPGXZIaMu6S1JBxl6SGjLskNWTcJakh4y5JDQ2Ke5LjSW4kWUty7iH7nk1SSRZnN6IkaVqbxj3JHuAicAI4CpxJcnTCvrcDHwG+POshJUnTGXLlfgxYq6qbVXUXuAycmrDv94Dngf+d4XySpC0YEvcDwK2R9frGse9I8l7gUFV9YYazSZK2aEjcM+FYfedk8hTwKeC3N32iZCnJapLVO3fuDJ9SkjSVIXFfBw6NrA8Ct0fWbwd+DPhSkteAnwaWJ/1StaouVdViVS3Ozc1tfWpJ0kMNiftV4EiSw0n2AaeB5W+frKo3qmp/VS1U1QLwEnCyqlZ3ZGJJ0qY2jXtV3QPOAleA68CLVXUtyYUkJ3d6QEnS9PYO2VRVK8DK2LHzD9j7s9sfS5K0HX5CVZIaMu6S1JBxl6SGjLskNWTcJakh4y5JDRl3SWrIuEtSQ8Zdkhoy7pLUkHGXpIaMuyQ1ZNwlqSHjLkkNGXdJasi4S1JDxl2SGjLuktSQcZekhoy7JDVk3CWpIeMuSQ0Zd0lqyLhLUkPGXZIaMu6S1JBxl6SGjLskNWTcJakh4y5JDRl3SWrIuEtSQ8ZdkhoaFPckx5PcSLKW5NyE8x9K8i9JXkny90mOzn5USdJQm8Y9yR7gInACOAqcmRDvz1bVj1fVTwDPA38080klSYMNuXI/BqxV1c2qugtcBk6Nbqiq/xpZfj9QsxtRkjStvQP2HABujazXgafHNyX5MPBRYB/wc5OeKMkSsAQwPz8/7aySpIGGXLlnwrHvuTKvqotV9cPAx4DfnfREVXWpqharanFubm66SSVJgw2J+zpwaGR9ELj9kP2XgV/azlCSpO0ZEverwJEkh5PsA04Dy6MbkhwZWb4f+NrsRpQkTWvTe+5VdS/JWeAKsAd4oaquJbkArFbVMnA2yTPAt4DXgQ/s5NCSpIcb8gtVqmoFWBk7dn7k8W/MeC5J0jb4CVVJasi4S1JDxl2SGjLuktSQcZekhoy7JDVk3CWpIeMuSQ0Zd0lqyLhLUkPGXZIaMu6S1JBxl6SGjLskNWTcJakh4y5JDRl3SWrIuEtSQ8Zdkhoy7pLUkHGXpIaMuyQ1ZNwlqSHjLkkNGXdJasi4S1JDxl2SGjLuktSQcZekhoy7JDVk3CWpIeMuSQ0NinuS40luJFlLcm7C+Y8meTXJV5J8Mcm7Zj+qJGmoTeOeZA9wETgBHAXOJDk6tu2fgcWqeg/weeD5WQ8qSRpuyJX7MWCtqm5W1V3gMnBqdENV/V1V/c/G8iXg4GzHlCRNY0jcDwC3RtbrG8ce5IPAX29nKEnS9uwdsCcTjtXEjcmvAovAzzzg/BKwBDA/Pz9wREnStIZcua8Dh0bWB4Hb45uSPAP8DnCyqr456Ymq6lJVLVbV4tzc3FbmlSQNMCTuV4EjSQ4n2QecBpZHNyR5L/Bp7of9P2Y/piRpGpvGvaruAWeBK8B14MWqupbkQpKTG9v+EHgb8BdJXkmy/ICnkyQ9AkPuuVNVK8DK2LHzI4+fmfFckqRt8BOqktSQcZekhoy7JDVk3CWpIeMuSQ0Zd0lqyLhLUkPGXZIaMu6S1JBxl6SGjLskNWTcJakh4y5JDRl3SWrIuEtSQ8Zdkhoy7pLUkHGXpIaMuyQ1ZNwlqSHjLkkNGXdJasi4S1JDxl2SGjLuktSQcZekhoy7JDVk3CWpIeMuSQ0Zd0lqyLhLUkPGXZIaMu6S1NCguCc5nuRGkrUk5yacf1+Sf0pyL8mzsx9TkjSNTeOeZA9wETgBHAXOJDk6tu3fgOeAz856QEnS9PYO2HMMWKuqmwBJLgOngFe/vaGqXts49+YOzChJmtKQ2zIHgFsj6/WNY1NLspRkNcnqnTt3tvIUkqQBhsQ9E47VVl6sqi5V1WJVLc7NzW3lKSRJAwyJ+zpwaGR9ELi9M+NIkmZhSNyvAkeSHE6yDzgNLO/sWJKk7dg07lV1DzgLXAGuAy9W1bUkF5KcBEjyU0nWgV8BPp3k2k4OLUl6uCHvlqGqVoCVsWPnRx5f5f7tGknSW4CfUJWkhoy7JDVk3CWpIeMuSQ0Zd0lqyLhLUkPGXZIaMu6S1JBxl6SGjLskNWTcJakh4y5JDRl3SWrIuEtSQ8Zdkhoy7pLUkHGXpIaMuyQ1ZNwlqSHjLkkNGXdJasi4S1JDxl2SGjLuktSQcZekhoy7JDVk3CWpIeMuSQ0Zd0lqyLhLUkPGXZIaMu6S1NCguCc5nuRGkrUk5yac/74kn9s4/+UkC7MeVJI03KZxT7IHuAicAI4CZ5IcHdv2QeD1qno38CngD2Y9qCRpuCFX7seAtaq6WVV3gcvAqbE9p4DPbDz+PPDzSTK7MSVJ09g7YM8B4NbIeh14+kF7qupekjeAHwC+PropyRKwtLH87yQ3tjK0JtrP2N/3W1H8N92T6LH42eSTj8316LuGbBoS90n/xbWFPVTVJeDSgNfUlJKsVtXibs8hjfNnc3cMuS2zDhwaWR8Ebj9oT5K9wDuAb8xiQEnS9IbE/SpwJMnhJPuA08Dy2J5l4AMbj58F/raqvufKXZL0aGx6W2bjHvpZ4AqwB3ihqq4luQCsVtUy8CfAnyVZ4/4V++mdHFoTebtLb1X+bO6CeIEtSf34CVVJasi4S1JDxl2SGhryPndJGizJj3D/U+sHuP95l9vAclVd39XBnjBeuUuamSQf4/5XlAT4R+6/lTrAn0/60kHtHN8t00ySX6uqP93tOfRkSvJV4Eer6ltjx/cB16rqyO5M9uTxyr2fT+72AHqivQn80ITj79w4p0fEe+6PoSRfedAp4Acf5SzSmN8Evpjka3z3CwfngXcDZ3dtqieQt2UeQ0n+HfhF4PXxU8A/VNWkKyfpkUjyFPe/KvwA938m14GrVfV/uzrYE8Yr98fTF4C3VdUr4yeSfOnRjyN9V1W9Cby023M86bxyl6SG/IWqJDVk3CWpIeMuSQ0Zd0lq6P8BshdYHBwkBd8AAAAASUVORK5CYII=\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"# view the distribution of percentages within the Class column\n",
"\n",
"\n",
"(df_oversampled['Class'].value_counts()/np.float(len(df_oversampled))).plot.bar()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The above bar plot shows that we have a balanced dataset."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now, I will train another model using Logistic Regression and check its accuracy, but this time on the balanced dataset."
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Accuracy : 93.76%\n"
]
}
],
"source": [
"# declare feature vector and target variable\n",
"X1 = df_oversampled.drop(['Class'], axis=1)\n",
"y1 = df_oversampled['Class']\n",
"\n",
"\n",
"# instantiate the Logistic Regression classifier\n",
"logreg1 = LogisticRegression()\n",
"\n",
"\n",
"# fit the classifier to the imbalanced data\n",
"clf1 = logreg1.fit(X1, y1)\n",
"\n",
"\n",
"# predict on the training data\n",
"y1_pred = clf1.predict(X1)\n",
"\n",
"\n",
"# print the accuracy\n",
"accuracy1 = accuracy_score(y1_pred, y1)\n",
"\n",
"print(\"Accuracy : %.2f%%\" % (accuracy1 * 100.0))\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We now have a balanced dataset. Although the accuracy is slightly decreased, but it is still quite high and acceptable. \n",
"This accuracy is more meaningful as a performance metric."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 8. Random under-sampling the majority class\n",
"\n",
"\n",
"The **under-sampling** methods work with the majority class. In these methods, we randomly eliminate instances of the majority class. It reduces the number of observations from majority class to make the dataset balanced. This method is applicable when the dataset is huge and reducing the number of training samples make the dataset balanced.\n",
"\n",
"\n",
"The most common technique for under-sampling is resampling without replacement.\n",
"\n",
"\n",
"I will proceed exactly as in the case of random over-sampling."
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {},
"outputs": [],
"source": [
"# separate the minority and majority classes\n",
"df_majority = df[df['Class']==0]\n",
"df_minority = df[df['Class']==1]"
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {},
"outputs": [],
"source": [
"# undersample majority class\n",
"\n",
"df_majority_undersampled = resample(df_majority, replace=True, n_samples=492, random_state=0) "
]
},
{
"cell_type": "code",
"execution_count": 25,
"metadata": {},
"outputs": [],
"source": [
"# combine majority class with oversampled minority class\n",
"\n",
"df_undersampled = pd.concat([df_minority, df_majority_undersampled])"
]
},
{
"cell_type": "code",
"execution_count": 26,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"1 492\n",
"0 492\n",
"Name: Class, dtype: int64"
]
},
"execution_count": 26,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# display new class value counts\n",
"\n",
"df_undersampled['Class'].value_counts()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now, we can see that the new dataframe `df_undersampled` has fewer observations than the original one `df` and the ratio of the two classes is now 1:1.\n",
"\n",
"Again, I will train a model using Logistic Regression classifier."
]
},
{
"cell_type": "code",
"execution_count": 27,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Accuracy : 93.90%\n"
]
}
],
"source": [
"# declare feature vector and target variable\n",
"X2 = df_undersampled.drop(['Class'], axis=1)\n",
"y2 = df_undersampled['Class']\n",
"\n",
"\n",
"# instantiate the Logistic Regression classifier\n",
"logreg2 = LogisticRegression()\n",
"\n",
"\n",
"# fit the classifier to the imbalanced data\n",
"clf2 = logreg2.fit(X2, y2)\n",
"\n",
"\n",
"# predict on the training data\n",
"y2_pred = clf2.predict(X2)\n",
"\n",
"\n",
"# print the accuracy\n",
"accuracy2 = accuracy_score(y2_pred, y2)\n",
"\n",
"print(\"Accuracy : %.2f%%\" % (accuracy2 * 100.0))\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Again, we can see that we have a slightly decreased accuracy but it is more meaningful now."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 9. Apply Tree-Based Algorithms"
]
},
{
"cell_type": "code",
"execution_count": 28,
"metadata": {},
"outputs": [],
"source": [
"# declare input features (X) and target variable (y)\n",
"X4 = df.drop('Class', axis=1)\n",
"y4 = df['Class']\n"
]
},
{
"cell_type": "code",
"execution_count": 29,
"metadata": {},
"outputs": [],
"source": [
"# import Random Forest classifier\n",
"from sklearn.ensemble import RandomForestClassifier\n"
]
},
{
"cell_type": "code",
"execution_count": 30,
"metadata": {},
"outputs": [],
"source": [
"# instantiate the classifier \n",
"clf4 = RandomForestClassifier()\n"
]
},
{
"cell_type": "code",
"execution_count": 31,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"RandomForestClassifier(bootstrap=True, class_weight=None, criterion='gini',\n",
" max_depth=None, max_features='auto', max_leaf_nodes=None,\n",
" min_impurity_decrease=0.0, min_impurity_split=None,\n",
" min_samples_leaf=1, min_samples_split=2,\n",
" min_weight_fraction_leaf=0.0, n_estimators=10, n_jobs=None,\n",
" oob_score=False, random_state=None, verbose=0,\n",
" warm_start=False)"
]
},
"execution_count": 31,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# fit the classifier to the training data\n",
"clf4.fit(X4, y4)\n",
" "
]
},
{
"cell_type": "code",
"execution_count": 32,
"metadata": {},
"outputs": [],
"source": [
"# predict on training set\n",
"y4_pred = clf4.predict(X4)"
]
},
{
"cell_type": "code",
"execution_count": 33,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Accuracy : 99.99%\n"
]
}
],
"source": [
"# compute and print accuracy\n",
"accuracy4 = accuracy_score(y4_pred, y4)\n",
"print(\"Accuracy : %.2f%%\" % (accuracy4 * 100.0))"
]
},
{
"cell_type": "code",
"execution_count": 34,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"ROC-AUC : 0.9999992243516689\n"
]
}
],
"source": [
"# compute and print ROC-AUC\n",
"\n",
"from sklearn.metrics import roc_auc_score\n",
"\n",
"y4_prob = clf4.predict_proba(X4)\n",
"y4_prob = [p[1] for p in y4_prob]\n",
"print(\"ROC-AUC : \" , roc_auc_score(y4, y4_prob))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 10.\tRandom under-sampling and over-sampling with imbalanced-learn\n",
"\n",
"\n",
"\n",
"There is a Python library which enable us to handle the imbalanced datasets. It is called **Imbalanced-Learn**. It is a Python library which contains various algorithms to handle the imbalanced datasets. It can be easily installed with the `pip` command. This library contains a `make_imbalance` method to exasperate the level of class imbalance within a given dataset.\n",
"\n",
"\n",
"Now, I will demonstrate the technique of random undersampling and oversampling with imbalanced learn. \n",
"\n",
"\n",
"First of all, I will import the `imbalanced learn` library.\n"
]
},
{
"cell_type": "code",
"execution_count": 35,
"metadata": {},
"outputs": [],
"source": [
"# import imbalanced learn library\n",
"\n",
"import imblearn"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Then, I will import the `RandomUnderSampler` class. It is a quick and easy way to balance the data by randomly selecting a subset of data for the targeted classes. "
]
},
{
"cell_type": "code",
"execution_count": 36,
"metadata": {},
"outputs": [],
"source": [
"# import RandomUnderSampler\n",
"from imblearn.under_sampling import RandomUnderSampler\n",
"\n",
"# instantiate the RandomUnderSampler\n",
"rus = RandomUnderSampler(return_indices=True)\n",
"\n",
"\n",
"# fit the RandomUnderSampler to the dataset\n",
"X_rus, y_rus, id_rus = rus.fit_sample(X, y)\n"
]
},
{
"cell_type": "code",
"execution_count": 37,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Removed indices: [194966 158207 239580 47508 164976 4059 244121 4712 129277 195768\n",
" 220169 159815 157385 226067 122412 186077 216500 36600 232308 61840\n",
" 216772 269873 49886 138921 64943 104600 211825 162198 236942 256405\n",
" 116084 165304 254299 217511 91142 67255 2349 132109 227416 75785\n",
" 23316 131322 177311 61790 91798 103220 103526 33083 148175 117300\n",
" 117437 251709 243617 136620 78369 177568 157649 150080 137441 277646\n",
" 988 213741 213602 264759 102266 166026 192696 269500 182970 7029\n",
" 138352 262530 99000 159383 225900 249330 14929 117795 252069 86625\n",
" 249970 58096 109913 195548 30897 8690 22107 261540 111780 105375\n",
" 62971 201607 177552 30981 84358 226572 7675 64315 172103 171021\n",
" 72979 208177 38876 63638 180868 76338 121268 264548 117134 182323\n",
" 254834 166395 235471 204943 9850 232780 83992 20930 54009 198064\n",
" 133443 81674 207050 274408 266475 165966 277813 23758 49157 222434\n",
" 178038 7746 72809 101544 198536 158586 102708 263512 63292 147707\n",
" 31677 6328 117195 33384 17731 59202 46117 2387 42641 195419\n",
" 99068 252335 81636 31663 136341 60555 120632 252749 77847 38097\n",
" 120550 197729 40557 45459 231116 131960 172075 190211 101223 15245\n",
" 29630 12298 41976 214451 248802 269876 19449 10430 123236 12237\n",
" 129734 144573 244632 270905 101667 6710 7139 234437 15658 246551\n",
" 230319 244236 257475 27540 67248 119614 270866 128596 146983 7481\n",
" 127789 89062 98494 7782 171395 13694 262656 240202 72635 265137\n",
" 35247 133304 253329 260581 204381 215806 209839 58584 280688 260875\n",
" 197410 244160 33589 192219 59500 178612 106610 272537 116951 227946\n",
" 241031 147854 235519 129102 54520 275960 37683 17697 136346 214118\n",
" 154999 661 209611 219382 227776 79642 114606 180877 269470 176024\n",
" 236924 228334 272616 241550 85666 138678 114154 116010 201161 202505\n",
" 143135 22461 269580 184552 216170 140159 239925 176351 97376 183893\n",
" 157472 8541 17800 162709 229779 94546 64270 269980 102331 179326\n",
" 181376 212424 19644 136198 193411 48301 42389 125628 96447 65219\n",
" 156953 238063 236407 118678 188004 88127 197928 83981 196266 58068\n",
" 173047 195404 78823 232805 200973 271144 141205 169895 226854 152573\n",
" 137340 249271 204025 101012 206959 184718 148995 19560 60100 93132\n",
" 228612 52144 94759 153881 202918 152068 87795 148397 253826 209436\n",
" 55565 156536 210650 52685 108761 130575 138829 278945 98843 30672\n",
" 74650 25075 37943 88289 232059 86024 124666 129281 168990 260644\n",
" 20605 259212 103199 105397 213227 69835 52662 161449 7192 234743\n",
" 161317 254693 131443 120272 272556 77072 224933 136289 65895 251918\n",
" 138989 156555 166011 69344 151419 228769 138165 72384 108491 249294\n",
" 242158 105805 31188 47452 107703 191593 108528 107788 195917 276293\n",
" 273099 166631 259123 52203 88536 23808 128556 234308 70648 266068\n",
" 1888 159394 60659 85037 272887 140314 152387 181888 153327 42232\n",
" 28043 162227 165238 124602 252553 59639 71697 88871 26193 73080\n",
" 36374 116525 171421 197736 193143 163500 235175 209073 236240 145689\n",
" 61701 169288 163571 134394 113957 225109 140344 1769 232132 2260\n",
" 184603 179724 185192 6070 18868 162478 209752 49614 91240 110724\n",
" 41575 112947 64470 109339 199665 251105 235051 263785 159364 241090\n",
" 273473 30352 232970 168645 234204 99853 116866 201284 144682 272207\n",
" 254527 47779 269575 161781 267340 45439 224250 181340 5375 173659\n",
" 5646 11315 74204 63810 91425 210100 171055 106752 179427 119743\n",
" 18362 30221 541 623 4920 6108 6329 6331 6334 6336\n",
" 6338 6427 6446 6472 6529 6609 6641 6717 6719 6734\n",
" 6774 6820 6870 6882 6899 6903 6971 8296 8312 8335\n",
" 8615 8617 8842 8845 8972 9035 9179 9252 9487 9509\n",
" 10204 10484 10497 10498 10568 10630 10690 10801 10891 10897\n",
" 11343 11710 11841 11880 12070 12108 12261 12369 14104 14170\n",
" 14197 14211 14338 15166 15204 15225 15451 15476 15506 15539\n",
" 15566 15736 15751 15781 15810 16415 16780 16863 17317 17366\n",
" 17407 17453 17480 18466 18472 18773 18809 20198 23308 23422\n",
" 26802 27362 27627 27738 27749 29687 30100 30314 30384 30398\n",
" 30442 30473 30496 31002 33276 39183 40085 40525 41395 41569\n",
" 41943 42007 42009 42473 42528 42549 42590 42609 42635 42674\n",
" 42696 42700 42741 42756 42769 42784 42856 42887 42936 42945\n",
" 42958 43061 43160 43204 43428 43624 43681 43773 44001 44091\n",
" 44223 44270 44556 45203 45732 46909 46918 46998 47802 48094\n",
" 50211 50537 52466 52521 52584 53591 53794 55401 56703 57248\n",
" 57470 57615 58422 58761 59539 61787 63421 63634 64329 64411\n",
" 64460 68067 68320 68522 68633 69498 69980 70141 70589 72757\n",
" 73784 73857 74496 74507 74794 75511 76555 76609 76929 77099\n",
" 77348 77387 77682 79525 79536 79835 79874 79883 80760 81186\n",
" 81609 82400 83053 83297 83417 84543 86155 87354 88258 88307\n",
" 88876 88897 89190 91671 92777 93424 93486 93788 94218 95534\n",
" 95597 96341 96789 96994 99506 100623 101509 102441 102442 102443\n",
" 102444 102445 102446 102782 105178 106679 106998 107067 107637 108258\n",
" 108708 111690 112840 114271 116139 116404 118308 119714 119781 120505\n",
" 120837 122479 123141 123201 123238 123270 123301 124036 124087 124115\n",
" 124176 125342 128479 131272 135718 137705 140786 141257 141258 141259\n",
" 141260 142405 142557 143188 143333 143334 143335 143336 143728 143731\n",
" 144104 144108 144754 145800 146790 147548 147605 149145 149357 149522\n",
" 149577 149587 149600 149869 149874 150601 150644 150647 150654 150660\n",
" 150661 150662 150663 150665 150666 150667 150668 150669 150677 150678\n",
" 150679 150680 150684 150687 150692 150697 150715 150925 151006 151007\n",
" 151008 151009 151011 151103 151196 151462 151519 151730 151807 152019\n",
" 152223 152295 153823 153835 153885 154234 154286 154371 154454 154587\n",
" 154633 154668 154670 154676 154684 154693 154694 154697 154718 154719\n",
" 154720 154960 156988 156990 157585 157868 157871 157918 163149 163586\n",
" 167184 167305 172787 176049 177195 178208 181966 182992 183106 184379\n",
" 189587 189701 189878 190368 191074 191267 191359 191544 191690 192382\n",
" 192529 192584 192687 195383 197586 198868 199896 201098 201601 203324\n",
" 203328 203700 204064 204079 204503 208651 212516 212644 213092 213116\n",
" 214662 214775 215132 215953 215984 218442 219025 219892 220725 221018\n",
" 221041 222133 222419 223366 223572 223578 223618 226814 226877 229712\n",
" 229730 230076 230476 231978 233258 234574 234632 234633 234705 235616\n",
" 235634 235644 237107 237426 238222 238366 238466 239499 239501 240222\n",
" 241254 241445 243393 243547 243699 243749 243848 244004 244333 245347\n",
" 245556 247673 247995 248296 248971 249167 249239 249607 249828 249963\n",
" 250761 251477 251866 251881 251891 251904 252124 252774 254344 254395\n",
" 255403 255556 258403 261056 261473 261925 262560 262826 263080 263274\n",
" 263324 263877 268375 272521 274382 274475 275992 276071 276864 279863\n",
" 280143 280149 281144 281674]\n"
]
}
],
"source": [
"# print the removed indices\n",
"print(\"Removed indices: \", id_rus)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The above indices are removed from the original dataset."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now, I will demonstrate random oversampling. The process will be the same as random undersampling."
]
},
{
"cell_type": "code",
"execution_count": 38,
"metadata": {},
"outputs": [],
"source": [
"from imblearn.over_sampling import RandomOverSampler\n",
"\n",
"ros = RandomOverSampler()\n",
"\n",
"X_ros, y_ros = ros.fit_sample(X, y)"
]
},
{
"cell_type": "code",
"execution_count": 39,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"283823 new random points generated\n"
]
}
],
"source": [
"print(X_ros.shape[0] - X.shape[0], 'new random points generated')"
]
},
{
"attachments": {
"Tomek%20links.jpg": {
"image/jpeg": ""
}
},
"cell_type": "markdown",
"metadata": {},
"source": [
"## 11.\tUnder-sampling : Tomek links\n",
"\n",
"\n",
"Tomek links are defined as the two observations of different classes which are nearest neighbours of each other.\n",
"\n",
"\n",
"The figure below illustrate the concept of Tomek links-\n",
"\n",
"\n",
"\n",
"![Tomek%20links.jpg](attachment:Tomek%20links.jpg)\n",
"\n",
"\n",
"\n",
"We can see in the above image that the Tomek links (circled in green) are given by the pairs of red and blue data points that are nearest neighbors. Most of the classification algorithms face difficulty due to these points. So, I will remove these \n",
"points and increase the separation gap between two classes. Now, the algorithms produce more reliable output.\n",
"\n",
"This technique will not produce a balanced dataset. It will simply clean the dataset by removing the Tomek links. It may result in an easier classification problem. Thus, by removing the Tomek links, we can improve the performance of the classifier even if we don’t have a balanced dataset.\n",
"\n",
"\n",
"So, removing the Tomek links increases the gap between the two classes and thus facilitate the classification process.\n",
"\n",
"\n",
"In the following code, I will use `ratio=majority` to resample the majority class.\n"
]
},
{
"cell_type": "code",
"execution_count": 40,
"metadata": {},
"outputs": [],
"source": [
"from imblearn.under_sampling import TomekLinks\n"
]
},
{
"cell_type": "code",
"execution_count": 41,
"metadata": {},
"outputs": [],
"source": [
"tl = TomekLinks(return_indices=True, ratio='majority')\n"
]
},
{
"cell_type": "code",
"execution_count": 42,
"metadata": {},
"outputs": [],
"source": [
"X_tl, y_tl, id_tl = tl.fit_sample(X, y)\n"
]
},
{
"cell_type": "code",
"execution_count": 43,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Removed indexes: [ 0 1 2 ... 284804 284805 284806]\n"
]
}
],
"source": [
"print('Removed indexes:', id_tl)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 12. Under-sampling : Cluster Centroids\n",
"\n",
"\n",
"In this technique, we perform under-sampling by generating centroids based on clustering methods. The dataset will be grouped\n",
"by similarity, in order to preserve information.\n",
"\n",
"In this example, I have passed the {0: 10} dict for the parameter ratio. It preserves 10 elements from the majority class (0), and all minority class (1) ."
]
},
{
"cell_type": "code",
"execution_count": 44,
"metadata": {},
"outputs": [],
"source": [
"from imblearn.under_sampling import ClusterCentroids"
]
},
{
"cell_type": "code",
"execution_count": 45,
"metadata": {},
"outputs": [],
"source": [
"cc = ClusterCentroids(ratio={0: 10})"
]
},
{
"cell_type": "code",
"execution_count": 46,
"metadata": {},
"outputs": [],
"source": [
"X_cc, y_cc = cc.fit_sample(X, y)"
]
},
{
"cell_type": "code",
"execution_count": 47,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"284305 New points undersampled under Cluster Centroids\n"
]
}
],
"source": [
"print(X.shape[0] - X_cc.shape[0], 'New points undersampled under Cluster Centroids')"
]
},
{
"attachments": {
"smote.png": {
"image/png": ""
}
},
"cell_type": "markdown",
"metadata": {},
"source": [
"## 13.\tOver-sampling : SMOTE\n",
"\n",
"\n",
"\n",
"In the context of synthetic data generation, there is a powerful and widely used method known as **synthetic minority oversampling technique** or **SMOTE**. Under this technique, artificial data is created based on feature space. \n",
"Artificial data is generated with bootstrapping and k-nearest neighbours algorithm. It works as follows:-\n",
"\n",
"\n",
"1.\tFirst of all, we take the difference between the feature vector (sample) under consideration and its nearest neighbour.\n",
"\n",
"\n",
"2.\tThen we multiply this difference by a random number between 0 and 1.\n",
"\n",
"\n",
"3.\tThen we add this number to the feature vector under consideration.\n",
"\n",
"\n",
"4.\tThus we select a random point along the line segment between two specific features.\n",
"\n",
"\n",
"The concept of **SMOTE** can best be illustrated with the following figure:-\n",
"\n",
"\n",
"![smote.png](attachment:smote.png)\n",
"\n",
"\n",
"So, **SMOTE** generates new observations by interpolation between existing observations in the dataset.\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 48,
"metadata": {},
"outputs": [],
"source": [
"from imblearn.over_sampling import SMOTE"
]
},
{
"cell_type": "code",
"execution_count": 49,
"metadata": {},
"outputs": [],
"source": [
"smote = SMOTE(ratio='minority')"
]
},
{
"cell_type": "code",
"execution_count": 50,
"metadata": {},
"outputs": [],
"source": [
"X_sm, y_sm = smote.fit_sample(X, y)"
]
},
{
"cell_type": "code",
"execution_count": 51,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"283823 New points created under SMOTE\n"
]
}
],
"source": [
"print(X_sm.shape[0] - X.shape[0], 'New points created under SMOTE')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 14. Conclusion\n",
"\n",
"\n",
"In this jupyter notebook, I have discussed various approaches to deal with the problem of imbalanced classes. These are `random oversampling`, `random undersampling`, `tree-based algorithms`, `resampling with imbalanced learn library`, `under-sampling : Tomek links`, `under-sampling : Cluster Centroids` and `over-sampling : SMOTE`.\n",
"\n",
"\n",
"Some combination of these approaches will help us to create a better classifier. Simple sampling techniques may handle slight imbalance whereas more advanced methods like ensemble methods are required for extreme imbalances. The most effective technique will vary according to the dataset.\n",
"\n",
"\n",
"So, based on the above discussion, we can conclude that there is no one solution to deal with the imbalanced classes problem. \n",
"We should try out multiple methods to select the best-suited sampling techniques for the dataset in hand. The most effective technique will vary according to the characteristics of the dataset.\n"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.0"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment