From b3a2f620779110bd35fde43948a2af7edcbe9a61 Mon Sep 17 00:00:00 2001 From: Robert Clark Date: Mon, 3 Feb 2025 20:02:19 +0000 Subject: [PATCH] Add support for torchrun to DGXC Executor DGX Cloud uses the PyTorch Training Operator from KubeFlow under the hood to launch jobs. This handles many of the variables necessary for running distributed PyTorch jobs with torchrun and only a subset of settings are required to launch the job as the original default settings will conflict with the auto-configured setup from DGX Cloud. Signed-Off-By: Robert Clark --- src/nemo_run/core/execution/__init__.py | 2 +- src/nemo_run/core/execution/dgxcloud.py | 8 +++- .../run/torchx_backend/components/torchrun.py | 41 +++++++++++-------- src/nemo_run/run/torchx_backend/packaging.py | 2 + 4 files changed, 33 insertions(+), 20 deletions(-) diff --git a/src/nemo_run/core/execution/__init__.py b/src/nemo_run/core/execution/__init__.py index fb28a46..7f2584b 100644 --- a/src/nemo_run/core/execution/__init__.py +++ b/src/nemo_run/core/execution/__init__.py @@ -17,4 +17,4 @@ from nemo_run.core.execution.skypilot import SkypilotExecutor from nemo_run.core.execution.slurm import SlurmExecutor -__all__ = ["LocalExecutor", "SlurmExecutor", "SkypilotExecutor"] +__all__ = ["LocalExecutor", "SlurmExecutor", "SkypilotExecutor", "DGXCloudExecutor"] diff --git a/src/nemo_run/core/execution/dgxcloud.py b/src/nemo_run/core/execution/dgxcloud.py index 5cba46c..3f8e73f 100644 --- a/src/nemo_run/core/execution/dgxcloud.py +++ b/src/nemo_run/core/execution/dgxcloud.py @@ -138,7 +138,7 @@ def create_distributed_job( return response def launch(self, name: str, cmd: list[str]) -> tuple[str, str]: - name = name.replace("_", "-") # to meet K8s requirements + name = name.replace("_", "-").replace(".", "-") # to meet K8s requirements token = self.get_auth_token() if not token: raise RuntimeError("Failed to get auth token") @@ -156,6 +156,12 @@ def launch(self, name: str, cmd: list[str]) -> tuple[str, str]: status = r_json["actualPhase"] return job_id, status + def nnodes(self) -> int: + return self.nodes + + def nproc_per_node(self) -> int: + return self.gpus_per_node + def status(self, job_id: str) -> Optional[DGXCloudState]: url = f"{self.base_url}/workloads/distributed/{job_id}" token = self.get_auth_token() diff --git a/src/nemo_run/run/torchx_backend/components/torchrun.py b/src/nemo_run/run/torchx_backend/components/torchrun.py index 0d0f179..6a5c4d3 100644 --- a/src/nemo_run/run/torchx_backend/components/torchrun.py +++ b/src/nemo_run/run/torchx_backend/components/torchrun.py @@ -59,6 +59,7 @@ def torchrun( rdzv_backend: str = "c10d", mounts: Optional[list[str]] = None, debug: bool = False, + dgxc: bool = False, ) -> specs.AppDef: """ Distributed data parallel style application (one role, multi-replica). @@ -92,6 +93,7 @@ def torchrun( mounts: mounts to mount into the worker environment/container (ex. type=,src=/host,dst=/job[,readonly]). See scheduler documentation for more info. debug: whether to run with preset debug flags enabled + dgxc: whether to use a subset of settings for DGX Cloud """ if (script is None) == (m is None): raise ValueError("exactly one of --script and -m must be specified") @@ -130,24 +132,27 @@ def torchrun( if debug: env.update(_TORCH_DEBUG_FLAGS) - cmd = [ - "--rdzv-backend", - rdzv_backend, - "--rdzv-endpoint", - rdzv_endpoint, - "--rdzv-id", - f"{random.randint(1, 10000)}", - "--nnodes", - num_nodes, - "--nproc-per-node", - nproc_per_node, - "--node-rank", - node_rank, - "--tee", - "3", - # "--role", - # "", - ] + if dgxc: + cmd = ["--nnodes", nnodes_rep, "--nproc-per-node", nproc_per_node] + else: + cmd = [ + "--rdzv-backend", + rdzv_backend, + "--rdzv-endpoint", + rdzv_endpoint, + "--rdzv-id", + f"{random.randint(1, 10000)}", + "--nnodes", + num_nodes, + "--nproc-per-node", + nproc_per_node, + "--node-rank", + node_rank, + "--tee", + "3", + # "--role", + # "", + ] if script is not None: if no_python: cmd += ["--no-python"] diff --git a/src/nemo_run/run/torchx_backend/packaging.py b/src/nemo_run/run/torchx_backend/packaging.py index 3a6dd14..0095ac3 100644 --- a/src/nemo_run/run/torchx_backend/packaging.py +++ b/src/nemo_run/run/torchx_backend/packaging.py @@ -23,6 +23,7 @@ from nemo_run.config import SCRIPTS_DIR, Partial, Script from nemo_run.core.execution.base import Executor, FaultTolerance, Torchrun +from nemo_run.core.execution.dgxcloud import DGXCloudExecutor from nemo_run.core.execution.local import LocalExecutor from nemo_run.core.serialization.yaml import YamlSerializer from nemo_run.core.serialization.zlib_json import ZlibJSONSerializer @@ -139,6 +140,7 @@ def package( mounts=mounts, debug=executor.packager.debug, max_retries=executor.retries, + dgxc=isinstance(executor, DGXCloudExecutor), ) elif launcher and isinstance(launcher, FaultTolerance): app_def = ft_launcher.ft_launcher(