Skip to content

Instantly share code, notes, and snippets.

@ProGamerGov
Last active December 8, 2020 16:03
Show Gist options
  • Save ProGamerGov/7294364e7e58d239fb1a8c0ae8a0957e to your computer and use it in GitHub Desktop.
Save ProGamerGov/7294364e7e58d239fb1a8c0ae8a0957e to your computer and use it in GitHub Desktop.
# Parameterization code from:
# https://github.com/ProGamerGov/dream-creator
# And from my and Ludwig's work on PyTorch's Captum:
# https://github.com/ProGamerGov/captum/tree/optim-wip
# https://github.com/ludwigschubert
# Test params:
# python neural_style.py -optimizer adam -backend cudnn -cudnn_autotune
# -learning_rate 0.0025 -init image -content_weight 100 -style_weight 6000 -normalize_gradients -style_image wave.jpg
import os
import copy
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from PIL import Image
from CaffeLoader import loadCaffemodel, ModelParallel
import argparse
parser = argparse.ArgumentParser()
# Basic options
parser.add_argument("-style_image", help="Style target image", default='examples/inputs/seated-nude.jpg')
parser.add_argument("-style_blend_weights", default=None)
parser.add_argument("-content_image", help="Content target image", default='examples/inputs/tubingen.jpg')
parser.add_argument("-image_size", help="Maximum height / width of generated image", type=int, default=512)
parser.add_argument("-gpu", help="Zero-indexed ID of the GPU to use; for CPU mode set -gpu = c", default=0)
# Optimization options
parser.add_argument("-content_weight", type=float, default=5e0)
parser.add_argument("-style_weight", type=float, default=1e2)
parser.add_argument("-normalize_weights", action='store_true')
parser.add_argument("-normalize_gradients", action='store_true')
parser.add_argument("-tv_weight", type=float, default=1e-3)
parser.add_argument("-num_iterations", type=int, default=1000)
parser.add_argument("-init", choices=['random', 'image'], default='random')
parser.add_argument("-init_image", default=None)
parser.add_argument("-optimizer", choices=['lbfgs', 'adam'], default='adam')
parser.add_argument("-learning_rate", type=float, default=0.024)
parser.add_argument("-lbfgs_num_correction", type=int, default=100)
# Output options
parser.add_argument("-print_iter", type=int, default=50)
parser.add_argument("-save_iter", type=int, default=100)
parser.add_argument("-output_image", default='out.png')
# Other options
parser.add_argument("-style_scale", type=float, default=1.0)
parser.add_argument("-original_colors", type=int, choices=[0, 1], default=0)
parser.add_argument("-pooling", choices=['avg', 'max'], default='max')
parser.add_argument("-model_file", type=str, default='models/vgg19-d01eb7cb.pth')
parser.add_argument("-disable_check", action='store_true')
parser.add_argument("-backend", choices=['nn', 'cudnn', 'mkl', 'mkldnn', 'openmp', 'mkl,cudnn', 'cudnn,mkl'], default='nn')
parser.add_argument("-cudnn_autotune", action='store_true')
parser.add_argument("-seed", type=int, default=-1)
parser.add_argument("-content_layers", help="layers for content", default='relu4_2')
parser.add_argument("-style_layers", help="layers for style", default='relu1_1,relu2_1,relu3_1,relu4_1,relu5_1')
parser.add_argument("-multidevice_strategy", default='4,7,29')
params = parser.parse_args()
Image.MAX_IMAGE_PIXELS = 1000000000 # Support gigapixel images
def main():
dtype, multidevice, backward_device = setup_gpu()
cnn, layerList = loadCaffemodel(params.model_file, params.pooling, params.gpu, params.disable_check)
content_image = preprocess(params.content_image, params.image_size).type(dtype)
style_image_input = params.style_image.split(',')
style_image_list, ext = [], [".jpg", ".jpeg", ".png", ".tiff"]
for image in style_image_input:
if os.path.isdir(image):
images = (image + "/" + file for file in os.listdir(image)
if os.path.splitext(file)[1].lower() in ext)
style_image_list.extend(images)
else:
style_image_list.append(image)
style_images_caffe = []
for image in style_image_list:
style_size = int(params.image_size * params.style_scale)
img_caffe = preprocess_functional(preprocess(image, style_size).type(dtype))
style_images_caffe.append(img_caffe)
if params.init_image != None:
image_size = (content_image.size(2), content_image.size(3))
init_image = preprocess(params.init_image, image_size).type(dtype)
# Handle style blending weights for multiple style inputs
style_blend_weights = []
if params.style_blend_weights == None:
# Style blending not specified, so use equal weighting
for i in style_image_list:
style_blend_weights.append(1.0)
for i, blend_weights in enumerate(style_blend_weights):
style_blend_weights[i] = int(style_blend_weights[i])
else:
style_blend_weights = params.style_blend_weights.split(',')
assert len(style_blend_weights) == len(style_image_list), \
"-style_blend_weights and -style_images must have the same number of elements!"
# Normalize the style blending weights so they sum to 1
style_blend_sum = 0
for i, blend_weights in enumerate(style_blend_weights):
style_blend_weights[i] = float(style_blend_weights[i])
style_blend_sum = float(style_blend_sum) + style_blend_weights[i]
for i, blend_weights in enumerate(style_blend_weights):
style_blend_weights[i] = float(style_blend_weights[i]) / float(style_blend_sum)
content_layers = params.content_layers.split(',')
style_layers = params.style_layers.split(',')
# Set up the network, inserting style and content loss modules
cnn = copy.deepcopy(cnn)
content_losses, style_losses, tv_losses = [], [], []
next_content_idx, next_style_idx = 1, 1
net = nn.Sequential()
c, r = 0, 0
if params.tv_weight > 0:
tv_mod = TVLoss(params.tv_weight).type(dtype)
net.add_module(str(len(net)), tv_mod)
tv_losses.append(tv_mod)
for i, layer in enumerate(list(cnn), 1):
if next_content_idx <= len(content_layers) or next_style_idx <= len(style_layers):
if isinstance(layer, nn.Conv2d):
net.add_module(str(len(net)), layer)
if layerList['C'][c] in content_layers:
print("Setting up content layer " + str(i) + ": " + str(layerList['C'][c]))
loss_module = ContentLoss(params.content_weight, params.normalize_gradients)
net.add_module(str(len(net)), loss_module)
content_losses.append(loss_module)
if layerList['C'][c] in style_layers:
print("Setting up style layer " + str(i) + ": " + str(layerList['C'][c]))
loss_module = StyleLoss(params.style_weight, params.normalize_gradients)
net.add_module(str(len(net)), loss_module)
style_losses.append(loss_module)
c+=1
if isinstance(layer, nn.ReLU):
net.add_module(str(len(net)), layer)
if layerList['R'][r] in content_layers:
print("Setting up content layer " + str(i) + ": " + str(layerList['R'][r]))
loss_module = ContentLoss(params.content_weight, params.normalize_gradients)
net.add_module(str(len(net)), loss_module)
content_losses.append(loss_module)
next_content_idx += 1
if layerList['R'][r] in style_layers:
print("Setting up style layer " + str(i) + ": " + str(layerList['R'][r]))
loss_module = StyleLoss(params.style_weight, params.normalize_gradients)
net.add_module(str(len(net)), loss_module)
style_losses.append(loss_module)
next_style_idx += 1
r+=1
if isinstance(layer, nn.MaxPool2d) or isinstance(layer, nn.AvgPool2d):
net.add_module(str(len(net)), layer)
if multidevice:
net = setup_multi_device(net)
# Capture content targets
for i in content_losses:
i.mode = 'capture'
print("Capturing content targets")
print_torch(net, multidevice)
net(preprocess_functional(content_image.clone().detach()))
# Capture style targets
for i in content_losses:
i.mode = 'None'
for i, image in enumerate(style_images_caffe):
print("Capturing style target " + str(i+1))
for j in style_losses:
j.mode = 'capture'
j.blend_weight = style_blend_weights[i]
net(style_images_caffe[i])
# Set all loss modules to loss mode
for i in content_losses:
i.mode = 'loss'
for i in style_losses:
i.mode = 'loss'
# Maybe normalize content and style weights
if params.normalize_weights:
normalize_weights(content_losses, style_losses)
# Freeze the network in order to prevent
# unnecessary gradient calculations
for param in net.parameters():
param.requires_grad = False
# Initialize the image
if params.seed >= 0:
torch.manual_seed(params.seed)
torch.cuda.manual_seed_all(params.seed)
torch.backends.cudnn.deterministic=True
if params.init == 'random':
B, C, H, W = content_image.size()
img = torch.randn(C, H, W).mul(0.001).unsqueeze(0).type(dtype)
elif params.init == 'image':
if params.init_image != None:
img = init_image.clone()
else:
img = content_image.clone()
# Setup color & spatially decorrelated content image for optimizer to optimize
squash_func = lambda x: torch.sigmoid(x) # lambda x: torch.sigmoid(x) or lambda x: x.clamp(0, 1)
img_param = ImageParam(init=img, squash_func=squash_func).to(backward_device)
preprocess_tensor = TransformLayer() # This is only needed for the optimization function, not saving.
robust_transforms = torch.nn.Sequential(
#torch.nn.ReflectionPad2d(16),
#RandomSpatialJitter(16),
#RandomScale(scale=(1, 0.975, 1.025, 0.95, 1.05)),
torchvision.transforms.RandomRotation(degrees=(-1,1)),
#RandomSpatialJitter(8),
#CenterCrop(16),
)
def maybe_print(t, loss):
if params.print_iter > 0 and t % params.print_iter == 0:
print("Iteration " + str(t) + " / "+ str(params.num_iterations))
for i, loss_module in enumerate(content_losses):
print(" Content " + str(i+1) + " loss: " + str(loss_module.loss.item()))
for i, loss_module in enumerate(style_losses):
print(" Style " + str(i+1) + " loss: " + str(loss_module.loss.item()))
print(" Total loss: " + str(loss.item()))
def maybe_save(t):
should_save = params.save_iter > 0 and t % params.save_iter == 0
should_save = should_save or t == params.num_iterations
if should_save:
output_filename, file_extension = os.path.splitext(params.output_image)
if t == params.num_iterations:
filename = output_filename + str(file_extension)
else:
filename = str(output_filename) + "_" + str(t) + str(file_extension)
disp = deprocess(img_param().clone().detach())
# Maybe perform postprocessing for color-independent style transfer
if params.original_colors == 1:
disp = original_colors(deprocess(content_image.clone()), disp)
disp.save(str(filename))
# Function to evaluate loss and gradient. We run the net forward and
# backward to get the gradient, and sum up losses from the loss modules.
# optim.lbfgs internally handles iteration and calls this function many
# times, so we manually count the number of iterations to handle printing
# and saving intermediate results.
num_calls = [0]
def feval():
num_calls[0] += 1
optimizer.zero_grad()
img = img_param()
img = preprocess_tensor(img)
#img = robust_transforms(img)
net(img)
loss = 0
for mod in content_losses:
loss += mod.loss.to(backward_device)
for mod in style_losses:
loss += mod.loss.to(backward_device)
if params.tv_weight > 0:
for mod in tv_losses:
loss += mod.loss.to(backward_device)
loss.backward()
maybe_save(num_calls[0])
maybe_print(num_calls[0], loss)
return loss
optimizer, loopVal = setup_optimizer(img_param)
while num_calls[0] <= loopVal:
optimizer.step(feval)
# Configure the optimizer
def setup_optimizer(img_param):
if params.optimizer == 'lbfgs':
print("Running optimization with L-BFGS")
optim_state = {
'max_iter': params.num_iterations,
'tolerance_change': -1,
'tolerance_grad': -1,
}
if params.lbfgs_num_correction != 100:
optim_state['history_size'] = params.lbfgs_num_correction
optimizer = optim.LBFGS(img_param.parameters(), **optim_state)
loopVal = 1
elif params.optimizer == 'adam':
print("Running optimization with ADAM")
optimizer = optim.Adam(img_param.parameters(), lr = params.learning_rate)
loopVal = params.num_iterations - 1
return optimizer, loopVal
def setup_gpu():
def setup_cuda():
if 'cudnn' in params.backend:
torch.backends.cudnn.enabled = True
if params.cudnn_autotune:
torch.backends.cudnn.benchmark = True
else:
torch.backends.cudnn.enabled = False
def setup_cpu():
if 'mkl' in params.backend and 'mkldnn' not in params.backend:
torch.backends.mkl.enabled = True
elif 'mkldnn' in params.backend:
raise ValueError("MKL-DNN is not supported yet.")
elif 'openmp' in params.backend:
torch.backends.openmp.enabled = True
multidevice = False
if "," in str(params.gpu):
devices = params.gpu.split(',')
multidevice = True
if 'c' in str(devices[0]).lower():
backward_device = "cpu"
setup_cuda(), setup_cpu()
else:
backward_device = "cuda:" + devices[0]
setup_cuda()
dtype = torch.FloatTensor
elif "c" not in str(params.gpu).lower():
setup_cuda()
dtype, backward_device = torch.cuda.FloatTensor, "cuda:" + str(params.gpu)
else:
setup_cpu()
dtype, backward_device = torch.FloatTensor, "cpu"
return dtype, multidevice, backward_device
def setup_multi_device(net):
assert len(params.gpu.split(',')) - 1 == len(params.multidevice_strategy.split(',')), \
"The number of -multidevice_strategy layer indices minus 1, must be equal to the number of -gpu devices."
new_net = ModelParallel(net, params.gpu, params.multidevice_strategy)
return new_net
# Preprocess an image before passing it to a model.
# We need to rescale from [0, 1] to [0, 255], convert from RGB to BGR,
# and subtract the mean pixel.
def preprocess(image_name, image_size):
image = Image.open(image_name).convert('RGB')
if type(image_size) is not tuple:
image_size = tuple([int((float(image_size) / max(image.size))*x) for x in (image.height, image.width)])
input_transforms = transforms.Compose([
transforms.Resize(image_size),
transforms.ToTensor(),
])
tensor = input_transforms(image).unsqueeze(0)
#print(tensor)
#quit()
return tensor
# Undo the above preprocessing.
def deprocess(output_tensor):
output_tensor = output_tensor.squeeze(0).cpu()
#output_tensor.clamp_(0, 1)
Image2PIL = transforms.ToPILImage()
image = Image2PIL(output_tensor.cpu())
return image
# Combine the Y channel of the generated image and the UV/CbCr channels of the
# content image to perform color-independent style transfer.
def original_colors(content, generated):
content_channels = list(content.convert('YCbCr').split())
generated_channels = list(generated.convert('YCbCr').split())
content_channels[0] = generated_channels[0]
return Image.merge('YCbCr', content_channels).convert('RGB')
# Print like Lua/Torch7
def print_torch(net, multidevice):
if multidevice:
return
simplelist = ""
for i, layer in enumerate(net, 1):
simplelist = simplelist + "(" + str(i) + ") -> "
print("nn.Sequential ( \n [input -> " + simplelist + "output]")
def strip(x):
return str(x).replace(", ",',').replace("(",'').replace(")",'') + ", "
def n():
return " (" + str(i) + "): " + "nn." + str(l).split("(", 1)[0]
for i, l in enumerate(net, 1):
if "2d" in str(l):
ks, st, pd = strip(l.kernel_size), strip(l.stride), strip(l.padding)
if "Conv2d" in str(l):
ch = str(l.in_channels) + " -> " + str(l.out_channels)
print(n() + "(" + ch + ", " + (ks).replace(",",'x', 1) + st + pd.replace(", ",')'))
elif "Pool2d" in str(l):
st = st.replace(" ",' ') + st.replace(", ",')')
print(n() + "(" + ((ks).replace(",",'x' + ks, 1) + st).replace(", ",','))
else:
print(n())
print(")")
# Divide weights by channel size
def normalize_weights(content_losses, style_losses):
for n, i in enumerate(content_losses):
i.strength = i.strength / max(i.target.size())
for n, i in enumerate(style_losses):
i.strength = i.strength / max(i.target.size())
# Scale gradients in the backward pass
class ScaleGradients(torch.autograd.Function):
@staticmethod
def forward(self, input_tensor, strength):
self.strength = strength
return input_tensor
@staticmethod
def backward(self, grad_output):
grad_input = grad_output.clone()
grad_input = grad_input / (torch.norm(grad_input, keepdim=True) + 1e-8)
return grad_input * self.strength * self.strength, None
# Define an nn Module to compute content loss
class ContentLoss(nn.Module):
def __init__(self, strength, normalize):
super(ContentLoss, self).__init__()
self.strength = strength
self.crit = nn.MSELoss()
self.mode = 'None'
self.normalize = normalize
def forward(self, input):
if self.mode == 'loss':
loss = self.crit(input, self.target)
if self.normalize:
loss = ScaleGradients.apply(loss, self.strength)
self.loss = loss * self.strength
elif self.mode == 'capture':
self.target = input.detach()
return input
class GramMatrix(nn.Module):
def forward(self, input):
B, C, H, W = input.size()
x_flat = input.view(C, H * W)
return torch.mm(x_flat, x_flat.t())
# Define an nn Module to compute style loss
class StyleLoss(nn.Module):
def __init__(self, strength, normalize):
super(StyleLoss, self).__init__()
self.target = torch.Tensor()
self.strength = strength
self.gram = GramMatrix()
self.crit = nn.MSELoss()
self.mode = 'None'
self.blend_weight = None
self.normalize = normalize
def forward(self, input):
self.G = self.gram(input)
self.G = self.G.div(input.nelement())
if self.mode == 'capture':
if self.blend_weight == None:
self.target = self.G.detach()
elif self.target.nelement() == 0:
self.target = self.G.detach().mul(self.blend_weight)
else:
self.target = self.target.add(self.blend_weight, self.G.detach())
elif self.mode == 'loss':
loss = self.crit(self.G, self.target)
if self.normalize:
loss = ScaleGradients.apply(loss, self.strength)
self.loss = self.strength * loss
return input
class TVLoss(nn.Module):
def __init__(self, strength):
super(TVLoss, self).__init__()
self.strength = strength
def forward(self, input):
self.x_diff = input[:,:,1:,:] - input[:,:,:-1,:]
self.y_diff = input[:,:,:,1:] - input[:,:,:,:-1]
self.loss = self.strength * (torch.sum(torch.abs(self.x_diff)) + torch.sum(torch.abs(self.y_diff)))
return input
class FFTImage(nn.Module):
"""Parameterize an image using inverse real 2D FFT"""
def __init__(
self,
size = None,
channels: int = 3,
batch: int = 1,
init = None,
) -> None:
super().__init__()
if init is None:
assert len(size) == 2
self.size = size
else:
assert init.dim() == 3 or init.dim() == 4
self.size = (
(init.size(1), init.size(2))
if init.dim() == 3
else (init.size(2), init.size(3))
)
frequencies = FFTImage.rfft2d_freqs(*self.size).to(init.device)
scale = 1.0 / torch.max(
frequencies,
torch.full_like(frequencies, 1.0 / (max(self.size[0], self.size[1]))),
)
scale = scale * ((self.size[0] * self.size[1]) ** (1 / 2))
spectrum_scale = scale[None, :, :, None]
self.register_buffer("spectrum_scale", spectrum_scale)
if init is None:
coeffs_shape = (channels, self.size[0], self.size[1] // 2 + 1, 2)
random_coeffs = torch.randn(
coeffs_shape
) # names=["C", "H_f", "W_f", "complex"]
fourier_coeffs = random_coeffs / 50
else:
fourier_coeffs = torch.rfft(init, signal_ndim=2)
w = spectrum_scale.size(2) - fourier_coeffs.size(3) if init.dim() == 4 else spectrum_scale.size(2) - fourier_coeffs.size(2)
self.spectrum_scale = spectrum_scale[:, :, w:, :] if w > 0 else self.spectrum_scale
fourier_coeffs = fourier_coeffs / self.spectrum_scale
self.fourier_coeffs = nn.Parameter(fourier_coeffs)
@staticmethod
def rfft2d_freqs(height: int, width: int) -> torch.Tensor:
"""Computes 2D spectrum frequencies."""
fy = FFTImage.pytorch_fftfreq(height)[:, None]
# on odd input dimensions we need to keep one additional frequency
wadd = 2 if width % 2 == 1 else 1
fx = FFTImage.pytorch_fftfreq(width)[: width // 2 + wadd]
return torch.sqrt((fx * fx) + (fy * fy))
@staticmethod
def pytorch_fftfreq(v: int, d: float = 1.0) -> torch.Tensor:
"""PyTorch version of np.fft.fftfreq"""
results = torch.empty(v)
s = (v - 1) // 2 + 1
results[:s] = torch.arange(0, s)
results[s:] = torch.arange(-(v // 2), 0)
return results * (1.0 / (v * d))
def forward(self) -> torch.Tensor:
h, w = self.size
scaled_spectrum = self.fourier_coeffs * self.spectrum_scale
output = torch.irfft(scaled_spectrum, signal_ndim=2)[:, :, :h, :w]
return output
# Parameterized Input Image
class ImageParam(nn.Module):
def __init__(self, init, squash_func, decorrelate_color=True):
super(ImageParam, self).__init__()
self.transform = self.get_color_matrix().to(init.device)
self.squash_func = squash_func
self.decorrelate_color = decorrelate_color
if self.decorrelate_color:
init = self.forward_image(init, inverse=True) # Decorrelate content_image colors
self.image = FFTImage(init=init)
def get_color_matrix(self):
transform = torch.Tensor([[0.26, 0.09, 0.02],
[0.27, 0.00, -0.05],
[0.27, -0.09, 0.03]])
transform = transform / torch.max(torch.norm(transform, dim=0))
return transform
def forward_image(self, x, inverse=False):
x = x.refine_names("B", "C", "H", "W")
h, w = x.size("H"), x.size("W")
flat = x.flatten(("H", "W"), "spatials")
if inverse:
correct = torch.inverse(self.transform) @ flat
else:
correct = self.transform @ flat
chw = correct.unflatten("spatials", (("H", h), ("W", w))).rename(None)
return chw
def forward(self):
image = self.image()
if self.decorrelate_color:
image = self.forward_image(image)
return self.squash_func(image)
# Preprocess input after decorrelation
class TransformLayer(torch.nn.Module):
def __init__(self, mean=[1,1,1], device='cpu'):
super(TransformLayer, self).__init__()
self.input_mean = torch.as_tensor(mean).view(3, 1, 1).to(device)
def forward(self, x):
assert x.dim() == 4
x = x[:, [2, 1, 0]] # RGB to BGR
x = x * 255 # Scale input range
return x - torch.as_tensor([103.939, 116.779, 123.68], device=x.device).view(3, 1, 1)
# Transform images for the model
def preprocess_functional(tensor):
tensor = tensor[:, [2, 1, 0]]
return (tensor * 255) - torch.as_tensor([103.939, 116.779, 123.68]).view(3, 1, 1).to(tensor.device)
def logit(p: torch.Tensor, epsilon: float = 1e-4) -> torch.Tensor:
p = torch.clamp(p, min=epsilon, max=1.0 - epsilon)
assert p.min() >= 0 and p.max() < 1
return torch.log(p / (1 - p))
class CenterCrop(torch.nn.Module):
"""
Center crop the specified amount of pixels from the edges.
Arguments:
size (int, sequence) or (int): Number of pixels to center crop away.
"""
def __init__(self, size = 0) -> None:
super(CenterCrop, self).__init__()
if type(size) is list or type(size) is tuple:
assert len(size) == 2, (
"CenterCrop requires a single crop value or a tuple of (height,width)"
+ "in pixels for cropping."
)
self.crop_val = size
else:
self.crop_val = [size] * 2
def forward(self, input: torch.Tensor) -> torch.Tensor:
assert (
input.dim() == 3 or input.dim() == 4
), "Input to CenterCrop must be 3D or 4D"
if input.dim() == 4:
h, w = input.size(2), input.size(3)
elif input.dim() == 3:
h, w = input.size(1), input.size(2)
h_crop = h - self.crop_val[0]
w_crop = w - self.crop_val[1]
sw, sh = w // 2 - (w_crop // 2), h // 2 - (h_crop // 2)
return input[..., sh : sh + h_crop, sw : sw + w_crop]
def rand_select(transform_values):
"""
Randomly return a value from the provided tuple or list
"""
n = torch.randint(low=0, high=len(transform_values) - 1, size=[1]).item()
return transform_values[n]
class RandomScale(nn.Module):
"""
Apply random rescaling on a NCHW tensor.
Arguments:
scale (float, sequence): Tuple of rescaling values to randomly select from.
"""
def __init__(self, scale) -> None:
super(RandomScale, self).__init__()
self.scale = scale
def get_scale_mat(
self, m, device: torch.device, dtype: torch.dtype
) -> torch.Tensor:
scale_mat = torch.tensor(
[[m, 0.0, 0.0], [0.0, m, 0.0]], device=device, dtype=dtype
)
return scale_mat
def scale_tensor(self, x: torch.Tensor, scale) -> torch.Tensor:
scale_matrix = self.get_scale_mat(scale, x.device, x.dtype)[None, ...].repeat(
x.shape[0], 1, 1
)
grid = F.affine_grid(scale_matrix, x.size())
x = F.grid_sample(x, grid)
return x
def forward(self, input: torch.Tensor) -> torch.Tensor:
scale = rand_select(self.scale)
return self.scale_tensor(input, scale=scale)
class RandomSpatialJitter(torch.nn.Module):
"""
Apply random spatial translations on a NCHW tensor.
Arguments:
translate (int):
"""
def __init__(self, translate: int) -> None:
super(RandomSpatialJitter, self).__init__()
self.pad_range = 2 * translate
self.pad = nn.ReflectionPad2d(translate)
def translate_tensor(self, x: torch.Tensor, insets: torch.Tensor) -> torch.Tensor:
padded = self.pad(x)
tblr = [
-insets[0],
-(self.pad_range - insets[0]),
-insets[1],
-(self.pad_range - insets[1]),
]
cropped = F.pad(padded, pad=tblr)
assert cropped.shape == x.shape
return cropped
def forward(self, input: torch.Tensor) -> torch.Tensor:
insets = torch.randint(high=self.pad_range, size=(2,))
return self.translate_tensor(input, insets)
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment