Skip to content

Instantly share code, notes, and snippets.

@airalcorn2
Last active August 7, 2022 14:57
Show Gist options
  • Save airalcorn2/c77ae441d6dbe23e6b6be39fd7042e43 to your computer and use it in GitHub Desktop.
Save airalcorn2/c77ae441d6dbe23e6b6be39fd7042e43 to your computer and use it in GitHub Desktop.
Minimal example demonstrating how a variational autoencoder frequently generates unrealistic samples when optimized to learn a simple 2D bimodal distribution.
# Adapted from: https://github.com/pytorch/examples/blob/main/vae/main.py.
import torch
import torch.utils.data
from torch import nn, optim
from torch.nn import functional as F
class VAE(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(2, 50)
self.fc21 = nn.Linear(50, 20)
self.fc22 = nn.Linear(50, 20)
self.fc3 = nn.Linear(20, 50)
self.fc4 = nn.Linear(50, 4)
def encode(self, x):
h1 = F.relu(self.fc1(x))
return (self.fc21(h1), self.fc22(h1))
def reparameterize(self, mu, logvar):
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std)
return mu + eps * std
def decode(self, z):
h3 = F.relu(self.fc3(z))
return self.fc4(h3)
def forward(self, x):
(mu, logvar) = self.encode(x)
z = self.reparameterize(mu, logvar)
return (self.decode(z), mu, logvar)
class RNN(nn.Module):
def __init__(self):
super().__init__()
self.rnn = nn.RNN(1, 25)
self.linear = nn.Linear(25, 2)
def forward(self, x):
(hiddens, _) = self.rnn(x)
return self.linear(hiddens)
def loss_function(recon_x, y, mu=None, logvar=None):
# Reconstruction + KL divergence losses summed over all elements and batch.
recon_x = recon_x.reshape(-1, 2)
log_probs = F.log_softmax(recon_x, dim=1)
NLL = F.nll_loss(log_probs, y, reduction="sum")
KLD = 0
if is_vae:
# See Appendix B from VAE paper:
# Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014
# https://arxiv.org/abs/1312.6114
# 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
return (NLL, KLD)
if __name__ == "__main__":
device = torch.device("cuda")
model = VAE().to(device)
print(model)
n_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Parameters: {n_params}")
optimizer = optim.Adam(model.parameters(), lr=1e-3)
X = torch.Tensor([[-1, -1], [1, 1]]).to(device)
y = torch.LongTensor([0, 0, 1, 1]).to(device)
epochs = 100
updates = 1024
is_vae = True
best_train_loss = float("inf")
patience = 5
no_improvement = 0
lr_drops = 0
for epoch in range(1, epochs + 1):
model.train()
NLL_total = 0
KLD_total = 0
train_loss = 0
for _ in range(updates):
optimizer.zero_grad()
(recon_batch, mu, logvar) = model(X)
(NLL, KLD) = loss_function(recon_batch, y, mu, logvar)
loss = NLL + KLD
loss.backward()
NLL_total += NLL.item() / len(X)
KLD_total += KLD.item() / len(X)
train_loss += loss.item() / len(X)
optimizer.step()
NLL_total /= updates
KLD_total /= updates
train_loss /= updates
if train_loss < best_train_loss:
best_train_loss = train_loss
no_improvement = 0
else:
no_improvement += 1
if no_improvement == patience:
lr_drops += 1
if lr_drops == 2:
break
print("Reducing learning rate.")
no_improvement = 0
for g in optimizer.param_groups:
g["lr"] *= 0.1
print(f"====> Epoch: {epoch} Average loss: {train_loss:.4f}")
print(f"====> Epoch: {epoch} Best average loss: {best_train_loss:.4f}")
print(f"====> Epoch: {epoch} Average NLL: {NLL_total:.4f}")
print(f"====> Epoch: {epoch} Average KLD: {KLD_total:.4f}")
with torch.no_grad():
samples = model.decode(torch.randn(500, 20).to(device)).cpu().reshape(-1, 2)
probs = torch.softmax(samples, dim=1)
samples = torch.multinomial(probs, 1).reshape(-1, 2).numpy()
both_zeros = ((samples[:, 0] == 0) & (samples[:, 1] == 0)).sum()
print(f"both_zeros %: {100 * both_zeros / len(samples)}")
both_ones = ((samples[:, 0] == 1) & (samples[:, 1] == 1)).sum()
print(f"both_ones %: {100 * both_ones / len(samples)}")
different = (samples[:, 0] != samples[:, 1]).sum()
print(f"different %: {100 * different / len(samples)}")
model = RNN().to(device)
print(model)
n_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Parameters: {n_params}")
optimizer = optim.Adam(model.parameters(), lr=1e-3)
X = torch.Tensor([[0, -1], [0, 1]]).to(device).T.unsqueeze(2)
is_vae = False
best_train_loss = float("inf")
no_improvement = 0
lr_drops = 0
for epoch in range(1, epochs + 1):
model.train()
train_loss = 0
for _ in range(updates):
optimizer.zero_grad()
recon_batch = model(X).permute(1, 0, 2)
loss = loss_function(recon_batch, y)[0]
loss.backward()
train_loss += loss.item() / len(X)
optimizer.step()
train_loss /= updates
if train_loss < best_train_loss:
best_train_loss = train_loss
no_improvement = 0
else:
no_improvement += 1
if no_improvement == patience:
lr_drops += 1
if lr_drops == 2:
break
print("Reducing learning rate.")
no_improvement = 0
for g in optimizer.param_groups:
g["lr"] *= 0.1
print(f"====> Epoch: {epoch} Average loss: {train_loss:.4f}")
print(f"====> Epoch: {epoch} Best average loss: {best_train_loss:.4f}")
with torch.no_grad():
samples = []
for sample in range(500):
X = torch.zeros(2, 1).to(device)
sample_vals = []
for step in range(2):
preds = model(X.unsqueeze(0).permute(1, 0, 2))
probs = torch.softmax(preds.permute(1, 0, 2).squeeze(0)[step], dim=0)
sample_val = torch.multinomial(probs, 1)
sample_vals.append(sample_val.item())
if step == 0:
X[step + 1] = 2 * sample_val - 1
samples.append(sample_vals)
samples = torch.Tensor(samples).cpu().numpy()
both_zeros = ((samples[:, 0] == 0) & (samples[:, 1] == 0)).sum()
print(f"both_zeros %: {100 * both_zeros / len(samples)}")
both_ones = ((samples[:, 0] == 1) & (samples[:, 1] == 1)).sum()
print(f"both_ones %: {100 * both_ones / len(samples)}")
different = (samples[:, 0] != samples[:, 1]).sum()
print(f"different %: {100 * different / len(samples)}")
@airalcorn2
Copy link
Author

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment