diff --git a/packages/jetbrains-plugin/src/test/kotlin/com/mongodb/jbplugin/autocomplete/JavaDriverCompletionContributorTest.kt b/packages/jetbrains-plugin/src/test/kotlin/com/mongodb/jbplugin/autocomplete/JavaDriverCompletionContributorTest.kt index 76214b1f..23979f50 100644 --- a/packages/jetbrains-plugin/src/test/kotlin/com/mongodb/jbplugin/autocomplete/JavaDriverCompletionContributorTest.kt +++ b/packages/jetbrains-plugin/src/test/kotlin/com/mongodb/jbplugin/autocomplete/JavaDriverCompletionContributorTest.kt @@ -936,4 +936,67 @@ public class Repository { }, ) } + + @ParsingTest( + fileName = "Repository.java", + value = """ +import com.mongodb.client.FindIterable; +import com.mongodb.client.MongoClient; +import com.mongodb.client.MongoCollection; +import com.mongodb.client.model.Aggregates; +import com.mongodb.client.model.Projections; +import com.mongodb.client.model.Sorts; +import org.bson.Document; +import org.bson.types.ObjectId; +import java.util.List; +import static com.mongodb.client.model.Filters.*; +import static com.mongodb.client.model.Updates.*; + +public class Repository { + private final MongoClient client; + + public Repository(MongoClient client) { + this.client = client; + } + + public void exampleFind() { + client.getDatabase("myDatabase").getCollection("myCollection") + .aggregate(List.of( + Aggregates.unwind("") + )); + } +} + """, + ) + fun `should autocomplete fields from the current namespace in an Aggregates#unwind stage`( + fixture: CodeInsightTestFixture, + ) { + fixture.specifyDialect(JavaDriverDialect) + + val (dataSource, readModelProvider) = fixture.setupConnection() + val namespace = Namespace("myDatabase", "myCollection") + + `when`( + readModelProvider.slice(eq(dataSource), eq(GetCollectionSchema.Slice(namespace))) + ).thenReturn( + GetCollectionSchema( + CollectionSchema( + namespace, + BsonObject( + mapOf( + "myField" to BsonString, + ), + ), + ), + ), + ) + + val elements = fixture.completeBasic() + + assertTrue( + elements.containsElements { + it.lookupString == "myField" + }, + ) + } } diff --git a/packages/jetbrains-plugin/src/test/kotlin/com/mongodb/jbplugin/autocomplete/JavaDriverMongoDbAutocompletionPopupHandlerTest.kt b/packages/jetbrains-plugin/src/test/kotlin/com/mongodb/jbplugin/autocomplete/JavaDriverMongoDbAutocompletionPopupHandlerTest.kt index 2557fa22..abdb8c04 100644 --- a/packages/jetbrains-plugin/src/test/kotlin/com/mongodb/jbplugin/autocomplete/JavaDriverMongoDbAutocompletionPopupHandlerTest.kt +++ b/packages/jetbrains-plugin/src/test/kotlin/com/mongodb/jbplugin/autocomplete/JavaDriverMongoDbAutocompletionPopupHandlerTest.kt @@ -951,4 +951,68 @@ public class Repository { }, ) } + + @ParsingTest( + fileName = "Repository.java", + value = """ +import com.mongodb.client.FindIterable; +import com.mongodb.client.MongoClient; +import com.mongodb.client.MongoCollection; +import com.mongodb.client.model.Aggregates; +import com.mongodb.client.model.Projections; +import com.mongodb.client.model.Sorts; +import org.bson.Document; +import org.bson.types.ObjectId; +import java.util.List; +import static com.mongodb.client.model.Filters.*; +import static com.mongodb.client.model.Updates.*; + +public class Repository { + private final MongoClient client; + + public Repository(MongoClient client) { + this.client = client; + } + + public void exampleFind() { + client.getDatabase("myDatabase").getCollection("myCollection") + .aggregate(List.of( + Aggregates.unwind() + )); + } +} + """, + ) + fun `should autocomplete fields from the current namespace in an Aggregates#unwind stage`( + fixture: CodeInsightTestFixture, + ) { + fixture.specifyDialect(JavaDriverDialect) + + val (dataSource, readModelProvider) = fixture.setupConnection() + val namespace = Namespace("myDatabase", "myCollection") + + `when`( + readModelProvider.slice(eq(dataSource), eq(GetCollectionSchema.Slice(namespace))) + ).thenReturn( + GetCollectionSchema( + CollectionSchema( + namespace, + BsonObject( + mapOf( + "myField" to BsonString, + ), + ), + ), + ), + ) + + fixture.type('"') + val elements = fixture.completeBasic() + + assertTrue( + elements.containsElements { + it.lookupString == "myField" + }, + ) + } } diff --git a/packages/jetbrains-plugin/src/test/kotlin/com/mongodb/jbplugin/inspections/impl/JavaDriverFieldCheckLinterInspectionTest.kt b/packages/jetbrains-plugin/src/test/kotlin/com/mongodb/jbplugin/inspections/impl/JavaDriverFieldCheckLinterInspectionTest.kt index eb0d4145..99dcf9fb 100644 --- a/packages/jetbrains-plugin/src/test/kotlin/com/mongodb/jbplugin/inspections/impl/JavaDriverFieldCheckLinterInspectionTest.kt +++ b/packages/jetbrains-plugin/src/test/kotlin/com/mongodb/jbplugin/inspections/impl/JavaDriverFieldCheckLinterInspectionTest.kt @@ -622,4 +622,98 @@ public class Repository { fixture.enableInspections(FieldCheckInspectionBridge::class.java) fixture.testHighlighting() } + + @ParsingTest( + fileName = "Repository.java", + value = """ +import com.mongodb.client.AggregateIterable; +import com.mongodb.client.MongoClient; +import com.mongodb.client.MongoCollection; +import com.mongodb.client.model.Aggregates; +import com.mongodb.client.model.Projections; +import com.mongodb.client.model.Sorts; +import org.bson.Document; +import org.bson.conversions.Bson; +import org.bson.types.ObjectId; +import java.util.List; +import static com.mongodb.client.model.Filters.*; + +public class Repository { + private final MongoClient client; + + public Repository(MongoClient client) { + this.client = client; + } + + public AggregateIterable exampleGoodUnwind() { + return client.getDatabase("myDatabase") + .getCollection("myCollection") + .aggregate(List.of( + Aggregates.unwind( + "${'$'}existingField" + ) + )); + } + + public AggregateIterable exampleUnwind1() { + return client.getDatabase("myDatabase") + .getCollection("myCollection") + .aggregate(List.of( + Aggregates.unwind( + "${'$'}nonExistingField" + ) + )); + } + + public AggregateIterable exampleUnwind2() { + String fieldName = "${'$'}nonExistingField"; + return client.getDatabase("myDatabase") + .getCollection("myCollection") + .aggregate(List.of( + Aggregates.unwind( + fieldName + ) + )); + } + + private String getField() { + return "${'$'}nonExistingField"; + } + + public AggregateIterable exampleUnwind3() { + return client.getDatabase("myDatabase") + .getCollection("myCollection") + .aggregate(List.of( + Aggregates.unwind( + getField() + ) + )); + } +} + """, + ) + fun `shows an inspection for Aggregates#unwind call when the field does not exist in the current namespace`( + fixture: CodeInsightTestFixture, + ) { + val (dataSource, readModelProvider) = fixture.setupConnection() + fixture.specifyDialect(JavaDriverDialect) + + `when`( + readModelProvider.slice(eq(dataSource), any()) + ).thenReturn( + GetCollectionSchema( + CollectionSchema( + Namespace("myDatabase", "myCollection"), + BsonObject( + mapOf( + "existingField" to BsonString + ) + ) + ) + ), + ) + + fixture.enableInspections(FieldCheckInspectionBridge::class.java) + fixture.testHighlighting() + } } diff --git a/packages/jetbrains-plugin/src/test/resources/project-fixtures/basic-java-project-with-mongodb/src/main/java/alt/mongodb/javadriver/JavaDriverRepository.java b/packages/jetbrains-plugin/src/test/resources/project-fixtures/basic-java-project-with-mongodb/src/main/java/alt/mongodb/javadriver/JavaDriverRepository.java index 86c11f7c..5e8dc079 100644 --- a/packages/jetbrains-plugin/src/test/resources/project-fixtures/basic-java-project-with-mongodb/src/main/java/alt/mongodb/javadriver/JavaDriverRepository.java +++ b/packages/jetbrains-plugin/src/test/resources/project-fixtures/basic-java-project-with-mongodb/src/main/java/alt/mongodb/javadriver/JavaDriverRepository.java @@ -1,6 +1,7 @@ package alt.mongodb.javadriver; import com.mongodb.client.MongoClient; +import com.mongodb.client.model.*; import com.mongodb.client.model.Accumulators; import com.mongodb.client.model.Aggregates; import com.mongodb.client.model.Filters; @@ -71,6 +72,10 @@ public List queryMoviesByYear(String year) { Sorts.orderBy( Sorts.ascending("asd", "qwe") ) + ), + Aggregates.unwind( + "asd", + new UnwindOptions() ) ) ) diff --git a/packages/mongodb-dialects/java-driver/src/main/kotlin/com/mongodb/jbplugin/dialects/javadriver/glossary/JavaDriverDialectParser.kt b/packages/mongodb-dialects/java-driver/src/main/kotlin/com/mongodb/jbplugin/dialects/javadriver/glossary/JavaDriverDialectParser.kt index ae4ae6fd..a68fd263 100644 --- a/packages/mongodb-dialects/java-driver/src/main/kotlin/com/mongodb/jbplugin/dialects/javadriver/glossary/JavaDriverDialectParser.kt +++ b/packages/mongodb-dialects/java-driver/src/main/kotlin/com/mongodb/jbplugin/dialects/javadriver/glossary/JavaDriverDialectParser.kt @@ -538,6 +538,41 @@ object JavaDriverDialectParser : DialectParser { ) } + "unwind" -> { + val fieldExpression = stageCall.argumentList.expressions.getOrNull(0) + ?: return Node( + source = stageCall, + components = listOf( + Named(Name.UNWIND) + ) + ) + + val fieldName = fieldExpression.tryToResolveAsConstantString() + ?: return Node( + source = stageCall, + components = listOf( + Named(Name.UNWIND), + HasFieldReference( + HasFieldReference.Unknown + ) + ) + ) + + return Node( + source = stageCall, + components = listOf( + Named(Name.UNWIND), + HasFieldReference( + HasFieldReference.FromSchema( + source = fieldExpression, + fieldName = fieldName.trim('$'), + displayName = fieldName, + ) + ), + ) + ) + } + else -> return null } } diff --git a/packages/mongodb-dialects/java-driver/src/test/kotlin/com/mongodb/jbplugin/dialects/javadriver/glossary/aggregationparser/UnwindStageParser.kt b/packages/mongodb-dialects/java-driver/src/test/kotlin/com/mongodb/jbplugin/dialects/javadriver/glossary/aggregationparser/UnwindStageParser.kt new file mode 100644 index 00000000..3cfb5cac --- /dev/null +++ b/packages/mongodb-dialects/java-driver/src/test/kotlin/com/mongodb/jbplugin/dialects/javadriver/glossary/aggregationparser/UnwindStageParser.kt @@ -0,0 +1,262 @@ +package com.mongodb.jbplugin.dialects.javadriver.glossary.aggregationparser + +import com.intellij.psi.PsiElement +import com.intellij.psi.PsiFile +import com.mongodb.jbplugin.dialects.javadriver.IntegrationTest +import com.mongodb.jbplugin.dialects.javadriver.ParsingTest +import com.mongodb.jbplugin.dialects.javadriver.getQueryAtMethod +import com.mongodb.jbplugin.dialects.javadriver.glossary.JavaDriverDialect +import com.mongodb.jbplugin.mql.components.HasAggregation +import com.mongodb.jbplugin.mql.components.HasFieldReference +import com.mongodb.jbplugin.mql.components.Name +import com.mongodb.jbplugin.mql.components.Named +import org.junit.jupiter.api.Assertions.assertEquals + +@IntegrationTest +class UnwindStageParser { + @ParsingTest( + fileName = "Aggregation.java", + value = """ +import com.mongodb.client.AggregateIterable; +import com.mongodb.client.MongoClient; +import com.mongodb.client.MongoCollection; +import com.mongodb.client.model.Aggregates; +import com.mongodb.client.model.Filters; +import org.bson.Document; +import org.bson.types.ObjectId; + +import java.util.List; + +import static com.mongodb.client.model.Filters.*; + +public final class Aggregation { + private final MongoCollection collection; + + public Aggregation(MongoClient client) { + this.collection = client.getDatabase("simple").getCollection("books"); + } + + public AggregateIterable getAllBookTitles(ObjectId id) { + return this.collection.aggregate(List.of( + Aggregates.unwind() + )); + } +} + """ + ) + fun `should be able to parse an empty unwind call`(psiFile: PsiFile) { + val aggregate = psiFile.getQueryAtMethod("Aggregation", "getAllBookTitles") + val parsedAggregate = JavaDriverDialect.parser.parse(aggregate) + val hasAggregation = parsedAggregate.component>() + assertEquals(1, hasAggregation?.children?.size) + + val unwindStageNode = hasAggregation?.children?.get(0)!! + assertEquals(1, unwindStageNode.components.size) + + val named = unwindStageNode.component()!! + assertEquals(Name.UNWIND, named.name) + } + + @ParsingTest( + fileName = "Aggregation.java", + value = """ +import com.mongodb.client.AggregateIterable; +import com.mongodb.client.MongoClient; +import com.mongodb.client.MongoCollection; +import com.mongodb.client.model.Aggregates; +import com.mongodb.client.model.Filters; +import org.bson.Document; +import org.bson.types.ObjectId; + +import java.util.List; + +import static com.mongodb.client.model.Filters.*; + +public final class Aggregation { + private final MongoCollection collection; + + public Aggregation(MongoClient client) { + this.collection = client.getDatabase("simple").getCollection("books"); + } + + public AggregateIterable getAllBookTitles(ObjectId id) { + return this.collection.aggregate(List.of( + Aggregates.unwind("${'$'}name") + )); + } +} + """ + ) + fun `should be able to parse an unwind call with fieldName`(psiFile: PsiFile) { + val aggregate = psiFile.getQueryAtMethod("Aggregation", "getAllBookTitles") + val parsedAggregate = JavaDriverDialect.parser.parse(aggregate) + val hasAggregation = parsedAggregate.component>() + assertEquals(1, hasAggregation?.children?.size) + + val unwindStageNode = hasAggregation?.children?.get(0)!! + assertEquals(2, unwindStageNode.components.size) + + val named = unwindStageNode.component()!! + assertEquals(Name.UNWIND, named.name) + + val fieldReference = unwindStageNode.component>()!!.reference + as HasFieldReference.FromSchema + assertEquals("name", fieldReference.fieldName) + assertEquals("${'$'}name", fieldReference.displayName) + } + + @ParsingTest( + fileName = "Aggregation.java", + value = """ +import com.mongodb.client.AggregateIterable; +import com.mongodb.client.MongoClient; +import com.mongodb.client.MongoCollection; +import com.mongodb.client.model.Aggregates; +import com.mongodb.client.model.Filters; +import org.bson.Document; +import org.bson.types.ObjectId; + +import java.util.List; + +import static com.mongodb.client.model.Filters.*; + +public final class Aggregation { + private final MongoCollection collection; + + public Aggregation(MongoClient client) { + this.collection = client.getDatabase("simple").getCollection("books"); + } + + public AggregateIterable getAllBookTitles(ObjectId id) { + String fieldName = "${'$'}name"; + return this.collection.aggregate(List.of( + Aggregates.unwind(fieldName) + )); + } +} + """ + ) + fun `should be able to parse an unwind call with fieldName where fieldName is variable`( + psiFile: PsiFile + ) { + val aggregate = psiFile.getQueryAtMethod("Aggregation", "getAllBookTitles") + val parsedAggregate = JavaDriverDialect.parser.parse(aggregate) + val hasAggregation = parsedAggregate.component>() + assertEquals(1, hasAggregation?.children?.size) + + val unwindStageNode = hasAggregation?.children?.get(0)!! + assertEquals(2, unwindStageNode.components.size) + + val named = unwindStageNode.component()!! + assertEquals(Name.UNWIND, named.name) + + val fieldReference = unwindStageNode.component>()!!.reference + as HasFieldReference.FromSchema + assertEquals("name", fieldReference.fieldName) + assertEquals("${'$'}name", fieldReference.displayName) + } + + @ParsingTest( + fileName = "Aggregation.java", + value = """ +import com.mongodb.client.AggregateIterable; +import com.mongodb.client.MongoClient; +import com.mongodb.client.MongoCollection; +import com.mongodb.client.model.Aggregates; +import com.mongodb.client.model.Filters; +import org.bson.Document; +import org.bson.types.ObjectId; + +import java.util.List; + +import static com.mongodb.client.model.Filters.*; + +public final class Aggregation { + private final MongoCollection collection; + + public Aggregation(MongoClient client) { + this.collection = client.getDatabase("simple").getCollection("books"); + } + + private String getUnwindField() { + return "${'$'}name"; + } + + public AggregateIterable getAllBookTitles(ObjectId id) { + return this.collection.aggregate(List.of( + Aggregates.unwind(getUnwindField()) + )); + } +} + """ + ) + fun `should be able to parse an unwind call with fieldName where fieldName is from a method call`( + psiFile: PsiFile + ) { + val aggregate = psiFile.getQueryAtMethod("Aggregation", "getAllBookTitles") + val parsedAggregate = JavaDriverDialect.parser.parse(aggregate) + val hasAggregation = parsedAggregate.component>() + assertEquals(1, hasAggregation?.children?.size) + + val unwindStageNode = hasAggregation?.children?.get(0)!! + assertEquals(2, unwindStageNode.components.size) + + val named = unwindStageNode.component()!! + assertEquals(Name.UNWIND, named.name) + + val fieldReference = unwindStageNode.component>()!!.reference + as HasFieldReference.FromSchema + assertEquals("name", fieldReference.fieldName) + assertEquals("${'$'}name", fieldReference.displayName) + } + + @ParsingTest( + fileName = "Aggregation.java", + value = """ +import com.mongodb.client.AggregateIterable; +import com.mongodb.client.MongoClient; +import com.mongodb.client.MongoCollection; +import com.mongodb.client.model.Aggregates; +import com.mongodb.client.model.Filters; +import org.bson.Document; +import org.bson.types.ObjectId; + +import java.util.List; + +import static com.mongodb.client.model.Filters.*; + +public final class Aggregation { + private final MongoCollection collection; + + public Aggregation(MongoClient client) { + this.collection = client.getDatabase("simple").getCollection("books"); + } + + public AggregateIterable getAllBookTitles(ObjectId id) { + return this.collection.aggregate(List.of( + Aggregates.unwind("${'$'}name", new UnwindOptions()) + )); + } +} + """ + ) + fun `should be able to parse an unwind call with fieldName while ignoring the UnwindOptions`( + psiFile: PsiFile + ) { + val aggregate = psiFile.getQueryAtMethod("Aggregation", "getAllBookTitles") + val parsedAggregate = JavaDriverDialect.parser.parse(aggregate) + val hasAggregation = parsedAggregate.component>() + assertEquals(1, hasAggregation?.children?.size) + + val unwindStageNode = hasAggregation?.children?.get(0)!! + assertEquals(2, unwindStageNode.components.size) + + val named = unwindStageNode.component()!! + assertEquals(Name.UNWIND, named.name) + + val fieldReference = unwindStageNode.component>()!!.reference + as HasFieldReference.FromSchema + assertEquals("name", fieldReference.fieldName) + assertEquals("${'$'}name", fieldReference.displayName) + } +} diff --git a/packages/mongodb-linting-engine/src/test/kotlin/com/mongodb/jbplugin/linting/FieldCheckingLinterTest.kt b/packages/mongodb-linting-engine/src/test/kotlin/com/mongodb/jbplugin/linting/FieldCheckingLinterTest.kt index b364faad..b24a961e 100644 --- a/packages/mongodb-linting-engine/src/test/kotlin/com/mongodb/jbplugin/linting/FieldCheckingLinterTest.kt +++ b/packages/mongodb-linting-engine/src/test/kotlin/com/mongodb/jbplugin/linting/FieldCheckingLinterTest.kt @@ -613,4 +613,56 @@ class FieldCheckingLinterTest { assertEquals(0, result.warnings.size) } + + @Test + fun `should not warn about the referenced fields in an Aggregation#unwind`() { + val readModelProvider = mock>() + val collectionNamespace = Namespace("database", "collection") + + `when`(readModelProvider.slice(any(), any())).thenReturn( + GetCollectionSchema( + CollectionSchema( + collectionNamespace, + BsonObject( + mapOf( + "myInt" to BsonInt32, + ), + ), + ), + ), + ) + + val result = + FieldCheckingLinter.lintQuery( + Unit, + readModelProvider, + Node( + null, + listOf( + HasCollectionReference( + HasCollectionReference.Known(null, null, collectionNamespace) + ), + HasAggregation( + children = listOf( + Node( + null, + listOf( + Named(Name.UNWIND), + HasFieldReference( + HasFieldReference.FromSchema( + null, + "myBoolean", + "${'$'}myBoolean" + ) + ) + ) + ) + ) + ) + ), + ), + ) + + assertEquals(1, result.warnings.size) + } } diff --git a/packages/mongodb-mql-model/src/main/kotlin/com/mongodb/jbplugin/mql/components/Named.kt b/packages/mongodb-mql-model/src/main/kotlin/com/mongodb/jbplugin/mql/components/Named.kt index 965905d5..d75aae54 100644 --- a/packages/mongodb-mql-model/src/main/kotlin/com/mongodb/jbplugin/mql/components/Named.kt +++ b/packages/mongodb-mql-model/src/main/kotlin/com/mongodb/jbplugin/mql/components/Named.kt @@ -67,6 +67,7 @@ enum class Name(val canonical: String) { ASCENDING("ascending"), DESCENDING("descending"), ADD_FIELDS("addFields"), + UNWIND("unwind"), UNKNOWN(""), ;