Created
September 7, 2020 18:57
-
-
Save JosephCatrambone/9cd54b218c234042357877eb4ddb1791 to your computer and use it in GitHub Desktop.
Cleaned PyTorch GAN Code
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy | |
from pathlib import Path | |
import sys | |
import torch | |
from torch import nn | |
import torchvision | |
from torchvision import datasets | |
from torchvision import transforms | |
device = torch.device("cuda:0" if True else "cpu") # Force CUDA | |
batch_size = 10 | |
epochs = 100 | |
output_samples_per_epoch = 16 | |
latent_space_size = 16 | |
image_width = 64 # Fixed so math in gen is easier. | |
image_height = image_width | |
image_channels = 3 | |
gen_channels = 64 | |
disc_channels = 64 | |
image_size = [image_channels, image_height, image_width] | |
class Generator(nn.Module): | |
def __init__(self): | |
super().__init__() | |
self.net = nn.Sequential( | |
nn.ConvTranspose2d(latent_space_size, gen_channels*8, 4, 1, 0, bias=False), | |
nn.BatchNorm2d(gen_channels * 8), | |
nn.ReLU(True), | |
# chan*8, 4, 4 | |
nn.ConvTranspose2d(gen_channels*8, gen_channels*4, 4, 2, 1, bias=False), | |
nn.BatchNorm2d(gen_channels*4), | |
nn.ReLU(True), | |
# chan*4, 8, 8 # Stride 2 * deconv 4 -> 2x increase. | |
nn.ConvTranspose2d(gen_channels*4, gen_channels*2, 4, 2, 1, bias=False), | |
nn.BatchNorm2d(gen_channels*2), | |
nn.ReLU(True), | |
# chan*2, 16, 16 | |
nn.ConvTranspose2d(gen_channels*2, gen_channels, 4, 2, 1, bias=False), | |
nn.BatchNorm2d(gen_channels), | |
nn.ReLU(True), | |
# chan, 32, 32 | |
nn.ConvTranspose2d(gen_channels, image_channels, 4, 2, 1, bias=False), | |
nn.Tanh(), # Don't like this activation func. | |
# Out, channels, 64, 64 | |
) | |
def forward(self, x): | |
return self.net(x) | |
class Discriminator(nn.Module): | |
# Same as above, basically. | |
def __init__(self): | |
super().__init__() | |
self.net = nn.Sequential( | |
# img_chan, 64, 64 | |
nn.Conv2d(image_channels, disc_channels, 4, 2, 1, bias=False), | |
nn.LeakyReLU(0.2, inplace=True), | |
# disc_chan, 32, 32 | |
nn.Conv2d(disc_channels, disc_channels * 2, 4, 2, 1, bias=False), | |
nn.BatchNorm2d(disc_channels * 2), | |
nn.LeakyReLU(0.2, inplace=True), | |
# dchan*2, 16, 16 | |
nn.Conv2d(disc_channels * 2, disc_channels * 4, 4, 2, 1, bias=False), | |
nn.BatchNorm2d(disc_channels * 4), | |
nn.LeakyReLU(0.2, inplace=True), | |
# dc*4, 8, 8 | |
nn.Conv2d(disc_channels * 4, disc_channels * 8, 4, 2, 1, bias=False), | |
nn.BatchNorm2d(disc_channels * 8), | |
nn.LeakyReLU(0.2, inplace=True), | |
# dchan*8, 4, 4 | |
nn.Conv2d(disc_channels * 8, 1, 4, 1, 0, bias=False), | |
nn.Sigmoid(), | |
# 1x1x1 out. | |
) | |
def forward(self, x): | |
return self.net(x).view(-1, 1).squeeze(1) # Cut trailing dims. | |
def gan_weight_init(m): | |
if 'Conv' in m.__class__.__name__: | |
nn.init.normal_(m.weight, 0.0, 0.02) | |
elif 'BatchNorm' in m.__class__.__name__: | |
nn.init.normal_(m.weight, 1.0, 0.02) | |
nn.init.zeros_(m.bias) # Really needed? Bias is false, no? | |
def load_dataset(root_folder:Path): | |
return datasets.ImageFolder( | |
root=str(root_folder), | |
transform=transforms.Compose([ | |
transforms.RandomResizedCrop(image_width), | |
transforms.ToTensor(), | |
]) | |
) | |
def build_models(gen_fname:Path = None, disc_fname:Path = None): | |
gen = Generator().to(device) | |
disc = Discriminator().to(device) | |
if gen_fname: | |
gen.load_state_dict(torch.load(str(gen_fname))) | |
else: | |
gen.apply(gan_weight_init) | |
if disc_fname: | |
disc.load_state_dict(torch.load(str(disc_fname))) | |
else: | |
disc.apply(gan_weight_init) | |
return gen, disc | |
def train(dset, gen, disc, opt_g, opt_d, loss_fn, epochs, batch_size): | |
debug_sample_noise = torch.randn(output_samples_per_epoch, latent_space_size, 1, 1, device=device) | |
real_label = 1 | |
fake_label = 0 | |
data_loader = torch.utils.data.DataLoader(dset, batch_size=batch_size, shuffle=True) | |
label = torch.full((batch_size,), real_label, dtype=torch.float32, device=device) | |
for epoch in range(epochs): | |
for i, data in enumerate(data_loader): | |
# Update D: max log(D(x) + log(1 - D(G(z))) | |
disc.zero_grad() | |
# Real pass | |
real = data[0].to(device) # Data[0] are the 'examples'. [1] is the label. | |
if real.shape[0] < batch_size: | |
continue # Lazy! Rather than resize the label, we skip the tail end. | |
label.fill_(real_label) | |
y = disc(real) | |
disc_error_real = loss_fn(y, label) | |
disc_error_real.backward() | |
D_x = y.mean().item() # Mean real prediction? | |
# Fake pass | |
noise = torch.randn(batch_size, latent_space_size, 1, 1, device=device) | |
fake = gen(noise) | |
label.fill_(fake_label) | |
y = disc(fake.detach()) # .data() is deprecated. .detach() removes from graph. | |
disc_error_fake = loss_fn(y, label) | |
disc_error_fake.backward() | |
D_G_z1 = y.mean().item() | |
disc_error = D_x + D_G_z1 | |
opt_d.step() | |
# Update generator: max log(D(G(z)) | |
gen.zero_grad() | |
label.fill_(real_label) # This is correct. "Fake labels are real for gen cost." Why? | |
y = disc(fake) # Reusing fake | |
gen_error = loss_fn(y, label) | |
gen_error.backward() | |
D_G_z2 = y.mean().item() | |
opt_g.step() | |
# Status: | |
print(f"{epoch*100.0/epochs}% | D_x: {D_x} D_G_z1: {D_G_z1} D_G_z2: {D_G_z2}") | |
# Sample: | |
fake = gen(debug_sample_noise) | |
torchvision.utils.save_image(fake.detach(), f"fake_samples_epoch_{epoch}.png", normalize=True) | |
# Checkpoint: | |
torch.save(gen.state_dict(), 'generator.pkl') | |
torch.save(disc.state_dict(), 'discriminator.pkl') | |
def main(data_path:Path): | |
dset = load_dataset(data_path) | |
gen, disc = build_models() | |
opt_g = torch.optim.Adam(gen.parameters(), lr=0.0001, betas=(0.5, 0.999)) | |
opt_d = torch.optim.Adam(disc.parameters(), lr=0.0001, betas=(0.5, 0.999)) | |
loss_fn = nn.BCELoss() # Maybe log-loss, since we have one class? | |
train(dset, gen, disc, opt_g, opt_d, loss_fn, epochs, batch_size) | |
if __name__=="__main__": | |
main(Path(sys.argv[1])) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment