Skip to content

Commit

Permalink
Support VFI and STVSR for VRT
Browse files Browse the repository at this point in the history
  • Loading branch information
JingyunLiang authored and JingyunLiang committed Oct 4, 2022
1 parent d1c0d9d commit 44c9eda
Show file tree
Hide file tree
Showing 17 changed files with 8,371 additions and 145 deletions.
16 changes: 7 additions & 9 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
## Training and testing codes for USRNet, DnCNN, FFDNet, SRMD, DPSR, MSRResNet, ESRGAN, BSRGAN, SwinIR, VRT
## Training and testing codes for USRNet, DnCNN, FFDNet, SRMD, DPSR, MSRResNet, ESRGAN, BSRGAN, SwinIR, VRT, RVRT
[![download](https://img.shields.io/github/downloads/cszn/KAIR/total.svg)](https://github.com/cszn/KAIR/releases) ![visitors](https://visitor-badge.glitch.me/badge?page_id=cszn/KAIR)

[Kai Zhang](https://cszn.github.io/)

*[Computer Vision Lab](https://vision.ee.ethz.ch/the-institute.html), ETH Zurich, Switzerland*

_______
- **_News (2022-06-01)_**: We release [the training codes](https://github.com/cszn/KAIR/blob/master/docs/README_RVRT.md) of [RVRT ![GitHub Stars](https://img.shields.io/github/stars/JingyunLiang/RVRT?style=social)](https://github.com/JingyunLiang/RVRT) for video SR, deblurring and denoising.
- **_News (2022-10-04)_**: We release [the training codes](https://github.com/cszn/KAIR/blob/master/docs/README_RVRT.md) of [RVRT, NeurlPS2022 ![GitHub Stars](https://img.shields.io/github/stars/JingyunLiang/RVRT?style=social)](https://github.com/JingyunLiang/RVRT) for video SR, deblurring and denoising.

- **_News (2022-05-05)_**: Try the [online demo](https://replicate.com/cszn/scunet) of [SCUNet ![GitHub Stars](https://img.shields.io/github/stars/cszn/SCUNet?style=social)](https://github.com/cszn/SCUNet) for blind real image denoising.

Expand All @@ -23,13 +23,11 @@ We did not use the paired noisy/clean data by DND and SIDD during training!*__


- **_News (2022-02-15)_**: We release [the training codes](https://github.com/cszn/KAIR/blob/master/docs/README_VRT.md) of [VRT ![GitHub Stars](https://img.shields.io/github/stars/JingyunLiang/VRT?style=social)](https://github.com/JingyunLiang/VRT) for video SR, deblurring and denoising.
<p align="center">
<a href="https://github.com/JingyunLiang/VRT">
<img width=30% src="https://raw.githubusercontent.com/JingyunLiang/VRT/main/assets/teaser_vsr.gif"/>
<img width=30% src="https://raw.githubusercontent.com/JingyunLiang/VRT/main/assets/teaser_vdb.gif"/>
<img width=30% src="https://raw.githubusercontent.com/JingyunLiang/VRT/main/assets/teaser_vdn.gif"/>
</a>
</p>
![Eg1](https://raw.githubusercontent.com/JingyunLiang/VRT/main/assets/teaser_vsr.gif)
![Eg2](https://raw.githubusercontent.com/JingyunLiang/VRT/main/assets/teaser_vdb.gif)
![Eg3](https://raw.githubusercontent.com/JingyunLiang/VRT/main/assets/teaser_vdn.gif)
![Eg4](https://raw.githubusercontent.com/JingyunLiang/VRT/main/assets/teaser_vfi.gif)
![Eg5](https://raw.githubusercontent.com/JingyunLiang/VRT/main/assets/teaser_stvsr.gif)

- **_News (2021-12-23)_**: Our techniques are adopted in [https://www.amemori.ai/](https://www.amemori.ai/).
- **_News (2021-12-23)_**: Our new work for practical image denoising.
Expand Down
142 changes: 137 additions & 5 deletions data/dataset_video_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
import torch
from os import path as osp
import torch.utils.data as data
from torchvision import transforms
from PIL import Image

import utils.utils_video as utils_video

Expand Down Expand Up @@ -245,12 +247,12 @@ def __init__(self, opt):
super(VideoTestVimeo90KDataset, self).__init__()
self.opt = opt
self.cache_data = opt['cache_data']
temporal_scale = opt.get('temporal_scale', 1)
self.temporal_scale = opt.get('temporal_scale', 1)
if self.cache_data:
raise NotImplementedError('cache_data in Vimeo90K-Test dataset is not implemented.')
self.gt_root, self.lq_root = opt['dataroot_gt'], opt['dataroot_lq']
self.data_info = {'lq_path': [], 'gt_path': [], 'folder': [], 'idx': [], 'border': []}
neighbor_list = [i + (9 - opt['num_frame']) // 2 for i in range(opt['num_frame'])][:: temporal_scale]
neighbor_list = [i + (9 - opt['num_frame']) // 2 for i in range(opt['num_frame'])][:: self.temporal_scale]

with open(opt['meta_info_file'], 'r') as fin:
subfolders = [line.split(' ')[0] for line in fin]
Expand All @@ -259,7 +261,7 @@ def __init__(self, opt):
self.data_info['gt_path'].append(gt_path)
lq_paths = [osp.join(self.lq_root, subfolder, f'im{i}.png') for i in neighbor_list]
self.data_info['lq_path'].append(lq_paths)
self.data_info['folder'].append('vimeo90k')
self.data_info['folder'].append(subfolder)
self.data_info['idx'].append(f'{idx}/{len(subfolders)}')
self.data_info['border'].append(0)

Expand All @@ -271,7 +273,6 @@ def __getitem__(self, index):
gt_path = self.data_info['gt_path'][index]
imgs_lq = utils_video.read_img_seq(lq_path)
img_gt = utils_video.read_img_seq([gt_path])
img_gt.squeeze_(0)

if self.pad_sequence: # pad the sequence: 7 frames to 8 frames
imgs_lq = torch.cat([imgs_lq, imgs_lq[-1:,...]], dim=0)
Expand All @@ -285,9 +286,140 @@ def __getitem__(self, index):
'folder': self.data_info['folder'][index], # folder name
'idx': self.data_info['idx'][index], # e.g., 0/843
'border': self.data_info['border'][index], # 0 for non-border
'lq_path': lq_path[self.opt['num_frame'] // 2] # center frame
'lq_path': lq_path,
'gt_path': [gt_path]
}

def __len__(self):
return len(self.data_info['gt_path'])


class VFI_DAVIS(data.Dataset):
"""Video test dataset for DAVIS dataset in video frame interpolation.
Modified from https://github.com/tarun005/FLAVR/blob/main/dataset/Davis_test.py
"""

def __init__(self, data_root, ext="png"):

super().__init__()

self.data_root = data_root
self.images_sets = []

for label_id in os.listdir(self.data_root):
ctg_imgs_ = sorted(os.listdir(os.path.join(self.data_root , label_id)))
ctg_imgs_ = [os.path.join(self.data_root , label_id , img_id) for img_id in ctg_imgs_]
for start_idx in range(0,len(ctg_imgs_)-6,2):
add_files = ctg_imgs_[start_idx : start_idx+7 : 2]
add_files = add_files[:2] + [ctg_imgs_[start_idx+3]] + add_files[2:]
self.images_sets.append(add_files)

self.transforms = transforms.Compose([
transforms.CenterCrop((480, 840)),
transforms.ToTensor()
])

def __getitem__(self, idx):

imgpaths = self.images_sets[idx]
images = [Image.open(img) for img in imgpaths]
images = [self.transforms(img) for img in images]

return {
'L': torch.stack(images[:2] + images[3:], 0),
'H': images[2].unsqueeze(0),
'folder': str(idx),
'gt_path': ['vfi_result.png'],
}

def __len__(self):
return len(self.images_sets)


class VFI_UCF101(data.Dataset):
"""Video test dataset for UCF101 dataset in video frame interpolation.
Modified from https://github.com/tarun005/FLAVR/blob/main/dataset/ucf101_test.py
"""

def __init__(self, data_root, ext="png"):
super().__init__()

self.data_root = data_root
self.file_list = sorted(os.listdir(self.data_root))

self.transforms = transforms.Compose([
transforms.CenterCrop((224,224)),
transforms.ToTensor(),
])

def __getitem__(self, idx):

imgpath = os.path.join(self.data_root , self.file_list[idx])
imgpaths = [os.path.join(imgpath , "frame0.png") , os.path.join(imgpath , "frame1.png") ,os.path.join(imgpath , "frame2.png") ,os.path.join(imgpath , "frame3.png") ,os.path.join(imgpath , "framet.png")]

images = [Image.open(img) for img in imgpaths]
images = [self.transforms(img) for img in images]

return {
'L': torch.stack(images[:-1], 0),
'H': images[-1].unsqueeze(0),
'folder': self.file_list[idx],
'gt_path': ['vfi_result.png'],
}

def __len__(self):
return len(self.file_list)


class VFI_Vid4(data.Dataset):
"""Video test dataset for Vid4 dataset in video frame interpolation.
Modified from https://github.com/tarun005/FLAVR/blob/main/dataset/Davis_test.py
"""

def __init__(self, data_root, ext="png"):

super().__init__()

self.data_root = data_root
self.images_sets = []
self.data_info = {'lq_path': [], 'gt_path': [], 'folder': []}
self.lq_path = []
self.folder = []

for label_id in os.listdir(self.data_root):
ctg_imgs_ = sorted(os.listdir(os.path.join(self.data_root, label_id)))
ctg_imgs_ = [os.path.join(self.data_root , label_id , img_id) for img_id in ctg_imgs_]
if len(ctg_imgs_) % 2 == 0:
ctg_imgs_.append(ctg_imgs_[-1])
ctg_imgs_.insert(0, None)
ctg_imgs_.insert(0, ctg_imgs_[1])
ctg_imgs_.append(None)
ctg_imgs_.append(ctg_imgs_[-2])

for start_idx in range(0,len(ctg_imgs_)-6,2):
add_files = ctg_imgs_[start_idx : start_idx+7 : 2]
self.data_info['lq_path'].append([os.path.basename(path) for path in add_files])
self.data_info['gt_path'].append(os.path.basename(ctg_imgs_[start_idx + 3]))
self.data_info['folder'].append(label_id)
add_files = add_files[:2] + [ctg_imgs_[start_idx+3]] + add_files[2:]
self.images_sets.append(add_files)

self.transforms = transforms.Compose([
transforms.ToTensor()
])

def __getitem__(self, idx):
imgpaths = self.images_sets[idx]
images = [Image.open(img) for img in imgpaths]
images = [self.transforms(img) for img in images]

return {
'L': torch.stack(images[:2] + images[3:], 0),
'H': images[2].unsqueeze(0),
'folder': self.data_info['folder'][idx],
'lq_path': self.data_info['lq_path'][idx],
'gt_path': [self.data_info['gt_path'][idx]]
}

def __len__(self):
return len(self.images_sets)
73 changes: 70 additions & 3 deletions data/dataset_video_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import torch
from pathlib import Path
import torch.utils.data as data
from torchvision import transforms

import utils.utils_video as utils_video

Expand Down Expand Up @@ -302,6 +303,7 @@ def __init__(self, opt):
super(VideoRecurrentTrainVimeoDataset, self).__init__()
self.opt = opt
self.gt_root, self.lq_root = Path(opt['dataroot_gt']), Path(opt['dataroot_lq'])
self.temporal_scale = opt.get('temporal_scale', 1)

with open(opt['meta_info_file'], 'r') as fin:
self.keys = [line.split(' ')[0] for line in fin]
Expand All @@ -316,15 +318,14 @@ def __init__(self, opt):
self.io_backend_opt['client_keys'] = ['lq', 'gt']

# indices of input images
self.neighbor_list = [i + (9 - opt['num_frame']) // 2 for i in range(opt['num_frame'])]
self.neighbor_list = [i + (9 - opt['num_frame']) // 2 for i in range(opt['num_frame'])][::self.temporal_scale]

# temporal augmentation configs
self.random_reverse = opt['random_reverse']
print(f'Random reverse is {self.random_reverse}.')

self.mirror_sequence = opt.get('mirror_sequence', False)
self.pad_sequence = opt.get('pad_sequence', False)
self.neighbor_list = [1, 2, 3, 4, 5, 6, 7]

def __getitem__(self, index):
if self.file_client is None:
Expand Down Expand Up @@ -378,9 +379,75 @@ def __getitem__(self, index):
img_gts = torch.cat([img_gts, img_gts[-1:,...]], dim=0)

# img_lqs: (t, c, h, w)
# img_gt: (c, h, w)
# img_gt: (t, c, h, w)
# key: str
return {'L': img_lqs, 'H': img_gts, 'key': key}

def __len__(self):
return len(self.keys)

class VideoRecurrentTrainVimeoVFIDataset(VideoRecurrentTrainVimeoDataset):

def __init__(self, opt):
super(VideoRecurrentTrainVimeoVFIDataset, self).__init__(opt)
self.color_jitter = self.opt.get('color_jitter', False)

if self.color_jitter:
self.transforms_color_jitter = transforms.ColorJitter(0.05, 0.05, 0.05, 0.05)

def __getitem__(self, index):
if self.file_client is None:
self.file_client = utils_video.FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt)

# random reverse
if self.random_reverse and random.random() < 0.5:
self.neighbor_list.reverse()

scale = self.opt['scale']
gt_size = self.opt['gt_size']
key = self.keys[index]
clip, seq = key.split('/') # key example: 00001/0001

# get the neighboring LQ and GT frames
img_lqs = []
img_gts = []
for neighbor in self.neighbor_list:
if self.is_lmdb:
img_lq_path = f'{clip}/{seq}/im{neighbor}'
else:
img_lq_path = self.lq_root / clip / seq / f'im{neighbor}.png'
# LQ
img_bytes = self.file_client.get(img_lq_path, 'lq')
img_lq = utils_video.imfrombytes(img_bytes, float32=True)
img_lqs.append(img_lq)

# GT
if self.is_lmdb:
img_gt_path = f'{clip}/{seq}/im4'
else:
img_gt_path = self.gt_root / clip / seq / 'im4.png'

img_bytes = self.file_client.get(img_gt_path, 'gt')
img_gt = utils_video.imfrombytes(img_bytes, float32=True)
img_gts.append(img_gt)

# randomly crop
img_gts, img_lqs = utils_video.paired_random_crop(img_gts, img_lqs, gt_size, scale, img_gt_path)

# augmentation - flip, rotate
img_lqs.extend([img_gts])
img_results = utils_video.augment(img_lqs, self.opt['use_hflip'], self.opt['use_rot'])

img_results = utils_video.img2tensor(img_results)
img_results = torch.stack(img_results, dim=0)

if self.color_jitter: # same color_jitter for img_lqs and img_gts
img_results = self.transforms_color_jitter(img_results)

img_lqs = img_results[:-1, ...]
img_gts = img_results[-1:, ...]

# img_lqs: (t, c, h, w)
# img_gt: (t, c, h, w)
# key: str
return {'L': img_lqs, 'H': img_gts, 'key': key}
Loading

0 comments on commit 44c9eda

Please sign in to comment.