Skip to content

Commit

Permalink
Refactor subquery wrapping pipeline (#6)
Browse files Browse the repository at this point in the history
  • Loading branch information
WaVEV authored Jan 9, 2025
1 parent a769278 commit c51c757
Show file tree
Hide file tree
Showing 4 changed files with 83 additions and 77 deletions.
86 changes: 11 additions & 75 deletions django_mongodb_backend/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def order_by(self, compiler, connection):
return self.expression.as_mql(compiler, connection)


def query(self, compiler, connection, lookup_name=None):
def query(self, compiler, connection, get_wrapping_pipeline=None):
subquery_compiler = self.get_compiler(connection=connection)
subquery_compiler.pre_sql_setup(with_col_aliases=False)
field_name, expr = subquery_compiler.columns[0]
Expand All @@ -119,76 +119,12 @@ def query(self, compiler, connection, lookup_name=None):
for col, i in subquery_compiler.column_indices.items()
},
}
wrapping_result_pipeline = None
# The result must be a list of values. The output is compressed with an
# aggregation pipeline.
if lookup_name in ("in", "range"):
wrapping_result_pipeline = [
{
"$facet": {
"group": [
{
"$group": {
"_id": None,
"tmp_name": {
"$addToSet": expr.as_mql(subquery_compiler, connection)
},
}
}
]
}
},
{
"$project": {
field_name: {
"$ifNull": [
{
"$getField": {
"input": {"$arrayElemAt": ["$group", 0]},
"field": "tmp_name",
}
},
[],
]
}
}
},
]
if lookup_name == "overlap":
wrapping_result_pipeline = [
{
"$facet": {
"group": [
{"$project": {"tmp_name": expr.as_mql(subquery_compiler, connection)}},
{
"$unwind": "$tmp_name",
},
{
"$group": {
"_id": None,
"tmp_name": {"$addToSet": "$tmp_name"},
}
},
]
}
},
{
"$project": {
field_name: {
"$ifNull": [
{
"$getField": {
"input": {"$arrayElemAt": ["$group", 0]},
"field": "tmp_name",
}
},
[],
]
}
}
},
]
if wrapping_result_pipeline:
if get_wrapping_pipeline:
# The results from some lookups must be converted to a list of values.
# The output is compressed with an aggregation pipeline.
wrapping_result_pipeline = get_wrapping_pipeline(
subquery_compiler, connection, field_name, expr
)
# If the subquery is a combinator, wrap the result at the end of the
# combinator pipeline...
if subquery.query.combinator:
Expand Down Expand Up @@ -221,13 +157,13 @@ def star(self, compiler, connection): # noqa: ARG001
return {"$literal": True}


def subquery(self, compiler, connection, lookup_name=None):
return self.query.as_mql(compiler, connection, lookup_name=lookup_name)
def subquery(self, compiler, connection, get_wrapping_pipeline=None):
return self.query.as_mql(compiler, connection, get_wrapping_pipeline=get_wrapping_pipeline)


def exists(self, compiler, connection, lookup_name=None):
def exists(self, compiler, connection, get_wrapping_pipeline=None):
try:
lhs_mql = subquery(self, compiler, connection, lookup_name=lookup_name)
lhs_mql = subquery(self, compiler, connection, get_wrapping_pipeline=get_wrapping_pipeline)
except EmptyResultSet:
return Value(False).as_mql(compiler, connection)
return connection.mongo_operators["isnull"](lhs_mql, False)
Expand Down
35 changes: 35 additions & 0 deletions django_mongodb_backend/fields/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,41 @@ class ArrayExact(ArrayRHSMixin, Exact):
class ArrayOverlap(ArrayRHSMixin, FieldGetDbPrepValueMixin, Lookup):
lookup_name = "overlap"

def get_subquery_wrapping_pipeline(self, compiler, connection, field_name, expr):
return [
{
"$facet": {
"group": [
{"$project": {"tmp_name": expr.as_mql(compiler, connection)}},
{
"$unwind": "$tmp_name",
},
{
"$group": {
"_id": None,
"tmp_name": {"$addToSet": "$tmp_name"},
}
},
]
}
},
{
"$project": {
field_name: {
"$ifNull": [
{
"$getField": {
"input": {"$arrayElemAt": ["$group", 0]},
"field": "tmp_name",
}
},
[],
]
}
}
},
]

def as_mql(self, compiler, connection):
lhs_mql = process_lhs(self, compiler, connection)
value = process_rhs(self, compiler, connection)
Expand Down
33 changes: 33 additions & 0 deletions django_mongodb_backend/lookups.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,38 @@ def in_(self, compiler, connection):
return builtin_lookup(self, compiler, connection)


def get_subquery_wrapping_pipeline(self, compiler, connection, field_name, expr): # noqa: ARG001
return [
{
"$facet": {
"group": [
{
"$group": {
"_id": None,
"tmp_name": {"$addToSet": expr.as_mql(compiler, connection)},
}
}
]
}
},
{
"$project": {
field_name: {
"$ifNull": [
{
"$getField": {
"input": {"$arrayElemAt": ["$group", 0]},
"field": "tmp_name",
}
},
[],
]
}
}
},
]


def is_null(self, compiler, connection):
if not isinstance(self.rhs, bool):
raise ValueError("The QuerySet value for an isnull lookup must be True or False.")
Expand Down Expand Up @@ -97,6 +129,7 @@ def register_lookups():
field_resolve_expression_parameter
)
In.as_mql = RelatedIn.as_mql = in_
In.get_subquery_wrapping_pipeline = get_subquery_wrapping_pipeline
IsNull.as_mql = is_null
PatternLookup.prep_lookup_value_mongo = pattern_lookup_prep_lookup_value
UUIDTextMixin.as_mql = uuid_text_mixin
6 changes: 4 additions & 2 deletions django_mongodb_backend/query_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,10 @@ def process_lhs(node, compiler, connection):
def process_rhs(node, compiler, connection):
rhs = node.rhs
if hasattr(rhs, "as_mql"):
if getattr(rhs, "subquery", False):
value = rhs.as_mql(compiler, connection, lookup_name=node.lookup_name)
if getattr(rhs, "subquery", False) and hasattr(node, "get_subquery_wrapping_pipeline"):
value = rhs.as_mql(
compiler, connection, get_wrapping_pipeline=node.get_subquery_wrapping_pipeline
)
else:
value = rhs.as_mql(compiler, connection)
else:
Expand Down

0 comments on commit c51c757

Please sign in to comment.