Skip to content

Instantly share code, notes, and snippets.

@tam17aki
Last active January 3, 2025 18:26
Show Gist options
  • Save tam17aki/730132340b33a54d35e03bb6b6eb9dc0 to your computer and use it in GitHub Desktop.
Save tam17aki/730132340b33a54d35e03bb6b6eb9dc0 to your computer and use it in GitHub Desktop.
# -*- coding: utf-8 -*-
"""A demonstration script for Autoencoder on MNIST dataset.
Copyright (C) 2025 by Akira TAMAMORI
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
"""
import argparse
import os
from typing import NamedTuple, final, override
import numpy as np
import torch
from torch import Tensor, nn, optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torchvision.utils import save_image
class Arguments(NamedTuple):
"""Defines a class for miscellaneous configurations."""
batch_size_val: int # Batch size for validation
out_dir: str # Directory to store output sample images
device: str # Device to use for training
class TrainingConfig(NamedTuple):
"""Defines a class for training configuration."""
batch_size: int # Batch size for training
learning_rate: float # Learning rate for optimizer
num_epochs: int # Number of epochs for training
use_bce: bool # Use BCE loss instead of MSE loss
class ModelConfig(NamedTuple):
"""Defines a class for model configuration."""
hidden_dim: int # Number of features in intermediate layers
latent_dim: int # Number of features in bottleneck layer
use_affine: bool # Use affine transform in training
def parse_args() -> tuple[Arguments, TrainingConfig, ModelConfig]:
"""Parse command line arguments.
Returns:
arguments (Arguments): miscellaneous configurations
train_config (TrainingConfig): configurations for model training
model_config (ModelConfig): configurations for model definition
"""
parser = argparse.ArgumentParser(description="AE training script")
parser.add_argument(
"--batch_size", type=int, default=128, help="Batch size for training"
)
parser.add_argument(
"--batch_size_val", type=int, default=64, help="Batch size for validation"
)
parser.add_argument(
"--learning_rate", type=float, default=1e-3, help="Learning rate for optimizer"
)
parser.add_argument(
"--num_epochs", type=int, default=30, help="Number of epochs for training"
)
parser.add_argument(
"--hidden_dim",
type=int,
default=512,
help="Number of features in intermediate layers",
)
parser.add_argument(
"--latent_dim",
type=int,
default=26,
help="Number of features in bottleneck layer",
)
parser.add_argument( # True if --use_bce is spefied in running.
"--use_bce", action="store_true", help="Use BCE loss instead of MSE loss"
)
parser.add_argument( # True if --use_affine is spefied in running.
"--use_affine", action="store_true", help="Use affine transform in training"
)
parser.add_argument(
"--out_dir",
type=str,
default="./images",
help="Directory to store output sample images",
)
parser.add_argument(
"--device",
type=str,
default="cuda" if torch.cuda.is_available() else "cpu",
help="Device to use for training",
)
args = parser.parse_args()
arguments = Arguments(
batch_size_val=args.batch_size_val, out_dir=args.out_dir, device=args.device
)
train_config = TrainingConfig(
batch_size=args.batch_size,
learning_rate=args.learning_rate,
num_epochs=args.num_epochs,
use_bce=args.use_bce,
)
model_config = ModelConfig(
hidden_dim=args.hidden_dim,
latent_dim=args.latent_dim,
use_affine=args.use_affine,
)
return arguments, train_config, model_config
@final
class AE(nn.Module):
"""Autoencoder."""
def __init__(self, config: ModelConfig, use_bce: bool):
"""Initialize module."""
super().__init__()
hidden_dim = config.hidden_dim
latent_dim = config.latent_dim
self.encoder = nn.Sequential(
nn.Flatten(), # [128, 28 x 28 x 1] = [128, 784]
nn.Linear(28 * 28, hidden_dim),
nn.BatchNorm1d(hidden_dim),
nn.ReLU(inplace=True),
nn.Linear(hidden_dim, hidden_dim),
nn.BatchNorm1d(hidden_dim),
nn.ReLU(inplace=True),
nn.Linear(hidden_dim, latent_dim),
)
# -6: affine parameters (translation and rotation in 2-d Euclidean space)
# +2: number of coordinates
self.decoder = nn.Sequential(
nn.Linear(latent_dim - 6 + 2, hidden_dim),
nn.BatchNorm1d(hidden_dim),
nn.ReLU(inplace=True),
nn.Linear(hidden_dim, hidden_dim),
nn.BatchNorm1d(hidden_dim),
nn.ReLU(inplace=True),
nn.Linear(hidden_dim, 1),
nn.Sigmoid() if use_bce else nn.Tanh(),
)
coord = torch.cartesian_prod(
torch.linspace(-1, 1, 28), torch.linspace(-1, 1, 28)
)
coord = torch.reshape(coord, (28, 28, 2)).unsqueeze(0) # [1, 28, 28, 2]
self.register_buffer("coord", coord)
self.use_affine = config.use_affine
def encode(self, inputs: Tensor, scale: float = 3.5) -> Tensor:
"""Encode input image.
Args:
inputs (torch.Tensor): input image
scale (float): scaling factor
Returns:
latent (torch.Tensor): bottleneck features in latent space
"""
hidden = self.encoder(inputs)
latent: Tensor = scale * torch.tanh(hidden)
return latent
def augment_latent(
self, latent: Tensor, use_affine: bool, scale: float = 0.1
) -> Tensor:
"""Augment bottleneck features.
Args:
latent (torch.Tensor): bottleneck features in latent space
use_affine (bool): flag to apply affine transform
scale (float): scaling factor for affine transform
Returns:
outputs (torch.Tensor): augmented bottleneck features
"""
batch_size = latent.shape[0] # 128
h_size = self.coord.shape[1] # 28
w_size = self.coord.shape[2] # 28
coord = self.coord.repeat(batch_size, 1, 1, 1) # [128, 28, 28, 2]
if use_affine:
affine = torch.reshape(latent[:, -6:], (-1, 2, 3)) # [128, 2, 3]
zeros = torch.zeros_like(affine[:, 0:1, :]) # [128, 1, 3]
affine = torch.concat([affine, zeros], dim=-2) # [128, 3, 3]
affine = scale * affine + torch.eye(3).to(latent.device) # [128, 3, 3]
ones = torch.ones_like(coord[:, :, :, 0:1]) # [128, 28, 28, 1]
coord = torch.concat([coord, ones], dim=-1) # [128, 28, 28, 3]
coord = torch.einsum("bhwj, bji -> bhwi", coord, affine)
coord = coord[:, :, :, 0:2] # [128, 28, 28, 2]
latent_ = latent[:, :-6] # [128, 20]
latent_ = latent_[:, :, None, None] # [128, 20, 1, 1]
latent_ = torch.permute(latent_, (0, 2, 3, 1)) # [128, 1, 1, 20]
latent_ = latent_.repeat(1, h_size, w_size, 1) # [128, 28, 28, 20]
outputs = torch.concat([coord, latent_], dim=-1) # [128, 28, 28, 22]
outputs = torch.reshape(outputs, (-1, outputs.shape[-1]))
return outputs # [128 * 28 * 28, 22] = [100352, 22]
def decode(self, latent: Tensor, use_affine: bool) -> Tensor:
"""Decode bottleneck features.
Args:
latent (torch.Tensor): bottleneck features in latent space
use_affine (bool): flag to apply affine transform
Returns:
reconst (torch.Tensor): reconstructed image
"""
batch_size = latent.shape[0]
latent = self.augment_latent(latent, use_affine)
hidden = self.decoder(latent)
hidden = torch.reshape(
hidden, (batch_size, self.coord.shape[1], self.coord.shape[2], 1)
)
reconst: Tensor = torch.permute(hidden, (0, 3, 1, 2))
return reconst
@override
def forward(self, inputs: Tensor) -> Tensor:
"""Forward propagation.
Args:
inputs (torch.Tensor): input image
Returns:
reconst (torch.Tensor): reconstructed image
"""
latent = self.encode(inputs)
reconst = self.decode(latent, self.use_affine)
return reconst
def get_dataloader(
is_train: bool, transform: transforms.Compose, batch_size: int
) -> DataLoader[tuple[Tensor, Tensor]]:
"""Get a dataloader for training or validation.
Args:
is_train (bool): a flag to determine training mode
transform (transforms.Compose): a chain of transforms to be applied
batch_size (int): batch size of data loader
Returns:
dataloader (Dataloader): a dataloader for training
"""
if is_train is True:
dataset = datasets.MNIST(
root="./data", train=True, transform=transform, download=True
)
dataloader = DataLoader(
dataset=dataset,
batch_size=batch_size,
shuffle=True,
drop_last=True,
)
else:
dataset = datasets.MNIST(
root="./data", train=False, transform=transform, download=True
)
dataloader = DataLoader(
dataset=dataset,
batch_size=batch_size,
shuffle=False,
drop_last=False,
)
return dataloader
def loss_function(model: AE, inputs: Tensor, use_bce: bool) -> Tensor:
"""Compute loss function.
Args:
model (AE): AE module
inputs (torch.Tensor): input image
use_bce (bool): flag to apply BCE loss or MSE loss
Returns:
loss (torch.Tensor): reconstructed loss
"""
reconst = model(inputs)
if use_bce:
reconst_loss = nn.BCELoss(reduction="sum")(reconst, inputs)
else:
reconst_loss = nn.MSELoss(reduction="sum")(reconst, inputs)
loss: Tensor = reconst_loss
return loss
def generate_sample(model: AE, val_data: Tensor, epoch: int) -> None:
"""Generate samples from trained model.
Args:
model (AE): AE module
val_data (torch.Tensor): validation data
epoch (int): current epoch
"""
args, train_config, _ = parse_args()
os.makedirs(args.out_dir, exist_ok=True)
with torch.no_grad():
# save reconstructed images of validation data for comparison
latent = model.encode(val_data)
val_reconstructed = model.decode(latent, False) # without affine transform
val_reconstructed = val_reconstructed.view(val_data.size())
comparison = torch.cat([val_data.cpu(), val_reconstructed.cpu()], dim=3)
if not train_config.use_bce:
comparison = 0.5 + 0.5 * comparison
save_image(comparison, f"{args.out_dir}/reconstructed_image_{epoch+1}.png")
def main() -> None:
"""Perform demonstration."""
args, train_config, model_config = parse_args()
if train_config.use_bce:
transform = transforms.Compose([transforms.ToTensor()])
else:
transform = transforms.Compose(
[transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))]
)
train_loader = get_dataloader(True, transform, train_config.batch_size)
test_loader = get_dataloader(False, transform, args.batch_size_val)
model = AE(model_config, train_config.use_bce).to(args.device)
optimizer = optim.Adam(model.parameters(), lr=train_config.learning_rate)
# prepare validation data
val_data, _ = next(iter(test_loader))
val_data = val_data.to(args.device)
for epoch in range(train_config.num_epochs):
model.train()
epoch_loss = []
for data, _ in train_loader:
data = data.to(args.device)
optimizer.zero_grad()
loss = loss_function(model, data, train_config.use_bce)
epoch_loss.append(loss.item())
loss.backward()
optimizer.step()
print(f"Epoch: {epoch+1}, Average Loss: {np.average(epoch_loss):.12f}")
# visualise training progress by generating samples from current model
model.eval()
generate_sample(model, val_data, epoch)
print("Training finished.")
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment