Skip to content

Commit

Permalink
wip schema changes
Browse files Browse the repository at this point in the history
  • Loading branch information
timgraham committed Dec 30, 2024
1 parent 2e7e277 commit 452fd0b
Show file tree
Hide file tree
Showing 5 changed files with 331 additions and 20 deletions.
80 changes: 60 additions & 20 deletions django_mongodb/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from pymongo import ASCENDING, DESCENDING
from pymongo.operations import IndexModel

from .fields import EmbeddedModelField
from .query import wrap_database_errors
from .utils import OperationCollector

Expand All @@ -29,25 +30,40 @@ def create_model(self, model):
if field.remote_field.through._meta.auto_created:
self.create_model(field.remote_field.through)

def _create_model_indexes(self, model):
def _create_model_indexes(self, model, column_prefix="", parent_model=None):
"""
Create all indexes (field indexes & uniques, Meta.index_together,
Meta.unique_together, Meta.constraints, Meta.indexes) for the model.
If this is a recursive call to due to an embedded model, `column_prefix`
tracks the path that must be prepended to the index's column, and
`parent_model` tracks the collection to add the index/constraint to.
"""
if not model._meta.managed or model._meta.proxy or model._meta.swapped:
return
# Field indexes and uniques
for field in model._meta.local_fields:
if isinstance(field, EmbeddedModelField):
new_path = f"{column_prefix}{field.column}."
self._create_model_indexes(
field.embedded_model, parent_model=parent_model or model, column_prefix=new_path
)
if self._field_should_be_indexed(model, field):
self._add_field_index(model, field)
self._add_field_index(parent_model or model, field, column_prefix=column_prefix)
elif self._field_should_have_unique(field):
self._add_field_unique(model, field)
self._add_field_unique(parent_model or model, field, column_prefix=column_prefix)
# Meta.index_together (RemovedInDjango51Warning)
for field_names in model._meta.index_together:
self._add_composed_index(model, field_names)
# Meta.unique_together
if model._meta.unique_together:
self.alter_unique_together(model, [], model._meta.unique_together)
self.alter_unique_together(
model,
[],
model._meta.unique_together,
column_prefix=column_prefix,
parent_model=parent_model,
)
# Meta.constraints
for constraint in model._meta.constraints:
self.add_constraint(model, constraint)
Expand Down Expand Up @@ -147,7 +163,9 @@ def alter_index_together(self, model, old_index_together, new_index_together):
for field_names in news.difference(olds):
self._add_composed_index(model, field_names)

def alter_unique_together(self, model, old_unique_together, new_unique_together):
def alter_unique_together(
self, model, old_unique_together, new_unique_together, column_prefix="", parent_model=None
):
olds = {tuple(fields) for fields in old_unique_together}
news = {tuple(fields) for fields in new_unique_together}
# Deleted uniques
Expand All @@ -160,11 +178,19 @@ def alter_unique_together(self, model, old_unique_together, new_unique_together)
# Created uniques
for field_names in news.difference(olds):
columns = [model._meta.get_field(field).column for field in field_names]
name = str(self._unique_constraint_name(model._meta.db_table, columns))
name = str(
self._unique_constraint_name(
model._meta.db_table, [column_prefix + col for col in columns]
)
)
constraint = UniqueConstraint(fields=field_names, name=name)
self.add_constraint(model, constraint)
self.add_constraint(
model, constraint, parent_model=parent_model, column_prefix=column_prefix
)

def add_index(self, model, index, field=None, unique=False):
def add_index(
self, model, index, *, field=None, unique=False, column_prefix="", parent_model=None
):
if index.contains_expressions:
return
kwargs = {}
Expand All @@ -176,7 +202,8 @@ def add_index(self, model, index, field=None, unique=False):
# Indexing on $type matches the value of most SQL databases by
# allowing multiple null values for the unique constraint.
if field:
filter_expression[field.column].update({"$type": field.db_type(self.connection)})
column = column_prefix + field.column
filter_expression[column].update({"$type": field.db_type(self.connection)})
else:
for field_name, _ in index.fields_orders:
field_ = model._meta.get_field(field_name)
Expand All @@ -186,16 +213,20 @@ def add_index(self, model, index, field=None, unique=False):
if filter_expression:
kwargs["partialFilterExpression"] = filter_expression
index_orders = (
[(field.column, ASCENDING)]
[(column_prefix + field.column, ASCENDING)]
if field
else [
# order is "" if ASCENDING or "DESC" if DESCENDING (see
# django.db.models.indexes.Index.fields_orders).
(model._meta.get_field(field_name).column, ASCENDING if order == "" else DESCENDING)
(
column_prefix + model._meta.get_field(field_name).column,
ASCENDING if order == "" else DESCENDING,
)
for field_name, order in index.fields_orders
]
)
idx = IndexModel(index_orders, name=index.name, **kwargs)
model = parent_model or model
self.get_collection(model._meta.db_table).create_indexes([idx])

def _add_composed_index(self, model, field_names):
Expand All @@ -204,11 +235,11 @@ def _add_composed_index(self, model, field_names):
idx.set_name_with_model(model)
self.add_index(model, idx)

def _add_field_index(self, model, field):
def _add_field_index(self, model, field, *, column_prefix=""):
"""Add an index on a field with db_index=True."""
index = Index(fields=[field.name])
index.name = self._create_index_name(model._meta.db_table, [field.column])
self.add_index(model, index, field=field)
index = Index(fields=[column_prefix + field.name])
index.name = self._create_index_name(model._meta.db_table, [column_prefix + field.column])
self.add_index(model, index, field=field, column_prefix=column_prefix)

def remove_index(self, model, index):
if index.contains_expressions:
Expand Down Expand Up @@ -260,7 +291,7 @@ def _remove_field_index(self, model, field):
)
collection.drop_index(index_names[0])

def add_constraint(self, model, constraint, field=None):
def add_constraint(self, model, constraint, field=None, column_prefix="", parent_model=None):
if isinstance(constraint, UniqueConstraint) and self._unique_supported(
condition=constraint.condition,
deferrable=constraint.deferrable,
Expand All @@ -273,12 +304,21 @@ def add_constraint(self, model, constraint, field=None):
name=constraint.name,
condition=constraint.condition,
)
self.add_index(model, idx, field=field, unique=True)
self.add_index(
model,
idx,
field=field,
unique=True,
column_prefix=column_prefix,
parent_model=parent_model,
)

def _add_field_unique(self, model, field):
name = str(self._unique_constraint_name(model._meta.db_table, [field.column]))
def _add_field_unique(self, model, field, column_prefix=""):
name = str(
self._unique_constraint_name(model._meta.db_table, [column_prefix + field.column])
)
constraint = UniqueConstraint(fields=[field.name], name=name)
self.add_constraint(model, constraint, field=field)
self.add_constraint(model, constraint, field=field, column_prefix=column_prefix)

def remove_constraint(self, model, constraint):
if isinstance(constraint, UniqueConstraint) and self._unique_supported(
Expand Down
1 change: 1 addition & 0 deletions tests/model_fields_/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ class EmbeddedModel(models.Model):
class Address(models.Model):
city = models.CharField(max_length=20)
state = models.CharField(max_length=2)
zip_code = models.IntegerField(db_index=True)


class Author(models.Model):
Expand Down
Empty file added tests/schema_/__init__.py
Empty file.
44 changes: 44 additions & 0 deletions tests/schema_/models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
from django.apps.registry import Apps
from django.db import models

from django_mongodb.fields import EmbeddedModelField

# Because we want to test creation and deletion of these as separate things,
# these models are all inserted into a separate Apps so the main test
# runner doesn't migrate them.

new_apps = Apps()


class Address(models.Model):
city = models.CharField(max_length=20)
state = models.CharField(max_length=2)
zip_code = models.IntegerField(db_index=True)
uid = models.IntegerField(unique=True)
unique_together_one = models.CharField(max_length=10)
unique_together_two = models.CharField(max_length=10)

class Meta:
apps = new_apps
unique_together = [("unique_together_one", "unique_together_two")]


class Author(models.Model):
name = models.CharField(max_length=10)
age = models.IntegerField(db_index=True)
address = EmbeddedModelField(Address)
employee_id = models.IntegerField(unique=True)
unique_together_three = models.CharField(max_length=10)
unique_together_four = models.CharField(max_length=10)

class Meta:
apps = new_apps
unique_together = [("unique_together_three", "unique_together_four")]


class Book(models.Model):
name = models.CharField(max_length=100)
author = EmbeddedModelField(Author)

class Meta:
apps = new_apps
Loading

0 comments on commit 452fd0b

Please sign in to comment.