From d43e06174b9823701d20d39a31538fac79828f7b Mon Sep 17 00:00:00 2001 From: Stanislav Pidhorskyi Date: Mon, 7 Dec 2020 08:10:39 -0500 Subject: [PATCH] Fixed confusing names. #38, #35 Fixed memory leaking related to matplotlib #17 Hope did not break anything. If you have an issues, try a revision before this commit. --- README.md | 2 +- checkpointer.py | 2 +- configs/celeba_ablation_nostyle.yaml | 32 -- configs/celeba_ablation_separate.yaml | 33 -- configs/celeba_ablation_z_regression.yaml | 32 -- configs/mnist_fc.yaml | 2 +- configs/mnist_fc_no_lreq_js_loss.yaml | 44 --- defaults.py | 4 +- interactive_demo.py | 10 +- make_figures/make_generation_figure.py | 4 +- .../make_recon_figure_celeba_pioneer.py | 6 +- make_figures/make_recon_figure_ffhq_real.py | 6 +- .../make_recon_figure_interpolation.py | 8 +- make_figures/make_recon_figure_multires.py | 6 +- make_figures/make_recon_figure_paged.py | 6 +- make_figures/make_traversarls.py | 10 +- make_figures/old/make_recon_figure_bed.py | 6 +- make_figures/old/make_recon_figure_celeba.py | 6 +- metrics/fid.py | 2 +- metrics/fid_rec.py | 4 +- metrics/lpips.py | 2 +- metrics/ppl.py | 6 +- model.py | 38 +- model_separate.py | 160 -------- net.py | 66 ++-- principal_directions/extract_attributes.py | 4 +- principal_directions/generate_images.py | 4 +- style_mixing/stylemix.py | 6 +- tracker.py | 34 +- train_alae.py | 54 +-- train_alae_separate.py | 345 ------------------ 31 files changed, 156 insertions(+), 788 deletions(-) delete mode 100644 configs/celeba_ablation_nostyle.yaml delete mode 100644 configs/celeba_ablation_separate.yaml delete mode 100644 configs/celeba_ablation_z_regression.yaml delete mode 100644 configs/mnist_fc_no_lreq_js_loss.yaml delete mode 100644 model_separate.py delete mode 100644 train_alae_separate.py diff --git a/README.md b/README.md index bb80fb8d..3ce11a7c 100644 --- a/README.md +++ b/README.md @@ -38,7 +38,7 @@ > **Adversarial Latent Autoencoders**
> Stanislav Pidhorskyi, Donald Adjeroh, Gianfranco Doretto
> -> **Abstract:** *Autoencoder networks are unsupervised approaches aiming at combining generative and representational properties by learning simultaneously an encoder-generator map. Although studied extensively, the issues of whether they have the same generative power of GANs, or learn disentangled representations, have not been fully addressed. We introduce an autoencoder that tackles these issues jointly, which we call Adversarial Latent Autoencoder (ALAE). It is a general architecture that can leverage recent improvements on GAN training procedures. We designed two autoencoders: one based on a MLP encoder, and another based on a StyleGAN generator, which we call StyleALAE. We verify the disentanglement properties of both architectures. We show that StyleALAE can not only generate 1024x1024 face images with comparable quality of StyleGAN, but at the same resolution can also produce face reconstructions and manipulations based on real images. This makes ALAE the first autoencoder able to compare with, and go beyond, the capabilities of a generator-only type of architecture.* +> **Abstract:** *Autoencoder networks are unsupervised approaches aiming at combining generative and representational properties by learning simultaneously an encoder-generator map. Although studied extensively, the issues of whether they have the same generative power of GANs, or learn disentangled representations, have not been fully addressed. We introduce an autoencoder that tackles these issues jointly, which we call Adversarial Latent Autoencoder (ALAE). It is a general architecture that can leverage recent improvements on GAN training procedures. We designed two autoencoders: one based on a MLP encoder, and another based on a StyleGAN generator, which we call StyleALAE. We verify the disentanglement properties of both architectures. We show that StyleALAE can not only generate 1024x1024 face images with comparable quality of StyleGAN, but at the same resolution can also produce face reconstructions and manipulations based on real images. This makes ALAE the first autoencoder able to compare with, and go beyond the capabilities of a generator-only type of architecture.* ## Citation * Stanislav Pidhorskyi, Donald A. Adjeroh, and Gianfranco Doretto. Adversarial Latent Autoencoders. In *Proceedings of the IEEE Computer Society Conference on Computer Vision and Pattern Recognition (CVPR)*, 2020. [to appear] diff --git a/checkpointer.py b/checkpointer.py index 61e79b83..cc346a27 100644 --- a/checkpointer.py +++ b/checkpointer.py @@ -108,7 +108,7 @@ def load(self, ignore_last_checkpoint=False, file_name=None): self.auxiliary[name].load_state_dict(checkpoint["optimizers"].pop(name)) if name in checkpoint: self.auxiliary[name].load_state_dict(checkpoint.pop(name)) - except IndexError: + except (IndexError, ValueError): self.logger.warning('%s\nFailed to load: %s\n%s' % ('!' * 160, name, '!' * 160)) checkpoint.pop('auxiliary') diff --git a/configs/celeba_ablation_nostyle.yaml b/configs/celeba_ablation_nostyle.yaml deleted file mode 100644 index 2604fc52..00000000 --- a/configs/celeba_ablation_nostyle.yaml +++ /dev/null @@ -1,32 +0,0 @@ - -NAME: celeba -PPL_CELEBA_ADJUSTMENT: True -DATASET: - PART_COUNT: 16 - SIZE: 182637 - SIZE_TEST: 202576 - 182637 - PATH: /data/datasets/celeba/tfrecords/celeba-r%02d.tfrecords.%03d - PATH_TEST: /data/datasets/celeba-test/tfrecords/celeba-r%02d.tfrecords.%03d - MAX_RESOLUTION_LEVEL: 7 -MODEL: - LATENT_SPACE_SIZE: 256 - LAYER_COUNT: 6 - MAX_CHANNEL_COUNT: 256 - START_CHANNEL_COUNT: 64 - DLATENT_AVG_BETA: 0.995 - MAPPING_LAYERS: 8 - ENCODER: "EncoderNoStyle" -OUTPUT_DIR: /data/celeba_z_nostyle -TRAIN: - BASE_LEARNING_RATE: 0.002 - EPOCHS_PER_LOD: 6 - LEARNING_DECAY_RATE: 0.1 - LEARNING_DECAY_STEPS: [] - TRAIN_EPOCHS: 67 - # 4 8 16 32 64 128 256 512 1024 - LOD_2_BATCH_8GPU: [512, 256, 128, 64, 32, 32, 32, 32, 32] - LOD_2_BATCH_4GPU: [512, 256, 128, 64, 32, 32, 32, 32, 16] - LOD_2_BATCH_2GPU: [128, 128, 128, 64, 32, 32, 16] - LOD_2_BATCH_1GPU: [128, 128, 128, 64, 32, 16] - - LEARNING_RATES: [0.0015, 0.0015, 0.0015, 0.0015, 0.0015, 0.0015, 0.002, 0.003, 0.003] diff --git a/configs/celeba_ablation_separate.yaml b/configs/celeba_ablation_separate.yaml deleted file mode 100644 index 1ed21bb1..00000000 --- a/configs/celeba_ablation_separate.yaml +++ /dev/null @@ -1,33 +0,0 @@ - -NAME: celeba -PPL_CELEBA_ADJUSTMENT: True -DATASET: - PART_COUNT: 16 - SIZE: 182637 - SIZE_TEST: 202576 - 182637 - PATH: /data/datasets/celeba/tfrecords/celeba-r%02d.tfrecords.%03d - PATH_TEST: /data/datasets/celeba-test/tfrecords/celeba-r%02d.tfrecords.%03d - MAX_RESOLUTION_LEVEL: 7 -MODEL: - LATENT_SPACE_SIZE: 256 - LAYER_COUNT: 6 - MAX_CHANNEL_COUNT: 256 - START_CHANNEL_COUNT: 64 - DLATENT_AVG_BETA: 0.995 - MAPPING_LAYERS: 8 - MAPPING_FROM_LATENT: "MappingDefault" - -OUTPUT_DIR: /data/celeba_z_sep -TRAIN: - BASE_LEARNING_RATE: 0.002 - EPOCHS_PER_LOD: 6 - LEARNING_DECAY_RATE: 0.1 - LEARNING_DECAY_STEPS: [] - TRAIN_EPOCHS: 67 - # 4 8 16 32 64 128 256 512 1024 - LOD_2_BATCH_8GPU: [512, 256, 128, 64, 32, 32, 32, 32, 32] - LOD_2_BATCH_4GPU: [512, 256, 128, 64, 32, 32, 32, 32, 16] - LOD_2_BATCH_2GPU: [128, 128, 128, 64, 32, 32, 16] - LOD_2_BATCH_1GPU: [128, 128, 128, 64, 32, 16] - - LEARNING_RATES: [0.0015, 0.0015, 0.0015, 0.0015, 0.0015, 0.0015, 0.002, 0.003, 0.003] diff --git a/configs/celeba_ablation_z_regression.yaml b/configs/celeba_ablation_z_regression.yaml deleted file mode 100644 index 51a74dd0..00000000 --- a/configs/celeba_ablation_z_regression.yaml +++ /dev/null @@ -1,32 +0,0 @@ - -NAME: celeba -PPL_CELEBA_ADJUSTMENT: True -DATASET: - PART_COUNT: 16 - SIZE: 182637 - SIZE_TEST: 202576 - 182637 - PATH: /data/datasets/celeba/tfrecords/celeba-r%02d.tfrecords.%03d - PATH_TEST: /data/datasets/celeba-test/tfrecords/celeba-r%02d.tfrecords.%03d - MAX_RESOLUTION_LEVEL: 7 -MODEL: - LATENT_SPACE_SIZE: 256 - LAYER_COUNT: 6 - MAX_CHANNEL_COUNT: 256 - START_CHANNEL_COUNT: 64 - DLATENT_AVG_BETA: 0.995 - MAPPING_LAYERS: 8 - Z_REGRESSION: True -OUTPUT_DIR: /data/celeba_z_zreg -TRAIN: - BASE_LEARNING_RATE: 0.002 - EPOCHS_PER_LOD: 6 - LEARNING_DECAY_RATE: 0.1 - LEARNING_DECAY_STEPS: [] - TRAIN_EPOCHS: 67 - # 4 8 16 32 64 128 256 512 1024 - LOD_2_BATCH_8GPU: [512, 256, 128, 64, 32, 32, 32, 32, 32] - LOD_2_BATCH_4GPU: [512, 256, 128, 64, 32, 32, 32, 32, 16] - LOD_2_BATCH_2GPU: [128, 128, 128, 64, 32, 32, 16] - LOD_2_BATCH_1GPU: [128, 128, 128, 64, 32, 16] - - LEARNING_RATES: [0.0015, 0.0015, 0.0015, 0.0015, 0.0015, 0.0015, 0.002, 0.003, 0.003] diff --git a/configs/mnist_fc.yaml b/configs/mnist_fc.yaml index 2c436197..7012ac5d 100644 --- a/configs/mnist_fc.yaml +++ b/configs/mnist_fc.yaml @@ -18,7 +18,7 @@ MODEL: CHANNELS: 1 GENERATOR: "GeneratorFC" ENCODER: "EncoderFC" - MAPPING_TO_LATENT: "MappingToLatentNoStyle" + MAPPING_D: "MappingDNoStyle" OUTPUT_DIR: mnist_results_fc2z_2 TRAIN: diff --git a/configs/mnist_fc_no_lreq_js_loss.yaml b/configs/mnist_fc_no_lreq_js_loss.yaml deleted file mode 100644 index e83e300e..00000000 --- a/configs/mnist_fc_no_lreq_js_loss.yaml +++ /dev/null @@ -1,44 +0,0 @@ -DATASET: - PART_COUNT: 1 - SIZE: 60000 - SIZE_TEST: 10000 - PATH: /data/datasets/mnist_fc/tfrecords/mnist-r%02d.tfrecords.%03d - PATH_TEST: /data/datasets/mnist_fc/tfrecords/mnist_test-r%02d.tfrecords.%03d - MAX_RESOLUTION_LEVEL: 5 - FLIP_IMAGES: False -MODEL: - # LATENT_SPACE_SIZE: 256 - LATENT_SPACE_SIZE: 50 - LAYER_COUNT: 4 - MAX_CHANNEL_COUNT: 256 - START_CHANNEL_COUNT: 64 - DLATENT_AVG_BETA: 0.995 - # MAPPING_LAYERS: 5 - MAPPING_LAYERS: 6 - CHANNELS: 1 - GENERATOR: "GeneratorFC" - ENCODER: "EncoderFC" - MAPPING_TO_LATENT: "MappingToLatentNoStyle" - -OUTPUT_DIR: mnist_results_fc2z_ne_js -TRAIN: - ADAM_BETA_0: 0.5 - ADAM_BETA_1: 0.999 - - BASE_LEARNING_RATE: 0.0002 - EPOCHS_PER_LOD: 0 - LEARNING_DECAY_RATE: 0.5 - - LEARNING_DECAY_STEPS: [] - - TRAIN_EPOCHS: 600 - LOD_2_BATCH_8GPU: [128, 128, 128, 1024] - LOD_2_BATCH_4GPU: [128, 128, 128, 128] - LOD_2_BATCH_2GPU: [128, 128, 128, 128] - LOD_2_BATCH_1GPU: [128, 128, 128, 128] - - LEARNING_RATES: [0.003, 0.003, 0.002, 0.002] - - SNAPSHOT_FREQ: [300, 300, 300, 300, 300] - - REPORT_FREQ: [100, 300, 300, 300, 300] diff --git a/defaults.py b/defaults.py index 41cd317e..e1146174 100644 --- a/defaults.py +++ b/defaults.py @@ -51,8 +51,8 @@ _C.MODEL.CHANNELS = 3 _C.MODEL.GENERATOR = "GeneratorDefault" _C.MODEL.ENCODER = "EncoderDefault" -_C.MODEL.MAPPING_TO_LATENT = "MappingToLatent" -_C.MODEL.MAPPING_FROM_LATENT = "MappingFromLatent" +_C.MODEL.MAPPING_D = "MappingD" +_C.MODEL.MAPPING_F = "MappingF" _C.MODEL.Z_REGRESSION = False _C.TRAIN = CN() diff --git a/interactive_demo.py b/interactive_demo.py index 6092e6af..e80ed738 100644 --- a/interactive_demo.py +++ b/interactive_demo.py @@ -62,8 +62,8 @@ def sample(cfg, logger): decoder = model.decoder encoder = model.encoder - mapping_tl = model.mapping_tl - mapping_fl = model.mapping_fl + mapping_tl = model.mapping_d + mapping_fl = model.mapping_f dlatent_avg = model.dlatent_avg logger.info("Trainable parameters generator:") @@ -97,7 +97,7 @@ def sample(cfg, logger): def encode(x): Z, _ = model.encode(x, layer_count - 1, 1) - Z = Z.repeat(1, model.mapping_fl.num_layers, 1) + Z = Z.repeat(1, model.mapping_f.num_layers, 1) return Z def decode(x): @@ -186,9 +186,9 @@ def loadRandom(): def update_image(w, latents_original): with torch.no_grad(): w = w + model.dlatent_avg.buff.data[0] - w = w[None, None, ...].repeat(1, model.mapping_fl.num_layers, 1) + w = w[None, None, ...].repeat(1, model.mapping_f.num_layers, 1) - layer_idx = torch.arange(model.mapping_fl.num_layers)[np.newaxis, :, np.newaxis] + layer_idx = torch.arange(model.mapping_f.num_layers)[np.newaxis, :, np.newaxis] cur_layers = (7 + 1) * 2 mixing_cutoff = cur_layers styles = torch.where(layer_idx < mixing_cutoff, w, latents_original) diff --git a/make_figures/make_generation_figure.py b/make_figures/make_generation_figure.py index b8264032..0cedec0c 100644 --- a/make_figures/make_generation_figure.py +++ b/make_figures/make_generation_figure.py @@ -65,8 +65,8 @@ def sample(cfg, logger): decoder = model.decoder encoder = model.encoder - mapping_tl = model.mapping_tl - mapping_fl = model.mapping_fl + mapping_tl = model.mapping_d + mapping_fl = model.mapping_f dlatent_avg = model.dlatent_avg diff --git a/make_figures/make_recon_figure_celeba_pioneer.py b/make_figures/make_recon_figure_celeba_pioneer.py index 7e234f27..8f8b56a8 100644 --- a/make_figures/make_recon_figure_celeba_pioneer.py +++ b/make_figures/make_recon_figure_celeba_pioneer.py @@ -73,8 +73,8 @@ def sample(cfg, logger): decoder = model.decoder encoder = model.encoder - mapping_tl = model.mapping_tl - mapping_fl = model.mapping_fl + mapping_tl = model.mapping_d + mapping_fl = model.mapping_f dlatent_avg = model.dlatent_avg logger.info("Trainable parameters generator:") @@ -108,7 +108,7 @@ def sample(cfg, logger): def encode(x): Z, _ = model.encode(x, layer_count - 1, 1) - Z = Z.repeat(1, model.mapping_fl.num_layers, 1) + Z = Z.repeat(1, model.mapping_f.num_layers, 1) return Z def decode(x): diff --git a/make_figures/make_recon_figure_ffhq_real.py b/make_figures/make_recon_figure_ffhq_real.py index 21ea4166..801b587d 100644 --- a/make_figures/make_recon_figure_ffhq_real.py +++ b/make_figures/make_recon_figure_ffhq_real.py @@ -69,8 +69,8 @@ def sample(cfg, logger): decoder = model.decoder encoder = model.encoder - mapping_tl = model.mapping_tl - mapping_fl = model.mapping_fl + mapping_tl = model.mapping_d + mapping_fl = model.mapping_f dlatent_avg = model.dlatent_avg logger.info("Trainable parameters generator:") @@ -104,7 +104,7 @@ def sample(cfg, logger): def encode(x): Z, _ = model.encode(x, layer_count - 1, 1) - Z = Z.repeat(1, model.mapping_fl.num_layers, 1) + Z = Z.repeat(1, model.mapping_f.num_layers, 1) return Z def decode(x): diff --git a/make_figures/make_recon_figure_interpolation.py b/make_figures/make_recon_figure_interpolation.py index f42b4c40..1a74c4f2 100644 --- a/make_figures/make_recon_figure_interpolation.py +++ b/make_figures/make_recon_figure_interpolation.py @@ -70,8 +70,8 @@ def sample(cfg, logger): decoder = model.decoder encoder = model.encoder - mapping_tl = model.mapping_tl - mapping_fl = model.mapping_fl + mapping_tl = model.mapping_d + mapping_fl = model.mapping_f dlatent_avg = model.dlatent_avg logger.info("Trainable parameters generator:") @@ -105,7 +105,7 @@ def sample(cfg, logger): def encode(x): Z, _ = model.encode(x, layer_count - 1, 1) - Z = Z.repeat(1, model.mapping_fl.num_layers, 1) + Z = Z.repeat(1, model.mapping_f.num_layers, 1) return Z def decode(x): @@ -144,7 +144,7 @@ def open_image(filename): def make(w): with torch.no_grad(): - w = w[None, None, ...].repeat(1, model.mapping_fl.num_layers, 1) + w = w[None, None, ...].repeat(1, model.mapping_f.num_layers, 1) x_rec = decode(w) return x_rec diff --git a/make_figures/make_recon_figure_multires.py b/make_figures/make_recon_figure_multires.py index 9627093e..16db392c 100644 --- a/make_figures/make_recon_figure_multires.py +++ b/make_figures/make_recon_figure_multires.py @@ -71,8 +71,8 @@ def sample(cfg, logger): decoder = model.decoder encoder = model.encoder - mapping_tl = model.mapping_tl - mapping_fl = model.mapping_fl + mapping_tl = model.mapping_d + mapping_fl = model.mapping_f dlatent_avg = model.dlatent_avg logger.info("Trainable parameters generator:") @@ -106,7 +106,7 @@ def sample(cfg, logger): def encode(x): Z, _ = model.encode(x, layer_count - 1, 1) - Z = Z.repeat(1, model.mapping_fl.num_layers, 1) + Z = Z.repeat(1, model.mapping_f.num_layers, 1) return Z def decode(x): diff --git a/make_figures/make_recon_figure_paged.py b/make_figures/make_recon_figure_paged.py index d2a6d620..5cd80ca6 100644 --- a/make_figures/make_recon_figure_paged.py +++ b/make_figures/make_recon_figure_paged.py @@ -72,8 +72,8 @@ def sample(cfg, logger): decoder = model.decoder encoder = model.encoder - mapping_tl = model.mapping_tl - mapping_fl = model.mapping_fl + mapping_tl = model.mapping_d + mapping_fl = model.mapping_f dlatent_avg = model.dlatent_avg logger.info("Trainable parameters generator:") @@ -107,7 +107,7 @@ def sample(cfg, logger): def encode(x): Z, _ = model.encode(x, layer_count - 1, 1) - Z = Z.repeat(1, model.mapping_fl.num_layers, 1) + Z = Z.repeat(1, model.mapping_f.num_layers, 1) return Z def decode(x): diff --git a/make_figures/make_traversarls.py b/make_figures/make_traversarls.py index 4ec7a206..b3b5fbf7 100644 --- a/make_figures/make_traversarls.py +++ b/make_figures/make_traversarls.py @@ -48,8 +48,8 @@ def sample(cfg, logger): decoder = model.decoder encoder = model.encoder - mapping_tl = model.mapping_tl - mapping_fl = model.mapping_fl + mapping_tl = model.mapping_d + mapping_fl = model.mapping_f dlatent_avg = model.dlatent_avg logger.info("Trainable parameters generator:") @@ -83,7 +83,7 @@ def sample(cfg, logger): def encode(x): Z, _ = model.encode(x, layer_count - 1, 1) - Z = Z.repeat(1, model.mapping_fl.num_layers, 1) + Z = Z.repeat(1, model.mapping_f.num_layers, 1) return Z def decode(x): @@ -122,9 +122,9 @@ def do_attribute_traversal(path, attrib_idx, start, end): def update_image(w): with torch.no_grad(): w = w + model.dlatent_avg.buff.data[0] - w = w[None, None, ...].repeat(1, model.mapping_fl.num_layers, 1) + w = w[None, None, ...].repeat(1, model.mapping_f.num_layers, 1) - layer_idx = torch.arange(model.mapping_fl.num_layers)[np.newaxis, :, np.newaxis] + layer_idx = torch.arange(model.mapping_f.num_layers)[np.newaxis, :, np.newaxis] cur_layers = (7 + 1) * 2 mixing_cutoff = cur_layers styles = torch.where(layer_idx < mixing_cutoff, w, _latents[0]) diff --git a/make_figures/old/make_recon_figure_bed.py b/make_figures/old/make_recon_figure_bed.py index b31728bc..4cafdbd7 100644 --- a/make_figures/old/make_recon_figure_bed.py +++ b/make_figures/old/make_recon_figure_bed.py @@ -74,8 +74,8 @@ def sample(cfg, logger): decoder = model.decoder encoder = model.encoder - mapping_tl = model.mapping_tl - mapping_fl = model.mapping_fl + mapping_tl = model.mapping_d + mapping_fl = model.mapping_f dlatent_avg = model.dlatent_avg logger.info("Trainable parameters generator:") @@ -109,7 +109,7 @@ def sample(cfg, logger): def encode(x): Z, _ = model.encode(x, layer_count - 1, 1) - Z = Z.repeat(1, model.mapping_fl.num_layers, 1) + Z = Z.repeat(1, model.mapping_f.num_layers, 1) return Z def decode(x): diff --git a/make_figures/old/make_recon_figure_celeba.py b/make_figures/old/make_recon_figure_celeba.py index 0658a52c..0557f3a2 100644 --- a/make_figures/old/make_recon_figure_celeba.py +++ b/make_figures/old/make_recon_figure_celeba.py @@ -74,8 +74,8 @@ def sample(cfg, logger): decoder = model.decoder encoder = model.encoder - mapping_tl = model.mapping_tl - mapping_fl = model.mapping_fl + mapping_tl = model.mapping_d + mapping_fl = model.mapping_f dlatent_avg = model.dlatent_avg logger.info("Trainable parameters generator:") @@ -109,7 +109,7 @@ def sample(cfg, logger): def encode(x): Z, _ = model.encode(x, layer_count - 1, 1) - Z = Z.repeat(1, model.mapping_fl.num_layers, 1) + Z = Z.repeat(1, model.mapping_f.num_layers, 1) return Z def decode(x): diff --git a/metrics/fid.py b/metrics/fid.py index efa42165..87001422 100644 --- a/metrics/fid.py +++ b/metrics/fid.py @@ -124,7 +124,7 @@ def sample(cfg, logger): decoder = model.decoder encoder = model.encoder - mapping_fl = model.mapping_fl + mapping_fl = model.mapping_f dlatent_avg = model.dlatent_avg logger.info("Trainable parameters generator:") diff --git a/metrics/fid_rec.py b/metrics/fid_rec.py index d5c9e1c5..e5085f94 100644 --- a/metrics/fid_rec.py +++ b/metrics/fid_rec.py @@ -138,8 +138,8 @@ def sample(cfg, logger): decoder = model.decoder encoder = model.encoder - mapping_tl = model.mapping_tl - mapping_fl = model.mapping_fl + mapping_tl = model.mapping_d + mapping_fl = model.mapping_f dlatent_avg = model.dlatent_avg logger.info("Trainable parameters generator:") diff --git a/metrics/lpips.py b/metrics/lpips.py index 9b123dbf..b69e0e4d 100644 --- a/metrics/lpips.py +++ b/metrics/lpips.py @@ -109,7 +109,7 @@ def sample(cfg, logger): decoder = model.decoder encoder = model.encoder - mapping_fl = model.mapping_fl + mapping_fl = model.mapping_f dlatent_avg = model.dlatent_avg logger.info("Trainable parameters generator:") diff --git a/metrics/ppl.py b/metrics/ppl.py index 781f2e58..29da71cf 100644 --- a/metrics/ppl.py +++ b/metrics/ppl.py @@ -156,8 +156,8 @@ def sample(cfg, logger): decoder = model.decoder encoder = model.encoder - mapping_tl = model.mapping_tl - mapping_fl = model.mapping_fl + mapping_tl = model.mapping_d + mapping_fl = model.mapping_f dlatent_avg = model.dlatent_avg logger.info("Trainable parameters generator:") @@ -191,7 +191,7 @@ def sample(cfg, logger): def encode(x): Z, _ = model.encode(x, layer_count - 1, 1) - Z = Z.repeat(1, model.mapping_fl.num_layers, 1) + Z = Z.repeat(1, model.mapping_f.num_layers, 1) return Z def decode(x): diff --git a/model.py b/model.py index 78c1da9a..ceb9bf34 100644 --- a/model.py +++ b/model.py @@ -35,13 +35,13 @@ def __init__(self, startf=32, maxf=256, layer_count=3, latent_size=128, mapping_ self.layer_count = layer_count self.z_regression = z_regression - self.mapping_tl = MAPPINGS["MappingToLatent"]( + self.mapping_d = MAPPINGS["MappingD"]( latent_size=latent_size, dlatent_size=latent_size, mapping_fmaps=latent_size, mapping_layers=3) - self.mapping_fl = MAPPINGS["MappingFromLatent"]( + self.mapping_f = MAPPINGS["MappingF"]( num_layers=2 * layer_count, latent_size=latent_size, dlatent_size=latent_size, @@ -62,7 +62,7 @@ def __init__(self, startf=32, maxf=256, layer_count=3, latent_size=128, mapping_ latent_size=latent_size, channels=channels) - self.dlatent_avg = DLatent(latent_size, self.mapping_fl.num_layers) + self.dlatent_avg = DLatent(latent_size, self.mapping_f.num_layers) self.latent_size = latent_size self.dlatent_avg_beta = dlatent_avg_beta self.truncation_psi = truncation_psi @@ -72,10 +72,10 @@ def __init__(self, startf=32, maxf=256, layer_count=3, latent_size=128, mapping_ def generate(self, lod, blend_factor, z=None, count=32, mixing=True, noise=True, return_styles=False, no_truncation=False): if z is None: z = torch.randn(count, self.latent_size) - styles = self.mapping_fl(z)[:, 0] + styles = self.mapping_f(z)[:, 0] s = styles.view(styles.shape[0], 1, styles.shape[1]) - styles = s.repeat(1, self.mapping_fl.num_layers, 1) + styles = s.repeat(1, self.mapping_f.num_layers, 1) if self.dlatent_avg_beta is not None: with torch.no_grad(): @@ -85,16 +85,16 @@ def generate(self, lod, blend_factor, z=None, count=32, mixing=True, noise=True, if mixing and self.style_mixing_prob is not None: if random.random() < self.style_mixing_prob: z2 = torch.randn(count, self.latent_size) - styles2 = self.mapping_fl(z2)[:, 0] - styles2 = styles2.view(styles2.shape[0], 1, styles2.shape[1]).repeat(1, self.mapping_fl.num_layers, 1) + styles2 = self.mapping_f(z2)[:, 0] + styles2 = styles2.view(styles2.shape[0], 1, styles2.shape[1]).repeat(1, self.mapping_f.num_layers, 1) - layer_idx = torch.arange(self.mapping_fl.num_layers)[np.newaxis, :, np.newaxis] + layer_idx = torch.arange(self.mapping_f.num_layers)[np.newaxis, :, np.newaxis] cur_layers = (lod + 1) * 2 mixing_cutoff = random.randint(1, cur_layers) styles = torch.where(layer_idx < mixing_cutoff, styles, styles2) if (self.truncation_psi is not None) and not no_truncation: - layer_idx = torch.arange(self.mapping_fl.num_layers)[np.newaxis, :, np.newaxis] + layer_idx = torch.arange(self.mapping_f.num_layers)[np.newaxis, :, np.newaxis] ones = torch.ones(layer_idx.shape, dtype=torch.float32) coefs = torch.where(layer_idx < self.truncation_cutoff, self.truncation_psi * ones, ones) styles = torch.lerp(self.dlatent_avg.buff.data, styles, coefs) @@ -107,8 +107,8 @@ def generate(self, lod, blend_factor, z=None, count=32, mixing=True, noise=True, def encode(self, x, lod, blend_factor): Z = self.encoder(x, lod, blend_factor) - Z_ = self.mapping_tl(Z) - return Z[:, :1], Z_[:, 1, 0] + discriminator_prediction = self.mapping_d(Z) + return Z[:, :1], discriminator_prediction def forward(self, x, lod, blend_factor, d_train, ae): if ae: @@ -136,7 +136,7 @@ def forward(self, x, lod, blend_factor, d_train, ae): _, d_result_real = self.encode(x, lod, blend_factor) - _, d_result_fake = self.encode(Xp.detach(), lod, blend_factor) + _, d_result_fake = self.encode(Xp, lod, blend_factor) loss_d = losses.discriminator_logistic_simple_gp(d_result_fake, d_result_real, x) return loss_d @@ -158,8 +158,8 @@ def lerp(self, other, betta): if hasattr(other, 'module'): other = other.module with torch.no_grad(): - params = list(self.mapping_tl.parameters()) + list(self.mapping_fl.parameters()) + list(self.decoder.parameters()) + list(self.encoder.parameters()) + list(self.dlatent_avg.parameters()) - other_param = list(other.mapping_tl.parameters()) + list(other.mapping_fl.parameters()) + list(other.decoder.parameters()) + list(other.encoder.parameters()) + list(other.dlatent_avg.parameters()) + params = list(self.mapping_d.parameters()) + list(self.mapping_f.parameters()) + list(self.decoder.parameters()) + list(self.encoder.parameters()) + list(self.dlatent_avg.parameters()) + other_param = list(other.mapping_d.parameters()) + list(other.mapping_f.parameters()) + list(other.decoder.parameters()) + list(other.encoder.parameters()) + list(other.dlatent_avg.parameters()) for p, p_other in zip(params, other_param): p.data.lerp_(p_other.data, 1.0 - betta) @@ -171,7 +171,7 @@ def __init__(self, startf=32, maxf=256, layer_count=3, latent_size=128, mapping_ self.layer_count = layer_count - self.mapping_fl = MAPPINGS["MappingFromLatent"]( + self.mapping_f = MAPPINGS["MappingF"]( num_layers=2 * layer_count, latent_size=latent_size, dlatent_size=latent_size, @@ -185,7 +185,7 @@ def __init__(self, startf=32, maxf=256, layer_count=3, latent_size=128, mapping_ latent_size=latent_size, channels=channels) - self.dlatent_avg = DLatent(latent_size, self.mapping_fl.num_layers) + self.dlatent_avg = DLatent(latent_size, self.mapping_f.num_layers) self.latent_size = latent_size self.dlatent_avg_beta = dlatent_avg_beta self.truncation_psi = truncation_psi @@ -193,12 +193,12 @@ def __init__(self, startf=32, maxf=256, layer_count=3, latent_size=128, mapping_ self.truncation_cutoff = truncation_cutoff def generate(self, lod, blend_factor, z=None): - styles = self.mapping_fl(z)[:, 0] + styles = self.mapping_f(z)[:, 0] s = styles.view(styles.shape[0], 1, styles.shape[1]) - styles = s.repeat(1, self.mapping_fl.num_layers, 1) + styles = s.repeat(1, self.mapping_f.num_layers, 1) - layer_idx = torch.arange(self.mapping_fl.num_layers)[np.newaxis, :, np.newaxis] + layer_idx = torch.arange(self.mapping_f.num_layers)[np.newaxis, :, np.newaxis] ones = torch.ones(layer_idx.shape, dtype=torch.float32) coefs = torch.where(layer_idx < self.truncation_cutoff, self.truncation_psi * ones, ones) styles = torch.lerp(self.dlatent_avg.buff.data, styles, coefs) diff --git a/model_separate.py b/model_separate.py deleted file mode 100644 index 2a50f238..00000000 --- a/model_separate.py +++ /dev/null @@ -1,160 +0,0 @@ -# Copyright 2019 Stanislav Pidhorskyi -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== - -import random -import losses -from net import * -import numpy as np - - -class DLatent(nn.Module): - def __init__(self, dlatent_size, layer_count): - super(DLatent, self).__init__() - buffer = torch.zeros(layer_count, dlatent_size, dtype=torch.float32) - self.register_buffer('buff', buffer) - - -class Model(nn.Module): - def __init__(self, startf=32, maxf=256, layer_count=3, latent_size=128, mapping_layers=5, dlatent_avg_beta=None, - truncation_psi=None, truncation_cutoff=None, style_mixing_prob=None, channels=3, generator="", encoder=""): - super(Model, self).__init__() - - self.layer_count = layer_count - - self.mapping_fl = Mapping( - num_layers=2 * layer_count, - latent_size=latent_size, - dlatent_size=latent_size, - mapping_fmaps=latent_size, - mapping_layers=mapping_layers) - - #self.decoder = Generator( - self.decoder = GENERATORS[generator]( - startf=startf, - layer_count=layer_count, - maxf=maxf, - latent_size=latent_size, - channels=channels) - - #self.encoder = Encoder_old( - self.encoder = ENCODERS[encoder]( - startf=startf, - layer_count=layer_count, - maxf=maxf, - latent_size=latent_size, - channels=channels) - - self.discriminator = Discriminator( - startf=startf, - layer_count=layer_count, - maxf=maxf, - channels=channels) - - self.dlatent_avg = DLatent(latent_size, self.mapping_fl.num_layers) - self.latent_size = latent_size - self.dlatent_avg_beta = dlatent_avg_beta - self.truncation_psi = truncation_psi - self.style_mixing_prob = style_mixing_prob - self.truncation_cutoff = truncation_cutoff - - def generate(self, lod, blend_factor, z=None, count=32, mixing=True, noise=True, return_styles=False, no_truncation=False): - if z is None: - z = torch.randn(count, self.latent_size) - styles = self.mapping_fl(z)[:, 0] - s = styles.view(styles.shape[0], 1, styles.shape[1]) - - styles = s.repeat(1, self.mapping_fl.num_layers, 1) - - if self.dlatent_avg_beta is not None: - with torch.no_grad(): - batch_avg = styles.mean(dim=0) - self.dlatent_avg.buff.data.lerp_(batch_avg.data, 1.0 - self.dlatent_avg_beta) - - if mixing and self.style_mixing_prob is not None: - if random.random() < self.style_mixing_prob: - z2 = torch.randn(count, self.latent_size) - styles2 = self.mapping_fl(z2)[:, 0] - styles2 = styles2.view(styles2.shape[0], 1, styles2.shape[1]).repeat(1, self.mapping_fl.num_layers, 1) - - layer_idx = torch.arange(self.mapping_fl.num_layers)[np.newaxis, :, np.newaxis] - cur_layers = (lod + 1) * 2 - mixing_cutoff = random.randint(1, cur_layers) - styles = torch.where(layer_idx < mixing_cutoff, styles, styles2) - - if (self.truncation_psi is not None) and not no_truncation: - layer_idx = torch.arange(self.mapping_fl.num_layers)[np.newaxis, :, np.newaxis] - ones = torch.ones(layer_idx.shape, dtype=torch.float32) - coefs = torch.where(layer_idx < self.truncation_cutoff, self.truncation_psi * ones, ones) - styles = torch.lerp(self.dlatent_avg.buff.data, styles, coefs) - - rec = self.decoder.forward(styles, lod, blend_factor, noise) - if return_styles: - return s, rec - else: - return rec - - def encode(self, x, lod, blend_factor): - Z = self.encoder(x, lod, blend_factor) - return Z - - def forward(self, x, lod, blend_factor, d_train, ae, alt): - if ae: - self.encoder.requires_grad_(True) - - z = torch.randn(x.shape[0], self.latent_size) - s, rec = self.generate(lod, blend_factor, z=z, mixing=False, noise=True, return_styles=True) - - Z = self.encode(rec, lod, blend_factor) - - assert Z.shape == s.shape - - Lae = torch.mean(((Z - s.detach())**2)) - - return Lae - - elif d_train: - with torch.no_grad(): - Xp = self.generate(lod, blend_factor, count=x.shape[0], noise=True) - - self.discriminator.requires_grad_(True) - - d_result_real = self.discriminator(x, lod, blend_factor) - - d_result_fake = self.discriminator(Xp.detach(), lod, blend_factor) - - loss_d = losses.discriminator_logistic_simple_gp(d_result_fake, d_result_real, x) - return loss_d - else: - with torch.no_grad(): - z = torch.randn(x.shape[0], self.latent_size) - - self.discriminator.requires_grad_(False) - - rec = self.generate(lod, blend_factor, count=x.shape[0], z=z.detach(), noise=True) - - d_result_fake = self.discriminator(rec, lod, blend_factor) - - loss_g = losses.generator_logistic_non_saturating(d_result_fake) - - return loss_g - - def lerp(self, other, betta): - if hasattr(other, 'module'): - other = other.module - with torch.no_grad(): - params = list(self.mapping_fl.parameters()) + list(self.decoder.parameters()) + list(self.encoder.parameters()) + list(self.discriminator.parameters()) + list(self.dlatent_avg.parameters()) - other_param = list(other.mapping_fl.parameters()) + list(other.decoder.parameters()) + list(other.encoder.parameters()) + list(other.dlatent_avg.parameters()) - for p, p_other in zip(params, other_param): - p.data.lerp_(p_other.data, 1.0 - betta) diff --git a/net.py b/net.py index 76b93d2f..09ec3b2b 100644 --- a/net.py +++ b/net.py @@ -64,7 +64,6 @@ class EncodeBlock(nn.Module): def __init__(self, inputs, outputs, latent_size, last=False, fused_scale=True): super(EncodeBlock, self).__init__() self.conv_1 = ln.Conv2d(inputs, inputs, 3, 1, 1, bias=False) - # self.conv_1 = ln.Conv2d(inputs + (1 if last else 0), inputs, 3, 1, 1, bias=False) self.bias_1 = nn.Parameter(torch.Tensor(1, inputs, 1, 1)) self.instance_norm_1 = nn.InstanceNorm2d(inputs, affine=False) self.blur = Blur(inputs) @@ -237,7 +236,7 @@ def forward(self, x, s1, s2, noise): tensor2=torch.randn([x.shape[0], 1, x.shape[2], x.shape[3]])) else: s = math.pow(self.layer + 1, 0.5) - x = x + s * torch.exp(-x * x / (2.0 * s * s)) / math.sqrt(2 * math.pi) * 0.8 + x = x + s * torch.exp(-x * x / (2.0 * s * s)) / math.sqrt(2 * math.pi) * 0.8 x = x + self.bias_2 @@ -273,10 +272,11 @@ def forward(self, x): return x +# Default Encoder. E network @ENCODERS.register("EncoderDefault") -class Encoder_old(nn.Module): +class EncoderDefault(nn.Module): def __init__(self, startf, maxf, layer_count, latent_size, channels=3): - super(Encoder_old, self).__init__() + super(EncoderDefault, self).__init__() self.maxf = maxf self.startf = startf self.layer_count = layer_count @@ -301,7 +301,6 @@ def __init__(self, startf, maxf, layer_count, latent_size, channels=3): resolution //= 2 - #print("encode_block%d %s styles out: %d" % ((i + 1), millify(count_parameters(block)), inputs)) self.encode_block.append(block) inputs = outputs mul *= 2 @@ -361,6 +360,7 @@ def get_statistics(self, lod): return rgb_std / rgb_std_c, layers +# For ablation only. Not used in default configuration @ENCODERS.register("EncoderWithFC") class EncoderWithFC(nn.Module): def __init__(self, startf, maxf, layer_count, latent_size, channels=3): @@ -389,7 +389,6 @@ def __init__(self, startf, maxf, layer_count, latent_size, channels=3): resolution //= 2 - #print("encode_block%d %s styles out: %d" % ((i + 1), millify(count_parameters(block)), inputs)) self.encode_block.append(block) inputs = outputs mul *= 2 @@ -539,6 +538,7 @@ def get_statistics(self, lod): return rgb_std / rgb_std_c, layers +# For ablation only. Not used in default configuration @ENCODERS.register("EncoderNoStyle") class EncoderNoStyle(nn.Module): def __init__(self, startf=32, maxf=256, layer_count=3, latent_size=512, channels=3): @@ -566,7 +566,6 @@ def __init__(self, startf=32, maxf=256, layer_count=3, latent_size=512, channels resolution //= 2 - #print("encode_block%d %s" % ((i + 1), millify(count_parameters(block)))) self.encode_block.append(block) inputs = outputs mul *= 2 @@ -607,6 +606,7 @@ def forward(self, x, lod, blend): return self.encode2(x, lod, blend) +# For ablation only. Not used in default configuration @DISCRIMINATORS.register("DiscriminatorDefault") class Discriminator(nn.Module): def __init__(self, startf=32, maxf=256, layer_count=3, channels=3): @@ -634,7 +634,6 @@ def __init__(self, startf=32, maxf=256, layer_count=3, channels=3): resolution //= 2 - #print("encode_block%d %s" % ((i + 1), millify(count_parameters(block)))) self.encode_block.append(block) inputs = outputs mul *= 2 @@ -715,8 +714,6 @@ def __init__(self, startf=32, maxf=256, layer_count=3, latent_size=128, channels to_rgb.append(ToRGB(outputs, channels)) - #print("decode_block%d %s styles in: %dl out resolution: %d" % ( - # (i + 1), millify(count_parameters(block)), outputs, resolution)) self.decode_block.append(block) inputs = outputs mul //= 2 @@ -773,6 +770,18 @@ def get_statistics(self, lod): return rgb_std / rgb_std_c, layers +def minibatch_stddev_layer(x, group_size=4): + group_size = min(group_size, x.shape[0]) + size = x.shape[0] + if x.shape[0] % group_size != 0: + x = torch.cat([x, x[:(group_size - (x.shape[0] % group_size)) % group_size]]) + y = x.view(group_size, -1, x.shape[1], x.shape[2], x.shape[3]) + y = y - y.mean(dim=0, keepdim=True) + y = torch.sqrt((y ** 2).mean(dim=0) + 1e-8).mean(dim=[1, 2, 3], keepdim=True) + y = y.repeat(group_size, 1, x.shape[2], x.shape[3]) + return torch.cat([x, y], dim=1)[:size] + + image_size = 64 # Number of channels in the training images. For color images this is 3 @@ -859,6 +868,7 @@ def forward(self, x): return x +# For ablation only. Not used in default configuration @MAPPINGS.register("MappingDefault") class Mapping(nn.Module): def __init__(self, num_layers, mapping_layers=5, latent_size=256, dlatent_size=256, mapping_fmaps=256): @@ -871,7 +881,6 @@ def __init__(self, num_layers, mapping_layers=5, latent_size=256, dlatent_size=2 block = MappingBlock(inputs, outputs, lrmul=0.01) inputs = outputs setattr(self, "block_%d" % (i + 1), block) - #print("dense %d %s" % ((i + 1), millify(count_parameters(block)))) def forward(self, z): x = pixel_norm(z) @@ -882,10 +891,11 @@ def forward(self, z): return x.view(x.shape[0], 1, x.shape[1]).repeat(1, self.num_layers, 1) -@MAPPINGS.register("MappingToLatent") -class VAEMappingToLatent_old(nn.Module): +# Used in default configuration. The D network +@MAPPINGS.register("MappingD") +class MappingD(nn.Module): def __init__(self, mapping_layers=5, latent_size=256, dlatent_size=256, mapping_fmaps=256): - super(VAEMappingToLatent_old, self).__init__() + super(MappingD, self).__init__() inputs = latent_size self.mapping_layers = mapping_layers self.map_blocks: nn.ModuleList[MappingBlock] = nn.ModuleList() @@ -894,19 +904,19 @@ def __init__(self, mapping_layers=5, latent_size=256, dlatent_size=256, mapping_ block = ln.Linear(inputs, outputs, lrmul=0.1) inputs = outputs self.map_blocks.append(block) - #print("dense %d %s" % ((i + 1), millify(count_parameters(block)))) def forward(self, x): for i in range(self.mapping_layers): x = self.map_blocks[i](x) + # We select just one output. For compatibility with older models. + # All other outputs are ignored + return x[:, 0, 0] - return x.view(x.shape[0], 2, x.shape[2] // 2) - -@MAPPINGS.register("MappingToLatentNoStyle") -class VAEMappingToLatentNoStyle(nn.Module): +@MAPPINGS.register("MappingDNoStyle") +class MappingDNoStyle(nn.Module): def __init__(self, mapping_layers=5, latent_size=256, dlatent_size=256, mapping_fmaps=256): - super(VAEMappingToLatentNoStyle, self).__init__() + super(MappingDNoStyle, self).__init__() inputs = latent_size self.mapping_layers = mapping_layers self.map_blocks: nn.ModuleList[MappingBlock] = nn.ModuleList() @@ -918,19 +928,14 @@ def __init__(self, mapping_layers=5, latent_size=256, dlatent_size=256, mapping_ def forward(self, x): for i in range(self.mapping_layers): - if i == self.mapping_layers - 1: - #x = self.map_blocks[i](x) - x = self.map_blocks[i](x) - else: - #x = self.map_blocks[i](x) - x = self.map_blocks[i](x) - return x + x = self.map_blocks[i](x) + return x[:, 0] -@MAPPINGS.register("MappingFromLatent") -class VAEMappingFromLatent(nn.Module): +@MAPPINGS.register("MappingF") +class MappingF(nn.Module): def __init__(self, num_layers, mapping_layers=5, latent_size=256, dlatent_size=256, mapping_fmaps=256): - super(VAEMappingFromLatent, self).__init__() + super(MappingF, self).__init__() inputs = dlatent_size self.mapping_layers = mapping_layers self.num_layers = num_layers @@ -940,7 +945,6 @@ def __init__(self, num_layers, mapping_layers=5, latent_size=256, dlatent_size=2 block = MappingBlock(inputs, outputs, lrmul=0.1) inputs = outputs self.map_blocks.append(block) - #print("dense %d %s" % ((i + 1), millify(count_parameters(block)))) def forward(self, x): x = pixel_norm(x) diff --git a/principal_directions/extract_attributes.py b/principal_directions/extract_attributes.py index 9e2f941f..7f9accb0 100644 --- a/principal_directions/extract_attributes.py +++ b/principal_directions/extract_attributes.py @@ -106,8 +106,8 @@ def main(cfg, logger): decoder = model.decoder encoder = model.encoder - mapping_tl = model.mapping_tl - mapping_fl = model.mapping_fl + mapping_tl = model.mapping_d + mapping_fl = model.mapping_f dlatent_avg = model.dlatent_avg logger.info("Trainable parameters generator:") diff --git a/principal_directions/generate_images.py b/principal_directions/generate_images.py index 05018d89..05cfb32a 100644 --- a/principal_directions/generate_images.py +++ b/principal_directions/generate_images.py @@ -80,8 +80,8 @@ def sample(cfg, logger): decoder = model.decoder encoder = model.encoder - mapping_tl = model.mapping_tl - mapping_fl = model.mapping_fl + mapping_tl = model.mapping_d + mapping_fl = model.mapping_f dlatent_avg = model.dlatent_avg logger.info("Trainable parameters generator:") diff --git a/style_mixing/stylemix.py b/style_mixing/stylemix.py index 15aa133b..731c28ea 100644 --- a/style_mixing/stylemix.py +++ b/style_mixing/stylemix.py @@ -61,8 +61,8 @@ def _main(cfg, logger): decoder = model.decoder encoder = model.encoder - mapping_tl = model.mapping_tl - mapping_fl = model.mapping_fl + mapping_tl = model.mapping_d + mapping_fl = model.mapping_f dlatent_avg = model.dlatent_avg logger.info("Trainable parameters generator:") @@ -104,7 +104,7 @@ def encode(x): Z, _ = model.encode(x[i][None, ...], layer_count - 1, 1) zlist.append(Z) Z = torch.cat(zlist) - Z = Z.repeat(1, model.mapping_fl.num_layers, 1) + Z = Z.repeat(1, model.mapping_f.num_layers, 1) return Z def decode(x): diff --git a/tracker.py b/tracker.py index 3fcbe0a4..4400ad17 100644 --- a/tracker.py +++ b/tracker.py @@ -16,6 +16,8 @@ import csv from collections import OrderedDict import matplotlib.pyplot as plt +import matplotlib +matplotlib.use('Agg') import numpy as np import torch import os @@ -23,20 +25,20 @@ class RunningMean: def __init__(self): - self.mean = 0.0 + self._mean = 0.0 self.n = 0 def __iadd__(self, value): - self.mean = (float(value) + self.mean * self.n)/(self.n + 1) + self._mean = (float(value) + self._mean * self.n)/(self.n + 1) self.n += 1 return self def reset(self): - self.mean = 0.0 + self._mean = 0.0 self.n = 0 def mean(self): - return self.mean + return self._mean class RunningMeanTorch: @@ -68,7 +70,7 @@ def __init__(self, output_folder='.'): def update(self, d): for k, v in d.items(): if k not in self.tracks: - self.add(k) + self.add(k, isinstance(v, torch.Tensor)) self.tracks[k] += v def add(self, name, pytorch=True): @@ -106,18 +108,24 @@ def __str__(self): return result[:-2] def plot(self): - plt.figure(figsize=(12, 8)) + fig = plt.figure() + fig.set_size_inches(12, 8) + ax = fig.add_subplot(111) for key in self.tracks.keys(): - plt.plot(self.epochs, self.means_over_epochs[key], label=key) + try: + plt.plot(self.epochs, self.means_over_epochs[key], label=key) + except ValueError: + continue - plt.xlabel('Epoch') - plt.ylabel('Loss') + ax.set_xlabel('Epoch') + ax.set_ylabel('Loss') - plt.legend(loc=4) - plt.grid(True) - plt.tight_layout() + ax.legend(loc=4) + ax.grid(True) + fig.tight_layout() - plt.savefig(os.path.join(self.output_folder, 'plot.png')) + fig.savefig(os.path.join(self.output_folder, 'plot.png')) + fig.clf() plt.close() def state_dict(self): diff --git a/train_alae.py b/train_alae.py index 9f1f30d2..00df8de5 100644 --- a/train_alae.py +++ b/train_alae.py @@ -33,7 +33,7 @@ from PIL import Image -def save_sample(lod2batch, tracker, sample, samplez, x, logger, model, cfg, encoder_optimizer, decoder_optimizer): +def save_sample(lod2batch, tracker, sample, samplez, x, logger, model, cmodel, cfg, encoder_optimizer, decoder_optimizer): os.makedirs('results', exist_ok=True) logger.info('\n[%d/%d] - ptime: %.2f, %s, blend: %.3f, lr: %.12f, %.12f, max mem: %f",' % ( @@ -44,6 +44,7 @@ def save_sample(lod2batch, tracker, sample, samplez, x, logger, model, cfg, enco with torch.no_grad(): model.eval() + cmodel.eval() sample = sample[:lod2batch.get_per_GPU_batch_size()] samplez = samplez[:lod2batch.get_per_GPU_batch_size()] @@ -63,22 +64,20 @@ def save_sample(lod2batch, tracker, sample, samplez, x, logger, model, cfg, enco Z, _ = model.encode(sample_in, lod2batch.lod, blend_factor) if cfg.MODEL.Z_REGRESSION: - Z = model.mapping_fl(Z[:, 0]) + Z = model.mapping_f(Z[:, 0]) else: - Z = Z.repeat(1, model.mapping_fl.num_layers, 1) + Z = Z.repeat(1, model.mapping_f.num_layers, 1) - rec1 = model.decoder(Z, lod2batch.lod, blend_factor, noise=False) - rec2 = model.decoder(Z, lod2batch.lod, blend_factor, noise=True) + rec1 = model.decoder(Z, lod2batch.lod, blend_factor, noise=True) + rec2 = cmodel.decoder(Z, lod2batch.lod, blend_factor, noise=True) - # rec1 = F.interpolate(rec1, sample.shape[2]) - # rec2 = F.interpolate(rec2, sample.shape[2]) - # sample_in = F.interpolate(sample_in, sample.shape[2]) - - Z = model.mapping_fl(samplez) + Z = model.mapping_f(samplez) g_rec = model.decoder(Z, lod2batch.lod, blend_factor, noise=True) - # g_rec = F.interpolate(g_rec, sample.shape[2]) - resultsample = torch.cat([sample_in, rec1, rec2, g_rec], dim=0) + Z = cmodel.mapping_f(samplez) + cg_rec = cmodel.decoder(Z, lod2batch.lod, blend_factor, noise=True) + + resultsample = torch.cat([sample_in, rec1, rec2, g_rec, cg_rec], dim=0) @utils.async_func def save_pic(x_rec): @@ -144,14 +143,15 @@ def train(cfg, logger, local_rank, world_size, distributed): decoder = model.module.decoder encoder = model.module.encoder - mapping_tl = model.module.mapping_tl - mapping_fl = model.module.mapping_fl + mapping_d = model.module.mapping_d + mapping_f = model.module.mapping_f + dlatent_avg = model.module.dlatent_avg else: decoder = model.decoder encoder = model.encoder - mapping_tl = model.mapping_tl - mapping_fl = model.mapping_fl + mapping_d = model.mapping_d + mapping_f = model.mapping_f dlatent_avg = model.dlatent_avg count_param_override.print = lambda a: logger.info(a) @@ -167,12 +167,12 @@ def train(cfg, logger, local_rank, world_size, distributed): decoder_optimizer = LREQAdam([ {'params': decoder.parameters()}, - {'params': mapping_fl.parameters()} + {'params': mapping_f.parameters()} ], lr=cfg.TRAIN.BASE_LEARNING_RATE, betas=(cfg.TRAIN.ADAM_BETA_0, cfg.TRAIN.ADAM_BETA_1), weight_decay=0) encoder_optimizer = LREQAdam([ {'params': encoder.parameters()}, - {'params': mapping_tl.parameters()}, + {'params': mapping_d.parameters()}, ], lr=cfg.TRAIN.BASE_LEARNING_RATE, betas=(cfg.TRAIN.ADAM_BETA_0, cfg.TRAIN.ADAM_BETA_1), weight_decay=0) scheduler = ComboMultiStepLR(optimizers= @@ -187,15 +187,15 @@ def train(cfg, logger, local_rank, world_size, distributed): model_dict = { 'discriminator': encoder, 'generator': decoder, - 'mapping_tl': mapping_tl, - 'mapping_fl': mapping_fl, + 'mapping_tl': mapping_d, + 'mapping_fl': mapping_f, 'dlatent_avg': dlatent_avg } if local_rank == 0: model_dict['discriminator_s'] = model_s.encoder model_dict['generator_s'] = model_s.decoder - model_dict['mapping_tl_s'] = model_s.mapping_tl - model_dict['mapping_fl_s'] = model_s.mapping_fl + model_dict['mapping_tl_s'] = model_s.mapping_d + model_dict['mapping_fl_s'] = model_s.mapping_f tracker = LossTracker(cfg.OUTPUT_DIR) @@ -225,7 +225,7 @@ def train(cfg, logger, local_rank, world_size, distributed): lod2batch = lod_driver.LODDriver(cfg, logger, world_size, dataset_size=len(dataset) * world_size) - if cfg.DATASET.SAMPLES_PATH: + if cfg.DATASET.SAMPLES_PATH != 'no_path': path = cfg.DATASET.SAMPLES_PATH src = [] with torch.no_grad(): @@ -308,7 +308,7 @@ def train(cfg, logger, local_rank, world_size, distributed): decoder_optimizer.zero_grad() lae = model(x, lod2batch.lod, blend_factor, d_train=True, ae=True) tracker.update(dict(lae=lae)) - (lae).backward() + lae.backward() encoder_optimizer.step() decoder_optimizer.step() @@ -325,14 +325,16 @@ def train(cfg, logger, local_rank, world_size, distributed): if lod2batch.is_time_to_save(): checkpointer.save("model_tmp_intermediate_lod%d" % lod_for_saving_model) if lod2batch.is_time_to_report(): - save_sample(lod2batch, tracker, sample, samplez, x, logger, model_s, cfg, encoder_optimizer, + save_sample(lod2batch, tracker, sample, samplez, x, logger, model_s, + model.module if hasattr(model, "module") else model, cfg, encoder_optimizer, decoder_optimizer) scheduler.step() if local_rank == 0: checkpointer.save("model_tmp_lod%d" % lod_for_saving_model) - save_sample(lod2batch, tracker, sample, samplez, x, logger, model_s, cfg, encoder_optimizer, decoder_optimizer) + save_sample(lod2batch, tracker, sample, samplez, x, logger, model_s, + model.module if hasattr(model, "module") else model, cfg, encoder_optimizer, decoder_optimizer) logger.info("Training finish!... save training results") if local_rank == 0: diff --git a/train_alae_separate.py b/train_alae_separate.py deleted file mode 100644 index fe53c9b7..00000000 --- a/train_alae_separate.py +++ /dev/null @@ -1,345 +0,0 @@ -# Copyright 2019-2020 Stanislav Pidhorskyi -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== - -import torch.utils.data -from torchvision.utils import save_image -from net import * -import os -import utils -from checkpointer import Checkpointer -from scheduler import ComboMultiStepLR -from custom_adam import LREQAdam -from dataloader import * -from tqdm import tqdm -from dlutils.pytorch import count_parameters -import dlutils.pytorch.count_parameters as count_param_override -from tracker import LossTracker -from model_separate import Model -from launcher import run -from defaults import get_cfg_defaults -import lod_driver -from PIL import Image - - -def save_sample(lod2batch, tracker, sample, samplez, x, logger, model, cfg, encoder_optimizer, decoder_optimizer): - os.makedirs('results', exist_ok=True) - - logger.info('\n[%d/%d] - ptime: %.2f, %s, blend: %.3f, lr: %.12f, %.12f, max mem: %f",' % ( - (lod2batch.current_epoch + 1), cfg.TRAIN.TRAIN_EPOCHS, lod2batch.per_epoch_ptime, str(tracker), - lod2batch.get_blend_factor(), - encoder_optimizer.param_groups[0]['lr'], decoder_optimizer.param_groups[0]['lr'], - torch.cuda.max_memory_allocated() / 1024.0 / 1024.0)) - - with torch.no_grad(): - model.eval() - sample = sample[:lod2batch.get_per_GPU_batch_size()] - samplez = samplez[:lod2batch.get_per_GPU_batch_size()] - - needed_resolution = model.decoder.layer_to_resolution[lod2batch.lod] - sample_in = sample - while sample_in.shape[2] != needed_resolution: - sample_in = F.avg_pool2d(sample_in, 2, 2) - - blend_factor = lod2batch.get_blend_factor() - if lod2batch.in_transition: - needed_resolution_prev = model.decoder.layer_to_resolution[lod2batch.lod - 1] - sample_in_prev = F.avg_pool2d(sample_in, 2, 2) - sample_in_prev_2x = F.interpolate(sample_in_prev, needed_resolution) - sample_in = sample_in * blend_factor + sample_in_prev_2x * (1.0 - blend_factor) - - Z = model.encode(sample_in, lod2batch.lod, blend_factor) - - Z = Z.repeat(1, model.mapping_fl.num_layers, 1) - rec1 = model.decoder(Z, lod2batch.lod, blend_factor, noise=False) - rec2 = model.decoder(Z, lod2batch.lod, blend_factor, noise=True) - - rec1 = F.interpolate(rec1, sample.shape[2]) - rec2 = F.interpolate(rec2, sample.shape[2]) - sample_in = F.interpolate(sample_in, sample.shape[2]) - - Z = model.mapping_fl(samplez) - g_rec = model.decoder(Z, lod2batch.lod, blend_factor, noise=True) - g_rec = F.interpolate(g_rec, sample.shape[2]) - - resultsample = torch.cat([sample_in, rec1, rec2, g_rec], dim=0) - - @utils.async_func - def save_pic(x_rec): - tracker.register_means(lod2batch.current_epoch + lod2batch.iteration * 1.0 / lod2batch.get_dataset_size()) - tracker.plot() - - result_sample = x_rec * 0.5 + 0.5 - result_sample = result_sample.cpu() - f = os.path.join(cfg.OUTPUT_DIR, - 'sample_%d_%d.jpg' % ( - lod2batch.current_epoch + 1, - lod2batch.iteration // 1000) - ) - print("Saved to %s" % f) - save_image(result_sample, f, nrow=min(32, lod2batch.get_per_GPU_batch_size())) - - save_pic(resultsample) - - -def train(cfg, logger, local_rank, world_size, distributed): - torch.cuda.set_device(local_rank) - model = Model( - startf=cfg.MODEL.START_CHANNEL_COUNT, - layer_count=cfg.MODEL.LAYER_COUNT, - maxf=cfg.MODEL.MAX_CHANNEL_COUNT, - latent_size=cfg.MODEL.LATENT_SPACE_SIZE, - dlatent_avg_beta=cfg.MODEL.DLATENT_AVG_BETA, - style_mixing_prob=cfg.MODEL.STYLE_MIXING_PROB, - mapping_layers=cfg.MODEL.MAPPING_LAYERS, - channels=cfg.MODEL.CHANNELS, - generator=cfg.MODEL.GENERATOR, - encoder=cfg.MODEL.ENCODER - ) - model.cuda(local_rank) - model.train() - - if local_rank == 0: - model_s = Model( - startf=cfg.MODEL.START_CHANNEL_COUNT, - layer_count=cfg.MODEL.LAYER_COUNT, - maxf=cfg.MODEL.MAX_CHANNEL_COUNT, - latent_size=cfg.MODEL.LATENT_SPACE_SIZE, - truncation_psi=cfg.MODEL.TRUNCATIOM_PSI, - truncation_cutoff=cfg.MODEL.TRUNCATIOM_CUTOFF, - mapping_layers=cfg.MODEL.MAPPING_LAYERS, - channels=cfg.MODEL.CHANNELS, - generator=cfg.MODEL.GENERATOR, - encoder=cfg.MODEL.ENCODER) - model_s.cuda(local_rank) - model_s.eval() - model_s.requires_grad_(False) - - if distributed: - model = nn.parallel.DistributedDataParallel( - model, - device_ids=[local_rank], - broadcast_buffers=False, - bucket_cap_mb=25, - find_unused_parameters=True) - model.device_ids = None - - decoder = model.module.decoder - encoder = model.module.encoder - discriminator = model.module.discriminator - mapping_fl = model.module.mapping_fl - dlatent_avg = model.module.dlatent_avg - else: - decoder = model.decoder - encoder = model.encoder - discriminator = model.discriminator - mapping_fl = model.mapping_fl - dlatent_avg = model.dlatent_avg - - count_param_override.print = lambda a: logger.info(a) - - logger.info("Trainable parameters generator:") - count_parameters(decoder) - - logger.info("Trainable parameters discriminator:") - count_parameters(encoder) - - arguments = dict() - arguments["iteration"] = 0 - - decoder_optimizer = LREQAdam([ - {'params': decoder.parameters()}, - {'params': mapping_fl.parameters()} - ], lr=cfg.TRAIN.BASE_LEARNING_RATE, betas=(cfg.TRAIN.ADAM_BETA_0, cfg.TRAIN.ADAM_BETA_1), weight_decay=0) - - encoder_optimizer = LREQAdam([ - {'params': encoder.parameters()}, - ], lr=cfg.TRAIN.BASE_LEARNING_RATE, betas=(cfg.TRAIN.ADAM_BETA_0, cfg.TRAIN.ADAM_BETA_1), weight_decay=0) - - discriminator_optimizer = LREQAdam([ - {'params': discriminator.parameters()}, - ], lr=cfg.TRAIN.BASE_LEARNING_RATE, betas=(cfg.TRAIN.ADAM_BETA_0, cfg.TRAIN.ADAM_BETA_1), weight_decay=0) - - scheduler = ComboMultiStepLR(optimizers= - { - 'encoder_optimizer': encoder_optimizer, - 'discriminator_optimizer': discriminator_optimizer, - 'decoder_optimizer': decoder_optimizer - }, - milestones=cfg.TRAIN.LEARNING_DECAY_STEPS, - gamma=cfg.TRAIN.LEARNING_DECAY_RATE, - reference_batch_size=32, base_lr=cfg.TRAIN.LEARNING_RATES) - - model_dict = { - 'discriminator': discriminator, - 'encoder': encoder, - 'generator': decoder, - 'mapping_fl': mapping_fl, - 'dlatent_avg': dlatent_avg - } - if local_rank == 0: - model_dict['discriminator_s'] = model_s.discriminator - model_dict['encoder_s'] = model_s.encoder - model_dict['generator_s'] = model_s.decoder - model_dict['mapping_fl_s'] = model_s.mapping_fl - - tracker = LossTracker(cfg.OUTPUT_DIR) - - checkpointer = Checkpointer(cfg, - model_dict, - { - 'encoder_optimizer': encoder_optimizer, - 'discriminator_optimizer': discriminator_optimizer, - 'decoder_optimizer': decoder_optimizer, - 'scheduler': scheduler, - 'tracker': tracker - }, - logger=logger, - save=local_rank == 0) - - extra_checkpoint_data = checkpointer.load() - logger.info("Starting from epoch: %d" % (scheduler.start_epoch())) - - arguments.update(extra_checkpoint_data) - - layer_to_resolution = decoder.layer_to_resolution - - dataset = TFRecordsDataset(cfg, logger, rank=local_rank, world_size=world_size, buffer_size_mb=1024, channels=cfg.MODEL.CHANNELS) - - rnd = np.random.RandomState(3456) - latents = rnd.randn(32, cfg.MODEL.LATENT_SPACE_SIZE) - samplez = torch.tensor(latents).float().cuda() - - lod2batch = lod_driver.LODDriver(cfg, logger, world_size, dataset_size=len(dataset) * world_size) - - if cfg.DATASET.SAMPLES_PATH: - path = cfg.DATASET.SAMPLES_PATH - src = [] - with torch.no_grad(): - for filename in list(os.listdir(path))[:32]: - img = np.asarray(Image.open(os.path.join(path, filename))) - if img.shape[2] == 4: - img = img[:, :, :3] - im = img.transpose((2, 0, 1)) - x = torch.tensor(np.asarray(im, dtype=np.float32), requires_grad=True).cuda() / 127.5 - 1. - if x.shape[0] == 4: - x = x[:3] - src.append(x) - sample = torch.stack(src) - else: - dataset.reset(cfg.DATASET.MAX_RESOLUTION_LEVEL, 32) - sample = next(make_dataloader(cfg, logger, dataset, 32, local_rank)) - sample = (sample / 127.5 - 1.) - - lod2batch.set_epoch(scheduler.start_epoch(), [encoder_optimizer, decoder_optimizer]) - - for epoch in range(scheduler.start_epoch(), cfg.TRAIN.TRAIN_EPOCHS): - model.train() - lod2batch.set_epoch(epoch, [encoder_optimizer, decoder_optimizer]) - - logger.info("Batch size: %d, Batch size per GPU: %d, LOD: %d - %dx%d, blend: %.3f, dataset size: %d" % ( - lod2batch.get_batch_size(), - lod2batch.get_per_GPU_batch_size(), - lod2batch.lod, - 2 ** lod2batch.get_lod_power2(), - 2 ** lod2batch.get_lod_power2(), - lod2batch.get_blend_factor(), - len(dataset) * world_size)) - - dataset.reset(lod2batch.get_lod_power2(), lod2batch.get_per_GPU_batch_size()) - batches = make_dataloader(cfg, logger, dataset, lod2batch.get_per_GPU_batch_size(), local_rank) - - scheduler.set_batch_size(lod2batch.get_batch_size(), lod2batch.lod) - - model.train() - - need_permute = False - epoch_start_time = time.time() - - i = 0 - with torch.autograd.profiler.profile(use_cuda=True, enabled=False) as prof: - for x_orig in tqdm(batches): - i +=1 - with torch.no_grad(): - if x_orig.shape[0] != lod2batch.get_per_GPU_batch_size(): - continue - if need_permute: - x_orig = x_orig.permute(0, 3, 1, 2) - x_orig = (x_orig / 127.5 - 1.) - - blend_factor = lod2batch.get_blend_factor() - - needed_resolution = layer_to_resolution[lod2batch.lod] - x = x_orig - - if lod2batch.in_transition: - needed_resolution_prev = layer_to_resolution[lod2batch.lod - 1] - x_prev = F.avg_pool2d(x_orig, 2, 2) - x_prev_2x = F.interpolate(x_prev, needed_resolution) - x = x * blend_factor + x_prev_2x * (1.0 - blend_factor) - - x.requires_grad = True - - loss_d = model(x, lod2batch.lod, blend_factor, d_train=True, ae=False) - tracker.update(dict(loss_d=loss_d)) - loss_d.backward() - discriminator_optimizer.step() - decoder_optimizer.zero_grad() - discriminator_optimizer.zero_grad() - - loss_g = model(x, lod2batch.lod, blend_factor, d_train=False, ae=False) - tracker.update(dict(loss_g=loss_g)) - loss_g.backward() - decoder_optimizer.step() - decoder_optimizer.zero_grad() - discriminator_optimizer.zero_grad() - - lae = model(x, lod2batch.lod, blend_factor, d_train=True, ae=True) - tracker.update(dict(lae=lae)) - (lae).backward() - encoder_optimizer.step() - decoder_optimizer.step() - encoder_optimizer.zero_grad() - decoder_optimizer.zero_grad() - - if local_rank == 0: - betta = 0.5 ** (lod2batch.get_batch_size() / (10 * 1000.0)) - model_s.lerp(model, betta) - - epoch_end_time = time.time() - per_epoch_ptime = epoch_end_time - epoch_start_time - - lod_for_saving_model = lod2batch.lod - lod2batch.step() - if local_rank == 0: - if lod2batch.is_time_to_save(): - checkpointer.save("model_tmp_intermediate_lod%d" % lod_for_saving_model) - if lod2batch.is_time_to_report(): - save_sample(lod2batch, tracker, sample, samplez, x, logger, model_s, cfg, encoder_optimizer, decoder_optimizer) - - scheduler.step() - - if local_rank == 0: - checkpointer.save("model_tmp_lod%d" % lod_for_saving_model) - save_sample(lod2batch, tracker, sample, samplez, x, logger, model_s, cfg, encoder_optimizer, decoder_optimizer) - - logger.info("Training finish!... save training results") - if local_rank == 0: - checkpointer.save("model_final").wait() - - -if __name__ == "__main__": - gpu_count = torch.cuda.device_count() - run(train, get_cfg_defaults(), description='StyleGAN', default_config='configs/experiment_celeba_sep.yaml', - world_size=gpu_count)