-
-
Save pzaffino/7c3714ffe8eb867eb45b721ac4d2d808 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.nn as nn | |
import torch.optim as optim | |
import torchvision.models | |
import torch.nn.functional as F | |
class Generator(nn.Module): | |
def __init__(self, initial_features, num_channels=1, dropout=0.1): | |
super(Generator, self).__init__() | |
self.conv_down1 = SingleConv3x3(num_channels, initial_features, dropout) | |
self.conv_down2 = SingleConv3x3(initial_features, initial_features, dropout) | |
self.maxpool1 = nn.MaxPool2d(kernel_size=2) | |
self.conv_down3_4 = DoubleConv3x3(initial_features, initial_features*2, dropout) | |
self.maxpool2 = nn.MaxPool2d(kernel_size=2) | |
self.conv_down5_6_7 = TripleConv3x3(initial_features*2, initial_features*(2**2), dropout) | |
self.maxpool3 = nn.MaxPool2d(kernel_size=2) | |
self.conv_down8_9_10 = TripleConv3x3(initial_features*(2**2), initial_features*(2**3), dropout) | |
self.maxpool4 = nn.MaxPool2d(kernel_size=2) | |
self.conv_down11_12_13 = TripleConv3x3(initial_features*(2**3), initial_features*(2**3), dropout) | |
self.conv_down14_15_16 = TripleConv3x3(initial_features*(2**3), initial_features*(2**3), dropout) | |
self.upconcat1 = UpConcat(initial_features*(2**3), initial_features*(2**3)) | |
self.conv_up1_2_3 = TripleConv3x3(initial_features*(2**3)+initial_features*(2**3), initial_features*(2**2), dropout) | |
self.upconcat2 = UpConcat(initial_features*(2**3), initial_features*(2**2)) | |
self.conv_up4_5_6 = TripleConv3x3(initial_features*(2**2)+initial_features*(2**2), initial_features*2, dropout) | |
self.upconcat3 = UpConcat(initial_features*(2**2), initial_features*2) | |
self.conv_up7_8 = DoubleConv3x3(initial_features*2+initial_features*2, initial_features, dropout) | |
self.upconcat4 = UpConcat(initial_features*2, initial_features) | |
self.conv_up9_10 = DoubleConv3x3(initial_features+initial_features, initial_features, dropout) | |
self.final = SingleConv1x1(initial_features, dropout) | |
def forward(self, inputs): | |
conv_down1_feat = self.conv_down1(inputs) | |
conv_down2_feat = self.conv_down2(conv_down1_feat) | |
maxpool1_feat = self.maxpool1(conv_down2_feat) | |
conv_down3_4_feat = self.conv_down3_4(maxpool1_feat) | |
maxpool2_feat = self.maxpool2(conv_down3_4_feat) | |
conv_down5_6_7_feat = self.conv_down5_6_7(maxpool2_feat) | |
maxpool3_feat = self.maxpool3(conv_down5_6_7_feat) | |
conv_down8_9_10_feat = self.conv_down8_9_10(maxpool3_feat) | |
maxpool4_feat = self.maxpool4(conv_down8_9_10_feat) | |
conv_down11_12_13_feat = self.conv_down11_12_13(maxpool4_feat) | |
conv_down14_15_16_feat = self.conv_down14_15_16(conv_down11_12_13_feat) | |
upconcat1_feat = self.upconcat1(conv_down14_15_16_feat, conv_down8_9_10_feat) | |
conv_up1_2_3_feat = self.conv_up1_2_3(upconcat1_feat) | |
upconcat2_feat = self.upconcat2(conv_up1_2_3_feat, conv_down5_6_7_feat) | |
conv_up4_5_6_feat = self.conv_up4_5_6(upconcat2_feat) | |
upconcat3_feat = self.upconcat3(conv_up4_5_6_feat, conv_down3_4_feat) | |
conv_up7_8_feat = self.conv_up7_8(upconcat3_feat) | |
upconcat4_feat = self.upconcat4(conv_up7_8_feat, conv_down2_feat) | |
conv_up9_10_feat = self.conv_up9_10(upconcat4_feat) | |
outputs = self.final(conv_up9_10_feat) | |
return outputs | |
class SingleConv1x1(nn.Module): | |
def __init__(self, in_feat, dropout): | |
super(SingleConv1x1, self).__init__() | |
self.conv1 = nn.Sequential(nn.Conv2d(in_feat, 1, | |
kernel_size=1, | |
stride=1, | |
padding=0), | |
nn.Dropout2d(dropout)) | |
def forward(self, inputs): | |
outputs = self.conv1(inputs) | |
return outputs | |
class SingleConv3x3(nn.Module): | |
def __init__(self, in_feat, out_feat, dropout): | |
super(SingleConv3x3, self).__init__() | |
self.conv1 = nn.Sequential(nn.Conv2d(in_feat, out_feat, | |
kernel_size=3, | |
stride=1, | |
padding=1), | |
nn.BatchNorm2d(out_feat), | |
nn.Dropout2d(dropout), | |
nn.LeakyReLU(0.2, True)) | |
def forward(self, inputs): | |
outputs = self.conv1(inputs) | |
return outputs | |
class DoubleConv3x3(nn.Module): | |
def __init__(self, in_feat, out_feat, dropout): | |
super(DoubleConv3x3, self).__init__() | |
self.conv1 = nn.Sequential(nn.Conv2d(in_feat, in_feat, | |
kernel_size=3, | |
stride=1, | |
padding=1), | |
nn.Dropout2d(dropout), | |
nn.BatchNorm2d(in_feat), | |
nn.LeakyReLU(0.2, True)) | |
self.conv2 = nn.Sequential(nn.Conv2d(in_feat, out_feat, | |
kernel_size=3, | |
stride=1, | |
padding=1), | |
nn.Dropout2d(dropout), | |
nn.BatchNorm2d(out_feat), | |
nn.LeakyReLU(0.2, True)) | |
def forward(self, inputs): | |
outputs = self.conv1(inputs) | |
outputs = self.conv2(outputs) | |
return outputs | |
class TripleConv3x3(nn.Module): | |
def __init__(self, in_feat, out_feat, dropout): | |
super(TripleConv3x3, self).__init__() | |
self.conv1 = nn.Sequential(nn.Conv2d(in_feat, in_feat, | |
kernel_size=3, | |
stride=1, | |
padding=1), | |
nn.Dropout2d(dropout), | |
nn.BatchNorm2d(in_feat), | |
nn.LeakyReLU(0.2, True)) | |
self.conv2 = nn.Sequential(nn.Conv2d(in_feat, in_feat, | |
kernel_size=3, | |
stride=1, | |
padding=1), | |
nn.Dropout2d(dropout), | |
nn.BatchNorm2d(in_feat), | |
nn.LeakyReLU(0.2, True)) | |
self.conv3 = nn.Sequential(nn.Conv2d(in_feat, out_feat, | |
kernel_size=3, | |
stride=1, | |
padding=1), | |
nn.Dropout2d(dropout), | |
nn.BatchNorm2d(out_feat), | |
nn.LeakyReLU(0.2, True)) | |
def forward(self, inputs): | |
outputs = self.conv1(inputs) | |
outputs = self.conv2(outputs) | |
outputs = self.conv3(outputs) | |
return outputs | |
class UpConcat(nn.Module): | |
def __init__(self, in_feat, out_feat): | |
super(UpConcat, self).__init__() | |
self.deconv = nn.ConvTranspose2d(out_feat, out_feat, kernel_size=3, | |
padding=1, stride=1, dilation=1, output_padding=0) | |
self.up = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True) | |
def forward(self, inputs, down_outputs): | |
outputs = self.deconv(inputs) | |
outputs = self.up(outputs) | |
out = torch.cat([down_outputs, outputs], 1) | |
return out | |
def compute_l1_norm(model, lambda1=0.5): | |
l1_regularization = torch.tensor(0).to(device, dtype=torch.float) | |
for param in model.parameters(): | |
l1_regularization += torch.norm(param, 1) | |
return lambda1 * l1_regularization | |
def compute_generator_loss(Y, Y_pred, ROI, weight=1.0): | |
return weight * torch.sum(torch.abs(ROI * (Y - Y_pred)))/torch.sum(ROI) | |
def run_training(models, epoch, device): | |
# Set optimizers | |
optimizerG = optim.Adam(models["generator"].parameters(), lr=0.0001, weight_decay=0.0004) | |
optimizerD = optim.Adam(models["discriminator"].parameters(), lr=0.0001, weight_decay=0.0004) | |
# Deine adversial loss | |
adversial_loss = nn.BCELoss().to(device) | |
# Initialize data batch for training | |
X = torch.ones(20,1,256,256, device=device, dtype=torch.float) | |
Y = torch.zeros(20,1,256,256, device=device, dtype=torch.float) | |
ROI = torch.ones(20,1,256,256, device=device, dtype=torch.float) | |
total_images = X.shape[0] | |
# mini batch stuff | |
g_minibatch_counter = 0 | |
d_minibatch_counter = 0 | |
minibatch_size = 4 | |
minibatch_status = "discriminator" | |
for i in range(total_images): | |
# Define labels for discriminato training | |
real_label = torch.full((1, 1), 1.0, dtype=torch.float, device=device) | |
fake_label = torch.full((1, 1), 0.0, dtype=torch.float, device=device) | |
# Convert numpy object to pytorch tensor | |
X_batch = X[i,:,:,:].view(1,1,256,256).to(device) | |
Y_batch = Y[i,:,:,:].view(1,1,256,256).to(device) | |
ROI_batch = ROI[i,:,:,:].view(1,1,256,256).to(device) | |
# Run conversion | |
Y_pred = models["generator"](X_batch) * ROI_batch | |
##################### | |
# Train discriminator | |
##################### | |
if minibatch_status == "discriminator": | |
models["generator"].eval() | |
models["discriminator"].train() | |
# Set the discriminator gradients to zero | |
models["discriminator"].zero_grad() | |
if d_minibatch_counter <= int(minibatch_size/2): | |
# real | |
score_d_real = models["discriminator"](X_batch) | |
loss_D_real = adversial_loss(score_d_real, real_label) | |
loss_D_real.backward() | |
loss_D = loss_D_real | |
else: | |
# fake | |
score_d_fake = models["discriminator"](Y_pred.detach()) | |
loss_D_fake = adversial_loss(score_d_fake, fake_label) | |
loss_D_fake.backward() | |
loss_D = loss_D_fake | |
# Update minibatch info | |
d_minibatch_counter +=1 | |
if d_minibatch_counter == minibatch_size: | |
optimizerD.step() | |
minibatch_status = "generator" | |
d_minibatch_counter = 0 | |
# Print on screen | |
print(">>> Epoch %d (Discriminator) -- Example %d/%d -- Loss D = %.3f" | |
% (epoch, i, total_images, loss_D.item())) | |
##################### | |
# Train generator | |
##################### | |
if minibatch_status == "generator": | |
models["generator"].train() | |
models["discriminator"].eval() | |
# Set the generator gradients to zero | |
models["generator"].zero_grad() | |
# Compute generative intensity loss | |
loss_G_intensity = compute_generator_loss(Y_batch, Y_pred, ROI_batch) | |
# Compute l1 penalty | |
loss_G_norm = compute_l1_norm(models["generator"], 0.0004).to(device) | |
# Compute losses | |
score_d_g = models["discriminator"](Y_pred) | |
loss_D_G = adversial_loss(score_d_g, real_label) | |
loss_G = loss_G_intensity + loss_G_norm + (10.0 * loss_D_G) | |
# Run optimizer | |
loss_G.backward() | |
optimizerG.step() | |
# Update minibatch info | |
g_minibatch_counter += 1 | |
if g_minibatch_counter == minibatch_size: | |
minibatch_status = "discriminator" | |
g_minibatch_counter = 0 | |
# Print on screen | |
print (">>> Epoch %d (Generator) -- Example %d/%d -- Loss G = %.3f" | |
% (epoch, i, total_images, loss_G.item())) | |
# Define device | |
device=torch.device("cuda:0") | |
#device="cpu" | |
# Create models | |
models = {} | |
models["generator"] = Generator(32).to(device) | |
models["discriminator"] = torch.hub.load('pytorch/vision', 'resnet152', pretrained=False) | |
models["discriminator"].conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False) | |
models["discriminator"].fc = nn.Sequential(nn.Linear(models["discriminator"].fc.in_features,512), | |
nn.ReLU(), | |
nn.Dropout(), | |
nn.Linear(512, 1), | |
nn.Sigmoid()) | |
models["discriminator"] = models["discriminator"].to(device) | |
for epoch in range(100): | |
run_training(models, epoch, device) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment