diff --git a/fltk/client_fegan.py b/fltk/client_fegan.py index fb9f839..061e422 100644 --- a/fltk/client_fegan.py +++ b/fltk/client_fegan.py @@ -53,7 +53,7 @@ def __init__(self, id, log_rref, rank, world_size, config=None): super().__init__(id, log_rref, rank, world_size, config) logging.info(f'Welcome to FE client {id}') self.latent_dim = 10 - self.batch_size = 3000 + self.batch_size = 300 self.sinkhorn = SinkhornDistance(eps=0.1, max_iter=100) def return_distribution(self): @@ -61,6 +61,10 @@ def return_distribution(self): unique, counts = np.unique(labels, return_counts=True) return sorted((zip(unique, counts))) + def weight_scale(self, m): + if type(m) == torch.nn.Conv2d: + m.weight = torch.nn.Parameter((m.weight * 0.02) / torch.max(m.weight, dim=1, keepdim=True)[0] - 0.01) + def init_dataloader(self, ): self.args.distributed = True self.args.rank = self.rank @@ -70,7 +74,7 @@ def init_dataloader(self, ): del lbl self.finished_init = True - self.batch_size = 3000 + self.batch_size = 300 logging.info('Done with init') @@ -100,48 +104,71 @@ def train_fe(self, epoch, net): if self.args.distributed: self.dataset.train_sampler.set_epoch(epoch) - inputs = torch.from_numpy(self.inputs[[random.randrange(self.inputs.shape[0]) for _ in range(self.batch_size)]]) - - optimizer_generator.zero_grad() - optimizer_discriminator.zero_grad() - - noise = Variable(torch.FloatTensor(np.random.normal(0, 1, (self.batch_size, self.latent_dim))), requires_grad=False) - generated_imgs = generator(noise.detach()) - d_generator = discriminator(generated_imgs) - generator_loss = self.J_generator(d_generator) - # generator_loss = torch.tensor([wasserstein_1d(torch.flatten(d_generator).detach().numpy(), - # torch.ones(self.batch_size).detach().numpy())], requires_grad=True) - # generator_loss, _, _ = self.sinkhorn(d_generator, torch.ones(self.batch_size, 1)) - # generator_loss = self.wasserstein_loss(d_generator, torch.ones(self.batch_size)) - # generator_loss.require_grad = True - generator_loss.backward(retain_graph=True) - - fake_loss = self.B_hat(d_generator.detach()) - real_loss = self.A_hat(discriminator(inputs)) - - discriminator_loss = 0.5 * (real_loss + fake_loss) - - # fake_loss = wasserstein_1d(torch.flatten(d_generator).detach().numpy(), - # torch.zeros(self.batch_size).detach().numpy()) - # real_loss = wasserstein_1d(torch.flatten(discriminator(inputs)).detach().numpy(), - # torch.ones(self.batch_size).detach().numpy()) - # fake_loss, _, _ = self.sinkhorn(d_generator, torch.zeros(self.batch_size, 1)) - # real_loss, _, _ = self.sinkhorn(discriminator(inputs), torch.ones(self.batch_size, 1)) - - # generated_imgs = generator(noise.detach()) - # d_generator = discriminator(generated_imgs.detach()) - # fake_loss = self.wasserstein_loss(d_generator.detach(), - # (-1) * torch.ones(self.batch_size)) - # real_loss = self.wasserstein_loss(discriminator(inputs), - # torch.ones(self.batch_size)) - # discriminator_loss = torch.tensor([real_loss + fake_loss], requires_grad=True) - discriminator_loss.backward() - generator_loss.backward() - - optimizer_generator.step() - optimizer_discriminator.step() - # generator_loss._grad = None - # discriminator_loss._grad = None + for batch_idx in range(len(self.inputs) // self.batch_size + 1): + + try: + inputs = torch.from_numpy(self.inputs[batch_idx * self.batch_size:(batch_idx + 1) * self.batch_size]) + except: + inputs = torch.from_numpy(self.inputs[batch_idx * self.batch_size:]) + + # inputs = torch.from_numpy(self.inputs[[random.randrange(self.inputs.shape[0]) for _ in range(self.batch_size)]]) + + optimizer_generator.zero_grad() + optimizer_discriminator.zero_grad() + + # apply weight scale for WGAN loss implementation + # with torch.no_grad(): + # discriminator = discriminator.apply(self.weight_scale) + + noise = Variable(torch.FloatTensor(np.random.normal(0, 1, (self.batch_size, self.latent_dim))), requires_grad=False) + generated_imgs = generator(noise.detach()) + d_generator = discriminator(generated_imgs) + + # minmax loss + generator_loss = self.J_generator(d_generator) + + # wasserstein distance loss + # generator_loss = torch.tensor([wasserstein_1d(torch.flatten(d_generator).detach().numpy(), + # torch.ones(self.batch_size).detach().numpy())], requires_grad=True) + + # wasserstein distance approximation via sinkhorn + # generator_loss, _, _ = self.sinkhorn(d_generator, torch.ones(self.batch_size, 1)) + + # wassersetin loss from wgan + # generator_loss = self.wasserstein_loss(d_generator, (-1.0) * torch.ones(self.batch_size)) + # generator_loss.require_grad = True + generator_loss.backward(retain_graph=True) + + # minmax loss + fake_loss = self.B_hat(d_generator.detach()) + real_loss = self.A_hat(discriminator(inputs)) + + discriminator_loss = 0.5 * (real_loss + fake_loss) + + # wasserstein distance + # fake_loss = wasserstein_1d(torch.flatten(d_generator).detach().numpy(), + # torch.zeros(self.batch_size).detach().numpy()) + # real_loss = wasserstein_1d(torch.flatten(discriminator(inputs)).detach().numpy(), + # torch.ones(self.batch_size).detach().numpy()) + + # wassersetin distance via sinkhorn + # fake_loss, _, _ = self.sinkhorn(d_generator, torch.zeros(self.batch_size, 1)) + # real_loss, _, _ = self.sinkhorn(discriminator(inputs), torch.ones(self.batch_size, 1)) + + # wasserstein loss + # generated_imgs = generator(noise.detach()) + # d_generator = discriminator(generated_imgs.detach()) + # fake_loss = self.wasserstein_loss(d_generator.detach(), + # (-1) * torch.ones(self.batch_size)) + # real_loss = self.wasserstein_loss(discriminator(inputs), + # torch.ones(self.batch_size)) + # discriminator_loss = real_loss + fake_loss + + discriminator_loss.backward() + generator_loss.backward() + + optimizer_generator.step() + optimizer_discriminator.step() # save model if self.args.should_save_model(epoch): diff --git a/fltk/federator_fegan.py b/fltk/federator_fegan.py index 482dadd..90b3903 100644 --- a/fltk/federator_fegan.py +++ b/fltk/federator_fegan.py @@ -30,7 +30,7 @@ def __init__(self, client_id_triple, num_epochs=3, config=None): self.discriminator = Discriminator() self.discriminator.apply(weights_init_normal) self.latent_dim = 10 - self.batch_size = 3000 + self.batch_size = 300 self.fids = [] self.inceptions = [] self.fic_model = InceptionV3() @@ -54,7 +54,6 @@ def preprocess_groups(self, c=0.3): self.init_groups(c, len(num_per_class)) # we need entropies for client weighting during kl-divergence calculations - # idea from code, not from official paper... all_samples = sum(num_per_class) rat_per_class = [float(n / all_samples) for n in num_per_class] cleaned_distributions = [[c for _, c in dist] for dist in self.class_distributions] @@ -73,7 +72,6 @@ def init_dataloader(self, ): del lbls logging.info('Done with init') - # TODO: cleanup def init_groups(self, c, label_size): gp_size = math.ceil(c * len(self.clients)) done = False @@ -115,7 +113,7 @@ def init_groups(self, c, label_size): done = True def test(self, fl_round): - test_batch = 1000 + test_batch = self.test_imgs.shape[0] with torch.no_grad(): file_output = f'./{self.config.output_location}' self.ensure_path_exists(file_output) @@ -127,7 +125,6 @@ def test(self, fl_round): print("FL-round {} FID Score: {}, IS Score: {}".format(fl_round, fid, mu_gen)) self.fids.append(fid) - print(self.fids) # self.inceptions.append(mu_gen) def checkpoint(self, fl_round): @@ -149,7 +146,7 @@ def plot_score_data(self): plt.xlabel('Federator runs') plt.ylabel('FID') - filename = f'{file_output}/{self.config.epochs}_epochs_fe_gan_wd4.png' + filename = f'{file_output}/{self.config.epochs}_epochs_fe_gan.png' logging.info(f'Saving data at {filename}') plt.savefig(filename)