Skip to content

Instantly share code, notes, and snippets.

@MercuriXito
Last active April 28, 2020 01:37
Show Gist options
  • Save MercuriXito/f8aa4944493865974f5381f9cf0849f5 to your computer and use it in GitHub Desktop.
Save MercuriXito/f8aa4944493865974f5381f9cf0849f5 to your computer and use it in GitHub Desktop.
GAN Training FrameWork

Train Frame

GAN 训练框架版本v2。

Modified

对简单的 Conditional GAN 和 Unconditional GAN 可以直接简单套用。 自定义的训练方式通过继承TrainerFrame类,并重写某些函数的形式,如例子的DCGANTrainer

Problems Remained

  • Loss 的计算还是不能很好地兼容不同种类的GAN,需要进一步设计 Loss 的计算类。
import os, time, json
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.tensorboard import SummaryWriter
from torchvision.utils import make_grid, save_image
from torch.nn.init import kaiming_normal_, xavier_normal_
import torch.autograd as autograd
# --------------------------------
# Loss
class BaseGANLoss:
def calculate_loss_D(self, netD, netG, images, fake_images, labels):
raise Exception("Not Implemented")
def calculate_loss_G(self, netD, netG, images, fake_images, labels):
raise Exception("Not Implemented")
class GANLoss(BaseGANLoss):
def __init__(self, last_layer_without_sigmoid = False, device = "cuda"):
if last_layer_without_sigmoid:
self.criterion = nn.BCEWithLogitsLoss()
else:
self.criterion = nn.BCELoss()
self.device = device
self.one = torch.tensor(1, dtype=torch.float, device=device)
self.zero = torch.tensor(0, dtype=torch.float, device=device)
def calculate_loss_D(self, netD, netG, images, fake_images, labels):
outputs = netD(images)
true_loss = self.criterion(outputs, self.one.expand(outputs.size(0)))
outputs = netD(fake_images)
fake_loss = self.criterion(outputs, self.zero.expand(outputs.size(0)))
return true_loss + fake_loss, true_loss, fake_loss
def calculate_loss_G(self, netD, netG, images, fake_images, labels):
outputs = netD(fake_images)
fake_loss = self.criterion(outputs, self.one.expand(outputs.size(0)))
return fake_loss
class WGANLoss(BaseGANLoss):
def __init__(self, gp_lambda = 10, use_gp = True, device = "cuda"):
self.use_gp = use_gp
self.device = device
self.gp_lambda = gp_lambda
def calculate_loss_D(self, netD, netG, images, fake_images, labels):
out_true = netD(images)
out_fake = netD(fake_images)
loss = out_fake.mean() - out_true.mean()
if self.use_gp:
gp_loss = self.calculate_gradient_penalty(netD, images, fake_images, self.gp_lambda, self.device)
loss += gp_loss
return loss, gp_loss
return loss
def calculate_loss_G(self, netD, netG, images, fake_images, labels):
return - netD(fake_images).mean()
def calculate_gradient_penalty(self, netD, images, fake_images, gp_lambda, device):
batch_size, C, W, H = images.size()
alpha = torch.randn((batch_size, 1), device=device)
alpha = alpha.expand((batch_size, C * W * H)).contiguous()
alpha = alpha.view_as(images)
interpolate = alpha * images + (1 - alpha) * fake_images
interpolate = interpolate.to(device)
interpolate.requires_grad_(True)
out = netD(interpolate)
if isinstance(out, tuple):
out = out[0]
grads = autograd.grad(out, interpolate,
grad_outputs=torch.ones_like(out).type(torch.float).to(device),
retain_graph=True, create_graph=True)[0]
grads = grads.view(grads.size(0), -1)
return gp_lambda * ((grads.norm(p=2, dim = 1) - 1) ** 2).mean()
class HingeLoss(BaseGANLoss):
""" Adversarial Loss proposed and used in "Geometric GAN" and "Self Attention GAN".
"""
def __init__(self, device = "cuda"):
self.device = device
self.min0 = lambda x: x - torch.relu(x)
def calculate_loss_D(self, netD, netG, images, fake_images, labels):
out_true = netD(images, labels)
true_loss = (-self.min0(out_true - 1)).mean()
out_fake = netD(fake_images, labels)
fake_loss = (-self.min0(-1-out_fake)).mean()
return true_loss + fake_loss
def calculate_loss_G(self, netD, netG, images, fake_images, labels):
lossG = - netD(fake_images, labels)
return lossG
class LSGANLoss(BaseGANLoss):
""" Adversrial Loss proposed from Least-Square GAN.
"""
def __init__(self, a, b, device = "cuda"):
self.a = a
self.b = b
self.device = device
def calculate_loss_D(self, netD, netG, images, fake_images, labels):
raise Exception("Not Implemented")
def calculate_loss_G(self, netD, netG, images, fake_images, labels):
raise Exception("Not Implemented")
class MarginLoss(BaseGANLoss):
""" Adversarial Loss prosposed in "EBGAN"
"""
def __init__(self, margin):
self.margin = margin
def calculate_loss_D(self, netD, netG, images, fake_images, labels):
true_loss = netD(images)
fake_loss = torch.relu(self.margin - netD(fake_images))
lossD = true_loss + fake_loss
return lossD
def calculate_loss_G(self, netD, netG, images, fake_images, labels):
lossG = -netD(fake_images)
return lossG
class ACGANLoss(BaseGANLoss):
""" Adversarial Loss proposed in ACGAN, with auxiliary classification loss.
"""
def __init__(self, last_layer_without_sigmoid = True, device = "cuda"):
self.clsc = nn.CrossEntropyLoss()
if last_layer_without_sigmoid:
self.dc = nn.BCEWithLogitsLoss()
else:
self.dc = nn.BCELoss()
self.one = torch.tensor(1, dtype=torch.float, device=device)
self.zero = torch.tensor(0, dtype=torch.float, device=device)
def calculate_loss_D(self, netD, netG, images, fake_images, labels):
d_out, cls_out = netD(images)
true_d_loss = self.dc(d_out, self.one.expand(d_out.size()))
true_cls_loss = self.clsc(cls_out, labels)
fake_d_out, fake_cls_out = netD(fake_images)
fake_d_loss = self.dc(fake_d_out, self.zero.expand(fake_d_out.size()))
fake_cls_loss = self.clsc(fake_cls_out, labels)
lossD = true_d_loss + true_cls_loss + fake_d_loss + fake_cls_loss
return lossD
def calculate_loss_G(self, netD, netG, images, fake_images, labels):
fake_d_out, fake_cls_out = netD(fake_images)
fake_d_loss = self.dc(fake_d_out, self.one.expand(fake_d_out.size()))
fake_cls_loss = self.clsc(fake_cls_out, labels)
lossG = fake_d_loss + fake_cls_loss
return lossG
# -------------------------
# Training Frame for training simple conditional GAN and unconditional GAN
# default hyper parameters
params = {
"device": None,
"noise_size" : 100,
"epochs" : 200,
"train_D_interval" : 1,
"train_G_interval" : 1,
"save_model_interval" : 5,
"D_lr" : 1e-4,
"G_lr" : 1e-4,
"optimizerD" : "adam",
"optimizerG" : "adam",
"D_optimizer_params" : None,
"G_optimizer_params" : None,
"use_loss" : "gan",
"loss_params" : {}
}
class TrainerConfig:
def __init__(self, params = params):
self.params = params
# modify
if self.params["device"] is None:
self.params["device"] = "cuda" if torch.cuda.is_available() else "cpu"
self.params["loss_params"]["device"] = self.params["device"]
for name, value in self.params.items():
setattr(self, name, value)
def __str__(self):
return json.dumps(self.params)
def dump_as_json(self, root, name = "config.json"):
with open(root + name, "w") as fp:
json.dump(self.params, fp)
@staticmethod
def get_timestr(self):
format = "%Y_%m_%d_%H_%M_%S"
return time.strftime(format, time.localtime())
# 设计一个框架,必要的时候,可以通过继承的方式修改关键的函数才是最好的方案。
class TrainerFrame:
""" To best adapated to training, override some functions, especially :
+ functions related to training: `generate_fake_images`, `train_netG_one_time`, `train_netD_one_time`
+ functions related to selection for further exploration of more selection
"""
def __init__(self, opt, save_root = ".", test_frame = False,
load_model = False, load_root = None, load_netG_name = "netG.pth", load_netD_name = "netD.pth"):
self.save_root = save_root
if not os.path.exists(self.save_root):
os.makedirs(self.save_root)
if self.save_root[-1] != os.sep:
self.save_root = self.save_root + os.sep
self.save_model_root = self.save_root + "models" + os.sep
self.save_images_root = self.save_root + "images" + os.sep
if not os.path.exists(self.save_model_root):
os.makedirs(self.save_model_root)
if not os.path.exists(self.save_images_root):
os.makedirs(self.save_images_root)
self.logger = SummaryWriter(self.save_root)
self.device = opt.device
self.opt = opt
self.test_frame = test_frame
self._initialize_parameters()
# if load_model:
# self.load_model(load_root, load_netG_name, load_netD_name)
# else:
# self._initialize_net()
if self.test_frame:
self._initialize_test_param()
#--------------------- initialization functions -----------------------------
def _initialize_net(self, netD, netG, method = "normal", var = 0.02):
netD.apply(self.get_weight_initializer(method=method, var=var))
netG.apply(self.get_weight_initializer(method=method, var=var))
def _initialize_parameters(self):
self.steps = { # record steps for tensorboard
"epoch_step": 0,
"iter_step": 0,
"netD_train_step": 0,
"netG_train_step": 0,
}
self.iter_intervals = { # training interval
"netD_train": self.opt.train_D_interval,
"netG_train": self.opt.train_G_interval,
}
self.epoch_interval = { # intervals in epochs for record
"save_model": self.opt.save_model_interval,
}
self.epochs = self.opt.epochs
self.train_D_interval = self.iter_intervals["netD_train"]
self.train_G_interval = self.iter_intervals["netG_train"]
self.save_model_interval = self.epoch_interval["save_model"]
self.D_lr = self.opt.D_lr
self.G_lr = self.opt.G_lr
self.D_optimizer = self.opt.optimizerD
self.G_optimizer = self.opt.optimizerG
self.D_optimizer_params = self.opt.D_optimizer_params
self.G_optimizer_params = self.opt.G_optimizer_params
self.use_loss = self.opt.use_loss
self.loss_params = self.opt.loss_params
self.loss_calculator = self.select_loss(self.use_loss, self.loss_params)
def _initialize_test_param(self):
""" Renew some parameters for testing the whole frame.
"""
self.epochs = 1
self.epoch_interval = { # intervals in epochs for record
"save_model": 1,
}
def switch_to_train(self):
""" Switch to training mode if trainer is under test phrase"""
if self.test_frame:
self.epochs = self.opt.epochs
self.test_frame = False
self.epoch_interval = { # intervals in epochs for record
"save_model": self.opt.save_model_interval,
}
#--------------------- selection functions -----------------------------
def select_loss(self, name, loss_params, *args, **kws):
if name == "gan":
calculator = GANLoss(**loss_params)
elif name == "wgan":
calculator = WGANLoss(**loss_params)
elif name == "hinge":
calculator = HingeLoss(**loss_params)
elif name == "acgan":
calculator = ACGANLoss(**loss_params)
else:
raise Exception("No implemented Loss")
return calculator
def get_optimizer(self, net, lr, name = "sgd", optim_params = None):
if optim_params is None:
if name == "adam":
optim_params = {"betas":(0.5,0.99), "weight_decay": 0, "eps": 1e-8, "amsgrad":False}
elif name == "sgd":
optim_params = {"momentum":0, "weight_decay":0, "dampening": 0, "nesterov":False}
else:
raise Exception("No support optimizer")
assert isinstance(optim_params, dict), "optim_params should be dict type."
if name == "adam":
optimizer = optim.Adam(net.parameters(), lr = lr, **optim_params)
elif name == "sgd":
optimizer = optim.SGD(net.parameters(), lr = lr, **optim_params)
else:
raise Exception("No support optimizer")
return optimizer
def get_weight_initializer(self, method = "normal", var = 0.02):
def weight_init(m):
class_name = m.__class__.__name__
if class_name.find("conv") != -1:
if method == "normal":
m.weight.data.normal_(0,var)
elif method == "kaiming":
kaiming_normal_(m.weight.data, 0.0 )
elif method == "xavier":
xavier_normal_(m.weight.data, var)
else:
raise NotImplementedError("Init Method %s not implemented"%(method))
elif class_name.find("norm") != -1:
m.weight.data.normal_(0,0.02)
m.bias.data.fill_(0)
return weight_init
#--------------------- train related functions -----------------------------
def set_grad(self, net, open = False):
for param in net.parameters():
param.requires_grad_(open)
def train(self, netD, netG, dataloader, initialize_model = None):
self.optimizerD = self.get_optimizer(netD, self.D_lr, self.D_optimizer, self.D_optimizer_params)
self.optimizerG = self.get_optimizer(netG, self.G_lr, self.G_optimizer, self.G_optimizer_params)
if initialize_model is not None:
self._initialize_net(netD, netG, method=initialize_model)
print("Using {}. Start Training:".format(self.device))
# train
starttime = time.clock()
for epoch in range(1, self.epochs + 1):
for i, (images, labels) in enumerate(tqdm(dataloader)):
images, labels = images.to(self.device), labels.to(self.device)
if i % self.train_D_interval == 0:
self.train_netD_one_time(netD, netG, images, labels)
if i % self.train_G_interval == 0:
fake_images = self.train_netG_one_time(netD, netG, images, labels)
if i == 10 and self.test_frame:
break
self.steps["iter_step"] += 1
self.steps["epoch_step"] += 1
self.save_images(images, "fake{}.png".format(epoch))
self.save_images(fake_images, "fake{}.png".format(epoch))
if epoch % self.save_model_interval == 0 or epoch == self.epochs:
self.save_model(netD, netG)
endtime = time.clock()
train_time = (endtime - starttime)
print("Training Using %5.2fs" %(train_time))
def generate_fake_images(self, netG, batch_size, labels, *args, **kws):
z = torch.randn((batch_size, self.opt.noise_size)).to(self.device)
return netG(z, labels)
# return netG(z)
def train_netD_one_time(self, netD, netG, images, labels, *args, **kws):
batch_size = images.size(0)
# train D
self.set_grad(netG, open = False)
self.optimizerD.zero_grad()
# calculate D loss
fake_images = self.generate_fake_images(netG, batch_size, labels)
losses = self.loss_calculator.calculate_loss_D(netD, netG, images, fake_images, labels)
if isinstance(losses, tuple):
lossD = losses[0]
else:
lossD = losses
lossD.backward()
self.optimizerD.step()
self.set_grad(netG, open = True)
self.steps["netD_train_step"] += 1
# record
this_step = self.steps["netD_train_step"]
self.logger.add_scalar("Loss/LossD", lossD.item(), this_step)
if self.opt.use_loss == "wgan":
loss_gp = losses[1].item()
self.logger.add_scalar("Loss/LossGP", loss_gp, this_step)
def train_netG_one_time(self, netD, netG, images, labels, *args, **kws):
batch_size = images.size(0)
# train G
self.set_grad(netD, open = False)
self.optimizerG.zero_grad()
# calculate G loss
fake_images = self.generate_fake_images(netG, batch_size, labels)
losses = self.loss_calculator.calculate_loss_G(netD, netG, images, fake_images, labels)
if isinstance(losses, tuple):
lossG = losses[0]
else:
lossG = losses
lossG.backward()
self.optimizerG.step()
self.set_grad(netD, open = True)
self.steps["netG_train_step"] += 1
# record
self.logger.add_scalar("Loss/LossG", lossG.item(), self.steps["netG_train_step"])
return fake_images
#--------------------- save and load functions -----------------------------
def save_model(self, netD, netG):
torch.save(netG.state_dict(), self.save_model_root + "netG.pth")
torch.save(netD.state_dict(), self.save_model_root + "netD.pth")
def save_images(self, fake_images, name, nrow = 16):
save_image(fake_images, self.save_images_root + name, nrow = nrow, normalize=True, range=(-1,1))
class DCGANTrainer(TrainerFrame):
""" Example
"""
def generate_fake_images(self, netG, batch_size, labels, *args, **kws):
z = torch.randn((batch_size, self.opt.noise_size)).to(self.device)
return netG(z)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment