From 6c65c9f5edb7985f642cf04130c86c12d1354f0f Mon Sep 17 00:00:00 2001 From: Kevin Mas Ruiz Date: Fri, 22 Nov 2024 14:01:44 +0100 Subject: [PATCH] chore: now missing tests and small fixes --- .../javadriver/JavaDriverRepository.java | 7 +- .../glossary/JavaDriverDialectParser.kt | 81 +++++++++++++++---- 2 files changed, 72 insertions(+), 16 deletions(-) diff --git a/packages/jetbrains-plugin/src/test/resources/project-fixtures/basic-java-project-with-mongodb/src/main/java/alt/mongodb/javadriver/JavaDriverRepository.java b/packages/jetbrains-plugin/src/test/resources/project-fixtures/basic-java-project-with-mongodb/src/main/java/alt/mongodb/javadriver/JavaDriverRepository.java index e32c475d..8bb3f203 100644 --- a/packages/jetbrains-plugin/src/test/resources/project-fixtures/basic-java-project-with-mongodb/src/main/java/alt/mongodb/javadriver/JavaDriverRepository.java +++ b/packages/jetbrains-plugin/src/test/resources/project-fixtures/basic-java-project-with-mongodb/src/main/java/alt/mongodb/javadriver/JavaDriverRepository.java @@ -54,10 +54,13 @@ public List queryMoviesByYear(String year) { .aggregate( List.of( Aggregates.match( - Filters.eq("_id", year) + Filters.eq("year", year) ), Aggregates.group( - "newField", Accumulators.avg("test", "$year") + "newField", + Accumulators.avg("test", "$xxx"), + Accumulators.sum("test2", "$year"), + Accumulators.bottom("field", Sorts.ascending("year"), "$year") ), Aggregates.project( Projections.fields( diff --git a/packages/mongodb-dialects/java-driver/src/main/kotlin/com/mongodb/jbplugin/dialects/javadriver/glossary/JavaDriverDialectParser.kt b/packages/mongodb-dialects/java-driver/src/main/kotlin/com/mongodb/jbplugin/dialects/javadriver/glossary/JavaDriverDialectParser.kt index e0a18fcb..b3d13707 100644 --- a/packages/mongodb-dialects/java-driver/src/main/kotlin/com/mongodb/jbplugin/dialects/javadriver/glossary/JavaDriverDialectParser.kt +++ b/packages/mongodb-dialects/java-driver/src/main/kotlin/com/mongodb/jbplugin/dialects/javadriver/glossary/JavaDriverDialectParser.kt @@ -29,6 +29,7 @@ private const val ACCUMULATORS_FQN = "com.mongodb.client.model.Accumulators" private const val SORTS_FQN = "com.mongodb.client.model.Sorts" private const val JAVA_LIST_FQN = "java.util.List" private const val JAVA_ARRAYS_FQN = "java.util.Arrays" + private val PARSEABLE_AGGREGATION_STAGE_METHODS = listOf( "match", "project", @@ -228,7 +229,8 @@ object JavaDriverDialectParser : DialectParser { return containingClass.qualifiedName == AGGREGATES_FQN || containingClass.qualifiedName == PROJECTIONS_FQN || - containingClass.qualifiedName == SORTS_FQN + containingClass.qualifiedName == SORTS_FQN || + containingClass.qualifiedName == ACCUMULATORS_FQN } private fun resolveBsonBuilderCall( @@ -297,7 +299,7 @@ object JavaDriverDialectParser : DialectParser { } val valueExpression = filter.argumentList.expressions[0] val valueReference = resolveValueFromExpression(valueExpression) - val fieldReference = HasFieldReference.FromSchema(valueExpression, "_id") + val fieldReference = FromSchema(valueExpression, "_id") return Node( filter, @@ -505,7 +507,7 @@ object JavaDriverDialectParser : DialectParser { } val accumulators = stageCall.getVarArgsOrIterableArgs().drop(1) - .mapNotNull { resolveToAccumulatorCall(it) } + .mapNotNull { resolveBsonBuilderCall(it, ACCUMULATORS_FQN) } val parsedAccumulators = accumulators.mapNotNull { parseAccumulatorExpression(it) } return nodeWithAccumulators(parsedAccumulators) @@ -515,13 +517,10 @@ object JavaDriverDialectParser : DialectParser { } } - private fun resolveToProjectionCall(element: PsiElement): PsiMethodCallExpression? { - return element.resolveToMethodCallExpression { _, methodCall -> - methodCall.containingClass?.qualifiedName == PROJECTIONS_FQN - } - } - - private fun parseProjectionExpression(expression: PsiMethodCallExpression): List> { + private fun parseBsonBuilderCallsSimilarToProjections( + expression: PsiMethodCallExpression, + classQualifiedName: String + ): List> { val methodCall = expression.resolveMethod() ?: return emptyList() return when (methodCall.name) { "fields", @@ -556,7 +555,8 @@ object JavaDriverDialectParser : DialectParser { ) ) ) - ) else -> null + ) + else -> null } } @@ -572,8 +572,8 @@ object JavaDriverDialectParser : DialectParser { "avg" -> parseKeyValAccumulator(expression, Name.AVG) "first" -> parseKeyValAccumulator(expression, Name.FIRST) "last" -> parseKeyValAccumulator(expression, Name.LAST) - "top" -> null - "bottom" -> null + "top" -> parseLeadingAccumulatorExpression(expression, Name.TOP) + "bottom" -> parseLeadingAccumulatorExpression(expression, Name.BOTTOM) "max" -> parseKeyValAccumulator(expression, Name.MAX) "min" -> parseKeyValAccumulator(expression, Name.MIN) "push" -> parseKeyValAccumulator(expression, Name.PUSH) @@ -605,6 +605,34 @@ object JavaDriverDialectParser : DialectParser { ) } + private fun parseLeadingAccumulatorExpression(expression: PsiMethodCallExpression, name: Name): Node? { + val keyExpr = expression.argumentList.expressions.getOrNull(0) ?: return null + val sortExprArgument = expression.argumentList.expressions.getOrNull(1) ?: return null + val valueExpr = expression.argumentList.expressions.getOrNull(2) ?: return null + + val sortExpr = resolveBsonBuilderCall(sortExprArgument, SORTS_FQN) ?: return null + + val fieldName = keyExpr.tryToResolveAsConstantString() + val sort = parseBsonBuilderCallsSimilarToProjections(sortExpr, SORTS_FQN) + val accumulatorExpr = parseComputedExpression(valueExpr, createsNewField = false) + + return Node( + expression, + listOf( + Named(name), + HasFieldReference( + if (fieldName != null) { + Computed(keyExpr, fieldName, fieldName) + } else { + HasFieldReference.Unknown as HasFieldReference.FieldReference + } + ), + HasSorts(sort), + accumulatorExpr + ) + ) + } + private fun hasMongoDbSessionReference(methodCall: PsiMethodCallExpression): Boolean { val hasEnoughArgs = methodCall.argumentList.expressionCount > 0 if (!hasEnoughArgs) { @@ -615,6 +643,31 @@ object JavaDriverDialectParser : DialectParser { return typeOfFirstArg != null && typeOfFirstArg.equalsToText(SESSION_FQN) } + private fun parseComputedExpression(element: PsiElement, createsNewField: Boolean = true): HasValueReference { + return HasValueReference( + when (val expression = element.tryToResolveAsConstantString()) { + null -> HasValueReference.Unknown as HasValueReference.ValueReference + else -> HasValueReference.Computed( + element, + Node( + element, + listOf( + if (createsNewField) { + HasFieldReference( + Computed(element, expression.trim('$'), expression) + ) + } else { + HasFieldReference( + FromSchema(element, expression.trim('$'), expression) + ) + } + ) + ) + ) + } + ) + } + private fun isAggregationStageMethodCall(callMethod: PsiMethod?): Boolean { return PARSEABLE_AGGREGATION_STAGE_METHODS.contains(callMethod?.name) && callMethod?.containingClass?.qualifiedName == AGGREGATES_FQN @@ -624,7 +677,7 @@ object JavaDriverDialectParser : DialectParser { val fieldNameAsString = expression.tryToResolveAsConstantString() val fieldReference = fieldNameAsString?.let { - HasFieldReference.FromSchema(expression, it) + FromSchema(expression, it) } ?: HasFieldReference.Unknown return fieldReference