Skip to content

Commit

Permalink
fix: fixes parsing of null _id expression
Browse files Browse the repository at this point in the history
  • Loading branch information
himanshusinghs committed Nov 25, 2024
1 parent 65d9dbe commit d392545
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -527,6 +527,7 @@ public class Repository {
return client.getDatabase("myDatabase")
.getCollection("myCollection")
.aggregate(List.of(
Aggregates.group(null),
Aggregates.group("${'$'}possibleIdField"),
Aggregates.group("${'$'}possibleIdField", Accumulators.sum("totalCount", 1)),
Aggregates.group(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -734,10 +734,10 @@ object JavaDriverDialectParser : DialectParser<PsiElement> {
)
)
)
constant && value != null -> HasValueReference.Constant(
constant -> HasValueReference.Constant(
element,
value,
value.javaClass.toBsonType(value)
value?.javaClass.toBsonType(value)
)
!constant && element is PsiExpression -> HasValueReference.Runtime(
element,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,56 @@ import java.util.List;
import static com.mongodb.client.model.Filters.*;
public final class Aggregation {
private final MongoCollection<Document> collection;
public Aggregation(MongoClient client) {
this.collection = client.getDatabase("simple").getCollection("books");
}
public AggregateIterable<Document> getAllBookTitles(ObjectId id) {
return this.collection.aggregate(List.of(
Aggregates.group(null)
));
}
}
"""
)
fun `should be able to parse a group stage with null _id`(psiFile: PsiFile) {
val aggregate = psiFile.getQueryAtMethod("Aggregation", "getAllBookTitles")
val parsedAggregate = JavaDriverDialect.parser.parse(aggregate)
val hasAggregation = parsedAggregate.component<HasAggregation<PsiElement>>()
assertEquals(1, hasAggregation?.children?.size)

val groupStage = hasAggregation?.children?.get(0)!!

val named = groupStage.component<Named>()!!
assertEquals(Name.GROUP, named.name)

val idFieldRef = groupStage.component<HasFieldReference<PsiElement>>()!!.reference as HasFieldReference.Inferred<PsiElement>
val constantValueRef = groupStage.component<HasValueReference<PsiElement>>()!!.reference as HasValueReference.Constant<PsiElement>
val accumulatedFields = groupStage.component<HasAccumulatedFields<PsiElement>>()!!

assertEquals("_id", idFieldRef.fieldName)
assertEquals(0, accumulatedFields.children.size)
assertEquals(null, constantValueRef.value)
}

@ParsingTest(
fileName = "Aggregation.java",
value = """
import com.mongodb.client.AggregateIterable;
import com.mongodb.client.MongoClient;
import com.mongodb.client.MongoCollection;
import com.mongodb.client.model.Aggregates;
import com.mongodb.client.model.Filters;
import org.bson.Document;
import org.bson.types.ObjectId;
import java.util.List;
import static com.mongodb.client.model.Filters.*;
public final class Aggregation {
private final MongoCollection<Document> collection;
Expand Down

0 comments on commit d392545

Please sign in to comment.