Skip to content

Commit

Permalink
Improved LINQ support for Dictionaries. Closes GH-2651
Browse files Browse the repository at this point in the history
  • Loading branch information
jeremydmiller committed Dec 4, 2023
1 parent 5aae5b6 commit 2150cbf
Show file tree
Hide file tree
Showing 37 changed files with 921 additions and 138 deletions.
6 changes: 3 additions & 3 deletions docs/documents/querying/linq/child-collections.md
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ public void query_against_string_array()
.Select(x => x.Id).ShouldHaveTheSameElementsAs(doc1.Id, doc2.Id);
}
```
<sup><a href='https://github.com/JasperFx/marten/blob/master/src/LinqTests/ChildCollections/query_against_child_collections.cs#L457-L475' title='Snippet source file'>snippet source</a> | <a href='#snippet-sample_query_against_string_array' title='Start of snippet'>anchor</a></sup>
<sup><a href='https://github.com/JasperFx/marten/blob/master/src/LinqTests/ChildCollections/query_against_child_collections.cs#L472-L490' title='Snippet source file'>snippet source</a> | <a href='#snippet-sample_query_against_string_array' title='Start of snippet'>anchor</a></sup>
<!-- endSnippet -->

Marten also allows you to query over IEnumerables using the Any method for equality (similar to Contains):
Expand Down Expand Up @@ -142,7 +142,7 @@ public void query_against_number_list_with_any()
.Count(x => x.Numbers.Any()).ShouldBe(3);
}
```
<sup><a href='https://github.com/JasperFx/marten/blob/master/src/LinqTests/ChildCollections/query_against_child_collections.cs#L573-L597' title='Snippet source file'>snippet source</a> | <a href='#snippet-sample_query_any_string_array' title='Start of snippet'>anchor</a></sup>
<sup><a href='https://github.com/JasperFx/marten/blob/master/src/LinqTests/ChildCollections/query_against_child_collections.cs#L588-L612' title='Snippet source file'>snippet source</a> | <a href='#snippet-sample_query_any_string_array' title='Start of snippet'>anchor</a></sup>
<!-- endSnippet -->

