Skip to content

Commit

Permalink
feat: adds support for parsing Aggregate.match
Browse files Browse the repository at this point in the history
  • Loading branch information
himanshusinghs committed Jan 14, 2025
1 parent d3b4fa3 commit 3c61499
Show file tree
Hide file tree
Showing 8 changed files with 1,324 additions and 15 deletions.
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package alt.mongodb.springcriteria;

import org.springframework.data.mongodb.core.MongoTemplate;
import org.springframework.data.mongodb.core.aggregation.Aggregation;
import org.springframework.data.mongodb.core.mapping.Document;

import java.util.List;
Expand All @@ -20,4 +21,16 @@ public BookRepository(MongoTemplate template) {
public List<Book> allReleasedBooks() {
return template.query(Book.class).matching(where("released").is(true)).all();
}

public List<Book> allReleasedBooksAgg() {
return template.aggregate(
Aggregation.newAggregation(
Aggregation.match(
where("released").is(true)
)
),
Book.class,
Book.class
).getMappedResults();
}
}
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package alt.mongodb.springcriteria;

import org.springframework.data.mongodb.core.MongoTemplate;
import org.springframework.data.mongodb.core.aggregation.Aggregation;
import org.springframework.data.mongodb.core.mapping.Document;

import java.util.List;
Expand Down Expand Up @@ -28,6 +29,18 @@ private List<Movie> allMoviesWithRatingAtLeast(int rating) {
);
}

private List<Movie> allMoviesWithRatingAtLeastAgg(int rating) {
return template.aggregate(
Aggregation.newAggregation(
Aggregation.match(
where( "tomatoes.viewer.rating").gte(rating)
)
),
Movie.class,
Movie.class
).getMappedResults();
}

