From 7cc1de477bbd4833da541a334459a0b20b79da2b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=A9r=C3=B4me=20Lafr=C3=A9choux?= Date: Thu, 25 Aug 2022 17:36:14 +0200 Subject: [PATCH] EnumValue: serialize choices with inner field --- src/marshmallow/fields.py | 6 ++++-- tests/test_deserialization.py | 14 +++++++------- tests/test_serialization.py | 4 ++-- 3 files changed, 13 insertions(+), 11 deletions(-) diff --git a/src/marshmallow/fields.py b/src/marshmallow/fields.py index 57610fa8a..45a5572e3 100644 --- a/src/marshmallow/fields.py +++ b/src/marshmallow/fields.py @@ -1905,8 +1905,6 @@ class EnumValue(Field): def __init__(self, cls_or_instance: Field | type, enum: type[Enum], **kwargs): super().__init__(**kwargs) - self.enum = enum - self.choices = ", ".join([str(m.value) for m in enum]) try: self.field = resolve_field_instance(cls_or_instance) except FieldInstanceResolutionError as error: @@ -1914,6 +1912,10 @@ def __init__(self, cls_or_instance: Field | type, enum: type[Enum], **kwargs): "The enum field must be a subclass or instance of " "marshmallow.base.FieldABC." ) from error + self.enum = enum + self.choices = ", ".join( + [str(self.field._serialize(m.value, None, None)) for m in enum] + ) def _serialize(self, value, attr, obj, **kwargs): if value is None: diff --git a/tests/test_deserialization.py b/tests/test_deserialization.py index 87227dee9..b922d6a56 100644 --- a/tests/test_deserialization.py +++ b/tests/test_deserialization.py @@ -1118,8 +1118,8 @@ def test_enumvalue_field_deserialization(self): assert field.deserialize("black hair") == HairColorEnum.black field = fields.EnumValue(fields.Integer, GenderEnum) assert field.deserialize(1) == GenderEnum.male - field = fields.EnumValue(fields.Date, DateEnum) - assert field.deserialize("2004-02-29") == DateEnum.date_1 + field = fields.EnumValue(fields.Date(format="%d/%m/%Y"), DateEnum) + assert field.deserialize("29/02/2004") == DateEnum.date_1 def test_enumvalue_field_invalid_value(self): field = fields.EnumValue(fields.String, HairColorEnum) @@ -1131,11 +1131,11 @@ def test_enumvalue_field_invalid_value(self): field = fields.EnumValue(fields.Integer, GenderEnum) with pytest.raises(ValidationError, match="Must be one of: 1, 2, 3."): field.deserialize(12) - field = fields.EnumValue(fields.Date, DateEnum) + field = fields.EnumValue(fields.Date(format="%d/%m/%Y"), DateEnum) with pytest.raises( - ValidationError, match="Must be one of: 2004-02-29, 2008-02-29, 2012-02-29." + ValidationError, match="Must be one of: 29/02/2004, 29/02/2008, 29/02/2012." ): - field.deserialize("2004-02-28") + field.deserialize("28/02/2004") def test_enumvalue_field_wrong_type(self): field = fields.EnumValue(fields.String, HairColorEnum) @@ -1144,9 +1144,9 @@ def test_enumvalue_field_wrong_type(self): field = fields.EnumValue(fields.Integer, GenderEnum) with pytest.raises(ValidationError, match="Not a valid integer."): field.deserialize("dummy") - field = fields.EnumValue(fields.Date, DateEnum) + field = fields.EnumValue(fields.Date(format="%d/%m/%Y"), DateEnum) with pytest.raises(ValidationError, match="Not a valid date."): - field.deserialize("2004-02-30") + field.deserialize("30/02/2004") def test_deserialization_function_must_be_callable(self): with pytest.raises(TypeError): diff --git a/tests/test_serialization.py b/tests/test_serialization.py index 47023b133..8e751e5ce 100644 --- a/tests/test_serialization.py +++ b/tests/test_serialization.py @@ -268,8 +268,8 @@ def test_enumvalue_field_serialization(self, user): field = fields.EnumValue(fields.Integer, GenderEnum) assert field.serialize("sex", user) == 1 user.some_date = DateEnum.date_1 - field = fields.EnumValue(fields.Date, DateEnum) - assert field.serialize("some_date", user) == "2004-02-29" + field = fields.EnumValue(fields.Date(format="%d/%m/%Y"), DateEnum) + assert field.serialize("some_date", user) == "29/02/2004" def test_decimal_field(self, user): user.m1 = 12