Skip to content

Commit

Permalink
feat(inference): add provision_model to exec on gcp with sqs
Browse files Browse the repository at this point in the history
  • Loading branch information
akhileshh authored and supersergiy committed Dec 6, 2023
1 parent a2027dc commit 635fa19
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from __future__ import annotations

from contextlib import contextmanager
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Literal, Optional

from kubernetes import client as k8s_client # type: ignore
from zetta_utils import builder, log
Expand Down Expand Up @@ -35,6 +35,7 @@ def get_deployment_spec(
volumes: Optional[List[k8s_client.V1Volume]] = None,
volume_mounts: Optional[List[k8s_client.V1VolumeMount]] = None,
resource_requests: Optional[Dict[str, int | float | str]] = None,
provisioning_model: Literal["standard", "spot"] = "spot",
) -> k8s_client.V1Deployment:
schedule_toleration = k8s_client.V1Toleration(
key="worker-pool", operator="Equal", value="true", effect="NoSchedule"
Expand All @@ -47,6 +48,7 @@ def get_deployment_spec(
command_args=["-c", command],
resources=resources,
env_secret_mapping=env_secret_mapping,
node_selector={"cloud.google.com/gke-provisioning": provisioning_model},
tolerations=[schedule_toleration],
volumes=volumes,
volume_mounts=volume_mounts,
Expand Down Expand Up @@ -91,6 +93,7 @@ def get_mazepa_worker_deployment( # pylint: disable=too-many-locals
resource_requests: Optional[Dict[str, int | float | str]] = None,
num_procs: int = 1,
semaphores_spec: dict[SemaphoreType, int] | None = None,
provisioning_model: Literal["standard", "spot"] = "spot",
):
if labels is None:
labels_final = {"execution_id": execution_id}
Expand All @@ -113,6 +116,7 @@ def get_mazepa_worker_deployment( # pylint: disable=too-many-locals
volumes=get_common_volumes(),
volume_mounts=get_common_volume_mounts(),
resource_requests=resource_requests,
provisioning_model=provisioning_model,
)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import copy
import os
from contextlib import AbstractContextManager, ExitStack
from typing import Dict, Final, Iterable, Optional, Union
from typing import Dict, Final, Iterable, Literal, Optional, Union

from zetta_utils import builder, log, mazepa
from zetta_utils.cloud_management import execution_tracker, resource_allocation
Expand Down Expand Up @@ -62,6 +62,7 @@ def get_gcp_with_sqs_config(
worker_resource_requests: Optional[Dict[str, int | float | str]] = None,
num_procs: int = 1,
semaphores_spec: dict[SemaphoreType, int] | None = None,
provisioning_model: Literal["standard", "spot"] = "spot",
) -> tuple[PushMessageQueue[Task], PullMessageQueue[OutcomeReport], list[AbstractContextManager]]:
work_queue_name = f"zzz-{execution_id}-work"
outcome_queue_name = f"zzz-{execution_id}-outcome"
Expand Down Expand Up @@ -101,6 +102,7 @@ def get_gcp_with_sqs_config(
resource_requests=worker_resource_requests,
num_procs=num_procs,
semaphores_spec=semaphores_spec,
provisioning_model=provisioning_model,
)

ctx_managers.append(
Expand Down Expand Up @@ -137,6 +139,7 @@ def execute_on_gcp_with_sqs( # pylint: disable=too-many-locals
worker_resource_requests: Optional[Dict[str, int | float | str]] = None,
num_procs: int = 1,
semaphores_spec: dict[SemaphoreType, int] | None = None,
provisioning_model: Literal["standard", "spot"] = "spot",
):
_ensure_required_env_vars()
execution_id = mazepa.id_generation.get_unique_id(
Expand Down Expand Up @@ -182,6 +185,7 @@ def execute_on_gcp_with_sqs( # pylint: disable=too-many-locals
worker_resource_requests=worker_resource_requests,
num_procs=num_procs,
semaphores_spec=semaphores_spec,
provisioning_model=provisioning_model,
)

with ExitStack() as stack:
Expand Down

0 comments on commit 635fa19

Please sign in to comment.