Skip to content

Commit

Permalink
chore: the index analyzer now supports match stages INTELLIJ-170 (#107)
Browse files Browse the repository at this point in the history
  • Loading branch information
kmruiz authored Nov 29, 2024
1 parent 08f07cc commit 5120fa0
Show file tree
Hide file tree
Showing 5 changed files with 262 additions and 33 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,14 @@ object MongoshDialectFormatter : DialectFormatter {
emitFunctionName("explain")
emitFunctionCall()
emitPropertyAccess()
if (isAggregate) {
emitFunctionName("aggregate")
} else {
emitFunctionName("find")
}
} else {
emitFunctionName(query.component<IsCommand>()?.type?.canonical ?: "find")
}
emitFunctionName(query.component<IsCommand>()?.type?.canonical ?: "find")
emitFunctionCall(long = true, {
if (isAggregate(query)) {
emitAggregateBody(query)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,23 @@
package com.mongodb.jbplugin.mql

import com.mongodb.jbplugin.mql.components.HasCollectionReference
import com.mongodb.jbplugin.mql.components.HasFieldReference
import com.mongodb.jbplugin.mql.components.HasFilter
import com.mongodb.jbplugin.mql.components.HasValueReference
import com.mongodb.jbplugin.mql.components.Name
import com.mongodb.jbplugin.mql.parser.anyError
import com.mongodb.jbplugin.mql.parser.components.NoFieldReference
import com.mongodb.jbplugin.mql.parser.components.aggregationStages
import com.mongodb.jbplugin.mql.parser.components.allNodesWithSchemaFieldReferences
import com.mongodb.jbplugin.mql.parser.components.hasName
import com.mongodb.jbplugin.mql.parser.components.schemaFieldReference
import com.mongodb.jbplugin.mql.parser.filter
import com.mongodb.jbplugin.mql.parser.first
import com.mongodb.jbplugin.mql.parser.flatMap
import com.mongodb.jbplugin.mql.parser.map
import com.mongodb.jbplugin.mql.parser.mapError
import com.mongodb.jbplugin.mql.parser.mapMany
import com.mongodb.jbplugin.mql.parser.matches
import com.mongodb.jbplugin.mql.parser.nth
import com.mongodb.jbplugin.mql.parser.parse
import com.mongodb.jbplugin.mql.parser.requireNonNull

/**
* The IndexAnalyzer service itself. It's stateless and can be used directly.
Expand All @@ -39,41 +53,38 @@ object IndexAnalyzer {
}

private fun <S> Node<S>.allFieldReferences(): List<Pair<String, S>> {
val hasFilter = component<HasFilter<S>>()
val otherRefs = hasFilter?.children?.flatMap { it.allFieldReferences() } ?: emptyList()
val fieldRef = component<HasFieldReference<S>>()?.reference ?: return otherRefs
val valueRef = component<HasValueReference<S>>()?.reference
return if (fieldRef is HasFieldReference.FromSchema) {
otherRefs + (
valueRef?.let { reference ->
when (reference) {
is HasValueReference.Constant<S> -> Pair(
fieldRef.fieldName,
fieldRef.source
)
val extractFieldReference = schemaFieldReference<S>()
.map { it.fieldName to it.source }
.mapError { NoFieldReference }

is HasValueReference.Runtime<S> -> Pair(
fieldRef.fieldName,
fieldRef.source
)
val extractAllFieldReferencesWithValues = allNodesWithSchemaFieldReferences<S>()
.mapMany(extractFieldReference)

else -> null
}
} ?: Pair(
fieldRef.fieldName,
fieldRef.source,
)
)
} else {
otherRefs
}
val extractFromFirstMatchStage = aggregationStages<S>()
.nth(0)
.matches(hasName(Name.MATCH))
.flatMap(extractAllFieldReferencesWithValues)
.anyError()

val extractFromFiltersWhenNoAggregation = requireNonNull<Node<S>, Any>(Unit)
.matches(aggregationStages<S>().filter { it.isEmpty() }.matches())
.flatMap(extractAllFieldReferencesWithValues)
.anyError()

val findIndexableFieldReferences = first(
extractFromFirstMatchStage,
extractFromFiltersWhenNoAggregation,
)

return findIndexableFieldReferences
.parse(this)
.orElse { emptyList() }
}

/**
* @param S
*/
sealed interface SuggestedIndex<S> {
@Suppress("UNCHECKED_CAST")
data object NoIndex : SuggestedIndex<Any> {
fun <S> cast(): SuggestedIndex<S> = this as SuggestedIndex<S>
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,24 @@ fun <I, E, O, OO> Parser<I, E, O>.map(mapFn: (O) -> OO): Parser<I, E, OO> {
}
}

/**
* Returns a new parser that maps the output to a new type that can be an error or a success
* value.
*/
fun <I, E, O, EE, OO> Parser<I, E, O>.flatMap(
mapFn: suspend (O) -> Either<EE, OO>
): Parser<I, Either<E, EE>, OO> {
return { input ->
when (val result = this(input)) {
is Either.Left -> Either.left(Either.left(result.value))
is Either.Right -> when (val mappingResult = mapFn(result.value)) {
is Either.Left -> Either.left(Either.right(mappingResult.value))
is Either.Right -> Either.right(mappingResult.value)
}
}
}
}

/**
* Returns a new parser that maps the output to a new type.
*/
Expand Down Expand Up @@ -283,6 +301,16 @@ fun <I, E, O> Parser<I, E, O>.parse(input: I): Either<E, O> {
}
}

fun <I, E> requireNonNull(err: E): Parser<I?, E, I> {
return { input ->
if (input == null) {
Either.left(err)
} else {
Either.right(input)
}
}
}

@Deprecated("Only use in development.")
fun <I, E, O> Parser<I, E, O>.debug(message: String): Parser<I, E, O> {
return { input ->
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
package com.mongodb.jbplugin.mql.parser.components

import com.mongodb.jbplugin.mql.HasChildren
import com.mongodb.jbplugin.mql.Node
import com.mongodb.jbplugin.mql.adt.Either
import com.mongodb.jbplugin.mql.components.HasFieldReference
Expand Down Expand Up @@ -30,7 +29,7 @@ fun <S> allNodesWithSchemaFieldReferences(): Parser<Node<S>, NoFieldReferences,
val currentNode = if (isSchemaFieldReference) listOf(node) else emptyList()

val childNodes = node.componentsWithChildren()
.flatMap { (it as HasChildren<S>).children }
.flatMap { it.children }
.flatMap(::gatherSchemaFieldReferenceNodes)

return currentNode + childNodes
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
package com.mongodb.jbplugin.mql

import com.mongodb.jbplugin.mql.components.HasAggregation
import com.mongodb.jbplugin.mql.components.HasCollectionReference
import com.mongodb.jbplugin.mql.components.HasCollectionReference.Known
import com.mongodb.jbplugin.mql.components.HasFieldReference
import com.mongodb.jbplugin.mql.components.HasFilter
import com.mongodb.jbplugin.mql.components.Name
import com.mongodb.jbplugin.mql.components.Named
import org.junit.jupiter.api.Assertions.*
import org.junit.jupiter.api.Test

Expand Down Expand Up @@ -95,4 +98,186 @@ class IndexAnalyzerTest {
result.fields[1]
)
}

@Test
fun `considers aggregation pipelines match stages`() {
val collectionReference =
HasCollectionReference(Known(Unit, Unit, Namespace("myDb", "myColl")))
val query = Node(
Unit,
listOf(
collectionReference,
HasAggregation(
listOf(
Node(
Unit,
listOf(
Named(Name.MATCH),
HasFilter(
listOf(
Node(
Unit,
listOf(
HasFieldReference(
HasFieldReference.FromSchema(Unit, "myField")
)
)
),
Node(
Unit,
listOf(
HasFieldReference(
HasFieldReference.FromSchema(
Unit,
"mySecondField"
)
)
)
),
Node(
Unit,
listOf(
HasFieldReference(
HasFieldReference.FromSchema(Unit, "myField")
)
)
)
)
)
)
)
)
)
)
)

val result = IndexAnalyzer.analyze(query) as IndexAnalyzer.SuggestedIndex.MongoDbIndex

assertEquals(2, result.fields.size)
assertEquals(collectionReference, result.collectionReference)
assertEquals(
IndexAnalyzer.SuggestedIndex.MongoDbIndexField("myField", Unit),
result.fields[0]
)
assertEquals(
IndexAnalyzer.SuggestedIndex.MongoDbIndexField("mySecondField", Unit),
result.fields[1]
)
}

@Test
fun `does not consider aggregation pipelines match stages in the second position`() {
val collectionReference =
HasCollectionReference(Known(Unit, Unit, Namespace("myDb", "myColl")))
val query = Node(
Unit,
listOf(
collectionReference,
HasAggregation(
listOf(
Node(Unit, listOf()),
Node(
Unit,
listOf(
Named(Name.MATCH),
HasFilter(
listOf(
Node(
Unit,
listOf(
HasFieldReference(
HasFieldReference.FromSchema(Unit, "myField")
)
)
),
Node(
Unit,
listOf(
HasFieldReference(
HasFieldReference.FromSchema(
Unit,
"mySecondField"
)
)
)
),
Node(
Unit,
listOf(
HasFieldReference(
HasFieldReference.FromSchema(Unit, "myField")
)
)
)
)
)
)
)
)
)
)
)

val result = IndexAnalyzer.analyze(query) as IndexAnalyzer.SuggestedIndex.MongoDbIndex

assertEquals(0, result.fields.size)
}

@Test
fun `does not consider aggregation pipelines stages that are not match`() {
val collectionReference =
HasCollectionReference(Known(Unit, Unit, Namespace("myDb", "myColl")))
val query = Node(
Unit,
listOf(
collectionReference,
HasAggregation(
listOf(
Node(
Unit,
listOf(
Named(Name.GROUP),
HasFilter(
listOf(
Node(
Unit,
listOf(
HasFieldReference(
HasFieldReference.FromSchema(Unit, "myField")
)
)
),
Node(
Unit,
listOf(
HasFieldReference(
HasFieldReference.FromSchema(
Unit,
"mySecondField"
)
)
)
),
Node(
Unit,
listOf(
HasFieldReference(
HasFieldReference.FromSchema(Unit, "myField")
)
)
)
)
)
)
)
)
)
)
)

val result = IndexAnalyzer.analyze(query) as IndexAnalyzer.SuggestedIndex.MongoDbIndex

assertEquals(0, result.fields.size)
assertEquals(collectionReference, result.collectionReference)
}
}

0 comments on commit 5120fa0

Please sign in to comment.