Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add EmbeddedModelField #151

Merged
merged 1 commit into from
Jan 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ repos:
hooks:
- id: rstcheck
additional_dependencies: [sphinx]
args: ["--ignore-directives=fieldlookup,setting", "--ignore-roles=lookup,setting"]
args: ["--ignore-directives=django-admin,fieldlookup,setting", "--ignore-roles=djadmin,lookup,setting"]

# We use the Python version instead of the original version which seems to require Docker
# https://github.com/koalaman/shellcheck-precommit
Expand Down
2 changes: 1 addition & 1 deletion THIRD-PARTY-NOTICES
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ be distributed under licenses different than this software.

The attached notices are provided for information only.

django-mongodb-backend began by borrowing code from Django non-rel's
django-mongodb-backend and EmbeddedModelField began by borrowing code from
django-mongodb-engine (https://github.com/django-nonrel/mongodb-engine),
abandoned since 2015 and Django 1.6.

Expand Down
2 changes: 1 addition & 1 deletion django_mongodb_backend/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -741,7 +741,7 @@ def execute_sql(self, result_type):
elif hasattr(value, "prepare_database_save"):
if field.remote_field:
value = value.prepare_database_save(field)
else:
elif not hasattr(field, "embedded_model"):
raise TypeError(
f"Tried to update field {field} with a model "
f"instance, {value!r}. Use a value compatible with "
Expand Down
9 changes: 8 additions & 1 deletion django_mongodb_backend/fields/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,17 @@
from .array import ArrayField
from .auto import ObjectIdAutoField
from .duration import register_duration_field
from .embedded_model import EmbeddedModelField
from .json import register_json_field
from .objectid import ObjectIdField

__all__ = ["register_fields", "ArrayField", "ObjectIdAutoField", "ObjectIdField"]
__all__ = [
"register_fields",
"ArrayField",
"EmbeddedModelField",
"ObjectIdAutoField",
"ObjectIdField",
]


def register_fields():
Expand Down
161 changes: 161 additions & 0 deletions django_mongodb_backend/fields/embedded_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
from django.core import checks
from django.db import models
from django.db.models.fields.related import lazy_related_operation
from django.db.models.lookups import Transform

from .. import forms


class EmbeddedModelField(models.Field):
"""Field that stores a model instance."""

def __init__(self, embedded_model, *args, **kwargs):
"""
`embedded_model` is the model class of the instance to be stored.
Like other relational fields, it may also be passed as a string.
"""
self.embedded_model = embedded_model
super().__init__(*args, **kwargs)

def check(self, **kwargs):
errors = super().check(**kwargs)
for field in self.embedded_model._meta.fields:
if field.remote_field:
errors.append(
checks.Error(
"Embedded models cannot have relational fields "
f"({self.embedded_model().__class__.__name__}.{field.name} "
f"is a {field.__class__.__name__}).",
obj=self,
id="django_mongodb.embedded_model.E001",
)
)
return errors

def deconstruct(self):
name, path, args, kwargs = super().deconstruct()
if path.startswith("django_mongodb_backend.fields.embedded_model"):
path = path.replace(
"django_mongodb_backend.fields.embedded_model", "django_mongodb_backend.fields"
)
kwargs["embedded_model"] = self.embedded_model
return name, path, args, kwargs

def get_internal_type(self):
return "EmbeddedModelField"

def _set_model(self, model):
"""
Resolve embedded model class once the field knows the model it belongs
to. If __init__()'s embedded_model argument is a string, resolve it to
the actual model class, similar to relation fields.
"""
self._model = model
if model is not None and isinstance(self.embedded_model, str):

def _resolve_lookup(_, resolved_model):
self.embedded_model = resolved_model

lazy_related_operation(_resolve_lookup, model, self.embedded_model)

model = property(lambda self: self._model, _set_model)

def from_db_value(self, value, expression, connection):
return self.to_python(value)

def to_python(self, value):
"""
Pass embedded model fields' values through each field's to_python() and
reinstantiate the embedded instance.
"""
if value is None:
return None
if not isinstance(value, dict):
return value
instance = self.embedded_model(
**{
field.attname: field.to_python(value[field.attname])
for field in self.embedded_model._meta.fields
if field.attname in value
}
)
instance._state.adding = False
return instance

def get_db_prep_save(self, embedded_instance, connection):
"""
Apply pre_save() and get_db_prep_save() of embedded instance fields and
create the {field: value} dict to be saved.
"""
if embedded_instance is None:
return None
if not isinstance(embedded_instance, self.embedded_model):
raise TypeError(
f"Expected instance of type {self.embedded_model!r}, not "
f"{type(embedded_instance)!r}."
)
field_values = {}
add = embedded_instance._state.adding
for field in embedded_instance._meta.fields:
value = field.get_db_prep_save(
field.pre_save(embedded_instance, add), connection=connection
)
# Exclude unset primary keys (e.g. {'id': None}).
if field.primary_key and value is None:
continue
field_values[field.attname] = value
# This instance will exist in the database soon.
embedded_instance._state.adding = False
return field_values

def get_transform(self, name):
transform = super().get_transform(name)
if transform:
return transform
return KeyTransformFactory(name)

def validate(self, value, model_instance):
super().validate(value, model_instance)
if self.embedded_model is None:
return
for field in self.embedded_model._meta.fields:
attname = field.attname
field.validate(getattr(value, attname), model_instance)

def formfield(self, **kwargs):
return super().formfield(
**{
"form_class": forms.EmbeddedModelField,
"model": self.embedded_model,
"prefix": self.name,
**kwargs,
}
)


class KeyTransform(Transform):
def __init__(self, key_name, *args, **kwargs):
super().__init__(*args, **kwargs)
self.key_name = str(key_name)

def preprocess_lhs(self, compiler, connection):
key_transforms = [self.key_name]
previous = self.lhs
while isinstance(previous, KeyTransform):
key_transforms.insert(0, previous.key_name)
previous = previous.lhs
mql = previous.as_mql(compiler, connection)
return mql, key_transforms

def as_mql(self, compiler, connection):
mql, key_transforms = self.preprocess_lhs(compiler, connection)
transforms = ".".join(key_transforms)
return f"{mql}.{transforms}"


class KeyTransformFactory:
def __init__(self, key_name):
self.key_name = key_name

def __call__(self, *args, **kwargs):
return KeyTransform(self.key_name, *args, **kwargs)
9 changes: 8 additions & 1 deletion django_mongodb_backend/forms/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,13 @@
from .fields import ObjectIdField, SimpleArrayField, SplitArrayField, SplitArrayWidget
from .fields import (
EmbeddedModelField,
ObjectIdField,
SimpleArrayField,
SplitArrayField,
SplitArrayWidget,
)

__all__ = [
"EmbeddedModelField",
"SimpleArrayField",
"SplitArrayField",
"SplitArrayWidget",
Expand Down
2 changes: 2 additions & 0 deletions django_mongodb_backend/forms/fields/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from .array import SimpleArrayField, SplitArrayField, SplitArrayWidget
from .embedded_model import EmbeddedModelField
from .objectid import ObjectIdField

__all__ = [
"EmbeddedModelField",
"SimpleArrayField",
"SplitArrayField",
"SplitArrayWidget",
Expand Down
62 changes: 62 additions & 0 deletions django_mongodb_backend/forms/fields/embedded_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
from django import forms
from django.forms.models import modelform_factory
from django.utils.safestring import mark_safe
from django.utils.translation import gettext_lazy as _


class EmbeddedModelWidget(forms.MultiWidget):
def __init__(self, field_names, *args, **kwargs):
self.field_names = field_names
super().__init__(*args, **kwargs)
# The default widget names are "_0", "_1", etc. Use the field names
# instead since that's how they'll be rendered by the model form.
self.widgets_names = ["-" + name for name in field_names]

def decompress(self, value):
if value is None:
return []
# Get the data from `value` (a model) for each field.
return [getattr(value, name) for name in self.field_names]


class EmbeddedModelBoundField(forms.BoundField):
def __str__(self):
"""Render the model form as the representation for this field."""
form = self.field.model_form_cls(instance=self.value(), **self.field.form_kwargs)
return mark_safe(f"{form.as_div()}") # noqa: S308


class EmbeddedModelField(forms.MultiValueField):
default_error_messages = {
"invalid": _("Enter a list of values."),
"incomplete": _("Enter all required values."),
}

def __init__(self, model, prefix, *args, **kwargs):
form_kwargs = {}
# To avoid collisions with other fields on the form, each subfield must
# be prefixed with the name of the field.
form_kwargs["prefix"] = prefix
self.form_kwargs = form_kwargs
self.model_form_cls = modelform_factory(model, fields="__all__")
self.model_form = self.model_form_cls(**form_kwargs)
self.field_names = list(self.model_form.fields.keys())
fields = self.model_form.fields.values()
widgets = [field.widget for field in fields]
widget = EmbeddedModelWidget(self.field_names, widgets)
super().__init__(*args, fields=fields, widget=widget, require_all_fields=False, **kwargs)

def compress(self, data_dict):
if not data_dict:
return None
values = dict(zip(self.field_names, data_dict, strict=False))
return self.model_form._meta.model(**values)

def get_bound_field(self, form, field_name):
return EmbeddedModelBoundField(form, self, field_name)

def bound_data(self, data, initial):
if self.disabled:
return initial
# Transform the bound data into a model instance.
return self.compress(data)
Loading
Loading