From 65d9dbe75accf5523810d430e3185b6361041ed3 Mon Sep 17 00:00:00 2001 From: Himanshu Singh Date: Mon, 25 Nov 2024 12:06:53 +0100 Subject: [PATCH 1/2] chore: added tests and suggestions from PR feedback --- .../JavaDriverCompletionContributorTest.kt | 131 +++++++ ...erMongoDbAutocompletionPopupHandlerTest.kt | 133 +++++++ ...avaDriverFieldCheckLinterInspectionTest.kt | 120 ++++++ .../glossary/JavaDriverDialectParser.kt | 21 +- .../aggregationparser/GroupStageParserTest.kt | 358 +++++++++++++++++- .../mql/components/HasValueReference.kt | 3 +- 6 files changed, 750 insertions(+), 16 deletions(-) diff --git a/packages/jetbrains-plugin/src/test/kotlin/com/mongodb/jbplugin/autocomplete/JavaDriverCompletionContributorTest.kt b/packages/jetbrains-plugin/src/test/kotlin/com/mongodb/jbplugin/autocomplete/JavaDriverCompletionContributorTest.kt index 51b51f53..76214b1f 100644 --- a/packages/jetbrains-plugin/src/test/kotlin/com/mongodb/jbplugin/autocomplete/JavaDriverCompletionContributorTest.kt +++ b/packages/jetbrains-plugin/src/test/kotlin/com/mongodb/jbplugin/autocomplete/JavaDriverCompletionContributorTest.kt @@ -805,4 +805,135 @@ public class Repository { }, ) } + + @ParsingTest( + fileName = "Repository.java", + value = """ +import com.mongodb.client.FindIterable; +import com.mongodb.client.MongoClient; +import com.mongodb.client.MongoCollection; +import com.mongodb.client.model.Aggregates; +import com.mongodb.client.model.Projections; +import com.mongodb.client.model.Sorts; +import org.bson.Document; +import org.bson.types.ObjectId; +import java.util.List; +import static com.mongodb.client.model.Filters.*; +import static com.mongodb.client.model.Updates.*; + +public class Repository { + private final MongoClient client; + + public Repository(MongoClient client) { + this.client = client; + } + + public void exampleFind() { + client.getDatabase("myDatabase").getCollection("myCollection") + .aggregate(List.of( + Aggregates.group( + "" + ) + )); + } +} + """, + ) + fun `should autocomplete fields from the current namespace for _id expression in Aggregates#group stage`( + fixture: CodeInsightTestFixture, + ) { + fixture.specifyDialect(JavaDriverDialect) + + val (dataSource, readModelProvider) = fixture.setupConnection() + val namespace = Namespace("myDatabase", "myCollection") + + `when`( + readModelProvider.slice(eq(dataSource), eq(GetCollectionSchema.Slice(namespace))) + ).thenReturn( + GetCollectionSchema( + CollectionSchema( + namespace, + BsonObject( + mapOf( + "myField" to BsonString, + ), + ), + ), + ), + ) + + val elements = fixture.completeBasic() + + assertTrue( + elements.containsElements { + it.lookupString == "myField" + }, + ) + } + + @ParsingTest( + fileName = "Repository.java", + value = """ +import com.mongodb.client.FindIterable; +import com.mongodb.client.MongoClient; +import com.mongodb.client.MongoCollection; +import com.mongodb.client.model.Accumulators;import com.mongodb.client.model.Aggregates; +import com.mongodb.client.model.Projections; +import com.mongodb.client.model.Sorts; +import org.bson.Document; +import org.bson.types.ObjectId; +import java.util.List; +import static com.mongodb.client.model.Filters.*; +import static com.mongodb.client.model.Updates.*; + +public class Repository { + private final MongoClient client; + + public Repository(MongoClient client) { + this.client = client; + } + + public void exampleFind() { + client.getDatabase("myDatabase").getCollection("myCollection") + .aggregate(List.of( + Aggregates.group( + "${'$'}year", + Accumulators.sum("totalMovies", "") + ) + )); + } +} + """, + ) + fun `should autocomplete fields from the current namespace for accumulator expression in Aggregates#group stage`( + fixture: CodeInsightTestFixture, + ) { + fixture.specifyDialect(JavaDriverDialect) + + val (dataSource, readModelProvider) = fixture.setupConnection() + val namespace = Namespace("myDatabase", "myCollection") + + `when`( + readModelProvider.slice(eq(dataSource), eq(GetCollectionSchema.Slice(namespace))) + ).thenReturn( + GetCollectionSchema( + CollectionSchema( + namespace, + BsonObject( + mapOf( + "myField" to BsonString, + ), + ), + ), + ), + ) + + val elements = fixture.completeBasic() + + assertTrue( + elements.containsElements { + it.lookupString == "myField" + }, + ) + } } diff --git a/packages/jetbrains-plugin/src/test/kotlin/com/mongodb/jbplugin/autocomplete/JavaDriverMongoDbAutocompletionPopupHandlerTest.kt b/packages/jetbrains-plugin/src/test/kotlin/com/mongodb/jbplugin/autocomplete/JavaDriverMongoDbAutocompletionPopupHandlerTest.kt index 357c9b5a..2557fa22 100644 --- a/packages/jetbrains-plugin/src/test/kotlin/com/mongodb/jbplugin/autocomplete/JavaDriverMongoDbAutocompletionPopupHandlerTest.kt +++ b/packages/jetbrains-plugin/src/test/kotlin/com/mongodb/jbplugin/autocomplete/JavaDriverMongoDbAutocompletionPopupHandlerTest.kt @@ -818,4 +818,137 @@ public class Repository { }, ) } + + @ParsingTest( + fileName = "Repository.java", + value = """ +import com.mongodb.client.FindIterable; +import com.mongodb.client.MongoClient; +import com.mongodb.client.MongoCollection; +import com.mongodb.client.model.Aggregates; +import com.mongodb.client.model.Projections; +import com.mongodb.client.model.Sorts; +import org.bson.Document; +import org.bson.types.ObjectId; +import java.util.List; +import static com.mongodb.client.model.Filters.*; +import static com.mongodb.client.model.Updates.*; + +public class Repository { + private final MongoClient client; + + public Repository(MongoClient client) { + this.client = client; + } + + public void exampleFind() { + client.getDatabase("myDatabase").getCollection("myCollection") + .aggregate(List.of( + Aggregates.group( + + ) + )); + } +} + """, + ) + fun `should autocomplete fields from the current namespace for _id expression in Aggregates#group stage`( + fixture: CodeInsightTestFixture, + ) { + fixture.specifyDialect(JavaDriverDialect) + + val (dataSource, readModelProvider) = fixture.setupConnection() + val namespace = Namespace("myDatabase", "myCollection") + + `when`( + readModelProvider.slice(eq(dataSource), eq(GetCollectionSchema.Slice(namespace))) + ).thenReturn( + GetCollectionSchema( + CollectionSchema( + namespace, + BsonObject( + mapOf( + "myField" to BsonString, + ), + ), + ), + ), + ) + + fixture.type('"') + val elements = fixture.completeBasic() + + assertTrue( + elements.containsElements { + it.lookupString == "myField" + }, + ) + } + + @ParsingTest( + fileName = "Repository.java", + value = """ +import com.mongodb.client.FindIterable; +import com.mongodb.client.MongoClient; +import com.mongodb.client.MongoCollection; +import com.mongodb.client.model.Accumulators;import com.mongodb.client.model.Aggregates; +import com.mongodb.client.model.Projections; +import com.mongodb.client.model.Sorts; +import org.bson.Document; +import org.bson.types.ObjectId; +import java.util.List; +import static com.mongodb.client.model.Filters.*; +import static com.mongodb.client.model.Updates.*; + +public class Repository { + private final MongoClient client; + + public Repository(MongoClient client) { + this.client = client; + } + + public void exampleFind() { + client.getDatabase("myDatabase").getCollection("myCollection") + .aggregate(List.of( + Aggregates.group( + "${'$'}year", + Accumulators.sum("totalMovies", ) + ) + )); + } +} + """, + ) + fun `should autocomplete fields from the current namespace for accumulator expression in Aggregates#group stage`( + fixture: CodeInsightTestFixture, + ) { + fixture.specifyDialect(JavaDriverDialect) + + val (dataSource, readModelProvider) = fixture.setupConnection() + val namespace = Namespace("myDatabase", "myCollection") + + `when`( + readModelProvider.slice(eq(dataSource), eq(GetCollectionSchema.Slice(namespace))) + ).thenReturn( + GetCollectionSchema( + CollectionSchema( + namespace, + BsonObject( + mapOf( + "myField" to BsonString, + ), + ), + ), + ), + ) + + fixture.type('"') + val elements = fixture.completeBasic() + + assertTrue( + elements.containsElements { + it.lookupString == "myField" + }, + ) + } } diff --git a/packages/jetbrains-plugin/src/test/kotlin/com/mongodb/jbplugin/inspections/impl/JavaDriverFieldCheckLinterInspectionTest.kt b/packages/jetbrains-plugin/src/test/kotlin/com/mongodb/jbplugin/inspections/impl/JavaDriverFieldCheckLinterInspectionTest.kt index 1da59d4a..46c0a3c9 100644 --- a/packages/jetbrains-plugin/src/test/kotlin/com/mongodb/jbplugin/inspections/impl/JavaDriverFieldCheckLinterInspectionTest.kt +++ b/packages/jetbrains-plugin/src/test/kotlin/com/mongodb/jbplugin/inspections/impl/JavaDriverFieldCheckLinterInspectionTest.kt @@ -9,6 +9,7 @@ import com.mongodb.jbplugin.fixtures.setupConnection import com.mongodb.jbplugin.fixtures.specifyDialect import com.mongodb.jbplugin.mql.BsonDouble import com.mongodb.jbplugin.mql.BsonObject +import com.mongodb.jbplugin.mql.BsonString import com.mongodb.jbplugin.mql.CollectionSchema import com.mongodb.jbplugin.mql.Namespace import org.mockito.Mockito.`when` @@ -501,4 +502,123 @@ public class Repository { fixture.enableInspections(FieldCheckInspectionBridge::class.java) fixture.testHighlighting() } + + @ParsingTest( + fileName = "Repository.java", + value = """ +import com.mongodb.client.AggregateIterable; +import com.mongodb.client.MongoClient; +import com.mongodb.client.MongoCollection; +import com.mongodb.client.model.*; +import org.bson.Document; +import org.bson.conversions.Bson; +import org.bson.types.ObjectId; +import java.util.List; +import static com.mongodb.client.model.Filters.*; + +public class Repository { + private final MongoClient client; + + public Repository(MongoClient client) { + this.client = client; + } + + public AggregateIterable goodGroupAggregate1() { + return client.getDatabase("myDatabase") + .getCollection("myCollection") + .aggregate(List.of( + Aggregates.group("${'$'}possibleIdField"), + Aggregates.group("${'$'}possibleIdField", Accumulators.sum("totalCount", 1)), + Aggregates.group( + "${'$'}possibleIdField", + Accumulators.sum("totalCount", "${'$'}otherField") + ) + )); + } + + private String getOtherField() { + return "${'$'}otherField"; + } + + public AggregateIterable goodGroupAggregate2() { + String fieldName = "${'$'}possibleIdField"; + BsonField totalCountAcc = Accumulators.sum("totalCount", 1); + return client.getDatabase("myDatabase") + .getCollection("myCollection") + .aggregate(List.of( + Aggregates.group(fieldName), + Aggregates.group(fieldName, totalCountAcc), + Aggregates.group( + fieldName, + Accumulators.sum("totalCount", getOtherField()) + ) + )); + } + + private String getBadFieldName() { + return "${'$'}nonExistentField"; + } + + private BsonField getAvgCountAcc() { + return Accumulators.avg( + "avgCount", + getBadFieldName() + ); + } + + public AggregateIterable badGroupAggregate1() { + String badFieldName = "${'$'}nonExistentField"; + BsonField avgCountAcc = Accumulators.avg( + "avgCount", + badFieldName + ); + return client.getDatabase("myDatabase") + .getCollection("myCollection") + .aggregate(List.of( + Aggregates.group( + "${'$'}nonExistentField" + ), + Aggregates.group( + badFieldName, + Accumulators.sum( + "totalCount", + badFieldName + ), + Accumulators.sum( + "totalCount", + getBadFieldName() + ), + avgCountAcc, + getAvgCountAcc() + ) + )); + } +} + """, + ) + fun `shows an inspection for Aggregates#group call when the field does not exist in the current namespace`( + fixture: CodeInsightTestFixture, + ) { + val (dataSource, readModelProvider) = fixture.setupConnection() + fixture.specifyDialect(JavaDriverDialect) + + `when`( + readModelProvider.slice(eq(dataSource), any()) + ).thenReturn( + GetCollectionSchema( + CollectionSchema( + Namespace("myDatabase", "myCollection"), + BsonObject( + mapOf( + "possibleIdField" to BsonString, + "otherField" to BsonString, + ) + ) + ) + ), + ) + + fixture.enableInspections(FieldCheckInspectionBridge::class.java) + fixture.testHighlighting() + } } diff --git a/packages/mongodb-dialects/java-driver/src/main/kotlin/com/mongodb/jbplugin/dialects/javadriver/glossary/JavaDriverDialectParser.kt b/packages/mongodb-dialects/java-driver/src/main/kotlin/com/mongodb/jbplugin/dialects/javadriver/glossary/JavaDriverDialectParser.kt index 170dff62..3a4b083a 100644 --- a/packages/mongodb-dialects/java-driver/src/main/kotlin/com/mongodb/jbplugin/dialects/javadriver/glossary/JavaDriverDialectParser.kt +++ b/packages/mongodb-dialects/java-driver/src/main/kotlin/com/mongodb/jbplugin/dialects/javadriver/glossary/JavaDriverDialectParser.kt @@ -11,6 +11,7 @@ import com.mongodb.jbplugin.mql.BsonAnyOf import com.mongodb.jbplugin.mql.BsonArray import com.mongodb.jbplugin.mql.BsonBoolean import com.mongodb.jbplugin.mql.BsonInt32 +import com.mongodb.jbplugin.mql.BsonString import com.mongodb.jbplugin.mql.BsonType import com.mongodb.jbplugin.mql.ComputedBsonType import com.mongodb.jbplugin.mql.Node @@ -612,7 +613,7 @@ object JavaDriverDialectParser : DialectParser { val valueExpr = expression.argumentList.expressions.getOrNull(1) ?: return null val fieldName = keyExpr.tryToResolveAsConstantString() - val accumulatorExpr = parseComputedExpression(valueExpr, createsNewField = false) + val accumulatorExpr = parseComputedExpression(valueExpr) return Node( expression, @@ -639,7 +640,7 @@ object JavaDriverDialectParser : DialectParser { val fieldName = keyExpr.tryToResolveAsConstantString() val sort = parseBsonBuilderCallsSimilarToProjections(sortExpr, SORTS_FQN) - val accumulatorExpr = parseComputedExpression(valueExpr, createsNewField = false) + val accumulatorExpr = parseComputedExpression(valueExpr) return Node( expression, @@ -715,26 +716,20 @@ object JavaDriverDialectParser : DialectParser { return typeOfFirstArg != null && typeOfFirstArg.equalsToText(SESSION_FQN) } - private fun parseComputedExpression(element: PsiElement, createsNewField: Boolean = true): HasValueReference { + private fun parseComputedExpression(element: PsiElement): HasValueReference { val (constant, value) = element.tryToResolveAsConstant() return HasValueReference( when { constant && value is String -> HasValueReference.Computed( element, type = ComputedBsonType( - BsonAny, + BsonString, Node( element, listOf( - if (createsNewField) { - HasFieldReference( - Computed(element, value.trim('$'), value) - ) - } else { - HasFieldReference( - FromSchema(element, value.trim('$'), value) - ) - } + HasFieldReference( + FromSchema(element, value.trim('$'), value) + ) ) ) ) diff --git a/packages/mongodb-dialects/java-driver/src/test/kotlin/com/mongodb/jbplugin/dialects/javadriver/glossary/aggregationparser/GroupStageParserTest.kt b/packages/mongodb-dialects/java-driver/src/test/kotlin/com/mongodb/jbplugin/dialects/javadriver/glossary/aggregationparser/GroupStageParserTest.kt index ccad05b5..dfb90040 100644 --- a/packages/mongodb-dialects/java-driver/src/test/kotlin/com/mongodb/jbplugin/dialects/javadriver/glossary/aggregationparser/GroupStageParserTest.kt +++ b/packages/mongodb-dialects/java-driver/src/test/kotlin/com/mongodb/jbplugin/dialects/javadriver/glossary/aggregationparser/GroupStageParserTest.kt @@ -73,7 +73,126 @@ public final class Aggregation { assertEquals(0, accumulatedFields.children.size) val computedExpression = computedValueRef.type.expression - val fieldUsedForComputation = computedExpression.component>()!!.reference as HasFieldReference.Computed + val fieldUsedForComputation = computedExpression.component>()!!.reference as HasFieldReference.FromSchema + + assertEquals("myField", fieldUsedForComputation.fieldName) + assertEquals("${'$'}myField", fieldUsedForComputation.displayName) + } + + @ParsingTest( + fileName = "Aggregation.java", + value = """ +import com.mongodb.client.AggregateIterable; +import com.mongodb.client.MongoClient; +import com.mongodb.client.MongoCollection; +import com.mongodb.client.model.Aggregates; +import com.mongodb.client.model.Filters; +import org.bson.Document; +import org.bson.types.ObjectId; + +import java.util.List; + +import static com.mongodb.client.model.Filters.*; + +public final class Aggregation { + private final MongoCollection collection; + + public Aggregation(MongoClient client) { + this.collection = client.getDatabase("simple").getCollection("books"); + } + + public AggregateIterable getAllBookTitles(ObjectId id) { + String _idField = "${'$'}myField"; + return this.collection.aggregate(List.of( + Aggregates.group(_idField) + )); + } +} + """ + ) + fun `should be able to parse a group stage without accumulators, when _id is a variable`( + psiFile: PsiFile + ) { + val aggregate = psiFile.getQueryAtMethod("Aggregation", "getAllBookTitles") + val parsedAggregate = JavaDriverDialect.parser.parse(aggregate) + val hasAggregation = parsedAggregate.component>() + assertEquals(1, hasAggregation?.children?.size) + + val groupStage = hasAggregation?.children?.get(0)!! + + val named = groupStage.component()!! + assertEquals(Name.GROUP, named.name) + + val idFieldRef = groupStage.component>()!!.reference as HasFieldReference.Inferred + val computedValueRef = groupStage.component>()!!.reference as HasValueReference.Computed + val accumulatedFields = groupStage.component>()!! + + assertEquals("_id", idFieldRef.fieldName) + assertEquals(0, accumulatedFields.children.size) + + val computedExpression = computedValueRef.type.expression + val fieldUsedForComputation = computedExpression.component>()!!.reference as HasFieldReference.FromSchema + + assertEquals("myField", fieldUsedForComputation.fieldName) + assertEquals("${'$'}myField", fieldUsedForComputation.displayName) + } + + @ParsingTest( + fileName = "Aggregation.java", + value = """ +import com.mongodb.client.AggregateIterable; +import com.mongodb.client.MongoClient; +import com.mongodb.client.MongoCollection; +import com.mongodb.client.model.Aggregates; +import com.mongodb.client.model.Filters; +import org.bson.Document; +import org.bson.types.ObjectId; + +import java.util.List; + +import static com.mongodb.client.model.Filters.*; + +public final class Aggregation { + private final MongoCollection collection; + + public Aggregation(MongoClient client) { + this.collection = client.getDatabase("simple").getCollection("books"); + } + + private String getIdField() { + return "${'$'}myField"; + } + + public AggregateIterable getAllBookTitles(ObjectId id) { + return this.collection.aggregate(List.of( + Aggregates.group(getIdField()) + )); + } +} + """ + ) + fun `should be able to parse a group stage without accumulators, when _id is from a method call`( + psiFile: PsiFile + ) { + val aggregate = psiFile.getQueryAtMethod("Aggregation", "getAllBookTitles") + val parsedAggregate = JavaDriverDialect.parser.parse(aggregate) + val hasAggregation = parsedAggregate.component>() + assertEquals(1, hasAggregation?.children?.size) + + val groupStage = hasAggregation?.children?.get(0)!! + + val named = groupStage.component()!! + assertEquals(Name.GROUP, named.name) + + val idFieldRef = groupStage.component>()!!.reference as HasFieldReference.Inferred + val computedValueRef = groupStage.component>()!!.reference as HasValueReference.Computed + val accumulatedFields = groupStage.component>()!! + + assertEquals("_id", idFieldRef.fieldName) + assertEquals(0, accumulatedFields.children.size) + + val computedExpression = computedValueRef.type.expression + val fieldUsedForComputation = computedExpression.component>()!!.reference as HasFieldReference.FromSchema assertEquals("myField", fieldUsedForComputation.fieldName) assertEquals("${'$'}myField", fieldUsedForComputation.displayName) @@ -169,6 +288,243 @@ import com.mongodb.client.MongoClient; import com.mongodb.client.MongoCollection; import com.mongodb.client.model.Aggregates; import com.mongodb.client.model.Accumulators; +import com.mongodb.client.model.BsonField; +import com.mongodb.client.model.Filters; +import org.bson.Document; +import org.bson.types.ObjectId; + +import java.util.List; + +import static com.mongodb.client.model.Filters.*; + +public final class Aggregation { + private final MongoCollection collection; + + public Aggregation(MongoClient client) { + this.collection = client.getDatabase("simple").getCollection("books"); + } + + public AggregateIterable getAllBookTitles(ObjectId id) { + BsonField accumulatedExpr = Accumulators."|"("myKey", "myVal"); + return this.collection.aggregate(List.of( + Aggregates.group("${'$'}myField", accumulatedExpr) + )); + } +} + """, + ) + @ParameterizedTest + @CsvSource( + value = [ + "method;;expected", + "sum;;SUM", + "avg;;AVG", + "first;;FIRST", + "last;;LAST", + "max;;MAX", + "min;;MIN", + "push;;PUSH", + "addToSet;;ADD_TO_SET", + ], + delimiterString = ";;", + useHeadersInDisplayName = true + ) + fun `supports all relevant key-value accumulators from the driver, when accumulated expression is a variable`( + method: String, + expected: Name, + psiFile: PsiFile + ) { + WriteCommandAction.runWriteCommandAction(psiFile.project) { + val elementAtCaret = psiFile.caret() + val javaFacade = JavaPsiFacade.getInstance(psiFile.project) + val methodToTest = javaFacade.parserFacade.createReferenceFromText(method, null) + elementAtCaret.replace(methodToTest) + } + + ApplicationManager.getApplication().runReadAction { + val aggregate = psiFile.getQueryAtMethod("Aggregation", "getAllBookTitles") + val parsedAggregate = JavaDriverDialect.parser.parse(aggregate) + val hasAggregation = parsedAggregate.component>() + assertEquals(1, hasAggregation?.children?.size) + + val groupStage = hasAggregation?.children?.get(0)!! + val named = groupStage.component()!! + assertEquals(Name.GROUP, named.name) + + val accumulator = groupStage.component>()!!.children[0] + val accumulatorName = accumulator.component()!! + assertEquals(expected, accumulatorName.name) + + val accumulatorField = accumulator.component>()?.reference as HasFieldReference.Computed + assertEquals("myKey", accumulatorField.fieldName) + + val accumulatorComputed = accumulator.component>()?.reference as HasValueReference.Computed + val accumulatorComputedFieldValue = accumulatorComputed.type.expression.component>()!!.reference as HasFieldReference.FromSchema + assertEquals("myVal", accumulatorComputedFieldValue.fieldName) + } + } + + @WithFile( + fileName = "Repository.java", + value = """ +import com.mongodb.client.AggregateIterable; +import com.mongodb.client.MongoClient; +import com.mongodb.client.MongoCollection; +import com.mongodb.client.model.Aggregates; +import com.mongodb.client.model.Accumulators; +import com.mongodb.client.model.BsonField; +import com.mongodb.client.model.Filters; +import org.bson.Document; +import org.bson.types.ObjectId; + +import java.util.List; + +import static com.mongodb.client.model.Filters.*; + +public final class Aggregation { + private final MongoCollection collection; + + public Aggregation(MongoClient client) { + this.collection = client.getDatabase("simple").getCollection("books"); + } + + private BsonField getAccumulatedExpr() { + return Accumulators."|"("myKey", "myVal"); + } + + public AggregateIterable getAllBookTitles(ObjectId id) { + return this.collection.aggregate(List.of( + Aggregates.group("${'$'}myField", getAccumulatedExpr()) + )); + } +} + """, + ) + @ParameterizedTest + @CsvSource( + value = [ + "method;;expected", + "sum;;SUM", + "avg;;AVG", + "first;;FIRST", + "last;;LAST", + "max;;MAX", + "min;;MIN", + "push;;PUSH", + "addToSet;;ADD_TO_SET", + ], + delimiterString = ";;", + useHeadersInDisplayName = true + ) + fun `supports all relevant key-value accumulators from the driver, when accumulated expression is from a method call`( + method: String, + expected: Name, + psiFile: PsiFile + ) { + WriteCommandAction.runWriteCommandAction(psiFile.project) { + val elementAtCaret = psiFile.caret() + val javaFacade = JavaPsiFacade.getInstance(psiFile.project) + val methodToTest = javaFacade.parserFacade.createReferenceFromText(method, null) + elementAtCaret.replace(methodToTest) + } + + ApplicationManager.getApplication().runReadAction { + val aggregate = psiFile.getQueryAtMethod("Aggregation", "getAllBookTitles") + val parsedAggregate = JavaDriverDialect.parser.parse(aggregate) + val hasAggregation = parsedAggregate.component>() + assertEquals(1, hasAggregation?.children?.size) + + val groupStage = hasAggregation?.children?.get(0)!! + val named = groupStage.component()!! + assertEquals(Name.GROUP, named.name) + + val accumulator = groupStage.component>()!!.children[0] + val accumulatorName = accumulator.component()!! + assertEquals(expected, accumulatorName.name) + + val accumulatorField = accumulator.component>()?.reference as HasFieldReference.Computed + assertEquals("myKey", accumulatorField.fieldName) + + val accumulatorComputed = accumulator.component>()?.reference as HasValueReference.Computed + val accumulatorComputedFieldValue = accumulatorComputed.type.expression.component>()!!.reference as HasFieldReference.FromSchema + assertEquals("myVal", accumulatorComputedFieldValue.fieldName) + } + } + + @ParsingTest( + fileName = "Aggregation.java", + value = """ +import com.mongodb.client.AggregateIterable; +import com.mongodb.client.MongoClient; +import com.mongodb.client.MongoCollection; +import com.mongodb.client.model.Accumulators; +import com.mongodb.client.model.Aggregates; +import com.mongodb.client.model.Filters; +import org.bson.Document; +import org.bson.types.ObjectId; + +import java.util.List; + +import static com.mongodb.client.model.Filters.*; + +public final class Aggregation { + private final MongoCollection collection; + + public Aggregation(MongoClient client) { + this.collection = client.getDatabase("simple").getCollection("books"); + } + + public AggregateIterable getAllBookTitles(ObjectId id) { + return this.collection.aggregate(List.of( + Aggregates.group("${'$'}myField", Accumulators.sum("totalCount", 1)) + )); + } +} + """ + ) + fun `supports accumulators also expecting a constant argument`(psiFile: PsiFile) { + val aggregate = psiFile.getQueryAtMethod("Aggregation", "getAllBookTitles") + val parsedAggregate = JavaDriverDialect.parser.parse(aggregate) + val hasAggregation = parsedAggregate.component>() + assertEquals(1, hasAggregation?.children?.size) + + val groupStage = hasAggregation?.children?.get(0)!! + + val named = groupStage.component()!! + assertEquals(Name.GROUP, named.name) + + val idFieldRef = groupStage.component>()!!.reference as HasFieldReference.Inferred + val computedValueRef = groupStage.component>()!!.reference as HasValueReference.Computed + val accumulatedFields = groupStage.component>()!! + + assertEquals("_id", idFieldRef.fieldName) + val computedExpression = computedValueRef.type.expression + val fieldUsedForComputation = computedExpression.component>()!!.reference as HasFieldReference.FromSchema + + assertEquals("myField", fieldUsedForComputation.fieldName) + assertEquals("${'$'}myField", fieldUsedForComputation.displayName) + + assertEquals(1, accumulatedFields.children.size) + val totalCountNode = accumulatedFields.children.first() + val totalCountFieldRef = + totalCountNode.component>()!!.reference + as HasFieldReference.Computed + val totalCountValueRef = + totalCountNode.component>()!!.reference + as HasValueReference.Constant + + assertEquals("totalCount", totalCountFieldRef.fieldName) + assertEquals(1, totalCountValueRef.value) + } + + @WithFile( + fileName = "Repository.java", + value = """ +import com.mongodb.client.AggregateIterable; +import com.mongodb.client.MongoClient; +import com.mongodb.client.MongoCollection; +import com.mongodb.client.model.Aggregates; +import com.mongodb.client.model.Accumulators; import com.mongodb.client.model.Filters; import com.mongodb.client.model.Sorts;import org.bson.Document; import org.bson.types.ObjectId; diff --git a/packages/mongodb-mql-model/src/main/kotlin/com/mongodb/jbplugin/mql/components/HasValueReference.kt b/packages/mongodb-mql-model/src/main/kotlin/com/mongodb/jbplugin/mql/components/HasValueReference.kt index 0f57b0be..8b77fbe9 100644 --- a/packages/mongodb-mql-model/src/main/kotlin/com/mongodb/jbplugin/mql/components/HasValueReference.kt +++ b/packages/mongodb-mql-model/src/main/kotlin/com/mongodb/jbplugin/mql/components/HasValueReference.kt @@ -1,14 +1,13 @@ package com.mongodb.jbplugin.mql.components import com.mongodb.jbplugin.mql.BsonType -import com.mongodb.jbplugin.mql.Component import com.mongodb.jbplugin.mql.ComputedBsonType import com.mongodb.jbplugin.mql.HasChildren import com.mongodb.jbplugin.mql.Node data class HasValueReference( val reference: ValueReference, -) : Component, HasChildren { +) : HasChildren { sealed interface ValueReference From d39254509f882f8242fba6eefc884aad51d8d2ad Mon Sep 17 00:00:00 2001 From: Himanshu Singh Date: Mon, 25 Nov 2024 12:15:51 +0100 Subject: [PATCH 2/2] fix: fixes parsing of null _id expression --- ...avaDriverFieldCheckLinterInspectionTest.kt | 1 + .../glossary/JavaDriverDialectParser.kt | 4 +- .../aggregationparser/GroupStageParserTest.kt | 50 +++++++++++++++++++ 3 files changed, 53 insertions(+), 2 deletions(-) diff --git a/packages/jetbrains-plugin/src/test/kotlin/com/mongodb/jbplugin/inspections/impl/JavaDriverFieldCheckLinterInspectionTest.kt b/packages/jetbrains-plugin/src/test/kotlin/com/mongodb/jbplugin/inspections/impl/JavaDriverFieldCheckLinterInspectionTest.kt index 46c0a3c9..eb0d4145 100644 --- a/packages/jetbrains-plugin/src/test/kotlin/com/mongodb/jbplugin/inspections/impl/JavaDriverFieldCheckLinterInspectionTest.kt +++ b/packages/jetbrains-plugin/src/test/kotlin/com/mongodb/jbplugin/inspections/impl/JavaDriverFieldCheckLinterInspectionTest.kt @@ -527,6 +527,7 @@ public class Repository { return client.getDatabase("myDatabase") .getCollection("myCollection") .aggregate(List.of( + Aggregates.group(null), Aggregates.group("${'$'}possibleIdField"), Aggregates.group("${'$'}possibleIdField", Accumulators.sum("totalCount", 1)), Aggregates.group( diff --git a/packages/mongodb-dialects/java-driver/src/main/kotlin/com/mongodb/jbplugin/dialects/javadriver/glossary/JavaDriverDialectParser.kt b/packages/mongodb-dialects/java-driver/src/main/kotlin/com/mongodb/jbplugin/dialects/javadriver/glossary/JavaDriverDialectParser.kt index 3a4b083a..59439f83 100644 --- a/packages/mongodb-dialects/java-driver/src/main/kotlin/com/mongodb/jbplugin/dialects/javadriver/glossary/JavaDriverDialectParser.kt +++ b/packages/mongodb-dialects/java-driver/src/main/kotlin/com/mongodb/jbplugin/dialects/javadriver/glossary/JavaDriverDialectParser.kt @@ -734,10 +734,10 @@ object JavaDriverDialectParser : DialectParser { ) ) ) - constant && value != null -> HasValueReference.Constant( + constant -> HasValueReference.Constant( element, value, - value.javaClass.toBsonType(value) + value?.javaClass.toBsonType(value) ) !constant && element is PsiExpression -> HasValueReference.Runtime( element, diff --git a/packages/mongodb-dialects/java-driver/src/test/kotlin/com/mongodb/jbplugin/dialects/javadriver/glossary/aggregationparser/GroupStageParserTest.kt b/packages/mongodb-dialects/java-driver/src/test/kotlin/com/mongodb/jbplugin/dialects/javadriver/glossary/aggregationparser/GroupStageParserTest.kt index dfb90040..1c5a562b 100644 --- a/packages/mongodb-dialects/java-driver/src/test/kotlin/com/mongodb/jbplugin/dialects/javadriver/glossary/aggregationparser/GroupStageParserTest.kt +++ b/packages/mongodb-dialects/java-driver/src/test/kotlin/com/mongodb/jbplugin/dialects/javadriver/glossary/aggregationparser/GroupStageParserTest.kt @@ -39,6 +39,56 @@ import java.util.List; import static com.mongodb.client.model.Filters.*; +public final class Aggregation { + private final MongoCollection collection; + + public Aggregation(MongoClient client) { + this.collection = client.getDatabase("simple").getCollection("books"); + } + + public AggregateIterable getAllBookTitles(ObjectId id) { + return this.collection.aggregate(List.of( + Aggregates.group(null) + )); + } +} + """ + ) + fun `should be able to parse a group stage with null _id`(psiFile: PsiFile) { + val aggregate = psiFile.getQueryAtMethod("Aggregation", "getAllBookTitles") + val parsedAggregate = JavaDriverDialect.parser.parse(aggregate) + val hasAggregation = parsedAggregate.component>() + assertEquals(1, hasAggregation?.children?.size) + + val groupStage = hasAggregation?.children?.get(0)!! + + val named = groupStage.component()!! + assertEquals(Name.GROUP, named.name) + + val idFieldRef = groupStage.component>()!!.reference as HasFieldReference.Inferred + val constantValueRef = groupStage.component>()!!.reference as HasValueReference.Constant + val accumulatedFields = groupStage.component>()!! + + assertEquals("_id", idFieldRef.fieldName) + assertEquals(0, accumulatedFields.children.size) + assertEquals(null, constantValueRef.value) + } + + @ParsingTest( + fileName = "Aggregation.java", + value = """ +import com.mongodb.client.AggregateIterable; +import com.mongodb.client.MongoClient; +import com.mongodb.client.MongoCollection; +import com.mongodb.client.model.Aggregates; +import com.mongodb.client.model.Filters; +import org.bson.Document; +import org.bson.types.ObjectId; + +import java.util.List; + +import static com.mongodb.client.model.Filters.*; + public final class Aggregation { private final MongoCollection collection;