Last active
April 20, 2018 17:55
-
-
Save anhquan0412/330494b051f74eacad3917f43e3ba43a to your computer and use it in GitHub Desktop.
Coursera bug in creating lag features
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"import pandas as pd\n", | |
"import numpy as np\n", | |
"from itertools import product\n", | |
"import seaborn as sns\n", | |
"import os\n", | |
"import matplotlib.pyplot as plt\n", | |
"import scipy.sparse \n", | |
"import sklearn\n", | |
"%matplotlib inline\n", | |
"data_path = 'data/'\n", | |
"seed=1204\n", | |
"\n", | |
"\n", | |
"def downcast_dtypes(df):\n", | |
" '''\n", | |
" Changes column types in the dataframe: \n", | |
" \n", | |
" `float64` type to `float32`\n", | |
" `int64` type to `int32`\n", | |
" '''\n", | |
" \n", | |
" # Select columns to downcast\n", | |
" float_cols = [c for c in df if df[c].dtype == \"float64\"]\n", | |
" int_cols = [c for c in df if df[c].dtype == \"int64\"]\n", | |
" \n", | |
" # Downcast\n", | |
" df[float_cols] = df[float_cols].astype(np.float32)\n", | |
" df[int_cols] = df[int_cols].astype(np.int32)\n", | |
" \n", | |
" return df" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"sales = pd.read_csv(os.path.join(data_path, 'sales_train.csv.gz'))\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"C:\\Users\\qtran\\AppData\\Local\\Continuum\\Miniconda3\\lib\\site-packages\\pandas\\core\\groupby.py:4036: FutureWarning: using a dict with renaming is deprecated and will be removed in a future version\n", | |
" return super(DataFrameGroupBy, self).aggregate(arg, *args, **kwargs)\n" | |
] | |
} | |
], | |
"source": [ | |
"# Create \"grid\" with columns\n", | |
"index_cols = ['shop_id', 'item_id', 'date_block_num']\n", | |
"\n", | |
"# For every month we create a grid from all shops/items combinations from that month\n", | |
"grid = [] \n", | |
"for block_num in sales['date_block_num'].unique():\n", | |
" cur_shops = sales.loc[sales['date_block_num'] == block_num, 'shop_id'].unique()\n", | |
" cur_items = sales.loc[sales['date_block_num'] == block_num, 'item_id'].unique()\n", | |
" grid.append(np.array(list(product(*[cur_shops, cur_items, [block_num]])),dtype='int32'))\n", | |
"\n", | |
"# Turn the grid into a dataframe\n", | |
"grid = pd.DataFrame(np.vstack(grid), columns = index_cols,dtype=np.int32)\n", | |
"\n", | |
"# Groupby data to get shop-item-month aggregates to get rid of duplicates\n", | |
"gb = sales.groupby(index_cols,as_index=False).agg({'item_cnt_day':{'target':'sum'}})\n", | |
"\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"[('shop_id', '') ('item_id', '') ('date_block_num', '')\n", | |
" ('item_cnt_day', 'target')]\n" | |
] | |
} | |
], | |
"source": [ | |
"# Fix column names\n", | |
"print(gb.columns.values)\n", | |
"gb.columns = [col[0] if col[-1]=='' else col[-1] for col in gb.columns.values] " | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"# Join it to the grid\n", | |
"all_data = pd.merge(grid, gb, how='left', on=index_cols).fillna(0)\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"C:\\Users\\qtran\\AppData\\Local\\Continuum\\Miniconda3\\lib\\site-packages\\pandas\\core\\groupby.py:4036: FutureWarning: using a dict with renaming is deprecated and will be removed in a future version\n", | |
" return super(DataFrameGroupBy, self).aggregate(arg, *args, **kwargs)\n" | |
] | |
} | |
], | |
"source": [ | |
"\n", | |
"# Same as above but with shop-month aggregates\n", | |
"gb = sales.groupby(['shop_id', 'date_block_num'],as_index=False).agg({'item_cnt_day':{'target_shop':'sum'}})\n", | |
"gb.columns = [col[0] if col[-1]=='' else col[-1] for col in gb.columns.values]\n", | |
"all_data = pd.merge(all_data, gb, how='left', on=['shop_id', 'date_block_num']).fillna(0)\n", | |
"\n", | |
"\n", | |
"# Same as above but with item-month aggregates\n", | |
"gb = sales.groupby(['item_id', 'date_block_num'],as_index=False).agg({'item_cnt_day':{'target_item':'sum'}})\n", | |
"gb.columns = [col[0] if col[-1]=='' else col[-1] for col in gb.columns.values]\n", | |
"all_data = pd.merge(all_data, gb, how='left', on=['item_id', 'date_block_num']).fillna(0)\n", | |
"\n", | |
"\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# get 3 date block num only\n", | |
"all_data_short = all_data.loc[(all_data.date_block_num >=4) & (all_data.date_block_num <=6),:]" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 9, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/html": [ | |
"<div>\n", | |
"<style>\n", | |
" .dataframe thead tr:only-child th {\n", | |
" text-align: right;\n", | |
" }\n", | |
"\n", | |
" .dataframe thead th {\n", | |
" text-align: left;\n", | |
" }\n", | |
"\n", | |
" .dataframe tbody tr th {\n", | |
" vertical-align: top;\n", | |
" }\n", | |
"</style>\n", | |
"<table border=\"1\" class=\"dataframe\">\n", | |
" <thead>\n", | |
" <tr style=\"text-align: right;\">\n", | |
" <th></th>\n", | |
" <th>shop_id</th>\n", | |
" <th>item_id</th>\n", | |
" <th>date_block_num</th>\n", | |
" <th>target</th>\n", | |
" <th>target_shop</th>\n", | |
" <th>target_item</th>\n", | |
" </tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" <tr>\n", | |
" <th>1497465</th>\n", | |
" <td>59</td>\n", | |
" <td>22114</td>\n", | |
" <td>4</td>\n", | |
" <td>1.0</td>\n", | |
" <td>1374.0</td>\n", | |
" <td>26.0</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>1497466</th>\n", | |
" <td>59</td>\n", | |
" <td>20239</td>\n", | |
" <td>4</td>\n", | |
" <td>0.0</td>\n", | |
" <td>1374.0</td>\n", | |
" <td>113.0</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>1497467</th>\n", | |
" <td>59</td>\n", | |
" <td>20238</td>\n", | |
" <td>4</td>\n", | |
" <td>1.0</td>\n", | |
" <td>1374.0</td>\n", | |
" <td>45.0</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>1497468</th>\n", | |
" <td>59</td>\n", | |
" <td>20785</td>\n", | |
" <td>4</td>\n", | |
" <td>0.0</td>\n", | |
" <td>1374.0</td>\n", | |
" <td>7.0</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>1497469</th>\n", | |
" <td>59</td>\n", | |
" <td>20783</td>\n", | |
" <td>4</td>\n", | |
" <td>0.0</td>\n", | |
" <td>1374.0</td>\n", | |
" <td>8.0</td>\n", | |
" </tr>\n", | |
" </tbody>\n", | |
"</table>\n", | |
"</div>" | |
], | |
"text/plain": [ | |
" shop_id item_id date_block_num target target_shop target_item\n", | |
"1497465 59 22114 4 1.0 1374.0 26.0\n", | |
"1497466 59 20239 4 0.0 1374.0 113.0\n", | |
"1497467 59 20238 4 1.0 1374.0 45.0\n", | |
"1497468 59 20785 4 0.0 1374.0 7.0\n", | |
"1497469 59 20783 4 0.0 1374.0 8.0" | |
] | |
}, | |
"execution_count": 9, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"all_data_short.head()" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Observation" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 12, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/html": [ | |
"<div>\n", | |
"<style>\n", | |
" .dataframe thead tr:only-child th {\n", | |
" text-align: right;\n", | |
" }\n", | |
"\n", | |
" .dataframe thead th {\n", | |
" text-align: left;\n", | |
" }\n", | |
"\n", | |
" .dataframe tbody tr th {\n", | |
" vertical-align: top;\n", | |
" }\n", | |
"</style>\n", | |
"<table border=\"1\" class=\"dataframe\">\n", | |
" <thead>\n", | |
" <tr style=\"text-align: right;\">\n", | |
" <th></th>\n", | |
" <th>item_id</th>\n", | |
" <th>date_block_num</th>\n", | |
" <th>target_item</th>\n", | |
" </tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" <tr>\n", | |
" <th>0</th>\n", | |
" <td>27</td>\n", | |
" <td>4</td>\n", | |
" <td>2.0</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>1</th>\n", | |
" <td>27</td>\n", | |
" <td>5</td>\n", | |
" <td>2.0</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>2</th>\n", | |
" <td>27</td>\n", | |
" <td>6</td>\n", | |
" <td>3.0</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>3</th>\n", | |
" <td>28</td>\n", | |
" <td>4</td>\n", | |
" <td>4.0</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>4</th>\n", | |
" <td>28</td>\n", | |
" <td>5</td>\n", | |
" <td>4.0</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>5</th>\n", | |
" <td>28</td>\n", | |
" <td>6</td>\n", | |
" <td>7.0</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>6</th>\n", | |
" <td>29</td>\n", | |
" <td>4</td>\n", | |
" <td>1.0</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>7</th>\n", | |
" <td>29</td>\n", | |
" <td>6</td>\n", | |
" <td>2.0</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>8</th>\n", | |
" <td>30</td>\n", | |
" <td>4</td>\n", | |
" <td>50.0</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>9</th>\n", | |
" <td>30</td>\n", | |
" <td>5</td>\n", | |
" <td>49.0</td>\n", | |
" </tr>\n", | |
" </tbody>\n", | |
"</table>\n", | |
"</div>" | |
], | |
"text/plain": [ | |
" item_id date_block_num target_item\n", | |
"0 27 4 2.0\n", | |
"1 27 5 2.0\n", | |
"2 27 6 3.0\n", | |
"3 28 4 4.0\n", | |
"4 28 5 4.0\n", | |
"5 28 6 7.0\n", | |
"6 29 4 1.0\n", | |
"7 29 6 2.0\n", | |
"8 30 4 50.0\n", | |
"9 30 5 49.0" | |
] | |
}, | |
"execution_count": 12, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"all_data_short.groupby(['item_id','date_block_num'],as_index=False).agg({'target_item': np.mean}).head(10)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"There is no item 29 for date_block_num 5 (from now I will use 'block' for date_block_num). Expected behavior is that only target_item_lag_1 features and target_lag_1 of block 6, item 29 are 0" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"# Coursera provided code" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 13, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"index_cols = ['shop_id', 'item_id', 'date_block_num']\n", | |
"cols_to_rename = list(all_data.columns.difference(index_cols))\n", | |
"\n", | |
"\n", | |
"shift_range = [1] # only 1 month shift, for testing purposes\n", | |
"for month_shift in shift_range:\n", | |
" train_shift = all_data_short[index_cols + cols_to_rename].copy()\n", | |
" \n", | |
" train_shift['date_block_num'] = train_shift['date_block_num'] + month_shift\n", | |
" \n", | |
" foo = lambda x: '{}_lag_{}'.format(x, month_shift) if x in cols_to_rename else x\n", | |
" train_shift = train_shift.rename(columns=foo)\n", | |
"\n", | |
" all_data_short = pd.merge(all_data_short, train_shift, on=index_cols, how='left').fillna(0)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"Let's check block 6, item 29" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 15, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/html": [ | |
"<div>\n", | |
"<style>\n", | |
" .dataframe thead tr:only-child th {\n", | |
" text-align: right;\n", | |
" }\n", | |
"\n", | |
" .dataframe thead th {\n", | |
" text-align: left;\n", | |
" }\n", | |
"\n", | |
" .dataframe tbody tr th {\n", | |
" vertical-align: top;\n", | |
" }\n", | |
"</style>\n", | |
"<table border=\"1\" class=\"dataframe\">\n", | |
" <thead>\n", | |
" <tr style=\"text-align: right;\">\n", | |
" <th></th>\n", | |
" <th>shop_id</th>\n", | |
" <th>item_id</th>\n", | |
" <th>date_block_num</th>\n", | |
" <th>target</th>\n", | |
" <th>target_shop</th>\n", | |
" <th>target_item</th>\n", | |
" <th>target_lag_1</th>\n", | |
" <th>target_item_lag_1</th>\n", | |
" <th>target_shop_lag_1</th>\n", | |
" </tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" <tr>\n", | |
" <th>771050</th>\n", | |
" <td>28</td>\n", | |
" <td>29</td>\n", | |
" <td>6</td>\n", | |
" <td>0.0</td>\n", | |
" <td>6739.0</td>\n", | |
" <td>2.0</td>\n", | |
" <td>0.0</td>\n", | |
" <td>0.0</td>\n", | |
" <td>0.0</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>779455</th>\n", | |
" <td>27</td>\n", | |
" <td>29</td>\n", | |
" <td>6</td>\n", | |
" <td>0.0</td>\n", | |
" <td>4148.0</td>\n", | |
" <td>2.0</td>\n", | |
" <td>0.0</td>\n", | |
" <td>0.0</td>\n", | |
" <td>0.0</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>787860</th>\n", | |
" <td>25</td>\n", | |
" <td>29</td>\n", | |
" <td>6</td>\n", | |
" <td>0.0</td>\n", | |
" <td>7361.0</td>\n", | |
" <td>2.0</td>\n", | |
" <td>0.0</td>\n", | |
" <td>0.0</td>\n", | |
" <td>0.0</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>796265</th>\n", | |
" <td>26</td>\n", | |
" <td>29</td>\n", | |
" <td>6</td>\n", | |
" <td>0.0</td>\n", | |
" <td>2163.0</td>\n", | |
" <td>2.0</td>\n", | |
" <td>0.0</td>\n", | |
" <td>0.0</td>\n", | |
" <td>0.0</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>804670</th>\n", | |
" <td>31</td>\n", | |
" <td>29</td>\n", | |
" <td>6</td>\n", | |
" <td>1.0</td>\n", | |
" <td>9500.0</td>\n", | |
" <td>2.0</td>\n", | |
" <td>0.0</td>\n", | |
" <td>0.0</td>\n", | |
" <td>0.0</td>\n", | |
" </tr>\n", | |
" </tbody>\n", | |
"</table>\n", | |
"</div>" | |
], | |
"text/plain": [ | |
" shop_id item_id date_block_num target target_shop target_item \\\n", | |
"771050 28 29 6 0.0 6739.0 2.0 \n", | |
"779455 27 29 6 0.0 4148.0 2.0 \n", | |
"787860 25 29 6 0.0 7361.0 2.0 \n", | |
"796265 26 29 6 0.0 2163.0 2.0 \n", | |
"804670 31 29 6 1.0 9500.0 2.0 \n", | |
"\n", | |
" target_lag_1 target_item_lag_1 target_shop_lag_1 \n", | |
"771050 0.0 0.0 0.0 \n", | |
"779455 0.0 0.0 0.0 \n", | |
"787860 0.0 0.0 0.0 \n", | |
"796265 0.0 0.0 0.0 \n", | |
"804670 0.0 0.0 0.0 " | |
] | |
}, | |
"execution_count": 15, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"all_data_short[(all_data_short.date_block_num==6) & (all_data_short.item_id==29)].head()" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"Let's just look at the first row. The entire lag features are 0. Why target_shop_lag_1 is 0 when there is shop 28 in block 5? \n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 18, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/html": [ | |
"<div>\n", | |
"<style>\n", | |
" .dataframe thead tr:only-child th {\n", | |
" text-align: right;\n", | |
" }\n", | |
"\n", | |
" .dataframe thead th {\n", | |
" text-align: left;\n", | |
" }\n", | |
"\n", | |
" .dataframe tbody tr th {\n", | |
" vertical-align: top;\n", | |
" }\n", | |
"</style>\n", | |
"<table border=\"1\" class=\"dataframe\">\n", | |
" <thead>\n", | |
" <tr style=\"text-align: right;\">\n", | |
" <th></th>\n", | |
" <th>shop_id</th>\n", | |
" <th>item_id</th>\n", | |
" <th>date_block_num</th>\n", | |
" <th>target</th>\n", | |
" <th>target_shop</th>\n", | |
" <th>target_item</th>\n", | |
" <th>target_lag_1</th>\n", | |
" <th>target_item_lag_1</th>\n", | |
" <th>target_shop_lag_1</th>\n", | |
" </tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" <tr>\n", | |
" <th>400689</th>\n", | |
" <td>28</td>\n", | |
" <td>11496</td>\n", | |
" <td>5</td>\n", | |
" <td>13.0</td>\n", | |
" <td>7056.0</td>\n", | |
" <td>192.0</td>\n", | |
" <td>23.0</td>\n", | |
" <td>323.0</td>\n", | |
" <td>5703.0</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>400690</th>\n", | |
" <td>28</td>\n", | |
" <td>11244</td>\n", | |
" <td>5</td>\n", | |
" <td>0.0</td>\n", | |
" <td>7056.0</td>\n", | |
" <td>11.0</td>\n", | |
" <td>0.0</td>\n", | |
" <td>7.0</td>\n", | |
" <td>5703.0</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>400691</th>\n", | |
" <td>28</td>\n", | |
" <td>11388</td>\n", | |
" <td>5</td>\n", | |
" <td>2.0</td>\n", | |
" <td>7056.0</td>\n", | |
" <td>28.0</td>\n", | |
" <td>0.0</td>\n", | |
" <td>17.0</td>\n", | |
" <td>5703.0</td>\n", | |
" </tr>\n", | |
" </tbody>\n", | |
"</table>\n", | |
"</div>" | |
], | |
"text/plain": [ | |
" shop_id item_id date_block_num target target_shop target_item \\\n", | |
"400689 28 11496 5 13.0 7056.0 192.0 \n", | |
"400690 28 11244 5 0.0 7056.0 11.0 \n", | |
"400691 28 11388 5 2.0 7056.0 28.0 \n", | |
"\n", | |
" target_lag_1 target_item_lag_1 target_shop_lag_1 \n", | |
"400689 23.0 323.0 5703.0 \n", | |
"400690 0.0 7.0 5703.0 \n", | |
"400691 0.0 17.0 5703.0 " | |
] | |
}, | |
"execution_count": 18, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"all_data_short[(all_data_short.date_block_num==5) & (all_data_short.shop_id==28)].head(3)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"This is because in previous code, all index cols (date_block_num,item_id,shop_id) are used for merging even though we only want to merge certain cols. As we can see there is no sales record for following trio in block 5, hence 0 is filled instead" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 16, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/html": [ | |
"<div>\n", | |
"<style>\n", | |
" .dataframe thead tr:only-child th {\n", | |
" text-align: right;\n", | |
" }\n", | |
"\n", | |
" .dataframe thead th {\n", | |
" text-align: left;\n", | |
" }\n", | |
"\n", | |
" .dataframe tbody tr th {\n", | |
" vertical-align: top;\n", | |
" }\n", | |
"</style>\n", | |
"<table border=\"1\" class=\"dataframe\">\n", | |
" <thead>\n", | |
" <tr style=\"text-align: right;\">\n", | |
" <th></th>\n", | |
" <th>shop_id</th>\n", | |
" <th>item_id</th>\n", | |
" <th>date_block_num</th>\n", | |
" <th>target</th>\n", | |
" <th>target_shop</th>\n", | |
" <th>target_item</th>\n", | |
" <th>target_lag_1</th>\n", | |
" <th>target_item_lag_1</th>\n", | |
" <th>target_shop_lag_1</th>\n", | |
" </tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" </tbody>\n", | |
"</table>\n", | |
"</div>" | |
], | |
"text/plain": [ | |
"Empty DataFrame\n", | |
"Columns: [shop_id, item_id, date_block_num, target, target_shop, target_item, target_lag_1, target_item_lag_1, target_shop_lag_1]\n", | |
"Index: []" | |
] | |
}, | |
"execution_count": 16, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"all_data_short[(all_data_short.date_block_num==5) & (all_data_short.item_id==29) & (all_data_short.shop_id==28)].head()" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"# Fix code by choosing the right columns for merging" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 20, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"# reinitialize this variable\n", | |
"all_data_short = all_data.loc[(all_data.date_block_num >=4) & (all_data.date_block_num <=6),:]" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 27, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"index_cols = ['shop_id', 'item_id', 'date_block_num']\n", | |
"cols_to_rename = list(all_data.columns.difference(index_cols))\n", | |
"\n", | |
"cols_gb_item = ['target_item']\n", | |
"cols_gb_shop = ['target_shop']\n", | |
"cols_gb_all = ['target']\n", | |
"cols_gb_key=[['item_id'],['shop_id'],['shop_id','item_id']]\n", | |
"cols_gb_value = [cols_gb_item,cols_gb_shop,cols_gb_all]" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 26, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"['item_id']:['target_item']\n", | |
"['shop_id']:['target_shop']\n", | |
"['shop_id', 'item_id']:['target']\n" | |
] | |
} | |
], | |
"source": [ | |
"# We will loop through a key-value pair value like this:\n", | |
"# columns_to_merge: columns_to_generate_lag\n", | |
"for k,v in zip(cols_gb_key,cols_gb_value):\n", | |
" print('{}:{}'.format(k,v))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 28, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"shift_range = [1]\n", | |
"for month_shift in shift_range:\n", | |
" for k,v in zip(cols_gb_key,cols_gb_value): \n", | |
" index_col = ['date_block_num'] + k # append date block num for each column to merge\n", | |
" train_shift = all_data_short[index_col + v].copy().drop_duplicates()\n", | |
"\n", | |
" train_shift['date_block_num'] = train_shift['date_block_num'] + month_shift\n", | |
"\n", | |
" foo = lambda x: '{}_lag_{}'.format(x, month_shift) if x in v else x\n", | |
" train_shift = train_shift.rename(columns=foo)\n", | |
" all_data_short = pd.merge(all_data_short, train_shift, on=index_col, how='left').fillna(0)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"Let's check block 6 item 29" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 29, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/html": [ | |
"<div>\n", | |
"<style>\n", | |
" .dataframe thead tr:only-child th {\n", | |
" text-align: right;\n", | |
" }\n", | |
"\n", | |
" .dataframe thead th {\n", | |
" text-align: left;\n", | |
" }\n", | |
"\n", | |
" .dataframe tbody tr th {\n", | |
" vertical-align: top;\n", | |
" }\n", | |
"</style>\n", | |
"<table border=\"1\" class=\"dataframe\">\n", | |
" <thead>\n", | |
" <tr style=\"text-align: right;\">\n", | |
" <th></th>\n", | |
" <th>shop_id</th>\n", | |
" <th>item_id</th>\n", | |
" <th>date_block_num</th>\n", | |
" <th>target</th>\n", | |
" <th>target_shop</th>\n", | |
" <th>target_item</th>\n", | |
" <th>target_item_lag_1</th>\n", | |
" <th>target_shop_lag_1</th>\n", | |
" <th>target_lag_1</th>\n", | |
" </tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" <tr>\n", | |
" <th>771050</th>\n", | |
" <td>28</td>\n", | |
" <td>29</td>\n", | |
" <td>6</td>\n", | |
" <td>0.0</td>\n", | |
" <td>6739.0</td>\n", | |
" <td>2.0</td>\n", | |
" <td>0.0</td>\n", | |
" <td>7056.0</td>\n", | |
" <td>0.0</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>779455</th>\n", | |
" <td>27</td>\n", | |
" <td>29</td>\n", | |
" <td>6</td>\n", | |
" <td>0.0</td>\n", | |
" <td>4148.0</td>\n", | |
" <td>2.0</td>\n", | |
" <td>0.0</td>\n", | |
" <td>275.0</td>\n", | |
" <td>0.0</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>787860</th>\n", | |
" <td>25</td>\n", | |
" <td>29</td>\n", | |
" <td>6</td>\n", | |
" <td>0.0</td>\n", | |
" <td>7361.0</td>\n", | |
" <td>2.0</td>\n", | |
" <td>0.0</td>\n", | |
" <td>8478.0</td>\n", | |
" <td>0.0</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>796265</th>\n", | |
" <td>26</td>\n", | |
" <td>29</td>\n", | |
" <td>6</td>\n", | |
" <td>0.0</td>\n", | |
" <td>2163.0</td>\n", | |
" <td>2.0</td>\n", | |
" <td>0.0</td>\n", | |
" <td>2661.0</td>\n", | |
" <td>0.0</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>804670</th>\n", | |
" <td>31</td>\n", | |
" <td>29</td>\n", | |
" <td>6</td>\n", | |
" <td>1.0</td>\n", | |
" <td>9500.0</td>\n", | |
" <td>2.0</td>\n", | |
" <td>0.0</td>\n", | |
" <td>10072.0</td>\n", | |
" <td>0.0</td>\n", | |
" </tr>\n", | |
" </tbody>\n", | |
"</table>\n", | |
"</div>" | |
], | |
"text/plain": [ | |
" shop_id item_id date_block_num target target_shop target_item \\\n", | |
"771050 28 29 6 0.0 6739.0 2.0 \n", | |
"779455 27 29 6 0.0 4148.0 2.0 \n", | |
"787860 25 29 6 0.0 7361.0 2.0 \n", | |
"796265 26 29 6 0.0 2163.0 2.0 \n", | |
"804670 31 29 6 1.0 9500.0 2.0 \n", | |
"\n", | |
" target_item_lag_1 target_shop_lag_1 target_lag_1 \n", | |
"771050 0.0 7056.0 0.0 \n", | |
"779455 0.0 275.0 0.0 \n", | |
"787860 0.0 8478.0 0.0 \n", | |
"796265 0.0 2661.0 0.0 \n", | |
"804670 0.0 10072.0 0.0 " | |
] | |
}, | |
"execution_count": 29, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"all_data_short[(all_data_short.date_block_num==6) & (all_data_short.item_id==29)].head()" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"This looks right. There is no target 29 in block 5 so target_item_lag_1 and target_lag_1 is 0, but not target_shop_lag_1" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"Sanity check for of the first 2 target_shop_lag_1" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 33, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"7056.0" | |
] | |
}, | |
"execution_count": 33, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"all_data_short.loc[(all_data_short.date_block_num==5) & (all_data_short.shop_id==28),'target'].sum()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 34, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"275.0" | |
] | |
}, | |
"execution_count": 34, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"all_data_short.loc[(all_data_short.date_block_num==5) & (all_data_short.shop_id==27),'target'].sum()" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"We are good to go! If we calculate mean of item count for each item or each shop, those mean features can be added to cols_gb_item or cols_gb_shop as well. This fix helps me get a massive boost on leaderboard." | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.6.1" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment