Skip to content

Instantly share code, notes, and snippets.

@alkalait
Last active January 30, 2020 15:30
Show Gist options
  • Save alkalait/15a8e0b6954f5b9b17ebf9f3bcc3ebb1 to your computer and use it in GitHub Desktop.
Save alkalait/15a8e0b6954f5b9b17ebf9f3bcc3ebb1 to your computer and use it in GitHub Desktop.
Output shape of a PyTorch layer
import torch # 1.3.1
import torch.nn as nn
import torchvision # 0.4.2
# for example
resnet = torchvision.models.resnet50(pretrained=False)
# strip away its FC layer
resnet_noFC = nn.Sequential(*list(resnet.children())[:-1])
# Suppose you want to use it for binary classification.
# How would you know the shape of the tensor to feed into your new FC layer?
# Feed it a dummy tensor and see what comes out
x = torch.zeros(1, 3, 33, 33) # batch size, channels, height, width
y = resnet_noFC(x)
print(y.shape)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment