Skip to content

Instantly share code, notes, and snippets.

@jaircastruita
Created March 5, 2021 06:42
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save jaircastruita/428bf1e8da997044f897ce7147cc5bc8 to your computer and use it in GitHub Desktop.
Save jaircastruita/428bf1e8da997044f897ce7147cc5bc8 to your computer and use it in GitHub Desktop.
class MildNet(nn.Module):
'''
Reference:
https://github.com/gofynd/mildnet/blob/master/trainer/model.py
'''
def __init__(self):
super(MildNet, self).__init__()
# VGG16 part
self.convblock1 = nn.Sequential(
nn.Conv2d(3, 64, kernel_size=3, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=2, stride=2)
)
self.convblock2 = nn.Sequential(
nn.Conv2d(64, 128, kernel_size=3, padding=1),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=2, stride=2)
)
self.convblock3 = nn.Sequential(
nn.Conv2d(128, 256, kernel_size=3, padding=1),
nn.Conv2d(256, 256, kernel_size=3, padding=1),
nn.BatchNorm2d(256),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=2, stride=2)
)
self.convblock4 = nn.Sequential(
nn.Conv2d(256, 512, kernel_size=3, padding=1),
nn.Conv2d(512, 512, kernel_size=3, padding=1),
nn.BatchNorm2d(512),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=2, stride=2)
)
self.convblock5 = nn.Sequential(
nn.Conv2d(512, 512, kernel_size=3, padding=1),
nn.Conv2d(512, 512, kernel_size=3, padding=1),
nn.BatchNorm2d(512),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=2, stride=2)
)
# Embedding output
self.fc1 = nn.Linear(1472, 2048)
self.dropout = nn.Dropout(p=0.5)
self.fc2 = nn.Linear(2048, 2048)
def forward(self, X):
out1 = self.convblock1(X)
out2 = self.convblock2(out1)
out3 = self.convblock3(out2)
out4 = self.convblock4(out3)
out5 = self.convblock5(out4)
agp1 = torch.mean(out1, dim=(2, 3))
agp2 = torch.mean(out2, dim=(2, 3))
agp3 = torch.mean(out3, dim=(2, 3))
agp4 = torch.mean(out4, dim=(2, 3))
agp5 = torch.mean(out5, dim=(2, 3))
emb = torch.cat([agp1, agp2, agp3, agp4, agp5], dim=1)
out = self.fc1(emb)
out = F.relu(out)
out = self.dropout(out)
out = self.fc2(out)
out = F.relu(out)
out = F.normalize(out, dim=1, p=2)
return out
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment