Last active
April 6, 2018 22:13
-
-
Save JossWhittle/4719ede7e961e3143230674ec74bfcd0 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| { | |
| "cells": [ | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "## Artistic Style Transfer\n", | |
| "\n", | |
| "Code modified from: https://github.com/fchollet/keras/blob/master/examples/neural_style_transfer.py" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": { | |
| "scrolled": false | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "import os\n", | |
| "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0\"\n", | |
| "\n", | |
| "import sys\n", | |
| "import random\n", | |
| "import time\n", | |
| "\n", | |
| "import numpy as np\n", | |
| "import tensorflow as tf\n", | |
| "import keras\n", | |
| "from keras import backend as K\n", | |
| "from keras.preprocessing.image import load_img, img_to_array\n", | |
| "from keras.models import Model\n", | |
| "from keras.layers import Input, Conv2D, AveragePooling2D\n", | |
| "from keras.engine.topology import get_source_inputs\n", | |
| "from keras.utils import layer_utils\n", | |
| "from keras.utils.data_utils import get_file\n", | |
| "from keras.applications.imagenet_utils import decode_predictions, preprocess_input, _obtain_input_shape\n", | |
| "from tensorflow.python.client import device_lib\n", | |
| "\n", | |
| "from scipy.misc import imsave, imresize\n", | |
| "from scipy.optimize import minimize" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "print('Using GPU(s):', [x.name for x in device_lib.list_local_devices() if x.device_type == 'GPU'])" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "# Modified VGG19 Model without Fully Connected Layers and using Average Pooling\n", | |
| "#\n", | |
| "def VGG19_avg_no_top(input_tensor):\n", | |
| " # Determine proper input shape\n", | |
| " input_shape = _obtain_input_shape(None, default_size=224, \n", | |
| " min_size=48, data_format=K.image_data_format(), \n", | |
| " include_top=False)\n", | |
| " with tf.device('/gpu:0'):\n", | |
| " x = Input(tensor=input_tensor, shape=input_shape)\n", | |
| "\n", | |
| " # Block 1\n", | |
| " x = Conv2D(64, (3, 3), activation='relu', padding='same', name='block1_conv1')(x)\n", | |
| " x = Conv2D(64, (3, 3), activation='relu', padding='same', name='block1_conv2')(x)\n", | |
| " x = AveragePooling2D((2, 2), strides=(2, 2), name='block1_pool')(x)\n", | |
| "\n", | |
| " # Block 2\n", | |
| " x = Conv2D(128, (3, 3), activation='relu', padding='same', name='block2_conv1')(x)\n", | |
| " x = Conv2D(128, (3, 3), activation='relu', padding='same', name='block2_conv2')(x)\n", | |
| " x = AveragePooling2D((2, 2), strides=(2, 2), name='block2_pool')(x)\n", | |
| "\n", | |
| " # Block 3\n", | |
| " x = Conv2D(256, (3, 3), activation='relu', padding='same', name='block3_conv1')(x)\n", | |
| " x = Conv2D(256, (3, 3), activation='relu', padding='same', name='block3_conv2')(x)\n", | |
| " x = Conv2D(256, (3, 3), activation='relu', padding='same', name='block3_conv3')(x)\n", | |
| " x = Conv2D(256, (3, 3), activation='relu', padding='same', name='block3_conv4')(x)\n", | |
| " x = AveragePooling2D((2, 2), strides=(2, 2), name='block3_pool')(x)\n", | |
| " \n", | |
| " # Block 4\n", | |
| " x = Conv2D(512, (3, 3), activation='relu', padding='same', name='block4_conv1')(x)\n", | |
| " x = Conv2D(512, (3, 3), activation='relu', padding='same', name='block4_conv2')(x)\n", | |
| " x = Conv2D(512, (3, 3), activation='relu', padding='same', name='block4_conv3')(x)\n", | |
| " x = Conv2D(512, (3, 3), activation='relu', padding='same', name='block4_conv4')(x)\n", | |
| " x = AveragePooling2D((2, 2), strides=(2, 2), name='block4_pool')(x)\n", | |
| "\n", | |
| " # Block 5\n", | |
| " x = Conv2D(512, (3, 3), activation='relu', padding='same', name='block5_conv1')(x)\n", | |
| " x = Conv2D(512, (3, 3), activation='relu', padding='same', name='block5_conv2')(x)\n", | |
| " x = Conv2D(512, (3, 3), activation='relu', padding='same', name='block5_conv3')(x)\n", | |
| " x = Conv2D(512, (3, 3), activation='relu', padding='same', name='block5_conv4')(x)\n", | |
| " x = AveragePooling2D((2, 2), strides=(2, 2), name='block5_pool')(x)\n", | |
| "\n", | |
| " # Create model.\n", | |
| " model = Model(get_source_inputs(input_tensor), x, name='vgg19_avg_no_top')\n", | |
| "\n", | |
| " # load weights\n", | |
| " weights_path = get_file('vgg19_weights_tf_dim_ordering_tf_kernels_notop.h5',\n", | |
| " 'https://github.com/fchollet/deep-learning-models/releases/'\n", | |
| " 'download/v0.1/vgg19_weights_tf_dim_ordering_tf_kernels_notop.h5', \n", | |
| " cache_subdir='models')\n", | |
| " model.load_weights(weights_path) \n", | |
| " return model" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "def getImageDimensions(image_file, target_height):\n", | |
| " w, h = load_img(image_file).size\n", | |
| " height = target_height\n", | |
| " width = int(w * height / h)\n", | |
| " return (width, height)\n", | |
| "\n", | |
| "def preprocessImage(image_file, width, height):\n", | |
| " img = np.expand_dims(img_to_array(load_img(image_file, target_size=(height, width))), axis=0)\n", | |
| " return preprocess_input(img)\n", | |
| "\n", | |
| "def deprocessImage(img, width, height):\n", | |
| " img = img.reshape((height, width, 3))\n", | |
| " \n", | |
| " # Remove zero-center by mean pixel\n", | |
| " img[:, :, 0] += 103.939\n", | |
| " img[:, :, 1] += 116.779\n", | |
| " img[:, :, 2] += 123.68\n", | |
| " \n", | |
| " # 'BGR'->'RGB'\n", | |
| " img = img[:,:,::-1]\n", | |
| " img = np.clip(img, 0, 255).astype('uint8')\n", | |
| " return img\n", | |
| "\n", | |
| "def getModelLayers(inputs):\n", | |
| " model = VGG19_avg_no_top(input_tensor=inputs)\n", | |
| " layers = dict([(layer.name, layer.output) for layer in model.layers])\n", | |
| " return layers\n", | |
| "\n", | |
| "def getCombinedInputs(base_file, style_file, width, height):\n", | |
| " base_image = K.variable(preprocessImage(base_file, width, height))\n", | |
| " style_image = K.variable(preprocessImage(style_file, width, height))\n", | |
| " combination_image = K.placeholder((1, height, width, 3))\n", | |
| " \n", | |
| " inputs = K.concatenate([base_image, style_image, combination_image], axis=0)\n", | |
| " return (inputs, combination_image)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "# Loss Function\n", | |
| "#\n", | |
| "def gramMatrix(img):\n", | |
| " assert K.ndim(img) == 3\n", | |
| " \n", | |
| " features = K.batch_flatten(K.permute_dimensions(img, (2, 0, 1)))\n", | |
| " return K.dot(features, K.transpose(features))\n", | |
| "\n", | |
| "def styleLoss(style, combination, width, height):\n", | |
| " assert K.ndim(style) == 3\n", | |
| " assert K.ndim(combination) == 3\n", | |
| " channels = 3\n", | |
| " \n", | |
| " S = gramMatrix(style)\n", | |
| " C = gramMatrix(combination)\n", | |
| " return K.sum(K.square(S - C)) / (4. * (channels ** 2) * ((height * width) ** 2))\n", | |
| "\n", | |
| "def contentLoss(base, combination):\n", | |
| " return K.sum(K.square(combination - base))\n", | |
| "\n", | |
| "def totalVariationLoss(img, width, height):\n", | |
| " assert K.ndim(img) == 4\n", | |
| " a = K.square(img[:,:height-1,:width-1,:] - img[:,1:,:width-1,:])\n", | |
| " b = K.square(img[:,:height-1,:width-1,:] - img[:,:height-1,1:,:])\n", | |
| " return K.sum(K.pow(a + b, 1.25))\n", | |
| "\n", | |
| "def combinedLoss(layers, combination_image, width, height, \n", | |
| " content_layer='block4_conv2', \n", | |
| " style_layers =['block1_conv1', 'block2_conv1', 'block3_conv1', 'block4_conv1', 'block5_conv1'],\n", | |
| " total_variation_weight = 1.0,\n", | |
| " style_weight = 1.0,\n", | |
| " content_weight = 0.1):\n", | |
| " \n", | |
| " layer_features = layers[content_layer]\n", | |
| " base_image_features = layer_features[0,:,:,:]\n", | |
| " combination_features = layer_features[2,:,:,:]\n", | |
| " \n", | |
| " loss = K.variable(0.)\n", | |
| " loss += content_weight * contentLoss(base_image_features, combination_features)\n", | |
| "\n", | |
| " for layer_name in style_layers:\n", | |
| " layer_features = layers[layer_name]\n", | |
| " style_reference_features = layer_features[1,:,:,:]\n", | |
| " combination_features = layer_features[2,:,:,:]\n", | |
| " \n", | |
| " loss += (style_weight / len(style_layers)) * styleLoss(style_reference_features, combination_features, width, height)\n", | |
| " \n", | |
| " if (total_variation_weight > 0.):\n", | |
| " loss += total_variation_weight * totalVariationLoss(combination_image, width, height)\n", | |
| "\n", | |
| " return K.function([combination_image], [loss] + K.gradients(loss, combination_image))" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "# Evaluator Wrapper for Sci-Py Optimizer\n", | |
| "#\n", | |
| "class Evaluator(object):\n", | |
| " def __init__(self, func, width, height):\n", | |
| " self.func = func\n", | |
| " self.width = width\n", | |
| " self.height = height\n", | |
| "\n", | |
| " def loss(self, img):\n", | |
| " outs = self.func([img.reshape((1, self.height, self.width, 3))])\n", | |
| " \n", | |
| " if len(outs[1:]) == 1:\n", | |
| " grad_values = outs[1].flatten().astype('float64')\n", | |
| " else:\n", | |
| " grad_values = np.array(outs[1:]).flatten().astype('float64')\n", | |
| " return (outs[0], grad_values)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": { | |
| "scrolled": false | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "# Transfer Artistic Style from one image to another\n", | |
| "#\n", | |
| "try:\n", | |
| " iterations = 20\n", | |
| " optimization_method = 'CG' # Conjugate Gradient Optimizer\n", | |
| " output_height = 720\n", | |
| "\n", | |
| " # Images to process\n", | |
| " base_file = 'base/cat.jpg'\n", | |
| " style_file = 'style/seated.jpg'\n", | |
| " \n", | |
| " # Output\n", | |
| " output_file = 'cat-seated'\n", | |
| " \n", | |
| " # Style Params\n", | |
| " total_variation_weight = 0.000000001\n", | |
| " style_weight = 1.0\n", | |
| " content_weight = 0.000001\n", | |
| " \n", | |
| " intial_variance = 10\n", | |
| " \n", | |
| " # VGG19 Layers\n", | |
| " content_layer = 'block4_conv2'\n", | |
| " style_layers = ['block1_conv1', 'block2_conv1', 'block3_conv1', 'block4_conv1', 'block5_conv1']\n", | |
| "\n", | |
| " # Construct tensor graph of input images for the VGG 16 network\n", | |
| " width, height = getImageDimensions(base_file, output_height)\n", | |
| " inputs, combination_image = getCombinedInputs(base_file, style_file, width, height)\n", | |
| " layers = getModelLayers(inputs)\n", | |
| "\n", | |
| " # Normalize weights\n", | |
| " sum_weights = total_variation_weight + style_weight + content_weight\n", | |
| " total_variation_weight /= sum_weights\n", | |
| " style_weight /= sum_weights\n", | |
| " content_weight /= sum_weights\n", | |
| " \n", | |
| " # Function handle for Sci-Py Optimizer\n", | |
| " evaluator = Evaluator(combinedLoss(layers, combination_image, width, height, content_layer, style_layers,\n", | |
| " total_variation_weight, style_weight, content_weight), width, height)\n", | |
| " \n", | |
| " # Random seed for output image\n", | |
| " img = preprocessImage(base_file, width, height)\n", | |
| " img += np.random.normal(0, intial_variance, (1, height, width, 3))\n", | |
| " \n", | |
| " bimg = deprocessImage(preprocessImage(base_file, width, height), width, height)\n", | |
| " simg = deprocessImage(preprocessImage(style_file, width, height), width, height)\n", | |
| " \n", | |
| " # Create thumbnails for combined image\n", | |
| " thumb_height = int((height - 4) / 2)\n", | |
| " thumb_width = int((thumb_height / height) * width)\n", | |
| " bthumb = imresize(bimg, (thumb_height, thumb_width), interp='lanczos')\n", | |
| " sthumb = imresize(simg, (thumb_height, thumb_width), interp='lanczos')\n", | |
| " timg = np.vstack([bthumb, np.ones((4, thumb_width, 3)) * 255, sthumb])\n", | |
| " \n", | |
| " epochs = []\n", | |
| " losses = []\n", | |
| " for epoch in range(iterations):\n", | |
| " start_time = time.time()\n", | |
| "\n", | |
| " # Move img closer to the optimum combined image\n", | |
| " res = minimize(method=optimization_method, fun=evaluator.loss, x0=img.flatten(), jac=True, options={'maxiter':20})\n", | |
| " img = res.x\n", | |
| " loss = res.fun\n", | |
| " oimg = deprocessImage(img.copy(), width, height)\n", | |
| " \n", | |
| " end_time = time.time()\n", | |
| " \n", | |
| " epochs.append(epoch)\n", | |
| " losses.append(loss)\n", | |
| " \n", | |
| " print('Iteration %d | Loss %f | Time %f' % (epoch, loss, (end_time - start_time)), end='\\r')\n", | |
| " \n", | |
| " # Save current generated image\n", | |
| " imsave('output/%s.png' % (output_file), oimg)\n", | |
| " imsave('combined/%s-combined.png' % (output_file), np.hstack([timg, np.ones((height, 4, 3)) * 255, oimg]))\n", | |
| "\n", | |
| "except (KeyboardInterrupt, SystemExit):\n", | |
| " print()\n", | |
| "print('\\nHalting...')" | |
| ] | |
| } | |
| ], | |
| "metadata": { | |
| "kernelspec": { | |
| "display_name": "Python 3", | |
| "language": "python", | |
| "name": "python3" | |
| }, | |
| "language_info": { | |
| "codemirror_mode": { | |
| "name": "ipython", | |
| "version": 3 | |
| }, | |
| "file_extension": ".py", | |
| "mimetype": "text/x-python", | |
| "name": "python", | |
| "nbconvert_exporter": "python", | |
| "pygments_lexer": "ipython3", | |
| "version": "3.6.4" | |
| } | |
| }, | |
| "nbformat": 4, | |
| "nbformat_minor": 2 | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment