Skip to content

Instantly share code, notes, and snippets.

@eyaler
Last active April 7, 2022 14:37
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save eyaler/5303782669fb43510d398bd346c6e3e6 to your computer and use it in GitHub Desktop.
Save eyaler/5303782669fb43510d398bd346c6e3e6 to your computer and use it in GitHub Desktop.
deep-painterly-harmonization.ipynb
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "deep-painterly-harmonization.ipynb",
"provenance": [],
"collapsed_sections": [],
"private_outputs": true,
"machine_shape": "hm",
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"accelerator": "GPU"
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/eyaler/5303782669fb43510d398bd346c6e3e6/deep-painterly-harmonization.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "markdown",
"source": [
"# Deep Painterly Harmonization\n",
"### A complete Colab adaptation of the original code including the post-processing stage + optional automatic mask generation\n",
"\n",
"Paper: https://arxiv.org/pdf/1804.03189.pdf\n",
"\n",
"Original repo: https://github.com/luanfujun/deep-painterly-harmonization\n",
"\n",
"Notebook adapted from: https://colab.research.google.com/github/marcduda/deep-painterly-harmonization/blob/master/deep_painterly_harmonization_colab.ipynb\n",
"\n",
"![](https://eyalgruss.com/share/rushmore.jpg?)\n",
"\n",
"More examples: https://github.com/luanfujun/deep-painterly-harmonization/blob/master/README2.md\n",
"\n",
"Shortcut to this notebook: [j.mp/deepharm](https://j.mp/deepharm)\n",
"\n",
"Notebook by: [Eyal Gruss](https://eyalgruss.com) \\([@eyaler](https://twitter.com/eyaler)\\)\n",
"\n",
"A curated list of online generative tools: [j.mp/generativetools](https://j.mp/generativetools)"
],
"metadata": {
"id": "EqR57s8oim4K"
}
},
{
"cell_type": "code",
"metadata": {
"id": "B2kC8y5u9spV",
"cellView": "form"
},
"source": [
"#@title Setup\n",
"#@markdown Takes about 20 minutes...\n",
"\n",
"!apt --purge remove \"*cublas*11*\" \"*cuda*11*\"\n",
"!apt install cuda-10-0 --reinstall\n",
"!rm /usr/local/cuda\n",
"!ln -s /usr/local/cuda-10.0 /usr/local/cuda\n",
"\n",
"%cd /content\n",
"!git clone --depth 1 https://github.com/nagadomi/distro torch --recursive\n",
"!git clone --depth 1 https://github.com/marcduda/deep-painterly-harmonization\n",
"%cd /content/torch\n",
"!bash install-deps\n",
"!yes | ./install.sh\n",
"!./install/bin/torch-activate\n",
"%cd /content/deep-painterly-harmonization\n",
"!sh models/download_models.sh\n",
"!wget -nc --no-check-certificate https://raw.githubusercontent.com/Gasp34/PatchMatch/fix_propagation/PatchMatch.py\n",
"!make clean && make\n",
"import os\n",
"if not os.path.exists('data-paper'):\n",
" !mv data data-paper\n",
" !mkdir data \n",
"if not os.path.exists('results-paper'):\n",
" !mv results results-paper\n",
"%cd data\n",
"!wget -nc --no-check-certificate https://eyalgruss.com/share/rushmore_naive.png\n",
"!wget -nc --no-check-certificate https://eyalgruss.com/share/rushmore_target.png\n",
"!apt-get install libprotobuf-dev protobuf-compiler\n",
"!git config --global url.https://github.com/.insteadOf git://github.com && /content/torch/install/bin/luarocks install loadcaffe"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"#@title Upload files\n",
"#@markdown Click button below to select files or cancel to use demo images\n",
"\n",
"#@markdown Image file sets (see example above) should be named e.g: **SOMENAME_target.jpg**, **SOMENAME_naive.jpg**, and optionally: SOMENAME_mask.jpg\n",
"\n",
"#@markdown Note: Will try to generate mask if missing - works better for original .png files (and may fail for .jpg files with unequal original sizes)\n",
"\n",
"from google.colab import files\n",
"\n",
"!rm -f /content/sample_data/*\n",
"%cd /content/sample_data\n",
"try:\n",
" uploaded = files.upload()\n",
" if uploaded:\n",
" !rm -f /content/deep-painterly-harmonization/data/*\n",
" !cp /content/sample_data/* /content/deep-painterly-harmonization/data/\n",
"except:\n",
" pass\n",
"%cd /content/deep-painterly-harmonization"
],
"metadata": {
"cellView": "form",
"id": "J5ZMs9tzMBHY"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"#@title Harmonize!\n",
"#@markdown 1. If you do not provide a mask - will try to generate using missing_mask_tolerance - this works better for original .png files\n",
"#@markdown 2. You can set size=0 to use original sizes. If you get a memory error reduce size parameter or set to 700\n",
"#@markdown 3. Play with the style_weight parameter. Paper recommends values as: \n",
"#@markdown >* **1** - for Art Nouveau (Modern), Baroque, Early Renaissance, High Renaissance, Mannerism (Late Renaissance), Naive Art (Primitivism), Northern Renaissance, Realism, Surrealism, Symbolism and Ukiyo-e\n",
"#@markdown >* **5** - for Abstract Art, Abstract Expressionism, Color Field Painting, Impressionism and Post-Impressionism\n",
"#@markdown >* **10** - for Cubism and Expressionism \n",
"#@markdown 4. Intermediate results will be found in: content/deep-painterly-harmonization/results\n",
"#@markdown 5. You will get an error if the generated mask is close to uniform, i.e. when the target and naive images are either too similar (e.g. when no object was added) or too different (e.g. when they are .jpg files with unequal original sizes)\n",
"\n",
"%cd /content/deep-painterly-harmonization\n",
"import os\n",
"import numpy as np\n",
"import cv2\n",
"from cv2.ximgproc import createGuidedFilter\n",
"from PatchMatch import NNS\n",
"from google.colab.patches import cv2_imshow\n",
"from google.colab import files\n",
"from IPython.utils.capture import capture_output\n",
"\n",
"missing_mask_tolerance = 0#@param {type: \"slider\", max: 255}\n",
"mask_dilate_size = 35 #@param {type: \"integer\"}\n",
"size = 700#@param {type: \"integer\"}\n",
"style_weight = 10#@param {type: \"number\"}\n",
"iterations = 1000 #@param {type: \"integer\"}\n",
"original_colors = False #@param {type: \"boolean\"}\n",
"patch_match_size = 7#@param {type: \"integer\"}\n",
"if not mask_dilate_size % 2:\n",
" mask_dilate_size += 1\n",
"if not patch_match_size % 2:\n",
" patch_match_size += 1\n",
"original_colors = int(original_colors)\n",
"data_folder = 'data'\n",
"result_folder = 'results'\n",
"\n",
"def resize(im, size):\n",
" h, w = im.shape[:2]\n",
" if not max(h, w) == size:\n",
" if h > w:\n",
" w = size * w // h\n",
" h = size\n",
" else:\n",
" h = size * h // w\n",
" w = size\n",
" return cv2.resize(im, (w, h))\n",
"\n",
"os.makedirs(result_folder, exist_ok=True)\n",
"images = os.listdir(data_folder)\n",
"for naive in images:\n",
" if '_naive' not in naive:\n",
" continue\n",
" naive = os.path.join(data_folder, naive)\n",
" naive_im = cv2.imread(naive)\n",
" if not size:\n",
" new_size = max(naive_im.shape[:2])\n",
" else:\n",
" new_size = size\n",
" prefix = naive.split('_naive')[0]\n",
" for target in images:\n",
" target = os.path.join(data_folder, target)\n",
" if target.startswith(prefix + '_target'):\n",
" break\n",
" else:\n",
" print('Could not find', prefix + '_target')\n",
" continue\n",
" for mask in images:\n",
" mask = os.path.join(data_folder, mask)\n",
" if mask.startswith(prefix + '_mask'):\n",
" break\n",
" else:\n",
" print('Could not find', prefix + '_mask', '- will create mask!')\n",
" target_im = cv2.imread(target)\n",
" naive_im = resize(naive_im, new_size)\n",
" if target_im.shape[:2] != naive_im.shape[:2]:\n",
" target_im = cv2.resize(target_im, naive_im.shape[:2][::-1])\n",
" mask = prefix + '_genmask' + os.path.splitext(naive)[1]\n",
" cv2.imwrite(mask, (np.max(abs(naive_im-target_im), axis=-1) > missing_mask_tolerance)*255)\n",
" print(prefix)\n",
" mask_im = cv2.imread(mask)\n",
" h, w, c = mask_im.shape\n",
" if c == 3:\n",
" mask_im = mask_im[..., 0] \n",
" mask_im = resize(mask_im, new_size)\n",
" dilated_im = cv2.GaussianBlur(mask_im / 255, (mask_dilate_size, mask_dilate_size), mask_dilate_size / 3)\n",
" dilated_im[dilated_im > 0.1] = 255\n",
" dilated_im[dilated_im <= 0.1] = 0\n",
" dilated = prefix + '_dilated' + os.path.splitext(mask)[1]\n",
" cv2.imwrite(dilated, dilated_im)\n",
" res_prefix = os.path.join(result_folder, prefix.split('/')[-1])\n",
" inter = res_prefix + '_inter_res.jpg'\n",
" final = res_prefix + '_final_res.jpg'\n",
" final1 = res_prefix + '_final_res1.jpg'\n",
" final2 = res_prefix + '_final_res2.jpg'\n",
" with open('style.txt', 'w') as f:\n",
" f.write(f'idx=0, classifed label=sheker kolshehu, weight={style_weight}\\n')\n",
" common_args = f'-content_image {naive} -style_image {target} -tmask_image {mask} -mask_image {dilated} -gpu 0 -original_colors {original_colors} -image_size {new_size} -print_iter 100 -save_iter 0'\n",
" with capture_output() as cap:\n",
" !/content/torch/install/bin/th neural_gram.lua $common_args -output_image $inter && /content/torch/install/bin/th neural_paint.lua $common_args -cnnmrf_image $inter -wikiart_fn style.txt -output_image $final -num_iterations $iterations\n",
" out = cap.stdout\n",
" if 'out of memory' in out:\n",
" print('\\nOut of memory error! Reduce size parameter or set to 700')\n",
" break\n",
" \n",
" tr = 3\n",
" dilated_im = cv2.GaussianBlur(mask_im / 255, (tr * 2 + 1, tr * 2 + 1), tr)\n",
" dilated_im[dilated_im > 0.01] = 1\n",
" dilated_im[dilated_im <= 0.01] = 0\n",
" dilated_im = cv2.GaussianBlur(dilated_im, (tr * 2 + 1, tr * 2 + 1), tr)\n",
" \n",
" r = 2 # try 2, 4, 8\n",
" eps = 0.1**2 # try 0.1**2, 0.2**2, 0.4**2\n",
" final_im = cv2.imread(final)\n",
" final_im = cv2.cvtColor(final_im, cv2.COLOR_BGR2LAB)\n",
" guided = createGuidedFilter(naive_im, r, eps * 255 * 255)\n",
" final_im[..., 1] = guided.filter(final_im[..., 1])\n",
" final_im[..., 2] = guided.filter(final_im[..., 2])\n",
" final_im = cv2.cvtColor(final_im, cv2.COLOR_LAB2BGR)\n",
" cv2.imwrite(final1, final_im)\n",
"\n",
" iter = 5\n",
" target_im = cv2.imread(target)\n",
" if target_im.shape[:2] != final_im.shape[:2]:\n",
" target_im = cv2.resize(target_im, final_im.shape[:2][::-1])\n",
" with capture_output() as cap:\n",
" ann, dist, score_list = NNS(final_im, target_im, patch_match_size, iter)\n",
" final_im2_base = np.zeros_like(final_im)\n",
" for i in range(final_im.shape[0]):\n",
" for j in range(final_im.shape[1]):\n",
" final_im2_base[i, j] = target_im[ann[i, j][0], ann[i, j][1]]\n",
" \n",
" fr = 3\n",
" final_im_base = cv2.GaussianBlur(final_im, (fr * 2 + 1, fr * 2 + 1), fr)\n",
" final_im2 = final_im2_base.astype(np.float32) + final_im - final_im_base\n",
" dilated_im = dilated_im[..., np.newaxis]\n",
" final_im2 = final_im2 * dilated_im + target_im.astype(np.float32) * (1 - dilated_im)\n",
" \n",
" cv2.imwrite(final2, final_im2)\n",
" cv2_imshow(final_im2)\n",
" files.download(final2)\n"
],
"metadata": {
"id": "OSRYToagdtxs",
"cellView": "form"
},
"execution_count": null,
"outputs": []
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment