Skip to content

Commit

Permalink
Demo with sliders
Browse files Browse the repository at this point in the history
  • Loading branch information
Stanislav Pidhorskyi committed Apr 18, 2020
1 parent 697b3e3 commit 3afbac2
Show file tree
Hide file tree
Showing 7 changed files with 126 additions and 497 deletions.
52 changes: 0 additions & 52 deletions find_principal_directions.py

This file was deleted.

15 changes: 0 additions & 15 deletions gradient_reversal.py

This file was deleted.

161 changes: 62 additions & 99 deletions interactive_sliders.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2019 Stanislav Pidhorskyi
# Copyright 2019-2020 Stanislav Pidhorskyi
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -13,41 +13,14 @@
# limitations under the License.
# ==============================================================================

from __future__ import print_function
import torch.utils.data
from scipy import misc
from torch import optim
from torchvision.utils import save_image
from net import *
import numpy as np
import pickle
import time
import random
import os
from model import Model
from net import *
from checkpointer import Checkpointer
from scheduler import ComboMultiStepLR
from model_z_gan import Model
from launcher import run
from defaults import get_cfg_defaults
import lod_driver


from checkpointer import Checkpointer
from scheduler import ComboMultiStepLR

from dlutils import batch_provider
from dlutils.pytorch.cuda_helper import *
from dlutils.pytorch import count_parameters
from defaults import get_cfg_defaults
import argparse
import logging
import sys
import bimpy
import lreq
from skimage.transform import resize
import utils

from PIL import Image
import bimpy
Expand All @@ -56,6 +29,20 @@
lreq.use_implicit_lreq.set(True)


indices = [0, 1, 2, 3, 4, 10, 11, 17, 19]

labels = ["gender",
"smile",
"attractive",
"wavy-hair",
"young",
"big lips",
"big nose",
"chubby",
"glasses",
]


