From ecff84a9456f151f04f61fb1100e4f592a3c0d28 Mon Sep 17 00:00:00 2001 From: Akhilesh Halageri Date: Fri, 24 Nov 2023 21:21:40 +0000 Subject: [PATCH 1/5] feat(train): create single entrypoint --- .../resource_allocation/k8s/pod.py | 2 +- zetta_utils/training/lightning/train.py | 194 +++++++++++------- 2 files changed, 118 insertions(+), 78 deletions(-) diff --git a/zetta_utils/cloud_management/resource_allocation/k8s/pod.py b/zetta_utils/cloud_management/resource_allocation/k8s/pod.py index 3080a78e6..f75d548da 100644 --- a/zetta_utils/cloud_management/resource_allocation/k8s/pod.py +++ b/zetta_utils/cloud_management/resource_allocation/k8s/pod.py @@ -19,7 +19,7 @@ def get_pod_spec( image: str, command: List[str], command_args: List[str], - resources: Dict[str, int | float | str], + resources: Optional[Dict[str, int | float | str]] = None, dns_policy: Optional[str] = "Default", envs: Optional[List[k8s_client.V1EnvVar]] = None, env_secret_mapping: Optional[Dict[str, str]] = None, diff --git a/zetta_utils/training/lightning/train.py b/zetta_utils/training/lightning/train.py index cb6133b23..292e0e073 100644 --- a/zetta_utils/training/lightning/train.py +++ b/zetta_utils/training/lightning/train.py @@ -12,8 +12,7 @@ from torch.distributed.launcher import api as torch_launcher_api from kubernetes import client as k8s_client # type: ignore -from zetta_utils import builder, load_all_modules, log, mazepa, parsing -from zetta_utils.builder.build import BuilderPartial +from zetta_utils import builder, load_all_modules, log, mazepa from zetta_utils.cloud_management import execution_tracker, resource_allocation from zetta_utils.parsing import json @@ -40,9 +39,21 @@ def lightning_train( regime: pl.LightningModule, trainer: pl.Trainer, train_dataloader: torch.utils.data.DataLoader, - val_dataloader: torch.utils.data.DataLoader | None = None, + val_dataloader: Optional[torch.utils.data.DataLoader] = None, full_state_ckpt_path: str = "last", -): + num_nodes: int = 1, + nproc_per_node: int = 1, + retry_count: int = 3, + local_run: bool = False, + follow_logs: bool = False, + image: Optional[str] = None, + cluster_name: Optional[str] = None, + cluster_region: Optional[str] = None, + cluster_project: Optional[str] = None, + env_vars: Optional[Dict[str, str]] = None, + resource_limits: Optional[dict[str, int | float | str]] = None, + resource_requests: Optional[dict[str, int | float | str]] = None, +) -> None: """ Perform neural net trainig with Zetta's PytorchLightning integration. @@ -58,7 +69,91 @@ def lightning_train( Must be a full training state checkpoint created by PytorchLightning rather than a model checkpoint. If ``full_state_ckpt_path=="last"``, the latest checkpoint for the given experiment will be identified and loaded. + :param num_nodes: Number of GPU nodes for distributed training. + :param nproc_per_node: Number of GPU workers per node. + :param retry_count: Max retry count for the master train job; + excludes failures due to pod distruptions. + :param local_run: If True run the training locally. + :param follow_logs: If True, eagerly print logs from the pod. + If False, will wait until job completes successfully. + :param image: Container image to use. + :param cluster_name: Cluster configuration. + :param cluster_region: Cluster configuration. + :param cluster_project: Cluster configuration. + :param env_vars: Custom env variables to be set on pods. + :param resource_limits: K8s reource limits per pod. + :param resource_requests: K8s resource requests per pod. """ + + if local_run: + _lightning_train(regime, trainer, train_dataloader, val_dataloader=val_dataloader) + return + + assert image is not None, "Must provide a container image for remote training." + execution_id = mazepa.id_generation.get_unique_id( + prefix="exec", slug_len=4, add_uuid=False, max_len=50 + ) + + cluster_info = resource_allocation.k8s.parse_cluster_info( + cluster_name=cluster_name, + cluster_region=cluster_region, + cluster_project=cluster_project, + ) + + train_spec = { + "@type": "lightning_train", + "regime": regime, + "trainer": trainer, + "train_dataloader": train_dataloader, + "val_dataloader": val_dataloader, + "full_state_ckpt_path": full_state_ckpt_path, + } + + _create_ddp_master_job( + execution_id, + cluster_info=cluster_info, + image=image, + num_nodes=num_nodes, + nproc_per_node=nproc_per_node, + retry_count=retry_count, + train_spec=train_spec, + env_vars=env_vars, + follow_logs=follow_logs, + host_network=num_nodes > 1, + resource_limits=resource_limits, + resource_requests=resource_requests, + ) + + +@builder.register("multinode_train_launch") +@typeguard.typechecked +def multinode_train_launch( + execution_id: str, + num_nodes: int, + nproc_per_node: int, + rdzv_backend: str = "c10d", + **kwargs, # pylint: disable=unused-argument +): + # worker pods have MY_ROLE env set to `worker` + is_worker = os.environ.get("MY_ROLE") == "worker" + config = torch_launcher_api.LaunchConfig( + run_id=execution_id, + min_nodes=num_nodes, + max_nodes=num_nodes, + nproc_per_node=nproc_per_node, + rdzv_backend=rdzv_backend, + rdzv_endpoint="master:29400" if is_worker else "localhost:29400", + ) + torch_launcher_api.elastic_launch(config, _parse_spec_and_train)() + + +def _lightning_train( + regime: pl.LightningModule, + trainer: pl.Trainer, + train_dataloader: torch.utils.data.DataLoader, + val_dataloader: torch.utils.data.DataLoader | None = None, + full_state_ckpt_path: str = "last", +): logger.info("Starting training...") if "CURRENT_BUILD_SPEC" in os.environ: if hasattr(trainer, "log_config"): @@ -83,7 +178,11 @@ def lightning_train( def _parse_spec_and_train(): load_all_modules() - train_spec = json.loads(os.environ["ZETTA_RUN_SPEC"]) + + train_spec = None + with open(os.environ["ZETTA_RUN_SPEC_PATH"], "r", encoding="utf-8") as f: + train_spec = json.load(f) + regime = builder.build(spec=train_spec["regime"]) trainer = builder.build(spec=train_spec["trainer"]) train_dataloader = builder.build(spec=train_spec["train_dataloader"]) @@ -95,27 +194,7 @@ def _parse_spec_and_train(): full_state_ckpt_path = builder.build(spec=train_spec["full_state_ckpt_path"]) except KeyError: full_state_ckpt_path = "last" - lightning_train(regime, trainer, train_dataloader, val_dataloader, full_state_ckpt_path) - - -@builder.register("multinode_train_launch") -@typeguard.typechecked -def multinode_train_launch( - execution_id: str, - num_nodes: int, - nproc_per_node: int = 1, - rdzv_backend: str = "c10d", - **kwargs, # pylint: disable=unused-argument -): - config = torch_launcher_api.LaunchConfig( - run_id=execution_id, - min_nodes=num_nodes, - max_nodes=num_nodes, - nproc_per_node=nproc_per_node, - rdzv_backend=rdzv_backend, - rdzv_endpoint="master:29400" if os.environ.get("MY_ROLE") else "localhost:29400", - ) - torch_launcher_api.elastic_launch(config, _parse_spec_and_train)() + _lightning_train(regime, trainer, train_dataloader, val_dataloader, full_state_ckpt_path) def _get_tolerations(role: str) -> List[k8s_client.V1Toleration]: @@ -166,22 +245,28 @@ def _create_ddp_master_job( *, cluster_info: resource_allocation.k8s.ClusterInfo, image: str, - resource_limits: dict[str, int | float | str], - train_spec: dict, num_nodes: int, + nproc_per_node: int, retry_count: int, + train_spec: dict, env_vars: Optional[Dict[str, str]] = None, follow_logs: Optional[bool] = False, host_network: Optional[bool] = False, + resource_limits: Optional[dict[str, int | float | str]] = None, resource_requests: Optional[dict[str, int | float | str]] = None, ): # pylint: disable=too-many-locals - zetta_cmd = "zetta run specs/train.cue" - env_vars = env_vars or {} + """ + Parse spec and launch single/multinode training accordingly. + Creates a volume mount for `train.cue` in `/opt/zetta_utils/specs`. + Runs the command `zetta run specs/train.cue` on one or more worker pods. + """ + train_spec["local_run"] = True # run locally on the pod if num_nodes > 1: train_spec["@type"] = "multinode_train_launch" train_spec["execution_id"] = execution_id train_spec["num_nodes"] = num_nodes + train_spec["nproc_per_node"] = nproc_per_node train_spec["trainer"]["num_nodes"] = num_nodes specs = {"train": train_spec} vol, mount, spec_ctx = _spec_configmap_vol_and_ctx(execution_id, cluster_info, specs) @@ -203,6 +288,7 @@ def _create_ddp_master_job( ] envs = [] + env_vars = env_vars or {} for key, val in env_vars.items(): envs.append(k8s_client.V1EnvVar(name=key, value=val)) @@ -213,6 +299,7 @@ def _create_ddp_master_job( ), ) + zetta_cmd = "zetta run specs/train.cue" train_pod_spec = resource_allocation.k8s.get_pod_spec( name=execution_id, image=image, @@ -306,50 +393,3 @@ def _create_ddp_master_job( resource_allocation.k8s.follow_job_logs(train_job, cluster_info) else: resource_allocation.k8s.wait_for_job_completion(train_job, cluster_info) - - -@builder.register("lightning_train_remote") -@typeguard.typechecked -def lightning_train_remote( - worker_image: str, - worker_resources: dict[str, int | float | str], - spec_path: str | dict | BuilderPartial, - num_nodes: int = 1, - retry_count: int = 3, - env_vars: Optional[Dict[str, str]] = None, - worker_cluster_name: Optional[str] = None, - worker_cluster_region: Optional[str] = None, - worker_cluster_project: Optional[str] = None, - follow_logs: Optional[bool] = False, - worker_resource_requests: Optional[dict[str, int | float | str]] = None, -) -> None: - assert num_nodes > 0 - cluster_info = resource_allocation.k8s.parse_cluster_info( - cluster_name=worker_cluster_name, - cluster_region=worker_cluster_region, - cluster_project=worker_cluster_project, - ) - - execution_id = mazepa.id_generation.get_unique_id( - prefix="exec", slug_len=4, add_uuid=False, max_len=50 - ) - if isinstance(spec_path, str): - spec = parsing.cue.load(spec_path) - elif isinstance(spec_path, dict): - spec = spec_path - elif isinstance(spec_path, BuilderPartial): - spec = spec_path.spec - - _create_ddp_master_job( - execution_id, - cluster_info=cluster_info, - env_vars=env_vars, - follow_logs=follow_logs, - image=worker_image, - resource_limits=worker_resources, - train_spec=spec, - num_nodes=num_nodes, - host_network=num_nodes > 1, - retry_count=retry_count, - resource_requests=worker_resource_requests, - ) From f960254011813a90fdd1f388bd021dfe22f56fc0 Mon Sep 17 00:00:00 2001 From: Akhilesh Halageri Date: Fri, 24 Nov 2023 21:34:07 +0000 Subject: [PATCH 2/5] fix(train): remove deprecated import --- zetta_utils/api/v0.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/zetta_utils/api/v0.py b/zetta_utils/api/v0.py index 88a96511c..ece3516a9 100644 --- a/zetta_utils/api/v0.py +++ b/zetta_utils/api/v0.py @@ -439,7 +439,7 @@ NaiveSupervisedRegime, ) from zetta_utils.training.lightning.regimes.noop import NoOpRegime -from zetta_utils.training.lightning.train import lightning_train, lightning_train_remote +from zetta_utils.training.lightning.train import lightning_train from zetta_utils.training.lightning.trainers.default import ( ConfigureTraceCallback, ZettaDefaultTrainer, From 95198c18fe0aeadf588947d5f3f94add13c653fd Mon Sep 17 00:00:00 2001 From: Akhilesh Halageri Date: Sat, 25 Nov 2023 16:35:16 +0000 Subject: [PATCH 3/5] fix(train): better function naming, default values, json serializable spec --- zetta_utils/training/lightning/train.py | 23 +++++++++++------------ 1 file changed, 11 insertions(+), 12 deletions(-) diff --git a/zetta_utils/training/lightning/train.py b/zetta_utils/training/lightning/train.py index 292e0e073..da8f76b48 100644 --- a/zetta_utils/training/lightning/train.py +++ b/zetta_utils/training/lightning/train.py @@ -44,8 +44,8 @@ def lightning_train( num_nodes: int = 1, nproc_per_node: int = 1, retry_count: int = 3, - local_run: bool = False, - follow_logs: bool = False, + local_run: bool = True, + follow_logs: bool = True, image: Optional[str] = None, cluster_name: Optional[str] = None, cluster_region: Optional[str] = None, @@ -86,7 +86,7 @@ def lightning_train( """ if local_run: - _lightning_train(regime, trainer, train_dataloader, val_dataloader=val_dataloader) + _lightning_train_local(regime, trainer, train_dataloader, val_dataloader=val_dataloader) return assert image is not None, "Must provide a container image for remote training." @@ -102,14 +102,14 @@ def lightning_train( train_spec = { "@type": "lightning_train", - "regime": regime, - "trainer": trainer, - "train_dataloader": train_dataloader, - "val_dataloader": val_dataloader, + "regime": builder.get_initial_builder_spec(regime), + "trainer": builder.get_initial_builder_spec(trainer), + "train_dataloader": builder.get_initial_builder_spec(train_dataloader), + "val_dataloader": builder.get_initial_builder_spec(val_dataloader), "full_state_ckpt_path": full_state_ckpt_path, } - _create_ddp_master_job( + _lightning_train_remote( execution_id, cluster_info=cluster_info, image=image, @@ -147,7 +147,7 @@ def multinode_train_launch( torch_launcher_api.elastic_launch(config, _parse_spec_and_train)() -def _lightning_train( +def _lightning_train_local( regime: pl.LightningModule, trainer: pl.Trainer, train_dataloader: torch.utils.data.DataLoader, @@ -194,7 +194,7 @@ def _parse_spec_and_train(): full_state_ckpt_path = builder.build(spec=train_spec["full_state_ckpt_path"]) except KeyError: full_state_ckpt_path = "last" - _lightning_train(regime, trainer, train_dataloader, val_dataloader, full_state_ckpt_path) + _lightning_train_local(regime, trainer, train_dataloader, val_dataloader, full_state_ckpt_path) def _get_tolerations(role: str) -> List[k8s_client.V1Toleration]: @@ -240,7 +240,7 @@ def _spec_configmap_vol_and_ctx( return (specs_vol, specs_mount, ctx) -def _create_ddp_master_job( +def _lightning_train_remote( execution_id: str, *, cluster_info: resource_allocation.k8s.ClusterInfo, @@ -261,7 +261,6 @@ def _create_ddp_master_job( Runs the command `zetta run specs/train.cue` on one or more worker pods. """ - train_spec["local_run"] = True # run locally on the pod if num_nodes > 1: train_spec["@type"] = "multinode_train_launch" train_spec["execution_id"] = execution_id From d627afe1539ebfc141e5d8777c32e33c45ad4274 Mon Sep 17 00:00:00 2001 From: Akhilesh Halageri Date: Sun, 26 Nov 2023 18:49:17 +0000 Subject: [PATCH 4/5] fix(train): assert builder spec object --- zetta_utils/training/lightning/train.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/zetta_utils/training/lightning/train.py b/zetta_utils/training/lightning/train.py index da8f76b48..1c32af0d7 100644 --- a/zetta_utils/training/lightning/train.py +++ b/zetta_utils/training/lightning/train.py @@ -109,6 +109,9 @@ def lightning_train( "full_state_ckpt_path": full_state_ckpt_path, } + for _key in ["regime", "trainer", "train_dataloader"]: + assert train_spec[_key] is not None, f"{_key} requires builder compatible spec." + _lightning_train_remote( execution_id, cluster_info=cluster_info, From 3f2396669e137f020abe9ef26b6653495e6ac2b5 Mon Sep 17 00:00:00 2001 From: Akhilesh Halageri Date: Sun, 26 Nov 2023 18:56:13 +0000 Subject: [PATCH 5/5] fix(train): assert resource requests or limits --- zetta_utils/training/lightning/train.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/zetta_utils/training/lightning/train.py b/zetta_utils/training/lightning/train.py index 1c32af0d7..8ce873e53 100644 --- a/zetta_utils/training/lightning/train.py +++ b/zetta_utils/training/lightning/train.py @@ -90,6 +90,9 @@ def lightning_train( return assert image is not None, "Must provide a container image for remote training." + assert ( + resource_requests or resource_limits + ), "Must provide at least one of resource requests or limits for remote training." execution_id = mazepa.id_generation.get_unique_id( prefix="exec", slug_len=4, add_uuid=False, max_len=50 )