This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""Deconvolution https://api.semanticscholar.org/CorpusID:208192734""" | |
import torch | |
from torch import nn | |
class Deconv(nn.Module): | |
"""Inverse conv https://gist.github.com/ModarTensai/7921460648230eda5053fe06b7cd2f4d""" | |
def __init__(self, conv, output_padding=0): | |
dim = len(conv.padding) | |
if isinstance(output_padding, int): | |
output_padding = (output_padding, ) * dim | |
assert isinstance(output_padding, tuple) and len(output_padding) == dim | |
super().__init__() | |
self.stride = conv.stride | |
self.padding = conv.padding | |
self.output_padding = output_padding | |
self.groups = conv.groups | |
self.dilation = conv.dilation | |
weight = conv.weight.detach() | |
weight = weight.view(conv.groups, -1, weight.shape[1:].numel()) | |
weight = weight.pinverse().transpose(-1, -2).reshape_as(conv.weight) | |
self.weight = nn.Parameter(weight) | |
if conv.bias is None: | |
self.bias = None | |
else: | |
self.bias = nn.Parameter(-conv.bias.detach()) | |
def forward(self, inputs, output_size=None): | |
"""Perform the forward pass""" | |
if self.bias is not None: | |
inputs = inputs + self.bias.view(-1, *[1] * self.dim) | |
output_padding = self.get_output_padding(output_size) | |
outputs = self.conv_transpose(inputs, self.weight, output_padding) | |
one = torch.ones((), device=outputs.device, dtype=outputs.dtype) | |
factor = self.groups / self.in_channels | |
inputs = (one * factor).expand(1, *inputs.shape[1:]) | |
weight = one.expand(self.in_channels, 1, *self.kernel_size) | |
overlaps = self.conv_transpose(inputs, weight, output_padding) | |
return outputs.div_(overlaps.clamp_min_(1)) | |
def get_output_padding(self, output_size=None): | |
"""Get the output padding given the output size""" | |
if output_size is None: | |
return self.output_padding | |
def rule(size, kernel_size, stride, padding, dilation): | |
margin = 2 * padding - dilation * (kernel_size - 1) - 1 | |
return size - ((size + margin) // stride) * stride + margin | |
output_size = tuple(output_size)[-self.dim:] | |
args = self.kernel_size, self.stride, self.padding, self.dilation | |
return tuple(map(lambda x: rule(*x), zip(output_size, *args))) | |
def set_output_padding(self, output_size=None): | |
"""Set the output padding given the output size""" | |
self.output_padding = self.get_output_padding(output_size) | |
return self | |
def conv_transpose(self, inputs, weight, output_padding): | |
"""Compute conv_transpose""" | |
conv = getattr(nn.functional, f'conv_transpose{self.dim}d') | |
return conv(inputs, weight, None, self.stride, self.padding, | |
output_padding, self.groups, self.dilation) | |
@property | |
def in_channels(self): | |
"""Number of input channels""" | |
return self.weight.shape[0] | |
@property | |
def out_channels(self): | |
"""Number of output channels""" | |
return self.weight.shape[1] | |
@property | |
def kernel_size(self): | |
"""Kernel size""" | |
return tuple(self.weight.shape[2:]) | |
@property | |
def dim(self): | |
"""Number of dimensions""" | |
return self.weight.ndim - 2 | |
@torch.no_grad() | |
def _test(): | |
conv = nn.Conv2d(3, 81, 3, stride=2, padding=2).double() | |
inputs = torch.randn(1, conv.in_channels, 10, 10).double() | |
deconv = Deconv(conv).set_output_padding(inputs.shape[2:]) | |
print(torch.allclose(inputs, deconv(conv(inputs)))) | |
if __name__ == '__main__': | |
_test() |
coolbay
commented
Oct 25, 2021
•
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment