Skip to content

Instantly share code, notes, and snippets.

@goldsborough
Last active June 27, 2018 16:46
Show Gist options
  • Save goldsborough/510b9475aec37bd7c1178c6ef53e2422 to your computer and use it in GitHub Desktop.
Save goldsborough/510b9475aec37bd7c1178c6ef53e2422 to your computer and use it in GitHub Desktop.
from __future__ import print_function
import argparse
import os
import random
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils
parser = argparse.ArgumentParser()
parser.add_argument('--dataroot', required=True, help='path to dataset')
parser.add_argument('--workers', type=int, default=2, help='number of data loading workers')
parser.add_argument('--epochs', type=int, default=25, help='number of epochs to train for')
parser.add_argument('--lr', type=float, default=2e-4, help='learning rate')
parser.add_argument('--cuda', action='store_true', help='enables cuda')
parser.add_argument('--gpus', type=int, default=1, help='number of GPUs to use')
parser.add_argument('--output-folder', default='out', help='output directory')
options = parser.parse_args()
print(options)
random.seed(123)
torch.manual_seed(123)
cudnn.benchmark = True
if not os.path.exists(options.output_folder):
os.makedirs(options.output_folder)
device = torch.device("cuda:0" if options.cuda else "cpu")
noise_size = 100
batch_size = 64
dataset = dset.MNIST(root=options.dataroot, download=True,
transform=transforms.Compose([
transforms.Resize(28),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]),
]))
dataloader = torch.utils.data.DataLoader(
dataset,
batch_size=batch_size,
shuffle=True,
num_workers=options.workers,
drop_last=True)
# custom weights initialization called on generator and discriminator
def weights_init(m):
classname = m.__class__.__name__
if classname.find('Conv') != -1:
m.weight.data.normal_(0.0, 0.02)
elif classname.find('BatchNorm') != -1:
m.weight.data.normal_(1.0, 0.02)
m.bias.data.fill_(0)
class Generator(nn.Module):
def __init__(self, gpus):
super(Generator, self).__init__()
self.gpus = gpus
self.main = nn.Sequential(
# input is Z, going into a convolution
nn.ConvTranspose2d(noise_size, 256, kernel_size=4, bias=False),
nn.BatchNorm2d(256),
nn.ReLU(inplace=True),
# state size. 256 x 4 x 4
nn.ConvTranspose2d(256, 128, kernel_size=3, stride=2, padding=1, bias=False),
nn.BatchNorm2d(128),
nn.ReLU(True),
# state size. 128 x 7 x 7
nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1, bias=False),
nn.BatchNorm2d(64),
nn.ReLU(True),
# state size. 64 x 14 x 14
nn.ConvTranspose2d(64, 1, kernel_size=4, stride=2, padding=1, bias=False),
nn.Tanh()
# state size. 1 x 28 x 28
)
def forward(self, input):
if input.is_cuda and self.gpus > 1:
output = nn.parallel.data_parallel(self.main, input, range(self.gpus))
else:
output = self.main(input)
return output
generator = Generator(options.gpus).to(device)
generator.apply(weights_init)
class Discriminator(nn.Module):
def __init__(self, gpus):
super(Discriminator, self).__init__()
self.gpus = gpus
self.main = nn.Sequential(
# input is 1 x 28 x 28
nn.Conv2d(1, 64, kernel_size=4, stride=2, padding=1, bias=False),
nn.LeakyReLU(0.2, inplace=True),
# state size. 64 x 14 x 14
nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1, bias=False),
nn.BatchNorm2d(64 * 2),
nn.LeakyReLU(0.2, inplace=True),
# state size. 128 x 7 x 7
nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1, bias=False),
nn.BatchNorm2d(256),
nn.LeakyReLU(0.2, inplace=True),
# state size. (64*4) x 3 x 3
nn.Conv2d(256, 1, kernel_size=3, stride=1, padding=0, bias=False),
nn.Sigmoid()
)
def forward(self, input):
if input.is_cuda and self.gpus > 1:
output = nn.parallel.data_parallel(self.main, input, range(self.gpus))
else:
output = self.main(input)
return output.view(-1, 1).squeeze(1)
discriminator = Discriminator(options.gpus).to(device)
discriminator.apply(weights_init)
criterion = nn.BCELoss()
fixed_noise = torch.randn(batch_size, noise_size, 1, 1, device=device)
# setup optimizer
discriminator_optimizer = optim.Adam(discriminator.parameters(), lr=options.lr, betas=(0.5, 0.999))
generator_optimizer = optim.Adam(generator.parameters(), lr=options.lr, betas=(0.5, 0.999))
for epoch in range(options.epochs):
for i, data in enumerate(dataloader, 0):
# (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
# train with real
discriminator.zero_grad()
real_images = data[0].to(device)
real_labels = torch.empty([batch_size], device=device).uniform_(0.8, 1.0)
output = discriminator(real_images)
d_loss_real = criterion(output, real_labels)
d_loss_real.backward()
# train with fake images
noise = torch.randn([batch_size, noise_size, 1, 1], device=device)
fake_images = generator(noise)
fake_labels = torch.zeros([batch_size], device=device)
output = discriminator(fake_images.detach())
d_loss_fake = criterion(output, fake_labels)
d_loss_fake.backward()
d_loss = d_loss_real + d_loss_fake
discriminator_optimizer.step()
# (2) Update G network: maximize log(D(G(z)))
generator.zero_grad()
fake_labels = torch.ones([batch_size], device=device)
output = discriminator(fake_images)
g_loss = criterion(output, fake_labels)
g_loss.backward()
generator_optimizer.step()
print('[{}/{}][{}/{}] Loss_D: {:.4f} Loss_G: {:.4f}'
.format(epoch, options.epochs, i, len(dataloader), d_loss.item(), g_loss.item()))
if i % 100 == 0:
vutils.save_image(real_images, 'out/real_samples.png', normalize=True)
fake_images = generator(fixed_noise)
vutils.save_image(fake_images.detach(), 'out/fake_samples_epoch_{}.png'.format(epoch), normalize=True)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment