Skip to content

Commit

Permalink
Fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
Stanislav Pidhorskyi committed Oct 11, 2019
1 parent 6ab4cec commit 7258f82
Show file tree
Hide file tree
Showing 7 changed files with 176 additions and 47 deletions.
58 changes: 42 additions & 16 deletions VAE.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,13 +168,13 @@ def train(cfg, logger, local_rank, world_size, distributed):
], lr=cfg.TRAIN.BASE_LEARNING_RATE, betas=(cfg.TRAIN.ADAM_BETA_0, cfg.TRAIN.ADAM_BETA_1), weight_decay=0)

scheduler = ComboMultiStepLR(optimizers=
{
'generator': autoencoder_optimizer,
# 'discriminator': discriminator_optimizer
},
milestones=cfg.TRAIN.LEARNING_DECAY_STEPS,
gamma=cfg.TRAIN.LEARNING_DECAY_RATE,
reference_batch_size=32)
{
'autoencoder': autoencoder_optimizer,
# 'discriminator': discriminator_optimizer
},
milestones=cfg.TRAIN.LEARNING_DECAY_STEPS,
gamma=cfg.TRAIN.LEARNING_DECAY_RATE,
reference_batch_size=32, base_lr=cfg.TRAIN.LEARNING_RATES)

model_dict = {
'discriminator': encoder,
Expand All @@ -200,7 +200,7 @@ def train(cfg, logger, local_rank, world_size, distributed):
logger=logger,
save=local_rank == 0)

extra_checkpoint_data = checkpointer.load()
extra_checkpoint_data = checkpointer.load()#file_name='results_ae/model_tmp.pth')

arguments.update(extra_checkpoint_data)

Expand All @@ -224,21 +224,44 @@ def process_batch(batch):
sample = process_batch(data_train[:32])
del data_train

lod2batch.set_epoch(scheduler.start_epoch(), [autoencoder_optimizer])

print(decoder.get_statistics(lod2batch.lod))
print(encoder.get_statistics(lod2batch.lod))

# stds = []
# dataset.reset(lod2batch.get_lod_power2(), lod2batch.get_per_GPU_batch_size())
# batches = make_dataloader(cfg, logger, dataset, lod2batch.get_per_GPU_batch_size(), local_rank)
# for x_orig in tqdm(batches):
# x_orig = (x_orig / 127.5 - 1.)
# x = x_orig.std()
# stds.append(x.item())
#
# print(sum(stds) / len(stds))

# exit()

for epoch in range(scheduler.start_epoch(), cfg.TRAIN.TRAIN_EPOCHS):
model.train()
lod2batch.set_epoch(epoch, [autoencoder_optimizer])

logger.info("Batch size: %d, Batch size per GPU: %d, LOD: %d - %dx%d, dataset size: %d" % (lod2batch.get_batch_size(),
print(decoder.get_statistics(lod2batch.lod))
print(encoder.get_statistics(lod2batch.lod))
# exit()

logger.info("Batch size: %d, Batch size per GPU: %d, LOD: %d - %dx%d, blend: %.3f, dataset size: %d" % (
lod2batch.get_batch_size(),
lod2batch.get_per_GPU_batch_size(),
lod2batch.lod,
2 ** lod2batch.get_lod_power2(),
2 ** lod2batch.get_lod_power2(),
lod2batch.get_blend_factor(),
len(dataset) * world_size))

dataset.reset(lod2batch.get_lod_power2(), lod2batch.get_per_GPU_batch_size())
batches = make_dataloader(cfg, logger, dataset, lod2batch.get_per_GPU_batch_size(), local_rank)

scheduler.set_batch_size(32)
scheduler.set_batch_size(lod2batch.get_batch_size(), lod2batch.lod)

model.train()

Expand Down Expand Up @@ -308,8 +331,12 @@ def process_batch(batch):
per_epoch_ptime = epoch_end_time - epoch_start_time

lod2batch.step()
if local_rank == 0 and lod2batch.is_time_to_report():
save_sample(lod2batch, tracker, sample, x, logger, model_s, cfg, autoencoder_optimizer)
if local_rank == 0:
if lod2batch.is_time_to_save():
checkpointer.save("model_tmp_intermediate")
if lod2batch.is_time_to_report():
save_sample(lod2batch, tracker, sample, x, logger, model_s, cfg, autoencoder_optimizer)

#
#
# with torch.no_grad():
Expand Down Expand Up @@ -342,12 +369,11 @@ def process_batch(batch):
# save_image(resultsample,
# 'results_gen/sample_' + str(epoch) + "_" + str(i // lod_2_batch[lod]) + '.jpg', nrow=8)

scheduler.step()

if local_rank == 0:
checkpointer.save("model_tmp")
save_sample(lod2batch, tracker, sample, x, logger, model_s, cfg, autoencoder_optimizer)
if epoch > 2:
checkpointer.save("model_tmp_%d" % epoch)

scheduler.step()

logger.info("Training finish!... save training results")
if local_rank == 0:
Expand Down
16 changes: 10 additions & 6 deletions configs/experiment.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,19 @@ MODEL:
DLATENT_AVG_BETA: 0.995
# MAPPING_LAYERS: 5
MAPPING_LAYERS: 8
OUTPUT_DIR: results_ae_2
OUTPUT_DIR: results_ae_6
TRAIN:
BASE_LEARNING_RATE: 0.002
EPOCHS_PER_LOD: 6
LEARNING_DECAY_RATE: 0.1
LEARNING_DECAY_STEPS: []
TRAIN_EPOCHS: 44
# 4 8 16 32 64 128 256
LOD_2_BATCH_8GPU: [512, 256, 128, 64, 32, 32]
LOD_2_BATCH_4GPU: [512, 256, 128, 64, 32, 16]
LOD_2_BATCH_2GPU: [512, 256, 128, 64, 32, 16]
LOD_2_BATCH_1GPU: [512, 256, 128, 64, 32, 16]
# 4 8 16 32 64 128 256 512 1024
LOD_2_BATCH_8GPU: [512, 256, 128, 64, 32, 32, 32, 32, 32]
LOD_2_BATCH_4GPU: [512, 256, 128, 64, 32, 32, 32, 32, 16]
LOD_2_BATCH_2GPU: [512, 256, 128, 64, 32, 32, 16]
LOD_2_BATCH_1GPU: [128, 128, 128, 64, 32, 16]

LEARNING_RATES: [0.006, 0.006, 0.006, 0.003, 0.0015, 0.0015, 0.002, 0.003, 0.003]

# LEARNING_RATES: [0.0015, 0.0015, 0.0015, 0.0015, 0.0015, 0.0015, 0.002, 0.003, 0.003,]
7 changes: 6 additions & 1 deletion defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,12 @@
_C.TRAIN.LOD_2_BATCH_2GPU = [256, 256, 128, 64, 32, 16]
_C.TRAIN.LOD_2_BATCH_1GPU = [128, 128, 128, 64, 32, 16]

_C.TRAIN.SNAPSHOT_FREQ = [100, 80, 60, 30, 20, 10, 20, 20, 20]

_C.TRAIN.SNAPSHOT_FREQ = [300, 300, 300, 100, 50, 30, 20, 20, 10]

_C.TRAIN.REPORT_FREQ = [100, 80, 60, 30, 20, 10, 10, 5, 5]

_C.TRAIN.LEARNING_RATES = [0.002]


def get_cfg_defaults():
Expand Down
17 changes: 13 additions & 4 deletions lod_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,10 @@ def __init__(self, cfg, logger, world_size, dataset_size):
self.epoch_end_time = 0
self.epoch_start_time = 0
self.per_epoch_ptime = 0
self.reports = cfg.TRAIN.REPORT_FREQ
self.snapshots = cfg.TRAIN.SNAPSHOT_FREQ
self.tick_start_nimg = 0
self.tick_start_nimg_report = 0
self.tick_start_nimg_snapshot = 0

def get_lod_power2(self):
return self.lod + 2
Expand All @@ -68,8 +70,14 @@ def get_blend_factor(self):
return blend_factor

def is_time_to_report(self):
if self.iteration >= self.tick_start_nimg + self.snapshots[min(self.lod, len(self.snapshots) - 1)] * 1000:
self.tick_start_nimg = self.iteration
if self.iteration >= self.tick_start_nimg_report + self.reports[min(self.lod, len(self.reports) - 1)] * 1000:
self.tick_start_nimg_report = self.iteration
return True
return False

def is_time_to_save(self):
if self.iteration >= self.tick_start_nimg_snapshot + self.snapshots[min(self.lod, len(self.snapshots) - 1)] * 1000:
self.tick_start_nimg_snapshot = self.iteration
return True
return False

Expand All @@ -81,7 +89,8 @@ def step(self):
def set_epoch(self, epoch, optimizers):
self.current_epoch = epoch
self.iteration = 0
self.tick_start_nimg = 0
self.tick_start_nimg_report = 0
self.tick_start_nimg_snapshot = 0
self.epoch_start_time = time.time()

new_lod = min(self.cfg.MODEL.LAYER_COUNT - 1, epoch // self.cfg.TRAIN.EPOCHS_PER_LOD)
Expand Down
33 changes: 31 additions & 2 deletions net.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,7 @@ def __init__(self, inputs, channels):
super(ToRGB, self).__init__()
self.inputs = inputs
self.channels = channels
self.to_rgb = ln.Conv2d(inputs, channels, 1, 1, 0, gain=1)
self.to_rgb = ln.Conv2d(inputs, channels, 1, 1, 0, gain=0.03)

def forward(self, x):
x = self.to_rgb(x)
Expand All @@ -265,7 +265,7 @@ def __init__(self, startf, maxf, layer_count, latent_size, channels=3):
self.maxf = maxf
self.startf = startf
self.layer_count = layer_count
self.from_rgb = nn.ModuleList()
self.from_rgb: nn.ModuleList[FromRGB] = nn.ModuleList()
self.channels = channels
self.latent_size = latent_size

Expand Down Expand Up @@ -332,6 +332,19 @@ def forward(self, x, lod, blend):
else:
return self.encode2(x, lod, blend)

def get_statistics(self, lod):
rgb_std = self.from_rgb[self.layer_count - lod - 1].from_rgb.weight.std().item()
rgb_std_c = self.from_rgb[self.layer_count - lod - 1].from_rgb.std

layers = []
for i in range(self.layer_count - lod - 1, self.layer_count):
conv_1 = self.encode_block[i].conv_1.weight.std().item()
conv_1_c = self.encode_block[i].conv_1.std
conv_2 = self.encode_block[i].conv_2.weight.std().item()
conv_2_c = self.encode_block[i].conv_2.std
layers.append(((conv_1 / conv_1_c), (conv_2 / conv_2_c)))
return rgb_std / rgb_std_c, layers


class Discriminator(nn.Module):
def __init__(self, startf=32, maxf=256, layer_count=3, channels=3):
Expand Down Expand Up @@ -480,6 +493,22 @@ def forward(self, styles, lod, blend):
else:
return self.decode2(styles, lod, blend, 1.)

def get_statistics(self, lod):
rgb_std = self.to_rgb[lod].to_rgb.weight.std().item()
rgb_std_c = self.to_rgb[lod].to_rgb.std

layers = []
for i in range(lod + 1):
conv_1 = 1.0
conv_1_c = 1.0
if i != 0:
conv_1 = self.decode_block[i].conv_1.weight.std().item()
conv_1_c = self.decode_block[i].conv_1.std
conv_2 = self.decode_block[i].conv_2.weight.std().item()
conv_2_c = self.decode_block[i].conv_2.std
layers.append(((conv_1 / conv_1_c), (conv_2 / conv_2_c)))
return rgb_std / rgb_std_c, layers


def minibatch_stddev_layer(x, group_size=4):
group_size = min(group_size, x.shape[0])
Expand Down
56 changes: 46 additions & 10 deletions scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@ def __init__(
warmup_factor=1.0 / 1.0,
warmup_iters=1,
last_epoch=-1,
reference_batch_size=128
reference_batch_size=128,
lr=[]
):
if not list(milestones) == sorted(milestones):
raise ValueError(
Expand All @@ -24,41 +25,75 @@ def __init__(
self.warmup_factor = warmup_factor
self.warmup_iters = warmup_iters
self.batch_size = 1
self.lod = 0
self.reference_batch_size = reference_batch_size
super(WarmupMultiStepLR, self).__init__(optimizer, last_epoch)

def set_batch_size(self, batch_size):
self.optimizer = optimizer
self.base_lrs = []
for _ in self.optimizer.param_groups:
self.base_lrs.append(lr)

self.last_epoch = last_epoch

if not isinstance(optimizer, torch.optim.Optimizer):
raise TypeError('{} is not an Optimizer'.format(
type(optimizer).__name__))
self.optimizer = optimizer

if last_epoch == -1:
for group in optimizer.param_groups:
group.setdefault('initial_lr', group['lr'])
last_epoch = 0

self.last_epoch = last_epoch

self.optimizer._step_count = 0
self._step_count = 0
self.step(last_epoch)

def set_batch_size(self, batch_size, lod):
self.batch_size = batch_size
self.lod = lod
for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()):
param_group['lr'] = lr

def get_lr(self):
warmup_factor = 1
if self.last_epoch < self.warmup_iters:
alpha = float(self.last_epoch) / self.warmup_iters
warmup_factor = self.warmup_factor * (1 - alpha) + alpha
return [
base_lr
base_lr[self.lod]
* warmup_factor
* self.gamma ** bisect_right(self.milestones, self.last_epoch)
# * float(self.batch_size)
# / float(self.reference_batch_size)
for base_lr in self.base_lrs
]

def state_dict(self):
return {
"last_epoch": self.last_epoch
}

def load_state_dict(self, state_dict):
self.__dict__.update(dict(last_epoch=state_dict["last_epoch"]))


class ComboMultiStepLR:
def __init__(
self,
optimizers,
optimizers, base_lr,
**kwargs
):
self.schedulers = dict()
for name, opt in optimizers.items():
self.schedulers[name] = WarmupMultiStepLR(opt, **kwargs)
self.schedulers[name] = WarmupMultiStepLR(opt, lr=base_lr, **kwargs)
self.last_epoch = 0

def set_batch_size(self, batch_size):
def set_batch_size(self, batch_size, lod):
for x in self.schedulers.values():
x.set_batch_size(batch_size)
x.set_batch_size(batch_size, lod)

def step(self, epoch=None):
for x in self.schedulers.values():
Expand All @@ -72,10 +107,11 @@ def state_dict(self):

def load_state_dict(self, state_dict):
for k, x in self.schedulers.items():
x.__dict__.update(state_dict[k])
x.load_state_dict(state_dict[k])

last_epochs = [x.last_epoch for k, x in self.schedulers.items()]
assert np.all(np.asarray(last_epochs) == last_epochs[0])
self.last_epoch = last_epochs[0] + 1
self.last_epoch = last_epochs[0]

def start_epoch(self):
return self.last_epoch
Loading

0 comments on commit 7258f82

Please sign in to comment.