Skip to content

Commit

Permalink
chore: refactor aggregation parsing logic into dedicated stage based …
Browse files Browse the repository at this point in the history
…parsers (#117)
  • Loading branch information
himanshusinghs authored Jan 16, 2025
1 parent 5afd7c9 commit fda709b
Show file tree
Hide file tree
Showing 8 changed files with 148 additions and 76 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@

import org.springframework.data.mongodb.core.MongoTemplate;
import org.springframework.data.mongodb.core.aggregation.Aggregation;
import org.springframework.data.mongodb.core.aggregation.Fields;
import org.springframework.data.mongodb.core.mapping.Document;
import org.springframework.data.mongodb.core.query.Field;

import java.util.List;

Expand Down Expand Up @@ -32,9 +34,8 @@ private List<Movie> allMoviesWithRatingAtLeast(int rating) {
private List<Movie> allMoviesWithRatingAtLeastAgg(int rating) {
return template.aggregate(
Aggregation.newAggregation(
Aggregation.match(
where( "tomatoes.viewer.rating").gte(rating)
)
Aggregation.match(where( "tomatoes.viewer.rating").gte(rating)),
Aggregation.project("asd").andInclude("asd").andExclude("qwe")
),
Movie.class,
Movie.class
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,7 @@ object JavaDriverDialectParser : DialectParser<PsiElement> {
return null // empty, do nothing
}

val fieldReference = resolveFieldNameFromExpression(filter.argumentList.expressions[0])
val fieldReference = filter.argumentList.expressions[0].resolveFieldNameFromExpression()
// if it's only 2 arguments it can be either:
// - in(field, singleElement) -> valid because of varargs, becomes a single element array
// - in(field, array) -> valid because of varargs
Expand Down Expand Up @@ -343,7 +343,7 @@ object JavaDriverDialectParser : DialectParser<PsiElement> {
return null
}
val fieldExpression = filter.argumentList.expressions[0]
val fieldReference = resolveFieldNameFromExpression(fieldExpression)
val fieldReference = fieldExpression.resolveFieldNameFromExpression()
val valueReference = HasValueReference.Inferred(
source = fieldExpression,
value = true,
Expand All @@ -360,7 +360,7 @@ object JavaDriverDialectParser : DialectParser<PsiElement> {
)
} else if (method.parameters.size == 2) {
// If it has two parameters, it's field/value.
val fieldReference = resolveFieldNameFromExpression(filter.argumentList.expressions[0])
val fieldReference = filter.argumentList.expressions[0].resolveFieldNameFromExpression()
val valueReference = resolveValueFromExpression(filter.argumentList.expressions[1])

return Node(
Expand Down Expand Up @@ -393,7 +393,7 @@ object JavaDriverDialectParser : DialectParser<PsiElement> {
)
} else if (method.parameters.size == 2) {
// If it has two parameters, it's field/value.
val fieldReference = resolveFieldNameFromExpression(filter.argumentList.expressions[0])
val fieldReference = filter.argumentList.expressions[0].resolveFieldNameFromExpression()
val valueReference = resolveValueFromExpression(filter.argumentList.expressions[1])

return Node(
Expand All @@ -410,7 +410,7 @@ object JavaDriverDialectParser : DialectParser<PsiElement> {
)
} else if (method.parameters.size == 1) {
// Updates.unset for example
val fieldReference = resolveFieldNameFromExpression(filter.argumentList.expressions[0])
val fieldReference = filter.argumentList.expressions[0].resolveFieldNameFromExpression()

return Node(
filter,
Expand Down Expand Up @@ -613,7 +613,7 @@ object JavaDriverDialectParser : DialectParser<PsiElement> {
"ascending",
"descending" -> expression.getVarArgsOrIterableArgs()
.mapNotNull {
val fieldReference = resolveFieldNameFromExpression(it)
val fieldReference = it.resolveFieldNameFromExpression()
val methodName = Name.from(methodCall.name)
when (fieldReference) {
is FromSchema -> Node(
Expand Down Expand Up @@ -818,16 +818,6 @@ object JavaDriverDialectParser : DialectParser<PsiElement> {
callMethod?.containingClass?.qualifiedName == AGGREGATES_FQN
}

private fun resolveFieldNameFromExpression(expression: PsiExpression): HasFieldReference.FieldReference<out Any> {
val fieldNameAsString = expression.tryToResolveAsConstantString()
val fieldReference =
fieldNameAsString?.let {
FromSchema(expression, it)
} ?: HasFieldReference.Unknown

return fieldReference
}

private fun resolveValueFromExpression(expression: PsiExpression): HasValueReference.ValueReference<PsiElement> {
val (wasResolvedAtCompileTime, resolvedValue) = expression.tryToResolveAsConstant()

Expand Down Expand Up @@ -907,6 +897,16 @@ object JavaDriverDialectParser : DialectParser<PsiElement> {
}
}

fun PsiExpression.resolveFieldNameFromExpression(): HasFieldReference.FieldReference<out Any> {
val fieldNameAsString = tryToResolveAsConstantString()
val fieldReference =
fieldNameAsString?.let {
FromSchema(this, it)
} ?: HasFieldReference.Unknown

return fieldReference
}

fun PsiExpressionList.inferFromSingleArrayArgument(start: Int = 0): HasValueReference.ValueReference<PsiElement> {
val arrayArg = expressions[start]
val (constant, value) = arrayArg.tryToResolveAsConstant()
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
package com.mongodb.jbplugin.dialects.springcriteria

import com.intellij.psi.PsiElement
import com.intellij.psi.PsiMethod
import com.intellij.psi.PsiMethodCallExpression
import com.mongodb.jbplugin.dialects.javadriver.glossary.fuzzyResolveMethod
import com.mongodb.jbplugin.dialects.javadriver.glossary.resolveToMethodCallExpression
import com.mongodb.jbplugin.dialects.springcriteria.aggregationstageparsers.MatchStageParser
import com.mongodb.jbplugin.mql.Node

/**
* Parser for parsing supported patterns of writing an aggregation pipeline.
* Supported patterns for writing aggregation calls are the following:
* 1. MongoTemplate.aggregate()
* 2. MongoTemplate.aggregateStream()
*
* The AggregationParser concerns itself only with parsing the aggregation related semantics and
* leave the rest as a responsibility for the composing unit.
*/
class AggregationStagesParser(private val matchStageParser: MatchStageParser) {
private fun isStageCall(stageCallMethod: PsiMethod): Boolean {
return matchStageParser.canParse(stageCallMethod)
}

private fun parseAggregationStages(
newAggregationCall: PsiMethodCallExpression
): List<Node<PsiElement>> {
val newAggregationCallArguments = newAggregationCall.argumentList.expressions
val resolvedStageCalls = newAggregationCallArguments.mapNotNull { stageCallExpression ->
stageCallExpression.resolveToMethodCallExpression { _, stageCallMethod ->
isStageCall(stageCallMethod)
}
}
return resolvedStageCalls.map { stageCall ->
val stageCallMethod = stageCall.fuzzyResolveMethod() ?: return@map Node(
source = stageCall,
components = emptyList()
)

if (matchStageParser.canParse(stageCallMethod)) {
matchStageParser.parse(stageCall)
} else {
Node(
source = stageCall,
components = emptyList()
)
}
}
}

fun parse(aggregateRootCall: PsiMethodCallExpression): List<Node<PsiElement>> {
val aggregateRootCallArguments = aggregateRootCall.argumentList.expressions

val newAggregationCallExpression = aggregateRootCallArguments.getOrNull(0)
?: return emptyList()

// This is the call to Aggregation.newAggregation method which is generally the first
// argument to the root aggregate call. All the aggregation stages are to be found as
// the argument to this method call.
val newAggregationCall = newAggregationCallExpression.resolveToMethodCallExpression {
_,
method
->
method.name == "newAggregation"
} ?: return emptyList()

return parseAggregationStages(newAggregationCall)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import com.mongodb.jbplugin.dialects.springcriteria.QueryTargetCollectionExtract
import com.mongodb.jbplugin.dialects.springcriteria.QueryTargetCollectionExtractor.extractCollectionFromQueryChain
import com.mongodb.jbplugin.dialects.springcriteria.QueryTargetCollectionExtractor.extractCollectionFromStringTypeParameter
import com.mongodb.jbplugin.dialects.springcriteria.QueryTargetCollectionExtractor.or
import com.mongodb.jbplugin.dialects.springcriteria.aggregationstageparsers.MatchStageParser
import com.mongodb.jbplugin.mql.BsonAny
import com.mongodb.jbplugin.mql.BsonArray
import com.mongodb.jbplugin.mql.Node
Expand All @@ -18,11 +19,8 @@ import com.mongodb.jbplugin.mql.toBsonType

private const val CRITERIA_CLASS_FQN = "org.springframework.data.mongodb.core.query.Criteria"
private const val DOCUMENT_FQN = "org.springframework.data.mongodb.core.mapping.Document"
private const val AGGREGATE_FQN = "org.springframework.data.mongodb.core.aggregation.Aggregation"
private const val MONGO_TEMPLATE_FQN = "org.springframework.data.mongodb.core.MongoTemplate"
private val PARSEABLE_AGGREGATION_STAGE_METHODS = listOf(
"match"
)
const val AGGREGATE_FQN = "org.springframework.data.mongodb.core.aggregation.Aggregation"

object SpringCriteriaDialectParser : DialectParser<PsiElement> {
override fun isCandidateForQuery(source: PsiElement) =
Expand Down Expand Up @@ -218,19 +216,6 @@ object SpringCriteriaDialectParser : DialectParser<PsiElement> {
)
"aggregate", "aggregateStream" -> {
val expressions = mongoOpCall.argumentList.expressions
val newAggregationCall = expressions.getOrNull(0)?.resolveToMethodCallExpression {
_,
method
->
method.name == "newAggregation"
}
val resolvedStageCallExpression =
newAggregationCall?.getVarArgsOrIterableArgs()?.mapNotNull {
it.resolveToMethodCallExpression { _, method ->
method.containingClass?.qualifiedName == AGGREGATE_FQN &&
PARSEABLE_AGGREGATION_STAGE_METHODS.contains(method.name)
}
} ?: emptyList()
val collectionExpression = expressions.getOrNull(1)
return Node(
mongoOpCall,
Expand All @@ -247,7 +232,9 @@ object SpringCriteriaDialectParser : DialectParser<PsiElement> {
extractCollectionFromStringTypeParameter(collectionExpression)
),
HasAggregation(
parseAggregationStagesFromCurrentCall(resolvedStageCallExpression)
children = AggregationStagesParser(
matchStageParser = MatchStageParser(::parseFilterRecursively)
).parse(mongoOpCall)
)
)
)
Expand All @@ -271,7 +258,7 @@ object SpringCriteriaDialectParser : DialectParser<PsiElement> {
}

override fun isReferenceToDatabase(source: PsiElement): Boolean {
return false // databases are in property files and we don't support AC there yet
return false // databases are in property files, and we don't support AC there yet
}

override fun isReferenceToCollection(source: PsiElement): Boolean {
Expand Down Expand Up @@ -336,34 +323,6 @@ object SpringCriteriaDialectParser : DialectParser<PsiElement> {
)
}

private fun parseAggregationStagesFromCurrentCall(
stageCallExpressions: List<PsiMethodCallExpression>
): List<Node<PsiElement>> {
return stageCallExpressions.mapNotNull { stageCall ->
val stageMethod = stageCall.fuzzyResolveMethod() ?: return@mapNotNull Node(
source = stageCall,
components = emptyList()
)

when (stageMethod.name) {
"match" -> parseMatchStageCall(stageCall)
else -> null
}
}
}

private fun parseMatchStageCall(matchStageCall: PsiMethodCallExpression): Node<PsiElement> {
return Node(
source = matchStageCall,
components = listOf(
Named(Name.MATCH),
HasFilter(
parseFilterRecursively(matchStageCall.argumentList.expressions.getOrNull(0))
)
)
)
}

private fun parseFilterRecursively(
valueFilterExpression: PsiElement?
): List<Node<PsiElement>> {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
package com.mongodb.jbplugin.dialects.springcriteria.aggregationstageparsers

import com.intellij.psi.PsiElement
import com.intellij.psi.PsiMethod
import com.intellij.psi.PsiMethodCallExpression
import com.mongodb.jbplugin.dialects.springcriteria.AGGREGATE_FQN
import com.mongodb.jbplugin.mql.Node
import com.mongodb.jbplugin.mql.components.HasFilter
import com.mongodb.jbplugin.mql.components.Name
import com.mongodb.jbplugin.mql.components.Named

class MatchStageParser(
private val parseFilters: (PsiElement) -> List<Node<PsiElement>>
) : StageParser {
private fun createMatchStageNode(
source: PsiElement,
filters: List<Node<PsiElement>> = emptyList()
) = Node(
source = source,
components = listOf(
Named(Name.MATCH),
HasFilter(filters),
)
)

override fun canParse(stageCallMethod: PsiMethod): Boolean {
return stageCallMethod.containingClass?.qualifiedName == AGGREGATE_FQN &&
stageCallMethod.name == "match"
}

override fun parse(stageCall: PsiMethodCallExpression): Node<PsiElement> {
val filterExpression = stageCall.argumentList.expressions.getOrNull(0)
?: return createMatchStageNode(source = stageCall)
return createMatchStageNode(
source = stageCall,
filters = parseFilters(filterExpression)
)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
package com.mongodb.jbplugin.dialects.springcriteria.aggregationstageparsers

import com.intellij.psi.PsiElement
import com.intellij.psi.PsiMethod
import com.intellij.psi.PsiMethodCallExpression
import com.mongodb.jbplugin.mql.Node

interface StageParser {
fun canParse(stageCallMethod: PsiMethod): Boolean
fun parse(stageCall: PsiMethodCallExpression): Node<PsiElement>
}
Original file line number Diff line number Diff line change
@@ -1,22 +1,15 @@
package com.mongodb.jbplugin.dialects.springcriteria.aggregationparser
package com.mongodb.jbplugin.dialects.springcriteria

import com.intellij.psi.PsiElement
import com.intellij.psi.PsiFile
import com.mongodb.jbplugin.dialects.springcriteria.IntegrationTest
import com.mongodb.jbplugin.dialects.springcriteria.ParsingTest
import com.mongodb.jbplugin.dialects.springcriteria.SpringCriteriaDialectParser
import com.mongodb.jbplugin.dialects.springcriteria.assert
import com.mongodb.jbplugin.dialects.springcriteria.collection
import com.mongodb.jbplugin.dialects.springcriteria.component
import com.mongodb.jbplugin.dialects.springcriteria.getQueryAtMethod
import com.mongodb.jbplugin.mql.components.HasAggregation
import com.mongodb.jbplugin.mql.components.HasCollectionReference
import com.mongodb.jbplugin.mql.components.HasSourceDialect
import com.mongodb.jbplugin.mql.components.IsCommand
import org.junit.jupiter.api.Assertions.assertEquals

@IntegrationTest
class AggregationParserTest {
class AggregationStagesParserTest {
@ParsingTest(
fileName = "Book.java",
"""
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package com.mongodb.jbplugin.dialects.springcriteria.aggregationparser
package com.mongodb.jbplugin.dialects.springcriteria.aggregationstageparsers

import com.intellij.psi.PsiElement
import com.intellij.psi.PsiFile
Expand Down

0 comments on commit fda709b

Please sign in to comment.