Skip to content

Instantly share code, notes, and snippets.

View SannaPersson's full-sized avatar

Sanna Persson SannaPersson

View GitHub Profile
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision import transforms
from torchvision.utils import save_image
from tqdm import tqdm
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
class VariationalAutoEncoder(nn.Module):
def __init__(self, input_dim, z_dim, h_dim=200):
super().__init__()
# encoder
self.img_2hid = nn.Linear(input_dim, h_dim)
# one for mu and one for stds, note how we only output
# diagonal values of covariance matrix. Here we assume
# the pixels are conditionally independent
self.hid_2mu = nn.Linear(h_dim, z_dim)
# Configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
INPUT_DIM = 784
Z_DIM = 20
H_DIM = 200
NUM_EPOCHS = 10
BATCH_SIZE = 32
LR_RATE = 3e-4
# Define train function
def train(num_epochs, model, optimizer, loss_fn):
# Start training
for epoch in range(num_epochs):
loop = tqdm(enumerate(train_loader))
for i, (x, y) in loop:
# Forward pass
x = x.to(device).view(-1, INPUT_DIM)
x_reconst, mu, sigma = model(x)
# Initialize model, optimizer, loss
model = VariationalAutoEncoder(INPUT_DIM, Z_DIM).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=LR_RATE)
loss_fn = nn.BCELoss(reduction="sum")
# Run training
train(NUM_EPOCHS, model, optimizer, loss_fn)
def inference(digit, num_examples=1):
"""
Generates (num_examples) of a particular digit.
Specifically we extract an example of each digit,
then after we have the mu, sigma representation for
each digit we can sample from that.
After we sample we can run the decoder part of the VAE
and generate examples.
"""
@SannaPersson
SannaPersson / f0beb25f-1b7a-4666-af7a-84daca6a8d57py
Created September 13, 2022 14:03
variational_autoencoder0
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision import transforms
from torchvision.utils import save_image
from tqdm import tqdm
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
@SannaPersson
SannaPersson / 4b81b90e-68a1-429b-b4e4-4ba26349315apy
Created September 13, 2022 14:03
variational_autoencoder1
class VariationalAutoEncoder(nn.Module):
def __init__(self, input_dim, z_dim, h_dim=200):
super().__init__()
# encoder
self.img_2hid = nn.Linear(input_dim, h_dim)
# one for mu and one for stds, note how we only output
# diagonal values of covariance matrix. Here we assume
# the pixels are conditionally independent
self.hid_2mu = nn.Linear(h_dim, z_dim)
@SannaPersson
SannaPersson / 81444f49-7158-4617-83c2-52aaf52c30d0py
Created September 13, 2022 14:03
variational_autoencoder2
# Configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
INPUT_DIM = 784
Z_DIM = 20
H_DIM = 200
NUM_EPOCHS = 10
BATCH_SIZE = 32
LR_RATE = 3e-4
@SannaPersson
SannaPersson / 57612d23-edc6-4d33-a332-ed8cbb5314bcpy
Created September 13, 2022 14:03
variational_autoencoder3
# Define train function
def train(num_epochs, model, optimizer, loss_fn):
# Start training
for epoch in range(num_epochs):
loop = tqdm(enumerate(train_loader))
for i, (x, y) in loop:
# Forward pass
x = x.to(device).view(-1, INPUT_DIM)
x_reconst, mu, sigma = model(x)