diff --git a/find_principal_directions.py b/find_principal_directions.py deleted file mode 100644 index 6972b354..00000000 --- a/find_principal_directions.py +++ /dev/null @@ -1,52 +0,0 @@ -import numpy as np -import sklearn.svm - - -method = { - 0:False, - 1:True, - 2:True, - 3:True, - 4:True, - 10:True, - 11:True, - 17:True, - 19:True, -} - -for attrib_idx in range(19, 20): - try: - results = np.load("wspace_att_%d.npy" % attrib_idx).item() - - pruned_indices = list(range(results['latents'].shape[0])) - pruned_indices = sorted(pruned_indices, key=lambda i: -np.max(results[attrib_idx][i])) - keep = int(results['latents'].shape[0] * 0.5) - print('Keeping: %d' % keep) - pruned_indices = pruned_indices[:keep] - - # Fit SVM to the remaining samples. - svm_targets = np.argmax(results[attrib_idx][pruned_indices], axis=1) - space = 'dlatents' - - svm_inputs = results[space][pruned_indices] - - if method[attrib_idx]: - svm = sklearn.svm.LinearSVC(C=0.1) - svm.fit(svm_inputs, svm_targets) - svm.score(svm_inputs, svm_targets) - svm_outputs = svm.predict(svm_inputs) - - w = svm.coef_[0] - - np.save("direction_%d" % attrib_idx, w) - - else: - c1 = (svm_inputs * svm_targets[..., None]).mean(axis=0) - c2 = (svm_inputs * (1.0 - svm_targets[..., None])).mean(axis=0) - - w = c1 - c2 - w = w / np.sqrt((w*w).sum()) - - np.save("direction_%d" % attrib_idx, w) - except: - pass diff --git a/gradient_reversal.py b/gradient_reversal.py deleted file mode 100644 index 8be39391..00000000 --- a/gradient_reversal.py +++ /dev/null @@ -1,15 +0,0 @@ -from torch.autograd import Function - - -class GradReverse(Function): - @staticmethod - def forward(ctx, x): - return x.clone() - - @staticmethod - def backward(ctx, grad_output): - return grad_output.neg() - - -def grad_reverse(x): - return GradReverse.apply(x) diff --git a/interactive_sliders.py b/interactive_sliders.py index 3dc4eccf..34a99647 100644 --- a/interactive_sliders.py +++ b/interactive_sliders.py @@ -1,4 +1,4 @@ -# Copyright 2019 Stanislav Pidhorskyi +# Copyright 2019-2020 Stanislav Pidhorskyi # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,41 +13,14 @@ # limitations under the License. # ============================================================================== -from __future__ import print_function import torch.utils.data -from scipy import misc -from torch import optim -from torchvision.utils import save_image from net import * -import numpy as np -import pickle -import time -import random -import os from model import Model -from net import * -from checkpointer import Checkpointer -from scheduler import ComboMultiStepLR -from model_z_gan import Model from launcher import run -from defaults import get_cfg_defaults -import lod_driver - - from checkpointer import Checkpointer -from scheduler import ComboMultiStepLR - -from dlutils import batch_provider -from dlutils.pytorch.cuda_helper import * from dlutils.pytorch import count_parameters from defaults import get_cfg_defaults -import argparse -import logging -import sys -import bimpy import lreq -from skimage.transform import resize -import utils from PIL import Image import bimpy @@ -56,6 +29,20 @@ lreq.use_implicit_lreq.set(True) +indices = [0, 1, 2, 3, 4, 10, 11, 17, 19] + +labels = ["gender", + "smile", + "attractive", + "wavy-hair", + "young", + "big lips", + "big nose", + "chubby", + "glasses", + ] + + def sample(cfg, logger): torch.cuda.set_device(0) model = Model( @@ -114,44 +101,24 @@ def encode(x): return Z def decode(x): - layer_idx = torch.arange(2 * cfg.MODEL.LAYER_COUNT)[np.newaxis, :, np.newaxis] + layer_idx = torch.arange(2 * layer_count)[np.newaxis, :, np.newaxis] ones = torch.ones(layer_idx.shape, dtype=torch.float32) coefs = torch.where(layer_idx < model.truncation_cutoff, ones, ones) # x = torch.lerp(model.dlatent_avg.buff.data, x, coefs) return model.decoder(x, layer_count - 1, 1, noise=True) - path = 'realign1024_2_v' - #path = 'imagenet256x256' - # path = 'realign128x128' + path = 'dataset_samples/faces/realign1024x1024' paths = list(os.listdir(path)) - - paths = ['00096.png', '00002.png', '00106.png', '00103.png', '00013.png', '00037.png']#sorted(paths) - #random.seed(3456) - #random.shuffle(paths) + paths.sort() + paths = ['00096.png', '00002.png', '00106.png', '00103.png', '00013.png', '00037.png'] + randomize = bimpy.Bool(True) ctx = bimpy.Context() - v0 = bimpy.Float(0) - v1 = bimpy.Float(0) - v2 = bimpy.Float(0) - v3 = bimpy.Float(0) - v4 = bimpy.Float(0) - v10 = bimpy.Float(0) - v11 = bimpy.Float(0) - v17 = bimpy.Float(0) - v19 = bimpy.Float(0) - - w0 = torch.tensor(np.load("direction_%d.npy" % 0), dtype=torch.float32) - w1 = torch.tensor(np.load("direction_%d.npy" % 1), dtype=torch.float32) - w2 = torch.tensor(np.load("direction_%d.npy" % 2), dtype=torch.float32) - w3 = torch.tensor(np.load("direction_%d.npy" % 3), dtype=torch.float32) - w4 = torch.tensor(np.load("direction_%d.npy" % 4), dtype=torch.float32) - w10 = torch.tensor(np.load("direction_%d.npy" % 10), dtype=torch.float32) - w11 = torch.tensor(np.load("direction_%d.npy" % 11), dtype=torch.float32) - w17 = torch.tensor(np.load("direction_%d.npy" % 17), dtype=torch.float32) - w19 = torch.tensor(np.load("direction_%d.npy" % 19), dtype=torch.float32) - - _latents = None + + attribute_values = [bimpy.Float(0) for i in indices] + + W = [torch.tensor(np.load("principal_directions/direction_%d.npy" % i), dtype=torch.float32) for i in indices] def loadNext(): img = np.asarray(Image.open(path + '/' + paths[0])) @@ -163,36 +130,32 @@ def loadNext(): x = torch.tensor(np.asarray(im, dtype=np.float32), device='cpu', requires_grad=True).cuda() / 127.5 - 1. if x.shape[0] == 4: x = x[:3] - _latents = encode(x[None, ...].cuda()) - latents = _latents[0, 0] + needed_resolution = model.decoder.layer_to_resolution[-1] + while x.shape[2] > needed_resolution: + x = F.avg_pool2d(x, 2, 2) + if x.shape[2] != needed_resolution: + x = F.adaptive_avg_pool2d(x, (needed_resolution, needed_resolution)) + + img_src = ((x * 0.5 + 0.5) * 255).type(torch.long).clamp(0, 255).cpu().type(torch.uint8).transpose(0, 2).transpose(0, 1).numpy() + + latents_original = encode(x[None, ...].cuda()) + latents = latents_original[0, 0] latents -= model.dlatent_avg.buff.data[0] - v0.value = (latents * w0).sum() - v1.value = (latents * w1).sum() - v2.value = (latents * w2).sum() - v3.value = (latents * w3).sum() - v4.value = (latents * w4).sum() - v10.value = (latents * w10).sum() - v11.value = (latents * w11).sum() - v17.value = (latents * w17).sum() - v19.value = (latents * w19).sum() - - latents = latents - v0.value * w0 - latents = latents - v1.value * w1 - latents = latents - v2.value * w2 - latents = latents - v3.value * w3 - latents = latents - v10.value * w10 - latents = latents - v11.value * w11 - latents = latents - v17.value * w17 - latents = latents - v19.value * w19 - return latents, _latents, img_src - - latents, _latents, img_src = loadNext() + for v, w in zip(attribute_values, W): + v.value = (latents * w).sum() + + for v, w in zip(attribute_values, W): + latents = latents - v.value * w + + return latents, latents_original, img_src + + latents, latents_original, img_src = loadNext() ctx.init(1800, 1600, "Styles") - def update_image(w, _w): + def update_image(w): with torch.no_grad(): w = w + model.dlatent_avg.buff.data[0] w = w[None, None, ...].repeat(1, model.mapping_fl.num_layers, 1) @@ -200,40 +163,42 @@ def update_image(w, _w): layer_idx = torch.arange(model.mapping_fl.num_layers)[np.newaxis, :, np.newaxis] cur_layers = (7 + 1) * 2 mixing_cutoff = cur_layers - styles = torch.where(layer_idx < mixing_cutoff, w, _latents[0]) + styles = torch.where(layer_idx < mixing_cutoff, w, latents_original[0]) x_rec = decode(styles) resultsample = ((x_rec * 0.5 + 0.5) * 255).type(torch.long).clamp(0, 255) resultsample = resultsample.cpu()[0, :, :, :] return resultsample.type(torch.uint8).transpose(0, 2).transpose(0, 1) - im = update_image(latents, _latents) + im = update_image(latents) print(im.shape) im = bimpy.Image(im) display_original = True - while(not ctx.should_close()): + seed = 0 + + while not ctx.should_close(): with ctx: - W = latents + w0 * v0.value + w1 * v1.value + w2 * v2.value + w3 * v3.value + w4 * v4.value + w10 * v10.value + w11 * v11.value + w17 * v17.value + w19 * v19.value + new_latents = sum([v.value * w for v, w in zip(attribute_values, W)]) if display_original: im = bimpy.Image(img_src) else: - im = bimpy.Image(update_image(W, _latents)) + im = bimpy.Image(update_image(new_latents)) - # if bimpy.button('Ok'): bimpy.image(im) bimpy.begin("Controls") - bimpy.slider_float("female <-> male", v0, -30.0, 30.0) - bimpy.slider_float("smile", v1, -30.0, 30.0) - bimpy.slider_float("attractive", v2, -30.0, 30.0) - bimpy.slider_float("wavy-hair", v3, -30.0, 30.0) - bimpy.slider_float("young", v4, -30.0, 30.0) - bimpy.slider_float("big lips", v10, -30.0, 30.0) - bimpy.slider_float("big nose", v11, -30.0, 30.0) - bimpy.slider_float("chubby", v17, -30.0, 30.0) - bimpy.slider_float("glasses", v19, -30.0, 30.0) + + for v, label in zip(attribute_values, labels): + bimpy.slider_float(label, v, -40.0, 40.0) + + bimpy.checkbox("Randomize noise", randomize) + + if randomize.value: + seed += 1 + + torch.manual_seed(seed) if bimpy.button('Next'): latents, _latents, img_src = loadNext() @@ -242,10 +207,8 @@ def update_image(w, _w): display_original = False bimpy.end() - exit() - if __name__ == "__main__": gpu_count = 1 - run(sample, get_cfg_defaults(), description='StyleGAN', default_config='configs/experiment_ffhq_z.yaml', + run(sample, get_cfg_defaults(), description='ALAE-interactive', default_config='configs/ffhq.yaml', world_size=gpu_count, write_log=False) diff --git a/pioneer/9783_orig.jpg b/pioneer/9783_orig.jpg deleted file mode 100644 index c713a6e8..00000000 Binary files a/pioneer/9783_orig.jpg and /dev/null differ diff --git a/principal_direction_miner.py b/principal_direction_miner.py deleted file mode 100644 index a8d75124..00000000 --- a/principal_direction_miner.py +++ /dev/null @@ -1,331 +0,0 @@ -# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. -# -# This work is licensed under the Creative Commons Attribution-NonCommercial -# 4.0 International License. To view a copy of this license, visit -# http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to -# Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. - -"""Linear Separability (LS).""" -import tensorflow as tf -import torch -import dnnlib -import dnnlib.tflib -import dnnlib.tflib as tflib -import pickle -from dataloader import * -import scipy.linalg - -from checkpointer import Checkpointer -from scheduler import ComboMultiStepLR - -from dlutils import batch_provider -from dlutils.pytorch.cuda_helper import * -from dlutils.pytorch import count_parameters -from defaults import get_cfg_defaults -from model import Model -import argparse -import logging -import sys -import lreq -from skimage.transform import resize -from tqdm import tqdm - -from launcher import run -from PIL import Image -from matplotlib import pyplot as plt -import utils -from net import * - -from collections import defaultdict -import numpy as np -import sklearn.svm -import tensorflow as tf -import dnnlib.tflib as tflib - -from metrics import metric_base -from training import misc - -dnnlib.tflib.init_tf() -tf_config = {'rnd.np_random_seed': 1000} - - - -#---------------------------------------------------------------------------- - -classifier_urls = [ - 'https://drive.google.com/uc?id=1Q5-AI6TwWhCVM7Muu4tBM7rp5nG_gmCX', # celebahq-classifier-00-male.pkl - 'https://drive.google.com/uc?id=1Q5c6HE__ReW2W8qYAXpao68V1ryuisGo', # celebahq-classifier-01-smiling.pkl - 'https://drive.google.com/uc?id=1Q7738mgWTljPOJQrZtSMLxzShEhrvVsU', # celebahq-classifier-02-attractive.pkl - 'https://drive.google.com/uc?id=1QBv2Mxe7ZLvOv1YBTLq-T4DS3HjmXV0o', # celebahq-classifier-03-wavy-hair.pkl - 'https://drive.google.com/uc?id=1QIvKTrkYpUrdA45nf7pspwAqXDwWOLhV', # celebahq-classifier-04-young.pkl - 'https://drive.google.com/uc?id=1QJPH5rW7MbIjFUdZT7vRYfyUjNYDl4_L', # celebahq-classifier-05-5-o-clock-shadow.pkl - 'https://drive.google.com/uc?id=1QPZXSYf6cptQnApWS_T83sqFMun3rULY', # celebahq-classifier-06-arched-eyebrows.pkl - 'https://drive.google.com/uc?id=1QPgoAZRqINXk_PFoQ6NwMmiJfxc5d2Pg', # celebahq-classifier-07-bags-under-eyes.pkl - 'https://drive.google.com/uc?id=1QQPQgxgI6wrMWNyxFyTLSgMVZmRr1oO7', # celebahq-classifier-08-bald.pkl - 'https://drive.google.com/uc?id=1QcSphAmV62UrCIqhMGgcIlZfoe8hfWaF', # celebahq-classifier-09-bangs.pkl - 'https://drive.google.com/uc?id=1QdWTVwljClTFrrrcZnPuPOR4mEuz7jGh', # celebahq-classifier-10-big-lips.pkl - 'https://drive.google.com/uc?id=1QgvEWEtr2mS4yj1b_Y3WKe6cLWL3LYmK', # celebahq-classifier-11-big-nose.pkl - 'https://drive.google.com/uc?id=1QidfMk9FOKgmUUIziTCeo8t-kTGwcT18', # celebahq-classifier-12-black-hair.pkl - 'https://drive.google.com/uc?id=1QthrJt-wY31GPtV8SbnZQZ0_UEdhasHO', # celebahq-classifier-13-blond-hair.pkl - 'https://drive.google.com/uc?id=1QvCAkXxdYT4sIwCzYDnCL9Nb5TDYUxGW', # celebahq-classifier-14-blurry.pkl - 'https://drive.google.com/uc?id=1QvLWuwSuWI9Ln8cpxSGHIciUsnmaw8L0', # celebahq-classifier-15-brown-hair.pkl - 'https://drive.google.com/uc?id=1QxW6THPI2fqDoiFEMaV6pWWHhKI_OoA7', # celebahq-classifier-16-bushy-eyebrows.pkl - 'https://drive.google.com/uc?id=1R71xKw8oTW2IHyqmRDChhTBkW9wq4N9v', # celebahq-classifier-17-chubby.pkl - 'https://drive.google.com/uc?id=1RDn_fiLfEGbTc7JjazRXuAxJpr-4Pl67', # celebahq-classifier-18-double-chin.pkl - 'https://drive.google.com/uc?id=1RGBuwXbaz5052bM4VFvaSJaqNvVM4_cI', # celebahq-classifier-19-eyeglasses.pkl - 'https://drive.google.com/uc?id=1RIxOiWxDpUwhB-9HzDkbkLegkd7euRU9', # celebahq-classifier-20-goatee.pkl - 'https://drive.google.com/uc?id=1RPaNiEnJODdr-fwXhUFdoSQLFFZC7rC-', # celebahq-classifier-21-gray-hair.pkl - 'https://drive.google.com/uc?id=1RQH8lPSwOI2K_9XQCZ2Ktz7xm46o80ep', # celebahq-classifier-22-heavy-makeup.pkl - 'https://drive.google.com/uc?id=1RXZM61xCzlwUZKq-X7QhxOg0D2telPow', # celebahq-classifier-23-high-cheekbones.pkl - 'https://drive.google.com/uc?id=1RgASVHW8EWMyOCiRb5fsUijFu-HfxONM', # celebahq-classifier-24-mouth-slightly-open.pkl - 'https://drive.google.com/uc?id=1RkC8JLqLosWMaRne3DARRgolhbtg_wnr', # celebahq-classifier-25-mustache.pkl - 'https://drive.google.com/uc?id=1RqtbtFT2EuwpGTqsTYJDyXdnDsFCPtLO', # celebahq-classifier-26-narrow-eyes.pkl - 'https://drive.google.com/uc?id=1Rs7hU-re8bBMeRHR-fKgMbjPh-RIbrsh', # celebahq-classifier-27-no-beard.pkl - 'https://drive.google.com/uc?id=1RynDJQWdGOAGffmkPVCrLJqy_fciPF9E', # celebahq-classifier-28-oval-face.pkl - 'https://drive.google.com/uc?id=1S0TZ_Hdv5cb06NDaCD8NqVfKy7MuXZsN', # celebahq-classifier-29-pale-skin.pkl - 'https://drive.google.com/uc?id=1S3JPhZH2B4gVZZYCWkxoRP11q09PjCkA', # celebahq-classifier-30-pointy-nose.pkl - 'https://drive.google.com/uc?id=1S3pQuUz-Jiywq_euhsfezWfGkfzLZ87W', # celebahq-classifier-31-receding-hairline.pkl - 'https://drive.google.com/uc?id=1S6nyIl_SEI3M4l748xEdTV2vymB_-lrY', # celebahq-classifier-32-rosy-cheeks.pkl - 'https://drive.google.com/uc?id=1S9P5WCi3GYIBPVYiPTWygrYIUSIKGxbU', # celebahq-classifier-33-sideburns.pkl - 'https://drive.google.com/uc?id=1SANviG-pp08n7AFpE9wrARzozPIlbfCH', # celebahq-classifier-34-straight-hair.pkl - 'https://drive.google.com/uc?id=1SArgyMl6_z7P7coAuArqUC2zbmckecEY', # celebahq-classifier-35-wearing-earrings.pkl - 'https://drive.google.com/uc?id=1SC5JjS5J-J4zXFO9Vk2ZU2DT82TZUza_', # celebahq-classifier-36-wearing-hat.pkl - 'https://drive.google.com/uc?id=1SDAQWz03HGiu0MSOKyn7gvrp3wdIGoj-', # celebahq-classifier-37-wearing-lipstick.pkl - 'https://drive.google.com/uc?id=1SEtrVK-TQUC0XeGkBE9y7L8VXfbchyKX', # celebahq-classifier-38-wearing-necklace.pkl - 'https://drive.google.com/uc?id=1SF_mJIdyGINXoV-I6IAxHB_k5dxiF6M-', # celebahq-classifier-39-wearing-necktie.pkl -] - -#---------------------------------------------------------------------------- - -def prob_normalize(p): - p = np.asarray(p).astype(np.float32) - assert len(p.shape) == 2 - return p / np.sum(p) - -def mutual_information(p): - p = prob_normalize(p) - px = np.sum(p, axis=1) - py = np.sum(p, axis=0) - result = 0.0 - for x in range(p.shape[0]): - p_x = px[x] - for y in range(p.shape[1]): - p_xy = p[x][y] - p_y = py[y] - if p_xy > 0.0: - result += p_xy * np.log2(p_xy / (p_x * p_y)) # get bits as output - return result - -def entropy(p): - p = prob_normalize(p) - result = 0.0 - for x in range(p.shape[0]): - for y in range(p.shape[1]): - p_xy = p[x][y] - if p_xy > 0.0: - result -= p_xy * np.log2(p_xy) - return result - -def conditional_entropy(p): - # H(Y|X) where X corresponds to axis 0, Y to axis 1 - # i.e., How many bits of additional information are needed to where we are on axis 1 if we know where we are on axis 0? - p = prob_normalize(p) - y = np.sum(p, axis=0, keepdims=True) # marginalize to calculate H(Y) - return max(0.0, entropy(y) - mutual_information(p)) # can slip just below 0 due to FP inaccuracies, clean those up. - -#---------------------------------------------------------------------------- - -def parse_tfrecord_np(record): - ex = tf.train.Example() - ex.ParseFromString(record) - shape = ex.features.feature['shape'].int64_list.value - data = ex.features.feature['data'].bytes_list.value[0] - dlat = ex.features.feature['dlat'].bytes_list.value[0] - lat = ex.features.feature['lat'].bytes_list.value[0] - return np.fromstring(data, np.uint8).reshape(shape), np.fromstring(dlat, np.float32), np.fromstring(lat, np.float32) - - -class LS: - def __init__(self, cfg, percent_keep, attrib_indices, minibatch_gpu): - self.percent_keep = percent_keep - self.attrib_indices = attrib_indices - self.minibatch_size = minibatch_gpu - self.cfg = cfg - - def evaluate(self, logger, mapping, decoder, lod, attrib_idx): - result_expr = [] - - rnd = np.random.RandomState(5) - - with tf.Graph().as_default(), tf.Session() as sess: - ds = tf.data.TFRecordDataset("generated_data.000") - ds = ds.batch(self.minibatch_size) - batch = ds.make_one_shot_iterator().get_next() - - classifier = misc.load_pkl(classifier_urls[attrib_idx]) - - i = 0 - while True: - try: - records = sess.run(batch) - images = [] - dlats = [] - lats = [] - for r in records: - im, dlat, lat = parse_tfrecord_np(r) - - # plt.imshow(im.transpose(1, 2, 0), interpolation='nearest') - # plt.show() - - images.append(im) - dlats.append(dlat) - lats.append(lat) - images = np.stack(images) - dlats = np.stack(dlats) - lats = np.stack(lats) - logits = classifier.run(images, None, num_gpus=2, assume_frozen=True) - logits = torch.tensor(logits) - predictions = torch.softmax(torch.cat([logits, -logits], dim=1), dim=1) - - result_dict = dict(latents=lats, dlatents=dlats) - result_dict[attrib_idx] = predictions.cpu().numpy() - result_expr.append(result_dict) - i += 1 - except tf.errors.OutOfRangeError: - break - - results = {key: np.concatenate([value[key] for value in result_expr], axis=0) for key in result_expr[0].keys()} - - np.save("wspace_att_%d" % attrib_idx, results) - - exit() - - conditional_entropies = defaultdict(list) - idx = attrib_idx - for attrib_idx in [idx]: # self.attrib_indices: - pruned_indices = list(range(results['latents'].shape[0])) - pruned_indices = sorted(pruned_indices, key=lambda i: -np.max(results[attrib_idx][i])) - keep = int(results['latents'].shape[0] * self.percent_keep) - print('Keeping: %d' % keep) - pruned_indices = pruned_indices[:keep] - - # Fit SVM to the remaining samples. - svm_targets = np.argmax(results[attrib_idx][pruned_indices], axis=1) - for space in ['latents', 'dlatents']: - svm_inputs = results[space][pruned_indices] - try: - svm = sklearn.svm.LinearSVC() - svm.fit(svm_inputs, svm_targets) - svm.score(svm_inputs, svm_targets) - svm_outputs = svm.predict(svm_inputs) - except: - svm_outputs = svm_targets # assume perfect prediction - - # Calculate conditional entropy. - p = [[np.mean([case == (row, col) for case in zip(svm_outputs, svm_targets)]) for col in (0, 1)] for row in - (0, 1)] - conditional_entropies[space].append(conditional_entropy(p)) - - scores = {key: 2 ** np.sum(values) for key, values in conditional_entropies.items()} - - logger.info("latents Z Result = %f" % (scores['latents'])) - logger.info("dlatents W Result = %f" % (scores['dlatents'])) - - - # # Calculate conditional entropy for each attribute. - # conditional_entropies = defaultdict(list) - # for attrib_idx in self.attrib_indices: - # # Prune the least confident samples. - # pruned_indices = list(range(self.num_samples)) - # pruned_indices = sorted(pruned_indices, key=lambda i: -np.max(results[attrib_idx][i])) - # pruned_indices = pruned_indices[:self.num_keep] - # - # # Fit SVM to the remaining samples. - # svm_targets = np.argmax(results[attrib_idx][pruned_indices], axis=1) - # for space in ['latents', 'dlatents']: - # svm_inputs = results[space][pruned_indices] - # try: - # svm = sklearn.svm.LinearSVC() - # svm.fit(svm_inputs, svm_targets) - # svm.score(svm_inputs, svm_targets) - # svm_outputs = svm.predict(svm_inputs) - # except: - # svm_outputs = svm_targets # assume perfect prediction - # - # # Calculate conditional entropy. - # p = [[np.mean([case == (row, col) for case in zip(svm_outputs, svm_targets)]) for col in (0, 1)] for row in (0, 1)] - # conditional_entropies[space].append(conditional_entropy(p)) - # - # # Calculate separability scores. - # scores = {key: 2**np.sum(values) for key, values in conditional_entropies.items()} - # - # logger.info("latents Z Result = %f" % (scores['latents'])) - # logger.info("dlatents W Result = %f" % (scores['dlatents'])) - - -def sample(cfg, logger): - torch.cuda.set_device(0) - model = Model( - startf=cfg.MODEL.START_CHANNEL_COUNT, - layer_count=cfg.MODEL.LAYER_COUNT, - maxf=cfg.MODEL.MAX_CHANNEL_COUNT, - latent_size=cfg.MODEL.LATENT_SPACE_SIZE, - truncation_psi=cfg.MODEL.TRUNCATIOM_PSI, - truncation_cutoff=cfg.MODEL.TRUNCATIOM_CUTOFF, - mapping_layers=cfg.MODEL.MAPPING_LAYERS, - channels=cfg.MODEL.CHANNELS, - generator=cfg.MODEL.GENERATOR, - encoder=cfg.MODEL.ENCODER) - - model.cuda(0) - model.eval() - model.requires_grad_(False) - - decoder = model.decoder - encoder = model.encoder - mapping_tl = model.mapping_tl - mapping_fl = model.mapping_fl - dlatent_avg = model.dlatent_avg - - logger.info("Trainable parameters generator:") - count_parameters(decoder) - - logger.info("Trainable parameters discriminator:") - count_parameters(encoder) - - arguments = dict() - arguments["iteration"] = 0 - - model_dict = { - 'discriminator_s': encoder, - 'generator_s': decoder, - 'mapping_tl_s': mapping_tl, - 'mapping_fl_s': mapping_fl, - 'dlatent_avg': dlatent_avg - } - - checkpointer = Checkpointer(cfg, - model_dict, - {}, - logger=logger, - save=False) - - checkpointer.load() - - model.eval() - - layer_count = cfg.MODEL.LAYER_COUNT - - logger.info("Evaluating LS metric") - - decoder = nn.DataParallel(decoder) - - with torch.no_grad(): - ppl = LS(cfg, percent_keep=0.5, attrib_indices=range(40), minibatch_gpu=4) - ppl.evaluate(logger, mapping_fl, decoder, cfg.DATASET.MAX_RESOLUTION_LEVEL - 2, 19) - - -if __name__ == "__main__": - gpu_count = 1 - run(sample, get_cfg_defaults(), description='StyleGAN', default_config='configs/experiment_ffhq_z.yaml', - world_size=gpu_count, write_log=False) diff --git a/principal_directions/find_principal_directions.py b/principal_directions/find_principal_directions.py new file mode 100644 index 00000000..616c794e --- /dev/null +++ b/principal_directions/find_principal_directions.py @@ -0,0 +1,49 @@ +# Copyright 2019-2020 Stanislav Pidhorskyi +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import numpy as np +import sklearn.svm +import multiprocessing as mp + +indices = [0, 1, 2, 3, 4, 10, 11, 17, 19] + + +def run(attrib_idx): + results = np.load("principal_directions/wspace_att_%d.npy" % attrib_idx).item() + + pruned_indices = list(range(results['latents'].shape[0])) + # pruned_indices = sorted(pruned_indices, key=lambda i: -np.max(results[attrib_idx][i])) + # keep = int(results['latents'].shape[0] * 0.95) + # print('Keeping: %d' % keep) + # pruned_indices = pruned_indices[:keep] + + # Fit SVM to the remaining samples. + svm_targets = np.argmax(results[attrib_idx][pruned_indices], axis=1) + space = 'dlatents' + + svm_inputs = results[space][pruned_indices] + + svm = sklearn.svm.LinearSVC(C=1.0, dual=False, max_iter=10000) + svm.fit(svm_inputs, svm_targets) + svm.score(svm_inputs, svm_targets) + svm_outputs = svm.predict(svm_inputs) + + w = svm.coef_[0] + + np.save("principal_directions/direction_%d" % attrib_idx, w) + + +p = mp.Pool(processes=4) +p.map(run, indices) diff --git a/utils.py b/utils.py index c1c33bde..0cb74245 100644 --- a/utils.py +++ b/utils.py @@ -1,3 +1,18 @@ +# Copyright 2019-2020 Stanislav Pidhorskyi +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + from torch import nn import torch import threading