def sample(cfg, logger):
torch.cuda.set_device(0)
model = Model(
Expand Down Expand Up @@ -114,44 +101,24 @@ def encode(x):
return Z

def decode(x):
layer_idx = torch.arange(2 * cfg.MODEL.LAYER_COUNT)[np.newaxis, :, np.newaxis]
layer_idx = torch.arange(2 * layer_count)[np.newaxis, :, np.newaxis]
ones = torch.ones(layer_idx.shape, dtype=torch.float32)
coefs = torch.where(layer_idx < model.truncation_cutoff, ones, ones)
# x = torch.lerp(model.dlatent_avg.buff.data, x, coefs)
return model.decoder(x, layer_count - 1, 1, noise=True)

path = 'realign1024_2_v'
#path = 'imagenet256x256'
# path = 'realign128x128'
path = 'dataset_samples/faces/realign1024x1024'

paths = list(os.listdir(path))

paths = ['00096.png', '00002.png', '00106.png', '00103.png', '00013.png', '00037.png']#sorted(paths)
#random.seed(3456)
#random.shuffle(paths)
paths.sort()
paths = ['00096.png', '00002.png', '00106.png', '00103.png', '00013.png', '00037.png']
randomize = bimpy.Bool(True)

ctx = bimpy.Context()
v0 = bimpy.Float(0)
v1 = bimpy.Float(0)
v2 = bimpy.Float(0)
v3 = bimpy.Float(0)
v4 = bimpy.Float(0)
v10 = bimpy.Float(0)
v11 = bimpy.Float(0)
v17 = bimpy.Float(0)
v19 = bimpy.Float(0)

w0 = torch.tensor(np.load("direction_%d.npy" % 0), dtype=torch.float32)
w1 = torch.tensor(np.load("direction_%d.npy" % 1), dtype=torch.float32)
w2 = torch.tensor(np.load("direction_%d.npy" % 2), dtype=torch.float32)
w3 = torch.tensor(np.load("direction_%d.npy" % 3), dtype=torch.float32)
w4 = torch.tensor(np.load("direction_%d.npy" % 4), dtype=torch.float32)
w10 = torch.tensor(np.load("direction_%d.npy" % 10), dtype=torch.float32)
w11 = torch.tensor(np.load("direction_%d.npy" % 11), dtype=torch.float32)
w17 = torch.tensor(np.load("direction_%d.npy" % 17), dtype=torch.float32)
w19 = torch.tensor(np.load("direction_%d.npy" % 19), dtype=torch.float32)

_latents = None

attribute_values = [bimpy.Float(0) for i in indices]

W = [torch.tensor(np.load("principal_directions/direction_%d.npy" % i), dtype=torch.float32) for i in indices]

def loadNext():
img = np.asarray(Image.open(path + '/' + paths[0]))
Expand All @@ -163,77 +130,75 @@ def loadNext():
x = torch.tensor(np.asarray(im, dtype=np.float32), device='cpu', requires_grad=True).cuda() / 127.5 - 1.
if x.shape[0] == 4:
x = x[:3]
_latents = encode(x[None, ...].cuda())
latents = _latents[0, 0]

needed_resolution = model.decoder.layer_to_resolution[-1]
while x.shape[2] > needed_resolution:
x = F.avg_pool2d(x, 2, 2)
if x.shape[2] != needed_resolution:
x = F.adaptive_avg_pool2d(x, (needed_resolution, needed_resolution))

img_src = ((x * 0.5 + 0.5) * 255).type(torch.long).clamp(0, 255).cpu().type(torch.uint8).transpose(0, 2).transpose(0, 1).numpy()

latents_original = encode(x[None, ...].cuda())
latents = latents_original[0, 0]
latents -= model.dlatent_avg.buff.data[0]

v0.value = (latents * w0).sum()
v1.value = (latents * w1).sum()
v2.value = (latents * w2).sum()
v3.value = (latents * w3).sum()
v4.value = (latents * w4).sum()
v10.value = (latents * w10).sum()
v11.value = (latents * w11).sum()
v17.value = (latents * w17).sum()
v19.value = (latents * w19).sum()

latents = latents - v0.value * w0
latents = latents - v1.value * w1
latents = latents - v2.value * w2
latents = latents - v3.value * w3
latents = latents - v10.value * w10
latents = latents - v11.value * w11
latents = latents - v17.value * w17
latents = latents - v19.value * w19
return latents, _latents, img_src

latents, _latents, img_src = loadNext()
for v, w in zip(attribute_values, W):
v.value = (latents * w).sum()

for v, w in zip(attribute_values, W):
latents = latents - v.value * w

return latents, latents_original, img_src

latents, latents_original, img_src = loadNext()

ctx.init(1800, 1600, "Styles")

def update_image(w, _w):
def update_image(w):
with torch.no_grad():
w = w + model.dlatent_avg.buff.data[0]
w = w[None, None, ...].repeat(1, model.mapping_fl.num_layers, 1)

layer_idx = torch.arange(model.mapping_fl.num_layers)[np.newaxis, :, np.newaxis]
cur_layers = (7 + 1) * 2
mixing_cutoff = cur_layers
styles = torch.where(layer_idx < mixing_cutoff, w, _latents[0])
styles = torch.where(layer_idx < mixing_cutoff, w, latents_original[0])

x_rec = decode(styles)
resultsample = ((x_rec * 0.5 + 0.5) * 255).type(torch.long).clamp(0, 255)
resultsample = resultsample.cpu()[0, :, :, :]
return resultsample.type(torch.uint8).transpose(0, 2).transpose(0, 1)

im = update_image(latents, _latents)
im = update_image(latents)
print(im.shape)
im = bimpy.Image(im)

display_original = True

while(not ctx.should_close()):
seed = 0

while not ctx.should_close():
with ctx:
W = latents + w0 * v0.value + w1 * v1.value + w2 * v2.value + w3 * v3.value + w4 * v4.value + w10 * v10.value + w11 * v11.value + w17 * v17.value + w19 * v19.value
new_latents = sum([v.value * w for v, w in zip(attribute_values, W)])

if display_original:
im = bimpy.Image(img_src)
else:
im = bimpy.Image(update_image(W, _latents))
im = bimpy.Image(update_image(new_latents))

# if bimpy.button('Ok'):
bimpy.image(im)
bimpy.begin("Controls")
bimpy.slider_float("female <-> male", v0, -30.0, 30.0)
bimpy.slider_float("smile", v1, -30.0, 30.0)
bimpy.slider_float("attractive", v2, -30.0, 30.0)
bimpy.slider_float("wavy-hair", v3, -30.0, 30.0)
bimpy.slider_float("young", v4, -30.0, 30.0)
bimpy.slider_float("big lips", v10, -30.0, 30.0)
bimpy.slider_float("big nose", v11, -30.0, 30.0)
bimpy.slider_float("chubby", v17, -30.0, 30.0)
bimpy.slider_float("glasses", v19, -30.0, 30.0)

for v, label in zip(attribute_values, labels):
bimpy.slider_float(label, v, -40.0, 40.0)

bimpy.checkbox("Randomize noise", randomize)

if randomize.value:
seed += 1

torch.manual_seed(seed)

if bimpy.button('Next'):
latents, _latents, img_src = loadNext()
Expand All @@ -242,10 +207,8 @@ def update_image(w, _w):
display_original = False
bimpy.end()

exit()


if __name__ == "__main__":
gpu_count = 1
run(sample, get_cfg_defaults(), description='StyleGAN', default_config='configs/experiment_ffhq_z.yaml',
run(sample, get_cfg_defaults(), description='ALAE-interactive', default_config='configs/ffhq.yaml',
world_size=gpu_count, write_log=False)
Binary file removed pioneer/9783_orig.jpg
Binary file not shown.
Loading

0 comments on commit 3afbac2

Please sign in to comment.