private void updateLanguageOfAllMoviesWithRatingAtLeast(int rating, String newLanguage) {
template.updateMulti(query(where("tomatoes.viewer.rating").gte(rating)), update("key", "value"), Movie.class);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ package com.mongodb.jbplugin.dialects.springcriteria
import com.intellij.psi.*
import com.intellij.psi.util.PsiTypesUtil
import com.mongodb.jbplugin.dialects.javadriver.glossary.fuzzyResolveMethod
import com.mongodb.jbplugin.dialects.javadriver.glossary.meaningfulExpression
import com.mongodb.jbplugin.dialects.javadriver.glossary.tryToResolveAsConstantString
import com.mongodb.jbplugin.mql.components.HasCollectionReference

object QueryTargetCollectionExtractor {
Expand All @@ -23,7 +25,7 @@ object QueryTargetCollectionExtractor {
}

if (currentMethod.name == "query") {
return extractCollectionFromParameter(
return extractCollectionFromClassTypeParameter(
currentMethodCall?.argumentList?.expressions?.getOrNull(0)
)
}
Expand All @@ -35,7 +37,7 @@ object QueryTargetCollectionExtractor {
return unknown
}

fun extractCollectionFromParameter(sourceExpression: PsiExpression?): HasCollectionReference<PsiElement> {
fun extractCollectionFromClassTypeParameter(sourceExpression: PsiExpression?): HasCollectionReference<PsiElement> {
if (sourceExpression == null) {
return unknown
}
Expand All @@ -49,6 +51,12 @@ object QueryTargetCollectionExtractor {
} ?: unknown
}

fun extractCollectionFromStringTypeParameter(sourceExpression: PsiExpression?): HasCollectionReference<PsiElement> {
return sourceExpression?.meaningfulExpression()?.tryToResolveAsConstantString()?.let {
HasCollectionReference(HasCollectionReference.OnlyCollection(sourceExpression, it))
} ?: unknown
}

fun HasCollectionReference<PsiElement>.or(other: HasCollectionReference<PsiElement>): HasCollectionReference<PsiElement> {
return if (this == unknown) {
other
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,9 @@ import com.intellij.psi.util.findParentOfType
import com.intellij.psi.util.parentOfType
import com.mongodb.jbplugin.dialects.DialectParser
import com.mongodb.jbplugin.dialects.javadriver.glossary.*
import com.mongodb.jbplugin.dialects.springcriteria.QueryTargetCollectionExtractor.extractCollectionFromParameter
import com.mongodb.jbplugin.dialects.springcriteria.QueryTargetCollectionExtractor.extractCollectionFromClassTypeParameter
import com.mongodb.jbplugin.dialects.springcriteria.QueryTargetCollectionExtractor.extractCollectionFromQueryChain
import com.mongodb.jbplugin.dialects.springcriteria.QueryTargetCollectionExtractor.extractCollectionFromStringTypeParameter
import com.mongodb.jbplugin.dialects.springcriteria.QueryTargetCollectionExtractor.or
import com.mongodb.jbplugin.mql.BsonAny
import com.mongodb.jbplugin.mql.BsonArray
Expand All @@ -17,6 +18,10 @@ import com.mongodb.jbplugin.mql.toBsonType

private const val CRITERIA_CLASS_FQN = "org.springframework.data.mongodb.core.query.Criteria"
private const val DOCUMENT_FQN = "org.springframework.data.mongodb.core.mapping.Document"
private const val AGGREGATE_FQN = "org.springframework.data.mongodb.core.aggregation.Aggregation"
private val PARSEABLE_AGGREGATION_STAGE_METHODS = listOf(
"match"
)

object SpringCriteriaDialectParser : DialectParser<PsiElement> {
override fun isCandidateForQuery(source: PsiElement) =
Expand Down Expand Up @@ -59,7 +64,7 @@ object SpringCriteriaDialectParser : DialectParser<PsiElement> {
sourceDialect,
command,
inferredFromChain.or(
extractCollectionFromParameter(
extractCollectionFromClassTypeParameter(
actualMethod.argumentList.expressions.getOrNull(1)
)
),
Expand All @@ -86,7 +91,7 @@ object SpringCriteriaDialectParser : DialectParser<PsiElement> {
sourceDialect,
command,
inferredFromChain.or(
extractCollectionFromParameter(
extractCollectionFromClassTypeParameter(
mongoOpCall.argumentList.expressions.getOrNull(1)
)
),
Expand All @@ -107,11 +112,11 @@ object SpringCriteriaDialectParser : DialectParser<PsiElement> {
sourceDialect,
command,
inferredFromChain.or(
extractCollectionFromParameter(
extractCollectionFromClassTypeParameter(
mongoOpCall.argumentList.expressions.getOrNull(1)
)
).or(
extractCollectionFromParameter(
extractCollectionFromClassTypeParameter(
mongoOpCall.argumentList.expressions.getOrNull(2)
)
),
Expand All @@ -131,7 +136,7 @@ object SpringCriteriaDialectParser : DialectParser<PsiElement> {
sourceDialect,
command,
inferredFromChain.or(
extractCollectionFromParameter(
extractCollectionFromClassTypeParameter(
mongoOpCall.argumentList.expressions.getOrNull(1)
)
),
Expand All @@ -144,7 +149,7 @@ object SpringCriteriaDialectParser : DialectParser<PsiElement> {
sourceDialect,
command,
inferredFromChain.or(
extractCollectionFromParameter(
extractCollectionFromClassTypeParameter(
mongoOpCall.argumentList.expressions.getOrNull(1)
)
),
Expand All @@ -160,7 +165,7 @@ object SpringCriteriaDialectParser : DialectParser<PsiElement> {
sourceDialect,
command,
inferredFromChain.or(
extractCollectionFromParameter(
extractCollectionFromClassTypeParameter(
mongoOpCall.argumentList.expressions.getOrNull(1)
)
),
Expand All @@ -172,7 +177,7 @@ object SpringCriteriaDialectParser : DialectParser<PsiElement> {
sourceDialect,
command,
inferredFromChain.or(
extractCollectionFromParameter(
extractCollectionFromClassTypeParameter(
mongoOpCall.argumentList.expressions.getOrNull(1)
)
),
Expand All @@ -184,7 +189,7 @@ object SpringCriteriaDialectParser : DialectParser<PsiElement> {
sourceDialect,
command,
inferredFromChain.or(
extractCollectionFromParameter(
extractCollectionFromClassTypeParameter(
mongoOpCall.argumentList.expressions.getOrNull(1)
)
),
Expand All @@ -200,7 +205,7 @@ object SpringCriteriaDialectParser : DialectParser<PsiElement> {
sourceDialect,
command,
inferredFromChain.or(
extractCollectionFromParameter(
extractCollectionFromClassTypeParameter(
mongoOpCall.argumentList.expressions.getOrNull(1)
)
),
Expand All @@ -210,13 +215,49 @@ object SpringCriteriaDialectParser : DialectParser<PsiElement> {
)
)
)
"aggregate" -> {
val expressions = mongoOpCall.argumentList.expressions
val newAggregationCall = expressions.getOrNull(0)?.resolveToMethodCallExpression {
_,
method
->
method.name == "newAggregation"
}
val resolvedStageCallExpression =
newAggregationCall?.getVarArgsOrIterableArgs()?.mapNotNull {
it.resolveToMethodCallExpression { _, method ->
method.containingClass?.qualifiedName == AGGREGATE_FQN &&
PARSEABLE_AGGREGATION_STAGE_METHODS.contains(method.name)
}
} ?: emptyList()
val collectionExpression = expressions.getOrNull(1)
return Node(
mongoOpCall,
listOf(
sourceDialect,
command,
// MongoTemplate.aggregate accepts both Class and a string for specifying the
// collection where the aggregation will run so we need to account for both
// method signatures while extracting the collection
//
// Note: It is uncommon to have the class type parameter referenced as a variable
// which is why we don't attempt to resolve the parameter as ClassType
extractCollectionFromClassTypeParameter(collectionExpression).or(
extractCollectionFromStringTypeParameter(collectionExpression)
),
HasAggregation(
parseAggregationStagesFromCurrentCall(resolvedStageCallExpression)
)
)
)
}
else -> Node(
mongoOpCall!!,
listOf(
sourceDialect,
command,
inferredFromChain.or(
extractCollectionFromParameter(
extractCollectionFromClassTypeParameter(
mongoOpCall.argumentList.expressions.getOrNull(1)
)
),
Expand Down Expand Up @@ -283,6 +324,34 @@ object SpringCriteriaDialectParser : DialectParser<PsiElement> {
)
}

private fun parseAggregationStagesFromCurrentCall(
stageCallExpressions: List<PsiMethodCallExpression>
): List<Node<PsiElement>> {
return stageCallExpressions.mapNotNull { stageCall ->
val stageMethod = stageCall.fuzzyResolveMethod() ?: return@mapNotNull Node(
source = stageCall,
components = emptyList()
)

when (stageMethod.name) {
"match" -> parseMatchStageCall(stageCall)
else -> null
}
}
}

private fun parseMatchStageCall(matchStageCall: PsiMethodCallExpression): Node<PsiElement> {
return Node(
source = matchStageCall,
components = listOf(
Named(Name.MATCH),
HasFilter(
parseFilterRecursively(matchStageCall.argumentList.expressions.getOrNull(0))
)
)
)
}

private fun parseFilterRecursively(
valueFilterExpression: PsiElement?
): List<Node<PsiElement>> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ import com.intellij.testFramework.fixtures.IdeaTestFixtureFactory
import com.mongodb.assertions.Assertions.assertNotNull
import com.mongodb.jbplugin.mql.Component
import com.mongodb.jbplugin.mql.Node
import com.mongodb.jbplugin.mql.components.HasAggregation
import com.mongodb.jbplugin.mql.components.HasCollectionReference
import com.mongodb.jbplugin.mql.components.HasFieldReference
import com.mongodb.jbplugin.mql.components.HasFilter
Expand Down Expand Up @@ -298,6 +299,35 @@ fun Node<PsiElement>.assert(
this.assertions()
}

fun Node<PsiElement>.stageN(
n: Int,
name: Name? = null,
assertions: Node<PsiElement>.() -> Unit = {
}
) {
val stages = component<HasAggregation<PsiElement>>()?.children ?: emptyList()
assertNotEquals(0, stages.size) {
"Expected HasAggregation to have at-least $n stages"
}

val stage = stages.getOrNull(n)
assertNotEquals(null, stage) {
"Expected a stage at index $n, found null"
}

if (name != null) {
val stageName = stage!!.component<Named>()
assertNotEquals(null, stageName) {
"Expected a stage with name $name but null found."
}
assertEquals(name, stageName!!.name) {
"Expected a stage with name $name but ${stageName.name} found."
}
}

stage!!.assertions()
}

fun Node<PsiElement>.filterN(
n: Int,
name: Name? = null,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ class Repository {
val query = psiFile.getQueryAtMethod("Repository", "allReleasedBooks")
val collection =
(
QueryTargetCollectionExtractor.extractCollectionFromParameter(
QueryTargetCollectionExtractor.extractCollectionFromClassTypeParameter(
(query as? PsiMethodCallExpression)?.argumentList?.expressions?.getOrNull(1)
).reference as HasCollectionReference.OnlyCollection
).collection
Expand Down
Loading

0 comments on commit 3c61499

Please sign in to comment.