Skip to content

Instantly share code, notes, and snippets.

@nikogamulin
Created November 1, 2020 12:50
Show Gist options
  • Save nikogamulin/7774e0e3988305a78fd73e1c4364aded to your computer and use it in GitHub Desktop.
Save nikogamulin/7774e0e3988305a78fd73e1c4364aded to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"torch.Size([4, 1000])\n"
]
}
],
"source": [
"import torch\n",
"import torch.nn as nn\n",
"\n",
"\n",
"class Block(nn.Module):\n",
" def __init__(self, num_layers, in_channels, out_channels, identity_downsample=None, stride=1):\n",
" assert num_layers in [18, 34, 50, 101, 152], \"should be a a valid architecture\"\n",
" super(Block, self).__init__()\n",
" self.num_layers = num_layers\n",
" if self.num_layers > 34:\n",
" self.expansion = 4\n",
" else:\n",
" self.expansion = 1\n",
" # ResNet50, 101, and 152 include additional layer of 1x1 kernels\n",
" self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)\n",
" self.bn1 = nn.BatchNorm2d(out_channels)\n",
" if self.num_layers > 34:\n",
" self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=stride, padding=1)\n",
" else:\n",
" # for ResNet18 and 34, connect input directly to (3x3) kernel (skip first (1x1))\n",
" self.conv2 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1)\n",
" self.bn2 = nn.BatchNorm2d(out_channels)\n",
" self.conv3 = nn.Conv2d(out_channels, out_channels * self.expansion, kernel_size=1, stride=1, padding=0)\n",
" self.bn3 = nn.BatchNorm2d(out_channels * self.expansion)\n",
" self.relu = nn.ReLU()\n",
" self.identity_downsample = identity_downsample\n",
"\n",
" def forward(self, x):\n",
" identity = x\n",
" if self.num_layers > 34:\n",
" x = self.conv1(x)\n",
" x = self.bn1(x)\n",
" x = self.relu(x)\n",
" x = self.conv2(x)\n",
" x = self.bn2(x)\n",
" x = self.relu(x)\n",
" x = self.conv3(x)\n",
" x = self.bn3(x)\n",
"\n",
" if self.identity_downsample is not None:\n",
" identity = self.identity_downsample(identity)\n",
"\n",
" x += identity\n",
" x = self.relu(x)\n",
" return x\n",
"\n",
"\n",
"class ResNet(nn.Module):\n",
" def __init__(self, num_layers, block, image_channels, num_classes):\n",
" assert num_layers in [18, 34, 50, 101, 152], f'ResNet{num_layers}: Unknown architecture! Number of layers has ' \\\n",
" f'to be 18, 34, 50, 101, or 152 '\n",
" super(ResNet, self).__init__()\n",
" if num_layers < 50:\n",
" self.expansion = 1\n",
" else:\n",
" self.expansion = 4\n",
" if num_layers == 18:\n",
" layers = [2, 2, 2, 2]\n",
" elif num_layers == 34 or num_layers == 50:\n",
" layers = [3, 4, 6, 3]\n",
" elif num_layers == 101:\n",
" layers = [3, 4, 23, 3]\n",
" else:\n",
" layers = [3, 8, 36, 3]\n",
" self.in_channels = 64\n",
" self.conv1 = nn.Conv2d(image_channels, 64, kernel_size=7, stride=2, padding=3)\n",
" self.bn1 = nn.BatchNorm2d(64)\n",
" self.relu = nn.ReLU()\n",
" self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)\n",
"\n",
" # ResNetLayers\n",
" self.layer1 = self.make_layers(num_layers, block, layers[0], intermediate_channels=64, stride=1)\n",
" self.layer2 = self.make_layers(num_layers, block, layers[1], intermediate_channels=128, stride=2)\n",
" self.layer3 = self.make_layers(num_layers, block, layers[2], intermediate_channels=256, stride=2)\n",
" self.layer4 = self.make_layers(num_layers, block, layers[3], intermediate_channels=512, stride=2)\n",
"\n",
" self.avgpool = nn.AdaptiveAvgPool2d((1, 1))\n",
" self.fc = nn.Linear(512 * self.expansion, num_classes)\n",
"\n",
" def forward(self, x):\n",
" x = self.conv1(x)\n",
" x = self.bn1(x)\n",
" x = self.relu(x)\n",
" x = self.maxpool(x)\n",
"\n",
" x = self.layer1(x)\n",
" x = self.layer2(x)\n",
" x = self.layer3(x)\n",
" x = self.layer4(x)\n",
"\n",
" x = self.avgpool(x)\n",
" x = x.reshape(x.shape[0], -1)\n",
" x = self.fc(x)\n",
" return x\n",
"\n",
" def make_layers(self, num_layers, block, num_residual_blocks, intermediate_channels, stride):\n",
" layers = []\n",
"\n",
" identity_downsample = nn.Sequential(nn.Conv2d(self.in_channels, intermediate_channels*self.expansion, kernel_size=1, stride=stride),\n",
" nn.BatchNorm2d(intermediate_channels*self.expansion))\n",
" layers.append(block(num_layers, self.in_channels, intermediate_channels, identity_downsample, stride))\n",
" self.in_channels = intermediate_channels * self.expansion # 256\n",
" for i in range(num_residual_blocks - 1):\n",
" layers.append(block(num_layers, self.in_channels, intermediate_channels)) # 256 -> 64, 64*4 (256) again\n",
" return nn.Sequential(*layers)\n",
"\n",
"\n",
"def ResNet18(img_channels=3, num_classes=1000):\n",
" return ResNet(18, Block, img_channels, num_classes)\n",
"\n",
"\n",
"def ResNet34(img_channels=3, num_classes=1000):\n",
" return ResNet(34, Block, img_channels, num_classes)\n",
"\n",
"\n",
"def ResNet50(img_channels=3, num_classes=1000):\n",
" return ResNet(50, Block, img_channels, num_classes)\n",
"\n",
"\n",
"def ResNet101(img_channels=3, num_classes=1000):\n",
" return ResNet(101, Block, img_channels, num_classes)\n",
"\n",
"\n",
"def ResNet152(img_channels=3, num_classes=1000):\n",
" return ResNet(152, Block, img_channels, num_classes)\n",
"\n",
"\n",
"def test():\n",
" net = ResNet18(img_channels=3, num_classes=1000)\n",
" y = net(torch.randn(4, 3, 224, 224)).to(\"cuda\")\n",
" print(y.size())\n",
"\n",
"\n",
"test()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"gist": {
"data": {
"description": "ResNet.ipynb",
"public": true
},
"id": ""
},
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.0"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
@mukul54
Copy link

mukul54 commented Mar 25, 2021

Hi, thanks for the code. In the Block class for conv3 won't the filter size be 3 for resnet18 and resnet34?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment