Created
February 13, 2021 15:32
-
-
Save CookieBox26/f066660436ae4cd6bfe4962e512c097d to your computer and use it in GitHub Desktop.
Informer を動かすだけ
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Informer を動かすだけ\n", | |
"\n", | |
"このノートを動作させるには手元に以下の2つのリポジトリをcloneし、Informer2020の直下で作業する。 \n", | |
"https://github.com/zhouhaoyi/Informer2020 \n", | |
"https://github.com/zhouhaoyi/ETDataset \n", | |
" \n", | |
"デフォルトの設定では、ETDataset の1時間毎のデータ(ETDataset/ETT-small/ETTh1.csv)を、\n", | |
"- 現時点までの96ステップ(seq_len)を使って未来の24ステップ(pred_len)を予測する。\n", | |
"- ただし、デコーダには現時点までの48ステップ(label_len)をフィードする。\n", | |
"- つまり、エンコーダへの入力の長さは96、デコーダへの入力の長さは72になる。\n", | |
" - デコーダへの入力は、予測すべき箇所はゼロで塗りつぶしておく。\n", | |
"\n", | |
"| date | x = encoder input | y = decoder target | decoder input | seq_len = 96 | label_len = 48 | pred_len = 24 |\n", | |
"| ---- | ---- | ---- | ---- | ---- | ---- | ---- |\n", | |
"| 2016-07-01 00:00:00 | X | - | - | * | - | - |\n", | |
"| 2016-07-01 01:00:00 | X | - | - | * | - | - |\n", | |
"| *** | X | - | - | * | - | - |\n", | |
"| 2016-07-02 23:00:00 | X | - | - | * | - | - |\n", | |
"| 2016-07-03 00:00:00 | X | X | X | * | * | - |\n", | |
"| *** | X | X | X | * | * | - |\n", | |
"| 2016-07-04 23:00:00 | X | X | X | * | * | - |\n", | |
"| 2016-07-05 00:00:00 | - | X | O | - | - | * |\n", | |
"| *** | - | X | O | - | - | * |\n", | |
"| 2016-07-05 23:00:00 | - | X | O | - | - | * |" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": { | |
"scrolled": false | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"===== batch 0 =====\n", | |
"----- batch_x -----\n", | |
"torch.Size([32, 96, 7])\n", | |
"----- batch_y -----\n", | |
"torch.Size([32, 72, 7])\n", | |
"----- batch_x_mark -----\n", | |
"torch.Size([32, 96, 4])\n", | |
"----- batch_y_mark -----\n", | |
"torch.Size([32, 72, 4])\n", | |
"----- dec_inp -----\n", | |
"torch.Size([32, 72, 7])\n", | |
"----- outputs -----\n", | |
"torch.Size([32, 24, 7])\n", | |
"----- loss -----\n", | |
"tensor(2.6383, dtype=torch.float64, grad_fn=<MseLossBackward>)\n", | |
"===== batch 1 =====\n", | |
"----- batch_x -----\n", | |
"torch.Size([32, 96, 7])\n", | |
"----- batch_y -----\n", | |
"torch.Size([32, 72, 7])\n", | |
"----- batch_x_mark -----\n", | |
"torch.Size([32, 96, 4])\n", | |
"----- batch_y_mark -----\n", | |
"torch.Size([32, 72, 4])\n", | |
"----- dec_inp -----\n", | |
"torch.Size([32, 72, 7])\n", | |
"----- outputs -----\n", | |
"torch.Size([32, 24, 7])\n", | |
"----- loss -----\n", | |
"tensor(0.3763, dtype=torch.float64, grad_fn=<MseLossBackward>)\n", | |
"===== batch 2 =====\n", | |
"----- batch_x -----\n", | |
"torch.Size([32, 96, 7])\n", | |
"----- batch_y -----\n", | |
"torch.Size([32, 72, 7])\n", | |
"----- batch_x_mark -----\n", | |
"torch.Size([32, 96, 4])\n", | |
"----- batch_y_mark -----\n", | |
"torch.Size([32, 72, 4])\n", | |
"----- dec_inp -----\n", | |
"torch.Size([32, 72, 7])\n", | |
"----- outputs -----\n", | |
"torch.Size([32, 24, 7])\n", | |
"----- loss -----\n", | |
"tensor(0.4360, dtype=torch.float64, grad_fn=<MseLossBackward>)\n", | |
"===== batch 3 =====\n", | |
"----- batch_x -----\n", | |
"torch.Size([32, 96, 7])\n", | |
"----- batch_y -----\n", | |
"torch.Size([32, 72, 7])\n", | |
"----- batch_x_mark -----\n", | |
"torch.Size([32, 96, 4])\n", | |
"----- batch_y_mark -----\n", | |
"torch.Size([32, 72, 4])\n", | |
"----- dec_inp -----\n", | |
"torch.Size([32, 72, 7])\n", | |
"----- outputs -----\n", | |
"torch.Size([32, 24, 7])\n", | |
"----- loss -----\n", | |
"tensor(1.6757, dtype=torch.float64, grad_fn=<MseLossBackward>)\n" | |
] | |
} | |
], | |
"source": [ | |
"import torch\n", | |
"import torch.nn as nn\n", | |
"from torch import optim\n", | |
"from data.data_loader import Dataset_ETT_hour\n", | |
"from torch.utils.data import DataLoader\n", | |
"from models.model import Informer\n", | |
"torch.manual_seed(0)\n", | |
"\n", | |
"seq_len = 24 * 4 # エンコーダへの入力ステップ数\n", | |
"label_len = 24 * 2 # デコーダへシードとしてフィードするステップ数\n", | |
"pred_len = 24 # 予測するステップ数\n", | |
"\n", | |
"# エンコーダへの入力は以下の2テンソル\n", | |
"# [バッチサイズ, seq_len, 入力次元数]\n", | |
"# [バッチサイズ, seq_len, タイムスタンプ次元数]\n", | |
"# デコーダへの入力は以下の2テンソル\n", | |
"# [バッチサイズ, label_len + pred_len, 入力次元数] # [:,-pred_len:,:] はゼロ埋め\n", | |
"# [バッチサイズ, label_len + pred_len, タイムスタンプ次元数]\n", | |
"\n", | |
"# データ\n", | |
"data_set = Dataset_ETT_hour(\n", | |
" root_path='../ETDataset/ETT-small/',\n", | |
" data_path='ETTh1.csv',\n", | |
" flag='train', \n", | |
" size=[seq_len, label_len, pred_len],\n", | |
" features='M', # M の場合は HUFL, HULL, MUFL, MULL, LUFL, LULL も利用, S の場合は OT のみ利用\n", | |
" scale=True,\n", | |
")\n", | |
"data_loader = DataLoader(\n", | |
" data_set,\n", | |
" batch_size=32,\n", | |
" shuffle=False,\n", | |
" num_workers=0,\n", | |
" drop_last=True\n", | |
")\n", | |
"\n", | |
"# モデル\n", | |
"model = Informer(\n", | |
" enc_in=7, dec_in=7, c_out=7, # ETTh1 データの次元数は 7 (HUFL, HULL, MUFL, MULL, LUFL, LULL, OT)\n", | |
" seq_len=seq_len, label_len=label_len, out_len=pred_len,\n", | |
" d_ff=1024,\n", | |
" data='ETTh', # データによって埋め込みの仕方が変わるため指定\n", | |
").double()\n", | |
"# モデルの詳細を表示\n", | |
"# print(model)\n", | |
"# 学習対象パラメータを表示\n", | |
"# for name, param in model.named_parameters():\n", | |
"# print(name.ljust(14), param.size())\n", | |
"\n", | |
"# 訓練\n", | |
"model_optim = optim.Adam(model.parameters(), lr=0.0001)\n", | |
"criterion = nn.MSELoss()\n", | |
"for i, (batch_x, batch_y, batch_x_mark, batch_y_mark) in enumerate(data_loader):\n", | |
" print(f'===== batch {i} =====')\n", | |
" batch_x = batch_x.double()\n", | |
" batch_y = batch_y.double()\n", | |
" batch_x_mark = batch_x_mark.double()\n", | |
" batch_y_mark = batch_y_mark.double()\n", | |
" print('----- batch_x -----')\n", | |
" print(batch_x.size()) \n", | |
" print('----- batch_y -----')\n", | |
" print(batch_y.size())\n", | |
" print('----- batch_x_mark -----')\n", | |
" print(batch_x_mark.size())\n", | |
" print('----- batch_y_mark -----')\n", | |
" print(batch_y_mark.size())\n", | |
"\n", | |
" # decoder input\n", | |
" dec_inp = torch.zeros_like(batch_y[:,-pred_len:,:]).double()\n", | |
" dec_inp = torch.cat([batch_y[:,:label_len,:], dec_inp], dim=1).double()\n", | |
" print('----- dec_inp -----')\n", | |
" print(dec_inp.size())\n", | |
" \n", | |
" outputs = model(batch_x, batch_x_mark, dec_inp, batch_y_mark)\n", | |
" print('----- outputs -----')\n", | |
" print(outputs.size())\n", | |
" \n", | |
" batch_y = batch_y[:,-pred_len:,:]\n", | |
" loss = criterion(outputs, batch_y)\n", | |
" print('----- loss -----')\n", | |
" print(loss)\n", | |
"\n", | |
" loss.backward()\n", | |
" model_optim.step()\n", | |
" \n", | |
" if i == 3:\n", | |
" break" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.7.0" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment