Skip to content

Instantly share code, notes, and snippets.

@CookieBox26
Created February 13, 2021 15:32
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save CookieBox26/f066660436ae4cd6bfe4962e512c097d to your computer and use it in GitHub Desktop.
Save CookieBox26/f066660436ae4cd6bfe4962e512c097d to your computer and use it in GitHub Desktop.
Informer を動かすだけ
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Informer を動かすだけ\n",
"\n",
"このノートを動作させるには手元に以下の2つのリポジトリをcloneし、Informer2020の直下で作業する。 \n",
"https://github.com/zhouhaoyi/Informer2020 \n",
"https://github.com/zhouhaoyi/ETDataset \n",
" \n",
"デフォルトの設定では、ETDataset の1時間毎のデータ(ETDataset/ETT-small/ETTh1.csv)を、\n",
"- 現時点までの96ステップ(seq_len)を使って未来の24ステップ(pred_len)を予測する。\n",
"- ただし、デコーダには現時点までの48ステップ(label_len)をフィードする。\n",
"- つまり、エンコーダへの入力の長さは96、デコーダへの入力の長さは72になる。\n",
" - デコーダへの入力は、予測すべき箇所はゼロで塗りつぶしておく。\n",
"\n",
"| date | x = encoder input | y = decoder target | decoder input | seq_len = 96 | label_len = 48 | pred_len = 24 |\n",
"| ---- | ---- | ---- | ---- | ---- | ---- | ---- |\n",
"| 2016-07-01 00:00:00 | X | - | - | * | - | - |\n",
"| 2016-07-01 01:00:00 | X | - | - | * | - | - |\n",
"| *** | X | - | - | * | - | - |\n",
"| 2016-07-02 23:00:00 | X | - | - | * | - | - |\n",
"| 2016-07-03 00:00:00 | X | X | X | * | * | - |\n",
"| *** | X | X | X | * | * | - |\n",
"| 2016-07-04 23:00:00 | X | X | X | * | * | - |\n",
"| 2016-07-05 00:00:00 | - | X | O | - | - | * |\n",
"| *** | - | X | O | - | - | * |\n",
"| 2016-07-05 23:00:00 | - | X | O | - | - | * |"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"scrolled": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"===== batch 0 =====\n",
"----- batch_x -----\n",
"torch.Size([32, 96, 7])\n",
"----- batch_y -----\n",
"torch.Size([32, 72, 7])\n",
"----- batch_x_mark -----\n",
"torch.Size([32, 96, 4])\n",
"----- batch_y_mark -----\n",
"torch.Size([32, 72, 4])\n",
"----- dec_inp -----\n",
"torch.Size([32, 72, 7])\n",
"----- outputs -----\n",
"torch.Size([32, 24, 7])\n",
"----- loss -----\n",
"tensor(2.6383, dtype=torch.float64, grad_fn=<MseLossBackward>)\n",
"===== batch 1 =====\n",
"----- batch_x -----\n",
"torch.Size([32, 96, 7])\n",
"----- batch_y -----\n",
"torch.Size([32, 72, 7])\n",
"----- batch_x_mark -----\n",
"torch.Size([32, 96, 4])\n",
"----- batch_y_mark -----\n",
"torch.Size([32, 72, 4])\n",
"----- dec_inp -----\n",
"torch.Size([32, 72, 7])\n",
"----- outputs -----\n",
"torch.Size([32, 24, 7])\n",
"----- loss -----\n",
"tensor(0.3763, dtype=torch.float64, grad_fn=<MseLossBackward>)\n",
"===== batch 2 =====\n",
"----- batch_x -----\n",
"torch.Size([32, 96, 7])\n",
"----- batch_y -----\n",
"torch.Size([32, 72, 7])\n",
"----- batch_x_mark -----\n",
"torch.Size([32, 96, 4])\n",
"----- batch_y_mark -----\n",
"torch.Size([32, 72, 4])\n",
"----- dec_inp -----\n",
"torch.Size([32, 72, 7])\n",
"----- outputs -----\n",
"torch.Size([32, 24, 7])\n",
"----- loss -----\n",
"tensor(0.4360, dtype=torch.float64, grad_fn=<MseLossBackward>)\n",
"===== batch 3 =====\n",
"----- batch_x -----\n",
"torch.Size([32, 96, 7])\n",
"----- batch_y -----\n",
"torch.Size([32, 72, 7])\n",
"----- batch_x_mark -----\n",
"torch.Size([32, 96, 4])\n",
"----- batch_y_mark -----\n",
"torch.Size([32, 72, 4])\n",
"----- dec_inp -----\n",
"torch.Size([32, 72, 7])\n",
"----- outputs -----\n",
"torch.Size([32, 24, 7])\n",
"----- loss -----\n",
"tensor(1.6757, dtype=torch.float64, grad_fn=<MseLossBackward>)\n"
]
}
],
"source": [
"import torch\n",
"import torch.nn as nn\n",
"from torch import optim\n",
"from data.data_loader import Dataset_ETT_hour\n",
"from torch.utils.data import DataLoader\n",
"from models.model import Informer\n",
"torch.manual_seed(0)\n",
"\n",
"seq_len = 24 * 4 # エンコーダへの入力ステップ数\n",
"label_len = 24 * 2 # デコーダへシードとしてフィードするステップ数\n",
"pred_len = 24 # 予測するステップ数\n",
"\n",
"# エンコーダへの入力は以下の2テンソル\n",
"# [バッチサイズ, seq_len, 入力次元数]\n",
"# [バッチサイズ, seq_len, タイムスタンプ次元数]\n",
"# デコーダへの入力は以下の2テンソル\n",
"# [バッチサイズ, label_len + pred_len, 入力次元数] # [:,-pred_len:,:] はゼロ埋め\n",
"# [バッチサイズ, label_len + pred_len, タイムスタンプ次元数]\n",
"\n",
"# データ\n",
"data_set = Dataset_ETT_hour(\n",
" root_path='../ETDataset/ETT-small/',\n",
" data_path='ETTh1.csv',\n",
" flag='train', \n",
" size=[seq_len, label_len, pred_len],\n",
" features='M', # M の場合は HUFL, HULL, MUFL, MULL, LUFL, LULL も利用, S の場合は OT のみ利用\n",
" scale=True,\n",
")\n",
"data_loader = DataLoader(\n",
" data_set,\n",
" batch_size=32,\n",
" shuffle=False,\n",
" num_workers=0,\n",
" drop_last=True\n",
")\n",
"\n",
"# モデル\n",
"model = Informer(\n",
" enc_in=7, dec_in=7, c_out=7, # ETTh1 データの次元数は 7 (HUFL, HULL, MUFL, MULL, LUFL, LULL, OT)\n",
" seq_len=seq_len, label_len=label_len, out_len=pred_len,\n",
" d_ff=1024,\n",
" data='ETTh', # データによって埋め込みの仕方が変わるため指定\n",
").double()\n",
"# モデルの詳細を表示\n",
"# print(model)\n",
"# 学習対象パラメータを表示\n",
"# for name, param in model.named_parameters():\n",
"# print(name.ljust(14), param.size())\n",
"\n",
"# 訓練\n",
"model_optim = optim.Adam(model.parameters(), lr=0.0001)\n",
"criterion = nn.MSELoss()\n",
"for i, (batch_x, batch_y, batch_x_mark, batch_y_mark) in enumerate(data_loader):\n",
" print(f'===== batch {i} =====')\n",
" batch_x = batch_x.double()\n",
" batch_y = batch_y.double()\n",
" batch_x_mark = batch_x_mark.double()\n",
" batch_y_mark = batch_y_mark.double()\n",
" print('----- batch_x -----')\n",
" print(batch_x.size()) \n",
" print('----- batch_y -----')\n",
" print(batch_y.size())\n",
" print('----- batch_x_mark -----')\n",
" print(batch_x_mark.size())\n",
" print('----- batch_y_mark -----')\n",
" print(batch_y_mark.size())\n",
"\n",
" # decoder input\n",
" dec_inp = torch.zeros_like(batch_y[:,-pred_len:,:]).double()\n",
" dec_inp = torch.cat([batch_y[:,:label_len,:], dec_inp], dim=1).double()\n",
" print('----- dec_inp -----')\n",
" print(dec_inp.size())\n",
" \n",
" outputs = model(batch_x, batch_x_mark, dec_inp, batch_y_mark)\n",
" print('----- outputs -----')\n",
" print(outputs.size())\n",
" \n",
" batch_y = batch_y[:,-pred_len:,:]\n",
" loss = criterion(outputs, batch_y)\n",
" print('----- loss -----')\n",
" print(loss)\n",
"\n",
" loss.backward()\n",
" model_optim.step()\n",
" \n",
" if i == 3:\n",
" break"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.0"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment