Skip to content

Instantly share code, notes, and snippets.

@kaparoo
Last active April 6, 2023 12:47
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save kaparoo/5c3cac442253d6e38864510ddaf50cba to your computer and use it in GitHub Desktop.
Save kaparoo/5c3cac442253d6e38864510ddaf50cba to your computer and use it in GitHub Desktop.
Standalone pix2pix framework using PyTorch Lightning
# -*- coding: utf-8 -*-
__all__ = ("Pix2Pix",)
from collections.abc import Sequence
from itertools import pairwise
from typing import Any
import torch
import torch.nn as nn
import torch.nn.functional as F
from lightning import LightningModule
from torch import Tensor
from torch.nn.common_types import _size_2_t
from torch.optim import Adam, Optimizer
def make_exponential_channels(
num_channels: int = 3,
unit_channels: int = 64,
max_scale: int = 8,
) -> Sequence[int]:
return tuple(unit_channels * min(2**i, max_scale) for i in range(num_channels))
class GAN(LightningModule):
def __init__(
self,
source_channels: int = 3,
target_channels: int | None = None,
) -> None:
super().__init__()
self.automatic_optimization = False
if not isinstance(target_channels, int):
target_channels = source_channels
self.source_channels = source_channels
self.target_channels = target_channels
def get_labels(self, preds: Tensor, as_real: bool = True) -> Tensor:
return torch.ones_like(preds) if as_real else torch.zeros_like(preds)
def adversarial_loss(self, preds: Tensor, as_real: bool = True) -> Tensor:
labels = self.get_labels(preds, as_real)
return F.binary_cross_entropy(preds, labels)
class Pix2Pix(GAN):
def __init__(
self,
source_channels: int = 3,
target_channels: int | None = None,
input_first: bool = True,
lambda_rcn: float = 100.0,
num_dis_blocks: int | None = 3,
num_gen_blocks: int | None = 8,
use_autoencoder: bool = False,
num_dropouts: int = 3,
dropout: float = 0.5,
dis_channels: Sequence[int] | None = None,
gen_channels: Sequence[int] | None = None,
dis_conv_kwargs: dict[str, Any] | None = None,
gen_conv_kwargs: dict[str, Any] | None = None,
num_encoder_dropouts: int = 0,
encoder_dropout: float = 0.0,
learning_rate: float = 0.0002,
beta1: float = 0.5,
beta2: float = 0.999,
) -> None:
super().__init__(source_channels, target_channels)
self.save_hyperparameters()
self.lambda_rcn = lambda_rcn
self.input_first = input_first
self.discriminator = Discriminator(
in_channels=self.source_channels + self.target_channels,
num_blocks=num_dis_blocks,
conv_kwargs=dis_conv_kwargs,
blocks_channels=dis_channels,
)
Generator = Autoencoder if use_autoencoder else UNet
self.generator = Generator(
in_channels=self.source_channels,
out_channels=self.target_channels,
num_blocks=num_gen_blocks,
num_encoder_dropouts=num_encoder_dropouts,
num_decoder_dropouts=num_dropouts,
encoder_dropout=encoder_dropout,
decoder_dropout=dropout,
conv_kwargs=gen_conv_kwargs,
blocks_channels=gen_channels,
)
def forward(self, x: Tensor) -> Tensor:
return self.generator(x) # type: ignore
def reconstruction_loss(self, fake_images: Tensor, real_images: Tensor) -> Tensor:
return F.l1_loss(fake_images, real_images)
def train_discriminator(
self,
cond_images: Tensor,
real_images: Tensor,
fake_images: Tensor,
) -> None:
self.toggle_optimizer(optimizer := self.optimizers()[0])
optimizer.zero_grad()
preds_real = self.discriminator(cond_images, real_images)
preds_fake = self.discriminator(cond_images, fake_images.detach())
loss_real = self.adversarial_loss(preds_real, as_real=True)
loss_fake = self.adversarial_loss(preds_fake, as_real=False)
loss = (loss_real + loss_fake) / 2
self.manual_backward(loss)
optimizer.step()
self.untoggle_optimizer(optimizer)
self.log("d_loss", loss, prog_bar=True)
self.log("d_loss_real", loss_real)
self.log("d_loss_fake", loss_fake)
def train_generator(
self,
cond_images: Tensor,
fake_images: Tensor,
real_images: Tensor,
) -> None:
self.toggle_optimizer(optimizer := self.optimizers()[1])
optimizer.zero_grad()
preds = self.discriminator(cond_images, fake_images)
loss_adv = self.adversarial_loss(preds, as_real=True)
loss_rcn = self.reconstruction_loss(fake_images, real_images)
loss = loss_adv + self.lambda_rcn * loss_rcn
self.manual_backward(loss)
optimizer.step()
self.untoggle_optimizer(optimizer)
self.log("g_loss", loss, prog_bar=True)
self.log("g_loss_adv", loss_adv)
self.log("g_loss_rcn", loss_rcn)
def training_step(self, batch: tuple[Tensor, Tensor]) -> None:
if self.input_first:
cond_images, real_images = batch
else:
real_images, cond_images = batch
fake_images = self.forward(cond_images)
self.train_discriminator(cond_images, fake_images, real_images)
print("discriminator is trained") # FOR DEBUG
self.train_generator(cond_images, fake_images, real_images)
print("generator is trained") # FOR DEBUG
def configure_optimizers(self) -> tuple[Optimizer, Optimizer]:
learning_rate = self.hparams.learning_rate
betas = (self.hparams.beta1, self.hparams.beta2)
d_opt = Adam(self.discriminator.parameters(), lr=learning_rate, betas=betas)
g_opt = Adam(self.generator.parameters(), lr=learning_rate, betas=betas)
return (d_opt, g_opt)
class Autoencoder(nn.Module):
class _Block(nn.Module):
def __init__(
self,
out_channels: int,
norm: bool = True,
dropout: float = 0.0,
) -> None:
super().__init__()
self.conv: nn.Module
self.actv: nn.Module
self.norm = nn.BatchNorm2d(out_channels) if norm else nn.Identity()
self.drop = nn.Dropout2d(dropout) if 0 < dropout < 1 else nn.Identity()
def forward(self, x: Tensor) -> Tensor:
z = self.conv(x)
z = self.norm(z)
z = self.actv(z)
z = self.drop(z)
return z # type: ignore
class Encoder(_Block):
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: _size_2_t,
stride: _size_2_t,
padding: _size_2_t | str,
bias: bool = True,
norm: bool = True,
dropout: float = 0.0,
) -> None:
super().__init__(out_channels, norm, dropout)
self.conv = nn.Conv2d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
bias=bias,
)
self.actv = nn.LeakyReLU(0.2, inplace=True)
class Decoder(_Block):
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: _size_2_t,
stride: _size_2_t,
padding: _size_2_t,
bias: bool = True,
norm: bool = True,
dropout: float = 0.0,
) -> None:
super().__init__(out_channels, norm, dropout)
self.conv = nn.ConvTranspose2d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
bias=bias,
)
self.actv = nn.ReLU(inplace=True)
def __init__(
self,
in_channels: int = 3,
out_channels: int | None = 3,
num_blocks: int | None = 8,
num_encoder_dropouts: int = 0,
num_decoder_dropouts: int = 3,
encoder_dropout: float = 0.0,
decoder_dropout: float = 0.5,
conv_kwargs: dict[str, Any] | None = None,
blocks_channels: Sequence[int] | None = None,
) -> None:
super().__init__()
self._set_all_channels(in_channels, out_channels, num_blocks, blocks_channels)
self._set_conv_kwargs(conv_kwargs)
self._build_encoders(num_encoder_dropouts, encoder_dropout)
self._build_decoders(num_decoder_dropouts, decoder_dropout)
def _set_all_channels(
self,
in_channels: int,
out_channels: int | None,
num_blocks: int | None,
blocks_channels: Sequence[int] | None,
) -> None:
if not isinstance(out_channels, int):
out_channels = in_channels
if not blocks_channels:
if not isinstance(num_blocks, int):
message = "one of `num_blocks` or `blocks_channels` must not be emtpy"
raise TypeError(message)
blocks_channels = make_exponential_channels(num_blocks)
self.encoders_channels = (in_channels, *blocks_channels)
self.decoders_channels = (out_channels, *blocks_channels)
def _set_conv_kwargs(self, conv_kwargs: dict[str, Any] | None) -> None:
if conv_kwargs is None:
conv_kwargs = {}
else:
raise NotImplementedError("user-defined conv_kwargs is not ready")
# allowed = {"kernel_size", "stride", "padding", "bias"}
# if notallowed := set(conv_kwargs.keys()) - allowed:
# raise KeyError(f"not allowed key(s): {notallowed} (allowed: {allowed})") # noqa: E501
self.conv_kwargs = {
"kernel_size": 4,
"stride": 2,
"padding": 1,
**conv_kwargs,
}
def _build_encoders(self, num_dropouts: int, dropout: float) -> None:
in_channels, *block_channels = self.encoders_channels
if num_dropouts >= len(block_channels): # exclude first encoder
raise ValueError("`num_dropouts` cannot exceed `len(blocks_channels)` - 1")
encoders, kwargs = [], self.conv_kwargs
# this loop runs only when len(blocks_channels) >= 2
for idx, channels in enumerate((pairwise(reversed(block_channels)))):
if idx >= num_dropouts:
dropout = 0.0
encoders.append(self.Encoder(*channels[::-1], **kwargs, dropout=dropout)) # type: ignore # noqa: E501
encoders.append(
self.Encoder(
in_channels=in_channels,
out_channels=block_channels[0],
**kwargs, # type: ignore
norm=False,
)
)
encoders.reverse()
self.encoders = nn.ModuleList(encoders)
def _decoder_channels_hook(
self,
channels: tuple[int, int],
idx: int | None = None,
) -> tuple[int, int]:
return channels
def _build_decoders(self, num_dropouts: int, dropout: float) -> None:
out_channels, *blocks_channels = self.decoders_channels
if num_dropouts >= len(blocks_channels): # exclude last decoder
raise ValueError("`num_dropouts` cannot exceed `len(blocks_channels)` - 1")
decoders, kwargs = [], self.conv_kwargs
# this loop runs only when len(blocks_channels) >= 2
for idx, channels in enumerate(pairwise(reversed(blocks_channels))):
if idx >= num_dropouts:
dropout = 0.0
channels = self._decoder_channels_hook(channels, idx)
decoders.append(self.Decoder(*channels, **kwargs, dropout=dropout)) # type: ignore # noqa: E501
channels = self._decoder_channels_hook((blocks_channels[0], out_channels))
decoders.append(
nn.Sequential(
nn.ConvTranspose2d(*channels, **kwargs), # type: ignore
nn.Tanh(),
)
)
self.decoders = nn.ModuleList(decoders)
def forward(self, x: Tensor) -> Tensor:
for encoder in self.encoders:
x = encoder(x)
for decoder in self.decoders:
x = decoder(x)
return x # type: ignore
class UNet(Autoencoder):
def _decoder_channels_hook(
self,
channels: tuple[int, int],
idx: int | None = None,
) -> tuple[int, int]:
if idx == 0:
return channels
# doubling `in_channels`` for skip connection
in_channels, out_channels = channels
return (2 * in_channels, out_channels)
def _forward_encoders(self, encoder_input: Tensor) -> Sequence[Tensor]:
encoder_outputs: Sequence[Tensor] = []
for encoder in self.encoders:
encoder_outputs.append(encoder(encoder_input)) # type: ignore
encoder_input = encoder_outputs[-1] # type: ignore
return encoder_outputs
def _forward_decoders(
self,
decoder_input: Tensor,
encoder_outputs: Sequence[Tensor],
) -> Tensor:
decoder, *decoders = self.decoders
decoder_output = decoder(decoder_input) # no skip connection
for decoder, encoder_output in zip(decoders, reversed(encoder_outputs)):
decoder_input = torch.cat((decoder_output, encoder_output), dim=1)
decoder_output = decoder(decoder_input)
return decoder_output # type: ignore
def forward(self, encoder_input: Tensor) -> Tensor:
*encoder_outputs, decoder_input = self._forward_encoders(encoder_input)
decoder_output = self._forward_decoders(decoder_input, encoder_outputs)
return decoder_output
class Discriminator(nn.Module):
class Block(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: _size_2_t = 4,
stride: _size_2_t = 2,
padding: _size_2_t | str = 1,
bias: bool = True,
norm: bool = True,
) -> None:
super().__init__()
self.conv = nn.Conv2d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
bias=bias,
)
self.norm = nn.BatchNorm2d(out_channels) if norm else nn.Identity()
self.actv = nn.LeakyReLU(0.2, inplace=True)
def forward(self, x: Tensor) -> Tensor:
h = self.conv(x)
h = self.norm(h)
h = self.actv(h)
return h # type: ignore
def __init__(
self,
in_channels: int = 6,
num_blocks: int | None = 3,
return_logits: bool = False,
conv_kwargs: dict[str, Any] | None = None,
blocks_channels: Sequence[int] | None = None,
) -> None:
super().__init__()
self._set_all_channels(in_channels, num_blocks, blocks_channels)
self._set_conv_kwargs(num_blocks == 0, conv_kwargs)
self._build_input()
self._build_hiddens()
self._build_output(return_logits)
def _set_all_channels(
self,
in_channels: int,
num_blocks: int | None,
blocks_channels: Sequence[int] | None,
) -> None:
if not blocks_channels:
if not isinstance(num_blocks, int):
message = "one of `num_blocks` or `blocks_channels` must not be emtpy"
raise TypeError(message)
if num_blocks == 0: # PixelGAN (now num_blocks >= 1)
num_blocks = 2
blocks_channels = make_exponential_channels(num_blocks)
self.all_channels = (in_channels, *blocks_channels, 1)
def _set_conv_kwargs(
self,
pixelgan: bool,
conv_kwargs: dict[str, Any] | None,
) -> None:
if conv_kwargs is None:
conv_kwargs = {}
else:
allowed = {"kernel_size", "stride", "padding", "bias"}
if notallowed := set(conv_kwargs.keys()) - allowed:
raise KeyError(f"not allowed key(s): {notallowed} (allowed: {allowed})")
if pixelgan:
self.conv_kwargs = {
**conv_kwargs,
"kernel_size": 1,
"stride": 1,
"padding": 0,
}
else:
self.conv_kwargs = {
"kernel_size": 4,
"stride": 2,
"padding": 1,
**conv_kwargs,
}
def _build_input(self) -> None:
channels, kwargs = self.all_channels[:2], self.conv_kwargs
self.input = self.Block(*channels, **kwargs, norm=False) # type: ignore
def _build_hiddens(self) -> None:
if len(blocks_channels := self.all_channels[1:-1]) < 2:
self.hidden = nn.Identity()
return
blocks, kwargs = [], self.conv_kwargs
# this loop only runs when len(blocks_channels) > 2
for channels in pairwise(blocks_channels[:-1]):
blocks.append(self.Block(*channels, **kwargs)) # type: ignore
channels, kwargs = blocks_channels[-2:], {**kwargs, "stride": 1} # type: ignore
blocks.append(self.Block(*channels, **kwargs)) # type: ignore
self.hidden = nn.Sequential(*blocks) # type: ignore
def _build_output(self, return_logits: bool) -> None:
channels = self.all_channels[-2:]
kwargs = {**self.conv_kwargs, "stride": 1}
output: list[nn.Module] = [nn.Conv2d(*channels, **kwargs)] # type: ignore
# return probabilities for log-likelihood losses (e.g., vanilla gan loss)
if not return_logits:
output.append(nn.Sigmoid())
self.output = nn.Sequential(*output)
def forward(self, x1: Tensor, x2: Tensor | None = None) -> Tensor:
# x.shape[1] == self.all_channels[0] (i.e., in_channels)
x = torch.cat((x1, x2), dim=1) if isinstance(x2, Tensor) else x1
z = self.input(x)
z = self.hidden(z)
z = self.output(z)
return z # type: ignore
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment