Skip to content

Commit

Permalink
prohibit embedded relational fields
Browse files Browse the repository at this point in the history
  • Loading branch information
timgraham committed Jan 6, 2025
1 parent 78bea1c commit b61cf56
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 28 deletions.
14 changes: 14 additions & 0 deletions django_mongodb/fields/embedded_model.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from django.core import checks
from django.db import models
from django.db.models.fields.related import lazy_related_operation
from django.db.models.lookups import Transform
Expand All @@ -17,6 +18,19 @@ def __init__(self, embedded_model, *args, **kwargs):
self.embedded_model = embedded_model
super().__init__(*args, **kwargs)

def check(self, **kwargs):
errors = super().check(**kwargs)
for field in self.embedded_model._meta.fields:
if field.remote_field:
errors.append(
checks.Error(
"Embedded models cannot have relational fields.",
obj=self,
id="django_mongodb.embedded_model.E001",
)
)
return errors

def deconstruct(self):
name, path, args, kwargs = super().deconstruct()
if path.startswith("django_mongodb.fields.embedded_model"):
Expand Down
11 changes: 1 addition & 10 deletions tests/model_fields_/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,6 @@ class PrimaryKeyObjectIdModel(models.Model):


# EmbeddedModelField
class Target(models.Model):
index = models.IntegerField()


class DecimalModel(models.Model):
decimal = models.DecimalField(max_digits=9, decimal_places=2)

Expand All @@ -29,17 +25,12 @@ class DecimalKey(models.Model):
decimal = models.DecimalField(max_digits=9, decimal_places=2, primary_key=True)


class DecimalParent(models.Model):
child = models.ForeignKey(DecimalKey, models.CASCADE)


class EmbeddedModelFieldModel(models.Model):
simple = EmbeddedModelField("EmbeddedModel", null=True, blank=True)
decimal_parent = EmbeddedModelField(DecimalParent, null=True, blank=True)
decimal_parent = EmbeddedModelField(DecimalKey, null=True, blank=True)


class EmbeddedModel(models.Model):
some_relation = models.ForeignKey(Target, models.CASCADE, null=True, blank=True)
someint = models.IntegerField(db_column="custom_column")
auto_now = models.DateTimeField(auto_now=True)
auto_now_add = models.DateTimeField(auto_now_add=True)
Expand Down
38 changes: 20 additions & 18 deletions tests/model_fields_/test_embedded_model.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,16 @@
from decimal import Decimal

from django.core.exceptions import ValidationError
from django.db import models
from django.test import SimpleTestCase, TestCase
from django.test.utils import isolate_apps

from django_mongodb.fields import EmbeddedModelField

from .models import (
Address,
Author,
Book,
DecimalKey,
DecimalParent,
EmbeddedModel,
EmbeddedModelFieldModel,
Target,
)


Expand Down Expand Up @@ -82,19 +79,6 @@ def test_pre_save(self):
self.assertEqual(obj.simple.auto_now_add, auto_now_add)
self.assertGreater(obj.simple.auto_now, auto_now_two)

def test_foreign_key_in_embedded_object(self):
simple = EmbeddedModel(some_relation=Target.objects.create(index=1))
obj = EmbeddedModelFieldModel.objects.create(simple=simple)
simple = EmbeddedModelFieldModel.objects.get().simple
self.assertNotIn("some_relation", simple.__dict__)
self.assertIsInstance(simple.__dict__["some_relation_id"], type(obj.id))
self.assertIsInstance(simple.some_relation, Target)

def test_embedded_field_with_foreign_conversion(self):
decimal = DecimalKey.objects.create(decimal=Decimal("1.5"))
decimal_parent = DecimalParent.objects.create(child=decimal)
EmbeddedModelFieldModel.objects.create(decimal_parent=decimal_parent)


class QueryingTests(TestCase):
@classmethod
Expand Down Expand Up @@ -134,3 +118,21 @@ def test_nested(self):
author=Author(name="Shakespeare", age=55, address=Address(city="NYC", state="NY"))
)
self.assertCountEqual(Book.objects.filter(author__address__city="NYC"), [obj])


@isolate_apps("model_fields_")
class CheckTests(SimpleTestCase):
def test_no_relational_fields(self):
class Target(models.Model):
key = models.ForeignKey("MyModel", models.CASCADE)

class MyModel(models.Model):
field = EmbeddedModelField(Target)

model = MyModel()
errors = model.check()
self.assertEqual(len(errors), 1)
# The inner CharField has a non-positive max_length.
self.assertEqual(errors[0].id, "django_mongodb.embedded_model.E001")
msg = errors[0].msg
self.assertEqual(msg, "Embedded models cannot have relational fields.")

0 comments on commit b61cf56

Please sign in to comment.