Last active
March 17, 2020 05:50
-
-
Save xmodar/14701088e8b135b5f6c88413a3548c0f to your computer and use it in GitHub Desktop.
PyTorch fully-convolutional auto-encoder for any arbitrary image sizes (including rectangles). Can, also, be used for a DCGAN.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from torch import nn | |
class Generator(nn.Module): | |
def __init__(self, input_dim, image_shape, memory): | |
super().__init__() | |
self.memory = memory | |
self.input_dim = input_dim | |
self.image_shape = image_shape | |
convs = self.conv_transpose_2d_layers(input_dim, image_shape, memory) | |
block = lambda c: [ | |
c, nn.BatchNorm2d(c.out_channels), | |
nn.LeakyReLU(0.2, True) | |
] | |
layers = [l for c in convs[:-1] for l in block(c)] | |
self.decoder = nn.Sequential(*layers, convs[-1], nn.Tanh()) | |
def forward(self, latent_code): # pylint: disable=arguments-differ | |
return self.decoder(latent_code.view(-1, self.input_dim, 1, 1)) | |
@staticmethod | |
def conv_transpose_2d_layers(input_dim, image_shape, memory, bias=False): | |
config = Generator.conv_transpose_2d_config(*image_shape[1:]) | |
num_layers = len(config['stride']) | |
channels = [memory * 2**i for i in range(num_layers - 2, -1, -1)] | |
config['in_channels'] = [input_dim] + channels | |
config['out_channels'] = channels + [image_shape[0]] | |
kwargs = [dict(zip(config.keys(), x)) for x in zip(*config.values())] | |
return [nn.ConvTranspose2d(**k, bias=bias) for k in kwargs] | |
@staticmethod | |
def conv_transpose_2d_config(*size): | |
binary = [bin(s)[2:] for s in size] | |
num_layers = max((len(b) for b in binary), default=0) - 1 | |
keys = [(num_layers - len(b)) * '-' + b[1:] for b in binary] | |
stride = lambda k: (1 if c == '-' else 2 for c in k) | |
kernel_size = lambda k: (3 if c == '-' else 4 for c in k) | |
output_padding = lambda k: (1 if c == '1' else 0 for c in k) | |
return { | |
'kernel_size': list(zip(*map(kernel_size, keys))), | |
'stride': list(zip(*map(stride, keys))), | |
'padding': [(1,) * len(keys)] * num_layers, | |
'output_padding': list(zip(*map(output_padding, keys))), | |
'dilation': [(1,) * len(keys)] * num_layers, | |
} | |
class Discriminator(nn.Module): | |
def __init__(self, output_dim, image_shape, memory): | |
super().__init__() | |
self.memory = memory | |
self.output_dim = output_dim | |
self.image_shape = image_shape | |
convs = self.conv_2d_layers(output_dim, image_shape, memory) | |
block = lambda c: [ | |
c, nn.BatchNorm2d(c.out_channels), | |
nn.LeakyReLU(0.2, True) | |
] | |
layers = [l for c in convs[:-1] for l in block(c)] | |
self.encoder = nn.Sequential(*layers, convs[-1]) | |
def forward(self, images): # pylint: disable=arguments-differ | |
return self.encoder(images).flatten(1) | |
@staticmethod | |
def conv_2d_layers(output_dim, image_shape, memory, bias=False): | |
config = Discriminator.conv_2d_config(*image_shape[1:]) | |
num_layers = len(config['stride']) | |
channels = [memory * 2**i for i in range(num_layers - 1)] | |
config['in_channels'] = [image_shape[0]] + channels | |
config['out_channels'] = channels + [output_dim] | |
kwargs = [dict(zip(config.keys(), x)) for x in zip(*config.values())] | |
return [nn.Conv2d(**k, bias=bias) for k in kwargs] | |
@staticmethod | |
def conv_2d_config(*size): | |
binary = [bin(s + 1)[2:] for s in size] | |
num_layers = max((len(b) for b in binary), default=0) - 1 | |
keys = [b[2:] + (num_layers - len(b) + 2) * '-' for b in binary] | |
padding = lambda k: (1 if c == '-' else 0 for c in k) | |
return { | |
'kernel_size': [(3,) * len(size)] * num_layers, | |
'stride': [(2,) * len(size)] * num_layers, | |
'padding': list(zip(*map(padding, keys))), | |
'dilation': [(1,) * len(size)] * num_layers, | |
} | |
class AutoEncoder(nn.Module): | |
def __init__(self, channels, size, memory): | |
super().__init__() | |
self.memory = memory | |
self.height, self.width = size | |
self.in_channels, self.latent_dim, self.out_channels = channels | |
in_size = (self.in_channels, self.height, self.width) | |
out_size = (self.out_channels, self.height, self.width) | |
self.encoder = Discriminator(self.latent_dim, in_size, memory).encoder | |
self.decoder = Generator(self.latent_dim, out_size, memory).decoder | |
def forward(self, images): # pylint: disable=arguments-differ | |
return self.decoder(self.encoder(images)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment