diff --git a/django_mongodb/fields/embedded_model.py b/django_mongodb/fields/embedded_model.py index 6b6fc899..4f54b58d 100644 --- a/django_mongodb/fields/embedded_model.py +++ b/django_mongodb/fields/embedded_model.py @@ -1,3 +1,4 @@ +from django.core import checks from django.db import models from django.db.models.fields.related import lazy_related_operation from django.db.models.lookups import Transform @@ -17,6 +18,19 @@ def __init__(self, embedded_model, *args, **kwargs): self.embedded_model = embedded_model super().__init__(*args, **kwargs) + def check(self, **kwargs): + errors = super().check(**kwargs) + for field in self.embedded_model._meta.fields: + if field.remote_field: + errors.append( + checks.Error( + "Embedded models cannot have relational fields.", + obj=self, + id="django_mongodb.embedded_model.E001", + ) + ) + return errors + def deconstruct(self): name, path, args, kwargs = super().deconstruct() if path.startswith("django_mongodb.fields.embedded_model"): diff --git a/tests/model_fields_/models.py b/tests/model_fields_/models.py index c1a2845e..bfdc8bd0 100644 --- a/tests/model_fields_/models.py +++ b/tests/model_fields_/models.py @@ -17,10 +17,6 @@ class PrimaryKeyObjectIdModel(models.Model): # EmbeddedModelField -class Target(models.Model): - index = models.IntegerField() - - class DecimalModel(models.Model): decimal = models.DecimalField(max_digits=9, decimal_places=2) @@ -29,17 +25,12 @@ class DecimalKey(models.Model): decimal = models.DecimalField(max_digits=9, decimal_places=2, primary_key=True) -class DecimalParent(models.Model): - child = models.ForeignKey(DecimalKey, models.CASCADE) - - class EmbeddedModelFieldModel(models.Model): simple = EmbeddedModelField("EmbeddedModel", null=True, blank=True) - decimal_parent = EmbeddedModelField(DecimalParent, null=True, blank=True) + decimal_parent = EmbeddedModelField(DecimalKey, null=True, blank=True) class EmbeddedModel(models.Model): - some_relation = models.ForeignKey(Target, models.CASCADE, null=True, blank=True) someint = models.IntegerField(db_column="custom_column") auto_now = models.DateTimeField(auto_now=True) auto_now_add = models.DateTimeField(auto_now_add=True) diff --git a/tests/model_fields_/test_embedded_model.py b/tests/model_fields_/test_embedded_model.py index 4b19a284..bfbaa178 100644 --- a/tests/model_fields_/test_embedded_model.py +++ b/tests/model_fields_/test_embedded_model.py @@ -1,7 +1,7 @@ -from decimal import Decimal - from django.core.exceptions import ValidationError +from django.db import models from django.test import SimpleTestCase, TestCase +from django.test.utils import isolate_apps from django_mongodb.fields import EmbeddedModelField @@ -9,11 +9,8 @@ Address, Author, Book, - DecimalKey, - DecimalParent, EmbeddedModel, EmbeddedModelFieldModel, - Target, ) @@ -82,19 +79,6 @@ def test_pre_save(self): self.assertEqual(obj.simple.auto_now_add, auto_now_add) self.assertGreater(obj.simple.auto_now, auto_now_two) - def test_foreign_key_in_embedded_object(self): - simple = EmbeddedModel(some_relation=Target.objects.create(index=1)) - obj = EmbeddedModelFieldModel.objects.create(simple=simple) - simple = EmbeddedModelFieldModel.objects.get().simple - self.assertNotIn("some_relation", simple.__dict__) - self.assertIsInstance(simple.__dict__["some_relation_id"], type(obj.id)) - self.assertIsInstance(simple.some_relation, Target) - - def test_embedded_field_with_foreign_conversion(self): - decimal = DecimalKey.objects.create(decimal=Decimal("1.5")) - decimal_parent = DecimalParent.objects.create(child=decimal) - EmbeddedModelFieldModel.objects.create(decimal_parent=decimal_parent) - class QueryingTests(TestCase): @classmethod @@ -134,3 +118,21 @@ def test_nested(self): author=Author(name="Shakespeare", age=55, address=Address(city="NYC", state="NY")) ) self.assertCountEqual(Book.objects.filter(author__address__city="NYC"), [obj]) + + +@isolate_apps("model_fields_") +class CheckTests(SimpleTestCase): + def test_no_relational_fields(self): + class Target(models.Model): + key = models.ForeignKey("MyModel", models.CASCADE) + + class MyModel(models.Model): + field = EmbeddedModelField(Target) + + model = MyModel() + errors = model.check() + self.assertEqual(len(errors), 1) + # The inner CharField has a non-positive max_length. + self.assertEqual(errors[0].id, "django_mongodb.embedded_model.E001") + msg = errors[0].msg + self.assertEqual(msg, "Embedded models cannot have relational fields.")