Skip to content

Instantly share code, notes, and snippets.

@Daiver
Created December 23, 2018 12:10
Show Gist options
  • Save Daiver/9eca72f486b3a5bb65d5148621311611 to your computer and use it in GitHub Desktop.
Save Daiver/9eca72f486b3a5bb65d5148621311611 to your computer and use it in GitHub Desktop.
from __future__ import print_function
import argparse
import os
import sys
import random
import warnings
import ignite
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
from ignite.contrib.handlers import ProgressBar
from ignite.engine import Engine, Events
from ignite.handlers import ModelCheckpoint, Timer
from ignite.metrics import RunningAverage
try:
import torchvision
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils
except ImportError:
raise ImportError("Please install torchvision to run this example, for example "
"via conda by running 'conda install -c pytorch torchvision'. ")
PRINT_FREQ = 100
FAKE_IMG_FNAME = 'fake_sample_epoch_{:04d}.png'
REAL_IMG_FNAME = 'real_sample_epoch_{:04d}.png'
LOGS_FNAME = 'logs.tsv'
PLOT_FNAME = 'plot.svg'
SAMPLES_FNAME = 'samples.svg'
CKPT_PREFIX = 'networks'
class Net(nn.Module):
""" A base class for both generator and the discriminator.
Provides a common weight initialization scheme.
"""
def weights_init(self):
for m in self.modules():
classname = m.__class__.__name__
if 'Conv' in classname:
m.weight.data.normal_(0.0, 0.02)
elif 'BatchNorm' in classname:
m.weight.data.normal_(1.0, 0.02)
m.bias.data.fill_(0)
def forward(self, x):
return x
class Generator(Net):
""" Generator network.
Args:
nf (int): Number of filters in the second-to-last deconv layer
"""
def __init__(self, z_dim, nf):
super(Generator, self).__init__()
self.net = nn.Sequential(
# input is Z, going into a convolution
nn.ConvTranspose2d(in_channels=z_dim, out_channels=nf * 8, kernel_size=4, stride=1, padding=0, bias=False),
nn.BatchNorm2d(nf * 8),
nn.ReLU(inplace=True),
# state size. (nf*8) x 4 x 4
nn.ConvTranspose2d(in_channels=nf * 8, out_channels=nf * 4, kernel_size=4, stride=2, padding=1, bias=False),
nn.BatchNorm2d(nf * 4),
nn.ReLU(inplace=True),
# state size. (nf*4) x 8 x 8
nn.ConvTranspose2d(in_channels=nf * 4, out_channels=nf * 2, kernel_size=4, stride=2, padding=1, bias=False),
nn.BatchNorm2d(nf * 2),
nn.ReLU(inplace=True),
# state size. (nf*2) x 16 x 16
nn.ConvTranspose2d(in_channels=nf * 2, out_channels=nf, kernel_size=4, stride=2, padding=1, bias=False),
nn.BatchNorm2d(nf),
nn.ReLU(inplace=True),
# state size. (nf) x 32 x 32
nn.ConvTranspose2d(in_channels=nf, out_channels=3, kernel_size=4, stride=2, padding=1, bias=False),
nn.Tanh()
# state size. (nc) x 64 x 64
)
self.weights_init()
def forward(self, x):
return self.net(x)
class Discriminator(Net):
""" Discriminator network.
Args:
nf (int): Number of filters in the first conv layer.
"""
def __init__(self, nf):
super(Discriminator, self).__init__()
self.net = nn.Sequential(
# input is (nc) x 64 x 64
nn.Conv2d(in_channels=3, out_channels=nf, kernel_size=4, stride=2, padding=1, bias=False),
nn.LeakyReLU(0.2, inplace=True),
# state size. (nf) x 32 x 32
nn.Conv2d(in_channels=nf, out_channels=nf * 2, kernel_size=4, stride=2, padding=1, bias=False),
nn.BatchNorm2d(nf * 2),
nn.LeakyReLU(0.2, inplace=True),
# state size. (nf*2) x 16 x 16
nn.Conv2d(in_channels=nf * 2, out_channels=nf * 4, kernel_size=4, stride=2, padding=1, bias=False),
nn.BatchNorm2d(nf * 4),
nn.LeakyReLU(0.2, inplace=True),
# state size. (nf*4) x 8 x 8
nn.Conv2d(in_channels=nf * 4, out_channels=nf * 8, kernel_size=4, stride=2, padding=1, bias=False),
nn.BatchNorm2d(nf * 8),
nn.LeakyReLU(0.2, inplace=True),
# state size. (nf*8) x 4 x 4
nn.Conv2d(in_channels=nf * 8, out_channels=1, kernel_size=4, stride=1, padding=0, bias=False),
nn.Sigmoid()
)
self.weights_init()
def forward(self, x):
output = self.net(x)
return output.view(-1, 1).squeeze(1)
def check_manual_seed(seed):
""" If manual seed is not specified, choose a random one and communicate it to the user.
"""
seed = seed or random.randint(1, 10000)
random.seed(seed)
torch.manual_seed(seed)
print('Using manual seed: {seed}'.format(seed=seed))
def to_rgb(img):
return img.convert('RGB')
def check_dataset(dataset, dataroot):
"""
Args:
dataset (str): Name of the dataset to use. See CLI help for details
dataroot (str): root directory where the dataset will be stored.
Returns:
dataset (data.Dataset): torchvision Dataset object
"""
# to_rgb = transforms.Lambda(lambda img: img.convert('RGB'))
resize = transforms.Resize(64)
crop = transforms.CenterCrop(64)
to_tensor = transforms.ToTensor()
normalize = transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
if dataset in {'imagenet', 'folder', 'lfw'}:
dataset = dset.ImageFolder(root=dataroot, transform=transforms.Compose([resize,
crop,
to_tensor,
normalize]))
elif dataset == 'lsun':
dataset = dset.LSUN(root=dataroot, classes=['bedroom_train'], transform=transforms.Compose([resize,
crop,
to_tensor,
normalize]))
elif dataset == 'cifar10':
dataset = dset.CIFAR10(root=dataroot, download=True, transform=transforms.Compose([resize,
to_tensor,
normalize]))
elif dataset == 'mnist':
dataset = dset.MNIST(root=dataroot, download=True, transform=transforms.Compose([
to_rgb,
resize,
to_tensor,
normalize]))
elif dataset == 'fake':
dataset = dset.FakeData(size=256, image_size=(3, 64, 64), transform=to_tensor)
else:
raise RuntimeError("Invalid dataset name: {}".format(dataset))
return dataset
def main(dataset, dataroot,
z_dim, g_filters, d_filters,
batch_size, epochs,
learning_rate, beta_1,
saved_G, saved_D,
seed,
n_workers, device,
alpha, output_dir):
# seed
check_manual_seed(seed)
# netowrks
netG = Generator(z_dim, g_filters).to(device)
netD = Discriminator(d_filters).to(device)
# criterion
bce = nn.BCELoss()
# optimizers
optimizerG = optim.Adam(netG.parameters(), lr=learning_rate, betas=(beta_1, 0.999))
optimizerD = optim.Adam(netD.parameters(), lr=learning_rate, betas=(beta_1, 0.999))
# data
dataset = check_dataset(dataset, dataroot)
loader = data.DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=n_workers, drop_last=True)
# load pre-trained models
if saved_G:
netG.load_state_dict(torch.load(saved_G))
if saved_D:
netD.load_state_dict(torch.load(saved_D))
# misc
real_labels = torch.ones(batch_size, device=device)
fake_labels = torch.zeros(batch_size, device=device)
fixed_noise = torch.randn(batch_size, z_dim, 1, 1, device=device)
def get_noise():
return torch.randn(batch_size, z_dim, 1, 1, device=device)
# The main function, processing a batch of examples
def step(engine, batch):
# unpack the batch. It comes from a dataset, so we have <images, labels> pairs. Discard labels.
real, _ = batch
real = real.to(device)
# -----------------------------------------------------------
# (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
netD.zero_grad()
# train with real
output = netD(real)
errD_real = bce(output, real_labels)
D_x = output.mean().item()
errD_real.backward()
# get fake image from generator
noise = get_noise()
fake = netG(noise)
# train with fake
output = netD(fake.detach())
errD_fake = bce(output, fake_labels)
D_G_z1 = output.mean().item()
errD_fake.backward()
# gradient update
errD = errD_real + errD_fake
optimizerD.step()
# -----------------------------------------------------------
# (2) Update G network: maximize log(D(G(z)))
netG.zero_grad()
# Update generator. We want to make a step that will make it more likely that discriminator outputs "real"
output = netD(fake)
errG = bce(output, real_labels)
D_G_z2 = output.mean().item()
errG.backward()
# gradient update
optimizerG.step()
return {
'errD': errD.item(),
'errG': errG.item(),
'D_x': D_x,
'D_G_z1': D_G_z1,
'D_G_z2': D_G_z2
}
# ignite objects
trainer = Engine(step)
checkpoint_handler = ModelCheckpoint(output_dir, CKPT_PREFIX, save_interval=1, n_saved=10, require_empty=False)
timer = Timer(average=True)
# attach running average metrics
monitoring_metrics = ['errD', 'errG', 'D_x', 'D_G_z1', 'D_G_z2']
RunningAverage(alpha=alpha, output_transform=lambda x: x['errD']).attach(trainer, 'errD')
RunningAverage(alpha=alpha, output_transform=lambda x: x['errG']).attach(trainer, 'errG')
RunningAverage(alpha=alpha, output_transform=lambda x: x['D_x']).attach(trainer, 'D_x')
RunningAverage(alpha=alpha, output_transform=lambda x: x['D_G_z1']).attach(trainer, 'D_G_z1')
RunningAverage(alpha=alpha, output_transform=lambda x: x['D_G_z2']).attach(trainer, 'D_G_z2')
# attach progress bar
pbar = ProgressBar()
pbar.attach(trainer, metric_names=monitoring_metrics)
@trainer.on(Events.ITERATION_COMPLETED)
def print_logs(engine):
if (engine.state.iteration - 1) % PRINT_FREQ == 0:
fname = os.path.join(output_dir, LOGS_FNAME)
columns = engine.state.metrics.keys()
values = [str(round(value, 5)) for value in engine.state.metrics.values()]
with open(fname, 'a') as f:
if f.tell() == 0:
print('\t'.join(columns), file=f)
print('\t'.join(values), file=f)
message = '[{epoch}/{max_epoch}][{i}/{max_i}]'.format(epoch=engine.state.epoch,
max_epoch=epochs,
i=(engine.state.iteration % len(loader)),
max_i=len(loader))
for name, value in zip(columns, values):
message += ' | {name}: {value}'.format(name=name, value=value)
pbar.log_message(message)
# adding handlers using `trainer.on` decorator API
@trainer.on(Events.EPOCH_COMPLETED)
def save_fake_example(engine):
fake = netG(fixed_noise)
path = os.path.join(output_dir, FAKE_IMG_FNAME.format(engine.state.epoch))
vutils.save_image(fake.detach(), path, normalize=True)
# adding handlers using `trainer.on` decorator API
@trainer.on(Events.EPOCH_COMPLETED)
def save_real_example(engine):
img, y = engine.state.batch
path = os.path.join(output_dir, REAL_IMG_FNAME.format(engine.state.epoch))
vutils.save_image(img, path, normalize=True)
# adding handlers using `trainer.add_event_handler` method API
trainer.add_event_handler(event_name=Events.EPOCH_COMPLETED, handler=checkpoint_handler,
to_save={
'netG': netG,
'netD': netD
})
# automatically adding handlers via a special `attach` method of `Timer` handler
timer.attach(trainer, start=Events.EPOCH_STARTED, resume=Events.ITERATION_STARTED,
pause=Events.ITERATION_COMPLETED, step=Events.ITERATION_COMPLETED)
# adding handlers using `trainer.on` decorator API
@trainer.on(Events.EPOCH_COMPLETED)
def print_times(engine):
pbar.log_message('Epoch {} done. Time per batch: {:.3f}[s]'.format(engine.state.epoch, timer.value()))
timer.reset()
# adding handlers using `trainer.on` decorator API
@trainer.on(Events.EPOCH_COMPLETED)
def create_plots(engine):
try:
import matplotlib as mpl
mpl.use('agg')
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
except ImportError:
warnings.warn('Loss plots will not be generated -- pandas or matplotlib not found')
else:
pass
# df = pd.read_csv(os.path.join(output_dir, LOGS_FNAME), delimiter='\t')
# x = np.arange(1, engine.state.iteration + 1, PRINT_FREQ)
# _ = df.plot(x=x, subplots=True, figsize=(20, 20))
# _ = plt.xlabel('Iteration number')
# fig = plt.gcf()
# path = os.path.join(output_dir, PLOT_FNAME)
#
# fig.savefig(path)
# adding handlers using `trainer.on` decorator API
@trainer.on(Events.EXCEPTION_RAISED)
def handle_exception(engine, e):
if isinstance(e, KeyboardInterrupt) and (engine.state.iteration > 1):
engine.terminate()
warnings.warn('KeyboardInterrupt caught. Exiting gracefully.')
create_plots(engine)
checkpoint_handler(engine, {
'netG_exception': netG,
'netD_exception': netD
})
else:
raise e
# Setup is done. Now let's run the training
trainer.run(loader, epochs)
if __name__ == '__main__':
dev = 'cuda:0'
output_dir = '../data/dcgan/'
# os.makedirs(output_dir)
print('torch version', torch.__version__)
print('torchvision version', torchvision.__version__)
print('ignite version, ', ignite.__version__)
print(sys.version)
main(dataset='mnist', dataroot='data/dcgan/',
z_dim=100, g_filters=64, d_filters=64,
batch_size=64, epochs=2500,
learning_rate=0.0002, beta_1=0.5,
saved_D=None, saved_G=None,
seed=42,
device=dev, n_workers=4,
alpha=0.98, output_dir=output_dir)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment