-
-
Save ProGamerGov/753b64404547662b9ff3d816a7f88f9f to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Code | |
from __future__ import print_function | |
import torch | |
import torch.nn as nn | |
from torch.autograd import Variable | |
import torch.optim as optim | |
from PIL import Image | |
from skimage import io,transform,img_as_float | |
from skimage.io import imread,imsave | |
import torchvision | |
import torchvision.transforms as transforms | |
import torchvision.models as models | |
from torchvision.utils import save_image | |
import copy | |
import argparse | |
parser = argparse.ArgumentParser() | |
# Basic options | |
parser.add_argument("-style_image", help="Style target image", default='examples/inputs/seated-nude.jpg') | |
parser.add_argument("-content_image", help="Content target image", default='examples/inputs/tubingen.jpg') | |
parser.add_argument("-image_size", help="Maximum height / width of generated image", type=int, default=512) | |
# Optimization options | |
parser.add_argument("-content_weight", help="content weight", type=int, default=5) | |
parser.add_argument("-style_weight", help="style weight", type=int, default=10) | |
parser.add_argument("-normalize_gradients", action='store_true') | |
# Output options | |
parser.add_argument("-output_image", default='out.png') | |
# Other options | |
parser.add_argument("-style_scale", help="style scale", type=float, default=1.0) | |
parser.add_argument("-proto_file", default='models/VGG_ILSVRC_19_layers_deploy.prototxt') | |
parser.add_argument("-model_file", default='models/VGG_ILSVRC_19_layers.caffemodel') | |
parser.add_argument("-backend", choices=["nn", "cudnn", "clnn"], default='cudnn') | |
params = parser.parse_args() | |
use_cuda = torch.cuda.is_available() | |
dtype = torch.cuda.FloatTensor if use_cuda else torch.FloatTensor | |
#cnn = loadcaffe.load(params.proto_file, params.model_file, params.backend) #.type(dtype) | |
cnn = models.vgg19(pretrained=True).features | |
loader = transforms.Compose([ | |
transforms.Scale(params.image_size), # scale imported image | |
transforms.ToTensor()]) # transform it into a torch tensor | |
def image_loader(image_name): | |
image = Image.open(image_name) | |
image = Variable(loader(image)) | |
# fake batch dimension required to fit network's input dimensions | |
image = image.unsqueeze(0) | |
return image | |
content_image_caffe = image_loader(params.content_image).type(dtype) | |
style_image_caffe = image_loader(params.style_image).type(dtype) | |
# move it to the GPU if possible: | |
if use_cuda: | |
cnn = cnn.cuda() | |
#print(cnn) | |
content_layers_default = ['conv_4'] | |
style_layers_default = ['conv_1', 'conv_2', 'conv_3', 'conv_4', 'conv_5'] | |
def create_model(cnn, style_image_caffe, content_image_caffe, style_weight=params.style_weight, content_weight=params.style_weight, content_layers=content_layers_default, style_layers=style_layers_default): | |
cnn = copy.deepcopy(cnn) | |
content_losses = [] | |
style_losses = [] | |
model = nn.Sequential() # the new Sequential module network | |
gram = GramMatrix() # we need a gram module in order to compute style targets | |
# move these modules to the GPU if possible: | |
if use_cuda: | |
model = model.cuda() | |
gram = gram.cuda() | |
i = 1 | |
for layer in list(cnn): | |
if isinstance(layer, nn.Conv2d): | |
name = "conv_" + str(i) | |
model.add_module(name, layer) | |
if name in content_layers: | |
# add content loss: | |
target = model(content_image_caffe).clone() | |
content_loss = ContentLoss(target, content_weight) | |
model.add_module("content_loss_" + str(i), content_loss) | |
content_losses.append(content_loss) | |
if name in style_layers: | |
# add style loss: | |
target_feature = model(style_image_caffe).clone() | |
target_feature_gram = gram(target_feature) | |
style_loss = StyleLoss(target_feature_gram, style_weight) | |
model.add_module("style_loss_" + str(i), style_loss) | |
style_losses.append(style_loss) | |
if isinstance(layer, nn.ReLU): | |
name = "relu_" + str(i) | |
model.add_module(name, layer) | |
if name in content_layers: | |
# add content loss: | |
target = model(content_image_caffe).clone() | |
content_loss = ContentLoss(target, content_weight) | |
model.add_module("content_loss_" + str(i), content_loss) | |
content_losses.append(content_loss) | |
if name in style_layers: | |
# add style loss: | |
target_feature = model(style_image_caffe).clone() | |
target_feature_gram = gram(target_feature) | |
style_loss = StyleLoss(target_feature_gram, style_weight) | |
model.add_module("style_loss_" + str(i), style_loss) | |
style_losses.append(style_loss) | |
i += 1 | |
if isinstance(layer, nn.MaxPool2d): | |
name = "pool_" + str(i) | |
model.add_module(name, layer) # *** | |
return model, style_losses, content_losses | |
class ContentLoss(nn.Module): | |
def __init__(self, target, strength): | |
super(ContentLoss, self).__init__() | |
self.target = target.detach() * strength | |
self.strength = strength | |
self.crit = nn.MSELoss() | |
def forward(self, input): | |
self.loss = self.crit(input, self.target) * self.strength | |
self.output = input | |
return self.output | |
def backward(self, retain_graph=True): | |
self.loss.backward(retain_graph=retain_graph) | |
return self.loss | |
class GramMatrix(nn.Module): | |
def forward(self, input): | |
a, b, c, d = input.size() # a=batch size(=1) | |
# b=number of feature maps | |
# (c,d)=dimensions of a f. map (N=c*d) | |
features = input.view(a * b, c * d) # resise F_XL into \hat F_XL | |
G = torch.mm(features, features.t()) # compute the gram product | |
# we 'normalize' the values of the gram matrix | |
# by dividing by the number of element in each feature maps. | |
return G.div(a * b * c * d) | |
class StyleLoss(nn.Module): | |
def __init__(self, target, strength): | |
super(StyleLoss, self).__init__() | |
self.target = target.detach() * strength | |
self.strength = strength | |
self.gram = GramMatrix() | |
self.crit = nn.MSELoss() | |
def forward(self, input): | |
self.output = input.clone() | |
self.G = self.gram(input) | |
self.G.mul_(self.strength) | |
self.loss = self.crit(self.G, self.target) * self.strength | |
return self.output | |
def backward(self, retain_graph=True): | |
self.loss.backward(retain_graph=retain_graph) | |
return self.loss | |
model, style_losses, content_losses = create_model(cnn, style_image_caffe, content_image_caffe, params.style_weight, params.content_weight, content_layers_default, style_layers_default) | |
print("Model Loaded") | |
input_img = content_image_caffe.clone() | |
input_param = nn.Parameter(input_img.data) | |
learning_rate = 2 | |
#optimizer = torch.optim.Adam([input_param], lr=learning_rate) | |
optimizer = optim.LBFGS([input_param]) | |
num_steps = 25 | |
def feval(x): | |
print('Optimizing..') | |
run = [0] | |
while run[0] <= num_steps: | |
def iterate(): | |
input_param.data.clamp_(0, 1) | |
optimizer.zero_grad() | |
model(input_param) | |
style_score = 0 | |
content_score = 0 | |
for sl in style_losses: | |
style_score += sl.backward() | |
for cl in content_losses: | |
content_score += cl.backward() | |
run[0] += 1 | |
#if run[0] > 10: | |
print('Style Loss : {:4f} Content Loss: {:4f}'.format(style_score.data[0], content_score.data[0])) | |
return style_score + content_score | |
optimizer.step(iterate) | |
input_param.data.clamp_(0, 1) | |
return input_param.data | |
output_img = feval(input_param) | |
print("Test CNN") | |
#print(model) | |
torchvision.utils.save_image(output_img, params.output_image, nrow=8, padding=2, normalize=False, range=None, scale_each=False, pad_value=0) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment