Skip to content

Instantly share code, notes, and snippets.

@ProGamerGov
Last active July 5, 2018 01:29
Show Gist options
  • Save ProGamerGov/3731931b00b5cc9dd999e75c129941a2 to your computer and use it in GitHub Desktop.
Save ProGamerGov/3731931b00b5cc9dd999e75c129941a2 to your computer and use it in GitHub Desktop.
Working PyTorch Style Transfer With CPU/CUDA
import torch
import torch.nn as nn
#import os
# VGG class based on: https://github.com/pytorch/vision/blob/master/torchvision/models/vgg.py
class VGG(nn.Module):
def __init__(self, features, num_classes=1000):
super(VGG, self).__init__()
self.features = features
self.classifier = nn.Sequential(
nn.Linear(512 * 7 * 7, 4096),
nn.ReLU(True),
nn.Dropout(),
nn.Linear(4096, 4096),
nn.ReLU(True),
nn.Dropout(),
nn.Linear(4096, num_classes),
)
def forward(self, x):
x = self.features(x)
x = x.view(x.size(0), -1)
x = self.classifier(x)
return x
def make_layers(cfg, pool='max'):
layers = []
in_channels = 3
if pool == 'max':
pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
elif pool == 'avg':
pool2d = nn.AvgPool2d(kernel_size=2, stride=2)
else:
print("Unrecognized pooling parameter")
quit()
for v in cfg:
if v == 'M':
layers += [pool2d]
else:
conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
layers += [conv2d, nn.ReLU(inplace=True)]
in_channels = v
return nn.Sequential(*layers)
cfg = {
'A': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
'B': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
'D': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
'E': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
}
def vgg19(pool='max', **kwargs):
# VGG 19-layer model (configuration "E")
model = VGG(make_layers(cfg['E'], pool='max'), **kwargs)
return model
def vgg16(pool='max',**kwargs):
# VGG 16-layer model (configuration "D")
model = VGG(make_layers(cfg['D'], pool='max'), **kwargs)
return model
def vgg13(pool='max',**kwargs):
# VGG 13-layer model (configuration "B")
model = VGG(make_layers(cfg['B'], pool='max'), **kwargs)
return model
def vgg11(pool='max',**kwargs):
# VGG 11-layer model (configuration "A")
model = VGG(make_layers(cfg['A'], pool='max'), **kwargs)
return model
import torch
import torch.nn as nn
# Define an nn Module to compute content loss in-place
class ContentLoss(nn.Module):
def __init__(self, strength):#, norm):
super(ContentLoss, self).__init__()
self.target = torch.Tensor()
self.strength = strength
self.crit = nn.MSELoss()
self.mode = 'None'
#self.normalize = norm
def forward(self, input):
if self.mode == 'capture':
self.target = input.detach() * self.strength
elif self.mode == 'loss':
self.loss = self.crit(input * self.strength, self.target)
self.output = input
return self.output
def backward(self, retain_graph=True):
self.loss.backward(retain_graph=retain_graph)
return self.loss
class GramMatrix(nn.Module):
def forward(self, input):
B, C, H, W = input.size()
x_flat = input.view(C, H * W)
self.output = torch.mm(x_flat, x_flat.t())
self.output.div_(B*C*H*W)
return self.output
# Define an nn Module to compute style loss in-place
class StyleLoss(nn.Module):
def __init__(self, strength):#, norm):
super(StyleLoss, self).__init__()
self.target = torch.Tensor()
self.strength = strength
self.gram = GramMatrix()
self.crit = nn.MSELoss()
self.mode = 'None'
self.blend_weight = None
#self.normalize = norm
def forward(self, input):
self.output = input.clone()
self.G = self.gram(input)
#self.G.div_(input.nelement()) 0 loss every iter
self.G.mul_(self.strength)
if self.mode == 'capture':
if self.blend_weight == None:
self.target = self.G.detach()
# Multiple blend weights, and multiple style images don't currently work
elif self.target.nelement() == 0:
self.target = self.G.detach().mul(self.blend_weight)
else:
self.target = torch.add(self.target, self.blend_weight, self.G.detach())
elif self.mode == 'loss':
self.loss = self.crit(self.G, self.target)
return self.output
def backward(self, retain_graph=True):
self.loss.backward(retain_graph=retain_graph)
return self.loss
# @ProGamerGov 25 March 2018
import torch
import torch.nn as nn
import torchvision
import torchvision.models as models
import torchvision.transforms as transforms
from torch.autograd import Variable, Function
import torch.optim as optim
from PIL import Image
import os
import copy
from LossModulesNew import ContentLoss
from LossModulesNew import StyleLoss
from LossModulesNew import GramMatrix
from CVGG import vgg19, vgg16
import argparse
parser = argparse.ArgumentParser()
# Basic options
parser.add_argument("-style_image", help="Style target image", default='examples/inputs/seated-nude.jpg')
parser.add_argument("-style_blend_weights", default=None)
parser.add_argument("-content_image", help="Content target image", default='examples/inputs/tubingen.jpg')
parser.add_argument("-image_size", help="Maximum height / width of generated image", type=int, default=512)
parser.add_argument("-gpu", help="Zero-indexed ID of the GPU to use; for CPU mode set -gpu = 0", default=0)
# Optimization options
parser.add_argument("-content_weight", type=int, default=1)
parser.add_argument("-style_weight", type=int, default=2000)
parser.add_argument("-tv_weight", default=0.001)
parser.add_argument("-num_iterations", type=int, default=1000)
parser.add_argument("-normalize_gradients", action='store_true')
parser.add_argument("-init", default="random", choices=["random", "image"])
parser.add_argument("-init_image", help="initial image", default="")
parser.add_argument("-optimizer", help="optimiser", default="lbfgs", choices=["lbfgs", "adam"])
parser.add_argument("-learning_rate", type=float, default=1e0)
# Output options
parser.add_argument("-print_iter", type=int, default=50)
parser.add_argument("-save_iter", type=int, default=100)
parser.add_argument("-output_image", default='out.png')
# Other options
parser.add_argument("-style_scale", help="style scale", type=float, default=1.0)
parser.add_argument("-pooling", help="avg or max pooling", type=str, default='max')
parser.add_argument("-model_file", help="VGG 19 model file location", type=str, default='vgg19-d01eb7cb.pth')
parser.add_argument("-seed", help="random number seed", type=int, default=-1)
parser.add_argument("-content_layers", help="VGG 19 content layers", default='relu2_2')
parser.add_argument("-style_layers", help="VGG 19 style layers", default='relu1_1,relu1_2,relu2_1,relu2_2,relu3_1')
params = parser.parse_args()
use_cuda = None #torch.cuda.is_available()
if params.gpu == '0':
use_cuda = torch.cuda.is_available()
torch.backends.cudnn.benchmark = True
print(use_cuda)
elif params.gpu == '-1':
use_cuda = False
dtype = torch.cuda.FloatTensor if use_cuda else torch.FloatTensor
# Optionally set the seed value
if params.seed >= 0:
torch.manual_seed(params.seed)
torch.cuda.manual_seed(params.seed)
torch.backends.cudnn.deterministic=True
# Preprocess an image before passing it to a model.
def ImageSetup(image_name, image_size):
image = Image.open(image_name)
#image = image.convert('RGB')
Loader = transforms.Compose([transforms.Resize(image_size), transforms.ToTensor()]) # resize and convert to tensor
tensor = Variable(Loader(image))
tensor = tensor.unsqueeze(0)
return tensor
# Undo the above preprocessing.
def SaveImage(output_img, output_name):
B, C, H, W = output_img.size()
output_img = output_img.view(C, H, W)
Image2PIL = transforms.ToPILImage()
image = Image2PIL(output_img.cpu().data)
image.save(str(output_name))
def maybe_save(t):
should_save = params.save_iter > 0 and t % params.save_iter == 0
should_save = should_save or t == params.num_iterations
if should_save:
output_filename, file_extension = os.path.splitext(params.output_image)
if t == params.num_iterations:
filename = output_filename + str(file_extension)
else:
filename = str(output_filename) + "_" + str(t) + str(file_extension)
SaveImage(img.clone(), filename)
content_image = ImageSetup(params.content_image, params.image_size).type(dtype)
style_image_list = params.style_image.split(',')
style_images_caffe = []
for image in style_image_list:
image_size = int(params.image_size * params.style_scale)
img_caffe = ImageSetup(image, image_size).type(dtype)
style_images_caffe.append(img_caffe)
style_blend_weights = []
if params.style_blend_weights == None:
# Style blending not specified, so use equal weighting
for i in style_image_list:
style_blend_weights.append(1.0)
i = 0
for blend_weights in style_blend_weights:
style_blend_weights[i] = int(style_blend_weights[i])
i+=1
else:
style_blend_weights = params.style_blend_weights.split(',')
# Normalize the style blending weights so they sum to 1
style_blend_sum = 0
i = 0
for blend_weights in style_blend_weights:
style_blend_weights[i] = float(style_blend_weights[i])
style_blend_sum = float(style_blend_sum) + style_blend_weights[i]
i+=1
i = 0
for blend_weights in style_blend_weights:
style_blend_weights[i] = float(style_blend_weights[i]) / float(style_blend_sum)
print(style_blend_weights[i])
i+=1
content_layers = params.content_layers.split(',')
style_layers = params.style_layers.split(',')
vgg16_dict = {
'C': ['conv1_1', 'conv1_2', 'conv2_1', 'conv2_2', 'conv3_1', 'conv3_2', 'conv3_3', 'conv4_1', 'conv4_2', 'conv4_3', 'conv5_1', 'conv5_2', 'conv5_3'],
'R': ['relu1_1', 'relu1_2', 'relu2_1', 'relu2_2', 'relu3_1', 'relu3_2', 'relu3_3', 'relu4_1', 'relu4_2', 'relu4_3', 'relu5_1', 'relu5_2', 'relu5_3'],
'P': ['pool1', 'pool2', 'pool3', 'pool4', 'pool5'],
}
vgg19_dict = {
'C': ['conv1_1', 'conv1_2', 'conv2_1', 'conv2_2', 'conv3_1', 'conv3_2', 'conv3_3', 'conv3_4', 'conv4_1', 'conv4_2', 'conv4_3', 'conv4_4', 'conv5_1', 'conv5_2', 'conv5_3', 'conv5_4'],
'R': ['relu1_1', 'relu1_2', 'relu2_1', 'relu2_2', 'relu3_1', 'relu3_2', 'relu3_3', 'relu3_4', 'relu4_1', 'relu4_2', 'relu4_3', 'relu4_4', 'relu5_1', 'relu5_2', 'relu5_3', 'relu5_4'],
'P': ['pool1', 'pool2', 'pool3', 'pool4', 'pool5'],
}
# Get the model class, and configure pooling layer type
def buildCNN(model_file, pooling):
cnn = None
layerList = []
if "vgg19" in str(model_file):
layerList = vgg19_dict
print("VGG-19 Architecture Detected")
cnn = vgg19(pooling)
elif "vgg16" in str(model_file):
layerList = vgg16_dict
print("VGG-16 Architecture Detected")
cnn = vgg16(pooling)
return cnn, layerList
def modelSetup(cnn, layerList):
cnn = copy.deepcopy(cnn)
content_losses = []
style_losses = []
net = nn.Sequential()
gram = GramMatrix()
if use_cuda:
gram = gram.cuda()
i = 1
c, r = 0, 0
for layer in list(cnn):
if isinstance(layer, nn.Conv2d):
name = "conv_" + str(i)
net.add_module(name, layer)
layerType = layerList['C']
if layerType[c] in content_layers:
print("Setting up content layer " + str(i) + ": " + str(layerType[c]))
loss_module = ContentLoss(params.content_weight)
net.add_module("content_loss_" + str(i), loss_module)
content_losses.append(loss_module)
if layerType[c] in style_layers:
print("Setting up style layer " + str(i) + ": " + str(layerType[c]))
loss_module = StyleLoss(params.style_weight)
net.add_module("style_loss_" + str(i), loss_module)
style_losses.append(loss_module)
c+=1
if isinstance(layer, nn.ReLU):
name = "relu_" + str(i)
net.add_module(name, layer)
layerType = layerList['R']
if layerType[r] in content_layers:
print("Setting up content layer " + str(i) + ": " + str(layerType[r]))
loss_module = ContentLoss(params.content_weight)
net.add_module("content_loss_" + str(i), loss_module)
content_losses.append(loss_module)
if layerType[r] in style_layers:
print("Setting up style layer " + str(i) + ": " + str(layerType[r]))
loss_module = StyleLoss(params.style_weight)
net.add_module("style_loss_" + str(i), loss_module)
style_losses.append(loss_module)
r+=1
i += 1
if isinstance(layer, nn.MaxPool2d) or isinstance(layer, nn.AvgPool2d):
name = "pool_" + str(i)
net.add_module(name, layer) # ***
return net, style_losses, content_losses
def captureTargets():
# Capture content targets
for i in content_losses:
i.mode = 'capture'
net(content_image).clone()
print("Capturing content targets")
# Capture style targets
for i in content_losses:
i.mode = 'None'
i = 0
for image in style_images_caffe:
print("Capturing style target " + str(i+1))
for j in style_losses:
j.mode = 'capture'
j.blend_weight = style_blend_weights[i]
net(style_images_caffe[i]).clone()
i+=1
# Set all loss modules to loss mode
for i in content_losses:
i.mode = 'loss'
for i in style_losses:
i.mode = 'loss'
return
# Configure optimizer and input image
def setupOptimizer(img):
img = nn.Parameter(img.data)
if params.optimizer == 'lbfgs':
print("Running optimization with L-BFGS")
optimizer = optim.LBFGS([img]) #, max_iter = params.num_iterations, tolerance_change = -1, tolerance_grad = -1)
elif params.optimizer == 'adam':
print("Running optimization with ADAM")
for t in xrange(params.num_iterations):
optimizer = optim.Adam([img], lr = params.learning_rate, betas=(0.99,0.999), eps=1e-8)
return img, optimizer
# Build the model definition and setup pooling layers:
cnn, layerList = buildCNN(params.model_file, params.pooling)
cnn.load_state_dict(torch.load(params.model_file)) # Use the model definition to load model file.
# Convert the model to cuda now, to avoid later issues:
if use_cuda:
cnn = cnn.cuda()
# We only need the features from the model:
cnn = cnn.features
# Build the style transfer network:
net, style_losses, content_losses = modelSetup(cnn, layerList)
captureTargets() # Capture content and style targets separately, to avoid size mismatches.
img = content_image.clone()
img, optimizer = setupOptimizer(img) # Setup the optimizer.
num_calls = [0]
while num_calls[0] <= params.num_iterations:
def feval():
num_calls[0] += 1
img.data.clamp_(0, 1)
optimizer.zero_grad()
net(img)
contentLoss = 0
styleLoss = 0
for mod in content_losses:
contentLoss += mod.backward()
for mod in style_losses:
styleLoss += mod.backward()
if num_calls[0] % params.print_iter == 0:
print("Iteration: " + str(num_calls[0]))
print('Style Loss : {:4f} Content Loss: {:4f}'.format(styleLoss.data[0], contentLoss.data[0]))
print()
maybe_save(num_calls[0])
#maybe_print(num_calls[0], contentLoss, styleLoss)
return contentLoss + styleLoss
optimizer.step(feval)
@ProGamerGov
Copy link
Author

This code was continued from the discussion on this issue thread: jcjohnson/neural-style#450

A more up to date version of this code can be found here: https://gist.github.com/ProGamerGov/089a082c2a000d1e1cc034fc75ff5931

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment