Skip to content

Commit

Permalink
add field name validation to EmbeddedModelField lookups
Browse files Browse the repository at this point in the history
  • Loading branch information
timgraham committed Jan 23, 2025
1 parent 5156577 commit a1c6b39
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 5 deletions.
40 changes: 36 additions & 4 deletions django_mongodb_backend/fields/embedded_model.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
import difflib

from django.core import checks
from django.core.exceptions import FieldDoesNotExist
from django.db import models
from django.db.models.fields.related import lazy_related_operation
from django.db.models.lookups import Transform
Expand Down Expand Up @@ -123,7 +126,8 @@ def get_transform(self, name):
transform = super().get_transform(name)
if transform:
return transform
return KeyTransformFactory(name)
field = self.embedded_model._meta.get_field(name)
return KeyTransformFactory(name, field)

def validate(self, value, model_instance):
super().validate(value, model_instance)
Expand All @@ -145,9 +149,36 @@ def formfield(self, **kwargs):


class KeyTransform(Transform):
def __init__(self, key_name, *args, **kwargs):
def __init__(self, key_name, ref_field, *args, **kwargs):
super().__init__(*args, **kwargs)
self.key_name = str(key_name)
self.ref_field = ref_field

def get_transform(self, name):
"""
Validate that `name` is either a field of an embedded model or a
lookup on an embedded model's field.
"""
result = None
if isinstance(self.ref_field, EmbeddedModelField):
opts = self.ref_field.embedded_model._meta
new_field = opts.get_field(name)
result = KeyTransformFactory(name, new_field)
else:
if self.ref_field.get_transform(name) is None:
suggested_lookups = difflib.get_close_matches(name, self.ref_field.get_lookups())
if suggested_lookups:
suggested_lookups = " or ".join(suggested_lookups)
suggestion = f", perhaps you meant {suggested_lookups}?"
else:
suggestion = "."
raise FieldDoesNotExist(
f"Unsupported lookup '{name}' for "
f"{self.ref_field.__class__.__name__} '{self.ref_field.name}'"
f"{suggestion}"
)
result = KeyTransformFactory(name, self.ref_field)
return result

def preprocess_lhs(self, compiler, connection):
key_transforms = [self.key_name]
Expand All @@ -165,8 +196,9 @@ def as_mql(self, compiler, connection):


class KeyTransformFactory:
def __init__(self, key_name):
def __init__(self, key_name, ref_field):
self.key_name = key_name
self.ref_field = ref_field

def __call__(self, *args, **kwargs):
return KeyTransform(self.key_name, *args, **kwargs)
return KeyTransform(self.key_name, self.ref_field, *args, **kwargs)
37 changes: 36 additions & 1 deletion tests/model_fields_/test_embedded_model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import operator

from django.core.exceptions import ValidationError
from django.core.exceptions import FieldDoesNotExist, ValidationError
from django.db import models
from django.db.models import ExpressionWrapper, F, Max, Sum
from django.test import SimpleTestCase, TestCase
Expand Down Expand Up @@ -147,6 +147,41 @@ def test_nested(self):
self.assertCountEqual(Book.objects.filter(author__address__city="NYC"), [obj])


class InvalidLookupTests(SimpleTestCase):
def test_invalid_field(self):
msg = "Author has no field named 'first_name'"
with self.assertRaisesMessage(FieldDoesNotExist, msg):
Book.objects.filter(author__first_name="Bob")

def test_invalid_field_nested(self):
msg = "Address has no field named 'floor'"
with self.assertRaisesMessage(FieldDoesNotExist, msg):
Book.objects.filter(author__address__floor="NYC")

def test_invalid_lookup(self):
msg = "Unsupported lookup 'foo' for CharField 'city'."
with self.assertRaisesMessage(FieldDoesNotExist, msg):
Book.objects.filter(author__address__city__foo="NYC")

def test_invalid_lookup_with_suggestions(self):
msg = (
"Unsupported lookup '{lookup}' for CharField 'name', "
"perhaps you meant {suggested_lookups}?"
)
with self.assertRaisesMessage(
FieldDoesNotExist, msg.format(lookup="exactly", suggested_lookups="exact or iexact")
):
Book.objects.filter(author__name__exactly="NYC")
with self.assertRaisesMessage(
FieldDoesNotExist, msg.format(lookup="gti", suggested_lookups="gt or gte")
):
Book.objects.filter(author__name__gti="NYC")
with self.assertRaisesMessage(
FieldDoesNotExist, msg.format(lookup="is_null", suggested_lookups="isnull")
):
Book.objects.filter(author__name__is_null="NYC")


@isolate_apps("model_fields_")
class CheckTests(SimpleTestCase):
def test_no_relational_fields(self):
Expand Down

0 comments on commit a1c6b39

Please sign in to comment.