diff --git a/modal/_serialization.py b/modal/_serialization.py index 3e8ca9446..9a093c7cf 100644 --- a/modal/_serialization.py +++ b/modal/_serialization.py @@ -393,19 +393,71 @@ def check_valid_cls_constructor_arg(key, obj): class ParamTypeInfo: default_field: str proto_field: str - converter: typing.Callable[[str], typing.Any] + encoder: typing.Callable[[Any], Any] + decoder: typing.Callable[[Any, "modal.client._Client"], Any] -PARAM_TYPE_MAPPING = { - api_pb2.PARAM_TYPE_STRING: ParamTypeInfo(default_field="string_default", proto_field="string_value", converter=str), - api_pb2.PARAM_TYPE_INT: ParamTypeInfo(default_field="int_default", proto_field="int_value", converter=int), +PYTHON_TO_PROTO_TYPE: dict[type, "api_pb2.ParameterType.ValueType"] = { + str: api_pb2.PARAM_TYPE_STRING, + int: api_pb2.PARAM_TYPE_INT, + list: api_pb2.PARAM_TYPE_LIST, + tuple: api_pb2.PARAM_TYPE_LIST, + dict: api_pb2.PARAM_TYPE_DICT, +} + + +def encode_list(python_list: typing.Sequence) -> api_pb2.PayloadListValue: + return api_pb2.PayloadListValue( + values=[_python_to_proto_value(python_list_value) for python_list_value in python_list] + ) + + +def decode_list(proto_list_value: api_pb2.PayloadListValue, client: "modal.client._Client") -> list: + return [_proto_to_python_value(proto_value, client) for proto_value in proto_list_value.values] + + +def encode_dict(python_dict: typing.Mapping[str, Any]) -> api_pb2.PayloadDictValue: + return api_pb2.PayloadDictValue( + entries=[ + api_pb2.PayloadDictEntry(name=k, value=_python_to_proto_value(python_dict_value)) + for k, python_dict_value in python_dict.items() + ] + ) + + +def decode_dict(proto_dict: api_pb2.PayloadDictValue, client: "modal.client._Client") -> dict[str, Any]: + return {entry.name: _proto_to_python_value(entry.value, client) for entry in proto_dict.entries} + + +PROTO_TYPE_INFO = { + api_pb2.PARAM_TYPE_STRING: ParamTypeInfo( + default_field="string_default", proto_field="string_value", encoder=str, decoder=lambda data, _: str(data) + ), + api_pb2.PARAM_TYPE_INT: ParamTypeInfo( + default_field="int_default", proto_field="int_value", encoder=int, decoder=lambda data, _: int(data) + ), + api_pb2.PARAM_TYPE_PICKLE: ParamTypeInfo( + default_field="pickle_default", proto_field="pickle_value", encoder=serialize, decoder=deserialize + ), + api_pb2.PARAM_TYPE_LIST: ParamTypeInfo( + default_field="list_default", + proto_field="list_value", + encoder=encode_list, + decoder=decode_list, + ), + api_pb2.PARAM_TYPE_DICT: ParamTypeInfo( + default_field="dict_default", + proto_field="dict_value", + encoder=encode_dict, + decoder=decode_dict, + ), } def serialize_proto_params(python_params: dict[str, Any], schema: typing.Sequence[api_pb2.ClassParameterSpec]) -> bytes: proto_params: list[api_pb2.ClassParameterValue] = [] for schema_param in schema: - type_info = PARAM_TYPE_MAPPING.get(schema_param.type) + type_info = PROTO_TYPE_INFO.get(schema_param.type) if not type_info: raise ValueError(f"Unsupported parameter type: {schema_param.type}") proto_param = api_pb2.ClassParameterValue( @@ -419,7 +471,7 @@ def serialize_proto_params(python_params: dict[str, Any], schema: typing.Sequenc else: raise ValueError(f"Missing required parameter: {schema_param.name}") try: - converted_value = type_info.converter(python_value) + converted_value = type_info.encoder(python_value) except ValueError as exc: raise ValueError(f"Invalid type for parameter {schema_param.name}: {exc}") setattr(proto_param, type_info.proto_field, converted_value) @@ -428,7 +480,79 @@ def serialize_proto_params(python_params: dict[str, Any], schema: typing.Sequenc return proto_bytes +def _python_to_proto_value(python_value: Any) -> api_pb2.ClassParameterValue: + # TODO: use schema + python_type = type(python_value) + if python_type in PYTHON_TO_PROTO_TYPE: + proto_type = PYTHON_TO_PROTO_TYPE[python_type] + else: + proto_type = api_pb2.PARAM_TYPE_PICKLE + + proto_type_info = PROTO_TYPE_INFO[proto_type] + proto_scalar = proto_type_info.encoder(python_value) + + return api_pb2.ClassParameterValue( + name="", # this field is unused for payloads and exists for legacy reasons/code reuse with class params + type=proto_type, + **{proto_type_info.proto_field: proto_scalar}, + ) + + +def _proto_to_python_value(proto_value: api_pb2.ClassParameterValue, client: "modal.client._Client") -> Any: + proto_type_info = PROTO_TYPE_INFO[proto_value.type] + proto_field = proto_type_info.proto_field + proto_dto = getattr(proto_value, proto_field) + python_value = proto_type_info.decoder(proto_dto, client) + return python_value + + +def python_to_proto_payload(python_args: tuple[Any, ...], python_kwargs: dict[str, Any]) -> api_pb2.Payload: + """Serialize a python payload using the input payload type rather than a schema w/ type coercion/validation + + This is similar to serialize_proto_params except: + * Doesn't require a prior schema for encoding + * It uses the new api_pb2.Payload container proto to include both args and kwargs + * Doesn't use the `name` field of the ClassParameterValue message (names are encoded as part + of the `kwargs` PayloadDictValue instead) + """ + proto_args = api_pb2.PayloadListValue(values=[]) + for python_value in python_args: + proto_value = _python_to_proto_value(python_value) + proto_args.values.append(proto_value) + + proto_kwargs = api_pb2.PayloadDictValue(entries=[]) + for param_name, python_value in python_kwargs.items(): + proto_value = _python_to_proto_value(python_value) + proto_kwargs.entries.append( + api_pb2.PayloadDictEntry( + name=param_name, + value=proto_value, + ) + ) + return api_pb2.Payload( + args=proto_args, + kwargs=proto_kwargs, + ) + + +def proto_to_python_payload( + proto_payload: api_pb2.Payload, client: "modal.client._Client" +) -> tuple[tuple[Any, ...], dict[str, Any]]: + python_args = [] + for proto_value in proto_payload.args.values: + python_value = _proto_to_python_value(proto_value, client) + python_args.append(python_value) + + python_kwargs = {} + for proto_dict_entry in proto_payload.kwargs.entries: + python_value = _proto_to_python_value(proto_dict_entry.value, client) + python_kwargs[proto_dict_entry.name] = python_value + + return tuple(python_args), python_kwargs + + def deserialize_proto_params(serialized_params: bytes, schema: list[api_pb2.ClassParameterSpec]) -> dict[str, Any]: + # TODO: this currently requires the schema to decode a payload, but we could separate validation from decoding proto_struct = api_pb2.ClassParameterSet() proto_struct.ParseFromString(serialized_params) value_by_name = {p.name: p for p in proto_struct.parameters} diff --git a/modal/_utils/function_utils.py b/modal/_utils/function_utils.py index e22253104..1148d4956 100644 --- a/modal/_utils/function_utils.py +++ b/modal/_utils/function_utils.py @@ -15,7 +15,7 @@ import modal_proto from modal_proto import api_pb2 -from .._serialization import deserialize, deserialize_data_format, serialize +from .._serialization import PROTO_TYPE_INFO, PYTHON_TO_PROTO_TYPE, deserialize, deserialize_data_format, serialize from .._traceback import append_modal_tb from ..config import config, logger from ..exception import ( @@ -39,10 +39,7 @@ class FunctionInfoType(Enum): # TODO(elias): Add support for quoted/str annotations -CLASS_PARAM_TYPE_MAP: dict[type, tuple["api_pb2.ParameterType.ValueType", str]] = { - str: (api_pb2.PARAM_TYPE_STRING, "string_default"), - int: (api_pb2.PARAM_TYPE_INT, "int_default"), -} +SUPPORTED_CLASS_PARAM_TYPES = [str, int] class LocalFunctionError(InvalidError): @@ -106,6 +103,21 @@ def get_function_type(is_generator: Optional[bool]) -> "api_pb2.Function.Functio return api_pb2.Function.FUNCTION_TYPE_GENERATOR if is_generator else api_pb2.Function.FUNCTION_TYPE_FUNCTION +def schema_from_signature(signature: inspect.Signature) -> list[api_pb2.ClassParameterSpec]: + modal_parameters: list[api_pb2.ClassParameterSpec] = [] + for param in signature.parameters.values(): + has_default = param.default is not param.empty + if param.annotation not in SUPPORTED_CLASS_PARAM_TYPES: + raise InvalidError("modal.parameter() currently only support str or int types") + param_type = PYTHON_TO_PROTO_TYPE[param.annotation] + param_type_info = PROTO_TYPE_INFO[param_type] + class_param_spec = api_pb2.ClassParameterSpec(name=param.name, has_default=has_default, type=param_type) + if has_default: + setattr(class_param_spec, param_type_info.default_field, param.default) + modal_parameters.append(class_param_spec) + return modal_parameters + + class FunctionInfo: """Utility that determines serialization/deserialization mechanisms for functions @@ -291,20 +303,11 @@ def class_parameter_info(self) -> api_pb2.ClassParameterInfo: # annotation parameters trigger strictly typed parametrization # which enables web endpoint for parametrized classes - modal_parameters: list[api_pb2.ClassParameterSpec] = [] signature = _get_class_constructor_signature(self.user_cls) - for param in signature.parameters.values(): - has_default = param.default is not param.empty - if param.annotation not in CLASS_PARAM_TYPE_MAP: - raise InvalidError("modal.parameter() currently only support str or int types") - param_type, default_field = CLASS_PARAM_TYPE_MAP[param.annotation] - class_param_spec = api_pb2.ClassParameterSpec(name=param.name, has_default=has_default, type=param_type) - if has_default: - setattr(class_param_spec, default_field, param.default) - modal_parameters.append(class_param_spec) + schema = schema_from_signature(signature) return api_pb2.ClassParameterInfo( - format=api_pb2.ClassParameterInfo.PARAM_SERIALIZATION_FORMAT_PROTO, schema=modal_parameters + format=api_pb2.ClassParameterInfo.PARAM_SERIALIZATION_FORMAT_PROTO, schema=schema ) def get_entrypoint_mount(self) -> dict[str, _Mount]: diff --git a/modal/cls.py b/modal/cls.py index 4a04922ee..4ed49b7cd 100644 --- a/modal/cls.py +++ b/modal/cls.py @@ -8,7 +8,7 @@ from google.protobuf.message import Message from grpclib import GRPCError, Status -from modal._utils.function_utils import CLASS_PARAM_TYPE_MAP, FunctionInfo +from modal._utils.function_utils import SUPPORTED_CLASS_PARAM_TYPES, FunctionInfo from modal_proto import api_pb2 from ._functions import _Function, _parse_retries @@ -444,9 +444,9 @@ def validate_construction_mechanism(user_cls): annotated_params = {k: t for k, t in annotations.items() if k in params} for k, t in annotated_params.items(): - if t not in CLASS_PARAM_TYPE_MAP: + if t not in SUPPORTED_CLASS_PARAM_TYPES: t_name = getattr(t, "__name__", repr(t)) - supported = ", ".join(t.__name__ for t in CLASS_PARAM_TYPE_MAP.keys()) + supported = ", ".join(t.__name__ for t in SUPPORTED_CLASS_PARAM_TYPES) raise InvalidError( f"{user_cls.__name__}.{k}: {t_name} is not a supported parameter type. Use one of: {supported}" ) diff --git a/modal_proto/api.proto b/modal_proto/api.proto index 3fcfdd4a9..98beea225 100644 --- a/modal_proto/api.proto +++ b/modal_proto/api.proto @@ -171,6 +171,8 @@ enum ParameterType { PARAM_TYPE_STRING = 1; PARAM_TYPE_INT = 2; PARAM_TYPE_PICKLE = 3; + PARAM_TYPE_LIST = 4; + PARAM_TYPE_DICT = 5; } enum ProgressType { @@ -680,7 +682,7 @@ message ClassParameterInfo { enum ParameterSerializationFormat { PARAM_SERIALIZATION_FORMAT_UNSPECIFIED = 0; PARAM_SERIALIZATION_FORMAT_PICKLE = 1; // legacy format - pickle of (args, kwargs) tuple - PARAM_SERIALIZATION_FORMAT_PROTO = 2; // new format using api.FunctionParameterSet + PARAM_SERIALIZATION_FORMAT_PROTO = 2; // new format using api.ClassParameterSet } ParameterSerializationFormat format = 1; repeated ClassParameterSpec schema = 2; // only set for PARAM_SERIALIZATION_FORMAT_PROTO @@ -691,6 +693,7 @@ message ClassParameterSet { // since we use the serialized message as the bound function identifier // for parameter-bound classes. Modify with *caution* repeated ClassParameterValue parameters = 1; + // PayloadDictValue kwargs = 2; // could supersede parameters at some point, so we can remove ClassParameterValue.name } message ClassParameterSpec { @@ -704,19 +707,22 @@ message ClassParameterSpec { } } -message ClassParameterValue { +message ClassParameterValue { // TODO: rename to PayloadValue // NOTE: adding additional *fields* here can invalidate function lookups // since we use the serialized message as the bound function identifier // for parameter-bound classes. Modify with *caution* - string name = 1; + optional string name = 1; // TODO: deprecate this eventually once all classes use a PayloadDictValue instead ParameterType type = 2; oneof value_oneof { string string_value = 3; int64 int_value = 4; bytes pickle_value = 5; + PayloadListValue list_value = 6; + PayloadDictValue dict_value = 7; } } + message ClientHelloResponse { string warning = 1; string image_builder_version = 2; // Deprecated, no longer used in client @@ -2018,6 +2024,25 @@ message PTYInfo { PTYType pty_type = 7; } +message Payload { + // this can be used + PayloadListValue args = 1; + PayloadDictValue kwargs = 2; +} + +message PayloadDictEntry { + string name = 1; + ClassParameterValue value = 2; +} + +message PayloadDictValue { + repeated PayloadDictEntry entries = 1; +} + +message PayloadListValue { + repeated ClassParameterValue values = 1; +} + message PortSpec { uint32 port = 1; bool unencrypted = 2; diff --git a/test/serialization_test.py b/test/serialization_test.py index debf25176..8eba237ac 100644 --- a/test/serialization_test.py +++ b/test/serialization_test.py @@ -2,11 +2,15 @@ import pytest import random +import modal from modal import Queue from modal._serialization import ( + PROTO_TYPE_INFO, deserialize, deserialize_data_format, deserialize_proto_params, + proto_to_python_payload, + python_to_proto_payload, serialize, serialize_data_format, serialize_proto_params, @@ -61,7 +65,7 @@ def test_deserialization_error(client): @pytest.mark.parametrize( - ["pydict", "params"], + ["pydict", "params", "expected_bytes"], [ ( {"foo": "bar", "i": 5}, @@ -69,11 +73,19 @@ def test_deserialization_error(client): api_pb2.ClassParameterSpec(name="foo", type=api_pb2.PARAM_TYPE_STRING), api_pb2.ClassParameterSpec(name="i", type=api_pb2.PARAM_TYPE_INT), ], + # only update this byte sequence if you are aware of the consequences of changing + # serialization byte output - it could invalidate existing container pools for users + # on redeployment, and possibly cause startup crashes if new containers can't + # deserialize old proto parameters. + b"\n\x0c\n\x03foo\x10\x01\x1a\x03bar\n\x07\n\x01i\x10\x02 \x05", ) ], ) -def test_proto_serde_params_success(pydict, params): +def test_proto_serde_params_success(pydict, params, expected_bytes): serialized_params = serialize_proto_params(pydict, params) + # it's important that the serialization doesn't change, since the serialized params bytes + # are used as a key for the container pooling of parameterized services (classes) + assert serialized_params == expected_bytes reconstructed = deserialize_proto_params(serialized_params, params) assert reconstructed == pydict @@ -88,3 +100,66 @@ def test_proto_serde_failure_incomplete_params(): deserialize_proto_params(encoded_params, [api_pb2.ClassParameterSpec(name="x", type=api_pb2.PARAM_TYPE_STRING)]) # TODO: add test for incorrect types + + +def _call(*args, **kwargs): + return args, kwargs + + +@pytest.fixture() +def disable_pickle_payloads(monkeypatch): + def bork(): + raise Exception("This test is expected to not use pickling") + + monkeypatch.setattr(PROTO_TYPE_INFO[api_pb2.PARAM_TYPE_PICKLE], "encoder", lambda _: bork()) + + +@pytest.mark.parametrize( + ["python_arg_kwargs", "expected_proto_bytes"], + [ + (_call("foo"), b"\n\x0b\n\t\n\x00\x10\x01\x1a\x03foo\x12\x00"), # positional args + (_call(bar=3), b"\n\x00\x12\x0f\n\r\n\x03bar\x12\x06\n\x00\x10\x02 \x03"), + ( + _call("foo", bar=2), + b"\n\x0b\n\t\n\x00\x10\x01\x1a\x03foo\x12\x0f\n\r\n\x03bar\x12\x06\n\x00\x10\x02 \x02", + ), # mix + ( + _call([1, 2]), + b"\n\x18\n\x16\n\x00\x10\x042\x10\n\x06\n\x00\x10\x02 \x01\n\x06\n\x00\x10\x02 \x02\x12\x00", + ), # list + ( + _call([1, "bar"]), + b"\n\x1b\n\x19\n\x00\x10\x042\x13\n\x06\n\x00\x10\x02 \x01\n\t\n\x00\x10\x01\x1a\x03bar\x12\x00", + ), # mixed list + ( + _call({"some_key": 123}), + b"\n\x1c\n\x1a\n\x00\x10\x05:\x14\n\x12\n\x08some_key\x12\x06\n\x00\x10\x02 {\x12\x00", + ), # dict + ], +) +@pytest.mark.usefixtures("disable_pickle_payloads") +def test_proto_serde_stability(python_arg_kwargs, expected_proto_bytes, client): + # simulates a call from an older client (typically fewer supported types) to a newer + proto_payload = python_to_proto_payload(*python_arg_kwargs) + proto_bytes = proto_payload.SerializeToString(deterministic=True) + assert proto_bytes == expected_proto_bytes # possibly relax this to only enforce being able to decode? + recovered_payload = api_pb2.Payload() + recovered_payload.ParseFromString(proto_bytes) + assert recovered_payload == proto_payload + recovered_python_arg_kwargs = proto_to_python_payload(recovered_payload, client) + assert recovered_python_arg_kwargs == python_arg_kwargs + + +def test_payload_modal_object(client): + with modal.Dict.ephemeral(client=client) as dct: + dct["foo"] = "bar" + proto_payload = python_to_proto_payload(*_call(dct)) + proto_bytes = proto_payload.SerializeToString() + recovered_payload = api_pb2.Payload() + recovered_payload.ParseFromString(proto_bytes) + assert recovered_payload == proto_payload + recovered_python_arg_kwargs = proto_to_python_payload(recovered_payload, client) + recovered_dct = recovered_python_arg_kwargs[0][0] + assert recovered_dct.is_hydrated + assert recovered_dct.object_id == dct.object_id + assert dct["foo"] == "bar"