Skip to content

Instantly share code, notes, and snippets.

@SharanSMenon
Last active December 5, 2021 22:42
Show Gist options
  • Save SharanSMenon/7afd37c9cac76a736fd1a592966608c0 to your computer and use it in GitHub Desktop.
Save SharanSMenon/7afd37c9cac76a736fd1a592966608c0 to your computer and use it in GitHub Desktop.
srgan.ipynb
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.1"
},
"colab": {
"name": "srgan.ipynb",
"provenance": [],
"collapsed_sections": [],
"toc_visible": true,
"include_colab_link": true
},
"accelerator": "GPU",
"widgets": {
"application/vnd.jupyter.widget-state+json": {
"5b14de92c3bb43f5b769eb8b8918d393": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HBoxModel",
"model_module_version": "1.5.0",
"state": {
"_view_name": "HBoxView",
"_dom_classes": [],
"_model_name": "HBoxModel",
"_view_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_view_count": null,
"_view_module_version": "1.5.0",
"box_style": "",
"layout": "IPY_MODEL_92a437f3d4d74fab94d4dfb34cad8cfc",
"_model_module": "@jupyter-widgets/controls",
"children": [
"IPY_MODEL_6047442ef1464494a1775e852d98de6e",
"IPY_MODEL_35baa95ba9434b2f94fbd8335df48a41"
]
}
},
"92a437f3d4d74fab94d4dfb34cad8cfc": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"model_module_version": "1.2.0",
"state": {
"_view_name": "LayoutView",
"grid_template_rows": null,
"right": null,
"justify_content": null,
"_view_module": "@jupyter-widgets/base",
"overflow": null,
"_model_module_version": "1.2.0",
"_view_count": null,
"flex_flow": null,
"width": null,
"min_width": null,
"border": null,
"align_items": null,
"bottom": null,
"_model_module": "@jupyter-widgets/base",
"top": null,
"grid_column": null,
"overflow_y": null,
"overflow_x": null,
"grid_auto_flow": null,
"grid_area": null,
"grid_template_columns": null,
"flex": null,
"_model_name": "LayoutModel",
"justify_items": null,
"grid_row": null,
"max_height": null,
"align_content": null,
"visibility": null,
"align_self": null,
"height": null,
"min_height": null,
"padding": null,
"grid_auto_rows": null,
"grid_gap": null,
"max_width": null,
"order": null,
"_view_module_version": "1.2.0",
"grid_template_areas": null,
"object_position": null,
"object_fit": null,
"grid_auto_columns": null,
"margin": null,
"display": null,
"left": null
}
},
"6047442ef1464494a1775e852d98de6e": {
"model_module": "@jupyter-widgets/controls",
"model_name": "FloatProgressModel",
"model_module_version": "1.5.0",
"state": {
"_view_name": "ProgressView",
"style": "IPY_MODEL_3b3fcd76032c4f4997e819e68f7ddef0",
"_dom_classes": [],
"description": "100%",
"_model_name": "FloatProgressModel",
"bar_style": "success",
"max": 553433881,
"_view_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"value": 553433881,
"_view_count": null,
"_view_module_version": "1.5.0",
"orientation": "horizontal",
"min": 0,
"description_tooltip": null,
"_model_module": "@jupyter-widgets/controls",
"layout": "IPY_MODEL_b201fa430ceb42e3bc438dc673856622"
}
},
"35baa95ba9434b2f94fbd8335df48a41": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HTMLModel",
"model_module_version": "1.5.0",
"state": {
"_view_name": "HTMLView",
"style": "IPY_MODEL_39c14f33da7a4d85bfcb187c0cacd68a",
"_dom_classes": [],
"description": "",
"_model_name": "HTMLModel",
"placeholder": "​",
"_view_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"value": " 528M/528M [00:09<00:00, 55.6MB/s]",
"_view_count": null,
"_view_module_version": "1.5.0",
"description_tooltip": null,
"_model_module": "@jupyter-widgets/controls",
"layout": "IPY_MODEL_638747b422554adb8c64e087f52f4fd6"
}
},
"3b3fcd76032c4f4997e819e68f7ddef0": {
"model_module": "@jupyter-widgets/controls",
"model_name": "ProgressStyleModel",
"model_module_version": "1.5.0",
"state": {
"_view_name": "StyleView",
"_model_name": "ProgressStyleModel",
"description_width": "initial",
"_view_module": "@jupyter-widgets/base",
"_model_module_version": "1.5.0",
"_view_count": null,
"_view_module_version": "1.2.0",
"bar_color": null,
"_model_module": "@jupyter-widgets/controls"
}
},
"b201fa430ceb42e3bc438dc673856622": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"model_module_version": "1.2.0",
"state": {
"_view_name": "LayoutView",
"grid_template_rows": null,
"right": null,
"justify_content": null,
"_view_module": "@jupyter-widgets/base",
"overflow": null,
"_model_module_version": "1.2.0",
"_view_count": null,
"flex_flow": null,
"width": null,
"min_width": null,
"border": null,
"align_items": null,
"bottom": null,
"_model_module": "@jupyter-widgets/base",
"top": null,
"grid_column": null,
"overflow_y": null,
"overflow_x": null,
"grid_auto_flow": null,
"grid_area": null,
"grid_template_columns": null,
"flex": null,
"_model_name": "LayoutModel",
"justify_items": null,
"grid_row": null,
"max_height": null,
"align_content": null,
"visibility": null,
"align_self": null,
"height": null,
"min_height": null,
"padding": null,
"grid_auto_rows": null,
"grid_gap": null,
"max_width": null,
"order": null,
"_view_module_version": "1.2.0",
"grid_template_areas": null,
"object_position": null,
"object_fit": null,
"grid_auto_columns": null,
"margin": null,
"display": null,
"left": null
}
},
"39c14f33da7a4d85bfcb187c0cacd68a": {
"model_module": "@jupyter-widgets/controls",
"model_name": "DescriptionStyleModel",
"model_module_version": "1.5.0",
"state": {
"_view_name": "StyleView",
"_model_name": "DescriptionStyleModel",
"description_width": "",
"_view_module": "@jupyter-widgets/base",
"_model_module_version": "1.5.0",
"_view_count": null,
"_view_module_version": "1.2.0",
"_model_module": "@jupyter-widgets/controls"
}
},
"638747b422554adb8c64e087f52f4fd6": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"model_module_version": "1.2.0",
"state": {
"_view_name": "LayoutView",
"grid_template_rows": null,
"right": null,
"justify_content": null,
"_view_module": "@jupyter-widgets/base",
"overflow": null,
"_model_module_version": "1.2.0",
"_view_count": null,
"flex_flow": null,
"width": null,
"min_width": null,
"border": null,
"align_items": null,
"bottom": null,
"_model_module": "@jupyter-widgets/base",
"top": null,
"grid_column": null,
"overflow_y": null,
"overflow_x": null,
"grid_auto_flow": null,
"grid_area": null,
"grid_template_columns": null,
"flex": null,
"_model_name": "LayoutModel",
"justify_items": null,
"grid_row": null,
"max_height": null,
"align_content": null,
"visibility": null,
"align_self": null,
"height": null,
"min_height": null,
"padding": null,
"grid_auto_rows": null,
"grid_gap": null,
"max_width": null,
"order": null,
"_view_module_version": "1.2.0",
"grid_template_areas": null,
"object_position": null,
"object_fit": null,
"grid_auto_columns": null,
"margin": null,
"display": null,
"left": null
}
}
}
}
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/SharanSMenon/7afd37c9cac76a736fd1a592966608c0/srgan.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "wtSTQigkkUSR"
},
"source": [
"# Super resolution with SRGAN"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "OLfJgUAekUST"
},
"source": [
"**Dataset Link**: http://data.vision.ee.ethz.ch/cvl/DIV2K/DIV2K_train_HR.zip"
]
},
{
"cell_type": "code",
"metadata": {
"id": "nggOyQPRke9t"
},
"source": [
"!wget http://data.vision.ee.ethz.ch/cvl/DIV2K/DIV2K_train_HR.zip\n",
"!unzip DIV2K_train_HR.zip"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "ll7Qe9DGkUST"
},
"source": [
"## Imports"
]
},
{
"cell_type": "code",
"metadata": {
"id": "kEbzONOhkUST"
},
"source": [
"import torch\n",
"import math\n",
"from os import listdir\n",
"import numpy as np\n",
"from torch.autograd import Variable"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "qI1niZaykUST"
},
"source": [
"from torchvision.transforms import Compose, RandomCrop, ToTensor, ToPILImage, CenterCrop, Resize"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "2av_Pl2SkUSU"
},
"source": [
"from torch.utils.data import DataLoader, Dataset"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "RaYLriLqkUSU"
},
"source": [
"from srgandata import TrainDatasetFromFolder, ValDatasetFromFolder, TestDatasetFromFolder, display_transform"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "UJ0FVu-NkUSU",
"outputId": "29ac2361-0f74-4c20-bcde-f58487383694"
},
"source": [
"torch.autograd.set_detect_anomaly(True)"
],
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"<torch.autograd.anomaly_mode.set_detect_anomaly at 0x7f71f6ead908>"
]
},
"metadata": {
"tags": []
},
"execution_count": 4
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "4OQpaNjXkUSU"
},
"source": [
"## Dataset"
]
},
{
"cell_type": "code",
"metadata": {
"id": "C8i4bP25kUSU"
},
"source": [
"UPSCALE_FACTOR = 4\n",
"CROP_SIZE = 88"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "3OkxFCSYkUSU"
},
"source": [
"mean = np.array([0.485, 0.456, 0.406])\n",
"std = np.array([0.229, 0.224, 0.225])"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "F8DP8DVokUSV"
},
"source": [
"Makes a low resokution image"
]
},
{
"cell_type": "code",
"metadata": {
"id": "cmejgK8xkUSV"
},
"source": [
"train_set = TrainDatasetFromFolder('DIV2K_train_HR', crop_size=CROP_SIZE, upscale_factor=UPSCALE_FACTOR)\n",
"# val_set = ValDatasetFromFolder('DIV2K_valid_HR', upscale_factor=UPSCALE_FACTOR)\n",
"train_loader = DataLoader(dataset=train_set, num_workers=4, batch_size=64, shuffle=True)\n",
"# val_loader = DataLoader(dataset=val_set, num_workers=4, batch_size=1, shuffle=False)"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "-VSfRAQWkUSV"
},
"source": [
"import matplotlib.pyplot as plt\n",
"%matplotlib inline"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "12sK5ob7kUSV"
},
"source": [
"## Model"
]
},
{
"cell_type": "code",
"metadata": {
"id": "ENEfy134kUSV"
},
"source": [
"from torch import nn, optim"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "xGPGQ_BpkUSV"
},
"source": [
"class ResidualBlock(nn.Module):\n",
" def __init__(self, channels):\n",
" super(ResidualBlock, self).__init__()\n",
" self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)\n",
" self.bn1 = nn.BatchNorm2d(channels)\n",
" self.prelu = nn.PReLU()\n",
" self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)\n",
" self.bn2 = nn.BatchNorm2d(channels)\n",
"\n",
" def forward(self, x):\n",
" residual = self.conv1(x)\n",
" residual = self.bn1(residual)\n",
" residual = self.prelu(residual)\n",
" residual = self.conv2(residual)\n",
" residual = self.bn2(residual)\n",
"\n",
" return x + residual"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "6ceKahB2kUSV"
},
"source": [
"class UpsampleBLock(nn.Module):\n",
" def __init__(self, in_channels, up_scale):\n",
" super(UpsampleBLock, self).__init__()\n",
" self.conv = nn.Conv2d(in_channels, in_channels * up_scale ** 2, kernel_size=3, padding=1)\n",
" self.pixel_shuffle = nn.PixelShuffle(up_scale)\n",
" self.prelu = nn.PReLU()\n",
"\n",
" def forward(self, x):\n",
" x = self.conv(x)\n",
" x = self.pixel_shuffle(x)\n",
" x = self.prelu(x)\n",
" return x"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "E9sHfbpRkUSW"
},
"source": [
"class Generator(nn.Module):\n",
" def __init__(self, scale_factor):\n",
" upsample_block_num = int(math.log(scale_factor, 2))\n",
"\n",
" super(Generator, self).__init__()\n",
" self.block1 = nn.Sequential(\n",
" nn.Conv2d(3, 64, kernel_size=9, padding=4),\n",
" nn.PReLU()\n",
" )\n",
" self.block2 = ResidualBlock(64)\n",
" self.block3 = ResidualBlock(64)\n",
" self.block4 = ResidualBlock(64)\n",
" self.block5 = ResidualBlock(64)\n",
" self.block6 = ResidualBlock(64)\n",
" self.block7 = nn.Sequential(\n",
" nn.Conv2d(64, 64, kernel_size=3, padding=1),\n",
" nn.BatchNorm2d(64)\n",
" )\n",
" block8 = [UpsampleBLock(64, 2) for _ in range(upsample_block_num)]\n",
" block8.append(nn.Conv2d(64, 3, kernel_size=9, padding=4))\n",
" self.block8 = nn.Sequential(*block8)\n",
"\n",
" def forward(self, x):\n",
" block1 = self.block1(x)\n",
" block2 = self.block2(block1)\n",
" block3 = self.block3(block2)\n",
" block4 = self.block4(block3)\n",
" block5 = self.block5(block4)\n",
" block6 = self.block6(block5)\n",
" block7 = self.block7(block6)\n",
" block8 = self.block8(block1 + block7)\n",
"\n",
" return (torch.tanh(block8) + 1) / 2"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "5v9c_NYNkUSW"
},
"source": [
"class Discriminator(nn.Module):\n",
" def __init__(self):\n",
" super(Discriminator, self).__init__()\n",
" self.net = nn.Sequential(\n",
" nn.Conv2d(3, 64, kernel_size=3, padding=1),\n",
" nn.LeakyReLU(0.2),\n",
"\n",
" nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1),\n",
" nn.BatchNorm2d(64),\n",
" nn.LeakyReLU(0.2),\n",
"\n",
" nn.Conv2d(64, 128, kernel_size=3, padding=1),\n",
" nn.BatchNorm2d(128),\n",
" nn.LeakyReLU(0.2),\n",
"\n",
" nn.Conv2d(128, 128, kernel_size=3, stride=2, padding=1),\n",
" nn.BatchNorm2d(128),\n",
" nn.LeakyReLU(0.2),\n",
"\n",
" nn.Conv2d(128, 256, kernel_size=3, padding=1),\n",
" nn.BatchNorm2d(256),\n",
" nn.LeakyReLU(0.2),\n",
"\n",
" nn.Conv2d(256, 256, kernel_size=3, stride=2, padding=1),\n",
" nn.BatchNorm2d(256),\n",
" nn.LeakyReLU(0.2),\n",
"\n",
" nn.Conv2d(256, 512, kernel_size=3, padding=1),\n",
" nn.BatchNorm2d(512),\n",
" nn.LeakyReLU(0.2),\n",
"\n",
" nn.Conv2d(512, 512, kernel_size=3, stride=2, padding=1),\n",
" nn.BatchNorm2d(512),\n",
" nn.LeakyReLU(0.2),\n",
"\n",
" nn.AdaptiveAvgPool2d(1),\n",
" nn.Conv2d(512, 1024, kernel_size=1),\n",
" nn.LeakyReLU(0.2),\n",
" nn.Conv2d(1024, 1, kernel_size=1)\n",
" )\n",
"\n",
" def forward(self, x):\n",
" batch_size = x.size(0)\n",
" return torch.sigmoid(self.net(x).view(batch_size))"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "UbJRjjxBkUSW"
},
"source": [
"from torchvision.models.vgg import vgg16"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "PBErzouckUSW"
},
"source": [
"### Loss"
]
},
{
"cell_type": "code",
"metadata": {
"id": "1dBQrQBgkUSW"
},
"source": [
"class GeneratorLoss(nn.Module):\n",
" def __init__(self):\n",
" super(GeneratorLoss, self).__init__()\n",
" vgg = vgg16(pretrained=True)\n",
" loss_network = nn.Sequential(*list(vgg.features)[:31]).eval()\n",
" for param in loss_network.parameters():\n",
" param.requires_grad = False\n",
" self.loss_network = loss_network\n",
" self.mse_loss = nn.MSELoss()\n",
" self.tv_loss = TVLoss()\n",
"\n",
" def forward(self, out_labels, out_images, target_images):\n",
" # Adversarial Loss\n",
" adversarial_loss = torch.mean(1 - out_labels)\n",
" # Perception Loss\n",
" perception_loss = self.mse_loss(self.loss_network(out_images), self.loss_network(target_images))\n",
" # Image Loss\n",
" image_loss = self.mse_loss(out_images, target_images)\n",
" # TV Loss\n",
" tv_loss = self.tv_loss(out_images)\n",
" return image_loss + 0.001 * adversarial_loss + 0.006 * perception_loss + 2e-8 * tv_loss\n"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "uGPc4qnWkUSW"
},
"source": [
"class TVLoss(nn.Module):\n",
" def __init__(self, tv_loss_weight=1):\n",
" super(TVLoss, self).__init__()\n",
" self.tv_loss_weight = tv_loss_weight\n",
"\n",
" def forward(self, x):\n",
" batch_size = x.size()[0]\n",
" h_x = x.size()[2]\n",
" w_x = x.size()[3]\n",
" count_h = self.tensor_size(x[:, :, 1:, :])\n",
" count_w = self.tensor_size(x[:, :, :, 1:])\n",
" h_tv = torch.pow((x[:, :, 1:, :] - x[:, :, :h_x - 1, :]), 2).sum()\n",
" w_tv = torch.pow((x[:, :, :, 1:] - x[:, :, :, :w_x - 1]), 2).sum()\n",
" return self.tv_loss_weight * 2 * (h_tv / count_h + w_tv / count_w) / batch_size\n",
"\n",
" @staticmethod\n",
" def tensor_size(t):\n",
" return t.size()[1] * t.size()[2] * t.size()[3]"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "duZ8UvogmJ8X"
},
"source": [
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "9QPi_X9_kUSW"
},
"source": [
"netG = Generator(UPSCALE_FACTOR)"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "6sQrfCZHkUSW"
},
"source": [
"netD = Discriminator()"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "PiiqnrvkkUSX",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 103,
"referenced_widgets": [
"5b14de92c3bb43f5b769eb8b8918d393",
"92a437f3d4d74fab94d4dfb34cad8cfc",
"6047442ef1464494a1775e852d98de6e",
"35baa95ba9434b2f94fbd8335df48a41",
"3b3fcd76032c4f4997e819e68f7ddef0",
"b201fa430ceb42e3bc438dc673856622",
"39c14f33da7a4d85bfcb187c0cacd68a",
"638747b422554adb8c64e087f52f4fd6"
]
},
"outputId": "596a3dca-fa3f-4329-b29c-85d632f9792b"
},
"source": [
"generator_criterion = GeneratorLoss()"
],
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"text": [
"Downloading: \"https://download.pytorch.org/models/vgg16-397923af.pth\" to /root/.cache/torch/hub/checkpoints/vgg16-397923af.pth\n"
],
"name": "stderr"
},
{
"output_type": "display_data",
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "5b14de92c3bb43f5b769eb8b8918d393",
"version_minor": 0,
"version_major": 2
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0, max=553433881.0), HTML(value='')))"
]
},
"metadata": {
"tags": []
}
},
{
"output_type": "stream",
"text": [
"\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "aQaIjG8UnunL"
},
"source": [
"generator_criterion = generator_criterion.to(device)\n",
"netG = netG.to(device)\n",
"netD = netD.to(device)"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "YRPW-NqlkUSX"
},
"source": [
"optimizerG = optim.Adam(netG.parameters(), lr=0.0002)\n",
"optimizerD = optim.Adam(netD.parameters(), lr=0.0002)"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "236B74VqOLgq"
},
"source": [
""
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "yULcXdFokUSX"
},
"source": [
"results = {'d_loss': [], 'g_loss': [], 'd_score': [], 'g_score': [], 'psnr': [], 'ssim': []}"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "9Esuf4_xkUSX"
},
"source": [
"## Train"
]
},
{
"cell_type": "code",
"metadata": {
"id": "rg9EVhMGkUSX"
},
"source": [
"from tqdm import tqdm\n",
"import os"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "lDi9oyr-kUSX"
},
"source": [
"N_EPOCHS = 150"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "dJSFZRzykUSX",
"outputId": "dc79719e-2477-4621-d365-c6d7c5fb99c2"
},
"source": [
"for epoch in range(1, N_EPOCHS + 1):\n",
" train_bar = tqdm(train_loader)\n",
" running_results = {'batch_sizes': 0, 'd_loss': 0, 'g_loss': 0, 'd_score': 0, 'g_score': 0}\n",
"\n",
" netG.train()\n",
" netD.train()\n",
" for data, target in train_bar:\n",
" g_update_first = True\n",
" batch_size = data.size(0)\n",
" running_results['batch_sizes'] += batch_size\n",
" \n",
" real_img = Variable(target)\n",
" if torch.cuda.is_available():\n",
" real_img = real_img.cuda()\n",
" z = Variable(data)\n",
" if torch.cuda.is_available():\n",
" z = z.cuda()\n",
" \n",
" ############################\n",
" # (1) Update D network: maximize D(x)-1-D(G(z))\n",
" ###########################\n",
" fake_img = netG(z)\n",
"\n",
" netD.zero_grad()\n",
" real_out = netD(real_img).mean()\n",
" fake_out = netD(fake_img).mean()\n",
" d_loss = 1 - real_out + fake_out\n",
" d_loss.backward(retain_graph=True)\n",
" optimizerD.step()\n",
"\n",
" ############################\n",
" # (2) Update G network: minimize 1-D(G(z)) + Perception Loss + Image Loss + TV Loss\n",
" ###########################\n",
" ###### Was causing Runtime Error ######\n",
" fake_img = netG(z)\n",
" fake_out = netD(fake_img).mean()\n",
" #######################################\n",
" netG.zero_grad()\n",
" g_loss = generator_criterion(fake_out, fake_img, real_img)\n",
" g_loss.backward()\n",
"\n",
" fake_img = netG(z)\n",
" fake_out = netD(fake_img).mean()\n",
"\n",
" optimizerG.step()\n",
"\n",
" # loss for current batch before optimization \n",
" running_results['g_loss'] += g_loss.item() * batch_size\n",
" running_results['d_loss'] += d_loss.item() * batch_size\n",
" running_results['d_score'] += real_out.item() * batch_size\n",
" running_results['g_score'] += fake_out.item() * batch_size\n",
"\n",
" train_bar.set_description(desc='[%d/%d] Loss_D: %.4f Loss_G: %.4f D(x): %.4f D(G(z)): %.4f' % (\n",
" epoch, N_EPOCHS, running_results['d_loss'] / running_results['batch_sizes'],\n",
" running_results['g_loss'] / running_results['batch_sizes'],\n",
" running_results['d_score'] / running_results['batch_sizes'],\n",
" running_results['g_score'] / running_results['batch_sizes']))\n",
"\n",
" netG.eval()\n",
" out_path = 'training_results/SRF_' + str(UPSCALE_FACTOR) + '/'\n",
" if not os.path.exists(out_path):\n",
" os.makedirs(out_path)"
],
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"text": [
"[1/150] Loss_D: 0.8393 Loss_G: 0.0452 D(x): 0.5709 D(G(z)): 0.3487: 100%|██████████| 13/13 [01:01<00:00, 4.72s/it]\n",
"[2/150] Loss_D: 0.7786 Loss_G: 0.0193 D(x): 0.5395 D(G(z)): 0.3129: 100%|██████████| 13/13 [01:00<00:00, 4.66s/it]\n",
"[3/150] Loss_D: 0.3920 Loss_G: 0.0157 D(x): 0.7983 D(G(z)): 0.1449: 100%|██████████| 13/13 [00:59<00:00, 4.57s/it]\n",
"[4/150] Loss_D: 0.1291 Loss_G: 0.0145 D(x): 0.9329 D(G(z)): 0.0532: 100%|██████████| 13/13 [00:59<00:00, 4.57s/it]\n",
"[5/150] Loss_D: 0.0426 Loss_G: 0.0136 D(x): 0.9778 D(G(z)): 0.0177: 100%|██████████| 13/13 [00:59<00:00, 4.57s/it]\n",
"[6/150] Loss_D: 0.0277 Loss_G: 0.0131 D(x): 0.9834 D(G(z)): 0.0386: 100%|██████████| 13/13 [01:00<00:00, 4.64s/it]\n",
"[7/150] Loss_D: 0.8722 Loss_G: 0.0119 D(x): 0.5780 D(G(z)): 0.3579: 100%|██████████| 13/13 [00:59<00:00, 4.59s/it]\n",
"[8/150] Loss_D: 0.7058 Loss_G: 0.0104 D(x): 0.5374 D(G(z)): 0.2106: 100%|██████████| 13/13 [00:59<00:00, 4.60s/it]\n",
"[9/150] Loss_D: 0.4493 Loss_G: 0.0108 D(x): 0.7283 D(G(z)): 0.1293: 100%|██████████| 13/13 [00:59<00:00, 4.57s/it]\n",
"[10/150] Loss_D: 0.1850 Loss_G: 0.0102 D(x): 0.9002 D(G(z)): 0.0621: 100%|██████████| 13/13 [00:59<00:00, 4.58s/it]\n",
" 0%| | 0/13 [00:00<?, ?it/s]"
],
"name": "stderr"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "YTPIWjk2NC5P"
},
"source": [
"from torchvision.transforms import ToTensor, ToPILImage"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "O9yt2ExXLxjO"
},
"source": [
"from PIL import Image"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "Qz8w5AtxcuRi"
},
"source": [
"torch.save(netG.state_dict(), \"super_res_gen.pth\")"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "4CIfaGsmQVwg",
"outputId": "eb6874a0-3589-400f-c571-9cc9f2d32d28"
},
"source": [
"netG.load_state_dict(torch.load(\"super_res_gen.pth\")) # If you already have a pretrained weights file for this model."
],
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"<All keys matched successfully>"
]
},
"metadata": {
"tags": []
},
"execution_count": 18
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "8LI5BqU5NIWu",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 165
},
"outputId": "97eef7b0-b8f0-48d5-b532-c4afc52a4b7d"
},
"source": [
"# LOAD your OWN image here. This cell wont work unless you upload your own image to Colab\n",
"image = Image.open(\"table2.jpg\")"
],
"execution_count": null,
"outputs": [
{
"output_type": "error",
"ename": "NameError",
"evalue": "ignored",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m<ipython-input-1-369897577819>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mimage\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mImage\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mopen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"table2.jpg\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
"\u001b[0;31mNameError\u001b[0m: name 'Image' is not defined"
]
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "TQjCZ7CBkUSX",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "f4b8339f-09b8-4c82-8886-4af5ef4530c8"
},
"source": [
"image = Variable(ToTensor()(image), volatile=True).unsqueeze(0)"
],
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"text": [
"/usr/local/lib/python3.6/dist-packages/ipykernel_launcher.py:1: UserWarning: volatile was removed and now has no effect. Use `with torch.no_grad():` instead.\n",
" \"\"\"Entry point for launching an IPython kernel.\n"
],
"name": "stderr"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "VpCwx3FXRXYO"
},
"source": [
"netG = netG.to(torch.device(\"cuda\"))\n",
"image = image.to(torch.device(\"cuda\"))"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "MehEf581NQTH"
},
"source": [
"out = netG(image)\n",
"out_img = ToPILImage()(out[0].data.cpu())"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "wrPdGj02NVRi"
},
"source": [
"out_img.save(\"table2-superres-4x.jpg\")"
],
"execution_count": null,
"outputs": []
}
]
}
@SharanSMenon
Copy link
Author

Add this to beginning of file

from PIL import Image

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment