Skip to content

Instantly share code, notes, and snippets.

@htoyryla
Last active June 11, 2018 11:13
Show Gist options
  • Save htoyryla/d57cf3889bdf32efd696a82547390ca5 to your computer and use it in GitHub Desktop.
Save htoyryla/d57cf3889bdf32efd696a82547390ca5 to your computer and use it in GitHub Desktop.
HT-GAN including AE and growing image size
code for a gan trainer
adding an encoder to assist in generator training
htoyryla 8 Jun 2018
support for progressive training with larger image size
htoyryla 11.6.2018
from __future__ import print_function
import argparse
import os
import random
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils
from torch.autograd import Variable
import functools
from collections import OrderedDict
#
# code for models for traindcg2
# htoyryla 8 Jun 2018
#
# experimental
# seriously in need of refactoring
#
# v.2c3 try naming of layers
# v.2c4 larger kernels on larger layers
# v.2c5 models changed so that in D & E, layers are added to input size as image size grows
nz = 100
ngf = 64
ndf = 64
size = 64
def weights_init(m):
classname = m.__class__.__name__
if classname.find('Conv') != -1:
m.weight.data.normal_(0.0, 0.02)
if hasattr(m.bias, 'data'):
m.bias.data.fill_(0)
elif classname.find('Batchnorm') !=-1:
m.weight.data.normal_(1.0, 0.02)
m.bias.data.fill_(0)
# total variation loss module
# use as a loss module, not a layer
class TVLoss(nn.Module):
def __init__(self,TVLoss_weight=1):
super(TVLoss,self).__init__()
self.TVLoss_weight = TVLoss_weight
def forward(self, x):
batch_size = x.size()[0]
h_x = x.size()[2]
w_x = x.size()[3]
count_h = self._tensor_size(x[:,:,1:,:])
count_w = self._tensor_size(x[:,:,:,1:])
h_tv = torch.pow((x[:,:,1:,:]-x[:,:,:h_x-1,:]),2).sum()
w_tv = torch.pow((x[:,:,:,1:]-x[:,:,:,:w_x-1]),2).sum()
return self.TVLoss_weight*2*(h_tv/count_h+w_tv/count_w)/batch_size
def _tensor_size(self,t):
return t.size()[1]*t.size()[2]*t.size()[3]
# generator, duplicates my mygan architecture as of 4/2017
class _netG(nn.Module):
def __init__(self, ngpu=1, norm_layer=nn.BatchNorm2d, opt=None):
super(_netG, self).__init__()
assert(opt is not None)
nz = opt.nz
nc = opt.nc
ndf = opt.ndf
ngf = opt.ndf
size = opt.imageSize
self.ngpu = ngpu
if type(norm_layer) == functools.partial:
use_bias = norm_layer.func == nn.InstanceNorm2d
else:
use_bias = norm_layer == nn.InstanceNorm2d
use_bias = norm_layer==nn.InstanceNorm2d
layers = [
# input is Z, going into a convolution
("deconv1", nn.ConvTranspose2d( nz, ngf * 8, 4, 1, 0, bias=use_bias)),
("norm1", norm_layer(ngf * 8)),
("relu1", nn.ReLU(True)),
# state size. (ngf*8) x 4 x 4
("deconv2", nn.ConvTranspose2d(ngf * 8, ngf * 4, 3, 2, 1, 1, bias=use_bias)),
("norm2", norm_layer(ngf * 4)),
("relu2", nn.ReLU(True)),
# state size. (ngf*4) x 8 x 8
("deconv3", nn.ConvTranspose2d(ngf * 4, ngf * 2, 3, 2, 1, 1, bias=use_bias)),
("norm3", norm_layer(ngf * 2)),
("relu3", nn.ReLU(True)),
# state size. (ngf*2) x 16 x 16
("deconv4", nn.ConvTranspose2d(ngf * 2, ngf, 3, 2, 1, 1, bias=use_bias)),
("norm4", norm_layer(ngf)),
("relu4", nn.ReLU(True)),
# state size. (ngf) x 32 x 32
("deconv5", nn.ConvTranspose2d( ngf, ngf, 3, 2, 1, 1, bias=use_bias)),
("norm5", norm_layer(ngf)),
("relu5", nn.ReLU(True)) ]
if size > 64:
#ht 30.4.2018
more = [
("deconv6", nn.ConvTranspose2d( ngf, ngf, 3, 2, 1, 1, bias=use_bias)),
("norm6", norm_layer(ngf)),
("relu6", nn.ReLU(True)) ]
layers.extend(more)
if size > 128:
more = [
("deconv7", nn.ConvTranspose2d( ngf, ngf, 3, 2, 1, 1, bias=use_bias)),
("norm7", norm_layer(ngf)),
("relu7", nn.ReLU(True)) ]
layers.extend(more)
if size > 256:
more = [
("deconv8", nn.ConvTranspose2d( ngf, ngf, 6, 2, 2, 0, bias=use_bias)),
("norm8", norm_layer(ngf)),
("relu8", nn.ReLU(True)) ]
layers.extend(more)
if size > 512:
more = [
("deconv9", nn.ConvTranspose2d( ngf, ngf, 6, 2, 2, 0, bias=use_bias)),
("norm9", norm_layer(ngf)),
("relu9", nn.ReLU(True)) ]
layers.extend(more)
final = [
("outconv", nn.Conv2d(ngf, nc, kernel_size=3, stride=1, padding=1, bias=use_bias)),
("outactiv", nn.Tanh()) ]
layers.extend(final)
print(layers)
self.main = nn.Sequential(OrderedDict(layers))
def forward(self, input):
if isinstance(input.data, torch.cuda.FloatTensor) and self.ngpu > 1:
output = nn.parallel.data_parallel(self.main, input, range(self.ngpu))
else:
output = self.main(input)
return output
# discriminator
class _netD(nn.Module):
def __init__(self, ngpu=1, use_sigmoid=True, norm_layer=nn.BatchNorm2d, opt=None):
super(_netD, self).__init__()
nz = opt.nz
nc = opt.nc
ndf = opt.ndf
ngf = opt.ndf
size = opt.imageSize
self.ngpu = ngpu
if type(norm_layer) == functools.partial:
use_bias = norm_layer.func==nn.InstanceNorm2d
else:
use_bias = norm_layer==nn.InstanceNorm2d
sequence = [
("inconv", nn.Conv2d(nc, ndf, kernel_size=4, stride=2, padding=1, bias=use_bias)),
("inrelu", nn.LeakyReLU(0.2, inplace=True))]
if size > 512:
more = [
("conv1024", nn.Conv2d(ndf, ndf, kernel_size=4, stride=2, padding=1, bias=use_bias)),
("norm1024", norm_layer(ndf)),
("relu1024", nn.LeakyReLU(0.2, inplace=True)) ]
sequence.extend(more)
if size > 256:
more = [
("conv512", nn.Conv2d(ndf, ndf, kernel_size=4, stride=2, padding=1, bias=use_bias)),
("norm5", norm_layer(ndf)),
("relu5", nn.LeakyReLU(0.2, inplace=True)) ]
sequence.extend(more)
if size > 128:
more = [
("conv256", nn.Conv2d(ndf, ndf, kernel_size=4, stride=2, padding=1, bias=use_bias)),
("norm256", norm_layer(ndf)),
("relu256", nn.LeakyReLU(0.2, inplace=True)) ]
sequence.extend(more)
if size > 64:
more = [
("conv128", nn.Conv2d(ndf, ndf, kernel_size=4, stride=2, padding=1, bias=use_bias)),
("norm128", norm_layer(ndf)),
("relu128", nn.LeakyReLU(0.2, inplace=True)) ]
sequence.extend(more)
body = [
("conv2", nn.Conv2d(ndf, ndf * 2, kernel_size=4, stride=2, padding=1, bias=use_bias)),
("norm2", norm_layer(ndf * 2)),
("relu2", nn.LeakyReLU(0.2, inplace=True)),
("conv3", nn.Conv2d(ndf * 2, ndf * 4, kernel_size=4, stride=2, padding=1, bias=use_bias)),
("norm3", norm_layer(ndf * 4)),
("relu3", nn.LeakyReLU(0.2, inplace=True)),
('convf4', nn.Conv2d(ndf * 4, 1, kernel_size=4, stride=4, padding=1, bias=use_bias)),
('outconv', nn.Conv2d(1, 1, kernel_size=4, stride=4, padding=1, bias=use_bias))]
sequence.extend(body)
if use_sigmoid:
sequence += [("outactiv", nn.Sigmoid())]
self.main = nn.Sequential(OrderedDict(sequence))
def forward(self, input):
if isinstance(input.data, torch.cuda.FloatTensor) and self.ngpu > 1:
output = nn.parallel.data_parallel(self.main, input, range(self.ngpu))
else:
output = self.main(input)
return output.view(-1, 1).squeeze(1)
# encoder, duplicates my mygan architecture as of 4/2017
class _netE(nn.Module):
def __init__(self, ngpu=1, use_sigmoid=True, norm_layer=nn.BatchNorm2d, opt=None):
super(_netE, self).__init__()
self.ngpu = ngpu
nz = opt.nz
nc = opt.nc
ndf = opt.ndf
ngf = opt.ndf
size = opt.imageSize
if type(norm_layer) == functools.partial:
use_bias = norm_layer.func==nn.InstanceNorm2d
else:
use_bias = norm_layer==nn.InstanceNorm2d
sequence = [
("inconv", nn.Conv2d(nc, ndf, kernel_size=3, stride=2, padding=1, bias=use_bias)),
("inrelu", nn.LeakyReLU(0.2, inplace=True))]
if size > 512:
more = [
("conv1024a", nn.Conv2d(ndf, ndf, kernel_size=3, stride=2, padding=1, bias=use_bias)),
("conv1024b", nn.Conv2d(ndf, ndf, kernel_size=3, stride=1, padding=1, bias=use_bias)),
("norm1024", norm_layer(ndf)),
("relu1024", nn.LeakyReLU(0.2, inplace=True))]
sequence.extend(more)
if size > 256:
more = [
("conv512a", nn.Conv2d(ndf, ndf, kernel_size=3, stride=2, padding=1, bias=use_bias)),
("conv512b", nn.Conv2d(ndf, ndf, kernel_size=3, stride=1, padding=1, bias=use_bias)),
("norm512", norm_layer(ndf)),
("relu512", nn.LeakyReLU(0.2, inplace=True))]
sequence.extend(more)
if size > 128:
more = [
("conv256a", nn.Conv2d(ndf, ndf, kernel_size=3, stride=2, padding=1, bias=use_bias)),
("conv256b", nn.Conv2d(ndf, ndf, kernel_size=3, stride=1, padding=1, bias=use_bias)),
("norm256", norm_layer(ndf)),
("relu256", nn.LeakyReLU(0.2, inplace=True))]
sequence.extend(more)
if size > 64:
more = [
("conv128a", nn.Conv2d(ndf, ndf, kernel_size=3, stride=2, padding=1, bias=use_bias)),
("conv128b", nn.Conv2d(ndf, ndf, kernel_size=3, stride=1, padding=1, bias=use_bias)),
("norm128", norm_layer(ndf)),
("relu128", nn.LeakyReLU(0.2, inplace=True))]
sequence.extend(more)
body = [
("conv2", nn.Conv2d(ndf, ndf * 2, kernel_size=3, stride=2, padding=1, bias=use_bias)),
("conv3", nn.Conv2d(ndf*2, ndf * 2, kernel_size=3, stride=1, padding=1, bias=use_bias)),
("norm3", norm_layer(ndf * 2)),
("relu3", nn.LeakyReLU(0.2, inplace=True)),
("conv4", nn.Conv2d(ndf * 2, ndf * 4, kernel_size=3, stride=2, padding=1, bias=use_bias)),
("conv5", nn.Conv2d(ndf*4, ndf * 4, kernel_size=3, stride=1, padding=1, bias=use_bias)),
("norm5", norm_layer(ndf * 4)),
("relu5", nn.LeakyReLU(0.2, inplace=True)),
("conv6", nn.Conv2d(ndf * 4, ndf * 8, kernel_size=3, stride=2, padding=1, bias=use_bias)),
("conv7", nn.Conv2d(ndf*8, ndf * 8, kernel_size=3, stride=1, padding=1, bias=use_bias)),
("norm7", norm_layer(ndf * 8)),
("relu7", nn.LeakyReLU(0.2, inplace=True)),
("outconv", nn.Conv2d( ndf * 8, 100, kernel_size=4, stride=1, padding=0, bias=use_bias)) ]
sequence.extend(body)
if use_sigmoid:
sequence += [("outactiv", nn.Sigmoid())]
self.main = nn.Sequential(OrderedDict(sequence))
def forward(self, input):
if isinstance(input.data, torch.cuda.FloatTensor) and self.ngpu > 1:
output = nn.parallel.data_parallel(self.main, input, range(self.ngpu))
else:
output = self.main(input)
return output.view(-1, 1).squeeze(1)
from __future__ import print_function
import argparse
import os
import sys
import random
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils
from torch.autograd import Variable
from model2c5 import _netG, _netD, _netE, weights_init, TVLoss
from torch.optim import lr_scheduler
import numpy as np
import cv2
# gan trainer
# htoyryla 8 Jun 2018
#
# put images in datasets/<name>/train/
#
# v.2c3 adds noisy labels
# v.2c4 add loading of netE weights
# v.2c5 add non-strict (partial) loading of pretrained weights (requires model2c3)
# v.2c6 add optional freezing of already trained layers
# v.2c7 use noisy labels only in discriminator, models: use larger kernels in upper layers
# v.2c8 models changed so that in D & E, layers are added to input size as image size grows
parser = argparse.ArgumentParser()
parser.add_argument('--dataset', default='folder', required=True, help=' folder | fake')
parser.add_argument('--dataroot', required=True, help='path to dataset')
parser.add_argument('--workers', type=int, default=8, help='number of data loading workers')
parser.add_argument('--batchSize', type=int, default=64, help='input batch size')
parser.add_argument('--imageSize', type=int, default=64, help='the height / width of the input image to network')
parser.add_argument('--nz', type=int, default=100, help='size of the latent z vector')
parser.add_argument('--ngf', type=int, default=64)
parser.add_argument('--ndf', type=int, default=64)
parser.add_argument('--nc', type=int, default=3)
parser.add_argument('--lmbd', type=float, default=100, help='lambda, default=100')
parser.add_argument('--tvloss', type=float, default=0.0002, help='tv loss, default=0.0002')
parser.add_argument('--niter', type=int, default=60, help='number of epochs to train for')
parser.add_argument('--save_every', type=int, default=10, help='number of epochs between saves')
parser.add_argument('--imgStep', type=int, default=0, help='minibatches between image folder saves')
parser.add_argument('--lrG', type=float, default=0.0002, help='learning G rate, default=0.0002')
parser.add_argument('--lrD', type=float, default=0.00005, help='learning D rate, default=0.0002')
parser.add_argument('--lrE', type=float, default=0.0002, help='learning E rate, default=0.0002')
parser.add_argument('--step', type=int, default=40, help='lr step, default=40')
parser.add_argument('--gamma', type=float, default=0.1, help='gamma, default=0.1')
parser.add_argument('--beta1', type=float, default=0.5, help='beta1 for adam. default=0.5')
parser.add_argument('--netG', default='', help="path to netG (to continue training)")
parser.add_argument('--netD', default='', help="path to netD (to continue training)")
parser.add_argument('--netE', default='', help="path to netE (to continue training)")
parser.add_argument('--name', default='baseline', help='folder to output images and model checkpoints')
parser.add_argument('--manualSeed', type=int, help='manual seed')
parser.add_argument('--gpu_ids', default='0', type=str, help='gpu_ids: e.g. 0 0,1,2 0,2')
parser.add_argument('--lsgan', action='store_true', help='use lsgan')
parser.add_argument('--instance', action='store_true', help='use instance norm')
parser.add_argument('--withoutE', action='store_true', help='do not use Encoder Network')
parser.add_argument('--debug', action='store_true', help='show debug info')
parser.add_argument('--hsv', action='store_true', help='use hsv color space')
parser.add_argument('--weight_decay', type=float, default=0, help='L2 regularization weight. Greatly helps convergence but leads to artifacts in images, not recommended.')
parser.add_argument('--nlabels', action='store_true', help='use noisy labels')
parser.add_argument('--nostrict', action='store_true', help='allow partial loading of pretrained nets')
parser.add_argument('--freeze', action='store_true', help='freeze already trained layers')
opt = parser.parse_args()
str_ids = opt.gpu_ids.split(',')
gpu_ids = []
for str_id in str_ids:
id = int(str_id)
if id>=0:
gpu_ids.append(id)
print(opt)
try:
os.makedirs(os.path.join('./model',opt.name))
os.makedirs(os.path.join('./visual',opt.name))
except OSError:
pass
if opt.manualSeed is None:
opt.manualSeed = random.randint(1, 10000)
print("Random Seed: ", opt.manualSeed)
random.seed(opt.manualSeed)
torch.manual_seed(opt.manualSeed)
opt.cuda=False
if torch.cuda.is_available():
opt.cuda=True
torch.cuda.manual_seed_all(opt.manualSeed)
torch.cuda.set_device(gpu_ids[0])
cudnn.benchmark = True
if opt.dataset in ['imagenet', 'folder', 'lfw']:
dataset = dset.ImageFolder(root=opt.dataroot,
transform=transforms.Compose([
transforms.Resize(opt.imageSize),
transforms.CenterCrop(opt.imageSize),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
]))
elif opt.dataset == 'fake':
dataset = dset.FakeData(image_size=(3, opt.imageSize, opt.imageSize),
transform=transforms.ToTensor())
assert dataset
dataloader = torch.utils.data.DataLoader(dataset, batch_size=opt.batchSize,
shuffle=True, num_workers=int(opt.workers))
ngpu = len(gpu_ids)
nz = int(opt.nz)
ngf = int(opt.ngf)
ndf = int(opt.ndf)
nc = 3
lmbd = opt.lmbd
assert(opt.imageSize in [64,128,256,512,1024])
# pre/deprocessing for using hsv color space in model
# not fully working at the moment
def preproc(im, blur=0):
im = (im.numpy()*255).astype(np.uint8).transpose((2, 1, 0))
if blur > 2:
im = cv2.medianBlur(im, blur)
hsv = cv2.cvtColor(im, cv2.COLOR_RGB2HSV)
h,s,v = cv2.split(hsv)
h = h / 179.9
s = s / 255.
v = v / 255.
hsv = cv2.merge((h, s, v)).transpose(2,0,1)
hsv = (hsv - 0.5)
return hsv
def deproc(hsv, blur=0):
#hsv = torch.clamp(hsv, -1, 1)
hsv = hsv/2 + 0.5
hsv = hsv.cpu().numpy().transpose(1,2,0)
h,s,v = cv2.split(hsv)
h = h * 179.9
s = s * 255.
v = v * 255.
hsv = cv2.merge((h, s, v)).astype(np.uint8)
im = cv2.cvtColor(hsv, cv2.COLOR_HSV2RGB)
im = im / 255.
im = im.transpose(2,1,0)
return im
# create models
# generator
if opt.instance:
netG = _netG(ngpu, norm_layer=nn.InstanceNorm2d, opt=opt)
netG.apply(weights_init)
else:
netG = _netG(ngpu, opt=opt)
netG.apply(weights_init)
# load prelearned weights if any
if opt.netG != '':
Gpar = torch.load(opt.netG)
try:
netG.load_state_dict(Gpar, strict = not opt.nostrict)
except RuntimeError:
print("Layer size mismatch during loading")
print(netG)
# discriminator
if opt.instance:
netD = _netD(ngpu, use_sigmoid=(not opt.lsgan), norm_layer=nn.InstanceNorm2d, opt=opt)
netD.apply(weights_init)
else:
netD = _netD(ngpu, use_sigmoid=(not opt.lsgan), opt=opt)
netD.apply(weights_init)
# load prelearned weights if any
if opt.netD != '':
Dpar = torch.load(opt.netD)
try:
netD.load_state_dict(Dpar, strict = not opt.nostrict)
except RuntimeError:
print("Layer size mismatch during loading")
print(netD)
# encoder
if not opt.withoutE:
if opt.instance:
netE = _netE(ngpu, use_sigmoid=(not opt.lsgan), norm_layer=nn.InstanceNorm2d, opt=opt)
netE.apply(weights_init)
else:
netE = _netE(ngpu, use_sigmoid=(not opt.lsgan), opt=opt)
netE.apply(weights_init)
# load prelearned weights if any
if opt.netE != '':
Epar = torch.load(opt.netE)
try:
netE.load_state_dict(Epar, strict = not opt.nostrict)
except RuntimeError:
print("Layer size mismatch during loading")
print(netE)
# freeze
if opt.freeze:
for key, mod in netG.main.named_children():
k = key.split(".")
if k[0] == "outconv": continue #do not freeze the final output layer
layer = "main."+k[0]
w = layer+".weight"
if w in Gpar.keys(): # otherwise freeze all pretrained layers
print("freezing netG."+w)
mod.requires_grad = False
for key, mod in netD.main.named_children():
k = key.split(".")
if k[0] == "inconv": continue #do not freeze the input layer
layer = "main."+k[0]
w = layer+".weight"
if w in Dpar.keys(): # otherwise freeze all pretrained layers
print("freezing netD."+w)
mod.requires_grad = False
for key, mod in netE.main.named_children():
k = key.split(".")
if k[0] == "inconv": continue #do not freeze the input layer
layer = "main."+k[0]
w = layer+".weight" # otherwise freeze all pretrained layers
if w in Epar.keys():
print("freezing netE."+w)
mod.requires_grad = False
# loss module for real / fake testing
class GANLoss(nn.Module):
def __init__(self, use_lsgan=False, target_real_label=1.0, target_fake_label=0.0, noisy=False, tensor=torch.FloatTensor):
super(GANLoss, self).__init__()
self.real_label = target_real_label
self.fake_label = target_fake_label
self.real_label_var = None
self.fake_label_var = None
self.Tensor = tensor
self.noisy = noisy
if use_lsgan:
self.loss = nn.MSELoss()
else:
self.loss = nn.BCELoss()
# make a target tensor for real and fake
# use noisy labels if opt.nlabels
def get_target_tensor(self, input, target_is_real):
target_tensor = None
if target_is_real:
if self.noisy:
real_tensor = self.Tensor(input.size()).uniform_(0.8, 1.0)
else:
real_tensor = self.Tensor(input.size()).fill_(self.real_label)
self.real_label_var = Variable(real_tensor, requires_grad=False)
target_tensor = self.real_label_var
else:
if self.noisy:
fake_tensor = self.Tensor(input.size()).uniform_(0, 0.2)
else:
fake_tensor = self.Tensor(input.size()).fill_(self.fake_label)
self.fake_label_var = Variable(fake_tensor, requires_grad=False)
target_tensor = self.fake_label_var
return target_tensor
def __call__(self, input, target_is_real):
target_tensor = self.get_target_tensor(input, target_is_real)
return self.loss(input, target_tensor)
# additional loss functions for AE loss
def mse_loss(input, target):
return torch.sum((input - target)**2) / input.data.nelement()
def l1_loss(input, target):
return torch.sum(torch.abs(input - target)) / input.data.nelement()
criterion = GANLoss(use_lsgan=opt.lsgan, tensor=torch.cuda.FloatTensor)
Dcriterion = GANLoss(use_lsgan=opt.lsgan, noisy = opt.nlabels, tensor=torch.cuda.FloatTensor)
criterionL1 = nn.L1Loss()
tvloss = TVLoss(opt.tvloss)
# general purpose vectors
input = torch.FloatTensor(opt.batchSize, 3, opt.imageSize, opt.imageSize)
noise = torch.FloatTensor(opt.batchSize, nz, 1, 1)
fixed_noise = torch.FloatTensor(opt.batchSize, nz, 1, 1).normal_(0, 1)
label = torch.FloatTensor(opt.batchSize)
real_label = 1
fake_label = 0
if opt.cuda:
netD.cuda()
netG.cuda()
if not opt.withoutE:
netE.cuda()
criterion.cuda()
criterionL1.cuda()
input, label = input.cuda(), label.cuda()
noise, fixed_noise = noise.cuda(), fixed_noise.cuda()
fixed_noise = Variable(fixed_noise)
# setup optimizer
optimizerD = optim.Adam(netD.parameters(), lr=opt.lrD, betas=(opt.beta1, 0.999), weight_decay=opt.weight_decay)
optimizerG = optim.Adam(netG.parameters(), lr=opt.lrG, betas=(opt.beta1, 0.999), weight_decay=opt.weight_decay)
if not opt.withoutE:
optimizerE = optim.Adam(netE.parameters(), lr=opt.lrE, betas=(opt.beta1, 0.999), weight_decay=opt.weight_decay)
schedulers = []
schedulers.append(lr_scheduler.StepLR(optimizerD, step_size=opt.step, gamma=opt.gamma))
schedulers.append(lr_scheduler.StepLR(optimizerG, step_size=opt.step, gamma=opt.gamma))
if not opt.withoutE:
schedulers.append(lr_scheduler.StepLR(optimizerE, step_size=opt.step, gamma=opt.gamma))
# main training loop starts here
imgCtr = 0
for epoch in range(opt.niter):
# get a batch of input images
for i, data in enumerate(dataloader, 0):
iimgs = data[0].clone() # store input images for later display
# convert to hsv if needed
if opt.hsv:
data_ = []
for im in data[0]:
data_.append(preproc(im, 0))
data[0] = torch.Tensor(data_)
# update netD
# first with real
netD.zero_grad()
real_cpu, _ = data
batch_size = real_cpu.size(0) # needed for take care of an incomplete batch at the end of an epoch
if opt.cuda:
real_cpu = real_cpu.cuda()
input.resize_as_(real_cpu).copy_(real_cpu)
inputv = Variable(input)
output = netD(inputv) # get D(x)
errD_real = Dcriterion(output, True) # get err ref to real
errD_real.backward() # get D_real gradients
D_x = output.data.mean()
# train with fake
noise.resize_(batch_size, nz, 1, 1).normal_(0, 1) # get z
noisev = Variable(noise)
fake_z = netG(noisev) # G(z)
output = netD(fake_z.detach()) # D(G(z))
errD_fake = Dcriterion(output, False) # get err ref to fake
errD_fake.backward() # get D_fake gradients
D_G_z1 = output.data.mean()
errD = errD_real + errD_fake # total D err display
optimizerD.step() # update D weights
#
# Update netG
#
netG.zero_grad()
output = netD(fake_z) # use G(z) fake from above, TODO should we use another?
tvl = 0
# if encoder used
if not opt.withoutE:
# encode input into Z
embedding = netE(inputv.detach()).view(batch_size,opt.nz,1,1) # z = E(x)
fake_e = netG(embedding.detach()) # new fake = G(E(x)), detach from E, train E later
errG = criterion(output, True) # get D(G(x)) err
dist = l1_loss(inputv.detach(), fake_e)*lmbd # get err between input and G(E(x))
errG = errG + dist
if opt.tvloss:
tvl = tvloss(fake_e) # tv loss on G(E(x))
errG = errG + tvl
else:
errG = criterion(output, True) # no E used, take plain gan loss as errG
if opt.tvloss:
tvl = tvloss(fake_z) # just add TV loss on G(z)
errG = errG + tvl
dist = 0
errG.backward() # get G gradients
D_G_z2 = output.data.mean()
optimizerG.step() # update G parameters
# Update E
if not opt.withoutE:
netE.zero_grad()
embedding = netE(fake_z.detach()) # E(G(z))
errE = criterionL1(embedding.view(batch_size, opt.nz, 1, 1), noisev) # err between E(G(z)) and z
errE.backward()
optimizerE.step()
print('[%d/%d][%d/%d] Loss_D: %.4f Loss_G: %.4f Loss_E: %.4f D(x): %.4f Dist: %4f TVLoss: %4f D(G(z)): %.4f / %.4f'
% (epoch, opt.niter, i, len(dataloader),
errD.data, errG.data, errE.data, D_x, dist, tvl, D_G_z1, D_G_z2))
else:
print('[%d/%d][%d/%d] Loss_D: %.4f Loss_G: %.4f D(x): %.4f D(G(z)): %.4f / %.4f TVLss: %4f'
% (epoch, opt.niter, i, len(dataloader),
errD.data, errG.data, D_x, D_G_z1, D_G_z2, tvl))
# save single samples if opt.imgStep > 0
if opt.imgStep != 0 and imgCtr % opt.imgStep == 0:
sampleNoise = torch.FloatTensor(opt.batchSize, nz, 1, 1).normal_(0, 1)
if opt.cuda: sampleNoise = sampleNoise.cuda()
fakeimg = netG(Variable(sampleNoise))
fakeimg = fakeimg.data
if opt.hsv:
sample = []
for hsv in fakeimg:
sample.append(deproc(hsv))
fakeimg = torch.Tensor(sample)
vutils.save_image(fakeimg,
'./images/'+opt.name+'-sample%06d.png' % (int(imgCtr/opt.imgStep)),
normalize=True)
imgCtr = imgCtr + 1
# visualize results
if i % 100 == 0:
vutils.save_image(iimgs,
'./visual/%s/real_samples.png' % opt.name,
normalize=True)
fake = netG(fixed_noise)
fake = fake.data
if opt.hsv:
sample = []
for hsv in fake:
sample.append(deproc(hsv))
fake = torch.Tensor(sample)
print('saving fakes ', fake.shape)
vutils.save_image(fake,
'./visual/%s/fake_samples_epoch_%03d.png' % (opt.name, epoch),
normalize=True)
# do checkpointing
if epoch % opt.save_every == 0:
torch.save(netG.state_dict(), './model/%s/netG_epoch_%d.pth' % (opt.name, epoch))
torch.save(netD.state_dict(), './model/%s/netD_epoch_%d.pth' % (opt.name, epoch))
if not opt.withoutE:
torch.save(netE.state_dict(), './model/%s/netE_epoch_%d.pth' % (opt.name, epoch))
#step lrRate
for scheduler in schedulers:
scheduler.step()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment