Created
March 30, 2022 14:31
-
-
Save IFeelBloated/e6ff8f3447648b5d81913ca018b23f77 to your computer and use it in GitHub Desktop.
1200-layer GAN
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.nn as nn | |
def ZeroCenteredGradientPenalty(Samples, Critics): | |
Gradient, = torch.autograd.grad(outputs=Critics.sum(), inputs=Samples, create_graph=True, only_inputs=True) | |
return 0.5 * Gradient.square().sum([1,2,3]).mean() | |
def RelativisticLoss(PositiveCritics, NegativeCritics): | |
return nn.functional.binary_cross_entropy_with_logits(PositiveCritics - NegativeCritics, torch.ones_like(PositiveCritics)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import math | |
import torch | |
import torch.nn as nn | |
CompressionFactor = 1 | |
SiLUGain = math.sqrt(2) | |
def MSRInitializer(Layer, ActivationGain=1, SpatialScaleFactor=1): | |
if SpatialScaleFactor == 1: | |
FanIn = Layer.weight.data.size(1) * Layer.weight.data[0][0].numel() | |
Layer.weight.data.normal_(0, ActivationGain / math.sqrt(FanIn)) | |
else: | |
SubpixelKernel = torch.empty(Layer.weight.shape[0] // (SpatialScaleFactor * SpatialScaleFactor), *Layer.weight.shape[1:]) | |
FanIn = SubpixelKernel.size(1) * SubpixelKernel[0][0].numel() | |
SubpixelKernel.normal_(0, ActivationGain / math.sqrt(FanIn)) | |
SubpixelKernel = nn.functional.interpolate(SubpixelKernel, scale_factor=SpatialScaleFactor, mode='bilinear', align_corners=False) | |
SubpixelKernel = nn.functional.pixel_unshuffle(SubpixelKernel.transpose(0, 1), SpatialScaleFactor) | |
Layer.weight.data.copy_(SubpixelKernel.transpose(0, 1)) | |
if Layer.bias is not None: | |
Layer.bias.data.zero_() | |
return Layer | |
class BiasedActivation(nn.Module): | |
def __init__(self, InputUnits, ConvolutionalLayer=True): | |
super(BiasedActivation, self).__init__() | |
self.Bias = nn.Parameter(torch.empty(InputUnits)) | |
self.Bias.data.zero_() | |
self.ConvolutionalLayer = ConvolutionalLayer | |
def forward(self, x): | |
y = x + self.Bias.view(1, -1, 1, 1) if self.ConvolutionalLayer else x + self.Bias.view(1, -1) | |
return nn.functional.silu(y) | |
class GeneratorBlock(nn.Module): | |
def __init__(self, InputChannels, ReceptiveField=3): | |
super(GeneratorBlock, self).__init__() | |
CompressedChannels = InputChannels // CompressionFactor | |
self.LinearLayer1 = MSRInitializer(nn.Conv2d(InputChannels, CompressedChannels, kernel_size=ReceptiveField, stride=1, padding=(ReceptiveField - 1) // 2, bias=False), ActivationGain=SiLUGain) | |
self.LinearLayer2 = MSRInitializer(nn.Conv2d(CompressedChannels, InputChannels, kernel_size=ReceptiveField, stride=1, padding=(ReceptiveField - 1) // 2, bias=False), ActivationGain=0) | |
self.NonLinearity1 = BiasedActivation(CompressedChannels) | |
self.NonLinearity2 = BiasedActivation(InputChannels) | |
def forward(self, x, ActivationMaps): | |
y = self.LinearLayer1(ActivationMaps) | |
y = self.NonLinearity1(y) | |
y = self.LinearLayer2(y) | |
y = x + y | |
return y, self.NonLinearity2(y) | |
class DiscriminatorBlock(nn.Module): | |
def __init__(self, InputChannels, ReceptiveField=3): | |
super(DiscriminatorBlock, self).__init__() | |
CompressedChannels = InputChannels // CompressionFactor | |
self.LinearLayer1 = MSRInitializer(nn.Conv2d(InputChannels, CompressedChannels, kernel_size=ReceptiveField, stride=1, padding=(ReceptiveField - 1) // 2, bias=False), ActivationGain=SiLUGain) | |
self.LinearLayer2 = MSRInitializer(nn.Conv2d(CompressedChannels, InputChannels, kernel_size=ReceptiveField, stride=1, padding=(ReceptiveField - 1) // 2, bias=False), ActivationGain=0) | |
self.NonLinearity1 = BiasedActivation(InputChannels) | |
self.NonLinearity2 = BiasedActivation(CompressedChannels) | |
def forward(self, x): | |
y = self.LinearLayer1(self.NonLinearity1(x)) | |
y = self.LinearLayer2(self.NonLinearity2(y)) | |
return x + y | |
class GeneratorUpsampleBlock(nn.Module): | |
def __init__(self, InputChannels, OutputChannels, ReceptiveField=3): | |
super(GeneratorUpsampleBlock, self).__init__() | |
CompressedChannels = InputChannels // CompressionFactor | |
self.LinearLayer1 = MSRInitializer(nn.Conv2d(InputChannels, CompressedChannels * 4, kernel_size=ReceptiveField, stride=1, padding=(ReceptiveField - 1) // 2, bias=False), ActivationGain=SiLUGain, SpatialScaleFactor=2) | |
self.LinearLayer2 = MSRInitializer(nn.Conv2d(CompressedChannels, OutputChannels, kernel_size=ReceptiveField, stride=1, padding=(ReceptiveField - 1) // 2, bias=False), ActivationGain=0) | |
self.NonLinearity1 = BiasedActivation(CompressedChannels) | |
self.NonLinearity2 = BiasedActivation(OutputChannels) | |
if InputChannels != OutputChannels: | |
self.ShortcutLayer = MSRInitializer(nn.Conv2d(InputChannels, OutputChannels, kernel_size=1, stride=1, padding=0, bias=False)) | |
def forward(self, x, ActivationMaps): | |
if hasattr(self, 'ShortcutLayer'): | |
x = self.ShortcutLayer(x) | |
y = self.LinearLayer1(ActivationMaps) | |
y = self.NonLinearity1(nn.functional.pixel_shuffle(y, 2)) | |
y = self.LinearLayer2(y) | |
y = nn.functional.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False) + y | |
return y, self.NonLinearity2(y) | |
class DiscriminatorDownsampleBlock(nn.Module): | |
def __init__(self, InputChannels, OutputChannels, ReceptiveField=3): | |
super(DiscriminatorDownsampleBlock, self).__init__() | |
CompressedChannels = OutputChannels // CompressionFactor | |
self.LinearLayer1 = MSRInitializer(nn.Conv2d(InputChannels, CompressedChannels, kernel_size=ReceptiveField, stride=1, padding=(ReceptiveField - 1) // 2, bias=False), ActivationGain=SiLUGain) | |
self.LinearLayer2 = MSRInitializer(nn.Conv2d(CompressedChannels * 4, OutputChannels, kernel_size=ReceptiveField, stride=1, padding=(ReceptiveField - 1) // 2, bias=False), ActivationGain=0) | |
self.NonLinearity1 = BiasedActivation(InputChannels) | |
self.NonLinearity2 = BiasedActivation(CompressedChannels) | |
if InputChannels != OutputChannels: | |
self.ShortcutLayer = MSRInitializer(nn.Conv2d(InputChannels, OutputChannels, kernel_size=1, stride=1, padding=0, bias=False)) | |
def forward(self, x): | |
y = self.LinearLayer1(self.NonLinearity1(x)) | |
y = nn.functional.pixel_unshuffle(self.NonLinearity2(y), 2) | |
y = self.LinearLayer2(y) | |
x = nn.functional.interpolate(x, scale_factor=0.5, mode='bilinear', align_corners=False, recompute_scale_factor=True) | |
if hasattr(self, 'ShortcutLayer'): | |
x = self.ShortcutLayer(x) | |
return x + y | |
class GeneratorOpeningLayer(nn.Module): | |
def __init__(self, LatentDimension, OutputChannels): | |
super(GeneratorOpeningLayer, self).__init__() | |
self.Basis = nn.Parameter(torch.empty((OutputChannels, 4, 4))) | |
self.LinearLayer = MSRInitializer(nn.Linear(LatentDimension, OutputChannels, bias=False)) | |
self.NonLinearity = BiasedActivation(OutputChannels) | |
self.Basis.data.normal_(0, SiLUGain) | |
def forward(self, w): | |
x = self.LinearLayer(w).view(w.shape[0], -1, 1, 1) | |
y = self.Basis.view(1, -1, 4, 4) * x | |
return y, self.NonLinearity(y) | |
class DiscriminatorClosingLayer(nn.Module): | |
def __init__(self, InputChannels, LatentDimension): | |
super(DiscriminatorClosingLayer, self).__init__() | |
self.LinearLayer1 = MSRInitializer(nn.Conv2d(InputChannels, InputChannels, kernel_size=4, stride=1, padding=0, groups=InputChannels, bias=False)) | |
self.LinearLayer2 = MSRInitializer(nn.Linear(InputChannels, LatentDimension, bias=False), ActivationGain=SiLUGain) | |
self.NonLinearity1 = BiasedActivation(InputChannels) | |
self.NonLinearity2 = BiasedActivation(LatentDimension, ConvolutionalLayer=False) | |
def forward(self, x): | |
y = self.LinearLayer1(self.NonLinearity1(x)).view(x.shape[0], -1) | |
return self.NonLinearity2(self.LinearLayer2(y)) | |
class FullyConnectedBlock(nn.Module): | |
def __init__(self, LatentDimension): | |
super(FullyConnectedBlock, self).__init__() | |
self.LinearLayer1 = MSRInitializer(nn.Linear(LatentDimension, LatentDimension, bias=False), ActivationGain=SiLUGain) | |
self.LinearLayer2 = MSRInitializer(nn.Linear(LatentDimension, LatentDimension, bias=False), ActivationGain=0) | |
self.NonLinearity1 = BiasedActivation(LatentDimension, ConvolutionalLayer=False) | |
self.NonLinearity2 = BiasedActivation(LatentDimension, ConvolutionalLayer=False) | |
def forward(self, x): | |
y = self.LinearLayer1(self.NonLinearity1(x)) | |
y = self.LinearLayer2(self.NonLinearity2(y)) | |
return x + y | |
class MappingBlock(nn.Module): | |
def __init__(self, LatentDimension): | |
super(MappingBlock, self).__init__() | |
self.LinearLayer1 = MSRInitializer(nn.Linear(LatentDimension, LatentDimension, bias=False), ActivationGain=SiLUGain) | |
self.Layer2To3 = FullyConnectedBlock(LatentDimension) | |
self.Layer4To5 = FullyConnectedBlock(LatentDimension) | |
self.Layer6To7 = FullyConnectedBlock(LatentDimension) | |
self.NonLinearity = BiasedActivation(LatentDimension, ConvolutionalLayer=False) | |
self.LinearLayer8 = MSRInitializer(nn.Linear(LatentDimension, LatentDimension, bias=False), ActivationGain=SiLUGain) | |
self.ClosingNonLinearity = BiasedActivation(LatentDimension, ConvolutionalLayer=False) | |
def forward(self, z): | |
w = self.LinearLayer1(z) | |
w = self.Layer2To3(w) | |
w = self.Layer4To5(w) | |
w = self.Layer6To7(w) | |
w = self.LinearLayer8(self.NonLinearity(w)) | |
return self.ClosingNonLinearity(w) | |
def ToRGB(InputChannels, ResidualComponent=False): | |
return MSRInitializer(nn.Conv2d(InputChannels, 3, kernel_size=1, stride=1, padding=0, bias=False), ActivationGain=0 if ResidualComponent else 1) | |
class Generator(nn.Module): | |
def __init__(self, LatentDimension): | |
super(Generator, self).__init__() | |
self.LatentLayer = MappingBlock(LatentDimension) | |
self.Layer4x4 = GeneratorOpeningLayer(LatentDimension, 64) | |
self.ToRGB4x4 = ToRGB(64) | |
Layer8x8 = [GeneratorUpsampleBlock(64, 32)] | |
self.ToRGB8x8 = ToRGB(32, ResidualComponent=True) | |
Layer16x16 = [GeneratorUpsampleBlock(32, 16)] | |
self.ToRGB16x16 = ToRGB(16, ResidualComponent=True) | |
Layer32x32 = [GeneratorUpsampleBlock(16, 16)] | |
self.ToRGB32x32 = ToRGB(16, ResidualComponent=True) | |
for _ in range(100): | |
Layer8x8 += [GeneratorBlock(64)] | |
for _ in range(100): | |
Layer16x16 += [GeneratorBlock(32)] | |
for _ in range(100): | |
Layer32x32 += [GeneratorBlock(16)] | |
Layer8x8.reverse() | |
Layer16x16.reverse() | |
Layer32x32.reverse() | |
self.Layer8x8 = nn.ModuleList(Layer8x8) | |
self.Layer16x16 = nn.ModuleList(Layer16x16) | |
self.Layer32x32 = nn.ModuleList(Layer32x32) | |
def forward(self, z, EnableLatentMapping=True): | |
w = self.LatentLayer(z) if EnableLatentMapping else z | |
y, ActivationMaps = self.Layer4x4(w) | |
Output4x4 = self.ToRGB4x4(ActivationMaps) | |
for Block in self.Layer8x8: | |
y, ActivationMaps = Block(y, ActivationMaps) | |
Output8x8 = nn.functional.interpolate(Output4x4, scale_factor=2, mode='bilinear', align_corners=False) + self.ToRGB8x8(ActivationMaps) | |
for Block in self.Layer16x16: | |
y, ActivationMaps = Block(y, ActivationMaps) | |
Output16x16 = nn.functional.interpolate(Output8x8, scale_factor=2, mode='bilinear', align_corners=False) + self.ToRGB16x16(ActivationMaps) | |
for Block in self.Layer32x32: | |
y, ActivationMaps = Block(y, ActivationMaps) | |
Output32x32 = nn.functional.interpolate(Output16x16, scale_factor=2, mode='bilinear', align_corners=False) + self.ToRGB32x32(ActivationMaps) | |
return Output32x32 | |
class Discriminator(nn.Module): | |
def __init__(self, LatentDimension): | |
super(Discriminator, self).__init__() | |
self.FromRGB = MSRInitializer(nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False), ActivationGain=SiLUGain) | |
Layer32x32 = [DiscriminatorDownsampleBlock(16, 16)] | |
Layer16x16 = [DiscriminatorDownsampleBlock(16, 32)] | |
Layer8x8 = [DiscriminatorDownsampleBlock(32, 64)] | |
self.Layer4x4 = DiscriminatorClosingLayer(64, LatentDimension) | |
self.CriticLayer = MSRInitializer(nn.Linear(LatentDimension, 1)) | |
for _ in range(100): | |
Layer8x8 += [DiscriminatorBlock(64)] | |
for _ in range(100): | |
Layer16x16 += [DiscriminatorBlock(32)] | |
for _ in range(100): | |
Layer32x32 += [DiscriminatorBlock(16)] | |
self.Layer8x8 = nn.ModuleList(Layer8x8) | |
self.Layer16x16 = nn.ModuleList(Layer16x16) | |
self.Layer32x32 = nn.ModuleList(Layer32x32) | |
def forward(self, x): | |
x = self.FromRGB(x) | |
for Block in self.Layer32x32: | |
x = Block(x) | |
for Block in self.Layer16x16: | |
x = Block(x) | |
for Block in self.Layer8x8: | |
x = Block(x) | |
x = self.Layer4x4(x) | |
return self.CriticLayer(x).squeeze() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import argparse | |
import os | |
import random | |
import numpy | |
import torch | |
import logging | |
import torch.backends.cudnn as cudnn | |
import torch.optim as optim | |
import torch.utils.data | |
import torchvision.datasets as dset | |
import torchvision.transforms as transforms | |
import torchvision.utils as vutils | |
import copy | |
from Network import Generator, Discriminator | |
from Loss import ZeroCenteredGradientPenalty, RelativisticLoss | |
ema_beta = 0.998 # ffhq: 0.99778438712388889017237329703832 cifar: 0.99991128109664301904760707704894 | |
w_avg_beta = 0.998 | |
gamma = 0.05 | |
def run(): | |
torch.set_printoptions(threshold=1) | |
logging.basicConfig(level=logging.INFO, filename='train_log.txt') | |
parser = argparse.ArgumentParser() | |
parser.add_argument('--batchSize', type=int, default=64, help='input batch size') | |
parser.add_argument('--imageSize', type=int, default=32, help='the height / width of the input image to network') | |
parser.add_argument('--lr', type=float, default=4e-5, help='learning rate') | |
parser.add_argument('--nz', type=int, default=64, help='size of the latent z vector') | |
parser.add_argument('--workers', type=int, help='number of data loading workers', default=8) | |
parser.add_argument('--outf', default='.', help='folder to output images and model checkpoints') | |
model = None #torch.load('epoch_613.pth', map_location='cpu') | |
opt = parser.parse_args() | |
print(opt) | |
try: | |
os.makedirs(opt.outf) | |
except OSError: | |
pass | |
manualSeed = 42 | |
print("Random Seed: ", manualSeed) | |
random.seed(manualSeed) | |
torch.manual_seed(manualSeed) | |
torch.cuda.manual_seed(manualSeed) | |
torch.cuda.manual_seed_all(manualSeed) | |
numpy.random.seed(manualSeed) | |
cudnn.benchmark = True | |
dataset = dset.ImageFolder('./Data',transform=transforms.Compose([ | |
transforms.Resize(opt.imageSize), | |
transforms.ToTensor(), | |
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), | |
transforms.RandomHorizontalFlip(p=0.5) | |
])) | |
dataloader = torch.utils.data.DataLoader(dataset, batch_size=opt.batchSize * 2, | |
shuffle=True, num_workers=int(opt.workers), drop_last=True) | |
device = torch.device("cuda:0") | |
nz = int(opt.nz) | |
fixed_noise = torch.randn(64, nz, device=device) | |
w_avg = torch.zeros(nz) | |
netG = Generator(nz).to(device) | |
netD = Discriminator(nz).to(device) | |
G_ema = copy.deepcopy(netG).eval() | |
optimizerD = optim.Adam(netD.parameters(), lr=opt.lr, betas=(0, 0.99)) | |
optimizerG = optim.Adam(netG.parameters(), lr=opt.lr, betas=(0, 0.99)) | |
if model is not None: | |
w_avg = model['w_avg'] | |
fixed_noise = model['fixed_noise'].to(device) | |
netG.load_state_dict(model['g_state_dict'], strict=True) | |
netD.load_state_dict(model['d_state_dict'], strict=True) | |
G_ema.load_state_dict(model['g_ema_state_dict'], strict=True) | |
optimizerD.load_state_dict(model['optimizerD_state_dict']) | |
optimizerG.load_state_dict(model['optimizerG_state_dict']) | |
print('G params: ' + str(sum(p.numel() for p in netG.parameters() if p.requires_grad))) | |
print('D params: ' + str(sum(p.numel() for p in netD.parameters() if p.requires_grad))) | |
for epoch in range(0 if model is None else model['epoch'] + 1, 1000000): | |
for i, data in enumerate(dataloader, 0): | |
netD.requires_grad = True | |
netG.requires_grad = False | |
netD.zero_grad() | |
real = data[0][0 : opt.batchSize, :, :, :].to(device) | |
real.requires_grad = True | |
noise = torch.randn(opt.batchSize, nz, device=device) | |
fake = netG(noise) | |
output_r = netD(real) | |
output_f = netD(fake) | |
r1_penalty = ZeroCenteredGradientPenalty(real, output_r) | |
r2_penalty = ZeroCenteredGradientPenalty(fake, output_f) | |
errD = RelativisticLoss(output_r, output_f) + gamma * (r1_penalty + r2_penalty) | |
errD.backward() | |
optimizerD.step() | |
########################### | |
netD.requires_grad = False | |
netG.requires_grad = True | |
netG.zero_grad() | |
real = data[0][opt.batchSize : 2 * opt.batchSize, :, :, :].to(device) | |
noise = torch.randn(opt.batchSize, nz, device=device) | |
fake = netG(noise) | |
output_f = netD(fake) | |
output_r = netD(real) | |
errG = RelativisticLoss(output_f, output_r) | |
errG.backward() | |
optimizerG.step() | |
########################### | |
with torch.no_grad(): | |
for p_ema, p in zip(G_ema.parameters(), netG.parameters()): | |
p_ema.copy_(p.lerp(p_ema, ema_beta)) | |
for b_ema, b in zip(G_ema.buffers(), netG.buffers()): | |
b_ema.copy_(b) | |
########################### | |
noise = torch.randn(opt.batchSize, nz, device=device) | |
w = G_ema.LatentLayer(noise) | |
w_avg = w_avg + (1 - w_avg_beta) * (w.mean(0).detach().cpu() - w_avg) | |
########################### | |
log_str = '[%d][%d/%d] Loss_D: %.4f Loss_G: %.4f R1: %.4f R2: %.4f' % (epoch, i, len(dataloader), errD.detach().item(), errG.detach().item(), gamma * r1_penalty.detach().item(), gamma * r2_penalty.detach().item()) | |
log_str += ' w_avg: ' + str(w_avg).removeprefix('tensor(').removesuffix(')') | |
print(log_str) | |
logging.info(log_str) | |
########################### | |
if i % 100 == 0: | |
fake = G_ema(fixed_noise) | |
mean = G_ema(w_avg.to(device).view(1, -1), EnableLatentMapping=False) | |
vutils.save_image(real, '%s/real_samples.png' % opt.outf, normalize=True, nrow=8) | |
vutils.save_image(torch.clamp(fake.detach(), -1, 1), '%s/fake_samples_epoch_%04d_%04d.png' % (opt.outf, epoch, i), normalize=True, nrow=8) | |
vutils.save_image(torch.clamp(mean.detach(), -1, 1), '%s/mean_sample_epoch_%04d_%04d.png' % (opt.outf, epoch, i), normalize=True, nrow=1) | |
torch.save({ | |
'epoch': epoch, | |
'g_ema_state_dict': G_ema.state_dict(), | |
'g_state_dict': netG.state_dict(), | |
'd_state_dict': netD.state_dict(), | |
'optimizerG_state_dict': optimizerG.state_dict(), | |
'optimizerD_state_dict': optimizerD.state_dict(), | |
'w_avg': w_avg, | |
'fixed_noise': fixed_noise, | |
'loss_D': errD.detach().item(), | |
'loss_G': errG.detach().item(), | |
'r1_penalty': gamma * r1_penalty.detach().item(), | |
'r2_penalty': gamma * r2_penalty.detach().item(), | |
}, '%s/epoch_%d.pth' % (opt.outf, epoch)) | |
if __name__ == '__main__': | |
run() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment