diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index c704e5f6..2ceb70d9 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -44,6 +44,7 @@ repos: hooks: - id: rstcheck additional_dependencies: [sphinx] + args: ["--ignore-directives=fieldlookup,setting", "--ignore-roles=lookup,setting"] # We use the Python version instead of the original version which seems to require Docker # https://github.com/koalaman/shellcheck-precommit diff --git a/django_mongodb_backend/expressions.py b/django_mongodb_backend/expressions.py index 2d14f3d1..8e8c1815 100644 --- a/django_mongodb_backend/expressions.py +++ b/django_mongodb_backend/expressions.py @@ -95,7 +95,7 @@ def order_by(self, compiler, connection): return self.expression.as_mql(compiler, connection) -def query(self, compiler, connection, lookup_name=None): +def query(self, compiler, connection, get_wrapping_pipeline=None): subquery_compiler = self.get_compiler(connection=connection) subquery_compiler.pre_sql_setup(with_col_aliases=False) field_name, expr = subquery_compiler.columns[0] @@ -119,48 +119,20 @@ def query(self, compiler, connection, lookup_name=None): for col, i in subquery_compiler.column_indices.items() }, } - # The result must be a list of values. The output is compressed with an - # aggregation pipeline. - if lookup_name in ("in", "range"): - if subquery.aggregation_pipeline is None: - subquery.aggregation_pipeline = [] - wrapping_result_pipeline = [ - { - "$facet": { - "group": [ - { - "$group": { - "_id": None, - "tmp_name": { - "$addToSet": expr.as_mql(subquery_compiler, connection) - }, - } - } - ] - } - }, - { - "$project": { - field_name: { - "$ifNull": [ - { - "$getField": { - "input": {"$arrayElemAt": ["$group", 0]}, - "field": "tmp_name", - } - }, - [], - ] - } - } - }, - ] + if get_wrapping_pipeline: + # The results from some lookups must be converted to a list of values. + # The output is compressed with an aggregation pipeline. + wrapping_result_pipeline = get_wrapping_pipeline( + subquery_compiler, connection, field_name, expr + ) # If the subquery is a combinator, wrap the result at the end of the # combinator pipeline... if subquery.query.combinator: subquery.combinator_pipeline.extend(wrapping_result_pipeline) # ... otherwise put at the end of subquery's pipeline. else: + if subquery.aggregation_pipeline is None: + subquery.aggregation_pipeline = [] subquery.aggregation_pipeline.extend(wrapping_result_pipeline) # Erase project_fields since the required value is projected above. subquery.project_fields = None @@ -185,13 +157,13 @@ def star(self, compiler, connection): # noqa: ARG001 return {"$literal": True} -def subquery(self, compiler, connection, lookup_name=None): - return self.query.as_mql(compiler, connection, lookup_name=lookup_name) +def subquery(self, compiler, connection, get_wrapping_pipeline=None): + return self.query.as_mql(compiler, connection, get_wrapping_pipeline=get_wrapping_pipeline) -def exists(self, compiler, connection, lookup_name=None): +def exists(self, compiler, connection, get_wrapping_pipeline=None): try: - lhs_mql = subquery(self, compiler, connection, lookup_name=lookup_name) + lhs_mql = subquery(self, compiler, connection, get_wrapping_pipeline=get_wrapping_pipeline) except EmptyResultSet: return Value(False).as_mql(compiler, connection) return connection.mongo_operators["isnull"](lhs_mql, False) diff --git a/django_mongodb_backend/features.py b/django_mongodb_backend/features.py index 1fba9523..59f867e0 100644 --- a/django_mongodb_backend/features.py +++ b/django_mongodb_backend/features.py @@ -80,6 +80,13 @@ class DatabaseFeatures(BaseDatabaseFeatures): "auth_tests.test_views.LoginTest.test_login_session_without_hash_session_key", # GenericRelation.value_to_string() assumes integer pk. "contenttypes_tests.test_fields.GenericRelationTests.test_value_to_string", + # icontains doesn't work on ArrayField: + # Unsupported conversion from array to string in $convert + "model_fields_.test_arrayfield.QueryingTests.test_icontains", + # ArrayField's contained_by lookup crashes with Exists: "both operands " + # of $setIsSubset must be arrays. Second argument is of type: null" + # https://jira.mongodb.org/browse/SERVER-99186 + "model_fields_.test_arrayfield.QueryingTests.test_contained_by_subquery", } # $bitAnd, #bitOr, and $bitXor are new in MongoDB 6.3. _django_test_expected_failures_bitwise = { diff --git a/django_mongodb_backend/fields/__init__.py b/django_mongodb_backend/fields/__init__.py index 9eb2518d..cab7071c 100644 --- a/django_mongodb_backend/fields/__init__.py +++ b/django_mongodb_backend/fields/__init__.py @@ -1,9 +1,10 @@ +from .array import ArrayField from .auto import ObjectIdAutoField from .duration import register_duration_field from .json import register_json_field from .objectid import ObjectIdField -__all__ = ["register_fields", "ObjectIdAutoField", "ObjectIdField"] +__all__ = ["register_fields", "ArrayField", "ObjectIdAutoField", "ObjectIdField"] def register_fields(): diff --git a/django_mongodb_backend/fields/array.py b/django_mongodb_backend/fields/array.py new file mode 100644 index 00000000..49c9e6ad --- /dev/null +++ b/django_mongodb_backend/fields/array.py @@ -0,0 +1,391 @@ +import json + +from django.contrib.postgres.validators import ArrayMaxLengthValidator +from django.core import checks, exceptions +from django.db.models import DecimalField, Field, Func, IntegerField, Transform, Value +from django.db.models.fields.mixins import CheckFieldDefaultMixin +from django.db.models.lookups import Exact, FieldGetDbPrepValueMixin, In, Lookup +from django.utils.translation import gettext_lazy as _ + +from ..forms import SimpleArrayField +from ..query_utils import process_lhs, process_rhs +from ..utils import prefix_validation_error + +__all__ = ["ArrayField"] + + +class AttributeSetter: + def __init__(self, name, value): + setattr(self, name, value) + + +class ArrayField(CheckFieldDefaultMixin, Field): + empty_strings_allowed = False + default_error_messages = { + "item_invalid": _("Item %(nth)s in the array did not validate:"), + "nested_array_mismatch": _("Nested arrays must have the same length."), + } + _default_hint = ("list", "[]") + + def __init__(self, base_field, size=None, **kwargs): + self.base_field = base_field + self.size = size + if self.size: + self.default_validators = [ + *self.default_validators, + ArrayMaxLengthValidator(self.size), + ] + # For performance, only add a from_db_value() method if the base field + # implements it. + if hasattr(self.base_field, "from_db_value"): + self.from_db_value = self._from_db_value + super().__init__(**kwargs) + + @property + def model(self): + try: + return self.__dict__["model"] + except KeyError: + raise AttributeError( + "'%s' object has no attribute 'model'" % self.__class__.__name__ + ) from None + + @model.setter + def model(self, model): + self.__dict__["model"] = model + self.base_field.model = model + + @classmethod + def _choices_is_value(cls, value): + return isinstance(value, list | tuple) or super()._choices_is_value(value) + + def check(self, **kwargs): + errors = super().check(**kwargs) + if self.base_field.remote_field: + errors.append( + checks.Error( + "Base field for array cannot be a related field.", + obj=self, + id="django_mongodb_backend.array.E002", + ) + ) + else: + base_checks = self.base_field.check() + if base_checks: + error_messages = "\n ".join( + f"{base_check.msg} ({base_check.id})" + for base_check in base_checks + if isinstance(base_check, checks.Error) + ) + if error_messages: + errors.append( + checks.Error( + f"Base field for array has errors:\n {error_messages}", + obj=self, + id="django_mongodb_backend.array.E001", + ) + ) + warning_messages = "\n ".join( + f"{base_check.msg} ({base_check.id})" + for base_check in base_checks + if isinstance(base_check, checks.Warning) + ) + if warning_messages: + errors.append( + checks.Warning( + f"Base field for array has warnings:\n {warning_messages}", + obj=self, + id="django_mongodb_backend.array.W004", + ) + ) + return errors + + def set_attributes_from_name(self, name): + super().set_attributes_from_name(name) + self.base_field.set_attributes_from_name(name) + + @property + def description(self): + return f"Array of {self.base_field.description}" + + def db_type(self, connection): + return "array" + + def get_db_prep_value(self, value, connection, prepared=False): + if isinstance(value, list | tuple): + # Workaround for https://code.djangoproject.com/ticket/35982 + # (fixed in Django 5.2). + if isinstance(self.base_field, DecimalField): + return [self.base_field.get_db_prep_save(i, connection) for i in value] + return [self.base_field.get_db_prep_value(i, connection, prepared=False) for i in value] + return value + + def deconstruct(self): + name, path, args, kwargs = super().deconstruct() + if path == "django_mongodb_backend.fields.array.ArrayField": + path = "django_mongodb_backend.fields.ArrayField" + kwargs.update( + { + "base_field": self.base_field.clone(), + "size": self.size, + } + ) + return name, path, args, kwargs + + def to_python(self, value): + if isinstance(value, str): + # Assume value is being deserialized. + vals = json.loads(value) + value = [self.base_field.to_python(val) for val in vals] + return value + + def _from_db_value(self, value, expression, connection): + if value is None: + return value + return [self.base_field.from_db_value(item, expression, connection) for item in value] + + def value_to_string(self, obj): + values = [] + vals = self.value_from_object(obj) + base_field = self.base_field + + for val in vals: + if val is None: + values.append(None) + else: + obj = AttributeSetter(base_field.attname, val) + values.append(base_field.value_to_string(obj)) + return json.dumps(values) + + def get_transform(self, name): + transform = super().get_transform(name) + if transform: + return transform + if "_" not in name: + try: + index = int(name) + except ValueError: + pass + else: + return IndexTransformFactory(index, self.base_field) + try: + start, end = name.split("_") + start = int(start) + end = int(end) + except ValueError: + pass + else: + return SliceTransformFactory(start, end) + + def validate(self, value, model_instance): + super().validate(value, model_instance) + for index, part in enumerate(value): + try: + self.base_field.validate(part, model_instance) + except exceptions.ValidationError as error: + raise prefix_validation_error( + error, + prefix=self.error_messages["item_invalid"], + code="item_invalid", + params={"nth": index + 1}, + ) from None + if isinstance(self.base_field, ArrayField) and len({len(i) for i in value}) > 1: + raise exceptions.ValidationError( + self.error_messages["nested_array_mismatch"], + code="nested_array_mismatch", + ) + + def run_validators(self, value): + super().run_validators(value) + for index, part in enumerate(value): + try: + self.base_field.run_validators(part) + except exceptions.ValidationError as error: + raise prefix_validation_error( + error, + prefix=self.error_messages["item_invalid"], + code="item_invalid", + params={"nth": index + 1}, + ) from None + + def formfield(self, **kwargs): + return super().formfield( + **{ + "form_class": SimpleArrayField, + "base_field": self.base_field.formfield(), + "max_length": self.size, + **kwargs, + } + ) + + +class Array(Func): + def as_mql(self, compiler, connection): + return [expr.as_mql(compiler, connection) for expr in self.get_source_expressions()] + + +class ArrayRHSMixin: + def __init__(self, lhs, rhs): + if isinstance(rhs, tuple | list): + expressions = [] + for value in rhs: + if not hasattr(value, "resolve_expression"): + field = lhs.output_field + value = Value(field.base_field.get_prep_value(value)) + expressions.append(value) + rhs = Array(*expressions) + super().__init__(lhs, rhs) + + +@ArrayField.register_lookup +class ArrayContains(ArrayRHSMixin, FieldGetDbPrepValueMixin, Lookup): + lookup_name = "contains" + + def as_mql(self, compiler, connection): + lhs_mql = process_lhs(self, compiler, connection) + value = process_rhs(self, compiler, connection) + return { + "$and": [ + {"$ne": [lhs_mql, None]}, + {"$ne": [value, None]}, + {"$setIsSubset": [value, lhs_mql]}, + ] + } + + +@ArrayField.register_lookup +class ArrayContainedBy(ArrayRHSMixin, FieldGetDbPrepValueMixin, Lookup): + lookup_name = "contained_by" + + def as_mql(self, compiler, connection): + lhs_mql = process_lhs(self, compiler, connection) + value = process_rhs(self, compiler, connection) + return { + "$and": [ + {"$ne": [lhs_mql, None]}, + {"$ne": [value, None]}, + {"$setIsSubset": [lhs_mql, value]}, + ] + } + + +@ArrayField.register_lookup +class ArrayExact(ArrayRHSMixin, Exact): + pass + + +@ArrayField.register_lookup +class ArrayOverlap(ArrayRHSMixin, FieldGetDbPrepValueMixin, Lookup): + lookup_name = "overlap" + + def get_subquery_wrapping_pipeline(self, compiler, connection, field_name, expr): + return [ + { + "$facet": { + "group": [ + {"$project": {"tmp_name": expr.as_mql(compiler, connection)}}, + { + "$unwind": "$tmp_name", + }, + { + "$group": { + "_id": None, + "tmp_name": {"$addToSet": "$tmp_name"}, + } + }, + ] + } + }, + { + "$project": { + field_name: { + "$ifNull": [ + { + "$getField": { + "input": {"$arrayElemAt": ["$group", 0]}, + "field": "tmp_name", + } + }, + [], + ] + } + } + }, + ] + + def as_mql(self, compiler, connection): + lhs_mql = process_lhs(self, compiler, connection) + value = process_rhs(self, compiler, connection) + return { + "$and": [{"$ne": [lhs_mql, None]}, {"$size": {"$setIntersection": [value, lhs_mql]}}] + } + + +@ArrayField.register_lookup +class ArrayLenTransform(Transform): + lookup_name = "len" + output_field = IntegerField() + + def as_mql(self, compiler, connection): + lhs_mql = process_lhs(self, compiler, connection) + return {"$cond": {"if": {"$eq": [lhs_mql, None]}, "then": None, "else": {"$size": lhs_mql}}} + + +@ArrayField.register_lookup +class ArrayInLookup(In): + def get_prep_lookup(self): + values = super().get_prep_lookup() + if hasattr(values, "resolve_expression"): + return values + # process_rhs() expects hashable values, so convert lists to tuples. + prepared_values = [] + for value in values: + if hasattr(value, "resolve_expression"): + prepared_values.append(value) + else: + prepared_values.append(tuple(value)) + return prepared_values + + +class IndexTransform(Transform): + def __init__(self, index, base_field, *args, **kwargs): + super().__init__(*args, **kwargs) + self.index = index + self.base_field = base_field + + def as_mql(self, compiler, connection): + lhs_mql = process_lhs(self, compiler, connection) + return {"$arrayElemAt": [lhs_mql, self.index]} + + @property + def output_field(self): + return self.base_field + + +class IndexTransformFactory: + def __init__(self, index, base_field): + self.index = index + self.base_field = base_field + + def __call__(self, *args, **kwargs): + return IndexTransform(self.index, self.base_field, *args, **kwargs) + + +class SliceTransform(Transform): + def __init__(self, start, end, *args, **kwargs): + super().__init__(*args, **kwargs) + self.start = start + self.end = end + + def as_mql(self, compiler, connection): + lhs_mql = process_lhs(self, compiler, connection) + return {"$slice": [lhs_mql, self.start, self.end]} + + +class SliceTransformFactory: + def __init__(self, start, end): + self.start = start + self.end = end + + def __call__(self, *args, **kwargs): + return SliceTransform(self.start, self.end, *args, **kwargs) diff --git a/django_mongodb_backend/forms/__init__.py b/django_mongodb_backend/forms/__init__.py index 9009a3ee..96c8775e 100644 --- a/django_mongodb_backend/forms/__init__.py +++ b/django_mongodb_backend/forms/__init__.py @@ -1,3 +1,8 @@ -from .fields import ObjectIdField +from .fields import ObjectIdField, SimpleArrayField, SplitArrayField, SplitArrayWidget -__all__ = ["ObjectIdField"] +__all__ = [ + "SimpleArrayField", + "SplitArrayField", + "SplitArrayWidget", + "ObjectIdField", +] diff --git a/django_mongodb_backend/forms/fields/__init__.py b/django_mongodb_backend/forms/fields/__init__.py new file mode 100644 index 00000000..298c6b68 --- /dev/null +++ b/django_mongodb_backend/forms/fields/__init__.py @@ -0,0 +1,9 @@ +from .array import SimpleArrayField, SplitArrayField, SplitArrayWidget +from .objectid import ObjectIdField + +__all__ = [ + "SimpleArrayField", + "SplitArrayField", + "SplitArrayWidget", + "ObjectIdField", +] diff --git a/django_mongodb_backend/forms/fields/array.py b/django_mongodb_backend/forms/fields/array.py new file mode 100644 index 00000000..0de48dff --- /dev/null +++ b/django_mongodb_backend/forms/fields/array.py @@ -0,0 +1,242 @@ +import copy +from itertools import chain + +from django import forms +from django.core.exceptions import ValidationError +from django.utils.translation import gettext_lazy as _ + +from ...utils import prefix_validation_error +from ...validators import ArrayMaxLengthValidator, ArrayMinLengthValidator + + +class SimpleArrayField(forms.CharField): + default_error_messages = { + "item_invalid": _("Item %(nth)s in the array did not validate:"), + } + + def __init__(self, base_field, *, delimiter=",", max_length=None, min_length=None, **kwargs): + self.base_field = base_field + self.delimiter = delimiter + super().__init__(**kwargs) + if min_length is not None: + self.min_length = min_length + self.validators.append(ArrayMinLengthValidator(int(min_length))) + if max_length is not None: + self.max_length = max_length + self.validators.append(ArrayMaxLengthValidator(int(max_length))) + + def clean(self, value): + value = super().clean(value) + return [self.base_field.clean(val) for val in value] + + def prepare_value(self, value): + if isinstance(value, list): + return self.delimiter.join(str(self.base_field.prepare_value(v)) for v in value) + return value + + def to_python(self, value): + if isinstance(value, list): + items = value + elif value: + items = value.split(self.delimiter) + else: + items = [] + errors = [] + values = [] + for index, item in enumerate(items): + try: + values.append(self.base_field.to_python(item)) + except ValidationError as error: + errors.append( + prefix_validation_error( + error, + prefix=self.error_messages["item_invalid"], + code="item_invalid", + params={"nth": index + 1}, + ) + ) + if errors: + raise ValidationError(errors) + return values + + def validate(self, value): + super().validate(value) + errors = [] + for index, item in enumerate(value): + try: + self.base_field.validate(item) + except ValidationError as error: + errors.append( + prefix_validation_error( + error, + prefix=self.error_messages["item_invalid"], + code="item_invalid", + params={"nth": index + 1}, + ) + ) + if errors: + raise ValidationError(errors) + + def run_validators(self, value): + super().run_validators(value) + errors = [] + for index, item in enumerate(value): + try: + self.base_field.run_validators(item) + except ValidationError as error: + errors.append( + prefix_validation_error( + error, + prefix=self.error_messages["item_invalid"], + code="item_invalid", + params={"nth": index + 1}, + ) + ) + if errors: + raise ValidationError(errors) + + def has_changed(self, initial, data): + try: + value = self.to_python(data) + except ValidationError: + pass + else: + if initial in self.empty_values and value in self.empty_values: + return False + return super().has_changed(initial, data) + + +class SplitArrayWidget(forms.Widget): + template_name = "mongodb/widgets/split_array.html" + + def __init__(self, widget, size, **kwargs): + self.widget = widget() if isinstance(widget, type) else widget + self.size = size + super().__init__(**kwargs) + + @property + def is_hidden(self): + return self.widget.is_hidden + + def value_from_datadict(self, data, files, name): + return [ + self.widget.value_from_datadict(data, files, f"{name}_{index}") + for index in range(self.size) + ] + + def value_omitted_from_data(self, data, files, name): + return all( + self.widget.value_omitted_from_data(data, files, f"{name}_{index}") + for index in range(self.size) + ) + + def id_for_label(self, id_): + # See the comment for RadioSelect.id_for_label() + if id_: + id_ += "_0" + return id_ + + def get_context(self, name, value, attrs=None): + attrs = {} if attrs is None else attrs + context = super().get_context(name, value, attrs) + if self.is_localized: + self.widget.is_localized = self.is_localized + value = value or [] + context["widget"]["subwidgets"] = [] + final_attrs = self.build_attrs(attrs) + id_ = final_attrs.get("id") + for i in range(max(len(value), self.size)): + try: + widget_value = value[i] + except IndexError: + widget_value = None + if id_: + final_attrs = {**final_attrs, "id": f"{id_}_{i}"} + context["widget"]["subwidgets"].append( + self.widget.get_context(name + "_%s" % i, widget_value, final_attrs)["widget"] + ) + return context + + @property + def media(self): + return self.widget.media + + def __deepcopy__(self, memo): + obj = super().__deepcopy__(memo) + obj.widget = copy.deepcopy(self.widget) + return obj + + @property + def needs_multipart_form(self): + return self.widget.needs_multipart_form + + +class SplitArrayField(forms.Field): + default_error_messages = { + "item_invalid": _("Item %(nth)s in the array did not validate:"), + } + + def __init__(self, base_field, size, *, remove_trailing_nulls=False, **kwargs): + self.base_field = base_field + self.size = size + self.remove_trailing_nulls = remove_trailing_nulls + widget = SplitArrayWidget(widget=base_field.widget, size=size) + kwargs.setdefault("widget", widget) + super().__init__(**kwargs) + + def _remove_trailing_nulls(self, values): + index = None + if self.remove_trailing_nulls: + for i, value in reversed(list(enumerate(values))): + if value in self.base_field.empty_values: + index = i + else: + break + if index is not None: + values = values[:index] + return values, index + + def to_python(self, value): + value = super().to_python(value) + return [self.base_field.to_python(item) for item in value] + + def clean(self, value): + cleaned_data = [] + errors = [] + if not any(value) and self.required: + raise ValidationError(self.error_messages["required"]) + max_size = max(self.size, len(value)) + for index in range(max_size): + item = value[index] + try: + cleaned_data.append(self.base_field.clean(item)) + except ValidationError as error: + errors.append( + prefix_validation_error( + error, + self.error_messages["item_invalid"], + code="item_invalid", + params={"nth": index + 1}, + ) + ) + cleaned_data.append(None) + else: + errors.append(None) + cleaned_data, null_index = self._remove_trailing_nulls(cleaned_data) + if null_index is not None: + errors = errors[:null_index] + errors = list(filter(None, errors)) + if errors: + raise ValidationError(list(chain.from_iterable(errors))) + return cleaned_data + + def has_changed(self, initial, data): + try: + data = self.to_python(data) + except ValidationError: + pass + else: + data, _ = self._remove_trailing_nulls(data) + if initial in self.empty_values and data in self.empty_values: + return False + return super().has_changed(initial, data) diff --git a/django_mongodb_backend/forms/fields.py b/django_mongodb_backend/forms/fields/objectid.py similarity index 100% rename from django_mongodb_backend/forms/fields.py rename to django_mongodb_backend/forms/fields/objectid.py diff --git a/django_mongodb_backend/jinja2/mongodb/widgets/split_array.html b/django_mongodb_backend/jinja2/mongodb/widgets/split_array.html new file mode 100644 index 00000000..32fda826 --- /dev/null +++ b/django_mongodb_backend/jinja2/mongodb/widgets/split_array.html @@ -0,0 +1 @@ +{% include 'django/forms/widgets/multiwidget.html' %} diff --git a/django_mongodb_backend/lookups.py b/django_mongodb_backend/lookups.py index c651dd6a..519a03c9 100644 --- a/django_mongodb_backend/lookups.py +++ b/django_mongodb_backend/lookups.py @@ -45,6 +45,38 @@ def in_(self, compiler, connection): return builtin_lookup(self, compiler, connection) +def get_subquery_wrapping_pipeline(self, compiler, connection, field_name, expr): # noqa: ARG001 + return [ + { + "$facet": { + "group": [ + { + "$group": { + "_id": None, + "tmp_name": {"$addToSet": expr.as_mql(compiler, connection)}, + } + } + ] + } + }, + { + "$project": { + field_name: { + "$ifNull": [ + { + "$getField": { + "input": {"$arrayElemAt": ["$group", 0]}, + "field": "tmp_name", + } + }, + [], + ] + } + } + }, + ] + + def is_null(self, compiler, connection): if not isinstance(self.rhs, bool): raise ValueError("The QuerySet value for an isnull lookup must be True or False.") @@ -97,6 +129,7 @@ def register_lookups(): field_resolve_expression_parameter ) In.as_mql = RelatedIn.as_mql = in_ + In.get_subquery_wrapping_pipeline = get_subquery_wrapping_pipeline IsNull.as_mql = is_null PatternLookup.prep_lookup_value_mongo = pattern_lookup_prep_lookup_value UUIDTextMixin.as_mql = uuid_text_mixin diff --git a/django_mongodb_backend/operations.py b/django_mongodb_backend/operations.py index 71b38350..cb1e93db 100644 --- a/django_mongodb_backend/operations.py +++ b/django_mongodb_backend/operations.py @@ -9,7 +9,7 @@ from django.db import DataError from django.db.backends.base.operations import BaseDatabaseOperations from django.db.models import TextField -from django.db.models.expressions import Combinable +from django.db.models.expressions import Combinable, Expression from django.db.models.functions import Cast from django.utils import timezone from django.utils.regex_helper import _lazy_re_compile @@ -77,10 +77,26 @@ def adapt_timefield_value(self, value): raise ValueError("MongoDB backend does not support timezone-aware times.") return datetime.datetime.combine(datetime.datetime.min.date(), value) + def _get_arrayfield_converter(self, converter, *args, **kwargs): + # Return a database converter that can be applied to a list of values. + def convert_value(value, expression, connection): + return [converter(x, expression, connection) for x in value] + + return convert_value + def get_db_converters(self, expression): converters = super().get_db_converters(expression) internal_type = expression.output_field.get_internal_type() - if internal_type == "DateField": + if internal_type == "ArrayField": + converters.extend( + [ + self._get_arrayfield_converter(converter) + for converter in self.get_db_converters( + Expression(output_field=expression.output_field.base_field) + ) + ] + ) + elif internal_type == "DateField": converters.append(self.convert_datefield_value) elif internal_type == "DateTimeField": if settings.USE_TZ: diff --git a/django_mongodb_backend/query_utils.py b/django_mongodb_backend/query_utils.py index ff98a1ed..dd7042c7 100644 --- a/django_mongodb_backend/query_utils.py +++ b/django_mongodb_backend/query_utils.py @@ -28,8 +28,10 @@ def process_lhs(node, compiler, connection): def process_rhs(node, compiler, connection): rhs = node.rhs if hasattr(rhs, "as_mql"): - if getattr(rhs, "subquery", False): - value = rhs.as_mql(compiler, connection, lookup_name=node.lookup_name) + if getattr(rhs, "subquery", False) and hasattr(node, "get_subquery_wrapping_pipeline"): + value = rhs.as_mql( + compiler, connection, get_wrapping_pipeline=node.get_subquery_wrapping_pipeline + ) else: value = rhs.as_mql(compiler, connection) else: diff --git a/django_mongodb_backend/templates/mongodb/widgets/split_array.html b/django_mongodb_backend/templates/mongodb/widgets/split_array.html new file mode 100644 index 00000000..32fda826 --- /dev/null +++ b/django_mongodb_backend/templates/mongodb/widgets/split_array.html @@ -0,0 +1 @@ +{% include 'django/forms/widgets/multiwidget.html' %} diff --git a/django_mongodb_backend/utils.py b/django_mongodb_backend/utils.py index 5b2051a8..c389d93b 100644 --- a/django_mongodb_backend/utils.py +++ b/django_mongodb_backend/utils.py @@ -3,8 +3,10 @@ import django from django.conf import settings -from django.core.exceptions import ImproperlyConfigured +from django.core.exceptions import ImproperlyConfigured, ValidationError from django.db.backends.utils import logger +from django.utils.functional import SimpleLazyObject +from django.utils.text import format_lazy from django.utils.version import get_version_tuple from pymongo.uri_parser import parse_uri as pymongo_parse_uri @@ -58,6 +60,32 @@ def parse_uri(uri, conn_max_age=0, test=None): return settings_dict +def prefix_validation_error(error, prefix, code, params): + """ + Prefix a validation error message while maintaining the existing + validation data structure. + """ + if error.error_list == [error]: + error_params = error.params or {} + return ValidationError( + # Messages can't simply be concatenated since they might require + # their associated parameters to be expressed correctly which is + # not something format_lazy() does. For example, proxied + # ngettext calls require a count parameter and are converted + # to an empty string if they are missing it. + message=format_lazy( + "{} {}", + SimpleLazyObject(lambda: prefix % params), + SimpleLazyObject(lambda: error.message % error_params), + ), + code=code, + params={**error_params, **params}, + ) + return ValidationError( + [prefix_validation_error(e, prefix, code, params) for e in error.error_list] + ) + + def set_wrapped_methods(cls): """Initialize the wrapped methods on cls.""" if hasattr(cls, "logging_wrapper"): diff --git a/django_mongodb_backend/validators.py b/django_mongodb_backend/validators.py new file mode 100644 index 00000000..6005152e --- /dev/null +++ b/django_mongodb_backend/validators.py @@ -0,0 +1,18 @@ +from django.core.validators import MaxLengthValidator, MinLengthValidator +from django.utils.translation import ngettext_lazy + + +class ArrayMaxLengthValidator(MaxLengthValidator): + message = ngettext_lazy( + "List contains %(show_value)d item, it should contain no more than %(limit_value)d.", + "List contains %(show_value)d items, it should contain no more than %(limit_value)d.", + "show_value", + ) + + +class ArrayMinLengthValidator(MinLengthValidator): + message = ngettext_lazy( + "List contains %(show_value)d item, it should contain no fewer than %(limit_value)d.", + "List contains %(show_value)d items, it should contain no fewer than %(limit_value)d.", + "show_value", + ) diff --git a/docs/source/_ext/djangodocs.py b/docs/source/_ext/djangodocs.py new file mode 100644 index 00000000..fda464d8 --- /dev/null +++ b/docs/source/_ext/djangodocs.py @@ -0,0 +1,11 @@ +def setup(app): + app.add_crossref_type( + directivename="fieldlookup", + rolename="lookup", + indextemplate="pair: %s; field lookup type", + ) + app.add_crossref_type( + directivename="setting", + rolename="setting", + indextemplate="pair: %s; setting", + ) diff --git a/docs/source/conf.py b/docs/source/conf.py index f81b9335..55bb84e1 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -7,7 +7,14 @@ # https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information from __future__ import annotations +import sys from importlib.metadata import version as _version +from pathlib import Path + +# If extensions (or modules to document with autodoc) are in another directory, +# add these directories to sys.path here. If the directory is relative to the +# documentation root, use os.path.abspath to make it absolute, like shown here. +sys.path.append(str((Path(__file__).parent / "_ext").resolve())) project = "django_mongodb_backend" copyright = "2024, The MongoDB Python Team" @@ -22,6 +29,7 @@ add_module_names = False extensions = [ + "djangodocs", "sphinx.ext.intersphinx", ] diff --git a/docs/source/fields.rst b/docs/source/fields.rst index 58f30a62..39f965a7 100644 --- a/docs/source/fields.rst +++ b/docs/source/fields.rst @@ -5,6 +5,211 @@ Model field reference Some MongoDB-specific fields are available in ``django_mongodb_backend.fields``. +``ArrayField`` +-------------- + +.. class:: ArrayField(base_field, size=None, **options) + + A field for storing lists of data. Most field types can be used, and you + pass another field instance as the :attr:`base_field + `. You may also specify a :attr:`size + `. ``ArrayField`` can be nested to store multi-dimensional + arrays. + + If you give the field a :attr:`~django.db.models.Field.default`, ensure + it's a callable such as ``list`` (for an empty default) or a callable that + returns a list (such as a function). Incorrectly using ``default=[]`` + creates a mutable default that is shared between all instances of + ``ArrayField``. + + .. attribute:: base_field + + This is a required argument. + + Specifies the underlying data type and behavior for the array. It + should be an instance of a subclass of + :class:`~django.db.models.Field`. For example, it could be an + :class:`~django.db.models.IntegerField` or a + :class:`~django.db.models.CharField`. Most field types are permitted, + with the exception of those handling relational data + (:class:`~django.db.models.ForeignKey`, + :class:`~django.db.models.OneToOneField` and + :class:`~django.db.models.ManyToManyField`) and file fields ( + :class:`~django.db.models.FileField` and + :class:`~django.db.models.ImageField`). + + It is possible to nest array fields - you can specify an instance of + ``ArrayField`` as the ``base_field``. For example:: + + from django.db import models + from django_mongodb_backend.fields import ArrayField + + + class ChessBoard(models.Model): + board = ArrayField( + ArrayField( + models.CharField(max_length=10, blank=True), + size=8, + ), + size=8, + ) + + Transformation of values between the database and the model, validation + of data and configuration, and serialization are all delegated to the + underlying base field. + + .. attribute:: size + + This is an optional argument. + + If passed, the array will have a maximum size as specified, validated + only by forms. + +Querying ``ArrayField`` +~~~~~~~~~~~~~~~~~~~~~~~ + +There are a number of custom lookups and transforms for :class:`ArrayField`. +We will use the following example model:: + + from django.db import models + from django_mongodb_backend.fields import ArrayField + + + class Post(models.Model): + name = models.CharField(max_length=200) + tags = ArrayField(models.CharField(max_length=200), blank=True) + + def __str__(self): + return self.name + +.. fieldlookup:: arrayfield.contains + +``contains`` +^^^^^^^^^^^^ + +The :lookup:`contains` lookup is overridden on :class:`ArrayField`. The +returned objects will be those where the values passed are a subset of the +data. It uses the ``$setIntersection`` operator. For example: + +.. code-block:: pycon + + >>> Post.objects.create(name="First post", tags=["thoughts", "django"]) + >>> Post.objects.create(name="Second post", tags=["thoughts"]) + >>> Post.objects.create(name="Third post", tags=["tutorial", "django"]) + + >>> Post.objects.filter(tags__contains=["thoughts"]) + , ]> + + >>> Post.objects.filter(tags__contains=["django"]) + , ]> + + >>> Post.objects.filter(tags__contains=["django", "thoughts"]) + ]> + +``contained_by`` +~~~~~~~~~~~~~~~~ + +This is the inverse of the :lookup:`contains ` lookup - +the objects returned will be those where the data is a subset of the values +passed. It uses the ``$setIntersection`` operator. For example: + +.. code-block:: pycon + + >>> Post.objects.create(name="First post", tags=["thoughts", "django"]) + >>> Post.objects.create(name="Second post", tags=["thoughts"]) + >>> Post.objects.create(name="Third post", tags=["tutorial", "django"]) + + >>> Post.objects.filter(tags__contained_by=["thoughts", "django"]) + , ]> + + >>> Post.objects.filter(tags__contained_by=["thoughts", "django", "tutorial"]) + , , ]> + +.. fieldlookup:: arrayfield.overlap + +``overlap`` +~~~~~~~~~~~ + +Returns objects where the data shares any results with the values passed. It +uses the ``$setIntersection`` operator. For example: + +.. code-block:: pycon + + >>> Post.objects.create(name="First post", tags=["thoughts", "django"]) + >>> Post.objects.create(name="Second post", tags=["thoughts", "tutorial"]) + >>> Post.objects.create(name="Third post", tags=["tutorial", "django"]) + + >>> Post.objects.filter(tags__overlap=["thoughts"]) + , ]> + + >>> Post.objects.filter(tags__overlap=["thoughts", "tutorial"]) + , , ]> + +.. fieldlookup:: arrayfield.len + +``len`` +^^^^^^^ + +Returns the length of the array. The lookups available afterward are those +available for :class:`~django.db.models.IntegerField`. For example: + +.. code-block:: pycon + + >>> Post.objects.create(name="First post", tags=["thoughts", "django"]) + >>> Post.objects.create(name="Second post", tags=["thoughts"]) + + >>> Post.objects.filter(tags__len=1) + ]> + +.. fieldlookup:: arrayfield.index + +Index transforms +^^^^^^^^^^^^^^^^ + +Index transforms index into the array. Any non-negative integer can be used. +There are no errors if it exceeds the :attr:`size ` of the +array. The lookups available after the transform are those from the +:attr:`base_field `. For example: + +.. code-block:: pycon + + >>> Post.objects.create(name="First post", tags=["thoughts", "django"]) + >>> Post.objects.create(name="Second post", tags=["thoughts"]) + + >>> Post.objects.filter(tags__0="thoughts") + , ]> + + >>> Post.objects.filter(tags__1__iexact="Django") + ]> + + >>> Post.objects.filter(tags__276="javascript") + + +These indexes use 0-based indexing. + +.. fieldlookup:: arrayfield.slice + +Slice transforms +^^^^^^^^^^^^^^^^ + +Slice transforms take a slice of the array. Any two non-negative integers can +be used, separated by a single underscore. The lookups available after the +transform do not change. For example: + +.. code-block:: pycon + + >>> Post.objects.create(name="First post", tags=["thoughts", "django"]) + >>> Post.objects.create(name="Second post", tags=["thoughts"]) + >>> Post.objects.create(name="Third post", tags=["django", "python", "thoughts"]) + + >>> Post.objects.filter(tags__0_1=["thoughts"]) + , ]> + + >>> Post.objects.filter(tags__0_2__contains=["thoughts"]) + , ]> + +These indexes use 0-based indexing. + ``ObjectIdField`` ----------------- diff --git a/docs/source/forms.rst b/docs/source/forms.rst index f3fa342b..eb020073 100644 --- a/docs/source/forms.rst +++ b/docs/source/forms.rst @@ -11,3 +11,143 @@ Some MongoDB-specific fields are available in ``django_mongodb_backend.forms``. .. class:: ObjectIdField Stores an :class:`~bson.objectid.ObjectId`. + +``SimpleArrayField`` +-------------------- + +.. class:: SimpleArrayField(base_field, delimiter=',', max_length=None, min_length=None) + + A field which maps to an array. It is represented by an HTML ````. + + .. attribute:: base_field + + This is a required argument. + + It specifies the underlying form field for the array. This is not used + to render any HTML, but it is used to process the submitted data and + validate it. For example: + + .. code-block:: pycon + + >>> from django import forms + >>> from django_mongodb_backend.forms import SimpleArrayField + + >>> class NumberListForm(forms.Form): + ... numbers = SimpleArrayField(forms.IntegerField()) + ... + + >>> form = NumberListForm({"numbers": "1,2,3"}) + >>> form.is_valid() + True + >>> form.cleaned_data + {'numbers': [1, 2, 3]} + + >>> form = NumberListForm({"numbers": "1,2,a"}) + >>> form.is_valid() + False + + .. attribute:: delimiter + + This is an optional argument which defaults to a comma: ``,``. This + value is used to split the submitted data. It allows you to chain + ``SimpleArrayField`` for multidimensional data: + + .. code-block:: pycon + + >>> from django import forms + >>> from django_mongodb_backend.forms import SimpleArrayField + + >>> class GridForm(forms.Form): + ... places = SimpleArrayField(SimpleArrayField(IntegerField()), delimiter="|") + ... + + >>> form = GridForm({"places": "1,2|2,1|4,3"}) + >>> form.is_valid() + True + >>> form.cleaned_data + {'places': [[1, 2], [2, 1], [4, 3]]} + + .. note:: + + The field does not support escaping of the delimiter, so be careful + in cases where the delimiter is a valid character in the underlying + field. The delimiter does not need to be only one character. + + .. attribute:: max_length + + This is an optional argument which validates that the array does not + exceed the stated length. + + .. attribute:: min_length + + This is an optional argument which validates that the array reaches at + least the stated length. + + .. admonition:: User friendly forms + + ``SimpleArrayField`` is not particularly user friendly in most cases, + however it is a useful way to format data from a client-side widget for + submission to the server. + +``SplitArrayField`` +------------------- + +.. class:: SplitArrayField(base_field, size, remove_trailing_nulls=False) + + This field handles arrays by reproducing the underlying field a fixed + number of times. + + The template for this widget is located in + ``django_mongodb_backend/templates/mongodb/widgets``. Don't forget to + configure template loading appropriately, for example, by using a + :class:`~django.template.backends.django.DjangoTemplates` engine with + :setting:`APP_DIRS=True ` and + ``"django_mongodb_backend"`` in :setting:`INSTALLED_APPS`. + + .. attribute:: base_field + + This is a required argument. It specifies the form field to be + repeated. + + .. attribute:: size + + This is the fixed number of times the underlying field will be used. + + .. attribute:: remove_trailing_nulls + + By default, this is set to ``False``. When ``False``, each value from + the repeated fields is stored. When set to ``True``, any trailing + values which are blank will be stripped from the result. If the + underlying field has ``required=True``, but ``remove_trailing_nulls`` + is ``True``, then null values are only allowed at the end, and will be + stripped. + + Some examples:: + + SplitArrayField(IntegerField(required=True), size=3, remove_trailing_nulls=False) + + ["1", "2", "3"] # -> [1, 2, 3] + ["1", "2", ""] # -> ValidationError - third entry required. + ["1", "", "3"] # -> ValidationError - second entry required. + ["", "2", ""] # -> ValidationError - first and third entries required. + + SplitArrayField(IntegerField(required=False), size=3, remove_trailing_nulls=False) + + ["1", "2", "3"] # -> [1, 2, 3] + ["1", "2", ""] # -> [1, 2, None] + ["1", "", "3"] # -> [1, None, 3] + ["", "2", ""] # -> [None, 2, None] + + SplitArrayField(IntegerField(required=True), size=3, remove_trailing_nulls=True) + + ["1", "2", "3"] # -> [1, 2, 3] + ["1", "2", ""] # -> [1, 2] + ["1", "", "3"] # -> ValidationError - second entry required. + ["", "2", ""] # -> ValidationError - first entry required. + + SplitArrayField(IntegerField(required=False), size=3, remove_trailing_nulls=True) + + ["1", "2", "3"] # -> [1, 2, 3] + ["1", "2", ""] # -> [1, 2] + ["1", "", "3"] # -> [1, None, 3] + ["", "2", ""] # -> [None, 2] diff --git a/tests/forms_tests_/test_array.py b/tests/forms_tests_/test_array.py new file mode 100644 index 00000000..58ab5566 --- /dev/null +++ b/tests/forms_tests_/test_array.py @@ -0,0 +1,405 @@ +from django import forms +from django.core import exceptions +from django.db import models +from django.test import SimpleTestCase +from django.test.utils import modify_settings +from forms_tests.widget_tests.base import WidgetTest + +from django_mongodb_backend.fields import ArrayField +from django_mongodb_backend.forms import SimpleArrayField, SplitArrayField, SplitArrayWidget + + +class IntegerArrayModel(models.Model): + field = ArrayField(models.IntegerField(), default=list, blank=True) + + +class SimpleArrayFieldTests(SimpleTestCase): + def test_valid(self): + field = SimpleArrayField(forms.CharField()) + value = field.clean("a,b,c") + self.assertEqual(value, ["a", "b", "c"]) + + def test_to_python_fail(self): + field = SimpleArrayField(forms.IntegerField()) + with self.assertRaises(exceptions.ValidationError) as cm: + field.clean("a,b,9") + self.assertEqual( + cm.exception.messages[0], + "Item 1 in the array did not validate: Enter a whole number.", + ) + + def test_validate_fail(self): + field = SimpleArrayField(forms.CharField(required=True)) + with self.assertRaises(exceptions.ValidationError) as cm: + field.clean("a,b,") + self.assertEqual( + cm.exception.messages[0], + "Item 3 in the array did not validate: This field is required.", + ) + + def test_validate_fail_base_field_error_params(self): + field = SimpleArrayField(forms.CharField(max_length=2)) + with self.assertRaises(exceptions.ValidationError) as cm: + field.clean("abc,c,defg") + errors = cm.exception.error_list + self.assertEqual(len(errors), 2) + first_error = errors[0] + self.assertEqual( + first_error.message, + "Item 1 in the array did not validate: Ensure this value has at most 2 " + "characters (it has 3).", + ) + self.assertEqual(first_error.code, "item_invalid") + self.assertEqual( + first_error.params, + {"nth": 1, "value": "abc", "limit_value": 2, "show_value": 3}, + ) + second_error = errors[1] + self.assertEqual( + second_error.message, + "Item 3 in the array did not validate: Ensure this value has at most 2 " + "characters (it has 4).", + ) + self.assertEqual(second_error.code, "item_invalid") + self.assertEqual( + second_error.params, + {"nth": 3, "value": "defg", "limit_value": 2, "show_value": 4}, + ) + + def test_validators_fail(self): + field = SimpleArrayField(forms.RegexField("[a-e]{2}")) + with self.assertRaises(exceptions.ValidationError) as cm: + field.clean("a,bc,de") + self.assertEqual( + cm.exception.messages[0], + "Item 1 in the array did not validate: Enter a valid value.", + ) + + def test_delimiter(self): + field = SimpleArrayField(forms.CharField(), delimiter="|") + value = field.clean("a|b|c") + self.assertEqual(value, ["a", "b", "c"]) + + def test_delimiter_with_nesting(self): + field = SimpleArrayField(SimpleArrayField(forms.CharField()), delimiter="|") + value = field.clean("a,b|c,d") + self.assertEqual(value, [["a", "b"], ["c", "d"]]) + + def test_prepare_value(self): + field = SimpleArrayField(forms.CharField()) + value = field.prepare_value(["a", "b", "c"]) + self.assertEqual(value, "a,b,c") + + def test_max_length(self): + field = SimpleArrayField(forms.CharField(), max_length=2) + with self.assertRaises(exceptions.ValidationError) as cm: + field.clean("a,b,c") + self.assertEqual( + cm.exception.messages[0], + "List contains 3 items, it should contain no more than 2.", + ) + + def test_min_length(self): + field = SimpleArrayField(forms.CharField(), min_length=4) + with self.assertRaises(exceptions.ValidationError) as cm: + field.clean("a,b,c") + self.assertEqual( + cm.exception.messages[0], + "List contains 3 items, it should contain no fewer than 4.", + ) + + def test_min_length_singular(self): + field = SimpleArrayField(forms.IntegerField(), min_length=2) + field.clean([1, 2]) + msg = "List contains 1 item, it should contain no fewer than 2." + with self.assertRaisesMessage(exceptions.ValidationError, msg): + field.clean([1]) + + def test_required(self): + field = SimpleArrayField(forms.CharField(), required=True) + with self.assertRaises(exceptions.ValidationError) as cm: + field.clean("") + self.assertEqual(cm.exception.messages[0], "This field is required.") + + def test_model_field_formfield(self): + model_field = ArrayField(models.CharField(max_length=27)) + form_field = model_field.formfield() + self.assertIsInstance(form_field, SimpleArrayField) + self.assertIsInstance(form_field.base_field, forms.CharField) + self.assertEqual(form_field.base_field.max_length, 27) + + def test_model_field_formfield_size(self): + model_field = ArrayField(models.CharField(max_length=27), size=4) + form_field = model_field.formfield() + self.assertIsInstance(form_field, SimpleArrayField) + self.assertEqual(form_field.max_length, 4) + + def test_model_field_choices(self): + model_field = ArrayField(models.IntegerField(choices=((1, "A"), (2, "B")))) + form_field = model_field.formfield() + self.assertEqual(form_field.clean("1,2"), [1, 2]) + + def test_already_converted_value(self): + field = SimpleArrayField(forms.CharField()) + vals = ["a", "b", "c"] + self.assertEqual(field.clean(vals), vals) + + def test_has_changed(self): + field = SimpleArrayField(forms.IntegerField()) + self.assertIs(field.has_changed([1, 2], [1, 2]), False) + self.assertIs(field.has_changed([1, 2], "1,2"), False) + self.assertIs(field.has_changed([1, 2], "1,2,3"), True) + self.assertIs(field.has_changed([1, 2], "a,b"), True) + + def test_has_changed_empty(self): + field = SimpleArrayField(forms.CharField()) + self.assertIs(field.has_changed(None, None), False) + self.assertIs(field.has_changed(None, ""), False) + self.assertIs(field.has_changed(None, []), False) + self.assertIs(field.has_changed([], None), False) + self.assertIs(field.has_changed([], ""), False) + + +# To locate the widget's template. +@modify_settings(INSTALLED_APPS={"append": "django_mongodb_backend"}) +class SplitFormFieldTests(SimpleTestCase): + def test_valid(self): + class SplitForm(forms.Form): + array = SplitArrayField(forms.CharField(), size=3) + + data = {"array_0": "a", "array_1": "b", "array_2": "c"} + form = SplitForm(data) + self.assertTrue(form.is_valid()) + self.assertEqual(form.cleaned_data, {"array": ["a", "b", "c"]}) + + def test_required(self): + class SplitForm(forms.Form): + array = SplitArrayField(forms.CharField(), required=True, size=3) + + data = {"array_0": "", "array_1": "", "array_2": ""} + form = SplitForm(data) + self.assertFalse(form.is_valid()) + self.assertEqual(form.errors, {"array": ["This field is required."]}) + + def test_remove_trailing_nulls(self): + class SplitForm(forms.Form): + array = SplitArrayField( + forms.CharField(required=False), size=5, remove_trailing_nulls=True + ) + + data = { + "array_0": "a", + "array_1": "", + "array_2": "b", + "array_3": "", + "array_4": "", + } + form = SplitForm(data) + self.assertTrue(form.is_valid(), form.errors) + self.assertEqual(form.cleaned_data, {"array": ["a", "", "b"]}) + + def test_remove_trailing_nulls_not_required(self): + class SplitForm(forms.Form): + array = SplitArrayField( + forms.CharField(required=False), + size=2, + remove_trailing_nulls=True, + required=False, + ) + + data = {"array_0": "", "array_1": ""} + form = SplitForm(data) + self.assertTrue(form.is_valid()) + self.assertEqual(form.cleaned_data, {"array": []}) + + def test_required_field(self): + class SplitForm(forms.Form): + array = SplitArrayField(forms.CharField(), size=3) + + data = {"array_0": "a", "array_1": "b", "array_2": ""} + form = SplitForm(data) + self.assertFalse(form.is_valid()) + self.assertEqual( + form.errors, + {"array": ["Item 3 in the array did not validate: This field is required."]}, + ) + + def test_invalid_integer(self): + msg = ( + "Item 2 in the array did not validate: Ensure this value is less than or " + "equal to 100." + ) + with self.assertRaisesMessage(exceptions.ValidationError, msg): + SplitArrayField(forms.IntegerField(max_value=100), size=2).clean([0, 101]) + + def test_rendering(self): + class SplitForm(forms.Form): + array = SplitArrayField(forms.CharField(), size=3) + + self.assertHTMLEqual( + str(SplitForm()), + """ +
+ + + + +
+ """, + ) + + def test_invalid_char_length(self): + field = SplitArrayField(forms.CharField(max_length=2), size=3) + with self.assertRaises(exceptions.ValidationError) as cm: + field.clean(["abc", "c", "defg"]) + self.assertEqual( + cm.exception.messages, + [ + "Item 1 in the array did not validate: Ensure this value has at most 2 " + "characters (it has 3).", + "Item 3 in the array did not validate: Ensure this value has at most 2 " + "characters (it has 4).", + ], + ) + + def test_splitarraywidget_value_omitted_from_data(self): + class Form(forms.ModelForm): + field = SplitArrayField(forms.IntegerField(), required=False, size=2) + + class Meta: + model = IntegerArrayModel + fields = ("field",) + + form = Form({"field_0": "1", "field_1": "2"}) + self.assertEqual(form.errors, {}) + obj = form.save(commit=False) + self.assertEqual(obj.field, [1, 2]) + + def test_splitarrayfield_has_changed(self): + class Form(forms.ModelForm): + field = SplitArrayField(forms.IntegerField(), required=False, size=2) + + class Meta: + model = IntegerArrayModel + fields = ("field",) + + tests = [ + ({}, {"field_0": "", "field_1": ""}, True), + ({"field": None}, {"field_0": "", "field_1": ""}, True), + ({"field": [1]}, {"field_0": "", "field_1": ""}, True), + ({"field": [1]}, {"field_0": "1", "field_1": "0"}, True), + ({"field": [1, 2]}, {"field_0": "1", "field_1": "2"}, False), + ({"field": [1, 2]}, {"field_0": "a", "field_1": "b"}, True), + ] + for initial, data, expected_result in tests: + with self.subTest(initial=initial, data=data): + obj = IntegerArrayModel(**initial) + form = Form(data, instance=obj) + self.assertIs(form.has_changed(), expected_result) + + def test_splitarrayfield_remove_trailing_nulls_has_changed(self): + class Form(forms.ModelForm): + field = SplitArrayField( + forms.IntegerField(), required=False, size=2, remove_trailing_nulls=True + ) + + class Meta: + model = IntegerArrayModel + fields = ("field",) + + tests = [ + ({}, {"field_0": "", "field_1": ""}, False), + ({"field": None}, {"field_0": "", "field_1": ""}, False), + ({"field": []}, {"field_0": "", "field_1": ""}, False), + ({"field": [1]}, {"field_0": "1", "field_1": ""}, False), + ] + for initial, data, expected_result in tests: + with self.subTest(initial=initial, data=data): + obj = IntegerArrayModel(**initial) + form = Form(data, instance=obj) + self.assertIs(form.has_changed(), expected_result) + + +# To locate the widget's template. +@modify_settings(INSTALLED_APPS={"append": "django_mongodb_backend"}) +class SplitArrayWidgetTests(WidgetTest, SimpleTestCase): + def test_get_context(self): + self.assertEqual( + SplitArrayWidget(forms.TextInput(), size=2).get_context("name", ["val1", "val2"]), + { + "widget": { + "name": "name", + "is_hidden": False, + "required": False, + "value": "['val1', 'val2']", + "attrs": {}, + "template_name": "mongodb/widgets/split_array.html", + "subwidgets": [ + { + "name": "name_0", + "is_hidden": False, + "required": False, + "value": "val1", + "attrs": {}, + "template_name": "django/forms/widgets/text.html", + "type": "text", + }, + { + "name": "name_1", + "is_hidden": False, + "required": False, + "value": "val2", + "attrs": {}, + "template_name": "django/forms/widgets/text.html", + "type": "text", + }, + ], + } + }, + ) + + def test_checkbox_get_context_attrs(self): + context = SplitArrayWidget( + forms.CheckboxInput(), + size=2, + ).get_context("name", [True, False]) + self.assertEqual(context["widget"]["value"], "[True, False]") + self.assertEqual( + [subwidget["attrs"] for subwidget in context["widget"]["subwidgets"]], + [{"checked": True}, {}], + ) + + def test_render(self): + self.check_html( + SplitArrayWidget(forms.TextInput(), size=2), + "array", + None, + """ + + + """, + ) + + def test_render_attrs(self): + self.check_html( + SplitArrayWidget(forms.TextInput(), size=2), + "array", + ["val1", "val2"], + attrs={"id": "foo"}, + html=( + """ + + + """ + ), + ) + + def test_value_omitted_from_data(self): + widget = SplitArrayWidget(forms.TextInput(), size=2) + self.assertIs(widget.value_omitted_from_data({}, {}, "field"), True) + self.assertIs(widget.value_omitted_from_data({"field_0": "value"}, {}, "field"), False) + self.assertIs(widget.value_omitted_from_data({"field_1": "value"}, {}, "field"), False) + self.assertIs( + widget.value_omitted_from_data({"field_0": "value", "field_1": "value"}, {}, "field"), + False, + ) diff --git a/tests/model_fields_/array_default_migrations/0001_initial.py b/tests/model_fields_/array_default_migrations/0001_initial.py new file mode 100644 index 00000000..4faaed19 --- /dev/null +++ b/tests/model_fields_/array_default_migrations/0001_initial.py @@ -0,0 +1,30 @@ +from django.db import migrations, models + +import django_mongodb_backend + + +class Migration(migrations.Migration): + dependencies = [] + + operations = [ + migrations.CreateModel( + name="IntegerArrayDefaultModel", + fields=[ + ( + "id", + django_mongodb_backend.fields.ObjectIdAutoField( + verbose_name="ID", + serialize=False, + auto_created=True, + primary_key=True, + ), + ), + ( + "field", + django_mongodb_backend.fields.ArrayField(models.IntegerField(), size=None), + ), + ], + options={}, + bases=(models.Model,), + ), + ] diff --git a/tests/model_fields_/array_default_migrations/0002_integerarraymodel_field_2.py b/tests/model_fields_/array_default_migrations/0002_integerarraymodel_field_2.py new file mode 100644 index 00000000..90f49499 --- /dev/null +++ b/tests/model_fields_/array_default_migrations/0002_integerarraymodel_field_2.py @@ -0,0 +1,20 @@ +from django.db import migrations, models + +import django_mongodb_backend + + +class Migration(migrations.Migration): + dependencies = [ + ("model_fields_", "0001_initial"), + ] + + operations = [ + migrations.AddField( + model_name="integerarraydefaultmodel", + name="field_2", + field=django_mongodb_backend.fields.ArrayField( + models.IntegerField(), default=[], size=None + ), + preserve_default=False, + ), + ] diff --git a/tests/model_fields_/array_default_migrations/__init__.py b/tests/model_fields_/array_default_migrations/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/model_fields_/array_index_migrations/0001_initial.py b/tests/model_fields_/array_index_migrations/0001_initial.py new file mode 100644 index 00000000..a32b0529 --- /dev/null +++ b/tests/model_fields_/array_index_migrations/0001_initial.py @@ -0,0 +1,37 @@ +from django.db import migrations, models + +import django_mongodb_backend + + +class Migration(migrations.Migration): + dependencies = [] + + operations = [ + migrations.CreateModel( + name="CharTextArrayIndexModel", + fields=[ + ( + "id", + django_mongodb_backend.fields.ObjectIdAutoField( + verbose_name="ID", + serialize=False, + auto_created=True, + primary_key=True, + ), + ), + ( + "char", + django_mongodb_backend.fields.ArrayField( + models.CharField(max_length=10), db_index=True, size=100 + ), + ), + ("char2", models.CharField(max_length=11, db_index=True)), + ( + "text", + django_mongodb_backend.fields.ArrayField(models.TextField(), db_index=True), + ), + ], + options={}, + bases=(models.Model,), + ), + ] diff --git a/tests/model_fields_/array_index_migrations/__init__.py b/tests/model_fields_/array_index_migrations/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/model_fields_/models.py b/tests/model_fields_/models.py index 983827ed..9b2d96ec 100644 --- a/tests/model_fields_/models.py +++ b/tests/model_fields_/models.py @@ -1,6 +1,8 @@ +import enum + from django.db import models -from django_mongodb_backend.fields import ObjectIdField +from django_mongodb_backend.fields import ArrayField, ObjectIdField class ObjectIdModel(models.Model): @@ -13,3 +15,77 @@ class NullableObjectIdModel(models.Model): class PrimaryKeyObjectIdModel(models.Model): field = ObjectIdField(primary_key=True) + + +class ArrayFieldSubclass(ArrayField): + def __init__(self, *args, **kwargs): + super().__init__(models.IntegerField()) + + +class Tag: + def __init__(self, tag_id): + self.tag_id = tag_id + + def __eq__(self, other): + return isinstance(other, Tag) and self.tag_id == other.tag_id + + +class TagField(models.SmallIntegerField): + def from_db_value(self, value, expression, connection): + if value is None: + return value + return Tag(int(value)) + + def to_python(self, value): + if isinstance(value, Tag): + return value + if value is None: + return value + return Tag(int(value)) + + def get_prep_value(self, value): + return value.tag_id + + +class IntegerArrayModel(models.Model): + field = ArrayField(models.IntegerField(), default=list, blank=True) + + +class NullableIntegerArrayModel(models.Model): + field = ArrayField(models.IntegerField(), blank=True, null=True) + field_nested = ArrayField(ArrayField(models.IntegerField(null=True)), null=True) + order = models.IntegerField(null=True) + + def __str__(self): + return str(self.field) + + +class CharArrayModel(models.Model): + field = ArrayField(models.CharField(max_length=10)) + + +class DateTimeArrayModel(models.Model): + datetimes = ArrayField(models.DateTimeField()) + dates = ArrayField(models.DateField()) + times = ArrayField(models.TimeField()) + + +class NestedIntegerArrayModel(models.Model): + field = ArrayField(ArrayField(models.IntegerField())) + + +class OtherTypesArrayModel(models.Model): + ips = ArrayField(models.GenericIPAddressField(), default=list) + uuids = ArrayField(models.UUIDField(), default=list) + decimals = ArrayField(models.DecimalField(max_digits=5, decimal_places=2), default=list) + tags = ArrayField(TagField(), blank=True, null=True) + json = ArrayField(models.JSONField(default=dict), default=list) + + +class EnumField(models.CharField): + def get_prep_value(self, value): + return value.value if isinstance(value, enum.Enum) else value + + +class ArrayEnumModel(models.Model): + array_of_enums = ArrayField(EnumField(max_length=20)) diff --git a/tests/model_fields_/test_arrayfield.py b/tests/model_fields_/test_arrayfield.py new file mode 100644 index 00000000..3b3c4c0c --- /dev/null +++ b/tests/model_fields_/test_arrayfield.py @@ -0,0 +1,882 @@ +import decimal +import enum +import json +import unittest +import uuid + +from django.contrib.admin.utils import display_for_field +from django.core import checks, exceptions, serializers, validators +from django.core.exceptions import FieldError +from django.core.management import call_command +from django.db import IntegrityError, connection, models +from django.db.models.expressions import Exists, OuterRef, Value +from django.db.models.functions import Upper +from django.test import ( + SimpleTestCase, + TestCase, + TransactionTestCase, + override_settings, +) +from django.test.utils import isolate_apps +from django.utils import timezone + +from django_mongodb_backend.fields import ArrayField + +from .models import ( + ArrayEnumModel, + ArrayFieldSubclass, + CharArrayModel, + DateTimeArrayModel, + IntegerArrayModel, + NestedIntegerArrayModel, + NullableIntegerArrayModel, + OtherTypesArrayModel, + Tag, +) + + +@isolate_apps("model_fields_") +class BasicTests(SimpleTestCase): + def test_get_field_display(self): + class MyModel(models.Model): + field = ArrayField( + models.CharField(max_length=16), + choices=[ + ["Media", [(["vinyl", "cd"], "Audio")]], + (("mp3", "mp4"), "Digital"), + ], + ) + + tests = ( + (["vinyl", "cd"], "Audio"), + (("mp3", "mp4"), "Digital"), + (("a", "b"), "('a', 'b')"), + (["c", "d"], "['c', 'd']"), + ) + for value, display in tests: + with self.subTest(value=value, display=display): + instance = MyModel(field=value) + self.assertEqual(instance.get_field_display(), display) + + def test_get_field_display_nested_array(self): + class MyModel(models.Model): + field = ArrayField( + ArrayField(models.CharField(max_length=16)), + choices=[ + [ + "Media", + [([["vinyl", "cd"], ("x",)], "Audio")], + ], + ((["mp3"], ("mp4",)), "Digital"), + ], + ) + + tests = ( + ([["vinyl", "cd"], ("x",)], "Audio"), + ((["mp3"], ("mp4",)), "Digital"), + ((("a", "b"), ("c",)), "(('a', 'b'), ('c',))"), + ([["a", "b"], ["c"]], "[['a', 'b'], ['c']]"), + ) + for value, display in tests: + with self.subTest(value=value, display=display): + instance = MyModel(field=value) + self.assertEqual(instance.get_field_display(), display) + + def test_deconstruct(self): + field = ArrayField(models.IntegerField()) + name, path, args, kwargs = field.deconstruct() + new = ArrayField(*args, **kwargs) + self.assertEqual(type(new.base_field), type(field.base_field)) + self.assertIsNot(new.base_field, field.base_field) + + def test_deconstruct_with_size(self): + field = ArrayField(models.IntegerField(), size=3) + name, path, args, kwargs = field.deconstruct() + new = ArrayField(*args, **kwargs) + self.assertEqual(new.size, field.size) + + def test_deconstruct_args(self): + field = ArrayField(models.CharField(max_length=20)) + name, path, args, kwargs = field.deconstruct() + new = ArrayField(*args, **kwargs) + self.assertEqual(new.base_field.max_length, field.base_field.max_length) + + def test_subclass_deconstruct(self): + field = ArrayField(models.IntegerField()) + name, path, args, kwargs = field.deconstruct() + self.assertEqual(path, "django_mongodb_backend.fields.ArrayField") + + field = ArrayFieldSubclass() + name, path, args, kwargs = field.deconstruct() + self.assertEqual(path, "model_fields_.models.ArrayFieldSubclass") + + +class SaveLoadTests(TestCase): + def test_integer(self): + instance = IntegerArrayModel(field=[1, 2, 3]) + instance.save() + loaded = IntegerArrayModel.objects.get() + self.assertEqual(instance.field, loaded.field) + + def test_char(self): + instance = CharArrayModel(field=["hello", "goodbye"]) + instance.save() + loaded = CharArrayModel.objects.get() + self.assertEqual(instance.field, loaded.field) + + def test_dates(self): + instance = DateTimeArrayModel( + datetimes=[timezone.now().replace(microsecond=0)], + dates=[timezone.now().date()], + times=[timezone.now().time().replace(microsecond=0)], + ) + instance.save() + loaded = DateTimeArrayModel.objects.get() + self.assertEqual(instance.datetimes, loaded.datetimes) + self.assertEqual(instance.dates, loaded.dates) + self.assertEqual(instance.times, loaded.times) + + def test_tuples(self): + instance = IntegerArrayModel(field=(1,)) + instance.save() + loaded = IntegerArrayModel.objects.get() + self.assertSequenceEqual(instance.field, loaded.field) + + def test_integers_passed_as_strings(self): + # This checks that get_prep_value() is deferred properly. + instance = IntegerArrayModel(field=["1"]) + instance.save() + loaded = IntegerArrayModel.objects.get() + self.assertEqual(loaded.field, [1]) + + def test_default_null(self): + instance = NullableIntegerArrayModel() + instance.save() + loaded = NullableIntegerArrayModel.objects.get(pk=instance.pk) + self.assertIsNone(loaded.field) + self.assertEqual(instance.field, loaded.field) + + def test_null_handling(self): + instance = NullableIntegerArrayModel(field=None) + instance.save() + loaded = NullableIntegerArrayModel.objects.get() + self.assertEqual(instance.field, loaded.field) + + def test_save_null_in_non_null(self): + instance = IntegerArrayModel(field=None) + msg = "You can't set field (a non-nullable field) to None." + with self.assertRaisesMessage(IntegrityError, msg): + instance.save() + + def test_nested(self): + instance = NestedIntegerArrayModel(field=[[1, 2], [3, 4]]) + instance.save() + loaded = NestedIntegerArrayModel.objects.get() + self.assertEqual(instance.field, loaded.field) + + def test_other_array_types(self): + instance = OtherTypesArrayModel( + ips=["192.168.0.1", "::1"], + uuids=[uuid.uuid4()], + decimals=[decimal.Decimal(1.25), 1.75], + tags=[Tag(1), Tag(2), Tag(3)], + json=[{"a": 1}, {"b": 2}], + ) + instance.save() + loaded = OtherTypesArrayModel.objects.get() + self.assertEqual(instance.ips, loaded.ips) + self.assertEqual(instance.uuids, loaded.uuids) + self.assertEqual(instance.decimals, loaded.decimals) + self.assertEqual(instance.tags, loaded.tags) + self.assertEqual(instance.json, loaded.json) + + def test_null_from_db_value_handling(self): + instance = OtherTypesArrayModel.objects.create( + ips=["192.168.0.1", "::1"], + uuids=[uuid.uuid4()], + decimals=[decimal.Decimal(1.25), 1.75], + tags=None, + ) + instance.refresh_from_db() + self.assertIsNone(instance.tags) + self.assertEqual(instance.json, []) + + def test_model_set_on_base_field(self): + instance = IntegerArrayModel() + field = instance._meta.get_field("field") + self.assertEqual(field.model, IntegerArrayModel) + self.assertEqual(field.base_field.model, IntegerArrayModel) + + def test_nested_nullable_base_field(self): + instance = NullableIntegerArrayModel.objects.create( + field_nested=[[None, None], [None, None]], + ) + self.assertEqual(instance.field_nested, [[None, None], [None, None]]) + + +class QueryingTests(TestCase): + @classmethod + def setUpTestData(cls): + cls.objs = NullableIntegerArrayModel.objects.bulk_create( + [ + NullableIntegerArrayModel(order=1, field=[1]), + NullableIntegerArrayModel(order=2, field=[2]), + NullableIntegerArrayModel(order=3, field=[2, 3]), + NullableIntegerArrayModel(order=4, field=[20, 30, 40]), + NullableIntegerArrayModel(order=5, field=None), + ] + ) + + def test_empty_list(self): + NullableIntegerArrayModel.objects.create(field=[]) + obj = ( + NullableIntegerArrayModel.objects.annotate( + empty_array=models.Value([], output_field=ArrayField(models.IntegerField())), + ) + .filter(field=models.F("empty_array")) + .get() + ) + self.assertEqual(obj.field, []) + self.assertEqual(obj.empty_array, []) + + def test_exact(self): + self.assertSequenceEqual( + NullableIntegerArrayModel.objects.filter(field__exact=[1]), self.objs[:1] + ) + + def test_exact_null_only_array(self): + obj = NullableIntegerArrayModel.objects.create(field=[None], field_nested=[None, None]) + self.assertSequenceEqual( + NullableIntegerArrayModel.objects.filter(field__exact=[None]), [obj] + ) + self.assertSequenceEqual( + NullableIntegerArrayModel.objects.filter(field_nested__exact=[None, None]), + [obj], + ) + + def test_exact_null_only_nested_array(self): + obj1 = NullableIntegerArrayModel.objects.create(field_nested=[[None, None]]) + obj2 = NullableIntegerArrayModel.objects.create( + field_nested=[[None, None], [None, None]], + ) + self.assertSequenceEqual( + NullableIntegerArrayModel.objects.filter( + field_nested__exact=[[None, None]], + ), + [obj1], + ) + self.assertSequenceEqual( + NullableIntegerArrayModel.objects.filter( + field_nested__exact=[[None, None], [None, None]], + ), + [obj2], + ) + + def test_exact_with_expression(self): + self.assertSequenceEqual( + NullableIntegerArrayModel.objects.filter(field__exact=[Value(1)]), + self.objs[:1], + ) + + def test_exact_charfield(self): + instance = CharArrayModel.objects.create(field=["text"]) + self.assertSequenceEqual(CharArrayModel.objects.filter(field=["text"]), [instance]) + + def test_exact_nested(self): + instance = NestedIntegerArrayModel.objects.create(field=[[1, 2], [3, 4]]) + self.assertSequenceEqual( + NestedIntegerArrayModel.objects.filter(field=[[1, 2], [3, 4]]), [instance] + ) + + def test_isnull(self): + self.assertSequenceEqual( + NullableIntegerArrayModel.objects.filter(field__isnull=True), self.objs[-1:] + ) + + def test_gt(self): + self.assertSequenceEqual( + NullableIntegerArrayModel.objects.filter(field__gt=[0]), self.objs[:4] + ) + + def test_lt(self): + self.assertSequenceEqual( + NullableIntegerArrayModel.objects.filter(field__lt=[2]), self.objs[:1] + ) + + def test_in(self): + self.assertSequenceEqual( + NullableIntegerArrayModel.objects.filter(field__in=[[1], [2]]), + self.objs[:2], + ) + + def test_in_subquery(self): + IntegerArrayModel.objects.create(field=[2, 3]) + self.assertSequenceEqual( + NullableIntegerArrayModel.objects.filter( + field__in=IntegerArrayModel.objects.values_list("field", flat=True) + ), + self.objs[2:3], + ) + + @unittest.expectedFailure + def test_in_including_F_object(self): + # Array objects passed to filters can be constructed to contain + # F objects. This doesn't work on PostgreSQL either (#27095). + self.assertSequenceEqual( + NullableIntegerArrayModel.objects.filter(field__in=[[models.F("id")]]), + self.objs[:2], + ) + + def test_in_as_F_object(self): + self.assertSequenceEqual( + NullableIntegerArrayModel.objects.filter(field__in=[models.F("field")]), + # Unlike PostgreSQL, MongoDB returns documents with field=null, + # i.e. null is in [null]. It seems okay to leave this alone rather + # than filtering out null in all $in queries. Feel free to + # reconsider this decision if the behavior is problematic in some + # other query. + self.objs, + ) + + def test_contained_by(self): + self.assertSequenceEqual( + NullableIntegerArrayModel.objects.filter(field__contained_by=[1, 2]), + self.objs[:2], + ) + + def test_contained_by_including_F_object(self): + self.assertSequenceEqual( + NullableIntegerArrayModel.objects.filter(field__contained_by=[models.F("order"), 2]), + self.objs[:3], + ) + + def test_contains(self): + self.assertSequenceEqual( + NullableIntegerArrayModel.objects.filter(field__contains=[2]), + self.objs[1:3], + ) + + def test_contains_subquery(self): + IntegerArrayModel.objects.create(field=[2, 3]) + inner_qs = IntegerArrayModel.objects.values_list("field", flat=True) + self.assertSequenceEqual( + NullableIntegerArrayModel.objects.filter(field__contains=inner_qs[:1]), + self.objs[2:3], + ) + inner_qs = IntegerArrayModel.objects.filter(field__contains=OuterRef("field")) + self.assertSequenceEqual( + NullableIntegerArrayModel.objects.filter(Exists(inner_qs)), + self.objs[1:3], + ) + + def test_contained_by_subquery(self): + IntegerArrayModel.objects.create(field=[2, 3]) + inner_qs = IntegerArrayModel.objects.values_list("field", flat=True) + self.assertSequenceEqual( + NullableIntegerArrayModel.objects.filter(field__contained_by=inner_qs[:1]), + self.objs[1:3], + ) + IntegerArrayModel.objects.create(field=[2]) + inner_qs = IntegerArrayModel.objects.filter(field__contained_by=OuterRef("field")) + self.assertSequenceEqual( + NullableIntegerArrayModel.objects.filter(Exists(inner_qs)), + self.objs[1:3], + ) + + def test_contains_including_expression(self): + self.assertSequenceEqual( + NullableIntegerArrayModel.objects.filter( + field__contains=[2, Value(6) / Value(2)], + ), + self.objs[2:3], + ) + + def test_icontains(self): + instance = CharArrayModel.objects.create(field=["FoO"]) + self.assertSequenceEqual(CharArrayModel.objects.filter(field__icontains="foo"), [instance]) + + def test_contains_charfield(self): + self.assertSequenceEqual(CharArrayModel.objects.filter(field__contains=["text"]), []) + + def test_contained_by_charfield(self): + self.assertSequenceEqual(CharArrayModel.objects.filter(field__contained_by=["text"]), []) + + def test_overlap_charfield(self): + self.assertSequenceEqual(CharArrayModel.objects.filter(field__overlap=["text"]), []) + + def test_overlap_charfield_including_expression(self): + obj_1 = CharArrayModel.objects.create(field=["TEXT", "lower text"]) + obj_2 = CharArrayModel.objects.create(field=["lower text", "TEXT"]) + CharArrayModel.objects.create(field=["lower text", "text"]) + self.assertSequenceEqual( + CharArrayModel.objects.filter( + field__overlap=[ + Upper(Value("text")), + "other", + ] + ), + [obj_1, obj_2], + ) + + def test_overlap_values(self): + qs = NullableIntegerArrayModel.objects.filter(order__lt=3) + self.assertCountEqual( + NullableIntegerArrayModel.objects.filter( + field__overlap=qs.values_list("field"), + ), + self.objs[:3], + ) + self.assertCountEqual( + NullableIntegerArrayModel.objects.filter( + field__overlap=qs.values("field"), + ), + self.objs[:3], + ) + + def test_index(self): + self.assertSequenceEqual( + NullableIntegerArrayModel.objects.filter(field__0=2), self.objs[1:3] + ) + + def test_index_chained(self): + self.assertSequenceEqual( + NullableIntegerArrayModel.objects.filter(field__0__lt=3), self.objs[0:3] + ) + + def test_index_nested(self): + instance = NestedIntegerArrayModel.objects.create(field=[[1, 2], [3, 4]]) + self.assertSequenceEqual(NestedIntegerArrayModel.objects.filter(field__0__0=1), [instance]) + + def test_index_used_on_nested_data(self): + instance = NestedIntegerArrayModel.objects.create(field=[[1, 2], [3, 4]]) + self.assertSequenceEqual( + NestedIntegerArrayModel.objects.filter(field__0=[1, 2]), [instance] + ) + + def test_overlap(self): + self.assertSequenceEqual( + NullableIntegerArrayModel.objects.filter(field__overlap=[1, 2]), + self.objs[0:3], + ) + + def test_index_annotation(self): + qs = NullableIntegerArrayModel.objects.annotate(second=models.F("field__1")) + self.assertCountEqual( + qs.values_list("second", flat=True), + [None, None, None, 3, 30], + ) + + def test_len(self): + self.assertSequenceEqual( + NullableIntegerArrayModel.objects.filter(field__len__lte=2), self.objs[0:3] + ) + + def test_len_empty_array(self): + obj = NullableIntegerArrayModel.objects.create(field=[]) + self.assertSequenceEqual(NullableIntegerArrayModel.objects.filter(field__len=0), [obj]) + + def test_slice(self): + self.assertSequenceEqual( + NullableIntegerArrayModel.objects.filter(field__0_1=[2]), self.objs[1:3] + ) + + self.assertSequenceEqual( + NullableIntegerArrayModel.objects.filter(field__0_2=[2, 3]), self.objs[2:3] + ) + + def test_order_by_index(self): + more_objs = ( + NullableIntegerArrayModel.objects.create(field=[1, 637]), + NullableIntegerArrayModel.objects.create(field=[2, 1]), + NullableIntegerArrayModel.objects.create(field=[3, -98123]), + NullableIntegerArrayModel.objects.create(field=[4, 2]), + ) + self.assertSequenceEqual( + NullableIntegerArrayModel.objects.order_by("field__1"), + [ + self.objs[0], + self.objs[1], + self.objs[4], + more_objs[2], + more_objs[1], + more_objs[3], + self.objs[2], + self.objs[3], + more_objs[0], + ], + ) + + def test_slice_nested(self): + instance = NestedIntegerArrayModel.objects.create(field=[[1, 2], [3, 4]]) + self.assertSequenceEqual( + NestedIntegerArrayModel.objects.filter(field__0__0_1=[1]), [instance] + ) + + def test_slice_annotation(self): + qs = NullableIntegerArrayModel.objects.annotate( + first_two=models.F("field__0_2"), + ) + self.assertCountEqual( + qs.values_list("first_two", flat=True), + [None, [1], [2], [2, 3], [20, 30]], + ) + + def test_usage_in_subquery(self): + self.assertSequenceEqual( + NullableIntegerArrayModel.objects.filter( + id__in=NullableIntegerArrayModel.objects.filter(field__len=3) + ), + [self.objs[3]], + ) + + def test_enum_lookup(self): + class TestEnum(enum.Enum): + VALUE_1 = "value_1" + + instance = ArrayEnumModel.objects.create(array_of_enums=[TestEnum.VALUE_1]) + self.assertSequenceEqual( + ArrayEnumModel.objects.filter(array_of_enums__contains=[TestEnum.VALUE_1]), + [instance], + ) + + def test_unsupported_lookup(self): + msg = "Unsupported lookup '0_bar' for ArrayField or join on the field not " "permitted." + with self.assertRaisesMessage(FieldError, msg): + list(NullableIntegerArrayModel.objects.filter(field__0_bar=[2])) + + msg = "Unsupported lookup '0bar' for ArrayField or join on the field not " "permitted." + with self.assertRaisesMessage(FieldError, msg): + list(NullableIntegerArrayModel.objects.filter(field__0bar=[2])) + + +class DateTimeExactQueryingTests(TestCase): + @classmethod + def setUpTestData(cls): + now = timezone.now() + cls.datetimes = [now] + cls.dates = [now.date()] + cls.times = [now.time()] + cls.objs = [ + DateTimeArrayModel.objects.create( + datetimes=cls.datetimes, dates=cls.dates, times=cls.times + ), + ] + + def test_exact_datetimes(self): + self.assertSequenceEqual( + DateTimeArrayModel.objects.filter(datetimes=self.datetimes), self.objs + ) + + def test_exact_dates(self): + self.assertSequenceEqual(DateTimeArrayModel.objects.filter(dates=self.dates), self.objs) + + def test_exact_times(self): + self.assertSequenceEqual(DateTimeArrayModel.objects.filter(times=self.times), self.objs) + + +class OtherTypesExactQueryingTests(TestCase): + @classmethod + def setUpTestData(cls): + cls.ips = ["192.168.0.1", "::1"] + cls.uuids = [uuid.uuid4()] + cls.decimals = [decimal.Decimal(1.25), 1.75] + cls.tags = [Tag(1), Tag(2), Tag(3)] + cls.objs = [ + OtherTypesArrayModel.objects.create( + ips=cls.ips, + uuids=cls.uuids, + decimals=cls.decimals, + tags=cls.tags, + ) + ] + + def test_exact_ip_addresses(self): + self.assertSequenceEqual(OtherTypesArrayModel.objects.filter(ips=self.ips), self.objs) + + def test_exact_uuids(self): + self.assertSequenceEqual(OtherTypesArrayModel.objects.filter(uuids=self.uuids), self.objs) + + def test_exact_decimals(self): + self.assertSequenceEqual( + OtherTypesArrayModel.objects.filter(decimals=self.decimals), self.objs + ) + + def test_exact_tags(self): + self.assertSequenceEqual(OtherTypesArrayModel.objects.filter(tags=self.tags), self.objs) + + +@isolate_apps("model_fields_") +class CheckTests(SimpleTestCase): + def test_base_field_errors(self): + class MyModel(models.Model): + field = ArrayField(models.CharField(max_length=-1)) + + model = MyModel() + errors = model.check() + self.assertEqual(len(errors), 1) + # The inner CharField has a non-positive max_length. + self.assertEqual(errors[0].id, "django_mongodb_backend.array.E001") + msg = errors[0].msg + self.assertIn("Base field for array has errors:", msg) + self.assertIn("'max_length' must be a positive integer. (fields.E121)", msg) + + def test_base_field_warnings(self): + class WarningField(models.IntegerField): + def check(self): + return [checks.Warning("Test warning", obj=self, id="test.E001")] + + class MyModel(models.Model): + field = ArrayField(WarningField(), default=None) + + model = MyModel() + errors = model.check() + self.assertEqual(len(errors), 1) + self.assertEqual(errors[0].id, "django_mongodb_backend.array.W004") + msg = errors[0].msg + self.assertIn("Base field for array has warnings:", msg) + self.assertIn("Test warning (test.E001)", msg) + + def test_invalid_base_fields(self): + class MyModel(models.Model): + field = ArrayField(models.ManyToManyField("model_fields_.IntegerArrayModel")) + + model = MyModel() + errors = model.check() + self.assertEqual(len(errors), 1) + self.assertEqual(errors[0].id, "django_mongodb_backend.array.E002") + + def test_invalid_default(self): + class MyModel(models.Model): + field = ArrayField(models.IntegerField(), default=[]) + + model = MyModel() + self.assertEqual( + model.check(), + [ + checks.Warning( + msg=( + "ArrayField default should be a callable instead of an " + "instance so that it's not shared between all field " + "instances." + ), + hint="Use a callable instead, e.g., use `list` instead of `[]`.", + obj=MyModel._meta.get_field("field"), + id="fields.E010", + ) + ], + ) + + def test_valid_default(self): + class MyModel(models.Model): + field = ArrayField(models.IntegerField(), default=list) + + model = MyModel() + self.assertEqual(model.check(), []) + + def test_valid_default_none(self): + class MyModel(models.Model): + field = ArrayField(models.IntegerField(), default=None) + + model = MyModel() + self.assertEqual(model.check(), []) + + def test_nested_field_checks(self): + """ + Nested ArrayFields are permitted. + """ + + class MyModel(models.Model): + field = ArrayField(ArrayField(models.CharField(max_length=-1))) + + model = MyModel() + errors = model.check() + self.assertEqual(len(errors), 1) + # The inner CharField has a non-positive max_length. + self.assertEqual(errors[0].id, "django_mongodb_backend.array.E001") + self.assertIn("max_length", errors[0].msg) + + def test_choices_tuple_list(self): + class MyModel(models.Model): + field = ArrayField( + models.CharField(max_length=16), + choices=[ + [ + "Media", + [(["vinyl", "cd"], "Audio"), (("vhs", "dvd"), "Video")], + ], + (["mp3", "mp4"], "Digital"), + ], + ) + + self.assertEqual(MyModel._meta.get_field("field").check(), []) + + +@isolate_apps("model_fields_") +class MigrationsTests(TransactionTestCase): + available_apps = ["model_fields_"] + + @override_settings( + MIGRATION_MODULES={ + "model_fields_": "model_fields_.array_default_migrations", + } + ) + def test_adding_field_with_default(self): + class IntegerArrayDefaultModel(models.Model): + field = ArrayField(models.IntegerField(), size=None) + + table_name = "model_fields__integerarraydefaultmodel" + self.assertNotIn(table_name, connection.introspection.table_names(None)) + # Create collection + call_command("migrate", "model_fields_", "0001", verbosity=0) + self.assertIn(table_name, connection.introspection.table_names(None)) + obj = IntegerArrayDefaultModel.objects.create(field=[1, 2, 3]) + # Add `field2 to IntegerArrayDefaultModel. + call_command("migrate", "model_fields_", "0002", verbosity=0) + + class UpdatedIntegerArrayDefaultModel(models.Model): + field = ArrayField(models.IntegerField(), size=None) + field_2 = ArrayField(models.IntegerField(), default=[], size=None) + + class Meta: + db_table = "model_fields__integerarraydefaultmodel" + + obj = UpdatedIntegerArrayDefaultModel.objects.get() + # The default is populated on existing documents. + self.assertEqual(obj.field_2, []) + # Cleanup. + call_command("migrate", "model_fields_", "zero", verbosity=0) + self.assertNotIn(table_name, connection.introspection.table_names(None)) + + @override_settings( + MIGRATION_MODULES={ + "model_fields_": "model_fields_.array_index_migrations", + } + ) + def test_adding_arrayfield_with_index(self): + table_name = "model_fields__chartextarrayindexmodel" + call_command("migrate", "model_fields_", verbosity=0) + # All fields should have indexes. + indexes = [ + c["columns"][0] + for c in connection.introspection.get_constraints(None, table_name).values() + if c["index"] and len(c["columns"]) == 1 + ] + self.assertIn("char", indexes) + self.assertIn("char2", indexes) + self.assertIn("text", indexes) + call_command("migrate", "model_fields_", "zero", verbosity=0) + self.assertNotIn(table_name, connection.introspection.table_names(None)) + + +class SerializationTests(SimpleTestCase): + test_data = ( + '[{"fields": {"field": "[\\"1\\", \\"2\\", null]"}, ' + '"model": "model_fields_.integerarraymodel", "pk": null}]' + ) + + def test_dumping(self): + instance = IntegerArrayModel(field=[1, 2, None]) + data = serializers.serialize("json", [instance]) + self.assertEqual(json.loads(data), json.loads(self.test_data)) + + def test_loading(self): + instance = next(serializers.deserialize("json", self.test_data)).object + self.assertEqual(instance.field, [1, 2, None]) + + +class ValidationTests(SimpleTestCase): + def test_unbounded(self): + field = ArrayField(models.IntegerField()) + with self.assertRaises(exceptions.ValidationError) as cm: + field.clean([1, None], None) + self.assertEqual(cm.exception.code, "item_invalid") + self.assertEqual( + cm.exception.message % cm.exception.params, + "Item 2 in the array did not validate: This field cannot be null.", + ) + + def test_blank_true(self): + field = ArrayField(models.IntegerField(blank=True, null=True)) + # This should not raise a validation error + field.clean([1, None], None) + + def test_with_size(self): + field = ArrayField(models.IntegerField(), size=3) + field.clean([1, 2, 3], None) + with self.assertRaises(exceptions.ValidationError) as cm: + field.clean([1, 2, 3, 4], None) + self.assertEqual( + cm.exception.messages[0], + "List contains 4 items, it should contain no more than 3.", + ) + + def test_with_size_singular(self): + field = ArrayField(models.IntegerField(), size=1) + field.clean([1], None) + msg = "List contains 2 items, it should contain no more than 1." + with self.assertRaisesMessage(exceptions.ValidationError, msg): + field.clean([1, 2], None) + + def test_nested_array_mismatch(self): + field = ArrayField(ArrayField(models.IntegerField())) + field.clean([[1, 2], [3, 4]], None) + with self.assertRaises(exceptions.ValidationError) as cm: + field.clean([[1, 2], [3, 4, 5]], None) + self.assertEqual(cm.exception.code, "nested_array_mismatch") + self.assertEqual(cm.exception.messages[0], "Nested arrays must have the same length.") + + def test_with_base_field_error_params(self): + field = ArrayField(models.CharField(max_length=2)) + with self.assertRaises(exceptions.ValidationError) as cm: + field.clean(["abc"], None) + self.assertEqual(len(cm.exception.error_list), 1) + exception = cm.exception.error_list[0] + self.assertEqual( + exception.message, + "Item 1 in the array did not validate: Ensure this value has at most 2 " + "characters (it has 3).", + ) + self.assertEqual(exception.code, "item_invalid") + self.assertEqual( + exception.params, + {"nth": 1, "value": "abc", "limit_value": 2, "show_value": 3}, + ) + + def test_with_validators(self): + field = ArrayField(models.IntegerField(validators=[validators.MinValueValidator(1)])) + field.clean([1, 2], None) + with self.assertRaises(exceptions.ValidationError) as cm: + field.clean([0], None) + self.assertEqual(len(cm.exception.error_list), 1) + exception = cm.exception.error_list[0] + self.assertEqual( + exception.message, + "Item 1 in the array did not validate: Ensure this value is greater than " + "or equal to 1.", + ) + self.assertEqual(exception.code, "item_invalid") + self.assertEqual( + exception.params, {"nth": 1, "value": 0, "limit_value": 1, "show_value": 0} + ) + + +class AdminUtilsTests(SimpleTestCase): + empty_value = "-empty-" + + def test_array_display_for_field(self): + array_field = ArrayField(models.IntegerField()) + display_value = display_for_field([1, 2], array_field, self.empty_value) + self.assertEqual(display_value, "1, 2") + + def test_array_with_choices_display_for_field(self): + array_field = ArrayField( + models.IntegerField(), + choices=[ + ([1, 2, 3], "1st choice"), + ([1, 2], "2nd choice"), + ], + ) + display_value = display_for_field([1, 2], array_field, self.empty_value) + self.assertEqual(display_value, "2nd choice") + display_value = display_for_field([99, 99], array_field, self.empty_value) + self.assertEqual(display_value, self.empty_value)