Last active
February 26, 2020 08:25
-
-
Save pranavpandey2511/6832ba565bf4d3793141bf8d4920dd83 to your computer and use it in GitHub Desktop.
Cycle GAN model
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import itertools | |
from util.image_pool import ImagePool | |
from .base_model import BaseModel | |
from . import networks | |
from PIL import Image | |
import torch.nn.functional as F | |
import numpy as np | |
import matplotlib.pyplot as plt | |
from models.vgg19 import VGG19 | |
from models.contextual_loss import contextual_loss as CX | |
import os | |
vgg19_path = os.path.join('.','vgg' , 'imagenet-vgg-verydeep-19.mat') | |
class CycleGANModel(BaseModel): | |
def name(self): | |
return 'CycleGANModel' | |
@staticmethod | |
def modify_commandline_options(parser, is_train=True): | |
# default CycleGAN did not use dropout | |
parser.set_defaults(no_dropout=True) | |
if is_train: | |
parser.add_argument('--lambda_A', type=float, default=10.0, help='weight for cycle loss (A -> B -> A)') | |
parser.add_argument('--lambda_B', type=float, default=10.0, | |
help='weight for cycle loss (B -> A -> B)') | |
parser.add_argument('--lambda_identity', type=float, default=0.5, | |
help='use identity mapping. Setting lambda_identity other than 0 has an effect of scaling the weight of the identity mapping loss. For example, if the weight of the identity loss should be 10 times smaller than the weight of the reconstruction loss, please set lambda_identity = 0.1') | |
parser.add_argument('--lambda_G_A', type=float, default=1.0, help='weight for G_a loss') | |
parser.add_argument('--lambda_D_A', type=float, default=1.0, help='weight for D_a loss') | |
parser.add_argument('--lambda_scale_G_A', type=float, default=10.0, help='weight for D_a loss') | |
parser.add_argument('--lambda_scale_G_B', type=float, default=10.0, help='weight for D_a loss') | |
parser.add_argument('--no_identity_b', action='store_true', help='if need to add identity_b to loss , otherwise it is zero') | |
parser.add_argument('--l0_reg', action='store_true', | |
help='if need to add lo_reg to loss , otherwise it is zero') | |
parser.add_argument('--try_a', action='store_true', | |
help='if need to add lo_reg to loss , otherwise it is zero') | |
parser.add_argument('--contextual_loss', action='store_true', | |
help='enable contextual loss') | |
return parser | |
def initialize(self, opt): | |
BaseModel.initialize(self, opt) | |
# specify the training losses you want to print out. The program will call base_model.get_current_losses | |
self.loss_names = ['D_A', 'G_A', 'cycle_A', 'idt_A', 'D_B', 'G_B', 'cycle_B', 'idt_B' , 'L0_reg' , 'scale_G_A' , 'scale_G_B' , 'contextual'] | |
# specify the images you want to save/display. The program will call base_model.get_current_visuals | |
visual_names_A = ['real_A', 'fake_B', 'rec_A'] | |
visual_names_B = ['real_B', 'fake_A', 'rec_B'] | |
if self.isTrain and self.opt.lambda_identity > 0.0: | |
visual_names_A.append('idt_A') | |
visual_names_B.append('idt_B') | |
self.visual_names = visual_names_A + visual_names_B | |
# specify the models you want to save to the disk. The program will call base_model.save_networks and base_model.load_networks | |
if self.isTrain: | |
self.model_names = ['G_A', 'G_B', 'D_A', 'D_B'] | |
else: # during test time, only load Gs | |
self.model_names = ['G_A', 'G_B'] | |
# load/define networks | |
# The naming conversion is different from those used in the paper | |
# Code (paper): G_A (G), G_B (F), D_A (D_Y), D_B (D_X) | |
self.netG_A = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG, opt.norm, | |
not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids , is_a = True) | |
self.netG_B = networks.define_G(opt.output_nc, opt.input_nc, opt.ngf, opt.netG, opt.norm, | |
not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids , is_a = False) | |
if self.isTrain: | |
use_sigmoid = opt.no_lsgan | |
self.netD_A = networks.define_D(opt.output_nc, opt.ndf, opt.netD, | |
opt.n_layers_D, opt.norm, use_sigmoid, opt.init_type, opt.init_gain, | |
self.gpu_ids) | |
self.netD_B = networks.define_D(opt.input_nc, opt.ndf, opt.netD, | |
opt.n_layers_D, opt.norm, use_sigmoid, opt.init_type, opt.init_gain, | |
self.gpu_ids) | |
if self.isTrain: | |
self.fake_A_pool = ImagePool(opt.pool_size) | |
self.fake_B_pool = ImagePool(opt.pool_size) | |
# define loss functions | |
self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan).to(self.device) | |
self.criterionCycle = torch.nn.L1Loss() | |
self.criterionIdt = torch.nn.L1Loss() | |
self.criterionScale = torch.nn.L1Loss() | |
#self.l0norm = torch.sum() # torch.nn.L1Loss() | |
# initialize optimizers | |
self.optimizer_G = torch.optim.Adam(itertools.chain(self.netG_A.parameters(), self.netG_B.parameters()), | |
lr=opt.lr, betas=(opt.beta1, opt.beta2)) | |
self.optimizer_D = torch.optim.Adam(itertools.chain(self.netD_A.parameters(), self.netD_B.parameters()), | |
lr=opt.lr, betas=(opt.beta1, opt.beta2)) | |
if self.opt.contextual_loss: | |
self.optimizer_G_contextual = torch.optim.Adam(itertools.chain(self.netG_A.parameters(), self.netG_B.parameters()), | |
lr=opt.lr*25, betas=(opt.beta1, opt.beta2)) | |
self.optimizers = [] | |
self.optimizers.append(self.optimizer_G) | |
self.optimizers.append(self.optimizer_D) | |
if self.opt.contextual_loss: | |
self.optimizers.append(self.optimizer_G_contextual) | |
# self.grad_conv_filter = torch.Tensor([[[0, 1, 0], [1, -4, 1], [0, 1, 0]], [[0, 1, 0], [1, -4, 1], [0, 1, 0]], | |
# [[0, 1, 0], [1, -4, 1], [0, 1, 0]]]).cuda().unsqueeze(0) | |
#TODO put an if condition to use cuda only when --gpu_ids is >= 0 | |
if(len(opt.gpu_ids) > 0): | |
self.grad_conv_filter = torch.Tensor([ [ [0, 1, 0], [1, -4, 1], [0, 1, 0]] ]).cuda().unsqueeze(0) | |
else: | |
self.grad_conv_filter = torch.Tensor([ [ [0, 1, 0], [1, -4, 1], [0, 1, 0]] ]).unsqueeze(0) | |
#self.grad_conv_filter = torch.Tensor([ [ [0, 1, 0], [1, -4, 1], [0, 1, 0]] ]).cuda().unsqueeze(0) | |
if (self.opt.contextual_loss and len(opt.gpu_ids) > 0): | |
self.vgg19 = VGG19(vgg19_path).cuda() | |
self.vgg19.eval() | |
elif (self.opt.contextual_loss): | |
self.vgg19 = VGG19(vgg19_path) | |
self.vgg19.eval() | |
def set_input(self, input): | |
AtoB = self.opt.direction == 'AtoB' | |
self.real_A = input['A' if AtoB else 'B'].to(self.device) | |
self.real_B = input['B' if AtoB else 'A'].to(self.device) | |
self.image_paths = input['A_paths' if AtoB else 'B_paths'] | |
def forward(self): | |
self.fake_B = self.netG_A(self.real_A) | |
#self.rec_A = self.netG_B(self.fake_B) | |
self.rec_A = self.real_A | |
# if self.opt.try_a: | |
# self.fake_A = self.real_B | |
# self.rec_B = self.real_B | |
# else: | |
# self.fake_A = self.netG_B(self.real_B) | |
# self.rec_B = self.netG_A(self.fake_A) | |
self.fake_A = self.netG_B(self.real_B) | |
self.rec_B = self.netG_A(self.fake_A) | |
def calc_scale_loss(self, real,fake): | |
list_of_scale = [1,2,4,8,16] | |
scale_factor = [0.0001,0.001,0.01,0.1,1] | |
#list_of_scale = [1] | |
#scale_factor = [1] | |
_ , __, orig_w, orig_h = real.shape | |
loss_scale = 0 | |
for index, scale in enumerate(list_of_scale): | |
scaled_w = int( orig_w / scale ) | |
scaled_h = int( orig_h / scale ) | |
scaled_real = F.adaptive_avg_pool3d(self.rgb2gray(real),(1,scaled_w,scaled_h)) #.resize((scaled_w, scaled_h), Image.BICUBIC) | |
scaled_fake = F.adaptive_avg_pool3d(self.rgb2gray(fake),(1,scaled_w,scaled_h)) #.resize((scaled_w, scaled_h), Image.BICUBIC) | |
grad_scaled_real = F.conv2d(scaled_real, self.grad_conv_filter, padding=1) #TODO padding | |
grad_scaled_fake = F.conv2d(scaled_fake, self.grad_conv_filter, padding=1) # TODO padding | |
# my_filter = torch.Tensor([[[1, 1, 1], [1, 1, 1], [1, 1, 1]]]).cuda().unsqueeze(0) | |
# image_filter = F.conv2d(scaled_real, my_filter, padding=1) | |
# scaleed = image_filter / 9 | |
# use_filter = (scaleed < 0.3).type(torch.cuda.FloatTensor) | |
# white = scaleed * use_filter + 1 * (1-use_filter) | |
# white = 1 - white | |
#grad_scaled_real.required_grad = False | |
#curr_loss = scale_factor[index] * self.criterionScale(grad_scaled_fake * white, grad_scaled_real * white) | |
curr_loss = scale_factor[index] * self.criterionScale(grad_scaled_fake , grad_scaled_real ) | |
loss_scale += curr_loss #TODO factor (best for now it's 10) | |
#self.save_image2(grad_scaled_fake) | |
return loss_scale | |
#self.save_image2(white) | |
def calc_contextual_loss(self,generated_image,target_image): | |
generated_image_vgg19_layers = self.vgg19(generated_image) | |
with torch.no_grad(): | |
target_image_vgg19_layers = self.vgg19(target_image) | |
loss = 0 | |
lambdas = [] # scaling parameters for vgg layers | |
num_elements = [] # number of elements in each vgg layer | |
for img_layer in generated_image_vgg19_layers: | |
num_elem = np.prod(img_layer.size()[1:]) | |
num_elements.append(num_elem) | |
lambdas.append(1.0/num_elem) | |
lambdas = [lamda / sum(lambdas) for lamda in lambdas] | |
for i in range( len(generated_image_vgg19_layers) ): | |
loss += lambdas[i].__float__() * CX(generated_image_vgg19_layers[i] , target_image_vgg19_layers[i]) | |
return loss | |
def save_image2(self,output,file_name = 'yoav.png'): | |
b, c, w, h = output.shape | |
output = torch.clamp((output + 1)/2, 0, 1) | |
#output = self.gray2rgb(output) | |
output = output.permute(0,2,3,1)[0,:,:,:] | |
I = output.data[:,:,0] | |
I = I.cpu().numpy() | |
I8 = (((I - I.min()) / (I.max() - I.min())) * 255.9).astype(np.uint8) | |
img = Image.fromarray(I8) | |
img.save(file_name) | |
#picture = np.zeros((w, h, 3)) | |
# picture[:, :, 0] = output.data[0,0:1, :, :] | |
# picture[:, :, 1] = output.data[0,1:2, :, :] | |
# picture[:, :, 2] = output.data[0,2:3, :, :] | |
#plt.imshow(output.data[:,:,0]) | |
#plt.savefig(file_name) | |
def save_image(self,output,file_name = 'yoav.png'): #self.save_image(self.real_B[1,:,:,:].squeeze(0),'real.png') | |
__, w, h = output.shape | |
output = torch.clamp(output + 0.5, 0, 1) | |
picture = np.zeros((w, h, 3)) | |
picture[:, :, 0] = output.data[0, :, :] | |
picture[:, :, 1] = output.data[1, :, :] | |
picture[:, :, 2] = output.data[2, :, :] | |
plt.imshow(picture.data) | |
plt.savefig(file_name) | |
# grad = torch.abs(grad_scaled_real - grad_scaled_fake) | |
# _, __, w, h = grad.shape | |
# output = torch.clamp(grad, 0, 1) | |
# picture = np.zeros((w, h, 3)) | |
# output = output.squeeze(0) | |
# picture[:, :, 0] = output.data[0, :, :] | |
# picture[:, :, 1] = output.data[0, :, :] | |
# picture[:, :, 2] = output.data[0, :, :] | |
# plt.imshow(picture.data) | |
# file_name = 'yoav.png' | |
# plt.savefig(file_name) | |
#self.save_image(real.squeeze(0)) | |
# def backward_D_new(self,netD,fake): | |
# | |
# pred_fake = netD(fake) | |
# loss_D_fake = self.criterionGAN(pred_fake, False) | |
# # Combined loss | |
# loss_D = (loss_D_fake) * 0.5 | |
# # backward | |
# loss_D.backward() | |
# | |
# return loss_D | |
def backward_D_basic(self, netD, real, fake,g_real = None, do_another_value = False): | |
# Real | |
pred_real = netD(real) | |
loss_D_real = self.criterionGAN(pred_real, True) | |
# Fake | |
pred_fake = netD(fake.detach()) | |
loss_D_fake = self.criterionGAN(pred_fake, False) | |
if (do_another_value): #TODO | |
loss_D_real2 = self.criterionGAN(netD(g_real.detach()), True) #TODO make sure it helps | |
else: | |
loss_D_real2 = 0 | |
# Combined loss | |
loss_D = (loss_D_real + loss_D_real2) * 0.5 + loss_D_fake | |
# backward | |
loss_D.backward() | |
return loss_D | |
def backward_D_A(self): | |
fake_B = self.fake_B_pool.query(self.fake_B) | |
self.loss_D_A = self.backward_D_basic(self.netD_A, self.real_B, fake_B, self.idt_A, do_another_value=True) * self.opt.lambda_D_A * 0.5 | |
def backward_D_B(self): ##think | |
fake_A = self.fake_A_pool.query(self.fake_A) | |
self.loss_D_B = self.backward_D_basic(self.netD_B, self.real_A, fake_A) | |
# if self.opt.try_a: | |
# self.loss_D_B = 0 | |
# else: | |
# fake_A = self.fake_A_pool.query(self.fake_A) | |
# self.loss_D_B = self.backward_D_basic(self.netD_B, self.real_A, fake_A) | |
def backward_G_contextual_loss(self): | |
self.loss_contextual = self.calc_contextual_loss(self.fake_B, self.real_B) #TODO consider add for input image | |
# combined loss | |
self.loss_G = self.loss_contextual # self.loss_G_A + self.loss_G_B + self.loss_cycle_A + self.loss_cycle_B + self.loss_idt_A + self.loss_idt_B + self.loss_L0_reg + self.loss_scale | |
self.loss_G.backward() | |
def backward_G(self): | |
lambda_idt = self.opt.lambda_identity | |
lambda_A = self.opt.lambda_A | |
lambda_B = self.opt.lambda_B | |
# Identity loss | |
if lambda_idt > 0: | |
# G_A should be identity if real_B is fed. | |
self.idt_A = self.netG_A(self.real_B) | |
self.loss_idt_A = self.criterionIdt(self.idt_A, self.real_B) * lambda_B * lambda_idt | |
# G_B should be identity if real_A is fed. | |
if self.opt.try_a == False and self.opt.no_identity_b: | |
self.idt_B = self.netG_B(self.real_A) | |
self.loss_idt_B = self.criterionIdt(self.idt_B, self.real_A) * lambda_A * lambda_idt | |
else : | |
self.idt_B = self.real_A | |
self.loss_idt_B = 0 #identical b removed because we expect the noise generator won't output same noisy input | |
else: | |
self.loss_idt_A = 0 | |
self.loss_idt_B = 0 | |
if (self.opt.l0_reg) : | |
#self.loss_L0_reg = 0.000003 * torch.sum(-1 * torch.clamp(self.fake_B,-0.5,0.5) + 0.5) # TODO add parameter | |
image = -1 * torch.clamp( self.rgb2gray( self.fake_B), -0.5, 0.5) + 0.5 | |
mask_toward_zero = image.clone() | |
mask_toward_one = image.clone() | |
mask_toward_zero[mask_toward_zero > 0.5] = 0 | |
mask_toward_one[mask_toward_one < 0.5] = 1 | |
self.loss_L0_reg = 0.0001 *( torch.sum( mask_toward_zero ) + torch.sum( 1 - mask_toward_one ) ) # TODO add parameter | |
else: | |
self.loss_L0_reg = 0 | |
self.loss_scale_G_A = self.opt.lambda_scale_G_A * self.calc_scale_loss(self.real_A,self.fake_B) | |
self.loss_scale_G_B = self.opt.lambda_scale_G_B * self.calc_scale_loss(self.real_B, self.fake_A) | |
# GAN loss D_A(G_A(A)) | |
self.loss_G_A = ( self.criterionGAN(self.netD_A(self.fake_B), True) )* self.opt.lambda_G_A + self.criterionGAN(self.netD_A(self.idt_A), True) | |
# GAN loss D_B(G_B(B)) | |
self.loss_G_B = self.criterionGAN(self.netD_B(self.fake_A), True) #think | |
# if self.opt.try_a: | |
# self.loss_G_B = 0 | |
# else: | |
# self.loss_G_B = self.criterionGAN(self.netD_B(self.fake_A), True) | |
#Forward cycle loss | |
# self.loss_cycle_A = self.criterionCycle(self.rec_A, self.real_A) * lambda_A | |
# # Backward cycle loss | |
# if self.opt.try_a: | |
# self.loss_cycle_B = 0 | |
# else: | |
# self.loss_cycle_B = self.criterionCycle(self.rec_B, self.real_B) * lambda_B | |
# Forward cycle loss | |
self.loss_cycle_A = 0 | |
self.loss_cycle_B = self.criterionCycle(self.rec_B, self.real_B) * lambda_B | |
# combined loss | |
self.loss_G = self.loss_G_A + self.loss_G_B + self.loss_cycle_A + self.loss_cycle_B + self.loss_idt_A + self.loss_idt_B + self.loss_L0_reg + self.loss_scale_G_A + self.loss_scale_G_B | |
self.loss_G.backward() | |
def rgb2gray(self, rgb): | |
r, g, b = rgb[:,0:1, :, :], rgb[:,1:2, :, :], rgb[:,2:3, :, :] | |
gray = 0.2989 * r + 0.5870 * g + 0.1140 * b | |
return gray | |
def gray2rgb(self, gray): | |
batch, c, w, h = gray.shape | |
if (c == 1): | |
#print("convert to rgb") | |
r, g, b = 0.2989 * gray, 0.5870 * gray, 0.1140 * gray | |
rgb = torch.zeros((batch , 3,w, h )).cuda() | |
rgb[:, 0:1, :, :] = r | |
rgb[:, 1:2, :, :] = g | |
rgb[:, 2:3, :, :] = b | |
return rgb | |
else: | |
return gray | |
def print_layer_grad(self, initial_print): | |
model = self.netG_A | |
modules_list = list(model.modules()) | |
layer_list = [x for x in modules_list if isinstance(x, torch.nn.Conv2d) or isinstance(x, torch.nn.Linear) ] | |
grad_list = [] | |
for layer in layer_list: | |
if isinstance(layer, torch.nn.Conv2d) or isinstance(layer, torch.nn.Linear): | |
grad_list.append( torch.norm(layer._parameters['weight'].grad, 1).__float__() ) | |
print (initial_print , grad_list) | |
def optimize_parameters(self): | |
# forward | |
self.forward() | |
# G_A and G_B | |
self.set_requires_grad([self.netD_A, self.netD_B], False) | |
self.optimizer_G.zero_grad() | |
self.backward_G() | |
#self.print_layer_grad("regular ") | |
self.optimizer_G.step() | |
# D_A and D_B | |
self.set_requires_grad([self.netD_A, self.netD_B], True) | |
self.optimizer_D.zero_grad() | |
self.backward_D_A() | |
self.backward_D_B() | |
self.optimizer_D.step() | |
self.loss_contextual = 0 | |
if self.opt.contextual_loss: | |
torch.cuda.empty_cache() | |
self.fake_B = self.netG_A(self.real_A) | |
self.optimizer_G_contextual.zero_grad() | |
self.backward_G_contextual_loss() | |
#self.print_layer_grad("con ") | |
self.optimizer_G_contextual.step() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment