Skip to content

Instantly share code, notes, and snippets.

Embed
What would you like to do?
Cleaned PyTorch GAN Code
import numpy
from pathlib import Path
import sys
import torch
from torch import nn
import torchvision
from torchvision import datasets
from torchvision import transforms
device = torch.device("cuda:0" if True else "cpu") # Force CUDA
batch_size = 10
epochs = 100
output_samples_per_epoch = 16
latent_space_size = 16
image_width = 64 # Fixed so math in gen is easier.
image_height = image_width
image_channels = 3
gen_channels = 64
disc_channels = 64
image_size = [image_channels, image_height, image_width]
class Generator(nn.Module):
def __init__(self):
super().__init__()
self.net = nn.Sequential(
nn.ConvTranspose2d(latent_space_size, gen_channels*8, 4, 1, 0, bias=False),
nn.BatchNorm2d(gen_channels * 8),
nn.ReLU(True),
# chan*8, 4, 4
nn.ConvTranspose2d(gen_channels*8, gen_channels*4, 4, 2, 1, bias=False),
nn.BatchNorm2d(gen_channels*4),
nn.ReLU(True),
# chan*4, 8, 8 # Stride 2 * deconv 4 -> 2x increase.
nn.ConvTranspose2d(gen_channels*4, gen_channels*2, 4, 2, 1, bias=False),
nn.BatchNorm2d(gen_channels*2),
nn.ReLU(True),
# chan*2, 16, 16
nn.ConvTranspose2d(gen_channels*2, gen_channels, 4, 2, 1, bias=False),
nn.BatchNorm2d(gen_channels),
nn.ReLU(True),
# chan, 32, 32
nn.ConvTranspose2d(gen_channels, image_channels, 4, 2, 1, bias=False),
nn.Tanh(), # Don't like this activation func.
# Out, channels, 64, 64
)
def forward(self, x):
return self.net(x)
class Discriminator(nn.Module):
# Same as above, basically.
def __init__(self):
super().__init__()
self.net = nn.Sequential(
# img_chan, 64, 64
nn.Conv2d(image_channels, disc_channels, 4, 2, 1, bias=False),
nn.LeakyReLU(0.2, inplace=True),
# disc_chan, 32, 32
nn.Conv2d(disc_channels, disc_channels * 2, 4, 2, 1, bias=False),
nn.BatchNorm2d(disc_channels * 2),
nn.LeakyReLU(0.2, inplace=True),
# dchan*2, 16, 16
nn.Conv2d(disc_channels * 2, disc_channels * 4, 4, 2, 1, bias=False),
nn.BatchNorm2d(disc_channels * 4),
nn.LeakyReLU(0.2, inplace=True),
# dc*4, 8, 8
nn.Conv2d(disc_channels * 4, disc_channels * 8, 4, 2, 1, bias=False),
nn.BatchNorm2d(disc_channels * 8),
nn.LeakyReLU(0.2, inplace=True),
# dchan*8, 4, 4
nn.Conv2d(disc_channels * 8, 1, 4, 1, 0, bias=False),
nn.Sigmoid(),
# 1x1x1 out.
)
def forward(self, x):
return self.net(x).view(-1, 1).squeeze(1) # Cut trailing dims.
def gan_weight_init(m):
if 'Conv' in m.__class__.__name__:
nn.init.normal_(m.weight, 0.0, 0.02)
elif 'BatchNorm' in m.__class__.__name__:
nn.init.normal_(m.weight, 1.0, 0.02)
nn.init.zeros_(m.bias) # Really needed? Bias is false, no?
def load_dataset(root_folder:Path):
return datasets.ImageFolder(
root=str(root_folder),
transform=transforms.Compose([
transforms.RandomResizedCrop(image_width),
transforms.ToTensor(),
])
)
def build_models(gen_fname:Path = None, disc_fname:Path = None):
gen = Generator().to(device)
disc = Discriminator().to(device)
if gen_fname:
gen.load_state_dict(torch.load(str(gen_fname)))
else:
gen.apply(gan_weight_init)
if disc_fname:
disc.load_state_dict(torch.load(str(disc_fname)))
else:
disc.apply(gan_weight_init)
return gen, disc
def train(dset, gen, disc, opt_g, opt_d, loss_fn, epochs, batch_size):
debug_sample_noise = torch.randn(output_samples_per_epoch, latent_space_size, 1, 1, device=device)
real_label = 1
fake_label = 0
data_loader = torch.utils.data.DataLoader(dset, batch_size=batch_size, shuffle=True)
label = torch.full((batch_size,), real_label, dtype=torch.float32, device=device)
for epoch in range(epochs):
for i, data in enumerate(data_loader):
# Update D: max log(D(x) + log(1 - D(G(z)))
disc.zero_grad()
# Real pass
real = data[0].to(device) # Data[0] are the 'examples'. [1] is the label.
if real.shape[0] < batch_size:
continue # Lazy! Rather than resize the label, we skip the tail end.
label.fill_(real_label)
y = disc(real)
disc_error_real = loss_fn(y, label)
disc_error_real.backward()
D_x = y.mean().item() # Mean real prediction?
# Fake pass
noise = torch.randn(batch_size, latent_space_size, 1, 1, device=device)
fake = gen(noise)
label.fill_(fake_label)
y = disc(fake.detach()) # .data() is deprecated. .detach() removes from graph.
disc_error_fake = loss_fn(y, label)
disc_error_fake.backward()
D_G_z1 = y.mean().item()
disc_error = D_x + D_G_z1
opt_d.step()
# Update generator: max log(D(G(z))
gen.zero_grad()
label.fill_(real_label) # This is correct. "Fake labels are real for gen cost." Why?
y = disc(fake) # Reusing fake
gen_error = loss_fn(y, label)
gen_error.backward()
D_G_z2 = y.mean().item()
opt_g.step()
# Status:
print(f"{epoch*100.0/epochs}% | D_x: {D_x} D_G_z1: {D_G_z1} D_G_z2: {D_G_z2}")
# Sample:
fake = gen(debug_sample_noise)
torchvision.utils.save_image(fake.detach(), f"fake_samples_epoch_{epoch}.png", normalize=True)
# Checkpoint:
torch.save(gen.state_dict(), 'generator.pkl')
torch.save(disc.state_dict(), 'discriminator.pkl')
def main(data_path:Path):
dset = load_dataset(data_path)
gen, disc = build_models()
opt_g = torch.optim.Adam(gen.parameters(), lr=0.0001, betas=(0.5, 0.999))
opt_d = torch.optim.Adam(disc.parameters(), lr=0.0001, betas=(0.5, 0.999))
loss_fn = nn.BCELoss() # Maybe log-loss, since we have one class?
train(dset, gen, disc, opt_g, opt_d, loss_fn, epochs, batch_size)
if __name__=="__main__":
main(Path(sys.argv[1]))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
You can’t perform that action at this time.