Skip to content

Instantly share code, notes, and snippets.

@ProGamerGov
Last active January 4, 2020 20:10
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save ProGamerGov/e64fcb309274c2946f5a9a679ed45669 to your computer and use it in GitHub Desktop.
Save ProGamerGov/e64fcb309274c2946f5a9a679ed45669 to your computer and use it in GitHub Desktop.

neural-style-pt with tiling

Tiling

The tiling is inspired on VaKonS's neural-style tiling implementation.

The tiling feature is still highly experimental, and as such input images may require manual resizing to get the desired results.

Usage

Basic usage:

python neural_style_tile.py -style_image <image.jpg> -content_image <image.jpg>

Example Set Of Parameters

python neural_style_tile.py -style_image examples/inputs/starry_night_google.jpg -content_image examples/inputs/brad_pitt.jpg -backend cudnn -tile_size 300 -image_size 512 -init image -output_image blended_tiling/outputs/out.png -print_iter 1 -save_iter 5 -tile_iter 250 -style_weight 4000 -tv_weight 0

Tiling Options:

  • -tile_size: The desired tile size. Default is 256.
  • -tile_iter: How many iterations to perform for each tile; default is -1 which automatically calculates the number of tile iterations.
  • -print_tile_iter: Print every print_tile_iter iterations. Default is set to 0 to disable printing.
  • -match_image_size: This flag will try to create tiles with a HxW size ratio that matches the original image.
  • -height_offset: Value to add to the tile height size. Default is 0.
  • -width_offset: Value to add to the tile width size. Default is 0.

Note that for ever normal iteration, every tile will have a set number of iterations run on it. This means that you will end up with num_iterations * (tile_iter * the number of tiles). So, you may want to play around with the values for -num_iterations, and possibly -tile_iter. If no -tile_iter value is specified, -tile_iter will become equal to -num_iterations / 3.

import os
import copy
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from PIL import Image
from CaffeLoader import loadCaffemodel, ModelParallel
from nspt_extras import rebuild_tensor, split_tensor_equal
import argparse
parser = argparse.ArgumentParser()
# Basic options
parser.add_argument("-style_image", help="Style target image", default='examples/inputs/seated-nude.jpg')
parser.add_argument("-style_blend_weights", default=None)
parser.add_argument("-content_image", help="Content target image", default='examples/inputs/tubingen.jpg')
parser.add_argument("-image_size", help="Maximum height / width of generated image", type=int, default=512)
parser.add_argument("-gpu", help="Zero-indexed ID of the GPU to use; for CPU mode set -gpu = c", default=0)
# Optimization options
parser.add_argument("-content_weight", type=float, default=5e0)
parser.add_argument("-style_weight", type=float, default=1e2)
parser.add_argument("-normalize_weights", action='store_true')
parser.add_argument("-tv_weight", type=float, default=1e-3)
parser.add_argument("-num_iterations", type=int, default=1000)
parser.add_argument("-init", choices=['random', 'image'], default='random')
parser.add_argument("-init_image", default=None)
parser.add_argument("-optimizer", choices=['lbfgs', 'adam'], default='lbfgs')
parser.add_argument("-learning_rate", type=float, default=1e0)
parser.add_argument("-lbfgs_num_correction", type=int, default=100)
# Output options
parser.add_argument("-print_iter", type=int, default=50)
parser.add_argument("-save_iter", type=int, default=100)
parser.add_argument("-output_image", default='out.png')
# Other options
parser.add_argument("-style_scale", type=float, default=1.0)
parser.add_argument("-original_colors", type=int, choices=[0, 1], default=0)
parser.add_argument("-pooling", choices=['avg', 'max'], default='max')
parser.add_argument("-model_file", type=str, default='models/vgg19-d01eb7cb.pth')
parser.add_argument("-disable_check", action='store_true')
parser.add_argument("-backend", choices=['nn', 'cudnn', 'mkl', 'mkldnn', 'openmp', 'mkl,cudnn', 'cudnn,mkl'], default='nn')
parser.add_argument("-cudnn_autotune", action='store_true')
parser.add_argument("-seed", type=int, default=-1)
parser.add_argument("-content_layers", help="layers for content", default='relu4_2')
parser.add_argument("-style_layers", help="layers for style", default='relu1_1,relu2_1,relu3_1,relu4_1,relu5_1')
parser.add_argument("-multidevice_strategy", default='4,7,29')
# Tiling Options
parser.add_argument("-tile_size", type=int, default=256)
parser.add_argument("-tile_iter", type=int, help="Set to 0 or -1 for auto tile iteration count", default=-1)
parser.add_argument("-print_tile_iter", type=int, default=0)
parser.add_argument("-match_image_size", action='store_true')
parser.add_argument("-width_offset", type=int, help="Number to be added to tile width", default=0)
parser.add_argument("-height_offset", type=int, help="Number to be added to tile height", default=0)
params = parser.parse_args()
Image.MAX_IMAGE_PIXELS = 1000000000 # Support gigapixel images
def main():
dtype, multidevice, backward_device = setup_gpu()
cnn, layerList = loadCaffemodel(params.model_file, params.pooling, params.gpu, params.disable_check)
content_image = preprocess(params.content_image, params.image_size).type(dtype)
style_image_input = params.style_image.split(',')
style_image_list, ext = [], [".jpg", ".jpeg", ".png", ".tiff"]
for image in style_image_input:
if os.path.isdir(image):
images = (image + "/" + file for file in os.listdir(image)
if os.path.splitext(file)[1].lower() in ext)
style_image_list.extend(images)
else:
style_image_list.append(image)
style_images_caffe = []
for image in style_image_list:
style_size = int(params.image_size * params.style_scale)
img_caffe = preprocess(image, style_size).type(dtype)
style_images_caffe.append(img_caffe)
if params.init_image != None:
image_size = (content_image.size(2), content_image.size(3))
init_image = preprocess(params.init_image, image_size).type(dtype)
# Handle style blending weights for multiple style inputs
style_blend_weights = []
if params.style_blend_weights == None:
# Style blending not specified, so use equal weighting
for i in style_image_list:
style_blend_weights.append(1.0)
for i, blend_weights in enumerate(style_blend_weights):
style_blend_weights[i] = int(style_blend_weights[i])
else:
style_blend_weights = params.style_blend_weights.split(',')
assert len(style_blend_weights) == len(style_image_list), \
"-style_blend_weights and -style_images must have the same number of elements!"
# Normalize the style blending weights so they sum to 1
style_blend_sum = 0
for i, blend_weights in enumerate(style_blend_weights):
style_blend_weights[i] = float(style_blend_weights[i])
style_blend_sum = float(style_blend_sum) + style_blend_weights[i]
for i, blend_weights in enumerate(style_blend_weights):
style_blend_weights[i] = float(style_blend_weights[i]) / float(style_blend_sum)
content_layers = params.content_layers.split(',')
style_layers = params.style_layers.split(',')
# Set up the network, inserting style and content loss modules
cnn = copy.deepcopy(cnn)
content_losses, style_losses, tv_losses = [], [], []
next_content_idx, next_style_idx = 1, 1
net_base = nn.Sequential()
c, r = 0, 0
if params.tv_weight > 0:
tv_mod = TVLoss(params.tv_weight).type(dtype)
net_base.add_module(str(len(net_base)), tv_mod)
tv_losses.append(tv_mod)
for i, layer in enumerate(list(cnn), 1):
if next_content_idx <= len(content_layers) or next_style_idx <= len(style_layers):
if isinstance(layer, nn.Conv2d):
net_base.add_module(str(len(net_base)), layer)
if layerList['C'][c] in content_layers:
print("Setting up content layer " + str(i) + ": " + str(layerList['C'][c]))
loss_module = ContentLoss(params.content_weight)
net_base.add_module(str(len(net_base)), loss_module)
content_losses.append(loss_module)
if layerList['C'][c] in style_layers:
print("Setting up style layer " + str(i) + ": " + str(layerList['C'][c]))
loss_module = StyleLoss(params.style_weight)
net_base.add_module(str(len(net_base)), loss_module)
style_losses.append(loss_module)
c+=1
if isinstance(layer, nn.ReLU):
net_base.add_module(str(len(net_base)), layer)
if layerList['R'][r] in content_layers:
print("Setting up content layer " + str(i) + ": " + str(layerList['R'][r]))
loss_module = ContentLoss(params.content_weight)
net_base.add_module(str(len(net_base)), loss_module)
content_losses.append(loss_module)
next_content_idx += 1
if layerList['R'][r] in style_layers:
print("Setting up style layer " + str(i) + ": " + str(layerList['R'][r]))
loss_module = StyleLoss(params.style_weight)
net_base.add_module(str(len(net_base)), loss_module)
style_losses.append(loss_module)
next_style_idx += 1
r+=1
if isinstance(layer, nn.MaxPool2d) or isinstance(layer, nn.AvgPool2d):
net_base.add_module(str(len(net_base)), layer)
if multidevice:
net_base = setup_multi_device(net_base)
print_torch(net_base, multidevice)
if params.init_image != None:
init_image_tiles, _, _ = split_tensor_equal(init_image.clone(), params.tile_size, match_size=params.match_image_size, offset_x=params.width_offset, offset_y=params.height_offset)
output_tiles = []
total_content_losses, total_style_losses, total_loss = [], [], [0]
content_tiles, c_rows, c_overlap = split_tensor_equal(content_image.clone(), params.tile_size, match_size=params.match_image_size, offset_x=params.width_offset, offset_y=params.height_offset)
hs, ws = 0, 0
output_img, first_run = None, True
print('---------------------')
print('Creating', len(content_tiles), 'tiles')
print('Tile pattern:', c_rows[0], 'x', c_rows[1])
print('Overlap:', c_overlap)
print('---------------------')
if params.tile_iter < 0:
sub_iter = int(params.num_iterations / 3)
else:
sub_iter = params.tile_iter
for iter in range(params.num_iterations):
if len(output_tiles) == len(content_tiles):
first_run = False
output_img = rebuild_tensor(output_tiles, c_rows, c_overlap)
maybe_print(iter, total_loss[0], total_content_losses, total_style_losses)
maybe_save(iter, output_img)
output_img_tiles, _, _ = split_tensor_equal(output_img.clone(), params.tile_size, match_size=params.match_image_size, offset_x=params.width_offset, offset_y=params.height_offset)
output_tiles = []
total_content_losses, total_style_losses, total_loss = [], [], [0]
for tile_num, c_tile in enumerate(content_tiles):
net = copy.deepcopy(net_base)
content_losses, style_losses, tv_losses = [], [], []
for i, layer in enumerate(net):
if isinstance(layer, TVLoss):
tv_losses.append(layer)
elif isinstance(layer, ContentLoss):
content_losses.append(layer)
elif isinstance(layer, StyleLoss):
style_losses.append(layer)
# Capture content targets
for i in content_losses:
i.mode = 'capture'
#print("Capturing content targets")
net(c_tile)
# Capture style targets
for i in content_losses:
i.mode = 'None'
for i, image in enumerate(style_images_caffe):
#print("Capturing style target " + str(i+1))
for j in style_losses:
j.mode = 'capture'
j.blend_weight = style_blend_weights[i]
net(style_images_caffe[i])
# Set all loss modules to loss mode
for i in content_losses:
i.mode = 'loss'
for i in style_losses:
i.mode = 'loss'
# Maybe normalize content and style weights
if params.normalize_weights:
normalize_weights(content_losses, style_losses)
# Freeze the network in order to prevent
# unnecessary gradient calculations
for param in net.parameters():
param.requires_grad = False
# Initialize the image
if params.seed >= 0:
torch.manual_seed(params.seed)
torch.cuda.manual_seed_all(params.seed)
torch.backends.cudnn.deterministic=True
if params.init == 'random':
B, C, H, W = c_tile.size()
img = torch.randn(C, H, W).mul(0.001).unsqueeze(0).type(dtype)
elif params.init == 'image':
if params.init_image != None:
img = init_image_tiles[tile_num].clone()
else:
img = c_tile.clone()
if first_run == False:
img = output_img_tiles[tile_num].clone()
img = nn.Parameter(img)
# Function to evaluate loss and gradient. We run the net forward and
# backward to get the gradient, and sum up losses from the loss modules.
# optim.lbfgs internally handles iteration and calls this function many
# times, so we manually count the number of iterations to handle printing
# and saving intermediate results.
num_calls = [0]
def feval():
num_calls[0] += 1
optimizer.zero_grad()
net(img)
loss = 0
for mod in content_losses:
loss += mod.loss.to(backward_device)
for mod in style_losses:
loss += mod.loss.to(backward_device)
if params.tv_weight > 0:
for mod in tv_losses:
loss += mod.loss.to(backward_device)
total_loss[0] += loss.item()
loss.backward()
maybe_print_tile(num_calls[0], len(output_tiles), sub_iter)
return loss
optimizer, loopVal = setup_optimizer(img)
while num_calls[0] <= sub_iter: #loopVal:
optimizer.step(feval)
if len(output_tiles) == 0:
for mod in content_losses:
total_content_losses.append(mod.loss.item())
for mod in style_losses:
total_style_losses.append(mod.loss.item())
else:
for c_loss, mod in enumerate(content_losses):
total_content_losses[c_loss] += mod.loss.item()
for s_loss, mod in enumerate(style_losses):
total_style_losses[s_loss] += mod.loss.item()
output_tiles.append(img.clone())
def maybe_save(t, save_img):
should_save = params.save_iter > 0 and t % params.save_iter == 0
should_save = should_save or t == params.num_iterations
if should_save:
output_filename, file_extension = os.path.splitext(params.output_image)
if t == params.num_iterations:
filename = output_filename + str(file_extension)
else:
filename = str(output_filename) + "_" + str(t) + str(file_extension)
disp = deprocess(save_img.clone())
# Maybe perform postprocessing for color-independent style transfer
if params.original_colors == 1:
disp = original_colors(deprocess(content_image.clone()), disp)
disp.save(str(filename))
def maybe_print(t, loss, content_losses, style_losses):
if params.print_iter > 0 and t % params.print_iter == 0:
print("Iteration " + str(t) + " / "+ str(params.num_iterations))
for i, loss_module in enumerate(content_losses):
print(" Content " + str(i+1) + " loss: " + str(loss_module))
for i, loss_module in enumerate(style_losses):
print(" Style " + str(i+1) + " loss: " + str(loss_module))
print(" Total loss: " + str(loss))
def maybe_print_tile(t, n, total):
if params.print_tile_iter > 0 and t % params.print_tile_iter == 0:
print("Tile "+str(n+1) +" iteration " + str(t) + " / "+ str(total))
# Configure the optimizer
def setup_optimizer(img):
if params.optimizer == 'lbfgs':
#print("Running optimization with L-BFGS")
optim_state = {
'max_iter': 1,#params.num_iterations,
'tolerance_change': -1,
'tolerance_grad': -1,
}
if params.lbfgs_num_correction != 100:
optim_state['history_size'] = params.lbfgs_num_correction
optimizer = optim.LBFGS([img], **optim_state)
loopVal = 1
elif params.optimizer == 'adam':
#print("Running optimization with ADAM")
optimizer = optim.Adam([img], lr = params.learning_rate)
loopVal = params.num_iterations - 1
return optimizer, loopVal
def setup_gpu():
def setup_cuda():
if 'cudnn' in params.backend:
torch.backends.cudnn.enabled = True
if params.cudnn_autotune:
torch.backends.cudnn.benchmark = True
else:
torch.backends.cudnn.enabled = False
def setup_cpu():
if 'mkl' in params.backend and 'mkldnn' not in params.backend:
torch.backends.mkl.enabled = True
elif 'mkldnn' in params.backend:
raise ValueError("MKL-DNN is not supported yet.")
elif 'openmp' in params.backend:
torch.backends.openmp.enabled = True
multidevice = False
if "," in str(params.gpu):
devices = params.gpu.split(',')
multidevice = True
if 'c' in str(devices[0]).lower():
backward_device = "cpu"
setup_cuda(), setup_cpu()
else:
backward_device = "cuda:" + devices[0]
setup_cuda()
dtype = torch.FloatTensor
elif "c" not in str(params.gpu).lower():
setup_cuda()
dtype, backward_device = torch.cuda.FloatTensor, "cuda:" + str(params.gpu)
else:
setup_cpu()
dtype, backward_device = torch.FloatTensor, "cpu"
return dtype, multidevice, backward_device
def setup_multi_device(net):
assert len(params.gpu.split(',')) - 1 == len(params.multidevice_strategy.split(',')), \
"The number of -multidevice_strategy layer indices minus 1, must be equal to the number of -gpu devices."
new_net = ModelParallel(net, params.gpu, params.multidevice_strategy)
return new_net
# Preprocess an image before passing it to a model.
# We need to rescale from [0, 1] to [0, 255], convert from RGB to BGR,
# and subtract the mean pixel.
def preprocess(image_name, image_size):
image = Image.open(image_name).convert('RGB')
if type(image_size) is not tuple:
image_size = tuple([int((float(image_size) / max(image.size))*x) for x in (image.height, image.width)])
Loader = transforms.Compose([transforms.Resize(image_size), transforms.ToTensor()])
rgb2bgr = transforms.Compose([transforms.Lambda(lambda x: x[torch.LongTensor([2,1,0])])])
Normalize = transforms.Compose([transforms.Normalize(mean=[103.939, 116.779, 123.68], std=[1,1,1])])
tensor = Normalize(rgb2bgr(Loader(image) * 256)).unsqueeze(0)
return tensor
# Undo the above preprocessing.
def deprocess(output_tensor):
Normalize = transforms.Compose([transforms.Normalize(mean=[-103.939, -116.779, -123.68], std=[1,1,1])])
bgr2rgb = transforms.Compose([transforms.Lambda(lambda x: x[torch.LongTensor([2,1,0])])])
output_tensor = bgr2rgb(Normalize(output_tensor.squeeze(0).cpu())) / 256
output_tensor.clamp_(0, 1)
Image2PIL = transforms.ToPILImage()
image = Image2PIL(output_tensor.cpu())
return image
# Combine the Y channel of the generated image and the UV/CbCr channels of the
# content image to perform color-independent style transfer.
def original_colors(content, generated):
content_channels = list(content.convert('YCbCr').split())
generated_channels = list(generated.convert('YCbCr').split())
content_channels[0] = generated_channels[0]
return Image.merge('YCbCr', content_channels).convert('RGB')
# Print like Lua/Torch7
def print_torch(net, multidevice):
if multidevice:
return
simplelist = ""
for i, layer in enumerate(net, 1):
simplelist = simplelist + "(" + str(i) + ") -> "
print("nn.Sequential ( \n [input -> " + simplelist + "output]")
def strip(x):
return str(x).replace(", ",',').replace("(",'').replace(")",'') + ", "
def n():
return " (" + str(i) + "): " + "nn." + str(l).split("(", 1)[0]
for i, l in enumerate(net, 1):
if "2d" in str(l):
ks, st, pd = strip(l.kernel_size), strip(l.stride), strip(l.padding)
if "Conv2d" in str(l):
ch = str(l.in_channels) + " -> " + str(l.out_channels)
print(n() + "(" + ch + ", " + (ks).replace(",",'x', 1) + st + pd.replace(", ",')'))
elif "Pool2d" in str(l):
st = st.replace(" ",' ') + st.replace(", ",')')
print(n() + "(" + ((ks).replace(",",'x' + ks, 1) + st).replace(", ",','))
else:
print(n())
print(")")
# Divide weights by channel size
def normalize_weights(content_losses, style_losses):
for n, i in enumerate(content_losses):
i.strength = i.strength / max(i.target.size())
for n, i in enumerate(style_losses):
i.strength = i.strength / max(i.target.size())
# Define an nn Module to compute content loss
class ContentLoss(nn.Module):
def __init__(self, strength):
super(ContentLoss, self).__init__()
self.strength = strength
self.crit = nn.MSELoss()
self.mode = 'None'
def forward(self, input):
if self.mode == 'loss':
self.loss = self.crit(input, self.target) * self.strength
elif self.mode == 'capture':
self.target = input.detach()
return input
class GramMatrix(nn.Module):
def forward(self, input):
B, C, H, W = input.size()
x_flat = input.view(C, H * W)
return torch.mm(x_flat, x_flat.t())
# Define an nn Module to compute style loss
class StyleLoss(nn.Module):
def __init__(self, strength):
super(StyleLoss, self).__init__()
self.target = torch.Tensor()
self.strength = strength
self.gram = GramMatrix()
self.crit = nn.MSELoss()
self.mode = 'None'
self.blend_weight = None
def forward(self, input):
self.G = self.gram(input)
self.G = self.G.div(input.nelement())
if self.mode == 'capture':
if self.blend_weight == None:
self.target = self.G.detach()
elif self.target.nelement() == 0:
self.target = self.G.detach().mul(self.blend_weight)
else:
self.target = self.target.add(self.blend_weight, self.G.detach())
elif self.mode == 'loss':
self.loss = self.strength * self.crit(self.G, self.target)
return input
class TVLoss(nn.Module):
def __init__(self, strength):
super(TVLoss, self).__init__()
self.strength = strength
def forward(self, input):
self.x_diff = input[:,:,1:,:] - input[:,:,:-1,:]
self.y_diff = input[:,:,:,1:] - input[:,:,:,:-1]
self.loss = self.strength * (torch.sum(torch.abs(self.x_diff)) + torch.sum(torch.abs(self.y_diff)))
return input
if __name__ == "__main__":
main()
import torch
# Import all tiling functions with:
# from nspt_tile import rebuild_tensor, split_tensor_equal
'''
tiles, rows, overlap_hw = split_tensor_equal(tensor, tile_size=256, offset_x=0, offset_y=0)
complete_tensor = rebuild_tensor(tensor_list, rows, overlap_hw)
'''
######################################################################
# Tiling Blending Functions
######################################################################
# from nspt_tile import rebuild_tensor
# Apply blend masks to tiles
def prepare_tile(tile, overlap, side='both'):
h, w = tile.size(2), tile.size(3)
if tile.is_cuda:
lin_mask_left = torch.linspace(0,1,overlap, device=tile.get_device()).repeat(h,1).repeat(3,1,1).unsqueeze(0)
lin_mask_right = torch.linspace(1,0,overlap, device=tile.get_device()).repeat(h,1).repeat(3,1,1).unsqueeze(0)
else:
lin_mask_left = torch.linspace(0,1,overlap).repeat(h,1).repeat(3,1,1).unsqueeze(0)
lin_mask_right = torch.linspace(1,0,overlap).repeat(h,1).repeat(3,1,1).unsqueeze(0)
if side == 'both' or side == 'right':
tile[:,:,:,w-overlap:] = tile[:,:,:,w-overlap:] * lin_mask_right
if side == 'both' or side == 'left':
tile[:,:,:,:overlap] = tile[:,:,:,:overlap] * lin_mask_left
return tile
# Calculate width for base row tensor
def calc_length(w, overlap, rows):
count = 0
l_max = w
for y in range(rows[0]):
for x in range(rows[1]):
if count % rows[1] != 0:
l_max += w-overlap
l_min = l_max - w
else:
l_max = w
count+=1
return l_max
# Combine tiles in to rows
def overlay_tiles(tile_list, rows, overlap):
c = 1
f_tiles = []
for i, tile in enumerate(tile_list):
if c == 1:
f_tile = prepare_tile(tile.clone(), overlap, side='right')
elif c == rows[1]:
f_tile = prepare_tile(tile.clone(), overlap, side='left')
elif c > 0 and c < rows[1]:
f_tile = prepare_tile(tile.clone(), overlap, side='both')
f_tiles.append(f_tile)
if c == rows[1]:
c = 0
c+=1
w = tile_list[0].size(3)
base_length = calc_length(w, overlap, rows)
if tile_list[0].is_cuda:
base_tensor = torch.zeros(3, tile_list[0].size(2), base_length, device=tile_list[0].get_device()).unsqueeze(0)
else:
base_tensor = torch.zeros(3, tile_list[0].size(2), base_length).unsqueeze(0)
row_list = []
for row in range(rows[0]):
row_list.append(base_tensor.clone())
row_num, num_tiles = 0, 0
l_max = w
for y in range(rows[0]):
for x in range(rows[1]):
if num_tiles % rows[1] != 0:
l_max += w-overlap
l_min = l_max - w
row_list[row_num][:, :, :, l_min:l_max] = row_list[row_num][:, :, :, l_min:l_max] + f_tiles[num_tiles]
else:
row_list[row_num][:, :, :, :w] = f_tiles[num_tiles]
l_max = w
num_tiles+=1
row_num+=1
return row_list
# Apply blend masks to row tensors
def prepare_row(row_tensor, overlap, side='both'):
h, w = row_tensor.size(2), row_tensor.size(3)
if row_tensor.is_cuda:
lin_mask_top = torch.linspace(0,1,overlap, device=row_tensor.get_device()).repeat(w,1).rot90(3).repeat(3,1,1).unsqueeze(0)
lin_mask_bottom = torch.linspace(1,0,overlap, device=row_tensor.get_device()).repeat(w,1).rot90(3).repeat(3,1,1).unsqueeze(0)
else:
lin_mask_top = torch.linspace(0,1,overlap).repeat(w,1).rot90(3).repeat(3,1,1).unsqueeze(0)
lin_mask_bottom = torch.linspace(1,0,overlap).repeat(w,1).rot90(3).repeat(3,1,1).unsqueeze(0)
if side == 'both' or side == 'top':
row_tensor[:,:,:overlap,:] = row_tensor[:,:,:overlap,:] * lin_mask_top
if side == 'both' or side == 'bottom':
row_tensor[:,:,h-overlap:,:] = row_tensor[:,:,h-overlap:,:] * lin_mask_bottom
return row_tensor
# Calculate base tensor height
def calc_height(h, overlap, rows):
num_rows = 0
l_max = h
for y in range(rows[0]):
if num_rows > 0:
l_max += (h-overlap)
l_min = l_max - h
else:
l_max = h
num_rows+=1
return l_max
# Combine row tensors into final output
def overlay_rows(row_list, rows, overlap):
c = 1
f_rows = []
for i, row_tensor in enumerate(row_list):
if c == 1:
f_row = prepare_row(row_tensor.clone(), overlap, side='bottom')
elif c == rows[0]:
f_row = prepare_row(row_tensor.clone(), overlap, side='top')
elif c > 0 and c < rows[0]:
f_row = prepare_row(row_tensor.clone(), overlap, side='both')
f_rows.append(f_row)
if c == rows[0]:
c = 0
c+=1
h = row_list[0].size(2)
base_height = calc_height(h, overlap, rows)
if row_list[0].is_cuda:
base_tensor = torch.zeros(3, base_height, row_list[0].size(3), device=row_list[0].get_device()).unsqueeze(0)
else:
base_tensor = torch.zeros(3, base_height, row_list[0].size(3)).unsqueeze(0)
num_rows = 0
l_max = row_list[0].size(3)
for y in range(rows[0]):
if num_rows > 0:
l_max += (h-overlap)
l_min = l_max - h
base_tensor[:, :, l_min:l_max, :] = base_tensor[:, :, l_min:l_max, :] + f_rows[num_rows]
else:
base_tensor[:, :, :h, :] = f_rows[num_rows]
l_max = h
num_rows+=1
return base_tensor
# Combine tiles into final output with blending
def rebuild_tensor(tensor_list, rows, overlap_hw):
row_tensors = overlay_tiles(tensor_list, rows, overlap_hw[1])
full_tensor = overlay_rows(row_tensors, rows, overlap_hw[0])
return full_tensor
######################################################################
# Tiling Splitting Functions
######################################################################
# from nspt_tile import split_tensor_equal
# Calculate tile index locations
def tile_calc(tile_size, v, d):
max_val = max(min(tile_size*v+tile_size, d), 0)
min_val = tile_size*v
if abs(min_val - max_val) < tile_size:
min_val = max_val-tile_size
return min_val, max_val
# Split tensor into tiles, possibly with overlap
def split_tensor_equal(tensor, tile_size=256, match_size=False, offset_x=0, offset_y=0):
tiles, tile_idx = [], []
h, w = tensor.size(2), tensor.size(3)
if match_size:
if h > w:
d = w/h
tile_size_y, tile_size_x = tile_size + offset_y, int(tile_size/d) + offset_x
elif h < w:
d = w/h
tile_size_y, tile_size_x = int(tile_size/d) + offset_y, tile_size + offset_x
else:
tile_size_y, tile_size_x = tile_size + offset_y, tile_size + offset_x
h_range, w_range = int(-(h // -tile_size_y)), int(-(w // -tile_size_x))
for y in range(h_range):
for x in range(w_range):
ty, y_val = tile_calc(tile_size_y, y, h)
tx, x_val = tile_calc(tile_size_x, x, w)
tiles.append(tensor[:, :, ty:y_val, tx:x_val])
tile_idx.append([ty, y_val, tx, x_val])
w_overlap = tile_idx[0][3] - tile_idx[1][2]
h_overlap = tile_idx[0][1] - tile_idx[w_range][0]
return tiles, (h_range, w_range), (h_overlap, w_overlap)
######################################################################
# Histogram Related Functions
######################################################################
# Match content image to histogram input
def content_hm(content_image, hist_image=None, target='content', eps=1e-5, mode='pca'):
if hist_image != None:
if content_image.is_cuda:
dtype = torch.cuda.FloatTensor
else:
dtype = torch.FloatTensor
MH = MatchHistogram(eps, mode)
if 'content' in target:
image_size = (content_image.size(2), content_image.size(3))
hist_image = preprocess(hist_image, image_size).type(dtype)
content_image = MH.match(content_image, hist_image)
return content_image
# Match initialization image to histogram input
def init_hm(init_image, hist_image=None, target='content', eps=1e-5, mode='pca'):
if hist_image != None:
if init_image.is_cuda:
dtype = torch.cuda.FloatTensor
else:
dtype = torch.FloatTensor
MH = MatchHistogram(eps, mode)
if 'content' in target:
image_size = (content_image.size(2), content_image.size(3))
hist_image = preprocess(hist_image, image_size).type(dtype)
init_image = MH.match(init_image, hist_image)
return init_image
# Match style image to histogram input
def style_hm(style_images, hist_image=None, target='content', eps=1e-5, mode='pca'):
if hist_image != None:
if style_images_caffe[0].is_cuda:
dtype = torch.cuda.FloatTensor
else:
dtype = torch.FloatTensor
MH = MatchHistogram(eps, mode)
if 'style' in target:
for i, image in enumerate(style_images):
image_size = (style_images[i].size(2), style_images[i].size(3))
hist_image = preprocess(hist_image, image_size).type(dtype)
style_images[i] = MH.match(style_images[i], hist_image)
return style_images
# Define a module to match histograms
class MatchHistogram(torch.nn.Module):
def __init__(self, eps=1e-5, mode='pca'):
super(MatchHistogram, self).__init__()
self.eps = eps or 1e-5
self.mode = mode or 'pca'
self.dim_val = 3
def get_histogram(self, tensor):
m = tensor.mean(0).mean(0)
h = (tensor - m).permute(2,0,1).reshape(tensor.size(2),-1)
if h.is_cuda:
ch = torch.mm(h, h.T) / h.shape[1] + self.eps * torch.eye(h.shape[0], device=h.get_device())
else:
ch = torch.mm(h, h.T) / h.shape[1] + self.eps * torch.eye(h.shape[0])
return m, h, ch
def convert_tensor(self, tensor):
if tensor.dim() == 4:
tensor = tensor.squeeze(0).permute(2, 1, 0)
self.dim_val = 4
elif tensor.dim() == 3 and self.dim_val != 4:
tensor = tensor.permute(2, 1, 0)
elif tensor.dim() == 3 and self.dim_val == 4:
tensor = tensor.permute(2, 1, 0).unsqueeze(0)
return tensor
def nan2zero(self, tensor):
tensor[tensor != tensor] = 0
return tensor
def chol(self, t, c, s):
chol_t, chol_s = torch.cholesky(c), torch.cholesky(s)
return torch.mm(torch.mm(chol_s, torch.inverse(chol_t)), t)
def sym(self, t, c, s):
p = self.pca(t, c)
psp = torch.mm(torch.mm(p, s), p)
eval_psp, evec_psp = torch.symeig(psp, eigenvectors=True, upper=True)
e = self.nan2zero(torch.sqrt(torch.diagflat(eval_psp)))
evec_mm = torch.mm(torch.mm(evec_psp, e), evec_psp.T)
return torch.mm(torch.mm(torch.mm(torch.inverse(p), evec_mm), torch.inverse(p)), t)
def pca(self, t, c):
eval_t, evec_t = torch.symeig(c, eigenvectors=True, upper=True)
e = self.nan2zero(torch.sqrt(torch.diagflat(eval_t)))
return torch.mm(torch.mm(evec_t, e), evec_t.T)
def match(self, target_tensor, source_tensor):
source_tensor = self.convert_tensor(source_tensor)
target_tensor = self.convert_tensor(target_tensor)
_, t, ct = self.get_histogram(target_tensor)
ms, s, cs = self.get_histogram(source_tensor)
if self.mode == 'pca':
mt = torch.mm(torch.mm(self.pca(s, cs), torch.inverse(self.pca(t, ct))), t)
elif self.mode == 'sym':
mt = self.sym(t, ct, cs)
elif self.mode == 'chol':
mt = self.chol(t, ct, cs)
matched_tensor = mt.reshape(*target_tensor.permute(2,0,1).shape).permute(1,2,0) + ms
return self.convert_tensor(matched_tensor)
def forward(self, input, source_tensor):
return self.match(input, source_tensor)
# Define an nn Module to compute mean loss
class MeanLoss(torch.nn.Module):
def __init__(self, strength):
super(MeanLoss, self).__init__()
self.target = torch.Tensor()
self.crit = torch.nn.MSELoss()
self.mode = 'None'
self.strength = strength
def double_mean(self, tensor):
tensor = tensor.squeeze(0).permute(2, 1, 0)
return tensor.mean(0).mean(0)
def forward(self, input):
if self.mode == 'captureS':
self.target = self.double_mean(input.detach())
self.target_size = list(input.size())
elif self.mode == 'loss':
self.loss = 0.01 * self.strength * self.crit(self.double_mean(input.clone()) , self.target)
return input
######################################################################
# DeepDream Functions
######################################################################
# from nspt_tile import roll_tensor, resize_tensor, rescale_tensor
# Shift tensor, possibly randomly.
def roll_tensor(tensor, h_shift=None, w_shift=None):
if h_shift == None:
h_shift = torch.LongTensor(10).random_(-tensor.size(1), tensor.size(1))[0].item()
if w_shift == None:
w_shift = torch.LongTensor(10).random_(-tensor.size(2), tensor.size(2))[0].item()
if tensor.dim() == 3:
tensor = torch.roll(torch.roll(tensor, shifts=h_shift, dims=1), shifts=w_shift, dims=2)
elif tensor.dim() == 4:
tensor = torch.roll(torch.roll(tensor, shifts=h_shift, dims=2), shifts=w_shift, dims=3)
return tensor, h_shift, w_shift
# Rescale tensor with scale factor
def rescale_tensor(tensor, scale_factor):
tensor = torch.nn.functional.interpolate(tensor, scale_factor=(scale_factor,scale_factor))
return tensor
def resize_tensor(tensor, size):
tensor = torch.nn.functional.interpolate(tensor, size=size)
return tensor
def pixel_jitter(tensor):
jit_tensor = torch.randn_like(tensor) * 5
return tensor + jit_tensor
# Maybe shift tensor before splitting into tiles
def tiled_input(tensor, tile_size=256, shift_tensor=False, hs=None, ws=None):
if shift_tensor:
tensor, hs, ws = roll_tensor(tensor, hs, ws)
tensor_tiles, rows, overlap = split_tensor_equal(tensor, tile_size)
return tensor_tiles, rows, overlap, hs, ws
# Reverse shifting of tensor and rebuilding from tiles
def tiled_output(tensor_tiles, rows, overlap_hw, shift_tensor=False, hs=None, ws=None):
tensor = rebuild_tensor(tensor_tiles, rows, overlap_hw)
if shift_tensor:
tensor, _, _ = roll_tensor(tensor, -hs, -ws)
return tensor
# Define an nn Module to compute simple DeepDream loss
class SimpleDeepDreamLoss(torch.nn.Module):
def __init__(self, strength):
super(SimpleDeepDreamLoss, self).__init__()
self.strength = strength
self.mode = 'None'
def forward(self, input):
if self.mode == 'loss':
self.loss = -input.mean() * self.strength
elif self.mode == 'None':
self.target_size = list(input.size())
return input
class DreamLossType(torch.nn.Module):
def __init__(self, loss_mode, channels, channel_mode):
super(DreamLossType, self).__init__()
self.get_mode(loss_mode)
self.channels = channels
self.channel_mode = channel_mode
def get_mode(self, loss_mode):
self.loss_mode_string = loss_mode
if loss_mode.lower() == 'norm':
self.loss_mode = self.norm_loss
elif loss_mode.lower() == 'mean':
self.loss_mode = self.mean_loss
elif loss_mode.lower() == 'mse':
self.crit = torch.nn.MSELoss()
self.loss_mode = self.crit_loss
elif loss_mode.lower() == 'bce':
self.crit = torch.nn.BCEWithLogitsLoss()
self.loss_mode = self.crit_loss
def norm_loss(self, input):
return self.ch(input).norm()
def mean_loss(self, input):
return self.ch(input).norm()
def crit_loss(self, input):
target = torch.zeros_like(self.ch(input.detach()))
loss = self.crit(input, target)
return loss
def rank_channels(self, input):
top_channels = int(self.channels[0])
channels, channel_list = [], []
for i in range(input.size(1)):
channel_list.append(torch.mean(input.clone().squeeze(0).narrow(0,i,1)).item())
sorted_channel_list = sorted((c,v) for v,c in enumerate(channel_list))
# Maybe sort by strongest channels
if self.channel_mode == 'strong':
sorted_channel_list.reverse()
# Maybe only select some of the channels
if top_channels > 0:
channel_range = top_channels
else:
channel_range = input.size(1)
for i in range(channel_range):
channels.append(sorted_channel_list[i][1])
return channels
def select_c(self, input):
if self.channel_mode != 'all':
channel_list = self.rank_channels(input)
else:
channel_list = self.channels
return channel_list
def ch(self, input):
if self.channel_mode != 'all':
channel_list = self.select_c(input)
for c in channel_list:
if int(c) < input.size(1):
if self.loss_mode_string != 'mse' and self.loss_mode_string != 'bce':
input = input[0, int(c)]
return input
def forward(self, input):
loss = self.loss_mode(input)
return loss
class DreamLoss(torch.nn.Module):
def __init__(self, loss_mode, strength, channels, channel_mode):
super(DreamLoss, self).__init__()
self.dream = DreamLossType(loss_mode, channels.split(','), channel_mode)
self.strength = strength
self.mode = 'None'
def forward(self, input):
if self.mode == 'loss':
self.loss = self.dream(input) * self.strength
elif self.mode == 'None':
self.target_size = input.size()
return input
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment