Skip to content

Instantly share code, notes, and snippets.

@Kaixhin
Created November 5, 2018 18:46
Show Gist options
  • Save Kaixhin/9c90221aae216f59ca2594ea000b0afe to your computer and use it in GitHub Desktop.
Save Kaixhin/9c90221aae216f59ca2594ea000b0afe to your computer and use it in GitHub Desktop.
Associative Compression Networks
import os
import torch
from torch import nn, optim
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torchvision.utils import save_image
class Encoder(nn.Module):
def __init__(self, latent_size):
super().__init__()
self.conv1 = nn.Conv2d(1, 8, 5, stride=2, padding=2, bias=False)
self.conv2 = nn.Conv2d(8, 17, 3, stride=2, padding=2, bias=False)
self.conv3 = nn.Conv2d(16, 32, 3, stride=2, padding=1, bias=False)
self.fc_c = nn.Linear(32 * 4 * 4, latent_size)
def forward(self, x):
x = F.relu(self.conv1(x))
x = F.relu(self.conv2(x))
x = F.relu(self.conv3(x))
x = x.view(-1, 32 * 4 * 4)
c = self.fc_c(x) # Actually only producing deterministic encoding
return c
class Prior(nn.Module):
def __init__(self, latent_size):
super().__init__()
self.fc1 = nn.Linear(latent_size, 256)
self.fc_mu = nn.Linear(256, latent_size)
self.fc_log_var = nn.Linear(256, latent_size)
def forward(self, c):
x = torch.tanh(self.fc1(c))
mu, log_var = self.fc_mu(x), None
z = mu
if self.training:
log_var = self.fc_log_var(x)
z = z + log_var.mul(0.5).exp() * torch.randn_like(z)
return z, mu, log_var
class Decoder(nn.Module):
def __init__(self, latent_size):
super().__init__()
self.fc_dec = nn.Linear(latent_size, 32 * 4 * 4)
self.conv4 = nn.ConvTranspose2d(32, 16, 3, stride=2, padding=1, output_padding=1, bias=False)
self.conv5 = nn.ConvTranspose2d(16, 8, 3, stride=2, padding=2, output_padding=1, bias=False)
self.conv6 = nn.ConvTranspose2d(8, 1, 5, stride=2, padding=2, output_padding=1)
def forward(self, z):
x = self.fc_dec(z)
x = x.view(-1, 32, 4, 4)
x = F.relu(self.conv4(x))
x = F.relu(self.conv5(x))
return self.conv6(x)
def kl_normal(mu_0, log_var_0, mu_1, log_var_1):
kl = (2 * (log_var_1 - log_var_0)).exp() + ((mu_1 - mu_0) / log_var_0.exp()) ** 2 - 2 * (log_var_1 - log_var_0) - 1
return 0.5 * kl.sum(1).mean()
latent_size = 16 # Latent/code size
k = 5 # Number of nearest neighbours
batch_size = 128
epochs = 10
train_data = datasets.MNIST(os.path.join(os.path.expanduser('~'), '.torch', 'datasets', 'mnist'), transform=transforms.ToTensor(), download=True)
train_dataloader = DataLoader(train_data, batch_size=batch_size, drop_last=True, num_workers=4) # Make easier to track indices
C = torch.randn(len(train_data), latent_size).cuda() # Associative dataset (codebook)
log_ones = torch.zeros(batch_size, latent_size).cuda() # Fixed encoding log variance
encoder = Encoder(latent_size).cuda()
prior = Prior(latent_size).cuda()
decoder = Decoder(latent_size).cuda()
optimiser = optim.Adam(list(encoder.parameters()) + list(prior.parameters()) + list(decoder.parameters()), lr=5e-3)
for epoch in range(epochs):
indices = torch.arange(batch_size)
for i, (x, _) in enumerate(train_dataloader):
x = x.cuda()
c = encoder(x) # Get code
with torch.no_grad():
C[indices, :] = c # Update C with new codes
l2_dists, l2_inds = (C.expand(batch_size, C.size(0), C.size(1)) - c.unsqueeze(1)).pow(2).sum(2).sort(dim=1, descending=False)
knn_inds = l2_inds[:, 1:k + 1] # Drop element itself
# TODO: Make sure that neighbouring codes are only used per pass through dataset
c_hat = C[knn_inds[range(batch_size), torch.randint(k, (batch_size,)).long()]] # Pick c_hat randomly from KNN(x)
z, mu, log_var = prior(c_hat) # Sample from prior network
x_hat = decoder(z) # Decode sample
recon_loss = F.binary_cross_entropy_with_logits(x_hat, x, reduction='sum') # Reconstruction loss
kld = kl_normal(c, log_ones, mu, log_var) # KL divergence between variational posterior and conditional prior
loss = (recon_loss + kld) / batch_size
optimiser.zero_grad()
loss.backward()
optimiser.step()
indices += batch_size
if i % 100 == 0:
print(recon_loss.item(), kld.item())
# Reconstruction
save_image(torch.sigmoid(x_hat[:64].cpu()), 'x_hat_%d.png' % epoch)
# Daydream
with torch.no_grad():
xs = [torch.sigmoid(x_hat[:8])]
for _ in range(7):
c = encoder(xs[-1])
z, _, _ = prior(c)
x_hat = torch.sigmoid(decoder(z))
xs.append(x_hat)
save_image(torch.cat(xs, 0).cpu(), 'z_%d.png' % epoch)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment