diff --git a/zetta_utils/api/v0.py b/zetta_utils/api/v0.py index 532a05e5e..24db54050 100644 --- a/zetta_utils/api/v0.py +++ b/zetta_utils/api/v0.py @@ -46,7 +46,6 @@ from zetta_utils.cloud_management.resource_allocation.k8s.deployment import ( deployment_ctx_mngr, get_deployment, - get_deployment_spec, get_mazepa_worker_deployment, ) from zetta_utils.cloud_management.resource_allocation.k8s.eks import ( diff --git a/zetta_utils/cloud_management/resource_allocation/k8s/__init__.py b/zetta_utils/cloud_management/resource_allocation/k8s/__init__.py index 251cdbca5..6e7e5e918 100644 --- a/zetta_utils/cloud_management/resource_allocation/k8s/__init__.py +++ b/zetta_utils/cloud_management/resource_allocation/k8s/__init__.py @@ -34,4 +34,4 @@ from .pod import get_pod_spec, get_mazepa_pod_spec from .secret import secrets_ctx_mngr, get_secrets_and_mapping from .service import get_service, service_ctx_manager -from .volume import get_common_volumes, get_common_volume_mounts +from .volume import ADC_MOUNT_PATH, get_common_volumes, get_common_volume_mounts diff --git a/zetta_utils/cloud_management/resource_allocation/k8s/deployment.py b/zetta_utils/cloud_management/resource_allocation/k8s/deployment.py index 1d580dfd1..d1f0913ef 100644 --- a/zetta_utils/cloud_management/resource_allocation/k8s/deployment.py +++ b/zetta_utils/cloud_management/resource_allocation/k8s/deployment.py @@ -24,7 +24,7 @@ logger = log.get_logger("zetta_utils") -def get_deployment_spec( +def _get_mazepa_deployment( name: str, image: str, command: str, @@ -35,6 +35,7 @@ def get_deployment_spec( resource_requests: Optional[Dict[str, int | float | str]] = None, provisioning_model: Literal["standard", "spot"] = "spot", gpu_accelerator_type: str | None = None, + adc_available: bool = False, ) -> k8s_client.V1Deployment: name = f"run-{name}" pod_spec = get_mazepa_pod_spec( @@ -45,6 +46,7 @@ def get_deployment_spec( provisioning_model=provisioning_model, resource_requests=resource_requests, gpu_accelerator_type=gpu_accelerator_type, + adc_available=adc_available, ) pod_template = k8s_client.V1PodTemplateSpec( @@ -69,7 +71,6 @@ def get_deployment_spec( metadata=k8s_client.V1ObjectMeta(name=name, labels=labels), spec=deployment_spec, ) - return deployment @@ -87,6 +88,7 @@ def get_mazepa_worker_deployment( # pylint: disable=too-many-locals semaphores_spec: dict[SemaphoreType, int] | None = None, provisioning_model: Literal["standard", "spot"] = "spot", gpu_accelerator_type: str | None = None, + adc_available: bool = False, ): if labels is None: labels_final = {"run_id": run_id} @@ -98,7 +100,7 @@ def get_mazepa_worker_deployment( # pylint: disable=too-many-locals ) logger.debug(f"Making a deployment with worker command: '{worker_command}'") - return get_deployment_spec( + return _get_mazepa_deployment( name=run_id, image=image, replicas=replicas, @@ -109,6 +111,7 @@ def get_mazepa_worker_deployment( # pylint: disable=too-many-locals resource_requests=resource_requests, provisioning_model=provisioning_model, gpu_accelerator_type=gpu_accelerator_type, + adc_available=adc_available, ) diff --git a/zetta_utils/cloud_management/resource_allocation/k8s/pod.py b/zetta_utils/cloud_management/resource_allocation/k8s/pod.py index 3fc9d435a..f12a861f7 100644 --- a/zetta_utils/cloud_management/resource_allocation/k8s/pod.py +++ b/zetta_utils/cloud_management/resource_allocation/k8s/pod.py @@ -8,10 +8,7 @@ from kubernetes import client as k8s_client from zetta_utils import log -from zetta_utils.cloud_management.resource_allocation.k8s.volume import ( - get_common_volume_mounts, - get_common_volumes, -) +from zetta_utils.cloud_management.resource_allocation.k8s import volume from .secret import get_worker_env_vars @@ -87,6 +84,7 @@ def get_mazepa_pod_spec( resource_requests: Optional[Dict[str, int | float | str]] = None, restart_policy: Literal["Always", "Never"] = "Always", gpu_accelerator_type: str | None = None, + adc_available: bool = False, ) -> k8s_client.V1PodSpec: schedule_toleration = k8s_client.V1Toleration( key="worker-pool", operator="Equal", value="true", effect="NoSchedule" @@ -96,17 +94,24 @@ def get_mazepa_pod_spec( if gpu_accelerator_type: node_selector["cloud.google.com/gke-accelerator"] = gpu_accelerator_type + envs = [] + if adc_available: + envs.append( + k8s_client.V1EnvVar(name="GOOGLE_APPLICATION_CREDENTIALS", value=volume.ADC_MOUNT_PATH) + ) + return get_pod_spec( name="zutils-worker", image=image, command=["/bin/sh"], command_args=["-c", command], resources=resources, + envs=envs, env_secret_mapping=env_secret_mapping, node_selector=node_selector, restart_policy=restart_policy, tolerations=[schedule_toleration], - volumes=get_common_volumes(), - volume_mounts=get_common_volume_mounts(), + volumes=volume.get_common_volumes(), + volume_mounts=volume.get_common_volume_mounts(), resource_requests=resource_requests, ) diff --git a/zetta_utils/cloud_management/resource_allocation/k8s/secret.py b/zetta_utils/cloud_management/resource_allocation/k8s/secret.py index 2b731f8bd..8928c4be0 100644 --- a/zetta_utils/cloud_management/resource_allocation/k8s/secret.py +++ b/zetta_utils/cloud_management/resource_allocation/k8s/secret.py @@ -2,6 +2,7 @@ Helpers for k8s secrets. """ +import base64 import os from contextlib import contextmanager from typing import Dict, Iterable, List, Optional, Tuple @@ -19,8 +20,6 @@ logger = log.get_logger("zetta_utils") -CV_SECRETS_NAME = "cloudvolume-secrets" - def get_worker_env_vars(env_secret_mapping: Optional[Dict[str, str]] = None) -> list: if env_secret_mapping is None: @@ -53,11 +52,30 @@ def get_worker_env_vars(env_secret_mapping: Optional[Dict[str, str]] = None) -> return envs +def _get_user_adc() -> str | None: + """ + Reads credentials file created by ` gcloud auth application-default login`. + """ + file_name = ".config/gcloud/application_default_credentials.json" + home_dir = os.path.expanduser("~") + file_path = os.path.join(home_dir, file_name) + data = None + if os.path.exists(file_path): + try: + with open(file_path, "r", encoding="utf-8") as f: + data = f.read() + data = base64.b64encode(data.encode()).decode() + except Exception: # pylint: disable=broad-exception-caught + ... + return data + + def get_secrets_and_mapping( run_id: str, share_envs: Iterable[str] = () -) -> Tuple[List[k8s_client.V1Secret], Dict[str, str]]: +) -> Tuple[List[k8s_client.V1Secret], Dict[str, str], bool]: env_secret_mapping: Dict[str, str] = {} secrets_kv: Dict[str, str] = {} + adc_available: bool = False combined_secret_data = {} for env_k in share_envs: @@ -88,7 +106,15 @@ def get_secrets_and_mapping( string_data={"value": v}, ) secrets.append(secret) - return secrets, env_secret_mapping + + adc_content = _get_user_adc() + adc_available = not adc_content is None + adc_creds = k8s_client.V1Secret( + metadata=k8s_client.V1ObjectMeta(name=f"run-{run_id}-adc"), + data={"adc.json": adc_content}, + ) + secrets.append(adc_creds) + return secrets, env_secret_mapping, adc_available @contextmanager diff --git a/zetta_utils/cloud_management/resource_allocation/k8s/volume.py b/zetta_utils/cloud_management/resource_allocation/k8s/volume.py index d3b321fe8..066eb619d 100644 --- a/zetta_utils/cloud_management/resource_allocation/k8s/volume.py +++ b/zetta_utils/cloud_management/resource_allocation/k8s/volume.py @@ -4,7 +4,14 @@ from __future__ import annotations +from typing import Final + from kubernetes import client as k8s_client +from zetta_utils import run + +ADC_MOUNT_PATH: Final[str] = "/etc/secrets/adc.json" +SHM_MOUNT_PATH: Final[str] = "/dev/shm" +TMP_MOUNT_PATH: Final[str] = "/tmp" def get_common_volumes(): @@ -14,11 +21,21 @@ def get_common_volumes(): tmp = k8s_client.V1Volume( name="tmp", empty_dir=k8s_client.V1EmptyDirVolumeSource(medium="Memory") ) - return [dshm, tmp] + + # application_default_credentials + adc = k8s_client.V1Volume( + name="adc", secret=k8s_client.V1SecretVolumeSource(secret_name=f"run-{run.RUN_ID}-adc") + ) + return [dshm, tmp, adc] def get_common_volume_mounts(): return [ - k8s_client.V1VolumeMount(mount_path="/dev/shm", name="dshm"), - k8s_client.V1VolumeMount(mount_path="/tmp", name="tmp"), + k8s_client.V1VolumeMount(mount_path=SHM_MOUNT_PATH, name="dshm"), + k8s_client.V1VolumeMount(mount_path=TMP_MOUNT_PATH, name="tmp"), + k8s_client.V1VolumeMount( + name="adc", + mount_path="/etc/secrets", + read_only=True, + ), ] diff --git a/zetta_utils/mazepa_addons/configurations/deprecated/execute_on_gcp_with_sqs.py b/zetta_utils/mazepa_addons/configurations/deprecated/execute_on_gcp_with_sqs.py index 47562d451..1acb290e3 100644 --- a/zetta_utils/mazepa_addons/configurations/deprecated/execute_on_gcp_with_sqs.py +++ b/zetta_utils/mazepa_addons/configurations/deprecated/execute_on_gcp_with_sqs.py @@ -86,7 +86,9 @@ def get_gcp_with_sqs_config( ctx_managers.append(aws_sqs.sqs_queue_ctx_mngr(execution_id, task_queue)) ctx_managers.append(aws_sqs.sqs_queue_ctx_mngr(execution_id, outcome_queue)) - secrets, env_secret_mapping = k8s.get_secrets_and_mapping(execution_id, REQUIRED_ENV_VARS) + secrets, env_secret_mapping, adc_available = k8s.get_secrets_and_mapping( + execution_id, REQUIRED_ENV_VARS + ) if sqs_based_scaling: worker_command = k8s.get_mazepa_worker_command( @@ -104,6 +106,7 @@ def get_gcp_with_sqs_config( provisioning_model=provisioning_model, resource_requests=worker_resource_requests, restart_policy="Never", + adc_available=adc_available, ) job_spec = k8s.get_job_spec(pod_spec=pod_spec) scaled_job_ctx_mngr = k8s.keda_deprecated.scaled_job_ctx_mngr( @@ -129,6 +132,7 @@ def get_gcp_with_sqs_config( num_procs=num_procs, semaphores_spec=semaphores_spec, provisioning_model=provisioning_model, + adc_available=adc_available, ) deployment_ctx_mngr = k8s.deployment_ctx_mngr( execution_id, diff --git a/zetta_utils/mazepa_addons/configurations/execute_on_gcp_with_sqs.py b/zetta_utils/mazepa_addons/configurations/execute_on_gcp_with_sqs.py index fbf64af33..94b01dd03 100644 --- a/zetta_utils/mazepa_addons/configurations/execute_on_gcp_with_sqs.py +++ b/zetta_utils/mazepa_addons/configurations/execute_on_gcp_with_sqs.py @@ -93,6 +93,7 @@ def _get_group_taskqueue_and_contexts( sqs_trigger_name: str, outcome_queue_spec: dict[str, Any], env_secret_mapping: dict[str, str], + adc_available: bool = False, ) -> tuple[PushMessageQueue[Task], list[AbstractContextManager]]: ctx_managers: list[AbstractContextManager] = [] work_queue_name = f"run-{execution_id}-{'-'.join(group.queue_tags)}-work" @@ -117,6 +118,7 @@ def _get_group_taskqueue_and_contexts( resource_requests=group.resource_requests, restart_policy="Never", gpu_accelerator_type=group.gpu_accelerator_type, + adc_available=adc_available, ) job_spec = k8s.get_job_spec(pod_spec=pod_spec) scaled_job_ctx_mngr = k8s.scaled_job_ctx_mngr( @@ -145,6 +147,7 @@ def _get_group_taskqueue_and_contexts( semaphores_spec=group.semaphores_spec, provisioning_model=group.provisioning_model, gpu_accelerator_type=group.gpu_accelerator_type, + adc_available=adc_available, ) deployment_ctx_mngr = k8s.deployment_ctx_mngr( execution_id, @@ -164,7 +167,9 @@ def get_gcp_with_sqs_config( ctx_managers: list[AbstractContextManager], ) -> tuple[PushMessageQueue[Task], PullMessageQueue[OutcomeReport], list[AbstractContextManager]]: task_queues = [] - secrets, env_secret_mapping = k8s.get_secrets_and_mapping(execution_id, REQUIRED_ENV_VARS) + secrets, env_secret_mapping, adc_available = k8s.get_secrets_and_mapping( + execution_id, REQUIRED_ENV_VARS + ) outcome_queue_name = f"run-{execution_id}-outcome" outcome_queue_spec = {"@type": "SQSQueue", "name": outcome_queue_name, "pull_wait_sec": 2.5} @@ -187,6 +192,7 @@ def get_gcp_with_sqs_config( sqs_trigger_name=sqs_trigger_name, outcome_queue_spec=outcome_queue_spec, env_secret_mapping=env_secret_mapping, + adc_available=adc_available, ) task_queues.append(task_queue) ctx_managers.extend(group_ctx_managers) diff --git a/zetta_utils/training/lightning/train.py b/zetta_utils/training/lightning/train.py index 530d59492..a876cd740 100644 --- a/zetta_utils/training/lightning/train.py +++ b/zetta_utils/training/lightning/train.py @@ -98,12 +98,16 @@ def lightning_train( _lightning_train_local( regime=regime if not isinstance(regime, dict) else builder.build(regime), trainer=trainer if not isinstance(trainer, dict) else builder.build(trainer), - train_dataloader=train_dataloader - if not isinstance(train_dataloader, dict) - else builder.build(train_dataloader, parallel=builder.PARALLEL_BUILD_ALLOWED), - val_dataloader=val_dataloader - if not isinstance(val_dataloader, dict) - else builder.build(val_dataloader, parallel=builder.PARALLEL_BUILD_ALLOWED), + train_dataloader=( + train_dataloader + if not isinstance(train_dataloader, dict) + else builder.build(train_dataloader, parallel=builder.PARALLEL_BUILD_ALLOWED) + ), + val_dataloader=( + val_dataloader + if not isinstance(val_dataloader, dict) + else builder.build(val_dataloader, parallel=builder.PARALLEL_BUILD_ALLOWED) + ), full_state_ckpt_path=full_state_ckpt_path, ) return @@ -336,7 +340,7 @@ def _lightning_train_remote( specs = {"train": train_spec} vol, mount, spec_ctx = _spec_configmap_vol_and_ctx(cluster_info, specs) - secrets, env_secret_mapping = resource_allocation.k8s.get_secrets_and_mapping( + secrets, env_secret_mapping, adc_available = resource_allocation.k8s.get_secrets_and_mapping( run.RUN_ID, REQUIRED_ENV_VARS ) @@ -348,6 +352,14 @@ def _lightning_train_remote( for key, val in env_vars.items(): envs.append(k8s_client.V1EnvVar(name=key, value=val)) + if adc_available: + envs.append( + k8s_client.V1EnvVar( + name="GOOGLE_APPLICATION_CREDENTIALS", + value=resource_allocation.k8s.ADC_MOUNT_PATH, + ) + ) + ip_env = k8s_client.V1EnvVar( name="NODE_IP", value_from=k8s_client.V1EnvVarSource(