From 53af79ac51e5f935acff59adfba76b9136033685 Mon Sep 17 00:00:00 2001 From: elxy Date: Wed, 12 Oct 2022 16:17:19 +0800 Subject: [PATCH] support different color spaces --- basicsr/data/data_util.py | 4 +- basicsr/data/ffhq_dataset.py | 6 +- basicsr/data/paired_image_dataset.py | 12 +- basicsr/data/realesrgan_dataset.py | 6 +- basicsr/data/realesrgan_paired_dataset.py | 7 +- basicsr/data/reds_dataset.py | 16 ++- basicsr/data/single_image_dataset.py | 7 +- basicsr/data/vimeo90k_dataset.py | 14 +- .../metrics/test_metrics/test_psnr_ssim.py | 6 +- basicsr/models/sr_model.py | 5 +- basicsr/utils/__init__.py | 3 +- basicsr/utils/color_util.py | 4 + basicsr/utils/img_util.py | 129 ++++++++++++++---- docs/Config.md | 4 + inference/inference_ridnet.py | 6 +- scripts/metrics/calculate_lpips.py | 4 +- 16 files changed, 170 insertions(+), 63 deletions(-) diff --git a/basicsr/data/data_util.py b/basicsr/data/data_util.py index bf4c494b7..e01975a05 100644 --- a/basicsr/data/data_util.py +++ b/basicsr/data/data_util.py @@ -8,7 +8,7 @@ from basicsr.utils import img2tensor, scandir -def read_img_seq(path, require_mod_crop=False, scale=1, return_imgname=False): +def read_img_seq(path, require_mod_crop=False, scale=1, return_imgname=False, color_space='rgb'): """Read a sequence of images from a given folder path. Args: @@ -30,7 +30,7 @@ def read_img_seq(path, require_mod_crop=False, scale=1, return_imgname=False): if require_mod_crop: imgs = [mod_crop(img, scale) for img in imgs] - imgs = img2tensor(imgs, bgr2rgb=True, float32=True) + imgs = img2tensor(imgs, color_space=color_space, float32=True) imgs = torch.stack(imgs, dim=0) if return_imgname: diff --git a/basicsr/data/ffhq_dataset.py b/basicsr/data/ffhq_dataset.py index 23992eb87..0ddf69139 100644 --- a/basicsr/data/ffhq_dataset.py +++ b/basicsr/data/ffhq_dataset.py @@ -5,7 +5,7 @@ from torchvision.transforms.functional import normalize from basicsr.data.transforms import augment -from basicsr.utils import FileClient, get_root_logger, imfrombytes, img2tensor +from basicsr.utils import ColorSpace, FileClient, get_root_logger, imfrombytes, img2tensor from basicsr.utils.registry import DATASET_REGISTRY @@ -70,8 +70,10 @@ def __getitem__(self, index): # random horizontal flip img_gt = augment(img_gt, hflip=self.opt['use_hflip'], rotation=False) + # color space transform + color_space = ColorSpace.RGB if 'color' not in self.opt else self.opt['color'] # BGR to RGB, HWC to CHW, numpy to tensor - img_gt = img2tensor(img_gt, bgr2rgb=True, float32=True) + img_gt = img2tensor(img_gt, color_space=color_space, float32=True) # normalize normalize(img_gt, self.mean, self.std, inplace=True) return {'gt': img_gt, 'gt_path': gt_path} diff --git a/basicsr/data/paired_image_dataset.py b/basicsr/data/paired_image_dataset.py index 9f5c8c6ad..03f1f8d02 100644 --- a/basicsr/data/paired_image_dataset.py +++ b/basicsr/data/paired_image_dataset.py @@ -3,7 +3,7 @@ from basicsr.data.data_util import paired_paths_from_folder, paired_paths_from_lmdb, paired_paths_from_meta_info_file from basicsr.data.transforms import augment, paired_random_crop -from basicsr.utils import FileClient, bgr2ycbcr, imfrombytes, img2tensor +from basicsr.utils import ColorSpace, FileClient, imfrombytes, img2tensor from basicsr.utils.registry import DATASET_REGISTRY @@ -83,18 +83,16 @@ def __getitem__(self, index): # flip, rotation img_gt, img_lq = augment([img_gt, img_lq], self.opt['use_hflip'], self.opt['use_rot']) - # color space transform - if 'color' in self.opt and self.opt['color'] == 'y': - img_gt = bgr2ycbcr(img_gt, y_only=True)[..., None] - img_lq = bgr2ycbcr(img_lq, y_only=True)[..., None] - # crop the unmatched GT images during validation or testing, especially for SR benchmark datasets # TODO: It is better to update the datasets, rather than force to crop if self.opt['phase'] != 'train': img_gt = img_gt[0:img_lq.shape[0] * scale, 0:img_lq.shape[1] * scale, :] + # color space transform + color_space = ColorSpace.RGB if 'color' not in self.opt else self.opt['color'] + # BGR to RGB, HWC to CHW, numpy to tensor - img_gt, img_lq = img2tensor([img_gt, img_lq], bgr2rgb=True, float32=True) + img_gt, img_lq = img2tensor([img_gt, img_lq], color_space=color_space, float32=True) # normalize if self.mean is not None or self.std is not None: normalize(img_lq, self.mean, self.std, inplace=True) diff --git a/basicsr/data/realesrgan_dataset.py b/basicsr/data/realesrgan_dataset.py index 1616e9b91..368b33024 100644 --- a/basicsr/data/realesrgan_dataset.py +++ b/basicsr/data/realesrgan_dataset.py @@ -10,7 +10,7 @@ from basicsr.data.degradations import circular_lowpass_kernel, random_mixed_kernels from basicsr.data.transforms import augment -from basicsr.utils import FileClient, get_root_logger, imfrombytes, img2tensor +from basicsr.utils import ColorSpace, FileClient, get_root_logger, imfrombytes, img2tensor from basicsr.utils.registry import DATASET_REGISTRY @@ -181,8 +181,10 @@ def __getitem__(self, index): else: sinc_kernel = self.pulse_tensor + # color space transform + color_space = ColorSpace.RGB if 'color' not in self.opt else self.opt['color'] # BGR to RGB, HWC to CHW, numpy to tensor - img_gt = img2tensor([img_gt], bgr2rgb=True, float32=True)[0] + img_gt = img2tensor([img_gt], color_space=color_space, float32=True)[0] kernel = torch.FloatTensor(kernel) kernel2 = torch.FloatTensor(kernel2) diff --git a/basicsr/data/realesrgan_paired_dataset.py b/basicsr/data/realesrgan_paired_dataset.py index 604b026d5..749891117 100644 --- a/basicsr/data/realesrgan_paired_dataset.py +++ b/basicsr/data/realesrgan_paired_dataset.py @@ -4,7 +4,7 @@ from basicsr.data.data_util import paired_paths_from_folder, paired_paths_from_lmdb from basicsr.data.transforms import augment, paired_random_crop -from basicsr.utils import FileClient, imfrombytes, img2tensor +from basicsr.utils import FileClient, imfrombytes, img2tensor, ColorSpace from basicsr.utils.registry import DATASET_REGISTRY @@ -93,8 +93,11 @@ def __getitem__(self, index): # flip, rotation img_gt, img_lq = augment([img_gt, img_lq], self.opt['use_hflip'], self.opt['use_rot']) + # color space transform + color_space = ColorSpace.RGB if 'color' not in self.opt else self.opt['color'] + # BGR to RGB, HWC to CHW, numpy to tensor - img_gt, img_lq = img2tensor([img_gt, img_lq], bgr2rgb=True, float32=True) + img_gt, img_lq = img2tensor([img_gt, img_lq], color_space=color_space, float32=True) # normalize if self.mean is not None or self.std is not None: normalize(img_lq, self.mean, self.std, inplace=True) diff --git a/basicsr/data/reds_dataset.py b/basicsr/data/reds_dataset.py index fabef1d7e..a0606e209 100644 --- a/basicsr/data/reds_dataset.py +++ b/basicsr/data/reds_dataset.py @@ -5,7 +5,7 @@ from torch.utils import data as data from basicsr.data.transforms import augment, paired_random_crop -from basicsr.utils import FileClient, get_root_logger, imfrombytes, img2tensor +from basicsr.utils import FileClient, get_root_logger, imfrombytes, img2tensor, ColorSpace from basicsr.utils.flow_util import dequantize_flow from basicsr.utils.registry import DATASET_REGISTRY @@ -182,12 +182,16 @@ def __getitem__(self, index): else: img_results = augment(img_lqs, self.opt['use_hflip'], self.opt['use_rot']) - img_results = img2tensor(img_results) + # color space transform + color_space = ColorSpace.RGB if 'color' not in self.opt else self.opt['color'] + + # BGR to RGB, HWC to CHW, numpy to tensor + img_results = img2tensor(img_results, color_space=color_space) img_lqs = torch.stack(img_results[0:-1], dim=0) img_gt = img_results[-1] if self.flow_root is not None: - img_flows = img2tensor(img_flows) + img_flows = img2tensor(img_flows, color_space=ColorSpace.RAW) # add the zero center flow img_flows.insert(self.num_half_frames, torch.zeros_like(img_flows[0])) img_flows = torch.stack(img_flows, dim=0) @@ -339,7 +343,11 @@ def __getitem__(self, index): img_lqs.extend(img_gts) img_results = augment(img_lqs, self.opt['use_hflip'], self.opt['use_rot']) - img_results = img2tensor(img_results) + # color space transform + color_space = ColorSpace.RGB if 'color' not in self.opt else self.opt['color'] + + # BGR to RGB, HWC to CHW, numpy to tensor + img_results = img2tensor(img_results, color_space=color_space) img_gts = torch.stack(img_results[len(img_lqs) // 2:], dim=0) img_lqs = torch.stack(img_results[:len(img_lqs) // 2], dim=0) diff --git a/basicsr/data/single_image_dataset.py b/basicsr/data/single_image_dataset.py index acbc7d921..3f60075f6 100644 --- a/basicsr/data/single_image_dataset.py +++ b/basicsr/data/single_image_dataset.py @@ -3,7 +3,7 @@ from torchvision.transforms.functional import normalize from basicsr.data.data_util import paths_from_lmdb -from basicsr.utils import FileClient, imfrombytes, img2tensor, rgb2ycbcr, scandir +from basicsr.utils import ColorSpace, FileClient, imfrombytes, img2tensor, scandir from basicsr.utils.registry import DATASET_REGISTRY @@ -54,11 +54,10 @@ def __getitem__(self, index): img_lq = imfrombytes(img_bytes, float32=True) # color space transform - if 'color' in self.opt and self.opt['color'] == 'y': - img_lq = rgb2ycbcr(img_lq, y_only=True)[..., None] + color_space = ColorSpace.RGB if 'color' not in self.opt else self.opt['color'] # BGR to RGB, HWC to CHW, numpy to tensor - img_lq = img2tensor(img_lq, bgr2rgb=True, float32=True) + img_lq = img2tensor(img_lq, color_space=color_space, float32=True) # normalize if self.mean is not None or self.std is not None: normalize(img_lq, self.mean, self.std, inplace=True) diff --git a/basicsr/data/vimeo90k_dataset.py b/basicsr/data/vimeo90k_dataset.py index e5e33e108..46ecb5735 100644 --- a/basicsr/data/vimeo90k_dataset.py +++ b/basicsr/data/vimeo90k_dataset.py @@ -4,7 +4,7 @@ from torch.utils import data as data from basicsr.data.transforms import augment, paired_random_crop -from basicsr.utils import FileClient, get_root_logger, imfrombytes, img2tensor +from basicsr.utils import FileClient, get_root_logger, imfrombytes, img2tensor, ColorSpace from basicsr.utils.registry import DATASET_REGISTRY @@ -120,7 +120,11 @@ def __getitem__(self, index): img_lqs.append(img_gt) img_results = augment(img_lqs, self.opt['use_hflip'], self.opt['use_rot']) - img_results = img2tensor(img_results) + # color space transform + color_space = ColorSpace.RGB if 'color' not in self.opt else self.opt['color'] + + # BGR to RGB, HWC to CHW, numpy to tensor + img_results = img2tensor(img_results, color_space=color_space) img_lqs = torch.stack(img_results[0:-1], dim=0) img_gt = img_results[-1] @@ -182,7 +186,11 @@ def __getitem__(self, index): img_lqs.extend(img_gts) img_results = augment(img_lqs, self.opt['use_hflip'], self.opt['use_rot']) - img_results = img2tensor(img_results) + # color space transform + color_space = ColorSpace.RGB if 'color' not in self.opt else self.opt['color'] + + # BGR to RGB, HWC to CHW, numpy to tensor + img_results = img2tensor(img_results, color_space=color_space) img_lqs = torch.stack(img_results[:7], dim=0) img_gts = torch.stack(img_results[7:], dim=0) diff --git a/basicsr/metrics/test_metrics/test_psnr_ssim.py b/basicsr/metrics/test_metrics/test_psnr_ssim.py index 18b05a73a..cccd6e91c 100644 --- a/basicsr/metrics/test_metrics/test_psnr_ssim.py +++ b/basicsr/metrics/test_metrics/test_psnr_ssim.py @@ -3,7 +3,7 @@ from basicsr.metrics import calculate_psnr, calculate_ssim from basicsr.metrics.psnr_ssim import calculate_psnr_pt, calculate_ssim_pt -from basicsr.utils import img2tensor +from basicsr.utils import img2tensor, ColorSpace def test(img_path, img_path2, crop_border, test_y_channel=False): @@ -16,8 +16,8 @@ def test(img_path, img_path2, crop_border, test_y_channel=False): print(f'\tNumpy\tPSNR: {psnr:.6f} dB, \tSSIM: {ssim:.6f}') # --------------------- PyTorch (CPU) --------------------- - img = img2tensor(img / 255., bgr2rgb=True, float32=True).unsqueeze_(0) - img2 = img2tensor(img2 / 255., bgr2rgb=True, float32=True).unsqueeze_(0) + img = img2tensor(img / 255., color_space=ColorSpace.RGB, float32=True).unsqueeze_(0) + img2 = img2tensor(img2 / 255., color_space=ColorSpace.RGB, float32=True).unsqueeze_(0) psnr_pth = calculate_psnr_pt(img, img2, crop_border=crop_border, test_y_channel=test_y_channel) ssim_pth = calculate_ssim_pt(img, img2, crop_border=crop_border, test_y_channel=test_y_channel) diff --git a/basicsr/models/sr_model.py b/basicsr/models/sr_model.py index 787f1fd2e..62b4e70c2 100644 --- a/basicsr/models/sr_model.py +++ b/basicsr/models/sr_model.py @@ -183,6 +183,7 @@ def dist_validation(self, dataloader, current_iter, tb_logger, save_img): def nondist_validation(self, dataloader, current_iter, tb_logger, save_img): dataset_name = dataloader.dataset.opt['name'] + color_space = dataloader.dataset.opt['color'] if 'color' in dataloader.dataset.opt else 'rgb' with_metrics = self.opt['val'].get('metrics') is not None use_pbar = self.opt['val'].get('pbar', False) @@ -205,10 +206,10 @@ def nondist_validation(self, dataloader, current_iter, tb_logger, save_img): self.test() visuals = self.get_current_visuals() - sr_img = tensor2img([visuals['result']]) + sr_img = tensor2img([visuals['result']], color_space=color_space) metric_data['img'] = sr_img if 'gt' in visuals: - gt_img = tensor2img([visuals['gt']]) + gt_img = tensor2img([visuals['gt']], color_space=color_space) metric_data['img2'] = gt_img del self.gt diff --git a/basicsr/utils/__init__.py b/basicsr/utils/__init__.py index 9569c5078..3886635c9 100644 --- a/basicsr/utils/__init__.py +++ b/basicsr/utils/__init__.py @@ -2,7 +2,7 @@ from .diffjpeg import DiffJPEG from .file_client import FileClient from .img_process_util import USMSharp, usm_sharp -from .img_util import crop_border, imfrombytes, img2tensor, imwrite, tensor2img +from .img_util import ColorSpace, crop_border, imfrombytes, img2tensor, imwrite, tensor2img from .logger import AvgTimer, MessageLogger, get_env_info, get_root_logger, init_tb_logger, init_wandb_logger from .misc import check_resume, get_time_str, make_exp_dirs, mkdir_and_rename, scandir, set_random_seed, sizeof_fmt from .options import yaml_load @@ -17,6 +17,7 @@ # file_client.py 'FileClient', # img_util.py + 'ColorSpace', 'img2tensor', 'tensor2img', 'imfrombytes', diff --git a/basicsr/utils/color_util.py b/basicsr/utils/color_util.py index 4740d5c98..79cf106bc 100644 --- a/basicsr/utils/color_util.py +++ b/basicsr/utils/color_util.py @@ -91,6 +91,8 @@ def ycbcr2rgb(img): """ img_type = img.dtype img = _convert_input_type_range(img) * 255 + if img.shape[2] == 1: # only y channel + img = np.pad(img, ((0, 0), (0, 0), (0, 2)), 'constant') out_img = np.matmul(img, [[0.00456621, 0.00456621, 0.00456621], [0, -0.00153632, 0.00791071], [0.00625893, -0.00318811, 0]]) * 255.0 + [-222.921, 135.576, -276.836] # noqa: E126 out_img = _convert_output_type_range(out_img, img_type) @@ -120,6 +122,8 @@ def ycbcr2bgr(img): """ img_type = img.dtype img = _convert_input_type_range(img) * 255 + if img.shape[2] == 1: # only y channel + img = np.pad(img, ((0, 0), (0, 0), (0, 2)), 'constant') out_img = np.matmul(img, [[0.00456621, 0.00456621, 0.00456621], [0.00791071, -0.00153632, 0], [0, -0.00318811, 0.00625893]]) * 255.0 + [-276.836, 135.576, -222.921] # noqa: E126 out_img = _convert_output_type_range(out_img, img_type) diff --git a/basicsr/utils/img_util.py b/basicsr/utils/img_util.py index 3a5f1da09..8b8168a7d 100644 --- a/basicsr/utils/img_util.py +++ b/basicsr/utils/img_util.py @@ -1,17 +1,76 @@ -import cv2 import math -import numpy as np import os +from enum import IntEnum, auto +from functools import partial + +import cv2 +import numpy as np import torch from torchvision.utils import make_grid - -def img2tensor(imgs, bgr2rgb=True, float32=True): +from .color_util import bgr2ycbcr, ycbcr2bgr + + +class ColorSpace(IntEnum): + RAW = auto() # do not convert colorspace + BGR = auto() + RGB = auto() + GRAY = auto() # YUVJ + XYZ = auto() + YCrCb = auto() + HSV = auto() + Lab = auto() + Luv = auto() + HLS = auto() + YUV = auto() # YUVJ + YUVI420 = auto() # YUVJ + Y = GRAY_BT601 = auto() + YUV_BT601 = auto() + GRAY_BT709 = auto() + YUV_BT709 = auto() + + +BGR2COLOR = { + ColorSpace.RAW: lambda x: x, + ColorSpace.BGR: lambda x: x, + ColorSpace.RGB: partial(cv2.cvtColor, code=cv2.COLOR_BGR2RGB), + ColorSpace.GRAY: partial(cv2.cvtColor, code=cv2.COLOR_BGR2GRAY), + ColorSpace.XYZ: partial(cv2.cvtColor, code=cv2.COLOR_BGR2XYZ), + ColorSpace.YCrCb: partial(cv2.cvtColor, code=cv2.COLOR_BGR2YCrCb), + ColorSpace.HSV: partial(cv2.cvtColor, code=cv2.COLOR_BGR2HSV), + ColorSpace.Lab: partial(cv2.cvtColor, code=cv2.COLOR_BGR2Lab), + ColorSpace.Luv: partial(cv2.cvtColor, code=cv2.COLOR_BGR2Luv), + ColorSpace.HLS: partial(cv2.cvtColor, code=cv2.COLOR_BGR2HLS), + ColorSpace.YUV: partial(cv2.cvtColor, code=cv2.COLOR_BGR2YUV), + ColorSpace.YUVI420: partial(cv2.cvtColor, code=cv2.COLOR_BGR2YUV_I420), + ColorSpace.GRAY_BT601: partial(bgr2ycbcr, y_only=True), + ColorSpace.YUV_BT601: partial(bgr2ycbcr, y_only=False), +} + +COLOR2BGR = { + ColorSpace.RAW: lambda x: x, + ColorSpace.BGR: lambda x: x, + ColorSpace.RGB: partial(cv2.cvtColor, code=cv2.COLOR_RGB2BGR), + ColorSpace.GRAY: partial(cv2.cvtColor, code=cv2.COLOR_GRAY2BGR), + ColorSpace.XYZ: partial(cv2.cvtColor, code=cv2.COLOR_XYZ2BGR), + ColorSpace.YCrCb: partial(cv2.cvtColor, code=cv2.COLOR_YCrCb2BGR), + ColorSpace.HSV: partial(cv2.cvtColor, code=cv2.COLOR_HSV2BGR), + ColorSpace.Lab: partial(cv2.cvtColor, code=cv2.COLOR_Lab2BGR), + ColorSpace.Luv: partial(cv2.cvtColor, code=cv2.COLOR_Luv2BGR), + ColorSpace.HLS: partial(cv2.cvtColor, code=cv2.COLOR_HLS2BGR), + ColorSpace.YUV: partial(cv2.cvtColor, code=cv2.COLOR_YUV2BGR), + ColorSpace.YUVI420: partial(cv2.cvtColor, code=cv2.COLOR_YUV2BGR_I420), + ColorSpace.GRAY_BT601: ycbcr2bgr, + ColorSpace.YUV_BT601: ycbcr2bgr, +} + + +def img2tensor(imgs, color_space=ColorSpace.RGB, float32=True): """Numpy array to tensor. Args: - imgs (list[ndarray] | ndarray): Input images. - bgr2rgb (bool): Whether to change bgr to rgb. + imgs (list[ndarray] | ndarray): Input images, MUST be BGR + color_space (ColorSpace): Target color space of images. float32 (bool): Whether to change to float32. Returns: @@ -19,23 +78,35 @@ def img2tensor(imgs, bgr2rgb=True, float32=True): one element, just return tensor. """ - def _totensor(img, bgr2rgb, float32): - if img.shape[2] == 3 and bgr2rgb: + def _totensor(img, color_space, float32): + if img.shape[2] == 3: # input is bgr if img.dtype == 'float64': img = img.astype('float32') - img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) - img = torch.from_numpy(img.transpose(2, 0, 1)) + img = BGR2COLOR[color_space](img) + if img.ndim == 3: # HWC to CHW + img = torch.from_numpy(img.transpose(2, 0, 1)) + elif img.ndim == 2: + img = torch.from_numpy(np.expand_dims(img, axis=0)) + else: + raise ValueError(f'Unsupported image dim {img.ndim}!') if float32: img = img.float() return img + if isinstance(color_space, str): + for cs in ColorSpace: + if color_space.lower() == cs.name.lower(): + color_space = cs + break + if isinstance(color_space, str): + raise ValueError(f'Do not support color space {color_space} yet!') if isinstance(imgs, list): - return [_totensor(img, bgr2rgb, float32) for img in imgs] + return [_totensor(img, color_space, float32) for img in imgs] else: - return _totensor(imgs, bgr2rgb, float32) + return _totensor(imgs, color_space, float32) -def tensor2img(tensor, rgb2bgr=True, out_type=np.uint8, min_max=(0, 1)): +def tensor2img(tensor, color_space=ColorSpace.RGB, out_type=np.uint8, min_max=None): """Convert torch Tensors into image numpy arrays. After clamping to [min, max], values will be normalized to [0, 1]. @@ -46,7 +117,7 @@ def tensor2img(tensor, rgb2bgr=True, out_type=np.uint8, min_max=(0, 1)): 2) 3D Tensor of shape (3/1 x H x W); 3) 2D Tensor of shape (H x W). Tensor channel should be in RGB order. - rgb2bgr (bool): Whether to change rgb to bgr. + color_space (ColorSpace): Color space of input tensor. out_type (numpy type): output types. If ``np.uint8``, transform outputs to uint8 type with range [0, 255]; otherwise, float type with range [0, 1]. Default: ``np.uint8``. @@ -59,31 +130,38 @@ def tensor2img(tensor, rgb2bgr=True, out_type=np.uint8, min_max=(0, 1)): if not (torch.is_tensor(tensor) or (isinstance(tensor, list) and all(torch.is_tensor(t) for t in tensor))): raise TypeError(f'tensor or list of tensors expected, got {type(tensor)}') + if isinstance(color_space, str): + for cs in ColorSpace: + if color_space.lower() == cs.name.lower(): + color_space = cs + break + if isinstance(color_space, str): + raise ValueError(f'Do not support color space {color_space} yet!') + if torch.is_tensor(tensor): tensor = [tensor] result = [] for _tensor in tensor: - _tensor = _tensor.squeeze(0).float().detach().cpu().clamp_(*min_max) - _tensor = (_tensor - min_max[0]) / (min_max[1] - min_max[0]) + _tensor = _tensor.squeeze(0).float().detach().cpu() n_dim = _tensor.dim() if n_dim == 4: img_np = make_grid(_tensor, nrow=int(math.sqrt(_tensor.size(0))), normalize=False).numpy() - img_np = img_np.transpose(1, 2, 0) - if rgb2bgr: - img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR) + img_np = img_np.transpose(1, 2, 0) # CHW to HWC + img_np = COLOR2BGR[color_space](img_np) elif n_dim == 3: img_np = _tensor.numpy() - img_np = img_np.transpose(1, 2, 0) + img_np = img_np.transpose(1, 2, 0) # CHW to HWC if img_np.shape[2] == 1: # gray image img_np = np.squeeze(img_np, axis=2) else: - if rgb2bgr: - img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR) + img_np = COLOR2BGR[color_space](img_np) elif n_dim == 2: img_np = _tensor.numpy() else: raise TypeError(f'Only support 4D, 3D or 2D tensor. But received with dimension: {n_dim}') + # clip BGR to (0, 1) + img_np = np.clip(img_np, 0, 1) if out_type == np.uint8: # Unlike MATLAB, numpy.unit8() WILL NOT round by default. img_np = (img_np * 255.0).round() @@ -94,20 +172,19 @@ def tensor2img(tensor, rgb2bgr=True, out_type=np.uint8, min_max=(0, 1)): return result -def tensor2img_fast(tensor, rgb2bgr=True, min_max=(0, 1)): +def tensor2img_fast(tensor, color_space=ColorSpace.RGB, min_max=(0, 1)): """This implementation is slightly faster than tensor2img. It now only supports torch tensor with shape (1, c, h, w). Args: tensor (Tensor): Now only support torch tensor with (1, c, h, w). - rgb2bgr (bool): Whether to change rgb to bgr. Default: True. + color_space (ColorSpace): Color space of tesnor. Default: ColorSpace.RGB. min_max (tuple[int]): min and max values for clamp. """ output = tensor.squeeze(0).detach().clamp_(*min_max).permute(1, 2, 0) output = (output - min_max[0]) / (min_max[1] - min_max[0]) * 255 output = output.type(torch.uint8).cpu().numpy() - if rgb2bgr: - output = cv2.cvtColor(output, cv2.COLOR_RGB2BGR) + output = COLOR2BGR[color_space](output) return output diff --git a/docs/Config.md b/docs/Config.md index d894407ed..bc12686c6 100644 --- a/docs/Config.md +++ b/docs/Config.md @@ -61,6 +61,8 @@ datasets: dataroot_gt: datasets/DIV2K/DIV2K_train_HR_sub # LQ (Low-Quality) folder path dataroot_lq: datasets/DIV2K/DIV2K_train_LR_bicubic/X4_sub + # Colorspace before feed into network, rgb in default + color: rgb # template for file name. Usually, LQ files have suffix like `_x4`. It is used for file name mismatching filename_tmpl: '{}' # IO backend, more details are in [docs/DatasetPreparation.md] @@ -96,6 +98,8 @@ datasets: dataroot_gt: datasets/Set5/GTmod12 # LQ (Low-Quality) folder path dataroot_lq: datasets/Set5/LRbicx4 + # Colorspace before feed into network + color: rgb # IO backend, more details are in [docs/DatasetPreparation.md] io_backend: # directly read from disk diff --git a/inference/inference_ridnet.py b/inference/inference_ridnet.py index 9825ba898..2c8733a2c 100644 --- a/inference/inference_ridnet.py +++ b/inference/inference_ridnet.py @@ -7,7 +7,7 @@ from tqdm import tqdm from basicsr.archs.ridnet_arch import RIDNet -from basicsr.utils.img_util import img2tensor, tensor2img +from basicsr.utils.img_util import ColorSpace, img2tensor, tensor2img if __name__ == '__main__': device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') @@ -41,11 +41,11 @@ pbar.set_description(f'{idx}: {img_name}') # read image img = cv2.imread(img_path, cv2.IMREAD_COLOR) - img = img2tensor(img, bgr2rgb=True, float32=True).unsqueeze(0).to(device) + img = img2tensor(img, color_space=ColorSpace.RGB, float32=True).unsqueeze(0).to(device) # inference with torch.no_grad(): output = net(img) # save image - output = tensor2img(output, rgb2bgr=True, out_type=np.uint8, min_max=(0, 255)) + output = tensor2img(output, color_space=ColorSpace.RGB, out_type=np.uint8, min_max=(0, 255)) save_img_path = os.path.join(result_root, f'{img_name}_x{args.noise_g}_RIDNet.png') cv2.imwrite(save_img_path, output) diff --git a/scripts/metrics/calculate_lpips.py b/scripts/metrics/calculate_lpips.py index 4170fb40e..e4d389641 100644 --- a/scripts/metrics/calculate_lpips.py +++ b/scripts/metrics/calculate_lpips.py @@ -4,7 +4,7 @@ import os.path as osp from torchvision.transforms.functional import normalize -from basicsr.utils import img2tensor +from basicsr.utils import ColorSpace, img2tensor try: import lpips @@ -32,7 +32,7 @@ def main(): img_restored = cv2.imread(osp.join(folder_restored, basename + suffix + ext), cv2.IMREAD_UNCHANGED).astype( np.float32) / 255. - img_gt, img_restored = img2tensor([img_gt, img_restored], bgr2rgb=True, float32=True) + img_gt, img_restored = img2tensor([img_gt, img_restored], color_space=ColorSpace.RGB, float32=True) # norm to [-1, 1] normalize(img_gt, mean, std, inplace=True) normalize(img_restored, mean, std, inplace=True)