-
-
Save adarsh-kr/b53ff2a7e2163ea4205ce7ac444277ce to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import torch | |
import torchvision | |
import torchvision.transforms as transforms | |
from torch.utils.data import DataLoader | |
import torchvision.models as models | |
import torch.optim as optim | |
import torch.nn as nn | |
import random | |
import argparse | |
class EmptyShortcutLayer(nn.Module): | |
def __init__(self): | |
super(EmptyShortcutLayer, self).__init__() | |
def forward(self, x): | |
return x | |
class ResNet18(nn.Module): | |
# do all layer manually | |
# DO NOT CREATE sequential for blocks | |
def __init__(self, input_size, block="Basic", num_classes=10): | |
super(ResNet18, self).__init__() | |
# blocks 2,2,2,2 | |
self.in_planes = 64 | |
self.input_size = input_size | |
self.output_size = self.input_size | |
self.num_classes = num_classes | |
self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) | |
# batch normalization takes C as argument, i.e. the channels | |
self.bn1 = nn.BatchNorm2d(64) | |
# segment 1 | |
# block 1 | |
self.block_1_conv1 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=False) | |
self.block_1_bn1 = nn.BatchNorm2d(64) | |
self.block_1_conv2 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=False) | |
self.block_1_bn2 = nn.BatchNorm2d(64) | |
self.block_1_shortcut = EmptyShortcutLayer() | |
# block 2 | |
self.block_2_conv1 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=False) | |
self.block_2_bn1 = nn.BatchNorm2d(64) | |
self.block_2_conv2 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=False) | |
self.block_2_bn2 = nn.BatchNorm2d(64) | |
self.block_2_shortcut = EmptyShortcutLayer() | |
# segment 2 | |
# block 3 | |
self.block_3_conv1 = nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1, bias=False) | |
self.block_3_bn1 = nn.BatchNorm2d(128) | |
# stride 2, so half it | |
self.output_size = self.output_size/2 | |
self.block_3_conv2 = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1, bias=False) | |
self.block_3_bn2 = nn.BatchNorm2d(128) | |
self.block_3_shortcut = nn.Sequential(nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1, bias=False), nn.BatchNorm2d(128)) | |
# block 4 | |
self.block_4_conv1 = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1, bias=False) | |
self.block_4_bn1 = nn.BatchNorm2d(128) | |
self.block_4_conv2 = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1, bias=False) | |
self.block_4_bn2 = nn.BatchNorm2d(128) | |
self.block_4_shortcut = EmptyShortcutLayer() | |
# segment 3 | |
# block 5 | |
self.block_5_conv1 = nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1, bias=False) | |
self.block_5_bn1 = nn.BatchNorm2d(256) | |
# stride =2, output size is halved | |
self.output_size = self.output_size/2 | |
self.block_5_conv2 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=False) | |
self.block_5_bn2 = nn.BatchNorm2d(256) | |
self.block_5_shortcut = nn.Sequential(nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1, bias=False), nn.BatchNorm2d(256)) | |
# block 6 | |
self.block_6_conv1 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=False) | |
self.block_6_bn1 = nn.BatchNorm2d(256) | |
self.block_6_conv2 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=False) | |
self.block_6_bn2 = nn.BatchNorm2d(256) | |
self.block_6_shortcut = EmptyShortcutLayer() | |
# segment 4 | |
# block 7 | |
self.block_7_conv1 = nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1, bias=False) | |
self.block_7_bn1 = nn.BatchNorm2d(512) | |
# stride=2, output size is halved | |
self.output_size = self.output_size/2 | |
self.block_7_conv2 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=False) | |
self.block_7_bn2 = nn.BatchNorm2d(512) | |
self.block_7_shortcut = nn.Sequential(nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1, bias=False), nn.BatchNorm2d(512)) | |
# block 8 | |
self.block_8_conv1 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=False) | |
self.block_8_bn1 = nn.BatchNorm2d(512) | |
self.block_8_conv2 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=False) | |
self.block_8_bn2 = nn.BatchNorm2d(512) | |
self.block_8_shortcut = EmptyShortcutLayer() | |
# avg pooling | |
print("self output isze",self.output_size) | |
self.avg_pool = nn.AvgPool2d(self.output_size) | |
self.output_size = self.output_size/4 | |
# final layer | |
self.linear = nn.Linear(512, self.num_classes) | |
def forward(self, x): | |
y = x | |
# first conv | |
y = self.conv1(y) | |
y = self.bn1(y) | |
# first block | |
out = self.block_1_conv1(y) | |
out = self.block_1_bn1(out) | |
out = self.block_1_conv2(out) | |
out = self.block_1_bn2(out) | |
y = self.block_1_shortcut(y) + out | |
# second block | |
out = self.block_2_conv1(y) | |
out = self.block_2_bn1(out) | |
out = self.block_2_conv2(out) | |
out = self.block_2_bn2(out) | |
y = self.block_2_shortcut(y) + out | |
# third block | |
out = self.block_3_conv1(y) | |
out = self.block_3_bn1(out) | |
out = self.block_3_conv2(out) | |
out = self.block_3_bn2(out) | |
y = self.block_3_shortcut(y) + out | |
# fourth block | |
out = self.block_4_conv1(y) | |
out = self.block_4_bn1(out) | |
out = self.block_4_conv2(out) | |
out = self.block_4_bn2(out) | |
y = self.block_4_shortcut(y) + out | |
# fifth block | |
out = self.block_5_conv1(y) | |
out = self.block_5_bn1(out) | |
out = self.block_5_conv2(out) | |
out = self.block_5_bn2(out) | |
y = self.block_5_shortcut(y) + out | |
# sixth block | |
out = self.block_6_conv1(y) | |
out = self.block_6_bn1(out) | |
out = self.block_6_conv2(out) | |
out = self.block_6_bn2(out) | |
y = self.block_6_shortcut(y) + out | |
# seveth block | |
out = self.block_7_conv1(y) | |
out = self.block_7_bn1(out) | |
out = self.block_7_conv2(out) | |
out = self.block_7_bn2(out) | |
y = self.block_7_shortcut(y) + out | |
# eigth block | |
out = self.block_8_conv1(y) | |
out = self.block_8_bn1(out) | |
out = self.block_8_conv2(out) | |
out = self.block_8_bn2(out) | |
y = self.block_8_shortcut(y) + out | |
# avg pool | |
y = self.avg_pool(y) | |
y = y.view(y.size(0), -1) | |
y = self.linear(y) | |
return y | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment