Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(train): create single entrypoint #553

Merged
merged 5 commits into from
Nov 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion zetta_utils/api/v0.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,7 +439,7 @@
NaiveSupervisedRegime,
)
from zetta_utils.training.lightning.regimes.noop import NoOpRegime
from zetta_utils.training.lightning.train import lightning_train, lightning_train_remote
from zetta_utils.training.lightning.train import lightning_train
from zetta_utils.training.lightning.trainers.default import (
ConfigureTraceCallback,
ZettaDefaultTrainer,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def get_pod_spec(
image: str,
command: List[str],
command_args: List[str],
resources: Dict[str, int | float | str],
resources: Optional[Dict[str, int | float | str]] = None,
dns_policy: Optional[str] = "Default",
envs: Optional[List[k8s_client.V1EnvVar]] = None,
env_secret_mapping: Optional[Dict[str, str]] = None,
Expand Down
201 changes: 123 additions & 78 deletions zetta_utils/training/lightning/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,7 @@
from torch.distributed.launcher import api as torch_launcher_api

from kubernetes import client as k8s_client # type: ignore
from zetta_utils import builder, load_all_modules, log, mazepa, parsing
from zetta_utils.builder.build import BuilderPartial
from zetta_utils import builder, load_all_modules, log, mazepa
from zetta_utils.cloud_management import execution_tracker, resource_allocation
from zetta_utils.parsing import json

Expand All @@ -40,9 +39,21 @@ def lightning_train(
regime: pl.LightningModule,
trainer: pl.Trainer,
train_dataloader: torch.utils.data.DataLoader,
val_dataloader: torch.utils.data.DataLoader | None = None,
val_dataloader: Optional[torch.utils.data.DataLoader] = None,
full_state_ckpt_path: str = "last",
):
num_nodes: int = 1,
nproc_per_node: int = 1,
retry_count: int = 3,
local_run: bool = True,
follow_logs: bool = True,
image: Optional[str] = None,
cluster_name: Optional[str] = None,
cluster_region: Optional[str] = None,
cluster_project: Optional[str] = None,
env_vars: Optional[Dict[str, str]] = None,
resource_limits: Optional[dict[str, int | float | str]] = None,
resource_requests: Optional[dict[str, int | float | str]] = None,
) -> None:
"""
Perform neural net trainig with Zetta's PytorchLightning integration.

Expand All @@ -58,7 +69,97 @@ def lightning_train(
Must be a full training state checkpoint created by PytorchLightning rather
than a model checkpoint. If ``full_state_ckpt_path=="last"``, the latest
checkpoint for the given experiment will be identified and loaded.
:param num_nodes: Number of GPU nodes for distributed training.
:param nproc_per_node: Number of GPU workers per node.
:param retry_count: Max retry count for the master train job;
excludes failures due to pod distruptions.
:param local_run: If True run the training locally.
:param follow_logs: If True, eagerly print logs from the pod.
If False, will wait until job completes successfully.
:param image: Container image to use.
:param cluster_name: Cluster configuration.
:param cluster_region: Cluster configuration.
:param cluster_project: Cluster configuration.
:param env_vars: Custom env variables to be set on pods.
:param resource_limits: K8s reource limits per pod.
:param resource_requests: K8s resource requests per pod.
"""

if local_run:
_lightning_train_local(regime, trainer, train_dataloader, val_dataloader=val_dataloader)
return

assert image is not None, "Must provide a container image for remote training."
assert (
resource_requests or resource_limits
), "Must provide at least one of resource requests or limits for remote training."
execution_id = mazepa.id_generation.get_unique_id(
prefix="exec", slug_len=4, add_uuid=False, max_len=50
)

cluster_info = resource_allocation.k8s.parse_cluster_info(
cluster_name=cluster_name,
cluster_region=cluster_region,
cluster_project=cluster_project,
)

train_spec = {
"@type": "lightning_train",
"regime": builder.get_initial_builder_spec(regime),
"trainer": builder.get_initial_builder_spec(trainer),
"train_dataloader": builder.get_initial_builder_spec(train_dataloader),
"val_dataloader": builder.get_initial_builder_spec(val_dataloader),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we gotta check if any of these are None here

"full_state_ckpt_path": full_state_ckpt_path,
}

for _key in ["regime", "trainer", "train_dataloader"]:
assert train_spec[_key] is not None, f"{_key} requires builder compatible spec."

_lightning_train_remote(
execution_id,
cluster_info=cluster_info,
image=image,
num_nodes=num_nodes,
nproc_per_node=nproc_per_node,
retry_count=retry_count,
train_spec=train_spec,
env_vars=env_vars,
follow_logs=follow_logs,
host_network=num_nodes > 1,
resource_limits=resource_limits,
resource_requests=resource_requests,
)


@builder.register("multinode_train_launch")
@typeguard.typechecked
def multinode_train_launch(
execution_id: str,
num_nodes: int,
nproc_per_node: int,
rdzv_backend: str = "c10d",
**kwargs, # pylint: disable=unused-argument
):
# worker pods have MY_ROLE env set to `worker`
is_worker = os.environ.get("MY_ROLE") == "worker"
config = torch_launcher_api.LaunchConfig(
run_id=execution_id,
min_nodes=num_nodes,
max_nodes=num_nodes,
nproc_per_node=nproc_per_node,
rdzv_backend=rdzv_backend,
rdzv_endpoint="master:29400" if is_worker else "localhost:29400",
)
torch_launcher_api.elastic_launch(config, _parse_spec_and_train)()


def _lightning_train_local(
regime: pl.LightningModule,
trainer: pl.Trainer,
train_dataloader: torch.utils.data.DataLoader,
val_dataloader: torch.utils.data.DataLoader | None = None,
full_state_ckpt_path: str = "last",
):
logger.info("Starting training...")
if "CURRENT_BUILD_SPEC" in os.environ:
if hasattr(trainer, "log_config"):
Expand All @@ -83,7 +184,11 @@ def lightning_train(

def _parse_spec_and_train():
load_all_modules()
train_spec = json.loads(os.environ["ZETTA_RUN_SPEC"])

train_spec = None
with open(os.environ["ZETTA_RUN_SPEC_PATH"], "r", encoding="utf-8") as f:
train_spec = json.load(f)

regime = builder.build(spec=train_spec["regime"])
trainer = builder.build(spec=train_spec["trainer"])
train_dataloader = builder.build(spec=train_spec["train_dataloader"])
Expand All @@ -95,27 +200,7 @@ def _parse_spec_and_train():
full_state_ckpt_path = builder.build(spec=train_spec["full_state_ckpt_path"])
except KeyError:
full_state_ckpt_path = "last"
lightning_train(regime, trainer, train_dataloader, val_dataloader, full_state_ckpt_path)


@builder.register("multinode_train_launch")
@typeguard.typechecked
def multinode_train_launch(
execution_id: str,
num_nodes: int,
nproc_per_node: int = 1,
rdzv_backend: str = "c10d",
**kwargs, # pylint: disable=unused-argument
):
config = torch_launcher_api.LaunchConfig(
run_id=execution_id,
min_nodes=num_nodes,
max_nodes=num_nodes,
nproc_per_node=nproc_per_node,
rdzv_backend=rdzv_backend,
rdzv_endpoint="master:29400" if os.environ.get("MY_ROLE") else "localhost:29400",
)
torch_launcher_api.elastic_launch(config, _parse_spec_and_train)()
_lightning_train_local(regime, trainer, train_dataloader, val_dataloader, full_state_ckpt_path)


def _get_tolerations(role: str) -> List[k8s_client.V1Toleration]:
Expand Down Expand Up @@ -161,27 +246,32 @@ def _spec_configmap_vol_and_ctx(
return (specs_vol, specs_mount, ctx)


def _create_ddp_master_job(
def _lightning_train_remote(
execution_id: str,
*,
cluster_info: resource_allocation.k8s.ClusterInfo,
image: str,
resource_limits: dict[str, int | float | str],
train_spec: dict,
num_nodes: int,
nproc_per_node: int,
retry_count: int,
train_spec: dict,
env_vars: Optional[Dict[str, str]] = None,
follow_logs: Optional[bool] = False,
host_network: Optional[bool] = False,
resource_limits: Optional[dict[str, int | float | str]] = None,
resource_requests: Optional[dict[str, int | float | str]] = None,
): # pylint: disable=too-many-locals
zetta_cmd = "zetta run specs/train.cue"
env_vars = env_vars or {}
"""
Parse spec and launch single/multinode training accordingly.
Creates a volume mount for `train.cue` in `/opt/zetta_utils/specs`.
Runs the command `zetta run specs/train.cue` on one or more worker pods.
"""

if num_nodes > 1:
train_spec["@type"] = "multinode_train_launch"
train_spec["execution_id"] = execution_id
train_spec["num_nodes"] = num_nodes
train_spec["nproc_per_node"] = nproc_per_node
train_spec["trainer"]["num_nodes"] = num_nodes
specs = {"train": train_spec}
vol, mount, spec_ctx = _spec_configmap_vol_and_ctx(execution_id, cluster_info, specs)
Expand All @@ -203,6 +293,7 @@ def _create_ddp_master_job(
]

envs = []
env_vars = env_vars or {}
for key, val in env_vars.items():
envs.append(k8s_client.V1EnvVar(name=key, value=val))

Expand All @@ -213,6 +304,7 @@ def _create_ddp_master_job(
),
)

zetta_cmd = "zetta run specs/train.cue"
train_pod_spec = resource_allocation.k8s.get_pod_spec(
name=execution_id,
image=image,
Expand Down Expand Up @@ -306,50 +398,3 @@ def _create_ddp_master_job(
resource_allocation.k8s.follow_job_logs(train_job, cluster_info)
else:
resource_allocation.k8s.wait_for_job_completion(train_job, cluster_info)


@builder.register("lightning_train_remote")
@typeguard.typechecked
def lightning_train_remote(
worker_image: str,
worker_resources: dict[str, int | float | str],
spec_path: str | dict | BuilderPartial,
num_nodes: int = 1,
retry_count: int = 3,
env_vars: Optional[Dict[str, str]] = None,
worker_cluster_name: Optional[str] = None,
worker_cluster_region: Optional[str] = None,
worker_cluster_project: Optional[str] = None,
follow_logs: Optional[bool] = False,
worker_resource_requests: Optional[dict[str, int | float | str]] = None,
) -> None:
assert num_nodes > 0
cluster_info = resource_allocation.k8s.parse_cluster_info(
cluster_name=worker_cluster_name,
cluster_region=worker_cluster_region,
cluster_project=worker_cluster_project,
)

execution_id = mazepa.id_generation.get_unique_id(
prefix="exec", slug_len=4, add_uuid=False, max_len=50
)
if isinstance(spec_path, str):
spec = parsing.cue.load(spec_path)
elif isinstance(spec_path, dict):
spec = spec_path
elif isinstance(spec_path, BuilderPartial):
spec = spec_path.spec

_create_ddp_master_job(
execution_id,
cluster_info=cluster_info,
env_vars=env_vars,
follow_logs=follow_logs,
image=worker_image,
resource_limits=worker_resources,
train_spec=spec,
num_nodes=num_nodes,
host_network=num_nodes > 1,
retry_count=retry_count,
resource_requests=worker_resource_requests,
)
Loading