Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save un1tz3r0/593f47cde9888b4d33e2d87ea681956c to your computer and use it in GitHub Desktop.
Save un1tz3r0/593f47cde9888b4d33e2d87ea681956c to your computer and use it in GitHub Desktop.
stylegan3_training_and_inference_2022_02_11.ipynb
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/un1tz3r0/593f47cde9888b4d33e2d87ea681956c/stylegan3_fine_tuning_with_colab_2022_01_01.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "markdown",
"source": [
"# <big><big><big><big><big>Training StyleGAN3</big></big></big></big></big>\n",
"\n",
"This is my setup for training NVidia's [StyleGAN3](https://github.com/NVlabs/stylegan3) \\(aka [Alias-Free GAN](https://nvlabs.github.io/stylegan3/)\\) on Google's [Colab Pro+ GPUs](https://colab.research.google.com/). For a measly $50/month you can train from scratch or from a pretrained network as a starting point your very own translation/rotation-invariant alias-free GAN for all kinds of fun applications. All you need is some images to train with.\n",
"\n",
"This work was inspired by many projects, most directly the work of [Max Braun](https://braun.design/), whose [eBoyGAN](https://onezero.medium.com/how-i-accidentally-created-an-infinite-pixel-hellscape-fe070551365f) I have attempted here to replicate using the more recent alias-free StyleGAN3. This notebook contains everything you will need to reproduce my results, and hopefully will act as a starting point for similar projects.\n",
"\n",
"One of the biggest motivations for this work and for publishing this notebook and the github repos below was the lack of pretrained models available for experimenting with inference, projection and semantic editing and other applications that manipulate images in the GAN latent space. There are many examples that use older models such as StyleGAN2, but they all suffer from poor image quality and artifacts due to the issues addressed by StyleGAN3.\n",
"\n",
"\\- [Victor Condino](https://twitter.com/un1tz3r0/)\n",
"\n",
"## Github repos used by this notebook:\n",
"\n",
"- [un1tz3r0/stylegan3](https://github.com/un1tz3r0/stylegan3.git)\n",
"- [un1tz3r0/pixelscapes-dataset](https://github.com/un1tz3r0/pixelscapes-dataset.git)\n",
"\n",
"\n",
"<big><big>*If you wish to support my work, please donate generously to:*\n",
"\n",
"BTC: [3MzZseGSqXFo6GmJxthVRvdCre8CR2F1QJ](https://www.blockchain.com/btc/address/3MzZseGSqXFo6GmJxthVRvdCre8CR2F1QJ)\n",
"\n",
"Ethereum/Polygon/ERC20 Tokens: [0x0480409E69c4c89EeB4cDb84111B63976E56c389](https://www.blockchain.com/eth/address/0x0480409E69c4c89EeB4cDb84111B63976E56c389)\n",
"\n"
],
"metadata": {
"id": "TV85oI4C-AoM"
}
},
{
"cell_type": "markdown",
"source": [
"\n",
"\n",
"---\n",
"\n"
],
"metadata": {
"id": "W4MzW9Tyx8bY"
}
},
{
"cell_type": "code",
"source": [
"#@title <big><big><big><big>Notebook Mode</big></big></big></big> { display-mode: \"form\" }\n",
"#@markdown This notebook has some common setup cells followed by a section containing cells specific to training and one specific to inference. This disables either the training section or inference section, so that you can set it and do \"Run All Cells\"\n",
"\n",
"notebook_mode = \"Training\" #@param [\"Training\", \"Inference\"]\n",
"\n",
"\n"
],
"metadata": {
"id": "ENcxrjI-xCZe"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"# <big><big><big><big>Setup</big></big></big></big>\n"
],
"metadata": {
"id": "P9YyiOFP745y"
}
},
{
"cell_type": "code",
"source": [
"#@title <big><big><big>Connect with Google Drive</big></big></big> { vertical-output: true, display-mode: \"form\" }\n",
"#@markdown Run this cell to authorize the runtime instance to access files on your Google Drive.\n",
"#@markdown\n",
"#@markdown The files will be placed in folder called 'stylegan3-training' in the top-level [My Drive] folder, which is created if it does not aleady exist. You can specify a different folder to use below:\n",
"\n",
"folderprefix = \"/stylegan3-training\" #@param {type:\"string\"}\n",
"\n",
"default_rclone_config_path = \"/content/rclone.config\"\n",
"\n",
"from shutil import make_archive\n",
"\n",
"def configure_rclone(configfile=default_rclone_config_path, folderprefix=folderprefix, overwrite=True, clearauth=True):\n",
" import pathlib\n",
" if (not pathlib.Path(configfile).exists()) or overwrite:\n",
" # Import PyDrive and associated libraries.\n",
" # This only needs to be done once per notebook.\n",
" from pydrive.auth import GoogleAuth\n",
" from pydrive.drive import GoogleDrive\n",
" from google.colab import auth\n",
" from oauth2client.client import GoogleCredentials\n",
" import json, datetime, tzlocal\n",
" import httplib2\n",
" import google.colab.drive\n",
" from pytz import timezone\n",
"\n",
" if clearauth:\n",
" google.colab.drive.flush_and_unmount()\n",
"\n",
" # Authenticate and create the PyDrive client.\n",
" # This only needs to be done once per notebook.\n",
" auth.authenticate_user()\n",
" gauth = GoogleAuth()\n",
" gauth.credentials = GoogleCredentials.get_application_default()\n",
" drive = GoogleDrive(gauth)\n",
" #print(drive.auth.credentials.get_access_token())\n",
"\n",
" # we have authenticated, write the credentials to rclone.config\n",
" rcloneconfig = \"\\n\".join([\n",
" '[driveapi]',\n",
" 'type = alias',\n",
" f'remote = driveroot:{folderprefix}',\n",
" '',\n",
" '[driveroot]',\n",
" 'type = drive',\n",
" f'client_id = {drive.auth.credentials.client_id}',\n",
" f'client_secret = {drive.auth.credentials.client_secret}',\n",
" 'scope = drive.file',\n",
" 'token = {}'.format(json.dumps({\n",
" \"access_token\":drive.auth.credentials.get_access_token().access_token,\n",
" \"token_type\":\"Bearer\",\n",
" \"refresh_token\":drive.auth.credentials.refresh_token,\n",
" \"expiry\":drive.auth.credentials.token_expiry.astimezone(tzlocal.get_localzone()).astimezone(datetime.timezone.utc).isoformat().replace(\"Z\", \"+\")\n",
" }))\n",
" ])\n",
"\n",
" with open(configfile, \"wt\") as fout:\n",
" print(f\"Writing rclone remote configuration with Google Drive auth token to {configfile}...\")\n",
" wrsz = fout.write(rcloneconfig)\n",
" print(f\"... wrote {wrsz} bytes\")\n",
"\n",
"# ------------------------------------------------------------------------------\n",
"# install rclone\n",
"# ------------------------------------------------------------------------------\n",
"\n",
"def install_rclone_from_github():\n",
" #!sudo apt install golang\n",
" %cd /content\n",
" !wget https://go.dev/dl/go1.17.5.linux-amd64.tar.gz\n",
" !rm -rf /usr/local/go && tar -C /usr/local -xzf go1.17.5.linux-amd64.tar.gz\n",
" \n",
" def extend_path(searchdir):\n",
" \n",
" # add the directory to the python interpreter's PATH env var (avoiding duplicates), \n",
" # so it takes effect immediately\n",
" import os, shlex\n",
" os.environ['PATH']=\":\".join([*[p for p in os.environ['PATH'].split(\":\") if p != \"/usr/local/go/bin\"], \"/usr/local/go/bin\"])\n",
" \n",
" # add the directory to the current user's .profile, which is sourced by shells\n",
" # on startup. also avoiding duplicates\n",
" lines = []\n",
" with open(os.path.expanduser(\"~/.profile\"), \"rt\") as fin:\n",
" lines = fin.readlines()\n",
" foundline = False\n",
" addline = 'export PATH=\"${PATH}:\"' + shlex.quote(searchdir)\n",
" for line in lines:\n",
" if line.strip() == addline:\n",
" foundline = True\n",
" break\n",
" if not foundline:\n",
" lines.append(addline)\n",
" with open(os.path.expanduser(\"~/.profile\"), \"wt\") as fout:\n",
" fout.write(\"\\n\".join(lines))\n",
" \n",
"\n",
" #!echo 'export PATH=\"${PATH}:/usr/local/go/bin\"' >> ~/.profile\n",
" \n",
" extend_path(\"/usr/local/go/bin\")\n",
" !go get github.com/rclone/rclone\n",
" # %cd /content\n",
" # !git clone https://github.com/rclone/rclone.git\n",
" # %cd rclone\n",
" # !make\n",
" # !sudo make install\n",
"\n",
"\n",
"def install_rclone_via_shell(quiet=True):\n",
" import pathlib, shutil\n",
" from google.colab import output\n",
" import json, binascii\n",
"\n",
" # check if rclone is in $PATH\n",
" if shutil.which(\"rclone\") == None:\n",
" # no rclone, install it!\n",
" print(\"Downloading and running rclone install.sh...\")\n",
" if not quiet:\n",
" !bash -c 'cd /content; curl https://rclone.org/install.sh | sudo bash'\n",
" else: \n",
" !bash -c 'cd /content; curl https://rclone.org/install.sh 2>/dev/null | sudo bash >/dev/null 2>&1'\n",
" assert(shutil.which(\"rclone\") != None)\n",
" else:\n",
" if not quiet:\n",
" print(\"It appears rclone is already installed!\")\n",
"\n",
"print(\"Authorizing notebook to use google drive...\")\n",
"configure_rclone(default_rclone_config_path)\n",
"\n",
"print(\"Downloading and installing rclone...\")\n",
"try:\n",
" import shutil\n",
" install_rclone_via_shell(quiet=True)\n",
" assert(shutil.which(\"rclone\") != None)\n",
"except:\n",
" install_rclone_from_github()\n",
" assert(shutil.which(\"rclone\") != None)\n",
"\n",
"print(\"Testing rclone google drive remote...\")\n",
"!rclone --config=$default_rclone_config_path touch driveapi:.timestamp\n",
"print(\"... success!\")\n",
"\n",
"# -------------------------------------------------------------------\n",
"# Python subprocess wrappers for programmatically using rclone cli\n",
"# -------------------------------------------------------------------\n",
"\n",
"import os, pathlib\n",
"import subprocess, json\n",
"\n",
"if 'default_rclone_config_path' not in vars().keys():\n",
" default_rclone_config_path = \"/content/rclone.config\"\n",
"\n",
"def rclone(*args, output=\"pass\", check=True, config=None):\n",
" global default_rclone_config_path\n",
" if config == None:\n",
" config = default_rclone_config_path\n",
" if config != None:\n",
" args = list([f\"--config={config}\", *args])\n",
" \n",
" if output == \"pass\":\n",
" p = subprocess.run([\"rclone\", *args], check=check, stderr=subprocess.STDOUT)\n",
" if not check:\n",
" return p.returncode\n",
" else:\n",
" p = subprocess.run([\"rclone\", *args], check=check, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True)\n",
" if output == \"raw\":\n",
" if not check:\n",
" return p.stdout, p.returncode\n",
" else:\n",
" return p.stdout\n",
" elif output == \"json\":\n",
" try:\n",
" jsonout = json.loads(p.stdout)\n",
" except json.JSONDecodeError as err:\n",
" jsonout = None\n",
" if not check:\n",
" return jsonout, p.returncode\n",
" else:\n",
" return jsonout\n",
"\n",
"def rclonels(*args, max_depth = None):\n",
" files, retcode = rclone(\"lsjson\", *args, *([f\"--max-depth={max_depth}\"] if max_depth != None else []), output=\"json\", check=False)\n",
" if retcode != 0:\n",
" return []\n",
" if files == None:\n",
" return []\n",
" return files\n",
"\n",
"def syncnewestmatchingfile(searchdirs, pattern, max_depth = None, download_to = None):\n",
" import fnmatch\n",
" allfiles = []\n",
" for searchdir in searchdirs:\n",
" dirfiles = rclonels(searchdir, max_depth=max_depth)\n",
" allfiles = list(allfiles) + list([(pathlib.Path(searchdir)/df['Path'], df['ModTime']) for df in dirfiles if fnmatch.fnmatch(df['Name'], pattern)])\n",
" if len(allfiles) < 1:\n",
" return None # no matching files were found\n",
" newestfile = str(list([g[0] for g in list(sorted(allfiles, key=lambda j: j[1]))])[-1])\n",
" if download_to != None:\n",
" downloadedfile = str(pathlib.Path(download_to) / pathlib.Path(newestfile).name)\n",
" if str(pathlib.Path(newestfile)).startswith(str(pathlib.Path(download_to))):\n",
" return str(pathlib.Path(newestfile))\n",
" print(f\"Downloading newest file matching pattern {repr(pattern)} in search dirs {searchdirs} to {download_to}: {newestfile}\")\n",
" rclone(\"copy\", \"--progress\", newestfile, download_to, output=\"pass\", check=False)\n",
" print(f\"... done downloading {downloadedfile}\")\n",
" return downloadedfile\n",
" else:\n",
" return str(newestfile)\n",
"\n",
"def localnewestmatchingfile(dirpath, pattern):\n",
" import pathlib\n",
" sortedmatches = list([str(f) for f in sorted(pathlib.Path(dirpath).glob(pattern), key=lambda f: f.stat().st_mtime)])\n",
" if len(sortedmatches) > 0:\n",
" return sortedmatches[-1]\n",
" else:\n",
" return None\n"
],
"metadata": {
"id": "F97f4CeSkfiG"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "00S7pPSdhZ5g",
"cellView": "form"
},
"outputs": [],
"source": [
"#@title <big><big><big>Install **StyleGAN3** Fork</big></big></big>\n",
"#@markdown And various python dependencies it needs to run. [un1tz3r0/stylegan3](https://github.com/un1tz3r0/stylegan3.git)\n",
"\n",
"!pip install einops ninja gdown aiohttp\n",
"\n",
"import os\n",
"%cd /content/\n",
"\n",
"#!rm -rf /content/stylegan3\n",
"if not os.path.isdir('/content/stylegan3/'):\n",
" !git clone https://github.com/un1tz3r0/stylegan3.git /content/stylegan3/\n",
"else:\n",
" %cd /content/stylegan3/\n",
" !git pull\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "L4Y28xIshBYs"
},
"source": [
"# <big><big><big>Prepare Model and Data</big></big></big>\n",
"If this is your first time running this notebook, the first two cells of this section won't do much, because you don't have any model snapshots on your Google Drive yet. The dropdown that appears after running the first cell, *Select a Model Snapshot to Resume Training*, will only contain the values `none` and the url(s) of the pretrained models distributed by the NVidia team with the official StyleGAN3 Alias-Free GANs paper. Once you have successfully started training a model, the snapshots produced will appear in this dropdown as starting points for future training runs."
]
},
{
"cell_type": "code",
"source": [
"#@title <big><big><big>Select a **Model Snapshot**</big></big></big> { vertical-output: true, display-mode: \"form\" }\n",
"\n",
"#@markdown Run this cell to display a list of all `network-snapshot-*.pkl` files on your Google drive, sorted in ascending chronological order by modification time.\n",
"#@markdown \n",
"#@markdown *the newest model snapshot found on google drive will be automatically selected*\n",
"#@markdown\n",
"#@markdown When you select a model, if there is a matching fakes*.png in the same folder it will be downloaded and shown below the dropdown to preview the selected model's generator output.\n",
"\n",
"import ipywidgets as widgets\n",
"from IPython import display\n",
"from fnmatch import fnmatch\n",
"import pathlib\n",
"\n",
"#if \"modelselect\" not in globals().keys():\n",
"modelselect = None\n",
"#if \"modellayout\" not in globals().keys():\n",
"modellayout = None\n",
"#if \"modelpreview\" not in globals().keys():\n",
"modelpreview = None\n",
"nodownloadpreview = False\n",
"\n",
"def modelselectchanged(evt):\n",
" global modelselection\n",
" global modelselect\n",
" global modelpreview\n",
" global nodownloadpreview\n",
"\n",
" newvalue = evt['new']\n",
" if isinstance(newvalue, int):\n",
" newvalue = modelselect.options[evt['new']]\n",
" modelselection = newvalue\n",
"\n",
" if modelpreview != None and not nodownloadpreview:\n",
" pngpath = pathlib.Path(str(newvalue.replace('network-snapshot-', 'fakes').replace('.pkl', '.png')))\n",
" if str(pngpath).startswith(\"driveapi:\"):\n",
" localpngpath = pathlib.Path(\"/content\") / pngpath.name\n",
" if not localpngpath.exists():\n",
" print(f\"Downloading preview image {pngpath} for model {newvalue} to {localpngpath}...\")\n",
" output, resultcode = rclone(\"copyto\", \"-P\", str(pngpath), str(localpngpath), check=False, output=\"raw\")\n",
" print(f\"download {str(pngpath)} to {str(localpngpath)} finished with result code {resultcode}: ouput={repr(output)}\")\n",
" pngpath = localpngpath\n",
" if pngpath.exists():\n",
" print(f\"found preview fakes png {repr(str(pngpath))} for model snapshot {repr(evt['new'])}, loading image in browser...\")\n",
" modelpreview.set_value_from_file(str(pngpath))\n",
" print(f\"loaded preview fakes png {repr(str(pngpath))}!\")\n",
" else:\n",
" print(f\"missing preview fakes png {repr(str(pngpath))} for model snapshot {repr(evt['new'])}\")\n",
" modelpreview.value = bytes()\n",
" elif nodownloadpreview:\n",
" print(\"Not downloading preview (choose a model from the dropdown to show the associated fakesN.png preview...)\")\n",
" nodownloadpreview = False\n",
"\n",
"def updatemodelselect():\n",
" global modelselect\n",
" global modelpreview\n",
" global modellayout\n",
" global nodownloadpreview\n",
" global modelselection\n",
"\n",
" nodownloadpreview = True\n",
" if modelpreview == None:\n",
" modelpreview = widgets.Image()\n",
" print(\"Getting list of models from google drive...\")\n",
" models = [str(pathlib.Path(\"driveapi:\") / f['Path']) for f in sorted(rclonels(\"driveapi:\", max_depth=3), key=lambda f: f['ModTime']) if fnmatch(f['Name'], \"network-snapshot-*.pkl\")]\n",
" models = ['none', 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/stylegan3-t-ffhqu-256x256.pkl'] + models\n",
" if 'modelselect' not in vars().keys() or modelselect == None:\n",
" modelselect = widgets.Dropdown(options=models)\n",
" modelselection = models[-1]\n",
" else:\n",
" modelselection = modelselect.value\n",
" modelselect.options = models\n",
" modelindex = list([-1, *[n for n, m in enumerate(models) if modelselection == m]])[-1]\n",
" modelselect.unobserve_all()\n",
" modelselect.observe(modelselectchanged, ['value', 'index'])\n",
" modelselect.index = modelindex\n",
" if modellayout == None:\n",
" modellayout = widgets.VBox(children=(modelselect, modelpreview), layout=widgets.Layout(height='auto', width='auto'))\n",
"\n",
"nodownloadpreview = True\n",
"updatemodelselect()\n",
"display.display(modellayout)\n",
"#nodownloadpreview = False\n",
"\n",
"\n"
],
"metadata": {
"id": "iy2jeejJMkQ0"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"#@title <big><big><big>**Download** the selected **Model Snapshot**</big></big></big> { vertical-output: true, display-mode: \"form\" }\n",
"\n",
"#@markdown Run this cell to download the model snapshot selected in the previous cell and set the downloaded file to be used for training and inference in the following sections.\n",
"\n",
"#latestpkl = syncnewestmatchingfile([\"driveapi:/\"], pattern=\"network-snapshot-*.pkl\", download_to=\"/content\")\n",
"modelpath = modelselection\n",
"if modelpath.startswith(\"driveapi:\"):\n",
" newmodelpath = str(pathlib.Path(\"/content\") / pathlib.Path(modelpath).name)\n",
" print(f\"Downloading model {repr(modelpath)} to local file {repr(newmodelpath)}...\")\n",
" out, code = rclone(\"copyto\", modelpath, newmodelpath, check=False, output=\"raw\")\n",
" print(f\"rclone_exit_code={repr(code)}, rclone_output={repr(out)}\")\n",
" modelpath = newmodelpath\n",
"print(f\"\\n\\nUsing network snapshot: {repr(modelpath)}...\")"
],
"metadata": {
"id": "F9MG7d12C83-"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "bECnYt6KXn8y"
},
"outputs": [],
"source": [
"#@title <big><big><big>Prepare **Dataset**</big></big></big> { vertical-output: true, display-mode: \"form\" }\n",
"\n",
"#@markdown Prepare the training image set by downloading it or optionally generating new images from the github repo of large source images. First check google drive for dataset.zip, and use that if found. If not, optionally generate a new dataset to use and upload it for future runs. \n",
"\n",
"#@markdown Options:\n",
"generate_missing_dataset = True #@param {type:\"boolean\"}\n",
"#@markdown > Check this box to enable generating a new dataset from the images and script in my [pixelscapes-dataset repo](https://github.com/un1tz3r0/pixelscapes-dataset.git)\n",
"force_regenerate_dataset = False #@param {type:\"boolean\"}\n",
"#@markdown > Check this box to skip checking google drive for dataset.zip, and rebuild a new dataset from the pixelscapes-dataset repo's source images and randomcrops.py script. When done, the new dataset will be uploaded via rclone to Google Drive (as dataset.zip, an existing dataset.zip, if present, will be backed up)\n",
"generate_dataset_count = 200000 #@param {type:\"integer\"}\n",
"#@markdown > Size of the dataset to generate, in number of training images. These will be random crops from the source images, weighted by relative size so all pixels contribute equally to the training. When generating a new dataset from source images. output this many randomly cropped squares\n",
"upscale_factor = 2.0#@param {type:\"number\", min:1.0, max:4.0}\n",
"#@markdown > Zoom original images using a pretrained superresolution model with RealESRGAN by this factor before randomly cropping.\n",
"weighting_amount = 0.25 #@param {type:\"number\", min:0.0, max:1.0}\n",
"#@markdown > Amount of weighting based on source image size to use when sampling source images. 1.0=probability is proportional to ${width} \\times {height}$, 0.0 = even probability\n",
"unzip_dataset = True #@param {type: \"boolean\"}\n",
"#@markdown > Extract the dataset.zip to /content/dataset (needs patched StyleGAN3 train.py, which is used by this notebook already.)\n",
"\n",
"# -----------------------------------------------------------------------------------------------\n",
"if 0: #notebook_mode != \"Training\":\n",
" # never get here, we use the training images for projection targets in the inference section below, so prepare the dataset either way!\n",
" print(\"No action taken, training is not enabled for this notebook_mode.\")\n",
"else:\n",
" # ---------------------------------------------------------------------------------------------\n",
"\n",
" datasetpath = None #\"/content/drive/dataset.zip\"\n",
"\n",
" if force_regenerate_dataset:\n",
" print(\"Forcing regeneration of dataset.zip, will back up existing dataset.zip on drive first...\")\n",
" if len(rclonels(\"driveapi:/dataset.zip\")) > 0:\n",
" generate_missing_dataset = True\n",
" import fnmatch, re\n",
" datasets = [f['Name'] for f in rclonels(\"driveapi:/\") if fnmatch.fnmatch(f['Name'], 'dataset-*.zip')]\n",
" datasetnumbers = [int(m.group(1)) for m in [re.match(\"^dataset-([0-9]+)\\.zip$\", f) for f in datasets] if m != None]\n",
" if len(datasetnumbers) < 1:\n",
" nextdatasetbackupnumber = 1\n",
" else:\n",
" nextdatasetbackupnumber = max(datasetnumbers) + 1\n",
" print(f\"Renaming existing dataset.zip on drive to dataset-{nextdatasetbackupnumber}.zip before generating new dataset...\")\n",
" rclone(\"moveto\", \"driveapi:/dataset.zip\", f\"driveapi:/dataset-{nextdatasetbackupnumber}.zip\", check=True)\n",
" if pathlib.Path(\"/content/dataset.zip\").exists():\n",
" print(\"Removing existing dataset.zip since we are about to generate a new one!\")\n",
" pathlib.Path(\"/content/dataset.zip\").unlink()\n",
" print(\"... removed /content/dataset.zip!\")\n",
" elif not os.path.exists(\"/content/dataset.zip\"):\n",
" # okay, first we check if there is a dataset.zip on google drive\n",
" print(\"Checking drive for existing dataset.zip to download...\")\n",
" if len(rclonels(\"driveapi:/dataset.zip\")) > 0:\n",
" # if it exists, download dataset from drive to local storage for speed\n",
" print(\"Copying dataset.zip from drive to /content...\")\n",
" resultcode = rclone(\"copyto\", \"--progress\", \"--stats=2s\", \"driveapi:/dataset.zip\", \"/content/dataset.zip\", output=\"pass\", check=False)\n",
" datasetpath = \"/content/dataset.zip\"\n",
" datasetbytes = pathlib.Path(datasetpath).stat().st_size\n",
" print(f\"... downloaded {(datasetbytes//1024//1024//10.24)/100} GB {datasetpath} from drive.\")\n",
" else:\n",
" # looks like the dataset.zip is missing from google drive, oops\n",
" print(\"No dataset.zip found on Google Drive! Nothing synced.\")\n",
"\n",
" # check if we need to generate a dataset from sources\n",
" if ((not os.path.exists(\"/content/dataset.zip\")) and \\\n",
" generate_missing_dataset) or force_regenerate_dataset:\n",
" print(\"Generating missing dataset.zip from github repo now!\")\n",
" # create a new dataset.zip from the source images and randomcrop script in our github repo\n",
" %cd '/content/'\n",
"\n",
" def install_realesrgan():\n",
" %cd '/content/'\n",
" # Clone Real-ESRGAN and enter the Real-ESRGAN\n",
" !git clone https://github.com/xinntao/Real-ESRGAN.git\n",
" %cd Real-ESRGAN\n",
" !pip3 install -r requirements.txt\n",
" !python3 setup.py develop --user\n",
" # Download the pre-trained model\n",
" !wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth -P experiments/pretrained_models\n",
"\n",
" def upscale_all(inputdir, outputdir, factor = 2.0):\n",
" import pathlib\n",
" outpath = pathlib.Path(outputdir).absolute()\n",
" if not (outpath.exists() and outpath.is_dir()):\n",
" outpath.mkdir()\n",
" infiles = [f.absolute() for f in pathlib.Path(inputdir).glob(\"*.png\")]\n",
" for f in infiles:\n",
" fin = str(f)\n",
" fo = outpath / f.name\n",
" if fo.exists():\n",
" print(f\"Skipping existing output file: {fo}\")\n",
" continue\n",
" #fo = outpath / f\"{fo.stem}-out.{fo.suffix}\"\n",
" fout = str(fo.parent.absolute())\n",
" print(f\"Upscaling x{factor}: {fin} -> {fout}\")\n",
" %cd /content/Real-ESRGAN\n",
" !python3 inference_realesrgan.py -i $fin --outscale $factor -o $fout -n RealESRGAN_x4plus\n",
"\n",
" import os, pathlib\n",
" if not pathlib.Path(\"pixelscapes-dataset\").exists():\n",
" !git clone https://github.com/un1tz3r0/pixelscapes-dataset.git\n",
" else:\n",
" %cd /content/pixelscapes-dataset\n",
" !git diff --no-ext-diff --quiet --exit-code || rm -Rf cropped ../dataset.zip\n",
" !git pull\n",
" \n",
" %cd /content\n",
" if not os.path.exists(\"/content/pixelscapes-dataset/scaled/\"):\n",
" #if True:\n",
" print(\">>> Installing upscaler network to zoom 2x source images\")\n",
" install_realesrgan()\n",
" print(\">>> Upscaling raw dataset images...\")\n",
" upscale_all(\"/content/pixelscapes-dataset/pixelscapes/\", \\\n",
" \"/content/pixelscapes-dataset/scaled\", upscale_factor)\n",
" print(\">>> Done, now cropping from upscaled images...\")\n",
" \n",
" %cd /content\n",
" if not os.path.exists(\"/content/pixelscapes-dataset/cropped\"):\n",
" !python3 pixelscapes-dataset/randomcrops.py \\\n",
" pixelscapes-dataset/scaled \\\n",
" pixelscapes-dataset/cropped \\\n",
" --jsonout pixelscapes-dataset/cropped/dataset.json \\\n",
" --count $generate_dataset_count \\\n",
" --size 256 --weighting $weighting_amount\n",
"\n",
" !python3 /content/stylegan3/dataset_tool.py \\\n",
" --source=pixelscapes-dataset/cropped \\\n",
" --dest=dataset.zip \\\n",
" --resolution='256x256'\n",
"\n",
" datasetpath = \"/content/dataset.zip\"\n",
"\n",
" # upload the newly created dataset\n",
" print(\"Syncing dataset.zip to drive...\")\n",
" resultcode = rclone(\"syncto\", \"--progress\", \"/content/dataset.zip\", \"driveapi:/\", output=\"pass\", check=False)\n",
" if resultcode != 0:\n",
" print(f\"... not synced, result code is {resultcode}\")\n",
" else:\n",
" print(\"ok\")\n",
"\n",
" !rm -rf /content/dataset\n",
" # unzip the dataset.zip if needed and we have one\n",
" if unzip_dataset and pathlib.Path(\"/content/dataset.zip\").exists() and not (pathlib.Path(\"/content/dataset\").exists() and pathlib.Path(\"/content/dataset\").is_dir()):\n",
" print(\"Unzipping dataset.zip to /content/dataset/...\")\n",
" import ipywidgets\n",
" from IPython import display\n",
" outw = widgets.Output()\n",
" display.display(outw)\n",
" import subprocess\n",
" p = subprocess.Popen([\"unzip\", \"-d/content/dataset\", \"/content/dataset.zip\"], stdin=subprocess.PIPE, stdout=subprocess.PIPE)\n",
" buf = bytes()\n",
" lineno = 0\n",
" while p.returncode is None:\n",
" o, e = p.communicate(bytes())\n",
" buf = buf + o\n",
" lines = buf.splitlines()\n",
" buf = lines[-1]\n",
" lines = lines[0:-1]\n",
" for line in lines:\n",
" lineno = lineno + 1\n",
" if lineno > 100:\n",
" lineno = 0\n",
" outw.clear_output(wait=True)\n",
" with outw:\n",
" print(line.strip().decode(\"utf8\"), flush=True)\n",
" print(\"done!\")\n",
" datasetpath = \"/content/dataset\"\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "5YPvBu7plAqg"
},
"source": [
"# <big><big><big><big><big>**Training**</big></big></big></big></big>\n",
"- From scratch\n",
"- Resume training a model snapshot"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "6Dt7NAJ6EYhe"
},
"outputs": [],
"source": [
"#@title # <big><big><big>Run train.py</big></big></big> { vertical-output: true, display-mode: \"form\" }\n",
"#@markdown ### `train.py` Options (see `train.py --help`):\n",
"\n",
"\n",
"upload_to_subdir = \"fromscratchconditionalone\" #@param {type: \"string\"}\n",
"#@markdown > when copying output files to google drive, place them in this subdirectory, so as not to overwrite other training runs. allows substitutions of the form *\\{varname\\}*, where *varname* is one of:\n",
"#@markdown > - `{gamma}` the R1 regularisation parameter, $\\gamma$\n",
"\n",
"cfg = \"stylegan3-t\" #@param [\"stylegan2\", \"stylegan3-t\", \"stylegan3-r\"]\n",
"\n",
"kimg = 2000#@param {type:\"integer\"}\n",
"#@markdown > Fine-tune the pre-trained model for an additional ${kimg} \\times 10^3$ iterations.\n",
"\n",
"tick = 5#@param {type:\"integer\"}\n",
"#@markdown > Log status after every ${tick}$ kimgs during training\n",
"\n",
"snap = 10 #@param {type:\"integer\"}\n",
"#@markdown > Save a model snapshot each ${snapshot\\_ticks}$ during training\n",
"\n",
"img_snap = 1 #@param {type:\"integer\"}\n",
"#@markdown > ${tick}$ per fakes png saved during training\n",
"\n",
"gamma_values = \"6\" #@param {type: \"string\"}\n",
"#@markdown > critical hyperparameter ${\\gamma}$ is the R1 regularisation rate factor for the mapping network in the discriminator. -1 to usethe default heuristic based on $gpus$, \n",
"#@markdown > \n",
"#@markdown > **for hyperparameter searching** give multiple, space separated ${\\gamma}$ values, and use `{gamma}` in your `upload_to_subdir` folder name, and training will be run with each ${\\gamma}$ value and the results uploaded to separate folders for analysis.\n",
"\n",
"ema_factor_values = \"1\" #@param {type: \"string\"}\n",
"#@markdown > critical hyperparameter ${ema}_{factor}$ scales the default heuristic value for the generator weights exponential moving average, which smooths large gradients during the training process and stabilizes things significantly. <1 \n",
"\n",
"freezed = -1#@param {type: \"integer\", min: 0}\n",
"#@markdown > freeze first ${freezed}$ layers of generator network (mostly useful for transfer learning)\n",
"batch = 32#@param {type: \"integer\"}\n",
"#@markdown > ${batch}$ is the training batch size \n",
"\n",
"batch_gpu = 16 #@param {type: \"integer\"}\n",
"#@markdown > ${batch}_{GPU}$ is the per-gpu training batch size. if this is less than ${batch} \\div {gpus}$\n",
"\n",
"mbstd_group = -1#@param {type: \"integer\"}\n",
"#@markdown > ${mbstd}_{group}$ is the mini-batch size. reduce it along with ${batch}_{GPU}$ if you run out of GPU memory.\n",
"\n",
"half_cbase = True #@param {type: \"boolean\"}\n",
"#@markdown > Set ${cbase}$ to half the usual channel count (affects the size of the neural network layers), to reduce GPU memory use.\n",
"\n",
"mirror = False #@param {type: \"boolean\"}\n",
"#@markdown > Augment the dataset by randomly flipping the images about the y axis centerline.\n",
"\n",
"image_snap_res = \"4k\" #@param [\"1080p\", \"4k\", \"8k\"]\n",
"#@markdown > the size of the fakes.png image grids written every ${image\\_ticks} \\times {tick\\_kimg}$ during training\n",
"\n",
"aug = \"ada\" #@param [\"noaug\", \"ada\", \"fixed\"]\n",
"\n",
"augpipe = \"blit\" #@param [\"none\", \"bgc\", \"bc\", \"blit\"] {allow-input: true}\n",
"\n",
"metrics = \"none\" #@param [\"none\", \"fid50k_full\"]\n",
"\n",
"\n",
"seed = 1983#@param {type: \"integer\"}\n",
"\n",
"#@markdown > seed for random values, use consistent seed for deterministic, replicatable training runs\n",
"#@markdown ---\n",
"\n",
"# -----------------------------------------------------------------------------------------------\n",
"if notebook_mode != \"Training\":\n",
" print(\"No action taken, training is not enabled for this notebook_mode.\")\n",
"else:\n",
" # ---------------------------------------------------------------------------------------------\n",
"\n",
" # if gamma <= 0:\n",
" # gamma = None\n",
" # if mbstd_group <= 0:\n",
" # mbstd_group = None\n",
" # if batch_gpu <= 0:\n",
" # batch_gpu = None\n",
" # if batch <= 0:\n",
" # batch = None\n",
"\n",
" import sys, os, pathlib, glob, re\n",
"\n",
" # resume training with the network-snapshot-######.pkl we downloaded above\n",
"\n",
" if modelpath != None and modelpath != 'none':\n",
" resume = modelpath\n",
" # determine the kimg that the model we are fine-tuning has already been trained for\n",
" # by extracting it from the filename. TODO: get this and other parameters from the\n",
" # training_options.json\n",
" try:\n",
" resume_kimg = int(re.match(\".*-(\\d+)\\..*?$\", pathlib.Path(resume).name).group(1))\n",
" except:\n",
" resume_kimg = 500\n",
" else:\n",
" resume = -1 #\"https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/stylegan3-t-ffhqu-256x256.pkl\"\n",
" resume_kimg = -1 #500\n",
"\n",
" # hpar fast orig\n",
" if (not isinstance(resume, int)) or resume != -1:\n",
" resume_kimg = int(((resume_kimg + tick-1) // tick) * tick)\n",
" kimg = int((int(resume_kimg + kimg + tick-1)//tick)*tick) # 500\n",
"\n",
" print(f'resuming from network snapshot pkl: {resume}')\n",
" print(f'resuming from kimg count: {resume_kimg}')\n",
" else:\n",
" print(f'training from scratch!')\n",
" print(f\"running until kimg: {kimg}\")\n",
"\n",
" %cd /content\n",
"\n",
" if False:\n",
" %cd /content/stylegan3\n",
"\n",
" if 'train' in sys.modules.keys():\n",
" del sys.modules['train']\n",
" if 'train' in locals().keys():\n",
" del train\n",
" import train\n",
"\n",
" def runtraining(outdirname=\"training-runs\", gamma=-1, ema_factor=-1, **overrides):\n",
" import shlex\n",
"\n",
"\n",
" def kwargstodict(*dicts, **kwargs):\n",
" import shlex\n",
" outdict = {}\n",
" for d in dicts:\n",
" for k,v in d.items():\n",
" outdict[k] = v\n",
" for k,v in kwargs.items():\n",
" if callable(v):\n",
" if k in outdict.keys():\n",
" outdict[k] = v(outdict[k])\n",
" else:\n",
" outdict[k] = v()\n",
" else:\n",
" outdict[k] = v\n",
" return outdict\n",
" \n",
" def quoteargs(args):\n",
" outargs = []\n",
" for k,v in args.items():\n",
" if v == None or v == '' or (isinstance(v, (int, float)) and v < 0):\n",
" print(f\"quoteargs(): ignoring deleted or empty or negative-integer long-option {repr(k)} with value {repr(v)}!\")\n",
" continue\n",
" outargs.append(f\"--{k.replace('_','-')}={shlex.quote(str(v))}\")\n",
" return \" \".join(outargs)\n",
"\n",
" args=kwargstodict(\n",
" outdir='/content/training-runs', \n",
" data=\"/content/dataset\", \n",
" resume=resume,\n",
" resume_kimg=resume_kimg, \n",
" cfg=cfg, \n",
" cond=True,\n",
" gpus=1, \n",
" workers=1,\n",
" kimg=kimg, \n",
" gamma=gamma,\n",
" ema_factor=ema_factor,\n",
" batch=batch, \n",
" batch_gpu=batch_gpu, \n",
" mbstd_group=mbstd_group,\n",
" aug=aug,\n",
" augpipe=augpipe, \n",
" mirror=mirror, \n",
" tick=tick, \n",
" snap=snap, \n",
" img_snap=img_snap, \n",
" snap_res=image_snap_res, \n",
" freezed=freezed if freezed > 0 else None,\n",
" cbase=16384 if half_cbase else None,\n",
" seed=seed,\n",
" metrics=metrics\n",
" )\n",
"\n",
" # apply any overrides over defaults\n",
" args = kwargstodict(args, **overrides) \n",
" \n",
" # add args for uploading output files to google drive using outdirname\n",
" args=kwargstodict(\n",
" args,\n",
" img_cmd='rclone --config=/content/rclone.config copy \"$1\" driveapi:/'+shlex.quote(outdirname)+'/training-run-\"$DESC\"/; echo \"$1\"',\n",
" snap_cmd='rclone --config=/content/rclone.config copy \"$1\" driveapi:/'+shlex.quote(outdirname)+'/training-run-\"$DESC\"/; echo \"$1\"'\n",
" )\n",
" \n",
" # produce a shell-quoted string with long-option-style args from the dict\n",
" qargs = quoteargs(args)\n",
"\n",
" print(f\"/content/stylegan3/train.py {repr(qargs)}\") \n",
" \n",
" # Fine-tune StyleGAN3-T for pixelscapes-256-50k using 1 GPU, starting from the pre-trained FFHQ-U pickle.\n",
" from IPython.display import display\n",
" from ipywidgets.widgets import Box, Output, Image\n",
" oimage = Image(value=bytes(),layout={'border': '1px solid black'})\n",
" #olog = Output(layout={'border': '1px solid black'})\n",
" #obox = Box(children=(oimage, olog))\n",
" #display(olog)\n",
" display(oimage)\n",
"\n",
" import subprocess, select, re\n",
" p = subprocess.Popen(f\"python3 /content/stylegan3/train.py {qargs}\", shell=True, stdin=subprocess.DEVNULL, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True, bufsize=1)\n",
" done = False\n",
" \n",
" #with olog:\n",
" if True:\n",
" print(\"running train.py script...\")\n",
"\n",
" outdir = None\n",
" lastpng = None\n",
" pngname = None\n",
"\n",
" try:\n",
" while not done:\n",
" rfds, wfds, xfds = select.select([p.stdout.fileno()], [], [])\n",
" if p.poll() != None:\n",
" # process \n",
" done=True\n",
" else:\n",
" #with olog:\n",
" #if out == '' and err == '':\n",
" # resultcode = p.wait()\n",
" # done = True\n",
" # continue\n",
" if p.stdout.fileno() in rfds:\n",
" line = p.stdout.readline().strip()\n",
" m = re.match(\"^(/content/.*\\.png);.*$\", line)\n",
" if m != None:\n",
" pngname = m.group(1)\n",
" oimg = Image(value=bytes())\n",
" oimg.set_value_from_file(str(pngname))\n",
" display(oimg)\n",
" #if lastpng != pngname:\n",
" # lastpng = pngname\n",
" # pngpath = pathlib.Path(lastpng)\n",
" # oimage.set_value_from_file(str(pngpath))\n",
" # display(oimage)\n",
" elif m == None:\n",
" print(line)\n",
"\n",
" #if p.stderr.fileno() in rfds:\n",
" # print(f\"{repr(p.stderr.read(1))}\", end='')\n",
" if len(rfds) == 0 or p.poll() != None:\n",
" done=True\n",
" except KeyboardInterrupt:\n",
" print(\"*** keyboardinterrupt in runtraining()\")\n",
" pass\n",
" except Exception as err:\n",
" import traceback as tb\n",
" print(f\"exception thrown in runtraining(): {err}\")\n",
" tb.print_exc()\n",
" finally:\n",
" try:\n",
" if p.poll() == None:\n",
" p.kill()\n",
" else:\n",
" p.wait(timeout=2.0)\n",
" except:\n",
" pass\n",
"\n",
" try:\n",
" import re\n",
" gammas = [float(word) for word in re.split(\"[, \\t;]+\", gamma_values)]\n",
" except Exception as err:\n",
" print(f\"Error parsing list of gamma values to run training on: {err}\")\n",
" exit(1)\n",
"\n",
" try:\n",
" import re\n",
" ema_factors = [float(word) for word in re.split(\"[, \\t;]+\", ema_factor_values)]\n",
" except Exception as err:\n",
" print(f\"Error parsing list of ema_factor values to run training on: {err}\")\n",
" exit(1)\n",
"\n",
" try:\n",
" for ema_factor in ema_factors:\n",
" for gamma in gammas:\n",
" print(f\"*** training run with gamma={gamma} and ema_factor={ema_factor}***\")\n",
" runtraining(\n",
" outdirname = upload_to_subdir,\n",
" gamma = float(gamma),\n",
" ema_factor = float(ema_factor)\n",
" #tick=1, kimg=2, gamma=lambda x: x*2\n",
" )\n",
" print(\"*** training run ended ***\")\n",
" except Exception as err:\n",
" import traceback as tb\n",
" print(f\"exception thrown in training main hyperparameter search loop: {err}\")\n",
" tb.print_exc()"
]
},
{
"cell_type": "markdown",
"source": [
"# <big><big><big><big><big>Inference</big></big></big></big></big>\n"
],
"metadata": {
"id": "6CzqZUcc_Jru"
}
},
{
"cell_type": "markdown",
"source": [
"These cells are a WIP interactive latent space explorer, allowing interpolation between two seed/class latent space points, and projection of real samples from the dataset or uploaded images, and interpolation between the projected latent space coordinates."
],
"metadata": {
"id": "inl-uHXghwOP"
}
},
{
"cell_type": "code",
"source": [
"#@title Real-time Inference { vertical-output: true, display-mode: \"form\" }\n",
"\n",
"%cd /content/stylegan3/\n",
"\n",
"import os\n",
"from typing import List, Optional, Union, Tuple\n",
"import click\n",
"\n",
"import dnnlib\n",
"from torch_utils import gen_utils\n",
"\n",
"import scipy\n",
"import numpy as np\n",
"import PIL.Image\n",
"import torch\n",
"\n",
"import legacy\n",
"import projector\n",
"from training.dataset import ImageFolderDataset\n",
"\n",
"dataset = ImageFolderDataset(datasetpath)\n",
"dataset._load_raw_labels()\n",
"\n",
"def init_model(network_pkl):\n",
"\tdevice = torch.device('cuda')\n",
"\twith dnnlib.util.open_url(network_pkl) as f:\n",
"\t\tG = legacy.load_network_pkl(f)['G_ema'].to(device) # type: ignore\n",
"\tgen_utils.anchor_latent_space(G)\n",
"\treturn device, G\n",
"\n",
"def pil_to_jpeg(im):\n",
" from io import BytesIO\n",
" with BytesIO() as o:\n",
" im.save(o, \"JPEG\")\n",
" return o.getvalue()\n",
"\n",
"def latent_from_seed(device, G, seed, class_idx):\n",
"\tif G.c_dim != 0:\n",
"\t\tc = np.zeros([G.c_dim])\n",
"\t\tc[class_idx] = 1\n",
"\telse:\n",
"\t\tc = None\n",
"\tz = np.random.RandomState(seed).randn(G.z_dim)\n",
"\treturn z, c\t\n",
"\n",
"def latent_from_image(device, G, img, class_idx):\n",
"\tlabel = np.eye([G.c_dim])[class_idx] if G.c_dim > 0 else None\n",
"\tw = projector.project(G, target=img, label=label, projection_seed=0, truncation_psi=1, device=device)[0][-1]\n",
"\tzs = np.zeros([1, G.z_dim])\n",
"\tzs[0] = G.mapping(w, np.eye([G.c_dim])[class_idx])\n",
"\timg = gen_utils.z_to_img(G, z, c, truncation_psi=1, noise_mode='const')\n",
"\tpilim = PIL.Image.from_array(img[0], \"RGB\")\n",
"\treturn pilim\n",
"\n",
"def interpolate_z(device, G, z0, z1, interp):\n",
"\treturn gen_utils.slerp(interp, z0, z1)\n",
"\n",
"def interpolate_c(device, G, c0, c1, interp):\n",
"\tif c0 is not None and c1 is not None:\n",
"\t\treturn c1 * interp + c0 * (1.0-interp)\n",
"\treturn None\n",
"\n",
"\n",
"\n",
"def render_image(device, G, \n",
"\t\tseeds,\n",
"\t\tclass_idxes,\n",
"\t\ttruncation_psi = 1.0,\n",
"\t\tnoise_mode = 'const',\n",
"\t\tgrid_rows = 1,\n",
"\t\tgrid_cols = 1,\n",
"\t\tinterpz = 0.0,\n",
"\t\tinterpl = 0.0\n",
"\t\t):\n",
"\ttry:\n",
"\t\t\n",
"\t\tif len(seeds) == 1:\n",
"\t\t\tseed0 = seeds[0]\n",
"\t\t\tseed1 = seeds[0]\n",
"\t\telse:\n",
"\t\t\tseed0 = seeds[0]\n",
"\t\t\tseed1 = seeds[1]\n",
"\n",
"\t\tif class_idxes != None:\n",
"\t\t\tif G.c_dim == 0:\n",
"\t\t\t\traise RuntimeError(\"Error, cannot specify class for unconditional network\")\n",
"\t\t\tif len(class_idxes) == 1:\n",
"\t\t\t\tclass_idx0 = class_idxes[0]\n",
"\t\t\t\tclass_idx1 = class_idxes[0]\n",
"\t\t\telse:\n",
"\t\t\t\tclass_idx0 = class_idxes[0]\n",
"\t\t\t\tclass_idx1 = class_idxes[1]\n",
"\t\telse:\n",
"\t\t\tif G.c_dim != 0:\n",
"\t\t\t\traise ValueError(\"Error, must specify class for conditional network\")\n",
"\t\t\tclass_idx0, class_idx1 = None, None\n",
"\n",
"\t\t# generate latent z interpolation\n",
"\t\tz0, c0 = latent_from_seed(device, G, seed0, class_idx0)\n",
"\t\tz1, c1 = latent_from_seed(device, G, seed1, class_idx1)\n",
"\n",
"\t\tif G.c_dim != 0:\n",
"\t\t\tcs = np.zeros([grid_rows*grid_cols, G.c_dim])\n",
"\t\t\tfor col in range(0,grid_cols):\n",
"\t\t\t\ta = (col/max(grid_cols-1, 1)) + interpl\n",
"\t\t\t\tci = interpolate_c(device, G, c0, c1, interpl)\n",
"\t\t\t\tfor row in range(0,grid_rows):\n",
"\t\t\t\t\tcs[row*grid_cols+col, :] = ci\n",
"\t\telse:\n",
"\t\t\tcs = None\n",
"\t\t\n",
"\t\tc = None\n",
"\t\tif cs is not None:\n",
"\t\t\tc = torch.from_numpy(cs).to(device)\n",
"\n",
"\t\tzs = np.zeros([grid_rows * grid_cols, G.z_dim])\n",
"\t\tfor row in range(0, grid_cols):\n",
"\t\t\ta = row/max(1,(grid_rows-1)) + interpz\n",
"\t\t\tzi = interpolate_z(device, G, z0, z1, a)\n",
"\t\t\tfor col in range(0, grid_rows):\n",
"\t\t\t\tzs[row * grid_cols + col, :] = zi\n",
"\t\tz = torch.from_numpy(zs).to(device)\n",
"\t\t\n",
"\t\timgs = gen_utils.z_to_img(G, z, c, truncation_psi, noise_mode)\n",
"\t\timg = gen_utils.create_image_grid(imgs, (grid_rows, grid_cols))\n",
"\t\tim = PIL.Image.fromarray(img, 'RGB')\n",
"\t\t# from io import BytesIO\n",
"\t\t# with BytesIO() as stream:\n",
"\t\t# \tim.save(stream, \"JPEG\")\n",
"\t\t# \treturn stream.getvalue()\n",
"\t\treturn im\n",
"\texcept Exception as err:\n",
"\t\traise err\n",
"\n",
"\n",
"import IPython\n",
"import ipywidgets as widgets\n",
"\n",
"def label_on_left(label=\"\", widget=None, parent=None):\n",
"\tchild = widgets.HBox(children=[widgets.Label(value=label), widget])\n",
"\tif parent != None:\n",
"\t\tparent.children.append(child)\n",
"\telse:\n",
"\t\treturn child\n",
"\treturn widget\n",
"\n",
"seed_a_input = widgets.IntText(value=0, description=\"Seed A\")\n",
"seed_b_input = widgets.IntText(value=0, description=\"Seed B\")\n",
"class_a_input = widgets.IntSlider(value=0, min=0, max=70, description=\"Class A\")\n",
"class_b_input = widgets.IntSlider(value=0, min=0, max=70, description=\"Class B\")\n",
"\n",
"interp_seed_input = widgets.FloatSlider(value=0, min=0, max=1, label=\"Interpolate Seed\")\n",
"interp_class_input = widgets.FloatSlider(value=0, min=0, max=1, label=\"Interpolate Class\")\n",
"inputs_a = widgets.VBox(children=[\n",
"\t label_on_left(\"Seed\", seed_a_input), \n",
"\t\tlabel_on_left(\"Class\", class_a_input)\n",
"\t], layout=widgets.Layout(border=\"1px solid white\"))\n",
"\n",
"inputs_b = widgets.VBox(children=[\n",
"\t label_on_left(\"Seed\", seed_b_input), \n",
"\t\tlabel_on_left(\"Class\", class_b_input)\n",
"\t], layout=widgets.Layout(border=\"1px solid white\"))\n",
"\n",
"inputs_i = widgets.VBox(children=[interp_seed_input, interp_class_input], layout=widgets.Layout(border=\"1px solid white\"))\n",
"\n",
"inputs = widgets.VBox(children=[\n",
"\t label_on_left(\"Input A\", inputs_a), \n",
"\t\tlabel_on_left(\"Fade A->B\", inputs_i), \n",
"\t\tlabel_on_left(\"Input B\", inputs_b)\n",
"])\n",
"outputim = widgets.Image(value=bytes(), format=\"jpeg\")\n",
"\n",
"origim0 = widgets.Image(value=bytes(), format=\"jpeg\")\n",
"projectedim0 = widgets.Image(value=bytes(), format=\"jpeg\")\n",
"realindex0 = widgets.IntText(value=0)\n",
"realclass0 = widgets.Label(value=\"0\")\n",
"uploadim0 = widgets.FileUpload(multiple=False, accepts=\"image/*\")\n",
"projected0box = widgets.VBox(children=[origim0, projectedim0, uploadim0, realindex0, realclass0], layout=widgets.Layout(border=\"1px solid white\"))\n",
"def onrealindex0changed(evt):\n",
"\tglobal G\n",
"\tglobal device\n",
"\tglobal dataset\n",
"\tglobal realindex0\n",
"\tglobal projectedim0\n",
"\timage, clsx = dataset[realindex0.value]\n",
"\tdeets = dataset.get_details(realindex0.value)\n",
"\tclss = deets.raw_label\n",
"\timage = image.transpose(1, 2, 0) # CHW -> HWC\n",
"\tpilim = PIL.Image.fromarray(image, 'RGB')\n",
"\torigim0.value = pil_to_jpeg(pilim)\n",
"\trealclass0.value = str(clss)\n",
"\tprojim = latent_from_image(device, G, pilim, 0)\n",
"\tprojectedim0.value = pil_to_jpeg(projim)\n",
"realindex0.unobserve_all()\n",
"realindex0.observe(onrealindex0changed, ['value'])\n",
"\n",
"projectedim1 = widgets.Image(value=bytes(), format=\"jpeg\")\n",
"realindex1 = widgets.IntText(value=0)\n",
"realclass1 = widgets.Label(value=\"0\")\n",
"uploadim1 = widgets.FileUpload(multiple=False, accepts=\"image/*\")\n",
"projected1box = widgets.VBox(children=[projectedim1, uploadim1, realindex1, realclass1], layout=widgets.Layout(border=\"1px solid white\"))\n",
"def onrealindex1changed(evt):\n",
"\tglobal dataset\n",
"\tglobal realindex1\n",
"\tglobal projectedim1\n",
"\tsingim = dataset._load_raw_image(realindex1.value)\n",
"\tim = np.zeros([1, *singim.shape])\n",
"\tim[0] = singim\n",
"\tpilim = PIL.Image.fromarray(img, 'RGB')\n",
"\tprojectedim1.value = pil_to_jpeg(pilim)\n",
"\trealclass1.value = str(cls)\n",
"realindex1.observe(onrealindex1changed, ['value'])\n",
"\n",
"mainproj = widgets.HBox(children=[projected0box, projected1box])\n",
"mainio = widgets.HBox(children=[inputs, outputim])\n",
"main = widgets.VBox(children=[mainio, mainproj])\n",
"\n",
"print(\"loading model...\")\n",
"device, G = init_model(modelpath)\n",
"\n",
"\n",
"def refresh(evt,*args, **kwargs):\n",
" global outputim\n",
" result = render_image(device, G, [seed_a_input.value, seed_b_input.value],\n",
" [class_a_input.value, class_b_input.value],\n",
" interpz = interp_seed_input.value,\n",
" interpl = interp_class_input.value)\n",
" #print(result)\n",
" outputim.value=pil_to_jpeg(result)\n",
" #main.children = [inputs, output]\n",
" \n",
"\n",
"def hookupinputs(widget):\n",
" if 'children' in widget.keys:\n",
" for child in widget.children:\n",
" hookupinputs(child)\n",
" else:\n",
" widget.unobserve_all()\n",
" widget.observe(refresh, ['value'])\n",
"\n",
"hookupinputs(inputs)\n",
"IPython.display.display(main)\n",
"refresh(None)"
],
"metadata": {
"id": "2tGh_0ixxWcr"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
""
],
"metadata": {
"id": "8FIFaQbyZ-TU"
}
},
{
"cell_type": "code",
"source": [
"#!curl -X GET --unix-socket /content/server.sock http://lodqlhost/\n",
"#!curl -X GET --unix-socket /content/server.sock http://lodqlhost/load?network_pkl=$modelpath"
],
"metadata": {
"id": "8rAyf0QWr6pe"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"'''\n",
"# -----------------------------------------------------------------------------------------------\n",
"if notebook_mode != \"Inference\":\n",
" print(\"No action taken, inference is not enabled for this notebook_mode.\")\n",
"else:\n",
" # ---------------------------------------------------------------------------------------------\n",
" \n",
" #!python3 /content/stylegan3/generate.py random-video -anchor --seeds=1 \\\n",
" # --network=$modelpath --class-interp=30 --class-seed=1\n",
" # --out=lerpout --trunc=1 --seeds=0-31 --grid=4x2 \\\n",
" # --network=$modelpath --anchor-latent-space --class=1\n",
" %pip install aiohttp\n",
" %cd /content/stylegan3\n",
" import subprocess\n",
" if 'server_proc' in globals().keys():\n",
" server_proc.kill()\n",
" !rm /content/server.sock\n",
" #server_proc = subprocess.Popen([\"python3\", \"-m\", \"aiohttp.web\", \"-U\", \"/content/server.sock\", \"gen_server:init_app\"])\n",
" server_proc = subprocess.Popen([\"python3\", \"gen_server.py\"])\n",
" print(server_proc.pid)\n",
" #!python3 -m aiohttp.web -U /content/server.sock gen_server:init_app &\n",
"\n",
" #!python3 /content/stylegan3/gen_video.py --output=lerp.mp4 --trunc=1 --seeds=0-31 --labels=0 \\\n",
" # --network=$modelpath --stabilize-video\n",
"'''"
],
"metadata": {
"id": "qhA5STOzAupj"
},
"execution_count": null,
"outputs": []
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"background_execution": "on",
"collapsed_sections": [],
"machine_shape": "hm",
"name": "stylegan3_training_and_inference_2022_02_11.ipynb",
"provenance": [],
"private_outputs": true,
"toc_visible": true,
"include_colab_link": true
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
},
"language_info": {
"name": "python"
}
},
"nbformat": 4,
"nbformat_minor": 0
}
@whatsnewsisyphus
Copy link

whatsnewsisyphus commented Apr 19, 2022

Thanks for this. I'm having some issues with the training, I think there might be a string issue with rclone trying to output to google drive, was wondering if you could take a peek. I haven't modified it, it just goes to training run ended immediately regardless of whether I run it from scratch or if I resume from a pickle
Screen Shot 2022-04-18 at 11 57 22 PM

training from scratch!
running until kimg: 2000
/content
*** training run with gamma=6.0 and ema_factor=1.0***
quoteargs(): ignoring deleted or empty or negative-integer long-option 'resume' with value -1!
quoteargs(): ignoring deleted or empty or negative-integer long-option 'resume_kimg' with value -1!
quoteargs(): ignoring deleted or empty or negative-integer long-option 'mbstd_group' with value -1!
quoteargs(): ignoring deleted or empty or negative-integer long-option 'freezed' with value None!
/content/stylegan3/train.py '--outdir=/content/training-runs --data=/content/dataset --cfg=stylegan3-t --cond=True --gpus=1 --workers=1 --kimg=2000 --gamma=6.0 --ema-factor=1.0 --batch=32 --batch-gpu=16 --aug=ada --augpipe=blit --mirror=False --tick=5 --snap=5 --img-snap=1 --snap-res=4k --cbase=16384 --seed=1983 --metrics=none --img-cmd=\'rclone --config=/content/rclone.config copy "$1" driveapi:/scratch/training-run-"$DESC"/; echo "$1"\' --snap-cmd=\'rclone --config=/content/rclone.config copy "$1" driveapi:/scratch/training-run-"$DESC"/; echo "$1"\''

running train.py script...
Usage: train.py [OPTIONS]
*** training run ended ***

@un1tz3r0 let me know if I can provide anything else.

@un1tz3r0
Copy link
Author

did you generate a dataset? the first time you run it, you must check the "generate_missing_dataset" toggle in the "Prepare Dataset" cell in order to create a dataset of cropped images in my pixelscapes-dataset repo. it will download realesrgan and some dependencies and grab the source images from the repo, upscale them, and then produce the crops and the class conditional labels which i found crucial to successfully training on diverse datasets like the pixelscapes. once it has been generated, it will upload a zip of the dataset to your google drive as dataset.zip and future invocations should download that and use it. please include the output of the Prepare Dataset cell and check that /content/dataset exists and is not empty if you are still having issues after checking that toggle. Hopefully this helps!

@whatsnewsisyphus
Copy link

I already had a dataset prepared with the dataset tool that I use to train on another colab without problem which I uploaded, and ran the cell without gen missing dataset checkbox. The cell said it found and processed the dataset over.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment