Skip to content

Instantly share code, notes, and snippets.

@IFeelBloated
Created March 30, 2022 14:31
Show Gist options
  • Save IFeelBloated/e6ff8f3447648b5d81913ca018b23f77 to your computer and use it in GitHub Desktop.
Save IFeelBloated/e6ff8f3447648b5d81913ca018b23f77 to your computer and use it in GitHub Desktop.
1200-layer GAN
import torch
import torch.nn as nn
def ZeroCenteredGradientPenalty(Samples, Critics):
Gradient, = torch.autograd.grad(outputs=Critics.sum(), inputs=Samples, create_graph=True, only_inputs=True)
return 0.5 * Gradient.square().sum([1,2,3]).mean()
def RelativisticLoss(PositiveCritics, NegativeCritics):
return nn.functional.binary_cross_entropy_with_logits(PositiveCritics - NegativeCritics, torch.ones_like(PositiveCritics))
import math
import torch
import torch.nn as nn
CompressionFactor = 1
SiLUGain = math.sqrt(2)
def MSRInitializer(Layer, ActivationGain=1, SpatialScaleFactor=1):
if SpatialScaleFactor == 1:
FanIn = Layer.weight.data.size(1) * Layer.weight.data[0][0].numel()
Layer.weight.data.normal_(0, ActivationGain / math.sqrt(FanIn))
else:
SubpixelKernel = torch.empty(Layer.weight.shape[0] // (SpatialScaleFactor * SpatialScaleFactor), *Layer.weight.shape[1:])
FanIn = SubpixelKernel.size(1) * SubpixelKernel[0][0].numel()
SubpixelKernel.normal_(0, ActivationGain / math.sqrt(FanIn))
SubpixelKernel = nn.functional.interpolate(SubpixelKernel, scale_factor=SpatialScaleFactor, mode='bilinear', align_corners=False)
SubpixelKernel = nn.functional.pixel_unshuffle(SubpixelKernel.transpose(0, 1), SpatialScaleFactor)
Layer.weight.data.copy_(SubpixelKernel.transpose(0, 1))
if Layer.bias is not None:
Layer.bias.data.zero_()
return Layer
class BiasedActivation(nn.Module):
def __init__(self, InputUnits, ConvolutionalLayer=True):
super(BiasedActivation, self).__init__()
self.Bias = nn.Parameter(torch.empty(InputUnits))
self.Bias.data.zero_()
self.ConvolutionalLayer = ConvolutionalLayer
def forward(self, x):
y = x + self.Bias.view(1, -1, 1, 1) if self.ConvolutionalLayer else x + self.Bias.view(1, -1)
return nn.functional.silu(y)
class GeneratorBlock(nn.Module):
def __init__(self, InputChannels, ReceptiveField=3):
super(GeneratorBlock, self).__init__()
CompressedChannels = InputChannels // CompressionFactor
self.LinearLayer1 = MSRInitializer(nn.Conv2d(InputChannels, CompressedChannels, kernel_size=ReceptiveField, stride=1, padding=(ReceptiveField - 1) // 2, bias=False), ActivationGain=SiLUGain)
self.LinearLayer2 = MSRInitializer(nn.Conv2d(CompressedChannels, InputChannels, kernel_size=ReceptiveField, stride=1, padding=(ReceptiveField - 1) // 2, bias=False), ActivationGain=0)
self.NonLinearity1 = BiasedActivation(CompressedChannels)
self.NonLinearity2 = BiasedActivation(InputChannels)
def forward(self, x, ActivationMaps):
y = self.LinearLayer1(ActivationMaps)
y = self.NonLinearity1(y)
y = self.LinearLayer2(y)
y = x + y
return y, self.NonLinearity2(y)
class DiscriminatorBlock(nn.Module):
def __init__(self, InputChannels, ReceptiveField=3):
super(DiscriminatorBlock, self).__init__()
CompressedChannels = InputChannels // CompressionFactor
self.LinearLayer1 = MSRInitializer(nn.Conv2d(InputChannels, CompressedChannels, kernel_size=ReceptiveField, stride=1, padding=(ReceptiveField - 1) // 2, bias=False), ActivationGain=SiLUGain)
self.LinearLayer2 = MSRInitializer(nn.Conv2d(CompressedChannels, InputChannels, kernel_size=ReceptiveField, stride=1, padding=(ReceptiveField - 1) // 2, bias=False), ActivationGain=0)
self.NonLinearity1 = BiasedActivation(InputChannels)
self.NonLinearity2 = BiasedActivation(CompressedChannels)
def forward(self, x):
y = self.LinearLayer1(self.NonLinearity1(x))
y = self.LinearLayer2(self.NonLinearity2(y))
return x + y
class GeneratorUpsampleBlock(nn.Module):
def __init__(self, InputChannels, OutputChannels, ReceptiveField=3):
super(GeneratorUpsampleBlock, self).__init__()
CompressedChannels = InputChannels // CompressionFactor
self.LinearLayer1 = MSRInitializer(nn.Conv2d(InputChannels, CompressedChannels * 4, kernel_size=ReceptiveField, stride=1, padding=(ReceptiveField - 1) // 2, bias=False), ActivationGain=SiLUGain, SpatialScaleFactor=2)
self.LinearLayer2 = MSRInitializer(nn.Conv2d(CompressedChannels, OutputChannels, kernel_size=ReceptiveField, stride=1, padding=(ReceptiveField - 1) // 2, bias=False), ActivationGain=0)
self.NonLinearity1 = BiasedActivation(CompressedChannels)
self.NonLinearity2 = BiasedActivation(OutputChannels)
if InputChannels != OutputChannels:
self.ShortcutLayer = MSRInitializer(nn.Conv2d(InputChannels, OutputChannels, kernel_size=1, stride=1, padding=0, bias=False))
def forward(self, x, ActivationMaps):
if hasattr(self, 'ShortcutLayer'):
x = self.ShortcutLayer(x)
y = self.LinearLayer1(ActivationMaps)
y = self.NonLinearity1(nn.functional.pixel_shuffle(y, 2))
y = self.LinearLayer2(y)
y = nn.functional.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False) + y
return y, self.NonLinearity2(y)
class DiscriminatorDownsampleBlock(nn.Module):
def __init__(self, InputChannels, OutputChannels, ReceptiveField=3):
super(DiscriminatorDownsampleBlock, self).__init__()
CompressedChannels = OutputChannels // CompressionFactor
self.LinearLayer1 = MSRInitializer(nn.Conv2d(InputChannels, CompressedChannels, kernel_size=ReceptiveField, stride=1, padding=(ReceptiveField - 1) // 2, bias=False), ActivationGain=SiLUGain)
self.LinearLayer2 = MSRInitializer(nn.Conv2d(CompressedChannels * 4, OutputChannels, kernel_size=ReceptiveField, stride=1, padding=(ReceptiveField - 1) // 2, bias=False), ActivationGain=0)
self.NonLinearity1 = BiasedActivation(InputChannels)
self.NonLinearity2 = BiasedActivation(CompressedChannels)
if InputChannels != OutputChannels:
self.ShortcutLayer = MSRInitializer(nn.Conv2d(InputChannels, OutputChannels, kernel_size=1, stride=1, padding=0, bias=False))
def forward(self, x):
y = self.LinearLayer1(self.NonLinearity1(x))
y = nn.functional.pixel_unshuffle(self.NonLinearity2(y), 2)
y = self.LinearLayer2(y)
x = nn.functional.interpolate(x, scale_factor=0.5, mode='bilinear', align_corners=False, recompute_scale_factor=True)
if hasattr(self, 'ShortcutLayer'):
x = self.ShortcutLayer(x)
return x + y
class GeneratorOpeningLayer(nn.Module):
def __init__(self, LatentDimension, OutputChannels):
super(GeneratorOpeningLayer, self).__init__()
self.Basis = nn.Parameter(torch.empty((OutputChannels, 4, 4)))
self.LinearLayer = MSRInitializer(nn.Linear(LatentDimension, OutputChannels, bias=False))
self.NonLinearity = BiasedActivation(OutputChannels)
self.Basis.data.normal_(0, SiLUGain)
def forward(self, w):
x = self.LinearLayer(w).view(w.shape[0], -1, 1, 1)
y = self.Basis.view(1, -1, 4, 4) * x
return y, self.NonLinearity(y)
class DiscriminatorClosingLayer(nn.Module):
def __init__(self, InputChannels, LatentDimension):
super(DiscriminatorClosingLayer, self).__init__()
self.LinearLayer1 = MSRInitializer(nn.Conv2d(InputChannels, InputChannels, kernel_size=4, stride=1, padding=0, groups=InputChannels, bias=False))
self.LinearLayer2 = MSRInitializer(nn.Linear(InputChannels, LatentDimension, bias=False), ActivationGain=SiLUGain)
self.NonLinearity1 = BiasedActivation(InputChannels)
self.NonLinearity2 = BiasedActivation(LatentDimension, ConvolutionalLayer=False)
def forward(self, x):
y = self.LinearLayer1(self.NonLinearity1(x)).view(x.shape[0], -1)
return self.NonLinearity2(self.LinearLayer2(y))
class FullyConnectedBlock(nn.Module):
def __init__(self, LatentDimension):
super(FullyConnectedBlock, self).__init__()
self.LinearLayer1 = MSRInitializer(nn.Linear(LatentDimension, LatentDimension, bias=False), ActivationGain=SiLUGain)
self.LinearLayer2 = MSRInitializer(nn.Linear(LatentDimension, LatentDimension, bias=False), ActivationGain=0)
self.NonLinearity1 = BiasedActivation(LatentDimension, ConvolutionalLayer=False)
self.NonLinearity2 = BiasedActivation(LatentDimension, ConvolutionalLayer=False)
def forward(self, x):
y = self.LinearLayer1(self.NonLinearity1(x))
y = self.LinearLayer2(self.NonLinearity2(y))
return x + y
class MappingBlock(nn.Module):
def __init__(self, LatentDimension):
super(MappingBlock, self).__init__()
self.LinearLayer1 = MSRInitializer(nn.Linear(LatentDimension, LatentDimension, bias=False), ActivationGain=SiLUGain)
self.Layer2To3 = FullyConnectedBlock(LatentDimension)
self.Layer4To5 = FullyConnectedBlock(LatentDimension)
self.Layer6To7 = FullyConnectedBlock(LatentDimension)
self.NonLinearity = BiasedActivation(LatentDimension, ConvolutionalLayer=False)
self.LinearLayer8 = MSRInitializer(nn.Linear(LatentDimension, LatentDimension, bias=False), ActivationGain=SiLUGain)
self.ClosingNonLinearity = BiasedActivation(LatentDimension, ConvolutionalLayer=False)
def forward(self, z):
w = self.LinearLayer1(z)
w = self.Layer2To3(w)
w = self.Layer4To5(w)
w = self.Layer6To7(w)
w = self.LinearLayer8(self.NonLinearity(w))
return self.ClosingNonLinearity(w)
def ToRGB(InputChannels, ResidualComponent=False):
return MSRInitializer(nn.Conv2d(InputChannels, 3, kernel_size=1, stride=1, padding=0, bias=False), ActivationGain=0 if ResidualComponent else 1)
class Generator(nn.Module):
def __init__(self, LatentDimension):
super(Generator, self).__init__()
self.LatentLayer = MappingBlock(LatentDimension)
self.Layer4x4 = GeneratorOpeningLayer(LatentDimension, 64)
self.ToRGB4x4 = ToRGB(64)
Layer8x8 = [GeneratorUpsampleBlock(64, 32)]
self.ToRGB8x8 = ToRGB(32, ResidualComponent=True)
Layer16x16 = [GeneratorUpsampleBlock(32, 16)]
self.ToRGB16x16 = ToRGB(16, ResidualComponent=True)
Layer32x32 = [GeneratorUpsampleBlock(16, 16)]
self.ToRGB32x32 = ToRGB(16, ResidualComponent=True)
for _ in range(100):
Layer8x8 += [GeneratorBlock(64)]
for _ in range(100):
Layer16x16 += [GeneratorBlock(32)]
for _ in range(100):
Layer32x32 += [GeneratorBlock(16)]
Layer8x8.reverse()
Layer16x16.reverse()
Layer32x32.reverse()
self.Layer8x8 = nn.ModuleList(Layer8x8)
self.Layer16x16 = nn.ModuleList(Layer16x16)
self.Layer32x32 = nn.ModuleList(Layer32x32)
def forward(self, z, EnableLatentMapping=True):
w = self.LatentLayer(z) if EnableLatentMapping else z
y, ActivationMaps = self.Layer4x4(w)
Output4x4 = self.ToRGB4x4(ActivationMaps)
for Block in self.Layer8x8:
y, ActivationMaps = Block(y, ActivationMaps)
Output8x8 = nn.functional.interpolate(Output4x4, scale_factor=2, mode='bilinear', align_corners=False) + self.ToRGB8x8(ActivationMaps)
for Block in self.Layer16x16:
y, ActivationMaps = Block(y, ActivationMaps)
Output16x16 = nn.functional.interpolate(Output8x8, scale_factor=2, mode='bilinear', align_corners=False) + self.ToRGB16x16(ActivationMaps)
for Block in self.Layer32x32:
y, ActivationMaps = Block(y, ActivationMaps)
Output32x32 = nn.functional.interpolate(Output16x16, scale_factor=2, mode='bilinear', align_corners=False) + self.ToRGB32x32(ActivationMaps)
return Output32x32
class Discriminator(nn.Module):
def __init__(self, LatentDimension):
super(Discriminator, self).__init__()
self.FromRGB = MSRInitializer(nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False), ActivationGain=SiLUGain)
Layer32x32 = [DiscriminatorDownsampleBlock(16, 16)]
Layer16x16 = [DiscriminatorDownsampleBlock(16, 32)]
Layer8x8 = [DiscriminatorDownsampleBlock(32, 64)]
self.Layer4x4 = DiscriminatorClosingLayer(64, LatentDimension)
self.CriticLayer = MSRInitializer(nn.Linear(LatentDimension, 1))
for _ in range(100):
Layer8x8 += [DiscriminatorBlock(64)]
for _ in range(100):
Layer16x16 += [DiscriminatorBlock(32)]
for _ in range(100):
Layer32x32 += [DiscriminatorBlock(16)]
self.Layer8x8 = nn.ModuleList(Layer8x8)
self.Layer16x16 = nn.ModuleList(Layer16x16)
self.Layer32x32 = nn.ModuleList(Layer32x32)
def forward(self, x):
x = self.FromRGB(x)
for Block in self.Layer32x32:
x = Block(x)
for Block in self.Layer16x16:
x = Block(x)
for Block in self.Layer8x8:
x = Block(x)
x = self.Layer4x4(x)
return self.CriticLayer(x).squeeze()
import argparse
import os
import random
import numpy
import torch
import logging
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils
import copy
from Network import Generator, Discriminator
from Loss import ZeroCenteredGradientPenalty, RelativisticLoss
ema_beta = 0.998 # ffhq: 0.99778438712388889017237329703832 cifar: 0.99991128109664301904760707704894
w_avg_beta = 0.998
gamma = 0.05
def run():
torch.set_printoptions(threshold=1)
logging.basicConfig(level=logging.INFO, filename='train_log.txt')
parser = argparse.ArgumentParser()
parser.add_argument('--batchSize', type=int, default=64, help='input batch size')
parser.add_argument('--imageSize', type=int, default=32, help='the height / width of the input image to network')
parser.add_argument('--lr', type=float, default=4e-5, help='learning rate')
parser.add_argument('--nz', type=int, default=64, help='size of the latent z vector')
parser.add_argument('--workers', type=int, help='number of data loading workers', default=8)
parser.add_argument('--outf', default='.', help='folder to output images and model checkpoints')
model = None #torch.load('epoch_613.pth', map_location='cpu')
opt = parser.parse_args()
print(opt)
try:
os.makedirs(opt.outf)
except OSError:
pass
manualSeed = 42
print("Random Seed: ", manualSeed)
random.seed(manualSeed)
torch.manual_seed(manualSeed)
torch.cuda.manual_seed(manualSeed)
torch.cuda.manual_seed_all(manualSeed)
numpy.random.seed(manualSeed)
cudnn.benchmark = True
dataset = dset.ImageFolder('./Data',transform=transforms.Compose([
transforms.Resize(opt.imageSize),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
transforms.RandomHorizontalFlip(p=0.5)
]))
dataloader = torch.utils.data.DataLoader(dataset, batch_size=opt.batchSize * 2,
shuffle=True, num_workers=int(opt.workers), drop_last=True)
device = torch.device("cuda:0")
nz = int(opt.nz)
fixed_noise = torch.randn(64, nz, device=device)
w_avg = torch.zeros(nz)
netG = Generator(nz).to(device)
netD = Discriminator(nz).to(device)
G_ema = copy.deepcopy(netG).eval()
optimizerD = optim.Adam(netD.parameters(), lr=opt.lr, betas=(0, 0.99))
optimizerG = optim.Adam(netG.parameters(), lr=opt.lr, betas=(0, 0.99))
if model is not None:
w_avg = model['w_avg']
fixed_noise = model['fixed_noise'].to(device)
netG.load_state_dict(model['g_state_dict'], strict=True)
netD.load_state_dict(model['d_state_dict'], strict=True)
G_ema.load_state_dict(model['g_ema_state_dict'], strict=True)
optimizerD.load_state_dict(model['optimizerD_state_dict'])
optimizerG.load_state_dict(model['optimizerG_state_dict'])
print('G params: ' + str(sum(p.numel() for p in netG.parameters() if p.requires_grad)))
print('D params: ' + str(sum(p.numel() for p in netD.parameters() if p.requires_grad)))
for epoch in range(0 if model is None else model['epoch'] + 1, 1000000):
for i, data in enumerate(dataloader, 0):
netD.requires_grad = True
netG.requires_grad = False
netD.zero_grad()
real = data[0][0 : opt.batchSize, :, :, :].to(device)
real.requires_grad = True
noise = torch.randn(opt.batchSize, nz, device=device)
fake = netG(noise)
output_r = netD(real)
output_f = netD(fake)
r1_penalty = ZeroCenteredGradientPenalty(real, output_r)
r2_penalty = ZeroCenteredGradientPenalty(fake, output_f)
errD = RelativisticLoss(output_r, output_f) + gamma * (r1_penalty + r2_penalty)
errD.backward()
optimizerD.step()
###########################
netD.requires_grad = False
netG.requires_grad = True
netG.zero_grad()
real = data[0][opt.batchSize : 2 * opt.batchSize, :, :, :].to(device)
noise = torch.randn(opt.batchSize, nz, device=device)
fake = netG(noise)
output_f = netD(fake)
output_r = netD(real)
errG = RelativisticLoss(output_f, output_r)
errG.backward()
optimizerG.step()
###########################
with torch.no_grad():
for p_ema, p in zip(G_ema.parameters(), netG.parameters()):
p_ema.copy_(p.lerp(p_ema, ema_beta))
for b_ema, b in zip(G_ema.buffers(), netG.buffers()):
b_ema.copy_(b)
###########################
noise = torch.randn(opt.batchSize, nz, device=device)
w = G_ema.LatentLayer(noise)
w_avg = w_avg + (1 - w_avg_beta) * (w.mean(0).detach().cpu() - w_avg)
###########################
log_str = '[%d][%d/%d] Loss_D: %.4f Loss_G: %.4f R1: %.4f R2: %.4f' % (epoch, i, len(dataloader), errD.detach().item(), errG.detach().item(), gamma * r1_penalty.detach().item(), gamma * r2_penalty.detach().item())
log_str += ' w_avg: ' + str(w_avg).removeprefix('tensor(').removesuffix(')')
print(log_str)
logging.info(log_str)
###########################
if i % 100 == 0:
fake = G_ema(fixed_noise)
mean = G_ema(w_avg.to(device).view(1, -1), EnableLatentMapping=False)
vutils.save_image(real, '%s/real_samples.png' % opt.outf, normalize=True, nrow=8)
vutils.save_image(torch.clamp(fake.detach(), -1, 1), '%s/fake_samples_epoch_%04d_%04d.png' % (opt.outf, epoch, i), normalize=True, nrow=8)
vutils.save_image(torch.clamp(mean.detach(), -1, 1), '%s/mean_sample_epoch_%04d_%04d.png' % (opt.outf, epoch, i), normalize=True, nrow=1)
torch.save({
'epoch': epoch,
'g_ema_state_dict': G_ema.state_dict(),
'g_state_dict': netG.state_dict(),
'd_state_dict': netD.state_dict(),
'optimizerG_state_dict': optimizerG.state_dict(),
'optimizerD_state_dict': optimizerD.state_dict(),
'w_avg': w_avg,
'fixed_noise': fixed_noise,
'loss_D': errD.detach().item(),
'loss_G': errG.detach().item(),
'r1_penalty': gamma * r1_penalty.detach().item(),
'r2_penalty': gamma * r2_penalty.detach().item(),
}, '%s/epoch_%d.pth' % (opt.outf, epoch))
if __name__ == '__main__':
run()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment