Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Index checking for aggregates that start with $match INTELLIJ-139 #94

Merged
merged 7 commits into from
Nov 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -7,27 +7,55 @@ import com.mongodb.jbplugin.mql.*
import com.mongodb.jbplugin.mql.components.*
import com.mongodb.jbplugin.mql.components.HasFieldReference.FromSchema
import com.mongodb.jbplugin.mql.components.HasFieldReference.Unknown
import com.mongodb.jbplugin.mql.parser.anyError
import com.mongodb.jbplugin.mql.parser.components.aggregationStages
import com.mongodb.jbplugin.mql.parser.components.allFiltersRecursively
import com.mongodb.jbplugin.mql.parser.components.hasName
import com.mongodb.jbplugin.mql.parser.components.whenIsCommand
import com.mongodb.jbplugin.mql.parser.count
import com.mongodb.jbplugin.mql.parser.filter
import com.mongodb.jbplugin.mql.parser.map
import com.mongodb.jbplugin.mql.parser.matches
import com.mongodb.jbplugin.mql.parser.nth
import com.mongodb.jbplugin.mql.parser.parse
import io.github.z4kn4fein.semver.Version
import org.owasp.encoder.Encode

object MongoshDialectFormatter : DialectFormatter {
override fun <S> formatQuery(query: Node<S>, explain: Boolean): OutputQuery {
val outputString = MongoshBackend().apply {
override fun <S> formatQuery(
query: Node<S>,
explain: Boolean,
): OutputQuery {
val isAggregate = isAggregate(query)
val canEmitAggregate = canEmitAggregate(query)

val outputString = MongoshBackend(prettyPrint = explain).apply {
if (isAggregate && !canEmitAggregate) {
emitComment("Only aggregates with a single match stage can be converted.")
return@apply
}

emitDbAccess()
emitCollectionReference(query.component<HasCollectionReference<S>>())
emitFunctionName("find")
emitFunctionCall({
emitQueryBody(query, firstCall = true)
})
if (explain) {
emitPropertyAccess()
emitFunctionName("explain")
emitFunctionCall()
emitPropertyAccess()
}
emitFunctionName(query.component<IsCommand>()?.type?.canonical ?: "find")
emitFunctionCall(long = true, {
if (isAggregate(query)) {
emitAggregateBody(query)
} else {
emitQueryBody(query, firstCall = true)
}
})
}.computeOutput()

return when (val ref = query.component<HasCollectionReference<S>>()?.reference) {
is HasCollectionReference.Known -> if (ref.namespace.isValid) {
val ref = query.component<HasCollectionReference<S>>()?.reference
return when {
isAggregate && !canEmitAggregate -> OutputQuery.Incomplete(outputString)
ref is HasCollectionReference.Known -> if (ref.namespace.isValid) {
OutputQuery.CanBeRun(outputString)
} else {
OutputQuery.Incomplete(outputString)
Expand Down Expand Up @@ -88,37 +116,38 @@ private fun <S> MongoshBackend.emitQueryBody(
val fieldRef = node.component<HasFieldReference<S>>()
val valueRef = node.component<HasValueReference<S>>()
val hasFilter = node.component<HasFilter<S>>()
val isLong = allFiltersRecursively<S>().parse(node).orElse { emptyList() }.size > 3

if (hasFilter != null && fieldRef == null && valueRef == null && named == null) {
// 1. has children, nothing else (root node)
if (firstCall) {
emitObjectStart()
emitObjectStart(long = isLong)
}

hasFilter.children.forEach {
emitQueryBody(it)
emitObjectValueEnd()
}
if (firstCall) {
emitObjectEnd()
emitObjectEnd(long = isLong)
}
} else if (hasFilter == null && fieldRef != null && valueRef != null && named == null) {
// 2. no children, only a field: value case
if (firstCall) {
emitObjectStart()
emitObjectStart(long = isLong)
}
emitObjectKey(resolveFieldReference(fieldRef))
emitContextValue(resolveValueReference(valueRef, fieldRef))
if (firstCall) {
emitObjectEnd()
emitObjectEnd(long = isLong)
}
} else {
named?.let {
// 3. children and named
if (named.name == Name.EQ) {
// normal a: b case
if (firstCall) {
emitObjectStart()
emitObjectStart(long = isLong)
}
if (fieldRef != null) {
emitObjectKey(resolveFieldReference(fieldRef))
Expand All @@ -134,7 +163,7 @@ private fun <S> MongoshBackend.emitQueryBody(
}

if (firstCall) {
emitObjectEnd()
emitObjectEnd(long = isLong)
}
} else if (setOf( // 1st basic attempt, to improve in INTELLIJ-76
Name.GT,
Expand All @@ -146,7 +175,7 @@ private fun <S> MongoshBackend.emitQueryBody(
) {
// a: { $gt: 1 }
if (firstCall) {
emitObjectStart()
emitObjectStart(long = isLong)
}

if (fieldRef != null) {
Expand All @@ -159,9 +188,10 @@ private fun <S> MongoshBackend.emitQueryBody(
emitObjectEnd()

if (firstCall) {
emitObjectEnd()
emitObjectEnd(long = isLong)
}
} else if (setOf( // 1st basic attempt, to improve in INTELLIJ-77
} else if (setOf(
// 1st basic attempt, to improve in INTELLIJ-77
Name.AND,
Name.OR,
Name.NOR,
Expand All @@ -171,14 +201,17 @@ private fun <S> MongoshBackend.emitQueryBody(
emitObjectStart()
}
emitObjectKey(registerConstant('$' + named.name.canonical))
emitArrayStart()
emitArrayStart(long = true)
hasFilter?.children?.forEach {
emitObjectStart()
emitQueryBody(it)
emitObjectEnd()
emitObjectValueEnd()
if (prettyPrint) {
emitNewLine()
}
}
emitArrayEnd()
emitArrayEnd(long = true)
if (firstCall) {
emitObjectEnd()
}
Expand Down Expand Up @@ -235,15 +268,15 @@ private fun <S> MongoshBackend.emitQueryBody(
}
} else if (named.name != Name.UNKNOWN && fieldRef != null && valueRef != null) {
if (firstCall) {
emitObjectStart()
emitObjectStart(long = isLong)
}
emitObjectKey(resolveFieldReference(fieldRef))
emitObjectStart()
emitObjectStart(long = isLong)
emitObjectKey(registerConstant('$' + named.name.canonical))
emitContextValue(resolveValueReference(valueRef, fieldRef))
emitObjectEnd()
emitObjectEnd(long = isLong)
if (firstCall) {
emitObjectEnd()
emitObjectEnd(long = isLong)
}
}
}
Expand All @@ -252,6 +285,43 @@ private fun <S> MongoshBackend.emitQueryBody(
return this
}

private fun <S> MongoshBackend.emitAggregateBody(node: Node<S>): MongoshBackend {
// here we can assume that we only have 1 single stage that is a match
val matchStage = node.component<HasAggregation<S>>()!!.children[0]
val filter = matchStage.component<HasFilter<S>>()?.children?.getOrNull(0)
val longFilter = filter?.component<HasFilter<S>>()?.children?.size?.let { it > 3 } == true

emitArrayStart(long = true)
emitObjectStart()
emitObjectKey(registerConstant('$' + "match"))
if (filter != null) {
emitObjectStart(long = longFilter)
emitQueryBody(filter)
emitObjectEnd(long = longFilter)
} else {
emitComment("No filter provided.")
}
emitObjectEnd()
emitArrayEnd(long = true)

return this
}

private fun <S> isAggregate(node: Node<S>): Boolean {
return whenIsCommand<S>(IsCommand.CommandType.AGGREGATE)
.map { true }
.parse(node).orElse { false }
}

private fun <S> canEmitAggregate(node: Node<S>): Boolean {
return aggregationStages<S>()
.matches(count<Node<S>>().filter { it >= 1 }.matches().anyError())
.nth(0)
.matches(hasName(Name.MATCH))
.map { true }
.parse(node).orElse { false }
}

private fun <S> MongoshBackend.resolveValueReference(
valueRef: HasValueReference<S>,
fieldRef: HasFieldReference<S>?
Expand Down
Loading
Loading