Skip to content

Commit

Permalink
wip forms support
Browse files Browse the repository at this point in the history
  • Loading branch information
timgraham committed Dec 30, 2024
1 parent 4602191 commit 62cafa8
Show file tree
Hide file tree
Showing 6 changed files with 234 additions and 0 deletions.
12 changes: 12 additions & 0 deletions django_mongodb/fields/embedded_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
from django.db.models.fields.related import lazy_related_operation
from django.db.models.lookups import Transform

from .. import forms


class EmbeddedModelField(models.Field):
"""Field that stores a model instance."""
Expand Down Expand Up @@ -123,6 +125,16 @@ def validate(self, value, model_instance):
attname = field.attname
field.validate(getattr(value, attname), model_instance)

def formfield(self, **kwargs):
return super().formfield(
**{
"form_class": forms.EmbeddedModelFormField,
"model": self.embedded_model,
"name": self.name,
**kwargs,
}
)


class KeyTransform(Transform):
def __init__(self, key_name, *args, **kwargs):
Expand Down
61 changes: 61 additions & 0 deletions django_mongodb/forms.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
from django import forms
from django.forms.models import modelform_factory
from django.utils.safestring import mark_safe
from django.utils.translation import gettext_lazy as _


class EmbeddedModelWidget(forms.MultiWidget):
def __init__(self, field_names, *args, **kwargs):
self.field_names = field_names
super().__init__(*args, **kwargs)
# The default widget names are "_0", "_1", etc. Use the field names
# instead since that's how they'll be rendered by the model form.
self.widgets_names = ["-" + name for name in field_names]

def decompress(self, value):
if value is None:
return []
# Get the data from `value` (a model) for each field.
return [getattr(value, name) for name in self.field_names]


class EmbeddedModelBoundField(forms.BoundField):
def __str__(self):
"""Render the model form as the representation for this field."""
form = self.field.model_form_cls(instance=self.value(), **self.field.form_kwargs)
return mark_safe(f"{form.as_div()}") # noqa: S308


class EmbeddedModelFormField(forms.MultiValueField):
default_error_messages = {
"invalid": _("Enter a list of values."),
"incomplete": _("Enter all required values."),
}

def __init__(self, model, name, *args, **kwargs):
form_kwargs = {}
# The field must be prefixed with the name of the field.
form_kwargs["prefix"] = name
self.form_kwargs = form_kwargs
self.model_form_cls = modelform_factory(model, fields="__all__")
self.model_form = self.model_form_cls(**form_kwargs)
self.field_names = list(self.model_form.fields.keys())
fields = self.model_form.fields.values()
widgets = [field.widget for field in fields]
widget = EmbeddedModelWidget(self.field_names, widgets)
super().__init__(*args, fields=fields, widget=widget, require_all_fields=False, **kwargs)

def compress(self, data_dict):
if not data_dict:
return None
values = dict(zip(self.field_names, data_dict, strict=False))
return self.model_form._meta.model(**values)

def get_bound_field(self, form, field_name):
return EmbeddedModelBoundField(form, self, field_name)

def bound_data(self, data, initial):
if self.disabled:
return initial
# The bound data must be transformed into a model instance.
return self.compress(data)
Empty file added tests/model_forms_/__init__.py
Empty file.
9 changes: 9 additions & 0 deletions tests/model_forms_/forms.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from django import forms

from .models import Author


class AuthorForm(forms.ModelForm):
class Meta:
fields = "__all__"
model = Author
22 changes: 22 additions & 0 deletions tests/model_forms_/models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from django.db import models

from django_mongodb.fields import EmbeddedModelField


class Address(models.Model):
po_box = models.CharField(max_length=50, blank=True, verbose_name="PO Box")
city = models.CharField(max_length=20)
state = models.CharField(max_length=2)
zip_code = models.IntegerField()


class Author(models.Model):
name = models.CharField(max_length=10)
age = models.IntegerField()
address = EmbeddedModelField(Address)
billing_address = EmbeddedModelField(Address, blank=True, null=True)


