Skip to content

Instantly share code, notes, and snippets.

@PrimeF
Last active July 17, 2020 11:42
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save PrimeF/2f49e6ffb5f0be279250907aca886220 to your computer and use it in GitHub Desktop.
Save PrimeF/2f49e6ffb5f0be279250907aca886220 to your computer and use it in GitHub Desktop.
---
optimizer:
class: "torch.optim.SGD"
params:
lr: 0.1
momentum: 0.9
nesterov: True
loss: "torch.nn.MSELoss"
inner_activation: "torch.nn.ReLU"
outer_activation: "torch.nn.Softsign"
num_epochs: 50
batch_size: 20
checkpoint_frequency: 1
num_layers: 5
in_channels: 1
out_channels: 2
channels: [64, 128, 256, 512, 1024]
conv_kernel: 3
conv_stride: 1
conv_padding: 1
upconv_kernel: 3
upconv_stride: 2
upconv_padding: 1
maxpool_kernel: 2
maxpool_stride: 2
# Python standard libraries
import os
import glob
# Installed libraries
import numpy as np
from PIL import Image, ImageCms
import torch
from torch.utils.data import DataLoader
from torch.utils.data.dataset import Dataset
class ImageDataset(Dataset):
def __init__(self, path, device, transforms=None):
super(ImageDataset, self).__init__()
self.img_paths = glob.glob(os.path.join(path, '*.jpg'))
self.device = device
self.transforms = transforms
# Create input and output colour profiles.
self._rgb_profile = ImageCms.createProfile(colorSpace='sRGB')
self._lab_profile = ImageCms.createProfile(colorSpace='LAB')
# Create a transform object from the input and output profiles
self.rgb2lab = ImageCms.buildTransform(
inputProfile=self._rgb_profile,
outputProfile=self._lab_profile,
inMode='RGB',
outMode='LAB'
)
def __getitem__(self, idx):
# Read PIL image and strip alpha channel if present
img = Image.open(self.img_paths[idx]).convert('RGB')
# Apply torch dataset transformation
if self.transforms:
img = self.transforms(img)
# Convert to LAB
lab = ImageCms.applyTransform(img, self.rgb2lab)
# Convert to ndarray
lab = np.array(lab)
# Convert from (H x W x C) to (C x H x W)
lab = np.transpose(lab, axes=[2, 0, 1])
# Normalize
lab = lab / 255.
# Get grayscale L channel
l = lab[:1, :, :]
# Get colors AB channels
ab = lab[1:, :, :]
return {'l': torch.tensor(l, device=self.device),
'ab': torch.tensor(ab, device=self.device)}
def __len__(self):
return len(self.img_paths)
class ImagePrefetcher:
def __init__(self, data_loader: DataLoader):
self._data_loader = iter(data_loader)
self._stream = torch.cuda.Stream()
self._next_batch = None
self.preload()
def preload(self):
try:
self._next_batch = next(self._data_loader)
except StopIteration:
self._next_batch = None
return
with torch.cuda.stream(self._stream):
self._next_batch = {u: v.contiguous().cuda(non_blocking=True) for u, v in self._next_batch.items()}
def next(self):
torch.cuda.current_stream().wait_stream(self._stream)
batch = self._next_batch
self.preload()
return batch
# Installed libraries
import torch
import torch.nn as nn
from torch.cuda.amp import autocast
# Custom libraries
from utils import resolve_class
class Model(nn.Module):
def __init__(self, **kwargs):
super(Model, self).__init__()
# Parse parameters
self.num_layers = kwargs['num_layers']
self.inner_activation = resolve_class(kwargs['inner_activation'])
self.outer_activation = resolve_class(kwargs['outer_activation'])
self.in_channels = kwargs['in_channels']
self.out_channels = kwargs['out_channels']
self.channels = kwargs['channels']
self.conv_kernel = kwargs['conv_kernel']
self.conv_stride = kwargs['conv_stride']
self.conv_padding = kwargs['conv_padding']
self.upconv_kernel = kwargs['upconv_kernel']
self.upconv_stride = kwargs['upconv_stride']
self.upconv_padding = kwargs['upconv_padding']
self.maxpool_kernel = kwargs['maxpool_kernel']
self.maxpool_stride = kwargs['maxpool_stride']
# Define Architecture
self.encoder, self.decoder = [], []
for i in range(0, self.num_layers):
is_first_layer = (i == 0)
is_last_layer = (i == (self.num_layers - 1))
self.encoder.append(UNetDownBlock(self.in_channels,
self.channels[i],
self.inner_activation,
self.conv_kernel,
self.conv_stride,
self.conv_padding,
self.maxpool_kernel,
self.maxpool_stride))
self.decoder.append(UNetUpBlock(self.channels[i] if is_last_layer else 2 * self.channels[i],
self.channels[i] if is_first_layer else self.in_channels,
self.inner_activation,
self.conv_kernel,
self.conv_stride,
self.conv_padding,
self.upconv_kernel,
self.upconv_stride,
self.upconv_padding))
# Set input channels for next layer
self.in_channels = self.channels[i]
# Reverse Decoder
self.decoder = self.decoder[::-1]
# Output Block (1x1 Convolution)
self.decoder.append(UNetOutputBlock(self.channels[0], self.out_channels, self.outer_activation))
# Create PyTorch module ĺist
self.encoder = nn.ModuleList(self.encoder)
self.decoder = nn.ModuleList(self.decoder)
@autocast()
def forward(self, x):
input = x
# Encoder
sizes = []
skip = []
for encoder_layer in self.encoder:
sizes.append(input.size())
input = encoder_layer(input)
skip.append(input)
# Reverse sizes and skip list for decoder
sizes = sizes[::-1]
skip = skip[::-1]
# Decoder
output_size = torch.Size([x.size()[0], self.channels[0], x.size()[2], x.size()[3]])
for i, decoder_layer in enumerate(self.decoder):
is_first_layer = (i == 0)
is_penultimate_layer = (i == (self.num_layers - 1))
is_last_layer = (i == self.num_layers)
if not is_last_layer:
size = output_size if is_penultimate_layer else sizes[i]
# Skip connection
if (not is_first_layer) and (not is_last_layer):
from_encoder = skip[i]
input = torch.cat([from_encoder, input], dim=1)
# Forward step
if not is_last_layer:
input = decoder_layer(input, output_size=size)
else:
input = decoder_layer(input)
return input
class UNetDownBlock(nn.Module):
def __init__(self, in_channels, out_channels, activation,
conv_kernel, conv_stride, conv_padding,
maxpool_kernel, maxpool_stride):
super(UNetDownBlock, self).__init__()
# Activation function
self.activation = activation()
# Double convolutions
self.conv1 = nn.Conv2d(in_channels, out_channels,
conv_kernel, conv_stride, conv_padding)
self.conv2 = nn.Conv2d(out_channels, out_channels,
conv_kernel, conv_stride, conv_padding)
# Max pooling
self.mp = nn.MaxPool2d(maxpool_kernel, maxpool_stride, ceil_mode=True)
# Batch normalization
self.bn1 = nn.BatchNorm2d(out_channels)
@autocast()
def forward(self, x):
x = self.activation(self.conv1(x))
x = self.activation(self.conv2(x))
x = self.mp(x)
x = self.bn1(x)
return x
class UNetUpBlock(nn.Module):
def __init__(self, in_channels, out_channels, activation,
conv_kernel, conv_stride, conv_padding,
upconv_kernel, upconv_stride, upconv_padding):
super(UNetUpBlock, self).__init__()
# Activation function
self.activation = activation()
# Up-convolution
self.upconv = nn.ConvTranspose2d(in_channels, out_channels,
upconv_kernel, upconv_stride, upconv_padding)
# Double convolutions
self.conv1 = nn.Conv2d(out_channels, out_channels,
conv_kernel, conv_stride, conv_padding)
self.conv2 = nn.Conv2d(out_channels, out_channels,
conv_kernel, conv_stride, conv_padding)
# Batch normalization
self.bn1 = nn.BatchNorm2d(out_channels)
@autocast()
def forward(self, x, output_size):
x = self.activation(self.upconv(x, output_size=output_size))
x = self.activation(self.conv1(x))
x = self.activation(self.conv2(x))
x = self.bn1(x)
return x
class UNetOutputBlock(nn.Module):
def __init__(self, in_channels, out_channels, activation):
super(UNetOutputBlock, self).__init__()
# Activation function
self.activation = activation()
# Convolution 1x1
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1)
@autocast()
def forward(self, x):
x = self.activation(self.conv(x))
return x
def train(data_loader_, model_, optimizer, optim_params, loss_fn, epochs, checkpoint_freq, model_dir_, checkpoint=None):
"""
Trains the model
:param data_loader_: DataLoader to retrieve the training data
:param model_: Model to train
:param optimizer: Optimizer
:param loss_fn: Loss function
:param lr: Learning rate
:param epochs: Number of epochs to train
:param checkpoint_freq: Frequency to save model checkpoints (in epochs)
:returns:
- Trained model -
- Training loss -
"""
# Define loss function
loss_fn = loss_fn().cuda()
# Define optimizer
optimizer = optimizer(model_.parameters(), **optim_params)
# Set gradient scaler for mixed precision
scaler = amp.GradScaler()
# Set model in train mode
model_.train()
global_step = 0
for t in trange(epochs, desc='epochs'):
prefetcher = ImagePrefetcher(data_loader=data_loader_)
batch = prefetcher.next()
# Loop over batch examples
while batch is not None:
inputs = batch['l']
outputs = batch['ab']
# Forward propagation
pred = model_(inputs)
# Compute loss
loss = loss_fn(pred, outputs)
# Backpropagation
optimizer.zero_grad()
# AMP backward pass
scaler.scale(loss).backward()
# Apply gradients
scaler.step(optimizer)
# Updates the scale for next iteration
scaler.update()
# Advance and prefetch batch for next iteration
global_step += 1
batch = prefetcher.next()
torch.cuda.empty_cache()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment