Created
September 16, 2022 02:12
-
-
Save aryan-jadon/80a39ef6d01de934ff6654ef09191477 to your computer and use it in GitHub Desktop.
Part-2.ipynb
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"id": "c5fcc709", | |
"metadata": {}, | |
"source": [ | |
"### Creating dataset and dataloaders" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"id": "c960e129", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"max_prediction_length = 6\n", | |
"max_encoder_length = 24\n", | |
"training_cutoff = data[\"time_idx\"].max() - max_prediction_length\n", | |
"\n", | |
"training = TimeSeriesDataSet(\n", | |
" data[lambda x: x.time_idx <= training_cutoff],\n", | |
" time_idx=\"time_idx\",\n", | |
" target=\"volume\",\n", | |
" group_ids=[\"agency\", \"sku\"],\n", | |
" min_encoder_length=max_encoder_length // 2, # keep encoder length long (as it is in the validation set)\n", | |
" max_encoder_length=max_encoder_length,\n", | |
" min_prediction_length=1,\n", | |
" max_prediction_length=max_prediction_length,\n", | |
" static_categoricals=[\"agency\", \"sku\"],\n", | |
" static_reals=[\"avg_population_2017\", \"avg_yearly_household_income_2017\"],\n", | |
" time_varying_known_categoricals=[\"special_days\", \"month\"],\n", | |
" variable_groups={\"special_days\": special_days}, # group of categorical variables can be treated as one variable\n", | |
" time_varying_known_reals=[\"time_idx\", \"price_regular\", \"discount_in_percent\"],\n", | |
" time_varying_unknown_categoricals=[],\n", | |
" time_varying_unknown_reals=[\n", | |
" \"volume\",\n", | |
" \"log_volume\",\n", | |
" \"industry_volume\",\n", | |
" \"soda_volume\",\n", | |
" \"avg_max_temp\",\n", | |
" \"avg_volume_by_agency\",\n", | |
" \"avg_volume_by_sku\",\n", | |
" ],\n", | |
" target_normalizer=GroupNormalizer(\n", | |
" groups=[\"agency\", \"sku\"], transformation=\"softplus\"\n", | |
" ), # use softplus and normalize by group\n", | |
" add_relative_time_idx=True,\n", | |
" add_target_scales=True,\n", | |
" add_encoder_length=True,\n", | |
")\n", | |
"\n", | |
"# create validation set (predict=True) which means to predict the last max_prediction_length points in time\n", | |
"# for each series\n", | |
"validation = TimeSeriesDataSet.from_dataset(training, data, predict=True, stop_randomization=True)\n", | |
"\n", | |
"# create dataloaders for model\n", | |
"batch_size = 128 # set this between 32 to 128\n", | |
"train_dataloader = training.to_dataloader(train=True, batch_size=batch_size, num_workers=0)\n", | |
"val_dataloader = validation.to_dataloader(train=False, batch_size=batch_size * 10, num_workers=0)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "df575038", | |
"metadata": {}, | |
"source": [ | |
"### Creating baseline model" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "dda95fef", | |
"metadata": {}, | |
"source": [ | |
"Evaluating a Baseline model that predicts the next 6 months by simply repeating the last observed volume gives us a simle benchmark that we want to outperform." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"id": "f41d36fa", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"293.0088195800781" | |
] | |
}, | |
"execution_count": 3, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"# calculate baseline mean absolute error, i.e. predict next value as the last available value from the history\n", | |
"actuals = torch.cat([y for x, (y, weight) in iter(val_dataloader)])\n", | |
"baseline_predictions = Baseline().predict(val_dataloader)\n", | |
"(actuals - baseline_predictions).abs().mean().item()" | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3.9 (pytorch)", | |
"language": "python", | |
"name": "pytorch" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.9.13" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 5 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment