-
Notifications
You must be signed in to change notification settings - Fork 7
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
6 changed files
with
234 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
from django import forms | ||
from django.forms.models import modelform_factory | ||
from django.utils.safestring import mark_safe | ||
from django.utils.translation import gettext_lazy as _ | ||
|
||
|
||
class EmbeddedModelWidget(forms.MultiWidget): | ||
def __init__(self, field_names, *args, **kwargs): | ||
self.field_names = field_names | ||
super().__init__(*args, **kwargs) | ||
# The default widget names are "_0", "_1", etc. Use the field names | ||
# instead since that's how they'll be rendered by the model form. | ||
self.widgets_names = ["-" + name for name in field_names] | ||
|
||
def decompress(self, value): | ||
if value is None: | ||
return [] | ||
# Get the data from `value` (a model) for each field. | ||
return [getattr(value, name) for name in self.field_names] | ||
|
||
|
||
class EmbeddedModelBoundField(forms.BoundField): | ||
def __str__(self): | ||
"""Render the model form as the representation for this field.""" | ||
form = self.field.model_form_cls(instance=self.value(), **self.field.form_kwargs) | ||
return mark_safe(f"{form.as_div()}") # noqa: S308 | ||
|
||
|
||
class EmbeddedModelFormField(forms.MultiValueField): | ||
default_error_messages = { | ||
"invalid": _("Enter a list of values."), | ||
"incomplete": _("Enter all required values."), | ||
} | ||
|
||
def __init__(self, model, name, *args, **kwargs): | ||
form_kwargs = {} | ||
# The field must be prefixed with the name of the field. | ||
form_kwargs["prefix"] = name | ||
self.form_kwargs = form_kwargs | ||
self.model_form_cls = modelform_factory(model, fields="__all__") | ||
self.model_form = self.model_form_cls(**form_kwargs) | ||
self.field_names = list(self.model_form.fields.keys()) | ||
fields = self.model_form.fields.values() | ||
widgets = [field.widget for field in fields] | ||
widget = EmbeddedModelWidget(self.field_names, widgets) | ||
super().__init__(*args, fields=fields, widget=widget, require_all_fields=False, **kwargs) | ||
|
||
def compress(self, data_dict): | ||
if not data_dict: | ||
return None | ||
values = dict(zip(self.field_names, data_dict, strict=False)) | ||
return self.model_form._meta.model(**values) | ||
|
||
def get_bound_field(self, form, field_name): | ||
return EmbeddedModelBoundField(form, self, field_name) | ||
|
||
def bound_data(self, data, initial): | ||
if self.disabled: | ||
return initial | ||
# The bound data must be transformed into a model instance. | ||
return self.compress(data) |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
from django import forms | ||
|
||
from .models import Author | ||
|
||
|
||
class AuthorForm(forms.ModelForm): | ||
class Meta: | ||
fields = "__all__" | ||
model = Author |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
from django.db import models | ||
|
||
from django_mongodb.fields import EmbeddedModelField | ||
|
||
|
||
class Address(models.Model): | ||
po_box = models.CharField(max_length=50, blank=True, verbose_name="PO Box") | ||
city = models.CharField(max_length=20) | ||
state = models.CharField(max_length=2) | ||
zip_code = models.IntegerField() | ||
|
||
|
||
class Author(models.Model): | ||
name = models.CharField(max_length=10) | ||
age = models.IntegerField() | ||
address = EmbeddedModelField(Address) | ||
billing_address = EmbeddedModelField(Address, blank=True, null=True) | ||
|
||
|
||
class Book(models.Model): | ||
name = models.CharField(max_length=100) | ||
author = EmbeddedModelField(Author) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,130 @@ | ||
from django.test import TestCase | ||
|
||
from .forms import AuthorForm | ||
from .models import Address, Author | ||
|
||
|
||
class ModelFormTests(TestCase): | ||
def test_update(self): | ||
author = Author.objects.create( | ||
name="Bob", age=50, address=Address(city="NYC", state="NY", zip_code="10001") | ||
) | ||
data = { | ||
"name": "Bob", | ||
"age": 51, | ||
"address-po_box": "", | ||
"address-city": "New York City", | ||
"address-state": "NY", | ||
"address-zip_code": "10001", | ||
} | ||
form = AuthorForm(data, instance=author) | ||
self.assertTrue(form.is_valid()) | ||
form.save() | ||
author.refresh_from_db() | ||
self.assertEqual(author.age, 51) | ||
self.assertEqual(author.address.city, "New York City") | ||
|
||
def test_some_missing_data(self): | ||
author = Author.objects.create( | ||
name="Bob", age=50, address=Address(city="NYC", state="NY", zip_code="10001") | ||
) | ||
data = { | ||
"name": "Bob", | ||
"age": 51, | ||
"address-po_box": "", | ||
"address-city": "New York City", | ||
"address-state": "NY", | ||
"address-zip_code": "", | ||
} | ||
form = AuthorForm(data, instance=author) | ||
self.assertFalse(form.is_valid()) | ||
self.assertEqual(form.errors["address"], ["Enter all required values."]) | ||
|
||
def test_invalid_field_data(self): | ||
"""A field's data (state) is too long.""" | ||
author = Author.objects.create( | ||
name="Bob", age=50, address=Address(city="NYC", state="NY", zip_code="10001") | ||
) | ||
data = { | ||
"name": "Bob", | ||
"age": 51, | ||
"address-po_box": "", | ||
"address-city": "New York City", | ||
"address-state": "TOO LONG", | ||
"address-zip_code": "", | ||
} | ||
form = AuthorForm(data, instance=author) | ||
self.assertFalse(form.is_valid()) | ||
self.assertEqual( | ||
form.errors["address"], | ||
[ | ||
"Ensure this value has at most 2 characters (it has 8).", | ||
"Enter all required values.", | ||
], | ||
) | ||
|
||
def test_all_missing_data(self): | ||
author = Author.objects.create( | ||
name="Bob", age=50, address=Address(city="NYC", state="NY", zip_code="10001") | ||
) | ||
data = { | ||
"name": "Bob", | ||
"age": 51, | ||
"address-po_box": "", | ||
"address-city": "", | ||
"address-state": "", | ||
"address-zip_code": "", | ||
} | ||
form = AuthorForm(data, instance=author) | ||
self.assertFalse(form.is_valid()) | ||
self.assertEqual(form.errors["address"], ["This field is required."]) | ||
|
||
def test_nullable_field(self): | ||
"""A nullable EmbeddedModelField is removed if all fields are empty.""" | ||
author = Author.objects.create( | ||
name="Bob", | ||
age=50, | ||
address=Address(city="NYC", state="NY", zip_code="10001"), | ||
billing_address=Address(city="NYC", state="NY", zip_code="10001"), | ||
) | ||
data = { | ||
"name": "Bob", | ||
"age": 51, | ||
"address-po_box": "", | ||
"address-city": "New York City", | ||
"address-state": "NY", | ||
"address-zip_code": "10001", | ||
"billing_address-po_box": "", | ||
"billing_address-city": "", | ||
"billing_address-state": "", | ||
"billing_address-zip_code": "", | ||
} | ||
form = AuthorForm(data, instance=author) | ||
self.assertTrue(form.is_valid()) | ||
form.save() | ||
author.refresh_from_db() | ||
self.assertIsNone(author.billing_address) | ||
|
||
def test_rendering(self): | ||
form = AuthorForm() | ||
self.assertHTMLEqual( | ||
str(form.fields["address"].get_bound_field(form, "address")), | ||
""" | ||
<div> | ||
<label for="id_address-po_box">PO Box:</label> | ||
<input id="id_address-po_box" maxlength="50" name="address-po_box" type="text"> | ||
</div> | ||
<div> | ||
<label for="id_address-city">City:</label> | ||
<input type="text" name="address-city" maxlength="20" required id="id_address-city"> | ||
</div> | ||
<div> | ||
<label for="id_address-state">State:</label> | ||
<input type="text" name="address-state" maxlength="2" required | ||
id="id_address-state"> | ||
</div> | ||
<div> | ||
<label for="id_address-zip_code">Zip code:</label> | ||
<input type="number" name="address-zip_code" required id="id_address-zip_code"> | ||
</div>""", | ||
) |