Instantly share code, notes, and snippets.
Last active
December 8, 2020 16:03
-
Star
(0)
0
You must be signed in to star a gist -
Fork
(0)
0
You must be signed in to fork a gist
-
Save ProGamerGov/7294364e7e58d239fb1a8c0ae8a0957e to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Parameterization code from: | |
# https://github.com/ProGamerGov/dream-creator | |
# And from my and Ludwig's work on PyTorch's Captum: | |
# https://github.com/ProGamerGov/captum/tree/optim-wip | |
# https://github.com/ludwigschubert | |
# Test params: | |
# python neural_style.py -optimizer adam -backend cudnn -cudnn_autotune | |
# -learning_rate 0.0025 -init image -content_weight 100 -style_weight 6000 -normalize_gradients -style_image wave.jpg | |
import os | |
import copy | |
import torch | |
import torch.nn as nn | |
import torch.optim as optim | |
import torch.nn.functional as F | |
import torchvision | |
import torchvision.transforms as transforms | |
from PIL import Image | |
from CaffeLoader import loadCaffemodel, ModelParallel | |
import argparse | |
parser = argparse.ArgumentParser() | |
# Basic options | |
parser.add_argument("-style_image", help="Style target image", default='examples/inputs/seated-nude.jpg') | |
parser.add_argument("-style_blend_weights", default=None) | |
parser.add_argument("-content_image", help="Content target image", default='examples/inputs/tubingen.jpg') | |
parser.add_argument("-image_size", help="Maximum height / width of generated image", type=int, default=512) | |
parser.add_argument("-gpu", help="Zero-indexed ID of the GPU to use; for CPU mode set -gpu = c", default=0) | |
# Optimization options | |
parser.add_argument("-content_weight", type=float, default=5e0) | |
parser.add_argument("-style_weight", type=float, default=1e2) | |
parser.add_argument("-normalize_weights", action='store_true') | |
parser.add_argument("-normalize_gradients", action='store_true') | |
parser.add_argument("-tv_weight", type=float, default=1e-3) | |
parser.add_argument("-num_iterations", type=int, default=1000) | |
parser.add_argument("-init", choices=['random', 'image'], default='random') | |
parser.add_argument("-init_image", default=None) | |
parser.add_argument("-optimizer", choices=['lbfgs', 'adam'], default='adam') | |
parser.add_argument("-learning_rate", type=float, default=0.024) | |
parser.add_argument("-lbfgs_num_correction", type=int, default=100) | |
# Output options | |
parser.add_argument("-print_iter", type=int, default=50) | |
parser.add_argument("-save_iter", type=int, default=100) | |
parser.add_argument("-output_image", default='out.png') | |
# Other options | |
parser.add_argument("-style_scale", type=float, default=1.0) | |
parser.add_argument("-original_colors", type=int, choices=[0, 1], default=0) | |
parser.add_argument("-pooling", choices=['avg', 'max'], default='max') | |
parser.add_argument("-model_file", type=str, default='models/vgg19-d01eb7cb.pth') | |
parser.add_argument("-disable_check", action='store_true') | |
parser.add_argument("-backend", choices=['nn', 'cudnn', 'mkl', 'mkldnn', 'openmp', 'mkl,cudnn', 'cudnn,mkl'], default='nn') | |
parser.add_argument("-cudnn_autotune", action='store_true') | |
parser.add_argument("-seed", type=int, default=-1) | |
parser.add_argument("-content_layers", help="layers for content", default='relu4_2') | |
parser.add_argument("-style_layers", help="layers for style", default='relu1_1,relu2_1,relu3_1,relu4_1,relu5_1') | |
parser.add_argument("-multidevice_strategy", default='4,7,29') | |
params = parser.parse_args() | |
Image.MAX_IMAGE_PIXELS = 1000000000 # Support gigapixel images | |
def main(): | |
dtype, multidevice, backward_device = setup_gpu() | |
cnn, layerList = loadCaffemodel(params.model_file, params.pooling, params.gpu, params.disable_check) | |
content_image = preprocess(params.content_image, params.image_size).type(dtype) | |
style_image_input = params.style_image.split(',') | |
style_image_list, ext = [], [".jpg", ".jpeg", ".png", ".tiff"] | |
for image in style_image_input: | |
if os.path.isdir(image): | |
images = (image + "/" + file for file in os.listdir(image) | |
if os.path.splitext(file)[1].lower() in ext) | |
style_image_list.extend(images) | |
else: | |
style_image_list.append(image) | |
style_images_caffe = [] | |
for image in style_image_list: | |
style_size = int(params.image_size * params.style_scale) | |
img_caffe = preprocess_functional(preprocess(image, style_size).type(dtype)) | |
style_images_caffe.append(img_caffe) | |
if params.init_image != None: | |
image_size = (content_image.size(2), content_image.size(3)) | |
init_image = preprocess(params.init_image, image_size).type(dtype) | |
# Handle style blending weights for multiple style inputs | |
style_blend_weights = [] | |
if params.style_blend_weights == None: | |
# Style blending not specified, so use equal weighting | |
for i in style_image_list: | |
style_blend_weights.append(1.0) | |
for i, blend_weights in enumerate(style_blend_weights): | |
style_blend_weights[i] = int(style_blend_weights[i]) | |
else: | |
style_blend_weights = params.style_blend_weights.split(',') | |
assert len(style_blend_weights) == len(style_image_list), \ | |
"-style_blend_weights and -style_images must have the same number of elements!" | |
# Normalize the style blending weights so they sum to 1 | |
style_blend_sum = 0 | |
for i, blend_weights in enumerate(style_blend_weights): | |
style_blend_weights[i] = float(style_blend_weights[i]) | |
style_blend_sum = float(style_blend_sum) + style_blend_weights[i] | |
for i, blend_weights in enumerate(style_blend_weights): | |
style_blend_weights[i] = float(style_blend_weights[i]) / float(style_blend_sum) | |
content_layers = params.content_layers.split(',') | |
style_layers = params.style_layers.split(',') | |
# Set up the network, inserting style and content loss modules | |
cnn = copy.deepcopy(cnn) | |
content_losses, style_losses, tv_losses = [], [], [] | |
next_content_idx, next_style_idx = 1, 1 | |
net = nn.Sequential() | |
c, r = 0, 0 | |
if params.tv_weight > 0: | |
tv_mod = TVLoss(params.tv_weight).type(dtype) | |
net.add_module(str(len(net)), tv_mod) | |
tv_losses.append(tv_mod) | |
for i, layer in enumerate(list(cnn), 1): | |
if next_content_idx <= len(content_layers) or next_style_idx <= len(style_layers): | |
if isinstance(layer, nn.Conv2d): | |
net.add_module(str(len(net)), layer) | |
if layerList['C'][c] in content_layers: | |
print("Setting up content layer " + str(i) + ": " + str(layerList['C'][c])) | |
loss_module = ContentLoss(params.content_weight, params.normalize_gradients) | |
net.add_module(str(len(net)), loss_module) | |
content_losses.append(loss_module) | |
if layerList['C'][c] in style_layers: | |
print("Setting up style layer " + str(i) + ": " + str(layerList['C'][c])) | |
loss_module = StyleLoss(params.style_weight, params.normalize_gradients) | |
net.add_module(str(len(net)), loss_module) | |
style_losses.append(loss_module) | |
c+=1 | |
if isinstance(layer, nn.ReLU): | |
net.add_module(str(len(net)), layer) | |
if layerList['R'][r] in content_layers: | |
print("Setting up content layer " + str(i) + ": " + str(layerList['R'][r])) | |
loss_module = ContentLoss(params.content_weight, params.normalize_gradients) | |
net.add_module(str(len(net)), loss_module) | |
content_losses.append(loss_module) | |
next_content_idx += 1 | |
if layerList['R'][r] in style_layers: | |
print("Setting up style layer " + str(i) + ": " + str(layerList['R'][r])) | |
loss_module = StyleLoss(params.style_weight, params.normalize_gradients) | |
net.add_module(str(len(net)), loss_module) | |
style_losses.append(loss_module) | |
next_style_idx += 1 | |
r+=1 | |
if isinstance(layer, nn.MaxPool2d) or isinstance(layer, nn.AvgPool2d): | |
net.add_module(str(len(net)), layer) | |
if multidevice: | |
net = setup_multi_device(net) | |
# Capture content targets | |
for i in content_losses: | |
i.mode = 'capture' | |
print("Capturing content targets") | |
print_torch(net, multidevice) | |
net(preprocess_functional(content_image.clone().detach())) | |
# Capture style targets | |
for i in content_losses: | |
i.mode = 'None' | |
for i, image in enumerate(style_images_caffe): | |
print("Capturing style target " + str(i+1)) | |
for j in style_losses: | |
j.mode = 'capture' | |
j.blend_weight = style_blend_weights[i] | |
net(style_images_caffe[i]) | |
# Set all loss modules to loss mode | |
for i in content_losses: | |
i.mode = 'loss' | |
for i in style_losses: | |
i.mode = 'loss' | |
# Maybe normalize content and style weights | |
if params.normalize_weights: | |
normalize_weights(content_losses, style_losses) | |
# Freeze the network in order to prevent | |
# unnecessary gradient calculations | |
for param in net.parameters(): | |
param.requires_grad = False | |
# Initialize the image | |
if params.seed >= 0: | |
torch.manual_seed(params.seed) | |
torch.cuda.manual_seed_all(params.seed) | |
torch.backends.cudnn.deterministic=True | |
if params.init == 'random': | |
B, C, H, W = content_image.size() | |
img = torch.randn(C, H, W).mul(0.001).unsqueeze(0).type(dtype) | |
elif params.init == 'image': | |
if params.init_image != None: | |
img = init_image.clone() | |
else: | |
img = content_image.clone() | |
# Setup color & spatially decorrelated content image for optimizer to optimize | |
squash_func = lambda x: torch.sigmoid(x) # lambda x: torch.sigmoid(x) or lambda x: x.clamp(0, 1) | |
img_param = ImageParam(init=img, squash_func=squash_func).to(backward_device) | |
preprocess_tensor = TransformLayer() # This is only needed for the optimization function, not saving. | |
robust_transforms = torch.nn.Sequential( | |
#torch.nn.ReflectionPad2d(16), | |
#RandomSpatialJitter(16), | |
#RandomScale(scale=(1, 0.975, 1.025, 0.95, 1.05)), | |
torchvision.transforms.RandomRotation(degrees=(-1,1)), | |
#RandomSpatialJitter(8), | |
#CenterCrop(16), | |
) | |
def maybe_print(t, loss): | |
if params.print_iter > 0 and t % params.print_iter == 0: | |
print("Iteration " + str(t) + " / "+ str(params.num_iterations)) | |
for i, loss_module in enumerate(content_losses): | |
print(" Content " + str(i+1) + " loss: " + str(loss_module.loss.item())) | |
for i, loss_module in enumerate(style_losses): | |
print(" Style " + str(i+1) + " loss: " + str(loss_module.loss.item())) | |
print(" Total loss: " + str(loss.item())) | |
def maybe_save(t): | |
should_save = params.save_iter > 0 and t % params.save_iter == 0 | |
should_save = should_save or t == params.num_iterations | |
if should_save: | |
output_filename, file_extension = os.path.splitext(params.output_image) | |
if t == params.num_iterations: | |
filename = output_filename + str(file_extension) | |
else: | |
filename = str(output_filename) + "_" + str(t) + str(file_extension) | |
disp = deprocess(img_param().clone().detach()) | |
# Maybe perform postprocessing for color-independent style transfer | |
if params.original_colors == 1: | |
disp = original_colors(deprocess(content_image.clone()), disp) | |
disp.save(str(filename)) | |
# Function to evaluate loss and gradient. We run the net forward and | |
# backward to get the gradient, and sum up losses from the loss modules. | |
# optim.lbfgs internally handles iteration and calls this function many | |
# times, so we manually count the number of iterations to handle printing | |
# and saving intermediate results. | |
num_calls = [0] | |
def feval(): | |
num_calls[0] += 1 | |
optimizer.zero_grad() | |
img = img_param() | |
img = preprocess_tensor(img) | |
#img = robust_transforms(img) | |
net(img) | |
loss = 0 | |
for mod in content_losses: | |
loss += mod.loss.to(backward_device) | |
for mod in style_losses: | |
loss += mod.loss.to(backward_device) | |
if params.tv_weight > 0: | |
for mod in tv_losses: | |
loss += mod.loss.to(backward_device) | |
loss.backward() | |
maybe_save(num_calls[0]) | |
maybe_print(num_calls[0], loss) | |
return loss | |
optimizer, loopVal = setup_optimizer(img_param) | |
while num_calls[0] <= loopVal: | |
optimizer.step(feval) | |
# Configure the optimizer | |
def setup_optimizer(img_param): | |
if params.optimizer == 'lbfgs': | |
print("Running optimization with L-BFGS") | |
optim_state = { | |
'max_iter': params.num_iterations, | |
'tolerance_change': -1, | |
'tolerance_grad': -1, | |
} | |
if params.lbfgs_num_correction != 100: | |
optim_state['history_size'] = params.lbfgs_num_correction | |
optimizer = optim.LBFGS(img_param.parameters(), **optim_state) | |
loopVal = 1 | |
elif params.optimizer == 'adam': | |
print("Running optimization with ADAM") | |
optimizer = optim.Adam(img_param.parameters(), lr = params.learning_rate) | |
loopVal = params.num_iterations - 1 | |
return optimizer, loopVal | |
def setup_gpu(): | |
def setup_cuda(): | |
if 'cudnn' in params.backend: | |
torch.backends.cudnn.enabled = True | |
if params.cudnn_autotune: | |
torch.backends.cudnn.benchmark = True | |
else: | |
torch.backends.cudnn.enabled = False | |
def setup_cpu(): | |
if 'mkl' in params.backend and 'mkldnn' not in params.backend: | |
torch.backends.mkl.enabled = True | |
elif 'mkldnn' in params.backend: | |
raise ValueError("MKL-DNN is not supported yet.") | |
elif 'openmp' in params.backend: | |
torch.backends.openmp.enabled = True | |
multidevice = False | |
if "," in str(params.gpu): | |
devices = params.gpu.split(',') | |
multidevice = True | |
if 'c' in str(devices[0]).lower(): | |
backward_device = "cpu" | |
setup_cuda(), setup_cpu() | |
else: | |
backward_device = "cuda:" + devices[0] | |
setup_cuda() | |
dtype = torch.FloatTensor | |
elif "c" not in str(params.gpu).lower(): | |
setup_cuda() | |
dtype, backward_device = torch.cuda.FloatTensor, "cuda:" + str(params.gpu) | |
else: | |
setup_cpu() | |
dtype, backward_device = torch.FloatTensor, "cpu" | |
return dtype, multidevice, backward_device | |
def setup_multi_device(net): | |
assert len(params.gpu.split(',')) - 1 == len(params.multidevice_strategy.split(',')), \ | |
"The number of -multidevice_strategy layer indices minus 1, must be equal to the number of -gpu devices." | |
new_net = ModelParallel(net, params.gpu, params.multidevice_strategy) | |
return new_net | |
# Preprocess an image before passing it to a model. | |
# We need to rescale from [0, 1] to [0, 255], convert from RGB to BGR, | |
# and subtract the mean pixel. | |
def preprocess(image_name, image_size): | |
image = Image.open(image_name).convert('RGB') | |
if type(image_size) is not tuple: | |
image_size = tuple([int((float(image_size) / max(image.size))*x) for x in (image.height, image.width)]) | |
input_transforms = transforms.Compose([ | |
transforms.Resize(image_size), | |
transforms.ToTensor(), | |
]) | |
tensor = input_transforms(image).unsqueeze(0) | |
#print(tensor) | |
#quit() | |
return tensor | |
# Undo the above preprocessing. | |
def deprocess(output_tensor): | |
output_tensor = output_tensor.squeeze(0).cpu() | |
#output_tensor.clamp_(0, 1) | |
Image2PIL = transforms.ToPILImage() | |
image = Image2PIL(output_tensor.cpu()) | |
return image | |
# Combine the Y channel of the generated image and the UV/CbCr channels of the | |
# content image to perform color-independent style transfer. | |
def original_colors(content, generated): | |
content_channels = list(content.convert('YCbCr').split()) | |
generated_channels = list(generated.convert('YCbCr').split()) | |
content_channels[0] = generated_channels[0] | |
return Image.merge('YCbCr', content_channels).convert('RGB') | |
# Print like Lua/Torch7 | |
def print_torch(net, multidevice): | |
if multidevice: | |
return | |
simplelist = "" | |
for i, layer in enumerate(net, 1): | |
simplelist = simplelist + "(" + str(i) + ") -> " | |
print("nn.Sequential ( \n [input -> " + simplelist + "output]") | |
def strip(x): | |
return str(x).replace(", ",',').replace("(",'').replace(")",'') + ", " | |
def n(): | |
return " (" + str(i) + "): " + "nn." + str(l).split("(", 1)[0] | |
for i, l in enumerate(net, 1): | |
if "2d" in str(l): | |
ks, st, pd = strip(l.kernel_size), strip(l.stride), strip(l.padding) | |
if "Conv2d" in str(l): | |
ch = str(l.in_channels) + " -> " + str(l.out_channels) | |
print(n() + "(" + ch + ", " + (ks).replace(",",'x', 1) + st + pd.replace(", ",')')) | |
elif "Pool2d" in str(l): | |
st = st.replace(" ",' ') + st.replace(", ",')') | |
print(n() + "(" + ((ks).replace(",",'x' + ks, 1) + st).replace(", ",',')) | |
else: | |
print(n()) | |
print(")") | |
# Divide weights by channel size | |
def normalize_weights(content_losses, style_losses): | |
for n, i in enumerate(content_losses): | |
i.strength = i.strength / max(i.target.size()) | |
for n, i in enumerate(style_losses): | |
i.strength = i.strength / max(i.target.size()) | |
# Scale gradients in the backward pass | |
class ScaleGradients(torch.autograd.Function): | |
@staticmethod | |
def forward(self, input_tensor, strength): | |
self.strength = strength | |
return input_tensor | |
@staticmethod | |
def backward(self, grad_output): | |
grad_input = grad_output.clone() | |
grad_input = grad_input / (torch.norm(grad_input, keepdim=True) + 1e-8) | |
return grad_input * self.strength * self.strength, None | |
# Define an nn Module to compute content loss | |
class ContentLoss(nn.Module): | |
def __init__(self, strength, normalize): | |
super(ContentLoss, self).__init__() | |
self.strength = strength | |
self.crit = nn.MSELoss() | |
self.mode = 'None' | |
self.normalize = normalize | |
def forward(self, input): | |
if self.mode == 'loss': | |
loss = self.crit(input, self.target) | |
if self.normalize: | |
loss = ScaleGradients.apply(loss, self.strength) | |
self.loss = loss * self.strength | |
elif self.mode == 'capture': | |
self.target = input.detach() | |
return input | |
class GramMatrix(nn.Module): | |
def forward(self, input): | |
B, C, H, W = input.size() | |
x_flat = input.view(C, H * W) | |
return torch.mm(x_flat, x_flat.t()) | |
# Define an nn Module to compute style loss | |
class StyleLoss(nn.Module): | |
def __init__(self, strength, normalize): | |
super(StyleLoss, self).__init__() | |
self.target = torch.Tensor() | |
self.strength = strength | |
self.gram = GramMatrix() | |
self.crit = nn.MSELoss() | |
self.mode = 'None' | |
self.blend_weight = None | |
self.normalize = normalize | |
def forward(self, input): | |
self.G = self.gram(input) | |
self.G = self.G.div(input.nelement()) | |
if self.mode == 'capture': | |
if self.blend_weight == None: | |
self.target = self.G.detach() | |
elif self.target.nelement() == 0: | |
self.target = self.G.detach().mul(self.blend_weight) | |
else: | |
self.target = self.target.add(self.blend_weight, self.G.detach()) | |
elif self.mode == 'loss': | |
loss = self.crit(self.G, self.target) | |
if self.normalize: | |
loss = ScaleGradients.apply(loss, self.strength) | |
self.loss = self.strength * loss | |
return input | |
class TVLoss(nn.Module): | |
def __init__(self, strength): | |
super(TVLoss, self).__init__() | |
self.strength = strength | |
def forward(self, input): | |
self.x_diff = input[:,:,1:,:] - input[:,:,:-1,:] | |
self.y_diff = input[:,:,:,1:] - input[:,:,:,:-1] | |
self.loss = self.strength * (torch.sum(torch.abs(self.x_diff)) + torch.sum(torch.abs(self.y_diff))) | |
return input | |
class FFTImage(nn.Module): | |
"""Parameterize an image using inverse real 2D FFT""" | |
def __init__( | |
self, | |
size = None, | |
channels: int = 3, | |
batch: int = 1, | |
init = None, | |
) -> None: | |
super().__init__() | |
if init is None: | |
assert len(size) == 2 | |
self.size = size | |
else: | |
assert init.dim() == 3 or init.dim() == 4 | |
self.size = ( | |
(init.size(1), init.size(2)) | |
if init.dim() == 3 | |
else (init.size(2), init.size(3)) | |
) | |
frequencies = FFTImage.rfft2d_freqs(*self.size).to(init.device) | |
scale = 1.0 / torch.max( | |
frequencies, | |
torch.full_like(frequencies, 1.0 / (max(self.size[0], self.size[1]))), | |
) | |
scale = scale * ((self.size[0] * self.size[1]) ** (1 / 2)) | |
spectrum_scale = scale[None, :, :, None] | |
self.register_buffer("spectrum_scale", spectrum_scale) | |
if init is None: | |
coeffs_shape = (channels, self.size[0], self.size[1] // 2 + 1, 2) | |
random_coeffs = torch.randn( | |
coeffs_shape | |
) # names=["C", "H_f", "W_f", "complex"] | |
fourier_coeffs = random_coeffs / 50 | |
else: | |
fourier_coeffs = torch.rfft(init, signal_ndim=2) | |
w = spectrum_scale.size(2) - fourier_coeffs.size(3) if init.dim() == 4 else spectrum_scale.size(2) - fourier_coeffs.size(2) | |
self.spectrum_scale = spectrum_scale[:, :, w:, :] if w > 0 else self.spectrum_scale | |
fourier_coeffs = fourier_coeffs / self.spectrum_scale | |
self.fourier_coeffs = nn.Parameter(fourier_coeffs) | |
@staticmethod | |
def rfft2d_freqs(height: int, width: int) -> torch.Tensor: | |
"""Computes 2D spectrum frequencies.""" | |
fy = FFTImage.pytorch_fftfreq(height)[:, None] | |
# on odd input dimensions we need to keep one additional frequency | |
wadd = 2 if width % 2 == 1 else 1 | |
fx = FFTImage.pytorch_fftfreq(width)[: width // 2 + wadd] | |
return torch.sqrt((fx * fx) + (fy * fy)) | |
@staticmethod | |
def pytorch_fftfreq(v: int, d: float = 1.0) -> torch.Tensor: | |
"""PyTorch version of np.fft.fftfreq""" | |
results = torch.empty(v) | |
s = (v - 1) // 2 + 1 | |
results[:s] = torch.arange(0, s) | |
results[s:] = torch.arange(-(v // 2), 0) | |
return results * (1.0 / (v * d)) | |
def forward(self) -> torch.Tensor: | |
h, w = self.size | |
scaled_spectrum = self.fourier_coeffs * self.spectrum_scale | |
output = torch.irfft(scaled_spectrum, signal_ndim=2)[:, :, :h, :w] | |
return output | |
# Parameterized Input Image | |
class ImageParam(nn.Module): | |
def __init__(self, init, squash_func, decorrelate_color=True): | |
super(ImageParam, self).__init__() | |
self.transform = self.get_color_matrix().to(init.device) | |
self.squash_func = squash_func | |
self.decorrelate_color = decorrelate_color | |
if self.decorrelate_color: | |
init = self.forward_image(init, inverse=True) # Decorrelate content_image colors | |
self.image = FFTImage(init=init) | |
def get_color_matrix(self): | |
transform = torch.Tensor([[0.26, 0.09, 0.02], | |
[0.27, 0.00, -0.05], | |
[0.27, -0.09, 0.03]]) | |
transform = transform / torch.max(torch.norm(transform, dim=0)) | |
return transform | |
def forward_image(self, x, inverse=False): | |
x = x.refine_names("B", "C", "H", "W") | |
h, w = x.size("H"), x.size("W") | |
flat = x.flatten(("H", "W"), "spatials") | |
if inverse: | |
correct = torch.inverse(self.transform) @ flat | |
else: | |
correct = self.transform @ flat | |
chw = correct.unflatten("spatials", (("H", h), ("W", w))).rename(None) | |
return chw | |
def forward(self): | |
image = self.image() | |
if self.decorrelate_color: | |
image = self.forward_image(image) | |
return self.squash_func(image) | |
# Preprocess input after decorrelation | |
class TransformLayer(torch.nn.Module): | |
def __init__(self, mean=[1,1,1], device='cpu'): | |
super(TransformLayer, self).__init__() | |
self.input_mean = torch.as_tensor(mean).view(3, 1, 1).to(device) | |
def forward(self, x): | |
assert x.dim() == 4 | |
x = x[:, [2, 1, 0]] # RGB to BGR | |
x = x * 255 # Scale input range | |
return x - torch.as_tensor([103.939, 116.779, 123.68], device=x.device).view(3, 1, 1) | |
# Transform images for the model | |
def preprocess_functional(tensor): | |
tensor = tensor[:, [2, 1, 0]] | |
return (tensor * 255) - torch.as_tensor([103.939, 116.779, 123.68]).view(3, 1, 1).to(tensor.device) | |
def logit(p: torch.Tensor, epsilon: float = 1e-4) -> torch.Tensor: | |
p = torch.clamp(p, min=epsilon, max=1.0 - epsilon) | |
assert p.min() >= 0 and p.max() < 1 | |
return torch.log(p / (1 - p)) | |
class CenterCrop(torch.nn.Module): | |
""" | |
Center crop the specified amount of pixels from the edges. | |
Arguments: | |
size (int, sequence) or (int): Number of pixels to center crop away. | |
""" | |
def __init__(self, size = 0) -> None: | |
super(CenterCrop, self).__init__() | |
if type(size) is list or type(size) is tuple: | |
assert len(size) == 2, ( | |
"CenterCrop requires a single crop value or a tuple of (height,width)" | |
+ "in pixels for cropping." | |
) | |
self.crop_val = size | |
else: | |
self.crop_val = [size] * 2 | |
def forward(self, input: torch.Tensor) -> torch.Tensor: | |
assert ( | |
input.dim() == 3 or input.dim() == 4 | |
), "Input to CenterCrop must be 3D or 4D" | |
if input.dim() == 4: | |
h, w = input.size(2), input.size(3) | |
elif input.dim() == 3: | |
h, w = input.size(1), input.size(2) | |
h_crop = h - self.crop_val[0] | |
w_crop = w - self.crop_val[1] | |
sw, sh = w // 2 - (w_crop // 2), h // 2 - (h_crop // 2) | |
return input[..., sh : sh + h_crop, sw : sw + w_crop] | |
def rand_select(transform_values): | |
""" | |
Randomly return a value from the provided tuple or list | |
""" | |
n = torch.randint(low=0, high=len(transform_values) - 1, size=[1]).item() | |
return transform_values[n] | |
class RandomScale(nn.Module): | |
""" | |
Apply random rescaling on a NCHW tensor. | |
Arguments: | |
scale (float, sequence): Tuple of rescaling values to randomly select from. | |
""" | |
def __init__(self, scale) -> None: | |
super(RandomScale, self).__init__() | |
self.scale = scale | |
def get_scale_mat( | |
self, m, device: torch.device, dtype: torch.dtype | |
) -> torch.Tensor: | |
scale_mat = torch.tensor( | |
[[m, 0.0, 0.0], [0.0, m, 0.0]], device=device, dtype=dtype | |
) | |
return scale_mat | |
def scale_tensor(self, x: torch.Tensor, scale) -> torch.Tensor: | |
scale_matrix = self.get_scale_mat(scale, x.device, x.dtype)[None, ...].repeat( | |
x.shape[0], 1, 1 | |
) | |
grid = F.affine_grid(scale_matrix, x.size()) | |
x = F.grid_sample(x, grid) | |
return x | |
def forward(self, input: torch.Tensor) -> torch.Tensor: | |
scale = rand_select(self.scale) | |
return self.scale_tensor(input, scale=scale) | |
class RandomSpatialJitter(torch.nn.Module): | |
""" | |
Apply random spatial translations on a NCHW tensor. | |
Arguments: | |
translate (int): | |
""" | |
def __init__(self, translate: int) -> None: | |
super(RandomSpatialJitter, self).__init__() | |
self.pad_range = 2 * translate | |
self.pad = nn.ReflectionPad2d(translate) | |
def translate_tensor(self, x: torch.Tensor, insets: torch.Tensor) -> torch.Tensor: | |
padded = self.pad(x) | |
tblr = [ | |
-insets[0], | |
-(self.pad_range - insets[0]), | |
-insets[1], | |
-(self.pad_range - insets[1]), | |
] | |
cropped = F.pad(padded, pad=tblr) | |
assert cropped.shape == x.shape | |
return cropped | |
def forward(self, input: torch.Tensor) -> torch.Tensor: | |
insets = torch.randint(high=self.pad_range, size=(2,)) | |
return self.translate_tensor(input, insets) | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment