Skip to content

Instantly share code, notes, and snippets.

@JossWhittle
Last active April 6, 2018 22:13
Show Gist options
  • Select an option

  • Save JossWhittle/4719ede7e961e3143230674ec74bfcd0 to your computer and use it in GitHub Desktop.

Select an option

Save JossWhittle/4719ede7e961e3143230674ec74bfcd0 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Artistic Style Transfer\n",
"\n",
"Code modified from: https://github.com/fchollet/keras/blob/master/examples/neural_style_transfer.py"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"scrolled": false
},
"outputs": [],
"source": [
"import os\n",
"os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0\"\n",
"\n",
"import sys\n",
"import random\n",
"import time\n",
"\n",
"import numpy as np\n",
"import tensorflow as tf\n",
"import keras\n",
"from keras import backend as K\n",
"from keras.preprocessing.image import load_img, img_to_array\n",
"from keras.models import Model\n",
"from keras.layers import Input, Conv2D, AveragePooling2D\n",
"from keras.engine.topology import get_source_inputs\n",
"from keras.utils import layer_utils\n",
"from keras.utils.data_utils import get_file\n",
"from keras.applications.imagenet_utils import decode_predictions, preprocess_input, _obtain_input_shape\n",
"from tensorflow.python.client import device_lib\n",
"\n",
"from scipy.misc import imsave, imresize\n",
"from scipy.optimize import minimize"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"print('Using GPU(s):', [x.name for x in device_lib.list_local_devices() if x.device_type == 'GPU'])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Modified VGG19 Model without Fully Connected Layers and using Average Pooling\n",
"#\n",
"def VGG19_avg_no_top(input_tensor):\n",
" # Determine proper input shape\n",
" input_shape = _obtain_input_shape(None, default_size=224, \n",
" min_size=48, data_format=K.image_data_format(), \n",
" include_top=False)\n",
" with tf.device('/gpu:0'):\n",
" x = Input(tensor=input_tensor, shape=input_shape)\n",
"\n",
" # Block 1\n",
" x = Conv2D(64, (3, 3), activation='relu', padding='same', name='block1_conv1')(x)\n",
" x = Conv2D(64, (3, 3), activation='relu', padding='same', name='block1_conv2')(x)\n",
" x = AveragePooling2D((2, 2), strides=(2, 2), name='block1_pool')(x)\n",
"\n",
" # Block 2\n",
" x = Conv2D(128, (3, 3), activation='relu', padding='same', name='block2_conv1')(x)\n",
" x = Conv2D(128, (3, 3), activation='relu', padding='same', name='block2_conv2')(x)\n",
" x = AveragePooling2D((2, 2), strides=(2, 2), name='block2_pool')(x)\n",
"\n",
" # Block 3\n",
" x = Conv2D(256, (3, 3), activation='relu', padding='same', name='block3_conv1')(x)\n",
" x = Conv2D(256, (3, 3), activation='relu', padding='same', name='block3_conv2')(x)\n",
" x = Conv2D(256, (3, 3), activation='relu', padding='same', name='block3_conv3')(x)\n",
" x = Conv2D(256, (3, 3), activation='relu', padding='same', name='block3_conv4')(x)\n",
" x = AveragePooling2D((2, 2), strides=(2, 2), name='block3_pool')(x)\n",
" \n",
" # Block 4\n",
" x = Conv2D(512, (3, 3), activation='relu', padding='same', name='block4_conv1')(x)\n",
" x = Conv2D(512, (3, 3), activation='relu', padding='same', name='block4_conv2')(x)\n",
" x = Conv2D(512, (3, 3), activation='relu', padding='same', name='block4_conv3')(x)\n",
" x = Conv2D(512, (3, 3), activation='relu', padding='same', name='block4_conv4')(x)\n",
" x = AveragePooling2D((2, 2), strides=(2, 2), name='block4_pool')(x)\n",
"\n",
" # Block 5\n",
" x = Conv2D(512, (3, 3), activation='relu', padding='same', name='block5_conv1')(x)\n",
" x = Conv2D(512, (3, 3), activation='relu', padding='same', name='block5_conv2')(x)\n",
" x = Conv2D(512, (3, 3), activation='relu', padding='same', name='block5_conv3')(x)\n",
" x = Conv2D(512, (3, 3), activation='relu', padding='same', name='block5_conv4')(x)\n",
" x = AveragePooling2D((2, 2), strides=(2, 2), name='block5_pool')(x)\n",
"\n",
" # Create model.\n",
" model = Model(get_source_inputs(input_tensor), x, name='vgg19_avg_no_top')\n",
"\n",
" # load weights\n",
" weights_path = get_file('vgg19_weights_tf_dim_ordering_tf_kernels_notop.h5',\n",
" 'https://github.com/fchollet/deep-learning-models/releases/'\n",
" 'download/v0.1/vgg19_weights_tf_dim_ordering_tf_kernels_notop.h5', \n",
" cache_subdir='models')\n",
" model.load_weights(weights_path) \n",
" return model"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def getImageDimensions(image_file, target_height):\n",
" w, h = load_img(image_file).size\n",
" height = target_height\n",
" width = int(w * height / h)\n",
" return (width, height)\n",
"\n",
"def preprocessImage(image_file, width, height):\n",
" img = np.expand_dims(img_to_array(load_img(image_file, target_size=(height, width))), axis=0)\n",
" return preprocess_input(img)\n",
"\n",
"def deprocessImage(img, width, height):\n",
" img = img.reshape((height, width, 3))\n",
" \n",
" # Remove zero-center by mean pixel\n",
" img[:, :, 0] += 103.939\n",
" img[:, :, 1] += 116.779\n",
" img[:, :, 2] += 123.68\n",
" \n",
" # 'BGR'->'RGB'\n",
" img = img[:,:,::-1]\n",
" img = np.clip(img, 0, 255).astype('uint8')\n",
" return img\n",
"\n",
"def getModelLayers(inputs):\n",
" model = VGG19_avg_no_top(input_tensor=inputs)\n",
" layers = dict([(layer.name, layer.output) for layer in model.layers])\n",
" return layers\n",
"\n",
"def getCombinedInputs(base_file, style_file, width, height):\n",
" base_image = K.variable(preprocessImage(base_file, width, height))\n",
" style_image = K.variable(preprocessImage(style_file, width, height))\n",
" combination_image = K.placeholder((1, height, width, 3))\n",
" \n",
" inputs = K.concatenate([base_image, style_image, combination_image], axis=0)\n",
" return (inputs, combination_image)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Loss Function\n",
"#\n",
"def gramMatrix(img):\n",
" assert K.ndim(img) == 3\n",
" \n",
" features = K.batch_flatten(K.permute_dimensions(img, (2, 0, 1)))\n",
" return K.dot(features, K.transpose(features))\n",
"\n",
"def styleLoss(style, combination, width, height):\n",
" assert K.ndim(style) == 3\n",
" assert K.ndim(combination) == 3\n",
" channels = 3\n",
" \n",
" S = gramMatrix(style)\n",
" C = gramMatrix(combination)\n",
" return K.sum(K.square(S - C)) / (4. * (channels ** 2) * ((height * width) ** 2))\n",
"\n",
"def contentLoss(base, combination):\n",
" return K.sum(K.square(combination - base))\n",
"\n",
"def totalVariationLoss(img, width, height):\n",
" assert K.ndim(img) == 4\n",
" a = K.square(img[:,:height-1,:width-1,:] - img[:,1:,:width-1,:])\n",
" b = K.square(img[:,:height-1,:width-1,:] - img[:,:height-1,1:,:])\n",
" return K.sum(K.pow(a + b, 1.25))\n",
"\n",
"def combinedLoss(layers, combination_image, width, height, \n",
" content_layer='block4_conv2', \n",
" style_layers =['block1_conv1', 'block2_conv1', 'block3_conv1', 'block4_conv1', 'block5_conv1'],\n",
" total_variation_weight = 1.0,\n",
" style_weight = 1.0,\n",
" content_weight = 0.1):\n",
" \n",
" layer_features = layers[content_layer]\n",
" base_image_features = layer_features[0,:,:,:]\n",
" combination_features = layer_features[2,:,:,:]\n",
" \n",
" loss = K.variable(0.)\n",
" loss += content_weight * contentLoss(base_image_features, combination_features)\n",
"\n",
" for layer_name in style_layers:\n",
" layer_features = layers[layer_name]\n",
" style_reference_features = layer_features[1,:,:,:]\n",
" combination_features = layer_features[2,:,:,:]\n",
" \n",
" loss += (style_weight / len(style_layers)) * styleLoss(style_reference_features, combination_features, width, height)\n",
" \n",
" if (total_variation_weight > 0.):\n",
" loss += total_variation_weight * totalVariationLoss(combination_image, width, height)\n",
"\n",
" return K.function([combination_image], [loss] + K.gradients(loss, combination_image))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Evaluator Wrapper for Sci-Py Optimizer\n",
"#\n",
"class Evaluator(object):\n",
" def __init__(self, func, width, height):\n",
" self.func = func\n",
" self.width = width\n",
" self.height = height\n",
"\n",
" def loss(self, img):\n",
" outs = self.func([img.reshape((1, self.height, self.width, 3))])\n",
" \n",
" if len(outs[1:]) == 1:\n",
" grad_values = outs[1].flatten().astype('float64')\n",
" else:\n",
" grad_values = np.array(outs[1:]).flatten().astype('float64')\n",
" return (outs[0], grad_values)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"scrolled": false
},
"outputs": [],
"source": [
"# Transfer Artistic Style from one image to another\n",
"#\n",
"try:\n",
" iterations = 20\n",
" optimization_method = 'CG' # Conjugate Gradient Optimizer\n",
" output_height = 720\n",
"\n",
" # Images to process\n",
" base_file = 'base/cat.jpg'\n",
" style_file = 'style/seated.jpg'\n",
" \n",
" # Output\n",
" output_file = 'cat-seated'\n",
" \n",
" # Style Params\n",
" total_variation_weight = 0.000000001\n",
" style_weight = 1.0\n",
" content_weight = 0.000001\n",
" \n",
" intial_variance = 10\n",
" \n",
" # VGG19 Layers\n",
" content_layer = 'block4_conv2'\n",
" style_layers = ['block1_conv1', 'block2_conv1', 'block3_conv1', 'block4_conv1', 'block5_conv1']\n",
"\n",
" # Construct tensor graph of input images for the VGG 16 network\n",
" width, height = getImageDimensions(base_file, output_height)\n",
" inputs, combination_image = getCombinedInputs(base_file, style_file, width, height)\n",
" layers = getModelLayers(inputs)\n",
"\n",
" # Normalize weights\n",
" sum_weights = total_variation_weight + style_weight + content_weight\n",
" total_variation_weight /= sum_weights\n",
" style_weight /= sum_weights\n",
" content_weight /= sum_weights\n",
" \n",
" # Function handle for Sci-Py Optimizer\n",
" evaluator = Evaluator(combinedLoss(layers, combination_image, width, height, content_layer, style_layers,\n",
" total_variation_weight, style_weight, content_weight), width, height)\n",
" \n",
" # Random seed for output image\n",
" img = preprocessImage(base_file, width, height)\n",
" img += np.random.normal(0, intial_variance, (1, height, width, 3))\n",
" \n",
" bimg = deprocessImage(preprocessImage(base_file, width, height), width, height)\n",
" simg = deprocessImage(preprocessImage(style_file, width, height), width, height)\n",
" \n",
" # Create thumbnails for combined image\n",
" thumb_height = int((height - 4) / 2)\n",
" thumb_width = int((thumb_height / height) * width)\n",
" bthumb = imresize(bimg, (thumb_height, thumb_width), interp='lanczos')\n",
" sthumb = imresize(simg, (thumb_height, thumb_width), interp='lanczos')\n",
" timg = np.vstack([bthumb, np.ones((4, thumb_width, 3)) * 255, sthumb])\n",
" \n",
" epochs = []\n",
" losses = []\n",
" for epoch in range(iterations):\n",
" start_time = time.time()\n",
"\n",
" # Move img closer to the optimum combined image\n",
" res = minimize(method=optimization_method, fun=evaluator.loss, x0=img.flatten(), jac=True, options={'maxiter':20})\n",
" img = res.x\n",
" loss = res.fun\n",
" oimg = deprocessImage(img.copy(), width, height)\n",
" \n",
" end_time = time.time()\n",
" \n",
" epochs.append(epoch)\n",
" losses.append(loss)\n",
" \n",
" print('Iteration %d | Loss %f | Time %f' % (epoch, loss, (end_time - start_time)), end='\\r')\n",
" \n",
" # Save current generated image\n",
" imsave('output/%s.png' % (output_file), oimg)\n",
" imsave('combined/%s-combined.png' % (output_file), np.hstack([timg, np.ones((height, 4, 3)) * 255, oimg]))\n",
"\n",
"except (KeyboardInterrupt, SystemExit):\n",
" print()\n",
"print('\\nHalting...')"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.4"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment