Skip to content

Instantly share code, notes, and snippets.

@p-i-
Created November 12, 2023 19:02
Show Gist options
  • Save p-i-/89f7386a9b9ea180015b6d1da83d607a to your computer and use it in GitHub Desktop.
Save p-i-/89f7386a9b9ea180015b6d1da83d607a to your computer and use it in GitHub Desktop.
Simple Variational Autoencoder over MNIST
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"! pip install -q numpy torch torchvision tqdm"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"import torch.nn as nn\n",
"import torch.nn.functional as F\n",
"\n",
"class Encoder(nn.Module):\n",
" def __init__(self, input_dim, hidden_dim, latent_dim):\n",
" super().__init__()\n",
" # Define layers for the encoder\n",
" # Example: a couple of linear layers with activation\n",
" self.fc1 = nn.Linear(input_dim, hidden_dim)\n",
"\n",
" self.fc_mu = nn.Linear(hidden_dim, latent_dim)\n",
" self.fc_logvar = nn.Linear(hidden_dim, latent_dim)\n",
"\n",
" def forward(self, x):\n",
" # Apply layers to input\n",
" h = F.relu(self.fc1(x))\n",
" mu = self.fc_mu(h)\n",
" log_var = self.fc_logvar(h)\n",
" return mu, log_var\n",
"\n",
"class Decoder(nn.Module):\n",
" def __init__(self, latent_dim, hidden_dim, output_dim):\n",
" super().__init__()\n",
" # Define layers for the decoder\n",
" # Example: reverse of encoder\n",
" self.fc1 = nn.Linear(latent_dim, hidden_dim)\n",
" self.fc2 = nn.Linear(hidden_dim, output_dim)\n",
"\n",
" def forward(self, z):\n",
" # Apply layers to latent variable\n",
" h = F.relu(self.fc1(z))\n",
" reconstruction = torch.sigmoid(self.fc2(h))\n",
" return reconstruction\n",
"\n",
"class VAE(nn.Module):\n",
" def __init__(self, input_dim, hidden_dim, latent_dim):\n",
" super().__init__()\n",
" self.encoder = Encoder(input_dim, hidden_dim, latent_dim)\n",
" self.decoder = Decoder(latent_dim, hidden_dim, output_dim=input_dim)\n",
"\n",
" def forward(self, x):\n",
" # Encode input to get mu and log_var\n",
" mu, log_var = self.encoder(x)\n",
" # Sample z from the distribution\n",
" std = torch.exp(0.5 * log_var)\n",
" eps = torch.randn_like(std)\n",
" z = mu + eps * std\n",
" # Decode z to reconstruct x\n",
" reconstructed_x = self.decoder(z)\n",
" return reconstructed_x, mu, log_var\n"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"from torchvision import datasets, transforms\n",
"from torch.utils.data import DataLoader\n",
"\n",
"# Define transformations for MNIST data\n",
"transform = transforms.Compose([\n",
" transforms.ToTensor(), # Convert images to PyTorch tensors\n",
" transforms.Normalize((0.5,), (0.5,)) # Normalize the images\n",
"])\n",
"\n",
"# Download and load the MNIST training data\n",
"train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)\n",
"train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)\n",
"\n",
"# Optionally, download and load the MNIST test data\n",
"test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)\n",
"test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"vae = VAE(input_dim=28*28, hidden_dim=128, latent_dim=16)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"import torch.optim as optim\n",
"\n",
"# Define the learning rate\n",
"learning_rate = 0.001\n",
"\n",
"# Create the optimizer, using the Adam algorithm in this case\n",
"optimizer = optim.Adam(vae.parameters(), lr=learning_rate)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch: 0\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Epoch [1/10]: 100%|██████████| 938/938 [00:04<00:00, 211.46it/s, kl_div=66.4, recon_loss=2.33e+4] \n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch: 1\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Epoch [2/10]: 100%|██████████| 938/938 [00:04<00:00, 220.60it/s, kl_div=79.3, recon_loss=2.32e+4]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch: 2\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Epoch [3/10]: 100%|██████████| 938/938 [00:04<00:00, 219.60it/s, kl_div=166, recon_loss=2.26e+4]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch: 3\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Epoch [4/10]: 100%|██████████| 938/938 [00:04<00:00, 217.36it/s, kl_div=187, recon_loss=2.26e+4]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch: 4\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Epoch [5/10]: 100%|██████████| 938/938 [00:04<00:00, 218.44it/s, kl_div=244, recon_loss=2.24e+4]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch: 5\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Epoch [6/10]: 100%|██████████| 938/938 [00:04<00:00, 220.89it/s, kl_div=264, recon_loss=2.25e+4]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch: 6\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Epoch [7/10]: 100%|██████████| 938/938 [00:04<00:00, 222.55it/s, kl_div=298, recon_loss=2.24e+4]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch: 7\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Epoch [8/10]: 100%|██████████| 938/938 [00:04<00:00, 217.78it/s, kl_div=295, recon_loss=2.23e+4]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch: 8\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Epoch [9/10]: 100%|██████████| 938/938 [00:04<00:00, 217.70it/s, kl_div=295, recon_loss=2.25e+4]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch: 9\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Epoch [10/10]: 100%|██████████| 938/938 [00:04<00:00, 217.63it/s, kl_div=317, recon_loss=2.23e+4]\n"
]
}
],
"source": [
"def vae_loss(reconstructed_x, x, mu, log_var):\n",
" # Reconstruction loss (e.g., MSE or Binary Cross Entropy)\n",
" recon_loss = F.mse_loss(reconstructed_x, x, reduction='sum')\n",
" # KL divergence\n",
" kl_div = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())\n",
" return recon_loss, kl_div\n",
"\n",
"from tqdm import tqdm\n",
"\n",
"NEPOCH = 10\n",
"for epoch in range(NEPOCH):\n",
" print('Epoch:', epoch)\n",
" # Wrap the train_loader with tqdm for a progress bar\n",
" train_loop = tqdm(train_loader, leave=True)\n",
" for batch in train_loop:\n",
" x, _ = batch # x: 64 1 28 28\n",
"\n",
" # Reshape the batch from [64, 1, 28, 28] to [64, 784]\n",
" x = x.view(x.size(0), -1) # x.size(0) is the batch size, -1 will flatten the remaining dimensions\n",
"\n",
" reconstructed_x, mu, log_var = vae(x)\n",
" recon_loss, kl_div = vae_loss(reconstructed_x, x, mu, log_var)\n",
" loss = recon_loss + kl_div\n",
"\n",
" # Update model parameters\n",
" optimizer.zero_grad()\n",
" loss.backward()\n",
" optimizer.step()\n",
"\n",
" # Update tqdm's description with the current loss\n",
" train_loop.set_description(f\"Epoch [{epoch+1}/{NEPOCH}]\")\n",
" train_loop.set_postfix(recon_loss=recon_loss.item(), kl_div=kl_div.item())\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.6"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment