From 4d9e98fd854e4f2d8a10d78529ab5a8d8591f038 Mon Sep 17 00:00:00 2001 From: Nir Sonnenschein Date: Wed, 28 Aug 2024 14:53:52 +0300 Subject: [PATCH 1/3] Change compile for pipeline module torch.compile We have encountered and issue with torch.compile and the pipeline module. modifying a member of the module duing the run will cause torch compile to restart the analysis and treat the module as dynamic. this happens because the fwd function will modify the micro_offset attribute of the pipeline module. in order to bypass this issue without significantly changing the way the pipeline module works we propose to compile only the layers in the pipeline module instead of the pipeline module itslef. this will bypass the issue, and should still give most of the benefit of torch compiling the pipeline module while avoiding the issue. --- deepspeed/runtime/pipe/module.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/deepspeed/runtime/pipe/module.py b/deepspeed/runtime/pipe/module.py index 31fec30be788..730501a59cd2 100644 --- a/deepspeed/runtime/pipe/module.py +++ b/deepspeed/runtime/pipe/module.py @@ -662,3 +662,8 @@ def get_additional_losses(self): Return a dictionary of {"loss name": loss_value} or None if no additional losses. """ return None + + def compile(self, *args, **kwargs): + for idx, layer in enumerate(self.forward_funcs): + new_layer = torch.compile(layer, *args, **kwargs) + self.forward_funcs[idx] = new_layer From 441a3280093e3e8a793c55f1d6ebd6050aa0c163 Mon Sep 17 00:00:00 2001 From: Nir Sonnenschein Date: Wed, 4 Dec 2024 13:16:33 +0200 Subject: [PATCH 2/3] add additional fix and unit test --- deepspeed/runtime/pipe/module.py | 7 +++++-- tests/unit/pipe/test_pipe_module.py | 6 ++++-- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/deepspeed/runtime/pipe/module.py b/deepspeed/runtime/pipe/module.py index 730501a59cd2..9fbd91f750a9 100644 --- a/deepspeed/runtime/pipe/module.py +++ b/deepspeed/runtime/pipe/module.py @@ -665,5 +665,8 @@ def get_additional_losses(self): def compile(self, *args, **kwargs): for idx, layer in enumerate(self.forward_funcs): - new_layer = torch.compile(layer, *args, **kwargs) - self.forward_funcs[idx] = new_layer + if isinstance(layer, nn.Module): + layer.compile(*args, **kwargs) + else: + new_layer = torch.compile(layer, *args, **kwargs) + self.forward_funcs[idx] = new_layer diff --git a/tests/unit/pipe/test_pipe_module.py b/tests/unit/pipe/test_pipe_module.py index 05c6a82ef55a..f9b89780a115 100644 --- a/tests/unit/pipe/test_pipe_module.py +++ b/tests/unit/pipe/test_pipe_module.py @@ -62,7 +62,8 @@ class TestPipeModuleSequential(DistributedTest): world_size = 2 @pytest.mark.parametrize("activation_checkpoints", [False, True]) - def test(self, sequential_model, simple_config, batch_input, activation_checkpoints): + @pytest.mark.parametrize("use_compile", [False, True]) + def test(self, sequential_model, simple_config, batch_input, activation_checkpoints, use_compile): base_model = copy.deepcopy(sequential_model) base_input = batch_input.clone().detach() base_output = base_model(base_input) @@ -71,7 +72,8 @@ def test(self, sequential_model, simple_config, batch_input, activation_checkpoi pipe_model = copy.deepcopy(sequential_model) pipe_model = PipelineModule(layers=pipe_model, num_stages=2) - + if (use_compile): + pipe_model.compile() # Ensure all parameters are accounted for. my_params = sum(p.numel() for p in pipe_model.parameters()) total_pipe_params = torch.LongTensor([my_params]).to(get_accelerator().device_name()) From 40fb3c4ed5a9ffcf8eb02d479b4eb2561743d0ae Mon Sep 17 00:00:00 2001 From: Nir Sonnenschein Date: Mon, 9 Dec 2024 14:08:08 +0200 Subject: [PATCH 3/3] make test process non-daemonic running torch compile with daemonic threads will cause an error due to the inductor implementation which can spawn processes --- tests/unit/pipe/test_pipe_module.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/unit/pipe/test_pipe_module.py b/tests/unit/pipe/test_pipe_module.py index f9b89780a115..2a8a4b9b7d82 100644 --- a/tests/unit/pipe/test_pipe_module.py +++ b/tests/unit/pipe/test_pipe_module.py @@ -60,6 +60,8 @@ def batch_input(): class TestPipeModuleSequential(DistributedTest): world_size = 2 + # needs to be set for torch.compile: running torch.compile with daemonic process causes an error + non_daemonic_procs = True @pytest.mark.parametrize("activation_checkpoints", [False, True]) @pytest.mark.parametrize("use_compile", [False, True])