Skip to content

Commit

Permalink
Cleanup model_z_gan.py
Browse files Browse the repository at this point in the history
Added GenModel
  • Loading branch information
Stanislav Pidhorskyi committed Apr 15, 2020
1 parent c67d491 commit dd902c3
Showing 1 changed file with 49 additions and 6 deletions.
55 changes: 49 additions & 6 deletions model_z_gan.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2019 Stanislav Pidhorskyi
# Copyright 2019-2020 Stanislav Pidhorskyi
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -28,36 +28,33 @@ def __init__(self, dlatent_size, layer_count):

class Model(nn.Module):
def __init__(self, startf=32, maxf=256, layer_count=3, latent_size=128, mapping_layers=5, dlatent_avg_beta=None,
truncation_psi=None, truncation_cutoff=None, style_mixing_prob=None, channels=3, generator="", encoder="", z_regression=False):
truncation_psi=None, truncation_cutoff=None, style_mixing_prob=None, channels=3, generator="",
encoder="", z_regression=False):
super(Model, self).__init__()

self.layer_count = layer_count
self.z_regression = z_regression

#self.mapping_tl = VAEMappingToLatent_old(
self.mapping_tl = MAPPINGS["MappingToLatent"](
latent_size=latent_size,
dlatent_size=latent_size,
mapping_fmaps=latent_size,
mapping_layers=3)

#self.mapping_fl = VAEMappingFromLatent(
self.mapping_fl = MAPPINGS["MappingFromLatent"](
num_layers=2 * layer_count,
latent_size=latent_size,
dlatent_size=latent_size,
mapping_fmaps=latent_size,
mapping_layers=mapping_layers)

#self.decoder = Generator(
self.decoder = GENERATORS[generator](
startf=startf,
layer_count=layer_count,
maxf=maxf,
latent_size=latent_size,
channels=channels)

#self.encoder = Encoder_old(
self.encoder = ENCODERS[encoder](
startf=startf,
layer_count=layer_count,
Expand Down Expand Up @@ -165,3 +162,49 @@ def lerp(self, other, betta):
other_param = list(other.mapping_tl.parameters()) + list(other.mapping_fl.parameters()) + list(other.decoder.parameters()) + list(other.encoder.parameters()) + list(other.dlatent_avg.parameters())
for p, p_other in zip(params, other_param):
p.data.lerp_(p_other.data, 1.0 - betta)


class GenModel(nn.Module):
def __init__(self, startf=32, maxf=256, layer_count=3, latent_size=128, mapping_layers=5, dlatent_avg_beta=None,
truncation_psi=None, truncation_cutoff=None, style_mixing_prob=None, channels=3, generator="", encoder="", z_regression=False):
super(GenModel, self).__init__()

self.layer_count = layer_count

self.mapping_fl = MAPPINGS["MappingFromLatent"](
num_layers=2 * layer_count,
latent_size=latent_size,
dlatent_size=latent_size,
mapping_fmaps=latent_size,
mapping_layers=mapping_layers)

self.decoder = GENERATORS[generator](
startf=startf,
layer_count=layer_count,
maxf=maxf,
latent_size=latent_size,
channels=channels)

self.dlatent_avg = DLatent(latent_size, self.mapping_fl.num_layers)
self.latent_size = latent_size
self.dlatent_avg_beta = dlatent_avg_beta
self.truncation_psi = truncation_psi
self.style_mixing_prob = style_mixing_prob
self.truncation_cutoff = truncation_cutoff

def generate(self, lod, blend_factor, z=None):
styles = self.mapping_fl(z)[:, 0]
s = styles.view(styles.shape[0], 1, styles.shape[1])

styles = s.repeat(1, self.mapping_fl.num_layers, 1)

layer_idx = torch.arange(self.mapping_fl.num_layers)[np.newaxis, :, np.newaxis]
ones = torch.ones(layer_idx.shape, dtype=torch.float32)
coefs = torch.where(layer_idx < self.truncation_cutoff, self.truncation_psi * ones, ones)
styles = torch.lerp(self.dlatent_avg.buff.data, styles, coefs)

rec = self.decoder.forward(styles, lod, blend_factor, True)
return rec

def forward(self, x):
return self.generate(self.layer_count-1, 1.0, z=x)

0 comments on commit dd902c3

Please sign in to comment.