Skip to content

Commit

Permalink
add net implementations for different nets + cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
nata1y committed Jun 11, 2021
1 parent a0c1375 commit 9174028
Show file tree
Hide file tree
Showing 9 changed files with 143 additions and 77 deletions.
6 changes: 3 additions & 3 deletions configs/experiment.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
---
# Experiment configuration
total_epochs: 100
epochs_per_cycle: 5
total_epochs: 50
epochs_per_cycle: 1
wait_for_clients: true
net: Cifar10CNN
dataset: cifar10
Expand All @@ -14,7 +14,7 @@ tensor_board_active: true
clients_per_round: 5
system:
federator:
hostname: '192.168.1.65'
hostname: '10.46.226.83'
nic: 'wlp0s20f3'
clients:
amount: 5
54 changes: 33 additions & 21 deletions fltk/client_fegan.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,9 @@
from torch.autograd import Variable
import logging
import numpy as np

import gc
from fltk.client import Client
from fltk.util.wassdistance import SinkhornDistance
from fltk.util.weight_init import *

from fltk.util.results import EpochData, FeGANEpochData
Expand Down Expand Up @@ -53,6 +54,7 @@ def __init__(self, id, log_rref, rank, world_size, config=None):
logging.info(f'Welcome to FE client {id}')
self.latent_dim = 10
self.batch_size = 3000
self.sinkhorn = SinkhornDistance(eps=0.1, max_iter=100)

def return_distribution(self):
labels = self.dataset.load_train_dataset()[1]
Expand All @@ -68,7 +70,7 @@ def init_dataloader(self, ):
del lbl
self.finished_init = True

self.batch_size = 100
self.batch_size = 3000

logging.info('Done with init')

Expand All @@ -82,12 +84,12 @@ def B_hat(self, disc_out):
return (1 / self.batch_size) * torch.sum(torch.log(torch.clamp(1 - disc_out, min=self.epsilon)))

def wasserstein_loss(self, y_true, y_pred):
return np.mean(y_true * y_pred)
return torch.mean(y_true * y_pred)

def train_fe(self, epoch, net):
generator, discriminator = net
generator.zero_grad()
discriminator.zero_grad()
# generator.zero_grad()
# discriminator.zero_grad()
optimizer_generator = torch.optim.Adam(generator.parameters(),
lr=0.0002,
betas=(0.5, 0.999))
Expand All @@ -98,23 +100,23 @@ def train_fe(self, epoch, net):
if self.args.distributed:
self.dataset.train_sampler.set_epoch(epoch)

inputs = torch.from_numpy(self.inputs[[random.randrange(self.inputs.shape[0]) for _ in range(self.batch_size)]]).detach()
inputs = torch.from_numpy(self.inputs[[random.randrange(self.inputs.shape[0]) for _ in range(self.batch_size)]])

optimizer_generator.zero_grad()
optimizer_discriminator.zero_grad()

noise = Variable(torch.FloatTensor(np.random.normal(0, 1, (self.batch_size, self.latent_dim))))
generated_imgs = generator(noise)
noise = Variable(torch.FloatTensor(np.random.normal(0, 1, (self.batch_size, self.latent_dim))), requires_grad=False)
generated_imgs = generator(noise.detach())
d_generator = discriminator(generated_imgs)
generator_loss = self.J_generator(d_generator)
# generator_loss = torch.tensor([wasserstein_1d(torch.flatten(d_generator).detach().numpy(),
# torch.ones(self.batch_size).detach().numpy())], requires_grad=True)
# generator_loss = torch.tensor([-1 * self.wasserstein_loss(torch.flatten(d_generator).detach().numpy(),
# torch.ones(self.batch_size).detach().numpy())], requires_grad=True)
generator_loss.require_grad = True
generator_loss.backward()
# generator_loss, _, _ = self.sinkhorn(d_generator, torch.ones(self.batch_size, 1))
# generator_loss = self.wasserstein_loss(d_generator, torch.ones(self.batch_size))
# generator_loss.require_grad = True
generator_loss.backward(retain_graph=True)

fake_loss = self.B_hat(d_generator)
fake_loss = self.B_hat(d_generator.detach())
real_loss = self.A_hat(discriminator(inputs))

