GAN 训练框架版本v2。
对简单的 Conditional GAN 和 Unconditional GAN 可以直接简单套用。
自定义的训练方式通过继承TrainerFrame类,并重写某些函数的形式,如例子的DCGANTrainer
。
- Loss 的计算还是不能很好地兼容不同种类的GAN,需要进一步设计 Loss 的计算类。
import os, time, json | |
from tqdm import tqdm | |
import torch | |
import torch.nn as nn | |
import torch.optim as optim | |
from torch.utils.tensorboard import SummaryWriter | |
from torchvision.utils import make_grid, save_image | |
from torch.nn.init import kaiming_normal_, xavier_normal_ | |
import torch.autograd as autograd | |
# -------------------------------- | |
# Loss | |
class BaseGANLoss: | |
def calculate_loss_D(self, netD, netG, images, fake_images, labels): | |
raise Exception("Not Implemented") | |
def calculate_loss_G(self, netD, netG, images, fake_images, labels): | |
raise Exception("Not Implemented") | |
class GANLoss(BaseGANLoss): | |
def __init__(self, last_layer_without_sigmoid = False, device = "cuda"): | |
if last_layer_without_sigmoid: | |
self.criterion = nn.BCEWithLogitsLoss() | |
else: | |
self.criterion = nn.BCELoss() | |
self.device = device | |
self.one = torch.tensor(1, dtype=torch.float, device=device) | |
self.zero = torch.tensor(0, dtype=torch.float, device=device) | |
def calculate_loss_D(self, netD, netG, images, fake_images, labels): | |
outputs = netD(images) | |
true_loss = self.criterion(outputs, self.one.expand(outputs.size(0))) | |
outputs = netD(fake_images) | |
fake_loss = self.criterion(outputs, self.zero.expand(outputs.size(0))) | |
return true_loss + fake_loss, true_loss, fake_loss | |
def calculate_loss_G(self, netD, netG, images, fake_images, labels): | |
outputs = netD(fake_images) | |
fake_loss = self.criterion(outputs, self.one.expand(outputs.size(0))) | |
return fake_loss | |
class WGANLoss(BaseGANLoss): | |
def __init__(self, gp_lambda = 10, use_gp = True, device = "cuda"): | |
self.use_gp = use_gp | |
self.device = device | |
self.gp_lambda = gp_lambda | |
def calculate_loss_D(self, netD, netG, images, fake_images, labels): | |
out_true = netD(images) | |
out_fake = netD(fake_images) | |
loss = out_fake.mean() - out_true.mean() | |
if self.use_gp: | |
gp_loss = self.calculate_gradient_penalty(netD, images, fake_images, self.gp_lambda, self.device) | |
loss += gp_loss | |
return loss, gp_loss | |
return loss | |
def calculate_loss_G(self, netD, netG, images, fake_images, labels): | |
return - netD(fake_images).mean() | |
def calculate_gradient_penalty(self, netD, images, fake_images, gp_lambda, device): | |
batch_size, C, W, H = images.size() | |
alpha = torch.randn((batch_size, 1), device=device) | |
alpha = alpha.expand((batch_size, C * W * H)).contiguous() | |
alpha = alpha.view_as(images) | |
interpolate = alpha * images + (1 - alpha) * fake_images | |
interpolate = interpolate.to(device) | |
interpolate.requires_grad_(True) | |
out = netD(interpolate) | |
if isinstance(out, tuple): | |
out = out[0] | |
grads = autograd.grad(out, interpolate, | |
grad_outputs=torch.ones_like(out).type(torch.float).to(device), | |
retain_graph=True, create_graph=True)[0] | |
grads = grads.view(grads.size(0), -1) | |
return gp_lambda * ((grads.norm(p=2, dim = 1) - 1) ** 2).mean() | |
class HingeLoss(BaseGANLoss): | |
""" Adversarial Loss proposed and used in "Geometric GAN" and "Self Attention GAN". | |
""" | |
def __init__(self, device = "cuda"): | |
self.device = device | |
self.min0 = lambda x: x - torch.relu(x) | |
def calculate_loss_D(self, netD, netG, images, fake_images, labels): | |
out_true = netD(images, labels) | |
true_loss = (-self.min0(out_true - 1)).mean() | |
out_fake = netD(fake_images, labels) | |
fake_loss = (-self.min0(-1-out_fake)).mean() | |
return true_loss + fake_loss | |
def calculate_loss_G(self, netD, netG, images, fake_images, labels): | |
lossG = - netD(fake_images, labels) | |
return lossG | |
class LSGANLoss(BaseGANLoss): | |
""" Adversrial Loss proposed from Least-Square GAN. | |
""" | |
def __init__(self, a, b, device = "cuda"): | |
self.a = a | |
self.b = b | |
self.device = device | |
def calculate_loss_D(self, netD, netG, images, fake_images, labels): | |
raise Exception("Not Implemented") | |
def calculate_loss_G(self, netD, netG, images, fake_images, labels): | |
raise Exception("Not Implemented") | |
class MarginLoss(BaseGANLoss): | |
""" Adversarial Loss prosposed in "EBGAN" | |
""" | |
def __init__(self, margin): | |
self.margin = margin | |
def calculate_loss_D(self, netD, netG, images, fake_images, labels): | |
true_loss = netD(images) | |
fake_loss = torch.relu(self.margin - netD(fake_images)) | |
lossD = true_loss + fake_loss | |
return lossD | |
def calculate_loss_G(self, netD, netG, images, fake_images, labels): | |
lossG = -netD(fake_images) | |
return lossG | |
class ACGANLoss(BaseGANLoss): | |
""" Adversarial Loss proposed in ACGAN, with auxiliary classification loss. | |
""" | |
def __init__(self, last_layer_without_sigmoid = True, device = "cuda"): | |
self.clsc = nn.CrossEntropyLoss() | |
if last_layer_without_sigmoid: | |
self.dc = nn.BCEWithLogitsLoss() | |
else: | |
self.dc = nn.BCELoss() | |
self.one = torch.tensor(1, dtype=torch.float, device=device) | |
self.zero = torch.tensor(0, dtype=torch.float, device=device) | |
def calculate_loss_D(self, netD, netG, images, fake_images, labels): | |
d_out, cls_out = netD(images) | |
true_d_loss = self.dc(d_out, self.one.expand(d_out.size())) | |
true_cls_loss = self.clsc(cls_out, labels) | |
fake_d_out, fake_cls_out = netD(fake_images) | |
fake_d_loss = self.dc(fake_d_out, self.zero.expand(fake_d_out.size())) | |
fake_cls_loss = self.clsc(fake_cls_out, labels) | |
lossD = true_d_loss + true_cls_loss + fake_d_loss + fake_cls_loss | |
return lossD | |
def calculate_loss_G(self, netD, netG, images, fake_images, labels): | |
fake_d_out, fake_cls_out = netD(fake_images) | |
fake_d_loss = self.dc(fake_d_out, self.one.expand(fake_d_out.size())) | |
fake_cls_loss = self.clsc(fake_cls_out, labels) | |
lossG = fake_d_loss + fake_cls_loss | |
return lossG | |
# ------------------------- | |
# Training Frame for training simple conditional GAN and unconditional GAN | |
# default hyper parameters | |
params = { | |
"device": None, | |
"noise_size" : 100, | |
"epochs" : 200, | |
"train_D_interval" : 1, | |
"train_G_interval" : 1, | |
"save_model_interval" : 5, | |
"D_lr" : 1e-4, | |
"G_lr" : 1e-4, | |
"optimizerD" : "adam", | |
"optimizerG" : "adam", | |
"D_optimizer_params" : None, | |
"G_optimizer_params" : None, | |
"use_loss" : "gan", | |
"loss_params" : {} | |
} | |
class TrainerConfig: | |
def __init__(self, params = params): | |
self.params = params | |
# modify | |
if self.params["device"] is None: | |
self.params["device"] = "cuda" if torch.cuda.is_available() else "cpu" | |
self.params["loss_params"]["device"] = self.params["device"] | |
for name, value in self.params.items(): | |
setattr(self, name, value) | |
def __str__(self): | |
return json.dumps(self.params) | |
def dump_as_json(self, root, name = "config.json"): | |
with open(root + name, "w") as fp: | |
json.dump(self.params, fp) | |
@staticmethod | |
def get_timestr(self): | |
format = "%Y_%m_%d_%H_%M_%S" | |
return time.strftime(format, time.localtime()) | |
# 设计一个框架,必要的时候,可以通过继承的方式修改关键的函数才是最好的方案。 | |
class TrainerFrame: | |
""" To best adapated to training, override some functions, especially : | |
+ functions related to training: `generate_fake_images`, `train_netG_one_time`, `train_netD_one_time` | |
+ functions related to selection for further exploration of more selection | |
""" | |
def __init__(self, opt, save_root = ".", test_frame = False, | |
load_model = False, load_root = None, load_netG_name = "netG.pth", load_netD_name = "netD.pth"): | |
self.save_root = save_root | |
if not os.path.exists(self.save_root): | |
os.makedirs(self.save_root) | |
if self.save_root[-1] != os.sep: | |
self.save_root = self.save_root + os.sep | |
self.save_model_root = self.save_root + "models" + os.sep | |
self.save_images_root = self.save_root + "images" + os.sep | |
if not os.path.exists(self.save_model_root): | |
os.makedirs(self.save_model_root) | |
if not os.path.exists(self.save_images_root): | |
os.makedirs(self.save_images_root) | |
self.logger = SummaryWriter(self.save_root) | |
self.device = opt.device | |
self.opt = opt | |
self.test_frame = test_frame | |
self._initialize_parameters() | |
# if load_model: | |
# self.load_model(load_root, load_netG_name, load_netD_name) | |
# else: | |
# self._initialize_net() | |
if self.test_frame: | |
self._initialize_test_param() | |
#--------------------- initialization functions ----------------------------- | |
def _initialize_net(self, netD, netG, method = "normal", var = 0.02): | |
netD.apply(self.get_weight_initializer(method=method, var=var)) | |
netG.apply(self.get_weight_initializer(method=method, var=var)) | |
def _initialize_parameters(self): | |
self.steps = { # record steps for tensorboard | |
"epoch_step": 0, | |
"iter_step": 0, | |
"netD_train_step": 0, | |
"netG_train_step": 0, | |
} | |
self.iter_intervals = { # training interval | |
"netD_train": self.opt.train_D_interval, | |
"netG_train": self.opt.train_G_interval, | |
} | |
self.epoch_interval = { # intervals in epochs for record | |
"save_model": self.opt.save_model_interval, | |
} | |
self.epochs = self.opt.epochs | |
self.train_D_interval = self.iter_intervals["netD_train"] | |
self.train_G_interval = self.iter_intervals["netG_train"] | |
self.save_model_interval = self.epoch_interval["save_model"] | |
self.D_lr = self.opt.D_lr | |
self.G_lr = self.opt.G_lr | |
self.D_optimizer = self.opt.optimizerD | |
self.G_optimizer = self.opt.optimizerG | |
self.D_optimizer_params = self.opt.D_optimizer_params | |
self.G_optimizer_params = self.opt.G_optimizer_params | |
self.use_loss = self.opt.use_loss | |
self.loss_params = self.opt.loss_params | |
self.loss_calculator = self.select_loss(self.use_loss, self.loss_params) | |
def _initialize_test_param(self): | |
""" Renew some parameters for testing the whole frame. | |
""" | |
self.epochs = 1 | |
self.epoch_interval = { # intervals in epochs for record | |
"save_model": 1, | |
} | |
def switch_to_train(self): | |
""" Switch to training mode if trainer is under test phrase""" | |
if self.test_frame: | |
self.epochs = self.opt.epochs | |
self.test_frame = False | |
self.epoch_interval = { # intervals in epochs for record | |
"save_model": self.opt.save_model_interval, | |
} | |
#--------------------- selection functions ----------------------------- | |
def select_loss(self, name, loss_params, *args, **kws): | |
if name == "gan": | |
calculator = GANLoss(**loss_params) | |
elif name == "wgan": | |
calculator = WGANLoss(**loss_params) | |
elif name == "hinge": | |
calculator = HingeLoss(**loss_params) | |
elif name == "acgan": | |
calculator = ACGANLoss(**loss_params) | |
else: | |
raise Exception("No implemented Loss") | |
return calculator | |
def get_optimizer(self, net, lr, name = "sgd", optim_params = None): | |
if optim_params is None: | |
if name == "adam": | |
optim_params = {"betas":(0.5,0.99), "weight_decay": 0, "eps": 1e-8, "amsgrad":False} | |
elif name == "sgd": | |
optim_params = {"momentum":0, "weight_decay":0, "dampening": 0, "nesterov":False} | |
else: | |
raise Exception("No support optimizer") | |
assert isinstance(optim_params, dict), "optim_params should be dict type." | |
if name == "adam": | |
optimizer = optim.Adam(net.parameters(), lr = lr, **optim_params) | |
elif name == "sgd": | |
optimizer = optim.SGD(net.parameters(), lr = lr, **optim_params) | |
else: | |
raise Exception("No support optimizer") | |
return optimizer | |
def get_weight_initializer(self, method = "normal", var = 0.02): | |
def weight_init(m): | |
class_name = m.__class__.__name__ | |
if class_name.find("conv") != -1: | |
if method == "normal": | |
m.weight.data.normal_(0,var) | |
elif method == "kaiming": | |
kaiming_normal_(m.weight.data, 0.0 ) | |
elif method == "xavier": | |
xavier_normal_(m.weight.data, var) | |
else: | |
raise NotImplementedError("Init Method %s not implemented"%(method)) | |
elif class_name.find("norm") != -1: | |
m.weight.data.normal_(0,0.02) | |
m.bias.data.fill_(0) | |
return weight_init | |
#--------------------- train related functions ----------------------------- | |
def set_grad(self, net, open = False): | |
for param in net.parameters(): | |
param.requires_grad_(open) | |
def train(self, netD, netG, dataloader, initialize_model = None): | |
self.optimizerD = self.get_optimizer(netD, self.D_lr, self.D_optimizer, self.D_optimizer_params) | |
self.optimizerG = self.get_optimizer(netG, self.G_lr, self.G_optimizer, self.G_optimizer_params) | |
if initialize_model is not None: | |
self._initialize_net(netD, netG, method=initialize_model) | |
print("Using {}. Start Training:".format(self.device)) | |
# train | |
starttime = time.clock() | |
for epoch in range(1, self.epochs + 1): | |
for i, (images, labels) in enumerate(tqdm(dataloader)): | |
images, labels = images.to(self.device), labels.to(self.device) | |
if i % self.train_D_interval == 0: | |
self.train_netD_one_time(netD, netG, images, labels) | |
if i % self.train_G_interval == 0: | |
fake_images = self.train_netG_one_time(netD, netG, images, labels) | |
if i == 10 and self.test_frame: | |
break | |
self.steps["iter_step"] += 1 | |
self.steps["epoch_step"] += 1 | |
self.save_images(images, "fake{}.png".format(epoch)) | |
self.save_images(fake_images, "fake{}.png".format(epoch)) | |
if epoch % self.save_model_interval == 0 or epoch == self.epochs: | |
self.save_model(netD, netG) | |
endtime = time.clock() | |
train_time = (endtime - starttime) | |
print("Training Using %5.2fs" %(train_time)) | |
def generate_fake_images(self, netG, batch_size, labels, *args, **kws): | |
z = torch.randn((batch_size, self.opt.noise_size)).to(self.device) | |
return netG(z, labels) | |
# return netG(z) | |
def train_netD_one_time(self, netD, netG, images, labels, *args, **kws): | |
batch_size = images.size(0) | |
# train D | |
self.set_grad(netG, open = False) | |
self.optimizerD.zero_grad() | |
# calculate D loss | |
fake_images = self.generate_fake_images(netG, batch_size, labels) | |
losses = self.loss_calculator.calculate_loss_D(netD, netG, images, fake_images, labels) | |
if isinstance(losses, tuple): | |
lossD = losses[0] | |
else: | |
lossD = losses | |
lossD.backward() | |
self.optimizerD.step() | |
self.set_grad(netG, open = True) | |
self.steps["netD_train_step"] += 1 | |
# record | |
this_step = self.steps["netD_train_step"] | |
self.logger.add_scalar("Loss/LossD", lossD.item(), this_step) | |
if self.opt.use_loss == "wgan": | |
loss_gp = losses[1].item() | |
self.logger.add_scalar("Loss/LossGP", loss_gp, this_step) | |
def train_netG_one_time(self, netD, netG, images, labels, *args, **kws): | |
batch_size = images.size(0) | |
# train G | |
self.set_grad(netD, open = False) | |
self.optimizerG.zero_grad() | |
# calculate G loss | |
fake_images = self.generate_fake_images(netG, batch_size, labels) | |
losses = self.loss_calculator.calculate_loss_G(netD, netG, images, fake_images, labels) | |
if isinstance(losses, tuple): | |
lossG = losses[0] | |
else: | |
lossG = losses | |
lossG.backward() | |
self.optimizerG.step() | |
self.set_grad(netD, open = True) | |
self.steps["netG_train_step"] += 1 | |
# record | |
self.logger.add_scalar("Loss/LossG", lossG.item(), self.steps["netG_train_step"]) | |
return fake_images | |
#--------------------- save and load functions ----------------------------- | |
def save_model(self, netD, netG): | |
torch.save(netG.state_dict(), self.save_model_root + "netG.pth") | |
torch.save(netD.state_dict(), self.save_model_root + "netD.pth") | |
def save_images(self, fake_images, name, nrow = 16): | |
save_image(fake_images, self.save_images_root + name, nrow = nrow, normalize=True, range=(-1,1)) | |
class DCGANTrainer(TrainerFrame): | |
""" Example | |
""" | |
def generate_fake_images(self, netG, batch_size, labels, *args, **kws): | |
z = torch.randn((batch_size, self.opt.noise_size)).to(self.device) | |
return netG(z) |