Skip to content

Commit

Permalink
Add support for torchrun to DGXC Executor
Browse files Browse the repository at this point in the history
DGX Cloud uses the PyTorch Training Operator from KubeFlow under the
hood to launch jobs. This handles many of the variables necessary for
running distributed PyTorch jobs with torchrun and only a subset of
settings are required to launch the job as the original default
settings will conflict with the auto-configured setup from DGX Cloud.

Signed-Off-By: Robert Clark <[email protected]>
  • Loading branch information
roclark committed Feb 5, 2025
1 parent 61bb965 commit b3a2f62
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 20 deletions.
2 changes: 1 addition & 1 deletion src/nemo_run/core/execution/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,4 @@
from nemo_run.core.execution.skypilot import SkypilotExecutor
from nemo_run.core.execution.slurm import SlurmExecutor

__all__ = ["LocalExecutor", "SlurmExecutor", "SkypilotExecutor"]
__all__ = ["LocalExecutor", "SlurmExecutor", "SkypilotExecutor", "DGXCloudExecutor"]
8 changes: 7 additions & 1 deletion src/nemo_run/core/execution/dgxcloud.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def create_distributed_job(
return response

def launch(self, name: str, cmd: list[str]) -> tuple[str, str]:
name = name.replace("_", "-") # to meet K8s requirements
name = name.replace("_", "-").replace(".", "-") # to meet K8s requirements
token = self.get_auth_token()
if not token:
raise RuntimeError("Failed to get auth token")
Expand All @@ -156,6 +156,12 @@ def launch(self, name: str, cmd: list[str]) -> tuple[str, str]:
status = r_json["actualPhase"]
return job_id, status

def nnodes(self) -> int:
return self.nodes

def nproc_per_node(self) -> int:
return self.gpus_per_node

def status(self, job_id: str) -> Optional[DGXCloudState]:
url = f"{self.base_url}/workloads/distributed/{job_id}"
token = self.get_auth_token()
Expand Down
41 changes: 23 additions & 18 deletions src/nemo_run/run/torchx_backend/components/torchrun.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ def torchrun(
rdzv_backend: str = "c10d",
mounts: Optional[list[str]] = None,
debug: bool = False,
dgxc: bool = False,
) -> specs.AppDef:
"""
Distributed data parallel style application (one role, multi-replica).
Expand Down Expand Up @@ -92,6 +93,7 @@ def torchrun(
mounts: mounts to mount into the worker environment/container (ex. type=<bind/volume>,src=/host,dst=/job[,readonly]).
See scheduler documentation for more info.
debug: whether to run with preset debug flags enabled
dgxc: whether to use a subset of settings for DGX Cloud
"""
if (script is None) == (m is None):
raise ValueError("exactly one of --script and -m must be specified")
Expand Down Expand Up @@ -130,24 +132,27 @@ def torchrun(
if debug:
env.update(_TORCH_DEBUG_FLAGS)

cmd = [
"--rdzv-backend",
rdzv_backend,
"--rdzv-endpoint",
rdzv_endpoint,
"--rdzv-id",
f"{random.randint(1, 10000)}",
"--nnodes",
num_nodes,
"--nproc-per-node",
nproc_per_node,
"--node-rank",
node_rank,
"--tee",
"3",
# "--role",
# "",
]
if dgxc:
cmd = ["--nnodes", nnodes_rep, "--nproc-per-node", nproc_per_node]
else:
cmd = [
"--rdzv-backend",
rdzv_backend,
"--rdzv-endpoint",
rdzv_endpoint,
"--rdzv-id",
f"{random.randint(1, 10000)}",
"--nnodes",
num_nodes,
"--nproc-per-node",
nproc_per_node,
"--node-rank",
node_rank,
"--tee",
"3",
# "--role",
# "",
]
if script is not None:
if no_python:
cmd += ["--no-python"]
Expand Down
2 changes: 2 additions & 0 deletions src/nemo_run/run/torchx_backend/packaging.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

from nemo_run.config import SCRIPTS_DIR, Partial, Script
from nemo_run.core.execution.base import Executor, FaultTolerance, Torchrun
from nemo_run.core.execution.dgxcloud import DGXCloudExecutor
from nemo_run.core.execution.local import LocalExecutor
from nemo_run.core.serialization.yaml import YamlSerializer
from nemo_run.core.serialization.zlib_json import ZlibJSONSerializer
Expand Down Expand Up @@ -139,6 +140,7 @@ def package(
mounts=mounts,
debug=executor.packager.debug,
max_retries=executor.retries,
dgxc=isinstance(executor, DGXCloudExecutor),
)
elif launcher and isinstance(launcher, FaultTolerance):
app_def = ft_launcher.ft_launcher(
Expand Down

0 comments on commit b3a2f62

Please sign in to comment.