Skip to content

Commit

Permalink
Use ds-specific module id to avoid conflicts
Browse files Browse the repository at this point in the history
  • Loading branch information
tjruwase committed Dec 10, 2024
1 parent 06f1d36 commit 82cacfc
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 23 deletions.
24 changes: 13 additions & 11 deletions deepspeed/runtime/zero/parameter_offload.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,7 @@ def _start_of_forward_hook(module, *args):
self.module.register_forward_pre_hook(_start_of_forward_hook)

#likely one of them should be enough but just to be safe
self._register_hooks_recursively(self.module)
self._register_deepspeed_module(self.module)

# Add top module to stack trace
global FWD_MODULE_STACK
Expand All @@ -266,19 +266,19 @@ def mark_persistent_parameters(self, param_threshold, model_threshold):

return persistent_params

def _register_hooks_recursively(self, module, count=[0]):
def _register_deepspeed_module(self, module, count=[0]):
my_count = count[0]
module.id = my_count
module.ds_id = my_count

#print(f"{module.__class__} : {module.id}")
#print(f"{module.__class__} : {module.ds_id}")

if z3_leaf_module(module):
for param in module.parameters():
param.ds_z3_leaf_module = module
else:
for child in module.children():
count[0] = count[0] + 1
self._register_hooks_recursively(child, count=count)
self._register_deepspeed_module(child, count=count)

@instrument_w_nvtx
def _pre_forward_module_hook(module, *args):
Expand Down Expand Up @@ -463,14 +463,16 @@ def pre_sub_module_forward_function(self, sub_module):

@torch.no_grad()
def post_sub_module_forward_function(self, sub_module):
see_memory_usage(f"After sub module function {sub_module.__class__.__name__} {sub_module.id} before release",
force=False)
see_memory_usage(
f"After sub module function {sub_module.__class__.__name__} {sub_module.ds_id} before release",
force=False)

param_coordinator = self.get_param_coordinator()
param_coordinator.release_sub_module(sub_module)

see_memory_usage(f"After sub module function {sub_module.__class__.__name__} {sub_module.id} after release",
force=False)
see_memory_usage(
f"After sub module function {sub_module.__class__.__name__} {sub_module.ds_id} after release",
force=False)

@torch.no_grad()
def pre_sub_module_backward_function(self, sub_module):
Expand All @@ -485,13 +487,13 @@ def pre_sub_module_backward_function(self, sub_module):
def post_sub_module_backward_function(self, sub_module):
# assert sub_module.training, "backward pass is invalid for module in evaluation mode"
see_memory_usage(
f"After sub module backward function {sub_module.__class__.__name__} {sub_module.id} before release",
f"After sub module backward function {sub_module.__class__.__name__} {sub_module.ds_id} before release",
force=False)

self.get_param_coordinator().release_sub_module(sub_module)

see_memory_usage(
f"After sub module backward function {sub_module.__class__.__name__} {sub_module.id} after release",
f"After sub module backward function {sub_module.__class__.__name__} {sub_module.ds_id} after release",
force=False)

def _set_z3_leaf_modules_by_threshold(self, module, zero_module_granularity_threshold):
Expand Down
24 changes: 12 additions & 12 deletions deepspeed/runtime/zero/partitioned_param_coordinator.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,17 +172,17 @@ def trace_prologue(self, sub_module: Module) -> None:
# sub_module must match expectation else invalidate trace cache
if len(self.__submodule_order) <= self.__step_id:
print_rank_0(
f"Invalidate trace cache @ step {self.__step_id} and module {sub_module.id}: "
f"Invalidate trace cache @ step {self.__step_id} and module {sub_module.ds_id}: "
f"cache has only {len(self.__submodule_order)} modules",
force=True)
self._invalidate_trace()
return

if sub_module != self.__submodule_order[self.__step_id]:
expected_module_id = self.__submodule_order[self.__step_id].id
expected_module_id = self.__submodule_order[self.__step_id].ds_id
print_rank_0(
f"Invalidate trace cache @ step {self.__step_id}: "
f"expected module {expected_module_id}, but got module {sub_module.id}",
f"expected module {expected_module_id}, but got module {sub_module.ds_id}",
force=True)
self._invalidate_trace()

Expand All @@ -196,7 +196,7 @@ def record_module(self, sub_module: Module) -> None:
raise RuntimeError(f"attempted to record trace when status = {self.__trace_mode}")

self.__submodule_order.append(sub_module)
self.__step_id_module_fetched_for[sub_module.id].append(self.__step_id)
self.__step_id_module_fetched_for[sub_module.ds_id].append(self.__step_id)

def record_parameters(self, sub_module: Module) -> None:
if is_compiling():
Expand All @@ -205,7 +205,7 @@ def record_parameters(self, sub_module: Module) -> None:
if not self.is_record_trace():
raise RuntimeError(f"attempted to record trace when status = {self.__trace_mode}")

step_id = self.__step_id_module_fetched_for[sub_module.id].popleft()
step_id = self.__step_id_module_fetched_for[sub_module.ds_id].popleft()
for param in sorted(set(iter_params(sub_module, recurse=z3_leaf_module(sub_module))), key=lambda p: p.ds_id):
self.__param_order.append(__class__.__ParamInTrace(param=param, step_id_last_used_at=step_id))

Expand All @@ -225,7 +225,7 @@ def reset_step(self) -> None:

if not self.is_complete_trace(): # not self.trace_complete:
# Make sure that recorded submodule orders are identical across ranks
assert_ints_same_as_other_ranks([m.id for m in self.__submodule_order])
assert_ints_same_as_other_ranks([m.ds_id for m in self.__submodule_order])

if self.is_record_trace():
# Successfully recorded a trace
Expand All @@ -238,7 +238,7 @@ def reset_step(self) -> None:
self.__param_order = tuple(self.__param_order) # freeze
self.__trace_mode = ZeRoTraceMode.COMPLETE
print_rank_0(
f"completed record trace of {len(self.__submodule_order)} sub modules: {[m.id for m in self.__submodule_order]}",
f"completed record trace of {len(self.__submodule_order)} sub modules: {[m.ds_id for m in self.__submodule_order]}",
force=False)
else:
# Enable trace recording for next forward/backward pass
Expand Down Expand Up @@ -281,7 +281,7 @@ def fetch_sub_module(self, current_submodule: Module, forward: bool) -> None:
"""
if logger.isEnabledFor(logging.DEBUG):
debug_rank0(
f"{self.__step_id}: M{current_submodule.id}({type(current_submodule).__name__}) P{[p.ds_id for p in iter_params(current_submodule, recurse=z3_leaf_module(current_submodule))]} "
f"{self.__step_id}: M{current_submodule.ds_id}({type(current_submodule).__name__}) P{[p.ds_id for p in iter_params(current_submodule, recurse=z3_leaf_module(current_submodule))]} "
+ str({
"avail": f"{self.__n_available_params:.1e}",
"queue_sz": f"{len(self.__param_queue or [])}",
Expand All @@ -294,7 +294,7 @@ def fetch_sub_module(self, current_submodule: Module, forward: bool) -> None:

if fetch_numel > 0:
event_name = __class__.FORWARD_FETCH_SUBMIT if forward else __class__.BACKWARD_FETCH_SUBMIT
self._dump_param_ids(event_name, current_submodule.id,
self._dump_param_ids(event_name, current_submodule.ds_id,
[p.ds_id for p in params_to_fetch if p.ds_status == ZeroParamStatus.NOT_AVAILABLE])
self.__profiler.start_event(event_name)
# kick off all gather for params in the immediately required submodule
Expand All @@ -310,7 +310,7 @@ def fetch_sub_module(self, current_submodule: Module, forward: bool) -> None:
self.__profiler.start_event(wait_event_name)
# wait for parameters in the immediately needed submodule to become available
for param in params_to_fetch:
param.ds_active_sub_modules.add(current_submodule.id)
param.ds_active_sub_modules.add(current_submodule.ds_id)
if logger.isEnabledFor(logging.DEBUG):
debug_rank0(f"-wait: {param.ds_summary()}")
if param in self.__inflight_param_registry:
Expand Down Expand Up @@ -352,7 +352,7 @@ def fetch_sub_module(self, current_submodule: Module, forward: bool) -> None:
if discarded_from_prefetch_queue != params_not_already_fetched:
raise RuntimeError(
f"tracing error at step {self.__step_id}: \n"
f"module id: {current_submodule.id}, training: {current_submodule.training}\n"
f"module id: {current_submodule.ds_id}, training: {current_submodule.training}\n"
f"expected the next {len(params_not_already_fetched)} parameters in the "
f"parameter fetch queue to be {tuple(p.ds_summary(use_debug_name=True) for p in params_not_already_fetched)} \n"
f"but got \n {tuple(p.ds_summary(use_debug_name=True) for p in discarded_from_prefetch_queue)}.")
Expand Down Expand Up @@ -413,7 +413,7 @@ def release_sub_module(self, submodule: Module) -> None:
params_to_release = (self.__params_to_release(submodule, self.__step_id) if self.is_complete_trace() else set(
p.ds_id for p in iter_params(submodule, recurse=z3_leaf_module(submodule))))
for param in iter_params(submodule, recurse=z3_leaf_module(submodule)):
param.ds_active_sub_modules.discard(submodule.id)
param.ds_active_sub_modules.discard(submodule.ds_id)
if param.ds_id in params_to_release and not param.is_external_param:
self.__release_param(param)

Expand Down
34 changes: 34 additions & 0 deletions tests/unit/runtime/zero/test_zero.py
Original file line number Diff line number Diff line change
Expand Up @@ -1673,3 +1673,37 @@ def test(self, prefetch_ratio, zero_stage=3):
with torch.no_grad():
for batch in data_loader:
loss = model(batch[0], batch[1])


# Avoid overwriting client module id
# https://github.com/microsoft/DeepSpeed/issues/6772
class TestZero3ClientModuleID(DistributedTest):
world_size = 2

def test_client_module_id(self):
config_dict = {
"train_micro_batch_size_per_gpu": 1,
"steps_per_print": 1,
"optimizer": {
"type": "Adam",
},
"zero_optimization": {
"stage": 3
},
}

class MyModel(torch.nn.Module):

def __init__(self):
super().__init__()
self.id = 3 # ID arbitrary client usage, e.g. GPU placement
self.fc = Linear(128, 128)

def forward(self, x):
return self.fc(x)

model = MyModel()
pre_init_m_id = model.id
model, _, _, _ = deepspeed.initialize(model=model, model_parameters=model.parameters(), config=config_dict)
post_init_m_id = model.id
assert pre_init_m_id == post_init_m_id

0 comments on commit 82cacfc

Please sign in to comment.