Skip to content

Instantly share code, notes, and snippets.

@gabrieldernbach
Created May 30, 2022 11:51
Show Gist options
  • Save gabrieldernbach/f0dd70a32e037f191c60e01f5390ee75 to your computer and use it in GitHub Desktop.
Save gabrieldernbach/f0dd70a32e037f191c60e01f5390ee75 to your computer and use it in GitHub Desktop.
import numpy
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader
import torchvision.transforms as T
from einops import rearrange
class AE(nn.Module):
def __init__(self):
super().__init__()
self.encoder = nn.Sequential(
nn.Conv2d(1, 8, 4, 2, 1), # (1, 28, 28) -> (64, 14, 14)
nn.ReLU(),
nn.Conv2d(8, 512, 4, 2, 1), # (64, 14, 14) -> (64, 7, 7)
nn.ReLU(),
nn.Conv2d(512, 64, 3, 2, 1), # (64, 7, 7) -> (64, 4, 4)
)
self.decoder = nn.Sequential(
nn.ConvTranspose2d(64, 512, 3, 2, 1), # (64, 4, 4) -> (64, 7, 7)
nn.ReLU(),
nn.ConvTranspose2d(512, 8, 4, 2, 1), # (64, 7, 7) -> (64, 14, 14)
nn.ReLU(),
nn.ConvTranspose2d(8, 1, 4, 2, 1) # (64, 14, 14) -> (1, 28, 28)
)
def loss(self, x):
enc = self.encoder(x)
dec = self.decoder(enc)
return F.mse_loss(x, dec)
class VQ(nn.Module):
def __init__(self, n_emb=512, emb_dim=64, commit_cost=0.25):
super().__init__()
self.n_emb = n_emb
self.emb_dim = emb_dim
self.emb = nn.Embedding(self.n_emb, self.emb_dim)
self.emb.weight.data.uniform_(-1/n_emb, 1/self.n_emb)
self.commit_cost = commit_cost
def forward(self, inputs):
b, c, h, w = inputs.shape
inputs_ = rearrange(inputs, "b c h w -> (b h w) c")
idx = torch.cdist(inputs_, self.emb.weight).argmin(1)
quantized_ = self.emb.weight[idx]
quantized = rearrange(quantized_, "(b h w) c -> b c h w", b=b, h=h)
q_loss = F.mse_loss(inputs, quantized.detach())
e_loss = F.mse_loss(inputs.detach(), quantized)
qeloss = q_loss + self.commit_cost * e_loss
quantized = inputs + (quantized - inputs).detach()
return qeloss, quantized, idx
class VQVAE(AE):
def __init__(self):
super().__init__()
self.vq = VQ()
def loss(self, x):
enc = self.encoder(x)
qloss, quant, idx = self.vq(enc)
dec = self.decoder(quant)
return F.mse_loss(x, dec) + qloss
tfm = T.Compose([
T.RandomRotation(360),
T.ToTensor(),
T.Normalize((0.1307,), (0.3081,)),
])
ds = MNIST('./', download=True, transform=tfm)
dl = DataLoader(ds, batch_size=512, shuffle=True, drop_last=True)
vqvae = VQVAE().train()
optim = torch.optim.Adam(vqvae.parameters(), lr=1e-3)
for epoch in range(10):
loss_avg = 0
for i, (x, _) in enumerate(tqdm(dl), start=1):
loss = vqvae.loss(x)
loss.backward()
optim.step()
optim.zero_grad()
#print(loss.item())
loss_avg += (loss.item() - loss_avg) / i
print(epoch, loss_avg)
out = vqvae.decoder(vqvae.vq(vqvae.encoder(x))[1])
toplot = torch.stack([x, out])
toplot = rearrange(toplot, 't (b1 b2) 1 h w -> (b1 h) (b2 t w)', b1=32).detach()
plt.figure(figsize=(13, 13))
plt.imshow(toplot)
plt.axis("off")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment