Skip to content

Instantly share code, notes, and snippets.

@jhlegarreta
Created September 9, 2018 15:55
Show Gist options
  • Save jhlegarreta/bea2b2f012a1e7dd593be8c56e846aba to your computer and use it in GitHub Desktop.
Save jhlegarreta/bea2b2f012a1e7dd593be8c56e846aba to your computer and use it in GitHub Desktop.
Pytorch tutorials for Neural Style transfer
"""
Pytorch tutorials for Neural Style transfer
https://github.com/alexis-jacq/Pytorch-Tutorials
"""
# Packages
from PIL import Image
import torch
from torch import nn, optim
from torch.autograd import Variable
from torchvision import models, transforms
# CUDA
use_cuda = torch.cuda.is_available()
dtype = torch.cuda.FloatTensor if use_cuda else torch.FloatTensor
# Load images
imsize = 200 # Desired size of the output image
loader = transforms.Compose([
transforms.Scale(imsize), # Scale imported image
transforms.ToTensor()]) # Transform it into a Torch tensor
def image_loader(image_name):
image = Image.open(image_name)
image = Variable(loader(image))
# Fake batch dimension required to fit network's input dimensions
image = image.unsqueeze(0)
return image
style = image_loader("style.jpg").type(dtype)
content = image_loader("content.jpg").type(dtype)
assert style.size() == content.size(), "We need to import style and content images of the same size"
# Display images
unloader = transforms.ToPILImage() # Reconvert into PIL image
def imshow(tensor):
image = tensor.clone().cpu() # We clone the tensor to not do changes on it
image = image.view(3, imsize, imsize) # Remove the fake batch dimension
image = unloader(image)
image.show()
imshow(style.data)
imshow(content.data)
# Content loss
class ContentLoss(nn.Module):
def __init__(self, target, weight):
super(ContentLoss, self).__init__()
# We 'detach' the target content from the tree used
self.target = target.detach() * weight
# To dynamically compute the gradient: This is a stated value, not a variable
# Otherwise the forward method of the criterion will throw an error
self.weight = weight
self.criterion = nn.MSELoss()
def forward(self, input):
self.loss = self.criterion.forward(input * self.weight, self.target)
self.output = input
return self.output
def backward(self, retain_variables=True):
self.loss.backward(retain_variables=retain_variables)
return self.loss
# Style loss
class GramMatrix(nn.Module):
def forward(self, input):
a, b, c, d = input.size() # a = batch size, b = number of feature maps, (c, d) = dimensions of a feature map (N = c * d)
features = input.view(a * b, c * d) # Resize F_XL into \hat F_XL
G = torch.mm(features, features.t()) # Compute the Gram product
# We 'normalise' the values of the Gram matrix by dividing by the number of elements in each feature map
return G.div(a * b * c * d)
class StyleLoss(nn.Module):
def __init__(self, target, weight):
super(StyleLoss, self).__init__()
self.target = target.detach() * weight
self.weight = weight
self.gram = GramMatrix()
self.criterion = nn.MSELoss()
def forward(self, input):
self.output = input.clone()
self.G = self.gram.forward(input)
self.G.mul_(self.weight)
self.loss = self.criterion.forward(self.G, self.target)
return self.output
def backward(self, retain_variables=True):
self.loss.backward(retain_variables=retain_variables)
return self.loss
# Load the neural network
cnn = models.vgg19(pretrained=True).features
# Move it to the GPU if possible
if use_cuda:
cnn.cuda()
# Desired depth layers to compute style/content losses
content_layers = ['conv_4']
style_layers = ['conv_1', 'conv_2', 'conv_3', 'conv_4', 'conv_5']
# Just in order to have an iterable access to our list of content/style losses
content_losses = []
style_losses = []
model = nn.Sequential() # The new Sequential module network
gram = GramMatrix() # We need a Gram module in order to compute the style targets
# Move these modules to GPU if possible
if use_cuda:
model.cuda()
gram.cuda()
# Weight associated with content and style losses
content_weight = 1
style_weight = 1000
i = 1
for layer in list(cnn):
if isinstance(layer, nn.Conv2d):
name = "conv_" + str(i)
model.add_module(name, layer)
if name in content_layers:
# Add content loss
target = model.forward(content).clone()
content_loss = ContentLoss(target, content_weight)
model.add_module("content_loss_" + str(i), content_loss)
content_losses.append(content_loss)
if name in style_layers:
# Add style loss
target_feature = model.forward(style).clone()
target_feature_gram = gram.forward(target_feature)
style_loss = StyleLoss(target_feature_gram, style_weight)
model.add_module("style_loss_" + str(i), style_loss)
style_losses.append(style_loss)
if isinstance(layer, nn.ReLU):
name = "relu_" + str(i)
model.add_module(name, layer)
if name in content_layers:
# Add content loss
target = model.forward(content).clone()
content_loss = ContentLoss(target, content_weight)
model.add_module("content_loss_" + str(i), content_loss)
content_losses.append(content_loss)
if name in style_layers:
# Add style loss
target_feature = model.forward(style).clone()
target_feature_gram = gram.forward(target_feature)
style_loss = StyleLoss(target_feature_gram, style_weight)
model.add_module("style_loss_" + str(i), style_loss)
style_losses.append(style_loss)
i += 1
if isinstance(layer, nn.MaxPool2d):
name = "pool_" + str(i)
model.add_module(name, layer)
# Input image
input = image_loader("content.jpg").type(dtype)
# If we want to fill it with white noise
# input.data = torch.randn(input.data.size()).type(dtype)
# Display the input image
imshow(input.data)
# Gradient descent
# This line to show that input is a parameter that requires a gradient
input = nn.Parameter(input.data)
optimiser = optim.LBFGS([input])
run = [0]
while run[0] <= 300:
def closure():
# Correct the values of updated input image
input.data.clamp_(0, 1)
optimiser.zero_grad()
model.forward(input)
style_score = 0
content_score = 0
for sl in style_losses:
style_score += sl.backward()
for cl in content_losses:
content_score += cl.backward()
run[0] += 1
if run[0] % 10 == 0:
print("run " + str(run) + ":")
print(style_score.data[0])
print(content_score.data[0])
imshow(input.data)
return style_score + content_score
optimiser.step(closure)
# A last correction
input.data.clamp_(0, 1)
# Finally, enjoy the result
imshow(input.data)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment