diff --git a/packages/jetbrains-plugin/src/test/resources/project-fixtures/basic-java-project-with-mongodb/src/main/java/alt/mongodb/springcriteria/SpringCriteriaRepository.java b/packages/jetbrains-plugin/src/test/resources/project-fixtures/basic-java-project-with-mongodb/src/main/java/alt/mongodb/springcriteria/SpringCriteriaRepository.java index 8d02cc7b..d3229d3b 100644 --- a/packages/jetbrains-plugin/src/test/resources/project-fixtures/basic-java-project-with-mongodb/src/main/java/alt/mongodb/springcriteria/SpringCriteriaRepository.java +++ b/packages/jetbrains-plugin/src/test/resources/project-fixtures/basic-java-project-with-mongodb/src/main/java/alt/mongodb/springcriteria/SpringCriteriaRepository.java @@ -2,7 +2,9 @@ import org.springframework.data.mongodb.core.MongoTemplate; import org.springframework.data.mongodb.core.aggregation.Aggregation; +import org.springframework.data.mongodb.core.aggregation.Fields; import org.springframework.data.mongodb.core.mapping.Document; +import org.springframework.data.mongodb.core.query.Field; import java.util.List; @@ -32,9 +34,8 @@ private List allMoviesWithRatingAtLeast(int rating) { private List allMoviesWithRatingAtLeastAgg(int rating) { return template.aggregate( Aggregation.newAggregation( - Aggregation.match( - where( "tomatoes.viewer.rating").gte(rating) - ) + Aggregation.match(where( "tomatoes.viewer.rating").gte(rating)), + Aggregation.project("asd").andInclude("asd").andExclude("qwe") ), Movie.class, Movie.class diff --git a/packages/mongodb-dialects/java-driver/src/main/kotlin/com/mongodb/jbplugin/dialects/javadriver/glossary/JavaDriverDialectParser.kt b/packages/mongodb-dialects/java-driver/src/main/kotlin/com/mongodb/jbplugin/dialects/javadriver/glossary/JavaDriverDialectParser.kt index d60f294c..603ee161 100644 --- a/packages/mongodb-dialects/java-driver/src/main/kotlin/com/mongodb/jbplugin/dialects/javadriver/glossary/JavaDriverDialectParser.kt +++ b/packages/mongodb-dialects/java-driver/src/main/kotlin/com/mongodb/jbplugin/dialects/javadriver/glossary/JavaDriverDialectParser.kt @@ -267,7 +267,7 @@ object JavaDriverDialectParser : DialectParser { return null // empty, do nothing } - val fieldReference = resolveFieldNameFromExpression(filter.argumentList.expressions[0]) + val fieldReference = filter.argumentList.expressions[0].resolveFieldNameFromExpression() // if it's only 2 arguments it can be either: // - in(field, singleElement) -> valid because of varargs, becomes a single element array // - in(field, array) -> valid because of varargs @@ -343,7 +343,7 @@ object JavaDriverDialectParser : DialectParser { return null } val fieldExpression = filter.argumentList.expressions[0] - val fieldReference = resolveFieldNameFromExpression(fieldExpression) + val fieldReference = fieldExpression.resolveFieldNameFromExpression() val valueReference = HasValueReference.Inferred( source = fieldExpression, value = true, @@ -360,7 +360,7 @@ object JavaDriverDialectParser : DialectParser { ) } else if (method.parameters.size == 2) { // If it has two parameters, it's field/value. - val fieldReference = resolveFieldNameFromExpression(filter.argumentList.expressions[0]) + val fieldReference = filter.argumentList.expressions[0].resolveFieldNameFromExpression() val valueReference = resolveValueFromExpression(filter.argumentList.expressions[1]) return Node( @@ -393,7 +393,7 @@ object JavaDriverDialectParser : DialectParser { ) } else if (method.parameters.size == 2) { // If it has two parameters, it's field/value. - val fieldReference = resolveFieldNameFromExpression(filter.argumentList.expressions[0]) + val fieldReference = filter.argumentList.expressions[0].resolveFieldNameFromExpression() val valueReference = resolveValueFromExpression(filter.argumentList.expressions[1]) return Node( @@ -410,7 +410,7 @@ object JavaDriverDialectParser : DialectParser { ) } else if (method.parameters.size == 1) { // Updates.unset for example - val fieldReference = resolveFieldNameFromExpression(filter.argumentList.expressions[0]) + val fieldReference = filter.argumentList.expressions[0].resolveFieldNameFromExpression() return Node( filter, @@ -613,7 +613,7 @@ object JavaDriverDialectParser : DialectParser { "ascending", "descending" -> expression.getVarArgsOrIterableArgs() .mapNotNull { - val fieldReference = resolveFieldNameFromExpression(it) + val fieldReference = it.resolveFieldNameFromExpression() val methodName = Name.from(methodCall.name) when (fieldReference) { is FromSchema -> Node( @@ -818,16 +818,6 @@ object JavaDriverDialectParser : DialectParser { callMethod?.containingClass?.qualifiedName == AGGREGATES_FQN } - private fun resolveFieldNameFromExpression(expression: PsiExpression): HasFieldReference.FieldReference { - val fieldNameAsString = expression.tryToResolveAsConstantString() - val fieldReference = - fieldNameAsString?.let { - FromSchema(expression, it) - } ?: HasFieldReference.Unknown - - return fieldReference - } - private fun resolveValueFromExpression(expression: PsiExpression): HasValueReference.ValueReference { val (wasResolvedAtCompileTime, resolvedValue) = expression.tryToResolveAsConstant() @@ -907,6 +897,16 @@ object JavaDriverDialectParser : DialectParser { } } +fun PsiExpression.resolveFieldNameFromExpression(): HasFieldReference.FieldReference { + val fieldNameAsString = tryToResolveAsConstantString() + val fieldReference = + fieldNameAsString?.let { + FromSchema(this, it) + } ?: HasFieldReference.Unknown + + return fieldReference +} + fun PsiExpressionList.inferFromSingleArrayArgument(start: Int = 0): HasValueReference.ValueReference { val arrayArg = expressions[start] val (constant, value) = arrayArg.tryToResolveAsConstant() diff --git a/packages/mongodb-dialects/spring-criteria/src/main/kotlin/com/mongodb/jbplugin/dialects/springcriteria/AggregationStagesParser.kt b/packages/mongodb-dialects/spring-criteria/src/main/kotlin/com/mongodb/jbplugin/dialects/springcriteria/AggregationStagesParser.kt new file mode 100644 index 00000000..db9ab95f --- /dev/null +++ b/packages/mongodb-dialects/spring-criteria/src/main/kotlin/com/mongodb/jbplugin/dialects/springcriteria/AggregationStagesParser.kt @@ -0,0 +1,69 @@ +package com.mongodb.jbplugin.dialects.springcriteria + +import com.intellij.psi.PsiElement +import com.intellij.psi.PsiMethod +import com.intellij.psi.PsiMethodCallExpression +import com.mongodb.jbplugin.dialects.javadriver.glossary.fuzzyResolveMethod +import com.mongodb.jbplugin.dialects.javadriver.glossary.resolveToMethodCallExpression +import com.mongodb.jbplugin.dialects.springcriteria.aggregationstageparsers.MatchStageParser +import com.mongodb.jbplugin.mql.Node + +/** + * Parser for parsing supported patterns of writing an aggregation pipeline. + * Supported patterns for writing aggregation calls are the following: + * 1. MongoTemplate.aggregate() + * 2. MongoTemplate.aggregateStream() + * + * The AggregationParser concerns itself only with parsing the aggregation related semantics and + * leave the rest as a responsibility for the composing unit. + */ +class AggregationStagesParser(private val matchStageParser: MatchStageParser) { + private fun isStageCall(stageCallMethod: PsiMethod): Boolean { + return matchStageParser.canParse(stageCallMethod) + } + + private fun parseAggregationStages( + newAggregationCall: PsiMethodCallExpression + ): List> { + val newAggregationCallArguments = newAggregationCall.argumentList.expressions + val resolvedStageCalls = newAggregationCallArguments.mapNotNull { stageCallExpression -> + stageCallExpression.resolveToMethodCallExpression { _, stageCallMethod -> + isStageCall(stageCallMethod) + } + } + return resolvedStageCalls.map { stageCall -> + val stageCallMethod = stageCall.fuzzyResolveMethod() ?: return@map Node( + source = stageCall, + components = emptyList() + ) + + if (matchStageParser.canParse(stageCallMethod)) { + matchStageParser.parse(stageCall) + } else { + Node( + source = stageCall, + components = emptyList() + ) + } + } + } + + fun parse(aggregateRootCall: PsiMethodCallExpression): List> { + val aggregateRootCallArguments = aggregateRootCall.argumentList.expressions + + val newAggregationCallExpression = aggregateRootCallArguments.getOrNull(0) + ?: return emptyList() + + // This is the call to Aggregation.newAggregation method which is generally the first + // argument to the root aggregate call. All the aggregation stages are to be found as + // the argument to this method call. + val newAggregationCall = newAggregationCallExpression.resolveToMethodCallExpression { + _, + method + -> + method.name == "newAggregation" + } ?: return emptyList() + + return parseAggregationStages(newAggregationCall) + } +} diff --git a/packages/mongodb-dialects/spring-criteria/src/main/kotlin/com/mongodb/jbplugin/dialects/springcriteria/SpringCriteriaDialectParser.kt b/packages/mongodb-dialects/spring-criteria/src/main/kotlin/com/mongodb/jbplugin/dialects/springcriteria/SpringCriteriaDialectParser.kt index d92b1d93..1241315b 100644 --- a/packages/mongodb-dialects/spring-criteria/src/main/kotlin/com/mongodb/jbplugin/dialects/springcriteria/SpringCriteriaDialectParser.kt +++ b/packages/mongodb-dialects/spring-criteria/src/main/kotlin/com/mongodb/jbplugin/dialects/springcriteria/SpringCriteriaDialectParser.kt @@ -10,6 +10,7 @@ import com.mongodb.jbplugin.dialects.springcriteria.QueryTargetCollectionExtract import com.mongodb.jbplugin.dialects.springcriteria.QueryTargetCollectionExtractor.extractCollectionFromQueryChain import com.mongodb.jbplugin.dialects.springcriteria.QueryTargetCollectionExtractor.extractCollectionFromStringTypeParameter import com.mongodb.jbplugin.dialects.springcriteria.QueryTargetCollectionExtractor.or +import com.mongodb.jbplugin.dialects.springcriteria.aggregationstageparsers.MatchStageParser import com.mongodb.jbplugin.mql.BsonAny import com.mongodb.jbplugin.mql.BsonArray import com.mongodb.jbplugin.mql.Node @@ -18,11 +19,8 @@ import com.mongodb.jbplugin.mql.toBsonType private const val CRITERIA_CLASS_FQN = "org.springframework.data.mongodb.core.query.Criteria" private const val DOCUMENT_FQN = "org.springframework.data.mongodb.core.mapping.Document" -private const val AGGREGATE_FQN = "org.springframework.data.mongodb.core.aggregation.Aggregation" private const val MONGO_TEMPLATE_FQN = "org.springframework.data.mongodb.core.MongoTemplate" -private val PARSEABLE_AGGREGATION_STAGE_METHODS = listOf( - "match" -) +const val AGGREGATE_FQN = "org.springframework.data.mongodb.core.aggregation.Aggregation" object SpringCriteriaDialectParser : DialectParser { override fun isCandidateForQuery(source: PsiElement) = @@ -218,19 +216,6 @@ object SpringCriteriaDialectParser : DialectParser { ) "aggregate", "aggregateStream" -> { val expressions = mongoOpCall.argumentList.expressions - val newAggregationCall = expressions.getOrNull(0)?.resolveToMethodCallExpression { - _, - method - -> - method.name == "newAggregation" - } - val resolvedStageCallExpression = - newAggregationCall?.getVarArgsOrIterableArgs()?.mapNotNull { - it.resolveToMethodCallExpression { _, method -> - method.containingClass?.qualifiedName == AGGREGATE_FQN && - PARSEABLE_AGGREGATION_STAGE_METHODS.contains(method.name) - } - } ?: emptyList() val collectionExpression = expressions.getOrNull(1) return Node( mongoOpCall, @@ -247,7 +232,9 @@ object SpringCriteriaDialectParser : DialectParser { extractCollectionFromStringTypeParameter(collectionExpression) ), HasAggregation( - parseAggregationStagesFromCurrentCall(resolvedStageCallExpression) + children = AggregationStagesParser( + matchStageParser = MatchStageParser(::parseFilterRecursively) + ).parse(mongoOpCall) ) ) ) @@ -271,7 +258,7 @@ object SpringCriteriaDialectParser : DialectParser { } override fun isReferenceToDatabase(source: PsiElement): Boolean { - return false // databases are in property files and we don't support AC there yet + return false // databases are in property files, and we don't support AC there yet } override fun isReferenceToCollection(source: PsiElement): Boolean { @@ -336,34 +323,6 @@ object SpringCriteriaDialectParser : DialectParser { ) } - private fun parseAggregationStagesFromCurrentCall( - stageCallExpressions: List - ): List> { - return stageCallExpressions.mapNotNull { stageCall -> - val stageMethod = stageCall.fuzzyResolveMethod() ?: return@mapNotNull Node( - source = stageCall, - components = emptyList() - ) - - when (stageMethod.name) { - "match" -> parseMatchStageCall(stageCall) - else -> null - } - } - } - - private fun parseMatchStageCall(matchStageCall: PsiMethodCallExpression): Node { - return Node( - source = matchStageCall, - components = listOf( - Named(Name.MATCH), - HasFilter( - parseFilterRecursively(matchStageCall.argumentList.expressions.getOrNull(0)) - ) - ) - ) - } - private fun parseFilterRecursively( valueFilterExpression: PsiElement? ): List> { diff --git a/packages/mongodb-dialects/spring-criteria/src/main/kotlin/com/mongodb/jbplugin/dialects/springcriteria/aggregationstageparsers/MatchStageParser.kt b/packages/mongodb-dialects/spring-criteria/src/main/kotlin/com/mongodb/jbplugin/dialects/springcriteria/aggregationstageparsers/MatchStageParser.kt new file mode 100644 index 00000000..f81b1e36 --- /dev/null +++ b/packages/mongodb-dialects/spring-criteria/src/main/kotlin/com/mongodb/jbplugin/dialects/springcriteria/aggregationstageparsers/MatchStageParser.kt @@ -0,0 +1,39 @@ +package com.mongodb.jbplugin.dialects.springcriteria.aggregationstageparsers + +import com.intellij.psi.PsiElement +import com.intellij.psi.PsiMethod +import com.intellij.psi.PsiMethodCallExpression +import com.mongodb.jbplugin.dialects.springcriteria.AGGREGATE_FQN +import com.mongodb.jbplugin.mql.Node +import com.mongodb.jbplugin.mql.components.HasFilter +import com.mongodb.jbplugin.mql.components.Name +import com.mongodb.jbplugin.mql.components.Named + +class MatchStageParser( + private val parseFilters: (PsiElement) -> List> +) : StageParser { + private fun createMatchStageNode( + source: PsiElement, + filters: List> = emptyList() + ) = Node( + source = source, + components = listOf( + Named(Name.MATCH), + HasFilter(filters), + ) + ) + + override fun canParse(stageCallMethod: PsiMethod): Boolean { + return stageCallMethod.containingClass?.qualifiedName == AGGREGATE_FQN && + stageCallMethod.name == "match" + } + + override fun parse(stageCall: PsiMethodCallExpression): Node { + val filterExpression = stageCall.argumentList.expressions.getOrNull(0) + ?: return createMatchStageNode(source = stageCall) + return createMatchStageNode( + source = stageCall, + filters = parseFilters(filterExpression) + ) + } +} diff --git a/packages/mongodb-dialects/spring-criteria/src/main/kotlin/com/mongodb/jbplugin/dialects/springcriteria/aggregationstageparsers/StageParser.kt b/packages/mongodb-dialects/spring-criteria/src/main/kotlin/com/mongodb/jbplugin/dialects/springcriteria/aggregationstageparsers/StageParser.kt new file mode 100644 index 00000000..e0f504cc --- /dev/null +++ b/packages/mongodb-dialects/spring-criteria/src/main/kotlin/com/mongodb/jbplugin/dialects/springcriteria/aggregationstageparsers/StageParser.kt @@ -0,0 +1,11 @@ +package com.mongodb.jbplugin.dialects.springcriteria.aggregationstageparsers + +import com.intellij.psi.PsiElement +import com.intellij.psi.PsiMethod +import com.intellij.psi.PsiMethodCallExpression +import com.mongodb.jbplugin.mql.Node + +interface StageParser { + fun canParse(stageCallMethod: PsiMethod): Boolean + fun parse(stageCall: PsiMethodCallExpression): Node +} diff --git a/packages/mongodb-dialects/spring-criteria/src/test/kotlin/com/mongodb/jbplugin/dialects/springcriteria/aggregationparser/AggregationParserTest.kt b/packages/mongodb-dialects/spring-criteria/src/test/kotlin/com/mongodb/jbplugin/dialects/springcriteria/AggregationStagesParserTest.kt similarity index 94% rename from packages/mongodb-dialects/spring-criteria/src/test/kotlin/com/mongodb/jbplugin/dialects/springcriteria/aggregationparser/AggregationParserTest.kt rename to packages/mongodb-dialects/spring-criteria/src/test/kotlin/com/mongodb/jbplugin/dialects/springcriteria/AggregationStagesParserTest.kt index 7ca58924..c345f32b 100644 --- a/packages/mongodb-dialects/spring-criteria/src/test/kotlin/com/mongodb/jbplugin/dialects/springcriteria/aggregationparser/AggregationParserTest.kt +++ b/packages/mongodb-dialects/spring-criteria/src/test/kotlin/com/mongodb/jbplugin/dialects/springcriteria/AggregationStagesParserTest.kt @@ -1,14 +1,7 @@ -package com.mongodb.jbplugin.dialects.springcriteria.aggregationparser +package com.mongodb.jbplugin.dialects.springcriteria import com.intellij.psi.PsiElement import com.intellij.psi.PsiFile -import com.mongodb.jbplugin.dialects.springcriteria.IntegrationTest -import com.mongodb.jbplugin.dialects.springcriteria.ParsingTest -import com.mongodb.jbplugin.dialects.springcriteria.SpringCriteriaDialectParser -import com.mongodb.jbplugin.dialects.springcriteria.assert -import com.mongodb.jbplugin.dialects.springcriteria.collection -import com.mongodb.jbplugin.dialects.springcriteria.component -import com.mongodb.jbplugin.dialects.springcriteria.getQueryAtMethod import com.mongodb.jbplugin.mql.components.HasAggregation import com.mongodb.jbplugin.mql.components.HasCollectionReference import com.mongodb.jbplugin.mql.components.HasSourceDialect @@ -16,7 +9,7 @@ import com.mongodb.jbplugin.mql.components.IsCommand import org.junit.jupiter.api.Assertions.assertEquals @IntegrationTest -class AggregationParserTest { +class AggregationStagesParserTest { @ParsingTest( fileName = "Book.java", """ diff --git a/packages/mongodb-dialects/spring-criteria/src/test/kotlin/com/mongodb/jbplugin/dialects/springcriteria/aggregationparser/MatchStageParserTest.kt b/packages/mongodb-dialects/spring-criteria/src/test/kotlin/com/mongodb/jbplugin/dialects/springcriteria/aggregationstageparsers/MatchStageParserTest.kt similarity index 99% rename from packages/mongodb-dialects/spring-criteria/src/test/kotlin/com/mongodb/jbplugin/dialects/springcriteria/aggregationparser/MatchStageParserTest.kt rename to packages/mongodb-dialects/spring-criteria/src/test/kotlin/com/mongodb/jbplugin/dialects/springcriteria/aggregationstageparsers/MatchStageParserTest.kt index d5348193..a173bf99 100644 --- a/packages/mongodb-dialects/spring-criteria/src/test/kotlin/com/mongodb/jbplugin/dialects/springcriteria/aggregationparser/MatchStageParserTest.kt +++ b/packages/mongodb-dialects/spring-criteria/src/test/kotlin/com/mongodb/jbplugin/dialects/springcriteria/aggregationstageparsers/MatchStageParserTest.kt @@ -1,4 +1,4 @@ -package com.mongodb.jbplugin.dialects.springcriteria.aggregationparser +package com.mongodb.jbplugin.dialects.springcriteria.aggregationstageparsers import com.intellij.psi.PsiElement import com.intellij.psi.PsiFile