Last active
July 17, 2020 11:42
-
-
Save PrimeF/2f49e6ffb5f0be279250907aca886220 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
--- | |
optimizer: | |
class: "torch.optim.SGD" | |
params: | |
lr: 0.1 | |
momentum: 0.9 | |
nesterov: True | |
loss: "torch.nn.MSELoss" | |
inner_activation: "torch.nn.ReLU" | |
outer_activation: "torch.nn.Softsign" | |
num_epochs: 50 | |
batch_size: 20 | |
checkpoint_frequency: 1 | |
num_layers: 5 | |
in_channels: 1 | |
out_channels: 2 | |
channels: [64, 128, 256, 512, 1024] | |
conv_kernel: 3 | |
conv_stride: 1 | |
conv_padding: 1 | |
upconv_kernel: 3 | |
upconv_stride: 2 | |
upconv_padding: 1 | |
maxpool_kernel: 2 | |
maxpool_stride: 2 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Python standard libraries | |
import os | |
import glob | |
# Installed libraries | |
import numpy as np | |
from PIL import Image, ImageCms | |
import torch | |
from torch.utils.data import DataLoader | |
from torch.utils.data.dataset import Dataset | |
class ImageDataset(Dataset): | |
def __init__(self, path, device, transforms=None): | |
super(ImageDataset, self).__init__() | |
self.img_paths = glob.glob(os.path.join(path, '*.jpg')) | |
self.device = device | |
self.transforms = transforms | |
# Create input and output colour profiles. | |
self._rgb_profile = ImageCms.createProfile(colorSpace='sRGB') | |
self._lab_profile = ImageCms.createProfile(colorSpace='LAB') | |
# Create a transform object from the input and output profiles | |
self.rgb2lab = ImageCms.buildTransform( | |
inputProfile=self._rgb_profile, | |
outputProfile=self._lab_profile, | |
inMode='RGB', | |
outMode='LAB' | |
) | |
def __getitem__(self, idx): | |
# Read PIL image and strip alpha channel if present | |
img = Image.open(self.img_paths[idx]).convert('RGB') | |
# Apply torch dataset transformation | |
if self.transforms: | |
img = self.transforms(img) | |
# Convert to LAB | |
lab = ImageCms.applyTransform(img, self.rgb2lab) | |
# Convert to ndarray | |
lab = np.array(lab) | |
# Convert from (H x W x C) to (C x H x W) | |
lab = np.transpose(lab, axes=[2, 0, 1]) | |
# Normalize | |
lab = lab / 255. | |
# Get grayscale L channel | |
l = lab[:1, :, :] | |
# Get colors AB channels | |
ab = lab[1:, :, :] | |
return {'l': torch.tensor(l, device=self.device), | |
'ab': torch.tensor(ab, device=self.device)} | |
def __len__(self): | |
return len(self.img_paths) | |
class ImagePrefetcher: | |
def __init__(self, data_loader: DataLoader): | |
self._data_loader = iter(data_loader) | |
self._stream = torch.cuda.Stream() | |
self._next_batch = None | |
self.preload() | |
def preload(self): | |
try: | |
self._next_batch = next(self._data_loader) | |
except StopIteration: | |
self._next_batch = None | |
return | |
with torch.cuda.stream(self._stream): | |
self._next_batch = {u: v.contiguous().cuda(non_blocking=True) for u, v in self._next_batch.items()} | |
def next(self): | |
torch.cuda.current_stream().wait_stream(self._stream) | |
batch = self._next_batch | |
self.preload() | |
return batch |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Installed libraries | |
import torch | |
import torch.nn as nn | |
from torch.cuda.amp import autocast | |
# Custom libraries | |
from utils import resolve_class | |
class Model(nn.Module): | |
def __init__(self, **kwargs): | |
super(Model, self).__init__() | |
# Parse parameters | |
self.num_layers = kwargs['num_layers'] | |
self.inner_activation = resolve_class(kwargs['inner_activation']) | |
self.outer_activation = resolve_class(kwargs['outer_activation']) | |
self.in_channels = kwargs['in_channels'] | |
self.out_channels = kwargs['out_channels'] | |
self.channels = kwargs['channels'] | |
self.conv_kernel = kwargs['conv_kernel'] | |
self.conv_stride = kwargs['conv_stride'] | |
self.conv_padding = kwargs['conv_padding'] | |
self.upconv_kernel = kwargs['upconv_kernel'] | |
self.upconv_stride = kwargs['upconv_stride'] | |
self.upconv_padding = kwargs['upconv_padding'] | |
self.maxpool_kernel = kwargs['maxpool_kernel'] | |
self.maxpool_stride = kwargs['maxpool_stride'] | |
# Define Architecture | |
self.encoder, self.decoder = [], [] | |
for i in range(0, self.num_layers): | |
is_first_layer = (i == 0) | |
is_last_layer = (i == (self.num_layers - 1)) | |
self.encoder.append(UNetDownBlock(self.in_channels, | |
self.channels[i], | |
self.inner_activation, | |
self.conv_kernel, | |
self.conv_stride, | |
self.conv_padding, | |
self.maxpool_kernel, | |
self.maxpool_stride)) | |
self.decoder.append(UNetUpBlock(self.channels[i] if is_last_layer else 2 * self.channels[i], | |
self.channels[i] if is_first_layer else self.in_channels, | |
self.inner_activation, | |
self.conv_kernel, | |
self.conv_stride, | |
self.conv_padding, | |
self.upconv_kernel, | |
self.upconv_stride, | |
self.upconv_padding)) | |
# Set input channels for next layer | |
self.in_channels = self.channels[i] | |
# Reverse Decoder | |
self.decoder = self.decoder[::-1] | |
# Output Block (1x1 Convolution) | |
self.decoder.append(UNetOutputBlock(self.channels[0], self.out_channels, self.outer_activation)) | |
# Create PyTorch module ĺist | |
self.encoder = nn.ModuleList(self.encoder) | |
self.decoder = nn.ModuleList(self.decoder) | |
@autocast() | |
def forward(self, x): | |
input = x | |
# Encoder | |
sizes = [] | |
skip = [] | |
for encoder_layer in self.encoder: | |
sizes.append(input.size()) | |
input = encoder_layer(input) | |
skip.append(input) | |
# Reverse sizes and skip list for decoder | |
sizes = sizes[::-1] | |
skip = skip[::-1] | |
# Decoder | |
output_size = torch.Size([x.size()[0], self.channels[0], x.size()[2], x.size()[3]]) | |
for i, decoder_layer in enumerate(self.decoder): | |
is_first_layer = (i == 0) | |
is_penultimate_layer = (i == (self.num_layers - 1)) | |
is_last_layer = (i == self.num_layers) | |
if not is_last_layer: | |
size = output_size if is_penultimate_layer else sizes[i] | |
# Skip connection | |
if (not is_first_layer) and (not is_last_layer): | |
from_encoder = skip[i] | |
input = torch.cat([from_encoder, input], dim=1) | |
# Forward step | |
if not is_last_layer: | |
input = decoder_layer(input, output_size=size) | |
else: | |
input = decoder_layer(input) | |
return input | |
class UNetDownBlock(nn.Module): | |
def __init__(self, in_channels, out_channels, activation, | |
conv_kernel, conv_stride, conv_padding, | |
maxpool_kernel, maxpool_stride): | |
super(UNetDownBlock, self).__init__() | |
# Activation function | |
self.activation = activation() | |
# Double convolutions | |
self.conv1 = nn.Conv2d(in_channels, out_channels, | |
conv_kernel, conv_stride, conv_padding) | |
self.conv2 = nn.Conv2d(out_channels, out_channels, | |
conv_kernel, conv_stride, conv_padding) | |
# Max pooling | |
self.mp = nn.MaxPool2d(maxpool_kernel, maxpool_stride, ceil_mode=True) | |
# Batch normalization | |
self.bn1 = nn.BatchNorm2d(out_channels) | |
@autocast() | |
def forward(self, x): | |
x = self.activation(self.conv1(x)) | |
x = self.activation(self.conv2(x)) | |
x = self.mp(x) | |
x = self.bn1(x) | |
return x | |
class UNetUpBlock(nn.Module): | |
def __init__(self, in_channels, out_channels, activation, | |
conv_kernel, conv_stride, conv_padding, | |
upconv_kernel, upconv_stride, upconv_padding): | |
super(UNetUpBlock, self).__init__() | |
# Activation function | |
self.activation = activation() | |
# Up-convolution | |
self.upconv = nn.ConvTranspose2d(in_channels, out_channels, | |
upconv_kernel, upconv_stride, upconv_padding) | |
# Double convolutions | |
self.conv1 = nn.Conv2d(out_channels, out_channels, | |
conv_kernel, conv_stride, conv_padding) | |
self.conv2 = nn.Conv2d(out_channels, out_channels, | |
conv_kernel, conv_stride, conv_padding) | |
# Batch normalization | |
self.bn1 = nn.BatchNorm2d(out_channels) | |
@autocast() | |
def forward(self, x, output_size): | |
x = self.activation(self.upconv(x, output_size=output_size)) | |
x = self.activation(self.conv1(x)) | |
x = self.activation(self.conv2(x)) | |
x = self.bn1(x) | |
return x | |
class UNetOutputBlock(nn.Module): | |
def __init__(self, in_channels, out_channels, activation): | |
super(UNetOutputBlock, self).__init__() | |
# Activation function | |
self.activation = activation() | |
# Convolution 1x1 | |
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1) | |
@autocast() | |
def forward(self, x): | |
x = self.activation(self.conv(x)) | |
return x |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def train(data_loader_, model_, optimizer, optim_params, loss_fn, epochs, checkpoint_freq, model_dir_, checkpoint=None): | |
""" | |
Trains the model | |
:param data_loader_: DataLoader to retrieve the training data | |
:param model_: Model to train | |
:param optimizer: Optimizer | |
:param loss_fn: Loss function | |
:param lr: Learning rate | |
:param epochs: Number of epochs to train | |
:param checkpoint_freq: Frequency to save model checkpoints (in epochs) | |
:returns: | |
- Trained model - | |
- Training loss - | |
""" | |
# Define loss function | |
loss_fn = loss_fn().cuda() | |
# Define optimizer | |
optimizer = optimizer(model_.parameters(), **optim_params) | |
# Set gradient scaler for mixed precision | |
scaler = amp.GradScaler() | |
# Set model in train mode | |
model_.train() | |
global_step = 0 | |
for t in trange(epochs, desc='epochs'): | |
prefetcher = ImagePrefetcher(data_loader=data_loader_) | |
batch = prefetcher.next() | |
# Loop over batch examples | |
while batch is not None: | |
inputs = batch['l'] | |
outputs = batch['ab'] | |
# Forward propagation | |
pred = model_(inputs) | |
# Compute loss | |
loss = loss_fn(pred, outputs) | |
# Backpropagation | |
optimizer.zero_grad() | |
# AMP backward pass | |
scaler.scale(loss).backward() | |
# Apply gradients | |
scaler.step(optimizer) | |
# Updates the scale for next iteration | |
scaler.update() | |
# Advance and prefetch batch for next iteration | |
global_step += 1 | |
batch = prefetcher.next() | |
torch.cuda.empty_cache() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment