Skip to content

Instantly share code, notes, and snippets.

@dusanstanojeviccs
Last active December 29, 2019 08:42
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save dusanstanojeviccs/2b3635d8cd4d4d5c52c1fc2043e5ba4a to your computer and use it in GitHub Desktop.
Save dusanstanojeviccs/2b3635d8cd4d4d5c52c1fc2043e5ba4a to your computer and use it in GitHub Desktop.
PyTorch - How to calculate the Conv2d output (stride, padding, dilation)
def conv2d_output(conv_layer, input_height = 128, input_width = 128):
"""Calculates the output dimensions of a convolutional layer in PyTorch.
Keyword arguments:
conv_layer -- the nn.Conv2d object
input_height -- the height of the input image to be processed by Conv2d
input_width -- the width of the input image to be processed by Conv2d
"""
# We need to get all of the values from the conv layer
if type(conv_layer.kernel_size) is tuple:
(kernel_size_h, kernel_size_w) = conv_layer.kernel_size
else:
kernel_size_h = conv_layer.kernel_size
kernel_size_w = conv_layer.kernel_size
if type(conv_layer.stride) is tuple:
(stride_h, stride_w) = conv_layer.stride
else:
stride_h = conv_layer.stride
stride_w = conv_layer.stride
if type(conv_layer.padding) is tuple:
(padding_h, padding_w) = conv_layer.padding
else:
padding_h = conv_layer.padding
padding_w = conv_layer.padding
if type(conv_layer.dilation) is tuple:
(dilation_h, dilation_w) = conv_layer.dilation
else:
dilation_h = conv_layer.dilation
dilation_w = conv_layer.dilation
# Now we can calculate the sizes all the hway (joke)
# f_h and f_w are filter height and width after dilation
dilation_w = dilation_w - 1
dilation_h = dilation_h - 1
f_h = kernel_size_h + (dilation_h * (kernel_size_h - 1))
f_w = kernel_size_w + (dilation_w * (kernel_size_w - 1))
# h and w are the width and height after the padding
h = 2 * padding_h + input_height
w = 2 * padding_w + input_width
# now the calculation for the output size is simple
# as it is based only on f_h, f_w, h and w
output_w = int((w - f_w) / stride_w) + 1
output_h = int((h - f_h) / stride_h) + 1
output_c = conv_layer.out_channels
return output_c, output_h, output_w
def test_conv2d_output():
test_data = [
# in_channels, out_channels, kernel_size, stride, padding, dilation, w, h
[3, 4, 3, 1, 0, 1, 128, 128],
[3, 5, 5, 2, 1, 2, 64, 64],
[3, 6, 10, 3, 2, 3, 512, 39],
[3, 7, 14, 4, 3, 4, 245, 190],
[3, 8, 17, 5, 4, 5, 143, 341],
[3, 4, (3, 3 + 2), (2 + 1, 2), (1 + 0, 0), (1, 1 + 1), 128, 128],
[3, 5, (5, 5 + 2), (2 + 2, 2), (1 + 1, 1), (2, 1 + 2), 28, 124],
[3, 6, (10, 10 + 2), (2 + 3, 2), (1 + 2, 2), (3, 1 + 3), 512, 123],
[3, 7, (14, 14 + 2), (2 + 4, 2), (1 + 3, 3), (4, 1 + 4), 150, 190],
[3, 8, (17, 17 + 2), (2 + 5, 2), (1 + 4, 4), (5, 1 + 5), 143, 341],
]
for line in test_data:
conv_demo = nn.Conv2d(
in_channels = line[0], out_channels = line[1], kernel_size = line[2], stride = line[3], padding = line[4], dilation = line[5]
)
x = torch.rand((1, 3, line[6], line[7]))
result = conv_demo(x)
s = result.size()[1:4]
o = conv2d_output(conv_demo, line[6], line[7])
if not (s[0] == o[0] and s[1] == o[1] and s[2] == o[2]):
print("Testing has failed :(")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment