Skip to content

Commit

Permalink
feat(query-generation): support for generating queries with sort INTE…
Browse files Browse the repository at this point in the history
…LLIJ-186 (#122)
  • Loading branch information
kmruiz authored Jan 21, 2025
1 parent 229aac7 commit 9f68d15
Show file tree
Hide file tree
Showing 7 changed files with 214 additions and 5 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ MongoDB plugin for IntelliJ IDEA.
## [Unreleased]

### Added
* [INTELLIJ-188](https://jira.mongodb.org/browse/INTELLIJ-188) Support for generating sort in the query generator.
* [INTELLIJ-186](https://jira.mongodb.org/browse/INTELLIJ-186) Support for parsing Sorts in the Java Driver.
* [INTELLIJ-187](https://jira.mongodb.org/browse/INTELLIJ-187) Use safe execution plans by default. Allow full execution
plans through a Plugin settings flag.
* [INTELLIJ-180](https://jira.mongodb.org/browse/INTELLIJ-180) Telemetry when inspections are shown and resolved.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ public List<Document> findMoviesByYear(int year) {
.getDatabase("sample_mflix")
.getCollection("movies")
.find(Filters.eq("year", year))
// .sort(Sorts.ascending(IMDB_VOTES))
.sort(Sorts.orderBy(Sorts.ascending(IMDB_VOTES), Sorts.descending("_id")))
.into(new ArrayList<>());
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import com.mongodb.jbplugin.mql.BsonArray
import com.mongodb.jbplugin.mql.BsonBoolean
import com.mongodb.jbplugin.mql.BsonInt32
import com.mongodb.jbplugin.mql.BsonType
import com.mongodb.jbplugin.mql.Component
import com.mongodb.jbplugin.mql.ComputedBsonType
import com.mongodb.jbplugin.mql.Node
import com.mongodb.jbplugin.mql.components.*
Expand Down Expand Up @@ -60,6 +61,9 @@ object JavaDriverDialectParser : DialectParser<PsiElement> {
?: return Node(source, listOf(sourceDialect, collectionReference))
val command = methodCallToCommand(currentCall)

val additionalMetadataMethods = collectAllMetadataMethods(currentCall)
val additionalMetadata = processFindQueryAdditionalMetadata(additionalMetadataMethods)

/**
* We might come across a FIND_ONE command and in that case we need to be pointing to the
* right method call, find() and not find().first(), in order to parse the filter arguments
Expand All @@ -80,7 +84,7 @@ object JavaDriverDialectParser : DialectParser<PsiElement> {
hasFilters,
hasUpdates,
hasAggregation,
),
) + additionalMetadata,
)
} else {
commandCallMethod?.let {
Expand Down Expand Up @@ -110,9 +114,13 @@ object JavaDriverDialectParser : DialectParser<PsiElement> {
sourceDialect,
collectionReference,
command
)
) + additionalMetadata
)
} ?: return Node(source, listOf(sourceDialect, collectionReference))
}
?: return Node(
source,
listOf(sourceDialect, collectionReference) + additionalMetadata
)
}
}

Expand Down Expand Up @@ -895,6 +903,53 @@ object JavaDriverDialectParser : DialectParser<PsiElement> {
}
)
}

private fun collectAllMetadataMethods(methodCall: PsiMethodCallExpression?): List<PsiMethodCallExpression> {
val allParentMethodExpressions = methodCall?.collectTypeUntil(
PsiMethodCallExpression::class.java,
PsiReturnStatement::class.java
)?.filter {
// filter out ourselves
!it.isEquivalentTo(methodCall)
} ?: emptyList()

val allChildrenMethodExpressions = methodCall?.findAllChildrenOfType(
PsiMethodCallExpression::class.java
)?.filter {
// filter out ourselves
!it.isEquivalentTo(methodCall)
} ?: emptyList()

return allParentMethodExpressions + allChildrenMethodExpressions
}

private fun processFindQueryAdditionalMetadata(methodCalls: List<PsiMethodCallExpression>): List<Component> {
return methodCalls.flatMap { methodCall ->
val method = methodCall.fuzzyResolveMethod()
if (method != null && isMethodPartOfTheCursorClass(method)) {
val currentMetadata = when (method.name) {
"sort" -> {
val sortArg =
methodCall.argumentList.expressions.getOrNull(0) ?: return emptyList()
val sortExpr =
resolveBsonBuilderCall(sortArg, SORTS_FQN) ?: return emptyList()
listOf(
HasSorts(parseBsonBuilderCallsSimilarToProjections(sortExpr, SORTS_FQN))
)
}
else -> emptyList()
}

return currentMetadata
} else {
emptyList()
}
}
}

private fun isMethodPartOfTheCursorClass(method: PsiMethod?): Boolean =
method?.containingClass?.qualifiedName?.contains("FindIterable") == true ||
method?.containingClass?.qualifiedName?.contains("MongoIterable") == true
}

fun PsiExpression.resolveFieldNameFromExpression(): HasFieldReference.FieldReference<out Any> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2184,4 +2184,37 @@ public final class Repository {
val command = parsedQuery.component<IsCommand>()
assertEquals(IsCommand.CommandType.FIND_ONE, command?.type)
}

@ParsingTest(
fileName = "Repository.java",
value = """
import com.mongodb.client.MongoClient;
import com.mongodb.client.MongoCollection;
import com.mongodb.client.model.Sorts;import org.bson.types.ObjectId;
import java.util.ArrayList;
import static com.mongodb.client.model.Filters.*;
public final class Repository {
private final MongoCollection<Document> collection;
public Repository(MongoClient client) {
this.collection = client.getDatabase("simple").getCollection("books");
}
public Document findBookById(ObjectId id) {
return this.collection.find(eq("_id", id)).sort(Sorts.ascending("myField")).first();
}
}
""",
)
fun `correctly parses FindIterable#sort as a SORT component`(psiFile: PsiFile) {
val query = psiFile.getQueryAtMethod("Repository", "findBookById")
val parsedQuery = JavaDriverDialect.parser.parse(query)
val sorting = parsedQuery.component<HasSorts<PsiElement>>()!!
val sortByMyFieldName = sorting.children[0].component<HasFieldReference<PsiElement>>()!!.reference
val sortByMyFieldOrder = sorting.children[0].component<HasValueReference<PsiElement>>()!!.reference

assertEquals("myField", (sortByMyFieldName as HasFieldReference.FromSchema).fieldName)
assertEquals(1, (sortByMyFieldOrder as HasValueReference.Inferred).value)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,10 @@ object MongoshDialectFormatter : DialectFormatter {
emitQueryBody(query, firstCall = true)
}
})

if (returnsACursor(query)) {
emitSort(query)
}
}.computeOutput()

val ref = query.component<HasCollectionReference<S>>()?.reference
Expand Down Expand Up @@ -329,6 +333,12 @@ private fun <S> isAggregate(node: Node<S>): Boolean {
.parse(node).orElse { false }
}

private fun <S> returnsACursor(node: Node<S>): Boolean {
return whenIsCommand<S>(IsCommand.CommandType.FIND_MANY)
.map { true }
.parse(node).orElse { false }
}

private fun <S> canEmitAggregate(node: Node<S>): Boolean {
return aggregationStages<S>()
.matches(count<Node<S>>().filter { it >= 1 }.matches().anyError())
Expand All @@ -343,6 +353,7 @@ private fun <S> MongoshBackend.resolveValueReference(
fieldRef: HasFieldReference<S>?
) = when (val ref = valueRef.reference) {
is HasValueReference.Constant -> registerConstant(ref.value)
is HasValueReference.Inferred -> registerConstant(ref.value)
is HasValueReference.Runtime -> registerVariable(
(fieldRef?.reference as? FromSchema)?.fieldName ?: "value",
ref.type
Expand Down Expand Up @@ -382,3 +393,29 @@ private fun <S> MongoshBackend.emitCollectionReference(collRef: HasCollectionRef

return this
}

private fun <S> MongoshBackend.emitSort(query: Node<S>): MongoshBackend {
val sortComponent = query.component<HasSorts<S>>()
if (sortComponent == null) {
return this
}

fun generateSortKeyVal(node: Node<S>): MongoshBackend {
val fieldRef = node.component<HasFieldReference<S>>() ?: return this
val valueRef = node.component<HasValueReference<S>>() ?: return this

emitObjectKey(resolveFieldReference(fieldRef))
emitContextValue(resolveValueReference(valueRef, fieldRef))
return emitObjectValueEnd()
}

emitPropertyAccess()
emitFunctionName("sort")
return emitFunctionCall(long = false, {
emitObjectStart(long = false)
for (sortCriteria in sortComponent.children) {
generateSortKeyVal(sortCriteria)
}
emitObjectEnd(long = false)
})
}
Original file line number Diff line number Diff line change
Expand Up @@ -225,12 +225,11 @@ class MongoshBackend(
}

private fun serializePrimitive(value: Any?, isPlaceholder: Boolean): String = when (value) {
is Byte, Short, Int, Long, Float, Double -> Encode.forJavaScript(value.toString())
is BigInteger -> "Decimal128(\"$value\")"
is BigDecimal -> "Decimal128(\"$value\")"
is Byte, is Short, is Int, is Long, is Float, is Double, is Number -> value.toString()
is Boolean -> value.toString()
is ObjectId -> "ObjectId(\"${Encode.forJavaScript(value.toHexString())}\")"
is Number -> Encode.forJavaScript(value.toString())
is String -> '"' + Encode.forJavaScript(value) + '"'
is Date, is Instant, is LocalDate, is LocalDateTime -> if (isPlaceholder) {
"ISODate(\"$MONGODB_FIRST_RELEASE\")"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import org.intellij.lang.annotations.Language
import org.junit.jupiter.api.Assertions.assertEquals
import org.junit.jupiter.api.Test
import org.junit.jupiter.params.ParameterizedTest
import org.junit.jupiter.params.provider.MethodSource
import org.junit.jupiter.params.provider.ValueSource

class MongoshDialectFormatterTest {
Expand Down Expand Up @@ -520,6 +521,86 @@ class MongoshDialectFormatterTest {
}
}

@Test
fun `can sort a find query when specified`() {
assertGeneratedQuery(
"""
var collection = ""
var database = ""
db.getSiblingDB(database).getCollection(collection).find().sort({"a": 1, })
""".trimIndent()
) {
Node(
Unit,
listOf(
IsCommand(IsCommand.CommandType.FIND_MANY),
HasSorts(
listOf(
Node(
Unit,
listOf(
HasFieldReference(HasFieldReference.FromSchema(Unit, "a")),
HasValueReference(
HasValueReference.Inferred(Unit, 1, BsonInt32)
)
)
)
)
)
)
)
}
}

companion object {
@JvmStatic
fun queryCommandsThatDoNotReturnSortableCursors(): Array<IsCommand.CommandType> {
return arrayOf(
IsCommand.CommandType.COUNT_DOCUMENTS,
IsCommand.CommandType.DELETE_MANY,
IsCommand.CommandType.DELETE_ONE,
IsCommand.CommandType.DISTINCT,
IsCommand.CommandType.ESTIMATED_DOCUMENT_COUNT,
IsCommand.CommandType.FIND_ONE,
IsCommand.CommandType.FIND_ONE_AND_DELETE,
)
}
}

@ParameterizedTest
@MethodSource("queryCommandsThatDoNotReturnSortableCursors")
fun `can not sort a query that does not return a cursor`(command: IsCommand.CommandType) {
assertGeneratedQuery(
"""
var collection = ""
var database = ""
db.getSiblingDB(database).getCollection(collection).${command.canonical}()
""".trimIndent()
) {
Node(
Unit,
listOf(
IsCommand(command),
HasSorts(
listOf(
Node(
Unit,
listOf(
HasFieldReference(HasFieldReference.FromSchema(Unit, "a")),
HasValueReference(
HasValueReference.Inferred(Unit, 1, BsonInt32)
)
)
)
)
)
)
)
}
}

@Test
fun `generates an index suggestion for a query given its fields`() {
assertGeneratedIndex(
Expand Down

0 comments on commit 9f68d15

Please sign in to comment.