diff --git a/packages/mongodb-dialects/java-driver/src/main/kotlin/com/mongodb/jbplugin/dialects/javadriver/glossary/JavaDriverDialectParser.kt b/packages/mongodb-dialects/java-driver/src/main/kotlin/com/mongodb/jbplugin/dialects/javadriver/glossary/JavaDriverDialectParser.kt index 379bcbfc..99957d1a 100644 --- a/packages/mongodb-dialects/java-driver/src/main/kotlin/com/mongodb/jbplugin/dialects/javadriver/glossary/JavaDriverDialectParser.kt +++ b/packages/mongodb-dialects/java-driver/src/main/kotlin/com/mongodb/jbplugin/dialects/javadriver/glossary/JavaDriverDialectParser.kt @@ -1,33 +1,179 @@ package com.mongodb.jbplugin.dialects.javadriver.glossary -import com.intellij.psi.PsiElement -import com.intellij.psi.PsiMethod -import com.intellij.psi.PsiMethodCallExpression +import com.intellij.psi.* import com.intellij.psi.util.PsiTreeUtil import com.mongodb.jbplugin.dialects.DialectParser +import com.mongodb.jbplugin.mql.Namespace import com.mongodb.jbplugin.mql.Node -import com.mongodb.jbplugin.mql.components.HasCollectionReference +import com.mongodb.jbplugin.mql.components.* +import com.mongodb.jbplugin.mql.toBsonType object JavaDriverDialectParser : DialectParser { - override fun isCandidateForQuery(source: PsiElement): Boolean = - (source as? PsiMethodCallExpression)?.findMongoDbCollectionReference(source.project) != null + override fun isCandidateForQuery(source: PsiElement): Boolean { + if (source !is PsiMethodCallExpression) { +// if it's not a method call, like .find(), it's not a query + return false + } + val sourceMethod = source.resolveMethod() ?: return false - override fun attachment(source: PsiElement): PsiElement = - (source as PsiMethodCallExpression).findMongoDbCollectionReference(source.project)!! + if ( // if the method is of MongoCollection, then we are in a query + sourceMethod.containingClass?.isMongoDbCollectionClass(source.project) == true + ) { + return true + } + + if ( // if it's any driver class, check inner calls + sourceMethod.containingClass?.isMongoDbClass(source.project) == true || + sourceMethod.containingClass?.isMongoDbCursorClass(source.project) == true + ) { + val allChildrenCandidates = PsiTreeUtil.findChildrenOfType(source, PsiMethodCallExpression::class.java) + return allChildrenCandidates.any { isCandidateForQuery(it) } + } + + return false + } + + override fun attachment(source: PsiElement): PsiElement = source.findMongoDbCollectionReference()!! override fun parse(source: PsiElement): Node { - val owningMethod = - PsiTreeUtil.getParentOfType(source, PsiMethod::class.java) - ?: return Node(source, emptyList()) - val namespace = NamespaceExtractor.extractNamespace(owningMethod) - - return Node( - source, - listOf( - namespace?.let { - HasCollectionReference(HasCollectionReference.Known(namespace)) - } ?: HasCollectionReference(HasCollectionReference.Unknown), - ), - ) + val namespace = NamespaceExtractor.extractNamespace(source) + val collectionReference = namespaceComponent(namespace) + + val currentCall = source as PsiMethodCallExpression? ?: return Node(source, listOf(collectionReference)) + + val calledMethod = currentCall.resolveMethod() + if (calledMethod?.containingClass?.isMongoDbCollectionClass(source.project) == true) { + val hasChildren = + if (currentCall.argumentList.expressionCount > 0) { +// we have at least 1 argument in the current method call + val argumentAsFilters = resolveToFiltersCall(currentCall.argumentList.expressions[0]) + argumentAsFilters?.let { + val parsedQuery = parseFilterExpression(argumentAsFilters) // assume it's a Filters call + parsedQuery?.let { + HasChildren( + listOf( + parseFilterExpression( + argumentAsFilters, + )!!, + ), + ) + } ?: HasChildren(emptyList()) + } ?: HasChildren(emptyList()) + } else { + HasChildren(emptyList()) + } + + return Node( + source, + listOf( + collectionReference, + hasChildren, + ), + ) + } else { + calledMethod?.let { +// if it's another class, try to resolve the query from the method body + val allReturns = PsiTreeUtil.findChildrenOfType(calledMethod.body, PsiReturnStatement::class.java) + return allReturns + .mapNotNull { it.returnValue } + .flatMap { + it.collectTypeUntil(PsiMethodCallExpression::class.java, PsiReturnStatement::class.java) + }.firstNotNullOfOrNull { + val innerQuery = parse(it) + if (!innerQuery.hasComponent>()) { + null + } else { + innerQuery + } + } ?: Node(source, listOf(collectionReference)) + } ?: return Node(source, listOf(collectionReference)) + } } + + private fun parseFilterExpression(filter: PsiMethodCallExpression): Node? { + val method = filter.resolveMethod() ?: return null + if (method.isVarArgs) { +// Filters.and, Filters.or... are varargs + return Node( + filter, + listOf( + Named(method.name), + HasChildren( + filter.argumentList.expressions + .mapNotNull { resolveToFiltersCall(it) } + .mapNotNull { parseFilterExpression(it) }, + ), + ), + ) + } else if (method.parameters.size == 2) { +// If it has two parameters, it's field/value. + val fieldNameAsString = filter.argumentList.expressions[0].tryToResolveAsConstantString() + val fieldReference = + fieldNameAsString?.let { + HasFieldReference.Known(filter.argumentList.expressions[0], fieldNameAsString) + } ?: HasFieldReference.Unknown + + val constantValue = filter.argumentList.expressions[1].tryToResolveAsConstant() + val typeOfConstantValue = constantValue?.javaClass?.toBsonType() + + val valueReference = + if (constantValue != null && typeOfConstantValue != null) { + HasValueReference.Constant(constantValue, typeOfConstantValue) + } else { + val psiTypeOfValue = + filter.argumentList.expressions[1] + .type + ?.toBsonType() + psiTypeOfValue?.let { + HasValueReference.Runtime(psiTypeOfValue) + } ?: HasValueReference.Unknown + } + + return Node( + filter, + listOf( + Named(method.name), + HasFieldReference( + fieldReference, + ), + HasValueReference( + valueReference, + ), + ), + ) + } +// here we really don't know much, so just don't attempt to parse the query + return null + } + + private fun resolveToFiltersCall(element: PsiElement): PsiMethodCallExpression? { + when (element) { + is PsiMethodCallExpression -> { + val method = element.resolveMethod() ?: return null + if (method.containingClass?.qualifiedName == "com.mongodb.client.model.Filters") { + return element + } + val allReturns = PsiTreeUtil.findChildrenOfType(method.body, PsiReturnStatement::class.java) + return allReturns.mapNotNull { it.returnValue }.firstNotNullOfOrNull { + resolveToFiltersCall(it) + } + } + is PsiVariable -> { + element.initializer ?: return null + return resolveToFiltersCall(element.initializer!!) + } + + is PsiReferenceExpression -> { + val referredValue = element.resolve() ?: return null + return resolveToFiltersCall(referredValue) + } + + else -> return null + } + } + + private fun namespaceComponent(namespace: Namespace?): HasCollectionReference = + namespace?.let { + HasCollectionReference(HasCollectionReference.Known(namespace)) + } ?: HasCollectionReference(HasCollectionReference.Unknown) } diff --git a/packages/mongodb-dialects/java-driver/src/main/kotlin/com/mongodb/jbplugin/dialects/javadriver/glossary/NamespaceExtractor.kt b/packages/mongodb-dialects/java-driver/src/main/kotlin/com/mongodb/jbplugin/dialects/javadriver/glossary/NamespaceExtractor.kt index db71fcf1..7d0a7ab8 100644 --- a/packages/mongodb-dialects/java-driver/src/main/kotlin/com/mongodb/jbplugin/dialects/javadriver/glossary/NamespaceExtractor.kt +++ b/packages/mongodb-dialects/java-driver/src/main/kotlin/com/mongodb/jbplugin/dialects/javadriver/glossary/NamespaceExtractor.kt @@ -15,8 +15,9 @@ private typealias FoundAssignedPsiFields = List = referencesToMongoDbClasses.flatMap { ref -> diff --git a/packages/mongodb-dialects/java-driver/src/main/kotlin/com/mongodb/jbplugin/dialects/javadriver/glossary/PsiMdbTreeUtil.kt b/packages/mongodb-dialects/java-driver/src/main/kotlin/com/mongodb/jbplugin/dialects/javadriver/glossary/PsiMdbTreeUtil.kt index 84a38866..17588bae 100644 --- a/packages/mongodb-dialects/java-driver/src/main/kotlin/com/mongodb/jbplugin/dialects/javadriver/glossary/PsiMdbTreeUtil.kt +++ b/packages/mongodb-dialects/java-driver/src/main/kotlin/com/mongodb/jbplugin/dialects/javadriver/glossary/PsiMdbTreeUtil.kt @@ -1,5 +1,5 @@ /** - * Defines an a set of extension methods to extract metadata from a Psi tree. + * Defines a set of extension methods to extract metadata from a Psi tree. */ package com.mongodb.jbplugin.dialects.javadriver.glossary @@ -12,6 +12,7 @@ import com.intellij.psi.util.PsiTreeUtil import com.intellij.psi.util.PsiTypesUtil import com.intellij.psi.util.childrenOfType import com.intellij.psi.util.parentOfType +import com.mongodb.jbplugin.mql.* /** * Helper extension function to get the containing class of any element. @@ -108,6 +109,35 @@ fun PsiType.isMongoDbClientClass(project: Project): Boolean { return thisClass?.isMongoDbClientClass(project) == true } +/** + * Helper function to check if a type is a MongoDB Cursor + * + * @param project + * @return + */ +fun PsiClass.isMongoDbCursorClass(project: Project): Boolean { + val javaFacade = JavaPsiFacade.getInstance(project) + + val mdbCursorClass = + javaFacade.findClass( + "com.mongodb.client.MongoIterable", + GlobalSearchScope.everythingScope(project), + ) + + return this.isInheritor(mdbCursorClass!!, false) || this == mdbCursorClass +} + +/** + * Helper function to check if a type is a MongoDB Cursor + * + * @param project + * @return + */ +fun PsiType.isMongoDbCursorClass(project: Project): Boolean { + val thisClass = PsiTypesUtil.getPsiClass(this) + return thisClass?.isMongoDbCursorClass(project) == true +} + /** * Helper function to check if a type is a MongoDB Class * @@ -132,18 +162,6 @@ fun PsiClass.isMongoDbClass(project: Project): Boolean = isMongoDbDatabaseClass(project) || isMongoDbClientClass(project) -/** - * Checks if a method is calling a MongoDB driver method. - * - * @return - */ -fun PsiMethod.isUsingMongoDbClasses(): Boolean = - PsiTreeUtil.findChildrenOfType(this, PsiMethodCallExpression::class.java).any { - it.methodExpression.qualifierExpression - ?.type - ?.isMongoDbClass(this.project) == true - } - /** * Finds all references to the MongoDB driver in a method. * @@ -171,7 +189,7 @@ fun PsiMethodCallExpression.findCurrentReferenceToMongoDbObject(): PsiReference? if (resolution is PsiField) { return if (resolution.type.isMongoDbClass(project)) resolution.reference else null } else { - return (methodExpression.resolve() as PsiMethod?)?.findAllReferencesToMongoDbObjects()?.first() + return (methodExpression.resolve() as PsiMethod?)?.findAllReferencesToMongoDbObjects()?.firstOrNull() } } else { if (methodExpression.qualifierExpression is PsiMethodCallExpression) { @@ -217,24 +235,35 @@ fun PsiMethodCallExpression.findMongoDbClassReference(project: Project): PsiExpr } else if (methodExpression.qualifierExpression?.reference?.resolve() is PsiField) { return methodExpression.qualifierExpression } else { - return null + val method = resolveMethod() ?: return null + return method.body + ?.collectTypeUntil(PsiMethodCallExpression::class.java, PsiMethod::class.java) + ?.firstNotNullOfOrNull { it.findMongoDbClassReference(it.project) } } } /** * Returns the reference to a MongoDB driver collection. - * - * @param project */ -fun PsiMethodCallExpression.findMongoDbCollectionReference(project: Project): PsiExpression? { - if (methodExpression.type?.isMongoDbCollectionClass(project) == true) { - return methodExpression - } else if (methodExpression.qualifierExpression is PsiMethodCallExpression) { - return (methodExpression.qualifierExpression as PsiMethodCallExpression).findMongoDbCollectionReference(project) - } else if (methodExpression.qualifierExpression?.reference?.resolve() is PsiField) { - return methodExpression.qualifierExpression - } else { +fun PsiElement.findMongoDbCollectionReference(): PsiExpression? { + if (this is PsiMethodCallExpression) { + if (methodExpression.type?.isMongoDbCollectionClass(project) == true) { + return methodExpression + } else if (methodExpression.qualifierExpression is PsiMethodCallExpression) { + return (methodExpression.qualifierExpression as PsiMethodCallExpression).findMongoDbCollectionReference() + } else if (methodExpression.qualifierExpression?.reference?.resolve() is PsiField) { + return methodExpression.qualifierExpression + } else { + return methodExpression.children.firstNotNullOfOrNull { it.findMongoDbCollectionReference() } + } + } else if (this is PsiExpression) { + if (this.type?.isMongoDbCollectionClass(project) == true) { + return this + } + return null + } else { + return children.firstNotNullOfOrNull { it.findMongoDbCollectionReference() } } } @@ -242,18 +271,65 @@ fun PsiMethodCallExpression.findMongoDbCollectionReference(project: Project): Ps * Resolves to the value of the expression if it can be known at compile time * or null if it can only be known at runtime. */ -fun PsiElement.tryToResolveAsConstantString(): String? { +fun PsiElement.tryToResolveAsConstant(): Any? { if (this is PsiReferenceExpression) { val varRef = this.resolve()!! - return varRef.tryToResolveAsConstantString() + return varRef.tryToResolveAsConstant() } else if (this is PsiLocalVariable) { - return this.initializer?.tryToResolveAsConstantString() + return this.initializer?.tryToResolveAsConstant() } else if (this is PsiLiteralValue) { val facade = JavaPsiFacade.getInstance(this.project) - return facade.constantEvaluationHelper.computeConstantExpression(this) as? String + return facade.constantEvaluationHelper.computeConstantExpression(this) + } else if (this is PsiLiteralExpression) { + val facade = JavaPsiFacade.getInstance(this.project) + return facade.constantEvaluationHelper.computeConstantExpression(this) } else if (this is PsiField && this.hasModifier(JvmModifier.FINAL)) { - return this.initializer?.tryToResolveAsConstantString() + return this.initializer?.tryToResolveAsConstant() } return null } + +/** + * Resolves to the value of the expression to a string + * if it's known at compile time. + * + * @return + */ +fun PsiElement.tryToResolveAsConstantString(): String? = tryToResolveAsConstant()?.toString() + +/** + * Maps a PsiType to its BSON counterpart. + * + * @return + */ +fun PsiType.toBsonType(): BsonType { + if (this.equalsToText("org.bson.types.ObjectId")) { + return BsonAnyOf(BsonObjectId, BsonNull) + } else if (this.equalsToText("boolean") || this.equalsToText("java.lang.Boolean")) { + return BsonBoolean + } else if (this.equalsToText("short") || this.equalsToText("java.lang.Short")) { + return BsonInt32 + } else if (this.equalsToText("int") || this.equalsToText("java.lang.Integer")) { + return BsonInt32 + } else if (this.equalsToText("long") || this.equalsToText("java.lang.Long")) { + return BsonInt64 + } else if (this.equalsToText("float") || this.equalsToText("java.lang.Float")) { + return BsonDouble + } else if (this.equalsToText("double") || this.equalsToText("java.lang.Double")) { + return BsonDouble + } else if (this.equalsToText("java.lang.CharSequence") || this.equalsToText("java.lang.String")) { + return BsonAnyOf(BsonString, BsonNull) + } else if (this.equalsToText("java.util.Date") || + this.equalsToText("java.time.LocalDate") || + this.equalsToText("java.time.LocalDateTime") + ) { + return BsonAnyOf(BsonDate, BsonNull) + } else if (this.equalsToText("java.math.BigInteger")) { + return BsonAnyOf(BsonInt64, BsonNull) + } else if (this.equalsToText("java.math.BigDecimal")) { + return BsonAnyOf(BsonDecimal128, BsonNull) + } + + return BsonAny +} diff --git a/packages/mongodb-dialects/java-driver/src/test/kotlin/com/mongodb/jbplugin/dialects/javadriver/IntegrationTest.kt b/packages/mongodb-dialects/java-driver/src/test/kotlin/com/mongodb/jbplugin/dialects/javadriver/IntegrationTest.kt index 46fd7274..97b9ff4b 100644 --- a/packages/mongodb-dialects/java-driver/src/test/kotlin/com/mongodb/jbplugin/dialects/javadriver/IntegrationTest.kt +++ b/packages/mongodb-dialects/java-driver/src/test/kotlin/com/mongodb/jbplugin/dialects/javadriver/IntegrationTest.kt @@ -7,7 +7,6 @@ package com.mongodb.jbplugin.dialects.javadriver import com.intellij.java.library.JavaLibraryUtil import com.intellij.openapi.application.ApplicationManager -import com.intellij.openapi.module.ModuleUtilCore import com.intellij.openapi.project.DumbService import com.intellij.openapi.project.Project import com.intellij.openapi.project.guessProjectDir @@ -21,6 +20,8 @@ import com.intellij.testFramework.PsiTestUtil import com.intellij.testFramework.fixtures.CodeInsightTestFixture import com.intellij.testFramework.fixtures.IdeaTestFixtureFactory import com.mongodb.client.MongoClient +import com.mongodb.client.model.Filters +import org.bson.types.ObjectId import org.intellij.lang.annotations.Language import org.junit.jupiter.api.Test import org.junit.jupiter.api.extension.* @@ -59,6 +60,7 @@ internal class IntegrationTestExtension : ParameterResolver { private val namespace = ExtensionContext.Namespace.create(IntegrationTestExtension::class.java) private val testFixtureKey = "TESTFIXTURE" + private val testPathKey = "TESTPATH" override fun beforeAll(context: ExtensionContext) { val projectFixture = @@ -78,31 +80,46 @@ internal class IntegrationTestExtension : testFixture.setUp() ApplicationManager.getApplication().invokeAndWait { - if (!JavaLibraryUtil.hasLibraryJar(testFixture.module, "org.mongodb:mongodb-driver-sync:5.1.1")) { + val module = testFixture.module + + if (!JavaLibraryUtil.hasLibraryJar(module, "org.mongodb:mongodb-driver-sync:5.1.0")) { runCatching { PsiTestUtil.addProjectLibrary( - testFixture.module, - "mongodb-driver-sync", - listOf( - Path( - pathToJavaDriver(), - ).toAbsolutePath().toString(), - ), + module, + "org.mongodb:mongodb-driver-sync:5.1.0", + listOf(pathToClassJarFile(MongoClient::class.java)), + ) + + PsiTestUtil.addProjectLibrary( + module, + "org.mongodb:mongodb-driver-core:5.1.0", + listOf(pathToClassJarFile(Filters::class.java)), + ) + + PsiTestUtil.addProjectLibrary( + module, + "org.mongodb:bson:5.1.0", + listOf(pathToClassJarFile(ObjectId::class.java)), ) } } } PsiTestUtil.addSourceRoot(testFixture.module, testFixture.project.guessProjectDir()!!) + val tmpRootDir = testFixture.tempDirFixture.getFile(".")!! + PsiTestUtil.addSourceRoot(testFixture.module, tmpRootDir) + context.getStore(namespace).put(testPathKey, tmpRootDir.path) } override fun beforeEach(context: ExtensionContext) { val fixture = context.getStore(namespace).get(testFixtureKey) as CodeInsightTestFixture + val modulePath = context.getStore(namespace).get(testPathKey).toString() ApplicationManager.getApplication().invokeAndWait { - val parsingTest = context.requiredTestMethod.getAnnotation(ParsingTest::class.java) - val path = ModuleUtilCore.getModuleDirPath(fixture.module) - val fileName = Path(path, "src", "main", "java", parsingTest.fileName).absolutePathString() + val parsingTest = context.requiredTestMethod.getAnnotation(ParsingTest::class.java) ?: return@invokeAndWait + + val fileName = Path(modulePath, "src", "main", "java", parsingTest.fileName).absolutePathString() + fixture.configureByText( fileName, parsingTest.value, @@ -119,7 +136,8 @@ internal class IntegrationTestExtension : val finished = AtomicBoolean(false) val fixture = extensionContext.getStore(namespace).get(testFixtureKey) as CodeInsightTestFixture - val dumbService = fixture.project.getService(DumbService::class.java) + val dumbService = DumbService.getInstance(fixture.project) + dumbService.runWhenSmart { val result = runCatching { @@ -140,7 +158,6 @@ internal class IntegrationTestExtension : } throwable.get()?.let { - System.err.println(it.message) it.printStackTrace(System.err) throw it } @@ -178,9 +195,9 @@ internal class IntegrationTestExtension : } } - private fun pathToJavaDriver(): String { + private fun pathToClassJarFile(javaClass: Class<*>): String { val classResource: URL = - MongoClient::class.java.getResource(MongoClient::class.java.getSimpleName() + ".class") + javaClass.getResource(javaClass.getSimpleName() + ".class") ?: throw RuntimeException("class resource is null") val url: String = classResource.toString() if (url.startsWith("jar:file:")) { diff --git a/packages/mongodb-dialects/java-driver/src/test/kotlin/com/mongodb/jbplugin/dialects/javadriver/glossary/JavaDriverDialectParserTest.kt b/packages/mongodb-dialects/java-driver/src/test/kotlin/com/mongodb/jbplugin/dialects/javadriver/glossary/JavaDriverDialectParserTest.kt index 7a7bf538..0e6dd87e 100644 --- a/packages/mongodb-dialects/java-driver/src/test/kotlin/com/mongodb/jbplugin/dialects/javadriver/glossary/JavaDriverDialectParserTest.kt +++ b/packages/mongodb-dialects/java-driver/src/test/kotlin/com/mongodb/jbplugin/dialects/javadriver/glossary/JavaDriverDialectParserTest.kt @@ -6,9 +6,12 @@ import com.intellij.psi.util.PsiTreeUtil import com.mongodb.jbplugin.dialects.javadriver.IntegrationTest import com.mongodb.jbplugin.dialects.javadriver.ParsingTest import com.mongodb.jbplugin.dialects.javadriver.getQueryAtMethod -import com.mongodb.jbplugin.mql.components.HasCollectionReference -import org.junit.jupiter.api.Assertions.assertEquals -import org.junit.jupiter.api.Assertions.assertTrue +import com.mongodb.jbplugin.mql.BsonAnyOf +import com.mongodb.jbplugin.mql.BsonBoolean +import com.mongodb.jbplugin.mql.BsonNull +import com.mongodb.jbplugin.mql.BsonObjectId +import com.mongodb.jbplugin.mql.components.* +import org.junit.jupiter.api.Assertions.* @IntegrationTest class JavaDriverDialectParserTest { @@ -46,6 +49,32 @@ import com.mongodb.client.MongoCollection; import org.bson.types.ObjectId; import static com.mongodb.client.model.Filters.*; +public final class Repository { + private final MongoClient client; + + public Repository(MongoClient client) { + this.client = client; + } + + public Document getCollection() { + return client.getDatabase("simple").getCollection("books"); + } +} + """, + ) + fun `not a candidate if does not query`(psiFile: PsiFile) { + val query = psiFile.getQueryAtMethod("Repository", "getCollection") + assertFalse(JavaDriverDialectParser.isCandidateForQuery(query)) + } + + @ParsingTest( + fileName = "Repository.java", + value = """ +import com.mongodb.client.MongoClient; +import com.mongodb.client.MongoCollection; +import org.bson.types.ObjectId; +import static com.mongodb.client.model.Filters.*; + public final class Repository { private final MongoCollection collection; @@ -132,4 +161,343 @@ public final class Repository { assertEquals(HasCollectionReference.Unknown, unknownReference) } + + @ParsingTest( + fileName = "Repository.java", + value = """ +import com.mongodb.client.MongoCollection; +import com.mongodb.client.model.Filters; +import org.bson.Document; +import org.bson.types.ObjectId; +import com.mongodb.client.FindIterable; + +public final class Repository { + private final MongoCollection collection; + + public Repository(MongoCollection collection) { + this.collection = collection; + } + + public FindIterable findBookById(ObjectId id) { + return this.collection.find(Filters.eq("_id", id)); + } +} + """, + ) + fun `can parse a basic Filters query`(psiFile: PsiFile) { + val query = psiFile.getQueryAtMethod("Repository", "findBookById") + val parsedQuery = JavaDriverDialect.parser.parse(query) + + val hasChildren = + parsedQuery.component>()!! + + val eq = hasChildren.children[0] + assertEquals("eq", eq.component()!!.name) + assertEquals("_id", (eq.component>()!!.reference as HasFieldReference.Known).fieldName) + assertEquals( + BsonAnyOf(BsonObjectId, BsonNull), + (eq.component()!!.reference as HasValueReference.Runtime).type, + ) + } + + @ParsingTest( + fileName = "Repository.java", + value = """ +import com.mongodb.client.FindIterable; +import com.mongodb.client.MongoClient; +import org.bson.Document; +import static com.mongodb.client.model.Filters.*; + +public class Repository { + private final MongoClient client; + + public Repository(MongoClient client) { + this.client = client; + } + + public FindIterable findReleasedBooks() { + return client.getDatabase("myDatabase") + .getCollection("myCollection") + .find(eq("myField", true)); + } +} + """, + ) + fun `can parse a basic Filters query with a constant parameter in a chain of calls`(psiFile: PsiFile) { + val query = psiFile.getQueryAtMethod("Repository", "findReleasedBooks") + val parsedQuery = JavaDriverDialect.parser.parse(query) + + val hasChildren = + parsedQuery.component>()!! + + val eq = hasChildren.children[0] + assertEquals("eq", eq.component()!!.name) + assertEquals( + "myField", + (eq.component>()!!.reference as HasFieldReference.Known).fieldName, + ) + assertEquals( + BsonAnyOf(BsonNull, BsonBoolean), + (eq.component()!!.reference as HasValueReference.Constant).type, + ) + assertEquals(true, (eq.component()!!.reference as HasValueReference.Constant).value) + } + + @ParsingTest( + fileName = "Repository.java", + value = """ +import com.mongodb.client.FindIterable; +import com.mongodb.client.MongoClient; +import org.bson.Document; +import static com.mongodb.client.model.Filters.*; + +public class Repository { + private final MongoClient client; + + public Repository(MongoClient client) { + this.client = client; + } + + public FindIterable findReleasedBooks() { + return client.getDatabase("myDatabase") + .getCollection("myCollection") + .find(and(eq("released", true), eq("hidden", false))); + } +} + """, + ) + fun `supports vararg operators`(psiFile: PsiFile) { + val query = psiFile.getQueryAtMethod("Repository", "findReleasedBooks") + val parsedQuery = JavaDriverDialect.parser.parse(query) + + val hasChildren = + parsedQuery.component>()!! + + val and = hasChildren.children[0] + assertEquals("and", and.component()!!.name) + val andChildren = and.component>()!! + + val firstEq = andChildren.children[0] + assertEquals( + "released", + (firstEq.component>()!!.reference as HasFieldReference.Known).fieldName, + ) + assertEquals( + BsonAnyOf(BsonNull, BsonBoolean), + (firstEq.component()!!.reference as HasValueReference.Constant).type, + ) + assertEquals(true, (firstEq.component()!!.reference as HasValueReference.Constant).value) + + val secondEq = andChildren.children[1] + assertEquals( + "hidden", + (secondEq.component>()!!.reference as HasFieldReference.Known).fieldName, + ) + assertEquals( + BsonAnyOf(BsonNull, BsonBoolean), + (secondEq.component()!!.reference as HasValueReference.Constant).type, + ) + assertEquals(false, (secondEq.component()!!.reference as HasValueReference.Constant).value) + } + + @ParsingTest( + fileName = "Repository.java", + value = """ +import com.mongodb.client.FindIterable; +import com.mongodb.client.MongoClient; +import org.bson.Document; +import static com.mongodb.client.model.Filters.*; + +public class Repository { + private final MongoClient client; + + public Repository(MongoClient client) { + this.client = client; + } + + public FindIterable findReleasedBooks() { + var isReleased = eq("released", true); + return client.getDatabase("myDatabase") + .getCollection("myCollection") + .find(isReleased); + } +} + """, + ) + fun `supports references to variables in a query expression`(psiFile: PsiFile) { + val query = psiFile.getQueryAtMethod("Repository", "findReleasedBooks") + val parsedQuery = JavaDriverDialect.parser.parse(query) + + val hasChildren = + parsedQuery.component>()!! + + val eq = hasChildren.children[0] + assertEquals("eq", eq.component()!!.name) + assertEquals( + "released", + (eq.component>()!!.reference as HasFieldReference.Known).fieldName, + ) + assertEquals( + BsonAnyOf(BsonNull, BsonBoolean), + (eq.component()!!.reference as HasValueReference.Constant).type, + ) + assertEquals(true, (eq.component()!!.reference as HasValueReference.Constant).value) + } + + @ParsingTest( + fileName = "Repository.java", + value = """ +import com.mongodb.client.FindIterable; +import com.mongodb.client.MongoClient; +import org.bson.Document; +import static com.mongodb.client.model.Filters.*; + +public class Repository { + private final MongoClient client; + + public Repository(MongoClient client) { + this.client = client; + } + + public FindIterable findReleasedBooks() { + return client.getDatabase("myDatabase") + .getCollection("myCollection") + .find(isReleased()); + } + + private Document isReleased() { + return eq("released", true); + } +} + """, + ) + fun `supports to methods in a query expression`(psiFile: PsiFile) { + val query = psiFile.getQueryAtMethod("Repository", "findReleasedBooks") + val parsedQuery = JavaDriverDialect.parser.parse(query) + + val hasChildren = + parsedQuery.component>()!! + + val eq = hasChildren.children[0] + assertEquals("eq", eq.component()!!.name) + assertEquals( + "released", + (eq.component>()!!.reference as HasFieldReference.Known).fieldName, + ) + assertEquals( + BsonAnyOf(BsonNull, BsonBoolean), + (eq.component()!!.reference as HasValueReference.Constant).type, + ) + assertEquals(true, (eq.component()!!.reference as HasValueReference.Constant).value) + } + + @ParsingTest( + fileName = "Repository.java", + value = """ +import com.mongodb.client.FindIterable; +import com.mongodb.client.MongoClient; +import org.bson.Document; +import static com.mongodb.client.model.Filters.*; + +public class Repository { + private final MongoClient client; + + public Repository(MongoClient client) { + this.client = client; + } + + public FindIterable findReleasedBooks() { + return findAllByReleaseFlag(true); + } + + private Document findAllByReleaseFlag(boolean released) { + return client.getDatabase("myDatabase") + .getCollection("myCollection") + .find(eq("released", released)); + } +} + """, + ) + fun `supports to methods in a custom dsl as in mms`(psiFile: PsiFile) { + val query = psiFile.getQueryAtMethod("Repository", "findReleasedBooks") + val parsedQuery = JavaDriverDialect.parser.parse(query) + + val hasChildren = + parsedQuery.component>()!! + + val eq = hasChildren.children[0] + assertEquals("eq", eq.component()!!.name) + assertEquals( + "released", + (eq.component>()!!.reference as HasFieldReference.Known).fieldName, + ) + assertEquals( + BsonBoolean, + (eq.component()!!.reference as HasValueReference.Runtime).type, + ) + } + + @Suppress("TOO_LONG_FUNCTION") + @ParsingTest( + fileName = "Repository.java", + value = """ +import com.mongodb.client.FindIterable; +import com.mongodb.client.MongoClient; +import org.bson.Document; +import static com.mongodb.client.model.Filters.*; + +public class Repository { + private static final String RELEASED = "released"; + private static final String HIDDEN = "hidden"; + + private final MongoClient client; + + public Repository(MongoClient client) { + this.client = client; + } + + public FindIterable findReleasedBooks() { + var isReleased = eq(RELEASED, true); + var isNotHidden = eq(HIDDEN, false); + + return client.getDatabase("myDatabase") + .getCollection("myCollection") + .find(and(isReleased, isNotHidden)); + } +} + """, + ) + fun `supports vararg operators with references to fields in variables`(psiFile: PsiFile) { + val query = psiFile.getQueryAtMethod("Repository", "findReleasedBooks") + val parsedQuery = JavaDriverDialect.parser.parse(query) + + val hasChildren = + parsedQuery.component>()!! + + val and = hasChildren.children[0] + assertEquals("and", and.component()!!.name) + val andChildren = and.component>()!! + + val firstEq = andChildren.children[0] + assertEquals( + "released", + (firstEq.component>()!!.reference as HasFieldReference.Known).fieldName, + ) + assertEquals( + BsonAnyOf(BsonNull, BsonBoolean), + (firstEq.component()!!.reference as HasValueReference.Constant).type, + ) + assertEquals(true, (firstEq.component()!!.reference as HasValueReference.Constant).value) + + val secondEq = andChildren.children[1] + assertEquals( + "hidden", + (secondEq.component>()!!.reference as HasFieldReference.Known).fieldName, + ) + assertEquals( + BsonAnyOf(BsonNull, BsonBoolean), + (secondEq.component()!!.reference as HasValueReference.Constant).type, + ) + assertEquals(false, (secondEq.component()!!.reference as HasValueReference.Constant).value) + } } diff --git a/packages/mongodb-dialects/java-driver/src/test/kotlin/com/mongodb/jbplugin/dialects/javadriver/glossary/PsiMdbTreeUtilTest.kt b/packages/mongodb-dialects/java-driver/src/test/kotlin/com/mongodb/jbplugin/dialects/javadriver/glossary/PsiMdbTreeUtilTest.kt new file mode 100644 index 00000000..654228e8 --- /dev/null +++ b/packages/mongodb-dialects/java-driver/src/test/kotlin/com/mongodb/jbplugin/dialects/javadriver/glossary/PsiMdbTreeUtilTest.kt @@ -0,0 +1,84 @@ +package com.mongodb.jbplugin.dialects.javadriver.glossary + +import com.intellij.openapi.application.ApplicationManager +import com.intellij.openapi.project.Project +import com.intellij.psi.JavaPsiFacade +import com.intellij.psi.PsiType +import com.intellij.psi.PsiTypes +import com.mongodb.jbplugin.dialects.javadriver.IntegrationTest +import com.mongodb.jbplugin.mql.* +import org.junit.jupiter.api.Assertions.* +import org.junit.jupiter.params.ParameterizedTest +import org.junit.jupiter.params.provider.MethodSource + +private typealias PsiTypeProvider = (Project) -> PsiType + +@IntegrationTest +class PsiMdbTreeUtilTest { + @ParameterizedTest + @MethodSource("psiTypeToBsonType") + fun `should map all psi types to their corresponding bson types`( + typeProvider: PsiTypeProvider, + expected: BsonType, + project: Project, + ) { + ApplicationManager.getApplication().invokeAndWait { + val psiType = typeProvider(project) + assertEquals(expected, psiType.toBsonType()) + } + } + + companion object { + @JvmStatic + fun psiTypeToBsonType(): Array> = + arrayOf( + arrayOf({ project: Project -> project.findClass("org.bson.types.ObjectId") }, + BsonAnyOf(BsonObjectId, BsonNull), + ), + arrayOf({ _: Project -> PsiTypes.booleanType() }, + BsonBoolean, + ), + arrayOf({ _: Project -> PsiTypes.shortType() }, + BsonInt32, + ), + arrayOf({ _: Project -> PsiTypes.intType() }, + BsonInt32, + ), + arrayOf({ _: Project -> PsiTypes.longType() }, + BsonInt64, + ), + arrayOf({ _: Project -> PsiTypes.floatType() }, + BsonDouble, + ), + arrayOf({ _: Project -> PsiTypes.doubleType() }, + BsonDouble, + ), + arrayOf({ project: Project -> project.findClass("java.lang.CharSequence") }, + BsonAnyOf(BsonString, BsonNull), + ), + arrayOf({ project: Project -> project.findClass("java.lang.String") }, + BsonAnyOf(BsonString, BsonNull), + ), + arrayOf({ project: Project -> project.findClass("java.util.Date") }, + BsonAnyOf(BsonDate, BsonNull), + ), + arrayOf({ project: Project -> project.findClass("java.time.LocalDate") }, + BsonAnyOf(BsonDate, BsonNull), + ), + arrayOf({ project: Project -> project.findClass("java.time.LocalDateTime") }, + BsonAnyOf(BsonDate, BsonNull), + ), + arrayOf({ project: Project -> project.findClass("java.math.BigInteger") }, + BsonAnyOf(BsonInt64, BsonNull), + ), + arrayOf({ project: Project -> project.findClass("java.math.BigDecimal") }, + BsonAnyOf(BsonDecimal128, BsonNull), + ), + ) + } +} + +private fun Project.findClass(name: String): PsiType = + JavaPsiFacade.getElementFactory(this).createTypeByFQClassName( + name, + ) diff --git a/packages/mongodb-mql-model/src/main/kotlin/com/mongodb/jbplugin/mql/BsonType.kt b/packages/mongodb-mql-model/src/main/kotlin/com/mongodb/jbplugin/mql/BsonType.kt index bcb4cb7a..73b5bede 100644 --- a/packages/mongodb-mql-model/src/main/kotlin/com/mongodb/jbplugin/mql/BsonType.kt +++ b/packages/mongodb-mql-model/src/main/kotlin/com/mongodb/jbplugin/mql/BsonType.kt @@ -45,6 +45,11 @@ data class BsonArray( val schema: BsonType, ) : BsonType +/** + * ObjectId + */ +data object BsonObjectId : BsonType + /** * Boolean */ @@ -109,10 +114,11 @@ fun Class<*>.toBsonType(): BsonType { Int::class.javaObjectType -> BsonAnyOf(BsonNull, BsonInt32) CharSequence::class.java, String::class.java -> BsonAnyOf(BsonNull, BsonString) Date::class.java, Instant::class.java, LocalDate::class.java, LocalDateTime::class.java -> - BsonAnyOf(BsonNull, BsonDate) + BsonAnyOf(BsonNull, BsonDate) BigInteger::class.java -> BsonAnyOf(BsonNull, BsonInt64) BigDecimal::class.java -> BsonAnyOf(BsonNull, BsonDecimal128) - else -> if (Collection::class.java.isAssignableFrom(this)) { + else -> + if (Collection::class.java.isAssignableFrom(this)) { return BsonAnyOf(BsonNull, BsonArray(BsonAny)) // types are lost at runtime } else { val fields = diff --git a/packages/mongodb-mql-model/src/main/kotlin/com/mongodb/jbplugin/mql/components/HasFieldReference.kt b/packages/mongodb-mql-model/src/main/kotlin/com/mongodb/jbplugin/mql/components/HasFieldReference.kt index 4495213e..262e7c70 100644 --- a/packages/mongodb-mql-model/src/main/kotlin/com/mongodb/jbplugin/mql/components/HasFieldReference.kt +++ b/packages/mongodb-mql-model/src/main/kotlin/com/mongodb/jbplugin/mql/components/HasFieldReference.kt @@ -3,18 +3,26 @@ package com.mongodb.jbplugin.mql.components import com.mongodb.jbplugin.mql.Component /** + * @param S * @property reference */ -class HasFieldReference( - val reference: FieldReference, +data class HasFieldReference( + val reference: FieldReference, ) : Component { - data object Unknown : FieldReference - sealed interface FieldReference + data object Unknown : FieldReference /** - * @property fieldName + * @param S */ -data class Known( +sealed interface FieldReference + + /** + * @param S + * @property fieldName + * @property source + */ + data class Known( + val source: S, val fieldName: String, - ) : FieldReference + ) : FieldReference } diff --git a/packages/mongodb-mql-model/src/main/kotlin/com/mongodb/jbplugin/mql/components/HasValueReference.kt b/packages/mongodb-mql-model/src/main/kotlin/com/mongodb/jbplugin/mql/components/HasValueReference.kt index 8f8fd4e0..ebe2eab6 100644 --- a/packages/mongodb-mql-model/src/main/kotlin/com/mongodb/jbplugin/mql/components/HasValueReference.kt +++ b/packages/mongodb-mql-model/src/main/kotlin/com/mongodb/jbplugin/mql/components/HasValueReference.kt @@ -1,5 +1,6 @@ package com.mongodb.jbplugin.mql.components +import com.mongodb.jbplugin.mql.BsonType import com.mongodb.jbplugin.mql.Component /** @@ -9,21 +10,22 @@ data class HasValueReference( val reference: ValueReference, ) : Component { data object Unknown : ValueReference + sealed interface ValueReference /** - * @property value - * @property type - */ -data class Constant( + * @property value + * @property type + */ + data class Constant( val value: Any, - val type: String, + val type: BsonType, ) : ValueReference /** - * @property type - */ -data class Runtime( - val type: String, + * @property type + */ + data class Runtime( + val type: BsonType, ) : ValueReference } diff --git a/packages/mongodb-mql-model/src/test/kotlin/com/mongodb/jbplugin/mql/NodeTest.kt b/packages/mongodb-mql-model/src/test/kotlin/com/mongodb/jbplugin/mql/NodeTest.kt index 5f0cd662..ea465bee 100644 --- a/packages/mongodb-mql-model/src/test/kotlin/com/mongodb/jbplugin/mql/NodeTest.kt +++ b/packages/mongodb-mql-model/src/test/kotlin/com/mongodb/jbplugin/mql/NodeTest.kt @@ -18,7 +18,7 @@ class NodeTest { @Test fun `returns null if a component does not exist`() { val node = Node(null, listOf(Named("myName"))) - val named = node.component() + val named = node.component>() assertNull(named) } @@ -29,15 +29,16 @@ class NodeTest { Node( null, listOf( - HasFieldReference(HasFieldReference.Known("field1")), + HasFieldReference(HasFieldReference.Known(null, "field1")), HasFieldReference( HasFieldReference.Known( + null, "field2", ), ), ), ) - val fieldReferences = node.components() + val fieldReferences = node.components>() assertEquals("field1", (fieldReferences[0].reference as HasFieldReference.Known).fieldName) assertEquals("field2", (fieldReferences[1].reference as HasFieldReference.Known).fieldName) @@ -48,9 +49,9 @@ class NodeTest { val node = Node( null, - listOf(HasFieldReference(HasFieldReference.Known("field1"))), + listOf(HasFieldReference(HasFieldReference.Known(null, "field1"))), ) - val hasFieldReferences = node.hasComponent() + val hasFieldReferences = node.hasComponent>() assertTrue(hasFieldReferences) } @@ -60,7 +61,7 @@ class NodeTest { val node = Node( null, - listOf(HasFieldReference(HasFieldReference.Known("field1"))), + listOf(HasFieldReference(HasFieldReference.Known(null, "field1"))), ) val hasNamedComponent = node.hasComponent() @@ -90,16 +91,20 @@ class NodeTest { arrayOf( arrayOf(HasChildren(emptyList()), HasChildren::class.java), arrayOf(HasCollectionReference(HasCollectionReference.Unknown), HasCollectionReference::class.java), - arrayOf(HasCollectionReference(HasCollectionReference.Known(Namespace("db", "coll"))), - HasCollectionReference::class.java), - arrayOf(HasCollectionReference(HasCollectionReference.OnlyCollection("coll")), - HasCollectionReference::class.java), + arrayOf( + HasCollectionReference(HasCollectionReference.Known(Namespace("db", "coll"))), + HasCollectionReference::class.java, + ), + arrayOf( + HasCollectionReference(HasCollectionReference.OnlyCollection("coll")), + HasCollectionReference::class.java, + ), arrayOf(HasFieldReference(HasFieldReference.Unknown), HasFieldReference::class.java), - arrayOf(HasFieldReference(HasFieldReference.Known("abc")), HasFieldReference::class.java), + arrayOf(HasFieldReference(HasFieldReference.Known(null, "abc")), HasFieldReference::class.java), arrayOf(HasFilter(Node(null, emptyList())), HasFilter::class.java), arrayOf(HasValueReference(HasValueReference.Unknown), HasValueReference::class.java), - arrayOf(HasValueReference(HasValueReference.Constant(123, "int")), HasValueReference::class.java), - arrayOf(HasValueReference(HasValueReference.Runtime("int")), HasValueReference::class.java), + arrayOf(HasValueReference(HasValueReference.Constant(123, BsonInt32)), HasValueReference::class.java), + arrayOf(HasValueReference(HasValueReference.Runtime(BsonInt32)), HasValueReference::class.java), arrayOf(HasValueReference(HasValueReference.Unknown), HasValueReference::class.java), arrayOf(Named("abc"), Named::class.java), )