Created
September 13, 2020 13:13
-
-
Save pyk/d11f19c89ecb3b15796ea790aa0d9608 to your computer and use it in GitHub Desktop.
Apply pre-trained EDVR model without the ground truth
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class VideoTestDatasetWithoutGT(data.Dataset): | |
"""Video test dataset. | |
More generally, it supports testing dataset with following structures: | |
dataroot | |
├── subfolder1 | |
├── frame000 | |
├── frame001 | |
├── ... | |
├── subfolder1 | |
├── frame000 | |
├── frame001 | |
├── ... | |
├── ... | |
For testing datasets, there is no need to prepare LMDB files. | |
Args: | |
opt (dict): Config for train dataset. It contains the following keys: | |
dataroot_lq (str): Data root path for lq. | |
io_backend (dict): IO backend type and other kwarg. | |
cache_data (bool): Whether to cache testing datasets. | |
name (str): Dataset name. | |
meta_info_file (str): The path to the file storing the list of test | |
folders. If not provided, all the folders in the dataroot will | |
be used. | |
num_frame (int): Window size for input frames. | |
padding (str): Padding mode. | |
""" | |
def __init__(self, opt): | |
self.opt = opt | |
self.cache_data = opt['cache_data'] | |
sself.lq_root = opt['dataroot_lq'] | |
self.data_info = { | |
'lq_path': [], | |
'folder': [], | |
'idx': [], | |
'border': [] | |
} | |
# file client (io backend) | |
self.file_client = None | |
self.io_backend_opt = opt['io_backend'] | |
assert self.io_backend_opt[ | |
'type'] != 'lmdb', 'No need to use lmdb during validation/test.' | |
logger = get_root_logger() | |
logger.info(f'Generate data info for VideoTestDatasetWithoutGT - {opt["name"]}') | |
self.imgs_lq = {} | |
if 'meta_info_file' in opt: | |
with open(opt['meta_info_file'], 'r') as fin: | |
subfolders = [line.split(' ')[0] for line in fin] | |
subfolders_lq = [ | |
osp.join(self.lq_root, key) for key in subfolders | |
] | |
else: | |
subfolders_lq = sorted(glob.glob(osp.join(self.lq_root, '*'))) | |
for subfolder_lq in zip(subfolders_lq): | |
# get frame list for lq | |
subfolder_name = osp.basename(subfolder_lq) | |
img_paths_lq = sorted([ | |
osp.join(subfolder_lq, v) | |
for v in mmcv.scandir(subfolder_lq) | |
]) | |
max_idx = len(img_paths_lq) | |
self.data_info['lq_path'].extend(img_paths_lq) | |
self.data_info['folder'].extend([subfolder_name] * max_idx) | |
for i in range(max_idx): | |
self.data_info['idx'].append(f'{i}/{max_idx}') | |
border_l = [0] * max_idx | |
for i in range(self.opt['num_frame'] // 2): | |
border_l[i] = 1 | |
border_l[max_idx - i - 1] = 1 | |
self.data_info['border'].extend(border_l) | |
# cache data or save the frame list | |
if self.cache_data: | |
logger.info( | |
f'Cache {subfolder_name} for VideoTestDataset...') | |
self.imgs_lq[subfolder_name] = util.read_img_seq( | |
img_paths_lq) | |
else: | |
self.imgs_lq[subfolder_name] = img_paths_lq | |
def __getitem__(self, index): | |
folder = self.data_info['folder'][index] | |
idx, max_idx = self.data_info['idx'][index].split('/') | |
idx, max_idx = int(idx), int(max_idx) | |
border = self.data_info['border'][index] | |
lq_path = self.data_info['lq_path'][index] | |
select_idx = util.generate_frame_indices( | |
idx, max_idx, self.opt['num_frame'], padding=self.opt['padding']) | |
if self.cache_data: | |
imgs_lq = self.imgs_lq[folder].index_select( | |
0, torch.LongTensor(select_idx)) | |
else: | |
img_paths_lq = [self.imgs_lq[folder][i] for i in select_idx] | |
imgs_lq = util.read_img_seq(img_paths_lq) | |
return { | |
'lq': imgs_lq, # (t, c, h, w) | |
'folder': folder, # folder name | |
'idx': self.data_info['idx'][index], # e.g., 0/99 | |
'border': border, # 1 for border, 0 for non-border | |
'lq_path': lq_path # center frame | |
} | |
def __len__(self): | |
return len(self.data_info['lq_path']) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment