Skip to content

Commit

Permalink
Fix count and A100-80GB
Browse files Browse the repository at this point in the history
  • Loading branch information
erikbern committed Feb 8, 2025
1 parent ed6e642 commit db8912c
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 7 deletions.
5 changes: 4 additions & 1 deletion modal/gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,10 @@ class _GPUConfig:

def __init__(self, gpu_type: str, count: int):
name = self.__class__.__name__
deprecation_warning((2025, 2, 7), f'`gpu={name}(...)` is deprecated. Use `gpu="{name}"` instead.')
str_value = gpu_type
if count > 1:
str_value += f":{count}"
deprecation_warning((2025, 2, 7), f'`gpu={name}(...)` is deprecated. Use `gpu="{str_value}"` instead.')
self.gpu_type = gpu_type
self.count = count

Expand Down
15 changes: 9 additions & 6 deletions test/gpu_test.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright Modal Labs 2022
import pytest

import modal.gpu
from modal import App
from modal.exception import InvalidError
from modal_proto import api_pb2
Expand Down Expand Up @@ -58,8 +59,6 @@ def test_invalid_gpu_string_config(client, servicer, gpu_arg):


def test_gpu_config_function(client, servicer):
import modal

app = App()

with pytest.warns(match='gpu="A100"'):
Expand All @@ -72,6 +71,14 @@ def test_gpu_config_function(client, servicer):
assert func_def.resources.gpu_config.count == 1


def test_gpu_config_function_more(client, servicer):
# Make sure some other GPU types also throw warnings
with pytest.warns(match='gpu="A100-80GB"'):
modal.gpu.A100(size="80GB")
with pytest.warns(match='gpu="T4:7"'):
modal.gpu.T4(count=7)


def test_cloud_provider_selection(client, servicer):
app = App()

Expand Down Expand Up @@ -100,8 +107,6 @@ def test_cloud_provider_selection(client, servicer):
],
)
def test_memory_selection_gpu_variant(client, servicer, memory_arg, gpu_type):
import modal

app = App()
with pytest.warns(match='gpu="A100'):
app.function(gpu=modal.gpu.A100(size=memory_arg))(dummy)
Expand All @@ -116,8 +121,6 @@ def test_memory_selection_gpu_variant(client, servicer, memory_arg, gpu_type):


def test_gpu_unsupported_config():
import modal

app = App()

with pytest.raises(ValueError, match="size='20GB' is invalid"):
Expand Down

0 comments on commit db8912c

Please sign in to comment.