From 9f68d155c3ca631d7f36c5c069826330e5b15767 Mon Sep 17 00:00:00 2001 From: Kevin Mas Ruiz Date: Tue, 21 Jan 2025 14:31:35 +0100 Subject: [PATCH] feat(query-generation): support for generating queries with sort INTELLIJ-186 (#122) --- CHANGELOG.md | 2 + .../javadriver/JavaDriverRepository.java | 2 + .../glossary/JavaDriverDialectParser.kt | 61 +++++++++++++- .../glossary/JavaDriverDialectParserTest.kt | 33 ++++++++ .../mongosh/MongoshDialectFormatter.kt | 37 +++++++++ .../mongosh/backend/MongoshBackend.kt | 3 +- .../mongosh/MongoshDialectFormatterTest.kt | 81 +++++++++++++++++++ 7 files changed, 214 insertions(+), 5 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 066f66139..732060055 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,8 @@ MongoDB plugin for IntelliJ IDEA. ## [Unreleased] ### Added +* [INTELLIJ-188](https://jira.mongodb.org/browse/INTELLIJ-188) Support for generating sort in the query generator. +* [INTELLIJ-186](https://jira.mongodb.org/browse/INTELLIJ-186) Support for parsing Sorts in the Java Driver. * [INTELLIJ-187](https://jira.mongodb.org/browse/INTELLIJ-187) Use safe execution plans by default. Allow full execution plans through a Plugin settings flag. * [INTELLIJ-180](https://jira.mongodb.org/browse/INTELLIJ-180) Telemetry when inspections are shown and resolved. diff --git a/packages/jetbrains-plugin/src/test/resources/project-fixtures/basic-java-project-with-mongodb/src/main/java/alt/mongodb/javadriver/JavaDriverRepository.java b/packages/jetbrains-plugin/src/test/resources/project-fixtures/basic-java-project-with-mongodb/src/main/java/alt/mongodb/javadriver/JavaDriverRepository.java index e8499d6da..49bf95a03 100644 --- a/packages/jetbrains-plugin/src/test/resources/project-fixtures/basic-java-project-with-mongodb/src/main/java/alt/mongodb/javadriver/JavaDriverRepository.java +++ b/packages/jetbrains-plugin/src/test/resources/project-fixtures/basic-java-project-with-mongodb/src/main/java/alt/mongodb/javadriver/JavaDriverRepository.java @@ -35,6 +35,8 @@ public List findMoviesByYear(int year) { .getDatabase("sample_mflix") .getCollection("movies") .find(Filters.eq("year", year)) +// .sort(Sorts.ascending(IMDB_VOTES)) + .sort(Sorts.orderBy(Sorts.ascending(IMDB_VOTES), Sorts.descending("_id"))) .into(new ArrayList<>()); } diff --git a/packages/mongodb-dialects/java-driver/src/main/kotlin/com/mongodb/jbplugin/dialects/javadriver/glossary/JavaDriverDialectParser.kt b/packages/mongodb-dialects/java-driver/src/main/kotlin/com/mongodb/jbplugin/dialects/javadriver/glossary/JavaDriverDialectParser.kt index 603ee1612..8c5620897 100644 --- a/packages/mongodb-dialects/java-driver/src/main/kotlin/com/mongodb/jbplugin/dialects/javadriver/glossary/JavaDriverDialectParser.kt +++ b/packages/mongodb-dialects/java-driver/src/main/kotlin/com/mongodb/jbplugin/dialects/javadriver/glossary/JavaDriverDialectParser.kt @@ -12,6 +12,7 @@ import com.mongodb.jbplugin.mql.BsonArray import com.mongodb.jbplugin.mql.BsonBoolean import com.mongodb.jbplugin.mql.BsonInt32 import com.mongodb.jbplugin.mql.BsonType +import com.mongodb.jbplugin.mql.Component import com.mongodb.jbplugin.mql.ComputedBsonType import com.mongodb.jbplugin.mql.Node import com.mongodb.jbplugin.mql.components.* @@ -60,6 +61,9 @@ object JavaDriverDialectParser : DialectParser { ?: return Node(source, listOf(sourceDialect, collectionReference)) val command = methodCallToCommand(currentCall) + val additionalMetadataMethods = collectAllMetadataMethods(currentCall) + val additionalMetadata = processFindQueryAdditionalMetadata(additionalMetadataMethods) + /** * We might come across a FIND_ONE command and in that case we need to be pointing to the * right method call, find() and not find().first(), in order to parse the filter arguments @@ -80,7 +84,7 @@ object JavaDriverDialectParser : DialectParser { hasFilters, hasUpdates, hasAggregation, - ), + ) + additionalMetadata, ) } else { commandCallMethod?.let { @@ -110,9 +114,13 @@ object JavaDriverDialectParser : DialectParser { sourceDialect, collectionReference, command - ) + ) + additionalMetadata ) - } ?: return Node(source, listOf(sourceDialect, collectionReference)) + } + ?: return Node( + source, + listOf(sourceDialect, collectionReference) + additionalMetadata + ) } } @@ -895,6 +903,53 @@ object JavaDriverDialectParser : DialectParser { } ) } + + private fun collectAllMetadataMethods(methodCall: PsiMethodCallExpression?): List { + val allParentMethodExpressions = methodCall?.collectTypeUntil( + PsiMethodCallExpression::class.java, + PsiReturnStatement::class.java + )?.filter { + // filter out ourselves + !it.isEquivalentTo(methodCall) + } ?: emptyList() + + val allChildrenMethodExpressions = methodCall?.findAllChildrenOfType( + PsiMethodCallExpression::class.java + )?.filter { + // filter out ourselves + !it.isEquivalentTo(methodCall) + } ?: emptyList() + + return allParentMethodExpressions + allChildrenMethodExpressions + } + + private fun processFindQueryAdditionalMetadata(methodCalls: List): List { + return methodCalls.flatMap { methodCall -> + val method = methodCall.fuzzyResolveMethod() + if (method != null && isMethodPartOfTheCursorClass(method)) { + val currentMetadata = when (method.name) { + "sort" -> { + val sortArg = + methodCall.argumentList.expressions.getOrNull(0) ?: return emptyList() + val sortExpr = + resolveBsonBuilderCall(sortArg, SORTS_FQN) ?: return emptyList() + listOf( + HasSorts(parseBsonBuilderCallsSimilarToProjections(sortExpr, SORTS_FQN)) + ) + } + else -> emptyList() + } + + return currentMetadata + } else { + emptyList() + } + } + } + + private fun isMethodPartOfTheCursorClass(method: PsiMethod?): Boolean = + method?.containingClass?.qualifiedName?.contains("FindIterable") == true || + method?.containingClass?.qualifiedName?.contains("MongoIterable") == true } fun PsiExpression.resolveFieldNameFromExpression(): HasFieldReference.FieldReference { diff --git a/packages/mongodb-dialects/java-driver/src/test/kotlin/com/mongodb/jbplugin/dialects/javadriver/glossary/JavaDriverDialectParserTest.kt b/packages/mongodb-dialects/java-driver/src/test/kotlin/com/mongodb/jbplugin/dialects/javadriver/glossary/JavaDriverDialectParserTest.kt index cb3aeb4c9..78afc4d9e 100644 --- a/packages/mongodb-dialects/java-driver/src/test/kotlin/com/mongodb/jbplugin/dialects/javadriver/glossary/JavaDriverDialectParserTest.kt +++ b/packages/mongodb-dialects/java-driver/src/test/kotlin/com/mongodb/jbplugin/dialects/javadriver/glossary/JavaDriverDialectParserTest.kt @@ -2184,4 +2184,37 @@ public final class Repository { val command = parsedQuery.component() assertEquals(IsCommand.CommandType.FIND_ONE, command?.type) } + + @ParsingTest( + fileName = "Repository.java", + value = """ +import com.mongodb.client.MongoClient; +import com.mongodb.client.MongoCollection; +import com.mongodb.client.model.Sorts;import org.bson.types.ObjectId; +import java.util.ArrayList; +import static com.mongodb.client.model.Filters.*; + +public final class Repository { + private final MongoCollection collection; + + public Repository(MongoClient client) { + this.collection = client.getDatabase("simple").getCollection("books"); + } + + public Document findBookById(ObjectId id) { + return this.collection.find(eq("_id", id)).sort(Sorts.ascending("myField")).first(); + } +} + """, + ) + fun `correctly parses FindIterable#sort as a SORT component`(psiFile: PsiFile) { + val query = psiFile.getQueryAtMethod("Repository", "findBookById") + val parsedQuery = JavaDriverDialect.parser.parse(query) + val sorting = parsedQuery.component>()!! + val sortByMyFieldName = sorting.children[0].component>()!!.reference + val sortByMyFieldOrder = sorting.children[0].component>()!!.reference + + assertEquals("myField", (sortByMyFieldName as HasFieldReference.FromSchema).fieldName) + assertEquals(1, (sortByMyFieldOrder as HasValueReference.Inferred).value) + } } diff --git a/packages/mongodb-dialects/mongosh/src/main/kotlin/com/mongodb/jbplugin/dialects/mongosh/MongoshDialectFormatter.kt b/packages/mongodb-dialects/mongosh/src/main/kotlin/com/mongodb/jbplugin/dialects/mongosh/MongoshDialectFormatter.kt index 703af6cf8..0a246eb98 100644 --- a/packages/mongodb-dialects/mongosh/src/main/kotlin/com/mongodb/jbplugin/dialects/mongosh/MongoshDialectFormatter.kt +++ b/packages/mongodb-dialects/mongosh/src/main/kotlin/com/mongodb/jbplugin/dialects/mongosh/MongoshDialectFormatter.kt @@ -66,6 +66,10 @@ object MongoshDialectFormatter : DialectFormatter { emitQueryBody(query, firstCall = true) } }) + + if (returnsACursor(query)) { + emitSort(query) + } }.computeOutput() val ref = query.component>()?.reference @@ -329,6 +333,12 @@ private fun isAggregate(node: Node): Boolean { .parse(node).orElse { false } } +private fun returnsACursor(node: Node): Boolean { + return whenIsCommand(IsCommand.CommandType.FIND_MANY) + .map { true } + .parse(node).orElse { false } +} + private fun canEmitAggregate(node: Node): Boolean { return aggregationStages() .matches(count>().filter { it >= 1 }.matches().anyError()) @@ -343,6 +353,7 @@ private fun MongoshBackend.resolveValueReference( fieldRef: HasFieldReference? ) = when (val ref = valueRef.reference) { is HasValueReference.Constant -> registerConstant(ref.value) + is HasValueReference.Inferred -> registerConstant(ref.value) is HasValueReference.Runtime -> registerVariable( (fieldRef?.reference as? FromSchema)?.fieldName ?: "value", ref.type @@ -382,3 +393,29 @@ private fun MongoshBackend.emitCollectionReference(collRef: HasCollectionRef return this } + +private fun MongoshBackend.emitSort(query: Node): MongoshBackend { + val sortComponent = query.component>() + if (sortComponent == null) { + return this + } + + fun generateSortKeyVal(node: Node): MongoshBackend { + val fieldRef = node.component>() ?: return this + val valueRef = node.component>() ?: return this + + emitObjectKey(resolveFieldReference(fieldRef)) + emitContextValue(resolveValueReference(valueRef, fieldRef)) + return emitObjectValueEnd() + } + + emitPropertyAccess() + emitFunctionName("sort") + return emitFunctionCall(long = false, { + emitObjectStart(long = false) + for (sortCriteria in sortComponent.children) { + generateSortKeyVal(sortCriteria) + } + emitObjectEnd(long = false) + }) +} diff --git a/packages/mongodb-dialects/mongosh/src/main/kotlin/com/mongodb/jbplugin/dialects/mongosh/backend/MongoshBackend.kt b/packages/mongodb-dialects/mongosh/src/main/kotlin/com/mongodb/jbplugin/dialects/mongosh/backend/MongoshBackend.kt index f8e9fc38a..c493aca57 100644 --- a/packages/mongodb-dialects/mongosh/src/main/kotlin/com/mongodb/jbplugin/dialects/mongosh/backend/MongoshBackend.kt +++ b/packages/mongodb-dialects/mongosh/src/main/kotlin/com/mongodb/jbplugin/dialects/mongosh/backend/MongoshBackend.kt @@ -225,12 +225,11 @@ class MongoshBackend( } private fun serializePrimitive(value: Any?, isPlaceholder: Boolean): String = when (value) { - is Byte, Short, Int, Long, Float, Double -> Encode.forJavaScript(value.toString()) is BigInteger -> "Decimal128(\"$value\")" is BigDecimal -> "Decimal128(\"$value\")" + is Byte, is Short, is Int, is Long, is Float, is Double, is Number -> value.toString() is Boolean -> value.toString() is ObjectId -> "ObjectId(\"${Encode.forJavaScript(value.toHexString())}\")" - is Number -> Encode.forJavaScript(value.toString()) is String -> '"' + Encode.forJavaScript(value) + '"' is Date, is Instant, is LocalDate, is LocalDateTime -> if (isPlaceholder) { "ISODate(\"$MONGODB_FIRST_RELEASE\")" diff --git a/packages/mongodb-dialects/mongosh/src/test/kotlin/com/mongodb/jbplugin/dialects/mongosh/MongoshDialectFormatterTest.kt b/packages/mongodb-dialects/mongosh/src/test/kotlin/com/mongodb/jbplugin/dialects/mongosh/MongoshDialectFormatterTest.kt index dfc6c4974..88d50c885 100644 --- a/packages/mongodb-dialects/mongosh/src/test/kotlin/com/mongodb/jbplugin/dialects/mongosh/MongoshDialectFormatterTest.kt +++ b/packages/mongodb-dialects/mongosh/src/test/kotlin/com/mongodb/jbplugin/dialects/mongosh/MongoshDialectFormatterTest.kt @@ -12,6 +12,7 @@ import org.intellij.lang.annotations.Language import org.junit.jupiter.api.Assertions.assertEquals import org.junit.jupiter.api.Test import org.junit.jupiter.params.ParameterizedTest +import org.junit.jupiter.params.provider.MethodSource import org.junit.jupiter.params.provider.ValueSource class MongoshDialectFormatterTest { @@ -520,6 +521,86 @@ class MongoshDialectFormatterTest { } } + @Test + fun `can sort a find query when specified`() { + assertGeneratedQuery( + """ + var collection = "" + var database = "" + + db.getSiblingDB(database).getCollection(collection).find().sort({"a": 1, }) + """.trimIndent() + ) { + Node( + Unit, + listOf( + IsCommand(IsCommand.CommandType.FIND_MANY), + HasSorts( + listOf( + Node( + Unit, + listOf( + HasFieldReference(HasFieldReference.FromSchema(Unit, "a")), + HasValueReference( + HasValueReference.Inferred(Unit, 1, BsonInt32) + ) + ) + ) + ) + ) + ) + ) + } + } + + companion object { + @JvmStatic + fun queryCommandsThatDoNotReturnSortableCursors(): Array { + return arrayOf( + IsCommand.CommandType.COUNT_DOCUMENTS, + IsCommand.CommandType.DELETE_MANY, + IsCommand.CommandType.DELETE_ONE, + IsCommand.CommandType.DISTINCT, + IsCommand.CommandType.ESTIMATED_DOCUMENT_COUNT, + IsCommand.CommandType.FIND_ONE, + IsCommand.CommandType.FIND_ONE_AND_DELETE, + ) + } + } + + @ParameterizedTest + @MethodSource("queryCommandsThatDoNotReturnSortableCursors") + fun `can not sort a query that does not return a cursor`(command: IsCommand.CommandType) { + assertGeneratedQuery( + """ + var collection = "" + var database = "" + + db.getSiblingDB(database).getCollection(collection).${command.canonical}() + """.trimIndent() + ) { + Node( + Unit, + listOf( + IsCommand(command), + HasSorts( + listOf( + Node( + Unit, + listOf( + HasFieldReference(HasFieldReference.FromSchema(Unit, "a")), + HasValueReference( + HasValueReference.Inferred(Unit, 1, BsonInt32) + ) + ) + ) + ) + ) + ) + ) + } + } + @Test fun `generates an index suggestion for a query given its fields`() { assertGeneratedIndex(