diff --git a/modal/gpu.py b/modal/gpu.py index 4f235226b..36cc2a8f7 100644 --- a/modal/gpu.py +++ b/modal/gpu.py @@ -1,17 +1,25 @@ # Copyright Modal Labs 2022 -from dataclasses import dataclass from typing import Union from modal_proto import api_pb2 +from ._utils.deprecation import deprecation_warning from .exception import InvalidError -@dataclass(frozen=True) class _GPUConfig: gpu_type: str count: int + def __init__(self, gpu_type: str, count: int): + name = self.__class__.__name__ + str_value = gpu_type + if count > 1: + str_value += f":{count}" + deprecation_warning((2025, 2, 7), f'`gpu={name}(...)` is deprecated. Use `gpu="{str_value}"` instead.') + self.gpu_type = gpu_type + self.count = count + def _to_proto(self) -> api_pb2.GPUConfig: """Convert this GPU config to an internal protobuf representation.""" return api_pb2.GPUConfig( @@ -175,10 +183,19 @@ def __repr__(self): def my_gpu_function(): ... # This will have 4 A100-80GB with each container ``` + +**Deprecation notes** + +An older deprecated way to configure GPU is also still supported, +but will be removed in future versions of Modal. Examples: + +- `gpu=modal.gpu.H100()` will attach 1 H100 GPU to each container +- `gpu=modal.gpu.T4(count=4)` will attach 4 T4 GPUs to each container +- `gpu=modal.gpu.A100()` will attach 1 A100-40GB GPUs to each container +- `gpu=modal.gpu.A100(size="80GB")` will attach 1 A100-80GB GPUs to each container """ -# bool will be deprecated in future versions. -GPU_T = Union[None, bool, str, _GPUConfig] +GPU_T = Union[None, str, _GPUConfig] def parse_gpu_config(value: GPU_T) -> api_pb2.GPUConfig: @@ -197,7 +214,9 @@ def parse_gpu_config(value: GPU_T) -> api_pb2.GPUConfig: gpu_type=gpu_type, count=count, ) - elif value is None or value is False: + elif value is None: return api_pb2.GPUConfig() else: - raise InvalidError(f"Invalid GPU config: {value}. Value must be a string, a `GPUConfig` object, or `None`.") + raise InvalidError( + f"Invalid GPU config: {value}. Value must be a string or `None` (or a deprecated `modal.gpu` object)" + ) diff --git a/test/gpu_fallbacks_test.py b/test/gpu_fallbacks_test.py index 7764cafa0..59ea3e8f9 100644 --- a/test/gpu_fallbacks_test.py +++ b/test/gpu_fallbacks_test.py @@ -1,5 +1,4 @@ # Copyright Modal Labs 2024 -import modal from modal import App from modal_proto import api_pb2 @@ -16,7 +15,7 @@ def f2(): pass -@app.function(gpu=["h100:2", modal.gpu.A100(count=2, size="80GB")]) +@app.function(gpu=["h100:2", "a100-80gb:2"]) def f3(): pass diff --git a/test/gpu_test.py b/test/gpu_test.py index c1742446d..8301e90f9 100644 --- a/test/gpu_test.py +++ b/test/gpu_test.py @@ -1,6 +1,7 @@ # Copyright Modal Labs 2022 import pytest +import modal.gpu from modal import App from modal.exception import InvalidError from modal_proto import api_pb2 @@ -58,11 +59,10 @@ def test_invalid_gpu_string_config(client, servicer, gpu_arg): def test_gpu_config_function(client, servicer): - import modal - app = App() - app.function(gpu=modal.gpu.A100())(dummy) + with pytest.warns(match='gpu="A100-40GB"'): + app.function(gpu=modal.gpu.A100())(dummy) with app.run(client=client): pass @@ -71,12 +71,18 @@ def test_gpu_config_function(client, servicer): assert func_def.resources.gpu_config.count == 1 -def test_cloud_provider_selection(client, servicer): - import modal +def test_gpu_config_function_more(client, servicer): + # Make sure some other GPU types also throw warnings + with pytest.warns(match='gpu="A100-80GB"'): + modal.gpu.A100(size="80GB") + with pytest.warns(match='gpu="T4:7"'): + modal.gpu.T4(count=7) + +def test_cloud_provider_selection(client, servicer): app = App() - app.function(gpu=modal.gpu.A100(), cloud="gcp")(dummy) + app.function(gpu="A100", cloud="gcp")(dummy) with app.run(client=client): pass @@ -85,6 +91,7 @@ def test_cloud_provider_selection(client, servicer): assert func_def.cloud_provider == api_pb2.CLOUD_PROVIDER_GCP assert func_def.cloud_provider_str == "GCP" + assert func_def.resources.gpu_config.gpu_type == "A100" assert func_def.resources.gpu_config.count == 1 # Invalid enum value. @@ -100,13 +107,9 @@ def test_cloud_provider_selection(client, servicer): ], ) def test_memory_selection_gpu_variant(client, servicer, memory_arg, gpu_type): - import modal - app = App() - if isinstance(memory_arg, str): + with pytest.warns(match='gpu="A100'): app.function(gpu=modal.gpu.A100(size=memory_arg))(dummy) - else: - raise RuntimeError(f"Unexpected test parametrization arg type {type(memory_arg)}") with app.run(client=client): pass @@ -118,8 +121,6 @@ def test_memory_selection_gpu_variant(client, servicer, memory_arg, gpu_type): def test_gpu_unsupported_config(): - import modal - app = App() with pytest.raises(ValueError, match="size='20GB' is invalid"): @@ -128,12 +129,10 @@ def test_gpu_unsupported_config(): @pytest.mark.parametrize("count", [1, 2, 3, 4]) def test_gpu_type_selection_from_count(client, servicer, count): - import modal - app = App() # Task type does not change when user asks more than 1 GPU on an A100. - app.function(gpu=modal.gpu.A100(count=count))(dummy) + app.function(gpu=f"A100:{count}")(dummy) with app.run(client=client): pass diff --git a/test/image_test.py b/test/image_test.py index 5898e9933..1fa87a45e 100644 --- a/test/image_test.py +++ b/test/image_test.py @@ -13,7 +13,7 @@ from unittest import mock import modal -from modal import App, Dict, Image, Secret, build, environments, gpu, method +from modal import App, Dict, Image, Secret, build, environments, method from modal._serialization import serialize from modal._utils.async_utils import synchronizer from modal.client import Client @@ -1001,7 +1001,7 @@ def test_image_gpu(builder_version, servicer, client): layers = get_image_layers(app.image.object_id, servicer) assert layers[0].gpu_config.gpu_type == "ANY" - app = App(image=Image.debian_slim().run_commands("echo 2", gpu=gpu.A10G())) + app = App(image=Image.debian_slim().run_commands("echo 2", gpu="a10g")) app.function()(dummy) with app.run(client=client): layers = get_image_layers(app.image.object_id, servicer)