From 452fd0b3767b0d5a605fa2ba7c1a456c796b3343 Mon Sep 17 00:00:00 2001 From: Tim Graham Date: Sat, 16 Nov 2024 20:30:33 -0500 Subject: [PATCH] wip schema changes --- django_mongodb/schema.py | 80 +++++++--- tests/model_fields_/models.py | 1 + tests/schema_/__init__.py | 0 tests/schema_/models.py | 44 ++++++ tests/schema_/test_embedded_model.py | 226 +++++++++++++++++++++++++++ 5 files changed, 331 insertions(+), 20 deletions(-) create mode 100644 tests/schema_/__init__.py create mode 100644 tests/schema_/models.py create mode 100644 tests/schema_/test_embedded_model.py diff --git a/django_mongodb/schema.py b/django_mongodb/schema.py index 8fc18fea..05b2c756 100644 --- a/django_mongodb/schema.py +++ b/django_mongodb/schema.py @@ -5,6 +5,7 @@ from pymongo import ASCENDING, DESCENDING from pymongo.operations import IndexModel +from .fields import EmbeddedModelField from .query import wrap_database_errors from .utils import OperationCollector @@ -29,25 +30,40 @@ def create_model(self, model): if field.remote_field.through._meta.auto_created: self.create_model(field.remote_field.through) - def _create_model_indexes(self, model): + def _create_model_indexes(self, model, column_prefix="", parent_model=None): """ Create all indexes (field indexes & uniques, Meta.index_together, Meta.unique_together, Meta.constraints, Meta.indexes) for the model. + + If this is a recursive call to due to an embedded model, `column_prefix` + tracks the path that must be prepended to the index's column, and + `parent_model` tracks the collection to add the index/constraint to. """ if not model._meta.managed or model._meta.proxy or model._meta.swapped: return # Field indexes and uniques for field in model._meta.local_fields: + if isinstance(field, EmbeddedModelField): + new_path = f"{column_prefix}{field.column}." + self._create_model_indexes( + field.embedded_model, parent_model=parent_model or model, column_prefix=new_path + ) if self._field_should_be_indexed(model, field): - self._add_field_index(model, field) + self._add_field_index(parent_model or model, field, column_prefix=column_prefix) elif self._field_should_have_unique(field): - self._add_field_unique(model, field) + self._add_field_unique(parent_model or model, field, column_prefix=column_prefix) # Meta.index_together (RemovedInDjango51Warning) for field_names in model._meta.index_together: self._add_composed_index(model, field_names) # Meta.unique_together if model._meta.unique_together: - self.alter_unique_together(model, [], model._meta.unique_together) + self.alter_unique_together( + model, + [], + model._meta.unique_together, + column_prefix=column_prefix, + parent_model=parent_model, + ) # Meta.constraints for constraint in model._meta.constraints: self.add_constraint(model, constraint) @@ -147,7 +163,9 @@ def alter_index_together(self, model, old_index_together, new_index_together): for field_names in news.difference(olds): self._add_composed_index(model, field_names) - def alter_unique_together(self, model, old_unique_together, new_unique_together): + def alter_unique_together( + self, model, old_unique_together, new_unique_together, column_prefix="", parent_model=None + ): olds = {tuple(fields) for fields in old_unique_together} news = {tuple(fields) for fields in new_unique_together} # Deleted uniques @@ -160,11 +178,19 @@ def alter_unique_together(self, model, old_unique_together, new_unique_together) # Created uniques for field_names in news.difference(olds): columns = [model._meta.get_field(field).column for field in field_names] - name = str(self._unique_constraint_name(model._meta.db_table, columns)) + name = str( + self._unique_constraint_name( + model._meta.db_table, [column_prefix + col for col in columns] + ) + ) constraint = UniqueConstraint(fields=field_names, name=name) - self.add_constraint(model, constraint) + self.add_constraint( + model, constraint, parent_model=parent_model, column_prefix=column_prefix + ) - def add_index(self, model, index, field=None, unique=False): + def add_index( + self, model, index, *, field=None, unique=False, column_prefix="", parent_model=None + ): if index.contains_expressions: return kwargs = {} @@ -176,7 +202,8 @@ def add_index(self, model, index, field=None, unique=False): # Indexing on $type matches the value of most SQL databases by # allowing multiple null values for the unique constraint. if field: - filter_expression[field.column].update({"$type": field.db_type(self.connection)}) + column = column_prefix + field.column + filter_expression[column].update({"$type": field.db_type(self.connection)}) else: for field_name, _ in index.fields_orders: field_ = model._meta.get_field(field_name) @@ -186,16 +213,20 @@ def add_index(self, model, index, field=None, unique=False): if filter_expression: kwargs["partialFilterExpression"] = filter_expression index_orders = ( - [(field.column, ASCENDING)] + [(column_prefix + field.column, ASCENDING)] if field else [ # order is "" if ASCENDING or "DESC" if DESCENDING (see # django.db.models.indexes.Index.fields_orders). - (model._meta.get_field(field_name).column, ASCENDING if order == "" else DESCENDING) + ( + column_prefix + model._meta.get_field(field_name).column, + ASCENDING if order == "" else DESCENDING, + ) for field_name, order in index.fields_orders ] ) idx = IndexModel(index_orders, name=index.name, **kwargs) + model = parent_model or model self.get_collection(model._meta.db_table).create_indexes([idx]) def _add_composed_index(self, model, field_names): @@ -204,11 +235,11 @@ def _add_composed_index(self, model, field_names): idx.set_name_with_model(model) self.add_index(model, idx) - def _add_field_index(self, model, field): + def _add_field_index(self, model, field, *, column_prefix=""): """Add an index on a field with db_index=True.""" - index = Index(fields=[field.name]) - index.name = self._create_index_name(model._meta.db_table, [field.column]) - self.add_index(model, index, field=field) + index = Index(fields=[column_prefix + field.name]) + index.name = self._create_index_name(model._meta.db_table, [column_prefix + field.column]) + self.add_index(model, index, field=field, column_prefix=column_prefix) def remove_index(self, model, index): if index.contains_expressions: @@ -260,7 +291,7 @@ def _remove_field_index(self, model, field): ) collection.drop_index(index_names[0]) - def add_constraint(self, model, constraint, field=None): + def add_constraint(self, model, constraint, field=None, column_prefix="", parent_model=None): if isinstance(constraint, UniqueConstraint) and self._unique_supported( condition=constraint.condition, deferrable=constraint.deferrable, @@ -273,12 +304,21 @@ def add_constraint(self, model, constraint, field=None): name=constraint.name, condition=constraint.condition, ) - self.add_index(model, idx, field=field, unique=True) + self.add_index( + model, + idx, + field=field, + unique=True, + column_prefix=column_prefix, + parent_model=parent_model, + ) - def _add_field_unique(self, model, field): - name = str(self._unique_constraint_name(model._meta.db_table, [field.column])) + def _add_field_unique(self, model, field, column_prefix=""): + name = str( + self._unique_constraint_name(model._meta.db_table, [column_prefix + field.column]) + ) constraint = UniqueConstraint(fields=[field.name], name=name) - self.add_constraint(model, constraint, field=field) + self.add_constraint(model, constraint, field=field, column_prefix=column_prefix) def remove_constraint(self, model, constraint): if isinstance(constraint, UniqueConstraint) and self._unique_supported( diff --git a/tests/model_fields_/models.py b/tests/model_fields_/models.py index 94da7bd8..c1a2845e 100644 --- a/tests/model_fields_/models.py +++ b/tests/model_fields_/models.py @@ -48,6 +48,7 @@ class EmbeddedModel(models.Model): class Address(models.Model): city = models.CharField(max_length=20) state = models.CharField(max_length=2) + zip_code = models.IntegerField(db_index=True) class Author(models.Model): diff --git a/tests/schema_/__init__.py b/tests/schema_/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/schema_/models.py b/tests/schema_/models.py new file mode 100644 index 00000000..b16c2629 --- /dev/null +++ b/tests/schema_/models.py @@ -0,0 +1,44 @@ +from django.apps.registry import Apps +from django.db import models + +from django_mongodb.fields import EmbeddedModelField + +# Because we want to test creation and deletion of these as separate things, +# these models are all inserted into a separate Apps so the main test +# runner doesn't migrate them. + +new_apps = Apps() + + +class Address(models.Model): + city = models.CharField(max_length=20) + state = models.CharField(max_length=2) + zip_code = models.IntegerField(db_index=True) + uid = models.IntegerField(unique=True) + unique_together_one = models.CharField(max_length=10) + unique_together_two = models.CharField(max_length=10) + + class Meta: + apps = new_apps + unique_together = [("unique_together_one", "unique_together_two")] + + +class Author(models.Model): + name = models.CharField(max_length=10) + age = models.IntegerField(db_index=True) + address = EmbeddedModelField(Address) + employee_id = models.IntegerField(unique=True) + unique_together_three = models.CharField(max_length=10) + unique_together_four = models.CharField(max_length=10) + + class Meta: + apps = new_apps + unique_together = [("unique_together_three", "unique_together_four")] + + +class Book(models.Model): + name = models.CharField(max_length=100) + author = EmbeddedModelField(Author) + + class Meta: + apps = new_apps diff --git a/tests/schema_/test_embedded_model.py b/tests/schema_/test_embedded_model.py new file mode 100644 index 00000000..d30f1294 --- /dev/null +++ b/tests/schema_/test_embedded_model.py @@ -0,0 +1,226 @@ +import itertools + +from django.db import connection +from django.test import TransactionTestCase + +from .models import Address, Author, Book, new_apps + + +class SchemaTests(TransactionTestCase): + available_apps = [] + models = [Address, Author, Book] + + # Utility functions + + def setUp(self): + # local_models should contain test dependent model classes that will be + # automatically removed from the app cache on test tear down. + self.local_models = [] + # isolated_local_models contains models that are in test methods + # decorated with @isolate_apps. + self.isolated_local_models = [] + + def tearDown(self): + # Delete any tables made for our models + self.delete_tables() + new_apps.clear_cache() + for model in new_apps.get_models(): + model._meta._expire_cache() + if "schema" in new_apps.all_models: + for model in self.local_models: + for many_to_many in model._meta.many_to_many: + through = many_to_many.remote_field.through + if through and through._meta.auto_created: + del new_apps.all_models["schema"][through._meta.model_name] + del new_apps.all_models["schema"][model._meta.model_name] + if self.isolated_local_models: + with connection.schema_editor() as editor: + for model in self.isolated_local_models: + editor.delete_model(model) + + def delete_tables(self): + "Deletes all model tables for our models for a clean test environment" + converter = connection.introspection.identifier_converter + with connection.schema_editor() as editor: + connection.disable_constraint_checking() + table_names = connection.introspection.table_names() + if connection.features.ignores_table_name_case: + table_names = [table_name.lower() for table_name in table_names] + for model in itertools.chain(SchemaTests.models, self.local_models): + tbl = converter(model._meta.db_table) + if connection.features.ignores_table_name_case: + tbl = tbl.lower() + if tbl in table_names: + editor.delete_model(model) + table_names.remove(tbl) + connection.enable_constraint_checking() + + def get_indexes(self, table): + """ + Get the indexes on the table using a new cursor. + """ + with connection.cursor() as cursor: + return [ + c["columns"][0] + for c in connection.introspection.get_constraints(cursor, table).values() + if c["index"] and len(c["columns"]) == 1 + ] + + def get_uniques(self, table): + with connection.cursor() as cursor: + return [ + c["columns"][0] + for c in connection.introspection.get_constraints(cursor, table).values() + if c["unique"] and len(c["columns"]) == 1 + ] + + def get_constraints(self, table): + """ + Get the constraints on a table using a new cursor. + """ + with connection.cursor() as cursor: + return connection.introspection.get_constraints(cursor, table) + + def get_constraints_for_columns(self, model, columns): + constraints = self.get_constraints(model._meta.db_table) + constraints_for_column = [] + for name, details in constraints.items(): + if details["columns"] == columns: + constraints_for_column.append(name) + return sorted(constraints_for_column) + + def check_added_field_default( + self, + schema_editor, + model, + field, + field_name, + expected_default, + cast_function=None, + ): + schema_editor.add_field(model, field) + database_default = connection.database[model._meta.db_table].find_one().get(field_name) + if cast_function and type(database_default) is not type(expected_default): + database_default = cast_function(database_default) + self.assertEqual(database_default, expected_default) + + def get_constraints_count(self, table, column, fk_to): + """ + Return a dict with keys 'fks', 'uniques, and 'indexes' indicating the + number of foreign keys, unique constraints, and indexes on + `table`.`column`. The `fk_to` argument is a 2-tuple specifying the + expected foreign key relationship's (table, column). + """ + with connection.cursor() as cursor: + constraints = connection.introspection.get_constraints(cursor, table) + counts = {"fks": 0, "uniques": 0, "indexes": 0} + for c in constraints.values(): + if c["columns"] == [column]: + if c["foreign_key"] == fk_to: + counts["fks"] += 1 + if c["unique"]: + counts["uniques"] += 1 + elif c["index"]: + counts["indexes"] += 1 + return counts + + def assertIndexOrder(self, table, index, order): + constraints = self.get_constraints(table) + self.assertIn(index, constraints) + index_orders = constraints[index]["orders"] + self.assertTrue( + all(val == expected for val, expected in zip(index_orders, order, strict=True)) + ) + + def assertForeignKeyExists(self, model, column, expected_fk_table, field="id"): + """ + Fail if the FK constraint on `model.Meta.db_table`.`column` to + `expected_fk_table`.id doesn't exist. + """ + if not connection.features.can_introspect_foreign_keys: + return + constraints = self.get_constraints(model._meta.db_table) + constraint_fk = None + for details in constraints.values(): + if details["columns"] == [column] and details["foreign_key"]: + constraint_fk = details["foreign_key"] + break + self.assertEqual(constraint_fk, (expected_fk_table, field)) + + def assertForeignKeyNotExists(self, model, column, expected_fk_table): + if not connection.features.can_introspect_foreign_keys: + return + with self.assertRaises(AssertionError): + self.assertForeignKeyExists(model, column, expected_fk_table) + + def assertTableExists(self, model): + self.assertIn(model._meta.db_table, connection.introspection.table_names()) + + def assertTableNotExists(self, model): + self.assertNotIn(model._meta.db_table, connection.introspection.table_names()) + + # Tests + def test_db_index(self): + """Field(db_index=True) on an embedded model.""" + with connection.schema_editor() as editor: + # Create the table + editor.create_model(Book) + # The table is there + self.assertTableExists(Book) + # Embedded indexes are created. + self.assertEqual( + self.get_constraints_for_columns(Book, ["author.age"]), + ["schema__book_author.age_dc08100b"], + ) + self.assertEqual( + self.get_constraints_for_columns(Book, ["author.address.zip_code"]), + ["schema__book_author.address.zip_code_7b9a9307"], + ) + # Clean up that table + editor.delete_model(Book) + # The table is gone + self.assertTableNotExists(Author) + + def test_unique(self): + """Field(unique=True) on an embedded model.""" + with connection.schema_editor() as editor: + editor.create_model(Book) + self.assertTableExists(Book) + # Embedded uniques are created. + self.assertEqual( + self.get_constraints_for_columns(Book, ["author.employee_id"]), + ["schema__book_author.employee_id_7d4d3eff_uniq"], + ) + self.assertEqual( + self.get_constraints_for_columns(Book, ["author.address.uid"]), + ["schema__book_author.address.uid_8124a01f_uniq"], + ) + # Clean up that table + editor.delete_model(Book) + self.assertTableNotExists(Author) + + def test_unique_together(self): + """Meta.unique_together on an embedded model.""" + with connection.schema_editor() as editor: + editor.create_model(Book) + self.assertTableExists(Book) + # Embedded uniques are created. + self.assertEqual( + self.get_constraints_for_columns( + Book, ["author.unique_together_three", "author.unique_together_four"] + ), + [ + "schema__author_author.unique_together_three_author.unique_together_four_39e1cb43_uniq" + ], + ) + self.assertEqual( + self.get_constraints_for_columns( + Book, + ["author.address.unique_together_one", "author.address.unique_together_two"], + ), + [ + "schema__address_author.address.unique_together_one_author.address.unique_together_two_de682e30_uniq" + ], + ) + editor.delete_model(Author) + self.assertTableNotExists(Author)