Last active
June 9, 2022 00:35
-
-
Save GrandArth/fcbf461064e447d60142c3fda77750d5 to your computer and use it in GitHub Desktop.
CycleGAN Implementation in Pytorch
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""This file implement a CycleGAN in pytorch, | |
the Structure of the Model follows the One described | |
in David Foster's 'Generative Deep Learning' Ch.5""" | |
import glob | |
import os | |
import torch | |
import torch.nn as nn | |
import torchvision.transforms as T | |
from PIL import Image | |
from torch.utils.data import Dataset, DataLoader | |
import wandb | |
class DownSampleBlock(nn.Module): | |
def __init__(self, num_channel_in: int = 3, num_channel_out: int = 32, | |
kernel_size: int = 4, stride: int = 2, | |
padding: int = 1, dilation: int = 1, useTanh: bool = False): | |
"""The costume layer group for a down sample block, | |
note in some conditions, output size will be the same with input size""" | |
super(DownSampleBlock, self).__init__() | |
self.useTanh = useTanh | |
self.main = nn.Sequential( | |
nn.Conv2d(in_channels=num_channel_in, out_channels=num_channel_out, kernel_size=(kernel_size, kernel_size), | |
stride=(stride, stride), padding=(padding, padding), dilation=(dilation, dilation))) | |
self.normalActivation = nn.Sequential(nn.InstanceNorm2d(num_features=num_channel_out), | |
nn.LeakyReLU()) | |
self.tanhActivation = nn.Tanh() | |
def forward(self, inputTensor): | |
"""If self.useTanh is true, single Tanh layer will be used instead of InstanceNorm2d + LeakyReLU""" | |
output = self.main(inputTensor) | |
if self.useTanh: | |
output = self.tanhActivation(output) | |
else: | |
output = self.normalActivation(output) | |
return output | |
class UpSampleBlock(nn.Module): | |
def __init__(self, input_size: int = 16, num_feature_in: int = 256, | |
num_feature_out: int = 128, kernel_size: int = 4, | |
stride: int = 1, padding: int = 3, dilation: int = 2, | |
useTanh: bool = False): | |
"""Up Sample block that use UpsamplingBilinear2d and Conv2d. | |
(The latter will keep the size of input and output the same) | |
Though might be not necessary, this group support using Tanh as the activation function.""" | |
super(UpSampleBlock, self).__init__() | |
self.useTanh = useTanh | |
self.main = nn.Sequential( | |
nn.UpsamplingBilinear2d(size=(input_size, input_size)), | |
nn.Conv2d(in_channels=num_feature_in, out_channels=num_feature_out, kernel_size=(kernel_size, kernel_size), | |
stride=(stride, stride), padding=(padding, padding), | |
dilation=(dilation, dilation))) | |
self.normalActivation = nn.Sequential(nn.InstanceNorm2d(num_features=num_feature_out), | |
nn.LeakyReLU()) | |
self.tanhActivation = nn.Tanh() | |
def forward(self, inputTensor): | |
output = self.main(inputTensor) | |
if self.useTanh: | |
output = self.tanhActivation(output) | |
else: | |
output = self.normalActivation(output) | |
return output | |
class ResidualBlock(nn.Module): | |
def __init__(self, channels: int = 3): | |
"""the block for residential body, using add operation to establish skip connection""" | |
super(ResidualBlock, self).__init__() | |
self.main = nn.Sequential( | |
nn.Conv2d(in_channels=channels, out_channels=channels, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), | |
dilation=(1, 1)), | |
nn.InstanceNorm2d(num_features=channels), | |
nn.LeakyReLU(), | |
nn.Conv2d(in_channels=channels, out_channels=channels, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), | |
dilation=(1, 1)), | |
nn.InstanceNorm2d(num_features=channels)) | |
def forward(self, inputTensor): | |
output = self.main(inputTensor) | |
return torch.add(output, inputTensor) | |
class UpSampleConvT(UpSampleBlock): | |
def __init__(self, num_feature_in: int = 256, | |
num_feature_out: int = 128, kernel_size: int = 4, | |
stride: int = 1, padding: int = 3, dilation: int = 2, | |
useTanh: bool = False): | |
"""UpSample Group that use Conv Transpose instead of the combination of UpSample and Conv""" | |
super(UpSampleConvT, self).__init__(num_feature_in=num_feature_in, num_feature_out=num_feature_out, | |
kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, | |
useTanh=useTanh) | |
self.main = nn.Sequential(nn.ConvTranspose2d(in_channels=num_feature_in, out_channels=num_feature_out, | |
kernel_size=(kernel_size, kernel_size), | |
stride=(stride, stride), padding=(padding, padding), | |
dilation=dilation)) | |
"""Only override the original main block""" | |
class UnetGen(nn.Module): | |
def __init__(self): | |
"""To keep the output and input shape of Conv2d layer consistent, i used | |
stride=(1, 1), padding=(3, 3), dilation=(2, 2) as parameters. | |
(1,3,1) to half the output. | |
Note this is different from original Unet implementation.""" | |
super(UnetGen, self).__init__() | |
self.downSampleBlock1 = DownSampleBlock(3, 32) | |
self.downSampleBlock2 = DownSampleBlock(32, 64) | |
self.downSampleBlock3 = DownSampleBlock(64, 128) | |
self.downSampleBlock4 = DownSampleBlock(128, 256) | |
self.upSampleBlock1 = UpSampleBlock(16, 256, 128) | |
self.upSampleBlock2 = UpSampleBlock(32, 128 * 2, 64) | |
self.upSampleBlock3 = UpSampleBlock(64, 64 * 2, 32) | |
self.upSampleBlock4 = UpSampleBlock(128, 32 * 2, 3) | |
def forward(self, inputTensor): | |
downSampled1 = self.downSampleBlock1(inputTensor) | |
downSampled2 = self.downSampleBlock2(downSampled1) | |
downSampled3 = self.downSampleBlock3(downSampled2) | |
downSampled4 = self.downSampleBlock4(downSampled3) | |
upSampled1 = self.upSampleBlock1(downSampled4) | |
upSampled2 = self.upSampleBlock2(torch.cat((upSampled1, downSampled3), dim=1)) | |
upSampled3 = self.upSampleBlock3(torch.cat((upSampled2, downSampled2), dim=1)) | |
upSampled4 = self.upSampleBlock4(torch.cat((upSampled3, downSampled1), dim=1)) | |
return upSampled4 | |
class ResNetGen(nn.Module): | |
def __init__(self): | |
super(ResNetGen, self).__init__() | |
self.down = nn.Sequential(DownSampleBlock(num_channel_in=3, num_channel_out=32, kernel_size=7, | |
stride=1, padding=3, dilation=1), | |
DownSampleBlock(num_channel_in=32, num_channel_out=64, kernel_size=3, | |
stride=2, padding=1, dilation=1), | |
DownSampleBlock(num_channel_in=64, num_channel_out=128, kernel_size=3, | |
stride=2, padding=1, dilation=1) | |
) | |
self.residuals = nn.Sequential(ResidualBlock(128), | |
ResidualBlock(128), | |
ResidualBlock(128), | |
ResidualBlock(128), | |
ResidualBlock(128), | |
ResidualBlock(128), | |
ResidualBlock(128), | |
ResidualBlock(128), | |
ResidualBlock(128)) | |
self.up = nn.Sequential(UpSampleConvT(num_feature_in=128, num_feature_out=64, kernel_size=3, stride=2, | |
padding=1, dilation=1), | |
UpSampleConvT(num_feature_in=64, num_feature_out=32, kernel_size=3, stride=2, | |
padding=1, dilation=1), | |
UpSampleConvT(num_feature_in=32, num_feature_out=16, kernel_size=3, stride=2, | |
padding=1, dilation=1), | |
DownSampleBlock(num_channel_in=16, num_channel_out=3, kernel_size=7, | |
stride=1, padding=3, dilation=1, useTanh=True)) | |
"""This downSampleBlock configuration will keep the input size.""" | |
def forward(self, inputTensor): | |
output = self.down(inputTensor) | |
output = self.residuals(output) | |
output = self.up(output) | |
return output | |
class Discriminator(nn.Module): | |
def __init__(self): | |
super(Discriminator, self).__init__() | |
self.mainBlock = nn.Sequential( | |
nn.Conv2d(in_channels=3, out_channels=32, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), | |
dilation=(1, 1)), | |
nn.LeakyReLU(), | |
DownSampleBlock(num_channel_in=32, num_channel_out=64, kernel_size=4, stride=2, padding=1, dilation=1), | |
DownSampleBlock(num_channel_in=64, num_channel_out=128, kernel_size=4, stride=2, padding=1, dilation=1), | |
DownSampleBlock(num_channel_in=128, num_channel_out=256, kernel_size=4, stride=2, padding=1, dilation=1), | |
nn.Conv2d(in_channels=256, out_channels=1, kernel_size=(4, 4), stride=(1, 1), padding=(3, 3), | |
dilation=(2, 2)) | |
) | |
def forward(self, inputTensor): | |
outputTensor = self.mainBlock(inputTensor) | |
return outputTensor | |
class CycleGAN(nn.Module): | |
def __init__(self): | |
super().__init__() | |
self.GenA = UnetGen() | |
self.GenB = UnetGen() | |
self.DiscA = Discriminator() | |
self.DiscB = Discriminator() | |
class CostumeImageSet(Dataset): | |
def __init__(self, root="data/Accountant", transform=None): | |
self.inputDirectory = root | |
self.inputList = glob.glob(os.path.join(self.inputDirectory, "*.jpg"), recursive=True) | |
self.transform = transform | |
def __len__(self): | |
return len(self.inputList) | |
def __getitem__(self, idx): | |
"""Offline version""" | |
input_data = (Image.open(self.inputList[idx])).convert('RGB') | |
if self.transform: | |
input_data = self.transform(input_data) | |
return input_data, torch.tensor(1, dtype=torch.float32) | |
def FreezeParameter(model, freeze=True): | |
"""Freeze part of model when training the other""" | |
for p in model.parameters(): | |
p.requires_grad = not freeze | |
def train_generator(loader_A, loader_B, model, lossFnMSE, lossFNMAE, optimizer, default_device): | |
"""This function should be refactored, need Some kind of dict to match dataset to | |
corresponding model group""" | |
FreezeParameter(model.DiscA) | |
FreezeParameter(model.DiscB) | |
for step, (image, label) in enumerate(loader_A): | |
"""Train Generator B""" | |
image, label = image.to(device=default_device), label.to(device=default_device, dtype=torch.float32) | |
stylized_image = model.GenB(image) | |
patch_prediction = model.DiscB(stylized_image) | |
label = label.expand( | |
patch_prediction.shape) # expand label (originally a single int) to match size of prediction | |
loss_traditional = lossFnMSE(label, patch_prediction) | |
wandb.log({"gen_loss_traditional_A": loss_traditional.item()}) | |
if step % 100 == 0: | |
wandb_log_image = wandb.Image(stylized_image, caption=f"Image StyleB from A set by GenB on step {step}") | |
wandb.log({"GenB_stylized": wandb_log_image}) | |
reConstructedImage = model.GenA(stylized_image) | |
loss_reconstruction = lossFNMAE(image, reConstructedImage) | |
wandb.log({"gen_loss_reconstruction_A": loss_reconstruction.item()}) | |
if step % 100 == 0: | |
wandb_log_image = wandb.Image(reConstructedImage, caption=f"Image StyleA reconed by GenA on step {step}") | |
wandb.log({"GenB_recon_by_GenA": wandb_log_image}) | |
identicalImage = model.GenA(image) | |
loss_identical = lossFNMAE(image, identicalImage) | |
wandb.log({"gen_loss_identical_A": loss_identical.item()}) | |
if step % 100 == 0: | |
wandb_log_image = wandb.Image(identicalImage, caption=f"Image StyleA made by GenA from setA on step {step}") | |
wandb.log({"GenA_iden_by_GenA": wandb_log_image}) | |
loss_sum = loss_traditional + loss_reconstruction + loss_identical | |
wandb.log({"gen_loss_sum_A": loss_sum.item()}) | |
optimizer.zero_grad() | |
loss_sum.backward() | |
optimizer.step() | |
for step, (image, label) in enumerate(loader_B): | |
image, label = image.to(device=default_device), label.to(device=default_device, dtype=torch.float32) | |
"""Train Generator A""" | |
stylized_image = model.GenA(image) | |
patch_prediction = model.DiscA(stylized_image) | |
label = label.expand( | |
patch_prediction.shape) # expand label (originally a single int) to match size of prediction | |
loss_traditional = lossFnMSE(label, patch_prediction) | |
wandb.log({"gen_loss_traditional_B": loss_traditional.item()}) | |
if step % 100 == 0: | |
wandb_log_image = wandb.Image(stylized_image, caption=f"Image StyleA from B set by GenA on step {step}") | |
wandb.log({"GenA_stylized": wandb_log_image}) | |
reConstructedImage = model.GenB(stylized_image) | |
loss_reconstruction = lossFNMAE(image, reConstructedImage) | |
wandb.log({"gen_loss_reconstruction_B": loss_reconstruction.item()}) | |
if step % 100 == 0: | |
wandb_log_image = wandb.Image(reConstructedImage, caption=f"Image StyleB reconed by GenB on step {step}") | |
wandb.log({"GenA_recon_by_GenB": wandb_log_image}) | |
identicalImage = model.GenB(image) | |
loss_identical = lossFNMAE(image, identicalImage) | |
wandb.log({"gen_loss_identical_B": loss_identical.item()}) | |
if step % 100 == 0: | |
wandb_log_image = wandb.Image(identicalImage, caption=f"Image StyleB made by GenB from setB on step {step}") | |
wandb.log({"GenB_iden_by_GenB": wandb_log_image}) | |
loss_sum = loss_traditional + loss_reconstruction + loss_identical | |
wandb.log({"gen_loss_sum_B": loss_sum.item()}) | |
optimizer.zero_grad() | |
loss_sum.backward() | |
optimizer.step() | |
FreezeParameter(model.DiscA, False) | |
FreezeParameter(model.DiscB, False) | |
def train_discriminator(loader_A, loader_B, model, lossFnMSE, optimizer, default_device): | |
FreezeParameter(model.GenA) | |
FreezeParameter(model.GenB) | |
for image, label in loader_A: | |
image, label = image.to(device=default_device), label.to(device=default_device, dtype=torch.float32) | |
patch_prediction = model.DiscA(image) | |
label = label.expand(patch_prediction.shape) | |
loss_positive_A = lossFnMSE(label, patch_prediction) | |
wandb.log({"disc_loss_positive_A": loss_positive_A.item()}) | |
optimizer.zero_grad() | |
loss_positive_A.backward() | |
optimizer.step() | |
"""Train DiscA on positive samples, expand labels to match path prediction as usual""" | |
stylized_image = model.GenB(image) | |
patch_prediction = model.DiscB(stylized_image) | |
label = torch.zeros(patch_prediction.shape) | |
loss_negative_B = lossFnMSE(label, patch_prediction) | |
wandb.log({"disc_loss_negative_B": loss_negative_B.item()}) | |
"""train DiscB on Negative samples, let genB change image from set A into B-stylized image | |
use DiscB to predict those Generated images towards 0 responses""" | |
optimizer.zero_grad() | |
loss_negative_B.backward() | |
optimizer.step() | |
for image, label in loader_B: | |
image, label = image.to(device=default_device), label.to(device=default_device, dtype=torch.float32) | |
"""Just like above loop, but everything reversed from A to B""" | |
patch_prediction = model.DiscB(image) | |
label = label.expand(patch_prediction.shape) | |
loss_positive_B = lossFnMSE(label, patch_prediction) | |
wandb.log({"disc_loss_positive_B": loss_positive_B.item()}) | |
optimizer.zero_grad() | |
loss_positive_B.backward() | |
optimizer.step() | |
stylized_image = model.GenA(image) | |
patch_prediction = model.DiscA(stylized_image) | |
label = torch.zeros(patch_prediction.shape) | |
loss_negative_A = lossFnMSE(label, patch_prediction) | |
wandb.log({"disc_loss_negative_A": loss_negative_A.item()}) | |
optimizer.zero_grad() | |
loss_negative_A.backward() | |
optimizer.step() | |
FreezeParameter(model.GenA, False) | |
FreezeParameter(model.GenB, False) | |
def train(loader_A, loader_B, model, lossFnMSE, lossFNMAE, optimizer, ratio, default_device): | |
"""Note the batch size should be 1""" | |
train_generator(loader_A, loader_B, model, lossFnMSE, lossFNMAE, optimizer, default_device) | |
for _ in range(ratio): | |
train_discriminator(loader_A, loader_B, model, lossFnMSE, optimizer, default_device) | |
if __name__ == '__main__': | |
offline_wandb = True | |
if offline_wandb: | |
os.environ["WANDB_API_KEY"] = "" | |
os.environ["WANDB_MODE"] = "offline" | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
transform_list = T.Compose([ | |
T.ToTensor(), | |
T.Resize((128, 128), interpolation=T.InterpolationMode.NEAREST), | |
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), | |
]) | |
batch_size = 1 | |
epochs = 10 | |
learning_rate = 1e-3 | |
wandb.init(project="CycleGan") | |
wandb.config = { | |
"learning_rate": learning_rate, | |
"epochs": epochs, | |
"batch_size": batch_size, | |
} | |
main_model = CycleGAN() | |
main_model.to(device=device) | |
loss_mse = nn.MSELoss() | |
loss_mae = nn.L1Loss() | |
Optimizer = torch.optim.Adam(main_model.parameters(), lr=learning_rate) | |
datasetA = CostumeImageSet(root="run/86", transform=transform_list) | |
dataloaderA = DataLoader(datasetA, batch_size=1, shuffle=True) | |
datasetB = CostumeImageSet(root="run/st", transform=transform_list) | |
dataloaderB = DataLoader(datasetB, batch_size=1, shuffle=True) | |
for epoch in range(epochs): | |
"""Train function here""" | |
train(dataloaderA, datasetB, main_model, loss_mse, loss_mae, Optimizer, 5, device) | |
print("Done!") | |
PATH = 'CycleGan.pth' | |
torch.save(main_model.state_dict(), PATH) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment