diff --git a/django_mongodb/features.py b/django_mongodb/features.py index ac062b78..b23fec02 100644 --- a/django_mongodb/features.py +++ b/django_mongodb/features.py @@ -91,16 +91,10 @@ class DatabaseFeatures(BaseDatabaseFeatures): "model_fields_.test_arrayfield.TestQuerying.test_icontains", # Field 'field' expected a number but got Value(1). "model_fields_.test_arrayfield.TestQuerying.test_exact_with_expression", - # int() argument must be a string, a bytes-like object or a real number, not 'list' - "model_fields_.test_arrayfield.TestQuerying.test_index_annotation", - # Wrong results - "model_fields_.test_arrayfield.TestQuerying.test_index", - "model_fields_.test_arrayfield.TestQuerying.test_index_chained", - "model_fields_.test_arrayfield.TestQuerying.test_index_nested", - "model_fields_.test_arrayfield.TestQuerying.test_order_by_slice", # $lt treats null values as zero. "model_fields_.test_arrayfield.TestQuerying.test_lt", "model_fields_.test_arrayfield.TestQuerying.test_len", + "model_fields_.test_arrayfield.TestQuerying.test_index_chained", # None is $in None "model_fields_.test_arrayfield.TestQuerying.test_in_as_F_object", } diff --git a/django_mongodb/fields/array.py b/django_mongodb/fields/array.py index 56c8506a..dfba1f69 100644 --- a/django_mongodb/fields/array.py +++ b/django_mongodb/fields/array.py @@ -183,7 +183,6 @@ def get_transform(self, name): except ValueError: pass else: - index += 1 # postgres uses 1-indexing return IndexTransformFactory(index, self.base_field) try: start, end = name.split("_") @@ -306,10 +305,8 @@ def __init__(self, index, base_field, *args, **kwargs): self.base_field = base_field def as_mql(self, compiler, connection): - lhs, params = compiler.compile(self.lhs) - if not lhs.endswith("]"): - lhs = "(%s)" % lhs - return "%s[%%s]" % lhs, (*params, self.index) + lhs_mql = process_lhs(self, compiler, connection) + return {"$arrayElemAt": [lhs_mql, self.index]} @property def output_field(self): diff --git a/tests/model_fields_/test_arrayfield.py b/tests/model_fields_/test_arrayfield.py index 9dae0cdc..4d2847ad 100644 --- a/tests/model_fields_/test_arrayfield.py +++ b/tests/model_fields_/test_arrayfield.py @@ -356,7 +356,6 @@ def test_index_nested(self): instance = NestedIntegerArrayModel.objects.create(field=[[1, 2], [3, 4]]) self.assertSequenceEqual(NestedIntegerArrayModel.objects.filter(field__0__0=1), [instance]) - @unittest.expectedFailure def test_index_used_on_nested_data(self): instance = NestedIntegerArrayModel.objects.create(field=[[1, 2], [3, 4]]) self.assertSequenceEqual( @@ -388,7 +387,7 @@ def test_slice(self): NullableIntegerArrayModel.objects.filter(field__0_2=[2, 3]), self.objs[2:3] ) - def test_order_by_slice(self): + def test_order_by_index(self): more_objs = ( NullableIntegerArrayModel.objects.create(field=[1, 637]), NullableIntegerArrayModel.objects.create(field=[2, 1]), @@ -398,19 +397,18 @@ def test_order_by_slice(self): self.assertSequenceEqual( NullableIntegerArrayModel.objects.order_by("field__1"), [ + self.objs[0], + self.objs[1], + self.objs[4], more_objs[2], more_objs[1], more_objs[3], self.objs[2], self.objs[3], more_objs[0], - self.objs[4], - self.objs[1], - self.objs[0], ], ) - @unittest.expectedFailure def test_slice_nested(self): instance = NestedIntegerArrayModel.objects.create(field=[[1, 2], [3, 4]]) self.assertSequenceEqual(