Skip to content

Instantly share code, notes, and snippets.

@hereismari
Created August 28, 2019 20:54
Show Gist options
  • Save hereismari/85891b30eb9804a7e81e75a238fb8e84 to your computer and use it in GitHub Desktop.
Save hereismari/85891b30eb9804a7e81e75a238fb8e84 to your computer and use it in GitHub Desktop.
def maxpool2d(a_sh, kernel_size: int = 1, stride: int = 1, padding: int = 0,
dilation: int = 1, ceil_mode=False)
"""Applies a 2D max pooling over an input signal composed of several input planes.
This interface is similar to torch.nn.MaxPool2D.
Args:
kernel_size: the size of the window to take a max over
stride: the stride of the window
padding: implicit zero padding to be added on both sides
dilation: a parameter that controls the stride of elements in the window
ceil_mode: when True, will use ceil instead of floor to compute the output shape
"""
assert len(a_sh.shape) == 4
# Change to tuple if not one
stride = torch.nn.modules.utils._pair(stride)
padding = torch.nn.modules.utils._pair(padding)
dilation = torch.nn.modules.utils._pair(dilation)
# Extract a few useful values
bh_in, ch_in, h_in, w_in = input.shape
# ########## Calculate output shapes ###############
round_op = math.ceil if ceil_model else math.floor
h_out = round_op((h_in + 2 * h_padding - h_dilation * (h_kernel_size - 1) - 1)/h_stride + 1))
w_out = round_op((h_in + 2 * w_padding - w_dilation * (w_kernel_size - 1) - 1)/w_stride + 1))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment