Skip to content

Commit

Permalink
feat(auth): use application_default_credentials.json when available
Browse files Browse the repository at this point in the history
akhileshh authored and supersergiy committed Jan 25, 2025
1 parent c899696 commit 4deebbb
Showing 9 changed files with 99 additions and 27 deletions.
1 change: 0 additions & 1 deletion zetta_utils/api/v0.py
Original file line number Diff line number Diff line change
@@ -46,7 +46,6 @@
from zetta_utils.cloud_management.resource_allocation.k8s.deployment import (
deployment_ctx_mngr,
get_deployment,
get_deployment_spec,
get_mazepa_worker_deployment,
)
from zetta_utils.cloud_management.resource_allocation.k8s.eks import (
Original file line number Diff line number Diff line change
@@ -34,4 +34,4 @@
from .pod import get_pod_spec, get_mazepa_pod_spec
from .secret import secrets_ctx_mngr, get_secrets_and_mapping
from .service import get_service, service_ctx_manager
from .volume import get_common_volumes, get_common_volume_mounts
from .volume import ADC_MOUNT_PATH, get_common_volumes, get_common_volume_mounts
Original file line number Diff line number Diff line change
@@ -24,7 +24,7 @@
logger = log.get_logger("zetta_utils")


def get_deployment_spec(
def _get_mazepa_deployment(
name: str,
image: str,
command: str,
@@ -35,6 +35,7 @@ def get_deployment_spec(
resource_requests: Optional[Dict[str, int | float | str]] = None,
provisioning_model: Literal["standard", "spot"] = "spot",
gpu_accelerator_type: str | None = None,
adc_available: bool = False,
) -> k8s_client.V1Deployment:
name = f"run-{name}"
pod_spec = get_mazepa_pod_spec(
@@ -45,6 +46,7 @@ def get_deployment_spec(
provisioning_model=provisioning_model,
resource_requests=resource_requests,
gpu_accelerator_type=gpu_accelerator_type,
adc_available=adc_available,
)

pod_template = k8s_client.V1PodTemplateSpec(
@@ -69,7 +71,6 @@ def get_deployment_spec(
metadata=k8s_client.V1ObjectMeta(name=name, labels=labels),
spec=deployment_spec,
)

return deployment


@@ -87,6 +88,7 @@ def get_mazepa_worker_deployment( # pylint: disable=too-many-locals
semaphores_spec: dict[SemaphoreType, int] | None = None,
provisioning_model: Literal["standard", "spot"] = "spot",
gpu_accelerator_type: str | None = None,
adc_available: bool = False,
):
if labels is None:
labels_final = {"run_id": run_id}
@@ -98,7 +100,7 @@ def get_mazepa_worker_deployment( # pylint: disable=too-many-locals
)
logger.debug(f"Making a deployment with worker command: '{worker_command}'")

return get_deployment_spec(
return _get_mazepa_deployment(
name=run_id,
image=image,
replicas=replicas,
@@ -109,6 +111,7 @@ def get_mazepa_worker_deployment( # pylint: disable=too-many-locals
resource_requests=resource_requests,
provisioning_model=provisioning_model,
gpu_accelerator_type=gpu_accelerator_type,
adc_available=adc_available,
)


17 changes: 11 additions & 6 deletions zetta_utils/cloud_management/resource_allocation/k8s/pod.py
Original file line number Diff line number Diff line change
@@ -8,10 +8,7 @@

from kubernetes import client as k8s_client
from zetta_utils import log
from zetta_utils.cloud_management.resource_allocation.k8s.volume import (
get_common_volume_mounts,
get_common_volumes,
)
from zetta_utils.cloud_management.resource_allocation.k8s import volume

from .secret import get_worker_env_vars

@@ -87,6 +84,7 @@ def get_mazepa_pod_spec(
resource_requests: Optional[Dict[str, int | float | str]] = None,
restart_policy: Literal["Always", "Never"] = "Always",
gpu_accelerator_type: str | None = None,
adc_available: bool = False,
) -> k8s_client.V1PodSpec:
schedule_toleration = k8s_client.V1Toleration(
key="worker-pool", operator="Equal", value="true", effect="NoSchedule"
@@ -96,17 +94,24 @@ def get_mazepa_pod_spec(
if gpu_accelerator_type:
node_selector["cloud.google.com/gke-accelerator"] = gpu_accelerator_type

envs = []
if adc_available:
envs.append(
k8s_client.V1EnvVar(name="GOOGLE_APPLICATION_CREDENTIALS", value=volume.ADC_MOUNT_PATH)
)

return get_pod_spec(
name="zutils-worker",
image=image,
command=["/bin/sh"],
command_args=["-c", command],
resources=resources,
envs=envs,
env_secret_mapping=env_secret_mapping,
node_selector=node_selector,
restart_policy=restart_policy,
tolerations=[schedule_toleration],
volumes=get_common_volumes(),
volume_mounts=get_common_volume_mounts(),
volumes=volume.get_common_volumes(),
volume_mounts=volume.get_common_volume_mounts(),
resource_requests=resource_requests,
)
34 changes: 30 additions & 4 deletions zetta_utils/cloud_management/resource_allocation/k8s/secret.py
Original file line number Diff line number Diff line change
@@ -2,6 +2,7 @@
Helpers for k8s secrets.
"""

import base64
import os
from contextlib import contextmanager
from typing import Dict, Iterable, List, Optional, Tuple
@@ -19,8 +20,6 @@

logger = log.get_logger("zetta_utils")

CV_SECRETS_NAME = "cloudvolume-secrets"


def get_worker_env_vars(env_secret_mapping: Optional[Dict[str, str]] = None) -> list:
if env_secret_mapping is None:
@@ -53,11 +52,30 @@ def get_worker_env_vars(env_secret_mapping: Optional[Dict[str, str]] = None) ->
return envs


def _get_user_adc() -> str | None:
"""
Reads credentials file created by ` gcloud auth application-default login`.
"""
file_name = ".config/gcloud/application_default_credentials.json"
home_dir = os.path.expanduser("~")
file_path = os.path.join(home_dir, file_name)
data = None
if os.path.exists(file_path):
try:
with open(file_path, "r", encoding="utf-8") as f:
data = f.read()
data = base64.b64encode(data.encode()).decode()
except Exception: # pylint: disable=broad-exception-caught
...
return data


def get_secrets_and_mapping(
run_id: str, share_envs: Iterable[str] = ()
) -> Tuple[List[k8s_client.V1Secret], Dict[str, str]]:
) -> Tuple[List[k8s_client.V1Secret], Dict[str, str], bool]:
env_secret_mapping: Dict[str, str] = {}
secrets_kv: Dict[str, str] = {}
adc_available: bool = False

combined_secret_data = {}
for env_k in share_envs:
@@ -88,7 +106,15 @@ def get_secrets_and_mapping(
string_data={"value": v},
)
secrets.append(secret)
return secrets, env_secret_mapping

adc_content = _get_user_adc()
adc_available = not adc_content is None
adc_creds = k8s_client.V1Secret(
metadata=k8s_client.V1ObjectMeta(name=f"run-{run_id}-adc"),
data={"adc.json": adc_content},
)
secrets.append(adc_creds)
return secrets, env_secret_mapping, adc_available


@contextmanager
23 changes: 20 additions & 3 deletions zetta_utils/cloud_management/resource_allocation/k8s/volume.py
Original file line number Diff line number Diff line change
@@ -4,7 +4,14 @@

from __future__ import annotations

from typing import Final

from kubernetes import client as k8s_client
from zetta_utils import run

ADC_MOUNT_PATH: Final[str] = "/etc/secrets/adc.json"
SHM_MOUNT_PATH: Final[str] = "/dev/shm"
TMP_MOUNT_PATH: Final[str] = "/tmp"


def get_common_volumes():
@@ -14,11 +21,21 @@ def get_common_volumes():
tmp = k8s_client.V1Volume(
name="tmp", empty_dir=k8s_client.V1EmptyDirVolumeSource(medium="Memory")
)
return [dshm, tmp]

# application_default_credentials
adc = k8s_client.V1Volume(
name="adc", secret=k8s_client.V1SecretVolumeSource(secret_name=f"run-{run.RUN_ID}-adc")
)
return [dshm, tmp, adc]


def get_common_volume_mounts():
return [
k8s_client.V1VolumeMount(mount_path="/dev/shm", name="dshm"),
k8s_client.V1VolumeMount(mount_path="/tmp", name="tmp"),
k8s_client.V1VolumeMount(mount_path=SHM_MOUNT_PATH, name="dshm"),
k8s_client.V1VolumeMount(mount_path=TMP_MOUNT_PATH, name="tmp"),
k8s_client.V1VolumeMount(
name="adc",
mount_path="/etc/secrets",
read_only=True,
),
]
Original file line number Diff line number Diff line change
@@ -86,7 +86,9 @@ def get_gcp_with_sqs_config(
ctx_managers.append(aws_sqs.sqs_queue_ctx_mngr(execution_id, task_queue))
ctx_managers.append(aws_sqs.sqs_queue_ctx_mngr(execution_id, outcome_queue))

secrets, env_secret_mapping = k8s.get_secrets_and_mapping(execution_id, REQUIRED_ENV_VARS)
secrets, env_secret_mapping, adc_available = k8s.get_secrets_and_mapping(
execution_id, REQUIRED_ENV_VARS
)

if sqs_based_scaling:
worker_command = k8s.get_mazepa_worker_command(
@@ -104,6 +106,7 @@ def get_gcp_with_sqs_config(
provisioning_model=provisioning_model,
resource_requests=worker_resource_requests,
restart_policy="Never",
adc_available=adc_available,
)
job_spec = k8s.get_job_spec(pod_spec=pod_spec)
scaled_job_ctx_mngr = k8s.keda_deprecated.scaled_job_ctx_mngr(
@@ -129,6 +132,7 @@ def get_gcp_with_sqs_config(
num_procs=num_procs,
semaphores_spec=semaphores_spec,
provisioning_model=provisioning_model,
adc_available=adc_available,
)
deployment_ctx_mngr = k8s.deployment_ctx_mngr(
execution_id,
Original file line number Diff line number Diff line change
@@ -93,6 +93,7 @@ def _get_group_taskqueue_and_contexts(
sqs_trigger_name: str,
outcome_queue_spec: dict[str, Any],
env_secret_mapping: dict[str, str],
adc_available: bool = False,
) -> tuple[PushMessageQueue[Task], list[AbstractContextManager]]:
ctx_managers: list[AbstractContextManager] = []
work_queue_name = f"run-{execution_id}-{'-'.join(group.queue_tags)}-work"
@@ -117,6 +118,7 @@ def _get_group_taskqueue_and_contexts(
resource_requests=group.resource_requests,
restart_policy="Never",
gpu_accelerator_type=group.gpu_accelerator_type,
adc_available=adc_available,
)
job_spec = k8s.get_job_spec(pod_spec=pod_spec)
scaled_job_ctx_mngr = k8s.scaled_job_ctx_mngr(
@@ -145,6 +147,7 @@ def _get_group_taskqueue_and_contexts(
semaphores_spec=group.semaphores_spec,
provisioning_model=group.provisioning_model,
gpu_accelerator_type=group.gpu_accelerator_type,
adc_available=adc_available,
)
deployment_ctx_mngr = k8s.deployment_ctx_mngr(
execution_id,
@@ -164,7 +167,9 @@ def get_gcp_with_sqs_config(
ctx_managers: list[AbstractContextManager],
) -> tuple[PushMessageQueue[Task], PullMessageQueue[OutcomeReport], list[AbstractContextManager]]:
task_queues = []
secrets, env_secret_mapping = k8s.get_secrets_and_mapping(execution_id, REQUIRED_ENV_VARS)
secrets, env_secret_mapping, adc_available = k8s.get_secrets_and_mapping(
execution_id, REQUIRED_ENV_VARS
)

outcome_queue_name = f"run-{execution_id}-outcome"
outcome_queue_spec = {"@type": "SQSQueue", "name": outcome_queue_name, "pull_wait_sec": 2.5}
@@ -187,6 +192,7 @@ def get_gcp_with_sqs_config(
sqs_trigger_name=sqs_trigger_name,
outcome_queue_spec=outcome_queue_spec,
env_secret_mapping=env_secret_mapping,
adc_available=adc_available,
)
task_queues.append(task_queue)
ctx_managers.extend(group_ctx_managers)
26 changes: 19 additions & 7 deletions zetta_utils/training/lightning/train.py
Original file line number Diff line number Diff line change
@@ -98,12 +98,16 @@ def lightning_train(
_lightning_train_local(
regime=regime if not isinstance(regime, dict) else builder.build(regime),
trainer=trainer if not isinstance(trainer, dict) else builder.build(trainer),
train_dataloader=train_dataloader
if not isinstance(train_dataloader, dict)
else builder.build(train_dataloader, parallel=builder.PARALLEL_BUILD_ALLOWED),
val_dataloader=val_dataloader
if not isinstance(val_dataloader, dict)
else builder.build(val_dataloader, parallel=builder.PARALLEL_BUILD_ALLOWED),
train_dataloader=(
train_dataloader
if not isinstance(train_dataloader, dict)
else builder.build(train_dataloader, parallel=builder.PARALLEL_BUILD_ALLOWED)
),
val_dataloader=(
val_dataloader
if not isinstance(val_dataloader, dict)
else builder.build(val_dataloader, parallel=builder.PARALLEL_BUILD_ALLOWED)
),
full_state_ckpt_path=full_state_ckpt_path,
)
return
@@ -336,7 +340,7 @@ def _lightning_train_remote(

specs = {"train": train_spec}
vol, mount, spec_ctx = _spec_configmap_vol_and_ctx(cluster_info, specs)
secrets, env_secret_mapping = resource_allocation.k8s.get_secrets_and_mapping(
secrets, env_secret_mapping, adc_available = resource_allocation.k8s.get_secrets_and_mapping(
run.RUN_ID, REQUIRED_ENV_VARS
)

@@ -348,6 +352,14 @@ def _lightning_train_remote(
for key, val in env_vars.items():
envs.append(k8s_client.V1EnvVar(name=key, value=val))

if adc_available:
envs.append(
k8s_client.V1EnvVar(
name="GOOGLE_APPLICATION_CREDENTIALS",
value=resource_allocation.k8s.ADC_MOUNT_PATH,
)
)

ip_env = k8s_client.V1EnvVar(
name="NODE_IP",
value_from=k8s_client.V1EnvVarSource(

0 comments on commit 4deebbb

Please sign in to comment.