From 238ba1f38ea3ac107374d58b33a1ae6c63dcc2ed Mon Sep 17 00:00:00 2001 From: traincheck-team Date: Sun, 15 Dec 2024 19:11:41 -0500 Subject: [PATCH] fix: forbid repeated deepspeed.initialize on training objects --- deepspeed/__init__.py | 51 ++++++++++++++++++++++++ tests/unit/runtime/test_ds_initialize.py | 39 ++++++++++++++++++ 2 files changed, 90 insertions(+) diff --git a/deepspeed/__init__.py b/deepspeed/__init__.py index a8d15cd5332b..6bc5642ec8ef 100755 --- a/deepspeed/__init__.py +++ b/deepspeed/__init__.py @@ -66,6 +66,50 @@ def _parse_version(version_str): dist = None +def _mark_initialized(trainobj: Union[torch.nn.Module, Optimizer, _LRScheduler]): + """Mark a trainobj as initialized by setting the ds_is_inited attribute to True.""" + # we shouldn't hit the assert below, but just in case + assert not hasattr( + trainobj, 'ds_is_inited' + ), "Model has already been initialized, please make sure to only call deepspeed.initialize on a model once." + trainobj.ds_is_inited = True + + +def _is_initialized(trainobj: Union[torch.nn.Module, Optimizer, _LRScheduler]): + """Check if a trainobj has been initialized by checking the ds_is_inited attribute.""" + if hasattr(trainobj, 'ds_is_inited'): + # we shouldn't hit the assert below, but just in case + assert trainobj.ds_is_inited, "Not expecting the model has `ds_is_inited` to be False if it exists, make sure you didn't set it to False or called deepspeed.initialize on the model more than once." + return True + return False + + +def _assert_trainobjs_not_inited(model: torch.nn.Module, optimizer: Optional[Optimizer], + lr_scheduler: Optional[_LRScheduler]): + """Enforce the model, optimizer, and lr_scheduler have not been used in a previous deepspeed.initialize call.""" + if _is_initialized(model): + raise ValueError( + "Model has already been initialized, please make sure to only call deepspeed.initialize on a model once.") + if optimizer is not None and _is_initialized(optimizer): + raise ValueError( + "Optimizer has already been initialized, please make sure to only call deepspeed.initialize on an optimizer once." + ) + if lr_scheduler is not None and _is_initialized(lr_scheduler): + raise ValueError( + "LR scheduler has already been initialized, please make sure to only call deepspeed.initialize on an LR scheduler once." + ) + + +def _mark_trainobjs_initialized(model: torch.nn.Module, optimizer: Optional[Optimizer], + lr_scheduler: Optional[_LRScheduler]): + """Mark the model, optimizer, and lr_scheduler as initialized.""" + _mark_initialized(model) + if optimizer is not None: + _mark_initialized(optimizer) + if lr_scheduler is not None: + _mark_initialized(lr_scheduler) + + def initialize(args=None, model: torch.nn.Module = None, optimizer: Optional[Union[Optimizer, DeepSpeedOptimizerCallable]] = None, @@ -137,6 +181,10 @@ def initialize(args=None, zero.partition_parameters.shutdown_init_context() assert model is not None, "deepspeed.initialize requires a model" + # enforce that model, optimizer, and lr_scheduler have not been used in a previous deepspeed.initialize call + _assert_trainobjs_not_inited(model, optimizer, lr_scheduler) + # mark model, optimizer, and lr_scheduler as initialized + _mark_trainobjs_initialized(model, optimizer, lr_scheduler) global dist from deepspeed import comm as dist @@ -221,6 +269,9 @@ def initialize(args=None, # Restore zero.Init context if necessary zero.partition_parameters.restore_init_context() + # mark engine, optimizer, and lr_scheduler as initialized + _mark_trainobjs_initialized(engine, engine.optimizer, engine.lr_scheduler) + return_items = [ engine, engine.optimizer, diff --git a/tests/unit/runtime/test_ds_initialize.py b/tests/unit/runtime/test_ds_initialize.py index a30f81cedde9..2c9ad701bfff 100644 --- a/tests/unit/runtime/test_ds_initialize.py +++ b/tests/unit/runtime/test_ds_initialize.py @@ -21,6 +21,7 @@ from deepspeed.utils.torch import required_torch_version from deepspeed.accelerator import get_accelerator from deepspeed.ops.op_builder import FusedAdamBuilder +from deepspeed import _assert_trainobjs_not_inited, _is_initialized @pytest.mark.parametrize('zero_stage', [0, 3]) @@ -434,3 +435,41 @@ def _lr_scheduler_callable(optimizer) -> _LRScheduler: else: # callable assert isinstance(ds_lr_scheduler, OneCycleLR) + + +# https://github.com/microsoft/DeepSpeed/issues/6770 +class TestNoRepeatedInitializationAllowed(DistributedTest): + world_size = 1 + + def test_no_repeated_init(self): + hidden_dim = 10 + model = SimpleModel(hidden_dim) + client_optimizer = torch.optim.Adam(model.parameters(), lr=0.01) + + model = SimpleModel() + # Initialize DeepSpeed configurations for fp16 + config_dict = {'train_batch_size': 1} + + client_optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-3) + # Initialize DeepSpeed engine + _assert_trainobjs_not_inited(model=model, optimizer=client_optimizer, lr_scheduler=None) + model_engine, optim, dataloader, scheduler = deepspeed.initialize(model=model, + optimizer=client_optimizer, + config_params=config_dict) + + # arguments should be marked as initialized now + assert _is_initialized(model), "Client model should be marked as initialized" + assert _is_initialized(client_optimizer), "Client optimizer should be marked as initialized" + + # return values should also be marked as initialized + assert _is_initialized(model_engine), "Model engine should be marked as initialized" + assert _is_initialized(optim), "Optimizer should be marked as initialized" + assert _is_initialized(scheduler), "Scheduler should be marked as initialized" + + exception_raised = False + try: + deepspeed.initialize(model=model, optimizer=client_optimizer, config_params=config_dict) + except ValueError: + exception_raised = True + + assert exception_raised, "Repeated initialization should raise an exception"