class Book(models.Model):
name = models.CharField(max_length=100)
author = EmbeddedModelField(Author)
130 changes: 130 additions & 0 deletions tests/model_forms_/test_embedded_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
from django.test import TestCase

from .forms import AuthorForm
from .models import Address, Author


class ModelFormTests(TestCase):
def test_update(self):
author = Author.objects.create(
name="Bob", age=50, address=Address(city="NYC", state="NY", zip_code="10001")
)
data = {
"name": "Bob",
"age": 51,
"address-po_box": "",
"address-city": "New York City",
"address-state": "NY",
"address-zip_code": "10001",
}
form = AuthorForm(data, instance=author)
self.assertTrue(form.is_valid())
form.save()
author.refresh_from_db()
self.assertEqual(author.age, 51)
self.assertEqual(author.address.city, "New York City")

def test_some_missing_data(self):
author = Author.objects.create(
name="Bob", age=50, address=Address(city="NYC", state="NY", zip_code="10001")
)
data = {
"name": "Bob",
"age": 51,
"address-po_box": "",
"address-city": "New York City",
"address-state": "NY",
"address-zip_code": "",
}
form = AuthorForm(data, instance=author)
self.assertFalse(form.is_valid())
self.assertEqual(form.errors["address"], ["Enter all required values."])

def test_invalid_field_data(self):
"""A field's data (state) is too long."""
author = Author.objects.create(
name="Bob", age=50, address=Address(city="NYC", state="NY", zip_code="10001")
)
data = {
"name": "Bob",
"age": 51,
"address-po_box": "",
"address-city": "New York City",
"address-state": "TOO LONG",
"address-zip_code": "",
}
form = AuthorForm(data, instance=author)
self.assertFalse(form.is_valid())
self.assertEqual(
form.errors["address"],
[
"Ensure this value has at most 2 characters (it has 8).",
"Enter all required values.",
],
)

def test_all_missing_data(self):
author = Author.objects.create(
name="Bob", age=50, address=Address(city="NYC", state="NY", zip_code="10001")
)
data = {
"name": "Bob",
"age": 51,
"address-po_box": "",
"address-city": "",
"address-state": "",
"address-zip_code": "",
}
form = AuthorForm(data, instance=author)
self.assertFalse(form.is_valid())
self.assertEqual(form.errors["address"], ["This field is required."])

def test_nullable_field(self):
"""A nullable EmbeddedModelField is removed if all fields are empty."""
author = Author.objects.create(
name="Bob",
age=50,
address=Address(city="NYC", state="NY", zip_code="10001"),
billing_address=Address(city="NYC", state="NY", zip_code="10001"),
)
data = {
"name": "Bob",
"age": 51,
"address-po_box": "",
"address-city": "New York City",
"address-state": "NY",
"address-zip_code": "10001",
"billing_address-po_box": "",
"billing_address-city": "",
"billing_address-state": "",
"billing_address-zip_code": "",
}
form = AuthorForm(data, instance=author)
self.assertTrue(form.is_valid())
form.save()
author.refresh_from_db()
self.assertIsNone(author.billing_address)

def test_rendering(self):
form = AuthorForm()
self.assertHTMLEqual(
str(form.fields["address"].get_bound_field(form, "address")),
"""
<div>
<label for="id_address-po_box">PO Box:</label>
<input id="id_address-po_box" maxlength="50" name="address-po_box" type="text">
</div>
<div>
<label for="id_address-city">City:</label>
<input type="text" name="address-city" maxlength="20" required id="id_address-city">
</div>
<div>
<label for="id_address-state">State:</label>
<input type="text" name="address-state" maxlength="2" required
id="id_address-state">
</div>
<div>
<label for="id_address-zip_code">Zip code:</label>
<input type="number" name="address-zip_code" required id="id_address-zip_code">
</div>""",
)

0 comments on commit 62cafa8

Please sign in to comment.