diff --git a/packages/jetbrains-plugin/src/test/kotlin/com/mongodb/jbplugin/autocomplete/JavaDriverCompletionContributorTest.kt b/packages/jetbrains-plugin/src/test/kotlin/com/mongodb/jbplugin/autocomplete/JavaDriverCompletionContributorTest.kt index af876f40..906be52b 100644 --- a/packages/jetbrains-plugin/src/test/kotlin/com/mongodb/jbplugin/autocomplete/JavaDriverCompletionContributorTest.kt +++ b/packages/jetbrains-plugin/src/test/kotlin/com/mongodb/jbplugin/autocomplete/JavaDriverCompletionContributorTest.kt @@ -543,4 +543,201 @@ public class Repository { }, ) } + + @ParsingTest( + fileName = "Repository.java", + value = """ +import com.mongodb.client.FindIterable; +import com.mongodb.client.MongoClient; +import com.mongodb.client.MongoCollection; +import com.mongodb.client.model.Aggregates; +import com.mongodb.client.model.Projections; +import com.mongodb.client.model.Sorts; +import org.bson.Document; +import org.bson.types.ObjectId; +import java.util.List; +import static com.mongodb.client.model.Filters.*; +import static com.mongodb.client.model.Updates.*; + +public class Repository { + private final MongoClient client; + + public Repository(MongoClient client) { + this.client = client; + } + + public void exampleFind() { + client.getDatabase("myDatabase").getCollection("myCollection") + .aggregate(List.of( + Aggregates.sort( + Sorts.ascending("") + ) + )); + } +} + """, + ) + fun `should autocomplete fields from the current namespace in Sorts#ascending of an Aggregates#sort stage`( + fixture: CodeInsightTestFixture, + ) { + fixture.specifyDialect(JavaDriverDialect) + + val (dataSource, readModelProvider) = fixture.setupConnection() + val namespace = Namespace("myDatabase", "myCollection") + + `when`( + readModelProvider.slice(eq(dataSource), eq(GetCollectionSchema.Slice(namespace))) + ).thenReturn( + GetCollectionSchema( + CollectionSchema( + namespace, + BsonObject( + mapOf( + "myField" to BsonString, + ), + ), + ), + ), + ) + + val elements = fixture.completeBasic() + + assertTrue( + elements.containsElements { + it.lookupString == "myField" + }, + ) + } + + @ParsingTest( + fileName = "Repository.java", + value = """ +import com.mongodb.client.FindIterable; +import com.mongodb.client.MongoClient; +import com.mongodb.client.MongoCollection; +import com.mongodb.client.model.Aggregates; +import com.mongodb.client.model.Projections; +import com.mongodb.client.model.Sorts; +import org.bson.Document; +import org.bson.types.ObjectId; +import java.util.List; +import static com.mongodb.client.model.Filters.*; +import static com.mongodb.client.model.Updates.*; + +public class Repository { + private final MongoClient client; + + public Repository(MongoClient client) { + this.client = client; + } + + public void exampleFind() { + client.getDatabase("myDatabase").getCollection("myCollection") + .aggregate(List.of( + Aggregates.sort( + Sorts.descending("") + ) + )); + } +} + """, + ) + fun `should autocomplete fields from the current namespace in Sorts#descending of an Aggregates#sort stage`( + fixture: CodeInsightTestFixture, + ) { + fixture.specifyDialect(JavaDriverDialect) + + val (dataSource, readModelProvider) = fixture.setupConnection() + val namespace = Namespace("myDatabase", "myCollection") + + `when`( + readModelProvider.slice(eq(dataSource), eq(GetCollectionSchema.Slice(namespace))) + ).thenReturn( + GetCollectionSchema( + CollectionSchema( + namespace, + BsonObject( + mapOf( + "myField" to BsonString, + ), + ), + ), + ), + ) + + val elements = fixture.completeBasic() + + assertTrue( + elements.containsElements { + it.lookupString == "myField" + }, + ) + } + + @ParsingTest( + fileName = "Repository.java", + value = """ +import com.mongodb.client.FindIterable; +import com.mongodb.client.MongoClient; +import com.mongodb.client.MongoCollection; +import com.mongodb.client.model.Aggregates; +import com.mongodb.client.model.Projections; +import com.mongodb.client.model.Sorts; +import org.bson.Document; +import org.bson.types.ObjectId; +import java.util.List; +import static com.mongodb.client.model.Filters.*; +import static com.mongodb.client.model.Updates.*; + +public class Repository { + private final MongoClient client; + + public Repository(MongoClient client) { + this.client = client; + } + + public void exampleFind() { + client.getDatabase("myDatabase").getCollection("myCollection") + .aggregate(List.of( + Aggregates.sort( + Sorts.orderBy( + Sorts.descending("") + ) + ) + )); + } +} + """, + ) + fun `should autocomplete fields from the current namespace in Sorts#orderBy of an Aggregates#sort stage`( + fixture: CodeInsightTestFixture, + ) { + fixture.specifyDialect(JavaDriverDialect) + + val (dataSource, readModelProvider) = fixture.setupConnection() + val namespace = Namespace("myDatabase", "myCollection") + + `when`( + readModelProvider.slice(eq(dataSource), eq(GetCollectionSchema.Slice(namespace))) + ).thenReturn( + GetCollectionSchema( + CollectionSchema( + namespace, + BsonObject( + mapOf( + "myField" to BsonString, + ), + ), + ), + ), + ) + + val elements = fixture.completeBasic() + + assertTrue( + elements.containsElements { + it.lookupString == "myField" + }, + ) + } } diff --git a/packages/jetbrains-plugin/src/test/kotlin/com/mongodb/jbplugin/autocomplete/JavaDriverMongoDbAutocompletionPopupHandlerTest.kt b/packages/jetbrains-plugin/src/test/kotlin/com/mongodb/jbplugin/autocomplete/JavaDriverMongoDbAutocompletionPopupHandlerTest.kt index 235e7d1d..b2fc5251 100644 --- a/packages/jetbrains-plugin/src/test/kotlin/com/mongodb/jbplugin/autocomplete/JavaDriverMongoDbAutocompletionPopupHandlerTest.kt +++ b/packages/jetbrains-plugin/src/test/kotlin/com/mongodb/jbplugin/autocomplete/JavaDriverMongoDbAutocompletionPopupHandlerTest.kt @@ -552,4 +552,204 @@ public class Repository { }, ) } + + @ParsingTest( + fileName = "Repository.java", + value = """ +import com.mongodb.client.FindIterable; +import com.mongodb.client.MongoClient; +import com.mongodb.client.MongoCollection; +import com.mongodb.client.model.Aggregates; +import com.mongodb.client.model.Projections; +import com.mongodb.client.model.Sorts; +import org.bson.Document; +import org.bson.types.ObjectId; +import java.util.List; +import static com.mongodb.client.model.Filters.*; +import static com.mongodb.client.model.Updates.*; + +public class Repository { + private final MongoClient client; + + public Repository(MongoClient client) { + this.client = client; + } + + public void exampleFind() { + client.getDatabase("myDatabase").getCollection("myCollection") + .aggregate(List.of( + Aggregates.sort( + Sorts.ascending() + ) + )); + } +} + """, + ) + fun `should autocomplete fields from the current namespace in Sorts#ascending of an Aggregates#sort stage`( + fixture: CodeInsightTestFixture, + ) { + fixture.specifyDialect(JavaDriverDialect) + + val (dataSource, readModelProvider) = fixture.setupConnection() + val namespace = Namespace("myDatabase", "myCollection") + + `when`( + readModelProvider.slice(eq(dataSource), eq(GetCollectionSchema.Slice(namespace))) + ).thenReturn( + GetCollectionSchema( + CollectionSchema( + namespace, + BsonObject( + mapOf( + "myField" to BsonString, + ), + ), + ), + ), + ) + + fixture.type('"') + val elements = fixture.completeBasic() + + assertTrue( + elements.containsElements { + it.lookupString == "myField" + }, + ) + } + + @ParsingTest( + fileName = "Repository.java", + value = """ +import com.mongodb.client.FindIterable; +import com.mongodb.client.MongoClient; +import com.mongodb.client.MongoCollection; +import com.mongodb.client.model.Aggregates; +import com.mongodb.client.model.Projections; +import com.mongodb.client.model.Sorts; +import org.bson.Document; +import org.bson.types.ObjectId; +import java.util.List; +import static com.mongodb.client.model.Filters.*; +import static com.mongodb.client.model.Updates.*; + +public class Repository { + private final MongoClient client; + + public Repository(MongoClient client) { + this.client = client; + } + + public void exampleFind() { + client.getDatabase("myDatabase").getCollection("myCollection") + .aggregate(List.of( + Aggregates.sort( + Sorts.descending() + ) + )); + } +} + """, + ) + fun `should autocomplete fields from the current namespace in Sorts#descending of an Aggregates#sort stage`( + fixture: CodeInsightTestFixture, + ) { + fixture.specifyDialect(JavaDriverDialect) + + val (dataSource, readModelProvider) = fixture.setupConnection() + val namespace = Namespace("myDatabase", "myCollection") + + `when`( + readModelProvider.slice(eq(dataSource), eq(GetCollectionSchema.Slice(namespace))) + ).thenReturn( + GetCollectionSchema( + CollectionSchema( + namespace, + BsonObject( + mapOf( + "myField" to BsonString, + ), + ), + ), + ), + ) + + fixture.type('"') + val elements = fixture.completeBasic() + + assertTrue( + elements.containsElements { + it.lookupString == "myField" + }, + ) + } + + @ParsingTest( + fileName = "Repository.java", + value = """ +import com.mongodb.client.FindIterable; +import com.mongodb.client.MongoClient; +import com.mongodb.client.MongoCollection; +import com.mongodb.client.model.Aggregates; +import com.mongodb.client.model.Projections; +import com.mongodb.client.model.Sorts; +import org.bson.Document; +import org.bson.types.ObjectId; +import java.util.List; +import static com.mongodb.client.model.Filters.*; +import static com.mongodb.client.model.Updates.*; + +public class Repository { + private final MongoClient client; + + public Repository(MongoClient client) { + this.client = client; + } + + public void exampleFind() { + client.getDatabase("myDatabase").getCollection("myCollection") + .aggregate(List.of( + Aggregates.sort( + Sorts.orderBy( + Sorts.descending() + ) + ) + )); + } +} + """, + ) + fun `should autocomplete fields from the current namespace in Sorts#orderBy of an Aggregates#sort stage`( + fixture: CodeInsightTestFixture, + ) { + fixture.specifyDialect(JavaDriverDialect) + + val (dataSource, readModelProvider) = fixture.setupConnection() + val namespace = Namespace("myDatabase", "myCollection") + + `when`( + readModelProvider.slice(eq(dataSource), eq(GetCollectionSchema.Slice(namespace))) + ).thenReturn( + GetCollectionSchema( + CollectionSchema( + namespace, + BsonObject( + mapOf( + "myField" to BsonString, + ), + ), + ), + ), + ) + + fixture.type('"') + val elements = fixture.completeBasic() + + assertTrue( + elements.containsElements { + it.lookupString == "myField" + }, + ) + } } diff --git a/packages/jetbrains-plugin/src/test/kotlin/com/mongodb/jbplugin/inspections/impl/JavaDriverFieldCheckLinterInspectionTest.kt b/packages/jetbrains-plugin/src/test/kotlin/com/mongodb/jbplugin/inspections/impl/JavaDriverFieldCheckLinterInspectionTest.kt index fa94a0fb..8b36a948 100644 --- a/packages/jetbrains-plugin/src/test/kotlin/com/mongodb/jbplugin/inspections/impl/JavaDriverFieldCheckLinterInspectionTest.kt +++ b/packages/jetbrains-plugin/src/test/kotlin/com/mongodb/jbplugin/inspections/impl/JavaDriverFieldCheckLinterInspectionTest.kt @@ -367,4 +367,85 @@ public class Repository { fixture.enableInspections(FieldCheckInspectionBridge::class.java) fixture.testHighlighting() } + + @ParsingTest( + fileName = "Repository.java", + value = """ +import com.mongodb.client.AggregateIterable; +import com.mongodb.client.MongoClient; +import com.mongodb.client.MongoCollection; +import com.mongodb.client.model.Aggregates; +import com.mongodb.client.model.Projections; +import com.mongodb.client.model.Sorts; +import org.bson.Document; +import org.bson.conversions.Bson; +import org.bson.types.ObjectId; +import java.util.List; +import static com.mongodb.client.model.Filters.*; + +public class Repository { + private final MongoClient client; + + public Repository(MongoClient client) { + this.client = client; + } + + public AggregateIterable exampleAggregateAscending() { + return client.getDatabase("myDatabase") + .getCollection("myCollection") + .aggregate(List.of( + Aggregates.sort( + Sorts.ascending("nonExistingField") + ) + )); + } + + public AggregateIterable exampleAggregateDescending() { + return client.getDatabase("myDatabase") + .getCollection("myCollection") + .aggregate(List.of( + Aggregates.sort( + Sorts.descending("nonExistingField") + ) + )); + } + + private Bson getAnotherSort() { + return Sorts.descending("nonExistingField"); + } + + public AggregateIterable exampleAggregateOrderBy() { + Bson ascendingSort = Sorts.ascending("nonExistingField"); + return client.getDatabase("myDatabase") + .getCollection("myCollection") + .aggregate(List.of( + Aggregates.sort( + Sorts.orderBy( + Sorts.descending("nonExistingField"), + ascendingSort, + getAnotherSort() + ) + ) + )); + } +} + """, + ) + fun `shows an inspection for Aggregates#sort call when the field does not exist in the current namespace`( + fixture: CodeInsightTestFixture, + ) { + val (dataSource, readModelProvider) = fixture.setupConnection() + fixture.specifyDialect(JavaDriverDialect) + + `when`( + readModelProvider.slice(eq(dataSource), any()) + ).thenReturn( + GetCollectionSchema( + CollectionSchema(Namespace("myDatabase", "myCollection"), BsonObject(emptyMap())) + ), + ) + + fixture.enableInspections(FieldCheckInspectionBridge::class.java) + fixture.testHighlighting() + } } diff --git a/packages/jetbrains-plugin/src/test/resources/project-fixtures/basic-java-project-with-mongodb/src/main/java/alt/mongodb/javadriver/JavaDriverRepository.java b/packages/jetbrains-plugin/src/test/resources/project-fixtures/basic-java-project-with-mongodb/src/main/java/alt/mongodb/javadriver/JavaDriverRepository.java index 8a33cddc..e32c475d 100644 --- a/packages/jetbrains-plugin/src/test/resources/project-fixtures/basic-java-project-with-mongodb/src/main/java/alt/mongodb/javadriver/JavaDriverRepository.java +++ b/packages/jetbrains-plugin/src/test/resources/project-fixtures/basic-java-project-with-mongodb/src/main/java/alt/mongodb/javadriver/JavaDriverRepository.java @@ -5,6 +5,7 @@ import com.mongodb.client.model.Aggregates; import com.mongodb.client.model.Filters; import com.mongodb.client.model.Projections; +import com.mongodb.client.model.Sorts; import org.bson.Document; import java.util.ArrayList; @@ -62,6 +63,11 @@ public List queryMoviesByYear(String year) { Projections.fields( Projections.include("year", "plot") ) + ), + Aggregates.sort( + Sorts.orderBy( + Sorts.ascending("asd", "qwe") + ) ) ) ) diff --git a/packages/mongodb-dialects/java-driver/src/main/kotlin/com/mongodb/jbplugin/dialects/javadriver/glossary/JavaDriverDialectParser.kt b/packages/mongodb-dialects/java-driver/src/main/kotlin/com/mongodb/jbplugin/dialects/javadriver/glossary/JavaDriverDialectParser.kt index 078a69b1..e0a18fcb 100644 --- a/packages/mongodb-dialects/java-driver/src/main/kotlin/com/mongodb/jbplugin/dialects/javadriver/glossary/JavaDriverDialectParser.kt +++ b/packages/mongodb-dialects/java-driver/src/main/kotlin/com/mongodb/jbplugin/dialects/javadriver/glossary/JavaDriverDialectParser.kt @@ -26,6 +26,7 @@ private const val UPDATES_FQN = "com.mongodb.client.model.Updates" private const val AGGREGATES_FQN = "com.mongodb.client.model.Aggregates" private const val PROJECTIONS_FQN = "com.mongodb.client.model.Projections" private const val ACCUMULATORS_FQN = "com.mongodb.client.model.Accumulators" +private const val SORTS_FQN = "com.mongodb.client.model.Sorts" private const val JAVA_LIST_FQN = "java.util.List" private const val JAVA_ARRAYS_FQN = "java.util.Arrays" private val PARSEABLE_AGGREGATION_STAGE_METHODS = listOf( @@ -120,7 +121,7 @@ object JavaDriverDialectParser : DialectParser { // we have at least 1 argument in the current method call // try to get the relevant filter calls, or avoid parsing the query at all - val argumentAsFilters = resolveToFiltersCall(filterExpression) + val argumentAsFilters = resolveBsonBuilderCall(filterExpression, FILTERS_FQN) return argumentAsFilters?.let { val parsedQuery = parseFilterExpression(argumentAsFilters) parsedQuery?.let { @@ -161,7 +162,7 @@ object JavaDriverDialectParser : DialectParser { val updateExpression = currentCall.argumentList.expressions.getOrNull(startIndex) ?: return emptyList() - val argumentAsUpdates = resolveToUpdatesCall(updateExpression) + val argumentAsUpdates = resolveBsonBuilderCall(updateExpression, UPDATES_FQN) // parse only if it's a call to `updates` methods return argumentAsUpdates?.let { val parsedQuery = parseUpdatesExpression(argumentAsUpdates) @@ -226,12 +227,16 @@ object JavaDriverDialectParser : DialectParser { val containingClass = methodCall.resolveMethod()?.containingClass ?: return false return containingClass.qualifiedName == AGGREGATES_FQN || - containingClass.qualifiedName == PROJECTIONS_FQN + containingClass.qualifiedName == PROJECTIONS_FQN || + containingClass.qualifiedName == SORTS_FQN } - private fun resolveToFiltersCall(element: PsiElement): PsiMethodCallExpression? { + private fun resolveBsonBuilderCall( + element: PsiElement, + classQualifiedName: String + ): PsiMethodCallExpression? { return element.resolveToMethodCallExpression { _, methodCall -> - methodCall.containingClass?.qualifiedName == FILTERS_FQN + methodCall.containingClass?.qualifiedName == classQualifiedName } } @@ -281,7 +286,7 @@ object JavaDriverDialectParser : DialectParser { Named(Name.from(method.name)), HasFilter( filter.argumentList.expressions - .mapNotNull { resolveToFiltersCall(it) } + .mapNotNull { resolveBsonBuilderCall(it, FILTERS_FQN) } .mapNotNull { parseFilterExpression(it) }, ), ), @@ -340,12 +345,6 @@ object JavaDriverDialectParser : DialectParser { return null } - private fun resolveToUpdatesCall(element: PsiElement): PsiMethodCallExpression? { - return element.resolveToMethodCallExpression { _, methodCall -> - methodCall.containingClass?.qualifiedName == UPDATES_FQN - } - } - private fun parseUpdatesExpression(filter: PsiMethodCallExpression): Node? { val method = filter.resolveMethod() ?: return null if (method.isVarArgs) { @@ -356,7 +355,7 @@ object JavaDriverDialectParser : DialectParser { Named(Name.from(method.name)), HasFilter( filter.argumentList.expressions - .mapNotNull { resolveToUpdatesCall(it) } + .mapNotNull { resolveBsonBuilderCall(it, UPDATES_FQN) } .mapNotNull { parseUpdatesExpression(it) }, ), ), @@ -417,16 +416,28 @@ object JavaDriverDialectParser : DialectParser { when (stageCallMethod.name) { "match" -> { + val nodeWithFilters: (List>) -> Node = + { filter -> + Node( + source = stageCall, + components = listOf( + Named(Name.MATCH), + HasFilter(filter) + ) + ) + } // There will only be one argument to Aggregates.match and that has to be the Bson // filters. We retrieve that and resolve the values. val filterExpression = stageCall.argumentList.expressions.getOrNull(0) - ?: return null + ?: return nodeWithFilters(emptyList()) - val resolvedFilterExpression = resolveToFiltersCall(filterExpression) - ?: return null + val resolvedFilterExpression = resolveBsonBuilderCall( + filterExpression, + FILTERS_FQN, + ) ?: return nodeWithFilters(emptyList()) val parsedFilter = parseFilterExpression(resolvedFilterExpression) - ?: return null + ?: return nodeWithFilters(emptyList()) return Node( source = stageCall, @@ -437,29 +448,41 @@ object JavaDriverDialectParser : DialectParser { ) } - "project" -> { - val nodeWithProjections: (List>) -> Node = - { projections -> + "project", + "sort" -> { + val stageName = Name.from(stageCallMethod.name) + val nodeWithParsedComponents: (List>) -> Node = + { components -> Node( source = stageCall, components = listOf( - Named(Name.PROJECT), - HasProjections(projections) + Named(stageName), + if (stageName == Name.PROJECT) { + HasProjections(components) + } else { + HasSorts(components) + } ) ) } - // There will only be one argument to Aggregates.project and that has to be the Bson - // projections. We retrieve that and resolve the values. - val projectionExpression = stageCall.argumentList.expressions.getOrNull(0) - ?: return nodeWithProjections(emptyList()) + // There will only be one argument to Aggregates.project and Aggregates.sort and + // that has to be the Bson projections or sorts. We retrieve that and resolve the + // values. + val bsonBuilderExpression = stageCall.argumentList.expressions.getOrNull(0) + ?: return nodeWithParsedComponents(emptyList()) - val resolvedProjectionExpression = resolveToProjectionCall(projectionExpression) - ?: return nodeWithProjections(emptyList()) + val resolvedBsonBuilderExpression = resolveBsonBuilderCall( + bsonBuilderExpression, + if (stageName == Name.PROJECT) PROJECTIONS_FQN else SORTS_FQN, + ) ?: return nodeWithParsedComponents(emptyList()) - val parsedProjections = parseProjectionExpression(resolvedProjectionExpression) + val parsedComponents = parseBsonBuilderCallsSimilarToProjections( + resolvedBsonBuilderExpression, + if (stageName == Name.PROJECT) PROJECTIONS_FQN else SORTS_FQN, + ) - return nodeWithProjections(parsedProjections) + return nodeWithParsedComponents(parsedComponents) } "group" -> { // the first parameter of group is going to be a string expression @@ -498,46 +521,18 @@ object JavaDriverDialectParser : DialectParser { } } - private fun resolveToAccumulatorCall(element: PsiElement): PsiMethodCallExpression? { - return element.resolveToMethodCallExpression { _, methodCall -> - methodCall.containingClass?.qualifiedName == ACCUMULATORS_FQN - } - } - - private fun parseComputedExpression(element: PsiElement, createsNewField: Boolean = true): HasValueReference { - return HasValueReference( - when (val expression = element.tryToResolveAsConstantString()) { - null -> HasValueReference.Unknown as HasValueReference.ValueReference - else -> HasValueReference.Computed( - element, - Node( - element, - listOf( - if (createsNewField) { - HasFieldReference( - Computed(element, expression.trim('$'), expression) - ) - } else { - HasFieldReference( - FromSchema(element, expression.trim('$'), expression) - ) - } - ) - ) - ) - } - ) - } - private fun parseProjectionExpression(expression: PsiMethodCallExpression): List> { val methodCall = expression.resolveMethod() ?: return emptyList() return when (methodCall.name) { - "fields" -> expression.getVarArgsOrIterableArgs() - .mapNotNull(::resolveToProjectionCall) - .flatMap(::parseProjectionExpression) + "fields", + "orderBy" -> expression.getVarArgsOrIterableArgs() + .mapNotNull { resolveBsonBuilderCall(it, classQualifiedName) } + .flatMap { parseBsonBuilderCallsSimilarToProjections(it, classQualifiedName) } "include", - "exclude" -> expression.getVarArgsOrIterableArgs() + "exclude", + "ascending", + "descending" -> expression.getVarArgsOrIterableArgs() .mapNotNull { val fieldReference = resolveFieldNameFromExpression(it) val methodName = Name.from(methodCall.name) @@ -550,7 +545,13 @@ object JavaDriverDialectParser : DialectParser { HasValueReference( reference = HasValueReference.Inferred( source = it, - value = if (methodName == Name.INCLUDE) 1 else -1, + value = when (methodName) { + Name.INCLUDE, + Name.ASCENDING -> 1 + Name.EXCLUDE -> 0 + // Else is just the descending case + else -> -1 + }, type = BsonInt32, ) ) diff --git a/packages/mongodb-dialects/java-driver/src/test/kotlin/com/mongodb/jbplugin/dialects/javadriver/glossary/aggregationparser/MatchStageParserTest.kt b/packages/mongodb-dialects/java-driver/src/test/kotlin/com/mongodb/jbplugin/dialects/javadriver/glossary/aggregationparser/MatchStageParserTest.kt index 2083aaf8..b33411b9 100644 --- a/packages/mongodb-dialects/java-driver/src/test/kotlin/com/mongodb/jbplugin/dialects/javadriver/glossary/aggregationparser/MatchStageParserTest.kt +++ b/packages/mongodb-dialects/java-driver/src/test/kotlin/com/mongodb/jbplugin/dialects/javadriver/glossary/aggregationparser/MatchStageParserTest.kt @@ -34,6 +34,45 @@ import java.util.List; import static com.mongodb.client.model.Filters.*; +public final class Aggregation { + private final MongoCollection collection; + + public Aggregation(MongoClient client) { + this.collection = client.getDatabase("simple").getCollection("books"); + } + + public AggregateIterable findBookById(ObjectId id) { + return this.collection.aggregate(List.of( + Aggregates.match() + )); + } +} + """, + ) + fun `(Aggregates#match call) should be able to parse an empty call`(psiFile: PsiFile) { + val aggregate = psiFile.getQueryAtMethod("Aggregation", "findBookById") + val parsedAggregate = JavaDriverDialect.parser.parse(aggregate) + val hasAggregation = parsedAggregate.component>() + assertEquals(hasAggregation?.children?.size, 1) + val matchStageNode = hasAggregation?.children?.first()!! + assertEquals(Name.MATCH, matchStageNode.component()!!.name) + } + + @ParsingTest( + fileName = "Aggregation.java", + value = """ +import com.mongodb.client.AggregateIterable; +import com.mongodb.client.MongoClient; +import com.mongodb.client.MongoCollection; +import com.mongodb.client.model.Aggregates; +import com.mongodb.client.model.Filters; +import org.bson.Document; +import org.bson.types.ObjectId; + +import java.util.List; + +import static com.mongodb.client.model.Filters.*; + public final class Aggregation { private final MongoCollection collection; diff --git a/packages/mongodb-dialects/java-driver/src/test/kotlin/com/mongodb/jbplugin/dialects/javadriver/glossary/aggregationparser/ProjectStageParserTest.kt b/packages/mongodb-dialects/java-driver/src/test/kotlin/com/mongodb/jbplugin/dialects/javadriver/glossary/aggregationparser/ProjectStageParserTest.kt index 20976d4f..7f1f033e 100644 --- a/packages/mongodb-dialects/java-driver/src/test/kotlin/com/mongodb/jbplugin/dialects/javadriver/glossary/aggregationparser/ProjectStageParserTest.kt +++ b/packages/mongodb-dialects/java-driver/src/test/kotlin/com/mongodb/jbplugin/dialects/javadriver/glossary/aggregationparser/ProjectStageParserTest.kt @@ -1202,7 +1202,7 @@ public final class Aggregation { val titleProjectionValueRef = (titleProjection.component>()!!.reference) as Inferred assertEquals( - -1, + 0, titleProjectionValueRef.value ) @@ -1219,7 +1219,7 @@ public final class Aggregation { val yearProjectionValueRef = (yearProjection.component>()!!.reference) as Inferred assertEquals( - -1, + 0, yearProjectionValueRef.value ) @@ -1236,7 +1236,7 @@ public final class Aggregation { val authorProjectionValueRef = (authorProjection.component>()!!.reference) as Inferred assertEquals( - -1, + 0, authorProjectionValueRef.value ) } @@ -1295,7 +1295,7 @@ public final class Aggregation { val publishedProjectionValueRef = (publishedProjection.component>()!!.reference) as Inferred assertEquals( - -1, + 0, publishedProjectionValueRef.value ) @@ -1312,7 +1312,7 @@ public final class Aggregation { val authorProjectionValueRef = (authorProjection.component>()!!.reference) as Inferred assertEquals( - -1, + 0, authorProjectionValueRef.value ) } diff --git a/packages/mongodb-dialects/java-driver/src/test/kotlin/com/mongodb/jbplugin/dialects/javadriver/glossary/aggregationparser/SortStageParserTest.kt b/packages/mongodb-dialects/java-driver/src/test/kotlin/com/mongodb/jbplugin/dialects/javadriver/glossary/aggregationparser/SortStageParserTest.kt new file mode 100644 index 00000000..29fe6a01 --- /dev/null +++ b/packages/mongodb-dialects/java-driver/src/test/kotlin/com/mongodb/jbplugin/dialects/javadriver/glossary/aggregationparser/SortStageParserTest.kt @@ -0,0 +1,1324 @@ +package com.mongodb.jbplugin.dialects.javadriver.glossary.aggregationparser + +import com.intellij.psi.PsiElement +import com.intellij.psi.PsiFile +import com.mongodb.jbplugin.dialects.javadriver.IntegrationTest +import com.mongodb.jbplugin.dialects.javadriver.ParsingTest +import com.mongodb.jbplugin.dialects.javadriver.getQueryAtMethod +import com.mongodb.jbplugin.dialects.javadriver.glossary.JavaDriverDialect +import com.mongodb.jbplugin.mql.Node +import com.mongodb.jbplugin.mql.components.HasAggregation +import com.mongodb.jbplugin.mql.components.HasFieldReference +import com.mongodb.jbplugin.mql.components.HasFieldReference.FromSchema +import com.mongodb.jbplugin.mql.components.HasSorts +import com.mongodb.jbplugin.mql.components.HasValueReference +import com.mongodb.jbplugin.mql.components.HasValueReference.Inferred +import com.mongodb.jbplugin.mql.components.Name +import com.mongodb.jbplugin.mql.components.Named +import org.junit.jupiter.api.Assertions.assertEquals + +@IntegrationTest +class SortStageParserTest { + @ParsingTest( + fileName = "Aggregation.java", + value = """ +import com.mongodb.client.AggregateIterable; +import com.mongodb.client.MongoClient; +import com.mongodb.client.MongoCollection; +import com.mongodb.client.model.Aggregates; +import com.mongodb.client.model.Filters; +import org.bson.Document; +import org.bson.types.ObjectId; + +import java.util.List; + +import static com.mongodb.client.model.Filters.*; + +public final class Aggregation { + private final MongoCollection collection; + + public Aggregation(MongoClient client) { + this.collection = client.getDatabase("simple").getCollection("books"); + } + + public AggregateIterable getAllBookTitles(ObjectId id) { + return this.collection.aggregate(List.of( + Aggregates.sort() + )); + } +} + """ + ) + fun `should be able to parse an empty sort call`(psiFile: PsiFile) { + val aggregate = psiFile.getQueryAtMethod("Aggregation", "getAllBookTitles") + val parsedAggregate = JavaDriverDialect.parser.parse(aggregate) + val hasAggregation = parsedAggregate.component>() + assertEquals(1, hasAggregation?.children?.size) + + val sortStageNode = hasAggregation?.children?.get(0)!! + + val named = sortStageNode.component()!! + assertEquals(Name.SORT, named.name) + + assertEquals(0, sortStageNode.component>()!!.children.size) + } + + @ParsingTest( + fileName = "Aggregation.java", + value = """ +import com.mongodb.client.AggregateIterable; +import com.mongodb.client.MongoClient; +import com.mongodb.client.MongoCollection; +import com.mongodb.client.model.Aggregates; +import com.mongodb.client.model.Filters; +import com.mongodb.client.model.Sorts; +import org.bson.Document; +import org.bson.types.ObjectId; + +import java.util.List; + +import static com.mongodb.client.model.Filters.*; + +public final class Aggregation { + private final MongoCollection collection; + + public Aggregation(MongoClient client) { + this.collection = client.getDatabase("simple").getCollection("books"); + } + + private String getAuthorField() { + return "author"; + } + + public AggregateIterable getAllBookTitles(ObjectId id) { + String yearField = "year"; + return this.collection.aggregate(List.of( + Aggregates.sort( + Sorts.ascending("title", yearField, getAuthorField()) + ) + )); + } +} + """ + ) + fun `Sorts#ascending - should be able to parse with varargs`(psiFile: PsiFile) { + val aggregate = psiFile.getQueryAtMethod( + "Aggregation", + "getAllBookTitles" + ) + val parsedAggregate = JavaDriverDialect.parser.parse(aggregate) + val hasAggregation = parsedAggregate.component>() + assertEquals(1, hasAggregation?.children?.size) + + val sortStageNode = hasAggregation?.children?.get(0)!! + + commonAssertionsForAscendingSort(sortStageNode) + } + + @ParsingTest( + fileName = "Aggregation.java", + value = """ +import com.mongodb.client.AggregateIterable; +import com.mongodb.client.MongoClient; +import com.mongodb.client.MongoCollection; +import com.mongodb.client.model.Aggregates; +import com.mongodb.client.model.Filters; +import com.mongodb.client.model.Sorts; +import org.bson.Document; +import org.bson.types.ObjectId; + +import java.util.List; + +import static com.mongodb.client.model.Filters.*; + +public final class Aggregation { + private final MongoCollection collection; + + public Aggregation(MongoClient client) { + this.collection = client.getDatabase("simple").getCollection("books"); + } + + private String getAuthorField() { + return "author"; + } + + public AggregateIterable getAllBookTitles(ObjectId id) { + String yearField = "year"; + return this.collection.aggregate(List.of( + Aggregates.sort( + Sorts.ascending( + List.of("title", yearField, getAuthorField()) + ) + ) + )); + } +} + """ + ) + fun `Sorts#ascending - should be able to parse with List#of`(psiFile: PsiFile) { + val aggregate = psiFile.getQueryAtMethod( + "Aggregation", + "getAllBookTitles" + ) + val parsedAggregate = JavaDriverDialect.parser.parse(aggregate) + val hasAggregation = parsedAggregate.component>() + assertEquals(1, hasAggregation?.children?.size) + + val sortStageNode = hasAggregation?.children?.get(0)!! + + commonAssertionsForAscendingSort(sortStageNode) + } + + @ParsingTest( + fileName = "Aggregation.java", + value = """ +import com.mongodb.client.AggregateIterable; +import com.mongodb.client.MongoClient; +import com.mongodb.client.MongoCollection; +import com.mongodb.client.model.Aggregates; +import com.mongodb.client.model.Filters; +import com.mongodb.client.model.Sorts; +import org.bson.Document; +import org.bson.types.ObjectId; + +import java.util.List; + +import static com.mongodb.client.model.Filters.*; + +public final class Aggregation { + private final MongoCollection collection; + + public Aggregation(MongoClient client) { + this.collection = client.getDatabase("simple").getCollection("books"); + } + + private String getAuthorField() { + return "author"; + } + + public AggregateIterable getAllBookTitles(ObjectId id) { + String yearField = "year"; + List projectedFields = List.of("title", yearField, getAuthorField()); + return this.collection.aggregate(List.of( + Aggregates.sort( + Sorts.ascending( + projectedFields + ) + ) + )); + } +} + """ + ) + fun `Sorts#ascending - should be able to parse with List#of when the list is a variable`( + psiFile: PsiFile + ) { + val aggregate = psiFile.getQueryAtMethod( + "Aggregation", + "getAllBookTitles" + ) + val parsedAggregate = JavaDriverDialect.parser.parse(aggregate) + val hasAggregation = parsedAggregate.component>() + assertEquals(1, hasAggregation?.children?.size) + + val sortStageNode = hasAggregation?.children?.get(0)!! + + commonAssertionsForAscendingSort(sortStageNode) + } + + @ParsingTest( + fileName = "Aggregation.java", + value = """ +import com.mongodb.client.AggregateIterable; +import com.mongodb.client.MongoClient; +import com.mongodb.client.MongoCollection; +import com.mongodb.client.model.Aggregates; +import com.mongodb.client.model.Filters; +import com.mongodb.client.model.Sorts; +import org.bson.Document; +import org.bson.types.ObjectId; + +import java.util.List; + +import static com.mongodb.client.model.Filters.*; + +public final class Aggregation { + private final MongoCollection collection; + + public Aggregation(MongoClient client) { + this.collection = client.getDatabase("simple").getCollection("books"); + } + + private String getAuthorField() { + return "author"; + } + + private List getProjectedFields() { + String yearField = "year"; + return List.of("title", yearField, getAuthorField()); + } + + public AggregateIterable getAllBookTitles(ObjectId id) { + return this.collection.aggregate(List.of( + Aggregates.sort( + Sorts.ascending( + getProjectedFields() + ) + ) + )); + } +} + """ + ) + fun `Sorts#ascending - should be able to parse with List#of when the list is from a method call`( + psiFile: PsiFile + ) { + val aggregate = psiFile.getQueryAtMethod( + "Aggregation", + "getAllBookTitles" + ) + val parsedAggregate = JavaDriverDialect.parser.parse(aggregate) + val hasAggregation = parsedAggregate.component>() + assertEquals(1, hasAggregation?.children?.size) + + val sortStageNode = hasAggregation?.children?.get(0)!! + + commonAssertionsForAscendingSort(sortStageNode) + } + + @ParsingTest( + fileName = "Aggregation.java", + value = """ +import com.mongodb.client.AggregateIterable; +import com.mongodb.client.MongoClient; +import com.mongodb.client.MongoCollection; +import com.mongodb.client.model.Aggregates; +import com.mongodb.client.model.Filters; +import com.mongodb.client.model.Sorts; +import org.bson.Document; +import org.bson.types.ObjectId; + +import java.util.List; +import java.util.Arrays; + +import static com.mongodb.client.model.Filters.*; + +public final class Aggregation { + private final MongoCollection collection; + + public Aggregation(MongoClient client) { + this.collection = client.getDatabase("simple").getCollection("books"); + } + + private String getAuthorField() { + return "author"; + } + + public AggregateIterable getAllBookTitles(ObjectId id) { + String yearField = "year"; + return this.collection.aggregate(List.of( + Aggregates.sort( + Sorts.ascending( + Arrays.asList("title", yearField, getAuthorField()) + ) + ) + )); + } +} + """ + ) + fun `Sorts#ascending - should be able to parse with Arrays#asList`(psiFile: PsiFile) { + val aggregate = psiFile.getQueryAtMethod( + "Aggregation", + "getAllBookTitles" + ) + val parsedAggregate = JavaDriverDialect.parser.parse(aggregate) + val hasAggregation = parsedAggregate.component>() + assertEquals(1, hasAggregation?.children?.size) + + val sortStageNode = hasAggregation?.children?.get(0)!! + + commonAssertionsForAscendingSort(sortStageNode) + } + + @ParsingTest( + fileName = "Aggregation.java", + value = """ +import com.mongodb.client.AggregateIterable; +import com.mongodb.client.MongoClient; +import com.mongodb.client.MongoCollection; +import com.mongodb.client.model.Aggregates; +import com.mongodb.client.model.Filters; +import com.mongodb.client.model.Sorts; +import org.bson.Document; +import org.bson.types.ObjectId; + +import java.util.Arrays; +import java.util.List; + +import static com.mongodb.client.model.Filters.*; + +public final class Aggregation { + private final MongoCollection collection; + + public Aggregation(MongoClient client) { + this.collection = client.getDatabase("simple").getCollection("books"); + } + + private String getAuthorField() { + return "author"; + } + + public AggregateIterable getAllBookTitles(ObjectId id) { + String yearField = "year"; + List projectedFields = Arrays.asList("title", yearField, getAuthorField()); + return this.collection.aggregate(List.of( + Aggregates.sort( + Sorts.ascending( + projectedFields + ) + ) + )); + } +} + """ + ) + fun `Sorts#ascending - should be able to parse with Arrays#asList when the list is a variable`( + psiFile: PsiFile + ) { + val aggregate = psiFile.getQueryAtMethod( + "Aggregation", + "getAllBookTitles" + ) + val parsedAggregate = JavaDriverDialect.parser.parse(aggregate) + val hasAggregation = parsedAggregate.component>() + assertEquals(1, hasAggregation?.children?.size) + + val sortStageNode = hasAggregation?.children?.get(0)!! + + commonAssertionsForAscendingSort(sortStageNode) + } + + @ParsingTest( + fileName = "Aggregation.java", + value = """ +import com.mongodb.client.AggregateIterable; +import com.mongodb.client.MongoClient; +import com.mongodb.client.MongoCollection; +import com.mongodb.client.model.Aggregates; +import com.mongodb.client.model.Filters; +import com.mongodb.client.model.Sorts; +import org.bson.Document; +import org.bson.types.ObjectId; + +import java.util.Arrays; +import java.util.List; + +import static com.mongodb.client.model.Filters.*; + +public final class Aggregation { + private final MongoCollection collection; + + public Aggregation(MongoClient client) { + this.collection = client.getDatabase("simple").getCollection("books"); + } + + private String getAuthorField() { + return "author"; + } + + private List getProjectedFields() { + String yearField = "year"; + return Arrays.asList("title", yearField, getAuthorField()); + } + + public AggregateIterable getAllBookTitles(ObjectId id) { + return this.collection.aggregate(List.of( + Aggregates.sort( + Sorts.ascending( + getProjectedFields() + ) + ) + )); + } +} + """ + ) + fun `Sorts#ascending - should be able to parse with Arrays#asList when the list is from a method call`( + psiFile: PsiFile + ) { + val aggregate = psiFile.getQueryAtMethod( + "Aggregation", + "getAllBookTitles" + ) + val parsedAggregate = JavaDriverDialect.parser.parse(aggregate) + val hasAggregation = parsedAggregate.component>() + assertEquals(1, hasAggregation?.children?.size) + + val sortStageNode = hasAggregation?.children?.get(0)!! + + commonAssertionsForAscendingSort(sortStageNode) + } + + // //////////////////////////////////////////////////////////////////////// + + @ParsingTest( + fileName = "Aggregation.java", + value = """ +import com.mongodb.client.AggregateIterable; +import com.mongodb.client.MongoClient; +import com.mongodb.client.MongoCollection; +import com.mongodb.client.model.Aggregates; +import com.mongodb.client.model.Filters; +import com.mongodb.client.model.Sorts; +import org.bson.Document; +import org.bson.types.ObjectId; + +import java.util.List; + +import static com.mongodb.client.model.Filters.*; + +public final class Aggregation { + private final MongoCollection collection; + + public Aggregation(MongoClient client) { + this.collection = client.getDatabase("simple").getCollection("books"); + } + + private String getAuthorField() { + return "author"; + } + + public AggregateIterable getAllBookTitles(ObjectId id) { + String yearField = "year"; + return this.collection.aggregate(List.of( + Aggregates.sort( + Sorts.descending("title", yearField, getAuthorField()) + ) + )); + } +} + """ + ) + fun `Sorts#descending - should be able to parse with varargs`(psiFile: PsiFile) { + val aggregate = psiFile.getQueryAtMethod( + "Aggregation", + "getAllBookTitles" + ) + val parsedAggregate = JavaDriverDialect.parser.parse(aggregate) + val hasAggregation = parsedAggregate.component>() + assertEquals(1, hasAggregation?.children?.size) + + val sortStageNode = hasAggregation?.children?.get(0)!! + + commonAssertionsForDescendingSort(sortStageNode) + } + + @ParsingTest( + fileName = "Aggregation.java", + value = """ +import com.mongodb.client.AggregateIterable; +import com.mongodb.client.MongoClient; +import com.mongodb.client.MongoCollection; +import com.mongodb.client.model.Aggregates; +import com.mongodb.client.model.Filters; +import com.mongodb.client.model.Sorts; +import org.bson.Document; +import org.bson.types.ObjectId; + +import java.util.List; + +import static com.mongodb.client.model.Filters.*; + +public final class Aggregation { + private final MongoCollection collection; + + public Aggregation(MongoClient client) { + this.collection = client.getDatabase("simple").getCollection("books"); + } + + private String getAuthorField() { + return "author"; + } + + public AggregateIterable getAllBookTitles(ObjectId id) { + String yearField = "year"; + return this.collection.aggregate(List.of( + Aggregates.sort( + Sorts.descending( + List.of("title", yearField, getAuthorField()) + ) + ) + )); + } +} + """ + ) + fun `Sorts#descending - should be able to parse with List#of`(psiFile: PsiFile) { + val aggregate = psiFile.getQueryAtMethod( + "Aggregation", + "getAllBookTitles" + ) + val parsedAggregate = JavaDriverDialect.parser.parse(aggregate) + val hasAggregation = parsedAggregate.component>() + assertEquals(1, hasAggregation?.children?.size) + + val sortStageNode = hasAggregation?.children?.get(0)!! + + commonAssertionsForDescendingSort(sortStageNode) + } + + @ParsingTest( + fileName = "Aggregation.java", + value = """ +import com.mongodb.client.AggregateIterable; +import com.mongodb.client.MongoClient; +import com.mongodb.client.MongoCollection; +import com.mongodb.client.model.Aggregates; +import com.mongodb.client.model.Filters; +import com.mongodb.client.model.Sorts; +import org.bson.Document; +import org.bson.types.ObjectId; + +import java.util.List; + +import static com.mongodb.client.model.Filters.*; + +public final class Aggregation { + private final MongoCollection collection; + + public Aggregation(MongoClient client) { + this.collection = client.getDatabase("simple").getCollection("books"); + } + + private String getAuthorField() { + return "author"; + } + + public AggregateIterable getAllBookTitles(ObjectId id) { + String yearField = "year"; + List projectedFields = List.of("title", yearField, getAuthorField()); + return this.collection.aggregate(List.of( + Aggregates.sort( + Sorts.descending( + projectedFields + ) + ) + )); + } +} + """ + ) + fun `Sorts#descending - should be able to parse with List#of when the list is a variable`( + psiFile: PsiFile + ) { + val aggregate = psiFile.getQueryAtMethod( + "Aggregation", + "getAllBookTitles" + ) + val parsedAggregate = JavaDriverDialect.parser.parse(aggregate) + val hasAggregation = parsedAggregate.component>() + assertEquals(1, hasAggregation?.children?.size) + + val sortStageNode = hasAggregation?.children?.get(0)!! + + commonAssertionsForDescendingSort(sortStageNode) + } + + @ParsingTest( + fileName = "Aggregation.java", + value = """ +import com.mongodb.client.AggregateIterable; +import com.mongodb.client.MongoClient; +import com.mongodb.client.MongoCollection; +import com.mongodb.client.model.Aggregates; +import com.mongodb.client.model.Filters; +import com.mongodb.client.model.Sorts; +import org.bson.Document; +import org.bson.types.ObjectId; + +import java.util.List; + +import static com.mongodb.client.model.Filters.*; + +public final class Aggregation { + private final MongoCollection collection; + + public Aggregation(MongoClient client) { + this.collection = client.getDatabase("simple").getCollection("books"); + } + + private String getAuthorField() { + return "author"; + } + + private List getProjectedFields() { + String yearField = "year"; + return List.of("title", yearField, getAuthorField()); + } + + public AggregateIterable getAllBookTitles(ObjectId id) { + return this.collection.aggregate(List.of( + Aggregates.sort( + Sorts.descending( + getProjectedFields() + ) + ) + )); + } +} + """ + ) + fun `Sorts#descending - should be able to parse with List#of when the list is from a method call`( + psiFile: PsiFile + ) { + val aggregate = psiFile.getQueryAtMethod( + "Aggregation", + "getAllBookTitles" + ) + val parsedAggregate = JavaDriverDialect.parser.parse(aggregate) + val hasAggregation = parsedAggregate.component>() + assertEquals(1, hasAggregation?.children?.size) + + val sortStageNode = hasAggregation?.children?.get(0)!! + + commonAssertionsForDescendingSort(sortStageNode) + } + + @ParsingTest( + fileName = "Aggregation.java", + value = """ +import com.mongodb.client.AggregateIterable; +import com.mongodb.client.MongoClient; +import com.mongodb.client.MongoCollection; +import com.mongodb.client.model.Aggregates; +import com.mongodb.client.model.Filters; +import com.mongodb.client.model.Sorts; +import org.bson.Document; +import org.bson.types.ObjectId; + +import java.util.List; +import java.util.Arrays; + +import static com.mongodb.client.model.Filters.*; + +public final class Aggregation { + private final MongoCollection collection; + + public Aggregation(MongoClient client) { + this.collection = client.getDatabase("simple").getCollection("books"); + } + + private String getAuthorField() { + return "author"; + } + + public AggregateIterable getAllBookTitles(ObjectId id) { + String yearField = "year"; + return this.collection.aggregate(List.of( + Aggregates.sort( + Sorts.descending( + Arrays.asList("title", yearField, getAuthorField()) + ) + ) + )); + } +} + """ + ) + fun `Sorts#descending - should be able to parse with Arrays#asList`(psiFile: PsiFile) { + val aggregate = psiFile.getQueryAtMethod( + "Aggregation", + "getAllBookTitles" + ) + val parsedAggregate = JavaDriverDialect.parser.parse(aggregate) + val hasAggregation = parsedAggregate.component>() + assertEquals(1, hasAggregation?.children?.size) + + val sortStageNode = hasAggregation?.children?.get(0)!! + + commonAssertionsForDescendingSort(sortStageNode) + } + + @ParsingTest( + fileName = "Aggregation.java", + value = """ +import com.mongodb.client.AggregateIterable; +import com.mongodb.client.MongoClient; +import com.mongodb.client.MongoCollection; +import com.mongodb.client.model.Aggregates; +import com.mongodb.client.model.Filters; +import com.mongodb.client.model.Sorts; +import org.bson.Document; +import org.bson.types.ObjectId; + +import java.util.Arrays; +import java.util.List; + +import static com.mongodb.client.model.Filters.*; + +public final class Aggregation { + private final MongoCollection collection; + + public Aggregation(MongoClient client) { + this.collection = client.getDatabase("simple").getCollection("books"); + } + + private String getAuthorField() { + return "author"; + } + + public AggregateIterable getAllBookTitles(ObjectId id) { + String yearField = "year"; + List projectedFields = Arrays.asList("title", yearField, getAuthorField()); + return this.collection.aggregate(List.of( + Aggregates.sort( + Sorts.descending( + projectedFields + ) + ) + )); + } +} + """ + ) + fun `Sorts#descending - should be able to parse with Arrays#asList when the list is a variable`( + psiFile: PsiFile + ) { + val aggregate = psiFile.getQueryAtMethod( + "Aggregation", + "getAllBookTitles" + ) + val parsedAggregate = JavaDriverDialect.parser.parse(aggregate) + val hasAggregation = parsedAggregate.component>() + assertEquals(1, hasAggregation?.children?.size) + + val sortStageNode = hasAggregation?.children?.get(0)!! + + commonAssertionsForDescendingSort(sortStageNode) + } + + @ParsingTest( + fileName = "Aggregation.java", + value = """ +import com.mongodb.client.AggregateIterable; +import com.mongodb.client.MongoClient; +import com.mongodb.client.MongoCollection; +import com.mongodb.client.model.Aggregates; +import com.mongodb.client.model.Filters; +import com.mongodb.client.model.Sorts; +import org.bson.Document; +import org.bson.types.ObjectId; + +import java.util.Arrays; +import java.util.List; + +import static com.mongodb.client.model.Filters.*; + +public final class Aggregation { + private final MongoCollection collection; + + public Aggregation(MongoClient client) { + this.collection = client.getDatabase("simple").getCollection("books"); + } + + private String getAuthorField() { + return "author"; + } + + private List getProjectedFields() { + String yearField = "year"; + return Arrays.asList("title", yearField, getAuthorField()); + } + + public AggregateIterable getAllBookTitles(ObjectId id) { + return this.collection.aggregate(List.of( + Aggregates.sort( + Sorts.descending( + getProjectedFields() + ) + ) + )); + } +} + """ + ) + fun `Sorts#descending - should be able to parse with Arrays#asList when the list is from a method call`( + psiFile: PsiFile + ) { + val aggregate = psiFile.getQueryAtMethod( + "Aggregation", + "getAllBookTitles" + ) + val parsedAggregate = JavaDriverDialect.parser.parse(aggregate) + val hasAggregation = parsedAggregate.component>() + assertEquals(1, hasAggregation?.children?.size) + + val sortStageNode = hasAggregation?.children?.get(0)!! + + commonAssertionsForDescendingSort(sortStageNode) + } + + // //////////////////////////////////////////////////// + + @ParsingTest( + fileName = "Aggregation.java", + value = """ +import com.mongodb.client.AggregateIterable; +import com.mongodb.client.MongoClient; +import com.mongodb.client.MongoCollection; +import com.mongodb.client.model.Aggregates; +import com.mongodb.client.model.Filters; +import com.mongodb.client.model.Sorts; +import org.bson.Document; +import org.bson.types.ObjectId; + +import java.util.Arrays; +import java.util.List; + +import static com.mongodb.client.model.Filters.*; + +public final class Aggregation { + private final MongoCollection collection; + + public Aggregation(MongoClient client) { + this.collection = client.getDatabase("simple").getCollection("books"); + } + + private String getAuthorField() { + return "author"; + } + + private Bson getThirdSort() { + return Sorts.descending(Arrays.asList("published", getAuthorField())); + } + + public AggregateIterable getAllBookTitles(ObjectId id) { + String yearField = "year"; + Bson secondSort = Sorts.ascending(List.of(yearField)); + return this.collection.aggregate(List.of( + Aggregates.sort( + Sorts.orderBy( + Sorts.ascending("title"), + secondSort, + getThirdSort() + ) + ) + )); + } +} + """ + ) + fun `Sort#orderBy - should be able to parse with varargs`(psiFile: PsiFile) { + val aggregate = psiFile.getQueryAtMethod( + "Aggregation", + "getAllBookTitles" + ) + val parsedAggregate = JavaDriverDialect.parser.parse(aggregate) + val hasAggregation = parsedAggregate.component>() + assertEquals(1, hasAggregation?.children?.size) + + val sortStageNode = hasAggregation?.children?.get(0)!! + + commonAssertionsForOrderBySort(sortStageNode) + } + + @ParsingTest( + fileName = "Aggregation.java", + value = """ +import com.mongodb.client.AggregateIterable; +import com.mongodb.client.MongoClient; +import com.mongodb.client.MongoCollection; +import com.mongodb.client.model.Aggregates; +import com.mongodb.client.model.Filters; +import com.mongodb.client.model.Sorts; +import org.bson.Document; +import org.bson.types.ObjectId; + +import java.util.Arrays; +import java.util.List; + +import static com.mongodb.client.model.Filters.*; + +public final class Aggregation { + private final MongoCollection collection; + + public Aggregation(MongoClient client) { + this.collection = client.getDatabase("simple").getCollection("books"); + } + + private String getAuthorField() { + return "author"; + } + + private Bson getThirdSort() { + return Sorts.descending(Arrays.asList("published", getAuthorField())); + } + + public AggregateIterable getAllBookTitles(ObjectId id) { + String yearField = "year"; + Bson secondSort = Sorts.ascending(List.of(yearField)); + return this.collection.aggregate(List.of( + Aggregates.sort( + Sorts.orderBy( + List.of( + Sorts.ascending("title"), + secondSort, + getThirdSort() + ) + ) + ) + )); + } +} + """ + ) + fun `Sort#orderBy - should be able to parse with List#of`(psiFile: PsiFile) { + val aggregate = psiFile.getQueryAtMethod( + "Aggregation", + "getAllBookTitles" + ) + val parsedAggregate = JavaDriverDialect.parser.parse(aggregate) + val hasAggregation = parsedAggregate.component>() + assertEquals(1, hasAggregation?.children?.size) + + val sortStageNode = hasAggregation?.children?.get(0)!! + + commonAssertionsForOrderBySort(sortStageNode) + } + + @ParsingTest( + fileName = "Aggregation.java", + value = """ +import com.mongodb.client.AggregateIterable; +import com.mongodb.client.MongoClient; +import com.mongodb.client.MongoCollection; +import com.mongodb.client.model.Aggregates; +import com.mongodb.client.model.Filters; +import com.mongodb.client.model.Sorts; +import org.bson.Document; +import org.bson.types.ObjectId; +import org.bson.conversions.Bson; + +import java.util.Arrays; +import java.util.List; + +import static com.mongodb.client.model.Filters.*; + +public final class Aggregation { + private final MongoCollection collection; + + public Aggregation(MongoClient client) { + this.collection = client.getDatabase("simple").getCollection("books"); + } + + private String getAuthorField() { + return "author"; + } + + private Bson getThirdSort() { + return Sorts.descending(Arrays.asList("published", getAuthorField())); + } + + public AggregateIterable getAllBookTitles(ObjectId id) { + String yearField = "year"; + Bson secondSort = Sorts.ascending(List.of(yearField)); + List sorts = List.of( + Sorts.ascending("title"), + secondSort, + getThirdSort() + ); + return this.collection.aggregate(List.of( + Aggregates.sort( + Sorts.orderBy( + sorts + ) + ) + )); + } +} + """ + ) + fun `Sort#orderBy - should be able to parse with List#of when the list is a variable`( + psiFile: PsiFile + ) { + val aggregate = psiFile.getQueryAtMethod( + "Aggregation", + "getAllBookTitles" + ) + val parsedAggregate = JavaDriverDialect.parser.parse(aggregate) + val hasAggregation = parsedAggregate.component>() + assertEquals(1, hasAggregation?.children?.size) + + val sortStageNode = hasAggregation?.children?.get(0)!! + + commonAssertionsForOrderBySort(sortStageNode) + } + + @ParsingTest( + fileName = "Aggregation.java", + value = """ +import com.mongodb.client.AggregateIterable; +import com.mongodb.client.MongoClient; +import com.mongodb.client.MongoCollection; +import com.mongodb.client.model.Aggregates; +import com.mongodb.client.model.Filters; +import com.mongodb.client.model.Sorts; +import org.bson.Document; +import org.bson.types.ObjectId; +import org.bson.conversions.Bson; + +import java.util.Arrays; +import java.util.List; + +import static com.mongodb.client.model.Filters.*; + +public final class Aggregation { + private final MongoCollection collection; + + public Aggregation(MongoClient client) { + this.collection = client.getDatabase("simple").getCollection("books"); + } + + private String getAuthorField() { + return "author"; + } + + private Bson getThirdSort() { + return Sorts.descending(Arrays.asList("published", getAuthorField())); + } + + private List getSorts() { + String yearField = "year"; + Bson secondSort = Sorts.ascending(List.of(yearField)); + return List.of( + Sorts.ascending("title"), + secondSort, + getThirdSort() + ); + } + + public AggregateIterable getAllBookTitles(ObjectId id) { + return this.collection.aggregate(List.of( + Aggregates.sort( + Sorts.orderBy( + getSorts() + ) + ) + )); + } +} + """ + ) + fun `Sort#orderBy - should be able to parse with List#of when the list comes from a method call`( + psiFile: PsiFile + ) { + val aggregate = psiFile.getQueryAtMethod( + "Aggregation", + "getAllBookTitles" + ) + val parsedAggregate = JavaDriverDialect.parser.parse(aggregate) + val hasAggregation = parsedAggregate.component>() + assertEquals(1, hasAggregation?.children?.size) + + val sortStageNode = hasAggregation?.children?.get(0)!! + + commonAssertionsForOrderBySort(sortStageNode) + } + + companion object { + fun commonAssertionsForAscendingSort(sortStageNode: Node) { + val named = sortStageNode.component()!! + assertEquals(Name.SORT, named.name) + + val sorts = sortStageNode.component>()!! + assertEquals(3, sorts.children.size) + + val titleSort = sorts.children[0] + assertEquals(Name.ASCENDING, titleSort.component()!!.name) + + val titleSortFieldRef = + (titleSort.component>()!!.reference) as FromSchema + assertEquals( + "title", + titleSortFieldRef.fieldName + ) + + val titleSortValueRef = + (titleSort.component>()!!.reference) as Inferred + assertEquals( + 1, + titleSortValueRef.value + ) + + val yearSort = sorts.children[1] + assertEquals(Name.ASCENDING, yearSort.component()!!.name) + + val yearSortFieldRef = + (yearSort.component>()!!.reference) as FromSchema + assertEquals( + "year", + yearSortFieldRef.fieldName + ) + + val yearSortValueRef = + (yearSort.component>()!!.reference) as Inferred + assertEquals( + 1, + yearSortValueRef.value + ) + + val authorSort = sorts.children[2] + assertEquals(Name.ASCENDING, authorSort.component()!!.name) + + val authorSortFieldRef = + (authorSort.component>()!!.reference) as FromSchema + assertEquals( + "author", + authorSortFieldRef.fieldName + ) + + val authorSortValueRef = + (authorSort.component>()!!.reference) as Inferred + assertEquals( + 1, + authorSortValueRef.value + ) + } + + fun commonAssertionsForDescendingSort(sortStageNode: Node) { + val named = sortStageNode.component()!! + assertEquals(Name.SORT, named.name) + + val sorts = sortStageNode.component>()!! + assertEquals(3, sorts.children.size) + + val titleSort = sorts.children[0] + assertEquals(Name.DESCENDING, titleSort.component()!!.name) + + val titleSortFieldRef = + (titleSort.component>()!!.reference) as FromSchema + assertEquals( + "title", + titleSortFieldRef.fieldName + ) + + val titleSortValueRef = + (titleSort.component>()!!.reference) as Inferred + assertEquals( + -1, + titleSortValueRef.value + ) + + val yearSort = sorts.children[1] + assertEquals(Name.DESCENDING, yearSort.component()!!.name) + + val yearSortFieldRef = + (yearSort.component>()!!.reference) as FromSchema + assertEquals( + "year", + yearSortFieldRef.fieldName + ) + + val yearSortValueRef = + (yearSort.component>()!!.reference) as Inferred + assertEquals( + -1, + yearSortValueRef.value + ) + + val authorSort = sorts.children[2] + assertEquals(Name.DESCENDING, authorSort.component()!!.name) + + val authorSortFieldRef = + (authorSort.component>()!!.reference) as FromSchema + assertEquals( + "author", + authorSortFieldRef.fieldName + ) + + val authorSortValueRef = + (authorSort.component>()!!.reference) as Inferred + assertEquals( + -1, + authorSortValueRef.value + ) + } + + fun commonAssertionsForOrderBySort(sortStageNode: Node) { + val named = sortStageNode.component()!! + assertEquals(Name.SORT, named.name) + + val sorts = sortStageNode.component>()!! + assertEquals(4, sorts.children.size) + + val titleSort = sorts.children[0] + assertEquals(Name.ASCENDING, titleSort.component()!!.name) + + val titleSortFieldRef = + (titleSort.component>()!!.reference) as FromSchema + assertEquals( + "title", + titleSortFieldRef.fieldName + ) + + val titleSortValueRef = + (titleSort.component>()!!.reference) as Inferred + assertEquals( + 1, + titleSortValueRef.value + ) + + val yearSort = sorts.children[1] + assertEquals(Name.ASCENDING, yearSort.component()!!.name) + + val yearSortFieldRef = + (yearSort.component>()!!.reference) as FromSchema + assertEquals( + "year", + yearSortFieldRef.fieldName + ) + + val yearSortValueRef = + (yearSort.component>()!!.reference) as Inferred + assertEquals( + 1, + yearSortValueRef.value + ) + + val publishedSort = sorts.children[2] + assertEquals(Name.DESCENDING, publishedSort.component()!!.name) + + val publishedSortFieldRef = + (publishedSort.component>()!!.reference) as FromSchema + assertEquals( + "published", + publishedSortFieldRef.fieldName + ) + + val publishedSortValueRef = + (publishedSort.component>()!!.reference) as Inferred + assertEquals( + -1, + publishedSortValueRef.value + ) + + val authorSort = sorts.children[3] + assertEquals(Name.DESCENDING, authorSort.component()!!.name) + + val authorSortFieldRef = + (authorSort.component>()!!.reference) as FromSchema + assertEquals( + "author", + authorSortFieldRef.fieldName + ) + + val authorSortValueRef = + (authorSort.component>()!!.reference) as Inferred + assertEquals( + -1, + authorSortValueRef.value + ) + } + } +} diff --git a/packages/mongodb-linting-engine/src/test/kotlin/com/mongodb/jbplugin/linting/FieldCheckingLinterTest.kt b/packages/mongodb-linting-engine/src/test/kotlin/com/mongodb/jbplugin/linting/FieldCheckingLinterTest.kt index cc3714dd..e6166eca 100644 --- a/packages/mongodb-linting-engine/src/test/kotlin/com/mongodb/jbplugin/linting/FieldCheckingLinterTest.kt +++ b/packages/mongodb-linting-engine/src/test/kotlin/com/mongodb/jbplugin/linting/FieldCheckingLinterTest.kt @@ -8,6 +8,7 @@ import com.mongodb.jbplugin.mql.components.HasCollectionReference import com.mongodb.jbplugin.mql.components.HasFieldReference import com.mongodb.jbplugin.mql.components.HasFilter import com.mongodb.jbplugin.mql.components.HasProjections +import com.mongodb.jbplugin.mql.components.HasSorts import com.mongodb.jbplugin.mql.components.HasValueReference import com.mongodb.jbplugin.mql.components.Name import com.mongodb.jbplugin.mql.components.Named @@ -473,4 +474,75 @@ class FieldCheckingLinterTest { val warning = result.warnings[0] as FieldCheckWarning.FieldDoesNotExist assertEquals("myBoolean", warning.field) } + + @Test + fun `warns about the referenced fields in an Aggregation#sort not in the specified collection`() { + val readModelProvider = mock>() + val collectionNamespace = Namespace("database", "collection") + + `when`(readModelProvider.slice(any(), any())).thenReturn( + GetCollectionSchema( + CollectionSchema( + collectionNamespace, + BsonObject( + mapOf( + "myInt" to BsonInt32, + ), + ), + ), + ), + ) + + val result = + FieldCheckingLinter.lintQuery( + Unit, + readModelProvider, + Node( + null, + listOf( + HasCollectionReference( + HasCollectionReference.Known(null, null, collectionNamespace) + ), + HasAggregation( + children = listOf( + Node( + null, + listOf( + Named(Name.SORT), + HasSorts( + listOf( + Node( + null, + listOf( + Named(Name.ASCENDING), + HasFieldReference( + HasFieldReference.FromSchema( + null, + "myBoolean" + ) + ), + HasValueReference( + HasValueReference.Inferred( + null, + 1, + BsonInt32 + ) + ) + ) + ) + ) + ) + ) + ) + ) + ) + ), + ), + ) + + assertEquals(1, result.warnings.size) + assertInstanceOf(FieldCheckWarning.FieldDoesNotExist::class.java, result.warnings[0]) + val warning = result.warnings[0] as FieldCheckWarning.FieldDoesNotExist + assertEquals("myBoolean", warning.field) + } } diff --git a/packages/mongodb-mql-model/src/main/kotlin/com/mongodb/jbplugin/mql/components/HasSorts.kt b/packages/mongodb-mql-model/src/main/kotlin/com/mongodb/jbplugin/mql/components/HasSorts.kt new file mode 100644 index 00000000..ca847b2f --- /dev/null +++ b/packages/mongodb-mql-model/src/main/kotlin/com/mongodb/jbplugin/mql/components/HasSorts.kt @@ -0,0 +1,8 @@ +package com.mongodb.jbplugin.mql.components + +import com.mongodb.jbplugin.mql.HasChildren +import com.mongodb.jbplugin.mql.Node + +data class HasSorts( + override val children: List> +) : HasChildren diff --git a/packages/mongodb-mql-model/src/main/kotlin/com/mongodb/jbplugin/mql/components/Named.kt b/packages/mongodb-mql-model/src/main/kotlin/com/mongodb/jbplugin/mql/components/Named.kt index 8c01ec9b..bf4e5abc 100644 --- a/packages/mongodb-mql-model/src/main/kotlin/com/mongodb/jbplugin/mql/components/Named.kt +++ b/packages/mongodb-mql-model/src/main/kotlin/com/mongodb/jbplugin/mql/components/Named.kt @@ -62,6 +62,9 @@ enum class Name(val canonical: String) { MIN("min"), PUSH("push"), ADD_TO_SET("addToSet"), + SORT("sort"), + ASCENDING("ascending"), + DESCENDING("descending"), UNKNOWN(""), ;