Skip to content

Instantly share code, notes, and snippets.

@vfdev-5
Created December 13, 2019 13:13
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save vfdev-5/7a651ca550843737200f148755c42861 to your computer and use it in GitHub Desktop.
Save vfdev-5/7a651ca550843737200f148755c42861 to your computer and use it in GitHub Desktop.
#!/usr/bin/env python
# coding: utf-8
opt_level = "O1"
import torch
print(torch.__version__)
import ignite
ignite.__file__
ignite.__version__
import random
import torch
seed = 17
random.seed(seed)
_ = torch.manual_seed(seed)
import torch.nn as nn
def get_conv_inorm_relu(in_planes, out_planes, kernel_size, stride, reflection_pad=True, with_relu=True):
layers = []
padding = (kernel_size - 1) // 2
if reflection_pad:
layers.append(nn.ReflectionPad2d(padding=padding))
padding = 0
layers += [
nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding),
nn.InstanceNorm2d(out_planes, affine=False, track_running_stats=False),
]
if with_relu:
layers.append(nn.ReLU(inplace=True))
return nn.Sequential(*layers)
def get_conv_transposed_inorm_relu(in_planes, out_planes, kernel_size, stride):
return nn.Sequential(
nn.ConvTranspose2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=1, output_padding=1),
nn.InstanceNorm2d(out_planes, affine=False, track_running_stats=False),
nn.ReLU(inplace=True)
)
class ResidualBlock(nn.Module):
def __init__(self, in_planes):
super(ResidualBlock, self).__init__()
self.conv1 = get_conv_inorm_relu(in_planes, in_planes, kernel_size=3, stride=1)
self.conv2 = get_conv_inorm_relu(in_planes, in_planes, kernel_size=3, stride=1, with_relu=False)
def forward(self, x):
residual = x
x = self.conv1(x)
x = self.conv2(x)
return x + residual
class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__()
self.c7s1_64 = get_conv_inorm_relu(3, 64, kernel_size=7, stride=1)
self.d128 = get_conv_inorm_relu(64, 128, kernel_size=3, stride=2, reflection_pad=False)
self.d256 = get_conv_inorm_relu(128, 256, kernel_size=3, stride=2, reflection_pad=False)
self.resnet9 = nn.Sequential(*[ResidualBlock(256) for i in range(9)])
self.u128 = get_conv_transposed_inorm_relu(256, 128, kernel_size=3, stride=2)
self.u64 = get_conv_transposed_inorm_relu(128, 64, kernel_size=3, stride=2)
self.c7s1_3 = get_conv_inorm_relu(64, 3, kernel_size=7, stride=1, with_relu=False)
# Replace instance norm by tanh activation
self.c7s1_3[-1] = nn.Tanh()
def forward(self, x):
# Encoding
x = self.c7s1_64(x)
x = self.d128(x)
x = self.d256(x)
# 9 residual blocks
x = self.resnet9(x)
# Decoding
x = self.u128(x)
x = self.u64(x)
y = self.c7s1_3(x)
return y
def get_conv_inorm_lrelu(in_planes, out_planes, stride=2, negative_slope=0.2):
return nn.Sequential(
nn.Conv2d(in_planes, out_planes, kernel_size=4, stride=stride, padding=1),
nn.InstanceNorm2d(out_planes, affine=False, track_running_stats=False),
nn.LeakyReLU(negative_slope=negative_slope, inplace=True)
)
class discriminators(nn.Module):
def __init__(self):
super(discriminators, self).__init__()
self.c64 = nn.Conv2d(3, 64, kernel_size=4, stride=2, padding=1)
self.relu = nn.LeakyReLU(0.2, inplace=True)
self.c128 = get_conv_inorm_lrelu(64, 128)
self.c256 = get_conv_inorm_lrelu(128, 256)
self.c512 = get_conv_inorm_lrelu(256, 512, stride=1)
self.last_conv = nn.Conv2d(512, 1, kernel_size=4, stride=1, padding=1)
def forward(self, x):
x = self.c64(x)
x = self.relu(x)
x = self.c128(x)
x = self.c256(x)
x = self.c512(x)
y = self.last_conv(x)
return y
def init_weights(module):
assert isinstance(module, nn.Module)
if hasattr(module, "weight") and module.weight is not None:
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
if hasattr(module, "bias") and module.bias is not None:
torch.nn.init.constant_(module.bias, 0.0)
for c in module.children():
init_weights(c)
assert torch.backends.cudnn.enabled, "NVIDIA/Apex:Amp requires cudnn backend to be enabled."
torch.backends.cudnn.benchmark = True
device = "cuda"
generator_A2B = Generator().to(device)
init_weights(generator_A2B)
discriminators_B = discriminators().to(device)
init_weights(discriminators_B)
generator_B2A = Generator().to(device)
init_weights(generator_B2A)
discriminators_A = discriminators().to(device)
init_weights(discriminators_A)
from itertools import chain
import torch.optim as optim
lr = 0.0002
beta1 = 0.5
optimizer_G = optim.Adam(chain(generator_A2B.parameters(), generator_B2A.parameters()), lr=lr, betas=(beta1, 0.999))
optimizer_D = optim.Adam(chain(discriminators_A.parameters(), discriminators_B.parameters()), lr=lr, betas=(beta1, 0.999))
def toggle_grad(model, on_or_off):
# https://github.com/ajbrock/BigGAN-PyTorch/blob/master/utils.py#L674
for param in model.parameters():
param.requires_grad = on_or_off
try:
from apex import amp
except ImportError:
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to run this example.")
# Initialize Amp
models, optimizers = amp.initialize([generator_A2B, generator_B2A, discriminators_A, discriminators_B],
[optimizer_G, optimizer_D],
opt_level=opt_level, num_losses=2)
generator_A2B, generator_B2A, discriminators_A, discriminators_B = models
optimizer_G, optimizer_D = optimizers
buffer_size = 50
fake_a_buffer = []
fake_b_buffer = []
def buffer_insert_and_get(buffer, batch):
output_batch = []
for b in batch:
b = b.unsqueeze(0)
# if buffer is not fully filled:
if len(buffer) < buffer_size:
output_batch.append(b)
buffer.append(b.cpu())
elif random.uniform(0, 1) > 0.5:
# Add newly created image into the buffer and put ont from the buffer into the output
random_index = random.randint(0, buffer_size - 1)
output_batch.append(buffer[random_index].clone().to(device))
buffer[random_index] = b.cpu()
else:
output_batch.append(b)
return torch.cat(output_batch, dim=0)
from ignite.utils import convert_tensor
import torch.nn.functional as F
lambda_value = 10.0
def discriminators_forward_pass(discriminators, batch_real, batch_fake, fake_buffer):
decision_real = discriminators(batch_real)
batch_fake = buffer_insert_and_get(fake_buffer, batch_fake)
batch_fake = batch_fake.detach()
decision_fake = discriminators(batch_fake)
return decision_real, decision_fake
def compute_loss_generator(batch_decision, batch_real, batch_rec, lambda_value):
# loss gan
target = torch.ones_like(batch_decision)
loss_gan = F.mse_loss(batch_decision, target)
print("> loss_gan:", loss_gan)
# loss cycle
loss_cycle = F.l1_loss(batch_rec, batch_real) * lambda_value
print("> loss_cycle:", loss_cycle)
return loss_gan + loss_cycle
def compute_loss_discriminators(decision_real, decision_fake):
# loss = mean (D_b(y) − 1)^2 + mean D_b(G(x))^2
loss = F.mse_loss(decision_fake, torch.zeros_like(decision_fake))
loss += F.mse_loss(decision_real, torch.ones_like(decision_real))
return loss
def update_fn(engine, batch):
generator_A2B.train()
generator_B2A.train()
discriminators_A.train()
discriminators_B.train()
real_a = convert_tensor(batch['A'], device=device, non_blocking=True)
real_b = convert_tensor(batch['B'], device=device, non_blocking=True)
fake_b = generator_A2B(real_a)
rec_a = generator_B2A(fake_b)
fake_a = generator_B2A(real_b)
rec_b = generator_A2B(fake_a)
decision_fake_a = discriminators_A(fake_a)
decision_fake_b = discriminators_B(fake_b)
# Disable grads computation for the discriminators:
toggle_grad(discriminators_A, False)
toggle_grad(discriminators_B, False)
# Compute loss for generators and update generators
# loss_a2b = GAN loss: mean (D_b(G(x)) − 1)^2 + Forward cycle loss: || F(G(x)) - x ||_1
loss_a2b = compute_loss_generator(decision_fake_b, real_a, rec_a, lambda_value)
print("loss_a2b:", loss_a2b)
# loss_b2a = GAN loss: mean (D_a(F(x)) − 1)^2 + Backward cycle loss: || G(F(y)) - y ||_1
loss_b2a = compute_loss_generator(decision_fake_a, real_b, rec_b, lambda_value)
print("loss_b2a:", loss_b2a)
# total generators loss:
loss_generators = loss_a2b + loss_b2a
print("- loss_generators:", loss_generators)
optimizer_G.zero_grad()
with amp.scale_loss(loss_generators, optimizer_G, loss_id=0) as scaled_loss:
scaled_loss.backward()
optimizer_G.step()
print("-- loss_generators:", loss_generators)
print("-- loss_a2b:", loss_a2b)
print("-- loss_b2a:", loss_b2a)
decision_fake_a = rec_a = decision_fake_b = rec_b = None
# Enable grads computation for the discriminators:
toggle_grad(discriminators_A, True)
toggle_grad(discriminators_B, True)
decision_real_a, decision_fake_a = discriminators_forward_pass(discriminators_A, real_a, fake_a, fake_a_buffer)
decision_real_b, decision_fake_b = discriminators_forward_pass(discriminators_B, real_b, fake_b, fake_b_buffer)
# Compute loss for discriminators and update discriminators
# loss_a = mean (D_a(y) − 1)^2 + mean D_a(F(x))^2
loss_a = compute_loss_discriminators(decision_real_a, decision_fake_a)
# loss_b = mean (D_b(y) − 1)^2 + mean D_b(G(x))^2
loss_b = compute_loss_discriminators(decision_real_b, decision_fake_b)
# total discriminators loss:
loss_discriminators = 0.5 * (loss_a + loss_b)
optimizer_D.zero_grad()
with amp.scale_loss(loss_discriminators, optimizer_D, loss_id=1) as scaled_loss:
scaled_loss.backward()
optimizer_D.step()
return {
"loss_generators": loss_generators.item(),
"loss_generator_a2b": loss_a2b.item(),
"loss_generator_b2a": loss_b2a.item(),
"loss_discriminators": loss_discriminators.item(),
"loss_discriminators_a": loss_a.item(),
"loss_discriminators_b": loss_b.item(),
}
real_batch = {
"A": 2.0 * torch.rand(6, 3, 200, 200) - 1.0,
"B": 2.0 * torch.rand(6, 3, 200, 200) - 1.0
}
print("\nRun update")
res = update_fn(engine=None, batch=real_batch)
print(res)
real_batch = {
"A": 2.0 * torch.rand(6, 3, 200, 200) - 1.0,
"B": 2.0 * torch.rand(6, 3, 200, 200) - 1.0
}
print("\nRun update")
res = update_fn(engine=None, batch=real_batch)
print(res)
real_batch = {
"A": 2.0 * torch.rand(6, 3, 200, 200) - 1.0,
"B": 2.0 * torch.rand(6, 3, 200, 200) - 1.0
}
print("\nRun update")
res = update_fn(engine=None, batch=real_batch)
print(res)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment