Skip to content

Instantly share code, notes, and snippets.

Last active March 7, 2018 01:12
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save ProGamerGov/f735c1360207b420c4f920d69853e157 to your computer and use it in GitHub Desktop.
Save ProGamerGov/f735c1360207b420c4f920d69853e157 to your computer and use it in GitHub Desktop.
# Code - Trying to translate to PyTorch.
from __future__ import print_function
import torch
import torch.legacy.nn as nn
from torch.autograd import Variable
import torch.legacy.optim as optim
from PIL import Image
#from skimage import io,transform,img_as_float
#from import imread,imsave
import torchvision
import torchvision.transforms as transforms
import torchvision.models as models
from torchvision.utils import save_image
import copy
import argparse
parser = argparse.ArgumentParser()
# Basic options
parser.add_argument("-style_image", help="Style target image", default='examples/inputs/seated-nude.jpg')
parser.add_argument("-content_image", help="Content target image", default='examples/inputs/tubingen.jpg')
parser.add_argument("-image_size", help="Maximum height / width of generated image", type=int, default=512)
# Optimization options
parser.add_argument("-content_weight", help="content weight", type=int, default=5)
parser.add_argument("-style_weight", help="style weight", type=int, default=10)
parser.add_argument("-num_iterations", help="iterations", type=int, default=1000)
parser.add_argument("-normalize_gradients", action='store_true')
parser.add_argument("-init", help="initialisation type", default="random", choices=["random", "image"])
parser.add_argument("-init_image", help="initial image", default="")
parser.add_argument("-optimizer", help="optimiser", default="lbfgs", choices=["lbfgs", "adam"])
parser.add_argument("-learning_rate", default=1)
parser.add_argument("-lbfgs_num_correction", help="lbfgs num correction", default=0)
# Output options
parser.add_argument("-output_image", default='out.png')
# Other options
parser.add_argument("-style_scale", help="style scale", type=float, default=1.0)
#parser.add_argument("-proto_file", default='models/VGG_ILSVRC_19_layers_deploy.prototxt')
#parser.add_argument("-model_file", default='models/VGG_ILSVRC_19_layers.caffemodel')
parser.add_argument("-backend", choices=["nn", "cudnn", "clnn"], default='cudnn')
parser.add_argument("-seed", help="random number seed", default=-1)
params = parser.parse_args()
use_cuda = torch.cuda.is_available()
dtype = torch.cuda.FloatTensor if use_cuda else torch.FloatTensor
#cnn = loadcaffe.load(params.proto_file, params.model_file, params.backend) #.type(dtype)
cnn = models.vgg19(pretrained=True).features
loader = transforms.Compose([
transforms.Scale(params.image_size), # scale imported image
transforms.ToTensor()]) # transform it into a torch tensor
def image_loader(image_name):
image =
image = Variable(loader(image))
# fake batch dimension required to fit network's input dimensions
image = image.unsqueeze(0)
return image
content_image_caffe = image_loader(params.content_image).type(dtype)
style_image_caffe = image_loader(params.style_image).type(dtype)
# move it to the GPU if possible:
if use_cuda:
cnn = cnn.cuda()
content_layers_default = ['relu_4']
style_layers_default = ['relu_1', 'relu_2', 'relu_3', 'relu_4', 'relu_5']
def create_model(cnn, style_image_caffe, content_image_caffe, style_weight=params.style_weight, content_weight=params.style_weight, content_layers=content_layers_default, style_layers=style_layers_default):
cnn = copy.deepcopy(cnn)
content_losses = []
style_losses = []
model = nn.Sequential() # the new Sequential module network
#gram = GramMatrix() # we need a gram module in order to compute style targets
# move these modules to the GPU if possible:
if use_cuda:
model = model.cuda()
#gram = gram.cuda()
i = 1
for layer in list(cnn):
if isinstance(layer, nn.ReLU):
name = "relu_" + str(i)
model.add_module(name, layer)
if name in content_layers:
# add content loss:
target = model(content_image_caffe).clone()
content_loss = ContentLoss(target, content_weight)
model.add_module("content_loss_" + str(i), content_loss)
if name in style_layers:
# add style loss:
target_feature = model(style_image_caffe).clone()
target_feature_gram = gram(target_feature).cuda()
style_loss = StyleLoss(target_feature_gram, style_weight)
model.add_module("style_loss_" + str(i), style_loss)
i += 1
return model, style_losses, content_losses
# Define an nn Module to compute content loss in-place
class ContentLoss(nn.Module):
def __init__(self, target, strength, normalize):
super(ContentLoss, self).__init__()
self.strength = strength = target.detach() * strength
self.normalize = false
self.loss = 0
self.crit = nn.MSECriterion()
self.mode = None
def updateOutput(self, input):
if self.mode == 'loss':
self.loss = self.crit.updateOutput(input, * self.strength #Forward
elif self.mode == 'capture':
self.output = input
return self.output
def updateGradInput(self, input, gradOutput):
if self.mode == 'loss':
if input.nelement() ==
self.gradInput = self.crit.updateGradInput(input, #Backward
if self.normalize:
self.gradInput.div(torch.norm(self.gradInput, 1) + 1e-8) # Normalize Gradients
return self.gradInput
class GramMatrix(nn.Module):
def __init__(self, input):
super(GramMatrix, self).__init__()
def updateOutput(self, input):
assert input.dim() == 3
C, H, W = input.size(1), input.size(2), input.size(3)
x_flat = input.view(C, H * W)
self.output.resize(C, C), x_flat.t())
return self.output
def updateGradInput(self, input, gradOutput):
assert input.dim() == 3 and input.size(1)
C, H, W = input.size(1), input.size(2), input.size(3)
x_flat = input.view(C, H * W)
self.gradInput.resize(C, H * W).mm(gradOutput, x_flat)
self.gradInput.addmm(gradOutput.t(), x_flat)
self.gradInput = self.gradInput.view(C, H, W)
return self.gradInput
# Define an nn Module to compute style loss in-place
class StyleLoss(nn.Module):
def __init__(self, target, strength, normalize):
super(StyleLoss, self).__init__()
self.normalize = false
self.strength = strength = target.detach() * strength
self.mode = None
self.loss = 0
self.gram = GramMatrix()
self.blend_weight = nil
self.G = None
self.crit = nn.MSECriterion()
def updateOutput(self, input):
self.G = self.gram.updateOutput(input) # Forward Gram
if self.mode == 'capture':
if self.blend_weight == None:
elif == 0:
else:, self.G)
elif self.mode == 'loss':
self.loss = self.strength * self.crit.updateOutput(input, #Forward
self.output = input
return self.output
def updateGradInput(self, input, gradOutput):
if self.mode == 'loss':
dG = self.crit.updateGradInpu(self.G, # Backward
self.gradInput = self.gram.updateGradInput(input) # Gram Backward
if self.normalize:
self.gradInput.div(torch.norm(self.gradInput, 1) + 1e-8) # Normalize Gradients
self.gradInput = gradOutput
return self.gradInput
model, style_losses, content_losses = create_model(cnn, style_image_caffe, content_image_caffe, params.style_weight, params.content_weight, content_layers_default, style_layers_default)
img = content_image_caffe.clone()
# Run it through the network once to get the proper size for the gradient
# All the gradients will come from the extra loss modules, so we just pass
# zeros into the top of the net on the backward pass.
y = model.updateOutput(img)
dy = img.clone().zero_()
#dy = dy.zero_()
# Declaring this here lets us access it in maybe_print
optim_state = None
if params.optimizer == 'lbfgs':
optim_state = {
"maxIter": params.num_iterations,
"verbose": True,
if params.lbfgs_num_correction > 0:
optim_state.nCorrection = params.lbfgs_num_correction
elif params.optimizer == 'adam':
optim_state = {
"learningRate": params.learning_rate,
# Function to evaluate loss and gradient. We run the net forward and
# backward to get the gradient, and sum up losses from the loss modules.
# optim.lbfgs internally handles iteration and calls this function many
# times, so we manually count the number of iterations to handle printing
# and saving intermediate results.
num_calls = [0]
def feval(x):
num_calls[0] += 1
grad = model.updateGradInput(x, dy)
loss = 0
for n, mod in content_losses:
loss = loss + mod.loss
for n, mod in style_losses:
loss = loss + mod.loss
# optim.lbfgs expects a vector for gradients
return loss, grad.view(grad.nelement())
print("Model Loaded")
# Capture content targets
for i in content_losses:
content_losses[i].mode = 'capture'
print("Capturing content targets")
content_image_caffe = content_image_caffe.type(dtype)
# Capture style targets
for i in content_losses:
content_losses[i].mode = None
print("Capturing style target")
for j in style_losses:
style_losses[j].mode = 'capture'
style_losses[j].blend_weight = style_blend_weights[i]
# Set all loss modules to loss mode
for i in content_losses:
content_losses[i].mode = loss
for i in style_losses:
style_losses[i].mode = loss
# Initialize the image
if params.seed >= 0:
# Run optimization.
if params.optimizer == 'lbfgs':
print("Running optimization with L-BFGS")
x, losses = optim.lbfgs(feval, img, optim_state)
elif params.optimizer == 'adam':
print("Running optimization with ADAM")
for t in params.num_iterations:
x, losses = optim.adam(feval, img, optim_state)
print("Test CNN")
torchvision.utils.save_image(output_img, params.output_image, nrow=8, padding=2, normalize=False, range=None, scale_each=False, pad_value=0)
Copy link

ProGamerGov commented Mar 6, 2018

The torch.legacy.nn package in PyTorch doesn't support the Conv2d, and the MaxPool2d layers that the pretrained VGG models seem to all use.

So I have been trying to figure out how to either use those layers with torch.legacy.nn, load the model and replace those layers, or convert a model's layers to the applicable legacy layers.

I went with torch.legacy.nn instead of torch.nn because it let me use the same functions in the ContentLoss, StyleLoss, and GramMatrix functions as Neural-Style uses.

There are also numerous issues like the how to images are processed, and how the input images need to be the same exact size, but I am trying to solve the issue of setting up the model first before I address those other issues.

Copy link

I made a script which can change the model layers:

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment