Last active
September 4, 2023 07:32
-
-
Save daisycamber/a4f38402199c1adf7955571d6cb6eaa9 to your computer and use it in GitHub Desktop.
Super resolution with options python
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import cv2 | |
from shell.execute import run_command | |
from django.conf import settings | |
import numpy as np | |
import os, sys | |
from importlib import reload | |
from cv2 import dnn_superres | |
from importlib import import_module | |
from threading import Thread | |
import traceback | |
import math | |
# ValueError: could not broadcast input array from shape (656,492,3) into shape (164,123,4) | |
SLICES = 20 | |
SIMULTANEOUS_THREADS = 1 | |
def superres(image_path, model, mode, size): | |
run_command('sudo chmod a+rwX ' + str(image_path)) | |
run_command('sudo chmod love:users ' + str(image_path)) | |
image = cv2.imread(image_path) | |
height, width, dim = image.shape | |
wmod = int(width/SLICES) | |
hmod = int(height/SLICES) | |
threads = [None] * (SLICES * SLICES) | |
images = [[None] * SLICES] * SLICES | |
dnn = [[None] * SLICES] * SLICES | |
def thread(image, y, x, images): | |
print('superres slice {},{}'.format(x, y)) | |
try: | |
image = cv2.imread(image_path) | |
i = image[int(y*hmod): int((y+1)*hmod), int(x * wmod): int((x+1)*wmod)] | |
path = os.path.join(settings.BASE_DIR, model) | |
dnn[y][x] = import_module('cv2.dnn_superres') | |
sr = dnn[y][x].DnnSuperResImpl_create() | |
sr.readModel(path) | |
sr.setModel(mode, size) | |
result = sr.upsample(i) | |
images[y][x] = result | |
except: | |
print(traceback.format_exc()) | |
thread(image, y, x, images) | |
thread_count = 0 | |
last_threads = [] | |
while thread_count < SLICES * SLICES: | |
for i in range(SIMULTANEOUS_THREADS): | |
y = math.floor(thread_count/SLICES) | |
x = thread_count%SLICES | |
threads[thread_count] = Thread(target=thread, args=(image, y, x, images)) | |
threads[thread_count].start() | |
thread_count = thread_count + 1 | |
for i in range(len(threads)): | |
if threads[i]: threads[i].join() | |
height, width, dim = images[0][0].shape | |
height = height*SLICES | |
width = width*SLICES | |
output = np.zeros((height, width, dim), dtype=np.uint8) | |
wmod = int(width/SLICES) | |
hmod = int(height/SLICES) | |
for y in range(0, SLICES-1): | |
for x in range(0,SLICES-1): | |
image_h, image_w, image_dim = images[y][x].shape | |
mask_h, mask_w, mask_dim = output[int(y * hmod): int(y * hmod) + image_h, int(x * wmod): int(x * wmod) + image_w].shape | |
fy = mask_h / image_h | |
fx = mask_w / image_w | |
scaled_image = cv2.resize(images[y][x], (mask_w, mask_h), fx=fx, fy=fy) | |
output[int(y * hmod): int(y * hmod) + image_h, int(x * wmod): int(x * wmod) + image_w] = scaled_image | |
cv2.imwrite(image_path, output) | |
def superres_x8(image_path): | |
superres(image_path, "enhance/LapSRN_x8.pb", 'lapsrn', 8) | |
def superres_x4(image_path): | |
superres(image_path, "enhance/EDSR_x4.pb", 'edsr', 4) | |
def superres_x2(image_path): | |
superres(image_path, "enhance/EDSR_x2.pb", 'edsr', 2) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment