diff --git a/extras_mongoengine/fields.py b/extras_mongoengine/fields.py index 5871ddf..6906d47 100644 --- a/extras_mongoengine/fields.py +++ b/extras_mongoengine/fields.py @@ -1,6 +1,14 @@ +import operator + from datetime import timedelta from mongoengine.base import BaseField -from mongoengine.fields import IntField, StringField, EmailField +from mongoengine.fields import EmailField, IntField, ListField, StringField + +try: + from functools import reduce +except ImportError: + # reduce is a builtin in Python2 + pass class TimedeltaField(BaseField): @@ -105,6 +113,33 @@ class IntEnumField(EnumField, IntField): pass +class IntFlagField(ListField): + def __init__(self, enum, **kwargs): + super(IntFlagField, self).__init__(IntEnumField(enum), **kwargs) + + def __get__(self, instance, owner): + if instance is None: + return self + + return self.field.enum(reduce( + operator.or_, instance._data.get(self.name, []), 0)) + + def __set__(self, instance, value): + # copy mongoengine + if value is None: + if self.null: + value = None + elif self.default is not None: + value = self.default + if callable(value): + value = value() + + if value is not None and not isinstance(value, list): + value = [i for i in self.field.enum if i and i & value == i] + + super(IntFlagField, self).__set__(instance, value) + + class StringEnumField(EnumField, StringField): """A variation on :class:`EnumField` for only string containing enumeration. """