diff --git a/src/LinqTests/Bugs/Bug_2850_missed_field_reducing.cs b/src/LinqTests/Bugs/Bug_2850_missed_field_reducing.cs new file mode 100644 index 0000000000..1c9c2badd1 --- /dev/null +++ b/src/LinqTests/Bugs/Bug_2850_missed_field_reducing.cs @@ -0,0 +1,31 @@ +using System.Linq; +using System.Threading.Tasks; +using Marten; +using Marten.Testing.Documents; +using Marten.Testing.Harness; +using Shouldly; + +namespace LinqTests.Bugs; + +public class Bug_2850_missed_field_reducing : BugIntegrationContext +{ + public async Task RunQuery(bool include, int resultCount) + { + var results = await theSession.Query().Where(x => include || !x.Flag).CountAsync(); + results.ShouldBe(resultCount); + } + + [Fact] + public async Task pass_bool_into_query() + { + await theStore.Advanced.Clean.DeleteDocumentsByTypeAsync(typeof(Target)); + + var targets = Target.GenerateRandomData(100).ToArray(); + await theStore.BulkInsertAsync(targets); + + var count = targets.Count(x => x.Flag); + + await RunQuery(true, 100); + await RunQuery(false, 100 - count); + } +} diff --git a/src/Marten/Linq/Parsing/WhereClauseParser.cs b/src/Marten/Linq/Parsing/WhereClauseParser.cs index afcda8c703..d0888e4232 100644 --- a/src/Marten/Linq/Parsing/WhereClauseParser.cs +++ b/src/Marten/Linq/Parsing/WhereClauseParser.cs @@ -1,5 +1,6 @@ using System; using System.Collections.Generic; +using System.Diagnostics; using System.Linq.Expressions; using Marten.Exceptions; using Marten.Linq.Members; @@ -49,6 +50,13 @@ protected override Expression VisitMember(MemberExpression node) { if (node.Type == typeof(bool)) { + // Check if it's a literal field. See https://github.com/JasperFx/marten/issues/2850 + if (node.TryToParseConstant(out var constant)) + { + _holder.Register(constant.Value.Equals(true) ? new WhereFragment("true") : new WhereFragment("false")); + return null; + } + var field = _members.MemberFor(node); if (field is IBooleanField b) { @@ -65,18 +73,6 @@ protected override Expression VisitMember(MemberExpression node) return base.VisitMember(node); } - private IQueryableMember? findHoldingMember(SimpleExpression? @object, SimpleExpression[] arguments) - { - if (@object?.Member != null) return @object.Member; - - foreach (var argument in arguments) - { - if (argument.Member != null) return argument.Member; - } - - return null; - } - protected override Expression VisitMethodCall(MethodCallExpression node) { var parser = _options.Linq.FindMethodParser(node); @@ -162,6 +158,10 @@ protected override Expression VisitUnary(UnaryExpression node) return returnValue; } + else if (node.NodeType == ExpressionType.OrElse) + { + Debug.WriteLine(node); + } return null; }