From ed0b1c9b4a2402969314d622885e0ed370129ec2 Mon Sep 17 00:00:00 2001 From: theirix Date: Fri, 9 Oct 2020 00:13:50 +0300 Subject: [PATCH] Add serialization with timezone to AwareDateTime Allows AwareDateTime field to optionally serialize naive datetime with default timezone. --- src/marshmallow/fields.py | 10 ++++++++++ tests/test_serialization.py | 20 ++++++++++++++++++++ 2 files changed, 30 insertions(+) diff --git a/src/marshmallow/fields.py b/src/marshmallow/fields.py index 8f0ac273e..2f8a229d4 100644 --- a/src/marshmallow/fields.py +++ b/src/marshmallow/fields.py @@ -1284,6 +1284,8 @@ class AwareDateTime(DateTime): :param default_timezone: Used on deserialization. If `None`, naive datetimes are rejected. If not `None`, naive datetimes are set this timezone. + :param bool use_serialization: If `True`, naive datetimes are set to + `default_timezone` on serialization. If `False`, datetimes are not changed. :param kwargs: The same keyword arguments that :class:`Field` receives. .. versionadded:: 3.0.0rc9 @@ -1296,10 +1298,12 @@ def __init__( format: typing.Optional[str] = None, *, default_timezone: typing.Optional[dt.tzinfo] = None, + use_serialization: bool = False, **kwargs ): super().__init__(format=format, **kwargs) self.default_timezone = default_timezone + self.use_serialization = use_serialization def _deserialize(self, value, attr, data, **kwargs): ret = super()._deserialize(value, attr, data, **kwargs) @@ -1313,6 +1317,12 @@ def _deserialize(self, value, attr, data, **kwargs): ret = ret.replace(tzinfo=self.default_timezone) return ret + def _serialize(self, value, attr, obj, **kwargs): + if self.use_serialization: + if value is not None and not is_aware(value): + value = value.replace(tzinfo=self.default_timezone) + return super()._serialize(value, attr, obj, **kwargs) + class Time(Field): """ISO8601-formatted time string. diff --git a/tests/test_serialization.py b/tests/test_serialization.py index 9b6480547..c4475d93f 100644 --- a/tests/test_serialization.py +++ b/tests/test_serialization.py @@ -533,6 +533,26 @@ def test_datetime_field_format(self, user): field = fields.DateTime(format=format) assert field.serialize("created", user) == user.created.strftime(format) + @pytest.mark.parametrize( + ("timezone", "value", "expected"), + [ + (None, dt.datetime(2013, 11, 10, 1, 23, 45), "2013-11-10T01:23:45"), + ( + dt.timezone.utc, + dt.datetime(2013, 11, 10, 1, 23, 45), + "2013-11-10T01:23:45+00:00", + ), + ( + dt.timezone(offset=dt.timedelta(hours=3)), + dt.datetime(2013, 11, 10, 1, 23, 45), + "2013-11-10T01:23:45+03:00", + ), + ], + ) + def test_aware_datetime_use_serialization(self, timezone, value, expected): + field = fields.AwareDateTime(default_timezone=timezone, use_serialization=True) + assert field.serialize("d", {"d": value}) == expected + def test_string_field(self): field = fields.String() user = User(name=b"foo")