Skip to content

Instantly share code, notes, and snippets.

@n-taku
Created March 29, 2020 04:21
Show Gist options
  • Save n-taku/8464b2cfaf3bb7c50aa812cd5351338e to your computer and use it in GitHub Desktop.
Save n-taku/8464b2cfaf3bb7c50aa812cd5351338e to your computer and use it in GitHub Desktop.
BatchNormalizationのモデルサンプル1
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "CIFAR10BN_model1.ipynb",
"provenance": [],
"collapsed_sections": []
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
}
},
"cells": [
{
"cell_type": "code",
"metadata": {
"id": "3mgvuAy0flDD",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 391
},
"outputId": "11f4171c-152f-47f7-8c00-74f9cbfc9f1f"
},
"source": [
"import torch\n",
"from torchsummary import summary\n",
"import torch.nn as nn\n",
"import torch.nn.functional as F\n",
"\n",
"class Net(nn.Module):\n",
" def __init__(self):\n",
" super(Net, self).__init__()\n",
" self.conv1 = nn.Conv2d(3, 64, 5)\n",
" self.pool = nn.MaxPool2d(2, 2)\n",
" self.conv2 = nn.Conv2d(64, 128, 5)\n",
" self.fc1 = nn.Linear(128 * 5 * 5, 120)\n",
" self.fc2 = nn.Linear(120, 10)\n",
" self.bn1 = nn.BatchNorm2d(64)\n",
" self.bn2 = nn.BatchNorm2d(128)\n",
" self.bn3 = nn.BatchNorm1d(120)\n",
"\n",
" def forward(self, x):\n",
" x = self.conv1(x)\n",
" x = self.bn1(x)\n",
" x = self.pool(F.relu(x))\n",
" x = self.conv2(x)\n",
" x = self.bn2(x)\n",
" x = self.pool(F.relu(x))\n",
" x = x.view(-1, 128 * 5 * 5)\n",
" x = F.relu(self.fc1(x))\n",
" x = self.bn3(x)\n",
" x = self.fc2(x)\n",
" return x\n",
"\n",
"summary(Net(), (3, 32, 32))"
],
"execution_count": 4,
"outputs": [
{
"output_type": "stream",
"text": [
"----------------------------------------------------------------\n",
" Layer (type) Output Shape Param #\n",
"================================================================\n",
" Conv2d-1 [-1, 64, 28, 28] 4,864\n",
" BatchNorm2d-2 [-1, 64, 28, 28] 128\n",
" MaxPool2d-3 [-1, 64, 14, 14] 0\n",
" Conv2d-4 [-1, 128, 10, 10] 204,928\n",
" BatchNorm2d-5 [-1, 128, 10, 10] 256\n",
" MaxPool2d-6 [-1, 128, 5, 5] 0\n",
" Linear-7 [-1, 120] 384,120\n",
" BatchNorm1d-8 [-1, 120] 240\n",
" Linear-9 [-1, 10] 1,210\n",
"================================================================\n",
"Total params: 595,746\n",
"Trainable params: 595,746\n",
"Non-trainable params: 0\n",
"----------------------------------------------------------------\n",
"Input size (MB): 0.01\n",
"Forward/backward pass size (MB): 1.08\n",
"Params size (MB): 2.27\n",
"Estimated Total Size (MB): 3.37\n",
"----------------------------------------------------------------\n"
],
"name": "stdout"
}
]
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment