-
-
Save daxiongshu/254bacfa734a4f3a0e13624f255fe319 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"id": "0ad58189-96c6-403b-a544-167f250492b5", | |
"metadata": {}, | |
"source": [ | |
"### In this notebook, we train two xgboost models to predict whether or not a customer will default in future.\n", | |
"- The first xgboost model is trained with given numerical features as is.\n", | |
"- The second xgboost model is trained with given numerical features as well as the Autoregressive RNN generated features.\n", | |
"- Save the model of the 2nd xgboost and write configure file for triton inference" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"id": "54bcbfa3-0eec-4fcd-81d6-6260500c917f", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import os\n", | |
"os.environ['CUDA_VISIBLE_DEVICES'] = '0'" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"id": "c06cfa80-14c8-4b30-a463-ec43de8f5e1b", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"('23.04.00', '1.7.1')" | |
] | |
}, | |
"execution_count": 2, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"import cudf\n", | |
"import cupy\n", | |
"from tqdm import tqdm\n", | |
"import numpy as np\n", | |
"import gc\n", | |
"import xgboost as xgb\n", | |
"from utils import amex_metric_np\n", | |
"\n", | |
"cudf.__version__, xgb.__version__" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"id": "e88971c2-5409-463d-91d6-cf4e42d0d881", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"PATH = '/raid/data/ml/kaggle/amex'" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "49398da7-c11d-45a3-8594-741c01e2bf94", | |
"metadata": {}, | |
"source": [ | |
"# Data preprocessing" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"id": "7ec73ed7-b457-49af-913a-c79e7f425f38", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"(458913, 2)\n", | |
"CPU times: user 1.23 s, sys: 1.35 s, total: 2.59 s\n", | |
"Wall time: 2.58 s\n" | |
] | |
}, | |
{ | |
"data": { | |
"text/html": [ | |
"<div>\n", | |
"<style scoped>\n", | |
" .dataframe tbody tr th:only-of-type {\n", | |
" vertical-align: middle;\n", | |
" }\n", | |
"\n", | |
" .dataframe tbody tr th {\n", | |
" vertical-align: top;\n", | |
" }\n", | |
"\n", | |
" .dataframe thead th {\n", | |
" text-align: right;\n", | |
" }\n", | |
"</style>\n", | |
"<table border=\"1\" class=\"dataframe\">\n", | |
" <thead>\n", | |
" <tr style=\"text-align: right;\">\n", | |
" <th></th>\n", | |
" <th>customer_ID</th>\n", | |
" <th>target</th>\n", | |
" </tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" <tr>\n", | |
" <th>0</th>\n", | |
" <td>0000099d6bd597052cdcda90ffabf56573fe9d7c79be5f...</td>\n", | |
" <td>0</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>1</th>\n", | |
" <td>00000fd6641609c6ece5454664794f0340ad84dddce9a2...</td>\n", | |
" <td>0</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>2</th>\n", | |
" <td>00001b22f846c82c51f6e3958ccd81970162bae8b007e8...</td>\n", | |
" <td>0</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>3</th>\n", | |
" <td>000041bdba6ecadd89a52d11886e8eaaec9325906c9723...</td>\n", | |
" <td>0</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>4</th>\n", | |
" <td>00007889e4fcd2614b6cbe7f8f3d2e5c728eca32d9eb8a...</td>\n", | |
" <td>0</td>\n", | |
" </tr>\n", | |
" </tbody>\n", | |
"</table>\n", | |
"</div>" | |
], | |
"text/plain": [ | |
" customer_ID target\n", | |
"0 0000099d6bd597052cdcda90ffabf56573fe9d7c79be5f... 0\n", | |
"1 00000fd6641609c6ece5454664794f0340ad84dddce9a2... 0\n", | |
"2 00001b22f846c82c51f6e3958ccd81970162bae8b007e8... 0\n", | |
"3 000041bdba6ecadd89a52d11886e8eaaec9325906c9723... 0\n", | |
"4 00007889e4fcd2614b6cbe7f8f3d2e5c728eca32d9eb8a... 0" | |
] | |
}, | |
"execution_count": 4, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"%%time\n", | |
"train = cudf.read_parquet(f'{PATH}/train.parquet')\n", | |
"trainl = cudf.read_csv(f'{PATH}/train_labels.csv')\n", | |
"print(trainl.shape)\n", | |
"trainl.head()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"id": "a586ed4d-f157-45c1-8a8f-e0b623aae80f", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"0 340085\n", | |
"1 118828\n", | |
"Name: target, dtype: int32" | |
] | |
}, | |
"execution_count": 5, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"trainl['target'].value_counts()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"id": "1a695ca4-ba9f-448c-b338-7b13733c5ec6", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"(5531451, 191)\n", | |
"CPU times: user 38.5 ms, sys: 294 ms, total: 332 ms\n", | |
"Wall time: 482 ms\n" | |
] | |
}, | |
{ | |
"data": { | |
"text/html": [ | |
"<div>\n", | |
"<style scoped>\n", | |
" .dataframe tbody tr th:only-of-type {\n", | |
" vertical-align: middle;\n", | |
" }\n", | |
"\n", | |
" .dataframe tbody tr th {\n", | |
" vertical-align: top;\n", | |
" }\n", | |
"\n", | |
" .dataframe thead th {\n", | |
" text-align: right;\n", | |
" }\n", | |
"</style>\n", | |
"<table border=\"1\" class=\"dataframe\">\n", | |
" <thead>\n", | |
" <tr style=\"text-align: right;\">\n", | |
" <th></th>\n", | |
" <th>customer_ID</th>\n", | |
" <th>S_2</th>\n", | |
" <th>P_2</th>\n", | |
" <th>D_39</th>\n", | |
" <th>B_1</th>\n", | |
" <th>B_2</th>\n", | |
" <th>R_1</th>\n", | |
" <th>S_3</th>\n", | |
" <th>D_41</th>\n", | |
" <th>B_3</th>\n", | |
" <th>...</th>\n", | |
" <th>D_137</th>\n", | |
" <th>D_138</th>\n", | |
" <th>D_139</th>\n", | |
" <th>D_140</th>\n", | |
" <th>D_141</th>\n", | |
" <th>D_142</th>\n", | |
" <th>D_143</th>\n", | |
" <th>D_144</th>\n", | |
" <th>D_145</th>\n", | |
" <th>target</th>\n", | |
" </tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" <tr>\n", | |
" <th>0</th>\n", | |
" <td>0044e52546003bf8dc7e0234a971910ea411d932148813...</td>\n", | |
" <td>2018-02-05</td>\n", | |
" <td>0.162061</td>\n", | |
" <td>91</td>\n", | |
" <td>0.104386</td>\n", | |
" <td>0.002240</td>\n", | |
" <td>1.003340</td>\n", | |
" <td>0.20222196</td>\n", | |
" <td>1.599497</td>\n", | |
" <td>0.153965</td>\n", | |
" <td>...</td>\n", | |
" <td>-1</td>\n", | |
" <td>-1</td>\n", | |
" <td>0</td>\n", | |
" <td>0</td>\n", | |
" <td>0.0</td>\n", | |
" <td><NA></td>\n", | |
" <td>0</td>\n", | |
" <td>0.007211</td>\n", | |
" <td>0</td>\n", | |
" <td>1</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>1</th>\n", | |
" <td>0044e52546003bf8dc7e0234a971910ea411d932148813...</td>\n", | |
" <td>2018-03-08</td>\n", | |
" <td>0.209617</td>\n", | |
" <td>31</td>\n", | |
" <td>0.109365</td>\n", | |
" <td>0.166125</td>\n", | |
" <td>1.256646</td>\n", | |
" <td>0.198744327</td>\n", | |
" <td>1.608763</td>\n", | |
" <td>0.156176</td>\n", | |
" <td>...</td>\n", | |
" <td>-1</td>\n", | |
" <td>-1</td>\n", | |
" <td>0</td>\n", | |
" <td>0</td>\n", | |
" <td>0.0</td>\n", | |
" <td><NA></td>\n", | |
" <td>0</td>\n", | |
" <td>0.001472</td>\n", | |
" <td>0</td>\n", | |
" <td>1</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>2</th>\n", | |
" <td>0044f11b5431d326feefaf34642a86e8a38d8b215522b2...</td>\n", | |
" <td>2017-03-25</td>\n", | |
" <td>0.913065</td>\n", | |
" <td>0</td>\n", | |
" <td>0.000341</td>\n", | |
" <td>0.818835</td>\n", | |
" <td>0.002579</td>\n", | |
" <td><NA></td>\n", | |
" <td>0.000000</td>\n", | |
" <td>0.005560</td>\n", | |
" <td>...</td>\n", | |
" <td>-1</td>\n", | |
" <td>-1</td>\n", | |
" <td>0</td>\n", | |
" <td>0</td>\n", | |
" <td>0.0</td>\n", | |
" <td><NA></td>\n", | |
" <td>0</td>\n", | |
" <td>0.002861</td>\n", | |
" <td>0</td>\n", | |
" <td>0</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>3</th>\n", | |
" <td>0044f11b5431d326feefaf34642a86e8a38d8b215522b2...</td>\n", | |
" <td>2017-04-25</td>\n", | |
" <td>0.915826</td>\n", | |
" <td>0</td>\n", | |
" <td>0.000356</td>\n", | |
" <td>0.813991</td>\n", | |
" <td>0.007165</td>\n", | |
" <td><NA></td>\n", | |
" <td>0.000000</td>\n", | |
" <td>0.004210</td>\n", | |
" <td>...</td>\n", | |
" <td>-1</td>\n", | |
" <td>-1</td>\n", | |
" <td>0</td>\n", | |
" <td>0</td>\n", | |
" <td>0.0</td>\n", | |
" <td><NA></td>\n", | |
" <td>0</td>\n", | |
" <td>0.000775</td>\n", | |
" <td>0</td>\n", | |
" <td>0</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>4</th>\n", | |
" <td>0044f11b5431d326feefaf34642a86e8a38d8b215522b2...</td>\n", | |
" <td>2017-05-26</td>\n", | |
" <td>0.914364</td>\n", | |
" <td>0</td>\n", | |
" <td>0.001560</td>\n", | |
" <td>1.007190</td>\n", | |
" <td>0.000820</td>\n", | |
" <td><NA></td>\n", | |
" <td>0.000000</td>\n", | |
" <td>0.008570</td>\n", | |
" <td>...</td>\n", | |
" <td>-1</td>\n", | |
" <td>-1</td>\n", | |
" <td>0</td>\n", | |
" <td>0</td>\n", | |
" <td>0.0</td>\n", | |
" <td><NA></td>\n", | |
" <td>0</td>\n", | |
" <td>0.005072</td>\n", | |
" <td>0</td>\n", | |
" <td>0</td>\n", | |
" </tr>\n", | |
" </tbody>\n", | |
"</table>\n", | |
"<p>5 rows × 191 columns</p>\n", | |
"</div>" | |
], | |
"text/plain": [ | |
" customer_ID S_2 P_2 \\\n", | |
"0 0044e52546003bf8dc7e0234a971910ea411d932148813... 2018-02-05 0.162061 \n", | |
"1 0044e52546003bf8dc7e0234a971910ea411d932148813... 2018-03-08 0.209617 \n", | |
"2 0044f11b5431d326feefaf34642a86e8a38d8b215522b2... 2017-03-25 0.913065 \n", | |
"3 0044f11b5431d326feefaf34642a86e8a38d8b215522b2... 2017-04-25 0.915826 \n", | |
"4 0044f11b5431d326feefaf34642a86e8a38d8b215522b2... 2017-05-26 0.914364 \n", | |
"\n", | |
" D_39 B_1 B_2 R_1 S_3 D_41 B_3 ... \\\n", | |
"0 91 0.104386 0.002240 1.003340 0.20222196 1.599497 0.153965 ... \n", | |
"1 31 0.109365 0.166125 1.256646 0.198744327 1.608763 0.156176 ... \n", | |
"2 0 0.000341 0.818835 0.002579 <NA> 0.000000 0.005560 ... \n", | |
"3 0 0.000356 0.813991 0.007165 <NA> 0.000000 0.004210 ... \n", | |
"4 0 0.001560 1.007190 0.000820 <NA> 0.000000 0.008570 ... \n", | |
"\n", | |
" D_137 D_138 D_139 D_140 D_141 D_142 D_143 D_144 D_145 target \n", | |
"0 -1 -1 0 0 0.0 <NA> 0 0.007211 0 1 \n", | |
"1 -1 -1 0 0 0.0 <NA> 0 0.001472 0 1 \n", | |
"2 -1 -1 0 0 0.0 <NA> 0 0.002861 0 0 \n", | |
"3 -1 -1 0 0 0.0 <NA> 0 0.000775 0 0 \n", | |
"4 -1 -1 0 0 0.0 <NA> 0 0.005072 0 0 \n", | |
"\n", | |
"[5 rows x 191 columns]" | |
] | |
}, | |
"execution_count": 6, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"%%time\n", | |
"train = train.merge(trainl, on='customer_ID', how='left')\n", | |
"print(train.shape)\n", | |
"train.head()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"id": "a9752ac0-e607-41ab-85bc-13f1976e964f", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"CPU times: user 354 ms, sys: 178 ms, total: 533 ms\n", | |
"Wall time: 525 ms\n" | |
] | |
} | |
], | |
"source": [ | |
"%%time\n", | |
"\n", | |
"train['cid'], _ = train.customer_ID.factorize()\n", | |
"train['S_2'] = cudf.to_datetime(train['S_2'])" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"id": "469063ce-d3f0-48ee-944b-dc048f821dd1", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Verify target distribution is consistent across tr and va\n", | |
"0.2493462806763533 0.24835027876095958\n" | |
] | |
} | |
], | |
"source": [ | |
"mask = train['cid']%4 == 0\n", | |
"tr,va = train.loc[~mask],train.loc[mask]\n", | |
"print(\"Verify target distribution is consistent across tr and va\")\n", | |
"print(tr['target'].mean(), va['target'].mean())" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "cb6953a1-72aa-4121-8a75-fe98d0e9207d", | |
"metadata": {}, | |
"source": [ | |
"### Utility Functions" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 9, | |
"id": "4d472896-986a-4f2f-85e7-39f3e21ed6f9", | |
"metadata": { | |
"tags": [] | |
}, | |
"outputs": [], | |
"source": [ | |
"def get_cat_cols():\n", | |
" return ['B_30', 'B_38', 'D_114', 'D_116', 'D_117', 'D_120',\n", | |
" 'D_126', 'D_63', 'D_64', 'D_66', 'D_68']\n", | |
"\n", | |
"def preprocess(df):\n", | |
" df = df.sort_values(['cid','S_2'])\n", | |
" df = df.drop_duplicates('cid',keep='last')\n", | |
" df = df.sort_values('cid')\n", | |
" df = df.reset_index(drop=True)\n", | |
" return df" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 10, | |
"id": "d166fa47-5206-4e96-bd20-a0b2b209b7f0", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"(344184, 192)\n", | |
"CPU times: user 207 ms, sys: 740 ms, total: 947 ms\n", | |
"Wall time: 1.01 s\n" | |
] | |
}, | |
{ | |
"data": { | |
"text/html": [ | |
"<div>\n", | |
"<style scoped>\n", | |
" .dataframe tbody tr th:only-of-type {\n", | |
" vertical-align: middle;\n", | |
" }\n", | |
"\n", | |
" .dataframe tbody tr th {\n", | |
" vertical-align: top;\n", | |
" }\n", | |
"\n", | |
" .dataframe thead th {\n", | |
" text-align: right;\n", | |
" }\n", | |
"</style>\n", | |
"<table border=\"1\" class=\"dataframe\">\n", | |
" <thead>\n", | |
" <tr style=\"text-align: right;\">\n", | |
" <th></th>\n", | |
" <th>customer_ID</th>\n", | |
" <th>S_2</th>\n", | |
" <th>P_2</th>\n", | |
" <th>D_39</th>\n", | |
" <th>B_1</th>\n", | |
" <th>B_2</th>\n", | |
" <th>R_1</th>\n", | |
" <th>S_3</th>\n", | |
" <th>D_41</th>\n", | |
" <th>B_3</th>\n", | |
" <th>...</th>\n", | |
" <th>D_138</th>\n", | |
" <th>D_139</th>\n", | |
" <th>D_140</th>\n", | |
" <th>D_141</th>\n", | |
" <th>D_142</th>\n", | |
" <th>D_143</th>\n", | |
" <th>D_144</th>\n", | |
" <th>D_145</th>\n", | |
" <th>target</th>\n", | |
" <th>cid</th>\n", | |
" </tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" <tr>\n", | |
" <th>0</th>\n", | |
" <td>00000fd6641609c6ece5454664794f0340ad84dddce9a2...</td>\n", | |
" <td>2018-03-25</td>\n", | |
" <td>0.880519</td>\n", | |
" <td>6</td>\n", | |
" <td>0.034684</td>\n", | |
" <td>1.004028</td>\n", | |
" <td>0.006911</td>\n", | |
" <td>0.165509477</td>\n", | |
" <td>0.0</td>\n", | |
" <td>0.005068</td>\n", | |
" <td>...</td>\n", | |
" <td>-1</td>\n", | |
" <td>0</td>\n", | |
" <td>0</td>\n", | |
" <td>0.0</td>\n", | |
" <td><NA></td>\n", | |
" <td>0</td>\n", | |
" <td>0.003169</td>\n", | |
" <td>0</td>\n", | |
" <td>0</td>\n", | |
" <td>1</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>1</th>\n", | |
" <td>00001b22f846c82c51f6e3958ccd81970162bae8b007e8...</td>\n", | |
" <td>2018-03-12</td>\n", | |
" <td>0.880875</td>\n", | |
" <td>0</td>\n", | |
" <td>0.004284</td>\n", | |
" <td>0.812649</td>\n", | |
" <td>0.006450</td>\n", | |
" <td><NA></td>\n", | |
" <td>0.0</td>\n", | |
" <td>0.007196</td>\n", | |
" <td>...</td>\n", | |
" <td>-1</td>\n", | |
" <td>0</td>\n", | |
" <td>0</td>\n", | |
" <td>0.0</td>\n", | |
" <td><NA></td>\n", | |
" <td>0</td>\n", | |
" <td>0.000834</td>\n", | |
" <td>0</td>\n", | |
" <td>0</td>\n", | |
" <td>2</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>2</th>\n", | |
" <td>000041bdba6ecadd89a52d11886e8eaaec9325906c9723...</td>\n", | |
" <td>2018-03-29</td>\n", | |
" <td>0.621776</td>\n", | |
" <td>0</td>\n", | |
" <td>0.012564</td>\n", | |
" <td>1.006183</td>\n", | |
" <td>0.007829</td>\n", | |
" <td>0.287765533</td>\n", | |
" <td>0.0</td>\n", | |
" <td>0.009937</td>\n", | |
" <td>...</td>\n", | |
" <td>-1</td>\n", | |
" <td>0</td>\n", | |
" <td>0</td>\n", | |
" <td>0.0</td>\n", | |
" <td><NA></td>\n", | |
" <td>0</td>\n", | |
" <td>0.005560</td>\n", | |
" <td>0</td>\n", | |
" <td>0</td>\n", | |
" <td>3</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>3</th>\n", | |
" <td>000084e5023181993c2e1b665ac88dbb1ce9ef621ec537...</td>\n", | |
" <td>2018-03-19</td>\n", | |
" <td>0.824061</td>\n", | |
" <td>0</td>\n", | |
" <td>0.007853</td>\n", | |
" <td>1.001713</td>\n", | |
" <td>0.006885</td>\n", | |
" <td>0.395739734</td>\n", | |
" <td>0.0</td>\n", | |
" <td>0.006134</td>\n", | |
" <td>...</td>\n", | |
" <td>-1</td>\n", | |
" <td>0</td>\n", | |
" <td>0</td>\n", | |
" <td>0.0</td>\n", | |
" <td><NA></td>\n", | |
" <td>0</td>\n", | |
" <td>0.006943</td>\n", | |
" <td>0</td>\n", | |
" <td>0</td>\n", | |
" <td>5</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>4</th>\n", | |
" <td>000098081fde4fd64bc4d503a5d6f86a0aedc425c96f52...</td>\n", | |
" <td>2018-03-12</td>\n", | |
" <td>0.477116</td>\n", | |
" <td>0</td>\n", | |
" <td>0.009413</td>\n", | |
" <td>1.009217</td>\n", | |
" <td>0.007775</td>\n", | |
" <td>0.267036825</td>\n", | |
" <td>0.0</td>\n", | |
" <td>0.125927</td>\n", | |
" <td>...</td>\n", | |
" <td>-1</td>\n", | |
" <td>0</td>\n", | |
" <td>0</td>\n", | |
" <td>0.0</td>\n", | |
" <td><NA></td>\n", | |
" <td>0</td>\n", | |
" <td>0.003703</td>\n", | |
" <td>0</td>\n", | |
" <td>0</td>\n", | |
" <td>6</td>\n", | |
" </tr>\n", | |
" </tbody>\n", | |
"</table>\n", | |
"<p>5 rows × 192 columns</p>\n", | |
"</div>" | |
], | |
"text/plain": [ | |
" customer_ID S_2 P_2 \\\n", | |
"0 00000fd6641609c6ece5454664794f0340ad84dddce9a2... 2018-03-25 0.880519 \n", | |
"1 00001b22f846c82c51f6e3958ccd81970162bae8b007e8... 2018-03-12 0.880875 \n", | |
"2 000041bdba6ecadd89a52d11886e8eaaec9325906c9723... 2018-03-29 0.621776 \n", | |
"3 000084e5023181993c2e1b665ac88dbb1ce9ef621ec537... 2018-03-19 0.824061 \n", | |
"4 000098081fde4fd64bc4d503a5d6f86a0aedc425c96f52... 2018-03-12 0.477116 \n", | |
"\n", | |
" D_39 B_1 B_2 R_1 S_3 D_41 B_3 ... D_138 \\\n", | |
"0 6 0.034684 1.004028 0.006911 0.165509477 0.0 0.005068 ... -1 \n", | |
"1 0 0.004284 0.812649 0.006450 <NA> 0.0 0.007196 ... -1 \n", | |
"2 0 0.012564 1.006183 0.007829 0.287765533 0.0 0.009937 ... -1 \n", | |
"3 0 0.007853 1.001713 0.006885 0.395739734 0.0 0.006134 ... -1 \n", | |
"4 0 0.009413 1.009217 0.007775 0.267036825 0.0 0.125927 ... -1 \n", | |
"\n", | |
" D_139 D_140 D_141 D_142 D_143 D_144 D_145 target cid \n", | |
"0 0 0 0.0 <NA> 0 0.003169 0 0 1 \n", | |
"1 0 0 0.0 <NA> 0 0.000834 0 0 2 \n", | |
"2 0 0 0.0 <NA> 0 0.005560 0 0 3 \n", | |
"3 0 0 0.0 <NA> 0 0.006943 0 0 5 \n", | |
"4 0 0 0.0 <NA> 0 0.003703 0 0 6 \n", | |
"\n", | |
"[5 rows x 192 columns]" | |
] | |
}, | |
"execution_count": 10, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"%%time\n", | |
"\n", | |
"tr = preprocess(tr)\n", | |
"print(tr.shape)\n", | |
"tr.head()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 11, | |
"id": "37854e55-40fe-42a8-9ea4-828a67e53f9a", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"(114729, 192)\n", | |
"CPU times: user 76.4 ms, sys: 453 ms, total: 529 ms\n", | |
"Wall time: 536 ms\n" | |
] | |
}, | |
{ | |
"data": { | |
"text/html": [ | |
"<div>\n", | |
"<style scoped>\n", | |
" .dataframe tbody tr th:only-of-type {\n", | |
" vertical-align: middle;\n", | |
" }\n", | |
"\n", | |
" .dataframe tbody tr th {\n", | |
" vertical-align: top;\n", | |
" }\n", | |
"\n", | |
" .dataframe thead th {\n", | |
" text-align: right;\n", | |
" }\n", | |
"</style>\n", | |
"<table border=\"1\" class=\"dataframe\">\n", | |
" <thead>\n", | |
" <tr style=\"text-align: right;\">\n", | |
" <th></th>\n", | |
" <th>customer_ID</th>\n", | |
" <th>S_2</th>\n", | |
" <th>P_2</th>\n", | |
" <th>D_39</th>\n", | |
" <th>B_1</th>\n", | |
" <th>B_2</th>\n", | |
" <th>R_1</th>\n", | |
" <th>S_3</th>\n", | |
" <th>D_41</th>\n", | |
" <th>B_3</th>\n", | |
" <th>...</th>\n", | |
" <th>D_138</th>\n", | |
" <th>D_139</th>\n", | |
" <th>D_140</th>\n", | |
" <th>D_141</th>\n", | |
" <th>D_142</th>\n", | |
" <th>D_143</th>\n", | |
" <th>D_144</th>\n", | |
" <th>D_145</th>\n", | |
" <th>target</th>\n", | |
" <th>cid</th>\n", | |
" </tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" <tr>\n", | |
" <th>0</th>\n", | |
" <td>0000099d6bd597052cdcda90ffabf56573fe9d7c79be5f...</td>\n", | |
" <td>2018-03-13</td>\n", | |
" <td>0.934745</td>\n", | |
" <td>0</td>\n", | |
" <td>0.009382</td>\n", | |
" <td>1.007647</td>\n", | |
" <td>0.006104</td>\n", | |
" <td>0.135021254</td>\n", | |
" <td>0.0</td>\n", | |
" <td>0.007174</td>\n", | |
" <td>...</td>\n", | |
" <td>-1</td>\n", | |
" <td>0</td>\n", | |
" <td>0</td>\n", | |
" <td>0.000000</td>\n", | |
" <td><NA></td>\n", | |
" <td>0</td>\n", | |
" <td>0.002970</td>\n", | |
" <td>0</td>\n", | |
" <td>0</td>\n", | |
" <td>0</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>1</th>\n", | |
" <td>00007889e4fcd2614b6cbe7f8f3d2e5c728eca32d9eb8a...</td>\n", | |
" <td>2018-03-30</td>\n", | |
" <td>0.871900</td>\n", | |
" <td>0</td>\n", | |
" <td>0.007679</td>\n", | |
" <td>0.815746</td>\n", | |
" <td>0.001247</td>\n", | |
" <td><NA></td>\n", | |
" <td>0.0</td>\n", | |
" <td>0.005528</td>\n", | |
" <td>...</td>\n", | |
" <td>-1</td>\n", | |
" <td>0</td>\n", | |
" <td>0</td>\n", | |
" <td>0.000000</td>\n", | |
" <td><NA></td>\n", | |
" <td>0</td>\n", | |
" <td>0.006944</td>\n", | |
" <td>0</td>\n", | |
" <td>0</td>\n", | |
" <td>4</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>2</th>\n", | |
" <td>0000f99513770170a1aba690daeeb8a96da4a39f11fc27...</td>\n", | |
" <td>2018-03-01</td>\n", | |
" <td>0.424624</td>\n", | |
" <td>18</td>\n", | |
" <td>0.979303</td>\n", | |
" <td>0.029291</td>\n", | |
" <td>0.008500</td>\n", | |
" <td>0.152607679</td>\n", | |
" <td>0.0</td>\n", | |
" <td>1.155846</td>\n", | |
" <td>...</td>\n", | |
" <td>-1</td>\n", | |
" <td>1</td>\n", | |
" <td>0</td>\n", | |
" <td>0.876028</td>\n", | |
" <td>0.184614226</td>\n", | |
" <td>1</td>\n", | |
" <td>0.003350</td>\n", | |
" <td>8</td>\n", | |
" <td>1</td>\n", | |
" <td>8</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>3</th>\n", | |
" <td>0001812036f1558332e5c0880ecbad70b13a6f28ab04a8...</td>\n", | |
" <td>2018-03-10</td>\n", | |
" <td>0.424076</td>\n", | |
" <td>8</td>\n", | |
" <td>0.917384</td>\n", | |
" <td>0.029441</td>\n", | |
" <td>0.257114</td>\n", | |
" <td>0.153415367</td>\n", | |
" <td>0.0</td>\n", | |
" <td>0.972654</td>\n", | |
" <td>...</td>\n", | |
" <td>-1</td>\n", | |
" <td>0</td>\n", | |
" <td>0</td>\n", | |
" <td>0.000000</td>\n", | |
" <td><NA></td>\n", | |
" <td>0</td>\n", | |
" <td>0.009148</td>\n", | |
" <td>0</td>\n", | |
" <td>1</td>\n", | |
" <td>12</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>4</th>\n", | |
" <td>0002d381bdd8048d76719042cf1eb63caf53b636f8aacd...</td>\n", | |
" <td>2018-03-19</td>\n", | |
" <td>1.004771</td>\n", | |
" <td>0</td>\n", | |
" <td>0.009469</td>\n", | |
" <td>0.810357</td>\n", | |
" <td>0.009299</td>\n", | |
" <td>0.16986692</td>\n", | |
" <td>0.0</td>\n", | |
" <td>0.008108</td>\n", | |
" <td>...</td>\n", | |
" <td>-1</td>\n", | |
" <td>0</td>\n", | |
" <td>0</td>\n", | |
" <td>0.000000</td>\n", | |
" <td><NA></td>\n", | |
" <td>0</td>\n", | |
" <td>0.003928</td>\n", | |
" <td>0</td>\n", | |
" <td>0</td>\n", | |
" <td>16</td>\n", | |
" </tr>\n", | |
" </tbody>\n", | |
"</table>\n", | |
"<p>5 rows × 192 columns</p>\n", | |
"</div>" | |
], | |
"text/plain": [ | |
" customer_ID S_2 P_2 \\\n", | |
"0 0000099d6bd597052cdcda90ffabf56573fe9d7c79be5f... 2018-03-13 0.934745 \n", | |
"1 00007889e4fcd2614b6cbe7f8f3d2e5c728eca32d9eb8a... 2018-03-30 0.871900 \n", | |
"2 0000f99513770170a1aba690daeeb8a96da4a39f11fc27... 2018-03-01 0.424624 \n", | |
"3 0001812036f1558332e5c0880ecbad70b13a6f28ab04a8... 2018-03-10 0.424076 \n", | |
"4 0002d381bdd8048d76719042cf1eb63caf53b636f8aacd... 2018-03-19 1.004771 \n", | |
"\n", | |
" D_39 B_1 B_2 R_1 S_3 D_41 B_3 ... D_138 \\\n", | |
"0 0 0.009382 1.007647 0.006104 0.135021254 0.0 0.007174 ... -1 \n", | |
"1 0 0.007679 0.815746 0.001247 <NA> 0.0 0.005528 ... -1 \n", | |
"2 18 0.979303 0.029291 0.008500 0.152607679 0.0 1.155846 ... -1 \n", | |
"3 8 0.917384 0.029441 0.257114 0.153415367 0.0 0.972654 ... -1 \n", | |
"4 0 0.009469 0.810357 0.009299 0.16986692 0.0 0.008108 ... -1 \n", | |
"\n", | |
" D_139 D_140 D_141 D_142 D_143 D_144 D_145 target cid \n", | |
"0 0 0 0.000000 <NA> 0 0.002970 0 0 0 \n", | |
"1 0 0 0.000000 <NA> 0 0.006944 0 0 4 \n", | |
"2 1 0 0.876028 0.184614226 1 0.003350 8 1 8 \n", | |
"3 0 0 0.000000 <NA> 0 0.009148 0 1 12 \n", | |
"4 0 0 0.000000 <NA> 0 0.003928 0 0 16 \n", | |
"\n", | |
"[5 rows x 192 columns]" | |
] | |
}, | |
"execution_count": 11, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"%%time\n", | |
"\n", | |
"va = preprocess(va)\n", | |
"print(va.shape)\n", | |
"va.head()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 12, | |
"id": "1e8c71b6-bcc5-482e-ae8b-067425171bb0", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"((344184, 177), (344184,), (114729, 177), (114729,))" | |
] | |
}, | |
"execution_count": 12, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"not_used = [i for i in tr.columns if i in ['cid','target','S_2'] or tr[i].dtype=='O']\n", | |
"not_used += get_cat_cols()\n", | |
"\n", | |
"X_train = tr.drop(not_used,axis=1)\n", | |
"y_train = tr['target']\n", | |
"\n", | |
"X_test = va.drop(not_used,axis=1)\n", | |
"y_test = va['target']\n", | |
"\n", | |
"for i in X_train.columns:\n", | |
" X_train[i] = X_train[i].astype('float32')\n", | |
" X_test[i] = X_test[i].astype('float32')\n", | |
"\n", | |
"X_train.shape, y_train.shape, X_test.shape, y_test.shape" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 13, | |
"id": "550b30cb-c8b3-46bf-aaa3-6b747b00ce07", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"Index(['P_2', 'D_39', 'B_1', 'B_2', 'R_1', 'S_3', 'D_41', 'B_3', 'D_42',\n", | |
" 'D_43',\n", | |
" ...\n", | |
" 'D_136', 'D_137', 'D_138', 'D_139', 'D_140', 'D_141', 'D_142', 'D_143',\n", | |
" 'D_144', 'D_145'],\n", | |
" dtype='object', length=177)" | |
] | |
}, | |
"execution_count": 13, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"X_train.columns" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 14, | |
"id": "879be56d-e76d-49ba-9fb2-459a2e5cefde", | |
"metadata": { | |
"tags": [] | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"1209" | |
] | |
}, | |
"execution_count": 14, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"del train,tr,va\n", | |
"gc.collect()" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "ccd72d75-846f-41f6-89ac-7b674189ee5a", | |
"metadata": { | |
"tags": [] | |
}, | |
"source": [ | |
"# Train the 1st xgboost model with given features only" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 15, | |
"id": "57909e61-3435-4f09-b8b2-6d2f129f2193", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def get_xgb_model():\n", | |
" max_depth = 7\n", | |
" num_trees = 1000\n", | |
" early_stop = xgb.callback.EarlyStopping(rounds=10,\n", | |
" maximize=True,\n", | |
" metric_name='amex_metric_np',\n", | |
" data_name='validation_0')\n", | |
" model = xgb.XGBClassifier(\n", | |
" tree_method='gpu_hist',\n", | |
" enable_categorical=False,\n", | |
" use_label_encoder=False,\n", | |
" predictor='gpu_predictor',\n", | |
" #eval_metric='auc',\n", | |
" objective='binary:logistic',\n", | |
" max_depth=max_depth,\n", | |
" n_estimators=num_trees,\n", | |
" #colsample_bytree=0.5,\n", | |
" min_child_weight=50,\n", | |
" eval_metric=amex_metric_np,\n", | |
" callbacks=[early_stop]\n", | |
" #gamma=10,\n", | |
" )\n", | |
" return model" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 16, | |
"id": "d3d620e6-4e80-4775-a9ce-f1a1a8f5ccdc", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"/raid/data/mambaforge/envs/rapids-23.04/lib/python3.10/site-packages/xgboost/sklearn.py:1421: UserWarning: `use_label_encoder` is deprecated in 1.7.0.\n", | |
" warnings.warn(\"`use_label_encoder` is deprecated in 1.7.0.\")\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"[0]\tvalidation_0-logloss:0.52306\tvalidation_0-amex_metric_np:0.71127\n", | |
"[53]\tvalidation_0-logloss:0.22668\tvalidation_0-amex_metric_np:0.77974\n" | |
] | |
}, | |
{ | |
"data": { | |
"text/plain": [ | |
"0.779758" | |
] | |
}, | |
"execution_count": 16, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"model = get_xgb_model()\n", | |
"model.fit(\n", | |
" X_train,\n", | |
" y_train,\n", | |
" eval_set=[(X_test, y_test)],\n", | |
" verbose=100\n", | |
" )\n", | |
"model.best_score" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "1b1e03ac-26b4-49c4-86da-e5be6b8345f4", | |
"metadata": {}, | |
"source": [ | |
"### The evaluation metric for this competition is the mean of two measures of rank ordering: Normalized Gini Coefficient, and default rate captured at 4%. The larger the metric, the more accurate the model is to predict the default in future. Please find the [description](https://www.kaggle.com/competitions/amex-default-prediction/overview/evaluation) and [analysis](https://www.kaggle.com/competitions/amex-default-prediction/overview/evaluation) to understand more about this metric. For now all we care is **larger metric is better!**" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 17, | |
"id": "21119007-a254-4a22-8080-af7102ad4bd8", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"136" | |
] | |
}, | |
"execution_count": 17, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"del model\n", | |
"gc.collect()" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "a2a7a820-c8b4-49a4-94e9-fb61fa64953a", | |
"metadata": {}, | |
"source": [ | |
"# Add RNN features and train the xgboost again" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 18, | |
"id": "61201a7f-23a5-4557-94fa-30a5f82f58cd", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"CPU times: user 4.09 ms, sys: 1.31 s, total: 1.31 s\n", | |
"Wall time: 1.31 s\n" | |
] | |
}, | |
{ | |
"data": { | |
"text/plain": [ | |
"(458913, 13, 177)" | |
] | |
}, | |
"execution_count": 18, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"%%time\n", | |
"rnn_feas = np.load('rnn_feas.npy')\n", | |
"rnn_feas.shape" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 19, | |
"id": "df040a31-6ca8-4d94-a5d9-6d9ec76c9c6e", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"mask = np.arange(rnn_feas.shape[0])%4==0" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 20, | |
"id": "3d1fe68e-3536-45ec-a415-3ecc6816b98d", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"((344184, 13, 177), (114729, 13, 177))" | |
] | |
}, | |
"execution_count": 20, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"tr_rnn = rnn_feas[~mask]\n", | |
"va_rnn = rnn_feas[mask]\n", | |
"tr_rnn.shape, va_rnn.shape" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "14870821-fb2b-4fa4-91c9-a7071d30f8a1", | |
"metadata": {}, | |
"source": [ | |
"For simplicity, we only use the last profile generated as new features" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 21, | |
"id": "48860986-3879-476d-9de2-9398a1dba2bb", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"((344184, 177), (114729, 177))" | |
] | |
}, | |
"execution_count": 21, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"tr_rnn = tr_rnn[:,-1,:]\n", | |
"va_rnn = va_rnn[:,-1,:]\n", | |
"tr_rnn.shape, va_rnn.shape" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 22, | |
"id": "d7854aaf-b8be-4066-ac99-072af0e93f37", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"((344184, 177), (114729, 177))" | |
] | |
}, | |
"execution_count": 22, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"tr_rnn_df = cudf.DataFrame(tr_rnn,columns=[f'rnn_{i}' for i in range(tr_rnn.shape[1])])\n", | |
"va_rnn_df = cudf.DataFrame(va_rnn,columns=[f'rnn_{i}' for i in range(tr_rnn.shape[1])])\n", | |
"tr_rnn_df.shape, va_rnn_df.shape" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 23, | |
"id": "f52c3f6e-4066-4357-a296-9b8ead421a1c", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"X_train = cudf.concat([X_train,tr_rnn_df],axis=1)\n", | |
"X_test = cudf.concat([X_test,va_rnn_df],axis=1)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 24, | |
"id": "7ef459a7-5156-4e95-ae09-45297267c77f", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"((344184, 354), (344184,), (114729, 354), (114729,))" | |
] | |
}, | |
"execution_count": 24, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"X_train.shape, y_train.shape, X_test.shape, y_test.shape" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 25, | |
"id": "3919ecd2-254c-4297-b8ac-b9e5df88d1e2", | |
"metadata": { | |
"tags": [] | |
}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"/raid/data/mambaforge/envs/rapids-23.04/lib/python3.10/site-packages/xgboost/sklearn.py:1421: UserWarning: `use_label_encoder` is deprecated in 1.7.0.\n", | |
" warnings.warn(\"`use_label_encoder` is deprecated in 1.7.0.\")\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"[0]\tvalidation_0-logloss:0.51646\tvalidation_0-amex_metric_np:0.74595\n", | |
"[49]\tvalidation_0-logloss:0.22555\tvalidation_0-amex_metric_np:0.78227\n" | |
] | |
}, | |
{ | |
"data": { | |
"text/plain": [ | |
"0.78302" | |
] | |
}, | |
"execution_count": 25, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"model = get_xgb_model()\n", | |
"model.fit(\n", | |
" X_train,\n", | |
" y_train,\n", | |
" eval_set=[(X_test, y_test)],\n", | |
" verbose=100\n", | |
" )\n", | |
"model.best_score" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "b271040e-d9fd-466b-aa9f-282bdd006354", | |
"metadata": {}, | |
"source": [ | |
"### We got 0.002 improvement by adding the future profile features! That's significant improvements for default detection! It could move the rank up by hundreds of places in the [competition](https://www.kaggle.com/competitions/amex-default-prediction/leaderboard)!!" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "6ae19bb4-2f61-4c9e-8494-0eacdc4fd5d9", | |
"metadata": {}, | |
"source": [ | |
"# Save the model and write config.pbtxt for triton inference" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 31, | |
"id": "7ae93b49-792d-4200-aaee-6a80aa4bff53", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"from pathlib import Path\n", | |
"\n", | |
"model_dir = 'amex_xgb'\n", | |
"Path(f'{model_dir}/1').mkdir(parents=True, exist_ok=True)\n", | |
"features = X_test.shape[1]\n", | |
"MAX_MEMORY_BYTES = 60_000_000\n", | |
"num_classes = y_test.unique().shape[0]\n", | |
"bytes_per_sample = (features + num_classes) * 4\n", | |
"max_batch_size = MAX_MEMORY_BYTES // bytes_per_sample" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 32, | |
"id": "d85d27e4-0a10-4601-b29e-ce2d04fb4333", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def generate_config(model_dir, max_batch_size, features, deployment_type='gpu', storage_type='AUTO'):\n", | |
" if deployment_type.lower() == 'cpu':\n", | |
" instance_kind = 'KIND_CPU'\n", | |
" else:\n", | |
" instance_kind = 'KIND_GPU'\n", | |
"\n", | |
" config_text = f\"\"\"backend: \"fil\"\n", | |
"max_batch_size: {max_batch_size}\n", | |
"input [ \n", | |
" {{ \n", | |
" name: \"input__0\"\n", | |
" data_type: TYPE_FP32\n", | |
" dims: [ {features} ] \n", | |
" }} \n", | |
"]\n", | |
"output [\n", | |
" {{\n", | |
" name: \"output__0\"\n", | |
" data_type: TYPE_FP32\n", | |
" dims: [ {num_classes} ]\n", | |
" }}\n", | |
"]\n", | |
"instance_group [{{ kind: {instance_kind} }}]\n", | |
"parameters [\n", | |
" {{\n", | |
" key: \"model_type\"\n", | |
" value: {{ string_value: \"xgboost\" }}\n", | |
" }},\n", | |
" {{\n", | |
" key: \"predict_proba\"\n", | |
" value: {{ string_value: \"true\" }}\n", | |
" }},\n", | |
" {{\n", | |
" key: \"output_class\"\n", | |
" value: {{ string_value: \"true\" }}\n", | |
" }},\n", | |
" {{\n", | |
" key: \"threshold\"\n", | |
" value: {{ string_value: \"0.5\" }}\n", | |
" }},\n", | |
" {{\n", | |
" key: \"storage_type\"\n", | |
" value: {{ string_value: \"{storage_type}\" }}\n", | |
" }}\n", | |
"]\n", | |
"\n", | |
"dynamic_batching {{\n", | |
" max_queue_delay_microseconds: 100\n", | |
"}}\"\"\"\n", | |
" config_path = os.path.join(model_dir, 'config.pbtxt')\n", | |
" with open(config_path, 'w') as file_:\n", | |
" file_.write(config_text)\n", | |
"\n", | |
" return config_path" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 33, | |
"id": "5525b9ac-42c0-4d37-ae2f-b7824465d134", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"'amex_xgb/config.pbtxt'" | |
] | |
}, | |
"execution_count": 33, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"generate_config(model_dir, max_batch_size, features)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 34, | |
"id": "6a227e2e-1981-48c2-aa47-5e732c2e6892", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"/raid/data/mambaforge/envs/rapids-23.04/lib/python3.10/site-packages/xgboost/sklearn.py:787: UserWarning: eval_metric is not saved in Scikit-Learn meta.\n", | |
" warnings.warn(\n", | |
"/raid/data/mambaforge/envs/rapids-23.04/lib/python3.10/site-packages/xgboost/sklearn.py:787: UserWarning: callbacks is not saved in Scikit-Learn meta.\n", | |
" warnings.warn(\n" | |
] | |
} | |
], | |
"source": [ | |
"model.save_model(f'{model_dir}/1/xgboost.model')" | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3 (ipykernel)", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.10.10" | |
}, | |
"widgets": { | |
"application/vnd.jupyter.widget-state+json": { | |
"state": {}, | |
"version_major": 2, | |
"version_minor": 0 | |
} | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 5 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment