diff --git a/CHANGELOG.md b/CHANGELOG.md index ea701fb03..c1730c808 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,7 @@ MongoDB plugin for IntelliJ IDEA. ## [Unreleased] ### Added +* [INTELLIJ-189](https://jira.mongodb.org/browse/INTELLIJ-189) Add support for generating update queries. * [INTELLIJ-177](https://jira.mongodb.org/browse/INTELLIJ-177) Add support for parsing, inspecting and autocompleting in a addFields stage written using `Aggregation.addFields` and chained `AddFieldsOperation`s using `addFieldWithValue`, `addFieldWithValueOf`, `addField().withValue()` and `addField().withValueOf()`. Parsing boxed Java values is not supported yet. * [INTELLIJ-174](https://jira.mongodb.org/browse/INTELLIJ-174) Add support for parsing, inspecting and autocompleting in a sort stage written using `Aggregation.sort` and chained `SortOperation`s using `and`. All the overloads of creating a `Sort` object are supported. * [INTELLIJ-188](https://jira.mongodb.org/browse/INTELLIJ-188) Support for generating sort in the query generator. diff --git a/packages/jetbrains-plugin/src/main/kotlin/com/mongodb/jbplugin/codeActions/impl/RunQueryCodeActionBridge.kt b/packages/jetbrains-plugin/src/main/kotlin/com/mongodb/jbplugin/codeActions/impl/RunQueryCodeActionBridge.kt index fffc6c905..75d8039fb 100644 --- a/packages/jetbrains-plugin/src/main/kotlin/com/mongodb/jbplugin/codeActions/impl/RunQueryCodeActionBridge.kt +++ b/packages/jetbrains-plugin/src/main/kotlin/com/mongodb/jbplugin/codeActions/impl/RunQueryCodeActionBridge.kt @@ -87,7 +87,7 @@ internal object RunQueryCodeAction : MongoDbCodeAction { coroutineScope.launchChildBackground { val outputQuery = MongoshDialect.formatter.formatQuery( query, - QueryContext.empty() + QueryContext.empty(prettyPrint = true) ) if (dataSource?.isConnected() == true) { diff --git a/packages/jetbrains-plugin/src/main/kotlin/com/mongodb/jbplugin/inspections/impl/IndexCheckInspectionBridge.kt b/packages/jetbrains-plugin/src/main/kotlin/com/mongodb/jbplugin/inspections/impl/IndexCheckInspectionBridge.kt index 32ee5eb80..821015045 100644 --- a/packages/jetbrains-plugin/src/main/kotlin/com/mongodb/jbplugin/inspections/impl/IndexCheckInspectionBridge.kt +++ b/packages/jetbrains-plugin/src/main/kotlin/com/mongodb/jbplugin/inspections/impl/IndexCheckInspectionBridge.kt @@ -66,7 +66,8 @@ internal object IndexCheckLinterInspection : MongoDbInspection { QueryContext.ExplainPlanType.FULL } else { QueryContext.ExplainPlanType.SAFE - } + }, + prettyPrint = false ) val readModelProvider by query.source.project.service() diff --git a/packages/jetbrains-plugin/src/test/resources/project-fixtures/basic-java-project-with-mongodb/src/main/java/alt/mongodb/javadriver/JavaDriverRepository.java b/packages/jetbrains-plugin/src/test/resources/project-fixtures/basic-java-project-with-mongodb/src/main/java/alt/mongodb/javadriver/JavaDriverRepository.java index 49bf95a03..976f46a06 100644 --- a/packages/jetbrains-plugin/src/test/resources/project-fixtures/basic-java-project-with-mongodb/src/main/java/alt/mongodb/javadriver/JavaDriverRepository.java +++ b/packages/jetbrains-plugin/src/test/resources/project-fixtures/basic-java-project-with-mongodb/src/main/java/alt/mongodb/javadriver/JavaDriverRepository.java @@ -30,6 +30,23 @@ public Document findMovieById(String id) { .first(); } + public void updateMoviesByYear(int year) { + client + .getDatabase("sample_mflix") + .getCollection("movies") + .updateMany( + Filters.and(Filters.eq("year", year), Filters.ne("languages", "Esperanto")), + Updates.combine( + Updates.inc(IMDB_VOTES, 1), + Updates.inc(AWARDS_WINS, 1), + Updates.set("other", 25), + Updates.pull(IMDB_VOTES, Filters.eq("a", 1)), + Updates.pullAll(IMDB_VOTES, List.of("1", "2")), + Updates.push("languages", "Esperanto") + ) + ); + } + public List findMoviesByYear(int year) { return client .getDatabase("sample_mflix") diff --git a/packages/mongodb-access-adapter/datagrip-access-adapter/src/test/kotlin/com/mongodb/jbplugin/accessadapter/datagrip/adapter/DataGripMongoDbDriverTest.kt b/packages/mongodb-access-adapter/datagrip-access-adapter/src/test/kotlin/com/mongodb/jbplugin/accessadapter/datagrip/adapter/DataGripMongoDbDriverTest.kt index 16d85a784..6d0b526c9 100644 --- a/packages/mongodb-access-adapter/datagrip-access-adapter/src/test/kotlin/com/mongodb/jbplugin/accessadapter/datagrip/adapter/DataGripMongoDbDriverTest.kt +++ b/packages/mongodb-access-adapter/datagrip-access-adapter/src/test/kotlin/com/mongodb/jbplugin/accessadapter/datagrip/adapter/DataGripMongoDbDriverTest.kt @@ -273,7 +273,7 @@ class DataGripMongoDbDriverTest { val explainPlanResult = driver.explain( query, - QueryContext(emptyMap(), QueryContext.ExplainPlanType.SAFE) + QueryContext(emptyMap(), QueryContext.ExplainPlanType.SAFE, false) ) assertEquals(ExplainPlan.CollectionScan, explainPlanResult) } @@ -324,7 +324,7 @@ class DataGripMongoDbDriverTest { val explainPlanResult = driver.explain( query, - QueryContext(emptyMap(), QueryContext.ExplainPlanType.SAFE) + QueryContext(emptyMap(), QueryContext.ExplainPlanType.SAFE, false) ) assertEquals(ExplainPlan.IndexScan, explainPlanResult) } diff --git a/packages/mongodb-dialects/java-driver/src/main/kotlin/com/mongodb/jbplugin/dialects/javadriver/glossary/JavaDriverDialectParser.kt b/packages/mongodb-dialects/java-driver/src/main/kotlin/com/mongodb/jbplugin/dialects/javadriver/glossary/JavaDriverDialectParser.kt index 8c5620897..cdf02f974 100644 --- a/packages/mongodb-dialects/java-driver/src/main/kotlin/com/mongodb/jbplugin/dialects/javadriver/glossary/JavaDriverDialectParser.kt +++ b/packages/mongodb-dialects/java-driver/src/main/kotlin/com/mongodb/jbplugin/dialects/javadriver/glossary/JavaDriverDialectParser.kt @@ -392,7 +392,7 @@ object JavaDriverDialectParser : DialectParser { filter, listOf( Named(Name.from(method.name)), - HasFilter( + HasUpdates( filter.argumentList.expressions .mapNotNull { resolveBsonBuilderCall(it, UPDATES_FQN) } .mapNotNull { parseUpdatesExpression(it) }, @@ -402,12 +402,36 @@ object JavaDriverDialectParser : DialectParser { } else if (method.parameters.size == 2) { // If it has two parameters, it's field/value. val fieldReference = filter.argumentList.expressions[0].resolveFieldNameFromExpression() + val named = Named(Name.from(method.name)) + + if (named.name == Name.PULL) { + // the second argument can be either a value or a filter + val filterExpression = filter.argumentList.expressions[1] + val resolvedFilterExpression = resolveBsonBuilderCall(filterExpression, FILTERS_FQN) + + if (resolvedFilterExpression != null) { + val parsedFilter = parseFilterExpression(resolvedFilterExpression) + ?: return null + + return Node( + filter, + listOf( + named, + HasFieldReference( + fieldReference, + ), + HasFilter(listOf(parsedFilter)), + ), + ) + } + } + val valueReference = resolveValueFromExpression(filter.argumentList.expressions[1]) return Node( filter, listOf( - Named(Name.from(method.name)), + named, HasFieldReference( fieldReference, ), diff --git a/packages/mongodb-dialects/java-driver/src/main/kotlin/com/mongodb/jbplugin/dialects/javadriver/glossary/PsiMdbTreeUtil.kt b/packages/mongodb-dialects/java-driver/src/main/kotlin/com/mongodb/jbplugin/dialects/javadriver/glossary/PsiMdbTreeUtil.kt index a79026c4d..80998e90b 100644 --- a/packages/mongodb-dialects/java-driver/src/main/kotlin/com/mongodb/jbplugin/dialects/javadriver/glossary/PsiMdbTreeUtil.kt +++ b/packages/mongodb-dialects/java-driver/src/main/kotlin/com/mongodb/jbplugin/dialects/javadriver/glossary/PsiMdbTreeUtil.kt @@ -6,13 +6,41 @@ package com.mongodb.jbplugin.dialects.javadriver.glossary import com.intellij.lang.jvm.JvmModifier import com.intellij.openapi.project.Project -import com.intellij.psi.* +import com.intellij.psi.JavaPsiFacade +import com.intellij.psi.PsiClass +import com.intellij.psi.PsiElement +import com.intellij.psi.PsiExpression +import com.intellij.psi.PsiField +import com.intellij.psi.PsiLiteralExpression +import com.intellij.psi.PsiLiteralValue +import com.intellij.psi.PsiLocalVariable +import com.intellij.psi.PsiMethod +import com.intellij.psi.PsiMethodCallExpression +import com.intellij.psi.PsiParenthesizedExpression +import com.intellij.psi.PsiReference +import com.intellij.psi.PsiReferenceExpression +import com.intellij.psi.PsiReturnStatement +import com.intellij.psi.PsiSuperExpression +import com.intellij.psi.PsiThisExpression +import com.intellij.psi.PsiType import com.intellij.psi.search.GlobalSearchScope import com.intellij.psi.util.PsiTreeUtil import com.intellij.psi.util.PsiTypesUtil import com.intellij.psi.util.childrenOfType import com.intellij.psi.util.parentOfType -import com.mongodb.jbplugin.mql.* +import com.mongodb.jbplugin.mql.BsonAny +import com.mongodb.jbplugin.mql.BsonAnyOf +import com.mongodb.jbplugin.mql.BsonArray +import com.mongodb.jbplugin.mql.BsonBoolean +import com.mongodb.jbplugin.mql.BsonDate +import com.mongodb.jbplugin.mql.BsonDecimal128 +import com.mongodb.jbplugin.mql.BsonDouble +import com.mongodb.jbplugin.mql.BsonInt32 +import com.mongodb.jbplugin.mql.BsonInt64 +import com.mongodb.jbplugin.mql.BsonNull +import com.mongodb.jbplugin.mql.BsonObjectId +import com.mongodb.jbplugin.mql.BsonString +import com.mongodb.jbplugin.mql.BsonType import com.mongodb.jbplugin.mql.components.IsCommand /** @@ -398,6 +426,9 @@ fun String.toBsonType(): BsonType { } else if (this.endsWith("[]")) { val baseType = this.substring(0, this.length - 2) return BsonArray(baseType.toBsonType()) + } else if (this.contains("List") || this.contains("Set")) { + val baseType = this.substringAfter("<").substringBeforeLast(">") + return BsonArray(baseType.toBsonType()) } return BsonAny diff --git a/packages/mongodb-dialects/java-driver/src/test/kotlin/com/mongodb/jbplugin/dialects/javadriver/glossary/JavaDriverDialectParserTest.kt b/packages/mongodb-dialects/java-driver/src/test/kotlin/com/mongodb/jbplugin/dialects/javadriver/glossary/JavaDriverDialectParserTest.kt index 78afc4d9e..5088160de 100644 --- a/packages/mongodb-dialects/java-driver/src/test/kotlin/com/mongodb/jbplugin/dialects/javadriver/glossary/JavaDriverDialectParserTest.kt +++ b/packages/mongodb-dialects/java-driver/src/test/kotlin/com/mongodb/jbplugin/dialects/javadriver/glossary/JavaDriverDialectParserTest.kt @@ -1709,9 +1709,9 @@ public class Repository { val combine = hasUpdate.children[0] assertEquals(Name.COMBINE, combine.component()!!.name) - assertEquals(2, combine.component>()!!.children.size) + assertEquals(2, combine.component>()!!.children.size) - val updates = combine.component>()!!.children + val updates = combine.component>()!!.children assertEquals(Name.SET, updates[0].component()!!.name) assertEquals(Name.UNSET, updates[1].component()!!.name) } diff --git a/packages/mongodb-dialects/java-driver/src/test/kotlin/com/mongodb/jbplugin/dialects/javadriver/glossary/PsiMdbTreeUtilTest.kt b/packages/mongodb-dialects/java-driver/src/test/kotlin/com/mongodb/jbplugin/dialects/javadriver/glossary/PsiMdbTreeUtilTest.kt index dfec8543c..bb0970519 100644 --- a/packages/mongodb-dialects/java-driver/src/test/kotlin/com/mongodb/jbplugin/dialects/javadriver/glossary/PsiMdbTreeUtilTest.kt +++ b/packages/mongodb-dialects/java-driver/src/test/kotlin/com/mongodb/jbplugin/dialects/javadriver/glossary/PsiMdbTreeUtilTest.kt @@ -28,6 +28,15 @@ class PsiMdbTreeUtilTest { } } + @ParameterizedTest + @MethodSource("stringToBsonType") + fun `should map all known java qualified names to their corresponding bson types`( + javaQualifiedName: String, + expected: BsonType, + ) { + assertEquals(expected, javaQualifiedName.toBsonType()) + } + companion object { @JvmStatic fun psiTypeToBsonType(): Array> = @@ -89,6 +98,38 @@ class PsiMdbTreeUtilTest { BsonAnyOf(BsonDecimal128, BsonNull), ), ) + + @JvmStatic + fun stringToBsonType(): Array> = arrayOf( + arrayOf("org.bson.types.ObjectId", BsonAnyOf(BsonObjectId, BsonNull)), + arrayOf("boolean", BsonBoolean), + arrayOf("java.lang.Boolean", BsonBoolean), + arrayOf("short", BsonInt32), + arrayOf("java.lang.Short", BsonInt32), + arrayOf("int", BsonInt32), + arrayOf("java.lang.Integer", BsonInt32), + arrayOf("long", BsonInt64), + arrayOf("java.lang.Long", BsonInt64), + arrayOf("float", BsonDouble), + arrayOf("java.lang.Float", BsonDouble), + arrayOf("double", BsonDouble), + arrayOf("java.lang.Double", BsonDouble), + arrayOf("java.lang.CharSequence", BsonAnyOf(BsonString, BsonNull)), + arrayOf("java.lang.String", BsonAnyOf(BsonString, BsonNull)), + arrayOf("String", BsonAnyOf(BsonString, BsonNull)), + arrayOf("java.util.Date", BsonAnyOf(BsonDate, BsonNull)), + arrayOf("java.time.LocalDate", BsonAnyOf(BsonDate, BsonNull)), + arrayOf("java.time.LocalDateTime", BsonAnyOf(BsonDate, BsonNull)), + arrayOf("java.math.BigInteger", BsonAnyOf(BsonInt64, BsonNull)), + arrayOf("java.math.BigDecimal", BsonAnyOf(BsonDecimal128, BsonNull)), + arrayOf("int[]", BsonArray(BsonInt32)), + arrayOf("java.lang.Long[]", BsonArray(BsonInt64)), + arrayOf("List", BsonArray(BsonAnyOf(BsonString, BsonNull))), + arrayOf("List[]", BsonArray(BsonArray(BsonInt32))), + arrayOf("Set", BsonArray(BsonAnyOf(BsonString, BsonNull))), + arrayOf("Map", BsonAny), + arrayOf("HashMap", BsonAny), + ) } } diff --git a/packages/mongodb-dialects/mongosh/src/main/kotlin/com/mongodb/jbplugin/dialects/mongosh/MongoshDialectFormatter.kt b/packages/mongodb-dialects/mongosh/src/main/kotlin/com/mongodb/jbplugin/dialects/mongosh/MongoshDialectFormatter.kt index 0a246eb98..71caa7799 100644 --- a/packages/mongodb-dialects/mongosh/src/main/kotlin/com/mongodb/jbplugin/dialects/mongosh/MongoshDialectFormatter.kt +++ b/packages/mongodb-dialects/mongosh/src/main/kotlin/com/mongodb/jbplugin/dialects/mongosh/MongoshDialectFormatter.kt @@ -2,23 +2,18 @@ package com.mongodb.jbplugin.dialects.mongosh import com.mongodb.jbplugin.dialects.DialectFormatter import com.mongodb.jbplugin.dialects.OutputQuery +import com.mongodb.jbplugin.dialects.mongosh.aggr.canEmitAggregate +import com.mongodb.jbplugin.dialects.mongosh.aggr.emitAggregateBody +import com.mongodb.jbplugin.dialects.mongosh.aggr.isAggregate import com.mongodb.jbplugin.dialects.mongosh.backend.MongoshBackend +import com.mongodb.jbplugin.dialects.mongosh.query.canUpdateDocuments +import com.mongodb.jbplugin.dialects.mongosh.query.emitCollectionReference +import com.mongodb.jbplugin.dialects.mongosh.query.emitQueryFilter +import com.mongodb.jbplugin.dialects.mongosh.query.emitQueryUpdate +import com.mongodb.jbplugin.dialects.mongosh.query.emitSort +import com.mongodb.jbplugin.dialects.mongosh.query.returnsACursor import com.mongodb.jbplugin.mql.* import com.mongodb.jbplugin.mql.components.* -import com.mongodb.jbplugin.mql.components.HasFieldReference.FromSchema -import com.mongodb.jbplugin.mql.components.HasFieldReference.Inferred -import com.mongodb.jbplugin.mql.components.HasFieldReference.Unknown -import com.mongodb.jbplugin.mql.parser.anyError -import com.mongodb.jbplugin.mql.parser.components.aggregationStages -import com.mongodb.jbplugin.mql.parser.components.allFiltersRecursively -import com.mongodb.jbplugin.mql.parser.components.hasName -import com.mongodb.jbplugin.mql.parser.components.whenIsCommand -import com.mongodb.jbplugin.mql.parser.count -import com.mongodb.jbplugin.mql.parser.filter -import com.mongodb.jbplugin.mql.parser.map -import com.mongodb.jbplugin.mql.parser.matches -import com.mongodb.jbplugin.mql.parser.nth -import com.mongodb.jbplugin.mql.parser.parse import io.github.z4kn4fein.semver.Version import org.owasp.encoder.Encode @@ -27,13 +22,10 @@ object MongoshDialectFormatter : DialectFormatter { query: Node, queryContext: QueryContext, ): OutputQuery { - val isAggregate = isAggregate(query) - val canEmitAggregate = canEmitAggregate(query) + val isAggregate = query.isAggregate() + val canEmitAggregate = query.canEmitAggregate() - val outputString = MongoshBackend( - prettyPrint = - queryContext.explainPlan != QueryContext.ExplainPlanType.NONE - ).apply { + val outputString = MongoshBackend(prettyPrint = queryContext.prettyPrint).apply { if (isAggregate && !canEmitAggregate) { emitComment("Only aggregates with a single match stage can be converted.") return@apply @@ -59,15 +51,25 @@ object MongoshDialectFormatter : DialectFormatter { } else { emitFunctionName(query.component()?.type?.canonical ?: "find") } - emitFunctionCall(long = true, { - if (isAggregate(query)) { - emitAggregateBody(query) - } else { - emitQueryBody(query, firstCall = true) - } - }) + if (query.canUpdateDocuments() && + queryContext.explainPlan == QueryContext.ExplainPlanType.NONE + ) { + emitFunctionCall(long = true, { + emitQueryFilter(query, firstCall = true) + }, { + emitQueryUpdate(query) + }) + } else { + emitFunctionCall(long = true, { + if (query.isAggregate()) { + emitAggregateBody(query) + } else { + emitQueryFilter(query, firstCall = true) + } + }) + } - if (returnsACursor(query)) { + if (query.returnsACursor()) { emitSort(query) } }.computeOutput() @@ -127,295 +129,3 @@ object MongoshDialectFormatter : DialectFormatter { override fun formatType(type: BsonType) = "" } - -private fun MongoshBackend.emitQueryBody( - node: Node, - firstCall: Boolean = false -): MongoshBackend { - val named = node.component() - val fieldRef = node.component>() - val valueRef = node.component>() - val hasFilter = node.component>() - val isLong = allFiltersRecursively().parse(node).orElse { emptyList() }.size > 3 - - if (hasFilter != null && fieldRef == null && valueRef == null && named == null) { - // 1. has children, nothing else (root node) - if (firstCall) { - emitObjectStart(long = isLong) - } - - hasFilter.children.forEach { - emitQueryBody(it) - emitObjectValueEnd() - } - if (firstCall) { - emitObjectEnd(long = isLong) - } - } else if (hasFilter == null && fieldRef != null && valueRef != null && named == null) { - // 2. no children, only a field: value case - if (firstCall) { - emitObjectStart(long = isLong) - } - emitObjectKey(resolveFieldReference(fieldRef)) - emitContextValue(resolveValueReference(valueRef, fieldRef)) - if (firstCall) { - emitObjectEnd(long = isLong) - } - } else { - named?.let { -// 3. children and named - if (named.name == Name.EQ) { -// normal a: b case - if (firstCall) { - emitObjectStart(long = isLong) - } - if (fieldRef != null) { - emitObjectKey(resolveFieldReference(fieldRef)) - } - - if (valueRef != null) { - emitContextValue(resolveValueReference(valueRef, fieldRef)) - } - - hasFilter?.children?.forEach { - emitQueryBody(it) - emitObjectValueEnd() - } - - if (firstCall) { - emitObjectEnd(long = isLong) - } - } else if (setOf( // 1st basic attempt, to improve in INTELLIJ-76 - Name.GT, - Name.GTE, - Name.LT, - Name.LTE - ).contains(named.name) && - valueRef != null - ) { -// a: { $gt: 1 } - if (firstCall) { - emitObjectStart(long = isLong) - } - - if (fieldRef != null) { - emitObjectKey(resolveFieldReference(fieldRef)) - } - - emitObjectStart() - emitObjectKey(registerConstant('$' + named.name.canonical)) - emitContextValue(resolveValueReference(valueRef, fieldRef)) - emitObjectEnd() - - if (firstCall) { - emitObjectEnd(long = isLong) - } - } else if (setOf( - // 1st basic attempt, to improve in INTELLIJ-77 - Name.AND, - Name.OR, - Name.NOR, - ).contains(named.name) - ) { - if (firstCall) { - emitObjectStart() - } - emitObjectKey(registerConstant('$' + named.name.canonical)) - emitArrayStart(long = true) - hasFilter?.children?.forEach { - emitObjectStart() - emitQueryBody(it) - emitObjectEnd() - emitObjectValueEnd() - if (prettyPrint) { - emitNewLine() - } - } - emitArrayEnd(long = true) - if (firstCall) { - emitObjectEnd() - } - } else if (named.name == Name.NOT && hasFilter?.children?.size == 1) { - // the not operator is a special case - // because we receive it as: - // $not: { $field$: $condition$ } - // and it needs to be: - // $field$: { $not: $condition$ } - // we will do a JIT translation - - var innerChild = hasFilter.children.first() - val operation = innerChild.component() - val fieldRef = innerChild.component>() - val valueRef = innerChild.component>() - - if (fieldRef == null) { // we are in an "and" / "or"... - // so we use $nor instead - emitQueryBody( - Node( - node.source, - node.components().filterNot { it is Named } + Named(Name.NOR) - ) - ) - return@let - } - - if (operation == null && valueRef == null) { - return@let - } - - if (firstCall) { - emitObjectStart() - } - - // emit field name first - emitObjectKey(resolveFieldReference(fieldRef)) - // emit the $not - emitObjectStart() - emitObjectKey(registerConstant('$' + "not")) - emitQueryBody( - Node( - innerChild.source, - listOfNotNull( - operation, - valueRef - ) - ) - ) - emitObjectEnd() - - if (firstCall) { - emitObjectEnd() - } - } else if (named.name != Name.UNKNOWN && fieldRef != null && valueRef != null) { - if (firstCall) { - emitObjectStart(long = isLong) - } - emitObjectKey(resolveFieldReference(fieldRef)) - emitObjectStart(long = isLong) - emitObjectKey(registerConstant('$' + named.name.canonical)) - emitContextValue(resolveValueReference(valueRef, fieldRef)) - emitObjectEnd(long = isLong) - if (firstCall) { - emitObjectEnd(long = isLong) - } - } - } - } - - return this -} - -private fun MongoshBackend.emitAggregateBody(node: Node): MongoshBackend { - // here we can assume that we only have 1 single stage that is a match - val matchStage = node.component>()!!.children[0] - val filter = matchStage.component>()?.children?.getOrNull(0) - val longFilter = filter?.component>()?.children?.size?.let { it > 3 } == true - - emitArrayStart(long = true) - emitObjectStart() - emitObjectKey(registerConstant('$' + "match")) - if (filter != null) { - emitObjectStart(long = longFilter) - emitQueryBody(filter) - emitObjectEnd(long = longFilter) - } else { - emitComment("No filter provided.") - } - emitObjectEnd() - emitArrayEnd(long = true) - - return this -} - -private fun isAggregate(node: Node): Boolean { - return whenIsCommand(IsCommand.CommandType.AGGREGATE) - .map { true } - .parse(node).orElse { false } -} - -private fun returnsACursor(node: Node): Boolean { - return whenIsCommand(IsCommand.CommandType.FIND_MANY) - .map { true } - .parse(node).orElse { false } -} - -private fun canEmitAggregate(node: Node): Boolean { - return aggregationStages() - .matches(count>().filter { it >= 1 }.matches().anyError()) - .nth(0) - .matches(hasName(Name.MATCH)) - .map { true } - .parse(node).orElse { false } -} - -private fun MongoshBackend.resolveValueReference( - valueRef: HasValueReference, - fieldRef: HasFieldReference? -) = when (val ref = valueRef.reference) { - is HasValueReference.Constant -> registerConstant(ref.value) - is HasValueReference.Inferred -> registerConstant(ref.value) - is HasValueReference.Runtime -> registerVariable( - (fieldRef?.reference as? FromSchema)?.fieldName ?: "value", - ref.type - ) - - else -> registerVariable( - "queryField", - BsonAny - ) -} - -private fun MongoshBackend.resolveFieldReference(fieldRef: HasFieldReference) = - when (val ref = fieldRef.reference) { - is FromSchema -> registerConstant(ref.fieldName) - is Inferred -> registerConstant(ref.fieldName) - is HasFieldReference.Computed -> registerConstant(ref.fieldName) - is Unknown -> registerVariable("field", BsonAny) - } - -private fun MongoshBackend.emitCollectionReference(collRef: HasCollectionReference?): MongoshBackend { - when (val ref = collRef?.reference) { - is HasCollectionReference.OnlyCollection -> { - emitDatabaseAccess(registerVariable("database", BsonString)) - emitCollectionAccess(registerConstant(ref.collection)) - } - - is HasCollectionReference.Known -> { - emitDatabaseAccess(registerConstant(ref.namespace.database)) - emitCollectionAccess(registerConstant(ref.namespace.collection)) - } - - else -> { - emitDatabaseAccess(registerVariable("database", BsonString)) - emitCollectionAccess(registerVariable("collection", BsonString)) - } - } - - return this -} - -private fun MongoshBackend.emitSort(query: Node): MongoshBackend { - val sortComponent = query.component>() - if (sortComponent == null) { - return this - } - - fun generateSortKeyVal(node: Node): MongoshBackend { - val fieldRef = node.component>() ?: return this - val valueRef = node.component>() ?: return this - - emitObjectKey(resolveFieldReference(fieldRef)) - emitContextValue(resolveValueReference(valueRef, fieldRef)) - return emitObjectValueEnd() - } - - emitPropertyAccess() - emitFunctionName("sort") - return emitFunctionCall(long = false, { - emitObjectStart(long = false) - for (sortCriteria in sortComponent.children) { - generateSortKeyVal(sortCriteria) - } - emitObjectEnd(long = false) - }) -} diff --git a/packages/mongodb-dialects/mongosh/src/main/kotlin/com/mongodb/jbplugin/dialects/mongosh/aggr/Aggregate.kt b/packages/mongodb-dialects/mongosh/src/main/kotlin/com/mongodb/jbplugin/dialects/mongosh/aggr/Aggregate.kt new file mode 100644 index 000000000..4f65f8bf4 --- /dev/null +++ b/packages/mongodb-dialects/mongosh/src/main/kotlin/com/mongodb/jbplugin/dialects/mongosh/aggr/Aggregate.kt @@ -0,0 +1,56 @@ +package com.mongodb.jbplugin.dialects.mongosh.aggr + +import com.mongodb.jbplugin.dialects.mongosh.backend.MongoshBackend +import com.mongodb.jbplugin.dialects.mongosh.query.emitQueryFilter +import com.mongodb.jbplugin.mql.Node +import com.mongodb.jbplugin.mql.components.HasAggregation +import com.mongodb.jbplugin.mql.components.HasFilter +import com.mongodb.jbplugin.mql.components.IsCommand +import com.mongodb.jbplugin.mql.components.Name +import com.mongodb.jbplugin.mql.parser.anyError +import com.mongodb.jbplugin.mql.parser.components.aggregationStages +import com.mongodb.jbplugin.mql.parser.components.hasName +import com.mongodb.jbplugin.mql.parser.components.whenIsCommand +import com.mongodb.jbplugin.mql.parser.count +import com.mongodb.jbplugin.mql.parser.filter +import com.mongodb.jbplugin.mql.parser.map +import com.mongodb.jbplugin.mql.parser.matches +import com.mongodb.jbplugin.mql.parser.nth +import com.mongodb.jbplugin.mql.parser.parse + +fun Node.isAggregate(): Boolean { + return whenIsCommand(IsCommand.CommandType.AGGREGATE) + .map { true } + .parse(this).orElse { false } +} + +fun Node.canEmitAggregate(): Boolean { + return aggregationStages() + .matches(count>().filter { it >= 1 }.matches().anyError()) + .nth(0) + .matches(hasName(Name.MATCH)) + .map { true } + .parse(this).orElse { false } +} + +fun MongoshBackend.emitAggregateBody(node: Node): MongoshBackend { + // here we can assume that we only have 1 single stage that is a match + val matchStage = node.component>()!!.children[0] + val filter = matchStage.component>()?.children?.getOrNull(0) + val longFilter = filter?.component>()?.children?.size?.let { it > 3 } == true + + emitArrayStart(long = true) + emitObjectStart() + emitObjectKey(registerConstant('$' + "match")) + if (filter != null) { + emitObjectStart(long = longFilter) + emitQueryFilter(filter) + emitObjectEnd(long = longFilter) + } else { + emitComment("No filter provided.") + } + emitObjectEnd() + emitArrayEnd(long = true) + + return this +} diff --git a/packages/mongodb-dialects/mongosh/src/main/kotlin/com/mongodb/jbplugin/dialects/mongosh/query/Find.kt b/packages/mongodb-dialects/mongosh/src/main/kotlin/com/mongodb/jbplugin/dialects/mongosh/query/Find.kt new file mode 100644 index 000000000..2fd6e6803 --- /dev/null +++ b/packages/mongodb-dialects/mongosh/src/main/kotlin/com/mongodb/jbplugin/dialects/mongosh/query/Find.kt @@ -0,0 +1,192 @@ +package com.mongodb.jbplugin.dialects.mongosh.query + +import com.mongodb.jbplugin.dialects.mongosh.backend.MongoshBackend +import com.mongodb.jbplugin.mql.Component +import com.mongodb.jbplugin.mql.Node +import com.mongodb.jbplugin.mql.components.HasFieldReference +import com.mongodb.jbplugin.mql.components.HasFilter +import com.mongodb.jbplugin.mql.components.HasValueReference +import com.mongodb.jbplugin.mql.components.Name +import com.mongodb.jbplugin.mql.components.Named +import com.mongodb.jbplugin.mql.parser.components.allFiltersRecursively +import com.mongodb.jbplugin.mql.parser.parse + +fun MongoshBackend.emitQueryFilter(node: Node, firstCall: Boolean = false): MongoshBackend { + val named = node.component() + val fieldRef = node.component>() + val valueRef = node.component>() + val hasFilter = node.component>() + val isLong = allFiltersRecursively().parse(node).orElse { emptyList() }.size > 3 + + if (firstCall && hasFilter == null && fieldRef == null && valueRef == null) { + emitObjectStart() + emitObjectEnd() + return this + } + + if (hasFilter != null && fieldRef == null && valueRef == null && named == null) { + // 1. has children, nothing else (root node) + if (firstCall) { + emitObjectStart(long = isLong) + } + + hasFilter.children.forEach { + emitQueryFilter(it) + emitObjectValueEnd() + } + if (firstCall) { + emitObjectEnd(long = isLong) + } + } else if (hasFilter == null && fieldRef != null && valueRef != null && named == null) { + // 2. no children, only a field: value case + if (firstCall) { + emitObjectStart(long = isLong) + } + emitObjectKey(resolveFieldReference(fieldRef)) + emitContextValue(resolveValueReference(valueRef, fieldRef)) + if (firstCall) { + emitObjectEnd(long = isLong) + } + } else { + named?.let { +// 3. children and named + if (named.name == Name.EQ) { +// normal a: b case + if (firstCall) { + emitObjectStart(long = isLong) + } + if (fieldRef != null) { + emitObjectKey(resolveFieldReference(fieldRef)) + } + + if (valueRef != null) { + emitContextValue(resolveValueReference(valueRef, fieldRef)) + } + + hasFilter?.children?.forEach { + emitQueryFilter(it) + emitObjectValueEnd() + } + + if (firstCall) { + emitObjectEnd(long = isLong) + } + } else if (setOf( // 1st basic attempt, to improve in INTELLIJ-76 + Name.GT, + Name.GTE, + Name.LT, + Name.LTE + ).contains(named.name) && + valueRef != null + ) { +// a: { $gt: 1 } + if (firstCall) { + emitObjectStart(long = isLong) + } + + if (fieldRef != null) { + emitObjectKey(resolveFieldReference(fieldRef)) + } + + emitObjectStart() + emitObjectKey(registerConstant('$' + named.name.canonical)) + emitContextValue(resolveValueReference(valueRef, fieldRef)) + emitObjectEnd() + + if (firstCall) { + emitObjectEnd(long = isLong) + } + } else if (setOf( + // 1st basic attempt, to improve in INTELLIJ-77 + Name.AND, + Name.OR, + Name.NOR, + ).contains(named.name) + ) { + if (firstCall) { + emitObjectStart() + } + emitObjectKey(registerConstant('$' + named.name.canonical)) + emitArrayStart(long = true) + hasFilter?.children?.forEach { + emitObjectStart() + emitQueryFilter(it) + emitObjectEnd() + emitObjectValueEnd() + if (prettyPrint) { + emitNewLine() + } + } + emitArrayEnd(long = true) + if (firstCall) { + emitObjectEnd() + } + } else if (named.name == Name.NOT && hasFilter?.children?.size == 1) { + // the not operator is a special case + // because we receive it as: + // $not: { $field$: $condition$ } + // and it needs to be: + // $field$: { $not: $condition$ } + // we will do a JIT translation + + var innerChild = hasFilter.children.first() + val operation = innerChild.component() + val fieldRef = innerChild.component>() + val valueRef = innerChild.component>() + + if (fieldRef == null) { // we are in an "and" / "or"... + // so we use $nor instead + emitQueryFilter( + Node( + node.source, + node.components().filterNot { it is Named } + Named(Name.NOR) + ) + ) + return@let + } + + if (operation == null && valueRef == null) { + return@let + } + + if (firstCall) { + emitObjectStart() + } + + // emit field name first + emitObjectKey(resolveFieldReference(fieldRef)) + // emit the $not + emitObjectStart() + emitObjectKey(registerConstant('$' + "not")) + emitQueryFilter( + Node( + innerChild.source, + listOfNotNull( + operation, + valueRef + ) + ) + ) + emitObjectEnd() + + if (firstCall) { + emitObjectEnd() + } + } else if (named.name != Name.UNKNOWN && fieldRef != null && valueRef != null) { + if (firstCall) { + emitObjectStart(long = isLong) + } + emitObjectKey(resolveFieldReference(fieldRef)) + emitObjectStart(long = isLong) + emitObjectKey(registerConstant('$' + named.name.canonical)) + emitContextValue(resolveValueReference(valueRef, fieldRef)) + emitObjectEnd(long = isLong) + if (firstCall) { + emitObjectEnd(long = isLong) + } + } + } + } + + return this +} diff --git a/packages/mongodb-dialects/mongosh/src/main/kotlin/com/mongodb/jbplugin/dialects/mongosh/query/Query.kt b/packages/mongodb-dialects/mongosh/src/main/kotlin/com/mongodb/jbplugin/dialects/mongosh/query/Query.kt new file mode 100644 index 000000000..fd9bea2e6 --- /dev/null +++ b/packages/mongodb-dialects/mongosh/src/main/kotlin/com/mongodb/jbplugin/dialects/mongosh/query/Query.kt @@ -0,0 +1,24 @@ +package com.mongodb.jbplugin.dialects.mongosh.query + +import com.mongodb.jbplugin.mql.Node +import com.mongodb.jbplugin.mql.components.IsCommand +import com.mongodb.jbplugin.mql.parser.components.whenHasAnyCommand +import com.mongodb.jbplugin.mql.parser.components.whenIsCommand +import com.mongodb.jbplugin.mql.parser.map +import com.mongodb.jbplugin.mql.parser.parse + +fun Node.returnsACursor(): Boolean { + return whenIsCommand(IsCommand.CommandType.FIND_MANY) + .map { true } + .parse(this).orElse { false } +} + +fun Node.canUpdateDocuments(): Boolean { + return whenHasAnyCommand() + .map { it.component()!!.type } + .map { + it == IsCommand.CommandType.UPDATE_ONE || + it == IsCommand.CommandType.UPDATE_MANY || + it == IsCommand.CommandType.FIND_ONE_AND_UPDATE + }.parse(this).orElse { false } +} diff --git a/packages/mongodb-dialects/mongosh/src/main/kotlin/com/mongodb/jbplugin/dialects/mongosh/query/References.kt b/packages/mongodb-dialects/mongosh/src/main/kotlin/com/mongodb/jbplugin/dialects/mongosh/query/References.kt new file mode 100644 index 000000000..303f47157 --- /dev/null +++ b/packages/mongodb-dialects/mongosh/src/main/kotlin/com/mongodb/jbplugin/dialects/mongosh/query/References.kt @@ -0,0 +1,57 @@ +package com.mongodb.jbplugin.dialects.mongosh.query + +import com.mongodb.jbplugin.dialects.mongosh.backend.MongoshBackend +import com.mongodb.jbplugin.mql.BsonAny +import com.mongodb.jbplugin.mql.BsonString +import com.mongodb.jbplugin.mql.components.HasCollectionReference +import com.mongodb.jbplugin.mql.components.HasFieldReference +import com.mongodb.jbplugin.mql.components.HasFieldReference.FromSchema +import com.mongodb.jbplugin.mql.components.HasFieldReference.Inferred +import com.mongodb.jbplugin.mql.components.HasFieldReference.Unknown +import com.mongodb.jbplugin.mql.components.HasValueReference + +fun MongoshBackend.resolveValueReference( + valueRef: HasValueReference, + fieldRef: HasFieldReference? +) = when (val ref = valueRef.reference) { + is HasValueReference.Constant -> registerConstant(ref.value) + is HasValueReference.Inferred -> registerConstant(ref.value) + is HasValueReference.Runtime -> registerVariable( + (fieldRef?.reference as? FromSchema)?.fieldName ?: "value", + ref.type + ) + + else -> registerVariable( + "queryField", + BsonAny + ) +} + +fun MongoshBackend.resolveFieldReference(fieldRef: HasFieldReference) = + when (val ref = fieldRef.reference) { + is FromSchema -> registerConstant(ref.fieldName) + is Inferred -> registerConstant(ref.fieldName) + is HasFieldReference.Computed -> registerConstant(ref.fieldName) + is Unknown -> registerVariable("field", BsonAny) + } + +fun MongoshBackend.emitCollectionReference(collRef: HasCollectionReference?): MongoshBackend { + when (val ref = collRef?.reference) { + is HasCollectionReference.OnlyCollection -> { + emitDatabaseAccess(registerVariable("database", BsonString)) + emitCollectionAccess(registerConstant(ref.collection)) + } + + is HasCollectionReference.Known -> { + emitDatabaseAccess(registerConstant(ref.namespace.database)) + emitCollectionAccess(registerConstant(ref.namespace.collection)) + } + + else -> { + emitDatabaseAccess(registerVariable("database", BsonString)) + emitCollectionAccess(registerVariable("collection", BsonString)) + } + } + + return this +} diff --git a/packages/mongodb-dialects/mongosh/src/main/kotlin/com/mongodb/jbplugin/dialects/mongosh/query/Sort.kt b/packages/mongodb-dialects/mongosh/src/main/kotlin/com/mongodb/jbplugin/dialects/mongosh/query/Sort.kt new file mode 100644 index 000000000..8fe448bee --- /dev/null +++ b/packages/mongodb-dialects/mongosh/src/main/kotlin/com/mongodb/jbplugin/dialects/mongosh/query/Sort.kt @@ -0,0 +1,33 @@ +package com.mongodb.jbplugin.dialects.mongosh.query + +import com.mongodb.jbplugin.dialects.mongosh.backend.MongoshBackend +import com.mongodb.jbplugin.mql.Node +import com.mongodb.jbplugin.mql.components.HasFieldReference +import com.mongodb.jbplugin.mql.components.HasSorts +import com.mongodb.jbplugin.mql.components.HasValueReference + +fun MongoshBackend.emitSort(query: Node): MongoshBackend { + val sortComponent = query.component>() + if (sortComponent == null) { + return this + } + + fun generateSortKeyVal(node: Node): MongoshBackend { + val fieldRef = node.component>() ?: return this + val valueRef = node.component>() ?: return this + + emitObjectKey(resolveFieldReference(fieldRef)) + emitContextValue(resolveValueReference(valueRef, fieldRef)) + return emitObjectValueEnd() + } + + emitPropertyAccess() + emitFunctionName("sort") + return emitFunctionCall(long = false, { + emitObjectStart(long = false) + for (sortCriteria in sortComponent.children) { + generateSortKeyVal(sortCriteria) + } + emitObjectEnd(long = false) + }) +} diff --git a/packages/mongodb-dialects/mongosh/src/main/kotlin/com/mongodb/jbplugin/dialects/mongosh/query/Update.kt b/packages/mongodb-dialects/mongosh/src/main/kotlin/com/mongodb/jbplugin/dialects/mongosh/query/Update.kt new file mode 100644 index 000000000..e80f2f8de --- /dev/null +++ b/packages/mongodb-dialects/mongosh/src/main/kotlin/com/mongodb/jbplugin/dialects/mongosh/query/Update.kt @@ -0,0 +1,76 @@ +package com.mongodb.jbplugin.dialects.mongosh.query + +import com.mongodb.jbplugin.dialects.mongosh.backend.MongoshBackend +import com.mongodb.jbplugin.mql.Node +import com.mongodb.jbplugin.mql.components.HasFieldReference +import com.mongodb.jbplugin.mql.components.HasFilter +import com.mongodb.jbplugin.mql.components.HasUpdates +import com.mongodb.jbplugin.mql.components.HasValueReference +import com.mongodb.jbplugin.mql.components.Name +import com.mongodb.jbplugin.mql.components.Named + +fun MongoshBackend.emitQueryUpdate(node: Node): MongoshBackend { + val hasUpdates = node.component>() ?: return this + val allUpdates = hasUpdates.children.flatMap { it.recursivelyCollectAllUpdates() } + val groupedUpdates = groupUpdatesByOperator(allUpdates) + + emitObjectStart(long = true) + groupedUpdates.forEach { + emitEachQueryUpdate(it) + emitObjectValueEnd() + } + emitObjectEnd(long = true) + return this +} + +private fun Node.recursivelyCollectAllUpdates(): List> { + val hasUpdates = component>() + if (hasUpdates != null) { + return hasUpdates.children.flatMap { it.recursivelyCollectAllUpdates() } + } + + return listOf(this) +} + +private fun groupUpdatesByOperator(updates: List>): Map>> { + return updates.groupBy { it.component()?.name } + .filter { it.key != null } as Map>> +} + +private fun MongoshBackend.emitEachQueryUpdate(node: Map.Entry>>): MongoshBackend { + val name = node.key + + emitObjectKey(registerConstant("${'$'}${name.canonical}")) + emitObjectStart(long = true) + when (name) { + Name.PULL -> { + for (pullNode in node.value) { + val fieldName = pullNode.component>() ?: continue + val fieldValue = pullNode.component>() + val filter = pullNode.component>() + + if (fieldValue != null) { + emitObjectKey(resolveFieldReference(fieldName)) + emitContextValue(resolveValueReference(fieldValue, fieldName)) + } else if (filter != null && filter.children.isNotEmpty()) { + emitObjectKey(resolveFieldReference(fieldName)) + emitQueryFilter(filter.children[0], firstCall = true) + } + + emitObjectValueEnd() + } + } + else -> { + for (updateNode in node.value) { + val fieldName = updateNode.component>() ?: continue + val fieldValue = updateNode.component>() ?: continue + + emitObjectKey(resolveFieldReference(fieldName)) + emitContextValue(resolveValueReference(fieldValue, fieldName)) + emitObjectValueEnd() + } + } + } + emitObjectEnd(long = true) + return this +} diff --git a/packages/mongodb-dialects/mongosh/src/test/kotlin/com/mongodb/jbplugin/dialects/mongosh/MongoshDialectFormatterTest.kt b/packages/mongodb-dialects/mongosh/src/test/kotlin/com/mongodb/jbplugin/dialects/mongosh/MongoshDialectFormatterTest.kt index 88d50c885..b2426ecb9 100644 --- a/packages/mongodb-dialects/mongosh/src/test/kotlin/com/mongodb/jbplugin/dialects/mongosh/MongoshDialectFormatterTest.kt +++ b/packages/mongodb-dialects/mongosh/src/test/kotlin/com/mongodb/jbplugin/dialects/mongosh/MongoshDialectFormatterTest.kt @@ -1,8 +1,5 @@ package com.mongodb.jbplugin.dialects.mongosh -import com.mongodb.jbplugin.mql.BsonAnyOf -import com.mongodb.jbplugin.mql.BsonArray -import com.mongodb.jbplugin.mql.BsonInt32 import com.mongodb.jbplugin.mql.BsonString import com.mongodb.jbplugin.mql.Namespace import com.mongodb.jbplugin.mql.Node @@ -11,78 +8,8 @@ import com.mongodb.jbplugin.mql.components.* import org.intellij.lang.annotations.Language import org.junit.jupiter.api.Assertions.assertEquals import org.junit.jupiter.api.Test -import org.junit.jupiter.params.ParameterizedTest -import org.junit.jupiter.params.provider.MethodSource -import org.junit.jupiter.params.provider.ValueSource class MongoshDialectFormatterTest { - @Test - fun `can format a query without references to a collection reference`() { - assertGeneratedQuery( - """ - var collection = "" - var database = "" - - db.getSiblingDB(database).getCollection(collection).find({"myField": "myVal", }) - """.trimIndent() - ) { - Node( - Unit, - listOf( - HasFilter( - listOf( - Node( - Unit, - listOf( - HasFieldReference( - HasFieldReference.FromSchema(Unit, "myField") - ), - HasValueReference( - HasValueReference.Constant(Unit, "myVal", BsonString) - ) - ) - ) - ) - ) - ) - ) - } - } - - @Test - fun `can format a simple query`() { - val namespace = Namespace("myDb", "myColl") - - assertGeneratedQuery( - """ - db.getSiblingDB("myDb").getCollection("myColl").find({"myField": "myVal", }) - """.trimIndent() - ) { - Node( - Unit, - listOf( - HasCollectionReference(HasCollectionReference.Known(Unit, Unit, namespace)), - HasFilter( - listOf( - Node( - Unit, - listOf( - Named(Name.EQ), - HasFieldReference( - HasFieldReference.FromSchema(Unit, "myField") - ), - HasValueReference( - HasValueReference.Constant(Unit, "myVal", BsonString) - ) - ) - ) - ) - ) - ) - ) - } - } - @Test fun `can format a simple delete query`() { val namespace = Namespace("myDb", "myColl") @@ -125,11 +52,7 @@ class MongoshDialectFormatterTest { var collection = "" var database = "" - db.getSiblingDB(database) - .getCollection(collection) - .explain("queryPlanner").find( - {"myField": "myVal", } - ) + db.getSiblingDB(database).getCollection(collection).explain("queryPlanner").find({"myField": "myVal", }) """.trimIndent(), explain = QueryContext.ExplainPlanType.SAFE ) { @@ -163,11 +86,7 @@ class MongoshDialectFormatterTest { var collection = "" var database = "" - db.getSiblingDB(database) - .getCollection(collection) - .explain("executionStats").find( - {"myField": "myVal", } - ) + db.getSiblingDB(database).getCollection(collection).explain("executionStats").find({"myField": "myVal", }) """.trimIndent(), explain = QueryContext.ExplainPlanType.FULL ) { @@ -252,13 +171,7 @@ class MongoshDialectFormatterTest { var collection = "" var database = "" - db.getSiblingDB(database) - .getCollection(collection) - .explain("queryPlanner").aggregate( - [ - {"${'$'}match": {"myField": "myVal"}} - ] - ) + db.getSiblingDB(database).getCollection(collection).explain("queryPlanner").aggregate([{"${'$'}match": {"myField": "myVal"}}]) """.trimIndent(), explain = QueryContext.ExplainPlanType.SAFE ) { @@ -310,13 +223,7 @@ class MongoshDialectFormatterTest { var collection = "" var database = "" - db.getSiblingDB(database) - .getCollection(collection) - .explain("executionStats").aggregate( - [ - {"${'$'}match": {"myField": "myVal"}} - ] - ) + db.getSiblingDB(database).getCollection(collection).explain("executionStats").aggregate([{"${'$'}match": {"myField": "myVal"}}]) """.trimIndent(), explain = QueryContext.ExplainPlanType.FULL ) { @@ -361,246 +268,6 @@ class MongoshDialectFormatterTest { } } - @ParameterizedTest - @ValueSource(strings = ["and", "or", "nor"]) - fun `can format a query with subquery operators`(operator: String) { - assertGeneratedQuery( - """ - var collection = "" - var database = "" - - db.getSiblingDB(database).getCollection(collection).find({"${"$"}$operator": [{"myField": "myVal"}, ]}) - """.trimIndent() - ) { - Node( - Unit, - listOf( - Named(Name.from(operator)), - HasFilter( - listOf( - Node( - Unit, - listOf( - HasFieldReference( - HasFieldReference.FromSchema(Unit, "myField") - ), - HasValueReference( - HasValueReference.Constant(Unit, "myVal", BsonString) - ) - ) - ) - ) - ) - ) - ) - } - } - - @Test - fun `can format query using the not operator`() { - val namespace = Namespace("myDb", "myColl") - - assertGeneratedQuery( - """ - db.getSiblingDB("myDb").getCollection("myColl").find({"myField": {"${"$"}not": "myVal"}, }) - """.trimIndent() - ) { - Node( - Unit, - listOf( - HasCollectionReference(HasCollectionReference.Known(Unit, Unit, namespace)), - HasFilter( - listOf( - Node( - Unit, - listOf( - Named(Name.NOT), - HasFilter( - listOf( - Node( - Unit, - listOf( - Named(Name.EQ), - HasFieldReference( - HasFieldReference.FromSchema( - Unit, - "myField" - ) - ), - HasValueReference( - HasValueReference.Constant( - Unit, - "myVal", - BsonString - ) - ) - ) - ) - ) - ) - ) - ) - ) - ) - ) - ) - } - } - - @ParameterizedTest - @ValueSource(strings = ["lt", "lte", "gt", "gte"]) - fun `can format a query with range operators`(operator: String) { - assertGeneratedQuery( - """ - var collection = "" - var database = "" - - db.getSiblingDB(database).getCollection(collection).find({"myField": {"${"$"}$operator": "myVal"}, }) - """.trimIndent() - ) { - Node( - Unit, - listOf( - HasFilter( - listOf( - Node( - Unit, - listOf( - Named(Name.from(operator)), - HasFieldReference( - HasFieldReference.FromSchema(Unit, "myField") - ), - HasValueReference( - HasValueReference.Constant(Unit, "myVal", BsonString) - ) - ) - ) - ) - ) - ) - ) - } - } - - @ParameterizedTest - @ValueSource(strings = ["in", "nin"]) - fun `can format a query with the in or nin operator`(operator: String) { - assertGeneratedQuery( - """ - var collection = "" - var database = "" - - db.getSiblingDB(database).getCollection(collection).find({"myField": {"${"$"}$operator": [1, 2]}, }) - """.trimIndent() - ) { - Node( - Unit, - listOf( - HasFilter( - listOf( - Node( - Unit, - listOf( - Named(Name.from(operator)), - HasFieldReference( - HasFieldReference.FromSchema(Unit, "myField") - ), - HasValueReference( - HasValueReference.Constant( - Unit, - listOf(1, 2), - BsonArray(BsonAnyOf(BsonInt32)) - ) - ) - ) - ) - ) - ) - ) - ) - } - } - - @Test - fun `can sort a find query when specified`() { - assertGeneratedQuery( - """ - var collection = "" - var database = "" - - db.getSiblingDB(database).getCollection(collection).find().sort({"a": 1, }) - """.trimIndent() - ) { - Node( - Unit, - listOf( - IsCommand(IsCommand.CommandType.FIND_MANY), - HasSorts( - listOf( - Node( - Unit, - listOf( - HasFieldReference(HasFieldReference.FromSchema(Unit, "a")), - HasValueReference( - HasValueReference.Inferred(Unit, 1, BsonInt32) - ) - ) - ) - ) - ) - ) - ) - } - } - - companion object { - @JvmStatic - fun queryCommandsThatDoNotReturnSortableCursors(): Array { - return arrayOf( - IsCommand.CommandType.COUNT_DOCUMENTS, - IsCommand.CommandType.DELETE_MANY, - IsCommand.CommandType.DELETE_ONE, - IsCommand.CommandType.DISTINCT, - IsCommand.CommandType.ESTIMATED_DOCUMENT_COUNT, - IsCommand.CommandType.FIND_ONE, - IsCommand.CommandType.FIND_ONE_AND_DELETE, - ) - } - } - - @ParameterizedTest - @MethodSource("queryCommandsThatDoNotReturnSortableCursors") - fun `can not sort a query that does not return a cursor`(command: IsCommand.CommandType) { - assertGeneratedQuery( - """ - var collection = "" - var database = "" - - db.getSiblingDB(database).getCollection(collection).${command.canonical}() - """.trimIndent() - ) { - Node( - Unit, - listOf( - IsCommand(command), - HasSorts( - listOf( - Node( - Unit, - listOf( - HasFieldReference(HasFieldReference.FromSchema(Unit, "a")), - HasValueReference( - HasValueReference.Inferred(Unit, 1, BsonInt32) - ) - ) - ) - ) - ) - ) - ) - } - } - @Test fun `generates an index suggestion for a query given its fields`() { assertGeneratedIndex( @@ -649,16 +316,19 @@ class MongoshDialectFormatterTest { } } -private fun assertGeneratedQuery( +internal fun assertGeneratedQuery( @Language("js") js: String, explain: QueryContext.ExplainPlanType = QueryContext.ExplainPlanType.NONE, script: () -> Node ) { - val generated = MongoshDialectFormatter.formatQuery(script(), QueryContext(emptyMap(), explain)) + val generated = MongoshDialectFormatter.formatQuery( + script(), + QueryContext(emptyMap(), explain, false) + ) assertEquals(js, generated.query) } -private fun assertGeneratedIndex( +internal fun assertGeneratedIndex( @Language("js") js: String, script: () -> Node ) { diff --git a/packages/mongodb-dialects/mongosh/src/test/kotlin/com/mongodb/jbplugin/dialects/mongosh/query/FindTest.kt b/packages/mongodb-dialects/mongosh/src/test/kotlin/com/mongodb/jbplugin/dialects/mongosh/query/FindTest.kt new file mode 100644 index 000000000..b76a62574 --- /dev/null +++ b/packages/mongodb-dialects/mongosh/src/test/kotlin/com/mongodb/jbplugin/dialects/mongosh/query/FindTest.kt @@ -0,0 +1,331 @@ +package com.mongodb.jbplugin.dialects.mongosh.query + +import com.mongodb.jbplugin.dialects.mongosh.assertGeneratedQuery +import com.mongodb.jbplugin.mql.BsonAnyOf +import com.mongodb.jbplugin.mql.BsonArray +import com.mongodb.jbplugin.mql.BsonInt32 +import com.mongodb.jbplugin.mql.BsonString +import com.mongodb.jbplugin.mql.Namespace +import com.mongodb.jbplugin.mql.Namespace.Companion.invoke +import com.mongodb.jbplugin.mql.Node +import com.mongodb.jbplugin.mql.components.HasCollectionReference +import com.mongodb.jbplugin.mql.components.HasFieldReference +import com.mongodb.jbplugin.mql.components.HasFilter +import com.mongodb.jbplugin.mql.components.HasSorts +import com.mongodb.jbplugin.mql.components.HasValueReference +import com.mongodb.jbplugin.mql.components.IsCommand +import com.mongodb.jbplugin.mql.components.Name +import com.mongodb.jbplugin.mql.components.Named +import org.junit.jupiter.api.Test +import org.junit.jupiter.params.ParameterizedTest +import org.junit.jupiter.params.provider.MethodSource +import org.junit.jupiter.params.provider.ValueSource + +class FindTest { + @Test + fun `can format a query without references to a collection reference`() { + assertGeneratedQuery( + """ + var collection = "" + var database = "" + + db.getSiblingDB(database).getCollection(collection).find({"myField": "myVal", }) + """.trimIndent() + ) { + Node( + Unit, + listOf( + HasFilter( + listOf( + Node( + Unit, + listOf( + HasFieldReference( + HasFieldReference.FromSchema(Unit, "myField") + ), + HasValueReference( + HasValueReference.Constant(Unit, "myVal", BsonString) + ) + ) + ) + ) + ) + ) + ) + } + } + + @Test + fun `can format a simple query`() { + val namespace = Namespace("myDb", "myColl") + + assertGeneratedQuery( + """ + db.getSiblingDB("myDb").getCollection("myColl").find({"myField": "myVal", }) + """.trimIndent() + ) { + Node( + Unit, + listOf( + HasCollectionReference(HasCollectionReference.Known(Unit, Unit, namespace)), + HasFilter( + listOf( + Node( + Unit, + listOf( + Named(Name.EQ), + HasFieldReference( + HasFieldReference.FromSchema(Unit, "myField") + ), + HasValueReference( + HasValueReference.Constant(Unit, "myVal", BsonString) + ) + ) + ) + ) + ) + ) + ) + } + } + + @ParameterizedTest + @ValueSource(strings = ["and", "or", "nor"]) + fun `can format a query with subquery operators`(operator: String) { + assertGeneratedQuery( + """ + var collection = "" + var database = "" + + db.getSiblingDB(database).getCollection(collection).find({"${"$"}$operator": [{"myField": "myVal"}, ]}) + """.trimIndent() + ) { + Node( + Unit, + listOf( + Named(Name.from(operator)), + HasFilter( + listOf( + Node( + Unit, + listOf( + HasFieldReference( + HasFieldReference.FromSchema(Unit, "myField") + ), + HasValueReference( + HasValueReference.Constant(Unit, "myVal", BsonString) + ) + ) + ) + ) + ) + ) + ) + } + } + + @Test + fun `can format query using the not operator`() { + val namespace = Namespace("myDb", "myColl") + + assertGeneratedQuery( + """ + db.getSiblingDB("myDb").getCollection("myColl").find({"myField": {"${"$"}not": "myVal"}, }) + """.trimIndent() + ) { + Node( + Unit, + listOf( + HasCollectionReference(HasCollectionReference.Known(Unit, Unit, namespace)), + HasFilter( + listOf( + Node( + Unit, + listOf( + Named(Name.NOT), + HasFilter( + listOf( + Node( + Unit, + listOf( + Named(Name.EQ), + HasFieldReference( + HasFieldReference.FromSchema( + Unit, + "myField" + ) + ), + HasValueReference( + HasValueReference.Constant( + Unit, + "myVal", + BsonString + ) + ) + ) + ) + ) + ) + ) + ) + ) + ) + ) + ) + } + } + + @ParameterizedTest + @ValueSource(strings = ["lt", "lte", "gt", "gte"]) + fun `can format a query with range operators`(operator: String) { + assertGeneratedQuery( + """ + var collection = "" + var database = "" + + db.getSiblingDB(database).getCollection(collection).find({"myField": {"${"$"}$operator": "myVal"}, }) + """.trimIndent() + ) { + Node( + Unit, + listOf( + HasFilter( + listOf( + Node( + Unit, + listOf( + Named(Name.from(operator)), + HasFieldReference( + HasFieldReference.FromSchema(Unit, "myField") + ), + HasValueReference( + HasValueReference.Constant(Unit, "myVal", BsonString) + ) + ) + ) + ) + ) + ) + ) + } + } + + @ParameterizedTest + @ValueSource(strings = ["in", "nin"]) + fun `can format a query with the in or nin operator`(operator: String) { + assertGeneratedQuery( + """ + var collection = "" + var database = "" + + db.getSiblingDB(database).getCollection(collection).find({"myField": {"${"$"}$operator": [1, 2]}, }) + """.trimIndent() + ) { + Node( + Unit, + listOf( + HasFilter( + listOf( + Node( + Unit, + listOf( + Named(Name.from(operator)), + HasFieldReference( + HasFieldReference.FromSchema(Unit, "myField") + ), + HasValueReference( + HasValueReference.Constant( + Unit, + listOf(1, 2), + BsonArray(BsonAnyOf(BsonInt32)) + ) + ) + ) + ) + ) + ) + ) + ) + } + } + + @Test + fun `can sort a find query when specified`() { + assertGeneratedQuery( + """ + var collection = "" + var database = "" + + db.getSiblingDB(database).getCollection(collection).find({}).sort({"a": 1, }) + """.trimIndent() + ) { + Node( + Unit, + listOf( + IsCommand(IsCommand.CommandType.FIND_MANY), + HasSorts( + listOf( + Node( + Unit, + listOf( + HasFieldReference(HasFieldReference.FromSchema(Unit, "a")), + HasValueReference( + HasValueReference.Inferred(Unit, 1, BsonInt32) + ) + ) + ) + ) + ) + ) + ) + } + } + + companion object { + @JvmStatic + fun queryCommandsThatDoNotReturnSortableCursors(): Array { + return arrayOf( + IsCommand.CommandType.COUNT_DOCUMENTS, + IsCommand.CommandType.DELETE_MANY, + IsCommand.CommandType.DELETE_ONE, + IsCommand.CommandType.DISTINCT, + IsCommand.CommandType.ESTIMATED_DOCUMENT_COUNT, + IsCommand.CommandType.FIND_ONE, + IsCommand.CommandType.FIND_ONE_AND_DELETE, + ) + } + } + + @ParameterizedTest + @MethodSource("queryCommandsThatDoNotReturnSortableCursors") + fun `can not sort a query that does not return a cursor`(command: IsCommand.CommandType) { + assertGeneratedQuery( + """ + var collection = "" + var database = "" + + db.getSiblingDB(database).getCollection(collection).${command.canonical}({}) + """.trimIndent() + ) { + Node( + Unit, + listOf( + IsCommand(command), + HasSorts( + listOf( + Node( + Unit, + listOf( + HasFieldReference(HasFieldReference.FromSchema(Unit, "a")), + HasValueReference( + HasValueReference.Inferred(Unit, 1, BsonInt32) + ) + ) + ) + ) + ) + ) + ) + } + } +} diff --git a/packages/mongodb-dialects/mongosh/src/test/kotlin/com/mongodb/jbplugin/dialects/mongosh/query/UpdateTest.kt b/packages/mongodb-dialects/mongosh/src/test/kotlin/com/mongodb/jbplugin/dialects/mongosh/query/UpdateTest.kt new file mode 100644 index 000000000..40d9eccee --- /dev/null +++ b/packages/mongodb-dialects/mongosh/src/test/kotlin/com/mongodb/jbplugin/dialects/mongosh/query/UpdateTest.kt @@ -0,0 +1,273 @@ +package com.mongodb.jbplugin.dialects.mongosh.query + +import com.mongodb.jbplugin.dialects.mongosh.assertGeneratedQuery +import com.mongodb.jbplugin.mql.BsonArray +import com.mongodb.jbplugin.mql.BsonInt32 +import com.mongodb.jbplugin.mql.BsonString +import com.mongodb.jbplugin.mql.Namespace +import com.mongodb.jbplugin.mql.Node +import com.mongodb.jbplugin.mql.components.HasCollectionReference +import com.mongodb.jbplugin.mql.components.HasFieldReference +import com.mongodb.jbplugin.mql.components.HasFilter +import com.mongodb.jbplugin.mql.components.HasUpdates +import com.mongodb.jbplugin.mql.components.HasValueReference +import com.mongodb.jbplugin.mql.components.IsCommand +import com.mongodb.jbplugin.mql.components.Name +import com.mongodb.jbplugin.mql.components.Named +import org.junit.jupiter.api.Test +import org.junit.jupiter.params.ParameterizedTest +import org.junit.jupiter.params.provider.MethodSource + +class UpdateTest { + @ParameterizedTest + @MethodSource("allSetterLikeOperators") + fun `generates an update one query with a simple setter-like operator`(name: Name) { + assertGeneratedQuery( + """ + db.getSiblingDB("myDb").getCollection("myColl").updateOne({}, {"${'$'}${name.canonical}": {"myField": "myValue", }, }) + """.trimIndent() + ) { + Node( + Unit, + listOf( + HasCollectionReference( + HasCollectionReference.Known(Unit, Unit, Namespace("myDb", "myColl")) + ), + IsCommand(IsCommand.CommandType.UPDATE_ONE), + HasUpdates( + listOf( + Node( + Unit, + listOf( + Named(name), + HasFieldReference( + HasFieldReference.FromSchema(Unit, "myField") + ), + HasValueReference( + HasValueReference.Constant(Unit, "myValue", BsonString) + ), + ) + ) + ) + ) + ) + ) + } + } + + @ParameterizedTest + @MethodSource("allSetterLikeOperators") + fun `generates an update many query with a simple setter-like operator`(name: Name) { + assertGeneratedQuery( + """ + db.getSiblingDB("myDb").getCollection("myColl").updateMany({}, {"${'$'}${name.canonical}": {"myField": "myValue", }, }) + """.trimIndent() + ) { + Node( + Unit, + listOf( + HasCollectionReference( + HasCollectionReference.Known(Unit, Unit, Namespace("myDb", "myColl")) + ), + IsCommand(IsCommand.CommandType.UPDATE_MANY), + HasUpdates( + listOf( + Node( + Unit, + listOf( + Named(name), + HasFieldReference( + HasFieldReference.FromSchema(Unit, "myField") + ), + HasValueReference( + HasValueReference.Constant(Unit, "myValue", BsonString) + ), + ) + ) + ) + ) + ) + ) + } + } + + @Test + fun `supports pull with a filter`() { + assertGeneratedQuery( + """ + db.getSiblingDB("myDb").getCollection("myColl").updateMany({}, {"${'$'}pull": {"myField": {"innerField": 123}, }, }) + """.trimIndent() + ) { + Node( + Unit, + listOf( + HasCollectionReference( + HasCollectionReference.Known(Unit, Unit, Namespace("myDb", "myColl")) + ), + IsCommand(IsCommand.CommandType.UPDATE_MANY), + HasUpdates( + listOf( + Node( + Unit, + listOf( + Named(Name.PULL), + HasFieldReference( + HasFieldReference.FromSchema(Unit, "myField") + ), + HasFilter( + listOf( + Node( + Unit, + listOf( + HasFieldReference( + HasFieldReference.FromSchema( + Unit, + "innerField" + ) + ), + HasValueReference( + HasValueReference.Constant( + Unit, + 123, + BsonInt32 + ) + ), + ) + ) + ) + ) + ) + ) + ) + ) + ) + ) + } + } + + @Test + fun `supports pullAll with a list of elements`() { + assertGeneratedQuery( + """ + db.getSiblingDB("myDb").getCollection("myColl").updateMany({}, {"${'$'}pullAll": {"myField": [1, 2, 3], }, }) + """.trimIndent() + ) { + Node( + Unit, + listOf( + HasCollectionReference( + HasCollectionReference.Known(Unit, Unit, Namespace("myDb", "myColl")) + ), + IsCommand(IsCommand.CommandType.UPDATE_MANY), + HasUpdates( + listOf( + Node( + Unit, + listOf( + Named(Name.PULL_ALL), + HasFieldReference( + HasFieldReference.FromSchema(Unit, "myField") + ), + HasValueReference( + HasValueReference.Constant( + Unit, + listOf(1, 2, 3), + BsonArray(BsonInt32) + ) + ) + ) + ) + ) + ) + ) + ) + } + } + + @Test + fun `combines the same type of operation`() { + assertGeneratedQuery( + """ + db.getSiblingDB("myDb").getCollection("myColl").updateMany({}, {"${'$'}pullAll": {"myField": [1, 2, 3], }, "${'$'}set": {"field1": [1], "field2": "abc", }, }) + """.trimIndent() + ) { + Node( + Unit, + listOf( + HasCollectionReference( + HasCollectionReference.Known(Unit, Unit, Namespace("myDb", "myColl")) + ), + IsCommand(IsCommand.CommandType.UPDATE_MANY), + HasUpdates( + listOf( + Node( + Unit, + listOf( + Named(Name.PULL_ALL), + HasFieldReference( + HasFieldReference.FromSchema(Unit, "myField") + ), + HasValueReference( + HasValueReference.Constant( + Unit, + listOf(1, 2, 3), + BsonArray(BsonInt32) + ) + ) + ), + ), + Node( + Unit, + listOf( + Named(Name.SET), + HasFieldReference( + HasFieldReference.FromSchema(Unit, "field1") + ), + HasValueReference( + HasValueReference.Constant( + Unit, + listOf(1), + BsonArray(BsonInt32) + ) + ) + ), + ), + Node( + Unit, + listOf( + Named(Name.SET), + HasFieldReference( + HasFieldReference.FromSchema(Unit, "field2") + ), + HasValueReference( + HasValueReference.Constant( + Unit, + "abc", + BsonString, + ) + ) + ), + ) + ) + ) + ) + ) + } + } + + companion object { + @JvmStatic + fun allSetterLikeOperators(): Array = arrayOf( + Name.INC, + Name.MIN, + Name.MAX, + Name.SET, + Name.SET_ON_INSERT, + Name.UNSET, + Name.ADD_FIELDS, + Name.POP, + Name.PUSH, + Name.PULL + ) + } +} diff --git a/packages/mongodb-mql-model/src/main/kotlin/com/mongodb/jbplugin/mql/Node.kt b/packages/mongodb-mql-model/src/main/kotlin/com/mongodb/jbplugin/mql/Node.kt index 963c8c47b..54a703d5e 100644 --- a/packages/mongodb-mql-model/src/main/kotlin/com/mongodb/jbplugin/mql/Node.kt +++ b/packages/mongodb-mql-model/src/main/kotlin/com/mongodb/jbplugin/mql/Node.kt @@ -103,6 +103,7 @@ data class Node( data class QueryContext( val expansions: Map, val explainPlan: ExplainPlanType, + val prettyPrint: Boolean, ) { data class LocalVariable(val type: BsonType, val defaultValue: Any?) enum class ExplainPlanType { @@ -112,6 +113,7 @@ data class QueryContext( } companion object { - fun empty(): QueryContext = QueryContext(emptyMap(), ExplainPlanType.NONE) + fun empty(prettyPrint: Boolean = false): QueryContext = + QueryContext(emptyMap(), ExplainPlanType.NONE, prettyPrint) } } diff --git a/packages/mongodb-mql-model/src/main/kotlin/com/mongodb/jbplugin/mql/components/Named.kt b/packages/mongodb-mql-model/src/main/kotlin/com/mongodb/jbplugin/mql/components/Named.kt index 68951b624..128277ea4 100644 --- a/packages/mongodb-mql-model/src/main/kotlin/com/mongodb/jbplugin/mql/components/Named.kt +++ b/packages/mongodb-mql-model/src/main/kotlin/com/mongodb/jbplugin/mql/components/Named.kt @@ -33,6 +33,7 @@ enum class Name(val canonical: String) { GT("gt"), GTE("gte"), IN("in"), + INC("inc"), LT("lt"), LTE("lte"), NE("ne"), @@ -44,6 +45,7 @@ enum class Name(val canonical: String) { OR("or"), REGEX("regex"), SET("set"), + SET_ON_INSERT("setOnInsert"), SIZE("size"), TEXT("text"), TYPE("type"), @@ -64,6 +66,9 @@ enum class Name(val canonical: String) { MAX("max"), MIN("min"), PUSH("push"), + PULL("pull"), + PULL_ALL("pullAll"), + POP("pop"), ADD_TO_SET("addToSet"), SORT("sort"), ASCENDING("ascending"),