Skip to content

Instantly share code, notes, and snippets.

Created September 13, 2020 13:13
Show Gist options
  • Save pyk/d11f19c89ecb3b15796ea790aa0d9608 to your computer and use it in GitHub Desktop.
Save pyk/d11f19c89ecb3b15796ea790aa0d9608 to your computer and use it in GitHub Desktop.
Apply pre-trained EDVR model without the ground truth
class VideoTestDatasetWithoutGT(data.Dataset):
"""Video test dataset.
More generally, it supports testing dataset with following structures:
├── subfolder1
├── frame000
├── frame001
├── ...
├── subfolder1
├── frame000
├── frame001
├── ...
├── ...
For testing datasets, there is no need to prepare LMDB files.
opt (dict): Config for train dataset. It contains the following keys:
dataroot_lq (str): Data root path for lq.
io_backend (dict): IO backend type and other kwarg.
cache_data (bool): Whether to cache testing datasets.
name (str): Dataset name.
meta_info_file (str): The path to the file storing the list of test
folders. If not provided, all the folders in the dataroot will
be used.
num_frame (int): Window size for input frames.
padding (str): Padding mode.
def __init__(self, opt):
self.opt = opt
self.cache_data = opt['cache_data']
sself.lq_root = opt['dataroot_lq']
self.data_info = {
'lq_path': [],
'folder': [],
'idx': [],
'border': []
# file client (io backend)
self.file_client = None
self.io_backend_opt = opt['io_backend']
assert self.io_backend_opt[
'type'] != 'lmdb', 'No need to use lmdb during validation/test.'
logger = get_root_logger()'Generate data info for VideoTestDatasetWithoutGT - {opt["name"]}')
self.imgs_lq = {}
if 'meta_info_file' in opt:
with open(opt['meta_info_file'], 'r') as fin:
subfolders = [line.split(' ')[0] for line in fin]
subfolders_lq = [
osp.join(self.lq_root, key) for key in subfolders
subfolders_lq = sorted(glob.glob(osp.join(self.lq_root, '*')))
for subfolder_lq in zip(subfolders_lq):
# get frame list for lq
subfolder_name = osp.basename(subfolder_lq)
img_paths_lq = sorted([
osp.join(subfolder_lq, v)
for v in mmcv.scandir(subfolder_lq)
max_idx = len(img_paths_lq)
self.data_info['folder'].extend([subfolder_name] * max_idx)
for i in range(max_idx):
border_l = [0] * max_idx
for i in range(self.opt['num_frame'] // 2):
border_l[i] = 1
border_l[max_idx - i - 1] = 1
# cache data or save the frame list
if self.cache_data:
f'Cache {subfolder_name} for VideoTestDataset...')
self.imgs_lq[subfolder_name] = util.read_img_seq(
self.imgs_lq[subfolder_name] = img_paths_lq
def __getitem__(self, index):
folder = self.data_info['folder'][index]
idx, max_idx = self.data_info['idx'][index].split('/')
idx, max_idx = int(idx), int(max_idx)
border = self.data_info['border'][index]
lq_path = self.data_info['lq_path'][index]
select_idx = util.generate_frame_indices(
idx, max_idx, self.opt['num_frame'], padding=self.opt['padding'])
if self.cache_data:
imgs_lq = self.imgs_lq[folder].index_select(
0, torch.LongTensor(select_idx))
img_paths_lq = [self.imgs_lq[folder][i] for i in select_idx]
imgs_lq = util.read_img_seq(img_paths_lq)
return {
'lq': imgs_lq, # (t, c, h, w)
'folder': folder, # folder name
'idx': self.data_info['idx'][index], # e.g., 0/99
'border': border, # 1 for border, 0 for non-border
'lq_path': lq_path # center frame
def __len__(self):
return len(self.data_info['lq_path'])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment