Skip to content

Instantly share code, notes, and snippets.

@Laurawly
Created February 23, 2021 22:43
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save Laurawly/b7cdba5a6323390d84d9c9069934f1b1 to your computer and use it in GitHub Desktop.
Save Laurawly/b7cdba5a6323390d84d9c9069934f1b1 to your computer and use it in GitHub Desktop.
#from dlib_alignment import dlib_detect_face, face_recover
import torch
from PIL import Image
import torchvision.transforms as transforms
from models.SRGAN_model import SRGANModel
import numpy as np
import argparse
#import utils
import cv2
import random
import dlib
import time
_transform = transforms.Compose([transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5],std=[0.5, 0.5, 0.5])
])
def get_FaceSR_opt():
parser = argparse.ArgumentParser()
parser.add_argument('--gpu_ids', type=str, default=None)
parser.add_argument('--batch_size', type=int, default=32)
parser.add_argument('--scale', type=int, default=2)
parser.add_argument('--lr_size', type=int, default=128)
parser.add_argument('--hr_size', type=int, default=512)
parser.add_argument('--kernel_size', type=int, default=15)
# network G
parser.add_argument('--which_model_G', type=str, default='RRDBNet')
parser.add_argument('--G_in_nc', type=int, default=3)
parser.add_argument('--out_nc', type=int, default=3)
parser.add_argument('--G_nf', type=int, default=32)
parser.add_argument('--nb', type=int, default=1)
parser.add_argument('--gc', type=int, default=16)
# data dir
parser.add_argument('--pretrain_model_G', type=str, default='ESRGAN-x2-256face_nb1nf32gc16_100w.pth')
args = parser.parse_args()
return args
img_path = '00003799.jpg'
img = cv2.imread(img_path)
img = cv2.resize(img, (1280,720))
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
#infer
sr_model = SRGANModel(opt=get_FaceSR_opt(), is_train=False)
sr_model.load()
batch_size = get_FaceSR_opt().batch_size
print(batch_size)
input_img = torch.unsqueeze(_transform(Image.fromarray(img)), 0)
inputs = input_img.repeat(batch_size, 1, 1, 1)
print(inputs.size())
sr_model.var_L = inputs.to(sr_model.device)
#warmup
for i in range(20):
img_out = sr_model.test()
#time measure
torch.cuda.synchronize()
start = time.time()
for i in range(100):
sr_model.test()
torch.cuda.synchronize()
end = time.time()
elapsed_time = end - start
#save
output_img = img_out.squeeze(0).cpu().numpy()
print('all ', elapsed_time/100)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment