Skip to content

Instantly share code, notes, and snippets.

@crackcomm
Forked from Erotemic/output_shape_for.py
Created April 29, 2019 22:28
Show Gist options
  • Save crackcomm/e29b640a71f728bef74e321164a0eaec to your computer and use it in GitHub Desktop.
Save crackcomm/e29b640a71f728bef74e321164a0eaec to your computer and use it in GitHub Desktop.
Output shape information for pytorch
import math
import torch
import torch.nn as nn
import torchvision
REGISTERED_OUTPUT_SHAPE_TYPES = []
def compute_type(type):
def _wrap(func):
REGISTERED_OUTPUT_SHAPE_TYPES.append((type, func))
return func
return _wrap
class OutputShapeFor(object):
math = math # for hacking in sympy
def __init__(self, module):
self.module = module
self._func = getattr(module, 'output_shape_for', None)
if self._func is None:
# Lookup shape func if we can't find it
for type, _func in REGISTERED_OUTPUT_SHAPE_TYPES:
try:
if module is type or isinstance(module, type):
self._func = _func
except TypeError:
pass
if not self._func:
raise TypeError('Unknown module type {}'.format(module))
def __call__(self, *args, **kwargs):
if isinstance(self.module, nn.Module):
# bound methods dont need module
is_bound = hasattr(self._func, '__func__') and getattr(self._func, '__func__', None) is not None
is_bound |= hasattr(self._func, 'im_func') and getattr(self._func, 'im_func', None) is not None
if is_bound:
output_shape = self._func(*args, **kwargs)
else:
# nn.Module with state
output_shape = self._func(self.module, *args, **kwargs)
else:
# a simple pytorch func
output_shape = self._func(*args, **kwargs)
return output_shape
@staticmethod
@compute_type(nn.UpsamplingBilinear2d)
def UpsamplingBilinear2d(module, input_shape):
"""
- Input: :math:`(N, C, H_{in}, W_{in})`
- Output: :math:`(N, C, H_{out}, W_{out})` where
:math:`H_{out} = floor(H_{in} * scale\_factor)`
:math:`W_{out} = floor(W_{in} * scale\_factor)`
Example:
>>> from pysseg.torch.models.output_shape_for import *
>>> input_shape = (1, 3, 256, 256)
>>> module = nn.UpsamplingBilinear2d(scale_factor=2)
>>> output_shape = OutputShapeFor(module)(input_shape)
>>> print('output_shape = {!r}'.format(output_shape))
output_shape = (1, 3, 512, 512)
"""
math = OutputShapeFor.math
(N, C, H_in, W_in) = input_shape
H_out = math.floor(H_in * module.scale_factor)
W_out = math.floor(W_in * module.scale_factor)
output_shape = (N, C, H_out, W_out)
return output_shape
@staticmethod
@compute_type(nn.ConvTranspose2d)
def conv2dT(module, input_shape):
"""
- Input: :math:`(N, C_{in}, H_{in}, W_{in})`
- Output: :math:`(N, C_{out}, H_{out}, W_{out})` where
:math:`H_{out} = (H_{in} - 1) * stride[0] - 2 * padding[0] + kernel\_size[0] + output\_padding[0]`
:math:`W_{out} = (W_{in} - 1) * stride[1] - 2 * padding[1] + kernel\_size[1] + output\_padding[1]`
Example:
>>> from pysseg.torch.models.output_shape_for import *
>>> input_shape = (1, 3, 256, 256)
>>> module = nn.ConvTranspose2d(input_shape[1], 11, kernel_size=2, stride=2)
>>> output_shape = OutputShapeFor(module)(input_shape)
>>> print('output_shape = {!r}'.format(output_shape))
output_shape = (1, 11, 512, 512)
"""
(N, C_in, H_in, W_in) = input_shape
C_out = module.out_channels
stride = module.stride
kernel_size = module.kernel_size
output_padding = module.output_padding
padding = module.padding
H_out = (H_in - 1) * stride[0] - 2 * padding[0] + kernel_size[0] + output_padding[0]
W_out = (W_in - 1) * stride[1] - 2 * padding[1] + kernel_size[1] + output_padding[1]
output_shape = (N, C_out, H_out, W_out)
return output_shape
@staticmethod
@compute_type(nn.Conv2d)
def conv2d(module, input_shape):
"""
Notes:
- Input: :math:`(N, C_{in}, H_{in}, W_{in})`
- Output: :math:`(N, C_{out}, H_{out}, W_{out})` where
:math:`H_{out} = floor((H_{in} + 2 * padding[0] - dilation[0] * (kernel\_size[0] - 1) - 1) / stride[0] + 1)`
:math:`W_{out} = floor((W_{in} + 2 * padding[1] - dilation[1] * (kernel\_size[1] - 1) - 1) / stride[1] + 1)`
Example:
>>> from pysseg.torch.models.output_shape_for import *
>>> input_shape = (1, 3, 256, 256)
>>> module = nn.Conv2d(input_shape[1], 11, 3, 1, 0)
>>> output_shape = OutputShapeFor(module)(input_shape)
>>> print('output_shape = {!r}'.format(output_shape))
output_shape = (1, 11, 254, 254)
"""
math = OutputShapeFor.math
N, C_in, H_in, W_in = input_shape
C_out = module.out_channels
padding = module.padding
stride = module.stride
dilation = module.dilation
kernel_size = module.kernel_size
H_out = math.floor((H_in + 2 * padding[0] - dilation[0] * (kernel_size[0] - 1) - 1) / stride[0] + 1)
W_out = math.floor((W_in + 2 * padding[1] - dilation[1] * (kernel_size[1] - 1) - 1) / stride[1] + 1)
output_shape = (N, C_out, H_out, W_out)
return output_shape
@staticmethod
@compute_type(nn.Conv3d)
def conv3d(module, input_shape):
"""
Notes:
- Input: :math:`(N, C_{in}, D_{in}, H_{in}, W_{in})`
- Output: :math:`(N, C_{out}, D_{out}, H_{out}, W_{out})` where
:math:`D_{out} = floor((D_{in} + 2 * padding[0] - dilation[0] * (kernel\_size[0] - 1) - 1) / stride[0] + 1)`
:math:`H_{out} = floor((H_{in} + 2 * padding[1] - dilation[1] * (kernel\_size[1] - 1) - 1) / stride[1] + 1)`
:math:`W_{out} = floor((W_{in} + 2 * padding[2] - dilation[2] * (kernel\_size[2] - 1) - 1) / stride[2] + 1)`
Example:
>>> from pysseg.torch.models.output_shape_for import *
>>> input_shape = (1, 3, 25, 32, 32)
>>> module = nn.Conv3d(in_channels=input_shape[1], out_channels=11,
>>> kernel_size=(3, 3, 3), stride=1, padding=0,
>>> dilation=1, groups=1, bias=True)
>>> output_shape = OutputShapeFor(module)(input_shape)
>>> print('output_shape = {!r}'.format(output_shape))
output_shape = (1, 11, 23, 30, 30)
"""
math = OutputShapeFor.math
N, C_in, D_in, H_in, W_in = input_shape
C_out = module.out_channels
padding = module.padding
stride = module.stride
dilation = module.dilation
kernel_size = module.kernel_size
D_out = math.floor((D_in + 2 * padding[0] - dilation[0] * (kernel_size[0] - 1) - 1) / stride[0] + 1)
H_out = math.floor((H_in + 2 * padding[1] - dilation[1] * (kernel_size[1] - 1) - 1) / stride[1] + 1)
W_out = math.floor((W_in + 2 * padding[2] - dilation[2] * (kernel_size[2] - 1) - 1) / stride[2] + 1)
output_shape = (N, C_out, D_out, H_out, W_out)
return output_shape
@staticmethod
@compute_type(nn.Conv3d)
def max_pool3d(module, input_shape):
"""
- Input: :math:`(N, C, D_{in}, H_{in}, W_{in})`
- Output: :math:`(N, C, D_{out}, H_{out}, W_{out})` where
:math:`D_{out} = floor((D_{in} + 2 * padding[0] - dilation[0] * (kernel\_size[0] - 1) - 1) / stride[0] + 1)`
:math:`H_{out} = floor((H_{in} + 2 * padding[1] - dilation[1] * (kernel\_size[1] - 1) - 1) / stride[1] + 1)`
:math:`W_{out} = floor((W_{in} + 2 * padding[2] - dilation[2] * (kernel\_size[2] - 1) - 1) / stride[2] + 1)`
"""
math = OutputShapeFor.math
N, C_in, D_in, H_in, W_in = input_shape
padding = module.padding
stride = module.stride
dilation = module.dilation
kernel_size = module.kernel_size
D_out = math.floor((D_in + 2 * padding[0] - dilation[0] * (kernel_size[0] - 1) - 1) / stride[0] + 1)
H_out = math.floor((H_in + 2 * padding[1] - dilation[1] * (kernel_size[1] - 1) - 1) / stride[1] + 1)
W_out = math.floor((W_in + 2 * padding[2] - dilation[2] * (kernel_size[2] - 1) - 1) / stride[2] + 1)
output_shape = (N, C_in, D_out, H_out, W_out)
return output_shape
@staticmethod
@compute_type(torch.cat)
def cat(input_shapes, dim=0):
"""
Example:
>>> from pysseg.torch.models.output_shape_for import *
>>> input_shape1 = (1, 3, 256, 256)
>>> input_shape2 = (1, 4, 256, 256)
>>> input_shapes = [input_shape1, input_shape2]
>>> output_shape = OutputShapeFor(torch.cat)(input_shapes, dim=1)
>>> print('output_shape = {!r}'.format(output_shape))
output_shape = [1, 7, 256, 256]
"""
n_dims = max(map(len, input_shapes))
assert n_dims == min(map(len, input_shapes))
output_shape = [None] * n_dims
for shape in input_shapes:
for i, v in enumerate(shape):
if output_shape[i] is None:
output_shape[i] = v
else:
if i == dim:
output_shape[i] += v
else:
assert output_shape[i] == v, 'inconsistent dims'
return output_shape
@staticmethod
@compute_type(nn.MaxPool2d)
def maxpool2(module, input_shape):
"""
Example:
>>> from pysseg.torch.models.output_shape_for import *
>>> input_shape = (1, 3, 256, 256)
>>> module = nn.MaxPool2d(kernel_size=2)
>>> output_shape = OutputShapeFor(module)(input_shape)
>>> print('output_shape = {!r}'.format(output_shape))
output_shape = [1, 7, 256, 256]
Shape:
Same as conv2 forumla except C2 = C1
- Input: :math:`(N, C, H_{in}, W_{in})`
- Output: :math:`(N, C, H_{out}, W_{out})` where
:math:`H_{out} = floor((H_{in} + 2 * padding[0] - dilation[0] * (kernel\_size[0] - 1) - 1) / stride[0] + 1)`
:math:`W_{out} = floor((W_{in} + 2 * padding[1] - dilation[1] * (kernel\_size[1] - 1) - 1) / stride[1] + 1)`
"""
math = OutputShapeFor.math
N, C, H_in, W_in = input_shape
def ensure_iterable2(scalar):
try:
iter(scalar)
except TypeError:
return [scalar] * 2
return scalar
padding = ensure_iterable2(module.padding)
stride = ensure_iterable2(module.stride)
dilation = ensure_iterable2(module.dilation)
kernel_size = ensure_iterable2(module.kernel_size)
H_out = math.floor((H_in + 2 * padding[0] - dilation[0] * (kernel_size[0] - 1) - 1) / stride[0] + 1)
W_out = math.floor((W_in + 2 * padding[1] - dilation[1] * (kernel_size[1] - 1) - 1) / stride[1] + 1)
output_shape = (N, C, H_out, W_out)
return output_shape
@staticmethod
@compute_type(nn.AvgPool2d)
def avepool2d(module, input_shape):
"""
Shape:
- Input: :math:`(N, C, H_{in}, W_{in})`
- Output: :math:`(N, C, H_{out}, W_{out})` where
:math:`H_{out} = floor((H_{in} + 2 * padding[0] - kernel\_size[0]) / stride[0] + 1)`
:math:`W_{out} = floor((W_{in} + 2 * padding[1] - kernel\_size[1]) / stride[1] + 1)`
"""
math = OutputShapeFor.math
N, C, H_in, W_in = input_shape
def ensure_iterable2(scalar):
try:
iter(scalar)
except TypeError:
return [scalar] * 2
return scalar
padding = ensure_iterable2(module.padding)
stride = ensure_iterable2(module.stride)
kernel_size = ensure_iterable2(module.kernel_size)
H_out = math.floor((H_in + 2 * padding[0] - kernel_size[0]) / stride[0] + 1)
W_out = math.floor((W_in + 2 * padding[1] - kernel_size[1]) / stride[1] + 1)
output_shape = (N, C, H_out, W_out)
return output_shape
@staticmethod
@compute_type(nn.Linear)
def linear(module, input_shape):
"""
Shape:
- Input: :math:`(N, *, in\_features)` where `*` means any number of
additional dimensions
- Output: :math:`(N, *, out\_features)` where all but the last dimension
are the same shape as the input.
"""
N, *other, in_feat = input_shape
output_shape = [N] + other + [module.out_features]
return output_shape
@staticmethod
@compute_type(nn.BatchNorm2d)
def batchnorm(module, input_shape):
return input_shape
@staticmethod
@compute_type(nn.ReLU)
def relu(module, input_shape):
return input_shape
@staticmethod
@compute_type(nn.LeakyReLU)
def leaky_relu(module, input_shape):
return input_shape
@staticmethod
@compute_type(nn.Sequential)
def sequential(module, input_shape):
shape = input_shape
for child in module._modules.values():
shape = OutputShapeFor(child)(shape)
return shape
@staticmethod
@compute_type(torchvision.models.resnet.BasicBlock)
def resent_basic_block(module, input_shape):
residual_shape = input_shape
shape = input_shape
shape = OutputShapeFor(module.conv1)(shape)
shape = OutputShapeFor(module.bn1)(shape)
shape = OutputShapeFor(module.relu)(shape)
shape = OutputShapeFor(module.conv2)(shape)
shape = OutputShapeFor(module.bn2)(shape)
shape = OutputShapeFor(module.relu)(shape)
if module.downsample is not None:
residual_shape = OutputShapeFor(module.downsample)(residual_shape)
# assert residual_shape[-2:] == shape[-2:], 'cannot add residual {} {}'.format(residual_shape, shape)
# out += residual
shape = OutputShapeFor(module.relu)(shape)
# print('BASIC residual_shape = {!r}'.format(residual_shape[-2:]))
# print('BASIC shape = {!r}'.format(shape[-2:]))
# print('---')
return shape
@staticmethod
@compute_type(torchvision.models.resnet.Bottleneck)
def resent_bottleneck(module, input_shape):
residual_shape = input_shape
shape = input_shape
shape = OutputShapeFor(module.conv1)(shape)
shape = OutputShapeFor(module.bn1)(shape)
shape = OutputShapeFor(module.relu)(shape)
shape = OutputShapeFor(module.conv2)(shape)
shape = OutputShapeFor(module.bn2)(shape)
shape = OutputShapeFor(module.relu)(shape)
shape = OutputShapeFor(module.conv3)(shape)
shape = OutputShapeFor(module.bn3)(shape)
if module.downsample is not None:
residual_shape = OutputShapeFor(module.downsample)(input_shape)
assert residual_shape[-2:] == shape[-2:], 'cannot add residual {} {}'.format(residual_shape, shape)
# out += residual
shape = OutputShapeFor(module.relu)(shape)
# print('bottle downsample = {!r}'.format(module.downsample))
# print('bottle input_shape = {!r}'.format(input_shape[-2:]))
# print('bottle residual_shape = {!r}'.format(residual_shape[-2:]))
# print('bottle shape = {!r}'.format(shape[-2:]))
# print('---')
return shape
@staticmethod
@compute_type(torchvision.models.resnet.ResNet)
def resnet_model(module, input_shape):
shape = input_shape
shape = OutputShapeFor(module.conv1)(shape)
shape = OutputShapeFor(module.bn1)(shape)
shape = OutputShapeFor(module.relu)(shape)
shape = OutputShapeFor(module.maxpool)(shape)
shape = OutputShapeFor(module.layer1)(shape)
shape = OutputShapeFor(module.layer2)(shape)
shape = OutputShapeFor(module.layer3)(shape)
shape = OutputShapeFor(module.layer4)(shape)
shape = OutputShapeFor(module.avgpool)(shape)
print('pre-flatten-shape = {!r}'.format(shape))
def prod(args):
result = args[0]
for arg in args[1:]:
result = result * arg
return result
shape = (shape[0], prod(shape[1:]))
# shape = shape.view(shape.size(0), -1)
shape = OutputShapeFor(module.fc)(shape)
@staticmethod
def resnet_conv_part(module, input_shape):
shape = input_shape
shape = OutputShapeFor(module.conv1)(shape)
shape = OutputShapeFor(module.bn1)(shape)
shape = OutputShapeFor(module.relu)(shape)
shape = OutputShapeFor(module.maxpool)(shape)
shape = OutputShapeFor(module.layer1)(shape)
shape = OutputShapeFor(module.layer2)(shape)
shape = OutputShapeFor(module.layer3)(shape)
shape = OutputShapeFor(module.layer4)(shape)
shape = OutputShapeFor(module.avgpool)(shape)
# print('pre-flatten-shape = {!r}'.format(shape))
def prod(args):
result = args[0]
for arg in args[1:]:
result = result * arg
return result
shape = (shape[0], prod(shape[1:]))
# shape = shape.view(shape.size(0), -1)
return shape
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment