Skip to content

Instantly share code, notes, and snippets.

@eyaler
Last active December 23, 2022 19:49
Show Gist options
  • Star 2 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save eyaler/0cee9a71f5dd3fdfa9c0c03656ebdd4c to your computer and use it in GitHub Desktop.
Save eyaler/0cee9a71f5dd3fdfa9c0c03656ebdd4c to your computer and use it in GitHub Desktop.
ruDALLE-Outpainting.ipynb
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/eyaler/0cee9a71f5dd3fdfa9c0c03656ebdd4c/rudalle-outpainting.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Py2MOWx5kNxH"
},
"source": [
"# ruDALL-E Outpainting\n",
"\n",
"ruDALL-E article: https://habr.com/ru/company/sberbank/blog/589673\n",
"\n",
"Original notebook: https://colab.research.google.com/github/sberbank-ai/ru-dalle/blob/master/jupyters/ruDALLE-image-prompts-A100.ipynb\n",
"\n",
"Arbitraty resolution notebook (not implemented here yet): https://colab.research.google.com/drive/1DbqOIUIVBPOrJ4MeaV4YkAlb7ilWQjKZ\n",
"\n",
"Inspired by: https://twitter.com/MichaelFriese10/status/1456023409213726725\n",
"\n",
"Experiments twitter thread I: https://twitter.com/eyaler/status/1468682110860992521\n",
"\n",
"Experiments twitter thread II: https://twitter.com/eyaler/status/1470150704488660993\n",
"\n",
"More results image gallery: https://imgur.com/gallery/tcwYSzM\n",
"\n",
"Shortcut to this notebook: [j.mp/outpaint](https://j.mp/outpaint)\n",
"\n",
"Notebook by: [Eyal Gruss](https://eyalgruss.com) \\([@eyaler](https://twitter.com/eyaler)\\)\n",
"\n",
"A curated list of online generative tools: [j.mp/generativetools](https://j.mp/generativetools)"
],
"id": "Py2MOWx5kNxH"
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "118b2319",
"cellView": "form"
},
"outputs": [],
"source": [
"#@title Setup {run: 'auto'}\n",
"\n",
"!pip install rudalle==1.0.0 > /dev/null 2>&1\n",
"!pip install ruclip==0.0.1 > /dev/null 2>&1\n",
"!pip install translators==4.11.0 > /dev/null 2>&1\n",
"\n",
"\n",
"from rudalle.pipelines import generate_images, show, super_resolution, cherry_pick_by_ruclip\n",
"from rudalle.image_prompts import ImagePrompts\n",
"from rudalle import get_rudalle_model, get_tokenizer, get_vae, get_realesrgan\n",
"from rudalle.utils import seed_everything\n",
"import ruclip\n",
"from PIL import Image, ImageOps\n",
"import torch\n",
"from google.colab import files, _message\n",
"import numpy as np\n",
"import translators\n",
"import requests\n",
"import os\n",
"import json\n",
"\n",
"\n",
"model = 'Malevich_v3' #@param ['Malevich_v3', 'Malevich_v2', 'Emojich']\n",
"if model == 'Malevich_v3':\n",
" model = 'Malevich'\n",
"dalle = get_rudalle_model(model, pretrained=True, fp16=True, device='cuda')\n",
"realesrgan = {x: get_realesrgan('x%d'%x, device='cuda') for x in [2,4,8]} \n",
"tokenizer = get_tokenizer()\n",
"vae = get_vae().to('cuda')\n",
"dwt_vae = get_vae(dwt=True).to('cuda')\n",
"clip, processor = ruclip.load('ruclip-vit-base-patch32-384', device='cuda')\n",
"clip_predictor = ruclip.Predictor(clip, processor, device='cuda')\n"
],
"id": "118b2319"
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "81b664b6"
},
"outputs": [],
"source": [
"#@title Upload images\n",
"#@markdown Click button below to select files or cancel to use default image\n",
"\n",
"#@markdown Note: Works best for square images. Consider padding/cropping your images beforehand\n",
"fallback_image_url = 'https://web.archive.org/web/20210213021410if_/https://www.gallerypop.co.uk/wp-content/uploads/2017/11/Soup.jpg'\n",
"try:\n",
" filenames\n",
"except Exception:\n",
" filenames = []\n",
"streams = []\n",
"try:\n",
" new_filenames = files.upload()\n",
"except Exception:\n",
" pass\n",
"else:\n",
" if new_filenames:\n",
" filenames = new_filenames\n",
"if not filenames:\n",
" streams = [requests.get(fallback_image_url, stream=True).raw]\n",
"orig_images = [ImageOps.exif_transpose(Image.open(f)) for f in filenames or streams]\n",
"if not filenames:\n",
" filenames = [fallback_image_url.rsplit('/',1)[-1]]"
],
"id": "81b664b6"
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "cebf4449",
"cellView": "form"
},
"outputs": [],
"source": [
"#@title Set options and run\n",
"#@markdown Take top part and complete bottom part (downpainting):\n",
"take_top = True #@param {type:'boolean'}\n",
"top_frac = 0.5 #@param {type:'slider', max:1, step:0.03125}\n",
"take_bottom = False\n",
"bottom_frac = 0.5\n",
"#take_bottom = False #@param {type:'boolean'}\n",
"#bootom_frac = 0.5 #@param {type:'slider', max:1, step:0.03125}\n",
"#@markdown Add additional runs taking different image parts (will probably not work so good):\n",
"flip_and_take_top = False #@param {type:'boolean'}\n",
"flipped_top_frac = 0.5 #@param {type:'slider', max:1, step:0.03125}\n",
"flip_result_back = 'no' #@param ['no', 'before_clip', 'finally']\n",
"take_left = False #@param {type:'boolean'}\n",
"left_frac = 0.5 #@param {type:'slider', max:1, step:0.03125}\n",
"take_right = False #@param {type:'boolean'}\n",
"right_frac = 0.5 #@param {type:'slider', max:1, step:0.03125}\n",
"#@markdown before_encoding will give less constrained results adequate for natual images but will give noncontextual completions otherwise:\n",
"crop_order = 'after_encoding' #@param ['after_encoding', 'before_encoding', 'both_after_and_before_runs']\n",
"#@markdown Increase for more diverse results, decrease for less:\n",
"temperature = 1#@param {type:'number'}\n",
"#@markdown Increase these for more outputs and better results, but will take longer. decrease for quick and dirty results:\n",
"num_levels = 1#@param {type:'slider', min:1, max:9}\n",
"retries_per_level = 4#@param {type:'integer'}\n",
"order_of_levels = 'high_to_low' #@param ['high_to_low', 'low_to_high', 'interleaved_extreme_to_middle', 'interleaved_middle_to_extreme']\n",
"#@markdown Optional text prompt (can leave empty; will auto-translate from any language to Russian - check if the back translation is OK):\n",
"text = '' #@param {type:'string'}\n",
"additional_run_without_text = False #@param {type:'boolean'}\n",
"#@markdown Output options:\n",
"fix_aspect_ratio = 'before_clip' #@param ['no', 'before_clip', 'after_clip']\n",
"paste_original_part_over_generated = 'no' #@param ['no', 'before_clip', 'between_clip_and_super_resolution', 'after_super_resolution','before_clip_and_again_after_super_resolution']\n",
"blend_frac = 0.1 #@param {type:'slider', max:1, step:0.005}\n",
"dwt_decoder_upscale = False #@param {type:'boolean'}\n",
"super_resolution_factor = 4 #@param [1, 2, 4, 8] {type:'raw'}\n",
"limit_display_results = 0#@param {type:'integer'}\n",
"display_width = 30 #@param {type:'number'}\n",
"#@markdown Set to a positive number for reproducing results (different results for different numbers):\n",
"random_seed = 42 #@param {type:'integer'}\n",
"if random_seed < 1:\n",
" random_seed = None\n",
"if random_seed:\n",
" seed_everything(random_seed)\n",
"if limit_display_results < 1 or limit_display_results > num_levels * retries_per_level:\n",
" limit_display_results = num_levels * retries_per_level\n",
"nrow = int(np.ceil((limit_display_results+2)/np.ceil((limit_display_results+2)/6)))\n",
"display_width = 30 \n",
"crop_orders = [0,1]\n",
"if crop_order == 'after_encoding':\n",
" crop_orders = [0]\n",
"elif crop_order == 'before_encoding':\n",
" crop_orders = [1]\n",
"assert take_top or flip_and_take_top or take_bottom or take_left or take_right, 'Select at least one of the take options'\n",
"\n",
"enc_size = 32\n",
"\n",
"def crop(im, borders, mask=None):\n",
" im = np.array(im)\n",
" if mask is not None:\n",
" mask = np.array(mask)\n",
" assert im.shape == mask.shape, (im.shape, mask.shape) \n",
" if borders['up'] is not None:\n",
" i = int(round(im.shape[0]*borders['up']/enc_size))\n",
" z = max(i-int(round(blend_frac*im.shape[0])), 0)\n",
" if mask is None:\n",
" im[i:] = 255\n",
" else:\n",
" im[:z] = mask[:z]\n",
" for j in range(z, i):\n",
" alpha = (j-z+1)/(i-z+1)\n",
" im[j] = im[j]*alpha+mask[j]*(1-alpha)\n",
" elif borders['down'] is not None:\n",
" i = int(round(im.shape[0]*borders['down']/enc_size))\n",
" z = min(i+int(round(blend_frac*im.shape[0])), im.shape[0])\n",
" if mask is None:\n",
" im[:i] = 255\n",
" else:\n",
" im[z:] = mask[z:]\n",
" for j in range(i, z):\n",
" alpha = (j-i+1)/(z-i+1)\n",
" im[j] = im[j]*(1-alpha)+mask[i:z]*alpha\n",
" elif borders['left'] is not None:\n",
" i = int(round(im.shape[1]*borders['left']/enc_size))\n",
" z = max(i-int(round(blend_frac*im.shape[1])), 0)\n",
" if mask is None:\n",
" im[:,i:] = 255\n",
" else:\n",
" im[:,:z] = mask[:,:z]\n",
" for j in range(z, i):\n",
" alpha = (j-z+1)/(i-z+1)\n",
" im[:,j] = im[:,j]*alpha+mask[:,j]*(1-alpha)\n",
" elif borders['right'] is not None:\n",
" i = int(round(im.shape[1]*borders['right']/enc_size))\n",
" z = min(i+int(round(blend_frac*im.shape[1])), im.shape[1])\n",
" if mask is None:\n",
" im[:,:i] = 255\n",
" else:\n",
" im[:,z:] = mask[:,z:]\n",
" for j in range(i, z):\n",
" alpha = (j-i+1)/(z-i+1)\n",
" im[:,j] = im[:,j]*(1-alpha)+mask[:,j]*alpha\n",
" im = Image.fromarray(im)\n",
" return im\n",
"\n",
"images = [im.resize((256,256)) for im in orig_images]\n",
"borders_flips = []\n",
"if take_top:\n",
" borders_flips.append(({'up': int(round(enc_size*top_frac)), 'left': None, 'right': None, 'down': None}, False))\n",
"if flip_and_take_top:\n",
" borders_flips.append(({'up': int(round(enc_size*flipped_top_frac)), 'left': None, 'right': None, 'down': None}, True))\n",
"if take_bottom:\n",
" borders_flips.append(({'up': None, 'left': None, 'right': None, 'down': int(round(enc_size*bottom_frac))}, False))\n",
"if take_left:\n",
" borders_flips.append(({'up': None, 'left': int(round(enc_size*left_frac)), 'right': None, 'down': None}, False))\n",
"if take_right:\n",
" borders_flips.append(({'up': None, 'left': None, 'right': int(round(enc_size*right_frac)), 'down': None}, False))\n",
"\n",
"def simple_detect_lang(text):\n",
" if len(set('абвгдежзийклмнопрстуфхцчшщъыьэюяё').intersection(text.lower())):\n",
" return 'ru'\n",
" if len(set('אבגדהוזחטיכךלמםנןסעפצץקרשת').intersection(text)):\n",
" return 'iw'\n",
" if len(set('abcdefghijklmnopqrstuvwxyz').intersection(text.lower())):\n",
" return 'en'\n",
" return 'auto'\n",
"\n",
"if text:\n",
" orig_text = text\n",
" lang = simple_detect_lang(text)\n",
" if lang != 'ru':\n",
" text = translators.google(text, from_language=lang, to_language='ru')\n",
" back_text = translators.google(text, from_language='ru', to_language=lang if lang not in ['auto', 'ru'] else 'en')\n",
" print('original text:', orig_text)\n",
" print('language detected:', lang)\n",
" print('prompt in russian:', text)\n",
" print('back translation:' if lang not in ['auto','ru'] else 'english translation', back_text)\n",
"texts = [text]\n",
"if text and additional_run_without_text:\n",
" texts.append('')\n",
"\n",
"levels = [\n",
" (2048, 0.995),\n",
" (1024, 0.98),\n",
" (1536, 0.99),\n",
" (1024, 0.99),\n",
" (512, 0.97),\n",
" (384, 0.96),\n",
" (256, 0.95),\n",
" (128, 0.95),\n",
" (64, 0.92),\n",
" ]\n",
"level_indices = list(range(len(levels)))\n",
"if order_of_levels == 'low_to_high':\n",
" level_indices.reverse()\n",
"elif order_of_levels.startswith('interleaved'):\n",
" level_indices = [i for pair in zip(level_indices, reversed(level_indices)) for i in pair][:len(levels)]\n",
" if order_of_levels == 'interleaved_middle_to_extreme':\n",
" level_indices.reverse()\n",
"\n",
"save_dir = None\n",
"save_dirs = []\n",
"os.makedirs('/content/output', exist_ok=True)\n",
"notebook = _message.blocking_request('get_ipynb', timeout_sec=60)\n",
"all_hires = []\n",
"for j, (image, orig, filename) in enumerate(zip(images, orig_images, filenames), start=1):\n",
" print('%d/%d: %s'%(j,len(filenames),filename))\n",
" for borders, flip in borders_flips:\n",
" if flip:\n",
" image = image.transpose(Image.FLIP_TOP_BOTTOM)\n",
" orig = orig.transpose(Image.FLIP_TOP_BOTTOM)\n",
" for text_to_use in texts:\n",
" for crop_first in crop_orders:\n",
" out_images = []\n",
" scores = []\n",
" image_prompt = ImagePrompts(image, {k: v or 0 for k, v in borders.items()}, dwt_vae if dwt_decoder_upscale else vae, device='cuda', crop_first=crop_first)\n",
" for i in level_indices[:num_levels]:\n",
" top_k, top_p = levels[i]\n",
" _pil_images, _scores = generate_images(\n",
" text_to_use,\n",
" tokenizer,\n",
" dalle,\n",
" dwt_vae if dwt_decoder_upscale else vae,\n",
" top_k=top_k,\n",
" top_p=top_p,\n",
" images_num=retries_per_level,\n",
" image_prompts=image_prompt,\n",
" temperature=temperature,\n",
" seed=random_seed\n",
" )\n",
" out_images += _pil_images\n",
" scores += _scores\n",
" aspect = (1, orig.size[1]/orig.size[0]) if orig.size[1]>orig.size[0] else (orig.size[0]/orig.size[1], 1) if orig.size[0]>orig.size[1] else (1,1)\n",
" if fix_aspect_ratio == 'before_clip' and (aspect[0]!=1 or aspect[1]!=1):\n",
" out_images = [im.resize((int(aspect[0]*im.size[0]), int(aspect[1]*im.size[0]))) for im in out_images]\n",
" rescaled_orig = None\n",
" if paste_original_part_over_generated.startswith('before_clip'):\n",
" if image.size == out_images[0].size: \n",
" rescaled_orig = image\n",
" elif orig.size == out_images[0].size:\n",
" rescaled_orig = orig\n",
" else:\n",
" rescaled_orig = orig.resize(out_images[0].size)\n",
" out_images = [crop(im, borders, mask=rescaled_orig) for im in out_images]\n",
" if text_to_use:\n",
" if flip and flip_result_back == 'before_clip':\n",
" out_images = [im.transpose(Image.FLIP_TOP_BOTTOM) for im in out_images]\n",
" out_images, _ = cherry_pick_by_ruclip(out_images, text_to_use, clip_predictor, count=None)\n",
" if flip and flip_result_back == 'before_clip':\n",
" out_images = [im.transpose(Image.FLIP_TOP_BOTTOM) for im in out_images]\n",
" else:\n",
" out_images, _ = zip(*sorted(zip(out_images, scores), key=lambda x: x[1]))\n",
" out_images = list(out_images)\n",
" if fix_aspect_ratio == 'after_clip' and (aspect[0]!=1 or aspect[1]!=1):\n",
" out_images = [im.resize((int(aspect[0]*im.size[0]), int(aspect[1]*im.size[0]))) for im in out_images]\n",
" if paste_original_part_over_generated == 'between_clip_and_super_resolution':\n",
" if rescaled_orig is None or rescaled_orig.size != out_images[0].size: \n",
" if image.size == out_images[0].size: \n",
" rescaled_orig = image\n",
" elif orig.size == out_images[0].size:\n",
" rescaled_orig = orig\n",
" else:\n",
" rescaled_orig = orig.resize(out_images[0].size) \n",
" out_images = [crop(im, borders, mask=rescaled_orig) for im in out_images]\n",
" if super_resolution_factor > 1:\n",
" out_images = super_resolution(out_images, realesrgan[super_resolution_factor])\n",
" if rescaled_orig is None or rescaled_orig.size != out_images[0].size: \n",
" if image.size == out_images[0].size: \n",
" rescaled_orig = image\n",
" elif orig.size == out_images[0].size:\n",
" rescaled_orig = orig\n",
" else:\n",
" rescaled_orig = orig.resize(out_images[0].size)\n",
" if paste_original_part_over_generated.endswith('after_super_resolution'):\n",
" out_images = [crop(im, borders, mask=rescaled_orig) for im in out_images] \n",
" out_images = [rescaled_orig, crop(rescaled_orig, borders)] + out_images\n",
" if flip and flip_result_back != 'no':\n",
" out_images = [im.transpose(Image.FLIP_TOP_BOTTOM) for im in out_images] \n",
" all_hires.append(out_images)\n",
" folders = [int(folder) for folder in os.listdir('/content/output') if os.path.isdir(os.path.join('/content/output', folder)) and folder.isnumeric()]\n",
" save_dir = os.path.join('/content/output', '%04d'%(max(folders, default=0)+1)) \n",
" save_dirs.append(save_dir)\n",
" x, y = out_images[0].size\n",
" out_images = out_images[:limit_display_results+2]\n",
" show(out_images, nrow, save_dir=save_dir, size=display_width * max(1, (len(out_images)//nrow) / nrow * y / x))\n",
" if notebook:\n",
" with open(os.path.join(save_dir, 'notebook.ipynb'), 'w', encoding='utf8') as f:\n",
" json.dump(notebook, f)\n",
"if save_dirs:\n",
" print('files saved to output folders:')\n",
" for save_dir in save_dirs:\n",
" print(save_dir)"
],
"id": "cebf4449"
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "yULZADI6pGqH"
},
"outputs": [],
"source": [
"#@title Download high-resolution images and notebook copy\n",
"if save_dirs:\n",
" save_file = os.path.join('/content', save_dirs[-1].rsplit('/',1)[-1] + '.zip') \n",
" if len(save_dirs)==1:\n",
" save_dirs_str = save_dirs[-1]\n",
" !zip -rjqFS $save_file $save_dirs_str\n",
" else:\n",
" save_dirs_str = ' '.join(save_dir.rsplit('/',1)[-1] for save_dir in save_dirs)\n",
" %pushd /content/output\n",
" !zip -rqFS $save_file $save_dirs_str\n",
" %popd\n",
" files.download(save_file)"
],
"id": "yULZADI6pGqH"
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "1GJ90t3GSQ4T"
},
"outputs": [],
"source": [
"#@title Display all images together\n",
"limit_display_results = 0#@param {type:'integer'}\n",
"display_width = 30#@param {type:'number'}\n",
"if limit_display_results < 1 or limit_display_results > num_levels * retries_per_level:\n",
" limit_display_results = num_levels * retries_per_level\n",
"nrow = int(np.ceil((limit_display_results+2)/np.ceil((limit_display_results+2)/6)))\n",
"for hires_images in all_hires:\n",
" x, y = hires_images[0].size\n",
" hires_images = hires_images[:limit_display_results+2]\n",
" show(hires_images, nrow, size=display_width * max(1, (len(hires_images)//nrow) / nrow * y / x))"
],
"id": "1GJ90t3GSQ4T"
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"collapsed_sections": [],
"machine_shape": "hm",
"name": "ruDALLE-Outpainting.ipynb",
"provenance": [],
"private_outputs": true,
"include_colab_link": true
},
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.12"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment