From 41d57c23e444d568d15b5f87853018955706773e Mon Sep 17 00:00:00 2001 From: Erik Bernhardsson Date: Tue, 11 Feb 2025 16:04:08 -0500 Subject: [PATCH] Don't send cloud provider enum to server (#2857) * Don't send cloud provider enum to server * Fix test --- modal/_functions.py | 8 +------- modal/_location.py | 29 ----------------------------- modal/sandbox.py | 4 +--- test/conftest.py | 5 +++++ test/function_test.py | 6 +++--- test/gpu_test.py | 10 ++++++++-- test/image_test.py | 4 ++-- 7 files changed, 20 insertions(+), 46 deletions(-) diff --git a/modal/_functions.py b/modal/_functions.py index 20909b1061..dad0071b4d 100644 --- a/modal/_functions.py +++ b/modal/_functions.py @@ -19,7 +19,6 @@ from modal_proto import api_pb2 from modal_proto.modal_api_grpc import ModalClientModal -from ._location import parse_cloud_provider from ._object import _get_environment_name, _Object, live_method, live_method_gen from ._pty import get_pty_info from ._resolver import Resolver @@ -627,10 +626,6 @@ def from_local( if not cloud and not is_builder_function: cloud = config.get("default_cloud") - if cloud: - cloud_provider = parse_cloud_provider(cloud) - else: - cloud_provider = None if is_generator and webhook_config: if webhook_config.type == api_pb2.WEBHOOK_TYPE_FUNCTION: @@ -819,8 +814,7 @@ async def _load(self: _Function, resolver: Resolver, existing_object_id: Optiona task_idle_timeout_secs=container_idle_timeout or 0, concurrency_limit=concurrency_limit or 0, pty_info=pty_info, - cloud_provider=cloud_provider, # Deprecated at some point - cloud_provider_str=cloud.upper() if cloud else "", # Supersedes cloud_provider + cloud_provider_str=cloud if cloud else "", warm_pool_size=keep_warm or 0, runtime=config.get("function_runtime"), runtime_debug=config.get("function_runtime_debug"), diff --git a/modal/_location.py b/modal/_location.py index 59237dea44..a5ae122244 100644 --- a/modal/_location.py +++ b/modal/_location.py @@ -1,35 +1,6 @@ # Copyright Modal Labs 2022 -from enum import Enum - import modal_proto.api_pb2 -from .exception import InvalidError - - -class CloudProvider(Enum): - AWS = modal_proto.api_pb2.CLOUD_PROVIDER_AWS - GCP = modal_proto.api_pb2.CLOUD_PROVIDER_GCP - AUTO = modal_proto.api_pb2.CLOUD_PROVIDER_AUTO - OCI = modal_proto.api_pb2.CLOUD_PROVIDER_OCI - - -def parse_cloud_provider(value: str) -> "modal_proto.api_pb2.CloudProvider.V": - try: - cloud_provider = CloudProvider[value.upper()] - except KeyError: - # provider's int identifier may be directly specified - try: - return int(value) # type: ignore - except ValueError: - pass - - raise InvalidError( - f"Invalid cloud provider: {value}. " - f"Value must be one of {[x.name.lower() for x in CloudProvider]} (case-insensitive)." - ) - - return cloud_provider.value - def display_location(cloud_provider: "modal_proto.api_pb2.CloudProvider.V") -> str: if cloud_provider == modal_proto.api_pb2.CLOUD_PROVIDER_GCP: diff --git a/modal/sandbox.py b/modal/sandbox.py index d83a6da1dc..c1426b6efe 100644 --- a/modal/sandbox.py +++ b/modal/sandbox.py @@ -15,7 +15,6 @@ from modal.volume import _Volume from modal_proto import api_pb2 -from ._location import parse_cloud_provider from ._object import _get_environment_name, _Object from ._resolver import Resolver from ._resources import convert_fn_config_to_resources_config @@ -186,8 +185,7 @@ async def _load(self: _Sandbox, resolver: Resolver, _existing_object_id: Optiona resources=convert_fn_config_to_resources_config( cpu=cpu, memory=memory, gpu=gpu, ephemeral_disk=ephemeral_disk ), - cloud_provider=parse_cloud_provider(cloud) if cloud else None, # Deprecated at some point - cloud_provider_str=cloud.upper() if cloud else None, # Supersedes cloud_provider + cloud_provider_str=cloud if cloud else None, # Supersedes cloud_provider nfs_mounts=network_file_system_mount_protos(validated_network_file_systems, False), runtime_debug=config.get("function_runtime_debug"), cloud_bucket_mounts=cloud_bucket_mounts_to_proto(cloud_bucket_mounts), diff --git a/test/conftest.py b/test/conftest.py index f31da250b1..97fe264249 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -50,6 +50,7 @@ from modal_proto import api_grpc, api_pb2 VALID_GPU_TYPES = ["ANY", "T4", "L4", "A10G", "L40S", "A100", "A100-40GB", "A100-80GB", "H100"] +VALID_CLOUD_PROVIDERS = ["AWS", "GCP", "OCI", "AUTO", "XYZ"] @dataclasses.dataclass @@ -975,6 +976,10 @@ async def FunctionCreate(self, stream): if request.function.resources.gpu_config.count > 0: if request.function.resources.gpu_config.gpu_type not in VALID_GPU_TYPES: raise GRPCError(Status.INVALID_ARGUMENT, "Not a valid GPU type") + if request.function.cloud_provider_str: + if request.function.cloud_provider_str.upper() not in VALID_CLOUD_PROVIDERS: + raise GRPCError(Status.INVALID_ARGUMENT, "Not a valid cloud provider") + if request.existing_function_id: function_id = request.existing_function_id else: diff --git a/test/function_test.py b/test/function_test.py index ef1774d4c3..9ce3d1e226 100644 --- a/test/function_test.py +++ b/test/function_test.py @@ -782,14 +782,14 @@ def my_handle(): def test_default_cloud_provider(client, servicer, monkeypatch): app = App() - monkeypatch.setenv("MODAL_DEFAULT_CLOUD", "oci") + monkeypatch.setenv("MODAL_DEFAULT_CLOUD", "xyz") app.function()(dummy) with app.run(client=client): object_id: str = app.registered_functions["dummy"].object_id f = servicer.app_functions[object_id] - assert f.cloud_provider == api_pb2.CLOUD_PROVIDER_OCI - assert f.cloud_provider_str == "OCI" + assert f.cloud_provider == api_pb2.CLOUD_PROVIDER_UNSPECIFIED # No longer sent + assert f.cloud_provider_str == "xyz" def test_not_hydrated(): diff --git a/test/gpu_test.py b/test/gpu_test.py index 8301e90f94..efe7252dce 100644 --- a/test/gpu_test.py +++ b/test/gpu_test.py @@ -88,15 +88,21 @@ def test_cloud_provider_selection(client, servicer): assert len(servicer.app_functions) == 1 func_def = next(iter(servicer.app_functions.values())) - assert func_def.cloud_provider == api_pb2.CLOUD_PROVIDER_GCP - assert func_def.cloud_provider_str == "GCP" + assert func_def.cloud_provider == api_pb2.CLOUD_PROVIDER_UNSPECIFIED # No longer set + assert func_def.cloud_provider_str == "gcp" assert func_def.resources.gpu_config.gpu_type == "A100" assert func_def.resources.gpu_config.count == 1 + +def test_invalid_cloud_provider_selection(client, servicer): + app = App() + # Invalid enum value. with pytest.raises(InvalidError): app.function(cloud="foo")(dummy) + with app.run(client=client): + pass @pytest.mark.parametrize( diff --git a/test/image_test.py b/test/image_test.py index 7955eec93d..572623a6ae 100644 --- a/test/image_test.py +++ b/test/image_test.py @@ -644,8 +644,8 @@ def test_image_run_function_with_cloud_selection(servicer, client): assert len(servicer.app_functions) == 2 func_def = next(iter(servicer.app_functions.values())) - assert func_def.cloud_provider == api_pb2.CLOUD_PROVIDER_OCI - assert func_def.cloud_provider_str == "OCI" + assert func_def.cloud_provider == api_pb2.CLOUD_PROVIDER_UNSPECIFIED # No longer set + assert func_def.cloud_provider_str == "oci" def test_poetry(builder_version, servicer, client):