Skip to content

Instantly share code, notes, and snippets.

@sborquez
Created February 7, 2020 17:17
Show Gist options
  • Save sborquez/2360dfe28cbe8a415563b0d21b239b58 to your computer and use it in GitHub Desktop.
Save sborquez/2360dfe28cbe8a415563b0d21b239b58 to your computer and use it in GitHub Desktop.
Applied Neural Style Transfer
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "Applied Neural Style Transfer",
"provenance": [],
"private_outputs": true,
"collapsed_sections": [],
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"accelerator": "GPU"
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/sborquez/2360dfe28cbe8a415563b0d21b239b58/applied-neural-style-transfer.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "jo5PziEC4hWs",
"colab_type": "text"
},
"source": [
"# Neural Style Transfer with tf.keras\n",
"\n",
"<table class=\"tfo-notebook-buttons\" align=\"left\">\n",
" <td>\n",
" <a target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/models/blob/master/research/nst_blogpost/4_Neural_Style_Transfer_with_Eager_Execution.ipynb\"><img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" />Run in Google Colab</a>\n",
" </td>\n",
" <td>\n",
" <a target=\"_blank\" href=\"https://github.com/tensorflow/models/blob/master/research/nst_blogpost/4_Neural_Style_Transfer_with_Eager_Execution.ipynb\"><img src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" />View source on GitHub</a>\n",
" </td>\n",
"</table>"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "eqxUicSPUOP6",
"colab_type": "text"
},
"source": [
"### Import and configure modules"
]
},
{
"cell_type": "code",
"metadata": {
"id": "sc1OLbOWhPCO",
"colab_type": "code",
"colab": {}
},
"source": [
"from google.colab import files\n",
"uploaded = {}\n",
"content_path = '/content/content.jpg'\n",
"style_path = '/content/deathnote.jpg'\n",
"\n",
"import IPython.display\n",
"from tqdm import tqdm_notebook, tqdm\n",
"from ipywidgets import interact, interactive, fixed, interact_manual\n",
"import ipywidgets as widgets\n",
"\n",
"import matplotlib.pyplot as plt\n",
"import matplotlib as mpl\n",
"mpl.rcParams['figure.figsize'] = (10,10)\n",
"mpl.rcParams['axes.grid'] = False\n",
"\n",
"import numpy as np\n",
"from PIL import Image\n",
"import time\n",
"import functools\n",
"\n",
"import tensorflow as tf\n",
"\n",
"from tensorflow.keras.preprocessing import image as kp_image\n",
"from tensorflow.keras import models \n",
"from tensorflow.keras import losses\n",
"from tensorflow.keras import layers\n",
"from tensorflow.keras import backend as K\n",
"\n",
"\n",
"print(f\"tf version: {tf.__version__}\")\n",
"\n",
"print(\"TESTING\")\n",
"tf.enable_eager_execution()\n",
"print(\"Eager execution: {}\".format(tf.executing_eagerly()))\n",
"\n",
"with tf.device('/gpu:0'):\n",
" a = tf.constant([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], shape=[2, 3], name='a')\n",
" b = tf.constant([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], shape=[3, 2], name='b')\n",
" c = tf.matmul(a, b)\n",
" print(c)\n",
"\n",
"def load_img(path_to_img):\n",
" max_dim = 512\n",
" img = Image.open(path_to_img)\n",
" long = max(img.size)\n",
" scale = max_dim/long\n",
" img = img.resize((round(img.size[0]*scale), round(img.size[1]*scale)), Image.ANTIALIAS)\n",
" \n",
" img = kp_image.img_to_array(img)\n",
" \n",
" # We need to broadcast the image array such that it has a batch dimension \n",
" img = np.expand_dims(img, axis=0)\n",
" return img\n",
"\n",
"def select_images(content = '/content/content.jpg', style = '/content/deathnote.jpg'):\n",
"\n",
" global content_path\n",
" global style_path\n",
"\n",
" content_path = content\n",
" style_path = style\n",
"\n",
" \n",
" def imshow(img, title=None):\n",
" # Remove the batch dimension\n",
" out = np.squeeze(img, axis=0)\n",
" # Normalize for display \n",
" out = out.astype('uint8')\n",
" plt.imshow(out)\n",
" if title is not None:\n",
" plt.title(title)\n",
" plt.imshow(out)\n",
"\n",
" plt.figure(figsize=(10,10))\n",
"\n",
" content = load_img(content_path).astype('uint8')\n",
" style = load_img(style_path).astype('uint8')\n",
"\n",
" plt.subplot(1, 2, 1)\n",
" imshow(content, 'Content Image')\n",
"\n",
" plt.subplot(1, 2, 2)\n",
" imshow(style, 'Style Image')\n",
" plt.show()"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "mjzlKRQRs_y2",
"colab_type": "code",
"colab": {}
},
"source": [
"def load_and_process_img(path_to_img):\n",
" img = load_img(path_to_img)\n",
" img = tf.keras.applications.vgg19.preprocess_input(img)\n",
" return img\n",
" \n",
"def deprocess_img(processed_img):\n",
" x = processed_img.copy()\n",
" if len(x.shape) == 4:\n",
" x = np.squeeze(x, 0)\n",
" assert len(x.shape) == 3, (\"Input to deprocess image must be an image of \"\n",
" \"dimension [1, height, width, channel] or [height, width, channel]\")\n",
" if len(x.shape) != 3:\n",
" raise ValueError(\"Invalid input to deprocessing image\")\n",
" \n",
" # perform the inverse of the preprocessiing step\n",
" x[:, :, 0] += 103.939\n",
" x[:, :, 1] += 116.779\n",
" x[:, :, 2] += 123.68\n",
" x = x[:, :, ::-1]\n",
"\n",
" x = np.clip(x, 0, 255).astype('uint8')\n",
" return x\n",
"\n",
"def get_model():\n",
" \"\"\" Creates our model with access to intermediate layers. \n",
" \n",
" This function will load the VGG19 model and access the intermediate layers. \n",
" These layers will then be used to create a new model that will take input image\n",
" and return the outputs from these intermediate layers from the VGG model. \n",
" \n",
" Returns:\n",
" returns a keras model that takes image inputs and outputs the style and \n",
" content intermediate layers. \n",
" \"\"\"\n",
" # Load our model. We load pretrained VGG, trained on imagenet data\n",
" vgg = tf.keras.applications.vgg19.VGG19(include_top=False, weights='imagenet')\n",
" vgg.trainable = False\n",
" # Get output layers corresponding to style and content layers \n",
" style_outputs = [vgg.get_layer(name).output for name in style_layers]\n",
" content_outputs = [vgg.get_layer(name).output for name in content_layers]\n",
" model_outputs = style_outputs + content_outputs\n",
" # Build model \n",
" return models.Model(vgg.input, model_outputs)\n",
"\n",
"def get_content_loss(base_content, target):\n",
" return tf.reduce_mean(tf.square(base_content - target))\n",
"\n",
"def gram_matrix(input_tensor):\n",
" # We make the image channels first \n",
" channels = int(input_tensor.shape[-1])\n",
" a = tf.reshape(input_tensor, [-1, channels])\n",
" n = tf.shape(a)[0]\n",
" gram = tf.matmul(a, a, transpose_a=True)\n",
" return gram / tf.cast(n, tf.float32)\n",
"\n",
"def get_style_loss(base_style, gram_target):\n",
" \"\"\"Expects two images of dimension h, w, c\"\"\"\n",
" # height, width, num filters of each layer\n",
" # We scale the loss at a given layer by the size of the feature map and the number of filters\n",
" height, width, channels = base_style.get_shape().as_list()\n",
" gram_style = gram_matrix(base_style)\n",
" \n",
" return tf.reduce_mean(tf.square(gram_style - gram_target))# / (4. * (channels ** 2) * (width * height) ** 2)\n",
"\n",
"def get_feature_representations(model, content_path, style_path):\n",
" \"\"\"Helper function to compute our content and style feature representations.\n",
"\n",
" This function will simply load and preprocess both the content and style \n",
" images from their path. Then it will feed them through the network to obtain\n",
" the outputs of the intermediate layers. \n",
" \n",
" Arguments:\n",
" model: The model that we are using.\n",
" content_path: The path to the content image.\n",
" style_path: The path to the style image\n",
" \n",
" Returns:\n",
" returns the style features and the content features. \n",
" \"\"\"\n",
" # Load our images in \n",
" content_image = load_and_process_img(content_path)\n",
" style_image = load_and_process_img(style_path)\n",
" \n",
" # batch compute content and style features\n",
" style_outputs = model(style_image)\n",
" content_outputs = model(content_image)\n",
" \n",
" \n",
" # Get the style and content feature representations from our model \n",
" style_features = [style_layer[0] for style_layer in style_outputs[:num_style_layers]]\n",
" content_features = [content_layer[0] for content_layer in content_outputs[num_style_layers:]]\n",
" return style_features, content_features\n",
"\n",
"def compute_loss(model, loss_weights, init_image, gram_style_features, content_features):\n",
" \"\"\"This function will compute the loss total loss.\n",
" \n",
" Arguments:\n",
" model: The model that will give us access to the intermediate layers\n",
" loss_weights: The weights of each contribution of each loss function. \n",
" (style weight, content weight, and total variation weight)\n",
" init_image: Our initial base image. This image is what we are updating with \n",
" our optimization process. We apply the gradients wrt the loss we are \n",
" calculating to this image.\n",
" gram_style_features: Precomputed gram matrices corresponding to the \n",
" defined style layers of interest.\n",
" content_features: Precomputed outputs from defined content layers of \n",
" interest.\n",
" \n",
" Returns:\n",
" returns the total loss, style loss, content loss, and total variational loss\n",
" \"\"\"\n",
" style_weight, content_weight = loss_weights\n",
" \n",
" # Feed our init image through our model. This will give us the content and \n",
" # style representations at our desired layers. Since we're using eager\n",
" # our model is callable just like any other function!\n",
" model_outputs = model(init_image)\n",
" \n",
" style_output_features = model_outputs[:num_style_layers]\n",
" content_output_features = model_outputs[num_style_layers:]\n",
" \n",
" style_score = 0\n",
" content_score = 0\n",
"\n",
" # Accumulate style losses from all layers\n",
" # Here, we equally weight each contribution of each loss layer\n",
" weight_per_style_layer = 1.0 / float(num_style_layers)\n",
" for target_style, comb_style in zip(gram_style_features, style_output_features):\n",
" style_score += weight_per_style_layer * get_style_loss(comb_style[0], target_style)\n",
" \n",
" # Accumulate content losses from all layers \n",
" weight_per_content_layer = 1.0 / float(num_content_layers)\n",
" for target_content, comb_content in zip(content_features, content_output_features):\n",
" content_score += weight_per_content_layer* get_content_loss(comb_content[0], target_content)\n",
" \n",
" style_score *= style_weight\n",
" content_score *= content_weight\n",
"\n",
" # Get total loss\n",
" loss = style_score + content_score \n",
" return loss, style_score, content_score\n",
"\n",
"def compute_grads(cfg):\n",
" with tf.GradientTape() as tape: \n",
" all_loss = compute_loss(**cfg)\n",
" # Compute gradients wrt input image\n",
" total_loss = all_loss[0]\n",
" return tape.gradient(total_loss, cfg['init_image']), all_loss\n",
"\n",
"def run_style_transfer(content_path, \n",
" style_path,\n",
" num_iterations=2000,\n",
" content_weight=1e3, \n",
" style_weight=1e-2,\n",
" display_interval=150): \n",
" # We don't need to (or want to) train any layers of our model, so we set their\n",
" # trainable to false. \n",
" model = get_model() \n",
" for layer in model.layers:\n",
" layer.trainable = False\n",
" \n",
" # Get the style and content feature representations (from our specified intermediate layers) \n",
" style_features, content_features = get_feature_representations(model, content_path, style_path)\n",
" gram_style_features = [gram_matrix(style_feature) for style_feature in style_features]\n",
" \n",
" # Set initial image\n",
" init_image = load_and_process_img(content_path)\n",
" init_image = tf.Variable(init_image, dtype=tf.float32)\n",
" # Create our optimizer\n",
" opt = tf.train.AdamOptimizer(learning_rate=5, beta1=0.99, epsilon=1e-1)\n",
"\n",
" # For displaying intermediate images \n",
" iter_count = 1\n",
" \n",
" # Store our best result\n",
" best_loss, best_img = float('inf'), None\n",
" \n",
" # Create a nice config \n",
" loss_weights = (style_weight, content_weight)\n",
" cfg = {\n",
" 'model': model,\n",
" 'loss_weights': loss_weights,\n",
" 'init_image': init_image,\n",
" 'gram_style_features': gram_style_features,\n",
" 'content_features': content_features\n",
" }\n",
" \n",
" # For displaying\n",
" num_rows = 2\n",
" num_cols = 5\n",
" #display_interval = 10 #num_iterations/(num_rows*num_cols)\n",
" start_time = time.time()\n",
" global_start = time.time()\n",
" \n",
" norm_means = np.array([103.939, 116.779, 123.68])\n",
" min_vals = -norm_means\n",
" max_vals = 255 - norm_means \n",
" \n",
" imgs = []\n",
" for i in tqdm(range(num_iterations)):\n",
" grads, all_loss = compute_grads(cfg)\n",
" loss, style_score, content_score = all_loss\n",
" opt.apply_gradients([(grads, init_image)])\n",
" clipped = tf.clip_by_value(init_image, min_vals, max_vals)\n",
" init_image.assign(clipped)\n",
" end_time = time.time() \n",
" \n",
" if loss < best_loss:\n",
" # Update best loss and best image from total loss. \n",
" best_loss = loss\n",
" best_img = deprocess_img(init_image.numpy())\n",
"\n",
" if i % display_interval== 0:\n",
" start_time = time.time()\n",
" \n",
" # Use the .numpy() method to get the concrete numpy array\n",
" plot_img = init_image.numpy()\n",
" plot_img = deprocess_img(plot_img)\n",
" imgs.append(plot_img)\n",
" IPython.display.clear_output(wait=True)\n",
" IPython.display.display_png(Image.fromarray(plot_img))\n",
" print('Iteration: {}'.format(i)) \n",
" print('Total loss: {:.4e}, ' \n",
" 'style loss: {:.4e}, '\n",
" 'content loss: {:.4e}, '\n",
" 'time: {:.4f}s'.format(loss, style_score, content_score, time.time() - start_time))\n",
" print('Total time: {:.4f}s'.format(time.time() - global_start))\n",
" IPython.display.clear_output(wait=True)\n",
" plt.figure(figsize=(25,30))\n",
" for i,img in enumerate(imgs):\n",
" plt.subplot(num_rows,num_cols,i+1)\n",
" plt.imshow(img)\n",
" plt.xticks([])\n",
" plt.yticks([])\n",
" \n",
" return best_img, best_loss\n",
"\n",
"\n",
"def show_results(best_img, content_path, style_path, show_large_final=True):\n",
" plt.figure(figsize=(10, 5))\n",
" content = load_img(content_path) \n",
" style = load_img(style_path)\n",
"\n",
" plt.subplot(1, 2, 1)\n",
" imshow(content, 'Content Image')\n",
"\n",
" plt.subplot(1, 2, 2)\n",
" imshow(style, 'Style Image')\n",
"\n",
" if show_large_final: \n",
" plt.figure(figsize=(10, 10))\n",
"\n",
" plt.imshow(best_img)\n",
" plt.title('Output Image')\n",
" plt.show()"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "xE4Yt8nArTeR",
"colab_type": "text"
},
"source": [
"## Select Images and Layers\n"
]
},
{
"cell_type": "code",
"metadata": {
"id": "vaxVKTRgEKE3",
"colab_type": "code",
"colab": {}
},
"source": [
"uploaded_ = files.upload()\n",
"uploaded.update(uploaded_)"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "_UWQmeEaiKkP",
"colab_type": "code",
"colab": {}
},
"source": [
"interact( select_images, content=list(uploaded.keys()), style=list(uploaded.keys()));"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "N4-8eUp_Kc-j",
"colab_type": "code",
"colab": {}
},
"source": [
"# Content layer where will pull our feature maps\n",
"content_layers = ['block5_conv2'] \n",
"\n",
"# Style layer we are interested in\n",
"style_layers = ['block1_conv1',\n",
" 'block2_conv1',\n",
" #'block3_conv1', \n",
" 'block3_conv4',\n",
" 'block4_conv1', \n",
" #'block4_conv4', \n",
" 'block5_conv1',\n",
" #'block5_conv3',\n",
" 'block5_conv4',\n",
" ]\n",
"\n",
"num_content_layers = len(content_layers)\n",
"num_style_layers = len(style_layers)"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "8QF_pHkelmMl",
"colab_type": "text"
},
"source": [
"# RUN!\n"
]
},
{
"cell_type": "code",
"metadata": {
"id": "vSVMx4burydi",
"colab_type": "code",
"colab": {}
},
"source": [
"best, best_loss = run_style_transfer(content_path, style_path, num_iterations=5000)"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "i6d6O50Yvs6a",
"colab_type": "code",
"colab": {}
},
"source": [
"show_results(best, content_path, style_path)"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "SSH6OpyyQn7w",
"colab_type": "code",
"colab": {}
},
"source": [
"Image.fromarray(best)\n",
"\n",
"files.download('wave_turtle.png')"
],
"execution_count": 0,
"outputs": []
}
]
}
@sborquez
Copy link
Author

sborquez commented Feb 7, 2020

@

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment