Skip to content

Instantly share code, notes, and snippets.

@jedrz
Created August 28, 2024 09:35
Show Gist options
  • Save jedrz/6ac24d10574660d2ff545a1bc6d9d08f to your computer and use it in GitHub Desktop.
Save jedrz/6ac24d10574660d2ff545a1bc6d9d08f to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"id": "24138c5f-e4d7-4105-8dc9-fe59f04f2f35",
"metadata": {},
"source": [
"# Introduction\n",
"\n",
"Second part of credit card fraud data model training."
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "da3c715a-151e-4da7-b321-462c4c1f91f8",
"metadata": {},
"outputs": [],
"source": [
"import warnings\n",
"warnings.simplefilter(action='ignore', category=FutureWarning)\n",
"warnings.simplefilter(action='ignore', category=UserWarning)\n",
"\n",
"import numpy as np\n",
"import pandas as pd\n",
"import matplotlib.pyplot as plt\n",
"import sklearn.model_selection\n",
"import sklearn.tree\n",
"import sklearn.svm\n",
"import sklearn.metrics\n",
"import sklearn.pipeline\n",
"import sklearn.preprocessing\n",
"import sklearn.compose\n",
"import mlflow\n",
"import mlflow.models\n",
"import imblearn.over_sampling\n",
"\n",
"random_state = 42\n",
"\n",
"# Comment to use local MLflow.\n",
"mlflow.set_tracking_uri('http://localhost:5000')"
]
},
{
"cell_type": "markdown",
"id": "a5c08ffd-1cd4-45f5-81c8-67195f73398b",
"metadata": {},
"source": [
"# Data preprocessing\n",
"\n",
"Let's say that after some time we get the idea that new features can be extracted from the original dataset.\n",
"\n",
"Using the date of birth and the transaction date, we can calculate how old a customer was and then check whether the new variable correlates with the predicted one. On the other hand, using the transaction date we can leave just the hour and minute since full date seems not to be very useful. "
]
},
{
"cell_type": "markdown",
"id": "be166fba-bc18-4f7a-8842-0fcbaf361442",
"metadata": {},
"source": [
"## Data preparation\n",
"\n",
"Copied from the first notebook."
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "61376e51-bb5e-4236-ab73-5e91afacbf84",
"metadata": {},
"outputs": [],
"source": [
"data = pd.read_csv(\"fraud_data.csv\")\n",
"data = data.drop('trans_num', axis='columns', errors='ignore')\n",
"data = data[(data['is_fraud'] == '0') | (data['is_fraud'] == '1')]\n",
"data = data.map(lambda x: x.strip('\"') if isinstance(x, str) else x)\n",
"data = data.copy()\n",
"data['trans_date_trans_time'] = data['trans_date_trans_time'].apply(lambda x: pd.to_datetime(x, dayfirst=True))\n",
"data['dob'] = data['dob'].apply(lambda x: pd.to_datetime(x, dayfirst=True))\n",
"data = data.astype({\n",
" 'merchant': 'category',\n",
" 'category': 'category',\n",
" 'city': 'category',\n",
" 'state': 'category',\n",
" 'job': 'category',\n",
" 'is_fraud': 'int',\n",
"})\n",
"data = data.astype({\n",
" 'is_fraud': 'boolean',\n",
"})"
]
},
{
"cell_type": "markdown",
"id": "81469c9a-ddc4-4fd9-95d5-bc1e745d9611",
"metadata": {},
"source": [
"## Age variable\n",
"\n",
"To calculate the age of a customer, we simply need to subtract the transaction date from the date of birth. The age is approximated in days, although an integer value should be just as good."
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "ee1e8f51-80d7-4077-9369-29f031328d20",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>trans_date_trans_time</th>\n",
" <th>dob</th>\n",
" <th>age</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>2019-01-04 00:58:00</td>\n",
" <td>1939-11-09</td>\n",
" <td>79.208219</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>2019-01-04 15:06:00</td>\n",
" <td>1939-11-09</td>\n",
" <td>79.208219</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>2019-01-04 22:37:00</td>\n",
" <td>1939-11-09</td>\n",
" <td>79.208219</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>2019-01-04 23:06:00</td>\n",
" <td>1939-11-09</td>\n",
" <td>79.208219</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>2019-01-04 23:59:00</td>\n",
" <td>1939-11-09</td>\n",
" <td>79.208219</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>14441</th>\n",
" <td>2019-01-22 00:37:00</td>\n",
" <td>1976-10-18</td>\n",
" <td>42.290411</td>\n",
" </tr>\n",
" <tr>\n",
" <th>14442</th>\n",
" <td>2019-01-22 00:41:00</td>\n",
" <td>1956-09-01</td>\n",
" <td>62.432877</td>\n",
" </tr>\n",
" <tr>\n",
" <th>14443</th>\n",
" <td>2019-01-22 00:42:00</td>\n",
" <td>1973-05-16</td>\n",
" <td>45.717808</td>\n",
" </tr>\n",
" <tr>\n",
" <th>14444</th>\n",
" <td>2019-01-22 00:48:00</td>\n",
" <td>1939-11-09</td>\n",
" <td>79.257534</td>\n",
" </tr>\n",
" <tr>\n",
" <th>14445</th>\n",
" <td>2019-01-22 00:55:00</td>\n",
" <td>1950-09-15</td>\n",
" <td>68.400000</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>14444 rows × 3 columns</p>\n",
"</div>"
],
"text/plain": [
" trans_date_trans_time dob age\n",
"0 2019-01-04 00:58:00 1939-11-09 79.208219\n",
"1 2019-01-04 15:06:00 1939-11-09 79.208219\n",
"2 2019-01-04 22:37:00 1939-11-09 79.208219\n",
"3 2019-01-04 23:06:00 1939-11-09 79.208219\n",
"4 2019-01-04 23:59:00 1939-11-09 79.208219\n",
"... ... ... ...\n",
"14441 2019-01-22 00:37:00 1976-10-18 42.290411\n",
"14442 2019-01-22 00:41:00 1956-09-01 62.432877\n",
"14443 2019-01-22 00:42:00 1973-05-16 45.717808\n",
"14444 2019-01-22 00:48:00 1939-11-09 79.257534\n",
"14445 2019-01-22 00:55:00 1950-09-15 68.400000\n",
"\n",
"[14444 rows x 3 columns]"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"age = data.apply(lambda row: (row['trans_date_trans_time'] - row['dob']).days / 365.0, axis=1)\n",
"data[['trans_date_trans_time', 'dob']].assign(age=age)"
]
},
{
"cell_type": "markdown",
"id": "542b97ab-9f43-4ade-9501-2d8b90387e0f",
"metadata": {},
"source": [
"Show some statistics of age variable."
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "ff063723-6134-4a19-87f7-8f781b42985a",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"count 14444.000000\n",
"mean 48.122891\n",
"std 17.266586\n",
"min 17.446575\n",
"25% 34.356164\n",
"50% 45.961644\n",
"75% 59.306849\n",
"max 93.375342\n",
"dtype: float64"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"age.describe()"
]
},
{
"cell_type": "markdown",
"id": "472db5dc-5aa5-4863-bc7a-7319033532d1",
"metadata": {},
"source": [
"Looks good, add to the data."
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "80aed9d6-5218-45c4-9fa9-3e01e998de49",
"metadata": {},
"outputs": [],
"source": [
"data = data.drop('age', axis='columns', errors='ignore') # remove just in case the cell is executed again \n",
"data.insert(data.shape[1] - 1, 'age', age)"
]
},
{
"cell_type": "markdown",
"id": "a98a307c-2c68-4ee3-86c3-5ace413b86ad",
"metadata": {},
"source": [
"## Transaction time variable\n",
"\n",
"Now extract just time of the encoded as total number of minutes starting from midnight."
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "75f4ffed-90ac-4fa1-ad63-78727fd44887",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>trans_date_trans_time</th>\n",
" <th>trans_time</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>2019-01-04 00:58:00</td>\n",
" <td>58</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>2019-01-04 15:06:00</td>\n",
" <td>906</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>2019-01-04 22:37:00</td>\n",
" <td>1357</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>2019-01-04 23:06:00</td>\n",
" <td>1386</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>2019-01-04 23:59:00</td>\n",
" <td>1439</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>14441</th>\n",
" <td>2019-01-22 00:37:00</td>\n",
" <td>37</td>\n",
" </tr>\n",
" <tr>\n",
" <th>14442</th>\n",
" <td>2019-01-22 00:41:00</td>\n",
" <td>41</td>\n",
" </tr>\n",
" <tr>\n",
" <th>14443</th>\n",
" <td>2019-01-22 00:42:00</td>\n",
" <td>42</td>\n",
" </tr>\n",
" <tr>\n",
" <th>14444</th>\n",
" <td>2019-01-22 00:48:00</td>\n",
" <td>48</td>\n",
" </tr>\n",
" <tr>\n",
" <th>14445</th>\n",
" <td>2019-01-22 00:55:00</td>\n",
" <td>55</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>14444 rows × 2 columns</p>\n",
"</div>"
],
"text/plain": [
" trans_date_trans_time trans_time\n",
"0 2019-01-04 00:58:00 58\n",
"1 2019-01-04 15:06:00 906\n",
"2 2019-01-04 22:37:00 1357\n",
"3 2019-01-04 23:06:00 1386\n",
"4 2019-01-04 23:59:00 1439\n",
"... ... ...\n",
"14441 2019-01-22 00:37:00 37\n",
"14442 2019-01-22 00:41:00 41\n",
"14443 2019-01-22 00:42:00 42\n",
"14444 2019-01-22 00:48:00 48\n",
"14445 2019-01-22 00:55:00 55\n",
"\n",
"[14444 rows x 2 columns]"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"def extract_trans_time(row):\n",
" trans_date_time = row['trans_date_trans_time']\n",
" return trans_date_time.hour * 60 + trans_date_time.minute\n",
"trans_time = data.apply(extract_trans_time, axis=1)\n",
"data[['trans_date_trans_time']].assign(trans_time=trans_time)"
]
},
{
"cell_type": "markdown",
"id": "b30668e6-4c42-4785-b689-26d34a2e03de",
"metadata": {},
"source": [
"Looks good, add to the data."
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "3acdc473-4670-4a00-b829-f184e4fb9f17",
"metadata": {},
"outputs": [],
"source": [
"data = data.drop('trans_time', axis='columns', errors='ignore') # remove just in case the cell is executed again \n",
"data.insert(data.shape[1] - 1, 'trans_time', trans_time)"
]
},
{
"cell_type": "markdown",
"id": "9d0b2d4d-03fa-47cf-a1a1-c9544c72c54b",
"metadata": {},
"source": [
"Drop irrelevant columns."
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "30e5a982-cd5d-483f-907d-54febc5db95b",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>merchant</th>\n",
" <th>category</th>\n",
" <th>amt</th>\n",
" <th>city</th>\n",
" <th>state</th>\n",
" <th>lat</th>\n",
" <th>long</th>\n",
" <th>city_pop</th>\n",
" <th>job</th>\n",
" <th>merch_lat</th>\n",
" <th>merch_long</th>\n",
" <th>age</th>\n",
" <th>trans_time</th>\n",
" <th>is_fraud</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>Stokes, Christiansen and Sipes</td>\n",
" <td>grocery_net</td>\n",
" <td>14.37</td>\n",
" <td>Wales</td>\n",
" <td>AK</td>\n",
" <td>64.7556</td>\n",
" <td>-165.6723</td>\n",
" <td>145</td>\n",
" <td>Administrator, education</td>\n",
" <td>65.654142</td>\n",
" <td>-164.722603</td>\n",
" <td>79.208219</td>\n",
" <td>58</td>\n",
" <td>True</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>Predovic Inc</td>\n",
" <td>shopping_net</td>\n",
" <td>966.11</td>\n",
" <td>Wales</td>\n",
" <td>AK</td>\n",
" <td>64.7556</td>\n",
" <td>-165.6723</td>\n",
" <td>145</td>\n",
" <td>Administrator, education</td>\n",
" <td>65.468863</td>\n",
" <td>-165.473127</td>\n",
" <td>79.208219</td>\n",
" <td>906</td>\n",
" <td>True</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>Wisozk and Sons</td>\n",
" <td>misc_pos</td>\n",
" <td>49.61</td>\n",
" <td>Wales</td>\n",
" <td>AK</td>\n",
" <td>64.7556</td>\n",
" <td>-165.6723</td>\n",
" <td>145</td>\n",
" <td>Administrator, education</td>\n",
" <td>65.347667</td>\n",
" <td>-165.914542</td>\n",
" <td>79.208219</td>\n",
" <td>1357</td>\n",
" <td>True</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>Murray-Smitham</td>\n",
" <td>grocery_pos</td>\n",
" <td>295.26</td>\n",
" <td>Wales</td>\n",
" <td>AK</td>\n",
" <td>64.7556</td>\n",
" <td>-165.6723</td>\n",
" <td>145</td>\n",
" <td>Administrator, education</td>\n",
" <td>64.445035</td>\n",
" <td>-166.080207</td>\n",
" <td>79.208219</td>\n",
" <td>1386</td>\n",
" <td>True</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>Friesen Lt</td>\n",
" <td>health_fitness</td>\n",
" <td>18.17</td>\n",
" <td>Wales</td>\n",
" <td>AK</td>\n",
" <td>64.7556</td>\n",
" <td>-165.6723</td>\n",
" <td>145</td>\n",
" <td>Administrator, education</td>\n",
" <td>65.447094</td>\n",
" <td>-165.446843</td>\n",
" <td>79.208219</td>\n",
" <td>1439</td>\n",
" <td>True</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>14441</th>\n",
" <td>Hudson-Grady</td>\n",
" <td>shopping_pos</td>\n",
" <td>122.00</td>\n",
" <td>Athena</td>\n",
" <td>OR</td>\n",
" <td>45.8289</td>\n",
" <td>-118.4971</td>\n",
" <td>1302</td>\n",
" <td>Dealer</td>\n",
" <td>46.442439</td>\n",
" <td>-118.524214</td>\n",
" <td>42.290411</td>\n",
" <td>37</td>\n",
" <td>False</td>\n",
" </tr>\n",
" <tr>\n",
" <th>14442</th>\n",
" <td>Nienow, Ankunding and Collie</td>\n",
" <td>misc_pos</td>\n",
" <td>9.07</td>\n",
" <td>Gardiner</td>\n",
" <td>OR</td>\n",
" <td>43.7857</td>\n",
" <td>-124.1437</td>\n",
" <td>260</td>\n",
" <td>Engineer, maintenance</td>\n",
" <td>42.901265</td>\n",
" <td>-124.995317</td>\n",
" <td>62.432877</td>\n",
" <td>41</td>\n",
" <td>False</td>\n",
" </tr>\n",
" <tr>\n",
" <th>14443</th>\n",
" <td>Pacocha-O'Reilly</td>\n",
" <td>grocery_pos</td>\n",
" <td>104.84</td>\n",
" <td>Alva</td>\n",
" <td>WY</td>\n",
" <td>44.6873</td>\n",
" <td>-104.4414</td>\n",
" <td>110</td>\n",
" <td>Administrator, local government</td>\n",
" <td>45.538062</td>\n",
" <td>-104.542117</td>\n",
" <td>45.717808</td>\n",
" <td>42</td>\n",
" <td>False</td>\n",
" </tr>\n",
" <tr>\n",
" <th>14444</th>\n",
" <td>Bins, Balistreri and Beatty</td>\n",
" <td>shopping_pos</td>\n",
" <td>268.16</td>\n",
" <td>Wales</td>\n",
" <td>AK</td>\n",
" <td>64.7556</td>\n",
" <td>-165.6723</td>\n",
" <td>145</td>\n",
" <td>Administrator, education</td>\n",
" <td>64.081462</td>\n",
" <td>-165.898698</td>\n",
" <td>79.257534</td>\n",
" <td>48</td>\n",
" <td>False</td>\n",
" </tr>\n",
" <tr>\n",
" <th>14445</th>\n",
" <td>Daugherty-Thompson</td>\n",
" <td>food_dining</td>\n",
" <td>50.09</td>\n",
" <td>Unionville</td>\n",
" <td>MO</td>\n",
" <td>40.4815</td>\n",
" <td>-92.9951</td>\n",
" <td>3805</td>\n",
" <td>Investment banker, corporate</td>\n",
" <td>40.387243</td>\n",
" <td>-92.224871</td>\n",
" <td>68.400000</td>\n",
" <td>55</td>\n",
" <td>False</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>14444 rows × 14 columns</p>\n",
"</div>"
],
"text/plain": [
" merchant category amt city \\\n",
"0 Stokes, Christiansen and Sipes grocery_net 14.37 Wales \n",
"1 Predovic Inc shopping_net 966.11 Wales \n",
"2 Wisozk and Sons misc_pos 49.61 Wales \n",
"3 Murray-Smitham grocery_pos 295.26 Wales \n",
"4 Friesen Lt health_fitness 18.17 Wales \n",
"... ... ... ... ... \n",
"14441 Hudson-Grady shopping_pos 122.00 Athena \n",
"14442 Nienow, Ankunding and Collie misc_pos 9.07 Gardiner \n",
"14443 Pacocha-O'Reilly grocery_pos 104.84 Alva \n",
"14444 Bins, Balistreri and Beatty shopping_pos 268.16 Wales \n",
"14445 Daugherty-Thompson food_dining 50.09 Unionville \n",
"\n",
" state lat long city_pop job \\\n",
"0 AK 64.7556 -165.6723 145 Administrator, education \n",
"1 AK 64.7556 -165.6723 145 Administrator, education \n",
"2 AK 64.7556 -165.6723 145 Administrator, education \n",
"3 AK 64.7556 -165.6723 145 Administrator, education \n",
"4 AK 64.7556 -165.6723 145 Administrator, education \n",
"... ... ... ... ... ... \n",
"14441 OR 45.8289 -118.4971 1302 Dealer \n",
"14442 OR 43.7857 -124.1437 260 Engineer, maintenance \n",
"14443 WY 44.6873 -104.4414 110 Administrator, local government \n",
"14444 AK 64.7556 -165.6723 145 Administrator, education \n",
"14445 MO 40.4815 -92.9951 3805 Investment banker, corporate \n",
"\n",
" merch_lat merch_long age trans_time is_fraud \n",
"0 65.654142 -164.722603 79.208219 58 True \n",
"1 65.468863 -165.473127 79.208219 906 True \n",
"2 65.347667 -165.914542 79.208219 1357 True \n",
"3 64.445035 -166.080207 79.208219 1386 True \n",
"4 65.447094 -165.446843 79.208219 1439 True \n",
"... ... ... ... ... ... \n",
"14441 46.442439 -118.524214 42.290411 37 False \n",
"14442 42.901265 -124.995317 62.432877 41 False \n",
"14443 45.538062 -104.542117 45.717808 42 False \n",
"14444 64.081462 -165.898698 79.257534 48 False \n",
"14445 40.387243 -92.224871 68.400000 55 False \n",
"\n",
"[14444 rows x 14 columns]"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"data = data.drop(['trans_date_trans_time', 'dob'], axis='columns', errors='ignore') \n",
"data"
]
},
{
"cell_type": "markdown",
"id": "d7f84860-41a2-494c-8f52-e3edce977d15",
"metadata": {},
"source": [
"We could also try:\n",
"- extracting more features from `trans_date_trans_time`, e.g. day of the week\n",
"- encoding `lat` and `long` as a single feature where lat and long would be rounded up so the values fall into the same area"
]
},
{
"cell_type": "markdown",
"id": "d33932e3-2dbf-4964-88a2-5ec88385aa8a",
"metadata": {},
"source": [
"Finally, visualise how strongly new variables are correlated with the fact that the transaction is fraudulent."
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "e90144af-ed42-4415-ba2a-f81de56827a6",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<Axes: >"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "",
"text/plain": [
"<Figure size 1000x600 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"plt.figure(figsize=(10, 6))\n",
"data.corr(numeric_only=True)['is_fraud'].drop(['is_fraud']).sort_values(ascending = False).plot(kind='bar', grid=True, rot=0)"
]
},
{
"cell_type": "markdown",
"id": "a0c7702c-c176-460e-9d5e-2b5a2e841bf3",
"metadata": {},
"source": [
"## Balance the dataset\n",
"\n",
"Prepare train and test data."
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "edd9c2f5-3faf-4f8e-9d16-dec5538c58c8",
"metadata": {},
"outputs": [],
"source": [
"train_data, test_data = sklearn.model_selection.train_test_split(data, random_state=random_state)\n",
"train_data_input = train_data.drop('is_fraud', axis='columns')\n",
"test_data_input = test_data.drop('is_fraud', axis='columns')\n",
"train_data_output = train_data['is_fraud']\n",
"test_data_output = test_data['is_fraud']"
]
},
{
"cell_type": "markdown",
"id": "b782191f-7929-4f40-9025-675191d72a01",
"metadata": {},
"source": [
"Since there are many fewer fraudulent transaction we oversample them to prepare balanced test dataset."
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "7ec781d5-51e5-4ad7-ab68-5d3997dde8cf",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"is_fraud\n",
"False 9456\n",
"True 1377\n",
"Name: count, dtype: Int64"
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Original.\n",
"train_data_output.value_counts()"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "1e99c8ad-c3e0-4227-8215-385aa2616f7d",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"is_fraud\n",
"False 9456\n",
"True 9456\n",
"Name: count, dtype: Int64"
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"over_sampler = imblearn.over_sampling.RandomOverSampler(random_state=random_state)\n",
"train_data_input, train_data_output = over_sampler.fit_resample(train_data_input, train_data_output)\n",
"# Over sampled.\n",
"train_data_output.value_counts()"
]
},
{
"cell_type": "markdown",
"id": "9ad62bca-7432-41a3-8227-e7312c8b3dcb",
"metadata": {},
"source": [
"# Model training - second experiment"
]
},
{
"cell_type": "markdown",
"id": "ac4c9de5-8f7d-4d58-b3d1-c28bbd8f0aa5",
"metadata": {},
"source": [
"First, we can train the model using the same decision tree with the same parameters as in the first experiment and see if the accuracy improves."
]
},
{
"cell_type": "markdown",
"id": "d367bd24-620b-40c6-b464-532964d4ec35",
"metadata": {},
"source": [
"Train the decision tree classifier on the new preprocessed dataset."
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "41d3dee4-d7cd-434a-b73b-ab6c4038fa8a",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"For parameters: {'criterion': 'gini', 'min_impurity_decrease': 0.0}, accuracy: 0.6271263437784352, precision: 0.8323734485387088 and recall: 0.8332871780670175 was achieved\n",
"For parameters: {'criterion': 'gini', 'min_impurity_decrease': 0.05}, accuracy: 0.5, precision: 0.758071375035556 and recall: 0.8706729437828856 was achieved\n",
"For parameters: {'criterion': 'gini', 'min_impurity_decrease': 0.3}, accuracy: 0.5, precision: 0.758071375035556 and recall: 0.8706729437828856 was achieved\n",
"For parameters: {'criterion': 'entropy', 'min_impurity_decrease': 0.0}, accuracy: 0.6381936157924275, precision: 0.8373859087442732 and recall: 0.8382719468291332 was achieved\n",
"For parameters: {'criterion': 'entropy', 'min_impurity_decrease': 0.05}, accuracy: 0.5, precision: 0.758071375035556 and recall: 0.8706729437828856 was achieved\n",
"For parameters: {'criterion': 'entropy', 'min_impurity_decrease': 0.3}, accuracy: 0.5, precision: 0.758071375035556 and recall: 0.8706729437828856 was achieved\n"
]
}
],
"source": [
"parameter_grid = {\n",
" 'criterion': ['gini', 'entropy'],\n",
" 'min_impurity_decrease': [0.0, 0.05, 0.3]\n",
"}\n",
"\n",
"mlflow.set_experiment(\"Credit card fraud - decision tree and preprocessed dataset\")\n",
"for parameters in sklearn.model_selection.ParameterGrid(parameter_grid):\n",
" criterion = parameters['criterion']\n",
" min_impurity_decrease = parameters['min_impurity_decrease']\n",
"\n",
" # For each parameter combination, we record a new run within a single experiment. \n",
" with mlflow.start_run():\n",
" clr = sklearn.pipeline.make_pipeline(\n",
" sklearn.compose.make_column_transformer(\n",
" (sklearn.preprocessing.OneHotEncoder(sparse_output=True), ['merchant', 'category', 'city', 'state', 'job']),\n",
" ),\n",
" sklearn.tree.DecisionTreeClassifier(criterion=criterion, min_impurity_decrease=min_impurity_decrease, random_state=random_state)\n",
" )\n",
" \n",
" clr.fit(train_data_input, train_data_output)\n",
" test_data_predicted = clr.predict(test_data_input)\n",
" \n",
" accuracy = sklearn.metrics.balanced_accuracy_score(test_data_output, test_data_predicted)\n",
" precision = sklearn.metrics.precision_score(test_data_output, test_data_predicted, average='weighted')\n",
" recall = sklearn.metrics.recall_score(test_data_output, test_data_predicted, average='weighted')\n",
"\n",
" print(f\"For parameters: {parameters}, accuracy: {accuracy}, precision: {precision} and recall: {recall} was achieved\")\n",
"\n",
" # Log all the classifier parameters.\n",
" mlflow.log_param(\"criterion\", criterion)\n",
" mlflow.log_param(\"min_impurity_decrease\", min_impurity_decrease)\n",
" # Log metrics of the trained classifer.\n",
" mlflow.log_metric(\"accuracy\", accuracy)\n",
" mlflow.log_metric(\"precision\", precision)\n",
" mlflow.log_metric(\"recall\", recall)\n",
" # And log the trained classifier itself.\n",
" model_signature = mlflow.models.infer_signature(model_input=train_data_input.iloc[:1], model_output=test_data_predicted[:1])\n",
" model_signature.outputs = mlflow.types.schema.Schema([mlflow.types.schema.ColSpec(\"double\")])\n",
" mlflow.sklearn.log_model(clr, artifact_path=f\"credit-card-fraud-tree-v2\", signature=model_signature)"
]
},
{
"cell_type": "markdown",
"id": "041d96e1-781b-405c-b1c9-9e3be8361de4",
"metadata": {},
"source": [
"Well, we can see that the metrics has not improved. Perhaps this is because the extracted features are not highly correlated with the fact that the transaction is fraudulent."
]
},
{
"cell_type": "markdown",
"id": "eb56490c-a1c2-48ff-ab16-187eb223572a",
"metadata": {},
"source": [
"# Model training - third experiment"
]
},
{
"cell_type": "markdown",
"id": "9a49f5cb-93d9-4062-8b7a-4bf1ee98d7fb",
"metadata": {},
"source": [
"This time we will use a different algorithm - a classifier based on support vector machines (SVMs) method."
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "d4ca530b-7824-4461-93cf-81018281aa89",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"For parameters: {'kernel': 'rbf', 'nu': 0.0005}, accuracy: 0.5135086170728651, precision: 0.7811636255218812 and recall: 0.48463029631680976 was achieved\n",
"For parameters: {'kernel': 'rbf', 'nu': 0.002}, accuracy: 0.5501873661670236, precision: 0.7946433158611632 and recall: 0.6056494045970645 was achieved\n",
"For parameters: {'kernel': 'rbf', 'nu': 0.01}, accuracy: 0.6217481651601092, precision: 0.823948954784356 and recall: 0.7953475491553587 was achieved\n",
"For parameters: {'kernel': 'rbf', 'nu': 0.1}, accuracy: 0.6040536067476339, precision: 0.850064804642781 and recall: 0.8756577125450014 was achieved\n",
"For parameters: {'kernel': 'rbf', 'nu': 0.2}, accuracy: 0.6138373762470646, precision: 0.8479969200478853 and recall: 0.8720576017723622 was achieved\n",
"For parameters: {'kernel': 'rbf', 'nu': 0.5}, accuracy: 0.6714757997286562, precision: 0.8428159779907852 and recall: 0.8089171974522293 was achieved\n",
"For parameters: {'kernel': 'rbf', 'nu': 0.9}, accuracy: 0.6489952651050777, precision: 0.8355923587441608 and recall: 0.6253115480476322 was achieved\n",
"For parameters: {'kernel': 'linear', 'nu': 0.0005}, accuracy: 0.4066785720123576, precision: 0.6994054459708862 and recall: 0.2763777346995292 was achieved\n",
"For parameters: {'kernel': 'linear', 'nu': 0.002}, accuracy: 0.4394594782352845, precision: 0.7323394091294531 and recall: 0.3223483799501523 was achieved\n",
"For parameters: {'kernel': 'linear', 'nu': 0.01}, accuracy: 0.3991035574371632, precision: 0.7080492661564639 and recall: 0.3187482691775131 was achieved\n",
"For parameters: {'kernel': 'linear', 'nu': 0.1}, accuracy: 0.42447903896344485, precision: 0.7183875511255063 and recall: 0.3010246469122127 was achieved\n",
"For parameters: {'kernel': 'linear', 'nu': 0.2}, accuracy: 0.40098947861669143, precision: 0.7065800469215918 and recall: 0.3093325948490722 was achieved\n",
"For parameters: {'kernel': 'linear', 'nu': 0.5}, accuracy: 0.6883527169796928, precision: 0.8463934043808079 and recall: 0.7144835225699252 was achieved\n",
"For parameters: {'kernel': 'linear', 'nu': 0.9}, accuracy: 0.7351326206471931, precision: 0.8651077799101538 and recall: 0.7197452229299363 was achieved\n"
]
}
],
"source": [
"parameter_grid = {\n",
" 'nu': [0.0005, 0.002, 0.01, 0.1, 0.2, 0.5, 0.9],\n",
" 'kernel': ['rbf', 'linear'],\n",
"}\n",
"\n",
"mlflow.set_experiment(\"Credit card fraud - SVM\")\n",
"for parameters in sklearn.model_selection.ParameterGrid(parameter_grid):\n",
" nu = parameters['nu']\n",
" kernel = parameters['kernel']\n",
"\n",
" # For each parameter combination, we record a new run within a single experiment.\n",
" with mlflow.start_run():\n",
" clr = sklearn.pipeline.make_pipeline(\n",
" sklearn.compose.make_column_transformer(\n",
" (sklearn.preprocessing.OneHotEncoder(sparse_output=True), ['merchant', 'category', 'city', 'state', 'job']),\n",
" ),\n",
" sklearn.preprocessing.StandardScaler(with_mean=False), # it is recommended to scale all features before training the SVM model\n",
" sklearn.svm.NuSVC(nu=nu, kernel=kernel, random_state=random_state)\n",
" )\n",
"\n",
" clr.fit(train_data_input, train_data_output)\n",
" test_data_predicted = clr.predict(test_data_input)\n",
"\n",
" accuracy = sklearn.metrics.balanced_accuracy_score(test_data_output, test_data_predicted)\n",
" precision = sklearn.metrics.precision_score(test_data_output, test_data_predicted, average='weighted')\n",
" recall = sklearn.metrics.recall_score(test_data_output, test_data_predicted, average='weighted')\n",
"\n",
" print(f\"For parameters: {parameters}, accuracy: {accuracy}, precision: {precision} and recall: {recall} was achieved\")\n",
"\n",
" # Log all the classifier parameters.\n",
" mlflow.log_param(\"nu\", nu)\n",
" mlflow.log_param(\"kernel\", kernel)\n",
" # Log metrics of the trained classifer.\n",
" mlflow.log_metric(\"accuracy\", accuracy)\n",
" mlflow.log_metric(\"precision\", precision)\n",
" mlflow.log_metric(\"recall\", recall)\n",
" # And log the trained classifier itself.\n",
" model_signature = mlflow.models.infer_signature(model_input=train_data_input.iloc[:1], model_output=test_data_predicted[:1])\n",
" model_signature.outputs = mlflow.types.schema.Schema([mlflow.types.schema.ColSpec(\"double\")])\n",
" mlflow.sklearn.log_model(clr, artifact_path=f\"credit-card-fraud-svm\", signature=model_signature)"
]
},
{
"cell_type": "markdown",
"id": "ea571801-2049-4627-992c-268fe94329a2",
"metadata": {},
"source": [
"The SVM algorithm achieved better metrics. Assume that precision is more important to us as we want to block transactions that appear fraudulent immediately to protect our customers. Log the best model in MLflow UI by comparing the runs."
]
}
],
"metadata": {
"kernelspec": {
"display_name": "nu-ml-demo-20240523",
"language": "python",
"name": "nu-ml-demo-20240523"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.14"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment