Skip to content

Instantly share code, notes, and snippets.

@ohmeow
Created July 16, 2023 21:55
Show Gist options
  • Save ohmeow/4cd9342f7c901d6396a8d489c6fbac44 to your computer and use it in GitHub Desktop.
Save ohmeow/4cd9342f7c901d6396a8d489c6fbac44 to your computer and use it in GitHub Desktop.
Corise Forecasting - Project 4
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Project 4: End-to-End Workflow"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"For the final project I introduced a number of updates to my work in project #3. They include:\n",
"- Shifting heavy data preprocessing to `polars` (huge speed improvements there)\n",
"- Caching datasets, fold train/eval datasets, fold features/models, and fold results\n",
"- Using `fast.ai` for putting datasets into formats suitable for decision trees, GBDTs, and NNs\n",
"- Ensembling and weighting model predictions\n",
"- Inference on new data\n",
"\n",
"Things to improve:\n",
"- Adding a simple FF NN to the mix \n",
"- Adding transformer based TS models to the mix\n",
"- Refactor everything into an ML pipeline that can be run as a whole\n",
"- Persist features, fold data, models, results in something a bit more maintable (e.g., S3 buckets, etc...)"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"# add imports\n",
"import json, joblib, pickle\n",
"import datetime as dt\n",
"from pathlib import Path\n",
"\n",
"import lightgbm as lgbm\n",
"import pandas as pd\n",
"import polars as pl\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"import torch\n",
"\n",
"from fastai.data.transforms import ColSplitter\n",
"from fastai.metrics import mse, rmse, mae, msle, exp_rmspe\n",
"from fastai.tabular.core import TabularPandas, Categorify, FillMissing, Normalize, add_datepart, add_elapsed_times\n",
"from fastai.tabular.learner import tabular_learner\n",
"from scipy.optimize import minimize\n",
"from sklearn.ensemble import RandomForestRegressor\n",
"\n",
"\n",
"pl.enable_string_cache(True)\n"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"data_path = \"./data\"\n",
"models_path = \"./models\"\n",
"results_path = \"./results\"\n",
"\n",
"level = \"id\"\n",
"forecast_horizon = 28\n",
"use_pd = False\n"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## Step 1: Build Dataset\n",
"\n",
"\"You can fit the models at any level you want (just make sure both are fit at the same level), but I'd recommend trying out `item_id`. It's a little faster than at the `id` level, and it gives both models a good opportunity to show their diversity.\"\n"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [],
"source": [
"def get_data_by_level(raw_sales_df: pl.DataFrame, level: str = \"id\"):\n",
" aggs = {\"state_id\": pl.col(\"state_id\").first(), \"sales\": pl.col(\"sales\").sum()}\n",
"\n",
" if level == \"id\":\n",
" aggs.update(\n",
" {\n",
" \"item_id\": pl.col(\"item_id\").first(),\n",
" \"dept_id\": pl.col(\"dept_id\").first(),\n",
" \"cat_id\": pl.col(\"cat_id\").first(),\n",
" \"store_id\": pl.col(\"store_id\").first(),\n",
" }\n",
" )\n",
" elif level == \"item_id\":\n",
" aggs.update(\n",
" {\n",
" \"dept_id\": pl.col(\"dept_id\").first(),\n",
" \"cat_id\": pl.col(\"cat_id\").first(),\n",
" \"store_id\": pl.col(\"store_id\").first(),\n",
" }\n",
" )\n",
" elif level == \"dept_id\":\n",
" aggs.update(\n",
" {\n",
" \"cat_id\": pl.col(\"cat_id\").first(),\n",
" \"store_id\": pl.col(\"store_id\").first(),\n",
" }\n",
" )\n",
" elif level == \"cat_id\":\n",
" aggs.update(\n",
" {\n",
" \"store_id\": pl.col(\"store_id\").first(),\n",
" }\n",
" )\n",
"\n",
" df = raw_sales_df.groupby([level, \"date\"], maintain_order=True).agg(list(aggs.values()))\n",
" return df\n"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [],
"source": [
"def calc_lag(df: pl.DataFrame, shift_length: int, forecast_horizon: int, level: str = \"id\", by_day_of_week: bool = False):\n",
" \"\"\"Use to create lagged features\"\"\"\n",
" group_cols = [level]\n",
"\n",
" if not by_day_of_week:\n",
" feature_name = f\"lag_{shift_length}_{forecast_horizon}\"\n",
" else:\n",
" feature_name = f\"seasonal_lag_{shift_length}_{forecast_horizon}\"\n",
" group_cols += [\"Dayofweek\"]\n",
"\n",
" lag_df = df.clone()\n",
" return (\n",
" lag_df.with_columns(shifted_sales=pl.col(\"sales\").shift(forecast_horizon + shift_length).over(group_cols)).rename(\n",
" {\"shifted_sales\": feature_name}\n",
" )\n",
" ), feature_name\n"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [],
"source": [
"def calc_rolling_agg(\n",
" df: pl.DataFrame, window_length: int, forecast_horizon: int, agg_func: str = \"mean\", level: str = \"id\", by_day_of_week: bool = False\n",
"):\n",
" \"\"\"Use to create rolling aggregations\"\"\"\n",
" group_cols = [level]\n",
"\n",
" if not by_day_of_week:\n",
" feature_name = f\"rolling_{agg_func}_{window_length}_{forecast_horizon}\"\n",
" else:\n",
" feature_name = f\"seasonal_rolling_{agg_func}_{window_length}_{forecast_horizon}\"\n",
" group_cols += [\"Dayofweek\"]\n",
"\n",
" roll_df = df.clone()\n",
" return (\n",
" roll_df.with_columns(\n",
" sales=pl.col(\"sales\").rolling_mean(window_length, min_periods=1).over(group_cols)\n",
" if agg_func == \"mean\"\n",
" else pl.col(\"sales\").rolling_std(window_length, min_periods=1).over(group_cols)\n",
" )\n",
" .with_columns(date=pl.col(\"date\").dt.offset_by(f\"{forecast_horizon}d\"))\n",
" .rename({\"sales\": feature_name})\n",
" ), feature_name\n"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {},
"outputs": [],
"source": [
"def add_new_feature(df: pl.DataFrame, feat_df: pl.DataFrame, feat_name: str, level: str = \"id\"):\n",
" \"\"\"Use to append created lag/rolling features to full dataset\"\"\"\n",
" return df.join(feat_df[[\"date\", level, feat_name]], on=[\"date\", \"id\"], how=\"left\")\n"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {},
"outputs": [],
"source": [
"def add_target_features(\n",
" df: pl.DataFrame,\n",
" level: str = \"id\",\n",
" lag_features: list[int] = [],\n",
" seasonal_lag_features: list[int] = [],\n",
" rolling_features: dict[str, list[int]] = {},\n",
" seasonal_rolling_features: dict[str, list[int]] = {},\n",
"):\n",
" \"\"\"Apply feature engineering to add additional features to our full and train/eval datasets\"\"\"\n",
" df = df.clone()\n",
"\n",
" # calculate lag/rolling features\n",
" for lag in lag_features:\n",
" feat_df, feature_name = calc_lag(df, lag, forecast_horizon, level=level, by_day_of_week=False)\n",
" df = add_new_feature(df, feat_df, feature_name, level=level)\n",
"\n",
" for lag in seasonal_lag_features:\n",
" feat_df, feature_name = calc_lag(df, lag, forecast_horizon, level=level, by_day_of_week=True)\n",
" df = add_new_feature(df, feat_df, feature_name, level=level)\n",
"\n",
" for agg_func, windows in rolling_features.items():\n",
" for window in windows:\n",
" feat_df, feature_name = calc_rolling_agg(df, window, forecast_horizon, agg_func=agg_func, level=level, by_day_of_week=False)\n",
" df = add_new_feature(df, feat_df, feature_name, level=level)\n",
"\n",
" for agg_func, windows in seasonal_rolling_features.items():\n",
" for window in windows:\n",
" feat_df, feature_name = calc_rolling_agg(df, window, forecast_horizon, agg_func=agg_func, level=level, by_day_of_week=True)\n",
" df = add_new_feature(df, feat_df, feature_name, level=level)\n",
"\n",
" return df\n"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {},
"outputs": [],
"source": [
"def build_dataset(\n",
" data_fpath: str,\n",
" latest_sales_df: pd.DataFrame|None = None,\n",
" level: str = \"id\",\n",
" forecast_horizon: int = 28,\n",
" lag_features: list[int] = [],\n",
" seasonal_lag_features: list[int] = [],\n",
" rolling_features: dict[str, list[int]] = {},\n",
" seasonal_rolling_features: dict[str, list[int]] = {},\n",
" cache: bool = True,\n",
" override: bool = False,\n",
"):\n",
" if latest_sales_df is None:\n",
" raw_sales_df = get_data_by_level(pl.read_parquet(f\"{data_fpath}/sales_data.parquet\"), level=level)\n",
" else:\n",
" raw_sales_df = get_data_by_level(pl.from_pandas(latest_sales_df), level=level)\n",
" \n",
" max_yyyymmdd = raw_sales_df[\"date\"].max().strftime(\"%Y%m%d\")\n",
" raw_train_fpath = Path(f\"{data_fpath}/{max_yyyymmdd}_data_fh_{forecast_horizon}.parquet\")\n",
"\n",
" if not override and raw_train_fpath.exists():\n",
" return pd.read_parquet(raw_train_fpath)\n",
" \n",
" # drop records until we start seeing sales (some products don't start showing sales for some time)\n",
" # we only care about no gaps by product!\n",
" raw_sales_df = raw_sales_df.filter(pl.col(\"sales\").cumsum().over(\"id\").gt(0))\n",
" \n",
" # merge prices\n",
" prices_df = pl.read_parquet(f\"{data_fpath}/prices.parquet\")\n",
" merged_df = raw_sales_df.join(prices_df, how=\"left\", on=[\"date\", \"store_id\", \"item_id\"])\n",
"\n",
" # merge calendar\n",
" calendar_df = pl.read_parquet(f\"{data_fpath}/calendar.parquet\")\n",
"\n",
" calendar_df = calendar_df.with_columns(\n",
" is_holiday=pl.when((pl.col(\"event_name_1\").is_not_null()) & (pl.col(\"event_name_2\").is_not_null())).then(True).otherwise(False)\n",
" )\n",
"\n",
" big_event_1s = [\n",
" \"SuperBowl\",\n",
" \"MemorialDay\",\n",
" \"NewYear\",\n",
" \"Christmas\",\n",
" \"Thanksgiving\",\n",
" \"IndependenceDay\",\n",
" \"Halloween\",\n",
" \"LaborDay\",\n",
" \"Cinco De Mayo\",\n",
" \"Easter\",\n",
" ]\n",
"\n",
" calendar_df = calendar_df.with_columns(is_big_holiday=pl.when(pl.col(\"event_name_1\").is_in(big_event_1s)).then(True).otherwise(False))\n",
"\n",
" # add in elapsed times for key events (by state)\n",
" calendar_df = calendar_df.to_pandas()\n",
" calendar_df[\"base_field\"] = \"cal\"\n",
" calendar_df = add_elapsed_times(calendar_df, [\"snap_TX\", \"is_holiday\", \"is_big_holiday\"], \"date\", \"base_field\")\n",
" calendar_df = calendar_df.drop(\"base_field\", axis=1)\n",
" calendar_df = pl.from_pandas(calendar_df)\n",
"\n",
" merged_df = merged_df.join(calendar_df, how=\"left\", on=[\"date\"])\n",
" \n",
"\n",
" # add dateparts\n",
" merged_df = merged_df.to_pandas()\n",
" merged_df = add_datepart(merged_df, \"date\", prefix=None, drop=False, time=False)\n",
" merged_df = pl.from_pandas(merged_df)\n",
"\n",
" # add base target features (lags, olling averages, etc..)\n",
" merged_df = add_target_features(\n",
" merged_df,\n",
" level=level,\n",
" lag_features=lag_features,\n",
" seasonal_lag_features=seasonal_lag_features,\n",
" rolling_features=rolling_features,\n",
" seasonal_rolling_features=seasonal_rolling_features,\n",
" )\n",
"\n",
" merged_df = merged_df.to_pandas()\n",
" if cache:\n",
" merged_df.to_parquet(raw_train_fpath)\n",
"\n",
" return merged_df\n"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## Step 2: `TabularPandas`\n"
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"10298579\n",
"CPU times: user 24.1 s, sys: 20.1 s, total: 44.2 s\n",
"Wall time: 11.8 s\n"
]
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>id</th>\n",
" <th>date</th>\n",
" <th>state_id</th>\n",
" <th>sales</th>\n",
" <th>item_id</th>\n",
" <th>dept_id</th>\n",
" <th>cat_id</th>\n",
" <th>store_id</th>\n",
" <th>sell_price</th>\n",
" <th>snap_TX</th>\n",
" <th>...</th>\n",
" <th>seasonal_rolling_mean_2_28</th>\n",
" <th>seasonal_rolling_mean_4_28</th>\n",
" <th>seasonal_rolling_mean_12_28</th>\n",
" <th>seasonal_rolling_mean_26_28</th>\n",
" <th>seasonal_rolling_mean_52_28</th>\n",
" <th>seasonal_rolling_std_2_28</th>\n",
" <th>seasonal_rolling_std_4_28</th>\n",
" <th>seasonal_rolling_std_12_28</th>\n",
" <th>seasonal_rolling_std_26_28</th>\n",
" <th>seasonal_rolling_std_52_28</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>FOODS_1_004_TX_1_evaluation</td>\n",
" <td>2013-01-01</td>\n",
" <td>TX</td>\n",
" <td>20</td>\n",
" <td>FOODS_1_004</td>\n",
" <td>FOODS_1</td>\n",
" <td>FOODS</td>\n",
" <td>TX_1</td>\n",
" <td>1.78</td>\n",
" <td>True</td>\n",
" <td>...</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>FOODS_1_004_TX_2_evaluation</td>\n",
" <td>2013-01-01</td>\n",
" <td>TX</td>\n",
" <td>20</td>\n",
" <td>FOODS_1_004</td>\n",
" <td>FOODS_1</td>\n",
" <td>FOODS</td>\n",
" <td>TX_2</td>\n",
" <td>1.78</td>\n",
" <td>True</td>\n",
" <td>...</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>FOODS_1_004_TX_3_evaluation</td>\n",
" <td>2013-01-01</td>\n",
" <td>TX</td>\n",
" <td>4</td>\n",
" <td>FOODS_1_004</td>\n",
" <td>FOODS_1</td>\n",
" <td>FOODS</td>\n",
" <td>TX_3</td>\n",
" <td>1.78</td>\n",
" <td>True</td>\n",
" <td>...</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>FOODS_1_005_TX_2_evaluation</td>\n",
" <td>2013-01-01</td>\n",
" <td>TX</td>\n",
" <td>1</td>\n",
" <td>FOODS_1_005</td>\n",
" <td>FOODS_1</td>\n",
" <td>FOODS</td>\n",
" <td>TX_2</td>\n",
" <td>3.28</td>\n",
" <td>True</td>\n",
" <td>...</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>FOODS_1_009_TX_2_evaluation</td>\n",
" <td>2013-01-01</td>\n",
" <td>TX</td>\n",
" <td>3</td>\n",
" <td>FOODS_1_009</td>\n",
" <td>FOODS_1</td>\n",
" <td>FOODS</td>\n",
" <td>TX_2</td>\n",
" <td>2.68</td>\n",
" <td>True</td>\n",
" <td>...</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>5 rows × 71 columns</p>\n",
"</div>"
],
"text/plain": [
" id date state_id sales item_id \\\n",
"0 FOODS_1_004_TX_1_evaluation 2013-01-01 TX 20 FOODS_1_004 \n",
"1 FOODS_1_004_TX_2_evaluation 2013-01-01 TX 20 FOODS_1_004 \n",
"2 FOODS_1_004_TX_3_evaluation 2013-01-01 TX 4 FOODS_1_004 \n",
"3 FOODS_1_005_TX_2_evaluation 2013-01-01 TX 1 FOODS_1_005 \n",
"4 FOODS_1_009_TX_2_evaluation 2013-01-01 TX 3 FOODS_1_009 \n",
"\n",
" dept_id cat_id store_id sell_price snap_TX ... \\\n",
"0 FOODS_1 FOODS TX_1 1.78 True ... \n",
"1 FOODS_1 FOODS TX_2 1.78 True ... \n",
"2 FOODS_1 FOODS TX_3 1.78 True ... \n",
"3 FOODS_1 FOODS TX_2 3.28 True ... \n",
"4 FOODS_1 FOODS TX_2 2.68 True ... \n",
"\n",
" seasonal_rolling_mean_2_28 seasonal_rolling_mean_4_28 \\\n",
"0 NaN NaN \n",
"1 NaN NaN \n",
"2 NaN NaN \n",
"3 NaN NaN \n",
"4 NaN NaN \n",
"\n",
" seasonal_rolling_mean_12_28 seasonal_rolling_mean_26_28 \\\n",
"0 NaN NaN \n",
"1 NaN NaN \n",
"2 NaN NaN \n",
"3 NaN NaN \n",
"4 NaN NaN \n",
"\n",
" seasonal_rolling_mean_52_28 seasonal_rolling_std_2_28 \\\n",
"0 NaN NaN \n",
"1 NaN NaN \n",
"2 NaN NaN \n",
"3 NaN NaN \n",
"4 NaN NaN \n",
"\n",
" seasonal_rolling_std_4_28 seasonal_rolling_std_12_28 \\\n",
"0 NaN NaN \n",
"1 NaN NaN \n",
"2 NaN NaN \n",
"3 NaN NaN \n",
"4 NaN NaN \n",
"\n",
" seasonal_rolling_std_26_28 seasonal_rolling_std_52_28 \n",
"0 NaN NaN \n",
"1 NaN NaN \n",
"2 NaN NaN \n",
"3 NaN NaN \n",
"4 NaN NaN \n",
"\n",
"[5 rows x 71 columns]"
]
},
"execution_count": 24,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%time\n",
"\n",
"lag_features = [1, 2, 3, 7, 14, 21, 30, 90, 365]\n",
"seasonal_lag_features = [1, 2, 4, 12, 26, 52]\n",
"\n",
"rolling_features = {\"mean\": [7, 14], \"std\": [7, 14]}\n",
"seasonal_rolling_features = {\"mean\": [1, 2, 4, 12, 26, 52], \"std\": [2, 4, 12, 26, 52]} # a std for 1 results in all NaNs (why???)\n",
"\n",
"df = build_dataset(\n",
" data_path,\n",
" level=\"id\",\n",
" forecast_horizon=28,\n",
" lag_features=lag_features,\n",
" seasonal_lag_features=seasonal_lag_features,\n",
" rolling_features=rolling_features,\n",
" seasonal_rolling_features=seasonal_rolling_features,\n",
")\n",
"\n",
"print(len(df))\n",
"df.head()\n"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"### A. Build Train/Evaluation Splits\n"
]
},
{
"cell_type": "code",
"execution_count": 25,
"metadata": {},
"outputs": [],
"source": [
"def build_train_test_splits(df, forecast_horizon=28, fold=1):\n",
" \"\"\"Creates the train/eval split based on `forecast_horizon`\"\"\"\n",
"\n",
" # create our train/eval splits by `forecast_horizon` * fold\n",
" test_days_start = dt.timedelta(days=forecast_horizon * fold)\n",
" forecast_horizon_days = test_days_start - dt.timedelta(days=forecast_horizon)\n",
"\n",
" train_filter = df[\"date\"] <= (df[\"date\"].max() - test_days_start)\n",
" test_filter = (df[\"date\"] > (df[\"date\"].max() - test_days_start)) & (df[\"date\"] <= (df[\"date\"].max() - forecast_horizon_days))\n",
"\n",
" train_df, test_df = df.loc[train_filter, :].copy(), df.loc[test_filter, :].copy()\n",
"\n",
" # add 'is_valid' column (helpful if you need to concat the datasets and split them back out)\n",
" train_df.loc[:, \"is_valid\"] = False\n",
" test_df.loc[:, \"is_valid\"] = True\n",
"\n",
" return train_df, test_df\n"
]
},
{
"cell_type": "code",
"execution_count": 26,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"2013-01-01 00:00:00 2016-05-22 00:00:00\n",
"2013-01-01 00:00:00 2016-03-27 00:00:00\n",
"2016-03-28 00:00:00 2016-04-24 00:00:00\n"
]
}
],
"source": [
"train_df, test_df = build_train_test_splits(df, forecast_horizon=forecast_horizon, fold=2)\n",
"\n",
"print(df.date.min(), df.date.max())\n",
"print(train_df.date.min(), train_df.date.max())\n",
"print(test_df.date.min(), test_df.date.max())\n"
]
},
{
"cell_type": "code",
"execution_count": 27,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"2013-01-01 00:00:00 2016-05-22 00:00:00\n",
"2013-01-01 00:00:00 2016-04-24 00:00:00\n",
"2016-04-25 00:00:00 2016-05-22 00:00:00\n"
]
}
],
"source": [
"train_df, test_df = build_train_test_splits(df, forecast_horizon=forecast_horizon, fold=1)\n",
"\n",
"print(df.date.min(), df.date.max())\n",
"print(train_df.date.min(), train_df.date.max())\n",
"print(test_df.date.min(), test_df.date.max())\n"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"### B. Add Global/Grouped Aggregations\n",
"\n",
"Notes:\n",
"\n",
"1. \"A common example in retail _(that may help you on the project)_ is taking the average sales at different product hierarchy levels, so averages grouped by category, averages grouped by departments, averages grouped by department and store, etc.\"\n",
"\n",
"2. \"But, **be very careful!** All of these aggregations **can ONLY be calculated on the training set!** Including the validation set leaks information about the validation set to the model while it's training, which is information it certainly won't have in Production.\"\n"
]
},
{
"cell_type": "code",
"execution_count": 28,
"metadata": {},
"outputs": [],
"source": [
"def add_group_features(train_df, test_df):\n",
" train_df = pl.from_pandas(train_df)\n",
" test_df = pl.from_pandas(test_df)\n",
"\n",
" for hierarchy_level in [\"id\", \"item_id\", \"dept_id\", \"cat_id\", \"store_id\"]:\n",
" price_feats_df = train_df.groupby(hierarchy_level, maintain_order=True).agg(\n",
" [\n",
" pl.col(\"sell_price\").max().alias(f\"max_price_{hierarchy_level}\"),\n",
" pl.col(\"sell_price\").median().alias(f\"median_price_{hierarchy_level}\"),\n",
" ]\n",
" )\n",
"\n",
" train_df = train_df.join(price_feats_df.clone(), on=hierarchy_level, how=\"left\")\n",
" test_df = test_df.join(price_feats_df.clone(), on=hierarchy_level, how=\"left\")\n",
"\n",
" return train_df.to_pandas(), test_df.to_pandas()\n"
]
},
{
"cell_type": "code",
"execution_count": 29,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 20.5 s, sys: 7.62 s, total: 28.1 s\n",
"Wall time: 11.8 s\n"
]
}
],
"source": [
"%%time\n",
"\n",
"train_df, test_df = add_group_features(train_df, test_df)\n",
"test_df = test_df.fillna(0)\n"
]
},
{
"cell_type": "code",
"execution_count": 30,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Index(['id', 'date', 'state_id', 'sales', 'item_id', 'dept_id', 'cat_id',\n",
" 'store_id', 'sell_price', 'snap_TX', 'event_name_1', 'event_type_1',\n",
" 'event_name_2', 'event_type_2', 'is_holiday', 'is_big_holiday',\n",
" 'Aftersnap_TX', 'Afteris_holiday', 'Afteris_big_holiday',\n",
" 'Beforesnap_TX', 'Beforeis_holiday', 'Beforeis_big_holiday',\n",
" 'snap_TX_bw', 'is_holiday_bw', 'is_big_holiday_bw', 'snap_TX_fw',\n",
" 'is_holiday_fw', 'is_big_holiday_fw', 'Year', 'Month', 'Week', 'Day',\n",
" 'Dayofweek', 'Dayofyear', 'Is_month_end', 'Is_month_start',\n",
" 'Is_quarter_end', 'Is_quarter_start', 'Is_year_end', 'Is_year_start',\n",
" 'Elapsed', 'lag_1_28', 'lag_2_28', 'lag_3_28', 'lag_7_28', 'lag_14_28',\n",
" 'lag_21_28', 'lag_30_28', 'lag_90_28', 'lag_365_28',\n",
" 'seasonal_lag_1_28', 'seasonal_lag_2_28', 'seasonal_lag_4_28',\n",
" 'seasonal_lag_12_28', 'seasonal_lag_26_28', 'seasonal_lag_52_28',\n",
" 'rolling_mean_7_28', 'rolling_mean_14_28', 'rolling_std_7_28',\n",
" 'rolling_std_14_28', 'seasonal_rolling_mean_1_28',\n",
" 'seasonal_rolling_mean_2_28', 'seasonal_rolling_mean_4_28',\n",
" 'seasonal_rolling_mean_12_28', 'seasonal_rolling_mean_26_28',\n",
" 'seasonal_rolling_mean_52_28', 'seasonal_rolling_std_2_28',\n",
" 'seasonal_rolling_std_4_28', 'seasonal_rolling_std_12_28',\n",
" 'seasonal_rolling_std_26_28', 'seasonal_rolling_std_52_28', 'is_valid',\n",
" 'max_price_id', 'median_price_id', 'max_price_item_id',\n",
" 'median_price_item_id', 'max_price_dept_id', 'median_price_dept_id',\n",
" 'max_price_cat_id', 'median_price_cat_id', 'max_price_store_id',\n",
" 'median_price_store_id'],\n",
" dtype='object')"
]
},
"execution_count": 30,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"train_df.columns\n"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"### C. Build `TabularPandas` object\n"
]
},
{
"cell_type": "code",
"execution_count": 31,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Training Features:\n",
"['lag_1_28', 'lag_2_28', 'lag_3_28', 'lag_7_28', 'lag_14_28', 'lag_21_28', 'lag_30_28', 'lag_90_28', 'lag_365_28', 'seasonal_lag_1_28', 'seasonal_lag_2_28', 'seasonal_lag_4_28', 'seasonal_lag_12_28', 'seasonal_lag_26_28', 'seasonal_lag_52_28', 'rolling_mean_7_28', 'rolling_mean_14_28', 'rolling_std_7_28', 'rolling_std_14_28', 'seasonal_rolling_mean_1_28', 'seasonal_rolling_mean_2_28', 'seasonal_rolling_mean_4_28', 'seasonal_rolling_mean_12_28', 'seasonal_rolling_mean_26_28', 'seasonal_rolling_mean_52_28', 'seasonal_rolling_std_2_28', 'seasonal_rolling_std_4_28', 'seasonal_rolling_std_12_28', 'seasonal_rolling_std_26_28', 'seasonal_rolling_std_52_28', 'Aftersnap_TX', 'Afteris_holiday', 'Afteris_big_holiday', 'Beforesnap_TX', 'Beforeis_holiday', 'Beforeis_big_holiday', 'max_price_id', 'median_price_id', 'max_price_item_id', 'median_price_item_id', 'max_price_dept_id', 'median_price_dept_id', 'max_price_cat_id', 'median_price_cat_id', 'max_price_store_id', 'median_price_store_id', 'id', 'item_id', 'dept_id', 'cat_id', 'store_id', 'snap_TX', 'event_name_1', 'event_type_1', 'event_name_2', 'event_type_2', 'Month', 'Day', 'Dayofweek', 'Is_month_end', 'Is_month_start', 'is_holiday', 'is_big_holiday', 'snap_TX_bw', 'snap_TX_fw']\n",
"\n",
"10298579\n"
]
}
],
"source": [
"# --- for full dataset ---\n",
"cont_features = [\n",
" \"lag_1_28\",\n",
" \"lag_2_28\",\n",
" \"lag_3_28\",\n",
" \"lag_7_28\",\n",
" \"lag_14_28\",\n",
" \"lag_21_28\",\n",
" \"lag_30_28\",\n",
" \"lag_90_28\",\n",
" \"lag_365_28\",\n",
" \"seasonal_lag_1_28\",\n",
" \"seasonal_lag_2_28\",\n",
" \"seasonal_lag_4_28\",\n",
" \"seasonal_lag_12_28\",\n",
" \"seasonal_lag_26_28\",\n",
" \"seasonal_lag_52_28\",\n",
" \"rolling_mean_7_28\",\n",
" \"rolling_mean_14_28\",\n",
" \"rolling_std_7_28\",\n",
" \"rolling_std_14_28\",\n",
" \"seasonal_rolling_mean_1_28\",\n",
" \"seasonal_rolling_mean_2_28\",\n",
" \"seasonal_rolling_mean_4_28\",\n",
" \"seasonal_rolling_mean_12_28\",\n",
" \"seasonal_rolling_mean_26_28\",\n",
" \"seasonal_rolling_mean_52_28\",\n",
" # \"seasonal_rolling_std_1_28\",\n",
" \"seasonal_rolling_std_2_28\",\n",
" \"seasonal_rolling_std_4_28\",\n",
" \"seasonal_rolling_std_12_28\",\n",
" \"seasonal_rolling_std_26_28\",\n",
" \"seasonal_rolling_std_52_28\",\n",
" \"Aftersnap_TX\",\n",
" \"Afteris_holiday\",\n",
" \"Afteris_big_holiday\",\n",
" \"Beforesnap_TX\",\n",
" \"Beforeis_holiday\",\n",
" \"Beforeis_big_holiday\",\n",
" \"max_price_id\",\n",
" \"median_price_id\",\n",
" \"max_price_item_id\",\n",
" \"median_price_item_id\",\n",
" \"max_price_dept_id\",\n",
" \"median_price_dept_id\",\n",
" \"max_price_cat_id\",\n",
" \"median_price_cat_id\",\n",
" \"max_price_store_id\",\n",
" \"median_price_store_id\",\n",
"]\n",
"\n",
"cat_features = [\n",
" \"id\",\n",
" \"item_id\",\n",
" \"dept_id\",\n",
" \"cat_id\",\n",
" \"store_id\",\n",
" \"snap_TX\",\n",
" \"event_name_1\",\n",
" \"event_type_1\",\n",
" \"event_name_2\",\n",
" \"event_type_2\",\n",
" \"Month\",\n",
" \"Day\",\n",
" \"Dayofweek\",\n",
" \"Is_month_end\",\n",
" \"Is_month_start\",\n",
" \"is_holiday\",\n",
" \"is_big_holiday\",\n",
" \"snap_TX_bw\",\n",
" \"snap_TX_fw\",\n",
"]\n",
"\n",
"features = cont_features + cat_features\n",
"target_attr = \"sales\"\n",
"\n",
"print(f\"Training Features:\\n{features}\\n\")\n",
"print(len(train_df) + len(test_df))\n"
]
},
{
"cell_type": "code",
"execution_count": 32,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"10298579\n"
]
}
],
"source": [
"combined_df = pd.concat([train_df, test_df])\n",
"print(len(combined_df))\n"
]
},
{
"cell_type": "code",
"execution_count": 33,
"metadata": {},
"outputs": [],
"source": [
"splits = ColSplitter(col=\"is_valid\")(combined_df)\n"
]
},
{
"cell_type": "code",
"execution_count": 34,
"metadata": {},
"outputs": [],
"source": [
"procs = [Categorify, FillMissing, Normalize]\n"
]
},
{
"cell_type": "code",
"execution_count": 35,
"metadata": {},
"outputs": [],
"source": [
"to = TabularPandas(combined_df, procs, cat_names=cat_features, cont_names=cont_features, y_names=target_attr, splits=splits)\n"
]
},
{
"cell_type": "code",
"execution_count": 36,
"metadata": {},
"outputs": [],
"source": [
"def build_fold_to(\n",
" df: pd.DataFrame,\n",
" cat_features,\n",
" cont_features,\n",
" target_attr,\n",
" forecast_horizon: int = 28,\n",
" fold: int = 1,\n",
" data_fpath: str = \"./data\",\n",
" cache: bool = True,\n",
" override: bool = False,\n",
"):\n",
" max_yyyymmdd = df[\"date\"].max().strftime(\"%Y%m%d\")\n",
" to_fpath = Path(f\"{data_fpath}/{max_yyyymmdd}_to_fh_{forecast_horizon}_fold_{fold}.pkl\")\n",
"\n",
" if not override and to_fpath.exists():\n",
" with open(to_fpath, \"rb\") as file:\n",
" to = pickle.load(file)\n",
" else:\n",
" train_df, test_df = build_train_test_splits(df, forecast_horizon=forecast_horizon, fold=fold)\n",
" train_df, test_df = add_group_features(train_df, test_df)\n",
" \n",
" test_df = test_df.fillna(0)\n",
" \n",
" combined_df = pd.concat([train_df.copy(), test_df.copy()])\n",
"\n",
" splits = ColSplitter(col=\"is_valid\")(combined_df)\n",
" procs = [Categorify, FillMissing, Normalize]\n",
"\n",
" to = TabularPandas(combined_df, procs, cat_names=cat_features, cont_names=cont_features, y_names=target_attr, splits=splits)\n",
"\n",
" if cache:\n",
" with open(to_fpath, \"wb\") as file:\n",
" pickle.dump(to, file)\n",
"\n",
" return to\n"
]
},
{
"cell_type": "code",
"execution_count": 37,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 41.6 ms, sys: 2.11 s, total: 2.15 s\n",
"Wall time: 2.87 s\n"
]
}
],
"source": [
"%%time\n",
"\n",
"to = build_fold_to(df, cat_features, cont_features, target_attr, forecast_horizon=forecast_horizon, fold= 2, data_fpath=data_path)"
]
},
{
"cell_type": "code",
"execution_count": 38,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>id</th>\n",
" <th>item_id</th>\n",
" <th>dept_id</th>\n",
" <th>cat_id</th>\n",
" <th>store_id</th>\n",
" <th>snap_TX</th>\n",
" <th>event_name_1</th>\n",
" <th>event_type_1</th>\n",
" <th>event_name_2</th>\n",
" <th>event_type_2</th>\n",
" <th>Month</th>\n",
" <th>Day</th>\n",
" <th>Dayofweek</th>\n",
" <th>Is_month_end</th>\n",
" <th>Is_month_start</th>\n",
" <th>is_holiday</th>\n",
" <th>is_big_holiday</th>\n",
" <th>snap_TX_bw</th>\n",
" <th>snap_TX_fw</th>\n",
" <th>lag_1_28_na</th>\n",
" <th>lag_2_28_na</th>\n",
" <th>lag_3_28_na</th>\n",
" <th>lag_7_28_na</th>\n",
" <th>lag_14_28_na</th>\n",
" <th>lag_21_28_na</th>\n",
" <th>lag_30_28_na</th>\n",
" <th>lag_90_28_na</th>\n",
" <th>lag_365_28_na</th>\n",
" <th>seasonal_lag_1_28_na</th>\n",
" <th>seasonal_lag_2_28_na</th>\n",
" <th>seasonal_lag_4_28_na</th>\n",
" <th>seasonal_lag_12_28_na</th>\n",
" <th>seasonal_lag_26_28_na</th>\n",
" <th>seasonal_lag_52_28_na</th>\n",
" <th>rolling_mean_7_28_na</th>\n",
" <th>rolling_mean_14_28_na</th>\n",
" <th>rolling_std_7_28_na</th>\n",
" <th>rolling_std_14_28_na</th>\n",
" <th>seasonal_rolling_mean_1_28_na</th>\n",
" <th>seasonal_rolling_mean_2_28_na</th>\n",
" <th>seasonal_rolling_mean_4_28_na</th>\n",
" <th>seasonal_rolling_mean_12_28_na</th>\n",
" <th>seasonal_rolling_mean_26_28_na</th>\n",
" <th>seasonal_rolling_mean_52_28_na</th>\n",
" <th>seasonal_rolling_std_2_28_na</th>\n",
" <th>seasonal_rolling_std_4_28_na</th>\n",
" <th>seasonal_rolling_std_12_28_na</th>\n",
" <th>seasonal_rolling_std_26_28_na</th>\n",
" <th>seasonal_rolling_std_52_28_na</th>\n",
" <th>lag_1_28</th>\n",
" <th>lag_2_28</th>\n",
" <th>lag_3_28</th>\n",
" <th>lag_7_28</th>\n",
" <th>lag_14_28</th>\n",
" <th>lag_21_28</th>\n",
" <th>lag_30_28</th>\n",
" <th>lag_90_28</th>\n",
" <th>lag_365_28</th>\n",
" <th>seasonal_lag_1_28</th>\n",
" <th>seasonal_lag_2_28</th>\n",
" <th>seasonal_lag_4_28</th>\n",
" <th>seasonal_lag_12_28</th>\n",
" <th>seasonal_lag_26_28</th>\n",
" <th>seasonal_lag_52_28</th>\n",
" <th>rolling_mean_7_28</th>\n",
" <th>rolling_mean_14_28</th>\n",
" <th>rolling_std_7_28</th>\n",
" <th>rolling_std_14_28</th>\n",
" <th>seasonal_rolling_mean_1_28</th>\n",
" <th>seasonal_rolling_mean_2_28</th>\n",
" <th>seasonal_rolling_mean_4_28</th>\n",
" <th>seasonal_rolling_mean_12_28</th>\n",
" <th>seasonal_rolling_mean_26_28</th>\n",
" <th>seasonal_rolling_mean_52_28</th>\n",
" <th>seasonal_rolling_std_2_28</th>\n",
" <th>seasonal_rolling_std_4_28</th>\n",
" <th>seasonal_rolling_std_12_28</th>\n",
" <th>seasonal_rolling_std_26_28</th>\n",
" <th>seasonal_rolling_std_52_28</th>\n",
" <th>Aftersnap_TX</th>\n",
" <th>Afteris_holiday</th>\n",
" <th>Afteris_big_holiday</th>\n",
" <th>Beforesnap_TX</th>\n",
" <th>Beforeis_holiday</th>\n",
" <th>Beforeis_big_holiday</th>\n",
" <th>max_price_id</th>\n",
" <th>median_price_id</th>\n",
" <th>max_price_item_id</th>\n",
" <th>median_price_item_id</th>\n",
" <th>max_price_dept_id</th>\n",
" <th>median_price_dept_id</th>\n",
" <th>max_price_cat_id</th>\n",
" <th>median_price_cat_id</th>\n",
" <th>max_price_store_id</th>\n",
" <th>median_price_store_id</th>\n",
" <th>sales</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>FOODS_1_004_TX_1_evaluation</td>\n",
" <td>FOODS_1_004</td>\n",
" <td>FOODS_1</td>\n",
" <td>FOODS</td>\n",
" <td>TX_1</td>\n",
" <td>True</td>\n",
" <td>NewYear</td>\n",
" <td>National</td>\n",
" <td>#na#</td>\n",
" <td>#na#</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" <td>False</td>\n",
" <td>True</td>\n",
" <td>False</td>\n",
" <td>True</td>\n",
" <td>1.0</td>\n",
" <td>5.0</td>\n",
" <td>True</td>\n",
" <td>True</td>\n",
" <td>True</td>\n",
" <td>True</td>\n",
" <td>True</td>\n",
" <td>True</td>\n",
" <td>True</td>\n",
" <td>True</td>\n",
" <td>True</td>\n",
" <td>True</td>\n",
" <td>True</td>\n",
" <td>True</td>\n",
" <td>True</td>\n",
" <td>True</td>\n",
" <td>True</td>\n",
" <td>True</td>\n",
" <td>True</td>\n",
" <td>True</td>\n",
" <td>True</td>\n",
" <td>True</td>\n",
" <td>True</td>\n",
" <td>True</td>\n",
" <td>True</td>\n",
" <td>True</td>\n",
" <td>True</td>\n",
" <td>True</td>\n",
" <td>True</td>\n",
" <td>True</td>\n",
" <td>True</td>\n",
" <td>True</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.428571</td>\n",
" <td>0.428571</td>\n",
" <td>0.690066</td>\n",
" <td>0.699293</td>\n",
" <td>0.0</td>\n",
" <td>0.5</td>\n",
" <td>0.5</td>\n",
" <td>0.416667</td>\n",
" <td>0.466667</td>\n",
" <td>0.5</td>\n",
" <td>0.0</td>\n",
" <td>0.5</td>\n",
" <td>0.6742</td>\n",
" <td>0.761577</td>\n",
" <td>0.816497</td>\n",
" <td>0.0</td>\n",
" <td>618.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>-124.0</td>\n",
" <td>0.0</td>\n",
" <td>1.96</td>\n",
" <td>1.96</td>\n",
" <td>1.96</td>\n",
" <td>1.96</td>\n",
" <td>12.98</td>\n",
" <td>2.48</td>\n",
" <td>18.98</td>\n",
" <td>2.68</td>\n",
" <td>48.779999</td>\n",
" <td>3.42</td>\n",
" <td>20</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>FOODS_1_004_TX_2_evaluation</td>\n",
" <td>FOODS_1_004</td>\n",
" <td>FOODS_1</td>\n",
" <td>FOODS</td>\n",
" <td>TX_2</td>\n",
" <td>True</td>\n",
" <td>NewYear</td>\n",
" <td>National</td>\n",
" <td>#na#</td>\n",
" <td>#na#</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" <td>False</td>\n",
" <td>True</td>\n",
" <td>False</td>\n",
" <td>True</td>\n",
" <td>1.0</td>\n",
" <td>5.0</td>\n",
" <td>True</td>\n",
" <td>True</td>\n",
" <td>True</td>\n",
" <td>True</td>\n",
" <td>True</td>\n",
" <td>True</td>\n",
" <td>True</td>\n",
" <td>True</td>\n",
" <td>True</td>\n",
" <td>True</td>\n",
" <td>True</td>\n",
" <td>True</td>\n",
" <td>True</td>\n",
" <td>True</td>\n",
" <td>True</td>\n",
" <td>True</td>\n",
" <td>True</td>\n",
" <td>True</td>\n",
" <td>True</td>\n",
" <td>True</td>\n",
" <td>True</td>\n",
" <td>True</td>\n",
" <td>True</td>\n",
" <td>True</td>\n",
" <td>True</td>\n",
" <td>True</td>\n",
" <td>True</td>\n",
" <td>True</td>\n",
" <td>True</td>\n",
" <td>True</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.428571</td>\n",
" <td>0.428571</td>\n",
" <td>0.690066</td>\n",
" <td>0.699293</td>\n",
" <td>0.0</td>\n",
" <td>0.5</td>\n",
" <td>0.5</td>\n",
" <td>0.416667</td>\n",
" <td>0.466667</td>\n",
" <td>0.5</td>\n",
" <td>0.0</td>\n",
" <td>0.5</td>\n",
" <td>0.6742</td>\n",
" <td>0.761577</td>\n",
" <td>0.816497</td>\n",
" <td>0.0</td>\n",
" <td>618.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>-124.0</td>\n",
" <td>0.0</td>\n",
" <td>1.96</td>\n",
" <td>1.78</td>\n",
" <td>1.96</td>\n",
" <td>1.96</td>\n",
" <td>12.98</td>\n",
" <td>2.48</td>\n",
" <td>18.98</td>\n",
" <td>2.68</td>\n",
" <td>29.969999</td>\n",
" <td>3.42</td>\n",
" <td>20</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"to.show(2) # display raw categorical labels\n"
]
},
{
"cell_type": "code",
"execution_count": 39,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>id</th>\n",
" <th>date</th>\n",
" <th>state_id</th>\n",
" <th>sales</th>\n",
" <th>item_id</th>\n",
" <th>dept_id</th>\n",
" <th>cat_id</th>\n",
" <th>store_id</th>\n",
" <th>sell_price</th>\n",
" <th>snap_TX</th>\n",
" <th>...</th>\n",
" <th>seasonal_rolling_mean_2_28_na</th>\n",
" <th>seasonal_rolling_mean_4_28_na</th>\n",
" <th>seasonal_rolling_mean_12_28_na</th>\n",
" <th>seasonal_rolling_mean_26_28_na</th>\n",
" <th>seasonal_rolling_mean_52_28_na</th>\n",
" <th>seasonal_rolling_std_2_28_na</th>\n",
" <th>seasonal_rolling_std_4_28_na</th>\n",
" <th>seasonal_rolling_std_12_28_na</th>\n",
" <th>seasonal_rolling_std_26_28_na</th>\n",
" <th>seasonal_rolling_std_52_28_na</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>10</td>\n",
" <td>2013-01-01</td>\n",
" <td>TX</td>\n",
" <td>20</td>\n",
" <td>4</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" <td>1.78</td>\n",
" <td>2</td>\n",
" <td>...</td>\n",
" <td>2</td>\n",
" <td>2</td>\n",
" <td>2</td>\n",
" <td>2</td>\n",
" <td>2</td>\n",
" <td>2</td>\n",
" <td>2</td>\n",
" <td>2</td>\n",
" <td>2</td>\n",
" <td>2</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>11</td>\n",
" <td>2013-01-01</td>\n",
" <td>TX</td>\n",
" <td>20</td>\n",
" <td>4</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" <td>2</td>\n",
" <td>1.78</td>\n",
" <td>2</td>\n",
" <td>...</td>\n",
" <td>2</td>\n",
" <td>2</td>\n",
" <td>2</td>\n",
" <td>2</td>\n",
" <td>2</td>\n",
" <td>2</td>\n",
" <td>2</td>\n",
" <td>2</td>\n",
" <td>2</td>\n",
" <td>2</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>12</td>\n",
" <td>2013-01-01</td>\n",
" <td>TX</td>\n",
" <td>4</td>\n",
" <td>4</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" <td>3</td>\n",
" <td>1.78</td>\n",
" <td>2</td>\n",
" <td>...</td>\n",
" <td>2</td>\n",
" <td>2</td>\n",
" <td>2</td>\n",
" <td>2</td>\n",
" <td>2</td>\n",
" <td>2</td>\n",
" <td>2</td>\n",
" <td>2</td>\n",
" <td>2</td>\n",
" <td>2</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>3 rows × 112 columns</p>\n",
"</div>"
],
"text/plain": [
" id date state_id sales item_id dept_id cat_id store_id \\\n",
"0 10 2013-01-01 TX 20 4 1 1 1 \n",
"1 11 2013-01-01 TX 20 4 1 1 2 \n",
"2 12 2013-01-01 TX 4 4 1 1 3 \n",
"\n",
" sell_price snap_TX ... seasonal_rolling_mean_2_28_na \\\n",
"0 1.78 2 ... 2 \n",
"1 1.78 2 ... 2 \n",
"2 1.78 2 ... 2 \n",
"\n",
" seasonal_rolling_mean_4_28_na seasonal_rolling_mean_12_28_na \\\n",
"0 2 2 \n",
"1 2 2 \n",
"2 2 2 \n",
"\n",
" seasonal_rolling_mean_26_28_na seasonal_rolling_mean_52_28_na \\\n",
"0 2 2 \n",
"1 2 2 \n",
"2 2 2 \n",
"\n",
" seasonal_rolling_std_2_28_na seasonal_rolling_std_4_28_na \\\n",
"0 2 2 \n",
"1 2 2 \n",
"2 2 2 \n",
"\n",
" seasonal_rolling_std_12_28_na seasonal_rolling_std_26_28_na \\\n",
"0 2 2 \n",
"1 2 2 \n",
"2 2 2 \n",
"\n",
" seasonal_rolling_std_52_28_na \n",
"0 2 \n",
"1 2 \n",
"2 2 \n",
"\n",
"[3 rows x 112 columns]"
]
},
"execution_count": 39,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"to.items.head(3) # display numeric representations\n"
]
},
{
"cell_type": "code",
"execution_count": 40,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"['#na#', False, True]"
]
},
"execution_count": 40,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"to.classes[\"is_big_holiday\"] # display categorical levels\n"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## Step 3: Train\n"
]
},
{
"cell_type": "code",
"execution_count": 41,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Training Features:\n",
"['lag_1_28', 'lag_2_28', 'lag_3_28', 'lag_7_28', 'lag_14_28', 'lag_21_28', 'lag_30_28', 'lag_90_28', 'lag_365_28', 'seasonal_lag_1_28', 'seasonal_lag_2_28', 'seasonal_lag_4_28', 'seasonal_lag_12_28', 'seasonal_lag_26_28', 'seasonal_lag_52_28', 'rolling_mean_7_28', 'rolling_mean_14_28', 'rolling_std_7_28', 'rolling_std_14_28', 'seasonal_rolling_mean_1_28', 'seasonal_rolling_mean_2_28', 'seasonal_rolling_mean_4_28', 'seasonal_rolling_mean_12_28', 'seasonal_rolling_mean_26_28', 'seasonal_rolling_mean_52_28', 'seasonal_rolling_std_2_28', 'seasonal_rolling_std_4_28', 'seasonal_rolling_std_12_28', 'seasonal_rolling_std_26_28', 'seasonal_rolling_std_52_28', 'Aftersnap_TX', 'Afteris_holiday', 'Afteris_big_holiday', 'Beforesnap_TX', 'Beforeis_holiday', 'Beforeis_big_holiday', 'max_price_id', 'median_price_id', 'max_price_item_id', 'median_price_item_id', 'max_price_dept_id', 'median_price_dept_id', 'max_price_cat_id', 'median_price_cat_id', 'max_price_store_id', 'median_price_store_id', 'id', 'item_id', 'dept_id', 'cat_id', 'store_id', 'snap_TX', 'event_name_1', 'event_type_1', 'event_name_2', 'event_type_2', 'Month', 'Day', 'Dayofweek', 'Is_month_end', 'Is_month_start', 'is_holiday', 'is_big_holiday', 'snap_TX_bw', 'snap_TX_fw']\n",
"\n",
"10298579\n"
]
}
],
"source": [
"# --- for full dataset ---\n",
"cont_features = [\n",
" \"lag_1_28\",\n",
" \"lag_2_28\",\n",
" \"lag_3_28\",\n",
" \"lag_7_28\",\n",
" \"lag_14_28\",\n",
" \"lag_21_28\",\n",
" \"lag_30_28\",\n",
" \"lag_90_28\",\n",
" \"lag_365_28\",\n",
" \"seasonal_lag_1_28\",\n",
" \"seasonal_lag_2_28\",\n",
" \"seasonal_lag_4_28\",\n",
" \"seasonal_lag_12_28\",\n",
" \"seasonal_lag_26_28\",\n",
" \"seasonal_lag_52_28\",\n",
" \"rolling_mean_7_28\",\n",
" \"rolling_mean_14_28\",\n",
" \"rolling_std_7_28\",\n",
" \"rolling_std_14_28\",\n",
" \"seasonal_rolling_mean_1_28\",\n",
" \"seasonal_rolling_mean_2_28\",\n",
" \"seasonal_rolling_mean_4_28\",\n",
" \"seasonal_rolling_mean_12_28\",\n",
" \"seasonal_rolling_mean_26_28\",\n",
" \"seasonal_rolling_mean_52_28\",\n",
" # \"seasonal_rolling_std_1_28\",\n",
" \"seasonal_rolling_std_2_28\",\n",
" \"seasonal_rolling_std_4_28\",\n",
" \"seasonal_rolling_std_12_28\",\n",
" \"seasonal_rolling_std_26_28\",\n",
" \"seasonal_rolling_std_52_28\",\n",
" \"Aftersnap_TX\",\n",
" \"Afteris_holiday\",\n",
" \"Afteris_big_holiday\",\n",
" \"Beforesnap_TX\",\n",
" \"Beforeis_holiday\",\n",
" \"Beforeis_big_holiday\",\n",
" \"max_price_id\",\n",
" \"median_price_id\",\n",
" \"max_price_item_id\",\n",
" \"median_price_item_id\",\n",
" \"max_price_dept_id\",\n",
" \"median_price_dept_id\",\n",
" \"max_price_cat_id\",\n",
" \"median_price_cat_id\",\n",
" \"max_price_store_id\",\n",
" \"median_price_store_id\",\n",
"]\n",
"\n",
"cat_features = [\n",
" \"id\",\n",
" \"item_id\",\n",
" \"dept_id\",\n",
" \"cat_id\",\n",
" \"store_id\",\n",
" \"snap_TX\",\n",
" \"event_name_1\",\n",
" \"event_type_1\",\n",
" \"event_name_2\",\n",
" \"event_type_2\",\n",
" \"Month\",\n",
" \"Day\",\n",
" \"Dayofweek\",\n",
" \"Is_month_end\",\n",
" \"Is_month_start\",\n",
" \"is_holiday\",\n",
" \"is_big_holiday\",\n",
" \"snap_TX_bw\",\n",
" \"snap_TX_fw\",\n",
"]\n",
"\n",
"features = cont_features + cat_features\n",
"target_attr = \"sales\"\n",
"\n",
"print(f\"Training Features:\\n{features}\\n\")\n",
"print(len(train_df) + len(test_df))\n"
]
},
{
"cell_type": "code",
"execution_count": 42,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"10298579\n",
"CPU times: user 21.9 s, sys: 34 s, total: 55.8 s\n",
"Wall time: 17.4 s\n"
]
}
],
"source": [
"%%time\n",
"\n",
"# load to from the file using pickle\n",
"lag_features = [1, 2, 3, 7, 14, 21, 30, 90, 365]\n",
"seasonal_lag_features = [1, 2, 4, 12, 26, 52]\n",
"\n",
"rolling_features = {\"mean\": [7, 14], \"std\": [7, 14]}\n",
"seasonal_rolling_features = {\"mean\": [1, 2, 4, 12, 26, 52], \"std\": [2, 4, 12, 26, 52]} # a std for 1 results in all NaNs (why???)\n",
"\n",
"df = build_dataset(\n",
" data_path,\n",
" level=\"id\",\n",
" forecast_horizon=28,\n",
" lag_features=lag_features,\n",
" seasonal_lag_features=seasonal_lag_features,\n",
" rolling_features=rolling_features,\n",
" seasonal_rolling_features=seasonal_rolling_features,\n",
")\n",
"\n",
"print(len(df))\n",
"df.head()\n",
"\n",
"to = build_fold_to(df, cat_features, cont_features, target_attr, forecast_horizon=forecast_horizon, fold= 1, data_fpath=data_path)"
]
},
{
"cell_type": "code",
"execution_count": 43,
"metadata": {},
"outputs": [],
"source": [
"# our core evaluation metric\n",
"def rmsse(train_df, val_df, y_pred):\n",
" train_scale_df = (\n",
" train_df.assign(\n",
" scale=train_df.groupby(\"id\").sales.diff() ** 2,\n",
" )\n",
" .groupby(\"id\")\n",
" .scale.mean()\n",
" )\n",
"\n",
" score_df = (\n",
" val_df.assign(squared_error=(val_df.sales - y_pred) ** 2)\n",
" .groupby(\"id\")\n",
" .squared_error.mean()\n",
" .to_frame()\n",
" .merge(train_scale_df, on=\"id\")\n",
" .assign(rmsse=lambda x: np.sqrt(x.squared_error / x.scale))\n",
" )\n",
" mean_score = score_df.rmsse.mean()\n",
"\n",
" return score_df, mean_score\n"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"### Random Forest\n"
]
},
{
"cell_type": "code",
"execution_count": 44,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 13.2 s, sys: 24.2 s, total: 37.4 s\n",
"Wall time: 43.4 s\n"
]
},
{
"data": {
"text/plain": [
"((10042491, 95), (256088, 95), (10042491,), (256088,))"
]
},
"execution_count": 44,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%time\n",
"\n",
"X_train, y_train = to.train.xs, to.train.ys.values.ravel()\n",
"X_test, y_test = to.valid.xs, to.valid.ys.values.ravel()\n",
"\n",
"train_df = X_train.copy()\n",
"train_df[target_attr] = y_train\n",
"\n",
"val_df = X_test.copy()\n",
"val_df[target_attr] = y_test\n",
"\n",
"X_train.shape, X_test.shape, y_train.shape, y_test.shape\n"
]
},
{
"cell_type": "code",
"execution_count": 45,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 2min 22s, sys: 11.2 s, total: 2min 33s\n",
"Wall time: 27.4 s\n"
]
}
],
"source": [
"%%time\n",
"\n",
"rf = RandomForestRegressor(n_estimators=40, max_samples=50_000, max_features=0.5, min_samples_leaf=100, n_jobs=-1, random_state=1)\n",
"\n",
"res = rf.fit(X_train, y_train)\n",
"preds = rf.predict(X_test)\n"
]
},
{
"cell_type": "code",
"execution_count": 46,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"RMSSE: 0.7558756523129304\n"
]
}
],
"source": [
"score_df, rmsse_val = rmsse(train_df, val_df, preds)\n",
"\n",
"print(f\"RMSSE: {rmsse_val}\")\n"
]
},
{
"cell_type": "code",
"execution_count": 47,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>cols</th>\n",
" <th>imp</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>72</th>\n",
" <td>seasonal_rolling_mean_26_28</td>\n",
" <td>0.266666</td>\n",
" </tr>\n",
" <tr>\n",
" <th>71</th>\n",
" <td>seasonal_rolling_mean_12_28</td>\n",
" <td>0.250787</td>\n",
" </tr>\n",
" <tr>\n",
" <th>73</th>\n",
" <td>seasonal_rolling_mean_52_28</td>\n",
" <td>0.175208</td>\n",
" </tr>\n",
" <tr>\n",
" <th>65</th>\n",
" <td>rolling_mean_14_28</td>\n",
" <td>0.121139</td>\n",
" </tr>\n",
" <tr>\n",
" <th>70</th>\n",
" <td>seasonal_rolling_mean_4_28</td>\n",
" <td>0.084250</td>\n",
" </tr>\n",
" <tr>\n",
" <th>64</th>\n",
" <td>rolling_mean_7_28</td>\n",
" <td>0.047834</td>\n",
" </tr>\n",
" <tr>\n",
" <th>78</th>\n",
" <td>seasonal_rolling_std_52_28</td>\n",
" <td>0.014380</td>\n",
" </tr>\n",
" <tr>\n",
" <th>67</th>\n",
" <td>rolling_std_14_28</td>\n",
" <td>0.007560</td>\n",
" </tr>\n",
" <tr>\n",
" <th>52</th>\n",
" <td>lag_7_28</td>\n",
" <td>0.005341</td>\n",
" </tr>\n",
" <tr>\n",
" <th>69</th>\n",
" <td>seasonal_rolling_mean_2_28</td>\n",
" <td>0.004858</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" cols imp\n",
"72 seasonal_rolling_mean_26_28 0.266666\n",
"71 seasonal_rolling_mean_12_28 0.250787\n",
"73 seasonal_rolling_mean_52_28 0.175208\n",
"65 rolling_mean_14_28 0.121139\n",
"70 seasonal_rolling_mean_4_28 0.084250\n",
"64 rolling_mean_7_28 0.047834\n",
"78 seasonal_rolling_std_52_28 0.014380\n",
"67 rolling_std_14_28 0.007560\n",
"52 lag_7_28 0.005341\n",
"69 seasonal_rolling_mean_2_28 0.004858"
]
},
"execution_count": 47,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"feat_importance_df = pd.DataFrame({\"cols\": X_train.columns.tolist(), \"imp\": rf.feature_importances_}).sort_values(\"imp\", ascending=False)\n",
"\n",
"feat_importance_df.head(10)\n"
]
},
{
"cell_type": "code",
"execution_count": 48,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<Axes: ylabel='cols'>"
]
},
"execution_count": 48,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "",
"text/plain": [
"<Figure size 1200x700 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"feat_importance_df.head(20).plot(\"cols\", \"imp\", \"barh\", figsize=(12, 7), legend=False)\n"
]
},
{
"cell_type": "code",
"execution_count": 49,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"((10042491, 10), (256088, 10))"
]
},
"execution_count": 49,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"col_keep = list(feat_importance_df[feat_importance_df.imp > 0.005].cols)\n",
"col_keep = list(set(col_keep + [\"id\"]))\n",
"\n",
"X_train = X_train[col_keep]\n",
"X_test = X_test[col_keep]\n",
"\n",
"X_train.shape, X_test.shape\n"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"### fast.ai\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"to.xs.iloc[:2]\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"dls = to.dataloaders(bs=64)\n",
"dls.show_batch()\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"if torch.cuda.is_available():\n",
" learn = tabular_learner(dls, metrics=[mse, rmse, mae, msle, exp_rmspe])\n",
" learn.fit(1, lr=3e-5)\n",
" learn.show_results()\n"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"### lightgbm\n",
"\n",
"Metrics: https://lightgbm.readthedocs.io/en/latest/Parameters.html#metric-parameters\n"
]
},
{
"cell_type": "code",
"execution_count": 50,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 11.2 s, sys: 18.1 s, total: 29.2 s\n",
"Wall time: 35.6 s\n"
]
}
],
"source": [
"%%time\n",
"\n",
"X_train, y_train = to.train.xs, to.train.ys.values.ravel()\n",
"X_test, y_test = to.valid.xs, to.valid.ys.values.ravel()\n",
"\n",
"train_df = X_train.copy()\n",
"train_df[target_attr] = y_train\n",
"\n",
"val_df = X_test.copy()\n",
"val_df[target_attr] = y_test\n"
]
},
{
"cell_type": "code",
"execution_count": 51,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[LightGBM] [Warning] Auto-choosing row-wise multi-threading, the overhead of testing was 0.889496 seconds.\n",
"You can set `force_row_wise=true` to remove the overhead.\n",
"And if memory is not enough, you can set `force_col_wise=true`.\n",
"[LightGBM] [Info] Total Bins 7157\n",
"[LightGBM] [Info] Number of data points in the train set: 10042491, number of used features: 94\n",
"[LightGBM] [Info] Start training from score 0.224990\n",
"Training until validation scores don't improve for 100 rounds\n",
"[50]\ttraining's tweedie: 12.5331\tvalid_0's tweedie: 12.9396\n",
"[100]\ttraining's tweedie: 12.5095\tvalid_0's tweedie: 12.93\n",
"[150]\ttraining's tweedie: 12.5016\tvalid_0's tweedie: 12.9292\n",
"[200]\ttraining's tweedie: 12.4969\tvalid_0's tweedie: 12.9288\n",
"[250]\ttraining's tweedie: 12.4933\tvalid_0's tweedie: 12.9288\n",
"[300]\ttraining's tweedie: 12.4901\tvalid_0's tweedie: 12.9284\n",
"[350]\ttraining's tweedie: 12.4873\tvalid_0's tweedie: 12.9282\n",
"[400]\ttraining's tweedie: 12.4847\tvalid_0's tweedie: 12.9282\n",
"[450]\ttraining's tweedie: 12.4825\tvalid_0's tweedie: 12.9283\n",
"[500]\ttraining's tweedie: 12.4802\tvalid_0's tweedie: 12.9282\n",
"Early stopping, best iteration is:\n",
"[445]\ttraining's tweedie: 12.4828\tvalid_0's tweedie: 12.9281\n",
"CPU times: user 1h 5min 3s, sys: 1min 35s, total: 1h 6min 38s\n",
"Wall time: 4min 59s\n"
]
}
],
"source": [
"%%time\n",
"\n",
"# define hyperparameters\n",
"params = dict(\n",
" objective=\"tweedie\",\n",
" tweedie_variance_power=1.1,\n",
" # metric={'tweedie','rmse', 'mape'},\n",
" learning_rate=0.05, # 0.05 ... 1e-1, 1e-2, 1e-3 (default=0.1)\n",
" min_samples_leaf=150,\n",
" feature_fraction=0.3, # 0.3 ... 0.5, 0.8 (default = 1.0)\n",
" subsample=0.3, # 0.3 ... 0.5, 0.8 (default = 1.0; alias for 'subsample')\n",
" deterministic=True,\n",
")\n",
"\n",
"train_dset = lgbm.Dataset(X_train, y_train)\n",
"val_dset = lgbm.Dataset(X_test, y_test, reference=train_dset)\n",
"\n",
"eval_results = {}\n",
"callbacks = [lgbm.early_stopping(100), lgbm.log_evaluation(50), lgbm.record_evaluation(eval_results)]\n",
"\n",
"model = lgbm.train(\n",
" params,\n",
" train_dset,\n",
" valid_sets=[val_dset, train_dset],\n",
" num_boost_round=1000,\n",
" callbacks=callbacks,\n",
")\n",
"\n",
"preds = model.predict(X_test)\n"
]
},
{
"cell_type": "code",
"execution_count": 52,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"RMSSE: 0.75242902668365\n"
]
}
],
"source": [
"score_df, rmsse_mean = rmsse(train_df, val_df, preds)\n",
"print(f\"RMSSE: {rmsse_mean}\")\n"
]
},
{
"cell_type": "code",
"execution_count": 53,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "",
"text/plain": [
"<Figure size 1200x600 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# visualizing in a plot\n",
"x_ax = range(len(y_test))\n",
"plt.figure(figsize=(12, 6))\n",
"plt.plot(x_ax, y_test, label=\"original\")\n",
"plt.plot(x_ax, preds, label=\"predicted\")\n",
"plt.title(\"Boston dataset test and predicted data\")\n",
"plt.xlabel(\"X\")\n",
"plt.ylabel(\"Price\")\n",
"plt.legend(loc=\"best\", fancybox=True, shadow=True)\n",
"plt.grid(True)\n",
"plt.show()\n"
]
},
{
"cell_type": "code",
"execution_count": 54,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<Axes: title={'center': 'Metric during training'}, xlabel='Iterations', ylabel='tweedie'>"
]
},
"execution_count": 54,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "",
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"lgbm.plot_metric(eval_results, metric=\"tweedie\")\n"
]
},
{
"cell_type": "code",
"execution_count": 55,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<Axes: title={'center': 'Feature importance'}, xlabel='Feature importance', ylabel='Features'>"
]
},
"execution_count": 55,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "",
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"lgbm.plot_importance(model, max_num_features=20)\n"
]
},
{
"cell_type": "code",
"execution_count": 56,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[('seasonal_rolling_mean_12_28', 0.1939),\n",
" ('rolling_mean_14_28', 0.1933),\n",
" ('seasonal_rolling_mean_52_28', 0.1596),\n",
" ('seasonal_rolling_mean_26_28', 0.1525),\n",
" ('seasonal_rolling_mean_4_28', 0.0635),\n",
" ('seasonal_rolling_std_52_28', 0.0568),\n",
" ('rolling_std_14_28', 0.0484),\n",
" ('rolling_mean_7_28', 0.0336),\n",
" ('seasonal_rolling_std_26_28', 0.0126),\n",
" ('rolling_std_7_28', 0.0076),\n",
" ('seasonal_rolling_std_12_28', 0.0067),\n",
" ('lag_1_28', 0.0059),\n",
" ('median_price_item_id', 0.0051),\n",
" ('median_price_id', 0.0042),\n",
" ('seasonal_rolling_mean_1_28', 0.0038),\n",
" ('max_price_id', 0.0038),\n",
" ('Afteris_holiday', 0.0033),\n",
" ('Month', 0.0032),\n",
" ('max_price_item_id', 0.003),\n",
" ('item_id', 0.0026),\n",
" ('max_price_dept_id', 0.0026),\n",
" ('id', 0.0022),\n",
" ('Beforeis_holiday', 0.0021),\n",
" ('median_price_dept_id', 0.0019),\n",
" ('Dayofweek', 0.0016),\n",
" ('rolling_mean_7_28_na', 0.0015),\n",
" ('lag_90_28', 0.0015),\n",
" ('seasonal_lag_2_28', 0.0012),\n",
" ('lag_30_28', 0.0011),\n",
" ('rolling_mean_14_28_na', 0.001),\n",
" ('seasonal_lag_26_28', 0.001),\n",
" ('seasonal_lag_4_28', 0.0009),\n",
" ('seasonal_lag_52_28', 0.0009),\n",
" ('lag_2_28', 0.0008),\n",
" ('lag_365_28', 0.0008),\n",
" ('seasonal_rolling_mean_2_28', 0.0008),\n",
" ('Afteris_big_holiday', 0.0008),\n",
" ('cat_id', 0.0007),\n",
" ('event_name_1', 0.0007),\n",
" ('Beforeis_big_holiday', 0.0007),\n",
" ('lag_1_28_na', 0.0006),\n",
" ('lag_14_28_na', 0.0006),\n",
" ('lag_21_28_na', 0.0006),\n",
" ('lag_3_28', 0.0006),\n",
" ('dept_id', 0.0005),\n",
" ('Day', 0.0005),\n",
" ('lag_14_28', 0.0005),\n",
" ('seasonal_lag_12_28', 0.0005),\n",
" ('Aftersnap_TX', 0.0005),\n",
" ('snap_TX', 0.0004),\n",
" ('lag_2_28_na', 0.0004),\n",
" ('lag_3_28_na', 0.0004),\n",
" ('seasonal_rolling_mean_2_28_na', 0.0004),\n",
" ('lag_7_28', 0.0004),\n",
" ('lag_21_28', 0.0004),\n",
" ('lag_30_28_na', 0.0003),\n",
" ('seasonal_lag_1_28', 0.0003),\n",
" ('seasonal_rolling_std_4_28', 0.0003),\n",
" ('Beforesnap_TX', 0.0003),\n",
" ('snap_TX_fw', 0.0002),\n",
" ('lag_7_28_na', 0.0002),\n",
" ('lag_90_28_na', 0.0002),\n",
" ('lag_365_28_na', 0.0002),\n",
" ('seasonal_lag_1_28_na', 0.0002),\n",
" ('seasonal_lag_12_28_na', 0.0002),\n",
" ('rolling_std_14_28_na', 0.0002),\n",
" ('seasonal_rolling_mean_12_28_na', 0.0002),\n",
" ('max_price_cat_id', 0.0002),\n",
" ('max_price_store_id', 0.0002),\n",
" ('store_id', 0.0001),\n",
" ('event_type_1', 0.0001),\n",
" ('is_big_holiday', 0.0001),\n",
" ('snap_TX_bw', 0.0001),\n",
" ('seasonal_lag_2_28_na', 0.0001),\n",
" ('seasonal_lag_4_28_na', 0.0001),\n",
" ('seasonal_lag_26_28_na', 0.0001),\n",
" ('seasonal_lag_52_28_na', 0.0001),\n",
" ('rolling_std_7_28_na', 0.0001),\n",
" ('seasonal_rolling_mean_1_28_na', 0.0001),\n",
" ('seasonal_rolling_std_2_28', 0.0001),\n",
" ('median_price_cat_id', 0.0001),\n",
" ('event_name_2', 0.0),\n",
" ('event_type_2', 0.0),\n",
" ('Is_month_end', 0.0),\n",
" ('Is_month_start', 0.0),\n",
" ('is_holiday', 0.0),\n",
" ('seasonal_rolling_mean_4_28_na', 0.0),\n",
" ('seasonal_rolling_mean_26_28_na', 0.0),\n",
" ('seasonal_rolling_mean_52_28_na', 0.0),\n",
" ('seasonal_rolling_std_2_28_na', 0.0),\n",
" ('seasonal_rolling_std_4_28_na', 0.0),\n",
" ('seasonal_rolling_std_12_28_na', 0.0),\n",
" ('seasonal_rolling_std_26_28_na', 0.0),\n",
" ('seasonal_rolling_std_52_28_na', 0.0),\n",
" ('median_price_store_id', 0.0)]"
]
},
"execution_count": 56,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"importance = model.feature_importance(importance_type='gain')\n",
"importance_normalized = np.round(importance / np.sum(importance), 4)\n",
"\n",
"fe_imps = {fe_name: fe_imp for fe_name, fe_imp in zip(model.feature_name(), importance_normalized)}\n",
" \n",
"sorted(fe_imps.items(), key=lambda x: x[1], reverse=True)"
]
},
{
"cell_type": "code",
"execution_count": 57,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"((10042491, 14), (256088, 14))"
]
},
"execution_count": 57,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"col_keep = list(set([col for col, imp in fe_imps.items() if imp > 0.005] + [\"id\"]))\n",
"X_train = X_train[col_keep]\n",
"X_test = X_test[col_keep]\n",
"\n",
"train_df = X_train.copy()\n",
"train_df[target_attr] = y_train\n",
"\n",
"val_df = X_test.copy()\n",
"val_df[target_attr] = y_test\n",
"\n",
"X_train.shape, X_test.shape"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Step 4: TSCV"
]
},
{
"cell_type": "code",
"execution_count": 58,
"metadata": {},
"outputs": [],
"source": [
"n_folds = 3\n",
"forecast_horizon = 28\n",
"level = \"id\"\n",
"\n",
"data_path = \"./data\"\n",
"models_path = \"./models\"\n",
"results_path = \"./results\""
]
},
{
"cell_type": "code",
"execution_count": 59,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"10298579\n",
"CPU times: user 20.8 s, sys: 17.3 s, total: 38.2 s\n",
"Wall time: 10.7 s\n"
]
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>id</th>\n",
" <th>date</th>\n",
" <th>state_id</th>\n",
" <th>sales</th>\n",
" <th>item_id</th>\n",
" <th>dept_id</th>\n",
" <th>cat_id</th>\n",
" <th>store_id</th>\n",
" <th>sell_price</th>\n",
" <th>snap_TX</th>\n",
" <th>...</th>\n",
" <th>seasonal_rolling_mean_2_28</th>\n",
" <th>seasonal_rolling_mean_4_28</th>\n",
" <th>seasonal_rolling_mean_12_28</th>\n",
" <th>seasonal_rolling_mean_26_28</th>\n",
" <th>seasonal_rolling_mean_52_28</th>\n",
" <th>seasonal_rolling_std_2_28</th>\n",
" <th>seasonal_rolling_std_4_28</th>\n",
" <th>seasonal_rolling_std_12_28</th>\n",
" <th>seasonal_rolling_std_26_28</th>\n",
" <th>seasonal_rolling_std_52_28</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>FOODS_1_004_TX_1_evaluation</td>\n",
" <td>2013-01-01</td>\n",
" <td>TX</td>\n",
" <td>20</td>\n",
" <td>FOODS_1_004</td>\n",
" <td>FOODS_1</td>\n",
" <td>FOODS</td>\n",
" <td>TX_1</td>\n",
" <td>1.78</td>\n",
" <td>True</td>\n",
" <td>...</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>FOODS_1_004_TX_2_evaluation</td>\n",
" <td>2013-01-01</td>\n",
" <td>TX</td>\n",
" <td>20</td>\n",
" <td>FOODS_1_004</td>\n",
" <td>FOODS_1</td>\n",
" <td>FOODS</td>\n",
" <td>TX_2</td>\n",
" <td>1.78</td>\n",
" <td>True</td>\n",
" <td>...</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>FOODS_1_004_TX_3_evaluation</td>\n",
" <td>2013-01-01</td>\n",
" <td>TX</td>\n",
" <td>4</td>\n",
" <td>FOODS_1_004</td>\n",
" <td>FOODS_1</td>\n",
" <td>FOODS</td>\n",
" <td>TX_3</td>\n",
" <td>1.78</td>\n",
" <td>True</td>\n",
" <td>...</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>FOODS_1_005_TX_2_evaluation</td>\n",
" <td>2013-01-01</td>\n",
" <td>TX</td>\n",
" <td>1</td>\n",
" <td>FOODS_1_005</td>\n",
" <td>FOODS_1</td>\n",
" <td>FOODS</td>\n",
" <td>TX_2</td>\n",
" <td>3.28</td>\n",
" <td>True</td>\n",
" <td>...</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>FOODS_1_009_TX_2_evaluation</td>\n",
" <td>2013-01-01</td>\n",
" <td>TX</td>\n",
" <td>3</td>\n",
" <td>FOODS_1_009</td>\n",
" <td>FOODS_1</td>\n",
" <td>FOODS</td>\n",
" <td>TX_2</td>\n",
" <td>2.68</td>\n",
" <td>True</td>\n",
" <td>...</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>5 rows × 71 columns</p>\n",
"</div>"
],
"text/plain": [
" id date state_id sales item_id \\\n",
"0 FOODS_1_004_TX_1_evaluation 2013-01-01 TX 20 FOODS_1_004 \n",
"1 FOODS_1_004_TX_2_evaluation 2013-01-01 TX 20 FOODS_1_004 \n",
"2 FOODS_1_004_TX_3_evaluation 2013-01-01 TX 4 FOODS_1_004 \n",
"3 FOODS_1_005_TX_2_evaluation 2013-01-01 TX 1 FOODS_1_005 \n",
"4 FOODS_1_009_TX_2_evaluation 2013-01-01 TX 3 FOODS_1_009 \n",
"\n",
" dept_id cat_id store_id sell_price snap_TX ... \\\n",
"0 FOODS_1 FOODS TX_1 1.78 True ... \n",
"1 FOODS_1 FOODS TX_2 1.78 True ... \n",
"2 FOODS_1 FOODS TX_3 1.78 True ... \n",
"3 FOODS_1 FOODS TX_2 3.28 True ... \n",
"4 FOODS_1 FOODS TX_2 2.68 True ... \n",
"\n",
" seasonal_rolling_mean_2_28 seasonal_rolling_mean_4_28 \\\n",
"0 NaN NaN \n",
"1 NaN NaN \n",
"2 NaN NaN \n",
"3 NaN NaN \n",
"4 NaN NaN \n",
"\n",
" seasonal_rolling_mean_12_28 seasonal_rolling_mean_26_28 \\\n",
"0 NaN NaN \n",
"1 NaN NaN \n",
"2 NaN NaN \n",
"3 NaN NaN \n",
"4 NaN NaN \n",
"\n",
" seasonal_rolling_mean_52_28 seasonal_rolling_std_2_28 \\\n",
"0 NaN NaN \n",
"1 NaN NaN \n",
"2 NaN NaN \n",
"3 NaN NaN \n",
"4 NaN NaN \n",
"\n",
" seasonal_rolling_std_4_28 seasonal_rolling_std_12_28 \\\n",
"0 NaN NaN \n",
"1 NaN NaN \n",
"2 NaN NaN \n",
"3 NaN NaN \n",
"4 NaN NaN \n",
"\n",
" seasonal_rolling_std_26_28 seasonal_rolling_std_52_28 \n",
"0 NaN NaN \n",
"1 NaN NaN \n",
"2 NaN NaN \n",
"3 NaN NaN \n",
"4 NaN NaN \n",
"\n",
"[5 rows x 71 columns]"
]
},
"execution_count": 59,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%time\n",
"\n",
"lag_features = [1, 2, 3, 7, 14, 21, 30, 90, 365]\n",
"seasonal_lag_features = [1, 2, 4, 12, 26, 52]\n",
"\n",
"rolling_features = {\"mean\": [7, 14], \"std\": [7, 14]}\n",
"seasonal_rolling_features = {\"mean\": [1, 2, 4, 12, 26, 52], \"std\": [2, 4, 12, 26, 52]} # a std for 1 results in all NaNs (why???)\n",
"\n",
"df = build_dataset(\n",
" data_path,\n",
" level=\"id\",\n",
" forecast_horizon=28,\n",
" lag_features=lag_features,\n",
" seasonal_lag_features=seasonal_lag_features,\n",
" rolling_features=rolling_features,\n",
" seasonal_rolling_features=seasonal_rolling_features,\n",
")\n",
"\n",
"print(len(df))\n",
"df.head()"
]
},
{
"cell_type": "code",
"execution_count": 60,
"metadata": {},
"outputs": [],
"source": [
"# save the fold train/validation splits to make calculating RMSSE a bit faster and also if further EDA needed\n",
"max_yyyymmdd = df[\"date\"].max().strftime(\"%Y%m%d\")\n",
"\n",
"for i in range(n_folds):\n",
" fold = i + 1\n",
" train_df, val_df = build_train_test_splits(df, forecast_horizon=forecast_horizon, fold=fold)\n",
" \n",
" train_df.to_parquet(Path(f\"{data_path}/{max_yyyymmdd}_fold_{fold}_train.parquet\"))\n",
" val_df.to_parquet(Path(f\"{data_path}/{max_yyyymmdd}_fold_{fold}_val.parquet\"))"
]
},
{
"cell_type": "code",
"execution_count": 61,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"['id', 'date', 'state_id', 'sales', 'item_id', 'dept_id', 'cat_id', 'store_id', 'sell_price', 'snap_TX', 'event_name_1', 'event_type_1', 'event_name_2', 'event_type_2', 'is_holiday', 'is_big_holiday', 'Aftersnap_TX', 'Afteris_holiday', 'Afteris_big_holiday', 'Beforesnap_TX', 'Beforeis_holiday', 'Beforeis_big_holiday', 'snap_TX_bw', 'is_holiday_bw', 'is_big_holiday_bw', 'snap_TX_fw', 'is_holiday_fw', 'is_big_holiday_fw', 'Year', 'Month', 'Week', 'Day', 'Dayofweek', 'Dayofyear', 'Is_month_end', 'Is_month_start', 'Is_quarter_end', 'Is_quarter_start', 'Is_year_end', 'Is_year_start', 'Elapsed', 'lag_1_28', 'lag_2_28', 'lag_3_28', 'lag_7_28', 'lag_14_28', 'lag_21_28', 'lag_30_28', 'lag_90_28', 'lag_365_28', 'seasonal_lag_1_28', 'seasonal_lag_2_28', 'seasonal_lag_4_28', 'seasonal_lag_12_28', 'seasonal_lag_26_28', 'seasonal_lag_52_28', 'rolling_mean_7_28', 'rolling_mean_14_28', 'rolling_std_7_28', 'rolling_std_14_28', 'seasonal_rolling_mean_1_28', 'seasonal_rolling_mean_2_28', 'seasonal_rolling_mean_4_28', 'seasonal_rolling_mean_12_28', 'seasonal_rolling_mean_26_28', 'seasonal_rolling_mean_52_28', 'seasonal_rolling_std_2_28', 'seasonal_rolling_std_4_28', 'seasonal_rolling_std_12_28', 'seasonal_rolling_std_26_28', 'seasonal_rolling_std_52_28']\n"
]
}
],
"source": [
"print(df.columns.tolist())"
]
},
{
"cell_type": "code",
"execution_count": 62,
"metadata": {},
"outputs": [],
"source": [
"cont_features = [\n",
" \"lag_1_28\",\n",
" \"lag_2_28\",\n",
" \"lag_3_28\",\n",
" \"lag_7_28\",\n",
" \"lag_14_28\",\n",
" \"lag_21_28\",\n",
" \"lag_30_28\",\n",
" \"lag_90_28\",\n",
" \"lag_365_28\",\n",
" \"seasonal_lag_1_28\",\n",
" \"seasonal_lag_2_28\",\n",
" \"seasonal_lag_4_28\",\n",
" \"seasonal_lag_12_28\",\n",
" \"seasonal_lag_26_28\",\n",
" \"seasonal_lag_52_28\",\n",
" \"rolling_mean_7_28\",\n",
" \"rolling_mean_14_28\",\n",
" \"rolling_std_7_28\",\n",
" \"rolling_std_14_28\",\n",
" \"seasonal_rolling_mean_1_28\",\n",
" \"seasonal_rolling_mean_2_28\",\n",
" \"seasonal_rolling_mean_4_28\",\n",
" \"seasonal_rolling_mean_12_28\",\n",
" \"seasonal_rolling_mean_26_28\",\n",
" \"seasonal_rolling_mean_52_28\",\n",
" # \"seasonal_rolling_std_1_28\",\n",
" \"seasonal_rolling_std_2_28\",\n",
" \"seasonal_rolling_std_4_28\",\n",
" \"seasonal_rolling_std_12_28\",\n",
" \"seasonal_rolling_std_26_28\",\n",
" \"seasonal_rolling_std_52_28\",\n",
" \"Aftersnap_TX\",\n",
" \"Afteris_holiday\",\n",
" \"Afteris_big_holiday\",\n",
" \"Beforesnap_TX\",\n",
" \"Beforeis_holiday\",\n",
" \"Beforeis_big_holiday\",\n",
" \"max_price_id\",\n",
" \"median_price_id\",\n",
" \"max_price_item_id\",\n",
" \"median_price_item_id\",\n",
" \"max_price_dept_id\",\n",
" \"median_price_dept_id\",\n",
" \"max_price_cat_id\",\n",
" \"median_price_cat_id\",\n",
" \"max_price_store_id\",\n",
" \"median_price_store_id\",\n",
"]\n",
"\n",
"cat_features = [\n",
" \"id\",\n",
" \"item_id\",\n",
" \"dept_id\",\n",
" \"cat_id\",\n",
" \"store_id\",\n",
" \"snap_TX\",\n",
" \"event_name_1\",\n",
" \"event_type_1\",\n",
" \"event_name_2\",\n",
" \"event_type_2\",\n",
" \"Month\",\n",
" \"Day\",\n",
" \"Dayofweek\",\n",
" \"Is_month_end\",\n",
" \"Is_month_start\",\n",
" \"is_holiday\",\n",
" \"is_big_holiday\",\n",
" \"snap_TX_bw\",\n",
" \"snap_TX_fw\",\n",
"]\n",
"\n",
"features = cont_features + cat_features\n",
"target_attr = \"sales\""
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Random Forest"
]
},
{
"cell_type": "code",
"execution_count": 63,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"BEGIN TRAINING :: RANDOM FOREST\n",
"=== Training Started: Fold 1 ===\n",
"Number of Train/Validation Examples: 10042491 | 256088\n",
"Train Date Range: 2013-01-01 00:00:00 | 2016-04-24 00:00:00\n",
"Validation Date Range: 2016-04-25 00:00:00 | 2016-05-22 00:00:00\n",
"RMSSE: 0.7557104300927474\n",
"RMSSE: 0.7546597446177328\n",
"=== Training Finished: Fold 1 ===\n",
"Final Features: ['seasonal_rolling_mean_2_28', 'rolling_mean_14_28', 'id', 'seasonal_rolling_mean_4_28', 'rolling_mean_7_28', 'seasonal_rolling_mean_12_28', 'seasonal_rolling_mean_26_28', 'seasonal_rolling_mean_52_28']\n",
"Final RMSSE: 0.7546597446177328\n",
"=== Training Started: Fold 2 ===\n",
"Number of Train/Validation Examples: 9786427 | 256064\n",
"Train Date Range: 2013-01-01 00:00:00 | 2016-03-27 00:00:00\n",
"Validation Date Range: 2016-03-28 00:00:00 | 2016-04-24 00:00:00\n",
"RMSSE: 0.7338695879512891\n",
"RMSSE: 0.731481537792378\n",
"=== Training Finished: Fold 2 ===\n",
"Final Features: ['seasonal_rolling_mean_2_28', 'rolling_mean_14_28', 'id', 'seasonal_rolling_mean_4_28', 'rolling_mean_7_28', 'seasonal_rolling_mean_12_28', 'seasonal_rolling_mean_26_28', 'seasonal_rolling_mean_52_28']\n",
"Final RMSSE: 0.731481537792378\n",
"=== Training Started: Fold 3 ===\n",
"Number of Train/Validation Examples: 9530367 | 256060\n",
"Train Date Range: 2013-01-01 00:00:00 | 2016-02-28 00:00:00\n",
"Validation Date Range: 2016-02-29 00:00:00 | 2016-03-27 00:00:00\n",
"RMSSE: 0.7429255771867385\n",
"RMSSE: 0.741228056471335\n",
"=== Training Finished: Fold 3 ===\n",
"Final Features: ['rolling_mean_14_28', 'id', 'seasonal_rolling_mean_4_28', 'rolling_mean_7_28', 'seasonal_rolling_mean_12_28', 'seasonal_rolling_mean_26_28', 'seasonal_rolling_mean_52_28']\n",
"Final RMSSE: 0.741228056471335\n",
"CPU times: user 10min 1s, sys: 3min 13s, total: 13min 14s\n",
"Wall time: 6min 3s\n"
]
}
],
"source": [
"%%time\n",
"\n",
"print(\"BEGIN TRAINING :: RANDOM FOREST\")\n",
"for fold_num in range(n_folds):\n",
" fold = fold_num + 1\n",
" print(f\"=== Training Started: Fold {fold} ===\")\n",
" \n",
" to = build_fold_to(df, cat_features, cont_features, target_attr, forecast_horizon=forecast_horizon, fold=fold, data_fpath=data_path)\n",
" \n",
" X_train, y_train = to.train.xs, to.train.ys.values.ravel()\n",
" X_test, y_test = to.valid.xs, to.valid.ys.values.ravel()\n",
"\n",
" train_df = to.train.items\n",
" train_df[target_attr] = y_train\n",
"\n",
" val_df = to.valid.items\n",
" val_df[target_attr] = y_test\n",
" \n",
" print(f\"Number of Train/Validation Examples: {len(X_train)} | {len(X_test)}\")\n",
" print(f\"Train Date Range: {train_df.date.min()} | {train_df.date.max()}\")\n",
" print(f\"Validation Date Range: {val_df.date.min()} | {val_df.date.max()}\")\n",
" \n",
" col_keep = []\n",
" rmsse_val = None\n",
" for run_num in range(2):\n",
" rf = RandomForestRegressor(n_estimators=40, max_samples=50_000, max_features=0.5, min_samples_leaf=100, n_jobs=-1, random_state=9)\n",
" res = rf.fit(X_train, y_train)\n",
" preds = rf.predict(X_test)\n",
" \n",
" score_df, rmsse_val = rmsse(train_df, val_df, preds)\n",
" print(f\"RMSSE: {rmsse_val}\")\n",
" \n",
" if run_num == 0:\n",
" # retrain on most important features\n",
" feat_importance_df = pd.DataFrame({\"cols\": X_train.columns.tolist(), \"imp\": rf.feature_importances_}).sort_values(\"imp\", ascending=False)\n",
" \n",
" col_keep = list(feat_importance_df[feat_importance_df.imp > 0.005].cols)\n",
" col_keep = list(set(col_keep + [\"id\"]))\n",
"\n",
" X_train = X_train[col_keep]\n",
" X_test = X_test[col_keep]\n",
" \n",
" # save model and features\n",
" joblib.dump(rf, Path(models_path)/f\"rf_fold_{fold}.pkl\")\n",
" \n",
" with open(Path(models_path)/f\"rf_fold_{fold}_features.json\", 'w') as json_file:\n",
" json.dump(col_keep, json_file)\n",
" \n",
" # save validation results\n",
" for cat_name in to.cat_names:\n",
" mapping_dict = {idx: value for idx, value in enumerate(to.classes[cat_name])}\n",
" val_df[f\"orig_{cat_name}\"] = val_df[cat_name].map(mapping_dict)\n",
" \n",
" val_df[target_attr] = y_test\n",
" val_df[f\"{target_attr}_preds\"] = preds\n",
" val_df = val_df.merge(score_df.reset_index()[[\"id\", \"rmsse\"]], on=\"id\")\n",
" val_df = val_df.rename(columns={\"rmsse\": \"rmsse_id\"})\n",
" val_df[f\"rmsse_overall\"] = rmsse_val\n",
" \n",
" val_df = (val_df.drop(columns=\"id\").rename(columns={\"orig_id\": \"id\"}))\n",
" val_df[[\"id\", \"date\", target_attr, f\"{target_attr}_preds\", \"rmsse_id\", \"rmsse_overall\"]].to_parquet(Path(results_path)/f\"{max_yyyymmdd}_fold_{fold}_rf_results.parquet\")\n",
" \n",
" print(f\"=== Training Finished: Fold {fold} ===\")\n",
" print(\"Final Features: \", col_keep)\n",
" print(\"Final RMSSE: \", rmsse_val)"
]
},
{
"cell_type": "code",
"execution_count": 64,
"metadata": {},
"outputs": [],
"source": [
"# pd.read_parquet(\"./results/20160522_fold_3_rf_results.parquet\").tail(20)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### lightgbm"
]
},
{
"cell_type": "code",
"execution_count": 65,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"BEGIN TRAINING :: lightGBM\n",
"=== Training Started: Fold 1 ===\n",
"Number of Train/Validation Examples: 10042491 | 256088\n",
"Train Date Range: 2013-01-01 00:00:00 | 2016-04-24 00:00:00\n",
"Validation Date Range: 2016-04-25 00:00:00 | 2016-05-22 00:00:00\n",
"[LightGBM] [Warning] Auto-choosing row-wise multi-threading, the overhead of testing was 1.023046 seconds.\n",
"You can set `force_row_wise=true` to remove the overhead.\n",
"And if memory is not enough, you can set `force_col_wise=true`.\n",
"[LightGBM] [Info] Total Bins 7157\n",
"[LightGBM] [Info] Number of data points in the train set: 10042491, number of used features: 94\n",
"[LightGBM] [Info] Start training from score 0.224990\n",
"Training until validation scores don't improve for 100 rounds\n",
"[50]\ttraining's tweedie: 12.5331\tvalid_0's tweedie: 12.9396\n",
"[100]\ttraining's tweedie: 12.5095\tvalid_0's tweedie: 12.93\n",
"[150]\ttraining's tweedie: 12.5016\tvalid_0's tweedie: 12.9292\n",
"[200]\ttraining's tweedie: 12.4969\tvalid_0's tweedie: 12.9288\n",
"[250]\ttraining's tweedie: 12.4933\tvalid_0's tweedie: 12.9288\n",
"[300]\ttraining's tweedie: 12.4901\tvalid_0's tweedie: 12.9284\n",
"[350]\ttraining's tweedie: 12.4873\tvalid_0's tweedie: 12.9282\n",
"[400]\ttraining's tweedie: 12.4847\tvalid_0's tweedie: 12.9282\n",
"[450]\ttraining's tweedie: 12.4825\tvalid_0's tweedie: 12.9283\n",
"[500]\ttraining's tweedie: 12.4802\tvalid_0's tweedie: 12.9282\n",
"Early stopping, best iteration is:\n",
"[445]\ttraining's tweedie: 12.4828\tvalid_0's tweedie: 12.9281\n",
"RMSSE: 0.75242902668365\n",
"[LightGBM] [Warning] Auto-choosing row-wise multi-threading, the overhead of testing was 0.026562 seconds.\n",
"You can set `force_row_wise=true` to remove the overhead.\n",
"And if memory is not enough, you can set `force_col_wise=true`.\n",
"[LightGBM] [Info] Total Bins 3398\n",
"[LightGBM] [Info] Number of data points in the train set: 10042491, number of used features: 14\n",
"[LightGBM] [Info] Start training from score 0.224990\n",
"Training until validation scores don't improve for 100 rounds\n",
"[50]\ttraining's tweedie: 12.5427\tvalid_0's tweedie: 12.9467\n",
"[100]\ttraining's tweedie: 12.5233\tvalid_0's tweedie: 12.9381\n",
"[150]\ttraining's tweedie: 12.5184\tvalid_0's tweedie: 12.9363\n",
"[200]\ttraining's tweedie: 12.5159\tvalid_0's tweedie: 12.9351\n",
"[250]\ttraining's tweedie: 12.5143\tvalid_0's tweedie: 12.9349\n",
"[300]\ttraining's tweedie: 12.5125\tvalid_0's tweedie: 12.9346\n",
"[350]\ttraining's tweedie: 12.5113\tvalid_0's tweedie: 12.935\n",
"[400]\ttraining's tweedie: 12.5102\tvalid_0's tweedie: 12.935\n",
"Early stopping, best iteration is:\n",
"[311]\ttraining's tweedie: 12.5122\tvalid_0's tweedie: 12.9346\n",
"RMSSE: 0.7537823122937269\n",
"=== Training Finished: Fold 1 ===\n",
"Final Features: ['rolling_std_14_28', 'rolling_mean_14_28', 'id', 'seasonal_rolling_std_12_28', 'seasonal_rolling_mean_4_28', 'median_price_item_id', 'rolling_mean_7_28', 'lag_1_28', 'seasonal_rolling_mean_12_28', 'seasonal_rolling_std_52_28', 'rolling_std_7_28', 'seasonal_rolling_mean_26_28', 'seasonal_rolling_std_26_28', 'seasonal_rolling_mean_52_28']\n",
"Final RMSSE: 0.7537823122937269\n",
"=== Training Started: Fold 2 ===\n",
"Number of Train/Validation Examples: 9786427 | 256064\n",
"Train Date Range: 2013-01-01 00:00:00 | 2016-03-27 00:00:00\n",
"Validation Date Range: 2016-03-28 00:00:00 | 2016-04-24 00:00:00\n",
"[LightGBM] [Warning] Auto-choosing row-wise multi-threading, the overhead of testing was 1.061464 seconds.\n",
"You can set `force_row_wise=true` to remove the overhead.\n",
"And if memory is not enough, you can set `force_col_wise=true`.\n",
"[LightGBM] [Info] Total Bins 7145\n",
"[LightGBM] [Info] Number of data points in the train set: 9786427, number of used features: 94\n",
"[LightGBM] [Info] Start training from score 0.225917\n",
"Training until validation scores don't improve for 100 rounds\n",
"[50]\ttraining's tweedie: 12.5412\tvalid_0's tweedie: 12.218\n",
"[100]\ttraining's tweedie: 12.5175\tvalid_0's tweedie: 12.2064\n",
"[150]\ttraining's tweedie: 12.5096\tvalid_0's tweedie: 12.2056\n",
"[200]\ttraining's tweedie: 12.5048\tvalid_0's tweedie: 12.2048\n",
"[250]\ttraining's tweedie: 12.5009\tvalid_0's tweedie: 12.2036\n",
"[300]\ttraining's tweedie: 12.4975\tvalid_0's tweedie: 12.2027\n",
"[350]\ttraining's tweedie: 12.4947\tvalid_0's tweedie: 12.2023\n",
"[400]\ttraining's tweedie: 12.4923\tvalid_0's tweedie: 12.2017\n",
"[450]\ttraining's tweedie: 12.4896\tvalid_0's tweedie: 12.2015\n",
"[500]\ttraining's tweedie: 12.4874\tvalid_0's tweedie: 12.201\n",
"[550]\ttraining's tweedie: 12.4855\tvalid_0's tweedie: 12.2009\n",
"[600]\ttraining's tweedie: 12.4836\tvalid_0's tweedie: 12.2008\n",
"[650]\ttraining's tweedie: 12.4816\tvalid_0's tweedie: 12.2008\n",
"[700]\ttraining's tweedie: 12.48\tvalid_0's tweedie: 12.2004\n",
"[750]\ttraining's tweedie: 12.4783\tvalid_0's tweedie: 12.2001\n",
"[800]\ttraining's tweedie: 12.4768\tvalid_0's tweedie: 12.2001\n",
"[850]\ttraining's tweedie: 12.4754\tvalid_0's tweedie: 12.1998\n",
"[900]\ttraining's tweedie: 12.4741\tvalid_0's tweedie: 12.1996\n",
"[950]\ttraining's tweedie: 12.4726\tvalid_0's tweedie: 12.1992\n",
"[1000]\ttraining's tweedie: 12.4716\tvalid_0's tweedie: 12.1991\n",
"Did not meet early stopping. Best iteration is:\n",
"[1000]\ttraining's tweedie: 12.4716\tvalid_0's tweedie: 12.1991\n",
"RMSSE: 0.7304551937922373\n",
"[LightGBM] [Warning] Auto-choosing row-wise multi-threading, the overhead of testing was 0.026671 seconds.\n",
"You can set `force_row_wise=true` to remove the overhead.\n",
"And if memory is not enough, you can set `force_col_wise=true`.\n",
"[LightGBM] [Info] Total Bins 3642\n",
"[LightGBM] [Info] Number of data points in the train set: 9786427, number of used features: 15\n",
"[LightGBM] [Info] Start training from score 0.225917\n",
"Training until validation scores don't improve for 100 rounds\n",
"[50]\ttraining's tweedie: 12.5458\tvalid_0's tweedie: 12.2178\n",
"[100]\ttraining's tweedie: 12.5274\tvalid_0's tweedie: 12.2072\n",
"[150]\ttraining's tweedie: 12.5229\tvalid_0's tweedie: 12.2061\n",
"[200]\ttraining's tweedie: 12.5204\tvalid_0's tweedie: 12.2059\n",
"[250]\ttraining's tweedie: 12.5184\tvalid_0's tweedie: 12.2055\n",
"[300]\ttraining's tweedie: 12.5167\tvalid_0's tweedie: 12.2052\n",
"[350]\ttraining's tweedie: 12.5153\tvalid_0's tweedie: 12.2049\n",
"[400]\ttraining's tweedie: 12.5141\tvalid_0's tweedie: 12.2048\n",
"[450]\ttraining's tweedie: 12.513\tvalid_0's tweedie: 12.2046\n",
"[500]\ttraining's tweedie: 12.5118\tvalid_0's tweedie: 12.2044\n",
"[550]\ttraining's tweedie: 12.5108\tvalid_0's tweedie: 12.2043\n",
"[600]\ttraining's tweedie: 12.5098\tvalid_0's tweedie: 12.2042\n",
"[650]\ttraining's tweedie: 12.509\tvalid_0's tweedie: 12.2041\n",
"[700]\ttraining's tweedie: 12.5082\tvalid_0's tweedie: 12.204\n",
"[750]\ttraining's tweedie: 12.5074\tvalid_0's tweedie: 12.204\n",
"[800]\ttraining's tweedie: 12.5067\tvalid_0's tweedie: 12.204\n",
"Early stopping, best iteration is:\n",
"[717]\ttraining's tweedie: 12.508\tvalid_0's tweedie: 12.204\n",
"RMSSE: 0.7291672416445893\n",
"=== Training Finished: Fold 2 ===\n",
"Final Features: ['rolling_std_14_28', 'rolling_mean_14_28', 'id', 'seasonal_rolling_std_12_28', 'seasonal_rolling_mean_4_28', 'median_price_item_id', 'rolling_mean_7_28', 'lag_1_28', 'seasonal_rolling_mean_12_28', 'seasonal_rolling_std_52_28', 'median_price_id', 'rolling_std_7_28', 'seasonal_rolling_mean_26_28', 'seasonal_rolling_std_26_28', 'seasonal_rolling_mean_52_28']\n",
"Final RMSSE: 0.7291672416445893\n",
"=== Training Started: Fold 3 ===\n",
"Number of Train/Validation Examples: 9530367 | 256060\n",
"Train Date Range: 2013-01-01 00:00:00 | 2016-02-28 00:00:00\n",
"Validation Date Range: 2016-02-29 00:00:00 | 2016-03-27 00:00:00\n",
"[LightGBM] [Warning] Auto-choosing row-wise multi-threading, the overhead of testing was 1.476439 seconds.\n",
"You can set `force_row_wise=true` to remove the overhead.\n",
"And if memory is not enough, you can set `force_col_wise=true`.\n",
"[LightGBM] [Info] Total Bins 7113\n",
"[LightGBM] [Info] Number of data points in the train set: 9530367, number of used features: 94\n",
"[LightGBM] [Info] Start training from score 0.226259\n",
"Training until validation scores don't improve for 100 rounds\n",
"[50]\ttraining's tweedie: 12.5437\tvalid_0's tweedie: 12.4528\n",
"[100]\ttraining's tweedie: 12.5197\tvalid_0's tweedie: 12.4395\n",
"[150]\ttraining's tweedie: 12.5115\tvalid_0's tweedie: 12.4379\n",
"[200]\ttraining's tweedie: 12.5067\tvalid_0's tweedie: 12.4379\n",
"[250]\ttraining's tweedie: 12.5028\tvalid_0's tweedie: 12.4373\n",
"[300]\ttraining's tweedie: 12.4991\tvalid_0's tweedie: 12.4365\n",
"[350]\ttraining's tweedie: 12.4962\tvalid_0's tweedie: 12.4361\n",
"[400]\ttraining's tweedie: 12.4937\tvalid_0's tweedie: 12.4359\n",
"[450]\ttraining's tweedie: 12.4914\tvalid_0's tweedie: 12.436\n",
"[500]\ttraining's tweedie: 12.4893\tvalid_0's tweedie: 12.4359\n",
"[550]\ttraining's tweedie: 12.4873\tvalid_0's tweedie: 12.4356\n",
"[600]\ttraining's tweedie: 12.4854\tvalid_0's tweedie: 12.4356\n",
"[650]\ttraining's tweedie: 12.4836\tvalid_0's tweedie: 12.4355\n",
"[700]\ttraining's tweedie: 12.482\tvalid_0's tweedie: 12.4354\n",
"[750]\ttraining's tweedie: 12.4805\tvalid_0's tweedie: 12.4355\n",
"[800]\ttraining's tweedie: 12.479\tvalid_0's tweedie: 12.4356\n",
"Early stopping, best iteration is:\n",
"[721]\ttraining's tweedie: 12.4814\tvalid_0's tweedie: 12.4353\n",
"RMSSE: 0.740774215864624\n",
"[LightGBM] [Warning] Auto-choosing row-wise multi-threading, the overhead of testing was 0.039919 seconds.\n",
"You can set `force_row_wise=true` to remove the overhead.\n",
"And if memory is not enough, you can set `force_col_wise=true`.\n",
"[LightGBM] [Info] Total Bins 3400\n",
"[LightGBM] [Info] Number of data points in the train set: 9530367, number of used features: 14\n",
"[LightGBM] [Info] Start training from score 0.226259\n",
"Training until validation scores don't improve for 100 rounds\n",
"[50]\ttraining's tweedie: 12.5537\tvalid_0's tweedie: 12.4556\n",
"[100]\ttraining's tweedie: 12.5339\tvalid_0's tweedie: 12.444\n",
"[150]\ttraining's tweedie: 12.5287\tvalid_0's tweedie: 12.4429\n",
"[200]\ttraining's tweedie: 12.5262\tvalid_0's tweedie: 12.4417\n",
"[250]\ttraining's tweedie: 12.5246\tvalid_0's tweedie: 12.4413\n",
"[300]\ttraining's tweedie: 12.5227\tvalid_0's tweedie: 12.4412\n",
"[350]\ttraining's tweedie: 12.5215\tvalid_0's tweedie: 12.4412\n",
"Early stopping, best iteration is:\n",
"[279]\ttraining's tweedie: 12.5236\tvalid_0's tweedie: 12.4409\n",
"RMSSE: 0.739529952210638\n",
"=== Training Finished: Fold 3 ===\n",
"Final Features: ['rolling_std_14_28', 'rolling_mean_14_28', 'id', 'seasonal_rolling_std_12_28', 'seasonal_rolling_mean_4_28', 'median_price_item_id', 'rolling_mean_7_28', 'lag_1_28', 'seasonal_rolling_mean_12_28', 'seasonal_rolling_std_52_28', 'rolling_std_7_28', 'seasonal_rolling_mean_26_28', 'seasonal_rolling_std_26_28', 'seasonal_rolling_mean_52_28']\n",
"Final RMSSE: 0.739529952210638\n",
"CPU times: user 7h 19min 18s, sys: 23min 53s, total: 7h 43min 12s\n",
"Wall time: 50min 48s\n"
]
}
],
"source": [
"%%time\n",
"\n",
"print(\"BEGIN TRAINING :: lightGBM\")\n",
"for fold_num in range(n_folds):\n",
" fold = fold_num + 1\n",
" print(f\"=== Training Started: Fold {fold} ===\")\n",
" \n",
" to = build_fold_to(df, cat_features, cont_features, target_attr, forecast_horizon=forecast_horizon, fold=fold, data_fpath=data_path)\n",
" \n",
" X_train, y_train = to.train.xs, to.train.ys.values.ravel()\n",
" X_test, y_test = to.valid.xs, to.valid.ys.values.ravel()\n",
"\n",
" train_df = to.train.items\n",
" train_df[target_attr] = y_train\n",
"\n",
" val_df = to.valid.items\n",
" val_df[target_attr] = y_test\n",
" \n",
" print(f\"Number of Train/Validation Examples: {len(X_train)} | {len(X_test)}\")\n",
" print(f\"Train Date Range: {train_df.date.min()} | {train_df.date.max()}\")\n",
" print(f\"Validation Date Range: {val_df.date.min()} | {val_df.date.max()}\")\n",
" \n",
" col_keep = []\n",
" rmsse_val = None\n",
" for run_num in range(2):\n",
" # define hyperparameters\n",
" params = dict(\n",
" objective=\"tweedie\",\n",
" tweedie_variance_power=1.1,\n",
" # metric={'tweedie','rmse', 'mape'},\n",
" learning_rate=0.05, # 0.05 ... 1e-1, 1e-2, 1e-3 (default=0.1)\n",
" min_samples_leaf=150,\n",
" feature_fraction=0.3, # 0.3 ... 0.5, 0.8 (default = 1.0)\n",
" subsample=0.3, # 0.3 ... 0.5, 0.8 (default = 1.0; alias for 'subsample')\n",
" deterministic=True,\n",
" )\n",
"\n",
" train_dset = lgbm.Dataset(X_train, y_train)\n",
" val_dset = lgbm.Dataset(X_test, y_test, reference=train_dset)\n",
"\n",
" eval_results = {}\n",
" callbacks = [lgbm.early_stopping(100), lgbm.log_evaluation(50), lgbm.record_evaluation(eval_results)]\n",
"\n",
" model = lgbm.train(\n",
" params,\n",
" train_dset,\n",
" valid_sets=[val_dset, train_dset],\n",
" num_boost_round=1000,\n",
" callbacks=callbacks,\n",
" )\n",
"\n",
" preds = model.predict(X_test)\n",
" \n",
" score_df, rmsse_val = rmsse(train_df, val_df, preds)\n",
" print(f\"RMSSE: {rmsse_val}\")\n",
" \n",
" if run_num == 0:\n",
" # retrain on most important features\n",
" importance = model.feature_importance(importance_type='gain')\n",
" importance_normalized = np.round(importance / np.sum(importance), 4)\n",
"\n",
" fe_imps = {fe_name: fe_imp for fe_name, fe_imp in zip(model.feature_name(), importance_normalized)} \n",
" sorted(fe_imps.items(), key=lambda x: x[1], reverse=True)\n",
" \n",
" col_keep = list(set([col for col, imp in fe_imps.items() if imp > 0.005] + [\"id\"]))\n",
" col_keep = list(set(col_keep + [\"id\"]))\n",
"\n",
" X_train = X_train[col_keep]\n",
" X_test = X_test[col_keep]\n",
" \n",
" # save model and features\n",
" model.save_model(Path(models_path)/f\"lightgbm_fold_{fold}.txt\")\n",
" \n",
" with open(Path(models_path)/f\"lightgbm_fold_{fold}_features.json\", 'w') as json_file:\n",
" json.dump(col_keep, json_file)\n",
" \n",
" # save validation results\n",
" for cat_name in to.cat_names:\n",
" mapping_dict = {idx: value for idx, value in enumerate(to.classes[cat_name])}\n",
" val_df[f\"orig_{cat_name}\"] = val_df[cat_name].map(mapping_dict)\n",
" \n",
" val_df[target_attr] = y_test\n",
" val_df[f\"{target_attr}_preds\"] = preds\n",
" val_df = val_df.merge(score_df.reset_index()[[\"id\", \"rmsse\"]], on=\"id\")\n",
" val_df = val_df.rename(columns={\"rmsse\": \"rmsse_id\"})\n",
" val_df[f\"rmsse_overall\"] = rmsse_val\n",
" \n",
" val_df = (val_df.drop(columns=\"id\").rename(columns={\"orig_id\": \"id\"}))\n",
" val_df[[\"id\", \"date\", target_attr, f\"{target_attr}_preds\", \"rmsse_id\", \"rmsse_overall\"]].to_parquet(Path(results_path)/f\"{max_yyyymmdd}_fold_{fold}_lightgbm_results.parquet\")\n",
" \n",
" \n",
" print(f\"=== Training Finished: Fold {fold} ===\")\n",
" print(\"Final Features: \", col_keep)\n",
" print(\"Final RMSSE: \", rmsse_val)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Step 5: Ensemble"
]
},
{
"cell_type": "code",
"execution_count": 66,
"metadata": {},
"outputs": [],
"source": [
"models = [\"rf\", \"lightgbm\"]"
]
},
{
"cell_type": "code",
"execution_count": 67,
"metadata": {},
"outputs": [],
"source": [
"from functools import partial\n",
"\n",
"def ensemble_metric(weights, train_df, val_df):\n",
" y_hat_avg = np.average(val_df.filter(regex='^model_').values, axis=1, weights=weights)\n",
" # assert y_hat_avg.ndim == 2, 'y_hat_avg has {y_hat_avg.ndim} dimensions, but it must be 2D. Did you calculate a weighted average over the first dimension?'\n",
" # assert y_hat_avg.shape == y.shape, 'y_hat_avg and y must have the same shape. y_hat_avg has shape {y_hat_avg.shape}, but y has shape {y.shape}'\n",
" \n",
" _, rmsse_score = rmsse(train_df, val_df, y_hat_avg)\n",
" return rmsse_score"
]
},
{
"cell_type": "code",
"execution_count": 68,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Fold 1 (rf): RMSSE: 0.7546597446177328\n",
"Fold 1 (lightgbm): RMSSE: 0.7537823122937269\n",
"Fold 1 Inital Blend RMSSE: 0.753705\n",
"Fold 1 Optimised Blend RMSSE: 0.753604\n",
"Fold 1 Optimised Weights: [0.19699333 0.50000008]\n",
"----------------------------------------------------------------------\n",
"rf Optimised Weights: 0.196993\n",
"lightgbm Optimised Weights: 0.500000\n",
"Fold 1 Normalized weights:[0.28263298 0.71736702]\n",
"======================================================================\n",
"Fold 2 (rf): RMSSE: 0.731481537792378\n",
"Fold 2 (lightgbm): RMSSE: 0.7291672416445893\n",
"Fold 2 Inital Blend RMSSE: 0.729646\n",
"Fold 2 Optimised Blend RMSSE: 0.729149\n",
"Fold 2 Optimised Weights: [0.0453946 0.50000091]\n",
"----------------------------------------------------------------------\n",
"rf Optimised Weights: 0.045395\n",
"lightgbm Optimised Weights: 0.500001\n",
"Fold 2 Normalized weights:[0.08323243 0.91676757]\n",
"======================================================================\n",
"Fold 3 (rf): RMSSE: 0.741228056471335\n",
"Fold 3 (lightgbm): RMSSE: 0.739529952210638\n",
"Fold 3 Inital Blend RMSSE: 0.739831\n",
"Fold 3 Optimised Blend RMSSE: 0.739495\n",
"Fold 3 Optimised Weights: [0.06852739 0.4999997 ]\n",
"----------------------------------------------------------------------\n",
"rf Optimised Weights: 0.068527\n",
"lightgbm Optimised Weights: 0.500000\n",
"Fold 3 Normalized weights:[0.12053495 0.87946505]\n",
"======================================================================\n"
]
}
],
"source": [
"for i in range(n_folds):\n",
" fold = i+1\n",
" fold_model_results = {}\n",
" \n",
" train_df = pd.read_parquet(Path(data_path)/f\"{max_yyyymmdd}_fold_{fold}_train.parquet\")\n",
"\n",
" combined_res_df = None\n",
" for m_name in models:\n",
" res_df = pd.read_parquet(Path(results_path)/f\"{max_yyyymmdd}_fold_{fold}_{m_name}_results.parquet\")\n",
" res_df = res_df.sort_values([\"id\", \"date\"])\n",
" \n",
" # score_df, rmsse_score = rmsse(train_df, res_df, res_df[\"sales_preds\"].values)\n",
" print(f\"Fold {fold} ({m_name}): RMSSE: {res_df.iloc[0]['rmsse_overall']}\")\n",
" \n",
" if combined_res_df is None:\n",
" combined_res_df = res_df.copy()\n",
" combined_res_df[f\"model_{m_name}_preds\"] = res_df[\"sales_preds\"]\n",
" combined_res_df = combined_res_df.drop(columns=[\"sales_preds\"])\n",
" else:\n",
" combined_res_df[f\"model_{m_name}_preds\"] = res_df[\"sales_preds\"]\n",
" \n",
" ensemble_metric = partial(ensemble_metric, train_df=train_df, val_df=combined_res_df)\n",
"\n",
" # Our first guess is setting all weights equal to each other, such that they sum up to 1\n",
" init_guess = [1 / len(models)] * len(models) # here will be [0.5, 0.5]\n",
"\n",
" print(f'Fold {fold} Inital Blend RMSSE: {ensemble_metric(init_guess):.6f}')\n",
" \n",
" bnds = [(0, 1) for _ in range(len(models))] # Weights must be between 0 and 1\n",
"\n",
" res_scipy = minimize(\n",
" fun=ensemble_metric, \n",
" x0=init_guess, \n",
" method='Powell', \n",
" bounds=bnds, \n",
" options=dict(maxiter=1_000_000),\n",
" tol=1e-8\n",
" )\n",
"\n",
" print(f'Fold {fold} Optimised Blend RMSSE: {res_scipy.fun:.6f}')\n",
" print(f'Fold {fold} Optimised Weights: {res_scipy.x}')\n",
" print('-' * 70)\n",
"\n",
" oof_names = models\n",
" for n, key in enumerate(oof_names):\n",
" print(f'{key} Optimised Weights: {res_scipy.x[n]:.6f}')\n",
"\n",
" ws = [ res_scipy.x[i] for i in range(len(oof_names))]\n",
"\n",
" # normalize the weights so they sum to 1\n",
" weights = ws / np.sum(ws)\n",
" print(f'Fold {fold} Normalized weights:{weights}')\n",
"\n",
" print('=' * 70)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Step 6: Inference"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Get latest dataset\n",
"\n",
"We simulate getting the latest dataset by simply grabbing \"enough\" of `sales_data` to calculate our lag, rolling, and grouped aggregations on the `target_attr` we want to predict `forecast_horizon` days out from the end.\n",
"\n",
"**Note:** Might be better to do such calculations in a SQL Database if possible so you can just grab the data you want to predict"
]
},
{
"cell_type": "code",
"execution_count": 69,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th></th>\n",
" <th>item_id</th>\n",
" <th>dept_id</th>\n",
" <th>cat_id</th>\n",
" <th>store_id</th>\n",
" <th>state_id</th>\n",
" <th>sales</th>\n",
" </tr>\n",
" <tr>\n",
" <th>date</th>\n",
" <th>id</th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th rowspan=\"5\" valign=\"top\">2013-01-01</th>\n",
" <th>FOODS_1_004_TX_1_evaluation</th>\n",
" <td>FOODS_1_004</td>\n",
" <td>FOODS_1</td>\n",
" <td>FOODS</td>\n",
" <td>TX_1</td>\n",
" <td>TX</td>\n",
" <td>20</td>\n",
" </tr>\n",
" <tr>\n",
" <th>FOODS_1_004_TX_2_evaluation</th>\n",
" <td>FOODS_1_004</td>\n",
" <td>FOODS_1</td>\n",
" <td>FOODS</td>\n",
" <td>TX_2</td>\n",
" <td>TX</td>\n",
" <td>20</td>\n",
" </tr>\n",
" <tr>\n",
" <th>FOODS_1_004_TX_3_evaluation</th>\n",
" <td>FOODS_1_004</td>\n",
" <td>FOODS_1</td>\n",
" <td>FOODS</td>\n",
" <td>TX_3</td>\n",
" <td>TX</td>\n",
" <td>4</td>\n",
" </tr>\n",
" <tr>\n",
" <th>FOODS_1_005_TX_2_evaluation</th>\n",
" <td>FOODS_1_005</td>\n",
" <td>FOODS_1</td>\n",
" <td>FOODS</td>\n",
" <td>TX_2</td>\n",
" <td>TX</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>FOODS_1_009_TX_2_evaluation</th>\n",
" <td>FOODS_1_009</td>\n",
" <td>FOODS_1</td>\n",
" <td>FOODS</td>\n",
" <td>TX_2</td>\n",
" <td>TX</td>\n",
" <td>3</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" item_id dept_id cat_id store_id \\\n",
"date id \n",
"2013-01-01 FOODS_1_004_TX_1_evaluation FOODS_1_004 FOODS_1 FOODS TX_1 \n",
" FOODS_1_004_TX_2_evaluation FOODS_1_004 FOODS_1 FOODS TX_2 \n",
" FOODS_1_004_TX_3_evaluation FOODS_1_004 FOODS_1 FOODS TX_3 \n",
" FOODS_1_005_TX_2_evaluation FOODS_1_005 FOODS_1 FOODS TX_2 \n",
" FOODS_1_009_TX_2_evaluation FOODS_1_009 FOODS_1 FOODS TX_2 \n",
"\n",
" state_id sales \n",
"date id \n",
"2013-01-01 FOODS_1_004_TX_1_evaluation TX 20 \n",
" FOODS_1_004_TX_2_evaluation TX 20 \n",
" FOODS_1_004_TX_3_evaluation TX 4 \n",
" FOODS_1_005_TX_2_evaluation TX 1 \n",
" FOODS_1_009_TX_2_evaluation TX 3 "
]
},
"execution_count": 69,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"df = pd.read_parquet(f\"{data_path}/sales_data.parquet\")\n",
"df.head()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Add in the future `forecast_horizon` dates"
]
},
{
"cell_type": "code",
"execution_count": 70,
"metadata": {},
"outputs": [],
"source": [
"min_date = df.index.get_level_values(\"date\").min()\n",
"# Dates don't have to be exact since there will be 1 record per month!\n",
"max_date = df.index.get_level_values(\"date\").max() + pd.Timedelta(days=forecast_horizon)\n",
"\n",
"# MS ensures that we get the first date of the month\n",
"dates_to_select = pd.date_range(min_date, max_date, freq=\"1D\")\n",
"unique_ids = df.index.get_level_values(\"id\").unique()\n",
"\n",
"# Get all combinations of our new dates and topics\n",
"index_to_select = pd.MultiIndex.from_product([dates_to_select, unique_ids], names=[\"date\", \"id\"])\n",
"\n",
"df = df.reindex(index_to_select)\n"
]
},
{
"cell_type": "code",
"execution_count": 71,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th></th>\n",
" <th>item_id</th>\n",
" <th>dept_id</th>\n",
" <th>cat_id</th>\n",
" <th>store_id</th>\n",
" <th>state_id</th>\n",
" <th>sales</th>\n",
" </tr>\n",
" <tr>\n",
" <th>date</th>\n",
" <th>id</th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
"Empty DataFrame\n",
"Columns: [item_id, dept_id, cat_id, store_id, state_id, sales]\n",
"Index: []"
]
},
"execution_count": 71,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"split_ids = [v.split(\"_\") for v in df.loc[pd.isna(df.sales)].index.get_level_values(\"id\").tolist()]\n",
"# split_ids\n",
"\n",
"fill_cols = [(f\"{s[0]}_{s[1]}_{s[2]}\", f\"{s[0]}_{s[1]}\", f\"{s[0]}\", f\"{s[3]}_{s[4]}\", f\"{s[3]}\", 0) for s in split_ids]\n",
"# fill_cols[:5]\n",
"\n",
"df.loc[pd.isna(df.sales), [\"item_id\", \"dept_id\", \"cat_id\", \"store_id\", \"state_id\", \"sales\"]] = fill_cols\n",
"df.loc[pd.isna(df.sales)]"
]
},
{
"cell_type": "code",
"execution_count": 72,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"2016-06-19 00:00:00\n"
]
}
],
"source": [
"print(df.reset_index().date.max())"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Build our dataset\n",
"\n",
"Again, we need a dataset that is long enough to calculate the lag, rolling and grouped features to use for the future dates we want to predict"
]
},
{
"cell_type": "code",
"execution_count": 73,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"10554667 2016-06-19 00:00:00\n",
"CPU times: user 6min 39s, sys: 56.8 s, total: 7min 35s\n",
"Wall time: 1min 36s\n"
]
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>id</th>\n",
" <th>date</th>\n",
" <th>state_id</th>\n",
" <th>sales</th>\n",
" <th>item_id</th>\n",
" <th>dept_id</th>\n",
" <th>cat_id</th>\n",
" <th>store_id</th>\n",
" <th>sell_price</th>\n",
" <th>snap_TX</th>\n",
" <th>...</th>\n",
" <th>seasonal_rolling_mean_2_28</th>\n",
" <th>seasonal_rolling_mean_4_28</th>\n",
" <th>seasonal_rolling_mean_12_28</th>\n",
" <th>seasonal_rolling_mean_26_28</th>\n",
" <th>seasonal_rolling_mean_52_28</th>\n",
" <th>seasonal_rolling_std_2_28</th>\n",
" <th>seasonal_rolling_std_4_28</th>\n",
" <th>seasonal_rolling_std_12_28</th>\n",
" <th>seasonal_rolling_std_26_28</th>\n",
" <th>seasonal_rolling_std_52_28</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>FOODS_1_004_TX_1_evaluation</td>\n",
" <td>2013-01-01</td>\n",
" <td>TX</td>\n",
" <td>20.0</td>\n",
" <td>FOODS_1_004</td>\n",
" <td>FOODS_1</td>\n",
" <td>FOODS</td>\n",
" <td>TX_1</td>\n",
" <td>1.78</td>\n",
" <td>True</td>\n",
" <td>...</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>FOODS_1_004_TX_2_evaluation</td>\n",
" <td>2013-01-01</td>\n",
" <td>TX</td>\n",
" <td>20.0</td>\n",
" <td>FOODS_1_004</td>\n",
" <td>FOODS_1</td>\n",
" <td>FOODS</td>\n",
" <td>TX_2</td>\n",
" <td>1.78</td>\n",
" <td>True</td>\n",
" <td>...</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>FOODS_1_004_TX_3_evaluation</td>\n",
" <td>2013-01-01</td>\n",
" <td>TX</td>\n",
" <td>4.0</td>\n",
" <td>FOODS_1_004</td>\n",
" <td>FOODS_1</td>\n",
" <td>FOODS</td>\n",
" <td>TX_3</td>\n",
" <td>1.78</td>\n",
" <td>True</td>\n",
" <td>...</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>FOODS_1_005_TX_2_evaluation</td>\n",
" <td>2013-01-01</td>\n",
" <td>TX</td>\n",
" <td>1.0</td>\n",
" <td>FOODS_1_005</td>\n",
" <td>FOODS_1</td>\n",
" <td>FOODS</td>\n",
" <td>TX_2</td>\n",
" <td>3.28</td>\n",
" <td>True</td>\n",
" <td>...</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>FOODS_1_009_TX_2_evaluation</td>\n",
" <td>2013-01-01</td>\n",
" <td>TX</td>\n",
" <td>3.0</td>\n",
" <td>FOODS_1_009</td>\n",
" <td>FOODS_1</td>\n",
" <td>FOODS</td>\n",
" <td>TX_2</td>\n",
" <td>2.68</td>\n",
" <td>True</td>\n",
" <td>...</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>5 rows × 71 columns</p>\n",
"</div>"
],
"text/plain": [
" id date state_id sales item_id \\\n",
"0 FOODS_1_004_TX_1_evaluation 2013-01-01 TX 20.0 FOODS_1_004 \n",
"1 FOODS_1_004_TX_2_evaluation 2013-01-01 TX 20.0 FOODS_1_004 \n",
"2 FOODS_1_004_TX_3_evaluation 2013-01-01 TX 4.0 FOODS_1_004 \n",
"3 FOODS_1_005_TX_2_evaluation 2013-01-01 TX 1.0 FOODS_1_005 \n",
"4 FOODS_1_009_TX_2_evaluation 2013-01-01 TX 3.0 FOODS_1_009 \n",
"\n",
" dept_id cat_id store_id sell_price snap_TX ... \\\n",
"0 FOODS_1 FOODS TX_1 1.78 True ... \n",
"1 FOODS_1 FOODS TX_2 1.78 True ... \n",
"2 FOODS_1 FOODS TX_3 1.78 True ... \n",
"3 FOODS_1 FOODS TX_2 3.28 True ... \n",
"4 FOODS_1 FOODS TX_2 2.68 True ... \n",
"\n",
" seasonal_rolling_mean_2_28 seasonal_rolling_mean_4_28 \\\n",
"0 NaN NaN \n",
"1 NaN NaN \n",
"2 NaN NaN \n",
"3 NaN NaN \n",
"4 NaN NaN \n",
"\n",
" seasonal_rolling_mean_12_28 seasonal_rolling_mean_26_28 \\\n",
"0 NaN NaN \n",
"1 NaN NaN \n",
"2 NaN NaN \n",
"3 NaN NaN \n",
"4 NaN NaN \n",
"\n",
" seasonal_rolling_mean_52_28 seasonal_rolling_std_2_28 \\\n",
"0 NaN NaN \n",
"1 NaN NaN \n",
"2 NaN NaN \n",
"3 NaN NaN \n",
"4 NaN NaN \n",
"\n",
" seasonal_rolling_std_4_28 seasonal_rolling_std_12_28 \\\n",
"0 NaN NaN \n",
"1 NaN NaN \n",
"2 NaN NaN \n",
"3 NaN NaN \n",
"4 NaN NaN \n",
"\n",
" seasonal_rolling_std_26_28 seasonal_rolling_std_52_28 \n",
"0 NaN NaN \n",
"1 NaN NaN \n",
"2 NaN NaN \n",
"3 NaN NaN \n",
"4 NaN NaN \n",
"\n",
"[5 rows x 71 columns]"
]
},
"execution_count": 73,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%time\n",
"\n",
"lag_features = [1, 2, 3, 7, 14, 21, 30, 90, 365]\n",
"seasonal_lag_features = [1, 2, 4, 12, 26, 52]\n",
"\n",
"rolling_features = {\"mean\": [7, 14], \"std\": [7, 14]}\n",
"seasonal_rolling_features = {\"mean\": [1, 2, 4, 12, 26, 52], \"std\": [2, 4, 12, 26, 52]} # a std for 1 results in all NaNs (why???)\n",
"\n",
"df = build_dataset(\n",
" data_path,\n",
" latest_sales_df = df.reset_index(),\n",
" level=\"id\",\n",
" forecast_horizon=forecast_horizon,\n",
" lag_features=lag_features,\n",
" seasonal_lag_features=seasonal_lag_features,\n",
" rolling_features=rolling_features,\n",
" seasonal_rolling_features=seasonal_rolling_features,\n",
" cache=False,\n",
" override=True\n",
")\n",
"\n",
"print(len(df), df.date.max())\n",
"df.head()"
]
},
{
"cell_type": "code",
"execution_count": 74,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"2016-06-19 00:00:00\n"
]
}
],
"source": [
"# during training `inf_df` represented our validation set, but we can use the same function to get our inference data\n",
"# since our dataset is prepared identically to that used in training with the last 28 days set to be predicted\n",
"train_df, inf_df = build_train_test_splits(df, forecast_horizon=forecast_horizon, fold=1)\n",
"train_df, inf_df = add_group_features(train_df, inf_df)\n",
"\n",
"inf_df = inf_df.fillna(0)\n",
"\n",
"# I print out the max date occassionaly to make sure I'm predicting on expected dates\n",
"print(inf_df.date.max())"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Get predictions"
]
},
{
"cell_type": "code",
"execution_count": 75,
"metadata": {},
"outputs": [],
"source": [
"models = ['rf', 'lightgbm']\n",
"weights = [\n",
" [0.74201649, 0.25798351],\n",
" [0.04333793, 0.95666207],\n",
" [0.14360396, 0.85639604]\n",
"]"
]
},
{
"cell_type": "code",
"execution_count": 76,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"256088\n",
"[array([2.03829893, 1.02116342, 2.65377985, ..., 0.35438653, 1.63700769,\n",
" 0.31277315]), array([2.91690379, 1.84358136, 3.23554852, ..., 0.3585518 , 1.57657679,\n",
" 0.32384546]), array([1.88160358, 1.22535588, 2.048757 , ..., 0.33898918, 1.61421308,\n",
" 0.31590147])]\n"
]
}
],
"source": [
"preds = []\n",
"for i, fold_weights in enumerate(weights):\n",
" fold, fold_preds = i + 1, []\n",
" \n",
" # grab out tabular object\n",
" to_fpath = Path(f\"{data_path}/{max_yyyymmdd}_to_fh_{forecast_horizon}_fold_1.pkl\")\n",
" with open(to_fpath, \"rb\") as file:\n",
" to = pickle.load(file)\n",
" \n",
" # prepare our dataset for inference\n",
" to_tst = to.new(inf_df)\n",
" to_tst.process()\n",
" to_tst.items.head()\n",
" \n",
" # get predictions for each model in the fold\n",
" for m in models:\n",
" with open(Path(models_path)/f\"{m}_fold_{fold}_features.json\", 'r') as json_file:\n",
" model_features = json.load(json_file)\n",
" \n",
" if m == 'rf':\n",
" X_test = to_tst.xs[model_features]\n",
" inf_model = joblib.load(Path(models_path)/f\"{m}_fold_{fold}.pkl\")\n",
" fold_preds.append(inf_model.predict(X_test))\n",
" elif m == 'lightgbm':\n",
" X_test = to_tst.xs[model_features]\n",
" inf_model = lgbm.Booster(model_file=Path(models_path)/f\"{m}_fold_{fold}.txt\")\n",
" fold_preds.append(inf_model.predict(X_test))\n",
" \n",
" # apply learned fold weights\n",
" preds.append(np.array(fold_weights) @ np.array(fold_preds))\n",
" \n",
"# average folds to get final predictions\n",
"final_preds = np.average(np.array(preds), axis=0)#, weights=weights)\n",
"\n",
"print(len(final_preds))\n",
"print(preds[:5])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 79,
"metadata": {},
"outputs": [],
"source": [
"import warnings\n",
"\n",
"def export(to:TabularPandas, fname='export.pkl', pickle_protocol=2):\n",
" \"Export the contents of `self` without the items\"\n",
" old_to = to\n",
" to = to.new_empty()\n",
" with warnings.catch_warnings():\n",
" warnings.simplefilter(\"ignore\")\n",
" pickle.dump(to, open(Path(fname), 'wb'), protocol=pickle_protocol)\n",
" to = old_to"
]
},
{
"cell_type": "code",
"execution_count": 80,
"metadata": {},
"outputs": [],
"source": [
"export(to, \"./data/aaaaaaaa.pkl\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "corise-forecasting",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment