Last active
April 6, 2023 12:47
-
-
Save kaparoo/5c3cac442253d6e38864510ddaf50cba to your computer and use it in GitHub Desktop.
Standalone pix2pix framework using PyTorch Lightning
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# -*- coding: utf-8 -*- | |
__all__ = ("Pix2Pix",) | |
from collections.abc import Sequence | |
from itertools import pairwise | |
from typing import Any | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from lightning import LightningModule | |
from torch import Tensor | |
from torch.nn.common_types import _size_2_t | |
from torch.optim import Adam, Optimizer | |
def make_exponential_channels( | |
num_channels: int = 3, | |
unit_channels: int = 64, | |
max_scale: int = 8, | |
) -> Sequence[int]: | |
return tuple(unit_channels * min(2**i, max_scale) for i in range(num_channels)) | |
class GAN(LightningModule): | |
def __init__( | |
self, | |
source_channels: int = 3, | |
target_channels: int | None = None, | |
) -> None: | |
super().__init__() | |
self.automatic_optimization = False | |
if not isinstance(target_channels, int): | |
target_channels = source_channels | |
self.source_channels = source_channels | |
self.target_channels = target_channels | |
def get_labels(self, preds: Tensor, as_real: bool = True) -> Tensor: | |
return torch.ones_like(preds) if as_real else torch.zeros_like(preds) | |
def adversarial_loss(self, preds: Tensor, as_real: bool = True) -> Tensor: | |
labels = self.get_labels(preds, as_real) | |
return F.binary_cross_entropy(preds, labels) | |
class Pix2Pix(GAN): | |
def __init__( | |
self, | |
source_channels: int = 3, | |
target_channels: int | None = None, | |
input_first: bool = True, | |
lambda_rcn: float = 100.0, | |
num_dis_blocks: int | None = 3, | |
num_gen_blocks: int | None = 8, | |
use_autoencoder: bool = False, | |
num_dropouts: int = 3, | |
dropout: float = 0.5, | |
dis_channels: Sequence[int] | None = None, | |
gen_channels: Sequence[int] | None = None, | |
dis_conv_kwargs: dict[str, Any] | None = None, | |
gen_conv_kwargs: dict[str, Any] | None = None, | |
num_encoder_dropouts: int = 0, | |
encoder_dropout: float = 0.0, | |
learning_rate: float = 0.0002, | |
beta1: float = 0.5, | |
beta2: float = 0.999, | |
) -> None: | |
super().__init__(source_channels, target_channels) | |
self.save_hyperparameters() | |
self.lambda_rcn = lambda_rcn | |
self.input_first = input_first | |
self.discriminator = Discriminator( | |
in_channels=self.source_channels + self.target_channels, | |
num_blocks=num_dis_blocks, | |
conv_kwargs=dis_conv_kwargs, | |
blocks_channels=dis_channels, | |
) | |
Generator = Autoencoder if use_autoencoder else UNet | |
self.generator = Generator( | |
in_channels=self.source_channels, | |
out_channels=self.target_channels, | |
num_blocks=num_gen_blocks, | |
num_encoder_dropouts=num_encoder_dropouts, | |
num_decoder_dropouts=num_dropouts, | |
encoder_dropout=encoder_dropout, | |
decoder_dropout=dropout, | |
conv_kwargs=gen_conv_kwargs, | |
blocks_channels=gen_channels, | |
) | |
def forward(self, x: Tensor) -> Tensor: | |
return self.generator(x) # type: ignore | |
def reconstruction_loss(self, fake_images: Tensor, real_images: Tensor) -> Tensor: | |
return F.l1_loss(fake_images, real_images) | |
def train_discriminator( | |
self, | |
cond_images: Tensor, | |
real_images: Tensor, | |
fake_images: Tensor, | |
) -> None: | |
self.toggle_optimizer(optimizer := self.optimizers()[0]) | |
optimizer.zero_grad() | |
preds_real = self.discriminator(cond_images, real_images) | |
preds_fake = self.discriminator(cond_images, fake_images.detach()) | |
loss_real = self.adversarial_loss(preds_real, as_real=True) | |
loss_fake = self.adversarial_loss(preds_fake, as_real=False) | |
loss = (loss_real + loss_fake) / 2 | |
self.manual_backward(loss) | |
optimizer.step() | |
self.untoggle_optimizer(optimizer) | |
self.log("d_loss", loss, prog_bar=True) | |
self.log("d_loss_real", loss_real) | |
self.log("d_loss_fake", loss_fake) | |
def train_generator( | |
self, | |
cond_images: Tensor, | |
fake_images: Tensor, | |
real_images: Tensor, | |
) -> None: | |
self.toggle_optimizer(optimizer := self.optimizers()[1]) | |
optimizer.zero_grad() | |
preds = self.discriminator(cond_images, fake_images) | |
loss_adv = self.adversarial_loss(preds, as_real=True) | |
loss_rcn = self.reconstruction_loss(fake_images, real_images) | |
loss = loss_adv + self.lambda_rcn * loss_rcn | |
self.manual_backward(loss) | |
optimizer.step() | |
self.untoggle_optimizer(optimizer) | |
self.log("g_loss", loss, prog_bar=True) | |
self.log("g_loss_adv", loss_adv) | |
self.log("g_loss_rcn", loss_rcn) | |
def training_step(self, batch: tuple[Tensor, Tensor]) -> None: | |
if self.input_first: | |
cond_images, real_images = batch | |
else: | |
real_images, cond_images = batch | |
fake_images = self.forward(cond_images) | |
self.train_discriminator(cond_images, fake_images, real_images) | |
print("discriminator is trained") # FOR DEBUG | |
self.train_generator(cond_images, fake_images, real_images) | |
print("generator is trained") # FOR DEBUG | |
def configure_optimizers(self) -> tuple[Optimizer, Optimizer]: | |
learning_rate = self.hparams.learning_rate | |
betas = (self.hparams.beta1, self.hparams.beta2) | |
d_opt = Adam(self.discriminator.parameters(), lr=learning_rate, betas=betas) | |
g_opt = Adam(self.generator.parameters(), lr=learning_rate, betas=betas) | |
return (d_opt, g_opt) | |
class Autoencoder(nn.Module): | |
class _Block(nn.Module): | |
def __init__( | |
self, | |
out_channels: int, | |
norm: bool = True, | |
dropout: float = 0.0, | |
) -> None: | |
super().__init__() | |
self.conv: nn.Module | |
self.actv: nn.Module | |
self.norm = nn.BatchNorm2d(out_channels) if norm else nn.Identity() | |
self.drop = nn.Dropout2d(dropout) if 0 < dropout < 1 else nn.Identity() | |
def forward(self, x: Tensor) -> Tensor: | |
z = self.conv(x) | |
z = self.norm(z) | |
z = self.actv(z) | |
z = self.drop(z) | |
return z # type: ignore | |
class Encoder(_Block): | |
def __init__( | |
self, | |
in_channels: int, | |
out_channels: int, | |
kernel_size: _size_2_t, | |
stride: _size_2_t, | |
padding: _size_2_t | str, | |
bias: bool = True, | |
norm: bool = True, | |
dropout: float = 0.0, | |
) -> None: | |
super().__init__(out_channels, norm, dropout) | |
self.conv = nn.Conv2d( | |
in_channels=in_channels, | |
out_channels=out_channels, | |
kernel_size=kernel_size, | |
stride=stride, | |
padding=padding, | |
bias=bias, | |
) | |
self.actv = nn.LeakyReLU(0.2, inplace=True) | |
class Decoder(_Block): | |
def __init__( | |
self, | |
in_channels: int, | |
out_channels: int, | |
kernel_size: _size_2_t, | |
stride: _size_2_t, | |
padding: _size_2_t, | |
bias: bool = True, | |
norm: bool = True, | |
dropout: float = 0.0, | |
) -> None: | |
super().__init__(out_channels, norm, dropout) | |
self.conv = nn.ConvTranspose2d( | |
in_channels=in_channels, | |
out_channels=out_channels, | |
kernel_size=kernel_size, | |
stride=stride, | |
padding=padding, | |
bias=bias, | |
) | |
self.actv = nn.ReLU(inplace=True) | |
def __init__( | |
self, | |
in_channels: int = 3, | |
out_channels: int | None = 3, | |
num_blocks: int | None = 8, | |
num_encoder_dropouts: int = 0, | |
num_decoder_dropouts: int = 3, | |
encoder_dropout: float = 0.0, | |
decoder_dropout: float = 0.5, | |
conv_kwargs: dict[str, Any] | None = None, | |
blocks_channels: Sequence[int] | None = None, | |
) -> None: | |
super().__init__() | |
self._set_all_channels(in_channels, out_channels, num_blocks, blocks_channels) | |
self._set_conv_kwargs(conv_kwargs) | |
self._build_encoders(num_encoder_dropouts, encoder_dropout) | |
self._build_decoders(num_decoder_dropouts, decoder_dropout) | |
def _set_all_channels( | |
self, | |
in_channels: int, | |
out_channels: int | None, | |
num_blocks: int | None, | |
blocks_channels: Sequence[int] | None, | |
) -> None: | |
if not isinstance(out_channels, int): | |
out_channels = in_channels | |
if not blocks_channels: | |
if not isinstance(num_blocks, int): | |
message = "one of `num_blocks` or `blocks_channels` must not be emtpy" | |
raise TypeError(message) | |
blocks_channels = make_exponential_channels(num_blocks) | |
self.encoders_channels = (in_channels, *blocks_channels) | |
self.decoders_channels = (out_channels, *blocks_channels) | |
def _set_conv_kwargs(self, conv_kwargs: dict[str, Any] | None) -> None: | |
if conv_kwargs is None: | |
conv_kwargs = {} | |
else: | |
raise NotImplementedError("user-defined conv_kwargs is not ready") | |
# allowed = {"kernel_size", "stride", "padding", "bias"} | |
# if notallowed := set(conv_kwargs.keys()) - allowed: | |
# raise KeyError(f"not allowed key(s): {notallowed} (allowed: {allowed})") # noqa: E501 | |
self.conv_kwargs = { | |
"kernel_size": 4, | |
"stride": 2, | |
"padding": 1, | |
**conv_kwargs, | |
} | |
def _build_encoders(self, num_dropouts: int, dropout: float) -> None: | |
in_channels, *block_channels = self.encoders_channels | |
if num_dropouts >= len(block_channels): # exclude first encoder | |
raise ValueError("`num_dropouts` cannot exceed `len(blocks_channels)` - 1") | |
encoders, kwargs = [], self.conv_kwargs | |
# this loop runs only when len(blocks_channels) >= 2 | |
for idx, channels in enumerate((pairwise(reversed(block_channels)))): | |
if idx >= num_dropouts: | |
dropout = 0.0 | |
encoders.append(self.Encoder(*channels[::-1], **kwargs, dropout=dropout)) # type: ignore # noqa: E501 | |
encoders.append( | |
self.Encoder( | |
in_channels=in_channels, | |
out_channels=block_channels[0], | |
**kwargs, # type: ignore | |
norm=False, | |
) | |
) | |
encoders.reverse() | |
self.encoders = nn.ModuleList(encoders) | |
def _decoder_channels_hook( | |
self, | |
channels: tuple[int, int], | |
idx: int | None = None, | |
) -> tuple[int, int]: | |
return channels | |
def _build_decoders(self, num_dropouts: int, dropout: float) -> None: | |
out_channels, *blocks_channels = self.decoders_channels | |
if num_dropouts >= len(blocks_channels): # exclude last decoder | |
raise ValueError("`num_dropouts` cannot exceed `len(blocks_channels)` - 1") | |
decoders, kwargs = [], self.conv_kwargs | |
# this loop runs only when len(blocks_channels) >= 2 | |
for idx, channels in enumerate(pairwise(reversed(blocks_channels))): | |
if idx >= num_dropouts: | |
dropout = 0.0 | |
channels = self._decoder_channels_hook(channels, idx) | |
decoders.append(self.Decoder(*channels, **kwargs, dropout=dropout)) # type: ignore # noqa: E501 | |
channels = self._decoder_channels_hook((blocks_channels[0], out_channels)) | |
decoders.append( | |
nn.Sequential( | |
nn.ConvTranspose2d(*channels, **kwargs), # type: ignore | |
nn.Tanh(), | |
) | |
) | |
self.decoders = nn.ModuleList(decoders) | |
def forward(self, x: Tensor) -> Tensor: | |
for encoder in self.encoders: | |
x = encoder(x) | |
for decoder in self.decoders: | |
x = decoder(x) | |
return x # type: ignore | |
class UNet(Autoencoder): | |
def _decoder_channels_hook( | |
self, | |
channels: tuple[int, int], | |
idx: int | None = None, | |
) -> tuple[int, int]: | |
if idx == 0: | |
return channels | |
# doubling `in_channels`` for skip connection | |
in_channels, out_channels = channels | |
return (2 * in_channels, out_channels) | |
def _forward_encoders(self, encoder_input: Tensor) -> Sequence[Tensor]: | |
encoder_outputs: Sequence[Tensor] = [] | |
for encoder in self.encoders: | |
encoder_outputs.append(encoder(encoder_input)) # type: ignore | |
encoder_input = encoder_outputs[-1] # type: ignore | |
return encoder_outputs | |
def _forward_decoders( | |
self, | |
decoder_input: Tensor, | |
encoder_outputs: Sequence[Tensor], | |
) -> Tensor: | |
decoder, *decoders = self.decoders | |
decoder_output = decoder(decoder_input) # no skip connection | |
for decoder, encoder_output in zip(decoders, reversed(encoder_outputs)): | |
decoder_input = torch.cat((decoder_output, encoder_output), dim=1) | |
decoder_output = decoder(decoder_input) | |
return decoder_output # type: ignore | |
def forward(self, encoder_input: Tensor) -> Tensor: | |
*encoder_outputs, decoder_input = self._forward_encoders(encoder_input) | |
decoder_output = self._forward_decoders(decoder_input, encoder_outputs) | |
return decoder_output | |
class Discriminator(nn.Module): | |
class Block(nn.Module): | |
def __init__( | |
self, | |
in_channels: int, | |
out_channels: int, | |
kernel_size: _size_2_t = 4, | |
stride: _size_2_t = 2, | |
padding: _size_2_t | str = 1, | |
bias: bool = True, | |
norm: bool = True, | |
) -> None: | |
super().__init__() | |
self.conv = nn.Conv2d( | |
in_channels=in_channels, | |
out_channels=out_channels, | |
kernel_size=kernel_size, | |
stride=stride, | |
padding=padding, | |
bias=bias, | |
) | |
self.norm = nn.BatchNorm2d(out_channels) if norm else nn.Identity() | |
self.actv = nn.LeakyReLU(0.2, inplace=True) | |
def forward(self, x: Tensor) -> Tensor: | |
h = self.conv(x) | |
h = self.norm(h) | |
h = self.actv(h) | |
return h # type: ignore | |
def __init__( | |
self, | |
in_channels: int = 6, | |
num_blocks: int | None = 3, | |
return_logits: bool = False, | |
conv_kwargs: dict[str, Any] | None = None, | |
blocks_channels: Sequence[int] | None = None, | |
) -> None: | |
super().__init__() | |
self._set_all_channels(in_channels, num_blocks, blocks_channels) | |
self._set_conv_kwargs(num_blocks == 0, conv_kwargs) | |
self._build_input() | |
self._build_hiddens() | |
self._build_output(return_logits) | |
def _set_all_channels( | |
self, | |
in_channels: int, | |
num_blocks: int | None, | |
blocks_channels: Sequence[int] | None, | |
) -> None: | |
if not blocks_channels: | |
if not isinstance(num_blocks, int): | |
message = "one of `num_blocks` or `blocks_channels` must not be emtpy" | |
raise TypeError(message) | |
if num_blocks == 0: # PixelGAN (now num_blocks >= 1) | |
num_blocks = 2 | |
blocks_channels = make_exponential_channels(num_blocks) | |
self.all_channels = (in_channels, *blocks_channels, 1) | |
def _set_conv_kwargs( | |
self, | |
pixelgan: bool, | |
conv_kwargs: dict[str, Any] | None, | |
) -> None: | |
if conv_kwargs is None: | |
conv_kwargs = {} | |
else: | |
allowed = {"kernel_size", "stride", "padding", "bias"} | |
if notallowed := set(conv_kwargs.keys()) - allowed: | |
raise KeyError(f"not allowed key(s): {notallowed} (allowed: {allowed})") | |
if pixelgan: | |
self.conv_kwargs = { | |
**conv_kwargs, | |
"kernel_size": 1, | |
"stride": 1, | |
"padding": 0, | |
} | |
else: | |
self.conv_kwargs = { | |
"kernel_size": 4, | |
"stride": 2, | |
"padding": 1, | |
**conv_kwargs, | |
} | |
def _build_input(self) -> None: | |
channels, kwargs = self.all_channels[:2], self.conv_kwargs | |
self.input = self.Block(*channels, **kwargs, norm=False) # type: ignore | |
def _build_hiddens(self) -> None: | |
if len(blocks_channels := self.all_channels[1:-1]) < 2: | |
self.hidden = nn.Identity() | |
return | |
blocks, kwargs = [], self.conv_kwargs | |
# this loop only runs when len(blocks_channels) > 2 | |
for channels in pairwise(blocks_channels[:-1]): | |
blocks.append(self.Block(*channels, **kwargs)) # type: ignore | |
channels, kwargs = blocks_channels[-2:], {**kwargs, "stride": 1} # type: ignore | |
blocks.append(self.Block(*channels, **kwargs)) # type: ignore | |
self.hidden = nn.Sequential(*blocks) # type: ignore | |
def _build_output(self, return_logits: bool) -> None: | |
channels = self.all_channels[-2:] | |
kwargs = {**self.conv_kwargs, "stride": 1} | |
output: list[nn.Module] = [nn.Conv2d(*channels, **kwargs)] # type: ignore | |
# return probabilities for log-likelihood losses (e.g., vanilla gan loss) | |
if not return_logits: | |
output.append(nn.Sigmoid()) | |
self.output = nn.Sequential(*output) | |
def forward(self, x1: Tensor, x2: Tensor | None = None) -> Tensor: | |
# x.shape[1] == self.all_channels[0] (i.e., in_channels) | |
x = torch.cat((x1, x2), dim=1) if isinstance(x2, Tensor) else x1 | |
z = self.input(x) | |
z = self.hidden(z) | |
z = self.output(z) | |
return z # type: ignore |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment