Skip to content

Instantly share code, notes, and snippets.

@StEvUgnIn
Created February 26, 2022 11:54
Show Gist options
  • Save StEvUgnIn/e8d7a91a68c2848df65e004810d84249 to your computer and use it in GitHub Desktop.
Save StEvUgnIn/e8d7a91a68c2848df65e004810d84249 to your computer and use it in GitHub Desktop.
dcgan.ipynb
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/StEvUgnIn/e8d7a91a68c2848df65e004810d84249/dcgan.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "LhAeqTlqapxM"
},
"source": [
"The following additional libraries are needed to run this\n",
"notebook. Note that running on Colab is experimental, please report a Github\n",
"issue if you have any problem."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "CTz4bdv3apxR",
"outputId": "b7338de9-a822-4abe-fd11-80ec30293a14"
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Requirement already satisfied: mxnet-cu101==1.7.0 in /usr/local/lib/python3.7/dist-packages (1.7.0)\n",
"Requirement already satisfied: numpy<2.0.0,>1.16.0 in /usr/local/lib/python3.7/dist-packages (from mxnet-cu101==1.7.0) (1.18.5)\n",
"Requirement already satisfied: requests<3,>=2.20.0 in /usr/local/lib/python3.7/dist-packages (from mxnet-cu101==1.7.0) (2.25.1)\n",
"Requirement already satisfied: graphviz<0.9.0,>=0.8.1 in /usr/local/lib/python3.7/dist-packages (from mxnet-cu101==1.7.0) (0.8.4)\n",
"Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.7/dist-packages (from requests<3,>=2.20.0->mxnet-cu101==1.7.0) (2.10)\n",
"Requirement already satisfied: urllib3<1.27,>=1.21.1 in /usr/local/lib/python3.7/dist-packages (from requests<3,>=2.20.0->mxnet-cu101==1.7.0) (1.24.3)\n",
"Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.7/dist-packages (from requests<3,>=2.20.0->mxnet-cu101==1.7.0) (2021.10.8)\n",
"Requirement already satisfied: chardet<5,>=3.0.2 in /usr/local/lib/python3.7/dist-packages (from requests<3,>=2.20.0->mxnet-cu101==1.7.0) (3.0.4)\n",
"Requirement already satisfied: d2l==0.17.2 in /usr/local/lib/python3.7/dist-packages (0.17.2)\n",
"Requirement already satisfied: pandas==1.2.2 in /usr/local/lib/python3.7/dist-packages (from d2l==0.17.2) (1.2.2)\n",
"Requirement already satisfied: requests==2.25.1 in /usr/local/lib/python3.7/dist-packages (from d2l==0.17.2) (2.25.1)\n",
"Requirement already satisfied: numpy==1.18.5 in /usr/local/lib/python3.7/dist-packages (from d2l==0.17.2) (1.18.5)\n",
"Requirement already satisfied: matplotlib==3.3.3 in /usr/local/lib/python3.7/dist-packages (from d2l==0.17.2) (3.3.3)\n",
"Requirement already satisfied: jupyter==1.0.0 in /usr/local/lib/python3.7/dist-packages (from d2l==0.17.2) (1.0.0)\n",
"Requirement already satisfied: notebook in /usr/local/lib/python3.7/dist-packages (from jupyter==1.0.0->d2l==0.17.2) (5.3.1)\n",
"Requirement already satisfied: nbconvert in /usr/local/lib/python3.7/dist-packages (from jupyter==1.0.0->d2l==0.17.2) (5.6.1)\n",
"Requirement already satisfied: ipykernel in /usr/local/lib/python3.7/dist-packages (from jupyter==1.0.0->d2l==0.17.2) (4.10.1)\n",
"Requirement already satisfied: jupyter-console in /usr/local/lib/python3.7/dist-packages (from jupyter==1.0.0->d2l==0.17.2) (5.2.0)\n",
"Requirement already satisfied: ipywidgets in /usr/local/lib/python3.7/dist-packages (from jupyter==1.0.0->d2l==0.17.2) (7.6.5)\n",
"Requirement already satisfied: qtconsole in /usr/local/lib/python3.7/dist-packages (from jupyter==1.0.0->d2l==0.17.2) (5.2.2)\n",
"Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.7/dist-packages (from matplotlib==3.3.3->d2l==0.17.2) (0.11.0)\n",
"Requirement already satisfied: pillow>=6.2.0 in /usr/local/lib/python3.7/dist-packages (from matplotlib==3.3.3->d2l==0.17.2) (7.1.2)\n",
"Requirement already satisfied: python-dateutil>=2.1 in /usr/local/lib/python3.7/dist-packages (from matplotlib==3.3.3->d2l==0.17.2) (2.8.2)\n",
"Requirement already satisfied: pyparsing!=2.0.4,!=2.1.2,!=2.1.6,>=2.0.3 in /usr/local/lib/python3.7/dist-packages (from matplotlib==3.3.3->d2l==0.17.2) (3.0.7)\n",
"Requirement already satisfied: kiwisolver>=1.0.1 in /usr/local/lib/python3.7/dist-packages (from matplotlib==3.3.3->d2l==0.17.2) (1.3.2)\n",
"Requirement already satisfied: pytz>=2017.3 in /usr/local/lib/python3.7/dist-packages (from pandas==1.2.2->d2l==0.17.2) (2018.9)\n",
"Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.7/dist-packages (from requests==2.25.1->d2l==0.17.2) (2.10)\n",
"Requirement already satisfied: chardet<5,>=3.0.2 in /usr/local/lib/python3.7/dist-packages (from requests==2.25.1->d2l==0.17.2) (3.0.4)\n",
"Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.7/dist-packages (from requests==2.25.1->d2l==0.17.2) (2021.10.8)\n",
"Requirement already satisfied: urllib3<1.27,>=1.21.1 in /usr/local/lib/python3.7/dist-packages (from requests==2.25.1->d2l==0.17.2) (1.24.3)\n",
"Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.7/dist-packages (from python-dateutil>=2.1->matplotlib==3.3.3->d2l==0.17.2) (1.15.0)\n",
"Requirement already satisfied: jupyter-client in /usr/local/lib/python3.7/dist-packages (from ipykernel->jupyter==1.0.0->d2l==0.17.2) (5.3.5)\n",
"Requirement already satisfied: ipython>=4.0.0 in /usr/local/lib/python3.7/dist-packages (from ipykernel->jupyter==1.0.0->d2l==0.17.2) (5.5.0)\n",
"Requirement already satisfied: traitlets>=4.1.0 in /usr/local/lib/python3.7/dist-packages (from ipykernel->jupyter==1.0.0->d2l==0.17.2) (5.1.1)\n",
"Requirement already satisfied: tornado>=4.0 in /usr/local/lib/python3.7/dist-packages (from ipykernel->jupyter==1.0.0->d2l==0.17.2) (5.1.1)\n",
"Requirement already satisfied: decorator in /usr/local/lib/python3.7/dist-packages (from ipython>=4.0.0->ipykernel->jupyter==1.0.0->d2l==0.17.2) (4.4.2)\n",
"Requirement already satisfied: pickleshare in /usr/local/lib/python3.7/dist-packages (from ipython>=4.0.0->ipykernel->jupyter==1.0.0->d2l==0.17.2) (0.7.5)\n",
"Requirement already satisfied: prompt-toolkit<2.0.0,>=1.0.4 in /usr/local/lib/python3.7/dist-packages (from ipython>=4.0.0->ipykernel->jupyter==1.0.0->d2l==0.17.2) (1.0.18)\n",
"Requirement already satisfied: pygments in /usr/local/lib/python3.7/dist-packages (from ipython>=4.0.0->ipykernel->jupyter==1.0.0->d2l==0.17.2) (2.6.1)\n",
"Requirement already satisfied: pexpect in /usr/local/lib/python3.7/dist-packages (from ipython>=4.0.0->ipykernel->jupyter==1.0.0->d2l==0.17.2) (4.8.0)\n",
"Requirement already satisfied: simplegeneric>0.8 in /usr/local/lib/python3.7/dist-packages (from ipython>=4.0.0->ipykernel->jupyter==1.0.0->d2l==0.17.2) (0.8.1)\n",
"Requirement already satisfied: setuptools>=18.5 in /usr/local/lib/python3.7/dist-packages (from ipython>=4.0.0->ipykernel->jupyter==1.0.0->d2l==0.17.2) (57.4.0)\n",
"Requirement already satisfied: wcwidth in /usr/local/lib/python3.7/dist-packages (from prompt-toolkit<2.0.0,>=1.0.4->ipython>=4.0.0->ipykernel->jupyter==1.0.0->d2l==0.17.2) (0.2.5)\n",
"Requirement already satisfied: ipython-genutils~=0.2.0 in /usr/local/lib/python3.7/dist-packages (from ipywidgets->jupyter==1.0.0->d2l==0.17.2) (0.2.0)\n",
"Requirement already satisfied: jupyterlab-widgets>=1.0.0 in /usr/local/lib/python3.7/dist-packages (from ipywidgets->jupyter==1.0.0->d2l==0.17.2) (1.0.2)\n",
"Requirement already satisfied: widgetsnbextension~=3.5.0 in /usr/local/lib/python3.7/dist-packages (from ipywidgets->jupyter==1.0.0->d2l==0.17.2) (3.5.2)\n",
"Requirement already satisfied: nbformat>=4.2.0 in /usr/local/lib/python3.7/dist-packages (from ipywidgets->jupyter==1.0.0->d2l==0.17.2) (5.1.3)\n",
"Requirement already satisfied: jupyter-core in /usr/local/lib/python3.7/dist-packages (from nbformat>=4.2.0->ipywidgets->jupyter==1.0.0->d2l==0.17.2) (4.9.2)\n",
"Requirement already satisfied: jsonschema!=2.5.0,>=2.4 in /usr/local/lib/python3.7/dist-packages (from nbformat>=4.2.0->ipywidgets->jupyter==1.0.0->d2l==0.17.2) (4.3.3)\n",
"Requirement already satisfied: importlib-metadata in /usr/local/lib/python3.7/dist-packages (from jsonschema!=2.5.0,>=2.4->nbformat>=4.2.0->ipywidgets->jupyter==1.0.0->d2l==0.17.2) (4.11.1)\n",
"Requirement already satisfied: importlib-resources>=1.4.0 in /usr/local/lib/python3.7/dist-packages (from jsonschema!=2.5.0,>=2.4->nbformat>=4.2.0->ipywidgets->jupyter==1.0.0->d2l==0.17.2) (5.4.0)\n",
"Requirement already satisfied: attrs>=17.4.0 in /usr/local/lib/python3.7/dist-packages (from jsonschema!=2.5.0,>=2.4->nbformat>=4.2.0->ipywidgets->jupyter==1.0.0->d2l==0.17.2) (21.4.0)\n",
"Requirement already satisfied: pyrsistent!=0.17.0,!=0.17.1,!=0.17.2,>=0.14.0 in /usr/local/lib/python3.7/dist-packages (from jsonschema!=2.5.0,>=2.4->nbformat>=4.2.0->ipywidgets->jupyter==1.0.0->d2l==0.17.2) (0.18.1)\n",
"Requirement already satisfied: typing-extensions in /usr/local/lib/python3.7/dist-packages (from jsonschema!=2.5.0,>=2.4->nbformat>=4.2.0->ipywidgets->jupyter==1.0.0->d2l==0.17.2) (3.10.0.2)\n",
"Requirement already satisfied: zipp>=3.1.0 in /usr/local/lib/python3.7/dist-packages (from importlib-resources>=1.4.0->jsonschema!=2.5.0,>=2.4->nbformat>=4.2.0->ipywidgets->jupyter==1.0.0->d2l==0.17.2) (3.7.0)\n",
"Requirement already satisfied: terminado>=0.8.1 in /usr/local/lib/python3.7/dist-packages (from notebook->jupyter==1.0.0->d2l==0.17.2) (0.13.1)\n",
"Requirement already satisfied: jinja2 in /usr/local/lib/python3.7/dist-packages (from notebook->jupyter==1.0.0->d2l==0.17.2) (2.11.3)\n",
"Requirement already satisfied: Send2Trash in /usr/local/lib/python3.7/dist-packages (from notebook->jupyter==1.0.0->d2l==0.17.2) (1.8.0)\n",
"Requirement already satisfied: pyzmq>=13 in /usr/local/lib/python3.7/dist-packages (from jupyter-client->ipykernel->jupyter==1.0.0->d2l==0.17.2) (22.3.0)\n",
"Requirement already satisfied: ptyprocess in /usr/local/lib/python3.7/dist-packages (from terminado>=0.8.1->notebook->jupyter==1.0.0->d2l==0.17.2) (0.7.0)\n",
"Requirement already satisfied: MarkupSafe>=0.23 in /usr/local/lib/python3.7/dist-packages (from jinja2->notebook->jupyter==1.0.0->d2l==0.17.2) (2.0.1)\n",
"Requirement already satisfied: pandocfilters>=1.4.1 in /usr/local/lib/python3.7/dist-packages (from nbconvert->jupyter==1.0.0->d2l==0.17.2) (1.5.0)\n",
"Requirement already satisfied: defusedxml in /usr/local/lib/python3.7/dist-packages (from nbconvert->jupyter==1.0.0->d2l==0.17.2) (0.7.1)\n",
"Requirement already satisfied: testpath in /usr/local/lib/python3.7/dist-packages (from nbconvert->jupyter==1.0.0->d2l==0.17.2) (0.5.0)\n",
"Requirement already satisfied: mistune<2,>=0.8.1 in /usr/local/lib/python3.7/dist-packages (from nbconvert->jupyter==1.0.0->d2l==0.17.2) (0.8.4)\n",
"Requirement already satisfied: entrypoints>=0.2.2 in /usr/local/lib/python3.7/dist-packages (from nbconvert->jupyter==1.0.0->d2l==0.17.2) (0.4)\n",
"Requirement already satisfied: bleach in /usr/local/lib/python3.7/dist-packages (from nbconvert->jupyter==1.0.0->d2l==0.17.2) (4.1.0)\n",
"Requirement already satisfied: webencodings in /usr/local/lib/python3.7/dist-packages (from bleach->nbconvert->jupyter==1.0.0->d2l==0.17.2) (0.5.1)\n",
"Requirement already satisfied: packaging in /usr/local/lib/python3.7/dist-packages (from bleach->nbconvert->jupyter==1.0.0->d2l==0.17.2) (21.3)\n",
"Requirement already satisfied: qtpy in /usr/local/lib/python3.7/dist-packages (from qtconsole->jupyter==1.0.0->d2l==0.17.2) (2.0.1)\n"
]
}
],
"source": [
"!pip install -U mxnet-cu101==1.7.0\n",
"!pip install d2l==0.17.2\n"
]
},
{
"cell_type": "code",
"source": [
"%matplotlib inline"
],
"metadata": {
"id": "gWX868hcutj5"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"origin_pos": 0,
"id": "7VkMAsK8apxT"
},
"source": [
"# Deep Convolutional Generative Adversarial Networks\n",
":label:`sec_dcgan`\n",
"\n",
"In :numref:`sec_basic_gan`, we introduced the basic ideas behind how GANs work. We showed that they can draw samples from some simple, easy-to-sample distribution, like a uniform or normal distribution, and transform them into samples that appear to match the distribution of some dataset. And while our example of matching a 2D Gaussian distribution got the point across, it is not especially exciting.\n",
"\n",
"In this section, we will demonstrate how you can use GANs to generate photorealistic images. We will be basing our models on the deep convolutional GANs (DCGAN) introduced in :cite:`Radford.Metz.Chintala.2015`. We will borrow the convolutional architecture that have proven so successful for discriminative computer vision problems and show how via GANs, they can be leveraged to generate photorealistic images.\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"origin_pos": 1,
"tab": [
"mxnet"
],
"id": "c5BOGgdlapxU"
},
"outputs": [],
"source": [
"from mxnet import gluon, init, np, npx\n",
"from mxnet.gluon import nn\n",
"from d2l import mxnet as d2l\n",
"\n",
"npx.set_np()"
]
},
{
"cell_type": "markdown",
"metadata": {
"origin_pos": 4,
"id": "q9e5r7T4apxV"
},
"source": [
"## The Pokemon Dataset\n",
"\n",
"The dataset we will use is a collection of Pokemon sprites obtained from [pokemondb](https://pokemondb.net/sprites). First download, extract and load this dataset.\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"origin_pos": 5,
"tab": [
"mxnet"
],
"id": "msDZAay1apxV"
},
"outputs": [],
"source": [
"#@save\n",
"d2l.DATA_HUB['pokemon'] = (d2l.DATA_URL + 'pokemon.zip',\n",
" 'c065c0e2593b8b161a2d7873e42418bf6a21106c')\n",
"\n",
"data_dir = d2l.download_extract('pokemon')\n",
"pokemon = gluon.data.vision.datasets.ImageFolderDataset(data_dir)"
]
},
{
"cell_type": "markdown",
"metadata": {
"origin_pos": 8,
"id": "mPutJNifapxX"
},
"source": [
"We resize each image into $64\\times 64$. The `ToTensor` transformation will project the pixel value into $[0, 1]$, while our generator will use the tanh function to obtain outputs in $[-1, 1]$. Therefore we normalize the data with $0.5$ mean and $0.5$ standard deviation to match the value range.\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"origin_pos": 9,
"tab": [
"mxnet"
],
"id": "T27XDwlEapxY"
},
"outputs": [],
"source": [
"batch_size = 256\n",
"transformer = gluon.data.vision.transforms.Compose([\n",
" gluon.data.vision.transforms.Resize(64),\n",
" gluon.data.vision.transforms.ToTensor(),\n",
" gluon.data.vision.transforms.Normalize(0.5, 0.5)\n",
"])\n",
"data_iter = gluon.data.DataLoader(\n",
" pokemon.transform_first(transformer), batch_size=batch_size,\n",
" shuffle=True, num_workers=d2l.get_dataloader_workers())"
]
},
{
"cell_type": "markdown",
"metadata": {
"origin_pos": 12,
"id": "uHtIRIPvapxZ"
},
"source": [
"Let us visualize the first 20 images.\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"origin_pos": 13,
"tab": [
"mxnet"
],
"colab": {
"base_uri": "https://localhost:8080/",
"height": 885
},
"id": "uf_fO7pYapxa",
"outputId": "2c445f92-6bd9-4ed2-92ca-df83da1942f4"
},
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"array([<matplotlib.axes._subplots.AxesSubplot object at 0x7fed01f240d0>,\n",
" <matplotlib.axes._subplots.AxesSubplot object at 0x7fed06f314d0>,\n",
" <matplotlib.axes._subplots.AxesSubplot object at 0x7fed01e4fd10>,\n",
" <matplotlib.axes._subplots.AxesSubplot object at 0x7fed07087c50>,\n",
" <matplotlib.axes._subplots.AxesSubplot object at 0x7fed06f2d450>,\n",
" <matplotlib.axes._subplots.AxesSubplot object at 0x7fed0391fc90>,\n",
" <matplotlib.axes._subplots.AxesSubplot object at 0x7fed039e9710>,\n",
" <matplotlib.axes._subplots.AxesSubplot object at 0x7fed003a6490>,\n",
" <matplotlib.axes._subplots.AxesSubplot object at 0x7fed002799d0>,\n",
" <matplotlib.axes._subplots.AxesSubplot object at 0x7fed001c8f90>,\n",
" <matplotlib.axes._subplots.AxesSubplot object at 0x7fed073cb550>,\n",
" <matplotlib.axes._subplots.AxesSubplot object at 0x7fed002bbb10>,\n",
" <matplotlib.axes._subplots.AxesSubplot object at 0x7fed002d2110>,\n",
" <matplotlib.axes._subplots.AxesSubplot object at 0x7fed075916d0>,\n",
" <matplotlib.axes._subplots.AxesSubplot object at 0x7fed075d2c90>,\n",
" <matplotlib.axes._subplots.AxesSubplot object at 0x7fecedd05290>,\n",
" <matplotlib.axes._subplots.AxesSubplot object at 0x7feceddc9790>,\n",
" <matplotlib.axes._subplots.AxesSubplot object at 0x7fecedd89d50>,\n",
" <matplotlib.axes._subplots.AxesSubplot object at 0x7fecedd5a350>,\n",
" <matplotlib.axes._subplots.AxesSubplot object at 0x7fecedcdb910>],\n",
" dtype=object)"
]
},
"metadata": {},
"execution_count": 64
},
{
"output_type": "error",
"ename": "ImportError",
"evalue": "ignored",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mImportError\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m/usr/local/lib/python3.7/dist-packages/IPython/core/formatters.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, obj)\u001b[0m\n\u001b[1;32m 332\u001b[0m \u001b[0;32mpass\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 333\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 334\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mprinter\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mobj\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 335\u001b[0m \u001b[0;31m# Finally look for special method names\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 336\u001b[0m \u001b[0mmethod\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mget_real_method\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mobj\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mprint_method\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.7/dist-packages/IPython/core/pylabtools.py\u001b[0m in \u001b[0;36m<lambda>\u001b[0;34m(fig)\u001b[0m\n\u001b[1;32m 245\u001b[0m \u001b[0mjpg_formatter\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfor_type\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mFigure\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;32mlambda\u001b[0m \u001b[0mfig\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mprint_figure\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfig\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'jpg'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 246\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0;34m'svg'\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mformats\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 247\u001b[0;31m \u001b[0msvg_formatter\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfor_type\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mFigure\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;32mlambda\u001b[0m \u001b[0mfig\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mprint_figure\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfig\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'svg'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 248\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0;34m'pdf'\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mformats\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 249\u001b[0m \u001b[0mpdf_formatter\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfor_type\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mFigure\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;32mlambda\u001b[0m \u001b[0mfig\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mprint_figure\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfig\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'pdf'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.7/dist-packages/IPython/core/pylabtools.py\u001b[0m in \u001b[0;36mprint_figure\u001b[0;34m(fig, fmt, bbox_inches, **kwargs)\u001b[0m\n\u001b[1;32m 123\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 124\u001b[0m \u001b[0mbytes_io\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mBytesIO\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 125\u001b[0;31m \u001b[0mfig\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcanvas\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mprint_figure\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbytes_io\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkw\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 126\u001b[0m \u001b[0mdata\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mbytes_io\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mgetvalue\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 127\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mfmt\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;34m'svg'\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.7/dist-packages/matplotlib/backend_bases.py\u001b[0m in \u001b[0;36mprint_figure\u001b[0;34m(self, filename, dpi, facecolor, edgecolor, orientation, format, bbox_inches, **kwargs)\u001b[0m\n\u001b[1;32m 2057\u001b[0m \u001b[0mIf\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0mbackend\u001b[0m\u001b[0;34m*\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mthen\u001b[0m \u001b[0mdetermine\u001b[0m \u001b[0ma\u001b[0m \u001b[0msuitable\u001b[0m \u001b[0mcanvas\u001b[0m \u001b[0;32mclass\u001b[0m \u001b[0;32mfor\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2058\u001b[0m \u001b[0msaving\u001b[0m \u001b[0mto\u001b[0m \u001b[0mformat\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0mfmt\u001b[0m\u001b[0;34m*\u001b[0m \u001b[0;34m-\u001b[0m\u001b[0;34m-\u001b[0m \u001b[0meither\u001b[0m \u001b[0mthe\u001b[0m \u001b[0mcurrent\u001b[0m \u001b[0mcanvas\u001b[0m \u001b[0;32mclass\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mit\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 2059\u001b[0;31m \u001b[0msupports\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0mfmt\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0mwhatever\u001b[0m\u001b[0;31m \u001b[0m\u001b[0;31m`\u001b[0m\u001b[0mget_registered_canvas_class\u001b[0m\u001b[0;31m`\u001b[0m \u001b[0mreturns\u001b[0m\u001b[0;34m;\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 2060\u001b[0m \u001b[0mswitch\u001b[0m \u001b[0mthe\u001b[0m \u001b[0mfigure\u001b[0m \u001b[0mcanvas\u001b[0m \u001b[0mto\u001b[0m \u001b[0mthat\u001b[0m \u001b[0mcanvas\u001b[0m \u001b[0;32mclass\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2061\u001b[0m \"\"\"\n",
"\u001b[0;32m/usr/local/lib/python3.7/dist-packages/matplotlib/backend_bases.py\u001b[0m in \u001b[0;36m_get_output_canvas\u001b[0;34m(self, fmt)\u001b[0m\n\u001b[1;32m 1991\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmouse_grabber\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0max\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1992\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmouse_grabber\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1993\u001b[0;31m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1994\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mdraw\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1995\u001b[0m \u001b[0;34m\"\"\"Render the `.Figure`.\"\"\"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.7/dist-packages/matplotlib/backend_bases.py\u001b[0m in \u001b[0;36mget_registered_canvas_class\u001b[0;34m(format)\u001b[0m\n\u001b[1;32m 124\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 125\u001b[0m \u001b[0mThe\u001b[0m \u001b[0mfollowing\u001b[0m \u001b[0mmethods\u001b[0m \u001b[0mmust\u001b[0m \u001b[0mbe\u001b[0m \u001b[0mimplemented\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mthe\u001b[0m \u001b[0mbackend\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mfull\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 126\u001b[0;31m functionality (though just implementing :meth:`draw_path` alone would\n\u001b[0m\u001b[1;32m 127\u001b[0m give a highly capable backend):\n\u001b[1;32m 128\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/lib/python3.7/importlib/__init__.py\u001b[0m in \u001b[0;36mimport_module\u001b[0;34m(name, package)\u001b[0m\n\u001b[1;32m 125\u001b[0m \u001b[0;32mbreak\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 126\u001b[0m \u001b[0mlevel\u001b[0m \u001b[0;34m+=\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 127\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0m_bootstrap\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_gcd_import\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mname\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mlevel\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mpackage\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlevel\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 128\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 129\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/lib/python3.7/importlib/_bootstrap.py\u001b[0m in \u001b[0;36m_gcd_import\u001b[0;34m(name, package, level)\u001b[0m\n",
"\u001b[0;32m/usr/lib/python3.7/importlib/_bootstrap.py\u001b[0m in \u001b[0;36m_find_and_load\u001b[0;34m(name, import_)\u001b[0m\n",
"\u001b[0;32m/usr/lib/python3.7/importlib/_bootstrap.py\u001b[0m in \u001b[0;36m_find_and_load_unlocked\u001b[0;34m(name, import_)\u001b[0m\n",
"\u001b[0;32m/usr/lib/python3.7/importlib/_bootstrap.py\u001b[0m in \u001b[0;36m_load_unlocked\u001b[0;34m(spec)\u001b[0m\n",
"\u001b[0;32m/usr/lib/python3.7/importlib/_bootstrap_external.py\u001b[0m in \u001b[0;36mexec_module\u001b[0;34m(self, module)\u001b[0m\n",
"\u001b[0;32m/usr/lib/python3.7/importlib/_bootstrap.py\u001b[0m in \u001b[0;36m_call_with_frames_removed\u001b[0;34m(f, *args, **kwds)\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.7/dist-packages/matplotlib/backends/backend_svg.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[1;32m 16\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mmatplotlib\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mmpl\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 17\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mmatplotlib\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mcbook\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 18\u001b[0;31m from matplotlib.backend_bases import (\n\u001b[0m\u001b[1;32m 19\u001b[0m \u001b[0m_Backend\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0m_check_savefig_extra_args\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mFigureCanvasBase\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mFigureManagerBase\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 20\u001b[0m RendererBase)\n",
"\u001b[0;31mImportError\u001b[0m: cannot import name '_check_savefig_extra_args' from 'matplotlib.backend_bases' (/usr/local/lib/python3.7/dist-packages/matplotlib/backend_bases.py)"
]
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<Figure size 540x432 with 20 Axes>"
]
},
"metadata": {
"needs_background": "light"
}
}
],
"source": [
"d2l.set_figsize((4, 4))\n",
"X, y = next(iter(data_iter))\n",
"imgs = X[0:20,:,:,:].transpose(0, 2, 3, 1)/2+0.5\n",
"d2l.show_images(imgs, num_rows=4, num_cols=5)"
]
},
{
"cell_type": "markdown",
"metadata": {
"origin_pos": 16,
"id": "ZBMQcrm3apxb"
},
"source": [
"## The Generator\n",
"\n",
"The generator needs to map the noise variable $\\mathbf z\\in\\mathbb R^d$, a length-$d$ vector, to a RGB image with width and height to be $64\\times 64$ . In :numref:`sec_fcn` we introduced the fully convolutional network that uses transposed convolution layer (refer to :numref:`sec_transposed_conv`) to enlarge input size. The basic block of the generator contains a transposed convolution layer followed by the batch normalization and ReLU activation.\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"origin_pos": 17,
"tab": [
"mxnet"
],
"id": "ExNYVBt9apxc"
},
"outputs": [],
"source": [
"class G_block(nn.Block):\n",
" def __init__(self, channels, kernel_size=4,\n",
" strides=2, padding=1, **kwargs):\n",
" super(G_block, self).__init__(**kwargs)\n",
" self.conv2d_trans = nn.Conv2DTranspose(\n",
" channels, kernel_size, strides, padding, use_bias=False)\n",
" self.batch_norm = nn.BatchNorm()\n",
" self.activation = nn.Activation('relu')\n",
"\n",
" def forward(self, X):\n",
" return self.activation(self.batch_norm(self.conv2d_trans(X)))"
]
},
{
"cell_type": "markdown",
"metadata": {
"origin_pos": 20,
"id": "n5gqbRUtapxc"
},
"source": [
"In default, the transposed convolution layer uses a $k_h = k_w = 4$ kernel, a $s_h = s_w = 2$ strides, and a $p_h = p_w = 1$ padding. With a input shape of $n_h^{'} \\times n_w^{'} = 16 \\times 16$, the generator block will double input's width and height.\n",
"\n",
"$$\n",
"\\begin{aligned}\n",
"n_h^{'} \\times n_w^{'} &= [(n_h k_h - (n_h-1)(k_h-s_h)- 2p_h] \\times [(n_w k_w - (n_w-1)(k_w-s_w)- 2p_w]\\\\\n",
" &= [(k_h + s_h (n_h-1)- 2p_h] \\times [(k_w + s_w (n_w-1)- 2p_w]\\\\\n",
" &= [(4 + 2 \\times (16-1)- 2 \\times 1] \\times [(4 + 2 \\times (16-1)- 2 \\times 1]\\\\\n",
" &= 32 \\times 32 .\\\\\n",
"\\end{aligned}\n",
"$$\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"origin_pos": 21,
"tab": [
"mxnet"
],
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "O8W8VLqtapxd",
"outputId": "da79d553-b892-4c09-d850-06649dfa084e"
},
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"(2, 20, 32, 32)"
]
},
"metadata": {},
"execution_count": 66
}
],
"source": [
"x = np.zeros((2, 3, 16, 16))\n",
"g_blk = G_block(20)\n",
"g_blk.initialize()\n",
"g_blk(x).shape"
]
},
{
"cell_type": "markdown",
"metadata": {
"origin_pos": 24,
"id": "mHcqxXuGapxe"
},
"source": [
"If changing the transposed convolution layer to a $4\\times 4$ kernel, $1\\times 1$ strides and zero padding. With a input size of $1 \\times 1$, the output will have its width and height increased by 3 respectively.\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"origin_pos": 25,
"tab": [
"mxnet"
],
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "SXnBWTEvapxe",
"outputId": "1c1e190c-ccdb-44e0-9b8c-6e6174f287e3"
},
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"(2, 20, 4, 4)"
]
},
"metadata": {},
"execution_count": 67
}
],
"source": [
"x = np.zeros((2, 3, 1, 1))\n",
"g_blk = G_block(20, strides=1, padding=0)\n",
"g_blk.initialize()\n",
"g_blk(x).shape"
]
},
{
"cell_type": "markdown",
"metadata": {
"origin_pos": 28,
"id": "zyR5CZFDapxf"
},
"source": [
"The generator consists of four basic blocks that increase input's both width and height from 1 to 32. At the same time, it first projects the latent variable into $64\\times 8$ channels, and then halve the channels each time. At last, a transposed convolution layer is used to generate the output. It further doubles the width and height to match the desired $64\\times 64$ shape, and reduces the channel size to $3$. The tanh activation function is applied to project output values into the $(-1, 1)$ range.\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"origin_pos": 29,
"tab": [
"mxnet"
],
"id": "z9GkPiqkapxf"
},
"outputs": [],
"source": [
"n_G = 64\n",
"net_G = nn.Sequential()\n",
"net_G.add(G_block(n_G*8, strides=1, padding=0), # Output: (64 * 8, 4, 4)\n",
" G_block(n_G*4), # Output: (64 * 4, 8, 8)\n",
" G_block(n_G*2), # Output: (64 * 2, 16, 16)\n",
" G_block(n_G), # Output: (64, 32, 32)\n",
" nn.Conv2DTranspose(\n",
" 3, kernel_size=4, strides=2, padding=1, use_bias=False,\n",
" activation='tanh')) # Output: (3, 64, 64)"
]
},
{
"cell_type": "markdown",
"metadata": {
"origin_pos": 32,
"id": "1x6jQtoOapxg"
},
"source": [
"Generate a 100 dimensional latent variable to verify the generator's output shape.\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"origin_pos": 33,
"tab": [
"mxnet"
],
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "vnJMm4Ecapxg",
"outputId": "91ffdee4-3679-4de4-ee42-18d26e0380ad"
},
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"(1, 3, 64, 64)"
]
},
"metadata": {},
"execution_count": 69
}
],
"source": [
"x = np.zeros((1, 100, 1, 1))\n",
"net_G.initialize()\n",
"net_G(x).shape"
]
},
{
"cell_type": "markdown",
"metadata": {
"origin_pos": 36,
"id": "p6mQRL94apxh"
},
"source": [
"## Discriminator\n",
"\n",
"The discriminator is a normal convolutional network network except that it uses a leaky ReLU as its activation function. Given $\\alpha \\in[0, 1]$, its definition is\n",
"\n",
"$$\\textrm{leaky ReLU}(x) = \\begin{cases}x & \\text{if}\\ x > 0\\\\ \\alpha x &\\text{otherwise}\\end{cases}.$$\n",
"\n",
"As it can be seen, it is normal ReLU if $\\alpha=0$, and an identity function if $\\alpha=1$. For $\\alpha \\in (0, 1)$, leaky ReLU is a nonlinear function that give a non-zero output for a negative input. It aims to fix the \"dying ReLU\" problem that a neuron might always output a negative value and therefore cannot make any progress since the gradient of ReLU is 0.\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"origin_pos": 37,
"tab": [
"mxnet"
],
"colab": {
"base_uri": "https://localhost:8080/",
"height": 503
},
"id": "k7RmoxFSapxh",
"outputId": "eb5aee7d-80f9-40d0-8a2b-8e9ad7c2dfd9"
},
"outputs": [
{
"output_type": "error",
"ename": "ImportError",
"evalue": "ignored",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mImportError\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m/usr/local/lib/python3.7/dist-packages/IPython/core/formatters.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, obj)\u001b[0m\n\u001b[1;32m 332\u001b[0m \u001b[0;32mpass\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 333\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 334\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mprinter\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mobj\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 335\u001b[0m \u001b[0;31m# Finally look for special method names\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 336\u001b[0m \u001b[0mmethod\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mget_real_method\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mobj\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mprint_method\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.7/dist-packages/IPython/core/pylabtools.py\u001b[0m in \u001b[0;36m<lambda>\u001b[0;34m(fig)\u001b[0m\n\u001b[1;32m 245\u001b[0m \u001b[0mjpg_formatter\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfor_type\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mFigure\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;32mlambda\u001b[0m \u001b[0mfig\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mprint_figure\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfig\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'jpg'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 246\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0;34m'svg'\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mformats\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 247\u001b[0;31m \u001b[0msvg_formatter\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfor_type\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mFigure\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;32mlambda\u001b[0m \u001b[0mfig\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mprint_figure\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfig\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'svg'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 248\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0;34m'pdf'\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mformats\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 249\u001b[0m \u001b[0mpdf_formatter\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfor_type\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mFigure\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;32mlambda\u001b[0m \u001b[0mfig\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mprint_figure\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfig\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'pdf'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.7/dist-packages/IPython/core/pylabtools.py\u001b[0m in \u001b[0;36mprint_figure\u001b[0;34m(fig, fmt, bbox_inches, **kwargs)\u001b[0m\n\u001b[1;32m 123\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 124\u001b[0m \u001b[0mbytes_io\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mBytesIO\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 125\u001b[0;31m \u001b[0mfig\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcanvas\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mprint_figure\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbytes_io\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkw\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 126\u001b[0m \u001b[0mdata\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mbytes_io\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mgetvalue\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 127\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mfmt\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;34m'svg'\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.7/dist-packages/matplotlib/backend_bases.py\u001b[0m in \u001b[0;36mprint_figure\u001b[0;34m(self, filename, dpi, facecolor, edgecolor, orientation, format, bbox_inches, **kwargs)\u001b[0m\n\u001b[1;32m 2057\u001b[0m \u001b[0mIf\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0mbackend\u001b[0m\u001b[0;34m*\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mthen\u001b[0m \u001b[0mdetermine\u001b[0m \u001b[0ma\u001b[0m \u001b[0msuitable\u001b[0m \u001b[0mcanvas\u001b[0m \u001b[0;32mclass\u001b[0m \u001b[0;32mfor\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2058\u001b[0m \u001b[0msaving\u001b[0m \u001b[0mto\u001b[0m \u001b[0mformat\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0mfmt\u001b[0m\u001b[0;34m*\u001b[0m \u001b[0;34m-\u001b[0m\u001b[0;34m-\u001b[0m \u001b[0meither\u001b[0m \u001b[0mthe\u001b[0m \u001b[0mcurrent\u001b[0m \u001b[0mcanvas\u001b[0m \u001b[0;32mclass\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mit\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 2059\u001b[0;31m \u001b[0msupports\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0mfmt\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0mwhatever\u001b[0m\u001b[0;31m \u001b[0m\u001b[0;31m`\u001b[0m\u001b[0mget_registered_canvas_class\u001b[0m\u001b[0;31m`\u001b[0m \u001b[0mreturns\u001b[0m\u001b[0;34m;\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 2060\u001b[0m \u001b[0mswitch\u001b[0m \u001b[0mthe\u001b[0m \u001b[0mfigure\u001b[0m \u001b[0mcanvas\u001b[0m \u001b[0mto\u001b[0m \u001b[0mthat\u001b[0m \u001b[0mcanvas\u001b[0m \u001b[0;32mclass\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2061\u001b[0m \"\"\"\n",
"\u001b[0;32m/usr/local/lib/python3.7/dist-packages/matplotlib/backend_bases.py\u001b[0m in \u001b[0;36m_get_output_canvas\u001b[0;34m(self, fmt)\u001b[0m\n\u001b[1;32m 1991\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmouse_grabber\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0max\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1992\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmouse_grabber\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1993\u001b[0;31m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1994\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mdraw\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1995\u001b[0m \u001b[0;34m\"\"\"Render the `.Figure`.\"\"\"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.7/dist-packages/matplotlib/backend_bases.py\u001b[0m in \u001b[0;36mget_registered_canvas_class\u001b[0;34m(format)\u001b[0m\n\u001b[1;32m 124\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 125\u001b[0m \u001b[0mThe\u001b[0m \u001b[0mfollowing\u001b[0m \u001b[0mmethods\u001b[0m \u001b[0mmust\u001b[0m \u001b[0mbe\u001b[0m \u001b[0mimplemented\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mthe\u001b[0m \u001b[0mbackend\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mfull\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 126\u001b[0;31m functionality (though just implementing :meth:`draw_path` alone would\n\u001b[0m\u001b[1;32m 127\u001b[0m give a highly capable backend):\n\u001b[1;32m 128\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/lib/python3.7/importlib/__init__.py\u001b[0m in \u001b[0;36mimport_module\u001b[0;34m(name, package)\u001b[0m\n\u001b[1;32m 125\u001b[0m \u001b[0;32mbreak\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 126\u001b[0m \u001b[0mlevel\u001b[0m \u001b[0;34m+=\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 127\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0m_bootstrap\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_gcd_import\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mname\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mlevel\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mpackage\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlevel\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 128\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 129\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/lib/python3.7/importlib/_bootstrap.py\u001b[0m in \u001b[0;36m_gcd_import\u001b[0;34m(name, package, level)\u001b[0m\n",
"\u001b[0;32m/usr/lib/python3.7/importlib/_bootstrap.py\u001b[0m in \u001b[0;36m_find_and_load\u001b[0;34m(name, import_)\u001b[0m\n",
"\u001b[0;32m/usr/lib/python3.7/importlib/_bootstrap.py\u001b[0m in \u001b[0;36m_find_and_load_unlocked\u001b[0;34m(name, import_)\u001b[0m\n",
"\u001b[0;32m/usr/lib/python3.7/importlib/_bootstrap.py\u001b[0m in \u001b[0;36m_load_unlocked\u001b[0;34m(spec)\u001b[0m\n",
"\u001b[0;32m/usr/lib/python3.7/importlib/_bootstrap_external.py\u001b[0m in \u001b[0;36mexec_module\u001b[0;34m(self, module)\u001b[0m\n",
"\u001b[0;32m/usr/lib/python3.7/importlib/_bootstrap.py\u001b[0m in \u001b[0;36m_call_with_frames_removed\u001b[0;34m(f, *args, **kwds)\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.7/dist-packages/matplotlib/backends/backend_svg.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[1;32m 16\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mmatplotlib\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mmpl\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 17\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mmatplotlib\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mcbook\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 18\u001b[0;31m from matplotlib.backend_bases import (\n\u001b[0m\u001b[1;32m 19\u001b[0m \u001b[0m_Backend\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0m_check_savefig_extra_args\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mFigureCanvasBase\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mFigureManagerBase\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 20\u001b[0m RendererBase)\n",
"\u001b[0;31mImportError\u001b[0m: cannot import name '_check_savefig_extra_args' from 'matplotlib.backend_bases' (/usr/local/lib/python3.7/dist-packages/matplotlib/backend_bases.py)"
]
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<Figure size 252x180 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
}
}
],
"source": [
"alphas = [0, .2, .4, .6, .8, 1]\n",
"x = np.arange(-2, 1, 0.1)\n",
"Y = [nn.LeakyReLU(alpha)(x).asnumpy() for alpha in alphas]\n",
"d2l.plot(x.asnumpy(), Y, 'x', 'y', alphas)"
]
},
{
"cell_type": "markdown",
"metadata": {
"origin_pos": 39,
"id": "T3K97yFfapxh"
},
"source": [
"The basic block of the discriminator is a convolution layer followed by a batch normalization layer and a leaky ReLU activation. The hyperparameters of the convolution layer are similar to the transpose convolution layer in the generator block.\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"origin_pos": 40,
"tab": [
"mxnet"
],
"id": "U2nMK1Osapxi"
},
"outputs": [],
"source": [
"class D_block(nn.Block):\n",
" def __init__(self, channels, kernel_size=4, strides=2,\n",
" padding=1, alpha=0.2, **kwargs):\n",
" super(D_block, self).__init__(**kwargs)\n",
" self.conv2d = nn.Conv2D(\n",
" channels, kernel_size, strides, padding, use_bias=False)\n",
" self.batch_norm = nn.BatchNorm()\n",
" self.activation = nn.LeakyReLU(alpha)\n",
"\n",
" def forward(self, X):\n",
" return self.activation(self.batch_norm(self.conv2d(X)))"
]
},
{
"cell_type": "markdown",
"metadata": {
"origin_pos": 43,
"id": "soOZfwBLapxi"
},
"source": [
"A basic block with default settings will halve the width and height of the inputs, as we demonstrated in :numref:`sec_padding`. For example, given a input shape $n_h = n_w = 16$, with a kernel shape $k_h = k_w = 4$, a stride shape $s_h = s_w = 2$, and a padding shape $p_h = p_w = 1$, the output shape will be:\n",
"\n",
"$$\n",
"\\begin{aligned}\n",
"n_h^{'} \\times n_w^{'} &= \\lfloor(n_h-k_h+2p_h+s_h)/s_h\\rfloor \\times \\lfloor(n_w-k_w+2p_w+s_w)/s_w\\rfloor\\\\\n",
" &= \\lfloor(16-4+2\\times 1+2)/2\\rfloor \\times \\lfloor(16-4+2\\times 1+2)/2\\rfloor\\\\\n",
" &= 8 \\times 8 .\\\\\n",
"\\end{aligned}\n",
"$$\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"origin_pos": 44,
"tab": [
"mxnet"
],
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "4YXw775Kapxj",
"outputId": "9106844e-da25-488a-fc98-eaf853806579"
},
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"(2, 20, 8, 8)"
]
},
"metadata": {},
"execution_count": 72
}
],
"source": [
"x = np.zeros((2, 3, 16, 16))\n",
"d_blk = D_block(20)\n",
"d_blk.initialize()\n",
"d_blk(x).shape"
]
},
{
"cell_type": "markdown",
"metadata": {
"origin_pos": 47,
"id": "Iw3KoJcuapxj"
},
"source": [
"The discriminator is a mirror of the generator.\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"origin_pos": 48,
"tab": [
"mxnet"
],
"id": "uK3sqkr1apxj"
},
"outputs": [],
"source": [
"n_D = 64\n",
"net_D = nn.Sequential()\n",
"net_D.add(D_block(n_D), # Output: (64, 32, 32)\n",
" D_block(n_D*2), # Output: (64 * 2, 16, 16)\n",
" D_block(n_D*4), # Output: (64 * 4, 8, 8)\n",
" D_block(n_D*8), # Output: (64 * 8, 4, 4)\n",
" nn.Conv2D(1, kernel_size=4, use_bias=False)) # Output: (1, 1, 1)"
]
},
{
"cell_type": "markdown",
"metadata": {
"origin_pos": 51,
"id": "E7Exncabapxj"
},
"source": [
"It uses a convolution layer with output channel $1$ as the last layer to obtain a single prediction value.\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"origin_pos": 52,
"tab": [
"mxnet"
],
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "dquptkmpapxk",
"outputId": "2b5c96e4-c85e-4c15-b885-6789818dbb2f"
},
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"(1, 1, 1, 1)"
]
},
"metadata": {},
"execution_count": 74
}
],
"source": [
"x = np.zeros((1, 3, 64, 64))\n",
"net_D.initialize()\n",
"net_D(x).shape"
]
},
{
"cell_type": "markdown",
"metadata": {
"origin_pos": 55,
"id": "UwmC_MNbapxk"
},
"source": [
"## Training\n",
"\n",
"Compared to the basic GAN in :numref:`sec_basic_gan`, we use the same learning rate for both generator and discriminator since they are similar to each other. In addition, we change $\\beta_1$ in Adam (:numref:`sec_adam`) from $0.9$ to $0.5$. It decreases the smoothness of the momentum, the exponentially weighted moving average of past gradients, to take care of the rapid changing gradients because the generator and the discriminator fight with each other. Besides, the random generated noise `Z`, is a 4-D tensor and we are using GPU to accelerate the computation.\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"origin_pos": 56,
"tab": [
"mxnet"
],
"id": "TyxwsKoeapxk"
},
"outputs": [],
"source": [
"def train(net_D, net_G, data_iter, num_epochs, lr, latent_dim,\n",
" device=d2l.try_gpu()):\n",
" loss = gluon.loss.SigmoidBCELoss()\n",
" net_D.initialize(init=init.Normal(0.02), force_reinit=True, ctx=device)\n",
" net_G.initialize(init=init.Normal(0.02), force_reinit=True, ctx=device)\n",
" trainer_hp = {'learning_rate': lr, 'beta1': 0.5}\n",
" trainer_D = gluon.Trainer(net_D.collect_params(), 'adam', trainer_hp)\n",
" trainer_G = gluon.Trainer(net_G.collect_params(), 'adam', trainer_hp)\n",
" animator = d2l.Animator(xlabel='epoch', ylabel='loss',\n",
" xlim=[1, num_epochs], nrows=2, figsize=(5, 5),\n",
" legend=['discriminator', 'generator'])\n",
" animator.fig.subplots_adjust(hspace=0.3)\n",
" for epoch in range(1, num_epochs + 1):\n",
" # Train one epoch\n",
" timer = d2l.Timer()\n",
" metric = d2l.Accumulator(3) # loss_D, loss_G, num_examples\n",
" for X, _ in data_iter:\n",
" batch_size = X.shape[0]\n",
" Z = np.random.normal(0, 1, size=(batch_size, latent_dim, 1, 1))\n",
" X, Z = X.as_in_ctx(device), Z.as_in_ctx(device),\n",
" metric.add(d2l.update_D(X, Z, net_D, net_G, loss, trainer_D),\n",
" d2l.update_G(Z, net_D, net_G, loss, trainer_G),\n",
" batch_size)\n",
" # Show generated examples\n",
" Z = np.random.normal(0, 1, size=(21, latent_dim, 1, 1), ctx=device)\n",
" # Normalize the synthetic data to N(0, 1)\n",
" fake_x = net_G(Z).transpose(0, 2, 3, 1) / 2 + 0.5\n",
" imgs = np.concatenate(\n",
" [np.concatenate([fake_x[i * 7 + j] for j in range(7)], axis=1)\n",
" for i in range(len(fake_x)//7)], axis=0)\n",
" animator.axes[1].cla()\n",
" animator.axes[1].imshow(imgs.asnumpy())\n",
" # Show the losses\n",
" loss_D, loss_G = metric[0] / metric[2], metric[1] / metric[2]\n",
" animator.add(epoch, (loss_D, loss_G))\n",
" print(f'loss_D {loss_D:.3f}, loss_G {loss_G:.3f}, '\n",
" f'{metric[2] / timer.stop():.1f} examples/sec on {str(device)}')"
]
},
{
"cell_type": "markdown",
"metadata": {
"origin_pos": 59,
"id": "Enqd8rRIapxl"
},
"source": [
"We train the model with a small number of epochs just for demonstration.\n",
"For better performance,\n",
"the variable `num_epochs` can be set to a larger number.\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"origin_pos": 60,
"tab": [
"mxnet"
],
"colab": {
"base_uri": "https://localhost:8080/",
"height": 522
},
"id": "MVdi-06eapxl",
"outputId": "545a3f8c-d985-4579-d9d5-2a31fd876140"
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"loss_D 0.202, loss_G 6.449, 329.8 examples/sec on gpu(0)\n"
]
},
{
"output_type": "error",
"ename": "ImportError",
"evalue": "ignored",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mImportError\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m/usr/local/lib/python3.7/dist-packages/IPython/core/formatters.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, obj)\u001b[0m\n\u001b[1;32m 332\u001b[0m \u001b[0;32mpass\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 333\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 334\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mprinter\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mobj\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 335\u001b[0m \u001b[0;31m# Finally look for special method names\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 336\u001b[0m \u001b[0mmethod\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mget_real_method\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mobj\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mprint_method\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.7/dist-packages/IPython/core/pylabtools.py\u001b[0m in \u001b[0;36m<lambda>\u001b[0;34m(fig)\u001b[0m\n\u001b[1;32m 245\u001b[0m \u001b[0mjpg_formatter\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfor_type\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mFigure\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;32mlambda\u001b[0m \u001b[0mfig\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mprint_figure\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfig\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'jpg'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 246\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0;34m'svg'\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mformats\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 247\u001b[0;31m \u001b[0msvg_formatter\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfor_type\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mFigure\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;32mlambda\u001b[0m \u001b[0mfig\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mprint_figure\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfig\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'svg'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 248\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0;34m'pdf'\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mformats\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 249\u001b[0m \u001b[0mpdf_formatter\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfor_type\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mFigure\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;32mlambda\u001b[0m \u001b[0mfig\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mprint_figure\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfig\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'pdf'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.7/dist-packages/IPython/core/pylabtools.py\u001b[0m in \u001b[0;36mprint_figure\u001b[0;34m(fig, fmt, bbox_inches, **kwargs)\u001b[0m\n\u001b[1;32m 123\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 124\u001b[0m \u001b[0mbytes_io\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mBytesIO\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 125\u001b[0;31m \u001b[0mfig\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcanvas\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mprint_figure\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbytes_io\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkw\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 126\u001b[0m \u001b[0mdata\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mbytes_io\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mgetvalue\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 127\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mfmt\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;34m'svg'\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.7/dist-packages/matplotlib/backend_bases.py\u001b[0m in \u001b[0;36mprint_figure\u001b[0;34m(self, filename, dpi, facecolor, edgecolor, orientation, format, bbox_inches, **kwargs)\u001b[0m\n\u001b[1;32m 2057\u001b[0m \u001b[0mIf\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0mbackend\u001b[0m\u001b[0;34m*\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mthen\u001b[0m \u001b[0mdetermine\u001b[0m \u001b[0ma\u001b[0m \u001b[0msuitable\u001b[0m \u001b[0mcanvas\u001b[0m \u001b[0;32mclass\u001b[0m \u001b[0;32mfor\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2058\u001b[0m \u001b[0msaving\u001b[0m \u001b[0mto\u001b[0m \u001b[0mformat\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0mfmt\u001b[0m\u001b[0;34m*\u001b[0m \u001b[0;34m-\u001b[0m\u001b[0;34m-\u001b[0m \u001b[0meither\u001b[0m \u001b[0mthe\u001b[0m \u001b[0mcurrent\u001b[0m \u001b[0mcanvas\u001b[0m \u001b[0;32mclass\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mit\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 2059\u001b[0;31m \u001b[0msupports\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0mfmt\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0mwhatever\u001b[0m\u001b[0;31m \u001b[0m\u001b[0;31m`\u001b[0m\u001b[0mget_registered_canvas_class\u001b[0m\u001b[0;31m`\u001b[0m \u001b[0mreturns\u001b[0m\u001b[0;34m;\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 2060\u001b[0m \u001b[0mswitch\u001b[0m \u001b[0mthe\u001b[0m \u001b[0mfigure\u001b[0m \u001b[0mcanvas\u001b[0m \u001b[0mto\u001b[0m \u001b[0mthat\u001b[0m \u001b[0mcanvas\u001b[0m \u001b[0;32mclass\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2061\u001b[0m \"\"\"\n",
"\u001b[0;32m/usr/local/lib/python3.7/dist-packages/matplotlib/backend_bases.py\u001b[0m in \u001b[0;36m_get_output_canvas\u001b[0;34m(self, fmt)\u001b[0m\n\u001b[1;32m 1991\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmouse_grabber\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0max\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1992\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmouse_grabber\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1993\u001b[0;31m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1994\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mdraw\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1995\u001b[0m \u001b[0;34m\"\"\"Render the `.Figure`.\"\"\"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.7/dist-packages/matplotlib/backend_bases.py\u001b[0m in \u001b[0;36mget_registered_canvas_class\u001b[0;34m(format)\u001b[0m\n\u001b[1;32m 124\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 125\u001b[0m \u001b[0mThe\u001b[0m \u001b[0mfollowing\u001b[0m \u001b[0mmethods\u001b[0m \u001b[0mmust\u001b[0m \u001b[0mbe\u001b[0m \u001b[0mimplemented\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mthe\u001b[0m \u001b[0mbackend\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mfull\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 126\u001b[0;31m functionality (though just implementing :meth:`draw_path` alone would\n\u001b[0m\u001b[1;32m 127\u001b[0m give a highly capable backend):\n\u001b[1;32m 128\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/lib/python3.7/importlib/__init__.py\u001b[0m in \u001b[0;36mimport_module\u001b[0;34m(name, package)\u001b[0m\n\u001b[1;32m 125\u001b[0m \u001b[0;32mbreak\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 126\u001b[0m \u001b[0mlevel\u001b[0m \u001b[0;34m+=\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 127\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0m_bootstrap\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_gcd_import\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mname\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mlevel\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mpackage\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlevel\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 128\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 129\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/lib/python3.7/importlib/_bootstrap.py\u001b[0m in \u001b[0;36m_gcd_import\u001b[0;34m(name, package, level)\u001b[0m\n",
"\u001b[0;32m/usr/lib/python3.7/importlib/_bootstrap.py\u001b[0m in \u001b[0;36m_find_and_load\u001b[0;34m(name, import_)\u001b[0m\n",
"\u001b[0;32m/usr/lib/python3.7/importlib/_bootstrap.py\u001b[0m in \u001b[0;36m_find_and_load_unlocked\u001b[0;34m(name, import_)\u001b[0m\n",
"\u001b[0;32m/usr/lib/python3.7/importlib/_bootstrap.py\u001b[0m in \u001b[0;36m_load_unlocked\u001b[0;34m(spec)\u001b[0m\n",
"\u001b[0;32m/usr/lib/python3.7/importlib/_bootstrap_external.py\u001b[0m in \u001b[0;36mexec_module\u001b[0;34m(self, module)\u001b[0m\n",
"\u001b[0;32m/usr/lib/python3.7/importlib/_bootstrap.py\u001b[0m in \u001b[0;36m_call_with_frames_removed\u001b[0;34m(f, *args, **kwds)\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.7/dist-packages/matplotlib/backends/backend_svg.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[1;32m 16\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mmatplotlib\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mmpl\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 17\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mmatplotlib\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mcbook\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 18\u001b[0;31m from matplotlib.backend_bases import (\n\u001b[0m\u001b[1;32m 19\u001b[0m \u001b[0m_Backend\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0m_check_savefig_extra_args\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mFigureCanvasBase\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mFigureManagerBase\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 20\u001b[0m RendererBase)\n",
"\u001b[0;31mImportError\u001b[0m: cannot import name '_check_savefig_extra_args' from 'matplotlib.backend_bases' (/usr/local/lib/python3.7/dist-packages/matplotlib/backend_bases.py)"
]
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<Figure size 360x360 with 2 Axes>"
]
},
"metadata": {
"needs_background": "light"
}
}
],
"source": [
"latent_dim, lr, num_epochs = 100, 0.005, 20\n",
"train(net_D, net_G, data_iter, num_epochs, lr, latent_dim)"
]
},
{
"cell_type": "markdown",
"metadata": {
"origin_pos": 62,
"id": "8vWWc40japxp"
},
"source": [
"## Summary\n",
"\n",
"* DCGAN architecture has four convolutional layers for the Discriminator and four \"fractionally-strided\" convolutional layers for the Generator.\n",
"* The Discriminator is a 4-layer strided convolutions with batch normalization (except its input layer) and leaky ReLU activations.\n",
"* Leaky ReLU is a nonlinear function that give a non-zero output for a negative input. It aims to fix the “dying ReLU” problem and helps the gradients flow easier through the architecture.\n",
"\n",
"\n",
"## Exercises\n",
"\n",
"1. What will happen if we use standard ReLU activation rather than leaky ReLU?\n",
"1. Apply DCGAN on Fashion-MNIST and see which category works well and which does not.\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"origin_pos": 63,
"tab": [
"mxnet"
],
"id": "xL5vXyVSapxq"
},
"source": [
"[Discussions](https://discuss.d2l.ai/t/409)\n"
]
}
],
"metadata": {
"accelerator": "GPU",
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
},
"language_info": {
"name": "python"
},
"colab": {
"name": "dcgan.ipynb",
"provenance": [],
"collapsed_sections": [],
"include_colab_link": true
}
},
"nbformat": 4,
"nbformat_minor": 0
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment