Skip to content

Instantly share code, notes, and snippets.

@aryan-jadon
Created September 16, 2022 02:12
Show Gist options
  • Star 3 You must be signed in to star a gist
  • Fork 2 You must be signed in to fork a gist
  • Save aryan-jadon/115f8fcf8fa20f34a0904fb8d196e2b6 to your computer and use it in GitHub Desktop.
Save aryan-jadon/115f8fcf8fa20f34a0904fb8d196e2b6 to your computer and use it in GitHub Desktop.
Part-1.ipynb
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "6c74f11f",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>agency</th>\n",
" <th>sku</th>\n",
" <th>volume</th>\n",
" <th>date</th>\n",
" <th>industry_volume</th>\n",
" <th>soda_volume</th>\n",
" <th>avg_max_temp</th>\n",
" <th>price_regular</th>\n",
" <th>price_actual</th>\n",
" <th>discount</th>\n",
" <th>...</th>\n",
" <th>football_gold_cup</th>\n",
" <th>beer_capital</th>\n",
" <th>music_fest</th>\n",
" <th>discount_in_percent</th>\n",
" <th>timeseries</th>\n",
" <th>time_idx</th>\n",
" <th>month</th>\n",
" <th>log_volume</th>\n",
" <th>avg_volume_by_sku</th>\n",
" <th>avg_volume_by_agency</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>291</th>\n",
" <td>Agency_25</td>\n",
" <td>SKU_03</td>\n",
" <td>0.5076</td>\n",
" <td>2013-01-01</td>\n",
" <td>492612703</td>\n",
" <td>718394219</td>\n",
" <td>25.845238</td>\n",
" <td>1264.162234</td>\n",
" <td>1152.473405</td>\n",
" <td>111.688829</td>\n",
" <td>...</td>\n",
" <td>-</td>\n",
" <td>-</td>\n",
" <td>-</td>\n",
" <td>8.835008</td>\n",
" <td>228</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>-0.678062</td>\n",
" <td>1225.306376</td>\n",
" <td>99.650400</td>\n",
" </tr>\n",
" <tr>\n",
" <th>871</th>\n",
" <td>Agency_29</td>\n",
" <td>SKU_02</td>\n",
" <td>8.7480</td>\n",
" <td>2015-01-01</td>\n",
" <td>498567142</td>\n",
" <td>762225057</td>\n",
" <td>27.584615</td>\n",
" <td>1316.098485</td>\n",
" <td>1296.804924</td>\n",
" <td>19.293561</td>\n",
" <td>...</td>\n",
" <td>-</td>\n",
" <td>-</td>\n",
" <td>-</td>\n",
" <td>1.465966</td>\n",
" <td>177</td>\n",
" <td>24</td>\n",
" <td>1</td>\n",
" <td>2.168825</td>\n",
" <td>1634.434615</td>\n",
" <td>11.397086</td>\n",
" </tr>\n",
" <tr>\n",
" <th>19532</th>\n",
" <td>Agency_47</td>\n",
" <td>SKU_01</td>\n",
" <td>4.9680</td>\n",
" <td>2013-09-01</td>\n",
" <td>454252482</td>\n",
" <td>789624076</td>\n",
" <td>30.665957</td>\n",
" <td>1269.250000</td>\n",
" <td>1266.490490</td>\n",
" <td>2.759510</td>\n",
" <td>...</td>\n",
" <td>-</td>\n",
" <td>-</td>\n",
" <td>-</td>\n",
" <td>0.217413</td>\n",
" <td>322</td>\n",
" <td>8</td>\n",
" <td>9</td>\n",
" <td>1.603017</td>\n",
" <td>2625.472644</td>\n",
" <td>48.295650</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2089</th>\n",
" <td>Agency_53</td>\n",
" <td>SKU_07</td>\n",
" <td>21.6825</td>\n",
" <td>2013-10-01</td>\n",
" <td>480693900</td>\n",
" <td>791658684</td>\n",
" <td>29.197727</td>\n",
" <td>1193.842373</td>\n",
" <td>1128.124395</td>\n",
" <td>65.717978</td>\n",
" <td>...</td>\n",
" <td>-</td>\n",
" <td>beer_capital</td>\n",
" <td>-</td>\n",
" <td>5.504745</td>\n",
" <td>240</td>\n",
" <td>9</td>\n",
" <td>10</td>\n",
" <td>3.076505</td>\n",
" <td>38.529107</td>\n",
" <td>2511.035175</td>\n",
" </tr>\n",
" <tr>\n",
" <th>9755</th>\n",
" <td>Agency_17</td>\n",
" <td>SKU_02</td>\n",
" <td>960.5520</td>\n",
" <td>2015-03-01</td>\n",
" <td>515468092</td>\n",
" <td>871204688</td>\n",
" <td>23.608120</td>\n",
" <td>1338.334248</td>\n",
" <td>1232.128069</td>\n",
" <td>106.206179</td>\n",
" <td>...</td>\n",
" <td>-</td>\n",
" <td>-</td>\n",
" <td>music_fest</td>\n",
" <td>7.935699</td>\n",
" <td>259</td>\n",
" <td>26</td>\n",
" <td>3</td>\n",
" <td>6.867508</td>\n",
" <td>2143.677462</td>\n",
" <td>396.022140</td>\n",
" </tr>\n",
" <tr>\n",
" <th>7561</th>\n",
" <td>Agency_05</td>\n",
" <td>SKU_03</td>\n",
" <td>1184.6535</td>\n",
" <td>2014-02-01</td>\n",
" <td>425528909</td>\n",
" <td>734443953</td>\n",
" <td>28.668254</td>\n",
" <td>1369.556376</td>\n",
" <td>1161.135214</td>\n",
" <td>208.421162</td>\n",
" <td>...</td>\n",
" <td>-</td>\n",
" <td>-</td>\n",
" <td>-</td>\n",
" <td>15.218151</td>\n",
" <td>21</td>\n",
" <td>13</td>\n",
" <td>2</td>\n",
" <td>7.077206</td>\n",
" <td>1566.643589</td>\n",
" <td>1881.866367</td>\n",
" </tr>\n",
" <tr>\n",
" <th>19204</th>\n",
" <td>Agency_11</td>\n",
" <td>SKU_05</td>\n",
" <td>5.5593</td>\n",
" <td>2017-08-01</td>\n",
" <td>623319783</td>\n",
" <td>1049868815</td>\n",
" <td>31.915385</td>\n",
" <td>1922.486644</td>\n",
" <td>1651.307674</td>\n",
" <td>271.178970</td>\n",
" <td>...</td>\n",
" <td>-</td>\n",
" <td>-</td>\n",
" <td>-</td>\n",
" <td>14.105636</td>\n",
" <td>17</td>\n",
" <td>55</td>\n",
" <td>8</td>\n",
" <td>1.715472</td>\n",
" <td>1385.225478</td>\n",
" <td>109.699200</td>\n",
" </tr>\n",
" <tr>\n",
" <th>8781</th>\n",
" <td>Agency_48</td>\n",
" <td>SKU_04</td>\n",
" <td>4275.1605</td>\n",
" <td>2013-03-01</td>\n",
" <td>509281531</td>\n",
" <td>892192092</td>\n",
" <td>26.767857</td>\n",
" <td>1761.258209</td>\n",
" <td>1546.059670</td>\n",
" <td>215.198539</td>\n",
" <td>...</td>\n",
" <td>-</td>\n",
" <td>-</td>\n",
" <td>music_fest</td>\n",
" <td>12.218455</td>\n",
" <td>151</td>\n",
" <td>2</td>\n",
" <td>3</td>\n",
" <td>8.360577</td>\n",
" <td>1757.950603</td>\n",
" <td>1925.272108</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2540</th>\n",
" <td>Agency_07</td>\n",
" <td>SKU_21</td>\n",
" <td>0.0000</td>\n",
" <td>2015-10-01</td>\n",
" <td>544203593</td>\n",
" <td>761469815</td>\n",
" <td>28.987755</td>\n",
" <td>0.000000</td>\n",
" <td>0.000000</td>\n",
" <td>0.000000</td>\n",
" <td>...</td>\n",
" <td>-</td>\n",
" <td>-</td>\n",
" <td>-</td>\n",
" <td>0.000000</td>\n",
" <td>300</td>\n",
" <td>33</td>\n",
" <td>10</td>\n",
" <td>-18.420681</td>\n",
" <td>0.000000</td>\n",
" <td>2418.719550</td>\n",
" </tr>\n",
" <tr>\n",
" <th>12084</th>\n",
" <td>Agency_21</td>\n",
" <td>SKU_03</td>\n",
" <td>46.3608</td>\n",
" <td>2017-04-01</td>\n",
" <td>589969396</td>\n",
" <td>940912941</td>\n",
" <td>32.478910</td>\n",
" <td>1675.922116</td>\n",
" <td>1413.571789</td>\n",
" <td>262.350327</td>\n",
" <td>...</td>\n",
" <td>-</td>\n",
" <td>-</td>\n",
" <td>-</td>\n",
" <td>15.654088</td>\n",
" <td>181</td>\n",
" <td>51</td>\n",
" <td>4</td>\n",
" <td>3.836454</td>\n",
" <td>2034.293024</td>\n",
" <td>109.381800</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>10 rows × 31 columns</p>\n",
"</div>"
],
"text/plain": [
" agency sku volume date industry_volume soda_volume \\\n",
"291 Agency_25 SKU_03 0.5076 2013-01-01 492612703 718394219 \n",
"871 Agency_29 SKU_02 8.7480 2015-01-01 498567142 762225057 \n",
"19532 Agency_47 SKU_01 4.9680 2013-09-01 454252482 789624076 \n",
"2089 Agency_53 SKU_07 21.6825 2013-10-01 480693900 791658684 \n",
"9755 Agency_17 SKU_02 960.5520 2015-03-01 515468092 871204688 \n",
"7561 Agency_05 SKU_03 1184.6535 2014-02-01 425528909 734443953 \n",
"19204 Agency_11 SKU_05 5.5593 2017-08-01 623319783 1049868815 \n",
"8781 Agency_48 SKU_04 4275.1605 2013-03-01 509281531 892192092 \n",
"2540 Agency_07 SKU_21 0.0000 2015-10-01 544203593 761469815 \n",
"12084 Agency_21 SKU_03 46.3608 2017-04-01 589969396 940912941 \n",
"\n",
" avg_max_temp price_regular price_actual discount ... \\\n",
"291 25.845238 1264.162234 1152.473405 111.688829 ... \n",
"871 27.584615 1316.098485 1296.804924 19.293561 ... \n",
"19532 30.665957 1269.250000 1266.490490 2.759510 ... \n",
"2089 29.197727 1193.842373 1128.124395 65.717978 ... \n",
"9755 23.608120 1338.334248 1232.128069 106.206179 ... \n",
"7561 28.668254 1369.556376 1161.135214 208.421162 ... \n",
"19204 31.915385 1922.486644 1651.307674 271.178970 ... \n",
"8781 26.767857 1761.258209 1546.059670 215.198539 ... \n",
"2540 28.987755 0.000000 0.000000 0.000000 ... \n",
"12084 32.478910 1675.922116 1413.571789 262.350327 ... \n",
"\n",
" football_gold_cup beer_capital music_fest discount_in_percent \\\n",
"291 - - - 8.835008 \n",
"871 - - - 1.465966 \n",
"19532 - - - 0.217413 \n",
"2089 - beer_capital - 5.504745 \n",
"9755 - - music_fest 7.935699 \n",
"7561 - - - 15.218151 \n",
"19204 - - - 14.105636 \n",
"8781 - - music_fest 12.218455 \n",
"2540 - - - 0.000000 \n",
"12084 - - - 15.654088 \n",
"\n",
" timeseries time_idx month log_volume avg_volume_by_sku \\\n",
"291 228 0 1 -0.678062 1225.306376 \n",
"871 177 24 1 2.168825 1634.434615 \n",
"19532 322 8 9 1.603017 2625.472644 \n",
"2089 240 9 10 3.076505 38.529107 \n",
"9755 259 26 3 6.867508 2143.677462 \n",
"7561 21 13 2 7.077206 1566.643589 \n",
"19204 17 55 8 1.715472 1385.225478 \n",
"8781 151 2 3 8.360577 1757.950603 \n",
"2540 300 33 10 -18.420681 0.000000 \n",
"12084 181 51 4 3.836454 2034.293024 \n",
"\n",
" avg_volume_by_agency \n",
"291 99.650400 \n",
"871 11.397086 \n",
"19532 48.295650 \n",
"2089 2511.035175 \n",
"9755 396.022140 \n",
"7561 1881.866367 \n",
"19204 109.699200 \n",
"8781 1925.272108 \n",
"2540 2418.719550 \n",
"12084 109.381800 \n",
"\n",
"[10 rows x 31 columns]"
]
},
"execution_count": 1,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import os\n",
"import warnings\n",
"import copy\n",
"from pathlib import Path\n",
"import numpy as np\n",
"import pandas as pd\n",
"import pytorch_lightning as pl\n",
"from pytorch_lightning.callbacks import EarlyStopping, LearningRateMonitor\n",
"from pytorch_lightning.loggers import TensorBoardLogger\n",
"import torch\n",
"\n",
"from pytorch_forecasting import Baseline, TemporalFusionTransformer, TimeSeriesDataSet\n",
"\n",
"warnings.filterwarnings(\"ignore\") # avoid printing out absolute paths\n",
"from pytorch_forecasting.data import GroupNormalizer\n",
"from pytorch_forecasting.metrics import SMAPE, PoissonLoss, QuantileLoss\n",
"from pytorch_forecasting.models.temporal_fusion_transformer.tuning import optimize_hyperparameters\n",
"\n",
"from pytorch_forecasting.data.examples import get_stallion_data\n",
"\n",
"data = get_stallion_data()\n",
"\n",
"# add time index\n",
"data[\"time_idx\"] = data[\"date\"].dt.year * 12 + data[\"date\"].dt.month\n",
"data[\"time_idx\"] -= data[\"time_idx\"].min()\n",
"\n",
"# add additional features\n",
"data[\"month\"] = data.date.dt.month.astype(str).astype(\"category\") # categories have be strings\n",
"data[\"log_volume\"] = np.log(data.volume + 1e-8)\n",
"data[\"avg_volume_by_sku\"] = data.groupby([\"time_idx\", \"sku\"], observed=True).volume.transform(\"mean\")\n",
"data[\"avg_volume_by_agency\"] = data.groupby([\"time_idx\", \"agency\"], observed=True).volume.transform(\"mean\")\n",
"\n",
"# we want to encode special days as one variable and thus need to first reverse one-hot encoding\n",
"special_days = [\n",
" \"easter_day\",\n",
" \"good_friday\",\n",
" \"new_year\",\n",
" \"christmas\",\n",
" \"labor_day\",\n",
" \"independence_day\",\n",
" \"revolution_day_memorial\",\n",
" \"regional_games\",\n",
" \"fifa_u_17_world_cup\",\n",
" \"football_gold_cup\",\n",
" \"beer_capital\",\n",
" \"music_fest\",\n",
"]\n",
"data[special_days] = data[special_days].apply(lambda x: x.map({0: \"-\", 1: x.name})).astype(\"category\")\n",
"data.sample(10, random_state=521)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3.9 (pytorch)",
"language": "python",
"name": "pytorch"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.13"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
@Yuu1001
Copy link

Yuu1001 commented Oct 13, 2022

Hi,can i ask why i run code by pycharm above and get no outcome like you?even no outcome and just a word 'Process finished with exit code 0'

@cicidodogoat
Copy link

you need a print(data.sample(10, random_state=521))

@wayu0730
Copy link

Hi, may I know what version of PyTorch you are using?

@aryan-jadon
Copy link
Author

Hi, may I know what version of PyTorch you are using?

Installation Steps and Version Instructions can be found here - https://github.com/aryan-jadon/Regression-Loss-Functions-in-Time-Series-Forecasting-PyTorch

@aryan-jadon
Copy link
Author

Hi,can i ask why i run code by pycharm above and get no outcome like you?even no outcome and just a word 'Process finished with exit code 0'

Try Running this on google colab

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment