Skip to content

Commit

Permalink
Don't send cloud provider enum to server (#2857)
Browse files Browse the repository at this point in the history
* Don't send cloud provider enum to server

* Fix test
  • Loading branch information
erikbern authored Feb 11, 2025
1 parent 20db6fd commit 41d57c2
Show file tree
Hide file tree
Showing 7 changed files with 20 additions and 46 deletions.
8 changes: 1 addition & 7 deletions modal/_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
from modal_proto import api_pb2
from modal_proto.modal_api_grpc import ModalClientModal

from ._location import parse_cloud_provider
from ._object import _get_environment_name, _Object, live_method, live_method_gen
from ._pty import get_pty_info
from ._resolver import Resolver
Expand Down Expand Up @@ -627,10 +626,6 @@ def from_local(

if not cloud and not is_builder_function:
cloud = config.get("default_cloud")
if cloud:
cloud_provider = parse_cloud_provider(cloud)
else:
cloud_provider = None

if is_generator and webhook_config:
if webhook_config.type == api_pb2.WEBHOOK_TYPE_FUNCTION:
Expand Down Expand Up @@ -819,8 +814,7 @@ async def _load(self: _Function, resolver: Resolver, existing_object_id: Optiona
task_idle_timeout_secs=container_idle_timeout or 0,
concurrency_limit=concurrency_limit or 0,
pty_info=pty_info,
cloud_provider=cloud_provider, # Deprecated at some point
cloud_provider_str=cloud.upper() if cloud else "", # Supersedes cloud_provider
cloud_provider_str=cloud if cloud else "",
warm_pool_size=keep_warm or 0,
runtime=config.get("function_runtime"),
runtime_debug=config.get("function_runtime_debug"),
Expand Down
29 changes: 0 additions & 29 deletions modal/_location.py
Original file line number Diff line number Diff line change
@@ -1,35 +1,6 @@
# Copyright Modal Labs 2022
from enum import Enum

import modal_proto.api_pb2

from .exception import InvalidError


class CloudProvider(Enum):
AWS = modal_proto.api_pb2.CLOUD_PROVIDER_AWS
GCP = modal_proto.api_pb2.CLOUD_PROVIDER_GCP
AUTO = modal_proto.api_pb2.CLOUD_PROVIDER_AUTO
OCI = modal_proto.api_pb2.CLOUD_PROVIDER_OCI


def parse_cloud_provider(value: str) -> "modal_proto.api_pb2.CloudProvider.V":
try:
cloud_provider = CloudProvider[value.upper()]
except KeyError:
# provider's int identifier may be directly specified
try:
return int(value) # type: ignore
except ValueError:
pass

raise InvalidError(
f"Invalid cloud provider: {value}. "
f"Value must be one of {[x.name.lower() for x in CloudProvider]} (case-insensitive)."
)

return cloud_provider.value


def display_location(cloud_provider: "modal_proto.api_pb2.CloudProvider.V") -> str:
if cloud_provider == modal_proto.api_pb2.CLOUD_PROVIDER_GCP:
Expand Down
4 changes: 1 addition & 3 deletions modal/sandbox.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from modal.volume import _Volume
from modal_proto import api_pb2

from ._location import parse_cloud_provider
from ._object import _get_environment_name, _Object
from ._resolver import Resolver
from ._resources import convert_fn_config_to_resources_config
Expand Down Expand Up @@ -186,8 +185,7 @@ async def _load(self: _Sandbox, resolver: Resolver, _existing_object_id: Optiona
resources=convert_fn_config_to_resources_config(
cpu=cpu, memory=memory, gpu=gpu, ephemeral_disk=ephemeral_disk
),
cloud_provider=parse_cloud_provider(cloud) if cloud else None, # Deprecated at some point
cloud_provider_str=cloud.upper() if cloud else None, # Supersedes cloud_provider
cloud_provider_str=cloud if cloud else None, # Supersedes cloud_provider
nfs_mounts=network_file_system_mount_protos(validated_network_file_systems, False),
runtime_debug=config.get("function_runtime_debug"),
cloud_bucket_mounts=cloud_bucket_mounts_to_proto(cloud_bucket_mounts),
Expand Down
5 changes: 5 additions & 0 deletions test/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
from modal_proto import api_grpc, api_pb2

VALID_GPU_TYPES = ["ANY", "T4", "L4", "A10G", "L40S", "A100", "A100-40GB", "A100-80GB", "H100"]
VALID_CLOUD_PROVIDERS = ["AWS", "GCP", "OCI", "AUTO", "XYZ"]


@dataclasses.dataclass
Expand Down Expand Up @@ -975,6 +976,10 @@ async def FunctionCreate(self, stream):
if request.function.resources.gpu_config.count > 0:
if request.function.resources.gpu_config.gpu_type not in VALID_GPU_TYPES:
raise GRPCError(Status.INVALID_ARGUMENT, "Not a valid GPU type")
if request.function.cloud_provider_str:
if request.function.cloud_provider_str.upper() not in VALID_CLOUD_PROVIDERS:
raise GRPCError(Status.INVALID_ARGUMENT, "Not a valid cloud provider")

if request.existing_function_id:
function_id = request.existing_function_id
else:
Expand Down
6 changes: 3 additions & 3 deletions test/function_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -782,14 +782,14 @@ def my_handle():
def test_default_cloud_provider(client, servicer, monkeypatch):
app = App()

monkeypatch.setenv("MODAL_DEFAULT_CLOUD", "oci")
monkeypatch.setenv("MODAL_DEFAULT_CLOUD", "xyz")
app.function()(dummy)
with app.run(client=client):
object_id: str = app.registered_functions["dummy"].object_id
f = servicer.app_functions[object_id]

assert f.cloud_provider == api_pb2.CLOUD_PROVIDER_OCI
assert f.cloud_provider_str == "OCI"
assert f.cloud_provider == api_pb2.CLOUD_PROVIDER_UNSPECIFIED # No longer sent
assert f.cloud_provider_str == "xyz"


def test_not_hydrated():
Expand Down
10 changes: 8 additions & 2 deletions test/gpu_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,15 +88,21 @@ def test_cloud_provider_selection(client, servicer):

assert len(servicer.app_functions) == 1
func_def = next(iter(servicer.app_functions.values()))
assert func_def.cloud_provider == api_pb2.CLOUD_PROVIDER_GCP
assert func_def.cloud_provider_str == "GCP"
assert func_def.cloud_provider == api_pb2.CLOUD_PROVIDER_UNSPECIFIED # No longer set
assert func_def.cloud_provider_str == "gcp"

assert func_def.resources.gpu_config.gpu_type == "A100"
assert func_def.resources.gpu_config.count == 1


def test_invalid_cloud_provider_selection(client, servicer):
app = App()

# Invalid enum value.
with pytest.raises(InvalidError):
app.function(cloud="foo")(dummy)
with app.run(client=client):
pass


@pytest.mark.parametrize(
Expand Down
4 changes: 2 additions & 2 deletions test/image_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -644,8 +644,8 @@ def test_image_run_function_with_cloud_selection(servicer, client):

assert len(servicer.app_functions) == 2
func_def = next(iter(servicer.app_functions.values()))
assert func_def.cloud_provider == api_pb2.CLOUD_PROVIDER_OCI
assert func_def.cloud_provider_str == "OCI"
assert func_def.cloud_provider == api_pb2.CLOUD_PROVIDER_UNSPECIFIED # No longer set
assert func_def.cloud_provider_str == "oci"


def test_poetry(builder_version, servicer, client):
Expand Down

0 comments on commit 41d57c2

Please sign in to comment.