Skip to content

Instantly share code, notes, and snippets.

@pigghead
Created March 31, 2023 16:21
Show Gist options
  • Save pigghead/2d90da0ce83250eb0d3bfdc4b8af00a3 to your computer and use it in GitHub Desktop.
Save pigghead/2d90da0ce83250eb0d3bfdc4b8af00a3 to your computer and use it in GitHub Desktop.
DCGAN-SimpleImplementation.ipynb
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"provenance": [],
"authorship_tag": "ABX9TyPmQr3EPD0CZLWuDNpws92m",
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
},
"accelerator": "GPU",
"widgets": {
"application/vnd.jupyter.widget-state+json": {
"84b2967b5937443095e0cb89c5519704": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HBoxModel",
"model_module_version": "1.5.0",
"state": {
"_view_name": "HBoxView",
"_dom_classes": [],
"_model_name": "HBoxModel",
"_view_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_view_count": null,
"_view_module_version": "1.5.0",
"box_style": "",
"layout": "IPY_MODEL_78459a4f65314489831c725d83c0877a",
"_model_module": "@jupyter-widgets/controls",
"children": [
"IPY_MODEL_2e709235ea2b42dab063e14f08e6383f",
"IPY_MODEL_67a1305205db49bcac1a3be7f177282d"
]
}
},
"78459a4f65314489831c725d83c0877a": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"model_module_version": "1.2.0",
"state": {
"_view_name": "LayoutView",
"grid_template_rows": null,
"right": null,
"justify_content": null,
"_view_module": "@jupyter-widgets/base",
"overflow": null,
"_model_module_version": "1.2.0",
"_view_count": null,
"flex_flow": null,
"width": null,
"min_width": null,
"border": null,
"align_items": null,
"bottom": null,
"_model_module": "@jupyter-widgets/base",
"top": null,
"grid_column": null,
"overflow_y": null,
"overflow_x": null,
"grid_auto_flow": null,
"grid_area": null,
"grid_template_columns": null,
"flex": null,
"_model_name": "LayoutModel",
"justify_items": null,
"grid_row": null,
"max_height": null,
"align_content": null,
"visibility": null,
"align_self": null,
"height": null,
"min_height": null,
"padding": null,
"grid_auto_rows": null,
"grid_gap": null,
"max_width": null,
"order": null,
"_view_module_version": "1.2.0",
"grid_template_areas": null,
"object_position": null,
"object_fit": null,
"grid_auto_columns": null,
"margin": null,
"display": null,
"left": null
}
},
"2e709235ea2b42dab063e14f08e6383f": {
"model_module": "@jupyter-widgets/controls",
"model_name": "FloatProgressModel",
"model_module_version": "1.5.0",
"state": {
"_view_name": "ProgressView",
"style": "IPY_MODEL_8e9c9f422b964f118c65b0c28ccb4ab4",
"_dom_classes": [],
"description": " 0%",
"_model_name": "FloatProgressModel",
"bar_style": "danger",
"max": 136,
"_view_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"value": 0,
"_view_count": null,
"_view_module_version": "1.5.0",
"orientation": "horizontal",
"min": 0,
"description_tooltip": null,
"_model_module": "@jupyter-widgets/controls",
"layout": "IPY_MODEL_33606ed261614b998d9f2c0e08b54963"
}
},
"67a1305205db49bcac1a3be7f177282d": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HTMLModel",
"model_module_version": "1.5.0",
"state": {
"_view_name": "HTMLView",
"style": "IPY_MODEL_4b67b60e414f435b99169c408ab52ce5",
"_dom_classes": [],
"description": "",
"_model_name": "HTMLModel",
"placeholder": "​",
"_view_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"value": " 0/136 [00:01<?, ?it/s]",
"_view_count": null,
"_view_module_version": "1.5.0",
"description_tooltip": null,
"_model_module": "@jupyter-widgets/controls",
"layout": "IPY_MODEL_b414b40e11e84e4f98d5e1b72494fdc6"
}
},
"8e9c9f422b964f118c65b0c28ccb4ab4": {
"model_module": "@jupyter-widgets/controls",
"model_name": "ProgressStyleModel",
"model_module_version": "1.5.0",
"state": {
"_view_name": "StyleView",
"_model_name": "ProgressStyleModel",
"description_width": "initial",
"_view_module": "@jupyter-widgets/base",
"_model_module_version": "1.5.0",
"_view_count": null,
"_view_module_version": "1.2.0",
"bar_color": null,
"_model_module": "@jupyter-widgets/controls"
}
},
"33606ed261614b998d9f2c0e08b54963": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"model_module_version": "1.2.0",
"state": {
"_view_name": "LayoutView",
"grid_template_rows": null,
"right": null,
"justify_content": null,
"_view_module": "@jupyter-widgets/base",
"overflow": null,
"_model_module_version": "1.2.0",
"_view_count": null,
"flex_flow": null,
"width": null,
"min_width": null,
"border": null,
"align_items": null,
"bottom": null,
"_model_module": "@jupyter-widgets/base",
"top": null,
"grid_column": null,
"overflow_y": null,
"overflow_x": null,
"grid_auto_flow": null,
"grid_area": null,
"grid_template_columns": null,
"flex": null,
"_model_name": "LayoutModel",
"justify_items": null,
"grid_row": null,
"max_height": null,
"align_content": null,
"visibility": null,
"align_self": null,
"height": null,
"min_height": null,
"padding": null,
"grid_auto_rows": null,
"grid_gap": null,
"max_width": null,
"order": null,
"_view_module_version": "1.2.0",
"grid_template_areas": null,
"object_position": null,
"object_fit": null,
"grid_auto_columns": null,
"margin": null,
"display": null,
"left": null
}
},
"4b67b60e414f435b99169c408ab52ce5": {
"model_module": "@jupyter-widgets/controls",
"model_name": "DescriptionStyleModel",
"model_module_version": "1.5.0",
"state": {
"_view_name": "StyleView",
"_model_name": "DescriptionStyleModel",
"description_width": "",
"_view_module": "@jupyter-widgets/base",
"_model_module_version": "1.5.0",
"_view_count": null,
"_view_module_version": "1.2.0",
"_model_module": "@jupyter-widgets/controls"
}
},
"b414b40e11e84e4f98d5e1b72494fdc6": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"model_module_version": "1.2.0",
"state": {
"_view_name": "LayoutView",
"grid_template_rows": null,
"right": null,
"justify_content": null,
"_view_module": "@jupyter-widgets/base",
"overflow": null,
"_model_module_version": "1.2.0",
"_view_count": null,
"flex_flow": null,
"width": null,
"min_width": null,
"border": null,
"align_items": null,
"bottom": null,
"_model_module": "@jupyter-widgets/base",
"top": null,
"grid_column": null,
"overflow_y": null,
"overflow_x": null,
"grid_auto_flow": null,
"grid_area": null,
"grid_template_columns": null,
"flex": null,
"_model_name": "LayoutModel",
"justify_items": null,
"grid_row": null,
"max_height": null,
"align_content": null,
"visibility": null,
"align_self": null,
"height": null,
"min_height": null,
"padding": null,
"grid_auto_rows": null,
"grid_gap": null,
"max_width": null,
"order": null,
"_view_module_version": "1.2.0",
"grid_template_areas": null,
"object_position": null,
"object_fit": null,
"grid_auto_columns": null,
"margin": null,
"display": null,
"left": null
}
}
}
}
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/pigghead/2d90da0ce83250eb0d3bfdc4b8af00a3/dcgan-simpleimplementation.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "VNZcgLAJ-Exe",
"outputId": "904fc33c-ea1d-401d-92d4-7531764080e8"
},
"source": [
"from google.colab import drive\n",
"drive.mount('/content/drive', force_remount=True)"
],
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Mounted at /content/drive\n"
]
}
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "K1nHyBGMrRz7",
"outputId": "3b915685-e4e7-4493-f99a-ed2809c203e0"
},
"source": [
"import torch\n",
"from torch import nn\n",
"from tqdm.auto import tqdm\n",
"from torchvision import transforms\n",
"from torchvision.datasets import MNIST # training dataset\n",
"from torchvision.utils import make_grid\n",
"from torch.utils.data import DataLoader\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"torch.manual_seed(0)"
],
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"<torch._C.Generator at 0x7ff17dcfdcb0>"
]
},
"metadata": {
"tags": []
},
"execution_count": 44
}
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "95Fjqut-BfoL",
"outputId": "1c1198cd-b220-48ef-a855-e9320f8b8246"
},
"source": [
"### File retrieved from separate colab file \n",
"### (gist: https://gist.github.com/pigghead/ece7aea9b51b799fbb17b5f1e7c1521b#file-downscale-image-test-ipynb)\n",
"# bring flwr_data.npy in as an ndarray\n",
"flower_images = np.load('/content/drive/MyDrive/YEAR 5/IGME 797/flwr_data_64x64.npy')\n",
"\n",
"def confirm_import():\n",
" assert flower_images.shape == (4323, 3, 64, 64), flower_images.shape #entire ndarray\n",
" assert flower_images[0].shape == (3,64,64) # single instance\n",
" assert flower_images[0].dtype == 'float'\n",
" print(type(flower_images[0]))\n",
" print('import successful')\n",
"\n",
"confirm_import()\n",
"### NOTE: Shape currently is: Samples, Height, Width, Channels\n",
"### desired shape is Samples, Channels, Height, Width\n",
"\n",
"\n",
"\n",
"print(flower_images.astype(np.double).dtype)"
],
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"text": [
"<class 'numpy.ndarray'>\n",
"import successful\n",
"float64\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "pzy0utf4rJCr",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "e7c387d5-eb5b-4675-b049-ee3832391609"
},
"source": [
"\"\"\"\n",
"Discriminator and Generator implementation from DCGAN paper\n",
"\"\"\"\n",
"\n",
"class Discriminator(nn.Module):\n",
" def __init__(self, channels_img, features_d):\n",
" super(Discriminator, self).__init__()\n",
" self.disc = nn.Sequential(\n",
" # input: N x channels_img x 64 x 64\n",
" nn.Conv2d(\n",
" channels_img, features_d, kernel_size=4, stride=2, padding=1\n",
" ),\n",
" nn.LeakyReLU(0.2),\n",
" # _block(in_channels, out_channels, kernel_size, stride, padding)\n",
" self._block(features_d, features_d * 2, 4, 2, 1),\n",
" self._block(features_d * 2, features_d * 4, 4, 2, 1),\n",
" self._block(features_d * 4, features_d * 8, 4, 2, 1),\n",
" # After all _block img output is 4x4 (Conv2d below makes into 1x1)\n",
" nn.Conv2d(features_d * 8, 1, kernel_size=4, stride=2, padding=0),\n",
" nn.Sigmoid(),\n",
" )\n",
"\n",
" def _block(self, in_channels, out_channels, kernel_size, stride, padding):\n",
" return nn.Sequential(\n",
" nn.Conv2d(\n",
" in_channels,\n",
" out_channels,\n",
" kernel_size,\n",
" stride,\n",
" padding,\n",
" bias=False,\n",
" ),\n",
" #nn.BatchNorm2d(out_channels),\n",
" nn.LeakyReLU(0.2),\n",
" )\n",
"\n",
" def forward(self, x):\n",
" return self.disc(x)\n",
"\n",
"\n",
"class Generator(nn.Module):\n",
" def __init__(self, channels_noise, channels_img, features_g):\n",
" super(Generator, self).__init__()\n",
" self.net = nn.Sequential(\n",
" self._block(channels_noise, features_g*16, 4, 1, 0), # 4x4\n",
" self._block(features_g*16, features_g*8, 4, 2, 1), # 8x8\n",
" self._block(features_g*8, features_g*4, 4, 2, 1), # 16x16\n",
" self._block(features_g*4, features_g*2, 4, 2, 1), # 32x32\n",
" nn.ConvTranspose2d(features_g*2, channels_img, 4, 2, 1), #64x64\n",
" nn.Tanh()\n",
" )\n",
"\n",
"\n",
" def _block(self, in_channels, out_channels, kernel_size, stride, padding):\n",
" return nn.Sequential(\n",
" nn.ConvTranspose2d(\n",
" in_channels,\n",
" out_channels,\n",
" kernel_size,\n",
" stride,\n",
" padding,\n",
" bias=False,\n",
" ),\n",
" #nn.BatchNorm2d(out_channels),\n",
" nn.ReLU(),\n",
" )\n",
"\n",
"\n",
" def forward(self, x):\n",
" return self.net(x)\n",
"\n",
"\n",
"def initialize_weights(model):\n",
" # Initializes weights according to the DCGAN paper\n",
" for m in model.modules():\n",
" if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.BatchNorm2d)):\n",
" nn.init.normal_(m.weight.data, 0.0, 0.02)\n",
"\n",
"\n",
"def test():\n",
" N, in_channels, H, W = 8, 3, 64, 64\n",
" noise_dim = 100\n",
" gen = Generator(noise_dim, in_channels, 8)\n",
" z = torch.randn((N, noise_dim, 1, 1))\n",
" print('gen(z) shape: ', gen(z).shape)\n",
"\n",
" disc = Discriminator(in_channels, 8)\n",
" x = torch.randn((N, in_channels, H, W))\n",
" print('disc(x) shape: ', disc(x).shape)\n",
"\n",
" print('disc(gen(z)) shape: ', disc(gen(z)).shape)\n",
" ####### Expected: N=8, C=1, H=64, W=64\n",
" assert gen(z).shape == (N, in_channels, H, W), gen(z).shape\n",
" #assert gen(z2).shape == (N, in_channels, H, W), gen(z2).shape\n",
"\n",
"\n",
"test()"
],
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"text": [
"gen(z) shape: torch.Size([8, 3, 64, 64])\n",
"disc(x) shape: torch.Size([8, 1, 1, 1])\n",
"disc(gen(z)) shape: torch.Size([8, 1, 1, 1])\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "A6wAPd4cZ2iH",
"outputId": "428b6a17-f4f5-4f51-f82b-f3ef5fc6f099"
},
"source": [
"class discriminator_32x32(nn.Module):\n",
" \n",
" def __init__(self, channels_img, features_d):\n",
" super(discriminator_32x32, self).__init__()\n",
" self.disc = nn.Sequential(\n",
" # Input: N x channels_img x features_d x features_d (?)\n",
" self._block(channels_img, features_d*2, 4, 2, 1), # 16x16\n",
" self._block(features_d*2, features_d*4, 4, 2, 1), #8x8\n",
" self._block(features_d*4, features_d*8, 4, 2, 1), #4x4\n",
" #self._block(features_d*8, features_d*16, 4, 2, 1),\n",
" nn.Conv2d(features_d*8, 1, 4, 1, 0), #2x2\n",
" nn.Sigmoid() # 1 or 0\n",
" )\n",
"\n",
" def _block(self, in_channels, out_channels, kernel_size, stride, padding):\n",
" return nn.Sequential(\n",
" nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=False),\n",
" nn.BatchNorm2d(out_channels),\n",
" nn.LeakyReLU(0.2, inplace=True)\n",
" )\n",
"\n",
"\n",
" def forward(self, img):\n",
" return self.disc(img)\n",
"\n",
"''' EXAMPLE TEST CLASS\n",
"\n",
" N, in_channels, H, W = 8, 3, 64, 64\n",
" noise_dim = 100\n",
" x = torch.randn((N, in_channels, H, W))\n",
"\n",
" disc = Discriminator(in_channels, 8)\n",
" assert disc(x).shape == (N, 1, 1, 1), \"Discriminator test failed\"\n",
"'''\n",
"\n",
"def test_disc():\n",
" N, in_channels, H, W = 8, 1, 32, 32\n",
" noise_dim = 100\n",
" x = torch.randn((N, in_channels, H, W)) # !!ERROR HERE (resolved)\n",
" print(x.shape)\n",
" disc = discriminator_32x32(in_channels, 32) \n",
"\n",
" print( disc(x)[0] )\n",
" assert disc(x).shape == (N, 1, 1, 1), \"Discriminator test failed\"\n",
"\n",
"\n",
"test_disc()"
],
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"text": [
"torch.Size([8, 1, 32, 32])\n",
"tensor([[[0.2947]]], grad_fn=<SelectBackward>)\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 283
},
"id": "pALZEKeiQVUW",
"outputId": "804ba980-47a4-430a-8799-353648ba4ce9"
},
"source": [
"class generator_32x32(nn.Module):\n",
" \n",
" def __init__(self, noise_channels, im_channels, features_g):\n",
" super(generator_32x32, self).__init__()\n",
" self.gen = nn.Sequential( \n",
" self._block(noise_channels, features_g*8, 4, 1, 0), #4x4\n",
" self._block(features_g*8, features_g*4, 4, 2, 1), #8x8\n",
" self._block(features_g*4, features_g*2, 4, 2, 1), #16x16\n",
" nn.ConvTranspose2d(\n",
" features_g*2, im_channels, 4, 2, 1\n",
" ),\n",
" # Output= s(n-1)+f-2p, 32x32 img\n",
" nn.Tanh()\n",
" )\n",
"\n",
"\n",
" def _block(self, in_channel, out_channel, kernel_size, stride, padding):\n",
" ''' \n",
" Performs transposed convolution, batch normalization, and Leaky ReLU activation \n",
" '''\n",
" return nn.Sequential (\n",
" nn.ConvTranspose2d(in_channel, out_channel, kernel_size, stride, padding, bias=False),\n",
" # Batch norm?\n",
" nn.BatchNorm2d(out_channel),\n",
" nn.LeakyReLU(0.2)\n",
" )\n",
"\n",
"\n",
" def forward(self, x):\n",
" return self.gen(x)\n",
"\n",
"\n",
"''' EXAMPLE TEST CLASS\n",
"\n",
" N, in_channels, H, W = 8, 3, 64, 64\n",
" noise_dim = 100\n",
" x = torch.randn((N, in_channels, H, W))\n",
"\n",
" disc = Discriminator(in_channels, 8)\n",
" assert disc(x).shape == (N, 1, 1, 1), \"Discriminator test failed\"\n",
"\n",
" gen = Generator(noise_dim, in_channels, 8)\n",
" z = torch.randn((N, noise_dim, 1, 1))\n",
" assert gen(z).shape == (N, in_channels, H, W), \"Generator test failed\"\n",
"'''\n",
"\n",
"\n",
"# test our 28x28 generator\n",
"def test_gen_28x28():\n",
" N, in_channels, H, W = 8, 1, 32, 32\n",
" noise_dim = 100\n",
" z = torch.randn((N, noise_dim, 1, 1))\n",
" gen1 = generator_32x32(noise_dim, in_channels, 32)\n",
" assert gen1(z).shape == (N, in_channels, H, W)\n",
" print('gen1(z).shape',gen1(z).shape)\n",
"\n",
"\n",
"#test_gen_28x28()\n",
"\n",
"\n",
"def print_gen_28x28():\n",
" N, in_channels, H, W = 1, 1, 32, 32\n",
" noise_dim = 128\n",
" z = torch.randn((N, noise_dim, 1, 1),device='cpu')\n",
" gen = generator_32x32(noise_dim, in_channels, 8).to('cpu')\n",
" print('gen(z) shape: ', gen(z).shape)\n",
" plt.Figure(figsize=(28, 28))\n",
" plt.imshow(gen(z)[0, 0, :, :].detach().cpu(), cmap='gray')\n",
" plt.show()\n",
"\n",
" \n",
"#test_gen_28x28()\n",
"print_gen_28x28()\n",
"\n",
"\n",
"# test both\n",
"def test_both():\n",
" N, in_channels, H, W = 8, 1, 32, 32\n",
" noise_dim = 100\n",
" # A test generator\n",
" z = torch.randn((N, noise_dim, 1, 1))\n",
" gen = generator_32x32(noise_dim, in_channels, 32)\n",
" # A test discriminator\n",
" disc = discriminator_32x32(in_channels, 32)\n",
" print(disc(gen(z)).shape)\n",
"\n",
"\n",
"#test_both()"
],
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"text": [
"gen(z) shape: torch.Size([1, 1, 32, 32])\n"
],
"name": "stdout"
},
{
"output_type": "display_data",
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAD5CAYAAADhukOtAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAevElEQVR4nO2deZRV1bHGv2ISFGSUUWSQdgEiAWwICqIRQY0DqBFFRaIoGiHGgSRITECcEARnJSgiEglDHCAGBURdRI1gowgqRBBBRWhwohFlaur9cS/vodnf7qaH2+Sd77cWi9v1dd2z+/Spe27vulVl7g4hxP9/ypX1AoQQmUHBLkRCULALkRAU7EIkBAW7EAlBwS5EQqhQHGczOw3AfQDKA3jM3UfFvr9atWpeu3btoJafnx87TtBepUoV6rN7926qxY4V0w4++GCqMcqV46+nW7ZsodohhxxCtW3btlGNnatDDz2U+uzcuZNqlSpVolpeXt5++8VSvbFzX7lyZaqxnxkAvvrqq6C9evXq1GfPnj1U27VrF9WK+jtj1/H27dupz0EHHRS0b9q0CXl5ecETUuRgN7PyAB4C0APAZwDeMrPZ7v4B86lduzaGDx8e1L7++mt6LHbhtG7dmvqwXzIAfPPNN1SLXcDt2rUL2mMXW+wCeOGFF6jWoUMHquXk5FCtYsWKQXvPnj2pz6effkq1Ro0aUW3evHlUa9KkSdAeC+jYi1/Lli2pVqECv4z/8pe/BO1nn3029fnuu++otmHDBqode+yxVIv9zo455pigfeXKldSnRYsWQfsNN9xAfYrzNr4TgNXuvsbddwKYBqBXMZ5PCFGKFCfYGwHY95bwWdomhDgAKfUNOjMbaGY5Zpbz7bfflvbhhBCE4gT7egCN9/n68LTtB7j7BHfPdvfsqlWrFuNwQojiUJxgfwtAlpk1M7NKAC4EMLtkliWEKGmKvBvv7rvNbDCAuUil3h539/cL8KHphNgOKHv7P2nSJOrTpUsXqn3//fdUY7vIALBo0aKgPbZjzXZaAaBBgwZUmzBhAtVuuukmqrEUWywrcO+991Ltl7/8JdXeeOMNqg0ePDhov+SSS6jPxIkTqZabm0u1WFqRabFd9VNOOYVq7BoAgH//+99Ua9u2LdVY6i127bDMRSy1Waw8u7vPATCnOM8hhMgM+gSdEAlBwS5EQlCwC5EQFOxCJAQFuxAJwTLZcLJu3brep0+foBYratmxY0fQ3qxZM+pz1113UW3QoEFU27hxI9V69Qp/9P+WW26hPuvX/8fnjP6XWAqtffv2VPv73/9OtVdffTVov/HGG6lPzZo1qTZ9+vQiaQMGDAjaY9WIWVlZVDvssMOoFkvLrV27NmiPfcArVpR11FFHUa1u3bpUi6UwZ86cGbSz3yUA1K9fP2gfMmQIVq9eHazM0p1diISgYBciISjYhUgICnYhEoKCXYiEkNHd+FatWvkTTzwR1ObPn0/9OnfuHLQ3btw4aAeAxYsXU+2pp56i2hFHHEG1W2+9NWh/+umnqc+oUbwtH+sjBsTXH+tN9v774VqkWM+12C54rFVUjRo1qDZr1qyg/aKLLqI+t99+O9XWrFlDtaVLl1Ltt7/9bdC+efNm6tO3b1+qxTIop556KtVi53jVqlVB+4MPPkh92HV1xRVXYOXKldqNFyLJKNiFSAgKdiESgoJdiISgYBciISjYhUgIxWpLtb9s2bIFc+aEu1jFilpYGufOO++kPrGJMFdffTXVLr/8cqqxnmCxSSCxIpns7Gyqxaa0fPzxx1Rjk05OOOEE6vP4449TLTZtZevWrVSLpUUZ7NoA4hNtWrVqtd/a7Nm8N+rYsWOpds0111Cte/fuVJs7dy7Vvvzyy6A9VkTF0q+x0VW6swuREBTsQiQEBbsQCUHBLkRCULALkRAU7EIkhGJVvZnZWgBbAeQD2O3uPJcEoH79+t6/f/+gFquG+vzzz4P2WO+0jh07Um3btm1Ue/vtt6nGxvTs2rWL+nTt2pVqrCILAMaNG0e1Dz74gGotWrQI2vPz86nP0UcfTbWRI0dSjfXkA4BzzjknaF++fDn16datG9WmTZtGtZdeeolqv/nNb4L22Bin3r17Uy2WIo6dj1ha8YILLgja2XUP8Eq/yZMnY+PGjcGqt5LIs//M3b8ogecRQpQiehsvREIobrA7gHlmtsTMBpbEgoQQpUNx38Z3dff1ZlYXwHwzW+nuC/f9hvSLwEAgPlpXCFG6FOvO7u7r0/9vAvAsgE6B75ng7tnuns02uIQQpU+Rg93MDjGzansfA+gJ4L2SWpgQomQpcurNzJojdTcHUn8OTHV33jEQQIsWLXzMmDFBLZa+eu6554L2Rx55hPq8/PLLVDvjjDOoFqsoa9iwYdA+Y8YM6hNLKcaaOca0Tp3+4w3U/8KaFMZ8YmOLYqOmunTpQrV//etfQXts/FOsOeeVV15JtZNOOolqLK345z//mfrEKvZiv+t3332XauPHj6daXl5e0L5ixQrqU6lSpaB97Nix+PTTT0s29ebuawD8pKj+QojMotSbEAlBwS5EQlCwC5EQFOxCJAQFuxAJIaMNJ/Pz82lK6d5776V+bNbbkiVLqE/btm2pdu6551KtZcuWVGOpplha6/XXX6dau3btqBabRbZp0yaqsUqp559/nvoMHjyYarFKrtjcs5UrVwbtN998M/Xp0KED1S699FKq3XHHHVS7++67g/ZYNWJsdt8NN9xAtW+++YZqubm5VLvvvvuCdjYXMfZ8hxxyCPXRnV2IhKBgFyIhKNiFSAgKdiESgoJdiISQ0d34nTt30t3i6dOnU78BAwYE7bFCjOHDh1OtefPmVIsVSLD+dJdddhn1iRVOLF68mGovvvgi1c4880yqsVFOw4YNoz5sZBQArFq1imonnngi1aZMmRK0xzIXsTXGyqNzcnKoxrI/P/kJL+soV47fA2PjwZ588kmqPfbYY1Rj5/+1116jPllZWVRj6M4uREJQsAuREBTsQiQEBbsQCUHBLkRCULALkRAymnqrUKECHdlUu3Zt6sf6073yyivUp0ePHlSLjV2qVasW1d55552gnaW7AODOO++k2hdf8EE6L7zwAtXmzZtHNdZDL9ZbL3Y+tm/fTrWxY8dSjf3csXTS9ddfT7X69etTbfXq1VRj/QZvu+026hM7V5MmTaLaH/7wB6qNGDGCahMmTAjaGzVqRH3YOK9YT0nd2YVICAp2IRKCgl2IhKBgFyIhKNiFSAgKdiESQoGpNzN7HMCZADa5e5u0rRaA6QCaAlgLoI+7f12I56Jja3bu3En9Zs6cGbSzsTkA0KRJE6pVrlyZarEeY4cffnjQft5551Gf0aNHU23RokVUi1Xtxc4VS9e89dZb1CfW/y+WVoyNa+rdu3fQnp2dTX0WLlxItVatWlEtlg5j62/atCn1iVWoxVLEsaq3yZMnU6179+5Be6xvIFv/1q1bqU9h7uxPADjtR7ahABa4exaABemvhRAHMAUGe3re+lc/MvcCsPelajKA8Mu4EOKAoah/s9dz9w3pxxsB1Cuh9QghSolib9B56vN59DN6ZjbQzHLMLCf294QQonQparDnmlkDAEj/T6cWuPsEd8929+xq1aoV8XBCiOJS1GCfDaB/+nF/ALNKZjlCiNKiMKm3vwI4CUAdM/sMwHAAowDMMLMBANYB6FOYg5UrV46Op+nbty/1u+aaa4L2xo0bU59YVVNstFK3bt2o1q9fv6A9Vp0UezfTunVrqsWaOX700UdUu+CCC4L2Z555hvrERjzFUpGxysLx48cH7bHRVbH02vLly6kWG7v061//OmiPVbYdc8wxRdJixM7xs88+G7Tv2bOH+nTs2DFonzp1KvUpMNjdnUVhODkohDgg0SfohEgICnYhEoKCXYiEoGAXIiEo2IVICBlvOFmjRo2gFqtgY3PgHn30UerDKokAoF27dlTr0KED1VjqrV49/mnhV199lWrr16+nWqx5YSw9yCqvYikjlsYBgDfffJNq33//PdXOOuusoD1W6RerNvvVr35FtUceeYRqrEIwltqMVRxeccUVVPv9739PtdmzZ1OtWbNmQfuXX35Jfdq0aRO0x2bi6c4uREJQsAuREBTsQiQEBbsQCUHBLkRCULALkRAymnqLNZyMNTbMysoK2h988EHqE0uDXHzxxVTbvXs31e6+++6g/eGHH6Y+f/zjH6kWS5PEKuliKZkPP/wwaP/HP/5BfWLNKH/xi19QLVZtxubYbdy4kfqce+65VIv9zmLzzS655JKgPVZ9F0sDx9K2LK0MAPfccw/V2Ky32LUzY8aMoP2rr37cQe7/0J1diISgYBciISjYhUgICnYhEoKCXYiEkNHdeAAoX7580L506VLqc9dddwXt48aNoz6xHdUpU6ZQbe3atVR75513gvbLL7+c+nTp0oVqK1asoFpst7hOnTpUY73a1q1bR33YWCsAmDZtGtWuvPJKqrHj3XrrrdQnthsfI5YVYP3dYrv7sTFOsRFVrP8fEP/Z2PUd66PI+i/Gskm6swuREBTsQiQEBbsQCUHBLkRCULALkRAU7EIkhMKMf3ocwJkANrl7m7RtBIArAeydozTM3ecU9Fz5+fn0g/r//Oc/qR9LQcydO5f6sAIIgBcRADxVAwB169YN2jt37kx9KlTgp3jDhg1UixV3xIpJWAHNsmXLqM+hhx5KtWuvvZZq2dnZVGO/s4oVK1Kfiy66iGqx3nWsuArgo76GDh1KfQYMGEC1QYMGUe3zzz+n2qRJk6g2Z044dG655Rbqs2XLlqA9VvBUmDv7EwBOC9jvcfd26X8FBroQomwpMNjdfSEAXjcnhPivoDh/sw82s2Vm9riZ1SyxFQkhSoWiBvsjAI4E0A7ABgBj2Tea2UAzyzGznNhHWIUQpUuRgt3dc9093933AHgUQKfI905w92x3z45tBAkhSpciBbuZNdjny3MAvFcyyxFClBaFSb39FcBJAOqY2WcAhgM4yczaAXAAawFcVZiDbd++nfZIi/WMGzlyZNA+cOBA6tO6dWuqvfLKK1SbOXMm1fr37x+0xyr2Yu9mrr/+eqr17duXarGxUSz1Ekvj3H///VSLVRa2b9+ealdffXXQXrt2berTs2dPqsXSWrF03kEHHRS09+nTh/rE0sAtW7ak2vz586k2ePBgqrF0aaxS8Zxzzgnay5Xj9+8Cg93dQ1fdxIL8hBAHFvoEnRAJQcEuREJQsAuREBTsQiQEBbsQCSGjDSd37txJGzrGmi+yhnyXXXYZ9YmlQWLVZqNGjaJamzZtgvYhQ4ZQn9GjR1NtzJgx+30sIF7ZdOyxxwbtsZFRRakaA4BZs2ZRjTXF7NixI/WJpZpiTTZZBRjA03lnnXUW9alVqxbV+vXrR7VYSvf888+nGquM/Pjjj6kPayz6/fffUx/d2YVICAp2IRKCgl2IhKBgFyIhKNiFSAgKdiESQkZTb5UqVUKzZs2CWuXKlanfiBEjgvZTTz2V+rCqICA+byxWyfXdd98F7du3b6c+rJkgADRt2pRqTz31FNXeeOMNqrFGlSxVA/AZdgBog1AgPmONnZN3332X+sQq0YYPH0612LVz3HHHBe0TJ/JarsWLF1MtNrOtSZMmVItVdbLmqKxiDwC+/vrroF2z3oQQCnYhkoKCXYiEoGAXIiEo2IVICBndjS9fvjxq1KgR1A4++GDqt2PHjqA9NgYp1kfsmGOOoRorugH4uKNvv/2W+jz22GNUmzp1KtVuuukmqtWsydv0f/TRR0F7lSpVqM/RRx9NtdiIrdjPdtRRRwXtZ599NvW5/fbbqfb6669TjV0fAL8OYgU5n332GdViWY2GDRtSzcyo9vLLLwftsQKln/3sZ0F7LDOhO7sQCUHBLkRCULALkRAU7EIkBAW7EAlBwS5EQijM+KfGAJ4EUA+pcU8T3P0+M6sFYDqApkiNgOrj7uFP56epVKkSGjduHNRiKQNWcHHppZdSn3nz5lFt0KBBVHv44Yf3W/vkk0+oz+zZs6k2adIkqsUKLqZPn061k08+OWhfsmQJ9WE90ADgwQcfpFqsF96ECROC9lixSGyc14ABA6gW+52xNcaKfy6++GKqxYqoevXqRbXJkydTrUePHkF7LM23atWqoD1WlFWYO/tuADe6e2sAnQEMMrPWAIYCWODuWQAWpL8WQhygFBjs7r7B3d9OP94KYAWARgB6Adj7cjUZQO/SWqQQovjs19/sZtYUQHsAiwDUc/e9PZk3IvU2XwhxgFLoYDezqgCeBnCdu+ftq3mqY0Kwa4KZDTSzHDPLycvLC32LECIDFCrYzawiUoH+lLs/kzbnmlmDtN4AwKaQr7tPcPdsd8+OzSoXQpQuBQa7pT7BPxHACncft480G0D/9OP+APh4ECFEmVOYqrcuAPoBWG5mS9O2YQBGAZhhZgMArAPAG4jtPViFCnSMD0vVAMARRxwRtLPeXQBw+eWXU+26666jWqw66e233w7ajz/+eOpTrhx/PY2l7Dp06EC1wYMHU23ZsmVBe25uLvWJVRzG+t2VL19+v9cRG0/06KOPUi1WUda7N98bfuGFF4L2unXrUp9Y/8J69fjWFDsWACxYsIBqrLotNk6KjcqKXW8FBru7vwaARUD3gvyFEAcG+gSdEAlBwS5EQlCwC5EQFOxCJAQFuxAJwdi4oNKgYcOGzqqXYmko1ogwVq119dVXU61///5Ua926NdVYE8VYqiZW7VS9enWqffHFF1S7+eabqcYqr2Ln97LLLqNarBqRjcMCgM2bNwft9evXpz6rV6+mGhsbBgDXX3891Via9Y477qA+rHIQiKcbY009V6xYQbXzzjsvaI81MmXVgz179sTSpUuD2TPd2YVICAp2IRKCgl2IhKBgFyIhKNiFSAgKdiESQkZTb3Xq1HE26ytWpcYaLI4ZM4b6fP7551Rj1WsAsHjxYqo98MADQTtriAkAffv2pRqb1wXEK9tmzJhBtaysrKB9yJAh1OfJJ5+kWiy9FltHTk5O0N6qVSvqc+edd1KNzbAD4ilYliqL/Vw9e/akWqwRaOz32bVrV6q9//77VGOwNLBSb0IIBbsQSUHBLkRCULALkRAU7EIkhML0oCsxqlSpgjZt2gS12Die9evXB+0VK1akPi+99BLVYgUG06ZNo9rEiROD9tiOe35+PtVq1qxJtR07dlAtNkKJ9SDbtm0b9Tn88MOpduaZZ1Ktbdu2VGPn/95776U+ffrwNoax30tsN57tWlerVo36xApyYr0SY33+HnroIaqdccYZQXvsXLFrLpZl0J1diISgYBciISjYhUgICnYhEoKCXYiEoGAXIiEUWAhjZo0BPInUSGYHMMHd7zOzEQCuBLC32dgwd58Te64jjzzSR40aFdQOO+ww6sdSELF+ccOHD6fa888/T7XTTz+dam+++eZ+P1/z5s2p9tVXX1EtNiaJ9ZkD+Mig+fPnUx82kguIj0m6++67qdajR4+gPdY/7+ijj6YaG5EEAFu3bqXarl27gvZ+/fpRny5dulDt2muvpVrsfMQKeU466aSgvXPnztRn9OjRQfsbb7yBLVu2BAthCpNn3w3gRnd/28yqAVhiZnuvnHvcnf+EQogDhsLMetsAYEP68VYzWwGgUWkvTAhRsuzX3+xm1hRAewCL0qbBZrbMzB43M/5xMCFEmVPoYDezqgCeBnCdu+cBeATAkQDaIXXnH0v8BppZjpnl5OXllcCShRBFoVDBbmYVkQr0p9z9GQBw91x3z3f3PQAeBdAp5OvuE9w9292zDz300JJatxBiPykw2M3MAEwEsMLdx+1jb7DPt50D4L2SX54QoqQoTOqtK4B/AlgOYE/aPAxAX6TewjuAtQCuSm/mUY466ii///77g9qCBQuoHxuvdOONN1KfWOXS7373O6otX76cameddVbQPm7cuKAdAK666iqqnXrqqVSL9XeLpcPYCKLYiKTYSKZYz7VYtRw7V7FqvlgvPDY2DODViADvbThlyhTq86c//Ylqr7/+OtXee4/f70455RSqsR50rIIR4H0UR48ejU8++aRoqTd3fw1AyDmaUxdCHFjoE3RCJAQFuxAJQcEuREJQsAuREBTsQiSEjDac3LlzJ20eGUt3NG3aNGhfunQp9Yk13mvUiH+0P5Y+YdVtTzzxBPWJVb0NGzaMav3796fa3/72N6qdcMIJQftxxx1HfWLpxg8++IBqsYaT8+bNC9qnTp1KfWJVXrNnz6ZaLB1Wu3btoP3444+nPnPnzqXarFmzqFavXj2qxSoczz///KA9liI+8cQTg/aqVatSH93ZhUgICnYhEoKCXYiEoGAXIiEo2IVICAp2IRJCRlNv3333HZ3p1qpVK+o3Z0645iZW/fXMM89Qbc2aNVS78MILqcaq7Bo0aBC0A8DAgQOp1rVrV6qxGWVAPNXEGlV++OGH1Cc2qy5WWRhLDd1+++1Be/v27alPLPV23nnnUW3x4sVU69atW9Ae+7li6atYw0xWcQgAP/3pT6l22223Be2swSnAq+hSFelhdGcXIiEo2IVICAp2IRKCgl2IhKBgFyIhKNiFSAgZTb3VrVsX11xzTVCLNW1cu3Zt0B5LGcXSFrFquTFjxlCNpX8+/vhj6lPUijj2MwPA+PHjqcaq5R5++GHqw+bvAcCLL75ItcmTJ1ONVXnFGkfGmmJ++eWXVHvggQeotmXLlv32iV071atXp1rsOli0aBHV7rvvvqC9U6dgd/bo823bto366M4uREJQsAuREBTsQiQEBbsQCUHBLkRCKMz4p8oAFgI4CKnd+7+5+3AzawZgGoDaAJYA6OfuO2PPlZWV5WznkfWmA4CNGzcG7WynFQDy8/OplpubSzXW7w4AGjduHLTPnz+f+sSKRWK7vg899BDVhg4dSjVWlDNy5EjqExvJFBuTxIpuYuuoUIEngGI/c6zYaPPmzVRjPehi133Hjh2pFsuEXHzxxVSbNGkS1diIsNjOOuuTN23aNOTm5garYQpzZ98B4GR3/wlSs91OM7POAO4CcI+7twDwNQCeUxFClDkFBrun+Db9ZcX0PwdwMoC9bU4nA+hdKisUQpQIhZ3PXt7MlgLYBGA+gI8AfOPuu9Pf8hkA3p9ZCFHmFCrY3T3f3dsBOBxAJwAtC3sAMxtoZjlmlhP7G1sIUbrs1268u38D4BUAxwGoYWZ7d1sOBxDcYXP3Ce6e7e7ZsY8aCiFKlwKD3cwOM7Ma6cdVAPQAsAKpoP9F+tv6A+CjMoQQZU5hUm9tkdqAK4/Ui8MMdx9pZs2RSr3VAvAOgEvcnedwABx55JF+xx13BLVdu3ZRv3bt2gXtS5YsoT4NGzakWk5ODtVio6Fq1KgRtFeqVIn6bN++nWqxFOC6deuo1r17d6otXLgwaO/Rowf1ycvLo1psHNbBBx9MNdZPLvZzxdJyzz33HNVOP/10qrGed7HrPjbyKpaajRUvxVKHrCiHjXgCgN27dwfto0aNwrp164KptwKr3tx9GYD/OGPuvgapv9+FEP8F6BN0QiQEBbsQCUHBLkRCULALkRAU7EIkhAJTbyV6MLPNAPbmXuoA+CJjB+doHT9E6/gh/23raOLuh4WEjAb7Dw5sluPu2WVycK1D60jgOvQ2XoiEoGAXIiGUZbBPKMNj74vW8UO0jh/y/2YdZfY3uxAis+htvBAJoUyC3cxOM7N/m9lqM+PdE0t/HWvNbLmZLTUzXgpX8sd93Mw2mdl7+9hqmdl8M1uV/r9mGa1jhJmtT5+TpWb28wyso7GZvWJmH5jZ+2b2m7Q9o+ckso6MnhMzq2xmi83s3fQ6bknbm5nZonTcTDczXm4Zwt0z+g+pUtmPADQHUAnAuwBaZ3od6bWsBVCnDI7bDUAHAO/tYxsNYGj68VAAd5XROkYAGJLh89EAQIf042oAPgTQOtPnJLKOjJ4TAAagavpxRQCLAHQGMAPAhWn7eAC/2p/nLYs7eycAq919jadaT08D0KsM1lFmuPtCAD+efNgLqb4BQIYaeJJ1ZBx33+Dub6cfb0WqOUojZPicRNaRUTxFiTd5LYtgbwTg032+LstmlQ5gnpktMbOBZbSGvdRz9w3pxxsB1CvDtQw2s2Xpt/ml/ufEvphZU6T6JyxCGZ6TH60DyPA5KY0mr0nfoOvq7h0AnA5gkJl1K+sFAalXdqReiMqCRwAcidSMgA0AxmbqwGZWFcDTAK5z9x+0z8nkOQmsI+PnxIvR5JVRFsG+HsC+o1Vos8rSxt3Xp//fBOBZlG3nnVwzawAA6f83lcUi3D03faHtAfAoMnROzKwiUgH2lLs/kzZn/JyE1lFW5yR97P1u8sooi2B/C0BWemexEoALAczO9CLM7BAzq7b3MYCeAHjDtdJnNlKNO4EybOC5N7jSnIMMnBMzMwATAaxw93H7SBk9J2wdmT4npdbkNVM7jD/abfw5UjudHwH4QxmtoTlSmYB3AbyfyXUA+CtSbwd3IfW31wCkZuYtALAKwEsAapXROqYAWA5gGVLB1iAD6+iK1Fv0ZQCWpv/9PNPnJLKOjJ4TAG2RauK6DKkXlj/tc80uBrAawEwAB+3P8+oTdEIkhKRv0AmRGBTsQiQEBbsQCUHBLkRCULALkRAU7EIkBAW7EAlBwS5EQvgfGZfi7XEnxDoAAAAASUVORK5CYII=\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"tags": [],
"needs_background": "light"
}
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "8-KR-tcPyqhc"
},
"source": [
"def show_tensor_image(img_tensor, plotTitle, num_images=25, size=(1,32,32)):\n",
" ''' Make a grid of images ''' \n",
" image_tensor = (img_tensor+1)/2\n",
" image_unflat = image_tensor.detach().cpu()\n",
"\n",
" image_grid = make_grid(image_unflat[:num_images], nrow=5)\n",
"\n",
" plt.title(plotTitle)\n",
" plt.imshow(image_grid.permute(1, 2, 0).squeeze())\n",
" plt.show()"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 441,
"referenced_widgets": [
"84b2967b5937443095e0cb89c5519704",
"78459a4f65314489831c725d83c0877a",
"2e709235ea2b42dab063e14f08e6383f",
"67a1305205db49bcac1a3be7f177282d",
"8e9c9f422b964f118c65b0c28ccb4ab4",
"33606ed261614b998d9f2c0e08b54963",
"4b67b60e414f435b99169c408ab52ce5",
"b414b40e11e84e4f98d5e1b72494fdc6"
]
},
"id": "aMgfXY1aqaR_",
"outputId": "0162ac23-0683-41ce-8496-30d21949b5d8"
},
"source": [
"### TRAINING ###\n",
"criterion = nn.BCELoss()\n",
"n_epochs = 500\n",
"display_step = 500 # for visualization\n",
"cur_step = 0\n",
"noise_dim = 100\n",
"IMG_SIZE = 32\n",
"BATCH_SIZE = 32\n",
"LEARNING_RATE = 0.0002 # a /alpha\n",
"FEATURES = 32 # this must correspond to the architecture of both d and g (32x32)\n",
"NUM_CHANNELS = 3\n",
"device = 'cpu'\n",
"\n",
"# betas as described in DCGAN paper\n",
"beta_1 = 0.5\n",
"beta_2 = 0.9\n",
"\n",
"transform = transforms.Compose([\n",
" transforms.Resize(IMG_SIZE),\n",
" transforms.ToTensor(),\n",
" transforms.Normalize(\n",
" [0.5 for _ in range(1)] , \n",
" [0.5 for _ in range(1)]\n",
" )\n",
"])\n",
"\n",
"d_loader = DataLoader(\n",
" flower_images,\n",
" batch_size = BATCH_SIZE,\n",
" shuffle = True\n",
")\n",
"\n",
"dataloader = DataLoader(\n",
" MNIST('.', train=True, download=False, transform=transform),\n",
" batch_size = BATCH_SIZE,\n",
" shuffle = True\n",
")\n",
"\n",
"# Initialize generator and discriminator (32x32 variant)\n",
"gen = Generator(noise_dim, NUM_CHANNELS, 64).to(device)\n",
"disc = Discriminator(NUM_CHANNELS, 64).to(device)\n",
"\n",
"# adam optimizers, dcgan paper\n",
"gen_optim = torch.optim.Adam(gen.parameters(), lr=LEARNING_RATE, betas=(beta_1, beta_2))\n",
"disc_optim = torch.optim.Adam(disc.parameters(), lr=LEARNING_RATE, betas=(beta_1, beta_2))\n",
"\n",
"# initialize weights\n",
"#gen = gen.apply(initialize_weights)\n",
"#disc = disc.apply(initialize_weights)\n",
"initialize_weights(gen)\n",
"initialize_weights(disc)\n",
"\n",
"for epochs in range(n_epochs):\n",
" # we aren't using labels as of right now (StyleGAN will use these)\n",
" for real in tqdm(d_loader):\n",
" cur_batch_size = len(real)\n",
" real = real.double().to(device)\n",
" fake_noise = torch.randn((cur_batch_size, noise_dim, 1, 1), device=device)\n",
" fake = gen(fake_noise)\n",
"\n",
" print(real.dtype)\n",
"\n",
" ### Updating disc\n",
" disc_real_pred = disc(real)\n",
" disc_real_loss = criterion(disc_real_pred, torch.ones_like(disc_real_pred))\n",
"\n",
" disc_fake_pred = disc(fake.detach()).reshape(-1)\n",
" disc_fake_loss = criterion(disc_fake_pred, torch.zeros_like(disc_fake_pred))\n",
"\n",
" # total loss is the avg of the two\n",
" #disc_loss = disc_real_loss - disc_fake_loss\n",
" disc_loss = (disc_fake_loss + disc_real_loss) / 2\n",
"\n",
" # update gradients\n",
" disc.zero_grad()\n",
" disc_loss.backward(retain_graph=True)\n",
" disc_optim.step()\n",
"\n",
" ### Updating gen\n",
" output = disc(fake).reshape(-1)\n",
" gen_loss = criterion(output, torch.ones_like(output))\n",
" gen.zero_grad()\n",
" gen_loss.backward()\n",
" gen_optim.step()\n",
"\n",
" if cur_step % display_step == 0 and cur_step > 0:\n",
" print(f'Step {cur_step}: Generator Loss: {gen_loss}, discriminator loss: {disc_loss}')\n",
" show_tensor_image(fake, 'fake/generated', num_images=25)\n",
" show_tensor_image(real, 'real/dataset', num_images=25)\n",
" mean_gen_loss = 0\n",
" mean_disc_loss = 0\n",
" cur_step+=1\n"
],
"execution_count": null,
"outputs": [
{
"output_type": "display_data",
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "84b2967b5937443095e0cb89c5519704",
"version_minor": 0,
"version_major": 2
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0, max=136.0), HTML(value='')))"
]
},
"metadata": {
"tags": []
}
},
{
"output_type": "stream",
"text": [
"torch.float64\n"
],
"name": "stdout"
},
{
"output_type": "error",
"ename": "RuntimeError",
"evalue": "ignored",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m<ipython-input-83-1c31cc3762f1>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[1;32m 62\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 63\u001b[0m \u001b[0;31m### Updating disc\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 64\u001b[0;31m \u001b[0mdisc_real_pred\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mdisc\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mreal\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 65\u001b[0m \u001b[0mdisc_real_loss\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcriterion\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdisc_real_pred\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mones_like\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdisc_real_pred\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 66\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 887\u001b[0m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_slow_forward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 888\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 889\u001b[0;31m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 890\u001b[0m for hook in itertools.chain(\n\u001b[1;32m 891\u001b[0m \u001b[0m_global_forward_hooks\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m<ipython-input-46-6554cda8df86>\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, x)\u001b[0m\n\u001b[1;32m 36\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 37\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 38\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdisc\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 39\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 40\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 887\u001b[0m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_slow_forward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 888\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 889\u001b[0;31m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 890\u001b[0m for hook in itertools.chain(\n\u001b[1;32m 891\u001b[0m \u001b[0m_global_forward_hooks\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.7/dist-packages/torch/nn/modules/container.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, input)\u001b[0m\n\u001b[1;32m 117\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 118\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mmodule\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 119\u001b[0;31m \u001b[0minput\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmodule\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 120\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 121\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 887\u001b[0m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_slow_forward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 888\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 889\u001b[0;31m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 890\u001b[0m for hook in itertools.chain(\n\u001b[1;32m 891\u001b[0m \u001b[0m_global_forward_hooks\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.7/dist-packages/torch/nn/modules/conv.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, input)\u001b[0m\n\u001b[1;32m 397\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 398\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mTensor\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m->\u001b[0m \u001b[0mTensor\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 399\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_conv_forward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mweight\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbias\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 400\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 401\u001b[0m \u001b[0;32mclass\u001b[0m \u001b[0mConv3d\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0m_ConvNd\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.7/dist-packages/torch/nn/modules/conv.py\u001b[0m in \u001b[0;36m_conv_forward\u001b[0;34m(self, input, weight, bias)\u001b[0m\n\u001b[1;32m 394\u001b[0m _pair(0), self.dilation, self.groups)\n\u001b[1;32m 395\u001b[0m return F.conv2d(input, weight, bias, self.stride,\n\u001b[0;32m--> 396\u001b[0;31m self.padding, self.dilation, self.groups)\n\u001b[0m\u001b[1;32m 397\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 398\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mTensor\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m->\u001b[0m \u001b[0mTensor\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;31mRuntimeError\u001b[0m: expected scalar type Double but found Float"
]
},
{
"output_type": "stream",
"text": [
"\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "rKUG2IlWKm-q"
},
"source": [],
"execution_count": null,
"outputs": []
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment