Skip to content

Instantly share code, notes, and snippets.

@xmodar
Last active March 17, 2020 05:50
Show Gist options
  • Save xmodar/14701088e8b135b5f6c88413a3548c0f to your computer and use it in GitHub Desktop.
Save xmodar/14701088e8b135b5f6c88413a3548c0f to your computer and use it in GitHub Desktop.
PyTorch fully-convolutional auto-encoder for any arbitrary image sizes (including rectangles). Can, also, be used for a DCGAN.
from torch import nn
class Generator(nn.Module):
def __init__(self, input_dim, image_shape, memory):
super().__init__()
self.memory = memory
self.input_dim = input_dim
self.image_shape = image_shape
convs = self.conv_transpose_2d_layers(input_dim, image_shape, memory)
block = lambda c: [
c, nn.BatchNorm2d(c.out_channels),
nn.LeakyReLU(0.2, True)
]
layers = [l for c in convs[:-1] for l in block(c)]
self.decoder = nn.Sequential(*layers, convs[-1], nn.Tanh())
def forward(self, latent_code): # pylint: disable=arguments-differ
return self.decoder(latent_code.view(-1, self.input_dim, 1, 1))
@staticmethod
def conv_transpose_2d_layers(input_dim, image_shape, memory, bias=False):
config = Generator.conv_transpose_2d_config(*image_shape[1:])
num_layers = len(config['stride'])
channels = [memory * 2**i for i in range(num_layers - 2, -1, -1)]
config['in_channels'] = [input_dim] + channels
config['out_channels'] = channels + [image_shape[0]]
kwargs = [dict(zip(config.keys(), x)) for x in zip(*config.values())]
return [nn.ConvTranspose2d(**k, bias=bias) for k in kwargs]
@staticmethod
def conv_transpose_2d_config(*size):
binary = [bin(s)[2:] for s in size]
num_layers = max((len(b) for b in binary), default=0) - 1
keys = [(num_layers - len(b)) * '-' + b[1:] for b in binary]
stride = lambda k: (1 if c == '-' else 2 for c in k)
kernel_size = lambda k: (3 if c == '-' else 4 for c in k)
output_padding = lambda k: (1 if c == '1' else 0 for c in k)
return {
'kernel_size': list(zip(*map(kernel_size, keys))),
'stride': list(zip(*map(stride, keys))),
'padding': [(1,) * len(keys)] * num_layers,
'output_padding': list(zip(*map(output_padding, keys))),
'dilation': [(1,) * len(keys)] * num_layers,
}
class Discriminator(nn.Module):
def __init__(self, output_dim, image_shape, memory):
super().__init__()
self.memory = memory
self.output_dim = output_dim
self.image_shape = image_shape
convs = self.conv_2d_layers(output_dim, image_shape, memory)
block = lambda c: [
c, nn.BatchNorm2d(c.out_channels),
nn.LeakyReLU(0.2, True)
]
layers = [l for c in convs[:-1] for l in block(c)]
self.encoder = nn.Sequential(*layers, convs[-1])
def forward(self, images): # pylint: disable=arguments-differ
return self.encoder(images).flatten(1)
@staticmethod
def conv_2d_layers(output_dim, image_shape, memory, bias=False):
config = Discriminator.conv_2d_config(*image_shape[1:])
num_layers = len(config['stride'])
channels = [memory * 2**i for i in range(num_layers - 1)]
config['in_channels'] = [image_shape[0]] + channels
config['out_channels'] = channels + [output_dim]
kwargs = [dict(zip(config.keys(), x)) for x in zip(*config.values())]
return [nn.Conv2d(**k, bias=bias) for k in kwargs]
@staticmethod
def conv_2d_config(*size):
binary = [bin(s + 1)[2:] for s in size]
num_layers = max((len(b) for b in binary), default=0) - 1
keys = [b[2:] + (num_layers - len(b) + 2) * '-' for b in binary]
padding = lambda k: (1 if c == '-' else 0 for c in k)
return {
'kernel_size': [(3,) * len(size)] * num_layers,
'stride': [(2,) * len(size)] * num_layers,
'padding': list(zip(*map(padding, keys))),
'dilation': [(1,) * len(size)] * num_layers,
}
class AutoEncoder(nn.Module):
def __init__(self, channels, size, memory):
super().__init__()
self.memory = memory
self.height, self.width = size
self.in_channels, self.latent_dim, self.out_channels = channels
in_size = (self.in_channels, self.height, self.width)
out_size = (self.out_channels, self.height, self.width)
self.encoder = Discriminator(self.latent_dim, in_size, memory).encoder
self.decoder = Generator(self.latent_dim, out_size, memory).decoder
def forward(self, images): # pylint: disable=arguments-differ
return self.decoder(self.encoder(images))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment