From 7a242ec7746630fa5c03943c1ae9c84a9a1e9f8b Mon Sep 17 00:00:00 2001 From: Hemil Desai Date: Thu, 30 Jan 2025 03:36:50 -0800 Subject: [PATCH] Support torchrun multi node on local executor Signed-off-by: Hemil Desai --- src/nemo_run/core/execution/local.py | 3 ++- .../torchx_backend/components/ft_launcher.py | 2 ++ .../run/torchx_backend/components/torchrun.py | 24 +++++++++++-------- src/nemo_run/run/torchx_backend/packaging.py | 4 ++++ test/run/torchx_backend/test_packaging.py | 2 +- 5 files changed, 23 insertions(+), 12 deletions(-) diff --git a/src/nemo_run/core/execution/local.py b/src/nemo_run/core/execution/local.py index c68260d..eba3d49 100644 --- a/src/nemo_run/core/execution/local.py +++ b/src/nemo_run/core/execution/local.py @@ -37,6 +37,7 @@ class LocalExecutor(Executor): #: Used by components like torchrun to deduce the number of tasks to launch. ntasks_per_node: int = 1 + nodes: int = 1 def assign( self, @@ -51,7 +52,7 @@ def assign( os.makedirs(self.job_dir, exist_ok=True) def nnodes(self) -> int: - return 1 + return self.nodes def nproc_per_node(self) -> int: return self.ntasks_per_node diff --git a/src/nemo_run/run/torchx_backend/components/ft_launcher.py b/src/nemo_run/run/torchx_backend/components/ft_launcher.py index 8ebd7c0..2f880f8 100644 --- a/src/nemo_run/run/torchx_backend/components/ft_launcher.py +++ b/src/nemo_run/run/torchx_backend/components/ft_launcher.py @@ -48,6 +48,7 @@ def ft_launcher( rank_termination_signal: Optional[str] = None, log_level: Optional[str] = None, max_restarts: Optional[int] = None, + use_env: bool = False, ) -> specs.AppDef: torchrun_component = torchrun.torchrun( *script_args, @@ -67,6 +68,7 @@ def ft_launcher( mounts=mounts, debug=debug, max_retries=max_retries, + use_env=use_env, ) ft_args = [] diff --git a/src/nemo_run/run/torchx_backend/components/torchrun.py b/src/nemo_run/run/torchx_backend/components/torchrun.py index 0d0f179..e4673f7 100644 --- a/src/nemo_run/run/torchx_backend/components/torchrun.py +++ b/src/nemo_run/run/torchx_backend/components/torchrun.py @@ -59,6 +59,7 @@ def torchrun( rdzv_backend: str = "c10d", mounts: Optional[list[str]] = None, debug: bool = False, + use_env: bool = False, ) -> specs.AppDef: """ Distributed data parallel style application (one role, multi-replica). @@ -111,17 +112,20 @@ def torchrun( nproc_per_node = str(nproc_per_node) node_rank = "0" else: - # for multi-node, rely on the rank0_env environment variable set by - # the schedulers (see scheduler implementation for the actual env var this maps to) - # some schedulers (e.g. aws batch) make the rank0's ip-addr available on all BUT on rank0 - # so default to "localhost" if the env var is not set or is empty - # rdzv_endpoint bash resolves to something to the effect of - # ${TORCHX_RANK0_HOST:=localhost}:29500 - # use $$ in the prefix to escape the '$' literal (rather than a string Template substitution argument) - rdzv_endpoint = torchx_dist._noquote(f"$${ExecutorMacros.HEAD_NODE_IP_VAR}:{rdzv_port}") - num_nodes = torchx_dist._noquote(f"$${ExecutorMacros.NUM_NODES_VAR}") + if use_env and os.getenv("MASTER_ADDR") and os.getenv("MASTER_PORT"): + master_addr = os.environ["MASTER_ADDR"] + master_port = os.environ["MASTER_PORT"] + rdzv_endpoint = torchx_dist._noquote(master_addr + ":" + master_port) + else: + rdzv_endpoint = torchx_dist._noquote(f"$${ExecutorMacros.HEAD_NODE_IP_VAR}:{rdzv_port}") + + num_nodes = nnodes_rep nproc_per_node = str(nproc_per_node) - node_rank = torchx_dist._noquote(f"$${ExecutorMacros.NODE_RANK_VAR}") + + if use_env and os.getenv("NODE_RANK"): + node_rank = os.environ["NODE_RANK"] + else: + node_rank = torchx_dist._noquote(f"$${ExecutorMacros.NODE_RANK_VAR}") if env is None: env = {} diff --git a/src/nemo_run/run/torchx_backend/packaging.py b/src/nemo_run/run/torchx_backend/packaging.py index 3a6dd14..185d1ce 100644 --- a/src/nemo_run/run/torchx_backend/packaging.py +++ b/src/nemo_run/run/torchx_backend/packaging.py @@ -24,6 +24,7 @@ from nemo_run.config import SCRIPTS_DIR, Partial, Script from nemo_run.core.execution.base import Executor, FaultTolerance, Torchrun from nemo_run.core.execution.local import LocalExecutor +from nemo_run.core.execution.slurm import SlurmExecutor from nemo_run.core.serialization.yaml import YamlSerializer from nemo_run.core.serialization.zlib_json import ZlibJSONSerializer from nemo_run.run.torchx_backend.components import ft_launcher, torchrun @@ -120,6 +121,7 @@ def package( entrypoint = fn_or_script.entrypoint launcher = executor.get_launcher() + use_env = not isinstance(executor, SlurmExecutor) if launcher and isinstance(launcher, Torchrun): app_def = torchrun.torchrun( *args, @@ -139,6 +141,7 @@ def package( mounts=mounts, debug=executor.packager.debug, max_retries=executor.retries, + use_env=use_env, ) elif launcher and isinstance(launcher, FaultTolerance): app_def = ft_launcher.ft_launcher( @@ -165,6 +168,7 @@ def package( log_level=launcher.log_level, max_retries=executor.retries, max_restarts=launcher.max_restarts, + use_env=use_env, ) else: app_def = specs.AppDef( diff --git a/test/run/torchx_backend/test_packaging.py b/test/run/torchx_backend/test_packaging.py index 2c61925..6f5b30e 100644 --- a/test/run/torchx_backend/test_packaging.py +++ b/test/run/torchx_backend/test_packaging.py @@ -206,7 +206,7 @@ def test_package_torchrun(mock_executor): "--rdzv-id", "1", "--nnodes", - "$$${num_nodes_var}", + "2", "--nproc-per-node", "1", "--node-rank",