Skip to content

Commit

Permalink
typing.Annotated
Browse files Browse the repository at this point in the history
  • Loading branch information
kramstrom committed Jan 27, 2025
1 parent f21dc99 commit eab8b27
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 12 deletions.
2 changes: 2 additions & 0 deletions modal/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
try:
from ._runtime.execution_context import current_function_call_id, current_input_id, interact, is_local
from ._tunnel import Tunnel, forward
from ._utils.function_utils import PickleSerialization
from .app import App, Stub
from .client import Client
from .cloud_bucket_mount import CloudBucketMount
Expand Down Expand Up @@ -78,6 +79,7 @@
"interact",
"method",
"parameter",
"PickleSerialization",
"web_endpoint",
"web_server",
"wsgi_app",
Expand Down
22 changes: 16 additions & 6 deletions modal/_utils/function_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from collections.abc import AsyncGenerator
from enum import Enum
from pathlib import Path, PurePosixPath
from typing import Any, Callable, Literal, Optional
from typing import Annotated, Any, Callable, Literal, Optional, get_args, get_origin

from grpclib import GRPCError
from grpclib.exceptions import StreamTerminatedError
Expand Down Expand Up @@ -37,11 +37,15 @@ class FunctionInfoType(Enum):
NOTEBOOK = "notebook"


class PickleSerialization:
pass


# TODO(elias): Add support for quoted/str annotations
CLASS_PARAM_TYPE_MAP: dict[type, tuple["api_pb2.ParameterType.ValueType", str]] = {
str: (api_pb2.PARAM_TYPE_STRING, "string_default"),
int: (api_pb2.PARAM_TYPE_INT, "int_default"),
Any: (api_pb2.PARAM_TYPE_PICKLE, "pickle_default"),
PickleSerialization: (api_pb2.PARAM_TYPE_PICKLE, "pickle_default"),
}


Expand Down Expand Up @@ -296,10 +300,16 @@ def class_parameter_info(self) -> api_pb2.ClassParameterInfo:
signature = _get_class_constructor_signature(self.user_cls)
for param in signature.parameters.values():
has_default = param.default is not param.empty
if param.annotation not in CLASS_PARAM_TYPE_MAP:
param_type, default_field = CLASS_PARAM_TYPE_MAP[Any]
else:
param_type, default_field = CLASS_PARAM_TYPE_MAP[param.annotation]
pickle_annotated = (
get_origin(param.annotation) == Annotated and PickleSerialization in get_args(param.annotation)[1:]
)
param_annotation = PickleSerialization if pickle_annotated else param.annotation
if param_annotation not in CLASS_PARAM_TYPE_MAP:
raise InvalidError(
"To use custom types you must use typing.Annotated[<type>, modal.PickleSerialization],"
+ f" got {param_annotation}."
)
param_type, default_field = CLASS_PARAM_TYPE_MAP[param_annotation]
class_param_spec = api_pb2.ClassParameterSpec(name=param.name, has_default=has_default, type=param_type)
if has_default:
type_info = PARAM_TYPE_MAPPING.get(param_type)
Expand Down
15 changes: 11 additions & 4 deletions modal/cls.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@
import os
import typing
from collections.abc import Collection
from typing import Any, Callable, Optional, TypeVar, Union
from typing import Annotated, Any, Callable, Optional, TypeVar, Union, get_args, get_origin

from google.protobuf.message import Message
from grpclib import GRPCError, Status

from modal._utils.function_utils import CLASS_PARAM_TYPE_MAP
from modal._utils.function_utils import CLASS_PARAM_TYPE_MAP, PickleSerialization
from modal_proto import api_pb2

from ._object import _get_environment_name, _Object
Expand Down Expand Up @@ -474,8 +474,15 @@ def validate_construction_mechanism(user_cls):

annotated_params = {k: t for k, t in annotations.items() if k in params}
for k, t in annotated_params.items():
if t not in CLASS_PARAM_TYPE_MAP:
pass
pickle_annotated = get_origin(t) == Annotated and PickleSerialization in get_args(t)[1:]
param_annotation = PickleSerialization if pickle_annotated else t

if param_annotation not in CLASS_PARAM_TYPE_MAP:
t_name = getattr(t, "__name__", repr(t))
supported = ", ".join(t.__name__ for t in CLASS_PARAM_TYPE_MAP.keys())
raise InvalidError(
f"{user_cls.__name__}.{k}: {t_name} is not a supported parameter type. Use one of: {supported}"
)
# TODO:
# raise if cls has webhooks
# and no default value for pickle parameter
Expand Down
4 changes: 2 additions & 2 deletions test/cls_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import sys
import threading
import typing
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING

from typing_extensions import assert_type

Expand Down Expand Up @@ -878,7 +878,7 @@ class UsingAnnotationParameters:
a: int = modal.parameter()
b: str = modal.parameter(default="hello")
c: float = modal.parameter(init=False)
d: Any = modal.parameter(default={"foo": "bar"})
d: typing.Annotated[dict, modal.PickleSerialization] = modal.parameter(default={"foo": "bar"})

@method()
def get_value(self):
Expand Down

0 comments on commit eab8b27

Please sign in to comment.