From d64d1c45e00f53880eb721b31af15e0c7bd9a196 Mon Sep 17 00:00:00 2001 From: Erik Bernhardsson Date: Sat, 8 Feb 2025 02:44:41 +0000 Subject: [PATCH] Fix count and A100-80GB --- modal/gpu.py | 5 ++++- test/gpu_test.py | 17 ++++++++++------- 2 files changed, 14 insertions(+), 8 deletions(-) diff --git a/modal/gpu.py b/modal/gpu.py index 57af5ffcf..627a299bd 100644 --- a/modal/gpu.py +++ b/modal/gpu.py @@ -13,7 +13,10 @@ class _GPUConfig: def __init__(self, gpu_type: str, count: int): name = self.__class__.__name__ - deprecation_warning((2025, 2, 7), f'`gpu={name}(...)` is deprecated. Use `gpu="{name}"` instead.') + str_value = gpu_type + if count > 1: + str_value += f":{count}" + deprecation_warning((2025, 2, 7), f'`gpu={name}(...)` is deprecated. Use `gpu="{str_value}"` instead.') self.gpu_type = gpu_type self.count = count diff --git a/test/gpu_test.py b/test/gpu_test.py index 8c9173e64..8301e90f9 100644 --- a/test/gpu_test.py +++ b/test/gpu_test.py @@ -1,6 +1,7 @@ # Copyright Modal Labs 2022 import pytest +import modal.gpu from modal import App from modal.exception import InvalidError from modal_proto import api_pb2 @@ -58,11 +59,9 @@ def test_invalid_gpu_string_config(client, servicer, gpu_arg): def test_gpu_config_function(client, servicer): - import modal - app = App() - with pytest.warns(match='gpu="A100"'): + with pytest.warns(match='gpu="A100-40GB"'): app.function(gpu=modal.gpu.A100())(dummy) with app.run(client=client): pass @@ -72,6 +71,14 @@ def test_gpu_config_function(client, servicer): assert func_def.resources.gpu_config.count == 1 +def test_gpu_config_function_more(client, servicer): + # Make sure some other GPU types also throw warnings + with pytest.warns(match='gpu="A100-80GB"'): + modal.gpu.A100(size="80GB") + with pytest.warns(match='gpu="T4:7"'): + modal.gpu.T4(count=7) + + def test_cloud_provider_selection(client, servicer): app = App() @@ -100,8 +107,6 @@ def test_cloud_provider_selection(client, servicer): ], ) def test_memory_selection_gpu_variant(client, servicer, memory_arg, gpu_type): - import modal - app = App() with pytest.warns(match='gpu="A100'): app.function(gpu=modal.gpu.A100(size=memory_arg))(dummy) @@ -116,8 +121,6 @@ def test_memory_selection_gpu_variant(client, servicer, memory_arg, gpu_type): def test_gpu_unsupported_config(): - import modal - app = App() with pytest.raises(ValueError, match="size='20GB' is invalid"):