Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Streaming Client Result investigation - chat #45

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 14 additions & 23 deletions .dotnet/src/Custom/Chat/ChatClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,6 @@
/// The number of independent, alternative response choices that should be generated.
/// </param>
/// <param name="options"> Additional options for the chat completion request. </param>
/// <param name="cancellationToken"> The cancellation token for the operation. </param>
/// <returns> A result for a single chat completion. </returns>
public virtual ClientResult<ChatCompletionCollection> CompleteChat(
IEnumerable<ChatRequestMessage> messages,
Expand Down Expand Up @@ -191,29 +190,25 @@
/// The number of independent, alternative choices that the chat completion request should generate.
/// </param>
/// <param name="options"> Additional options for the chat completion request. </param>
/// <param name="cancellationToken"> The cancellation token for the operation. </param>
/// <returns> A streaming result with incremental chat completion updates. </returns>
public virtual StreamingClientResult<StreamingChatUpdate> CompleteChatStreaming(
IEnumerable<ChatRequestMessage> messages,
int? choiceCount = null,
ChatCompletionOptions options = null)
{
PipelineMessage requestMessage = CreateCustomRequestMessage(messages, choiceCount, options);
requestMessage.BufferResponse = false;
Shim.Pipeline.Send(requestMessage);
PipelineResponse response = requestMessage.ExtractResponse();
PipelineMessage message = CreateCustomRequestMessage(messages, choiceCount, options);
message.BufferResponse = false;

Shim.Pipeline.Send(message);

PipelineResponse response = message.Response;

if (response.IsError)
{
throw new ClientResultException(response);
}

ClientResult genericResult = ClientResult.FromResponse(response);
return StreamingClientResult<StreamingChatUpdate>.CreateFromResponse(
genericResult,
(responseForEnumeration) => SseAsyncEnumerator<StreamingChatUpdate>.EnumerateFromSseStream(
responseForEnumeration.GetRawResponse().ContentStream,
e => StreamingChatUpdate.DeserializeStreamingChatUpdates(e)));
return new StreamingChatResult(response);
}

/// <summary>
Expand All @@ -229,29 +224,25 @@
/// The number of independent, alternative choices that the chat completion request should generate.
/// </param>
/// <param name="options"> Additional options for the chat completion request. </param>
/// <param name="cancellationToken"> The cancellation token for the operation. </param>
/// <returns> A streaming result with incremental chat completion updates. </returns>
public virtual async Task<StreamingClientResult<StreamingChatUpdate>> CompleteChatStreamingAsync(
IEnumerable<ChatRequestMessage> messages,
int? choiceCount = null,
ChatCompletionOptions options = null)
{
PipelineMessage requestMessage = CreateCustomRequestMessage(messages, choiceCount, options);
requestMessage.BufferResponse = false;
await Shim.Pipeline.SendAsync(requestMessage).ConfigureAwait(false);
PipelineResponse response = requestMessage.ExtractResponse();
PipelineMessage message = CreateCustomRequestMessage(messages, choiceCount, options);
message.BufferResponse = false;

await Shim.Pipeline.SendAsync(message).ConfigureAwait(false);

PipelineResponse response = message.Response;

if (response.IsError)
{
throw new ClientResultException(response);
}

ClientResult genericResult = ClientResult.FromResponse(response);
return StreamingClientResult<StreamingChatUpdate>.CreateFromResponse(
genericResult,
(responseForEnumeration) => SseAsyncEnumerator<StreamingChatUpdate>.EnumerateFromSseStream(
responseForEnumeration.GetRawResponse().ContentStream,
e => StreamingChatUpdate.DeserializeStreamingChatUpdates(e)));
return new StreamingChatResult(response);
}

private Internal.Models.CreateChatCompletionRequest CreateInternalRequest(
Expand All @@ -261,7 +252,7 @@
bool? stream = null)
{
options ??= new();
Internal.Models.CreateChatCompletionRequestResponseFormat? internalFormat = null;

Check warning on line 255 in .dotnet/src/Custom/Chat/ChatClient.cs

View workflow job for this annotation

GitHub Actions / build

The annotation for nullable reference types should only be used in code within a '#nullable' annotations context.
if (options.ResponseFormat is not null)
{
internalFormat = new(options.ResponseFormat switch
Expand Down
98 changes: 98 additions & 0 deletions .dotnet/src/Custom/Chat/StreamingChatResult.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
using System;
using System.ClientModel.Primitives;
using System.Collections.Generic;
using System.IO;
using System.Text.Json;
using System.Threading;
using System.Threading.Tasks;

namespace OpenAI.Chat;

#nullable enable

internal class StreamingChatResult : StreamingClientResult<StreamingChatUpdate>
{
public StreamingChatResult(PipelineResponse response) : base(response)
{
}

public override IAsyncEnumerator<StreamingChatUpdate> GetAsyncEnumerator(CancellationToken cancellationToken = default)
{
// Note: this implementation disposes the stream after the caller has
// enumerated the elements obtained from the stream. That is to say,
// the `await foreach` loop can only happen once -- if it is tried a
// second time, the caller will get an ObjectDisposedException trying
// to access a disposed Stream.
using PipelineResponse response = GetRawResponse();

// Extract the content stream from the response to obtain dispose
// ownership of it. This means the content stream will not be disposed
// when the response is disposed.
Stream contentStream = response.ContentStream ?? throw new InvalidOperationException("Cannot enumerate null response ContentStream.");
response.ContentStream = null;

return new ChatUpdateEnumerator(contentStream);
}

private class ChatUpdateEnumerator : IAsyncEnumerator<StreamingChatUpdate>
{
private readonly SseReader _sseReader;

private List<StreamingChatUpdate>? _currentUpdates;
private int _currentUpdateIndex;

public ChatUpdateEnumerator(Stream stream)
{
_sseReader = new(stream);
}

public StreamingChatUpdate Current => throw new NotImplementedException();

public async ValueTask<bool> MoveNextAsync()
{
// TODO: How to handle the CancellationToken?

if (_currentUpdates is not null && _currentUpdateIndex < _currentUpdates.Count)
{
_currentUpdateIndex++;
return true;
}

// We either don't have any stored updates, or we've exceeded the
// count of the ones we have. Get the next set.

// TODO: Call different configure await variant in this context, or no?
SseLine? sseEvent = await _sseReader.TryReadSingleFieldEventAsync().ConfigureAwait(false);
if (sseEvent is null)
{
// TODO: does this mean we're done or not?
return false;
}

ReadOnlyMemory<char> name = sseEvent.Value.FieldName;
if (!name.Span.SequenceEqual("data".AsSpan()))
{
throw new InvalidDataException();
}

ReadOnlyMemory<char> value = sseEvent.Value.FieldValue;
if (value.Span.SequenceEqual("[DONE]".AsSpan()))
{
// enumerator semantics are that MoveNextAsync returns false when done.
return false;
}

// TODO:optimize performance using Utf8JsonReader?
using JsonDocument sseMessageJson = JsonDocument.Parse(value);
_currentUpdates = StreamingChatUpdate.DeserializeStreamingChatUpdates(sseMessageJson.RootElement);
return true;
}

public ValueTask DisposeAsync()
{
// TODO: revisit per platforms where async dispose is available.
_sseReader?.Dispose();
return new ValueTask();
}
}
}
9 changes: 7 additions & 2 deletions .dotnet/src/Custom/Chat/StreamingChatUpdate.cs
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
namespace OpenAI.Chat;

using System;
using System.Collections.Generic;
using System.Text.Json;

namespace OpenAI.Chat;

/// <summary>
/// Represents an incremental item of new data in a streaming response to a chat completion request.
/// </summary>
Expand Down Expand Up @@ -184,11 +184,16 @@ internal StreamingChatUpdate(

internal static List<StreamingChatUpdate> DeserializeStreamingChatUpdates(JsonElement element)
{
// TODO: Do we need to validate that we didn't get null or empty?
// What's the contract for the JSON updates?

List<StreamingChatUpdate> results = [];

if (element.ValueKind == JsonValueKind.Null)
{
return results;
}

string id = default;
DateTimeOffset created = default;
string systemFingerprint = null;
Expand Down
59 changes: 0 additions & 59 deletions .dotnet/src/Utility/SseAsyncEnumerator.cs

This file was deleted.

12 changes: 8 additions & 4 deletions .dotnet/src/Utility/SseReader.cs
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
using System;
using System.ClientModel;
using System.ClientModel.Internal;
using System.IO;
using System.Threading.Tasks;

namespace OpenAI;

internal sealed class SseReader : IDisposable
internal sealed class SseReader : IDisposable, IAsyncDisposable
{
private readonly Stream _stream;
private readonly StreamReader _reader;
Expand Down Expand Up @@ -115,4 +113,10 @@ public void Dispose()
Dispose(disposing: true);
GC.SuppressFinalize(this);
}
}

ValueTask IAsyncDisposable.DisposeAsync()
{
// TODO: revisit per platforms where async dispose is available.
return new ValueTask();
}
}
27 changes: 27 additions & 0 deletions .dotnet/src/Utility/StreamingClientResultOfT.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
using System.ClientModel;
using System.ClientModel.Primitives;
using System.Collections.Generic;
using System.Threading;

namespace OpenAI;

#pragma warning disable CS1591 // public XML comments

/// <summary>
/// Represents an operation response with streaming content that can be deserialized and enumerated while the response
/// is still being received.
/// </summary>
/// <typeparam name="T"> The data type representative of distinct, streamable items. </typeparam>
public abstract class StreamingClientResult<T> : ClientResult, IAsyncEnumerable<T>
{
protected StreamingClientResult(PipelineResponse response) : base(response)
{
}

// Note that if the implementation disposes the stream, the caller can only
// enumerate the results once. I think this makes sense, but we should
// make sure architects agree.
public abstract IAsyncEnumerator<T> GetAsyncEnumerator(CancellationToken cancellationToken = default);
}

#pragma warning restore CS1591 // public XML comments
Loading
Loading