|
import os |
|
import copy |
|
import torch |
|
import torch.nn as nn |
|
import torch.optim as optim |
|
import torchvision.transforms as transforms |
|
|
|
from PIL import Image |
|
from CaffeLoader import loadCaffemodel, ModelParallel |
|
|
|
import argparse |
|
parser = argparse.ArgumentParser() |
|
# Basic options |
|
parser.add_argument("-style_image", help="Style target image", default='examples/inputs/seated-nude.jpg') |
|
parser.add_argument("-style_blend_weights", default=None) |
|
parser.add_argument("-content_image", help="Content target image", default='examples/inputs/tubingen.jpg') |
|
parser.add_argument("-image_size", help="Maximum height / width of generated image", type=int, default=512) |
|
parser.add_argument("-gpu", help="Zero-indexed ID of the GPU to use; for CPU mode set -gpu = c", default=0) |
|
|
|
# Optimization options |
|
parser.add_argument("-content_weight", type=float, default=5e0) |
|
parser.add_argument("-style_weight", type=float, default=1e2) |
|
parser.add_argument("-normalize_weights", action='store_true') |
|
parser.add_argument("-mean_weight", type=float, default=1e2) |
|
parser.add_argument("-tv_weight", type=float, default=1e-3) |
|
parser.add_argument("-num_iterations", type=int, default=1000) |
|
parser.add_argument("-init", choices=['random', 'image'], default='random') |
|
parser.add_argument("-init_image", default=None) |
|
parser.add_argument("-optimizer", choices=['lbfgs', 'adam'], default='lbfgs') |
|
parser.add_argument("-learning_rate", type=float, default=1e0) |
|
parser.add_argument("-lbfgs_num_correction", type=int, default=100) |
|
|
|
# Output options |
|
parser.add_argument("-print_iter", type=int, default=50) |
|
parser.add_argument("-save_iter", type=int, default=100) |
|
parser.add_argument("-output_image", default='out.png') |
|
|
|
# Other options |
|
parser.add_argument("-style_scale", type=float, default=1.0) |
|
parser.add_argument("-original_colors", type=int, choices=[0, 1], default=0) |
|
parser.add_argument("-pooling", choices=['avg', 'max'], default='max') |
|
parser.add_argument("-model_file", type=str, default='models/vgg19-d01eb7cb.pth') |
|
parser.add_argument("-disable_check", action='store_true') |
|
parser.add_argument("-backend", choices=['nn', 'cudnn', 'mkl', 'mkldnn', 'openmp', 'mkl,cudnn', 'cudnn,mkl'], default='nn') |
|
parser.add_argument("-cudnn_autotune", action='store_true') |
|
parser.add_argument("-seed", type=int, default=-1) |
|
|
|
parser.add_argument("-content_layers", help="layers for content", default='relu4_2') |
|
parser.add_argument("-style_layers", help="layers for style", default='relu1_1,relu2_1,relu3_1,relu4_1,relu5_1') |
|
|
|
parser.add_argument("-multidevice_strategy", default='4,7,29') |
|
|
|
parser.add_argument("-mean_layers", help="layers for mean", default='') |
|
parser.add_argument("-eps", type=float, default=1e-5) |
|
parser.add_argument("-transfer_mode", choices=['pca', 'sym', 'chol'], default='pca') |
|
parser.add_argument("-hist_image", default=None) |
|
parser.add_argument("-hist_target", choices=['content', 'style', 'content,style'], default='content') |
|
params = parser.parse_args() |
|
|
|
|
|
Image.MAX_IMAGE_PIXELS = 1000000000 # Support gigapixel images |
|
|
|
|
|
def main(): |
|
dtype, multidevice, backward_device = setup_gpu() |
|
|
|
cnn, layerList = loadCaffemodel(params.model_file, params.pooling, params.gpu, params.disable_check) |
|
|
|
content_image = preprocess(params.content_image, params.image_size).type(dtype) |
|
style_image_input = params.style_image.split(',') |
|
style_image_list, ext = [], [".jpg", ".jpeg", ".png", ".tiff"] |
|
for image in style_image_input: |
|
if os.path.isdir(image): |
|
images = (image + "/" + file for file in os.listdir(image) |
|
if os.path.splitext(file)[1].lower() in ext) |
|
style_image_list.extend(images) |
|
else: |
|
style_image_list.append(image) |
|
style_images_caffe = [] |
|
for image in style_image_list: |
|
style_size = int(params.image_size * params.style_scale) |
|
img_caffe = preprocess(image, style_size).type(dtype) |
|
style_images_caffe.append(img_caffe) |
|
|
|
if params.init_image != None: |
|
image_size = (content_image.size(2), content_image.size(3)) |
|
init_image = preprocess(params.init_image, image_size).type(dtype) |
|
|
|
if params.hist_image != None: |
|
MH = MatchHistogram(params.eps, params.transfer_mode) |
|
if 'content' in params.hist_target: |
|
image_size = (content_image.size(2), content_image.size(3)) |
|
hist_image = preprocess(params.hist_image, image_size).type(dtype) |
|
content_image = MH.match(content_image, hist_image) |
|
if params.init_image != None: |
|
init_image = MH.match(init_image, hist_image) |
|
if 'style' in params.hist_target: |
|
for i, image in enumerate(style_images_caffe): |
|
image_size = (style_images_caffe[i].size(2), style_images_caffe[i].size(3)) |
|
hist_image = preprocess(params.hist_image, image_size).type(dtype) |
|
style_images_caffe[i] = MH.match(style_images_caffe[i], hist_image) |
|
|
|
|
|
# Handle style blending weights for multiple style inputs |
|
style_blend_weights = [] |
|
if params.style_blend_weights == None: |
|
# Style blending not specified, so use equal weighting |
|
for i in style_image_list: |
|
style_blend_weights.append(1.0) |
|
for i, blend_weights in enumerate(style_blend_weights): |
|
style_blend_weights[i] = int(style_blend_weights[i]) |
|
else: |
|
style_blend_weights = params.style_blend_weights.split(',') |
|
assert len(style_blend_weights) == len(style_image_list), \ |
|
"-style_blend_weights and -style_images must have the same number of elements!" |
|
|
|
# Normalize the style blending weights so they sum to 1 |
|
style_blend_sum = 0 |
|
for i, blend_weights in enumerate(style_blend_weights): |
|
style_blend_weights[i] = float(style_blend_weights[i]) |
|
style_blend_sum = float(style_blend_sum) + style_blend_weights[i] |
|
for i, blend_weights in enumerate(style_blend_weights): |
|
style_blend_weights[i] = float(style_blend_weights[i]) / float(style_blend_sum) |
|
|
|
content_layers = params.content_layers.split(',') |
|
style_layers = params.style_layers.split(',') |
|
mean_layers = params.mean_layers.split(',') |
|
|
|
# Set up the network, inserting style and content loss modules |
|
cnn = copy.deepcopy(cnn) |
|
content_losses, style_losses, tv_losses, mean_losses = [], [], [], [] |
|
next_content_idx, next_style_idx, next_mean_idx = 1, 1, 1 |
|
net = nn.Sequential() |
|
c, r = 0, 0 |
|
if params.tv_weight > 0: |
|
tv_mod = TVLoss(params.tv_weight).type(dtype) |
|
net.add_module(str(len(net)), tv_mod) |
|
tv_losses.append(tv_mod) |
|
|
|
for i, layer in enumerate(list(cnn), 1): |
|
if next_content_idx <= len(content_layers) or next_style_idx <= len(style_layers) or next_mean_idx <= len(mean_layers): |
|
if isinstance(layer, nn.Conv2d): |
|
net.add_module(str(len(net)), layer) |
|
|
|
if layerList['C'][c] in content_layers: |
|
print("Setting up content layer " + str(i) + ": " + str(layerList['C'][c])) |
|
loss_module = ContentLoss(params.content_weight) |
|
net.add_module(str(len(net)), loss_module) |
|
content_losses.append(loss_module) |
|
|
|
if layerList['C'][c] in style_layers: |
|
print("Setting up style layer " + str(i) + ": " + str(layerList['C'][c])) |
|
loss_module = StyleLoss(params.style_weight) |
|
net.add_module(str(len(net)), loss_module) |
|
style_losses.append(loss_module) |
|
c+=1 |
|
|
|
if isinstance(layer, nn.ReLU): |
|
net.add_module(str(len(net)), layer) |
|
|
|
if layerList['R'][r] in content_layers: |
|
print("Setting up content layer " + str(i) + ": " + str(layerList['R'][r])) |
|
loss_module = ContentLoss(params.content_weight) |
|
net.add_module(str(len(net)), loss_module) |
|
content_losses.append(loss_module) |
|
next_content_idx += 1 |
|
|
|
if layerList['R'][r] in style_layers: |
|
print("Setting up style layer " + str(i) + ": " + str(layerList['R'][r])) |
|
loss_module = StyleLoss(params.style_weight) |
|
net.add_module(str(len(net)), loss_module) |
|
style_losses.append(loss_module) |
|
next_style_idx += 1 |
|
|
|
if layerList['R'][r] in mean_layers: |
|
print("Setting up mean layer " + str(i) + ": " + str(layerList['R'][r])) |
|
loss_module = MeanLoss(params.mean_weight) |
|
net.add_module(str(len(net)), loss_module) |
|
mean_losses.append(loss_module) |
|
next_mean_idx +=1 |
|
r+=1 |
|
|
|
if isinstance(layer, nn.MaxPool2d) or isinstance(layer, nn.AvgPool2d): |
|
net.add_module(str(len(net)), layer) |
|
|
|
if multidevice: |
|
net = setup_multi_device(net) |
|
|
|
|
|
print("Capturing mean targets") |
|
for i in mean_losses: |
|
i.mode = 'captureS' |
|
net(style_images_caffe[0]) |
|
for i in mean_losses: |
|
i.mode = 'None' |
|
|
|
# Capture content targets |
|
for i in content_losses: |
|
i.mode = 'capture' |
|
print("Capturing content targets") |
|
print_torch(net, multidevice) |
|
net(content_image) |
|
|
|
# Capture style targets |
|
for i in content_losses: |
|
i.mode = 'None' |
|
|
|
for i, image in enumerate(style_images_caffe): |
|
print("Capturing style target " + str(i+1)) |
|
for j in style_losses: |
|
j.mode = 'capture' |
|
j.blend_weight = style_blend_weights[i] |
|
net(style_images_caffe[i]) |
|
|
|
# Set all loss modules to loss mode |
|
for i in content_losses: |
|
i.mode = 'loss' |
|
for i in style_losses: |
|
i.mode = 'loss' |
|
for i in mean_losses: |
|
i.mode = 'loss' |
|
|
|
# Maybe normalize content and style weights |
|
if params.normalize_weights: |
|
normalize_weights(content_losses, style_losses, mean_losses) |
|
|
|
# Freeze the network in order to prevent |
|
# unnecessary gradient calculations |
|
for param in net.parameters(): |
|
param.requires_grad = False |
|
|
|
# Initialize the image |
|
if params.seed >= 0: |
|
torch.manual_seed(params.seed) |
|
torch.cuda.manual_seed_all(params.seed) |
|
torch.backends.cudnn.deterministic=True |
|
if params.init == 'random': |
|
B, C, H, W = content_image.size() |
|
img = torch.randn(C, H, W).mul(0.001).unsqueeze(0).type(dtype) |
|
elif params.init == 'image': |
|
if params.init_image != None: |
|
img = init_image.clone() |
|
else: |
|
img = content_image.clone() |
|
img = nn.Parameter(img) |
|
|
|
def maybe_print(t, loss): |
|
if params.print_iter > 0 and t % params.print_iter == 0: |
|
print("Iteration " + str(t) + " / "+ str(params.num_iterations)) |
|
for i, loss_module in enumerate(content_losses): |
|
print(" Content " + str(i+1) + " loss: " + str(loss_module.loss.item())) |
|
for i, loss_module in enumerate(style_losses): |
|
print(" Style " + str(i+1) + " loss: " + str(loss_module.loss.item())) |
|
for i, loss_module in enumerate(mean_losses): |
|
print(" Mean " + str(i+1) + " loss: " + str(loss_module.loss.item())) |
|
print(" Total loss: " + str(loss.item())) |
|
|
|
def maybe_save(t): |
|
should_save = params.save_iter > 0 and t % params.save_iter == 0 |
|
should_save = should_save or t == params.num_iterations |
|
if should_save: |
|
output_filename, file_extension = os.path.splitext(params.output_image) |
|
if t == params.num_iterations: |
|
filename = output_filename + str(file_extension) |
|
else: |
|
filename = str(output_filename) + "_" + str(t) + str(file_extension) |
|
disp = deprocess(img.clone()) |
|
|
|
# Maybe perform postprocessing for color-independent style transfer |
|
if params.original_colors == 1: |
|
disp = original_colors(deprocess(content_image.clone()), disp) |
|
|
|
disp.save(str(filename)) |
|
|
|
# Function to evaluate loss and gradient. We run the net forward and |
|
# backward to get the gradient, and sum up losses from the loss modules. |
|
# optim.lbfgs internally handles iteration and calls this function many |
|
# times, so we manually count the number of iterations to handle printing |
|
# and saving intermediate results. |
|
num_calls = [0] |
|
def feval(): |
|
num_calls[0] += 1 |
|
optimizer.zero_grad() |
|
net(img) |
|
loss = 0 |
|
|
|
for mod in content_losses: |
|
loss += mod.loss.to(backward_device) |
|
for mod in style_losses: |
|
loss += mod.loss.to(backward_device) |
|
if params.tv_weight > 0: |
|
for mod in tv_losses: |
|
loss += mod.loss.to(backward_device) |
|
for mod in mean_losses: |
|
loss += mod.loss.to(backward_device) |
|
|
|
loss.backward() |
|
|
|
maybe_save(num_calls[0]) |
|
maybe_print(num_calls[0], loss) |
|
|
|
return loss |
|
|
|
optimizer, loopVal = setup_optimizer(img) |
|
while num_calls[0] <= loopVal: |
|
optimizer.step(feval) |
|
|
|
|
|
# Configure the optimizer |
|
def setup_optimizer(img): |
|
if params.optimizer == 'lbfgs': |
|
print("Running optimization with L-BFGS") |
|
optim_state = { |
|
'max_iter': params.num_iterations, |
|
'tolerance_change': -1, |
|
'tolerance_grad': -1, |
|
} |
|
if params.lbfgs_num_correction != 100: |
|
optim_state['history_size'] = params.lbfgs_num_correction |
|
optimizer = optim.LBFGS([img], **optim_state) |
|
loopVal = 1 |
|
elif params.optimizer == 'adam': |
|
print("Running optimization with ADAM") |
|
optimizer = optim.Adam([img], lr = params.learning_rate) |
|
loopVal = params.num_iterations - 1 |
|
return optimizer, loopVal |
|
|
|
|
|
def setup_gpu(): |
|
def setup_cuda(): |
|
if 'cudnn' in params.backend: |
|
torch.backends.cudnn.enabled = True |
|
if params.cudnn_autotune: |
|
torch.backends.cudnn.benchmark = True |
|
else: |
|
torch.backends.cudnn.enabled = False |
|
|
|
def setup_cpu(): |
|
if 'mkl' in params.backend and 'mkldnn' not in params.backend: |
|
torch.backends.mkl.enabled = True |
|
elif 'mkldnn' in params.backend: |
|
raise ValueError("MKL-DNN is not supported yet.") |
|
elif 'openmp' in params.backend: |
|
torch.backends.openmp.enabled = True |
|
|
|
multidevice = False |
|
if "," in str(params.gpu): |
|
devices = params.gpu.split(',') |
|
multidevice = True |
|
|
|
if 'c' in str(devices[0]).lower(): |
|
backward_device = "cpu" |
|
setup_cuda(), setup_cpu() |
|
else: |
|
backward_device = "cuda:" + devices[0] |
|
setup_cuda() |
|
dtype = torch.FloatTensor |
|
|
|
elif "c" not in str(params.gpu).lower(): |
|
setup_cuda() |
|
dtype, backward_device = torch.cuda.FloatTensor, "cuda:" + str(params.gpu) |
|
else: |
|
setup_cpu() |
|
dtype, backward_device = torch.FloatTensor, "cpu" |
|
return dtype, multidevice, backward_device |
|
|
|
|
|
def setup_multi_device(net): |
|
assert len(params.gpu.split(',')) - 1 == len(params.multidevice_strategy.split(',')), \ |
|
"The number of -multidevice_strategy layer indices minus 1, must be equal to the number of -gpu devices." |
|
|
|
new_net = ModelParallel(net, params.gpu, params.multidevice_strategy) |
|
return new_net |
|
|
|
|
|
# Preprocess an image before passing it to a model. |
|
# We need to rescale from [0, 1] to [0, 255], convert from RGB to BGR, |
|
# and subtract the mean pixel. |
|
def preprocess(image_name, image_size): |
|
image = Image.open(image_name).convert('RGB') |
|
if type(image_size) is not tuple: |
|
image_size = tuple([int((float(image_size) / max(image.size))*x) for x in (image.height, image.width)]) |
|
Loader = transforms.Compose([transforms.Resize(image_size), transforms.ToTensor()]) |
|
rgb2bgr = transforms.Compose([transforms.Lambda(lambda x: x[torch.LongTensor([2,1,0])])]) |
|
Normalize = transforms.Compose([transforms.Normalize(mean=[103.939, 116.779, 123.68], std=[1,1,1])]) |
|
tensor = Normalize(rgb2bgr(Loader(image) * 256)).unsqueeze(0) |
|
return tensor |
|
|
|
|
|
# Undo the above preprocessing. |
|
def deprocess(output_tensor): |
|
Normalize = transforms.Compose([transforms.Normalize(mean=[-103.939, -116.779, -123.68], std=[1,1,1])]) |
|
bgr2rgb = transforms.Compose([transforms.Lambda(lambda x: x[torch.LongTensor([2,1,0])])]) |
|
output_tensor = bgr2rgb(Normalize(output_tensor.squeeze(0).cpu())) / 256 |
|
output_tensor.clamp_(0, 1) |
|
Image2PIL = transforms.ToPILImage() |
|
image = Image2PIL(output_tensor.cpu()) |
|
return image |
|
|
|
|
|
# Combine the Y channel of the generated image and the UV/CbCr channels of the |
|
# content image to perform color-independent style transfer. |
|
def original_colors(content, generated): |
|
content_channels = list(content.convert('YCbCr').split()) |
|
generated_channels = list(generated.convert('YCbCr').split()) |
|
content_channels[0] = generated_channels[0] |
|
return Image.merge('YCbCr', content_channels).convert('RGB') |
|
|
|
|
|
# Print like Lua/Torch7 |
|
def print_torch(net, multidevice): |
|
if multidevice: |
|
return |
|
simplelist = "" |
|
for i, layer in enumerate(net, 1): |
|
simplelist = simplelist + "(" + str(i) + ") -> " |
|
print("nn.Sequential ( \n [input -> " + simplelist + "output]") |
|
|
|
def strip(x): |
|
return str(x).replace(", ",',').replace("(",'').replace(")",'') + ", " |
|
def n(): |
|
return " (" + str(i) + "): " + "nn." + str(l).split("(", 1)[0] |
|
|
|
for i, l in enumerate(net, 1): |
|
if "2d" in str(l): |
|
ks, st, pd = strip(l.kernel_size), strip(l.stride), strip(l.padding) |
|
if "Conv2d" in str(l): |
|
ch = str(l.in_channels) + " -> " + str(l.out_channels) |
|
print(n() + "(" + ch + ", " + (ks).replace(",",'x', 1) + st + pd.replace(", ",')')) |
|
elif "Pool2d" in str(l): |
|
st = st.replace(" ",' ') + st.replace(", ",')') |
|
print(n() + "(" + ((ks).replace(",",'x' + ks, 1) + st).replace(", ",',')) |
|
else: |
|
print(n()) |
|
print(")") |
|
|
|
|
|
# Divide weights by channel size |
|
def normalize_weights(content_losses, style_losses, mean_losses=None): |
|
for n, i in enumerate(content_losses): |
|
i.strength = i.strength / max(i.target.size()) |
|
for n, i in enumerate(style_losses): |
|
i.strength = i.strength / max(i.target.size()) |
|
if mean_losses != None: |
|
for n, i in enumerate(mean_losses): |
|
i.strength = i.strength / max(i.target_size) |
|
|
|
|
|
# Define an nn Module to compute content loss |
|
class ContentLoss(nn.Module): |
|
|
|
def __init__(self, strength): |
|
super(ContentLoss, self).__init__() |
|
self.strength = strength |
|
self.crit = nn.MSELoss() |
|
self.mode = 'None' |
|
|
|
def forward(self, input): |
|
if self.mode == 'loss': |
|
self.loss = self.crit(input, self.target) * self.strength |
|
elif self.mode == 'capture': |
|
self.target = input.detach() |
|
return input |
|
|
|
|
|
class GramMatrix(nn.Module): |
|
|
|
def forward(self, input): |
|
B, C, H, W = input.size() |
|
x_flat = input.view(C, H * W) |
|
return torch.mm(x_flat, x_flat.t()) |
|
|
|
|
|
# Define an nn Module to compute style loss |
|
class StyleLoss(nn.Module): |
|
|
|
def __init__(self, strength): |
|
super(StyleLoss, self).__init__() |
|
self.target = torch.Tensor() |
|
self.strength = strength |
|
self.gram = GramMatrix() |
|
self.crit = nn.MSELoss() |
|
self.mode = 'None' |
|
self.blend_weight = None |
|
|
|
def forward(self, input): |
|
self.G = self.gram(input) |
|
self.G = self.G.div(input.nelement()) |
|
if self.mode == 'capture': |
|
if self.blend_weight == None: |
|
self.target = self.G.detach() |
|
elif self.target.nelement() == 0: |
|
self.target = self.G.detach().mul(self.blend_weight) |
|
else: |
|
self.target = self.target.add(self.blend_weight, self.G.detach()) |
|
elif self.mode == 'loss': |
|
self.loss = self.strength * self.crit(self.G, self.target) |
|
return input |
|
|
|
|
|
class TVLoss(nn.Module): |
|
|
|
def __init__(self, strength): |
|
super(TVLoss, self).__init__() |
|
self.strength = strength |
|
|
|
def forward(self, input): |
|
self.x_diff = input[:,:,1:,:] - input[:,:,:-1,:] |
|
self.y_diff = input[:,:,:,1:] - input[:,:,:,:-1] |
|
self.loss = self.strength * (torch.sum(torch.abs(self.x_diff)) + torch.sum(torch.abs(self.y_diff))) |
|
return input |
|
|
|
|
|
# Define a module to match histograms |
|
class MatchHistogram(nn.Module): |
|
def __init__(self, eps=1e-5, mode='pca'): |
|
super(MatchHistogram, self).__init__() |
|
self.eps = eps or 1e-5 |
|
self.mode = mode or 'pca' |
|
self.dim_val = 3 |
|
|
|
def get_histogram(self, tensor): |
|
m = tensor.mean(0).mean(0) |
|
h = (tensor - m).permute(2,0,1).reshape(tensor.size(2),-1) |
|
if h.is_cuda: |
|
ch = torch.mm(h, h.T) / h.shape[1] + self.eps * torch.eye(h.shape[0], device=h.get_device()) |
|
else: |
|
ch = torch.mm(h, h.T) / h.shape[1] + self.eps * torch.eye(h.shape[0]) |
|
return m, h, ch |
|
|
|
def convert_tensor(self, tensor): |
|
if tensor.dim() == 4: |
|
tensor = tensor.squeeze(0).permute(2, 1, 0) |
|
self.dim_val = 4 |
|
elif tensor.dim() == 3 and self.dim_val != 4: |
|
tensor = tensor.permute(2, 1, 0) |
|
elif tensor.dim() == 3 and self.dim_val == 4: |
|
tensor = tensor.permute(2, 1, 0).unsqueeze(0) |
|
return tensor |
|
|
|
def nan2zero(self, tensor): |
|
tensor[tensor != tensor] = 0 |
|
return tensor |
|
|
|
def chol(self, t, c, s): |
|
chol_t, chol_s = torch.cholesky(c), torch.cholesky(s) |
|
return torch.mm(torch.mm(chol_s, torch.inverse(chol_t)), t) |
|
|
|
def sym(self, t, c, s): |
|
p = self.pca(t, c) |
|
psp = torch.mm(torch.mm(p, s), p) |
|
eval_psp, evec_psp = torch.symeig(psp, eigenvectors=True, upper=True) |
|
e = self.nan2zero(torch.sqrt(torch.diagflat(eval_psp))) |
|
evec_mm = torch.mm(torch.mm(evec_psp, e), evec_psp.T) |
|
return torch.mm(torch.mm(torch.mm(torch.inverse(p), evec_mm), torch.inverse(p)), t) |
|
|
|
def pca(self, t, c): |
|
eval_t, evec_t = torch.symeig(c, eigenvectors=True, upper=True) |
|
e = self.nan2zero(torch.sqrt(torch.diagflat(eval_t))) |
|
return torch.mm(torch.mm(evec_t, e), evec_t.T) |
|
|
|
def match(self, target_tensor, source_tensor): |
|
source_tensor = self.convert_tensor(source_tensor) |
|
target_tensor = self.convert_tensor(target_tensor) |
|
|
|
_, t, ct = self.get_histogram(target_tensor) |
|
ms, s, cs = self.get_histogram(source_tensor) |
|
|
|
if self.mode == 'pca': |
|
mt = torch.mm(torch.mm(self.pca(s, cs), torch.inverse(self.pca(t, ct))), t) |
|
elif self.mode == 'sym': |
|
mt = self.sym(t, ct, cs) |
|
elif self.mode == 'chol': |
|
mt = self.chol(t, ct, cs) |
|
|
|
matched_tensor = mt.reshape(*target_tensor.permute(2,0,1).shape).permute(1,2,0) + ms |
|
return self.convert_tensor(matched_tensor) |
|
|
|
def forward(self, input, source_tensor): |
|
return self.match(input, source_tensor) |
|
|
|
|
|
# Define an nn Module to compute mean loss |
|
class MeanLoss(nn.Module): |
|
|
|
def __init__(self, strength): |
|
super(MeanLoss, self).__init__() |
|
self.target = torch.Tensor() |
|
self.crit = nn.MSELoss() |
|
self.mode = 'None' |
|
self.strength = strength |
|
|
|
def double_mean(self, tensor): |
|
tensor = tensor.squeeze(0).permute(2, 1, 0) |
|
return tensor.mean(0).mean(0) |
|
|
|
def forward(self, input): |
|
if self.mode == 'captureS': |
|
self.target = self.double_mean(input.detach()) |
|
self.target_size = list(input.size()) |
|
elif self.mode == 'loss': |
|
self.loss = 0.01 * self.strength * self.crit(self.double_mean(input.clone()) , self.target) |
|
return input |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |