Skip to content
This repository has been archived by the owner on Oct 19, 2022. It is now read-only.

Commit

Permalink
Merge pull request #40 from jeffsawatzky/master
Browse files Browse the repository at this point in the history
Create a StrictEnumField
  • Loading branch information
justanr authored Jan 28, 2021
2 parents 253c0d9 + 5b475c5 commit f036748
Show file tree
Hide file tree
Showing 3 changed files with 89 additions and 1 deletion.
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,9 @@ If either `load_by` or `dump_by` are unset, they will follow from `by_value`.
Additionally, there is `EnumField.NAME` to be explicit about the load and dump behavior, this
is the same as leaving both `by_value` and either `load_by` and/or `dump_by` unset.

If you want to ensure that the `load_by` and `dump_by` behaviour is always the same you can use
the `StrictEnumField`.

### Custom Error Message

A custom error message can be provided via the `error` keyword argument. It can accept three
Expand Down
15 changes: 15 additions & 0 deletions marshmallow_enum/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,3 +115,18 @@ def fail(self, key, **kwargs):
raise ValidationError(msg)
else:
raise super(EnumField, self).make_error(key, **kwargs)


class StrictEnumField(EnumField):
"""
Like EnumField but will always load and dump using the same behaviour
Ignores any `load_by` or `dump_by` parameters passed to it
"""

def __init__(
self, enum, by_value=False, error='', *args, **kwargs
):

kwargs.pop('load_by', None)
kwargs.pop('dump_by', None)
super(StrictEnumField, self).__init__(enum, by_value, *args, **kwargs)
72 changes: 71 additions & 1 deletion tests/test_enum_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import marshmallow
from marshmallow import Schema, ValidationError
from marshmallow.fields import List
from marshmallow_enum import EnumField
from marshmallow_enum import EnumField, StrictEnumField

PY2 = sys.version_info.major == 2
MARSHMALLOW_VERSION_MAJOR = int(marshmallow.__version__.split('.')[0])
Expand Down Expand Up @@ -359,3 +359,73 @@ class MyEnumField(EnumField):
EnumField(self.UnicodeEnumTester, error='{values}').fail('by_value')

assert exc_info.value.messages[0] == self.values


class TestStrictEnumFieldByName(object):

def setup(self):
self.field = StrictEnumField(EnumTester)

def test_serialize_enum(self):
assert self.field._serialize(EnumTester.one, None, object()) == 'one'

def test_serialize_none(self):
assert self.field._serialize(None, None, object()) is None

def test_deserialize_enum(self):
assert self.field._deserialize('one', None, {}) == EnumTester.one

def test_deserialize_none(self):
assert self.field._deserialize(None, None, {}) is None

def test_deserialize_nonexistent_member(self):
with pytest.raises(ValidationError):
self.field._deserialize('fred', None, {})


class TestStrictEnumFieldLoadAndDumpByValueIgnored(object):

def setup(self):
self.field = StrictEnumField(
EnumTester,
load_by=EnumField.VALUE,
dump_by=EnumField.VALUE
)

def test_serialize_enum(self):
assert self.field._serialize(EnumTester.one, None, object()) == 'one'

def test_serialize_none(self):
assert self.field._serialize(None, None, object()) is None

def test_deserialize_enum(self):
assert self.field._deserialize('one', None, {}) == EnumTester.one

def test_deserialize_none(self):
assert self.field._deserialize(None, None, {}) is None

def test_deserialize_nonexistent_member(self):
with pytest.raises(ValidationError):
self.field._deserialize('fred', None, {})


class TestStrictEnumFieldValue(object):

def test_deserialize_enum(self):
field = StrictEnumField(EnumTester, by_value=True)

assert field._deserialize(1, None, {}) == EnumTester.one

def test_serialize_enum(self):
field = EnumField(EnumTester, by_value=True)
assert field._serialize(EnumTester.one, None, object()) == 1

def test_serialize_none(self):
field = EnumField(EnumTester, by_value=True)
assert field._serialize(None, None, object()) is None

def test_deserialize_nonexistent_member(self):
field = EnumField(EnumTester, by_value=True)

with pytest.raises(ValidationError):
field._deserialize(4, None, {})

0 comments on commit f036748

Please sign in to comment.