Skip to content

Instantly share code, notes, and snippets.

Last active May 11, 2019 15:28
Show Gist options
  • Save ata4/619b28422d288605685200a8c0edfd6b to your computer and use it in GitHub Desktop.
Save ata4/619b28422d288605685200a8c0edfd6b to your computer and use it in GitHub Desktop.
ESRGAN launcher with tiling support
import sys
import os
import glob
import math
import argparse
import cv2
import numpy as np
import torch
import architecture as arch
class ESRGAN:
def __init__(self, model_path, device, scale_factor=4, tile_size=256):
self.scale_factor = scale_factor
self.tile_size = tile_size
model = arch.RRDB_Net(3, 3, 64, 23, upscale=self.scale_factor)
model.load_state_dict(torch.load(model_path), strict=True)
for _, v in model.named_parameters():
v.requires_grad = False
self.model =
self.device = device
def upscale(self, img):
img = img * 1.0 / 255
img = torch.from_numpy(np.transpose(img[:, :, [2, 1, 0]], (2, 0, 1))).float()
img_LR = img.unsqueeze(0).to(self.device)
output = self.model(img_LR).data.squeeze().float().cpu().clamp_(0, 1).numpy()
output = np.transpose(output[[2, 1, 0], :, :], (1, 2, 0))
output = (output * 255.0).round()
return output
def process(self, input_path, output_path):
# read input image
input = cv2.imread(input_path, cv2.IMREAD_COLOR)
width, height, depth = input.shape
# process small images directly without the use of tiles
if self.tile_size > 0 and width <= self.tile_size and height <= self.tile_size:
output = self.upscale(input)
cv2.imwrite(output_path, output)
# pre-allocate upscaled output image
output = np.zeros((width * self.scale_factor, height * self.scale_factor, depth), np.uint8)
tiles_x = math.ceil(width / self.tile_size)
tiles_y = math.ceil(height / self.tile_size)
for y in range(tiles_y):
for x in range(tiles_x):
# extract tile from input image
ofs_x = x * self.tile_size
ofs_y = y * self.tile_size
input_start_x = ofs_x
input_end_x = min(ofs_x + self.tile_size, width)
input_start_y = ofs_y
input_end_y = min(ofs_y + self.tile_size, height)
input_tile_width = input_end_x - input_start_x
input_tile_height = input_end_y - input_start_y
tile_idx = y * tiles_x + x + 1
print('Tile %d/%d (x=%d y=%d %dx%d)' % (tile_idx, tiles_x * tiles_y, x, y, input_tile_width, input_tile_height), flush=True)
input_tile = input[input_start_x:input_end_x, input_start_y:input_end_y]
# upscale tile
output_tile = self.upscale(input_tile)
# put tile into output image
output_start_x = input_start_x * self.scale_factor
output_end_x = input_end_x * self.scale_factor
output_start_y = input_start_y * self.scale_factor
output_end_y = input_end_y * self.scale_factor
output[output_start_x:output_end_x, output_start_y:output_end_y] = output_tile
cv2.imwrite(output_path, output)
def main():
parser = argparse.ArgumentParser(description='ESRGAN image upscaler with tiling support')
parser.add_argument('input', help='Path to input folder')
parser.add_argument('output', help='Path to output folder')
parser.add_argument('model', help='Path to model file')
parser.add_argument('--tilesize', type=int, metavar='N', default=256, help='size of tiles in pixels (0 = don\'t use tiles)')
parser.add_argument('--cpu', action='store_true', help='use CPU instead of GPU/CUDA (very slow!)')
args = parser.parse_args()
if args.cpu:
device = torch.device('cpu')
device = torch.device('cuda')
input_folder = args.input
output_folder = args.output
model_path = args.model
print("Initializing ESRGAN using model '%s'" % os.path.basename(model_path), flush=True)
esrgan = ESRGAN(model_path, device, tile_size=args.tilesize)
for input_path in glob.glob(input_folder):
input_name = os.path.basename(input_path)
print('Upscaling', input_name, flush=True)
input_name = os.path.splitext(input_name)[0]
output_path = os.path.join(output_folder, input_name + '_esrgan.png')
esrgan.process(input_path, output_path)
if __name__ == '__main__':
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment