Skip to content

Instantly share code, notes, and snippets.

@KushajveerSingh
Last active August 13, 2023 16:30
Show Gist options
  • Save KushajveerSingh/7773052dfb6d8adedc53d0544dedaf60 to your computer and use it in GitHub Desktop.
Save KushajveerSingh/7773052dfb6d8adedc53d0544dedaf60 to your computer and use it in GitHub Desktop.
VGG16 PyTorch implementation
class VGG(nn.Module):
"""
Standard PyTorch implementation of VGG. Pretrained imagenet model is used.
"""
def __init__(self):
super().__init__()
self.features = nn.Sequential(
# conv1
nn.Conv2d(3, 64, 3, padding=1),
nn.ReLU(),
nn.Conv2d(64, 64, 3, padding=1),
nn.ReLU(),
nn.MaxPool2d(2, stride=2, return_indices=True),
# conv2
nn.Conv2d(64, 128, 3, padding=1),
nn.ReLU(),
nn.Conv2d(128, 128, 3, padding=1),
nn.ReLU(),
nn.MaxPool2d(2, stride=2, return_indices=True),
# conv3
nn.Conv2d(128, 256, 3, padding=1),
nn.ReLU(),
nn.Conv2d(256, 256, 3, padding=1),
nn.ReLU(),
nn.Conv2d(256, 256, 3, padding=1),
nn.ReLU(),
nn.MaxPool2d(2, stride=2, return_indices=True),
# conv4
nn.Conv2d(256, 512, 3, padding=1),
nn.ReLU(),
nn.Conv2d(512, 512, 3, padding=1),
nn.ReLU(),
nn.Conv2d(512, 512, 3, padding=1),
nn.ReLU(),
nn.MaxPool2d(2, stride=2, return_indices=True),
# conv5
nn.Conv2d(512, 512, 3, padding=1),
nn.ReLU(),
nn.Conv2d(512, 512, 3, padding=1),
nn.ReLU(),
nn.Conv2d(512, 512, 3, padding=1),
nn.ReLU(),
nn.MaxPool2d(2, stride=2, return_indices=True)
)
self.classifier = nn.Sequential(
nn.Linear(512 * 7 * 7, 4096),
nn.ReLU(),
nn.Dropout(),
nn.Linear(4096, 4096),
nn.ReLU(),
nn.Dropout(),
nn.Linear(4096, 1000)
)
# We need these for MaxUnpool operation
self.conv_layer_indices = [0, 2, 5, 7, 10, 12, 14, 17, 19, 21, 24, 26, 28]
self.feature_maps = OrderedDict()
self.pool_locs = OrderedDict()
def forward(self, x):
for layer in self.features:
if isinstance(layer, nn.MaxPool2d):
x, location = layer(x)
else:
x = layer(x)
x = x.view(x.size()[0], -1)
x = self.classifier(x)
return x
def get_vgg():
vgg = VGG()
temp = torchvision.models.vgg16(pretrained=True)
vgg.load_state_dict(temp.state_dict())
return vgg
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment