Skip to content

Instantly share code, notes, and snippets.

@ProGamerGov
Last active April 6, 2020 20:32
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save ProGamerGov/89693fc33e813ae9f63862f1d01b9369 to your computer and use it in GitHub Desktop.
Save ProGamerGov/89693fc33e813ae9f63862f1d01b9369 to your computer and use it in GitHub Desktop.

neural-style-pt with tiling

Tiling

The tiling feature is based on neural-dream's tiling system.

Usage

Basic usage:

python neural_style_tile.py -style_image <image.jpg> -content_image <image.jpg> -tile_size 256 -image_size 512

Tiling options:

  • -tile_size: The desired tile size. Default is 256. The style image size will be equal to tile_size * style_scale.
  • -overlap_percent: The percentage of overlap to use for the tiles. Default is 50.
  • -print_tile: Print the current tile being processed every print_tile tiles without any other information. Default is set to 1. Set it to 0 to disable printing.
  • -tile_iter: How many iterations to perform for each tile; default is set to 0. If set to 0, tile_iter is calculated by dividing num_iterations by 3.
  • -print_tile_iter: Print tile progress every print_tile_iter iterations. Default is set to 0 to disable printing.
  • -roll_image: If enabled this flag will randomly shift the image between iterations.

Note that for ever normal iteration, every tile will have a set number of iterations run on it. This means that you will end up with num_iterations * (tile_iter * the number of tiles). So, you may want to play around with the values for -num_iterations, -tile_iter, and possibly other parameters as well.

import os
import copy
import random
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from PIL import Image
from CaffeLoader import loadCaffemodel, ModelParallel
import argparse
parser = argparse.ArgumentParser()
# Basic options
parser.add_argument("-style_image", help="Style target image", default='examples/inputs/seated-nude.jpg')
parser.add_argument("-style_blend_weights", default=None)
parser.add_argument("-content_image", help="Content target image", default='examples/inputs/tubingen.jpg')
parser.add_argument("-image_size", help="Maximum height / width of generated image", type=int, default=512)
parser.add_argument("-gpu", help="Zero-indexed ID of the GPU to use; for CPU mode set -gpu = c", default=0)
# Optimization options
parser.add_argument("-content_weight", type=float, default=5e0)
parser.add_argument("-style_weight", type=float, default=1e2)
parser.add_argument("-normalize_weights", action='store_true')
parser.add_argument("-tv_weight", type=float, default=1e-3)
parser.add_argument("-num_iterations", type=int, default=10)
parser.add_argument("-init", choices=['random', 'image'], default='random')
parser.add_argument("-init_image", default=None)
parser.add_argument("-optimizer", choices=['lbfgs', 'adam'], default='lbfgs')
parser.add_argument("-learning_rate", type=float, default=1e0)
parser.add_argument("-lbfgs_num_correction", type=int, default=100)
# Output options
parser.add_argument("-print_iter", type=int, default=1)
parser.add_argument("-save_iter", type=int, default=1)
parser.add_argument("-output_image", default='out.png')
# Other options
parser.add_argument("-style_scale", type=float, default=1.0)
parser.add_argument("-original_colors", type=int, choices=[0, 1], default=0)
parser.add_argument("-pooling", choices=['avg', 'max'], default='max')
parser.add_argument("-model_file", type=str, default='models/vgg19-d01eb7cb.pth')
parser.add_argument("-disable_check", action='store_true')
parser.add_argument("-backend", choices=['nn', 'cudnn', 'mkl', 'mkldnn', 'openmp', 'mkl,cudnn', 'cudnn,mkl'], default='nn')
parser.add_argument("-cudnn_autotune", action='store_true')
parser.add_argument("-seed", type=int, default=-1)
parser.add_argument("-content_layers", help="layers for content", default='relu4_2')
parser.add_argument("-style_layers", help="layers for style", default='relu1_1,relu2_1,relu3_1,relu4_1,relu5_1')
parser.add_argument("-multidevice_strategy", default='4,7,29')
# Tile options
parser.add_argument("-tile_size", type=int, default=256)
parser.add_argument("-overlap_percent", type=float, default=0.5)
parser.add_argument("-tile_iter", type=int, default=0)
parser.add_argument("-print_tile_iter", type=int, default=0)
parser.add_argument("-print_tile", type=int, default=1)
parser.add_argument("-roll_image", action='store_true')
parser.add_argument("-jitter", type=int, default=0)
params = parser.parse_args()
Image.MAX_IMAGE_PIXELS = 1000000000 # Support gigapixel images
def main():
dtype, multidevice, backward_device = setup_gpu()
cnn, layerList = loadCaffemodel(params.model_file, params.pooling, params.gpu, params.disable_check)
content_image = preprocess(params.content_image, params.image_size).type(dtype)
style_image_input = params.style_image.split(',')
style_image_list, ext = [], [".jpg", ".jpeg", ".png", ".tiff"]
for image in style_image_input:
if os.path.isdir(image):
images = (image + "/" + file for file in os.listdir(image)
if os.path.splitext(file)[1].lower() in ext)
style_image_list.extend(images)
else:
style_image_list.append(image)
style_images_caffe = []
for image in style_image_list:
style_size = int(params.tile_size * params.style_scale)
img_caffe = preprocess(image, style_size).type(dtype)
style_images_caffe.append(img_caffe)
if params.init_image != None:
image_size = (content_image.size(2), content_image.size(3))
init_image = preprocess(params.init_image, image_size).type(dtype)
# Handle style blending weights for multiple style inputs
style_blend_weights = []
if params.style_blend_weights == None:
# Style blending not specified, so use equal weighting
for i in style_image_list:
style_blend_weights.append(1.0)
for i, blend_weights in enumerate(style_blend_weights):
style_blend_weights[i] = int(style_blend_weights[i])
else:
style_blend_weights = params.style_blend_weights.split(',')
assert len(style_blend_weights) == len(style_image_list), \
"-style_blend_weights and -style_images must have the same number of elements!"
# Normalize the style blending weights so they sum to 1
style_blend_sum = 0
for i, blend_weights in enumerate(style_blend_weights):
style_blend_weights[i] = float(style_blend_weights[i])
style_blend_sum = float(style_blend_sum) + style_blend_weights[i]
for i, blend_weights in enumerate(style_blend_weights):
style_blend_weights[i] = float(style_blend_weights[i]) / float(style_blend_sum)
content_layers = params.content_layers.split(',')
style_layers = params.style_layers.split(',')
# Set up the network, inserting style and content loss modules
cnn = copy.deepcopy(cnn)
content_losses, style_losses, tv_losses = [], [], []
next_content_idx, next_style_idx = 1, 1
net_base = nn.Sequential()
c, r = 0, 0
if params.jitter > 0:
jitter_mod = Jitter(params.jitter).type(dtype)
net_base.add_module(str(len(net_base)), jitter_mod)
if params.tv_weight > 0:
tv_mod = TVLoss(params.tv_weight).type(dtype)
net_base.add_module(str(len(net_base)), tv_mod)
tv_losses.append(tv_mod)
for i, layer in enumerate(list(cnn), 1):
if next_content_idx <= len(content_layers) or next_style_idx <= len(style_layers):
if isinstance(layer, nn.Conv2d):
net_base.add_module(str(len(net_base)), layer)
if layerList['C'][c] in content_layers:
print("Setting up content layer " + str(i) + ": " + str(layerList['C'][c]))
loss_module = ContentLoss(params.content_weight)
net_base.add_module(str(len(net_base)), loss_module)
content_losses.append(loss_module)
if layerList['C'][c] in style_layers:
print("Setting up style layer " + str(i) + ": " + str(layerList['C'][c]))
loss_module = StyleLoss(params.style_weight)
net_base.add_module(str(len(net_base)), loss_module)
style_losses.append(loss_module)
c+=1
if isinstance(layer, nn.ReLU):
net_base.add_module(str(len(net_base)), layer)
if layerList['R'][r] in content_layers:
print("Setting up content layer " + str(i) + ": " + str(layerList['R'][r]))
loss_module = ContentLoss(params.content_weight)
net_base.add_module(str(len(net_base)), loss_module)
content_losses.append(loss_module)
next_content_idx += 1
if layerList['R'][r] in style_layers:
print("Setting up style layer " + str(i) + ": " + str(layerList['R'][r]))
loss_module = StyleLoss(params.style_weight)
net_base.add_module(str(len(net_base)), loss_module)
style_losses.append(loss_module)
next_style_idx += 1
r+=1
if isinstance(layer, nn.MaxPool2d) or isinstance(layer, nn.AvgPool2d):
net_base.add_module(str(len(net_base)), layer)
if multidevice:
net_base = setup_multi_device(net_base)
print_torch(net_base, multidevice)
if params.optimizer == 'lbfgs':
print("Running optimization with L-BFGS")
else:
print("Running optimization with ADAM")
if params.seed >= 0:
torch.manual_seed(params.seed)
torch.cuda.manual_seed_all(params.seed)
torch.backends.cudnn.deterministic=True
overlap_percent = params.overlap_percent / 100 if params.overlap_percent > 1 else params.overlap_percent
if params.init_image != None:
init_image_tiles = tile_image(init_image.clone(), params.tile_size, overlap_percent)
output_tiles = []
total_content_losses, total_style_losses, total_loss = [], [], [0]
content_tiles = tile_image(content_image.clone(), params.tile_size, overlap_percent)
first_run = True
h_roll, w_roll = 0, 0
_, _, tile_pattern, num_tiles = tile_image(content_image.clone(), params.tile_size, overlap_percent, True)
print('\nCreated ' + str(num_tiles) + ' tiles')
print('Tile pattern: ' + str(tile_pattern[0]) + 'x' + str(tile_pattern[1]))
if params.tile_iter <= 0:
sub_iter = int(1000 / 3)#int(params.num_iterations / 3)
else:
sub_iter = params.tile_iter
for iter in range(1, params.num_iterations+1):
for tile_num, c_tile in enumerate(content_tiles):
net = copy.deepcopy(net_base)
content_losses, style_losses, tv_losses = [], [], []
for i, layer in enumerate(net):
if isinstance(layer, TVLoss):
tv_losses.append(layer)
elif isinstance(layer, ContentLoss):
content_losses.append(layer)
elif isinstance(layer, StyleLoss):
style_losses.append(layer)
maybe_print_tile(tile_num, num_tiles)
# Capture content targets
for i in content_losses:
i.mode = 'capture'
net(c_tile)
# Capture style targets
for i in content_losses:
i.mode = 'None'
for i, image in enumerate(style_images_caffe):
for j in style_losses:
j.mode = 'capture'
j.blend_weight = style_blend_weights[i]
net(style_images_caffe[i])
# Set all loss modules to loss mode
for i in content_losses:
i.mode = 'loss'
for i in style_losses:
i.mode = 'loss'
# Maybe normalize content and style weights
if params.normalize_weights:
normalize_weights(content_losses, style_losses)
# Freeze the network in order to prevent
# unnecessary gradient calculations
for param in net.parameters():
param.requires_grad = False
# Initialize the image
if params.init == 'random':
B, C, H, W = c_tile.size()
img = torch.randn(C, H, W).mul(0.001).unsqueeze(0).type(dtype)
elif params.init == 'image':
if params.init_image != None:
img = init_image_tiles[tile_num].clone()
else:
img = c_tile.clone()
if first_run == False:
img = output_img_tiles[tile_num].clone()
img = nn.Parameter(img)
# Function to evaluate loss and gradient. We run the net forward and
# backward to get the gradient, and sum up losses from the loss modules.
# optim.lbfgs internally handles iteration and calls this function many
# times, so we manually count the number of iterations to handle printing
# and saving intermediate results.
num_calls = [0]
def feval():
num_calls[0] += 1
optimizer.zero_grad()
net(img)
loss = 0
for mod in content_losses:
loss += mod.loss.to(backward_device)
for mod in style_losses:
loss += mod.loss.to(backward_device)
if params.tv_weight > 0:
for mod in tv_losses:
loss += mod.loss.to(backward_device)
total_loss[0] += loss.item()
loss.backward()
maybe_print_tile_iter(num_calls[0], len(output_tiles), sub_iter)
return loss
optimizer, loopVal = setup_optimizer(img)
while num_calls[0] <= sub_iter:
optimizer.step(feval)
if len(output_tiles) == 0:
for mod in content_losses:
total_content_losses.append(mod.loss.item())
for mod in style_losses:
total_style_losses.append(mod.loss.item())
else:
for c_loss, mod in enumerate(content_losses):
total_content_losses[c_loss] += mod.loss.item()
for s_loss, mod in enumerate(style_losses):
total_style_losses[s_loss] += mod.loss.item()
output_tiles.append(img.clone())
if len(output_tiles) == len(content_tiles):
first_run = False
output_img = rebuild_image(output_tiles, content_image.clone(), params.tile_size, overlap_percent)
output_tiles = []
if params.roll_image:
output_img, _, _ = roll_tensor(output_img, -h_roll, -w_roll)
maybe_print(iter, total_loss[0], total_content_losses, total_style_losses)
maybe_save(iter, output_img)
if params.roll_image:
output_img, h_roll, w_roll = roll_tensor(output_img.clone())
output_img_tiles = tile_image(output_img.clone(), params.tile_size, overlap_percent)
output_tiles = []
total_content_losses, total_style_losses, total_loss = [], [], [0]
def maybe_save(t, save_img):
should_save = params.save_iter > 0 and t % params.save_iter == 0
should_save = should_save or t == params.num_iterations
if should_save:
output_filename, file_extension = os.path.splitext(params.output_image)
if t == params.num_iterations:
filename = output_filename + str(file_extension)
else:
filename = str(output_filename) + "_" + str(t) + str(file_extension)
disp = deprocess(save_img.clone())
# Maybe perform postprocessing for color-independent style transfer
if params.original_colors == 1:
disp = original_colors(deprocess(content_image.clone()), disp)
disp.save(str(filename))
def maybe_print(t, loss, content_losses, style_losses):
if params.print_iter > 0 and t % params.print_iter == 0:
print("Iteration " + str(t) + " / "+ str(params.num_iterations))
for i, loss_module in enumerate(content_losses):
print(" Content " + str(i+1) + " loss: " + str(loss_module))
for i, loss_module in enumerate(style_losses):
print(" Style " + str(i+1) + " loss: " + str(loss_module))
print(" Total loss: " + str(loss))
def maybe_print_tile_iter(t, n, total):
if params.print_tile_iter > 0 and t % params.print_tile_iter == 0:
print("Tile "+str(n+1) +" iteration " + str(t) + " / "+ str(total))
def maybe_print_tile(tile_num, num_tiles):
if params.print_tile > 0 and (tile_num + 1) % params.print_tile == 0:
print('Processing tile: ' + str(tile_num+1) + ' of ' + str(num_tiles))
# Configure the optimizer
def setup_optimizer(img):
if params.optimizer == 'lbfgs':
optim_state = {
'max_iter': 1,#params.num_iterations,
'tolerance_change': -1,
'tolerance_grad': -1,
}
if params.lbfgs_num_correction != 100:
optim_state['history_size'] = params.lbfgs_num_correction
optimizer = optim.LBFGS([img], **optim_state)
loopVal = 1
elif params.optimizer == 'adam':
optimizer = optim.Adam([img], lr = params.learning_rate)
loopVal = params.num_iterations - 1
return optimizer, loopVal
def setup_gpu():
def setup_cuda():
if 'cudnn' in params.backend:
torch.backends.cudnn.enabled = True
if params.cudnn_autotune:
torch.backends.cudnn.benchmark = True
else:
torch.backends.cudnn.enabled = False
def setup_cpu():
if 'mkl' in params.backend and 'mkldnn' not in params.backend:
torch.backends.mkl.enabled = True
elif 'mkldnn' in params.backend:
raise ValueError("MKL-DNN is not supported yet.")
elif 'openmp' in params.backend:
torch.backends.openmp.enabled = True
multidevice = False
if "," in str(params.gpu):
devices = params.gpu.split(',')
multidevice = True
if 'c' in str(devices[0]).lower():
backward_device = "cpu"
setup_cuda(), setup_cpu()
else:
backward_device = "cuda:" + devices[0]
setup_cuda()
dtype = torch.FloatTensor
elif "c" not in str(params.gpu).lower():
setup_cuda()
dtype, backward_device = torch.cuda.FloatTensor, "cuda:" + str(params.gpu)
else:
setup_cpu()
dtype, backward_device = torch.FloatTensor, "cpu"
return dtype, multidevice, backward_device
def setup_multi_device(net):
assert len(params.gpu.split(',')) - 1 == len(params.multidevice_strategy.split(',')), \
"The number of -multidevice_strategy layer indices minus 1, must be equal to the number of -gpu devices."
new_net = ModelParallel(net, params.gpu, params.multidevice_strategy)
return new_net
# Preprocess an image before passing it to a model.
# We need to rescale from [0, 1] to [0, 255], convert from RGB to BGR,
# and subtract the mean pixel.
def preprocess(image_name, image_size):
image = Image.open(image_name).convert('RGB')
if type(image_size) is not tuple:
image_size = tuple([int((float(image_size) / max(image.size))*x) for x in (image.height, image.width)])
Loader = transforms.Compose([transforms.Resize(image_size), transforms.ToTensor()])
rgb2bgr = transforms.Compose([transforms.Lambda(lambda x: x[torch.LongTensor([2,1,0])])])
Normalize = transforms.Compose([transforms.Normalize(mean=[103.939, 116.779, 123.68], std=[1,1,1])])
tensor = Normalize(rgb2bgr(Loader(image) * 256)).unsqueeze(0)
return tensor
# Undo the above preprocessing.
def deprocess(output_tensor):
Normalize = transforms.Compose([transforms.Normalize(mean=[-103.939, -116.779, -123.68], std=[1,1,1])])
bgr2rgb = transforms.Compose([transforms.Lambda(lambda x: x[torch.LongTensor([2,1,0])])])
output_tensor = bgr2rgb(Normalize(output_tensor.squeeze(0).cpu())) / 256
output_tensor.clamp_(0, 1)
Image2PIL = transforms.ToPILImage()
image = Image2PIL(output_tensor.cpu())
return image
# Combine the Y channel of the generated image and the UV/CbCr channels of the
# content image to perform color-independent style transfer.
def original_colors(content, generated):
content_channels = list(content.convert('YCbCr').split())
generated_channels = list(generated.convert('YCbCr').split())
content_channels[0] = generated_channels[0]
return Image.merge('YCbCr', content_channels).convert('RGB')
# Print like Lua/Torch7
def print_torch(net, multidevice):
if multidevice:
return
simplelist = ""
for i, layer in enumerate(net, 1):
simplelist = simplelist + "(" + str(i) + ") -> "
print("nn.Sequential ( \n [input -> " + simplelist + "output]")
def strip(x):
return str(x).replace(", ",',').replace("(",'').replace(")",'') + ", "
def n():
return " (" + str(i) + "): " + "nn." + str(l).split("(", 1)[0]
for i, l in enumerate(net, 1):
if "2d" in str(l):
ks, st, pd = strip(l.kernel_size), strip(l.stride), strip(l.padding)
if "Conv2d" in str(l):
ch = str(l.in_channels) + " -> " + str(l.out_channels)
print(n() + "(" + ch + ", " + (ks).replace(",",'x', 1) + st + pd.replace(", ",')'))
elif "Pool2d" in str(l):
st = st.replace(" ",' ') + st.replace(", ",')')
print(n() + "(" + ((ks).replace(",",'x' + ks, 1) + st).replace(", ",','))
else:
print(n())
print(")")
# Divide weights by channel size
def normalize_weights(content_losses, style_losses):
for n, i in enumerate(content_losses):
i.strength = i.strength / max(i.target.size())
for n, i in enumerate(style_losses):
i.strength = i.strength / max(i.target.size())
# Define an nn Module to compute content loss
class ContentLoss(nn.Module):
def __init__(self, strength):
super(ContentLoss, self).__init__()
self.strength = strength
self.crit = nn.MSELoss()
self.mode = 'None'
def forward(self, input):
if self.mode == 'loss':
self.loss = self.crit(input, self.target) * self.strength
elif self.mode == 'capture':
self.target = input.detach()
return input
class GramMatrix(nn.Module):
def forward(self, input):
B, C, H, W = input.size()
x_flat = input.view(C, H * W)
return torch.mm(x_flat, x_flat.t())
# Define an nn Module to compute style loss
class StyleLoss(nn.Module):
def __init__(self, strength):
super(StyleLoss, self).__init__()
self.target = torch.Tensor()
self.strength = strength
self.gram = GramMatrix()
self.crit = nn.MSELoss()
self.mode = 'None'
self.blend_weight = None
def forward(self, input):
self.G = self.gram(input)
self.G = self.G.div(input.nelement())
if self.mode == 'capture':
if self.blend_weight == None:
self.target = self.G.detach()
elif self.target.nelement() == 0:
self.target = self.G.detach().mul(self.blend_weight)
else:
self.target = self.target.add(self.blend_weight, self.G.detach())
elif self.mode == 'loss':
self.loss = self.strength * self.crit(self.G, self.target)
return input
class TVLoss(nn.Module):
def __init__(self, strength):
super(TVLoss, self).__init__()
self.strength = strength
def forward(self, input):
self.x_diff = input[:,:,1:,:] - input[:,:,:-1,:]
self.y_diff = input[:,:,:,1:] - input[:,:,:,:-1]
self.loss = self.strength * (torch.sum(torch.abs(self.x_diff)) + torch.sum(torch.abs(self.y_diff)))
return input
# Shift tensor, possibly randomly.
def roll_tensor(tensor, h_shift=None, w_shift=None):
if h_shift == None:
h_shift = torch.LongTensor(10).random_(-tensor.size(1), tensor.size(1))[0].item()
if w_shift == None:
w_shift = torch.LongTensor(10).random_(-tensor.size(2), tensor.size(2))[0].item()
tensor = torch.roll(torch.roll(tensor, shifts=h_shift, dims=2), shifts=w_shift, dims=3)
return tensor, h_shift, w_shift
# Apply blend masks to tiles
def mask_tile(tile, overlap, side='bottom'):
h, w = tile.size(2), tile.size(3)
top_overlap, bottom_overlap, right_overlap, left_overlap = overlap[0], overlap[1], overlap[2], overlap[3]
if tile.is_cuda:
if 'left' in side and 'left-special' not in side:
lin_mask_left = torch.linspace(0,1,left_overlap, device=tile.get_device()).repeat(h,1).repeat(3,1,1).unsqueeze(0)
if 'right' in side and 'right-special' not in side:
lin_mask_right = torch.linspace(1,0,right_overlap, device=tile.get_device()).repeat(h,1).repeat(3,1,1).unsqueeze(0)
if 'top' in side and 'top-special' not in side:
lin_mask_top = torch.linspace(0,1,top_overlap, device=tile.get_device()).repeat(w,1).rot90(3).repeat(3,1,1).unsqueeze(0)
if 'bottom' in side and 'bottom-special' not in side:
lin_mask_bottom = torch.linspace(1,0,bottom_overlap, device=tile.get_device()).repeat(w,1).rot90(3).repeat(3,1,1).unsqueeze(0)
if 'left-special' in side:
lin_mask_left = torch.linspace(0,1,left_overlap, device=tile.get_device())
zeros_mask = torch.zeros(w-(left_overlap*2), device=tile.get_device())
ones_mask = torch.ones(left_overlap, device=tile.get_device())
lin_mask_left = torch.cat([zeros_mask, lin_mask_left, ones_mask], 0).repeat(h,1).repeat(3,1,1).unsqueeze(0)
if 'right-special' in side:
lin_mask_right = torch.linspace(1,0,right_overlap, device=tile.get_device())
ones_mask = torch.ones(w-right_overlap, device=tile.get_device())
lin_mask_right = torch.cat([ones_mask, lin_mask_right], 0).repeat(h,1).repeat(3,1,1).unsqueeze(0)
if 'top-special' in side:
lin_mask_top = torch.linspace(0,1,top_overlap, device=tile.get_device())
zeros_mask = torch.zeros(h-(top_overlap*2), device=tile.get_device())
ones_mask = torch.ones(top_overlap, device=tile.get_device())
lin_mask_top = torch.cat([zeros_mask, lin_mask_top, ones_mask], 0).repeat(w,1).rot90(3).repeat(3,1,1).unsqueeze(0)
if 'bottom-special' in side:
lin_mask_bottom = torch.linspace(1,0,bottom_overlap, device=tile.get_device())
ones_mask = torch.ones(h-bottom_overlap, device=tile.get_device())
lin_mask_bottom = torch.cat([ones_mask, lin_mask_bottom], 0).repeat(w,1).rot90(3).repeat(3,1,1).unsqueeze(0)
else:
if 'left' in side and 'left-special' not in side:
lin_mask_left = torch.linspace(0,1,left_overlap).repeat(h,1).repeat(3,1,1).unsqueeze(0)
if 'right' in side and 'right-special' not in side:
lin_mask_right = torch.linspace(1,0,right_overlap).repeat(h,1).repeat(3,1,1).unsqueeze(0)
if 'top' in side and 'top-special' not in side:
lin_mask_top = torch.linspace(0,1,top_overlap).repeat(w,1).rot90(3).repeat(3,1,1).unsqueeze(0)
if 'bottom' in side and 'bottom-special' not in side:
lin_mask_bottom = torch.linspace(1,0,bottom_overlap).repeat(w,1).rot90(3).repeat(3,1,1).unsqueeze(0)
if 'left-special' in side:
lin_mask_left = torch.linspace(0,1,left_overlap)
zeros_mask = torch.zeros(w-(left_overlap*2))
ones_mask = torch.ones(left_overlap)
lin_mask_left = torch.cat([zeros_mask, lin_mask_left, ones_mask], 0).repeat(h,1).repeat(3,1,1).unsqueeze(0)
if 'right-special' in side:
lin_mask_right = torch.linspace(1,0,right_overlap)
ones_mask = torch.ones(w-right_overlap)
lin_mask_right = torch.cat([ones_mask, lin_mask_right], 0).repeat(h,1).repeat(3,1,1).unsqueeze(0)
if 'top-special' in side:
lin_mask_top = torch.linspace(0,1,top_overlap)
zeros_mask = torch.zeros(h-(top_overlap*2))
ones_mask = torch.ones(top_overlap)
lin_mask_top = torch.cat([zeros_mask, lin_mask_top, ones_mask], 0).repeat(w,1).rot90(3).repeat(3,1,1).unsqueeze(0)
if 'bottom-special' in side:
lin_mask_bottom = torch.linspace(1,0,bottom_overlap)
ones_mask = torch.ones(h-bottom_overlap)
lin_mask_bottom = torch.cat([ones_mask, lin_mask_bottom], 0).repeat(w,1).rot90(3).repeat(3,1,1).unsqueeze(0)
base_mask = torch.ones_like(tile)
if 'right' in side and 'right-special' not in side:
base_mask[:,:,:,w-right_overlap:] = base_mask[:,:,:,w-right_overlap:] * lin_mask_right
if 'left' in side and 'left-special' not in side:
base_mask[:,:,:,:left_overlap] = base_mask[:,:,:,:left_overlap] * lin_mask_left
if 'bottom' in side and 'bottom-special' not in side:
base_mask[:,:,h-bottom_overlap:,:] = base_mask[:,:,h-bottom_overlap:,:] * lin_mask_bottom
if 'top' in side and 'top-special' not in side:
base_mask[:,:,:top_overlap,:] = base_mask[:,:,:top_overlap,:] * lin_mask_top
if 'right-special' in side:
base_mask = base_mask * lin_mask_right
if 'left-special' in side:
base_mask = base_mask * lin_mask_left
if 'bottom-special' in side:
base_mask = base_mask * lin_mask_bottom
if 'top-special' in side:
base_mask = base_mask * lin_mask_top
return tile * base_mask
def get_tile_coords(d, tile_dim, overlap=0):
overlap = int(tile_dim * (1-overlap))
c, tile_start, coords = 1, 0, [0]
while tile_start + tile_dim < d:
tile_start = overlap * c
if tile_start + tile_dim >= d:
coords.append(d - tile_dim)
else:
coords.append(tile_start)
c += 1
return coords, overlap
def get_tiles(img, tile_coords, tile_size, info_only=False):
tile_list = []
for y in tile_coords[0]:
for x in tile_coords[1]:
tile = img[:, :, y:y+tile_size[0], x:x+tile_size[1]]
tile_list.append(tile)
if not info_only:
return tile_list
else:
return tile_list[0].size(2), tile_list[0].size(3)
def final_overlap(tile_coords):
r, c = len(tile_coords[0]), len(tile_coords[1])
return (tile_coords[0][r-1] - tile_coords[0][r-2], tile_coords[1][c-1] - tile_coords[1][c-2])
def add_tiles(tiles, base_img, tile_coords, tile_size, overlap):
f_ovlp = final_overlap(tile_coords)
h, w = tiles[0].size(2), tiles[0].size(3)
t=0
column, row, = 0, 0
for y in tile_coords[0]:
for x in tile_coords[1]:
mask_sides=''
c_overlap = overlap.copy()
if len(tile_coords[0]) > 1:
if row == 0:
if row == len(tile_coords[0]) - 2:
mask_sides += 'bottom-special'
c_overlap[1] = f_ovlp[0] # Change bottom overlap
else:
mask_sides += 'bottom'
elif row > 0 and row < len(tile_coords[0]) -2:
mask_sides += 'bottom,top'
elif row == len(tile_coords[0]) - 2:
if f_ovlp[0] > 0:
mask_sides += 'bottom-special,top'
c_overlap[1] = f_ovlp[0] # Change bottom overlap
elif f_ovlp[0] <= 0:
mask_sides += 'bottom,top'
elif row == len(tile_coords[0]) -1:
if f_ovlp[0] > 0:
mask_sides += 'top-special'
c_overlap[0] = f_ovlp[0] # Change top overlap
elif f_ovlp[0] <= 0:
mask_sides += 'top'
if len(tile_coords[1]) > 1:
if column == 0:
if column == len(tile_coords[1]) -2:
mask_sides += ',right-special'
c_overlap[2] = f_ovlp[1] # Change right overlap
else:
mask_sides += ',right'
elif column > 0 and column < len(tile_coords[1]) -2:
mask_sides += ',right,left'
elif column == len(tile_coords[1]) -2:
if f_ovlp[1] > 0:
mask_sides += ',right-special,left'
c_overlap[2] = f_ovlp[1] # Change right overlap
elif f_ovlp[1] <= 0:
mask_sides += ',right,left'
elif column == len(tile_coords[1]) -1:
if f_ovlp[1] > 0:
mask_sides += ',left-special'
c_overlap[3] = f_ovlp[1] # Change left overlap
elif f_ovlp[1] <= 0:
mask_sides += ',left'
if t < len(tiles):
tile = mask_tile(tiles[t], c_overlap, side=mask_sides)
base_img[:, :, y:y+tile_size[0], x:x+tile_size[1]] = base_img[:, :, y:y+tile_size[0], x:x+tile_size[1]] + tile
t+=1
column+=1
row+=1
column=0
return base_img
def tile_setup(tile_size, overlap_percent, base_size):
if type(tile_size) is not tuple and type(tile_size) is not list:
tile_size = (tile_size, tile_size)
if type(overlap_percent) is not tuple and type(overlap_percent) is not list:
overlap_percent = (overlap_percent, overlap_percent)
x_coords, x_ovlp = get_tile_coords(base_size[1], tile_size[1], overlap_percent[1])
y_coords, y_ovlp = get_tile_coords(base_size[0], tile_size[0], overlap_percent[0])
return (y_coords, x_coords), tile_size, [y_ovlp, y_ovlp, x_ovlp, x_ovlp]
def tile_image(img, tile_size, overlap_percent, info_only=False):
tile_coords, tile_size, _ = tile_setup(tile_size, overlap_percent, (img.size(2), img.size(3)))
if not info_only:
return get_tiles(img, tile_coords, tile_size)
else:
tile_size = get_tiles(img, tile_coords, tile_size, info_only)
return tile_size[0], tile_size[1], (len(tile_coords[0]), len(tile_coords[1])), (len(tile_coords[0]) * len(tile_coords[1]))
def rebuild_image(tiles, base_img, tile_size, overlap_percent):
base_img = torch.zeros_like(base_img)
tile_coords, tile_size, overlap = tile_setup(tile_size, overlap_percent, (base_img.size(2), base_img.size(3)))
return add_tiles(tiles, base_img, tile_coords, tile_size, overlap)
# Define an nn Module to apply jitter
class Jitter(torch.nn.Module):
def __init__(self, jitter_val):
super(Jitter, self).__init__()
self.jitter_val = jitter_val
def roll_tensor(self, input):
h_shift = random.randint(-self.jitter_val, self.jitter_val)
w_shift = random.randint(-self.jitter_val, self.jitter_val)
return torch.roll(torch.roll(input, shifts=h_shift, dims=2), shifts=w_shift, dims=3)
def forward(self, input):
return self.roll_tensor(input)
if __name__ == "__main__":
main()
@ProGamerGov
Copy link
Author

ProGamerGov commented Mar 28, 2020

Using some tile sizes mean that the network has to spend more time time changing the image, or something like that, I think?

out5_1

If you use better tile sizes, you get something that looks more like this:

With poor content/style weight choices:

out4_2

@ProGamerGov
Copy link
Author

Normal style transfer up to here:
out3

Then 2048 image size with 1024 tile size and 2500 style weight.

@ProGamerGov
Copy link
Author

out4_7

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment