From 6a87d8416720e692e06aff0c38e8775d39c26e73 Mon Sep 17 00:00:00 2001 From: Emanuel Lupi Date: Mon, 27 Jan 2025 19:48:43 -0500 Subject: [PATCH] add support for JSONField lookups in an embedded model --- .../fields/embedded_model.py | 22 ++++++++--- django_mongodb_backend/fields/json.py | 38 ++++++++++--------- tests/model_fields_/models.py | 1 + tests/model_fields_/test_embedded_model.py | 25 ++++++++++++ 4 files changed, 63 insertions(+), 23 deletions(-) diff --git a/django_mongodb_backend/fields/embedded_model.py b/django_mongodb_backend/fields/embedded_model.py index 6d3d3580..d9dd5b6c 100644 --- a/django_mongodb_backend/fields/embedded_model.py +++ b/django_mongodb_backend/fields/embedded_model.py @@ -7,6 +7,7 @@ from django.db.models.lookups import Transform from .. import forms +from .json import build_json_mql_path class EmbeddedModelField(models.Field): @@ -181,18 +182,27 @@ def get_transform(self, name): return result def preprocess_lhs(self, compiler, connection): - key_transforms = [self.key_name] - previous = self.lhs + previous = self + embedded_key_transforms = [] + json_key_transforms = [] while isinstance(previous, KeyTransform): - key_transforms.insert(0, previous.key_name) + if isinstance(previous.ref_field, EmbeddedModelField): + embedded_key_transforms.insert(0, previous.key_name) + else: + json_key_transforms.insert(0, previous.key_name) previous = previous.lhs mql = previous.as_mql(compiler, connection) - return mql, key_transforms + # The first json_key_transform is the field name. + embedded_key_transforms.append(json_key_transforms.pop(0)) + return mql, embedded_key_transforms, json_key_transforms def as_mql(self, compiler, connection): - mql, key_transforms = self.preprocess_lhs(compiler, connection) + mql, key_transforms, json_key_transforms = self.preprocess_lhs(compiler, connection) transforms = ".".join(key_transforms) - return f"{mql}.{transforms}" + result = f"{mql}.{transforms}" + if json_key_transforms: + result = build_json_mql_path(result, json_key_transforms) + return result class KeyTransformFactory: diff --git a/django_mongodb_backend/fields/json.py b/django_mongodb_backend/fields/json.py index 218ae649..b7cf49dc 100644 --- a/django_mongodb_backend/fields/json.py +++ b/django_mongodb_backend/fields/json.py @@ -17,6 +17,26 @@ from ..query_utils import process_lhs, process_rhs +def build_json_mql_path(lhs, key_transforms): + # Build the MQL path using the collected key transforms. + result = lhs + for key in key_transforms: + get_field = {"$getField": {"input": result, "field": key}} + # Handle array indexing if the key is a digit. If key is something + # like '001', it's not an array index despite isdigit() returning True. + if key.isdigit() and str(int(key)) == key: + result = { + "$cond": { + "if": {"$isArray": result}, + "then": {"$arrayElemAt": [result, int(key)]}, + "else": get_field, + } + } + else: + result = get_field + return result + + def contained_by(self, compiler, connection): # noqa: ARG001 raise NotSupportedError("contained_by lookup is not supported on this database backend.") @@ -89,23 +109,7 @@ def key_transform(self, compiler, connection): key_transforms.insert(0, previous.key_name) previous = previous.lhs lhs_mql = previous.as_mql(compiler, connection) - result = lhs_mql - # Build the MQL path using the collected key transforms. - for key in key_transforms: - get_field = {"$getField": {"input": result, "field": key}} - # Handle array indexing if the key is a digit. If key is something - # like '001', it's not an array index despite isdigit() returning True. - if key.isdigit() and str(int(key)) == key: - result = { - "$cond": { - "if": {"$isArray": result}, - "then": {"$arrayElemAt": [result, int(key)]}, - "else": get_field, - } - } - else: - result = get_field - return result + return build_json_mql_path(lhs_mql, key_transforms) def key_transform_in(self, compiler, connection): diff --git a/tests/model_fields_/models.py b/tests/model_fields_/models.py index 0420249c..b25b94a1 100644 --- a/tests/model_fields_/models.py +++ b/tests/model_fields_/models.py @@ -103,6 +103,7 @@ class Data(EmbeddedModel): integer = models.IntegerField(db_column="custom_column") auto_now = models.DateTimeField(auto_now=True) auto_now_add = models.DateTimeField(auto_now_add=True) + json_value = models.JSONField() class Address(EmbeddedModel): diff --git a/tests/model_fields_/test_embedded_model.py b/tests/model_fields_/test_embedded_model.py index dfb3a579..4beb2e25 100644 --- a/tests/model_fields_/test_embedded_model.py +++ b/tests/model_fields_/test_embedded_model.py @@ -114,6 +114,31 @@ def test_order_by_embedded_field(self): qs = Holder.objects.filter(data__integer__gt=3).order_by("-data__integer") self.assertSequenceEqual(qs, list(reversed(self.objs[4:]))) + def test_embedded_json_field_lookups(self): + objs = [ + Holder.objects.create( + data=Data(json_value={"field1": i * 5, "field2": {"0": {"value": list(range(i))}}}) + ) + for i in range(4) + ] + self.assertCountEqual( + Holder.objects.filter(data__json_value__field2__0__value__0=0), + objs[1:], + ) + self.assertCountEqual( + Holder.objects.filter(data__json_value__field2__0__value__1=1), + objs[2:], + ) + self.assertCountEqual(Holder.objects.filter(data__json_value__field2__0__value__1=5), []) + self.assertCountEqual(Holder.objects.filter(data__json_value__field1__lt=100), objs) + self.assertCountEqual(Holder.objects.filter(data__json_value__field1__gt=100), []) + self.assertCountEqual( + Holder.objects.filter( + data__json_value__field1__gte=5, data__json_value__field1__lte=10 + ), + objs[1:3], + ) + def test_order_and_group_by_embedded_field(self): # Create and sort test data by `data__integer`. expected_objs = sorted(