From 28a69f838c9bd9f5012aa3ca70db5855a5ed49ab Mon Sep 17 00:00:00 2001 From: Kevin Mas Ruiz Date: Thu, 21 Nov 2024 18:50:36 +0100 Subject: [PATCH 01/10] chore: poc on aggregations --- .../javadriver/JavaDriverRepository.java | 6 +- .../glossary/JavaDriverDialectParser.kt | 106 +++++++++++++++++- .../mongosh/MongoshDialectFormatter.kt | 4 + .../mql/components/HasAccumulatedFields.kt | 8 ++ .../mql/components/HasFieldReference.kt | 20 ++++ .../mql/components/HasValueReference.kt | 29 ++++- .../mongodb/jbplugin/mql/components/Named.kt | 10 ++ 7 files changed, 178 insertions(+), 5 deletions(-) create mode 100644 packages/mongodb-mql-model/src/main/kotlin/com/mongodb/jbplugin/mql/components/HasAccumulatedFields.kt diff --git a/packages/jetbrains-plugin/src/test/resources/project-fixtures/basic-java-project-with-mongodb/src/main/java/alt/mongodb/javadriver/JavaDriverRepository.java b/packages/jetbrains-plugin/src/test/resources/project-fixtures/basic-java-project-with-mongodb/src/main/java/alt/mongodb/javadriver/JavaDriverRepository.java index 2c22b3d0..8a33cddc 100644 --- a/packages/jetbrains-plugin/src/test/resources/project-fixtures/basic-java-project-with-mongodb/src/main/java/alt/mongodb/javadriver/JavaDriverRepository.java +++ b/packages/jetbrains-plugin/src/test/resources/project-fixtures/basic-java-project-with-mongodb/src/main/java/alt/mongodb/javadriver/JavaDriverRepository.java @@ -1,6 +1,7 @@ package alt.mongodb.javadriver; import com.mongodb.client.MongoClient; +import com.mongodb.client.model.Accumulators; import com.mongodb.client.model.Aggregates; import com.mongodb.client.model.Filters; import com.mongodb.client.model.Projections; @@ -52,7 +53,10 @@ public List queryMoviesByYear(String year) { .aggregate( List.of( Aggregates.match( - Filters.eq("year", year) + Filters.eq("_id", year) + ), + Aggregates.group( + "newField", Accumulators.avg("test", "$year") ), Aggregates.project( Projections.fields( diff --git a/packages/mongodb-dialects/java-driver/src/main/kotlin/com/mongodb/jbplugin/dialects/javadriver/glossary/JavaDriverDialectParser.kt b/packages/mongodb-dialects/java-driver/src/main/kotlin/com/mongodb/jbplugin/dialects/javadriver/glossary/JavaDriverDialectParser.kt index 2da221c8..078a69b1 100644 --- a/packages/mongodb-dialects/java-driver/src/main/kotlin/com/mongodb/jbplugin/dialects/javadriver/glossary/JavaDriverDialectParser.kt +++ b/packages/mongodb-dialects/java-driver/src/main/kotlin/com/mongodb/jbplugin/dialects/javadriver/glossary/JavaDriverDialectParser.kt @@ -14,6 +14,8 @@ import com.mongodb.jbplugin.mql.BsonInt32 import com.mongodb.jbplugin.mql.BsonType import com.mongodb.jbplugin.mql.Node import com.mongodb.jbplugin.mql.components.* +import com.mongodb.jbplugin.mql.components.HasFieldReference.Computed +import com.mongodb.jbplugin.mql.components.HasFieldReference.FromSchema import com.mongodb.jbplugin.mql.flattenAnyOfReferences import com.mongodb.jbplugin.mql.toBsonType @@ -23,6 +25,7 @@ private const val FILTERS_FQN = "com.mongodb.client.model.Filters" private const val UPDATES_FQN = "com.mongodb.client.model.Updates" private const val AGGREGATES_FQN = "com.mongodb.client.model.Aggregates" private const val PROJECTIONS_FQN = "com.mongodb.client.model.Projections" +private const val ACCUMULATORS_FQN = "com.mongodb.client.model.Accumulators" private const val JAVA_LIST_FQN = "java.util.List" private const val JAVA_ARRAYS_FQN = "java.util.Arrays" private val PARSEABLE_AGGREGATION_STAGE_METHODS = listOf( @@ -458,6 +461,32 @@ object JavaDriverDialectParser : DialectParser { return nodeWithProjections(parsedProjections) } + "group" -> { + // the first parameter of group is going to be a string expression + val groupArgument = stageCall.argumentList.expressions.getOrNull(0) + ?: return null + + val groupFieldValueExpression = parseComputedExpression(groupArgument) + + val nodeWithAccumulators: (List>) -> Node = { accFields: List> -> + Node( + stageCall, + listOf( + HasFieldReference( + HasFieldReference.Inferred(groupArgument, "_id", "_id") + ), + groupFieldValueExpression, + HasAccumulatedFields(accFields) + ) + ) + } + + val accumulators = stageCall.getVarArgsOrIterableArgs().drop(1) + .mapNotNull { resolveToAccumulatorCall(it) } + + val parsedAccumulators = accumulators.mapNotNull { parseAccumulatorExpression(it) } + return nodeWithAccumulators(parsedAccumulators) + } else -> return null } @@ -469,6 +498,37 @@ object JavaDriverDialectParser : DialectParser { } } + private fun resolveToAccumulatorCall(element: PsiElement): PsiMethodCallExpression? { + return element.resolveToMethodCallExpression { _, methodCall -> + methodCall.containingClass?.qualifiedName == ACCUMULATORS_FQN + } + } + + private fun parseComputedExpression(element: PsiElement, createsNewField: Boolean = true): HasValueReference { + return HasValueReference( + when (val expression = element.tryToResolveAsConstantString()) { + null -> HasValueReference.Unknown as HasValueReference.ValueReference + else -> HasValueReference.Computed( + element, + Node( + element, + listOf( + if (createsNewField) { + HasFieldReference( + Computed(element, expression.trim('$'), expression) + ) + } else { + HasFieldReference( + FromSchema(element, expression.trim('$'), expression) + ) + } + ) + ) + ) + } + ) + } + private fun parseProjectionExpression(expression: PsiMethodCallExpression): List> { val methodCall = expression.resolveMethod() ?: return emptyList() return when (methodCall.name) { @@ -482,8 +542,7 @@ object JavaDriverDialectParser : DialectParser { val fieldReference = resolveFieldNameFromExpression(it) val methodName = Name.from(methodCall.name) when (fieldReference) { - is HasFieldReference.Unknown -> null - is HasFieldReference.FromSchema -> Node( + is FromSchema -> Node( source = it, components = listOf( Named(methodName), @@ -496,7 +555,7 @@ object JavaDriverDialectParser : DialectParser { ) ) ) - ) + ) else -> null } } @@ -504,6 +563,47 @@ object JavaDriverDialectParser : DialectParser { } } + private fun parseAccumulatorExpression(expression: PsiMethodCallExpression): Node? { + val methodCall = expression.fuzzyResolveMethod() ?: return null + + return when (methodCall.name) { + "sum" -> parseKeyValAccumulator(expression, Name.SUM) + "avg" -> parseKeyValAccumulator(expression, Name.AVG) + "first" -> parseKeyValAccumulator(expression, Name.FIRST) + "last" -> parseKeyValAccumulator(expression, Name.LAST) + "top" -> null + "bottom" -> null + "max" -> parseKeyValAccumulator(expression, Name.MAX) + "min" -> parseKeyValAccumulator(expression, Name.MIN) + "push" -> parseKeyValAccumulator(expression, Name.PUSH) + "addToSet" -> parseKeyValAccumulator(expression, Name.ADD_TO_SET) + else -> null + } + } + + private fun parseKeyValAccumulator(expression: PsiMethodCallExpression, name: Name): Node? { + val keyExpr = expression.argumentList.expressions.getOrNull(0) ?: return null + val valueExpr = expression.argumentList.expressions.getOrNull(1) ?: return null + + val fieldName = keyExpr.tryToResolveAsConstantString() + val accumulatorExpr = parseComputedExpression(valueExpr, createsNewField = false) + + return Node( + expression, + listOf( + Named(name), + HasFieldReference( + if (fieldName != null) { + Computed(keyExpr, fieldName, fieldName) + } else { + HasFieldReference.Unknown as HasFieldReference.FieldReference + } + ), + accumulatorExpr + ) + ) + } + private fun hasMongoDbSessionReference(methodCall: PsiMethodCallExpression): Boolean { val hasEnoughArgs = methodCall.argumentList.expressionCount > 0 if (!hasEnoughArgs) { diff --git a/packages/mongodb-dialects/mongosh/src/main/kotlin/com/mongodb/jbplugin/dialects/mongosh/MongoshDialectFormatter.kt b/packages/mongodb-dialects/mongosh/src/main/kotlin/com/mongodb/jbplugin/dialects/mongosh/MongoshDialectFormatter.kt index 3a9eb8b7..e34b591d 100644 --- a/packages/mongodb-dialects/mongosh/src/main/kotlin/com/mongodb/jbplugin/dialects/mongosh/MongoshDialectFormatter.kt +++ b/packages/mongodb-dialects/mongosh/src/main/kotlin/com/mongodb/jbplugin/dialects/mongosh/MongoshDialectFormatter.kt @@ -6,7 +6,9 @@ import com.mongodb.jbplugin.dialects.mongosh.backend.MongoshBackend import com.mongodb.jbplugin.mql.* import com.mongodb.jbplugin.mql.components.* import com.mongodb.jbplugin.mql.components.HasFieldReference.FromSchema +import com.mongodb.jbplugin.mql.components.HasFieldReference.Inferred import com.mongodb.jbplugin.mql.components.HasFieldReference.Unknown +import com.mongodb.jbplugin.mql.components.HasValueReference.Computed import com.mongodb.jbplugin.mql.parser.anyError import com.mongodb.jbplugin.mql.parser.components.aggregationStages import com.mongodb.jbplugin.mql.parser.components.allFiltersRecursively @@ -341,6 +343,8 @@ private fun MongoshBackend.resolveValueReference( private fun MongoshBackend.resolveFieldReference(fieldRef: HasFieldReference) = when (val ref = fieldRef.reference) { is FromSchema -> registerConstant(ref.fieldName) + is Inferred -> registerConstant(ref.fieldName) + is HasFieldReference.Computed -> registerConstant(ref.fieldName) is Unknown -> registerVariable("field", BsonAny) } diff --git a/packages/mongodb-mql-model/src/main/kotlin/com/mongodb/jbplugin/mql/components/HasAccumulatedFields.kt b/packages/mongodb-mql-model/src/main/kotlin/com/mongodb/jbplugin/mql/components/HasAccumulatedFields.kt new file mode 100644 index 00000000..d56c85bb --- /dev/null +++ b/packages/mongodb-mql-model/src/main/kotlin/com/mongodb/jbplugin/mql/components/HasAccumulatedFields.kt @@ -0,0 +1,8 @@ +package com.mongodb.jbplugin.mql.components + +import com.mongodb.jbplugin.mql.HasChildren +import com.mongodb.jbplugin.mql.Node + +data class HasAccumulatedFields( + override val children: List> +) : HasChildren diff --git a/packages/mongodb-mql-model/src/main/kotlin/com/mongodb/jbplugin/mql/components/HasFieldReference.kt b/packages/mongodb-mql-model/src/main/kotlin/com/mongodb/jbplugin/mql/components/HasFieldReference.kt index 18dffeea..13633c5c 100644 --- a/packages/mongodb-mql-model/src/main/kotlin/com/mongodb/jbplugin/mql/components/HasFieldReference.kt +++ b/packages/mongodb-mql-model/src/main/kotlin/com/mongodb/jbplugin/mql/components/HasFieldReference.kt @@ -23,4 +23,24 @@ data class HasFieldReference( val fieldName: String, val displayName: String = fieldName, ) : FieldReference + + /** + * Encodes a FieldReference that is part of a schema, but it's not defined + * in code. For example, the _id field that is created on { $group: "$expr" }. + */ + data class Inferred( + val source: S, + val fieldName: String, + val displayName: String = fieldName, + ) : FieldReference + + /** + * Encodes a FieldReference that is part of a schema, but it's not defined + * in code. For example, the _id field that is created on { $group: "$expr" }. + */ + data class Computed( + val source: S, + val fieldName: String, + val displayName: String = fieldName, + ) : FieldReference } diff --git a/packages/mongodb-mql-model/src/main/kotlin/com/mongodb/jbplugin/mql/components/HasValueReference.kt b/packages/mongodb-mql-model/src/main/kotlin/com/mongodb/jbplugin/mql/components/HasValueReference.kt index 3961a45b..e7c9002c 100644 --- a/packages/mongodb-mql-model/src/main/kotlin/com/mongodb/jbplugin/mql/components/HasValueReference.kt +++ b/packages/mongodb-mql-model/src/main/kotlin/com/mongodb/jbplugin/mql/components/HasValueReference.kt @@ -1,11 +1,14 @@ package com.mongodb.jbplugin.mql.components +import com.mongodb.jbplugin.mql.BsonAny import com.mongodb.jbplugin.mql.BsonType import com.mongodb.jbplugin.mql.Component +import com.mongodb.jbplugin.mql.HasChildren +import com.mongodb.jbplugin.mql.Node data class HasValueReference( val reference: ValueReference, -) : Component { +) : Component, HasChildren { sealed interface ValueReference @@ -56,4 +59,28 @@ data class HasValueReference( val source: S, val type: BsonType, ) : ValueReference + + /** + * Encodes a ValueReference when the value is computed on the server side. For example + * for $group stages in an aggregation pipeline: + * ``` + * Aggregates.group("$year") //-> Computed(Node(HasFieldReference(...))) + * ``` + * The computedExpression can not be known, so we don't have a BsonType attached to it. + * + * Unless it can be inferred from the expression (not implemented), we will assume it's + * BsonAny + */ + data class Computed( + val source: S, + val expression: Node, + ) : ValueReference { + val type: BsonType = BsonAny + } + + override val children: List> + get() = when (reference) { + is Computed -> listOf(reference.expression) + else -> emptyList() + } } diff --git a/packages/mongodb-mql-model/src/main/kotlin/com/mongodb/jbplugin/mql/components/Named.kt b/packages/mongodb-mql-model/src/main/kotlin/com/mongodb/jbplugin/mql/components/Named.kt index 28aa0f97..8c01ec9b 100644 --- a/packages/mongodb-mql-model/src/main/kotlin/com/mongodb/jbplugin/mql/components/Named.kt +++ b/packages/mongodb-mql-model/src/main/kotlin/com/mongodb/jbplugin/mql/components/Named.kt @@ -52,6 +52,16 @@ enum class Name(val canonical: String) { PROJECT("project"), INCLUDE("include"), EXCLUDE("exclude"), + SUM("sum"), + AVG("avg"), + FIRST("first"), + LAST("last"), + TOP("top"), + BOTTOM("bottom"), + MAX("max"), + MIN("min"), + PUSH("push"), + ADD_TO_SET("addToSet"), UNKNOWN(""), ; From 6c65c9f5edb7985f642cf04130c86c12d1354f0f Mon Sep 17 00:00:00 2001 From: Kevin Mas Ruiz Date: Fri, 22 Nov 2024 14:01:44 +0100 Subject: [PATCH 02/10] chore: now missing tests and small fixes --- .../javadriver/JavaDriverRepository.java | 7 +- .../glossary/JavaDriverDialectParser.kt | 81 +++++++++++++++---- 2 files changed, 72 insertions(+), 16 deletions(-) diff --git a/packages/jetbrains-plugin/src/test/resources/project-fixtures/basic-java-project-with-mongodb/src/main/java/alt/mongodb/javadriver/JavaDriverRepository.java b/packages/jetbrains-plugin/src/test/resources/project-fixtures/basic-java-project-with-mongodb/src/main/java/alt/mongodb/javadriver/JavaDriverRepository.java index e32c475d..8bb3f203 100644 --- a/packages/jetbrains-plugin/src/test/resources/project-fixtures/basic-java-project-with-mongodb/src/main/java/alt/mongodb/javadriver/JavaDriverRepository.java +++ b/packages/jetbrains-plugin/src/test/resources/project-fixtures/basic-java-project-with-mongodb/src/main/java/alt/mongodb/javadriver/JavaDriverRepository.java @@ -54,10 +54,13 @@ public List queryMoviesByYear(String year) { .aggregate( List.of( Aggregates.match( - Filters.eq("_id", year) + Filters.eq("year", year) ), Aggregates.group( - "newField", Accumulators.avg("test", "$year") + "newField", + Accumulators.avg("test", "$xxx"), + Accumulators.sum("test2", "$year"), + Accumulators.bottom("field", Sorts.ascending("year"), "$year") ), Aggregates.project( Projections.fields( diff --git a/packages/mongodb-dialects/java-driver/src/main/kotlin/com/mongodb/jbplugin/dialects/javadriver/glossary/JavaDriverDialectParser.kt b/packages/mongodb-dialects/java-driver/src/main/kotlin/com/mongodb/jbplugin/dialects/javadriver/glossary/JavaDriverDialectParser.kt index e0a18fcb..b3d13707 100644 --- a/packages/mongodb-dialects/java-driver/src/main/kotlin/com/mongodb/jbplugin/dialects/javadriver/glossary/JavaDriverDialectParser.kt +++ b/packages/mongodb-dialects/java-driver/src/main/kotlin/com/mongodb/jbplugin/dialects/javadriver/glossary/JavaDriverDialectParser.kt @@ -29,6 +29,7 @@ private const val ACCUMULATORS_FQN = "com.mongodb.client.model.Accumulators" private const val SORTS_FQN = "com.mongodb.client.model.Sorts" private const val JAVA_LIST_FQN = "java.util.List" private const val JAVA_ARRAYS_FQN = "java.util.Arrays" + private val PARSEABLE_AGGREGATION_STAGE_METHODS = listOf( "match", "project", @@ -228,7 +229,8 @@ object JavaDriverDialectParser : DialectParser { return containingClass.qualifiedName == AGGREGATES_FQN || containingClass.qualifiedName == PROJECTIONS_FQN || - containingClass.qualifiedName == SORTS_FQN + containingClass.qualifiedName == SORTS_FQN || + containingClass.qualifiedName == ACCUMULATORS_FQN } private fun resolveBsonBuilderCall( @@ -297,7 +299,7 @@ object JavaDriverDialectParser : DialectParser { } val valueExpression = filter.argumentList.expressions[0] val valueReference = resolveValueFromExpression(valueExpression) - val fieldReference = HasFieldReference.FromSchema(valueExpression, "_id") + val fieldReference = FromSchema(valueExpression, "_id") return Node( filter, @@ -505,7 +507,7 @@ object JavaDriverDialectParser : DialectParser { } val accumulators = stageCall.getVarArgsOrIterableArgs().drop(1) - .mapNotNull { resolveToAccumulatorCall(it) } + .mapNotNull { resolveBsonBuilderCall(it, ACCUMULATORS_FQN) } val parsedAccumulators = accumulators.mapNotNull { parseAccumulatorExpression(it) } return nodeWithAccumulators(parsedAccumulators) @@ -515,13 +517,10 @@ object JavaDriverDialectParser : DialectParser { } } - private fun resolveToProjectionCall(element: PsiElement): PsiMethodCallExpression? { - return element.resolveToMethodCallExpression { _, methodCall -> - methodCall.containingClass?.qualifiedName == PROJECTIONS_FQN - } - } - - private fun parseProjectionExpression(expression: PsiMethodCallExpression): List> { + private fun parseBsonBuilderCallsSimilarToProjections( + expression: PsiMethodCallExpression, + classQualifiedName: String + ): List> { val methodCall = expression.resolveMethod() ?: return emptyList() return when (methodCall.name) { "fields", @@ -556,7 +555,8 @@ object JavaDriverDialectParser : DialectParser { ) ) ) - ) else -> null + ) + else -> null } } @@ -572,8 +572,8 @@ object JavaDriverDialectParser : DialectParser { "avg" -> parseKeyValAccumulator(expression, Name.AVG) "first" -> parseKeyValAccumulator(expression, Name.FIRST) "last" -> parseKeyValAccumulator(expression, Name.LAST) - "top" -> null - "bottom" -> null + "top" -> parseLeadingAccumulatorExpression(expression, Name.TOP) + "bottom" -> parseLeadingAccumulatorExpression(expression, Name.BOTTOM) "max" -> parseKeyValAccumulator(expression, Name.MAX) "min" -> parseKeyValAccumulator(expression, Name.MIN) "push" -> parseKeyValAccumulator(expression, Name.PUSH) @@ -605,6 +605,34 @@ object JavaDriverDialectParser : DialectParser { ) } + private fun parseLeadingAccumulatorExpression(expression: PsiMethodCallExpression, name: Name): Node? { + val keyExpr = expression.argumentList.expressions.getOrNull(0) ?: return null + val sortExprArgument = expression.argumentList.expressions.getOrNull(1) ?: return null + val valueExpr = expression.argumentList.expressions.getOrNull(2) ?: return null + + val sortExpr = resolveBsonBuilderCall(sortExprArgument, SORTS_FQN) ?: return null + + val fieldName = keyExpr.tryToResolveAsConstantString() + val sort = parseBsonBuilderCallsSimilarToProjections(sortExpr, SORTS_FQN) + val accumulatorExpr = parseComputedExpression(valueExpr, createsNewField = false) + + return Node( + expression, + listOf( + Named(name), + HasFieldReference( + if (fieldName != null) { + Computed(keyExpr, fieldName, fieldName) + } else { + HasFieldReference.Unknown as HasFieldReference.FieldReference + } + ), + HasSorts(sort), + accumulatorExpr + ) + ) + } + private fun hasMongoDbSessionReference(methodCall: PsiMethodCallExpression): Boolean { val hasEnoughArgs = methodCall.argumentList.expressionCount > 0 if (!hasEnoughArgs) { @@ -615,6 +643,31 @@ object JavaDriverDialectParser : DialectParser { return typeOfFirstArg != null && typeOfFirstArg.equalsToText(SESSION_FQN) } + private fun parseComputedExpression(element: PsiElement, createsNewField: Boolean = true): HasValueReference { + return HasValueReference( + when (val expression = element.tryToResolveAsConstantString()) { + null -> HasValueReference.Unknown as HasValueReference.ValueReference + else -> HasValueReference.Computed( + element, + Node( + element, + listOf( + if (createsNewField) { + HasFieldReference( + Computed(element, expression.trim('$'), expression) + ) + } else { + HasFieldReference( + FromSchema(element, expression.trim('$'), expression) + ) + } + ) + ) + ) + } + ) + } + private fun isAggregationStageMethodCall(callMethod: PsiMethod?): Boolean { return PARSEABLE_AGGREGATION_STAGE_METHODS.contains(callMethod?.name) && callMethod?.containingClass?.qualifiedName == AGGREGATES_FQN @@ -624,7 +677,7 @@ object JavaDriverDialectParser : DialectParser { val fieldNameAsString = expression.tryToResolveAsConstantString() val fieldReference = fieldNameAsString?.let { - HasFieldReference.FromSchema(expression, it) + FromSchema(expression, it) } ?: HasFieldReference.Unknown return fieldReference From 9bfbddf17ce64cd0ad256ba22380d3c2fce10d9a Mon Sep 17 00:00:00 2001 From: Kevin Mas Ruiz Date: Fri, 22 Nov 2024 16:23:33 +0100 Subject: [PATCH 03/10] chore: merge with main and add tests --- .../glossary/JavaDriverDialectParser.kt | 1 + .../aggregationparser/GroupStageParserTest.kt | 158 ++++++++++++++++++ .../mongodb/jbplugin/mql/components/Named.kt | 1 + 3 files changed, 160 insertions(+) create mode 100644 packages/mongodb-dialects/java-driver/src/test/kotlin/com/mongodb/jbplugin/dialects/javadriver/glossary/aggregationparser/GroupStageParserTest.kt diff --git a/packages/mongodb-dialects/java-driver/src/main/kotlin/com/mongodb/jbplugin/dialects/javadriver/glossary/JavaDriverDialectParser.kt b/packages/mongodb-dialects/java-driver/src/main/kotlin/com/mongodb/jbplugin/dialects/javadriver/glossary/JavaDriverDialectParser.kt index b3d13707..d327f02d 100644 --- a/packages/mongodb-dialects/java-driver/src/main/kotlin/com/mongodb/jbplugin/dialects/javadriver/glossary/JavaDriverDialectParser.kt +++ b/packages/mongodb-dialects/java-driver/src/main/kotlin/com/mongodb/jbplugin/dialects/javadriver/glossary/JavaDriverDialectParser.kt @@ -497,6 +497,7 @@ object JavaDriverDialectParser : DialectParser { Node( stageCall, listOf( + Named(Name.GROUP), HasFieldReference( HasFieldReference.Inferred(groupArgument, "_id", "_id") ), diff --git a/packages/mongodb-dialects/java-driver/src/test/kotlin/com/mongodb/jbplugin/dialects/javadriver/glossary/aggregationparser/GroupStageParserTest.kt b/packages/mongodb-dialects/java-driver/src/test/kotlin/com/mongodb/jbplugin/dialects/javadriver/glossary/aggregationparser/GroupStageParserTest.kt new file mode 100644 index 00000000..dbcbfd9d --- /dev/null +++ b/packages/mongodb-dialects/java-driver/src/test/kotlin/com/mongodb/jbplugin/dialects/javadriver/glossary/aggregationparser/GroupStageParserTest.kt @@ -0,0 +1,158 @@ +package com.mongodb.jbplugin.dialects.javadriver.glossary.aggregationparser + +import com.intellij.openapi.application.ApplicationManager +import com.intellij.openapi.command.WriteCommandAction +import com.intellij.psi.JavaPsiFacade +import com.intellij.psi.PsiElement +import com.intellij.psi.PsiFile +import com.mongodb.jbplugin.dialects.javadriver.IntegrationTest +import com.mongodb.jbplugin.dialects.javadriver.ParsingTest +import com.mongodb.jbplugin.dialects.javadriver.WithFile +import com.mongodb.jbplugin.dialects.javadriver.caret +import com.mongodb.jbplugin.dialects.javadriver.getQueryAtMethod +import com.mongodb.jbplugin.dialects.javadriver.glossary.JavaDriverDialect +import com.mongodb.jbplugin.mql.components.HasAccumulatedFields +import com.mongodb.jbplugin.mql.components.HasAggregation +import com.mongodb.jbplugin.mql.components.HasFieldReference +import com.mongodb.jbplugin.mql.components.HasValueReference +import com.mongodb.jbplugin.mql.components.Name +import com.mongodb.jbplugin.mql.components.Named +import org.junit.jupiter.api.Assertions.assertEquals +import org.junit.jupiter.params.ParameterizedTest +import org.junit.jupiter.params.provider.CsvSource + +@IntegrationTest +class GroupStageParserTest { + @ParsingTest( + fileName = "Aggregation.java", + value = """ +import com.mongodb.client.AggregateIterable; +import com.mongodb.client.MongoClient; +import com.mongodb.client.MongoCollection; +import com.mongodb.client.model.Aggregates; +import com.mongodb.client.model.Filters; +import org.bson.Document; +import org.bson.types.ObjectId; + +import java.util.List; + +import static com.mongodb.client.model.Filters.*; + +public final class Aggregation { + private final MongoCollection collection; + + public Aggregation(MongoClient client) { + this.collection = client.getDatabase("simple").getCollection("books"); + } + + public AggregateIterable getAllBookTitles(ObjectId id) { + return this.collection.aggregate(List.of( + Aggregates.group("${'$'}myField") + )); + } +} + """ + ) + fun `should be able to parse a group stage without accumulators`(psiFile: PsiFile) { + val aggregate = psiFile.getQueryAtMethod("Aggregation", "getAllBookTitles") + val parsedAggregate = JavaDriverDialect.parser.parse(aggregate) + val hasAggregation = parsedAggregate.component>() + assertEquals(1, hasAggregation?.children?.size) + + val groupStage = hasAggregation?.children?.get(0)!! + + val named = groupStage.component()!! + assertEquals(Name.GROUP, named.name) + + val idFieldRef = groupStage.component>()!!.reference as HasFieldReference.Inferred + val computedValueRef = groupStage.component>()!!.reference as HasValueReference.Computed + val accumulatedFields = groupStage.component>()!! + + assertEquals("_id", idFieldRef.fieldName) + assertEquals(0, accumulatedFields.children.size) + + val computedExpression = computedValueRef.expression + val fieldUsedForComputation = computedExpression.component>()!!.reference as HasFieldReference.Computed + + assertEquals("myField", fieldUsedForComputation.fieldName) + assertEquals("${'$'}myField", fieldUsedForComputation.displayName) + } + + @WithFile( + fileName = "Repository.java", + value = """ +import com.mongodb.client.AggregateIterable; +import com.mongodb.client.MongoClient; +import com.mongodb.client.MongoCollection; +import com.mongodb.client.model.Aggregates; +import com.mongodb.client.model.Accumulators; +import com.mongodb.client.model.Filters; +import org.bson.Document; +import org.bson.types.ObjectId; + +import java.util.List; + +import static com.mongodb.client.model.Filters.*; + +public final class Aggregation { + private final MongoCollection collection; + + public Aggregation(MongoClient client) { + this.collection = client.getDatabase("simple").getCollection("books"); + } + + public AggregateIterable getAllBookTitles(ObjectId id) { + return this.collection.aggregate(List.of( + Aggregates.group("${'$'}myField", Accumulators."|"("myKey", "myVal")) + )); + } +} + """, + ) + @ParameterizedTest + @CsvSource( + value = [ + "method;;expected", + "sum;;SUM", + "avg;;AVG", + "first;;FIRST", + "last;;LAST", + "max;;MAX", + "min;;MIN", + "push;;PUSH", + "addToSet;;ADD_TO_SET", + ], + delimiterString = ";;", + useHeadersInDisplayName = true + ) + fun `supports all relevant key-value accumulators from the driver`( + method: String, + expected: Name, + psiFile: PsiFile + ) { + WriteCommandAction.runWriteCommandAction(psiFile.project) { + val elementAtCaret = psiFile.caret() + val javaFacade = JavaPsiFacade.getInstance(psiFile.project) + val methodToTest = javaFacade.parserFacade.createReferenceFromText(method, null) + elementAtCaret.replace(methodToTest) + } + + ApplicationManager.getApplication().runReadAction { + val aggregate = psiFile.getQueryAtMethod("Aggregation", "getAllBookTitles") + val parsedAggregate = JavaDriverDialect.parser.parse(aggregate) + val hasAggregation = parsedAggregate.component>() + assertEquals(1, hasAggregation?.children?.size) + + val groupStage = hasAggregation?.children?.get(0)!! + val named = groupStage.component()!! + assertEquals(Name.GROUP, named.name) + + val accumulator = groupStage.component>()!!.children[0] + val accumulatorName = accumulator.component()!! + assertEquals(expected, accumulatorName.name) + + val accumulatorField = accumulator.component>()?.reference as HasFieldReference.Computed + assertEquals("myKey", accumulatorField.fieldName) + } + } +} diff --git a/packages/mongodb-mql-model/src/main/kotlin/com/mongodb/jbplugin/mql/components/Named.kt b/packages/mongodb-mql-model/src/main/kotlin/com/mongodb/jbplugin/mql/components/Named.kt index bf4e5abc..46f9500a 100644 --- a/packages/mongodb-mql-model/src/main/kotlin/com/mongodb/jbplugin/mql/components/Named.kt +++ b/packages/mongodb-mql-model/src/main/kotlin/com/mongodb/jbplugin/mql/components/Named.kt @@ -52,6 +52,7 @@ enum class Name(val canonical: String) { PROJECT("project"), INCLUDE("include"), EXCLUDE("exclude"), + GROUP("group"), SUM("sum"), AVG("avg"), FIRST("first"), From b6be79d3e507fd4f290544e8d265f00fc3a886cf Mon Sep 17 00:00:00 2001 From: Kevin Mas Ruiz Date: Fri, 22 Nov 2024 16:33:55 +0100 Subject: [PATCH 04/10] chore: added tests left --- .../aggregationparser/GroupStageParserTest.kt | 85 +++++++++++++++++++ 1 file changed, 85 insertions(+) diff --git a/packages/mongodb-dialects/java-driver/src/test/kotlin/com/mongodb/jbplugin/dialects/javadriver/glossary/aggregationparser/GroupStageParserTest.kt b/packages/mongodb-dialects/java-driver/src/test/kotlin/com/mongodb/jbplugin/dialects/javadriver/glossary/aggregationparser/GroupStageParserTest.kt index dbcbfd9d..92911302 100644 --- a/packages/mongodb-dialects/java-driver/src/test/kotlin/com/mongodb/jbplugin/dialects/javadriver/glossary/aggregationparser/GroupStageParserTest.kt +++ b/packages/mongodb-dialects/java-driver/src/test/kotlin/com/mongodb/jbplugin/dialects/javadriver/glossary/aggregationparser/GroupStageParserTest.kt @@ -14,6 +14,7 @@ import com.mongodb.jbplugin.dialects.javadriver.glossary.JavaDriverDialect import com.mongodb.jbplugin.mql.components.HasAccumulatedFields import com.mongodb.jbplugin.mql.components.HasAggregation import com.mongodb.jbplugin.mql.components.HasFieldReference +import com.mongodb.jbplugin.mql.components.HasSorts import com.mongodb.jbplugin.mql.components.HasValueReference import com.mongodb.jbplugin.mql.components.Name import com.mongodb.jbplugin.mql.components.Named @@ -153,6 +154,90 @@ public final class Aggregation { val accumulatorField = accumulator.component>()?.reference as HasFieldReference.Computed assertEquals("myKey", accumulatorField.fieldName) + + val accumulatorComputed = accumulator.component>()?.reference as HasValueReference.Computed + val accumulatorComputedFieldValue = accumulatorComputed.expression.component>()!!.reference as HasFieldReference.FromSchema + assertEquals("myVal", accumulatorComputedFieldValue.fieldName) + } + } + + @WithFile( + fileName = "Repository.java", + value = """ +import com.mongodb.client.AggregateIterable; +import com.mongodb.client.MongoClient; +import com.mongodb.client.MongoCollection; +import com.mongodb.client.model.Aggregates; +import com.mongodb.client.model.Accumulators; +import com.mongodb.client.model.Filters; +import com.mongodb.client.model.Sorts;import org.bson.Document; +import org.bson.types.ObjectId; + +import java.util.List; + +import static com.mongodb.client.model.Filters.*; + +public final class Aggregation { + private final MongoCollection collection; + + public Aggregation(MongoClient client) { + this.collection = client.getDatabase("simple").getCollection("books"); + } + + public AggregateIterable getAllBookTitles(ObjectId id) { + return this.collection.aggregate(List.of( + Aggregates.group("${'$'}myField", Accumulators."|"("myKey", Sorts.ascending("mySort"), "myVal")) + )); + } +} + """, + ) + @ParameterizedTest + @CsvSource( + value = [ + "method;;expected", + "top;;TOP", + "bottom;;BOTTOM", + ], + delimiterString = ";;", + useHeadersInDisplayName = true + ) + fun `supports all relevant key-value accumulators with sorting criteria from the driver`( + method: String, + expected: Name, + psiFile: PsiFile + ) { + WriteCommandAction.runWriteCommandAction(psiFile.project) { + val elementAtCaret = psiFile.caret() + val javaFacade = JavaPsiFacade.getInstance(psiFile.project) + val methodToTest = javaFacade.parserFacade.createReferenceFromText(method, null) + elementAtCaret.replace(methodToTest) + } + + ApplicationManager.getApplication().runReadAction { + val aggregate = psiFile.getQueryAtMethod("Aggregation", "getAllBookTitles") + val parsedAggregate = JavaDriverDialect.parser.parse(aggregate) + val hasAggregation = parsedAggregate.component>() + assertEquals(1, hasAggregation?.children?.size) + + val groupStage = hasAggregation?.children?.get(0)!! + val named = groupStage.component()!! + assertEquals(Name.GROUP, named.name) + + val accumulator = groupStage.component>()!!.children[0] + val accumulatorName = accumulator.component()!! + assertEquals(expected, accumulatorName.name) + + val accumulatorField = accumulator.component>()?.reference as HasFieldReference.Computed + assertEquals("myKey", accumulatorField.fieldName) + + val accumulatorComputed = accumulator.component>()?.reference as HasValueReference.Computed + val accumulatorComputedFieldValue = accumulatorComputed.expression.component>()!!.reference as HasFieldReference.FromSchema + assertEquals("myVal", accumulatorComputedFieldValue.fieldName) + + val accumulatorSorting = accumulator.component>()!!.children[0] + val sortingField = accumulatorSorting.component>()!!.reference as HasFieldReference.FromSchema + assertEquals("mySort", sortingField.fieldName) } } } From 5a9d4269ffadc0ec0634aba7e9d87820f33eddca Mon Sep 17 00:00:00 2001 From: Kevin Mas Ruiz Date: Fri, 22 Nov 2024 17:28:18 +0100 Subject: [PATCH 05/10] chore: use computed type instead of computed expression --- .../javadriver/JavaDriverRepository.java | 2 +- .../glossary/JavaDriverDialectParser.kt | 28 +++++++++++-------- .../aggregationparser/GroupStageParserTest.kt | 6 ++-- .../mongosh/backend/MongoshBackend.kt | 1 + .../com/mongodb/jbplugin/mql/BsonType.kt | 3 ++ .../mql/components/HasValueReference.kt | 10 +++---- 6 files changed, 28 insertions(+), 22 deletions(-) diff --git a/packages/jetbrains-plugin/src/test/resources/project-fixtures/basic-java-project-with-mongodb/src/main/java/alt/mongodb/javadriver/JavaDriverRepository.java b/packages/jetbrains-plugin/src/test/resources/project-fixtures/basic-java-project-with-mongodb/src/main/java/alt/mongodb/javadriver/JavaDriverRepository.java index 8bb3f203..86c11f7c 100644 --- a/packages/jetbrains-plugin/src/test/resources/project-fixtures/basic-java-project-with-mongodb/src/main/java/alt/mongodb/javadriver/JavaDriverRepository.java +++ b/packages/jetbrains-plugin/src/test/resources/project-fixtures/basic-java-project-with-mongodb/src/main/java/alt/mongodb/javadriver/JavaDriverRepository.java @@ -58,7 +58,7 @@ public List queryMoviesByYear(String year) { ), Aggregates.group( "newField", - Accumulators.avg("test", "$xxx"), + Accumulators.avg("test", "$year"), Accumulators.sum("test2", "$year"), Accumulators.bottom("field", Sorts.ascending("year"), "$year") ), diff --git a/packages/mongodb-dialects/java-driver/src/main/kotlin/com/mongodb/jbplugin/dialects/javadriver/glossary/JavaDriverDialectParser.kt b/packages/mongodb-dialects/java-driver/src/main/kotlin/com/mongodb/jbplugin/dialects/javadriver/glossary/JavaDriverDialectParser.kt index d327f02d..22011b10 100644 --- a/packages/mongodb-dialects/java-driver/src/main/kotlin/com/mongodb/jbplugin/dialects/javadriver/glossary/JavaDriverDialectParser.kt +++ b/packages/mongodb-dialects/java-driver/src/main/kotlin/com/mongodb/jbplugin/dialects/javadriver/glossary/JavaDriverDialectParser.kt @@ -12,6 +12,7 @@ import com.mongodb.jbplugin.mql.BsonArray import com.mongodb.jbplugin.mql.BsonBoolean import com.mongodb.jbplugin.mql.BsonInt32 import com.mongodb.jbplugin.mql.BsonType +import com.mongodb.jbplugin.mql.ComputedBsonType import com.mongodb.jbplugin.mql.Node import com.mongodb.jbplugin.mql.components.* import com.mongodb.jbplugin.mql.components.HasFieldReference.Computed @@ -650,18 +651,21 @@ object JavaDriverDialectParser : DialectParser { null -> HasValueReference.Unknown as HasValueReference.ValueReference else -> HasValueReference.Computed( element, - Node( - element, - listOf( - if (createsNewField) { - HasFieldReference( - Computed(element, expression.trim('$'), expression) - ) - } else { - HasFieldReference( - FromSchema(element, expression.trim('$'), expression) - ) - } + type = ComputedBsonType( + BsonAny, + Node( + element, + listOf( + if (createsNewField) { + HasFieldReference( + Computed(element, expression.trim('$'), expression) + ) + } else { + HasFieldReference( + FromSchema(element, expression.trim('$'), expression) + ) + } + ) ) ) ) diff --git a/packages/mongodb-dialects/java-driver/src/test/kotlin/com/mongodb/jbplugin/dialects/javadriver/glossary/aggregationparser/GroupStageParserTest.kt b/packages/mongodb-dialects/java-driver/src/test/kotlin/com/mongodb/jbplugin/dialects/javadriver/glossary/aggregationparser/GroupStageParserTest.kt index 92911302..ccad05b5 100644 --- a/packages/mongodb-dialects/java-driver/src/test/kotlin/com/mongodb/jbplugin/dialects/javadriver/glossary/aggregationparser/GroupStageParserTest.kt +++ b/packages/mongodb-dialects/java-driver/src/test/kotlin/com/mongodb/jbplugin/dialects/javadriver/glossary/aggregationparser/GroupStageParserTest.kt @@ -72,7 +72,7 @@ public final class Aggregation { assertEquals("_id", idFieldRef.fieldName) assertEquals(0, accumulatedFields.children.size) - val computedExpression = computedValueRef.expression + val computedExpression = computedValueRef.type.expression val fieldUsedForComputation = computedExpression.component>()!!.reference as HasFieldReference.Computed assertEquals("myField", fieldUsedForComputation.fieldName) @@ -156,7 +156,7 @@ public final class Aggregation { assertEquals("myKey", accumulatorField.fieldName) val accumulatorComputed = accumulator.component>()?.reference as HasValueReference.Computed - val accumulatorComputedFieldValue = accumulatorComputed.expression.component>()!!.reference as HasFieldReference.FromSchema + val accumulatorComputedFieldValue = accumulatorComputed.type.expression.component>()!!.reference as HasFieldReference.FromSchema assertEquals("myVal", accumulatorComputedFieldValue.fieldName) } } @@ -232,7 +232,7 @@ public final class Aggregation { assertEquals("myKey", accumulatorField.fieldName) val accumulatorComputed = accumulator.component>()?.reference as HasValueReference.Computed - val accumulatorComputedFieldValue = accumulatorComputed.expression.component>()!!.reference as HasFieldReference.FromSchema + val accumulatorComputedFieldValue = accumulatorComputed.type.expression.component>()!!.reference as HasFieldReference.FromSchema assertEquals("myVal", accumulatorComputedFieldValue.fieldName) val accumulatorSorting = accumulator.component>()!!.children[0] diff --git a/packages/mongodb-dialects/mongosh/src/main/kotlin/com/mongodb/jbplugin/dialects/mongosh/backend/MongoshBackend.kt b/packages/mongodb-dialects/mongosh/src/main/kotlin/com/mongodb/jbplugin/dialects/mongosh/backend/MongoshBackend.kt index daee46c9..1b106bb5 100644 --- a/packages/mongodb-dialects/mongosh/src/main/kotlin/com/mongodb/jbplugin/dialects/mongosh/backend/MongoshBackend.kt +++ b/packages/mongodb-dialects/mongosh/src/main/kotlin/com/mongodb/jbplugin/dialects/mongosh/backend/MongoshBackend.kt @@ -251,4 +251,5 @@ private fun defaultValueOfBsonType(type: BsonType): Any? = when (type) { is BsonObject -> emptyMap() BsonObjectId -> ObjectId("000000000000000000000000") BsonString -> "" + is ComputedBsonType<*> -> defaultValueOfBsonType(type.baseType) } diff --git a/packages/mongodb-mql-model/src/main/kotlin/com/mongodb/jbplugin/mql/BsonType.kt b/packages/mongodb-mql-model/src/main/kotlin/com/mongodb/jbplugin/mql/BsonType.kt index c538b9d6..bed80c28 100644 --- a/packages/mongodb-mql-model/src/main/kotlin/com/mongodb/jbplugin/mql/BsonType.kt +++ b/packages/mongodb-mql-model/src/main/kotlin/com/mongodb/jbplugin/mql/BsonType.kt @@ -179,6 +179,9 @@ data class BsonArray( } } +data class ComputedBsonType(val baseType: BsonType, val expression: Node) : + BsonType by baseType // for now it will behave as baseType + /** * Returns the inferred BSON type of the current Java class, considering it's nullability. * diff --git a/packages/mongodb-mql-model/src/main/kotlin/com/mongodb/jbplugin/mql/components/HasValueReference.kt b/packages/mongodb-mql-model/src/main/kotlin/com/mongodb/jbplugin/mql/components/HasValueReference.kt index e7c9002c..0f57b0be 100644 --- a/packages/mongodb-mql-model/src/main/kotlin/com/mongodb/jbplugin/mql/components/HasValueReference.kt +++ b/packages/mongodb-mql-model/src/main/kotlin/com/mongodb/jbplugin/mql/components/HasValueReference.kt @@ -1,8 +1,8 @@ package com.mongodb.jbplugin.mql.components -import com.mongodb.jbplugin.mql.BsonAny import com.mongodb.jbplugin.mql.BsonType import com.mongodb.jbplugin.mql.Component +import com.mongodb.jbplugin.mql.ComputedBsonType import com.mongodb.jbplugin.mql.HasChildren import com.mongodb.jbplugin.mql.Node @@ -73,14 +73,12 @@ data class HasValueReference( */ data class Computed( val source: S, - val expression: Node, - ) : ValueReference { - val type: BsonType = BsonAny - } + val type: ComputedBsonType, + ) : ValueReference override val children: List> get() = when (reference) { - is Computed -> listOf(reference.expression) + is Computed -> listOf(reference.type.expression) else -> emptyList() } } From 6e412c06f5cd46d7b8f47b6168157660548a7576 Mon Sep 17 00:00:00 2001 From: Kevin Mas Ruiz Date: Fri, 22 Nov 2024 17:55:24 +0100 Subject: [PATCH 06/10] chore: fix tryToResolveAsConstantString, it should not cast if it's not a string --- .../jbplugin/dialects/javadriver/glossary/PsiMdbTreeUtil.kt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/packages/mongodb-dialects/java-driver/src/main/kotlin/com/mongodb/jbplugin/dialects/javadriver/glossary/PsiMdbTreeUtil.kt b/packages/mongodb-dialects/java-driver/src/main/kotlin/com/mongodb/jbplugin/dialects/javadriver/glossary/PsiMdbTreeUtil.kt index 1bd96ad2..a79026c4 100644 --- a/packages/mongodb-dialects/java-driver/src/main/kotlin/com/mongodb/jbplugin/dialects/javadriver/glossary/PsiMdbTreeUtil.kt +++ b/packages/mongodb-dialects/java-driver/src/main/kotlin/com/mongodb/jbplugin/dialects/javadriver/glossary/PsiMdbTreeUtil.kt @@ -353,7 +353,7 @@ fun PsiElement.tryToResolveAsConstant(): Pair { * @return */ fun PsiElement.tryToResolveAsConstantString(): String? = - tryToResolveAsConstant().takeIf { it.first }?.second?.toString() + tryToResolveAsConstant().takeIf { it.first }?.second as? String /** * Maps a PsiType to its BSON counterpart. From c7557ca1df6f163faf78264731f95532329d60d9 Mon Sep 17 00:00:00 2001 From: Kevin Mas Ruiz Date: Fri, 22 Nov 2024 18:09:43 +0100 Subject: [PATCH 07/10] chore: be smarter on computed expression, we can use runtime values --- .../glossary/JavaDriverDialectParser.kt | 20 ++++++++++++++----- 1 file changed, 15 insertions(+), 5 deletions(-) diff --git a/packages/mongodb-dialects/java-driver/src/main/kotlin/com/mongodb/jbplugin/dialects/javadriver/glossary/JavaDriverDialectParser.kt b/packages/mongodb-dialects/java-driver/src/main/kotlin/com/mongodb/jbplugin/dialects/javadriver/glossary/JavaDriverDialectParser.kt index 22011b10..d17cc51f 100644 --- a/packages/mongodb-dialects/java-driver/src/main/kotlin/com/mongodb/jbplugin/dialects/javadriver/glossary/JavaDriverDialectParser.kt +++ b/packages/mongodb-dialects/java-driver/src/main/kotlin/com/mongodb/jbplugin/dialects/javadriver/glossary/JavaDriverDialectParser.kt @@ -646,10 +646,10 @@ object JavaDriverDialectParser : DialectParser { } private fun parseComputedExpression(element: PsiElement, createsNewField: Boolean = true): HasValueReference { + val (constant, value) = element.tryToResolveAsConstant() return HasValueReference( - when (val expression = element.tryToResolveAsConstantString()) { - null -> HasValueReference.Unknown as HasValueReference.ValueReference - else -> HasValueReference.Computed( + when { + constant && value is String -> HasValueReference.Computed( element, type = ComputedBsonType( BsonAny, @@ -658,17 +658,27 @@ object JavaDriverDialectParser : DialectParser { listOf( if (createsNewField) { HasFieldReference( - Computed(element, expression.trim('$'), expression) + Computed(element, value.trim('$'), value) ) } else { HasFieldReference( - FromSchema(element, expression.trim('$'), expression) + FromSchema(element, value.trim('$'), value) ) } ) ) ) ) + constant && value != null -> HasValueReference.Constant( + element, + value, + value.javaClass.toBsonType(value) + ) + !constant && element is PsiExpression -> HasValueReference.Runtime( + element, + element.type?.toBsonType() ?: BsonAny + ) + else -> HasValueReference.Unknown as HasValueReference.ValueReference } ) } From 1f4c4aec575557229c8741f4b618fa7240907a75 Mon Sep 17 00:00:00 2001 From: Kevin Mas Ruiz Date: Mon, 25 Nov 2024 15:13:38 +0100 Subject: [PATCH 08/10] Update packages/mongodb-mql-model/src/main/kotlin/com/mongodb/jbplugin/mql/components/HasValueReference.kt Co-authored-by: Himanshu Singh --- .../com/mongodb/jbplugin/mql/components/HasValueReference.kt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/packages/mongodb-mql-model/src/main/kotlin/com/mongodb/jbplugin/mql/components/HasValueReference.kt b/packages/mongodb-mql-model/src/main/kotlin/com/mongodb/jbplugin/mql/components/HasValueReference.kt index 0f57b0be..62cac7ea 100644 --- a/packages/mongodb-mql-model/src/main/kotlin/com/mongodb/jbplugin/mql/components/HasValueReference.kt +++ b/packages/mongodb-mql-model/src/main/kotlin/com/mongodb/jbplugin/mql/components/HasValueReference.kt @@ -8,7 +8,7 @@ import com.mongodb.jbplugin.mql.Node data class HasValueReference( val reference: ValueReference, -) : Component, HasChildren { +) : HasChildren { sealed interface ValueReference From d8a9dc0bf6c9e63b4a38d4e235dc8a6f04e8dd56 Mon Sep 17 00:00:00 2001 From: Himanshu Singh Date: Mon, 25 Nov 2024 17:05:31 +0100 Subject: [PATCH 09/10] chore: tests and minor suggestions for INTELLIJ-128 (#101) --- .../JavaDriverCompletionContributorTest.kt | 131 ++++++ ...erMongoDbAutocompletionPopupHandlerTest.kt | 133 ++++++ ...avaDriverFieldCheckLinterInspectionTest.kt | 121 ++++++ .../glossary/JavaDriverDialectParser.kt | 25 +- .../aggregationparser/GroupStageParserTest.kt | 408 +++++++++++++++++- .../mql/components/HasValueReference.kt | 1 - 6 files changed, 802 insertions(+), 17 deletions(-) diff --git a/packages/jetbrains-plugin/src/test/kotlin/com/mongodb/jbplugin/autocomplete/JavaDriverCompletionContributorTest.kt b/packages/jetbrains-plugin/src/test/kotlin/com/mongodb/jbplugin/autocomplete/JavaDriverCompletionContributorTest.kt index 51b51f53..76214b1f 100644 --- a/packages/jetbrains-plugin/src/test/kotlin/com/mongodb/jbplugin/autocomplete/JavaDriverCompletionContributorTest.kt +++ b/packages/jetbrains-plugin/src/test/kotlin/com/mongodb/jbplugin/autocomplete/JavaDriverCompletionContributorTest.kt @@ -805,4 +805,135 @@ public class Repository { }, ) } + + @ParsingTest( + fileName = "Repository.java", + value = """ +import com.mongodb.client.FindIterable; +import com.mongodb.client.MongoClient; +import com.mongodb.client.MongoCollection; +import com.mongodb.client.model.Aggregates; +import com.mongodb.client.model.Projections; +import com.mongodb.client.model.Sorts; +import org.bson.Document; +import org.bson.types.ObjectId; +import java.util.List; +import static com.mongodb.client.model.Filters.*; +import static com.mongodb.client.model.Updates.*; + +public class Repository { + private final MongoClient client; + + public Repository(MongoClient client) { + this.client = client; + } + + public void exampleFind() { + client.getDatabase("myDatabase").getCollection("myCollection") + .aggregate(List.of( + Aggregates.group( + "" + ) + )); + } +} + """, + ) + fun `should autocomplete fields from the current namespace for _id expression in Aggregates#group stage`( + fixture: CodeInsightTestFixture, + ) { + fixture.specifyDialect(JavaDriverDialect) + + val (dataSource, readModelProvider) = fixture.setupConnection() + val namespace = Namespace("myDatabase", "myCollection") + + `when`( + readModelProvider.slice(eq(dataSource), eq(GetCollectionSchema.Slice(namespace))) + ).thenReturn( + GetCollectionSchema( + CollectionSchema( + namespace, + BsonObject( + mapOf( + "myField" to BsonString, + ), + ), + ), + ), + ) + + val elements = fixture.completeBasic() + + assertTrue( + elements.containsElements { + it.lookupString == "myField" + }, + ) + } + + @ParsingTest( + fileName = "Repository.java", + value = """ +import com.mongodb.client.FindIterable; +import com.mongodb.client.MongoClient; +import com.mongodb.client.MongoCollection; +import com.mongodb.client.model.Accumulators;import com.mongodb.client.model.Aggregates; +import com.mongodb.client.model.Projections; +import com.mongodb.client.model.Sorts; +import org.bson.Document; +import org.bson.types.ObjectId; +import java.util.List; +import static com.mongodb.client.model.Filters.*; +import static com.mongodb.client.model.Updates.*; + +public class Repository { + private final MongoClient client; + + public Repository(MongoClient client) { + this.client = client; + } + + public void exampleFind() { + client.getDatabase("myDatabase").getCollection("myCollection") + .aggregate(List.of( + Aggregates.group( + "${'$'}year", + Accumulators.sum("totalMovies", "") + ) + )); + } +} + """, + ) + fun `should autocomplete fields from the current namespace for accumulator expression in Aggregates#group stage`( + fixture: CodeInsightTestFixture, + ) { + fixture.specifyDialect(JavaDriverDialect) + + val (dataSource, readModelProvider) = fixture.setupConnection() + val namespace = Namespace("myDatabase", "myCollection") + + `when`( + readModelProvider.slice(eq(dataSource), eq(GetCollectionSchema.Slice(namespace))) + ).thenReturn( + GetCollectionSchema( + CollectionSchema( + namespace, + BsonObject( + mapOf( + "myField" to BsonString, + ), + ), + ), + ), + ) + + val elements = fixture.completeBasic() + + assertTrue( + elements.containsElements { + it.lookupString == "myField" + }, + ) + } } diff --git a/packages/jetbrains-plugin/src/test/kotlin/com/mongodb/jbplugin/autocomplete/JavaDriverMongoDbAutocompletionPopupHandlerTest.kt b/packages/jetbrains-plugin/src/test/kotlin/com/mongodb/jbplugin/autocomplete/JavaDriverMongoDbAutocompletionPopupHandlerTest.kt index 357c9b5a..2557fa22 100644 --- a/packages/jetbrains-plugin/src/test/kotlin/com/mongodb/jbplugin/autocomplete/JavaDriverMongoDbAutocompletionPopupHandlerTest.kt +++ b/packages/jetbrains-plugin/src/test/kotlin/com/mongodb/jbplugin/autocomplete/JavaDriverMongoDbAutocompletionPopupHandlerTest.kt @@ -818,4 +818,137 @@ public class Repository { }, ) } + + @ParsingTest( + fileName = "Repository.java", + value = """ +import com.mongodb.client.FindIterable; +import com.mongodb.client.MongoClient; +import com.mongodb.client.MongoCollection; +import com.mongodb.client.model.Aggregates; +import com.mongodb.client.model.Projections; +import com.mongodb.client.model.Sorts; +import org.bson.Document; +import org.bson.types.ObjectId; +import java.util.List; +import static com.mongodb.client.model.Filters.*; +import static com.mongodb.client.model.Updates.*; + +public class Repository { + private final MongoClient client; + + public Repository(MongoClient client) { + this.client = client; + } + + public void exampleFind() { + client.getDatabase("myDatabase").getCollection("myCollection") + .aggregate(List.of( + Aggregates.group( + + ) + )); + } +} + """, + ) + fun `should autocomplete fields from the current namespace for _id expression in Aggregates#group stage`( + fixture: CodeInsightTestFixture, + ) { + fixture.specifyDialect(JavaDriverDialect) + + val (dataSource, readModelProvider) = fixture.setupConnection() + val namespace = Namespace("myDatabase", "myCollection") + + `when`( + readModelProvider.slice(eq(dataSource), eq(GetCollectionSchema.Slice(namespace))) + ).thenReturn( + GetCollectionSchema( + CollectionSchema( + namespace, + BsonObject( + mapOf( + "myField" to BsonString, + ), + ), + ), + ), + ) + + fixture.type('"') + val elements = fixture.completeBasic() + + assertTrue( + elements.containsElements { + it.lookupString == "myField" + }, + ) + } + + @ParsingTest( + fileName = "Repository.java", + value = """ +import com.mongodb.client.FindIterable; +import com.mongodb.client.MongoClient; +import com.mongodb.client.MongoCollection; +import com.mongodb.client.model.Accumulators;import com.mongodb.client.model.Aggregates; +import com.mongodb.client.model.Projections; +import com.mongodb.client.model.Sorts; +import org.bson.Document; +import org.bson.types.ObjectId; +import java.util.List; +import static com.mongodb.client.model.Filters.*; +import static com.mongodb.client.model.Updates.*; + +public class Repository { + private final MongoClient client; + + public Repository(MongoClient client) { + this.client = client; + } + + public void exampleFind() { + client.getDatabase("myDatabase").getCollection("myCollection") + .aggregate(List.of( + Aggregates.group( + "${'$'}year", + Accumulators.sum("totalMovies", ) + ) + )); + } +} + """, + ) + fun `should autocomplete fields from the current namespace for accumulator expression in Aggregates#group stage`( + fixture: CodeInsightTestFixture, + ) { + fixture.specifyDialect(JavaDriverDialect) + + val (dataSource, readModelProvider) = fixture.setupConnection() + val namespace = Namespace("myDatabase", "myCollection") + + `when`( + readModelProvider.slice(eq(dataSource), eq(GetCollectionSchema.Slice(namespace))) + ).thenReturn( + GetCollectionSchema( + CollectionSchema( + namespace, + BsonObject( + mapOf( + "myField" to BsonString, + ), + ), + ), + ), + ) + + fixture.type('"') + val elements = fixture.completeBasic() + + assertTrue( + elements.containsElements { + it.lookupString == "myField" + }, + ) + } } diff --git a/packages/jetbrains-plugin/src/test/kotlin/com/mongodb/jbplugin/inspections/impl/JavaDriverFieldCheckLinterInspectionTest.kt b/packages/jetbrains-plugin/src/test/kotlin/com/mongodb/jbplugin/inspections/impl/JavaDriverFieldCheckLinterInspectionTest.kt index 1da59d4a..eb0d4145 100644 --- a/packages/jetbrains-plugin/src/test/kotlin/com/mongodb/jbplugin/inspections/impl/JavaDriverFieldCheckLinterInspectionTest.kt +++ b/packages/jetbrains-plugin/src/test/kotlin/com/mongodb/jbplugin/inspections/impl/JavaDriverFieldCheckLinterInspectionTest.kt @@ -9,6 +9,7 @@ import com.mongodb.jbplugin.fixtures.setupConnection import com.mongodb.jbplugin.fixtures.specifyDialect import com.mongodb.jbplugin.mql.BsonDouble import com.mongodb.jbplugin.mql.BsonObject +import com.mongodb.jbplugin.mql.BsonString import com.mongodb.jbplugin.mql.CollectionSchema import com.mongodb.jbplugin.mql.Namespace import org.mockito.Mockito.`when` @@ -501,4 +502,124 @@ public class Repository { fixture.enableInspections(FieldCheckInspectionBridge::class.java) fixture.testHighlighting() } + + @ParsingTest( + fileName = "Repository.java", + value = """ +import com.mongodb.client.AggregateIterable; +import com.mongodb.client.MongoClient; +import com.mongodb.client.MongoCollection; +import com.mongodb.client.model.*; +import org.bson.Document; +import org.bson.conversions.Bson; +import org.bson.types.ObjectId; +import java.util.List; +import static com.mongodb.client.model.Filters.*; + +public class Repository { + private final MongoClient client; + + public Repository(MongoClient client) { + this.client = client; + } + + public AggregateIterable goodGroupAggregate1() { + return client.getDatabase("myDatabase") + .getCollection("myCollection") + .aggregate(List.of( + Aggregates.group(null), + Aggregates.group("${'$'}possibleIdField"), + Aggregates.group("${'$'}possibleIdField", Accumulators.sum("totalCount", 1)), + Aggregates.group( + "${'$'}possibleIdField", + Accumulators.sum("totalCount", "${'$'}otherField") + ) + )); + } + + private String getOtherField() { + return "${'$'}otherField"; + } + + public AggregateIterable goodGroupAggregate2() { + String fieldName = "${'$'}possibleIdField"; + BsonField totalCountAcc = Accumulators.sum("totalCount", 1); + return client.getDatabase("myDatabase") + .getCollection("myCollection") + .aggregate(List.of( + Aggregates.group(fieldName), + Aggregates.group(fieldName, totalCountAcc), + Aggregates.group( + fieldName, + Accumulators.sum("totalCount", getOtherField()) + ) + )); + } + + private String getBadFieldName() { + return "${'$'}nonExistentField"; + } + + private BsonField getAvgCountAcc() { + return Accumulators.avg( + "avgCount", + getBadFieldName() + ); + } + + public AggregateIterable badGroupAggregate1() { + String badFieldName = "${'$'}nonExistentField"; + BsonField avgCountAcc = Accumulators.avg( + "avgCount", + badFieldName + ); + return client.getDatabase("myDatabase") + .getCollection("myCollection") + .aggregate(List.of( + Aggregates.group( + "${'$'}nonExistentField" + ), + Aggregates.group( + badFieldName, + Accumulators.sum( + "totalCount", + badFieldName + ), + Accumulators.sum( + "totalCount", + getBadFieldName() + ), + avgCountAcc, + getAvgCountAcc() + ) + )); + } +} + """, + ) + fun `shows an inspection for Aggregates#group call when the field does not exist in the current namespace`( + fixture: CodeInsightTestFixture, + ) { + val (dataSource, readModelProvider) = fixture.setupConnection() + fixture.specifyDialect(JavaDriverDialect) + + `when`( + readModelProvider.slice(eq(dataSource), any()) + ).thenReturn( + GetCollectionSchema( + CollectionSchema( + Namespace("myDatabase", "myCollection"), + BsonObject( + mapOf( + "possibleIdField" to BsonString, + "otherField" to BsonString, + ) + ) + ) + ), + ) + + fixture.enableInspections(FieldCheckInspectionBridge::class.java) + fixture.testHighlighting() + } } diff --git a/packages/mongodb-dialects/java-driver/src/main/kotlin/com/mongodb/jbplugin/dialects/javadriver/glossary/JavaDriverDialectParser.kt b/packages/mongodb-dialects/java-driver/src/main/kotlin/com/mongodb/jbplugin/dialects/javadriver/glossary/JavaDriverDialectParser.kt index 170dff62..59439f83 100644 --- a/packages/mongodb-dialects/java-driver/src/main/kotlin/com/mongodb/jbplugin/dialects/javadriver/glossary/JavaDriverDialectParser.kt +++ b/packages/mongodb-dialects/java-driver/src/main/kotlin/com/mongodb/jbplugin/dialects/javadriver/glossary/JavaDriverDialectParser.kt @@ -11,6 +11,7 @@ import com.mongodb.jbplugin.mql.BsonAnyOf import com.mongodb.jbplugin.mql.BsonArray import com.mongodb.jbplugin.mql.BsonBoolean import com.mongodb.jbplugin.mql.BsonInt32 +import com.mongodb.jbplugin.mql.BsonString import com.mongodb.jbplugin.mql.BsonType import com.mongodb.jbplugin.mql.ComputedBsonType import com.mongodb.jbplugin.mql.Node @@ -612,7 +613,7 @@ object JavaDriverDialectParser : DialectParser { val valueExpr = expression.argumentList.expressions.getOrNull(1) ?: return null val fieldName = keyExpr.tryToResolveAsConstantString() - val accumulatorExpr = parseComputedExpression(valueExpr, createsNewField = false) + val accumulatorExpr = parseComputedExpression(valueExpr) return Node( expression, @@ -639,7 +640,7 @@ object JavaDriverDialectParser : DialectParser { val fieldName = keyExpr.tryToResolveAsConstantString() val sort = parseBsonBuilderCallsSimilarToProjections(sortExpr, SORTS_FQN) - val accumulatorExpr = parseComputedExpression(valueExpr, createsNewField = false) + val accumulatorExpr = parseComputedExpression(valueExpr) return Node( expression, @@ -715,34 +716,28 @@ object JavaDriverDialectParser : DialectParser { return typeOfFirstArg != null && typeOfFirstArg.equalsToText(SESSION_FQN) } - private fun parseComputedExpression(element: PsiElement, createsNewField: Boolean = true): HasValueReference { + private fun parseComputedExpression(element: PsiElement): HasValueReference { val (constant, value) = element.tryToResolveAsConstant() return HasValueReference( when { constant && value is String -> HasValueReference.Computed( element, type = ComputedBsonType( - BsonAny, + BsonString, Node( element, listOf( - if (createsNewField) { - HasFieldReference( - Computed(element, value.trim('$'), value) - ) - } else { - HasFieldReference( - FromSchema(element, value.trim('$'), value) - ) - } + HasFieldReference( + FromSchema(element, value.trim('$'), value) + ) ) ) ) ) - constant && value != null -> HasValueReference.Constant( + constant -> HasValueReference.Constant( element, value, - value.javaClass.toBsonType(value) + value?.javaClass.toBsonType(value) ) !constant && element is PsiExpression -> HasValueReference.Runtime( element, diff --git a/packages/mongodb-dialects/java-driver/src/test/kotlin/com/mongodb/jbplugin/dialects/javadriver/glossary/aggregationparser/GroupStageParserTest.kt b/packages/mongodb-dialects/java-driver/src/test/kotlin/com/mongodb/jbplugin/dialects/javadriver/glossary/aggregationparser/GroupStageParserTest.kt index ccad05b5..1c5a562b 100644 --- a/packages/mongodb-dialects/java-driver/src/test/kotlin/com/mongodb/jbplugin/dialects/javadriver/glossary/aggregationparser/GroupStageParserTest.kt +++ b/packages/mongodb-dialects/java-driver/src/test/kotlin/com/mongodb/jbplugin/dialects/javadriver/glossary/aggregationparser/GroupStageParserTest.kt @@ -39,6 +39,56 @@ import java.util.List; import static com.mongodb.client.model.Filters.*; +public final class Aggregation { + private final MongoCollection collection; + + public Aggregation(MongoClient client) { + this.collection = client.getDatabase("simple").getCollection("books"); + } + + public AggregateIterable getAllBookTitles(ObjectId id) { + return this.collection.aggregate(List.of( + Aggregates.group(null) + )); + } +} + """ + ) + fun `should be able to parse a group stage with null _id`(psiFile: PsiFile) { + val aggregate = psiFile.getQueryAtMethod("Aggregation", "getAllBookTitles") + val parsedAggregate = JavaDriverDialect.parser.parse(aggregate) + val hasAggregation = parsedAggregate.component>() + assertEquals(1, hasAggregation?.children?.size) + + val groupStage = hasAggregation?.children?.get(0)!! + + val named = groupStage.component()!! + assertEquals(Name.GROUP, named.name) + + val idFieldRef = groupStage.component>()!!.reference as HasFieldReference.Inferred + val constantValueRef = groupStage.component>()!!.reference as HasValueReference.Constant + val accumulatedFields = groupStage.component>()!! + + assertEquals("_id", idFieldRef.fieldName) + assertEquals(0, accumulatedFields.children.size) + assertEquals(null, constantValueRef.value) + } + + @ParsingTest( + fileName = "Aggregation.java", + value = """ +import com.mongodb.client.AggregateIterable; +import com.mongodb.client.MongoClient; +import com.mongodb.client.MongoCollection; +import com.mongodb.client.model.Aggregates; +import com.mongodb.client.model.Filters; +import org.bson.Document; +import org.bson.types.ObjectId; + +import java.util.List; + +import static com.mongodb.client.model.Filters.*; + public final class Aggregation { private final MongoCollection collection; @@ -73,7 +123,126 @@ public final class Aggregation { assertEquals(0, accumulatedFields.children.size) val computedExpression = computedValueRef.type.expression - val fieldUsedForComputation = computedExpression.component>()!!.reference as HasFieldReference.Computed + val fieldUsedForComputation = computedExpression.component>()!!.reference as HasFieldReference.FromSchema + + assertEquals("myField", fieldUsedForComputation.fieldName) + assertEquals("${'$'}myField", fieldUsedForComputation.displayName) + } + + @ParsingTest( + fileName = "Aggregation.java", + value = """ +import com.mongodb.client.AggregateIterable; +import com.mongodb.client.MongoClient; +import com.mongodb.client.MongoCollection; +import com.mongodb.client.model.Aggregates; +import com.mongodb.client.model.Filters; +import org.bson.Document; +import org.bson.types.ObjectId; + +import java.util.List; + +import static com.mongodb.client.model.Filters.*; + +public final class Aggregation { + private final MongoCollection collection; + + public Aggregation(MongoClient client) { + this.collection = client.getDatabase("simple").getCollection("books"); + } + + public AggregateIterable getAllBookTitles(ObjectId id) { + String _idField = "${'$'}myField"; + return this.collection.aggregate(List.of( + Aggregates.group(_idField) + )); + } +} + """ + ) + fun `should be able to parse a group stage without accumulators, when _id is a variable`( + psiFile: PsiFile + ) { + val aggregate = psiFile.getQueryAtMethod("Aggregation", "getAllBookTitles") + val parsedAggregate = JavaDriverDialect.parser.parse(aggregate) + val hasAggregation = parsedAggregate.component>() + assertEquals(1, hasAggregation?.children?.size) + + val groupStage = hasAggregation?.children?.get(0)!! + + val named = groupStage.component()!! + assertEquals(Name.GROUP, named.name) + + val idFieldRef = groupStage.component>()!!.reference as HasFieldReference.Inferred + val computedValueRef = groupStage.component>()!!.reference as HasValueReference.Computed + val accumulatedFields = groupStage.component>()!! + + assertEquals("_id", idFieldRef.fieldName) + assertEquals(0, accumulatedFields.children.size) + + val computedExpression = computedValueRef.type.expression + val fieldUsedForComputation = computedExpression.component>()!!.reference as HasFieldReference.FromSchema + + assertEquals("myField", fieldUsedForComputation.fieldName) + assertEquals("${'$'}myField", fieldUsedForComputation.displayName) + } + + @ParsingTest( + fileName = "Aggregation.java", + value = """ +import com.mongodb.client.AggregateIterable; +import com.mongodb.client.MongoClient; +import com.mongodb.client.MongoCollection; +import com.mongodb.client.model.Aggregates; +import com.mongodb.client.model.Filters; +import org.bson.Document; +import org.bson.types.ObjectId; + +import java.util.List; + +import static com.mongodb.client.model.Filters.*; + +public final class Aggregation { + private final MongoCollection collection; + + public Aggregation(MongoClient client) { + this.collection = client.getDatabase("simple").getCollection("books"); + } + + private String getIdField() { + return "${'$'}myField"; + } + + public AggregateIterable getAllBookTitles(ObjectId id) { + return this.collection.aggregate(List.of( + Aggregates.group(getIdField()) + )); + } +} + """ + ) + fun `should be able to parse a group stage without accumulators, when _id is from a method call`( + psiFile: PsiFile + ) { + val aggregate = psiFile.getQueryAtMethod("Aggregation", "getAllBookTitles") + val parsedAggregate = JavaDriverDialect.parser.parse(aggregate) + val hasAggregation = parsedAggregate.component>() + assertEquals(1, hasAggregation?.children?.size) + + val groupStage = hasAggregation?.children?.get(0)!! + + val named = groupStage.component()!! + assertEquals(Name.GROUP, named.name) + + val idFieldRef = groupStage.component>()!!.reference as HasFieldReference.Inferred + val computedValueRef = groupStage.component>()!!.reference as HasValueReference.Computed + val accumulatedFields = groupStage.component>()!! + + assertEquals("_id", idFieldRef.fieldName) + assertEquals(0, accumulatedFields.children.size) + + val computedExpression = computedValueRef.type.expression + val fieldUsedForComputation = computedExpression.component>()!!.reference as HasFieldReference.FromSchema assertEquals("myField", fieldUsedForComputation.fieldName) assertEquals("${'$'}myField", fieldUsedForComputation.displayName) @@ -169,6 +338,243 @@ import com.mongodb.client.MongoClient; import com.mongodb.client.MongoCollection; import com.mongodb.client.model.Aggregates; import com.mongodb.client.model.Accumulators; +import com.mongodb.client.model.BsonField; +import com.mongodb.client.model.Filters; +import org.bson.Document; +import org.bson.types.ObjectId; + +import java.util.List; + +import static com.mongodb.client.model.Filters.*; + +public final class Aggregation { + private final MongoCollection collection; + + public Aggregation(MongoClient client) { + this.collection = client.getDatabase("simple").getCollection("books"); + } + + public AggregateIterable getAllBookTitles(ObjectId id) { + BsonField accumulatedExpr = Accumulators."|"("myKey", "myVal"); + return this.collection.aggregate(List.of( + Aggregates.group("${'$'}myField", accumulatedExpr) + )); + } +} + """, + ) + @ParameterizedTest + @CsvSource( + value = [ + "method;;expected", + "sum;;SUM", + "avg;;AVG", + "first;;FIRST", + "last;;LAST", + "max;;MAX", + "min;;MIN", + "push;;PUSH", + "addToSet;;ADD_TO_SET", + ], + delimiterString = ";;", + useHeadersInDisplayName = true + ) + fun `supports all relevant key-value accumulators from the driver, when accumulated expression is a variable`( + method: String, + expected: Name, + psiFile: PsiFile + ) { + WriteCommandAction.runWriteCommandAction(psiFile.project) { + val elementAtCaret = psiFile.caret() + val javaFacade = JavaPsiFacade.getInstance(psiFile.project) + val methodToTest = javaFacade.parserFacade.createReferenceFromText(method, null) + elementAtCaret.replace(methodToTest) + } + + ApplicationManager.getApplication().runReadAction { + val aggregate = psiFile.getQueryAtMethod("Aggregation", "getAllBookTitles") + val parsedAggregate = JavaDriverDialect.parser.parse(aggregate) + val hasAggregation = parsedAggregate.component>() + assertEquals(1, hasAggregation?.children?.size) + + val groupStage = hasAggregation?.children?.get(0)!! + val named = groupStage.component()!! + assertEquals(Name.GROUP, named.name) + + val accumulator = groupStage.component>()!!.children[0] + val accumulatorName = accumulator.component()!! + assertEquals(expected, accumulatorName.name) + + val accumulatorField = accumulator.component>()?.reference as HasFieldReference.Computed + assertEquals("myKey", accumulatorField.fieldName) + + val accumulatorComputed = accumulator.component>()?.reference as HasValueReference.Computed + val accumulatorComputedFieldValue = accumulatorComputed.type.expression.component>()!!.reference as HasFieldReference.FromSchema + assertEquals("myVal", accumulatorComputedFieldValue.fieldName) + } + } + + @WithFile( + fileName = "Repository.java", + value = """ +import com.mongodb.client.AggregateIterable; +import com.mongodb.client.MongoClient; +import com.mongodb.client.MongoCollection; +import com.mongodb.client.model.Aggregates; +import com.mongodb.client.model.Accumulators; +import com.mongodb.client.model.BsonField; +import com.mongodb.client.model.Filters; +import org.bson.Document; +import org.bson.types.ObjectId; + +import java.util.List; + +import static com.mongodb.client.model.Filters.*; + +public final class Aggregation { + private final MongoCollection collection; + + public Aggregation(MongoClient client) { + this.collection = client.getDatabase("simple").getCollection("books"); + } + + private BsonField getAccumulatedExpr() { + return Accumulators."|"("myKey", "myVal"); + } + + public AggregateIterable getAllBookTitles(ObjectId id) { + return this.collection.aggregate(List.of( + Aggregates.group("${'$'}myField", getAccumulatedExpr()) + )); + } +} + """, + ) + @ParameterizedTest + @CsvSource( + value = [ + "method;;expected", + "sum;;SUM", + "avg;;AVG", + "first;;FIRST", + "last;;LAST", + "max;;MAX", + "min;;MIN", + "push;;PUSH", + "addToSet;;ADD_TO_SET", + ], + delimiterString = ";;", + useHeadersInDisplayName = true + ) + fun `supports all relevant key-value accumulators from the driver, when accumulated expression is from a method call`( + method: String, + expected: Name, + psiFile: PsiFile + ) { + WriteCommandAction.runWriteCommandAction(psiFile.project) { + val elementAtCaret = psiFile.caret() + val javaFacade = JavaPsiFacade.getInstance(psiFile.project) + val methodToTest = javaFacade.parserFacade.createReferenceFromText(method, null) + elementAtCaret.replace(methodToTest) + } + + ApplicationManager.getApplication().runReadAction { + val aggregate = psiFile.getQueryAtMethod("Aggregation", "getAllBookTitles") + val parsedAggregate = JavaDriverDialect.parser.parse(aggregate) + val hasAggregation = parsedAggregate.component>() + assertEquals(1, hasAggregation?.children?.size) + + val groupStage = hasAggregation?.children?.get(0)!! + val named = groupStage.component()!! + assertEquals(Name.GROUP, named.name) + + val accumulator = groupStage.component>()!!.children[0] + val accumulatorName = accumulator.component()!! + assertEquals(expected, accumulatorName.name) + + val accumulatorField = accumulator.component>()?.reference as HasFieldReference.Computed + assertEquals("myKey", accumulatorField.fieldName) + + val accumulatorComputed = accumulator.component>()?.reference as HasValueReference.Computed + val accumulatorComputedFieldValue = accumulatorComputed.type.expression.component>()!!.reference as HasFieldReference.FromSchema + assertEquals("myVal", accumulatorComputedFieldValue.fieldName) + } + } + + @ParsingTest( + fileName = "Aggregation.java", + value = """ +import com.mongodb.client.AggregateIterable; +import com.mongodb.client.MongoClient; +import com.mongodb.client.MongoCollection; +import com.mongodb.client.model.Accumulators; +import com.mongodb.client.model.Aggregates; +import com.mongodb.client.model.Filters; +import org.bson.Document; +import org.bson.types.ObjectId; + +import java.util.List; + +import static com.mongodb.client.model.Filters.*; + +public final class Aggregation { + private final MongoCollection collection; + + public Aggregation(MongoClient client) { + this.collection = client.getDatabase("simple").getCollection("books"); + } + + public AggregateIterable getAllBookTitles(ObjectId id) { + return this.collection.aggregate(List.of( + Aggregates.group("${'$'}myField", Accumulators.sum("totalCount", 1)) + )); + } +} + """ + ) + fun `supports accumulators also expecting a constant argument`(psiFile: PsiFile) { + val aggregate = psiFile.getQueryAtMethod("Aggregation", "getAllBookTitles") + val parsedAggregate = JavaDriverDialect.parser.parse(aggregate) + val hasAggregation = parsedAggregate.component>() + assertEquals(1, hasAggregation?.children?.size) + + val groupStage = hasAggregation?.children?.get(0)!! + + val named = groupStage.component()!! + assertEquals(Name.GROUP, named.name) + + val idFieldRef = groupStage.component>()!!.reference as HasFieldReference.Inferred + val computedValueRef = groupStage.component>()!!.reference as HasValueReference.Computed + val accumulatedFields = groupStage.component>()!! + + assertEquals("_id", idFieldRef.fieldName) + val computedExpression = computedValueRef.type.expression + val fieldUsedForComputation = computedExpression.component>()!!.reference as HasFieldReference.FromSchema + + assertEquals("myField", fieldUsedForComputation.fieldName) + assertEquals("${'$'}myField", fieldUsedForComputation.displayName) + + assertEquals(1, accumulatedFields.children.size) + val totalCountNode = accumulatedFields.children.first() + val totalCountFieldRef = + totalCountNode.component>()!!.reference + as HasFieldReference.Computed + val totalCountValueRef = + totalCountNode.component>()!!.reference + as HasValueReference.Constant + + assertEquals("totalCount", totalCountFieldRef.fieldName) + assertEquals(1, totalCountValueRef.value) + } + + @WithFile( + fileName = "Repository.java", + value = """ +import com.mongodb.client.AggregateIterable; +import com.mongodb.client.MongoClient; +import com.mongodb.client.MongoCollection; +import com.mongodb.client.model.Aggregates; +import com.mongodb.client.model.Accumulators; import com.mongodb.client.model.Filters; import com.mongodb.client.model.Sorts;import org.bson.Document; import org.bson.types.ObjectId; diff --git a/packages/mongodb-mql-model/src/main/kotlin/com/mongodb/jbplugin/mql/components/HasValueReference.kt b/packages/mongodb-mql-model/src/main/kotlin/com/mongodb/jbplugin/mql/components/HasValueReference.kt index 62cac7ea..8b77fbe9 100644 --- a/packages/mongodb-mql-model/src/main/kotlin/com/mongodb/jbplugin/mql/components/HasValueReference.kt +++ b/packages/mongodb-mql-model/src/main/kotlin/com/mongodb/jbplugin/mql/components/HasValueReference.kt @@ -1,7 +1,6 @@ package com.mongodb.jbplugin.mql.components import com.mongodb.jbplugin.mql.BsonType -import com.mongodb.jbplugin.mql.Component import com.mongodb.jbplugin.mql.ComputedBsonType import com.mongodb.jbplugin.mql.HasChildren import com.mongodb.jbplugin.mql.Node From 7d415e8aaddad84419ecbb859c769e5cd34ed6cc Mon Sep 17 00:00:00 2001 From: Himanshu Singh Date: Mon, 25 Nov 2024 17:06:21 +0100 Subject: [PATCH 10/10] chore: revert BsonString to BsonAny --- .../dialects/javadriver/glossary/JavaDriverDialectParser.kt | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/packages/mongodb-dialects/java-driver/src/main/kotlin/com/mongodb/jbplugin/dialects/javadriver/glossary/JavaDriverDialectParser.kt b/packages/mongodb-dialects/java-driver/src/main/kotlin/com/mongodb/jbplugin/dialects/javadriver/glossary/JavaDriverDialectParser.kt index 59439f83..ae4ae6fd 100644 --- a/packages/mongodb-dialects/java-driver/src/main/kotlin/com/mongodb/jbplugin/dialects/javadriver/glossary/JavaDriverDialectParser.kt +++ b/packages/mongodb-dialects/java-driver/src/main/kotlin/com/mongodb/jbplugin/dialects/javadriver/glossary/JavaDriverDialectParser.kt @@ -11,7 +11,6 @@ import com.mongodb.jbplugin.mql.BsonAnyOf import com.mongodb.jbplugin.mql.BsonArray import com.mongodb.jbplugin.mql.BsonBoolean import com.mongodb.jbplugin.mql.BsonInt32 -import com.mongodb.jbplugin.mql.BsonString import com.mongodb.jbplugin.mql.BsonType import com.mongodb.jbplugin.mql.ComputedBsonType import com.mongodb.jbplugin.mql.Node @@ -723,7 +722,7 @@ object JavaDriverDialectParser : DialectParser { constant && value is String -> HasValueReference.Computed( element, type = ComputedBsonType( - BsonString, + BsonAny, Node( element, listOf(