-
-
Save ProGamerGov/17a99ebab50e4fe7079ebc1c40c889a3 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import torch.optim as optim | |
from torch.autograd import Variable | |
import torchvision | |
from torchvision import transforms | |
from PIL import Image | |
from collections import OrderedDict | |
import math | |
import os | |
import sys | |
import argparse | |
parser = argparse.ArgumentParser() | |
# Basic options | |
parser.add_argument("-style_image", help="Style target image", default='examples/inputs/seated-nude.jpg') | |
parser.add_argument("-content_image", help="Content target image", default='examples/inputs/tubingen.jpg') | |
parser.add_argument("-image_size", help="Maximum height / width of generated image", type=int, default=512) | |
# Optimization options | |
parser.add_argument("-num_iterations", help="iterations", type=int, default=1000) | |
parser.add_argument("-optimizer", help="optimiser", default="lbfgs", choices=["lbfgs", "adam"]) | |
parser.add_argument("-learning_rate", default=1) | |
# Output options | |
parser.add_argument("-print_iter", type=int, default=50) | |
parser.add_argument("-save_iter", type=int, default=100) | |
parser.add_argument("-output_image", default='out.png') | |
# Other options | |
parser.add_argument("-model_file", help="VGG 19 model file location", type=str, default='vgg19-d01eb7cb.pth') | |
parser.add_argument("-seed", help="random number seed", type=int, default=-1) | |
params = parser.parse_args() | |
use_cuda = torch.cuda.is_available() | |
dtype = torch.cuda.FloatTensor if use_cuda else torch.FloatTensor | |
# Initialize the image | |
if params.seed >= 0: | |
torch.manual_seed(params.seed) | |
torch.cuda.manual_seed(params.seed) | |
def ImageSetup(image_name, image_size): | |
image = Image.open(image_name) | |
image = image.convert('RGB') | |
Loader = transforms.Compose([transforms.Resize((image_size)), transforms.ToTensor()]) # resize and convert to tensor | |
#Normalize = transforms.Compose([transforms.Normalize(mean=[ 0.406, 0.456, 0.485], std=[0.225, 0.224, 0.229])]) # BGR STD & Mean | |
#image = Variable(Normalize(Loader(image))) | |
image = Variable(Loader(image)) | |
image = image.unsqueeze(0) | |
return image | |
def SaveImage(output_img, output_name): | |
torchvision.utils.save_image(output_img, output_name, nrow=8, padding=2, normalize=False, range=None, scale_each=False, pad_value=0) | |
content_image = ImageSetup(params.content_image, params.image_size).cuda()#.cpu() | |
style_image = ImageSetup(params.style_image, params.image_size).cuda()#.cpu() | |
class VGG(nn.Module): | |
def __init__(self, features, layer_list, num_classes=1000, init_weights=True): | |
super(VGG, self).__init__() | |
self.features = features | |
self.layer_list = layer_list | |
self.classifier = nn.Sequential( | |
nn.Linear(512 * 7 * 7, 4096), | |
nn.ReLU(True), | |
nn.Dropout(), | |
nn.Linear(4096, 4096), | |
nn.ReLU(True), | |
nn.Dropout(), | |
nn.Linear(4096, num_classes), | |
) | |
if init_weights: | |
self._initialize_weights() | |
def forward(self, x, out_keys): | |
layer = self.layer_list['C'] | |
i, c = 0, 0 | |
out = {} | |
for l in self.modules(): | |
if isinstance(l, nn.Conv2d): | |
if c < 1: | |
out[layer[c]] = F.relu(l(x)) | |
else: | |
s = c - 1 | |
out[layer[c]] = F.relu(l(out[layer[s]])) | |
c = c+1 | |
i=i+1 | |
return [out[key] for key in out_keys] | |
def _initialize_weights(self): | |
for m in self.modules(): | |
if isinstance(m, nn.Conv2d): | |
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels | |
m.weight.data.normal_(0, math.sqrt(2. / n)) | |
if m.bias is not None: | |
m.bias.data.zero_() | |
elif isinstance(m, nn.Linear): | |
m.weight.data.normal_(0, 0.01) | |
m.bias.data.zero_() | |
cfg = { | |
'A': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], | |
'B': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], | |
'D': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'], | |
'E': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'], | |
} | |
vgg19_list = ['conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1', 'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2', 'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'conv3_3', 'relu3_3', 'conv3_4', 'relu3_4', 'pool3', 'conv4_1', 'relu4_1', 'conv4_2', 'relu4_2', 'conv4_3', 'relu4_3', 'conv4_4', 'relu4_4', 'pool4', 'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2', 'conv5_3', 'relu5_3', 'conv5_4', 'relu5_4', 'pool5', 'torch_view', 'fc6', 'relu6', 'drop6', 'fc7', 'relu7', 'drop7', 'fc8', 'prob'] | |
vgg16_list = ['conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1', 'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2', 'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'conv3_3', 'relu3_3', 'pool3', 'conv4_1', 'relu4_1', 'conv4_2', 'relu4_2', 'conv4_3', 'relu4_3', 'pool4', 'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2', 'conv5_3', 'relu5_3', 'pool5', 'torch_view', 'fc6', 'relu6', 'drop6', 'fc7', 'relu7', 'drop7', 'fc8', 'prob'] | |
vgg19_dict = { | |
'C': ['conv1_1', 'conv1_2', 'conv2_1', 'conv2_2', 'conv3_1', 'conv3_2', 'conv3_3', 'conv3_4', 'conv4_1', 'conv4_2', 'conv4_3', 'conv4_4', 'conv5_1', 'conv5_2', 'conv5_3', 'conv5_4'], | |
'R': ['relu1_1', 'relu1_2', 'relu2_1', 'relu2_2', 'relu3_1', 'relu3_2', 'relu3_3', 'relu3_4', 'relu4_1', 'relu4_2', 'relu4_3', 'relu4_4', 'relu5_1', 'relu5_2', 'relu5_3', 'relu5_4'], | |
'P': ['pool1', 'pool2', 'pool3', 'pool4', 'pool5'], | |
} | |
def make_layers(cfg, pool='max'): | |
layers = [] | |
in_channels = 3 | |
if pool == 'max': | |
pool2d = nn.MaxPool2d(kernel_size=2, stride=2) | |
elif pool == 'avg': | |
pool2d = nn.AvgPool2d(kernel_size=2, stride=2) | |
else: | |
print("Unrecognized pooling parameter") | |
quit() | |
for v in cfg: | |
if v == 'M': | |
layers += [pool2d] | |
else: | |
conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1) | |
layers += [conv2d, nn.ReLU(inplace=True)] | |
in_channels = v | |
return nn.Sequential(*layers) | |
def vgg19(pool='max', **kwargs): | |
# VGG 19-layer model (configuration "E") | |
layer_list = vgg19_dict | |
model = VGG(make_layers(cfg['E'], pool='max'), layer_list, **kwargs) | |
return model | |
def vgg16(pool='max',**kwargs): | |
# VGG 16-layer model (configuration "D") | |
layer_list = vgg16_list | |
model = VGG(make_layers(cfg['D'], pool='max'), layer_list, **kwargs) | |
return model | |
model_name = os.path.splitext(params.model_file)[0].split('-')[0] | |
cnn = vgg19() | |
cnn.load_state_dict(torch.load(params.model_file)) | |
cnn = cnn.cuda() | |
# gram matrix and loss | |
class GramMatrix(nn.Module): | |
def forward(self, input): | |
b,c,h,w = input.size() | |
F = input.view(b, c, h*w) | |
G = torch.bmm(F, F.transpose(1,2)) | |
G.div_(h*w) | |
return G | |
class GramMSELoss(nn.Module): | |
def forward(self, input, target): | |
out = nn.MSELoss()(GramMatrix()(input), target) | |
return(out) | |
def maybe_print(t, loss): | |
if params.print_iter > 0 and t % params.print_iter == 0: | |
print("maybe_print") | |
return | |
def maybe_save(t): | |
should_save = params.save_iter > 0 and t % params.save_iter == 0 | |
should_save = should_save or t == params.num_iterations | |
if should_save: | |
output_filename, file_extension = os.path.splitext(params.output_image) | |
if t == params.num_iterations: | |
filename = output_filename + str(file_extension) | |
else: | |
filename = str(output_filename) + "_" + str(t) + str(file_extension) | |
SaveImage(img.data, filename) | |
#B, C, H, W = 1, 3, 64, 64 | |
#content_image = Variable(torch.randn(B, C, H, W)) | |
#style_image = Variable(torch.randn(B, C, H, W)) | |
style_layers = ['conv1_1', 'conv2_1', 'conv3_1', 'conv4_1', 'conv5_1'] | |
style_targets = [GramMatrix()(A).detach() for A in cnn(style_image, style_layers)] | |
content_layers = ['conv4_2'] | |
content_targets = [A.detach() for A in cnn(content_image, content_layers)] | |
targets = style_targets + content_targets | |
loss_layers = style_layers + content_layers | |
loss_fns = [GramMSELoss()] * len(style_layers) + [nn.MSELoss()] * len(content_layers) | |
style_weights = [1e3/n**2 for n in [64,128,256,512,512]] | |
content_weights = [1e0] | |
weights = style_weights + content_weights | |
img = Variable(content_image.data.clone(), requires_grad=True) | |
optimizer = None | |
# Run optimization. | |
if params.optimizer == 'lbfgs': | |
print("Running optimization with L-BFGS") | |
optimizer = optim.LBFGS([img], max_iter = params.num_iterations, tolerance_change = -1, tolerance_grad = -1) | |
elif params.optimizer == 'adam': | |
print("Running optimization with ADAM") | |
for t in xrange(params.num_iterations): | |
optimizer = optim.Adam([img], lr = params.learning_rate) | |
num_calls = [0] | |
while num_calls[0] <= 1: | |
def feval(): | |
optimizer.zero_grad() | |
out = cnn(img, loss_layers) | |
layer_losses = [weights[a] * loss_fns[a](A, targets[a]) for a,A in enumerate(out)] | |
loss = sum(layer_losses) | |
loss.backward() | |
num_calls[0] += 1 | |
#print loss | |
#if n_iter[0]%show_iter == (show_iter-1): | |
# print('Iteration: %d, loss: %f'%(n_iter[0]+1, loss.data[0])) | |
# print([loss_layers[li] + ': ' + str(l.data[0]) for li,l in enumerate(layer_losses)]) #loss of each layer | |
maybe_print(num_calls[0], loss) | |
maybe_save(num_calls[0]) | |
return loss | |
optimizer.step(feval) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment