Created
November 12, 2023 19:02
-
-
Save p-i-/89f7386a9b9ea180015b6d1da83d607a to your computer and use it in GitHub Desktop.
Simple Variational Autoencoder over MNIST
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"! pip install -q numpy torch torchvision tqdm" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import torch\n", | |
"import torch.nn as nn\n", | |
"import torch.nn.functional as F\n", | |
"\n", | |
"class Encoder(nn.Module):\n", | |
" def __init__(self, input_dim, hidden_dim, latent_dim):\n", | |
" super().__init__()\n", | |
" # Define layers for the encoder\n", | |
" # Example: a couple of linear layers with activation\n", | |
" self.fc1 = nn.Linear(input_dim, hidden_dim)\n", | |
"\n", | |
" self.fc_mu = nn.Linear(hidden_dim, latent_dim)\n", | |
" self.fc_logvar = nn.Linear(hidden_dim, latent_dim)\n", | |
"\n", | |
" def forward(self, x):\n", | |
" # Apply layers to input\n", | |
" h = F.relu(self.fc1(x))\n", | |
" mu = self.fc_mu(h)\n", | |
" log_var = self.fc_logvar(h)\n", | |
" return mu, log_var\n", | |
"\n", | |
"class Decoder(nn.Module):\n", | |
" def __init__(self, latent_dim, hidden_dim, output_dim):\n", | |
" super().__init__()\n", | |
" # Define layers for the decoder\n", | |
" # Example: reverse of encoder\n", | |
" self.fc1 = nn.Linear(latent_dim, hidden_dim)\n", | |
" self.fc2 = nn.Linear(hidden_dim, output_dim)\n", | |
"\n", | |
" def forward(self, z):\n", | |
" # Apply layers to latent variable\n", | |
" h = F.relu(self.fc1(z))\n", | |
" reconstruction = torch.sigmoid(self.fc2(h))\n", | |
" return reconstruction\n", | |
"\n", | |
"class VAE(nn.Module):\n", | |
" def __init__(self, input_dim, hidden_dim, latent_dim):\n", | |
" super().__init__()\n", | |
" self.encoder = Encoder(input_dim, hidden_dim, latent_dim)\n", | |
" self.decoder = Decoder(latent_dim, hidden_dim, output_dim=input_dim)\n", | |
"\n", | |
" def forward(self, x):\n", | |
" # Encode input to get mu and log_var\n", | |
" mu, log_var = self.encoder(x)\n", | |
" # Sample z from the distribution\n", | |
" std = torch.exp(0.5 * log_var)\n", | |
" eps = torch.randn_like(std)\n", | |
" z = mu + eps * std\n", | |
" # Decode z to reconstruct x\n", | |
" reconstructed_x = self.decoder(z)\n", | |
" return reconstructed_x, mu, log_var\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import torch\n", | |
"from torchvision import datasets, transforms\n", | |
"from torch.utils.data import DataLoader\n", | |
"\n", | |
"# Define transformations for MNIST data\n", | |
"transform = transforms.Compose([\n", | |
" transforms.ToTensor(), # Convert images to PyTorch tensors\n", | |
" transforms.Normalize((0.5,), (0.5,)) # Normalize the images\n", | |
"])\n", | |
"\n", | |
"# Download and load the MNIST training data\n", | |
"train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)\n", | |
"train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)\n", | |
"\n", | |
"# Optionally, download and load the MNIST test data\n", | |
"test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)\n", | |
"test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"vae = VAE(input_dim=28*28, hidden_dim=128, latent_dim=16)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import torch.optim as optim\n", | |
"\n", | |
"# Define the learning rate\n", | |
"learning_rate = 0.001\n", | |
"\n", | |
"# Create the optimizer, using the Adam algorithm in this case\n", | |
"optimizer = optim.Adam(vae.parameters(), lr=learning_rate)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Epoch: 0\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"Epoch [1/10]: 100%|██████████| 938/938 [00:04<00:00, 211.46it/s, kl_div=66.4, recon_loss=2.33e+4] \n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Epoch: 1\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"Epoch [2/10]: 100%|██████████| 938/938 [00:04<00:00, 220.60it/s, kl_div=79.3, recon_loss=2.32e+4]\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Epoch: 2\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"Epoch [3/10]: 100%|██████████| 938/938 [00:04<00:00, 219.60it/s, kl_div=166, recon_loss=2.26e+4]\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Epoch: 3\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"Epoch [4/10]: 100%|██████████| 938/938 [00:04<00:00, 217.36it/s, kl_div=187, recon_loss=2.26e+4]\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Epoch: 4\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"Epoch [5/10]: 100%|██████████| 938/938 [00:04<00:00, 218.44it/s, kl_div=244, recon_loss=2.24e+4]\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Epoch: 5\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"Epoch [6/10]: 100%|██████████| 938/938 [00:04<00:00, 220.89it/s, kl_div=264, recon_loss=2.25e+4]\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Epoch: 6\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"Epoch [7/10]: 100%|██████████| 938/938 [00:04<00:00, 222.55it/s, kl_div=298, recon_loss=2.24e+4]\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Epoch: 7\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"Epoch [8/10]: 100%|██████████| 938/938 [00:04<00:00, 217.78it/s, kl_div=295, recon_loss=2.23e+4]\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Epoch: 8\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"Epoch [9/10]: 100%|██████████| 938/938 [00:04<00:00, 217.70it/s, kl_div=295, recon_loss=2.25e+4]\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Epoch: 9\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"Epoch [10/10]: 100%|██████████| 938/938 [00:04<00:00, 217.63it/s, kl_div=317, recon_loss=2.23e+4]\n" | |
] | |
} | |
], | |
"source": [ | |
"def vae_loss(reconstructed_x, x, mu, log_var):\n", | |
" # Reconstruction loss (e.g., MSE or Binary Cross Entropy)\n", | |
" recon_loss = F.mse_loss(reconstructed_x, x, reduction='sum')\n", | |
" # KL divergence\n", | |
" kl_div = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())\n", | |
" return recon_loss, kl_div\n", | |
"\n", | |
"from tqdm import tqdm\n", | |
"\n", | |
"NEPOCH = 10\n", | |
"for epoch in range(NEPOCH):\n", | |
" print('Epoch:', epoch)\n", | |
" # Wrap the train_loader with tqdm for a progress bar\n", | |
" train_loop = tqdm(train_loader, leave=True)\n", | |
" for batch in train_loop:\n", | |
" x, _ = batch # x: 64 1 28 28\n", | |
"\n", | |
" # Reshape the batch from [64, 1, 28, 28] to [64, 784]\n", | |
" x = x.view(x.size(0), -1) # x.size(0) is the batch size, -1 will flatten the remaining dimensions\n", | |
"\n", | |
" reconstructed_x, mu, log_var = vae(x)\n", | |
" recon_loss, kl_div = vae_loss(reconstructed_x, x, mu, log_var)\n", | |
" loss = recon_loss + kl_div\n", | |
"\n", | |
" # Update model parameters\n", | |
" optimizer.zero_grad()\n", | |
" loss.backward()\n", | |
" optimizer.step()\n", | |
"\n", | |
" # Update tqdm's description with the current loss\n", | |
" train_loop.set_description(f\"Epoch [{epoch+1}/{NEPOCH}]\")\n", | |
" train_loop.set_postfix(recon_loss=recon_loss.item(), kl_div=kl_div.item())\n", | |
"\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": ".venv", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.11.6" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment