Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Star 2 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save madebyollin/89010ef7ff9b80d8fde2db54ead69244 to your computer and use it in GitHub Desktop.
Save madebyollin/89010ef7ff9b80d8fde2db54ead69244 to your computer and use it in GitHub Desktop.

Variational Autoencoders Will Never Work

So you want to generate images with neural networks. You're in luck! VAEs are here to save the day. They're simple to implement, they generate images in one inference step (unlike those awful slow autoregressive models) and (most importantly) VAEs are πŸš€πŸŽ‰πŸŽ‚πŸ₯³ theoretically grounded πŸš€πŸŽ‰πŸŽ‚πŸ₯³ (unlike those scary GANs - don't look at the GANs)!

The idea

The idea of VAE is so simple, even an AI chatbot could explain it:

  1. Your goal is to train a "decoder" neural network that consumes blobs of random noise from a fixed distribution (like torch.randn(1024)), interprets that noise as decisions about what to generate, and produces corresponding real-looking images. You want to train this network with nice simple image-space MSE loss against your dataset of real images.
  2. You don't have a proper dataset for this decoder yet, because you don't know which input noise blob is supposed to describe each of the real images in your dataset.
  3. So, you simultaneously train a secondary "encoder" that tries to autolabel a torch.randn-looking latent vector from each real image. You end up with two losses - a "reconstruction" (MSE) loss in image space, and a sort of "make it torch.randn-like" (KL divergence) loss in latent space. As the network learns to minimize both losses, it will eventually learn a perfect encoder/decoder to and from the fixed noise distribution, which yields a beautiful single-step generative model.

Not sold yet? Well, look here, there's some math that shows why VAE is guaranteed to work. Guaranteed! The whole thing's a parameterized probabilistic model of images, and you're optimizing it to maximize the likelihood of the real image distribution - here, I even have a nice \begin{aligned} derivation for the losses. All you need to do is code it up.

Cracks in the foundation

If you take the VAE and "code it up", and start training, you quickly notice that it doesn't work. Specifically, one of two things is always occurring:

  1. The decoder can reconstruct images perfectly, so the decoder's gradient is 0, so the decoder doesn't learn anything further about decoding images.
  2. The decoder can't reconstruct images perfectly, so the decoder's output is blurry yes yes "noisy"... whatever... point is, the decoder output doesn't look like anything resembling real images. It looks very obviously not like real images.
  3. Your training died because of NaNs or something.

Anyway, that's a bit worrying, but okay, maybe it's still salvageable. For the NaNs we can just spam clipping / epsilons and stuff until it's more or less stable. For the gradient / blurriness issues... maybe we can just adjust the weighting between the KL and reconstruction loss dynamically, so that we prioritize KL minimization early on and then switch to favoring reconstruction at the end? Or we can even set a different KL/rec weighting per-sample and condition the model on that weight, so that the model is always learning compression while still being able to generate sharp samples on command? Or maybe...

Anyway, turns out that's a distraction. The actual issue with VAE runs much deeper and cannot be fixed.

Why VAE is broken

The fundamental issue with VAE is that the KL loss implementation flat out doesn't work.

What you want your VAE to learn is a 1:1 mapping between image space and torch.randn space. The KL divergence you want is between the population of encoded latents (all latents - the result of running the encoder on all real images in your dataset), and the population of noise vectors generated by torch.randn. But what everyone actually implements is a KL divergence between the per-sample latent "distribution" and torch-randn - essentially telling the encoder "for each sample, your optimal behavior is to ignore it, and always report 0s, since that corresponds to a diagonal Gaussian... oh, but also try to minimize the reconstruction term somehow".

As practically implemented, the "KL" loss term in VAEs isn't actually aligning the latent / randn distributions... it's only acting as an information bottleneck, forcing the encoder to transmit as few bits as possible, and relying on undefined behavior careful implementation of inductive biases to make sure that the decoder somehow samples somewhat-plausible outputs from 0-bits-of-information inputs at test time.

The actual KL loss you would want to use for VAEs (the KL between the population of all latents and torch.randn) is fundamentally intractable to hand-code, because your latents are high-dimensional compared to your tiny batch / dataset size.

Not convinced yet? Fix your "image" dataset to be torch.randn(10), set up a 10-dimensional latent space, and hand-initialize your encoder and decoder MLPs to be the identity function (yes, yes, identity for mu - meaning the stds should be 0). At initialization, this is already a perfect, beautiful, 0-KL-divergence, randn-distributed-latents, 0-reconstruction-error-and-sharp-outputs generative model of the data distribution. But the moment you start training this perfect generative model with VAE losses, it will cease to function! You'll start throwing away information to minimize the KL loss term, and get blurry (yes...yes... "noisy"... omg) outputs as the decoder tries to minimize the reconstruction term, and your total loss will never reach 0 no matter how you tune the hyperparameters.

Variational autoencoders don't work. Variational autoencoders will never work. Please save yourself the trouble.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment