Skip to content

Instantly share code, notes, and snippets.

@n-taku
Created April 12, 2020 17:35
Show Gist options
  • Save n-taku/30fdb3396d3d408864992b67fee0cd17 to your computer and use it in GitHub Desktop.
Save n-taku/30fdb3396d3d408864992b67fee0cd17 to your computer and use it in GitHub Desktop.
GAPのサンプル
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "GAPModelSample.ipynb",
"provenance": [],
"collapsed_sections": []
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
}
},
"cells": [
{
"cell_type": "code",
"metadata": {
"id": "lnieGyUccbs-",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 323
},
"outputId": "2c3f054b-3ab6-433a-c3e6-25e011f77854"
},
"source": [
"import torch\n",
"import torch.nn as nn\n",
"import torch.nn.functional as F\n",
"from torchsummary import summary\n",
"\n",
"class Net(nn.Module):\n",
" def __init__(self):\n",
" super(Net, self).__init__()\n",
" self.conv1 = nn.Conv2d(3, 16, 5)\n",
" self.conv2 = nn.Conv2d(16, 32, 5)\n",
" self.conv3 = nn.Conv2d(32, 64, 5)\n",
" self.conv4 = nn.Conv2d(64, 10, 5)\n",
" self.avgpool = torch.nn.AdaptiveAvgPool2d((1,1))\n",
" def forward(self, x):\n",
" x = F.relu(self.conv1(x))\n",
" x = F.relu(self.conv2(x))\n",
" x = F.relu(self.conv3(x))\n",
" x = F.relu(self.conv4(x))\n",
" #GAP\n",
" x = self.avgpool(x)\n",
" x = torch.flatten(x, 1)\n",
" return x\n",
"\n",
"summary(Net(), (3, 32, 32))"
],
"execution_count": 5,
"outputs": [
{
"output_type": "stream",
"text": [
"----------------------------------------------------------------\n",
" Layer (type) Output Shape Param #\n",
"================================================================\n",
" Conv2d-1 [-1, 16, 28, 28] 1,216\n",
" Conv2d-2 [-1, 32, 24, 24] 12,832\n",
" Conv2d-3 [-1, 64, 20, 20] 51,264\n",
" Conv2d-4 [-1, 10, 16, 16] 16,010\n",
" AdaptiveAvgPool2d-5 [-1, 10, 1, 1] 0\n",
"================================================================\n",
"Total params: 81,322\n",
"Trainable params: 81,322\n",
"Non-trainable params: 0\n",
"----------------------------------------------------------------\n",
"Input size (MB): 0.01\n",
"Forward/backward pass size (MB): 0.45\n",
"Params size (MB): 0.31\n",
"Estimated Total Size (MB): 0.77\n",
"----------------------------------------------------------------\n"
],
"name": "stdout"
}
]
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment