Skip to content

Instantly share code, notes, and snippets.

@Noob-can-Compile
Created January 20, 2020 11:33
Show Gist options
  • Save Noob-can-Compile/340cbf926a666b577bb66c5b1037e87b to your computer and use it in GitHub Desktop.
Save Noob-can-Compile/340cbf926a666b577bb66c5b1037e87b to your computer and use it in GitHub Desktop.
class EncoderCNN(nn.Module):
def __init__(self, embed_size):
super(EncoderCNN, self).__init__()
resnet = models.resnet50(pretrained=True)
for param in resnet.parameters():
param.requires_grad_(False)
modules = list(resnet.children())[:-1]
self.resnet = nn.Sequential(*modules)
self.embed = nn.Linear(resnet.fc.in_features, embed_size)
def forward(self, images):
features = self.resnet(images)
features = features.view(features.size(0), -1)
features = self.embed(features)
return features
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment