-
-
Save choupi/05cac0a1b9dd44d5fda91f45755b8e09 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"Using Theano backend.\n" | |
] | |
} | |
], | |
"source": [ | |
"%matplotlib inline\n", | |
"from __future__ import print_function\n", | |
"from keras.preprocessing.image import load_img, img_to_array\n", | |
"from scipy.misc import imsave\n", | |
"import numpy as np\n", | |
"from scipy.optimize import fmin_l_bfgs_b\n", | |
"import time\n", | |
"\n", | |
"import vgg16\n", | |
"from keras import backend as K\n", | |
"from keras.applications.imagenet_utils import preprocess_input" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"base_image_path = 'tubingen.jpg'\n", | |
"style_reference_image_path = 'starry_night_google.jpg'\n", | |
"result_prefix = 'art'\n", | |
"iterations = 5\n", | |
"total_variation_weight = 1.0\n", | |
"style_weight = 1.0\n", | |
"content_weight = 0.075" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"width, height = load_img(base_image_path).size\n", | |
"img_nrows = 400\n", | |
"img_ncols = int(width * img_nrows / height)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"def preprocess_image(image_path):\n", | |
" img = load_img(image_path, target_size=(img_nrows, img_ncols))\n", | |
" img = img_to_array(img)\n", | |
" img = np.expand_dims(img, axis=0)\n", | |
" img = vgg16.preprocess_input(img)\n", | |
" return img\n", | |
"\n", | |
"def deprocess_image(x):\n", | |
" x = x.reshape((3, img_nrows, img_ncols))\n", | |
" x = x.transpose((1, 2, 0))\n", | |
" # Remove zero-center by mean pixel\n", | |
" x[:, :, 0] += 103.939\n", | |
" x[:, :, 1] += 116.779\n", | |
" x[:, :, 2] += 123.68\n", | |
" # 'BGR'->'RGB'\n", | |
" x = x[:, :, ::-1]\n", | |
" x = np.clip(x, 0, 255).astype('uint8')\n", | |
" return x" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"base_image = K.variable(preprocess_image(base_image_path))\n", | |
"style_reference_image = K.variable(preprocess_image(style_reference_image_path))\n", | |
"combination_image = K.placeholder((1, 3, img_nrows, img_ncols))\n", | |
"input_tensor = K.concatenate([base_image,\n", | |
" style_reference_image,\n", | |
" combination_image], axis=0)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Model loaded.\n" | |
] | |
} | |
], | |
"source": [ | |
"model = vgg16.VGG16(input_tensor=input_tensor,\n", | |
" weights='imagenet', include_top=False)\n", | |
"print('Model loaded.')\n", | |
"outputs_dict = dict([(layer.name, layer.output) for layer in model.layers])" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"def gram_matrix(x):\n", | |
" assert K.ndim(x) == 3\n", | |
" features = K.batch_flatten(x)\n", | |
" gram = K.dot(features, K.transpose(features))\n", | |
" return gram\n", | |
"\n", | |
"def style_loss(style, combination):\n", | |
" assert K.ndim(style) == 3\n", | |
" assert K.ndim(combination) == 3\n", | |
" S = gram_matrix(style)\n", | |
" C = gram_matrix(combination)\n", | |
" channels = 3\n", | |
" size = img_nrows * img_ncols\n", | |
" return K.sum(K.square(S - C)) / (4. * (channels ** 2) * (size ** 2))\n", | |
"\n", | |
"def content_loss(base, combination):\n", | |
" return K.sum(K.square(combination - base))\n", | |
"\n", | |
"def total_variation_loss(x):\n", | |
" assert K.ndim(x) == 4\n", | |
" a = K.square(x[:, :, :img_nrows - 1, :img_ncols - 1] - x[:, :, 1:, :img_ncols - 1])\n", | |
" b = K.square(x[:, :, :img_nrows - 1, :img_ncols - 1] - x[:, :, :img_nrows - 1, 1:])\n", | |
" return K.sum(K.pow(a + b, 1.25))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"loss = K.variable(0.)\n", | |
"layer_features = outputs_dict['block5_conv2']\n", | |
"base_image_features = layer_features[0, :, :, :]\n", | |
"combination_features = layer_features[2, :, :, :]\n", | |
"loss += content_weight * content_loss(base_image_features,\n", | |
" combination_features)\n", | |
"\n", | |
"feature_layers = ['block1_conv1', 'block2_conv1',\n", | |
" 'block3_conv1', 'block4_conv1',\n", | |
" 'block5_conv1']\n", | |
"for layer_name in feature_layers:\n", | |
" layer_features = outputs_dict[layer_name]\n", | |
" style_reference_features = layer_features[1, :, :, :]\n", | |
" combination_features = layer_features[2, :, :, :]\n", | |
" sl = style_loss(style_reference_features, combination_features)\n", | |
" loss += (style_weight / len(feature_layers)) * sl\n", | |
"loss += total_variation_weight * total_variation_loss(combination_image)\n", | |
"grads = K.gradients(loss, combination_image)\n", | |
"outputs = [loss]\n", | |
"if isinstance(grads, (list, tuple)):\n", | |
" outputs += grads\n", | |
"else:\n", | |
" outputs.append(grads)\n", | |
"\n", | |
"f_outputs = K.function([combination_image], outputs)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 9, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"def eval_loss_and_grads(x):\n", | |
" x = x.reshape((1, 3, img_nrows, img_ncols))\n", | |
" print('.', end='')\n", | |
" outs = f_outputs([x])\n", | |
" loss_value = outs[0]\n", | |
" if len(outs[1:]) == 1:\n", | |
" grad_values = outs[1].flatten().astype('float64')\n", | |
" else:\n", | |
" grad_values = np.array(outs[1:]).flatten().astype('float64')\n", | |
" return loss_value, grad_values\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 10, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"class Evaluator(object):\n", | |
"\n", | |
" def __init__(self):\n", | |
" self.loss_value = None\n", | |
" self.grads_values = None\n", | |
"\n", | |
" def loss(self, x):\n", | |
" assert self.loss_value is None\n", | |
" loss_value, grad_values = eval_loss_and_grads(x)\n", | |
" self.loss_value = loss_value\n", | |
" self.grad_values = grad_values\n", | |
" return self.loss_value\n", | |
"\n", | |
" def grads(self, x):\n", | |
" assert self.loss_value is not None\n", | |
" grad_values = np.copy(self.grad_values)\n", | |
" self.loss_value = None\n", | |
" self.grad_values = None\n", | |
" return grad_values\n", | |
"\n", | |
"evaluator = Evaluator()\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 11, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"(1L, 3L, 400L, 533L)\n", | |
"Start of iteration 0\n", | |
".....................\n", | |
"Current loss value: 2857231872.0\n", | |
"Image saved as art_at_iteration_0.png\n", | |
"Iteration 0 completed in 223s\n", | |
"Start of iteration 1\n", | |
".....................\n", | |
"Current loss value: 2399473664.0\n", | |
"Image saved as art_at_iteration_1.png\n", | |
"Iteration 1 completed in 223s\n", | |
"Start of iteration 2\n", | |
".....................\n", | |
"Current loss value: 2273068544.0\n", | |
"Image saved as art_at_iteration_2.png\n", | |
"Iteration 2 completed in 243s\n", | |
"Start of iteration 3\n", | |
".....................\n", | |
"Current loss value: 2208670208.0\n", | |
"Image saved as art_at_iteration_3.png\n", | |
"Iteration 3 completed in 251s\n", | |
"Start of iteration 4\n", | |
".....................\n", | |
"Current loss value: 2167820032.0\n", | |
"Image saved as art_at_iteration_4.png\n", | |
"Iteration 4 completed in 260s\n" | |
] | |
} | |
], | |
"source": [ | |
"x = preprocess_image(base_image_path)\n", | |
"print(x.shape)\n", | |
"for i in range(iterations):\n", | |
" print('Start of iteration', i)\n", | |
" start_time = time.time()\n", | |
" x, min_val, info = fmin_l_bfgs_b(evaluator.loss, x.flatten(), fprime=evaluator.grads, maxfun=20)\n", | |
" print()\n", | |
" print('Current loss value:', min_val)\n", | |
" # save current generated image\n", | |
" img = deprocess_image(x.copy())\n", | |
" fname = result_prefix + '_at_iteration_%d.png' % i\n", | |
" imsave(fname, img)\n", | |
" end_time = time.time()\n", | |
" print('Image saved as', fname)\n", | |
" print('Iteration %d completed in %ds' % (i, end_time - start_time))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 2", | |
"language": "python", | |
"name": "python2" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 2 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython2", | |
"version": "2.7.13" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment