Skip to content

Commit

Permalink
add EmbeddedModelField
Browse files Browse the repository at this point in the history
  • Loading branch information
timgraham committed Jan 14, 2025
1 parent eae01a4 commit f77e03c
Show file tree
Hide file tree
Showing 23 changed files with 1,492 additions and 42 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ repos:
hooks:
- id: rstcheck
additional_dependencies: [sphinx]
args: ["--ignore-directives=fieldlookup,setting", "--ignore-roles=lookup,setting"]
args: ["--ignore-directives=django-admin,fieldlookup,setting", "--ignore-roles=djadmin,lookup,setting"]

# We use the Python version instead of the original version which seems to require Docker
# https://github.com/koalaman/shellcheck-precommit
Expand Down
2 changes: 1 addition & 1 deletion THIRD-PARTY-NOTICES
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ be distributed under licenses different than this software.

The attached notices are provided for information only.

django-mongodb-backend began by borrowing code from Django non-rel's
django-mongodb-backend and EmbeddedModelField began by borrowing code from
django-mongodb-engine (https://github.com/django-nonrel/mongodb-engine),
abandoned since 2015 and Django 1.6.

Expand Down
2 changes: 1 addition & 1 deletion django_mongodb_backend/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -741,7 +741,7 @@ def execute_sql(self, result_type):
elif hasattr(value, "prepare_database_save"):
if field.remote_field:
value = value.prepare_database_save(field)
else:
elif not hasattr(field, "embedded_model"):
raise TypeError(
f"Tried to update field {field} with a model "
f"instance, {value!r}. Use a value compatible with "
Expand Down
9 changes: 8 additions & 1 deletion django_mongodb_backend/fields/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,17 @@
from .array import ArrayField
from .auto import ObjectIdAutoField
from .duration import register_duration_field
from .embedded_model import EmbeddedModelField
from .json import register_json_field
from .objectid import ObjectIdField

__all__ = ["register_fields", "ArrayField", "ObjectIdAutoField", "ObjectIdField"]
__all__ = [
"register_fields",
"ArrayField",
"EmbeddedModelField",
"ObjectIdAutoField",
"ObjectIdField",
]


def register_fields():
Expand Down
161 changes: 161 additions & 0 deletions django_mongodb_backend/fields/embedded_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
from django.core import checks
from django.db import models
from django.db.models.fields.related import lazy_related_operation
from django.db.models.lookups import Transform

from .. import forms


class EmbeddedModelField(models.Field):
"""Field that stores a model instance."""

def __init__(self, embedded_model, *args, **kwargs):
"""
`embedded_model` is the model class of the instance to be stored.
Like other relational fields, it may also be passed as a string.
"""
self.embedded_model = embedded_model
super().__init__(*args, **kwargs)

def check(self, **kwargs):
errors = super().check(**kwargs)
for field in self.embedded_model._meta.fields:
if field.remote_field:
errors.append(
checks.Error(
"Embedded models cannot have relational fields "
f"({self.embedded_model().__class__.__name__}.{field.name} "
f"is a {field.__class__.__name__}).",
obj=self,
id="django_mongodb.embedded_model.E001",
)
)
return errors

def deconstruct(self):
name, path, args, kwargs = super().deconstruct()
if path.startswith("django_mongodb_backend.fields.embedded_model"):
path = path.replace(
"django_mongodb_backend.fields.embedded_model", "django_mongodb_backend.fields"
)
kwargs["embedded_model"] = self.embedded_model
return name, path, args, kwargs

def get_internal_type(self):
return "EmbeddedModelField"

def _set_model(self, model):
"""
Resolve embedded model class once the field knows the model it belongs
to. If __init__()'s embedded_model argument is a string, resolve it to
the actual model class, similar to relation fields.
"""
self._model = model
if model is not None and isinstance(self.embedded_model, str):

def _resolve_lookup(_, resolved_model):
self.embedded_model = resolved_model

lazy_related_operation(_resolve_lookup, model, self.embedded_model)

model = property(lambda self: self._model, _set_model)

def from_db_value(self, value, expression, connection):
return self.to_python(value)

def to_python(self, value):
"""
Pass embedded model fields' values through each field's to_python() and
reinstantiate the embedded instance.
"""
if value is None:
return None
if not isinstance(value, dict):
return value
instance = self.embedded_model(
**{
field.attname: field.to_python(value[field.attname])
for field in self.embedded_model._meta.fields
if field.attname in value
}
)
instance._state.adding = False
return instance

def get_db_prep_save(self, embedded_instance, connection):
"""
Apply pre_save() and get_db_prep_save() of embedded instance fields and
create the {field: value} dict to be saved.
"""
if embedded_instance is None:
return None
if not isinstance(embedded_instance, self.embedded_model):
raise TypeError(
f"Expected instance of type {self.embedded_model!r}, not "
f"{type(embedded_instance)!r}."
)
field_values = {}
add = embedded_instance._state.adding
for field in embedded_instance._meta.fields:
value = field.get_db_prep_save(
field.pre_save(embedded_instance, add), connection=connection
)
# Exclude unset primary keys (e.g. {'id': None}).
if field.primary_key and value is None:
continue
field_values[field.attname] = value
# This instance will exist in the database soon.
embedded_instance._state.adding = False
return field_values

def get_transform(self, name):
transform = super().get_transform(name)
if transform:
return transform
return KeyTransformFactory(name)

def validate(self, value, model_instance):
super().validate(value, model_instance)
if self.embedded_model is None:
return
for field in self.embedded_model._meta.fields:
attname = field.attname
field.validate(getattr(value, attname), model_instance)

def formfield(self, **kwargs):
return super().formfield(
**{
"form_class": forms.EmbeddedModelField,
"model": self.embedded_model,
"prefix": self.name,
**kwargs,
}
)


class KeyTransform(Transform):
def __init__(self, key_name, *args, **kwargs):
super().__init__(*args, **kwargs)
self.key_name = str(key_name)

def preprocess_lhs(self, compiler, connection):
key_transforms = [self.key_name]
previous = self.lhs
while isinstance(previous, KeyTransform):
key_transforms.insert(0, previous.key_name)
previous = previous.lhs
mql = previous.as_mql(compiler, connection)
return mql, key_transforms

def as_mql(self, compiler, connection):
mql, key_transforms = self.preprocess_lhs(compiler, connection)
transforms = ".".join(key_transforms)
return f"{mql}.{transforms}"


class KeyTransformFactory:
def __init__(self, key_name):
self.key_name = key_name

def __call__(self, *args, **kwargs):
return KeyTransform(self.key_name, *args, **kwargs)
9 changes: 8 additions & 1 deletion django_mongodb_backend/forms/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,13 @@
from .fields import ObjectIdField, SimpleArrayField, SplitArrayField, SplitArrayWidget
from .fields import (
EmbeddedModelField,
ObjectIdField,
SimpleArrayField,
SplitArrayField,
SplitArrayWidget,
)

__all__ = [
"EmbeddedModelField",
"SimpleArrayField",
"SplitArrayField",
"SplitArrayWidget",
Expand Down
2 changes: 2 additions & 0 deletions django_mongodb_backend/forms/fields/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from .array import SimpleArrayField, SplitArrayField, SplitArrayWidget
from .embedded_model import EmbeddedModelField
from .objectid import ObjectIdField

__all__ = [
"EmbeddedModelField",
"SimpleArrayField",
"SplitArrayField",
"SplitArrayWidget",
Expand Down
62 changes: 62 additions & 0 deletions django_mongodb_backend/forms/fields/embedded_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
from django import forms
from django.forms.models import modelform_factory
from django.utils.safestring import mark_safe
from django.utils.translation import gettext_lazy as _


class EmbeddedModelWidget(forms.MultiWidget):
def __init__(self, field_names, *args, **kwargs):
self.field_names = field_names
super().__init__(*args, **kwargs)
# The default widget names are "_0", "_1", etc. Use the field names
# instead since that's how they'll be rendered by the model form.
self.widgets_names = ["-" + name for name in field_names]

def decompress(self, value):
if value is None:
return []
# Get the data from `value` (a model) for each field.
return [getattr(value, name) for name in self.field_names]


class EmbeddedModelBoundField(forms.BoundField):
def __str__(self):
"""Render the model form as the representation for this field."""
form = self.field.model_form_cls(instance=self.value(), **self.field.form_kwargs)
return mark_safe(f"{form.as_div()}") # noqa: S308


class EmbeddedModelField(forms.MultiValueField):
default_error_messages = {
"invalid": _("Enter a list of values."),
"incomplete": _("Enter all required values."),
}

def __init__(self, model, prefix, *args, **kwargs):
form_kwargs = {}
# To avoid collisions with other fields on the form, each subfield must
# be prefixed with the name of the field.
form_kwargs["prefix"] = prefix
self.form_kwargs = form_kwargs
self.model_form_cls = modelform_factory(model, fields="__all__")
self.model_form = self.model_form_cls(**form_kwargs)
self.field_names = list(self.model_form.fields.keys())
fields = self.model_form.fields.values()
widgets = [field.widget for field in fields]
widget = EmbeddedModelWidget(self.field_names, widgets)
super().__init__(*args, fields=fields, widget=widget, require_all_fields=False, **kwargs)

def compress(self, data_dict):
if not data_dict:
return None
values = dict(zip(self.field_names, data_dict, strict=False))
return self.model_form._meta.model(**values)

def get_bound_field(self, form, field_name):
return EmbeddedModelBoundField(form, self, field_name)

def bound_data(self, data, initial):
if self.disabled:
return initial
# Transform the bound data into a model instance.
return self.compress(data)
Loading

0 comments on commit f77e03c

Please sign in to comment.