Skip to content

Commit

Permalink
Merge pull request #3 from nata1y/main
Browse files Browse the repository at this point in the history
add comments + custom local methods
  • Loading branch information
nata1y authored Jun 13, 2021
2 parents f0032d3 + e23b59f commit 8e2618b
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 50 deletions.
115 changes: 71 additions & 44 deletions fltk/client_fegan.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,14 +53,18 @@ def __init__(self, id, log_rref, rank, world_size, config=None):
super().__init__(id, log_rref, rank, world_size, config)
logging.info(f'Welcome to FE client {id}')
self.latent_dim = 10
self.batch_size = 3000
self.batch_size = 300
self.sinkhorn = SinkhornDistance(eps=0.1, max_iter=100)

def return_distribution(self):
labels = self.dataset.load_train_dataset()[1]
unique, counts = np.unique(labels, return_counts=True)
return sorted((zip(unique, counts)))

def weight_scale(self, m):
if type(m) == torch.nn.Conv2d:
m.weight = torch.nn.Parameter((m.weight * 0.02) / torch.max(m.weight, dim=1, keepdim=True)[0] - 0.01)

def init_dataloader(self, ):
self.args.distributed = True
self.args.rank = self.rank
Expand All @@ -70,7 +74,7 @@ def init_dataloader(self, ):
del lbl
self.finished_init = True

self.batch_size = 3000
self.batch_size = 300

logging.info('Done with init')

Expand Down Expand Up @@ -100,48 +104,71 @@ def train_fe(self, epoch, net):
if self.args.distributed:
self.dataset.train_sampler.set_epoch(epoch)

inputs = torch.from_numpy(self.inputs[[random.randrange(self.inputs.shape[0]) for _ in range(self.batch_size)]])

optimizer_generator.zero_grad()
optimizer_discriminator.zero_grad()

noise = Variable(torch.FloatTensor(np.random.normal(0, 1, (self.batch_size, self.latent_dim))), requires_grad=False)
generated_imgs = generator(noise.detach())
d_generator = discriminator(generated_imgs)
generator_loss = self.J_generator(d_generator)
# generator_loss = torch.tensor([wasserstein_1d(torch.flatten(d_generator).detach().numpy(),
# torch.ones(self.batch_size).detach().numpy())], requires_grad=True)
# generator_loss, _, _ = self.sinkhorn(d_generator, torch.ones(self.batch_size, 1))
# generator_loss = self.wasserstein_loss(d_generator, torch.ones(self.batch_size))
# generator_loss.require_grad = True
generator_loss.backward(retain_graph=True)

fake_loss = self.B_hat(d_generator.detach())
real_loss = self.A_hat(discriminator(inputs))

discriminator_loss = 0.5 * (real_loss + fake_loss)

# fake_loss = wasserstein_1d(torch.flatten(d_generator).detach().numpy(),
# torch.zeros(self.batch_size).detach().numpy())
# real_loss = wasserstein_1d(torch.flatten(discriminator(inputs)).detach().numpy(),
# torch.ones(self.batch_size).detach().numpy())
# fake_loss, _, _ = self.sinkhorn(d_generator, torch.zeros(self.batch_size, 1))
# real_loss, _, _ = self.sinkhorn(discriminator(inputs), torch.ones(self.batch_size, 1))

# generated_imgs = generator(noise.detach())
# d_generator = discriminator(generated_imgs.detach())
# fake_loss = self.wasserstein_loss(d_generator.detach(),
# (-1) * torch.ones(self.batch_size))
# real_loss = self.wasserstein_loss(discriminator(inputs),
# torch.ones(self.batch_size))
# discriminator_loss = torch.tensor([real_loss + fake_loss], requires_grad=True)
discriminator_loss.backward()
generator_loss.backward()

optimizer_generator.step()
optimizer_discriminator.step()
# generator_loss._grad = None
# discriminator_loss._grad = None
for batch_idx in range(len(self.inputs) // self.batch_size + 1):

try:
inputs = torch.from_numpy(self.inputs[batch_idx * self.batch_size:(batch_idx + 1) * self.batch_size])
except:
inputs = torch.from_numpy(self.inputs[batch_idx * self.batch_size:])

# inputs = torch.from_numpy(self.inputs[[random.randrange(self.inputs.shape[0]) for _ in range(self.batch_size)]])

optimizer_generator.zero_grad()
optimizer_discriminator.zero_grad()

# apply weight scale for WGAN loss implementation
# with torch.no_grad():
# discriminator = discriminator.apply(self.weight_scale)

noise = Variable(torch.FloatTensor(np.random.normal(0, 1, (self.batch_size, self.latent_dim))), requires_grad=False)
generated_imgs = generator(noise.detach())
d_generator = discriminator(generated_imgs)

# minmax loss
generator_loss = self.J_generator(d_generator)

# wasserstein distance loss
# generator_loss = torch.tensor([wasserstein_1d(torch.flatten(d_generator).detach().numpy(),
# torch.ones(self.batch_size).detach().numpy())], requires_grad=True)

# wasserstein distance approximation via sinkhorn
# generator_loss, _, _ = self.sinkhorn(d_generator, torch.ones(self.batch_size, 1))

# wassersetin loss from wgan
# generator_loss = self.wasserstein_loss(d_generator, (-1.0) * torch.ones(self.batch_size))
# generator_loss.require_grad = True
generator_loss.backward(retain_graph=True)

# minmax loss
fake_loss = self.B_hat(d_generator.detach())
real_loss = self.A_hat(discriminator(inputs))

discriminator_loss = 0.5 * (real_loss + fake_loss)

# wasserstein distance
# fake_loss = wasserstein_1d(torch.flatten(d_generator).detach().numpy(),
# torch.zeros(self.batch_size).detach().numpy())
# real_loss = wasserstein_1d(torch.flatten(discriminator(inputs)).detach().numpy(),
# torch.ones(self.batch_size).detach().numpy())

# wassersetin distance via sinkhorn
# fake_loss, _, _ = self.sinkhorn(d_generator, torch.zeros(self.batch_size, 1))
# real_loss, _, _ = self.sinkhorn(discriminator(inputs), torch.ones(self.batch_size, 1))

# wasserstein loss
# generated_imgs = generator(noise.detach())
# d_generator = discriminator(generated_imgs.detach())
# fake_loss = self.wasserstein_loss(d_generator.detach(),
# (-1) * torch.ones(self.batch_size))
# real_loss = self.wasserstein_loss(discriminator(inputs),
# torch.ones(self.batch_size))
# discriminator_loss = real_loss + fake_loss

discriminator_loss.backward()
generator_loss.backward()

optimizer_generator.step()
optimizer_discriminator.step()

# save model
if self.args.should_save_model(epoch):
Expand Down
9 changes: 3 additions & 6 deletions fltk/federator_fegan.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def __init__(self, client_id_triple, num_epochs=3, config=None):
self.discriminator = Discriminator()
self.discriminator.apply(weights_init_normal)
self.latent_dim = 10
self.batch_size = 3000
self.batch_size = 300
self.fids = []
self.inceptions = []
self.fic_model = InceptionV3()
Expand All @@ -54,7 +54,6 @@ def preprocess_groups(self, c=0.3):
self.init_groups(c, len(num_per_class))

# we need entropies for client weighting during kl-divergence calculations
# idea from code, not from official paper...
all_samples = sum(num_per_class)
rat_per_class = [float(n / all_samples) for n in num_per_class]
cleaned_distributions = [[c for _, c in dist] for dist in self.class_distributions]
Expand All @@ -73,7 +72,6 @@ def init_dataloader(self, ):
del lbls
logging.info('Done with init')

# TODO: cleanup
def init_groups(self, c, label_size):
gp_size = math.ceil(c * len(self.clients))
done = False
Expand Down Expand Up @@ -115,7 +113,7 @@ def init_groups(self, c, label_size):
done = True

def test(self, fl_round):
test_batch = 1000
test_batch = self.test_imgs.shape[0]
with torch.no_grad():
file_output = f'./{self.config.output_location}'
self.ensure_path_exists(file_output)
Expand All @@ -127,7 +125,6 @@ def test(self, fl_round):
print("FL-round {} FID Score: {}, IS Score: {}".format(fl_round, fid, mu_gen))

self.fids.append(fid)
print(self.fids)
# self.inceptions.append(mu_gen)

def checkpoint(self, fl_round):
Expand All @@ -149,7 +146,7 @@ def plot_score_data(self):
plt.xlabel('Federator runs')
plt.ylabel('FID')

filename = f'{file_output}/{self.config.epochs}_epochs_fe_gan_wd4.png'
filename = f'{file_output}/{self.config.epochs}_epochs_fe_gan.png'
logging.info(f'Saving data at {filename}')

plt.savefig(filename)
Expand Down

0 comments on commit 8e2618b

Please sign in to comment.