Skip to content

Instantly share code, notes, and snippets.

@CookieBox26
Created February 12, 2021 15:50
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save CookieBox26/fff02a28414bc6e530dd217a1a1453d8 to your computer and use it in GitHub Desktop.
Save CookieBox26/fff02a28414bc6e530dd217a1a1453d8 to your computer and use it in GitHub Desktop.
Informer のエンコーダとデコーダへの入力の期間を確認するだけ
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Informer のエンコーダとデコーダへの入力の期間を確認するだけ\n",
"\n",
"このノートを動作させるには手元に以下の2つのリポジトリをcloneし、Informer2020の直下で作業する。 \n",
"https://github.com/zhouhaoyi/Informer2020 \n",
"https://github.com/zhouhaoyi/ETDataset \n",
" \n",
"デフォルトの設定では、ETDataset の1時間毎のデータ(ETDataset/ETT-small/ETTh1.csv)を、\n",
"- 現時点までの96ステップ(seq_len)を使って未来の24ステップ(pred_len)を予測する。\n",
"- ただし、デコーダには現時点までの48ステップ(label_len)をフィードする。\n",
"- つまり、エンコーダへの入力の長さは96、デコーダへの入力の長さは72になる。\n",
" - デコーダへの入力は、予測すべき箇所はゼロで塗りつぶしておく。\n",
"\n",
"| date | x = encoder input | y = decoder target | decoder input | seq_len = 96 | label_len = 48 | pred_len = 24 |\n",
"| ---- | ---- | ---- | ---- | ---- | ---- | ---- |\n",
"| 2016-07-01 00:00:00 | X | - | - | * | - | - |\n",
"| 2016-07-01 01:00:00 | X | - | - | * | - | - |\n",
"| *** | X | - | - | * | - | - |\n",
"| 2016-07-02 23:00:00 | X | - | - | * | - | - |\n",
"| 2016-07-03 00:00:00 | X | X | X | * | * | - |\n",
"| *** | X | X | X | * | * | - |\n",
"| 2016-07-04 23:00:00 | X | X | X | * | * | - |\n",
"| 2016-07-05 00:00:00 | - | X | O | - | - | * |\n",
"| *** | - | X | O | - | - | * |\n",
"| 2016-07-05 23:00:00 | - | X | O | - | - | * |"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 元データを確認する"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from IPython.core.display import display, HTML\n",
"import pandas as pd\n",
"from sklearn.preprocessing import StandardScaler\n",
"\n",
"df_raw = pd.read_csv('../ETDataset/ETT-small/ETTh1.csv')\n",
"display(HTML('<h4>生データ</h4>'))\n",
"display(HTML(df_raw.head().to_html(index=True)))\n",
"\n",
"scaler = StandardScaler()\n",
"df_data = df_raw.iloc[:,1:]\n",
"data = scaler.fit_transform(df_data.values)\n",
"df_raw.iloc[:,1:] = data\n",
"display(HTML('<h4>スケール後(学習時はスケールされるのでスケール後と見比べる)</h4>'))\n",
"display(HTML(df_raw.head().to_html(index=True)))\n",
"display(HTML(df_raw.iloc[46:,:].head().to_html(index=True)))\n",
"display(HTML(df_raw.iloc[94:,:].head().to_html(index=True)))\n",
"display(HTML(df_raw.iloc[118:,:].head().to_html(index=True)))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 1バッチ目の1つ目のデータを確認し元データと見比べる\n",
"\n",
"batch_x, batch_y が上のデータと合致する。 \n",
"dec_inp は予測すべき箇所がゼロで塗りつぶされている。 \n",
"batch_x_mark, batch_y_mark は「月, 日, 曜日, 時」のタイムスタンプであり、データと同じ次元に埋め込まれデータに足される。"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"from data.data_loader import Dataset_ETT_hour\n",
"from torch.utils.data import DataLoader\n",
"\n",
"seq_len = 24 * 4\n",
"label_len = 24 * 2\n",
"pred_len = 24\n",
"\n",
"data_set = Dataset_ETT_hour(\n",
" root_path='../ETDataset/ETT-small/',\n",
" data_path='ETTh1.csv',\n",
" flag='train', \n",
" size=[seq_len, label_len, pred_len],\n",
" features='M', # M の場合は HUFL, HULL, MUFL, MULL, LUFL, LULL を利用, S の場合は OT のみ利用\n",
" scale=True,\n",
")\n",
"data_loader = DataLoader(\n",
" data_set,\n",
" batch_size=32,\n",
" shuffle=False,\n",
" num_workers=0,\n",
" drop_last=True\n",
")\n",
"\n",
"for i, (batch_x, batch_y, batch_x_mark, batch_y_mark) in enumerate(data_loader):\n",
" print('----- batch_x -----')\n",
" print(batch_x.size())\n",
" print(batch_x[0,:3,:])\n",
" print(batch_x[0,-3:,:])\n",
" \n",
" print('\\n----- batch_y -----')\n",
" print(batch_y.size())\n",
" print(batch_y[0,:3,:])\n",
" print(batch_y[0,-3:,:])\n",
" \n",
" print('\\n----- batch_x_mark -----')\n",
" print(batch_x_mark.size())\n",
" print(batch_x_mark[0,:3,:])\n",
" print(batch_x_mark[0,-3:,:])\n",
" \n",
" print('\\n----- batch_y_mark -----')\n",
" print(batch_y_mark.size())\n",
" print(batch_y_mark[0,:3,:])\n",
" print(batch_y_mark[0,-3:,:])\n",
"\n",
" # decoder input\n",
" dec_inp = torch.zeros_like(batch_y[:,-pred_len:,:]).double()\n",
" dec_inp = torch.cat([batch_y[:,:label_len,:], dec_inp], dim=1)\n",
" \n",
" print('\\n----- dec_inp -----')\n",
" print(dec_inp.size())\n",
" print(dec_inp[0,:3,:])\n",
" print(dec_inp[0,-26:-22,:])\n",
" print(dec_inp[0,-3:,:])\n",
" break"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.0"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment