Last active
May 19, 2022 09:57
-
-
Save ancri/953aea7c6ea8f59723e9d7e9d745413a to your computer and use it in GitHub Desktop.
Quickly calculate required Linear layer input size in a Conv --> Linear architecture
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Below is a cheeky little utility that helps you | |
# avoid doing all the tedious math by piggy-backing | |
# on torch’s error messages. The assumption is that | |
# net is your Module which is composed of a sequence | |
# of input shape-agnostic layers (such as Conv & ReLU | |
# & BatchNorm & MaxPool layers), followed at some | |
# point by a flattening and a linear layer whose required | |
# size you’re trying to figure out. height and width | |
# refer to your desired input image shape. | |
try: | |
net.forward(torch.rand((1, 3, height, width))) | |
print("Image size is compatible with layer sizes.") | |
except RuntimeError as e: | |
e = str(e) | |
if e.endswith("Output size is too small"): | |
print("Image size is too small.") | |
elif "shapes cannot be multiplied" in e: | |
required_shape = e[e.index("x") + 1:].split(" ")[0] | |
print(f"Linear layer needs to have size: {required_shape}") | |
else: | |
print(f"Error not understood: {e}") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment