Skip to content

Instantly share code, notes, and snippets.

@ancri
Last active May 19, 2022 09:57
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save ancri/953aea7c6ea8f59723e9d7e9d745413a to your computer and use it in GitHub Desktop.
Save ancri/953aea7c6ea8f59723e9d7e9d745413a to your computer and use it in GitHub Desktop.
Quickly calculate required Linear layer input size in a Conv --> Linear architecture
# Below is a cheeky little utility that helps you
# avoid doing all the tedious math by piggy-backing
# on torch’s error messages. The assumption is that
# net is your Module which is composed of a sequence
# of input shape-agnostic layers (such as Conv & ReLU
# & BatchNorm & MaxPool layers), followed at some
# point by a flattening and a linear layer whose required
# size you’re trying to figure out. height and width
# refer to your desired input image shape.
try:
net.forward(torch.rand((1, 3, height, width)))
print("Image size is compatible with layer sizes.")
except RuntimeError as e:
e = str(e)
if e.endswith("Output size is too small"):
print("Image size is too small.")
elif "shapes cannot be multiplied" in e:
required_shape = e[e.index("x") + 1:].split(" ")[0]
print(f"Linear layer needs to have size: {required_shape}")
else:
print(f"Error not understood: {e}")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment