Skip to content

Commit

Permalink
feat: add support for parsing topN/bottomN INTELLIJ-153 (#105)
Browse files Browse the repository at this point in the history
  • Loading branch information
himanshusinghs authored Nov 28, 2024
1 parent 9670ea2 commit 08f07cc
Show file tree
Hide file tree
Showing 7 changed files with 293 additions and 3 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/quality-check.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ jobs:
CHANGELOG.md
- uses: mheap/github-action-required-labels@5847eef68201219cf0a4643ea7be61e77837bbce # 5.4.1
if: ${{ steps.verify-changelog-files.outputs.files_changed == 'false' }}
if: steps.verify-changed-files.outputs.files_changed == 'false'
with:
mode: minimum
count: 1
Expand Down
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ MongoDB plugin for IntelliJ IDEA.
## [Unreleased]

### Added
* [INTELLIJ-153](https://jira.mongodb.org/browse/INTELLIJ-153) Add support for parsing, linting and
autocompleting fields in Accumulators.topN and Accumulators.bottomN
* [INTELLIJ-104](https://jira.mongodb.org/browse/INTELLIJ-104) Add support for Spring Criteria
in/nin operator, like in `where(field).in(1, 2, 3)`
* [INTELLIJ-61](https://jira.mongodb.org/browse/INTELLIJ-61) Add support for Spring Criteria
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -590,7 +590,13 @@ public class Repository {
<warning descr="Field \"nonExistentField\" does not exist in collection \"myDatabase.myCollection\"">getBadFieldName()</warning>
),
avgCountAcc,
getAvgCountAcc()
getAvgCountAcc(),
Accumulators.topN(
"totalCount",
Sorts.ascending("otherField"),
<warning descr="Field \"nonExistentField\" does not exist in collection \"myDatabase.myCollection\"">getBadFieldName()</warning>,
3
)
)
));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -633,7 +633,9 @@ object JavaDriverDialectParser : DialectParser<PsiElement> {
"first" -> parseKeyValAccumulator(expression, Name.FIRST)
"last" -> parseKeyValAccumulator(expression, Name.LAST)
"top" -> parseLeadingAccumulatorExpression(expression, Name.TOP)
"topN" -> parseLeadingAccumulatorExpression(expression, Name.TOP_N)
"bottom" -> parseLeadingAccumulatorExpression(expression, Name.BOTTOM)
"bottomN" -> parseLeadingAccumulatorExpression(expression, Name.BOTTOM_N)
"max" -> parseKeyValAccumulator(expression, Name.MAX)
"min" -> parseKeyValAccumulator(expression, Name.MIN)
"push" -> parseKeyValAccumulator(expression, Name.PUSH)
Expand Down Expand Up @@ -669,6 +671,15 @@ object JavaDriverDialectParser : DialectParser<PsiElement> {
val keyExpr = expression.argumentList.expressions.getOrNull(0) ?: return null
val sortExprArgument = expression.argumentList.expressions.getOrNull(1) ?: return null
val valueExpr = expression.argumentList.expressions.getOrNull(2) ?: return null
val hasLimit = expression.argumentList.expressions.getOrNull(3)?.let {
val (wasResolved, value) = it.tryToResolveAsConstant()
val valueAsInt = value as? Int
if (wasResolved && valueAsInt != null) {
listOf(HasLimit(valueAsInt))
} else {
emptyList()
}
} ?: emptyList()

val sortExpr = resolveBsonBuilderCall(sortExprArgument, SORTS_FQN) ?: return null

Expand All @@ -689,7 +700,7 @@ object JavaDriverDialectParser : DialectParser<PsiElement> {
),
HasSorts(sort),
accumulatorExpr
)
) + hasLimit
)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import com.mongodb.jbplugin.dialects.javadriver.glossary.JavaDriverDialect
import com.mongodb.jbplugin.mql.components.HasAccumulatedFields
import com.mongodb.jbplugin.mql.components.HasAggregation
import com.mongodb.jbplugin.mql.components.HasFieldReference
import com.mongodb.jbplugin.mql.components.HasLimit
import com.mongodb.jbplugin.mql.components.HasSorts
import com.mongodb.jbplugin.mql.components.HasValueReference
import com.mongodb.jbplugin.mql.components.Name
Expand Down Expand Up @@ -646,4 +647,267 @@ public final class Aggregation {
assertEquals("mySort", sortingField.fieldName)
}
}

@WithFile(
fileName = "Repository.java",
value = """
import com.mongodb.client.AggregateIterable;
import com.mongodb.client.MongoClient;
import com.mongodb.client.MongoCollection;
import com.mongodb.client.model.Aggregates;
import com.mongodb.client.model.Accumulators;
import com.mongodb.client.model.Filters;
import com.mongodb.client.model.Sorts;import org.bson.Document;
import org.bson.types.ObjectId;
import java.util.List;
import static com.mongodb.client.model.Filters.*;
public final class Aggregation {
private final MongoCollection<Document> collection;
public Aggregation(MongoClient client) {
this.collection = client.getDatabase("simple").getCollection("books");
}
public AggregateIterable<Document> getAllBookTitles(ObjectId id) {
return this.collection.aggregate(List.of(
Aggregates.group(
"${'$'}myField",
Accumulators."|"("myKey", Sorts.ascending("mySort"), "myVal", 3)
)
));
}
}
""",
)
@ParameterizedTest
@CsvSource(
value = [
"method;;expected",
"topN;;TOP_N",
"bottomN;;BOTTOM_N",
],
delimiterString = ";;",
useHeadersInDisplayName = true
)
fun `supports all relevant key-value accumulators with sorting criteria and n value`(
method: String,
expected: Name,
psiFile: PsiFile
) {
WriteCommandAction.runWriteCommandAction(psiFile.project) {
val elementAtCaret = psiFile.caret()
val javaFacade = JavaPsiFacade.getInstance(psiFile.project)
val methodToTest = javaFacade.parserFacade.createReferenceFromText(method, null)
elementAtCaret.replace(methodToTest)
}

ApplicationManager.getApplication().runReadAction {
val aggregate = psiFile.getQueryAtMethod("Aggregation", "getAllBookTitles")
val parsedAggregate = JavaDriverDialect.parser.parse(aggregate)
val hasAggregation = parsedAggregate.component<HasAggregation<PsiElement>>()
assertEquals(1, hasAggregation?.children?.size)

val groupStage = hasAggregation?.children?.get(0)!!
val named = groupStage.component<Named>()!!
assertEquals(Name.GROUP, named.name)

val accumulator = groupStage.component<HasAccumulatedFields<PsiElement>>()!!.children[0]
val accumulatorName = accumulator.component<Named>()!!
assertEquals(expected, accumulatorName.name)

val accumulatorField = accumulator.component<HasFieldReference<PsiElement>>()?.reference as HasFieldReference.Computed<PsiElement>
assertEquals("myKey", accumulatorField.fieldName)

val accumulatorComputed = accumulator.component<HasValueReference<PsiElement>>()?.reference as HasValueReference.Computed<PsiElement>
val accumulatorComputedFieldValue = accumulatorComputed.type.expression.component<HasFieldReference<PsiElement>>()!!.reference as HasFieldReference.FromSchema<PsiElement>
assertEquals("myVal", accumulatorComputedFieldValue.fieldName)

val accumulatorSorting = accumulator.component<HasSorts<PsiElement>>()!!.children[0]
val sortingField = accumulatorSorting.component<HasFieldReference<PsiElement>>()!!.reference as HasFieldReference.FromSchema<PsiElement>
assertEquals("mySort", sortingField.fieldName)

val accumulatorLimit = accumulator.component<HasLimit>()!!
assertEquals(3, accumulatorLimit.limit)
}
}

@WithFile(
fileName = "Repository.java",
value = """
import com.mongodb.client.AggregateIterable;
import com.mongodb.client.MongoClient;
import com.mongodb.client.MongoCollection;
import com.mongodb.client.model.Aggregates;
import com.mongodb.client.model.Accumulators;
import com.mongodb.client.model.Filters;
import com.mongodb.client.model.Sorts;import org.bson.Document;
import org.bson.types.ObjectId;
import java.util.List;
import static com.mongodb.client.model.Filters.*;
public final class Aggregation {
private final MongoCollection<Document> collection;
public Aggregation(MongoClient client) {
this.collection = client.getDatabase("simple").getCollection("books");
}
public AggregateIterable<Document> getAllBookTitles(ObjectId id) {
Number nValue = 3;
return this.collection.aggregate(List.of(
Aggregates.group(
"${'$'}myField",
Accumulators."|"("myKey", Sorts.ascending("mySort"), "myVal", nValue)
)
));
}
}
""",
)
@ParameterizedTest
@CsvSource(
value = [
"method;;expected",
"topN;;TOP_N",
"bottomN;;BOTTOM_N",
],
delimiterString = ";;",
useHeadersInDisplayName = true
)
fun `supports all relevant key-value accumulators with sorting criteria and n value as variable`(
method: String,
expected: Name,
psiFile: PsiFile
) {
WriteCommandAction.runWriteCommandAction(psiFile.project) {
val elementAtCaret = psiFile.caret()
val javaFacade = JavaPsiFacade.getInstance(psiFile.project)
val methodToTest = javaFacade.parserFacade.createReferenceFromText(method, null)
elementAtCaret.replace(methodToTest)
}

ApplicationManager.getApplication().runReadAction {
val aggregate = psiFile.getQueryAtMethod("Aggregation", "getAllBookTitles")
val parsedAggregate = JavaDriverDialect.parser.parse(aggregate)
val hasAggregation = parsedAggregate.component<HasAggregation<PsiElement>>()
assertEquals(1, hasAggregation?.children?.size)

val groupStage = hasAggregation?.children?.get(0)!!
val named = groupStage.component<Named>()!!
assertEquals(Name.GROUP, named.name)

val accumulator = groupStage.component<HasAccumulatedFields<PsiElement>>()!!.children[0]
val accumulatorName = accumulator.component<Named>()!!
assertEquals(expected, accumulatorName.name)

val accumulatorField = accumulator.component<HasFieldReference<PsiElement>>()?.reference as HasFieldReference.Computed<PsiElement>
assertEquals("myKey", accumulatorField.fieldName)

val accumulatorComputed = accumulator.component<HasValueReference<PsiElement>>()?.reference as HasValueReference.Computed<PsiElement>
val accumulatorComputedFieldValue = accumulatorComputed.type.expression.component<HasFieldReference<PsiElement>>()!!.reference as HasFieldReference.FromSchema<PsiElement>
assertEquals("myVal", accumulatorComputedFieldValue.fieldName)

val accumulatorSorting = accumulator.component<HasSorts<PsiElement>>()!!.children[0]
val sortingField = accumulatorSorting.component<HasFieldReference<PsiElement>>()!!.reference as HasFieldReference.FromSchema<PsiElement>
assertEquals("mySort", sortingField.fieldName)

val accumulatorLimit = accumulator.component<HasLimit>()!!
assertEquals(3, accumulatorLimit.limit)
}
}

@WithFile(
fileName = "Repository.java",
value = """
import com.mongodb.client.AggregateIterable;
import com.mongodb.client.MongoClient;
import com.mongodb.client.MongoCollection;
import com.mongodb.client.model.Aggregates;
import com.mongodb.client.model.Accumulators;
import com.mongodb.client.model.Filters;
import com.mongodb.client.model.Sorts;import org.bson.Document;
import org.bson.types.ObjectId;
import java.util.List;
import static com.mongodb.client.model.Filters.*;
public final class Aggregation {
private final MongoCollection<Document> collection;
public Aggregation(MongoClient client) {
this.collection = client.getDatabase("simple").getCollection("books");
}
private Number getNValue() {
return 3;
}
public AggregateIterable<Document> getAllBookTitles(ObjectId id) {
return this.collection.aggregate(List.of(
Aggregates.group(
"${'$'}myField",
Accumulators."|"("myKey", Sorts.ascending("mySort"), "myVal", getNValue())
)
));
}
}
""",
)
@ParameterizedTest
@CsvSource(
value = [
"method;;expected",
"topN;;TOP_N",
"bottomN;;BOTTOM_N",
],
delimiterString = ";;",
useHeadersInDisplayName = true
)
fun `supports all relevant key-value accumulators with sorting criteria and n value from method call`(
method: String,
expected: Name,
psiFile: PsiFile
) {
WriteCommandAction.runWriteCommandAction(psiFile.project) {
val elementAtCaret = psiFile.caret()
val javaFacade = JavaPsiFacade.getInstance(psiFile.project)
val methodToTest = javaFacade.parserFacade.createReferenceFromText(method, null)
elementAtCaret.replace(methodToTest)
}

ApplicationManager.getApplication().runReadAction {
val aggregate = psiFile.getQueryAtMethod("Aggregation", "getAllBookTitles")
val parsedAggregate = JavaDriverDialect.parser.parse(aggregate)
val hasAggregation = parsedAggregate.component<HasAggregation<PsiElement>>()
assertEquals(1, hasAggregation?.children?.size)

val groupStage = hasAggregation?.children?.get(0)!!
val named = groupStage.component<Named>()!!
assertEquals(Name.GROUP, named.name)

val accumulator = groupStage.component<HasAccumulatedFields<PsiElement>>()!!.children[0]
val accumulatorName = accumulator.component<Named>()!!
assertEquals(expected, accumulatorName.name)

val accumulatorField = accumulator.component<HasFieldReference<PsiElement>>()?.reference as HasFieldReference.Computed<PsiElement>
assertEquals("myKey", accumulatorField.fieldName)

val accumulatorComputed = accumulator.component<HasValueReference<PsiElement>>()?.reference as HasValueReference.Computed<PsiElement>
val accumulatorComputedFieldValue = accumulatorComputed.type.expression.component<HasFieldReference<PsiElement>>()!!.reference as HasFieldReference.FromSchema<PsiElement>
assertEquals("myVal", accumulatorComputedFieldValue.fieldName)

val accumulatorSorting = accumulator.component<HasSorts<PsiElement>>()!!.children[0]
val sortingField = accumulatorSorting.component<HasFieldReference<PsiElement>>()!!.reference as HasFieldReference.FromSchema<PsiElement>
assertEquals("mySort", sortingField.fieldName)

val accumulatorLimit = accumulator.component<HasLimit>()!!
assertEquals(3, accumulatorLimit.limit)
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
package com.mongodb.jbplugin.mql.components

import com.mongodb.jbplugin.mql.Component

data class HasLimit(val limit: Int) : Component
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,9 @@ enum class Name(val canonical: String) {
FIRST("first"),
LAST("last"),
TOP("top"),
TOP_N("topN"),
BOTTOM("bottom"),
BOTTOM_N("bottomN"),
MAX("max"),
MIN("min"),
PUSH("push"),
Expand Down

0 comments on commit 08f07cc

Please sign in to comment.