Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft/RFC: prep for proto serialization of payloads #2893

Closed
wants to merge 10 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
136 changes: 130 additions & 6 deletions modal/_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,19 +393,71 @@ def check_valid_cls_constructor_arg(key, obj):
class ParamTypeInfo:
default_field: str
proto_field: str
converter: typing.Callable[[str], typing.Any]
encoder: typing.Callable[[Any], Any]
decoder: typing.Callable[[Any, "modal.client._Client"], Any]


PARAM_TYPE_MAPPING = {
api_pb2.PARAM_TYPE_STRING: ParamTypeInfo(default_field="string_default", proto_field="string_value", converter=str),
api_pb2.PARAM_TYPE_INT: ParamTypeInfo(default_field="int_default", proto_field="int_value", converter=int),
PYTHON_TO_PROTO_TYPE: dict[type, "api_pb2.ParameterType.ValueType"] = {
str: api_pb2.PARAM_TYPE_STRING,
int: api_pb2.PARAM_TYPE_INT,
list: api_pb2.PARAM_TYPE_LIST,
tuple: api_pb2.PARAM_TYPE_LIST,
dict: api_pb2.PARAM_TYPE_DICT,
}


def encode_list(python_list: typing.Sequence) -> api_pb2.PayloadListValue:
return api_pb2.PayloadListValue(
values=[_python_to_proto_value(python_list_value) for python_list_value in python_list]
)


def decode_list(proto_list_value: api_pb2.PayloadListValue, client: "modal.client._Client") -> list:
return [_proto_to_python_value(proto_value, client) for proto_value in proto_list_value.values]


def encode_dict(python_dict: typing.Mapping[str, Any]) -> api_pb2.PayloadDictValue:
return api_pb2.PayloadDictValue(
entries=[
api_pb2.PayloadDictEntry(name=k, value=_python_to_proto_value(python_dict_value))
for k, python_dict_value in python_dict.items()
]
)


def decode_dict(proto_dict: api_pb2.PayloadDictValue, client: "modal.client._Client") -> dict[str, Any]:
return {entry.name: _proto_to_python_value(entry.value, client) for entry in proto_dict.entries}


PROTO_TYPE_INFO = {
api_pb2.PARAM_TYPE_STRING: ParamTypeInfo(
default_field="string_default", proto_field="string_value", encoder=str, decoder=lambda data, _: str(data)
),
api_pb2.PARAM_TYPE_INT: ParamTypeInfo(
default_field="int_default", proto_field="int_value", encoder=int, decoder=lambda data, _: int(data)
),
api_pb2.PARAM_TYPE_PICKLE: ParamTypeInfo(
default_field="pickle_default", proto_field="pickle_value", encoder=serialize, decoder=deserialize
),
api_pb2.PARAM_TYPE_LIST: ParamTypeInfo(
default_field="list_default",
proto_field="list_value",
encoder=encode_list,
decoder=decode_list,
),
api_pb2.PARAM_TYPE_DICT: ParamTypeInfo(
default_field="dict_default",
proto_field="dict_value",
encoder=encode_dict,
decoder=decode_dict,
),
}


def serialize_proto_params(python_params: dict[str, Any], schema: typing.Sequence[api_pb2.ClassParameterSpec]) -> bytes:
proto_params: list[api_pb2.ClassParameterValue] = []
for schema_param in schema:
type_info = PARAM_TYPE_MAPPING.get(schema_param.type)
type_info = PROTO_TYPE_INFO.get(schema_param.type)
if not type_info:
raise ValueError(f"Unsupported parameter type: {schema_param.type}")
proto_param = api_pb2.ClassParameterValue(
Expand All @@ -419,7 +471,7 @@ def serialize_proto_params(python_params: dict[str, Any], schema: typing.Sequenc
else:
raise ValueError(f"Missing required parameter: {schema_param.name}")
try:
converted_value = type_info.converter(python_value)
converted_value = type_info.encoder(python_value)
except ValueError as exc:
raise ValueError(f"Invalid type for parameter {schema_param.name}: {exc}")
setattr(proto_param, type_info.proto_field, converted_value)
Expand All @@ -428,7 +480,79 @@ def serialize_proto_params(python_params: dict[str, Any], schema: typing.Sequenc
return proto_bytes


def _python_to_proto_value(python_value: Any) -> api_pb2.ClassParameterValue:
# TODO: use schema
python_type = type(python_value)
if python_type in PYTHON_TO_PROTO_TYPE:
proto_type = PYTHON_TO_PROTO_TYPE[python_type]
else:
proto_type = api_pb2.PARAM_TYPE_PICKLE

proto_type_info = PROTO_TYPE_INFO[proto_type]
proto_scalar = proto_type_info.encoder(python_value)

return api_pb2.ClassParameterValue(
name="", # this field is unused for payloads and exists for legacy reasons/code reuse with class params
type=proto_type,
**{proto_type_info.proto_field: proto_scalar},
)


def _proto_to_python_value(proto_value: api_pb2.ClassParameterValue, client: "modal.client._Client") -> Any:
proto_type_info = PROTO_TYPE_INFO[proto_value.type]
proto_field = proto_type_info.proto_field
proto_dto = getattr(proto_value, proto_field)
python_value = proto_type_info.decoder(proto_dto, client)
return python_value


def python_to_proto_payload(python_args: tuple[Any, ...], python_kwargs: dict[str, Any]) -> api_pb2.Payload:
"""Serialize a python payload using the input payload type rather than a schema w/ type coercion/validation

This is similar to serialize_proto_params except:
* Doesn't require a prior schema for encoding
* It uses the new api_pb2.Payload container proto to include both args and kwargs
* Doesn't use the `name` field of the ClassParameterValue message (names are encoded as part
of the `kwargs` PayloadDictValue instead)
"""
proto_args = api_pb2.PayloadListValue(values=[])
for python_value in python_args:
proto_value = _python_to_proto_value(python_value)
proto_args.values.append(proto_value)

proto_kwargs = api_pb2.PayloadDictValue(entries=[])
for param_name, python_value in python_kwargs.items():
proto_value = _python_to_proto_value(python_value)
proto_kwargs.entries.append(
api_pb2.PayloadDictEntry(
name=param_name,
value=proto_value,
)
)
return api_pb2.Payload(
args=proto_args,
kwargs=proto_kwargs,
)


def proto_to_python_payload(
proto_payload: api_pb2.Payload, client: "modal.client._Client"
) -> tuple[tuple[Any, ...], dict[str, Any]]:
python_args = []
for proto_value in proto_payload.args.values:
python_value = _proto_to_python_value(proto_value, client)
python_args.append(python_value)

python_kwargs = {}
for proto_dict_entry in proto_payload.kwargs.entries:
python_value = _proto_to_python_value(proto_dict_entry.value, client)
python_kwargs[proto_dict_entry.name] = python_value

return tuple(python_args), python_kwargs


def deserialize_proto_params(serialized_params: bytes, schema: list[api_pb2.ClassParameterSpec]) -> dict[str, Any]:
# TODO: this currently requires the schema to decode a payload, but we could separate validation from decoding
proto_struct = api_pb2.ClassParameterSet()
proto_struct.ParseFromString(serialized_params)
value_by_name = {p.name: p for p in proto_struct.parameters}
Expand Down
35 changes: 19 additions & 16 deletions modal/_utils/function_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import modal_proto
from modal_proto import api_pb2

from .._serialization import deserialize, deserialize_data_format, serialize
from .._serialization import PROTO_TYPE_INFO, PYTHON_TO_PROTO_TYPE, deserialize, deserialize_data_format, serialize
from .._traceback import append_modal_tb
from ..config import config, logger
from ..exception import (
Expand All @@ -39,10 +39,7 @@ class FunctionInfoType(Enum):


# TODO(elias): Add support for quoted/str annotations
CLASS_PARAM_TYPE_MAP: dict[type, tuple["api_pb2.ParameterType.ValueType", str]] = {
str: (api_pb2.PARAM_TYPE_STRING, "string_default"),
int: (api_pb2.PARAM_TYPE_INT, "int_default"),
}
SUPPORTED_CLASS_PARAM_TYPES = [str, int]


class LocalFunctionError(InvalidError):
Expand Down Expand Up @@ -106,6 +103,21 @@ def get_function_type(is_generator: Optional[bool]) -> "api_pb2.Function.Functio
return api_pb2.Function.FUNCTION_TYPE_GENERATOR if is_generator else api_pb2.Function.FUNCTION_TYPE_FUNCTION


def schema_from_signature(signature: inspect.Signature) -> list[api_pb2.ClassParameterSpec]:
modal_parameters: list[api_pb2.ClassParameterSpec] = []
for param in signature.parameters.values():
has_default = param.default is not param.empty
if param.annotation not in SUPPORTED_CLASS_PARAM_TYPES:
raise InvalidError("modal.parameter() currently only support str or int types")
param_type = PYTHON_TO_PROTO_TYPE[param.annotation]
param_type_info = PROTO_TYPE_INFO[param_type]
class_param_spec = api_pb2.ClassParameterSpec(name=param.name, has_default=has_default, type=param_type)
if has_default:
setattr(class_param_spec, param_type_info.default_field, param.default)
modal_parameters.append(class_param_spec)
return modal_parameters


class FunctionInfo:
"""Utility that determines serialization/deserialization mechanisms for functions

Expand Down Expand Up @@ -291,20 +303,11 @@ def class_parameter_info(self) -> api_pb2.ClassParameterInfo:
# annotation parameters trigger strictly typed parametrization
# which enables web endpoint for parametrized classes

modal_parameters: list[api_pb2.ClassParameterSpec] = []
signature = _get_class_constructor_signature(self.user_cls)
for param in signature.parameters.values():
has_default = param.default is not param.empty
if param.annotation not in CLASS_PARAM_TYPE_MAP:
raise InvalidError("modal.parameter() currently only support str or int types")
param_type, default_field = CLASS_PARAM_TYPE_MAP[param.annotation]
class_param_spec = api_pb2.ClassParameterSpec(name=param.name, has_default=has_default, type=param_type)
if has_default:
setattr(class_param_spec, default_field, param.default)
modal_parameters.append(class_param_spec)
schema = schema_from_signature(signature)

return api_pb2.ClassParameterInfo(
format=api_pb2.ClassParameterInfo.PARAM_SERIALIZATION_FORMAT_PROTO, schema=modal_parameters
format=api_pb2.ClassParameterInfo.PARAM_SERIALIZATION_FORMAT_PROTO, schema=schema
)

def get_entrypoint_mount(self) -> dict[str, _Mount]:
Expand Down
6 changes: 3 additions & 3 deletions modal/cls.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from google.protobuf.message import Message
from grpclib import GRPCError, Status

from modal._utils.function_utils import CLASS_PARAM_TYPE_MAP, FunctionInfo
from modal._utils.function_utils import SUPPORTED_CLASS_PARAM_TYPES, FunctionInfo
from modal_proto import api_pb2

from ._functions import _Function, _parse_retries
Expand Down Expand Up @@ -444,9 +444,9 @@ def validate_construction_mechanism(user_cls):

annotated_params = {k: t for k, t in annotations.items() if k in params}
for k, t in annotated_params.items():
if t not in CLASS_PARAM_TYPE_MAP:
if t not in SUPPORTED_CLASS_PARAM_TYPES:
t_name = getattr(t, "__name__", repr(t))
supported = ", ".join(t.__name__ for t in CLASS_PARAM_TYPE_MAP.keys())
supported = ", ".join(t.__name__ for t in SUPPORTED_CLASS_PARAM_TYPES)
raise InvalidError(
f"{user_cls.__name__}.{k}: {t_name} is not a supported parameter type. Use one of: {supported}"
)
Expand Down
31 changes: 28 additions & 3 deletions modal_proto/api.proto
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,8 @@ enum ParameterType {
PARAM_TYPE_STRING = 1;
PARAM_TYPE_INT = 2;
PARAM_TYPE_PICKLE = 3;
PARAM_TYPE_LIST = 4;
PARAM_TYPE_DICT = 5;
}

enum ProgressType {
Expand Down Expand Up @@ -680,7 +682,7 @@ message ClassParameterInfo {
enum ParameterSerializationFormat {
PARAM_SERIALIZATION_FORMAT_UNSPECIFIED = 0;
PARAM_SERIALIZATION_FORMAT_PICKLE = 1; // legacy format - pickle of (args, kwargs) tuple
PARAM_SERIALIZATION_FORMAT_PROTO = 2; // new format using api.FunctionParameterSet
PARAM_SERIALIZATION_FORMAT_PROTO = 2; // new format using api.ClassParameterSet
}
ParameterSerializationFormat format = 1;
repeated ClassParameterSpec schema = 2; // only set for PARAM_SERIALIZATION_FORMAT_PROTO
Expand All @@ -691,6 +693,7 @@ message ClassParameterSet {
// since we use the serialized message as the bound function identifier
// for parameter-bound classes. Modify with *caution*
repeated ClassParameterValue parameters = 1;
// PayloadDictValue kwargs = 2; // could supersede parameters at some point, so we can remove ClassParameterValue.name
}

message ClassParameterSpec {
Expand All @@ -704,19 +707,22 @@ message ClassParameterSpec {
}
}

message ClassParameterValue {
message ClassParameterValue { // TODO: rename to PayloadValue
// NOTE: adding additional *fields* here can invalidate function lookups
// since we use the serialized message as the bound function identifier
// for parameter-bound classes. Modify with *caution*
string name = 1;
optional string name = 1; // TODO: deprecate this eventually once all classes use a PayloadDictValue instead
ParameterType type = 2;
oneof value_oneof {
string string_value = 3;
int64 int_value = 4;
bytes pickle_value = 5;
PayloadListValue list_value = 6;
PayloadDictValue dict_value = 7;
}
}


message ClientHelloResponse {
string warning = 1;
string image_builder_version = 2; // Deprecated, no longer used in client
Expand Down Expand Up @@ -2018,6 +2024,25 @@ message PTYInfo {
PTYType pty_type = 7;
}

message Payload {
// this can be used
PayloadListValue args = 1;
PayloadDictValue kwargs = 2;
}

message PayloadDictEntry {
string name = 1;
ClassParameterValue value = 2;
}

message PayloadDictValue {
repeated PayloadDictEntry entries = 1;
}

message PayloadListValue {
repeated ClassParameterValue values = 1;
}

message PortSpec {
uint32 port = 1;
bool unencrypted = 2;
Expand Down
Loading
Loading