-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathremodulation.py
115 lines (90 loc) · 4.17 KB
/
remodulation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import configargparse
import os
os.environ["OPENCV_IO_ENABLE_OPENEXR"]="1"
import cv2
import numpy as np
from skimage.metrics import peak_signal_noise_ratio as psnr, structural_similarity as ssim
from tqdm import tqdm
def getFromExr(path):
bgr = cv2.imread(path, cv2.IMREAD_UNCHANGED)
if bgr is None:
raise ValueError(f"Failed to load image from {path}")
rgb = cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB)
return rgb
def gamma_correct(data):
data = np.where(data < 0.0031, data * 12.92, np.power(data, 1.0/2.4) * 1.055 - 0.055)
return data
def sRGB(linear):
return np.where(linear < 0.0031308,
12.92 * linear,
1.055 * np.power(linear, 1/2.4) - 0.055)
def toneMapTev(color):
color[:, :, 0] = sRGB(color[:, :, 0])
color[:, :, 1] = sRGB(color[:, :, 1])
color[:, :, 2] = sRGB(color[:, :, 2])
return np.clip(color, 0.0, 1.0)
def calc_psnr_ssim(res_dir, gt_dir):
total_ssim = 0
total_psnr = 0
count = 0
for name in tqdm(os.listdir(res_dir)):
png_path = os.path.join(res_dir, name)
gt_path = os.path.join(gt_dir, name)
res_png = getFromExr(png_path)
gt_png = getFromExr(gt_path)
total_psnr += psnr(gt_png, res_png)
total_ssim += ssim(gt_png, res_png, multichannel=True)
count += 1
print("avg_psnr : {}".format(total_psnr / count))
print("avg_ssim : {}".format(total_ssim / count))
def remodulation(exp_dir, gt_dir):
avg_psnr, avg_ssim = 0, 0
sr_res_dir = os.path.join(exp_dir, 'sr_results_x4').replace("\\", "/")
img_save_dir = os.path.join(exp_dir, 'final_results_x4').replace("\\", "/")
folder_num = len(os.listdir(sr_res_dir))
for ind in range(folder_num):
os.makedirs(os.path.join(img_save_dir, str(ind)).replace("\\", "/"), exist_ok=True)
cur_res_dir = os.path.join(sr_res_dir, str(ind)).replace("\\", "/")
cur_res_lst = os.listdir(cur_res_dir)
cur_num = len(cur_res_lst)
cur_psnr, cur_ssim = 0, 0
for name in tqdm(cur_res_lst):
if ".png" in name:
continue
res_path = os.path.join(cur_res_dir, name).replace("\\", "/")
png_name = name.split('.')[0] + '.png'
try:
irr = getFromExr(res_path)
irr[irr < 0] = 0
brdf_path = os.path.join(gt_dir, str(ind), "BRDF", name).replace("\\", "/")
emiss_sky_path = os.path.join(gt_dir, str(ind), "Emission_Sky", name).replace("\\", "/")
brdf = getFromExr(brdf_path)
emiss_sky = getFromExr(emiss_sky_path)
emiss_sky_mask = ((abs(emiss_sky[:, :, 0]) >= 1e-4) | (abs(emiss_sky[:, :, 1]) >= 1e-4) | (abs(emiss_sky[:, :, 2]) >= 1e-4))[:, :, np.newaxis]
sr_img = brdf * irr
sr_img = np.where(emiss_sky_mask, emiss_sky, sr_img)
sr_img = (toneMapTev(sr_img) * 255).astype(np.uint8)
gt_img_path = os.path.join(gt_dir, str(ind), "View_PNG", png_name).replace("\\", "/")
gt_img = getFromExr(gt_img_path)
cur_psnr += psnr(gt_img, sr_img)
cur_ssim += ssim(gt_img, sr_img, win_size=11, channel_axis=2, data_range=255)
save_path = os.path.join(img_save_dir,str(ind), png_name).replace("\\", "/")
cv2.imwrite(save_path, sr_img[:, :, ::-1])
except ValueError as e:
print(f"Error processing {name}: {e}")
cur_psnr /= cur_num
cur_ssim /= cur_num
avg_psnr += cur_psnr
avg_ssim += cur_ssim
avg_psnr /= folder_num
avg_ssim /= folder_num
print("Avg_pnsr: {}".format(avg_psnr))
print("Avg_ssim: {}".format(avg_ssim))
if __name__ == '__main__':
parser = configargparse.ArgumentParser()
parser.add_argument('--exp_dir', type=str, default=r"../experiment/Bistro_X4",
help='experiment dir')
parser.add_argument('--gt_dir', type=str, default=r"../dataset/Bistro/test/GT",
help='ground truth dir, which contains View_PNG, BRDF and Emisson_Sky.')
args = parser.parse_args()
remodulation(args.exp_dir.replace("\\", "/"), args.gt_dir.replace("\\", "/"))