Created
January 25, 2025 07:45
-
-
Save tam17aki/c1d17c8a306ea2e61e069edd25c93f55 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# -*- coding: utf-8 -*- | |
"""A training script for Variational Autoencoder on MNIST dataset. | |
Copyright (C) 2025 by Akira TAMAMORI | |
Permission is hereby granted, free of charge, to any person obtaining a copy | |
of this software and associated documentation files (the "Software"), to deal | |
in the Software without restriction, including without limitation the rights | |
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | |
copies of the Software, and to permit persons to whom the Software is | |
furnished to do so, subject to the following conditions: | |
The above copyright notice and this permission notice shall be included in all | |
copies or substantial portions of the Software. | |
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | |
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | |
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | |
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | |
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, | |
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE | |
SOFTWARE. | |
""" | |
import argparse | |
import os | |
from typing import NamedTuple, final, override | |
import torch | |
from torch import Tensor, nn, optim | |
from torch.utils.data import DataLoader | |
from torchvision import datasets, transforms | |
from torchvision.utils import save_image | |
class Arguments(NamedTuple): | |
"""Defines a class for miscellaneous configurations.""" | |
batch_size_val: int # Batch size for validation | |
log_dir: str # Directory to store training log | |
out_dir: str # Directory to store output sample images | |
model_dir: str # Directory to store trained network parameters | |
model_file: str # Filename to store trained network parameters | |
device: str # Device to use for training | |
class TrainingConfig(NamedTuple): | |
"""Defines a class for training configuration.""" | |
batch_size: int # Batch size for training | |
learning_rate: float # Learning rate for optimizer | |
num_epochs: int # Number of epochs for training | |
use_mse: bool # Use MSE loss instead of BCE loss | |
class ModelConfig(NamedTuple): | |
"""Defines a class for model configuration.""" | |
hidden_dim: int # Number of features in intermediate layers | |
latent_dim: int # Dimension of latent space | |
use_affine: bool # Use affine transform in training | |
def parse_args() -> tuple[Arguments, TrainingConfig, ModelConfig]: | |
"""Parse command line arguments. | |
Returns: | |
arguments (Arguments): miscellaneous configurations. | |
train_config (TrainingConfig): configurations for model training. | |
model_config (ModelConfig): configurations for model definition. | |
""" | |
parser = argparse.ArgumentParser(description="VAE training script") | |
parser.add_argument( | |
"--batch_size_val", type=int, default=64, help="Batch size for validation" | |
) | |
parser.add_argument( | |
"--log_dir", | |
type=str, | |
default="./log", | |
help="Directory to store training log", | |
) | |
parser.add_argument( | |
"--out_dir", | |
type=str, | |
default="./images", | |
help="Directory to store output sample images", | |
) | |
parser.add_argument( | |
"--model_dir", | |
type=str, | |
default="./model", | |
help="Directory to store trained network parameters", | |
) | |
parser.add_argument( | |
"--model_file", | |
type=str, | |
default="model.pth", | |
help="Filename to store trained network parameters", | |
) | |
parser.add_argument( | |
"--device", | |
type=str, | |
default="cuda" if torch.cuda.is_available() else "cpu", | |
help="Device to use for training", | |
) | |
parser.add_argument( | |
"--batch_size", type=int, default=128, help="Batch size for training" | |
) | |
parser.add_argument( | |
"--learning_rate", type=float, default=1e-3, help="Learning rate for optimizer" | |
) | |
parser.add_argument( | |
"--num_epochs", type=int, default=30, help="Number of epochs for training" | |
) | |
parser.add_argument( # True if --use_mse is spefied in running. | |
"--use_mse", action="store_true", help="Use MSE loss instead of BCE loss" | |
) | |
parser.add_argument( | |
"--hidden_dim", | |
type=int, | |
default=512, | |
help="Number of features in intermediate layers", | |
) | |
parser.add_argument( | |
"--latent_dim", type=int, default=26, help="Dimension of latent space" | |
) | |
parser.add_argument( # True if --use_affine is spefied in running. | |
"--use_affine", action="store_true", help="Use affine transform in training" | |
) | |
args = parser.parse_args() | |
arguments = Arguments( | |
batch_size_val=args.batch_size_val, | |
log_dir=args.log_dir, | |
out_dir=args.out_dir, | |
model_dir=args.model_dir, | |
model_file=args.model_file, | |
device=args.device, | |
) | |
train_config = TrainingConfig( | |
batch_size=args.batch_size, | |
learning_rate=args.learning_rate, | |
num_epochs=args.num_epochs, | |
use_mse=args.use_mse, | |
) | |
model_config = ModelConfig( | |
hidden_dim=args.hidden_dim, | |
latent_dim=args.latent_dim, | |
use_affine=args.use_affine, | |
) | |
return arguments, train_config, model_config | |
@final | |
class VAE(nn.Module): | |
"""Variational Autoencoder.""" | |
def __init__(self, config: ModelConfig, use_mse: bool): | |
"""Initialize module.""" | |
super().__init__() | |
hidden_dim = config.hidden_dim | |
latent_dim = config.latent_dim | |
self.encoder = nn.Sequential( | |
nn.Flatten(), # [128, 28 x 28 x 1] = [128, 784] | |
nn.Linear(28 * 28, hidden_dim), | |
nn.BatchNorm1d(hidden_dim), | |
nn.ReLU(inplace=True), | |
nn.Linear(hidden_dim, hidden_dim), | |
nn.BatchNorm1d(hidden_dim), | |
nn.ReLU(inplace=True), | |
) | |
self.fc_mu = nn.Linear(hidden_dim, latent_dim) | |
self.fc_logvar = nn.Linear(hidden_dim, latent_dim) | |
# -5: affine parameters (translation and rotation in 2-d Euclidean space) | |
# +2: number of coordinates | |
self.decoder_fc = nn.Linear(latent_dim - 5 + 2, hidden_dim) | |
self.decoder = nn.Sequential( | |
nn.BatchNorm1d(hidden_dim), | |
nn.ReLU(inplace=True), | |
nn.Linear(hidden_dim, hidden_dim), | |
nn.BatchNorm1d(hidden_dim), | |
nn.ReLU(inplace=True), | |
nn.Linear(hidden_dim, 1), | |
nn.Tanh() if use_mse else nn.Sigmoid(), | |
) | |
coord = torch.cartesian_prod( | |
torch.linspace(-1, 1, 28), torch.linspace(-1, 1, 28) | |
) | |
coord = torch.reshape(coord, (28, 28, 2)).unsqueeze(0) # [1, 28, 28, 2] | |
self.register_buffer("coord", coord) | |
self.use_affine = config.use_affine | |
def encode(self, inputs: Tensor) -> tuple[Tensor, Tensor]: | |
"""Encode inputs. | |
Args: | |
inputs (torch.Tensor): input image | |
Returns: | |
mu (torch.Tensor): mean vector of posterior dist. | |
logvar (torch.Tensor): log-starndard deviation vector of posterior dist. | |
""" | |
hidden = self.encoder(inputs) | |
mu = self.fc_mu(hidden) | |
logvar = self.fc_logvar(hidden) | |
return mu, logvar | |
def reparameterize(self, mu: Tensor, logvar: Tensor) -> Tensor: | |
"""Perform reparameterization trick. | |
Args: | |
mu (torch.Tensor): mean vector | |
logvar (torch.Tensor): log-starndard deviation vector | |
Returns: | |
latent (torch.Tensor): latent variables | |
""" | |
std = torch.exp(0.5 * logvar) | |
eps = torch.randn_like(std) | |
latent = mu + eps * std | |
return latent | |
def augment_latent(self, latent: Tensor, use_affine: bool) -> Tensor: | |
"""Augment latent variables. | |
Args: | |
latent (torch.Tensor): latent variables | |
use_affine (bool): flag to apply affine transform | |
translation_scale (float): scaling factor for affine transform | |
rotation_scale (float): scaling factor for affine transform | |
Returns: | |
outputs (torch.Tensor): augmented latent variables | |
""" | |
batch_size = latent.shape[0] # 128 | |
coord = self.coord.repeat(batch_size, 1, 1, 1) # [128, 28, 28, 2] | |
if use_affine: | |
scale = latent[:, -5:-3] | |
angle = latent[:, -3:-2] | |
translate = latent[:, -2:] | |
# scale | |
scales = torch.zeros((batch_size, 3, 3)).to(latent.device) # [128, 1, 3] | |
scales[:, 0, 0] = 0.2 * torch.tanh(scale[:, 0]) + 1.0 | |
scales[:, 1, 1] = 0.2 * torch.tanh(scale[:, 1]) + 1.0 | |
scales[:, 2, 2] = 1.0 | |
# rotation | |
rotates = torch.zeros((batch_size, 3, 3)).to(latent.device) # [128, 1, 3] | |
rotates[:, 0, 0] = torch.cos(angle[:, 0]) | |
rotates[:, 0, 1] = -torch.sin(angle[:, 0]) | |
rotates[:, 1, 0] = torch.sin(angle[:, 0]) | |
rotates[:, 1, 1] = torch.cos(angle[:, 0]) | |
rotates[:, 2, 2] = 1.0 | |
# translation | |
translates = torch.zeros((batch_size, 3, 3)).to(latent.device) | |
translates[:, 0, 0] = 1.0 | |
translates[:, 1, 1] = 1.0 | |
translates[:, 2, 2] = 1.0 | |
translates[:, 0, 2] = translate[:, 0] | |
translates[:, 1, 2] = translate[:, 1] | |
# affine transform | |
affine = torch.matmul(scales, rotates) | |
affine = torch.matmul(affine, translates) | |
ones = torch.ones_like(coord[:, :, :, 0:1]) # [128, 28, 28, 1] | |
coord = torch.concat([coord, ones], dim=-1) # [128, 28, 28, 3] | |
coord = torch.einsum("bhwj, bij -> bhwi", coord, affine) | |
coord = coord[:, :, :, 0:2] # [128, 28, 28, 2] | |
latent_ = latent[:, :-5] # [128, 20] | |
latent_ = latent_[:, :, None, None] # [128, 20, 1, 1] | |
latent_ = torch.permute(latent_, (0, 2, 3, 1)) # [128, 1, 1, 20] | |
latent_ = latent_.repeat( | |
1, self.coord.shape[1], self.coord.shape[2], 1 | |
) # [128, 28, 28, 20] | |
outputs = torch.concat([latent_, coord], dim=-1) # [128, 28, 28, 22] | |
outputs = torch.reshape(outputs, (-1, outputs.shape[-1])) | |
return outputs # [128 * 28 * 28, 22] = [100352, 22] | |
def decode(self, latent: Tensor, use_affine: bool) -> Tensor: | |
"""Decode latent variables. | |
Args: | |
latent (torch.Tensor): latent variables | |
use_affine (bool): flag to apply affine transform | |
Returns: | |
reconst (torch.Tensor): reconstructed image | |
""" | |
batch_size = latent.shape[0] | |
latent = self.augment_latent(latent, use_affine) | |
hidden = self.decoder_fc(latent) | |
hidden = self.decoder(hidden) | |
hidden = torch.reshape( | |
hidden, (batch_size, self.coord.shape[1], self.coord.shape[2], 1) | |
) | |
reconst: Tensor = torch.permute(hidden, (0, 3, 1, 2)) | |
return reconst | |
@override | |
def forward(self, inputs: Tensor) -> tuple[Tensor, Tensor, Tensor]: | |
"""Forward propagation. | |
Args: | |
inputs (torch.Tensor): input image | |
Returns: | |
reconst (torch.Tensor): reconstructed image | |
mu (torch.Tensor): mean vector of posterior dist. | |
logvar (torch.Tensor): log-starndard deviation vector of posterior dist. | |
""" | |
mu, logvar = self.encode(inputs) | |
latent = self.reparameterize(mu, logvar) | |
reconst = self.decode(latent, self.use_affine) | |
return reconst, mu, logvar | |
def get_dataloader( | |
is_train: bool, use_mse: bool, batch_size: int | |
) -> DataLoader[tuple[Tensor, Tensor]]: | |
"""Get a dataloader for training or validation. | |
Args: | |
is_train (bool): a flag to determine training mode | |
use_mse (bool): flag to apply MSE loss or BCE loss | |
batch_size (int): batch size of data loader | |
Returns: | |
dataloader (Dataloader): a dataloader for training | |
""" | |
if use_mse: # convert tensor with range [-1, 1] | |
transform = transforms.Compose( | |
[transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))] | |
) | |
else: # convert tensor with range [0, 1] | |
transform = transforms.Compose([transforms.ToTensor()]) | |
if is_train is True: | |
dataset = datasets.MNIST( | |
root="./data", train=True, transform=transform, download=True | |
) | |
dataloader = DataLoader( | |
dataset=dataset, | |
batch_size=batch_size, | |
shuffle=True, | |
drop_last=True, | |
) | |
else: | |
dataset = datasets.MNIST( | |
root="./data", train=False, transform=transform, download=True | |
) | |
dataloader = DataLoader( | |
dataset=dataset, | |
batch_size=batch_size, | |
shuffle=False, | |
drop_last=False, | |
) | |
return dataloader | |
def loss_function( | |
model: VAE, inputs: Tensor, use_mse: bool | |
) -> tuple[Tensor, Tensor, Tensor]: | |
"""Compute loss function (negative ELBO). | |
Args: | |
model (VAE): VAE module | |
inputs (torch.Tensor): input image | |
use_mse (bool): flag to apply MSE loss or BCE loss | |
Returns: | |
reconst_error (torch.Tensor): reconstruction error | |
kl_divergence (torch.Tensor): KL divergence | |
loss (torch.Tensor): loss function (negative ELBO) | |
""" | |
reconst, mu, logvar = model(inputs) | |
if use_mse: | |
reconst_error = nn.MSELoss(reduction="sum")(reconst, inputs) | |
else: | |
reconst_error = nn.BCELoss(reduction="sum")(reconst, inputs) | |
kl_divergence = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) | |
loss = reconst_error + kl_divergence | |
return reconst_error, kl_divergence, loss | |
def generate_sample( | |
model: VAE, val_loader: DataLoader[tuple[Tensor, Tensor]], epoch: int | |
) -> None: | |
"""Generate samples from trained model. | |
Args: | |
model (VAE): VAE module | |
val_loader (DataLoader[tuple[Tensor, Tensor]]): dataloader for validation | |
epoch (int): current epoch | |
Returns: | |
None | |
""" | |
args, train_config, model_config = parse_args() | |
os.makedirs(args.out_dir, exist_ok=True) | |
batch_size = args.batch_size_val | |
val_data, _ = next(iter(val_loader)) | |
val_data = val_data.to(args.device) | |
with torch.no_grad(): | |
# sample latent variables randomly from the prior distribution | |
latent = torch.randn(batch_size, model_config.latent_dim).to(args.device) | |
generated_images = model.decode(latent, False) # without affine transform | |
images = generated_images.cpu().view(val_data.size()) | |
if train_config.use_mse: | |
images = 0.5 + 0.5 * images | |
save_image(images[:batch_size], f"{args.out_dir}/generated_image_{epoch+1}.png") | |
# save reconstructed images of validation data for comparison | |
mu, logvar = model.encode(val_data) | |
latent = model.reparameterize(mu, logvar) | |
val_reconstructed = model.decode(latent, False) # without affine transform | |
val_reconstructed = val_reconstructed.view(val_data.size()) | |
images = torch.cat([val_data.cpu(), val_reconstructed.cpu()], dim=3) | |
if train_config.use_mse: | |
images = 0.5 + 0.5 * images | |
save_image(images, f"{args.out_dir}/reconstructed_image_{epoch+1}.png") | |
def main() -> None: | |
"""Perform demonstration.""" | |
args, train_config, model_config = parse_args() | |
train_loader = get_dataloader(True, train_config.use_mse, train_config.batch_size) | |
val_loader = get_dataloader(False, train_config.use_mse, args.batch_size_val) | |
model = VAE(model_config, train_config.use_mse).to(args.device) | |
optimizer = optim.Adam(model.parameters(), lr=train_config.learning_rate) | |
# Auto-Encoding VB (AEVB) algorithm | |
loss_dict = {} | |
global_step = 0 | |
for epoch in range(train_config.num_epochs): | |
model.train() | |
loss_dict["reconst"] = 0.0 | |
loss_dict["kl_div"] = 0.0 | |
loss_dict["neg_elbo"] = 0.0 | |
for data, _ in train_loader: | |
global_step += 1 | |
data = data.to(args.device) | |
optimizer.zero_grad() | |
# an estimator of negative ELBO of the full dataset | |
reconst_error, kl_div, loss = loss_function( | |
model, data, train_config.use_mse | |
) | |
reconst_error *= len(train_loader) | |
kl_div *= len(train_loader) | |
loss *= len(train_loader) | |
loss.backward() | |
optimizer.step() | |
loss_dict["reconst"] += reconst_error.item() | |
loss_dict["kl_div"] += kl_div.item() | |
loss_dict["neg_elbo"] += loss.item() | |
print( | |
f"Epoch: {epoch+1}, " | |
+ f"Negative ELBO: {loss_dict['neg_elbo'] / len(train_loader):.6f}, " | |
+ f"Reconstruction Error: {loss_dict['reconst'] / len(train_loader):.6f}, " | |
+ f"KL Divergence: {loss_dict["kl_div"] / len(train_loader):.6f}" | |
) | |
# visualise training progress by generating samples from current model | |
model.eval() | |
generate_sample(model, val_loader, epoch) | |
os.makedirs(args.model_dir, exist_ok=True) | |
torch.save(model.state_dict(), f=os.path.join(args.model_dir, args.model_file)) | |
print("Training finished.") | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment