Skip to content

Commit

Permalink
Fixed confusing names. #38, #35
Browse files Browse the repository at this point in the history
Fixed memory leaking related to matplotlib #17

Hope did not break anything. If you have an issues, try a revision before this commit.
  • Loading branch information
Stanislav Pidhorskyi committed Dec 7, 2020
1 parent 64accd1 commit d43e061
Show file tree
Hide file tree
Showing 31 changed files with 156 additions and 788 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
> **Adversarial Latent Autoencoders**<br>
> Stanislav Pidhorskyi, Donald Adjeroh, Gianfranco Doretto<br>
>
> **Abstract:** *Autoencoder networks are unsupervised approaches aiming at combining generative and representational properties by learning simultaneously an encoder-generator map. Although studied extensively, the issues of whether they have the same generative power of GANs, or learn disentangled representations, have not been fully addressed. We introduce an autoencoder that tackles these issues jointly, which we call Adversarial Latent Autoencoder (ALAE). It is a general architecture that can leverage recent improvements on GAN training procedures. We designed two autoencoders: one based on a MLP encoder, and another based on a StyleGAN generator, which we call StyleALAE. We verify the disentanglement properties of both architectures. We show that StyleALAE can not only generate 1024x1024 face images with comparable quality of StyleGAN, but at the same resolution can also produce face reconstructions and manipulations based on real images. This makes ALAE the first autoencoder able to compare with, and go beyond, the capabilities of a generator-only type of architecture.*
> **Abstract:** *Autoencoder networks are unsupervised approaches aiming at combining generative and representational properties by learning simultaneously an encoder-generator map. Although studied extensively, the issues of whether they have the same generative power of GANs, or learn disentangled representations, have not been fully addressed. We introduce an autoencoder that tackles these issues jointly, which we call Adversarial Latent Autoencoder (ALAE). It is a general architecture that can leverage recent improvements on GAN training procedures. We designed two autoencoders: one based on a MLP encoder, and another based on a StyleGAN generator, which we call StyleALAE. We verify the disentanglement properties of both architectures. We show that StyleALAE can not only generate 1024x1024 face images with comparable quality of StyleGAN, but at the same resolution can also produce face reconstructions and manipulations based on real images. This makes ALAE the first autoencoder able to compare with, and go beyond the capabilities of a generator-only type of architecture.*
## Citation
* Stanislav Pidhorskyi, Donald A. Adjeroh, and Gianfranco Doretto. Adversarial Latent Autoencoders. In *Proceedings of the IEEE Computer Society Conference on Computer Vision and Pattern Recognition (CVPR)*, 2020. [to appear]
Expand Down
2 changes: 1 addition & 1 deletion checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def load(self, ignore_last_checkpoint=False, file_name=None):
self.auxiliary[name].load_state_dict(checkpoint["optimizers"].pop(name))
if name in checkpoint:
self.auxiliary[name].load_state_dict(checkpoint.pop(name))
except IndexError:
except (IndexError, ValueError):
self.logger.warning('%s\nFailed to load: %s\n%s' % ('!' * 160, name, '!' * 160))
checkpoint.pop('auxiliary')

Expand Down
32 changes: 0 additions & 32 deletions configs/celeba_ablation_nostyle.yaml

This file was deleted.

33 changes: 0 additions & 33 deletions configs/celeba_ablation_separate.yaml

This file was deleted.

32 changes: 0 additions & 32 deletions configs/celeba_ablation_z_regression.yaml

This file was deleted.

2 changes: 1 addition & 1 deletion configs/mnist_fc.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ MODEL:
CHANNELS: 1
GENERATOR: "GeneratorFC"
ENCODER: "EncoderFC"
MAPPING_TO_LATENT: "MappingToLatentNoStyle"
MAPPING_D: "MappingDNoStyle"

OUTPUT_DIR: mnist_results_fc2z_2
TRAIN:
Expand Down
44 changes: 0 additions & 44 deletions configs/mnist_fc_no_lreq_js_loss.yaml

This file was deleted.

4 changes: 2 additions & 2 deletions defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,8 @@
_C.MODEL.CHANNELS = 3
_C.MODEL.GENERATOR = "GeneratorDefault"
_C.MODEL.ENCODER = "EncoderDefault"
_C.MODEL.MAPPING_TO_LATENT = "MappingToLatent"
_C.MODEL.MAPPING_FROM_LATENT = "MappingFromLatent"
_C.MODEL.MAPPING_D = "MappingD"
_C.MODEL.MAPPING_F = "MappingF"
_C.MODEL.Z_REGRESSION = False

_C.TRAIN = CN()
Expand Down
10 changes: 5 additions & 5 deletions interactive_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,8 @@ def sample(cfg, logger):

decoder = model.decoder
encoder = model.encoder
mapping_tl = model.mapping_tl
mapping_fl = model.mapping_fl
mapping_tl = model.mapping_d
mapping_fl = model.mapping_f
dlatent_avg = model.dlatent_avg

logger.info("Trainable parameters generator:")
Expand Down Expand Up @@ -97,7 +97,7 @@ def sample(cfg, logger):

def encode(x):
Z, _ = model.encode(x, layer_count - 1, 1)
Z = Z.repeat(1, model.mapping_fl.num_layers, 1)
Z = Z.repeat(1, model.mapping_f.num_layers, 1)
return Z

def decode(x):
Expand Down Expand Up @@ -186,9 +186,9 @@ def loadRandom():
def update_image(w, latents_original):
with torch.no_grad():
w = w + model.dlatent_avg.buff.data[0]
w = w[None, None, ...].repeat(1, model.mapping_fl.num_layers, 1)
w = w[None, None, ...].repeat(1, model.mapping_f.num_layers, 1)

layer_idx = torch.arange(model.mapping_fl.num_layers)[np.newaxis, :, np.newaxis]
layer_idx = torch.arange(model.mapping_f.num_layers)[np.newaxis, :, np.newaxis]
cur_layers = (7 + 1) * 2
mixing_cutoff = cur_layers
styles = torch.where(layer_idx < mixing_cutoff, w, latents_original)
Expand Down
4 changes: 2 additions & 2 deletions make_figures/make_generation_figure.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,8 @@ def sample(cfg, logger):

decoder = model.decoder
encoder = model.encoder
mapping_tl = model.mapping_tl
mapping_fl = model.mapping_fl
mapping_tl = model.mapping_d
mapping_fl = model.mapping_f

dlatent_avg = model.dlatent_avg

Expand Down
6 changes: 3 additions & 3 deletions make_figures/make_recon_figure_celeba_pioneer.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,8 @@ def sample(cfg, logger):

decoder = model.decoder
encoder = model.encoder
mapping_tl = model.mapping_tl
mapping_fl = model.mapping_fl
mapping_tl = model.mapping_d
mapping_fl = model.mapping_f
dlatent_avg = model.dlatent_avg

logger.info("Trainable parameters generator:")
Expand Down Expand Up @@ -108,7 +108,7 @@ def sample(cfg, logger):

def encode(x):
Z, _ = model.encode(x, layer_count - 1, 1)
Z = Z.repeat(1, model.mapping_fl.num_layers, 1)
Z = Z.repeat(1, model.mapping_f.num_layers, 1)
return Z

def decode(x):
Expand Down
6 changes: 3 additions & 3 deletions make_figures/make_recon_figure_ffhq_real.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,8 @@ def sample(cfg, logger):

decoder = model.decoder
encoder = model.encoder
mapping_tl = model.mapping_tl
mapping_fl = model.mapping_fl
mapping_tl = model.mapping_d
mapping_fl = model.mapping_f
dlatent_avg = model.dlatent_avg

logger.info("Trainable parameters generator:")
Expand Down Expand Up @@ -104,7 +104,7 @@ def sample(cfg, logger):

def encode(x):
Z, _ = model.encode(x, layer_count - 1, 1)
Z = Z.repeat(1, model.mapping_fl.num_layers, 1)
Z = Z.repeat(1, model.mapping_f.num_layers, 1)
return Z

def decode(x):
Expand Down
8 changes: 4 additions & 4 deletions make_figures/make_recon_figure_interpolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,8 @@ def sample(cfg, logger):

decoder = model.decoder
encoder = model.encoder
mapping_tl = model.mapping_tl
mapping_fl = model.mapping_fl
mapping_tl = model.mapping_d
mapping_fl = model.mapping_f
dlatent_avg = model.dlatent_avg

logger.info("Trainable parameters generator:")
Expand Down Expand Up @@ -105,7 +105,7 @@ def sample(cfg, logger):

def encode(x):
Z, _ = model.encode(x, layer_count - 1, 1)
Z = Z.repeat(1, model.mapping_fl.num_layers, 1)
Z = Z.repeat(1, model.mapping_f.num_layers, 1)
return Z

def decode(x):
Expand Down Expand Up @@ -144,7 +144,7 @@ def open_image(filename):

def make(w):
with torch.no_grad():
w = w[None, None, ...].repeat(1, model.mapping_fl.num_layers, 1)
w = w[None, None, ...].repeat(1, model.mapping_f.num_layers, 1)
x_rec = decode(w)
return x_rec

Expand Down
6 changes: 3 additions & 3 deletions make_figures/make_recon_figure_multires.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,8 @@ def sample(cfg, logger):

decoder = model.decoder
encoder = model.encoder
mapping_tl = model.mapping_tl
mapping_fl = model.mapping_fl
mapping_tl = model.mapping_d
mapping_fl = model.mapping_f
dlatent_avg = model.dlatent_avg

logger.info("Trainable parameters generator:")
Expand Down Expand Up @@ -106,7 +106,7 @@ def sample(cfg, logger):

def encode(x):
Z, _ = model.encode(x, layer_count - 1, 1)
Z = Z.repeat(1, model.mapping_fl.num_layers, 1)
Z = Z.repeat(1, model.mapping_f.num_layers, 1)
return Z

def decode(x):
Expand Down
6 changes: 3 additions & 3 deletions make_figures/make_recon_figure_paged.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,8 @@ def sample(cfg, logger):

decoder = model.decoder
encoder = model.encoder
mapping_tl = model.mapping_tl
mapping_fl = model.mapping_fl
mapping_tl = model.mapping_d
mapping_fl = model.mapping_f
dlatent_avg = model.dlatent_avg

logger.info("Trainable parameters generator:")
Expand Down Expand Up @@ -107,7 +107,7 @@ def sample(cfg, logger):

def encode(x):
Z, _ = model.encode(x, layer_count - 1, 1)
Z = Z.repeat(1, model.mapping_fl.num_layers, 1)
Z = Z.repeat(1, model.mapping_f.num_layers, 1)
return Z

def decode(x):
Expand Down
10 changes: 5 additions & 5 deletions make_figures/make_traversarls.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,8 @@ def sample(cfg, logger):

decoder = model.decoder
encoder = model.encoder
mapping_tl = model.mapping_tl
mapping_fl = model.mapping_fl
mapping_tl = model.mapping_d
mapping_fl = model.mapping_f
dlatent_avg = model.dlatent_avg

logger.info("Trainable parameters generator:")
Expand Down Expand Up @@ -83,7 +83,7 @@ def sample(cfg, logger):

def encode(x):
Z, _ = model.encode(x, layer_count - 1, 1)
Z = Z.repeat(1, model.mapping_fl.num_layers, 1)
Z = Z.repeat(1, model.mapping_f.num_layers, 1)
return Z

def decode(x):
Expand Down Expand Up @@ -122,9 +122,9 @@ def do_attribute_traversal(path, attrib_idx, start, end):
def update_image(w):
with torch.no_grad():
w = w + model.dlatent_avg.buff.data[0]
w = w[None, None, ...].repeat(1, model.mapping_fl.num_layers, 1)
w = w[None, None, ...].repeat(1, model.mapping_f.num_layers, 1)

layer_idx = torch.arange(model.mapping_fl.num_layers)[np.newaxis, :, np.newaxis]
layer_idx = torch.arange(model.mapping_f.num_layers)[np.newaxis, :, np.newaxis]
cur_layers = (7 + 1) * 2
mixing_cutoff = cur_layers
styles = torch.where(layer_idx < mixing_cutoff, w, _latents[0])
Expand Down
6 changes: 3 additions & 3 deletions make_figures/old/make_recon_figure_bed.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,8 @@ def sample(cfg, logger):

decoder = model.decoder
encoder = model.encoder
mapping_tl = model.mapping_tl
mapping_fl = model.mapping_fl
mapping_tl = model.mapping_d
mapping_fl = model.mapping_f
dlatent_avg = model.dlatent_avg

logger.info("Trainable parameters generator:")
Expand Down Expand Up @@ -109,7 +109,7 @@ def sample(cfg, logger):

def encode(x):
Z, _ = model.encode(x, layer_count - 1, 1)
Z = Z.repeat(1, model.mapping_fl.num_layers, 1)
Z = Z.repeat(1, model.mapping_f.num_layers, 1)
return Z

def decode(x):
Expand Down
Loading

0 comments on commit d43e061

Please sign in to comment.