Skip to content

Instantly share code, notes, and snippets.

@pigghead
Created March 31, 2023 16:21
Show Gist options
  • Save pigghead/2d90da0ce83250eb0d3bfdc4b8af00a3 to your computer and use it in GitHub Desktop.
Save pigghead/2d90da0ce83250eb0d3bfdc4b8af00a3 to your computer and use it in GitHub Desktop.
DCGAN-SimpleImplementation.ipynb
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"provenance": [],
"authorship_tag": "ABX9TyPmQr3EPD0CZLWuDNpws92m",
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
},
"accelerator": "GPU",
"widgets": {
"application/vnd.jupyter.widget-state+json": {
"84b2967b5937443095e0cb89c5519704": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HBoxModel",
"model_module_version": "1.5.0",
"state": {
"_view_name": "HBoxView",
"_dom_classes": [],
"_model_name": "HBoxModel",
"_view_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_view_count": null,
"_view_module_version": "1.5.0",
"box_style": "",
"layout": "IPY_MODEL_78459a4f65314489831c725d83c0877a",
"_model_module": "@jupyter-widgets/controls",
"children": [
"IPY_MODEL_2e709235ea2b42dab063e14f08e6383f",
"IPY_MODEL_67a1305205db49bcac1a3be7f177282d"
]
}
},
"78459a4f65314489831c725d83c0877a": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"model_module_version": "1.2.0",
"state": {
"_view_name": "LayoutView",
"grid_template_rows": null,
"right": null,
"justify_content": null,
"_view_module": "@jupyter-widgets/base",
"overflow": null,
"_model_module_version": "1.2.0",
"_view_count": null,
"flex_flow": null,
"width": null,
"min_width": null,
"border": null,
"align_items": null,
"bottom": null,
"_model_module": "@jupyter-widgets/base",
"top": null,
"grid_column": null,
"overflow_y": null,
"overflow_x": null,
"grid_auto_flow": null,
"grid_area": null,
"grid_template_columns": null,
"flex": null,
"_model_name": "LayoutModel",
"justify_items": null,
"grid_row": null,
"max_height": null,
"align_content": null,
"visibility": null,
"align_self": null,
"height": null,
"min_height": null,
"padding": null,
"grid_auto_rows": null,
"grid_gap": null,
"max_width": null,
"order": null,
"_view_module_version": "1.2.0",
"grid_template_areas": null,
"object_position": null,
"object_fit": null,
"grid_auto_columns": null,
"margin": null,
"display": null,
"left": null
}
},
"2e709235ea2b42dab063e14f08e6383f": {
"model_module": "@jupyter-widgets/controls",
"model_name": "FloatProgressModel",
"model_module_version": "1.5.0",
"state": {
"_view_name": "ProgressView",
"style": "IPY_MODEL_8e9c9f422b964f118c65b0c28ccb4ab4",
"_dom_classes": [],
"description": " 0%",
"_model_name": "FloatProgressModel",
"bar_style": "danger",
"max": 136,
"_view_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"value": 0,
"_view_count": null,
"_view_module_version": "1.5.0",
"orientation": "horizontal",
"min": 0,
"description_tooltip": null,
"_model_module": "@jupyter-widgets/controls",
"layout": "IPY_MODEL_33606ed261614b998d9f2c0e08b54963"
}
},
"67a1305205db49bcac1a3be7f177282d": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HTMLModel",
"model_module_version": "1.5.0",
"state": {
"_view_name": "HTMLView",
"style": "IPY_MODEL_4b67b60e414f435b99169c408ab52ce5",
"_dom_classes": [],
"description": "",
"_model_name": "HTMLModel",
"placeholder": "​",
"_view_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"value": " 0/136 [00:01<?, ?it/s]",
"_view_count": null,
"_view_module_version": "1.5.0",
"description_tooltip": null,
"_model_module": "@jupyter-widgets/controls",
"layout": "IPY_MODEL_b414b40e11e84e4f98d5e1b72494fdc6"
}
},
"8e9c9f422b964f118c65b0c28ccb4ab4": {
"model_module": "@jupyter-widgets/controls",
"model_name": "ProgressStyleModel",
"model_module_version": "1.5.0",
"state": {
"_view_name": "StyleView",
"_model_name": "ProgressStyleModel",
"description_width": "initial",
"_view_module": "@jupyter-widgets/base",
"_model_module_version": "1.5.0",
"_view_count": null,
"_view_module_version": "1.2.0",
"bar_color": null,
"_model_module": "@jupyter-widgets/controls"
}
},
"33606ed261614b998d9f2c0e08b54963": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"model_module_version": "1.2.0",
"state": {
"_view_name": "LayoutView",
"grid_template_rows": null,
"right": null,
"justify_content": null,
"_view_module": "@jupyter-widgets/base",
"overflow": null,
"_model_module_version": "1.2.0",
"_view_count": null,
"flex_flow": null,
"width": null,
"min_width": null,
"border": null,
"align_items": null,
"bottom": null,
"_model_module": "@jupyter-widgets/base",
"top": null,
"grid_column": null,
"overflow_y": null,
"overflow_x": null,
"grid_auto_flow": null,
"grid_area": null,
"grid_template_columns": null,
"flex": null,
"_model_name": "LayoutModel",
"justify_items": null,
"grid_row": null,
"max_height": null,
"align_content": null,
"visibility": null,
"align_self": null,
"height": null,
"min_height": null,
"padding": null,
"grid_auto_rows": null,
"grid_gap": null,
"max_width": null,
"order": null,
"_view_module_version": "1.2.0",
"grid_template_areas": null,
"object_position": null,
"object_fit": null,
"grid_auto_columns": null,
"margin": null,
"display": null,
"left": null
}
},
"4b67b60e414f435b99169c408ab52ce5": {
"model_module": "@jupyter-widgets/controls",
"model_name": "DescriptionStyleModel",
"model_module_version": "1.5.0",
"state": {
"_view_name": "StyleView",
"_model_name": "DescriptionStyleModel",
"description_width": "",
"_view_module": "@jupyter-widgets/base",
"_model_module_version": "1.5.0",
"_view_count": null,
"_view_module_version": "1.2.0",
"_model_module": "@jupyter-widgets/controls"
}
},
"b414b40e11e84e4f98d5e1b72494fdc6": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"model_module_version": "1.2.0",
"state": {
"_view_name": "LayoutView",
"grid_template_rows": null,
"right": null,
"justify_content": null,
"_view_module": "@jupyter-widgets/base",
"overflow": null,
"_model_module_version": "1.2.0",
"_view_count": null,
"flex_flow": null,
"width": null,
"min_width": null,
"border": null,
"align_items": null,
"bottom": null,
"_model_module": "@jupyter-widgets/base",
"top": null,
"grid_column": null,
"overflow_y": null,
"overflow_x": null,
"grid_auto_flow": null,
"grid_area": null,
"grid_template_columns": null,
"flex": null,
"_model_name": "LayoutModel",
"justify_items": null,
"grid_row": null,
"max_height": null,
"align_content": null,
"visibility": null,
"align_self": null,
"height": null,
"min_height": null,
"padding": null,
"grid_auto_rows": null,
"grid_gap": null,
"max_width": null,
"order": null,
"_view_module_version": "1.2.0",
"grid_template_areas": null,
"object_position": null,
"object_fit": null,
"grid_auto_columns": null,
"margin": null,
"display": null,
"left": null
}
}
}
}
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/pigghead/2d90da0ce83250eb0d3bfdc4b8af00a3/dcgan-simpleimplementation.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "VNZcgLAJ-Exe",
"outputId": "904fc33c-ea1d-401d-92d4-7531764080e8"
},
"source": [
"from google.colab import drive\n",
"drive.mount('/content/drive', force_remount=True)"
],
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Mounted at /content/drive\n"
]
}
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "K1nHyBGMrRz7",
"outputId": "3b915685-e4e7-4493-f99a-ed2809c203e0"
},
"source": [
"import torch\n",
"from torch import nn\n",
"from tqdm.auto import tqdm\n",
"from torchvision import transforms\n",
"from torchvision.datasets import MNIST # training dataset\n",
"from torchvision.utils import make_grid\n",
"from torch.utils.data import DataLoader\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"torch.manual_seed(0)"
],
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"<torch._C.Generator at 0x7ff17dcfdcb0>"
]
},
"metadata": {
"tags": []
},
"execution_count": 44
}
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "95Fjqut-BfoL",
"outputId": "1c1198cd-b220-48ef-a855-e9320f8b8246"
},
"source": [
"### File retrieved from separate colab file \n",
"### (gist: https://gist.github.com/pigghead/ece7aea9b51b799fbb17b5f1e7c1521b#file-downscale-image-test-ipynb)\n",
"# bring flwr_data.npy in as an ndarray\n",
"flower_images = np.load('/content/drive/MyDrive/YEAR 5/IGME 797/flwr_data_64x64.npy')\n",
"\n",
"def confirm_import():\n",
" assert flower_images.shape == (4323, 3, 64, 64), flower_images.shape #entire ndarray\n",
" assert flower_images[0].shape == (3,64,64) # single instance\n",
" assert flower_images[0].dtype == 'float'\n",
" print(type(flower_images[0]))\n",
" print('import successful')\n",
"\n",
"confirm_import()\n",
"### NOTE: Shape currently is: Samples, Height, Width, Channels\n",
"### desired shape is Samples, Channels, Height, Width\n",
"\n",
"\n",
"\n",
"print(flower_images.astype(np.double).dtype)"
],
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"text": [
"<class 'numpy.ndarray'>\n",
"import successful\n",
"float64\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "pzy0utf4rJCr",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "e7c387d5-eb5b-4675-b049-ee3832391609"
},
"source": [
"\"\"\"\n",
"Discriminator and Generator implementation from DCGAN paper\n",
"\"\"\"\n",
"\n",
"class Discriminator(nn.Module):\n",
" def __init__(self, channels_img, features_d):\n",
" super(Discriminator, self).__init__()\n",
" self.disc = nn.Sequential(\n",
" # input: N x channels_img x 64 x 64\n",
" nn.Conv2d(\n",
" channels_img, features_d, kernel_size=4, stride=2, padding=1\n",
" ),\n",
" nn.LeakyReLU(0.2),\n",
" # _block(in_channels, out_channels, kernel_size, stride, padding)\n",
" self._block(features_d, features_d * 2, 4, 2, 1),\n",
" self._block(features_d * 2, features_d * 4, 4, 2, 1),\n",
" self._block(features_d * 4, features_d * 8, 4, 2, 1),\n",
" # After all _block img output is 4x4 (Conv2d below makes into 1x1)\n",
" nn.Conv2d(features_d * 8, 1, kernel_size=4, stride=2, padding=0),\n",
" nn.Sigmoid(),\n",
" )\n",
"\n",
" def _block(self, in_channels, out_channels, kernel_size, stride, padding):\n",
" return nn.Sequential(\n",
" nn.Conv2d(\n",
" in_channels,\n",
" out_channels,\n",
" kernel_size,\n",
" stride,\n",
" padding,\n",
" bias=False,\n",
" ),\n",
" #nn.BatchNorm2d(out_channels),\n",
" nn.LeakyReLU(0.2),\n",
" )\n",
"\n",
" def forward(self, x):\n",
" return self.disc(x)\n",
"\n",
"\n",
"class Generator(nn.Module):\n",
" def __init__(self, channels_noise, channels_img, features_g):\n",
" super(Generator, self).__init__()\n",
" self.net = nn.Sequential(\n",
" self._block(channels_noise, features_g*16, 4, 1, 0), # 4x4\n",
" self._block(features_g*16, features_g*8, 4, 2, 1), # 8x8\n",
" self._block(features_g*8, features_g*4, 4, 2, 1), # 16x16\n",
" self._block(features_g*4, features_g*2, 4, 2, 1), # 32x32\n",
" nn.ConvTranspose2d(features_g*2, channels_img, 4, 2, 1), #64x64\n",
" nn.Tanh()\n",
" )\n",
"\n",
"\n",
" def _block(self, in_channels, out_channels, kernel_size, stride, padding):\n",
" return nn.Sequential(\n",
" nn.ConvTranspose2d(\n",
" in_channels,\n",
" out_channels,\n",
" kernel_size,\n",
" stride,\n",
" padding,\n",
" bias=False,\n",
" ),\n",
" #nn.BatchNorm2d(out_channels),\n",
" nn.ReLU(),\n",
" )\n",
"\n",
"\n",
" def forward(self, x):\n",
" return self.net(x)\n",
"\n",
"\n",
"def initialize_weights(model):\n",
" # Initializes weights according to the DCGAN paper\n",
" for m in model.modules():\n",
" if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.BatchNorm2d)):\n",
" nn.init.normal_(m.weight.data, 0.0, 0.02)\n",
"\n",
"\n",
"def test():\n",
" N, in_channels, H, W = 8, 3, 64, 64\n",
" noise_dim = 100\n",
" gen = Generator(noise_dim, in_channels, 8)\n",
" z = torch.randn((N, noise_dim, 1, 1))\n",
" print('gen(z) shape: ', gen(z).shape)\n",
"\n",
" disc = Discriminator(in_channels, 8)\n",
" x = torch.randn((N, in_channels, H, W))\n",
" print('disc(x) shape: ', disc(x).shape)\n",
"\n",
" print('disc(gen(z)) shape: ', disc(gen(z)).shape)\n",
" ####### Expected: N=8, C=1, H=64, W=64\n",
" assert gen(z).shape == (N, in_channels, H, W), gen(z).shape\n",
" #assert gen(z2).shape == (N, in_channels, H, W), gen(z2).shape\n",
"\n",
"\n",
"test()"
],
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"text": [
"gen(z) shape: torch.Size([8, 3, 64, 64])\n",
"disc(x) shape: torch.Size([8, 1, 1, 1])\n",
"disc(gen(z)) shape: torch.Size([8, 1, 1, 1])\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "A6wAPd4cZ2iH",
"outputId": "428b6a17-f4f5-4f51-f82b-f3ef5fc6f099"
},
"source": [
"class discriminator_32x32(nn.Module):\n",
" \n",
" def __init__(self, channels_img, features_d):\n",
" super(discriminator_32x32, self).__init__()\n",
" self.disc = nn.Sequential(\n",
" # Input: N x channels_img x features_d x features_d (?)\n",
" self._block(channels_img, features_d*2, 4, 2, 1), # 16x16\n",
" self._block(features_d*2, features_d*4, 4, 2, 1), #8x8\n",
" self._block(features_d*4, features_d*8, 4, 2, 1), #4x4\n",
" #self._block(features_d*8, features_d*16, 4, 2, 1),\n",
" nn.Conv2d(features_d*8, 1, 4, 1, 0), #2x2\n",
" nn.Sigmoid() # 1 or 0\n",
" )\n",
"\n",
" def _block(self, in_channels, out_channels, kernel_size, stride, padding):\n",
" return nn.Sequential(\n",
" nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=False),\n",
" nn.BatchNorm2d(out_channels),\n",
" nn.LeakyReLU(0.2, inplace=True)\n",
" )\n",
"\n",
"\n",
" def forward(self, img):\n",
" return self.disc(img)\n",
"\n",
"''' EXAMPLE TEST CLASS\n",
"\n",
" N, in_channels, H, W = 8, 3, 64, 64\n",
" noise_dim = 100\n",
" x = torch.randn((N, in_channels, H, W))\n",
"\n",
" disc = Discriminator(in_channels, 8)\n",
" assert disc(x).shape == (N, 1, 1, 1), \"Discriminator test failed\"\n",
"'''\n",
"\n",
"def test_disc():\n",
" N, in_channels, H, W = 8, 1, 32, 32\n",
" noise_dim = 100\n",
" x = torch.randn((N, in_channels, H, W)) # !!ERROR HERE (resolved)\n",
" print(x.shape)\n",
" disc = discriminator_32x32(in_channels, 32) \n",
"\n",
" print( disc(x)[0] )\n",
" assert disc(x).shape == (N, 1, 1, 1), \"Discriminator test failed\"\n",
"\n",
"\n",
"test_disc()"
],
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"text": [
"torch.Size([8, 1, 32, 32])\n",
"tensor([[[0.2947]]], grad_fn=<SelectBackward>)\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 283
},
"id": "pALZEKeiQVUW",
"outputId": "804ba980-47a4-430a-8799-353648ba4ce9"
},
"source": [
"class generator_32x32(nn.Module):\n",
" \n",
" def __init__(self, noise_channels, im_channels, features_g):\n",
" super(generator_32x32, self).__init__()\n",
" self.gen = nn.Sequential( \n",
" self._block(noise_channels, features_g*8, 4, 1, 0), #4x4\n",
" self._block(features_g*8, features_g*4, 4, 2, 1), #8x8\n",
" self._block(features_g*4, features_g*2, 4, 2, 1), #16x16\n",
" nn.ConvTranspose2d(\n",
" features_g*2, im_channels, 4, 2, 1\n",
" ),\n",
" # Output= s(n-1)+f-2p, 32x32 img\n",
" nn.Tanh()\n",
" )\n",
"\n",
"\n",
" def _block(self, in_channel, out_channel, kernel_size, stride, padding):\n",
" ''' \n",
" Performs transposed convolution, batch normalization, and Leaky ReLU activation \n",
" '''\n",
" return nn.Sequential (\n",
" nn.ConvTranspose2d(in_channel, out_channel, kernel_size, stride, padding, bias=False),\n",
" # Batch norm?\n",
" nn.BatchNorm2d(out_channel),\n",
" nn.LeakyReLU(0.2)\n",
" )\n",
"\n",
"\n",
" def forward(self, x):\n",
" return self.gen(x)\n",
"\n",
"\n",
"''' EXAMPLE TEST CLASS\n",
"\n",
" N, in_channels, H, W = 8, 3, 64, 64\n",
" noise_dim = 100\n",
" x = torch.randn((N, in_channels, H, W))\n",
"\n",
" disc = Discriminator(in_channels, 8)\n",
" assert disc(x).shape == (N, 1, 1, 1), \"Discriminator test failed\"\n",
"\n",
" gen = Generator(noise_dim, in_channels, 8)\n",
" z = torch.randn((N, noise_dim, 1, 1))\n",
" assert gen(z).shape == (N, in_channels, H, W), \"Generator test failed\"\n",
"'''\n",
"\n",
"\n",
"# test our 28x28 generator\n",
"def test_gen_28x28():\n",
" N, in_channels, H, W = 8, 1, 32, 32\n",
" noise_dim = 100\n",
" z = torch.randn((N, noise_dim, 1, 1))\n",
" gen1 = generator_32x32(noise_dim, in_channels, 32)\n",
" assert gen1(z).shape == (N, in_channels, H, W)\n",
" print('gen1(z).shape',gen1(z).shape)\n",
"\n",
"\n",
"#test_gen_28x28()\n",
"\n",
"\n",
"def print_gen_28x28():\n",
" N, in_channels, H, W = 1, 1, 32, 32\n",
" noise_dim = 128\n",
" z = torch.randn((N, noise_dim, 1, 1),device='cpu')\n",
" gen = generator_32x32(noise_dim, in_channels, 8).to('cpu')\n",
" print('gen(z) shape: ', gen(z).shape)\n",
" plt.Figure(figsize=(28, 28))\n",
" plt.imshow(gen(z)[0, 0, :, :].detach().cpu(), cmap='gray')\n",
" plt.show()\n",
"\n",
" \n",
"#test_gen_28x28()\n",
"print_gen_28x28()\n",
"\n",
"\n",
"# test both\n",
"def test_both():\n",
" N, in_channels, H, W = 8, 1, 32, 32\n",
" noise_dim = 100\n",
" # A test generator\n",
" z = torch.randn((N, noise_dim, 1, 1))\n",
" gen = generator_32x32(noise_dim, in_channels, 32)\n",
" # A test discriminator\n",
" disc = discriminator_32x32(in_channels, 32)\n",
" print(disc(gen(z)).shape)\n",
"\n",
"\n",
"#test_both()"
],
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"text": [
"gen(z) shape: torch.Size([1, 1, 32, 32])\n"
],
"name": "stdout"
},
{
"output_type": "display_data",
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"tags": [],
"needs_background": "light"
}
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "8-KR-tcPyqhc"
},
"source": [
"def show_tensor_image(img_tensor, plotTitle, num_images=25, size=(1,32,32)):\n",
" ''' Make a grid of images ''' \n",
" image_tensor = (img_tensor+1)/2\n",
" image_unflat = image_tensor.detach().cpu()\n",
"\n",
" image_grid = make_grid(image_unflat[:num_images], nrow=5)\n",
"\n",
" plt.title(plotTitle)\n",
" plt.imshow(image_grid.permute(1, 2, 0).squeeze())\n",
" plt.show()"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 441,
"referenced_widgets": [
"84b2967b5937443095e0cb89c5519704",
"78459a4f65314489831c725d83c0877a",
"2e709235ea2b42dab063e14f08e6383f",
"67a1305205db49bcac1a3be7f177282d",
"8e9c9f422b964f118c65b0c28ccb4ab4",
"33606ed261614b998d9f2c0e08b54963",
"4b67b60e414f435b99169c408ab52ce5",
"b414b40e11e84e4f98d5e1b72494fdc6"
]
},
"id": "aMgfXY1aqaR_",
"outputId": "0162ac23-0683-41ce-8496-30d21949b5d8"
},
"source": [
"### TRAINING ###\n",
"criterion = nn.BCELoss()\n",
"n_epochs = 500\n",
"display_step = 500 # for visualization\n",
"cur_step = 0\n",
"noise_dim = 100\n",
"IMG_SIZE = 32\n",
"BATCH_SIZE = 32\n",
"LEARNING_RATE = 0.0002 # a /alpha\n",
"FEATURES = 32 # this must correspond to the architecture of both d and g (32x32)\n",
"NUM_CHANNELS = 3\n",
"device = 'cpu'\n",
"\n",
"# betas as described in DCGAN paper\n",
"beta_1 = 0.5\n",
"beta_2 = 0.9\n",
"\n",
"transform = transforms.Compose([\n",
" transforms.Resize(IMG_SIZE),\n",
" transforms.ToTensor(),\n",
" transforms.Normalize(\n",
" [0.5 for _ in range(1)] , \n",
" [0.5 for _ in range(1)]\n",
" )\n",
"])\n",
"\n",
"d_loader = DataLoader(\n",
" flower_images,\n",
" batch_size = BATCH_SIZE,\n",
" shuffle = True\n",
")\n",
"\n",
"dataloader = DataLoader(\n",
" MNIST('.', train=True, download=False, transform=transform),\n",
" batch_size = BATCH_SIZE,\n",
" shuffle = True\n",
")\n",
"\n",
"# Initialize generator and discriminator (32x32 variant)\n",
"gen = Generator(noise_dim, NUM_CHANNELS, 64).to(device)\n",
"disc = Discriminator(NUM_CHANNELS, 64).to(device)\n",
"\n",
"# adam optimizers, dcgan paper\n",
"gen_optim = torch.optim.Adam(gen.parameters(), lr=LEARNING_RATE, betas=(beta_1, beta_2))\n",
"disc_optim = torch.optim.Adam(disc.parameters(), lr=LEARNING_RATE, betas=(beta_1, beta_2))\n",
"\n",
"# initialize weights\n",
"#gen = gen.apply(initialize_weights)\n",
"#disc = disc.apply(initialize_weights)\n",
"initialize_weights(gen)\n",
"initialize_weights(disc)\n",
"\n",
"for epochs in range(n_epochs):\n",
" # we aren't using labels as of right now (StyleGAN will use these)\n",
" for real in tqdm(d_loader):\n",
" cur_batch_size = len(real)\n",
" real = real.double().to(device)\n",
" fake_noise = torch.randn((cur_batch_size, noise_dim, 1, 1), device=device)\n",
" fake = gen(fake_noise)\n",
"\n",
" print(real.dtype)\n",
"\n",
" ### Updating disc\n",
" disc_real_pred = disc(real)\n",
" disc_real_loss = criterion(disc_real_pred, torch.ones_like(disc_real_pred))\n",
"\n",
" disc_fake_pred = disc(fake.detach()).reshape(-1)\n",
" disc_fake_loss = criterion(disc_fake_pred, torch.zeros_like(disc_fake_pred))\n",
"\n",
" # total loss is the avg of the two\n",
" #disc_loss = disc_real_loss - disc_fake_loss\n",
" disc_loss = (disc_fake_loss + disc_real_loss) / 2\n",
"\n",
" # update gradients\n",
" disc.zero_grad()\n",
" disc_loss.backward(retain_graph=True)\n",
" disc_optim.step()\n",
"\n",
" ### Updating gen\n",
" output = disc(fake).reshape(-1)\n",
" gen_loss = criterion(output, torch.ones_like(output))\n",
" gen.zero_grad()\n",
" gen_loss.backward()\n",
" gen_optim.step()\n",
"\n",
" if cur_step % display_step == 0 and cur_step > 0:\n",
" print(f'Step {cur_step}: Generator Loss: {gen_loss}, discriminator loss: {disc_loss}')\n",
" show_tensor_image(fake, 'fake/generated', num_images=25)\n",
" show_tensor_image(real, 'real/dataset', num_images=25)\n",
" mean_gen_loss = 0\n",
" mean_disc_loss = 0\n",
" cur_step+=1\n"
],
"execution_count": null,
"outputs": [
{
"output_type": "display_data",
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "84b2967b5937443095e0cb89c5519704",
"version_minor": 0,
"version_major": 2
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0, max=136.0), HTML(value='')))"
]
},
"metadata": {
"tags": []
}
},
{
"output_type": "stream",
"text": [
"torch.float64\n"
],
"name": "stdout"
},
{
"output_type": "error",
"ename": "RuntimeError",
"evalue": "ignored",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m<ipython-input-83-1c31cc3762f1>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[1;32m 62\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 63\u001b[0m \u001b[0;31m### Updating disc\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 64\u001b[0;31m \u001b[0mdisc_real_pred\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mdisc\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mreal\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 65\u001b[0m \u001b[0mdisc_real_loss\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcriterion\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdisc_real_pred\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mones_like\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdisc_real_pred\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 66\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 887\u001b[0m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_slow_forward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 888\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 889\u001b[0;31m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 890\u001b[0m for hook in itertools.chain(\n\u001b[1;32m 891\u001b[0m \u001b[0m_global_forward_hooks\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m<ipython-input-46-6554cda8df86>\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, x)\u001b[0m\n\u001b[1;32m 36\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 37\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 38\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdisc\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 39\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 40\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 887\u001b[0m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_slow_forward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 888\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 889\u001b[0;31m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 890\u001b[0m for hook in itertools.chain(\n\u001b[1;32m 891\u001b[0m \u001b[0m_global_forward_hooks\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.7/dist-packages/torch/nn/modules/container.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, input)\u001b[0m\n\u001b[1;32m 117\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 118\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mmodule\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 119\u001b[0;31m \u001b[0minput\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmodule\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 120\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 121\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 887\u001b[0m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_slow_forward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 888\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 889\u001b[0;31m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 890\u001b[0m for hook in itertools.chain(\n\u001b[1;32m 891\u001b[0m \u001b[0m_global_forward_hooks\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.7/dist-packages/torch/nn/modules/conv.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, input)\u001b[0m\n\u001b[1;32m 397\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 398\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mTensor\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m->\u001b[0m \u001b[0mTensor\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 399\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_conv_forward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mweight\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbias\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 400\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 401\u001b[0m \u001b[0;32mclass\u001b[0m \u001b[0mConv3d\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0m_ConvNd\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.7/dist-packages/torch/nn/modules/conv.py\u001b[0m in \u001b[0;36m_conv_forward\u001b[0;34m(self, input, weight, bias)\u001b[0m\n\u001b[1;32m 394\u001b[0m _pair(0), self.dilation, self.groups)\n\u001b[1;32m 395\u001b[0m return F.conv2d(input, weight, bias, self.stride,\n\u001b[0;32m--> 396\u001b[0;31m self.padding, self.dilation, self.groups)\n\u001b[0m\u001b[1;32m 397\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 398\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mTensor\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m->\u001b[0m \u001b[0mTensor\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;31mRuntimeError\u001b[0m: expected scalar type Double but found Float"
]
},
{
"output_type": "stream",
"text": [
"\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "rKUG2IlWKm-q"
},
"source": [],
"execution_count": null,
"outputs": []
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment