Skip to content

Commit

Permalink
enable monai loss tracking and fix eval overwriting model
Browse files Browse the repository at this point in the history
  • Loading branch information
holgerroth committed May 1, 2024
1 parent c0c68f5 commit a927f06
Show file tree
Hide file tree
Showing 5 changed files with 90 additions and 48 deletions.
19 changes: 15 additions & 4 deletions integration/monai/examples/mednist/code/fedavg_monai.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,12 @@ class FedAvgMONAI(BaseFedAvg):
If n is 0 then no persist.
"""

def param_sum(self, params):
s = 0
for k, v in params.items():
s += v.sum()
return s

def run(self) -> None:
self.info("Start FedAvg.")

Expand All @@ -51,21 +57,26 @@ def run(self) -> None:
init_weights = {}
for k, v in monai_model.state_dict().items():
init_weights[k] = v.cpu().numpy()
self.model = FLModel(params_type=ParamsType.FULL, params=init_weights)
model = FLModel(params_type=ParamsType.FULL, params=init_weights)
model.current_round = self._current_round

for self._current_round in range(self._num_rounds):
self.info(f"Round {self._current_round} started.")

clients = self.sample_clients(self._min_clients)

results = self.send_model_and_wait(targets=clients, data=self.model)
print("$$$$$$$$$ Server BEGIN ROUND", model.current_round, self.param_sum(model.params))
results = self.send_model_and_wait(targets=clients, data=model)

aggregate_results = self.aggregate(
results, aggregate_fn=None
) # if no `aggregate_fn` provided, default `WeightedAggregationHelper` is used

self.update_model(aggregate_results)
model = aggregate_results
model.current_round = self._current_round

print("$$$$$$$$$ Server END ROUND", model.current_round, self.param_sum(model.params))

self.save_model()
#self.save_model()

self.info("Finished FedAvg.")
113 changes: 72 additions & 41 deletions integration/monai/examples/mednist/code/monai_mednist_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
import sys
import tempfile
import torch
from copy import deepcopy

from monai.apps import MedNISTDataset
from monai.config import print_config
Expand All @@ -42,15 +43,60 @@
from monai.networks import eval_mode
from monai.networks.nets import densenet121
from monai.transforms import LoadImageD, EnsureChannelFirstD, ScaleIntensityD, Compose
from monai.handlers import TensorBoardStatsHandler

# (1) import nvflare client API
import nvflare.client as flare

# (optional) metrics
from nvflare.client.tracking import SummaryWriter
#from nvflare.client.tracking import SummaryWriter
from tracking import SummaryWriter

print_config()


def param_sum(params):
s = 0
for k, v in params.items():
s += v.sum()
return s


# (5) wraps evaluation logic into a method to re-use for
# evaluation on both trained and received model
def evaluate(DEVICE, transform, root_dir, model, input_weights, x=0):
model.load_state_dict(input_weights, strict=True)

# Check the prediction on the test dataset
dataset_dir = Path(root_dir, "MedNIST")
class_names = sorted(f"{x.name}" for x in dataset_dir.iterdir() if x.is_dir())
testdata = MedNISTDataset(root_dir=root_dir, transform=transform, section="test", download=False,
runtime_cache=True)
correct = 0
total = 0
max_items_to_print = 10
_print = 0
with eval_mode(model):
for item in DataLoader(testdata, batch_size=512, num_workers=0): # changed to do batch processing
prob = np.array(model(item["image"].to(DEVICE)).detach().to("cpu"))
pred = [class_names[p] for p in prob.argmax(axis=1)]
gt = item["class_name"]
# changed the logic a bit from tutorial to compute accuracy on full test set
# but only print for some.
for _gt, _pred in zip(gt, pred):
if _print < max_items_to_print:
print(f"Class prediction is {_pred}. Ground-truth: {_gt}")
_print += 1

# compute accuracy
total += 1
correct += float(_pred == _gt)

acc = correct // total
print(f"Accuracy of the network on the {total} test images: {100 * acc} %")
return acc


def main():
# (2) initializes NVFlare client API
flare.init()
Expand Down Expand Up @@ -80,7 +126,7 @@ def main():
# If available, we use GPU to speed things up.
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

max_epochs = 1 # rather than 5 epochs, we run 5 FL rounds with 1 local epoch each.
max_epochs = 10 # rather than 5 epochs, we run 5 FL rounds with 1 local epoch each.
model = densenet121(spatial_dims=2, in_channels=1, out_channels=6).to(DEVICE)

train_loader = DataLoader(dataset, batch_size=512, shuffle=True, num_workers=4)
Expand All @@ -94,73 +140,58 @@ def main():
optimizer=torch.optim.Adam(model.parameters(), lr=1e-5),
loss_function=torch.nn.CrossEntropyLoss(),
inferer=SimpleInferer(),
train_handlers=StatsHandler(),
)

# StatsHandler prints loss at every iteration and print metrics at every epoch,
# we don't set metrics for trainer here, so just print loss, user can also customize print functions
# and can use output_transform to convert engine.state.output if it's not loss value
train_stats_handler = StatsHandler(name="trainer")
train_stats_handler.attach(trainer)

# TensorBoardStatsHandler plots loss at every iteration and plots metrics at every epoch, same as StatsHandler
summary_writer = SummaryWriter()
train_tensorboard_stats_handler = TensorBoardStatsHandler(
summary_writer=summary_writer
)
train_tensorboard_stats_handler.attach(trainer)

# (optional) calculate total steps
steps = max_epochs * len(train_loader)
# Run the training

summary_writer = SummaryWriter()
while flare.is_running():
# (3) receives FLModel from NVFlare
input_model = flare.receive()
print(f"current_round={input_model.current_round}")
print("$$$$$$$$$ CLIENT BEGIN ROUND", input_model.current_round, param_sum(input_model.params))

# (4) loads model from NVFlare and sends it to GPU
trainer.network.load_state_dict(input_model.params, strict=True)
trainer.network.to(DEVICE)

trainer.run()
# set engine state max epochs.
trainer.state.max_epochs = trainer.state.epoch + max_epochs
# get current iteration when a round starts
iter_of_start_time = trainer.state.iteration

# (5) wraps evaluation logic into a method to re-use for
# evaluation on both trained and received model
def evaluate(input_weights, x=0):
model.load_state_dict(input_weights, strict=True)

# Check the prediction on the test dataset
dataset_dir = Path(root_dir, "MedNIST")
class_names = sorted(f"{x.name}" for x in dataset_dir.iterdir() if x.is_dir())
testdata = MedNISTDataset(root_dir=root_dir, transform=transform, section="test", download=False,
runtime_cache=True)
correct = 0
total = 0
max_items_to_print = 10
_print = 0
with eval_mode(model):
for item in DataLoader(testdata, batch_size=512, num_workers=0): # changed to do batch processing
prob = np.array(model(item["image"].to(DEVICE)).detach().to("cpu"))
pred = [class_names[p] for p in prob.argmax(axis=1)]
gt = item["class_name"]
# changed the logic a bit from tutorial to compute accuracy on full test set
# but only print for some.
for _gt, _pred in zip(gt, pred):
if _print < max_items_to_print:
print(f"Class prediction is {_pred}. Ground-truth: {_gt}")
_print += 1

# compute accuracy
total += 1
correct += float(_pred == _gt)

#acc = correct // total
acc = 1 - 1/(1 + x)
print(f"Accuracy of the network on the {total} test images: {100 * acc} %")
return acc
trainer.run()

# (6) evaluate on received model for model selection
accuracy = evaluate(input_model.params, x=input_model.current_round)
summary_writer.add_scalar(tag="global_model_accuracy", scalar=accuracy, global_step=input_model.current_round)
accuracy = evaluate(DEVICE, transform, root_dir, deepcopy(model), input_model.params, x=0)
summary_writer.add_scalar("global_model_accuracy", accuracy, input_model.current_round)

# (7) construct trained FL model
output_model = flare.FLModel(
params=trainer.network.cpu().state_dict(),
metrics={"accuracy": accuracy},
meta={"NUM_STEPS_CURRENT_ROUND": steps},
)

print("$$$$$$$$$ CLIENT END ROUND", output_model.current_round, param_sum(output_model.params))
# (8) send model back to NVFlare
flare.send(output_model)



if __name__ == "__main__":
main()
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
# if the transfer_type is FULL, then it will be sent directly
# if the transfer_type is DIFF, then we will calculate the
# difference VS received parameters and send the difference
params_transfer_type = "DIFF"
params_transfer_type = "FULL"

# if train_with_evaluation is true, the executor will expect
# the custom code need to send back both the trained parameters and the evaluation metric
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
# min number of clients required for ScatterAndGather controller to move to the next round
# during the workflow cycle. The controller will wait until the min_clients returned from clients
# before move to the next step.
min_clients = 2
min_clients = 1

# number of global round of the training.
num_rounds = 5
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,6 @@
# change deploy map as needed.
app = ["@ALL"]
}
min_clients = 2
min_clients = 1
mandatory_clients = []
}

0 comments on commit a927f06

Please sign in to comment.