diff --git a/src/middlewared/middlewared/api/base/model.py b/src/middlewared/middlewared/api/base/model.py index 22cecfb95b9d1..bb52f6d8fd39c 100644 --- a/src/middlewared/middlewared/api/base/model.py +++ b/src/middlewared/middlewared/api/base/model.py @@ -13,10 +13,58 @@ __all__ = ["BaseModel", "ForUpdateMetaclass", "query_result", "query_result_item", "added_event_model", - "changed_event_model", "removed_event_model", "single_argument_args", "single_argument_result"] + "changed_event_model", "removed_event_model", "single_argument_args", "single_argument_result", + "NotRequired"] -class BaseModel(PydanticBaseModel): +class _NotRequiredMixin(PydanticBaseModel): + @model_serializer(mode="wrap") + def serialize_basemodel(self, serializer): + obj = serializer(self) + if isinstance(obj, dict): + return { + k: v + for k, v in obj.items() + if v is not undefined + } + return obj + + +NotRequired = undefined +"""Use as the default value for fields that may be excluded from the model.""" + + +class _BaseModelMetaclass(ModelMetaclass): + """Any BaseModel subclass that uses the NotRequired default value on any of its fields receives the appropriate + model serializer.""" + # FIXME: In the future we want to set defaults on all fields that are not required. Remove this metaclass, + # `_NotRequiredMixin`, and `NotRequired` at that time. + + def __new__(mcls, name, bases, namespaces, **kwargs): + skip_patching = kwargs.pop("__BaseModelMetaclass_skip_patching", False) + + cls = super().__new__(mcls, name, bases, namespaces, **kwargs) + + if skip_patching or name == "BaseModel": + return cls + + for field in cls.model_fields.values(): + if getattr(field, "default", None) is undefined: + return create_model( + cls.__name__, + __base__=(cls, _NotRequiredMixin), + __module__=cls.__module__, + __cls_kwargs__={"__BaseModelMetaclass_skip_patching": True}, + **{ + k: (v.annotation, v) + for k, v in cls.model_fields.items() + } + ) + else: + return cls + + +class BaseModel(PydanticBaseModel, metaclass=_BaseModelMetaclass): model_config = ConfigDict( extra="forbid", strict=True, @@ -51,7 +99,7 @@ def model_dump( exclude_none: bool = False, round_trip: bool = False, warnings: bool | typing.Literal['none', 'warn', 'error'] = True, - serialize_as_any: bool = False, + serialize_as_any: bool = True, # so that nested models set to `NotRequired` do not serialize ) -> dict[str, typing.Any]: return self.__pydantic_serializer__.to_python( self, @@ -102,7 +150,7 @@ def to_previous(cls, value): return value -class ForUpdateMetaclass(ModelMetaclass): +class ForUpdateMetaclass(_BaseModelMetaclass): """ Using this metaclass on a model will change all of its fields default values to `undefined`. Such a model might be instantiated with any subset of its fields, which can be useful to validate request bodies @@ -112,7 +160,7 @@ class ForUpdateMetaclass(ModelMetaclass): def __new__(mcls, name, bases, namespaces, **kwargs): skip_patching = kwargs.pop("__ForUpdateMetaclass_skip_patching", False) - cls = super().__new__(mcls, name, bases, namespaces, **kwargs) + cls = ModelMetaclass.__new__(mcls, name, bases, namespaces, **kwargs) if skip_patching: return cls diff --git a/src/middlewared/middlewared/pytest/unit/api/base/test_excluded.py b/src/middlewared/middlewared/pytest/unit/api/base/test_excluded.py index 29f54934f3e89..a4c7a6b6793d2 100644 --- a/src/middlewared/middlewared/pytest/unit/api/base/test_excluded.py +++ b/src/middlewared/middlewared/pytest/unit/api/base/test_excluded.py @@ -1,7 +1,8 @@ +from pydantic import Field, Secret import pytest -from middlewared.api.base import BaseModel, Excluded, excluded_field -from middlewared.api.base.handler.accept import accept_params +from middlewared.api.base import BaseModel, Excluded, excluded_field, ForUpdateMetaclass, NotRequired +from middlewared.api.base.handler.accept import accept_params, validate_model from middlewared.service_exception import ValidationErrors @@ -18,9 +19,114 @@ class CreateArgs(BaseModel): data: CreateObject +def check_serialization(test_model, test_cases): + for args, dump in test_cases: + result = validate_model(test_model, args) + assert result == dump, (args, dump, result) + + def test_excluded_field(): with pytest.raises(ValidationErrors) as ve: - accept_params(CreateObject, [{"id": 1, "name": "Ivan"}]) + accept_params(CreateArgs, [{"id": 1, "name": "Ivan"}]) - assert ve.value.errors[0].attribute == "id" + assert ve.value.errors[0].attribute == "data.id" assert ve.value.errors[0].errmsg == "Extra inputs are not permitted" + + +def test_not_required(): + class NestedModel(BaseModel): + a: int = NotRequired + + class NotRequiredTestModel(BaseModel): + b: int + c: int = 3 + d: int = NotRequired + e: NestedModel + f: NestedModel = Field(default_factory=NestedModel) + # default_factory must be used here + g: NestedModel = NotRequired + h: list[NestedModel] = NotRequired + i_: int = Field(alias="i", default=NotRequired) + j: Secret[int] = NotRequired + + test_cases = ( + ( + {"b": 2, "e": {}}, + {"b": 2, "c": 3, "e": {}, "f": {}} + ), + ( + {"b": 2, "e": {"a": 1}}, + {"b": 2, "c": 3, "e": {"a": 1}, "f": {}} + ), + ( + {"b": 2, "c": -3, "e": {}}, + {"b": 2, "c": -3, "e": {}, "f": {}} + ), + ( + {"b": 2, "d": 4, "e": {}}, + {"b": 2, "c": 3, "d": 4, "e": {}, "f": {}} + ), + ( + {"b": 2, "e": {}, "f": {}}, + {"b": 2, "c": 3, "e": {}, "f": {}} + ), + ( + {"b": 2, "e": {}, "f": {"a": 1}}, + {"b": 2, "c": 3, "e": {}, "f": {"a": 1}} + ), + ( + {"b": 2, "e": {}, "g": {}}, + {"b": 2, "c": 3, "e": {}, "f": {}, "g": {}} + ), + ( + {"b": 2, "e": {}, "g": {"a": 1}}, + {"b": 2, "c": 3, "e": {}, "f": {}, "g": {"a": 1}} + ), + ( + {"b": 2, "e": {}, "h": []}, + {"b": 2, "c": 3, "e": {}, "f": {}, "h": []} + ), + ( + {"b": 2, "e": {}, "h": [{}]}, + {"b": 2, "c": 3, "e": {}, "f": {}, "h": [{}]} + ), + ( + {"b": 2, "e": {}, "h": [{"a": 1}]}, + {"b": 2, "c": 3, "e": {}, "f": {}, "h": [{"a": 1}]} + ), + ( + {"b": 2, "e": {}, "h": [{"a": 1}, {}]}, + {"b": 2, "c": 3, "e": {}, "f": {}, "h": [{"a": 1}, {}]} + ), + ( + {"b": 2, "e": {}, "i": 4}, + {"b": 2, "c": 3, "e": {}, "f": {}, "i": 4} + ), + ( + {"b": 2, "e": {}, "j": 4}, + {"b": 2, "c": 3, "e": {}, "f": {}, "j": 4} + ), + ) + check_serialization(NotRequiredTestModel, test_cases) + + +def test_update_metaclass(): + class NestedModel(BaseModel): + a: int + + class UpdateModel(BaseModel, metaclass=ForUpdateMetaclass): + b: int + c: NestedModel + + test_cases = ( + ( + {}, {} + ), + ( + {"b": 2}, {"b": 2} + ), + ( + {"c": {"a": 1}}, {"c": {"a": 1}} + ), + ) + check_serialization(UpdateModel, test_cases) diff --git a/tests/unit/test_api.py b/src/middlewared/middlewared/pytest/unit/api/handler/result/test_alias.py similarity index 60% rename from tests/unit/test_api.py rename to src/middlewared/middlewared/pytest/unit/api/handler/result/test_alias.py index 745cabf252a86..8a97cba6faa56 100644 --- a/tests/unit/test_api.py +++ b/src/middlewared/middlewared/pytest/unit/api/handler/result/test_alias.py @@ -6,14 +6,12 @@ def test_dump_by_alias(): class AliasModel(BaseModel): - field1_: int = Field(..., alias='field1') + field1_: int = Field(alias='field1') field2: str field3_: bool = Field(alias='field3', default=False) class AliasModelResult(BaseModel): result: AliasModel - result = {'field1': 1, 'field2': 'two'} - dump = serialize_result(AliasModelResult, result, False) - - assert dump == {'field1': 1, 'field2': 'two', 'field3': False} + result = serialize_result(AliasModelResult, {'field1': 1, 'field2': 'two'}, True) + assert result == {'field1': 1, 'field2': 'two', 'field3': False}