diff --git a/VAE.py b/VAE.py index 3d4b824e..60000459 100644 --- a/VAE.py +++ b/VAE.py @@ -168,13 +168,13 @@ def train(cfg, logger, local_rank, world_size, distributed): ], lr=cfg.TRAIN.BASE_LEARNING_RATE, betas=(cfg.TRAIN.ADAM_BETA_0, cfg.TRAIN.ADAM_BETA_1), weight_decay=0) scheduler = ComboMultiStepLR(optimizers= - { - 'generator': autoencoder_optimizer, - # 'discriminator': discriminator_optimizer - }, - milestones=cfg.TRAIN.LEARNING_DECAY_STEPS, - gamma=cfg.TRAIN.LEARNING_DECAY_RATE, - reference_batch_size=32) + { + 'autoencoder': autoencoder_optimizer, + # 'discriminator': discriminator_optimizer + }, + milestones=cfg.TRAIN.LEARNING_DECAY_STEPS, + gamma=cfg.TRAIN.LEARNING_DECAY_RATE, + reference_batch_size=32, base_lr=cfg.TRAIN.LEARNING_RATES) model_dict = { 'discriminator': encoder, @@ -200,7 +200,7 @@ def train(cfg, logger, local_rank, world_size, distributed): logger=logger, save=local_rank == 0) - extra_checkpoint_data = checkpointer.load() + extra_checkpoint_data = checkpointer.load()#file_name='results_ae/model_tmp.pth') arguments.update(extra_checkpoint_data) @@ -224,21 +224,44 @@ def process_batch(batch): sample = process_batch(data_train[:32]) del data_train + lod2batch.set_epoch(scheduler.start_epoch(), [autoencoder_optimizer]) + + print(decoder.get_statistics(lod2batch.lod)) + print(encoder.get_statistics(lod2batch.lod)) + + # stds = [] + # dataset.reset(lod2batch.get_lod_power2(), lod2batch.get_per_GPU_batch_size()) + # batches = make_dataloader(cfg, logger, dataset, lod2batch.get_per_GPU_batch_size(), local_rank) + # for x_orig in tqdm(batches): + # x_orig = (x_orig / 127.5 - 1.) + # x = x_orig.std() + # stds.append(x.item()) + # + # print(sum(stds) / len(stds)) + + # exit() + for epoch in range(scheduler.start_epoch(), cfg.TRAIN.TRAIN_EPOCHS): model.train() lod2batch.set_epoch(epoch, [autoencoder_optimizer]) - logger.info("Batch size: %d, Batch size per GPU: %d, LOD: %d - %dx%d, dataset size: %d" % (lod2batch.get_batch_size(), + print(decoder.get_statistics(lod2batch.lod)) + print(encoder.get_statistics(lod2batch.lod)) + # exit() + + logger.info("Batch size: %d, Batch size per GPU: %d, LOD: %d - %dx%d, blend: %.3f, dataset size: %d" % ( + lod2batch.get_batch_size(), lod2batch.get_per_GPU_batch_size(), lod2batch.lod, 2 ** lod2batch.get_lod_power2(), 2 ** lod2batch.get_lod_power2(), + lod2batch.get_blend_factor(), len(dataset) * world_size)) dataset.reset(lod2batch.get_lod_power2(), lod2batch.get_per_GPU_batch_size()) batches = make_dataloader(cfg, logger, dataset, lod2batch.get_per_GPU_batch_size(), local_rank) - scheduler.set_batch_size(32) + scheduler.set_batch_size(lod2batch.get_batch_size(), lod2batch.lod) model.train() @@ -308,8 +331,12 @@ def process_batch(batch): per_epoch_ptime = epoch_end_time - epoch_start_time lod2batch.step() - if local_rank == 0 and lod2batch.is_time_to_report(): - save_sample(lod2batch, tracker, sample, x, logger, model_s, cfg, autoencoder_optimizer) + if local_rank == 0: + if lod2batch.is_time_to_save(): + checkpointer.save("model_tmp_intermediate") + if lod2batch.is_time_to_report(): + save_sample(lod2batch, tracker, sample, x, logger, model_s, cfg, autoencoder_optimizer) + # # # with torch.no_grad(): @@ -342,12 +369,11 @@ def process_batch(batch): # save_image(resultsample, # 'results_gen/sample_' + str(epoch) + "_" + str(i // lod_2_batch[lod]) + '.jpg', nrow=8) + scheduler.step() + if local_rank == 0: + checkpointer.save("model_tmp") save_sample(lod2batch, tracker, sample, x, logger, model_s, cfg, autoencoder_optimizer) - if epoch > 2: - checkpointer.save("model_tmp_%d" % epoch) - - scheduler.step() logger.info("Training finish!... save training results") if local_rank == 0: diff --git a/configs/experiment.yaml b/configs/experiment.yaml index bd053359..21a2a7d7 100644 --- a/configs/experiment.yaml +++ b/configs/experiment.yaml @@ -11,15 +11,19 @@ MODEL: DLATENT_AVG_BETA: 0.995 # MAPPING_LAYERS: 5 MAPPING_LAYERS: 8 -OUTPUT_DIR: results_ae_2 +OUTPUT_DIR: results_ae_6 TRAIN: BASE_LEARNING_RATE: 0.002 EPOCHS_PER_LOD: 6 LEARNING_DECAY_RATE: 0.1 LEARNING_DECAY_STEPS: [] TRAIN_EPOCHS: 44 - # 4 8 16 32 64 128 256 - LOD_2_BATCH_8GPU: [512, 256, 128, 64, 32, 32] - LOD_2_BATCH_4GPU: [512, 256, 128, 64, 32, 16] - LOD_2_BATCH_2GPU: [512, 256, 128, 64, 32, 16] - LOD_2_BATCH_1GPU: [512, 256, 128, 64, 32, 16] + # 4 8 16 32 64 128 256 512 1024 + LOD_2_BATCH_8GPU: [512, 256, 128, 64, 32, 32, 32, 32, 32] + LOD_2_BATCH_4GPU: [512, 256, 128, 64, 32, 32, 32, 32, 16] + LOD_2_BATCH_2GPU: [512, 256, 128, 64, 32, 32, 16] + LOD_2_BATCH_1GPU: [128, 128, 128, 64, 32, 16] + + LEARNING_RATES: [0.006, 0.006, 0.006, 0.003, 0.0015, 0.0015, 0.002, 0.003, 0.003] + + # LEARNING_RATES: [0.0015, 0.0015, 0.0015, 0.0015, 0.0015, 0.0015, 0.002, 0.003, 0.003,] diff --git a/defaults.py b/defaults.py index 036534f1..047f6592 100644 --- a/defaults.py +++ b/defaults.py @@ -60,7 +60,12 @@ _C.TRAIN.LOD_2_BATCH_2GPU = [256, 256, 128, 64, 32, 16] _C.TRAIN.LOD_2_BATCH_1GPU = [128, 128, 128, 64, 32, 16] -_C.TRAIN.SNAPSHOT_FREQ = [100, 80, 60, 30, 20, 10, 20, 20, 20] + +_C.TRAIN.SNAPSHOT_FREQ = [300, 300, 300, 100, 50, 30, 20, 20, 10] + +_C.TRAIN.REPORT_FREQ = [100, 80, 60, 30, 20, 10, 10, 5, 5] + +_C.TRAIN.LEARNING_RATES = [0.002] def get_cfg_defaults(): diff --git a/lod_driver.py b/lod_driver.py index 14bfef0d..da4daf61 100644 --- a/lod_driver.py +++ b/lod_driver.py @@ -42,8 +42,10 @@ def __init__(self, cfg, logger, world_size, dataset_size): self.epoch_end_time = 0 self.epoch_start_time = 0 self.per_epoch_ptime = 0 + self.reports = cfg.TRAIN.REPORT_FREQ self.snapshots = cfg.TRAIN.SNAPSHOT_FREQ - self.tick_start_nimg = 0 + self.tick_start_nimg_report = 0 + self.tick_start_nimg_snapshot = 0 def get_lod_power2(self): return self.lod + 2 @@ -68,8 +70,14 @@ def get_blend_factor(self): return blend_factor def is_time_to_report(self): - if self.iteration >= self.tick_start_nimg + self.snapshots[min(self.lod, len(self.snapshots) - 1)] * 1000: - self.tick_start_nimg = self.iteration + if self.iteration >= self.tick_start_nimg_report + self.reports[min(self.lod, len(self.reports) - 1)] * 1000: + self.tick_start_nimg_report = self.iteration + return True + return False + + def is_time_to_save(self): + if self.iteration >= self.tick_start_nimg_snapshot + self.snapshots[min(self.lod, len(self.snapshots) - 1)] * 1000: + self.tick_start_nimg_snapshot = self.iteration return True return False @@ -81,7 +89,8 @@ def step(self): def set_epoch(self, epoch, optimizers): self.current_epoch = epoch self.iteration = 0 - self.tick_start_nimg = 0 + self.tick_start_nimg_report = 0 + self.tick_start_nimg_snapshot = 0 self.epoch_start_time = time.time() new_lod = min(self.cfg.MODEL.LAYER_COUNT - 1, epoch // self.cfg.TRAIN.EPOCHS_PER_LOD) diff --git a/net.py b/net.py index f022c3c8..798fa103 100644 --- a/net.py +++ b/net.py @@ -252,7 +252,7 @@ def __init__(self, inputs, channels): super(ToRGB, self).__init__() self.inputs = inputs self.channels = channels - self.to_rgb = ln.Conv2d(inputs, channels, 1, 1, 0, gain=1) + self.to_rgb = ln.Conv2d(inputs, channels, 1, 1, 0, gain=0.03) def forward(self, x): x = self.to_rgb(x) @@ -265,7 +265,7 @@ def __init__(self, startf, maxf, layer_count, latent_size, channels=3): self.maxf = maxf self.startf = startf self.layer_count = layer_count - self.from_rgb = nn.ModuleList() + self.from_rgb: nn.ModuleList[FromRGB] = nn.ModuleList() self.channels = channels self.latent_size = latent_size @@ -332,6 +332,19 @@ def forward(self, x, lod, blend): else: return self.encode2(x, lod, blend) + def get_statistics(self, lod): + rgb_std = self.from_rgb[self.layer_count - lod - 1].from_rgb.weight.std().item() + rgb_std_c = self.from_rgb[self.layer_count - lod - 1].from_rgb.std + + layers = [] + for i in range(self.layer_count - lod - 1, self.layer_count): + conv_1 = self.encode_block[i].conv_1.weight.std().item() + conv_1_c = self.encode_block[i].conv_1.std + conv_2 = self.encode_block[i].conv_2.weight.std().item() + conv_2_c = self.encode_block[i].conv_2.std + layers.append(((conv_1 / conv_1_c), (conv_2 / conv_2_c))) + return rgb_std / rgb_std_c, layers + class Discriminator(nn.Module): def __init__(self, startf=32, maxf=256, layer_count=3, channels=3): @@ -480,6 +493,22 @@ def forward(self, styles, lod, blend): else: return self.decode2(styles, lod, blend, 1.) + def get_statistics(self, lod): + rgb_std = self.to_rgb[lod].to_rgb.weight.std().item() + rgb_std_c = self.to_rgb[lod].to_rgb.std + + layers = [] + for i in range(lod + 1): + conv_1 = 1.0 + conv_1_c = 1.0 + if i != 0: + conv_1 = self.decode_block[i].conv_1.weight.std().item() + conv_1_c = self.decode_block[i].conv_1.std + conv_2 = self.decode_block[i].conv_2.weight.std().item() + conv_2_c = self.decode_block[i].conv_2.std + layers.append(((conv_1 / conv_1_c), (conv_2 / conv_2_c))) + return rgb_std / rgb_std_c, layers + def minibatch_stddev_layer(x, group_size=4): group_size = min(group_size, x.shape[0]) diff --git a/scheduler.py b/scheduler.py index 49985068..d92284a7 100644 --- a/scheduler.py +++ b/scheduler.py @@ -12,7 +12,8 @@ def __init__( warmup_factor=1.0 / 1.0, warmup_iters=1, last_epoch=-1, - reference_batch_size=128 + reference_batch_size=128, + lr=[] ): if not list(milestones) == sorted(milestones): raise ValueError( @@ -24,11 +25,37 @@ def __init__( self.warmup_factor = warmup_factor self.warmup_iters = warmup_iters self.batch_size = 1 + self.lod = 0 self.reference_batch_size = reference_batch_size - super(WarmupMultiStepLR, self).__init__(optimizer, last_epoch) - def set_batch_size(self, batch_size): + self.optimizer = optimizer + self.base_lrs = [] + for _ in self.optimizer.param_groups: + self.base_lrs.append(lr) + + self.last_epoch = last_epoch + + if not isinstance(optimizer, torch.optim.Optimizer): + raise TypeError('{} is not an Optimizer'.format( + type(optimizer).__name__)) + self.optimizer = optimizer + + if last_epoch == -1: + for group in optimizer.param_groups: + group.setdefault('initial_lr', group['lr']) + last_epoch = 0 + + self.last_epoch = last_epoch + + self.optimizer._step_count = 0 + self._step_count = 0 + self.step(last_epoch) + + def set_batch_size(self, batch_size, lod): self.batch_size = batch_size + self.lod = lod + for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()): + param_group['lr'] = lr def get_lr(self): warmup_factor = 1 @@ -36,7 +63,7 @@ def get_lr(self): alpha = float(self.last_epoch) / self.warmup_iters warmup_factor = self.warmup_factor * (1 - alpha) + alpha return [ - base_lr + base_lr[self.lod] * warmup_factor * self.gamma ** bisect_right(self.milestones, self.last_epoch) # * float(self.batch_size) @@ -44,21 +71,29 @@ def get_lr(self): for base_lr in self.base_lrs ] + def state_dict(self): + return { + "last_epoch": self.last_epoch + } + + def load_state_dict(self, state_dict): + self.__dict__.update(dict(last_epoch=state_dict["last_epoch"])) + class ComboMultiStepLR: def __init__( self, - optimizers, + optimizers, base_lr, **kwargs ): self.schedulers = dict() for name, opt in optimizers.items(): - self.schedulers[name] = WarmupMultiStepLR(opt, **kwargs) + self.schedulers[name] = WarmupMultiStepLR(opt, lr=base_lr, **kwargs) self.last_epoch = 0 - def set_batch_size(self, batch_size): + def set_batch_size(self, batch_size, lod): for x in self.schedulers.values(): - x.set_batch_size(batch_size) + x.set_batch_size(batch_size, lod) def step(self, epoch=None): for x in self.schedulers.values(): @@ -72,10 +107,11 @@ def state_dict(self): def load_state_dict(self, state_dict): for k, x in self.schedulers.items(): - x.__dict__.update(state_dict[k]) + x.load_state_dict(state_dict[k]) + last_epochs = [x.last_epoch for k, x in self.schedulers.items()] assert np.all(np.asarray(last_epochs) == last_epochs[0]) - self.last_epoch = last_epochs[0] + 1 + self.last_epoch = last_epochs[0] def start_epoch(self): return self.last_epoch diff --git a/tracker.py b/tracker.py index 28b967d3..33bf5f3e 100644 --- a/tracker.py +++ b/tracker.py @@ -44,16 +44,18 @@ def __init__(self): self.values = [] def __iadd__(self, value): - self.values.append(value.detach().unsqueeze(0)) - return self + with torch.no_grad(): + self.values.append(value.detach().cpu().unsqueeze(0)) + return self def reset(self): self.values = [] def mean(self): - if len(self.values) == 0: - return 0.0 - return float(torch.cat(self.values).mean().item()) + with torch.no_grad(): + if len(self.values) == 0: + return 0.0 + return float(torch.cat(self.values).mean().item()) class LossTracker: @@ -81,9 +83,14 @@ def add(self, name, pytorch=True): def register_means(self, epoch): self.epochs.append(epoch) - for key, value in self.tracks.items(): - self.means_over_epochs[key].append(value.mean()) - value.reset() + + for key in self.means_over_epochs.keys(): + if key in self.tracks: + value = self.tracks[key] + self.means_over_epochs[key].append(value.mean()) + value.reset() + else: + self.means_over_epochs[key].append(None) with open(os.path.join(self.output_folder, 'log.csv'), mode='w') as csv_file: fieldnames = ['epoch'] + list(self.tracks.keys()) @@ -125,3 +132,16 @@ def load_state_dict(self, state_dict): self.epochs = state_dict['epochs'] self.means_over_epochs = state_dict['means_over_epochs'] self.output_folder = state_dict['output_folder'] + + counts = list(map(len, self.means_over_epochs.values())) + + if len(counts) == 0: + counts = [0] + m = min(counts) + + if m < len(self.epochs): + self.epochs = self.epochs[:m] + + for key in self.means_over_epochs.keys(): + if len(self.means_over_epochs[key]) > m: + self.means_over_epochs[key] = self.means_over_epochs[key][:m]