Skip to content

Instantly share code, notes, and snippets.

@n-taku
Created April 16, 2020 14:23
Show Gist options
  • Save n-taku/bfe31282bf385c913f599b2ab03adb0a to your computer and use it in GitHub Desktop.
Save n-taku/bfe31282bf385c913f599b2ab03adb0a to your computer and use it in GitHub Desktop.
DropoutModelのSample
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "DropoutModelSample.ipynb",
"provenance": [],
"collapsed_sections": []
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
}
},
"cells": [
{
"cell_type": "code",
"metadata": {
"id": "VPcDpmvhWWQp",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 425
},
"outputId": "5c0df660-64ae-42b8-a418-f3d6f2332510"
},
"source": [
"import torch\n",
"import torch.nn as nn\n",
"from torchsummary import summary\n",
"import torch.nn.functional as F\n",
"\n",
"class Net(nn.Module):\n",
" def __init__(self):\n",
" super(Net, self).__init__()\n",
" self.pool = nn.MaxPool2d(2, 2)\n",
" self.conv1 = nn.Conv2d(3, 16, 5)\n",
" self.conv2 = nn.Conv2d(16, 32, 5)\n",
" self.conv3 = nn.Conv2d(32, 32, 5)\n",
" self.fc1 = nn.Linear(32 * 6 * 6, 256)\n",
" self.fc2 = nn.Linear(256, 10)\n",
" self.dropout1 = torch.nn.Dropout2d(p=0.2)\n",
" self.dropout2 = torch.nn.Dropout2d(p=0.3)\n",
" self.dropout3 = torch.nn.Dropout(p=0.3)\n",
" def forward(self, x):\n",
" x = self.dropout1(x)\n",
" x = self.pool(F.relu(self.conv1(x)))\n",
" x = self.dropout2(x)\n",
" x = F.relu(self.conv2(x))\n",
" x = self.dropout2(x)\n",
" x = F.relu(self.conv3(x))\n",
" x = self.dropout2(x)\n",
" x = torch.flatten(x, 1)\n",
" x = F.relu(self.fc1(x))\n",
" x = self.dropout3(x)\n",
" x = self.fc2(x)\n",
" return x\n",
"\n",
"summary(Net(), (3, 32, 32))"
],
"execution_count": 2,
"outputs": [
{
"output_type": "stream",
"text": [
"----------------------------------------------------------------\n",
" Layer (type) Output Shape Param #\n",
"================================================================\n",
" Dropout2d-1 [-1, 3, 32, 32] 0\n",
" Conv2d-2 [-1, 16, 28, 28] 1,216\n",
" MaxPool2d-3 [-1, 16, 14, 14] 0\n",
" Dropout2d-4 [-1, 16, 14, 14] 0\n",
" Conv2d-5 [-1, 32, 10, 10] 12,832\n",
" Dropout2d-6 [-1, 32, 10, 10] 0\n",
" Conv2d-7 [-1, 32, 6, 6] 25,632\n",
" Dropout2d-8 [-1, 32, 6, 6] 0\n",
" Linear-9 [-1, 256] 295,168\n",
" Dropout-10 [-1, 256] 0\n",
" Linear-11 [-1, 10] 2,570\n",
"================================================================\n",
"Total params: 337,418\n",
"Trainable params: 337,418\n",
"Non-trainable params: 0\n",
"----------------------------------------------------------------\n",
"Input size (MB): 0.01\n",
"Forward/backward pass size (MB): 0.24\n",
"Params size (MB): 1.29\n",
"Estimated Total Size (MB): 1.54\n",
"----------------------------------------------------------------\n"
],
"name": "stdout"
}
]
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment