diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 2ceb70d9..60ea59c6 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -44,7 +44,7 @@ repos: hooks: - id: rstcheck additional_dependencies: [sphinx] - args: ["--ignore-directives=fieldlookup,setting", "--ignore-roles=lookup,setting"] + args: ["--ignore-directives=django-admin,fieldlookup,setting", "--ignore-roles=djadmin,lookup,setting"] # We use the Python version instead of the original version which seems to require Docker # https://github.com/koalaman/shellcheck-precommit diff --git a/THIRD-PARTY-NOTICES b/THIRD-PARTY-NOTICES index b7b2da6e..e708290f 100644 --- a/THIRD-PARTY-NOTICES +++ b/THIRD-PARTY-NOTICES @@ -3,7 +3,7 @@ be distributed under licenses different than this software. The attached notices are provided for information only. -django-mongodb-backend began by borrowing code from Django non-rel's +django-mongodb-backend and EmbeddedModelField began by borrowing code from django-mongodb-engine (https://github.com/django-nonrel/mongodb-engine), abandoned since 2015 and Django 1.6. diff --git a/django_mongodb_backend/compiler.py b/django_mongodb_backend/compiler.py index 1608d7d7..afb5d252 100644 --- a/django_mongodb_backend/compiler.py +++ b/django_mongodb_backend/compiler.py @@ -741,7 +741,7 @@ def execute_sql(self, result_type): elif hasattr(value, "prepare_database_save"): if field.remote_field: value = value.prepare_database_save(field) - else: + elif not hasattr(field, "embedded_model"): raise TypeError( f"Tried to update field {field} with a model " f"instance, {value!r}. Use a value compatible with " diff --git a/django_mongodb_backend/fields/__init__.py b/django_mongodb_backend/fields/__init__.py index cab7071c..569c19be 100644 --- a/django_mongodb_backend/fields/__init__.py +++ b/django_mongodb_backend/fields/__init__.py @@ -1,10 +1,17 @@ from .array import ArrayField from .auto import ObjectIdAutoField from .duration import register_duration_field +from .embedded_model import EmbeddedModelField from .json import register_json_field from .objectid import ObjectIdField -__all__ = ["register_fields", "ArrayField", "ObjectIdAutoField", "ObjectIdField"] +__all__ = [ + "register_fields", + "ArrayField", + "EmbeddedModelField", + "ObjectIdAutoField", + "ObjectIdField", +] def register_fields(): diff --git a/django_mongodb_backend/fields/embedded_model.py b/django_mongodb_backend/fields/embedded_model.py new file mode 100644 index 00000000..a5899326 --- /dev/null +++ b/django_mongodb_backend/fields/embedded_model.py @@ -0,0 +1,161 @@ +from django.core import checks +from django.db import models +from django.db.models.fields.related import lazy_related_operation +from django.db.models.lookups import Transform + +from .. import forms + + +class EmbeddedModelField(models.Field): + """Field that stores a model instance.""" + + def __init__(self, embedded_model, *args, **kwargs): + """ + `embedded_model` is the model class of the instance to be stored. + Like other relational fields, it may also be passed as a string. + """ + self.embedded_model = embedded_model + super().__init__(*args, **kwargs) + + def check(self, **kwargs): + errors = super().check(**kwargs) + for field in self.embedded_model._meta.fields: + if field.remote_field: + errors.append( + checks.Error( + "Embedded models cannot have relational fields " + f"({self.embedded_model().__class__.__name__}.{field.name} " + f"is a {field.__class__.__name__}).", + obj=self, + id="django_mongodb.embedded_model.E001", + ) + ) + return errors + + def deconstruct(self): + name, path, args, kwargs = super().deconstruct() + if path.startswith("django_mongodb_backend.fields.embedded_model"): + path = path.replace( + "django_mongodb_backend.fields.embedded_model", "django_mongodb_backend.fields" + ) + kwargs["embedded_model"] = self.embedded_model + return name, path, args, kwargs + + def get_internal_type(self): + return "EmbeddedModelField" + + def _set_model(self, model): + """ + Resolve embedded model class once the field knows the model it belongs + to. If __init__()'s embedded_model argument is a string, resolve it to + the actual model class, similar to relation fields. + """ + self._model = model + if model is not None and isinstance(self.embedded_model, str): + + def _resolve_lookup(_, resolved_model): + self.embedded_model = resolved_model + + lazy_related_operation(_resolve_lookup, model, self.embedded_model) + + model = property(lambda self: self._model, _set_model) + + def from_db_value(self, value, expression, connection): + return self.to_python(value) + + def to_python(self, value): + """ + Pass embedded model fields' values through each field's to_python() and + reinstantiate the embedded instance. + """ + if value is None: + return None + if not isinstance(value, dict): + return value + instance = self.embedded_model( + **{ + field.attname: field.to_python(value[field.attname]) + for field in self.embedded_model._meta.fields + if field.attname in value + } + ) + instance._state.adding = False + return instance + + def get_db_prep_save(self, embedded_instance, connection): + """ + Apply pre_save() and get_db_prep_save() of embedded instance fields and + create the {field: value} dict to be saved. + """ + if embedded_instance is None: + return None + if not isinstance(embedded_instance, self.embedded_model): + raise TypeError( + f"Expected instance of type {self.embedded_model!r}, not " + f"{type(embedded_instance)!r}." + ) + field_values = {} + add = embedded_instance._state.adding + for field in embedded_instance._meta.fields: + value = field.get_db_prep_save( + field.pre_save(embedded_instance, add), connection=connection + ) + # Exclude unset primary keys (e.g. {'id': None}). + if field.primary_key and value is None: + continue + field_values[field.attname] = value + # This instance will exist in the database soon. + embedded_instance._state.adding = False + return field_values + + def get_transform(self, name): + transform = super().get_transform(name) + if transform: + return transform + return KeyTransformFactory(name) + + def validate(self, value, model_instance): + super().validate(value, model_instance) + if self.embedded_model is None: + return + for field in self.embedded_model._meta.fields: + attname = field.attname + field.validate(getattr(value, attname), model_instance) + + def formfield(self, **kwargs): + return super().formfield( + **{ + "form_class": forms.EmbeddedModelField, + "model": self.embedded_model, + "prefix": self.name, + **kwargs, + } + ) + + +class KeyTransform(Transform): + def __init__(self, key_name, *args, **kwargs): + super().__init__(*args, **kwargs) + self.key_name = str(key_name) + + def preprocess_lhs(self, compiler, connection): + key_transforms = [self.key_name] + previous = self.lhs + while isinstance(previous, KeyTransform): + key_transforms.insert(0, previous.key_name) + previous = previous.lhs + mql = previous.as_mql(compiler, connection) + return mql, key_transforms + + def as_mql(self, compiler, connection): + mql, key_transforms = self.preprocess_lhs(compiler, connection) + transforms = ".".join(key_transforms) + return f"{mql}.{transforms}" + + +class KeyTransformFactory: + def __init__(self, key_name): + self.key_name = key_name + + def __call__(self, *args, **kwargs): + return KeyTransform(self.key_name, *args, **kwargs) diff --git a/django_mongodb_backend/forms/__init__.py b/django_mongodb_backend/forms/__init__.py index 96c8775e..2adc8fbe 100644 --- a/django_mongodb_backend/forms/__init__.py +++ b/django_mongodb_backend/forms/__init__.py @@ -1,6 +1,13 @@ -from .fields import ObjectIdField, SimpleArrayField, SplitArrayField, SplitArrayWidget +from .fields import ( + EmbeddedModelField, + ObjectIdField, + SimpleArrayField, + SplitArrayField, + SplitArrayWidget, +) __all__ = [ + "EmbeddedModelField", "SimpleArrayField", "SplitArrayField", "SplitArrayWidget", diff --git a/django_mongodb_backend/forms/fields/__init__.py b/django_mongodb_backend/forms/fields/__init__.py index 298c6b68..03cc2372 100644 --- a/django_mongodb_backend/forms/fields/__init__.py +++ b/django_mongodb_backend/forms/fields/__init__.py @@ -1,7 +1,9 @@ from .array import SimpleArrayField, SplitArrayField, SplitArrayWidget +from .embedded_model import EmbeddedModelField from .objectid import ObjectIdField __all__ = [ + "EmbeddedModelField", "SimpleArrayField", "SplitArrayField", "SplitArrayWidget", diff --git a/django_mongodb_backend/forms/fields/embedded_model.py b/django_mongodb_backend/forms/fields/embedded_model.py new file mode 100644 index 00000000..b86e85e7 --- /dev/null +++ b/django_mongodb_backend/forms/fields/embedded_model.py @@ -0,0 +1,62 @@ +from django import forms +from django.forms.models import modelform_factory +from django.utils.safestring import mark_safe +from django.utils.translation import gettext_lazy as _ + + +class EmbeddedModelWidget(forms.MultiWidget): + def __init__(self, field_names, *args, **kwargs): + self.field_names = field_names + super().__init__(*args, **kwargs) + # The default widget names are "_0", "_1", etc. Use the field names + # instead since that's how they'll be rendered by the model form. + self.widgets_names = ["-" + name for name in field_names] + + def decompress(self, value): + if value is None: + return [] + # Get the data from `value` (a model) for each field. + return [getattr(value, name) for name in self.field_names] + + +class EmbeddedModelBoundField(forms.BoundField): + def __str__(self): + """Render the model form as the representation for this field.""" + form = self.field.model_form_cls(instance=self.value(), **self.field.form_kwargs) + return mark_safe(f"{form.as_div()}") # noqa: S308 + + +class EmbeddedModelField(forms.MultiValueField): + default_error_messages = { + "invalid": _("Enter a list of values."), + "incomplete": _("Enter all required values."), + } + + def __init__(self, model, prefix, *args, **kwargs): + form_kwargs = {} + # To avoid collisions with other fields on the form, each subfield must + # be prefixed with the name of the field. + form_kwargs["prefix"] = prefix + self.form_kwargs = form_kwargs + self.model_form_cls = modelform_factory(model, fields="__all__") + self.model_form = self.model_form_cls(**form_kwargs) + self.field_names = list(self.model_form.fields.keys()) + fields = self.model_form.fields.values() + widgets = [field.widget for field in fields] + widget = EmbeddedModelWidget(self.field_names, widgets) + super().__init__(*args, fields=fields, widget=widget, require_all_fields=False, **kwargs) + + def compress(self, data_dict): + if not data_dict: + return None + values = dict(zip(self.field_names, data_dict, strict=False)) + return self.model_form._meta.model(**values) + + def get_bound_field(self, form, field_name): + return EmbeddedModelBoundField(form, self, field_name) + + def bound_data(self, data, initial): + if self.disabled: + return initial + # Transform the bound data into a model instance. + return self.compress(data) diff --git a/django_mongodb_backend/schema.py b/django_mongodb_backend/schema.py index 8fc18fea..9769df8b 100644 --- a/django_mongodb_backend/schema.py +++ b/django_mongodb_backend/schema.py @@ -5,6 +5,7 @@ from pymongo import ASCENDING, DESCENDING from pymongo.operations import IndexModel +from .fields import EmbeddedModelField from .query import wrap_database_errors from .utils import OperationCollector @@ -29,31 +30,50 @@ def create_model(self, model): if field.remote_field.through._meta.auto_created: self.create_model(field.remote_field.through) - def _create_model_indexes(self, model): + def _create_model_indexes(self, model, column_prefix="", parent_model=None): """ Create all indexes (field indexes & uniques, Meta.index_together, Meta.unique_together, Meta.constraints, Meta.indexes) for the model. + + If this is a recursive call due to an embedded model, `column_prefix` + tracks the path that must be prepended to the index's column, and + `parent_model` tracks the collection to add the index/constraint to. """ if not model._meta.managed or model._meta.proxy or model._meta.swapped: return # Field indexes and uniques for field in model._meta.local_fields: + if isinstance(field, EmbeddedModelField): + new_path = f"{column_prefix}{field.column}." + self._create_model_indexes( + field.embedded_model, parent_model=parent_model or model, column_prefix=new_path + ) if self._field_should_be_indexed(model, field): - self._add_field_index(model, field) + self._add_field_index(parent_model or model, field, column_prefix=column_prefix) elif self._field_should_have_unique(field): - self._add_field_unique(model, field) + self._add_field_unique(parent_model or model, field, column_prefix=column_prefix) # Meta.index_together (RemovedInDjango51Warning) for field_names in model._meta.index_together: - self._add_composed_index(model, field_names) + self._add_composed_index( + model, field_names, column_prefix=column_prefix, parent_model=parent_model + ) # Meta.unique_together if model._meta.unique_together: - self.alter_unique_together(model, [], model._meta.unique_together) + self.alter_unique_together( + model, + [], + model._meta.unique_together, + column_prefix=column_prefix, + parent_model=parent_model, + ) # Meta.constraints for constraint in model._meta.constraints: - self.add_constraint(model, constraint) + self.add_constraint( + model, constraint, column_prefix=column_prefix, parent_model=parent_model + ) # Meta.indexes for index in model._meta.indexes: - self.add_index(model, index) + self.add_index(model, index, column_prefix=column_prefix, parent_model=parent_model) def delete_model(self, model): # Delete implicit M2m tables. @@ -72,6 +92,11 @@ def add_field(self, model, field): self.get_collection(model._meta.db_table).update_many( {}, [{"$set": {column: self.effective_default(field)}}] ) + if isinstance(field, EmbeddedModelField): + new_path = f"{field.column}." + self._create_model_indexes( + field.embedded_model, parent_model=model, column_prefix=new_path + ) # Add an index or unique, if required. if self._field_should_be_indexed(model, field): self._add_field_index(model, field) @@ -136,18 +161,70 @@ def remove_field(self, model, field): self._remove_field_index(model, field) elif self._field_should_have_unique(field): self._remove_field_unique(model, field) + if isinstance(field, EmbeddedModelField): + new_path = f"{field.column}." + self._remove_model_indexes( + field.embedded_model, parent_model=model, column_prefix=new_path + ) - def alter_index_together(self, model, old_index_together, new_index_together): + def _remove_model_indexes(self, model, column_prefix="", parent_model=None): + """ + When removing an EmbeddedModelField, the indexes need to be removed + recursively. + """ + if not model._meta.managed or model._meta.proxy or model._meta.swapped: + return + # Field indexes and uniques + for field in model._meta.local_fields: + if isinstance(field, EmbeddedModelField): + new_path = f"{column_prefix}{field.column}." + self._remove_model_indexes( + field.embedded_model, parent_model=parent_model or model, column_prefix=new_path + ) + if self._field_should_be_indexed(model, field): + self._remove_field_index(parent_model or model, field, column_prefix=column_prefix) + elif self._field_should_have_unique(field): + self._remove_field_unique(parent_model or model, field, column_prefix=column_prefix) + # Meta.index_together (RemovedInDjango51Warning) + for field_names in model._meta.index_together: + self._remove_composed_index( + model, + field_names, + {"index": True, "unique": False}, + column_prefix=column_prefix, + parent_model=parent_model, + ) + # Meta.unique_together + if model._meta.unique_together: + self.alter_unique_together( + model, + model._meta.unique_together, + [], + column_prefix=column_prefix, + parent_model=parent_model, + ) + # Meta.constraints + for constraint in model._meta.constraints: + self.remove_constraint(parent_model or model, constraint) + # Meta.indexes + for index in model._meta.indexes: + self.remove_index(parent_model or model, index) + + def alter_index_together(self, model, old_index_together, new_index_together, column_prefix=""): olds = {tuple(fields) for fields in old_index_together} news = {tuple(fields) for fields in new_index_together} # Deleted indexes for field_names in olds.difference(news): - self._remove_composed_index(model, field_names, {"index": True, "unique": False}) + self._remove_composed_index( + model, field_names, {"index": True, "unique": False}, column_prefix="" + ) # Created indexes for field_names in news.difference(olds): - self._add_composed_index(model, field_names) + self._add_composed_index(model, field_names, column_prefix=column_prefix) - def alter_unique_together(self, model, old_unique_together, new_unique_together): + def alter_unique_together( + self, model, old_unique_together, new_unique_together, column_prefix="", parent_model=None + ): olds = {tuple(fields) for fields in old_unique_together} news = {tuple(fields) for fields in new_unique_together} # Deleted uniques @@ -156,15 +233,25 @@ def alter_unique_together(self, model, old_unique_together, new_unique_together) model, field_names, {"unique": True, "primary_key": False}, + column_prefix=column_prefix, + parent_model=parent_model, ) # Created uniques for field_names in news.difference(olds): columns = [model._meta.get_field(field).column for field in field_names] - name = str(self._unique_constraint_name(model._meta.db_table, columns)) + name = str( + self._unique_constraint_name( + model._meta.db_table, [column_prefix + col for col in columns] + ) + ) constraint = UniqueConstraint(fields=field_names, name=name) - self.add_constraint(model, constraint) + self.add_constraint( + model, constraint, parent_model=parent_model, column_prefix=column_prefix + ) - def add_index(self, model, index, field=None, unique=False): + def add_index( + self, model, index, *, field=None, unique=False, column_prefix="", parent_model=None + ): if index.contains_expressions: return kwargs = {} @@ -176,7 +263,8 @@ def add_index(self, model, index, field=None, unique=False): # Indexing on $type matches the value of most SQL databases by # allowing multiple null values for the unique constraint. if field: - filter_expression[field.column].update({"$type": field.db_type(self.connection)}) + column = column_prefix + field.column + filter_expression[column].update({"$type": field.db_type(self.connection)}) else: for field_name, _ in index.fields_orders: field_ = model._meta.get_field(field_name) @@ -186,45 +274,51 @@ def add_index(self, model, index, field=None, unique=False): if filter_expression: kwargs["partialFilterExpression"] = filter_expression index_orders = ( - [(field.column, ASCENDING)] + [(column_prefix + field.column, ASCENDING)] if field else [ # order is "" if ASCENDING or "DESC" if DESCENDING (see # django.db.models.indexes.Index.fields_orders). - (model._meta.get_field(field_name).column, ASCENDING if order == "" else DESCENDING) + ( + column_prefix + model._meta.get_field(field_name).column, + ASCENDING if order == "" else DESCENDING, + ) for field_name, order in index.fields_orders ] ) idx = IndexModel(index_orders, name=index.name, **kwargs) + model = parent_model or model self.get_collection(model._meta.db_table).create_indexes([idx]) - def _add_composed_index(self, model, field_names): + def _add_composed_index(self, model, field_names, column_prefix="", parent_model=None): """Add an index on the given list of field_names.""" idx = Index(fields=field_names) idx.set_name_with_model(model) - self.add_index(model, idx) + self.add_index(model, idx, column_prefix=column_prefix, parent_model=parent_model) - def _add_field_index(self, model, field): + def _add_field_index(self, model, field, *, column_prefix=""): """Add an index on a field with db_index=True.""" - index = Index(fields=[field.name]) - index.name = self._create_index_name(model._meta.db_table, [field.column]) - self.add_index(model, index, field=field) + index = Index(fields=[column_prefix + field.name]) + index.name = self._create_index_name(model._meta.db_table, [column_prefix + field.column]) + self.add_index(model, index, field=field, column_prefix=column_prefix) def remove_index(self, model, index): if index.contains_expressions: return self.get_collection(model._meta.db_table).drop_index(index.name) - def _remove_composed_index(self, model, field_names, constraint_kwargs): + def _remove_composed_index( + self, model, field_names, constraint_kwargs, column_prefix="", parent_model=None + ): """ Remove the index on the given list of field_names created by index/unique_together, depending on constraint_kwargs. """ meta_constraint_names = {constraint.name for constraint in model._meta.constraints} meta_index_names = {constraint.name for constraint in model._meta.indexes} - columns = [model._meta.get_field(field).column for field in field_names] + columns = [column_prefix + model._meta.get_field(field).column for field in field_names] constraint_names = self._constraint_names( - model, + parent_model or model, columns, exclude=meta_constraint_names | meta_index_names, **constraint_kwargs, @@ -236,16 +330,17 @@ def _remove_composed_index(self, model, field_names, constraint_kwargs): f"Found wrong number ({num_found}) of constraints for " f"{model._meta.db_table}({columns_str})." ) + model = parent_model or model collection = self.get_collection(model._meta.db_table) collection.drop_index(constraint_names[0]) - def _remove_field_index(self, model, field): + def _remove_field_index(self, model, field, column_prefix=""): """Remove a field's db_index=True index.""" collection = self.get_collection(model._meta.db_table) meta_index_names = {index.name for index in model._meta.indexes} index_names = self._constraint_names( model, - [field.column], + [column_prefix + field.column], index=True, # Retrieve only BTREE indexes since this is what's created with # db_index=True. @@ -260,7 +355,7 @@ def _remove_field_index(self, model, field): ) collection.drop_index(index_names[0]) - def add_constraint(self, model, constraint, field=None): + def add_constraint(self, model, constraint, field=None, column_prefix="", parent_model=None): if isinstance(constraint, UniqueConstraint) and self._unique_supported( condition=constraint.condition, deferrable=constraint.deferrable, @@ -273,12 +368,21 @@ def add_constraint(self, model, constraint, field=None): name=constraint.name, condition=constraint.condition, ) - self.add_index(model, idx, field=field, unique=True) + self.add_index( + model, + idx, + field=field, + unique=True, + column_prefix=column_prefix, + parent_model=parent_model, + ) - def _add_field_unique(self, model, field): - name = str(self._unique_constraint_name(model._meta.db_table, [field.column])) + def _add_field_unique(self, model, field, column_prefix=""): + name = str( + self._unique_constraint_name(model._meta.db_table, [column_prefix + field.column]) + ) constraint = UniqueConstraint(fields=[field.name], name=name) - self.add_constraint(model, constraint, field=field) + self.add_constraint(model, constraint, field=field, column_prefix=column_prefix) def remove_constraint(self, model, constraint): if isinstance(constraint, UniqueConstraint) and self._unique_supported( @@ -295,12 +399,12 @@ def remove_constraint(self, model, constraint): ) self.remove_index(model, idx) - def _remove_field_unique(self, model, field): + def _remove_field_unique(self, model, field, column_prefix=""): # Find the unique constraint for this field meta_constraint_names = {constraint.name for constraint in model._meta.constraints} constraint_names = self._constraint_names( model, - [field.column], + [column_prefix + field.column], unique=True, primary_key=False, exclude=meta_constraint_names, diff --git a/docs/source/_ext/djangodocs.py b/docs/source/_ext/djangodocs.py index fda464d8..fc89c24c 100644 --- a/docs/source/_ext/djangodocs.py +++ b/docs/source/_ext/djangodocs.py @@ -1,4 +1,9 @@ def setup(app): + app.add_object_type( + directivename="django-admin", + rolename="djadmin", + indextemplate="pair: %s; django-admin command", + ) app.add_crossref_type( directivename="fieldlookup", rolename="lookup", diff --git a/docs/source/embedded-models.rst b/docs/source/embedded-models.rst new file mode 100644 index 00000000..08e6891b --- /dev/null +++ b/docs/source/embedded-models.rst @@ -0,0 +1,55 @@ +Embedded models +=============== + +Use :class:`~django_mongdob.fields.EmbeddedModelField` to structure your data +using `embedded documents +`_. + +The basics +---------- + +Let's consider this example:: + + from django_mongodb_backend.fields import EmbeddedModelField + + class Customer(models.Model): + name = models.CharField(...) + address = EmbeddedModelField("Address") + ... + + class Address(models.Model): + ... + city = models.CharField(...) + + +The API is similar to that of Django's relational fields:: + + >>> Customer.objects.create(name="Bob", address=Address(city="New York", ...), ...) + >>> bob = Customer.objects.get(...) + >>> bob.address + + >>> bob.address.city + 'New York' + +Represented in BSON, Bob's structure looks like this: + +.. code-block:: js + + { + "_id": ObjectId(...), + "name": "Bob", + "address": { + ... + "city": "New York" + }, + ... + } + +Querying ``EmbeddedModelField`` +------------------------------- + +You can query into an embedded model using the same double underscore syntax +as relational fields. For example, to retrieve all customers who have an +address with the city "New York":: + + >>> Customer.objects.filter(address__city="New York") diff --git a/docs/source/fields.rst b/docs/source/fields.rst index 39f965a7..270ea7d7 100644 --- a/docs/source/fields.rst +++ b/docs/source/fields.rst @@ -36,7 +36,8 @@ Some MongoDB-specific fields are available in ``django_mongodb_backend.fields``. :class:`~django.db.models.OneToOneField` and :class:`~django.db.models.ManyToManyField`) and file fields ( :class:`~django.db.models.FileField` and - :class:`~django.db.models.ImageField`). + :class:`~django.db.models.ImageField`). :class:`EmbeddedModelField` is + also not (yet) supported. It is possible to nest array fields - you can specify an instance of ``ArrayField`` as the ``base_field``. For example:: @@ -210,6 +211,51 @@ transform do not change. For example: These indexes use 0-based indexing. +``EmbeddedModelField`` +---------------------- + +.. class:: EmbeddedModelField(embedded_model, **kwargs) + +Stores a model of type ``embedded_model``. + + .. attribute:: embedded_model + + This is a required argument. + + Specifies the model class to embed. It can be either a concrete model + class or a :ref:`lazy reference ` to a model class. + + The embedded model cannot have relational fields + (:class:`~django.db.models.ForeignKey`, + :class:`~django.db.models.OneToOneField` and + :class:`~django.db.models.ManyToManyField`). + + It is possible to nest embedded models. For example:: + + from django.db import models + from django_mongodb_backend.fields import EmbeddedModelField + + class Address(models.Model): + ... + + class Author(models.Model): + address = EmbeddedModelField(Address) + + class Book(models.Model): + author = EmbeddedModelField(Author) + +See :doc:`embedded-models` for more details and examples. + +.. admonition:: Migrations support is limited + + :djadmin:`makemigrations` does not yet detect changes to embedded models. + + After you create a model with an ``EmbeddedModelField`` or add an + ``EmbeddedModelField`` to an existing model, no further updates to the + embedded model will be made. Using the models above as an example, if you + created these models and then added an indexed field to ``Address``, + the index created in the nested ``Book`` embed is not created. + ``ObjectIdField`` ----------------- diff --git a/docs/source/forms.rst b/docs/source/forms.rst index eb020073..64c42755 100644 --- a/docs/source/forms.rst +++ b/docs/source/forms.rst @@ -5,6 +5,24 @@ Forms API reference Some MongoDB-specific fields are available in ``django_mongodb_backend.forms``. +``EmbeddedModelField`` +---------------------- + +.. class:: EmbeddedModelField(model, prefix, **kwargs) + + A field which maps to a model. The field will render as a + :class:`~django.forms.ModelForm`. + + .. attribute:: model + + This is a required argument that specifies the model class. + + .. attribute:: prefix + + This is a required argument that specifies the prefix that all fields + in this field's subform will have so that the names don't collide with + fields in the main form. + ``ObjectIdField`` ----------------- diff --git a/docs/source/index.rst b/docs/source/index.rst index 89a1ab23..4e331bd6 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -8,6 +8,7 @@ django-mongodb-backend 5.0.x documentation fields querysets forms + embedded-models Indices and tables ================== diff --git a/tests/model_fields_/models.py b/tests/model_fields_/models.py index 9b2d96ec..9b00665b 100644 --- a/tests/model_fields_/models.py +++ b/tests/model_fields_/models.py @@ -2,9 +2,10 @@ from django.db import models -from django_mongodb_backend.fields import ArrayField, ObjectIdField +from django_mongodb_backend.fields import ArrayField, EmbeddedModelField, ObjectIdField +# ObjectIdField class ObjectIdModel(models.Model): field = ObjectIdField() @@ -17,6 +18,7 @@ class PrimaryKeyObjectIdModel(models.Model): field = ObjectIdField(primary_key=True) +# ArrayField class ArrayFieldSubclass(ArrayField): def __init__(self, *args, **kwargs): super().__init__(models.IntegerField()) @@ -89,3 +91,31 @@ def get_prep_value(self, value): class ArrayEnumModel(models.Model): array_of_enums = ArrayField(EnumField(max_length=20)) + + +# EmbeddedModelField +class Holder(models.Model): + data = EmbeddedModelField("Data", null=True, blank=True) + + +class Data(models.Model): + integer = models.IntegerField(db_column="custom_column") + auto_now = models.DateTimeField(auto_now=True) + auto_now_add = models.DateTimeField(auto_now_add=True) + + +class Address(models.Model): + city = models.CharField(max_length=20) + state = models.CharField(max_length=2) + zip_code = models.IntegerField(db_index=True) + + +class Author(models.Model): + name = models.CharField(max_length=10) + age = models.IntegerField() + address = EmbeddedModelField(Address) + + +class Book(models.Model): + name = models.CharField(max_length=100) + author = EmbeddedModelField(Author) diff --git a/tests/model_fields_/test_embedded_model.py b/tests/model_fields_/test_embedded_model.py new file mode 100644 index 00000000..8b0b53f6 --- /dev/null +++ b/tests/model_fields_/test_embedded_model.py @@ -0,0 +1,125 @@ +from django.core.exceptions import ValidationError +from django.db import models +from django.test import SimpleTestCase, TestCase +from django.test.utils import isolate_apps + +from django_mongodb_backend.fields import EmbeddedModelField + +from .models import ( + Address, + Author, + Book, + Data, + Holder, +) + + +class MethodTests(SimpleTestCase): + def test_deconstruct(self): + field = EmbeddedModelField("Data", null=True) + name, path, args, kwargs = field.deconstruct() + self.assertEqual(path, "django_mongodb_backend.fields.EmbeddedModelField") + self.assertEqual(args, []) + self.assertEqual(kwargs, {"embedded_model": "Data", "null": True}) + + def test_get_db_prep_save_invalid(self): + msg = "Expected instance of type , " "not ." + with self.assertRaisesMessage(TypeError, msg): + Holder(data=42).save() + + def test_validate(self): + obj = Holder(data=Data(integer=None)) + # This isn't quite right because "integer" is the subfield of data + # that's non-null. + msg = "{'data': ['This field cannot be null.']}" + with self.assertRaisesMessage(ValidationError, msg): + obj.full_clean() + + +class ModelTests(TestCase): + def truncate_ms(self, value): + """Truncate microseconds to milliseconds as supported by MongoDB.""" + return value.replace(microsecond=(value.microsecond // 1000) * 1000) + + def test_save_load(self): + Holder.objects.create(data=Data(integer="5")) + obj = Holder.objects.get() + self.assertIsInstance(obj.data, Data) + # get_prep_value() is called, transforming string to int. + self.assertEqual(obj.data.integer, 5) + # Primary keys should not be populated... + self.assertEqual(obj.data.id, None) + # ... unless set explicitly. + obj.data.id = obj.id + obj.save() + obj = Holder.objects.get() + self.assertEqual(obj.data.id, obj.id) + + def test_save_load_null(self): + Holder.objects.create(data=None) + obj = Holder.objects.get() + self.assertIsNone(obj.data) + + def test_pre_save(self): + """Field.pre_save() is called on embedded model fields.""" + obj = Holder.objects.create(data=Data()) + auto_now = self.truncate_ms(obj.data.auto_now) + auto_now_add = self.truncate_ms(obj.data.auto_now_add) + self.assertEqual(auto_now, auto_now_add) + # save() updates auto_now but not auto_now_add. + obj.save() + self.assertEqual(self.truncate_ms(obj.data.auto_now_add), auto_now_add) + auto_now_two = obj.data.auto_now + self.assertGreater(auto_now_two, obj.data.auto_now_add) + # And again, save() updates auto_now but not auto_now_add. + obj = Holder.objects.get() + obj.save() + self.assertEqual(obj.data.auto_now_add, auto_now_add) + self.assertGreater(obj.data.auto_now, auto_now_two) + + +class QueryingTests(TestCase): + @classmethod + def setUpTestData(cls): + cls.objs = [Holder.objects.create(data=Data(integer=x)) for x in range(6)] + + def test_exact(self): + self.assertCountEqual(Holder.objects.filter(data__integer=3), [self.objs[3]]) + + def test_lt(self): + self.assertCountEqual(Holder.objects.filter(data__integer__lt=3), self.objs[:3]) + + def test_lte(self): + self.assertCountEqual(Holder.objects.filter(data__integer__lte=3), self.objs[:4]) + + def test_gt(self): + self.assertCountEqual(Holder.objects.filter(data__integer__gt=3), self.objs[4:]) + + def test_gte(self): + self.assertCountEqual(Holder.objects.filter(data__integer__gte=3), self.objs[3:]) + + def test_nested(self): + obj = Book.objects.create( + author=Author(name="Shakespeare", age=55, address=Address(city="NYC", state="NY")) + ) + self.assertCountEqual(Book.objects.filter(author__address__city="NYC"), [obj]) + + +@isolate_apps("model_fields_") +class CheckTests(SimpleTestCase): + def test_no_relational_fields(self): + class Target(models.Model): + key = models.ForeignKey("MyModel", models.CASCADE) + + class MyModel(models.Model): + field = EmbeddedModelField(Target) + + model = MyModel() + errors = model.check() + self.assertEqual(len(errors), 1) + # The inner CharField has a non-positive max_length. + self.assertEqual(errors[0].id, "django_mongodb.embedded_model.E001") + msg = errors[0].msg + self.assertEqual( + msg, "Embedded models cannot have relational fields (Target.key is a ForeignKey)." + ) diff --git a/tests/model_forms_/__init__.py b/tests/model_forms_/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/model_forms_/forms.py b/tests/model_forms_/forms.py new file mode 100644 index 00000000..7bfed3fb --- /dev/null +++ b/tests/model_forms_/forms.py @@ -0,0 +1,9 @@ +from django import forms + +from .models import Author + + +class AuthorForm(forms.ModelForm): + class Meta: + fields = "__all__" + model = Author diff --git a/tests/model_forms_/models.py b/tests/model_forms_/models.py new file mode 100644 index 00000000..d61196ab --- /dev/null +++ b/tests/model_forms_/models.py @@ -0,0 +1,22 @@ +from django.db import models + +from django_mongodb_backend.fields import EmbeddedModelField + + +class Address(models.Model): + po_box = models.CharField(max_length=50, blank=True, verbose_name="PO Box") + city = models.CharField(max_length=20) + state = models.CharField(max_length=2) + zip_code = models.IntegerField() + + +class Author(models.Model): + name = models.CharField(max_length=10) + age = models.IntegerField() + address = EmbeddedModelField(Address) + billing_address = EmbeddedModelField(Address, blank=True, null=True) + + +class Book(models.Model): + name = models.CharField(max_length=100) + author = EmbeddedModelField(Author) diff --git a/tests/model_forms_/test_embedded_model.py b/tests/model_forms_/test_embedded_model.py new file mode 100644 index 00000000..240f8c6d --- /dev/null +++ b/tests/model_forms_/test_embedded_model.py @@ -0,0 +1,130 @@ +from django.test import TestCase + +from .forms import AuthorForm +from .models import Address, Author + + +class ModelFormTests(TestCase): + def test_update(self): + author = Author.objects.create( + name="Bob", age=50, address=Address(city="NYC", state="NY", zip_code="10001") + ) + data = { + "name": "Bob", + "age": 51, + "address-po_box": "", + "address-city": "New York City", + "address-state": "NY", + "address-zip_code": "10001", + } + form = AuthorForm(data, instance=author) + self.assertTrue(form.is_valid()) + form.save() + author.refresh_from_db() + self.assertEqual(author.age, 51) + self.assertEqual(author.address.city, "New York City") + + def test_some_missing_data(self): + author = Author.objects.create( + name="Bob", age=50, address=Address(city="NYC", state="NY", zip_code="10001") + ) + data = { + "name": "Bob", + "age": 51, + "address-po_box": "", + "address-city": "New York City", + "address-state": "NY", + "address-zip_code": "", + } + form = AuthorForm(data, instance=author) + self.assertFalse(form.is_valid()) + self.assertEqual(form.errors["address"], ["Enter all required values."]) + + def test_invalid_field_data(self): + """A field's data (state) is too long.""" + author = Author.objects.create( + name="Bob", age=50, address=Address(city="NYC", state="NY", zip_code="10001") + ) + data = { + "name": "Bob", + "age": 51, + "address-po_box": "", + "address-city": "New York City", + "address-state": "TOO LONG", + "address-zip_code": "", + } + form = AuthorForm(data, instance=author) + self.assertFalse(form.is_valid()) + self.assertEqual( + form.errors["address"], + [ + "Ensure this value has at most 2 characters (it has 8).", + "Enter all required values.", + ], + ) + + def test_all_missing_data(self): + author = Author.objects.create( + name="Bob", age=50, address=Address(city="NYC", state="NY", zip_code="10001") + ) + data = { + "name": "Bob", + "age": 51, + "address-po_box": "", + "address-city": "", + "address-state": "", + "address-zip_code": "", + } + form = AuthorForm(data, instance=author) + self.assertFalse(form.is_valid()) + self.assertEqual(form.errors["address"], ["This field is required."]) + + def test_nullable_field(self): + """A nullable EmbeddedModelField is removed if all fields are empty.""" + author = Author.objects.create( + name="Bob", + age=50, + address=Address(city="NYC", state="NY", zip_code="10001"), + billing_address=Address(city="NYC", state="NY", zip_code="10001"), + ) + data = { + "name": "Bob", + "age": 51, + "address-po_box": "", + "address-city": "New York City", + "address-state": "NY", + "address-zip_code": "10001", + "billing_address-po_box": "", + "billing_address-city": "", + "billing_address-state": "", + "billing_address-zip_code": "", + } + form = AuthorForm(data, instance=author) + self.assertTrue(form.is_valid()) + form.save() + author.refresh_from_db() + self.assertIsNone(author.billing_address) + + def test_rendering(self): + form = AuthorForm() + self.assertHTMLEqual( + str(form.fields["address"].get_bound_field(form, "address")), + """ +
+ + +
+
+ + +
+
+ + +
+
+ + +
""", + ) diff --git a/tests/schema_/__init__.py b/tests/schema_/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/schema_/models.py b/tests/schema_/models.py new file mode 100644 index 00000000..7c0f4533 --- /dev/null +++ b/tests/schema_/models.py @@ -0,0 +1,37 @@ +from django.apps.registry import Apps +from django.db import models + +from django_mongodb_backend.fields import EmbeddedModelField + +# These models are inserted into a separate Apps so the test runner doesn't +# migrate them. + +new_apps = Apps() + + +class Address(models.Model): + city = models.CharField(max_length=20) + state = models.CharField(max_length=2) + zip_code = models.IntegerField(db_index=True) + uid = models.IntegerField(unique=True) + + class Meta: + apps = new_apps + + +class Author(models.Model): + name = models.CharField(max_length=10) + age = models.IntegerField(db_index=True) + address = EmbeddedModelField(Address) + employee_id = models.IntegerField(unique=True) + + class Meta: + apps = new_apps + + +class Book(models.Model): + name = models.CharField(max_length=100) + author = EmbeddedModelField(Author) + + class Meta: + apps = new_apps diff --git a/tests/schema_/test_embedded_model.py b/tests/schema_/test_embedded_model.py new file mode 100644 index 00000000..faf58ce2 --- /dev/null +++ b/tests/schema_/test_embedded_model.py @@ -0,0 +1,629 @@ +import itertools + +from django.db import connection, models +from django.test import TransactionTestCase, ignore_warnings +from django.test.utils import isolate_apps +from django.utils.deprecation import RemovedInDjango51Warning + +from django_mongodb_backend.fields import EmbeddedModelField + +from .models import Address, Author, Book, new_apps + + +class SchemaTests(TransactionTestCase): + available_apps = [] + models = [Address, Author, Book] + + # Utility functions + + def setUp(self): + # local_models should contain test dependent model classes that will be + # automatically removed from the app cache on test tear down. + self.local_models = [] + # isolated_local_models contains models that are in test methods + # decorated with @isolate_apps. + self.isolated_local_models = [] + + def tearDown(self): + # Delete any tables made for our models + self.delete_tables() + new_apps.clear_cache() + for model in new_apps.get_models(): + model._meta._expire_cache() + if "schema" in new_apps.all_models: + for model in self.local_models: + for many_to_many in model._meta.many_to_many: + through = many_to_many.remote_field.through + if through and through._meta.auto_created: + del new_apps.all_models["schema"][through._meta.model_name] + del new_apps.all_models["schema"][model._meta.model_name] + if self.isolated_local_models: + with connection.schema_editor() as editor: + for model in self.isolated_local_models: + editor.delete_model(model) + + def delete_tables(self): + "Deletes all model tables for our models for a clean test environment" + converter = connection.introspection.identifier_converter + with connection.schema_editor() as editor: + connection.disable_constraint_checking() + table_names = connection.introspection.table_names() + if connection.features.ignores_table_name_case: + table_names = [table_name.lower() for table_name in table_names] + for model in itertools.chain(SchemaTests.models, self.local_models): + tbl = converter(model._meta.db_table) + if connection.features.ignores_table_name_case: + tbl = tbl.lower() + if tbl in table_names: + editor.delete_model(model) + table_names.remove(tbl) + connection.enable_constraint_checking() + + def get_constraints(self, table): + """ + Get the constraints on a table using a new cursor. + """ + with connection.cursor() as cursor: + return connection.introspection.get_constraints(cursor, table) + + def get_constraints_for_columns(self, model, columns): + constraints = self.get_constraints(model._meta.db_table) + constraints_for_column = [] + for name, details in constraints.items(): + if details["columns"] == columns: + constraints_for_column.append(name) + return sorted(constraints_for_column) + + def assertIndexOrder(self, table, index, order): + constraints = self.get_constraints(table) + self.assertIn(index, constraints) + index_orders = constraints[index]["orders"] + self.assertTrue( + all(val == expected for val, expected in zip(index_orders, order, strict=True)) + ) + + def assertTableExists(self, model): + self.assertIn(model._meta.db_table, connection.introspection.table_names()) + + def assertTableNotExists(self, model): + self.assertNotIn(model._meta.db_table, connection.introspection.table_names()) + + # SchemaEditor.create_model() tests + def test_db_index(self): + """Field(db_index=True) on an embedded model.""" + with connection.schema_editor() as editor: + # Create the table + editor.create_model(Book) + # The table is there + self.assertTableExists(Book) + # Embedded indexes are created. + self.assertEqual( + self.get_constraints_for_columns(Book, ["author.age"]), + ["schema__book_author.age_dc08100b"], + ) + self.assertEqual( + self.get_constraints_for_columns(Book, ["author.address.zip_code"]), + ["schema__book_author.address.zip_code_7b9a9307"], + ) + # Clean up that table + editor.delete_model(Book) + # The table is gone + self.assertTableNotExists(Author) + + def test_unique(self): + """Field(unique=True) on an embedded model.""" + with connection.schema_editor() as editor: + editor.create_model(Book) + self.assertTableExists(Book) + # Embedded uniques are created. + self.assertEqual( + self.get_constraints_for_columns(Book, ["author.employee_id"]), + ["schema__book_author.employee_id_7d4d3eff_uniq"], + ) + self.assertEqual( + self.get_constraints_for_columns(Book, ["author.address.uid"]), + ["schema__book_author.address.uid_8124a01f_uniq"], + ) + # Clean up that table + editor.delete_model(Book) + self.assertTableNotExists(Author) + + @ignore_warnings(category=RemovedInDjango51Warning) + @isolate_apps("schema_") + def test_index_together(self): + """Meta.index_together on an embedded model.""" + + class Address(models.Model): + index_together_one = models.CharField(max_length=10) + index_together_two = models.CharField(max_length=10) + + class Meta: + app_label = "schema_" + index_together = [("index_together_one", "index_together_two")] + + class Author(models.Model): + address = EmbeddedModelField(Address) + index_together_three = models.CharField(max_length=10) + index_together_four = models.CharField(max_length=10) + + class Meta: + app_label = "schema_" + index_together = [("index_together_three", "index_together_four")] + + class Book(models.Model): + author = EmbeddedModelField(Author) + + class Meta: + app_label = "schema_" + + with connection.schema_editor() as editor: + editor.create_model(Book) + self.assertTableExists(Book) + # Embedded uniques are created. + self.assertEqual( + self.get_constraints_for_columns( + Book, ["author.address.index_together_one", "author.address.index_together_two"] + ), + ["schema__add_index_t_efa93e_idx"], + ) + self.assertEqual( + self.get_constraints_for_columns( + Book, + ["author.index_together_three", "author.index_together_four"], + ), + ["schema__aut_index_t_df32aa_idx"], + ) + editor.delete_model(Book) + self.assertTableNotExists(Book) + + @isolate_apps("schema_") + def test_unique_together(self): + """Meta.unique_together on an embedded model.""" + + class Address(models.Model): + unique_together_one = models.CharField(max_length=10) + unique_together_two = models.CharField(max_length=10) + + class Meta: + app_label = "schema_" + unique_together = [("unique_together_one", "unique_together_two")] + + class Author(models.Model): + address = EmbeddedModelField(Address) + unique_together_three = models.CharField(max_length=10) + unique_together_four = models.CharField(max_length=10) + + class Meta: + app_label = "schema_" + unique_together = [("unique_together_three", "unique_together_four")] + + class Book(models.Model): + author = EmbeddedModelField(Author) + + class Meta: + app_label = "schema_" + + with connection.schema_editor() as editor: + editor.create_model(Book) + self.assertTableExists(Book) + # Embedded uniques are created. + self.assertEqual( + self.get_constraints_for_columns( + Book, ["author.unique_together_three", "author.unique_together_four"] + ), + [ + "schema__author_author.unique_together_three_author.unique_together_four_39e1cb43_uniq" + ], + ) + self.assertEqual( + self.get_constraints_for_columns( + Book, + ["author.address.unique_together_one", "author.address.unique_together_two"], + ), + [ + "schema__address_author.address.unique_together_one_author.address.unique_together_two_de682e30_uniq" + ], + ) + editor.delete_model(Book) + self.assertTableNotExists(Book) + + @isolate_apps("schema_") + def test_indexes(self): + """Meta.indexes on an embedded model.""" + + class Address(models.Model): + indexed_one = models.CharField(max_length=10) + + class Meta: + app_label = "schema_" + indexes = [models.Index(fields=["indexed_one"])] + + class Author(models.Model): + address = EmbeddedModelField(Address) + indexed_two = models.CharField(max_length=10) + + class Meta: + app_label = "schema_" + indexes = [models.Index(fields=["indexed_two"])] + + class Book(models.Model): + author = EmbeddedModelField(Author) + + class Meta: + app_label = "schema_" + + with connection.schema_editor() as editor: + editor.create_model(Book) + self.assertTableExists(Book) + # Embedded uniques are created. + self.assertEqual( + self.get_constraints_for_columns(Book, ["author.indexed_two"]), + ["schema__aut_indexed_b19137_idx"], + ) + self.assertEqual( + self.get_constraints_for_columns( + Book, + ["author.address.indexed_one"], + ), + ["schema__add_indexed_b64972_idx"], + ) + editor.delete_model(Author) + self.assertTableNotExists(Author) + + @isolate_apps("schema_") + def test_constraints(self): + """Meta.constraints on an embedded model.""" + + class Address(models.Model): + unique_constraint_one = models.CharField(max_length=10) + + class Meta: + app_label = "schema_" + constraints = [ + models.UniqueConstraint(fields=["unique_constraint_one"], name="unique_one") + ] + + class Author(models.Model): + address = EmbeddedModelField(Address) + unique_constraint_two = models.CharField(max_length=10) + + class Meta: + app_label = "schema_" + constraints = [ + models.UniqueConstraint(fields=["unique_constraint_two"], name="unique_two") + ] + + class Book(models.Model): + author = EmbeddedModelField(Author) + + class Meta: + app_label = "schema_" + + with connection.schema_editor() as editor: + editor.create_model(Book) + self.assertTableExists(Book) + # Embedded uniques are created. + self.assertEqual( + self.get_constraints_for_columns(Book, ["author.unique_constraint_two"]), + ["unique_two"], + ) + self.assertEqual( + self.get_constraints_for_columns( + Book, + ["author.address.unique_constraint_one"], + ), + ["unique_one"], + ) + editor.delete_model(Author) + self.assertTableNotExists(Author) + + # SchemaEditor.add_field() / remove_field() tests + @isolate_apps("schema_") + def test_add_remove_field_db_index_and_unique(self): + """AddField/RemoveField + EmbeddedModelField + Field(db_index=True) & Field(unique=True).""" + + class Book(models.Model): + name = models.CharField(max_length=100) + + class Meta: + app_label = "schema_" + + new_field = EmbeddedModelField(Author) + new_field.set_attributes_from_name("author") + with connection.schema_editor() as editor: + # Create the table amd add the field. + editor.create_model(Book) + editor.add_field(Book, new_field) + # Embedded indexes are created. + self.assertEqual( + self.get_constraints_for_columns(Book, ["author.age"]), + ["schema__book_author.age_dc08100b"], + ) + self.assertEqual( + self.get_constraints_for_columns(Book, ["author.address.zip_code"]), + ["schema__book_author.address.zip_code_7b9a9307"], + ) + # Embedded uniques + self.assertEqual( + self.get_constraints_for_columns(Book, ["author.employee_id"]), + ["schema__book_author.employee_id_7d4d3eff_uniq"], + ) + self.assertEqual( + self.get_constraints_for_columns(Book, ["author.address.uid"]), + ["schema__book_author.address.uid_8124a01f_uniq"], + ) + editor.remove_field(Book, new_field) + # Embedded indexes are removed. + self.assertEqual( + self.get_constraints_for_columns(Book, ["author.age"]), + [], + ) + self.assertEqual( + self.get_constraints_for_columns(Book, ["author.address.zip_code"]), + [], + ) + self.assertEqual( + self.get_constraints_for_columns(Book, ["author.employee_id"]), + [], + ) + self.assertEqual( + self.get_constraints_for_columns(Book, ["author.address.uid"]), + [], + ) + editor.delete_model(Book) + self.assertTableNotExists(Author) + + @ignore_warnings(category=RemovedInDjango51Warning) + @isolate_apps("schema_") + def test_add_remove_field_index_together(self): + """AddField/RemoveField + EmbeddedModelField + Meta.index_together.""" + + class Address(models.Model): + index_together_one = models.CharField(max_length=10) + index_together_two = models.CharField(max_length=10) + + class Meta: + app_label = "schema_" + index_together = [("index_together_one", "index_together_two")] + + class Author(models.Model): + address = EmbeddedModelField(Address) + index_together_three = models.CharField(max_length=10) + index_together_four = models.CharField(max_length=10) + + class Meta: + app_label = "schema_" + index_together = [("index_together_three", "index_together_four")] + + class Book(models.Model): + class Meta: + app_label = "schema_" + + new_field = EmbeddedModelField(Author) + new_field.set_attributes_from_name("author") + with connection.schema_editor() as editor: + # Create the table amd add the field. + editor.create_model(Book) + editor.add_field(Book, new_field) + # Embedded index_togethers are created. + self.assertEqual( + self.get_constraints_for_columns( + Book, ["author.address.index_together_one", "author.address.index_together_two"] + ), + ["schema__add_index_t_efa93e_idx"], + ) + self.assertEqual( + self.get_constraints_for_columns( + Book, + ["author.index_together_three", "author.index_together_four"], + ), + ["schema__aut_index_t_df32aa_idx"], + ) + editor.remove_field(Book, new_field) + # Embedded indexes are removed. + self.assertEqual( + self.get_constraints_for_columns( + Book, ["author.address.index_together_one", "author.address.index_together_two"] + ), + [], + ) + self.assertEqual( + self.get_constraints_for_columns( + Book, + ["author.index_together_three", "author.index_together_four"], + ), + [], + ) + editor.delete_model(Book) + self.assertTableNotExists(Book) + + @isolate_apps("schema_") + def test_add_remove_field_unique_together(self): + """AddField/RemoveField + EmbeddedModelField + Meta.unique_together.""" + + class Address(models.Model): + unique_together_one = models.CharField(max_length=10) + unique_together_two = models.CharField(max_length=10) + + class Meta: + app_label = "schema_" + unique_together = [("unique_together_one", "unique_together_two")] + + class Author(models.Model): + address = EmbeddedModelField(Address) + unique_together_three = models.CharField(max_length=10) + unique_together_four = models.CharField(max_length=10) + + class Meta: + app_label = "schema_" + unique_together = [("unique_together_three", "unique_together_four")] + + class Book(models.Model): + author = EmbeddedModelField(Author) + + class Meta: + app_label = "schema_" + + new_field = EmbeddedModelField(Author) + new_field.set_attributes_from_name("author") + with connection.schema_editor() as editor: + # Create the table and add the field. + editor.create_model(Book) + editor.add_field(Book, new_field) + # Embedded uniques are created. + self.assertEqual( + self.get_constraints_for_columns( + Book, ["author.unique_together_three", "author.unique_together_four"] + ), + [ + "schema__author_author.unique_together_three_author.unique_together_four_39e1cb43_uniq" + ], + ) + self.assertEqual( + self.get_constraints_for_columns( + Book, + ["author.address.unique_together_one", "author.address.unique_together_two"], + ), + [ + "schema__address_author.address.unique_together_one_author.address.unique_together_two_de682e30_uniq" + ], + ) + editor.remove_field(Book, new_field) + # Embedded indexes are removed. + self.assertEqual( + self.get_constraints_for_columns( + Book, ["author.unique_together_three", "author.unique_together_four"] + ), + [], + ) + self.assertEqual( + self.get_constraints_for_columns( + Book, + ["author.address.unique_together_one", "author.address.unique_together_two"], + ), + [], + ) + editor.delete_model(Book) + self.assertTableNotExists(Book) + + @isolate_apps("schema_") + def test_add_remove_field_indexes(self): + """AddField/RemoveField + EmbeddedModelField + Meta.indexes.""" + + class Address(models.Model): + indexed_one = models.CharField(max_length=10) + + class Meta: + app_label = "schema_" + indexes = [models.Index(fields=["indexed_one"])] + + class Author(models.Model): + address = EmbeddedModelField(Address) + indexed_two = models.CharField(max_length=10) + + class Meta: + app_label = "schema_" + indexes = [models.Index(fields=["indexed_two"])] + + class Book(models.Model): + author = EmbeddedModelField(Author) + + class Meta: + app_label = "schema_" + + new_field = EmbeddedModelField(Author) + new_field.set_attributes_from_name("author") + with connection.schema_editor() as editor: + # Create the table and add the field. + editor.create_model(Book) + editor.add_field(Book, new_field) + # Embedded indexes are created. + self.assertEqual( + self.get_constraints_for_columns(Book, ["author.indexed_two"]), + ["schema__aut_indexed_b19137_idx"], + ) + self.assertEqual( + self.get_constraints_for_columns( + Book, + ["author.address.indexed_one"], + ), + ["schema__add_indexed_b64972_idx"], + ) + editor.remove_field(Book, new_field) + # Embedded indexes are removed. + self.assertEqual( + self.get_constraints_for_columns(Book, ["author.indexed_two"]), + [], + ) + self.assertEqual( + self.get_constraints_for_columns( + Book, + ["author.address.indexed_one"], + ), + [], + ) + editor.delete_model(Author) + self.assertTableNotExists(Author) + + @isolate_apps("schema_") + def test_add_remove_field_constraints(self): + """AddField/RemoveField + EmbeddedModelField + Meta.constraints.""" + + class Address(models.Model): + unique_constraint_one = models.CharField(max_length=10) + + class Meta: + app_label = "schema_" + constraints = [ + models.UniqueConstraint(fields=["unique_constraint_one"], name="unique_one") + ] + + class Author(models.Model): + address = EmbeddedModelField(Address) + unique_constraint_two = models.CharField(max_length=10) + + class Meta: + app_label = "schema_" + constraints = [ + models.UniqueConstraint(fields=["unique_constraint_two"], name="unique_two") + ] + + class Book(models.Model): + author = EmbeddedModelField(Author) + + class Meta: + app_label = "schema_" + + new_field = EmbeddedModelField(Author) + new_field.set_attributes_from_name("author") + with connection.schema_editor() as editor: + # Create the table and add the field. + editor.create_model(Book) + editor.add_field(Book, new_field) + # Embedded constraints are created. + self.assertEqual( + self.get_constraints_for_columns(Book, ["author.unique_constraint_two"]), + ["unique_two"], + ) + self.assertEqual( + self.get_constraints_for_columns( + Book, + ["author.address.unique_constraint_one"], + ), + ["unique_one"], + ) + editor.remove_field(Book, new_field) + # Embedded constraints are removed. + self.assertEqual( + self.get_constraints_for_columns(Book, ["author.unique_constraint_two"]), + [], + ) + self.assertEqual( + self.get_constraints_for_columns( + Book, + ["author.address.unique_constraint_one"], + ), + [], + ) + editor.delete_model(Author) + self.assertTableNotExists(Author)