Skip to content

Commit

Permalink
init commit
Browse files Browse the repository at this point in the history
  • Loading branch information
Riga2 committed Mar 4, 2024
1 parent b5a1876 commit d285738
Show file tree
Hide file tree
Showing 28 changed files with 2,199 additions and 2 deletions.
6 changes: 6 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
# user settings
dataset/
experiment/
src/check/*
.idea/

# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
Expand Down
63 changes: 61 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,2 +1,61 @@
# NSRD
[CVPR 2024] Neural Super-Resolution for Real-time Rendering with Radiance Demodulation
## <u>N</u>eural <u>S</u>uper-Resolution for Real-time Rendering with <u>R</u>adiance <u>D</u>emodulation (CVPR 2024)

### [Paper](https://markdown.com.cn) | [Datasets](https://markdown.com.cn)

### Installation

Tested on Windows + CUDA 11.3 + Pytorch 1.12.1

Install environment:

```bazaar
git clone https://github.com/riga2/NSRD.git
cd NSRD
conda create -n NSRD python=3.9
conda activate NSRD
pip install torch==1.12.1+cu113 torchvision==0.13.1+cu113 torchaudio==0.12.1 --extra-index-url https://download.pytorch.org/whl/cu113
pip install -r requirements.txt
```

The following training and testing take the Bistro scene (X4) as an example.

### Training
1. Download the dataset and put it in the "dataset" folder.
```bazaar
|--dataset
|--Bistro
|--train
|---GT
|--0
|--1
...
|---X4
|--0
|--1
...
|---test
|---GT
...
|---X4
...
```
2. Use the Anaconda Prompt to run the following commands to train. The trained model is stored in "experiment\Bistro_X4\model".
```bazaar
cd src
.\script\BistroX4_train.bat
```

### Testing
1. Run the following commands to perform super-resolution on the LR lighting components. The SR results are stored in "experiment\Bistro_X4\sr_results_x4".
```bazaar
cd src
.\test_script\BistroX4_test.bat
```
2. Run the following commands to perform remodulation on the SR lighting components. The final results are stored in "experiment\Bistro_X4\final_results_x4".
```bazaar
cd src
python remodulation.py --exp_dir ../experiment/Bistro_X4 --gt_dir ../dataset/Bistro/test/GT
```

### Citations

9 changes: 9 additions & 0 deletions configs/Bistro_test.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
root_dir = ../dataset/Bistro/test
grain = 300
total_folder_num = 4
test_folder = [1] # Currently only supports testing one folder at a time
valid_folder_num = 0
test_only = True
save = Bistro_X4
resume = -1
save_results = True
9 changes: 9 additions & 0 deletions configs/Bistro_train.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
root_dir = ../dataset/Bistro/train
grain = 100
total_folder_num = 60
test_folder = ()
valid_folder_num = 6
test_only = False
save = Bistro_X4
resume = 1
save_results = False
8 changes: 8 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
numpy==1.24.3
opencv-python==4.7.0.72
Pillow==9.5.0
tqdm==4.65.0
matplotlib
imageio
scikit-image
configparser
19 changes: 19 additions & 0 deletions src/data/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from importlib import import_module
from data import data_loader
import os

class Data:
def __init__(self, args):
self.sr_content = args.sr_content
self.loader_train = None
module_name = self.sr_content
m = import_module('data.' + module_name.lower() + '_dataset')
if not args.test_only:
train_datasets = m.make_model(args, train=True)
valid_datasets = m.make_model(args, train=False)
self.loader_train = data_loader.RenderingDataLoader(args, train_datasets)
self.loader_valid = self.loader_train.split_validation(valid_datasets)
else:
test_datasets = m.make_model(args, train=False)
self.loader_valid = data_loader.RenderingDataLoader(args, test_datasets)

57 changes: 57 additions & 0 deletions src/data/base_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import os
os.environ["OPENCV_IO_ENABLE_OPENEXR"]="1"
import cv2
import numpy as np
from torch.utils.data import Dataset
from data import data_utils

class BaseDataset(Dataset):
def __init__(self, args=None):
super(BaseDataset, self).__init__()
self.args = args
self.scale = args.scale
self.number_previous_frames = args.num_pre_frames
self.hr_size = args.gt_size
self.lr_size = (args.gt_size[0] // self.scale, args.gt_size[1] // self.scale)

self.gt_dir = os.path.join(args.root_dir, 'GT')
self.lr_dir = os.path.join(args.root_dir, 'X' + str(args.scale))

self.useNormal, self.useDepth = args.use_normal, args.use_depth
self.depth_dirname = args.depth_dirname
self.normal_dirname = args.normal_dirname
self.mv_dirname = args.mv_dirname
self.ocmv_dirname = args.ocmv_dirname
self.grain = args.grain
self.total_folder_num = args.total_folder_num

def load_Normal_Unity(self, folder_index, file_index, ext='.exr'):
filename = str(file_index) + ext
lr_file_path = os.path.join(self.lr_dir, str(folder_index), self.normal_dirname, filename)
lr = data_utils.getFromExr(lr_file_path)[:, :, :3]
return lr

def load_Depth_Unity(self, folder_index, file_index, ext='.exr'):
filename = str(file_index) + ext
lr_file_path = os.path.join(self.lr_dir, str(folder_index), self.depth_dirname, filename)
lr = data_utils.getFromExr(lr_file_path)[:, :, 0][:, :, None]
return lr

def load_MV(self, folder_index, file_index, ext='.exr'):
filename = str(file_index) + ext
lr_file_path = os.path.join(self.lr_dir, str(folder_index), self.mv_dirname, filename)
# lr = data_utils.getFromBin(lr_file_path, self.lr_size[0], self.lr_size[1])[:, :, :2]
lr = data_utils.getFromExr(lr_file_path)[:, :, :2]
lr[:, :, 0] = lr[:, :, 0] * self.lr_size[1]
lr[:, :, 1] = lr[:, :, 1] * self.lr_size[0]
return lr

def load_OCMV(self, folder_index, file_index, ext='.exr'):
filename = str(file_index) + ext
lr_file_path = os.path.join(self.lr_dir, str(folder_index), self.ocmv_dirname, filename)
# lr = data_utils.getFromBin(lr_file_path, self.lr_size[0], self.lr_size[1])[:, :, :2]
lr = data_utils.getFromExr(lr_file_path)[:, :, :2]
lr[:, :, 0] = lr[:, :, 0] * self.lr_size[1]
lr[:, :, 1] = lr[:, :, 1] * self.lr_size[0]
return lr

93 changes: 93 additions & 0 deletions src/data/data_loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
import os

import numpy as np
from torch.utils.data import DataLoader
from torch.utils.data.sampler import SubsetRandomSampler, SequentialSampler, Sampler
import copy
from typing import Callable
import random

class SubsetSequenceSampler(Sampler):
def __init__(self, indices):
super(SubsetSequenceSampler, self).__init__(indices)
self.indices = indices

def __iter__(self):
return (self.indices[i] for i in range(len(self.indices)))

def __len__(self):
return len(self.indices)


class RenderingDataLoader(DataLoader):
def __init__(self, args, dataset):
self.args = args
self.dataset = dataset
self.batch_size = 1 if args.test_only else args.batch_size
self.number_previous_frames = args.num_pre_frames
self.num_frames_samples = args.num_frames_samples
self.test_every = args.test_every
self.n_samples = args.total_folder_num * args.grain

self.valid_range = []
if args.test_only is False:
valid_folders = np.random.choice(np.arange(0, args.total_folder_num), args.valid_folder_num, replace=False)
for i in valid_folders:
self.valid_range += range(i * args.grain, (i+1) * args.grain)

self.test_range = []
for i in args.test_folder:
self.test_range += range(i * args.grain, (i+1) * args.grain)

self.idx_train = []
for idx in range(self.n_samples):
if (idx not in self.valid_range) and (idx not in self.test_range):
idx_mod = idx % args.grain
if (idx_mod >= self.number_previous_frames) and (idx_mod <= args.grain - self.num_frames_samples):
self.idx_train.append(idx)

self.idx_valid = []
for idx in self.valid_range:
idx_mod = idx % args.grain
if (idx_mod >= self.number_previous_frames) and (idx_mod <= args.grain - self.num_frames_samples):
self.idx_valid.append(idx)

self.idx_test = []
for idx in self.test_range:
idx_mod = idx % args.grain
if (idx_mod >= self.number_previous_frames) and (idx_mod <= args.grain - self.num_frames_samples):
self.idx_test.append(idx)

self.sampler, self.valid_sampler = self._split_sampler()

init_kwargs = {
'dataset': dataset,
'batch_size': self.batch_size,
'num_workers': args.n_threads
}
super().__init__(sampler=self.sampler, **init_kwargs, drop_last=True)

def _split_sampler(self):
# For test
if len(self.valid_range) == 0:
train_sampler = SubsetSequenceSampler(self.idx_test)
return train_sampler, None

repeat = (self.batch_size * self.test_every) // len(self.idx_train)
train_idx = np.repeat(self.idx_train, repeat)

np.random.seed(0)
np.random.shuffle(train_idx)

train_sampler = SubsetRandomSampler(train_idx)
valid_sampler = SubsetSequenceSampler(self.idx_valid)
self.n_samples = len(train_idx)

return train_sampler, valid_sampler

def split_validation(self, valid_dataset):
if self.valid_sampler is None:
return None
else:
valid_dataloader = DataLoader(sampler=self.valid_sampler, dataset=valid_dataset, batch_size=1, num_workers=self.num_workers, drop_last=True)
return valid_dataloader
Loading

0 comments on commit d285738

Please sign in to comment.