Skip to content

Commit

Permalink
Support torchrun multi node on local executor
Browse files Browse the repository at this point in the history
Signed-off-by: Hemil Desai <[email protected]>
  • Loading branch information
hemildesai committed Jan 30, 2025
1 parent 5ed6128 commit 7a242ec
Show file tree
Hide file tree
Showing 5 changed files with 23 additions and 12 deletions.
3 changes: 2 additions & 1 deletion src/nemo_run/core/execution/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ class LocalExecutor(Executor):

#: Used by components like torchrun to deduce the number of tasks to launch.
ntasks_per_node: int = 1
nodes: int = 1

def assign(
self,
Expand All @@ -51,7 +52,7 @@ def assign(
os.makedirs(self.job_dir, exist_ok=True)

def nnodes(self) -> int:
return 1
return self.nodes

def nproc_per_node(self) -> int:
return self.ntasks_per_node
2 changes: 2 additions & 0 deletions src/nemo_run/run/torchx_backend/components/ft_launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def ft_launcher(
rank_termination_signal: Optional[str] = None,
log_level: Optional[str] = None,
max_restarts: Optional[int] = None,
use_env: bool = False,
) -> specs.AppDef:
torchrun_component = torchrun.torchrun(
*script_args,
Expand All @@ -67,6 +68,7 @@ def ft_launcher(
mounts=mounts,
debug=debug,
max_retries=max_retries,
use_env=use_env,
)

ft_args = []
Expand Down
24 changes: 14 additions & 10 deletions src/nemo_run/run/torchx_backend/components/torchrun.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ def torchrun(
rdzv_backend: str = "c10d",
mounts: Optional[list[str]] = None,
debug: bool = False,
use_env: bool = False,
) -> specs.AppDef:
"""
Distributed data parallel style application (one role, multi-replica).
Expand Down Expand Up @@ -111,17 +112,20 @@ def torchrun(
nproc_per_node = str(nproc_per_node)
node_rank = "0"
else:
# for multi-node, rely on the rank0_env environment variable set by
# the schedulers (see scheduler implementation for the actual env var this maps to)
# some schedulers (e.g. aws batch) make the rank0's ip-addr available on all BUT on rank0
# so default to "localhost" if the env var is not set or is empty
# rdzv_endpoint bash resolves to something to the effect of
# ${TORCHX_RANK0_HOST:=localhost}:29500
# use $$ in the prefix to escape the '$' literal (rather than a string Template substitution argument)
rdzv_endpoint = torchx_dist._noquote(f"$${ExecutorMacros.HEAD_NODE_IP_VAR}:{rdzv_port}")
num_nodes = torchx_dist._noquote(f"$${ExecutorMacros.NUM_NODES_VAR}")
if use_env and os.getenv("MASTER_ADDR") and os.getenv("MASTER_PORT"):
master_addr = os.environ["MASTER_ADDR"]
master_port = os.environ["MASTER_PORT"]
rdzv_endpoint = torchx_dist._noquote(master_addr + ":" + master_port)
else:
rdzv_endpoint = torchx_dist._noquote(f"$${ExecutorMacros.HEAD_NODE_IP_VAR}:{rdzv_port}")

num_nodes = nnodes_rep
nproc_per_node = str(nproc_per_node)
node_rank = torchx_dist._noquote(f"$${ExecutorMacros.NODE_RANK_VAR}")

if use_env and os.getenv("NODE_RANK"):
node_rank = os.environ["NODE_RANK"]
else:
node_rank = torchx_dist._noquote(f"$${ExecutorMacros.NODE_RANK_VAR}")

if env is None:
env = {}
Expand Down
4 changes: 4 additions & 0 deletions src/nemo_run/run/torchx_backend/packaging.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from nemo_run.config import SCRIPTS_DIR, Partial, Script
from nemo_run.core.execution.base import Executor, FaultTolerance, Torchrun
from nemo_run.core.execution.local import LocalExecutor
from nemo_run.core.execution.slurm import SlurmExecutor
from nemo_run.core.serialization.yaml import YamlSerializer
from nemo_run.core.serialization.zlib_json import ZlibJSONSerializer
from nemo_run.run.torchx_backend.components import ft_launcher, torchrun
Expand Down Expand Up @@ -120,6 +121,7 @@ def package(
entrypoint = fn_or_script.entrypoint

launcher = executor.get_launcher()
use_env = not isinstance(executor, SlurmExecutor)
if launcher and isinstance(launcher, Torchrun):
app_def = torchrun.torchrun(
*args,
Expand All @@ -139,6 +141,7 @@ def package(
mounts=mounts,
debug=executor.packager.debug,
max_retries=executor.retries,
use_env=use_env,
)
elif launcher and isinstance(launcher, FaultTolerance):
app_def = ft_launcher.ft_launcher(
Expand All @@ -165,6 +168,7 @@ def package(
log_level=launcher.log_level,
max_retries=executor.retries,
max_restarts=launcher.max_restarts,
use_env=use_env,
)
else:
app_def = specs.AppDef(
Expand Down
2 changes: 1 addition & 1 deletion test/run/torchx_backend/test_packaging.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ def test_package_torchrun(mock_executor):
"--rdzv-id",
"1",
"--nnodes",
"$$${num_nodes_var}",
"2",
"--nproc-per-node",
"1",
"--node-rank",
Expand Down

0 comments on commit 7a242ec

Please sign in to comment.