From a927f063cd180780a047e92ae2f03f22e945dcd1 Mon Sep 17 00:00:00 2001 From: Holger Roth Date: Wed, 1 May 2024 19:24:21 -0400 Subject: [PATCH] enable monai loss tracking and fix eval overwriting model --- .../examples/mednist/code/fedavg_monai.py | 19 ++- .../mednist/code/monai_mednist_train.py | 113 +++++++++++------- .../fedavg_mednist/config_fed_client.conf | 2 +- .../fedavg_mednist/config_fed_server.conf | 2 +- .../job_templates/fedavg_mednist/meta.conf | 2 +- 5 files changed, 90 insertions(+), 48 deletions(-) diff --git a/integration/monai/examples/mednist/code/fedavg_monai.py b/integration/monai/examples/mednist/code/fedavg_monai.py index 09350f9e46..95489359a7 100644 --- a/integration/monai/examples/mednist/code/fedavg_monai.py +++ b/integration/monai/examples/mednist/code/fedavg_monai.py @@ -43,6 +43,12 @@ class FedAvgMONAI(BaseFedAvg): If n is 0 then no persist. """ + def param_sum(self, params): + s = 0 + for k, v in params.items(): + s += v.sum() + return s + def run(self) -> None: self.info("Start FedAvg.") @@ -51,21 +57,26 @@ def run(self) -> None: init_weights = {} for k, v in monai_model.state_dict().items(): init_weights[k] = v.cpu().numpy() - self.model = FLModel(params_type=ParamsType.FULL, params=init_weights) + model = FLModel(params_type=ParamsType.FULL, params=init_weights) + model.current_round = self._current_round for self._current_round in range(self._num_rounds): self.info(f"Round {self._current_round} started.") clients = self.sample_clients(self._min_clients) - results = self.send_model_and_wait(targets=clients, data=self.model) + print("$$$$$$$$$ Server BEGIN ROUND", model.current_round, self.param_sum(model.params)) + results = self.send_model_and_wait(targets=clients, data=model) aggregate_results = self.aggregate( results, aggregate_fn=None ) # if no `aggregate_fn` provided, default `WeightedAggregationHelper` is used - self.update_model(aggregate_results) + model = aggregate_results + model.current_round = self._current_round + + print("$$$$$$$$$ Server END ROUND", model.current_round, self.param_sum(model.params)) - self.save_model() + #self.save_model() self.info("Finished FedAvg.") diff --git a/integration/monai/examples/mednist/code/monai_mednist_train.py b/integration/monai/examples/mednist/code/monai_mednist_train.py index e3cd842a88..84b83af5c1 100644 --- a/integration/monai/examples/mednist/code/monai_mednist_train.py +++ b/integration/monai/examples/mednist/code/monai_mednist_train.py @@ -32,6 +32,7 @@ import sys import tempfile import torch +from copy import deepcopy from monai.apps import MedNISTDataset from monai.config import print_config @@ -42,15 +43,60 @@ from monai.networks import eval_mode from monai.networks.nets import densenet121 from monai.transforms import LoadImageD, EnsureChannelFirstD, ScaleIntensityD, Compose +from monai.handlers import TensorBoardStatsHandler # (1) import nvflare client API import nvflare.client as flare # (optional) metrics -from nvflare.client.tracking import SummaryWriter +#from nvflare.client.tracking import SummaryWriter +from tracking import SummaryWriter print_config() + +def param_sum(params): + s = 0 + for k, v in params.items(): + s += v.sum() + return s + + +# (5) wraps evaluation logic into a method to re-use for +# evaluation on both trained and received model +def evaluate(DEVICE, transform, root_dir, model, input_weights, x=0): + model.load_state_dict(input_weights, strict=True) + + # Check the prediction on the test dataset + dataset_dir = Path(root_dir, "MedNIST") + class_names = sorted(f"{x.name}" for x in dataset_dir.iterdir() if x.is_dir()) + testdata = MedNISTDataset(root_dir=root_dir, transform=transform, section="test", download=False, + runtime_cache=True) + correct = 0 + total = 0 + max_items_to_print = 10 + _print = 0 + with eval_mode(model): + for item in DataLoader(testdata, batch_size=512, num_workers=0): # changed to do batch processing + prob = np.array(model(item["image"].to(DEVICE)).detach().to("cpu")) + pred = [class_names[p] for p in prob.argmax(axis=1)] + gt = item["class_name"] + # changed the logic a bit from tutorial to compute accuracy on full test set + # but only print for some. + for _gt, _pred in zip(gt, pred): + if _print < max_items_to_print: + print(f"Class prediction is {_pred}. Ground-truth: {_gt}") + _print += 1 + + # compute accuracy + total += 1 + correct += float(_pred == _gt) + + acc = correct // total + print(f"Accuracy of the network on the {total} test images: {100 * acc} %") + return acc + + def main(): # (2) initializes NVFlare client API flare.init() @@ -80,7 +126,7 @@ def main(): # If available, we use GPU to speed things up. DEVICE = "cuda" if torch.cuda.is_available() else "cpu" - max_epochs = 1 # rather than 5 epochs, we run 5 FL rounds with 1 local epoch each. + max_epochs = 10 # rather than 5 epochs, we run 5 FL rounds with 1 local epoch each. model = densenet121(spatial_dims=2, in_channels=1, out_channels=6).to(DEVICE) train_loader = DataLoader(dataset, batch_size=512, shuffle=True, num_workers=4) @@ -94,63 +140,45 @@ def main(): optimizer=torch.optim.Adam(model.parameters(), lr=1e-5), loss_function=torch.nn.CrossEntropyLoss(), inferer=SimpleInferer(), - train_handlers=StatsHandler(), ) + # StatsHandler prints loss at every iteration and print metrics at every epoch, + # we don't set metrics for trainer here, so just print loss, user can also customize print functions + # and can use output_transform to convert engine.state.output if it's not loss value + train_stats_handler = StatsHandler(name="trainer") + train_stats_handler.attach(trainer) + + # TensorBoardStatsHandler plots loss at every iteration and plots metrics at every epoch, same as StatsHandler + summary_writer = SummaryWriter() + train_tensorboard_stats_handler = TensorBoardStatsHandler( + summary_writer=summary_writer + ) + train_tensorboard_stats_handler.attach(trainer) + # (optional) calculate total steps steps = max_epochs * len(train_loader) # Run the training - summary_writer = SummaryWriter() while flare.is_running(): # (3) receives FLModel from NVFlare input_model = flare.receive() print(f"current_round={input_model.current_round}") + print("$$$$$$$$$ CLIENT BEGIN ROUND", input_model.current_round, param_sum(input_model.params)) # (4) loads model from NVFlare and sends it to GPU trainer.network.load_state_dict(input_model.params, strict=True) trainer.network.to(DEVICE) - trainer.run() + # set engine state max epochs. + trainer.state.max_epochs = trainer.state.epoch + max_epochs + # get current iteration when a round starts + iter_of_start_time = trainer.state.iteration - # (5) wraps evaluation logic into a method to re-use for - # evaluation on both trained and received model - def evaluate(input_weights, x=0): - model.load_state_dict(input_weights, strict=True) - - # Check the prediction on the test dataset - dataset_dir = Path(root_dir, "MedNIST") - class_names = sorted(f"{x.name}" for x in dataset_dir.iterdir() if x.is_dir()) - testdata = MedNISTDataset(root_dir=root_dir, transform=transform, section="test", download=False, - runtime_cache=True) - correct = 0 - total = 0 - max_items_to_print = 10 - _print = 0 - with eval_mode(model): - for item in DataLoader(testdata, batch_size=512, num_workers=0): # changed to do batch processing - prob = np.array(model(item["image"].to(DEVICE)).detach().to("cpu")) - pred = [class_names[p] for p in prob.argmax(axis=1)] - gt = item["class_name"] - # changed the logic a bit from tutorial to compute accuracy on full test set - # but only print for some. - for _gt, _pred in zip(gt, pred): - if _print < max_items_to_print: - print(f"Class prediction is {_pred}. Ground-truth: {_gt}") - _print += 1 - - # compute accuracy - total += 1 - correct += float(_pred == _gt) - - #acc = correct // total - acc = 1 - 1/(1 + x) - print(f"Accuracy of the network on the {total} test images: {100 * acc} %") - return acc + trainer.run() # (6) evaluate on received model for model selection - accuracy = evaluate(input_model.params, x=input_model.current_round) - summary_writer.add_scalar(tag="global_model_accuracy", scalar=accuracy, global_step=input_model.current_round) + accuracy = evaluate(DEVICE, transform, root_dir, deepcopy(model), input_model.params, x=0) + summary_writer.add_scalar("global_model_accuracy", accuracy, input_model.current_round) # (7) construct trained FL model output_model = flare.FLModel( @@ -158,9 +186,12 @@ def evaluate(input_weights, x=0): metrics={"accuracy": accuracy}, meta={"NUM_STEPS_CURRENT_ROUND": steps}, ) + + print("$$$$$$$$$ CLIENT END ROUND", output_model.current_round, param_sum(output_model.params)) # (8) send model back to NVFlare flare.send(output_model) + if __name__ == "__main__": main() diff --git a/integration/monai/examples/mednist/job_templates/fedavg_mednist/config_fed_client.conf b/integration/monai/examples/mednist/job_templates/fedavg_mednist/config_fed_client.conf index c6d7b9d61d..63349fcb10 100644 --- a/integration/monai/examples/mednist/job_templates/fedavg_mednist/config_fed_client.conf +++ b/integration/monai/examples/mednist/job_templates/fedavg_mednist/config_fed_client.conf @@ -37,7 +37,7 @@ # if the transfer_type is FULL, then it will be sent directly # if the transfer_type is DIFF, then we will calculate the # difference VS received parameters and send the difference - params_transfer_type = "DIFF" + params_transfer_type = "FULL" # if train_with_evaluation is true, the executor will expect # the custom code need to send back both the trained parameters and the evaluation metric diff --git a/integration/monai/examples/mednist/job_templates/fedavg_mednist/config_fed_server.conf b/integration/monai/examples/mednist/job_templates/fedavg_mednist/config_fed_server.conf index a34a70b81f..1841cb76ec 100644 --- a/integration/monai/examples/mednist/job_templates/fedavg_mednist/config_fed_server.conf +++ b/integration/monai/examples/mednist/job_templates/fedavg_mednist/config_fed_server.conf @@ -25,7 +25,7 @@ # min number of clients required for ScatterAndGather controller to move to the next round # during the workflow cycle. The controller will wait until the min_clients returned from clients # before move to the next step. - min_clients = 2 + min_clients = 1 # number of global round of the training. num_rounds = 5 diff --git a/integration/monai/examples/mednist/job_templates/fedavg_mednist/meta.conf b/integration/monai/examples/mednist/job_templates/fedavg_mednist/meta.conf index 43a7e6a2a3..d8fa9d702b 100644 --- a/integration/monai/examples/mednist/job_templates/fedavg_mednist/meta.conf +++ b/integration/monai/examples/mednist/job_templates/fedavg_mednist/meta.conf @@ -5,6 +5,6 @@ # change deploy map as needed. app = ["@ALL"] } - min_clients = 2 + min_clients = 1 mandatory_clients = [] }