Skip to content

Instantly share code, notes, and snippets.

@seanbenhur
Last active November 10, 2020 11:03
Show Gist options
  • Save seanbenhur/ffc3b49f38c489f7308050741d3fadb9 to your computer and use it in GitHub Desktop.
Save seanbenhur/ffc3b49f38c489f7308050741d3fadb9 to your computer and use it in GitHub Desktop.
class ResNet(nn.Module):
def __init__(self,block,layers,image_channels,num_classes):
super(ResNet,self).__init__()
self.in_channels = 64
self.conv1 = nn.Conv2d(image_channels,64,kernel_size=7,stride=2,padding=3)
self.bn1 = nn.BatchNorm2d(64)
self.relu = nn.ReLU()
self.maxpool = nn.MaxPool2d(kernel_size=3,stride=2,padding=1)
#the resnet layers
self.layer1 = self._make_layer(block,layers[0],int_channels=64,stride=1)
self.layer2 = self._make_layer(block,layers[1],int_channels=128,stride=2)
self.layer3 = self._make_layer(block,layers[2],int_channels=256,stride=2)
self.layer4 = self._make_layer(block,layers[3],int_channels=512,stride=2)
self.avgpool = nn.AdaptiveAvgPool2d((1,1))
self.fc1 = nn.Linear(512*4,num_classes)
def forward(self,x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.avgpool(x)
x = x.reshape(x.shape[0],-1)
x = self.fc1(x)
return x
def ResNet18(img_channel=3,num_classes=1000):
return ResNet(ResBlock,[2,2,2,2],img_channel,num_classes)
def ResNet34(img_channel=3,num_classes=1000):
return ResNet(ResBlock,[3,4,6,3],img_channel,num_classes)
def ResNet50(img_channel=3,num_classes=1000):
return ResNet(ResBlock,[3,4,6,3],img_channel,num_classes)
def ResNet101(img_channel=3,num_classes=1000):
return ResNet(ResBlock,[3,4,23,3],img_channel,num_classes)
def ResNet152(img_channel=3,num_classes=10000):
return ResNet(ResBlock,[3,8,36,3],img_channel,num_classes)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment