Skip to content

Instantly share code, notes, and snippets.

@ata4 ata4/process.py
Last active May 11, 2019

Embed
What would you like to do?
ESRGAN launcher with tiling support
#!/usr/bin/python3
import sys
import os
import glob
import math
import argparse
import cv2
import numpy as np
import torch
import architecture as arch
class ESRGAN:
def __init__(self, model_path, device, scale_factor=4, tile_size=256):
self.scale_factor = scale_factor
self.tile_size = tile_size
model = arch.RRDB_Net(3, 3, 64, 23, upscale=self.scale_factor)
model.load_state_dict(torch.load(model_path), strict=True)
model.eval()
for _, v in model.named_parameters():
v.requires_grad = False
self.model = model.to(device)
self.device = device
def upscale(self, img):
img = img * 1.0 / 255
img = torch.from_numpy(np.transpose(img[:, :, [2, 1, 0]], (2, 0, 1))).float()
img_LR = img.unsqueeze(0).to(self.device)
output = self.model(img_LR).data.squeeze().float().cpu().clamp_(0, 1).numpy()
output = np.transpose(output[[2, 1, 0], :, :], (1, 2, 0))
output = (output * 255.0).round()
return output
def process(self, input_path, output_path):
# read input image
input = cv2.imread(input_path, cv2.IMREAD_COLOR)
width, height, depth = input.shape
# process small images directly without the use of tiles
if self.tile_size > 0 and width <= self.tile_size and height <= self.tile_size:
output = self.upscale(input)
cv2.imwrite(output_path, output)
return
# pre-allocate upscaled output image
output = np.zeros((width * self.scale_factor, height * self.scale_factor, depth), np.uint8)
tiles_x = math.ceil(width / self.tile_size)
tiles_y = math.ceil(height / self.tile_size)
for y in range(tiles_y):
for x in range(tiles_x):
# extract tile from input image
ofs_x = x * self.tile_size
ofs_y = y * self.tile_size
input_start_x = ofs_x
input_end_x = min(ofs_x + self.tile_size, width)
input_start_y = ofs_y
input_end_y = min(ofs_y + self.tile_size, height)
input_tile_width = input_end_x - input_start_x
input_tile_height = input_end_y - input_start_y
tile_idx = y * tiles_x + x + 1
print('Tile %d/%d (x=%d y=%d %dx%d)' % (tile_idx, tiles_x * tiles_y, x, y, input_tile_width, input_tile_height), flush=True)
input_tile = input[input_start_x:input_end_x, input_start_y:input_end_y]
# upscale tile
output_tile = self.upscale(input_tile)
# put tile into output image
output_start_x = input_start_x * self.scale_factor
output_end_x = input_end_x * self.scale_factor
output_start_y = input_start_y * self.scale_factor
output_end_y = input_end_y * self.scale_factor
output[output_start_x:output_end_x, output_start_y:output_end_y] = output_tile
cv2.imwrite(output_path, output)
def main():
parser = argparse.ArgumentParser(description='ESRGAN image upscaler with tiling support')
parser.add_argument('input', help='Path to input folder')
parser.add_argument('output', help='Path to output folder')
parser.add_argument('model', help='Path to model file')
parser.add_argument('--tilesize', type=int, metavar='N', default=256, help='size of tiles in pixels (0 = don\'t use tiles)')
parser.add_argument('--cpu', action='store_true', help='use CPU instead of GPU/CUDA (very slow!)')
args = parser.parse_args()
if args.cpu:
device = torch.device('cpu')
else:
device = torch.device('cuda')
input_folder = args.input
output_folder = args.output
model_path = args.model
print("Initializing ESRGAN using model '%s'" % os.path.basename(model_path), flush=True)
esrgan = ESRGAN(model_path, device, tile_size=args.tilesize)
for input_path in glob.glob(input_folder):
input_name = os.path.basename(input_path)
print('Upscaling', input_name, flush=True)
input_name = os.path.splitext(input_name)[0]
output_path = os.path.join(output_folder, input_name + '_esrgan.png')
esrgan.process(input_path, output_path)
if __name__ == '__main__':
exit(main())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
You can’t perform that action at this time.