-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
28 changed files
with
2,199 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,61 @@ | ||
# NSRD | ||
[CVPR 2024] Neural Super-Resolution for Real-time Rendering with Radiance Demodulation | ||
## <u>N</u>eural <u>S</u>uper-Resolution for Real-time Rendering with <u>R</u>adiance <u>D</u>emodulation (CVPR 2024) | ||
|
||
### [Paper](https://markdown.com.cn) | [Datasets](https://markdown.com.cn) | ||
|
||
### Installation | ||
|
||
Tested on Windows + CUDA 11.3 + Pytorch 1.12.1 | ||
|
||
Install environment: | ||
|
||
```bazaar | ||
git clone https://github.com/riga2/NSRD.git | ||
cd NSRD | ||
conda create -n NSRD python=3.9 | ||
conda activate NSRD | ||
pip install torch==1.12.1+cu113 torchvision==0.13.1+cu113 torchaudio==0.12.1 --extra-index-url https://download.pytorch.org/whl/cu113 | ||
pip install -r requirements.txt | ||
``` | ||
|
||
The following training and testing take the Bistro scene (X4) as an example. | ||
|
||
### Training | ||
1. Download the dataset and put it in the "dataset" folder. | ||
```bazaar | ||
|--dataset | ||
|--Bistro | ||
|--train | ||
|---GT | ||
|--0 | ||
|--1 | ||
... | ||
|---X4 | ||
|--0 | ||
|--1 | ||
... | ||
|---test | ||
|---GT | ||
... | ||
|---X4 | ||
... | ||
``` | ||
2. Use the Anaconda Prompt to run the following commands to train. The trained model is stored in "experiment\Bistro_X4\model". | ||
```bazaar | ||
cd src | ||
.\script\BistroX4_train.bat | ||
``` | ||
|
||
### Testing | ||
1. Run the following commands to perform super-resolution on the LR lighting components. The SR results are stored in "experiment\Bistro_X4\sr_results_x4". | ||
```bazaar | ||
cd src | ||
.\test_script\BistroX4_test.bat | ||
``` | ||
2. Run the following commands to perform remodulation on the SR lighting components. The final results are stored in "experiment\Bistro_X4\final_results_x4". | ||
```bazaar | ||
cd src | ||
python remodulation.py --exp_dir ../experiment/Bistro_X4 --gt_dir ../dataset/Bistro/test/GT | ||
``` | ||
|
||
### Citations | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
root_dir = ../dataset/Bistro/test | ||
grain = 300 | ||
total_folder_num = 4 | ||
test_folder = [1] # Currently only supports testing one folder at a time | ||
valid_folder_num = 0 | ||
test_only = True | ||
save = Bistro_X4 | ||
resume = -1 | ||
save_results = True |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
root_dir = ../dataset/Bistro/train | ||
grain = 100 | ||
total_folder_num = 60 | ||
test_folder = () | ||
valid_folder_num = 6 | ||
test_only = False | ||
save = Bistro_X4 | ||
resume = 1 | ||
save_results = False |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
numpy==1.24.3 | ||
opencv-python==4.7.0.72 | ||
Pillow==9.5.0 | ||
tqdm==4.65.0 | ||
matplotlib | ||
imageio | ||
scikit-image | ||
configparser |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
from importlib import import_module | ||
from data import data_loader | ||
import os | ||
|
||
class Data: | ||
def __init__(self, args): | ||
self.sr_content = args.sr_content | ||
self.loader_train = None | ||
module_name = self.sr_content | ||
m = import_module('data.' + module_name.lower() + '_dataset') | ||
if not args.test_only: | ||
train_datasets = m.make_model(args, train=True) | ||
valid_datasets = m.make_model(args, train=False) | ||
self.loader_train = data_loader.RenderingDataLoader(args, train_datasets) | ||
self.loader_valid = self.loader_train.split_validation(valid_datasets) | ||
else: | ||
test_datasets = m.make_model(args, train=False) | ||
self.loader_valid = data_loader.RenderingDataLoader(args, test_datasets) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
import os | ||
os.environ["OPENCV_IO_ENABLE_OPENEXR"]="1" | ||
import cv2 | ||
import numpy as np | ||
from torch.utils.data import Dataset | ||
from data import data_utils | ||
|
||
class BaseDataset(Dataset): | ||
def __init__(self, args=None): | ||
super(BaseDataset, self).__init__() | ||
self.args = args | ||
self.scale = args.scale | ||
self.number_previous_frames = args.num_pre_frames | ||
self.hr_size = args.gt_size | ||
self.lr_size = (args.gt_size[0] // self.scale, args.gt_size[1] // self.scale) | ||
|
||
self.gt_dir = os.path.join(args.root_dir, 'GT') | ||
self.lr_dir = os.path.join(args.root_dir, 'X' + str(args.scale)) | ||
|
||
self.useNormal, self.useDepth = args.use_normal, args.use_depth | ||
self.depth_dirname = args.depth_dirname | ||
self.normal_dirname = args.normal_dirname | ||
self.mv_dirname = args.mv_dirname | ||
self.ocmv_dirname = args.ocmv_dirname | ||
self.grain = args.grain | ||
self.total_folder_num = args.total_folder_num | ||
|
||
def load_Normal_Unity(self, folder_index, file_index, ext='.exr'): | ||
filename = str(file_index) + ext | ||
lr_file_path = os.path.join(self.lr_dir, str(folder_index), self.normal_dirname, filename) | ||
lr = data_utils.getFromExr(lr_file_path)[:, :, :3] | ||
return lr | ||
|
||
def load_Depth_Unity(self, folder_index, file_index, ext='.exr'): | ||
filename = str(file_index) + ext | ||
lr_file_path = os.path.join(self.lr_dir, str(folder_index), self.depth_dirname, filename) | ||
lr = data_utils.getFromExr(lr_file_path)[:, :, 0][:, :, None] | ||
return lr | ||
|
||
def load_MV(self, folder_index, file_index, ext='.exr'): | ||
filename = str(file_index) + ext | ||
lr_file_path = os.path.join(self.lr_dir, str(folder_index), self.mv_dirname, filename) | ||
# lr = data_utils.getFromBin(lr_file_path, self.lr_size[0], self.lr_size[1])[:, :, :2] | ||
lr = data_utils.getFromExr(lr_file_path)[:, :, :2] | ||
lr[:, :, 0] = lr[:, :, 0] * self.lr_size[1] | ||
lr[:, :, 1] = lr[:, :, 1] * self.lr_size[0] | ||
return lr | ||
|
||
def load_OCMV(self, folder_index, file_index, ext='.exr'): | ||
filename = str(file_index) + ext | ||
lr_file_path = os.path.join(self.lr_dir, str(folder_index), self.ocmv_dirname, filename) | ||
# lr = data_utils.getFromBin(lr_file_path, self.lr_size[0], self.lr_size[1])[:, :, :2] | ||
lr = data_utils.getFromExr(lr_file_path)[:, :, :2] | ||
lr[:, :, 0] = lr[:, :, 0] * self.lr_size[1] | ||
lr[:, :, 1] = lr[:, :, 1] * self.lr_size[0] | ||
return lr | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,93 @@ | ||
import os | ||
|
||
import numpy as np | ||
from torch.utils.data import DataLoader | ||
from torch.utils.data.sampler import SubsetRandomSampler, SequentialSampler, Sampler | ||
import copy | ||
from typing import Callable | ||
import random | ||
|
||
class SubsetSequenceSampler(Sampler): | ||
def __init__(self, indices): | ||
super(SubsetSequenceSampler, self).__init__(indices) | ||
self.indices = indices | ||
|
||
def __iter__(self): | ||
return (self.indices[i] for i in range(len(self.indices))) | ||
|
||
def __len__(self): | ||
return len(self.indices) | ||
|
||
|
||
class RenderingDataLoader(DataLoader): | ||
def __init__(self, args, dataset): | ||
self.args = args | ||
self.dataset = dataset | ||
self.batch_size = 1 if args.test_only else args.batch_size | ||
self.number_previous_frames = args.num_pre_frames | ||
self.num_frames_samples = args.num_frames_samples | ||
self.test_every = args.test_every | ||
self.n_samples = args.total_folder_num * args.grain | ||
|
||
self.valid_range = [] | ||
if args.test_only is False: | ||
valid_folders = np.random.choice(np.arange(0, args.total_folder_num), args.valid_folder_num, replace=False) | ||
for i in valid_folders: | ||
self.valid_range += range(i * args.grain, (i+1) * args.grain) | ||
|
||
self.test_range = [] | ||
for i in args.test_folder: | ||
self.test_range += range(i * args.grain, (i+1) * args.grain) | ||
|
||
self.idx_train = [] | ||
for idx in range(self.n_samples): | ||
if (idx not in self.valid_range) and (idx not in self.test_range): | ||
idx_mod = idx % args.grain | ||
if (idx_mod >= self.number_previous_frames) and (idx_mod <= args.grain - self.num_frames_samples): | ||
self.idx_train.append(idx) | ||
|
||
self.idx_valid = [] | ||
for idx in self.valid_range: | ||
idx_mod = idx % args.grain | ||
if (idx_mod >= self.number_previous_frames) and (idx_mod <= args.grain - self.num_frames_samples): | ||
self.idx_valid.append(idx) | ||
|
||
self.idx_test = [] | ||
for idx in self.test_range: | ||
idx_mod = idx % args.grain | ||
if (idx_mod >= self.number_previous_frames) and (idx_mod <= args.grain - self.num_frames_samples): | ||
self.idx_test.append(idx) | ||
|
||
self.sampler, self.valid_sampler = self._split_sampler() | ||
|
||
init_kwargs = { | ||
'dataset': dataset, | ||
'batch_size': self.batch_size, | ||
'num_workers': args.n_threads | ||
} | ||
super().__init__(sampler=self.sampler, **init_kwargs, drop_last=True) | ||
|
||
def _split_sampler(self): | ||
# For test | ||
if len(self.valid_range) == 0: | ||
train_sampler = SubsetSequenceSampler(self.idx_test) | ||
return train_sampler, None | ||
|
||
repeat = (self.batch_size * self.test_every) // len(self.idx_train) | ||
train_idx = np.repeat(self.idx_train, repeat) | ||
|
||
np.random.seed(0) | ||
np.random.shuffle(train_idx) | ||
|
||
train_sampler = SubsetRandomSampler(train_idx) | ||
valid_sampler = SubsetSequenceSampler(self.idx_valid) | ||
self.n_samples = len(train_idx) | ||
|
||
return train_sampler, valid_sampler | ||
|
||
def split_validation(self, valid_dataset): | ||
if self.valid_sampler is None: | ||
return None | ||
else: | ||
valid_dataloader = DataLoader(sampler=self.valid_sampler, dataset=valid_dataset, batch_size=1, num_workers=self.num_workers, drop_last=True) | ||
return valid_dataloader |
Oops, something went wrong.