Created
December 13, 2019 13:13
-
-
Save vfdev-5/7a651ca550843737200f148755c42861 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python | |
# coding: utf-8 | |
opt_level = "O1" | |
import torch | |
print(torch.__version__) | |
import ignite | |
ignite.__file__ | |
ignite.__version__ | |
import random | |
import torch | |
seed = 17 | |
random.seed(seed) | |
_ = torch.manual_seed(seed) | |
import torch.nn as nn | |
def get_conv_inorm_relu(in_planes, out_planes, kernel_size, stride, reflection_pad=True, with_relu=True): | |
layers = [] | |
padding = (kernel_size - 1) // 2 | |
if reflection_pad: | |
layers.append(nn.ReflectionPad2d(padding=padding)) | |
padding = 0 | |
layers += [ | |
nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding), | |
nn.InstanceNorm2d(out_planes, affine=False, track_running_stats=False), | |
] | |
if with_relu: | |
layers.append(nn.ReLU(inplace=True)) | |
return nn.Sequential(*layers) | |
def get_conv_transposed_inorm_relu(in_planes, out_planes, kernel_size, stride): | |
return nn.Sequential( | |
nn.ConvTranspose2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=1, output_padding=1), | |
nn.InstanceNorm2d(out_planes, affine=False, track_running_stats=False), | |
nn.ReLU(inplace=True) | |
) | |
class ResidualBlock(nn.Module): | |
def __init__(self, in_planes): | |
super(ResidualBlock, self).__init__() | |
self.conv1 = get_conv_inorm_relu(in_planes, in_planes, kernel_size=3, stride=1) | |
self.conv2 = get_conv_inorm_relu(in_planes, in_planes, kernel_size=3, stride=1, with_relu=False) | |
def forward(self, x): | |
residual = x | |
x = self.conv1(x) | |
x = self.conv2(x) | |
return x + residual | |
class Generator(nn.Module): | |
def __init__(self): | |
super(Generator, self).__init__() | |
self.c7s1_64 = get_conv_inorm_relu(3, 64, kernel_size=7, stride=1) | |
self.d128 = get_conv_inorm_relu(64, 128, kernel_size=3, stride=2, reflection_pad=False) | |
self.d256 = get_conv_inorm_relu(128, 256, kernel_size=3, stride=2, reflection_pad=False) | |
self.resnet9 = nn.Sequential(*[ResidualBlock(256) for i in range(9)]) | |
self.u128 = get_conv_transposed_inorm_relu(256, 128, kernel_size=3, stride=2) | |
self.u64 = get_conv_transposed_inorm_relu(128, 64, kernel_size=3, stride=2) | |
self.c7s1_3 = get_conv_inorm_relu(64, 3, kernel_size=7, stride=1, with_relu=False) | |
# Replace instance norm by tanh activation | |
self.c7s1_3[-1] = nn.Tanh() | |
def forward(self, x): | |
# Encoding | |
x = self.c7s1_64(x) | |
x = self.d128(x) | |
x = self.d256(x) | |
# 9 residual blocks | |
x = self.resnet9(x) | |
# Decoding | |
x = self.u128(x) | |
x = self.u64(x) | |
y = self.c7s1_3(x) | |
return y | |
def get_conv_inorm_lrelu(in_planes, out_planes, stride=2, negative_slope=0.2): | |
return nn.Sequential( | |
nn.Conv2d(in_planes, out_planes, kernel_size=4, stride=stride, padding=1), | |
nn.InstanceNorm2d(out_planes, affine=False, track_running_stats=False), | |
nn.LeakyReLU(negative_slope=negative_slope, inplace=True) | |
) | |
class discriminators(nn.Module): | |
def __init__(self): | |
super(discriminators, self).__init__() | |
self.c64 = nn.Conv2d(3, 64, kernel_size=4, stride=2, padding=1) | |
self.relu = nn.LeakyReLU(0.2, inplace=True) | |
self.c128 = get_conv_inorm_lrelu(64, 128) | |
self.c256 = get_conv_inorm_lrelu(128, 256) | |
self.c512 = get_conv_inorm_lrelu(256, 512, stride=1) | |
self.last_conv = nn.Conv2d(512, 1, kernel_size=4, stride=1, padding=1) | |
def forward(self, x): | |
x = self.c64(x) | |
x = self.relu(x) | |
x = self.c128(x) | |
x = self.c256(x) | |
x = self.c512(x) | |
y = self.last_conv(x) | |
return y | |
def init_weights(module): | |
assert isinstance(module, nn.Module) | |
if hasattr(module, "weight") and module.weight is not None: | |
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) | |
if hasattr(module, "bias") and module.bias is not None: | |
torch.nn.init.constant_(module.bias, 0.0) | |
for c in module.children(): | |
init_weights(c) | |
assert torch.backends.cudnn.enabled, "NVIDIA/Apex:Amp requires cudnn backend to be enabled." | |
torch.backends.cudnn.benchmark = True | |
device = "cuda" | |
generator_A2B = Generator().to(device) | |
init_weights(generator_A2B) | |
discriminators_B = discriminators().to(device) | |
init_weights(discriminators_B) | |
generator_B2A = Generator().to(device) | |
init_weights(generator_B2A) | |
discriminators_A = discriminators().to(device) | |
init_weights(discriminators_A) | |
from itertools import chain | |
import torch.optim as optim | |
lr = 0.0002 | |
beta1 = 0.5 | |
optimizer_G = optim.Adam(chain(generator_A2B.parameters(), generator_B2A.parameters()), lr=lr, betas=(beta1, 0.999)) | |
optimizer_D = optim.Adam(chain(discriminators_A.parameters(), discriminators_B.parameters()), lr=lr, betas=(beta1, 0.999)) | |
def toggle_grad(model, on_or_off): | |
# https://github.com/ajbrock/BigGAN-PyTorch/blob/master/utils.py#L674 | |
for param in model.parameters(): | |
param.requires_grad = on_or_off | |
try: | |
from apex import amp | |
except ImportError: | |
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to run this example.") | |
# Initialize Amp | |
models, optimizers = amp.initialize([generator_A2B, generator_B2A, discriminators_A, discriminators_B], | |
[optimizer_G, optimizer_D], | |
opt_level=opt_level, num_losses=2) | |
generator_A2B, generator_B2A, discriminators_A, discriminators_B = models | |
optimizer_G, optimizer_D = optimizers | |
buffer_size = 50 | |
fake_a_buffer = [] | |
fake_b_buffer = [] | |
def buffer_insert_and_get(buffer, batch): | |
output_batch = [] | |
for b in batch: | |
b = b.unsqueeze(0) | |
# if buffer is not fully filled: | |
if len(buffer) < buffer_size: | |
output_batch.append(b) | |
buffer.append(b.cpu()) | |
elif random.uniform(0, 1) > 0.5: | |
# Add newly created image into the buffer and put ont from the buffer into the output | |
random_index = random.randint(0, buffer_size - 1) | |
output_batch.append(buffer[random_index].clone().to(device)) | |
buffer[random_index] = b.cpu() | |
else: | |
output_batch.append(b) | |
return torch.cat(output_batch, dim=0) | |
from ignite.utils import convert_tensor | |
import torch.nn.functional as F | |
lambda_value = 10.0 | |
def discriminators_forward_pass(discriminators, batch_real, batch_fake, fake_buffer): | |
decision_real = discriminators(batch_real) | |
batch_fake = buffer_insert_and_get(fake_buffer, batch_fake) | |
batch_fake = batch_fake.detach() | |
decision_fake = discriminators(batch_fake) | |
return decision_real, decision_fake | |
def compute_loss_generator(batch_decision, batch_real, batch_rec, lambda_value): | |
# loss gan | |
target = torch.ones_like(batch_decision) | |
loss_gan = F.mse_loss(batch_decision, target) | |
print("> loss_gan:", loss_gan) | |
# loss cycle | |
loss_cycle = F.l1_loss(batch_rec, batch_real) * lambda_value | |
print("> loss_cycle:", loss_cycle) | |
return loss_gan + loss_cycle | |
def compute_loss_discriminators(decision_real, decision_fake): | |
# loss = mean (D_b(y) − 1)^2 + mean D_b(G(x))^2 | |
loss = F.mse_loss(decision_fake, torch.zeros_like(decision_fake)) | |
loss += F.mse_loss(decision_real, torch.ones_like(decision_real)) | |
return loss | |
def update_fn(engine, batch): | |
generator_A2B.train() | |
generator_B2A.train() | |
discriminators_A.train() | |
discriminators_B.train() | |
real_a = convert_tensor(batch['A'], device=device, non_blocking=True) | |
real_b = convert_tensor(batch['B'], device=device, non_blocking=True) | |
fake_b = generator_A2B(real_a) | |
rec_a = generator_B2A(fake_b) | |
fake_a = generator_B2A(real_b) | |
rec_b = generator_A2B(fake_a) | |
decision_fake_a = discriminators_A(fake_a) | |
decision_fake_b = discriminators_B(fake_b) | |
# Disable grads computation for the discriminators: | |
toggle_grad(discriminators_A, False) | |
toggle_grad(discriminators_B, False) | |
# Compute loss for generators and update generators | |
# loss_a2b = GAN loss: mean (D_b(G(x)) − 1)^2 + Forward cycle loss: || F(G(x)) - x ||_1 | |
loss_a2b = compute_loss_generator(decision_fake_b, real_a, rec_a, lambda_value) | |
print("loss_a2b:", loss_a2b) | |
# loss_b2a = GAN loss: mean (D_a(F(x)) − 1)^2 + Backward cycle loss: || G(F(y)) - y ||_1 | |
loss_b2a = compute_loss_generator(decision_fake_a, real_b, rec_b, lambda_value) | |
print("loss_b2a:", loss_b2a) | |
# total generators loss: | |
loss_generators = loss_a2b + loss_b2a | |
print("- loss_generators:", loss_generators) | |
optimizer_G.zero_grad() | |
with amp.scale_loss(loss_generators, optimizer_G, loss_id=0) as scaled_loss: | |
scaled_loss.backward() | |
optimizer_G.step() | |
print("-- loss_generators:", loss_generators) | |
print("-- loss_a2b:", loss_a2b) | |
print("-- loss_b2a:", loss_b2a) | |
decision_fake_a = rec_a = decision_fake_b = rec_b = None | |
# Enable grads computation for the discriminators: | |
toggle_grad(discriminators_A, True) | |
toggle_grad(discriminators_B, True) | |
decision_real_a, decision_fake_a = discriminators_forward_pass(discriminators_A, real_a, fake_a, fake_a_buffer) | |
decision_real_b, decision_fake_b = discriminators_forward_pass(discriminators_B, real_b, fake_b, fake_b_buffer) | |
# Compute loss for discriminators and update discriminators | |
# loss_a = mean (D_a(y) − 1)^2 + mean D_a(F(x))^2 | |
loss_a = compute_loss_discriminators(decision_real_a, decision_fake_a) | |
# loss_b = mean (D_b(y) − 1)^2 + mean D_b(G(x))^2 | |
loss_b = compute_loss_discriminators(decision_real_b, decision_fake_b) | |
# total discriminators loss: | |
loss_discriminators = 0.5 * (loss_a + loss_b) | |
optimizer_D.zero_grad() | |
with amp.scale_loss(loss_discriminators, optimizer_D, loss_id=1) as scaled_loss: | |
scaled_loss.backward() | |
optimizer_D.step() | |
return { | |
"loss_generators": loss_generators.item(), | |
"loss_generator_a2b": loss_a2b.item(), | |
"loss_generator_b2a": loss_b2a.item(), | |
"loss_discriminators": loss_discriminators.item(), | |
"loss_discriminators_a": loss_a.item(), | |
"loss_discriminators_b": loss_b.item(), | |
} | |
real_batch = { | |
"A": 2.0 * torch.rand(6, 3, 200, 200) - 1.0, | |
"B": 2.0 * torch.rand(6, 3, 200, 200) - 1.0 | |
} | |
print("\nRun update") | |
res = update_fn(engine=None, batch=real_batch) | |
print(res) | |
real_batch = { | |
"A": 2.0 * torch.rand(6, 3, 200, 200) - 1.0, | |
"B": 2.0 * torch.rand(6, 3, 200, 200) - 1.0 | |
} | |
print("\nRun update") | |
res = update_fn(engine=None, batch=real_batch) | |
print(res) | |
real_batch = { | |
"A": 2.0 * torch.rand(6, 3, 200, 200) - 1.0, | |
"B": 2.0 * torch.rand(6, 3, 200, 200) - 1.0 | |
} | |
print("\nRun update") | |
res = update_fn(engine=None, batch=real_batch) | |
print(res) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment