Skip to content

Commit

Permalink
make reader use enumerables
Browse files Browse the repository at this point in the history
  • Loading branch information
annelo-msft committed Mar 28, 2024
1 parent 704c617 commit abfa53c
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 49 deletions.
18 changes: 9 additions & 9 deletions .dotnet/src/Custom/Assistants/AssistantClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ public virtual async Task<ClientResult<AssistantThread>> CreateThreadAsync(
return ClientResult.FromValue(new AssistantThread(internalResult.Value), internalResult.GetRawResponse());
}

public virtual ClientResult<AssistantThread> GetThread(string threadId)
public virtual ClientResult<AssistantThread> GetThread(string threadId)
{
ClientResult<Internal.Models.ThreadObject> internalResult = ThreadShim.GetThread(threadId);
return ClientResult.FromValue(new AssistantThread(internalResult.Value), internalResult.GetRawResponse());
Expand Down Expand Up @@ -268,13 +268,13 @@ public virtual async Task<ClientResult<AssistantThread>> ModifyThreadAsync(
return ClientResult.FromValue(new AssistantThread(internalResult.Value), internalResult.GetRawResponse());
}

public virtual ClientResult<bool> DeleteThread(string threadId)
public virtual ClientResult<bool> DeleteThread(string threadId)
{
ClientResult<Internal.Models.DeleteThreadResponse> internalResult = ThreadShim.DeleteThread(threadId);
return ClientResult.FromValue(internalResult.Value.Deleted, internalResult.GetRawResponse());
}

public virtual async Task<ClientResult<bool>> DeleteThreadAsync(string threadId)
public virtual async Task<ClientResult<bool>> DeleteThreadAsync(string threadId)
{
ClientResult<Internal.Models.DeleteThreadResponse> internalResult = await ThreadShim.DeleteThreadAsync(threadId).ConfigureAwait(false);
return ClientResult.FromValue(internalResult.Value.Deleted, internalResult.GetRawResponse());
Expand Down Expand Up @@ -464,7 +464,7 @@ public virtual StreamingClientResult<StreamingUpdate> CreateRunStreaming(
string assistantId,
RunCreationOptions options = null)
{
PipelineMessage message = CreateCreateRunRequest(threadId, assistantId, options, stream: true);
using PipelineMessage message = CreateCreateRunRequest(threadId, assistantId, options, stream: true);
RunShim.Pipeline.Send(message);
return CreateStreamingRunResult(message);
}
Expand All @@ -474,7 +474,7 @@ public virtual async Task<StreamingClientResult<StreamingUpdate>> CreateRunStrea
string assistantId,
RunCreationOptions options = null)
{
PipelineMessage message = CreateCreateRunRequest(threadId, assistantId, options, stream: true);
using PipelineMessage message = CreateCreateRunRequest(threadId, assistantId, options, stream: true);
await RunShim.Pipeline.SendAsync(message);
return CreateStreamingRunResult(message);
}
Expand Down Expand Up @@ -506,7 +506,7 @@ public virtual StreamingClientResult<StreamingUpdate> CreateThreadAndRunStreamin
ThreadCreationOptions threadOptions = null,
RunCreationOptions runOptions = null)
{
PipelineMessage message = CreateCreateThreadAndRunRequest(assistantId, threadOptions, runOptions, stream: true);
using PipelineMessage message = CreateCreateThreadAndRunRequest(assistantId, threadOptions, runOptions, stream: true);
Shim.Pipeline.Send(message);
return CreateStreamingRunResult(message);
}
Expand All @@ -516,7 +516,7 @@ public virtual async Task<StreamingClientResult<StreamingUpdate>> CreateThreadAn
ThreadCreationOptions threadOptions = null,
RunCreationOptions runOptions = null)
{
PipelineMessage message = CreateCreateThreadAndRunRequest(assistantId, threadOptions, runOptions, stream: true);
using PipelineMessage message = CreateCreateThreadAndRunRequest(assistantId, threadOptions, runOptions, stream: true);
await Shim.Pipeline.SendAsync(message);
return CreateStreamingRunResult(message);
}
Expand Down Expand Up @@ -621,14 +621,14 @@ public virtual async Task<ClientResult<ThreadRun>> SubmitToolOutputsAsync(string

public virtual StreamingClientResult<StreamingUpdate> SubmitToolOutputsStreaming(string threadId, string runId, IEnumerable<ToolOutput> toolOutputs)
{
PipelineMessage message = CreateSubmitToolOutputsRequest(threadId, runId, toolOutputs, stream: true);
using PipelineMessage message = CreateSubmitToolOutputsRequest(threadId, runId, toolOutputs, stream: true);
Shim.Pipeline.SendAsync(message);
return CreateStreamingRunResult(message);
}

public virtual async Task<StreamingClientResult<StreamingUpdate>> SubmitToolOutputsStreamingAsync(string threadId, string runId, IEnumerable<ToolOutput> toolOutputs)
{
PipelineMessage message = CreateSubmitToolOutputsRequest(threadId, runId, toolOutputs, stream: true);
using PipelineMessage message = CreateSubmitToolOutputsRequest(threadId, runId, toolOutputs, stream: true);
await Shim.Pipeline.SendAsync(message);
return CreateStreamingRunResult(message);
}
Expand Down
43 changes: 11 additions & 32 deletions .dotnet/src/Utility/SseAsyncEnumerator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -10,42 +10,24 @@ namespace OpenAI;

internal static class SseAsyncEnumerator<T>
{
internal static async IAsyncEnumerable<ServerSentEvent> EnumerateServerSentEvents(
Stream stream,
[EnumeratorCancellation] CancellationToken cancellationToken = default)
{
try
{
using SseReader sseReader = new(stream);
while (!cancellationToken.IsCancellationRequested)
{
ServerSentEvent? sseEvent = await sseReader.TryGetNextEventAsync(cancellationToken).ConfigureAwait(false);
if (sseEvent is null)
{
break;
}
else
{
yield return sseEvent.Value;
}
}
}
finally
{
// Always dispose the stream immediately once enumeration is complete for any reason
stream.Dispose();
}
}
private static ReadOnlyMemory<char>[] _wellKnownTokens =
[
"[DONE]".AsMemory(),
];

internal static async IAsyncEnumerable<T> EnumerateFromSseJsonStream(
Stream stream,
Func<ReadOnlyMemory<char>, JsonElement, IEnumerable<T>> multiElementDeserializer,
[EnumeratorCancellation] CancellationToken cancellationToken = default)
{
await foreach (ServerSentEvent sseEvent in EnumerateServerSentEvents(stream, cancellationToken))
using SseReader reader = new SseReader(stream);

await foreach (ServerSentEvent sseEvent in reader.GetEventsAsync(cancellationToken))
{
if (IsWellKnownDoneToken(sseEvent.Data)) continue;

using JsonDocument sseDocument = JsonDocument.Parse(sseEvent.Data);

foreach (T item in multiElementDeserializer(sseEvent.EventName, sseDocument.RootElement))
{
yield return item;
Expand All @@ -55,10 +37,7 @@ internal static async IAsyncEnumerable<T> EnumerateFromSseJsonStream(

private static bool IsWellKnownDoneToken(ReadOnlyMemory<char> data)
{
ReadOnlyMemory<char>[] wellKnownTokens =
[
"[DONE]".AsMemory(),
];
return wellKnownTokens.Any(token => data.Span.SequenceEqual(token.Span));
// TODO: Make faster than LINQ.
return _wellKnownTokens.Any(token => data.Span.SequenceEqual(token.Span));
}
}
17 changes: 9 additions & 8 deletions .dotnet/src/Utility/SseReader.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
using System;
using System.Collections.Generic;
using System.IO;
using System.Runtime.CompilerServices;
using System.Threading;
using System.Threading.Tasks;

Expand All @@ -26,7 +27,7 @@ public SseReader(Stream stream)
/// <returns>
/// The next <see cref="ServerSentEvent"/> in the stream, or null once no more data can be read from the stream.
/// </returns>
public ServerSentEvent? TryGetNextEvent(CancellationToken cancellationToken = default)
public IEnumerable<ServerSentEvent> GetEvents(CancellationToken cancellationToken = default)
{
List<ServerSentEventField> fields = [];

Expand All @@ -36,14 +37,14 @@ public SseReader(Stream stream)
if (line == null)
{
// A null line indicates end of input
return null;
yield break;
}
else if (line.Length == 0)
{
// An empty line should dispatch an event for pending accumulated fields
ServerSentEvent nextEvent = new(fields);
fields = [];
return nextEvent;
yield return nextEvent;
}
else if (line[0] == ':')
{
Expand All @@ -57,7 +58,7 @@ public SseReader(Stream stream)
}
}

return null;
yield break;
}

/// <summary>
Expand All @@ -68,7 +69,7 @@ public SseReader(Stream stream)
/// <returns>
/// The next <see cref="ServerSentEvent"/> in the stream, or null once no more data can be read from the stream.
/// </returns>
public async Task<ServerSentEvent?> TryGetNextEventAsync(CancellationToken cancellationToken = default)
public async IAsyncEnumerable<ServerSentEvent> GetEventsAsync([EnumeratorCancellation] CancellationToken cancellationToken = default)
{
List<ServerSentEventField> fields = [];

Expand All @@ -78,14 +79,14 @@ public SseReader(Stream stream)
if (line == null)
{
// A null line indicates end of input
return null;
yield break;
}
else if (line.Length == 0)
{
// An empty line should dispatch an event for pending accumulated fields
ServerSentEvent nextEvent = new(fields);
fields = [];
return nextEvent;
yield return nextEvent;
}
else if (line[0] == ':')
{
Expand All @@ -99,7 +100,7 @@ public SseReader(Stream stream)
}
}

return null;
yield break;
}

public void Dispose()
Expand Down

0 comments on commit abfa53c

Please sign in to comment.