-
-
Save ProGamerGov/2066ec2c09117931f69d226b9115d21f to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.nn as nn | |
import torchvision | |
# Define an nn Module to compute content loss in-place | |
class ContentLoss(nn.Module): | |
def __init__(self, strength, normalize): | |
super(ContentLoss, self).__init__() | |
self.strength = strength | |
self.target = torch.Tensor() | |
self.normalize = 'false' | |
self.loss = 0 | |
self.crit = nn.MSELoss() | |
self.mode = None | |
def updateOutput(self, input): | |
if self.mode == 'loss': | |
self.loss = self.crit.updateOutput(input, self.target) * self.strength #Forward | |
elif self.mode == 'capture': | |
self.target.resize_as_(input).copy_(input) | |
self.output = input | |
return self.output | |
def updateGradInput(self, input, gradOutput): | |
if self.mode == 'loss': | |
if input.nelement() == self.target.nelement(): | |
self.gradInput = self.crit.updateGradInput(input, self.target) #Backward | |
if self.normalize: | |
self.gradInput.div(torch.norm(self.gradInput, 1) + 1e-8) # Normalize Gradients | |
self.gradInput.mul(self.strength) | |
self.gradInput.add(gradOutput) | |
else: | |
self.target.resize_as_(gradOutput).copy_(gradOutput) | |
return self.gradInput | |
class GramMatrix(nn.Module): | |
def __init__(self): | |
super(GramMatrix, self).__init__() | |
def updateOutput(self, input): | |
assert input.dim() == 3 | |
C, H, W = input.size(0), input.size(1), input.size(2) | |
x_flat = input.view(C, H * W) | |
self.output.resize_(C, C) | |
self.output = torch.mm(x_flat, x_flat.t()) | |
return self.output | |
def updateGradInput(self, input, gradOutput): | |
assert input.dim() == 3 and input.size(0) | |
C, H, W = input.size(0), input.size(1), input.size(2) | |
x_flat = input.view(C, H * W) | |
#self.gradInput.resize(C, H * W).mm(gradOutput, x_flat) | |
self.gradInput.resize_(C, H * W)#.mm(gradOutput, x_flat) | |
self.gradInput = torch.mm(gradOutput, x_flat) #, out=self.gradInput | |
self.gradInput.addmm(gradOutput.t(), x_flat) | |
self.gradInput = self.gradInput.view(C, H, W) | |
return self.gradInput | |
# Define an nn Module to compute style loss in-place | |
class StyleLoss(nn.Module): | |
def __init__(self, strength, normalize): | |
super(StyleLoss, self).__init__() | |
self.normalize = 'false' | |
self.strength = strength | |
self.target = torch.Tensor() | |
self.mode = None | |
self.loss = 0 | |
self.gram = GramMatrix() | |
self.blend_weight = None | |
self.G = None | |
self.crit = nn.MSELoss() | |
def updateOutput(self, input): | |
self.G = self.gram.updateOutput(input) # Forward Gram | |
self.G.div(input.nelement()) #Lua (Fix): self.G:div(input:nElement()) | |
if self.mode == 'capture': | |
if self.blend_weight == None: | |
self.target.resize_as_(self.G).copy_(self.G) | |
elif self.target.nelement() == 0: | |
self.target.resize_as_(self.G).copy_(self.G).mul_(self.blend_weight) | |
else: | |
self.target.add(self.blend_weight, self.G) | |
elif self.mode == 'loss': | |
self.loss = self.strength * self.crit.forward(self.G, self.target) #Forward | |
self.output = input | |
return self.output | |
def updateGradInput(self, input, gradOutput): | |
if self.mode == 'loss': | |
dG = self.crit.updateGradInput(self.G, self.target) # Backward | |
#dG.div(input.nelement()) | |
self.gradInput = self.gram.updateGradInput(input, dG) # Gram Backward | |
if self.normalize: | |
self.gradInput.div(torch.norm(self.gradInput, 1) + 1e-8) # Normalize Gradients | |
self.gradInput.mul(self.strength) | |
self.gradInput.add(gradOutput) | |
else: | |
self.gradInput = gradOutput | |
return self.gradInput |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.nn as nn | |
import torchvision | |
import torchvision.models as models | |
import torchvision.transforms as transforms | |
from torch.autograd import Variable | |
import torch.optim as optim | |
from PIL import Image | |
from LossModules import ContentLoss | |
from LossModules import StyleLoss | |
from LossModules import GramMatrix | |
import argparse | |
parser = argparse.ArgumentParser() | |
# Basic options | |
parser.add_argument("-style_image", help="Style target image", default='examples/inputs/seated-nude.jpg') | |
parser.add_argument("-content_image", help="Content target image", default='examples/inputs/tubingen.jpg') | |
parser.add_argument("-image_size", help="Maximum height / width of generated image", type=int, default=512) | |
# Optimization options | |
parser.add_argument("-num_iterations", help="iterations", type=int, default=1000) | |
parser.add_argument("-optimizer", help="optimiser", default="lbfgs", choices=["lbfgs", "adam"]) | |
params = parser.parse_args() | |
use_cuda = torch.cuda.is_available() | |
dtype = torch.cuda.FloatTensor if use_cuda else torch.FloatTensor | |
def ImageSetup(image_name, image_size): | |
image = Image.open(image_name) | |
image = image.convert('RGB') | |
loader = transforms.Compose([transforms.Resize((image_size)), transforms.ToTensor()]) # resize and convert to tensor | |
image = Variable(loader(image)) | |
#image = image.permute(2, 1, 0) | |
image = image.unsqueeze(0) | |
print(image.size()) | |
return image | |
def SaveImage(output_img, output_name): | |
torchvision.utils.save_image(output_img, output_name, nrow=8, padding=2, normalize=False, range=None, scale_each=False, pad_value=0) | |
content_image = ImageSetup(params.content_image, params.image_size) | |
style_image = ImageSetup(params.style_image, params.image_size) | |
# Separate names for layers | |
VGG19_Layer_List = ['conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1', 'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2', 'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'conv3_3', 'relu3_3', 'conv3_4', 'relu3_4', 'pool3', 'conv4_1', 'relu4_1', 'conv4_2', 'relu4_2', 'conv4_3', 'relu4_3', 'conv4_4', 'relu4_4', 'pool4', 'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2', 'conv5_3', 'relu5_3', 'conv5_4', 'relu5_4', 'pool5', 'torch_view', 'fc6', 'relu6', 'drop6', 'fc7', 'relu7', 'drop7', 'fc8', 'prob'] | |
VGG16_layer_List = ['conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1', 'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2', 'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'conv3_3', 'relu3_3', 'pool3', 'conv4_1', 'relu4_1', 'conv4_2', 'relu4_2', 'conv4_3', 'relu4_3', 'pool4', 'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2', 'conv5_3', 'relu5_3', 'pool5', 'torch_view', 'fc6', 'relu6', 'drop6', 'fc7', 'relu7', 'drop7', 'fc8', 'prob'] | |
NIN_Layer_List = ['conv1', 'relu0', 'cccp1', 'relu1', 'cccp2', 'relu2', 'pool0', 'conv2', 'relu3', 'cccp3', 'relu5', 'cccp4', 'relu6', 'pool2', 'conv3', 'relu7', 'cccp5', 'relu8', 'cccp6', 'relu9', 'pool3', 'drop', 'conv4-1024', 'relu10', 'cccp7-1024', 'relu11', 'cccp8-1024', 'relu12', 'pool4', 'loss'] | |
def ModelSetup(cnn, style_weight, content_weight, Layer_List, content_layers, style_layers, normalize_gradients): | |
content_losses = [] | |
style_losses = [] | |
next_content_idx = 1 | |
next_style_idx = 1 | |
net = nn.Sequential() | |
i = 0 | |
for layer in list(cnn): | |
l = int(i) | |
layer_name = Layer_List[l] | |
if "conv" in layer_name: | |
net.add_module(layer_name, layer) | |
if layer_name in content_layers: | |
print("Setting up content layer " + str(i) + ": " + str(layer_name)) | |
norm = normalize_gradients | |
loss_module = ContentLoss(content_weight, norm) | |
net.add_module(layer_name, loss_module) | |
content_losses.append(loss_module) | |
next_content_idx = next_content_idx + 1 | |
if layer_name in style_layers: | |
print("Setting up style layer " + str(i) + ": " + str(layer_name)) | |
norm = normalize_gradients | |
loss_module = StyleLoss(style_weight, norm)#.type(dtype) | |
net.add_module(layer_name, loss_module) | |
style_losses.append(loss_module) | |
next_style_idx = next_style_idx + 1 | |
if "relu" in layer_name: | |
net.add_module(layer_name, layer) | |
if layer_name in content_layers: | |
print("Setting up content layer " + str(i) + ": " + str(layer_name)) | |
norm = normalize_gradients | |
loss_module = ContentLoss(content_weight, norm) | |
net.add_module(layer_name, loss_module) | |
content_losses.append(loss_module) | |
next_content_idx = next_content_idx + 1 | |
if layer_name in style_layers: | |
print("Setting up style layer " + str(i) + ": " + str(layer_name)) | |
norm = normalize_gradients | |
loss_module = StyleLoss(style_weight, norm) | |
net.add_module(layer_name, loss_module) | |
style_losses.append(loss_module) | |
next_style_idx = next_style_idx + 1 | |
if "pool" in layer_name: | |
net.add_module(layer_name, layer) # *** | |
i = i + 1 | |
cnn = None | |
return net, style_losses, content_losses | |
model_type ='vgg19' # Default value for testing | |
style_weight = 1000 # Default value for testing | |
content_weight = 100 # Default value for testing | |
normalize_gradients = 'False' # Default value for testing | |
content_layers = ['relu4_2'] # Default value for testing | |
style_layers = ['relu1_1', 'relu2_1', 'relu3_1', 'relu4_1', 'relu5_1'] # Default value for testing | |
max_iter = 1000 # Default value for testing | |
cnn = None | |
Layer_List = [] | |
if model_type == 'vgg19': | |
cnn = models.vgg19(pretrained=True).features | |
Layer_List = VGG19_Layer_List | |
elif model_type == 'vgg16': | |
cnn = models.vgg16(pretrained=True).features | |
Layer_List = VGG16_Layer_List | |
# Figure out what layer setup to use: | |
# Build the style transfer model: | |
net, style_losses, content_losses = ModelSetup(cnn, style_weight, content_weight, Layer_List, content_layers, style_layers, normalize_gradients) | |
#print(net) | |
################################################ | |
# This doesn't look right, but it works: | |
norm = normalize_gradients | |
CL = ContentLoss(content_weight, norm) | |
SL = StyleLoss(style_weight, norm) | |
img = content_image.clone() | |
img = nn.Parameter(img.data) | |
content_image = nn.Parameter(content_image.data) | |
style_image = nn.Parameter(style_image.data) | |
################################################ | |
print(CL) | |
print(SL) | |
# Capture content targets | |
for i in content_losses: | |
i.mode = 'capture' | |
#CL.updateOutput(content_image) #This works, but then StyleLoss doesn't have it's self.output variable from ContentLoss. So GramMatrix can't run properly. | |
net.forward(content_image) | |
print("Capturing content targets") | |
# Capture style targets | |
for i in content_losses: | |
i.mode = None | |
for j in style_losses: | |
j.mode = 'capture' | |
#j.blend_weight = style_blend_weights[i] | |
SL.updateOutput(style_image) | |
# Set all loss modules to loss mode | |
for i in content_losses: | |
i.mode = 'loss' | |
for i in style_losses: | |
i.mode = 'loss' | |
img = content_image.clone() | |
img = nn.Parameter(img.data) | |
y = net.forward(img) | |
dy = y.zero_() | |
num_calls = [0] | |
def feval(x): | |
num_calls[0] += 1 | |
print("feval(x)") | |
x.data.clamp_(0, 1) | |
optimizer.zero_grad() | |
for mod in content_losses: | |
loss = loss + mod.loss | |
for mod in style_losses: | |
loss = loss + mod.loss | |
maybe_print(num_calls[0], loss) | |
maybe_save(num_calls[0]) | |
# optim.lbfgs expects a vector for gradients | |
return loss, grad.view(grad.nelement()) | |
optim_state = None | |
if params.optimizer == 'lbfgs': | |
max_iter = params.num_iterations, | |
tolerance_change = -1, | |
tolerance_grad = -1, | |
optim_state = { | |
max_iter, | |
tolerance_change, | |
tolerance_grad, | |
} | |
elif params.optimizer == 'adam': | |
optim_state = { | |
"lr": 1, | |
} | |
# Run optimization. | |
if params.optimizer == 'lbfgs': | |
print("Running optimization with L-BFGS") | |
x, losses = optim.LBFGS(feval, img, optim_state) | |
elif params.optimizer == 'adam': | |
print("Running optimization with ADAM") | |
for t in xrange(params.num_iterations): | |
x, losses = optim.Adam(feval, img, optim_state) | |
optim_state2 = { | |
"max_iter": params.num_iterations, | |
"tolerance_change": -1, | |
"tolerance_grad": -1, | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment