-
-
Save ProGamerGov/1cef6405c822e10272535131ef70143e to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Code based on: http://pytorch.org/tutorials/advanced/neural_style_tutorial.html | |
# Trying to copy: https://github.com/jcjohnson/neural-style/blob/master/neural_style.lua | |
from __future__ import print_function | |
import torch | |
import torch.nn as nn | |
from torch.autograd import Variable | |
import torch.optim as optim | |
from PIL import Image | |
from skimage import io,transform,img_as_float | |
from skimage.io import imread,imsave | |
import torchvision | |
import torchvision.transforms as transforms | |
import torchvision.models as models | |
from torchvision.utils import save_image | |
import copy | |
################################################################################################################################## | |
import argparse | |
parser = argparse.ArgumentParser() | |
# Basic options | |
parser.add_argument("-style_image", help="Style target image", default='examples/inputs/seated-nude.jpg') | |
parser.add_argument("-style_blend_weights", help="style image blending weights", default="") | |
parser.add_argument("-content_image", help="Content target image", default='examples/inputs/tubingen.jpg') | |
parser.add_argument("-image_size", help="Maximum height / width of generated image", default=512) | |
parser.add_argument("-gpu", help="Zero-indexed ID of the GPU to use; for CPU mode set -gpu = 0", default=0) | |
parser.add_argument("-multigpu_strategy", help="multi-GPU layer splits", default="") | |
# Optimization options | |
parser.add_argument("-content_weight", help="content weight", type=int, default=5) | |
parser.add_argument("-style_weight", help="style weight", type=int, default=10) | |
parser.add_argument("-tv_weight", help="tv weight", default=0.001) | |
parser.add_argument("-num_iterations", help="iterations", type=int, default=1000) | |
parser.add_argument("-normalize_gradients", action='store_true') | |
parser.add_argument("-init", help="initialisation type", default="random", choices=["random", "image"]) | |
parser.add_argument("-init_image", help="initial image", default="") | |
parser.add_argument("-optimizer", help="optimiser", default="lbfgs", choices=["lbfgs", "adam"]) | |
parser.add_argument("-learning_rate", default=1) | |
parser.add_argument("-lbfgs_num_correction", help="lbfgs num correction", default=0) | |
# Output options | |
parser.add_argument("-print_iter", type=int, default=50) | |
parser.add_argument("-save_iter", type=int, default=100) | |
parser.add_argument("-output_image", default='out.png') | |
# Other options | |
parser.add_argument("-style_scale", help="style scale", default=1.0) | |
parser.add_argument("-original_colors", help="use original colours", choices=["0", "1"], default=0) | |
parser.add_argument("-pooling", help="pooling type", choices=["max", "avg"], default='max') | |
parser.add_argument("-proto_file", help="VGG 19 proto file location", default='models/VGG_ILSVRC_19_layers_deploy.prototxt') | |
parser.add_argument("-model_file", help="VGG 19 model file location", default='models/VGG_ILSVRC_19_layers.caffemodel') | |
parser.add_argument("-backend", help="backend", choices=["nn", "cudnn", "clnn"], default='cudnn') | |
parser.add_argument("-cudnn_autotune", help="cudnn autotune flag", action='store_true') | |
parser.add_argument("-seed", help="random number seed", default=-1) | |
parser.add_argument("-content_layers", help="VGG 19 content layers", default='relu4_2') | |
parser.add_argument("-style_layers", help="VGG 19 style layers", default='relu1_1,relu2_1,relu3_1,relu4_1,relu5_1') | |
params = parser.parse_args() | |
################################################################################################################################## | |
use_cuda = torch.cuda.is_available() | |
dtype = torch.cuda.FloatTensor if use_cuda else torch.FloatTensor | |
# desired size of the output image | |
imsize = params.image_size | |
loader = transforms.Compose([ | |
transforms.Scale(imsize), # scale imported image | |
transforms.ToTensor()]) # transform it into a torch tensor | |
def image_loader(image_name): | |
image = Image.open(image_name) | |
image = Variable(loader(image)) | |
# fake batch dimension required to fit network's input dimensions | |
image = image.unsqueeze(0) | |
return image | |
style_img = image_loader(params.style_image).type(dtype) | |
content_img = image_loader(params.content_image).type(dtype) | |
#unloader = transforms.ToPILImage() # reconvert into PIL image | |
class ContentLoss(nn.Module): | |
def __init__(self, target, strength): | |
super(ContentLoss, self).__init__() | |
self.target = target.detach() * strength | |
self.strength = strength | |
self.crit = nn.MSELoss() | |
def forward(self, input): | |
self.loss = self.crit(input, self.target) * self.strength | |
self.output = input | |
return self.output | |
def backward(self, retain_graph=True): | |
self.loss.backward(retain_graph=retain_graph) | |
return self.loss | |
class GramMatrix(nn.Module): | |
def forward(self, input): | |
a, b, c, d = input.size() # a=batch size(=1) | |
# b=number of feature maps | |
# (c,d)=dimensions of a f. map (N=c*d) | |
features = input.view(a * b, c * d) # resise F_XL into \hat F_XL | |
G = torch.mm(features, features.t()) # compute the gram product | |
# we 'normalize' the values of the gram matrix | |
# by dividing by the number of element in each feature maps. | |
return G.div(a * b * c * d) | |
class StyleLoss(nn.Module): | |
def __init__(self, target, strength): | |
super(StyleLoss, self).__init__() | |
self.target = target.detach() * strength | |
self.strength = strength | |
self.gram = GramMatrix() | |
self.crit = nn.MSELoss() | |
def forward(self, input): | |
self.output = input.clone() | |
self.G = self.gram(input) | |
self.G.mul_(self.strength) | |
self.loss = self.crit(self.G, self.target) * self.strength | |
return self.output | |
def backward(self, retain_graph=True): | |
self.loss.backward(retain_graph=retain_graph) | |
return self.loss | |
cnn = models.vgg19(pretrained=True).features | |
# move it to the GPU if possible: | |
if use_cuda: | |
cnn = cnn.cuda() | |
# desired depth layers to compute style/content losses : | |
content_layers_default = ['conv_4'] | |
style_layers_default = ['conv_1', 'conv_2', 'conv_3', 'conv_4', 'conv_5'] | |
def get_style_model_and_losses(cnn, style_img, content_img, | |
style_weight=params.style_weight, content_weight=params.style_weight, | |
content_layers=content_layers_default, | |
style_layers=style_layers_default): | |
cnn = copy.deepcopy(cnn) | |
# just in order to have an iterable access to or list of content/syle | |
# losses | |
content_losses = [] | |
style_losses = [] | |
model = nn.Sequential() # the new Sequential module network | |
gram = GramMatrix() # we need a gram module in order to compute style targets | |
# move these modules to the GPU if possible: | |
if use_cuda: | |
model = model.cuda() | |
gram = gram.cuda() | |
i = 1 | |
for layer in list(cnn): | |
if isinstance(layer, nn.Conv2d): | |
name = "conv_" + str(i) | |
model.add_module(name, layer) | |
if name in content_layers: | |
# add content loss: | |
target = model(content_img).clone() | |
content_loss = ContentLoss(target, content_weight) | |
model.add_module("content_loss_" + str(i), content_loss) | |
content_losses.append(content_loss) | |
if name in style_layers: | |
# add style loss: | |
target_feature = model(style_img).clone() | |
target_feature_gram = gram(target_feature) | |
style_loss = StyleLoss(target_feature_gram, style_weight) | |
model.add_module("style_loss_" + str(i), style_loss) | |
style_losses.append(style_loss) | |
if isinstance(layer, nn.ReLU): | |
name = "relu_" + str(i) | |
model.add_module(name, layer) | |
if name in content_layers: | |
# add content loss: | |
target = model(content_img).clone() | |
content_loss = ContentLoss(target, content_weight) | |
model.add_module("content_loss_" + str(i), content_loss) | |
content_losses.append(content_loss) | |
if name in style_layers: | |
# add style loss: | |
target_feature = model(style_img).clone() | |
target_feature_gram = gram(target_feature) | |
style_loss = StyleLoss(target_feature_gram, style_weight) | |
model.add_module("style_loss_" + str(i), style_loss) | |
style_losses.append(style_loss) | |
i += 1 | |
if isinstance(layer, nn.MaxPool2d): | |
name = "pool_" + str(i) | |
model.add_module(name, layer) # *** | |
return model, style_losses, content_losses | |
input_img = content_img.clone() | |
# if you want to use a white noise instead uncomment the below line: | |
# input_img = Variable(torch.randn(content_img.data.size())).type(dtype) | |
def get_input_param_optimizer(input_img): | |
# this line to show that input is a parameter that requires a gradient | |
input_param = nn.Parameter(input_img.data) | |
optimizer = optim.LBFGS([input_param]) | |
return input_param, optimizer | |
def run_style_transfer(cnn, content_img, style_img, input_img, num_steps=params.num_iterations, | |
style_weight=params.style_weight, content_weight=params.content_weight): | |
"""Run the style transfer.""" | |
print('Building the style transfer model..') | |
model, style_losses, content_losses = get_style_model_and_losses(cnn, | |
style_img, content_img, style_weight, content_weight) | |
input_param, optimizer = get_input_param_optimizer(input_img) | |
print('Optimizing..') | |
run = [0] | |
while run[0] <= num_steps: | |
def closure(): | |
# correct the values of updated input image | |
input_param.data.clamp_(0, 1) | |
optimizer.zero_grad() | |
model(input_param) | |
style_score = 0 | |
content_score = 0 | |
for sl in style_losses: | |
style_score += sl.backward() | |
for cl in content_losses: | |
content_score += cl.backward() | |
run[0] += 1 | |
if run[0] % params.print_iter == 0: | |
print("run {}:".format(run)) | |
print('Style Loss : {:4f} Content Loss: {:4f}'.format( | |
style_score.data[0], content_score.data[0])) | |
print() | |
if run[0] % params.save_iter == 0: | |
# iteration = run[0] | |
iter_img = input_param.data.clamp_(0, 1) | |
iter_save = params.output_image + str(run[0]) + ".png" | |
torchvision.utils.save_image(iter_img, iter_save, nrow=8, padding=2, normalize=False, range=None, scale_each=False, pad_value=0) | |
return style_score + content_score | |
optimizer.step(closure) | |
# a last correction... | |
input_param.data.clamp_(0, 1) | |
return input_param.data | |
###################################################################### | |
output_img = run_style_transfer(cnn, content_img, style_img, input_img) | |
output_image = params.output_image | |
print('Saving..') | |
torchvision.utils.save_image(output_img, params.output_image, nrow=8, padding=2, normalize=False, range=None, scale_each=False, pad_value=0) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment