-
Notifications
You must be signed in to change notification settings - Fork 54
/
Copy pathinterpolate.py
122 lines (106 loc) · 4.67 KB
/
interpolate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
# python 3.6
"""Interpolates real images with In-domain GAN Inversion.
The real images should be first inverted to latent codes with `invert.py`. After
that, this script can be used for image interpolation.
NOTE: This script will interpolate every image pair from source directory to
target directory.
"""
import os
import argparse
from tqdm import tqdm
import numpy as np
from models.helper import build_generator
from utils.logger import setup_logger
from utils.editor import interpolate
from utils.visualizer import load_image
from utils.visualizer import HtmlPageVisualizer
def parse_args():
"""Parses arguments."""
parser = argparse.ArgumentParser()
parser.add_argument('model_name', type=str, help='Name of the GAN model.')
parser.add_argument('src_dir', type=str,
help='Source directory, which includes original images, '
'inverted codes, and image list.')
parser.add_argument('dst_dir', type=str,
help='Target directory, which includes original images, '
'inverted codes, and image list.')
parser.add_argument('-o', '--output_dir', type=str, default='',
help='Directory to save the results. If not specified, '
'`./results/interpolation` will be used by default.')
parser.add_argument('--step', type=int, default=5,
help='Number of steps for interpolation. (default: 5)')
parser.add_argument('--viz_size', type=int, default=256,
help='Image size for visualization. (default: 256)')
parser.add_argument('--gpu_id', type=str, default='0',
help='Which GPU(s) to use. (default: `0`)')
return parser.parse_args()
def main():
"""Main function."""
args = parse_args()
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_id
src_dir = args.src_dir
src_dir_name = os.path.basename(src_dir.rstrip('/'))
assert os.path.exists(src_dir)
assert os.path.exists(f'{src_dir}/image_list.txt')
assert os.path.exists(f'{src_dir}/inverted_codes.npy')
dst_dir = args.dst_dir
dst_dir_name = os.path.basename(dst_dir.rstrip('/'))
assert os.path.exists(dst_dir)
assert os.path.exists(f'{dst_dir}/image_list.txt')
assert os.path.exists(f'{dst_dir}/inverted_codes.npy')
output_dir = args.output_dir or 'results/interpolation'
job_name = f'{src_dir_name}_TO_{dst_dir_name}'
logger = setup_logger(output_dir, f'{job_name}.log', f'{job_name}_logger')
# Load model.
logger.info(f'Loading generator.')
generator = build_generator(args.model_name)
# Load image and codes.
logger.info(f'Loading images and corresponding inverted latent codes.')
src_list = []
with open(f'{src_dir}/image_list.txt', 'r') as f:
for line in f:
name = os.path.splitext(os.path.basename(line.strip()))[0]
assert os.path.exists(f'{src_dir}/{name}_ori.png')
src_list.append(name)
src_codes = np.load(f'{src_dir}/inverted_codes.npy')
assert src_codes.shape[0] == len(src_list)
num_src = src_codes.shape[0]
dst_list = []
with open(f'{dst_dir}/image_list.txt', 'r') as f:
for line in f:
name = os.path.splitext(os.path.basename(line.strip()))[0]
assert os.path.exists(f'{dst_dir}/{name}_ori.png')
dst_list.append(name)
dst_codes = np.load(f'{dst_dir}/inverted_codes.npy')
assert dst_codes.shape[0] == len(dst_list)
num_dst = dst_codes.shape[0]
# Interpolate images.
logger.info(f'Start interpolation.')
step = args.step + 2
viz_size = None if args.viz_size == 0 else args.viz_size
visualizer = HtmlPageVisualizer(
num_rows=num_src * num_dst, num_cols=step + 2, viz_size=viz_size)
visualizer.set_headers(
['Source', 'Source Inversion'] +
[f'Step {i:02d}' for i in range(1, step - 1)] +
['Target Inversion', 'Target']
)
for src_idx in tqdm(range(num_src), leave=False):
src_code = src_codes[src_idx:src_idx + 1]
src_path = f'{src_dir}/{src_list[src_idx]}_ori.png'
codes = interpolate(src_codes=np.repeat(src_code, num_dst, axis=0),
dst_codes=dst_codes,
step=step)
for dst_idx in tqdm(range(num_dst), leave=False):
dst_path = f'{dst_dir}/{dst_list[dst_idx]}_ori.png'
output_images = generator.easy_synthesize(
codes[dst_idx], latent_space_type='wp')['image']
row_idx = src_idx * num_dst + dst_idx
visualizer.set_cell(row_idx, 0, image=load_image(src_path))
visualizer.set_cell(row_idx, step + 1, image=load_image(dst_path))
for s, output_image in enumerate(output_images):
visualizer.set_cell(row_idx, s + 1, image=output_image)
# Save results.
visualizer.save(f'{output_dir}/{job_name}.html')
if __name__ == '__main__':
main()