Last active
February 6, 2025 09:31
-
-
Save tam17aki/44ff2fd8d778e64c808ad7deda85accc to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# -*- coding: utf-8 -*- | |
"""A demonstration script for Restricted Boltzmann Machine on MNIST dataset. | |
Copyright (C) 2025 by Akira TAMAMORI | |
Permission is hereby granted, free of charge, to any person obtaining a copy | |
of this software and associated documentation files (the "Software"), to deal | |
in the Software without restriction, including without limitation the rights | |
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | |
copies of the Software, and to permit persons to whom the Software is | |
furnished to do so, subject to the following conditions: | |
The above copyright notice and this permission notice shall be included in all | |
copies or substantial portions of the Software. | |
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | |
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | |
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | |
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | |
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, | |
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE | |
SOFTWARE. | |
""" | |
import argparse | |
import os | |
from typing import NamedTuple, final, override | |
import matplotlib.pyplot as plt | |
import numpy as np | |
import torch | |
from torch import nn, optim | |
from torch.utils.data import DataLoader | |
from torch.utils.tensorboard.writer import SummaryWriter | |
from torchvision import datasets, transforms | |
class Arguments(NamedTuple): | |
"""Defines a class for miscellaneous configurations.""" | |
batch_size_val: int # Batch size for validation | |
log_dir: str # Directory to store training log | |
out_dir: str # Directory to store output sample images | |
model_dir: str # Directory to store trained network parameters | |
model_file: str # Filename to store trained network parameters | |
num_reconst: int # Number of reconstructions to visualize | |
device: str # Device to use for training | |
class RBMConfig(NamedTuple): | |
"""Configuration for the RBM architecture.""" | |
visible_units: int # Number of units in the visible layer | |
hidden_units: int # Number of units in the hidden layer | |
class TrainingConfig(NamedTuple): | |
"""Configuration for training the RBM.""" | |
learning_rate: float | |
epochs: int # Number of training epochs | |
batch_size: int # Batch size for training | |
cd_steps: int # Number of steps for Contrastive Divergence | |
def parse_arguments() -> tuple[Arguments, RBMConfig, TrainingConfig]: | |
"""Parses command-line arguments and returns RBMConfig and TrainingConfig. | |
Returns: | |
tuple[Arguments, RBMConfig, TrainingConfig]: Configuration objects. | |
""" | |
parser = argparse.ArgumentParser( | |
description="Train and generate samples " | |
+ "using a Restricted Boltzmann Machine (RBM)." | |
) | |
parser.add_argument( | |
"--batch_size_val", | |
type=int, | |
default=64, | |
help="Batch size for validation (default: 64).", | |
) | |
parser.add_argument( | |
"--log_dir", | |
type=str, | |
default="./log", | |
help="Directory to store training log (default: './log').", | |
) | |
parser.add_argument( | |
"--out_dir", | |
type=str, | |
default="./images", | |
help="Directory to store output sample images (default: './images').", | |
) | |
parser.add_argument( | |
"--model_dir", | |
type=str, | |
default="./model", | |
help="Directory to store trained network parameters (default: './model').", | |
) | |
parser.add_argument( | |
"--model_file", | |
type=str, | |
default="model.pth", | |
help="Filename to store trained network parameters (default: 'model.pth').", | |
) | |
parser.add_argument( | |
"--num_reconst", | |
type=int, | |
default=5, | |
help="Number of reconstructions to visualize (default: 5).", | |
) | |
parser.add_argument( | |
"--device", | |
type=str, | |
default="cuda" if torch.cuda.is_available() else "cpu", | |
help="Device to use for training (default: 'cuda')", | |
) | |
parser.add_argument( | |
"--visible_units", | |
type=int, | |
default=784, | |
help="Number of units in the visible layer (default: 784).", | |
) | |
parser.add_argument( | |
"--hidden_units", | |
type=int, | |
default=256, | |
help="Number of units in the hidden layer (default: 256).", | |
) | |
parser.add_argument( | |
"--learning_rate", | |
type=float, | |
default=0.001, | |
help="Learning rate (default: 0.001).", | |
) | |
parser.add_argument( | |
"--epochs", | |
type=int, | |
default=10, | |
help="Number of training epochs (default: 10).", | |
) | |
parser.add_argument( | |
"--batch_size", | |
type=int, | |
default=64, | |
help="Batch size for training (default: 64).", | |
) | |
parser.add_argument( | |
"--cd_steps", | |
type=int, | |
default=1, | |
help="Number of steps for Contrastive Divergence (default: 1).", | |
) | |
args = parser.parse_args() | |
arguments = Arguments( | |
batch_size_val=args.batch_size_val, | |
log_dir=args.log_dir, | |
out_dir=args.out_dir, | |
model_dir=args.model_dir, | |
model_file=args.model_file, | |
num_reconst=args.num_reconst, | |
device=args.device, | |
) | |
rbm_config = RBMConfig( | |
visible_units=args.visible_units, hidden_units=args.hidden_units | |
) | |
training_config = TrainingConfig( | |
learning_rate=args.learning_rate, | |
epochs=args.epochs, | |
batch_size=args.batch_size, | |
cd_steps=args.cd_steps, | |
) | |
return arguments, rbm_config, training_config | |
@final | |
class RBM(nn.Module): | |
"""Implementation of a Restricted Boltzmann Machine (RBM). | |
Learns the connections between visible and hidden layers to model | |
the probability distribution of data. | |
""" | |
def __init__(self, visible_units: int, hidden_units: int): | |
"""Initializes the RBM. | |
Args: | |
visible_units (int): The number of units in the visible layer. | |
hidden_units (int): The number of units in the hidden layer. | |
""" | |
super().__init__() | |
self.w = nn.Parameter(torch.randn(visible_units, hidden_units) * 0.01) | |
self.bv = nn.Parameter(torch.zeros(visible_units)) | |
self.bh = nn.Parameter(torch.zeros(hidden_units)) | |
self.softplus = nn.Softplus() | |
def sample_h(self, v: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: | |
"""Samples the probability of the hidden layer. | |
Args: | |
v (torch.Tensor): State of the visible layer. | |
Returns: | |
h (torch.Tensor): Sampled state of the hidden layer. | |
""" | |
wx_b = torch.matmul(v, self.w) + self.bh | |
p_h = torch.sigmoid(wx_b) | |
h = torch.bernoulli(p_h) # Convert to a binary state | |
return h, p_h | |
def sample_v(self, h: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: | |
"""Samples the state of the visible layer. | |
Args: | |
h (torch.Tensor): State of the hidden layer. | |
Returns: | |
Tuple[torch.Tensor, torch.Tensor]: | |
v (torch.Tensor): Sampled state of the visible layer. | |
p_v (torch.Tensor): Sampling probability of the visible layer. | |
""" | |
wh_b = torch.matmul(h, self.w.t()) + self.bv | |
p_v = torch.sigmoid(wh_b) | |
v = torch.bernoulli(p_v) | |
return v, p_v # Visible layer uses a probability value | |
def free_energy(self, v: torch.Tensor) -> torch.Tensor: | |
"""Calculates free energy. | |
Args: | |
v (torch.Tensor): Input state of the visible layer. | |
Returns: | |
free_energy (torch.Tensor): Free energy. | |
""" | |
vbias_term = torch.matmul(v, self.bv) | |
wx_b = torch.matmul(v, self.w) + self.bh | |
hidden_term = torch.sum(self.softplus(wx_b), dim=1) | |
free_energy = -hidden_term - vbias_term | |
return free_energy | |
@override | |
def forward(self, v: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: | |
"""Reconstructs the visible layer state from the input visible layer state. | |
Implements the 1-step of Gibbs sampling (v->h & h->v'). | |
Args: | |
v (torch.Tensor): Input state of the visible layer. | |
Returns: | |
Tuple[torch.Tensor, torch.Tensor]: | |
v_recon (torch.Tensor): Reconstructed state of the visible layer. | |
p_v_recon (torch.Tensor): Probability of the reconstructed state. | |
""" | |
h, _ = self.sample_h(v) # Sample a binary state of the hidden layer | |
# Reconstruct the visible layer using the hidden layer's binary state | |
v_recon, p_v_recon = self.sample_v(h) | |
return v_recon, p_v_recon | |
def binarize_tensor(x: torch.Tensor) -> torch.Tensor: | |
"""Binarize a PyTorch tensor. | |
Args: | |
x (torch.Tensor): The input tensor. | |
Returns: | |
torch.Tensor: The binarized tensor. | |
""" | |
return (x > 0.5).float() | |
def flatten_tensor(x: torch.Tensor) -> torch.Tensor: | |
"""Flattens a PyTorch tensor into a 1D tensor. | |
Args: | |
x (torch.Tensor): The input tensor. | |
Returns: | |
torch.Tensor: The flattened 1D tensor. | |
""" | |
return x.flatten() | |
def get_dataloader( | |
is_train: bool, batch_size: int | |
) -> DataLoader[tuple[torch.Tensor, torch.Tensor]]: | |
"""Get a dataloader for training or validation. | |
Args: | |
is_train (bool): a flag to determine training mode | |
batch_size (int): batch size of data loader | |
Returns: | |
dataloader (Dataloader): a dataloader for training/test | |
""" | |
transform = transforms.Compose( | |
[ | |
transforms.ToTensor(), | |
transforms.Lambda(binarize_tensor), | |
transforms.Lambda(flatten_tensor), | |
] | |
) | |
if is_train is True: | |
dataset = datasets.MNIST( | |
root="./data", train=True, download=True, transform=transform | |
) | |
dataloader = DataLoader( | |
dataset, | |
batch_size=batch_size, | |
shuffle=True, | |
drop_last=True, | |
) | |
else: | |
dataset = datasets.MNIST( | |
root="./data", train=False, transform=transform, download=True | |
) | |
dataloader = DataLoader( | |
dataset=dataset, | |
batch_size=batch_size, | |
shuffle=False, | |
drop_last=False, | |
) | |
return dataloader | |
def cd_loss( | |
rbm: RBM, v0: torch.Tensor, config: TrainingConfig | |
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: | |
"""Performs Contrastive Divergence algorithm and calculates the loss. | |
Args: | |
rbm (RBM): The RBM model. | |
v0 (torch.Tensor): The visible layer state from the training data. | |
config (TrainingConfig): The training configuration. | |
Returns: | |
Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: | |
loss (torch.Tensor): The calculated loss value. | |
free_energy_0 (torch.Tensor): free energy of visible layer. | |
free_energy_k (torch.Tensor): free energy of reconstructed visible layer. | |
""" | |
# k-step Gibbs sampling | |
vk = v0 | |
for _ in range(config.cd_steps): | |
vk, _ = rbm(vk) | |
# Calculate Free Energies and the loss | |
free_energy_0 = rbm.free_energy(v0).mean() | |
free_energy_k = rbm.free_energy(vk).mean() | |
loss: torch.Tensor = free_energy_0 - free_energy_k | |
return loss, free_energy_0, free_energy_k | |
def generate_sample( | |
rbm: RBM, initial_state: torch.Tensor, steps: int = 200 | |
) -> tuple[torch.Tensor, torch.Tensor]: | |
"""Generates a sample using the RBM. | |
Args: | |
rbm (RBM): Trained RBM model. | |
initial_state (torch.Tensor): Initial state of the visible layer for sampling. | |
steps (int): Number of sampling steps. | |
Returns: | |
Tuple[torch.Tensor, torch.Tensor]: | |
visible_state (torch.Tensor): State of the generated sample. | |
p_visible_state (torch.Tensor): Probability of the visible state. | |
""" | |
visible_state: torch.Tensor = initial_state.clone() | |
p_visible_state: torch.Tensor = torch.zeros_like(initial_state) | |
rbm.eval() | |
for _ in range(steps): | |
visible_state, p_visible_state = rbm(visible_state) | |
return visible_state, p_visible_state | |
def visualize_reconstruction( | |
rbm: RBM, | |
data_loader: DataLoader[tuple[torch.Tensor, torch.Tensor]], | |
args: Arguments, | |
) -> None: | |
"""Reconstructs the input data using the trained RBM and visualizes the results. | |
Args: | |
rbm (RBM): The trained RBM model. | |
data_loader (DataLoader): data loader for evaluation. | |
args (Arguments): Miscellaneous configuration. | |
Returns: | |
None | |
""" | |
rbm.eval() | |
val_data, _ = next(iter(data_loader)) | |
val_data = val_data.to(args.device) | |
num_images = min(args.num_reconst, val_data.shape[0]) | |
_, axes = plt.subplots(2, num_images, figsize=(2 * num_images, 4)) | |
if num_images == 1: | |
axes = np.array(axes).reshape(2, -1) # Fix for 1 image case | |
with torch.no_grad(): | |
for i in range(num_images): | |
original_image = val_data[i].reshape(28, 28).cpu().detach().numpy() | |
reconst_data, _ = generate_sample(rbm, val_data[i].unsqueeze(0), 1) | |
reconst_image = reconst_data[0].reshape(28, 28).cpu().detach().numpy() | |
axes[0, i].imshow(original_image, cmap="gray") | |
axes[0, i].set_title("") | |
axes[1, i].imshow(reconst_image, cmap="gray") | |
axes[1, i].set_title("") | |
axes[0, i].axis("off") | |
axes[1, i].axis("off") | |
plt.suptitle("Test Raw Data (Top), Reconstruction Results (Bottom)") | |
plt.tight_layout() | |
plt.show() | |
def main() -> None: | |
"""Perform training.""" | |
args, rbm_config, training_config = parse_arguments() | |
train_loader = get_dataloader(True, training_config.batch_size) | |
# Model and optimizer initialization | |
model = RBM(rbm_config.visible_units, rbm_config.hidden_units).to(args.device) | |
optimizer = optim.Adam(model.parameters(), lr=training_config.learning_rate) | |
writer = SummaryWriter(log_dir=args.log_dir) | |
# Training loop | |
loss_dict: dict[str, list[float]] = {} | |
global_step = 0 | |
for epoch in range(training_config.epochs): | |
loss_dict["energy_v0"] = [] | |
loss_dict["energy_vk"] = [] | |
loss_dict["loss"] = [] | |
for data, _ in train_loader: | |
global_step += 1 | |
data = data.view(-1, rbm_config.visible_units).to(args.device) | |
optimizer.zero_grad() | |
loss, energy_v0, energy_vk = cd_loss(model, data, training_config) | |
loss.backward() | |
optimizer.step() | |
loss_dict["loss"].append(loss.item()) | |
loss_dict["energy_v0"].append(energy_v0.item()) | |
loss_dict["energy_vk"].append(energy_vk.item()) | |
writer.add_scalar("Loss/Free Energy (v0)", energy_v0.item(), global_step) | |
writer.add_scalar("Loss/Free Energy (vk)", energy_vk.item(), global_step) | |
writer.add_scalar("Loss/Loss Function", loss.item(), global_step) | |
print( | |
f"Epoch: {epoch+1} Average Loss: {np.average(loss_dict['loss']):.4f}, " | |
+ f"Free Energy (v0): {np.average(loss_dict['energy_v0']):.4f}, " | |
+ f"Free Energy (vk): {np.average(loss_dict['energy_vk']):.4f}" | |
) | |
# Visualize reconstruction | |
test_loader = get_dataloader(False, args.batch_size_val) | |
visualize_reconstruction(model, test_loader, args) | |
# Save parameters | |
os.makedirs(args.model_dir, exist_ok=True) | |
torch.save(model.state_dict(), f=os.path.join(args.model_dir, args.model_file)) | |
print("Training finished.") | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment