-
Notifications
You must be signed in to change notification settings - Fork 38
/
Copy pathgenerate.py
214 lines (189 loc) · 9.12 KB
/
generate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""
Samples a large number of images from a pre-trained SiT model using DDP.
Subsequently saves a .npz file that can be used to compute FID and other
evaluation metrics via the ADM repo: https://github.com/openai/guided-diffusion/tree/main/evaluations
For a simple single-GPU/CPU sampling script, see sample.py.
"""
import torch
import torch.distributed as dist
from models.sit import SiT_models
from diffusers.models import AutoencoderKL
from tqdm import tqdm
import os
from PIL import Image
import numpy as np
import math
import argparse
from samplers import euler_sampler, euler_maruyama_sampler
from utils import load_legacy_checkpoints, download_model
def create_npz_from_sample_folder(sample_dir, num=50_000):
"""
Builds a single .npz file from a folder of .png samples.
"""
samples = []
for i in tqdm(range(num), desc="Building .npz file from samples"):
sample_pil = Image.open(f"{sample_dir}/{i:06d}.png")
sample_np = np.asarray(sample_pil).astype(np.uint8)
samples.append(sample_np)
samples = np.stack(samples)
assert samples.shape == (num, samples.shape[1], samples.shape[2], 3)
npz_path = f"{sample_dir}.npz"
np.savez(npz_path, arr_0=samples)
print(f"Saved .npz file to {npz_path} [shape={samples.shape}].")
return npz_path
def main(args):
"""
Run sampling.
"""
torch.backends.cuda.matmul.allow_tf32 = args.tf32 # True: fast but may lead to some small numerical differences
assert torch.cuda.is_available(), "Sampling with DDP requires at least one GPU. sample.py supports CPU-only usage"
torch.set_grad_enabled(False)
# Setup DDP:cd
dist.init_process_group("nccl")
rank = dist.get_rank()
device = rank % torch.cuda.device_count()
seed = args.global_seed * dist.get_world_size() + rank
torch.manual_seed(seed)
torch.cuda.set_device(device)
print(f"Starting rank={rank}, seed={seed}, world_size={dist.get_world_size()}.")
# Load model:
block_kwargs = {"fused_attn": args.fused_attn, "qk_norm": args.qk_norm}
latent_size = args.resolution // 8
model = SiT_models[args.model](
input_size=latent_size,
num_classes=args.num_classes,
use_cfg = True,
z_dims = [int(z_dim) for z_dim in args.projector_embed_dims.split(',')],
encoder_depth=args.encoder_depth,
**block_kwargs,
).to(device)
# Auto-download a pre-trained model or load a custom SiT checkpoint from train.py:
ckpt_path = args.ckpt
if ckpt_path is None:
args.ckpt = 'SiT-XL-2-256x256.pt'
assert args.model == 'SiT-XL/2'
assert len(args.projector_embed_dims.split(',')) == 1
assert int(args.projector_embed_dims.split(',')[0]) == 768
state_dict = download_model('last.pt')
else:
state_dict = torch.load(ckpt_path, map_location=f'cuda:{device}')['ema']
if args.legacy:
state_dict = load_legacy_checkpoints(
state_dict=state_dict, encoder_depth=args.encoder_depth
)
model.load_state_dict(state_dict)
model.eval() # important!
vae = AutoencoderKL.from_pretrained(f"stabilityai/sd-vae-ft-{args.vae}").to(device)
assert args.cfg_scale >= 1.0, "In almost all cases, cfg_scale be >= 1.0"
using_cfg = args.cfg_scale > 1.0
# Create folder to save samples:
model_string_name = args.model.replace("/", "-")
ckpt_string_name = os.path.basename(args.ckpt).replace(".pt", "") if args.ckpt else "pretrained"
folder_name = f"{model_string_name}-{ckpt_string_name}-size-{args.resolution}-vae-{args.vae}-" \
f"cfg-{args.cfg_scale}-seed-{args.global_seed}-{args.mode}"
sample_folder_dir = f"{args.sample_dir}/{folder_name}"
if rank == 0:
os.makedirs(sample_folder_dir, exist_ok=True)
print(f"Saving .png samples at {sample_folder_dir}")
dist.barrier()
# Figure out how many samples we need to generate on each GPU and how many iterations we need to run:
n = args.per_proc_batch_size
global_batch_size = n * dist.get_world_size()
# To make things evenly-divisible, we'll sample a bit more than we need and then discard the extra samples:
total_samples = int(math.ceil(args.num_fid_samples / global_batch_size) * global_batch_size)
if rank == 0:
print(f"Total number of images that will be sampled: {total_samples}")
print(f"SiT Parameters: {sum(p.numel() for p in model.parameters()):,}")
print(f"projector Parameters: {sum(p.numel() for p in model.projectors.parameters()):,}")
assert total_samples % dist.get_world_size() == 0, "total_samples must be divisible by world_size"
samples_needed_this_gpu = int(total_samples // dist.get_world_size())
assert samples_needed_this_gpu % n == 0, "samples_needed_this_gpu must be divisible by the per-GPU batch size"
iterations = int(samples_needed_this_gpu // n)
pbar = range(iterations)
pbar = tqdm(pbar) if rank == 0 else pbar
total = 0
for _ in pbar:
# Sample inputs:
z = torch.randn(n, model.in_channels, latent_size, latent_size, device=device)
y = torch.randint(0, args.num_classes, (n,), device=device)
# Sample images:
sampling_kwargs = dict(
model=model,
latents=z,
y=y,
num_steps=args.num_steps,
heun=args.heun,
cfg_scale=args.cfg_scale,
guidance_low=args.guidance_low,
guidance_high=args.guidance_high,
path_type=args.path_type,
)
with torch.no_grad():
if args.mode == "sde":
samples = euler_maruyama_sampler(**sampling_kwargs).to(torch.float32)
elif args.mode == "ode":
samples = euler_sampler(**sampling_kwargs).to(torch.float32)
else:
raise NotImplementedError()
latents_scale = torch.tensor(
[0.18215, 0.18215, 0.18215, 0.18215, ]
).view(1, 4, 1, 1).to(device)
latents_bias = -torch.tensor(
[0., 0., 0., 0.,]
).view(1, 4, 1, 1).to(device)
samples = vae.decode((samples - latents_bias) / latents_scale).sample
samples = (samples + 1) / 2.
samples = torch.clamp(
255. * samples, 0, 255
).permute(0, 2, 3, 1).to("cpu", dtype=torch.uint8).numpy()
# Save samples to disk as individual .png files
for i, sample in enumerate(samples):
index = i * dist.get_world_size() + rank + total
Image.fromarray(sample).save(f"{sample_folder_dir}/{index:06d}.png")
total += global_batch_size
# Make sure all processes have finished saving their samples before attempting to convert to .npz
dist.barrier()
if rank == 0:
create_npz_from_sample_folder(sample_folder_dir, args.num_fid_samples)
print("Done.")
dist.barrier()
dist.destroy_process_group()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
# seed
parser.add_argument("--global-seed", type=int, default=0)
# precision
parser.add_argument("--tf32", action=argparse.BooleanOptionalAction, default=True,
help="By default, use TF32 matmuls. This massively accelerates sampling on Ampere GPUs.")
# logging/saving:
parser.add_argument("--ckpt", type=str, default=None, help="Optional path to a SiT checkpoint.")
parser.add_argument("--sample-dir", type=str, default="samples")
# model
parser.add_argument("--model", type=str, choices=list(SiT_models.keys()), default="SiT-XL/2")
parser.add_argument("--num-classes", type=int, default=1000)
parser.add_argument("--encoder-depth", type=int, default=8)
parser.add_argument("--resolution", type=int, choices=[256, 512], default=256)
parser.add_argument("--fused-attn", action=argparse.BooleanOptionalAction, default=False)
parser.add_argument("--qk-norm", action=argparse.BooleanOptionalAction, default=False)
# vae
parser.add_argument("--vae", type=str, choices=["ema", "mse"], default="ema")
# number of samples
parser.add_argument("--per-proc-batch-size", type=int, default=32)
parser.add_argument("--num-fid-samples", type=int, default=50_000)
# sampling related hyperparameters
parser.add_argument("--mode", type=str, default="ode")
parser.add_argument("--cfg-scale", type=float, default=1.5)
parser.add_argument("--projector-embed-dims", type=str, default="768,1024")
parser.add_argument("--path-type", type=str, default="linear", choices=["linear", "cosine"])
parser.add_argument("--num-steps", type=int, default=50)
parser.add_argument("--heun", action=argparse.BooleanOptionalAction, default=False) # only for ode
parser.add_argument("--guidance-low", type=float, default=0.)
parser.add_argument("--guidance-high", type=float, default=1.)
# will be deprecated
parser.add_argument("--legacy", action=argparse.BooleanOptionalAction, default=False) # only for ode
args = parser.parse_args()
main(args)