From afda7cb1edfcaef418f3b8bfca02f34c0c0e6afb Mon Sep 17 00:00:00 2001 From: Steven Loria Date: Sun, 8 Sep 2019 10:45:32 -0400 Subject: [PATCH] Add missing_values parameter to field Allows specifying which values are treated as "missing". Addresses #713 --- src/marshmallow/fields.py | 19 +++++++++++--- tests/test_deserialization.py | 49 ++++++++++++++++++++++++++++++++++- tests/test_serialization.py | 2 +- 3 files changed, 64 insertions(+), 6 deletions(-) diff --git a/src/marshmallow/fields.py b/src/marshmallow/fields.py index 527583985..4735e127c 100644 --- a/src/marshmallow/fields.py +++ b/src/marshmallow/fields.py @@ -134,6 +134,8 @@ class Field(FieldABC): "validator_failed": "Invalid value.", } + default_missing_values = tuple() + def __init__( self, *, @@ -147,6 +149,7 @@ def __init__( load_only=False, dump_only=False, error_messages=None, + missing_values=None, **metadata ): self.default = default @@ -168,9 +171,14 @@ def __init__( "or a collection of callables." ) - # If missing=None, None should be considered valid by default + self.missing_values = ( + missing_values + if missing_values is not None + else self.default_missing_values + ) + # If missing=None or None is in missing_values, None should be considered valid by default if allow_none is None: - if missing is None: + if missing is None or self._is_missing_value(None): self.allow_none = True else: self.allow_none = False @@ -223,6 +231,9 @@ def get_value(self, obj, attr, accessor=None, default=missing_): check_key = attr if attribute is None else attribute return accessor_func(obj, check_key, default) + def _is_missing_value(self, value): + return value is missing_ or value in self.missing_values + def _validate(self, value): """Perform validation on ``value``. Raise a :exc:`ValidationError` if validation does not succeed. @@ -279,7 +290,7 @@ def _validate_missing(self, value): """Validate missing values. Raise a :exc:`ValidationError` if `value` should be considered missing. """ - if value is missing_: + if self._is_missing_value(value): if hasattr(self, "required") and self.required: raise self.make_error("required") if value is None: @@ -319,7 +330,7 @@ def deserialize(self, value, attr=None, data=None, **kwargs): # Validate required fields, deserialize, then validate # deserialized value self._validate_missing(value) - if value is missing_: + if self._is_missing_value(value): _miss = self.missing return _miss() if callable(_miss) else _miss if getattr(self, "allow_none", False) is True and value is None: diff --git a/tests/test_deserialization.py b/tests/test_deserialization.py index 30fc0a3e8..f8ca0ef3a 100644 --- a/tests/test_deserialization.py +++ b/tests/test_deserialization.py @@ -25,7 +25,7 @@ def test_fields_dont_allow_none_by_default(self, FieldClass): with pytest.raises(ValidationError, match="Field may not be null."): field.deserialize(None) - def test_allow_none_is_true_if_missing_is_true(self): + def test_allow_none_is_true_if_missing_is_none(self): field = fields.Field(missing=None) assert field.allow_none is True field.deserialize(None) is None @@ -1338,6 +1338,53 @@ class AliasingUserSerializer(Schema): assert result["name"] == "Mick" assert result["years"] is None + # https://github.com/marshmallow-code/marshmallow/issues/713 + @pytest.mark.parametrize( + ("missing", "missing_values", "input_data", "expected"), + [ + (None, {None}, {"name": None}, {"name": None}), + (None, {None}, {"name": ""}, {"name": ""}), + (None, {""}, {"name": ""}, {"name": None}), + (None, {""}, {}, {"name": None}), + ("", {""}, {"name": ""}, {"name": ""}), + ("", {None}, {"name": None}, {"name": ""}), + ("", {None}, {}, {"name": ""}), + ], + ) + def test_deserialize_with_custom_missing_values( + self, missing, missing_values, input_data, expected + ): + class ArtistSchema(Schema): + name = fields.String(missing=missing, missing_values=missing_values) + + schema = ArtistSchema() + assert schema.load(input_data) == expected + + def test_deserialize_required_field_with_custom_missing_values(self): + class ArtistSchema(Schema): + album_names = fields.List( + fields.Str(), required=True, missing_values=([], ()) + ) + + with pytest.raises(ValidationError, match="required"): + ArtistSchema().load({"album_names": []}) + + def test_setting_default_missing_values(self, monkeypatch): + monkeypatch.setattr(fields.Field, "default_missing_values", ("",)) + monkeypatch.setattr(fields.List, "default_missing_values", ([], ())) + + class ArtistSchema(Schema): + name = fields.String(missing=None) + dob = fields.DateTime(missing=None) + album_names = fields.List(fields.Str(), required=True) + + schema = ArtistSchema() + loaded = schema.load({"name": "", "dob": "", "album_names": ["Hunky Dory"]}) + assert loaded == {"name": None, "dob": None, "album_names": ["Hunky Dory"]} + + with pytest.raises(ValidationError, match="required"): + assert schema.load({"name": "", "dob": "", "album_names": []}) + def test_deserialization_raises_with_errors(self): bad_data = {"email": "invalid-email", "colors": "burger", "age": -1} v = Validator() diff --git a/tests/test_serialization.py b/tests/test_serialization.py index f88e28afc..292668588 100644 --- a/tests/test_serialization.py +++ b/tests/test_serialization.py @@ -81,7 +81,7 @@ def test_function_field_load_only(self): field = fields.Function(deserialize=lambda obj: None) assert field.load_only - def test_function_field_passed_serialize_with_context(self, user, monkeypatch): + def test_function_field_passed_serialize_with_context(self, user): class Parent(Schema): pass