From e98a21c4086416cac415ff580ae877445692f86d Mon Sep 17 00:00:00 2001 From: Holger Roth <6304754+holgerroth@users.noreply.github.com> Date: Thu, 13 Feb 2025 17:45:16 -0500 Subject: [PATCH] Enhance lightning api (#3225) Fixes # . ### Description Enhance lightning client API to give warnings if there were unexpected or missing keys when loading back the global state dictionary. Skip updating the fit loop to support customized lightning trainers, such as NeMo & BioNeMo. ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Quick tests passed locally by running `./runtest.sh`. - [ ] In-line docstrings updated. - [ ] Documentation updated. --- nvflare/app_opt/lightning/api.py | 38 +++++++++++++++++++++++++------- 1 file changed, 30 insertions(+), 8 deletions(-) diff --git a/nvflare/app_opt/lightning/api.py b/nvflare/app_opt/lightning/api.py index 4e674e5915..45629a6b42 100644 --- a/nvflare/app_opt/lightning/api.py +++ b/nvflare/app_opt/lightning/api.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import logging from typing import Dict import pytorch_lightning as pl @@ -29,7 +30,9 @@ FL_META_KEY = "__fl_meta__" -def patch(trainer: pl.Trainer, restore_state: bool = True, load_state_dict_strict: bool = True): +def patch( + trainer: pl.Trainer, restore_state: bool = True, load_state_dict_strict: bool = True, update_fit_loop: bool = True +): """Patches the PyTorch Lightning Trainer for usage with NVFlare. Args: @@ -39,6 +42,8 @@ def patch(trainer: pl.Trainer, restore_state: bool = True, load_state_dict_stric load_state_dict_strict: exposes `strict` argument of `torch.nn.Module.load_state_dict()` used to load the received model. Defaults to `True`. See https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module.load_state_dict for details. + update_fit_loop: whether to increase `trainer.fit_loop.max_epochs` and `trainer.fit_loop.epoch_loop.max_steps` each FL round. + Defaults to `True` which is suitable for most PyTorch Lightning applications. Example: @@ -75,7 +80,9 @@ def __init__(self): callbacks = [] if not any(isinstance(cb, FLCallback) for cb in callbacks): - fl_callback = FLCallback(rank=trainer.global_rank, load_state_dict_strict=load_state_dict_strict) + fl_callback = FLCallback( + rank=trainer.global_rank, load_state_dict_strict=load_state_dict_strict, update_fit_loop=update_fit_loop + ) callbacks.append(fl_callback) if restore_state and not any(isinstance(cb, RestoreState) for cb in callbacks): @@ -85,7 +92,7 @@ def __init__(self): class FLCallback(Callback): - def __init__(self, rank: int = 0, load_state_dict_strict: bool = True): + def __init__(self, rank: int = 0, load_state_dict_strict: bool = True, update_fit_loop: bool = True): """FL callback for lightning API. Args: @@ -93,6 +100,8 @@ def __init__(self, rank: int = 0, load_state_dict_strict: bool = True): load_state_dict_strict: exposes `strict` argument of `torch.nn.Module.load_state_dict()` used to load the received model. Defaults to `True`. See https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module.load_state_dict for details. + update_fit_loop: whether to increase `trainer.fit_loop.max_epochs` and `trainer.fit_loop.epoch_loop.max_steps` each FL round. + Defaults to `True` which is suitable for most PyTorch Lightning applications. """ super(FLCallback, self).__init__() init(rank=str(rank)) @@ -108,6 +117,9 @@ def __init__(self, rank: int = 0, load_state_dict_strict: bool = True): self._is_evaluation = False self._is_submit_model = False self._load_state_dict_strict = load_state_dict_strict + self._update_fit_loop = update_fit_loop + + self.logger = logging.getLogger(self.__class__.__name__) def reset_state(self, trainer): """Resets the state. @@ -130,10 +142,12 @@ def reset_state(self, trainer): # for next round trainer.num_sanity_val_steps = 0 # Turn off sanity validation steps in following rounds of FL - if self.total_local_epochs and self.max_epochs_per_round is not None: - trainer.fit_loop.max_epochs = self.max_epochs_per_round + self.total_local_epochs - if self.total_local_steps and self.max_steps_per_round is not None: - trainer.fit_loop.epoch_loop.max_steps = self.max_steps_per_round + self.total_local_steps + + if self._update_fit_loop: + if self.total_local_epochs and self.max_epochs_per_round is not None: + trainer.fit_loop.max_epochs = self.max_epochs_per_round + self.total_local_epochs + if self.total_local_steps and self.max_steps_per_round is not None: + trainer.fit_loop.epoch_loop.max_steps = self.max_steps_per_round + self.total_local_steps # resets attributes self.metrics = None @@ -184,7 +198,15 @@ def _receive_and_update_model(self, trainer, pl_module): model = self._receive_model(trainer) if model: if model.params: - pl_module.load_state_dict(model.params, strict=self._load_state_dict_strict) + missing_keys, unexpected_keys = pl_module.load_state_dict( + model.params, strict=self._load_state_dict_strict + ) + if len(missing_keys) > 0: + self.logger.warning(f"There were missing keys when loading the global state_dict: {missing_keys}") + if len(unexpected_keys) > 0: + self.logger.warning( + f"There were unexpected keys when loading the global state_dict: {unexpected_keys}" + ) if model.current_round is not None: self.current_round = model.current_round