Skip to content

Commit

Permalink
Enhance lightning api (NVIDIA#3225)
Browse files Browse the repository at this point in the history
Fixes # .

### Description

Enhance lightning client API to give warnings if there were unexpected
or missing keys when loading back the global state dictionary.
Skip updating the fit loop to support customized lightning trainers,
such as NeMo & BioNeMo.

### Types of changes
<!--- Put an `x` in all the boxes that apply, and remove the not
applicable items -->
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [ ] Breaking change (fix or new feature that would cause existing
functionality to change).
- [ ] New tests added to cover the changes.
- [ ] Quick tests passed locally by running `./runtest.sh`.
- [ ] In-line docstrings updated.
- [ ] Documentation updated.
  • Loading branch information
holgerroth authored Feb 13, 2025
1 parent 4efdf78 commit e98a21c
Showing 1 changed file with 30 additions and 8 deletions.
38 changes: 30 additions & 8 deletions nvflare/app_opt/lightning/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
from typing import Dict

import pytorch_lightning as pl
Expand All @@ -29,7 +30,9 @@
FL_META_KEY = "__fl_meta__"


def patch(trainer: pl.Trainer, restore_state: bool = True, load_state_dict_strict: bool = True):
def patch(
trainer: pl.Trainer, restore_state: bool = True, load_state_dict_strict: bool = True, update_fit_loop: bool = True
):
"""Patches the PyTorch Lightning Trainer for usage with NVFlare.
Args:
Expand All @@ -39,6 +42,8 @@ def patch(trainer: pl.Trainer, restore_state: bool = True, load_state_dict_stric
load_state_dict_strict: exposes `strict` argument of `torch.nn.Module.load_state_dict()`
used to load the received model. Defaults to `True`.
See https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module.load_state_dict for details.
update_fit_loop: whether to increase `trainer.fit_loop.max_epochs` and `trainer.fit_loop.epoch_loop.max_steps` each FL round.
Defaults to `True` which is suitable for most PyTorch Lightning applications.
Example:
Expand Down Expand Up @@ -75,7 +80,9 @@ def __init__(self):
callbacks = []

if not any(isinstance(cb, FLCallback) for cb in callbacks):
fl_callback = FLCallback(rank=trainer.global_rank, load_state_dict_strict=load_state_dict_strict)
fl_callback = FLCallback(
rank=trainer.global_rank, load_state_dict_strict=load_state_dict_strict, update_fit_loop=update_fit_loop
)
callbacks.append(fl_callback)

if restore_state and not any(isinstance(cb, RestoreState) for cb in callbacks):
Expand All @@ -85,14 +92,16 @@ def __init__(self):


class FLCallback(Callback):
def __init__(self, rank: int = 0, load_state_dict_strict: bool = True):
def __init__(self, rank: int = 0, load_state_dict_strict: bool = True, update_fit_loop: bool = True):
"""FL callback for lightning API.
Args:
rank: global rank of the PyTorch Lightning trainer.
load_state_dict_strict: exposes `strict` argument of `torch.nn.Module.load_state_dict()`
used to load the received model. Defaults to `True`.
See https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module.load_state_dict for details.
update_fit_loop: whether to increase `trainer.fit_loop.max_epochs` and `trainer.fit_loop.epoch_loop.max_steps` each FL round.
Defaults to `True` which is suitable for most PyTorch Lightning applications.
"""
super(FLCallback, self).__init__()
init(rank=str(rank))
Expand All @@ -108,6 +117,9 @@ def __init__(self, rank: int = 0, load_state_dict_strict: bool = True):
self._is_evaluation = False
self._is_submit_model = False
self._load_state_dict_strict = load_state_dict_strict
self._update_fit_loop = update_fit_loop

self.logger = logging.getLogger(self.__class__.__name__)

def reset_state(self, trainer):
"""Resets the state.
Expand All @@ -130,10 +142,12 @@ def reset_state(self, trainer):

# for next round
trainer.num_sanity_val_steps = 0 # Turn off sanity validation steps in following rounds of FL
if self.total_local_epochs and self.max_epochs_per_round is not None:
trainer.fit_loop.max_epochs = self.max_epochs_per_round + self.total_local_epochs
if self.total_local_steps and self.max_steps_per_round is not None:
trainer.fit_loop.epoch_loop.max_steps = self.max_steps_per_round + self.total_local_steps

if self._update_fit_loop:
if self.total_local_epochs and self.max_epochs_per_round is not None:
trainer.fit_loop.max_epochs = self.max_epochs_per_round + self.total_local_epochs
if self.total_local_steps and self.max_steps_per_round is not None:
trainer.fit_loop.epoch_loop.max_steps = self.max_steps_per_round + self.total_local_steps

# resets attributes
self.metrics = None
Expand Down Expand Up @@ -184,7 +198,15 @@ def _receive_and_update_model(self, trainer, pl_module):
model = self._receive_model(trainer)
if model:
if model.params:
pl_module.load_state_dict(model.params, strict=self._load_state_dict_strict)
missing_keys, unexpected_keys = pl_module.load_state_dict(
model.params, strict=self._load_state_dict_strict
)
if len(missing_keys) > 0:
self.logger.warning(f"There were missing keys when loading the global state_dict: {missing_keys}")
if len(unexpected_keys) > 0:
self.logger.warning(
f"There were unexpected keys when loading the global state_dict: {unexpected_keys}"
)
if model.current_round is not None:
self.current_round = model.current_round

Expand Down

0 comments on commit e98a21c

Please sign in to comment.