Skip to content

Instantly share code, notes, and snippets.

@tam17aki
Last active February 6, 2025 09:31
Show Gist options
  • Save tam17aki/44ff2fd8d778e64c808ad7deda85accc to your computer and use it in GitHub Desktop.
Save tam17aki/44ff2fd8d778e64c808ad7deda85accc to your computer and use it in GitHub Desktop.
# -*- coding: utf-8 -*-
"""A demonstration script for Restricted Boltzmann Machine on MNIST dataset.
Copyright (C) 2025 by Akira TAMAMORI
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
"""
import argparse
import os
from typing import NamedTuple, final, override
import matplotlib.pyplot as plt
import numpy as np
import torch
from torch import nn, optim
from torch.utils.data import DataLoader
from torch.utils.tensorboard.writer import SummaryWriter
from torchvision import datasets, transforms
class Arguments(NamedTuple):
"""Defines a class for miscellaneous configurations."""
batch_size_val: int # Batch size for validation
log_dir: str # Directory to store training log
out_dir: str # Directory to store output sample images
model_dir: str # Directory to store trained network parameters
model_file: str # Filename to store trained network parameters
num_reconst: int # Number of reconstructions to visualize
device: str # Device to use for training
class RBMConfig(NamedTuple):
"""Configuration for the RBM architecture."""
visible_units: int # Number of units in the visible layer
hidden_units: int # Number of units in the hidden layer
class TrainingConfig(NamedTuple):
"""Configuration for training the RBM."""
learning_rate: float
epochs: int # Number of training epochs
batch_size: int # Batch size for training
cd_steps: int # Number of steps for Contrastive Divergence
def parse_arguments() -> tuple[Arguments, RBMConfig, TrainingConfig]:
"""Parses command-line arguments and returns RBMConfig and TrainingConfig.
Returns:
tuple[Arguments, RBMConfig, TrainingConfig]: Configuration objects.
"""
parser = argparse.ArgumentParser(
description="Train and generate samples "
+ "using a Restricted Boltzmann Machine (RBM)."
)
parser.add_argument(
"--batch_size_val",
type=int,
default=64,
help="Batch size for validation (default: 64).",
)
parser.add_argument(
"--log_dir",
type=str,
default="./log",
help="Directory to store training log (default: './log').",
)
parser.add_argument(
"--out_dir",
type=str,
default="./images",
help="Directory to store output sample images (default: './images').",
)
parser.add_argument(
"--model_dir",
type=str,
default="./model",
help="Directory to store trained network parameters (default: './model').",
)
parser.add_argument(
"--model_file",
type=str,
default="model.pth",
help="Filename to store trained network parameters (default: 'model.pth').",
)
parser.add_argument(
"--num_reconst",
type=int,
default=5,
help="Number of reconstructions to visualize (default: 5).",
)
parser.add_argument(
"--device",
type=str,
default="cuda" if torch.cuda.is_available() else "cpu",
help="Device to use for training (default: 'cuda')",
)
parser.add_argument(
"--visible_units",
type=int,
default=784,
help="Number of units in the visible layer (default: 784).",
)
parser.add_argument(
"--hidden_units",
type=int,
default=256,
help="Number of units in the hidden layer (default: 256).",
)
parser.add_argument(
"--learning_rate",
type=float,
default=0.001,
help="Learning rate (default: 0.001).",
)
parser.add_argument(
"--epochs",
type=int,
default=10,
help="Number of training epochs (default: 10).",
)
parser.add_argument(
"--batch_size",
type=int,
default=64,
help="Batch size for training (default: 64).",
)
parser.add_argument(
"--cd_steps",
type=int,
default=1,
help="Number of steps for Contrastive Divergence (default: 1).",
)
args = parser.parse_args()
arguments = Arguments(
batch_size_val=args.batch_size_val,
log_dir=args.log_dir,
out_dir=args.out_dir,
model_dir=args.model_dir,
model_file=args.model_file,
num_reconst=args.num_reconst,
device=args.device,
)
rbm_config = RBMConfig(
visible_units=args.visible_units, hidden_units=args.hidden_units
)
training_config = TrainingConfig(
learning_rate=args.learning_rate,
epochs=args.epochs,
batch_size=args.batch_size,
cd_steps=args.cd_steps,
)
return arguments, rbm_config, training_config
@final
class RBM(nn.Module):
"""Implementation of a Restricted Boltzmann Machine (RBM).
Learns the connections between visible and hidden layers to model
the probability distribution of data.
"""
def __init__(self, visible_units: int, hidden_units: int):
"""Initializes the RBM.
Args:
visible_units (int): The number of units in the visible layer.
hidden_units (int): The number of units in the hidden layer.
"""
super().__init__()
self.w = nn.Parameter(torch.randn(visible_units, hidden_units) * 0.01)
self.bv = nn.Parameter(torch.zeros(visible_units))
self.bh = nn.Parameter(torch.zeros(hidden_units))
self.softplus = nn.Softplus()
def sample_h(self, v: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
"""Samples the probability of the hidden layer.
Args:
v (torch.Tensor): State of the visible layer.
Returns:
h (torch.Tensor): Sampled state of the hidden layer.
"""
wx_b = torch.matmul(v, self.w) + self.bh
p_h = torch.sigmoid(wx_b)
h = torch.bernoulli(p_h) # Convert to a binary state
return h, p_h
def sample_v(self, h: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
"""Samples the state of the visible layer.
Args:
h (torch.Tensor): State of the hidden layer.
Returns:
Tuple[torch.Tensor, torch.Tensor]:
v (torch.Tensor): Sampled state of the visible layer.
p_v (torch.Tensor): Sampling probability of the visible layer.
"""
wh_b = torch.matmul(h, self.w.t()) + self.bv
p_v = torch.sigmoid(wh_b)
v = torch.bernoulli(p_v)
return v, p_v # Visible layer uses a probability value
def free_energy(self, v: torch.Tensor) -> torch.Tensor:
"""Calculates free energy.
Args:
v (torch.Tensor): Input state of the visible layer.
Returns:
free_energy (torch.Tensor): Free energy.
"""
vbias_term = torch.matmul(v, self.bv)
wx_b = torch.matmul(v, self.w) + self.bh
hidden_term = torch.sum(self.softplus(wx_b), dim=1)
free_energy = -hidden_term - vbias_term
return free_energy
@override
def forward(self, v: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
"""Reconstructs the visible layer state from the input visible layer state.
Implements the 1-step of Gibbs sampling (v->h & h->v').
Args:
v (torch.Tensor): Input state of the visible layer.
Returns:
Tuple[torch.Tensor, torch.Tensor]:
v_recon (torch.Tensor): Reconstructed state of the visible layer.
p_v_recon (torch.Tensor): Probability of the reconstructed state.
"""
h, _ = self.sample_h(v) # Sample a binary state of the hidden layer
# Reconstruct the visible layer using the hidden layer's binary state
v_recon, p_v_recon = self.sample_v(h)
return v_recon, p_v_recon
def binarize_tensor(x: torch.Tensor) -> torch.Tensor:
"""Binarize a PyTorch tensor.
Args:
x (torch.Tensor): The input tensor.
Returns:
torch.Tensor: The binarized tensor.
"""
return (x > 0.5).float()
def flatten_tensor(x: torch.Tensor) -> torch.Tensor:
"""Flattens a PyTorch tensor into a 1D tensor.
Args:
x (torch.Tensor): The input tensor.
Returns:
torch.Tensor: The flattened 1D tensor.
"""
return x.flatten()
def get_dataloader(
is_train: bool, batch_size: int
) -> DataLoader[tuple[torch.Tensor, torch.Tensor]]:
"""Get a dataloader for training or validation.
Args:
is_train (bool): a flag to determine training mode
batch_size (int): batch size of data loader
Returns:
dataloader (Dataloader): a dataloader for training/test
"""
transform = transforms.Compose(
[
transforms.ToTensor(),
transforms.Lambda(binarize_tensor),
transforms.Lambda(flatten_tensor),
]
)
if is_train is True:
dataset = datasets.MNIST(
root="./data", train=True, download=True, transform=transform
)
dataloader = DataLoader(
dataset,
batch_size=batch_size,
shuffle=True,
drop_last=True,
)
else:
dataset = datasets.MNIST(
root="./data", train=False, transform=transform, download=True
)
dataloader = DataLoader(
dataset=dataset,
batch_size=batch_size,
shuffle=False,
drop_last=False,
)
return dataloader
def cd_loss(
rbm: RBM, v0: torch.Tensor, config: TrainingConfig
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Performs Contrastive Divergence algorithm and calculates the loss.
Args:
rbm (RBM): The RBM model.
v0 (torch.Tensor): The visible layer state from the training data.
config (TrainingConfig): The training configuration.
Returns:
Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
loss (torch.Tensor): The calculated loss value.
free_energy_0 (torch.Tensor): free energy of visible layer.
free_energy_k (torch.Tensor): free energy of reconstructed visible layer.
"""
# k-step Gibbs sampling
vk = v0
for _ in range(config.cd_steps):
vk, _ = rbm(vk)
# Calculate Free Energies and the loss
free_energy_0 = rbm.free_energy(v0).mean()
free_energy_k = rbm.free_energy(vk).mean()
loss: torch.Tensor = free_energy_0 - free_energy_k
return loss, free_energy_0, free_energy_k
def generate_sample(
rbm: RBM, initial_state: torch.Tensor, steps: int = 200
) -> tuple[torch.Tensor, torch.Tensor]:
"""Generates a sample using the RBM.
Args:
rbm (RBM): Trained RBM model.
initial_state (torch.Tensor): Initial state of the visible layer for sampling.
steps (int): Number of sampling steps.
Returns:
Tuple[torch.Tensor, torch.Tensor]:
visible_state (torch.Tensor): State of the generated sample.
p_visible_state (torch.Tensor): Probability of the visible state.
"""
visible_state: torch.Tensor = initial_state.clone()
p_visible_state: torch.Tensor = torch.zeros_like(initial_state)
rbm.eval()
for _ in range(steps):
visible_state, p_visible_state = rbm(visible_state)
return visible_state, p_visible_state
def visualize_reconstruction(
rbm: RBM,
data_loader: DataLoader[tuple[torch.Tensor, torch.Tensor]],
args: Arguments,
) -> None:
"""Reconstructs the input data using the trained RBM and visualizes the results.
Args:
rbm (RBM): The trained RBM model.
data_loader (DataLoader): data loader for evaluation.
args (Arguments): Miscellaneous configuration.
Returns:
None
"""
rbm.eval()
val_data, _ = next(iter(data_loader))
val_data = val_data.to(args.device)
num_images = min(args.num_reconst, val_data.shape[0])
_, axes = plt.subplots(2, num_images, figsize=(2 * num_images, 4))
if num_images == 1:
axes = np.array(axes).reshape(2, -1) # Fix for 1 image case
with torch.no_grad():
for i in range(num_images):
original_image = val_data[i].reshape(28, 28).cpu().detach().numpy()
reconst_data, _ = generate_sample(rbm, val_data[i].unsqueeze(0), 1)
reconst_image = reconst_data[0].reshape(28, 28).cpu().detach().numpy()
axes[0, i].imshow(original_image, cmap="gray")
axes[0, i].set_title("")
axes[1, i].imshow(reconst_image, cmap="gray")
axes[1, i].set_title("")
axes[0, i].axis("off")
axes[1, i].axis("off")
plt.suptitle("Test Raw Data (Top), Reconstruction Results (Bottom)")
plt.tight_layout()
plt.show()
def main() -> None:
"""Perform training."""
args, rbm_config, training_config = parse_arguments()
train_loader = get_dataloader(True, training_config.batch_size)
# Model and optimizer initialization
model = RBM(rbm_config.visible_units, rbm_config.hidden_units).to(args.device)
optimizer = optim.Adam(model.parameters(), lr=training_config.learning_rate)
writer = SummaryWriter(log_dir=args.log_dir)
# Training loop
loss_dict: dict[str, list[float]] = {}
global_step = 0
for epoch in range(training_config.epochs):
loss_dict["energy_v0"] = []
loss_dict["energy_vk"] = []
loss_dict["loss"] = []
for data, _ in train_loader:
global_step += 1
data = data.view(-1, rbm_config.visible_units).to(args.device)
optimizer.zero_grad()
loss, energy_v0, energy_vk = cd_loss(model, data, training_config)
loss.backward()
optimizer.step()
loss_dict["loss"].append(loss.item())
loss_dict["energy_v0"].append(energy_v0.item())
loss_dict["energy_vk"].append(energy_vk.item())
writer.add_scalar("Loss/Free Energy (v0)", energy_v0.item(), global_step)
writer.add_scalar("Loss/Free Energy (vk)", energy_vk.item(), global_step)
writer.add_scalar("Loss/Loss Function", loss.item(), global_step)
print(
f"Epoch: {epoch+1} Average Loss: {np.average(loss_dict['loss']):.4f}, "
+ f"Free Energy (v0): {np.average(loss_dict['energy_v0']):.4f}, "
+ f"Free Energy (vk): {np.average(loss_dict['energy_vk']):.4f}"
)
# Visualize reconstruction
test_loader = get_dataloader(False, args.batch_size_val)
visualize_reconstruction(model, test_loader, args)
# Save parameters
os.makedirs(args.model_dir, exist_ok=True)
torch.save(model.state_dict(), f=os.path.join(args.model_dir, args.model_file))
print("Training finished.")
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment