diff --git a/README.md b/README.md index 53d9cbbe..7894aa7f 100644 --- a/README.md +++ b/README.md @@ -6,6 +6,8 @@ *[Computer Vision Lab](https://vision.ee.ethz.ch/the-institute.html), ETH Zurich, Switzerland* _______ +- **_News (2022-06-01)_**: We release [the training codes](https://github.com/cszn/KAIR/blob/master/docs/README_RVRT.md) of [RVRT ![GitHub Stars](https://img.shields.io/github/stars/JingyunLiang/RVRT?style=social)](https://github.com/JingyunLiang/RVRT) for video SR, deblurring and denoising. + - **_News (2022-05-05)_**: Try the [online demo](https://replicate.com/cszn/scunet) of [SCUNet ![GitHub Stars](https://img.shields.io/github/stars/cszn/SCUNet?style=social)](https://github.com/cszn/SCUNet) for blind real image denoising. - **_News (2022-03-23)_**: We release [the testing codes](https://github.com/cszn/SCUNet) of [SCUNet ![GitHub Stars](https://img.shields.io/github/stars/cszn/SCUNet?style=social)](https://github.com/cszn/SCUNet) for blind real image denoising. diff --git a/data/dataset_video_test.py b/data/dataset_video_test.py index e3614413..512bd55c 100755 --- a/data/dataset_video_test.py +++ b/data/dataset_video_test.py @@ -129,7 +129,7 @@ def __len__(self): class SingleVideoRecurrentTestDataset(data.Dataset): - """Single ideo test dataset for recurrent architectures, which takes LR video + """Single video test dataset for recurrent architectures, which takes LR video frames as input and output corresponding HR video frames (only input LQ path). More generally, it supports testing dataset with following structures: @@ -245,11 +245,12 @@ def __init__(self, opt): super(VideoTestVimeo90KDataset, self).__init__() self.opt = opt self.cache_data = opt['cache_data'] + temporal_scale = opt.get('temporal_scale', 1) if self.cache_data: raise NotImplementedError('cache_data in Vimeo90K-Test dataset is not implemented.') self.gt_root, self.lq_root = opt['dataroot_gt'], opt['dataroot_lq'] self.data_info = {'lq_path': [], 'gt_path': [], 'folder': [], 'idx': [], 'border': []} - neighbor_list = [i + (9 - opt['num_frame']) // 2 for i in range(opt['num_frame'])] + neighbor_list = [i + (9 - opt['num_frame']) // 2 for i in range(opt['num_frame'])][:: temporal_scale] with open(opt['meta_info_file'], 'r') as fin: subfolders = [line.split(' ')[0] for line in fin] @@ -263,6 +264,7 @@ def __init__(self, opt): self.data_info['border'].append(0) self.pad_sequence = opt.get('pad_sequence', False) + self.mirror_sequence = opt.get('mirror_sequence', False) def __getitem__(self, index): lq_path = self.data_info['lq_path'][index] @@ -274,6 +276,9 @@ def __getitem__(self, index): if self.pad_sequence: # pad the sequence: 7 frames to 8 frames imgs_lq = torch.cat([imgs_lq, imgs_lq[-1:,...]], dim=0) + if self.mirror_sequence: # mirror the sequence: 7 frames to 14 frames + imgs_lq = torch.cat([imgs_lq, imgs_lq.flip(0)], dim=0) + return { 'L': imgs_lq, # (t, c, h, w) 'H': img_gt, # (c, h, w) @@ -286,97 +291,3 @@ def __getitem__(self, index): def __len__(self): return len(self.data_info['gt_path']) - -class SingleVideoRecurrentTestDataset(data.Dataset): - """Single Video test dataset (only input LQ path). - - Supported datasets: Vid4, REDS4, REDSofficial. - More generally, it supports testing dataset with following structures: - - dataroot - ├── subfolder1 - ├── frame000 - ├── frame001 - ├── ... - ├── subfolder1 - ├── frame000 - ├── frame001 - ├── ... - ├── ... - - For testing datasets, there is no need to prepare LMDB files. - - Args: - opt (dict): Config for train dataset. It contains the following keys: - dataroot_gt (str): Data root path for gt. - dataroot_lq (str): Data root path for lq. - io_backend (dict): IO backend type and other kwarg. - cache_data (bool): Whether to cache testing datasets. - name (str): Dataset name. - meta_info_file (str): The path to the file storing the list of test - folders. If not provided, all the folders in the dataroot will - be used. - num_frame (int): Window size for input frames. - padding (str): Padding mode. - """ - - def __init__(self, opt): - super(SingleVideoRecurrentTestDataset, self).__init__() - self.opt = opt - self.cache_data = opt['cache_data'] - self.lq_root = opt['dataroot_lq'] - self.data_info = {'lq_path': [], 'folder': [], 'idx': [], 'border': []} - # file client (io backend) - self.file_client = None - - self.imgs_lq = {} - if 'meta_info_file' in opt: - with open(opt['meta_info_file'], 'r') as fin: - subfolders = [line.split(' ')[0] for line in fin] - subfolders_lq = [osp.join(self.lq_root, key) for key in subfolders] - else: - subfolders_lq = sorted(glob.glob(osp.join(self.lq_root, '*'))) - - for subfolder_lq in subfolders_lq: - # get frame list for lq and gt - subfolder_name = osp.basename(subfolder_lq) - img_paths_lq = sorted(list(utils_video.scandir(subfolder_lq, full_path=True))) - - max_idx = len(img_paths_lq) - - self.data_info['lq_path'].extend(img_paths_lq) - self.data_info['folder'].extend([subfolder_name] * max_idx) - for i in range(max_idx): - self.data_info['idx'].append(f'{i}/{max_idx}') - border_l = [0] * max_idx - for i in range(self.opt['num_frame'] // 2): - border_l[i] = 1 - border_l[max_idx - i - 1] = 1 - self.data_info['border'].extend(border_l) - - # cache data or save the frame list - if self.cache_data: - logger.info(f'Cache {subfolder_name} for VideoTestDataset...') - self.imgs_lq[subfolder_name] = utils_video.read_img_seq(img_paths_lq) - else: - self.imgs_lq[subfolder_name] = img_paths_lq - - # Find unique folder strings - self.folders = sorted(list(set(self.data_info['folder']))) - - def __getitem__(self, index): - folder = self.folders[index] - - if self.cache_data: - imgs_lq = self.imgs_lq[folder] - else: - imgs_lq = utils_video.read_img_seq(self.imgs_lq[folder]) - - return { - 'L': imgs_lq, - 'folder': folder, - 'lq_path': self.imgs_lq[folder], - } - - def __len__(self): - return len(self.folders) diff --git a/data/dataset_video_train.py b/data/dataset_video_train.py index c65b3fac..b596ed1a 100755 --- a/data/dataset_video_train.py +++ b/data/dataset_video_train.py @@ -322,7 +322,7 @@ def __init__(self, opt): self.random_reverse = opt['random_reverse'] print(f'Random reverse is {self.random_reverse}.') - self.flip_sequence = opt.get('flip_sequence', False) + self.mirror_sequence = opt.get('mirror_sequence', False) self.pad_sequence = opt.get('pad_sequence', False) self.neighbor_list = [1, 2, 3, 4, 5, 6, 7] @@ -370,7 +370,7 @@ def __getitem__(self, index): img_lqs = torch.stack(img_results[:7], dim=0) img_gts = torch.stack(img_results[7:], dim=0) - if self.flip_sequence: # flip the sequence: 7 frames to 14 frames + if self.mirror_sequence: # mirror the sequence: 7 frames to 14 frames img_lqs = torch.cat([img_lqs, img_lqs.flip(0)], dim=0) img_gts = torch.cat([img_gts, img_gts.flip(0)], dim=0) elif self.pad_sequence: # pad the sequence: 7 frames to 8 frames diff --git a/docs/README_RVRT.md b/docs/README_RVRT.md new file mode 100644 index 00000000..5c70d168 --- /dev/null +++ b/docs/README_RVRT.md @@ -0,0 +1,180 @@ +# [Recurrent Video Restoration Transformer with Guided Deformable Attention (RVRT)](https://github.com/JingyunLiang/RVRT) +[arxiv](https://arxiv.org/abs/2206.02146) +**|** +[supplementary](https://github.com/JingyunLiang/RVRT/releases/download/v0.0/RVRT_supplementary.pdf) +**|** +[pretrained models](https://github.com/JingyunLiang/RVRT/releases) +**|** +[visual results](https://github.com/JingyunLiang/RVRT/releases) +**|** +[original project page](https://github.com/JingyunLiang/RVRT) + +[![arXiv](https://img.shields.io/badge/arXiv-Paper-.svg)](https://arxiv.org/abs/2206.02146) +[![GitHub Stars](https://img.shields.io/github/stars/JingyunLiang/RVRT?style=social)](https://github.com/JingyunLiang/RVRT) +[![download](https://img.shields.io/github/downloads/JingyunLiang/RVRT/total.svg)](https://github.com/JingyunLiang/RVRT/releases) +![visitors](https://visitor-badge.glitch.me/badge?page_id=jingyunliang/RVRT) +[ google colab logo](https://colab.research.google.com/gist/JingyunLiang/23502e2c65d82144219fa3e3322e4fc3/rvrt-demo-on-video-restoration.ipynb) + +This is the readme of "Recurrent Video Restoration Transformer with Guided Deformable Attention" +([arxiv](https://arxiv.org/pdf/2206.02146.pdf), [supp](https://github.com/JingyunLiang/RVRT/releases/download/v0.0/RVRT_supplementary.pdf), [pretrained models](https://github.com/JingyunLiang/RVRT/releases), [visual results](https://github.com/JingyunLiang/RVRT/releases)). RVRT ahcieves state-of-the-art performance with balanced model size, testing memory and runtime in +- video SR (REDS, Vimeo90K, Vid4, UDM10) +- video deblurring (GoPro, DVD) +- video denoising (DAVIS, Set8) + +

+ + + + + +

+ +--- + +> Video restoration aims at restoring multiple high-quality frames from multiple low-quality frames. Existing video restoration methods generally fall into two extreme cases, i.e., they either restore all frames in parallel or restore the video frame by frame in a recurrent way, which would result in different merits and drawbacks. Typically, the former has the advantage of temporal information fusion. However, it suffers from large model size and intensive memory consumption; the latter has a relatively small model size as it shares parameters across frames; however, it lacks long-range dependency modeling ability and parallelizability. In this paper, we attempt to integrate the advantages of the two cases by proposing a recurrent video restoration transformer, namely RVRT. RVRT processes local neighboring frames in parallel within a globally recurrent framework which can achieve a good trade-off between model size, effectiveness, and efficiency. Specifically, RVRT divides the video into multiple clips and uses the previously inferred clip feature to estimate the subsequent clip feature. Within each clip, different frame features are jointly updated with implicit feature aggregation. Across different clips, the guided deformable attention is designed for clip-to-clip alignment, which predicts multiple relevant locations from the whole inferred clip and aggregates their features by the attention mechanism. Extensive experiments on video super-resolution, deblurring, and denoising show that the proposed RVRT achieves state-of-the-art performance on benchmark datasets with balanced model size, testing memory and runtime. +

+
+ + +

+ +#### Contents + +1. [Requirements](#Requirements) +1. [Quick Testing](#Quick-Testing) +1. [Training](#Training) +1. [Results](#Results) +1. [Citation](#Citation) +1. [License and Acknowledgement](#License-and-Acknowledgement) + + +## Requirements +> - Python 3.8, PyTorch >= 1.9.1 +> - Requirements: see requirements.txt +> - Platforms: Ubuntu 18.04, cuda-11.1 + +## Quick Testing +Following commands will download [pretrained models](https://github.com/JingyunLiang/RVRT/releases) and [test datasets](https://github.com/JingyunLiang/VRT/releases) **automatically** (except Vimeo-90K testing set). If out-of-memory, try to reduce `--tile` at the expense of slightly decreased performance. + +You can also try to test it on Colab[ google colab logo](https://colab.research.google.com/gist/JingyunLiang/23502e2c65d82144219fa3e3322e4fc3/rvrt-demo-on-video-restoration.ipynb), but the results may be slightly different due to `--tile` difference. +```bash +# download code +git clone https://github.com/JingyunLiang/RVRT +cd RVRT +pip install -r requirements.txt + +# 001, video sr trained on REDS, tested on REDS4 +python main_test_rvrt.py --task 001_RVRT_videosr_bi_REDS_30frames --folder_lq testsets/REDS4/sharp_bicubic --folder_gt testsets/REDS4/GT --tile 100 128 128 --tile_overlap 2 20 20 + +# 002, video sr trained on Vimeo (bicubic), tested on Vid4 and Vimeo +python main_test_rvrt.py --task 002_RVRT_videosr_bi_Vimeo_14frames --folder_lq testsets/Vid4/BIx4 --folder_gt testsets/Vid4/GT --tile 0 0 0 --tile_overlap 2 20 20 +python main_test_rvrt.py --task 002_RVRT_videosr_bi_Vimeo_14frames --folder_lq testsets/vimeo90k/vimeo_septuplet_matlabLRx4/sequences --folder_gt testsets/vimeo90k/vimeo_septuplet/sequences --tile 0 0 0 --tile_overlap 0 20 20 + +# 003, video sr trained on Vimeo (blur-downsampling), tested on Vid4, UDM10 and Vimeo +python main_test_rvrt.py --task 003_RVRT_videosr_bd_Vimeo_14frames --folder_lq testsets/Vid4/BDx4 --folder_gt testsets/Vid4/GT --tile 0 0 0 --tile_overlap 2 20 20 +python main_test_rvrt.py --task 003_RVRT_videosr_bd_Vimeo_14frames --folder_lq testsets/UDM10/BDx4 --folder_gt testsets/UDM10/GT --tile 0 0 0 --tile_overlap 2 20 20 +python main_test_rvrt.py --task 003_RVRT_videosr_bd_Vimeo_14frames --folder_lq testsets/vimeo90k/vimeo_septuplet_BDLRx4/sequences --folder_gt testsets/vimeo90k/vimeo_septuplet/sequences --tile 0 0 0 --tile_overlap 0 20 20 + +# 004, video deblurring trained and tested on DVD +python main_test_rvrt.py --task 004_RVRT_videodeblurring_DVD_16frames --folder_lq testsets/DVD10/test_GT_blurred --folder_gt testsets/DVD10/test_GT --tile 0 256 256 --tile_overlap 2 20 20 + +# 005, video deblurring trained and tested on GoPro +python main_test_rvrt.py --task 005_RVRT_videodeblurring_GoPro_16frames --folder_lq testsets/GoPro11/test_GT_blurred --folder_gt testsets/GoPro11/test_GT --tile 0 256 256 --tile_overlap 2 20 20 + +# 006, video denoising trained on DAVIS (noise level 0-50) and tested on Set8 and DAVIS +python main_test_rvrt.py --task 006_RVRT_videodenoising_DAVIS_16frames --sigma 50 --folder_lq testsets/Set8 --folder_gt testsets/Set8 --tile 0 256 256 --tile_overlap 2 20 20 +python main_test_rvrt.py --task 006_RVRT_videodenoising_DAVIS_16frames --sigma 50 --folder_lq testsets/DAVIS-test --folder_gt testsets/DAVIS-test --tile 0 256 256 --tile_overlap 2 20 20 + +# test on your own datasets (an example) +python main_test_rvrt.py --task 001_RVRT_videosr_bi_REDS_30frames --folder_lq testsets/your/own --tile 0 0 0 --tile_overlap 2 20 20 +``` + +**All visual results of RVRT can be downloaded [here](https://github.com/JingyunLiang/RVRT/releases)**. + + +## Training +The training and testing sets are as follows (see the [supplementary](https://github.com/JingyunLiang/RVRT/releases) for a detailed introduction of all datasets). For better I/O speed, use [create_lmdb.py](https://github.com/cszn/KAIR/tree/master/scripts/data_preparation/create_lmdb.py) to convert `.png` datasets to `.lmdb` datasets. + +Note: You do **NOT need** to prepare the datasets if you just want to test the model. `main_test_rvrt.py` will download the testing set automaticaly. + + +| Task | Training Set | Testing Set | Pretrained Model and Visual Results of RVRT | +|:--------------------------------------------------------------|:-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------:|:----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------:| :---: | +| video SR (setting 1, BI) | [REDS sharp & sharp_bicubic](https://seungjunnah.github.io/Datasets/reds.html) (266 videos, 266000 frames: train + val except REDS4)

*Use [regroup_reds_dataset.py](https://github.com/cszn/KAIR/tree/master/scripts/data_preparation/regroup_reds_dataset.py) to regroup and rename REDS val set | REDS4 (4 videos, 400 frames: 000, 011, 015, 020 of REDS) | [here](https://github.com/JingyunLiang/RVRT/releases) | +| video SR (setting 2 & 3, BI & BD) | [Vimeo90K](http://data.csail.mit.edu/tofu/dataset/vimeo_septuplet.zip) (64612 seven-frame videos as in `sep_trainlist.txt`)

* Use [generate_LR_Vimeo90K.m](https://github.com/cszn/KAIR/tree/master/scripts/matlab_scripts/generate_LR_Vimeo90K.m) and [generate_LR_Vimeo90K_BD.m](https://github.com/cszn/KAIR/tree/master/scripts/matlab_scripts/generate_LR_Vimeo90K_BD.m) to generate LR frames for bicubic and blur-downsampling VSR, respectively. | Vimeo90K-T (the rest 7824 7-frame videos) + [Vid4](https://drive.google.com/file/d/1ZuvNNLgR85TV_whJoHM7uVb-XW1y70DW/view) (4 videos) + [UDM10](https://www.terabox.com/web/share/link?surl=LMuQCVntRegfZSxn7s3hXw&path=%2Fproject%2Fpfnl) (10 videos)

*Use [prepare_UDM10.py](https://github.com/cszn/KAIR/tree/master/scripts/data_preparation/prepare_UDM10.py) to regroup and rename the UDM10 dataset | [here](https://github.com/JingyunLiang/RVRT/releases) | +| video deblurring (setting 1, motion blur) | [DVD](http://www.cs.ubc.ca/labs/imager/tr/2017/DeepVideoDeblurring/DeepVideoDeblurring_Dataset.zip) (61 videos, 5708 frames)

*Use [prepare_DVD.py](https://github.com/cszn/KAIR/tree/master/scripts/data_preparation/prepare_DVD.py) to regroup and rename the dataset. | DVD (10 videos, 1000 frames)

*Use [evaluate_video_deblurring.m](https://github.com/cszn/KAIR/tree/master/scripts/matlab_scripts/evaluate_video_deblurring.m) for final evaluation. | [here](https://github.com/JingyunLiang/RVRT/releases) | +| video deblurring (setting 2, motion blur) | [GoPro](http://data.cv.snu.ac.kr:8008/webdav/dataset/GOPRO/GOPRO_Large.zip) (22 videos, 2103 frames)

*Use [prepare_GoPro_as_video.py](https://github.com/cszn/KAIR/tree/master/scripts/data_preparation/prepare_GoPro_as_video.py) to regroup and rename the dataset. | GoPro (11 videos, 1111 frames)

*Use [evaluate_video_deblurring.m](https://github.com/cszn/KAIR/tree/master/scripts/matlab_scripts/evaluate_video_deblurring.m) for final evaluation. | [here](https://github.com/JingyunLiang/RVRT/releases) | +| video denoising (Gaussian noise) | [DAVIS-2017](https://data.vision.ee.ethz.ch/csergi/share/davis/DAVIS-2017-Unsupervised-trainval-480p.zip) (90 videos, 6208 frames)

*Use all files in DAVIS/JPEGImages/480p | [DAVIS-2017-test](https://github.com/JingyunLiang/RVRT/releases) (30 videos) + [Set8](https://www.dropbox.com/sh/20n4cscqkqsfgoj/AABGftyJuJDwuCLGczL-fKvBa/test_sequences?dl=0&subfolder_nav_tracking=1) (8 videos: tractor, touchdown, park_joy and sunflower selected from DERF + hypersmooth, motorbike, rafting and snowboard from GOPRO_540P) | [here](https://github.com/JingyunLiang/RVRT/releases) | + +Run following commands for training: +```bash +# download code +git clone https://github.com/cszn/KAIR +cd KAIR +pip install -r requirements.txt + +# 001, video sr trained on REDS, tested on REDS4 +python -m torch.distributed.launch --nproc_per_node=8 --master_port=1234 main_train_vrt.py --opt options/rvrt/001_train_rvrt_videosr_bi_reds_30frames.json --dist True + +# 002, video sr trained on Vimeo (bicubic), tested on Vid4 and Vimeo +python -m torch.distributed.launch --nproc_per_node=8 --master_port=1234 main_train_vrt.py --opt options/rvrt/002_train_rvrt_videosr_bi_vimeo_14frames.json --dist True + +# 003, video sr trained on Vimeo (blur-downsampling), tested on Vid4, Vimeo and UDM10 +python -m torch.distributed.launch --nproc_per_node=8 --master_port=1234 main_train_vrt.py --opt options/rvrt/003_train_rvrt_videosr_bd_vimeo_14frames.json --dist True + +# 004, video deblurring trained and tested on DVD +python -m torch.distributed.launch --nproc_per_node=8 --master_port=1234 main_train_vrt.py --opt options/rvrt/004_train_rvrt_videodeblurring_dvd.json --dist True + +# 005, video deblurring trained and tested on GoPro +python -m torch.distributed.launch --nproc_per_node=8 --master_port=1234 main_train_vrt.py --opt options/rvrt/005_train_rvrt_videodeblurring_gopro.json --dist True + +# 006, video denoising trained on DAVIS (noise level 0-50) and tested on Set8 and DAVIS +python -m torch.distributed.launch --nproc_per_node=8 --master_port=1234 main_train_vrt.py --opt options/rvrt/006_train_rvrt_videodenoising_davis.json --dist True +``` +Tip: The training process will terminate automatically at 20000 iteration due to a bug. Just resume training after that. +
+Bug +Bug: PyTorch DistributedDataParallel (DDP) does not support `torch.utils.checkpoint` well. To alleviate the problem, set `find_unused_parameters=False` when `use_checkpoint=True`. If there are other errors, make sure that unused parameters will not change during training loop and set `use_static_graph=True`. + +If you find a better solution, feel free to pull a request. Thank you. +
+ +## Results +We achieved state-of-the-art performance on video SR, video deblurring and video denoising. Detailed results can be found in the [paper](https://arxiv.org/abs/2206.02146). + +
+Video Super-Resolution (click me) +

+ + +

+
+ +
+Video Deblurring +

+ + +

+
+ +
+Video Denoising +

+ + +

+
+ + +## Citation + @article{liang2022rvrt, + title={Recurrent Video Restoration Transformer with Guided Deformable Attention}, + author={Liang, Jingyun and Fan, Yuchen and Xiang, Xiaoyu and Ranjan, Rakesh and Ilg, Eddy and Green, Simon and Cao, Jiezhang and Zhang, Kai and Timofte, Radu and Van Gool, Luc}, + journal={arXiv preprint arXiv:2206.02146}, + year={2022} + } + + +## License and Acknowledgement +This project is released under the CC-BY-NC license. We refer to codes from [KAIR](https://github.com/cszn/KAIR), [BasicSR](https://github.com/xinntao/BasicSR), [Video Swin Transformer](https://github.com/SwinTransformer/Video-Swin-Transformer) and [mmediting](https://github.com/open-mmlab/mmediting). Thanks for their awesome works. The majority of VRT is licensed under CC-BY-NC, however portions of the project are available under separate license terms: KAIR is licensed under the MIT License, BasicSR, Video Swin Transformer and mmediting are licensed under the Apache 2.0 license. \ No newline at end of file diff --git a/main_test_rvrt.py b/main_test_rvrt.py new file mode 100644 index 00000000..e0d6833f --- /dev/null +++ b/main_test_rvrt.py @@ -0,0 +1,336 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + + +import argparse +import cv2 +import glob +import os +import torch +import requests +import numpy as np +from os import path as osp +from collections import OrderedDict +from torch.utils.data import DataLoader + +from models.network_rvrt import RVRT as net +from utils import utils_image as util +from data.dataset_video_test import VideoRecurrentTestDataset, VideoTestVimeo90KDataset, SingleVideoRecurrentTestDataset + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument('--task', type=str, default='001_RVRT_videosr_bi_REDS_30frames', help='tasks: 001 to 006') + parser.add_argument('--sigma', type=int, default=0, help='noise level for denoising: 10, 20, 30, 40, 50') + parser.add_argument('--folder_lq', type=str, default='testsets/REDS4/sharp_bicubic', + help='input low-quality test video folder') + parser.add_argument('--folder_gt', type=str, default=None, + help='input ground-truth test video folder') + parser.add_argument('--tile', type=int, nargs='+', default=[100,128,128], + help='Tile size, [0,0,0] for no tile during testing (testing as a whole)') + parser.add_argument('--tile_overlap', type=int, nargs='+', default=[2,20,20], + help='Overlapping of different tiles') + parser.add_argument('--num_workers', type=int, default=16, help='number of workers in data loading') + parser.add_argument('--save_result', action='store_true', help='save resulting image') + args = parser.parse_args() + + # define model + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + model = prepare_model_dataset(args) + model.eval() + model = model.to(device) + if 'vimeo' in args.folder_lq.lower(): + test_set = VideoTestVimeo90KDataset({'dataroot_gt':args.folder_gt, 'dataroot_lq':args.folder_lq, + 'meta_info_file': "data/meta_info/meta_info_Vimeo90K_test_GT.txt", + 'mirror_sequence': True, 'num_frame': 7, 'cache_data': False}) + elif args.folder_gt is not None: + test_set = VideoRecurrentTestDataset({'dataroot_gt':args.folder_gt, 'dataroot_lq':args.folder_lq, + 'sigma':args.sigma, 'num_frame':-1, 'cache_data': False}) + else: + test_set = SingleVideoRecurrentTestDataset({'dataroot_gt':args.folder_gt, 'dataroot_lq':args.folder_lq, + 'sigma':args.sigma, 'num_frame':-1, 'cache_data': False}) + + test_loader = DataLoader(dataset=test_set, num_workers=args.num_workers, batch_size=1, shuffle=False) + + save_dir = f'results/{args.task}' + if args.save_result: + os.makedirs(save_dir, exist_ok=True) + test_results = OrderedDict() + test_results['psnr'] = [] + test_results['ssim'] = [] + test_results['psnr_y'] = [] + test_results['ssim_y'] = [] + + assert len(test_loader) != 0, f'No dataset found at {args.folder_lq}' + + for idx, batch in enumerate(test_loader): + lq = batch['L'].to(device) + folder = batch['folder'] + gt = batch['H'] if 'H' in batch else None + + # inference + with torch.no_grad(): + output = test_video(lq, model, args) + + if 'vimeo' in args.folder_lq.lower(): + output = (output[:, 3:4, :, :, :] + output[:, 10:11, :, :, :]) / 2 + gt = gt.unsqueeze(0) + batch['lq_path'] = [['im4.png']] + + test_results_folder = OrderedDict() + test_results_folder['psnr'] = [] + test_results_folder['ssim'] = [] + test_results_folder['psnr_y'] = [] + test_results_folder['ssim_y'] = [] + + for i in range(output.shape[1]): + # save image + img = output[:, i, ...].data.squeeze().float().cpu().clamp_(0, 1).numpy() + if img.ndim == 3: + img = np.transpose(img[[2, 1, 0], :, :], (1, 2, 0)) # CHW-RGB to HCW-BGR + img = (img * 255.0).round().astype(np.uint8) # float32 to uint8 + if args.save_result: + seq_ = osp.basename(batch['lq_path'][i][0]).split('.')[0] + os.makedirs(f'{save_dir}/{folder[0]}', exist_ok=True) + cv2.imwrite(f'{save_dir}/{folder[0]}/{seq_}.png', img) + + # evaluate psnr/ssim + if gt is not None: + img_gt = gt[:, i, ...].data.squeeze().float().cpu().clamp_(0, 1).numpy() + if img_gt.ndim == 3: + img_gt = np.transpose(img_gt[[2, 1, 0], :, :], (1, 2, 0)) # CHW-RGB to HCW-BGR + img_gt = (img_gt * 255.0).round().astype(np.uint8) # float32 to uint8 + img_gt = np.squeeze(img_gt) + + test_results_folder['psnr'].append(util.calculate_psnr(img, img_gt, border=0)) + test_results_folder['ssim'].append(util.calculate_ssim(img, img_gt, border=0)) + if img_gt.ndim == 3: # RGB image + img = util.bgr2ycbcr(img.astype(np.float32) / 255.) * 255. + img_gt = util.bgr2ycbcr(img_gt.astype(np.float32) / 255.) * 255. + test_results_folder['psnr_y'].append(util.calculate_psnr(img, img_gt, border=0)) + test_results_folder['ssim_y'].append(util.calculate_ssim(img, img_gt, border=0)) + else: + test_results_folder['psnr_y'] = test_results_folder['psnr'] + test_results_folder['ssim_y'] = test_results_folder['ssim'] + + if gt is not None: + psnr = sum(test_results_folder['psnr']) / len(test_results_folder['psnr']) + ssim = sum(test_results_folder['ssim']) / len(test_results_folder['ssim']) + psnr_y = sum(test_results_folder['psnr_y']) / len(test_results_folder['psnr_y']) + ssim_y = sum(test_results_folder['ssim_y']) / len(test_results_folder['ssim_y']) + test_results['psnr'].append(psnr) + test_results['ssim'].append(ssim) + test_results['psnr_y'].append(psnr_y) + test_results['ssim_y'].append(ssim_y) + print('Testing {:20s} ({:2d}/{}) - PSNR: {:.2f} dB; SSIM: {:.4f}; PSNR_Y: {:.2f} dB; SSIM_Y: {:.4f}'. + format(folder[0], idx, len(test_loader), psnr, ssim, psnr_y, ssim_y)) + else: + print('Testing {:20s} ({:2d}/{})'.format(folder[0], idx, len(test_loader))) + + # summarize psnr/ssim + if gt is not None: + ave_psnr = sum(test_results['psnr']) / len(test_results['psnr']) + ave_ssim = sum(test_results['ssim']) / len(test_results['ssim']) + ave_psnr_y = sum(test_results['psnr_y']) / len(test_results['psnr_y']) + ave_ssim_y = sum(test_results['ssim_y']) / len(test_results['ssim_y']) + print('\n{} \n-- Average PSNR: {:.2f} dB; SSIM: {:.4f}; PSNR_Y: {:.2f} dB; SSIM_Y: {:.4f}'. + format(save_dir, ave_psnr, ave_ssim, ave_psnr_y, ave_ssim_y)) + + +def prepare_model_dataset(args): + ''' prepare model and dataset according to args.task. ''' + + # define model + if args.task == '001_RVRT_videosr_bi_REDS_30frames': + model = net(upscale=4, clip_size=2, img_size=[2, 64, 64], window_size=[2, 8, 8], num_blocks=[1, 2, 1], + depths=[2, 2, 2], embed_dims=[144, 144, 144], num_heads=[6, 6, 6], + inputconv_groups=[1, 1, 1, 1, 1, 1], deformable_groups=12, attention_heads=12, + attention_window=[3, 3], cpu_cache_length=100) + datasets = ['REDS4'] + args.scale = 4 + args.window_size = [2,8,8] + args.nonblind_denoising = False + + elif args.task in ['002_RVRT_videosr_bi_Vimeo_14frames', '003_RVRT_videosr_bd_Vimeo_14frames']: + model = net(upscale=4, clip_size=2, img_size=[2, 64, 64], window_size=[2, 8, 8], num_blocks=[1, 2, 1], + depths=[2, 2, 2], embed_dims=[144, 144, 144], num_heads=[6, 6, 6], + inputconv_groups=[1, 1, 1, 1, 1, 1], deformable_groups=12, attention_heads=12, + attention_window=[3, 3], cpu_cache_length=100) + datasets = ['Vid4'] # 'Vimeo'. Vimeo dataset is too large. Please refer to #training to download it. + args.scale = 4 + args.window_size = [2,8,8] + args.nonblind_denoising = False + + elif args.task in ['004_RVRT_videodeblurring_DVD_16frames']: + model = net(upscale=1, clip_size=2, img_size=[2, 64, 64], window_size=[2, 8, 8], num_blocks=[1, 2, 1], + depths=[2, 2, 2], embed_dims=[192, 192, 192], num_heads=[6, 6, 6], + inputconv_groups=[1, 3, 3, 3, 3, 3], deformable_groups=12, attention_heads=12, + attention_window=[3, 3], cpu_cache_length=100) + datasets = ['DVD10'] + args.scale = 1 + args.window_size = [2,8,8] + args.nonblind_denoising = False + + elif args.task in ['005_RVRT_videodeblurring_GoPro_16frames']: + model = net(upscale=1, clip_size=2, img_size=[2, 64, 64], window_size=[2, 8, 8], num_blocks=[1, 2, 1], + depths=[2, 2, 2], embed_dims=[192, 192, 192], num_heads=[6, 6, 6], + inputconv_groups=[1, 3, 3, 3, 3, 3], deformable_groups=12, attention_heads=12, + attention_window=[3, 3], cpu_cache_length=100) + datasets = ['GoPro11-part1', 'GoPro11-part2'] + args.scale = 1 + args.window_size = [2,8,8] + args.nonblind_denoising = False + + elif args.task == '006_RVRT_videodenoising_DAVIS_16frames': + model = net(upscale=1, clip_size=2, img_size=[2, 64, 64], window_size=[2, 8, 8], num_blocks=[1, 2, 1], + depths=[2, 2, 2], embed_dims=[192, 192, 192], num_heads=[6, 6, 6], + inputconv_groups=[1, 3, 4, 6, 8, 4], deformable_groups=12, attention_heads=12, + attention_window=[3, 3], nonblind_denoising=True, cpu_cache_length=100) + datasets = ['Set8', 'DAVIS-test'] + args.scale = 1 + args.window_size = [2,8,8] + args.nonblind_denoising = True + + # download model + model_path = f'model_zoo/rvrt/{args.task}.pth' + if os.path.exists(model_path): + print(f'loading model from ./model_zoo/rvrt/{model_path}') + else: + os.makedirs(os.path.dirname(model_path), exist_ok=True) + url = 'https://github.com/JingyunLiang/RVRT/releases/download/v0.0/{}'.format(os.path.basename(model_path)) + r = requests.get(url, allow_redirects=True) + print(f'downloading model {model_path}') + open(model_path, 'wb').write(r.content) + + pretrained_model = torch.load(model_path) + model.load_state_dict(pretrained_model['params'] if 'params' in pretrained_model.keys() else pretrained_model) + + # download datasets + if os.path.exists(f'{args.folder_lq}'): + print(f'using dataset from {args.folder_lq}') + else: + if 'vimeo' in args.folder_lq.lower(): + print(f'Vimeo dataset is not at {args.folder_lq}! Please refer to #training of Readme.md to download it.') + else: + os.makedirs('testsets', exist_ok=True) + for dataset in datasets: + url = f'https://github.com/JingyunLiang/VRT/releases/download/v0.0/testset_{dataset}.tar.gz' + r = requests.get(url, allow_redirects=True) + print(f'downloading testing dataset {dataset}') + open(f'testsets/{dataset}.tar.gz', 'wb').write(r.content) + os.system(f'tar -xvf testsets/{dataset}.tar.gz -C testsets') + os.system(f'rm testsets/{dataset}.tar.gz') + + return model + + +def test_video(lq, model, args): + '''test the video as a whole or as clips (divided temporally). ''' + + num_frame_testing = args.tile[0] + if num_frame_testing: + # test as multiple clips if out-of-memory + sf = args.scale + num_frame_overlapping = args.tile_overlap[0] + not_overlap_border = False + b, d, c, h, w = lq.size() + c = c - 1 if args.nonblind_denoising else c + stride = num_frame_testing - num_frame_overlapping + d_idx_list = list(range(0, d-num_frame_testing, stride)) + [max(0, d-num_frame_testing)] + E = torch.zeros(b, d, c, h*sf, w*sf) + W = torch.zeros(b, d, 1, 1, 1) + + for d_idx in d_idx_list: + lq_clip = lq[:, d_idx:d_idx+num_frame_testing, ...] + out_clip = test_clip(lq_clip, model, args) + out_clip_mask = torch.ones((b, min(num_frame_testing, d), 1, 1, 1)) + + if not_overlap_border: + if d_idx < d_idx_list[-1]: + out_clip[:, -num_frame_overlapping//2:, ...] *= 0 + out_clip_mask[:, -num_frame_overlapping//2:, ...] *= 0 + if d_idx > d_idx_list[0]: + out_clip[:, :num_frame_overlapping//2, ...] *= 0 + out_clip_mask[:, :num_frame_overlapping//2, ...] *= 0 + + E[:, d_idx:d_idx+num_frame_testing, ...].add_(out_clip) + W[:, d_idx:d_idx+num_frame_testing, ...].add_(out_clip_mask) + output = E.div_(W) + else: + # test as one clip (the whole video) if you have enough memory + window_size = args.window_size + d_old = lq.size(1) + d_pad = (window_size[0] - d_old % window_size[0]) % window_size[0] + lq = torch.cat([lq, torch.flip(lq[:, -d_pad:, ...], [1])], 1) if d_pad else lq + output = test_clip(lq, model, args) + output = output[:, :d_old, :, :, :] + + return output + + +def test_clip(lq, model, args): + ''' test the clip as a whole or as patches. ''' + + sf = args.scale + window_size = args.window_size + size_patch_testing = args.tile[1] + assert size_patch_testing % window_size[-1] == 0, 'testing patch size should be a multiple of window_size.' + + if size_patch_testing: + # divide the clip to patches (spatially only, tested patch by patch) + overlap_size = args.tile_overlap[1] + not_overlap_border = True + + # test patch by patch + b, d, c, h, w = lq.size() + c = c - 1 if args.nonblind_denoising else c + stride = size_patch_testing - overlap_size + h_idx_list = list(range(0, h-size_patch_testing, stride)) + [max(0, h-size_patch_testing)] + w_idx_list = list(range(0, w-size_patch_testing, stride)) + [max(0, w-size_patch_testing)] + E = torch.zeros(b, d, c, h*sf, w*sf) + W = torch.zeros_like(E) + + for h_idx in h_idx_list: + for w_idx in w_idx_list: + in_patch = lq[..., h_idx:h_idx+size_patch_testing, w_idx:w_idx+size_patch_testing] + out_patch = model(in_patch).detach().cpu() + + out_patch_mask = torch.ones_like(out_patch) + + if not_overlap_border: + if h_idx < h_idx_list[-1]: + out_patch[..., -overlap_size//2:, :] *= 0 + out_patch_mask[..., -overlap_size//2:, :] *= 0 + if w_idx < w_idx_list[-1]: + out_patch[..., :, -overlap_size//2:] *= 0 + out_patch_mask[..., :, -overlap_size//2:] *= 0 + if h_idx > h_idx_list[0]: + out_patch[..., :overlap_size//2, :] *= 0 + out_patch_mask[..., :overlap_size//2, :] *= 0 + if w_idx > w_idx_list[0]: + out_patch[..., :, :overlap_size//2] *= 0 + out_patch_mask[..., :, :overlap_size//2] *= 0 + + E[..., h_idx*sf:(h_idx+size_patch_testing)*sf, w_idx*sf:(w_idx+size_patch_testing)*sf].add_(out_patch) + W[..., h_idx*sf:(h_idx+size_patch_testing)*sf, w_idx*sf:(w_idx+size_patch_testing)*sf].add_(out_patch_mask) + output = E.div_(W) + + else: + _, _, _, h_old, w_old = lq.size() + h_pad = (window_size[1] - h_old % window_size[1]) % window_size[1] + w_pad = (window_size[2] - w_old % window_size[2]) % window_size[2] + + lq = torch.cat([lq, torch.flip(lq[:, :, :, -h_pad:, :], [3])], 3) if h_pad else lq + lq = torch.cat([lq, torch.flip(lq[:, :, :, :, -w_pad:], [4])], 4) if w_pad else lq + + output = model(lq).detach().cpu() + + output = output[:, :, :, :h_old*sf, :w_old*sf] + + return output + + +if __name__ == '__main__': + main() diff --git a/main_train_vrt.py b/main_train_vrt.py index 0fa62647..3ce4db34 100644 --- a/main_train_vrt.py +++ b/main_train_vrt.py @@ -23,7 +23,7 @@ ''' # -------------------------------------------- -# training code for VRT +# training code for VRT/RVRT # -------------------------------------------- ''' diff --git a/models/network_rvrt.py b/models/network_rvrt.py new file mode 100755 index 00000000..69ccab4e --- /dev/null +++ b/models/network_rvrt.py @@ -0,0 +1,1198 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + + +import os +import warnings +import math +import torch +import torch.nn as nn +import torchvision +import torch.nn.functional as F +import torch.utils.checkpoint as checkpoint +from distutils.version import LooseVersion +import numpy as np +from functools import reduce, lru_cache +from operator import mul +from einops import rearrange +from einops.layers.torch import Rearrange +from .op.deform_attn import deform_attn, DeformAttnPack + + +def flow_warp(x, flow, interp_mode='bilinear', padding_mode='zeros', align_corners=True): + """Warp an image or feature map with optical flow. + + Args: + x (Tensor): Tensor with size (n, c, h, w). + flow (Tensor): Tensor with size (n, h, w, 2), normal value. + interp_mode (str): 'nearest' or 'bilinear' or 'nearest4'. Default: 'bilinear'. + padding_mode (str): 'zeros' or 'border' or 'reflection'. + Default: 'zeros'. + align_corners (bool): Before pytorch 1.3, the default value is + align_corners=True. After pytorch 1.3, the default value is + align_corners=False. Here, we use the True as default. + + + Returns: + Tensor: Warped image or feature map. + """ + n, _, h, w = x.size() + # create mesh grid + grid_y, grid_x = torch.meshgrid(torch.arange(0, h, dtype=x.dtype, device=x.device), + torch.arange(0, w, dtype=x.dtype, device=x.device)) + grid = torch.stack((grid_x, grid_y), 2).float() # W(x), H(y), 2 + grid.requires_grad = False + + vgrid = grid + flow + + # scale grid to [-1,1] + vgrid_x = 2.0 * vgrid[:, :, :, 0] / max(w - 1, 1) - 1.0 + vgrid_y = 2.0 * vgrid[:, :, :, 1] / max(h - 1, 1) - 1.0 + vgrid_scaled = torch.stack((vgrid_x, vgrid_y), dim=3) + + output = F.grid_sample(x, vgrid_scaled, mode=interp_mode, padding_mode=padding_mode, align_corners=align_corners) + + return output + + +def make_layer(block, num_blocks, **kwarg): + """Make layers by stacking the same blocks. + + Args: + block (nn.module): nn.module class for basic block. + num_blocks (int): number of blocks. + + Returns: + nn.Sequential: Stacked blocks in nn.Sequential. + """ + layers = [] + for _ in range(num_blocks): + layers.append(block(**kwarg)) + return nn.Sequential(*layers) + + +class BasicModule(nn.Module): + """Basic Module for SpyNet. + """ + + def __init__(self): + super(BasicModule, self).__init__() + + self.basic_module = nn.Sequential( + nn.Conv2d(in_channels=8, out_channels=32, kernel_size=7, stride=1, padding=3), nn.ReLU(inplace=False), + nn.Conv2d(in_channels=32, out_channels=64, kernel_size=7, stride=1, padding=3), nn.ReLU(inplace=False), + nn.Conv2d(in_channels=64, out_channels=32, kernel_size=7, stride=1, padding=3), nn.ReLU(inplace=False), + nn.Conv2d(in_channels=32, out_channels=16, kernel_size=7, stride=1, padding=3), nn.ReLU(inplace=False), + nn.Conv2d(in_channels=16, out_channels=2, kernel_size=7, stride=1, padding=3)) + + def forward(self, tensor_input): + return self.basic_module(tensor_input) + + +class SpyNet(nn.Module): + """SpyNet architecture. + + Args: + load_path (str): path for pretrained SpyNet. Default: None. + return_levels (list[int]): return flows of different levels. Default: [5]. + """ + + def __init__(self, load_path=None, return_levels=[5]): + super(SpyNet, self).__init__() + self.return_levels = return_levels + self.basic_module = nn.ModuleList([BasicModule() for _ in range(6)]) + if load_path: + if not os.path.exists(load_path): + import requests + url = 'https://github.com/JingyunLiang/RVRT/releases/download/v0.0/spynet_sintel_final-3d2a1287.pth' + r = requests.get(url, allow_redirects=True) + print(f'downloading SpyNet pretrained model from {url}') + os.makedirs(os.path.dirname(load_path), exist_ok=True) + open(load_path, 'wb').write(r.content) + + self.load_state_dict(torch.load(load_path, map_location=lambda storage, loc: storage)['params']) + + self.register_buffer('mean', torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)) + self.register_buffer('std', torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)) + + def preprocess(self, tensor_input): + tensor_output = (tensor_input - self.mean) / self.std + return tensor_output + + def process(self, ref, supp, w, h, w_floor, h_floor): + flow_list = [] + + ref = [self.preprocess(ref)] + supp = [self.preprocess(supp)] + + for level in range(5): + ref.insert(0, F.avg_pool2d(input=ref[0], kernel_size=2, stride=2, count_include_pad=False)) + supp.insert(0, F.avg_pool2d(input=supp[0], kernel_size=2, stride=2, count_include_pad=False)) + + flow = ref[0].new_zeros( + [ref[0].size(0), 2, + int(math.floor(ref[0].size(2) / 2.0)), + int(math.floor(ref[0].size(3) / 2.0))]) + + for level in range(len(ref)): + upsampled_flow = F.interpolate(input=flow, scale_factor=2, mode='bilinear', align_corners=True) * 2.0 + + if upsampled_flow.size(2) != ref[level].size(2): + upsampled_flow = F.pad(input=upsampled_flow, pad=[0, 0, 0, 1], mode='replicate') + if upsampled_flow.size(3) != ref[level].size(3): + upsampled_flow = F.pad(input=upsampled_flow, pad=[0, 1, 0, 0], mode='replicate') + + flow = self.basic_module[level](torch.cat([ + ref[level], + flow_warp( + supp[level], upsampled_flow.permute(0, 2, 3, 1), interp_mode='bilinear', padding_mode='border'), + upsampled_flow + ], 1)) + upsampled_flow + + if level in self.return_levels: + scale = 2 ** (5 - level) # level=5 (scale=1), level=4 (scale=2), level=3 (scale=4), level=2 (scale=8) + flow_out = F.interpolate(input=flow, size=(h // scale, w // scale), mode='bilinear', + align_corners=False) + flow_out[:, 0, :, :] *= float(w // scale) / float(w_floor // scale) + flow_out[:, 1, :, :] *= float(h // scale) / float(h_floor // scale) + flow_list.insert(0, flow_out) + + return flow_list + + def forward(self, ref, supp): + assert ref.size() == supp.size() + + h, w = ref.size(2), ref.size(3) + w_floor = math.floor(math.ceil(w / 32.0) * 32.0) + h_floor = math.floor(math.ceil(h / 32.0) * 32.0) + + ref = F.interpolate(input=ref, size=(h_floor, w_floor), mode='bilinear', align_corners=False) + supp = F.interpolate(input=supp, size=(h_floor, w_floor), mode='bilinear', align_corners=False) + + flow_list = self.process(ref, supp, w, h, w_floor, h_floor) + + return flow_list[0] if len(flow_list) == 1 else flow_list + + +class GuidedDeformAttnPack(DeformAttnPack): + """Guided deformable attention module. + + Args: + in_channels (int): Same as nn.Conv2d. + out_channels (int): Same as nn.Conv2d. + attention_window (int or tuple[int]): Attention window size. Default: [3, 3]. + attention_heads (int): Attention head number. Default: 12. + deformable_groups (int): Deformable offset groups. Default: 12. + clip_size (int): clip size. Default: 2. + max_residue_magnitude (int): The maximum magnitude of the offset residue. Default: 10. + Ref: + Recurrent Video Restoration Transformer with Guided Deformable Attention + + """ + + def __init__(self, *args, **kwargs): + self.max_residue_magnitude = kwargs.pop('max_residue_magnitude', 10) + + super(GuidedDeformAttnPack, self).__init__(*args, **kwargs) + + self.conv_offset = nn.Sequential( + nn.Conv3d(self.in_channels * (1 + self.clip_size) + self.clip_size * 2, 64, kernel_size=(1, 1, 1), + padding=(0, 0, 0)), + nn.LeakyReLU(negative_slope=0.1, inplace=True), + nn.Conv3d(64, 64, kernel_size=(1, 3, 3), padding=(0, 1, 1)), + nn.LeakyReLU(negative_slope=0.1, inplace=True), + nn.Conv3d(64, 64, kernel_size=(1, 3, 3), padding=(0, 1, 1)), + nn.LeakyReLU(negative_slope=0.1, inplace=True), + nn.Conv3d(64, 64, kernel_size=(1, 3, 3), padding=(0, 1, 1)), + nn.LeakyReLU(negative_slope=0.1, inplace=True), + nn.Conv3d(64, 64, kernel_size=(1, 3, 3), padding=(0, 1, 1)), + nn.LeakyReLU(negative_slope=0.1, inplace=True), + nn.Conv3d(64, self.clip_size * self.deformable_groups * self.attn_size * 2, kernel_size=(1, 1, 1), + padding=(0, 0, 0)), + ) + self.init_offset() + + # proj to a higher dimension can slightly improve the performance + self.proj_channels = int(self.in_channels * 2) + self.proj_q = nn.Sequential(Rearrange('n d c h w -> n d h w c'), + nn.Linear(self.in_channels, self.proj_channels), + Rearrange('n d h w c -> n d c h w')) + self.proj_k = nn.Sequential(Rearrange('n d c h w -> n d h w c'), + nn.Linear(self.in_channels, self.proj_channels), + Rearrange('n d h w c -> n d c h w')) + self.proj_v = nn.Sequential(Rearrange('n d c h w -> n d h w c'), + nn.Linear(self.in_channels, self.proj_channels), + Rearrange('n d h w c -> n d c h w')) + self.proj = nn.Sequential(Rearrange('n d c h w -> n d h w c'), + nn.Linear(self.proj_channels, self.in_channels), + Rearrange('n d h w c -> n d c h w')) + self.mlp = nn.Sequential(Rearrange('n d c h w -> n d h w c'), + Mlp(self.in_channels, self.in_channels * 2, self.in_channels), + Rearrange('n d h w c -> n d c h w')) + + def init_offset(self): + if hasattr(self, 'conv_offset'): + self.conv_offset[-1].weight.data.zero_() + self.conv_offset[-1].bias.data.zero_() + + def forward(self, q, k, v, v_prop_warped, flows, return_updateflow): + offset1, offset2 = torch.chunk(self.max_residue_magnitude * torch.tanh( + self.conv_offset(torch.cat([q] + v_prop_warped + flows, 2).transpose(1, 2)).transpose(1, 2)), 2, dim=2) + offset1 = offset1 + flows[0].flip(2).repeat(1, 1, offset1.size(2) // 2, 1, 1) + offset2 = offset2 + flows[1].flip(2).repeat(1, 1, offset2.size(2) // 2, 1, 1) + offset = torch.cat([offset1, offset2], dim=2).flatten(0, 1) + + b, t, c, h, w = offset1.shape + q = self.proj_q(q).view(b * t, 1, self.proj_channels, h, w) + kv = torch.cat([self.proj_k(k), self.proj_v(v)], 2) + v = deform_attn(q, kv, offset, self.kernel_h, self.kernel_w, self.stride, self.padding, self.dilation, + self.attention_heads, self.deformable_groups, self.clip_size).view(b, t, self.proj_channels, h, + w) + v = self.proj(v) + v = v + self.mlp(v) + + if return_updateflow: + return v, offset1.view(b, t, c // 2, 2, h, w).mean(2).flip(2), offset2.view(b, t, c // 2, 2, h, w).mean( + 2).flip(2) + else: + return v + + +def window_partition(x, window_size): + """ Partition the input into windows. Attention will be conducted within the windows. + + Args: + x: (B, D, H, W, C) + window_size (tuple[int]): window size + + Returns: + windows: (B*num_windows, window_size*window_size, C) + """ + B, D, H, W, C = x.shape + x = x.view(B, D // window_size[0], window_size[0], H // window_size[1], window_size[1], W // window_size[2], + window_size[2], C) + windows = x.permute(0, 1, 3, 5, 2, 4, 6, 7).contiguous().view(-1, reduce(mul, window_size), C) + + return windows + + +def window_reverse(windows, window_size, B, D, H, W): + """ Reverse windows back to the original input. Attention was conducted within the windows. + + Args: + windows: (B*num_windows, window_size, window_size, C) + window_size (tuple[int]): Window size + H (int): Height of image + W (int): Width of image + + Returns: + x: (B, D, H, W, C) + """ + x = windows.view(B, D // window_size[0], H // window_size[1], W // window_size[2], window_size[0], window_size[1], + window_size[2], -1) + x = x.permute(0, 1, 4, 2, 5, 3, 6, 7).contiguous().view(B, D, H, W, -1) + + return x + + +def get_window_size(x_size, window_size, shift_size=None): + """ Get the window size and the shift size """ + + use_window_size = list(window_size) + if shift_size is not None: + use_shift_size = list(shift_size) + for i in range(len(x_size)): + if x_size[i] <= window_size[i]: + use_window_size[i] = x_size[i] + if shift_size is not None: + use_shift_size[i] = 0 + + if shift_size is None: + return tuple(use_window_size) + else: + return tuple(use_window_size), tuple(use_shift_size) + + +@lru_cache() +def compute_mask(D, H, W, window_size, shift_size, device): + """ Compute attnetion mask for input of size (D, H, W). @lru_cache caches each stage results. """ + + img_mask = torch.zeros((1, D, H, W, 1), device=device) # 1 Dp Hp Wp 1 + cnt = 0 + for d in slice(-window_size[0]), slice(-window_size[0], -shift_size[0]), slice(-shift_size[0], None): + for h in slice(-window_size[1]), slice(-window_size[1], -shift_size[1]), slice(-shift_size[1], None): + for w in slice(-window_size[2]), slice(-window_size[2], -shift_size[2]), slice(-shift_size[2], None): + img_mask[:, d, h, w, :] = cnt + cnt += 1 + mask_windows = window_partition(img_mask, window_size) # nW, ws[0]*ws[1]*ws[2], 1 + mask_windows = mask_windows.squeeze(-1) # nW, ws[0]*ws[1]*ws[2] + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) + + return attn_mask + + +class Mlp(nn.Module): + """ Multilayer perceptron. + + Args: + x: (B, D, H, W, C) + + Returns: + x: (B, D, H, W, C) + """ + + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + + def forward(self, x): + return self.fc2(self.act(self.fc1(x))) + + +class WindowAttention(nn.Module): + """ Window based multi-head self attention. + + Args: + dim (int): Number of input channels. + window_size (tuple[int]): The temporal length, height and width of the window. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set + """ + + def __init__(self, dim, window_size, num_heads, qkv_bias=False, qk_scale=None): + super().__init__() + self.window_size = window_size + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + + self.relative_position_bias_table = nn.Parameter( + torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1) * (2 * window_size[2] - 1), + num_heads)) # 2*Wd-1 * 2*Wh-1 * 2*Ww-1, nH + self.register_buffer("relative_position_index", self.get_position_index(window_size)) + self.qkv_self = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.proj = nn.Linear(dim, dim) + + def forward(self, x, mask=None): + """ Forward function. + + Args: + x: input features with shape of (num_windows*B, N, C) + mask: (0/-inf) mask with shape of (num_windows, N, N) or None + """ + + # self attention + B_, N, C = x.shape + qkv = self.qkv_self(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] # B_, nH, N, C + x_out = self.attention(q, k, v, mask, (B_, N, C)) + + # projection + x = self.proj(x_out) + + return x + + def attention(self, q, k, v, mask, x_shape): + B_, N, C = x_shape + attn = (q * self.scale) @ k.transpose(-2, -1) + + relative_position_bias = self.relative_position_bias_table[ + self.relative_position_index[:N, :N].reshape(-1)].reshape(N, N, -1) # Wd*Wh*Ww, Wd*Wh*Ww,nH + attn = attn + relative_position_bias.permute(2, 0, 1).unsqueeze(0) # B_, nH, N, N + + if mask is not None: + nW = mask.shape[0] + attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask[:, :N, :N].unsqueeze(1).unsqueeze(0) + attn = attn.view(-1, self.num_heads, N, N) + + attn = F.softmax(attn, -1, dtype=q.dtype) # Don't use attn.dtype after addition! + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + + return x + + def get_position_index(self, window_size): + ''' Get pair-wise relative position index for each token inside the window. ''' + + coords_d = torch.arange(window_size[0]) + coords_h = torch.arange(window_size[1]) + coords_w = torch.arange(window_size[2]) + coords = torch.stack(torch.meshgrid(coords_d, coords_h, coords_w)) # 3, Wd, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 3, Wd*Wh*Ww + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 3, Wd*Wh*Ww, Wd*Wh*Ww + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wd*Wh*Ww, Wd*Wh*Ww, 3 + relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += window_size[1] - 1 + relative_coords[:, :, 2] += window_size[2] - 1 + + relative_coords[:, :, 0] *= (2 * window_size[1] - 1) * (2 * window_size[2] - 1) + relative_coords[:, :, 1] *= (2 * window_size[2] - 1) + relative_position_index = relative_coords.sum(-1) # Wd*Wh*Ww, Wd*Wh*Ww + + return relative_position_index + + +class STL(nn.Module): + """ Swin Transformer Layer (STL). + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + num_heads (int): Number of attention heads. + window_size (tuple[int]): Window size. + shift_size (tuple[int]): Shift size for mutual and self attention. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True. + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + act_layer (nn.Module, optional): Activation layer. Default: nn.GELU. + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm. + use_checkpoint_attn (bool): If True, use torch.checkpoint for attention modules. Default: False. + use_checkpoint_ffn (bool): If True, use torch.checkpoint for feed-forward modules. Default: False. + """ + + def __init__(self, + dim, + input_resolution, + num_heads, + window_size=(2, 8, 8), + shift_size=(0, 0, 0), + mlp_ratio=2., + qkv_bias=True, + qk_scale=None, + act_layer=nn.GELU, + norm_layer=nn.LayerNorm, + use_checkpoint_attn=False, + use_checkpoint_ffn=False + ): + super().__init__() + self.input_resolution = input_resolution + self.num_heads = num_heads + self.window_size = window_size + self.shift_size = shift_size + self.use_checkpoint_attn = use_checkpoint_attn + self.use_checkpoint_ffn = use_checkpoint_ffn + + assert 0 <= self.shift_size[0] < self.window_size[0], "shift_size must in 0-window_size" + assert 0 <= self.shift_size[1] < self.window_size[1], "shift_size must in 0-window_size" + assert 0 <= self.shift_size[2] < self.window_size[2], "shift_size must in 0-window_size" + + self.norm1 = norm_layer(dim) + self.attn = WindowAttention(dim, window_size=self.window_size, num_heads=num_heads, qkv_bias=qkv_bias, + qk_scale=qk_scale) + self.norm2 = norm_layer(dim) + self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer) + + def forward_part1(self, x, mask_matrix): + B, D, H, W, C = x.shape + window_size, shift_size = get_window_size((D, H, W), self.window_size, self.shift_size) + + x = self.norm1(x) + + # pad feature maps to multiples of window size + pad_l = pad_t = pad_d0 = 0 + pad_d1 = (window_size[0] - D % window_size[0]) % window_size[0] + pad_b = (window_size[1] - H % window_size[1]) % window_size[1] + pad_r = (window_size[2] - W % window_size[2]) % window_size[2] + x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b, pad_d0, pad_d1), mode='constant') + + _, Dp, Hp, Wp, _ = x.shape + # cyclic shift + if any(i > 0 for i in shift_size): + shifted_x = torch.roll(x, shifts=(-shift_size[0], -shift_size[1], -shift_size[2]), dims=(1, 2, 3)) + attn_mask = mask_matrix + else: + shifted_x = x + attn_mask = None + + # partition windows + x_windows = window_partition(shifted_x, window_size) # B*nW, Wd*Wh*Ww, C + + # attention / shifted attention + attn_windows = self.attn(x_windows, mask=attn_mask) # B*nW, Wd*Wh*Ww, C + + # merge windows + attn_windows = attn_windows.view(-1, *(window_size + (C,))) + shifted_x = window_reverse(attn_windows, window_size, B, Dp, Hp, Wp) # B D' H' W' C + + # reverse cyclic shift + if any(i > 0 for i in shift_size): + x = torch.roll(shifted_x, shifts=(shift_size[0], shift_size[1], shift_size[2]), dims=(1, 2, 3)) + else: + x = shifted_x + + if pad_d1 > 0 or pad_r > 0 or pad_b > 0: + x = x[:, :D, :H, :W, :] + + return x + + def forward_part2(self, x): + return self.mlp(self.norm2(x)) + + def forward(self, x, mask_matrix): + """ Forward function. + + Args: + x: Input feature, tensor size (B, D, H, W, C). + mask_matrix: Attention mask for cyclic shift. + """ + + # attention + if self.use_checkpoint_attn: + x = x + checkpoint.checkpoint(self.forward_part1, x, mask_matrix) + else: + x = x + self.forward_part1(x, mask_matrix) + + # feed-forward + if self.use_checkpoint_ffn: + x = x + checkpoint.checkpoint(self.forward_part2, x) + else: + x = x + self.forward_part2(x) + + return x + + +class STG(nn.Module): + """ Swin Transformer Group (STG). + + Args: + dim (int): Number of feature channels + input_resolution (tuple[int]): Input resolution. + depth (int): Depths of this stage. + num_heads (int): Number of attention head. + window_size (tuple[int]): Local window size. Default: (6,8,8). + shift_size (tuple[int]): Shift size for mutual and self attention. Default: None. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 2. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + use_checkpoint_attn (bool): If True, use torch.checkpoint for attention modules. Default: False. + use_checkpoint_ffn (bool): If True, use torch.checkpoint for feed-forward modules. Default: False. + """ + + def __init__(self, + dim, + input_resolution, + depth, + num_heads, + window_size=[2, 8, 8], + shift_size=None, + mlp_ratio=2., + qkv_bias=False, + qk_scale=None, + norm_layer=nn.LayerNorm, + use_checkpoint_attn=False, + use_checkpoint_ffn=False, + ): + super().__init__() + self.input_resolution = input_resolution + self.window_size = window_size + self.shift_size = list(i // 2 for i in window_size) if shift_size is None else shift_size + + # build blocks + self.blocks = nn.ModuleList([ + STL( + dim=dim, + input_resolution=input_resolution, + num_heads=num_heads, + window_size=window_size, + shift_size=[0, 0, 0] if i % 2 == 0 else self.shift_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + norm_layer=norm_layer, + use_checkpoint_attn=use_checkpoint_attn, + use_checkpoint_ffn=use_checkpoint_ffn + ) + for i in range(depth)]) + + def forward(self, x): + """ Forward function. + + Args: + x: Input feature, tensor size (B, C, D, H, W). + """ + # calculate attention mask for attention + B, C, D, H, W = x.shape + window_size, shift_size = get_window_size((D, H, W), self.window_size, self.shift_size) + x = rearrange(x, 'b c d h w -> b d h w c') + Dp = int(np.ceil(D / window_size[0])) * window_size[0] + Hp = int(np.ceil(H / window_size[1])) * window_size[1] + Wp = int(np.ceil(W / window_size[2])) * window_size[2] + attn_mask = compute_mask(Dp, Hp, Wp, window_size, shift_size, x.device) + + for blk in self.blocks: + x = blk(x, attn_mask) + + x = x.view(B, D, H, W, -1) + x = rearrange(x, 'b d h w c -> b c d h w') + + return x + + +class RSTB(nn.Module): + """ Residual Swin Transformer Block (RSTB). + + Args: + kwargs: Args for RSTB. + """ + + def __init__(self, **kwargs): + super(RSTB, self).__init__() + self.input_resolution = kwargs['input_resolution'] + + self.residual_group = STG(**kwargs) + self.linear = nn.Linear(kwargs['dim'], kwargs['dim']) + + def forward(self, x): + return x + self.linear(self.residual_group(x).transpose(1, 4)).transpose(1, 4) + + +class RSTBWithInputConv(nn.Module): + """RSTB with a convolution in front. + + Args: + in_channels (int): Number of input channels of the first conv. + kernel_size (int): Size of kernel of the first conv. + stride (int): Stride of the first conv. + group (int): Group of the first conv. + num_blocks (int): Number of residual blocks. Default: 2. + **kwarg: Args for RSTB. + """ + + def __init__(self, in_channels=3, kernel_size=(1, 3, 3), stride=1, groups=1, num_blocks=2, **kwargs): + super().__init__() + + main = [] + main += [Rearrange('n d c h w -> n c d h w'), + nn.Conv3d(in_channels, + kwargs['dim'], + kernel_size=kernel_size, + stride=stride, + padding=(kernel_size[0] // 2, kernel_size[1] // 2, kernel_size[2] // 2), + groups=groups), + Rearrange('n c d h w -> n d h w c'), + nn.LayerNorm(kwargs['dim']), + Rearrange('n d h w c -> n c d h w')] + + # RSTB blocks + kwargs['use_checkpoint_attn'] = kwargs.pop('use_checkpoint_attn')[0] + kwargs['use_checkpoint_ffn'] = kwargs.pop('use_checkpoint_ffn')[0] + main.append(make_layer(RSTB, num_blocks, **kwargs)) + + main += [Rearrange('n c d h w -> n d h w c'), + nn.LayerNorm(kwargs['dim']), + Rearrange('n d h w c -> n d c h w')] + + self.main = nn.Sequential(*main) + + def forward(self, x): + """ + Forward function for RSTBWithInputConv. + + Args: + feat (Tensor): Input feature with shape (n, t, in_channels, h, w) + + Returns: + Tensor: Output feature with shape (n, t, out_channels, h, w) + """ + return self.main(x) + + +class Upsample(nn.Sequential): + """Upsample module for video SR. + + Args: + scale (int): Scale factor. Supported scales: 2^n and 3. + num_feat (int): Channel number of intermediate features. + """ + + def __init__(self, scale, num_feat): + assert LooseVersion(torch.__version__) >= LooseVersion('1.8.1'), \ + 'PyTorch version >= 1.8.1 to support 5D PixelShuffle.' + + m = [] + if (scale & (scale - 1)) == 0: # scale = 2^n + for _ in range(int(math.log(scale, 2))): + m.append(nn.Conv3d(num_feat, 4 * num_feat, kernel_size=(1, 3, 3), padding=(0, 1, 1))) + m.append(Rearrange('n c d h w -> n d c h w')) + m.append(nn.PixelShuffle(2)) + m.append(Rearrange('n c d h w -> n d c h w')) + m.append(nn.LeakyReLU(negative_slope=0.1, inplace=True)) + m.append(nn.Conv3d(num_feat, num_feat, kernel_size=(1, 3, 3), padding=(0, 1, 1))) + elif scale == 3: + m.append(nn.Conv3d(num_feat, 9 * num_feat, kernel_size=(1, 3, 3), padding=(0, 1, 1))) + m.append(Rearrange('n c d h w -> n d c h w')) + m.append(nn.PixelShuffle(3)) + m.append(Rearrange('n c d h w -> n d c h w')) + m.append(nn.LeakyReLU(negative_slope=0.1, inplace=True)) + m.append(nn.Conv3d(num_feat, num_feat, kernel_size=(1, 3, 3), padding=(0, 1, 1))) + else: + raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.') + super(Upsample, self).__init__(*m) + + +class RVRT(nn.Module): + """ Recurrent Video Restoration Transformer with Guided Deformable Attention (RVRT). + A PyTorch impl of : `Recurrent Video Restoration Transformer with Guided Deformable Attention` - + https://arxiv.org/pdf/2205.00000 + + Args: + upscale (int): Upscaling factor. Set as 1 for video deblurring, etc. Default: 4. + clip_size (int): Size of clip in recurrent restoration transformer. + img_size (int | tuple(int)): Size of input video. Default: [2, 64, 64]. + window_size (int | tuple(int)): Window size. Default: (2,8,8). + num_blocks (list[int]): Number of RSTB blocks in each stage. + depths (list[int]): Depths of each RSTB. + embed_dims (list[int]): Number of linear projection output channels. + num_heads (list[int]): Number of attention head of each stage. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 2. + qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True. + qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. + norm_layer (obj): Normalization layer. Default: nn.LayerNorm. + inputconv_groups (int): Group of the first convolution layer in RSTBWithInputConv. Default: [1,1,1,1,1,1] + spynet_path (str): Pretrained SpyNet model path. + deformable_groups (int): Number of deformable groups in deformable attention. Default: 12. + attention_heads (int): Number of attention heads in deformable attention. Default: 12. + attention_window (list[int]): Attention window size in aeformable attention. Default: [3, 3]. + nonblind_denoising (bool): If True, conduct experiments on non-blind denoising. Default: False. + use_checkpoint_attn (bool): If True, use torch.checkpoint for attention modules. Default: False. + use_checkpoint_ffn (bool): If True, use torch.checkpoint for feed-forward modules. Default: False. + no_checkpoint_attn_blocks (list[int]): Layers without torch.checkpoint for attention modules. + no_checkpoint_ffn_blocks (list[int]): Layers without torch.checkpoint for feed-forward modules. + cpu_cache_length: (int): Maximum video length without cpu caching. Default: 100. + """ + + def __init__(self, + upscale=4, + clip_size=2, + img_size=[2, 64, 64], + window_size=[2, 8, 8], + num_blocks=[1, 2, 1], + depths=[2, 2, 2], + embed_dims=[144, 144, 144], + num_heads=[6, 6, 6], + mlp_ratio=2., + qkv_bias=True, + qk_scale=None, + norm_layer=nn.LayerNorm, + inputconv_groups=[1, 1, 1, 1, 1, 1], + spynet_path=None, + max_residue_magnitude=10, + deformable_groups=12, + attention_heads=12, + attention_window=[3, 3], + nonblind_denoising=False, + use_checkpoint_attn=False, + use_checkpoint_ffn=False, + no_checkpoint_attn_blocks=[], + no_checkpoint_ffn_blocks=[], + cpu_cache_length=100 + ): + + super().__init__() + self.upscale = upscale + self.clip_size = clip_size + self.nonblind_denoising = nonblind_denoising + use_checkpoint_attns = [False if i in no_checkpoint_attn_blocks else use_checkpoint_attn for i in range(100)] + use_checkpoint_ffns = [False if i in no_checkpoint_ffn_blocks else use_checkpoint_ffn for i in range(100)] + self.cpu_cache_length = cpu_cache_length + + # optical flow + self.spynet = SpyNet(spynet_path) + + # shallow feature extraction + if self.upscale == 4: + # video sr + self.feat_extract = RSTBWithInputConv(in_channels=3, + kernel_size=(1, 3, 3), + groups=inputconv_groups[0], + num_blocks=num_blocks[0], + dim=embed_dims[0], + input_resolution=[1, img_size[1], img_size[2]], + depth=depths[0], + num_heads=num_heads[0], + window_size=[1, window_size[1], window_size[2]], + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + norm_layer=norm_layer, + use_checkpoint_attn=[False], + use_checkpoint_ffn=[False] + ) + else: + # video deblurring/denoising + self.feat_extract = nn.Sequential(Rearrange('n d c h w -> n c d h w'), + nn.Conv3d(4 if self.nonblind_denoising else 3, embed_dims[0], (1, 3, 3), + (1, 2, 2), (0, 1, 1)), + nn.LeakyReLU(negative_slope=0.1, inplace=True), + nn.Conv3d(embed_dims[0], embed_dims[0], (1, 3, 3), (1, 2, 2), (0, 1, 1)), + nn.LeakyReLU(negative_slope=0.1, inplace=True), + Rearrange('n c d h w -> n d c h w'), + RSTBWithInputConv( + in_channels=embed_dims[0], + kernel_size=(1, 3, 3), + groups=inputconv_groups[0], + num_blocks=num_blocks[0], + dim=embed_dims[0], + input_resolution=[1, img_size[1], img_size[2]], + depth=depths[0], + num_heads=num_heads[0], + window_size=[1, window_size[1], window_size[2]], + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + norm_layer=norm_layer, + use_checkpoint_attn=[False], + use_checkpoint_ffn=[False] + ) + ) + + # check if the sequence is augmented by flipping + self.is_mirror_extended = False + + # recurrent feature refinement + self.backbone = nn.ModuleDict() + self.deform_align = nn.ModuleDict() + modules = ['backward_1', 'forward_1', 'backward_2', 'forward_2'] + for i, module in enumerate(modules): + # deformable attention + self.deform_align[module] = GuidedDeformAttnPack(embed_dims[1], + embed_dims[1], + attention_window=attention_window, + attention_heads=attention_heads, + deformable_groups=deformable_groups, + clip_size=clip_size, + max_residue_magnitude=max_residue_magnitude) + + # feature propagation + self.backbone[module] = RSTBWithInputConv( + in_channels=(2 + i) * embed_dims[0], + kernel_size=(1, 3, 3), + groups=inputconv_groups[i + 1], + num_blocks=num_blocks[1], + dim=embed_dims[1], + input_resolution=img_size, + depth=depths[1], + num_heads=num_heads[1], + window_size=window_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + norm_layer=norm_layer, + use_checkpoint_attn=[use_checkpoint_attns[i]], + use_checkpoint_ffn=[use_checkpoint_ffns[i]] + ) + + # reconstruction + self.reconstruction = RSTBWithInputConv( + in_channels=5 * embed_dims[0], + kernel_size=(1, 3, 3), + groups=inputconv_groups[5], + num_blocks=num_blocks[2], + + dim=embed_dims[2], + input_resolution=[1, img_size[1], img_size[2]], + depth=depths[2], + num_heads=num_heads[2], + window_size=[1, window_size[1], window_size[2]], + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + norm_layer=norm_layer, + use_checkpoint_attn=[False], + use_checkpoint_ffn=[False] + ) + self.conv_before_upsampler = nn.Sequential( + nn.Conv3d(embed_dims[-1], 64, kernel_size=(1, 1, 1), + padding=(0, 0, 0)), + nn.LeakyReLU(negative_slope=0.1, inplace=True) + ) + self.upsampler = Upsample(4, 64) + self.conv_last = nn.Conv3d(64, 3, kernel_size=(1, 3, 3), padding=(0, 1, 1)) + + def compute_flow(self, lqs): + """Compute optical flow using SPyNet for feature alignment. + + Note that if the input is an mirror-extended sequence, 'flows_forward' + is not needed, since it is equal to 'flows_backward.flip(1)'. + + Args: + lqs (tensor): Input low quality (LQ) sequence with + shape (n, t, c, h, w). + + Return: + tuple(Tensor): Optical flow. 'flows_forward' corresponds to the + flows used for forward-time propagation (current to previous). + 'flows_backward' corresponds to the flows used for + backward-time propagation (current to next). + """ + + n, t, c, h, w = lqs.size() + lqs_1 = lqs[:, :-1, :, :, :].reshape(-1, c, h, w) + lqs_2 = lqs[:, 1:, :, :, :].reshape(-1, c, h, w) + + flows_backward = self.spynet(lqs_1, lqs_2).view(n, t - 1, 2, h, w) + + if self.is_mirror_extended: # flows_forward = flows_backward.flip(1) + flows_forward = None + else: + flows_forward = self.spynet(lqs_2, lqs_1).view(n, t - 1, 2, h, w) + + return flows_forward, flows_backward + + def check_if_mirror_extended(self, lqs): + """Check whether the input is a mirror-extended sequence. + + If mirror-extended, the i-th (i=0, ..., t-1) frame is equal to the + (t-1-i)-th frame. + + Args: + lqs (tensor): Input low quality (LQ) sequence with + shape (n, t, c, h, w). + """ + + if lqs.size(1) % 2 == 0: + lqs_1, lqs_2 = torch.chunk(lqs, 2, dim=1) + if torch.norm(lqs_1 - lqs_2.flip(1)) == 0: + self.is_mirror_extended = True + + def propagate(self, feats, flows, module_name, updated_flows=None): + """Propagate the latent clip features throughout the sequence. + + Args: + feats dict(list[tensor]): Features from previous branches. Each + component is a list of tensors with shape (n, clip_size, c, h, w). + flows (tensor): Optical flows with shape (n, t - 1, 2, h, w). + module_name (str): The name of the propgation branches. Can either + be 'backward_1', 'forward_1', 'backward_2', 'forward_2'. + updated_flows dict(list[tensor]): Each component is a list of updated + optical flows with shape (n, clip_size, 2, h, w). + + Return: + dict(list[tensor]): A dictionary containing all the propagated + features. Each key in the dictionary corresponds to a + propagation branch, which is represented by a list of tensors. + """ + + n, t, _, h, w = flows.size() + if 'backward' in module_name: + flow_idx = range(0, t + 1)[::-1] + clip_idx = range(0, (t + 1) // self.clip_size)[::-1] + else: + flow_idx = range(-1, t) + clip_idx = range(0, (t + 1) // self.clip_size) + + if '_1' in module_name: + updated_flows[f'{module_name}_n1'] = [] + updated_flows[f'{module_name}_n2'] = [] + + feat_prop = torch.zeros_like(feats['shallow'][0]) + if self.cpu_cache: + feat_prop = feat_prop.cuda() + + last_key = list(feats)[-2] + for i in range(0, len(clip_idx)): + idx_c = clip_idx[i] + if i > 0: + if '_1' in module_name: + flow_n01 = flows[:, flow_idx[self.clip_size * i - 1], :, :, :] + flow_n12 = flows[:, flow_idx[self.clip_size * i], :, :, :] + flow_n23 = flows[:, flow_idx[self.clip_size * i + 1], :, :, :] + flow_n02 = flow_n12 + flow_warp(flow_n01, flow_n12.permute(0, 2, 3, 1)) + flow_n13 = flow_n23 + flow_warp(flow_n12, flow_n23.permute(0, 2, 3, 1)) + flow_n03 = flow_n23 + flow_warp(flow_n02, flow_n23.permute(0, 2, 3, 1)) + flow_n1 = torch.stack([flow_n02, flow_n13], 1) + flow_n2 = torch.stack([flow_n12, flow_n03], 1) + if self.cpu_cache: + flow_n1 = flow_n1.cuda() + flow_n2 = flow_n2.cuda() + else: + module_name_old = module_name.replace('_2', '_1') + flow_n1 = updated_flows[f'{module_name_old}_n1'][i - 1] + flow_n2 = updated_flows[f'{module_name_old}_n2'][i - 1] + + if self.cpu_cache: + if 'backward' in module_name: + feat_q = feats[last_key][idx_c].flip(1).cuda() + feat_k = feats[last_key][clip_idx[i - 1]].flip(1).cuda() + else: + feat_q = feats[last_key][idx_c].cuda() + feat_k = feats[last_key][clip_idx[i - 1]].cuda() + else: + if 'backward' in module_name: + feat_q = feats[last_key][idx_c].flip(1) + feat_k = feats[last_key][clip_idx[i - 1]].flip(1) + else: + feat_q = feats[last_key][idx_c] + feat_k = feats[last_key][clip_idx[i - 1]] + + feat_prop_warped1 = flow_warp(feat_prop.flatten(0, 1), + flow_n1.permute(0, 1, 3, 4, 2).flatten(0, 1))\ + .view(n, feat_prop.shape[1], feat_prop.shape[2], h, w) + feat_prop_warped2 = flow_warp(feat_prop.flip(1).flatten(0, 1), + flow_n2.permute(0, 1, 3, 4, 2).flatten(0, 1))\ + .view(n, feat_prop.shape[1], feat_prop.shape[2], h, w) + + if '_1' in module_name: + feat_prop, flow_n1, flow_n2 = self.deform_align[module_name](feat_q, feat_k, feat_prop, + [feat_prop_warped1, feat_prop_warped2], + [flow_n1, flow_n2], + True) + updated_flows[f'{module_name}_n1'].append(flow_n1) + updated_flows[f'{module_name}_n2'].append(flow_n2) + else: + feat_prop = self.deform_align[module_name](feat_q, feat_k, feat_prop, + [feat_prop_warped1, feat_prop_warped2], + [flow_n1, flow_n2], + False) + + if 'backward' in module_name: + feat = [feats[k][idx_c].flip(1) for k in feats if k not in [module_name]] + [feat_prop] + else: + feat = [feats[k][idx_c] for k in feats if k not in [module_name]] + [feat_prop] + + if self.cpu_cache: + feat = [f.cuda() for f in feat] + + feat_prop = feat_prop + self.backbone[module_name](torch.cat(feat, dim=2)) + feats[module_name].append(feat_prop) + + if self.cpu_cache: + feats[module_name][-1] = feats[module_name][-1].cpu() + torch.cuda.empty_cache() + + if 'backward' in module_name: + feats[module_name] = feats[module_name][::-1] + feats[module_name] = [f.flip(1) for f in feats[module_name]] + + return feats + + def upsample(self, lqs, feats): + """Compute the output image given the features. + + Args: + lqs (tensor): Input low quality (LQ) sequence with + shape (n, t, c, h, w). + feats (dict): The features from the propgation branches. + + Returns: + Tensor: Output HR sequence with shape (n, t, c, 4h, 4w). + + """ + + feats['shallow'] = torch.cat(feats['shallow'], 1) + feats['backward_1'] = torch.cat(feats['backward_1'], 1) + feats['forward_1'] = torch.cat(feats['forward_1'], 1) + feats['backward_2'] = torch.cat(feats['backward_2'], 1) + feats['forward_2'] = torch.cat(feats['forward_2'], 1) + + if self.cpu_cache: + outputs = [] + for i in range(0, feats['shallow'].shape[1]): + hr = torch.cat([feats[k][:, i:i + 1, :, :, :] for k in feats], dim=2) + hr = self.reconstruction(hr.cuda()) + hr = self.conv_last(self.upsampler(self.conv_before_upsampler(hr.transpose(1, 2)))).transpose(1, 2) + hr += torch.nn.functional.interpolate(lqs[:, i:i + 1, :, :, :].cuda(), size=hr.shape[-3:], + mode='trilinear', align_corners=False) + hr = hr.cpu() + outputs.append(hr) + torch.cuda.empty_cache() + + return torch.cat(outputs, dim=1) + + else: + hr = torch.cat([feats[k] for k in feats], dim=2) + hr = self.reconstruction(hr) + hr = self.conv_last(self.upsampler(self.conv_before_upsampler(hr.transpose(1, 2)))).transpose(1, 2) + hr += torch.nn.functional.interpolate(lqs, size=hr.shape[-3:], mode='trilinear', align_corners=False) + + return hr + + def forward(self, lqs): + """Forward function for RVRT. + + Args: + lqs (tensor): Input low quality (LQ) sequence with + shape (n, t, c, h, w). + + Returns: + Tensor: Output HR sequence with shape (n, t, c, 4h, 4w). + """ + + n, t, _, h, w = lqs.size() + + # whether to cache the features in CPU + self.cpu_cache = True if t > self.cpu_cache_length else False + + if self.upscale == 4: + lqs_downsample = lqs.clone() + else: + lqs_downsample = F.interpolate(lqs[:, :, :3, :, :].view(-1, 3, h, w), scale_factor=0.25, mode='bicubic')\ + .view(n, t, 3, h // 4, w // 4) + + # check whether the input is an extended sequence + self.check_if_mirror_extended(lqs) + + # shallow feature extractions + feats = {} + if self.cpu_cache: + feats['shallow'] = [] + for i in range(0, t // self.clip_size): + feat = self.feat_extract(lqs[:, i * self.clip_size:(i + 1) * self.clip_size, :, :, :]).cpu() + feats['shallow'].append(feat) + flows_forward, flows_backward = self.compute_flow(lqs_downsample) + + lqs = lqs.cpu() + lqs_downsample = lqs_downsample.cpu() + flows_backward = flows_backward.cpu() + flows_forward = flows_forward.cpu() + torch.cuda.empty_cache() + else: + feats['shallow'] = list(torch.chunk(self.feat_extract(lqs), t // self.clip_size, dim=1)) + flows_forward, flows_backward = self.compute_flow(lqs_downsample) + + # recurrent feature refinement + updated_flows = {} + for iter_ in [1, 2]: + for direction in ['backward', 'forward']: + if direction == 'backward': + flows = flows_backward + else: + flows = flows_forward if flows_forward is not None else flows_backward.flip(1) + + module_name = f'{direction}_{iter_}' + feats[module_name] = [] + feats = self.propagate(feats, flows, module_name, updated_flows) + + # reconstruction + return self.upsample(lqs[:, :, :3, :, :], feats) + + +if __name__ == '__main__': + device = torch.device('cpu') + upscale = 4 + window_size = 8 + height = (256 // upscale // window_size) * window_size + width = (256 // upscale // window_size) * window_size + + model = RVRT(upscale=upscale, + clip_size=2, + img_size=[2, 64, 64], + window_size=[2, 8, 8], + num_blocks=[1, 2, 1], + depths=[2, 2, 2], + embed_dims=[24, 24, 24], + num_heads=[2, 2, 2], + deformable_groups=2, + attention_heads=2, + ).to(device) + print(model) + print('{:>16s} : {:<.4f} [M]'.format('#Params', sum(map(lambda x: x.numel(), model.parameters())) / 10 ** 6)) + + x = torch.randn((1, 30, 3, height, width)).to(device) + x = model(x) + print(x.shape) diff --git a/models/op/deform_attn.py b/models/op/deform_attn.py new file mode 100644 index 00000000..7ef80f05 --- /dev/null +++ b/models/op/deform_attn.py @@ -0,0 +1,191 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + + +import math +import os +import torch +from torch import nn as nn +from torch.autograd import Function +from torch.autograd.function import once_differentiable +from torch.nn import functional as F +from einops.layers.torch import Rearrange +from distutils.version import LooseVersion +from torch.utils.cpp_extension import load + +module_path = os.path.dirname(__file__) +deform_attn_ext = load( + 'deform_attn', + sources=[ + os.path.join(module_path, 'deform_attn_ext.cpp'), + os.path.join(module_path, 'deform_attn_cuda_pt110.cpp' if LooseVersion(torch.__version__) >= LooseVersion( + '1.10.0') else 'deform_attn_cuda_pt109.cpp'), + os.path.join(module_path, 'deform_attn_cuda_kernel.cu'), +], +) + + +class Mlp(nn.Module): + """ Multilayer perceptron. + + Args: + x: (B, D, H, W, C) + + Returns: + x: (B, D, H, W, C) + """ + + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + + def forward(self, x): + return self.fc2(self.act(self.fc1(x))) + + +class DeformAttnFunction(Function): + + @staticmethod + def forward(ctx, + q, + kv, + offset, + kernel_h, + kernel_w, + stride=1, + padding=0, + dilation=1, + attention_heads=1, + deformable_groups=1, + clip_size=1): + ctx.kernel_h = kernel_h + ctx.kernel_w = kernel_w + ctx.stride = stride + ctx.padding = padding + ctx.dilation = dilation + ctx.attention_heads = attention_heads + ctx.deformable_groups = deformable_groups + ctx.clip_size = clip_size + if q.requires_grad or kv.requires_grad or offset.requires_grad: + ctx.save_for_backward(q, kv, offset) + output = q.new_empty(q.shape) + ctx._bufs = [q.new_empty(0), q.new_empty(0), q.new_empty(0), q.new_empty(0), q.new_empty(0)] + deform_attn_ext.deform_attn_forward(q, kv, offset, output, + ctx._bufs[0], ctx._bufs[1], ctx._bufs[2], ctx.kernel_h, ctx.kernel_w, ctx.stride, + ctx.stride, ctx.padding, ctx.padding, ctx.dilation, ctx.dilation, + ctx.attention_heads, ctx.deformable_groups, ctx.clip_size) + return output + + @staticmethod + @once_differentiable + def backward(ctx, grad_output): + if not grad_output.is_cuda: + raise NotImplementedError + q, kv, offset = ctx.saved_tensors + grad_q = torch.zeros_like(q) + grad_kv = torch.zeros_like(kv) + grad_offset = torch.zeros_like(offset) + deform_attn_ext.deform_attn_backward(q, kv, offset, ctx._bufs[0], ctx._bufs[1], ctx._bufs[2], ctx._bufs[3], ctx._bufs[4], + grad_q, grad_kv, grad_offset, + grad_output, ctx.kernel_h, ctx.kernel_w, ctx.stride, + ctx.stride, ctx.padding, ctx.padding, ctx.dilation, ctx.dilation, + ctx.attention_heads, ctx.deformable_groups, ctx.clip_size) + + return (grad_q, grad_kv, grad_offset, None, None, None, None, None, None, None, None) + + +deform_attn = DeformAttnFunction.apply + + +class DeformAttn(nn.Module): + + def __init__(self, + in_channels, + out_channels, + attention_window=[3, 3], + deformable_groups=12, + attention_heads=12, + clip_size=1): + super(DeformAttn, self).__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_h = attention_window[0] + self.kernel_w = attention_window[1] + self.attn_size = self.kernel_h * self.kernel_w + self.deformable_groups = deformable_groups + self.attention_heads = attention_heads + self.clip_size = clip_size + self.stride = 1 + self.padding = self.kernel_h//2 + self.dilation = 1 + + self.proj_q = nn.Sequential(Rearrange('n d c h w -> n d h w c'), + nn.Linear(self.in_channels, self.in_channels), + Rearrange('n d h w c -> n d c h w')) + self.proj_k = nn.Sequential(Rearrange('n d c h w -> n d h w c'), + nn.Linear(self.in_channels, self.in_channels), + Rearrange('n d h w c -> n d c h w')) + self.proj_v = nn.Sequential(Rearrange('n d c h w -> n d h w c'), + nn.Linear(self.in_channels, self.in_channels), + Rearrange('n d h w c -> n d c h w')) + self.mlp = nn.Sequential(Rearrange('n d c h w -> n d h w c'), + Mlp(self.in_channels, self.in_channels * 2), + Rearrange('n d h w c -> n d c h w')) + + def forward(self, q, k, v, offset): + q = self.proj_q(q) + kv = torch.cat([self.proj_k(k), self.proj_v(v)], 2) + v = deform_attn(q, kv, offset, self.kernel_h, self.kernel_w, self.stride, self.padding, self.dilation, + self.attention_heads, self.deformable_groups, self.clip_size) + v = v + self.mlp(v) + return v + + +class DeformAttnPack(DeformAttn): + """A Deformable Attention Encapsulation that acts as normal attention layers. + + Args: + in_channels (int): Same as nn.Conv2d. + out_channels (int): Same as nn.Conv2d. + attention_window (int or tuple[int]): Attention window size. Default: [3, 3]. + attention_heads (int): Attention head number. Default: 12. + deformable_groups (int): Deformable offset groups. Default: 12. + clip_size (int): clip size. Default: 2. + """ + + def __init__(self, *args, **kwargs): + super(DeformAttnPack, self).__init__(*args, **kwargs) + + self.conv_offset = nn.Conv2d( + self.in_channels * (1 + self.clip_size), + self.clip_size * self.deformable_groups * self.attn_size * 2, + kernel_size=(3, 3), + stride=(1, 1), + padding=(1, 1), + dilation=(1, 1), + bias=True) + self.init_weight() + + def init_weight(self): + if hasattr(self, 'conv_offset'): + self.conv_offset.weight.data.zero_() + self.conv_offset.bias.data.zero_() + + def forward(self, q, k, v): + out = self.conv_offset(torch.cat([q.flatten(1, 2), k.flatten(1, 2)], 1)) + o1, o2 = torch.chunk(out, 2, dim=1) + offset = torch.cat((o1, o2), dim=1) + + q = self.proj_q(q) + kv = torch.cat([self.proj_k(k), self.proj_v(v)], 2) + v = deform_attn(q, kv, offset, self.kernel_h, self.kernel_w, self.stride, self.padding, self.dilation, + self.attention_heads, self.deformable_groups, self.clip_size) + v = v + self.mlp(v) + return v diff --git a/models/op/deform_attn_cuda_kernel.cu b/models/op/deform_attn_cuda_kernel.cu new file mode 100755 index 00000000..98752dcc --- /dev/null +++ b/models/op/deform_attn_cuda_kernel.cu @@ -0,0 +1,867 @@ +/*! + ******************* BEGIN Caffe Copyright Notice and Disclaimer **************** + * + * COPYRIGHT + * + * All contributions by the University of California: + * Copyright (c) 2014-2017 The Regents of the University of California (Regents) + * All rights reserved. + * + * All other contributions: + * Copyright (c) 2014-2017, the respective contributors + * All rights reserved. + * + * Caffe uses a shared copyright model: each contributor holds copyright over + * their contributions to Caffe. The project versioning records all such + * contribution and copyright details. If a contributor wants to further mark + * their specific copyright on a particular contribution, they should indicate + * their copyright solely in the commit message of the change when it is + * committed. + * + * LICENSE + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR + * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + * CONTRIBUTION AGREEMENT + * + * By contributing to the BVLC/caffe repository through pull-request, comment, + * or otherwise, the contributor releases their content to the + * license and copyright terms herein. + * + ***************** END Caffe Copyright Notice and Disclaimer ******************** + * + * Copyright (c) 2018 Microsoft + * Licensed under The MIT License [see LICENSE for details] + * \file modulated_deformable_im2col.cuh + * \brief Function definitions of converting an image to + * column matrix based on kernel, padding, dilation, and offset. + * These functions are mainly used in deformable convolution operators. + * \ref: https://arxiv.org/abs/1703.06211 + * \author Yuwen Xiong, Haozhi Qi, Jifeng Dai, Xizhou Zhu, Han Hu, Dazhi Cheng + */ + +// modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/blob/mmdetection/mmdet/ops/dcn/src/deform_conv_cuda_kernel.cu + +#include +#include +#include +#include +#include +#include + +using namespace at; + +#define CUDA_KERNEL_LOOP(i, n) \ + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \ + i += blockDim.x * gridDim.x) + +const int CUDA_NUM_THREADS = 1024; +const int kMaxGridNum = 65535; + +inline int GET_BLOCKS(const int N) +{ + return std::min(kMaxGridNum, (N + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS); +} + +template +__device__ scalar_t deformable_im2col_bilinear(const scalar_t *bottom_data, const int data_width, + const int height, const int width, scalar_t h, scalar_t w) +{ + + int h_low = floor(h); + int w_low = floor(w); + int h_high = h_low + 1; + int w_high = w_low + 1; + + scalar_t lh = h - h_low; + scalar_t lw = w - w_low; + scalar_t hh = 1 - lh, hw = 1 - lw; + + scalar_t v1 = 0; + if (h_low >= 0 && w_low >= 0) + v1 = bottom_data[h_low * data_width + w_low]; + scalar_t v2 = 0; + if (h_low >= 0 && w_high <= width - 1) + v2 = bottom_data[h_low * data_width + w_high]; + scalar_t v3 = 0; + if (h_high <= height - 1 && w_low >= 0) + v3 = bottom_data[h_high * data_width + w_low]; + scalar_t v4 = 0; + if (h_high <= height - 1 && w_high <= width - 1) + v4 = bottom_data[h_high * data_width + w_high]; + + scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw; + + scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); + return val; +} + +template +__device__ scalar_t get_gradient_weight(scalar_t argmax_h, scalar_t argmax_w, + const int h, const int w, const int height, const int width) +{ + + if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || argmax_w >= width) + { + //empty + return 0; + } + + int argmax_h_low = floor(argmax_h); + int argmax_w_low = floor(argmax_w); + int argmax_h_high = argmax_h_low + 1; + int argmax_w_high = argmax_w_low + 1; + + scalar_t weight = 0; + if (h == argmax_h_low && w == argmax_w_low) + weight = (h + 1 - argmax_h) * (w + 1 - argmax_w); + if (h == argmax_h_low && w == argmax_w_high) + weight = (h + 1 - argmax_h) * (argmax_w + 1 - w); + if (h == argmax_h_high && w == argmax_w_low) + weight = (argmax_h + 1 - h) * (w + 1 - argmax_w); + if (h == argmax_h_high && w == argmax_w_high) + weight = (argmax_h + 1 - h) * (argmax_w + 1 - w); + return weight; +} + +template +__device__ scalar_t get_coordinate_weight(scalar_t argmax_h, scalar_t argmax_w, + const int height, const int width, const scalar_t *im_data, + const int data_width, const int bp_dir) +{ + + if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || argmax_w >= width) + { + //empty + return 0; + } + + int argmax_h_low = floor(argmax_h); + int argmax_w_low = floor(argmax_w); + int argmax_h_high = argmax_h_low + 1; + int argmax_w_high = argmax_w_low + 1; + + scalar_t weight = 0; + + if (bp_dir == 0) + { + if (argmax_h_low >= 0 && argmax_w_low >= 0) + weight += -1 * (argmax_w_low + 1 - argmax_w) * im_data[argmax_h_low * data_width + argmax_w_low]; + if (argmax_h_low >= 0 && argmax_w_high <= width - 1) + weight += -1 * (argmax_w - argmax_w_low) * im_data[argmax_h_low * data_width + argmax_w_high]; + if (argmax_h_high <= height - 1 && argmax_w_low >= 0) + weight += (argmax_w_low + 1 - argmax_w) * im_data[argmax_h_high * data_width + argmax_w_low]; + if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1) + weight += (argmax_w - argmax_w_low) * im_data[argmax_h_high * data_width + argmax_w_high]; + } + else if (bp_dir == 1) + { + if (argmax_h_low >= 0 && argmax_w_low >= 0) + weight += -1 * (argmax_h_low + 1 - argmax_h) * im_data[argmax_h_low * data_width + argmax_w_low]; + if (argmax_h_low >= 0 && argmax_w_high <= width - 1) + weight += (argmax_h_low + 1 - argmax_h) * im_data[argmax_h_low * data_width + argmax_w_high]; + if (argmax_h_high <= height - 1 && argmax_w_low >= 0) + weight += -1 * (argmax_h - argmax_h_low) * im_data[argmax_h_high * data_width + argmax_w_low]; + if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1) + weight += (argmax_h - argmax_h_low) * im_data[argmax_h_high * data_width + argmax_w_high]; + } + + return weight; +} + +template +__global__ void deformable_im2col_gpu_kernel(const int n, const scalar_t *data_im, const scalar_t *data_offset, + const int height, const int width, const int kernel_h, const int kernel_w, + const int pad_h, const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, const int channel_per_deformable_group, + const int batch_size, const int num_channels, const int deformable_group, + const int height_col, const int width_col, + scalar_t *data_col) +{ + CUDA_KERNEL_LOOP(index, n) + { + // index index of output matrix + const int w_col = index % width_col; + const int h_col = (index / width_col) % height_col; + const int b_col = (index / width_col / height_col) % batch_size; + const int c_im = (index / width_col / height_col) / batch_size; + const int c_col = c_im * kernel_h * kernel_w; + + // compute deformable group index + const int deformable_group_index = c_im / channel_per_deformable_group; + + const int h_in = h_col * stride_h - pad_h; + const int w_in = w_col * stride_w - pad_w; + scalar_t *data_col_ptr = data_col + ((c_col * batch_size + b_col) * height_col + h_col) * width_col + w_col; + //const scalar_t* data_im_ptr = data_im + ((b_col * num_channels + c_im) * height + h_in) * width + w_in; + const scalar_t *data_im_ptr = data_im + (b_col * num_channels + c_im) * height * width; + const scalar_t *data_offset_ptr = data_offset + (b_col * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col; + + for (int i = 0; i < kernel_h; ++i) + { + for (int j = 0; j < kernel_w; ++j) + { + const int data_offset_h_ptr = ((2 * (i * kernel_w + j)) * height_col + h_col) * width_col + w_col; + const int data_offset_w_ptr = ((2 * (i * kernel_w + j) + 1) * height_col + h_col) * width_col + w_col; + const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr]; + const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr]; + scalar_t val = static_cast(0); + const scalar_t h_im = h_in + i * dilation_h + offset_h; + const scalar_t w_im = w_in + j * dilation_w + offset_w; + if (h_im > -1 && w_im > -1 && h_im < height && w_im < width) + { + //const scalar_t map_h = i * dilation_h + offset_h; + //const scalar_t map_w = j * dilation_w + offset_w; + //const int cur_height = height - h_in; + //const int cur_width = width - w_in; + //val = deformable_im2col_bilinear(data_im_ptr, width, cur_height, cur_width, map_h, map_w); + val = deformable_im2col_bilinear(data_im_ptr, width, height, width, h_im, w_im); + } + *data_col_ptr = val; + data_col_ptr += batch_size * height_col * width_col; + } + } + } +} + +void deformable_im2col( + const at::Tensor data_im, const at::Tensor data_offset, const int channels, + const int height, const int width, const int ksize_h, const int ksize_w, + const int pad_h, const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, const int parallel_imgs, + const int deformable_group, at::Tensor data_col) +{ + // num_axes should be smaller than block size + // todo: check parallel_imgs is correctly passed in + int height_col = (height + 2 * pad_h - (dilation_h * (ksize_h - 1) + 1)) / stride_h + 1; + int width_col = (width + 2 * pad_w - (dilation_w * (ksize_w - 1) + 1)) / stride_w + 1; + int num_kernels = channels * height_col * width_col * parallel_imgs; + int channel_per_deformable_group = channels / deformable_group; + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + data_im.scalar_type(), "deformable_im2col_gpu", ([&] { + const scalar_t *data_im_ = data_im.data_ptr(); + const scalar_t *data_offset_ = data_offset.data_ptr(); + scalar_t *data_col_ = data_col.data_ptr(); + + deformable_im2col_gpu_kernel<<>>( + num_kernels, data_im_, data_offset_, height, width, ksize_h, ksize_w, + pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, + channel_per_deformable_group, parallel_imgs, channels, deformable_group, + height_col, width_col, data_col_); + })); + + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) + { + printf("error in deformable_im2col: %s\n", cudaGetErrorString(err)); + } +} + +template +__global__ void deformable_col2im_gpu_kernel( + const int n, const scalar_t *data_col, const scalar_t *data_offset, + const int channels, const int height, const int width, + const int kernel_h, const int kernel_w, + const int pad_h, const int pad_w, + const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int channel_per_deformable_group, + const int batch_size, const int deformable_group, + const int height_col, const int width_col, + scalar_t *grad_im) +{ + CUDA_KERNEL_LOOP(index, n) + { + const int j = (index / width_col / height_col / batch_size) % kernel_w; + const int i = (index / width_col / height_col / batch_size / kernel_w) % kernel_h; + const int c = index / width_col / height_col / batch_size / kernel_w / kernel_h; + // compute the start and end of the output + + const int deformable_group_index = c / channel_per_deformable_group; + + int w_out = index % width_col; + int h_out = (index / width_col) % height_col; + int b = (index / width_col / height_col) % batch_size; + int w_in = w_out * stride_w - pad_w; + int h_in = h_out * stride_h - pad_h; + + const scalar_t *data_offset_ptr = data_offset + (b * deformable_group + deformable_group_index) * + 2 * kernel_h * kernel_w * height_col * width_col; + const int data_offset_h_ptr = ((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out; + const int data_offset_w_ptr = ((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out; + const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr]; + const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr]; + const scalar_t cur_inv_h_data = h_in + i * dilation_h + offset_h; + const scalar_t cur_inv_w_data = w_in + j * dilation_w + offset_w; + + const scalar_t cur_top_grad = data_col[index]; + const int cur_h = (int)cur_inv_h_data; + const int cur_w = (int)cur_inv_w_data; + for (int dy = -2; dy <= 2; dy++) + { + for (int dx = -2; dx <= 2; dx++) + { + if (cur_h + dy >= 0 && cur_h + dy < height && + cur_w + dx >= 0 && cur_w + dx < width && + abs(cur_inv_h_data - (cur_h + dy)) < 1 && + abs(cur_inv_w_data - (cur_w + dx)) < 1) + { + int cur_bottom_grad_pos = ((b * channels + c) * height + cur_h + dy) * width + cur_w + dx; + scalar_t weight = get_gradient_weight(cur_inv_h_data, cur_inv_w_data, cur_h + dy, cur_w + dx, height, width); + atomicAdd(grad_im + cur_bottom_grad_pos, weight * cur_top_grad); + } + } + } + } +} + +void deformable_col2im( + const at::Tensor data_col, const at::Tensor data_offset, const int channels, + const int height, const int width, const int ksize_h, + const int ksize_w, const int pad_h, const int pad_w, + const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int parallel_imgs, const int deformable_group, + at::Tensor grad_im) +{ + + // todo: make sure parallel_imgs is passed in correctly + int height_col = (height + 2 * pad_h - (dilation_h * (ksize_h - 1) + 1)) / stride_h + 1; + int width_col = (width + 2 * pad_w - (dilation_w * (ksize_w - 1) + 1)) / stride_w + 1; + int num_kernels = channels * ksize_h * ksize_w * height_col * width_col * parallel_imgs; + int channel_per_deformable_group = channels / deformable_group; + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + data_col.scalar_type(), "deformable_col2im_gpu", ([&] { + const scalar_t *data_col_ = data_col.data_ptr(); + const scalar_t *data_offset_ = data_offset.data_ptr(); + scalar_t *grad_im_ = grad_im.data_ptr(); + + deformable_col2im_gpu_kernel<<>>( + num_kernels, data_col_, data_offset_, channels, height, width, ksize_h, + ksize_w, pad_h, pad_w, stride_h, stride_w, + dilation_h, dilation_w, channel_per_deformable_group, + parallel_imgs, deformable_group, height_col, width_col, grad_im_); + })); + + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) + { + printf("error in deformable_col2im: %s\n", cudaGetErrorString(err)); + } +} + +template +__global__ void deformable_col2im_coord_gpu_kernel(const int n, const scalar_t *data_col, + const scalar_t *data_im, const scalar_t *data_offset, + const int channels, const int height, const int width, + const int kernel_h, const int kernel_w, + const int pad_h, const int pad_w, + const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int channel_per_deformable_group, + const int batch_size, const int offset_channels, const int deformable_group, + const int height_col, const int width_col, scalar_t *grad_offset) +{ + CUDA_KERNEL_LOOP(index, n) + { + scalar_t val = 0; + int w = index % width_col; + int h = (index / width_col) % height_col; + int c = (index / width_col / height_col) % offset_channels; + int b = (index / width_col / height_col) / offset_channels; + // compute the start and end of the output + + const int deformable_group_index = c / (2 * kernel_h * kernel_w); + const int col_step = kernel_h * kernel_w; + int cnt = 0; + const scalar_t *data_col_ptr = data_col + deformable_group_index * channel_per_deformable_group * + batch_size * width_col * height_col; + const scalar_t *data_im_ptr = data_im + (b * deformable_group + deformable_group_index) * + channel_per_deformable_group / kernel_h / kernel_w * height * width; + const scalar_t *data_offset_ptr = data_offset + (b * deformable_group + deformable_group_index) * 2 * + kernel_h * kernel_w * height_col * width_col; + + const int offset_c = c - deformable_group_index * 2 * kernel_h * kernel_w; + + for (int col_c = (offset_c / 2); col_c < channel_per_deformable_group; col_c += col_step) + { + const int col_pos = (((col_c * batch_size + b) * height_col) + h) * width_col + w; + const int bp_dir = offset_c % 2; + + int j = (col_pos / width_col / height_col / batch_size) % kernel_w; + int i = (col_pos / width_col / height_col / batch_size / kernel_w) % kernel_h; + int w_out = col_pos % width_col; + int h_out = (col_pos / width_col) % height_col; + int w_in = w_out * stride_w - pad_w; + int h_in = h_out * stride_h - pad_h; + const int data_offset_h_ptr = (((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out); + const int data_offset_w_ptr = (((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out); + const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr]; + const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr]; + scalar_t inv_h = h_in + i * dilation_h + offset_h; + scalar_t inv_w = w_in + j * dilation_w + offset_w; + if (inv_h <= -1 || inv_w <= -1 || inv_h >= height || inv_w >= width) + { + inv_h = inv_w = -2; + } + const scalar_t weight = get_coordinate_weight( + inv_h, inv_w, + height, width, data_im_ptr + cnt * height * width, width, bp_dir); + val += weight * data_col_ptr[col_pos]; + cnt += 1; + } + + grad_offset[index] = val; + } +} + +void deformable_col2im_coord( + const at::Tensor data_col, const at::Tensor data_im, const at::Tensor data_offset, + const int channels, const int height, const int width, const int ksize_h, + const int ksize_w, const int pad_h, const int pad_w, const int stride_h, + const int stride_w, const int dilation_h, const int dilation_w, + const int parallel_imgs, const int deformable_group, at::Tensor grad_offset) +{ + + int height_col = (height + 2 * pad_h - (dilation_h * (ksize_h - 1) + 1)) / stride_h + 1; + int width_col = (width + 2 * pad_w - (dilation_w * (ksize_w - 1) + 1)) / stride_w + 1; + int num_kernels = height_col * width_col * 2 * ksize_h * ksize_w * deformable_group * parallel_imgs; + int channel_per_deformable_group = channels * ksize_h * ksize_w / deformable_group; + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + data_col.scalar_type(), "deformable_col2im_coord_gpu", ([&] { + const scalar_t *data_col_ = data_col.data_ptr(); + const scalar_t *data_im_ = data_im.data_ptr(); + const scalar_t *data_offset_ = data_offset.data_ptr(); + scalar_t *grad_offset_ = grad_offset.data_ptr(); + + deformable_col2im_coord_gpu_kernel<<>>( + num_kernels, data_col_, data_im_, data_offset_, channels, height, width, + ksize_h, ksize_w, pad_h, pad_w, stride_h, stride_w, + dilation_h, dilation_w, channel_per_deformable_group, + parallel_imgs, 2 * ksize_h * ksize_w * deformable_group, deformable_group, + height_col, width_col, grad_offset_); + })); +} + +template +__device__ scalar_t dmcn_im2col_bilinear(const scalar_t *bottom_data, const int data_width, + const int height, const int width, scalar_t h, scalar_t w) +{ + int h_low = floor(h); + int w_low = floor(w); + int h_high = h_low + 1; + int w_high = w_low + 1; + + scalar_t lh = h - h_low; + scalar_t lw = w - w_low; + scalar_t hh = 1 - lh, hw = 1 - lw; + + scalar_t v1 = 0; + if (h_low >= 0 && w_low >= 0) + v1 = bottom_data[h_low * data_width + w_low]; + scalar_t v2 = 0; + if (h_low >= 0 && w_high <= width - 1) + v2 = bottom_data[h_low * data_width + w_high]; + scalar_t v3 = 0; + if (h_high <= height - 1 && w_low >= 0) + v3 = bottom_data[h_high * data_width + w_low]; + scalar_t v4 = 0; + if (h_high <= height - 1 && w_high <= width - 1) + v4 = bottom_data[h_high * data_width + w_high]; + + scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw; + + scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); + return val; +} + +template +__device__ scalar_t dmcn_get_gradient_weight(scalar_t argmax_h, scalar_t argmax_w, + const int h, const int w, const int height, const int width) +{ + if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || argmax_w >= width) + { + //empty + return 0; + } + + int argmax_h_low = floor(argmax_h); + int argmax_w_low = floor(argmax_w); + int argmax_h_high = argmax_h_low + 1; + int argmax_w_high = argmax_w_low + 1; + + scalar_t weight = 0; + if (h == argmax_h_low && w == argmax_w_low) + weight = (h + 1 - argmax_h) * (w + 1 - argmax_w); + if (h == argmax_h_low && w == argmax_w_high) + weight = (h + 1 - argmax_h) * (argmax_w + 1 - w); + if (h == argmax_h_high && w == argmax_w_low) + weight = (argmax_h + 1 - h) * (w + 1 - argmax_w); + if (h == argmax_h_high && w == argmax_w_high) + weight = (argmax_h + 1 - h) * (argmax_w + 1 - w); + return weight; +} + +template +__device__ scalar_t dmcn_get_coordinate_weight(scalar_t argmax_h, scalar_t argmax_w, + const int height, const int width, const scalar_t *im_data, + const int data_width, const int bp_dir) +{ + if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || argmax_w >= width) + { + //empty + return 0; + } + + int argmax_h_low = floor(argmax_h); + int argmax_w_low = floor(argmax_w); + int argmax_h_high = argmax_h_low + 1; + int argmax_w_high = argmax_w_low + 1; + + scalar_t weight = 0; + + if (bp_dir == 0) + { + if (argmax_h_low >= 0 && argmax_w_low >= 0) + weight += -1 * (argmax_w_low + 1 - argmax_w) * im_data[argmax_h_low * data_width + argmax_w_low]; + if (argmax_h_low >= 0 && argmax_w_high <= width - 1) + weight += -1 * (argmax_w - argmax_w_low) * im_data[argmax_h_low * data_width + argmax_w_high]; + if (argmax_h_high <= height - 1 && argmax_w_low >= 0) + weight += (argmax_w_low + 1 - argmax_w) * im_data[argmax_h_high * data_width + argmax_w_low]; + if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1) + weight += (argmax_w - argmax_w_low) * im_data[argmax_h_high * data_width + argmax_w_high]; + } + else if (bp_dir == 1) + { + if (argmax_h_low >= 0 && argmax_w_low >= 0) + weight += -1 * (argmax_h_low + 1 - argmax_h) * im_data[argmax_h_low * data_width + argmax_w_low]; + if (argmax_h_low >= 0 && argmax_w_high <= width - 1) + weight += (argmax_h_low + 1 - argmax_h) * im_data[argmax_h_low * data_width + argmax_w_high]; + if (argmax_h_high <= height - 1 && argmax_w_low >= 0) + weight += -1 * (argmax_h - argmax_h_low) * im_data[argmax_h_high * data_width + argmax_w_low]; + if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1) + weight += (argmax_h - argmax_h_low) * im_data[argmax_h_high * data_width + argmax_w_high]; + } + + return weight; +} + +template +__global__ void modulated_deformable_im2col_gpu_kernel(const int n, + const scalar_t *data_im, const scalar_t *data_offset, const scalar_t *data_mask, + const int height, const int width, const int kernel_h, const int kernel_w, + const int pad_h, const int pad_w, + const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int channel_per_deformable_group, + const int batch_size, const int num_channels, const int deformable_group, + const int height_col, const int width_col, + scalar_t *data_col) +{ + CUDA_KERNEL_LOOP(index, n) + { + // index index of output matrix + const int w_col = index % width_col; + const int h_col = (index / width_col) % height_col; + const int b_col = (index / width_col / height_col) % batch_size; + const int c_im = (index / width_col / height_col) / batch_size; + const int c_col = c_im * kernel_h * kernel_w; + + // compute deformable group index + const int deformable_group_index = c_im / channel_per_deformable_group; + + const int h_in = h_col * stride_h - pad_h; + const int w_in = w_col * stride_w - pad_w; + + scalar_t *data_col_ptr = data_col + ((c_col * batch_size + b_col) * height_col + h_col) * width_col + w_col; + //const float* data_im_ptr = data_im + ((b_col * num_channels + c_im) * height + h_in) * width + w_in; + const scalar_t *data_im_ptr = data_im + (b_col * num_channels + c_im) * height * width; + const scalar_t *data_offset_ptr = data_offset + (b_col * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col; + + const scalar_t *data_mask_ptr = data_mask + (b_col * deformable_group + deformable_group_index) * kernel_h * kernel_w * height_col * width_col; + + for (int i = 0; i < kernel_h; ++i) + { + for (int j = 0; j < kernel_w; ++j) + { + const int data_offset_h_ptr = ((2 * (i * kernel_w + j)) * height_col + h_col) * width_col + w_col; + const int data_offset_w_ptr = ((2 * (i * kernel_w + j) + 1) * height_col + h_col) * width_col + w_col; + const int data_mask_hw_ptr = ((i * kernel_w + j) * height_col + h_col) * width_col + w_col; + const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr]; + const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr]; + const scalar_t mask = data_mask_ptr[data_mask_hw_ptr]; + scalar_t val = static_cast(0); + const scalar_t h_im = h_in + i * dilation_h + offset_h; + const scalar_t w_im = w_in + j * dilation_w + offset_w; + //if (h_im >= 0 && w_im >= 0 && h_im < height && w_im < width) { + if (h_im > -1 && w_im > -1 && h_im < height && w_im < width) + { + //const float map_h = i * dilation_h + offset_h; + //const float map_w = j * dilation_w + offset_w; + //const int cur_height = height - h_in; + //const int cur_width = width - w_in; + //val = dmcn_im2col_bilinear(data_im_ptr, width, cur_height, cur_width, map_h, map_w); + val = dmcn_im2col_bilinear(data_im_ptr, width, height, width, h_im, w_im); + } + *data_col_ptr = val * mask; + data_col_ptr += batch_size * height_col * width_col; + //data_col_ptr += height_col * width_col; + } + } + } +} + +template +__global__ void modulated_deformable_col2im_gpu_kernel(const int n, + const scalar_t *data_col, const scalar_t *data_offset, const scalar_t *data_mask, + const int channels, const int height, const int width, + const int kernel_h, const int kernel_w, + const int pad_h, const int pad_w, + const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int channel_per_deformable_group, + const int batch_size, const int deformable_group, + const int height_col, const int width_col, + scalar_t *grad_im) +{ + CUDA_KERNEL_LOOP(index, n) + { + const int j = (index / width_col / height_col / batch_size) % kernel_w; + const int i = (index / width_col / height_col / batch_size / kernel_w) % kernel_h; + const int c = index / width_col / height_col / batch_size / kernel_w / kernel_h; + // compute the start and end of the output + + const int deformable_group_index = c / channel_per_deformable_group; + + int w_out = index % width_col; + int h_out = (index / width_col) % height_col; + int b = (index / width_col / height_col) % batch_size; + int w_in = w_out * stride_w - pad_w; + int h_in = h_out * stride_h - pad_h; + + const scalar_t *data_offset_ptr = data_offset + (b * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col; + const scalar_t *data_mask_ptr = data_mask + (b * deformable_group + deformable_group_index) * kernel_h * kernel_w * height_col * width_col; + const int data_offset_h_ptr = ((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out; + const int data_offset_w_ptr = ((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out; + const int data_mask_hw_ptr = ((i * kernel_w + j) * height_col + h_out) * width_col + w_out; + const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr]; + const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr]; + const scalar_t mask = data_mask_ptr[data_mask_hw_ptr]; + const scalar_t cur_inv_h_data = h_in + i * dilation_h + offset_h; + const scalar_t cur_inv_w_data = w_in + j * dilation_w + offset_w; + + const scalar_t cur_top_grad = data_col[index] * mask; + const int cur_h = (int)cur_inv_h_data; + const int cur_w = (int)cur_inv_w_data; + for (int dy = -2; dy <= 2; dy++) + { + for (int dx = -2; dx <= 2; dx++) + { + if (cur_h + dy >= 0 && cur_h + dy < height && + cur_w + dx >= 0 && cur_w + dx < width && + abs(cur_inv_h_data - (cur_h + dy)) < 1 && + abs(cur_inv_w_data - (cur_w + dx)) < 1) + { + int cur_bottom_grad_pos = ((b * channels + c) * height + cur_h + dy) * width + cur_w + dx; + scalar_t weight = dmcn_get_gradient_weight(cur_inv_h_data, cur_inv_w_data, cur_h + dy, cur_w + dx, height, width); + atomicAdd(grad_im + cur_bottom_grad_pos, weight * cur_top_grad); + } + } + } + } +} + +template +__global__ void modulated_deformable_col2im_coord_gpu_kernel(const int n, + const scalar_t *data_col, const scalar_t *data_im, + const scalar_t *data_offset, const scalar_t *data_mask, + const int channels, const int height, const int width, + const int kernel_h, const int kernel_w, + const int pad_h, const int pad_w, + const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int channel_per_deformable_group, + const int batch_size, const int offset_channels, const int deformable_group, + const int height_col, const int width_col, + scalar_t *grad_offset, scalar_t *grad_mask) +{ + CUDA_KERNEL_LOOP(index, n) + { + scalar_t val = 0, mval = 0; + int w = index % width_col; + int h = (index / width_col) % height_col; + int c = (index / width_col / height_col) % offset_channels; + int b = (index / width_col / height_col) / offset_channels; + // compute the start and end of the output + + const int deformable_group_index = c / (2 * kernel_h * kernel_w); + const int col_step = kernel_h * kernel_w; + int cnt = 0; + const scalar_t *data_col_ptr = data_col + deformable_group_index * channel_per_deformable_group * batch_size * width_col * height_col; + const scalar_t *data_im_ptr = data_im + (b * deformable_group + deformable_group_index) * channel_per_deformable_group / kernel_h / kernel_w * height * width; + const scalar_t *data_offset_ptr = data_offset + (b * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col; + const scalar_t *data_mask_ptr = data_mask + (b * deformable_group + deformable_group_index) * kernel_h * kernel_w * height_col * width_col; + + const int offset_c = c - deformable_group_index * 2 * kernel_h * kernel_w; + + for (int col_c = (offset_c / 2); col_c < channel_per_deformable_group; col_c += col_step) + { + const int col_pos = (((col_c * batch_size + b) * height_col) + h) * width_col + w; + const int bp_dir = offset_c % 2; + + int j = (col_pos / width_col / height_col / batch_size) % kernel_w; + int i = (col_pos / width_col / height_col / batch_size / kernel_w) % kernel_h; + int w_out = col_pos % width_col; + int h_out = (col_pos / width_col) % height_col; + int w_in = w_out * stride_w - pad_w; + int h_in = h_out * stride_h - pad_h; + const int data_offset_h_ptr = (((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out); + const int data_offset_w_ptr = (((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out); + const int data_mask_hw_ptr = (((i * kernel_w + j) * height_col + h_out) * width_col + w_out); + const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr]; + const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr]; + const scalar_t mask = data_mask_ptr[data_mask_hw_ptr]; + scalar_t inv_h = h_in + i * dilation_h + offset_h; + scalar_t inv_w = w_in + j * dilation_w + offset_w; + if (inv_h <= -1 || inv_w <= -1 || inv_h >= height || inv_w >= width) + { + inv_h = inv_w = -2; + } + else + { + mval += data_col_ptr[col_pos] * dmcn_im2col_bilinear(data_im_ptr + cnt * height * width, width, height, width, inv_h, inv_w); + } + const scalar_t weight = dmcn_get_coordinate_weight( + inv_h, inv_w, + height, width, data_im_ptr + cnt * height * width, width, bp_dir); + val += weight * data_col_ptr[col_pos] * mask; + cnt += 1; + } + // KERNEL_ASSIGN(grad_offset[index], offset_req, val); + grad_offset[index] = val; + if (offset_c % 2 == 0) + // KERNEL_ASSIGN(grad_mask[(((b * deformable_group + deformable_group_index) * kernel_h * kernel_w + offset_c / 2) * height_col + h) * width_col + w], mask_req, mval); + grad_mask[(((b * deformable_group + deformable_group_index) * kernel_h * kernel_w + offset_c / 2) * height_col + h) * width_col + w] = mval; + } +} + +void modulated_deformable_im2col_cuda( + const at::Tensor data_im, const at::Tensor data_offset, const at::Tensor data_mask, + const int batch_size, const int channels, const int height_im, const int width_im, + const int height_col, const int width_col, const int kernel_h, const int kenerl_w, + const int pad_h, const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int deformable_group, at::Tensor data_col) +{ + // num_axes should be smaller than block size + const int channel_per_deformable_group = channels / deformable_group; + const int num_kernels = channels * batch_size * height_col * width_col; + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + data_im.scalar_type(), "modulated_deformable_im2col_gpu", ([&] { + const scalar_t *data_im_ = data_im.data_ptr(); + const scalar_t *data_offset_ = data_offset.data_ptr(); + const scalar_t *data_mask_ = data_mask.data_ptr(); + scalar_t *data_col_ = data_col.data_ptr(); + + modulated_deformable_im2col_gpu_kernel<<>>( + num_kernels, data_im_, data_offset_, data_mask_, height_im, width_im, kernel_h, kenerl_w, + pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, channel_per_deformable_group, + batch_size, channels, deformable_group, height_col, width_col, data_col_); + })); + + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) + { + printf("error in modulated_deformable_im2col_cuda: %s\n", cudaGetErrorString(err)); + } +} + +void modulated_deformable_col2im_cuda( + const at::Tensor data_col, const at::Tensor data_offset, const at::Tensor data_mask, + const int batch_size, const int channels, const int height_im, const int width_im, + const int height_col, const int width_col, const int kernel_h, const int kernel_w, + const int pad_h, const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int deformable_group, at::Tensor grad_im) +{ + + const int channel_per_deformable_group = channels / deformable_group; + const int num_kernels = channels * kernel_h * kernel_w * batch_size * height_col * width_col; + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + data_col.scalar_type(), "modulated_deformable_col2im_gpu", ([&] { + const scalar_t *data_col_ = data_col.data_ptr(); + const scalar_t *data_offset_ = data_offset.data_ptr(); + const scalar_t *data_mask_ = data_mask.data_ptr(); + scalar_t *grad_im_ = grad_im.data_ptr(); + + modulated_deformable_col2im_gpu_kernel<<>>( + num_kernels, data_col_, data_offset_, data_mask_, channels, height_im, width_im, + kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, + dilation_h, dilation_w, channel_per_deformable_group, + batch_size, deformable_group, height_col, width_col, grad_im_); + })); + + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) + { + printf("error in modulated_deformable_col2im_cuda: %s\n", cudaGetErrorString(err)); + } +} + +void modulated_deformable_col2im_coord_cuda( + const at::Tensor data_col, const at::Tensor data_im, const at::Tensor data_offset, const at::Tensor data_mask, + const int batch_size, const int channels, const int height_im, const int width_im, + const int height_col, const int width_col, const int kernel_h, const int kernel_w, + const int pad_h, const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int deformable_group, + at::Tensor grad_offset, at::Tensor grad_mask) +{ + const int num_kernels = batch_size * height_col * width_col * 2 * kernel_h * kernel_w * deformable_group; + const int channel_per_deformable_group = channels * kernel_h * kernel_w / deformable_group; + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + data_col.scalar_type(), "modulated_deformable_col2im_coord_gpu", ([&] { + const scalar_t *data_col_ = data_col.data_ptr(); + const scalar_t *data_im_ = data_im.data_ptr(); + const scalar_t *data_offset_ = data_offset.data_ptr(); + const scalar_t *data_mask_ = data_mask.data_ptr(); + scalar_t *grad_offset_ = grad_offset.data_ptr(); + scalar_t *grad_mask_ = grad_mask.data_ptr(); + + modulated_deformable_col2im_coord_gpu_kernel<<>>( + num_kernels, data_col_, data_im_, data_offset_, data_mask_, channels, height_im, width_im, + kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, + dilation_h, dilation_w, channel_per_deformable_group, + batch_size, 2 * kernel_h * kernel_w * deformable_group, deformable_group, height_col, width_col, + grad_offset_, grad_mask_); + })); + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) + { + printf("error in modulated_deformable_col2im_coord_cuda: %s\n", cudaGetErrorString(err)); + } +} diff --git a/models/op/deform_attn_cuda_pt109.cpp b/models/op/deform_attn_cuda_pt109.cpp new file mode 100755 index 00000000..9207956f --- /dev/null +++ b/models/op/deform_attn_cuda_pt109.cpp @@ -0,0 +1,219 @@ +// modify from +// https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/blob/mmdetection/mmdet/ops/dcn/src/deform_conv_cuda.c + +#include +#include +#include + +#include +#include + +void deformable_im2col(const at::Tensor data_im, const at::Tensor data_offset, + const int channels, const int height, const int width, + const int ksize_h, const int ksize_w, const int pad_h, + const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int parallel_imgs, const int deform_group, + at::Tensor data_col); + +void deformable_col2im(const at::Tensor data_col, const at::Tensor data_offset, + const int channels, const int height, const int width, + const int ksize_h, const int ksize_w, const int pad_h, + const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int parallel_imgs, const int deform_group, + at::Tensor grad_im); + +void deformable_col2im_coord( + const at::Tensor data_col, const at::Tensor data_im, + const at::Tensor data_offset, const int channels, const int height, + const int width, const int ksize_h, const int ksize_w, const int pad_h, + const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, const int parallel_imgs, + const int deform_group, at::Tensor grad_offset); + +void modulated_deformable_im2col_cuda( + const at::Tensor data_im, const at::Tensor data_offset, + const at::Tensor data_mask, const int batch_size, const int channels, + const int height_im, const int width_im, const int height_col, + const int width_col, const int kernel_h, const int kenerl_w, + const int pad_h, const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, const int deform_group, + at::Tensor data_col); + +void modulated_deformable_col2im_cuda( + const at::Tensor data_col, const at::Tensor data_offset, + const at::Tensor data_mask, const int batch_size, const int channels, + const int height_im, const int width_im, const int height_col, + const int width_col, const int kernel_h, const int kenerl_w, + const int pad_h, const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, const int deform_group, + at::Tensor grad_im); + +void modulated_deformable_col2im_coord_cuda( + const at::Tensor data_col, const at::Tensor data_im, + const at::Tensor data_offset, const at::Tensor data_mask, + const int batch_size, const int channels, const int height_im, + const int width_im, const int height_col, const int width_col, + const int kernel_h, const int kenerl_w, const int pad_h, const int pad_w, + const int stride_h, const int stride_w, const int dilation_h, + const int dilation_w, const int deform_group, at::Tensor grad_offset, + at::Tensor grad_mask); + +void deform_attn_cuda_forward( + at::Tensor q, at::Tensor kv, at::Tensor offset, at::Tensor output, + at::Tensor columns, at::Tensor attns, at::Tensor mask_ones, int kernel_h, int kernel_w, const int stride_h, const int stride_w, + const int pad_h, const int pad_w, const int dilation_h, + const int dilation_w, const int attn_head, const int deform_group, const int clip_size + ){ + TORCH_CHECK(kv.is_contiguous(), "input tensor has to be contiguous"); + at::DeviceGuard guard(kv.device()); + + const int batch = q.size(0); + const int kv_channels = kv.size(2); + const int channels = kv.size(2) / 2; + const int height = kv.size(3); + const int width = kv.size(4); + const int area = height * width; + + const int attn_dim = channels / attn_head; + const int attn_size = kernel_h * kernel_w; + const float attn_scale = pow(attn_dim, -0.5); + + // resize inputs + q = q.view({batch, 1, attn_head, attn_dim, area}).permute({0, 2, 4, 1, 3}) * attn_scale; // batch x attn_head x (height*width) x 1 x attn_dim + offset = offset.view({batch, clip_size, offset.size(1) / clip_size, area}); // batch x clip_size x (deform_groupxattn_sizex2) x (heightxwidht) + + output = output.view({batch, attn_head, attn_dim, height, width}).zero_(); + + // resize temporary columns and attns + columns = at::zeros({clip_size, kv_channels * attn_size, area}, q.options()); + attns = at::zeros({attn_head, area, 1, clip_size * attn_size}, q.options()); + mask_ones = at::ones({deform_group * attn_size, area}, q.options()); // batch x clip_size x (deform_group*attn_size) x (heightxwidth) + + for (int b = 0; b < batch; b++) { // 0->2,1->2, or, 1->3,0->3 // todo: refer to deformable_im2col_cuda and use `im2col_step` to speed up + // grid_sample q and k according to offset + for (int n = 0; n < clip_size; n++) { + modulated_deformable_im2col_cuda( + kv[b/clip_size][(n+b)%clip_size], offset[b][n], mask_ones, 1, kv_channels, height, width, height, + width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, + dilation_h, dilation_w, deform_group, columns[n]); + } + + columns = columns.view({clip_size, 2, attn_head, attn_dim, attn_size, area}) + .permute({1, 2, 5, 3, 0, 4}).flatten(4); // kv x attn_head x (height*width) x attn_dim x (clip_size*attn_size) + + // calculate attention, (attn_head x (height*width) x 1 x attn_dim) @ (attn_head x (height*width) x attn_dim x (clip_size*attn_size)) + attns = at::matmul(q[b], columns[0]) + .softmax(-1); // (attn_head x (height*width) x 1 x (clip_size*attn_size)) + // do attention + output[b] = at::matmul(attns, columns[1].transpose(2, 3)) // (attn_head x (height*width) x 1 x attn_dim) + .transpose(1, 3).view({attn_head, attn_dim, height, width}); // (attn_head x attn_dim x height x width) + + // resize columns back for next batch + columns = columns.view({2, attn_head, area, attn_dim, clip_size , attn_size}) + .permute({4, 0, 1, 3, 5, 2}) // clip_size x attn_head x attn_dim x attn_size x (height*width) + .flatten(1, 3); // clip_size x (attn_head*attn_dim*attn_size) x (height*width) + } + + output = output.view({batch, channels, height, width}); +} + +void deform_attn_cuda_backward( + at::Tensor q, at::Tensor kv, at::Tensor offset, + at::Tensor columns, at::Tensor attns, at::Tensor mask_ones, at::Tensor grad_attns, at::Tensor grad_mask_ones, at::Tensor grad_q, at::Tensor grad_kv, + at::Tensor grad_offset, at::Tensor grad_output, + int kernel_h, int kernel_w, int stride_h, int stride_w, int pad_h, + int pad_w, int dilation_h, int dilation_w, int attn_head, int deform_group, int clip_size + ){ + at::DeviceGuard guard(kv.device()); + + const int batch = q.size(0); + const int kv_channels = kv.size(2); + const int channels = kv.size(2) / 2; + const int height = kv.size(3); + const int width = kv.size(4); + const int area = height * width; + + const int attn_dim = channels / attn_head; + const int attn_size = kernel_h * kernel_w; + const float attn_scale = pow(attn_dim, -0.5); +// // for PyTorch 1.10.1 +// const at::ScalarType dtype = kv.scalar_type(); + + // resize inputs + q = q.view({batch, 1, attn_head, attn_dim, area}).permute({0, 2, 4, 1, 3}) * attn_scale; // batch x attn_head x (height*width) x 1 x attn_dim + offset = offset.view({batch, clip_size, offset.size(1) / clip_size, area}); // batch x clip_size x (deform_groupxattn_sizex2) x (heightxwidht) + + grad_q = grad_q.view({batch, 1, attn_head, attn_dim, area}).permute({0, 2, 4, 1, 3}); + grad_offset = grad_offset.view({batch, clip_size, grad_offset.size(1) / clip_size, area}); + grad_output = grad_output.view({batch, 1, attn_head, attn_dim, area}).permute({0, 2, 4, 1, 3}); + + // resize temporary columns, attns and grad_attns (we further need grad_attns in backward propagation because attn@V are interdependent. + columns = at::zeros({clip_size, kv_channels * attn_size, area}, q.options()); + attns = at::zeros({attn_head, area, 1, clip_size * attn_size}, q.options()); + mask_ones = at::ones({deform_group * attn_size, area}, q.options()); // (deform_group*attn_size) x (heightxwidth) + grad_attns = at::zeros({attn_head, area, 1, clip_size * attn_size}, q.options()); + grad_mask_ones = at::zeros({deform_group * attn_size, area}, q.options()); // not returned + + + for (int b = 0; b < batch; b++) { + // recalculate columns and attns + // grid_sample q and k according to offset + for (int n = 0; n < clip_size; n++) { + modulated_deformable_im2col_cuda( + kv[b/clip_size][(n+b)%clip_size], offset[b][n], mask_ones, 1, kv_channels, height, width, height, + width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, + dilation_h, dilation_w, deform_group, columns[n]); + } + + columns = columns.view({clip_size, 2, attn_head, attn_dim, attn_size, area}) + .permute({1, 2, 5, 3, 0, 4}).flatten(4); // kv x attn_head x (height*width) x attn_dim x (clip_size*attn_size) + + // calculate attention, (attn_head x (height*width) x 1 x attn_dim) @ (attn_head x (height*width) x attn_dim x (clip_size*attn_size)) + attns = at::matmul(q[b], columns[0]) + .softmax(-1); // (attn_head x (height*width) x 1 x (clip_size*attn_size)) + + // gradient w.r.t. attns, (attn_head x (height*width) x 1 x attn_dim) @ (attn_head x (height*width) x attn_dim x (clip_size*attn_size)) + grad_attns = at::matmul(grad_output[b], columns[1]); // (attn_head x (height*width) x 1 x (clip_size*attn_size)) + + // gradient w.r.t. sampled_v, (attn_head x (height*width) x attn_dim x 1) @ (attn_head x (height*width) x 1 x (clip_size*attn_size)) + columns[1] = at::matmul(grad_output[b].transpose(2, 3), attns); // (attn_head x (height*width) x attn_dim x (clip_size*attn_size)) + + // gradient w.r.t. attns_before_softmax +// for PyTorch 1.9.1 + grad_attns = at::_softmax_backward_data(grad_attns, attns, -1, grad_attns); // todo: it seems pt191 has different interface as pt110 +// // for PyTorch 1.10.1 +// grad_attns = at::_softmax_backward_data(grad_attns, attns, -1, dtype); + + // gradient w.r.t. q, (attn_head x (height*width) x 1 x (clip_size*attn_size)) @ (attn_head x (height*width) x (clip_size*attn_size) x attn_dim) + grad_q[b] = at::matmul(grad_attns, columns[0].transpose(2, 3)) * attn_scale; // (attn_head x (height*width) x 1 x attn_dim) + + // gradient w.r.t. sampled_k, (attn_head x (height*width) x attn_dim x 1) @ (attn_head x (height*width) x 1 x (clip_size*attn_size)) + columns[0] = at::matmul(q[b].transpose(2, 3), grad_attns) * attn_scale; // (attn_head x (height*width) x attn_dim x (clip_size*attn_size)) + + columns = columns.view({2, attn_head, area, attn_dim, clip_size, attn_size}) + .permute({4, 0, 1, 3, 5, 2}) // clip_size x 2 x attn_head x attn_dim x attn_size x (height*width) + .flatten(1, 4); // clip_size x (2*attn_head*attn_dim*attn_size) x (height*width) + + for (int n = 0; n < clip_size; n++) { + // gradient w.r.t. input coordinate data (grad_offset and grad_mask_ones) + modulated_deformable_col2im_coord_cuda( + columns[n], kv[b/clip_size][(n+b)%clip_size], offset[b][n], mask_ones, 1, kv_channels, height, width, + height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, + stride_w, dilation_h, dilation_w, deform_group, grad_offset[b][n], + grad_mask_ones); + + // gradient w.r.t. kv + modulated_deformable_col2im_cuda( + columns[n], offset[b][n], mask_ones, 1, kv_channels, height, width, height, + width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, + dilation_h, dilation_w, deform_group, grad_kv[b/clip_size][(n+b)%clip_size]); // the grad is accumulated + } + } + + // resize gradidents back + grad_q = grad_q.transpose(2, 4).view({batch, channels, height, width}); // batch x (attn_headxattn_dim) x height x width + grad_offset = grad_offset.flatten(1, 2); + grad_output = grad_output.permute({0, 1, 4, 3, 2}).view({batch, channels, height, width}); +} diff --git a/models/op/deform_attn_cuda_pt110.cpp b/models/op/deform_attn_cuda_pt110.cpp new file mode 100755 index 00000000..bc979b81 --- /dev/null +++ b/models/op/deform_attn_cuda_pt110.cpp @@ -0,0 +1,219 @@ +// modify from +// https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/blob/mmdetection/mmdet/ops/dcn/src/deform_conv_cuda.c + +#include +#include +#include + +#include +#include + +void deformable_im2col(const at::Tensor data_im, const at::Tensor data_offset, + const int channels, const int height, const int width, + const int ksize_h, const int ksize_w, const int pad_h, + const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int parallel_imgs, const int deform_group, + at::Tensor data_col); + +void deformable_col2im(const at::Tensor data_col, const at::Tensor data_offset, + const int channels, const int height, const int width, + const int ksize_h, const int ksize_w, const int pad_h, + const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int parallel_imgs, const int deform_group, + at::Tensor grad_im); + +void deformable_col2im_coord( + const at::Tensor data_col, const at::Tensor data_im, + const at::Tensor data_offset, const int channels, const int height, + const int width, const int ksize_h, const int ksize_w, const int pad_h, + const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, const int parallel_imgs, + const int deform_group, at::Tensor grad_offset); + +void modulated_deformable_im2col_cuda( + const at::Tensor data_im, const at::Tensor data_offset, + const at::Tensor data_mask, const int batch_size, const int channels, + const int height_im, const int width_im, const int height_col, + const int width_col, const int kernel_h, const int kenerl_w, + const int pad_h, const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, const int deform_group, + at::Tensor data_col); + +void modulated_deformable_col2im_cuda( + const at::Tensor data_col, const at::Tensor data_offset, + const at::Tensor data_mask, const int batch_size, const int channels, + const int height_im, const int width_im, const int height_col, + const int width_col, const int kernel_h, const int kenerl_w, + const int pad_h, const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, const int deform_group, + at::Tensor grad_im); + +void modulated_deformable_col2im_coord_cuda( + const at::Tensor data_col, const at::Tensor data_im, + const at::Tensor data_offset, const at::Tensor data_mask, + const int batch_size, const int channels, const int height_im, + const int width_im, const int height_col, const int width_col, + const int kernel_h, const int kenerl_w, const int pad_h, const int pad_w, + const int stride_h, const int stride_w, const int dilation_h, + const int dilation_w, const int deform_group, at::Tensor grad_offset, + at::Tensor grad_mask); + +void deform_attn_cuda_forward( + at::Tensor q, at::Tensor kv, at::Tensor offset, at::Tensor output, + at::Tensor columns, at::Tensor attns, at::Tensor mask_ones, int kernel_h, int kernel_w, const int stride_h, const int stride_w, + const int pad_h, const int pad_w, const int dilation_h, + const int dilation_w, const int attn_head, const int deform_group, const int clip_size + ){ + TORCH_CHECK(kv.is_contiguous(), "input tensor has to be contiguous"); + at::DeviceGuard guard(kv.device()); + + const int batch = q.size(0); + const int kv_channels = kv.size(2); + const int channels = kv.size(2) / 2; + const int height = kv.size(3); + const int width = kv.size(4); + const int area = height * width; + + const int attn_dim = channels / attn_head; + const int attn_size = kernel_h * kernel_w; + const float attn_scale = pow(attn_dim, -0.5); + + // resize inputs + q = q.view({batch, 1, attn_head, attn_dim, area}).permute({0, 2, 4, 1, 3}) * attn_scale; // batch x attn_head x (height*width) x 1 x attn_dim + offset = offset.view({batch, clip_size, offset.size(1) / clip_size, area}); // batch x clip_size x (deform_groupxattn_sizex2) x (heightxwidht) + + output = output.view({batch, attn_head, attn_dim, height, width}).zero_(); + + // resize temporary columns and attns + columns = at::zeros({clip_size, kv_channels * attn_size, area}, q.options()); + attns = at::zeros({attn_head, area, 1, clip_size * attn_size}, q.options()); + mask_ones = at::ones({deform_group * attn_size, area}, q.options()); // batch x clip_size x (deform_group*attn_size) x (heightxwidth) + + for (int b = 0; b < batch; b++) { // 0->2,1->2, or, 1->3,0->3 // todo: refer to deformable_im2col_cuda and use `im2col_step` to speed up + // grid_sample q and k according to offset + for (int n = 0; n < clip_size; n++) { + modulated_deformable_im2col_cuda( + kv[b/clip_size][(n+b)%clip_size], offset[b][n], mask_ones, 1, kv_channels, height, width, height, + width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, + dilation_h, dilation_w, deform_group, columns[n]); + } + + columns = columns.view({clip_size, 2, attn_head, attn_dim, attn_size, area}) + .permute({1, 2, 5, 3, 0, 4}).flatten(4); // kv x attn_head x (height*width) x attn_dim x (clip_size*attn_size) + + // calculate attention, (attn_head x (height*width) x 1 x attn_dim) @ (attn_head x (height*width) x attn_dim x (clip_size*attn_size)) + attns = at::matmul(q[b], columns[0]) + .softmax(-1); // (attn_head x (height*width) x 1 x (clip_size*attn_size)) + // do attention + output[b] = at::matmul(attns, columns[1].transpose(2, 3)) // (attn_head x (height*width) x 1 x attn_dim) + .transpose(1, 3).view({attn_head, attn_dim, height, width}); // (attn_head x attn_dim x height x width) + + // resize columns back for next batch + columns = columns.view({2, attn_head, area, attn_dim, clip_size , attn_size}) + .permute({4, 0, 1, 3, 5, 2}) // clip_size x attn_head x attn_dim x attn_size x (height*width) + .flatten(1, 3); // clip_size x (attn_head*attn_dim*attn_size) x (height*width) + } + + output = output.view({batch, channels, height, width}); +} + +void deform_attn_cuda_backward( + at::Tensor q, at::Tensor kv, at::Tensor offset, + at::Tensor columns, at::Tensor attns, at::Tensor mask_ones, at::Tensor grad_attns, at::Tensor grad_mask_ones, at::Tensor grad_q, at::Tensor grad_kv, + at::Tensor grad_offset, at::Tensor grad_output, + int kernel_h, int kernel_w, int stride_h, int stride_w, int pad_h, + int pad_w, int dilation_h, int dilation_w, int attn_head, int deform_group, int clip_size + ){ + at::DeviceGuard guard(kv.device()); + + const int batch = q.size(0); + const int kv_channels = kv.size(2); + const int channels = kv.size(2) / 2; + const int height = kv.size(3); + const int width = kv.size(4); + const int area = height * width; + + const int attn_dim = channels / attn_head; + const int attn_size = kernel_h * kernel_w; + const float attn_scale = pow(attn_dim, -0.5); + // for PyTorch 1.10.1 + const at::ScalarType dtype = kv.scalar_type(); + + // resize inputs + q = q.view({batch, 1, attn_head, attn_dim, area}).permute({0, 2, 4, 1, 3}) * attn_scale; // batch x attn_head x (height*width) x 1 x attn_dim + offset = offset.view({batch, clip_size, offset.size(1) / clip_size, area}); // batch x clip_size x (deform_groupxattn_sizex2) x (heightxwidht) + + grad_q = grad_q.view({batch, 1, attn_head, attn_dim, area}).permute({0, 2, 4, 1, 3}); + grad_offset = grad_offset.view({batch, clip_size, grad_offset.size(1) / clip_size, area}); + grad_output = grad_output.view({batch, 1, attn_head, attn_dim, area}).permute({0, 2, 4, 1, 3}); + + // resize temporary columns, attns and grad_attns (we further need grad_attns in backward propagation because attn@V are interdependent. + columns = at::zeros({clip_size, kv_channels * attn_size, area}, q.options()); + attns = at::zeros({attn_head, area, 1, clip_size * attn_size}, q.options()); + mask_ones = at::ones({deform_group * attn_size, area}, q.options()); // (deform_group*attn_size) x (heightxwidth) + grad_attns = at::zeros({attn_head, area, 1, clip_size * attn_size}, q.options()); + grad_mask_ones = at::zeros({deform_group * attn_size, area}, q.options()); // not returned + + + for (int b = 0; b < batch; b++) { + // recalculate columns and attns + // grid_sample q and k according to offset + for (int n = 0; n < clip_size; n++) { + modulated_deformable_im2col_cuda( + kv[b/clip_size][(n+b)%clip_size], offset[b][n], mask_ones, 1, kv_channels, height, width, height, + width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, + dilation_h, dilation_w, deform_group, columns[n]); + } + + columns = columns.view({clip_size, 2, attn_head, attn_dim, attn_size, area}) + .permute({1, 2, 5, 3, 0, 4}).flatten(4); // kv x attn_head x (height*width) x attn_dim x (clip_size*attn_size) + + // calculate attention, (attn_head x (height*width) x 1 x attn_dim) @ (attn_head x (height*width) x attn_dim x (clip_size*attn_size)) + attns = at::matmul(q[b], columns[0]) + .softmax(-1); // (attn_head x (height*width) x 1 x (clip_size*attn_size)) + + // gradient w.r.t. attns, (attn_head x (height*width) x 1 x attn_dim) @ (attn_head x (height*width) x attn_dim x (clip_size*attn_size)) + grad_attns = at::matmul(grad_output[b], columns[1]); // (attn_head x (height*width) x 1 x (clip_size*attn_size)) + + // gradient w.r.t. sampled_v, (attn_head x (height*width) x attn_dim x 1) @ (attn_head x (height*width) x 1 x (clip_size*attn_size)) + columns[1] = at::matmul(grad_output[b].transpose(2, 3), attns); // (attn_head x (height*width) x attn_dim x (clip_size*attn_size)) + + // gradient w.r.t. attns_before_softmax +// // for PyTorch 1.9.1 +// grad_attns = at::_softmax_backward_data(grad_attns, attns, -1, grad_attns); // todo: it seems pt191 has different interface as pt110 + // for PyTorch 1.10.1 + grad_attns = at::_softmax_backward_data(grad_attns, attns, -1, dtype); + + // gradient w.r.t. q, (attn_head x (height*width) x 1 x (clip_size*attn_size)) @ (attn_head x (height*width) x (clip_size*attn_size) x attn_dim) + grad_q[b] = at::matmul(grad_attns, columns[0].transpose(2, 3)) * attn_scale; // (attn_head x (height*width) x 1 x attn_dim) + + // gradient w.r.t. sampled_k, (attn_head x (height*width) x attn_dim x 1) @ (attn_head x (height*width) x 1 x (clip_size*attn_size)) + columns[0] = at::matmul(q[b].transpose(2, 3), grad_attns) * attn_scale; // (attn_head x (height*width) x attn_dim x (clip_size*attn_size)) + + columns = columns.view({2, attn_head, area, attn_dim, clip_size, attn_size}) + .permute({4, 0, 1, 3, 5, 2}) // clip_size x 2 x attn_head x attn_dim x attn_size x (height*width) + .flatten(1, 4); // clip_size x (2*attn_head*attn_dim*attn_size) x (height*width) + + for (int n = 0; n < clip_size; n++) { + // gradient w.r.t. input coordinate data (grad_offset and grad_mask_ones) + modulated_deformable_col2im_coord_cuda( + columns[n], kv[b/clip_size][(n+b)%clip_size], offset[b][n], mask_ones, 1, kv_channels, height, width, + height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, + stride_w, dilation_h, dilation_w, deform_group, grad_offset[b][n], + grad_mask_ones); + + // gradient w.r.t. kv + modulated_deformable_col2im_cuda( + columns[n], offset[b][n], mask_ones, 1, kv_channels, height, width, height, + width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, + dilation_h, dilation_w, deform_group, grad_kv[b/clip_size][(n+b)%clip_size]); // the grad is accumulated + } + } + + // resize gradidents back + grad_q = grad_q.transpose(2, 4).view({batch, channels, height, width}); // batch x (attn_headxattn_dim) x height x width + grad_offset = grad_offset.flatten(1, 2); + grad_output = grad_output.permute({0, 1, 4, 3, 2}).view({batch, channels, height, width}); +} diff --git a/models/op/deform_attn_ext.cpp b/models/op/deform_attn_ext.cpp new file mode 100755 index 00000000..79cf83a5 --- /dev/null +++ b/models/op/deform_attn_ext.cpp @@ -0,0 +1,75 @@ +// modify from +// https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/blob/mmdetection/mmdet/ops/dcn/src/deform_conv_cuda.c + +#include +#include + +#include +#include + +#define WITH_CUDA // always use cuda +#ifdef WITH_CUDA + +void deform_attn_cuda_forward( + at::Tensor q, at::Tensor kv, at::Tensor offset, at::Tensor output, + at::Tensor columns, at::Tensor attns, at::Tensor mask_ones, int kernel_h, int kernel_w, const int stride_h, const int stride_w, + const int pad_h, const int pad_w, const int dilation_h, + const int dilation_w, const int attn_head, const int deform_group, const int clip_size + ); + +void deform_attn_cuda_backward( + at::Tensor q, at::Tensor kv, at::Tensor offset, + at::Tensor columns, at::Tensor attns, at::Tensor mask_ones, at::Tensor grad_attns, at::Tensor grad_mask_ones, at::Tensor grad_q, at::Tensor grad_kv, + at::Tensor grad_offset, at::Tensor grad_output, + int kernel_h, int kernel_w, int stride_h, int stride_w, int pad_h, + int pad_w, int dilation_h, int dilation_w, int attn_head, int deform_group, int clip_size + ); +#endif + +void deform_attn_forward( + at::Tensor q, at::Tensor kv, at::Tensor offset, at::Tensor output, + at::Tensor columns, at::Tensor attns, at::Tensor mask_ones, int kernel_h, int kernel_w, const int stride_h, const int stride_w, + const int pad_h, const int pad_w, const int dilation_h, + const int dilation_w, const int attn_head, const int deform_group, const int clip_size + ) { + if (q.device().is_cuda()) { +#ifdef WITH_CUDA + return deform_attn_cuda_forward(q, kv, + offset, output, columns, attns, mask_ones, kernel_h, kernel_w, stride_h, + stride_w, pad_h, pad_w, dilation_h, dilation_w, attn_head, deform_group, clip_size); +#else + AT_ERROR("modulated deform attn is not compiled with GPU support"); +#endif + } + AT_ERROR("modulated deform attn is not implemented on CPU"); +} + +void deform_attn_backward( + at::Tensor q, at::Tensor kv, at::Tensor offset, at::Tensor columns, + at::Tensor attns, at::Tensor mask_ones, at::Tensor grad_attns, at::Tensor grad_mask_ones, at::Tensor grad_q, at::Tensor grad_kv, + at::Tensor grad_offset, at::Tensor grad_output, + int kernel_h, int kernel_w, int stride_h, int stride_w, int pad_h, + int pad_w, int dilation_h, int dilation_w, int attn_head, int deform_group, int clip_size + ) { + if (q.device().is_cuda()) { +#ifdef WITH_CUDA + return deform_attn_cuda_backward(q, kv, + offset, columns, attns, mask_ones, grad_attns, grad_mask_ones, grad_q, grad_kv, grad_offset, + grad_output, kernel_h, kernel_w, stride_h, stride_w, + pad_h, pad_w, dilation_h, dilation_w, attn_head, deform_group, clip_size); +#else + AT_ERROR("modulated deform attn is not compiled with GPU support"); +#endif + } + AT_ERROR("modulated deform attn is not implemented on CPU"); +} + + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("deform_attn_forward", + &deform_attn_forward, + "deform attn forward"); + m.def("deform_attn_backward", + &deform_attn_backward, + "deform attn backward"); +} diff --git a/models/select_network.py b/models/select_network.py index c5f92d19..dfe4797b 100644 --- a/models/select_network.py +++ b/models/select_network.py @@ -222,6 +222,31 @@ def define_G(opt): no_checkpoint_attn_blocks=opt_net['no_checkpoint_attn_blocks'], no_checkpoint_ffn_blocks=opt_net['no_checkpoint_ffn_blocks']) + # ---------------------------------------- + # RVRT + # ---------------------------------------- + elif net_type == 'rvrt': + from models.network_rvrt import RVRT as net + netG = net(upscale=opt_net['upscale'], + clip_size=opt_net['clip_size'], + img_size=opt_net['img_size'], + window_size=opt_net['window_size'], + num_blocks=opt_net['num_blocks'], + depths=opt_net['depths'], + embed_dims=opt_net['embed_dims'], + num_heads=opt_net['num_heads'], + inputconv_groups=opt_net['inputconv_groups'], + spynet_path=opt_net['spynet_path'], + deformable_groups=opt_net['deformable_groups'], + attention_heads=opt_net['attention_heads'], + attention_window=opt_net['attention_window'], + nonblind_denoising=opt_net['nonblind_denoising'], + use_checkpoint_attn=opt_net['use_checkpoint_attn'], + use_checkpoint_ffn=opt_net['use_checkpoint_ffn'], + no_checkpoint_attn_blocks=opt_net['no_checkpoint_attn_blocks'], + no_checkpoint_ffn_blocks=opt_net['no_checkpoint_ffn_blocks'], + cpu_cache_length=opt_net['cpu_cache_length']) + # ---------------------------------------- # others # ---------------------------------------- diff --git a/options/rvrt/001_train_rvrt_videosr_bi_reds_30frames.json b/options/rvrt/001_train_rvrt_videosr_bi_reds_30frames.json new file mode 100644 index 00000000..fa50a5b6 --- /dev/null +++ b/options/rvrt/001_train_rvrt_videosr_bi_reds_30frames.json @@ -0,0 +1,121 @@ +{ + "task": "001_train_rvrt_videosr_bi_reds_30frames" + , "model": "vrt" + , "gpu_ids": [0,1,2,3,4,5,6,7] + , "dist": true + , "find_unused_parameters": false + , "use_static_graph": true + + ,"scale": 4 + , "n_channels": 3 + + , "path": { + "root": "experiments" + , "pretrained_netG": null + , "pretrained_netE": null + } + + , "datasets": { + "train": { + "name": "train_dataset" + , "dataset_type": "VideoRecurrentTrainDataset" + , "dataroot_gt": "trainsets/REDS/train_sharp_with_val.lmdb" + , "dataroot_lq": "trainsets/REDS/train_sharp_bicubic_with_val.lmdb" + , "meta_info_file": "data/meta_info/meta_info_REDS_GT.txt" + , "filename_tmpl": "08d" + , "filename_ext": "png" + , "val_partition": "REDS4" + , "test_mode": false + , "io_backend": {"type": "lmdb"} + , "num_frame": 30 + , "gt_size": 256 + , "interval_list": [1] + , "random_reverse": false + , "use_hflip": true + , "use_rot": true + + , "dataloader_shuffle": true + , "dataloader_num_workers": 32 + , "dataloader_batch_size": 8 + } + , "test": { + "name": "test_dataset" + , "dataset_type": "VideoRecurrentTestDataset" + , "dataroot_gt": "testsets/REDS4/GT" + , "dataroot_lq": "testsets/REDS4/sharp_bicubic" + , "cache_data": true + , "io_backend": {"type": "disk"} + , "num_frame": -1 + } + } + + , "netG": { + "net_type": "rvrt" + , "upscale": 4 + , "clip_size": 2 + , "img_size": [2, 64, 64] + , "window_size": [2, 8, 8] + , "num_blocks": [1, 2, 1] + , "depths": [2, 2, 2] + , "embed_dims": [144, 144, 144] + , "num_heads": [6, 6, 6] + , "inputconv_groups": [1, 1, 1, 1, 1, 1] + , "spynet_path": "model_zoo/rvrt/spynet_sintel_final-3d2a1287.pth" // automatical download + , "deformable_groups": 12 + , "attention_heads": 12 + , "attention_window": [3, 3] + , "use_checkpoint_attn": false + , "use_checkpoint_ffn": true + , "no_checkpoint_attn_blocks": [] + , "no_checkpoint_ffn_blocks": [] + , "cpu_cache_length": 100 + + , "init_type": "default" + } + + + , "train": { + "G_lossfn_type": "charbonnier" + , "G_lossfn_weight": 1.0 + , "G_charbonnier_eps": 1e-9 + + , "E_decay": 0 // Exponential Moving Average for netG: set 0 to disable; default setting 0.999 + + , "G_optimizer_type": "adam" // fixed, adam is enough + , "G_optimizer_lr": 4e-4 // learning rate + , "G_optimizer_betas": [0.9,0.99] + , "G_optimizer_wd": 0 // weight decay, default 0 + , "G_optimizer_clipgrad": null // unused + , "G_optimizer_reuse": true // + + , "fix_iter": 30000 + , "fix_lr_mul": 0.25 + , "fix_keys": ["spynet"] + + , "total_iter": 600000 + , "G_scheduler_type": "CosineAnnealingWarmRestarts" + , "G_scheduler_periods": 600000 + , "G_scheduler_eta_min": 1e-7 + + , "G_regularizer_orthstep": null // unused + , "G_regularizer_clipstep": null // unused + + , "G_param_strict": true + , "E_param_strict": true + + , "checkpoint_test": 5000 // for testing + , "checkpoint_save": 5000 // for saving model + , "checkpoint_print": 200 // for print + } + + , "val": { + "save_img": false + , "pad_seq": false + , "flip_seq": false + , "center_frame_only": false + , "num_frame_testing": 0 + , "num_frame_overlapping": 2 + , "size_patch_testing": 128 + } + +} diff --git a/options/rvrt/002_train_rvrt_videosr_bi_vimeo_14frames.json b/options/rvrt/002_train_rvrt_videosr_bi_vimeo_14frames.json new file mode 100644 index 00000000..d0936960 --- /dev/null +++ b/options/rvrt/002_train_rvrt_videosr_bi_vimeo_14frames.json @@ -0,0 +1,118 @@ +{ + "task": "002_train_rvrt_videosr_bi_vimeo_14frames" + , "model": "vrt" + , "gpu_ids": [0,1,2,3,4,5,6,7] + , "dist": true + , "find_unused_parameters": false + , "use_static_graph": true + + ,"scale": 4 + , "n_channels": 3 + + , "path": { + "root": "experiments" + , "pretrained_netG": "model_zoo/rvrt/001_RVRT_videosr_bi_REDS_30frames.pth" + , "pretrained_netE": null + } + + , "datasets": { + "train": { + "name": "train_dataset" + , "dataset_type": "VideoRecurrentTrainVimeoDataset" + , "dataroot_gt": "trainsets/vimeo90k/vimeo90k_train_GT_all.lmdb" + , "dataroot_lq": "trainsets/vimeo90k/vimeo90k_train_LR7frames.lmdb" + , "meta_info_file": "data/meta_info/meta_info_Vimeo90K_train_GT.txt" + , "io_backend": {"type": "lmdb"} + , "num_frame": -1 + , "gt_size": 256 + , "interval_list": [1] + , "random_reverse": true + , "use_hflip": true + , "use_rot": true + , "mirror_sequence": true + + , "dataloader_shuffle": true + , "dataloader_num_workers": 32 + , "dataloader_batch_size": 8 + } + , "test": { + "name": "test_dataset" + , "dataset_type": "VideoRecurrentTestDataset" + , "dataroot_gt": "testsets/Vid4/GT" + , "dataroot_lq": "testsets/Vid4/BIx4" + , "cache_data": true + , "io_backend": {"type": "disk"} + , "num_frame": -1 + } + } + + , "netG": { + "net_type": "rvrt" + , "upscale": 4 + , "clip_size": 2 + , "img_size": [2, 64, 64] + , "window_size": [2, 8, 8] + , "num_blocks": [1, 2, 1] + , "depths": [2, 2, 2] + , "embed_dims": [144, 144, 144] + , "num_heads": [6, 6, 6] + , "inputconv_groups": [1, 1, 1, 1, 1, 1] + , "spynet_path": "model_zoo/rvrt/spynet_sintel_final-3d2a1287.pth" // automatical download + , "deformable_groups": 12 + , "attention_heads": 12 + , "attention_window": [3, 3] + , "use_checkpoint_attn": false + , "use_checkpoint_ffn": false + , "no_checkpoint_attn_blocks": [] + , "no_checkpoint_ffn_blocks": [] + , "cpu_cache_length": 100 + + , "init_type": "default" + } + + + , "train": { + "G_lossfn_type": "charbonnier" + , "G_lossfn_weight": 1.0 + , "G_charbonnier_eps": 1e-9 + + , "E_decay": 0 // Exponential Moving Average for netG: set 0 to disable; default setting 0.999 + + , "G_optimizer_type": "adam" // fixed, adam is enough + , "G_optimizer_lr": 2e-4 // learning rate + , "G_optimizer_betas": [0.9,0.99] + , "G_optimizer_wd": 0 // weight decay, default 0 + , "G_optimizer_clipgrad": null // unused + , "G_optimizer_reuse": true // + + , "fix_iter": -1 + , "fix_lr_mul": 0.25 + , "fix_keys": ["spynet"] + + , "total_iter": 300000 + , "G_scheduler_type": "CosineAnnealingWarmRestarts" + , "G_scheduler_periods": 300000 + , "G_scheduler_eta_min": 1e-7 + + , "G_regularizer_orthstep": null // unused + , "G_regularizer_clipstep": null // unused + + , "G_param_strict": false + , "E_param_strict": false + + , "checkpoint_test": 5000 // for testing + , "checkpoint_save": 5000 // for saving model + , "checkpoint_print": 200 // for print + } + + , "val": { + "save_img": false + , "pad_seq": false + , "flip_seq": false + , "center_frame_only": false + , "num_frame_testing": 0 + , "num_frame_overlapping": 2 + , "size_patch_testing": 0 + } + +} diff --git a/options/rvrt/003_train_rvrt_videosr_bd_vimeo_14frames.json b/options/rvrt/003_train_rvrt_videosr_bd_vimeo_14frames.json new file mode 100644 index 00000000..b343b585 --- /dev/null +++ b/options/rvrt/003_train_rvrt_videosr_bd_vimeo_14frames.json @@ -0,0 +1,118 @@ +{ + "task": "003_train_rvrt_videosr_bd_vimeo_14frames" + , "model": "vrt" + , "gpu_ids": [0,1,2,3,4,5,6,7] + , "dist": true + , "find_unused_parameters": false + , "use_static_graph": true + + ,"scale": 4 + , "n_channels": 3 + + , "path": { + "root": "experiments" + , "pretrained_netG": "model_zoo/rvrt/001_RVRT_videosr_bi_REDS_30frames.pth" + , "pretrained_netE": null + } + + , "datasets": { + "train": { + "name": "train_dataset" + , "dataset_type": "VideoRecurrentTrainVimeoDataset" + , "dataroot_gt": "trainsets/vimeo90k/vimeo90k_train_GT_all.lmdb" + , "dataroot_lq": "trainsets/vimeo90k/vimeo90k_train_BDLR7frames.lmdb" + , "meta_info_file": "data/meta_info/meta_info_Vimeo90K_train_GT.txt" + , "io_backend": {"type": "lmdb"} + , "num_frame": -1 + , "gt_size": 256 + , "interval_list": [1] + , "random_reverse": true + , "use_hflip": true + , "use_rot": true + , "mirror_sequence": true + + , "dataloader_shuffle": true + , "dataloader_num_workers": 32 + , "dataloader_batch_size": 8 + } + , "test": { + "name": "test_dataset" + , "dataset_type": "VideoRecurrentTestDataset" + , "dataroot_gt": "testsets/Vid4/GT" + , "dataroot_lq": "testsets/Vid4/BDx4" + , "cache_data": true + , "io_backend": {"type": "disk"} + , "num_frame": -1 + } + } + + , "netG": { + "net_type": "rvrt" + , "upscale": 4 + , "clip_size": 2 + , "img_size": [2, 64, 64] + , "window_size": [2, 8, 8] + , "num_blocks": [1, 2, 1] + , "depths": [2, 2, 2] + , "embed_dims": [144, 144, 144] + , "num_heads": [6, 6, 6] + , "inputconv_groups": [1, 1, 1, 1, 1, 1] + , "spynet_path": "model_zoo/rvrt/spynet_sintel_final-3d2a1287.pth" // automatical download + , "deformable_groups": 12 + , "attention_heads": 12 + , "attention_window": [3, 3] + , "use_checkpoint_attn": false + , "use_checkpoint_ffn": false + , "no_checkpoint_attn_blocks": [] + , "no_checkpoint_ffn_blocks": [] + , "cpu_cache_length": 100 + + , "init_type": "default" + } + + + , "train": { + "G_lossfn_type": "charbonnier" + , "G_lossfn_weight": 1.0 + , "G_charbonnier_eps": 1e-9 + + , "E_decay": 0 // Exponential Moving Average for netG: set 0 to disable; default setting 0.999 + + , "G_optimizer_type": "adam" // fixed, adam is enough + , "G_optimizer_lr": 2e-4 // learning rate + , "G_optimizer_betas": [0.9,0.99] + , "G_optimizer_wd": 0 // weight decay, default 0 + , "G_optimizer_clipgrad": null // unused + , "G_optimizer_reuse": true // + + , "fix_iter": -1 + , "fix_lr_mul": 0.25 + , "fix_keys": ["spynet"] + + , "total_iter": 300000 + , "G_scheduler_type": "CosineAnnealingWarmRestarts" + , "G_scheduler_periods": 300000 + , "G_scheduler_eta_min": 1e-7 + + , "G_regularizer_orthstep": null // unused + , "G_regularizer_clipstep": null // unused + + , "G_param_strict": false + , "E_param_strict": false + + , "checkpoint_test": 5000 // for testing + , "checkpoint_save": 5000 // for saving model + , "checkpoint_print": 200 // for print + } + + , "val": { + "save_img": false + , "pad_seq": false + , "flip_seq": false + , "center_frame_only": false + , "num_frame_testing": 0 + , "num_frame_overlapping": 2 + , "size_patch_testing": 0 + } + +} diff --git a/options/rvrt/004_train_rvrt_videodeblurring_dvd.json b/options/rvrt/004_train_rvrt_videodeblurring_dvd.json new file mode 100644 index 00000000..b6d99852 --- /dev/null +++ b/options/rvrt/004_train_rvrt_videodeblurring_dvd.json @@ -0,0 +1,120 @@ +{ + "task": "004_train_rvrt_videodeblurring_dvd" + , "model": "vrt" + , "gpu_ids": [0,1,2,3,4,5,6,7] + , "dist": true + , "find_unused_parameters": false + , "use_static_graph": true + + ,"scale": 1 + , "n_channels": 3 + + , "path": { + "root": "experiments" + , "pretrained_netG": null + , "pretrained_netE": null + } + + , "datasets": { + "train": { + "name": "train_dataset" + , "dataset_type": "VideoRecurrentTrainDataset" + , "dataroot_gt": "trainsets/DVD/train_GT.lmdb" + , "dataroot_lq": "trainsets/DVD/train_GT_blurred.lmdb" + , "meta_info_file": "data/meta_info/meta_info_DVD_train_GT.txt" + , "filename_tmpl": "05d" + , "filename_ext": "jpg" + , "test_mode": false + , "io_backend": {"type": "lmdb"} + , "num_frame": 16 + , "gt_size": 256 + , "interval_list": [1] + , "random_reverse": false + , "use_hflip": true + , "use_rot": true + + , "dataloader_shuffle": true + , "dataloader_num_workers": 32 + , "dataloader_batch_size": 8 + } + , "test": { + "name": "test_dataset" + , "dataset_type": "VideoRecurrentTestDataset" + , "dataroot_gt": "testsets/DVD10/test_GT" + , "dataroot_lq": "testsets/DVD10/test_GT_blurred" + , "cache_data": false + , "io_backend": {"type": "disk"} + , "num_frame": -1 + } + } + + , "netG": { + "net_type": "rvrt" + , "upscale": 1 + , "clip_size": 2 + , "img_size": [2, 64, 64] + , "window_size": [2, 8, 8] + , "num_blocks": [1, 2, 1] + , "depths": [2, 2, 2] + , "embed_dims": [192, 192, 192] + , "num_heads": [6, 6, 6] + , "inputconv_groups": [1, 3, 3, 3, 3, 3] + , "spynet_path": "model_zoo/rvrt/spynet_sintel_final-3d2a1287.pth" // automatical download + , "deformable_groups": 12 + , "attention_heads": 12 + , "attention_window": [3, 3] + , "use_checkpoint_attn": false + , "use_checkpoint_ffn": false + , "no_checkpoint_attn_blocks": [] + , "no_checkpoint_ffn_blocks": [] + , "cpu_cache_length": 100 + + , "init_type": "default" + } + + + , "train": { + "G_lossfn_type": "charbonnier" + , "G_lossfn_weight": 1.0 + , "G_charbonnier_eps": 1e-9 + + , "E_decay": 0 // Exponential Moving Average for netG: set 0 to disable; default setting 0.999 + + , "G_optimizer_type": "adam" // fixed, adam is enough + , "G_optimizer_lr": 4e-4 // learning rate + , "G_optimizer_betas": [0.9,0.99] + , "G_optimizer_wd": 0 // weight decay, default 0 + , "G_optimizer_clipgrad": null // unused + , "G_optimizer_reuse": true // + + , "fix_iter": 30000 + , "fix_lr_mul": 0.25 + , "fix_keys": ["spynet"] + + , "total_iter": 600000 + , "G_scheduler_type": "CosineAnnealingWarmRestarts" + , "G_scheduler_periods": 600000 + , "G_scheduler_eta_min": 1e-7 + + , "G_regularizer_orthstep": null // unused + , "G_regularizer_clipstep": null // unused + + , "G_param_strict": true + , "E_param_strict": true + + , "checkpoint_test": 5000 // for testing + , "checkpoint_save": 5000 // for saving model + , "checkpoint_print": 200 // for print + } + + , "val": { + "save_img": false + , "pad_seq": false + , "flip_seq": false + , "center_frame_only": false + , "num_frame_testing": 0 + , "num_frame_overlapping": 2 + , "size_patch_testing": 256 + } + +} diff --git a/options/rvrt/005_train_rvrt_videodeblurring_gopro.json b/options/rvrt/005_train_rvrt_videodeblurring_gopro.json new file mode 100644 index 00000000..f658aa78 --- /dev/null +++ b/options/rvrt/005_train_rvrt_videodeblurring_gopro.json @@ -0,0 +1,120 @@ +{ + "task": "005_train_rvrt_videodeblurring_gopro" + , "model": "vrt" + , "gpu_ids": [0,1,2,3,4,5,6,7] + , "dist": true + , "find_unused_parameters": false + , "use_static_graph": true + + ,"scale": 1 + , "n_channels": 3 + + , "path": { + "root": "experiments" + , "pretrained_netG": null + , "pretrained_netE": null + } + + , "datasets": { + "train": { + "name": "train_dataset" + , "dataset_type": "VideoRecurrentTrainDataset" + , "dataroot_gt": "trainsets/GoPro/train_GT.lmdb" + , "dataroot_lq": "trainsets/GoPro/train_GT_blurred.lmdb" + , "meta_info_file": "data/meta_info/meta_info_GoPro_train_GT.txt" + , "filename_tmpl": "06d" + , "filename_ext": "png" + , "test_mode": false + , "io_backend": {"type": "lmdb"} + , "num_frame": 16 + , "gt_size": 256 + , "interval_list": [1] + , "random_reverse": false + , "use_hflip": true + , "use_rot": true + + , "dataloader_shuffle": true + , "dataloader_num_workers": 32 + , "dataloader_batch_size": 8 + } + , "test": { + "name": "test_dataset" + , "dataset_type": "VideoRecurrentTestDataset" + , "dataroot_gt": "testsets/GoPro11/test_GT" + , "dataroot_lq": "testsets/GoPro11/test_GT_blurred" + , "cache_data": false + , "io_backend": {"type": "disk"} + , "num_frame": -1 + } + } + + , "netG": { + "net_type": "rvrt" + , "upscale": 1 + , "clip_size": 2 + , "img_size": [2, 64, 64] + , "window_size": [2, 8, 8] + , "num_blocks": [1, 2, 1] + , "depths": [2, 2, 2] + , "embed_dims": [192, 192, 192] + , "num_heads": [6, 6, 6] + , "inputconv_groups": [1, 3, 3, 3, 3, 3] + , "spynet_path": "model_zoo/rvrt/spynet_sintel_final-3d2a1287.pth" // automatical download + , "deformable_groups": 12 + , "attention_heads": 12 + , "attention_window": [3, 3] + , "use_checkpoint_attn": false + , "use_checkpoint_ffn": false + , "no_checkpoint_attn_blocks": [] + , "no_checkpoint_ffn_blocks": [] + , "cpu_cache_length": 100 + + , "init_type": "default" + } + + + , "train": { + "G_lossfn_type": "charbonnier" + , "G_lossfn_weight": 1.0 + , "G_charbonnier_eps": 1e-9 + + , "E_decay": 0 // Exponential Moving Average for netG: set 0 to disable; default setting 0.999 + + , "G_optimizer_type": "adam" // fixed, adam is enough + , "G_optimizer_lr": 4e-4 // learning rate + , "G_optimizer_betas": [0.9,0.99] + , "G_optimizer_wd": 0 // weight decay, default 0 + , "G_optimizer_clipgrad": null // unused + , "G_optimizer_reuse": true // + + , "fix_iter": 30000 + , "fix_lr_mul": 0.25 + , "fix_keys": ["spynet"] + + , "total_iter": 600000 + , "G_scheduler_type": "CosineAnnealingWarmRestarts" + , "G_scheduler_periods": 600000 + , "G_scheduler_eta_min": 1e-7 + + , "G_regularizer_orthstep": null // unused + , "G_regularizer_clipstep": null // unused + + , "G_param_strict": true + , "E_param_strict": true + + , "checkpoint_test": 5000 // for testing + , "checkpoint_save": 5000 // for saving model + , "checkpoint_print": 200 // for print + } + + , "val": { + "save_img": false + , "pad_seq": false + , "flip_seq": false + , "center_frame_only": false + , "num_frame_testing": 0 + , "num_frame_overlapping": 2 + , "size_patch_testing": 256 + } + +} diff --git a/options/rvrt/006_train_rvrt_videodenoising_davis.json b/options/rvrt/006_train_rvrt_videodenoising_davis.json new file mode 100644 index 00000000..a10ca1dd --- /dev/null +++ b/options/rvrt/006_train_rvrt_videodenoising_davis.json @@ -0,0 +1,126 @@ +{ + "task": "006_train_rvrt_videodenoising_davis" + , "model": "vrt" + , "gpu_ids": [0,1,2,3,4,5,6,7] + , "dist": true + , "find_unused_parameters": false + , "use_static_graph": true + + ,"scale": 1 + , "n_channels": 3 + + , "path": { + "root": "experiments" + , "pretrained_netG": null + , "pretrained_netE": null + } + + , "datasets": { + "train": { + "name": "train_dataset" + , "dataset_type": "VideoRecurrentTrainNonblindDenoisingDataset" + , "dataroot_gt": "trainsets/DAVIS/train_GT.lmdb" + , "dataroot_lq": "trainsets/DAVIS/train_GT.lmdb" + , "meta_info_file": "data/meta_info/meta_info_DAVIS_train_GT.txt" + , "filename_tmpl": "05d" + , "filename_ext": "jpg" + , "test_mode": false + , "io_backend": {"type": "lmdb"} + , "num_frame": 16 + , "gt_size": 256 + , "interval_list": [1] + , "random_reverse": false + , "use_hflip": true + , "use_rot": true + + , "sigma_min": 0 + , "sigma_max": 50 + + , "dataloader_shuffle": true + , "dataloader_num_workers": 32 + , "dataloader_batch_size": 8 + } + , "test": { + "name": "test_dataset" + , "dataset_type": "VideoRecurrentTestDataset" + , "dataroot_gt": "testsets/Set8" + , "dataroot_lq": "testsets/Set8" + , "cache_data": true + , "io_backend": {"type": "disk"} + , "num_frame": -1 + + , "sigma": 30 + } + } + + , "netG": { + "net_type": "rvrt" + , "upscale": 1 + , "clip_size": 2 + , "img_size": [2, 64, 64] + , "window_size": [2, 8, 8] + , "num_blocks": [1, 2, 1] + , "depths": [2, 2, 2] + , "embed_dims": [192, 192, 192] + , "num_heads": [6, 6, 6] + , "inputconv_groups": [1, 3, 4, 6, 8, 4] + , "spynet_path": "model_zoo/rvrt/spynet_sintel_final-3d2a1287.pth" // automatical download + , "deformable_groups": 12 + , "attention_heads": 12 + , "nonblind_denoising": true + , "attention_window": [3, 3] + , "use_checkpoint_attn": false + , "use_checkpoint_ffn": false + , "no_checkpoint_attn_blocks": [] + , "no_checkpoint_ffn_blocks": [] + , "cpu_cache_length": 100 + + , "init_type": "default" + } + + + , "train": { + "G_lossfn_type": "charbonnier" + , "G_lossfn_weight": 1.0 + , "G_charbonnier_eps": 1e-9 + + , "E_decay": 0 // Exponential Moving Average for netG: set 0 to disable; default setting 0.999 + + , "G_optimizer_type": "adam" // fixed, adam is enough + , "G_optimizer_lr": 4e-4 // learning rate + , "G_optimizer_betas": [0.9,0.99] + , "G_optimizer_wd": 0 // weight decay, default 0 + , "G_optimizer_clipgrad": null // unused + , "G_optimizer_reuse": true // + + , "fix_iter": 30000 + , "fix_lr_mul": 0.25 + , "fix_keys": ["spynet"] + + , "total_iter": 600000 + , "G_scheduler_type": "CosineAnnealingWarmRestarts" + , "G_scheduler_periods": 600000 + , "G_scheduler_eta_min": 1e-7 + + , "G_regularizer_orthstep": null // unused + , "G_regularizer_clipstep": null // unused + + , "G_param_strict": true + , "E_param_strict": true + + , "checkpoint_test": 5000 // for testing + , "checkpoint_save": 5000 // for saving model + , "checkpoint_print": 200 // for print + } + + , "val": { + "save_img": false + , "pad_seq": false + , "flip_seq": false + , "center_frame_only": false + , "num_frame_testing": 0 + , "num_frame_overlapping": 2 + , "size_patch_testing": 256 + } + +}