Skip to content

Commit

Permalink
support torch dynamo for deepspeed>=0.14.4 (#3069)
Browse files Browse the repository at this point in the history
* compile after deepspeed 0.14.4

* fix

* fmt

* add test
  • Loading branch information
oraluben authored Oct 10, 2024
1 parent f1f2b4d commit cba3f2d
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 0 deletions.
3 changes: 3 additions & 0 deletions src/accelerate/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1827,6 +1827,9 @@ def _prepare_deepspeed(self, *args):
kwargs["lr_scheduler"] = scheduler

engine, optimizer, _, lr_scheduler = ds_initialize(**kwargs)
if compare_versions("deepspeed", ">=", "0.14.4") and self.state.dynamo_plugin.backend != DynamoBackend.NO:
compile_kwargs = self.state.dynamo_plugin.to_kwargs()
engine.compile(backend=compile_kwargs.pop("backend"), compile_kwargs=compile_kwargs)
if optimizer is not None:
optimizer = DeepSpeedOptimizerWrapper(optimizer)
if scheduler is not None:
Expand Down
37 changes: 37 additions & 0 deletions tests/deepspeed/test_deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -889,6 +889,43 @@ def test_basic_run(self):
with patch_environment(omp_num_threads=1):
execute_subprocess_async(cmd)

def test_basic_dynamo_run(self):
test_file_path = path_in_accelerate_package("test_utils", "scripts", "external_deps", "test_performance.py")
with tempfile.TemporaryDirectory() as dirpath:
cmd = [
"accelerate",
"launch",
"--num_processes=1",
"--num_machines=1",
"--machine_rank=0",
"--mixed_precision=fp16",
"--use_deepspeed",
"--gradient_accumulation_steps=1",
"--offload_optimizer_device=none",
"--offload_param_device=none",
test_file_path,
"--dynamo_backend=eager",
"--model_name_or_path=distilbert-base-uncased",
"--num_epochs=1",
f"--output_dir={dirpath}",
]
with patch_environment(omp_num_threads=1):
with_dynamo = False
try:
r = execute_subprocess_async(
cmd,
env={
"TORCH_LOGS": "dynamo",
**os.environ,
},
)
with_dynamo = "torch._dynamo" in "\n".join(r.stderr)
except RuntimeError as e:
# It's possible that the run fail, but we focus on if dynamo is enabled via deepspeed.
with_dynamo = "torch._dynamo" in e.args[0]
finally:
assert with_dynamo


@require_deepspeed
@require_multi_device
Expand Down

0 comments on commit cba3f2d

Please sign in to comment.