Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support for new accumulators INTELLIJ-128 #98

Merged
merged 12 commits into from
Nov 25, 2024
Merged
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package alt.mongodb.javadriver;

import com.mongodb.client.MongoClient;
import com.mongodb.client.model.Accumulators;
import com.mongodb.client.model.Aggregates;
import com.mongodb.client.model.Filters;
import com.mongodb.client.model.Projections;
Expand Down Expand Up @@ -55,6 +56,12 @@ public List<Document> queryMoviesByYear(String year) {
Aggregates.match(
Filters.eq("year", year)
),
Aggregates.group(
"newField",
Accumulators.avg("test", "$year"),
Accumulators.sum("test2", "$year"),
Accumulators.bottom("field", Sorts.ascending("year"), "$year")
),
Aggregates.project(
Projections.fields(
Projections.include("year", "plot")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,11 @@ import com.mongodb.jbplugin.mql.BsonArray
import com.mongodb.jbplugin.mql.BsonBoolean
import com.mongodb.jbplugin.mql.BsonInt32
import com.mongodb.jbplugin.mql.BsonType
import com.mongodb.jbplugin.mql.ComputedBsonType
import com.mongodb.jbplugin.mql.Node
import com.mongodb.jbplugin.mql.components.*
import com.mongodb.jbplugin.mql.components.HasFieldReference.Computed
import com.mongodb.jbplugin.mql.components.HasFieldReference.FromSchema
import com.mongodb.jbplugin.mql.flattenAnyOfReferences
import com.mongodb.jbplugin.mql.toBsonType

Expand All @@ -23,9 +26,11 @@ private const val FILTERS_FQN = "com.mongodb.client.model.Filters"
private const val UPDATES_FQN = "com.mongodb.client.model.Updates"
private const val AGGREGATES_FQN = "com.mongodb.client.model.Aggregates"
private const val PROJECTIONS_FQN = "com.mongodb.client.model.Projections"
private const val ACCUMULATORS_FQN = "com.mongodb.client.model.Accumulators"
private const val SORTS_FQN = "com.mongodb.client.model.Sorts"
private const val JAVA_LIST_FQN = "java.util.List"
private const val JAVA_ARRAYS_FQN = "java.util.Arrays"

private val PARSEABLE_AGGREGATION_STAGE_METHODS = listOf(
"match",
"project",
Expand Down Expand Up @@ -225,7 +230,8 @@ object JavaDriverDialectParser : DialectParser<PsiElement> {

return containingClass.qualifiedName == AGGREGATES_FQN ||
containingClass.qualifiedName == PROJECTIONS_FQN ||
containingClass.qualifiedName == SORTS_FQN
containingClass.qualifiedName == SORTS_FQN ||
containingClass.qualifiedName == ACCUMULATORS_FQN
}

private fun resolveBsonBuilderCall(
Expand Down Expand Up @@ -294,7 +300,7 @@ object JavaDriverDialectParser : DialectParser<PsiElement> {
}
val valueExpression = filter.argumentList.expressions[0]
val valueReference = resolveValueFromExpression(valueExpression)
val fieldReference = HasFieldReference.FromSchema(valueExpression, "_id")
val fieldReference = FromSchema(valueExpression, "_id")

return Node(
filter,
Expand Down Expand Up @@ -481,6 +487,33 @@ object JavaDriverDialectParser : DialectParser<PsiElement> {

return nodeWithParsedComponents(parsedComponents)
}
"group" -> {
// the first parameter of group is going to be a string expression
val groupArgument = stageCall.argumentList.expressions.getOrNull(0)
?: return null

val groupFieldValueExpression = parseComputedExpression(groupArgument)

val nodeWithAccumulators: (List<Node<PsiElement>>) -> Node<PsiElement> = { accFields: List<Node<PsiElement>> ->
Node(
stageCall,
listOf(
Named(Name.GROUP),
HasFieldReference(
HasFieldReference.Inferred(groupArgument, "_id", "_id")
),
groupFieldValueExpression,
HasAccumulatedFields(accFields)
)
)
}

val accumulators = stageCall.getVarArgsOrIterableArgs().drop(1)
.mapNotNull { resolveBsonBuilderCall(it, ACCUMULATORS_FQN) }

val parsedAccumulators = accumulators.mapNotNull { parseAccumulatorExpression(it) }
return nodeWithAccumulators(parsedAccumulators)
}

else -> return null
}
Expand All @@ -505,8 +538,7 @@ object JavaDriverDialectParser : DialectParser<PsiElement> {
val fieldReference = resolveFieldNameFromExpression(it)
val methodName = Name.from(methodCall.name)
when (fieldReference) {
is HasFieldReference.Unknown -> null
is HasFieldReference.FromSchema -> Node(
is FromSchema -> Node(
source = it,
components = listOf(
Named(methodName),
Expand All @@ -526,13 +558,83 @@ object JavaDriverDialectParser : DialectParser<PsiElement> {
)
)
)
else -> null
}
}

else -> emptyList()
}
}

private fun parseAccumulatorExpression(expression: PsiMethodCallExpression): Node<PsiElement>? {
val methodCall = expression.fuzzyResolveMethod() ?: return null

return when (methodCall.name) {
"sum" -> parseKeyValAccumulator(expression, Name.SUM)
"avg" -> parseKeyValAccumulator(expression, Name.AVG)
"first" -> parseKeyValAccumulator(expression, Name.FIRST)
"last" -> parseKeyValAccumulator(expression, Name.LAST)
"top" -> parseLeadingAccumulatorExpression(expression, Name.TOP)
"bottom" -> parseLeadingAccumulatorExpression(expression, Name.BOTTOM)
"max" -> parseKeyValAccumulator(expression, Name.MAX)
"min" -> parseKeyValAccumulator(expression, Name.MIN)
"push" -> parseKeyValAccumulator(expression, Name.PUSH)
"addToSet" -> parseKeyValAccumulator(expression, Name.ADD_TO_SET)
else -> null
}
}

private fun parseKeyValAccumulator(expression: PsiMethodCallExpression, name: Name): Node<PsiElement>? {
val keyExpr = expression.argumentList.expressions.getOrNull(0) ?: return null
val valueExpr = expression.argumentList.expressions.getOrNull(1) ?: return null

val fieldName = keyExpr.tryToResolveAsConstantString()
val accumulatorExpr = parseComputedExpression(valueExpr, createsNewField = false)

return Node(
expression,
listOf(
Named(name),
HasFieldReference(
if (fieldName != null) {
Computed(keyExpr, fieldName, fieldName)
} else {
HasFieldReference.Unknown as HasFieldReference.FieldReference<PsiElement>
}
),
accumulatorExpr
)
)
}

private fun parseLeadingAccumulatorExpression(expression: PsiMethodCallExpression, name: Name): Node<PsiElement>? {
val keyExpr = expression.argumentList.expressions.getOrNull(0) ?: return null
val sortExprArgument = expression.argumentList.expressions.getOrNull(1) ?: return null
val valueExpr = expression.argumentList.expressions.getOrNull(2) ?: return null

val sortExpr = resolveBsonBuilderCall(sortExprArgument, SORTS_FQN) ?: return null

val fieldName = keyExpr.tryToResolveAsConstantString()
val sort = parseBsonBuilderCallsSimilarToProjections(sortExpr, SORTS_FQN)
val accumulatorExpr = parseComputedExpression(valueExpr, createsNewField = false)

return Node(
expression,
listOf(
Named(name),
HasFieldReference(
if (fieldName != null) {
Computed(keyExpr, fieldName, fieldName)
} else {
HasFieldReference.Unknown as HasFieldReference.FieldReference<PsiElement>
}
),
HasSorts(sort),
accumulatorExpr
)
)
}

private fun hasMongoDbSessionReference(methodCall: PsiMethodCallExpression): Boolean {
val hasEnoughArgs = methodCall.argumentList.expressionCount > 0
if (!hasEnoughArgs) {
Expand All @@ -543,6 +645,44 @@ object JavaDriverDialectParser : DialectParser<PsiElement> {
return typeOfFirstArg != null && typeOfFirstArg.equalsToText(SESSION_FQN)
}

private fun parseComputedExpression(element: PsiElement, createsNewField: Boolean = true): HasValueReference<PsiElement> {
val (constant, value) = element.tryToResolveAsConstant()
return HasValueReference(
when {
constant && value is String -> HasValueReference.Computed(
element,
type = ComputedBsonType(
BsonAny,
himanshusinghs marked this conversation as resolved.
Show resolved Hide resolved
Node(
element,
listOf(
if (createsNewField) {
HasFieldReference(
Computed(element, value.trim('$'), value)
)
} else {
HasFieldReference(
FromSchema(element, value.trim('$'), value)
)
}
)
)
)
)
constant && value != null -> HasValueReference.Constant(
element,
value,
value.javaClass.toBsonType(value)
)
!constant && element is PsiExpression -> HasValueReference.Runtime(
element,
element.type?.toBsonType() ?: BsonAny
)
else -> HasValueReference.Unknown as HasValueReference.ValueReference<PsiElement>
}
)
}

private fun isAggregationStageMethodCall(callMethod: PsiMethod?): Boolean {
return PARSEABLE_AGGREGATION_STAGE_METHODS.contains(callMethod?.name) &&
callMethod?.containingClass?.qualifiedName == AGGREGATES_FQN
Expand All @@ -552,7 +692,7 @@ object JavaDriverDialectParser : DialectParser<PsiElement> {
val fieldNameAsString = expression.tryToResolveAsConstantString()
val fieldReference =
fieldNameAsString?.let {
HasFieldReference.FromSchema(expression, it)
FromSchema(expression, it)
} ?: HasFieldReference.Unknown

return fieldReference
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -353,7 +353,7 @@ fun PsiElement.tryToResolveAsConstant(): Pair<Boolean, Any?> {
* @return
*/
fun PsiElement.tryToResolveAsConstantString(): String? =
tryToResolveAsConstant().takeIf { it.first }?.second?.toString()
tryToResolveAsConstant().takeIf { it.first }?.second as? String

/**
* Maps a PsiType to its BSON counterpart.
Expand Down
Loading
Loading