From c7369cc4323ae3c781d61d3ad1b5aa75cbaacfb1 Mon Sep 17 00:00:00 2001 From: Travis Wilson Date: Fri, 15 Mar 2024 20:40:52 -0700 Subject: [PATCH 01/15] initial streaming assistants support --- .../Assistants/AssistantClient.Protocol.cs | 85 +++++++++ .../src/Custom/Assistants/AssistantClient.cs | 98 +++++++++- .../Assistants/FunctionToolDefinition.cs | 3 +- .../src/Custom/Assistants/FunctionToolInfo.cs | 4 +- .../src/Custom/Assistants/MessageContent.cs | 1 + .../Custom/Assistants/StreamingRunUpdate.cs | 146 +++++++++++++++ .../Models/CreateRunRequest.Serialization.cs | 17 +- .../src/Generated/Models/CreateRunRequest.cs | 12 +- ...CreateThreadAndRunRequest.Serialization.cs | 17 +- .../Models/CreateThreadAndRunRequest.cs | 12 +- .../Models/MessageObject.Serialization.cs | 70 +++++++- .dotnet/src/Generated/Models/MessageObject.cs | 28 ++- ...geObjectIncompleteDetails.Serialization.cs | 120 +++++++++++++ .../Models/MessageObjectIncompleteDetails.cs | 71 ++++++++ .../Generated/Models/MessageObjectStatus.cs | 49 +++++ .dotnet/src/Generated/OpenAIModelFactory.cs | 32 +++- .dotnet/src/OpenAI.csproj | 1 + .dotnet/src/Utility/SseAsyncEnumerator.cs | 18 +- .dotnet/src/Utility/SseEvent.cs | 64 +++++++ .dotnet/src/Utility/SseEventField.cs | 64 +++++++ .dotnet/src/Utility/SseEventFieldType.cs | 11 ++ .dotnet/src/Utility/SseLine.cs | 29 --- .dotnet/src/Utility/SseReader.cs | 167 ++++++++---------- .dotnet/tests/OpenAI.Tests.csproj | 23 --- .dotnet/tests/TestScenarios/AssistantTests.cs | 29 ++- messages/models.tsp | 27 +++ runs/models.tsp | 12 ++ tsp-output/@typespec/openapi3/openapi.yaml | 41 +++++ 28 files changed, 1081 insertions(+), 170 deletions(-) create mode 100644 .dotnet/src/Custom/Assistants/StreamingRunUpdate.cs create mode 100644 .dotnet/src/Generated/Models/MessageObjectIncompleteDetails.Serialization.cs create mode 100644 .dotnet/src/Generated/Models/MessageObjectIncompleteDetails.cs create mode 100644 .dotnet/src/Generated/Models/MessageObjectStatus.cs create mode 100644 .dotnet/src/Utility/SseEvent.cs create mode 100644 .dotnet/src/Utility/SseEventField.cs create mode 100644 .dotnet/src/Utility/SseEventFieldType.cs delete mode 100644 .dotnet/src/Utility/SseLine.cs diff --git a/.dotnet/src/Custom/Assistants/AssistantClient.Protocol.cs b/.dotnet/src/Custom/Assistants/AssistantClient.Protocol.cs index 1ac71c23b..009a3ccd8 100644 --- a/.dotnet/src/Custom/Assistants/AssistantClient.Protocol.cs +++ b/.dotnet/src/Custom/Assistants/AssistantClient.Protocol.cs @@ -1,6 +1,8 @@ +using System; using System.ClientModel; using System.ClientModel.Primitives; using System.ComponentModel; +using System.Text; using System.Threading.Tasks; namespace OpenAI.Assistants; @@ -318,6 +320,15 @@ public virtual ClientResult CreateRun( RequestOptions options = null) => RunShim.CreateRun(threadId, content, options); + /// + [EditorBrowsable(EditorBrowsableState.Never)] + public virtual ClientResult CreateRunStreaming(string threadId, BinaryContent content, RequestOptions options = null) + { + PipelineMessage message = CreateCreateRunRequest(threadId, content, stream: true, options); + RunShim.Pipeline.Send(message); + return ClientResult.FromResponse(message.ExtractResponse()); + } + /// [EditorBrowsable(EditorBrowsableState.Never)] public virtual async Task CreateRunAsync( @@ -326,6 +337,15 @@ public virtual async Task CreateRunAsync( RequestOptions options = null) => await RunShim.CreateRunAsync(threadId, content, options).ConfigureAwait(false); + /// + [EditorBrowsable(EditorBrowsableState.Never)] + public virtual async Task CreateRunStreamingAsync(string threadId, BinaryContent content, RequestOptions options = null) + { + PipelineMessage message = CreateCreateRunRequest(threadId, content, stream: true, options); + await RunShim.Pipeline.SendAsync(message); + return ClientResult.FromResponse(message.ExtractResponse()); + } + /// [EditorBrowsable(EditorBrowsableState.Never)] public virtual ClientResult CreateThreadAndRun( @@ -333,6 +353,15 @@ public virtual ClientResult CreateThreadAndRun( RequestOptions options = null) => RunShim.CreateThreadAndRun(content, options); + /// + [EditorBrowsable(EditorBrowsableState.Never)] + public virtual ClientResult CreateThreadAndRunStreaming(BinaryContent content, RequestOptions options = null) + { + PipelineMessage message = CreateCreateThreadAndRunRequest(content, stream: true, options); + RunShim.Pipeline.Send(message); + return ClientResult.FromResponse(message.ExtractResponse()); + } + /// [EditorBrowsable(EditorBrowsableState.Never)] public virtual async Task CreateThreadAndRunAsync( @@ -340,6 +369,15 @@ public virtual async Task CreateThreadAndRunAsync( RequestOptions options = null) => await RunShim.CreateThreadAndRunAsync(content, options).ConfigureAwait(false); + /// + [EditorBrowsable(EditorBrowsableState.Never)] + public virtual async Task CreateThreadAndRunStreamingAsync(BinaryContent content, RequestOptions options = null) + { + PipelineMessage message = CreateCreateThreadAndRunRequest(content, stream: true, options); + await RunShim.Pipeline.SendAsync(message); + return ClientResult.FromResponse(message.ExtractResponse()); + } + /// [EditorBrowsable(EditorBrowsableState.Never)] public virtual ClientResult GetRun( @@ -467,4 +505,51 @@ public virtual async Task GetRunStepsAsync( string subsequentStepId, RequestOptions options) => await RunShim.GetRunStepsAsync(threadId, runId, maxResults, createdSortOrder, previousStepId, subsequentStepId, options).ConfigureAwait(false); + + internal PipelineMessage CreateCreateRunRequest(string threadId, BinaryContent content, bool? stream = null, RequestOptions options = null) + { + PipelineMessage message = Shim.Pipeline.CreateMessage(); + message.ResponseClassifier = ResponseErrorClassifier200; + if (stream == true) + { + message.BufferResponse = false; + } + PipelineRequest request = message.Request; + request.Method = "POST"; + UriBuilder uriBuilder = new(_clientConnector.Endpoint.AbsoluteUri); + StringBuilder path = new(); + path.Append("/threads/"); + path.Append(threadId); + path.Append("/runs"); + uriBuilder.Path += path.ToString(); + uriBuilder.Query += $"?thread_id={threadId}"; + request.Uri = uriBuilder.Uri; + request.Headers.Set("Content-Type", "application/json"); + request.Headers.Set("Accept", "text/event-stream"); + request.Content = content; + message.Apply(options ?? new()); + return message; + } + + internal PipelineMessage CreateCreateThreadAndRunRequest(BinaryContent content, bool? stream = null, RequestOptions options = null) + { + PipelineMessage message = Shim.Pipeline.CreateMessage(); + message.ResponseClassifier = ResponseErrorClassifier200; + if (stream == true) + { + message.BufferResponse = false; + } + PipelineRequest request = message.Request; + request.Method = "POST"; + UriBuilder uriBuilder = new(_clientConnector.Endpoint.AbsoluteUri); + StringBuilder path = new(); + path.Append("/threads/runs"); + uriBuilder.Path += path.ToString(); + request.Uri = uriBuilder.Uri; + request.Headers.Set("Content-Type", "application/json"); + request.Headers.Set("Accept", "text/event-stream"); + request.Content = content; + message.Apply(options ?? new()); + return message; + } } diff --git a/.dotnet/src/Custom/Assistants/AssistantClient.cs b/.dotnet/src/Custom/Assistants/AssistantClient.cs index e20089177..8703e169b 100644 --- a/.dotnet/src/Custom/Assistants/AssistantClient.cs +++ b/.dotnet/src/Custom/Assistants/AssistantClient.cs @@ -1,8 +1,11 @@ +using OpenAI.Chat; using OpenAI.ClientShared.Internal; +using OpenAI.Internal; using System; using System.ClientModel; using System.ClientModel.Primitives; using System.Collections.Generic; +using System.Text; using System.Threading.Tasks; namespace OpenAI.Assistants; @@ -506,6 +509,26 @@ public virtual async Task> CreateRunAsync( return ClientResult.FromValue(new ThreadRun(internalResult.Value), internalResult.GetRawResponse()); } + public virtual StreamingClientResult CreateRunStreaming( + string threadId, + string assistantId, + RunCreationOptions options = null) + { + PipelineMessage message = CreateCreateRunRequest(threadId, assistantId, options); + RunShim.Pipeline.Send(message); + return CreateStreamingRunResult(message); + } + + public virtual async Task> CreateRunStreamingAsync( + string threadId, + string assistantId, + RunCreationOptions options = null) + { + PipelineMessage message = CreateCreateRunRequest(threadId, assistantId, options); + await RunShim.Pipeline.SendAsync(message); + return CreateStreamingRunResult(message); + } + public virtual ClientResult CreateThreadAndRun( string assistantId, ThreadCreationOptions threadOptions = null, @@ -529,6 +552,26 @@ Internal.Models.CreateThreadAndRunRequest request return ClientResult.FromValue(new ThreadRun(internalResult.Value), internalResult.GetRawResponse()); } + public virtual StreamingClientResult CreateThreadAndRunStreaming( + string assistantId, + ThreadCreationOptions threadOptions = null, + RunCreationOptions runOptions = null) + { + PipelineMessage message = CreateCreateThreadAndRunRequest(assistantId, threadOptions, runOptions, stream: true); + Shim.Pipeline.Send(message); + return CreateStreamingRunResult(message); + } + + public virtual async Task> CreateThreadAndRunStreamingAsync( + string assistantId, + ThreadCreationOptions threadOptions = null, + RunCreationOptions runOptions = null) + { + PipelineMessage message = CreateCreateThreadAndRunRequest(assistantId, threadOptions, runOptions, stream: true); + await Shim.Pipeline.SendAsync(message); + return CreateStreamingRunResult(message); + } + public virtual ClientResult GetRun(string threadId, string runId) { ClientResult internalResult = RunShim.GetRun(threadId, runId); @@ -630,6 +673,50 @@ public virtual async Task> SubmitToolOutputsAsync(string return ClientResult.FromValue(new ThreadRun(internalResult.Value), internalResult.GetRawResponse()); } + internal PipelineMessage CreateCreateRunRequest(string threadId, string assistantId, RunCreationOptions runOptions, bool? stream = null) + { + Internal.Models.CreateRunRequest internalCreateRunRequest = CreateInternalCreateRunRequest(assistantId, runOptions, stream); + BinaryContent requestBody = BinaryContent.Create(internalCreateRunRequest); + return CreateCreateRunRequest(threadId, requestBody, stream: true); + } + + internal PipelineMessage CreateCreateThreadAndRunRequest( + string assistantId, + ThreadCreationOptions threadOptions, + RunCreationOptions runOptions, + bool? stream = null) + { + Internal.Models.CreateThreadAndRunRequest internalRequest + = CreateInternalCreateThreadAndRunRequest(assistantId, threadOptions, runOptions, stream: true); + BinaryContent content = BinaryContent.Create(internalRequest); + return CreateCreateThreadAndRunRequest(content, stream: true); + } + + internal static StreamingClientResult CreateStreamingRunResult(PipelineMessage message) + { + if (message.Response.IsError) + { + throw new ClientResultException(message.Response); + } + PipelineResponse response = null; + try + { + response = message.ExtractResponse(); + ClientResult genericResult = ClientResult.FromResponse(response); + StreamingClientResult streamingResult = StreamingClientResult.CreateFromResponse( + genericResult, + (responseForEnumeration) => SseAsyncEnumerator.EnumerateFromSseStream( + responseForEnumeration.GetRawResponse().ContentStream, + e => StreamingRunUpdate.DeserializeStreamingRunUpdates(e))); + response = null; + return streamingResult; + } + finally + { + response?.Dispose(); + } + } + internal static Internal.Models.CreateAssistantRequest CreateInternalCreateAssistantRequest( string modelName, AssistantCreationOptions options) @@ -670,7 +757,8 @@ internal static Internal.Models.CreateThreadRequest CreateInternalCreateThreadRe internal static Internal.Models.CreateRunRequest CreateInternalCreateRunRequest( string assistantId, - RunCreationOptions options = null) + RunCreationOptions options = null, + bool? stream = null) { options ??= new(); return new( @@ -680,13 +768,15 @@ internal static Internal.Models.CreateRunRequest CreateInternalCreateRunRequest( options.AdditionalInstructions, ToInternalBinaryDataList(options.OverrideTools), options.Metadata, + stream, serializedAdditionalRawData: null); } internal static Internal.Models.CreateThreadAndRunRequest CreateInternalCreateThreadAndRunRequest( string assistantId, ThreadCreationOptions threadOptions, - RunCreationOptions runOptions) + RunCreationOptions runOptions, + bool? stream = null) { threadOptions ??= new(); runOptions ??= new(); @@ -698,6 +788,7 @@ internal static Internal.Models.CreateThreadAndRunRequest CreateInternalCreateTh runOptions.OverrideInstructions, ToInternalBinaryDataList(runOptions?.OverrideTools), runOptions?.Metadata, + stream, serializedAdditionalRawData: null); } @@ -766,4 +857,7 @@ internal virtual async Task>> GetListQueryPageAsyn ListQueryPage convertedValue = ListQueryPage.Create(internalResult.Value) as ListQueryPage; return ClientResult.FromValue(convertedValue, internalResult.GetRawResponse()); } + + private static PipelineMessageClassifier _responseErrorClassifier200; + private static PipelineMessageClassifier ResponseErrorClassifier200 => _responseErrorClassifier200 ??= PipelineMessageClassifier.Create(stackalloc ushort[] { 200 }); } diff --git a/.dotnet/src/Custom/Assistants/FunctionToolDefinition.cs b/.dotnet/src/Custom/Assistants/FunctionToolDefinition.cs index 57592ac9e..d0d061ce7 100644 --- a/.dotnet/src/Custom/Assistants/FunctionToolDefinition.cs +++ b/.dotnet/src/Custom/Assistants/FunctionToolDefinition.cs @@ -82,7 +82,8 @@ internal override void WriteDerived(Utf8JsonWriter writer, ModelReaderWriterOpti if (OptionalProperty.IsDefined(Parameters)) { writer.WritePropertyName("parameters"u8); - writer.WriteRawValue(Parameters.ToString()); + using JsonDocument parametersJson = JsonDocument.Parse(Parameters); + parametersJson.WriteTo(writer); } writer.WriteEndObject(); } diff --git a/.dotnet/src/Custom/Assistants/FunctionToolInfo.cs b/.dotnet/src/Custom/Assistants/FunctionToolInfo.cs index a85cd9158..527779480 100644 --- a/.dotnet/src/Custom/Assistants/FunctionToolInfo.cs +++ b/.dotnet/src/Custom/Assistants/FunctionToolInfo.cs @@ -75,7 +75,9 @@ internal override void WriteDerived(Utf8JsonWriter writer, ModelReaderWriterOpti } if (OptionalProperty.IsDefined(Parameters)) { - writer.WriteRawValue(Parameters.ToString()); + writer.WritePropertyName("parameters"u8); + using JsonDocument parametersJson = JsonDocument.Parse(Parameters); + parametersJson.WriteTo(writer); } writer.WriteEndObject(); } diff --git a/.dotnet/src/Custom/Assistants/MessageContent.cs b/.dotnet/src/Custom/Assistants/MessageContent.cs index 928137621..9c41b4e49 100644 --- a/.dotnet/src/Custom/Assistants/MessageContent.cs +++ b/.dotnet/src/Custom/Assistants/MessageContent.cs @@ -3,4 +3,5 @@ namespace OpenAI.Assistants; public abstract partial class MessageContent { + public virtual string GetText() => (this as MessageTextContent).Text; } diff --git a/.dotnet/src/Custom/Assistants/StreamingRunUpdate.cs b/.dotnet/src/Custom/Assistants/StreamingRunUpdate.cs new file mode 100644 index 000000000..59af9e031 --- /dev/null +++ b/.dotnet/src/Custom/Assistants/StreamingRunUpdate.cs @@ -0,0 +1,146 @@ +namespace OpenAI.Assistants; + +using System.Collections.Generic; +using System.Text.Json; + +/// +/// Represents an incremental item of new data in a streaming response when running a thread with an assistant. +/// +public partial class StreamingRunUpdate +{ + // To be implemented: properties/pattern for non-streaming objects (e.g. thread and run creation) + //public string AssistantId { get; } + //public string ThreadId { get; } + //public DateTimeOffset? ThreadCreatedAt { get; } + //public IReadOnlyDictionary ThreadMetadata; + //public string RunId { get; } + //public DateTimeOffset? RunCreatedAt { get; } + //public DateTimeOffset? RunStartedAt { get; } + //public RunStatus? RunStatus { get; } + + public string AssistantId { get; } + public string ThreadId { get; } + public string MessageId { get; } + public string RunId { get; } + public string RunStepId { get; } + public MessageRole? MessageRole { get; } + public MessageContent MessageContentUpdate { get; } + public int? MessageContentIndex { get; } + + internal static List DeserializeStreamingRunUpdates(JsonElement element) + { + List results = []; + if (element.ValueKind == JsonValueKind.Null) + { + return results; + } + string objectLabel = null; + string assistantId = null; + string threadId = null; + string messageId = null; + string runId = null; + string runStepId = null; + MessageRole? role = null; + List<(MessageContent, int?)?> deltaContentItems = [null]; + foreach (JsonProperty property in element.EnumerateObject()) + { + if (property.NameEquals("object"u8)) + { + objectLabel = property.Value.GetString(); + continue; + } + if (property.NameEquals("id"u8)) + { + if (objectLabel?.Contains("run.step") == true) + { + runStepId = property.Value.GetString(); + } + else if (objectLabel?.Contains("run") == true) + { + runId = property.Value.GetString(); + } + else if (objectLabel?.Contains("message") == true) + { + messageId = property.Value.GetString(); + } + else if (objectLabel?.Contains("thread") == true) + { + threadId = property.Value.GetString(); + } + continue; + } + if (property.NameEquals("assistant_id")) + { + assistantId = property.Value.GetString(); + continue; + } + if (property.NameEquals("role")) + { + string roleText = property.Value.GetString(); + if (roleText == "user") + { + role = Assistants.MessageRole.User; + } + else if (roleText == "assistant") + { + role = Assistants.MessageRole.Assistant; + } + continue; + } + if (property.NameEquals("delta"u8)) + { + foreach (JsonProperty messageDeltaProperty in property.Value.EnumerateObject()) + { + if (messageDeltaProperty.NameEquals("content"u8)) + { + deltaContentItems.Clear(); + int? contentIndex = null; + foreach (JsonElement contentArrayElement in messageDeltaProperty.Value.EnumerateArray()) + { + foreach (JsonProperty contentItemProperty in contentArrayElement.EnumerateObject()) + { + if (contentItemProperty.NameEquals("index"u8)) + { + contentIndex = contentItemProperty.Value.GetInt32(); + continue; + } + } + deltaContentItems.Add((MessageContent.DeserializeMessageContent(contentArrayElement), contentIndex)); + } + continue; + } + } + continue; + } + } + foreach ((MessageContent, int?)? deltaContentItem in deltaContentItems) + { + results.Add( + new StreamingRunUpdate( + assistantId, + threadId, + runId, + runStepId, + role, + deltaContentItem)); + } + return results; + } + + internal StreamingRunUpdate( + string assistantId, + string threadId, + string runId, + string runStepId, + MessageRole? role, + (MessageContent, int?)? indexedContentUpdate) + { + AssistantId = assistantId; + ThreadId = threadId; + RunId = runId; + RunStepId = runStepId; + MessageRole = role; + MessageContentUpdate = indexedContentUpdate?.Item1 ?? null; + MessageContentIndex = indexedContentUpdate?.Item2 ?? null; + } +} diff --git a/.dotnet/src/Generated/Models/CreateRunRequest.Serialization.cs b/.dotnet/src/Generated/Models/CreateRunRequest.Serialization.cs index 3bdb98f68..ef9615bfa 100644 --- a/.dotnet/src/Generated/Models/CreateRunRequest.Serialization.cs +++ b/.dotnet/src/Generated/Models/CreateRunRequest.Serialization.cs @@ -104,6 +104,11 @@ void IJsonModel.Write(Utf8JsonWriter writer, ModelReaderWriter writer.WriteNull("metadata"); } } + if (OptionalProperty.IsDefined(Stream)) + { + writer.WritePropertyName("stream"u8); + writer.WriteBooleanValue(Stream.Value); + } if (options.Format != "W" && _serializedAdditionalRawData != null) { foreach (var item in _serializedAdditionalRawData) @@ -148,6 +153,7 @@ internal static CreateRunRequest DeserializeCreateRunRequest(JsonElement element OptionalProperty additionalInstructions = default; OptionalProperty> tools = default; OptionalProperty> metadata = default; + OptionalProperty stream = default; IDictionary serializedAdditionalRawData = default; Dictionary additionalPropertiesDictionary = new Dictionary(); foreach (var property in element.EnumerateObject()) @@ -222,13 +228,22 @@ internal static CreateRunRequest DeserializeCreateRunRequest(JsonElement element metadata = dictionary; continue; } + if (property.NameEquals("stream"u8)) + { + if (property.Value.ValueKind == JsonValueKind.Null) + { + continue; + } + stream = property.Value.GetBoolean(); + continue; + } if (options.Format != "W") { additionalPropertiesDictionary.Add(property.Name, BinaryData.FromString(property.Value.GetRawText())); } } serializedAdditionalRawData = additionalPropertiesDictionary; - return new CreateRunRequest(assistantId, model.Value, instructions.Value, additionalInstructions.Value, OptionalProperty.ToList(tools), OptionalProperty.ToDictionary(metadata), serializedAdditionalRawData); + return new CreateRunRequest(assistantId, model.Value, instructions.Value, additionalInstructions.Value, OptionalProperty.ToList(tools), OptionalProperty.ToDictionary(metadata), OptionalProperty.ToNullable(stream), serializedAdditionalRawData); } BinaryData IPersistableModel.Write(ModelReaderWriterOptions options) diff --git a/.dotnet/src/Generated/Models/CreateRunRequest.cs b/.dotnet/src/Generated/Models/CreateRunRequest.cs index e7bf9f448..dd63ba6f5 100644 --- a/.dotnet/src/Generated/Models/CreateRunRequest.cs +++ b/.dotnet/src/Generated/Models/CreateRunRequest.cs @@ -77,8 +77,12 @@ public CreateRunRequest(string assistantId) /// additional information about the object in a structured format. Keys can be a maximum of 64 /// characters long and values can be a maxium of 512 characters long. /// + /// + /// If true, returns a stream of events that happen during the Run as server-sent events, + /// terminating when the Run enters a terminal state with a data: [DONE] message. + /// /// Keeps track of any properties unknown to the library. - internal CreateRunRequest(string assistantId, string model, string instructions, string additionalInstructions, IList tools, IDictionary metadata, IDictionary serializedAdditionalRawData) + internal CreateRunRequest(string assistantId, string model, string instructions, string additionalInstructions, IList tools, IDictionary metadata, bool? stream, IDictionary serializedAdditionalRawData) { AssistantId = assistantId; Model = model; @@ -86,6 +90,7 @@ internal CreateRunRequest(string assistantId, string model, string instructions, AdditionalInstructions = additionalInstructions; Tools = tools; Metadata = metadata; + Stream = stream; _serializedAdditionalRawData = serializedAdditionalRawData; } @@ -150,5 +155,10 @@ internal CreateRunRequest() /// characters long and values can be a maxium of 512 characters long. /// public IDictionary Metadata { get; set; } + /// + /// If true, returns a stream of events that happen during the Run as server-sent events, + /// terminating when the Run enters a terminal state with a data: [DONE] message. + /// + public bool? Stream { get; set; } } } diff --git a/.dotnet/src/Generated/Models/CreateThreadAndRunRequest.Serialization.cs b/.dotnet/src/Generated/Models/CreateThreadAndRunRequest.Serialization.cs index b3281386f..2c8307f9c 100644 --- a/.dotnet/src/Generated/Models/CreateThreadAndRunRequest.Serialization.cs +++ b/.dotnet/src/Generated/Models/CreateThreadAndRunRequest.Serialization.cs @@ -97,6 +97,11 @@ void IJsonModel.Write(Utf8JsonWriter writer, ModelRea writer.WriteNull("metadata"); } } + if (OptionalProperty.IsDefined(Stream)) + { + writer.WritePropertyName("stream"u8); + writer.WriteBooleanValue(Stream.Value); + } if (options.Format != "W" && _serializedAdditionalRawData != null) { foreach (var item in _serializedAdditionalRawData) @@ -141,6 +146,7 @@ internal static CreateThreadAndRunRequest DeserializeCreateThreadAndRunRequest(J OptionalProperty instructions = default; OptionalProperty> tools = default; OptionalProperty> metadata = default; + OptionalProperty stream = default; IDictionary serializedAdditionalRawData = default; Dictionary additionalPropertiesDictionary = new Dictionary(); foreach (var property in element.EnumerateObject()) @@ -214,13 +220,22 @@ internal static CreateThreadAndRunRequest DeserializeCreateThreadAndRunRequest(J metadata = dictionary; continue; } + if (property.NameEquals("stream"u8)) + { + if (property.Value.ValueKind == JsonValueKind.Null) + { + continue; + } + stream = property.Value.GetBoolean(); + continue; + } if (options.Format != "W") { additionalPropertiesDictionary.Add(property.Name, BinaryData.FromString(property.Value.GetRawText())); } } serializedAdditionalRawData = additionalPropertiesDictionary; - return new CreateThreadAndRunRequest(assistantId, thread.Value, model.Value, instructions.Value, OptionalProperty.ToList(tools), OptionalProperty.ToDictionary(metadata), serializedAdditionalRawData); + return new CreateThreadAndRunRequest(assistantId, thread.Value, model.Value, instructions.Value, OptionalProperty.ToList(tools), OptionalProperty.ToDictionary(metadata), OptionalProperty.ToNullable(stream), serializedAdditionalRawData); } BinaryData IPersistableModel.Write(ModelReaderWriterOptions options) diff --git a/.dotnet/src/Generated/Models/CreateThreadAndRunRequest.cs b/.dotnet/src/Generated/Models/CreateThreadAndRunRequest.cs index 942ac10f1..853b19fe8 100644 --- a/.dotnet/src/Generated/Models/CreateThreadAndRunRequest.cs +++ b/.dotnet/src/Generated/Models/CreateThreadAndRunRequest.cs @@ -74,8 +74,12 @@ public CreateThreadAndRunRequest(string assistantId) /// additional information about the object in a structured format. Keys can be a maximum of 64 /// characters long and values can be a maxium of 512 characters long. /// + /// + /// If true, returns a stream of events that happen during the Run as server-sent events, + /// terminating when the Run enters a terminal state with a data: [DONE] message. + /// /// Keeps track of any properties unknown to the library. - internal CreateThreadAndRunRequest(string assistantId, CreateThreadRequest thread, string model, string instructions, IList tools, IDictionary metadata, IDictionary serializedAdditionalRawData) + internal CreateThreadAndRunRequest(string assistantId, CreateThreadRequest thread, string model, string instructions, IList tools, IDictionary metadata, bool? stream, IDictionary serializedAdditionalRawData) { AssistantId = assistantId; Thread = thread; @@ -83,6 +87,7 @@ internal CreateThreadAndRunRequest(string assistantId, CreateThreadRequest threa Instructions = instructions; Tools = tools; Metadata = metadata; + Stream = stream; _serializedAdditionalRawData = serializedAdditionalRawData; } @@ -144,5 +149,10 @@ internal CreateThreadAndRunRequest() /// characters long and values can be a maxium of 512 characters long. /// public IDictionary Metadata { get; set; } + /// + /// If true, returns a stream of events that happen during the Run as server-sent events, + /// terminating when the Run enters a terminal state with a data: [DONE] message. + /// + public bool? Stream { get; set; } } } diff --git a/.dotnet/src/Generated/Models/MessageObject.Serialization.cs b/.dotnet/src/Generated/Models/MessageObject.Serialization.cs index 27a044c39..1edf85faa 100644 --- a/.dotnet/src/Generated/Models/MessageObject.Serialization.cs +++ b/.dotnet/src/Generated/Models/MessageObject.Serialization.cs @@ -27,6 +27,35 @@ void IJsonModel.Write(Utf8JsonWriter writer, ModelReaderWriterOpt writer.WriteNumberValue(CreatedAt, "U"); writer.WritePropertyName("thread_id"u8); writer.WriteStringValue(ThreadId); + writer.WritePropertyName("status"u8); + writer.WriteStringValue(Status.ToString()); + if (IncompleteDetails != null) + { + writer.WritePropertyName("incomplete_details"u8); + writer.WriteObjectValue(IncompleteDetails); + } + else + { + writer.WriteNull("incomplete_details"); + } + if (CompletedAt != null) + { + writer.WritePropertyName("completed_at"u8); + writer.WriteStringValue(CompletedAt.Value, "O"); + } + else + { + writer.WriteNull("completed_at"); + } + if (IncompleteAt != null) + { + writer.WritePropertyName("incomplete_at"u8); + writer.WriteStringValue(IncompleteAt.Value, "O"); + } + else + { + writer.WriteNull("incomplete_at"); + } writer.WritePropertyName("role"u8); writer.WriteStringValue(Role.ToString()); writer.WritePropertyName("content"u8); @@ -130,6 +159,10 @@ internal static MessageObject DeserializeMessageObject(JsonElement element, Mode MessageObjectObject @object = default; DateTimeOffset createdAt = default; string threadId = default; + MessageObjectStatus status = default; + MessageObjectIncompleteDetails incompleteDetails = default; + DateTimeOffset? completedAt = default; + DateTimeOffset? incompleteAt = default; MessageObjectRole role = default; IReadOnlyList content = default; string assistantId = default; @@ -160,6 +193,41 @@ internal static MessageObject DeserializeMessageObject(JsonElement element, Mode threadId = property.Value.GetString(); continue; } + if (property.NameEquals("status"u8)) + { + status = new MessageObjectStatus(property.Value.GetString()); + continue; + } + if (property.NameEquals("incomplete_details"u8)) + { + if (property.Value.ValueKind == JsonValueKind.Null) + { + incompleteDetails = null; + continue; + } + incompleteDetails = MessageObjectIncompleteDetails.DeserializeMessageObjectIncompleteDetails(property.Value); + continue; + } + if (property.NameEquals("completed_at"u8)) + { + if (property.Value.ValueKind == JsonValueKind.Null) + { + completedAt = null; + continue; + } + completedAt = property.Value.GetDateTimeOffset("O"); + continue; + } + if (property.NameEquals("incomplete_at"u8)) + { + if (property.Value.ValueKind == JsonValueKind.Null) + { + incompleteAt = null; + continue; + } + incompleteAt = property.Value.GetDateTimeOffset("O"); + continue; + } if (property.NameEquals("role"u8)) { role = new MessageObjectRole(property.Value.GetString()); @@ -233,7 +301,7 @@ internal static MessageObject DeserializeMessageObject(JsonElement element, Mode } } serializedAdditionalRawData = additionalPropertiesDictionary; - return new MessageObject(id, @object, createdAt, threadId, role, content, assistantId, runId, fileIds, metadata, serializedAdditionalRawData); + return new MessageObject(id, @object, createdAt, threadId, status, incompleteDetails, completedAt, incompleteAt, role, content, assistantId, runId, fileIds, metadata, serializedAdditionalRawData); } BinaryData IPersistableModel.Write(ModelReaderWriterOptions options) diff --git a/.dotnet/src/Generated/Models/MessageObject.cs b/.dotnet/src/Generated/Models/MessageObject.cs index f796b78be..f1d857220 100644 --- a/.dotnet/src/Generated/Models/MessageObject.cs +++ b/.dotnet/src/Generated/Models/MessageObject.cs @@ -46,6 +46,10 @@ internal partial class MessageObject /// The identifier, which can be referenced in API endpoints. /// The Unix timestamp (in seconds) for when the message was created. /// The [thread](/docs/api-reference/threads) ID that this message belongs to. + /// The status of the message, which can be either in_progress, incomplete, or completed. + /// On an incomplete message, details about why the message is incomplete. + /// The Unix timestamp at which the message was completed. + /// The Unix timestamp at which the message was marked as incomplete. /// The entity that produced the message. One of `user` or `assistant`. /// The content of the message in array of text and/or images. /// @@ -67,7 +71,7 @@ internal partial class MessageObject /// characters long and values can be a maxium of 512 characters long. /// /// , , or is null. - internal MessageObject(string id, DateTimeOffset createdAt, string threadId, MessageObjectRole role, IEnumerable content, string assistantId, string runId, IEnumerable fileIds, IReadOnlyDictionary metadata) + internal MessageObject(string id, DateTimeOffset createdAt, string threadId, MessageObjectStatus status, MessageObjectIncompleteDetails incompleteDetails, DateTimeOffset? completedAt, DateTimeOffset? incompleteAt, MessageObjectRole role, IEnumerable content, string assistantId, string runId, IEnumerable fileIds, IReadOnlyDictionary metadata) { if (id is null) throw new ArgumentNullException(nameof(id)); if (threadId is null) throw new ArgumentNullException(nameof(threadId)); @@ -77,6 +81,10 @@ internal MessageObject(string id, DateTimeOffset createdAt, string threadId, Mes Id = id; CreatedAt = createdAt; ThreadId = threadId; + Status = status; + IncompleteDetails = incompleteDetails; + CompletedAt = completedAt; + IncompleteAt = incompleteAt; Role = role; Content = content.ToList(); AssistantId = assistantId; @@ -90,6 +98,10 @@ internal MessageObject(string id, DateTimeOffset createdAt, string threadId, Mes /// The object type, which is always `thread.message`. /// The Unix timestamp (in seconds) for when the message was created. /// The [thread](/docs/api-reference/threads) ID that this message belongs to. + /// The status of the message, which can be either in_progress, incomplete, or completed. + /// On an incomplete message, details about why the message is incomplete. + /// The Unix timestamp at which the message was completed. + /// The Unix timestamp at which the message was marked as incomplete. /// The entity that produced the message. One of `user` or `assistant`. /// The content of the message in array of text and/or images. /// @@ -111,12 +123,16 @@ internal MessageObject(string id, DateTimeOffset createdAt, string threadId, Mes /// characters long and values can be a maxium of 512 characters long. /// /// Keeps track of any properties unknown to the library. - internal MessageObject(string id, MessageObjectObject @object, DateTimeOffset createdAt, string threadId, MessageObjectRole role, IReadOnlyList content, string assistantId, string runId, IReadOnlyList fileIds, IReadOnlyDictionary metadata, IDictionary serializedAdditionalRawData) + internal MessageObject(string id, MessageObjectObject @object, DateTimeOffset createdAt, string threadId, MessageObjectStatus status, MessageObjectIncompleteDetails incompleteDetails, DateTimeOffset? completedAt, DateTimeOffset? incompleteAt, MessageObjectRole role, IReadOnlyList content, string assistantId, string runId, IReadOnlyList fileIds, IReadOnlyDictionary metadata, IDictionary serializedAdditionalRawData) { Id = id; Object = @object; CreatedAt = createdAt; ThreadId = threadId; + Status = status; + IncompleteDetails = incompleteDetails; + CompletedAt = completedAt; + IncompleteAt = incompleteAt; Role = role; Content = content; AssistantId = assistantId; @@ -140,6 +156,14 @@ internal MessageObject() public DateTimeOffset CreatedAt { get; } /// The [thread](/docs/api-reference/threads) ID that this message belongs to. public string ThreadId { get; } + /// The status of the message, which can be either in_progress, incomplete, or completed. + public MessageObjectStatus Status { get; } + /// On an incomplete message, details about why the message is incomplete. + public MessageObjectIncompleteDetails IncompleteDetails { get; } + /// The Unix timestamp at which the message was completed. + public DateTimeOffset? CompletedAt { get; } + /// The Unix timestamp at which the message was marked as incomplete. + public DateTimeOffset? IncompleteAt { get; } /// The entity that produced the message. One of `user` or `assistant`. public MessageObjectRole Role { get; } /// diff --git a/.dotnet/src/Generated/Models/MessageObjectIncompleteDetails.Serialization.cs b/.dotnet/src/Generated/Models/MessageObjectIncompleteDetails.Serialization.cs new file mode 100644 index 000000000..02445140a --- /dev/null +++ b/.dotnet/src/Generated/Models/MessageObjectIncompleteDetails.Serialization.cs @@ -0,0 +1,120 @@ +// + +using System; +using OpenAI.ClientShared.Internal; +using System.ClientModel.Primitives; +using System.Collections.Generic; +using System.Text.Json; + +namespace OpenAI.Internal.Models +{ + internal partial class MessageObjectIncompleteDetails : IJsonModel + { + void IJsonModel.Write(Utf8JsonWriter writer, ModelReaderWriterOptions options) + { + var format = options.Format == "W" ? ((IPersistableModel)this).GetFormatFromOptions(options) : options.Format; + if (format != "J") + { + throw new FormatException($"The model {nameof(MessageObjectIncompleteDetails)} does not support '{format}' format."); + } + + writer.WriteStartObject(); + writer.WritePropertyName("reason"u8); + writer.WriteStringValue(Reason); + if (options.Format != "W" && _serializedAdditionalRawData != null) + { + foreach (var item in _serializedAdditionalRawData) + { + writer.WritePropertyName(item.Key); +#if NET6_0_OR_GREATER + writer.WriteRawValue(item.Value); +#else + using (JsonDocument document = JsonDocument.Parse(item.Value)) + { + JsonSerializer.Serialize(writer, document.RootElement); + } +#endif + } + } + writer.WriteEndObject(); + } + + MessageObjectIncompleteDetails IJsonModel.Create(ref Utf8JsonReader reader, ModelReaderWriterOptions options) + { + var format = options.Format == "W" ? ((IPersistableModel)this).GetFormatFromOptions(options) : options.Format; + if (format != "J") + { + throw new FormatException($"The model {nameof(MessageObjectIncompleteDetails)} does not support '{format}' format."); + } + + using JsonDocument document = JsonDocument.ParseValue(ref reader); + return DeserializeMessageObjectIncompleteDetails(document.RootElement, options); + } + + internal static MessageObjectIncompleteDetails DeserializeMessageObjectIncompleteDetails(JsonElement element, ModelReaderWriterOptions options = null) + { + options ??= new ModelReaderWriterOptions("W"); + + if (element.ValueKind == JsonValueKind.Null) + { + return null; + } + string reason = default; + IDictionary serializedAdditionalRawData = default; + Dictionary additionalPropertiesDictionary = new Dictionary(); + foreach (var property in element.EnumerateObject()) + { + if (property.NameEquals("reason"u8)) + { + reason = property.Value.GetString(); + continue; + } + if (options.Format != "W") + { + additionalPropertiesDictionary.Add(property.Name, BinaryData.FromString(property.Value.GetRawText())); + } + } + serializedAdditionalRawData = additionalPropertiesDictionary; + return new MessageObjectIncompleteDetails(reason, serializedAdditionalRawData); + } + + BinaryData IPersistableModel.Write(ModelReaderWriterOptions options) + { + var format = options.Format == "W" ? ((IPersistableModel)this).GetFormatFromOptions(options) : options.Format; + + switch (format) + { + case "J": + return ModelReaderWriter.Write(this, options); + default: + throw new FormatException($"The model {nameof(MessageObjectIncompleteDetails)} does not support '{options.Format}' format."); + } + } + + MessageObjectIncompleteDetails IPersistableModel.Create(BinaryData data, ModelReaderWriterOptions options) + { + var format = options.Format == "W" ? ((IPersistableModel)this).GetFormatFromOptions(options) : options.Format; + + switch (format) + { + case "J": + { + using JsonDocument document = JsonDocument.Parse(data); + return DeserializeMessageObjectIncompleteDetails(document.RootElement, options); + } + default: + throw new FormatException($"The model {nameof(MessageObjectIncompleteDetails)} does not support '{options.Format}' format."); + } + } + + string IPersistableModel.GetFormatFromOptions(ModelReaderWriterOptions options) => "J"; + + /// Deserializes the model from a raw response. + /// The result to deserialize the model from. + internal static MessageObjectIncompleteDetails FromResponse(PipelineResponse response) + { + using var document = JsonDocument.Parse(response.Content); + return DeserializeMessageObjectIncompleteDetails(document.RootElement); + } + } +} diff --git a/.dotnet/src/Generated/Models/MessageObjectIncompleteDetails.cs b/.dotnet/src/Generated/Models/MessageObjectIncompleteDetails.cs new file mode 100644 index 000000000..f07321bd9 --- /dev/null +++ b/.dotnet/src/Generated/Models/MessageObjectIncompleteDetails.cs @@ -0,0 +1,71 @@ +// + +using System; +using OpenAI.ClientShared.Internal; +using System.Collections.Generic; + +namespace OpenAI.Internal.Models +{ + /// The MessageObjectIncompleteDetails. + internal partial class MessageObjectIncompleteDetails + { + /// + /// Keeps track of any properties unknown to the library. + /// + /// To assign an object to the value of this property use . + /// + /// + /// To assign an already formatted json string to this property use . + /// + /// + /// Examples: + /// + /// + /// BinaryData.FromObjectAsJson("foo") + /// Creates a payload of "foo". + /// + /// + /// BinaryData.FromString("\"foo\"") + /// Creates a payload of "foo". + /// + /// + /// BinaryData.FromObjectAsJson(new { key = "value" }) + /// Creates a payload of { "key": "value" }. + /// + /// + /// BinaryData.FromString("{\"key\": \"value\"}") + /// Creates a payload of { "key": "value" }. + /// + /// + /// + /// + private IDictionary _serializedAdditionalRawData; + + /// Initializes a new instance of . + /// The reason the message is incomplete. + /// is null. + internal MessageObjectIncompleteDetails(string reason) + { + if (reason is null) throw new ArgumentNullException(nameof(reason)); + + Reason = reason; + } + + /// Initializes a new instance of . + /// The reason the message is incomplete. + /// Keeps track of any properties unknown to the library. + internal MessageObjectIncompleteDetails(string reason, IDictionary serializedAdditionalRawData) + { + Reason = reason; + _serializedAdditionalRawData = serializedAdditionalRawData; + } + + /// Initializes a new instance of for deserialization. + internal MessageObjectIncompleteDetails() + { + } + + /// The reason the message is incomplete. + public string Reason { get; } + } +} diff --git a/.dotnet/src/Generated/Models/MessageObjectStatus.cs b/.dotnet/src/Generated/Models/MessageObjectStatus.cs new file mode 100644 index 000000000..b1ca494f3 --- /dev/null +++ b/.dotnet/src/Generated/Models/MessageObjectStatus.cs @@ -0,0 +1,49 @@ +// + +using System; +using System.ComponentModel; + +namespace OpenAI.Internal.Models +{ + /// Enum for status in MessageObject. + internal readonly partial struct MessageObjectStatus : IEquatable + { + private readonly string _value; + + /// Initializes a new instance of . + /// is null. + public MessageObjectStatus(string value) + { + _value = value ?? throw new ArgumentNullException(nameof(value)); + } + + private const string InProgressValue = "in_progress"; + private const string IncompleteValue = "incomplete"; + private const string CompletedValue = "completed"; + + /// in_progress. + public static MessageObjectStatus InProgress { get; } = new MessageObjectStatus(InProgressValue); + /// incomplete. + public static MessageObjectStatus Incomplete { get; } = new MessageObjectStatus(IncompleteValue); + /// completed. + public static MessageObjectStatus Completed { get; } = new MessageObjectStatus(CompletedValue); + /// Determines if two values are the same. + public static bool operator ==(MessageObjectStatus left, MessageObjectStatus right) => left.Equals(right); + /// Determines if two values are not the same. + public static bool operator !=(MessageObjectStatus left, MessageObjectStatus right) => !left.Equals(right); + /// Converts a string to a . + public static implicit operator MessageObjectStatus(string value) => new MessageObjectStatus(value); + + /// + [EditorBrowsable(EditorBrowsableState.Never)] + public override bool Equals(object obj) => obj is MessageObjectStatus other && Equals(other); + /// + public bool Equals(MessageObjectStatus other) => string.Equals(_value, other._value, StringComparison.InvariantCultureIgnoreCase); + + /// + [EditorBrowsable(EditorBrowsableState.Never)] + public override int GetHashCode() => _value?.GetHashCode() ?? 0; + /// + public override string ToString() => _value; + } +} diff --git a/.dotnet/src/Generated/OpenAIModelFactory.cs b/.dotnet/src/Generated/OpenAIModelFactory.cs index cc63a0fba..917163315 100644 --- a/.dotnet/src/Generated/OpenAIModelFactory.cs +++ b/.dotnet/src/Generated/OpenAIModelFactory.cs @@ -363,6 +363,10 @@ public static CreateMessageRequest CreateMessageRequest(CreateMessageRequestRole /// The object type, which is always `thread.message`. /// The Unix timestamp (in seconds) for when the message was created. /// The [thread](/docs/api-reference/threads) ID that this message belongs to. + /// The status of the message, which can be either in_progress, incomplete, or completed. + /// On an incomplete message, details about why the message is incomplete. + /// The Unix timestamp at which the message was completed. + /// The Unix timestamp at which the message was marked as incomplete. /// The entity that produced the message. One of `user` or `assistant`. /// The content of the message in array of text and/or images. /// @@ -384,13 +388,21 @@ public static CreateMessageRequest CreateMessageRequest(CreateMessageRequestRole /// characters long and values can be a maxium of 512 characters long. /// /// A new instance for mocking. - public static MessageObject MessageObject(string id = null, MessageObjectObject @object = default, DateTimeOffset createdAt = default, string threadId = null, MessageObjectRole role = default, IEnumerable content = null, string assistantId = null, string runId = null, IEnumerable fileIds = null, IReadOnlyDictionary metadata = null) + public static MessageObject MessageObject(string id = null, MessageObjectObject @object = default, DateTimeOffset createdAt = default, string threadId = null, MessageObjectStatus status = default, MessageObjectIncompleteDetails incompleteDetails = null, DateTimeOffset? completedAt = null, DateTimeOffset? incompleteAt = null, MessageObjectRole role = default, IEnumerable content = null, string assistantId = null, string runId = null, IEnumerable fileIds = null, IReadOnlyDictionary metadata = null) { content ??= new List(); fileIds ??= new List(); metadata ??= new Dictionary(); - return new MessageObject(id, @object, createdAt, threadId, role, content?.ToList(), assistantId, runId, fileIds?.ToList(), metadata, serializedAdditionalRawData: null); + return new MessageObject(id, @object, createdAt, threadId, status, incompleteDetails, completedAt, incompleteAt, role, content?.ToList(), assistantId, runId, fileIds?.ToList(), metadata, serializedAdditionalRawData: null); + } + + /// Initializes a new instance of . + /// The reason the message is incomplete. + /// A new instance for mocking. + public static MessageObjectIncompleteDetails MessageObjectIncompleteDetails(string reason = null) + { + return new MessageObjectIncompleteDetails(reason, serializedAdditionalRawData: null); } /// Initializes a new instance of . @@ -485,13 +497,17 @@ public static DeleteModelResponse DeleteModelResponse(string id = null, bool del /// additional information about the object in a structured format. Keys can be a maximum of 64 /// characters long and values can be a maxium of 512 characters long. /// + /// + /// If true, returns a stream of events that happen during the Run as server-sent events, + /// terminating when the Run enters a terminal state with a data: [DONE] message. + /// /// A new instance for mocking. - public static CreateThreadAndRunRequest CreateThreadAndRunRequest(string assistantId = null, CreateThreadRequest thread = null, string model = null, string instructions = null, IEnumerable tools = null, IDictionary metadata = null) + public static CreateThreadAndRunRequest CreateThreadAndRunRequest(string assistantId = null, CreateThreadRequest thread = null, string model = null, string instructions = null, IEnumerable tools = null, IDictionary metadata = null, bool? stream = null) { tools ??= new List(); metadata ??= new Dictionary(); - return new CreateThreadAndRunRequest(assistantId, thread, model, instructions, tools?.ToList(), metadata, serializedAdditionalRawData: null); + return new CreateThreadAndRunRequest(assistantId, thread, model, instructions, tools?.ToList(), metadata, stream, serializedAdditionalRawData: null); } /// Initializes a new instance of . @@ -624,13 +640,17 @@ public static RunCompletionUsage RunCompletionUsage(long completionTokens = defa /// additional information about the object in a structured format. Keys can be a maximum of 64 /// characters long and values can be a maxium of 512 characters long. /// + /// + /// If true, returns a stream of events that happen during the Run as server-sent events, + /// terminating when the Run enters a terminal state with a data: [DONE] message. + /// /// A new instance for mocking. - public static CreateRunRequest CreateRunRequest(string assistantId = null, string model = null, string instructions = null, string additionalInstructions = null, IEnumerable tools = null, IDictionary metadata = null) + public static CreateRunRequest CreateRunRequest(string assistantId = null, string model = null, string instructions = null, string additionalInstructions = null, IEnumerable tools = null, IDictionary metadata = null, bool? stream = null) { tools ??= new List(); metadata ??= new Dictionary(); - return new CreateRunRequest(assistantId, model, instructions, additionalInstructions, tools?.ToList(), metadata, serializedAdditionalRawData: null); + return new CreateRunRequest(assistantId, model, instructions, additionalInstructions, tools?.ToList(), metadata, stream, serializedAdditionalRawData: null); } /// Initializes a new instance of . diff --git a/.dotnet/src/OpenAI.csproj b/.dotnet/src/OpenAI.csproj index 39523f012..ed91914a5 100644 --- a/.dotnet/src/OpenAI.csproj +++ b/.dotnet/src/OpenAI.csproj @@ -2,6 +2,7 @@ This is the OpenAI client library for developing .NET applications with rich experience. SDK Code Generation OpenAI + 1.0.0-beta.1 OpenAI netstandard2.0 latest diff --git a/.dotnet/src/Utility/SseAsyncEnumerator.cs b/.dotnet/src/Utility/SseAsyncEnumerator.cs index 743a1bedd..96e0b9995 100644 --- a/.dotnet/src/Utility/SseAsyncEnumerator.cs +++ b/.dotnet/src/Utility/SseAsyncEnumerator.cs @@ -19,20 +19,18 @@ internal static async IAsyncEnumerable EnumerateFromSseStream( using SseReader sseReader = new(stream); while (!cancellationToken.IsCancellationRequested) { - SseLine? sseEvent = await sseReader.TryReadSingleFieldEventAsync().ConfigureAwait(false); - if (sseEvent is not null) + SseEvent? sseEvent = await sseReader.TryGetNextEventAsync(cancellationToken).ConfigureAwait(false); + if (sseEvent is null) { - ReadOnlyMemory name = sseEvent.Value.FieldName; - if (!name.Span.SequenceEqual("data".AsSpan())) - { - throw new InvalidDataException(); - } - ReadOnlyMemory value = sseEvent.Value.FieldValue; - if (value.Span.SequenceEqual("[DONE]".AsSpan())) + break; + } + else + { + if (sseEvent.Value.Data.Span.SequenceEqual("[DONE]".AsSpan())) { break; } - using JsonDocument sseMessageJson = JsonDocument.Parse(value); + using JsonDocument sseMessageJson = JsonDocument.Parse(sseEvent.Value.Data); IEnumerable newItems = multiElementDeserializer.Invoke(sseMessageJson.RootElement); foreach (T item in newItems) { diff --git a/.dotnet/src/Utility/SseEvent.cs b/.dotnet/src/Utility/SseEvent.cs new file mode 100644 index 000000000..96552f392 --- /dev/null +++ b/.dotnet/src/Utility/SseEvent.cs @@ -0,0 +1,64 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace OpenAI; + +// SSE specification: https://html.spec.whatwg.org/multipage/server-sent-events.html#parsing-an-event-stream +internal readonly struct SseEvent +{ + public ReadOnlyMemory EventType { get; } + public ReadOnlyMemory Data { get; } + public ReadOnlyMemory LastEventId { get; } + public TimeSpan? ReconnectionTime { get; } + + private readonly IReadOnlyList _fields; + private readonly string _multiLineData; + + internal SseEvent(IReadOnlyList fields) + { + _fields = fields; + StringBuilder multiLineDataBuilder = null; + for (int i = 0; i < _fields.Count; i++) + { + ReadOnlyMemory fieldValue = _fields[i].Value; + switch (_fields[i].FieldType) + { + case SseEventFieldType.Event: + EventType = fieldValue; + break; + case SseEventFieldType.Data: + { + if (multiLineDataBuilder != null) + { + multiLineDataBuilder.Append(fieldValue); + } + else if (Data.IsEmpty) + { + Data = fieldValue; + } + else + { + multiLineDataBuilder ??= new(); + multiLineDataBuilder.Append(fieldValue); + Data = null; + } + break; + } + case SseEventFieldType.Id: + LastEventId = fieldValue; + break; + case SseEventFieldType.Retry: + ReconnectionTime = Int32.TryParse(fieldValue.ToString(), out int retry) ? TimeSpan.FromMilliseconds(retry) : null; + break; + default: + break; + } + if (multiLineDataBuilder != null) + { + _multiLineData = multiLineDataBuilder.ToString(); + Data = _multiLineData.AsMemory(); + } + } + } +} \ No newline at end of file diff --git a/.dotnet/src/Utility/SseEventField.cs b/.dotnet/src/Utility/SseEventField.cs new file mode 100644 index 000000000..0293a0a87 --- /dev/null +++ b/.dotnet/src/Utility/SseEventField.cs @@ -0,0 +1,64 @@ +using System; + +namespace OpenAI; + +// SSE specification: https://html.spec.whatwg.org/multipage/server-sent-events.html#parsing-an-event-stream +internal readonly struct SseEventField +{ + public SseEventFieldType FieldType { get; } + + // TODO: we should not expose UTF16 publicly + public ReadOnlyMemory Value + { + get + { + if (_valueStartIndex >= _original.Length) + { + return ReadOnlyMemory.Empty; + } + else + { + return _original.AsMemory(_valueStartIndex); + } + } + } + + private readonly string _original; + private readonly int _valueStartIndex; + + internal SseEventField(string line) + { + _original = line; + int colonIndex = _original.AsSpan().IndexOf(':'); + + ReadOnlyMemory fieldName = colonIndex < 0 ? _original.AsMemory(): _original.AsMemory(0, colonIndex); + FieldType = fieldName.Span switch + { + var x when x.SequenceEqual(s_eventFieldName.Span) => SseEventFieldType.Event, + var x when x.SequenceEqual(s_dataFieldName.Span) => SseEventFieldType.Data, + var x when x.SequenceEqual(s_lastEventIdFieldName.Span) => SseEventFieldType.Id, + var x when x.SequenceEqual(s_retryFieldName.Span) => SseEventFieldType.Retry, + _ => SseEventFieldType.Ignored, + }; + + if (colonIndex < 0) + { + _valueStartIndex = _original.Length; + } + else if (colonIndex + 1 < _original.Length && _original[colonIndex + 1] == ' ') + { + _valueStartIndex = colonIndex + 2; + } + else + { + _valueStartIndex = colonIndex + 1; + } + } + + public override string ToString() => _original; + + private static readonly ReadOnlyMemory s_eventFieldName = "event".AsMemory(); + private static readonly ReadOnlyMemory s_dataFieldName = "data".AsMemory(); + private static readonly ReadOnlyMemory s_lastEventIdFieldName = "id".AsMemory(); + private static readonly ReadOnlyMemory s_retryFieldName = "retry".AsMemory(); +} \ No newline at end of file diff --git a/.dotnet/src/Utility/SseEventFieldType.cs b/.dotnet/src/Utility/SseEventFieldType.cs new file mode 100644 index 000000000..d10bd404a --- /dev/null +++ b/.dotnet/src/Utility/SseEventFieldType.cs @@ -0,0 +1,11 @@ +namespace OpenAI; + +// SSE specification: https://html.spec.whatwg.org/multipage/server-sent-events.html#parsing-an-event-stream +internal enum SseEventFieldType +{ + Event, + Data, + Id, + Retry, + Ignored +} \ No newline at end of file diff --git a/.dotnet/src/Utility/SseLine.cs b/.dotnet/src/Utility/SseLine.cs deleted file mode 100644 index 4d82315f9..000000000 --- a/.dotnet/src/Utility/SseLine.cs +++ /dev/null @@ -1,29 +0,0 @@ -using System; - -namespace OpenAI; - -// SSE specification: https://html.spec.whatwg.org/multipage/server-sent-events.html#parsing-an-event-stream -internal readonly struct SseLine -{ - private readonly string _original; - private readonly int _colonIndex; - private readonly int _valueIndex; - - public static SseLine Empty { get; } = new SseLine(string.Empty, 0, false); - - internal SseLine(string original, int colonIndex, bool hasSpaceAfterColon) - { - _original = original; - _colonIndex = colonIndex; - _valueIndex = colonIndex + (hasSpaceAfterColon ? 2 : 1); - } - - public bool IsEmpty => _original.Length == 0; - public bool IsComment => !IsEmpty && _original[0] == ':'; - - // TODO: we should not expose UTF16 publicly - public ReadOnlyMemory FieldName => _original.AsMemory(0, _colonIndex); - public ReadOnlyMemory FieldValue => _original.AsMemory(_valueIndex); - - public override string ToString() => _original; -} \ No newline at end of file diff --git a/.dotnet/src/Utility/SseReader.cs b/.dotnet/src/Utility/SseReader.cs index cf0301408..2f5facda8 100644 --- a/.dotnet/src/Utility/SseReader.cs +++ b/.dotnet/src/Utility/SseReader.cs @@ -1,118 +1,105 @@ using System; -using System.ClientModel; -using System.ClientModel.Internal; +using System.Collections.Generic; using System.IO; +using System.Threading; using System.Threading.Tasks; namespace OpenAI; internal sealed class SseReader : IDisposable +{ + private readonly Stream _stream; + private readonly StreamReader _reader; + private bool _disposedValue; + + public SseReader(Stream stream) { - private readonly Stream _stream; - private readonly StreamReader _reader; - private bool _disposedValue; + _stream = stream; + _reader = new StreamReader(stream); + } - public SseReader(Stream stream) - { - _stream = stream; - _reader = new StreamReader(stream); - } + public SseEvent? TryGetNextEvent() + { + List fields = []; - public SseLine? TryReadSingleFieldEvent() + while (true) { - while (true) + string line = _reader.ReadLine(); + if (line == null) { - SseLine? line = TryReadLine(); - if (line == null) - return null; - if (line.Value.IsEmpty) - throw new InvalidDataException("event expected."); - SseLine? empty = TryReadLine(); - if (empty != null && !empty.Value.IsEmpty) - throw new NotSupportedException("Multi-filed events not supported."); - if (!line.Value.IsComment) - return line; // skip comment lines + // A null line indicates end of input + return null; } - } - - // TODO: we should support cancellation tokens, but StreamReader does not in NS2 - public async Task TryReadSingleFieldEventAsync() - { - while (true) + else if (line.Length == 0) { - SseLine? line = await TryReadLineAsync().ConfigureAwait(false); - if (line == null) - return null; - if (line.Value.IsEmpty) - throw new InvalidDataException("event expected."); - SseLine? empty = await TryReadLineAsync().ConfigureAwait(false); - if (empty != null && !empty.Value.IsEmpty) - throw new NotSupportedException("Multi-filed events not supported."); - if (!line.Value.IsComment) - return line; // skip comment lines + // An empty line should dispatch an event for pending accumulated fields + SseEvent nextEvent = new(fields); + fields = []; + return nextEvent; + } + else if (line[0] == ':') + { + // A line beginning with a colon is a comment and should be ignored + continue; + } + else + { + // Otherwise, process the the field + value and accumulate it for the next dispatched event + fields.Add(new SseEventField(line)); } } + } - public SseLine? TryReadLine() - { - string lineText = _reader.ReadLine(); - if (lineText == null) - return null; - if (lineText.Length == 0) - return SseLine.Empty; - if (TryParseLine(lineText, out SseLine line)) - return line; - return null; - } + public async Task TryGetNextEventAsync(CancellationToken cancellationToken = default) + { + List fields = []; - // TODO: we should support cancellation tokens, but StreamReader does not in NS2 - public async Task TryReadLineAsync() + while (!cancellationToken.IsCancellationRequested) { - string lineText = await _reader.ReadLineAsync().ConfigureAwait(false); - if (lineText == null) + string line = await _reader.ReadLineAsync().ConfigureAwait(false); + if (line == null) + { + // A null line indicates end of input return null; - if (lineText.Length == 0) - return SseLine.Empty; - if (TryParseLine(lineText, out SseLine line)) - return line; - return null; - } - - private static bool TryParseLine(string lineText, out SseLine line) - { - if (lineText.Length == 0) + } + else if (line.Length == 0) { - line = default; - return false; + // An empty line should dispatch an event for pending accumulated fields + SseEvent nextEvent = new(fields); + fields = []; + return nextEvent; + } + else if (line[0] == ':') + { + // A line beginning with a colon is a comment and should be ignored + continue; + } + else + { + // Otherwise, process the the field + value and accumulate it for the next dispatched event + fields.Add(new SseEventField(line)); } - - ReadOnlySpan lineSpan = lineText.AsSpan(); - int colonIndex = lineSpan.IndexOf(':'); - ReadOnlySpan fieldValue = lineSpan.Slice(colonIndex + 1); - - bool hasSpace = false; - if (fieldValue.Length > 0 && fieldValue[0] == ' ') - hasSpace = true; - line = new SseLine(lineText, colonIndex, hasSpace); - return true; } + return null; + } + + public void Dispose() + { + Dispose(disposing: true); + GC.SuppressFinalize(this); + } - private void Dispose(bool disposing) + private void Dispose(bool disposing) + { + if (!_disposedValue) { - if (!_disposedValue) + if (disposing) { - if (disposing) - { - _reader.Dispose(); - _stream.Dispose(); - } - - _disposedValue = true; + _reader.Dispose(); + _stream.Dispose(); } + + _disposedValue = true; } - public void Dispose() - { - Dispose(disposing: true); - GC.SuppressFinalize(this); - } - } \ No newline at end of file + } +} \ No newline at end of file diff --git a/.dotnet/tests/OpenAI.Tests.csproj b/.dotnet/tests/OpenAI.Tests.csproj index 9bb493c12..aaf8385b2 100644 --- a/.dotnet/tests/OpenAI.Tests.csproj +++ b/.dotnet/tests/OpenAI.Tests.csproj @@ -16,27 +16,4 @@ - - - PreserveNewest - - - PreserveNewest - - - PreserveNewest - - - PreserveNewest - - - PreserveNewest - - - PreserveNewest - - - PreserveNewest - - \ No newline at end of file diff --git a/.dotnet/tests/TestScenarios/AssistantTests.cs b/.dotnet/tests/TestScenarios/AssistantTests.cs index aea1c3614..031b93b0e 100644 --- a/.dotnet/tests/TestScenarios/AssistantTests.cs +++ b/.dotnet/tests/TestScenarios/AssistantTests.cs @@ -126,6 +126,32 @@ public async Task BasicFunctionToolWorks() Assert.That(runResult.Value.Status, Is.Not.EqualTo(RunStatus.RequiresAction)); } + [Test] + public async Task StreamingRunWorks() + { + AssistantClient client = GetTestClient(); + Assistant assistant = await CreateCommonTestAssistantAsync(); + + StreamingClientResult runUpdateResult = client.CreateThreadAndRunStreaming( + assistant.Id, + new ThreadCreationOptions() + { + Messages = + { + "Hello, assistant! Can you help me?", + } + }); + Assert.That(runUpdateResult, Is.Not.Null); + await foreach (StreamingRunUpdate runUpdate in runUpdateResult) + { + if (runUpdate.MessageRole.HasValue) + { + Console.Write($"[{runUpdate.MessageRole}]"); + } + Console.Write(runUpdate.MessageContentUpdate?.GetText()); + } + } + private async Task CreateCommonTestAssistantAsync() { AssistantClient client = new(); @@ -141,7 +167,8 @@ private async Task CreateCommonTestAssistantAsync() return newAssistantResult.Value; } - private async Task DeleteRecentTestThings() + [TearDown] + protected async Task DeleteRecentTestThings() { AssistantClient client = new(); foreach(Assistant assistant in client.GetAssistants().Value) diff --git a/messages/models.tsp b/messages/models.tsp index 67d78db29..16c059d4b 100644 --- a/messages/models.tsp +++ b/messages/models.tsp @@ -72,6 +72,33 @@ model MessageObject { /** The [thread](/docs/api-reference/threads) ID that this message belongs to. */ thread_id: string; + /** + * The status of the message, which can be either in_progress, incomplete, or completed. + */ + status: "in_progress" | "incomplete" | "completed"; + + /** + * On an incomplete message, details about why the message is incomplete. + */ + incomplete_details: { + /** + * The reason the message is incomplete. + */ + reason: string; + } | null; + + /** + * The Unix timestamp at which the message was completed. + */ + @encode("unixTimestamp", int32) + completed_at: utcDateTime | null; + + /** + * The Unix timestamp at which the message was marked as incomplete. + */ + @encode("unixTimestamp", int32) + incomplete_at: utcDateTime | null; + /** The entity that produced the message. One of `user` or `assistant`. */ role: "user" | "assistant"; diff --git a/runs/models.tsp b/runs/models.tsp index e29f762e0..7619a731d 100644 --- a/runs/models.tsp +++ b/runs/models.tsp @@ -42,6 +42,12 @@ model CreateRunRequest { */ @extension("x-oaiTypeLabel", "map") metadata?: Record | null; + + /** + * If true, returns a stream of events that happen during the Run as server-sent events, + * terminating when the Run enters a terminal state with a data: [DONE] message. + */ + stream?: boolean; } model ModifyRunRequest { @@ -87,6 +93,12 @@ model CreateThreadAndRunRequest { */ @extension("x-oaiTypeLabel", "map") metadata?: Record | null; + + /** + * If true, returns a stream of events that happen during the Run as server-sent events, + * terminating when the Run enters a terminal state with a data: [DONE] message. + */ + stream?: boolean; } model SubmitToolOutputsRunRequest { diff --git a/tsp-output/@typespec/openapi3/openapi.yaml b/tsp-output/@typespec/openapi3/openapi.yaml index 394ac492c..e19b6143a 100644 --- a/tsp-output/@typespec/openapi3/openapi.yaml +++ b/tsp-output/@typespec/openapi3/openapi.yaml @@ -3690,6 +3690,11 @@ components: additional information about the object in a structured format. Keys can be a maximum of 64 characters long and values can be a maxium of 512 characters long. x-oaiTypeLabel: map + stream: + type: boolean + description: |- + If true, returns a stream of events that happen during the Run as server-sent events, + terminating when the Run enters a terminal state with a data: [DONE] message. CreateRunRequestTool: oneOf: - $ref: '#/components/schemas/AssistantToolsCode' @@ -3793,6 +3798,11 @@ components: additional information about the object in a structured format. Keys can be a maximum of 64 characters long and values can be a maxium of 512 characters long. x-oaiTypeLabel: map + stream: + type: boolean + description: |- + If true, returns a stream of events that happen during the Run as server-sent events, + terminating when the Run enters a terminal state with a data: [DONE] message. CreateThreadRequest: type: object properties: @@ -4704,6 +4714,10 @@ components: - object - created_at - thread_id + - status + - incomplete_details + - completed_at + - incomplete_at - role - content - assistant_id @@ -4726,6 +4740,33 @@ components: thread_id: type: string description: The [thread](/docs/api-reference/threads) ID that this message belongs to. + status: + type: string + enum: + - in_progress + - incomplete + - completed + description: The status of the message, which can be either in_progress, incomplete, or completed. + incomplete_details: + type: object + properties: + reason: + type: string + description: The reason the message is incomplete. + required: + - reason + nullable: true + description: On an incomplete message, details about why the message is incomplete. + completed_at: + type: string + format: date-time + nullable: true + description: The Unix timestamp at which the message was completed. + incomplete_at: + type: string + format: date-time + nullable: true + description: The Unix timestamp at which the message was marked as incomplete. role: type: string enum: From 8c519269554f440e8e1e50dc5da772156991850a Mon Sep 17 00:00:00 2001 From: Travis Wilson Date: Mon, 18 Mar 2024 10:36:40 -0700 Subject: [PATCH 02/15] merge main --- .../MessageObjectIncompleteDetails.Serialization.cs | 10 ++++++++++ .../Generated/Models/MessageObjectIncompleteDetails.cs | 4 ++-- .dotnet/tests/OpenAI.Tests.csproj | 2 +- 3 files changed, 13 insertions(+), 3 deletions(-) diff --git a/.dotnet/src/Generated/Models/MessageObjectIncompleteDetails.Serialization.cs b/.dotnet/src/Generated/Models/MessageObjectIncompleteDetails.Serialization.cs index 02445140a..d7683486e 100644 --- a/.dotnet/src/Generated/Models/MessageObjectIncompleteDetails.Serialization.cs +++ b/.dotnet/src/Generated/Models/MessageObjectIncompleteDetails.Serialization.cs @@ -2,9 +2,11 @@ using System; using OpenAI.ClientShared.Internal; +using System.ClientModel; using System.ClientModel.Primitives; using System.Collections.Generic; using System.Text.Json; +using OpenAI; namespace OpenAI.Internal.Models { @@ -116,5 +118,13 @@ internal static MessageObjectIncompleteDetails FromResponse(PipelineResponse res using var document = JsonDocument.Parse(response.Content); return DeserializeMessageObjectIncompleteDetails(document.RootElement); } + + /// Convert into a Utf8JsonRequestBody. + internal virtual BinaryContent ToRequestBody() + { + var content = new Utf8JsonRequestBody(); + content.JsonWriter.WriteObjectValue(this); + return content; + } } } diff --git a/.dotnet/src/Generated/Models/MessageObjectIncompleteDetails.cs b/.dotnet/src/Generated/Models/MessageObjectIncompleteDetails.cs index f07321bd9..4033843d0 100644 --- a/.dotnet/src/Generated/Models/MessageObjectIncompleteDetails.cs +++ b/.dotnet/src/Generated/Models/MessageObjectIncompleteDetails.cs @@ -1,8 +1,8 @@ // using System; -using OpenAI.ClientShared.Internal; using System.Collections.Generic; +using OpenAI; namespace OpenAI.Internal.Models { @@ -46,7 +46,7 @@ internal partial class MessageObjectIncompleteDetails /// is null. internal MessageObjectIncompleteDetails(string reason) { - if (reason is null) throw new ArgumentNullException(nameof(reason)); + Argument.AssertNotNull(reason, nameof(reason)); Reason = reason; } diff --git a/.dotnet/tests/OpenAI.Tests.csproj b/.dotnet/tests/OpenAI.Tests.csproj index 61cb32339..aaf8385b2 100644 --- a/.dotnet/tests/OpenAI.Tests.csproj +++ b/.dotnet/tests/OpenAI.Tests.csproj @@ -1,6 +1,6 @@ - net8.0 + net7.0 $(NoWarn);CS1591 latest From e4916bb38c9199c46ac45100df5945accf3dfb94 Mon Sep 17 00:00:00 2001 From: Travis Wilson Date: Mon, 18 Mar 2024 16:02:06 -0700 Subject: [PATCH 03/15] flush updates to use polymorphic types --- .dotnet/scripts/Add-Customizations.ps1 | 12 +- .../src/Custom/Assistants/AssistantClient.cs | 4 +- .../Assistants/MessageRole.Serialization.cs | 29 ++++ ...treamingMessageCompletion.Serialization.cs | 13 ++ .../Assistants/StreamingMessageCompletion.cs | 15 ++ .../StreamingMessageCreation.Serialization.cs | 84 +++++++++++ .../Assistants/StreamingMessageCreation.cs | 37 +++++ .../StreamingMessageUpdate.Serialization.cs | 63 ++++++++ .../Assistants/StreamingMessageUpdate.cs | 18 +++ .../StreamingRunUpdate.Serialization.cs | 37 +++++ .../Custom/Assistants/StreamingRunUpdate.cs | 136 +----------------- .dotnet/src/Custom/Chat/ChatClient.cs | 8 +- .../src/Custom/Chat/StreamingChatUpdate.cs | 6 +- .../Models/MessageObject.Serialization.cs | 8 +- .../Models/RunObject.Serialization.cs | 20 +++ .dotnet/src/OpenAI.csproj | 1 - .dotnet/src/Utility/SseAsyncEnumerator.cs | 32 ++--- .dotnet/tests/OpenAI.Tests.csproj | 3 + .dotnet/tests/TestScenarios/AssistantTests.cs | 14 +- 19 files changed, 373 insertions(+), 167 deletions(-) create mode 100644 .dotnet/src/Custom/Assistants/MessageRole.Serialization.cs create mode 100644 .dotnet/src/Custom/Assistants/StreamingMessageCompletion.Serialization.cs create mode 100644 .dotnet/src/Custom/Assistants/StreamingMessageCompletion.cs create mode 100644 .dotnet/src/Custom/Assistants/StreamingMessageCreation.Serialization.cs create mode 100644 .dotnet/src/Custom/Assistants/StreamingMessageCreation.cs create mode 100644 .dotnet/src/Custom/Assistants/StreamingMessageUpdate.Serialization.cs create mode 100644 .dotnet/src/Custom/Assistants/StreamingMessageUpdate.cs create mode 100644 .dotnet/src/Custom/Assistants/StreamingRunUpdate.Serialization.cs diff --git a/.dotnet/scripts/Add-Customizations.ps1 b/.dotnet/scripts/Add-Customizations.ps1 index 30e795553..8f81c806d 100644 --- a/.dotnet/scripts/Add-Customizations.ps1 +++ b/.dotnet/scripts/Add-Customizations.ps1 @@ -34,11 +34,15 @@ function Set-LangVersionToLatest { $xml.Save($filePath) } -function Edit-RunObjectSerialization { +function Edit-DateTimeOffsetSerialization { + param( + [string]$filename + ) + $root = Split-Path $PSScriptRoot -Parent $directory = Join-Path -Path $root -ChildPath "src\Generated\Models" - $file = Get-ChildItem -Path $directory -Filter "RunObject.Serialization.cs" + $file = Get-ChildItem -Path $directory -Filter $filename $content = Get-Content -Path $file -Raw Write-Output "Editing $($file.FullName)" @@ -48,6 +52,7 @@ function Edit-RunObjectSerialization { $content = $content -creplace "cancelledAt = property\.Value\.GetDateTimeOffset\(`"O`"\);", "// BUG: https://github.com/Azure/autorest.csharp/issues/4296`r`n // cancelledAt = property.Value.GetDateTimeOffset(`"O`");`r`n cancelledAt = DateTimeOffset.FromUnixTimeSeconds(property.Value.GetInt64());" $content = $content -creplace "failedAt = property\.Value\.GetDateTimeOffset\(`"O`"\);", "// BUG: https://github.com/Azure/autorest.csharp/issues/4296`r`n // failedAt = property.Value.GetDateTimeOffset(`"O`");`r`n failedAt = DateTimeOffset.FromUnixTimeSeconds(property.Value.GetInt64());" $content = $content -creplace "completedAt = property\.Value\.GetDateTimeOffset\(`"O`"\);", "// BUG: https://github.com/Azure/autorest.csharp/issues/4296`r`n // completedAt = property.Value.GetDateTimeOffset(`"O`");`r`n completedAt = DateTimeOffset.FromUnixTimeSeconds(property.Value.GetInt64());" + $content = $content -creplace "incompleteAt = property\.Value\.GetDateTimeOffset\(`"O`"\);", "// BUG: https://github.com/Azure/autorest.csharp/issues/4296`r`n // completedAt = property.Value.GetDateTimeOffset(`"O`");`r`n completedAt = DateTimeOffset.FromUnixTimeSeconds(property.Value.GetInt64());" $content | Set-Content -Path $file.FullName -NoNewline } @@ -55,4 +60,5 @@ function Edit-RunObjectSerialization { Update-SystemTextJsonPackage Update-MicrosoftBclAsyncInterfacesPackage Set-LangVersionToLatest -Edit-RunObjectSerialization \ No newline at end of file +Edit-DateTimeOffsetSerialization -filename "RunObject.Serialization.cs" +Edit-DateTimeOffsetSerialization -filename "MessageObject.Serialization.cs" diff --git a/.dotnet/src/Custom/Assistants/AssistantClient.cs b/.dotnet/src/Custom/Assistants/AssistantClient.cs index 4ca68aac8..b1a575680 100644 --- a/.dotnet/src/Custom/Assistants/AssistantClient.cs +++ b/.dotnet/src/Custom/Assistants/AssistantClient.cs @@ -705,9 +705,9 @@ internal static StreamingClientResult CreateStreamingRunResu ClientResult genericResult = ClientResult.FromResponse(response); StreamingClientResult streamingResult = StreamingClientResult.CreateFromResponse( genericResult, - (responseForEnumeration) => SseAsyncEnumerator.EnumerateFromSseStream( + (responseForEnumeration) => SseAsyncEnumerator.EnumerateFromSseJsonStream( responseForEnumeration.GetRawResponse().ContentStream, - e => StreamingRunUpdate.DeserializeStreamingRunUpdates(e))); + StreamingRunUpdate.DeserializeSseRunUpdates)); response = null; return streamingResult; } diff --git a/.dotnet/src/Custom/Assistants/MessageRole.Serialization.cs b/.dotnet/src/Custom/Assistants/MessageRole.Serialization.cs new file mode 100644 index 000000000..eb1a01972 --- /dev/null +++ b/.dotnet/src/Custom/Assistants/MessageRole.Serialization.cs @@ -0,0 +1,29 @@ +using System; +using System.ClientModel.Primitives; +using System.Text.Json; + +namespace OpenAI.Assistants; + +internal class MessageRoleSerialization +{ + public static MessageRole DeserializeMessageRole(JsonElement jsonElement, ModelReaderWriterOptions options = default) + { + if (jsonElement.ValueKind != JsonValueKind.String) + { + throw new ArgumentException(nameof(jsonElement)); + } + string roleName = jsonElement.GetString(); + if (roleName == "assistant") + { + return MessageRole.Assistant; + } + else if (roleName == "user") + { + return MessageRole.User; + } + else + { + throw new NotImplementedException(roleName); + } + } +} diff --git a/.dotnet/src/Custom/Assistants/StreamingMessageCompletion.Serialization.cs b/.dotnet/src/Custom/Assistants/StreamingMessageCompletion.Serialization.cs new file mode 100644 index 000000000..eab1a3f01 --- /dev/null +++ b/.dotnet/src/Custom/Assistants/StreamingMessageCompletion.Serialization.cs @@ -0,0 +1,13 @@ +namespace OpenAI.Assistants; + +using System.ClientModel.Primitives; +using System.Text.Json; + +public partial class StreamingMessageCompletion : StreamingRunUpdate +{ + internal static StreamingMessageCompletion DeserializeStreamingMessageCompletion(JsonElement element, ModelReaderWriterOptions options = default) + { + return new StreamingMessageCompletion( + new ThreadMessage(Internal.Models.MessageObject.DeserializeMessageObject(element, options))); + } +} diff --git a/.dotnet/src/Custom/Assistants/StreamingMessageCompletion.cs b/.dotnet/src/Custom/Assistants/StreamingMessageCompletion.cs new file mode 100644 index 000000000..daede59df --- /dev/null +++ b/.dotnet/src/Custom/Assistants/StreamingMessageCompletion.cs @@ -0,0 +1,15 @@ +namespace OpenAI.Assistants; + +using System; +using System.Collections.Generic; +using System.Text.Json; + +public partial class StreamingMessageCompletion : StreamingRunUpdate +{ + public ThreadMessage Message { get; } + + internal StreamingMessageCompletion(ThreadMessage message) + { + Message = message; + } +} diff --git a/.dotnet/src/Custom/Assistants/StreamingMessageCreation.Serialization.cs b/.dotnet/src/Custom/Assistants/StreamingMessageCreation.Serialization.cs new file mode 100644 index 000000000..7f45b7339 --- /dev/null +++ b/.dotnet/src/Custom/Assistants/StreamingMessageCreation.Serialization.cs @@ -0,0 +1,84 @@ +namespace OpenAI.Assistants; + +using System; +using System.ClientModel.Primitives; +using System.Collections.Generic; +using System.Text.Json; + +public partial class StreamingMessageCreation +{ + internal static StreamingRunUpdate DeserializeSseMessageCreation( + JsonElement sseDataJson, + ModelReaderWriterOptions options = null) + { + string id = null; + string assistantId = null; + string threadId = null; + string runId = null; + DateTimeOffset? createdAt = null; + MessageRole? role = null; + List contentItems = []; + List fileIds = []; + + foreach (JsonProperty property in sseDataJson.EnumerateObject()) + { + if (property.NameEquals("id"u8)) + { + id = property.Value.GetString(); + continue; + } + if (property.NameEquals("assistant_id"u8)) + { + assistantId = property.Value.GetString(); + continue; + } + if (property.NameEquals("thread_id"u8)) + { + threadId = property.Value.GetString(); + continue; + } + if (property.NameEquals("run_id"u8)) + { + runId = property.Value.GetString(); + continue; + } + if (property.NameEquals("created_at"u8)) + { + createdAt = DateTimeOffset.FromUnixTimeSeconds(property.Value.GetInt64()); + continue; + } + if (property.NameEquals("role"u8)) + { + role = MessageRoleSerialization.DeserializeMessageRole(property.Value); + continue; + } + if (property.NameEquals("content"u8)) + { + foreach (JsonElement contentElement in property.Value.EnumerateArray()) + { + contentItems.Add(MessageContent.DeserializeMessageContent(contentElement)); + } + continue; + } + if (property.NameEquals("file_ids"u8)) + { + foreach (JsonElement fileIdElement in property.Value.EnumerateArray()) + { + fileIds.Add(fileIdElement.GetString()); + } + continue; + } + } + + return new StreamingMessageCreation( + id, + assistantId, + threadId, + runId, + createdAt.Value, + role.Value, + contentItems, + fileIds); + + } +} diff --git a/.dotnet/src/Custom/Assistants/StreamingMessageCreation.cs b/.dotnet/src/Custom/Assistants/StreamingMessageCreation.cs new file mode 100644 index 000000000..71ae8c651 --- /dev/null +++ b/.dotnet/src/Custom/Assistants/StreamingMessageCreation.cs @@ -0,0 +1,37 @@ +namespace OpenAI.Assistants; + +using System; +using System.Collections.Generic; +using System.Text.Json; + +public partial class StreamingMessageCreation : StreamingRunUpdate +{ + public string Id { get; } + public string AssistantId { get; } + public string ThreadId { get; } + public string RunId { get; } + public DateTimeOffset CreatedAt { get; } + public MessageRole Role { get; } + public IReadOnlyList ContentItems { get; } + public IReadOnlyList FileIds { get; } + + internal StreamingMessageCreation( + string id, + string assistantId, + string threadId, + string runId, + DateTimeOffset createdAt, + MessageRole role, + IReadOnlyList contentItems, + IReadOnlyList fileIds) + { + Id = id; + AssistantId = assistantId; + ThreadId = threadId; + RunId = runId; + CreatedAt = createdAt; + Role = role; + ContentItems = contentItems; + FileIds = fileIds; + } +} diff --git a/.dotnet/src/Custom/Assistants/StreamingMessageUpdate.Serialization.cs b/.dotnet/src/Custom/Assistants/StreamingMessageUpdate.Serialization.cs new file mode 100644 index 000000000..f59e22f07 --- /dev/null +++ b/.dotnet/src/Custom/Assistants/StreamingMessageUpdate.Serialization.cs @@ -0,0 +1,63 @@ +namespace OpenAI.Assistants; + +using System; +using System.ClientModel.Primitives; +using System.Collections.Generic; +using System.Text.Json; + +public partial class StreamingMessageUpdate +{ + internal static IEnumerable DeserializeSseMessageUpdates( + JsonElement sseDataJson, + ModelReaderWriterOptions options = default) + { + List results = []; + if (sseDataJson.ValueKind == JsonValueKind.Null) + { + return results; + } + string id = null; + List<(int Index, MessageContent Content)> indexedContentItems = []; + foreach (JsonProperty property in sseDataJson.EnumerateObject()) + { + if (property.NameEquals("id"u8)) + { + id = property.Value.GetString(); + continue; + } + if (property.NameEquals("delta"u8)) + { + foreach (JsonProperty messageDeltaProperty in property.Value.EnumerateObject()) + { + if (messageDeltaProperty.NameEquals("content"u8)) + { + foreach (JsonElement contentItemElement in messageDeltaProperty.Value.EnumerateArray()) + { + MessageContent contentItem = MessageContent.DeserializeMessageContent(contentItemElement); + foreach (JsonProperty contentItemProperty in contentItemElement.EnumerateObject()) + { + if (contentItemProperty.NameEquals("index"u8)) + { + indexedContentItems.Add((contentItemProperty.Value.GetInt32(), contentItem)); + continue; + } + } + } + continue; + } + } + continue; + } + } + foreach((int index, MessageContent contentItem) in indexedContentItems) + { + results.Add(new StreamingMessageUpdate(contentItem, index)); + } + return results; + } + + internal StreamingMessageUpdate() : base() + { + + } +} diff --git a/.dotnet/src/Custom/Assistants/StreamingMessageUpdate.cs b/.dotnet/src/Custom/Assistants/StreamingMessageUpdate.cs new file mode 100644 index 000000000..0ffe94bb5 --- /dev/null +++ b/.dotnet/src/Custom/Assistants/StreamingMessageUpdate.cs @@ -0,0 +1,18 @@ +namespace OpenAI.Assistants; + +using System.Collections.Generic; +using System.Text.Json; + +public partial class StreamingMessageUpdate : StreamingRunUpdate +{ + public MessageContent ContentUpdate { get; } + public int? ContentUpdateIndex { get; } + + internal StreamingMessageUpdate( + MessageContent contentUpdate, + int? contentIndex) + { + ContentUpdate = contentUpdate; + ContentUpdateIndex = contentIndex; + } +} diff --git a/.dotnet/src/Custom/Assistants/StreamingRunUpdate.Serialization.cs b/.dotnet/src/Custom/Assistants/StreamingRunUpdate.Serialization.cs new file mode 100644 index 000000000..acedb5984 --- /dev/null +++ b/.dotnet/src/Custom/Assistants/StreamingRunUpdate.Serialization.cs @@ -0,0 +1,37 @@ +namespace OpenAI.Assistants; + +using System; +using System.Collections.Generic; +using System.Text.Json; + +public partial class StreamingRunUpdate +{ + internal static IEnumerable DeserializeSseRunUpdates( + ReadOnlyMemory sseEventName, + JsonElement sseDataJson) + { + IEnumerable results = []; + if (sseEventName.Span.SequenceEqual(s_threadMessageCreationEventName.Span)) + { + results = [StreamingMessageCreation.DeserializeSseMessageCreation(sseDataJson)]; + } + else if (sseEventName.Span.SequenceEqual(s_threadMessageDeltaEventName.Span)) + { + results = StreamingMessageUpdate.DeserializeSseMessageUpdates(sseDataJson); + } + else if (sseEventName.Span.SequenceEqual(s_threadMessageCompletionEventName.Span)) + { + results = [StreamingMessageCompletion.DeserializeStreamingMessageCompletion(sseDataJson)]; + } + + foreach (StreamingRunUpdate baseUpdate in results) + { + baseUpdate._originalJson = sseDataJson.Clone(); + } + return results; + } + + private static readonly ReadOnlyMemory s_threadMessageCreationEventName = "thread.message.created".AsMemory(); + private static readonly ReadOnlyMemory s_threadMessageDeltaEventName = "thread.message.delta".AsMemory(); + private static readonly ReadOnlyMemory s_threadMessageCompletionEventName = "thread.message.completed".AsMemory(); +} diff --git a/.dotnet/src/Custom/Assistants/StreamingRunUpdate.cs b/.dotnet/src/Custom/Assistants/StreamingRunUpdate.cs index 59af9e031..778d23f95 100644 --- a/.dotnet/src/Custom/Assistants/StreamingRunUpdate.cs +++ b/.dotnet/src/Custom/Assistants/StreamingRunUpdate.cs @@ -1,5 +1,6 @@ namespace OpenAI.Assistants; +using System; using System.Collections.Generic; using System.Text.Json; @@ -8,139 +9,10 @@ namespace OpenAI.Assistants; /// public partial class StreamingRunUpdate { - // To be implemented: properties/pattern for non-streaming objects (e.g. thread and run creation) - //public string AssistantId { get; } - //public string ThreadId { get; } - //public DateTimeOffset? ThreadCreatedAt { get; } - //public IReadOnlyDictionary ThreadMetadata; - //public string RunId { get; } - //public DateTimeOffset? RunCreatedAt { get; } - //public DateTimeOffset? RunStartedAt { get; } - //public RunStatus? RunStatus { get; } + private JsonElement _originalJson; + public JsonElement GetRawSseEvent() => _originalJson; - public string AssistantId { get; } - public string ThreadId { get; } - public string MessageId { get; } - public string RunId { get; } - public string RunStepId { get; } - public MessageRole? MessageRole { get; } - public MessageContent MessageContentUpdate { get; } - public int? MessageContentIndex { get; } - - internal static List DeserializeStreamingRunUpdates(JsonElement element) - { - List results = []; - if (element.ValueKind == JsonValueKind.Null) - { - return results; - } - string objectLabel = null; - string assistantId = null; - string threadId = null; - string messageId = null; - string runId = null; - string runStepId = null; - MessageRole? role = null; - List<(MessageContent, int?)?> deltaContentItems = [null]; - foreach (JsonProperty property in element.EnumerateObject()) - { - if (property.NameEquals("object"u8)) - { - objectLabel = property.Value.GetString(); - continue; - } - if (property.NameEquals("id"u8)) - { - if (objectLabel?.Contains("run.step") == true) - { - runStepId = property.Value.GetString(); - } - else if (objectLabel?.Contains("run") == true) - { - runId = property.Value.GetString(); - } - else if (objectLabel?.Contains("message") == true) - { - messageId = property.Value.GetString(); - } - else if (objectLabel?.Contains("thread") == true) - { - threadId = property.Value.GetString(); - } - continue; - } - if (property.NameEquals("assistant_id")) - { - assistantId = property.Value.GetString(); - continue; - } - if (property.NameEquals("role")) - { - string roleText = property.Value.GetString(); - if (roleText == "user") - { - role = Assistants.MessageRole.User; - } - else if (roleText == "assistant") - { - role = Assistants.MessageRole.Assistant; - } - continue; - } - if (property.NameEquals("delta"u8)) - { - foreach (JsonProperty messageDeltaProperty in property.Value.EnumerateObject()) - { - if (messageDeltaProperty.NameEquals("content"u8)) - { - deltaContentItems.Clear(); - int? contentIndex = null; - foreach (JsonElement contentArrayElement in messageDeltaProperty.Value.EnumerateArray()) - { - foreach (JsonProperty contentItemProperty in contentArrayElement.EnumerateObject()) - { - if (contentItemProperty.NameEquals("index"u8)) - { - contentIndex = contentItemProperty.Value.GetInt32(); - continue; - } - } - deltaContentItems.Add((MessageContent.DeserializeMessageContent(contentArrayElement), contentIndex)); - } - continue; - } - } - continue; - } - } - foreach ((MessageContent, int?)? deltaContentItem in deltaContentItems) - { - results.Add( - new StreamingRunUpdate( - assistantId, - threadId, - runId, - runStepId, - role, - deltaContentItem)); - } - return results; - } - - internal StreamingRunUpdate( - string assistantId, - string threadId, - string runId, - string runStepId, - MessageRole? role, - (MessageContent, int?)? indexedContentUpdate) + protected StreamingRunUpdate() { - AssistantId = assistantId; - ThreadId = threadId; - RunId = runId; - RunStepId = runStepId; - MessageRole = role; - MessageContentUpdate = indexedContentUpdate?.Item1 ?? null; - MessageContentIndex = indexedContentUpdate?.Item2 ?? null; } } diff --git a/.dotnet/src/Custom/Chat/ChatClient.cs b/.dotnet/src/Custom/Chat/ChatClient.cs index bfab9fc10..fde5594d2 100644 --- a/.dotnet/src/Custom/Chat/ChatClient.cs +++ b/.dotnet/src/Custom/Chat/ChatClient.cs @@ -271,9 +271,9 @@ public virtual StreamingClientResult CompleteChatStreaming( ClientResult genericResult = ClientResult.FromResponse(response); return StreamingClientResult.CreateFromResponse( genericResult, - (responseForEnumeration) => SseAsyncEnumerator.EnumerateFromSseStream( + (responseForEnumeration) => SseAsyncEnumerator.EnumerateFromSseJsonStream( responseForEnumeration.GetRawResponse().ContentStream, - e => StreamingChatUpdate.DeserializeStreamingChatUpdates(e))); + StreamingChatUpdate.DeserializeSseChatUpdates)); } /// @@ -309,9 +309,9 @@ public virtual async Task> CompleteCh ClientResult genericResult = ClientResult.FromResponse(response); return StreamingClientResult.CreateFromResponse( genericResult, - (responseForEnumeration) => SseAsyncEnumerator.EnumerateFromSseStream( + (responseForEnumeration) => SseAsyncEnumerator.EnumerateFromSseJsonStream( responseForEnumeration.GetRawResponse().ContentStream, - e => StreamingChatUpdate.DeserializeStreamingChatUpdates(e))); + StreamingChatUpdate.DeserializeSseChatUpdates)); } private Internal.Models.CreateChatCompletionRequest CreateInternalRequest( diff --git a/.dotnet/src/Custom/Chat/StreamingChatUpdate.cs b/.dotnet/src/Custom/Chat/StreamingChatUpdate.cs index c1540897b..f7e9e9ec4 100644 --- a/.dotnet/src/Custom/Chat/StreamingChatUpdate.cs +++ b/.dotnet/src/Custom/Chat/StreamingChatUpdate.cs @@ -182,17 +182,17 @@ internal StreamingChatUpdate( LogProbabilities = logProbabilities; } - internal static List DeserializeStreamingChatUpdates(JsonElement element) + internal static IEnumerable DeserializeSseChatUpdates(ReadOnlyMemory _, JsonElement sseDataJson) { List results = []; - if (element.ValueKind == JsonValueKind.Null) + if (sseDataJson.ValueKind == JsonValueKind.Null) { return results; } string id = default; DateTimeOffset created = default; string systemFingerprint = null; - foreach (JsonProperty property in element.EnumerateObject()) + foreach (JsonProperty property in sseDataJson.EnumerateObject()) { if (property.NameEquals("id"u8)) { diff --git a/.dotnet/src/Generated/Models/MessageObject.Serialization.cs b/.dotnet/src/Generated/Models/MessageObject.Serialization.cs index 61f9a92ca..0bbbc7db0 100644 --- a/.dotnet/src/Generated/Models/MessageObject.Serialization.cs +++ b/.dotnet/src/Generated/Models/MessageObject.Serialization.cs @@ -217,7 +217,9 @@ internal static MessageObject DeserializeMessageObject(JsonElement element, Mode completedAt = null; continue; } - completedAt = property.Value.GetDateTimeOffset("O"); + // BUG: https://github.com/Azure/autorest.csharp/issues/4296 + // completedAt = property.Value.GetDateTimeOffset("O"); + completedAt = DateTimeOffset.FromUnixTimeSeconds(property.Value.GetInt64()); continue; } if (property.NameEquals("incomplete_at"u8)) @@ -227,7 +229,9 @@ internal static MessageObject DeserializeMessageObject(JsonElement element, Mode incompleteAt = null; continue; } - incompleteAt = property.Value.GetDateTimeOffset("O"); + // BUG: https://github.com/Azure/autorest.csharp/issues/4296 + // completedAt = property.Value.GetDateTimeOffset("O"); + completedAt = DateTimeOffset.FromUnixTimeSeconds(property.Value.GetInt64()); continue; } if (property.NameEquals("role"u8)) diff --git a/.dotnet/src/Generated/Models/RunObject.Serialization.cs b/.dotnet/src/Generated/Models/RunObject.Serialization.cs index 08d10f0d4..0c6258ec7 100644 --- a/.dotnet/src/Generated/Models/RunObject.Serialization.cs +++ b/.dotnet/src/Generated/Models/RunObject.Serialization.cs @@ -269,8 +269,12 @@ internal static RunObject DeserializeRunObject(JsonElement element, ModelReaderW continue; } // BUG: https://github.com/Azure/autorest.csharp/issues/4296 + // // BUG: https://github.com/Azure/autorest.csharp/issues/4296 + // // BUG: https://github.com/Azure/autorest.csharp/issues/4296 // expiresAt = property.Value.GetDateTimeOffset("O"); expiresAt = DateTimeOffset.FromUnixTimeSeconds(property.Value.GetInt64()); + expiresAt = DateTimeOffset.FromUnixTimeSeconds(property.Value.GetInt64()); + expiresAt = DateTimeOffset.FromUnixTimeSeconds(property.Value.GetInt64()); continue; } if (property.NameEquals("started_at"u8)) @@ -281,8 +285,12 @@ internal static RunObject DeserializeRunObject(JsonElement element, ModelReaderW continue; } // BUG: https://github.com/Azure/autorest.csharp/issues/4296 + // // BUG: https://github.com/Azure/autorest.csharp/issues/4296 + // // BUG: https://github.com/Azure/autorest.csharp/issues/4296 // startedAt = property.Value.GetDateTimeOffset("O"); startedAt = DateTimeOffset.FromUnixTimeSeconds(property.Value.GetInt64()); + startedAt = DateTimeOffset.FromUnixTimeSeconds(property.Value.GetInt64()); + startedAt = DateTimeOffset.FromUnixTimeSeconds(property.Value.GetInt64()); continue; } if (property.NameEquals("cancelled_at"u8)) @@ -293,8 +301,12 @@ internal static RunObject DeserializeRunObject(JsonElement element, ModelReaderW continue; } // BUG: https://github.com/Azure/autorest.csharp/issues/4296 + // // BUG: https://github.com/Azure/autorest.csharp/issues/4296 + // // BUG: https://github.com/Azure/autorest.csharp/issues/4296 // cancelledAt = property.Value.GetDateTimeOffset("O"); cancelledAt = DateTimeOffset.FromUnixTimeSeconds(property.Value.GetInt64()); + cancelledAt = DateTimeOffset.FromUnixTimeSeconds(property.Value.GetInt64()); + cancelledAt = DateTimeOffset.FromUnixTimeSeconds(property.Value.GetInt64()); continue; } if (property.NameEquals("failed_at"u8)) @@ -305,8 +317,12 @@ internal static RunObject DeserializeRunObject(JsonElement element, ModelReaderW continue; } // BUG: https://github.com/Azure/autorest.csharp/issues/4296 + // // BUG: https://github.com/Azure/autorest.csharp/issues/4296 + // // BUG: https://github.com/Azure/autorest.csharp/issues/4296 // failedAt = property.Value.GetDateTimeOffset("O"); failedAt = DateTimeOffset.FromUnixTimeSeconds(property.Value.GetInt64()); + failedAt = DateTimeOffset.FromUnixTimeSeconds(property.Value.GetInt64()); + failedAt = DateTimeOffset.FromUnixTimeSeconds(property.Value.GetInt64()); continue; } if (property.NameEquals("completed_at"u8)) @@ -317,8 +333,12 @@ internal static RunObject DeserializeRunObject(JsonElement element, ModelReaderW continue; } // BUG: https://github.com/Azure/autorest.csharp/issues/4296 + // // BUG: https://github.com/Azure/autorest.csharp/issues/4296 + // // BUG: https://github.com/Azure/autorest.csharp/issues/4296 // completedAt = property.Value.GetDateTimeOffset("O"); completedAt = DateTimeOffset.FromUnixTimeSeconds(property.Value.GetInt64()); + completedAt = DateTimeOffset.FromUnixTimeSeconds(property.Value.GetInt64()); + completedAt = DateTimeOffset.FromUnixTimeSeconds(property.Value.GetInt64()); continue; } if (property.NameEquals("model"u8)) diff --git a/.dotnet/src/OpenAI.csproj b/.dotnet/src/OpenAI.csproj index ed91914a5..b7f96c9a7 100644 --- a/.dotnet/src/OpenAI.csproj +++ b/.dotnet/src/OpenAI.csproj @@ -11,6 +11,5 @@ - diff --git a/.dotnet/src/Utility/SseAsyncEnumerator.cs b/.dotnet/src/Utility/SseAsyncEnumerator.cs index 96e0b9995..a5b89e278 100644 --- a/.dotnet/src/Utility/SseAsyncEnumerator.cs +++ b/.dotnet/src/Utility/SseAsyncEnumerator.cs @@ -1,7 +1,9 @@ using System; using System.Collections.Generic; using System.IO; +using System.Linq; using System.Runtime.CompilerServices; +using System.Text; using System.Text.Json; using System.Threading; @@ -9,9 +11,9 @@ namespace OpenAI; internal static class SseAsyncEnumerator { - internal static async IAsyncEnumerable EnumerateFromSseStream( + internal static async IAsyncEnumerable EnumerateFromSseJsonStream( Stream stream, - Func> multiElementDeserializer, + Func, JsonElement, IEnumerable> multiElementDeserializer, [EnumeratorCancellation] CancellationToken cancellationToken = default) { try @@ -26,13 +28,9 @@ internal static async IAsyncEnumerable EnumerateFromSseStream( } else { - if (sseEvent.Value.Data.Span.SequenceEqual("[DONE]".AsSpan())) - { - break; - } - using JsonDocument sseMessageJson = JsonDocument.Parse(sseEvent.Value.Data); - IEnumerable newItems = multiElementDeserializer.Invoke(sseMessageJson.RootElement); - foreach (T item in newItems) + if (IsWellKnownDoneToken(sseEvent.Value.Data)) continue; + using JsonDocument sseDocument = JsonDocument.Parse(sseEvent.Value.Data); + foreach (T item in multiElementDeserializer(sseEvent.Value.EventType, sseDocument.RootElement)) { yield return item; } @@ -46,12 +44,12 @@ internal static async IAsyncEnumerable EnumerateFromSseStream( } } - internal static IAsyncEnumerable EnumerateFromSseStream( - Stream stream, - Func elementDeserializer, - CancellationToken cancellationToken = default) - => EnumerateFromSseStream( - stream, - (element) => new T[] { elementDeserializer.Invoke(element) }, - cancellationToken); + private static bool IsWellKnownDoneToken(ReadOnlyMemory data) + { + ReadOnlyMemory[] wellKnownTokens = + [ + "[DONE]".AsMemory(), + ]; + return wellKnownTokens.Any(token => data.Span.SequenceEqual(token.Span)); + } } \ No newline at end of file diff --git a/.dotnet/tests/OpenAI.Tests.csproj b/.dotnet/tests/OpenAI.Tests.csproj index aaf8385b2..6271699d0 100644 --- a/.dotnet/tests/OpenAI.Tests.csproj +++ b/.dotnet/tests/OpenAI.Tests.csproj @@ -4,6 +4,9 @@ $(NoWarn);CS1591 latest + latest + latest + latest diff --git a/.dotnet/tests/TestScenarios/AssistantTests.cs b/.dotnet/tests/TestScenarios/AssistantTests.cs index 031b93b0e..f64f38bab 100644 --- a/.dotnet/tests/TestScenarios/AssistantTests.cs +++ b/.dotnet/tests/TestScenarios/AssistantTests.cs @@ -144,11 +144,19 @@ public async Task StreamingRunWorks() Assert.That(runUpdateResult, Is.Not.Null); await foreach (StreamingRunUpdate runUpdate in runUpdateResult) { - if (runUpdate.MessageRole.HasValue) + if (runUpdate is StreamingMessageCreation messageCreation) { - Console.Write($"[{runUpdate.MessageRole}]"); + Console.WriteLine($"Message created, id={messageCreation.Id}"); + } + if (runUpdate is StreamingMessageUpdate messageUpdate) + { + Console.Write(messageUpdate.ContentUpdate.GetText()); + } + if (runUpdate is StreamingMessageCompletion messageCompletion) + { + Console.WriteLine(); + Console.WriteLine($"Message complete: {messageCompletion.Message.ContentItems[0].GetText()}"); } - Console.Write(runUpdate.MessageContentUpdate?.GetText()); } } From 716dea2f9a79379f1d035243d3e69bad7202daf1 Mon Sep 17 00:00:00 2001 From: Travis Wilson Date: Mon, 18 Mar 2024 19:23:41 -0700 Subject: [PATCH 04/15] more streaming types + submit tool outputs --- .../Assistants/AssistantClient.Protocol.cs | 62 +++++++++-------- .../src/Custom/Assistants/AssistantClient.cs | 34 ++++++++-- .../StreamingRequiredAction.Serialization.cs | 19 ++++++ .../Assistants/StreamingRequiredAction.cs | 11 +++ .../StreamingRunCreation.Serialization.cs | 16 +++++ .../Custom/Assistants/StreamingRunCreation.cs | 15 ++++ .../StreamingRunUpdate.Serialization.cs | 18 ++++- .../Models/RunObject.Serialization.cs | 20 ------ ...bmitToolOutputsRunRequest.Serialization.cs | 17 ++++- .../Models/SubmitToolOutputsRunRequest.cs | 12 +++- .dotnet/src/OpenAI.csproj | 1 + .dotnet/tests/OpenAI.Tests.csproj | 3 - .dotnet/tests/TestScenarios/AssistantTests.cs | 68 ++++++++++++++++++- runs/models.tsp | 6 ++ tsp-output/@typespec/openapi3/openapi.yaml | 5 ++ 15 files changed, 249 insertions(+), 58 deletions(-) create mode 100644 .dotnet/src/Custom/Assistants/StreamingRequiredAction.Serialization.cs create mode 100644 .dotnet/src/Custom/Assistants/StreamingRequiredAction.cs create mode 100644 .dotnet/src/Custom/Assistants/StreamingRunCreation.Serialization.cs create mode 100644 .dotnet/src/Custom/Assistants/StreamingRunCreation.cs diff --git a/.dotnet/src/Custom/Assistants/AssistantClient.Protocol.cs b/.dotnet/src/Custom/Assistants/AssistantClient.Protocol.cs index 009a3ccd8..063cdab2e 100644 --- a/.dotnet/src/Custom/Assistants/AssistantClient.Protocol.cs +++ b/.dotnet/src/Custom/Assistants/AssistantClient.Protocol.cs @@ -459,6 +459,19 @@ public virtual ClientResult SubmitToolOutputs( RequestOptions options = null) => RunShim.SubmitToolOuputsToRun(threadId, runId, content, options); + /// + [EditorBrowsable(EditorBrowsableState.Never)] + public virtual ClientResult SubmitToolOutputsStreaming( + string threadId, + string runId, + BinaryContent content, + RequestOptions options = null) + { + PipelineMessage message = CreateSubmitToolOutputsRequest(threadId, runId, content, stream: true, options); + RunShim.Pipeline.Send(message); + return ClientResult.FromResponse(message.ExtractResponse()); + } + /// [EditorBrowsable(EditorBrowsableState.Never)] public virtual async Task SubmitToolOutputsAsync( @@ -468,6 +481,19 @@ public virtual async Task SubmitToolOutputsAsync( RequestOptions options = null) => await RunShim.SubmitToolOuputsToRunAsync(threadId, runId, content, options).ConfigureAwait(false); + /// + [EditorBrowsable(EditorBrowsableState.Never)] + public virtual async Task SubmitToolOutputsStreamingAsync( + string threadId, + string runId, + BinaryContent content, + RequestOptions options = null) + { + PipelineMessage message = CreateSubmitToolOutputsRequest(threadId, runId, content, stream: true, options); + await RunShim.Pipeline.SendAsync(message); + return ClientResult.FromResponse(message.ExtractResponse()); + } + /// public virtual ClientResult GetRunStep( string threadId, @@ -507,31 +533,15 @@ public virtual async Task GetRunStepsAsync( => await RunShim.GetRunStepsAsync(threadId, runId, maxResults, createdSortOrder, previousStepId, subsequentStepId, options).ConfigureAwait(false); internal PipelineMessage CreateCreateRunRequest(string threadId, BinaryContent content, bool? stream = null, RequestOptions options = null) - { - PipelineMessage message = Shim.Pipeline.CreateMessage(); - message.ResponseClassifier = ResponseErrorClassifier200; - if (stream == true) - { - message.BufferResponse = false; - } - PipelineRequest request = message.Request; - request.Method = "POST"; - UriBuilder uriBuilder = new(_clientConnector.Endpoint.AbsoluteUri); - StringBuilder path = new(); - path.Append("/threads/"); - path.Append(threadId); - path.Append("/runs"); - uriBuilder.Path += path.ToString(); - uriBuilder.Query += $"?thread_id={threadId}"; - request.Uri = uriBuilder.Uri; - request.Headers.Set("Content-Type", "application/json"); - request.Headers.Set("Accept", "text/event-stream"); - request.Content = content; - message.Apply(options ?? new()); - return message; - } + => CreateAssistantProtocolRequest($"/threads/{threadId}/runs", content, stream, options); internal PipelineMessage CreateCreateThreadAndRunRequest(BinaryContent content, bool? stream = null, RequestOptions options = null) + => CreateAssistantProtocolRequest($"/threads/runs", content, stream, options); + + internal PipelineMessage CreateSubmitToolOutputsRequest(string threadId, string runId, BinaryContent content, bool? stream = null, RequestOptions options = null) + => CreateAssistantProtocolRequest($"/threads/{threadId}/runs/{runId}/submit_tool_outputs", content, stream, options); + + internal PipelineMessage CreateAssistantProtocolRequest(string path, BinaryContent content, bool? stream = null, RequestOptions options = null) { PipelineMessage message = Shim.Pipeline.CreateMessage(); message.ResponseClassifier = ResponseErrorClassifier200; @@ -542,12 +552,10 @@ internal PipelineMessage CreateCreateThreadAndRunRequest(BinaryContent content, PipelineRequest request = message.Request; request.Method = "POST"; UriBuilder uriBuilder = new(_clientConnector.Endpoint.AbsoluteUri); - StringBuilder path = new(); - path.Append("/threads/runs"); - uriBuilder.Path += path.ToString(); + uriBuilder.Path += path; request.Uri = uriBuilder.Uri; request.Headers.Set("Content-Type", "application/json"); - request.Headers.Set("Accept", "text/event-stream"); + request.Headers.Set("Accept", stream == true ? "text/event-stream" : "application/json"); request.Content = content; message.Apply(options ?? new()); return message; diff --git a/.dotnet/src/Custom/Assistants/AssistantClient.cs b/.dotnet/src/Custom/Assistants/AssistantClient.cs index b1a575680..2efcb37b1 100644 --- a/.dotnet/src/Custom/Assistants/AssistantClient.cs +++ b/.dotnet/src/Custom/Assistants/AssistantClient.cs @@ -514,7 +514,7 @@ public virtual StreamingClientResult CreateRunStreaming( string assistantId, RunCreationOptions options = null) { - PipelineMessage message = CreateCreateRunRequest(threadId, assistantId, options); + PipelineMessage message = CreateCreateRunRequest(threadId, assistantId, options, stream: true); RunShim.Pipeline.Send(message); return CreateStreamingRunResult(message); } @@ -524,7 +524,7 @@ public virtual async Task> CreateRunSt string assistantId, RunCreationOptions options = null) { - PipelineMessage message = CreateCreateRunRequest(threadId, assistantId, options); + PipelineMessage message = CreateCreateRunRequest(threadId, assistantId, options, stream: true); await RunShim.Pipeline.SendAsync(message); return CreateStreamingRunResult(message); } @@ -654,7 +654,7 @@ public virtual ClientResult SubmitToolOutputs(string threadId, string requestToolOutputs.Add(new(toolOutput.Id, toolOutput.Output, null)); } - Internal.Models.SubmitToolOutputsRunRequest request = new(requestToolOutputs, null); + Internal.Models.SubmitToolOutputsRunRequest request = new(requestToolOutputs, null, serializedAdditionalRawData: null); ClientResult internalResult = RunShim.SubmitToolOuputsToRun(threadId, runId, request); return ClientResult.FromValue(new ThreadRun(internalResult.Value), internalResult.GetRawResponse()); } @@ -668,11 +668,25 @@ public virtual async Task> SubmitToolOutputsAsync(string requestToolOutputs.Add(new(toolOutput.Id, toolOutput.Output, null)); } - Internal.Models.SubmitToolOutputsRunRequest request = new(requestToolOutputs, null); + Internal.Models.SubmitToolOutputsRunRequest request = new(requestToolOutputs, null, serializedAdditionalRawData: null); ClientResult internalResult = await RunShim.SubmitToolOuputsToRunAsync(threadId, runId, request).ConfigureAwait(false); return ClientResult.FromValue(new ThreadRun(internalResult.Value), internalResult.GetRawResponse()); } + public virtual StreamingClientResult SubmitToolOutputsStreaming(string threadId, string runId, IEnumerable toolOutputs) + { + PipelineMessage message = CreateSubmitToolOutputsRequest(threadId, runId, toolOutputs, stream: true); + Shim.Pipeline.SendAsync(message); + return CreateStreamingRunResult(message); + } + + public virtual async Task> SubmitToolOutputsStreamingAsync(string threadId, string runId, IEnumerable toolOutputs) + { + PipelineMessage message = CreateSubmitToolOutputsRequest(threadId, runId, toolOutputs, stream: true); + await Shim.Pipeline.SendAsync(message); + return CreateStreamingRunResult(message); + } + internal PipelineMessage CreateCreateRunRequest(string threadId, string assistantId, RunCreationOptions runOptions, bool? stream = null) { Internal.Models.CreateRunRequest internalCreateRunRequest = CreateInternalCreateRunRequest(assistantId, runOptions, stream); @@ -692,6 +706,18 @@ Internal.Models.CreateThreadAndRunRequest internalRequest return CreateCreateThreadAndRunRequest(content, stream: true); } + internal PipelineMessage CreateSubmitToolOutputsRequest(string threadId, string runId, IEnumerable toolOutputs, bool? stream) + { + List requestToolOutputs = []; + foreach (ToolOutput toolOutput in toolOutputs) + { + requestToolOutputs.Add(new(toolOutput.Id, toolOutput.Output, null)); + } + Internal.Models.SubmitToolOutputsRunRequest internalRequest = new(requestToolOutputs, stream, serializedAdditionalRawData: null); + BinaryContent content = BinaryContent.Create(internalRequest); + return CreateSubmitToolOutputsRequest(threadId, runId, content, stream: true); + } + internal static StreamingClientResult CreateStreamingRunResult(PipelineMessage message) { if (message.Response.IsError) diff --git a/.dotnet/src/Custom/Assistants/StreamingRequiredAction.Serialization.cs b/.dotnet/src/Custom/Assistants/StreamingRequiredAction.Serialization.cs new file mode 100644 index 000000000..e89c6e0c1 --- /dev/null +++ b/.dotnet/src/Custom/Assistants/StreamingRequiredAction.Serialization.cs @@ -0,0 +1,19 @@ +namespace OpenAI.Assistants; + +using System.ClientModel.Primitives; +using System.Collections.Generic; +using System.Text.Json; + +public partial class StreamingRequiredAction +{ + internal static IEnumerable DeserializeStreamingRequiredActions(JsonElement sseDataJson, ModelReaderWriterOptions options = null) + { + ThreadRun run = new(Internal.Models.RunObject.DeserializeRunObject(sseDataJson, options)); + List results = []; + foreach (RunRequiredAction deserializedAction in run.RequiredActions) + { + results.Add(new(deserializedAction)); + } + return results; + } +} diff --git a/.dotnet/src/Custom/Assistants/StreamingRequiredAction.cs b/.dotnet/src/Custom/Assistants/StreamingRequiredAction.cs new file mode 100644 index 000000000..a9ab14074 --- /dev/null +++ b/.dotnet/src/Custom/Assistants/StreamingRequiredAction.cs @@ -0,0 +1,11 @@ +namespace OpenAI.Assistants; + +public partial class StreamingRequiredAction : StreamingRunUpdate +{ + public RunRequiredAction RequiredAction { get; } + + internal StreamingRequiredAction(RunRequiredAction requiredAction) + { + RequiredAction = requiredAction; + } +} diff --git a/.dotnet/src/Custom/Assistants/StreamingRunCreation.Serialization.cs b/.dotnet/src/Custom/Assistants/StreamingRunCreation.Serialization.cs new file mode 100644 index 000000000..49938bd32 --- /dev/null +++ b/.dotnet/src/Custom/Assistants/StreamingRunCreation.Serialization.cs @@ -0,0 +1,16 @@ +namespace OpenAI.Assistants; + +using System; +using System.ClientModel.Primitives; +using System.Collections.Generic; +using System.Text.Json; + +public partial class StreamingRunCreation +{ + internal static StreamingRunUpdate DeserializeStreamingRunCreation(JsonElement element, ModelReaderWriterOptions options = null) + { + Internal.Models.RunObject internalRun = Internal.Models.RunObject.DeserializeRunObject(element, options); + ThreadRun run = new(internalRun); + return new StreamingRunCreation(run); + } +} diff --git a/.dotnet/src/Custom/Assistants/StreamingRunCreation.cs b/.dotnet/src/Custom/Assistants/StreamingRunCreation.cs new file mode 100644 index 000000000..7013fc89c --- /dev/null +++ b/.dotnet/src/Custom/Assistants/StreamingRunCreation.cs @@ -0,0 +1,15 @@ +namespace OpenAI.Assistants; + +using System; +using System.Collections.Generic; +using System.Text.Json; + +public partial class StreamingRunCreation : StreamingRunUpdate +{ + public ThreadRun Run { get; } + + internal StreamingRunCreation(ThreadRun run) + { + Run = run; + } +} diff --git a/.dotnet/src/Custom/Assistants/StreamingRunUpdate.Serialization.cs b/.dotnet/src/Custom/Assistants/StreamingRunUpdate.Serialization.cs index acedb5984..cc477b2ac 100644 --- a/.dotnet/src/Custom/Assistants/StreamingRunUpdate.Serialization.cs +++ b/.dotnet/src/Custom/Assistants/StreamingRunUpdate.Serialization.cs @@ -23,10 +23,24 @@ internal static IEnumerable DeserializeSseRunUpdates( { results = [StreamingMessageCompletion.DeserializeStreamingMessageCompletion(sseDataJson)]; } + else if (sseEventName.Span.SequenceEqual(s_runCreatedEventName.Span)) + { + results = [StreamingRunCreation.DeserializeStreamingRunCreation(sseDataJson)]; + } + + else if (sseEventName.Span.SequenceEqual(s_runRequiredActionEventName.Span)) + { + results = StreamingRequiredAction.DeserializeStreamingRequiredActions(sseDataJson); + } + else + { + results = [new StreamingRunUpdate()]; + } + JsonElement rawElementClone = sseDataJson.Clone(); foreach (StreamingRunUpdate baseUpdate in results) { - baseUpdate._originalJson = sseDataJson.Clone(); + baseUpdate._originalJson = rawElementClone; } return results; } @@ -34,4 +48,6 @@ internal static IEnumerable DeserializeSseRunUpdates( private static readonly ReadOnlyMemory s_threadMessageCreationEventName = "thread.message.created".AsMemory(); private static readonly ReadOnlyMemory s_threadMessageDeltaEventName = "thread.message.delta".AsMemory(); private static readonly ReadOnlyMemory s_threadMessageCompletionEventName = "thread.message.completed".AsMemory(); + private static readonly ReadOnlyMemory s_runCreatedEventName = "thread.run.created".AsMemory(); + private static readonly ReadOnlyMemory s_runRequiredActionEventName = "thread.run.requires_action".AsMemory(); } diff --git a/.dotnet/src/Generated/Models/RunObject.Serialization.cs b/.dotnet/src/Generated/Models/RunObject.Serialization.cs index 0c6258ec7..08d10f0d4 100644 --- a/.dotnet/src/Generated/Models/RunObject.Serialization.cs +++ b/.dotnet/src/Generated/Models/RunObject.Serialization.cs @@ -269,12 +269,8 @@ internal static RunObject DeserializeRunObject(JsonElement element, ModelReaderW continue; } // BUG: https://github.com/Azure/autorest.csharp/issues/4296 - // // BUG: https://github.com/Azure/autorest.csharp/issues/4296 - // // BUG: https://github.com/Azure/autorest.csharp/issues/4296 // expiresAt = property.Value.GetDateTimeOffset("O"); expiresAt = DateTimeOffset.FromUnixTimeSeconds(property.Value.GetInt64()); - expiresAt = DateTimeOffset.FromUnixTimeSeconds(property.Value.GetInt64()); - expiresAt = DateTimeOffset.FromUnixTimeSeconds(property.Value.GetInt64()); continue; } if (property.NameEquals("started_at"u8)) @@ -285,12 +281,8 @@ internal static RunObject DeserializeRunObject(JsonElement element, ModelReaderW continue; } // BUG: https://github.com/Azure/autorest.csharp/issues/4296 - // // BUG: https://github.com/Azure/autorest.csharp/issues/4296 - // // BUG: https://github.com/Azure/autorest.csharp/issues/4296 // startedAt = property.Value.GetDateTimeOffset("O"); startedAt = DateTimeOffset.FromUnixTimeSeconds(property.Value.GetInt64()); - startedAt = DateTimeOffset.FromUnixTimeSeconds(property.Value.GetInt64()); - startedAt = DateTimeOffset.FromUnixTimeSeconds(property.Value.GetInt64()); continue; } if (property.NameEquals("cancelled_at"u8)) @@ -301,12 +293,8 @@ internal static RunObject DeserializeRunObject(JsonElement element, ModelReaderW continue; } // BUG: https://github.com/Azure/autorest.csharp/issues/4296 - // // BUG: https://github.com/Azure/autorest.csharp/issues/4296 - // // BUG: https://github.com/Azure/autorest.csharp/issues/4296 // cancelledAt = property.Value.GetDateTimeOffset("O"); cancelledAt = DateTimeOffset.FromUnixTimeSeconds(property.Value.GetInt64()); - cancelledAt = DateTimeOffset.FromUnixTimeSeconds(property.Value.GetInt64()); - cancelledAt = DateTimeOffset.FromUnixTimeSeconds(property.Value.GetInt64()); continue; } if (property.NameEquals("failed_at"u8)) @@ -317,12 +305,8 @@ internal static RunObject DeserializeRunObject(JsonElement element, ModelReaderW continue; } // BUG: https://github.com/Azure/autorest.csharp/issues/4296 - // // BUG: https://github.com/Azure/autorest.csharp/issues/4296 - // // BUG: https://github.com/Azure/autorest.csharp/issues/4296 // failedAt = property.Value.GetDateTimeOffset("O"); failedAt = DateTimeOffset.FromUnixTimeSeconds(property.Value.GetInt64()); - failedAt = DateTimeOffset.FromUnixTimeSeconds(property.Value.GetInt64()); - failedAt = DateTimeOffset.FromUnixTimeSeconds(property.Value.GetInt64()); continue; } if (property.NameEquals("completed_at"u8)) @@ -333,12 +317,8 @@ internal static RunObject DeserializeRunObject(JsonElement element, ModelReaderW continue; } // BUG: https://github.com/Azure/autorest.csharp/issues/4296 - // // BUG: https://github.com/Azure/autorest.csharp/issues/4296 - // // BUG: https://github.com/Azure/autorest.csharp/issues/4296 // completedAt = property.Value.GetDateTimeOffset("O"); completedAt = DateTimeOffset.FromUnixTimeSeconds(property.Value.GetInt64()); - completedAt = DateTimeOffset.FromUnixTimeSeconds(property.Value.GetInt64()); - completedAt = DateTimeOffset.FromUnixTimeSeconds(property.Value.GetInt64()); continue; } if (property.NameEquals("model"u8)) diff --git a/.dotnet/src/Generated/Models/SubmitToolOutputsRunRequest.Serialization.cs b/.dotnet/src/Generated/Models/SubmitToolOutputsRunRequest.Serialization.cs index d38502d9b..be9df651c 100644 --- a/.dotnet/src/Generated/Models/SubmitToolOutputsRunRequest.Serialization.cs +++ b/.dotnet/src/Generated/Models/SubmitToolOutputsRunRequest.Serialization.cs @@ -28,6 +28,11 @@ void IJsonModel.Write(Utf8JsonWriter writer, ModelR writer.WriteObjectValue(item); } writer.WriteEndArray(); + if (Optional.IsDefined(Stream)) + { + writer.WritePropertyName("stream"u8); + writer.WriteBooleanValue(Stream.Value); + } if (options.Format != "W" && _serializedAdditionalRawData != null) { foreach (var item in _serializedAdditionalRawData) @@ -67,6 +72,7 @@ internal static SubmitToolOutputsRunRequest DeserializeSubmitToolOutputsRunReque return null; } IList toolOutputs = default; + bool? stream = default; IDictionary serializedAdditionalRawData = default; Dictionary additionalPropertiesDictionary = new Dictionary(); foreach (var property in element.EnumerateObject()) @@ -81,13 +87,22 @@ internal static SubmitToolOutputsRunRequest DeserializeSubmitToolOutputsRunReque toolOutputs = array; continue; } + if (property.NameEquals("stream"u8)) + { + if (property.Value.ValueKind == JsonValueKind.Null) + { + continue; + } + stream = property.Value.GetBoolean(); + continue; + } if (options.Format != "W") { additionalPropertiesDictionary.Add(property.Name, BinaryData.FromString(property.Value.GetRawText())); } } serializedAdditionalRawData = additionalPropertiesDictionary; - return new SubmitToolOutputsRunRequest(toolOutputs, serializedAdditionalRawData); + return new SubmitToolOutputsRunRequest(toolOutputs, stream, serializedAdditionalRawData); } BinaryData IPersistableModel.Write(ModelReaderWriterOptions options) diff --git a/.dotnet/src/Generated/Models/SubmitToolOutputsRunRequest.cs b/.dotnet/src/Generated/Models/SubmitToolOutputsRunRequest.cs index 2bed85b35..a4395b14c 100644 --- a/.dotnet/src/Generated/Models/SubmitToolOutputsRunRequest.cs +++ b/.dotnet/src/Generated/Models/SubmitToolOutputsRunRequest.cs @@ -54,10 +54,15 @@ public SubmitToolOutputsRunRequest(IEnumerable Initializes a new instance of . /// A list of tools for which the outputs are being submitted. + /// + /// If true, returns a stream of events that happen during the Run as server-sent events, + /// terminating when the Run enters a terminal state. + /// /// Keeps track of any properties unknown to the library. - internal SubmitToolOutputsRunRequest(IList toolOutputs, IDictionary serializedAdditionalRawData) + internal SubmitToolOutputsRunRequest(IList toolOutputs, bool? stream, IDictionary serializedAdditionalRawData) { ToolOutputs = toolOutputs; + Stream = stream; _serializedAdditionalRawData = serializedAdditionalRawData; } @@ -68,5 +73,10 @@ internal SubmitToolOutputsRunRequest() /// A list of tools for which the outputs are being submitted. public IList ToolOutputs { get; } + /// + /// If true, returns a stream of events that happen during the Run as server-sent events, + /// terminating when the Run enters a terminal state. + /// + public bool? Stream { get; set; } } } diff --git a/.dotnet/src/OpenAI.csproj b/.dotnet/src/OpenAI.csproj index b7f96c9a7..ed91914a5 100644 --- a/.dotnet/src/OpenAI.csproj +++ b/.dotnet/src/OpenAI.csproj @@ -11,5 +11,6 @@ + diff --git a/.dotnet/tests/OpenAI.Tests.csproj b/.dotnet/tests/OpenAI.Tests.csproj index 6271699d0..aaf8385b2 100644 --- a/.dotnet/tests/OpenAI.Tests.csproj +++ b/.dotnet/tests/OpenAI.Tests.csproj @@ -4,9 +4,6 @@ $(NoWarn);CS1591 latest - latest - latest - latest diff --git a/.dotnet/tests/TestScenarios/AssistantTests.cs b/.dotnet/tests/TestScenarios/AssistantTests.cs index f64f38bab..460e887ae 100644 --- a/.dotnet/tests/TestScenarios/AssistantTests.cs +++ b/.dotnet/tests/TestScenarios/AssistantTests.cs @@ -1,7 +1,9 @@ using NUnit.Framework; using OpenAI.Assistants; +using OpenAI.Chat; using System; using System.ClientModel; +using System.Collections.Generic; using System.Threading.Tasks; using static OpenAI.Tests.TestHelpers; @@ -127,7 +129,7 @@ public async Task BasicFunctionToolWorks() } [Test] - public async Task StreamingRunWorks() + public async Task SimpleStreamingRunWorks() { AssistantClient client = GetTestClient(); Assistant assistant = await CreateCommonTestAssistantAsync(); @@ -160,6 +162,70 @@ public async Task StreamingRunWorks() } } + [Test] + public async Task StreamingWithToolsWorks() + { + AssistantClient client = GetTestClient(); + ClientResult assistantResult = await client.CreateAssistantAsync("gpt-3.5-turbo", new AssistantCreationOptions() + { + Instructions = "You are a helpful math assistant that helps with visualizing equations. Use the code interpreter tool when asked to generate images. Use provided functions to resolve appropriate unknown values", + Tools = + { + new CodeInterpreterToolDefinition(), + new FunctionToolDefinition("get_boilerplate_equation", "Retrieves a predefined 'boilerplate equation' from the caller."), + }, + Metadata = { [s_cleanupMetadataKey] = "true" }, + }); + Assistant assistant = assistantResult.Value; + Assert.That(assistant, Is.Not.Null); + + ClientResult threadResult = await client.CreateThreadAsync(new ThreadCreationOptions() + { + Messages = + { + "Please make a graph for my boilerplate equation", + }, + }); + AssistantThread thread = threadResult.Value; + Assert.That(thread, Is.Not.Null); + + StreamingClientResult streamingResult = await client.CreateRunStreamingAsync(thread.Id, assistant.Id); + Assert.That(streamingResult, Is.Not.Null); + List requiredActions = []; + ThreadRun initialStreamedRun = null; + await foreach (StreamingRunUpdate streamingUpdate in streamingResult) + { + if (streamingUpdate is StreamingRunCreation streamingRunCreation) + { + initialStreamedRun = streamingRunCreation.Run; + } + if (streamingUpdate is StreamingRequiredAction streamedRequiredAction) + { + requiredActions.Add(streamedRequiredAction.RequiredAction); + } + Console.WriteLine(streamingUpdate.GetRawSseEvent().ToString()); + } + Assert.That(initialStreamedRun?.Id, Is.Not.Null.Or.Empty); + Assert.That(requiredActions, Is.Not.Empty); + + List toolOutputs = []; + foreach (RunRequiredAction requiredAction in requiredActions) + { + if (requiredAction is RequiredFunctionToolCall functionCall) + { + if (functionCall.Name == "get_boilerplate_equation") + { + toolOutputs.Add(new(functionCall, "y = 14x - 3")); + } + } + } + streamingResult = await client.SubmitToolOutputsStreamingAsync(thread.Id, initialStreamedRun.Id, toolOutputs); + await foreach (StreamingRunUpdate streamingUpdate in streamingResult) + { + Console.WriteLine(streamingUpdate.GetRawSseEvent().ToString()); + } + } + private async Task CreateCommonTestAssistantAsync() { AssistantClient client = new(); diff --git a/runs/models.tsp b/runs/models.tsp index 7619a731d..15ab33b54 100644 --- a/runs/models.tsp +++ b/runs/models.tsp @@ -112,6 +112,12 @@ model SubmitToolOutputsRunRequest { /** The output of the tool call to be submitted to continue the run. */ output?: string; }[]; + + /** + * If true, returns a stream of events that happen during the Run as server-sent events, + * terminating when the Run enters a terminal state. + */ + stream?: boolean; } model ListRunsResponse { diff --git a/tsp-output/@typespec/openapi3/openapi.yaml b/tsp-output/@typespec/openapi3/openapi.yaml index ea72756be..ade9e3fd7 100644 --- a/tsp-output/@typespec/openapi3/openapi.yaml +++ b/tsp-output/@typespec/openapi3/openapi.yaml @@ -5557,6 +5557,11 @@ components: type: string description: The output of the tool call to be submitted to continue the run. description: A list of tools for which the outputs are being submitted. + stream: + type: boolean + description: |- + If true, returns a stream of events that happen during the Run as server-sent events, + terminating when the Run enters a terminal state. SuffixString: type: string minLength: 1 From f01eb1c9e58dc4187653c8306badca12d62a69a3 Mon Sep 17 00:00:00 2001 From: Travis Wilson Date: Tue, 19 Mar 2024 10:19:53 -0700 Subject: [PATCH 05/15] small consistency pass for easier pattern discussion --- .../src/Custom/Assistants/AssistantClient.cs | 20 ++--- ...treamingMessageCompletion.Serialization.cs | 2 +- .../Assistants/StreamingMessageCompletion.cs | 6 +- .../StreamingMessageCreation.Serialization.cs | 77 ++----------------- .../Assistants/StreamingMessageCreation.cs | 34 +------- .../StreamingMessageUpdate.Serialization.cs | 10 +-- .../Assistants/StreamingMessageUpdate.cs | 5 +- .../Assistants/StreamingRequiredAction.cs | 2 +- .../StreamingRunCreation.Serialization.cs | 4 +- .../Custom/Assistants/StreamingRunCreation.cs | 6 +- ...on.cs => StreamingUpdate.Serialization.cs} | 10 +-- ...reamingRunUpdate.cs => StreamingUpdate.cs} | 7 +- .dotnet/tests/TestScenarios/AssistantTests.cs | 12 +-- 13 files changed, 40 insertions(+), 155 deletions(-) rename .dotnet/src/Custom/Assistants/{StreamingRunUpdate.Serialization.cs => StreamingUpdate.Serialization.cs} (87%) rename .dotnet/src/Custom/Assistants/{StreamingRunUpdate.cs => StreamingUpdate.cs} (71%) diff --git a/.dotnet/src/Custom/Assistants/AssistantClient.cs b/.dotnet/src/Custom/Assistants/AssistantClient.cs index 2efcb37b1..2bc2c9b72 100644 --- a/.dotnet/src/Custom/Assistants/AssistantClient.cs +++ b/.dotnet/src/Custom/Assistants/AssistantClient.cs @@ -509,7 +509,7 @@ public virtual async Task> CreateRunAsync( return ClientResult.FromValue(new ThreadRun(internalResult.Value), internalResult.GetRawResponse()); } - public virtual StreamingClientResult CreateRunStreaming( + public virtual StreamingClientResult CreateRunStreaming( string threadId, string assistantId, RunCreationOptions options = null) @@ -519,7 +519,7 @@ public virtual StreamingClientResult CreateRunStreaming( return CreateStreamingRunResult(message); } - public virtual async Task> CreateRunStreamingAsync( + public virtual async Task> CreateRunStreamingAsync( string threadId, string assistantId, RunCreationOptions options = null) @@ -552,7 +552,7 @@ Internal.Models.CreateThreadAndRunRequest request return ClientResult.FromValue(new ThreadRun(internalResult.Value), internalResult.GetRawResponse()); } - public virtual StreamingClientResult CreateThreadAndRunStreaming( + public virtual StreamingClientResult CreateThreadAndRunStreaming( string assistantId, ThreadCreationOptions threadOptions = null, RunCreationOptions runOptions = null) @@ -562,7 +562,7 @@ public virtual StreamingClientResult CreateThreadAndRunStrea return CreateStreamingRunResult(message); } - public virtual async Task> CreateThreadAndRunStreamingAsync( + public virtual async Task> CreateThreadAndRunStreamingAsync( string assistantId, ThreadCreationOptions threadOptions = null, RunCreationOptions runOptions = null) @@ -673,14 +673,14 @@ public virtual async Task> SubmitToolOutputsAsync(string return ClientResult.FromValue(new ThreadRun(internalResult.Value), internalResult.GetRawResponse()); } - public virtual StreamingClientResult SubmitToolOutputsStreaming(string threadId, string runId, IEnumerable toolOutputs) + public virtual StreamingClientResult SubmitToolOutputsStreaming(string threadId, string runId, IEnumerable toolOutputs) { PipelineMessage message = CreateSubmitToolOutputsRequest(threadId, runId, toolOutputs, stream: true); Shim.Pipeline.SendAsync(message); return CreateStreamingRunResult(message); } - public virtual async Task> SubmitToolOutputsStreamingAsync(string threadId, string runId, IEnumerable toolOutputs) + public virtual async Task> SubmitToolOutputsStreamingAsync(string threadId, string runId, IEnumerable toolOutputs) { PipelineMessage message = CreateSubmitToolOutputsRequest(threadId, runId, toolOutputs, stream: true); await Shim.Pipeline.SendAsync(message); @@ -718,7 +718,7 @@ internal PipelineMessage CreateSubmitToolOutputsRequest(string threadId, string return CreateSubmitToolOutputsRequest(threadId, runId, content, stream: true); } - internal static StreamingClientResult CreateStreamingRunResult(PipelineMessage message) + internal static StreamingClientResult CreateStreamingRunResult(PipelineMessage message) { if (message.Response.IsError) { @@ -729,11 +729,11 @@ internal static StreamingClientResult CreateStreamingRunResu { response = message.ExtractResponse(); ClientResult genericResult = ClientResult.FromResponse(response); - StreamingClientResult streamingResult = StreamingClientResult.CreateFromResponse( + StreamingClientResult streamingResult = StreamingClientResult.CreateFromResponse( genericResult, - (responseForEnumeration) => SseAsyncEnumerator.EnumerateFromSseJsonStream( + (responseForEnumeration) => SseAsyncEnumerator.EnumerateFromSseJsonStream( responseForEnumeration.GetRawResponse().ContentStream, - StreamingRunUpdate.DeserializeSseRunUpdates)); + StreamingUpdate.DeserializeSseRunUpdates)); response = null; return streamingResult; } diff --git a/.dotnet/src/Custom/Assistants/StreamingMessageCompletion.Serialization.cs b/.dotnet/src/Custom/Assistants/StreamingMessageCompletion.Serialization.cs index eab1a3f01..e5869c3d9 100644 --- a/.dotnet/src/Custom/Assistants/StreamingMessageCompletion.Serialization.cs +++ b/.dotnet/src/Custom/Assistants/StreamingMessageCompletion.Serialization.cs @@ -3,7 +3,7 @@ namespace OpenAI.Assistants; using System.ClientModel.Primitives; using System.Text.Json; -public partial class StreamingMessageCompletion : StreamingRunUpdate +public partial class StreamingMessageCompletion : StreamingUpdate { internal static StreamingMessageCompletion DeserializeStreamingMessageCompletion(JsonElement element, ModelReaderWriterOptions options = default) { diff --git a/.dotnet/src/Custom/Assistants/StreamingMessageCompletion.cs b/.dotnet/src/Custom/Assistants/StreamingMessageCompletion.cs index daede59df..212e97910 100644 --- a/.dotnet/src/Custom/Assistants/StreamingMessageCompletion.cs +++ b/.dotnet/src/Custom/Assistants/StreamingMessageCompletion.cs @@ -1,10 +1,6 @@ namespace OpenAI.Assistants; -using System; -using System.Collections.Generic; -using System.Text.Json; - -public partial class StreamingMessageCompletion : StreamingRunUpdate +public partial class StreamingMessageCompletion : StreamingUpdate { public ThreadMessage Message { get; } diff --git a/.dotnet/src/Custom/Assistants/StreamingMessageCreation.Serialization.cs b/.dotnet/src/Custom/Assistants/StreamingMessageCreation.Serialization.cs index 7f45b7339..170df7cd1 100644 --- a/.dotnet/src/Custom/Assistants/StreamingMessageCreation.Serialization.cs +++ b/.dotnet/src/Custom/Assistants/StreamingMessageCreation.Serialization.cs @@ -1,84 +1,17 @@ namespace OpenAI.Assistants; -using System; using System.ClientModel.Primitives; -using System.Collections.Generic; using System.Text.Json; public partial class StreamingMessageCreation { - internal static StreamingRunUpdate DeserializeSseMessageCreation( + internal static StreamingUpdate DeserializeSseMessageCreation( JsonElement sseDataJson, ModelReaderWriterOptions options = null) { - string id = null; - string assistantId = null; - string threadId = null; - string runId = null; - DateTimeOffset? createdAt = null; - MessageRole? role = null; - List contentItems = []; - List fileIds = []; - - foreach (JsonProperty property in sseDataJson.EnumerateObject()) - { - if (property.NameEquals("id"u8)) - { - id = property.Value.GetString(); - continue; - } - if (property.NameEquals("assistant_id"u8)) - { - assistantId = property.Value.GetString(); - continue; - } - if (property.NameEquals("thread_id"u8)) - { - threadId = property.Value.GetString(); - continue; - } - if (property.NameEquals("run_id"u8)) - { - runId = property.Value.GetString(); - continue; - } - if (property.NameEquals("created_at"u8)) - { - createdAt = DateTimeOffset.FromUnixTimeSeconds(property.Value.GetInt64()); - continue; - } - if (property.NameEquals("role"u8)) - { - role = MessageRoleSerialization.DeserializeMessageRole(property.Value); - continue; - } - if (property.NameEquals("content"u8)) - { - foreach (JsonElement contentElement in property.Value.EnumerateArray()) - { - contentItems.Add(MessageContent.DeserializeMessageContent(contentElement)); - } - continue; - } - if (property.NameEquals("file_ids"u8)) - { - foreach (JsonElement fileIdElement in property.Value.EnumerateArray()) - { - fileIds.Add(fileIdElement.GetString()); - } - continue; - } - } - - return new StreamingMessageCreation( - id, - assistantId, - threadId, - runId, - createdAt.Value, - role.Value, - contentItems, - fileIds); - + Internal.Models.MessageObject internalMessage + = Internal.Models.MessageObject.DeserializeMessageObject(sseDataJson, options); + ThreadMessage message = new(internalMessage); + return new StreamingMessageCreation(message); } } diff --git a/.dotnet/src/Custom/Assistants/StreamingMessageCreation.cs b/.dotnet/src/Custom/Assistants/StreamingMessageCreation.cs index 71ae8c651..c6641903d 100644 --- a/.dotnet/src/Custom/Assistants/StreamingMessageCreation.cs +++ b/.dotnet/src/Custom/Assistants/StreamingMessageCreation.cs @@ -1,37 +1,11 @@ namespace OpenAI.Assistants; -using System; -using System.Collections.Generic; -using System.Text.Json; - -public partial class StreamingMessageCreation : StreamingRunUpdate +public partial class StreamingMessageCreation : StreamingUpdate { - public string Id { get; } - public string AssistantId { get; } - public string ThreadId { get; } - public string RunId { get; } - public DateTimeOffset CreatedAt { get; } - public MessageRole Role { get; } - public IReadOnlyList ContentItems { get; } - public IReadOnlyList FileIds { get; } + public ThreadMessage Message { get; } - internal StreamingMessageCreation( - string id, - string assistantId, - string threadId, - string runId, - DateTimeOffset createdAt, - MessageRole role, - IReadOnlyList contentItems, - IReadOnlyList fileIds) + internal StreamingMessageCreation(ThreadMessage message) { - Id = id; - AssistantId = assistantId; - ThreadId = threadId; - RunId = runId; - CreatedAt = createdAt; - Role = role; - ContentItems = contentItems; - FileIds = fileIds; + Message = message; } } diff --git a/.dotnet/src/Custom/Assistants/StreamingMessageUpdate.Serialization.cs b/.dotnet/src/Custom/Assistants/StreamingMessageUpdate.Serialization.cs index f59e22f07..d31176fdb 100644 --- a/.dotnet/src/Custom/Assistants/StreamingMessageUpdate.Serialization.cs +++ b/.dotnet/src/Custom/Assistants/StreamingMessageUpdate.Serialization.cs @@ -1,17 +1,16 @@ namespace OpenAI.Assistants; -using System; using System.ClientModel.Primitives; using System.Collections.Generic; using System.Text.Json; public partial class StreamingMessageUpdate { - internal static IEnumerable DeserializeSseMessageUpdates( + internal static IEnumerable DeserializeSseMessageUpdates( JsonElement sseDataJson, ModelReaderWriterOptions options = default) { - List results = []; + List results = []; if (sseDataJson.ValueKind == JsonValueKind.Null) { return results; @@ -55,9 +54,4 @@ internal static IEnumerable DeserializeSseMessageUpdates( } return results; } - - internal StreamingMessageUpdate() : base() - { - - } } diff --git a/.dotnet/src/Custom/Assistants/StreamingMessageUpdate.cs b/.dotnet/src/Custom/Assistants/StreamingMessageUpdate.cs index 0ffe94bb5..30b5dacbd 100644 --- a/.dotnet/src/Custom/Assistants/StreamingMessageUpdate.cs +++ b/.dotnet/src/Custom/Assistants/StreamingMessageUpdate.cs @@ -1,9 +1,6 @@ namespace OpenAI.Assistants; -using System.Collections.Generic; -using System.Text.Json; - -public partial class StreamingMessageUpdate : StreamingRunUpdate +public partial class StreamingMessageUpdate : StreamingUpdate { public MessageContent ContentUpdate { get; } public int? ContentUpdateIndex { get; } diff --git a/.dotnet/src/Custom/Assistants/StreamingRequiredAction.cs b/.dotnet/src/Custom/Assistants/StreamingRequiredAction.cs index a9ab14074..ccd037154 100644 --- a/.dotnet/src/Custom/Assistants/StreamingRequiredAction.cs +++ b/.dotnet/src/Custom/Assistants/StreamingRequiredAction.cs @@ -1,6 +1,6 @@ namespace OpenAI.Assistants; -public partial class StreamingRequiredAction : StreamingRunUpdate +public partial class StreamingRequiredAction : StreamingUpdate { public RunRequiredAction RequiredAction { get; } diff --git a/.dotnet/src/Custom/Assistants/StreamingRunCreation.Serialization.cs b/.dotnet/src/Custom/Assistants/StreamingRunCreation.Serialization.cs index 49938bd32..14bb26f94 100644 --- a/.dotnet/src/Custom/Assistants/StreamingRunCreation.Serialization.cs +++ b/.dotnet/src/Custom/Assistants/StreamingRunCreation.Serialization.cs @@ -1,13 +1,11 @@ namespace OpenAI.Assistants; -using System; using System.ClientModel.Primitives; -using System.Collections.Generic; using System.Text.Json; public partial class StreamingRunCreation { - internal static StreamingRunUpdate DeserializeStreamingRunCreation(JsonElement element, ModelReaderWriterOptions options = null) + internal static StreamingUpdate DeserializeStreamingRunCreation(JsonElement element, ModelReaderWriterOptions options = null) { Internal.Models.RunObject internalRun = Internal.Models.RunObject.DeserializeRunObject(element, options); ThreadRun run = new(internalRun); diff --git a/.dotnet/src/Custom/Assistants/StreamingRunCreation.cs b/.dotnet/src/Custom/Assistants/StreamingRunCreation.cs index 7013fc89c..89e628e43 100644 --- a/.dotnet/src/Custom/Assistants/StreamingRunCreation.cs +++ b/.dotnet/src/Custom/Assistants/StreamingRunCreation.cs @@ -1,10 +1,6 @@ namespace OpenAI.Assistants; -using System; -using System.Collections.Generic; -using System.Text.Json; - -public partial class StreamingRunCreation : StreamingRunUpdate +public partial class StreamingRunCreation : StreamingUpdate { public ThreadRun Run { get; } diff --git a/.dotnet/src/Custom/Assistants/StreamingRunUpdate.Serialization.cs b/.dotnet/src/Custom/Assistants/StreamingUpdate.Serialization.cs similarity index 87% rename from .dotnet/src/Custom/Assistants/StreamingRunUpdate.Serialization.cs rename to .dotnet/src/Custom/Assistants/StreamingUpdate.Serialization.cs index cc477b2ac..7ef50fa57 100644 --- a/.dotnet/src/Custom/Assistants/StreamingRunUpdate.Serialization.cs +++ b/.dotnet/src/Custom/Assistants/StreamingUpdate.Serialization.cs @@ -4,13 +4,13 @@ namespace OpenAI.Assistants; using System.Collections.Generic; using System.Text.Json; -public partial class StreamingRunUpdate +public partial class StreamingUpdate { - internal static IEnumerable DeserializeSseRunUpdates( + internal static IEnumerable DeserializeSseRunUpdates( ReadOnlyMemory sseEventName, JsonElement sseDataJson) { - IEnumerable results = []; + IEnumerable results = []; if (sseEventName.Span.SequenceEqual(s_threadMessageCreationEventName.Span)) { results = [StreamingMessageCreation.DeserializeSseMessageCreation(sseDataJson)]; @@ -34,11 +34,11 @@ internal static IEnumerable DeserializeSseRunUpdates( } else { - results = [new StreamingRunUpdate()]; + results = [new StreamingUpdate()]; } JsonElement rawElementClone = sseDataJson.Clone(); - foreach (StreamingRunUpdate baseUpdate in results) + foreach (StreamingUpdate baseUpdate in results) { baseUpdate._originalJson = rawElementClone; } diff --git a/.dotnet/src/Custom/Assistants/StreamingRunUpdate.cs b/.dotnet/src/Custom/Assistants/StreamingUpdate.cs similarity index 71% rename from .dotnet/src/Custom/Assistants/StreamingRunUpdate.cs rename to .dotnet/src/Custom/Assistants/StreamingUpdate.cs index 778d23f95..a25f7eba4 100644 --- a/.dotnet/src/Custom/Assistants/StreamingRunUpdate.cs +++ b/.dotnet/src/Custom/Assistants/StreamingUpdate.cs @@ -1,18 +1,15 @@ namespace OpenAI.Assistants; - -using System; -using System.Collections.Generic; using System.Text.Json; /// /// Represents an incremental item of new data in a streaming response when running a thread with an assistant. /// -public partial class StreamingRunUpdate +public partial class StreamingUpdate { private JsonElement _originalJson; public JsonElement GetRawSseEvent() => _originalJson; - protected StreamingRunUpdate() + protected StreamingUpdate() { } } diff --git a/.dotnet/tests/TestScenarios/AssistantTests.cs b/.dotnet/tests/TestScenarios/AssistantTests.cs index 460e887ae..696d6490b 100644 --- a/.dotnet/tests/TestScenarios/AssistantTests.cs +++ b/.dotnet/tests/TestScenarios/AssistantTests.cs @@ -134,7 +134,7 @@ public async Task SimpleStreamingRunWorks() AssistantClient client = GetTestClient(); Assistant assistant = await CreateCommonTestAssistantAsync(); - StreamingClientResult runUpdateResult = client.CreateThreadAndRunStreaming( + StreamingClientResult runUpdateResult = client.CreateThreadAndRunStreaming( assistant.Id, new ThreadCreationOptions() { @@ -144,11 +144,11 @@ public async Task SimpleStreamingRunWorks() } }); Assert.That(runUpdateResult, Is.Not.Null); - await foreach (StreamingRunUpdate runUpdate in runUpdateResult) + await foreach (StreamingUpdate runUpdate in runUpdateResult) { if (runUpdate is StreamingMessageCreation messageCreation) { - Console.WriteLine($"Message created, id={messageCreation.Id}"); + Console.WriteLine($"Message created, id={messageCreation.Message.Id}"); } if (runUpdate is StreamingMessageUpdate messageUpdate) { @@ -189,11 +189,11 @@ public async Task StreamingWithToolsWorks() AssistantThread thread = threadResult.Value; Assert.That(thread, Is.Not.Null); - StreamingClientResult streamingResult = await client.CreateRunStreamingAsync(thread.Id, assistant.Id); + StreamingClientResult streamingResult = await client.CreateRunStreamingAsync(thread.Id, assistant.Id); Assert.That(streamingResult, Is.Not.Null); List requiredActions = []; ThreadRun initialStreamedRun = null; - await foreach (StreamingRunUpdate streamingUpdate in streamingResult) + await foreach (StreamingUpdate streamingUpdate in streamingResult) { if (streamingUpdate is StreamingRunCreation streamingRunCreation) { @@ -220,7 +220,7 @@ public async Task StreamingWithToolsWorks() } } streamingResult = await client.SubmitToolOutputsStreamingAsync(thread.Id, initialStreamedRun.Id, toolOutputs); - await foreach (StreamingRunUpdate streamingUpdate in streamingResult) + await foreach (StreamingUpdate streamingUpdate in streamingResult) { Console.WriteLine(streamingUpdate.GetRawSseEvent().ToString()); } From e82c55a7835ca343720c181d91ef258688109a4e Mon Sep 17 00:00:00 2001 From: Travis Wilson Date: Tue, 19 Mar 2024 10:41:11 -0700 Subject: [PATCH 06/15] SSE tidy pass --- .../{SseEvent.cs => ServerSentEvent.cs} | 18 ++++++------ ...eEventField.cs => ServerSentEventField.cs} | 16 +++++----- ...eldType.cs => ServerSentEventFieldKind.cs} | 2 +- .dotnet/src/Utility/SseAsyncEnumerator.cs | 29 ++++++++++++------- .dotnet/src/Utility/SseReader.cs | 21 ++++++++------ 5 files changed, 49 insertions(+), 37 deletions(-) rename .dotnet/src/Utility/{SseEvent.cs => ServerSentEvent.cs} (79%) rename .dotnet/src/Utility/{SseEventField.cs => ServerSentEventField.cs} (84%) rename .dotnet/src/Utility/{SseEventFieldType.cs => ServerSentEventFieldKind.cs} (82%) diff --git a/.dotnet/src/Utility/SseEvent.cs b/.dotnet/src/Utility/ServerSentEvent.cs similarity index 79% rename from .dotnet/src/Utility/SseEvent.cs rename to .dotnet/src/Utility/ServerSentEvent.cs index 96552f392..b47e3bceb 100644 --- a/.dotnet/src/Utility/SseEvent.cs +++ b/.dotnet/src/Utility/ServerSentEvent.cs @@ -5,17 +5,17 @@ namespace OpenAI; // SSE specification: https://html.spec.whatwg.org/multipage/server-sent-events.html#parsing-an-event-stream -internal readonly struct SseEvent +internal readonly struct ServerSentEvent { - public ReadOnlyMemory EventType { get; } + public ReadOnlyMemory EventName { get; } public ReadOnlyMemory Data { get; } public ReadOnlyMemory LastEventId { get; } public TimeSpan? ReconnectionTime { get; } - private readonly IReadOnlyList _fields; + private readonly IReadOnlyList _fields; private readonly string _multiLineData; - internal SseEvent(IReadOnlyList fields) + internal ServerSentEvent(IReadOnlyList fields) { _fields = fields; StringBuilder multiLineDataBuilder = null; @@ -24,10 +24,10 @@ internal SseEvent(IReadOnlyList fields) ReadOnlyMemory fieldValue = _fields[i].Value; switch (_fields[i].FieldType) { - case SseEventFieldType.Event: - EventType = fieldValue; + case ServerSentEventFieldKind.Event: + EventName = fieldValue; break; - case SseEventFieldType.Data: + case ServerSentEventFieldKind.Data: { if (multiLineDataBuilder != null) { @@ -45,10 +45,10 @@ internal SseEvent(IReadOnlyList fields) } break; } - case SseEventFieldType.Id: + case ServerSentEventFieldKind.Id: LastEventId = fieldValue; break; - case SseEventFieldType.Retry: + case ServerSentEventFieldKind.Retry: ReconnectionTime = Int32.TryParse(fieldValue.ToString(), out int retry) ? TimeSpan.FromMilliseconds(retry) : null; break; default: diff --git a/.dotnet/src/Utility/SseEventField.cs b/.dotnet/src/Utility/ServerSentEventField.cs similarity index 84% rename from .dotnet/src/Utility/SseEventField.cs rename to .dotnet/src/Utility/ServerSentEventField.cs index 0293a0a87..c6032c931 100644 --- a/.dotnet/src/Utility/SseEventField.cs +++ b/.dotnet/src/Utility/ServerSentEventField.cs @@ -3,9 +3,9 @@ namespace OpenAI; // SSE specification: https://html.spec.whatwg.org/multipage/server-sent-events.html#parsing-an-event-stream -internal readonly struct SseEventField +internal readonly struct ServerSentEventField { - public SseEventFieldType FieldType { get; } + public ServerSentEventFieldKind FieldType { get; } // TODO: we should not expose UTF16 publicly public ReadOnlyMemory Value @@ -26,7 +26,7 @@ public ReadOnlyMemory Value private readonly string _original; private readonly int _valueStartIndex; - internal SseEventField(string line) + internal ServerSentEventField(string line) { _original = line; int colonIndex = _original.AsSpan().IndexOf(':'); @@ -34,11 +34,11 @@ internal SseEventField(string line) ReadOnlyMemory fieldName = colonIndex < 0 ? _original.AsMemory(): _original.AsMemory(0, colonIndex); FieldType = fieldName.Span switch { - var x when x.SequenceEqual(s_eventFieldName.Span) => SseEventFieldType.Event, - var x when x.SequenceEqual(s_dataFieldName.Span) => SseEventFieldType.Data, - var x when x.SequenceEqual(s_lastEventIdFieldName.Span) => SseEventFieldType.Id, - var x when x.SequenceEqual(s_retryFieldName.Span) => SseEventFieldType.Retry, - _ => SseEventFieldType.Ignored, + var x when x.SequenceEqual(s_eventFieldName.Span) => ServerSentEventFieldKind.Event, + var x when x.SequenceEqual(s_dataFieldName.Span) => ServerSentEventFieldKind.Data, + var x when x.SequenceEqual(s_lastEventIdFieldName.Span) => ServerSentEventFieldKind.Id, + var x when x.SequenceEqual(s_retryFieldName.Span) => ServerSentEventFieldKind.Retry, + _ => ServerSentEventFieldKind.Ignored, }; if (colonIndex < 0) diff --git a/.dotnet/src/Utility/SseEventFieldType.cs b/.dotnet/src/Utility/ServerSentEventFieldKind.cs similarity index 82% rename from .dotnet/src/Utility/SseEventFieldType.cs rename to .dotnet/src/Utility/ServerSentEventFieldKind.cs index d10bd404a..c3597b0ff 100644 --- a/.dotnet/src/Utility/SseEventFieldType.cs +++ b/.dotnet/src/Utility/ServerSentEventFieldKind.cs @@ -1,7 +1,7 @@ namespace OpenAI; // SSE specification: https://html.spec.whatwg.org/multipage/server-sent-events.html#parsing-an-event-stream -internal enum SseEventFieldType +internal enum ServerSentEventFieldKind { Event, Data, diff --git a/.dotnet/src/Utility/SseAsyncEnumerator.cs b/.dotnet/src/Utility/SseAsyncEnumerator.cs index a5b89e278..5bd8c7efa 100644 --- a/.dotnet/src/Utility/SseAsyncEnumerator.cs +++ b/.dotnet/src/Utility/SseAsyncEnumerator.cs @@ -3,7 +3,6 @@ using System.IO; using System.Linq; using System.Runtime.CompilerServices; -using System.Text; using System.Text.Json; using System.Threading; @@ -11,9 +10,8 @@ namespace OpenAI; internal static class SseAsyncEnumerator { - internal static async IAsyncEnumerable EnumerateFromSseJsonStream( + internal static async IAsyncEnumerable EnumerateServerSentEvents( Stream stream, - Func, JsonElement, IEnumerable> multiElementDeserializer, [EnumeratorCancellation] CancellationToken cancellationToken = default) { try @@ -21,19 +19,14 @@ internal static async IAsyncEnumerable EnumerateFromSseJsonStream( using SseReader sseReader = new(stream); while (!cancellationToken.IsCancellationRequested) { - SseEvent? sseEvent = await sseReader.TryGetNextEventAsync(cancellationToken).ConfigureAwait(false); + ServerSentEvent? sseEvent = await sseReader.TryGetNextEventAsync(cancellationToken).ConfigureAwait(false); if (sseEvent is null) { break; } else { - if (IsWellKnownDoneToken(sseEvent.Value.Data)) continue; - using JsonDocument sseDocument = JsonDocument.Parse(sseEvent.Value.Data); - foreach (T item in multiElementDeserializer(sseEvent.Value.EventType, sseDocument.RootElement)) - { - yield return item; - } + yield return sseEvent.Value; } } } @@ -44,6 +37,22 @@ internal static async IAsyncEnumerable EnumerateFromSseJsonStream( } } + internal static async IAsyncEnumerable EnumerateFromSseJsonStream( + Stream stream, + Func, JsonElement, IEnumerable> multiElementDeserializer, + [EnumeratorCancellation] CancellationToken cancellationToken = default) + { + await foreach (ServerSentEvent sseEvent in EnumerateServerSentEvents(stream, cancellationToken)) + { + if (IsWellKnownDoneToken(sseEvent.Data)) continue; + using JsonDocument sseDocument = JsonDocument.Parse(sseEvent.Data); + foreach (T item in multiElementDeserializer(sseEvent.EventName, sseDocument.RootElement)) + { + yield return item; + } + } + } + private static bool IsWellKnownDoneToken(ReadOnlyMemory data) { ReadOnlyMemory[] wellKnownTokens = diff --git a/.dotnet/src/Utility/SseReader.cs b/.dotnet/src/Utility/SseReader.cs index 2f5facda8..a9de0c30d 100644 --- a/.dotnet/src/Utility/SseReader.cs +++ b/.dotnet/src/Utility/SseReader.cs @@ -18,11 +18,11 @@ public SseReader(Stream stream) _reader = new StreamReader(stream); } - public SseEvent? TryGetNextEvent() + public ServerSentEvent? TryGetNextEvent(CancellationToken cancellationToken = default) { - List fields = []; + List fields = []; - while (true) + while (!cancellationToken.IsCancellationRequested) { string line = _reader.ReadLine(); if (line == null) @@ -33,7 +33,7 @@ public SseReader(Stream stream) else if (line.Length == 0) { // An empty line should dispatch an event for pending accumulated fields - SseEvent nextEvent = new(fields); + ServerSentEvent nextEvent = new(fields); fields = []; return nextEvent; } @@ -45,14 +45,16 @@ public SseReader(Stream stream) else { // Otherwise, process the the field + value and accumulate it for the next dispatched event - fields.Add(new SseEventField(line)); + fields.Add(new ServerSentEventField(line)); } } + + return null; } - public async Task TryGetNextEventAsync(CancellationToken cancellationToken = default) + public async Task TryGetNextEventAsync(CancellationToken cancellationToken = default) { - List fields = []; + List fields = []; while (!cancellationToken.IsCancellationRequested) { @@ -65,7 +67,7 @@ public SseReader(Stream stream) else if (line.Length == 0) { // An empty line should dispatch an event for pending accumulated fields - SseEvent nextEvent = new(fields); + ServerSentEvent nextEvent = new(fields); fields = []; return nextEvent; } @@ -77,9 +79,10 @@ public SseReader(Stream stream) else { // Otherwise, process the the field + value and accumulate it for the next dispatched event - fields.Add(new SseEventField(line)); + fields.Add(new ServerSentEventField(line)); } } + return null; } From 3996600277a2d318777214b654e50c934e0b3eb5 Mon Sep 17 00:00:00 2001 From: Travis Wilson Date: Tue, 19 Mar 2024 10:52:46 -0700 Subject: [PATCH 07/15] SSE tidy pass 2 --- .dotnet/src/Utility/ServerSentEvent.cs | 4 ++++ .dotnet/src/Utility/SseReader.cs | 16 ++++++++++++++++ 2 files changed, 20 insertions(+) diff --git a/.dotnet/src/Utility/ServerSentEvent.cs b/.dotnet/src/Utility/ServerSentEvent.cs index b47e3bceb..ea91b1889 100644 --- a/.dotnet/src/Utility/ServerSentEvent.cs +++ b/.dotnet/src/Utility/ServerSentEvent.cs @@ -7,9 +7,13 @@ namespace OpenAI; // SSE specification: https://html.spec.whatwg.org/multipage/server-sent-events.html#parsing-an-event-stream internal readonly struct ServerSentEvent { + // Gets the value of the SSE "event type" buffer, used to distinguish between event kinds. public ReadOnlyMemory EventName { get; } + // Gets the value of the SSE "data" buffer, which holds the payload of the server-sent event. public ReadOnlyMemory Data { get; } + // Gets the value of the "last event ID" buffer, with which a user agent can reestablish a session. public ReadOnlyMemory LastEventId { get; } + // If present, gets the defined "retry" value for the event, which represents the delay before reconnecting. public TimeSpan? ReconnectionTime { get; } private readonly IReadOnlyList _fields; diff --git a/.dotnet/src/Utility/SseReader.cs b/.dotnet/src/Utility/SseReader.cs index a9de0c30d..cab725caf 100644 --- a/.dotnet/src/Utility/SseReader.cs +++ b/.dotnet/src/Utility/SseReader.cs @@ -18,6 +18,14 @@ public SseReader(Stream stream) _reader = new StreamReader(stream); } + /// + /// Synchronously retrieves the next server-sent event from the underlying stream, blocking until a new event is + /// available and returning null once no further data is present on the stream. + /// + /// An optional cancellation token that can abort subsequent reads. + /// + /// The next in the stream, or null once no more data can be read from the stream. + /// public ServerSentEvent? TryGetNextEvent(CancellationToken cancellationToken = default) { List fields = []; @@ -52,6 +60,14 @@ public SseReader(Stream stream) return null; } + /// + /// Asynchronously retrieves the next server-sent event from the underlying stream, blocking until a new event is + /// available and returning null once no further data is present on the stream. + /// + /// An optional cancellation token that can abort subsequent reads. + /// + /// The next in the stream, or null once no more data can be read from the stream. + /// public async Task TryGetNextEventAsync(CancellationToken cancellationToken = default) { List fields = []; From 4c355d5bd93b72598040fc4a5bd040ecb15d81e0 Mon Sep 17 00:00:00 2001 From: Anne Thompson Date: Thu, 28 Mar 2024 14:13:57 -0700 Subject: [PATCH 08/15] updates --- .dotnet/src/Custom/Assistants/AssistantClient.cs | 1 - 1 file changed, 1 deletion(-) diff --git a/.dotnet/src/Custom/Assistants/AssistantClient.cs b/.dotnet/src/Custom/Assistants/AssistantClient.cs index b16329231..7de8404cf 100644 --- a/.dotnet/src/Custom/Assistants/AssistantClient.cs +++ b/.dotnet/src/Custom/Assistants/AssistantClient.cs @@ -11,7 +11,6 @@ namespace OpenAI.Assistants; /// /// The service client for OpenAI assistants. /// -[Experimental("OPENAI001")] public partial class AssistantClient { private OpenAIClientConnector _clientConnector; From 4f9e8322f7ba4e7a193496cb5c58798a81467068 Mon Sep 17 00:00:00 2001 From: Anne Thompson Date: Thu, 28 Mar 2024 14:26:59 -0700 Subject: [PATCH 09/15] first steps --- .dotnet/src/Custom/Chat/ChatClient.cs | 67 ++++++------- .../src/Custom/Chat/StreamingChatUpdate.cs | 12 ++- .dotnet/src/Utility/SseReader.cs | 8 +- .../src/Utility/StreamingClientResultOfT.cs | 28 ++++++ .dotnet/src/Utility/StreamingResult.cs | 95 ------------------- 5 files changed, 73 insertions(+), 137 deletions(-) create mode 100644 .dotnet/src/Utility/StreamingClientResultOfT.cs delete mode 100644 .dotnet/src/Utility/StreamingResult.cs diff --git a/.dotnet/src/Custom/Chat/ChatClient.cs b/.dotnet/src/Custom/Chat/ChatClient.cs index eeeeeacf9..a4b7b1352 100644 --- a/.dotnet/src/Custom/Chat/ChatClient.cs +++ b/.dotnet/src/Custom/Chat/ChatClient.cs @@ -40,8 +40,8 @@ public ChatClient(string model, ApiKeyCredential credential = default, OpenAICli /// The user message to provide as a prompt for chat completion. /// Additional options for the chat completion request. /// A result for a single chat completion. - public virtual ClientResult CompleteChat(string message, ChatCompletionOptions options = null) - => CompleteChat(new List() { new ChatRequestUserMessage(message) }, options); + public virtual ClientResult CompleteChat(string message, ChatCompletionOptions options = null) + => CompleteChat(new List() { new ChatRequestUserMessage(message) }, options); /// /// Generates a single chat completion result for a single, simple user message. @@ -49,9 +49,9 @@ public virtual ClientResult CompleteChat(string message, ChatCom /// The user message to provide as a prompt for chat completion. /// Additional options for the chat completion request. /// A result for a single chat completion. - public virtual Task> CompleteChatAsync(string message, ChatCompletionOptions options = null) - => CompleteChatAsync( - new List() { new ChatRequestUserMessage(message) }, options); + public virtual Task> CompleteChatAsync(string message, ChatCompletionOptions options = null) + => CompleteChatAsync( + new List() { new ChatRequestUserMessage(message) }, options); /// /// Generates a single chat completion result for a provided set of input chat messages. @@ -63,7 +63,7 @@ public virtual ClientResult CompleteChat( IEnumerable messages, ChatCompletionOptions options = null) { - Internal.Models.CreateChatCompletionRequest request = CreateInternalRequest(messages, options); + Internal.Models.CreateChatCompletionRequest request = CreateInternalRequest(messages, options); ClientResult response = Shim.CreateChatCompletion(request); ChatCompletion chatCompletion = new(response.Value, internalChoiceIndex: 0); return ClientResult.FromValue(chatCompletion, response.GetRawResponse()); @@ -93,7 +93,6 @@ public virtual async Task> CompleteChatAsync( /// The number of independent, alternative response choices that should be generated. /// /// Additional options for the chat completion request. - /// The cancellation token for the operation. /// A result for a single chat completion. public virtual ClientResult CompleteChat( IEnumerable messages, @@ -147,14 +146,14 @@ public virtual async Task> CompleteChatAs /// /// Additional options for the chat completion request. /// A streaming result with incremental chat completion updates. - public virtual StreamingClientResult CompleteChatStreaming( - string message, - int? choiceCount = null, - ChatCompletionOptions options = null) - => CompleteChatStreaming( - new List { new ChatRequestUserMessage(message) }, - choiceCount, - options); + public virtual StreamingClientResult CompleteChatStreaming( + string message, + int? choiceCount = null, + ChatCompletionOptions options = null) + => CompleteChatStreaming( + new List { new ChatRequestUserMessage(message) }, + choiceCount, + options); /// /// Begins a streaming response for a chat completion request using a single, simple user message as input. @@ -191,29 +190,25 @@ public virtual Task> CompleteChatStre /// The number of independent, alternative choices that the chat completion request should generate. /// /// Additional options for the chat completion request. - /// The cancellation token for the operation. /// A streaming result with incremental chat completion updates. public virtual StreamingClientResult CompleteChatStreaming( IEnumerable messages, int? choiceCount = null, ChatCompletionOptions options = null) { - PipelineMessage requestMessage = CreateCustomRequestMessage(messages, choiceCount, options); - requestMessage.BufferResponse = false; - Shim.Pipeline.Send(requestMessage); - PipelineResponse response = requestMessage.ExtractResponse(); + PipelineMessage message = CreateCustomRequestMessage(messages, choiceCount, options); + message.BufferResponse = false; + + Shim.Pipeline.Send(message); + + PipelineResponse response = message.Response; if (response.IsError) { throw new ClientResultException(response); } - ClientResult genericResult = ClientResult.FromResponse(response); - return StreamingClientResult.CreateFromResponse( - genericResult, - (responseForEnumeration) => SseAsyncEnumerator.EnumerateFromSseJsonStream( - responseForEnumeration.GetRawResponse().ContentStream, - StreamingChatUpdate.DeserializeSseChatUpdates)); + return new StreamingChatResult(response); } /// @@ -229,29 +224,25 @@ public virtual StreamingClientResult CompleteChatStreaming( /// The number of independent, alternative choices that the chat completion request should generate. /// /// Additional options for the chat completion request. - /// The cancellation token for the operation. /// A streaming result with incremental chat completion updates. public virtual async Task> CompleteChatStreamingAsync( IEnumerable messages, int? choiceCount = null, ChatCompletionOptions options = null) { - PipelineMessage requestMessage = CreateCustomRequestMessage(messages, choiceCount, options); - requestMessage.BufferResponse = false; - await Shim.Pipeline.SendAsync(requestMessage).ConfigureAwait(false); - PipelineResponse response = requestMessage.ExtractResponse(); + PipelineMessage message = CreateCustomRequestMessage(messages, choiceCount, options); + message.BufferResponse = false; + + await Shim.Pipeline.SendAsync(message).ConfigureAwait(false); + + PipelineResponse response = message.Response; if (response.IsError) { throw new ClientResultException(response); } - ClientResult genericResult = ClientResult.FromResponse(response); - return StreamingClientResult.CreateFromResponse( - genericResult, - (responseForEnumeration) => SseAsyncEnumerator.EnumerateFromSseJsonStream( - responseForEnumeration.GetRawResponse().ContentStream, - StreamingChatUpdate.DeserializeSseChatUpdates)); + return new StreamingChatResult(response); } private Internal.Models.CreateChatCompletionRequest CreateInternalRequest( @@ -326,4 +317,4 @@ private PipelineMessage CreateCustomRequestMessage(IEnumerable _responseErrorClassifier200 ??= PipelineMessageClassifier.Create(stackalloc ushort[] { 200 }); -} +} \ No newline at end of file diff --git a/.dotnet/src/Custom/Chat/StreamingChatUpdate.cs b/.dotnet/src/Custom/Chat/StreamingChatUpdate.cs index f7e9e9ec4..11e50eabe 100644 --- a/.dotnet/src/Custom/Chat/StreamingChatUpdate.cs +++ b/.dotnet/src/Custom/Chat/StreamingChatUpdate.cs @@ -1,9 +1,9 @@ -namespace OpenAI.Chat; - using System; using System.Collections.Generic; using System.Text.Json; +namespace OpenAI.Chat; + /// /// Represents an incremental item of new data in a streaming response to a chat completion request. /// @@ -184,11 +184,17 @@ internal StreamingChatUpdate( internal static IEnumerable DeserializeSseChatUpdates(ReadOnlyMemory _, JsonElement sseDataJson) { + // TODO: would another enumerable implementation be more performant than list? List results = []; + + // TODO: Do we need to validate that we didn't get null or empty? + // What's the contract for the JSON updates? + if (sseDataJson.ValueKind == JsonValueKind.Null) { return results; } + string id = default; DateTimeOffset created = default; string systemFingerprint = null; @@ -333,4 +339,4 @@ Internal.Models.CreateChatCompletionResponseChoiceLogprobs internalLogprobs } return results; } -} +} \ No newline at end of file diff --git a/.dotnet/src/Utility/SseReader.cs b/.dotnet/src/Utility/SseReader.cs index cab725caf..fef0d6d7b 100644 --- a/.dotnet/src/Utility/SseReader.cs +++ b/.dotnet/src/Utility/SseReader.cs @@ -6,7 +6,7 @@ namespace OpenAI; -internal sealed class SseReader : IDisposable +internal sealed class SseReader : IDisposable, IAsyncDisposable { private readonly Stream _stream; private readonly StreamReader _reader; @@ -121,4 +121,10 @@ private void Dispose(bool disposing) _disposedValue = true; } } + + ValueTask IAsyncDisposable.DisposeAsync() + { + // TODO: revisit per platforms where async dispose is available. + return new ValueTask(); + } } \ No newline at end of file diff --git a/.dotnet/src/Utility/StreamingClientResultOfT.cs b/.dotnet/src/Utility/StreamingClientResultOfT.cs new file mode 100644 index 000000000..6d2d68225 --- /dev/null +++ b/.dotnet/src/Utility/StreamingClientResultOfT.cs @@ -0,0 +1,28 @@ +using System.ClientModel; +using System.ClientModel.Primitives; +using System.Collections.Generic; +using System.Threading; + +namespace OpenAI; + +#pragma warning disable CS1591 // public XML comments + +/// +/// Represents an operation response with streaming content that can be deserialized and enumerated while the response +/// is still being received. +/// +/// The data type representative of distinct, streamable items. +// TODO: Revisit the IDisposable question +public abstract class StreamingClientResult : ClientResult, IAsyncEnumerable +{ + protected StreamingClientResult(PipelineResponse response) : base(response) + { + } + + // Note that if the implementation disposes the stream, the caller can only + // enumerate the results once. I think this makes sense, but we should + // make sure architects agree. + public abstract IAsyncEnumerator GetAsyncEnumerator(CancellationToken cancellationToken = default); +} + +#pragma warning restore CS1591 // public XML comments \ No newline at end of file diff --git a/.dotnet/src/Utility/StreamingResult.cs b/.dotnet/src/Utility/StreamingResult.cs deleted file mode 100644 index a1b6ff538..000000000 --- a/.dotnet/src/Utility/StreamingResult.cs +++ /dev/null @@ -1,95 +0,0 @@ -using System.ClientModel; -using System.ClientModel; -using System.ClientModel.Primitives; -using System.Threading; -using System.Collections.Generic; -using System; - -namespace OpenAI; - -/// -/// Represents an operation response with streaming content that can be deserialized and enumerated while the response -/// is still being received. -/// -/// The data type representative of distinct, streamable items. -public class StreamingClientResult - : IDisposable - , IAsyncEnumerable -{ - private ClientResult _rawResult { get; } - private IAsyncEnumerable _asyncEnumerableSource { get; } - private bool _disposedValue { get; set; } - - private StreamingClientResult() { } - - private StreamingClientResult( - ClientResult rawResult, - Func> asyncEnumerableProcessor) - { - _rawResult = rawResult; - _asyncEnumerableSource = asyncEnumerableProcessor.Invoke(rawResult); - } - - /// - /// Creates a new instance of using the provided underlying HTTP response. The - /// provided function will be used to resolve the response into an asynchronous enumeration of streamed response - /// items. - /// - /// The HTTP response. - /// - /// The function that will resolve the provided response into an IAsyncEnumerable. - /// - /// - /// A new instance of that will be capable of asynchronous enumeration of - /// items from the HTTP response. - /// - internal static StreamingClientResult CreateFromResponse( - ClientResult result, - Func> asyncEnumerableProcessor) - { - return new(result, asyncEnumerableProcessor); - } - - /// - /// Gets the underlying instance that this may enumerate - /// over. - /// - /// The instance attached to this . - public PipelineResponse GetRawResponse() => _rawResult.GetRawResponse(); - - /// - /// Gets the asynchronously enumerable collection of distinct, streamable items in the response. - /// - /// - /// The return value of this method may be used with the "await foreach" statement. - /// - /// As explicitly implements , callers may - /// enumerate a instance directly instead of calling this method. - /// - /// - /// - public IAsyncEnumerable EnumerateValues() => this; - - /// - public void Dispose() - { - Dispose(disposing: true); - GC.SuppressFinalize(this); - } - - /// - protected virtual void Dispose(bool disposing) - { - if (!_disposedValue) - { - if (disposing) - { - _rawResult?.GetRawResponse()?.Dispose(); - } - _disposedValue = true; - } - } - - IAsyncEnumerator IAsyncEnumerable.GetAsyncEnumerator(CancellationToken cancellationToken) - => _asyncEnumerableSource.GetAsyncEnumerator(cancellationToken); -} \ No newline at end of file From 704c617859729389d37629a22601256f85d4ea2a Mon Sep 17 00:00:00 2001 From: Anne Thompson Date: Thu, 28 Mar 2024 14:29:15 -0700 Subject: [PATCH 10/15] add prior chat implementation --- .../src/Custom/Chat/StreamingChatResult.cs | 100 ++++++++++++++++++ 1 file changed, 100 insertions(+) create mode 100644 .dotnet/src/Custom/Chat/StreamingChatResult.cs diff --git a/.dotnet/src/Custom/Chat/StreamingChatResult.cs b/.dotnet/src/Custom/Chat/StreamingChatResult.cs new file mode 100644 index 000000000..3a5be633b --- /dev/null +++ b/.dotnet/src/Custom/Chat/StreamingChatResult.cs @@ -0,0 +1,100 @@ +using System; +using System.ClientModel.Primitives; +using System.Collections.Generic; +using System.IO; +using System.Text.Json; +using System.Threading; +using System.Threading.Tasks; + +namespace OpenAI.Chat; + +#nullable enable + +internal class StreamingChatResult : StreamingClientResult +{ + public StreamingChatResult(PipelineResponse response) : base(response) + { + } + + public override IAsyncEnumerator GetAsyncEnumerator(CancellationToken cancellationToken = default) + { + // Note: this implementation disposes the stream after the caller has + // enumerated the elements obtained from the stream. That is to say, + // the `await foreach` loop can only happen once -- if it is tried a + // second time, the caller will get an ObjectDisposedException trying + // to access a disposed Stream. + using PipelineResponse response = GetRawResponse(); + + // Extract the content stream from the response to obtain dispose + // ownership of it. This means the content stream will not be disposed + // when the response is disposed. + Stream contentStream = response.ContentStream ?? throw new InvalidOperationException("Cannot enumerate null response ContentStream."); + response.ContentStream = null; + + return new ChatUpdateEnumerator(contentStream); + } + + private class ChatUpdateEnumerator : IAsyncEnumerator + { + private readonly SseReader _sseReader; + + private List? _currentUpdates; + private int _currentUpdateIndex; + + public ChatUpdateEnumerator(Stream stream) + { + _sseReader = new(stream); + } + + public StreamingChatUpdate Current => throw new NotImplementedException(); + + public async ValueTask MoveNextAsync() + { + // TODO: How to handle the CancellationToken? + + if (_currentUpdates is not null && _currentUpdateIndex < _currentUpdates.Count) + { + _currentUpdateIndex++; + return true; + } + + // We either don't have any stored updates, or we've exceeded the + // count of the ones we have. Get the next set. + + // TODO: Call different configure await variant in this context, or no? + + // TODO: Update to new reader APIs + SseLine? sseEvent = await _sseReader.TryReadSingleFieldEventAsync().ConfigureAwait(false); + if (sseEvent is null) + { + // TODO: does this mean we're done or not? + return false; + } + + ReadOnlyMemory name = sseEvent.Value.FieldName; + if (!name.Span.SequenceEqual("data".AsSpan())) + { + throw new InvalidDataException(); + } + + ReadOnlyMemory value = sseEvent.Value.FieldValue; + if (value.Span.SequenceEqual("[DONE]".AsSpan())) + { + // enumerator semantics are that MoveNextAsync returns false when done. + return false; + } + + // TODO:optimize performance using Utf8JsonReader? + using JsonDocument sseMessageJson = JsonDocument.Parse(value); + _currentUpdates = StreamingChatUpdate.DeserializeSseChatUpdates(/*TODO: update*/ sseMessageJson.RootElement); + return true; + } + + public ValueTask DisposeAsync() + { + // TODO: revisit per platforms where async dispose is available. + _sseReader?.Dispose(); + return new ValueTask(); + } + } +} \ No newline at end of file From abfa53cf7aae852264a08e52803c8f17ea9f409d Mon Sep 17 00:00:00 2001 From: Anne Thompson Date: Thu, 28 Mar 2024 14:41:37 -0700 Subject: [PATCH 11/15] make reader use enumerables --- .../src/Custom/Assistants/AssistantClient.cs | 18 ++++---- .dotnet/src/Utility/SseAsyncEnumerator.cs | 43 +++++-------------- .dotnet/src/Utility/SseReader.cs | 17 ++++---- 3 files changed, 29 insertions(+), 49 deletions(-) diff --git a/.dotnet/src/Custom/Assistants/AssistantClient.cs b/.dotnet/src/Custom/Assistants/AssistantClient.cs index 7de8404cf..3f05ed243 100644 --- a/.dotnet/src/Custom/Assistants/AssistantClient.cs +++ b/.dotnet/src/Custom/Assistants/AssistantClient.cs @@ -233,7 +233,7 @@ public virtual async Task> CreateThreadAsync( return ClientResult.FromValue(new AssistantThread(internalResult.Value), internalResult.GetRawResponse()); } - public virtual ClientResult GetThread(string threadId) + public virtual ClientResult GetThread(string threadId) { ClientResult internalResult = ThreadShim.GetThread(threadId); return ClientResult.FromValue(new AssistantThread(internalResult.Value), internalResult.GetRawResponse()); @@ -268,13 +268,13 @@ public virtual async Task> ModifyThreadAsync( return ClientResult.FromValue(new AssistantThread(internalResult.Value), internalResult.GetRawResponse()); } - public virtual ClientResult DeleteThread(string threadId) + public virtual ClientResult DeleteThread(string threadId) { ClientResult internalResult = ThreadShim.DeleteThread(threadId); return ClientResult.FromValue(internalResult.Value.Deleted, internalResult.GetRawResponse()); } - public virtual async Task> DeleteThreadAsync(string threadId) + public virtual async Task> DeleteThreadAsync(string threadId) { ClientResult internalResult = await ThreadShim.DeleteThreadAsync(threadId).ConfigureAwait(false); return ClientResult.FromValue(internalResult.Value.Deleted, internalResult.GetRawResponse()); @@ -464,7 +464,7 @@ public virtual StreamingClientResult CreateRunStreaming( string assistantId, RunCreationOptions options = null) { - PipelineMessage message = CreateCreateRunRequest(threadId, assistantId, options, stream: true); + using PipelineMessage message = CreateCreateRunRequest(threadId, assistantId, options, stream: true); RunShim.Pipeline.Send(message); return CreateStreamingRunResult(message); } @@ -474,7 +474,7 @@ public virtual async Task> CreateRunStrea string assistantId, RunCreationOptions options = null) { - PipelineMessage message = CreateCreateRunRequest(threadId, assistantId, options, stream: true); + using PipelineMessage message = CreateCreateRunRequest(threadId, assistantId, options, stream: true); await RunShim.Pipeline.SendAsync(message); return CreateStreamingRunResult(message); } @@ -506,7 +506,7 @@ public virtual StreamingClientResult CreateThreadAndRunStreamin ThreadCreationOptions threadOptions = null, RunCreationOptions runOptions = null) { - PipelineMessage message = CreateCreateThreadAndRunRequest(assistantId, threadOptions, runOptions, stream: true); + using PipelineMessage message = CreateCreateThreadAndRunRequest(assistantId, threadOptions, runOptions, stream: true); Shim.Pipeline.Send(message); return CreateStreamingRunResult(message); } @@ -516,7 +516,7 @@ public virtual async Task> CreateThreadAn ThreadCreationOptions threadOptions = null, RunCreationOptions runOptions = null) { - PipelineMessage message = CreateCreateThreadAndRunRequest(assistantId, threadOptions, runOptions, stream: true); + using PipelineMessage message = CreateCreateThreadAndRunRequest(assistantId, threadOptions, runOptions, stream: true); await Shim.Pipeline.SendAsync(message); return CreateStreamingRunResult(message); } @@ -621,14 +621,14 @@ public virtual async Task> SubmitToolOutputsAsync(string public virtual StreamingClientResult SubmitToolOutputsStreaming(string threadId, string runId, IEnumerable toolOutputs) { - PipelineMessage message = CreateSubmitToolOutputsRequest(threadId, runId, toolOutputs, stream: true); + using PipelineMessage message = CreateSubmitToolOutputsRequest(threadId, runId, toolOutputs, stream: true); Shim.Pipeline.SendAsync(message); return CreateStreamingRunResult(message); } public virtual async Task> SubmitToolOutputsStreamingAsync(string threadId, string runId, IEnumerable toolOutputs) { - PipelineMessage message = CreateSubmitToolOutputsRequest(threadId, runId, toolOutputs, stream: true); + using PipelineMessage message = CreateSubmitToolOutputsRequest(threadId, runId, toolOutputs, stream: true); await Shim.Pipeline.SendAsync(message); return CreateStreamingRunResult(message); } diff --git a/.dotnet/src/Utility/SseAsyncEnumerator.cs b/.dotnet/src/Utility/SseAsyncEnumerator.cs index 5bd8c7efa..31e2c522d 100644 --- a/.dotnet/src/Utility/SseAsyncEnumerator.cs +++ b/.dotnet/src/Utility/SseAsyncEnumerator.cs @@ -10,42 +10,24 @@ namespace OpenAI; internal static class SseAsyncEnumerator { - internal static async IAsyncEnumerable EnumerateServerSentEvents( - Stream stream, - [EnumeratorCancellation] CancellationToken cancellationToken = default) - { - try - { - using SseReader sseReader = new(stream); - while (!cancellationToken.IsCancellationRequested) - { - ServerSentEvent? sseEvent = await sseReader.TryGetNextEventAsync(cancellationToken).ConfigureAwait(false); - if (sseEvent is null) - { - break; - } - else - { - yield return sseEvent.Value; - } - } - } - finally - { - // Always dispose the stream immediately once enumeration is complete for any reason - stream.Dispose(); - } - } + private static ReadOnlyMemory[] _wellKnownTokens = + [ + "[DONE]".AsMemory(), + ]; internal static async IAsyncEnumerable EnumerateFromSseJsonStream( Stream stream, Func, JsonElement, IEnumerable> multiElementDeserializer, [EnumeratorCancellation] CancellationToken cancellationToken = default) { - await foreach (ServerSentEvent sseEvent in EnumerateServerSentEvents(stream, cancellationToken)) + using SseReader reader = new SseReader(stream); + + await foreach (ServerSentEvent sseEvent in reader.GetEventsAsync(cancellationToken)) { if (IsWellKnownDoneToken(sseEvent.Data)) continue; + using JsonDocument sseDocument = JsonDocument.Parse(sseEvent.Data); + foreach (T item in multiElementDeserializer(sseEvent.EventName, sseDocument.RootElement)) { yield return item; @@ -55,10 +37,7 @@ internal static async IAsyncEnumerable EnumerateFromSseJsonStream( private static bool IsWellKnownDoneToken(ReadOnlyMemory data) { - ReadOnlyMemory[] wellKnownTokens = - [ - "[DONE]".AsMemory(), - ]; - return wellKnownTokens.Any(token => data.Span.SequenceEqual(token.Span)); + // TODO: Make faster than LINQ. + return _wellKnownTokens.Any(token => data.Span.SequenceEqual(token.Span)); } } \ No newline at end of file diff --git a/.dotnet/src/Utility/SseReader.cs b/.dotnet/src/Utility/SseReader.cs index fef0d6d7b..f7d3a96a7 100644 --- a/.dotnet/src/Utility/SseReader.cs +++ b/.dotnet/src/Utility/SseReader.cs @@ -1,6 +1,7 @@ using System; using System.Collections.Generic; using System.IO; +using System.Runtime.CompilerServices; using System.Threading; using System.Threading.Tasks; @@ -26,7 +27,7 @@ public SseReader(Stream stream) /// /// The next in the stream, or null once no more data can be read from the stream. /// - public ServerSentEvent? TryGetNextEvent(CancellationToken cancellationToken = default) + public IEnumerable GetEvents(CancellationToken cancellationToken = default) { List fields = []; @@ -36,14 +37,14 @@ public SseReader(Stream stream) if (line == null) { // A null line indicates end of input - return null; + yield break; } else if (line.Length == 0) { // An empty line should dispatch an event for pending accumulated fields ServerSentEvent nextEvent = new(fields); fields = []; - return nextEvent; + yield return nextEvent; } else if (line[0] == ':') { @@ -57,7 +58,7 @@ public SseReader(Stream stream) } } - return null; + yield break; } /// @@ -68,7 +69,7 @@ public SseReader(Stream stream) /// /// The next in the stream, or null once no more data can be read from the stream. /// - public async Task TryGetNextEventAsync(CancellationToken cancellationToken = default) + public async IAsyncEnumerable GetEventsAsync([EnumeratorCancellation] CancellationToken cancellationToken = default) { List fields = []; @@ -78,14 +79,14 @@ public SseReader(Stream stream) if (line == null) { // A null line indicates end of input - return null; + yield break; } else if (line.Length == 0) { // An empty line should dispatch an event for pending accumulated fields ServerSentEvent nextEvent = new(fields); fields = []; - return nextEvent; + yield return nextEvent; } else if (line[0] == ':') { @@ -99,7 +100,7 @@ public SseReader(Stream stream) } } - return null; + yield break; } public void Dispose() From 32c122b4b25b53e008bbed96fd2a03f1f7fc7eca Mon Sep 17 00:00:00 2001 From: Anne Thompson Date: Thu, 28 Mar 2024 15:45:50 -0700 Subject: [PATCH 12/15] backup of current idea-in-progress --- .../src/Custom/Chat/StreamingChatResult.cs | 43 ++-- .../src/Custom/Chat/StreamingChatUpdate.cs | 172 ++----------- .../Chat/StreamingChatUpdateCollection.cs | 177 ++++++++++++++ .dotnet/src/Utility/SseAsyncEnumerator.cs | 5 +- .dotnet/src/Utility/SseReader.cs | 230 ++++++++++++------ .../src/Utility/StreamedEventCollection.cs | 33 +++ .dotnet/tests/TestScenarios/AssistantTests.cs | 10 +- 7 files changed, 410 insertions(+), 260 deletions(-) create mode 100644 .dotnet/src/Custom/Chat/StreamingChatUpdateCollection.cs create mode 100644 .dotnet/src/Utility/StreamedEventCollection.cs diff --git a/.dotnet/src/Custom/Chat/StreamingChatResult.cs b/.dotnet/src/Custom/Chat/StreamingChatResult.cs index 3a5be633b..542470b00 100644 --- a/.dotnet/src/Custom/Chat/StreamingChatResult.cs +++ b/.dotnet/src/Custom/Chat/StreamingChatResult.cs @@ -2,7 +2,6 @@ using System.ClientModel.Primitives; using System.Collections.Generic; using System.IO; -using System.Text.Json; using System.Threading; using System.Threading.Tasks; @@ -36,22 +35,24 @@ public override IAsyncEnumerator GetAsyncEnumerator(Cancell private class ChatUpdateEnumerator : IAsyncEnumerator { - private readonly SseReader _sseReader; + private readonly IAsyncEnumerator _sseEvents; private List? _currentUpdates; private int _currentUpdateIndex; public ChatUpdateEnumerator(Stream stream) { - _sseReader = new(stream); + AsyncSseReader reader = new AsyncSseReader(stream); + + // TODO: Pass CancellationToken. + _sseEvents = reader.GetAsyncEnumerator(); } public StreamingChatUpdate Current => throw new NotImplementedException(); public async ValueTask MoveNextAsync() { - // TODO: How to handle the CancellationToken? - + // Still have leftovers from the last event we pulled from the reader. if (_currentUpdates is not null && _currentUpdateIndex < _currentUpdates.Count) { _currentUpdateIndex++; @@ -62,38 +63,30 @@ public async ValueTask MoveNextAsync() // count of the ones we have. Get the next set. // TODO: Call different configure await variant in this context, or no? - - // TODO: Update to new reader APIs - SseLine? sseEvent = await _sseReader.TryReadSingleFieldEventAsync().ConfigureAwait(false); - if (sseEvent is null) + if (!await _sseEvents.MoveNextAsync().ConfigureAwait(false)) { - // TODO: does this mean we're done or not? + // Done with events from the stream. return false; } - ReadOnlyMemory name = sseEvent.Value.FieldName; - if (!name.Span.SequenceEqual("data".AsSpan())) - { - throw new InvalidDataException(); - } + ServerSentEvent ssEvent = _sseEvents.Current; - ReadOnlyMemory value = sseEvent.Value.FieldValue; - if (value.Span.SequenceEqual("[DONE]".AsSpan())) - { - // enumerator semantics are that MoveNextAsync returns false when done. - return false; - } + // TODO: optimize + BinaryData data = BinaryData.FromString(new string(ssEvent.Data.ToArray())); + + // TODO: don't instantiate every time. + StreamingChatUpdateCollection justToCreate = new StreamingChatUpdateCollection(); + + _currentUpdates = justToCreate.Create(data, ModelReaderWriterOptions.Json); + _currentUpdateIndex = 0; - // TODO:optimize performance using Utf8JsonReader? - using JsonDocument sseMessageJson = JsonDocument.Parse(value); - _currentUpdates = StreamingChatUpdate.DeserializeSseChatUpdates(/*TODO: update*/ sseMessageJson.RootElement); return true; } public ValueTask DisposeAsync() { // TODO: revisit per platforms where async dispose is available. - _sseReader?.Dispose(); + _sseEvents?.DisposeAsync(); return new ValueTask(); } } diff --git a/.dotnet/src/Custom/Chat/StreamingChatUpdate.cs b/.dotnet/src/Custom/Chat/StreamingChatUpdate.cs index 11e50eabe..604ff5110 100644 --- a/.dotnet/src/Custom/Chat/StreamingChatUpdate.cs +++ b/.dotnet/src/Custom/Chat/StreamingChatUpdate.cs @@ -1,4 +1,5 @@ using System; +using System.ClientModel.Primitives; using System.Collections.Generic; using System.Text.Json; @@ -7,7 +8,7 @@ namespace OpenAI.Chat; /// /// Represents an incremental item of new data in a streaming response to a chat completion request. /// -public partial class StreamingChatUpdate +public partial class StreamingChatUpdate : IJsonModel { /// /// Gets a unique identifier associated with this streamed Chat Completions response. @@ -182,161 +183,28 @@ internal StreamingChatUpdate( LogProbabilities = logProbabilities; } - internal static IEnumerable DeserializeSseChatUpdates(ReadOnlyMemory _, JsonElement sseDataJson) + public void Write(Utf8JsonWriter writer, ModelReaderWriterOptions options) { - // TODO: would another enumerable implementation be more performant than list? - List results = []; + throw new NotImplementedException(); + } - // TODO: Do we need to validate that we didn't get null or empty? - // What's the contract for the JSON updates? + public StreamingChatUpdate Create(ref Utf8JsonReader reader, ModelReaderWriterOptions options) + { + throw new NotImplementedException(); + } - if (sseDataJson.ValueKind == JsonValueKind.Null) - { - return results; - } + public BinaryData Write(ModelReaderWriterOptions options) + { + throw new NotImplementedException(); + } - string id = default; - DateTimeOffset created = default; - string systemFingerprint = null; - foreach (JsonProperty property in sseDataJson.EnumerateObject()) - { - if (property.NameEquals("id"u8)) - { - id = property.Value.GetString(); - continue; - } - if (property.NameEquals("created"u8)) - { - created = DateTimeOffset.FromUnixTimeSeconds(property.Value.GetInt64()); - continue; - } - if (property.NameEquals("system_fingerprint")) - { - systemFingerprint = property.Value.GetString(); - continue; - } - if (property.NameEquals("choices"u8)) - { - foreach (JsonElement choiceElement in property.Value.EnumerateArray()) - { - ChatRole? role = null; - string contentUpdate = null; - string functionName = null; - string functionArgumentsUpdate = null; - int choiceIndex = 0; - ChatFinishReason? finishReason = null; - List toolCallUpdates = []; - ChatLogProbabilityCollection logProbabilities = null; + public StreamingChatUpdate Create(BinaryData data, ModelReaderWriterOptions options) + { + throw new NotImplementedException(); + } - foreach (JsonProperty choiceProperty in choiceElement.EnumerateObject()) - { - if (choiceProperty.NameEquals("index"u8)) - { - choiceIndex = choiceProperty.Value.GetInt32(); - continue; - } - if (choiceProperty.NameEquals("finish_reason"u8)) - { - if (choiceProperty.Value.ValueKind == JsonValueKind.Null) - { - finishReason = null; - continue; - } - finishReason = choiceProperty.Value.GetString() switch - { - "stop" => ChatFinishReason.Stopped, - "length" => ChatFinishReason.Length, - "tool_calls" => ChatFinishReason.ToolCalls, - "function_call" => ChatFinishReason.FunctionCall, - "content_filter" => ChatFinishReason.ContentFilter, - _ => throw new ArgumentException(nameof(finishReason)), - }; - continue; - } - if (choiceProperty.NameEquals("delta"u8)) - { - foreach (JsonProperty deltaProperty in choiceProperty.Value.EnumerateObject()) - { - if (deltaProperty.NameEquals("role"u8)) - { - role = deltaProperty.Value.GetString() switch - { - "system" => ChatRole.System, - "user" => ChatRole.User, - "assistant" => ChatRole.Assistant, - "tool" => ChatRole.Tool, - "function" => ChatRole.Function, - _ => throw new ArgumentException(nameof(role)), - }; - continue; - } - if (deltaProperty.NameEquals("content"u8)) - { - contentUpdate = deltaProperty.Value.GetString(); - continue; - } - if (deltaProperty.NameEquals("function_call"u8)) - { - foreach (JsonProperty functionProperty in deltaProperty.Value.EnumerateObject()) - { - if (functionProperty.NameEquals("name"u8)) - { - functionName = functionProperty.Value.GetString(); - continue; - } - if (functionProperty.NameEquals("arguments"u8)) - { - functionArgumentsUpdate = functionProperty.Value.GetString(); - } - } - } - if (deltaProperty.NameEquals("tool_calls")) - { - foreach (JsonElement toolCallElement in deltaProperty.Value.EnumerateArray()) - { - toolCallUpdates.Add( - StreamingToolCallUpdate.DeserializeStreamingToolCallUpdate(toolCallElement)); - } - } - } - } - if (choiceProperty.NameEquals("logprobs"u8)) - { - Internal.Models.CreateChatCompletionResponseChoiceLogprobs internalLogprobs - = Internal.Models.CreateChatCompletionResponseChoiceLogprobs.DeserializeCreateChatCompletionResponseChoiceLogprobs( - choiceProperty.Value); - logProbabilities = ChatLogProbabilityCollection.FromInternalData(internalLogprobs); - } - } - // In the unlikely event that more than one tool call arrives on a single chunk, we'll generate - // separate updates just like for choices. Adding a "null" if empty lets us avoid a separate loop. - if (toolCallUpdates.Count == 0) - { - toolCallUpdates.Add(null); - } - foreach (StreamingToolCallUpdate toolCallUpdate in toolCallUpdates) - { - results.Add(new StreamingChatUpdate( - id, - created, - systemFingerprint, - choiceIndex, - role, - contentUpdate, - finishReason, - functionName, - functionArgumentsUpdate, - toolCallUpdate, - logProbabilities)); - } - } - continue; - } - } - if (results.Count == 0) - { - results.Add(new StreamingChatUpdate(id, created, systemFingerprint)); - } - return results; + public string GetFormatFromOptions(ModelReaderWriterOptions options) + { + throw new NotImplementedException(); } } \ No newline at end of file diff --git a/.dotnet/src/Custom/Chat/StreamingChatUpdateCollection.cs b/.dotnet/src/Custom/Chat/StreamingChatUpdateCollection.cs new file mode 100644 index 000000000..8b7466848 --- /dev/null +++ b/.dotnet/src/Custom/Chat/StreamingChatUpdateCollection.cs @@ -0,0 +1,177 @@ +using System; +using System.ClientModel.Primitives; +using System.Collections.Generic; +using System.Text.Json; +namespace OpenAI.Chat; + +internal class StreamingChatUpdateCollection : StreamedEventCollection +{ + internal static StreamingChatUpdateCollection DeserializeSseChatUpdates(ReadOnlyMemory _, JsonElement sseDataJson) + { + // TODO: would another enumerable implementation be more performant than list? + StreamingChatUpdateCollection results = []; + + // TODO: Do we need to validate that we didn't get null or empty? + // What's the contract for the JSON updates? + + if (sseDataJson.ValueKind == JsonValueKind.Null) + { + return results; + } + + string id = default; + DateTimeOffset created = default; + string systemFingerprint = null; + foreach (JsonProperty property in sseDataJson.EnumerateObject()) + { + if (property.NameEquals("id"u8)) + { + id = property.Value.GetString(); + continue; + } + if (property.NameEquals("created"u8)) + { + created = DateTimeOffset.FromUnixTimeSeconds(property.Value.GetInt64()); + continue; + } + if (property.NameEquals("system_fingerprint")) + { + systemFingerprint = property.Value.GetString(); + continue; + } + if (property.NameEquals("choices"u8)) + { + foreach (JsonElement choiceElement in property.Value.EnumerateArray()) + { + ChatRole? role = null; + string contentUpdate = null; + string functionName = null; + string functionArgumentsUpdate = null; + int choiceIndex = 0; + ChatFinishReason? finishReason = null; + List toolCallUpdates = []; + ChatLogProbabilityCollection logProbabilities = null; + + foreach (JsonProperty choiceProperty in choiceElement.EnumerateObject()) + { + if (choiceProperty.NameEquals("index"u8)) + { + choiceIndex = choiceProperty.Value.GetInt32(); + continue; + } + if (choiceProperty.NameEquals("finish_reason"u8)) + { + if (choiceProperty.Value.ValueKind == JsonValueKind.Null) + { + finishReason = null; + continue; + } + finishReason = choiceProperty.Value.GetString() switch + { + "stop" => ChatFinishReason.Stopped, + "length" => ChatFinishReason.Length, + "tool_calls" => ChatFinishReason.ToolCalls, + "function_call" => ChatFinishReason.FunctionCall, + "content_filter" => ChatFinishReason.ContentFilter, + _ => throw new ArgumentException(nameof(finishReason)), + }; + continue; + } + if (choiceProperty.NameEquals("delta"u8)) + { + foreach (JsonProperty deltaProperty in choiceProperty.Value.EnumerateObject()) + { + if (deltaProperty.NameEquals("role"u8)) + { + role = deltaProperty.Value.GetString() switch + { + "system" => ChatRole.System, + "user" => ChatRole.User, + "assistant" => ChatRole.Assistant, + "tool" => ChatRole.Tool, + "function" => ChatRole.Function, + _ => throw new ArgumentException(nameof(role)), + }; + continue; + } + if (deltaProperty.NameEquals("content"u8)) + { + contentUpdate = deltaProperty.Value.GetString(); + continue; + } + if (deltaProperty.NameEquals("function_call"u8)) + { + foreach (JsonProperty functionProperty in deltaProperty.Value.EnumerateObject()) + { + if (functionProperty.NameEquals("name"u8)) + { + functionName = functionProperty.Value.GetString(); + continue; + } + if (functionProperty.NameEquals("arguments"u8)) + { + functionArgumentsUpdate = functionProperty.Value.GetString(); + } + } + } + if (deltaProperty.NameEquals("tool_calls")) + { + foreach (JsonElement toolCallElement in deltaProperty.Value.EnumerateArray()) + { + toolCallUpdates.Add( + StreamingToolCallUpdate.DeserializeStreamingToolCallUpdate(toolCallElement)); + } + } + } + } + if (choiceProperty.NameEquals("logprobs"u8)) + { + Internal.Models.CreateChatCompletionResponseChoiceLogprobs internalLogprobs + = Internal.Models.CreateChatCompletionResponseChoiceLogprobs.DeserializeCreateChatCompletionResponseChoiceLogprobs( + choiceProperty.Value); + logProbabilities = ChatLogProbabilityCollection.FromInternalData(internalLogprobs); + } + } + // In the unlikely event that more than one tool call arrives on a single chunk, we'll generate + // separate updates just like for choices. Adding a "null" if empty lets us avoid a separate loop. + if (toolCallUpdates.Count == 0) + { + toolCallUpdates.Add(null); + } + foreach (StreamingToolCallUpdate toolCallUpdate in toolCallUpdates) + { + results.Add(new StreamingChatUpdate( + id, + created, + systemFingerprint, + choiceIndex, + role, + contentUpdate, + finishReason, + functionName, + functionArgumentsUpdate, + toolCallUpdate, + logProbabilities)); + } + } + continue; + } + } + if (results.Count == 0) + { + results.Add(new StreamingChatUpdate(id, created, systemFingerprint)); + } + return results; + } + + public override StreamedEventCollection Create(ref Utf8JsonReader reader, ModelReaderWriterOptions options) + { + throw new NotImplementedException(); + } + + public override StreamedEventCollection Create(BinaryData data, ModelReaderWriterOptions options) + { + using JsonDocument doc = JsonDocument.Parse(data); + return DeserializeSseChatUpdates(null, doc.RootElement); + } +} diff --git a/.dotnet/src/Utility/SseAsyncEnumerator.cs b/.dotnet/src/Utility/SseAsyncEnumerator.cs index 31e2c522d..7588136b3 100644 --- a/.dotnet/src/Utility/SseAsyncEnumerator.cs +++ b/.dotnet/src/Utility/SseAsyncEnumerator.cs @@ -20,12 +20,15 @@ internal static async IAsyncEnumerable EnumerateFromSseJsonStream( Func, JsonElement, IEnumerable> multiElementDeserializer, [EnumeratorCancellation] CancellationToken cancellationToken = default) { - using SseReader reader = new SseReader(stream); + using AsyncSseReader reader = new AsyncSseReader(stream); await foreach (ServerSentEvent sseEvent in reader.GetEventsAsync(cancellationToken)) { + // TODO: does `continue` mean we keep reading from the Stream after + // the [DONE] event? If so, figure out if this is what we want here. if (IsWellKnownDoneToken(sseEvent.Data)) continue; + // TODO: Make faster with Utf8JsonReader, IModel? using JsonDocument sseDocument = JsonDocument.Parse(sseEvent.Data); foreach (T item in multiElementDeserializer(sseEvent.EventName, sseDocument.RootElement)) diff --git a/.dotnet/src/Utility/SseReader.cs b/.dotnet/src/Utility/SseReader.cs index f7d3a96a7..cdefe3b3d 100644 --- a/.dotnet/src/Utility/SseReader.cs +++ b/.dotnet/src/Utility/SseReader.cs @@ -1,106 +1,183 @@ using System; using System.Collections.Generic; using System.IO; -using System.Runtime.CompilerServices; using System.Threading; using System.Threading.Tasks; namespace OpenAI; -internal sealed class SseReader : IDisposable, IAsyncDisposable +#nullable enable + +// TODO: Can this type move its Dispose implementation into the +// Enumerator? +internal sealed class AsyncSseReader : IAsyncEnumerable, IDisposable, IAsyncDisposable { private readonly Stream _stream; - private readonly StreamReader _reader; + private bool _disposedValue; - public SseReader(Stream stream) + public AsyncSseReader(Stream stream) { _stream = stream; - _reader = new StreamReader(stream); } - /// - /// Synchronously retrieves the next server-sent event from the underlying stream, blocking until a new event is - /// available and returning null once no further data is present on the stream. - /// - /// An optional cancellation token that can abort subsequent reads. - /// - /// The next in the stream, or null once no more data can be read from the stream. - /// - public IEnumerable GetEvents(CancellationToken cancellationToken = default) - { - List fields = []; + // TODO: Provide sync version + //// TODO: reuse code across sync and async. - while (!cancellationToken.IsCancellationRequested) - { - string line = _reader.ReadLine(); - if (line == null) - { - // A null line indicates end of input - yield break; - } - else if (line.Length == 0) - { - // An empty line should dispatch an event for pending accumulated fields - ServerSentEvent nextEvent = new(fields); - fields = []; - yield return nextEvent; - } - else if (line[0] == ':') - { - // A line beginning with a colon is a comment and should be ignored - continue; - } - else - { - // Otherwise, process the the field + value and accumulate it for the next dispatched event - fields.Add(new ServerSentEventField(line)); - } - } + ///// + ///// Synchronously retrieves the next server-sent event from the underlying stream, blocking until a new event is + ///// available and returning null once no further data is present on the stream. + ///// + ///// An optional cancellation token that can abort subsequent reads. + ///// + ///// The next in the stream, or null once no more data can be read from the stream. + ///// + //public IEnumerable GetEvents(CancellationToken cancellationToken = default) + //{ + // List fields = []; + + // while (!cancellationToken.IsCancellationRequested) + // { + // string line = _reader.ReadLine(); + // if (line == null) + // { + // // A null line indicates end of input + // yield break; + // } + // else if (line.Length == 0) + // { + // // An empty line should dispatch an event for pending accumulated fields + // ServerSentEvent nextEvent = new(fields); + // fields = []; + // yield return nextEvent; + // } + // else if (line[0] == ':') + // { + // // A line beginning with a colon is a comment and should be ignored + // continue; + // } + // else + // { + // // Otherwise, process the the field + value and accumulate it for the next dispatched event + // fields.Add(new ServerSentEventField(line)); + // } + // } + + // yield break; + //} - yield break; + ///// + ///// Asynchronously retrieves the next server-sent event from the underlying stream, blocking until a new event is + ///// available and returning null once no further data is present on the stream. + ///// + ///// An optional cancellation token that can abort subsequent reads. + ///// + ///// The next in the stream, or null once no more data can be read from the stream. + ///// + //public async IAsyncEnumerable GetEventsAsync([EnumeratorCancellation] CancellationToken cancellationToken = default) + //{ + // List fields = []; + + // while (!cancellationToken.IsCancellationRequested) + // { + // string line = await _reader.ReadLineAsync().ConfigureAwait(false); + // if (line == null) + // { + // // A null line indicates end of input + // yield break; + // } + // else if (line.Length == 0) + // { + // // An empty line should dispatch an event for pending accumulated fields + // ServerSentEvent nextEvent = new(fields); + // fields = []; + // yield return nextEvent; + // } + // else if (line[0] == ':') + // { + // // A line beginning with a colon is a comment and should be ignored + // continue; + // } + // else + // { + // // Otherwise, process the the field + value and accumulate it for the next dispatched event + // fields.Add(new ServerSentEventField(line)); + // } + // } + + // yield break; + //} + + public IAsyncEnumerator GetAsyncEnumerator(CancellationToken cancellationToken = default) + { + return new AsyncSseReaderEnumerator(_stream); } - /// - /// Asynchronously retrieves the next server-sent event from the underlying stream, blocking until a new event is - /// available and returning null once no further data is present on the stream. - /// - /// An optional cancellation token that can abort subsequent reads. - /// - /// The next in the stream, or null once no more data can be read from the stream. - /// - public async IAsyncEnumerable GetEventsAsync([EnumeratorCancellation] CancellationToken cancellationToken = default) + private class AsyncSseReaderEnumerator : IAsyncEnumerator { - List fields = []; + private readonly Stream _stream; + private readonly StreamReader _reader; - while (!cancellationToken.IsCancellationRequested) + private ServerSentEvent? _current; + + public AsyncSseReaderEnumerator(Stream stream) { - string line = await _reader.ReadLineAsync().ConfigureAwait(false); - if (line == null) - { - // A null line indicates end of input - yield break; - } - else if (line.Length == 0) - { - // An empty line should dispatch an event for pending accumulated fields - ServerSentEvent nextEvent = new(fields); - fields = []; - yield return nextEvent; - } - else if (line[0] == ':') - { - // A line beginning with a colon is a comment and should be ignored - continue; - } - else + _stream = stream; + _reader = new StreamReader(stream); + } + + // TODO: recall proper semantics for Current and fix null issues. + public ServerSentEvent Current => _current!.Value; + + public async ValueTask MoveNextAsync() + { + // TODO: don't reallocate every call to MoveNext + // TODO: UTF-8 all the way down possible here? + List fields = []; + + // TODO: How to handle the CancellationToken? + + // TODO: Call different ConfigureAwait variant in this context, or no? + while (true) { - // Otherwise, process the the field + value and accumulate it for the next dispatched event - fields.Add(new ServerSentEventField(line)); + string line = await _reader.ReadLineAsync().ConfigureAwait(false); + + if (line == null) + { + // A null line indicates end of input + return false; + } + + // TODO: another way to rework this for perf, clarity, or is + // this optimal? + else if (line.Length == 0) + { + // An empty line should dispatch an event for pending accumulated fields + ServerSentEvent nextEvent = new(fields); + fields = []; + _current = nextEvent; + return true; + } + else if (line[0] == ':') + { + // A line beginning with a colon is a comment and should be ignored + continue; + } + else + { + // Otherwise, process the the field + value and accumulate it for the next dispatched event + fields.Add(new ServerSentEventField(line)); + } } } - yield break; + public ValueTask DisposeAsync() + { + // TODO: revisit per platforms where async dispose is available. + _stream?.Dispose(); + _reader?.Dispose(); + return new ValueTask(); + } } public void Dispose() @@ -115,7 +192,6 @@ private void Dispose(bool disposing) { if (disposing) { - _reader.Dispose(); _stream.Dispose(); } diff --git a/.dotnet/src/Utility/StreamedEventCollection.cs b/.dotnet/src/Utility/StreamedEventCollection.cs new file mode 100644 index 000000000..d9584fa1d --- /dev/null +++ b/.dotnet/src/Utility/StreamedEventCollection.cs @@ -0,0 +1,33 @@ +using System; +using System.ClientModel.Primitives; +using System.Collections.Generic; +using System.Text.Json; + +namespace OpenAI; + +// TODO: Make it work for non-JSON models too. +// TODO: The list part is a hack - pull it out +internal abstract class StreamedEventCollection : List, IJsonModel> + // This just promises me I can deserialize a collection of serializable models. + // Should it be more general? + where TModel : IJsonModel +{ + public abstract StreamedEventCollection Create(ref Utf8JsonReader reader, ModelReaderWriterOptions options); + + public abstract StreamedEventCollection Create(BinaryData data, ModelReaderWriterOptions options); + + public string GetFormatFromOptions(ModelReaderWriterOptions options) + { + throw new NotImplementedException(); + } + + public void Write(Utf8JsonWriter writer, ModelReaderWriterOptions options) + { + throw new NotImplementedException(); + } + + public BinaryData Write(ModelReaderWriterOptions options) + { + throw new NotImplementedException(); + } +} diff --git a/.dotnet/tests/TestScenarios/AssistantTests.cs b/.dotnet/tests/TestScenarios/AssistantTests.cs index aa19fc2d8..b8a5a662c 100644 --- a/.dotnet/tests/TestScenarios/AssistantTests.cs +++ b/.dotnet/tests/TestScenarios/AssistantTests.cs @@ -135,7 +135,7 @@ public async Task SimpleStreamingRunWorks() AssistantClient client = GetTestClient(); Assistant assistant = await CreateCommonTestAssistantAsync(); - StreamingClientResult runUpdateResult = client.CreateThreadAndRunStreaming( + StreamingClientResult runUpdateResult = client.CreateThreadAndRunStreaming( assistant.Id, new ThreadCreationOptions() { @@ -145,7 +145,7 @@ public async Task SimpleStreamingRunWorks() } }); Assert.That(runUpdateResult, Is.Not.Null); - await foreach (StreamingUpdate runUpdate in runUpdateResult) + await foreach (StreamedEventCollection runUpdate in runUpdateResult) { if (runUpdate is StreamingMessageCreation messageCreation) { @@ -190,11 +190,11 @@ public async Task StreamingWithToolsWorks() AssistantThread thread = threadResult.Value; Assert.That(thread, Is.Not.Null); - StreamingClientResult streamingResult = await client.CreateRunStreamingAsync(thread.Id, assistant.Id); + StreamingClientResult streamingResult = await client.CreateRunStreamingAsync(thread.Id, assistant.Id); Assert.That(streamingResult, Is.Not.Null); List requiredActions = []; ThreadRun initialStreamedRun = null; - await foreach (StreamingUpdate streamingUpdate in streamingResult) + await foreach (StreamedEventCollection streamingUpdate in streamingResult) { if (streamingUpdate is StreamingRunCreation streamingRunCreation) { @@ -221,7 +221,7 @@ public async Task StreamingWithToolsWorks() } } streamingResult = await client.SubmitToolOutputsStreamingAsync(thread.Id, initialStreamedRun.Id, toolOutputs); - await foreach (StreamingUpdate streamingUpdate in streamingResult) + await foreach (StreamedEventCollection streamingUpdate in streamingResult) { Console.WriteLine(streamingUpdate.GetRawSseEvent().ToString()); } From 25d28fce4798e63fffbd64ef9faac7dd6db97336 Mon Sep 17 00:00:00 2001 From: Anne Thompson Date: Thu, 28 Mar 2024 16:05:54 -0700 Subject: [PATCH 13/15] add quick and dirty StreamingAssistantResult subtype ... --- .../src/Custom/Assistants/AssistantClient.cs | 16 ++- .../Assistants/StreamingAssistantResult.cs | 97 +++++++++++++++++++ 2 files changed, 103 insertions(+), 10 deletions(-) create mode 100644 .dotnet/src/Custom/Assistants/StreamingAssistantResult.cs diff --git a/.dotnet/src/Custom/Assistants/AssistantClient.cs b/.dotnet/src/Custom/Assistants/AssistantClient.cs index 3f05ed243..96082108f 100644 --- a/.dotnet/src/Custom/Assistants/AssistantClient.cs +++ b/.dotnet/src/Custom/Assistants/AssistantClient.cs @@ -670,18 +670,14 @@ internal static StreamingClientResult CreateStreamingRunResult( { throw new ClientResultException(message.Response); } - PipelineResponse response = null; + + // TODO: why do we need to wrap this in a try-catch? + // Would putting the message in a `using` block suffice? + PipelineResponse response = message.Response; try { - response = message.ExtractResponse(); - ClientResult genericResult = ClientResult.FromResponse(response); - StreamingClientResult streamingResult = StreamingClientResult.CreateFromResponse( - genericResult, - (responseForEnumeration) => SseAsyncEnumerator.EnumerateFromSseJsonStream( - responseForEnumeration.GetRawResponse().ContentStream, - StreamingUpdate.DeserializeSseRunUpdates)); - response = null; - return streamingResult; + // TODO: dust this part up... + return new StreamingAssistantResult(response); } finally { diff --git a/.dotnet/src/Custom/Assistants/StreamingAssistantResult.cs b/.dotnet/src/Custom/Assistants/StreamingAssistantResult.cs new file mode 100644 index 000000000..4ad61d8ec --- /dev/null +++ b/.dotnet/src/Custom/Assistants/StreamingAssistantResult.cs @@ -0,0 +1,97 @@ +using System; +using System.ClientModel.Primitives; +using System.Collections.Generic; +using System.IO; +using System.Threading; +using System.Threading.Tasks; + +namespace OpenAI.Assistants; + +#nullable enable + +internal class StreamingAssistantResult : StreamingClientResult +{ + public StreamingAssistantResult(PipelineResponse response) : base(response) + { + } + + public override IAsyncEnumerator GetAsyncEnumerator(CancellationToken cancellationToken = default) + { + // Note: this implementation disposes the stream after the caller has + // enumerated the elements obtained from the stream. That is to say, + // the `await foreach` loop can only happen once -- if it is tried a + // second time, the caller will get an ObjectDisposedException trying + // to access a disposed Stream. + using PipelineResponse response = GetRawResponse(); + + // Extract the content stream from the response to obtain dispose + // ownership of it. This means the content stream will not be disposed + // when the response is disposed. + Stream contentStream = response.ContentStream ?? throw new InvalidOperationException("Cannot enumerate null response ContentStream."); + response.ContentStream = null; + + return new ChatUpdateEnumerator(contentStream); + } + + private class ChatUpdateEnumerator : IAsyncEnumerator + { + private readonly IAsyncEnumerator _sseEvents; + + private List? _currentUpdates; + private int _currentUpdateIndex; + + public ChatUpdateEnumerator(Stream stream) + { + AsyncSseReader reader = new AsyncSseReader(stream); + + // TODO: Pass CancellationToken. + _sseEvents = reader.GetAsyncEnumerator(); + } + + public StreamingUpdate Current => throw new NotImplementedException(); + + public async ValueTask MoveNextAsync() + { + // TODO: Can we wrap the boilerplate parts of this up into a special SSE base type for this? + // would that be public/internal/instantiated based on configuration given to the + // generator from the TSP? + + // Still have leftovers from the last event we pulled from the reader. + if (_currentUpdates is not null && _currentUpdateIndex < _currentUpdates.Count) + { + _currentUpdateIndex++; + return true; + } + + // We either don't have any stored updates, or we've exceeded the + // count of the ones we have. Get the next set. + + // TODO: Call different configure await variant in this context, or no? + if (!await _sseEvents.MoveNextAsync().ConfigureAwait(false)) + { + // Done with events from the stream. + return false; + } + + ServerSentEvent ssEvent = _sseEvents.Current; + + // TODO: optimize + BinaryData data = BinaryData.FromString(new string(ssEvent.Data.ToArray())); + + // TODO: don't instantiate every time. + StreamingUpdateCollection justToCreate = new StreamingUpdateCollection(); + + _currentUpdates = justToCreate.Create(data, ModelReaderWriterOptions.Json); + _currentUpdateIndex = 0; + + return true; + } + + public ValueTask DisposeAsync() + { + // TODO: revisit per platforms where async dispose is available. + _sseEvents?.DisposeAsync(); + return new ValueTask(); + } + } +} \ No newline at end of file From 6f0acd5b821a46f4b6d78660195aa909353c0b52 Mon Sep 17 00:00:00 2001 From: Anne Thompson Date: Thu, 28 Mar 2024 17:58:54 -0700 Subject: [PATCH 14/15] random thoughts and questions --- .../src/Custom/Assistants/StreamingAssistantResult.cs | 6 +++--- .../tests/Samples/Chat/Sample02_StreamingChatAsync.cs | 5 +++++ .dotnet/tests/TestScenarios/AssistantTests.cs | 10 +++++----- 3 files changed, 13 insertions(+), 8 deletions(-) diff --git a/.dotnet/src/Custom/Assistants/StreamingAssistantResult.cs b/.dotnet/src/Custom/Assistants/StreamingAssistantResult.cs index 4ad61d8ec..060942470 100644 --- a/.dotnet/src/Custom/Assistants/StreamingAssistantResult.cs +++ b/.dotnet/src/Custom/Assistants/StreamingAssistantResult.cs @@ -30,17 +30,17 @@ public override IAsyncEnumerator GetAsyncEnumerator(Cancellatio Stream contentStream = response.ContentStream ?? throw new InvalidOperationException("Cannot enumerate null response ContentStream."); response.ContentStream = null; - return new ChatUpdateEnumerator(contentStream); + return new AssistantRunUpdateEnumerator(contentStream); } - private class ChatUpdateEnumerator : IAsyncEnumerator + private class AssistantRunUpdateEnumerator : IAsyncEnumerator { private readonly IAsyncEnumerator _sseEvents; private List? _currentUpdates; private int _currentUpdateIndex; - public ChatUpdateEnumerator(Stream stream) + public AssistantRunUpdateEnumerator(Stream stream) { AsyncSseReader reader = new AsyncSseReader(stream); diff --git a/.dotnet/tests/Samples/Chat/Sample02_StreamingChatAsync.cs b/.dotnet/tests/Samples/Chat/Sample02_StreamingChatAsync.cs index 6ce1a96f0..e59ff628e 100644 --- a/.dotnet/tests/Samples/Chat/Sample02_StreamingChatAsync.cs +++ b/.dotnet/tests/Samples/Chat/Sample02_StreamingChatAsync.cs @@ -1,6 +1,8 @@ using NUnit.Framework; using OpenAI.Chat; using System; +using System.ClientModel.Primitives; +using System.IO; using System.Threading.Tasks; namespace OpenAI.Samples @@ -21,6 +23,9 @@ public async Task Sample02_StreamingChatAsync() { Console.Write(chatUpdate.ContentUpdate); } + + PipelineResponse response = result.GetRawResponse(); + Stream stream = response.ContentStream; } } } diff --git a/.dotnet/tests/TestScenarios/AssistantTests.cs b/.dotnet/tests/TestScenarios/AssistantTests.cs index b8a5a662c..aa19fc2d8 100644 --- a/.dotnet/tests/TestScenarios/AssistantTests.cs +++ b/.dotnet/tests/TestScenarios/AssistantTests.cs @@ -135,7 +135,7 @@ public async Task SimpleStreamingRunWorks() AssistantClient client = GetTestClient(); Assistant assistant = await CreateCommonTestAssistantAsync(); - StreamingClientResult runUpdateResult = client.CreateThreadAndRunStreaming( + StreamingClientResult runUpdateResult = client.CreateThreadAndRunStreaming( assistant.Id, new ThreadCreationOptions() { @@ -145,7 +145,7 @@ public async Task SimpleStreamingRunWorks() } }); Assert.That(runUpdateResult, Is.Not.Null); - await foreach (StreamedEventCollection runUpdate in runUpdateResult) + await foreach (StreamingUpdate runUpdate in runUpdateResult) { if (runUpdate is StreamingMessageCreation messageCreation) { @@ -190,11 +190,11 @@ public async Task StreamingWithToolsWorks() AssistantThread thread = threadResult.Value; Assert.That(thread, Is.Not.Null); - StreamingClientResult streamingResult = await client.CreateRunStreamingAsync(thread.Id, assistant.Id); + StreamingClientResult streamingResult = await client.CreateRunStreamingAsync(thread.Id, assistant.Id); Assert.That(streamingResult, Is.Not.Null); List requiredActions = []; ThreadRun initialStreamedRun = null; - await foreach (StreamedEventCollection streamingUpdate in streamingResult) + await foreach (StreamingUpdate streamingUpdate in streamingResult) { if (streamingUpdate is StreamingRunCreation streamingRunCreation) { @@ -221,7 +221,7 @@ public async Task StreamingWithToolsWorks() } } streamingResult = await client.SubmitToolOutputsStreamingAsync(thread.Id, initialStreamedRun.Id, toolOutputs); - await foreach (StreamedEventCollection streamingUpdate in streamingResult) + await foreach (StreamingUpdate streamingUpdate in streamingResult) { Console.WriteLine(streamingUpdate.GetRawSseEvent().ToString()); } From ecf286f68cf92e83cb2e2c546ceb01c1172e00d8 Mon Sep 17 00:00:00 2001 From: Anne Thompson Date: Thu, 28 Mar 2024 18:04:00 -0700 Subject: [PATCH 15/15] update - this is important I think --- .dotnet/src/Utility/StreamingClientResultOfT.cs | 3 +++ 1 file changed, 3 insertions(+) diff --git a/.dotnet/src/Utility/StreamingClientResultOfT.cs b/.dotnet/src/Utility/StreamingClientResultOfT.cs index 6d2d68225..7483dad71 100644 --- a/.dotnet/src/Utility/StreamingClientResultOfT.cs +++ b/.dotnet/src/Utility/StreamingClientResultOfT.cs @@ -14,6 +14,9 @@ namespace OpenAI; /// The data type representative of distinct, streamable items. // TODO: Revisit the IDisposable question public abstract class StreamingClientResult : ClientResult, IAsyncEnumerable + // TODO: Note that constraining the T means the implementation can use + // ModelReaderWriter for deserialization. + where T : IPersistableModel { protected StreamingClientResult(PipelineResponse response) : base(response) {