Last active
December 29, 2019 08:42
-
-
Save dusanstanojeviccs/2b3635d8cd4d4d5c52c1fc2043e5ba4a to your computer and use it in GitHub Desktop.
PyTorch - How to calculate the Conv2d output (stride, padding, dilation)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def conv2d_output(conv_layer, input_height = 128, input_width = 128): | |
"""Calculates the output dimensions of a convolutional layer in PyTorch. | |
Keyword arguments: | |
conv_layer -- the nn.Conv2d object | |
input_height -- the height of the input image to be processed by Conv2d | |
input_width -- the width of the input image to be processed by Conv2d | |
""" | |
# We need to get all of the values from the conv layer | |
if type(conv_layer.kernel_size) is tuple: | |
(kernel_size_h, kernel_size_w) = conv_layer.kernel_size | |
else: | |
kernel_size_h = conv_layer.kernel_size | |
kernel_size_w = conv_layer.kernel_size | |
if type(conv_layer.stride) is tuple: | |
(stride_h, stride_w) = conv_layer.stride | |
else: | |
stride_h = conv_layer.stride | |
stride_w = conv_layer.stride | |
if type(conv_layer.padding) is tuple: | |
(padding_h, padding_w) = conv_layer.padding | |
else: | |
padding_h = conv_layer.padding | |
padding_w = conv_layer.padding | |
if type(conv_layer.dilation) is tuple: | |
(dilation_h, dilation_w) = conv_layer.dilation | |
else: | |
dilation_h = conv_layer.dilation | |
dilation_w = conv_layer.dilation | |
# Now we can calculate the sizes all the hway (joke) | |
# f_h and f_w are filter height and width after dilation | |
dilation_w = dilation_w - 1 | |
dilation_h = dilation_h - 1 | |
f_h = kernel_size_h + (dilation_h * (kernel_size_h - 1)) | |
f_w = kernel_size_w + (dilation_w * (kernel_size_w - 1)) | |
# h and w are the width and height after the padding | |
h = 2 * padding_h + input_height | |
w = 2 * padding_w + input_width | |
# now the calculation for the output size is simple | |
# as it is based only on f_h, f_w, h and w | |
output_w = int((w - f_w) / stride_w) + 1 | |
output_h = int((h - f_h) / stride_h) + 1 | |
output_c = conv_layer.out_channels | |
return output_c, output_h, output_w | |
def test_conv2d_output(): | |
test_data = [ | |
# in_channels, out_channels, kernel_size, stride, padding, dilation, w, h | |
[3, 4, 3, 1, 0, 1, 128, 128], | |
[3, 5, 5, 2, 1, 2, 64, 64], | |
[3, 6, 10, 3, 2, 3, 512, 39], | |
[3, 7, 14, 4, 3, 4, 245, 190], | |
[3, 8, 17, 5, 4, 5, 143, 341], | |
[3, 4, (3, 3 + 2), (2 + 1, 2), (1 + 0, 0), (1, 1 + 1), 128, 128], | |
[3, 5, (5, 5 + 2), (2 + 2, 2), (1 + 1, 1), (2, 1 + 2), 28, 124], | |
[3, 6, (10, 10 + 2), (2 + 3, 2), (1 + 2, 2), (3, 1 + 3), 512, 123], | |
[3, 7, (14, 14 + 2), (2 + 4, 2), (1 + 3, 3), (4, 1 + 4), 150, 190], | |
[3, 8, (17, 17 + 2), (2 + 5, 2), (1 + 4, 4), (5, 1 + 5), 143, 341], | |
] | |
for line in test_data: | |
conv_demo = nn.Conv2d( | |
in_channels = line[0], out_channels = line[1], kernel_size = line[2], stride = line[3], padding = line[4], dilation = line[5] | |
) | |
x = torch.rand((1, 3, line[6], line[7])) | |
result = conv_demo(x) | |
s = result.size()[1:4] | |
o = conv2d_output(conv_demo, line[6], line[7]) | |
if not (s[0] == o[0] and s[1] == o[1] and s[2] == o[2]): | |
print("Testing has failed :(") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment