Skip to content

Instantly share code, notes, and snippets.

@GrandArth
Last active June 9, 2022 00:35
Show Gist options
  • Save GrandArth/fcbf461064e447d60142c3fda77750d5 to your computer and use it in GitHub Desktop.
Save GrandArth/fcbf461064e447d60142c3fda77750d5 to your computer and use it in GitHub Desktop.
CycleGAN Implementation in Pytorch
"""This file implement a CycleGAN in pytorch,
the Structure of the Model follows the One described
in David Foster's 'Generative Deep Learning' Ch.5"""
import glob
import os
import torch
import torch.nn as nn
import torchvision.transforms as T
from PIL import Image
from torch.utils.data import Dataset, DataLoader
import wandb
class DownSampleBlock(nn.Module):
def __init__(self, num_channel_in: int = 3, num_channel_out: int = 32,
kernel_size: int = 4, stride: int = 2,
padding: int = 1, dilation: int = 1, useTanh: bool = False):
"""The costume layer group for a down sample block,
note in some conditions, output size will be the same with input size"""
super(DownSampleBlock, self).__init__()
self.useTanh = useTanh
self.main = nn.Sequential(
nn.Conv2d(in_channels=num_channel_in, out_channels=num_channel_out, kernel_size=(kernel_size, kernel_size),
stride=(stride, stride), padding=(padding, padding), dilation=(dilation, dilation)))
self.normalActivation = nn.Sequential(nn.InstanceNorm2d(num_features=num_channel_out),
nn.LeakyReLU())
self.tanhActivation = nn.Tanh()
def forward(self, inputTensor):
"""If self.useTanh is true, single Tanh layer will be used instead of InstanceNorm2d + LeakyReLU"""
output = self.main(inputTensor)
if self.useTanh:
output = self.tanhActivation(output)
else:
output = self.normalActivation(output)
return output
class UpSampleBlock(nn.Module):
def __init__(self, input_size: int = 16, num_feature_in: int = 256,
num_feature_out: int = 128, kernel_size: int = 4,
stride: int = 1, padding: int = 3, dilation: int = 2,
useTanh: bool = False):
"""Up Sample block that use UpsamplingBilinear2d and Conv2d.
(The latter will keep the size of input and output the same)
Though might be not necessary, this group support using Tanh as the activation function."""
super(UpSampleBlock, self).__init__()
self.useTanh = useTanh
self.main = nn.Sequential(
nn.UpsamplingBilinear2d(size=(input_size, input_size)),
nn.Conv2d(in_channels=num_feature_in, out_channels=num_feature_out, kernel_size=(kernel_size, kernel_size),
stride=(stride, stride), padding=(padding, padding),
dilation=(dilation, dilation)))
self.normalActivation = nn.Sequential(nn.InstanceNorm2d(num_features=num_feature_out),
nn.LeakyReLU())
self.tanhActivation = nn.Tanh()
def forward(self, inputTensor):
output = self.main(inputTensor)
if self.useTanh:
output = self.tanhActivation(output)
else:
output = self.normalActivation(output)
return output
class ResidualBlock(nn.Module):
def __init__(self, channels: int = 3):
"""the block for residential body, using add operation to establish skip connection"""
super(ResidualBlock, self).__init__()
self.main = nn.Sequential(
nn.Conv2d(in_channels=channels, out_channels=channels, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1),
dilation=(1, 1)),
nn.InstanceNorm2d(num_features=channels),
nn.LeakyReLU(),
nn.Conv2d(in_channels=channels, out_channels=channels, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1),
dilation=(1, 1)),
nn.InstanceNorm2d(num_features=channels))
def forward(self, inputTensor):
output = self.main(inputTensor)
return torch.add(output, inputTensor)
class UpSampleConvT(UpSampleBlock):
def __init__(self, num_feature_in: int = 256,
num_feature_out: int = 128, kernel_size: int = 4,
stride: int = 1, padding: int = 3, dilation: int = 2,
useTanh: bool = False):
"""UpSample Group that use Conv Transpose instead of the combination of UpSample and Conv"""
super(UpSampleConvT, self).__init__(num_feature_in=num_feature_in, num_feature_out=num_feature_out,
kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation,
useTanh=useTanh)
self.main = nn.Sequential(nn.ConvTranspose2d(in_channels=num_feature_in, out_channels=num_feature_out,
kernel_size=(kernel_size, kernel_size),
stride=(stride, stride), padding=(padding, padding),
dilation=dilation))
"""Only override the original main block"""
class UnetGen(nn.Module):
def __init__(self):
"""To keep the output and input shape of Conv2d layer consistent, i used
stride=(1, 1), padding=(3, 3), dilation=(2, 2) as parameters.
(1,3,1) to half the output.
Note this is different from original Unet implementation."""
super(UnetGen, self).__init__()
self.downSampleBlock1 = DownSampleBlock(3, 32)
self.downSampleBlock2 = DownSampleBlock(32, 64)
self.downSampleBlock3 = DownSampleBlock(64, 128)
self.downSampleBlock4 = DownSampleBlock(128, 256)
self.upSampleBlock1 = UpSampleBlock(16, 256, 128)
self.upSampleBlock2 = UpSampleBlock(32, 128 * 2, 64)
self.upSampleBlock3 = UpSampleBlock(64, 64 * 2, 32)
self.upSampleBlock4 = UpSampleBlock(128, 32 * 2, 3)
def forward(self, inputTensor):
downSampled1 = self.downSampleBlock1(inputTensor)
downSampled2 = self.downSampleBlock2(downSampled1)
downSampled3 = self.downSampleBlock3(downSampled2)
downSampled4 = self.downSampleBlock4(downSampled3)
upSampled1 = self.upSampleBlock1(downSampled4)
upSampled2 = self.upSampleBlock2(torch.cat((upSampled1, downSampled3), dim=1))
upSampled3 = self.upSampleBlock3(torch.cat((upSampled2, downSampled2), dim=1))
upSampled4 = self.upSampleBlock4(torch.cat((upSampled3, downSampled1), dim=1))
return upSampled4
class ResNetGen(nn.Module):
def __init__(self):
super(ResNetGen, self).__init__()
self.down = nn.Sequential(DownSampleBlock(num_channel_in=3, num_channel_out=32, kernel_size=7,
stride=1, padding=3, dilation=1),
DownSampleBlock(num_channel_in=32, num_channel_out=64, kernel_size=3,
stride=2, padding=1, dilation=1),
DownSampleBlock(num_channel_in=64, num_channel_out=128, kernel_size=3,
stride=2, padding=1, dilation=1)
)
self.residuals = nn.Sequential(ResidualBlock(128),
ResidualBlock(128),
ResidualBlock(128),
ResidualBlock(128),
ResidualBlock(128),
ResidualBlock(128),
ResidualBlock(128),
ResidualBlock(128),
ResidualBlock(128))
self.up = nn.Sequential(UpSampleConvT(num_feature_in=128, num_feature_out=64, kernel_size=3, stride=2,
padding=1, dilation=1),
UpSampleConvT(num_feature_in=64, num_feature_out=32, kernel_size=3, stride=2,
padding=1, dilation=1),
UpSampleConvT(num_feature_in=32, num_feature_out=16, kernel_size=3, stride=2,
padding=1, dilation=1),
DownSampleBlock(num_channel_in=16, num_channel_out=3, kernel_size=7,
stride=1, padding=3, dilation=1, useTanh=True))
"""This downSampleBlock configuration will keep the input size."""
def forward(self, inputTensor):
output = self.down(inputTensor)
output = self.residuals(output)
output = self.up(output)
return output
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
self.mainBlock = nn.Sequential(
nn.Conv2d(in_channels=3, out_channels=32, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1),
dilation=(1, 1)),
nn.LeakyReLU(),
DownSampleBlock(num_channel_in=32, num_channel_out=64, kernel_size=4, stride=2, padding=1, dilation=1),
DownSampleBlock(num_channel_in=64, num_channel_out=128, kernel_size=4, stride=2, padding=1, dilation=1),
DownSampleBlock(num_channel_in=128, num_channel_out=256, kernel_size=4, stride=2, padding=1, dilation=1),
nn.Conv2d(in_channels=256, out_channels=1, kernel_size=(4, 4), stride=(1, 1), padding=(3, 3),
dilation=(2, 2))
)
def forward(self, inputTensor):
outputTensor = self.mainBlock(inputTensor)
return outputTensor
class CycleGAN(nn.Module):
def __init__(self):
super().__init__()
self.GenA = UnetGen()
self.GenB = UnetGen()
self.DiscA = Discriminator()
self.DiscB = Discriminator()
class CostumeImageSet(Dataset):
def __init__(self, root="data/Accountant", transform=None):
self.inputDirectory = root
self.inputList = glob.glob(os.path.join(self.inputDirectory, "*.jpg"), recursive=True)
self.transform = transform
def __len__(self):
return len(self.inputList)
def __getitem__(self, idx):
"""Offline version"""
input_data = (Image.open(self.inputList[idx])).convert('RGB')
if self.transform:
input_data = self.transform(input_data)
return input_data, torch.tensor(1, dtype=torch.float32)
def FreezeParameter(model, freeze=True):
"""Freeze part of model when training the other"""
for p in model.parameters():
p.requires_grad = not freeze
def train_generator(loader_A, loader_B, model, lossFnMSE, lossFNMAE, optimizer, default_device):
"""This function should be refactored, need Some kind of dict to match dataset to
corresponding model group"""
FreezeParameter(model.DiscA)
FreezeParameter(model.DiscB)
for step, (image, label) in enumerate(loader_A):
"""Train Generator B"""
image, label = image.to(device=default_device), label.to(device=default_device, dtype=torch.float32)
stylized_image = model.GenB(image)
patch_prediction = model.DiscB(stylized_image)
label = label.expand(
patch_prediction.shape) # expand label (originally a single int) to match size of prediction
loss_traditional = lossFnMSE(label, patch_prediction)
wandb.log({"gen_loss_traditional_A": loss_traditional.item()})
if step % 100 == 0:
wandb_log_image = wandb.Image(stylized_image, caption=f"Image StyleB from A set by GenB on step {step}")
wandb.log({"GenB_stylized": wandb_log_image})
reConstructedImage = model.GenA(stylized_image)
loss_reconstruction = lossFNMAE(image, reConstructedImage)
wandb.log({"gen_loss_reconstruction_A": loss_reconstruction.item()})
if step % 100 == 0:
wandb_log_image = wandb.Image(reConstructedImage, caption=f"Image StyleA reconed by GenA on step {step}")
wandb.log({"GenB_recon_by_GenA": wandb_log_image})
identicalImage = model.GenA(image)
loss_identical = lossFNMAE(image, identicalImage)
wandb.log({"gen_loss_identical_A": loss_identical.item()})
if step % 100 == 0:
wandb_log_image = wandb.Image(identicalImage, caption=f"Image StyleA made by GenA from setA on step {step}")
wandb.log({"GenA_iden_by_GenA": wandb_log_image})
loss_sum = loss_traditional + loss_reconstruction + loss_identical
wandb.log({"gen_loss_sum_A": loss_sum.item()})
optimizer.zero_grad()
loss_sum.backward()
optimizer.step()
for step, (image, label) in enumerate(loader_B):
image, label = image.to(device=default_device), label.to(device=default_device, dtype=torch.float32)
"""Train Generator A"""
stylized_image = model.GenA(image)
patch_prediction = model.DiscA(stylized_image)
label = label.expand(
patch_prediction.shape) # expand label (originally a single int) to match size of prediction
loss_traditional = lossFnMSE(label, patch_prediction)
wandb.log({"gen_loss_traditional_B": loss_traditional.item()})
if step % 100 == 0:
wandb_log_image = wandb.Image(stylized_image, caption=f"Image StyleA from B set by GenA on step {step}")
wandb.log({"GenA_stylized": wandb_log_image})
reConstructedImage = model.GenB(stylized_image)
loss_reconstruction = lossFNMAE(image, reConstructedImage)
wandb.log({"gen_loss_reconstruction_B": loss_reconstruction.item()})
if step % 100 == 0:
wandb_log_image = wandb.Image(reConstructedImage, caption=f"Image StyleB reconed by GenB on step {step}")
wandb.log({"GenA_recon_by_GenB": wandb_log_image})
identicalImage = model.GenB(image)
loss_identical = lossFNMAE(image, identicalImage)
wandb.log({"gen_loss_identical_B": loss_identical.item()})
if step % 100 == 0:
wandb_log_image = wandb.Image(identicalImage, caption=f"Image StyleB made by GenB from setB on step {step}")
wandb.log({"GenB_iden_by_GenB": wandb_log_image})
loss_sum = loss_traditional + loss_reconstruction + loss_identical
wandb.log({"gen_loss_sum_B": loss_sum.item()})
optimizer.zero_grad()
loss_sum.backward()
optimizer.step()
FreezeParameter(model.DiscA, False)
FreezeParameter(model.DiscB, False)
def train_discriminator(loader_A, loader_B, model, lossFnMSE, optimizer, default_device):
FreezeParameter(model.GenA)
FreezeParameter(model.GenB)
for image, label in loader_A:
image, label = image.to(device=default_device), label.to(device=default_device, dtype=torch.float32)
patch_prediction = model.DiscA(image)
label = label.expand(patch_prediction.shape)
loss_positive_A = lossFnMSE(label, patch_prediction)
wandb.log({"disc_loss_positive_A": loss_positive_A.item()})
optimizer.zero_grad()
loss_positive_A.backward()
optimizer.step()
"""Train DiscA on positive samples, expand labels to match path prediction as usual"""
stylized_image = model.GenB(image)
patch_prediction = model.DiscB(stylized_image)
label = torch.zeros(patch_prediction.shape)
loss_negative_B = lossFnMSE(label, patch_prediction)
wandb.log({"disc_loss_negative_B": loss_negative_B.item()})
"""train DiscB on Negative samples, let genB change image from set A into B-stylized image
use DiscB to predict those Generated images towards 0 responses"""
optimizer.zero_grad()
loss_negative_B.backward()
optimizer.step()
for image, label in loader_B:
image, label = image.to(device=default_device), label.to(device=default_device, dtype=torch.float32)
"""Just like above loop, but everything reversed from A to B"""
patch_prediction = model.DiscB(image)
label = label.expand(patch_prediction.shape)
loss_positive_B = lossFnMSE(label, patch_prediction)
wandb.log({"disc_loss_positive_B": loss_positive_B.item()})
optimizer.zero_grad()
loss_positive_B.backward()
optimizer.step()
stylized_image = model.GenA(image)
patch_prediction = model.DiscA(stylized_image)
label = torch.zeros(patch_prediction.shape)
loss_negative_A = lossFnMSE(label, patch_prediction)
wandb.log({"disc_loss_negative_A": loss_negative_A.item()})
optimizer.zero_grad()
loss_negative_A.backward()
optimizer.step()
FreezeParameter(model.GenA, False)
FreezeParameter(model.GenB, False)
def train(loader_A, loader_B, model, lossFnMSE, lossFNMAE, optimizer, ratio, default_device):
"""Note the batch size should be 1"""
train_generator(loader_A, loader_B, model, lossFnMSE, lossFNMAE, optimizer, default_device)
for _ in range(ratio):
train_discriminator(loader_A, loader_B, model, lossFnMSE, optimizer, default_device)
if __name__ == '__main__':
offline_wandb = True
if offline_wandb:
os.environ["WANDB_API_KEY"] = ""
os.environ["WANDB_MODE"] = "offline"
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
transform_list = T.Compose([
T.ToTensor(),
T.Resize((128, 128), interpolation=T.InterpolationMode.NEAREST),
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
batch_size = 1
epochs = 10
learning_rate = 1e-3
wandb.init(project="CycleGan")
wandb.config = {
"learning_rate": learning_rate,
"epochs": epochs,
"batch_size": batch_size,
}
main_model = CycleGAN()
main_model.to(device=device)
loss_mse = nn.MSELoss()
loss_mae = nn.L1Loss()
Optimizer = torch.optim.Adam(main_model.parameters(), lr=learning_rate)
datasetA = CostumeImageSet(root="run/86", transform=transform_list)
dataloaderA = DataLoader(datasetA, batch_size=1, shuffle=True)
datasetB = CostumeImageSet(root="run/st", transform=transform_list)
dataloaderB = DataLoader(datasetB, batch_size=1, shuffle=True)
for epoch in range(epochs):
"""Train function here"""
train(dataloaderA, datasetB, main_model, loss_mse, loss_mae, Optimizer, 5, device)
print("Done!")
PATH = 'CycleGan.pth'
torch.save(main_model.state_dict(), PATH)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment