Skip to content

Instantly share code, notes, and snippets.

@pzaffino
Created April 30, 2021 23:12
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save pzaffino/7c3714ffe8eb867eb45b721ac4d2d808 to your computer and use it in GitHub Desktop.
Save pzaffino/7c3714ffe8eb867eb45b721ac4d2d808 to your computer and use it in GitHub Desktop.
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.models
import torch.nn.functional as F
class Generator(nn.Module):
def __init__(self, initial_features, num_channels=1, dropout=0.1):
super(Generator, self).__init__()
self.conv_down1 = SingleConv3x3(num_channels, initial_features, dropout)
self.conv_down2 = SingleConv3x3(initial_features, initial_features, dropout)
self.maxpool1 = nn.MaxPool2d(kernel_size=2)
self.conv_down3_4 = DoubleConv3x3(initial_features, initial_features*2, dropout)
self.maxpool2 = nn.MaxPool2d(kernel_size=2)
self.conv_down5_6_7 = TripleConv3x3(initial_features*2, initial_features*(2**2), dropout)
self.maxpool3 = nn.MaxPool2d(kernel_size=2)
self.conv_down8_9_10 = TripleConv3x3(initial_features*(2**2), initial_features*(2**3), dropout)
self.maxpool4 = nn.MaxPool2d(kernel_size=2)
self.conv_down11_12_13 = TripleConv3x3(initial_features*(2**3), initial_features*(2**3), dropout)
self.conv_down14_15_16 = TripleConv3x3(initial_features*(2**3), initial_features*(2**3), dropout)
self.upconcat1 = UpConcat(initial_features*(2**3), initial_features*(2**3))
self.conv_up1_2_3 = TripleConv3x3(initial_features*(2**3)+initial_features*(2**3), initial_features*(2**2), dropout)
self.upconcat2 = UpConcat(initial_features*(2**3), initial_features*(2**2))
self.conv_up4_5_6 = TripleConv3x3(initial_features*(2**2)+initial_features*(2**2), initial_features*2, dropout)
self.upconcat3 = UpConcat(initial_features*(2**2), initial_features*2)
self.conv_up7_8 = DoubleConv3x3(initial_features*2+initial_features*2, initial_features, dropout)
self.upconcat4 = UpConcat(initial_features*2, initial_features)
self.conv_up9_10 = DoubleConv3x3(initial_features+initial_features, initial_features, dropout)
self.final = SingleConv1x1(initial_features, dropout)
def forward(self, inputs):
conv_down1_feat = self.conv_down1(inputs)
conv_down2_feat = self.conv_down2(conv_down1_feat)
maxpool1_feat = self.maxpool1(conv_down2_feat)
conv_down3_4_feat = self.conv_down3_4(maxpool1_feat)
maxpool2_feat = self.maxpool2(conv_down3_4_feat)
conv_down5_6_7_feat = self.conv_down5_6_7(maxpool2_feat)
maxpool3_feat = self.maxpool3(conv_down5_6_7_feat)
conv_down8_9_10_feat = self.conv_down8_9_10(maxpool3_feat)
maxpool4_feat = self.maxpool4(conv_down8_9_10_feat)
conv_down11_12_13_feat = self.conv_down11_12_13(maxpool4_feat)
conv_down14_15_16_feat = self.conv_down14_15_16(conv_down11_12_13_feat)
upconcat1_feat = self.upconcat1(conv_down14_15_16_feat, conv_down8_9_10_feat)
conv_up1_2_3_feat = self.conv_up1_2_3(upconcat1_feat)
upconcat2_feat = self.upconcat2(conv_up1_2_3_feat, conv_down5_6_7_feat)
conv_up4_5_6_feat = self.conv_up4_5_6(upconcat2_feat)
upconcat3_feat = self.upconcat3(conv_up4_5_6_feat, conv_down3_4_feat)
conv_up7_8_feat = self.conv_up7_8(upconcat3_feat)
upconcat4_feat = self.upconcat4(conv_up7_8_feat, conv_down2_feat)
conv_up9_10_feat = self.conv_up9_10(upconcat4_feat)
outputs = self.final(conv_up9_10_feat)
return outputs
class SingleConv1x1(nn.Module):
def __init__(self, in_feat, dropout):
super(SingleConv1x1, self).__init__()
self.conv1 = nn.Sequential(nn.Conv2d(in_feat, 1,
kernel_size=1,
stride=1,
padding=0),
nn.Dropout2d(dropout))
def forward(self, inputs):
outputs = self.conv1(inputs)
return outputs
class SingleConv3x3(nn.Module):
def __init__(self, in_feat, out_feat, dropout):
super(SingleConv3x3, self).__init__()
self.conv1 = nn.Sequential(nn.Conv2d(in_feat, out_feat,
kernel_size=3,
stride=1,
padding=1),
nn.BatchNorm2d(out_feat),
nn.Dropout2d(dropout),
nn.LeakyReLU(0.2, True))
def forward(self, inputs):
outputs = self.conv1(inputs)
return outputs
class DoubleConv3x3(nn.Module):
def __init__(self, in_feat, out_feat, dropout):
super(DoubleConv3x3, self).__init__()
self.conv1 = nn.Sequential(nn.Conv2d(in_feat, in_feat,
kernel_size=3,
stride=1,
padding=1),
nn.Dropout2d(dropout),
nn.BatchNorm2d(in_feat),
nn.LeakyReLU(0.2, True))
self.conv2 = nn.Sequential(nn.Conv2d(in_feat, out_feat,
kernel_size=3,
stride=1,
padding=1),
nn.Dropout2d(dropout),
nn.BatchNorm2d(out_feat),
nn.LeakyReLU(0.2, True))
def forward(self, inputs):
outputs = self.conv1(inputs)
outputs = self.conv2(outputs)
return outputs
class TripleConv3x3(nn.Module):
def __init__(self, in_feat, out_feat, dropout):
super(TripleConv3x3, self).__init__()
self.conv1 = nn.Sequential(nn.Conv2d(in_feat, in_feat,
kernel_size=3,
stride=1,
padding=1),
nn.Dropout2d(dropout),
nn.BatchNorm2d(in_feat),
nn.LeakyReLU(0.2, True))
self.conv2 = nn.Sequential(nn.Conv2d(in_feat, in_feat,
kernel_size=3,
stride=1,
padding=1),
nn.Dropout2d(dropout),
nn.BatchNorm2d(in_feat),
nn.LeakyReLU(0.2, True))
self.conv3 = nn.Sequential(nn.Conv2d(in_feat, out_feat,
kernel_size=3,
stride=1,
padding=1),
nn.Dropout2d(dropout),
nn.BatchNorm2d(out_feat),
nn.LeakyReLU(0.2, True))
def forward(self, inputs):
outputs = self.conv1(inputs)
outputs = self.conv2(outputs)
outputs = self.conv3(outputs)
return outputs
class UpConcat(nn.Module):
def __init__(self, in_feat, out_feat):
super(UpConcat, self).__init__()
self.deconv = nn.ConvTranspose2d(out_feat, out_feat, kernel_size=3,
padding=1, stride=1, dilation=1, output_padding=0)
self.up = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True)
def forward(self, inputs, down_outputs):
outputs = self.deconv(inputs)
outputs = self.up(outputs)
out = torch.cat([down_outputs, outputs], 1)
return out
def compute_l1_norm(model, lambda1=0.5):
l1_regularization = torch.tensor(0).to(device, dtype=torch.float)
for param in model.parameters():
l1_regularization += torch.norm(param, 1)
return lambda1 * l1_regularization
def compute_generator_loss(Y, Y_pred, ROI, weight=1.0):
return weight * torch.sum(torch.abs(ROI * (Y - Y_pred)))/torch.sum(ROI)
def run_training(models, epoch, device):
# Set optimizers
optimizerG = optim.Adam(models["generator"].parameters(), lr=0.0001, weight_decay=0.0004)
optimizerD = optim.Adam(models["discriminator"].parameters(), lr=0.0001, weight_decay=0.0004)
# Deine adversial loss
adversial_loss = nn.BCELoss().to(device)
# Initialize data batch for training
X = torch.ones(20,1,256,256, device=device, dtype=torch.float)
Y = torch.zeros(20,1,256,256, device=device, dtype=torch.float)
ROI = torch.ones(20,1,256,256, device=device, dtype=torch.float)
total_images = X.shape[0]
# mini batch stuff
g_minibatch_counter = 0
d_minibatch_counter = 0
minibatch_size = 4
minibatch_status = "discriminator"
for i in range(total_images):
# Define labels for discriminato training
real_label = torch.full((1, 1), 1.0, dtype=torch.float, device=device)
fake_label = torch.full((1, 1), 0.0, dtype=torch.float, device=device)
# Convert numpy object to pytorch tensor
X_batch = X[i,:,:,:].view(1,1,256,256).to(device)
Y_batch = Y[i,:,:,:].view(1,1,256,256).to(device)
ROI_batch = ROI[i,:,:,:].view(1,1,256,256).to(device)
# Run conversion
Y_pred = models["generator"](X_batch) * ROI_batch
#####################
# Train discriminator
#####################
if minibatch_status == "discriminator":
models["generator"].eval()
models["discriminator"].train()
# Set the discriminator gradients to zero
models["discriminator"].zero_grad()
if d_minibatch_counter <= int(minibatch_size/2):
# real
score_d_real = models["discriminator"](X_batch)
loss_D_real = adversial_loss(score_d_real, real_label)
loss_D_real.backward()
loss_D = loss_D_real
else:
# fake
score_d_fake = models["discriminator"](Y_pred.detach())
loss_D_fake = adversial_loss(score_d_fake, fake_label)
loss_D_fake.backward()
loss_D = loss_D_fake
# Update minibatch info
d_minibatch_counter +=1
if d_minibatch_counter == minibatch_size:
optimizerD.step()
minibatch_status = "generator"
d_minibatch_counter = 0
# Print on screen
print(">>> Epoch %d (Discriminator) -- Example %d/%d -- Loss D = %.3f"
% (epoch, i, total_images, loss_D.item()))
#####################
# Train generator
#####################
if minibatch_status == "generator":
models["generator"].train()
models["discriminator"].eval()
# Set the generator gradients to zero
models["generator"].zero_grad()
# Compute generative intensity loss
loss_G_intensity = compute_generator_loss(Y_batch, Y_pred, ROI_batch)
# Compute l1 penalty
loss_G_norm = compute_l1_norm(models["generator"], 0.0004).to(device)
# Compute losses
score_d_g = models["discriminator"](Y_pred)
loss_D_G = adversial_loss(score_d_g, real_label)
loss_G = loss_G_intensity + loss_G_norm + (10.0 * loss_D_G)
# Run optimizer
loss_G.backward()
optimizerG.step()
# Update minibatch info
g_minibatch_counter += 1
if g_minibatch_counter == minibatch_size:
minibatch_status = "discriminator"
g_minibatch_counter = 0
# Print on screen
print (">>> Epoch %d (Generator) -- Example %d/%d -- Loss G = %.3f"
% (epoch, i, total_images, loss_G.item()))
# Define device
device=torch.device("cuda:0")
#device="cpu"
# Create models
models = {}
models["generator"] = Generator(32).to(device)
models["discriminator"] = torch.hub.load('pytorch/vision', 'resnet152', pretrained=False)
models["discriminator"].conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
models["discriminator"].fc = nn.Sequential(nn.Linear(models["discriminator"].fc.in_features,512),
nn.ReLU(),
nn.Dropout(),
nn.Linear(512, 1),
nn.Sigmoid())
models["discriminator"] = models["discriminator"].to(device)
for epoch in range(100):
run_training(models, epoch, device)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment