Skip to content

Instantly share code, notes, and snippets.

@pranavpandey2511
Last active February 26, 2020 08:25
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save pranavpandey2511/6832ba565bf4d3793141bf8d4920dd83 to your computer and use it in GitHub Desktop.
Save pranavpandey2511/6832ba565bf4d3793141bf8d4920dd83 to your computer and use it in GitHub Desktop.
Cycle GAN model
import torch
import itertools
from util.image_pool import ImagePool
from .base_model import BaseModel
from . import networks
from PIL import Image
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from models.vgg19 import VGG19
from models.contextual_loss import contextual_loss as CX
import os
vgg19_path = os.path.join('.','vgg' , 'imagenet-vgg-verydeep-19.mat')
class CycleGANModel(BaseModel):
def name(self):
return 'CycleGANModel'
@staticmethod
def modify_commandline_options(parser, is_train=True):
# default CycleGAN did not use dropout
parser.set_defaults(no_dropout=True)
if is_train:
parser.add_argument('--lambda_A', type=float, default=10.0, help='weight for cycle loss (A -> B -> A)')
parser.add_argument('--lambda_B', type=float, default=10.0,
help='weight for cycle loss (B -> A -> B)')
parser.add_argument('--lambda_identity', type=float, default=0.5,
help='use identity mapping. Setting lambda_identity other than 0 has an effect of scaling the weight of the identity mapping loss. For example, if the weight of the identity loss should be 10 times smaller than the weight of the reconstruction loss, please set lambda_identity = 0.1')
parser.add_argument('--lambda_G_A', type=float, default=1.0, help='weight for G_a loss')
parser.add_argument('--lambda_D_A', type=float, default=1.0, help='weight for D_a loss')
parser.add_argument('--lambda_scale_G_A', type=float, default=10.0, help='weight for D_a loss')
parser.add_argument('--lambda_scale_G_B', type=float, default=10.0, help='weight for D_a loss')
parser.add_argument('--no_identity_b', action='store_true', help='if need to add identity_b to loss , otherwise it is zero')
parser.add_argument('--l0_reg', action='store_true',
help='if need to add lo_reg to loss , otherwise it is zero')
parser.add_argument('--try_a', action='store_true',
help='if need to add lo_reg to loss , otherwise it is zero')
parser.add_argument('--contextual_loss', action='store_true',
help='enable contextual loss')
return parser
def initialize(self, opt):
BaseModel.initialize(self, opt)
# specify the training losses you want to print out. The program will call base_model.get_current_losses
self.loss_names = ['D_A', 'G_A', 'cycle_A', 'idt_A', 'D_B', 'G_B', 'cycle_B', 'idt_B' , 'L0_reg' , 'scale_G_A' , 'scale_G_B' , 'contextual']
# specify the images you want to save/display. The program will call base_model.get_current_visuals
visual_names_A = ['real_A', 'fake_B', 'rec_A']
visual_names_B = ['real_B', 'fake_A', 'rec_B']
if self.isTrain and self.opt.lambda_identity > 0.0:
visual_names_A.append('idt_A')
visual_names_B.append('idt_B')
self.visual_names = visual_names_A + visual_names_B
# specify the models you want to save to the disk. The program will call base_model.save_networks and base_model.load_networks
if self.isTrain:
self.model_names = ['G_A', 'G_B', 'D_A', 'D_B']
else: # during test time, only load Gs
self.model_names = ['G_A', 'G_B']
# load/define networks
# The naming conversion is different from those used in the paper
# Code (paper): G_A (G), G_B (F), D_A (D_Y), D_B (D_X)
self.netG_A = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG, opt.norm,
not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids , is_a = True)
self.netG_B = networks.define_G(opt.output_nc, opt.input_nc, opt.ngf, opt.netG, opt.norm,
not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids , is_a = False)
if self.isTrain:
use_sigmoid = opt.no_lsgan
self.netD_A = networks.define_D(opt.output_nc, opt.ndf, opt.netD,
opt.n_layers_D, opt.norm, use_sigmoid, opt.init_type, opt.init_gain,
self.gpu_ids)
self.netD_B = networks.define_D(opt.input_nc, opt.ndf, opt.netD,
opt.n_layers_D, opt.norm, use_sigmoid, opt.init_type, opt.init_gain,
self.gpu_ids)
if self.isTrain:
self.fake_A_pool = ImagePool(opt.pool_size)
self.fake_B_pool = ImagePool(opt.pool_size)
# define loss functions
self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan).to(self.device)
self.criterionCycle = torch.nn.L1Loss()
self.criterionIdt = torch.nn.L1Loss()
self.criterionScale = torch.nn.L1Loss()
#self.l0norm = torch.sum() # torch.nn.L1Loss()
# initialize optimizers
self.optimizer_G = torch.optim.Adam(itertools.chain(self.netG_A.parameters(), self.netG_B.parameters()),
lr=opt.lr, betas=(opt.beta1, opt.beta2))
self.optimizer_D = torch.optim.Adam(itertools.chain(self.netD_A.parameters(), self.netD_B.parameters()),
lr=opt.lr, betas=(opt.beta1, opt.beta2))
if self.opt.contextual_loss:
self.optimizer_G_contextual = torch.optim.Adam(itertools.chain(self.netG_A.parameters(), self.netG_B.parameters()),
lr=opt.lr*25, betas=(opt.beta1, opt.beta2))
self.optimizers = []
self.optimizers.append(self.optimizer_G)
self.optimizers.append(self.optimizer_D)
if self.opt.contextual_loss:
self.optimizers.append(self.optimizer_G_contextual)
# self.grad_conv_filter = torch.Tensor([[[0, 1, 0], [1, -4, 1], [0, 1, 0]], [[0, 1, 0], [1, -4, 1], [0, 1, 0]],
# [[0, 1, 0], [1, -4, 1], [0, 1, 0]]]).cuda().unsqueeze(0)
#TODO put an if condition to use cuda only when --gpu_ids is >= 0
if(len(opt.gpu_ids) > 0):
self.grad_conv_filter = torch.Tensor([ [ [0, 1, 0], [1, -4, 1], [0, 1, 0]] ]).cuda().unsqueeze(0)
else:
self.grad_conv_filter = torch.Tensor([ [ [0, 1, 0], [1, -4, 1], [0, 1, 0]] ]).unsqueeze(0)
#self.grad_conv_filter = torch.Tensor([ [ [0, 1, 0], [1, -4, 1], [0, 1, 0]] ]).cuda().unsqueeze(0)
if (self.opt.contextual_loss and len(opt.gpu_ids) > 0):
self.vgg19 = VGG19(vgg19_path).cuda()
self.vgg19.eval()
elif (self.opt.contextual_loss):
self.vgg19 = VGG19(vgg19_path)
self.vgg19.eval()
def set_input(self, input):
AtoB = self.opt.direction == 'AtoB'
self.real_A = input['A' if AtoB else 'B'].to(self.device)
self.real_B = input['B' if AtoB else 'A'].to(self.device)
self.image_paths = input['A_paths' if AtoB else 'B_paths']
def forward(self):
self.fake_B = self.netG_A(self.real_A)
#self.rec_A = self.netG_B(self.fake_B)
self.rec_A = self.real_A
# if self.opt.try_a:
# self.fake_A = self.real_B
# self.rec_B = self.real_B
# else:
# self.fake_A = self.netG_B(self.real_B)
# self.rec_B = self.netG_A(self.fake_A)
self.fake_A = self.netG_B(self.real_B)
self.rec_B = self.netG_A(self.fake_A)
def calc_scale_loss(self, real,fake):
list_of_scale = [1,2,4,8,16]
scale_factor = [0.0001,0.001,0.01,0.1,1]
#list_of_scale = [1]
#scale_factor = [1]
_ , __, orig_w, orig_h = real.shape
loss_scale = 0
for index, scale in enumerate(list_of_scale):
scaled_w = int( orig_w / scale )
scaled_h = int( orig_h / scale )
scaled_real = F.adaptive_avg_pool3d(self.rgb2gray(real),(1,scaled_w,scaled_h)) #.resize((scaled_w, scaled_h), Image.BICUBIC)
scaled_fake = F.adaptive_avg_pool3d(self.rgb2gray(fake),(1,scaled_w,scaled_h)) #.resize((scaled_w, scaled_h), Image.BICUBIC)
grad_scaled_real = F.conv2d(scaled_real, self.grad_conv_filter, padding=1) #TODO padding
grad_scaled_fake = F.conv2d(scaled_fake, self.grad_conv_filter, padding=1) # TODO padding
# my_filter = torch.Tensor([[[1, 1, 1], [1, 1, 1], [1, 1, 1]]]).cuda().unsqueeze(0)
# image_filter = F.conv2d(scaled_real, my_filter, padding=1)
# scaleed = image_filter / 9
# use_filter = (scaleed < 0.3).type(torch.cuda.FloatTensor)
# white = scaleed * use_filter + 1 * (1-use_filter)
# white = 1 - white
#grad_scaled_real.required_grad = False
#curr_loss = scale_factor[index] * self.criterionScale(grad_scaled_fake * white, grad_scaled_real * white)
curr_loss = scale_factor[index] * self.criterionScale(grad_scaled_fake , grad_scaled_real )
loss_scale += curr_loss #TODO factor (best for now it's 10)
#self.save_image2(grad_scaled_fake)
return loss_scale
#self.save_image2(white)
def calc_contextual_loss(self,generated_image,target_image):
generated_image_vgg19_layers = self.vgg19(generated_image)
with torch.no_grad():
target_image_vgg19_layers = self.vgg19(target_image)
loss = 0
lambdas = [] # scaling parameters for vgg layers
num_elements = [] # number of elements in each vgg layer
for img_layer in generated_image_vgg19_layers:
num_elem = np.prod(img_layer.size()[1:])
num_elements.append(num_elem)
lambdas.append(1.0/num_elem)
lambdas = [lamda / sum(lambdas) for lamda in lambdas]
for i in range( len(generated_image_vgg19_layers) ):
loss += lambdas[i].__float__() * CX(generated_image_vgg19_layers[i] , target_image_vgg19_layers[i])
return loss
def save_image2(self,output,file_name = 'yoav.png'):
b, c, w, h = output.shape
output = torch.clamp((output + 1)/2, 0, 1)
#output = self.gray2rgb(output)
output = output.permute(0,2,3,1)[0,:,:,:]
I = output.data[:,:,0]
I = I.cpu().numpy()
I8 = (((I - I.min()) / (I.max() - I.min())) * 255.9).astype(np.uint8)
img = Image.fromarray(I8)
img.save(file_name)
#picture = np.zeros((w, h, 3))
# picture[:, :, 0] = output.data[0,0:1, :, :]
# picture[:, :, 1] = output.data[0,1:2, :, :]
# picture[:, :, 2] = output.data[0,2:3, :, :]
#plt.imshow(output.data[:,:,0])
#plt.savefig(file_name)
def save_image(self,output,file_name = 'yoav.png'): #self.save_image(self.real_B[1,:,:,:].squeeze(0),'real.png')
__, w, h = output.shape
output = torch.clamp(output + 0.5, 0, 1)
picture = np.zeros((w, h, 3))
picture[:, :, 0] = output.data[0, :, :]
picture[:, :, 1] = output.data[1, :, :]
picture[:, :, 2] = output.data[2, :, :]
plt.imshow(picture.data)
plt.savefig(file_name)
# grad = torch.abs(grad_scaled_real - grad_scaled_fake)
# _, __, w, h = grad.shape
# output = torch.clamp(grad, 0, 1)
# picture = np.zeros((w, h, 3))
# output = output.squeeze(0)
# picture[:, :, 0] = output.data[0, :, :]
# picture[:, :, 1] = output.data[0, :, :]
# picture[:, :, 2] = output.data[0, :, :]
# plt.imshow(picture.data)
# file_name = 'yoav.png'
# plt.savefig(file_name)
#self.save_image(real.squeeze(0))
# def backward_D_new(self,netD,fake):
#
# pred_fake = netD(fake)
# loss_D_fake = self.criterionGAN(pred_fake, False)
# # Combined loss
# loss_D = (loss_D_fake) * 0.5
# # backward
# loss_D.backward()
#
# return loss_D
def backward_D_basic(self, netD, real, fake,g_real = None, do_another_value = False):
# Real
pred_real = netD(real)
loss_D_real = self.criterionGAN(pred_real, True)
# Fake
pred_fake = netD(fake.detach())
loss_D_fake = self.criterionGAN(pred_fake, False)
if (do_another_value): #TODO
loss_D_real2 = self.criterionGAN(netD(g_real.detach()), True) #TODO make sure it helps
else:
loss_D_real2 = 0
# Combined loss
loss_D = (loss_D_real + loss_D_real2) * 0.5 + loss_D_fake
# backward
loss_D.backward()
return loss_D
def backward_D_A(self):
fake_B = self.fake_B_pool.query(self.fake_B)
self.loss_D_A = self.backward_D_basic(self.netD_A, self.real_B, fake_B, self.idt_A, do_another_value=True) * self.opt.lambda_D_A * 0.5
def backward_D_B(self): ##think
fake_A = self.fake_A_pool.query(self.fake_A)
self.loss_D_B = self.backward_D_basic(self.netD_B, self.real_A, fake_A)
# if self.opt.try_a:
# self.loss_D_B = 0
# else:
# fake_A = self.fake_A_pool.query(self.fake_A)
# self.loss_D_B = self.backward_D_basic(self.netD_B, self.real_A, fake_A)
def backward_G_contextual_loss(self):
self.loss_contextual = self.calc_contextual_loss(self.fake_B, self.real_B) #TODO consider add for input image
# combined loss
self.loss_G = self.loss_contextual # self.loss_G_A + self.loss_G_B + self.loss_cycle_A + self.loss_cycle_B + self.loss_idt_A + self.loss_idt_B + self.loss_L0_reg + self.loss_scale
self.loss_G.backward()
def backward_G(self):
lambda_idt = self.opt.lambda_identity
lambda_A = self.opt.lambda_A
lambda_B = self.opt.lambda_B
# Identity loss
if lambda_idt > 0:
# G_A should be identity if real_B is fed.
self.idt_A = self.netG_A(self.real_B)
self.loss_idt_A = self.criterionIdt(self.idt_A, self.real_B) * lambda_B * lambda_idt
# G_B should be identity if real_A is fed.
if self.opt.try_a == False and self.opt.no_identity_b:
self.idt_B = self.netG_B(self.real_A)
self.loss_idt_B = self.criterionIdt(self.idt_B, self.real_A) * lambda_A * lambda_idt
else :
self.idt_B = self.real_A
self.loss_idt_B = 0 #identical b removed because we expect the noise generator won't output same noisy input
else:
self.loss_idt_A = 0
self.loss_idt_B = 0
if (self.opt.l0_reg) :
#self.loss_L0_reg = 0.000003 * torch.sum(-1 * torch.clamp(self.fake_B,-0.5,0.5) + 0.5) # TODO add parameter
image = -1 * torch.clamp( self.rgb2gray( self.fake_B), -0.5, 0.5) + 0.5
mask_toward_zero = image.clone()
mask_toward_one = image.clone()
mask_toward_zero[mask_toward_zero > 0.5] = 0
mask_toward_one[mask_toward_one < 0.5] = 1
self.loss_L0_reg = 0.0001 *( torch.sum( mask_toward_zero ) + torch.sum( 1 - mask_toward_one ) ) # TODO add parameter
else:
self.loss_L0_reg = 0
self.loss_scale_G_A = self.opt.lambda_scale_G_A * self.calc_scale_loss(self.real_A,self.fake_B)
self.loss_scale_G_B = self.opt.lambda_scale_G_B * self.calc_scale_loss(self.real_B, self.fake_A)
# GAN loss D_A(G_A(A))
self.loss_G_A = ( self.criterionGAN(self.netD_A(self.fake_B), True) )* self.opt.lambda_G_A + self.criterionGAN(self.netD_A(self.idt_A), True)
# GAN loss D_B(G_B(B))
self.loss_G_B = self.criterionGAN(self.netD_B(self.fake_A), True) #think
# if self.opt.try_a:
# self.loss_G_B = 0
# else:
# self.loss_G_B = self.criterionGAN(self.netD_B(self.fake_A), True)
#Forward cycle loss
# self.loss_cycle_A = self.criterionCycle(self.rec_A, self.real_A) * lambda_A
# # Backward cycle loss
# if self.opt.try_a:
# self.loss_cycle_B = 0
# else:
# self.loss_cycle_B = self.criterionCycle(self.rec_B, self.real_B) * lambda_B
# Forward cycle loss
self.loss_cycle_A = 0
self.loss_cycle_B = self.criterionCycle(self.rec_B, self.real_B) * lambda_B
# combined loss
self.loss_G = self.loss_G_A + self.loss_G_B + self.loss_cycle_A + self.loss_cycle_B + self.loss_idt_A + self.loss_idt_B + self.loss_L0_reg + self.loss_scale_G_A + self.loss_scale_G_B
self.loss_G.backward()
def rgb2gray(self, rgb):
r, g, b = rgb[:,0:1, :, :], rgb[:,1:2, :, :], rgb[:,2:3, :, :]
gray = 0.2989 * r + 0.5870 * g + 0.1140 * b
return gray
def gray2rgb(self, gray):
batch, c, w, h = gray.shape
if (c == 1):
#print("convert to rgb")
r, g, b = 0.2989 * gray, 0.5870 * gray, 0.1140 * gray
rgb = torch.zeros((batch , 3,w, h )).cuda()
rgb[:, 0:1, :, :] = r
rgb[:, 1:2, :, :] = g
rgb[:, 2:3, :, :] = b
return rgb
else:
return gray
def print_layer_grad(self, initial_print):
model = self.netG_A
modules_list = list(model.modules())
layer_list = [x for x in modules_list if isinstance(x, torch.nn.Conv2d) or isinstance(x, torch.nn.Linear) ]
grad_list = []
for layer in layer_list:
if isinstance(layer, torch.nn.Conv2d) or isinstance(layer, torch.nn.Linear):
grad_list.append( torch.norm(layer._parameters['weight'].grad, 1).__float__() )
print (initial_print , grad_list)
def optimize_parameters(self):
# forward
self.forward()
# G_A and G_B
self.set_requires_grad([self.netD_A, self.netD_B], False)
self.optimizer_G.zero_grad()
self.backward_G()
#self.print_layer_grad("regular ")
self.optimizer_G.step()
# D_A and D_B
self.set_requires_grad([self.netD_A, self.netD_B], True)
self.optimizer_D.zero_grad()
self.backward_D_A()
self.backward_D_B()
self.optimizer_D.step()
self.loss_contextual = 0
if self.opt.contextual_loss:
torch.cuda.empty_cache()
self.fake_B = self.netG_A(self.real_A)
self.optimizer_G_contextual.zero_grad()
self.backward_G_contextual_loss()
#self.print_layer_grad("con ")
self.optimizer_G_contextual.step()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment