Skip to content

Commit

Permalink
Autoencoder
Browse files Browse the repository at this point in the history
  • Loading branch information
Stanislav Pidhorskyi committed Oct 9, 2019
1 parent b94c1e7 commit 6ab4cec
Show file tree
Hide file tree
Showing 24 changed files with 2,258 additions and 502 deletions.
562 changes: 301 additions & 261 deletions VAE.py

Large diffs are not rendered by default.

113 changes: 113 additions & 0 deletions checkpointer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
# Copyright 2019 Stanislav Pidhorskyi
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

import os
from torch import nn
import torch
import utils


def get_model_dict(x):
if x is None:
return None
if isinstance(x, nn.DataParallel):
return x.module.state_dict()
else:
return x.state_dict()


def load_model(x, state_dict):
if isinstance(x, nn.DataParallel):
x.module.load_state_dict(state_dict)
else:
x.load_state_dict(state_dict)


class Checkpointer(object):
def __init__(self, cfg, models, auxiliary=None, logger=None, save=True):
self.models = models
self.auxiliary = auxiliary
self.cfg = cfg
self.logger = logger
self._save = save

def save(self, _name, **kwargs):
if not self._save:
return
data = dict()
data["models"] = dict()
data["auxiliary"] = dict()
for name, model in self.models.items():
data["models"][name] = get_model_dict(model)

if self.auxiliary is not None:
for name, item in self.auxiliary.items():
data["auxiliary"][name] = item.state_dict()
data.update(kwargs)

@utils.async_func
def save_data():
save_file = os.path.join(self.cfg.OUTPUT_DIR, "%s.pth" % _name)
self.logger.info("Saving checkpoint to %s" % save_file)
torch.save(data, save_file)
self.tag_last_checkpoint(save_file)

return save_data()

def load(self, ignore_last_checkpoint=False, file_name=None):
save_file = os.path.join(self.cfg.OUTPUT_DIR, "last_checkpoint")
try:
with open(save_file, "r") as last_checkpoint:
f = last_checkpoint.read().strip()
except IOError:
self.logger.info("No checkpoint found. Initializing model from scratch")
if file_name is None:
return {}

if ignore_last_checkpoint:
self.logger.info("Forced to Initialize model from scratch")
return {}
if file_name is not None:
f = file_name

self.logger.info("Loading checkpoint from {}".format(f))
checkpoint = torch.load(f, map_location=torch.device("cpu"))
for name, model in self.models.items():
if name in checkpoint["models"]:
model_dict = checkpoint["models"].pop(name)
if model_dict is not None:
self.models[name].load_state_dict(model_dict, strict=False)
else:
self.logger.warning("State dict for model \"%s\" is None " % name)
else:
self.logger.warning("No state dict for model: %s" % name)
checkpoint.pop('models')
if "auxiliary" in checkpoint and self.auxiliary:
self.logger.info("Loading auxiliary from {}".format(f))
for name, item in self.auxiliary.items():
if name in checkpoint["auxiliary"]:
self.auxiliary[name].load_state_dict(checkpoint["auxiliary"].pop(name))
if "optimizers" in checkpoint and name in checkpoint["optimizers"]:
self.auxiliary[name].load_state_dict(checkpoint["optimizers"].pop(name))
if name in checkpoint:
self.auxiliary[name].load_state_dict(checkpoint.pop(name))
checkpoint.pop('auxiliary')

return checkpoint

def tag_last_checkpoint(self, last_filename):
save_file = os.path.join(self.cfg.OUTPUT_DIR, "last_checkpoint")
with open(save_file, "w") as f:
f.write(last_filename)
25 changes: 25 additions & 0 deletions configs/experiment.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
DATASET:
PART_COUNT: 16
SIZE: 202576
PATH: /data/datasets/celeba/tfrecords/celeba-r%02d.tfrecords.%03d
MAX_RESOLUTION_LEVEL: 7
MODEL:
LATENT_SPACE_SIZE: 256
LAYER_COUNT: 6
MAX_CHANNEL_COUNT: 512
START_CHANNEL_COUNT: 64
DLATENT_AVG_BETA: 0.995
# MAPPING_LAYERS: 5
MAPPING_LAYERS: 8
OUTPUT_DIR: results_ae_2
TRAIN:
BASE_LEARNING_RATE: 0.002
EPOCHS_PER_LOD: 6
LEARNING_DECAY_RATE: 0.1
LEARNING_DECAY_STEPS: []
TRAIN_EPOCHS: 44
# 4 8 16 32 64 128 256
LOD_2_BATCH_8GPU: [512, 256, 128, 64, 32, 32]
LOD_2_BATCH_4GPU: [512, 256, 128, 64, 32, 16]
LOD_2_BATCH_2GPU: [512, 256, 128, 64, 32, 16]
LOD_2_BATCH_1GPU: [512, 256, 128, 64, 32, 16]
25 changes: 25 additions & 0 deletions configs/experiment_celeba.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
DATASET:
PART_COUNT: 16
SIZE: 202576
PATH: /data/datasets/celeba/tfrecords/celeba-r%02d.tfrecords.%03d
MAX_RESOLUTION_LEVEL: 7
MODEL:
# LATENT_SPACE_SIZE: 256
LATENT_SPACE_SIZE: 512
LAYER_COUNT: 6
MAX_CHANNEL_COUNT: 512
START_CHANNEL_COUNT: 64
DLATENT_AVG_BETA: 0.995
# MAPPING_LAYERS: 5
MAPPING_LAYERS: 8
OUTPUT_DIR: /data/celeba/results
TRAIN:
BASE_LEARNING_RATE: 0.0015
EPOCHS_PER_LOD: 6
LEARNING_DECAY_RATE: 0.1
LEARNING_DECAY_STEPS: []
TRAIN_EPOCHS: 44
LOD_2_BATCH_8GPU: [512, 256, 128, 64, 32, 32]
LOD_2_BATCH_4GPU: [512, 256, 128, 64, 32, 16]
LOD_2_BATCH_2GPU: [256, 256, 128, 64, 32, 16]
LOD_2_BATCH_1GPU: [128, 128, 128, 64, 32, 16]
25 changes: 25 additions & 0 deletions configs/experiment_celeba_tiny.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
DATASET:
PART_COUNT: 16
SIZE: 202576
PATH: /data/datasets/celeba/tfrecords/celeba-r%02d.tfrecords.%03d
MAX_RESOLUTION_LEVEL: 7
MODEL:
# LATENT_SPACE_SIZE: 256
LATENT_SPACE_SIZE: 128
LAYER_COUNT: 6
MAX_CHANNEL_COUNT: 128
START_CHANNEL_COUNT: 16
DLATENT_AVG_BETA: 0.995
# MAPPING_LAYERS: 5
MAPPING_LAYERS: 8
OUTPUT_DIR: /data/celeba_tiny/results
TRAIN:
BASE_LEARNING_RATE: 0.0015
EPOCHS_PER_LOD: 6
LEARNING_DECAY_RATE: 0.1
LEARNING_DECAY_STEPS: []
TRAIN_EPOCHS: 44
LOD_2_BATCH_8GPU: [512, 256, 128, 64, 32, 32]
LOD_2_BATCH_4GPU: [512, 256, 128, 64, 32, 16]
LOD_2_BATCH_2GPU: [256, 256, 128, 64, 32, 16]
LOD_2_BATCH_1GPU: [128, 128, 128, 64, 32, 16]
53 changes: 53 additions & 0 deletions configs/experiment_ffhq.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# Running locally

DATASET:
PART_COUNT: 16
SIZE: 70000
FFHQ_SOURCE: /data/datasets/ffhq-dataset/tfrecords/ffhq/ffhq-r%02d.tfrecords
PATH: /data/datasets/ffhq-dataset/tfrecords/ffhq/splitted/ffhq-r%02d.tfrecords.%03d
MAX_RESOLUTION_LEVEL: 10
MODEL:
LATENT_SPACE_SIZE: 512
LAYER_COUNT: 9
MAX_CHANNEL_COUNT: 512
START_CHANNEL_COUNT: 16
DLATENT_AVG_BETA: 0.995
MAPPING_LAYERS: 8
OUTPUT_DIR: /data/ffhq_3/results
TRAIN:
BASE_LEARNING_RATE: 0.0015
EPOCHS_PER_LOD: 16
LEARNING_DECAY_RATE: 0.1
LEARNING_DECAY_STEPS: []
TRAIN_EPOCHS: 112
# 4 8 16 32 64 128 256
LOD_2_BATCH_8GPU: [512, 256, 128, 64, 32, 32, 32]
LOD_2_BATCH_4GPU: [512, 256, 128, 64, 32, 32, 32]
LOD_2_BATCH_2GPU: [512, 256, 128, 64, 32, 32, 16]
LOD_2_BATCH_1GPU: [512, 256, 128, 64, 32, 16]

# Running on server
#DATASET:
# PART_COUNT: 16
# SIZE: 70000
# FFHQ_SOURCE: /data/datasets/ffhq-dataset/tfrecords/ffhq/ffhq-r%02d.tfrecords
# PATH: /data/datasets/ffhq-dataset/tfrecords/ffhq/splitted/ffhq-r%02d.tfrecords.%03d
# MAX_RESOLUTION_LEVEL: 10
#MODEL:
# LATENT_SPACE_SIZE: 256
# LAYER_COUNT: 7
# MAX_CHANNEL_COUNT: 512
# START_CHANNEL_COUNT: 32
# DLATENT_AVG_BETA: 0.995
# MAPPING_LAYERS: 5
#OUTPUT_DIR: /data/ffhq/results
#TRAIN:
# BASE_LEARNING_RATE: 0.0015
# EPOCHS_PER_LOD: 16
# LEARNING_DECAY_RATE: 0.1
# LEARNING_DECAY_STEPS: []
# TRAIN_EPOCHS: 112
# LOD_2_BATCH_8GPU: [512, 256, 128, 64, 32, 32, 32]
# LOD_2_BATCH_4GPU: [512, 256, 128, 64, 32, 32, 32]
# LOD_2_BATCH_2GPU: [512, 256, 128, 64, 32, 32, 16]
# LOD_2_BATCH_1GPU: [512, 256, 128, 64, 32, 16]
21 changes: 21 additions & 0 deletions configs/experiment_large.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
DATASET:
PART_COUNT: 8
PATH: /efs/user/pidhorss/celeba/data_fold_%d_lod_%d.pkl
MODEL:
LATENT_SPACE_SIZE: 512
LAYER_COUNT: 6
MAX_CHANNEL_COUNT: 512
START_CHANNEL_COUNT: 128
DLATENT_AVG_BETA: 0.995
MAPPING_LAYERS: 8
OUTPUT_DIR: results
TRAIN:
BASE_LEARNING_RATE: 0.0015
EPOCHS_PER_LOD: 6
LEARNING_DECAY_RATE: 0.1
LEARNING_DECAY_STEPS: []
TRAIN_EPOCHS: 36
LOD_2_BATCH_8GPU: [512, 256, 128, 64, 32, 32]
LOD_2_BATCH_4GPU: [512, 256, 128, 64, 32, 16]
LOD_2_BATCH_2GPU: [256, 256, 128, 64, 32, 16]
LOD_2_BATCH_1GPU: [128, 128, 128, 64, 32, 16]
24 changes: 24 additions & 0 deletions configs/experiment_mnist.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
DATASET:
PART_COUNT: 1
SIZE: 60000
MAX_RESOLUTION_LEVEL: 5
MODEL:
# LATENT_SPACE_SIZE: 256
LATENT_SPACE_SIZE: 16
LAYER_COUNT: 4
MAX_CHANNEL_COUNT: 256
START_CHANNEL_COUNT: 64
DLATENT_AVG_BETA: 0.995
# MAPPING_LAYERS: 5
MAPPING_LAYERS: 4
OUTPUT_DIR: mnist_results2
TRAIN:
BASE_LEARNING_RATE: 0.0015
EPOCHS_PER_LOD: 10
LEARNING_DECAY_RATE: 0.1
LEARNING_DECAY_STEPS: []
TRAIN_EPOCHS: 40
LOD_2_BATCH_8GPU: [512, 256, 128, 64, 32, 32]
LOD_2_BATCH_4GPU: [512, 256, 128, 64, 32, 16]
LOD_2_BATCH_2GPU: [256, 256, 128, 64, 32, 16]
LOD_2_BATCH_1GPU: [256, 256, 128, 64, 32, 16]
21 changes: 21 additions & 0 deletions configs/experiment_small.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
DATASET:
PART_COUNT: 8
PATH: /efs/user/pidhorss/celeba/data_fold_%d_lod_%d.pkl
MODEL:
LATENT_SPACE_SIZE: 128
LAYER_COUNT: 6
MAX_CHANNEL_COUNT: 512
START_CHANNEL_COUNT: 16
DLATENT_AVG_BETA: 0.995
MAPPING_LAYERS: 5
OUTPUT_DIR: results
TRAIN:
BASE_LEARNING_RATE: 0.001
EPOCHS_PER_LOD: 6
LEARNING_DECAY_RATE: 0.1
LEARNING_DECAY_STEPS: []
TRAIN_EPOCHS: 40
LOD_2_BATCH_8GPU: [512, 256, 128, 64, 32, 32]
LOD_2_BATCH_4GPU: [512, 256, 128, 64, 32, 16]
LOD_2_BATCH_2GPU: [256, 256, 128, 64, 32, 16]
LOD_2_BATCH_1GPU: [128, 128, 128, 64, 32, 16]
22 changes: 22 additions & 0 deletions configs/experiment_stylegan.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
DATASET:
PART_COUNT: 8
PATH: /efs/user/pidhorss/celeba/data_fold_%d_lod_%d.pkl
MODEL:
LATENT_SPACE_SIZE: 512
LAYER_COUNT: 9
MAX_CHANNEL_COUNT: 512
START_CHANNEL_COUNT: 16
DLATENT_AVG_BETA: 0.995
MAPPING_LAYERS: 8
OUTPUT_DIR: results
TRAIN:
BASE_LEARNING_RATE: 0.0015
EPOCHS_PER_LOD: 6
LEARNING_DECAY_RATE: 0.1
LEARNING_DECAY_STEPS: []
TRAIN_EPOCHS: 40
# 4 8 16 32 64 128
LOD_2_BATCH_8GPU: [512, 256, 128, 64, 32, 32]
LOD_2_BATCH_4GPU: [512, 256, 128, 64, 32, 16]
LOD_2_BATCH_2GPU: [256, 256, 128, 64, 32, 16]
LOD_2_BATCH_1GPU: [128, 128, 128, 64, 32, 16]
Loading

0 comments on commit 6ab4cec

Please sign in to comment.