Last active
January 30, 2020 15:30
-
-
Save alkalait/15a8e0b6954f5b9b17ebf9f3bcc3ebb1 to your computer and use it in GitHub Desktop.
Output shape of a PyTorch layer
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch # 1.3.1 | |
import torch.nn as nn | |
import torchvision # 0.4.2 | |
# for example | |
resnet = torchvision.models.resnet50(pretrained=False) | |
# strip away its FC layer | |
resnet_noFC = nn.Sequential(*list(resnet.children())[:-1]) | |
# Suppose you want to use it for binary classification. | |
# How would you know the shape of the tensor to feed into your new FC layer? | |
# Feed it a dummy tensor and see what comes out | |
x = torch.zeros(1, 3, 33, 33) # batch size, channels, height, width | |
y = resnet_noFC(x) | |
print(y.shape) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment