Last active
January 3, 2025 18:26
-
-
Save tam17aki/730132340b33a54d35e03bb6b6eb9dc0 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# -*- coding: utf-8 -*- | |
"""A demonstration script for Autoencoder on MNIST dataset. | |
Copyright (C) 2025 by Akira TAMAMORI | |
Permission is hereby granted, free of charge, to any person obtaining a copy | |
of this software and associated documentation files (the "Software"), to deal | |
in the Software without restriction, including without limitation the rights | |
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | |
copies of the Software, and to permit persons to whom the Software is | |
furnished to do so, subject to the following conditions: | |
The above copyright notice and this permission notice shall be included in all | |
copies or substantial portions of the Software. | |
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | |
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | |
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | |
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | |
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, | |
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE | |
SOFTWARE. | |
""" | |
import argparse | |
import os | |
from typing import NamedTuple, final, override | |
import numpy as np | |
import torch | |
from torch import Tensor, nn, optim | |
from torch.utils.data import DataLoader | |
from torchvision import datasets, transforms | |
from torchvision.utils import save_image | |
class Arguments(NamedTuple): | |
"""Defines a class for miscellaneous configurations.""" | |
batch_size_val: int # Batch size for validation | |
out_dir: str # Directory to store output sample images | |
device: str # Device to use for training | |
class TrainingConfig(NamedTuple): | |
"""Defines a class for training configuration.""" | |
batch_size: int # Batch size for training | |
learning_rate: float # Learning rate for optimizer | |
num_epochs: int # Number of epochs for training | |
use_bce: bool # Use BCE loss instead of MSE loss | |
class ModelConfig(NamedTuple): | |
"""Defines a class for model configuration.""" | |
hidden_dim: int # Number of features in intermediate layers | |
latent_dim: int # Number of features in bottleneck layer | |
use_affine: bool # Use affine transform in training | |
def parse_args() -> tuple[Arguments, TrainingConfig, ModelConfig]: | |
"""Parse command line arguments. | |
Returns: | |
arguments (Arguments): miscellaneous configurations | |
train_config (TrainingConfig): configurations for model training | |
model_config (ModelConfig): configurations for model definition | |
""" | |
parser = argparse.ArgumentParser(description="AE training script") | |
parser.add_argument( | |
"--batch_size", type=int, default=128, help="Batch size for training" | |
) | |
parser.add_argument( | |
"--batch_size_val", type=int, default=64, help="Batch size for validation" | |
) | |
parser.add_argument( | |
"--learning_rate", type=float, default=1e-3, help="Learning rate for optimizer" | |
) | |
parser.add_argument( | |
"--num_epochs", type=int, default=30, help="Number of epochs for training" | |
) | |
parser.add_argument( | |
"--hidden_dim", | |
type=int, | |
default=512, | |
help="Number of features in intermediate layers", | |
) | |
parser.add_argument( | |
"--latent_dim", | |
type=int, | |
default=26, | |
help="Number of features in bottleneck layer", | |
) | |
parser.add_argument( # True if --use_bce is spefied in running. | |
"--use_bce", action="store_true", help="Use BCE loss instead of MSE loss" | |
) | |
parser.add_argument( # True if --use_affine is spefied in running. | |
"--use_affine", action="store_true", help="Use affine transform in training" | |
) | |
parser.add_argument( | |
"--out_dir", | |
type=str, | |
default="./images", | |
help="Directory to store output sample images", | |
) | |
parser.add_argument( | |
"--device", | |
type=str, | |
default="cuda" if torch.cuda.is_available() else "cpu", | |
help="Device to use for training", | |
) | |
args = parser.parse_args() | |
arguments = Arguments( | |
batch_size_val=args.batch_size_val, out_dir=args.out_dir, device=args.device | |
) | |
train_config = TrainingConfig( | |
batch_size=args.batch_size, | |
learning_rate=args.learning_rate, | |
num_epochs=args.num_epochs, | |
use_bce=args.use_bce, | |
) | |
model_config = ModelConfig( | |
hidden_dim=args.hidden_dim, | |
latent_dim=args.latent_dim, | |
use_affine=args.use_affine, | |
) | |
return arguments, train_config, model_config | |
@final | |
class AE(nn.Module): | |
"""Autoencoder.""" | |
def __init__(self, config: ModelConfig, use_bce: bool): | |
"""Initialize module.""" | |
super().__init__() | |
hidden_dim = config.hidden_dim | |
latent_dim = config.latent_dim | |
self.encoder = nn.Sequential( | |
nn.Flatten(), # [128, 28 x 28 x 1] = [128, 784] | |
nn.Linear(28 * 28, hidden_dim), | |
nn.BatchNorm1d(hidden_dim), | |
nn.ReLU(inplace=True), | |
nn.Linear(hidden_dim, hidden_dim), | |
nn.BatchNorm1d(hidden_dim), | |
nn.ReLU(inplace=True), | |
nn.Linear(hidden_dim, latent_dim), | |
) | |
# -6: affine parameters (translation and rotation in 2-d Euclidean space) | |
# +2: number of coordinates | |
self.decoder = nn.Sequential( | |
nn.Linear(latent_dim - 6 + 2, hidden_dim), | |
nn.BatchNorm1d(hidden_dim), | |
nn.ReLU(inplace=True), | |
nn.Linear(hidden_dim, hidden_dim), | |
nn.BatchNorm1d(hidden_dim), | |
nn.ReLU(inplace=True), | |
nn.Linear(hidden_dim, 1), | |
nn.Sigmoid() if use_bce else nn.Tanh(), | |
) | |
coord = torch.cartesian_prod( | |
torch.linspace(-1, 1, 28), torch.linspace(-1, 1, 28) | |
) | |
coord = torch.reshape(coord, (28, 28, 2)).unsqueeze(0) # [1, 28, 28, 2] | |
self.register_buffer("coord", coord) | |
self.use_affine = config.use_affine | |
def encode(self, inputs: Tensor, scale: float = 3.5) -> Tensor: | |
"""Encode input image. | |
Args: | |
inputs (torch.Tensor): input image | |
scale (float): scaling factor | |
Returns: | |
latent (torch.Tensor): bottleneck features in latent space | |
""" | |
hidden = self.encoder(inputs) | |
latent: Tensor = scale * torch.tanh(hidden) | |
return latent | |
def augment_latent( | |
self, latent: Tensor, use_affine: bool, scale: float = 0.1 | |
) -> Tensor: | |
"""Augment bottleneck features. | |
Args: | |
latent (torch.Tensor): bottleneck features in latent space | |
use_affine (bool): flag to apply affine transform | |
scale (float): scaling factor for affine transform | |
Returns: | |
outputs (torch.Tensor): augmented bottleneck features | |
""" | |
batch_size = latent.shape[0] # 128 | |
h_size = self.coord.shape[1] # 28 | |
w_size = self.coord.shape[2] # 28 | |
coord = self.coord.repeat(batch_size, 1, 1, 1) # [128, 28, 28, 2] | |
if use_affine: | |
affine = torch.reshape(latent[:, -6:], (-1, 2, 3)) # [128, 2, 3] | |
zeros = torch.zeros_like(affine[:, 0:1, :]) # [128, 1, 3] | |
affine = torch.concat([affine, zeros], dim=-2) # [128, 3, 3] | |
affine = scale * affine + torch.eye(3).to(latent.device) # [128, 3, 3] | |
ones = torch.ones_like(coord[:, :, :, 0:1]) # [128, 28, 28, 1] | |
coord = torch.concat([coord, ones], dim=-1) # [128, 28, 28, 3] | |
coord = torch.einsum("bhwj, bji -> bhwi", coord, affine) | |
coord = coord[:, :, :, 0:2] # [128, 28, 28, 2] | |
latent_ = latent[:, :-6] # [128, 20] | |
latent_ = latent_[:, :, None, None] # [128, 20, 1, 1] | |
latent_ = torch.permute(latent_, (0, 2, 3, 1)) # [128, 1, 1, 20] | |
latent_ = latent_.repeat(1, h_size, w_size, 1) # [128, 28, 28, 20] | |
outputs = torch.concat([coord, latent_], dim=-1) # [128, 28, 28, 22] | |
outputs = torch.reshape(outputs, (-1, outputs.shape[-1])) | |
return outputs # [128 * 28 * 28, 22] = [100352, 22] | |
def decode(self, latent: Tensor, use_affine: bool) -> Tensor: | |
"""Decode bottleneck features. | |
Args: | |
latent (torch.Tensor): bottleneck features in latent space | |
use_affine (bool): flag to apply affine transform | |
Returns: | |
reconst (torch.Tensor): reconstructed image | |
""" | |
batch_size = latent.shape[0] | |
latent = self.augment_latent(latent, use_affine) | |
hidden = self.decoder(latent) | |
hidden = torch.reshape( | |
hidden, (batch_size, self.coord.shape[1], self.coord.shape[2], 1) | |
) | |
reconst: Tensor = torch.permute(hidden, (0, 3, 1, 2)) | |
return reconst | |
@override | |
def forward(self, inputs: Tensor) -> Tensor: | |
"""Forward propagation. | |
Args: | |
inputs (torch.Tensor): input image | |
Returns: | |
reconst (torch.Tensor): reconstructed image | |
""" | |
latent = self.encode(inputs) | |
reconst = self.decode(latent, self.use_affine) | |
return reconst | |
def get_dataloader( | |
is_train: bool, transform: transforms.Compose, batch_size: int | |
) -> DataLoader[tuple[Tensor, Tensor]]: | |
"""Get a dataloader for training or validation. | |
Args: | |
is_train (bool): a flag to determine training mode | |
transform (transforms.Compose): a chain of transforms to be applied | |
batch_size (int): batch size of data loader | |
Returns: | |
dataloader (Dataloader): a dataloader for training | |
""" | |
if is_train is True: | |
dataset = datasets.MNIST( | |
root="./data", train=True, transform=transform, download=True | |
) | |
dataloader = DataLoader( | |
dataset=dataset, | |
batch_size=batch_size, | |
shuffle=True, | |
drop_last=True, | |
) | |
else: | |
dataset = datasets.MNIST( | |
root="./data", train=False, transform=transform, download=True | |
) | |
dataloader = DataLoader( | |
dataset=dataset, | |
batch_size=batch_size, | |
shuffle=False, | |
drop_last=False, | |
) | |
return dataloader | |
def loss_function(model: AE, inputs: Tensor, use_bce: bool) -> Tensor: | |
"""Compute loss function. | |
Args: | |
model (AE): AE module | |
inputs (torch.Tensor): input image | |
use_bce (bool): flag to apply BCE loss or MSE loss | |
Returns: | |
loss (torch.Tensor): reconstructed loss | |
""" | |
reconst = model(inputs) | |
if use_bce: | |
reconst_loss = nn.BCELoss(reduction="sum")(reconst, inputs) | |
else: | |
reconst_loss = nn.MSELoss(reduction="sum")(reconst, inputs) | |
loss: Tensor = reconst_loss | |
return loss | |
def generate_sample(model: AE, val_data: Tensor, epoch: int) -> None: | |
"""Generate samples from trained model. | |
Args: | |
model (AE): AE module | |
val_data (torch.Tensor): validation data | |
epoch (int): current epoch | |
""" | |
args, train_config, _ = parse_args() | |
os.makedirs(args.out_dir, exist_ok=True) | |
with torch.no_grad(): | |
# save reconstructed images of validation data for comparison | |
latent = model.encode(val_data) | |
val_reconstructed = model.decode(latent, False) # without affine transform | |
val_reconstructed = val_reconstructed.view(val_data.size()) | |
comparison = torch.cat([val_data.cpu(), val_reconstructed.cpu()], dim=3) | |
if not train_config.use_bce: | |
comparison = 0.5 + 0.5 * comparison | |
save_image(comparison, f"{args.out_dir}/reconstructed_image_{epoch+1}.png") | |
def main() -> None: | |
"""Perform demonstration.""" | |
args, train_config, model_config = parse_args() | |
if train_config.use_bce: | |
transform = transforms.Compose([transforms.ToTensor()]) | |
else: | |
transform = transforms.Compose( | |
[transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))] | |
) | |
train_loader = get_dataloader(True, transform, train_config.batch_size) | |
test_loader = get_dataloader(False, transform, args.batch_size_val) | |
model = AE(model_config, train_config.use_bce).to(args.device) | |
optimizer = optim.Adam(model.parameters(), lr=train_config.learning_rate) | |
# prepare validation data | |
val_data, _ = next(iter(test_loader)) | |
val_data = val_data.to(args.device) | |
for epoch in range(train_config.num_epochs): | |
model.train() | |
epoch_loss = [] | |
for data, _ in train_loader: | |
data = data.to(args.device) | |
optimizer.zero_grad() | |
loss = loss_function(model, data, train_config.use_bce) | |
epoch_loss.append(loss.item()) | |
loss.backward() | |
optimizer.step() | |
print(f"Epoch: {epoch+1}, Average Loss: {np.average(epoch_loss):.12f}") | |
# visualise training progress by generating samples from current model | |
model.eval() | |
generate_sample(model, val_data, epoch) | |
print("Training finished.") | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment