Skip to content

Instantly share code, notes, and snippets.

@OhadRubin
Last active December 28, 2021 10:43
Show Gist options
  • Save OhadRubin/de32999222e0b414ddbf92869a977b03 to your computer and use it in GitHub Desktop.
Save OhadRubin/de32999222e0b414ddbf92869a977b03 to your computer and use it in GitHub Desktop.
rewriting-interface.ipynb
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/OhadRubin/de32999222e0b414ddbf92869a977b03/rewriting-interface.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "aUdqs7fXeR0y"
},
"source": [
"### Instructions\n",
"To run the code below:\n",
"1. Read the comments, and alter settings if you wish to change the model or layer.\n",
"2. Then run all the cells to download models and run the UI.\n",
"\n",
"To operate the interface:\n",
"1. In the UI, enter numbers such as \"200-212\" in the Image Number entry to see generated images.\n",
"2. Click on a region of an image and \"Add to context\" to help define a rule to change.\n",
"3. Click \"Search\" to find other images that are affected by the same rule.\n",
"4. Click on a region of an image and \"Copy\" to choose a new pattern to use.\n",
"5. Click on a target region of an image and \"Paste\" to paste the pattern in a new place. Clicking around the red box can adjust it.\n",
"6. Click \"Execute\" to see the effects. \"Toggle\" compares to the original model, and \"Revert\" discards the edit.\n",
"\n",
"Editing a model is challenging, because you need to develop an understanding of the way the model organies its rules. You can build your intuition by using the \"Highlight\" button to see how the model generalizes regions that you select in the context.\n",
"\n",
"Particular model edits can be saved or loaded as json files; and other geeky details\n",
"can be seen in the source code at http://github.com/davidbau/rewriting. This\n",
"research code is not yet optimized for speed, and contributions are welcome.\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "-rpAqI2meR02"
},
"outputs": [],
"source": [
"%%bash\n",
"!(stat -t /usr/local/lib/*/dist-packages/google/colab > /dev/null 2>&1) && exit \n",
"pip install ninja 2>> install.log\n",
"git clone https://github.com/davidbau/rewriting.git tutorial_code 2>> install.log"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "ZgP193ZzeR03"
},
"outputs": [],
"source": [
"try: # set up path\n",
" import google.colab, sys, torch\n",
" sys.path.append('/content/tutorial_code')\n",
" if not torch.cuda.is_available():\n",
" print(\"Change runtime type to include a GPU.\") \n",
"except:\n",
" pass"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"scrolled": false,
"id": "6ZNA6tPLeR03"
},
"outputs": [],
"source": [
"from utils import zdataset, show, labwidget\n",
"from rewrite import ganrewrite, rewriteapp\n",
"import torch, copy, os, json\n",
"from torchvision.utils import save_image\n",
"import utils.stylegan2, utils.proggan\n",
"from utils.stylegan2 import load_seq_stylegan\n",
"\n",
"# Choices: ganname = 'stylegan' or ganname = 'proggan'\n",
"ganname = 'stylegan'\n",
"\n",
"# Choices: modelname = 'church' or faces' or 'horse' or 'kitchen'\n",
"modelname = 'church'\n",
"\n",
"# layer 6,7,8,9,10 work OK for different things.\n",
"# layer 8 is good for trees or domes in churches, and hats on horses\n",
"# layer 6 is good for smiles on faces\n",
"# layer 10 is good for hair on faces\n",
"layernum = 8\n",
"\n",
"# Number of images to sample when gathering statistics.\n",
"size = 1000\n",
"\n",
"# Make a directory for caching some data.\n",
"layerscheme = 'default'\n",
"expdir = 'results/pgw/%s/%s/%s/layer%d' % (ganname, modelname, layerscheme, layernum)\n",
"os.makedirs(expdir, exist_ok=True)\n",
"\n",
"# Load (and download) a pretrained GAN\n",
"if ganname == 'stylegan':\n",
" model = load_seq_stylegan(modelname, mconv='seq', truncation=0.50)\n",
" Rewriter = ganrewrite.SeqStyleGanRewriter\n",
"elif ganname == 'proggan':\n",
" model = utils.proggan.load_pretrained(modelname)\n",
" Rewriter = ganrewrite.ProgressiveGanRewriter\n",
" \n",
"# Create a Rewriter object - this implements our method.\n",
"zds = zdataset.z_dataset_for_model(model, size=size)\n",
"gw = Rewriter(\n",
" model, zds, layernum, cachedir=expdir,\n",
" low_rank_insert=True, low_rank_gradient=False,\n",
" use_linear_insert=False, # If you set this to True, increase learning rate.e\n",
" key_method='zca')\n",
"\n",
"# Display a user interface to allow model rewriting.\n",
"savedir = f'masks/{ganname}/{modelname}'\n",
"interface = rewriteapp.GanRewriteApp(gw, size=256, mask_dir=savedir, num_canvases=32)\n",
"show(interface)"
]
}
],
"metadata": {
"accelerator": "GPU",
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.7"
},
"colab": {
"name": "rewriting-interface.ipynb",
"provenance": [],
"include_colab_link": true
}
},
"nbformat": 4,
"nbformat_minor": 0
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment