Skip to content

Instantly share code, notes, and snippets.

@xmodar
Last active October 25, 2021 17:47
Embed
What would you like to do?
"""Deconvolution https://api.semanticscholar.org/CorpusID:208192734"""
import torch
from torch import nn
class Deconv(nn.Module):
"""Inverse conv https://gist.github.com/ModarTensai/7921460648230eda5053fe06b7cd2f4d"""
def __init__(self, conv, output_padding=0):
dim = len(conv.padding)
if isinstance(output_padding, int):
output_padding = (output_padding, ) * dim
assert isinstance(output_padding, tuple) and len(output_padding) == dim
super().__init__()
self.stride = conv.stride
self.padding = conv.padding
self.output_padding = output_padding
self.groups = conv.groups
self.dilation = conv.dilation
weight = conv.weight.detach()
weight = weight.view(conv.groups, -1, weight.shape[1:].numel())
weight = weight.pinverse().transpose(-1, -2).reshape_as(conv.weight)
self.weight = nn.Parameter(weight)
if conv.bias is None:
self.bias = None
else:
self.bias = nn.Parameter(-conv.bias.detach())
def forward(self, inputs, output_size=None):
"""Perform the forward pass"""
if self.bias is not None:
inputs = inputs + self.bias.view(-1, *[1] * self.dim)
output_padding = self.get_output_padding(output_size)
outputs = self.conv_transpose(inputs, self.weight, output_padding)
one = torch.ones((), device=outputs.device, dtype=outputs.dtype)
factor = self.groups / self.in_channels
inputs = (one * factor).expand(1, *inputs.shape[1:])
weight = one.expand(self.in_channels, 1, *self.kernel_size)
overlaps = self.conv_transpose(inputs, weight, output_padding)
return outputs.div_(overlaps.clamp_min_(1))
def get_output_padding(self, output_size=None):
"""Get the output padding given the output size"""
if output_size is None:
return self.output_padding
def rule(size, kernel_size, stride, padding, dilation):
margin = 2 * padding - dilation * (kernel_size - 1) - 1
return size - ((size + margin) // stride) * stride + margin
output_size = tuple(output_size)[-self.dim:]
args = self.kernel_size, self.stride, self.padding, self.dilation
return tuple(map(lambda x: rule(*x), zip(output_size, *args)))
def set_output_padding(self, output_size=None):
"""Set the output padding given the output size"""
self.output_padding = self.get_output_padding(output_size)
return self
def conv_transpose(self, inputs, weight, output_padding):
"""Compute conv_transpose"""
conv = getattr(nn.functional, f'conv_transpose{self.dim}d')
return conv(inputs, weight, None, self.stride, self.padding,
output_padding, self.groups, self.dilation)
@property
def in_channels(self):
"""Number of input channels"""
return self.weight.shape[0]
@property
def out_channels(self):
"""Number of output channels"""
return self.weight.shape[1]
@property
def kernel_size(self):
"""Kernel size"""
return tuple(self.weight.shape[2:])
@property
def dim(self):
"""Number of dimensions"""
return self.weight.ndim - 2
@torch.no_grad()
def _test():
conv = nn.Conv2d(3, 81, 3, stride=2, padding=2).double()
inputs = torch.randn(1, conv.in_channels, 10, 10).double()
deconv = Deconv(conv).set_output_padding(inputs.shape[2:])
print(torch.allclose(inputs, deconv(conv(inputs))))
if __name__ == '__main__':
_test()
@coolbay
Copy link

coolbay commented Oct 25, 2021

class InvertibleConv(nn.Module):
    """Invertible conv https://gist.github.com/ModarTensai/7921460648230eda5053fe06b7cd2f4d"""
    def __init__(self, conv):
        super().__init__()
        columns = conv.weight.shape[1:].numel()
        rows = conv.weight.shape[0] // conv.groups
        assert columns <= rows, f'out_channels are too few'
        self.conv = conv

    def forward(self, inputs):
        """Perform the forward pass"""
        input_padding = self.get_input_padding(inputs.shape)
        assert sum(input_padding) == 0, f'inputs need padding: {inputs.shape}'
        return self.conv(inputs)

    def inverse(self, inputs):
        """Perform the inverse pass"""
        weight = self.weight.view(self.groups, -1, self.weight.shape[1:].numel())
        weight = weight.pinverse().transpose(-1, -2).reshape_as(self.weight)
        if self.bias is not None:
            inputs = inputs - self.bias.view(-1, *[1] * self.dim)
        outputs = self.conv_transpose(inputs, weight)
        one = torch.ones((), device=outputs.device, dtype=outputs.dtype)
        factor = self.groups / self.out_channels
        inputs = (one * factor).expand(1, *inputs.shape[1:])
        overlap_weights = one.expand(self.out_channels, 1, *self.kernel_size)
        overlaps = self.conv_transpose(inputs, overlap_weights)
        return outputs.div_(overlaps.clamp_min_(1))

    def get_input_padding(self, input_size):
        """Get the input padding given the input size"""
        def rule(size, kernel_size, stride, padding, dilation):
            margin = 2 * padding - dilation * (kernel_size - 1) - 1
            return size - ((size + margin) // stride) * stride + margin

        input_size = tuple(input_size)[-self.dim:]
        args = self.kernel_size, self.stride, self.padding, self.dilation
        return tuple(map(lambda x: rule(*x), zip(input_size, *args)))

    def conv_transpose(self, inputs, weight):
        """Compute conv_transpose"""
        conv = getattr(nn.functional, f'conv_transpose{self.dim}d')
        return conv(inputs, weight, None, self.stride, self.padding,
                    0, self.groups, self.dilation)

    @property
    def dim(self):
        """Number of dimensions"""
        return self.weight.ndim - 2

    def __getattr__(self, item):
        try:
            return super().__getattr__(item)
        except AttributeError:
            return getattr(self.conv, item)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment