Skip to content

Commit

Permalink
handle callable types in init mark
Browse files Browse the repository at this point in the history
  • Loading branch information
traincheck-team committed Dec 19, 2024
1 parent d1e7777 commit 62067cc
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 34 deletions.
45 changes: 21 additions & 24 deletions deepspeed/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,48 +66,45 @@ def _parse_version(version_str):
dist = None


def _mark_initialized(trainobj: Union[torch.nn.Module, Optimizer, _LRScheduler]):
def _mark_ds_initialized(trainobj: Union[torch.nn.Module, Optimizer, _LRScheduler]):
"""Mark a trainobj as initialized by setting the ds_is_inited attribute to True."""
if hasattr(trainobj, 'ds_is_inited'):
assert trainobj.ds_is_inited, "Not expecting the training object has `ds_is_inited` to be False if it exists, make sure you didn't set it to False or called deepspeed.initialize on the model more than once."
return

trainobj.ds_is_inited = True


def _is_initialized(trainobj: Union[torch.nn.Module, Optimizer, _LRScheduler]):
def _is_ds_initialized(trainobj: Union[torch.nn.Module, Optimizer, _LRScheduler]):
"""Check if a trainobj has been initialized by checking the ds_is_inited attribute."""
if hasattr(trainobj, 'ds_is_inited'):
# we shouldn't hit the assert below, but just in case
assert trainobj.ds_is_inited, "Not expecting the training object has `ds_is_inited` to be False if it exists, make sure you didn't set it to False or called deepspeed.initialize on the model more than once."
return True
return False
return getattr(trainobj, 'ds_is_inited', False)


def _assert_trainobjs_not_inited(model: torch.nn.Module, optimizer: Optional[Optimizer],
lr_scheduler: Optional[_LRScheduler]):
def _assert_trainobjs_not_inited(model: torch.nn.Module, optimizer: Optional[Union[Optimizer,
DeepSpeedOptimizerCallable]],
lr_scheduler: Optional[Union[_LRScheduler, DeepSpeedSchedulerCallable]]):
"""Enforce the model, optimizer, and lr_scheduler have not been used in a previous deepspeed.initialize call."""
if _is_initialized(model):
if _is_ds_initialized(model):
raise ValueError(
"Model has already been initialized, please make sure to only call deepspeed.initialize on a model once.")
if optimizer is not None and _is_initialized(optimizer):
if optimizer is not None and isinstance(optimizer, Optimizer) and _is_ds_initialized(optimizer):
raise ValueError(
"Optimizer has already been initialized, please make sure to only call deepspeed.initialize on an optimizer once."
)
if lr_scheduler is not None and _is_initialized(lr_scheduler):
if lr_scheduler is not None and isinstance(lr_scheduler, _LRScheduler) and _is_ds_initialized(lr_scheduler):
raise ValueError(
"LR scheduler has already been initialized, please make sure to only call deepspeed.initialize on an LR scheduler once."
)


def _mark_trainobjs_initialized(model: torch.nn.Module, optimizer: Optional[Optimizer],
lr_scheduler: Optional[_LRScheduler]):
"""Mark the model, optimizer, and lr_scheduler as initialized."""
_mark_initialized(model)
if optimizer is not None:
_mark_initialized(optimizer)
if lr_scheduler is not None:
_mark_initialized(lr_scheduler)
def _mark_trainobjs_initialized(model: torch.nn.Module, optimizer: Optional[Union[Optimizer,
DeepSpeedOptimizerCallable]],
lr_scheduler: Optional[Union[_LRScheduler, DeepSpeedSchedulerCallable]]):
"""Mark the model, optimizer, and lr_scheduler as initialized.
Note that callables of type DeepSpeedOptimizerCallable and DeepSpeedSchedulerCallable are not marked
as they are not stateful and reuse should be permissible.
"""
_mark_ds_initialized(model)
if optimizer is not None and isinstance(optimizer, Optimizer):
_mark_ds_initialized(optimizer)
if lr_scheduler is not None and isinstance(lr_scheduler, _LRScheduler):
_mark_ds_initialized(lr_scheduler)


def initialize(args=None,
Expand Down
31 changes: 21 additions & 10 deletions tests/unit/runtime/test_ds_initialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from deepspeed.utils.torch import required_torch_version
from deepspeed.accelerator import get_accelerator
from deepspeed.ops.op_builder import FusedAdamBuilder
from deepspeed import _assert_trainobjs_not_inited, _is_initialized
from deepspeed import _assert_trainobjs_not_inited, _is_ds_initialized


@pytest.mark.parametrize('zero_stage', [0, 3])
Expand Down Expand Up @@ -441,12 +441,22 @@ def _lr_scheduler_callable(optimizer) -> _LRScheduler:
class TestNoRepeatedInitializationAllowed(DistributedTest):
world_size = 1

def test_no_repeated_init(self):
@pytest.mark.parametrize('optimizer_type', [None, Optimizer, Callable])
def test(self, optimizer_type):
hidden_dim = 10
model = SimpleModel(hidden_dim)
client_optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
# Initialize DeepSpeed configurations for fp16

def _optimizer_callable(params) -> Optimizer:
return AdamW(params=params)

config_dict = {'train_batch_size': 1}
if optimizer_type is None:
client_optimizer = None
config_dict['optimizer'] = {'type': ADAM_OPTIMIZER}
elif optimizer_type is Optimizer:
client_optimizer = Adam(model.parameters())
else:
client_optimizer = _optimizer_callable

# Initialize DeepSpeed engine
_assert_trainobjs_not_inited(model=model, optimizer=client_optimizer, lr_scheduler=None)
Expand All @@ -455,12 +465,13 @@ def test_no_repeated_init(self):
config_params=config_dict)

# arguments should be marked as initialized now
assert _is_initialized(model), "Client model should be marked as initialized"
assert _is_initialized(client_optimizer), "Client optimizer should be marked as initialized"
assert _is_ds_initialized(model), "Client model should be marked as initialized"
if optimizer_type is Optimizer:
assert _is_ds_initialized(client_optimizer), "Client optimizer should be marked as initialized"

# return values should also be marked as initialized
assert _is_initialized(model_engine), "Model engine should be marked as initialized"
assert _is_initialized(optim), "Optimizer should be marked as initialized"
assert _is_ds_initialized(model_engine), "Model engine should be marked as initialized"
assert _is_ds_initialized(optim), "Optimizer should be marked as initialized"

exception_raised = False
try:
Expand All @@ -480,15 +491,15 @@ def test_no_repeated_init(self):

exception_raised = False
try:
deepspeed.initialize(model=model, optimizer=client_optimizer, config_params=config_dict)
deepspeed.initialize(model=model, optimizer=optim, config_params=config_dict)
except ValueError:
exception_raised = True

assert exception_raised, "Initialization on ds types should raise an exception"

exception_raised = False
try:
deepspeed.initialize(model=model_engine, optimizer=client_optimizer, config_params=config_dict)
deepspeed.initialize(model=model_engine, optimizer=optim, config_params=config_dict)
except ValueError:
exception_raised = True
assert exception_raised, "Initialization on ds types should raise an exception"

0 comments on commit 62067cc

Please sign in to comment.