Skip to content

Commit

Permalink
add RVRT training codes
Browse files Browse the repository at this point in the history
  • Loading branch information
JingyunLiang committed Jun 13, 2022
1 parent 06bd194 commit d1c0d9d
Show file tree
Hide file tree
Showing 19 changed files with 4,045 additions and 99 deletions.
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
*[Computer Vision Lab](https://vision.ee.ethz.ch/the-institute.html), ETH Zurich, Switzerland*

_______
- **_News (2022-06-01)_**: We release [the training codes](https://github.com/cszn/KAIR/blob/master/docs/README_RVRT.md) of [RVRT ![GitHub Stars](https://img.shields.io/github/stars/JingyunLiang/RVRT?style=social)](https://github.com/JingyunLiang/RVRT) for video SR, deblurring and denoising.

- **_News (2022-05-05)_**: Try the [online demo](https://replicate.com/cszn/scunet) of [SCUNet ![GitHub Stars](https://img.shields.io/github/stars/cszn/SCUNet?style=social)](https://github.com/cszn/SCUNet) for blind real image denoising.

- **_News (2022-03-23)_**: We release [the testing codes](https://github.com/cszn/SCUNet) of [SCUNet ![GitHub Stars](https://img.shields.io/github/stars/cszn/SCUNet?style=social)](https://github.com/cszn/SCUNet) for blind real image denoising.
Expand Down
103 changes: 7 additions & 96 deletions data/dataset_video_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def __len__(self):


class SingleVideoRecurrentTestDataset(data.Dataset):
"""Single ideo test dataset for recurrent architectures, which takes LR video
"""Single video test dataset for recurrent architectures, which takes LR video
frames as input and output corresponding HR video frames (only input LQ path).
More generally, it supports testing dataset with following structures:
Expand Down Expand Up @@ -245,11 +245,12 @@ def __init__(self, opt):
super(VideoTestVimeo90KDataset, self).__init__()
self.opt = opt
self.cache_data = opt['cache_data']
temporal_scale = opt.get('temporal_scale', 1)
if self.cache_data:
raise NotImplementedError('cache_data in Vimeo90K-Test dataset is not implemented.')
self.gt_root, self.lq_root = opt['dataroot_gt'], opt['dataroot_lq']
self.data_info = {'lq_path': [], 'gt_path': [], 'folder': [], 'idx': [], 'border': []}
neighbor_list = [i + (9 - opt['num_frame']) // 2 for i in range(opt['num_frame'])]
neighbor_list = [i + (9 - opt['num_frame']) // 2 for i in range(opt['num_frame'])][:: temporal_scale]

with open(opt['meta_info_file'], 'r') as fin:
subfolders = [line.split(' ')[0] for line in fin]
Expand All @@ -263,6 +264,7 @@ def __init__(self, opt):
self.data_info['border'].append(0)

self.pad_sequence = opt.get('pad_sequence', False)
self.mirror_sequence = opt.get('mirror_sequence', False)

def __getitem__(self, index):
lq_path = self.data_info['lq_path'][index]
Expand All @@ -274,6 +276,9 @@ def __getitem__(self, index):
if self.pad_sequence: # pad the sequence: 7 frames to 8 frames
imgs_lq = torch.cat([imgs_lq, imgs_lq[-1:,...]], dim=0)

if self.mirror_sequence: # mirror the sequence: 7 frames to 14 frames
imgs_lq = torch.cat([imgs_lq, imgs_lq.flip(0)], dim=0)

return {
'L': imgs_lq, # (t, c, h, w)
'H': img_gt, # (c, h, w)
Expand All @@ -286,97 +291,3 @@ def __getitem__(self, index):
def __len__(self):
return len(self.data_info['gt_path'])


class SingleVideoRecurrentTestDataset(data.Dataset):
"""Single Video test dataset (only input LQ path).
Supported datasets: Vid4, REDS4, REDSofficial.
More generally, it supports testing dataset with following structures:
dataroot
├── subfolder1
├── frame000
├── frame001
├── ...
├── subfolder1
├── frame000
├── frame001
├── ...
├── ...
For testing datasets, there is no need to prepare LMDB files.
Args:
opt (dict): Config for train dataset. It contains the following keys:
dataroot_gt (str): Data root path for gt.
dataroot_lq (str): Data root path for lq.
io_backend (dict): IO backend type and other kwarg.
cache_data (bool): Whether to cache testing datasets.
name (str): Dataset name.
meta_info_file (str): The path to the file storing the list of test
folders. If not provided, all the folders in the dataroot will
be used.
num_frame (int): Window size for input frames.
padding (str): Padding mode.
"""

def __init__(self, opt):
super(SingleVideoRecurrentTestDataset, self).__init__()
self.opt = opt
self.cache_data = opt['cache_data']
self.lq_root = opt['dataroot_lq']
self.data_info = {'lq_path': [], 'folder': [], 'idx': [], 'border': []}
# file client (io backend)
self.file_client = None

self.imgs_lq = {}
if 'meta_info_file' in opt:
with open(opt['meta_info_file'], 'r') as fin:
subfolders = [line.split(' ')[0] for line in fin]
subfolders_lq = [osp.join(self.lq_root, key) for key in subfolders]
else:
subfolders_lq = sorted(glob.glob(osp.join(self.lq_root, '*')))

for subfolder_lq in subfolders_lq:
# get frame list for lq and gt
subfolder_name = osp.basename(subfolder_lq)
img_paths_lq = sorted(list(utils_video.scandir(subfolder_lq, full_path=True)))

max_idx = len(img_paths_lq)

self.data_info['lq_path'].extend(img_paths_lq)
self.data_info['folder'].extend([subfolder_name] * max_idx)
for i in range(max_idx):
self.data_info['idx'].append(f'{i}/{max_idx}')
border_l = [0] * max_idx
for i in range(self.opt['num_frame'] // 2):
border_l[i] = 1
border_l[max_idx - i - 1] = 1
self.data_info['border'].extend(border_l)

# cache data or save the frame list
if self.cache_data:
logger.info(f'Cache {subfolder_name} for VideoTestDataset...')
self.imgs_lq[subfolder_name] = utils_video.read_img_seq(img_paths_lq)
else:
self.imgs_lq[subfolder_name] = img_paths_lq

# Find unique folder strings
self.folders = sorted(list(set(self.data_info['folder'])))

def __getitem__(self, index):
folder = self.folders[index]

if self.cache_data:
imgs_lq = self.imgs_lq[folder]
else:
imgs_lq = utils_video.read_img_seq(self.imgs_lq[folder])

return {
'L': imgs_lq,
'folder': folder,
'lq_path': self.imgs_lq[folder],
}

def __len__(self):
return len(self.folders)
4 changes: 2 additions & 2 deletions data/dataset_video_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,7 +322,7 @@ def __init__(self, opt):
self.random_reverse = opt['random_reverse']
print(f'Random reverse is {self.random_reverse}.')

self.flip_sequence = opt.get('flip_sequence', False)
self.mirror_sequence = opt.get('mirror_sequence', False)
self.pad_sequence = opt.get('pad_sequence', False)
self.neighbor_list = [1, 2, 3, 4, 5, 6, 7]

Expand Down Expand Up @@ -370,7 +370,7 @@ def __getitem__(self, index):
img_lqs = torch.stack(img_results[:7], dim=0)
img_gts = torch.stack(img_results[7:], dim=0)

if self.flip_sequence: # flip the sequence: 7 frames to 14 frames
if self.mirror_sequence: # mirror the sequence: 7 frames to 14 frames
img_lqs = torch.cat([img_lqs, img_lqs.flip(0)], dim=0)
img_gts = torch.cat([img_gts, img_gts.flip(0)], dim=0)
elif self.pad_sequence: # pad the sequence: 7 frames to 8 frames
Expand Down
Loading

0 comments on commit d1c0d9d

Please sign in to comment.