Created
February 12, 2021 15:50
-
-
Save CookieBox26/fff02a28414bc6e530dd217a1a1453d8 to your computer and use it in GitHub Desktop.
Informer のエンコーダとデコーダへの入力の期間を確認するだけ
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Informer のエンコーダとデコーダへの入力の期間を確認するだけ\n", | |
"\n", | |
"このノートを動作させるには手元に以下の2つのリポジトリをcloneし、Informer2020の直下で作業する。 \n", | |
"https://github.com/zhouhaoyi/Informer2020 \n", | |
"https://github.com/zhouhaoyi/ETDataset \n", | |
" \n", | |
"デフォルトの設定では、ETDataset の1時間毎のデータ(ETDataset/ETT-small/ETTh1.csv)を、\n", | |
"- 現時点までの96ステップ(seq_len)を使って未来の24ステップ(pred_len)を予測する。\n", | |
"- ただし、デコーダには現時点までの48ステップ(label_len)をフィードする。\n", | |
"- つまり、エンコーダへの入力の長さは96、デコーダへの入力の長さは72になる。\n", | |
" - デコーダへの入力は、予測すべき箇所はゼロで塗りつぶしておく。\n", | |
"\n", | |
"| date | x = encoder input | y = decoder target | decoder input | seq_len = 96 | label_len = 48 | pred_len = 24 |\n", | |
"| ---- | ---- | ---- | ---- | ---- | ---- | ---- |\n", | |
"| 2016-07-01 00:00:00 | X | - | - | * | - | - |\n", | |
"| 2016-07-01 01:00:00 | X | - | - | * | - | - |\n", | |
"| *** | X | - | - | * | - | - |\n", | |
"| 2016-07-02 23:00:00 | X | - | - | * | - | - |\n", | |
"| 2016-07-03 00:00:00 | X | X | X | * | * | - |\n", | |
"| *** | X | X | X | * | * | - |\n", | |
"| 2016-07-04 23:00:00 | X | X | X | * | * | - |\n", | |
"| 2016-07-05 00:00:00 | - | X | O | - | - | * |\n", | |
"| *** | - | X | O | - | - | * |\n", | |
"| 2016-07-05 23:00:00 | - | X | O | - | - | * |" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"### 元データを確認する" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"from IPython.core.display import display, HTML\n", | |
"import pandas as pd\n", | |
"from sklearn.preprocessing import StandardScaler\n", | |
"\n", | |
"df_raw = pd.read_csv('../ETDataset/ETT-small/ETTh1.csv')\n", | |
"display(HTML('<h4>生データ</h4>'))\n", | |
"display(HTML(df_raw.head().to_html(index=True)))\n", | |
"\n", | |
"scaler = StandardScaler()\n", | |
"df_data = df_raw.iloc[:,1:]\n", | |
"data = scaler.fit_transform(df_data.values)\n", | |
"df_raw.iloc[:,1:] = data\n", | |
"display(HTML('<h4>スケール後(学習時はスケールされるのでスケール後と見比べる)</h4>'))\n", | |
"display(HTML(df_raw.head().to_html(index=True)))\n", | |
"display(HTML(df_raw.iloc[46:,:].head().to_html(index=True)))\n", | |
"display(HTML(df_raw.iloc[94:,:].head().to_html(index=True)))\n", | |
"display(HTML(df_raw.iloc[118:,:].head().to_html(index=True)))" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"### 1バッチ目の1つ目のデータを確認し元データと見比べる\n", | |
"\n", | |
"batch_x, batch_y が上のデータと合致する。 \n", | |
"dec_inp は予測すべき箇所がゼロで塗りつぶされている。 \n", | |
"batch_x_mark, batch_y_mark は「月, 日, 曜日, 時」のタイムスタンプであり、データと同じ次元に埋め込まれデータに足される。" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import torch\n", | |
"from data.data_loader import Dataset_ETT_hour\n", | |
"from torch.utils.data import DataLoader\n", | |
"\n", | |
"seq_len = 24 * 4\n", | |
"label_len = 24 * 2\n", | |
"pred_len = 24\n", | |
"\n", | |
"data_set = Dataset_ETT_hour(\n", | |
" root_path='../ETDataset/ETT-small/',\n", | |
" data_path='ETTh1.csv',\n", | |
" flag='train', \n", | |
" size=[seq_len, label_len, pred_len],\n", | |
" features='M', # M の場合は HUFL, HULL, MUFL, MULL, LUFL, LULL を利用, S の場合は OT のみ利用\n", | |
" scale=True,\n", | |
")\n", | |
"data_loader = DataLoader(\n", | |
" data_set,\n", | |
" batch_size=32,\n", | |
" shuffle=False,\n", | |
" num_workers=0,\n", | |
" drop_last=True\n", | |
")\n", | |
"\n", | |
"for i, (batch_x, batch_y, batch_x_mark, batch_y_mark) in enumerate(data_loader):\n", | |
" print('----- batch_x -----')\n", | |
" print(batch_x.size())\n", | |
" print(batch_x[0,:3,:])\n", | |
" print(batch_x[0,-3:,:])\n", | |
" \n", | |
" print('\\n----- batch_y -----')\n", | |
" print(batch_y.size())\n", | |
" print(batch_y[0,:3,:])\n", | |
" print(batch_y[0,-3:,:])\n", | |
" \n", | |
" print('\\n----- batch_x_mark -----')\n", | |
" print(batch_x_mark.size())\n", | |
" print(batch_x_mark[0,:3,:])\n", | |
" print(batch_x_mark[0,-3:,:])\n", | |
" \n", | |
" print('\\n----- batch_y_mark -----')\n", | |
" print(batch_y_mark.size())\n", | |
" print(batch_y_mark[0,:3,:])\n", | |
" print(batch_y_mark[0,-3:,:])\n", | |
"\n", | |
" # decoder input\n", | |
" dec_inp = torch.zeros_like(batch_y[:,-pred_len:,:]).double()\n", | |
" dec_inp = torch.cat([batch_y[:,:label_len,:], dec_inp], dim=1)\n", | |
" \n", | |
" print('\\n----- dec_inp -----')\n", | |
" print(dec_inp.size())\n", | |
" print(dec_inp[0,:3,:])\n", | |
" print(dec_inp[0,-26:-22,:])\n", | |
" print(dec_inp[0,-3:,:])\n", | |
" break" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.7.0" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment