Skip to content

Instantly share code, notes, and snippets.

@GallagherCommaJack
Created June 20, 2021 21:16
Show Gist options
  • Save GallagherCommaJack/52b995923d6f35526715eb2b74c5c00c to your computer and use it in GitHub Desktop.
Save GallagherCommaJack/52b995923d6f35526715eb2b74c5c00c to your computer and use it in GitHub Desktop.
import torch
import torch.nn as nn
import torch.nn.functional as TF
import torch.utils.data as D
import torchvision as Tv
import pytorch_lightning as pl
def init_weights(m):
if type(m) == nn.Conv2d or type(m) == nn.ConvTranspose2d:
nn.init.normal_(m.weight, 0.0, 0.02)
elif type(m) == nn.BatchNorm2d:
nn.init.normal_(m.weight, 1.0, 0.02)
nn.init.zeros_(m.bias)
class Residual(nn.Module):
def __init__(self, inner):
super().__init__()
self.inner = inner
def forward(self, x):
return x + self.inner(x)
def c2d(ch_in, ch_out, k, stride=1, padding=0):
return nn.Conv2d(ch_in, ch_out, k, stride = stride, padding = padding, bias = False)
def c2dt(ch_in, ch_out, k, stride=1, padding=0):
return nn.ConvTranspose2d(ch_in, ch_out, k, stride=stride, padding=padding, bias=False)
def res_block(ch_base, act):
ch = ch_base * 4
return Residual(nn.Sequential(
c2d(ch, ch_base, 1),
nn.BatchNorm2d(ch_base),
act(),
c2d(ch_base, ch_base, 3, padding = 1),
nn.BatchNorm2d(ch_base),
act(),
c2d(ch_base, ch, 3, padding = 1),
nn.BatchNorm2d(ch),
act(),
))
class G(nn.Module):
def __init__(self, z_dim, ngf=128, nc=3):
super().__init__()
act = lambda: nn.ReLU(inplace=False)
self.z_dim = z_dim
self.project = nn.Sequential(
c2dt(z_dim, ngf * 8, 4),
nn.BatchNorm2d(ngf * 8)
)
self.res_1 = nn.Sequential(*[res_block(ngf*2, act) for _ in range(1)])
self.upsample1 = nn.Sequential(
c2dt(ngf*8, ngf*4, 4, stride=2, padding=1),
nn.BatchNorm2d(ngf*4)
)
self.res_2 = nn.Sequential(*[res_block(ngf, act) for _ in range(4)])
self.upsample_final = nn.Sequential(
c2dt(ngf*4, ngf*2, 4, stride=2, padding=1),
nn.BatchNorm2d(ngf*2),
c2dt(ngf*2, ngf, 4, stride=2, padding=1),
nn.BatchNorm2d(ngf),
c2dt(ngf, nc, 4, stride=2, padding=1),
nn.Tanh()
)
self.apply(init_weights)
def forward(self, x):
out = x.reshape(-1, self.z_dim, 1, 1)
out = self.project(out)
out = self.res_1(out)
out = self.upsample1(out)
out = self.res_2(out)
out = self.upsample_final(out)
return (out + 1) / 2
def generate(self, n):
zs = torch.randn((n, self.z_dim), device='cuda:0')
return self.forward(zs)
def d(*, ndf = 64, nc=3, d_out=2):
leaky = lambda: nn.LeakyReLU(0.2, inplace=True)
d = nn.Sequential(
c2d(nc, ndf, 4, stride=2, padding=1),
leaky(),
c2d(ndf, ndf*2, 4, stride=2, padding=1),
nn.BatchNorm2d(ndf*2),
leaky(),
c2d(ndf*2, ndf*4, 4, stride=2, padding=1),
nn.BatchNorm2d(ndf*4),
leaky(),
res_block(ndf, leaky),
res_block(ndf, leaky),
res_block(ndf, leaky),
c2d(ndf*4, ndf*8, 4, stride=2, padding=1),
nn.BatchNorm2d(ndf*8),
res_block(ndf*2, leaky),
c2d(ndf*8, ndf*8, 1),
nn.BatchNorm2d(ndf*8),
c2d(ndf*8, d_out, 4),
nn.Flatten()
)
d.apply(init_weights)
return d
from layers import ResConvBlock, ResConvFFTBlock
class Encoder(nn.Module):
def __init__(self, nef, d_out, nc=3):
super().__init__()
leaky = lambda: nn.LeakyReLU(0.2, inplace=True)
layers = [
c2d(nc, nef//16, 4, stride=2, padding=0),
ResConvFFTBlock(nef//16, 2),
c2d(nef//16, nef//8, 4, stride=2, padding=0),
c2d(nef//8, nef//8, 4, stride=2, padding=0),
ResConvFFTBlock(nef//8, 2),
c2d(nef//8, nef//4, 4, stride=2, padding=0),
ResConvFFTBlock(nef//4, 2),
nn.Flatten(),
nn.Linear(nef, nef)
]
self.conv = nn.Sequential(*layers)
self.conv.apply(init_weights)
self.fc_mu = nn.Lienar(nef, d_out)
self.fc_var = nn.Linear(nef, d_out)
def sample(self, x):
e = self.conv(x)
mu = self.fc_mu(e)
log_var = self.fc_var(e)
std = torch.exp(log_var/2)
p = torch.distributions.Normal(torch.zeros_like(mu), torch.ones_like(std))
q = torch.distributions.Normal(mu, std)
z = q.rsample()
kl = (q.log_prob(z) - p.log_prob(z)).mean()
return kl, z
def forward(self, x):
_, z = self.sample(x)
return z
run_label = 'ae-gan-fnet-smaller-batch'
img_dir = f'./grid-{run_label}'
nef = 64*64
ndf = 64
ngf = 192
latent_size = nef
batch_size = 32
class VAEGAN(pl.LightningModule):
def __init__(
self,
latent_dim, nef, ndf, ngf, *,
rc_loss_mult = 1e3,
kl_mult = 5e-2,
encoder=None, decoder=None, discriminator=None,
scheduler=None,
):
super().__init__()
self.latent_dim = latent_dim
self.encoder = encoder
self.decoder = decoder
self.discriminator = discriminator
if self.encoder is None:
self.encoder = Encoder(nef=nef, d_out = latent_dim)
if self.decoder is None:
self.decoder = G(z_dim=latent_dim, ngf=ngf)
if self.discriminator is None:
self.discriminator = d(ndf=ndf, d_out=1, nc=6)
self.scheduler = scheduler
if self.scheduler is None:
self.scheduler = lambda opt: torch.optim.lr_scheduler.MultiplicativeLR(opt, lambda e: 0.95)
self.rc_mult = rc_loss_mult
self.kl_mult = kl_mult
self.automatic_optimization = False
def configure_optimizers(self):
rc_params = list(self.encoder.parameters()) + list(self.decoder.parameters())
rc_opt = torch.optim.Adam(rc_params, 1e-4)
self.rc_opt_sched = self.scheduler(rc_opt)
d_opt = torch.optim.Adam(self.discriminator.parameters(), 1e-4)
self.d_opt_sched = self.scheduler(d_opt)
return rc_opt, d_opt
def forward(self, x):
return self.decoder(self.encoder(x))
def step(self, x):
kl, z = self.encoder(x)
x_hat = self.decoder(z)
rc_loss = TF.mse_loss(x, x_hat)
logs = {
'rc_loss': rc_loss,
'kl': kl,
}
loss = rc_loss * self.rc_mult + kl * self.kl_mult
return x_hat, loss, logs
def training_step(self, batch, batch_idx):
x = batch
x_hat, vae_loss, logs = self.step(x)
rc_opt, d_opt = self.optimizers()
reals_first = torch.cat([x, x_hat], dim=1)
fakes_first = torch.cat([x_hat, x], dim=1)
d_reals_first = self.discriminator(reals_first)
d_fakes_first = self.discriminator(fakes_first)
d_target_reals_first = torch.ones_like(d_reals_first)
d_target_fakes_first = torch.zeros_like(d_fakes_first)
g_target_reals_first = torch.zeros_like(d_reals_first)
g_target_fakes_first = torch.ones_like(d_fakes_first)
d_loss_reals_first = TF.binary_cross_entropy_with_logits(d_reals_first, d_target_reals_first)
d_loss_fakes_first = TF.binary_cross_entropy_with_logits(d_fakes_first, d_target_fakes_first)
g_loss_reals_first = TF.binary_cross_entropy_with_logits(d_reals_first, g_target_reals_first)
g_loss_fakes_first = TF.binary_cross_entropy_with_logits(d_fakes_first, g_target_fakes_first)
d_loss = d_loss_reals_first + d_loss_fakes_first
g_loss = g_loss_reals_first + g_loss_fakes_first
logs['d_loss'] = d_loss
logs['g_loss'] = g_loss
if d_loss > g_loss:
d_opt.zero_grad()
self.manual_backward(d_loss, retain_graph = True)
d_opt.step()
else:
# just add to vae loss since we're going to backprop anyway
vae_loss += g_loss
rc_opt.zero_grad()
self.manual_backward(vae_loss)
rc_opt.step()
self.log_dict({f'train_{k}': v for (k,v) in logs.items()}, prog_bar=True, on_step=True)
def validation_step(self, batch, batch_idx):
x = batch
_, loss, logs = self.step(x)
self.log_dict({f'val_{k}': v for (k,v) in logs.items()}, prog_bar=True, on_epoch=True)
return loss
import wandb
class WandbDemoCallback(pl.Callback):
def __init__(self, *, demo_imgs, width, demo_every=500):
super().__init__()
self.demo_imgs = demo_imgs
self.demo_every = demo_every
self.width = width
assert len(self.demo_imgs) % self.width == 0
def log_imgs(self, trainer, pl_module):
demo_imgs = self.demo_imgs.to(device = pl_module.device)
demo_imgs_rc = pl_module.forward(demo_imgs)
pairs = [Tv.utils.make_grid(i, i_hat) for i, i_hat in zip(demo_imgs, demo_imgs_rc)]
grid = Tv.utils.make_grid(torch.stack(pairs), nrow=self.width)
trainer.logger.experiment.log({
'demo_examples': [wandb.Image(p, caption='left: original, right: reconstructed') for p in pairs],
'demo_grid': wandb.Image(grid, caption='left: original, right: reconstructed'),
'global_step': trainer.global_step,
})
return grid
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
if batch_idx % self.demo_every == 0:
grid = self.log_imgs(trainer, pl_module)
Tv.utils.save_image(grid, f'{img_dir}/{trainer.global_step:05}.png')
if __name__ == "__main__":
import loader
dset = loader.CUB(img_transform = loader.mk_image_transform(64))
card = len(dset)
card_train = int(card * 0.95)
card_val = int(card - card_train)
train_set, val_set = D.random_split(dset, [card_train, card_val], generator=torch.Generator().manual_seed(0))
train_loader = D.DataLoader(train_set, batch_size = batch_size, num_workers = 6, shuffle = True)
val_loader = D.DataLoader(val_set, batch_size = batch_size, num_workers = 6)
import os
os.makedirs(img_dir, exist_ok = True)
checkpoint_callback = pl.callbacks.ModelCheckpoint(
monitor = 'val_rc_loss',
dirpath = f'./checkpoints-{run_label}',
filename = 'ae-{epoch:03d}',
save_top_k = 10,
save_last = True,
mode = 'min',
)
demo_callback = WandbDemoCallback(demo_imgs = torch.stack([val_set[i] for i in range(18)], width=3))
trainer = pl.Trainer(
# precision = 16,
gpus = 1,
#accelerator = 'ddp',
callbacks = [demo_callback, checkpoint_callback],
#max_epochs = epochs,
benchmark = True,
check_val_every_n_epoch = 10,
)
task = VAEGAN(latent_size, nef, ndf, ngf, grid_imgs=torch.stack([val_set[i] for i in range(18)]))
torch.autograd.set_detect_anomaly(True)
trainer.fit(task, train_loader, val_loader)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment