Skip to content

Commit

Permalink
Add session.AdvancedSql.StreamAsync<>()
Browse files Browse the repository at this point in the history
  • Loading branch information
e-tobi authored and jeremydmiller committed May 28, 2024
1 parent 24db609 commit 4ea682d
Show file tree
Hide file tree
Showing 6 changed files with 214 additions and 20 deletions.
52 changes: 52 additions & 0 deletions src/DocumentDbTests/Reading/advanced_sql_query.cs
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,58 @@ limit 2
#endregion
}

[Fact]
public async void can_async_stream_multiple_documents_and_scalar()
{
await using var session = theStore.LightweightSession();
#region sample_advanced_sql_stream_related_documents_and_scalar
session.Store(new DocWithMeta { Id = 1, Name = "Max" });
session.Store(new DocDetailsWithMeta { Id = 1, Detail = "Likes bees" });
session.Store(new DocWithMeta { Id = 2, Name = "Michael" });
session.Store(new DocDetailsWithMeta { Id = 2, Detail = "Is a good chess player" });
session.Store(new DocWithMeta { Id = 3, Name = "Anne" });
session.Store(new DocDetailsWithMeta { Id = 3, Detail = "Hates soap operas" });
session.Store(new DocWithMeta { Id = 4, Name = "Beatrix" });
session.Store(new DocDetailsWithMeta { Id = 4, Detail = "Likes to cook" });
await session.SaveChangesAsync();

var schema = session.DocumentStore.Options.Schema;

var asyncEnumerable = session.AdvancedSql.StreamAsync<DocWithMeta, DocDetailsWithMeta, long>(
$"""
select
row(a.id, a.data, a.mt_version),
row(b.id, b.data, b.mt_version),
row(count(*) over())
from
{schema.For<DocWithMeta>()} a
left join
{schema.For<DocDetailsWithMeta>()} b on a.id = b.id
where
(a.data ->> 'Id')::int > 1
order by
a.data ->> 'Name'
""",
CancellationToken.None);

var collectedResults = new List<(DocWithMeta doc, DocDetailsWithMeta detail, long totalResults)>();
await foreach (var result in asyncEnumerable)
{
collectedResults.Add(result);
}
#endregion
collectedResults.Count.ShouldBe(3);
collectedResults[0].totalResults.ShouldBe(3);
collectedResults[0].doc.Name.ShouldBe("Anne");
collectedResults[0].detail.Detail.ShouldBe("Hates soap operas");
collectedResults[1].totalResults.ShouldBe(3);
collectedResults[1].doc.Name.ShouldBe("Beatrix");
collectedResults[1].detail.Detail.ShouldBe("Likes to cook");
collectedResults[2].totalResults.ShouldBe(3);
collectedResults[2].doc.Name.ShouldBe("Michael");
collectedResults[2].detail.Detail.ShouldBe("Is a good chess player");
}

[Fact]
public void can_query_synchrounously()
{
Expand Down
53 changes: 53 additions & 0 deletions src/Marten/IAdvancedSql.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
#nullable enable
using System.Collections.Generic;
using System.Threading;

namespace Marten;

public interface IAdvancedSql
{
/// <summary>
/// Asynchronously queries the document storage with the supplied SQL.
/// The type parameters can be any document class, scalar or JSON-serializable class.
/// For each result type parameter, the SQL SELECT statement must contain a ROW.
/// For document types, the row must contain the required fields in the correct order,
/// depending on the session type and the metadata the document might use, at least id and data must be
/// provided.
/// </summary>
/// <typeparam name="T"></typeparam>
/// <param name="sql"></param>
/// <param name="parameters"></param>
/// <returns>An async enumerable iterating over the results</returns>
IAsyncEnumerable<T> StreamAsync<T>(string sql, CancellationToken token, params object[] parameters);

/// <summary>
/// Asynchronously queries the document storage with the supplied SQL.
/// The type parameters can be any document class, scalar or JSON-serializable class.
/// For each result type parameter, the SQL SELECT statement must contain a ROW.
/// For document types, the row must contain the required fields in the correct order,
/// depending on the session type and the metadata the document might use, at least id and data must be
/// provided.
/// </summary>
/// <typeparam name="T1"></typeparam>
/// <typeparam name="T2"></typeparam>
/// <param name="sql"></param>
/// <param name="parameters"></param>
/// <returns>An async enumerable iterating over the list of result tuples</returns>
IAsyncEnumerable<(T1, T2)> StreamAsync<T1, T2>(string sql, CancellationToken token, params object[] parameters);

/// <summary>
/// Asynchronously queries the document storage with the supplied SQL.
/// The type parameters can be any document class, scalar or JSON-serializable class.
/// For each result type parameter, the SQL SELECT statement must contain a ROW.
/// For document types, the row must contain the required fields in the correct order,
/// depending on the session type and the metadata the document might use, at least id and data must be
/// provided.
/// </summary>
/// <typeparam name="T1"></typeparam>
/// <typeparam name="T2"></typeparam>
/// <typeparam name="T3"></typeparam>
/// <param name="sql"></param>
/// <param name="parameters"></param>
/// <returns>An async enumerable iterating over the list of result tuples</returns>
IAsyncEnumerable<(T1, T2, T3)> StreamAsync<T1, T2, T3>(string sql, CancellationToken token, params object[] parameters);
}
6 changes: 6 additions & 0 deletions src/Marten/IQuerySession.cs
Original file line number Diff line number Diff line change
Expand Up @@ -716,4 +716,10 @@ Task<DocumentMetadata> MetadataForAsync<T>(T entity,
/// <param name="token"></param>
/// <returns></returns>
Task<DbDataReader> ExecuteReaderAsync(NpgsqlCommand command, CancellationToken token = default);

/// <summary>
/// Advanced sql query methods, to allow you to query the database
/// beyond what you can do with LINQ.
/// </summary>
IAdvancedSql AdvancedSql { get; }
}
73 changes: 73 additions & 0 deletions src/Marten/Internal/Sessions/QuerySession.AdvancedSql.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
using Marten.Linq.QueryHandlers;
using Marten.Util;
using System.Collections.Generic;
using System.Runtime.CompilerServices;
using System.Threading;

namespace Marten.Internal.Sessions;

public partial class QuerySession: IAdvancedSql
{
public async IAsyncEnumerable<T> StreamAsync<T>(string sql, [EnumeratorCancellation] CancellationToken token,
params object[] parameters)
{
assertNotDisposed();

var handler = new AdvancedSqlQueryHandler<T>(this, sql, parameters);

foreach (var documentType in handler.DocumentTypes)
{
await Database.EnsureStorageExistsAsync(documentType, token).ConfigureAwait(false);
}

var batch = this.BuildCommand(handler);
await using var reader = await ExecuteReaderAsync(batch, token).ConfigureAwait(false);

await foreach (var result in handler.EnumerateResults(reader, token))
{
yield return result;
}
}

public async IAsyncEnumerable<(T1, T2)> StreamAsync<T1, T2>(string sql, [EnumeratorCancellation] CancellationToken token,
params object[] parameters)
{
assertNotDisposed();

var handler = new AdvancedSqlQueryHandler<T1, T2>(this, sql, parameters);

foreach (var documentType in handler.DocumentTypes)
{
await Database.EnsureStorageExistsAsync(documentType, token).ConfigureAwait(false);
}

var batch = this.BuildCommand(handler);
await using var reader = await ExecuteReaderAsync(batch, token).ConfigureAwait(false);

await foreach (var result in handler.EnumerateResults(reader, token))
{
yield return result;
}
}

public async IAsyncEnumerable<(T1, T2, T3)> StreamAsync<T1, T2, T3>(string sql, [EnumeratorCancellation] CancellationToken token,
params object[] parameters)
{
assertNotDisposed();

var handler = new AdvancedSqlQueryHandler<T1, T2, T3>(this, sql, parameters);

foreach (var documentType in handler.DocumentTypes)
{
await Database.EnsureStorageExistsAsync(documentType, token).ConfigureAwait(false);
}

var batch = this.BuildCommand(handler);
await using var reader = await ExecuteReaderAsync(batch, token).ConfigureAwait(false);

await foreach (var result in handler.EnumerateResults(reader, token))
{
yield return result;
}
}
}
2 changes: 2 additions & 0 deletions src/Marten/Internal/Sessions/QuerySession.cs
Original file line number Diff line number Diff line change
Expand Up @@ -110,4 +110,6 @@ public IMartenSessionLogger Logger

public int RequestCount { get; set; }
public IDocumentStore DocumentStore => _store;

public IAdvancedSql AdvancedSql => this;
}
48 changes: 28 additions & 20 deletions src/Marten/Linq/QueryHandlers/AdvancedSqlQueryHandler.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
using System.Data.Common;
using System.IO;
using System.Linq;
using System.Runtime.CompilerServices;
using System.Threading;
using System.Threading.Tasks;
using JasperFx.Core.Reflection;
Expand All @@ -15,7 +16,7 @@

namespace Marten.Linq.QueryHandlers;

internal class AdvancedSqlQueryHandler<T>: AdvancedSqlQueryHandlerBase, IQueryHandler<IReadOnlyList<T>>
internal class AdvancedSqlQueryHandler<T>: AdvancedSqlQueryHandlerBase<T>, IQueryHandler<IReadOnlyList<T>>
{
public AdvancedSqlQueryHandler(IMartenSession session, string sql, object[] parameters):base(sql, parameters)
{
Expand All @@ -33,20 +34,17 @@ public IReadOnlyList<T> Handle(DbDataReader reader, IMartenSession session)
return list;
}

public async Task<IReadOnlyList<T>> HandleAsync(DbDataReader reader, IMartenSession session,
CancellationToken token)
public override async IAsyncEnumerable<T> EnumerateResults(DbDataReader reader,
[EnumeratorCancellation] CancellationToken token)
{
var list = new List<T>();
while (await reader.ReadAsync(token).ConfigureAwait(false))
{
var item = await ((ISelector<T>)Selectors[0]).ResolveAsync(reader, token).ConfigureAwait(false);
list.Add(item);
yield return await ((ISelector<T>)Selectors[0]).ResolveAsync(reader, token).ConfigureAwait(false);
}
return list;
}
}

internal class AdvancedSqlQueryHandler<T1, T2>: AdvancedSqlQueryHandlerBase, IQueryHandler<IReadOnlyList<(T1, T2)>>
internal class AdvancedSqlQueryHandler<T1, T2>: AdvancedSqlQueryHandlerBase<(T1, T2)>, IQueryHandler<IReadOnlyList<(T1, T2)>>
{
public AdvancedSqlQueryHandler(IMartenSession session, string sql, object[] parameters) : base(sql, parameters)
{
Expand All @@ -66,20 +64,18 @@ public AdvancedSqlQueryHandler(IMartenSession session, string sql, object[] para
return list;
}

public async Task<IReadOnlyList<(T1, T2)>> HandleAsync(DbDataReader reader, IMartenSession session,
CancellationToken token)
public override async IAsyncEnumerable<(T1, T2)> EnumerateResults(DbDataReader reader,
[EnumeratorCancellation] CancellationToken token)
{
var list = new List<(T1, T2)>();
while (await reader.ReadAsync(token).ConfigureAwait(false))
{
var item1 = await ReadNestedRowAsync<T1>(reader, 0, token).ConfigureAwait(false);
var item2 = await ReadNestedRowAsync<T2>(reader, 1, token).ConfigureAwait(false);
list.Add((item1, item2));
yield return (item1, item2);
}
return list;
}
}
internal class AdvancedSqlQueryHandler<T1, T2, T3>: AdvancedSqlQueryHandlerBase, IQueryHandler<IReadOnlyList<(T1, T2, T3)>>
internal class AdvancedSqlQueryHandler<T1, T2, T3>: AdvancedSqlQueryHandlerBase<(T1, T2, T3)>, IQueryHandler<IReadOnlyList<(T1, T2, T3)>>
{
public AdvancedSqlQueryHandler(IMartenSession session, string sql, object[] parameters) : base(sql, parameters)
{
Expand All @@ -101,22 +97,20 @@ public AdvancedSqlQueryHandler(IMartenSession session, string sql, object[] para
return list;
}

public async Task<IReadOnlyList<(T1, T2, T3)>> HandleAsync(DbDataReader reader, IMartenSession session,
CancellationToken token)
public override async IAsyncEnumerable<(T1, T2, T3)> EnumerateResults(DbDataReader reader,
[EnumeratorCancellation] CancellationToken token)
{
var list = new List<(T1, T2, T3)>();
while (await reader.ReadAsync(token).ConfigureAwait(false))
{
var item1 = await ReadNestedRowAsync<T1>(reader, 0, token).ConfigureAwait(false);
var item2 = await ReadNestedRowAsync<T2>(reader, 1, token).ConfigureAwait(false);
var item3 = await ReadNestedRowAsync<T3>(reader, 2, token).ConfigureAwait(false);
list.Add((item1, item2, item3));
yield return (item1, item2, item3);
}
return list;
}
}

internal class AdvancedSqlQueryHandlerBase
internal abstract class AdvancedSqlQueryHandlerBase<TResult>
{
protected readonly object[] Parameters;
protected readonly string Sql;
Expand Down Expand Up @@ -226,4 +220,18 @@ public Task<int> StreamJson(Stream stream, DbDataReader reader, CancellationToke
{
throw new NotImplementedException();
}

public async Task<IReadOnlyList<TResult>> HandleAsync(DbDataReader reader, IMartenSession session,
CancellationToken token)
{
var list = new List<TResult>();
await foreach (var result in EnumerateResults(reader, token).ConfigureAwait(false))
{
list.Add(result);
}

return list;
}

public abstract IAsyncEnumerable<TResult> EnumerateResults(DbDataReader reader, CancellationToken token);
}

0 comments on commit 4ea682d

Please sign in to comment.