From 57c62bcc7fc6789ed8d04ee318287e5d85aee5ad Mon Sep 17 00:00:00 2001 From: Kevin Mas Ruiz Date: Mon, 25 Nov 2024 18:21:42 +0100 Subject: [PATCH] feat: support for new accumulators INTELLIJ-128 (#98) --- .../JavaDriverCompletionContributorTest.kt | 131 ++++ ...erMongoDbAutocompletionPopupHandlerTest.kt | 133 ++++ ...avaDriverFieldCheckLinterInspectionTest.kt | 121 ++++ .../javadriver/JavaDriverRepository.java | 7 + .../glossary/JavaDriverDialectParser.kt | 145 +++- .../javadriver/glossary/PsiMdbTreeUtil.kt | 2 +- .../aggregationparser/GroupStageParserTest.kt | 649 ++++++++++++++++++ .../mongosh/MongoshDialectFormatter.kt | 6 +- .../mongosh/backend/MongoshBackend.kt | 1 + .../com/mongodb/jbplugin/mql/BsonType.kt | 3 + .../mql/components/HasAccumulatedFields.kt | 8 + .../mql/components/HasFieldReference.kt | 11 + .../mql/components/HasValueReference.kt | 28 +- .../mongodb/jbplugin/mql/components/Named.kt | 11 + 14 files changed, 1245 insertions(+), 11 deletions(-) create mode 100644 packages/mongodb-dialects/java-driver/src/test/kotlin/com/mongodb/jbplugin/dialects/javadriver/glossary/aggregationparser/GroupStageParserTest.kt create mode 100644 packages/mongodb-mql-model/src/main/kotlin/com/mongodb/jbplugin/mql/components/HasAccumulatedFields.kt diff --git a/packages/jetbrains-plugin/src/test/kotlin/com/mongodb/jbplugin/autocomplete/JavaDriverCompletionContributorTest.kt b/packages/jetbrains-plugin/src/test/kotlin/com/mongodb/jbplugin/autocomplete/JavaDriverCompletionContributorTest.kt index 51b51f53..76214b1f 100644 --- a/packages/jetbrains-plugin/src/test/kotlin/com/mongodb/jbplugin/autocomplete/JavaDriverCompletionContributorTest.kt +++ b/packages/jetbrains-plugin/src/test/kotlin/com/mongodb/jbplugin/autocomplete/JavaDriverCompletionContributorTest.kt @@ -805,4 +805,135 @@ public class Repository { }, ) } + + @ParsingTest( + fileName = "Repository.java", + value = """ +import com.mongodb.client.FindIterable; +import com.mongodb.client.MongoClient; +import com.mongodb.client.MongoCollection; +import com.mongodb.client.model.Aggregates; +import com.mongodb.client.model.Projections; +import com.mongodb.client.model.Sorts; +import org.bson.Document; +import org.bson.types.ObjectId; +import java.util.List; +import static com.mongodb.client.model.Filters.*; +import static com.mongodb.client.model.Updates.*; + +public class Repository { + private final MongoClient client; + + public Repository(MongoClient client) { + this.client = client; + } + + public void exampleFind() { + client.getDatabase("myDatabase").getCollection("myCollection") + .aggregate(List.of( + Aggregates.group( + "" + ) + )); + } +} + """, + ) + fun `should autocomplete fields from the current namespace for _id expression in Aggregates#group stage`( + fixture: CodeInsightTestFixture, + ) { + fixture.specifyDialect(JavaDriverDialect) + + val (dataSource, readModelProvider) = fixture.setupConnection() + val namespace = Namespace("myDatabase", "myCollection") + + `when`( + readModelProvider.slice(eq(dataSource), eq(GetCollectionSchema.Slice(namespace))) + ).thenReturn( + GetCollectionSchema( + CollectionSchema( + namespace, + BsonObject( + mapOf( + "myField" to BsonString, + ), + ), + ), + ), + ) + + val elements = fixture.completeBasic() + + assertTrue( + elements.containsElements { + it.lookupString == "myField" + }, + ) + } + + @ParsingTest( + fileName = "Repository.java", + value = """ +import com.mongodb.client.FindIterable; +import com.mongodb.client.MongoClient; +import com.mongodb.client.MongoCollection; +import com.mongodb.client.model.Accumulators;import com.mongodb.client.model.Aggregates; +import com.mongodb.client.model.Projections; +import com.mongodb.client.model.Sorts; +import org.bson.Document; +import org.bson.types.ObjectId; +import java.util.List; +import static com.mongodb.client.model.Filters.*; +import static com.mongodb.client.model.Updates.*; + +public class Repository { + private final MongoClient client; + + public Repository(MongoClient client) { + this.client = client; + } + + public void exampleFind() { + client.getDatabase("myDatabase").getCollection("myCollection") + .aggregate(List.of( + Aggregates.group( + "${'$'}year", + Accumulators.sum("totalMovies", "") + ) + )); + } +} + """, + ) + fun `should autocomplete fields from the current namespace for accumulator expression in Aggregates#group stage`( + fixture: CodeInsightTestFixture, + ) { + fixture.specifyDialect(JavaDriverDialect) + + val (dataSource, readModelProvider) = fixture.setupConnection() + val namespace = Namespace("myDatabase", "myCollection") + + `when`( + readModelProvider.slice(eq(dataSource), eq(GetCollectionSchema.Slice(namespace))) + ).thenReturn( + GetCollectionSchema( + CollectionSchema( + namespace, + BsonObject( + mapOf( + "myField" to BsonString, + ), + ), + ), + ), + ) + + val elements = fixture.completeBasic() + + assertTrue( + elements.containsElements { + it.lookupString == "myField" + }, + ) + } } diff --git a/packages/jetbrains-plugin/src/test/kotlin/com/mongodb/jbplugin/autocomplete/JavaDriverMongoDbAutocompletionPopupHandlerTest.kt b/packages/jetbrains-plugin/src/test/kotlin/com/mongodb/jbplugin/autocomplete/JavaDriverMongoDbAutocompletionPopupHandlerTest.kt index 357c9b5a..2557fa22 100644 --- a/packages/jetbrains-plugin/src/test/kotlin/com/mongodb/jbplugin/autocomplete/JavaDriverMongoDbAutocompletionPopupHandlerTest.kt +++ b/packages/jetbrains-plugin/src/test/kotlin/com/mongodb/jbplugin/autocomplete/JavaDriverMongoDbAutocompletionPopupHandlerTest.kt @@ -818,4 +818,137 @@ public class Repository { }, ) } + + @ParsingTest( + fileName = "Repository.java", + value = """ +import com.mongodb.client.FindIterable; +import com.mongodb.client.MongoClient; +import com.mongodb.client.MongoCollection; +import com.mongodb.client.model.Aggregates; +import com.mongodb.client.model.Projections; +import com.mongodb.client.model.Sorts; +import org.bson.Document; +import org.bson.types.ObjectId; +import java.util.List; +import static com.mongodb.client.model.Filters.*; +import static com.mongodb.client.model.Updates.*; + +public class Repository { + private final MongoClient client; + + public Repository(MongoClient client) { + this.client = client; + } + + public void exampleFind() { + client.getDatabase("myDatabase").getCollection("myCollection") + .aggregate(List.of( + Aggregates.group( + + ) + )); + } +} + """, + ) + fun `should autocomplete fields from the current namespace for _id expression in Aggregates#group stage`( + fixture: CodeInsightTestFixture, + ) { + fixture.specifyDialect(JavaDriverDialect) + + val (dataSource, readModelProvider) = fixture.setupConnection() + val namespace = Namespace("myDatabase", "myCollection") + + `when`( + readModelProvider.slice(eq(dataSource), eq(GetCollectionSchema.Slice(namespace))) + ).thenReturn( + GetCollectionSchema( + CollectionSchema( + namespace, + BsonObject( + mapOf( + "myField" to BsonString, + ), + ), + ), + ), + ) + + fixture.type('"') + val elements = fixture.completeBasic() + + assertTrue( + elements.containsElements { + it.lookupString == "myField" + }, + ) + } + + @ParsingTest( + fileName = "Repository.java", + value = """ +import com.mongodb.client.FindIterable; +import com.mongodb.client.MongoClient; +import com.mongodb.client.MongoCollection; +import com.mongodb.client.model.Accumulators;import com.mongodb.client.model.Aggregates; +import com.mongodb.client.model.Projections; +import com.mongodb.client.model.Sorts; +import org.bson.Document; +import org.bson.types.ObjectId; +import java.util.List; +import static com.mongodb.client.model.Filters.*; +import static com.mongodb.client.model.Updates.*; + +public class Repository { + private final MongoClient client; + + public Repository(MongoClient client) { + this.client = client; + } + + public void exampleFind() { + client.getDatabase("myDatabase").getCollection("myCollection") + .aggregate(List.of( + Aggregates.group( + "${'$'}year", + Accumulators.sum("totalMovies", ) + ) + )); + } +} + """, + ) + fun `should autocomplete fields from the current namespace for accumulator expression in Aggregates#group stage`( + fixture: CodeInsightTestFixture, + ) { + fixture.specifyDialect(JavaDriverDialect) + + val (dataSource, readModelProvider) = fixture.setupConnection() + val namespace = Namespace("myDatabase", "myCollection") + + `when`( + readModelProvider.slice(eq(dataSource), eq(GetCollectionSchema.Slice(namespace))) + ).thenReturn( + GetCollectionSchema( + CollectionSchema( + namespace, + BsonObject( + mapOf( + "myField" to BsonString, + ), + ), + ), + ), + ) + + fixture.type('"') + val elements = fixture.completeBasic() + + assertTrue( + elements.containsElements { + it.lookupString == "myField" + }, + ) + } } diff --git a/packages/jetbrains-plugin/src/test/kotlin/com/mongodb/jbplugin/inspections/impl/JavaDriverFieldCheckLinterInspectionTest.kt b/packages/jetbrains-plugin/src/test/kotlin/com/mongodb/jbplugin/inspections/impl/JavaDriverFieldCheckLinterInspectionTest.kt index 1da59d4a..eb0d4145 100644 --- a/packages/jetbrains-plugin/src/test/kotlin/com/mongodb/jbplugin/inspections/impl/JavaDriverFieldCheckLinterInspectionTest.kt +++ b/packages/jetbrains-plugin/src/test/kotlin/com/mongodb/jbplugin/inspections/impl/JavaDriverFieldCheckLinterInspectionTest.kt @@ -9,6 +9,7 @@ import com.mongodb.jbplugin.fixtures.setupConnection import com.mongodb.jbplugin.fixtures.specifyDialect import com.mongodb.jbplugin.mql.BsonDouble import com.mongodb.jbplugin.mql.BsonObject +import com.mongodb.jbplugin.mql.BsonString import com.mongodb.jbplugin.mql.CollectionSchema import com.mongodb.jbplugin.mql.Namespace import org.mockito.Mockito.`when` @@ -501,4 +502,124 @@ public class Repository { fixture.enableInspections(FieldCheckInspectionBridge::class.java) fixture.testHighlighting() } + + @ParsingTest( + fileName = "Repository.java", + value = """ +import com.mongodb.client.AggregateIterable; +import com.mongodb.client.MongoClient; +import com.mongodb.client.MongoCollection; +import com.mongodb.client.model.*; +import org.bson.Document; +import org.bson.conversions.Bson; +import org.bson.types.ObjectId; +import java.util.List; +import static com.mongodb.client.model.Filters.*; + +public class Repository { + private final MongoClient client; + + public Repository(MongoClient client) { + this.client = client; + } + + public AggregateIterable goodGroupAggregate1() { + return client.getDatabase("myDatabase") + .getCollection("myCollection") + .aggregate(List.of( + Aggregates.group(null), + Aggregates.group("${'$'}possibleIdField"), + Aggregates.group("${'$'}possibleIdField", Accumulators.sum("totalCount", 1)), + Aggregates.group( + "${'$'}possibleIdField", + Accumulators.sum("totalCount", "${'$'}otherField") + ) + )); + } + + private String getOtherField() { + return "${'$'}otherField"; + } + + public AggregateIterable goodGroupAggregate2() { + String fieldName = "${'$'}possibleIdField"; + BsonField totalCountAcc = Accumulators.sum("totalCount", 1); + return client.getDatabase("myDatabase") + .getCollection("myCollection") + .aggregate(List.of( + Aggregates.group(fieldName), + Aggregates.group(fieldName, totalCountAcc), + Aggregates.group( + fieldName, + Accumulators.sum("totalCount", getOtherField()) + ) + )); + } + + private String getBadFieldName() { + return "${'$'}nonExistentField"; + } + + private BsonField getAvgCountAcc() { + return Accumulators.avg( + "avgCount", + getBadFieldName() + ); + } + + public AggregateIterable badGroupAggregate1() { + String badFieldName = "${'$'}nonExistentField"; + BsonField avgCountAcc = Accumulators.avg( + "avgCount", + badFieldName + ); + return client.getDatabase("myDatabase") + .getCollection("myCollection") + .aggregate(List.of( + Aggregates.group( + "${'$'}nonExistentField" + ), + Aggregates.group( + badFieldName, + Accumulators.sum( + "totalCount", + badFieldName + ), + Accumulators.sum( + "totalCount", + getBadFieldName() + ), + avgCountAcc, + getAvgCountAcc() + ) + )); + } +} + """, + ) + fun `shows an inspection for Aggregates#group call when the field does not exist in the current namespace`( + fixture: CodeInsightTestFixture, + ) { + val (dataSource, readModelProvider) = fixture.setupConnection() + fixture.specifyDialect(JavaDriverDialect) + + `when`( + readModelProvider.slice(eq(dataSource), any()) + ).thenReturn( + GetCollectionSchema( + CollectionSchema( + Namespace("myDatabase", "myCollection"), + BsonObject( + mapOf( + "possibleIdField" to BsonString, + "otherField" to BsonString, + ) + ) + ) + ), + ) + + fixture.enableInspections(FieldCheckInspectionBridge::class.java) + fixture.testHighlighting() + } } diff --git a/packages/jetbrains-plugin/src/test/resources/project-fixtures/basic-java-project-with-mongodb/src/main/java/alt/mongodb/javadriver/JavaDriverRepository.java b/packages/jetbrains-plugin/src/test/resources/project-fixtures/basic-java-project-with-mongodb/src/main/java/alt/mongodb/javadriver/JavaDriverRepository.java index 7b89e06c..86c11f7c 100644 --- a/packages/jetbrains-plugin/src/test/resources/project-fixtures/basic-java-project-with-mongodb/src/main/java/alt/mongodb/javadriver/JavaDriverRepository.java +++ b/packages/jetbrains-plugin/src/test/resources/project-fixtures/basic-java-project-with-mongodb/src/main/java/alt/mongodb/javadriver/JavaDriverRepository.java @@ -1,6 +1,7 @@ package alt.mongodb.javadriver; import com.mongodb.client.MongoClient; +import com.mongodb.client.model.Accumulators; import com.mongodb.client.model.Aggregates; import com.mongodb.client.model.Filters; import com.mongodb.client.model.Projections; @@ -55,6 +56,12 @@ public List queryMoviesByYear(String year) { Aggregates.match( Filters.eq("year", year) ), + Aggregates.group( + "newField", + Accumulators.avg("test", "$year"), + Accumulators.sum("test2", "$year"), + Accumulators.bottom("field", Sorts.ascending("year"), "$year") + ), Aggregates.project( Projections.fields( Projections.include("year", "plot") diff --git a/packages/mongodb-dialects/java-driver/src/main/kotlin/com/mongodb/jbplugin/dialects/javadriver/glossary/JavaDriverDialectParser.kt b/packages/mongodb-dialects/java-driver/src/main/kotlin/com/mongodb/jbplugin/dialects/javadriver/glossary/JavaDriverDialectParser.kt index 3c92d2dd..ae4ae6fd 100644 --- a/packages/mongodb-dialects/java-driver/src/main/kotlin/com/mongodb/jbplugin/dialects/javadriver/glossary/JavaDriverDialectParser.kt +++ b/packages/mongodb-dialects/java-driver/src/main/kotlin/com/mongodb/jbplugin/dialects/javadriver/glossary/JavaDriverDialectParser.kt @@ -12,8 +12,11 @@ import com.mongodb.jbplugin.mql.BsonArray import com.mongodb.jbplugin.mql.BsonBoolean import com.mongodb.jbplugin.mql.BsonInt32 import com.mongodb.jbplugin.mql.BsonType +import com.mongodb.jbplugin.mql.ComputedBsonType import com.mongodb.jbplugin.mql.Node import com.mongodb.jbplugin.mql.components.* +import com.mongodb.jbplugin.mql.components.HasFieldReference.Computed +import com.mongodb.jbplugin.mql.components.HasFieldReference.FromSchema import com.mongodb.jbplugin.mql.flattenAnyOfReferences import com.mongodb.jbplugin.mql.toBsonType @@ -23,10 +26,12 @@ private const val FILTERS_FQN = "com.mongodb.client.model.Filters" private const val UPDATES_FQN = "com.mongodb.client.model.Updates" private const val AGGREGATES_FQN = "com.mongodb.client.model.Aggregates" private const val PROJECTIONS_FQN = "com.mongodb.client.model.Projections" +private const val ACCUMULATORS_FQN = "com.mongodb.client.model.Accumulators" private const val SORTS_FQN = "com.mongodb.client.model.Sorts" private const val FIELD_FQN = "com.mongodb.client.model.Field" private const val JAVA_LIST_FQN = "java.util.List" private const val JAVA_ARRAYS_FQN = "java.util.Arrays" + private val PARSEABLE_AGGREGATION_STAGE_METHODS = listOf( "match", "project", @@ -225,7 +230,8 @@ object JavaDriverDialectParser : DialectParser { val containingClass = methodCall.resolveMethod()?.containingClass ?: return false if ( containingClass.qualifiedName == PROJECTIONS_FQN || - containingClass.qualifiedName == SORTS_FQN + containingClass.qualifiedName == SORTS_FQN || + containingClass.qualifiedName == ACCUMULATORS_FQN ) { return isInAutoCompletableAggregation(methodCall) } @@ -302,7 +308,7 @@ object JavaDriverDialectParser : DialectParser { } val valueExpression = filter.argumentList.expressions[0] val valueReference = resolveValueFromExpression(valueExpression) - val fieldReference = HasFieldReference.FromSchema(valueExpression, "_id") + val fieldReference = FromSchema(valueExpression, "_id") return Node( filter, @@ -489,6 +495,33 @@ object JavaDriverDialectParser : DialectParser { return nodeWithParsedComponents(parsedComponents) } + "group" -> { + // the first parameter of group is going to be a string expression + val groupArgument = stageCall.argumentList.expressions.getOrNull(0) + ?: return null + + val groupFieldValueExpression = parseComputedExpression(groupArgument) + + val nodeWithAccumulators: (List>) -> Node = { accFields: List> -> + Node( + stageCall, + listOf( + Named(Name.GROUP), + HasFieldReference( + HasFieldReference.Inferred(groupArgument, "_id", "_id") + ), + groupFieldValueExpression, + HasAccumulatedFields(accFields) + ) + ) + } + + val accumulators = stageCall.getVarArgsOrIterableArgs().drop(1) + .mapNotNull { resolveBsonBuilderCall(it, ACCUMULATORS_FQN) } + + val parsedAccumulators = accumulators.mapNotNull { parseAccumulatorExpression(it) } + return nodeWithAccumulators(parsedAccumulators) + } "addFields" -> { // .addFields can have varargs of Field objects or Fields objects in an iterable @@ -528,9 +561,7 @@ object JavaDriverDialectParser : DialectParser { val fieldReference = resolveFieldNameFromExpression(it) val methodName = Name.from(methodCall.name) when (fieldReference) { - is HasFieldReference.Computed, - is HasFieldReference.Unknown -> null - is HasFieldReference.FromSchema -> Node( + is FromSchema -> Node( source = it, components = listOf( Named(methodName), @@ -550,6 +581,7 @@ object JavaDriverDialectParser : DialectParser { ) ) ) + else -> null } } @@ -557,6 +589,75 @@ object JavaDriverDialectParser : DialectParser { } } + private fun parseAccumulatorExpression(expression: PsiMethodCallExpression): Node? { + val methodCall = expression.fuzzyResolveMethod() ?: return null + + return when (methodCall.name) { + "sum" -> parseKeyValAccumulator(expression, Name.SUM) + "avg" -> parseKeyValAccumulator(expression, Name.AVG) + "first" -> parseKeyValAccumulator(expression, Name.FIRST) + "last" -> parseKeyValAccumulator(expression, Name.LAST) + "top" -> parseLeadingAccumulatorExpression(expression, Name.TOP) + "bottom" -> parseLeadingAccumulatorExpression(expression, Name.BOTTOM) + "max" -> parseKeyValAccumulator(expression, Name.MAX) + "min" -> parseKeyValAccumulator(expression, Name.MIN) + "push" -> parseKeyValAccumulator(expression, Name.PUSH) + "addToSet" -> parseKeyValAccumulator(expression, Name.ADD_TO_SET) + else -> null + } + } + + private fun parseKeyValAccumulator(expression: PsiMethodCallExpression, name: Name): Node? { + val keyExpr = expression.argumentList.expressions.getOrNull(0) ?: return null + val valueExpr = expression.argumentList.expressions.getOrNull(1) ?: return null + + val fieldName = keyExpr.tryToResolveAsConstantString() + val accumulatorExpr = parseComputedExpression(valueExpr) + + return Node( + expression, + listOf( + Named(name), + HasFieldReference( + if (fieldName != null) { + Computed(keyExpr, fieldName, fieldName) + } else { + HasFieldReference.Unknown as HasFieldReference.FieldReference + } + ), + accumulatorExpr + ) + ) + } + + private fun parseLeadingAccumulatorExpression(expression: PsiMethodCallExpression, name: Name): Node? { + val keyExpr = expression.argumentList.expressions.getOrNull(0) ?: return null + val sortExprArgument = expression.argumentList.expressions.getOrNull(1) ?: return null + val valueExpr = expression.argumentList.expressions.getOrNull(2) ?: return null + + val sortExpr = resolveBsonBuilderCall(sortExprArgument, SORTS_FQN) ?: return null + + val fieldName = keyExpr.tryToResolveAsConstantString() + val sort = parseBsonBuilderCallsSimilarToProjections(sortExpr, SORTS_FQN) + val accumulatorExpr = parseComputedExpression(valueExpr) + + return Node( + expression, + listOf( + Named(name), + HasFieldReference( + if (fieldName != null) { + Computed(keyExpr, fieldName, fieldName) + } else { + HasFieldReference.Unknown as HasFieldReference.FieldReference + } + ), + HasSorts(sort), + accumulatorExpr + ) + ) + } + private fun resolveAddFieldsArguments( addFieldsCall: PsiMethodCallExpression ): List { @@ -614,6 +715,38 @@ object JavaDriverDialectParser : DialectParser { return typeOfFirstArg != null && typeOfFirstArg.equalsToText(SESSION_FQN) } + private fun parseComputedExpression(element: PsiElement): HasValueReference { + val (constant, value) = element.tryToResolveAsConstant() + return HasValueReference( + when { + constant && value is String -> HasValueReference.Computed( + element, + type = ComputedBsonType( + BsonAny, + Node( + element, + listOf( + HasFieldReference( + FromSchema(element, value.trim('$'), value) + ) + ) + ) + ) + ) + constant -> HasValueReference.Constant( + element, + value, + value?.javaClass.toBsonType(value) + ) + !constant && element is PsiExpression -> HasValueReference.Runtime( + element, + element.type?.toBsonType() ?: BsonAny + ) + else -> HasValueReference.Unknown as HasValueReference.ValueReference + } + ) + } + private fun isAggregationStageMethodCall(callMethod: PsiMethod?): Boolean { return PARSEABLE_AGGREGATION_STAGE_METHODS.contains(callMethod?.name) && callMethod?.containingClass?.qualifiedName == AGGREGATES_FQN @@ -623,7 +756,7 @@ object JavaDriverDialectParser : DialectParser { val fieldNameAsString = expression.tryToResolveAsConstantString() val fieldReference = fieldNameAsString?.let { - HasFieldReference.FromSchema(expression, it) + FromSchema(expression, it) } ?: HasFieldReference.Unknown return fieldReference diff --git a/packages/mongodb-dialects/java-driver/src/main/kotlin/com/mongodb/jbplugin/dialects/javadriver/glossary/PsiMdbTreeUtil.kt b/packages/mongodb-dialects/java-driver/src/main/kotlin/com/mongodb/jbplugin/dialects/javadriver/glossary/PsiMdbTreeUtil.kt index 1bd96ad2..a79026c4 100644 --- a/packages/mongodb-dialects/java-driver/src/main/kotlin/com/mongodb/jbplugin/dialects/javadriver/glossary/PsiMdbTreeUtil.kt +++ b/packages/mongodb-dialects/java-driver/src/main/kotlin/com/mongodb/jbplugin/dialects/javadriver/glossary/PsiMdbTreeUtil.kt @@ -353,7 +353,7 @@ fun PsiElement.tryToResolveAsConstant(): Pair { * @return */ fun PsiElement.tryToResolveAsConstantString(): String? = - tryToResolveAsConstant().takeIf { it.first }?.second?.toString() + tryToResolveAsConstant().takeIf { it.first }?.second as? String /** * Maps a PsiType to its BSON counterpart. diff --git a/packages/mongodb-dialects/java-driver/src/test/kotlin/com/mongodb/jbplugin/dialects/javadriver/glossary/aggregationparser/GroupStageParserTest.kt b/packages/mongodb-dialects/java-driver/src/test/kotlin/com/mongodb/jbplugin/dialects/javadriver/glossary/aggregationparser/GroupStageParserTest.kt new file mode 100644 index 00000000..1c5a562b --- /dev/null +++ b/packages/mongodb-dialects/java-driver/src/test/kotlin/com/mongodb/jbplugin/dialects/javadriver/glossary/aggregationparser/GroupStageParserTest.kt @@ -0,0 +1,649 @@ +package com.mongodb.jbplugin.dialects.javadriver.glossary.aggregationparser + +import com.intellij.openapi.application.ApplicationManager +import com.intellij.openapi.command.WriteCommandAction +import com.intellij.psi.JavaPsiFacade +import com.intellij.psi.PsiElement +import com.intellij.psi.PsiFile +import com.mongodb.jbplugin.dialects.javadriver.IntegrationTest +import com.mongodb.jbplugin.dialects.javadriver.ParsingTest +import com.mongodb.jbplugin.dialects.javadriver.WithFile +import com.mongodb.jbplugin.dialects.javadriver.caret +import com.mongodb.jbplugin.dialects.javadriver.getQueryAtMethod +import com.mongodb.jbplugin.dialects.javadriver.glossary.JavaDriverDialect +import com.mongodb.jbplugin.mql.components.HasAccumulatedFields +import com.mongodb.jbplugin.mql.components.HasAggregation +import com.mongodb.jbplugin.mql.components.HasFieldReference +import com.mongodb.jbplugin.mql.components.HasSorts +import com.mongodb.jbplugin.mql.components.HasValueReference +import com.mongodb.jbplugin.mql.components.Name +import com.mongodb.jbplugin.mql.components.Named +import org.junit.jupiter.api.Assertions.assertEquals +import org.junit.jupiter.params.ParameterizedTest +import org.junit.jupiter.params.provider.CsvSource + +@IntegrationTest +class GroupStageParserTest { + @ParsingTest( + fileName = "Aggregation.java", + value = """ +import com.mongodb.client.AggregateIterable; +import com.mongodb.client.MongoClient; +import com.mongodb.client.MongoCollection; +import com.mongodb.client.model.Aggregates; +import com.mongodb.client.model.Filters; +import org.bson.Document; +import org.bson.types.ObjectId; + +import java.util.List; + +import static com.mongodb.client.model.Filters.*; + +public final class Aggregation { + private final MongoCollection collection; + + public Aggregation(MongoClient client) { + this.collection = client.getDatabase("simple").getCollection("books"); + } + + public AggregateIterable getAllBookTitles(ObjectId id) { + return this.collection.aggregate(List.of( + Aggregates.group(null) + )); + } +} + """ + ) + fun `should be able to parse a group stage with null _id`(psiFile: PsiFile) { + val aggregate = psiFile.getQueryAtMethod("Aggregation", "getAllBookTitles") + val parsedAggregate = JavaDriverDialect.parser.parse(aggregate) + val hasAggregation = parsedAggregate.component>() + assertEquals(1, hasAggregation?.children?.size) + + val groupStage = hasAggregation?.children?.get(0)!! + + val named = groupStage.component()!! + assertEquals(Name.GROUP, named.name) + + val idFieldRef = groupStage.component>()!!.reference as HasFieldReference.Inferred + val constantValueRef = groupStage.component>()!!.reference as HasValueReference.Constant + val accumulatedFields = groupStage.component>()!! + + assertEquals("_id", idFieldRef.fieldName) + assertEquals(0, accumulatedFields.children.size) + assertEquals(null, constantValueRef.value) + } + + @ParsingTest( + fileName = "Aggregation.java", + value = """ +import com.mongodb.client.AggregateIterable; +import com.mongodb.client.MongoClient; +import com.mongodb.client.MongoCollection; +import com.mongodb.client.model.Aggregates; +import com.mongodb.client.model.Filters; +import org.bson.Document; +import org.bson.types.ObjectId; + +import java.util.List; + +import static com.mongodb.client.model.Filters.*; + +public final class Aggregation { + private final MongoCollection collection; + + public Aggregation(MongoClient client) { + this.collection = client.getDatabase("simple").getCollection("books"); + } + + public AggregateIterable getAllBookTitles(ObjectId id) { + return this.collection.aggregate(List.of( + Aggregates.group("${'$'}myField") + )); + } +} + """ + ) + fun `should be able to parse a group stage without accumulators`(psiFile: PsiFile) { + val aggregate = psiFile.getQueryAtMethod("Aggregation", "getAllBookTitles") + val parsedAggregate = JavaDriverDialect.parser.parse(aggregate) + val hasAggregation = parsedAggregate.component>() + assertEquals(1, hasAggregation?.children?.size) + + val groupStage = hasAggregation?.children?.get(0)!! + + val named = groupStage.component()!! + assertEquals(Name.GROUP, named.name) + + val idFieldRef = groupStage.component>()!!.reference as HasFieldReference.Inferred + val computedValueRef = groupStage.component>()!!.reference as HasValueReference.Computed + val accumulatedFields = groupStage.component>()!! + + assertEquals("_id", idFieldRef.fieldName) + assertEquals(0, accumulatedFields.children.size) + + val computedExpression = computedValueRef.type.expression + val fieldUsedForComputation = computedExpression.component>()!!.reference as HasFieldReference.FromSchema + + assertEquals("myField", fieldUsedForComputation.fieldName) + assertEquals("${'$'}myField", fieldUsedForComputation.displayName) + } + + @ParsingTest( + fileName = "Aggregation.java", + value = """ +import com.mongodb.client.AggregateIterable; +import com.mongodb.client.MongoClient; +import com.mongodb.client.MongoCollection; +import com.mongodb.client.model.Aggregates; +import com.mongodb.client.model.Filters; +import org.bson.Document; +import org.bson.types.ObjectId; + +import java.util.List; + +import static com.mongodb.client.model.Filters.*; + +public final class Aggregation { + private final MongoCollection collection; + + public Aggregation(MongoClient client) { + this.collection = client.getDatabase("simple").getCollection("books"); + } + + public AggregateIterable getAllBookTitles(ObjectId id) { + String _idField = "${'$'}myField"; + return this.collection.aggregate(List.of( + Aggregates.group(_idField) + )); + } +} + """ + ) + fun `should be able to parse a group stage without accumulators, when _id is a variable`( + psiFile: PsiFile + ) { + val aggregate = psiFile.getQueryAtMethod("Aggregation", "getAllBookTitles") + val parsedAggregate = JavaDriverDialect.parser.parse(aggregate) + val hasAggregation = parsedAggregate.component>() + assertEquals(1, hasAggregation?.children?.size) + + val groupStage = hasAggregation?.children?.get(0)!! + + val named = groupStage.component()!! + assertEquals(Name.GROUP, named.name) + + val idFieldRef = groupStage.component>()!!.reference as HasFieldReference.Inferred + val computedValueRef = groupStage.component>()!!.reference as HasValueReference.Computed + val accumulatedFields = groupStage.component>()!! + + assertEquals("_id", idFieldRef.fieldName) + assertEquals(0, accumulatedFields.children.size) + + val computedExpression = computedValueRef.type.expression + val fieldUsedForComputation = computedExpression.component>()!!.reference as HasFieldReference.FromSchema + + assertEquals("myField", fieldUsedForComputation.fieldName) + assertEquals("${'$'}myField", fieldUsedForComputation.displayName) + } + + @ParsingTest( + fileName = "Aggregation.java", + value = """ +import com.mongodb.client.AggregateIterable; +import com.mongodb.client.MongoClient; +import com.mongodb.client.MongoCollection; +import com.mongodb.client.model.Aggregates; +import com.mongodb.client.model.Filters; +import org.bson.Document; +import org.bson.types.ObjectId; + +import java.util.List; + +import static com.mongodb.client.model.Filters.*; + +public final class Aggregation { + private final MongoCollection collection; + + public Aggregation(MongoClient client) { + this.collection = client.getDatabase("simple").getCollection("books"); + } + + private String getIdField() { + return "${'$'}myField"; + } + + public AggregateIterable getAllBookTitles(ObjectId id) { + return this.collection.aggregate(List.of( + Aggregates.group(getIdField()) + )); + } +} + """ + ) + fun `should be able to parse a group stage without accumulators, when _id is from a method call`( + psiFile: PsiFile + ) { + val aggregate = psiFile.getQueryAtMethod("Aggregation", "getAllBookTitles") + val parsedAggregate = JavaDriverDialect.parser.parse(aggregate) + val hasAggregation = parsedAggregate.component>() + assertEquals(1, hasAggregation?.children?.size) + + val groupStage = hasAggregation?.children?.get(0)!! + + val named = groupStage.component()!! + assertEquals(Name.GROUP, named.name) + + val idFieldRef = groupStage.component>()!!.reference as HasFieldReference.Inferred + val computedValueRef = groupStage.component>()!!.reference as HasValueReference.Computed + val accumulatedFields = groupStage.component>()!! + + assertEquals("_id", idFieldRef.fieldName) + assertEquals(0, accumulatedFields.children.size) + + val computedExpression = computedValueRef.type.expression + val fieldUsedForComputation = computedExpression.component>()!!.reference as HasFieldReference.FromSchema + + assertEquals("myField", fieldUsedForComputation.fieldName) + assertEquals("${'$'}myField", fieldUsedForComputation.displayName) + } + + @WithFile( + fileName = "Repository.java", + value = """ +import com.mongodb.client.AggregateIterable; +import com.mongodb.client.MongoClient; +import com.mongodb.client.MongoCollection; +import com.mongodb.client.model.Aggregates; +import com.mongodb.client.model.Accumulators; +import com.mongodb.client.model.Filters; +import org.bson.Document; +import org.bson.types.ObjectId; + +import java.util.List; + +import static com.mongodb.client.model.Filters.*; + +public final class Aggregation { + private final MongoCollection collection; + + public Aggregation(MongoClient client) { + this.collection = client.getDatabase("simple").getCollection("books"); + } + + public AggregateIterable getAllBookTitles(ObjectId id) { + return this.collection.aggregate(List.of( + Aggregates.group("${'$'}myField", Accumulators."|"("myKey", "myVal")) + )); + } +} + """, + ) + @ParameterizedTest + @CsvSource( + value = [ + "method;;expected", + "sum;;SUM", + "avg;;AVG", + "first;;FIRST", + "last;;LAST", + "max;;MAX", + "min;;MIN", + "push;;PUSH", + "addToSet;;ADD_TO_SET", + ], + delimiterString = ";;", + useHeadersInDisplayName = true + ) + fun `supports all relevant key-value accumulators from the driver`( + method: String, + expected: Name, + psiFile: PsiFile + ) { + WriteCommandAction.runWriteCommandAction(psiFile.project) { + val elementAtCaret = psiFile.caret() + val javaFacade = JavaPsiFacade.getInstance(psiFile.project) + val methodToTest = javaFacade.parserFacade.createReferenceFromText(method, null) + elementAtCaret.replace(methodToTest) + } + + ApplicationManager.getApplication().runReadAction { + val aggregate = psiFile.getQueryAtMethod("Aggregation", "getAllBookTitles") + val parsedAggregate = JavaDriverDialect.parser.parse(aggregate) + val hasAggregation = parsedAggregate.component>() + assertEquals(1, hasAggregation?.children?.size) + + val groupStage = hasAggregation?.children?.get(0)!! + val named = groupStage.component()!! + assertEquals(Name.GROUP, named.name) + + val accumulator = groupStage.component>()!!.children[0] + val accumulatorName = accumulator.component()!! + assertEquals(expected, accumulatorName.name) + + val accumulatorField = accumulator.component>()?.reference as HasFieldReference.Computed + assertEquals("myKey", accumulatorField.fieldName) + + val accumulatorComputed = accumulator.component>()?.reference as HasValueReference.Computed + val accumulatorComputedFieldValue = accumulatorComputed.type.expression.component>()!!.reference as HasFieldReference.FromSchema + assertEquals("myVal", accumulatorComputedFieldValue.fieldName) + } + } + + @WithFile( + fileName = "Repository.java", + value = """ +import com.mongodb.client.AggregateIterable; +import com.mongodb.client.MongoClient; +import com.mongodb.client.MongoCollection; +import com.mongodb.client.model.Aggregates; +import com.mongodb.client.model.Accumulators; +import com.mongodb.client.model.BsonField; +import com.mongodb.client.model.Filters; +import org.bson.Document; +import org.bson.types.ObjectId; + +import java.util.List; + +import static com.mongodb.client.model.Filters.*; + +public final class Aggregation { + private final MongoCollection collection; + + public Aggregation(MongoClient client) { + this.collection = client.getDatabase("simple").getCollection("books"); + } + + public AggregateIterable getAllBookTitles(ObjectId id) { + BsonField accumulatedExpr = Accumulators."|"("myKey", "myVal"); + return this.collection.aggregate(List.of( + Aggregates.group("${'$'}myField", accumulatedExpr) + )); + } +} + """, + ) + @ParameterizedTest + @CsvSource( + value = [ + "method;;expected", + "sum;;SUM", + "avg;;AVG", + "first;;FIRST", + "last;;LAST", + "max;;MAX", + "min;;MIN", + "push;;PUSH", + "addToSet;;ADD_TO_SET", + ], + delimiterString = ";;", + useHeadersInDisplayName = true + ) + fun `supports all relevant key-value accumulators from the driver, when accumulated expression is a variable`( + method: String, + expected: Name, + psiFile: PsiFile + ) { + WriteCommandAction.runWriteCommandAction(psiFile.project) { + val elementAtCaret = psiFile.caret() + val javaFacade = JavaPsiFacade.getInstance(psiFile.project) + val methodToTest = javaFacade.parserFacade.createReferenceFromText(method, null) + elementAtCaret.replace(methodToTest) + } + + ApplicationManager.getApplication().runReadAction { + val aggregate = psiFile.getQueryAtMethod("Aggregation", "getAllBookTitles") + val parsedAggregate = JavaDriverDialect.parser.parse(aggregate) + val hasAggregation = parsedAggregate.component>() + assertEquals(1, hasAggregation?.children?.size) + + val groupStage = hasAggregation?.children?.get(0)!! + val named = groupStage.component()!! + assertEquals(Name.GROUP, named.name) + + val accumulator = groupStage.component>()!!.children[0] + val accumulatorName = accumulator.component()!! + assertEquals(expected, accumulatorName.name) + + val accumulatorField = accumulator.component>()?.reference as HasFieldReference.Computed + assertEquals("myKey", accumulatorField.fieldName) + + val accumulatorComputed = accumulator.component>()?.reference as HasValueReference.Computed + val accumulatorComputedFieldValue = accumulatorComputed.type.expression.component>()!!.reference as HasFieldReference.FromSchema + assertEquals("myVal", accumulatorComputedFieldValue.fieldName) + } + } + + @WithFile( + fileName = "Repository.java", + value = """ +import com.mongodb.client.AggregateIterable; +import com.mongodb.client.MongoClient; +import com.mongodb.client.MongoCollection; +import com.mongodb.client.model.Aggregates; +import com.mongodb.client.model.Accumulators; +import com.mongodb.client.model.BsonField; +import com.mongodb.client.model.Filters; +import org.bson.Document; +import org.bson.types.ObjectId; + +import java.util.List; + +import static com.mongodb.client.model.Filters.*; + +public final class Aggregation { + private final MongoCollection collection; + + public Aggregation(MongoClient client) { + this.collection = client.getDatabase("simple").getCollection("books"); + } + + private BsonField getAccumulatedExpr() { + return Accumulators."|"("myKey", "myVal"); + } + + public AggregateIterable getAllBookTitles(ObjectId id) { + return this.collection.aggregate(List.of( + Aggregates.group("${'$'}myField", getAccumulatedExpr()) + )); + } +} + """, + ) + @ParameterizedTest + @CsvSource( + value = [ + "method;;expected", + "sum;;SUM", + "avg;;AVG", + "first;;FIRST", + "last;;LAST", + "max;;MAX", + "min;;MIN", + "push;;PUSH", + "addToSet;;ADD_TO_SET", + ], + delimiterString = ";;", + useHeadersInDisplayName = true + ) + fun `supports all relevant key-value accumulators from the driver, when accumulated expression is from a method call`( + method: String, + expected: Name, + psiFile: PsiFile + ) { + WriteCommandAction.runWriteCommandAction(psiFile.project) { + val elementAtCaret = psiFile.caret() + val javaFacade = JavaPsiFacade.getInstance(psiFile.project) + val methodToTest = javaFacade.parserFacade.createReferenceFromText(method, null) + elementAtCaret.replace(methodToTest) + } + + ApplicationManager.getApplication().runReadAction { + val aggregate = psiFile.getQueryAtMethod("Aggregation", "getAllBookTitles") + val parsedAggregate = JavaDriverDialect.parser.parse(aggregate) + val hasAggregation = parsedAggregate.component>() + assertEquals(1, hasAggregation?.children?.size) + + val groupStage = hasAggregation?.children?.get(0)!! + val named = groupStage.component()!! + assertEquals(Name.GROUP, named.name) + + val accumulator = groupStage.component>()!!.children[0] + val accumulatorName = accumulator.component()!! + assertEquals(expected, accumulatorName.name) + + val accumulatorField = accumulator.component>()?.reference as HasFieldReference.Computed + assertEquals("myKey", accumulatorField.fieldName) + + val accumulatorComputed = accumulator.component>()?.reference as HasValueReference.Computed + val accumulatorComputedFieldValue = accumulatorComputed.type.expression.component>()!!.reference as HasFieldReference.FromSchema + assertEquals("myVal", accumulatorComputedFieldValue.fieldName) + } + } + + @ParsingTest( + fileName = "Aggregation.java", + value = """ +import com.mongodb.client.AggregateIterable; +import com.mongodb.client.MongoClient; +import com.mongodb.client.MongoCollection; +import com.mongodb.client.model.Accumulators; +import com.mongodb.client.model.Aggregates; +import com.mongodb.client.model.Filters; +import org.bson.Document; +import org.bson.types.ObjectId; + +import java.util.List; + +import static com.mongodb.client.model.Filters.*; + +public final class Aggregation { + private final MongoCollection collection; + + public Aggregation(MongoClient client) { + this.collection = client.getDatabase("simple").getCollection("books"); + } + + public AggregateIterable getAllBookTitles(ObjectId id) { + return this.collection.aggregate(List.of( + Aggregates.group("${'$'}myField", Accumulators.sum("totalCount", 1)) + )); + } +} + """ + ) + fun `supports accumulators also expecting a constant argument`(psiFile: PsiFile) { + val aggregate = psiFile.getQueryAtMethod("Aggregation", "getAllBookTitles") + val parsedAggregate = JavaDriverDialect.parser.parse(aggregate) + val hasAggregation = parsedAggregate.component>() + assertEquals(1, hasAggregation?.children?.size) + + val groupStage = hasAggregation?.children?.get(0)!! + + val named = groupStage.component()!! + assertEquals(Name.GROUP, named.name) + + val idFieldRef = groupStage.component>()!!.reference as HasFieldReference.Inferred + val computedValueRef = groupStage.component>()!!.reference as HasValueReference.Computed + val accumulatedFields = groupStage.component>()!! + + assertEquals("_id", idFieldRef.fieldName) + val computedExpression = computedValueRef.type.expression + val fieldUsedForComputation = computedExpression.component>()!!.reference as HasFieldReference.FromSchema + + assertEquals("myField", fieldUsedForComputation.fieldName) + assertEquals("${'$'}myField", fieldUsedForComputation.displayName) + + assertEquals(1, accumulatedFields.children.size) + val totalCountNode = accumulatedFields.children.first() + val totalCountFieldRef = + totalCountNode.component>()!!.reference + as HasFieldReference.Computed + val totalCountValueRef = + totalCountNode.component>()!!.reference + as HasValueReference.Constant + + assertEquals("totalCount", totalCountFieldRef.fieldName) + assertEquals(1, totalCountValueRef.value) + } + + @WithFile( + fileName = "Repository.java", + value = """ +import com.mongodb.client.AggregateIterable; +import com.mongodb.client.MongoClient; +import com.mongodb.client.MongoCollection; +import com.mongodb.client.model.Aggregates; +import com.mongodb.client.model.Accumulators; +import com.mongodb.client.model.Filters; +import com.mongodb.client.model.Sorts;import org.bson.Document; +import org.bson.types.ObjectId; + +import java.util.List; + +import static com.mongodb.client.model.Filters.*; + +public final class Aggregation { + private final MongoCollection collection; + + public Aggregation(MongoClient client) { + this.collection = client.getDatabase("simple").getCollection("books"); + } + + public AggregateIterable getAllBookTitles(ObjectId id) { + return this.collection.aggregate(List.of( + Aggregates.group("${'$'}myField", Accumulators."|"("myKey", Sorts.ascending("mySort"), "myVal")) + )); + } +} + """, + ) + @ParameterizedTest + @CsvSource( + value = [ + "method;;expected", + "top;;TOP", + "bottom;;BOTTOM", + ], + delimiterString = ";;", + useHeadersInDisplayName = true + ) + fun `supports all relevant key-value accumulators with sorting criteria from the driver`( + method: String, + expected: Name, + psiFile: PsiFile + ) { + WriteCommandAction.runWriteCommandAction(psiFile.project) { + val elementAtCaret = psiFile.caret() + val javaFacade = JavaPsiFacade.getInstance(psiFile.project) + val methodToTest = javaFacade.parserFacade.createReferenceFromText(method, null) + elementAtCaret.replace(methodToTest) + } + + ApplicationManager.getApplication().runReadAction { + val aggregate = psiFile.getQueryAtMethod("Aggregation", "getAllBookTitles") + val parsedAggregate = JavaDriverDialect.parser.parse(aggregate) + val hasAggregation = parsedAggregate.component>() + assertEquals(1, hasAggregation?.children?.size) + + val groupStage = hasAggregation?.children?.get(0)!! + val named = groupStage.component()!! + assertEquals(Name.GROUP, named.name) + + val accumulator = groupStage.component>()!!.children[0] + val accumulatorName = accumulator.component()!! + assertEquals(expected, accumulatorName.name) + + val accumulatorField = accumulator.component>()?.reference as HasFieldReference.Computed + assertEquals("myKey", accumulatorField.fieldName) + + val accumulatorComputed = accumulator.component>()?.reference as HasValueReference.Computed + val accumulatorComputedFieldValue = accumulatorComputed.type.expression.component>()!!.reference as HasFieldReference.FromSchema + assertEquals("myVal", accumulatorComputedFieldValue.fieldName) + + val accumulatorSorting = accumulator.component>()!!.children[0] + val sortingField = accumulatorSorting.component>()!!.reference as HasFieldReference.FromSchema + assertEquals("mySort", sortingField.fieldName) + } + } +} diff --git a/packages/mongodb-dialects/mongosh/src/main/kotlin/com/mongodb/jbplugin/dialects/mongosh/MongoshDialectFormatter.kt b/packages/mongodb-dialects/mongosh/src/main/kotlin/com/mongodb/jbplugin/dialects/mongosh/MongoshDialectFormatter.kt index 495f1d08..e34b591d 100644 --- a/packages/mongodb-dialects/mongosh/src/main/kotlin/com/mongodb/jbplugin/dialects/mongosh/MongoshDialectFormatter.kt +++ b/packages/mongodb-dialects/mongosh/src/main/kotlin/com/mongodb/jbplugin/dialects/mongosh/MongoshDialectFormatter.kt @@ -5,9 +5,10 @@ import com.mongodb.jbplugin.dialects.OutputQuery import com.mongodb.jbplugin.dialects.mongosh.backend.MongoshBackend import com.mongodb.jbplugin.mql.* import com.mongodb.jbplugin.mql.components.* -import com.mongodb.jbplugin.mql.components.HasFieldReference.Computed import com.mongodb.jbplugin.mql.components.HasFieldReference.FromSchema +import com.mongodb.jbplugin.mql.components.HasFieldReference.Inferred import com.mongodb.jbplugin.mql.components.HasFieldReference.Unknown +import com.mongodb.jbplugin.mql.components.HasValueReference.Computed import com.mongodb.jbplugin.mql.parser.anyError import com.mongodb.jbplugin.mql.parser.components.aggregationStages import com.mongodb.jbplugin.mql.parser.components.allFiltersRecursively @@ -341,8 +342,9 @@ private fun MongoshBackend.resolveValueReference( private fun MongoshBackend.resolveFieldReference(fieldRef: HasFieldReference) = when (val ref = fieldRef.reference) { - is Computed -> registerConstant(ref.fieldName) is FromSchema -> registerConstant(ref.fieldName) + is Inferred -> registerConstant(ref.fieldName) + is HasFieldReference.Computed -> registerConstant(ref.fieldName) is Unknown -> registerVariable("field", BsonAny) } diff --git a/packages/mongodb-dialects/mongosh/src/main/kotlin/com/mongodb/jbplugin/dialects/mongosh/backend/MongoshBackend.kt b/packages/mongodb-dialects/mongosh/src/main/kotlin/com/mongodb/jbplugin/dialects/mongosh/backend/MongoshBackend.kt index daee46c9..1b106bb5 100644 --- a/packages/mongodb-dialects/mongosh/src/main/kotlin/com/mongodb/jbplugin/dialects/mongosh/backend/MongoshBackend.kt +++ b/packages/mongodb-dialects/mongosh/src/main/kotlin/com/mongodb/jbplugin/dialects/mongosh/backend/MongoshBackend.kt @@ -251,4 +251,5 @@ private fun defaultValueOfBsonType(type: BsonType): Any? = when (type) { is BsonObject -> emptyMap() BsonObjectId -> ObjectId("000000000000000000000000") BsonString -> "" + is ComputedBsonType<*> -> defaultValueOfBsonType(type.baseType) } diff --git a/packages/mongodb-mql-model/src/main/kotlin/com/mongodb/jbplugin/mql/BsonType.kt b/packages/mongodb-mql-model/src/main/kotlin/com/mongodb/jbplugin/mql/BsonType.kt index c538b9d6..bed80c28 100644 --- a/packages/mongodb-mql-model/src/main/kotlin/com/mongodb/jbplugin/mql/BsonType.kt +++ b/packages/mongodb-mql-model/src/main/kotlin/com/mongodb/jbplugin/mql/BsonType.kt @@ -179,6 +179,9 @@ data class BsonArray( } } +data class ComputedBsonType(val baseType: BsonType, val expression: Node) : + BsonType by baseType // for now it will behave as baseType + /** * Returns the inferred BSON type of the current Java class, considering it's nullability. * diff --git a/packages/mongodb-mql-model/src/main/kotlin/com/mongodb/jbplugin/mql/components/HasAccumulatedFields.kt b/packages/mongodb-mql-model/src/main/kotlin/com/mongodb/jbplugin/mql/components/HasAccumulatedFields.kt new file mode 100644 index 00000000..d56c85bb --- /dev/null +++ b/packages/mongodb-mql-model/src/main/kotlin/com/mongodb/jbplugin/mql/components/HasAccumulatedFields.kt @@ -0,0 +1,8 @@ +package com.mongodb.jbplugin.mql.components + +import com.mongodb.jbplugin.mql.HasChildren +import com.mongodb.jbplugin.mql.Node + +data class HasAccumulatedFields( + override val children: List> +) : HasChildren diff --git a/packages/mongodb-mql-model/src/main/kotlin/com/mongodb/jbplugin/mql/components/HasFieldReference.kt b/packages/mongodb-mql-model/src/main/kotlin/com/mongodb/jbplugin/mql/components/HasFieldReference.kt index a9dbe30b..4d17ee6d 100644 --- a/packages/mongodb-mql-model/src/main/kotlin/com/mongodb/jbplugin/mql/components/HasFieldReference.kt +++ b/packages/mongodb-mql-model/src/main/kotlin/com/mongodb/jbplugin/mql/components/HasFieldReference.kt @@ -24,6 +24,16 @@ data class HasFieldReference( val displayName: String = fieldName, ) : FieldReference + /** + * Encodes a FieldReference that is part of a schema, but it's not defined + * in code. For example, the _id field that is created on { $group: "$expr" }. + */ + data class Inferred( + val source: S, + val fieldName: String, + val displayName: String = fieldName, + ) : FieldReference + /** * Encodes a FieldReference that does not exist in the original schema and the value of which * is computed using some expression. @@ -31,5 +41,6 @@ data class HasFieldReference( data class Computed( val source: S, val fieldName: String, + val displayName: String = fieldName, ) : FieldReference } diff --git a/packages/mongodb-mql-model/src/main/kotlin/com/mongodb/jbplugin/mql/components/HasValueReference.kt b/packages/mongodb-mql-model/src/main/kotlin/com/mongodb/jbplugin/mql/components/HasValueReference.kt index 3961a45b..8b77fbe9 100644 --- a/packages/mongodb-mql-model/src/main/kotlin/com/mongodb/jbplugin/mql/components/HasValueReference.kt +++ b/packages/mongodb-mql-model/src/main/kotlin/com/mongodb/jbplugin/mql/components/HasValueReference.kt @@ -1,11 +1,13 @@ package com.mongodb.jbplugin.mql.components import com.mongodb.jbplugin.mql.BsonType -import com.mongodb.jbplugin.mql.Component +import com.mongodb.jbplugin.mql.ComputedBsonType +import com.mongodb.jbplugin.mql.HasChildren +import com.mongodb.jbplugin.mql.Node data class HasValueReference( val reference: ValueReference, -) : Component { +) : HasChildren { sealed interface ValueReference @@ -56,4 +58,26 @@ data class HasValueReference( val source: S, val type: BsonType, ) : ValueReference + + /** + * Encodes a ValueReference when the value is computed on the server side. For example + * for $group stages in an aggregation pipeline: + * ``` + * Aggregates.group("$year") //-> Computed(Node(HasFieldReference(...))) + * ``` + * The computedExpression can not be known, so we don't have a BsonType attached to it. + * + * Unless it can be inferred from the expression (not implemented), we will assume it's + * BsonAny + */ + data class Computed( + val source: S, + val type: ComputedBsonType, + ) : ValueReference + + override val children: List> + get() = when (reference) { + is Computed -> listOf(reference.type.expression) + else -> emptyList() + } } diff --git a/packages/mongodb-mql-model/src/main/kotlin/com/mongodb/jbplugin/mql/components/Named.kt b/packages/mongodb-mql-model/src/main/kotlin/com/mongodb/jbplugin/mql/components/Named.kt index dba2cdda..965905d5 100644 --- a/packages/mongodb-mql-model/src/main/kotlin/com/mongodb/jbplugin/mql/components/Named.kt +++ b/packages/mongodb-mql-model/src/main/kotlin/com/mongodb/jbplugin/mql/components/Named.kt @@ -52,6 +52,17 @@ enum class Name(val canonical: String) { PROJECT("project"), INCLUDE("include"), EXCLUDE("exclude"), + GROUP("group"), + SUM("sum"), + AVG("avg"), + FIRST("first"), + LAST("last"), + TOP("top"), + BOTTOM("bottom"), + MAX("max"), + MIN("min"), + PUSH("push"), + ADD_TO_SET("addToSet"), SORT("sort"), ASCENDING("ascending"), DESCENDING("descending"),