discriminator_loss = 0.5 * (real_loss + fake_loss)
Expand All @@ -123,35 +125,45 @@ def train_fe(self, epoch, net):
# torch.zeros(self.batch_size).detach().numpy())
# real_loss = wasserstein_1d(torch.flatten(discriminator(inputs)).detach().numpy(),
# torch.ones(self.batch_size).detach().numpy())
# fake_loss = self.wasserstein_loss(torch.flatten(d_generator).detach().numpy(),
# torch.zeros(self.batch_size).detach().numpy())
# real_loss = self.wasserstein_loss(torch.flatten(discriminator(inputs)).detach().numpy(),
# torch.ones(self.batch_size).detach().numpy())
# fake_loss, _, _ = self.sinkhorn(d_generator, torch.zeros(self.batch_size, 1))
# real_loss, _, _ = self.sinkhorn(discriminator(inputs), torch.ones(self.batch_size, 1))

# generated_imgs = generator(noise.detach())
# d_generator = discriminator(generated_imgs.detach())
# fake_loss = self.wasserstein_loss(d_generator.detach(),
# (-1) * torch.ones(self.batch_size))
# real_loss = self.wasserstein_loss(discriminator(inputs),
# torch.ones(self.batch_size))
# discriminator_loss = torch.tensor([real_loss + fake_loss], requires_grad=True)
# discriminator_loss.require_grad = True
discriminator_loss.backward()
generator_loss.backward()

optimizer_generator.step()
optimizer_discriminator.step()
# generator_loss._grad = None
# discriminator_loss._grad = None

# save model
if self.args.should_save_model(epoch):
self.save_model(epoch, self.args.get_epoch_save_end_suffix())

del noise, generated_imgs, fake_loss, real_loss, discriminator_loss, generator_loss, \
optimizer_discriminator, optimizer_generator, d_generator, inputs
gc.collect()

return generator, discriminator

def run_epochs(self, num_epoch, net=None):
start_time_train = datetime.datetime.now()
gen, disc = None, None

for e in range(num_epoch):
gen, disc = self.train_fe(self.epoch_counter, net)
net = self.train_fe(self.epoch_counter, net)
self.epoch_counter += 1
gc.collect()
elapsed_time_train = datetime.datetime.now() - start_time_train
train_time_ms = int(elapsed_time_train.total_seconds() * 1000)

data = FeGANEpochData(self.epoch_counter, train_time_ms, 0, (gen, disc), client_id=self.id)
data = FeGANEpochData(self.epoch_counter, train_time_ms, 0, net, client_id=self.id)
# self.epoch_results.append(data)

return data
35 changes: 16 additions & 19 deletions fltk/client_mdgan.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from fltk.client import Client
from fltk.util.weight_init import *
from fltk.nets.ls_gan import *
from fltk.nets.cifar_ls_gan import *

from fltk.util.results import EpochData, GANEpochData

Expand Down Expand Up @@ -76,29 +76,20 @@ def train_md(self, epoch, Xd):
if self.args.distributed:
self.dataset.train_sampler.set_epoch(epoch)

for batch_idx in (0, 1): #range(len(self.imgs) // self.batch_size + 1):
for batch_idx in range(len(self.imgs) // self.batch_size + 1):

rnd_indices = np.random.choice(len(self.imgs), size=self.batch_size)
inputs = torch.from_numpy(self.imgs[rnd_indices])
# try:
# inputs = torch.from_numpy(self.imgs[batch_idx * self.batch_size:(batch_idx + 1) * self.batch_size])
# except:
# inputs = torch.from_numpy(self.imgs[batch_idx * self.batch_size:])
try:
inputs = torch.from_numpy(self.imgs[batch_idx * self.batch_size:(batch_idx + 1) * self.batch_size])
except:
inputs = torch.from_numpy(self.imgs[batch_idx * self.batch_size:])

# rnd_indices = np.random.choice(len(self.imgs), size=self.batch_size)
# inputs = torch.from_numpy(self.imgs[rnd_indices])

# zero the parameter gradients
self.optimizer.zero_grad()

# outputs = self.discriminator(inputs)

# not sure about loss function - shall we use A_hat / B_hat?
# fake_loss = self.adversarial_loss(self.discriminator(Xd.detach()), fake)
# real_loss = self.adversarial_loss(outputs, labels)
# valid = Variable(torch.FloatTensor(self.batch_size, 1).fill_(1.0), requires_grad=False)
# fake = Variable(torch.FloatTensor(self.batch_size, 1).fill_(0.0), requires_grad=False)
# fake_loss = F.binary_cross_entropy(self.discriminator(Xd.detach()), fake)
# real_loss = F.binary_cross_entropy(self.discriminator(inputs), valid)

disc_out = torch.clamp(1 - self.discriminator(Xd.detach()), min=self.epsilon)
disc_out = torch.clamp(1 - self.discriminator(Xd), min=self.epsilon)
fake_loss = self.B_hat(disc_out)
real_loss = self.A_hat(inputs)

Expand Down Expand Up @@ -140,6 +131,11 @@ def J_generator(self, disc_out):
return (1 / self.batch_size) * torch.sum(torch.log(torch.clamp(1 - disc_out, min=self.epsilon)))

def run_epochs(self, num_epoch, current_epoch=1, Xs=None):
try:
self.loss_g.zero_grad()
except:
pass

start_time_train = datetime.datetime.now()
Xd, Xg = Xs

Expand All @@ -149,6 +145,7 @@ def run_epochs(self, num_epoch, current_epoch=1, Xs=None):

d_generator = self.discriminator(Xg)
self.loss_g = self.J_generator(d_generator)
self.loss_g.backward(retain_graph=True)

elapsed_time_train = datetime.datetime.now() - start_time_train
train_time_ms = int(elapsed_time_train.total_seconds()*1000)
Expand Down
13 changes: 7 additions & 6 deletions fltk/federator_fegan.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from fltk.federator import *
from fltk.strategy.client_selection import balanced_sampling
from fltk.util.fed_avg import kl_weighting
from fltk.nets.ls_gan import *
from fltk.nets.cifar_ls_gan import *
from fltk.util.fid_score import calculate_activation_statistics, calculate_frechet_distance
from fltk.util.inception import InceptionV3
from fltk.util.weight_init import *
Expand All @@ -30,7 +30,7 @@ def __init__(self, client_id_triple, num_epochs=3, config=None):
self.discriminator = Discriminator()
self.discriminator.apply(weights_init_normal)
self.latent_dim = 10
self.batch_size = 100
self.batch_size = 3000
self.fids = []
self.inceptions = []
self.fic_model = InceptionV3()
Expand Down Expand Up @@ -111,22 +111,23 @@ def init_groups(self, c, label_size):
done_q = True

self.groups.append(g)
# TODO: should be hyperparam
if len(self.groups) > 10:
done = True

def test(self, fl_round):
test_batch = 1000
with torch.no_grad():
file_output = f'./{self.config.output_location}'
self.ensure_path_exists(file_output)
fid_z = Variable(torch.FloatTensor(np.random.normal(0, 1, (self.test_imgs.shape[0], self.latent_dim))))
fid_z = Variable(torch.FloatTensor(np.random.normal(0, 1, (test_batch, self.latent_dim))))
gen_imgs = self.generator(fid_z.detach())
mu_gen, sigma_gen = calculate_activation_statistics(gen_imgs, self.fic_model)
mu_test, sigma_test = calculate_activation_statistics(torch.from_numpy(self.test_imgs), self.fic_model)
mu_test, sigma_test = calculate_activation_statistics(torch.from_numpy(self.test_imgs[:test_batch]), self.fic_model)
fid = calculate_frechet_distance(mu_gen, sigma_gen, mu_test, sigma_test)
print("FL-round {} FID Score: {}, IS Score: {}".format(fl_round, fid, mu_gen))

self.fids.append(fid)
print(self.fids)
# self.inceptions.append(mu_gen)

def checkpoint(self, fl_round):
Expand All @@ -148,7 +149,7 @@ def plot_score_data(self):
plt.xlabel('Federator runs')
plt.ylabel('FID')

filename = f'{file_output}/{self.config.epochs}_epochs_fe_gan_wassloss.png'
filename = f'{file_output}/{self.config.epochs}_epochs_fe_gan_wd4.png'
logging.info(f'Saving data at {filename}')

plt.savefig(filename)
Expand Down
43 changes: 22 additions & 21 deletions fltk/federator_mdgan.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from fltk.client_mdgan import ClientMDGAN, _remote_method_async
from fltk.util.fid_score import calculate_activation_statistics, calculate_frechet_distance
from fltk.util.inception import InceptionV3
from fltk.nets.ls_gan import *
from fltk.nets.cifar_ls_gan import *
from fltk.util.weight_init import *
import logging
import matplotlib.pyplot as plt
Expand All @@ -36,7 +36,7 @@ def __init__(self, client_id_triple, num_epochs=3, config=None):
self.latent_dim = 10
# K <= N
self.k = len(client_id_triple) - 1
self.batch_size = math.floor(self.dataset[0].shape[0] / self.k) // 10
self.batch_size = math.floor(self.dataset[0].shape[0] / self.k) // 2
self.introduce_clients()
self.fids = []
self.discriminator = Discriminator(32)
Expand Down Expand Up @@ -116,27 +116,28 @@ def remote_run_epoch(self, epochs, fl_round=0):
# broadcast generated datasets to clients and get trained discriminators back
selected_clients = self.select_clients(self.config.clients_per_round)

# X_g = []
# X_d = []
# for i in range(self.k):
# noise = Variable(torch.FloatTensor(np.random.normal(0, 1, (self.batch_size, self.latent_dim))))
# X_g.append(self.generator(noise.detach()))
#
# noise = Variable(torch.FloatTensor(np.random.normal(0, 1, (self.batch_size, self.latent_dim))))
# X_d.append(self.generator(noise.detach()))
#
# samples_d = [random.randrange(self.k) for _ in range(len(selected_clients))]
# samples_g = [random.randrange(self.k) for _ in range(len(selected_clients))]
X_g = []
X_d = []
for i in range(self.k):
noise = Variable(torch.FloatTensor(np.random.normal(0, 1, (self.batch_size, self.latent_dim))))
X_g.append(self.generator(noise.detach()))

noise = Variable(torch.FloatTensor(np.random.normal(0, 1, (self.batch_size, self.latent_dim))))
X_d.append(self.generator(noise.detach()))

samples_d = [random.randrange(self.k) for _ in range(len(selected_clients))]
samples_g = [random.randrange(self.k) for _ in range(len(selected_clients))]

for id, client in enumerate(selected_clients):
# Sample noise as generator input and feed to clients based on their datat size
# X_d_i = X_d[samples_d[id]]
# X_g_i = X_g[samples_g[id]]
noise = Variable(torch.FloatTensor(np.random.normal(0, 1, (self.batch_size, self.latent_dim))))
X_g_i = self.generator(noise)
X_d_i = X_d[samples_d[id]]
X_g_i = X_g[samples_g[id]]

noise = Variable(torch.FloatTensor(np.random.normal(0, 1, (self.batch_size, self.latent_dim))))
X_d_i = self.generator(noise)
# noise = Variable(torch.FloatTensor(np.random.normal(0, 1, (self.batch_size, self.latent_dim))))
# X_g_i = self.generator(noise)
#
# noise = Variable(torch.FloatTensor(np.random.normal(0, 1, (self.batch_size, self.latent_dim))))
# X_d_i = self.generator(noise)

responses.append((client, _remote_method_async(ClientMDGAN.run_epochs, client.ref,
epochs, fl_round, (X_d_i, X_g_i))))
Expand All @@ -159,14 +160,14 @@ def remote_run_epoch(self, epochs, fl_round=0):
self.optimizer.step(G_context)

logging.info('Gradient is updated')
if fl_round % 25 == 0:
if fl_round % 5 == 0:
self.test_generator(fl_round)

def plot_score_data(self):
file_output = f'./{self.config.output_location}'
self.ensure_path_exists(file_output)

plt.plot(range(0, self.config.epochs, 25), self.fids, 'b')
plt.plot(range(0, self.config.epochs, 5), self.fids, 'b')
# plt.plot(range(self.config.epochs), self.inceptions, 'r')
plt.xlabel('Federator runs')
plt.ylabel('FID')
Expand Down
2 changes: 1 addition & 1 deletion fltk/launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

def run_ps(rpc_ids_triple, args):
print(f'Starting the federator...')
fed = FederatorMDGAN(rpc_ids_triple, config=args)
fed = FederatorFeGAN(rpc_ids_triple, config=args)
fed.run()


Expand Down
Loading

0 comments on commit 9174028

Please sign in to comment.