Skip to content

Instantly share code, notes, and snippets.

@CookieBox26
Created September 6, 2020 14:50
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save CookieBox26/831d03037141f852f9a47a10a9eb4780 to your computer and use it in GitHub Desktop.
Save CookieBox26/831d03037141f852f9a47a10a9eb4780 to your computer and use it in GitHub Desktop.
TCN で Sequential MNIST を学習する
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# TCN で Sequential MNIST を学習する\n",
"\n",
"### 参考文献\n",
"1. [[1803.01271]An Empirical Evaluation of Generic Convolutional and Recurrent Networks for Sequence Modeling](https://arxiv.org/abs/1803.01271) ;TCNの原論文。\n",
"1. [locuslab/TCN: Sequence modeling benchmarks and temporal convolutional networks](https://github.com/locuslab/TCN) ;TCNの原論文のリポジトリ。\n",
"1. [Normalization in the mnist example - PyTorch Forums](https://discuss.pytorch.org/t/normalization-in-the-mnist-example/457); (0.1307,), (0.3081,) は MNIST の訓練データの平均と標準偏差であるらしい。\n",
"\n",
"### 関連Gist\n",
"- [ニューラルネットで足し算する(Temporal Convolutional Network )](https://gist.github.com/CookieBox26/8e314f1164d7d5beea5312d625115fed); TCN による足し算タスク。\n",
"- [LSTM で足し算する](https://gist.github.com/CookieBox26/31a1247c0e31d6109067229a151ead66); LSTM による足し算タスク。"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"<table width=\"100%\">\n",
"<tr>\n",
"<td width=\"80px\" style=\"vertical-align:top;\"><img src=\"https://cookiebox26.github.io/ToyBox/20200407_100ML/1.png\"></td>\n",
"<td style=\"vertical-align:top;text-align:left;\">TCN の原論文では、系列データのモデリング性能を検証するタスクの1つとして、手書き数字データの MNIST をいちいちピクセルの系列として読み込んで分類するタスクをしていますよね。もっとも、このタスク自体はこの論文以前から思考的な系列モデリングタスクとしてちょいちょい採用されているものですが。ともかく著者のリポジトリでいうと、Sequential MNIST and P-MNIST タスクは以下のディレクトリですね。\n",
"<ul style=\"margin:0.3em 0\"><li><a href=\"https://github.com/locuslab/TCN/tree/master/TCN/mnist_pixel\">https://github.com/locuslab/TCN/tree/master/TCN/mnist_pixel</a></li></ul>\n",
"データ取得するコードを以下に抜粋しましたが、transforms.Compose([]) にはデータに施すべきプレ処理を登録しておくのですよね。ここではTorchテンソル化と正規化を登録していますが、transforms.Normalize((0.1307,), (0.3081,)) とは、「すべてのピクセルから 0.1307 を引いて 0.3081 で割ります」というプレ処理ですか? 何ですかこの 0.1307 と 0.3081 は。\n",
"</td>\n",
"</tr>\n",
"<tr>\n",
"<td width=\"80px\" style=\"vertical-align:top;\"><img src=\"https://cookiebox26.github.io/ToyBox/20200407_100ML/2.png\"></td>\n",
"<td style=\"vertical-align:top;text-align:left;\">\n",
"MNIST の訓練データの平均と標準偏差っぽいね。0.1307 と 0.3081 でぐぐったら出てきた。\n",
"<ul style=\"margin:0.3em 0\"><li><a href=\"https://discuss.pytorch.org/t/normalization-in-the-mnist-example/457\">Normalization in the mnist example - PyTorch Forums</a></li></ul>\n",
"</td>\n",
"</tr>\n",
"<tr>\n",
"<td width=\"80px\" style=\"vertical-align:top;\"><img src=\"https://cookiebox26.github.io/ToyBox/20200407_100ML/1.png\"></td>\n",
"<td style=\"vertical-align:top;text-align:left;\">ええ…ハイコンテクストすぎませんか…。\n",
"</td>\n",
"</tr>\n",
"</table>"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"scrolled": false
},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAMgAAADZCAYAAABl0n+gAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAStUlEQVR4nO3dfZRcdX3H8fcHyBN5gkKCkjYsopBiOEazYIWAOZEiAgdCEBSLpxTbUB6UUEq1VIEIKqJggORUQyPQ1iICSVDxWJBDalFTWIVoeAgKhhAxkAiJkAcIybd/3Lsnk9nf/rIPs3t3k8/rnDnZ+d47c7+z2c/87v3NnRlFBGaWtlvVDZj1ZQ6IWYYDYpbhgJhlOCBmGQ7ITkrSGEljyp9HSzqg6p76IwekCyR9VNJBVfexA98C7ih/vhb4n47cSNJCSYt6qqn+ZpcJiKS/kBRduAyuu5+3Av8FnCtpXHkZuIP7+GjN7Sd1cvurevj3Ml7SXj25jcQ2J+/gMTf1Zj85e1TdQC96DPjzRH1P4OfA54C7Estfr7t+NiDg0vICMK7mvu8DfgxcDewL/G/d7X/eTh/tebMT63bFr4C/AW7t4e2kfAZ4LlFf3duNtGeXCUhEbAKeqq9LGlb+uCoi2iyvW3cEcBFFkM6i+ONaDTwdESHpDGB/4PqIeErS2PKm62v62Jjqo6sknQp8AfhkRDxQs53JNT+f3ajtNdh/R8RjVTeRs8vsYjXIV4H9gPcCS4F3AGOBeZLOBW4D/gH4laRxwBnl7Z7uiWYkTQT+E3gbcK+kU3piO53o58pyF2lRB28yVNJISerJvrpjlwuIpPdL+moXbncG8Hfl1auBtwM3Ab8GPgCcCJwUETdS7EI9CVwF/GtELKu5n8e6cBx0ZaKf/YHvAiuBA4HFwF2S/qqzj61CDwFrgVfLyYEJVTdUb5fZxapxOHAJ8I919RGS3lJXeykitpY/30VxzPEVoHUG663ASOC3EXGypI9JugIYDTwTEW9PbH8qMDhRTxkHLACeqS1KGgLcA4wA/jIifi/pROBe4N8lDY+Ir3dwG400G/g2sGEH662iOOZ7CthE8fs8H1gs6YMR0aEZt14REbvUhSIYUXN9GBDtXP607raTy/rr5b+bgS3AonL5/wG/o9jteY7iD3wcMLyLvX6O4iD9T2pqAr5Tbv+MuvWHAovKZf/UwW0EcHbN9YWtj6eX/1+GAMuAR6r+G6m97IojSHtuopiBqrWmnXUPAX5Lscs1GWiSNJ5idBLQupvzZPnvmRTPrB0maU/gAmB+RLxcs+jzwOnAdRHxndrbRMR6SScAPwC+LGlkRPxLZ7ab6edvgSbgc1H+RTdSRGyU9G8UfQ+NiPU7vFEvcEC2+WVEfL+D615S/jsNOAB4heK4ZAvF8ccM4PhI72J11FUU08RfaC1IehvF1OiDwKdTN4qIDWVI7gUuK2fePpX6o5Y0qPyxfiq7fr2RZR9LeyIcNV6geIIZSc3MX5V2uYP0btq9/HcMxSvVf6TYRQO4Dvh0RPymuxuRdA7FbNg1EbGktR4RzwInAB+NiC3t3T4iNlBMGiwCLgRukbR7YtXWaeiVmV4EzAOGA9N30HdnZ7HqvRN4A/hDF2/fcB5BOqf1IP6ciFgLIOlWil2PH5fXr2tdWVLrs+1zEdG0ozsvn9GvAP4ZuBu4sn6diLi/I42WI8mJwPeBvwaGSzozIt6oWW0axeixNHNX15TrfTwinsms1y2SDqYI870RkR3RepMD0jlTgN+3hqNO7avjnwWOBj5YXt+cu9NyN+jjwMUUMzqzgYsjoluvopchOYkiJNOAmRThQ9KBFBMWCyLilURP+wE3UryW88mI+FYHNtmhWSxJP6QYtX5JsSs1nuJ47lWKkbPvqHqWoBdnSR6j/dmq3OXW8vYHUPzHf6O83kQxQ/UD4L66bc0GfpPpRcDfA9dTvBawudzWT4AP9MBj35Pij31EeX0S8CzFdOv+desupDimWg+8RPHaTqP7mQE8DLxMsUv1HPB1YEzVfyf1l11pBJkJ7NOF27W+Cj6K4kXBr5TXrwE+QvGsd35n7jAiQtIk4GTgcYqzbRdEREsX+uvI9jYAn6opTaII5PER8ULd6qspjrXmAl+OiIafLBkRs4BZjb7fnqAy0dbLJA0ENkcF/wHlAfvQiPhjYtkgYHBErOvtvvoiB8Qsw9O8ZhkNCYikAZJulvS6pFclzWzE/ZpVrVEH6VcDR1KcajECuFPSyxFxQ4Pu36wS3T4GKc8sXQN8JMpTNcr3RsykmELcmrrdvvvuG01NTd3atlkjLF++nDVr1iTfk9KIEeTdwADggZraQop57UPYdsLedpqammhp6ZFZTbNOaW5ubndZI45BxgCrozgbc5akiyPiRYoX1fxRM9avNWIEGQy0zpmPZ9vJe2upe2OQpOmUJ7yNHTsWs76uESPIBrYF4XHgifLnvag7Jyci5kZEc0Q0jxo1qgGbNutZjRhBVgD7S9ojIi6C4pP8KM7/SX2ki1m/0YgRZAnFyXZH1tSmUbz5pUc+zcOst3Q7IFG8v2AOMEvSoZKOoXgfwzVVnGdk1kiNeqHwcopP8niUYjS5ISJuatB9m1WmIQEpR5FzyovZTsMnK5plOCBmGQ6IWYYDYpbhgJhlOCBmGQ6IWYYDYpbhgJhlOCBmGQ6IWYYDYpbhgJhlOCBmGQ6IWYYDYpbhgJhlOCBmGbvSN0z1C1u3Jj/KmNdfb8z3Wt52223J+vr1bb91+YknnkisCbNmpb8c6rLLLkvWZ8+enawPGTIkWb/uuuuS9fPOOy9Z70keQcwyHBCzDAfELMMBMctwQMwyPIvVBevWpb8hecuWLcn6kiVL2tTuu+++5Lpr165N1ufOndvB7hqnvW8Au+SSS5L1efPmJesjR45M1o8++uhkfcqUKTturpd4BDHLcEDMMhwQswwHxCzDATHL8CxWxsqVK5P1CRMmJOuvvPJKT7bTo3bbre1zZXuzUu2dQ/WJT3wiWR89enSyPmzYsGS9L31/pUcQswwHxCzDATHLcEDMMhwQswzPYmXss88+yfp+++2XrFcxi3Xccccl6+31Pn/+/GR90KBBbWqTJ0/ucl87C48gZhkOiFmGA2KW4YCYZTggZhmexcpo75yjW2+9NVm/6667kvX3ve99bWqnnXZap3qZNGlSsn7PPfck6wMHDkzWV61alazfcMMNnepnV+ERxCzDATHLcEDMMjoVEEnXS5pdVxshaYGkzZLWSOr9D1A16yEdDoikEcDpiUU3A3sD7wTOBq6VdEpDujOrmCJixytJJwJ3A4OAORFxYVk/AFgOHBYRS8val4BjI+Lw3H02NzdHS0tL97rvY9r7BPbUjFJ7n4R+7bXXJusPPvhgsn7MMcd0sDtrT3NzMy0tLUot6+gI8iAwjiIktY4CVraGo7QQmChpaKc7NetjOhSQiNgQEcuB1+oWjQFWAJTHIdOAZwEBf9bAPs0q0d1ZrMFA6+dwHgYcCqytWbYdSdMltUhqWb16dTc3bdbzuhuQDWwLwlLgSWCvmmXbiYi5EdEcEc196ZMrzNrT3VNNVgBjASJiKoCkI4CtwPPdvO9+J/Wmo/bsvffenbrvG2+8MVlv7wOgpeQxp3VSd0eQnwIHlrNZraYBD0fExm7et1nluhWQiPgdcCfwDUkHSToZuBD4UiOaM6taI041ORd4A1gG3AJ8JiK+24D7Natcp45BIuLsRG0dcHKjGjLrS3yyolmG3zBVkRkzZiTrDz/8cLK+YMGCZP3xxx9P1sePH9+1xmw7HkHMMhwQswwHxCzDATHLcEDMMjyLVZH2PpZn7ty5yfoDDzyQrJ9ySvrNm1OnTk3WjzrqqGT91FNPbVPz+VweQcyyHBCzDAfELMMBMctwQMwyOvSxPz1hZ/zYn57U3jlaxx9/fLK+bt26ZL093/zmN9vU2vuA7WHDhnXqvvu6Rnzsj9kuyQExy3BAzDIcELMMB8Qsw+di9RNHHHFEst7eOwovvvjiZP3OO+9M1s8555w2tWeeeSa57qWXXpqsDx8+PFnvzzyCmGU4IGYZDohZhgNiluGAmGX4XKyd1KZNm5L1xYsXJ+vHHntsm1p7fxsf/vCHk/U77rijg931LT4Xy6yLHBCzDAfELMMBMctwQMwyfC7WTmrw4DZfMgzA5MmTk/Xdd9+9Te3NN99Mrrtw4cJkfdmyZcn6IYcckqz3Bx5BzDIcELMMB8QswwExy/BBej/3wgsvJOvz589P1n/2s58l6+0dkKccfvjhyfrBBx/c4fvoLzyCmGU4IGYZDohZhgNiluGAmGV4FquPWb16dbI+Z86cZP2WW25J1leuXNntXlKnnwA0NTUl6zvjV7Z5BDHLcEDMMhwQs4wOBUTSOyTdL2mTpDWS5kgaUi4bI2mRpC2SnpfU9vuEzfqpHQZE0lDgh8BLwGHAScDRwNdUHJXdDbwIHAR8FrhdUnOPdWzWizoyizUJ2Bs4OyI2A7+WNBP4KkVQJgIfiohXgOWSjgc+DZzeQz33O6+99lqb2ve+973kup///OeT9aeffrqhPdWbMmVKm9o111yTXHfixIk92ktf0pFdrPuBt5ThaDUUCIrwLC7D0WohcEzjWjSrzg4DEhFbI+KN1uvlscclwHxgDLCirLeUu1bPAqMlpd/zadaPdGoWS9JIihFid2AmMBho/TrVCcA4YG15vU1AJE0vg9TS3gtiZn1JhwMi6VDgEWAQ8P6IeBXYwLYgLAGWAXuV1zfU30dEzI2I5ohoHjVqVLcaN+sNHZ3mPRR4iOJ45NiI+EO5aAUwFiAiJkbEI8CBwKra3TKz/mqHs1iS9gC+DdwRERfULf4JcLmkoRGxvqxNAxY1tMs+Zv369cn6888/n6yfddZZbWqPPvpoQ3uqd9xxxyXrM2fOTNZT7xLcGc+t6qyOTPOeBDQBJ0rat27ZYmApMEfSFcAUYCpwZCObNKtKRwLyLmA45WxVnQMpXu+4nWL2ahVwVkT8omEdmlVohwGJiJkUM1Y5RzemHbO+xScrmmU4IGYZfkchsHHjxmR9xowZyfpDDz2UrD/11FMN66neCSeckKxffvnlyfqECROS9QEDBjSsp12BRxCzDAfELMMBMctwQMwyHBCzjJ1yFmv58uXJ+he/+MVk/Uc/+lGy/txzzzWqpTb23HPPZP2qq65K1s8///xkfeDAgQ3rydryCGKW4YCYZTggZhkOiFmGA2KWsVPOYt19993J+rx58xpy/+95z3uS9TPPPDNZ32OPtr/m6dOnJ9cdPNgfBtOXeAQxy3BAzDIcELMMB8QsQxFRyYabm5ujpaWlkm2b1WpubqalpSX5GUceQcwyHBCzDAfELMMBMctwQMwyHBCzDAfELMMBMctwQMwyHBCzDAfELMMBMctwQMwyHBCzDAfELMMBMctwQMwyKntHoaTVQOunQ+8LrKmkkd7lx9k3HRARo1ILKgvIdk1ILRHRXHUfPc2Ps//xLpZZhgNiltFXAjK36gZ6iR9nP9MnjkHM+qq+MoKY9UmVBkTSAEk3S3pd0quSZlbZT6NJul7S7LraCEkLJG2WtEbSeVX1112S3iHpfkmbyscyR9KQctkYSYskbZH0vKRTq+63K6oeQa4GjgQOBz4ETJd0UbUtNYakEcDpiUU3A3sD7wTOBq6VdEovttYQkoYCPwReAg4DTgKOBr4mScDdwIvAQcBngdsl9b+p34io5AIMAdYDJ9XUzgVWAbtV1VeDHtuJwCYggNk19QPK2via2peAR6ruuQuP8YPAy8CAmtppwG+BY4DNwN41y24H7qy6785eqhxB3g0MAB6oqS0E9gMOqaSjxnkQGEfxLFrrKGBlRCytqS0EJpbPyP3J/cBbImJzTW0oxRPAJGBxRLxSs2whRXD6lSoDMgZYHREbJc2SdHFEvAhsoHim7bciYkNELAdeq1s0BlgBUB6HTAOeBQT8Wa822U0RsTUi3mi9Xh57XALMZ/vH2VLuWj0LjJbUr75Cq8qADAbWlT+Pp9gnB1hbLtsZ1T7mw4BDKR5v67J+SdJIihFid2Am2z/OCRSjab98nFUGZAPbflmPA0+UP+9VLtsZ1T7mpcCTFI+3dVm/I+lQ4BFgEPD+iHiV7R/nEmAZ/fRxVvklniuA/SXtEREXAUgaDezJtrN8dzYrgLEAETEVQNIRwFbg+Qr76pIyHA9RHIBfFBFvlotWUBzEExETy3XPAFbV7pb1B1WOIEsoZjqOrKlNA14Anq6ko573U+BASbXHWNOAhyNiY0U9dYmkPYBvA3dExAU14QD4CfDeuomHacCiXmyxISobQSLiDUlzgFmSzqJ4D8GVwBeinBfc2UTE7yTdCXxD0gUUx10XAh+rtrMuOQloAk6UtG/dssUUu5BzJF0BTAGmsv2TYb9Q9fekXw6MBh6lGE1uiIibqm2px50L/AfFfvk64DMR8d1qW+qSdwHDKWer6hxI8SLp7RSzV6uAsyLiF73XXmP4ZEWzjKpPNTHr0xwQswwHxCzDATHLcEDMMhwQswwHxCzDATHL+H9Yk74D4bSz6wAAAABJRU5ErkJggg==\n",
"text/plain": [
"<Figure size 864x216 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"◆ 元々の1バッチのサイズ\n",
"torch.Size([64, 1, 28, 28])\n",
"◆ 系列データ化後の1バッチのサイズ\n",
"torch.Size([64, 1, 784])\n"
]
}
],
"source": [
"# https://github.com/locuslab/TCN/blob/master/TCN/mnist_pixel/utils.py の内容そのまま\n",
"\n",
"import torch\n",
"from torchvision import datasets, transforms\n",
"\n",
"\n",
"def data_generator(root, batch_size):\n",
" train_set = datasets.MNIST(root=root, train=True, download=True,\n",
" transform=transforms.Compose([\n",
" transforms.ToTensor(),\n",
" transforms.Normalize((0.1307,), (0.3081,))\n",
" ]))\n",
" test_set = datasets.MNIST(root=root, train=False, download=True,\n",
" transform=transforms.Compose([\n",
" transforms.ToTensor(),\n",
" transforms.Normalize((0.1307,), (0.3081,))\n",
" ]))\n",
"\n",
" train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size)\n",
" test_loader = torch.utils.data.DataLoader(test_set, batch_size=batch_size)\n",
" return train_loader, test_loader\n",
"\n",
"\n",
"# データをみてみるテスト\n",
"\n",
"%matplotlib inline\n",
"import matplotlib.pyplot as plt\n",
"from pylab import rcParams\n",
"rcParams['figure.figsize'] = 12, 3\n",
"rcParams['font.size'] = 14\n",
"rcParams['font.family']='Ume Hy Gothic O5'\n",
"\n",
"root = '../data'\n",
"batch_size = 64\n",
"train_loader, test_loader = data_generator(root, batch_size)\n",
"for batch_idx, (data, target) in enumerate(train_loader):\n",
" fig = plt.figure()\n",
" ax1 = fig.add_subplot(1, 1, 1)\n",
" ax1.imshow(data[0, 0], cmap=\"Greys\")\n",
" ax1.set_title(f'正解ラベル : {target[0].item()}')\n",
" plt.show()\n",
"\n",
" print('◆ 元々の1バッチのサイズ')\n",
" print(data.size())\n",
" data = data.view(-1, 1, 784)\n",
" print('◆ 系列データ化後の1バッチのサイズ')\n",
" print(data.size())\n",
" \n",
" break"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"<table width=\"100%\">\n",
"<tr>\n",
"<td width=\"80px\" style=\"vertical-align:top;\"><img src=\"https://cookiebox26.github.io/ToyBox/20200407_100ML/1.png\"></td>\n",
"<td style=\"vertical-align:top;text-align:left;\">次はモデルですが、足し算タスクのときと違うのは最後に log_softmax している点ですね。log_softmax は文字通り softmax してから log を取る関数で、log まで取る方が softmax より安定とのことです。ネットワークの出力を log_softmax にしたら、損失を正解との nll_loss にすれば交差エントロピーになるのですよね。\n",
"<ul style=\"margin:0.3em 0\"><li><a href=\"https://pytorch.org/docs/master/nn.functional.html#torch.nn.functional.log_softmax\">torch.nn.functional &mdash; PyTorch master documentation#torch.nn.functional.log_softmax</a></li></ul>\n",
"あとネットワークの重みを初期化していませんが、わざとなんでしょうか? とりあえず初期化はコメントアウトしておきます。\n",
"</td>\n",
"</tr>\n",
"</table>"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"import torch.nn.functional as F\n",
"import torch.nn as nn\n",
"from torch.nn.utils import weight_norm\n",
"from collections import OrderedDict\n",
"\n",
"\n",
"def _debug_print(debug, *content):\n",
" if debug:\n",
" print(*content)\n",
"\n",
"\n",
"class Chomp1d(nn.Module):\n",
" def __init__(self, chomp_size):\n",
" super(Chomp1d, self).__init__()\n",
" self.chomp_size = chomp_size\n",
"\n",
" def forward(self, x):\n",
" return x[:, :, :-self.chomp_size].contiguous()\n",
"\n",
"\n",
"# 以下と同じネットワークを1クラスで実装したもの\n",
"# https://github.com/locuslab/TCN/blob/master/TCN/tcn.py\n",
"class TCN(nn.Module):\n",
" def __init__(self,\n",
" input_size=1,\n",
" output_size=10,\n",
" num_channels=[25]*8,\n",
" kernel_size=7,\n",
" dropout=0.0):\n",
" super(TCN, self).__init__()\n",
" self.layers = OrderedDict()\n",
" self.num_levels = len(num_channels)\n",
"\n",
" for i in range(self.num_levels):\n",
" dilation = 2 ** i\n",
" n_in = input_size if (i == 0) else num_channels[i-1]\n",
" n_out = num_channels[i]\n",
" padding = (kernel_size - 1) * dilation\n",
" # ========== TemporalBlock ==========\n",
" self.layers[f'conv1_{i}'] \\\n",
" = weight_norm(nn.Conv1d(n_in, n_out, kernel_size,\n",
" padding=padding,\n",
" dilation=dilation))\n",
" self.layers[f'chomp1_{i}'] = Chomp1d(padding)\n",
" self.layers[f'relu1_{i}'] = nn.ReLU()\n",
" self.layers[f'dropout1_{i}'] = nn.Dropout(dropout)\n",
" self.layers[f'conv2_{i}'] \\\n",
" = weight_norm(nn.Conv1d(n_out, n_out, kernel_size,\n",
" padding=padding,\n",
" dilation=dilation))\n",
" self.layers[f'chomp2_{i}'] = Chomp1d(padding)\n",
" self.layers[f'relu2_{i}'] = nn.ReLU()\n",
" self.layers[f'dropout2_{i}'] = nn.Dropout(dropout)\n",
" self.layers[f'downsample_{i}'] = nn.Conv1d(n_in, n_out, 1) \\\n",
" if (n_in != n_out) else None\n",
" self.layers[f'relu_{i}'] = nn.ReLU()\n",
" # ===================================\n",
" self.network = nn.Sequential(self.layers)\n",
" self.linear = nn.Linear(num_channels[-1], output_size)\n",
" # self.init_weights()\n",
"\n",
" # def init_weights(self):\n",
" # for i in range(self.num_levels):\n",
" # self.layers[f'conv1_{i}'].weight.data.normal_(0, 0.01)\n",
" # self.layers[f'conv2_{i}'].weight.data.normal_(0, 0.01)\n",
" # if self.layers[f'downsample_{i}'] is not None:\n",
" # self.layers[f'downsample_{i}'].weight.data.normal_(0, 0.01)\n",
" # self.linear.weight.data.normal_(0, 0.01)\n",
"\n",
" def forward(self, x, debug=False):\n",
" _debug_print(debug, '========== forward ==========')\n",
" _debug_print(debug, x.size())\n",
" for i in range(self.num_levels):\n",
" _debug_print(debug, f'---------- block {i} ----------')\n",
" _debug_print(debug, 'in : ', x.size())\n",
" # Residual Connection\n",
" res = x if (self.layers[f'downsample_{i}'] is None) \\\n",
" else self.layers[f'downsample_{i}'](x)\n",
" out = self.layers[f'conv1_{i}'](x)\n",
" out = self.layers[f'chomp1_{i}'](out)\n",
" out = self.layers[f'relu1_{i}'](out)\n",
" out = self.layers[f'dropout1_{i}'](out)\n",
" out = self.layers[f'conv2_{i}'](out)\n",
" out = self.layers[f'chomp2_{i}'](out)\n",
" out = self.layers[f'relu2_{i}'](out)\n",
" out = self.layers[f'dropout2_{i}'](out)\n",
" _debug_print(debug, 'out: ', out.size())\n",
" _debug_print(debug, 'res: ', res.size())\n",
" x = self.layers[f'relu_{i}'](out + res)\n",
" _debug_print(debug, x.size())\n",
" _debug_print(debug, '-----------------------------')\n",
" _debug_print(debug, x.size())\n",
" x = self.linear(x[:, :, -1])\n",
" _debug_print(debug, x.size())\n",
" _debug_print(debug, '=============================')\n",
" return F.log_softmax(x, dim=1)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"<table width=\"100%\">\n",
"<tr>\n",
"<td width=\"80px\" style=\"vertical-align:top;\"><img src=\"https://cookiebox26.github.io/ToyBox/20200407_100ML/1.png\"></td>\n",
"<td style=\"vertical-align:top;text-align:left;\">\n",
"後は学習ですね。コードの引数のデフォルト値には原論文の Table 2. と全く同じパラメータが入っていました。走らせてみると普通に学習は進んでいくようです。以下のコードは動作確認のため1エポック目をわざと早く終わらせています。\n",
"</td>\n",
"</tr>\n",
"</table>"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Train Epoch: 1 [1280/60000 (2%)]\tLoss: 2.041475\tSteps: 16464\n",
"Train Epoch: 1 [2560/60000 (4%)]\tLoss: 0.820613\tSteps: 32144\n",
"Train Epoch: 1 [3840/60000 (6%)]\tLoss: 0.649577\tSteps: 47824\n",
"Train Epoch: 1 [5120/60000 (9%)]\tLoss: 0.560962\tSteps: 63504\n",
"Train Epoch: 1 [6400/60000 (11%)]\tLoss: 0.446930\tSteps: 79184\n",
"\n",
"Test set: Average loss: 0.6064, Accuracy: 8041/10000 (80%)\n",
"\n",
"Train Epoch: 2 [1280/60000 (2%)]\tLoss: 0.488135\tSteps: 95648\n",
"Train Epoch: 2 [2560/60000 (4%)]\tLoss: 0.327060\tSteps: 111328\n",
"Train Epoch: 2 [3840/60000 (6%)]\tLoss: 0.281475\tSteps: 127008\n",
"Train Epoch: 2 [5120/60000 (9%)]\tLoss: 0.267830\tSteps: 142688\n",
"Train Epoch: 2 [6400/60000 (11%)]\tLoss: 0.232774\tSteps: 158368\n",
"Train Epoch: 2 [7680/60000 (13%)]\tLoss: 0.306946\tSteps: 174048\n",
"Train Epoch: 2 [8960/60000 (15%)]\tLoss: 0.334192\tSteps: 189728\n"
]
}
],
"source": [
"# https://github.com/locuslab/TCN/blob/master/TCN/mnist_pixel/pmnist_test.py から適宜必要なコードを抜粋.\n",
"\n",
"from torch.autograd import Variable\n",
"import torch.optim as optim\n",
"import torch.nn.functional as F\n",
"\n",
"\n",
"root = '../data'\n",
"batch_size = 64\n",
"train_loader, test_loader = data_generator(root, batch_size)\n",
"\n",
"model = TCN(input_size=1,\n",
" output_size=10,\n",
" num_channels=[25]*8,\n",
" kernel_size=7)\n",
"optimizer = optim.Adam(model.parameters(), lr=2e-3)\n",
"\n",
"input_channels = 1\n",
"seq_length = 784\n",
"epochs = 20\n",
"log_interval = 20\n",
"\n",
"steps = 0\n",
"\n",
"def train(ep):\n",
" global steps\n",
" train_loss = 0\n",
" model.train()\n",
" for batch_idx, (data, target) in enumerate(train_loader):\n",
" data = data.view(-1, input_channels, seq_length)\n",
" data, target = Variable(data), Variable(target)\n",
" optimizer.zero_grad()\n",
" output = model(data)\n",
" loss = F.nll_loss(output, target)\n",
" loss.backward()\n",
" optimizer.step()\n",
" train_loss += loss\n",
" steps += seq_length\n",
" if batch_idx > 0 and batch_idx % log_interval == 0:\n",
" print('Train Epoch: {} [{}/{} ({:.0f}%)]\\tLoss: {:.6f}\\tSteps: {}'.format(\n",
" ep, batch_idx * batch_size, len(train_loader.dataset),\n",
" 100. * batch_idx / len(train_loader), train_loss.item()/log_interval, steps))\n",
" train_loss = 0\n",
" \n",
" # 動作確認のため1エポック目をわざと早く終わらせる\n",
" if (ep == 1) and (batch_idx == 100):\n",
" break\n",
" \n",
"def test():\n",
" model.eval()\n",
" test_loss = 0\n",
" correct = 0\n",
" with torch.no_grad():\n",
" for data, target in test_loader:\n",
" data = data.view(-1, input_channels, seq_length)\n",
" data, target = Variable(data), Variable(target)\n",
" output = model(data)\n",
" test_loss += F.nll_loss(output, target, reduction='sum').item()\n",
" pred = output.data.max(1, keepdim=True)[1]\n",
" correct += pred.eq(target.data.view_as(pred)).cpu().sum()\n",
"\n",
" test_loss /= len(test_loader.dataset)\n",
" print('\\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\\n'.format(\n",
" test_loss, correct, len(test_loader.dataset),\n",
" 100. * correct / len(test_loader.dataset)))\n",
" return test_loss\n",
"\n",
"\n",
"# 実行\n",
"for epoch in range(1, epochs + 1):\n",
" train(epoch)\n",
" test()\n",
" if epoch % 10 == 0:\n",
" lr /= 10\n",
" for param_group in optimizer.param_groups:\n",
" param_group['lr'] = lr"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.0"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment