Created
June 20, 2021 21:16
-
-
Save GallagherCommaJack/52b995923d6f35526715eb2b74c5c00c to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.nn as nn | |
import torch.nn.functional as TF | |
import torch.utils.data as D | |
import torchvision as Tv | |
import pytorch_lightning as pl | |
def init_weights(m): | |
if type(m) == nn.Conv2d or type(m) == nn.ConvTranspose2d: | |
nn.init.normal_(m.weight, 0.0, 0.02) | |
elif type(m) == nn.BatchNorm2d: | |
nn.init.normal_(m.weight, 1.0, 0.02) | |
nn.init.zeros_(m.bias) | |
class Residual(nn.Module): | |
def __init__(self, inner): | |
super().__init__() | |
self.inner = inner | |
def forward(self, x): | |
return x + self.inner(x) | |
def c2d(ch_in, ch_out, k, stride=1, padding=0): | |
return nn.Conv2d(ch_in, ch_out, k, stride = stride, padding = padding, bias = False) | |
def c2dt(ch_in, ch_out, k, stride=1, padding=0): | |
return nn.ConvTranspose2d(ch_in, ch_out, k, stride=stride, padding=padding, bias=False) | |
def res_block(ch_base, act): | |
ch = ch_base * 4 | |
return Residual(nn.Sequential( | |
c2d(ch, ch_base, 1), | |
nn.BatchNorm2d(ch_base), | |
act(), | |
c2d(ch_base, ch_base, 3, padding = 1), | |
nn.BatchNorm2d(ch_base), | |
act(), | |
c2d(ch_base, ch, 3, padding = 1), | |
nn.BatchNorm2d(ch), | |
act(), | |
)) | |
class G(nn.Module): | |
def __init__(self, z_dim, ngf=128, nc=3): | |
super().__init__() | |
act = lambda: nn.ReLU(inplace=False) | |
self.z_dim = z_dim | |
self.project = nn.Sequential( | |
c2dt(z_dim, ngf * 8, 4), | |
nn.BatchNorm2d(ngf * 8) | |
) | |
self.res_1 = nn.Sequential(*[res_block(ngf*2, act) for _ in range(1)]) | |
self.upsample1 = nn.Sequential( | |
c2dt(ngf*8, ngf*4, 4, stride=2, padding=1), | |
nn.BatchNorm2d(ngf*4) | |
) | |
self.res_2 = nn.Sequential(*[res_block(ngf, act) for _ in range(4)]) | |
self.upsample_final = nn.Sequential( | |
c2dt(ngf*4, ngf*2, 4, stride=2, padding=1), | |
nn.BatchNorm2d(ngf*2), | |
c2dt(ngf*2, ngf, 4, stride=2, padding=1), | |
nn.BatchNorm2d(ngf), | |
c2dt(ngf, nc, 4, stride=2, padding=1), | |
nn.Tanh() | |
) | |
self.apply(init_weights) | |
def forward(self, x): | |
out = x.reshape(-1, self.z_dim, 1, 1) | |
out = self.project(out) | |
out = self.res_1(out) | |
out = self.upsample1(out) | |
out = self.res_2(out) | |
out = self.upsample_final(out) | |
return (out + 1) / 2 | |
def generate(self, n): | |
zs = torch.randn((n, self.z_dim), device='cuda:0') | |
return self.forward(zs) | |
def d(*, ndf = 64, nc=3, d_out=2): | |
leaky = lambda: nn.LeakyReLU(0.2, inplace=True) | |
d = nn.Sequential( | |
c2d(nc, ndf, 4, stride=2, padding=1), | |
leaky(), | |
c2d(ndf, ndf*2, 4, stride=2, padding=1), | |
nn.BatchNorm2d(ndf*2), | |
leaky(), | |
c2d(ndf*2, ndf*4, 4, stride=2, padding=1), | |
nn.BatchNorm2d(ndf*4), | |
leaky(), | |
res_block(ndf, leaky), | |
res_block(ndf, leaky), | |
res_block(ndf, leaky), | |
c2d(ndf*4, ndf*8, 4, stride=2, padding=1), | |
nn.BatchNorm2d(ndf*8), | |
res_block(ndf*2, leaky), | |
c2d(ndf*8, ndf*8, 1), | |
nn.BatchNorm2d(ndf*8), | |
c2d(ndf*8, d_out, 4), | |
nn.Flatten() | |
) | |
d.apply(init_weights) | |
return d | |
from layers import ResConvBlock, ResConvFFTBlock | |
class Encoder(nn.Module): | |
def __init__(self, nef, d_out, nc=3): | |
super().__init__() | |
leaky = lambda: nn.LeakyReLU(0.2, inplace=True) | |
layers = [ | |
c2d(nc, nef//16, 4, stride=2, padding=0), | |
ResConvFFTBlock(nef//16, 2), | |
c2d(nef//16, nef//8, 4, stride=2, padding=0), | |
c2d(nef//8, nef//8, 4, stride=2, padding=0), | |
ResConvFFTBlock(nef//8, 2), | |
c2d(nef//8, nef//4, 4, stride=2, padding=0), | |
ResConvFFTBlock(nef//4, 2), | |
nn.Flatten(), | |
nn.Linear(nef, nef) | |
] | |
self.conv = nn.Sequential(*layers) | |
self.conv.apply(init_weights) | |
self.fc_mu = nn.Lienar(nef, d_out) | |
self.fc_var = nn.Linear(nef, d_out) | |
def sample(self, x): | |
e = self.conv(x) | |
mu = self.fc_mu(e) | |
log_var = self.fc_var(e) | |
std = torch.exp(log_var/2) | |
p = torch.distributions.Normal(torch.zeros_like(mu), torch.ones_like(std)) | |
q = torch.distributions.Normal(mu, std) | |
z = q.rsample() | |
kl = (q.log_prob(z) - p.log_prob(z)).mean() | |
return kl, z | |
def forward(self, x): | |
_, z = self.sample(x) | |
return z | |
run_label = 'ae-gan-fnet-smaller-batch' | |
img_dir = f'./grid-{run_label}' | |
nef = 64*64 | |
ndf = 64 | |
ngf = 192 | |
latent_size = nef | |
batch_size = 32 | |
class VAEGAN(pl.LightningModule): | |
def __init__( | |
self, | |
latent_dim, nef, ndf, ngf, *, | |
rc_loss_mult = 1e3, | |
kl_mult = 5e-2, | |
encoder=None, decoder=None, discriminator=None, | |
scheduler=None, | |
): | |
super().__init__() | |
self.latent_dim = latent_dim | |
self.encoder = encoder | |
self.decoder = decoder | |
self.discriminator = discriminator | |
if self.encoder is None: | |
self.encoder = Encoder(nef=nef, d_out = latent_dim) | |
if self.decoder is None: | |
self.decoder = G(z_dim=latent_dim, ngf=ngf) | |
if self.discriminator is None: | |
self.discriminator = d(ndf=ndf, d_out=1, nc=6) | |
self.scheduler = scheduler | |
if self.scheduler is None: | |
self.scheduler = lambda opt: torch.optim.lr_scheduler.MultiplicativeLR(opt, lambda e: 0.95) | |
self.rc_mult = rc_loss_mult | |
self.kl_mult = kl_mult | |
self.automatic_optimization = False | |
def configure_optimizers(self): | |
rc_params = list(self.encoder.parameters()) + list(self.decoder.parameters()) | |
rc_opt = torch.optim.Adam(rc_params, 1e-4) | |
self.rc_opt_sched = self.scheduler(rc_opt) | |
d_opt = torch.optim.Adam(self.discriminator.parameters(), 1e-4) | |
self.d_opt_sched = self.scheduler(d_opt) | |
return rc_opt, d_opt | |
def forward(self, x): | |
return self.decoder(self.encoder(x)) | |
def step(self, x): | |
kl, z = self.encoder(x) | |
x_hat = self.decoder(z) | |
rc_loss = TF.mse_loss(x, x_hat) | |
logs = { | |
'rc_loss': rc_loss, | |
'kl': kl, | |
} | |
loss = rc_loss * self.rc_mult + kl * self.kl_mult | |
return x_hat, loss, logs | |
def training_step(self, batch, batch_idx): | |
x = batch | |
x_hat, vae_loss, logs = self.step(x) | |
rc_opt, d_opt = self.optimizers() | |
reals_first = torch.cat([x, x_hat], dim=1) | |
fakes_first = torch.cat([x_hat, x], dim=1) | |
d_reals_first = self.discriminator(reals_first) | |
d_fakes_first = self.discriminator(fakes_first) | |
d_target_reals_first = torch.ones_like(d_reals_first) | |
d_target_fakes_first = torch.zeros_like(d_fakes_first) | |
g_target_reals_first = torch.zeros_like(d_reals_first) | |
g_target_fakes_first = torch.ones_like(d_fakes_first) | |
d_loss_reals_first = TF.binary_cross_entropy_with_logits(d_reals_first, d_target_reals_first) | |
d_loss_fakes_first = TF.binary_cross_entropy_with_logits(d_fakes_first, d_target_fakes_first) | |
g_loss_reals_first = TF.binary_cross_entropy_with_logits(d_reals_first, g_target_reals_first) | |
g_loss_fakes_first = TF.binary_cross_entropy_with_logits(d_fakes_first, g_target_fakes_first) | |
d_loss = d_loss_reals_first + d_loss_fakes_first | |
g_loss = g_loss_reals_first + g_loss_fakes_first | |
logs['d_loss'] = d_loss | |
logs['g_loss'] = g_loss | |
if d_loss > g_loss: | |
d_opt.zero_grad() | |
self.manual_backward(d_loss, retain_graph = True) | |
d_opt.step() | |
else: | |
# just add to vae loss since we're going to backprop anyway | |
vae_loss += g_loss | |
rc_opt.zero_grad() | |
self.manual_backward(vae_loss) | |
rc_opt.step() | |
self.log_dict({f'train_{k}': v for (k,v) in logs.items()}, prog_bar=True, on_step=True) | |
def validation_step(self, batch, batch_idx): | |
x = batch | |
_, loss, logs = self.step(x) | |
self.log_dict({f'val_{k}': v for (k,v) in logs.items()}, prog_bar=True, on_epoch=True) | |
return loss | |
import wandb | |
class WandbDemoCallback(pl.Callback): | |
def __init__(self, *, demo_imgs, width, demo_every=500): | |
super().__init__() | |
self.demo_imgs = demo_imgs | |
self.demo_every = demo_every | |
self.width = width | |
assert len(self.demo_imgs) % self.width == 0 | |
def log_imgs(self, trainer, pl_module): | |
demo_imgs = self.demo_imgs.to(device = pl_module.device) | |
demo_imgs_rc = pl_module.forward(demo_imgs) | |
pairs = [Tv.utils.make_grid(i, i_hat) for i, i_hat in zip(demo_imgs, demo_imgs_rc)] | |
grid = Tv.utils.make_grid(torch.stack(pairs), nrow=self.width) | |
trainer.logger.experiment.log({ | |
'demo_examples': [wandb.Image(p, caption='left: original, right: reconstructed') for p in pairs], | |
'demo_grid': wandb.Image(grid, caption='left: original, right: reconstructed'), | |
'global_step': trainer.global_step, | |
}) | |
return grid | |
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): | |
if batch_idx % self.demo_every == 0: | |
grid = self.log_imgs(trainer, pl_module) | |
Tv.utils.save_image(grid, f'{img_dir}/{trainer.global_step:05}.png') | |
if __name__ == "__main__": | |
import loader | |
dset = loader.CUB(img_transform = loader.mk_image_transform(64)) | |
card = len(dset) | |
card_train = int(card * 0.95) | |
card_val = int(card - card_train) | |
train_set, val_set = D.random_split(dset, [card_train, card_val], generator=torch.Generator().manual_seed(0)) | |
train_loader = D.DataLoader(train_set, batch_size = batch_size, num_workers = 6, shuffle = True) | |
val_loader = D.DataLoader(val_set, batch_size = batch_size, num_workers = 6) | |
import os | |
os.makedirs(img_dir, exist_ok = True) | |
checkpoint_callback = pl.callbacks.ModelCheckpoint( | |
monitor = 'val_rc_loss', | |
dirpath = f'./checkpoints-{run_label}', | |
filename = 'ae-{epoch:03d}', | |
save_top_k = 10, | |
save_last = True, | |
mode = 'min', | |
) | |
demo_callback = WandbDemoCallback(demo_imgs = torch.stack([val_set[i] for i in range(18)], width=3)) | |
trainer = pl.Trainer( | |
# precision = 16, | |
gpus = 1, | |
#accelerator = 'ddp', | |
callbacks = [demo_callback, checkpoint_callback], | |
#max_epochs = epochs, | |
benchmark = True, | |
check_val_every_n_epoch = 10, | |
) | |
task = VAEGAN(latent_size, nef, ndf, ngf, grid_imgs=torch.stack([val_set[i] for i in range(18)])) | |
torch.autograd.set_detect_anomaly(True) | |
trainer.fit(task, train_loader, val_loader) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment