Implementation of SRGAN in PyTorch. Superresolves to 4x
Based on this paper: https://arxiv.org/pdf/1609.04802.pdf
Model is to be trained for 150 epochs.
Implementation of SRGAN in PyTorch. Superresolves to 4x
Based on this paper: https://arxiv.org/pdf/1609.04802.pdf
Model is to be trained for 150 epochs.
{ | |
"nbformat": 4, | |
"nbformat_minor": 0, | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.9.1" | |
}, | |
"colab": { | |
"name": "srgan.ipynb", | |
"provenance": [], | |
"collapsed_sections": [], | |
"toc_visible": true, | |
"include_colab_link": true | |
}, | |
"accelerator": "GPU", | |
"widgets": { | |
"application/vnd.jupyter.widget-state+json": { | |
"5b14de92c3bb43f5b769eb8b8918d393": { | |
"model_module": "@jupyter-widgets/controls", | |
"model_name": "HBoxModel", | |
"model_module_version": "1.5.0", | |
"state": { | |
"_view_name": "HBoxView", | |
"_dom_classes": [], | |
"_model_name": "HBoxModel", | |
"_view_module": "@jupyter-widgets/controls", | |
"_model_module_version": "1.5.0", | |
"_view_count": null, | |
"_view_module_version": "1.5.0", | |
"box_style": "", | |
"layout": "IPY_MODEL_92a437f3d4d74fab94d4dfb34cad8cfc", | |
"_model_module": "@jupyter-widgets/controls", | |
"children": [ | |
"IPY_MODEL_6047442ef1464494a1775e852d98de6e", | |
"IPY_MODEL_35baa95ba9434b2f94fbd8335df48a41" | |
] | |
} | |
}, | |
"92a437f3d4d74fab94d4dfb34cad8cfc": { | |
"model_module": "@jupyter-widgets/base", | |
"model_name": "LayoutModel", | |
"model_module_version": "1.2.0", | |
"state": { | |
"_view_name": "LayoutView", | |
"grid_template_rows": null, | |
"right": null, | |
"justify_content": null, | |
"_view_module": "@jupyter-widgets/base", | |
"overflow": null, | |
"_model_module_version": "1.2.0", | |
"_view_count": null, | |
"flex_flow": null, | |
"width": null, | |
"min_width": null, | |
"border": null, | |
"align_items": null, | |
"bottom": null, | |
"_model_module": "@jupyter-widgets/base", | |
"top": null, | |
"grid_column": null, | |
"overflow_y": null, | |
"overflow_x": null, | |
"grid_auto_flow": null, | |
"grid_area": null, | |
"grid_template_columns": null, | |
"flex": null, | |
"_model_name": "LayoutModel", | |
"justify_items": null, | |
"grid_row": null, | |
"max_height": null, | |
"align_content": null, | |
"visibility": null, | |
"align_self": null, | |
"height": null, | |
"min_height": null, | |
"padding": null, | |
"grid_auto_rows": null, | |
"grid_gap": null, | |
"max_width": null, | |
"order": null, | |
"_view_module_version": "1.2.0", | |
"grid_template_areas": null, | |
"object_position": null, | |
"object_fit": null, | |
"grid_auto_columns": null, | |
"margin": null, | |
"display": null, | |
"left": null | |
} | |
}, | |
"6047442ef1464494a1775e852d98de6e": { | |
"model_module": "@jupyter-widgets/controls", | |
"model_name": "FloatProgressModel", | |
"model_module_version": "1.5.0", | |
"state": { | |
"_view_name": "ProgressView", | |
"style": "IPY_MODEL_3b3fcd76032c4f4997e819e68f7ddef0", | |
"_dom_classes": [], | |
"description": "100%", | |
"_model_name": "FloatProgressModel", | |
"bar_style": "success", | |
"max": 553433881, | |
"_view_module": "@jupyter-widgets/controls", | |
"_model_module_version": "1.5.0", | |
"value": 553433881, | |
"_view_count": null, | |
"_view_module_version": "1.5.0", | |
"orientation": "horizontal", | |
"min": 0, | |
"description_tooltip": null, | |
"_model_module": "@jupyter-widgets/controls", | |
"layout": "IPY_MODEL_b201fa430ceb42e3bc438dc673856622" | |
} | |
}, | |
"35baa95ba9434b2f94fbd8335df48a41": { | |
"model_module": "@jupyter-widgets/controls", | |
"model_name": "HTMLModel", | |
"model_module_version": "1.5.0", | |
"state": { | |
"_view_name": "HTMLView", | |
"style": "IPY_MODEL_39c14f33da7a4d85bfcb187c0cacd68a", | |
"_dom_classes": [], | |
"description": "", | |
"_model_name": "HTMLModel", | |
"placeholder": "", | |
"_view_module": "@jupyter-widgets/controls", | |
"_model_module_version": "1.5.0", | |
"value": " 528M/528M [00:09<00:00, 55.6MB/s]", | |
"_view_count": null, | |
"_view_module_version": "1.5.0", | |
"description_tooltip": null, | |
"_model_module": "@jupyter-widgets/controls", | |
"layout": "IPY_MODEL_638747b422554adb8c64e087f52f4fd6" | |
} | |
}, | |
"3b3fcd76032c4f4997e819e68f7ddef0": { | |
"model_module": "@jupyter-widgets/controls", | |
"model_name": "ProgressStyleModel", | |
"model_module_version": "1.5.0", | |
"state": { | |
"_view_name": "StyleView", | |
"_model_name": "ProgressStyleModel", | |
"description_width": "initial", | |
"_view_module": "@jupyter-widgets/base", | |
"_model_module_version": "1.5.0", | |
"_view_count": null, | |
"_view_module_version": "1.2.0", | |
"bar_color": null, | |
"_model_module": "@jupyter-widgets/controls" | |
} | |
}, | |
"b201fa430ceb42e3bc438dc673856622": { | |
"model_module": "@jupyter-widgets/base", | |
"model_name": "LayoutModel", | |
"model_module_version": "1.2.0", | |
"state": { | |
"_view_name": "LayoutView", | |
"grid_template_rows": null, | |
"right": null, | |
"justify_content": null, | |
"_view_module": "@jupyter-widgets/base", | |
"overflow": null, | |
"_model_module_version": "1.2.0", | |
"_view_count": null, | |
"flex_flow": null, | |
"width": null, | |
"min_width": null, | |
"border": null, | |
"align_items": null, | |
"bottom": null, | |
"_model_module": "@jupyter-widgets/base", | |
"top": null, | |
"grid_column": null, | |
"overflow_y": null, | |
"overflow_x": null, | |
"grid_auto_flow": null, | |
"grid_area": null, | |
"grid_template_columns": null, | |
"flex": null, | |
"_model_name": "LayoutModel", | |
"justify_items": null, | |
"grid_row": null, | |
"max_height": null, | |
"align_content": null, | |
"visibility": null, | |
"align_self": null, | |
"height": null, | |
"min_height": null, | |
"padding": null, | |
"grid_auto_rows": null, | |
"grid_gap": null, | |
"max_width": null, | |
"order": null, | |
"_view_module_version": "1.2.0", | |
"grid_template_areas": null, | |
"object_position": null, | |
"object_fit": null, | |
"grid_auto_columns": null, | |
"margin": null, | |
"display": null, | |
"left": null | |
} | |
}, | |
"39c14f33da7a4d85bfcb187c0cacd68a": { | |
"model_module": "@jupyter-widgets/controls", | |
"model_name": "DescriptionStyleModel", | |
"model_module_version": "1.5.0", | |
"state": { | |
"_view_name": "StyleView", | |
"_model_name": "DescriptionStyleModel", | |
"description_width": "", | |
"_view_module": "@jupyter-widgets/base", | |
"_model_module_version": "1.5.0", | |
"_view_count": null, | |
"_view_module_version": "1.2.0", | |
"_model_module": "@jupyter-widgets/controls" | |
} | |
}, | |
"638747b422554adb8c64e087f52f4fd6": { | |
"model_module": "@jupyter-widgets/base", | |
"model_name": "LayoutModel", | |
"model_module_version": "1.2.0", | |
"state": { | |
"_view_name": "LayoutView", | |
"grid_template_rows": null, | |
"right": null, | |
"justify_content": null, | |
"_view_module": "@jupyter-widgets/base", | |
"overflow": null, | |
"_model_module_version": "1.2.0", | |
"_view_count": null, | |
"flex_flow": null, | |
"width": null, | |
"min_width": null, | |
"border": null, | |
"align_items": null, | |
"bottom": null, | |
"_model_module": "@jupyter-widgets/base", | |
"top": null, | |
"grid_column": null, | |
"overflow_y": null, | |
"overflow_x": null, | |
"grid_auto_flow": null, | |
"grid_area": null, | |
"grid_template_columns": null, | |
"flex": null, | |
"_model_name": "LayoutModel", | |
"justify_items": null, | |
"grid_row": null, | |
"max_height": null, | |
"align_content": null, | |
"visibility": null, | |
"align_self": null, | |
"height": null, | |
"min_height": null, | |
"padding": null, | |
"grid_auto_rows": null, | |
"grid_gap": null, | |
"max_width": null, | |
"order": null, | |
"_view_module_version": "1.2.0", | |
"grid_template_areas": null, | |
"object_position": null, | |
"object_fit": null, | |
"grid_auto_columns": null, | |
"margin": null, | |
"display": null, | |
"left": null | |
} | |
} | |
} | |
} | |
}, | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "view-in-github", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"<a href=\"https://colab.research.google.com/gist/SharanSMenon/7afd37c9cac76a736fd1a592966608c0/srgan.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "wtSTQigkkUSR" | |
}, | |
"source": [ | |
"# Super resolution with SRGAN" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "OLfJgUAekUST" | |
}, | |
"source": [ | |
"**Dataset Link**: http://data.vision.ee.ethz.ch/cvl/DIV2K/DIV2K_train_HR.zip" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "nggOyQPRke9t" | |
}, | |
"source": [ | |
"!wget http://data.vision.ee.ethz.ch/cvl/DIV2K/DIV2K_train_HR.zip\n", | |
"!unzip DIV2K_train_HR.zip" | |
], | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "ll7Qe9DGkUST" | |
}, | |
"source": [ | |
"## Imports" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "kEbzONOhkUST" | |
}, | |
"source": [ | |
"import torch\n", | |
"import math\n", | |
"from os import listdir\n", | |
"import numpy as np\n", | |
"from torch.autograd import Variable" | |
], | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "qI1niZaykUST" | |
}, | |
"source": [ | |
"from torchvision.transforms import Compose, RandomCrop, ToTensor, ToPILImage, CenterCrop, Resize" | |
], | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "2av_Pl2SkUSU" | |
}, | |
"source": [ | |
"from torch.utils.data import DataLoader, Dataset" | |
], | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "RaYLriLqkUSU" | |
}, | |
"source": [ | |
"from srgandata import TrainDatasetFromFolder, ValDatasetFromFolder, TestDatasetFromFolder, display_transform" | |
], | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "UJ0FVu-NkUSU", | |
"outputId": "29ac2361-0f74-4c20-bcde-f58487383694" | |
}, | |
"source": [ | |
"torch.autograd.set_detect_anomaly(True)" | |
], | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"<torch.autograd.anomaly_mode.set_detect_anomaly at 0x7f71f6ead908>" | |
] | |
}, | |
"metadata": { | |
"tags": [] | |
}, | |
"execution_count": 4 | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "4OQpaNjXkUSU" | |
}, | |
"source": [ | |
"## Dataset" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "C8i4bP25kUSU" | |
}, | |
"source": [ | |
"UPSCALE_FACTOR = 4\n", | |
"CROP_SIZE = 88" | |
], | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "3OkxFCSYkUSU" | |
}, | |
"source": [ | |
"mean = np.array([0.485, 0.456, 0.406])\n", | |
"std = np.array([0.229, 0.224, 0.225])" | |
], | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "F8DP8DVokUSV" | |
}, | |
"source": [ | |
"Makes a low resokution image" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "cmejgK8xkUSV" | |
}, | |
"source": [ | |
"train_set = TrainDatasetFromFolder('DIV2K_train_HR', crop_size=CROP_SIZE, upscale_factor=UPSCALE_FACTOR)\n", | |
"# val_set = ValDatasetFromFolder('DIV2K_valid_HR', upscale_factor=UPSCALE_FACTOR)\n", | |
"train_loader = DataLoader(dataset=train_set, num_workers=4, batch_size=64, shuffle=True)\n", | |
"# val_loader = DataLoader(dataset=val_set, num_workers=4, batch_size=1, shuffle=False)" | |
], | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "-VSfRAQWkUSV" | |
}, | |
"source": [ | |
"import matplotlib.pyplot as plt\n", | |
"%matplotlib inline" | |
], | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "12sK5ob7kUSV" | |
}, | |
"source": [ | |
"## Model" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "ENEfy134kUSV" | |
}, | |
"source": [ | |
"from torch import nn, optim" | |
], | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "xGPGQ_BpkUSV" | |
}, | |
"source": [ | |
"class ResidualBlock(nn.Module):\n", | |
" def __init__(self, channels):\n", | |
" super(ResidualBlock, self).__init__()\n", | |
" self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)\n", | |
" self.bn1 = nn.BatchNorm2d(channels)\n", | |
" self.prelu = nn.PReLU()\n", | |
" self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)\n", | |
" self.bn2 = nn.BatchNorm2d(channels)\n", | |
"\n", | |
" def forward(self, x):\n", | |
" residual = self.conv1(x)\n", | |
" residual = self.bn1(residual)\n", | |
" residual = self.prelu(residual)\n", | |
" residual = self.conv2(residual)\n", | |
" residual = self.bn2(residual)\n", | |
"\n", | |
" return x + residual" | |
], | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "6ceKahB2kUSV" | |
}, | |
"source": [ | |
"class UpsampleBLock(nn.Module):\n", | |
" def __init__(self, in_channels, up_scale):\n", | |
" super(UpsampleBLock, self).__init__()\n", | |
" self.conv = nn.Conv2d(in_channels, in_channels * up_scale ** 2, kernel_size=3, padding=1)\n", | |
" self.pixel_shuffle = nn.PixelShuffle(up_scale)\n", | |
" self.prelu = nn.PReLU()\n", | |
"\n", | |
" def forward(self, x):\n", | |
" x = self.conv(x)\n", | |
" x = self.pixel_shuffle(x)\n", | |
" x = self.prelu(x)\n", | |
" return x" | |
], | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "E9sHfbpRkUSW" | |
}, | |
"source": [ | |
"class Generator(nn.Module):\n", | |
" def __init__(self, scale_factor):\n", | |
" upsample_block_num = int(math.log(scale_factor, 2))\n", | |
"\n", | |
" super(Generator, self).__init__()\n", | |
" self.block1 = nn.Sequential(\n", | |
" nn.Conv2d(3, 64, kernel_size=9, padding=4),\n", | |
" nn.PReLU()\n", | |
" )\n", | |
" self.block2 = ResidualBlock(64)\n", | |
" self.block3 = ResidualBlock(64)\n", | |
" self.block4 = ResidualBlock(64)\n", | |
" self.block5 = ResidualBlock(64)\n", | |
" self.block6 = ResidualBlock(64)\n", | |
" self.block7 = nn.Sequential(\n", | |
" nn.Conv2d(64, 64, kernel_size=3, padding=1),\n", | |
" nn.BatchNorm2d(64)\n", | |
" )\n", | |
" block8 = [UpsampleBLock(64, 2) for _ in range(upsample_block_num)]\n", | |
" block8.append(nn.Conv2d(64, 3, kernel_size=9, padding=4))\n", | |
" self.block8 = nn.Sequential(*block8)\n", | |
"\n", | |
" def forward(self, x):\n", | |
" block1 = self.block1(x)\n", | |
" block2 = self.block2(block1)\n", | |
" block3 = self.block3(block2)\n", | |
" block4 = self.block4(block3)\n", | |
" block5 = self.block5(block4)\n", | |
" block6 = self.block6(block5)\n", | |
" block7 = self.block7(block6)\n", | |
" block8 = self.block8(block1 + block7)\n", | |
"\n", | |
" return (torch.tanh(block8) + 1) / 2" | |
], | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "5v9c_NYNkUSW" | |
}, | |
"source": [ | |
"class Discriminator(nn.Module):\n", | |
" def __init__(self):\n", | |
" super(Discriminator, self).__init__()\n", | |
" self.net = nn.Sequential(\n", | |
" nn.Conv2d(3, 64, kernel_size=3, padding=1),\n", | |
" nn.LeakyReLU(0.2),\n", | |
"\n", | |
" nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1),\n", | |
" nn.BatchNorm2d(64),\n", | |
" nn.LeakyReLU(0.2),\n", | |
"\n", | |
" nn.Conv2d(64, 128, kernel_size=3, padding=1),\n", | |
" nn.BatchNorm2d(128),\n", | |
" nn.LeakyReLU(0.2),\n", | |
"\n", | |
" nn.Conv2d(128, 128, kernel_size=3, stride=2, padding=1),\n", | |
" nn.BatchNorm2d(128),\n", | |
" nn.LeakyReLU(0.2),\n", | |
"\n", | |
" nn.Conv2d(128, 256, kernel_size=3, padding=1),\n", | |
" nn.BatchNorm2d(256),\n", | |
" nn.LeakyReLU(0.2),\n", | |
"\n", | |
" nn.Conv2d(256, 256, kernel_size=3, stride=2, padding=1),\n", | |
" nn.BatchNorm2d(256),\n", | |
" nn.LeakyReLU(0.2),\n", | |
"\n", | |
" nn.Conv2d(256, 512, kernel_size=3, padding=1),\n", | |
" nn.BatchNorm2d(512),\n", | |
" nn.LeakyReLU(0.2),\n", | |
"\n", | |
" nn.Conv2d(512, 512, kernel_size=3, stride=2, padding=1),\n", | |
" nn.BatchNorm2d(512),\n", | |
" nn.LeakyReLU(0.2),\n", | |
"\n", | |
" nn.AdaptiveAvgPool2d(1),\n", | |
" nn.Conv2d(512, 1024, kernel_size=1),\n", | |
" nn.LeakyReLU(0.2),\n", | |
" nn.Conv2d(1024, 1, kernel_size=1)\n", | |
" )\n", | |
"\n", | |
" def forward(self, x):\n", | |
" batch_size = x.size(0)\n", | |
" return torch.sigmoid(self.net(x).view(batch_size))" | |
], | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "UbJRjjxBkUSW" | |
}, | |
"source": [ | |
"from torchvision.models.vgg import vgg16" | |
], | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "PBErzouckUSW" | |
}, | |
"source": [ | |
"### Loss" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "1dBQrQBgkUSW" | |
}, | |
"source": [ | |
"class GeneratorLoss(nn.Module):\n", | |
" def __init__(self):\n", | |
" super(GeneratorLoss, self).__init__()\n", | |
" vgg = vgg16(pretrained=True)\n", | |
" loss_network = nn.Sequential(*list(vgg.features)[:31]).eval()\n", | |
" for param in loss_network.parameters():\n", | |
" param.requires_grad = False\n", | |
" self.loss_network = loss_network\n", | |
" self.mse_loss = nn.MSELoss()\n", | |
" self.tv_loss = TVLoss()\n", | |
"\n", | |
" def forward(self, out_labels, out_images, target_images):\n", | |
" # Adversarial Loss\n", | |
" adversarial_loss = torch.mean(1 - out_labels)\n", | |
" # Perception Loss\n", | |
" perception_loss = self.mse_loss(self.loss_network(out_images), self.loss_network(target_images))\n", | |
" # Image Loss\n", | |
" image_loss = self.mse_loss(out_images, target_images)\n", | |
" # TV Loss\n", | |
" tv_loss = self.tv_loss(out_images)\n", | |
" return image_loss + 0.001 * adversarial_loss + 0.006 * perception_loss + 2e-8 * tv_loss\n" | |
], | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "uGPc4qnWkUSW" | |
}, | |
"source": [ | |
"class TVLoss(nn.Module):\n", | |
" def __init__(self, tv_loss_weight=1):\n", | |
" super(TVLoss, self).__init__()\n", | |
" self.tv_loss_weight = tv_loss_weight\n", | |
"\n", | |
" def forward(self, x):\n", | |
" batch_size = x.size()[0]\n", | |
" h_x = x.size()[2]\n", | |
" w_x = x.size()[3]\n", | |
" count_h = self.tensor_size(x[:, :, 1:, :])\n", | |
" count_w = self.tensor_size(x[:, :, :, 1:])\n", | |
" h_tv = torch.pow((x[:, :, 1:, :] - x[:, :, :h_x - 1, :]), 2).sum()\n", | |
" w_tv = torch.pow((x[:, :, :, 1:] - x[:, :, :, :w_x - 1]), 2).sum()\n", | |
" return self.tv_loss_weight * 2 * (h_tv / count_h + w_tv / count_w) / batch_size\n", | |
"\n", | |
" @staticmethod\n", | |
" def tensor_size(t):\n", | |
" return t.size()[1] * t.size()[2] * t.size()[3]" | |
], | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "duZ8UvogmJ8X" | |
}, | |
"source": [ | |
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")" | |
], | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "9QPi_X9_kUSW" | |
}, | |
"source": [ | |
"netG = Generator(UPSCALE_FACTOR)" | |
], | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "6sQrfCZHkUSW" | |
}, | |
"source": [ | |
"netD = Discriminator()" | |
], | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "PiiqnrvkkUSX", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 103, | |
"referenced_widgets": [ | |
"5b14de92c3bb43f5b769eb8b8918d393", | |
"92a437f3d4d74fab94d4dfb34cad8cfc", | |
"6047442ef1464494a1775e852d98de6e", | |
"35baa95ba9434b2f94fbd8335df48a41", | |
"3b3fcd76032c4f4997e819e68f7ddef0", | |
"b201fa430ceb42e3bc438dc673856622", | |
"39c14f33da7a4d85bfcb187c0cacd68a", | |
"638747b422554adb8c64e087f52f4fd6" | |
] | |
}, | |
"outputId": "596a3dca-fa3f-4329-b29c-85d632f9792b" | |
}, | |
"source": [ | |
"generator_criterion = GeneratorLoss()" | |
], | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"Downloading: \"https://download.pytorch.org/models/vgg16-397923af.pth\" to /root/.cache/torch/hub/checkpoints/vgg16-397923af.pth\n" | |
], | |
"name": "stderr" | |
}, | |
{ | |
"output_type": "display_data", | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "5b14de92c3bb43f5b769eb8b8918d393", | |
"version_minor": 0, | |
"version_major": 2 | |
}, | |
"text/plain": [ | |
"HBox(children=(FloatProgress(value=0.0, max=553433881.0), HTML(value='')))" | |
] | |
}, | |
"metadata": { | |
"tags": [] | |
} | |
}, | |
{ | |
"output_type": "stream", | |
"text": [ | |
"\n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "aQaIjG8UnunL" | |
}, | |
"source": [ | |
"generator_criterion = generator_criterion.to(device)\n", | |
"netG = netG.to(device)\n", | |
"netD = netD.to(device)" | |
], | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "YRPW-NqlkUSX" | |
}, | |
"source": [ | |
"optimizerG = optim.Adam(netG.parameters(), lr=0.0002)\n", | |
"optimizerD = optim.Adam(netD.parameters(), lr=0.0002)" | |
], | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "236B74VqOLgq" | |
}, | |
"source": [ | |
"" | |
], | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "yULcXdFokUSX" | |
}, | |
"source": [ | |
"results = {'d_loss': [], 'g_loss': [], 'd_score': [], 'g_score': [], 'psnr': [], 'ssim': []}" | |
], | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "9Esuf4_xkUSX" | |
}, | |
"source": [ | |
"## Train" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "rg9EVhMGkUSX" | |
}, | |
"source": [ | |
"from tqdm import tqdm\n", | |
"import os" | |
], | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "lDi9oyr-kUSX" | |
}, | |
"source": [ | |
"N_EPOCHS = 150" | |
], | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "dJSFZRzykUSX", | |
"outputId": "dc79719e-2477-4621-d365-c6d7c5fb99c2" | |
}, | |
"source": [ | |
"for epoch in range(1, N_EPOCHS + 1):\n", | |
" train_bar = tqdm(train_loader)\n", | |
" running_results = {'batch_sizes': 0, 'd_loss': 0, 'g_loss': 0, 'd_score': 0, 'g_score': 0}\n", | |
"\n", | |
" netG.train()\n", | |
" netD.train()\n", | |
" for data, target in train_bar:\n", | |
" g_update_first = True\n", | |
" batch_size = data.size(0)\n", | |
" running_results['batch_sizes'] += batch_size\n", | |
" \n", | |
" real_img = Variable(target)\n", | |
" if torch.cuda.is_available():\n", | |
" real_img = real_img.cuda()\n", | |
" z = Variable(data)\n", | |
" if torch.cuda.is_available():\n", | |
" z = z.cuda()\n", | |
" \n", | |
" ############################\n", | |
" # (1) Update D network: maximize D(x)-1-D(G(z))\n", | |
" ###########################\n", | |
" fake_img = netG(z)\n", | |
"\n", | |
" netD.zero_grad()\n", | |
" real_out = netD(real_img).mean()\n", | |
" fake_out = netD(fake_img).mean()\n", | |
" d_loss = 1 - real_out + fake_out\n", | |
" d_loss.backward(retain_graph=True)\n", | |
" optimizerD.step()\n", | |
"\n", | |
" ############################\n", | |
" # (2) Update G network: minimize 1-D(G(z)) + Perception Loss + Image Loss + TV Loss\n", | |
" ###########################\n", | |
" ###### Was causing Runtime Error ######\n", | |
" fake_img = netG(z)\n", | |
" fake_out = netD(fake_img).mean()\n", | |
" #######################################\n", | |
" netG.zero_grad()\n", | |
" g_loss = generator_criterion(fake_out, fake_img, real_img)\n", | |
" g_loss.backward()\n", | |
"\n", | |
" fake_img = netG(z)\n", | |
" fake_out = netD(fake_img).mean()\n", | |
"\n", | |
" optimizerG.step()\n", | |
"\n", | |
" # loss for current batch before optimization \n", | |
" running_results['g_loss'] += g_loss.item() * batch_size\n", | |
" running_results['d_loss'] += d_loss.item() * batch_size\n", | |
" running_results['d_score'] += real_out.item() * batch_size\n", | |
" running_results['g_score'] += fake_out.item() * batch_size\n", | |
"\n", | |
" train_bar.set_description(desc='[%d/%d] Loss_D: %.4f Loss_G: %.4f D(x): %.4f D(G(z)): %.4f' % (\n", | |
" epoch, N_EPOCHS, running_results['d_loss'] / running_results['batch_sizes'],\n", | |
" running_results['g_loss'] / running_results['batch_sizes'],\n", | |
" running_results['d_score'] / running_results['batch_sizes'],\n", | |
" running_results['g_score'] / running_results['batch_sizes']))\n", | |
"\n", | |
" netG.eval()\n", | |
" out_path = 'training_results/SRF_' + str(UPSCALE_FACTOR) + '/'\n", | |
" if not os.path.exists(out_path):\n", | |
" os.makedirs(out_path)" | |
], | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"[1/150] Loss_D: 0.8393 Loss_G: 0.0452 D(x): 0.5709 D(G(z)): 0.3487: 100%|██████████| 13/13 [01:01<00:00, 4.72s/it]\n", | |
"[2/150] Loss_D: 0.7786 Loss_G: 0.0193 D(x): 0.5395 D(G(z)): 0.3129: 100%|██████████| 13/13 [01:00<00:00, 4.66s/it]\n", | |
"[3/150] Loss_D: 0.3920 Loss_G: 0.0157 D(x): 0.7983 D(G(z)): 0.1449: 100%|██████████| 13/13 [00:59<00:00, 4.57s/it]\n", | |
"[4/150] Loss_D: 0.1291 Loss_G: 0.0145 D(x): 0.9329 D(G(z)): 0.0532: 100%|██████████| 13/13 [00:59<00:00, 4.57s/it]\n", | |
"[5/150] Loss_D: 0.0426 Loss_G: 0.0136 D(x): 0.9778 D(G(z)): 0.0177: 100%|██████████| 13/13 [00:59<00:00, 4.57s/it]\n", | |
"[6/150] Loss_D: 0.0277 Loss_G: 0.0131 D(x): 0.9834 D(G(z)): 0.0386: 100%|██████████| 13/13 [01:00<00:00, 4.64s/it]\n", | |
"[7/150] Loss_D: 0.8722 Loss_G: 0.0119 D(x): 0.5780 D(G(z)): 0.3579: 100%|██████████| 13/13 [00:59<00:00, 4.59s/it]\n", | |
"[8/150] Loss_D: 0.7058 Loss_G: 0.0104 D(x): 0.5374 D(G(z)): 0.2106: 100%|██████████| 13/13 [00:59<00:00, 4.60s/it]\n", | |
"[9/150] Loss_D: 0.4493 Loss_G: 0.0108 D(x): 0.7283 D(G(z)): 0.1293: 100%|██████████| 13/13 [00:59<00:00, 4.57s/it]\n", | |
"[10/150] Loss_D: 0.1850 Loss_G: 0.0102 D(x): 0.9002 D(G(z)): 0.0621: 100%|██████████| 13/13 [00:59<00:00, 4.58s/it]\n", | |
" 0%| | 0/13 [00:00<?, ?it/s]" | |
], | |
"name": "stderr" | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "YTPIWjk2NC5P" | |
}, | |
"source": [ | |
"from torchvision.transforms import ToTensor, ToPILImage" | |
], | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "O9yt2ExXLxjO" | |
}, | |
"source": [ | |
"from PIL import Image" | |
], | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "Qz8w5AtxcuRi" | |
}, | |
"source": [ | |
"torch.save(netG.state_dict(), \"super_res_gen.pth\")" | |
], | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "4CIfaGsmQVwg", | |
"outputId": "eb6874a0-3589-400f-c571-9cc9f2d32d28" | |
}, | |
"source": [ | |
"netG.load_state_dict(torch.load(\"super_res_gen.pth\")) # If you already have a pretrained weights file for this model." | |
], | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"<All keys matched successfully>" | |
] | |
}, | |
"metadata": { | |
"tags": [] | |
}, | |
"execution_count": 18 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "8LI5BqU5NIWu", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 165 | |
}, | |
"outputId": "97eef7b0-b8f0-48d5-b532-c4afc52a4b7d" | |
}, | |
"source": [ | |
"# LOAD your OWN image here. This cell wont work unless you upload your own image to Colab\n", | |
"image = Image.open(\"table2.jpg\")" | |
], | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "error", | |
"ename": "NameError", | |
"evalue": "ignored", | |
"traceback": [ | |
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", | |
"\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", | |
"\u001b[0;32m<ipython-input-1-369897577819>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mimage\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mImage\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mopen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"table2.jpg\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", | |
"\u001b[0;31mNameError\u001b[0m: name 'Image' is not defined" | |
] | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "TQjCZ7CBkUSX", | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"outputId": "f4b8339f-09b8-4c82-8886-4af5ef4530c8" | |
}, | |
"source": [ | |
"image = Variable(ToTensor()(image), volatile=True).unsqueeze(0)" | |
], | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"/usr/local/lib/python3.6/dist-packages/ipykernel_launcher.py:1: UserWarning: volatile was removed and now has no effect. Use `with torch.no_grad():` instead.\n", | |
" \"\"\"Entry point for launching an IPython kernel.\n" | |
], | |
"name": "stderr" | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "VpCwx3FXRXYO" | |
}, | |
"source": [ | |
"netG = netG.to(torch.device(\"cuda\"))\n", | |
"image = image.to(torch.device(\"cuda\"))" | |
], | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "MehEf581NQTH" | |
}, | |
"source": [ | |
"out = netG(image)\n", | |
"out_img = ToPILImage()(out[0].data.cpu())" | |
], | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "wrPdGj02NVRi" | |
}, | |
"source": [ | |
"out_img.save(\"table2-superres-4x.jpg\")" | |
], | |
"execution_count": null, | |
"outputs": [] | |
} | |
] | |
} |
Add this to beginning of file