Skip to content

Instantly share code, notes, and snippets.

@CookieBox26
Last active August 17, 2022 14:34
Show Gist options
  • Save CookieBox26/61451c79d7d7e54e180df47183cd43f1 to your computer and use it in GitHub Desktop.
Save CookieBox26/61451c79d7d7e54e180df47183cd43f1 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"id": "e71f743f-5d4b-44da-8282-2fbc5d8bc838",
"metadata": {},
"source": [
"## LeNet 編\n",
"\n",
"#### 参考文献\n",
"- [1] [Probabilistic Machine Learning: An Introduction](https://probml.github.io/pml-book/book1.html) (テキスト) \n",
"- [2] [Yann LeCun, Leon Bottou, Yoshua Bengio, Patrick Haffner. Gradient-Based Learning Applied to Document Recognition. in proc of the IEEE, 1998.](http://yann.lecun.com/exdb/publis/pdf/lecun-01a.pdf) (原論文)\n",
" - 7ページに LeNet-5 の節がある。原論文通りの実装は複雑なので今回の実装は [1] にしたがうが、活性化がわからない。\n",
"\n",
"#### LeNet とは [1][2]\n",
"最初期の畳み込みニューラルネットであり、1998 年に Yann LeCun に提案された。 \n",
"MNIST を訓練すると1エポックでテストデータ正解率が 98.8% になる。 \n",
"訓練を続けるともはやラベル誤差というレベルまで正解率が上昇する。 \n",
"原論文はそもそも文章認識に取り組んでおり、LeNet はシステム中の文字認識のパーツとなっている。"
]
},
{
"cell_type": "markdown",
"id": "5633ca2b-2e4c-4b2d-8ac1-88a14e5eb5a6",
"metadata": {},
"source": [
"#### そういうわけで MNIST を用意する\n",
"訓練データの最初の 5 からなんか雑である。"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "34952295-671a-4680-b726-9b550424faa2",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"◆ データのサイズ(データ数、縦方向ピクセル数、横方向ピクセル数)\n",
"訓練データ torch.Size([60000, 28, 28])\n",
"テストデータ torch.Size([10000, 28, 28])\n",
"\n",
"◆ 訓練データの最初の4枚を描画\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA9UAAAEVCAYAAADuPcaCAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/NK7nSAAAACXBIWXMAAA9hAAAPYQGoP6dpAAA1wklEQVR4nO3deXxU5d338e9kISEBgmHTCARkc8EIIhoXQKWCVatVoxYU484mLuhTxduKcgvUPmprEXfA6m3tU5BWRcWlEgFJ2KpUiEtVCAQULIEkQMg21/NH7oyMyXWSOcxklnzer9f8wfzOcuUw35n5zZk5l8cYYwQAAAAAAAIWF+4BAAAAAAAQrWiqAQAAAABwiaYaAAAAAACXaKoBAAAAAHCJphoAAAAAAJdoqgEAAAAAcImmGgAAAAAAl2iqAQAAAABwiaYaESs/Pz8o29m5c6f+/e9/B2VbAAJHloHYQJaB6EeOQ4OmGhHphRde0BlnnKFHHnnksLZTVFSk4447ThdffLH27dsXpNEBaC6yDMQGsgxEP3IcOjTViDhbt27VXXfdpYSEBP3iF784rG1lZmZqyJAh+uKLL3TfffcFaYQAmoMsA7GBLAPRjxyHFk11GL344ovyeDy67rrrwj2UiHLvvfeqrKxMkydP1vHHH+9X83g8zbrl5eX51pkzZ44SExM1d+5cffbZZy3816A1IMuNc8qyJNXW1mrWrFnq2bOnPB6PUlNTdcEFF6iwsLDR7ZFlhBpZblxTWf6pqqoqnXDCCfJ4PHrxxRcb1MkyQokcN665Of7qq68UFxens88+23F75NgfTTUiSnFxsRYuXKjExERNmzbNulxmZqbjLTk52bfsscceq7Fjx8rr9erBBx9sgb8CQFNZNsZo7Nix+q//+i9t27ZNnTt3VnV1td555x2deuqpKigoaLAOWQZaXnNflw/1u9/9zvrhmESWgZbW3Bx///33uummm2SMaXKb5NgfTTUiypNPPqmamhpdfPHF6tatm3W5LVu2ON6ys7P9lp84caIk6fXXX9e2bdtC+jcAaDrLTz31lP7617+qW7duys/P1w8//KDdu3fr5ptv1v79+3XppZeqrKyswXpkGWhZzX1drvfNN99o5syZTS5HloGW01SO9+7dK4/Ho6OOOkorVqxo9nbJ8Y9oqhFR/vKXv0iSbrjhhqBu97TTTlNWVpZqa2v15z//OajbBtCQU5Zra2s1Y8YMSXUXTan/EKx9+/Z69tlndcYZZ+j777/XE0880WBdsgy0rEBflydOnKiDBw/K4/E4LkeWgZbTVI7j4uJ83/Zszodn9cjxIQxa3LJly4wk661ebW2teeaZZ8ygQYNMUlKSiY+PN3369DEPPPCAOXDgQIPt5ubmGklmwYIFZvv27ea2224zmZmZJj4+3rRv395cffXVZufOnQ3W+/77782tt95q+vbta1JSUowk06VLFzNmzBjz1VdfNfo3VFVVmVmzZpk+ffqYuLg4ExcXZ4499ljz9NNPG6/X22D5ESNGGEnm8ccfN7/85S9Nhw4djCTz+9//3rfMli1bjCSTkJBg9u/f3+h+f3qMAnHXXXcZSWbYsGGu1gd+iiy7y3L9cevbt2+jY3r99deNJHPcccc1WifLCDay7P51+VCvvPKKb6y/+MUvfH+7DVlGMJHj4OS4/jiOGDGiyWWNIcf1aKrDID8/32RmZppOnToZSSY1NdVkZmb6bsYY4/V6za9+9SvfE0GHDh1M165dff/Ozs42+/bt89tufeivueYak5aW5ls2PT3dxMXFGUlmwIABfk8YZWVlpm/fvr5lO3fubDIyMkx8fLyRZNq1a2fWrl3rt5+Kigrzs5/9zLdOx44dTfv27X3/HjNmjKmtrfVbpz709be0tDSTmZlp5s2b51vmpZdeMpLMkCFDrMeufv0BAwaYY445xpx00knmyiuvNC+++KKpqKhwPO5/+9vfjCSTnJxsqqurHZcFmoMsu8vy7NmzjSQzadKkRusVFRUmKSnJSDIlJSUN6mQZwUaW3b8u1yspKTHdunUzksxLL73k14jYkGUEEzk+/BwbE3hTTY7r0FSH0YIFC4wkk5ub26A2f/58X6Bef/1136dTn376qTnuuOOMJPPrX//ab5360NffrrjiCvP5558bY4zZvHmzyczMbPACN2fOHCPJnHzyyaa4uNh3/65du0xOTo6RZLKysvw+HbvjjjuMJNOpUyezZMkSY4wxNTU1ZsGCBSY5OdlIMg899JDf2OpDn5mZad57771GP2275557jCQzYcIE6zFz+gSyd+/eJj8/37pucXGxb9n64wIEA1n211SWr7vuOiPJPP3008aYujfjPXv2NK+88opvmcGDBxtJ5uOPP26wPllGqJBlf815Xa53yy23GEnmnHPO8fvbnZpqsoxQIMf+AsmxMYE31eS4Dk11GDmFfujQoUaSefnllxvU1qxZ4wvdoZ8I1Yfe4/H43qwe6uGHHzaSzJQpU3z3TZw40Uh1Xxv5qYqKCnPUUUcZSaagoMAYU/dVljZt2hiPx2M++uijBussXLjQSDIpKSlm9+7dvvvrQ//OO+9Yj8eNN95oJJkHHnjAuswTTzxh1q1bZ3bu3GmqqqpMSUmJWbp0qRk5cqTvE8fCwsJG162qqvKFfunSpdZ9AIEiy/6ayvLFF19sJJnXX3/dGGPMm2++aSSZnJwc6zKHIssIFbLsrzmvy8YY8/HHHxuPx2OSkpLMl19+6fe3OzXVZBmhQI79NTfH9QJtqslxHS5UFoEqKir0z3/+U/Hx8br88ssb1IcOHaqsrCzt3r1bGzZsaFCfOHGiJkyY0OD+o48+WpJUXl7uu6979+6SpPXr18vr9fotn5ycrBdeeEF///vf1bdvX0nS0qVLVVVVpXPPPVfDhw9vsI+cnBydfvrpOnDggJYuXdqgfuhUVz/1n//8R5KUnp5uXea2227TkCFD1LVrVyUmJuqII47Q6NGj9f7772vs2LEqKyuzTkKfmJioDh06SKq7yiEQamS58Szv379fkpSamirpxzzu2bPHt0z79u0lSfv27WuwPllGSyPL9tfl6upqjR8/XsYY3Xvvverfv7912Z8iy2hJ5Nie48NBjuvQVEegb7/9VrW1teratavatm3b6DK9e/eWJH3xxRcNakOHDnXcvjlk7rmbb75Z3bp10yuvvKKsrCzNmDFDH3zwgQ4cOCBJuuCCC3TJJZeoU6dOkqSNGzdKks4991zr9kePHi1JjT4hOfnpG+1AeDwePfbYY4qLi9M777yjioqKRpdr166dJKmqqirgfQCBIsuNZzk+Pt5v/B07dpQkHXHEEb5l6t+E2K4gTJbRksiy/XX5scce08aNG9W/f/9mz2N9KLKMlkKOA39/3VzkmKY6ItXPzfrdd9/J4/E0env99dcl/fjpk1tdunTRRx99pGHDhmnTpk2aPn26zjvvPHXp0kXXXnutvvrqK7/l688kOV1uv/4Tu5KSkoDGUv+GurS0NKD16h155JEaMGCAKisr9e233za6TP0naLYnUyCYyHLjWa5/Ya9/c3HmmWeqZ8+e+uUvf+lbpv5NQP0L9U+RZbQkstx4ljdv3uybHu/pp59WUlJSQNuXyDJaDjl29/66OcixlBDuAcAuPj7e9/URm4SEw/8vHDBggJYvX66vvvpK7733nlauXKl3331XL7/8shYtWqSFCxfqwgsvlPTjWaPmfBLV1ByVP1X/aV2gTxaHcnriOHjwoO9NfP2+gJZAlv3Vv2n47rvvJNXltqioyG+ZrVu3+i17KLKMcCHL/iZNmqSKigqNGzfO8QybDVlGOJDj4CLHdWiqI1B9Y9ipUydt2bKlxfbbv39/9e/fX7feeqsOHDigBx54QI899pjGjRunbdu2KTU11fd7jJ++AT5U/RvlQIPVuXNnv/Xd2LlzpyQpLS3NOi5JyszMdL0PoLnIcuNZrv/NZWNfr5PqXqA///xzSdKxxx5rHZdEltEyyHLDLK9du9b32853331XvXr18qvXn+m7++679eCDD+rRRx9VTk5Oo+OSyDJCjxy7f3/thBzX4evfEahv375q27atfvjhh5B9qlTv9ttv18UXX+x3cQVJSklJ0aOPPqrs7Gzt2bNHy5cvlySddNJJkqRVq1ZZt/nhhx9KkrKysgIay8knnyxJWrNmTUDr1SssLNQ333yjdu3aNXqhlIKCAkl1DXf9b2aAUCLLjWf59NNPlyQtW7as0fp7772nyspKnXTSSb6LnxyKLKOlkeWGWa7/iYYk7dq1S0VFRX63+vru3btVVFTU6EUHyTJaEjl29/66KeS4Dk11BEpISND5558vY4yef/75Rpe5+eabNXbsWH3yySeHta9169bpzTffbPRKgpKUkZEh6ccXz5///OdKTk7WihUrtGLFigbLL1u2THl5eWrbtq3OP//8gMYybNgweTweFRYWBvy7j6qqKk2aNEmSdOWVVyoxMbHBMh9//LGkujf0gX51BnCDLDee5VNPPVVHH320NmzYoJUrV/rVvF6vZs+eLUn61a9+1ej2yTJaGllumOWzzz5bpm5q1kZvubm5kqQFCxbIGKPrrruuwfbJMloSOQ78/XVzkOM6NNVhVP97jcZ+P3HvvfcqPj5ev/nNb/TEE0/4rma9f/9+zZgxQ/PmzdNf//rXw76S39VXXy2p7ndR//jHP3z3G2P01ltv6e2335bH49GgQYMk1X11ZsqUKZLq3vDWf2pmjNEbb7zhm6Lgjjvu8LuSb3N07txZxx9/vIwxev/99xvUb7/9dj311FPavHmz78rAXq9XK1as0Nlnn62PPvpInTt31kMPPdRgXWOM3nvvPUnSRRddFNC4gKaQZX9NZTkuLk7333+/JOmaa67Rp59+KqnuIjI33XSTCgoKdNRRR+nWW29tsC5ZRiiRZX9NZflwkGWECjn2R45bSKgnwoZdQUGBkWTi4uJMZmam71bvqaeeMnFxcUaSadOmjcnIyDCJiYm+CejnzJnjt736yekXLFjQ6P4WLFhgJJnc3FzffZWVleass87yTdresWNH0717d5OSkuK7b/LkyX7bqaioMOecc46vnp6ebtLS0nz/Pv/8883Bgwf91qmfnH7ZsmWOx+SRRx4xkszo0aMb1E466STfPhITE01GRobfOI888kizZs2aRrf7/vvvG0kmKSnJ/PDDD45jAAJFlhtyyrIxxni9XnPppZf69tW1a1ffMUlNTTWrVq1qdD2yjFAiyw01lWWbpv52soxQIccNBZLjZcuWGUlmxIgRTS5Ljn9EUx1m9913n+ncubMvMD/9nGPNmjXmqquuMkcddZRp06aN6d69u7nkkktMXl5eg225Cb0xxuzfv988+OCD5sQTTzSpqakmISHBdOrUyQwbNsw899xzpra2tsG2qqqqzOOPP24GDx5sUlJSTLt27czQoUPNU089ZWpqahos39zQl5SUmNTUVBMXF2c2b97sV8vLyzNTpkwxgwcPNl27djUJCQkmPT3dnHnmmWb27Nlm79691u1efvnlRpK58cYbHfcPuEWW/TlluV51dbV58MEHzdFHH20kmbZt25pRo0aZf/3rX9btkmWEGln215wsN6apv50sI5TIsb9AchxIU02Of+Qx5pCZyoEIcPvtt+uPf/yjxo4dq1deeeWwt5efn68zzzxTSUlJ+vLLL9WzZ88gjBJAU8gyEBvIMhD9yHFo0VQj4pSWlurEE0/Utm3b9NFHH2n48OGut1VbW6shQ4Zow4YN+u1vf6t77rkniCMF4IQsA7GBLAPRjxyHFhcqQ8RJS0vT/PnzlZqaetjzCH7//fcyxmjkyJG6++67gzNAAM1CloHYQJaB6EeOQ4sz1YhYJSUlSk9PP+zt1NTU6MCBA43Odwsg9MgyEBvIMhD9yHFo0FQDAAAAAOBSyL/+XVVVpXvuuUedO3eWx+NRRkaGZs6cqdra2lDvGkCQkGMgNpBlIDaQZSCyhPxM9eWXX67Fixdr1KhROuWUU7R8+XKtXLlS48aN00svvdSsbXi9Xu3YsUPt27eXx+MJ5XCBqGWMUXl5uTIyMhQXF9zPy8gx0DJCmWOJLAMtJdKzTI6B5ml2lkM5X9eiRYuMJHPLLbf47vN6vWbMmDFGklm8eHGztrNt2za/eea4ceNmv23bto0cc+MW5bdg55gsc+MWnlukZpkcc+MW2K2pLIf0TPWFF16od955R0VFRerRo4fv/h07digzM1MjR47U0qVLm9xOaWmpOnbsqLN0gRKUGKrhAlGtRtVaqbe1d+9epaWlBW275BhoOaHKsUSWgZYU6Vkmx0DzNDfLCaEcREFBgQYPHuwXeEnKyMhQdna2Vq1aJWNMk187qa8nKFEJHoIPNOp/Px4L9te4yDHQgkKUY4ksAy0qwrNMjoFmamaWQ3ahstLSUpWUlKhv376SpNWrV6tbt27aunWrJGngwIEqLy/Xzp07G6xbWVmpsrIyvxuAlkeOgdhAloHY4DbL5BgIrZA11fv27ZMkdenSRZKUl5enXbt2KT8/X5J886OVl5c3WHf27NlKS0vz3X76SRyAlkGOgdhAloHY4DbL5BgIrZA11fVXR0tIaPwb5vU/5W7sVPq0adNUWlrqu23bti1UwwTggBwDsYEsA7HBbZbJMRBaIftNdYcOHSTVfU1FkkaMGKEuXbooOztbklRSUiJJjf7gOykpSUlJSaEaGoBmIsdAbCDLQGxwm2VyDIRWyJrq1NRUdevWTV9//bUkKTs7W7t27fLVN27cqI4dO/q+vgIg8pBjIDaQZSA2kGUgMoXs69+SdMYZZ2jdunUNLoawY8cOrV69WsOGDQvl7gEEATkGYgNZBmIDWQYiT0ib6tzcXB08eFCzZs3y3ef1ejV16lTV1NRo/Pjxodw9gCAgx0BsIMtAbCDLQOQJ6TzVl1xyiS699FI98sgjKiwsVFZWlj788EPl5+crNzdXF154YSh3DyAIyDEQG8gyEBvIMhB5QnqmWpJeffVV3XXXXfr44481c+ZMbd68WQ899JDmzZsX6l0DCBJyDMQGsgzEBrIMRBaPqb/2fgQrKytTWlqaztYlSvAkhns4QESqMdXK0+sqLS31XR00kpBjoGmRnmOJLAPNEelZJsdA8zQ3yyE/Uw0AAAAAQKyiqQYAAAAAwCWaagAAAAAAXKKpBgAAAADAJZpqAAAAAABcoqkGAAAAAMAlmmoAAAAAAFyiqQYAAAAAwCWaagAAAAAAXKKpBgAAAADAJZpqAAAAAABcoqkGAAAAAMAlmmoAAAAAAFyiqQYAAAAAwCWaagAAAAAAXEoI9wAAAJGr5twh1tp3kyqttQ2n/8laOyk/13GfGXPbWGvxy/7puC4AAEBL40w1AAAAAAAu0VQDAAAAAOASTTUAAAAAAC7RVAMAAAAA4BJNNQAAAAAALtFUAwAAAADgElNqQZ4E+8MgvkvnkOzzy7t7WWu1KV5rLbPPLsftpkzyWGvfP26fpuefp/w/a+0/tfsd93nawrustb5TCxzXBcLNO2KwY/2P85+01vom2p877CmWPjl9geM+vzyl1lr7P72yHdcFEB3255xmrT3yu6ettf++8lrH7Zp1G12PCWiNvvm/pzvWPx9rfx+Q6Im31oZPusVaa/v3NU0PLMpwphoAAAAAAJdoqgEAAAAAcImmGgAAAAAAl2iqAQAAAABwiaYaAAAAAACXaKoBAAAAAHCJKbUiTPxx/aw1k5Rore0Y0dFxuxXZ9mmh0tPstRUn2aeaCod3DrR3rD/y5PnW2uoT/2ytba6usNZ+u/M8x31mrDCOdSDcqkedYq39+qmXHdftn2ifis7rMHHWt9XV1lqpN8lxn4MdypU/H2qttV32mbXmPXjQcZ+IfBWXnGqvdbJP65I+Pz8Uw8Fh2nWK/bzOf2/5RQuOBIh93995hrWWd9XvHNetNvb3AY5a2dtjzlQDAAAAAOASTTUAAAAAAC7RVAMAAAAA4BJNNQAAAAAALtFUAwAAAADgEk01AAAAAAAu0VQDAAAAAOAS81SHQe3ZJ1trj78411pzmi82llSbWmvtgTnXOa6bsN8+Kd7pC2+11tpvr7HWkv5jn8NaklLWrXasA8ES36GDtbZ/+LHW2p2/t8/Rfk7bfU3s1d1nry/usc+J+Y+nTndc9+MH/2itvf/CM9ba8f9jz/gx9zBXcbTbMdz+WEzps9e+4vzgjwXNFGefP9z0tL+2juz6hbX2D4/9uQVA4/b18Fpr6XGto78INc5UAwAAAADgEk01AAAAAAAu0VQDAAAAAOASTTUAAAAAAC7RVAMAAAAA4BJNNQAAAAAALjGlVhgkfbnDWlt/sIe11j9xZyiG49pd32U71r/d19lae7HPImut1GufFqvbH1c1PbAgs48GaFnFLx1tra0dap+OLxxmdF1rrS1t5zwlzvVbRllrf+r1gbXW4fjdTQ8MUeuhixZaa498bn/MIHzi+2Raa1+MsM91NmjNNdZaxtrPDmtMQKzad8Vp1tprlz7hsKbHcbvP7LVP2fnBladYa6lFm6w1+wRf0Ysz1QAAAAAAuERTDQAAAACASzTVAAAAAAC4RFMNAAAAAIBLNNUAAAAAALhEUw0AAAAAgEtMqRUGNd99b63NeeQKa23m+futtfh/tXPc54ZJc5oeWCMe/k+Wtfb1z1Ic163d+521Nvb0Sdbaltvs2+ytDY77BKJZzblDHOuvDnrSWotTG1f7vL5opGN93QfHWWuf3Wgfz7KKZGut67oKx31+vcc+fUfirGXWWpzzrCCIcomemnAPAQFKeOGAq/UqvukQ5JEAseHgRadaa9Nn26ep65/o/gXyT8+fb60dWdjyU91GKs5UAwAAAADgEk01AAAAAAAu0VQDAAAAAOASTTUAAAAAAC7RVAMAAAAA4JLrpnrJkiXyeDxasmRJo/WysjLdfPPN6tChgzwej/r06aNnn33W9UABhAZZBqIfOQZiA1kGolPAU2rt2LFDxcXFuvvuu63LVFVVafTo0SooKFBOTo769eunt956SxMmTND27ds1Y8aMwxp0LEtfkG+tdXmzk7VWu7vEcbsnDLzBWts03H4J/jeeG2Gtdd3r/jL6nnz71Fi97YcAQUSWw8M7YrC19sf59imqJKlvov0p2yuvtXbxF5daa/E59qn6JKnjhcZaO/7lW621/nO3WWtx2z5x3OcRK+y16pm11tprWfbnshvOcZirT1L8sn861iNVrOXYe9Yga21Y8sqWGwiColfqblfr9fjAnvNYFWtZRmh8d81Ba+2ctvaaFG+t5G75meM+j3yCabOaI+Az1ddee61OO+00ffnll9Zl5syZo4KCAs2aNUsLFy7UrFmztHbtWg0fPlwzZ87U+vXrD2vQAA4fWQaiHzkGYgNZBqJbwE319OnTtXDhQk2ePNm6zPz589WuXTvdeeedvvvatGmjRx99VF6vV88//7y70QIIGrIMRD9yDMQGsgxEt4C//j1s2DBJ0r59+xqtl5WVqbCwUJdddpmSk5P9akOHDlX37t21ciVf4QLCjSwD0Y8cA7GBLAPRLehX/968ebMkqW/fvpKk1157Tb169dL+/XW/2xs4cKC+/vprx21UVlaqrKzM7wagZR1ulskxEH68JgOxgddkILIFvamu/4StS5cukqQPPvhARUVF2rhxoyQpPT1dlZWVqq6utm5j9uzZSktL89169OgR7GECaMLhZpkcA+HHazIQG3hNBiJb0JvquLi6TSYkNP7NcmPqriTr8Xis25g2bZpKS0t9t23b7FeRBRAah5tlcgyEH6/JQGzgNRmIbAH/propaWlpkqTS0lJJ0siRI/X2229r4MCBkqSSkhKlpKRYnxQkKSkpSUlJScEeGoAAHG6WyTEQfrwmA7GB12QgsgW9qe7du7c8Ho/vdx05OTnKycnx1Tdt2qR+/foFe7etQu1/3M33KEnVZW1crXfC1YXW2g9P2+e8kyR5W988k7GELLvnGXKCtfafqRXWWv9E55yur7TXPtx3vLW2+y/2r/l12uM8MXza/xTYaw7r1ThuNTS6xdvfMO6+44Djul2XBXs0kSHaclx0UVtrrWt8SguOBM2R0KunYz0n/Q1X2227eY+11lrfWURbluFOQvejHeubhi2w1qqNPR2f23/ho62P93fcZ6pWO9ZRJ+hf/27btq0GDRqkvLw8eb1ev9ratWtVXFys4cOHB3u3AIKMLAPRjxwDsYEsA5Et6E21JOXm5qq4uFjPPfec777KykpNnTpV8fHxuummm0KxWwBBRpaB6EeOgdhAloHIFfSvf0vSxIkT9eqrr2ry5MlasWKFMjMz9cYbb2jTpk2aPn26srKyQrFbAEFGloHoR46B2ECWgcgVkqa6TZs2evfdd3XnnXdq0aJFKi8vV69evTR37lxNmjQpFLsEEAJkGYh+5BiIDWQZiFyum+rrrrtO1113nbWelpam+fPna/78+W53AaAFkGUg+pFjIDaQZSA6heQ31QAAAAAAtAYh+fo3Is9x93xlrV1/4khrbUHmP6y1EVdMdtxn+/9nn4oHiHZxKfbpfWp+V2atFRy72FrbXFPluM+p991lrR2xYqu11jV1l7XWWqanOfWoIsf6lpYZBpqQ0Lfc1XoHv+gY3IGgWbb9IdWxfmaS11qbV9bdvuJe+3MoEO3iTxhgrZ3y540h2edVi2+z1vq8xvv1YOBMNQAAAAAALtFUAwAAAADgEk01AAAAAAAu0VQDAAAAAOASTTUAAAAAAC7RVAMAAAAA4BJTarUStXtLrbXdE4+z1ra+UWGt3fvwS477nHblpdaa+STNWusxM9++UWMc9wm0lIoRJ1hr7x77lKtt3nT7nY719n+3T3tR42qPQGzous4+dRPqxHfuZK3tvLy/tZZ+ZbG19lH/eU3sNdlaeXruL621rjtXNbFdIHoVXWzP4qJOnzSxdry1MvabX1hr/X/7jbXWWqbWDDXOVAMAAAAA4BJNNQAAAAAALtFUAwAAAADgEk01AAAAAAAu0VQDAAAAAOASTTUAAAAAAC4xpRbk3fC5tfarh/6PtfbK9Ecdt/tptsOUW9n20gmpt1pr/Z7/znGfNd9ucawDwZL1359aa3EOn1deXzTSWmv79zWHM6RWIdFjn06k2mHGvXgP0/HFsop0e+ZSQ7RP77DB1pqJ9ziuu+1nSdZaVUa1tRbXxj75zXvD5jjuM9FhSN/X2sfzm2/t02OWeJ2nMkuJs4+32+pya420ItqVXH+6tfa3Cf/XYc1Ex+1O2DbCWqvOtee49oetjtvF4eNMNQAAAAAALtFUAwAAAADgEk01AAAAAAAu0VQDAAAAAOASTTUAAAAAAC7RVAMAAAAA4BJNNQAAAAAALjFPNRylz8+31m79crLjuh1+W2ytvXrMu9bapmuftNaO7XGT4z4HPGT/nKj23986rgscau84+xyTknR/N/s87V61sdbWv3e8tdZTq5oeWCtXbezz3nplnzN36ef24y5J/fRP12NC8FQetM/R6nWYvXjBfb+31t64ddDhDMnqnk4vWGtxcp6nusJUWWs7au2P8Sd/ONta+9kHdzjus+Mn9uelo97baa15iuyv5T983tZxn93i7XNum7WfOa4LRLr4EwZYa6setr+XlZJd7zO/uJe11mPLRtfbxeHjTDUAAAAAAC7RVAMAAAAA4BJNNQAAAAAALtFUAwAAAADgEk01AAAAAAAu0VQDAAAAAOASU2rBNc/HnzrWD+R0tdaGXjXFWlt9zxPW2hfn2KcwkaSre42y1krPclwV8FPjPFOM0uLs09PkH0yy1o55aYd9n02OKjbEpaQ41r94dKBDdb21cvW3P7fWjr19s+M+7ZMYoSX1veYTa+2E2bdaaz2Gbg/FcBwt29XfWvvhne6O63baZJ9qqs3StQ5r2tfrr3WO+3Ti9Pjffs8Z1trQJPu0m5L0l31HuxwREPm+us/+WuY0/ePh6Plbe80+6SBaAmeqAQAAAABwiaYaAAAAAACXaKoBAAAAAHCJphoAAAAAAJdoqgEAAAAAcImmGgAAAAAAl5hSCyFTu3OXtdbtj/bawV/bJxZK8dinMZKk53stsdYuuvQO+3b/ttpxu0Agdte2s9Zqvt3ScgMJI6dps7787YmO635xyZPW2jsH0qy1HXP7Wmvt9xQ47hORr/c05+mbIslR2hruIQRNyvAfXK97/7LLrbX+WuN6u0BL8I4Y7Fh/+JS/B32f5238lWO93bqNQd8ngoMz1QAAAAAAuERTDQAAAACASzTVAAAAAAC4RFMNAAAAAIBLNNUAAAAAALhEUw0AAAAAgEtMqQXXvGcNcqx/c0WytTZw0BZrralps5zMKbFPf5Dy+jrX2wUCcffHV1hr/bW+BUcSWk7TjeyaWmGtfX6KfcosSRr52VXWWur531pr7cW0WUAkyXzdhHsIgGszX3zOsT4w0d3j++7vhltraWP2OK5b62qPaAmcqQYAAAAAwCWaagAAAAAAXKKpBgAAAADAJZpqAAAAAABcoqkGAAAAAMAlmmoAAAAAAFyiqQYAAAAAwCXmqYY8pwy01r66zT5n9PNn/slxu8OTq1yPyabSVDvWC0p624ve74I8GsQ0j3M5zuEzySfOetVam6v+bkcUFkUzTrfWXrv2cWutf6L9uePkNbmO+8y4tLDpgQEAEEKD2zife6w27maNzl9wsrXWdc8qV9tE+HGmGgAAAAAAl2iqAQAAAABwiaYaAAAAAACXaKoBAAAAAHCJphoAAAAAAJdcNdUbNmzQRRddpHbt2ikuLk69e/fWlClTVFxc7LdcWVmZbr75ZnXo0EEej0d9+vTRs88+G5SBAzg85BiIDWQZiA1kGYheAU+ptWLFCo0aNUq1tbW64oorlJmZqX/961+aO3euXn31Va1Zs0bHHHOMqqqqNHr0aBUUFCgnJ0f9+vXTW2+9pQkTJmj79u2aMWNGKP6eViuhd6Zj/ZvrM6y1B6/6i7V2ebv/uB6TW/ftPMVa++iJbMd1j/hTfrCHE5PIcTMY57JXXmttRNvd1todLw6x1vossG9TkhK/L7fWdo7oYq2lX1VsrU3p+Q/Hff48Zb219sb+btbatZ+db611fjbVcZ9oPrKMcIr3OJ+b2dM/0Vo78p1gjya6keXw2LbIPq1soufTkOzzqDz7e2t3k3QhEgR8pnrixImSpJUrV+qVV17RrFmztGTJEr300kvavXu3nnjiCUnSnDlzVFBQoFmzZmnhwoWaNWuW1q5dq+HDh2vmzJlav97+Rg1AaJFjIDaQZSA2kGUgugXUVJeUlKi4uFhXX321Tj31VL/amDFj1KZNG3355ZeSpPnz56tdu3a68847fcu0adNGjz76qLxer55//vkgDB9AoMgxEBvIMhAbyDIQ/QL6+nd6err27t3baG3v3r2qqqpSRkaGysrKVFhYqMsuu0zJycl+yw0dOlTdu3fXypUrrfuprKxUZWWl799lZWWBDBOAA3IMxAayDMSGlsgyOQZCK2hX/3744Ycl1X2itnnzZklS3759JUmvvfaaevXqpf3790uSBg4cqK+//tq6rdmzZystLc1369GjR7CGCcABOQZiA1kGYkOwskyOgdAKSlP92GOP6Q9/+IOuvfZanXfeedq3b58kqUuXugvnfPDBByoqKtLGjRsl1X0iV1lZqerq6ka3N23aNJWWlvpu27ZtC8YwATggx0BsIMtAbAhmlskxEFoBX/37UHv27NH48eO1cOFCjRs3TvPmzZMkxcXV9eoJCY1v3pi6y+p6PJ5G60lJSUpKSjqcoQFoJnIMxAayDMSGUGSZHAOh5bqpzsvL07hx47Rnzx7NmzdPN9xwg6+WlpYmSSotLZUkjRw5Um+//bYGDqy7bH1JSYlSUlKsTwqtWUKvntZa6ZCjrLWrZix13O6Ejotdj8mtu76zT3+V/5R92qz0F9dYa0d4mTIrmMhxaCR77Mfk8/OesdZWDku21iTp35VHWmvXp21pclxu3L5jmLW2dNUga63f7QUhGA1syDLCodY4TwMYvB8Zth5kOfi8IwZba38Y9D/WWrVxnuCq1HvQWhv6zh3W2rFFhY7bRXRy9XS3ePFijRo1Sp07d9ann37qF3hJ6t27tzwej+93HTk5OSoqKlJqat3cpJs2bVK/fv0Oc+gADgc5BmIDWQZiA1kGolfATfWGDRs0ZswYZWdna/ny5b6LJRyqbdu2GjRokPLy8uT1+n+KuXbtWhUXF2v48OHuRw3gsJBjIDaQZSA2kGUgugXcVE+YMEEdO3bU4sWL1b59e+tyubm5Ki4u1nPPPee7r7KyUlOnTlV8fLxuuukmdyMGcNjIMRAbyDIQG8gyEN0C+tFFYWGhCgoKNGTIED3zjP13gffff78mTpyoV199VZMnT9aKFSuUmZmpN954Q5s2bdL06dOVlZV12IMHEDhyDMQGsgzEBrIMRL+Amupdu3ZJktavX6/169dbl7v//vvVpk0bvfvuu7rzzju1aNEilZeXq1evXpo7d64mTZp0eKMG4Bo5BmIDWQZiA1kGol9ATfXZZ5/tu1x/c6SlpWn+/PmaP39+wAMDEBrkGIgNZBmIDWQZiH5ccz8EEo6yT3sjSSXzU621ib0/stbGtN/pekxu3br9LGvtn08Pcly386KN1lp6OVNjIbJ1y9vlWL9n/OnW2iNHunt8D0+ucqyflbzF1XY/qbRfPmPMR7c4rtv/evtZk35i2iwAdgeGHgj3EAAdTG9jrZ2VvN9hzXjH7b57wD4Nbv9b1lprTUxEhyjFDIIAAAAAALhEUw0AAAAAgEs01QAAAAAAuERTDQAAAACASzTVAAAAAAC4RFMNAAAAAIBLNNUAAAAAALjEPNUOqkafYq/dWWKt3df3bcftjmrrNCdeaOysrbDWhr9xl7V27P1fWGvpe53n4mUePkSz2q++caz/+4pe1trxU6ZYa4VXznE7JEfHvj3JWhvwlH2u2P6f2OehBgAn8R7OzQCAxJlqAAAAAABco6kGAAAAAMAlmmoAAAAAAFyiqQYAAAAAwCWaagAAAAAAXKKpBgAAAADAJabUcrDll/bPHL46cWFI9jl3bx9r7YmPRllrnlqP43aPfXiztdZv52prrdZxq0DrVfPtFmut75322sV3Dg3+YCT111przYRkjwBag8oPulhrtYOYPBORr8On31trU4rPtdae6fFRKIaDGMWZagAAAAAAXKKpBgAAAADAJZpqAAAAAABcoqkGAAAAAMAlmmoAAAAAAFyiqQYAAAAAwCWm1HLQf+Iaa+2iiUNacCR1+ss+nqYwNRYAAAjUkb9fZa1d8PuTHdc9Rp8GeTRA4Go2F1lrxdn29S5Sy7/XR/TiTDUAAAAAAC7RVAMAAAAA4BJNNQAAAAAALtFUAwAAAADgEk01AAAAAAAu0VQDAAAAAOASTTUAAAAAAC7RVAMAAAAA4BJNNQAAAAAALtFUAwAAAADgEk01AAAAAAAu0VQDAAAAAOASTTUAAAAAAC7RVAMAAAAA4BJNNQAAAAAALtFUAwAAAADgEk01AAAAAAAu0VQDAAAAAOASTTUAAAAAAC7RVAMAAAAA4FJCuAfQHMYYSVKNqiUT5sEAEapG1ZJ+zEukIcdA0yI9xxJZBpoj0rNMjoHmaW6Wo6KpLi8vlySt1NthHgkQ+crLy5WWlhbuYTRAjoHmi9QcS2QZCESkZpkcA4FpKsseE6kfoR3C6/Vqx44dat++vcrLy9WjRw9t27ZNHTp0CPfQIk5ZWRnHpwmxeoyMMSovL1dGRobi4iLvlx2H5tjj8cTs/0OwcHycxerxifQcS7wmBypWH6vBEqvHJ9KzzGtyYDg+zmL5+DQ3y1FxpjouLk7du3eXJHk8HklShw4dYu4/LZg4Pk2LxWMUiZ+G1zs0x4eKxf+HYOL4OIvF4xPJOZZ4TXaLY+QsFo9PJGeZ12R3OD7OYvX4NCfLkffRGQAAAAAAUYKmGgAAAAAAl6KuqU5KStL06dOVlJQU7qFEJI5P0zhGkYH/B2ccH2ccn8jA/0PTOEbOOD6Rgf8HZxwfZxyfKLlQGQAAAAAAkSjqzlQDAAAAABApaKoBAAAAAHCJphoAAAAAAJdoqgEAAAAAcImmGgAAAAAAl6Kqqa6qqtI999yjzp07y+PxKCMjQzNnzlRtbW24hxYWS5Yskcfj0ZIlSxqtl5WV6eabb1aHDh3k8XjUp08fPfvssy08ypa3YcMGXXTRRWrXrp3i4uLUu3dvTZkyRcXFxX7LtdbjE27k2B85tiPLkY0s+yPLdmQ5cpHjhshy48hxE0wUueyyy4wkM2rUKHPfffeZs846y0gy48aNC/fQWtT27dvN6tWrzYABA4wk8+abbzZYprKy0mRnZxtJJicnx0ybNs1kZWUZSeY3v/lNGEbdMpYvX26Sk5NNYmKiGTt2rJk2bZq58MILjcfjMZ06dTLffPONMab1Hp9IQI7rkGNnZDnykeU6ZNkZWY5s5PhHZNmOHDctaprqRYsWGUnmlltu8d3n9XrNmDFjjCSzePHiMI6uZY0cOdJI8t0aC/2jjz5qJJlZs2b57qusrDTDhw83cXFxZt26dS055BZzwgknmOTkZLN69Wq/+19++WUjydx2223GmNZ7fMKNHP+IHDsjy5GNLP+ILDsjy5GLHPsjy3bkuGlR01RfcMEFxuPxmK1bt/rdv337dpOQkGBGjx4dppG1vOXLl5uFCxeayZMnW0N//PHHm3bt2pmKigq/+9esWWMkmfHjx7fUcFvM7t27TVpamrnxxhsb1GpqakybNm18j5PWeHwiATn+ETm2I8uRjyz/iCzbkeXIRo79keXGkePmSXD/xfGWVVBQoMGDB6tHjx5+92dkZCg7O1urVq2SMUYejydMI2w5w4YNkyTt27ev0XpZWZkKCwt12WWXKTk52a82dOhQde/eXStXrgz5OFtaenq69u7d22ht7969qqqqUkZGRqs9PpGAHP+IHNuR5chHln9Elu3IcmQjx/7IcuPIcfNExYXKSktLVVJSor59+0qSVq9erW7dumnr1q2SpIEDB6q8vFw7d+4M5zAjxubNmyXJd7xee+019erVS/v375dUd7y+/vrrsI0vHB5++GFJ0pgxYzg+YUKOA8PjtHFkOfzIcmB4nDaOLIcXOQ4cj9OGyPGPoqKprv/EqEuXLpKkvLw87dq1S/n5+ZLqPkGRpPLy8vAMMML89Hh98MEHKioq0saNGyXVHa/KykpVV1eHbYwt6bHHHtMf/vAHXXvttTrvvPM4PmFCjgPD47QhshwZyHJgeJw2RJbDjxwHjsepP3LsLyq+/h0XV9f7JyQ0PlxjjCS1mq+nNIXjVWfPnj0aP368Fi5cqHHjxmnevHmSOD7hwnEPDMfrR2Q5snDcA8Px+hFZjhwc88BxzOqQ48ZFxZnqDh06SKr7qookjRgxQl26dFF2drYkqaSkRJKUlpYWngFGmPrjUH+8Ro4cqZ49e2rgwIGS6o5XSkqK9UEfC/Ly8pSVlaW3335b8+bN00svveT7ezk+4UGOA8PjtA5ZjjxkOTA8TuuQ5chCjgPH45QcO4mKpjo1NVXdunXzfQ8/Oztbu3btUmZmpiRp48aN6tixo+/rBq1d79695fF4fMcrJydHRUVFSk1NlSRt2rRJ/fr1C+cQQ2rx4sUaNWqUOnfurE8//VQ33HCDX721H59wIceB4XFKliMVWQ4Mj1OyHInIceBa++OUHDuLiqZaks444wytW7dOZWVlfvfv2LFDq1ev9l2xD1Lbtm01aNAg5eXlyev1+tXWrl2r4uJiDR8+PEyjC60NGzZozJgxys7O1vLly30XSzhUaz4+4UaOm6+1P07JcmQjy83X2h+nZDlykePAtObHKTluWtQ01bm5uTp48KBmzZrlu8/r9Wrq1KmqqanR+PHjwzi6yJObm6vi4mI999xzvvsqKys1depUxcfH66abbgrj6EJnwoQJ6tixoxYvXqz27dtbl2utxyfcyHFgWvPjlCxHNrIcmNb8OCXLkYscB661Pk7JcdOi5kvtl1xyiS699FI98sgjKiwsVFZWlj788EPl5+crNzdXF154YbiHGFEmTpyoV199VZMnT9aKFSuUmZmpN954Q5s2bdL06dOVlZUV7iEGXWFhoQoKCjRkyBA988wz1uXuv//+Vnl8IgE5DkxrfZyS5chHlgPTWh+nZDmykePAtcbHKTluJhNFDh48aO666y6Tnp5uJJkjjzzSPPTQQ6ampibcQwuLBQsWGEnmzTffbLS+d+9ec/3115v27dsbSaZXr15m7ty5LTzKlrNs2TIjqclbvdZ2fCIFOfZHjhsiy9GBLPsjyw2R5chHjhsiy/7IcfN4jPnf65sDAAAAAICARM1vqgEAAAAAiDQ01QAAAAAAuERTDQAAAACASzTVAAAAAAC4RFMNAAAAAIBLNNUAAAAAALhEUw0AAAAAgEs01QAAAAAAuERTDQAAAACASzTVAAAAAAC4RFMNAAAAAIBL/x/5UwG0lM3OxgAAAABJRU5ErkJggg==\n",
"text/plain": [
"<Figure size 1200x300 with 4 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"import warnings\n",
"warnings.simplefilter('ignore')\n",
"import torch\n",
"import torchvision\n",
"import torchvision.transforms as transforms\n",
"%matplotlib inline\n",
"import matplotlib.pyplot as plt\n",
"from pylab import rcParams\n",
"rcParams['figure.figsize'] = 12, 3\n",
"rcParams['font.size'] = 14\n",
"rcParams['font.family']='Ume Hy Gothic O5'\n",
"\n",
"# 各画像に以下の順に処理をしてくださいというのを込めた関数\n",
"# ・ Tensor にする(ついでにこのときレンジを [0,255] から [0,1] にする)\n",
"# ・ 各画素の値から 0.5 を差し引いて 0.5 で割る(レンジを [0,1] から [-1,1] にする)\n",
"transform = transforms.Compose([transforms.ToTensor(), \n",
" transforms.Normalize((0.5,), (0.5,))])\n",
"\n",
"# 以下は最初に実行するときはダウンロードが走る\n",
"root = '../data'\n",
"batch_size = 4\n",
"trainset = torchvision.datasets.MNIST(root=root, train=True, download=True, transform=transform)\n",
"trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=2)\n",
"testset = torchvision.datasets.MNIST(root=root, train=False, download=True, transform=transform)\n",
"testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=2)\n",
"\n",
"print('◆ データのサイズ(データ数、縦方向ピクセル数、横方向ピクセル数)')\n",
"print('訓練データ', trainset.data.shape)\n",
"print('テストデータ', testset.data.shape)\n",
"\n",
"print('\\n◆ 訓練データの最初の4枚を描画')\n",
"fig = plt.figure()\n",
"ax1 = fig.add_subplot(1, 4, 1)\n",
"ax2 = fig.add_subplot(1, 4, 2)\n",
"ax3 = fig.add_subplot(1, 4, 3)\n",
"ax4 = fig.add_subplot(1, 4, 4)\n",
"ax1.imshow(trainset.data[0])\n",
"ax1.set_title(trainset.targets[0])\n",
"ax2.imshow(trainset.data[1])\n",
"ax2.set_title(trainset.targets[1])\n",
"ax3.imshow(trainset.data[2])\n",
"ax3.set_title(trainset.targets[2])\n",
"ax4.imshow(trainset.data[3])\n",
"ax4.set_title(trainset.targets[3])\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"id": "890e83a4-c2a6-42a2-9c52-2823a295c9f2",
"metadata": {},
"source": [
"#### テキスト同様のモデルを実現するのに必要な層を特定する"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "947a4447-cf6f-4d1c-99a0-d99e91f1479d",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"◆ テキスト474ページの図14.15と図14.16に合うようなモデルを突き止める\n",
"※ バッチサイズ4, 1チャネル, 28x28ピクセルのダミーデータを流す。\n",
"入力直後\n",
"torch.Size([4, 1, 28, 28])\n",
"1番目の畳み込み後\n",
"torch.Size([4, 6, 28, 28])\n",
"1番目のプール後\n",
"torch.Size([4, 6, 14, 14])\n",
"2番目の畳み込み後\n",
"torch.Size([4, 16, 10, 10])\n",
"2番目のプール後\n",
"torch.Size([4, 16, 5, 5])\n",
"リシェイプした後\n",
"torch.Size([4, 400])\n",
"1番目の全結合した後\n",
"torch.Size([4, 120])\n",
"2番目の全結合した後\n",
"torch.Size([4, 84])\n",
"3番目の全結合した後\n",
"torch.Size([4, 10])\n"
]
}
],
"source": [
"import torch\n",
"import torch.nn as nn\n",
"import torch.nn.functional as F\n",
"\n",
"print('◆ テキスト474ページの図14.15と図14.16に合うようなモデルを突き止める')\n",
"print(f'※ バッチサイズ{batch_size}, 1チャネル, 28x28ピクセルのダミーデータを流す。')\n",
"x = torch.randn(batch_size, 1, 28, 28)\n",
"print(f'入力直後\\n{x.size()}')\n",
"x = F.relu(nn.Conv2d(1, 6, 5, padding=2)(x))\n",
"print(f'1番目の畳み込み後\\n{x.size()}')\n",
"x = nn.AvgPool2d(2)(x)\n",
"print(f'1番目のプール後\\n{x.size()}')\n",
"x = F.relu(nn.Conv2d(6, 16, 5)(x))\n",
"print(f'2番目の畳み込み後\\n{x.size()}')\n",
"x = nn.AvgPool2d(2)(x)\n",
"print(f'2番目のプール後\\n{x.size()}')\n",
"size = x.size()[1:]\n",
"num_features = 1\n",
"for s in size:\n",
" num_features *= s\n",
"x = x.view(-1, num_features)\n",
"print(f'リシェイプした後\\n{x.size()}')\n",
"x = F.relu(nn.Linear(num_features, 120)(x))\n",
"print(f'1番目の全結合した後\\n{x.size()}')\n",
"x = F.relu(nn.Linear(120, 84)(x))\n",
"print(f'2番目の全結合した後\\n{x.size()}')\n",
"x = nn.Linear(84, 10)(x)\n",
"print(f'3番目の全結合した後\\n{x.size()}')"
]
},
{
"cell_type": "markdown",
"id": "25ec0a82-2747-452b-ab94-16b0dfca4454",
"metadata": {},
"source": [
"#### モデルを実装する"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "5bd5eee6-98e8-4439-b31f-bcb3061fa466",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"◆ 今回学習するネットワーク\n",
"LeNet5_MNIST(\n",
" (conv1): Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))\n",
" (conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))\n",
" (fc1): Linear(in_features=400, out_features=120, bias=True)\n",
" (fc2): Linear(in_features=120, out_features=84, bias=True)\n",
" (fc3): Linear(in_features=84, out_features=10, bias=True)\n",
")\n",
"\n",
"◆ 今回学習するパラメータたち\n",
"conv1.weight torch.Size([6, 1, 5, 5])\n",
"conv1.bias torch.Size([6])\n",
"conv2.weight torch.Size([16, 6, 5, 5])\n",
"conv2.bias torch.Size([16])\n",
"fc1.weight torch.Size([120, 400])\n",
"fc1.bias torch.Size([120])\n",
"fc2.weight torch.Size([84, 120])\n",
"fc2.bias torch.Size([84])\n",
"fc3.weight torch.Size([10, 84])\n",
"fc3.bias torch.Size([10])\n"
]
}
],
"source": [
"import torch.nn as nn\n",
"import torch.nn.functional as F\n",
"\n",
"class LeNet5_MNIST(nn.Module):\n",
"\n",
" def __init__(self):\n",
" super(LeNet5_MNIST, self).__init__()\n",
" self.conv1 = nn.Conv2d(1, 6, 5, padding=2)\n",
" self.conv2 = nn.Conv2d(6, 16, 5)\n",
" self.fc1 = nn.Linear(400, 120) # TODO: ここが 28x28 ピクセル決め打ちになっている\n",
" self.fc2 = nn.Linear(120, 84)\n",
" self.fc3 = nn.Linear(84, 10)\n",
"\n",
" def forward(self, x):\n",
" x = nn.AvgPool2d(2)(F.relu(self.conv1(x)))\n",
" x = nn.AvgPool2d(2)(F.relu(self.conv2(x)))\n",
" x = x.view(-1, 400)\n",
" x = F.relu(self.fc1(x))\n",
" x = F.relu(self.fc2(x))\n",
" x = self.fc3(x)\n",
" return x\n",
"\n",
"net = LeNet5_MNIST()\n",
"\n",
"print('◆ 今回学習するネットワーク')\n",
"print(net)\n",
"print('\\n◆ 今回学習するパラメータたち')\n",
"for name, param in net.named_parameters():\n",
" print(name.ljust(14), param.size())"
]
},
{
"cell_type": "markdown",
"id": "a3723955-316e-48fb-9df4-97d43e79a68a",
"metadata": {},
"source": [
"#### 1エポック学習する"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "61a185f9-95b7-427d-a2d8-f77bf8d11d3d",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"2000バッチまで学習済み, 現在の1バッチあたり損失:2.025\n",
"4000バッチまで学習済み, 現在の1バッチあたり損失:0.399\n",
"6000バッチまで学習済み, 現在の1バッチあたり損失:0.185\n",
"8000バッチまで学習済み, 現在の1バッチあたり損失:0.144\n",
"10000バッチまで学習済み, 現在の1バッチあたり損失:0.126\n",
"12000バッチまで学習済み, 現在の1バッチあたり損失:0.111\n",
"14000バッチまで学習済み, 現在の1バッチあたり損失:0.104\n",
"テストデータ 10000 枚に対する正解率:97.94%\n"
]
}
],
"source": [
"import torch.optim as optim\n",
"\n",
"def train(model, trainloader, criterion, optimizer, print_interval=-1):\n",
" running_loss = 0.0 # プリント用\n",
" for i, data in enumerate(trainloader):\n",
" inputs, labels = data\n",
" optimizer.zero_grad()\n",
" outputs = model(inputs)\n",
" loss = criterion(outputs, labels)\n",
" loss.backward()\n",
" optimizer.step()\n",
" running_loss += loss.item() # プリント用\n",
"\n",
" if print_interval > 0:\n",
" if i % print_interval == (print_interval - 1):\n",
" print(f'{i + 1}バッチまで学習済み, '\n",
" f'現在の1バッチあたり損失:{running_loss / print_interval:.3f}')\n",
" running_loss = 0.0\n",
"\n",
"def test(model, testloader):\n",
" correct = 0\n",
" total = 0\n",
" with torch.no_grad():\n",
" for data in testloader:\n",
" images, labels = data\n",
" outputs = model(images)\n",
" _, predicted = torch.max(outputs.data, 1)\n",
" total += labels.size(0)\n",
" correct += (predicted == labels).sum().item()\n",
" return correct / total\n",
"\n",
"criterion = nn.CrossEntropyLoss()\n",
"optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)\n",
"train(net, trainloader, criterion, optimizer, print_interval=2000)\n",
"accuracy = test(net, testloader)\n",
"print(f'テストデータ 10000 枚に対する正解率:{accuracy:.2%}')"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "8ebec1fa-1758-4acf-ac28-2acfa40ad8ff",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.2"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment