Skip to content

Instantly share code, notes, and snippets.

@tam17aki
Created January 25, 2025 07:45
Show Gist options
  • Save tam17aki/c1d17c8a306ea2e61e069edd25c93f55 to your computer and use it in GitHub Desktop.
Save tam17aki/c1d17c8a306ea2e61e069edd25c93f55 to your computer and use it in GitHub Desktop.
# -*- coding: utf-8 -*-
"""A training script for Variational Autoencoder on MNIST dataset.
Copyright (C) 2025 by Akira TAMAMORI
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
"""
import argparse
import os
from typing import NamedTuple, final, override
import torch
from torch import Tensor, nn, optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torchvision.utils import save_image
class Arguments(NamedTuple):
"""Defines a class for miscellaneous configurations."""
batch_size_val: int # Batch size for validation
log_dir: str # Directory to store training log
out_dir: str # Directory to store output sample images
model_dir: str # Directory to store trained network parameters
model_file: str # Filename to store trained network parameters
device: str # Device to use for training
class TrainingConfig(NamedTuple):
"""Defines a class for training configuration."""
batch_size: int # Batch size for training
learning_rate: float # Learning rate for optimizer
num_epochs: int # Number of epochs for training
use_mse: bool # Use MSE loss instead of BCE loss
class ModelConfig(NamedTuple):
"""Defines a class for model configuration."""
hidden_dim: int # Number of features in intermediate layers
latent_dim: int # Dimension of latent space
use_affine: bool # Use affine transform in training
def parse_args() -> tuple[Arguments, TrainingConfig, ModelConfig]:
"""Parse command line arguments.
Returns:
arguments (Arguments): miscellaneous configurations.
train_config (TrainingConfig): configurations for model training.
model_config (ModelConfig): configurations for model definition.
"""
parser = argparse.ArgumentParser(description="VAE training script")
parser.add_argument(
"--batch_size_val", type=int, default=64, help="Batch size for validation"
)
parser.add_argument(
"--log_dir",
type=str,
default="./log",
help="Directory to store training log",
)
parser.add_argument(
"--out_dir",
type=str,
default="./images",
help="Directory to store output sample images",
)
parser.add_argument(
"--model_dir",
type=str,
default="./model",
help="Directory to store trained network parameters",
)
parser.add_argument(
"--model_file",
type=str,
default="model.pth",
help="Filename to store trained network parameters",
)
parser.add_argument(
"--device",
type=str,
default="cuda" if torch.cuda.is_available() else "cpu",
help="Device to use for training",
)
parser.add_argument(
"--batch_size", type=int, default=128, help="Batch size for training"
)
parser.add_argument(
"--learning_rate", type=float, default=1e-3, help="Learning rate for optimizer"
)
parser.add_argument(
"--num_epochs", type=int, default=30, help="Number of epochs for training"
)
parser.add_argument( # True if --use_mse is spefied in running.
"--use_mse", action="store_true", help="Use MSE loss instead of BCE loss"
)
parser.add_argument(
"--hidden_dim",
type=int,
default=512,
help="Number of features in intermediate layers",
)
parser.add_argument(
"--latent_dim", type=int, default=26, help="Dimension of latent space"
)
parser.add_argument( # True if --use_affine is spefied in running.
"--use_affine", action="store_true", help="Use affine transform in training"
)
args = parser.parse_args()
arguments = Arguments(
batch_size_val=args.batch_size_val,
log_dir=args.log_dir,
out_dir=args.out_dir,
model_dir=args.model_dir,
model_file=args.model_file,
device=args.device,
)
train_config = TrainingConfig(
batch_size=args.batch_size,
learning_rate=args.learning_rate,
num_epochs=args.num_epochs,
use_mse=args.use_mse,
)
model_config = ModelConfig(
hidden_dim=args.hidden_dim,
latent_dim=args.latent_dim,
use_affine=args.use_affine,
)
return arguments, train_config, model_config
@final
class VAE(nn.Module):
"""Variational Autoencoder."""
def __init__(self, config: ModelConfig, use_mse: bool):
"""Initialize module."""
super().__init__()
hidden_dim = config.hidden_dim
latent_dim = config.latent_dim
self.encoder = nn.Sequential(
nn.Flatten(), # [128, 28 x 28 x 1] = [128, 784]
nn.Linear(28 * 28, hidden_dim),
nn.BatchNorm1d(hidden_dim),
nn.ReLU(inplace=True),
nn.Linear(hidden_dim, hidden_dim),
nn.BatchNorm1d(hidden_dim),
nn.ReLU(inplace=True),
)
self.fc_mu = nn.Linear(hidden_dim, latent_dim)
self.fc_logvar = nn.Linear(hidden_dim, latent_dim)
# -5: affine parameters (translation and rotation in 2-d Euclidean space)
# +2: number of coordinates
self.decoder_fc = nn.Linear(latent_dim - 5 + 2, hidden_dim)
self.decoder = nn.Sequential(
nn.BatchNorm1d(hidden_dim),
nn.ReLU(inplace=True),
nn.Linear(hidden_dim, hidden_dim),
nn.BatchNorm1d(hidden_dim),
nn.ReLU(inplace=True),
nn.Linear(hidden_dim, 1),
nn.Tanh() if use_mse else nn.Sigmoid(),
)
coord = torch.cartesian_prod(
torch.linspace(-1, 1, 28), torch.linspace(-1, 1, 28)
)
coord = torch.reshape(coord, (28, 28, 2)).unsqueeze(0) # [1, 28, 28, 2]
self.register_buffer("coord", coord)
self.use_affine = config.use_affine
def encode(self, inputs: Tensor) -> tuple[Tensor, Tensor]:
"""Encode inputs.
Args:
inputs (torch.Tensor): input image
Returns:
mu (torch.Tensor): mean vector of posterior dist.
logvar (torch.Tensor): log-starndard deviation vector of posterior dist.
"""
hidden = self.encoder(inputs)
mu = self.fc_mu(hidden)
logvar = self.fc_logvar(hidden)
return mu, logvar
def reparameterize(self, mu: Tensor, logvar: Tensor) -> Tensor:
"""Perform reparameterization trick.
Args:
mu (torch.Tensor): mean vector
logvar (torch.Tensor): log-starndard deviation vector
Returns:
latent (torch.Tensor): latent variables
"""
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std)
latent = mu + eps * std
return latent
def augment_latent(self, latent: Tensor, use_affine: bool) -> Tensor:
"""Augment latent variables.
Args:
latent (torch.Tensor): latent variables
use_affine (bool): flag to apply affine transform
translation_scale (float): scaling factor for affine transform
rotation_scale (float): scaling factor for affine transform
Returns:
outputs (torch.Tensor): augmented latent variables
"""
batch_size = latent.shape[0] # 128
coord = self.coord.repeat(batch_size, 1, 1, 1) # [128, 28, 28, 2]
if use_affine:
scale = latent[:, -5:-3]
angle = latent[:, -3:-2]
translate = latent[:, -2:]
# scale
scales = torch.zeros((batch_size, 3, 3)).to(latent.device) # [128, 1, 3]
scales[:, 0, 0] = 0.2 * torch.tanh(scale[:, 0]) + 1.0
scales[:, 1, 1] = 0.2 * torch.tanh(scale[:, 1]) + 1.0
scales[:, 2, 2] = 1.0
# rotation
rotates = torch.zeros((batch_size, 3, 3)).to(latent.device) # [128, 1, 3]
rotates[:, 0, 0] = torch.cos(angle[:, 0])
rotates[:, 0, 1] = -torch.sin(angle[:, 0])
rotates[:, 1, 0] = torch.sin(angle[:, 0])
rotates[:, 1, 1] = torch.cos(angle[:, 0])
rotates[:, 2, 2] = 1.0
# translation
translates = torch.zeros((batch_size, 3, 3)).to(latent.device)
translates[:, 0, 0] = 1.0
translates[:, 1, 1] = 1.0
translates[:, 2, 2] = 1.0
translates[:, 0, 2] = translate[:, 0]
translates[:, 1, 2] = translate[:, 1]
# affine transform
affine = torch.matmul(scales, rotates)
affine = torch.matmul(affine, translates)
ones = torch.ones_like(coord[:, :, :, 0:1]) # [128, 28, 28, 1]
coord = torch.concat([coord, ones], dim=-1) # [128, 28, 28, 3]
coord = torch.einsum("bhwj, bij -> bhwi", coord, affine)
coord = coord[:, :, :, 0:2] # [128, 28, 28, 2]
latent_ = latent[:, :-5] # [128, 20]
latent_ = latent_[:, :, None, None] # [128, 20, 1, 1]
latent_ = torch.permute(latent_, (0, 2, 3, 1)) # [128, 1, 1, 20]
latent_ = latent_.repeat(
1, self.coord.shape[1], self.coord.shape[2], 1
) # [128, 28, 28, 20]
outputs = torch.concat([latent_, coord], dim=-1) # [128, 28, 28, 22]
outputs = torch.reshape(outputs, (-1, outputs.shape[-1]))
return outputs # [128 * 28 * 28, 22] = [100352, 22]
def decode(self, latent: Tensor, use_affine: bool) -> Tensor:
"""Decode latent variables.
Args:
latent (torch.Tensor): latent variables
use_affine (bool): flag to apply affine transform
Returns:
reconst (torch.Tensor): reconstructed image
"""
batch_size = latent.shape[0]
latent = self.augment_latent(latent, use_affine)
hidden = self.decoder_fc(latent)
hidden = self.decoder(hidden)
hidden = torch.reshape(
hidden, (batch_size, self.coord.shape[1], self.coord.shape[2], 1)
)
reconst: Tensor = torch.permute(hidden, (0, 3, 1, 2))
return reconst
@override
def forward(self, inputs: Tensor) -> tuple[Tensor, Tensor, Tensor]:
"""Forward propagation.
Args:
inputs (torch.Tensor): input image
Returns:
reconst (torch.Tensor): reconstructed image
mu (torch.Tensor): mean vector of posterior dist.
logvar (torch.Tensor): log-starndard deviation vector of posterior dist.
"""
mu, logvar = self.encode(inputs)
latent = self.reparameterize(mu, logvar)
reconst = self.decode(latent, self.use_affine)
return reconst, mu, logvar
def get_dataloader(
is_train: bool, use_mse: bool, batch_size: int
) -> DataLoader[tuple[Tensor, Tensor]]:
"""Get a dataloader for training or validation.
Args:
is_train (bool): a flag to determine training mode
use_mse (bool): flag to apply MSE loss or BCE loss
batch_size (int): batch size of data loader
Returns:
dataloader (Dataloader): a dataloader for training
"""
if use_mse: # convert tensor with range [-1, 1]
transform = transforms.Compose(
[transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))]
)
else: # convert tensor with range [0, 1]
transform = transforms.Compose([transforms.ToTensor()])
if is_train is True:
dataset = datasets.MNIST(
root="./data", train=True, transform=transform, download=True
)
dataloader = DataLoader(
dataset=dataset,
batch_size=batch_size,
shuffle=True,
drop_last=True,
)
else:
dataset = datasets.MNIST(
root="./data", train=False, transform=transform, download=True
)
dataloader = DataLoader(
dataset=dataset,
batch_size=batch_size,
shuffle=False,
drop_last=False,
)
return dataloader
def loss_function(
model: VAE, inputs: Tensor, use_mse: bool
) -> tuple[Tensor, Tensor, Tensor]:
"""Compute loss function (negative ELBO).
Args:
model (VAE): VAE module
inputs (torch.Tensor): input image
use_mse (bool): flag to apply MSE loss or BCE loss
Returns:
reconst_error (torch.Tensor): reconstruction error
kl_divergence (torch.Tensor): KL divergence
loss (torch.Tensor): loss function (negative ELBO)
"""
reconst, mu, logvar = model(inputs)
if use_mse:
reconst_error = nn.MSELoss(reduction="sum")(reconst, inputs)
else:
reconst_error = nn.BCELoss(reduction="sum")(reconst, inputs)
kl_divergence = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
loss = reconst_error + kl_divergence
return reconst_error, kl_divergence, loss
def generate_sample(
model: VAE, val_loader: DataLoader[tuple[Tensor, Tensor]], epoch: int
) -> None:
"""Generate samples from trained model.
Args:
model (VAE): VAE module
val_loader (DataLoader[tuple[Tensor, Tensor]]): dataloader for validation
epoch (int): current epoch
Returns:
None
"""
args, train_config, model_config = parse_args()
os.makedirs(args.out_dir, exist_ok=True)
batch_size = args.batch_size_val
val_data, _ = next(iter(val_loader))
val_data = val_data.to(args.device)
with torch.no_grad():
# sample latent variables randomly from the prior distribution
latent = torch.randn(batch_size, model_config.latent_dim).to(args.device)
generated_images = model.decode(latent, False) # without affine transform
images = generated_images.cpu().view(val_data.size())
if train_config.use_mse:
images = 0.5 + 0.5 * images
save_image(images[:batch_size], f"{args.out_dir}/generated_image_{epoch+1}.png")
# save reconstructed images of validation data for comparison
mu, logvar = model.encode(val_data)
latent = model.reparameterize(mu, logvar)
val_reconstructed = model.decode(latent, False) # without affine transform
val_reconstructed = val_reconstructed.view(val_data.size())
images = torch.cat([val_data.cpu(), val_reconstructed.cpu()], dim=3)
if train_config.use_mse:
images = 0.5 + 0.5 * images
save_image(images, f"{args.out_dir}/reconstructed_image_{epoch+1}.png")
def main() -> None:
"""Perform demonstration."""
args, train_config, model_config = parse_args()
train_loader = get_dataloader(True, train_config.use_mse, train_config.batch_size)
val_loader = get_dataloader(False, train_config.use_mse, args.batch_size_val)
model = VAE(model_config, train_config.use_mse).to(args.device)
optimizer = optim.Adam(model.parameters(), lr=train_config.learning_rate)
# Auto-Encoding VB (AEVB) algorithm
loss_dict = {}
global_step = 0
for epoch in range(train_config.num_epochs):
model.train()
loss_dict["reconst"] = 0.0
loss_dict["kl_div"] = 0.0
loss_dict["neg_elbo"] = 0.0
for data, _ in train_loader:
global_step += 1
data = data.to(args.device)
optimizer.zero_grad()
# an estimator of negative ELBO of the full dataset
reconst_error, kl_div, loss = loss_function(
model, data, train_config.use_mse
)
reconst_error *= len(train_loader)
kl_div *= len(train_loader)
loss *= len(train_loader)
loss.backward()
optimizer.step()
loss_dict["reconst"] += reconst_error.item()
loss_dict["kl_div"] += kl_div.item()
loss_dict["neg_elbo"] += loss.item()
print(
f"Epoch: {epoch+1}, "
+ f"Negative ELBO: {loss_dict['neg_elbo'] / len(train_loader):.6f}, "
+ f"Reconstruction Error: {loss_dict['reconst'] / len(train_loader):.6f}, "
+ f"KL Divergence: {loss_dict["kl_div"] / len(train_loader):.6f}"
)
# visualise training progress by generating samples from current model
model.eval()
generate_sample(model, val_loader, epoch)
os.makedirs(args.model_dir, exist_ok=True)
torch.save(model.state_dict(), f=os.path.join(args.model_dir, args.model_file))
print("Training finished.")
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment