Skip to content

Commit

Permalink
Merge pull request #145 from roclark/roclark/dgxc-executor-torchrun
Browse files Browse the repository at this point in the history
Add support for torchrun to DGXC Executor
  • Loading branch information
pablo-garay authored Feb 5, 2025
2 parents 61bb965 + b3a2f62 commit 570f577
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 20 deletions.
2 changes: 1 addition & 1 deletion src/nemo_run/core/execution/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,4 @@
from nemo_run.core.execution.skypilot import SkypilotExecutor
from nemo_run.core.execution.slurm import SlurmExecutor

__all__ = ["LocalExecutor", "SlurmExecutor", "SkypilotExecutor"]
__all__ = ["LocalExecutor", "SlurmExecutor", "SkypilotExecutor", "DGXCloudExecutor"]

Check failure

Code scanning / CodeQL

Explicit export is not defined Error

The name 'DGXCloudExecutor' is exported by __all__ but is not defined.
8 changes: 7 additions & 1 deletion src/nemo_run/core/execution/dgxcloud.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def create_distributed_job(
return response

def launch(self, name: str, cmd: list[str]) -> tuple[str, str]:
name = name.replace("_", "-") # to meet K8s requirements
name = name.replace("_", "-").replace(".", "-") # to meet K8s requirements
token = self.get_auth_token()
if not token:
raise RuntimeError("Failed to get auth token")
Expand All @@ -156,6 +156,12 @@ def launch(self, name: str, cmd: list[str]) -> tuple[str, str]:
status = r_json["actualPhase"]
return job_id, status

def nnodes(self) -> int:
return self.nodes

def nproc_per_node(self) -> int:
return self.gpus_per_node

def status(self, job_id: str) -> Optional[DGXCloudState]:
url = f"{self.base_url}/workloads/distributed/{job_id}"
token = self.get_auth_token()
Expand Down
41 changes: 23 additions & 18 deletions src/nemo_run/run/torchx_backend/components/torchrun.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ def torchrun(
rdzv_backend: str = "c10d",
mounts: Optional[list[str]] = None,
debug: bool = False,
dgxc: bool = False,
) -> specs.AppDef:
"""
Distributed data parallel style application (one role, multi-replica).
Expand Down Expand Up @@ -92,6 +93,7 @@ def torchrun(
mounts: mounts to mount into the worker environment/container (ex. type=<bind/volume>,src=/host,dst=/job[,readonly]).
See scheduler documentation for more info.
debug: whether to run with preset debug flags enabled
dgxc: whether to use a subset of settings for DGX Cloud
"""
if (script is None) == (m is None):
raise ValueError("exactly one of --script and -m must be specified")
Expand Down Expand Up @@ -130,24 +132,27 @@ def torchrun(
if debug:
env.update(_TORCH_DEBUG_FLAGS)

cmd = [
"--rdzv-backend",
rdzv_backend,
"--rdzv-endpoint",
rdzv_endpoint,
"--rdzv-id",
f"{random.randint(1, 10000)}",
"--nnodes",
num_nodes,
"--nproc-per-node",
nproc_per_node,
"--node-rank",
node_rank,
"--tee",
"3",
# "--role",
# "",
]
if dgxc:
cmd = ["--nnodes", nnodes_rep, "--nproc-per-node", nproc_per_node]
else:
cmd = [
"--rdzv-backend",
rdzv_backend,
"--rdzv-endpoint",
rdzv_endpoint,
"--rdzv-id",
f"{random.randint(1, 10000)}",
"--nnodes",
num_nodes,
"--nproc-per-node",
nproc_per_node,
"--node-rank",
node_rank,
"--tee",
"3",
# "--role",
# "",
]
if script is not None:
if no_python:
cmd += ["--no-python"]
Expand Down
2 changes: 2 additions & 0 deletions src/nemo_run/run/torchx_backend/packaging.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

from nemo_run.config import SCRIPTS_DIR, Partial, Script
from nemo_run.core.execution.base import Executor, FaultTolerance, Torchrun
from nemo_run.core.execution.dgxcloud import DGXCloudExecutor
from nemo_run.core.execution.local import LocalExecutor
from nemo_run.core.serialization.yaml import YamlSerializer
from nemo_run.core.serialization.zlib_json import ZlibJSONSerializer
Expand Down Expand Up @@ -139,6 +140,7 @@ def package(
mounts=mounts,
debug=executor.packager.debug,
max_retries=executor.retries,
dgxc=isinstance(executor, DGXCloudExecutor),
)
elif launcher and isinstance(launcher, FaultTolerance):
app_def = ft_launcher.ft_launcher(
Expand Down

0 comments on commit 570f577

Please sign in to comment.