As of 1.2, you can also query against the `Count()` or `Length` of a child collection with the normal comparison
Expand Down Expand Up @@ -170,7 +170,7 @@ public void query_against_number_list_with_count_method()
.Single(x => x.Numbers.Count() == 4).Id.ShouldBe(doc3.Id);
}
```
<sup><a href='https://github.com/JasperFx/marten/blob/master/src/LinqTests/ChildCollections/query_against_child_collections.cs#L599-L620' title='Snippet source file'>snippet source</a> | <a href='#snippet-sample_query_against_number_list_with_count_method' title='Start of snippet'>anchor</a></sup>
<sup><a href='https://github.com/JasperFx/marten/blob/master/src/LinqTests/ChildCollections/query_against_child_collections.cs#L614-L635' title='Snippet source file'>snippet source</a> | <a href='#snippet-sample_query_against_number_list_with_count_method' title='Start of snippet'>anchor</a></sup>
<!-- endSnippet -->

## IsOneOf
Expand Down
2 changes: 1 addition & 1 deletion docs/events/projections/flat.md
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ public class FlatImportProjection: FlatTableProjection
Project<ImportStarted>(map =>
{
// Set values in the table from the event
map.Map(x => x.ActivityType).NotNull();
map.Map(x => x.ActivityType);
map.Map(x => x.CustomerId);
map.Map(x => x.PlannedSteps, "total_steps")
.DefaultValue(0);
Expand Down
125 changes: 108 additions & 17 deletions src/LinqTests/Acceptance/dictionary_usage.cs
Original file line number Diff line number Diff line change
Expand Up @@ -14,22 +14,33 @@ public class dictionary_usage: IntegrationContext
{
private readonly ITestOutputHelper _output;

public dictionary_usage(DefaultStoreFixture fixture, ITestOutputHelper output) : base(fixture)
public dictionary_usage(DefaultStoreFixture fixture, ITestOutputHelper output): base(fixture)
{
_output = output;
theStore.BulkInsert(Target.GenerateRandomData(100).ToArray());
}


// using key0 and value0 for these because the last node, which is deep, should have at least a single dict node

[Fact]
public void dict_string_can_query_using_containskey()
public async Task playing()
{
var results = theSession.Query<Target>().Where(x => x.StringDict.ContainsKey("key0")).ToList();
results.All(r => r.StringDict.ContainsKey("key0")).ShouldBeTrue();
await theStore.Advanced.Clean.DeleteDocumentsByTypeAsync(typeof(Target));

theSession.Logger = new TestOutputMartenLogger(_output);


var targets = Target.GenerateRandomData(10).ToArray();
var number = targets.Select(x => x.StringDict).SelectMany(x => x.Values).Count();
_output.WriteLine(number.ToString());


await theStore.BulkInsertAsync(targets);
var data = await theSession.Query<Target>().Select(x => x.StringDict).ToListAsync();

var count = await theSession.Query<Target>().Where(x => x.StringDict.Any()).Select(x => x.StringDict).ToListAsync();
}

// using key0 and value0 for these because the last node, which is deep, should have at least a single dict node


[Fact]
public async Task dict_guid_can_query_using_containskey()
{
Expand All @@ -45,14 +56,6 @@ public async Task dict_guid_can_query_using_containskey()
results.All(r => r.GuidDict.ContainsKey(guid)).ShouldBeTrue();
}

[Fact]
public void dict_string_can_query_using_containsKVP()
{
var kvp = new KeyValuePair<string, string>("key0", "value0");
var results = theSession.Query<Target>().Where(x => x.StringDict.Contains(kvp)).ToList();
results.All(r => r.StringDict.Contains(kvp)).ShouldBeTrue();
}

[Fact]
public async Task dict_guid_can_query_using_containsKVP()
{
Expand All @@ -65,8 +68,96 @@ public async Task dict_guid_can_query_using_containsKVP()

var kvp = new KeyValuePair<Guid, Guid>(guidk, guidv);
// Only works if the dictionary is in interface form
var results = await theSession.Query<Target>().Where(x => ((IDictionary<Guid, Guid>)x.GuidDict).Contains(kvp)).ToListAsync();
var results = await theSession.Query<Target>().Where(x => ((IDictionary<Guid, Guid>)x.GuidDict).Contains(kvp))
.ToListAsync();
results.All(r => r.GuidDict.Contains(kvp)).ShouldBeTrue();
}

[Fact]
public async Task query_against_values()
{
var queryId = Guid.NewGuid();
var queryId2 = Guid.NewGuid();
var queryId3 = Guid.NewGuid();
var guidDict = new Dictionary<Guid, HashSet<Guid>> { { queryId, new HashSet<Guid>() { queryId3 } } };
var objectDict = new Dictionary<Guid, MyEntity> { { Guid.NewGuid(), new MyEntity(queryId2, string.Empty) } };
var dictEntity = new DictEntity(Guid.NewGuid(), guidDict, objectDict);

theSession.Store(dictEntity);
await theSession.SaveChangesAsync();


var entityExists = await theSession.Query<DictEntity>()
.Where(x => x.GuidDict.Values.Any(hs => hs.Contains(queryId3)))
.AnyAsync();
}

[Fact]
public async Task select_many_against_the_Keys()
{
theSession.Logger = new TestOutputMartenLogger(_output);


await theStore.BulkInsertAsync(Target.GenerateRandomData(1000).ToArray());

var pairs = await theSession.Query<Target>().SelectMany(x => x.StringDict.Keys).ToListAsync();
pairs.Count.ShouldBeGreaterThan(0);
}

[Fact]
public async Task select_many_against_the_values()
{
/*
WITH mt_temp_id_list1CTE as (
select jsonb_path_query(d.data -> 'StringDict', '$.*') ->> 0 as data from public.mt_doc_target as d
)
select data from mt_temp_id_list1CTE
as d where d.data = 'value2' order by d.data;
*/

theSession.Logger = new TestOutputMartenLogger(_output);

await theStore.Advanced.Clean.DeleteDocumentsByTypeAsync(typeof(Target));
var targets = Target.GenerateRandomData(10).ToArray();
await theStore.BulkInsertAsync(targets);

var values = await theSession.Query<Target>().SelectMany(x => x.StringDict.Values).ToListAsync();
values.Count.ShouldBe(targets.SelectMany(x => x.StringDict.Values).Count());
}


[Fact]
public async Task select_many_with_wheres_and_order_by_on_values()
{
theSession.Logger = new TestOutputMartenLogger(_output);


await theStore.BulkInsertAsync(Target.GenerateRandomData(1000).ToArray());

var values = await theSession.Query<Target>().SelectMany(x => x.StringDict.Values)
.Where(x => x == "value2")
.OrderBy(x => x)
.ToListAsync();

values.Count.ShouldBeGreaterThan(0);

/*
WITH mt_temp_id_list1CTE as (
select jsonb_path_query(d.data -> 'StringDict', '$.*') ->> 0 as data from public.mt_doc_target as d
)
select data from mt_temp_id_list1CTE
as d where d.data = 'value2' order by d.data;
*/
}
}

public class EntityWithDict
{
public Guid Id { get; set; }
public Dictionary<int, string> Data { get; set; } = new();
}

public sealed record MyEntity(Guid Id, string Value);

public sealed record DictEntity(Guid Id, Dictionary<Guid, HashSet<Guid>> GuidDict,
Dictionary<Guid, MyEntity> ObjectDict);
19 changes: 19 additions & 0 deletions src/LinqTests/Acceptance/where_clauses.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Threading.Tasks;
using LinqTests.Acceptance.Support;
using Xunit.Abstractions;
Expand Down Expand Up @@ -143,6 +144,24 @@ static where_clauses()
@where(x => x.Number > x.AnotherNumber);
@where(x => x.Number <= x.AnotherNumber);
@where(x => x.Number >= x.AnotherNumber);


// Dictionaries
@where(x => x.StringDict.ContainsKey("key0"));

var kvp = new KeyValuePair<string, string>("key0", "value0");
@where(x => x.StringDict.Contains(kvp));
@where(x => x.StringDict.Values.Contains("value3"));
@where(x => x.StringDict.Keys.Contains("key2"));
@where(x => x.StringDict.Values.Any(v => v.EndsWith("3")));
@where(x => x.StringDict.Keys.Any(v => v.EndsWith("3")));
@where(x => x.StringDict.Any());
@where(x => !x.StringDict.Any());
@where(x => x.StringDict.Any(p => p.Key == "key1"));
@where(x => x.StringDict.Any(p => p.Value == "value2"));

@where(x => x.StringDict.Count > 2);
@where(x => x.StringDict.Count() == 2);
}

[Theory]
Expand Down
15 changes: 15 additions & 0 deletions src/LinqTests/ChildCollections/query_against_child_collections.cs
Original file line number Diff line number Diff line change
Expand Up @@ -452,6 +452,21 @@ public void query_against_number_array()
.Select(x => x.Id).ShouldHaveTheSameElementsAs(doc1.Id, doc2.Id);
}

[Fact]
public void query_against_number_array_count()
{
var doc1 = new DocWithArrays { Numbers = new[] { 1, 2, 3 } };
var doc2 = new DocWithArrays { Numbers = new[] { 3, 4, 5 } };
var doc3 = new DocWithArrays { Numbers = new[] { 5, 6, 7, 8 } };

theSession.Store(doc1, doc2, doc3);

theSession.SaveChanges();

theSession.Query<DocWithArrays>().Where(x => x.Numbers.Length == 4).ToArray()
.Select(x => x.Id).ShouldHaveTheSameElementsAs(doc3.Id);
}

[Fact]

#region sample_query_against_string_array
Expand Down
1 change: 1 addition & 0 deletions src/LinqTests/Internals/SimpleExpressionTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
using Marten;
using Marten.Linq;
using Marten.Linq.Members;
using Marten.Linq.Members.Dictionaries;
using Marten.Linq.Parsing;
using Marten.Schema;
using Marten.Testing.Documents;
Expand Down
2 changes: 1 addition & 1 deletion src/Marten.Testing/Documents/Target.cs
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ public static Target Random(bool deep = false)
target.Children[i] = Random();
}

target.StringDict = Enumerable.Range(0, _random.Next(1, 10)).ToDictionary(i => $"key{i}", i => $"value{i}");
target.StringDict = Enumerable.Range(0, _random.Next(0, 10)).ToDictionary(i => $"key{i}", i => $"value{i}");
target.String = _strings[_random.Next(0, 10)];
target.OtherGuid = Guid.NewGuid();
}
Expand Down
4 changes: 3 additions & 1 deletion src/Marten/Linq/Members/ChildCollectionMember.cs
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,9 @@ public ISqlFragment ParseWhereForAll(MethodCallExpression method, IReadOnlyStore

public ISqlFragment ParseWhereForContains(MethodCallExpression body, IReadOnlyStoreOptions options)
{
throw new NotImplementedException();
throw new BadLinqExpressionException(
"Marten does not (yet) support contains queries through collections of element type " +
ElementType.FullNameInCode());
}


Expand Down
2 changes: 2 additions & 0 deletions src/Marten/Linq/Members/ChildCollectionWhereClause.cs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ public void Register(ISqlFragment fragment)
_fragment = fragment;
}

public ISqlFragment Fragment => _fragment;

public static bool TryBuildInlineFragment(ISqlFragment fragment, ICollectionMember collectionMember,
ISerializer serializer, out ICollectionAwareFilter filter)
{
Expand Down
Original file line number Diff line number Diff line change
@@ -1,15 +1,13 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Linq.Expressions;
using JasperFx.Core.Reflection;
using Marten.Linq.Members;
using Marten.Linq.Parsing;
using Marten.Linq.SqlGeneration.Filters;
using Weasel.Postgresql.SqlGeneration;

namespace Marten.Linq.Parsing.Methods;
namespace Marten.Linq.Members.Dictionaries;

[Obsolete]
internal class DictionaryContainsKey: IMethodCallParser
{
public bool Matches(MethodCallExpression expression)
Expand All @@ -29,6 +27,6 @@ public ISqlFragment Parse(IQueryableMemberCollection memberCollection, IReadOnly
var member = memberCollection.MemberFor(expression.Object);
var constant = expression.Arguments.Single().ReduceToConstant();

return new ContainmentWhereFilter(member, constant, options.Serializer());
return new DictionaryContainsKeyFilter((IDictionaryMember)member, options.Serializer(), constant);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
using System.Linq.Expressions;
using System.Reflection;
using JasperFx.CodeGeneration;
using Marten.Exceptions;
using Marten.Internal.CompiledQueries;
using Marten.Linq.Parsing;
using Weasel.Postgresql;
using Weasel.Postgresql.SqlGeneration;

namespace Marten.Linq.Members.Dictionaries;

internal class DictionaryContainsKeyFilter: ISqlFragment, ICompiledQueryAwareFilter
{
private readonly object _value;
private readonly string _keyText;
private readonly IDictionaryMember _member;

public DictionaryContainsKeyFilter(IDictionaryMember member, ISerializer serializer, ConstantExpression constant)
{
_value = constant.Value;
_keyText = serializer.ToCleanJson(_value);

_member = member;
}

public DictionaryContainsKeyFilter(IDictionaryMember member, ISerializer serializer, object value)
{
_value = value;
_keyText = serializer.ToCleanJson(_value);

_member = member;
}

public void Apply(CommandBuilder builder)
{
builder.Append("d.data #> '{");
foreach (var segment in _member.JsonPathSegments())
{
builder.Append(segment);
builder.Append(", ");
}

builder.Append(_keyText);
builder.Append("}' is not null");
}

public bool Contains(string sqlText)
{
return false;
}

public bool TryMatchValue(object value, MemberInfo member)
{
throw new BadLinqExpressionException("Marten does not (yet) support Dictionary.ContainsKey() in compiled queries");
}

public void GenerateCode(GeneratedMethod method, int parameterIndex)
{
throw new BadLinqExpressionException("Marten does not (yet) support Dictionary.ContainsKey() in compiled queries");
}

public string ParameterName { get; } = "NONE";
}
21 changes: 21 additions & 0 deletions src/Marten/Linq/Members/Dictionaries/DictionaryCountMember.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
using System.Linq.Expressions;
using Weasel.Postgresql.SqlGeneration;

namespace Marten.Linq.Members.Dictionaries;

internal class DictionaryCountMember: QueryableMember, IComparableMember
{
public DictionaryCountMember(IDictionaryMember parent): base(parent, "Count", typeof(int))
{
RawLocator = TypedLocator = $"jsonb_array_length(jsonb_path_query_array({parent.TypedLocator}, '$.keyvalue()'))";
Parent = parent;
}

public ICollectionMember Parent { get; }

public override ISqlFragment CreateComparison(string op, ConstantExpression constant)
{
var def = new CommandParameter(constant);
return new ComparisonFilter(this, def, op);
}
}
Loading

0 comments on commit 2150cbf

Please sign in to comment.